diff --git a/.jenkins/rules/pylint/pylintrc b/.jenkins/rules/pylint/pylintrc old mode 100755 new mode 100644 diff --git a/akg b/akg deleted file mode 160000 index e53ad90026a..00000000000 --- a/akg +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e53ad90026ad849732b9a1876914bd0e73755be9 diff --git a/build.sh b/build.sh old mode 100755 new mode 100644 diff --git a/cmake/gencode.cmake b/cmake/gencode.cmake old mode 100755 new mode 100644 diff --git a/config/hccl_multi_machine_multi_rank.json b/config/hccl_multi_machine_multi_rank.json index 4b494f7d71b..7ed882e30c8 100644 --- a/config/hccl_multi_machine_multi_rank.json +++ b/config/hccl_multi_machine_multi_rank.json @@ -1,175 +1,175 @@ -{ - "board_id": "0x0000", - "chip_info": "910", - "deploy_mode": "lab", - "group_count": "1", - "group_list": [{ - "device_num": "16", - "server_num": "2", - "group_name": "", - "instance_count": "16", - "instance_list": [{ - "devices": [{ - "device_id": "0", - "device_ip": "[A_device_ip_0]" - }], - "rank_id": "0", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "1", - "device_ip": "[A_device_ip_1]" - }], - "rank_id": "1", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "2", - "device_ip": "[A_device_ip_2]" - }], - "rank_id": "2", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "3", - "device_ip": "[A_device_ip_3]" - }], - "rank_id": "3", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "4", - "device_ip": "[A_device_ip_4]" - }], - "rank_id": "4", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "5", - "device_ip": "[A_device_ip_5]" - }], - "rank_id": "5", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "6", - "device_ip": "[A_device_ip_6]" - }], - "rank_id": "6", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "7", - "device_ip": "[A_device_ip_7]" - }], - "rank_id": "7", - "server_id": "[server_id_A]" - }, - { - "devices": [{ - "device_id": "0", - "device_ip": "[B_device_ip_0]" - }], - "rank_id": "8", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "1", - "device_ip": "[B_device_ip_1]" - }], - "rank_id": "9", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "2", - "device_ip": "[B_device_ip_2]" - }], - "rank_id": "10", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "3", - "device_ip": "[B_device_ip_3]" - }], - "rank_id": "11", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "4", - "device_ip": "[B_device_ip_4]" - }], - "rank_id": "12", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "5", - "device_ip": "[B_device_ip_5]" - }], - "rank_id": "13", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "6", - "device_ip": "[B_device_ip_6]" - }], - "rank_id": "14", - "server_id": "[server_id_B]" - }, - { - "devices": [{ - "device_id": "7", - "device_ip": "[B_device_ip_7]" - }], - "rank_id": "15", - "server_id": "[server_id_B]" - } - ] - }], - "para_plane_nic_location": "device", - "para_plane_nic_name": [ - "eth0", - "eth1", - "eth2", - "eth3", - "eth4", - "eth5", - "eth6", - "eth7" - ], - "para_plane_nic_num": "8", - "status": "completed", - - "hccl_config_json_spec": { - "board_id": "board id, current support x0000 or 0x3000", - "chip_info": "chip info, current is 910", - "deploy_mode": "current use lab", - "group_count": "number of groups used", - "group_list": "detailed group information", - "device_num": "number of devices used, the value is the nth power of 2", - "server_num": "number of multiple machines, single machine is 1", - "group_name": "default is hccl_world_group or specified", - "instance_count": "number of instance used, generally equal to device_num", - "instance_list": "detailed instance information", - "device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num.if server_num greater than 1, the id can be restart from 0", - "device_ip": "ip corresponding to device_id", - "rank_id": "the first device must be 0 and then increase in order", - "server_id": "can be specified as the machine's ip address", - "para_plane_nic_location": "current use device", - "para_plane_nic_name": "network card corresponding to device ip", - "para_plane_nic_num": "number of network cards used", - "status": "current use completed" - } +{ + "board_id": "0x0000", + "chip_info": "910", + "deploy_mode": "lab", + "group_count": "1", + "group_list": [{ + "device_num": "16", + "server_num": "2", + "group_name": "", + "instance_count": "16", + "instance_list": [{ + "devices": [{ + "device_id": "0", + "device_ip": "[A_device_ip_0]" + }], + "rank_id": "0", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "1", + "device_ip": "[A_device_ip_1]" + }], + "rank_id": "1", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "2", + "device_ip": "[A_device_ip_2]" + }], + "rank_id": "2", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "3", + "device_ip": "[A_device_ip_3]" + }], + "rank_id": "3", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "4", + "device_ip": "[A_device_ip_4]" + }], + "rank_id": "4", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "5", + "device_ip": "[A_device_ip_5]" + }], + "rank_id": "5", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "6", + "device_ip": "[A_device_ip_6]" + }], + "rank_id": "6", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "7", + "device_ip": "[A_device_ip_7]" + }], + "rank_id": "7", + "server_id": "[server_id_A]" + }, + { + "devices": [{ + "device_id": "0", + "device_ip": "[B_device_ip_0]" + }], + "rank_id": "8", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "1", + "device_ip": "[B_device_ip_1]" + }], + "rank_id": "9", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "2", + "device_ip": "[B_device_ip_2]" + }], + "rank_id": "10", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "3", + "device_ip": "[B_device_ip_3]" + }], + "rank_id": "11", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "4", + "device_ip": "[B_device_ip_4]" + }], + "rank_id": "12", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "5", + "device_ip": "[B_device_ip_5]" + }], + "rank_id": "13", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "6", + "device_ip": "[B_device_ip_6]" + }], + "rank_id": "14", + "server_id": "[server_id_B]" + }, + { + "devices": [{ + "device_id": "7", + "device_ip": "[B_device_ip_7]" + }], + "rank_id": "15", + "server_id": "[server_id_B]" + } + ] + }], + "para_plane_nic_location": "device", + "para_plane_nic_name": [ + "eth0", + "eth1", + "eth2", + "eth3", + "eth4", + "eth5", + "eth6", + "eth7" + ], + "para_plane_nic_num": "8", + "status": "completed", + + "hccl_config_json_spec": { + "board_id": "board id, current support x0000 or 0x3000", + "chip_info": "chip info, current is 910", + "deploy_mode": "current use lab", + "group_count": "number of groups used", + "group_list": "detailed group information", + "device_num": "number of devices used, the value is the nth power of 2", + "server_num": "number of multiple machines, single machine is 1", + "group_name": "default is hccl_world_group or specified", + "instance_count": "number of instance used, generally equal to device_num", + "instance_list": "detailed instance information", + "device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num.if server_num greater than 1, the id can be restart from 0", + "device_ip": "ip corresponding to device_id", + "rank_id": "the first device must be 0 and then increase in order", + "server_id": "can be specified as the machine's ip address", + "para_plane_nic_location": "current use device", + "para_plane_nic_name": "network card corresponding to device ip", + "para_plane_nic_num": "number of network cards used", + "status": "current use completed" + } } \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.AGNewsDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.AGNewsDataset.rst index fea14419aeb..e712224c3e9 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.AGNewsDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.AGNewsDataset.rst @@ -1,66 +1,66 @@ -mindspore.dataset.AGNewsDataset -=============================== - -.. py:class:: mindspore.dataset.AGNewsDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - AG News数据集。 - - 生成的数据集有三列 `[index, title, description]` ,三列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于AGNews数据集:** - - AG是一个大型合集,具有超过100万篇新闻文章。这些新闻文章是由ComeToMyHead在持续1年多的活动中,从2000多个新闻来源收集的。ComeToMyHead是一个学术新闻搜索引擎,自2004年7月以来一直在运营。 - 数据集由学者提供,用于研究目的,如数据挖掘(聚类、分类等)、信息检索(排名、搜索等)、xml、数据压缩、数据流和任何其他非商业活动。 - AG的新闻主题类别来自于原始语料库中四个最大的类别。每个分类包含30000个训练样本和1900个测试样本。train.csv中的训练样本总数为12万,test.csv中的测试样本总数为7600。 - - 可以将数据集文件解压缩到以下结构中,并通过MindSpore的API读取: - - .. code-block:: - - . - └── ag_news_dataset_dir - ├── classes.txt - ├── train.csv - ├── test.csv - └── readme.txt - - **引用:** - - .. code-block:: - - @misc{zhang2015characterlevel, - title={Character-level Convolutional Networks for Text Classification}, - author={Xiang Zhang and Junbo Zhao and Yann LeCun}, - year={2015}, - eprint={1509.01626}, - archivePrefix={arXiv}, - primaryClass={cs.LG} - } - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.AGNewsDataset +=============================== + +.. py:class:: mindspore.dataset.AGNewsDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + AG News数据集。 + + 生成的数据集有三列 `[index, title, description]` ,三列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于AGNews数据集:** + + AG是一个大型合集,具有超过100万篇新闻文章。这些新闻文章是由ComeToMyHead在持续1年多的活动中,从2000多个新闻来源收集的。ComeToMyHead是一个学术新闻搜索引擎,自2004年7月以来一直在运营。 + 数据集由学者提供,用于研究目的,如数据挖掘(聚类、分类等)、信息检索(排名、搜索等)、xml、数据压缩、数据流和任何其他非商业活动。 + AG的新闻主题类别来自于原始语料库中四个最大的类别。每个分类包含30000个训练样本和1900个测试样本。train.csv中的训练样本总数为12万,test.csv中的测试样本总数为7600。 + + 可以将数据集文件解压缩到以下结构中,并通过MindSpore的API读取: + + .. code-block:: + + . + └── ag_news_dataset_dir + ├── classes.txt + ├── train.csv + ├── test.csv + └── readme.txt + + **引用:** + + .. code-block:: + + @misc{zhang2015characterlevel, + title={Character-level Convolutional Networks for Text Classification}, + author={Xiang Zhang and Junbo Zhao and Yann LeCun}, + year={2015}, + eprint={1509.01626}, + archivePrefix={arXiv}, + primaryClass={cs.LG} + } + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.AmazonReviewDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.AmazonReviewDataset.rst index 6ae530b08ce..fbba12c872e 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.AmazonReviewDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.AmazonReviewDataset.rst @@ -1,71 +1,71 @@ -mindspore.dataset.AmazonReviewDataset -===================================== - -.. py:class:: mindspore.dataset.AmazonReviewDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - Amazon Review Full和Amazon Review Polarity数据集。 - - 生成的数据集有三列 `[label, title, content]` ,三列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 对于Polarity数据集, ``'train'`` 将读取360万个训练样本, ``'test'`` 将读取40万个测试样本, ``'all'`` 将读取所有400万个样本。 - 对于Full数据集, ``'train'`` 将读取300万个训练样本, ``'test'`` 将读取65万个测试样本, ``'all'`` 将读取所有365万个样本。默认值: ``None`` ,读取所有样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于AmazonReview数据集:** - - Amazon Review Full数据集包括来自亚马逊的评论数据。这些数据跨越18年,包括截止至2013年3月的约3500万条评论。评论数据包括产品和用户信息、产品评级和产品评论。 - 数据集主要用于文本分类,给定内容和标题,预测正确的星级评定。 - - Amazon Review Polarity数据集对产品评分进行了分级,评论分数1和2视为负面评论,4和5视为正面评论。 - 评分3的样本则被忽略。 - - Amazon Reviews Polarity和Amazon Reviews Full datasets具有相同的目录结构。 - 可以将数据集文件解压缩到以下结构,并通过MindSpore的API读取: - - .. code-block:: - - . - └── amazon_review_dir - ├── train.csv - ├── test.csv - └── readme.txt - - **引用:** - - .. code-block:: - - @article{zhang2015character, - title={Character-level convolutional networks for text classification}, - author={Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, - journal={Advances in neural information processing systems}, - volume={28}, - pages={649--657}, - year={2015} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.AmazonReviewDataset +===================================== + +.. py:class:: mindspore.dataset.AmazonReviewDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + Amazon Review Full和Amazon Review Polarity数据集。 + + 生成的数据集有三列 `[label, title, content]` ,三列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 对于Polarity数据集, ``'train'`` 将读取360万个训练样本, ``'test'`` 将读取40万个测试样本, ``'all'`` 将读取所有400万个样本。 + 对于Full数据集, ``'train'`` 将读取300万个训练样本, ``'test'`` 将读取65万个测试样本, ``'all'`` 将读取所有365万个样本。默认值: ``None`` ,读取所有样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于AmazonReview数据集:** + + Amazon Review Full数据集包括来自亚马逊的评论数据。这些数据跨越18年,包括截止至2013年3月的约3500万条评论。评论数据包括产品和用户信息、产品评级和产品评论。 + 数据集主要用于文本分类,给定内容和标题,预测正确的星级评定。 + + Amazon Review Polarity数据集对产品评分进行了分级,评论分数1和2视为负面评论,4和5视为正面评论。 + 评分3的样本则被忽略。 + + Amazon Reviews Polarity和Amazon Reviews Full datasets具有相同的目录结构。 + 可以将数据集文件解压缩到以下结构,并通过MindSpore的API读取: + + .. code-block:: + + . + └── amazon_review_dir + ├── train.csv + ├── test.csv + └── readme.txt + + **引用:** + + .. code-block:: + + @article{zhang2015character, + title={Character-level convolutional networks for text classification}, + author={Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, + journal={Advances in neural information processing systems}, + volume={28}, + pages={649--657}, + year={2015} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.BatchInfo.rst b/docs/api/api_python/dataset/mindspore.dataset.BatchInfo.rst index c931e5a26ae..772b2253602 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.BatchInfo.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.BatchInfo.rst @@ -1,14 +1,14 @@ -mindspore.dataset.BatchInfo -=========================== - -.. py:class:: mindspore.dataset.BatchInfo - - 当 `batch` 操作中参数 `batch_size` 或 `per_batch_map` 的传入对象是回调函数时,可以通过此类提供的方法获取数据集信息。 - - .. py:method:: get_batch_num() - - 返回当前epoch已经处理的batch数,数值从0开始。 - - .. py:method:: get_epoch_num() - - 返回当前的epoch数,数值从0开始。 +mindspore.dataset.BatchInfo +=========================== + +.. py:class:: mindspore.dataset.BatchInfo + + 当 `batch` 操作中参数 `batch_size` 或 `per_batch_map` 的传入对象是回调函数时,可以通过此类提供的方法获取数据集信息。 + + .. py:method:: get_batch_num() + + 返回当前epoch已经处理的batch数,数值从0开始。 + + .. py:method:: get_epoch_num() + + 返回当前的epoch数,数值从0开始。 diff --git a/docs/api/api_python/dataset/mindspore.dataset.CLUEDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.CLUEDataset.rst index 1ce47b599f3..2430dc3f3e9 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.CLUEDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.CLUEDataset.rst @@ -1,212 +1,212 @@ -mindspore.dataset.CLUEDataset -============================= - -.. py:class:: mindspore.dataset.CLUEDataset(dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - CLUE(Chinese Language Understanding Evaluation)数据集。 - - 目前支持的CLUE分类任务包括: ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 和 ``'CSL'`` 。更多CLUE数据集的说明详见 `CLUE GitHub `_ 。 - - 参数: - - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 - - **task** (str, 可选) - 任务类型,可取值为 ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 或 ``'CSL'`` 。默认值: ``'AFQMC'`` 。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'eval'`` 。默认值: ``'train'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 根据给定的 `task` 参数 和 `usage` 配置,数据集会生成不同的输出列: - - +-------------------------+------------------------------+-----------------------------+ - | `task` | `usage` | 输出列 | - +=========================+==============================+=============================+ - | AFQMC | train | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | | | | - | | | [label, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [id, dtype=uint32] | - | | | | - | | | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | | | | - | | | [label, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - | TNEWS | train | [label, dtype=string] | - | | | | - | | | [label_des, dtype=string] | - | | | | - | | | [sentence, dtype=string] | - | | | | - | | | [keywords, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [label, dtype=uint32] | - | | | | - | | | [keywords, dtype=string] | - | | | | - | | | [sentence, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [label, dtype=string] | - | | | | - | | | [label_des, dtype=string] | - | | | | - | | | [sentence, dtype=string] | - | | | | - | | | [keywords, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - | IFLYTEK | train | [label, dtype=string] | - | | | | - | | | [label_des, dtype=string] | - | | | | - | | | [sentence, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [id, dtype=uint32] | - | | | | - | | | [sentence, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [label, dtype=string] | - | | | | - | | | [label_des, dtype=string] | - | | | | - | | | [sentence, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - | CMNLI | train | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | | | | - | | | [label, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [id, dtype=uint32] | - | | | | - | | | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [sentence1, dtype=string] | - | | | | - | | | [sentence2, dtype=string] | - | | | | - | | | [label, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - | WSC | train | [span1_index, dtype=uint32]| - | | | | - | | | [span2_index, dtype=uint32]| - | | | | - | | | [span1_text, dtype=string] | - | | | | - | | | [span2_text, dtype=string] | - | | | | - | | | [idx, dtype=uint32] | - | | | | - | | | [text, dtype=string] | - | | | | - | | | [label, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [span1_index, dtype=uint32]| - | | | | - | | | [span2_index, dtype=uint32]| - | | | | - | | | [span1_text, dtype=string] | - | | | | - | | | [span2_text, dtype=string] | - | | | | - | | | [idx, dtype=uint32] | - | | | | - | | | [text, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [span1_index, dtype=uint32]| - | | | | - | | | [span2_index, dtype=uint32]| - | | | | - | | | [span1_text, dtype=string] | - | | | | - | | | [span2_text, dtype=string] | - | | | | - | | | [idx, dtype=uint32] | - | | | | - | | | [text, dtype=string] | - | | | | - | | | [label, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - | CSL | train | [id, dtype=uint32] | - | | | | - | | | [abst, dtype=string] | - | | | | - | | | [keyword, dtype=string] | - | | | | - | | | [label, dtype=string] | - | +------------------------------+-----------------------------+ - | | test | [id, dtype=uint32] | - | | | | - | | | [abst, dtype=string] | - | | | | - | | | [keyword, dtype=string] | - | +------------------------------+-----------------------------+ - | | eval | [id, dtype=uint32] | - | | | | - | | | [abst, dtype=string] | - | | | | - | | | [keyword, dtype=string] | - | | | | - | | | [label, dtype=string] | - +-------------------------+------------------------------+-----------------------------+ - - 异常: - - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 - - **ValueError** - `task` 参数不为 ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 或 ``'CSL'`` 。 - - **ValueError** - `usage` 参数不为 ``'train'`` 、 ``'test'`` 或 ``'eval'`` 。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于CLUE数据集:** - - CLUE,又名中文语言理解测评基准,包含许多有代表性的数据集,涵盖单句分类、句对分类和机器阅读理解等任务。 - - 您可以将数据集解压成如下的文件结构,并通过MindSpore的API进行读取,以 'afqmc' 数据集为例: - - .. code-block:: - - . - └── afqmc_public - ├── train.json - ├── test.json - └── dev.json - - **引用:** - - .. code-block:: - - @article{CLUEbenchmark, - title = {CLUE: A Chinese Language Understanding Evaluation Benchmark}, - author = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li, - Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng, - Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou, - Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan}, - journal = {arXiv preprint arXiv:2004.05986}, - year = {2020}, - howpublished = {https://github.com/CLUEbenchmark/CLUE} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.CLUEDataset +============================= + +.. py:class:: mindspore.dataset.CLUEDataset(dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + CLUE(Chinese Language Understanding Evaluation)数据集。 + + 目前支持的CLUE分类任务包括: ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 和 ``'CSL'`` 。更多CLUE数据集的说明详见 `CLUE GitHub `_ 。 + + 参数: + - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 + - **task** (str, 可选) - 任务类型,可取值为 ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 或 ``'CSL'`` 。默认值: ``'AFQMC'`` 。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'eval'`` 。默认值: ``'train'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 根据给定的 `task` 参数 和 `usage` 配置,数据集会生成不同的输出列: + + +-------------------------+------------------------------+-----------------------------+ + | `task` | `usage` | 输出列 | + +=========================+==============================+=============================+ + | AFQMC | train | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | | | | + | | | [label, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [id, dtype=uint32] | + | | | | + | | | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | | | | + | | | [label, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + | TNEWS | train | [label, dtype=string] | + | | | | + | | | [label_des, dtype=string] | + | | | | + | | | [sentence, dtype=string] | + | | | | + | | | [keywords, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [label, dtype=uint32] | + | | | | + | | | [keywords, dtype=string] | + | | | | + | | | [sentence, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [label, dtype=string] | + | | | | + | | | [label_des, dtype=string] | + | | | | + | | | [sentence, dtype=string] | + | | | | + | | | [keywords, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + | IFLYTEK | train | [label, dtype=string] | + | | | | + | | | [label_des, dtype=string] | + | | | | + | | | [sentence, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [id, dtype=uint32] | + | | | | + | | | [sentence, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [label, dtype=string] | + | | | | + | | | [label_des, dtype=string] | + | | | | + | | | [sentence, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + | CMNLI | train | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | | | | + | | | [label, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [id, dtype=uint32] | + | | | | + | | | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [sentence1, dtype=string] | + | | | | + | | | [sentence2, dtype=string] | + | | | | + | | | [label, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + | WSC | train | [span1_index, dtype=uint32]| + | | | | + | | | [span2_index, dtype=uint32]| + | | | | + | | | [span1_text, dtype=string] | + | | | | + | | | [span2_text, dtype=string] | + | | | | + | | | [idx, dtype=uint32] | + | | | | + | | | [text, dtype=string] | + | | | | + | | | [label, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [span1_index, dtype=uint32]| + | | | | + | | | [span2_index, dtype=uint32]| + | | | | + | | | [span1_text, dtype=string] | + | | | | + | | | [span2_text, dtype=string] | + | | | | + | | | [idx, dtype=uint32] | + | | | | + | | | [text, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [span1_index, dtype=uint32]| + | | | | + | | | [span2_index, dtype=uint32]| + | | | | + | | | [span1_text, dtype=string] | + | | | | + | | | [span2_text, dtype=string] | + | | | | + | | | [idx, dtype=uint32] | + | | | | + | | | [text, dtype=string] | + | | | | + | | | [label, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + | CSL | train | [id, dtype=uint32] | + | | | | + | | | [abst, dtype=string] | + | | | | + | | | [keyword, dtype=string] | + | | | | + | | | [label, dtype=string] | + | +------------------------------+-----------------------------+ + | | test | [id, dtype=uint32] | + | | | | + | | | [abst, dtype=string] | + | | | | + | | | [keyword, dtype=string] | + | +------------------------------+-----------------------------+ + | | eval | [id, dtype=uint32] | + | | | | + | | | [abst, dtype=string] | + | | | | + | | | [keyword, dtype=string] | + | | | | + | | | [label, dtype=string] | + +-------------------------+------------------------------+-----------------------------+ + + 异常: + - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 + - **ValueError** - `task` 参数不为 ``'AFQMC'`` 、 ``'TNEWS'`` 、 ``'IFLYTEK'`` 、 ``'CMNLI'`` 、 ``'WSC'`` 或 ``'CSL'`` 。 + - **ValueError** - `usage` 参数不为 ``'train'`` 、 ``'test'`` 或 ``'eval'`` 。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于CLUE数据集:** + + CLUE,又名中文语言理解测评基准,包含许多有代表性的数据集,涵盖单句分类、句对分类和机器阅读理解等任务。 + + 您可以将数据集解压成如下的文件结构,并通过MindSpore的API进行读取,以 'afqmc' 数据集为例: + + .. code-block:: + + . + └── afqmc_public + ├── train.json + ├── test.json + └── dev.json + + **引用:** + + .. code-block:: + + @article{CLUEbenchmark, + title = {CLUE: A Chinese Language Understanding Evaluation Benchmark}, + author = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li, + Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng, + Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou, + Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan}, + journal = {arXiv preprint arXiv:2004.05986}, + year = {2020}, + howpublished = {https://github.com/CLUEbenchmark/CLUE} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.CSVDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.CSVDataset.rst index d03e216c8cf..6651a1ed832 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.CSVDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.CSVDataset.rst @@ -1,40 +1,40 @@ -mindspore.dataset.CSVDataset -============================= - -.. py:class:: mindspore.dataset.CSVDataset(dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - CSV(Comma-Separated Values)文件数据集。 - - 生成的数据集的列名和列类型取决于输入的CSV文件。 - - 参数: - - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 - - **field_delim** (str, 可选) - 指定用于分隔字段的分隔符。默认值: ``','`` 。 - - **column_defaults** (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值: ``None`` ,不指定。如果未指定该参数,则所有列的数据类型将被视为string。 - - **column_names** (list[str], 可选) - 指定数据集生成的列名。默认值: ``None`` ,不指定。如果未指定该列表,则将CSV文件首行提供的字段作为列名生成。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL``:混洗文件和文件中的数据。 - - ``Shuffle.FILES``:仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_files` 参数所指向的文件无效或不存在。 - - **ValueError** - `field_delim` 参数无效。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.CSVDataset +============================= + +.. py:class:: mindspore.dataset.CSVDataset(dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + CSV(Comma-Separated Values)文件数据集。 + + 生成的数据集的列名和列类型取决于输入的CSV文件。 + + 参数: + - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 + - **field_delim** (str, 可选) - 指定用于分隔字段的分隔符。默认值: ``','`` 。 + - **column_defaults** (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值: ``None`` ,不指定。如果未指定该参数,则所有列的数据类型将被视为string。 + - **column_names** (list[str], 可选) - 指定数据集生成的列名。默认值: ``None`` ,不指定。如果未指定该列表,则将CSV文件首行提供的字段作为列名生成。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL``:混洗文件和文件中的数据。 + - ``Shuffle.FILES``:仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_files` 参数所指向的文件无效或不存在。 + - **ValueError** - `field_delim` 参数无效。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Caltech101Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Caltech101Dataset.rst index e3df1cd86d2..b5534f7cdc9 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Caltech101Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Caltech101Dataset.rst @@ -1,94 +1,94 @@ -mindspore.dataset.Caltech101Dataset -=================================== - -.. py:class:: mindspore.dataset.Caltech101Dataset(dataset_dir, target_type=None, num_samples=None, num_parallel_workers=1, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None) - - Caltech 101数据集。 - - 根据不同的 `target_type` 配置,数据集会生成不同的输出列。 - - - `target_type` 为 ``'category'``,输出列为 `[image, category]` 。 - - `target_type` 为 ``'annotation'``,输出列为 `[image, annotation]` 。 - - `target_type` 为 ``'all'``,输出列为 `[image, category, annotation]` 。 - - 列 'image' 为 uint8 类型。列 'category' 为 uint32 类型。列 'annotation' 是一个二维的ndarray,存储了图像的轮廓,由一系列的点组成。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径,该路径下将包含2个子目录,目录101_ObjectCategories用于存储图像, - 目录Annotations用于存储图像的标注。 - - **target_type** (str, 可选) - 指定数据集的子集,可取值为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 - 取值为 ``'category'`` 时将读取图像的类别标注作为label,取值为 ``'annotation'`` 时将读取图像的轮廓标注作为label, - 取值为 ``'all'`` 时将同时输出图像的类别标注和轮廓标注。默认值: ``None`` ,表示 ``'category'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `shard_id` 参数错误,小于0或者大于等于 `num_shards` 。 - - **ValueError** - `target_type` 参数取值不为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Caltech101数据集:** - - Caltech101数据集包含 101 种类别的图片。每种类别大约 40 到 800 张图像,大多数类别有大约 50 张图像。 - 每张图像的大小约为 300 x 200 像素。数据集中也提供了每张图片中每个物体的轮廓数据,用于检测和定位。 - - 您可以解压缩原始Caltech101数据集文件到如下目录结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── caltech101_dataset_directory - ├── 101_ObjectCategories - │ ├── Faces - │ │ ├── image_0001.jpg - │ │ ├── image_0002.jpg - │ │ ... - │ ├── Faces_easy - │ │ ├── image_0001.jpg - │ │ ├── image_0002.jpg - │ │ ... - │ ├── ... - └── Annotations - ├── Airplanes_Side_2 - │ ├── annotation_0001.mat - │ ├── annotation_0002.mat - │ ... - ├── Faces_2 - │ ├── annotation_0001.mat - │ ├── annotation_0002.mat - │ ... - ├── ... - - **引用:** - - .. code-block:: - - @article{FeiFei2004LearningGV, - author = {Li Fei-Fei and Rob Fergus and Pietro Perona}, - title = {Learning Generative Visual Models from Few Training Examples: - An Incremental Bayesian Approach Tested on 101 Object Categories}, - journal = {Computer Vision and Pattern Recognition Workshop}, - year = {2004}, - url = {http://data.caltech.edu/records/20086}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Caltech101Dataset +=================================== + +.. py:class:: mindspore.dataset.Caltech101Dataset(dataset_dir, target_type=None, num_samples=None, num_parallel_workers=1, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None) + + Caltech 101数据集。 + + 根据不同的 `target_type` 配置,数据集会生成不同的输出列。 + + - `target_type` 为 ``'category'``,输出列为 `[image, category]` 。 + - `target_type` 为 ``'annotation'``,输出列为 `[image, annotation]` 。 + - `target_type` 为 ``'all'``,输出列为 `[image, category, annotation]` 。 + + 列 'image' 为 uint8 类型。列 'category' 为 uint32 类型。列 'annotation' 是一个二维的ndarray,存储了图像的轮廓,由一系列的点组成。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径,该路径下将包含2个子目录,目录101_ObjectCategories用于存储图像, + 目录Annotations用于存储图像的标注。 + - **target_type** (str, 可选) - 指定数据集的子集,可取值为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 + 取值为 ``'category'`` 时将读取图像的类别标注作为label,取值为 ``'annotation'`` 时将读取图像的轮廓标注作为label, + 取值为 ``'all'`` 时将同时输出图像的类别标注和轮廓标注。默认值: ``None`` ,表示 ``'category'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `shard_id` 参数错误,小于0或者大于等于 `num_shards` 。 + - **ValueError** - `target_type` 参数取值不为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Caltech101数据集:** + + Caltech101数据集包含 101 种类别的图片。每种类别大约 40 到 800 张图像,大多数类别有大约 50 张图像。 + 每张图像的大小约为 300 x 200 像素。数据集中也提供了每张图片中每个物体的轮廓数据,用于检测和定位。 + + 您可以解压缩原始Caltech101数据集文件到如下目录结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── caltech101_dataset_directory + ├── 101_ObjectCategories + │ ├── Faces + │ │ ├── image_0001.jpg + │ │ ├── image_0002.jpg + │ │ ... + │ ├── Faces_easy + │ │ ├── image_0001.jpg + │ │ ├── image_0002.jpg + │ │ ... + │ ├── ... + └── Annotations + ├── Airplanes_Side_2 + │ ├── annotation_0001.mat + │ ├── annotation_0002.mat + │ ... + ├── Faces_2 + │ ├── annotation_0001.mat + │ ├── annotation_0002.mat + │ ... + ├── ... + + **引用:** + + .. code-block:: + + @article{FeiFei2004LearningGV, + author = {Li Fei-Fei and Rob Fergus and Pietro Perona}, + title = {Learning Generative Visual Models from Few Training Examples: + An Incremental Bayesian Approach Tested on 101 Object Categories}, + journal = {Computer Vision and Pattern Recognition Workshop}, + year = {2004}, + url = {http://data.caltech.edu/records/20086}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Caltech256Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Caltech256Dataset.rst index 75bcd11eabc..57f007f7426 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Caltech256Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Caltech256Dataset.rst @@ -1,81 +1,81 @@ -mindspore.dataset.Caltech256Dataset -=================================== - -.. py:class:: mindspore.dataset.Caltech256Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) - - Caltech 256数据集。 - - 生成的数据集有两列 `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `target_type` 参数取值不为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Caltech256数据集:** - - Caltech-256 是一个对象识别数据集,包含 30,607 张不同大小的真实世界图像,共有 257 个类别(256类物体和1个其他类), - 每个类别由至少 80 张图像。该数据集是 Caltech101 数据集的超集。 - - 您可以解压缩原始Caltech256数据集文件到如下目录结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── caltech256_dataset_directory - ├── 001.ak47 - │ ├── 001_0001.jpg - │ ├── 001_0002.jpg - │ ... - ├── 002.american-flag - │ ├── 002_0001.jpg - │ ├── 002_0002.jpg - │ ... - ├── 003.backpack - │ ├── 003_0001.jpg - │ ├── 003_0002.jpg - │ ... - ├── ... - - **引用:** - - .. code-block:: - - @article{griffin2007caltech, - title = {Caltech-256 object category dataset}, - added-at = {2021-01-21T02:54:42.000+0100}, - author = {Griffin, Gregory and Holub, Alex and Perona, Pietro}, - biburl = {https://www.bibsonomy.org/bibtex/21f746f23ff0307826cca3e3be45f8de7/s364315}, - interhash = {bfe1e648c1778c04baa60f23d1223375}, - intrahash = {1f746f23ff0307826cca3e3be45f8de7}, - publisher = {California Institute of Technology}, - timestamp = {2021-01-21T02:54:42.000+0100}, - year = {2007} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Caltech256Dataset +=================================== + +.. py:class:: mindspore.dataset.Caltech256Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) + + Caltech 256数据集。 + + 生成的数据集有两列 `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `target_type` 参数取值不为 ``'category'`` 、 ``'annotation'`` 或 ``'all'`` 。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Caltech256数据集:** + + Caltech-256 是一个对象识别数据集,包含 30,607 张不同大小的真实世界图像,共有 257 个类别(256类物体和1个其他类), + 每个类别由至少 80 张图像。该数据集是 Caltech101 数据集的超集。 + + 您可以解压缩原始Caltech256数据集文件到如下目录结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── caltech256_dataset_directory + ├── 001.ak47 + │ ├── 001_0001.jpg + │ ├── 001_0002.jpg + │ ... + ├── 002.american-flag + │ ├── 002_0001.jpg + │ ├── 002_0002.jpg + │ ... + ├── 003.backpack + │ ├── 003_0001.jpg + │ ├── 003_0002.jpg + │ ... + ├── ... + + **引用:** + + .. code-block:: + + @article{griffin2007caltech, + title = {Caltech-256 object category dataset}, + added-at = {2021-01-21T02:54:42.000+0100}, + author = {Griffin, Gregory and Holub, Alex and Perona, Pietro}, + biburl = {https://www.bibsonomy.org/bibtex/21f746f23ff0307826cca3e3be45f8de7/s364315}, + interhash = {bfe1e648c1778c04baa60f23d1223375}, + intrahash = {1f746f23ff0307826cca3e3be45f8de7}, + publisher = {California Institute of Technology}, + timestamp = {2021-01-21T02:54:42.000+0100}, + year = {2007} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.CelebADataset.rst b/docs/api/api_python/dataset/mindspore.dataset.CelebADataset.rst index 491ce8599f8..c6c17346b52 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.CelebADataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.CelebADataset.rst @@ -1,105 +1,105 @@ -mindspore.dataset.CelebADataset -=============================== - -.. py:class:: mindspore.dataset.CelebADataset(dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None, decrypt=None) - - CelebA(CelebFaces Attributes)数据集。 - - 目前仅支持解析CelebA数据集中的 `list_attr_celeba.txt` 文件作为数据集的label。 - 生成的数据集有两列 `[image, attr]` 。 `image` 列的数据类型为uint8。`attr` 列的数据类型为uint32,并以one-hot编码的形式生成。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``'all'`` ,全部样本图片。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **extensions** (list[str], 可选) - 指定文件的扩展名,仅读取与指定扩展名匹配的文件到数据集中。默认值: ``None`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于CelebA数据集:** - - CelebFaces Attributes Dataset(CelebA)数据集是一个大规模数据集,拥有超过20万张名人图像,每个图像都有40个属性标注。此数据集包含了大量不同姿态、各种背景的图像,种类丰富、数量庞大、标注充分。数据集总体包含: - - - 10177个不同的身份 - - 202599张图像 - - 每张图像拥有5个五官位置标注,40个属性标签 - - 此数据集可用于各种计算机视觉任务的训练和测试,包括属性识别、检测和五官定位等。 - - 原始CelebA数据集结构: - - .. code-block:: - - . - └── CelebA - ├── README.md - ├── Img - │ ├── img_celeba.7z - │ ├── img_align_celeba_png.7z - │ └── img_align_celeba.zip - ├── Eval - │ └── list_eval_partition.txt - └── Anno - ├── list_landmarks_celeba.txt - ├── list_landmarks_align_celeba.txt - ├── list_bbox_celeba.txt - ├── list_attr_celeba.txt - └── identity_CelebA.txt - - 您可以将上述Anno目录下的txt文件与Img目录下的文件解压放至同一目录,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── celeba_dataset_directory - ├── list_attr_celeba.txt - ├── 000001.jpg - ├── 000002.jpg - ├── 000003.jpg - ├── ... - - **引用:** - - .. code-block:: - - @article{DBLP:journals/corr/LiuLWT14, - author = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang}, - title = {Deep Learning Attributes in the Wild}, - journal = {CoRR}, - volume = {abs/1411.7766}, - year = {2014}, - url = {http://arxiv.org/abs/1411.7766}, - archivePrefix = {arXiv}, - eprint = {1411.7766}, - timestamp = {Tue, 10 Dec 2019 15:37:26 +0100}, - biburl = {https://dblp.org/rec/journals/corr/LiuLWT14.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org}, - howpublished = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.CelebADataset +=============================== + +.. py:class:: mindspore.dataset.CelebADataset(dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None, decrypt=None) + + CelebA(CelebFaces Attributes)数据集。 + + 目前仅支持解析CelebA数据集中的 `list_attr_celeba.txt` 文件作为数据集的label。 + 生成的数据集有两列 `[image, attr]` 。 `image` 列的数据类型为uint8。`attr` 列的数据类型为uint32,并以one-hot编码的形式生成。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``'all'`` ,全部样本图片。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **extensions** (list[str], 可选) - 指定文件的扩展名,仅读取与指定扩展名匹配的文件到数据集中。默认值: ``None`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于CelebA数据集:** + + CelebFaces Attributes Dataset(CelebA)数据集是一个大规模数据集,拥有超过20万张名人图像,每个图像都有40个属性标注。此数据集包含了大量不同姿态、各种背景的图像,种类丰富、数量庞大、标注充分。数据集总体包含: + + - 10177个不同的身份 + - 202599张图像 + - 每张图像拥有5个五官位置标注,40个属性标签 + + 此数据集可用于各种计算机视觉任务的训练和测试,包括属性识别、检测和五官定位等。 + + 原始CelebA数据集结构: + + .. code-block:: + + . + └── CelebA + ├── README.md + ├── Img + │ ├── img_celeba.7z + │ ├── img_align_celeba_png.7z + │ └── img_align_celeba.zip + ├── Eval + │ └── list_eval_partition.txt + └── Anno + ├── list_landmarks_celeba.txt + ├── list_landmarks_align_celeba.txt + ├── list_bbox_celeba.txt + ├── list_attr_celeba.txt + └── identity_CelebA.txt + + 您可以将上述Anno目录下的txt文件与Img目录下的文件解压放至同一目录,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── celeba_dataset_directory + ├── list_attr_celeba.txt + ├── 000001.jpg + ├── 000002.jpg + ├── 000003.jpg + ├── ... + + **引用:** + + .. code-block:: + + @article{DBLP:journals/corr/LiuLWT14, + author = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang}, + title = {Deep Learning Attributes in the Wild}, + journal = {CoRR}, + volume = {abs/1411.7766}, + year = {2014}, + url = {http://arxiv.org/abs/1411.7766}, + archivePrefix = {arXiv}, + eprint = {1411.7766}, + timestamp = {Tue, 10 Dec 2019 15:37:26 +0100}, + biburl = {https://dblp.org/rec/journals/corr/LiuLWT14.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org}, + howpublished = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Cifar100Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Cifar100Dataset.rst index 722a9659ef0..309e64dbaa8 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Cifar100Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Cifar100Dataset.rst @@ -1,68 +1,68 @@ -mindspore.dataset.Cifar100Dataset -================================= - -.. py:class:: mindspore.dataset.Cifar100Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - CIFAR-100数据集。 - - 生成的数据集有三列: `[image, coarse_label, fine_label]` 。 `image` 列的数据类型为uint8。 `coarse_label` 和 `fine_labels` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'``、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取50,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部60,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `usage` 参数取值不为 ``'train'``、 ``'test'`` 或 ``'all'`` 。 - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于CIFAR-100数据集:** - - CIFAR-100数据集和CIFAR-10数据集非常相似,CIFAR-100有100个类别,每类包含600张图片。其中500张训练图片和100张测试图片。这100个类别又被分成20个超类。每个图片都有一个"fine"标签(所属子类)和一个"coarse"标签(所属超类)。 - - 以下为原始CIFAR-100数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── cifar-100-binary - ├── train.bin - ├── test.bin - ├── fine_label_names.txt - └── coarse_label_names.txt - - **引用:** - - .. code-block:: - - @techreport{Krizhevsky09, - author = {Alex Krizhevsky}, - title = {Learning multiple layers of features from tiny images}, - institution = {}, - year = {2009}, - howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Cifar100Dataset +================================= + +.. py:class:: mindspore.dataset.Cifar100Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + CIFAR-100数据集。 + + 生成的数据集有三列: `[image, coarse_label, fine_label]` 。 `image` 列的数据类型为uint8。 `coarse_label` 和 `fine_labels` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'``、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取50,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部60,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `usage` 参数取值不为 ``'train'``、 ``'test'`` 或 ``'all'`` 。 + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于CIFAR-100数据集:** + + CIFAR-100数据集和CIFAR-10数据集非常相似,CIFAR-100有100个类别,每类包含600张图片。其中500张训练图片和100张测试图片。这100个类别又被分成20个超类。每个图片都有一个"fine"标签(所属子类)和一个"coarse"标签(所属超类)。 + + 以下为原始CIFAR-100数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── cifar-100-binary + ├── train.bin + ├── test.bin + ├── fine_label_names.txt + └── coarse_label_names.txt + + **引用:** + + .. code-block:: + + @techreport{Krizhevsky09, + author = {Alex Krizhevsky}, + title = {Learning multiple layers of features from tiny images}, + institution = {}, + year = {2009}, + howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Cifar10Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Cifar10Dataset.rst index 6b55079bcd9..f4b5e9bff36 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Cifar10Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Cifar10Dataset.rst @@ -1,72 +1,72 @@ -mindspore.dataset.Cifar10Dataset -================================ - -.. py:class:: mindspore.dataset.Cifar10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - CIFAR-10数据集。 - - 该API目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。 - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取50,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部60,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于CIFAR-10数据集:** - - CIFAR-10数据集由60000张32x32彩色图片组成,总共有10个类别,每类6000张图片。有50000个训练样本和10000个测试样本。10个类别包含飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。 - - 以下为原始CIFAR-10数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── cifar-10-batches-bin - ├── data_batch_1.bin - ├── data_batch_2.bin - ├── data_batch_3.bin - ├── data_batch_4.bin - ├── data_batch_5.bin - ├── test_batch.bin - ├── readme.html - └── batches.meta.text - - **引用:** - - .. code-block:: - - @techreport{Krizhevsky09, - author = {Alex Krizhevsky}, - title = {Learning multiple layers of features from tiny images}, - institution = {}, - year = {2009}, - howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} - } - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Cifar10Dataset +================================ + +.. py:class:: mindspore.dataset.Cifar10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + CIFAR-10数据集。 + + 该API目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。 + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取50,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部60,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于CIFAR-10数据集:** + + CIFAR-10数据集由60000张32x32彩色图片组成,总共有10个类别,每类6000张图片。有50000个训练样本和10000个测试样本。10个类别包含飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。 + + 以下为原始CIFAR-10数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── cifar-10-batches-bin + ├── data_batch_1.bin + ├── data_batch_2.bin + ├── data_batch_3.bin + ├── data_batch_4.bin + ├── data_batch_5.bin + ├── test_batch.bin + ├── readme.html + └── batches.meta.text + + **引用:** + + .. code-block:: + + @techreport{Krizhevsky09, + author = {Alex Krizhevsky}, + title = {Learning multiple layers of features from tiny images}, + institution = {}, + year = {2009}, + howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html} + } + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.CityscapesDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.CityscapesDataset.rst index 647d69a6f10..84083b6bb1e 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.CityscapesDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.CityscapesDataset.rst @@ -1,105 +1,105 @@ -mindspore.dataset.CityscapesDataset -=================================== - -.. py:class:: mindspore.dataset.CityscapesDataset(dataset_dir, usage="train", quality_mode="fine", task="instance", num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Cityscapes数据集。 - - 生成的数据集有两列 `[image, task]` 。 - `image` 列的数据类型为uint8。`task` 列的数据类型根据参数 `task` 的值而定,当参数 `task` 取值为 ``'polygon'`` ,列的数据类型为string,其他取值下,列的数据类型为uint8。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集。当参数 `quality_mode` 取值为 ``'fine'`` 时,此参数可取值为 ``'train'`` 、 ``'test'`` 、 ``'val'`` 或 ``'all'`` 。 - 当参数 `quality_mode` 取值为 ``'coarse'`` 时,此参数可取值为 ``'train'`` 、 ``'train_extra'`` 、 ``'val'`` 或 ``'all'`` 。默认值: ``'train'`` ,全部样本图片。 - - **quality_mode** (str, 可选) - 指定数据集的质量模式,可取值为 ``'fine'`` 或 ``'coarse'`` 。默认值: ``'fine'`` 。 - - **task** (str, 可选) - 指定数据集的任务类型,可取值为 ``'instance'`` 、 ``'semantic'`` 、 ``'polygon'`` 或 ``'color'`` 。默认值: ``'instance'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `dataset_dir` 路径非法或不存在。 - - **ValueError** - `task` 参数取值不为 ``'instance'`` 、 ``'semantic'``、 ``'polygon'`` 或 ``'color'`` 。 - - **ValueError** - `quality_mode` 参数取值不为 ``'fine'`` 或 ``'coarse'`` 。 - - **ValueError** - `usage` 参数取值不在给定的字段中。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Cityscapes数据集:** - - Cityscapes 数据集由来自 50 个城市的 24998 张彩色图像组成。 - 其中 5000 张图像具有高质量的密集像素标注,19998 张图像具有粗糙的多边形标注。 - 该数据集共有 30 个类,多边形标注包括密集语义分割,以及车辆和人的实例分割。 - - 您可以解压缩原始数据集文件到如下目录结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── Cityscapes - ├── leftImg8bit - | ├── train - | | ├── aachen - | | | ├── aachen_000000_000019_leftImg8bit.png - | | | ├── aachen_000001_000019_leftImg8bit.png - | | | ├── ... - | | ├── bochum - | | | ├── ... - | | ├── ... - | ├── test - | | ├── ... - | ├── val - | | ├── ... - └── gtFine - ├── train - | ├── aachen - | | ├── aachen_000000_000019_gtFine_color.png - | | ├── aachen_000000_000019_gtFine_instanceIds.png - | | ├── aachen_000000_000019_gtFine_labelIds.png - | | ├── aachen_000000_000019_gtFine_polygons.json - | | ├── aachen_000001_000019_gtFine_color.png - | | ├── aachen_000001_000019_gtFine_instanceIds.png - | | ├── aachen_000001_000019_gtFine_labelIds.png - | | ├── aachen_000001_000019_gtFine_polygons.json - | | ├── ... - | ├── bochum - | | ├── ... - | ├── ... - ├── test - | ├── ... - └── val - ├── ... - - **引用:** - - .. code-block:: - - @inproceedings{Cordts2016Cityscapes, - title = {The Cityscapes Dataset for Semantic Urban Scene Understanding}, - author = {Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, - Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, - booktitle = {Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, - year = {2016} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.CityscapesDataset +=================================== + +.. py:class:: mindspore.dataset.CityscapesDataset(dataset_dir, usage="train", quality_mode="fine", task="instance", num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Cityscapes数据集。 + + 生成的数据集有两列 `[image, task]` 。 + `image` 列的数据类型为uint8。`task` 列的数据类型根据参数 `task` 的值而定,当参数 `task` 取值为 ``'polygon'`` ,列的数据类型为string,其他取值下,列的数据类型为uint8。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集。当参数 `quality_mode` 取值为 ``'fine'`` 时,此参数可取值为 ``'train'`` 、 ``'test'`` 、 ``'val'`` 或 ``'all'`` 。 + 当参数 `quality_mode` 取值为 ``'coarse'`` 时,此参数可取值为 ``'train'`` 、 ``'train_extra'`` 、 ``'val'`` 或 ``'all'`` 。默认值: ``'train'`` ,全部样本图片。 + - **quality_mode** (str, 可选) - 指定数据集的质量模式,可取值为 ``'fine'`` 或 ``'coarse'`` 。默认值: ``'fine'`` 。 + - **task** (str, 可选) - 指定数据集的任务类型,可取值为 ``'instance'`` 、 ``'semantic'`` 、 ``'polygon'`` 或 ``'color'`` 。默认值: ``'instance'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `dataset_dir` 路径非法或不存在。 + - **ValueError** - `task` 参数取值不为 ``'instance'`` 、 ``'semantic'``、 ``'polygon'`` 或 ``'color'`` 。 + - **ValueError** - `quality_mode` 参数取值不为 ``'fine'`` 或 ``'coarse'`` 。 + - **ValueError** - `usage` 参数取值不在给定的字段中。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Cityscapes数据集:** + + Cityscapes 数据集由来自 50 个城市的 24998 张彩色图像组成。 + 其中 5000 张图像具有高质量的密集像素标注,19998 张图像具有粗糙的多边形标注。 + 该数据集共有 30 个类,多边形标注包括密集语义分割,以及车辆和人的实例分割。 + + 您可以解压缩原始数据集文件到如下目录结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── Cityscapes + ├── leftImg8bit + | ├── train + | | ├── aachen + | | | ├── aachen_000000_000019_leftImg8bit.png + | | | ├── aachen_000001_000019_leftImg8bit.png + | | | ├── ... + | | ├── bochum + | | | ├── ... + | | ├── ... + | ├── test + | | ├── ... + | ├── val + | | ├── ... + └── gtFine + ├── train + | ├── aachen + | | ├── aachen_000000_000019_gtFine_color.png + | | ├── aachen_000000_000019_gtFine_instanceIds.png + | | ├── aachen_000000_000019_gtFine_labelIds.png + | | ├── aachen_000000_000019_gtFine_polygons.json + | | ├── aachen_000001_000019_gtFine_color.png + | | ├── aachen_000001_000019_gtFine_instanceIds.png + | | ├── aachen_000001_000019_gtFine_labelIds.png + | | ├── aachen_000001_000019_gtFine_polygons.json + | | ├── ... + | ├── bochum + | | ├── ... + | ├── ... + ├── test + | ├── ... + └── val + ├── ... + + **引用:** + + .. code-block:: + + @inproceedings{Cordts2016Cityscapes, + title = {The Cityscapes Dataset for Semantic Urban Scene Understanding}, + author = {Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, + Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, + booktitle = {Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2016} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.CocoDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.CocoDataset.rst index 070354786a5..60d533cbb33 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.CocoDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.CocoDataset.rst @@ -1,141 +1,141 @@ -mindspore.dataset.CocoDataset -============================== - -.. py:class:: mindspore.dataset.CocoDataset(dataset_dir, annotation_file, task='Detection', num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None, extra_metadata=False, decrypt=None) - - COCO(Common Objects in Context)数据集。 - - 该API支持解析COCO2017数据集,支持四种类型的机器学习任务,分别是目标检测、关键点检测、物体分割、全景分割和图片注解。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **annotation_file** (str) - 数据集标注JSON文件的路径。 - - **task** (str, 可选) - 指定COCO数据的任务类型。支持的任务类型包括: ``'Detection'`` (目标检测) 、 ``'Stuff'`` (物体分割) 、 ``'Panoptic'`` (全景分割) 、 ``'Keypoint'`` (关键点检测)和 ``'Captioning'`` (图片注解) 。默认值: ``'Detection'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` ,表2中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` ,表2中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **extra_metadata** (bool, 可选) - 用于指定是否额外输出一个数据列用于表示图片元信息。如果为True,则将额外输出一个名为 `[_meta-filename, dtype=string]` 的数据列。默认值: ``False`` 。 - - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 - - 根据不同 `task` 参数设置,生成数据集具有不同的输出列: - - +-------------------------+----------------------------------------------+ - | `task` | 输出列 | - +=========================+==============================================+ - | Detection | [image, dtype=uint8] | - | | | - | | [bbox, dtype=float32] | - | | | - | | [category_id, dtype=uint32] | - | | | - | | [iscrowd, dtype=uint32] | - +-------------------------+----------------------------------------------+ - | Stuff | [image, dtype=uint8] | - | | | - | | [segmentation, dtype=float32] | - | | | - | | [iscrowd, dtype=uint32] | - +-------------------------+----------------------------------------------+ - | Keypoint | [image, dtype=uint8] | - | | | - | | [keypoints, dtype=float32] | - | | | - | | [num_keypoints, dtype=uint32] | - +-------------------------+----------------------------------------------+ - | Panoptic | [image, dtype=uint8] | - | | | - | | [bbox, dtype=float32] | - | | | - | | [category_id, dtype=uint32] | - | | | - | | [iscrowd, dtype=uint32] | - | | | - | | [area, dtype=uint32] | - +-------------------------+----------------------------------------------+ - | Captioning | [image, dtype=uint8] | - | | | - | | [captions, dtype=string] | - +-------------------------+----------------------------------------------+ - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **RuntimeError** - 解析 `annotation_file` 指定的JSON文件失败。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `task` 参数取值不为 ``'Detection'`` 、 ``'Stuff'`` 、 ``'Panoptic'`` 、 ``'Keypoint'`` 或 ``'Captioning'`` 。 - - **ValueError** - `annotation_file` 参数对应的文件不存在。 - - **ValueError** - `dataset_dir` 参数路径不存在。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: - - 当参数 `extra_metadata` 为 ``True`` 时,还需使用 `rename` 操作删除额外数据列 '_meta-filename'的前缀 '_meta-', - 否则迭代得到的数据行中不会出现此额外数据列。 - - 暂不支持指定 `sampler` 参数为 :class:`mindspore.dataset.PKSampler`。 - - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于COCO数据集:** - - Microsoft Common Objects in Context(COCO)是一个大型数据集,该数据集专门为目标检测,语义分割和字幕生成任务而设计。它拥有330K张图像(标记数量大于200K个)、1500000个目标实例、80个目标类别、91个对象类别、每张图片均有5个字幕、带关键点标注的人有250000个。与流行的ImageNet数据集相比,COCO的类别较少,但每个类别中的图片样本非常多。 - - 您可以解压缩原始COCO-2017数据集文件得到如下目录结构,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── coco_dataset_directory - ├── train2017 - │ ├── 000000000009.jpg - │ ├── 000000000025.jpg - │ ├── ... - ├── test2017 - │ ├── 000000000001.jpg - │ ├── 000000058136.jpg - │ ├── ... - ├── val2017 - │ ├── 000000000139.jpg - │ ├── 000000057027.jpg - │ ├── ... - └── annotation - ├── captions_train2017.json - ├── captions_val2017.json - ├── instances_train2017.json - ├── instances_val2017.json - ├── person_keypoints_train2017.json - └── person_keypoints_val2017.json - - **引用:** - - .. code-block:: - - @article{DBLP:journals/corr/LinMBHPRDZ14, - author = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and - Lubomir D. Bourdev and Ross B. Girshick and James Hays and - Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick}, - title = {Microsoft {COCO:} Common Objects in Context}, - journal = {CoRR}, - volume = {abs/1405.0312}, - year = {2014}, - url = {http://arxiv.org/abs/1405.0312}, - archivePrefix = {arXiv}, - eprint = {1405.0312}, - timestamp = {Mon, 13 Aug 2018 16:48:13 +0200}, - biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.CocoDataset +============================== + +.. py:class:: mindspore.dataset.CocoDataset(dataset_dir, annotation_file, task='Detection', num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None, extra_metadata=False, decrypt=None) + + COCO(Common Objects in Context)数据集。 + + 该API支持解析COCO2017数据集,支持四种类型的机器学习任务,分别是目标检测、关键点检测、物体分割、全景分割和图片注解。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **annotation_file** (str) - 数据集标注JSON文件的路径。 + - **task** (str, 可选) - 指定COCO数据的任务类型。支持的任务类型包括: ``'Detection'`` (目标检测) 、 ``'Stuff'`` (物体分割) 、 ``'Panoptic'`` (全景分割) 、 ``'Keypoint'`` (关键点检测)和 ``'Captioning'`` (图片注解) 。默认值: ``'Detection'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` ,表2中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` ,表2中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **extra_metadata** (bool, 可选) - 用于指定是否额外输出一个数据列用于表示图片元信息。如果为True,则将额外输出一个名为 `[_meta-filename, dtype=string]` 的数据列。默认值: ``False`` 。 + - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 + + 根据不同 `task` 参数设置,生成数据集具有不同的输出列: + + +-------------------------+----------------------------------------------+ + | `task` | 输出列 | + +=========================+==============================================+ + | Detection | [image, dtype=uint8] | + | | | + | | [bbox, dtype=float32] | + | | | + | | [category_id, dtype=uint32] | + | | | + | | [iscrowd, dtype=uint32] | + +-------------------------+----------------------------------------------+ + | Stuff | [image, dtype=uint8] | + | | | + | | [segmentation, dtype=float32] | + | | | + | | [iscrowd, dtype=uint32] | + +-------------------------+----------------------------------------------+ + | Keypoint | [image, dtype=uint8] | + | | | + | | [keypoints, dtype=float32] | + | | | + | | [num_keypoints, dtype=uint32] | + +-------------------------+----------------------------------------------+ + | Panoptic | [image, dtype=uint8] | + | | | + | | [bbox, dtype=float32] | + | | | + | | [category_id, dtype=uint32] | + | | | + | | [iscrowd, dtype=uint32] | + | | | + | | [area, dtype=uint32] | + +-------------------------+----------------------------------------------+ + | Captioning | [image, dtype=uint8] | + | | | + | | [captions, dtype=string] | + +-------------------------+----------------------------------------------+ + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **RuntimeError** - 解析 `annotation_file` 指定的JSON文件失败。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `task` 参数取值不为 ``'Detection'`` 、 ``'Stuff'`` 、 ``'Panoptic'`` 、 ``'Keypoint'`` 或 ``'Captioning'`` 。 + - **ValueError** - `annotation_file` 参数对应的文件不存在。 + - **ValueError** - `dataset_dir` 参数路径不存在。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: + - 当参数 `extra_metadata` 为 ``True`` 时,还需使用 `rename` 操作删除额外数据列 '_meta-filename'的前缀 '_meta-', + 否则迭代得到的数据行中不会出现此额外数据列。 + - 暂不支持指定 `sampler` 参数为 :class:`mindspore.dataset.PKSampler`。 + - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于COCO数据集:** + + Microsoft Common Objects in Context(COCO)是一个大型数据集,该数据集专门为目标检测,语义分割和字幕生成任务而设计。它拥有330K张图像(标记数量大于200K个)、1500000个目标实例、80个目标类别、91个对象类别、每张图片均有5个字幕、带关键点标注的人有250000个。与流行的ImageNet数据集相比,COCO的类别较少,但每个类别中的图片样本非常多。 + + 您可以解压缩原始COCO-2017数据集文件得到如下目录结构,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── coco_dataset_directory + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000000025.jpg + │ ├── ... + ├── test2017 + │ ├── 000000000001.jpg + │ ├── 000000058136.jpg + │ ├── ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000057027.jpg + │ ├── ... + └── annotation + ├── captions_train2017.json + ├── captions_val2017.json + ├── instances_train2017.json + ├── instances_val2017.json + ├── person_keypoints_train2017.json + └── person_keypoints_val2017.json + + **引用:** + + .. code-block:: + + @article{DBLP:journals/corr/LinMBHPRDZ14, + author = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and + Lubomir D. Bourdev and Ross B. Girshick and James Hays and + Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick}, + title = {Microsoft {COCO:} Common Objects in Context}, + journal = {CoRR}, + volume = {abs/1405.0312}, + year = {2014}, + url = {http://arxiv.org/abs/1405.0312}, + archivePrefix = {arXiv}, + eprint = {1405.0312}, + timestamp = {Mon, 13 Aug 2018 16:48:13 +0200}, + biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.DBpediaDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.DBpediaDataset.rst index 3e65b321ad4..8178026cf4e 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.DBpediaDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.DBpediaDataset.rst @@ -1,69 +1,69 @@ -mindspore.dataset.DBpediaDataset -================================ - -.. py:class:: mindspore.dataset.DBpediaDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - DBpedia数据集。 - - 生成的数据集有三列 `[class, title, content]` ,三列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` , ``'test'`` 或 ``'all'`` 。 - ``'train'`` 将读取560,000个训练样本, ``'test'`` 将读取70,000个测试样本中, ``'all'`` 将读取所有630,000个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于DBpedia数据集:** - - DBpedia数据集包括14个类,超过63万个文本样本,train.csv中有56万样本,test.csv中有7万测试样本。 - 14个不同的类别分别是:公司、教育学院、艺术家、运动员、文员,交通,建筑,自然场所,村庄,动物,植物,专辑,电影,书面工作。 - - 以下是原始DBpedia数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并通过Mindspore的API读取。 - - .. code-block:: - - . - └── dbpedia_dataset_dir - ├── train.csv - ├── test.csv - ├── classes.txt - └── readme.txt - - **引用:** - - .. code-block:: - - @article{DBpedia, - title = {DBPedia Ontology Classification Dataset}, - author = {Jens Lehmann, Robert Isele, Max Jakob, Anja Jentzsch, Dimitris Kontokostas, - Pablo N. Mendes, Sebastian Hellmann, Mohamed Morsey, Patrick van Kleef, - Sören Auer, Christian Bizer}, - year = {2015}, - howpublished = {http://dbpedia.org} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.DBpediaDataset +================================ + +.. py:class:: mindspore.dataset.DBpediaDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + DBpedia数据集。 + + 生成的数据集有三列 `[class, title, content]` ,三列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` , ``'test'`` 或 ``'all'`` 。 + ``'train'`` 将读取560,000个训练样本, ``'test'`` 将读取70,000个测试样本中, ``'all'`` 将读取所有630,000个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于DBpedia数据集:** + + DBpedia数据集包括14个类,超过63万个文本样本,train.csv中有56万样本,test.csv中有7万测试样本。 + 14个不同的类别分别是:公司、教育学院、艺术家、运动员、文员,交通,建筑,自然场所,村庄,动物,植物,专辑,电影,书面工作。 + + 以下是原始DBpedia数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并通过Mindspore的API读取。 + + .. code-block:: + + . + └── dbpedia_dataset_dir + ├── train.csv + ├── test.csv + ├── classes.txt + └── readme.txt + + **引用:** + + .. code-block:: + + @article{DBpedia, + title = {DBPedia Ontology Classification Dataset}, + author = {Jens Lehmann, Robert Isele, Max Jakob, Anja Jentzsch, Dimitris Kontokostas, + Pablo N. Mendes, Sebastian Hellmann, Mohamed Morsey, Patrick van Kleef, + Sören Auer, Christian Bizer}, + year = {2015}, + howpublished = {http://dbpedia.org} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.DIV2KDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.DIV2KDataset.rst index 390897bd71e..3b230fa3f43 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.DIV2KDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.DIV2KDataset.rst @@ -1,122 +1,122 @@ -mindspore.dataset.DIV2KDataset -============================== - -.. py:class:: mindspore.dataset.DIV2KDataset(dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - DIV2K(DIVerse 2K resolution image)数据集。 - - 生成的数据集有两列 `[hr_image, lr_image]` 。 `hr_image` 列和 `lr_image` 列的数据类型都为uint8。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集。可取值为 ``'train'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``'train'`` 。 - - **downgrade** (str, 可选) - 指定数据集的下采样的模式,可取值为 ``'bicubic'`` 、 ``'unknown'`` 、 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` 。默认值: ``'bicubic'`` 。 - - **scale** (int, 可选) - 指定数据集的缩放尺度。当参数 `downgrade` 取值为 ``'bicubic'`` 时,此参数可以取值为 ``2`` 、 ``3`` 、 ``4`` 、``8`` 。 - 当参数 `downgrade` 取值为 ``'unknown'`` 时,此参数可以取值为 ``2`` 、 ``3`` 、 ``4`` 。当参数 `downgrade` 取值为 ``'mild'`` 、 ``'difficult'`` 、 ``'wild'`` 时,此参数仅可以取值为 ``4`` 。默认值: ``2`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `dataset_dir` 路径非法或不存在。 - - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'valid'`` 或 ``'all'`` 。 - - **ValueError** - `downgrade` 参数取值不为 ``'bicubic'`` 、 ``'unknown'`` 、 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` 。 - - **ValueError** - `scale` 参数取值不在给定的字段中,或与 `downgrade` 参数的值不匹配。 - - **ValueError** - `scale` 参数取值为8,但 `downgrade` 参数的值不为 ``'bicubic'`` 。 - - **ValueError** - `downgrade` 参数取值为 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` ,但 `scale` 参数的值不为 ``4`` 。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于DIV2K数据集:** - - DIV2K数据集由1000张2K分辨率图像组成,其中800张用于训练,100张用于验证,100张用于测试。 - 作为NTIRE比赛的数据集,NTIRE 2017 和 NTIRE 2018 仅包括DIV2K的训练数据集和验证数据集。 - - 您可以解压缩原始DIV2K数据集文件到如下目录结构,并通过MindSpore的API进行读取。 - - 以训练数据集作为例子。 - - .. code-block:: - - . - └── DIV2K - ├── DIV2K_train_HR - | ├── 0001.png - | ├── 0002.png - | ├── ... - ├── DIV2K_train_LR_bicubic - | ├── X2 - | | ├── 0001x2.png - | | ├── 0002x2.png - | | ├── ... - | ├── X3 - | | ├── 0001x3.png - | | ├── 0002x3.png - | | ├── ... - | └── X4 - | ├── 0001x4.png - | ├── 0002x4.png - | ├── ... - ├── DIV2K_train_LR_unknown - | ├── X2 - | | ├── 0001x2.png - | | ├── 0002x2.png - | | ├── ... - | ├── X3 - | | ├── 0001x3.png - | | ├── 0002x3.png - | | ├── ... - | └── X4 - | ├── 0001x4.png - | ├── 0002x4.png - | ├── ... - ├── DIV2K_train_LR_mild - | ├── 0001x4m.png - | ├── 0002x4m.png - | ├── ... - ├── DIV2K_train_LR_difficult - | ├── 0001x4d.png - | ├── 0002x4d.png - | ├── ... - ├── DIV2K_train_LR_wild - | ├── 0001x4w.png - | ├── 0002x4w.png - | ├── ... - └── DIV2K_train_LR_x8 - ├── 0001x8.png - ├── 0002x8.png - ├── ... - - **引用:** - - .. code-block:: - - @InProceedings{Agustsson_2017_CVPR_Workshops, - author = {Agustsson, Eirikur and Timofte, Radu}, - title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study}, - booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, - url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf", - month = {July}, - year = {2017} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.DIV2KDataset +============================== + +.. py:class:: mindspore.dataset.DIV2KDataset(dataset_dir, usage="train", downgrade="bicubic", scale=2, num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + DIV2K(DIVerse 2K resolution image)数据集。 + + 生成的数据集有两列 `[hr_image, lr_image]` 。 `hr_image` 列和 `lr_image` 列的数据类型都为uint8。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集。可取值为 ``'train'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``'train'`` 。 + - **downgrade** (str, 可选) - 指定数据集的下采样的模式,可取值为 ``'bicubic'`` 、 ``'unknown'`` 、 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` 。默认值: ``'bicubic'`` 。 + - **scale** (int, 可选) - 指定数据集的缩放尺度。当参数 `downgrade` 取值为 ``'bicubic'`` 时,此参数可以取值为 ``2`` 、 ``3`` 、 ``4`` 、``8`` 。 + 当参数 `downgrade` 取值为 ``'unknown'`` 时,此参数可以取值为 ``2`` 、 ``3`` 、 ``4`` 。当参数 `downgrade` 取值为 ``'mild'`` 、 ``'difficult'`` 、 ``'wild'`` 时,此参数仅可以取值为 ``4`` 。默认值: ``2`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `dataset_dir` 路径非法或不存在。 + - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'valid'`` 或 ``'all'`` 。 + - **ValueError** - `downgrade` 参数取值不为 ``'bicubic'`` 、 ``'unknown'`` 、 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` 。 + - **ValueError** - `scale` 参数取值不在给定的字段中,或与 `downgrade` 参数的值不匹配。 + - **ValueError** - `scale` 参数取值为8,但 `downgrade` 参数的值不为 ``'bicubic'`` 。 + - **ValueError** - `downgrade` 参数取值为 ``'mild'`` 、 ``'difficult'`` 或 ``'wild'`` ,但 `scale` 参数的值不为 ``4`` 。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于DIV2K数据集:** + + DIV2K数据集由1000张2K分辨率图像组成,其中800张用于训练,100张用于验证,100张用于测试。 + 作为NTIRE比赛的数据集,NTIRE 2017 和 NTIRE 2018 仅包括DIV2K的训练数据集和验证数据集。 + + 您可以解压缩原始DIV2K数据集文件到如下目录结构,并通过MindSpore的API进行读取。 + + 以训练数据集作为例子。 + + .. code-block:: + + . + └── DIV2K + ├── DIV2K_train_HR + | ├── 0001.png + | ├── 0002.png + | ├── ... + ├── DIV2K_train_LR_bicubic + | ├── X2 + | | ├── 0001x2.png + | | ├── 0002x2.png + | | ├── ... + | ├── X3 + | | ├── 0001x3.png + | | ├── 0002x3.png + | | ├── ... + | └── X4 + | ├── 0001x4.png + | ├── 0002x4.png + | ├── ... + ├── DIV2K_train_LR_unknown + | ├── X2 + | | ├── 0001x2.png + | | ├── 0002x2.png + | | ├── ... + | ├── X3 + | | ├── 0001x3.png + | | ├── 0002x3.png + | | ├── ... + | └── X4 + | ├── 0001x4.png + | ├── 0002x4.png + | ├── ... + ├── DIV2K_train_LR_mild + | ├── 0001x4m.png + | ├── 0002x4m.png + | ├── ... + ├── DIV2K_train_LR_difficult + | ├── 0001x4d.png + | ├── 0002x4d.png + | ├── ... + ├── DIV2K_train_LR_wild + | ├── 0001x4w.png + | ├── 0002x4w.png + | ├── ... + └── DIV2K_train_LR_x8 + ├── 0001x8.png + ├── 0002x8.png + ├── ... + + **引用:** + + .. code-block:: + + @InProceedings{Agustsson_2017_CVPR_Workshops, + author = {Agustsson, Eirikur and Timofte, Radu}, + title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, + url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf", + month = {July}, + year = {2017} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.DistributedSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.DistributedSampler.rst index dfa003a7504..50ec22bdfbf 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.DistributedSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.DistributedSampler.rst @@ -1,28 +1,28 @@ -mindspore.dataset.DistributedSampler -==================================== - -.. py:class:: mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1) - - 分布式采样器,将数据集进行分片用于分布式训练。 - - 参数: - - **num_shards** (int) - 数据集分片数量。 - - **shard_id** (int) - 当前分片的分片ID,应在[0, num_shards-1]范围内。 - - **shuffle** (bool, 可选) - 是否混洗采样得到的样本。默认值: ``True`` ,混洗样本。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - **offset** (int, 可选) - 分布式采样结果进行分配时的起始分片ID号,值不能大于参数 `num_shards` 。从不同的分片ID开始分配数据可能会影响每个分片的最终样本数。仅当ConcatDataset以 :class:`mindspore.dataset.DistributedSampler` 为采样器时,此参数才有效。默认值: ``-1`` ,每个分片具有相同的样本数。 - - 异常: - - **TypeError** - `num_shards` 的类型不是int。 - - **TypeError** - `shard_id` 的类型不是int。 - - **TypeError** - `shuffle` 的类型不是bool。 - - **TypeError** - `num_samples` 的类型不是int。 - - **TypeError** - `offset` 的类型不是int。 - - **ValueError** - `num_samples` 为负值。 - - **RuntimeError** - `num_shards` 不是正值。 - - **RuntimeError** - `shard_id` 小于0或大于等于 `num_shards` 。 - - **RuntimeError** - `offset` 大于 `num_shards` 。 - - .. include:: mindspore.dataset.BuiltinSampler.rst - +mindspore.dataset.DistributedSampler +==================================== + +.. py:class:: mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1) + + 分布式采样器,将数据集进行分片用于分布式训练。 + + 参数: + - **num_shards** (int) - 数据集分片数量。 + - **shard_id** (int) - 当前分片的分片ID,应在[0, num_shards-1]范围内。 + - **shuffle** (bool, 可选) - 是否混洗采样得到的样本。默认值: ``True`` ,混洗样本。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + - **offset** (int, 可选) - 分布式采样结果进行分配时的起始分片ID号,值不能大于参数 `num_shards` 。从不同的分片ID开始分配数据可能会影响每个分片的最终样本数。仅当ConcatDataset以 :class:`mindspore.dataset.DistributedSampler` 为采样器时,此参数才有效。默认值: ``-1`` ,每个分片具有相同的样本数。 + + 异常: + - **TypeError** - `num_shards` 的类型不是int。 + - **TypeError** - `shard_id` 的类型不是int。 + - **TypeError** - `shuffle` 的类型不是bool。 + - **TypeError** - `num_samples` 的类型不是int。 + - **TypeError** - `offset` 的类型不是int。 + - **ValueError** - `num_samples` 为负值。 + - **RuntimeError** - `num_shards` 不是正值。 + - **RuntimeError** - `shard_id` 小于0或大于等于 `num_shards` 。 + - **RuntimeError** - `offset` 大于 `num_shards` 。 + + .. include:: mindspore.dataset.BuiltinSampler.rst + .. include:: mindspore.dataset.BuiltinSampler.b.rst \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.EMnistDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.EMnistDataset.rst index 6e773bc37b4..2f0d78caf79 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.EMnistDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.EMnistDataset.rst @@ -1,80 +1,80 @@ -mindspore.dataset.EMnistDataset -=============================== - -.. py:class:: mindspore.dataset.EMnistDataset(dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - EMNIST(Extended MNIST)数据集。 - - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **name** (str) - 按给定规则对数据集进行拆分,可以是 ``'byclass'`` 、 ``'bymerge'`` 、 ``'balanced'`` 、 ``'letters'`` 、 ``'digits'`` 或 ``'mnist'`` 。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于EMNIST数据集:** - - EMNIST数据集由一组手写字符数字组成,源自NIST特别版数据库19,并转换为与MNIST数据集直接匹配的28x28像素图像格式和数据集结构。 - 有关数据集内容和转换过程的更多信息可在 https://arxiv.org/abs/1702.05373v1 上查阅。 - - EMNIST按照不同的规则拆分成不同的子数据集的样本数和类数如下: - - 按类拆分:814,255个样本和62个样本不平衡类。 - 按合并拆分:814,255个样本和47个样本不平衡类。 - 平衡拆分:131,600个样本和47个样本平衡类。 - 按字母拆分:145,600个样本和26个样本平衡类。 - 按数字拆分:280,000个样本和10个样本平衡类。 - MNIST: 70,000个样本符和10个样本平衡类。 - - 以下是原始EMNIST数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── mnist_dataset_dir - ├── emnist-mnist-train-images-idx3-ubyte - ├── emnist-mnist-train-labels-idx1-ubyte - ├── emnist-mnist-test-images-idx3-ubyte - ├── emnist-mnist-test-labels-idx1-ubyte - ├── ... - - **引用:** - - .. code-block:: - - @article{cohen_afshar_tapson_schaik_2017, - title = {EMNIST: Extending MNIST to handwritten letters}, - DOI = {10.1109/ijcnn.2017.7966217}, - journal = {2017 International Joint Conference on Neural Networks (IJCNN)}, - author = {Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van}, - year = {2017}, - howpublished = {https://www.westernsydney.edu.au/icns/reproducible_research/ - publication_support_materials/emnist} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.EMnistDataset +=============================== + +.. py:class:: mindspore.dataset.EMnistDataset(dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + EMNIST(Extended MNIST)数据集。 + + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **name** (str) - 按给定规则对数据集进行拆分,可以是 ``'byclass'`` 、 ``'bymerge'`` 、 ``'balanced'`` 、 ``'letters'`` 、 ``'digits'`` 或 ``'mnist'`` 。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于EMNIST数据集:** + + EMNIST数据集由一组手写字符数字组成,源自NIST特别版数据库19,并转换为与MNIST数据集直接匹配的28x28像素图像格式和数据集结构。 + 有关数据集内容和转换过程的更多信息可在 https://arxiv.org/abs/1702.05373v1 上查阅。 + + EMNIST按照不同的规则拆分成不同的子数据集的样本数和类数如下: + + 按类拆分:814,255个样本和62个样本不平衡类。 + 按合并拆分:814,255个样本和47个样本不平衡类。 + 平衡拆分:131,600个样本和47个样本平衡类。 + 按字母拆分:145,600个样本和26个样本平衡类。 + 按数字拆分:280,000个样本和10个样本平衡类。 + MNIST: 70,000个样本符和10个样本平衡类。 + + 以下是原始EMNIST数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── mnist_dataset_dir + ├── emnist-mnist-train-images-idx3-ubyte + ├── emnist-mnist-train-labels-idx1-ubyte + ├── emnist-mnist-test-images-idx3-ubyte + ├── emnist-mnist-test-labels-idx1-ubyte + ├── ... + + **引用:** + + .. code-block:: + + @article{cohen_afshar_tapson_schaik_2017, + title = {EMNIST: Extending MNIST to handwritten letters}, + DOI = {10.1109/ijcnn.2017.7966217}, + journal = {2017 International Joint Conference on Neural Networks (IJCNN)}, + author = {Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and Schaik, Andre Van}, + year = {2017}, + howpublished = {https://www.westernsydney.edu.au/icns/reproducible_research/ + publication_support_materials/emnist} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.EnWik9Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.EnWik9Dataset.rst index f681929fc37..abc0786fc77 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.EnWik9Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.EnWik9Dataset.rst @@ -1,62 +1,62 @@ -mindspore.dataset.EnWik9Dataset -=============================== - -.. py:class:: mindspore.dataset.EnWik9Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=True, num_shards=None, shard_id=None, cache=None) - - EnWik9数据集。 - - 生成的数据集有一列 `[text]` ,数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``True`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于EnWik9数据集:** - - EnWik9的数据是一系列UTF-8编码的XML,主要由英文文本组成。数据集包含243,426篇文章标题,其中85,560个被重定向以修复丢失的网页链接,其余是常规文章。 - - 数据是UTF-8格式。所有字符都在U'0000到U'10FFFF范围内,有效编码为1到4字节。字节值0xC0、0xC1和0xF5-0xFF从未出现。此外,在维基百科转储中,除了0x09(制表符)和0x0A(换行符)外,没有范围为0x00-0x1F的控制字符。 - 断行符只出现在段落边界上,因此整体是有语义目的。 - - 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── EnWik9 - ├── enwik9 - - **引用:** - - .. code-block:: - - @NetworkResource{Hutter_prize, - author = {English Wikipedia}, - url = "https://cs.fit.edu/~mmahoney/compression/textdata.html", - month = {March}, - year = {2006} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.EnWik9Dataset +=============================== + +.. py:class:: mindspore.dataset.EnWik9Dataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=True, num_shards=None, shard_id=None, cache=None) + + EnWik9数据集。 + + 生成的数据集有一列 `[text]` ,数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``True`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于EnWik9数据集:** + + EnWik9的数据是一系列UTF-8编码的XML,主要由英文文本组成。数据集包含243,426篇文章标题,其中85,560个被重定向以修复丢失的网页链接,其余是常规文章。 + + 数据是UTF-8格式。所有字符都在U'0000到U'10FFFF范围内,有效编码为1到4字节。字节值0xC0、0xC1和0xF5-0xFF从未出现。此外,在维基百科转储中,除了0x09(制表符)和0x0A(换行符)外,没有范围为0x00-0x1F的控制字符。 + 断行符只出现在段落边界上,因此整体是有语义目的。 + + 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── EnWik9 + ├── enwik9 + + **引用:** + + .. code-block:: + + @NetworkResource{Hutter_prize, + author = {English Wikipedia}, + url = "https://cs.fit.edu/~mmahoney/compression/textdata.html", + month = {March}, + year = {2006} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.FakeImageDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.FakeImageDataset.rst index ece650b8a98..4c5cbc7ade6 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.FakeImageDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.FakeImageDataset.rst @@ -1,40 +1,40 @@ -mindspore.dataset.FakeImageDataset -================================== - -.. py:class:: mindspore.dataset.FakeImageDataset(num_images=1000, image_size=(224, 224, 3), num_classes=10, base_seed=0, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - 生成虚假图像构建数据集。 - - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **num_images** (int, 可选) - 要生成的虚假图像数。默认值: ``1000`` 。 - - **image_size** (tuple, 可选) - 虚假图像的尺寸。默认值: ``(224, 224, 3)`` 。 - - **num_classes** (int, 可选) - 数据集的类别数。默认值: ``10`` 。 - - **base_seed** (int, 可选) - 生成随机图像的随机种子。默认值: ``0`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - :parser: reStructuredText - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.FakeImageDataset +================================== + +.. py:class:: mindspore.dataset.FakeImageDataset(num_images=1000, image_size=(224, 224, 3), num_classes=10, base_seed=0, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + 生成虚假图像构建数据集。 + + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **num_images** (int, 可选) - 要生成的虚假图像数。默认值: ``1000`` 。 + - **image_size** (tuple, 可选) - 虚假图像的尺寸。默认值: ``(224, 224, 3)`` 。 + - **num_classes** (int, 可选) - 数据集的类别数。默认值: ``10`` 。 + - **base_seed** (int, 可选) - 生成随机图像的随机种子。默认值: ``0`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + :parser: reStructuredText + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.FashionMnistDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.FashionMnistDataset.rst index 9153228542b..79501ad53ec 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.FashionMnistDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.FashionMnistDataset.rst @@ -1,70 +1,70 @@ -mindspore.dataset.FashionMnistDataset -===================================== - -.. py:class:: mindspore.dataset.FashionMnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Fashion-MNIST数据集。 - - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Fashion-MNIST数据集:** - - Fashion-MNIST是网络电子商城Zalando推出的数据集,包括60,000个样本的训练集和10,000个样本的测试集。每个示例都是一个28x28灰度图像,分别与10个类的标签关联。 - Fashion-MNIST是原始MNIST数据集的变种,用于对机器学习算法进行基准测试。它的训练集和测试集的图像尺寸和结构相同。 - - 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── fashionmnist_dataset_dir - ├── t10k-images-idx3-ubyte - ├── t10k-labels-idx1-ubyte - ├── train-images-idx3-ubyte - └── train-labels-idx1-ubyte - - **引用:** - - .. code-block:: - - @online{xiao2017/online, - author = {Han Xiao and Kashif Rasul and Roland Vollgraf}, - title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms}, - date = {2017-08-28}, - year = {2017}, - eprintclass = {cs.LG}, - eprinttype = {arXiv}, - eprint = {cs.LG/1708.07747}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.FashionMnistDataset +===================================== + +.. py:class:: mindspore.dataset.FashionMnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Fashion-MNIST数据集。 + + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Fashion-MNIST数据集:** + + Fashion-MNIST是网络电子商城Zalando推出的数据集,包括60,000个样本的训练集和10,000个样本的测试集。每个示例都是一个28x28灰度图像,分别与10个类的标签关联。 + Fashion-MNIST是原始MNIST数据集的变种,用于对机器学习算法进行基准测试。它的训练集和测试集的图像尺寸和结构相同。 + + 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── fashionmnist_dataset_dir + ├── t10k-images-idx3-ubyte + ├── t10k-labels-idx1-ubyte + ├── train-images-idx3-ubyte + └── train-labels-idx1-ubyte + + **引用:** + + .. code-block:: + + @online{xiao2017/online, + author = {Han Xiao and Kashif Rasul and Roland Vollgraf}, + title = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms}, + date = {2017-08-28}, + year = {2017}, + eprintclass = {cs.LG}, + eprinttype = {arXiv}, + eprint = {cs.LG/1708.07747}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.FlickrDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.FlickrDataset.rst index ab7a9f3d197..d0aa8c5d4bf 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.FlickrDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.FlickrDataset.rst @@ -1,110 +1,110 @@ -mindspore.dataset.FlickrDataset -================================ - -.. py:class:: mindspore.dataset.FlickrDataset(dataset_dir, annotation_file, num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Flickr8k和Flickr30k数据集。 - - 生成的数据集有两列: `[image, annotation]`。 `image` 列的数据类型为uint8。 `annotation` 列是一个包含5个标注字符的张量,如["a", "b", "c", "d", "e"]。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **annotation_file** (str) - 数据集标注JSON文件的路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` ,表2中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` ,表2中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `annotation_file` 参数对应的文件不存在。 - - **ValueError** - `dataset_dir` 参数路径不存在。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Flickr8k数据集:** - - Flickr8k数据集由8092张彩色图像组成。Flickr8k.token.txt中有40460个标注,每张图像有5个标注。 - - 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── Flickr8k - ├── Flickr8k_Dataset - │ ├── 1000268201_693b08cb0e.jpg - │ ├── 1001773457_577c3a7d70.jpg - │ ├── ... - └── Flickr8k.token.txt - - **引用:** - - .. code-block:: - - @article{DBLP:journals/jair/HodoshYH13, - author = {Micah Hodosh and Peter Young and Julia Hockenmaier}, - title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics}, - journal = {J. Artif. Intell. Res.}, - volume = {47}, - pages = {853--899}, - year = {2013}, - url = {https://doi.org/10.1613/jair.3994}, - doi = {10.1613/jair.3994}, - timestamp = {Mon, 21 Jan 2019 15:01:17 +0100}, - biburl = {https://dblp.org/rec/journals/jair/HodoshYH13.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} - } - - **关于Flickr30k数据集:** - - Flickr30k数据集由31783张彩色图像组成。results_20130124.token中有158915个标注,每个图像有5个标注。 - - 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── Flickr30k - ├── flickr30k-images - │ ├── 1000092795.jpg - │ ├── 10002456.jpg - │ ├── ... - └── results_20130124.token - - **引用:** - - .. code-block:: - - @article{DBLP:journals/tacl/YoungLHH14, - author = {Peter Young and Alice Lai and Micah Hodosh and Julia Hockenmaier}, - title = {From image descriptions to visual denotations: New similarity metrics - for semantic inference over event descriptions}, - journal = {Trans. Assoc. Comput. Linguistics}, - volume = {2}, - pages = {67--78}, - year = {2014}, - url = {https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/229}, - timestamp = {Wed, 17 Feb 2021 21:55:25 +0100}, - biburl = {https://dblp.org/rec/journals/tacl/YoungLHH14.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.FlickrDataset +================================ + +.. py:class:: mindspore.dataset.FlickrDataset(dataset_dir, annotation_file, num_samples=None, num_parallel_workers=None, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Flickr8k和Flickr30k数据集。 + + 生成的数据集有两列: `[image, annotation]`。 `image` 列的数据类型为uint8。 `annotation` 列是一个包含5个标注字符的张量,如["a", "b", "c", "d", "e"]。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **annotation_file** (str) - 数据集标注JSON文件的路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` ,表2中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` ,表2中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `annotation_file` 参数对应的文件不存在。 + - **ValueError** - `dataset_dir` 参数路径不存在。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Flickr8k数据集:** + + Flickr8k数据集由8092张彩色图像组成。Flickr8k.token.txt中有40460个标注,每张图像有5个标注。 + + 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── Flickr8k + ├── Flickr8k_Dataset + │ ├── 1000268201_693b08cb0e.jpg + │ ├── 1001773457_577c3a7d70.jpg + │ ├── ... + └── Flickr8k.token.txt + + **引用:** + + .. code-block:: + + @article{DBLP:journals/jair/HodoshYH13, + author = {Micah Hodosh and Peter Young and Julia Hockenmaier}, + title = {Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics}, + journal = {J. Artif. Intell. Res.}, + volume = {47}, + pages = {853--899}, + year = {2013}, + url = {https://doi.org/10.1613/jair.3994}, + doi = {10.1613/jair.3994}, + timestamp = {Mon, 21 Jan 2019 15:01:17 +0100}, + biburl = {https://dblp.org/rec/journals/jair/HodoshYH13.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + + **关于Flickr30k数据集:** + + Flickr30k数据集由31783张彩色图像组成。results_20130124.token中有158915个标注,每个图像有5个标注。 + + 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── Flickr30k + ├── flickr30k-images + │ ├── 1000092795.jpg + │ ├── 10002456.jpg + │ ├── ... + └── results_20130124.token + + **引用:** + + .. code-block:: + + @article{DBLP:journals/tacl/YoungLHH14, + author = {Peter Young and Alice Lai and Micah Hodosh and Julia Hockenmaier}, + title = {From image descriptions to visual denotations: New similarity metrics + for semantic inference over event descriptions}, + journal = {Trans. Assoc. Comput. Linguistics}, + volume = {2}, + pages = {67--78}, + year = {2014}, + url = {https://tacl2013.cs.columbia.edu/ojs/index.php/tacl/article/view/229}, + timestamp = {Wed, 17 Feb 2021 21:55:25 +0100}, + biburl = {https://dblp.org/rec/journals/tacl/YoungLHH14.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Flowers102Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Flowers102Dataset.rst index 461b18bcb16..90a270a9dd9 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Flowers102Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Flowers102Dataset.rst @@ -1,77 +1,77 @@ -mindspore.dataset.Flowers102Dataset -=================================== - -.. py:class:: mindspore.dataset.Flowers102Dataset(dataset_dir, task='Classification', usage='all', num_samples=None, num_parallel_workers=1, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None) - - Oxfird 102 Flower数据集。 - - 根据给定的 `task` 配置,生成数据集具有不同的输出列: - - - `task` 为 ``'Classification'`` ,输出列: `[image, dtype=uint8]` 、 `[label, dtype=uint32]` 。 - - `task` 为 ``'Segmentation'`` ,输出列: `[image, dtype=uint8]` 、 `[segmentation, dtype=uint8]` 、 `[label, dtype=uint32]` 。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **task** (str, 可选) - 指定读取数据的任务类型,支持 ``'Classification'`` 和 ``'Segmentation'``。默认值: ``'Classification'`` 。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``'all'`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Union[Sampler, Iterable], 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Flowers102数据集:** - - Flowers102数据集由102个花类别组成,每个类由40到258张图像组成,这些花常见于英国。 - - 以下是原始的Flowers102数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── flowes102_dataset_dir - ├── imagelabels.mat - ├── setid.mat - ├── jpg - ├── image_00001.jpg - ├── image_00002.jpg - ├── ... - ├── segmim - ├── segmim_00001.jpg - ├── segmim_00002.jpg - ├── ... - - **引用:** - - .. code-block:: - - @InProceedings{Nilsback08, - author = "Maria-Elena Nilsback and Andrew Zisserman", - title = "Automated Flower Classification over a Large Number of Classes", - booktitle = "Indian Conference on Computer Vision, Graphics and Image Processing", - month = "Dec", - year = "2008", - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Flowers102Dataset +=================================== + +.. py:class:: mindspore.dataset.Flowers102Dataset(dataset_dir, task='Classification', usage='all', num_samples=None, num_parallel_workers=1, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None) + + Oxfird 102 Flower数据集。 + + 根据给定的 `task` 配置,生成数据集具有不同的输出列: + + - `task` 为 ``'Classification'`` ,输出列: `[image, dtype=uint8]` 、 `[label, dtype=uint32]` 。 + - `task` 为 ``'Segmentation'`` ,输出列: `[image, dtype=uint8]` 、 `[segmentation, dtype=uint8]` 、 `[label, dtype=uint32]` 。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **task** (str, 可选) - 指定读取数据的任务类型,支持 ``'Classification'`` 和 ``'Segmentation'``。默认值: ``'Classification'`` 。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``'all'`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Union[Sampler, Iterable], 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Flowers102数据集:** + + Flowers102数据集由102个花类别组成,每个类由40到258张图像组成,这些花常见于英国。 + + 以下是原始的Flowers102数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── flowes102_dataset_dir + ├── imagelabels.mat + ├── setid.mat + ├── jpg + ├── image_00001.jpg + ├── image_00002.jpg + ├── ... + ├── segmim + ├── segmim_00001.jpg + ├── segmim_00002.jpg + ├── ... + + **引用:** + + .. code-block:: + + @InProceedings{Nilsback08, + author = "Maria-Elena Nilsback and Andrew Zisserman", + title = "Automated Flower Classification over a Large Number of Classes", + booktitle = "Indian Conference on Computer Vision, Graphics and Image Processing", + month = "Dec", + year = "2008", + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.IMDBDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.IMDBDataset.rst index 2ceaaad0b19..c332d45344b 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.IMDBDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.IMDBDataset.rst @@ -1,88 +1,88 @@ -mindspore.dataset.IMDBDataset -============================= - -.. py:class:: mindspore.dataset.IMDBDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - IMDb(Internet Movie Database)数据集。 - - 生成的数据集有两列 `[text, label]` 。 `text` 列的数据类型是string。 `label` 列的数据类型是uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于IMDB数据集:** - - IMDB数据集包含来自互联网电影数据库(IMDB)的50000条高度两极分化的评论。 - 数据集分为25,000条用于训练的评论和25,000条用于测试的评论,训练集和测试集都包含50%的积极评论和50%的消极评论。 - 训练标签和测试标签分别是0和1,其中0代表负样本,1代表正样本。 - - 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── imdb_dataset_directory - ├── train - │ ├── pos - │ │ ├── 0_9.txt - │ │ ├── 1_7.txt - │ │ ├── ... - │ ├── neg - │ │ ├── 0_3.txt - │ │ ├── 1_1.txt - │ │ ├── ... - ├── test - │ ├── pos - │ │ ├── 0_10.txt - │ │ ├── 1_10.txt - │ │ ├── ... - │ ├── neg - │ │ ├── 0_2.txt - │ │ ├── 1_3.txt - │ │ ├── ... - - **引用:** - - .. code-block:: - - @InProceedings{maas-EtAl:2011:ACL-HLT2011, - author = {Maas, Andrew L. and Daly, Raymond E. and Pham, Peter T. and Huang, Dan - and Ng, Andrew Y. and Potts, Christopher}, - title = {Learning Word Vectors for Sentiment Analysis}, - booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: - Human Language Technologies}, - month = {June}, - year = {2011}, - address = {Portland, Oregon, USA}, - publisher = {Association for Computational Linguistics}, - pages = {142--150}, - url = {http://www.aclweb.org/anthology/P11-1015} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.IMDBDataset +============================= + +.. py:class:: mindspore.dataset.IMDBDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + IMDb(Internet Movie Database)数据集。 + + 生成的数据集有两列 `[text, label]` 。 `text` 列的数据类型是string。 `label` 列的数据类型是uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于IMDB数据集:** + + IMDB数据集包含来自互联网电影数据库(IMDB)的50000条高度两极分化的评论。 + 数据集分为25,000条用于训练的评论和25,000条用于测试的评论,训练集和测试集都包含50%的积极评论和50%的消极评论。 + 训练标签和测试标签分别是0和1,其中0代表负样本,1代表正样本。 + + 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── imdb_dataset_directory + ├── train + │ ├── pos + │ │ ├── 0_9.txt + │ │ ├── 1_7.txt + │ │ ├── ... + │ ├── neg + │ │ ├── 0_3.txt + │ │ ├── 1_1.txt + │ │ ├── ... + ├── test + │ ├── pos + │ │ ├── 0_10.txt + │ │ ├── 1_10.txt + │ │ ├── ... + │ ├── neg + │ │ ├── 0_2.txt + │ │ ├── 1_3.txt + │ │ ├── ... + + **引用:** + + .. code-block:: + + @InProceedings{maas-EtAl:2011:ACL-HLT2011, + author = {Maas, Andrew L. and Daly, Raymond E. and Pham, Peter T. and Huang, Dan + and Ng, Andrew Y. and Potts, Christopher}, + title = {Learning Word Vectors for Sentiment Analysis}, + booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: + Human Language Technologies}, + month = {June}, + year = {2011}, + address = {Portland, Oregon, USA}, + publisher = {Association for Computational Linguistics}, + pages = {142--150}, + url = {http://www.aclweb.org/anthology/P11-1015} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.IWSLT2016Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.IWSLT2016Dataset.rst index 6207084bd5e..45c06cc4439 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.IWSLT2016Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.IWSLT2016Dataset.rst @@ -1,93 +1,93 @@ -mindspore.dataset.IWSLT2016Dataset -================================== - -.. py:class:: mindspore.dataset.IWSLT2016Dataset(dataset_dir, usage=None, language_pair=None, valid_set=None, test_set=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) - - IWSLT2016(International Workshop on Spoken Language Translation)数据集。 - - 生成的数据集有两列 `[text, translation]` 。 `text` 列的数据类型是string。 `translation` 列的数据类型是string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **language_pair** (sequence, 可选) - 包含源语言和目标语言的序列,支持的值为 ``('en', 'fr')`` 、 ``('en', 'de')`` 、 ``('en', 'cs')`` 、 ``('en', 'ar')`` 、 ``('de', 'en')`` 、 ``('cs', 'en')`` 、 ``('ar', 'en')`` 。默认值: ``None``,默认为 ``('de', 'en')`` 。 - - **valid_set** (str, 可选) - 标识验证集的字符串,支持的值为 ``'dev2010'`` 、 ``'tst2010'`` 、 ``'tst2011'`` 、 ``'tst2012'`` 、 ``'tst2013'`` 和 ``'tst2014'`` 。默认值: ``None``,默认为 ``'tst2013'`` 。 - - **test_set** (str, 可选) - 识别测试集的字符串,支持的值为 ``'dev2010'`` 、 ``'tst2010'`` 、 ``'tst2011'`` 、 ``'tst2012'`` 、 ``'tst2013'`` 和 ``'tst2014'`` 。默认值: ``None``,默认为 ``'tst2014'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于IWSLT2016数据集:** - - IWSLT是一个专门讨论口译各个方面的重要年度科学会议。IWSLT评估活动中的MT任务被构成一个数据集,该数据集可通过 `wit3 `_ 公开获取。 - IWSLT2016数据集包括从英语到阿拉伯语、捷克、法语和德语的翻译,以及从阿拉伯语、捷克、法语和德语到英语的翻译。 - - 可以将原始IWSLT2016数据集文件解压缩到此目录结构中,并由MindSpore的API读取。解压后,还需要将要读取的数据集解压到指定文件夹中。例如,如果要读取de-en的数据集,则需要解压缩de/en目录下的tgz文件,数据集位于解压缩文件夹中。 - - .. code-block:: - - . - └── iwslt2016_dataset_directory - ├── subeval_files - └── texts - ├── ar - │ └── en - │ └── ar-en - ├── cs - │ └── en - │ └── cs-en - ├── de - │ └── en - │ └── de-en - │ ├── IWSLT16.TED.dev2010.de-en.de.xml - │ ├── train.tags.de-en.de - │ ├── ... - ├── en - │ ├── ar - │ │ └── en-ar - │ ├── cs - │ │ └── en-cs - │ ├── de - │ │ └── en-de - │ └── fr - │ └── en-fr - └── fr - └── en - └── fr-en - - **引用:** - - .. code-block:: - - @inproceedings{cettoloEtAl:EAMT2012, - Address = {Trento, Italy}, - Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, - Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation - (EAMT)}, - Date = {28-30}, - Month = {May}, - Pages = {261--268}, - Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, - Year = {2012}} - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.IWSLT2016Dataset +================================== + +.. py:class:: mindspore.dataset.IWSLT2016Dataset(dataset_dir, usage=None, language_pair=None, valid_set=None, test_set=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) + + IWSLT2016(International Workshop on Spoken Language Translation)数据集。 + + 生成的数据集有两列 `[text, translation]` 。 `text` 列的数据类型是string。 `translation` 列的数据类型是string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **language_pair** (sequence, 可选) - 包含源语言和目标语言的序列,支持的值为 ``('en', 'fr')`` 、 ``('en', 'de')`` 、 ``('en', 'cs')`` 、 ``('en', 'ar')`` 、 ``('de', 'en')`` 、 ``('cs', 'en')`` 、 ``('ar', 'en')`` 。默认值: ``None``,默认为 ``('de', 'en')`` 。 + - **valid_set** (str, 可选) - 标识验证集的字符串,支持的值为 ``'dev2010'`` 、 ``'tst2010'`` 、 ``'tst2011'`` 、 ``'tst2012'`` 、 ``'tst2013'`` 和 ``'tst2014'`` 。默认值: ``None``,默认为 ``'tst2013'`` 。 + - **test_set** (str, 可选) - 识别测试集的字符串,支持的值为 ``'dev2010'`` 、 ``'tst2010'`` 、 ``'tst2011'`` 、 ``'tst2012'`` 、 ``'tst2013'`` 和 ``'tst2014'`` 。默认值: ``None``,默认为 ``'tst2014'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于IWSLT2016数据集:** + + IWSLT是一个专门讨论口译各个方面的重要年度科学会议。IWSLT评估活动中的MT任务被构成一个数据集,该数据集可通过 `wit3 `_ 公开获取。 + IWSLT2016数据集包括从英语到阿拉伯语、捷克、法语和德语的翻译,以及从阿拉伯语、捷克、法语和德语到英语的翻译。 + + 可以将原始IWSLT2016数据集文件解压缩到此目录结构中,并由MindSpore的API读取。解压后,还需要将要读取的数据集解压到指定文件夹中。例如,如果要读取de-en的数据集,则需要解压缩de/en目录下的tgz文件,数据集位于解压缩文件夹中。 + + .. code-block:: + + . + └── iwslt2016_dataset_directory + ├── subeval_files + └── texts + ├── ar + │ └── en + │ └── ar-en + ├── cs + │ └── en + │ └── cs-en + ├── de + │ └── en + │ └── de-en + │ ├── IWSLT16.TED.dev2010.de-en.de.xml + │ ├── train.tags.de-en.de + │ ├── ... + ├── en + │ ├── ar + │ │ └── en-ar + │ ├── cs + │ │ └── en-cs + │ ├── de + │ │ └── en-de + │ └── fr + │ └── en-fr + └── fr + └── en + └── fr-en + + **引用:** + + .. code-block:: + + @inproceedings{cettoloEtAl:EAMT2012, + Address = {Trento, Italy}, + Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, + Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation + (EAMT)}, + Date = {28-30}, + Month = {May}, + Pages = {261--268}, + Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, + Year = {2012}} + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.IWSLT2017Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.IWSLT2017Dataset.rst index b6c8fe02bdc..bb2b73d976c 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.IWSLT2017Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.IWSLT2017Dataset.rst @@ -1,94 +1,94 @@ -mindspore.dataset.IWSLT2017Dataset -================================== - -.. py:class:: mindspore.dataset.IWSLT2017Dataset(dataset_dir, usage=None, language_pair=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) - - IWSLT2017(International Workshop on Spoken Language Translation)数据集。 - - 生成的数据集有两列 `[text, translation]` 。 `text` 列和 `translation` 列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **language_pair** (sequence, 可选) - 包含源语和目标语的语言列表,支持的语言对有 ``('en', 'nl')`` 、 - ``('en', 'de')`` 、 ``('en', 'it')`` 、 ``('en', 'ro')`` 、 ``('nl', 'en')`` 、 ``('nl', 'de')`` 、 ``('nl', 'it')`` 、 ``('nl', 'ro')`` 、 - ``('de', 'en')`` 、 ``('de', 'nl')`` 、 ``('de', 'it')`` 、 ``('de', 'ro')`` 、 ``('it', 'en')`` 、 ``('it', 'nl')`` 、 ``('it', 'de')`` 、 - ``('it', 'ro')`` 、 ``('ro', 'en')`` 、 ``('ro', 'nl')`` 、 ``('ro', 'de')`` 、 ``('ro', 'it')`` 。默认值: ``None`` ,默认为 ``('de', 'en')`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于IWSLT2017数据集:** - - IWSLT是一个专门讨论口译各个方面的重要年度科学会议。IWSLT评估活动中的MT任务被构成一个数据集,该数据集可通过 `wit3 `_ 公开获取。 - IWSLT2017数据集中有德语、英语、意大利语、荷兰语和罗马尼亚语,数据集包括其中任何两种语言的翻译。 - - 可以将原始IWSLT2017数据集文件解压缩到此目录结构中,并由MindSpore的API读取。解压后,还需要将要读取的数据集解压到指定文件夹中。例如,如果要读取de-en的数据集,则需要解压缩de/en目录下的tgz文件,数据集位于解压缩文件夹中。 - - .. code-block:: - - . - └── iwslt2017_dataset_directory - ├── subeval_files - └── texts - ├── ar - │ └── en - │ └── ar-en - ├── cs - │ └── en - │ └── cs-en - ├── de - │ └── en - │ └── de-en - │ ├── IWSLT16.TED.dev2010.de-en.de.xml - │ ├── train.tags.de-en.de - │ ├── ... - ├── en - │ ├── ar - │ │ └── en-ar - │ ├── cs - │ │ └── en-cs - │ ├── de - │ │ └── en-de - │ └── fr - │ └── en-fr - └── fr - └── en - └── fr-en - - **引用:** - - .. code-block:: - - @inproceedings{cettoloEtAl:EAMT2012, - Address = {Trento, Italy}, - Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, - Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation - (EAMT)}, - Date = {28-30}, - Month = {May}, - Pages = {261--268}, - Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, - Year = {2012}} - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.IWSLT2017Dataset +================================== + +.. py:class:: mindspore.dataset.IWSLT2017Dataset(dataset_dir, usage=None, language_pair=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) + + IWSLT2017(International Workshop on Spoken Language Translation)数据集。 + + 生成的数据集有两列 `[text, translation]` 。 `text` 列和 `translation` 列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'valid'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **language_pair** (sequence, 可选) - 包含源语和目标语的语言列表,支持的语言对有 ``('en', 'nl')`` 、 + ``('en', 'de')`` 、 ``('en', 'it')`` 、 ``('en', 'ro')`` 、 ``('nl', 'en')`` 、 ``('nl', 'de')`` 、 ``('nl', 'it')`` 、 ``('nl', 'ro')`` 、 + ``('de', 'en')`` 、 ``('de', 'nl')`` 、 ``('de', 'it')`` 、 ``('de', 'ro')`` 、 ``('it', 'en')`` 、 ``('it', 'nl')`` 、 ``('it', 'de')`` 、 + ``('it', 'ro')`` 、 ``('ro', 'en')`` 、 ``('ro', 'nl')`` 、 ``('ro', 'de')`` 、 ``('ro', 'it')`` 。默认值: ``None`` ,默认为 ``('de', 'en')`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于IWSLT2017数据集:** + + IWSLT是一个专门讨论口译各个方面的重要年度科学会议。IWSLT评估活动中的MT任务被构成一个数据集,该数据集可通过 `wit3 `_ 公开获取。 + IWSLT2017数据集中有德语、英语、意大利语、荷兰语和罗马尼亚语,数据集包括其中任何两种语言的翻译。 + + 可以将原始IWSLT2017数据集文件解压缩到此目录结构中,并由MindSpore的API读取。解压后,还需要将要读取的数据集解压到指定文件夹中。例如,如果要读取de-en的数据集,则需要解压缩de/en目录下的tgz文件,数据集位于解压缩文件夹中。 + + .. code-block:: + + . + └── iwslt2017_dataset_directory + ├── subeval_files + └── texts + ├── ar + │ └── en + │ └── ar-en + ├── cs + │ └── en + │ └── cs-en + ├── de + │ └── en + │ └── de-en + │ ├── IWSLT16.TED.dev2010.de-en.de.xml + │ ├── train.tags.de-en.de + │ ├── ... + ├── en + │ ├── ar + │ │ └── en-ar + │ ├── cs + │ │ └── en-cs + │ ├── de + │ │ └── en-de + │ └── fr + │ └── en-fr + └── fr + └── en + └── fr-en + + **引用:** + + .. code-block:: + + @inproceedings{cettoloEtAl:EAMT2012, + Address = {Trento, Italy}, + Author = {Mauro Cettolo and Christian Girardi and Marcello Federico}, + Booktitle = {Proceedings of the 16$^{th}$ Conference of the European Association for Machine Translation + (EAMT)}, + Date = {28-30}, + Month = {May}, + Pages = {261--268}, + Title = {WIT$^3$: Web Inventory of Transcribed and Translated Talks}, + Year = {2012}} + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.ImageFolderDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.ImageFolderDataset.rst index f02f977977b..d8293ac88db 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.ImageFolderDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.ImageFolderDataset.rst @@ -1,68 +1,68 @@ -mindspore.dataset.ImageFolderDataset -===================================== - -.. py:class:: mindspore.dataset.ImageFolderDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, extensions=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None, decrypt=None) - - 从树状结构的文件目录中读取图片构建源数据集。同一个文件夹中的所有图片将被分配相同的label。 - - 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **extensions** (list[str], 可选) - 指定文件的扩展名,仅读取与指定扩展名匹配的文件到数据集中。默认值: ``None`` 。 - - **class_indexing** (dict, 可选) - 指定文件夹名称到label索引的映射,要求映射规则为string到int。文件夹名称将按字母顺序排列,索引值从0开始,并且要求每个文件夹名称对应的索引值唯一。默认值: ``None`` ,不指定。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 - - 异常: - - **RuntimeError** - `dataset_dir` 不包含任何数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **RuntimeError** - `class_indexing` 参数的类型不是dict。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: - - 如果 `decode` 参数的值为 ``False`` ,则得到的 `image` 列的shape为[undecoded_image_size],如果为True则 `image` 列的shape为[H,W,C]。 - - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于ImageFolderDataset:** - - 您可以将图片数据文件构建成如下目录结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── image_folder_dataset_directory - ├── class1 - │ ├── 000000000001.jpg - │ ├── 000000000002.jpg - │ ├── ... - ├── class2 - │ ├── 000000000001.jpg - │ ├── 000000000002.jpg - │ ├── ... - ├── class3 - │ ├── 000000000001.jpg - │ ├── 000000000002.jpg - │ ├── ... - ├── classN - ├── ... - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.ImageFolderDataset +===================================== + +.. py:class:: mindspore.dataset.ImageFolderDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, extensions=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None, decrypt=None) + + 从树状结构的文件目录中读取图片构建源数据集。同一个文件夹中的所有图片将被分配相同的label。 + + 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **extensions** (list[str], 可选) - 指定文件的扩展名,仅读取与指定扩展名匹配的文件到数据集中。默认值: ``None`` 。 + - **class_indexing** (dict, 可选) - 指定文件夹名称到label索引的映射,要求映射规则为string到int。文件夹名称将按字母顺序排列,索引值从0开始,并且要求每个文件夹名称对应的索引值唯一。默认值: ``None`` ,不指定。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 + + 异常: + - **RuntimeError** - `dataset_dir` 不包含任何数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **RuntimeError** - `class_indexing` 参数的类型不是dict。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: + - 如果 `decode` 参数的值为 ``False`` ,则得到的 `image` 列的shape为[undecoded_image_size],如果为True则 `image` 列的shape为[H,W,C]。 + - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于ImageFolderDataset:** + + 您可以将图片数据文件构建成如下目录结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── image_folder_dataset_directory + ├── class1 + │ ├── 000000000001.jpg + │ ├── 000000000002.jpg + │ ├── ... + ├── class2 + │ ├── 000000000001.jpg + │ ├── 000000000002.jpg + │ ├── ... + ├── class3 + │ ├── 000000000001.jpg + │ ├── 000000000002.jpg + │ ├── ... + ├── classN + ├── ... + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.KMnistDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.KMnistDataset.rst index 7374d11524f..f2981154397 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.KMnistDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.KMnistDataset.rst @@ -1,69 +1,69 @@ -mindspore.dataset.KMnistDataset -=============================== - -.. py:class:: mindspore.dataset.KMnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - KMNIST(Kuzushiji-MNIST)数据集。 - - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于KMNIST数据集:** - - KMNIST是一个数据集,改编自Kuzushiji数据集,作为MNIST数据集的替代数据集(MNIST数据集是机器学习社区中著名的数据集)。 - 以下是原始KMNIST数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── kmnist_dataset_dir - ├── t10k-images-idx3-ubyte - ├── t10k-labels-idx1-ubyte - ├── train-images-idx3-ubyte - └── train-labels-idx1-ubyte - - **引用:** - - .. code-block:: - - @online{clanuwat2018deep, - author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and - Alex Lamb and Kazuaki Yamamoto and David Ha}, - title = {Deep Learning for Classical Japanese Literature}, - date = {2018-12-03}, - year = {2018}, - eprintclass = {cs.CV}, - eprinttype = {arXiv}, - eprint = {cs.CV/1812.01718}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.KMnistDataset +=============================== + +.. py:class:: mindspore.dataset.KMnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + KMNIST(Kuzushiji-MNIST)数据集。 + + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于KMNIST数据集:** + + KMNIST是一个数据集,改编自Kuzushiji数据集,作为MNIST数据集的替代数据集(MNIST数据集是机器学习社区中著名的数据集)。 + 以下是原始KMNIST数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── kmnist_dataset_dir + ├── t10k-images-idx3-ubyte + ├── t10k-labels-idx1-ubyte + ├── train-images-idx3-ubyte + └── train-labels-idx1-ubyte + + **引用:** + + .. code-block:: + + @online{clanuwat2018deep, + author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and + Alex Lamb and Kazuaki Yamamoto and David Ha}, + title = {Deep Learning for Classical Japanese Literature}, + date = {2018-12-03}, + year = {2018}, + eprintclass = {cs.CV}, + eprinttype = {arXiv}, + eprint = {cs.CV/1812.01718}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.LJSpeechDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.LJSpeechDataset.rst index 92176a5a5ea..a271c55a22c 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.LJSpeechDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.LJSpeechDataset.rst @@ -1,79 +1,79 @@ -mindspore.dataset.LJSpeechDataset -================================= - -.. py:class:: mindspore.dataset.LJSpeechDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - LJSpeech数据集。 - - 生成的数据集有四列: `[waveform, sample_rate, transcription, normalized_transcript]` 。 - `waveform` 列的数据类型为float32。 `sample_rate` 列的数据类型为int32。 `transcription` 列的数据类型为string。 `normalized_transcript` 列的数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本音频。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于LJSPEECH数据集:** - - LJSPEECH是一个公共领域的语音数据集,由13,100个来自7部非小说类书籍的段落短音频片段组成。 - 为每个剪辑片段都进行转录。剪辑的长度从1秒到10秒不等,总长度约为24小时。 - - 这些被阅读的文本于1884年至1964年间出版,属于公共领域。这些音频由LibriVox项目于2016-17年录制。 - - 以下是原始的LJSPEECH数据集结构。 - 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── LJSpeech-1.1 - ├── README - ├── metadata.csv - └── wavs - ├── LJ001-0001.wav - ├── LJ001-0002.wav - ├── LJ001-0003.wav - ├── LJ001-0004.wav - ├── LJ001-0005.wav - ├── LJ001-0006.wav - ├── LJ001-0007.wav - ├── LJ001-0008.wav - ... - ├── LJ050-0277.wav - └── LJ050-0278.wav - - **引用:** - - .. code-block:: - - @misc{lj_speech17, - author = {Keith Ito and Linda Johnson}, - title = {The LJ Speech Dataset}, - howpublished = {url{https://keithito.com/LJ-Speech-Dataset}}, - year = 2017 - } - - -.. include:: mindspore.dataset.api_list_audio.rst +mindspore.dataset.LJSpeechDataset +================================= + +.. py:class:: mindspore.dataset.LJSpeechDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + LJSpeech数据集。 + + 生成的数据集有四列: `[waveform, sample_rate, transcription, normalized_transcript]` 。 + `waveform` 列的数据类型为float32。 `sample_rate` 列的数据类型为int32。 `transcription` 列的数据类型为string。 `normalized_transcript` 列的数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本音频。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于LJSPEECH数据集:** + + LJSPEECH是一个公共领域的语音数据集,由13,100个来自7部非小说类书籍的段落短音频片段组成。 + 为每个剪辑片段都进行转录。剪辑的长度从1秒到10秒不等,总长度约为24小时。 + + 这些被阅读的文本于1884年至1964年间出版,属于公共领域。这些音频由LibriVox项目于2016-17年录制。 + + 以下是原始的LJSPEECH数据集结构。 + 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── LJSpeech-1.1 + ├── README + ├── metadata.csv + └── wavs + ├── LJ001-0001.wav + ├── LJ001-0002.wav + ├── LJ001-0003.wav + ├── LJ001-0004.wav + ├── LJ001-0005.wav + ├── LJ001-0006.wav + ├── LJ001-0007.wav + ├── LJ001-0008.wav + ... + ├── LJ050-0277.wav + └── LJ050-0278.wav + + **引用:** + + .. code-block:: + + @misc{lj_speech17, + author = {Keith Ito and Linda Johnson}, + title = {The LJ Speech Dataset}, + howpublished = {url{https://keithito.com/LJ-Speech-Dataset}}, + year = 2017 + } + + +.. include:: mindspore.dataset.api_list_audio.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.ManifestDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.ManifestDataset.rst index 15be2b57c69..eadbf57486d 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.ManifestDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.ManifestDataset.rst @@ -1,64 +1,64 @@ -mindspore.dataset.ManifestDataset -================================== - -.. py:class:: mindspore.dataset.ManifestDataset(dataset_file, usage='train', num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None) - - 读取和解析Manifest数据文件构建数据集。 - - 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8类型。 `label` 列的数据类型为uint64类型。 - - 参数: - - **dataset_file** (str) - 数据集文件的目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'eval'`` 或 ``'inference'`` 。默认值: ``'train'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **class_indexing** (dict, 可选) - 指定一个从label名称到label索引的映射,要求映射规则为string到int。索引值从0开始,并且要求每个label名称对应的索引值唯一。默认值: ``None`` ,不指定。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_files` 路径下不包含任何数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **RuntimeError** - `class_indexing` 参数的类型不是dict。 - - **ValueError** - `shard_id` 参数值错误(小于0或者大于等于 `num_shards`)。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: - - 如果 `decode` 为 ``False`` ,`image` 列返回图像的一维原始字节。否则,将返回 shape 为 :math:`[H,W,C]` 的解码图像。 - - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Manifest数据集:** - - Manifest文件包含数据集中包含的文件列表,包括文件名和文件ID等基本文件信息,以及扩展文件元数据。 - Manifest是华为ModelArts支持的数据格式文件,详细说明请参见 `Manifest文档 `_ 。 - - 以下是原始Manifest数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── manifest_dataset_directory - ├── train - │ ├── 1.JPEG - │ ├── 2.JPEG - │ ├── ... - ├── eval - │ ├── 1.JPEG - │ ├── 2.JPEG - │ ├── ... - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.ManifestDataset +================================== + +.. py:class:: mindspore.dataset.ManifestDataset(dataset_file, usage='train', num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None) + + 读取和解析Manifest数据文件构建数据集。 + + 生成的数据集有两列: `[image, label]` 。 `image` 列的数据类型为uint8类型。 `label` 列的数据类型为uint64类型。 + + 参数: + - **dataset_file** (str) - 数据集文件的目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'eval'`` 或 ``'inference'`` 。默认值: ``'train'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **class_indexing** (dict, 可选) - 指定一个从label名称到label索引的映射,要求映射规则为string到int。索引值从0开始,并且要求每个label名称对应的索引值唯一。默认值: ``None`` ,不指定。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_files` 路径下不包含任何数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **RuntimeError** - `class_indexing` 参数的类型不是dict。 + - **ValueError** - `shard_id` 参数值错误(小于0或者大于等于 `num_shards`)。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: + - 如果 `decode` 为 ``False`` ,`image` 列返回图像的一维原始字节。否则,将返回 shape 为 :math:`[H,W,C]` 的解码图像。 + - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Manifest数据集:** + + Manifest文件包含数据集中包含的文件列表,包括文件名和文件ID等基本文件信息,以及扩展文件元数据。 + Manifest是华为ModelArts支持的数据格式文件,详细说明请参见 `Manifest文档 `_ 。 + + 以下是原始Manifest数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── manifest_dataset_directory + ├── train + │ ├── 1.JPEG + │ ├── 2.JPEG + │ ├── ... + ├── eval + │ ├── 1.JPEG + │ ├── 2.JPEG + │ ├── ... + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.MindDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.MindDataset.rst index 0bec30c7de9..3315627758d 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.MindDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.MindDataset.rst @@ -1,44 +1,44 @@ -mindspore.dataset.MindDataset -============================== - -.. py:class:: mindspore.dataset.MindDataset(dataset_files, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None) - - 读取和解析MindRecord数据文件构建数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。 - - 参数: - - **dataset_files** (Union[str, list[str]]) - MindRecord文件路径,支持单文件路径字符串、多文件路径字符串列表。如果 `dataset_files` 的类型是字符串,则它代表一组具有相同前缀名的MindRecord文件,同一路径下具有相同前缀名的其他MindRecord文件将会被自动寻找并加载。如果 `dataset_files` 的类型是列表,则它表示所需读取的MindRecord数据文件。 - - **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``None`` ,采用 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和文件中的数据。 - - ``Shuffle.FILES`` :仅混洗文件,当数据集样本量大于1亿条时不支持。 - - ``Shuffle.INFILE`` :保持读入文件的序列,仅混洗每个文件中的数据,当数据集样本量大于1亿条时不支持。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。当前此数据集仅支持以下采样器: :class:`mindspore.dataset.SubsetRandomSampler` 、 :class:`mindspore.dataset.PKSampler` 、 :class:`mindspore.dataset.RandomSampler` 、 :class:`mindspore.dataset.SequentialSampler` 和 :class:`mindspore.dataset.DistributedSampler` 。 - - **padded_sample** (dict, 可选) - 指定额外添加到数据集的样本,可用于在分布式训练时补齐分片数据,注意字典的键名需要与 `columns_list` 指定的列名相同。默认值: ``None`` ,不添加样本。需要与 `num_padded` 参数同时使用。 - - **num_padded** (int, 可选) - 指定额外添加的数据集样本的数量。在分布式训练时可用于为数据集补齐样本,使得总样本数量可被 `num_shards` 整除。默认值: ``None`` ,不添加样本。需要与 `padded_sample` 参数同时使用。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 - - **ValueError** - `num_parallel_workers` 参数超过最大线程数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - :parser: reStructuredText - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.MindDataset +============================== + +.. py:class:: mindspore.dataset.MindDataset(dataset_files, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None) + + 读取和解析MindRecord数据文件构建数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。 + + 参数: + - **dataset_files** (Union[str, list[str]]) - MindRecord文件路径,支持单文件路径字符串、多文件路径字符串列表。如果 `dataset_files` 的类型是字符串,则它代表一组具有相同前缀名的MindRecord文件,同一路径下具有相同前缀名的其他MindRecord文件将会被自动寻找并加载。如果 `dataset_files` 的类型是列表,则它表示所需读取的MindRecord数据文件。 + - **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``None`` ,采用 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和文件中的数据。 + - ``Shuffle.FILES`` :仅混洗文件,当数据集样本量大于1亿条时不支持。 + - ``Shuffle.INFILE`` :保持读入文件的序列,仅混洗每个文件中的数据,当数据集样本量大于1亿条时不支持。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。当前此数据集仅支持以下采样器: :class:`mindspore.dataset.SubsetRandomSampler` 、 :class:`mindspore.dataset.PKSampler` 、 :class:`mindspore.dataset.RandomSampler` 、 :class:`mindspore.dataset.SequentialSampler` 和 :class:`mindspore.dataset.DistributedSampler` 。 + - **padded_sample** (dict, 可选) - 指定额外添加到数据集的样本,可用于在分布式训练时补齐分片数据,注意字典的键名需要与 `columns_list` 指定的列名相同。默认值: ``None`` ,不添加样本。需要与 `num_padded` 参数同时使用。 + - **num_padded** (int, 可选) - 指定额外添加的数据集样本的数量。在分布式训练时可用于为数据集补齐样本,使得总样本数量可被 `num_shards` 整除。默认值: ``None`` ,不添加样本。需要与 `padded_sample` 参数同时使用。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 + - **ValueError** - `num_parallel_workers` 参数超过最大线程数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + :parser: reStructuredText + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.MnistDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.MnistDataset.rst index 6fd4cde961e..866de79aa39 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.MnistDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.MnistDataset.rst @@ -1,69 +1,69 @@ -mindspore.dataset.MnistDataset -=============================== - -.. py:class:: mindspore.dataset.MnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - MNIST数据集。 - - 生成的数据集有两列: `[image, label]`。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于MNIST数据集:** - - MNIST手写数字数据集是NIST数据集的子集,共有60,000个训练样本和10,000个测试样本。此数据集是NIST数据集的子集。数字已经预先进行了尺寸归一化和中心化处理。 - - 以下为原始MNIST数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── mnist_dataset_dir - ├── t10k-images-idx3-ubyte - ├── t10k-labels-idx1-ubyte - ├── train-images-idx3-ubyte - └── train-labels-idx1-ubyte - - **引用:** - - .. code-block:: - - @article{lecun2010mnist, - title = {MNIST handwritten digit database}, - author = {LeCun, Yann and Cortes, Corinna and Burges, CJ}, - journal = {ATT Labs [Online]}, - volume = {2}, - year = {2010}, - howpublished = {http://yann.lecun.com/exdb/mnist} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.MnistDataset +=============================== + +.. py:class:: mindspore.dataset.MnistDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + MNIST数据集。 + + 生成的数据集有两列: `[image, label]`。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取60,000个训练样本,取值为 ``'test'`` 时将会读取10,000个测试样本,取值为 ``'all'`` 时将会读取全部70,000个样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `usage` 参数取值不为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于MNIST数据集:** + + MNIST手写数字数据集是NIST数据集的子集,共有60,000个训练样本和10,000个测试样本。此数据集是NIST数据集的子集。数字已经预先进行了尺寸归一化和中心化处理。 + + 以下为原始MNIST数据集的结构。您可以将数据集文件解压得到如下的文件结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── mnist_dataset_dir + ├── t10k-images-idx3-ubyte + ├── t10k-labels-idx1-ubyte + ├── train-images-idx3-ubyte + └── train-labels-idx1-ubyte + + **引用:** + + .. code-block:: + + @article{lecun2010mnist, + title = {MNIST handwritten digit database}, + author = {LeCun, Yann and Cortes, Corinna and Burges, CJ}, + journal = {ATT Labs [Online]}, + volume = {2}, + year = {2010}, + howpublished = {http://yann.lecun.com/exdb/mnist} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.NumpySlicesDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.NumpySlicesDataset.rst index a6748d95d3c..a21111b5156 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.NumpySlicesDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.NumpySlicesDataset.rst @@ -1,39 +1,39 @@ -mindspore.dataset.NumpySlicesDataset -===================================== - -.. py:class:: mindspore.dataset.NumpySlicesDataset(data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None) - - 由Python数据构建数据集。生成的数据集的列名和列类型取决于用户传入的Python数据。 - - 参数: - - **data** (Union[list, tuple, dict]) - 输入的Python数据。支持的数据类型包括:list、tuple、dict和其他NumPy格式。 - 输入数据将沿着第一个维度切片,并生成额外的行。如果输入是单个list,则将生成一个数据列,若是嵌套多个list,则生成多个数据列。不建议通过这种方式加载大量的数据,因为可能会在数据加载到内存时等待较长时间。 - - **column_names** (list[str], 可选) - 指定数据集生成的列名。默认值: ``None`` ,不指定。 - 如果未指定该参数,且当输入数据的类型是dict时,输出列名称将被命名为dict的键名,否则它们将被统一命名为column_0,column_1...。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 - - **shuffle** (bool, 可选) - 是否混洗数据集。 - 只有输入的 `data` 参数带有可随机访问属性(`__getitem__`)时,才可以指定该参数。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **sampler** (Union[Sampler, Iterable], 可选) - 指定从数据集中选取样本的采样器。 - 只有输入的 `data` 参数带有可随机访问属性(`__getitem__`)时,才可以指定该参数。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - 异常: - - **RuntimeError** - `column_names` 列表的长度与数据的输出列表长度不匹配。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **ValueError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **ValueError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **ValueError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.NumpySlicesDataset +===================================== + +.. py:class:: mindspore.dataset.NumpySlicesDataset(data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None) + + 由Python数据构建数据集。生成的数据集的列名和列类型取决于用户传入的Python数据。 + + 参数: + - **data** (Union[list, tuple, dict]) - 输入的Python数据。支持的数据类型包括:list、tuple、dict和其他NumPy格式。 + 输入数据将沿着第一个维度切片,并生成额外的行。如果输入是单个list,则将生成一个数据列,若是嵌套多个list,则生成多个数据列。不建议通过这种方式加载大量的数据,因为可能会在数据加载到内存时等待较长时间。 + - **column_names** (list[str], 可选) - 指定数据集生成的列名。默认值: ``None`` ,不指定。 + 如果未指定该参数,且当输入数据的类型是dict时,输出列名称将被命名为dict的键名,否则它们将被统一命名为column_0,column_1...。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 + - **shuffle** (bool, 可选) - 是否混洗数据集。 + 只有输入的 `data` 参数带有可随机访问属性(`__getitem__`)时,才可以指定该参数。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **sampler** (Union[Sampler, Iterable], 可选) - 指定从数据集中选取样本的采样器。 + 只有输入的 `data` 参数带有可随机访问属性(`__getitem__`)时,才可以指定该参数。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + 异常: + - **RuntimeError** - `column_names` 列表的长度与数据的输出列表长度不匹配。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **ValueError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **ValueError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **ValueError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst index 081484aa836..eb1a6b9973e 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.OBSMindDataset.rst @@ -1,48 +1,48 @@ -mindspore.dataset.OBSMindDataset -================================== - -.. py:class:: mindspore.dataset.OBSMindDataset(dataset_files, server, ak, sk, sync_obs_path, columns_list=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=True) - - 读取和解析存放在华为云OBS、Minio以及AWS S3等云存储上的MindRecord格式数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。 - - 参数: - - **dataset_files** (list[str]) - 云存储上MindRecord格式数据集文件的路径列表,每个文件的路径格式为s3://bucketName/objectKey。 - - **server** (str) - 连接云存储的服务地址。可包含协议类型、域名、端口号。 - 假如为华为云OBS,服务地址为: ```` 。 - 假如为Minio,服务地址为: ```` 。 - - **ak** (str) - 用于访问OBS数据的访问密钥ID。 - - **sk** (str) - 用于访问OBS数据的私有访问密钥。 - - **sync_obs_path** (str) - 用于同步操作云存储上的路径,用户需要提前创建,目录路径的格式为s3://bucketName/objectKey。 - - **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和文件中的数据。 - - ``Shuffle.FILES`` :仅混洗文件。 - - ``Shuffle.INFILE`` :保持读入文件的序列,仅混洗每个文件中的数据。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值: ``True`` 。 - 如果 `shard_equal_rows` 为False,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。 - 因此当每个MindRecord文件的数据数量不相等时,建议将此参数设置为 ``True`` 。注意,只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - `sync_obs_path` 参数指定的目录不存在。 - - **ValueError** - `columns_list` 参数无效。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: - - 需要用户提前在云存储上创建同步用的目录,然后通过 `sync_obs_path` 指定。 - - 如果线下训练,建议为每次训练设置 `BATCH_JOB_ID` 环境变量。 - - 分布式训练中,假如使用多个节点(服务器),则必须使用每个节点全部的8张卡。如果只有一个节点(服务器),则没有这样的限制。 - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.OBSMindDataset +================================== + +.. py:class:: mindspore.dataset.OBSMindDataset(dataset_files, server, ak, sk, sync_obs_path, columns_list=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=True) + + 读取和解析存放在华为云OBS、Minio以及AWS S3等云存储上的MindRecord格式数据集。生成的数据集的列名和列类型取决于MindRecord文件中的保存的列名与类型。 + + 参数: + - **dataset_files** (list[str]) - 云存储上MindRecord格式数据集文件的路径列表,每个文件的路径格式为s3://bucketName/objectKey。 + - **server** (str) - 连接云存储的服务地址。可包含协议类型、域名、端口号。 + 假如为华为云OBS,服务地址为: ```` 。 + 假如为Minio,服务地址为: ```` 。 + - **ak** (str) - 用于访问OBS数据的访问密钥ID。 + - **sk** (str) - 用于访问OBS数据的私有访问密钥。 + - **sync_obs_path** (str) - 用于同步操作云存储上的路径,用户需要提前创建,目录路径的格式为s3://bucketName/objectKey。 + - **columns_list** (list[str],可选) - 指定从MindRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和文件中的数据。 + - ``Shuffle.FILES`` :仅混洗文件。 + - ``Shuffle.INFILE`` :保持读入文件的序列,仅混洗每个文件中的数据。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值: ``True`` 。 + 如果 `shard_equal_rows` 为False,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。 + 因此当每个MindRecord文件的数据数量不相等时,建议将此参数设置为 ``True`` 。注意,只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - `sync_obs_path` 参数指定的目录不存在。 + - **ValueError** - `columns_list` 参数无效。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: + - 需要用户提前在云存储上创建同步用的目录,然后通过 `sync_obs_path` 指定。 + - 如果线下训练,建议为每次训练设置 `BATCH_JOB_ID` 环境变量。 + - 分布式训练中,假如使用多个节点(服务器),则必须使用每个节点全部的8张卡。如果只有一个节点(服务器),则没有这样的限制。 + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.PKSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.PKSampler.rst index aa67ba94ef0..4dc771fdfc4 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.PKSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.PKSampler.rst @@ -1,25 +1,25 @@ -mindspore.dataset.PKSampler -============================== - -.. py:class:: mindspore.dataset.PKSampler(num_val, num_class=None, shuffle=False, class_column='label', num_samples=None) - - 为数据集中每P个类别各采样K个样本。 - - 参数: - - **num_val** (int) - 每个类要采样的元素数量。 - - **num_class** (int, 可选) - 要采样的类数量。默认值为 ``None`` ,采样所有类。当前不支持指定该参数。 - - **shuffle** (bool, 可选) - 是否混洗采样得到的样本。默认值: ``False`` ,不混洗样本。 - - **class_column** (str, 可选) - 指定label所属数据列的名称,将基于此列作为数据标签进行采样。默认值: ``'label'`` 。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - 异常: - - **TypeError** - `shuffle` 的类型不是bool。 - - **TypeError** - `class_column` 的类型不是str。 - - **TypeError** - `num_samples` 的类型不是int。 - - **NotImplementedError** - `num_class` 不为 ``None`` 。 - - **RuntimeError** - `num_val` 不是正值。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst - +mindspore.dataset.PKSampler +============================== + +.. py:class:: mindspore.dataset.PKSampler(num_val, num_class=None, shuffle=False, class_column='label', num_samples=None) + + 为数据集中每P个类别各采样K个样本。 + + 参数: + - **num_val** (int) - 每个类要采样的元素数量。 + - **num_class** (int, 可选) - 要采样的类数量。默认值为 ``None`` ,采样所有类。当前不支持指定该参数。 + - **shuffle** (bool, 可选) - 是否混洗采样得到的样本。默认值: ``False`` ,不混洗样本。 + - **class_column** (str, 可选) - 指定label所属数据列的名称,将基于此列作为数据标签进行采样。默认值: ``'label'`` 。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + + 异常: + - **TypeError** - `shuffle` 的类型不是bool。 + - **TypeError** - `class_column` 的类型不是str。 + - **TypeError** - `num_samples` 的类型不是int。 + - **NotImplementedError** - `num_class` 不为 ``None`` 。 + - **RuntimeError** - `num_val` 不是正值。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst + .. include:: mindspore.dataset.BuiltinSampler.b.rst \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.PaddedDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.PaddedDataset.rst index db6a8565cbf..33304871279 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.PaddedDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.PaddedDataset.rst @@ -1,20 +1,20 @@ -mindspore.dataset.PaddedDataset -================================ - -.. py:class:: mindspore.dataset.PaddedDataset(padded_samples) - - 由用户提供的填充数据构建数据集。可用于在分布式训练时给原始数据集添加样本,使数据集样本能平均分配给不同的分片。 - - 参数: - - **padded_samples** (list(dict)) - 用户提供的样本数据。 - - 异常: - - **TypeError** - `padded_samples` 的类型不为list。 - - **TypeError** - `padded_samples` 的元素类型不为dict。 - - **ValueError** - `padded_samples` 为空的list。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.PaddedDataset +================================ + +.. py:class:: mindspore.dataset.PaddedDataset(padded_samples) + + 由用户提供的填充数据构建数据集。可用于在分布式训练时给原始数据集添加样本,使数据集样本能平均分配给不同的分片。 + + 参数: + - **padded_samples** (list(dict)) - 用户提供的样本数据。 + + 异常: + - **TypeError** - `padded_samples` 的类型不为list。 + - **TypeError** - `padded_samples` 的元素类型不为dict。 + - **ValueError** - `padded_samples` 为空的list。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.PennTreebankDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.PennTreebankDataset.rst index 84253845aa0..95198fed402 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.PennTreebankDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.PennTreebankDataset.rst @@ -1,74 +1,74 @@ -mindspore.dataset.PennTreebankDataset -===================================== - -.. py:class:: mindspore.dataset.PennTreebankDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - PennTreebank数据集。 - - 生成的数据集有一列 `[text]`。 `text` 列的数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 将读取42,068个样本, ``'test'`` 将读取3,370个样本, ``'valid'`` 将读取3,761个样本, ``'all'`` 将读取所有49,199个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于PennTreebank数据集:** - - Penn Treebank (PTB) 数据集,广泛用于 NLP(自然语言处理)的机器学习研究。 - PTB 不包含大写字母、数字和标点符号,其词汇表上限为10k个不重复词,与大多数现代数据集相比相对较小,可能会导致出现大量超出词汇表外的token。 - - 以下是原始的PennTreebank数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── PennTreebank_dataset_dir - ├── ptb.test.txt - ├── ptb.train.txt - └── ptb.valid.txt - - **引用:** - - .. code-block:: - - @techreport{Santorini1990, - added-at = {2014-03-26T23:25:56.000+0100}, - author = {Santorini, Beatrice}, - biburl = {https://www.bibsonomy.org/bibtex/234cdf6ddadd89376090e7dada2fc18ec/butonic}, - file = {:Santorini - Penn Treebank tag definitions.pdf:PDF}, - institution = {Department of Computer and Information Science, University of Pennsylvania}, - interhash = {818e72efd9e4b5fae3e51e88848100a0}, - intrahash = {34cdf6ddadd89376090e7dada2fc18ec}, - keywords = {dis pos tagging treebank}, - number = {MS-CIS-90-47}, - timestamp = {2014-03-26T23:25:56.000+0100}, - title = {Part-of-speech tagging guidelines for the {P}enn {T}reebank {P}roject}, - url = {ftp://ftp.cis.upenn.edu/pub/treebank/doc/tagguide.ps.gz}, - year = 1990 - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.PennTreebankDataset +===================================== + +.. py:class:: mindspore.dataset.PennTreebankDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + PennTreebank数据集。 + + 生成的数据集有一列 `[text]`。 `text` 列的数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 将读取42,068个样本, ``'test'`` 将读取3,370个样本, ``'valid'`` 将读取3,761个样本, ``'all'`` 将读取所有49,199个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于PennTreebank数据集:** + + Penn Treebank (PTB) 数据集,广泛用于 NLP(自然语言处理)的机器学习研究。 + PTB 不包含大写字母、数字和标点符号,其词汇表上限为10k个不重复词,与大多数现代数据集相比相对较小,可能会导致出现大量超出词汇表外的token。 + + 以下是原始的PennTreebank数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── PennTreebank_dataset_dir + ├── ptb.test.txt + ├── ptb.train.txt + └── ptb.valid.txt + + **引用:** + + .. code-block:: + + @techreport{Santorini1990, + added-at = {2014-03-26T23:25:56.000+0100}, + author = {Santorini, Beatrice}, + biburl = {https://www.bibsonomy.org/bibtex/234cdf6ddadd89376090e7dada2fc18ec/butonic}, + file = {:Santorini - Penn Treebank tag definitions.pdf:PDF}, + institution = {Department of Computer and Information Science, University of Pennsylvania}, + interhash = {818e72efd9e4b5fae3e51e88848100a0}, + intrahash = {34cdf6ddadd89376090e7dada2fc18ec}, + keywords = {dis pos tagging treebank}, + number = {MS-CIS-90-47}, + timestamp = {2014-03-26T23:25:56.000+0100}, + title = {Part-of-speech tagging guidelines for the {P}enn {T}reebank {P}roject}, + url = {ftp://ftp.cis.upenn.edu/pub/treebank/doc/tagguide.ps.gz}, + year = 1990 + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.PhotoTourDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.PhotoTourDataset.rst index c7e371479e3..52128b70d7f 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.PhotoTourDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.PhotoTourDataset.rst @@ -1,98 +1,98 @@ -mindspore.dataset.PhotoTourDataset -================================== - -.. py:class:: mindspore.dataset.PhotoTourDataset(dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - PhotoTour数据集。 - - 根据给定的 `usage` 配置,生成数据集具有不同的输出列: - - - `usage` = 'train',输出列: `[image, dtype=uint8]` 。 - - `usage` ≠ 'train',输出列: `[image1, dtype=uint8]` 、 `[image2, dtype=uint8]` 、 `[matches, dtype=uint32]` 。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **name** (str) - 要加载的数据集内容名称,可以取值为 ``'notredame'`` 、 ``'yosemite'`` 、 ``'liberty'`` 、 ``'notredame_harris'`` 、 ``'yosemite_harris'`` 或 ``'liberty_harris'`` 。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 或 ``'test'``。默认值: ``None`` ,将被设置为 ``'train'`` 。 - 取值为 ``'train'`` 时,每个 `name` 的数据集样本数分别为{'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}。 - 取值为 ``'test'`` 时,将读取100,000个测试样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `dataset_dir` 不存在。 - - **ValueError** - `usage` 不是 ``'train'`` 或 ``'test'`` 。 - - **ValueError** - `name` 不是 ``''notredame'`` 、 ``'yosemite'`` 、 ``'liberty'`` 、 ``'notredame_harris'`` 、 ``'yosemite_harris'`` 或 ``'liberty_harris'`` 。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于PhotoTour数据集:** - - 数据取自许愿池(罗马)、巴黎圣母院(巴黎)和半圆顶(美国约塞米蒂国家公园)的旅游圣地照片。 - 每个数据集包括一系列相应的图像块,是通过将旅游圣地的照片中的3D点投影回到原始图像而获得的。 - - 数据集由1024 x 1024位图(.bmp)图像组成,每个图像都包含16 x 16的图像修补数组。 - 每个图像块都以64 x 64灰度采样,具有规范的比例和方向。有关如何确定比例和方向的详细信息,请参见论文。 - 关联的元数据文件info.txt包含匹配信息。info.txt的每一行对应一个单独的图像块,图像块在每个位图图像中从左到右、从上到下顺序排列。 - info.txt每行上的第一个数字是采样该图像块的3D点ID——具有相同3D点ID的图像块从同一3D点投影(到不同的图像中)。 - info.txt中的第二个数字代表图像块是从哪个原始图像采样得到,目前未使用。 - - 可以将原始PhotoTour数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── photo_tour_dataset_directory - ├── liberty/ - │ ├── info.txt // two columns: 3D_point_ID, unused - │ ├── m50_100000_100000_0.txt // seven columns: patch_ID1, 3D_point_ID1, unused1, - │ │ // patch_ID2, 3D_point_ID2, unused2, unused3 - │ ├── patches0000.bmp // 1024*1024 pixels, with 16 * 16 patches. - │ ├── patches0001.bmp - │ ├── ... - ├── yosemite/ - │ ├── ... - ├── notredame/ - │ ├── ... - ├── liberty_harris/ - │ ├── ... - ├── yosemite_harris/ - │ ├── ... - ├── notredame_harris/ - │ ├── ... - - **引用:** - - .. code-block:: - - @INPROCEEDINGS{4269996, - author={Winder, Simon A. J. and Brown, Matthew}, - booktitle={2007 IEEE Conference on Computer Vision and Pattern Recognition}, - title={Learning Local Image Descriptors}, - year={2007}, - volume={}, - number={}, - pages={1-8}, - doi={10.1109/CVPR.2007.382971} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.PhotoTourDataset +================================== + +.. py:class:: mindspore.dataset.PhotoTourDataset(dataset_dir, name, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + PhotoTour数据集。 + + 根据给定的 `usage` 配置,生成数据集具有不同的输出列: + + - `usage` = 'train',输出列: `[image, dtype=uint8]` 。 + - `usage` ≠ 'train',输出列: `[image1, dtype=uint8]` 、 `[image2, dtype=uint8]` 、 `[matches, dtype=uint32]` 。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **name** (str) - 要加载的数据集内容名称,可以取值为 ``'notredame'`` 、 ``'yosemite'`` 、 ``'liberty'`` 、 ``'notredame_harris'`` 、 ``'yosemite_harris'`` 或 ``'liberty_harris'`` 。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 或 ``'test'``。默认值: ``None`` ,将被设置为 ``'train'`` 。 + 取值为 ``'train'`` 时,每个 `name` 的数据集样本数分别为{'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}。 + 取值为 ``'test'`` 时,将读取100,000个测试样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `dataset_dir` 不存在。 + - **ValueError** - `usage` 不是 ``'train'`` 或 ``'test'`` 。 + - **ValueError** - `name` 不是 ``''notredame'`` 、 ``'yosemite'`` 、 ``'liberty'`` 、 ``'notredame_harris'`` 、 ``'yosemite_harris'`` 或 ``'liberty_harris'`` 。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于PhotoTour数据集:** + + 数据取自许愿池(罗马)、巴黎圣母院(巴黎)和半圆顶(美国约塞米蒂国家公园)的旅游圣地照片。 + 每个数据集包括一系列相应的图像块,是通过将旅游圣地的照片中的3D点投影回到原始图像而获得的。 + + 数据集由1024 x 1024位图(.bmp)图像组成,每个图像都包含16 x 16的图像修补数组。 + 每个图像块都以64 x 64灰度采样,具有规范的比例和方向。有关如何确定比例和方向的详细信息,请参见论文。 + 关联的元数据文件info.txt包含匹配信息。info.txt的每一行对应一个单独的图像块,图像块在每个位图图像中从左到右、从上到下顺序排列。 + info.txt每行上的第一个数字是采样该图像块的3D点ID——具有相同3D点ID的图像块从同一3D点投影(到不同的图像中)。 + info.txt中的第二个数字代表图像块是从哪个原始图像采样得到,目前未使用。 + + 可以将原始PhotoTour数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── photo_tour_dataset_directory + ├── liberty/ + │ ├── info.txt // two columns: 3D_point_ID, unused + │ ├── m50_100000_100000_0.txt // seven columns: patch_ID1, 3D_point_ID1, unused1, + │ │ // patch_ID2, 3D_point_ID2, unused2, unused3 + │ ├── patches0000.bmp // 1024*1024 pixels, with 16 * 16 patches. + │ ├── patches0001.bmp + │ ├── ... + ├── yosemite/ + │ ├── ... + ├── notredame/ + │ ├── ... + ├── liberty_harris/ + │ ├── ... + ├── yosemite_harris/ + │ ├── ... + ├── notredame_harris/ + │ ├── ... + + **引用:** + + .. code-block:: + + @INPROCEEDINGS{4269996, + author={Winder, Simon A. J. and Brown, Matthew}, + booktitle={2007 IEEE Conference on Computer Vision and Pattern Recognition}, + title={Learning Local Image Descriptors}, + year={2007}, + volume={}, + number={}, + pages={1-8}, + doi={10.1109/CVPR.2007.382971} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Places365Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.Places365Dataset.rst index e3633856e61..9c0c21ffea0 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Places365Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Places365Dataset.rst @@ -1,88 +1,88 @@ -mindspore.dataset.Places365Dataset -================================== - -.. py:class:: mindspore.dataset.Places365Dataset(dataset_dir, usage=None, small=True, decode=False, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Places365数据集。 - - 生成的数据集有两列: `[image, label]`。 - `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train-standard'`` 、 ``'train-challenge'`` 或 ``'val'`` 。默认值: ``None``,将使用 ``'train-standard'`` 。 - - **small** (bool, 可选) - 是否使用256*256的低分辨率图像(True)或高分辨率图像(False)。默认值: ``True`` ,使用低分辨率图像。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `shard_id` 参数错误,参数小于0或者大于等于 `num_shards` 。 - - **ValueError** - `usage` 不是 ``'train-standard'`` 、 ``'train-challenge'`` 或 ``'val'`` 。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Places365数据集:** - - 在Places2数据库上训练的卷积神经网络(CNN)可用于场景识别,也可用于视觉识别的通用深度场景特征。 - - Places作者向公众发布了Places365-Standard数据集和Places365-Challenge数据集。 - Places365-Standard数据集是Places2数据库的核心集,该数据库已用于训练Places365-CNN。 - Places作者将在未来的Places365-Standard数据集上添加其他类型的标注。 - Places365-Challenge数据集是Places2数据库的竞赛数据集,与Places365-Standard数据集相比,该数据库有620万张额外的图像。此数据集用于2016年的Places挑战赛。 - - 可以将原始的Places365数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── categories_places365 - ├── places365_train-standard.txt - ├── places365_train-challenge.txt - ├── val_large/ - │ ├── Places365_val_00000001.jpg - │ ├── Places365_val_00000002.jpg - │ ├── Places365_val_00000003.jpg - │ ├── ... - ├── val_256/ - │ ├── ... - ├── data_large_standard/ - │ ├── ... - ├── data_256_standard/ - │ ├── ... - ├── data_large_challenge/ - │ ├── ... - ├── data_256_challenge / - │ ├── ... - - **引用:** - - .. code-block:: - - article{zhou2017places, - title={Places: A 10 million Image Database for Scene Recognition}, - author={Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio}, - journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, - year={2017}, - publisher={IEEE} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.Places365Dataset +================================== + +.. py:class:: mindspore.dataset.Places365Dataset(dataset_dir, usage=None, small=True, decode=False, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Places365数据集。 + + 生成的数据集有两列: `[image, label]`。 + `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train-standard'`` 、 ``'train-challenge'`` 或 ``'val'`` 。默认值: ``None``,将使用 ``'train-standard'`` 。 + - **small** (bool, 可选) - 是否使用256*256的低分辨率图像(True)或高分辨率图像(False)。默认值: ``True`` ,使用低分辨率图像。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `shard_id` 参数错误,参数小于0或者大于等于 `num_shards` 。 + - **ValueError** - `usage` 不是 ``'train-standard'`` 、 ``'train-challenge'`` 或 ``'val'`` 。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Places365数据集:** + + 在Places2数据库上训练的卷积神经网络(CNN)可用于场景识别,也可用于视觉识别的通用深度场景特征。 + + Places作者向公众发布了Places365-Standard数据集和Places365-Challenge数据集。 + Places365-Standard数据集是Places2数据库的核心集,该数据库已用于训练Places365-CNN。 + Places作者将在未来的Places365-Standard数据集上添加其他类型的标注。 + Places365-Challenge数据集是Places2数据库的竞赛数据集,与Places365-Standard数据集相比,该数据库有620万张额外的图像。此数据集用于2016年的Places挑战赛。 + + 可以将原始的Places365数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── categories_places365 + ├── places365_train-standard.txt + ├── places365_train-challenge.txt + ├── val_large/ + │ ├── Places365_val_00000001.jpg + │ ├── Places365_val_00000002.jpg + │ ├── Places365_val_00000003.jpg + │ ├── ... + ├── val_256/ + │ ├── ... + ├── data_large_standard/ + │ ├── ... + ├── data_256_standard/ + │ ├── ... + ├── data_large_challenge/ + │ ├── ... + ├── data_256_challenge / + │ ├── ... + + **引用:** + + .. code-block:: + + article{zhou2017places, + title={Places: A 10 million Image Database for Scene Recognition}, + author={Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio}, + journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, + year={2017}, + publisher={IEEE} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.QMnistDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.QMnistDataset.rst index 5d818c2da0d..89e8d10b536 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.QMnistDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.QMnistDataset.rst @@ -1,71 +1,71 @@ -mindspore.dataset.QMnistDataset -=============================== - -.. py:class:: mindspore.dataset.QMnistDataset(dataset_dir, usage=None, compat=True, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - QMNIST数据集。 - - 生成的数据集有两列: `[image, label]`。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'test10k'`` 、 ``'test50k'`` 、 ``'nist'`` 或 ``'all'`` 。默认值: ``None`` ,读取所有子集。 - - **compat** (bool, 可选) - 若为 ``True`` ,指定每个样本的标签是类别号,否则指定标签是完整的QMNIST信息。默认值: ``True`` ,标签为类别号。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于QMNIST数据集:** - - QMNIST 数据集是从 NIST Special Database 19 中的原始数据生成的,目的是尽可能地匹配 MNIST 预处理。 - 研究人员试图生成额外的 50k 类似 MNIST 数据的图像。在QMNIST论文中,作者给出了重建过程,并使用匈牙利算法来找到原始 MNIST 样本与其重建样本之间的最佳匹配。 - - 以下是原始的QMNIST数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── qmnist_dataset_dir - ├── qmnist-train-images-idx3-ubyte - ├── qmnist-train-labels-idx2-int - ├── qmnist-test-images-idx3-ubyte - ├── qmnist-test-labels-idx2-int - ├── xnist-images-idx3-ubyte - └── xnist-labels-idx2-int - - **引用:** - - .. code-block:: - - @incollection{qmnist-2019, - title = "Cold Case: The Lost MNIST Digits", - author = "Chhavi Yadav and L\'{e}on Bottou",\ - booktitle = {Advances in Neural Information Processing Systems 32}, - year = {2019}, - publisher = {Curran Associates, Inc.}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.QMnistDataset +=============================== + +.. py:class:: mindspore.dataset.QMnistDataset(dataset_dir, usage=None, compat=True, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + QMNIST数据集。 + + 生成的数据集有两列: `[image, label]`。 `image` 列的数据类型为uint8。 `label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'test10k'`` 、 ``'test50k'`` 、 ``'nist'`` 或 ``'all'`` 。默认值: ``None`` ,读取所有子集。 + - **compat** (bool, 可选) - 若为 ``True`` ,指定每个样本的标签是类别号,否则指定标签是完整的QMNIST信息。默认值: ``True`` ,标签为类别号。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于QMNIST数据集:** + + QMNIST 数据集是从 NIST Special Database 19 中的原始数据生成的,目的是尽可能地匹配 MNIST 预处理。 + 研究人员试图生成额外的 50k 类似 MNIST 数据的图像。在QMNIST论文中,作者给出了重建过程,并使用匈牙利算法来找到原始 MNIST 样本与其重建样本之间的最佳匹配。 + + 以下是原始的QMNIST数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── qmnist_dataset_dir + ├── qmnist-train-images-idx3-ubyte + ├── qmnist-train-labels-idx2-int + ├── qmnist-test-images-idx3-ubyte + ├── qmnist-test-labels-idx2-int + ├── xnist-images-idx3-ubyte + └── xnist-labels-idx2-int + + **引用:** + + .. code-block:: + + @incollection{qmnist-2019, + title = "Cold Case: The Lost MNIST Digits", + author = "Chhavi Yadav and L\'{e}on Bottou",\ + booktitle = {Advances in Neural Information Processing Systems 32}, + year = {2019}, + publisher = {Curran Associates, Inc.}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.RandomDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.RandomDataset.rst index 232c571ab88..40e80d2d295 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.RandomDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.RandomDataset.rst @@ -1,35 +1,35 @@ -mindspore.dataset.RandomDataset -=============================== - -.. py:class:: mindspore.dataset.RandomDataset(total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, cache=None, shuffle=None, num_shards=None, shard_id=None) - - 生成随机数据的源数据集。 - - 参数: - - **total_rows** (int, 可选) - 随机生成样本数据的数量。默认值: ``None`` ,生成随机数量的样本。 - - **schema** (Union[str, :class:`~.dataset.Schema`], 可选) - 数据格式策略,用于指定读取数据列的数据类型、数据维度等信息。 - 支持传入JSON文件路径或 :class:`mindspore.dataset.Schema` 构造的对象。默认值: ``None`` 。 - - **columns_list** (list[str], 可选) - 指定生成数据集的列名。默认值: ``None`` ,生成的数据列将以"c0"、"c1"、"c2" ... "cn"的规则命名。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **TypeError** - `total_rows` 的类型不是int。 - - **TypeError** - `num_shards` 的类型不是int。 - - **TypeError** - `num_parallel_workers` 的类型不是int。 - - **TypeError** - `shuffle` 的类型不是bool。 - - **TypeError** - `columns_list` 的类型不是list。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.RandomDataset +=============================== + +.. py:class:: mindspore.dataset.RandomDataset(total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, cache=None, shuffle=None, num_shards=None, shard_id=None) + + 生成随机数据的源数据集。 + + 参数: + - **total_rows** (int, 可选) - 随机生成样本数据的数量。默认值: ``None`` ,生成随机数量的样本。 + - **schema** (Union[str, :class:`~.dataset.Schema`], 可选) - 数据格式策略,用于指定读取数据列的数据类型、数据维度等信息。 + 支持传入JSON文件路径或 :class:`mindspore.dataset.Schema` 构造的对象。默认值: ``None`` 。 + - **columns_list** (list[str], 可选) - 指定生成数据集的列名。默认值: ``None`` ,生成的数据列将以"c0"、"c1"、"c2" ... "cn"的规则命名。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **TypeError** - `total_rows` 的类型不是int。 + - **TypeError** - `num_shards` 的类型不是int。 + - **TypeError** - `num_parallel_workers` 的类型不是int。 + - **TypeError** - `shuffle` 的类型不是bool。 + - **TypeError** - `columns_list` 的类型不是list。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.RandomSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.RandomSampler.rst index 91bef467188..6ad3124425b 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.RandomSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.RandomSampler.rst @@ -1,19 +1,19 @@ -mindspore.dataset.RandomSampler -================================ - -.. py:class:: mindspore.dataset.RandomSampler(replacement=False, num_samples=None) - - 随机采样器。 - - 参数: - - **replacement** (bool, 可选) - 是否将样本ID放回下一次采样。默认值: ``False`` ,无放回采样。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - 异常: - - **TypeError** - `replacement` 不是bool值。 - - **TypeError** - `num_samples` 不是整数值。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst - +mindspore.dataset.RandomSampler +================================ + +.. py:class:: mindspore.dataset.RandomSampler(replacement=False, num_samples=None) + + 随机采样器。 + + 参数: + - **replacement** (bool, 可选) - 是否将样本ID放回下一次采样。默认值: ``False`` ,无放回采样。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + + 异常: + - **TypeError** - `replacement` 不是bool值。 + - **TypeError** - `num_samples` 不是整数值。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst + .. include:: mindspore.dataset.BuiltinSampler.b.rst \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.SBDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SBDataset.rst index 1fcdf8b02ee..dfcd6982fab 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SBDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SBDataset.rst @@ -1,86 +1,86 @@ -mindspore.dataset.SBDataset -=========================== - -.. py:class:: mindspore.dataset.SBDataset(dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None) - - SB(Semantic Boundaries)数据集。 - - 通过配置 `task` 参数,生成的数据集具有不同的输出列: - - - `task` 为 ``'Boundaries'`` ,有两个输出列: `image` 列的数据类型为uint8, `label` 列包含1个的数据类型为uint8的图像。 - - `task` 为 ``'Segmentation'`` ,有两个输出列: `image` 列的数据类型为uint8。 `label` 列包含20个的数据类型为uint8的图像。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **task** (str, 可选) - 指定读取SB数据集的任务类型,支持 ``'Boundaries'`` 和 ``'Segmentation'``。默认值: ``'Boundaries'`` 。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'val'`` 、 ``'train_noval'`` 和 ``'all'`` 。默认值: ``'all'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `dataset_dir` 不存在。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `task` 不是 ``'Boundaries'`` 或 ``'Segmentation'`` 。 - - **ValueError** - `usage` 不是 ``'train'`` 、 ``'val'`` 、 ``'train_noval'`` 或 ``'all'`` 。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于Semantic Boundaries数据集:** - - Semantic Boundaries(语义边界)数据集由11355张彩色图像组成。 - train.txt中有8498个图像,val.txt中有2857个图像,train_noval.txt中有5623个图像。 - 目录cls中包含类别的分割和边界标注,目录inst中包含实例级的分割和边界标注。 - - 可以将数据集文件解压缩为以下结构,并通过MindSpore的API读取: - - .. code-block:: - - . - └── benchmark_RELEASE - ├── dataset - ├── img - │ ├── 2008_000002.jpg - │ ├── 2008_000003.jpg - │ ├── ... - ├── cls - │ ├── 2008_000002.mat - │ ├── 2008_000003.mat - │ ├── ... - ├── inst - │ ├── 2008_000002.mat - │ ├── 2008_000003.mat - │ ├── ... - ├── train.txt - └── val.txt - - **引用:** - - .. code-block:: - - @InProceedings{BharathICCV2011, - author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and - Subhransu Maji and Jitendra Malik", - title = "Semantic Contours from Inverse Detectors", - booktitle = "International Conference on Computer Vision (ICCV)", - year = "2011", - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.SBDataset +=========================== + +.. py:class:: mindspore.dataset.SBDataset(dataset_dir, task='Boundaries', usage='all', num_samples=None, num_parallel_workers=1, shuffle=None, decode=None, sampler=None, num_shards=None, shard_id=None) + + SB(Semantic Boundaries)数据集。 + + 通过配置 `task` 参数,生成的数据集具有不同的输出列: + + - `task` 为 ``'Boundaries'`` ,有两个输出列: `image` 列的数据类型为uint8, `label` 列包含1个的数据类型为uint8的图像。 + - `task` 为 ``'Segmentation'`` ,有两个输出列: `image` 列的数据类型为uint8。 `label` 列包含20个的数据类型为uint8的图像。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **task** (str, 可选) - 指定读取SB数据集的任务类型,支持 ``'Boundaries'`` 和 ``'Segmentation'``。默认值: ``'Boundaries'`` 。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'val'`` 、 ``'train_noval'`` 和 ``'all'`` 。默认值: ``'all'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``None`` ,默认为 ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `dataset_dir` 不存在。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `task` 不是 ``'Boundaries'`` 或 ``'Segmentation'`` 。 + - **ValueError** - `usage` 不是 ``'train'`` 、 ``'val'`` 、 ``'train_noval'`` 或 ``'all'`` 。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于Semantic Boundaries数据集:** + + Semantic Boundaries(语义边界)数据集由11355张彩色图像组成。 + train.txt中有8498个图像,val.txt中有2857个图像,train_noval.txt中有5623个图像。 + 目录cls中包含类别的分割和边界标注,目录inst中包含实例级的分割和边界标注。 + + 可以将数据集文件解压缩为以下结构,并通过MindSpore的API读取: + + .. code-block:: + + . + └── benchmark_RELEASE + ├── dataset + ├── img + │ ├── 2008_000002.jpg + │ ├── 2008_000003.jpg + │ ├── ... + ├── cls + │ ├── 2008_000002.mat + │ ├── 2008_000003.mat + │ ├── ... + ├── inst + │ ├── 2008_000002.mat + │ ├── 2008_000003.mat + │ ├── ... + ├── train.txt + └── val.txt + + **引用:** + + .. code-block:: + + @InProceedings{BharathICCV2011, + author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and + Subhransu Maji and Jitendra Malik", + title = "Semantic Contours from Inverse Detectors", + booktitle = "International Conference on Computer Vision (ICCV)", + year = "2011", + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SBUDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SBUDataset.rst index 9b9106938ed..790bce9e08e 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SBUDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SBUDataset.rst @@ -1,67 +1,67 @@ -mindspore.dataset.SBUDataset -============================ - -.. py:class:: mindspore.dataset.SBUDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) - - SBU(SBU Captioned Photo)数据集。 - - 生成的数据集有两列:`[image, caption]`。`image` 列的数据类型为uint8。`caption` 列的数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于SBU数据集:** - - SBU数据集是一个带字幕的大型照片集。它包含一百万张带有视觉相关标注的图像。 - - 你需要使用官方的download.m手动下载图片,将 'urls{i}(24, end)'替换为 'urls{i}(24:1:end)',并将目录保持如下。 - - .. code-block:: - - . - └─ dataset_dir - ├── SBU_captioned_photo_dataset_captions.txt - ├── SBU_captioned_photo_dataset_urls.txt - └── sbu_images - ├── m_3326_3596303505_3ce4c20529.jpg - ├── ...... - └── m_2522_4182181099_c3c23ab1cc.jpg - - **引用:** - - .. code-block:: - - @inproceedings{Ordonez:2011:im2text, - Author = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg}, - Title = {Im2Text: Describing Images Using 1 Million Captioned Photographs}, - Booktitle = {Neural Information Processing Systems ({NIPS})}, - Year = {2011}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.SBUDataset +============================ + +.. py:class:: mindspore.dataset.SBUDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) + + SBU(SBU Captioned Photo)数据集。 + + 生成的数据集有两列:`[image, caption]`。`image` 列的数据类型为uint8。`caption` 列的数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于SBU数据集:** + + SBU数据集是一个带字幕的大型照片集。它包含一百万张带有视觉相关标注的图像。 + + 你需要使用官方的download.m手动下载图片,将 'urls{i}(24, end)'替换为 'urls{i}(24:1:end)',并将目录保持如下。 + + .. code-block:: + + . + └─ dataset_dir + ├── SBU_captioned_photo_dataset_captions.txt + ├── SBU_captioned_photo_dataset_urls.txt + └── sbu_images + ├── m_3326_3596303505_3ce4c20529.jpg + ├── ...... + └── m_2522_4182181099_c3c23ab1cc.jpg + + **引用:** + + .. code-block:: + + @inproceedings{Ordonez:2011:im2text, + Author = {Vicente Ordonez and Girish Kulkarni and Tamara L. Berg}, + Title = {Im2Text: Describing Images Using 1 Million Captioned Photographs}, + Booktitle = {Neural Information Processing Systems ({NIPS})}, + Year = {2011}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.STL10Dataset.rst b/docs/api/api_python/dataset/mindspore.dataset.STL10Dataset.rst index 848c58315cd..99f9c221eab 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.STL10Dataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.STL10Dataset.rst @@ -1,79 +1,79 @@ -mindspore.dataset.STL10Dataset -============================== - -.. py:class:: mindspore.dataset.STL10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - STL-10数据集。 - - 生成的数据集有两列:`[image, label]`。`image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'unlabeled'`` 、 ``'train+unlabeled'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取5,000个样本,取值为 ``'test'`` 时将会读取8,000个样本,取值为 ``'unlabeled'`` 时将会读取100,000个样本,取值为 ``'train+unlabeled'`` 时将会读取10,5000个样本, - 取值为 ``'all'`` 时将会读取全部类型的样本。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `usage` 参数无效。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于STL10数据集:** - - STL10数据集由10类组成:飞机、鸟、汽车、猫、鹿、狗、马、猴子、船、卡车。 - 数据集样本均为96x96的彩色图像。 - 每个类别分别有500张训练图像和800张测试图像,以及100000张没有标签的图像。 - 标签索引从0开始标记,没有标签的的图像以-1作为标记。 - - 以下是原始STL10数据集结构。 - 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── stl10_dataset_dir - ├── train_X.bin - ├── train_y.bin - ├── test_X.bin - ├── test_y.bin - └── unlabeled_X.bin - - **引用:** - - .. code-block:: - - @techreport{Coates10, - author = {Adam Coates}, - title = {Learning multiple layers of features from tiny images}, - year = {20010}, - howpublished = {https://cs.stanford.edu/~acoates/stl10/}, - description = {The STL-10 dataset consists of 96x96 RGB images in 10 classes, - with 500 training images and 800 testing images per class. - There are 5000 training images and 8000 test images. - It also has 100000 unlabeled images for unsupervised learning. - These examples are extracted from a similar but broader distribution of images. - } - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.STL10Dataset +============================== + +.. py:class:: mindspore.dataset.STL10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + STL-10数据集。 + + 生成的数据集有两列:`[image, label]`。`image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'unlabeled'`` 、 ``'train+unlabeled'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取5,000个样本,取值为 ``'test'`` 时将会读取8,000个样本,取值为 ``'unlabeled'`` 时将会读取100,000个样本,取值为 ``'train+unlabeled'`` 时将会读取10,5000个样本, + 取值为 ``'all'`` 时将会读取全部类型的样本。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `usage` 参数无效。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于STL10数据集:** + + STL10数据集由10类组成:飞机、鸟、汽车、猫、鹿、狗、马、猴子、船、卡车。 + 数据集样本均为96x96的彩色图像。 + 每个类别分别有500张训练图像和800张测试图像,以及100000张没有标签的图像。 + 标签索引从0开始标记,没有标签的的图像以-1作为标记。 + + 以下是原始STL10数据集结构。 + 可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── stl10_dataset_dir + ├── train_X.bin + ├── train_y.bin + ├── test_X.bin + ├── test_y.bin + └── unlabeled_X.bin + + **引用:** + + .. code-block:: + + @techreport{Coates10, + author = {Adam Coates}, + title = {Learning multiple layers of features from tiny images}, + year = {20010}, + howpublished = {https://cs.stanford.edu/~acoates/stl10/}, + description = {The STL-10 dataset consists of 96x96 RGB images in 10 classes, + with 500 training images and 800 testing images per class. + There are 5000 training images and 8000 test images. + It also has 100000 unlabeled images for unsupervised learning. + These examples are extracted from a similar but broader distribution of images. + } + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SVHNDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SVHNDataset.rst index 216f357ba68..d964200cbf6 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SVHNDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SVHNDataset.rst @@ -1,66 +1,66 @@ -mindspore.dataset.SVHNDataset -============================= - -.. py:class:: mindspore.dataset.SVHNDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None) - - SVHN(Street View House Numbers)数据集。 - - 生成的数据集有两列:`[image, label]`。`image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'extra'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本图片。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `usage` 参数无效。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于SVHN数据集:** - - SVHN数据集是从谷歌街景图像中的门牌号码中获得的,由10位数字组成。 - - 以下是原始SVHN数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── svhn_dataset_dir - ├── train_32x32.mat - ├── test_32x32.mat - └── extra_32x32.mat - - **引用:** - - .. code-block:: - - @article{ - title={Reading Digits in Natural Images with Unsupervised Feature Learning}, - author={Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, Andrew Y. Ng}, - conference={NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011.}, - year={2011}, - publisher={NIPS} - url={http://ufldl.stanford.edu/housenumbers} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.SVHNDataset +============================= + +.. py:class:: mindspore.dataset.SVHNDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None) + + SVHN(Street View House Numbers)数据集。 + + 生成的数据集有两列:`[image, label]`。`image` 列的数据类型是uint8。`label` 列的数据类型是uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'extra'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本图片。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数,可以小于数据集总数。默认值: ``None`` ,读取全部样本图片。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作进程数。默认值: ``1`` 。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `usage` 参数无效。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于SVHN数据集:** + + SVHN数据集是从谷歌街景图像中的门牌号码中获得的,由10位数字组成。 + + 以下是原始SVHN数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── svhn_dataset_dir + ├── train_32x32.mat + ├── test_32x32.mat + └── extra_32x32.mat + + **引用:** + + .. code-block:: + + @article{ + title={Reading Digits in Natural Images with Unsupervised Feature Learning}, + author={Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, Andrew Y. Ng}, + conference={NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011.}, + year={2011}, + publisher={NIPS} + url={http://ufldl.stanford.edu/housenumbers} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Schema.rst b/docs/api/api_python/dataset/mindspore.dataset.Schema.rst index e1f54e23b3f..347e7650441 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Schema.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Schema.rst @@ -1,59 +1,59 @@ -mindspore.dataset.Schema -========================= - -.. py:class:: mindspore.dataset.Schema(schema_file=None) - - 用于解析和存储数据列属性的类。 - - 参数: - - **schema_file** (str) - schema文件的路径。默认值: ``None`` 。 - - 异常: - - **RuntimeError** - schema文件加载失败。 - - .. py:method:: add_column(name, de_type, shape=None) - - 向schema中添加新列。 - - 参数: - - **name** (str) - 列的新名称。 - - **de_type** (str) - 列的数据类型。 - - **shape** (list[int], 可选) - 列shape。默认值: ``None`` , ``-1`` 表示该维度的shape是未知的。 - - 异常: - - **ValueError** - 列类型未知。 - - .. py:method:: from_json(json_obj) - - 从JSON对象获取schema文件。 - - 参数: - - **json_obj** (dictionary) - 解析的JSON对象。 - - 异常: - - **RuntimeError** - 对象中存在未知的项。 - - **RuntimeError** - 对象中缺少数据集类型。 - - **RuntimeError** - 对象中缺少列。 - - .. py:method:: parse_columns(columns) - - 解析传入的数据列的属性并将其添加到自身的schema中。 - - 参数: - - **columns** (Union[dict, list[dict], tuple[dict]]) - 数据集属性信息,从schema文件解码。 - - - **list** [dict]:'name'和 'type'必须为key值, 'shape'可选。 - - **dict** :columns.keys()作为名称,columns.values()是dict,其中包含 'type', 'shape'可选。 - - 异常: - - **RuntimeError** - 解析列失败。 - - **RuntimeError** - 列name字段缺失。 - - **RuntimeError** - 列type字段缺失。 - - .. py:method:: to_json() - - 获取schema的JSON字符串。 - - 返回: - str,模式的JSON字符串。 +mindspore.dataset.Schema +========================= + +.. py:class:: mindspore.dataset.Schema(schema_file=None) + + 用于解析和存储数据列属性的类。 + + 参数: + - **schema_file** (str) - schema文件的路径。默认值: ``None`` 。 + + 异常: + - **RuntimeError** - schema文件加载失败。 + + .. py:method:: add_column(name, de_type, shape=None) + + 向schema中添加新列。 + + 参数: + - **name** (str) - 列的新名称。 + - **de_type** (str) - 列的数据类型。 + - **shape** (list[int], 可选) - 列shape。默认值: ``None`` , ``-1`` 表示该维度的shape是未知的。 + + 异常: + - **ValueError** - 列类型未知。 + + .. py:method:: from_json(json_obj) + + 从JSON对象获取schema文件。 + + 参数: + - **json_obj** (dictionary) - 解析的JSON对象。 + + 异常: + - **RuntimeError** - 对象中存在未知的项。 + - **RuntimeError** - 对象中缺少数据集类型。 + - **RuntimeError** - 对象中缺少列。 + + .. py:method:: parse_columns(columns) + + 解析传入的数据列的属性并将其添加到自身的schema中。 + + 参数: + - **columns** (Union[dict, list[dict], tuple[dict]]) - 数据集属性信息,从schema文件解码。 + + - **list** [dict]:'name'和 'type'必须为key值, 'shape'可选。 + - **dict** :columns.keys()作为名称,columns.values()是dict,其中包含 'type', 'shape'可选。 + + 异常: + - **RuntimeError** - 解析列失败。 + - **RuntimeError** - 列name字段缺失。 + - **RuntimeError** - 列type字段缺失。 + + .. py:method:: to_json() + + 获取schema的JSON字符串。 + + 返回: + str,模式的JSON字符串。 \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.SemeionDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SemeionDataset.rst index d9941783a44..78a53e7e347 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SemeionDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SemeionDataset.rst @@ -1,61 +1,61 @@ -mindspore.dataset.SemeionDataset -================================ - -.. py:class:: mindspore.dataset.SemeionDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Semeion数据集。 - - 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。`label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于SEMEION数据集:** - - 该数据集由意大利布雷西亚Tactile Srl创建(http://www.tattil.it),并于1994年捐赠给意大利罗马Semeion通信科学研究中心(http://www.semeion.it),用于机器学习研究。 - 此数据集由1593条样本记录(行)和256个属性(列)组成。每条记录代表一个手写数字,最初扫描的分辨率为256灰度。 - 数据集拉伸了每个原始扫描图像的每个像素,然后在0和1之间缩放(将值低于灰度值127的每个像素(包括127)设置为0,并将灰度值超过127的每个像素设置为1)。 - 最后,每个二进制图像再次缩放为一个16x16的方形图像。 - - .. code-block:: - - . - └── semeion_dataset_dir - └──semeion.data - └──semeion.names - - **引用:** - - .. code-block:: - - @article{ - title={The Theory of Independent Judges, in Substance Use & Misuse 33(2)1998, pp 439-461}, - author={M Buscema, MetaNet}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.SemeionDataset +================================ + +.. py:class:: mindspore.dataset.SemeionDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Semeion数据集。 + + 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。`label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于SEMEION数据集:** + + 该数据集由意大利布雷西亚Tactile Srl创建(http://www.tattil.it),并于1994年捐赠给意大利罗马Semeion通信科学研究中心(http://www.semeion.it),用于机器学习研究。 + 此数据集由1593条样本记录(行)和256个属性(列)组成。每条记录代表一个手写数字,最初扫描的分辨率为256灰度。 + 数据集拉伸了每个原始扫描图像的每个像素,然后在0和1之间缩放(将值低于灰度值127的每个像素(包括127)设置为0,并将灰度值超过127的每个像素设置为1)。 + 最后,每个二进制图像再次缩放为一个16x16的方形图像。 + + .. code-block:: + + . + └── semeion_dataset_dir + └──semeion.data + └──semeion.names + + **引用:** + + .. code-block:: + + @article{ + title={The Theory of Independent Judges, in Substance Use & Misuse 33(2)1998, pp 439-461}, + author={M Buscema, MetaNet}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SequentialSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.SequentialSampler.rst index 260a1942c2d..ae2e5adc48a 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SequentialSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SequentialSampler.rst @@ -1,20 +1,20 @@ -mindspore.dataset.SequentialSampler -=================================== - -.. py:class:: mindspore.dataset.SequentialSampler(start_index=None, num_samples=None) - - 按数据集的读取顺序采样数据集样本,相当于不使用采样器。 - - 参数: - - **start_index** (int, 可选) - 采样的起始样本ID。默认值: ``None`` ,从数据集第一个样本开始采样。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - 异常: - - **TypeError** - `start_index` 的类型不是int。 - - **TypeError** - `num_samples` 的类型不是int。 - - **RuntimeError** - `start_index` 为负值。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst - +mindspore.dataset.SequentialSampler +=================================== + +.. py:class:: mindspore.dataset.SequentialSampler(start_index=None, num_samples=None) + + 按数据集的读取顺序采样数据集样本,相当于不使用采样器。 + + 参数: + - **start_index** (int, 可选) - 采样的起始样本ID。默认值: ``None`` ,从数据集第一个样本开始采样。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + + 异常: + - **TypeError** - `start_index` 的类型不是int。 + - **TypeError** - `num_samples` 的类型不是int。 + - **RuntimeError** - `start_index` 为负值。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst + .. include:: mindspore.dataset.BuiltinSampler.b.rst \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.Shuffle.rst b/docs/api/api_python/dataset/mindspore.dataset.Shuffle.rst index 91b6bddddbb..06f9007d9f8 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.Shuffle.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.Shuffle.rst @@ -1,10 +1,10 @@ -mindspore.dataset.Shuffle -========================= - -.. py:class:: mindspore.dataset.Shuffle - - 指定混洗模式的枚举类。 - - - **Shuffle.GLOBAL** - 混洗文件和文件中的数据。 - - **Shuffle.FILES** - 仅混洗文件。 - - **Shuffle.INFILE** - 保持读入文件的序列,仅混洗每个文件中的数据。 +mindspore.dataset.Shuffle +========================= + +.. py:class:: mindspore.dataset.Shuffle + + 指定混洗模式的枚举类。 + + - **Shuffle.GLOBAL** - 混洗文件和文件中的数据。 + - **Shuffle.FILES** - 仅混洗文件。 + - **Shuffle.INFILE** - 保持读入文件的序列,仅混洗每个文件中的数据。 diff --git a/docs/api/api_python/dataset/mindspore.dataset.SogouNewsDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SogouNewsDataset.rst index 3ca9e93d1d6..a7d9394664c 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SogouNewsDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SogouNewsDataset.rst @@ -1,68 +1,68 @@ -mindspore.dataset.SogouNewsDataset -================================== - -.. py:class:: mindspore.dataset.SogouNewsDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) - - Sogou New数据集。 - - 生成的数据集有三列 `[index, title, content]`,三列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - 取值为 ``'train'`` 时将会读取45万个训练样本,取值为 ``'test'`` 时将会读取6万个测试样本,取值为 ``'all'`` 时将会读取全部51万个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` , 读取全部样本。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于SogouNew数据集:** - - SogouNews 数据集包括3列,分别对应类别索引(1到5)、标题和内容。 - 标题和内容使用双引号(")进行转义,任何内部双引号都使用2个双引号("")进行转义。 - 新行使用反斜杠进行转义,后跟“n”字符,即 "\n"。 - - 以下是原始SogouNew数据集结构,可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取: - - .. code-block:: - - . - └── sogou_news_dir - ├── classes.txt - ├── readme.txt - ├── test.csv - └── train.csv - - **引用:** - - .. code-block:: - - @misc{zhang2015characterlevel, - title={Character-level Convolutional Networks for Text Classification}, - author={Xiang Zhang and Junbo Zhao and Yann LeCun}, - year={2015}, - eprint={1509.01626}, - archivePrefix={arXiv}, - primaryClass={cs.LG} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.SogouNewsDataset +================================== + +.. py:class:: mindspore.dataset.SogouNewsDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) + + Sogou New数据集。 + + 生成的数据集有三列 `[index, title, content]`,三列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + 取值为 ``'train'`` 时将会读取45万个训练样本,取值为 ``'test'`` 时将会读取6万个测试样本,取值为 ``'all'`` 时将会读取全部51万个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` , 读取全部样本。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于SogouNew数据集:** + + SogouNews 数据集包括3列,分别对应类别索引(1到5)、标题和内容。 + 标题和内容使用双引号(")进行转义,任何内部双引号都使用2个双引号("")进行转义。 + 新行使用反斜杠进行转义,后跟“n”字符,即 "\n"。 + + 以下是原始SogouNew数据集结构,可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取: + + .. code-block:: + + . + └── sogou_news_dir + ├── classes.txt + ├── readme.txt + ├── test.csv + └── train.csv + + **引用:** + + .. code-block:: + + @misc{zhang2015characterlevel, + title={Character-level Convolutional Networks for Text Classification}, + author={Xiang Zhang and Junbo Zhao and Yann LeCun}, + year={2015}, + eprint={1509.01626}, + archivePrefix={arXiv}, + primaryClass={cs.LG} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SpeechCommandsDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.SpeechCommandsDataset.rst index 41824e4e1b1..23c67992fb3 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SpeechCommandsDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SpeechCommandsDataset.rst @@ -1,71 +1,71 @@ -mindspore.dataset.SpeechCommandsDataset -======================================= - -.. py:class:: mindspore.dataset.SpeechCommandsDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Speech Commands数据集。 - - 生成的数据集有五列 `[waveform, sample_rate, label, speaker_id, utterance_number]` 。 - 列 `waveform` 的数据类型为float32。列 `sample_rate` 的数据类型为int32。列 `label` 的数据类型为string。列 `speaker_id` 的数据类型为string。列 `utterance_number` 的数据类型为int32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - 取值为 ``'train'`` 时将会读取84,843个训练样本,取值为 ``'test'`` 时将会读取11,005个测试样本,取值为 ``'valid'`` 时将会读取9,981个测试样本,取值为 ``'all'`` 时将会读取全部105,829个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于SpeechCommands数据集:** - - SpeechCommands(语音命令)数据是用于有限词汇语音识别的数据集,包含105,829个 '.wav'格式的音频样本。 - - 以下是原始SpeechCommands的数据集结构。可以将数据集文件解压缩成此目录结构,并由MindSpore的API读取。 - - .. code-block:: - - . - └── speech_commands_dataset_dir - ├── cat - ├── b433eff_nohash_0.wav - ├── 5a33edf_nohash_1.wav - └──.... - ├── dog - ├── b433w2w_nohash_0.wav - └──.... - ├── four - └── .... - - **引用:** - - .. code-block:: - - @article{2018Speech, - title={Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition}, - author={Warden, P.}, - year={2018} - } - - -.. include:: mindspore.dataset.api_list_audio.rst +mindspore.dataset.SpeechCommandsDataset +======================================= + +.. py:class:: mindspore.dataset.SpeechCommandsDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Speech Commands数据集。 + + 生成的数据集有五列 `[waveform, sample_rate, label, speaker_id, utterance_number]` 。 + 列 `waveform` 的数据类型为float32。列 `sample_rate` 的数据类型为int32。列 `label` 的数据类型为string。列 `speaker_id` 的数据类型为string。列 `utterance_number` 的数据类型为int32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + 取值为 ``'train'`` 时将会读取84,843个训练样本,取值为 ``'test'`` 时将会读取11,005个测试样本,取值为 ``'valid'`` 时将会读取9,981个测试样本,取值为 ``'all'`` 时将会读取全部105,829个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于SpeechCommands数据集:** + + SpeechCommands(语音命令)数据是用于有限词汇语音识别的数据集,包含105,829个 '.wav'格式的音频样本。 + + 以下是原始SpeechCommands的数据集结构。可以将数据集文件解压缩成此目录结构,并由MindSpore的API读取。 + + .. code-block:: + + . + └── speech_commands_dataset_dir + ├── cat + ├── b433eff_nohash_0.wav + ├── 5a33edf_nohash_1.wav + └──.... + ├── dog + ├── b433w2w_nohash_0.wav + └──.... + ├── four + └── .... + + **引用:** + + .. code-block:: + + @article{2018Speech, + title={Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition}, + author={Warden, P.}, + year={2018} + } + + +.. include:: mindspore.dataset.api_list_audio.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SubsetRandomSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.SubsetRandomSampler.rst index f74acb51f74..c7f36450013 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SubsetRandomSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SubsetRandomSampler.rst @@ -1,17 +1,17 @@ -mindspore.dataset.SubsetRandomSampler -====================================== - -.. py:class:: mindspore.dataset.SubsetRandomSampler(indices, num_samples=None) - - 给定样本的索引序列,从序列中随机获取索引对数据集进行采样。 - - 参数: - - **indices** (Iterable) - 样本索引的序列(除了string类型外的任意Python可迭代对象类型)。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - 异常: - - **TypeError** - `indices` 的类型不是int。 - - **TypeError** - `num_samples` 的类型不是int。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst +mindspore.dataset.SubsetRandomSampler +====================================== + +.. py:class:: mindspore.dataset.SubsetRandomSampler(indices, num_samples=None) + + 给定样本的索引序列,从序列中随机获取索引对数据集进行采样。 + + 参数: + - **indices** (Iterable) - 样本索引的序列(除了string类型外的任意Python可迭代对象类型)。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + + 异常: + - **TypeError** - `indices` 的类型不是int。 + - **TypeError** - `num_samples` 的类型不是int。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.SubsetSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.SubsetSampler.rst index 9e091befaa0..1914da4111d 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.SubsetSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.SubsetSampler.rst @@ -1,17 +1,17 @@ -mindspore.dataset.SubsetSampler -==================================== - -.. py:class:: mindspore.dataset.SubsetSampler(indices, num_samples=None) - - 给定样本的索引序列,对数据集采样指定索引的样本。 - - 参数: - - **indices** (Iterable) - 索引的序列(包括除了string类型的任意Python可迭代对象类型)。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - 异常: - - **TypeError** - `indices` 的类型不是int。 - - **TypeError** - `num_samples` 的类型不是int。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst +mindspore.dataset.SubsetSampler +==================================== + +.. py:class:: mindspore.dataset.SubsetSampler(indices, num_samples=None) + + 给定样本的索引序列,对数据集采样指定索引的样本。 + + 参数: + - **indices** (Iterable) - 索引的序列(包括除了string类型的任意Python可迭代对象类型)。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + + 异常: + - **TypeError** - `indices` 的类型不是int。 + - **TypeError** - `num_samples` 的类型不是int。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.TFRecordDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.TFRecordDataset.rst index f1bf6056c0c..e0f0d6d2ac6 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.TFRecordDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.TFRecordDataset.rst @@ -1,51 +1,51 @@ -mindspore.dataset.TFRecordDataset -================================= - -.. py:class:: mindspore.dataset.TFRecordDataset(dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None, compression_type=None) - - 读取和解析TFData格式的数据文件构建数据集。生成的数据集的列名和列类型取决于TFRecord文件中的保存的列名与类型。 - - .. note:: Windows平台尚不支持 `TFRecordDataset` 。 - - 参数: - - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 - - **schema** (Union[str, :class:`~.dataset.Schema`], 可选) - 数据格式策略,用于指定读取数据列的数据类型、数据维度等信息。 - 支持传入JSON文件路径或 :class:`mindspore.dataset.Schema` 构造的对象。默认值: ``None`` 。 - - **columns_list** (list[str], 可选) - 指定从TFRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - 当指定了 `num_shards` 和 `shard_id` 参数时,`num_samples` 或numRows字段(由参数 `schema` 定义)将表示每个分片读取的数据量。 `num_samples` 的处理优先级如下: - - - 指定了 `num_samples` 参数,且值大于0,则读取 `num_samples` 条数据。此时 `schema` 参数的numRows字段会失效。 - - 不指定 `num_samples` 参数,指定了 `schema` 参数并定义了numRows字段,且值大于0,则读取numRows条数据。 - - 不指定 `num_samples` 参数 与 `schema` 参数,读取所有样本数据。 - - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后,`num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值: ``False`` 。如果 `shard_equal_rows` 为 ``False`` ,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。因此当每个TFRecord文件的数据数量不相等时,建议将此参数设置为 ``True`` 。注意,只有当指定了 `num_shards` 时才能指定此参数。当 `compression_type` 非 ``None`` ,且指定了 `num_samples` 或numRows字段(由参数 `schema` 定义)时,`shard_equal_rows` 会被视为 ``True`` 。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **compression_type** (str, 可选) - 用于所有文件的压缩类型,必须是 ``“”`` , ``“GZIP”`` 或 ``“ZLIB”`` 。默认值: ``None`` ,即空字符串。 - 建议在 `compression_type` 为 ``"GZIP"`` 或 ``"ZLIB"`` 时,指定 `num_samples` 或numRows字段(由参数 `schema` 定义)以避免出现为了获取文件大小对同一个文件进行多次解压而导致性能下降的问题。 - - 异常: - - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `compression_type` 不是 ``''`` 、 ``'GZIP'`` 、 ``'ZLIB'`` 三者之一。 - - **ValueError** - `compression_type` 有效但是数据集文件数量小于 `num_shards` 。 - - **ValueError** - `num_samples` 小于0。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.TFRecordDataset +================================= + +.. py:class:: mindspore.dataset.TFRecordDataset(dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None, compression_type=None) + + 读取和解析TFData格式的数据文件构建数据集。生成的数据集的列名和列类型取决于TFRecord文件中的保存的列名与类型。 + + .. note:: Windows平台尚不支持 `TFRecordDataset` 。 + + 参数: + - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 + - **schema** (Union[str, :class:`~.dataset.Schema`], 可选) - 数据格式策略,用于指定读取数据列的数据类型、数据维度等信息。 + 支持传入JSON文件路径或 :class:`mindspore.dataset.Schema` 构造的对象。默认值: ``None`` 。 + - **columns_list** (list[str], 可选) - 指定从TFRecord文件中读取的数据列。默认值: ``None`` ,读取所有列。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + 当指定了 `num_shards` 和 `shard_id` 参数时,`num_samples` 或numRows字段(由参数 `schema` 定义)将表示每个分片读取的数据量。 `num_samples` 的处理优先级如下: + + - 指定了 `num_samples` 参数,且值大于0,则读取 `num_samples` 条数据。此时 `schema` 参数的numRows字段会失效。 + - 不指定 `num_samples` 参数,指定了 `schema` 参数并定义了numRows字段,且值大于0,则读取numRows条数据。 + - 不指定 `num_samples` 参数 与 `schema` 参数,读取所有样本数据。 + + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后,`num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **shard_equal_rows** (bool, 可选) - 分布式训练时,为所有分片获取等量的数据行数。默认值: ``False`` 。如果 `shard_equal_rows` 为 ``False`` ,则可能会使得每个分片的数据条目不相等,从而导致分布式训练失败。因此当每个TFRecord文件的数据数量不相等时,建议将此参数设置为 ``True`` 。注意,只有当指定了 `num_shards` 时才能指定此参数。当 `compression_type` 非 ``None`` ,且指定了 `num_samples` 或numRows字段(由参数 `schema` 定义)时,`shard_equal_rows` 会被视为 ``True`` 。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **compression_type** (str, 可选) - 用于所有文件的压缩类型,必须是 ``“”`` , ``“GZIP”`` 或 ``“ZLIB”`` 。默认值: ``None`` ,即空字符串。 + 建议在 `compression_type` 为 ``"GZIP"`` 或 ``"ZLIB"`` 时,指定 `num_samples` 或numRows字段(由参数 `schema` 定义)以避免出现为了获取文件大小对同一个文件进行多次解压而导致性能下降的问题。 + + 异常: + - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `compression_type` 不是 ``''`` 、 ``'GZIP'`` 、 ``'ZLIB'`` 三者之一。 + - **ValueError** - `compression_type` 有效但是数据集文件数量小于 `num_shards` 。 + - **ValueError** - `num_samples` 小于0。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.TedliumDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.TedliumDataset.rst index 2040c4c4a54..562a98192ae 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.TedliumDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.TedliumDataset.rst @@ -1,133 +1,133 @@ -mindspore.dataset.TedliumDataset -================================ - -.. py:class:: mindspore.dataset.TedliumDataset(dataset_dir, release, usage=None, extensions=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - Tedlium数据集。生成的数据集的列取决于源SPH文件和相应的STM文件。 - - 生成的数据集有六列 `[waveform, sample_rate, transcript, talk_id, speaker_id, identifier]`。 - 列 `waveform` 的数据类型为float32,列 `sample_rate` 的数据类型为int32,列 `transcript`、列 `talk_id`、列 `speaker_id` 和列 `identifier` 的数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **release** (str) - 指定数据集的发布版本,可以取值为 ``'release1'`` 、 ``'release2'`` 或 ``'release3'`` 。 - - **usage** (str, 可选) - 指定数据集的子集。 - 对于 `release` 为 ``'release1'`` 或 ``'release2'``,`usage` 可以是 ``'train'`` 、 ``'test'`` 、 ``'dev'`` 或 ``'all'`` 。 - 对于 `release` 为 ``'release3'`` , `usage` 只能是 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **extensions** (str, 可选) - 指定SPH文件的扩展名。默认值: ``None`` ,默认指定为 ``'.sph'`` 。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于TEDLIUM数据集:** - - TEDLIUM_release1数据集:TED-LUM语料库是英语TED演讲,有转录,采样频率为16kHz。包含了大约118小时的演讲。 - - TEDLIUM_release2数据集:这是TED-LIUM语料库版本2,根据知识共享BY-NC-ND 3.0授权。所有会谈和文本均为TED会议有限责任公司的财产。TED-LIUM语料库是由音频谈话和他们的转录在TED网站上提供的。我们准备并过滤了这些数据,以便训练声学模型参加2011年口语翻译国际研讨会(LIUM英语/法语SLT系统在SLT任务中排名第一)。 - - TEDLIUM_release-3数据集:这是TED-LIUM语料库版本3,根据知识共享BY-NC-ND 3.0授权。所有会谈和文本均为TED会议有限责任公司的财产。这个新的TED-LIUM版本是通过Ubiqus公司和LIUM(法国勒芒大学)的合作发布的。 - - 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 - - TEDLIUM release1与TEDLIUM release2的结构相同,只是数据不同。 - - .. code-block:: - - . - └──TEDLIUM_release1 - └── dev - ├── sph - ├── AlGore_2009.sph - ├── BarrySchwartz_2005G.sph - ├── stm - ├── AlGore_2009.stm - ├── BarrySchwartz_2005G.stm - └── test - ├── sph - ├── AimeeMullins_2009P.sph - ├── BillGates_2010.sph - ├── stm - ├── AimeeMullins_2009P.stm - ├── BillGates_2010.stm - └── train - ├── sph - ├── AaronHuey_2010X.sph - ├── AdamGrosser_2007.sph - ├── stm - ├── AaronHuey_2010X.stm - ├── AdamGrosser_2007.stm - └── readme - └── TEDLIUM.150k.dic - - TEDLIUM release3目录结构稍有不同。 - - .. code-block:: - - . - └──TEDLIUM_release-3 - └── data - ├── ctl - ├── sph - ├── 911Mothers_2010W.sph - ├── AalaElKhani.sph - ├── stm - ├── 911Mothers_2010W.stm - ├── AalaElKhani.stm - └── doc - └── legacy - └── LM - └── speaker-adaptation - └── readme - └── TEDLIUM.150k.dic - - **引用:** - - .. code-block:: - - @article{ - title={TED-LIUM: an automatic speech recognition dedicated corpus}, - author={A. Rousseau, P. Deléglise, Y. Estève}, - journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)}, - year={May 2012}, - biburl={https://www.openslr.org/7/} - } - - @article{ - title={Enhancing the TED-LIUM Corpus with Selected Data for Language Modeling and More TED Talks}, - author={A. Rousseau, P. Deléglise, and Y. Estève}, - journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)}, - year={May 2014}, - biburl={https://www.openslr.org/19/} - } - - @article{ - title={TED-LIUM 3: twice as much data and corpus repartition for experiments on speaker adaptation}, - author={François Hernandez, Vincent Nguyen, Sahar Ghannay, Natalia Tomashenko, and Yannick Estève}, - journal={the 20th International Conference on Speech and Computer (SPECOM 2018)}, - year={September 2018}, - biburl={https://www.openslr.org/51/} - } - - -.. include:: mindspore.dataset.api_list_audio.rst +mindspore.dataset.TedliumDataset +================================ + +.. py:class:: mindspore.dataset.TedliumDataset(dataset_dir, release, usage=None, extensions=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + Tedlium数据集。生成的数据集的列取决于源SPH文件和相应的STM文件。 + + 生成的数据集有六列 `[waveform, sample_rate, transcript, talk_id, speaker_id, identifier]`。 + 列 `waveform` 的数据类型为float32,列 `sample_rate` 的数据类型为int32,列 `transcript`、列 `talk_id`、列 `speaker_id` 和列 `identifier` 的数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **release** (str) - 指定数据集的发布版本,可以取值为 ``'release1'`` 、 ``'release2'`` 或 ``'release3'`` 。 + - **usage** (str, 可选) - 指定数据集的子集。 + 对于 `release` 为 ``'release1'`` 或 ``'release2'``,`usage` 可以是 ``'train'`` 、 ``'test'`` 、 ``'dev'`` 或 ``'all'`` 。 + 对于 `release` 为 ``'release3'`` , `usage` 只能是 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **extensions** (str, 可选) - 指定SPH文件的扩展名。默认值: ``None`` ,默认指定为 ``'.sph'`` 。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于TEDLIUM数据集:** + + TEDLIUM_release1数据集:TED-LUM语料库是英语TED演讲,有转录,采样频率为16kHz。包含了大约118小时的演讲。 + + TEDLIUM_release2数据集:这是TED-LIUM语料库版本2,根据知识共享BY-NC-ND 3.0授权。所有会谈和文本均为TED会议有限责任公司的财产。TED-LIUM语料库是由音频谈话和他们的转录在TED网站上提供的。我们准备并过滤了这些数据,以便训练声学模型参加2011年口语翻译国际研讨会(LIUM英语/法语SLT系统在SLT任务中排名第一)。 + + TEDLIUM_release-3数据集:这是TED-LIUM语料库版本3,根据知识共享BY-NC-ND 3.0授权。所有会谈和文本均为TED会议有限责任公司的财产。这个新的TED-LIUM版本是通过Ubiqus公司和LIUM(法国勒芒大学)的合作发布的。 + + 可以将数据集文件解压缩到以下目录结构中,并由MindSpore的API读取。 + + TEDLIUM release1与TEDLIUM release2的结构相同,只是数据不同。 + + .. code-block:: + + . + └──TEDLIUM_release1 + └── dev + ├── sph + ├── AlGore_2009.sph + ├── BarrySchwartz_2005G.sph + ├── stm + ├── AlGore_2009.stm + ├── BarrySchwartz_2005G.stm + └── test + ├── sph + ├── AimeeMullins_2009P.sph + ├── BillGates_2010.sph + ├── stm + ├── AimeeMullins_2009P.stm + ├── BillGates_2010.stm + └── train + ├── sph + ├── AaronHuey_2010X.sph + ├── AdamGrosser_2007.sph + ├── stm + ├── AaronHuey_2010X.stm + ├── AdamGrosser_2007.stm + └── readme + └── TEDLIUM.150k.dic + + TEDLIUM release3目录结构稍有不同。 + + .. code-block:: + + . + └──TEDLIUM_release-3 + └── data + ├── ctl + ├── sph + ├── 911Mothers_2010W.sph + ├── AalaElKhani.sph + ├── stm + ├── 911Mothers_2010W.stm + ├── AalaElKhani.stm + └── doc + └── legacy + └── LM + └── speaker-adaptation + └── readme + └── TEDLIUM.150k.dic + + **引用:** + + .. code-block:: + + @article{ + title={TED-LIUM: an automatic speech recognition dedicated corpus}, + author={A. Rousseau, P. Deléglise, Y. Estève}, + journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)}, + year={May 2012}, + biburl={https://www.openslr.org/7/} + } + + @article{ + title={Enhancing the TED-LIUM Corpus with Selected Data for Language Modeling and More TED Talks}, + author={A. Rousseau, P. Deléglise, and Y. Estève}, + journal={Proceedings of the Eighth International Conference on Language Resources and Evaluation (LREC'12)}, + year={May 2014}, + biburl={https://www.openslr.org/19/} + } + + @article{ + title={TED-LIUM 3: twice as much data and corpus repartition for experiments on speaker adaptation}, + author={François Hernandez, Vincent Nguyen, Sahar Ghannay, Natalia Tomashenko, and Yannick Estève}, + journal={the 20th International Conference on Speech and Computer (SPECOM 2018)}, + year={September 2018}, + biburl={https://www.openslr.org/51/} + } + + +.. include:: mindspore.dataset.api_list_audio.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.TextFileDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.TextFileDataset.rst index 6bd31236bd8..cd46fa4856f 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.TextFileDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.TextFileDataset.rst @@ -1,34 +1,34 @@ -mindspore.dataset.TextFileDataset -================================== - -.. py:class:: mindspore.dataset.TextFileDataset(dataset_files, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - 读取和解析文本文件构建数据集。生成的数据集有一个数据列:`[text]`,类型为string。 - - 参数: - - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.TextFileDataset +================================== + +.. py:class:: mindspore.dataset.TextFileDataset(dataset_files, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + 读取和解析文本文件构建数据集。生成的数据集有一个数据列:`[text]`,类型为string。 + + 参数: + - **dataset_files** (Union[str, list[str]]) - 数据集文件路径,支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串,文件列表将在内部进行字典排序。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.UDPOSDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.UDPOSDataset.rst index 05940d0f855..9e742226d1b 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.UDPOSDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.UDPOSDataset.rst @@ -1,56 +1,56 @@ -mindspore.dataset.UDPOSDataset -============================== - -.. py:class:: mindspore.dataset.UDPOSDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) - - UDPOS(Universal Dependencies dataset for Part of Speech)数据集。 - - 生成的数据集有三列 `[word, universal, stanford]`,三列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取12,543个样本,取值为 ``'test'`` 时将会读取2,077个测试样本,取值为 ``'valid'`` 时将会读取2,002个样本,取值为 ``'all'`` 时将会读取全部16,622个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: `Shuffle.GLOBAL` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于UDPOS数据集:** - - UDPOS是一个解析的文本语料库数据集,用于阐明句法或者语义句子结构。 - 该语料库包含254,830个单词和16,622个句子,取自各种网络媒体,包括博客、新闻组、电子邮件和评论。 - - **引用:** - - .. code-block:: - - @inproceedings{silveira14gold, - year = {2014}, - author = {Natalia Silveira and Timothy Dozat and Marie-Catherine de Marneffe and Samuel Bowman - and Miriam Connor and John Bauer and Christopher D. Manning}, - title = {A Gold Standard Dependency Corpus for {E}nglish}, - booktitle = {Proceedings of the Ninth International Conference on Language - Resources and Evaluation (LREC-2014)} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.UDPOSDataset +============================== + +.. py:class:: mindspore.dataset.UDPOSDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) + + UDPOS(Universal Dependencies dataset for Part of Speech)数据集。 + + 生成的数据集有三列 `[word, universal, stanford]`,三列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取12,543个样本,取值为 ``'test'`` 时将会读取2,077个测试样本,取值为 ``'valid'`` 时将会读取2,002个样本,取值为 ``'all'`` 时将会读取全部16,622个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: `Shuffle.GLOBAL` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于UDPOS数据集:** + + UDPOS是一个解析的文本语料库数据集,用于阐明句法或者语义句子结构。 + 该语料库包含254,830个单词和16,622个句子,取自各种网络媒体,包括博客、新闻组、电子邮件和评论。 + + **引用:** + + .. code-block:: + + @inproceedings{silveira14gold, + year = {2014}, + author = {Natalia Silveira and Timothy Dozat and Marie-Catherine de Marneffe and Samuel Bowman + and Miriam Connor and John Bauer and Christopher D. Manning}, + title = {A Gold Standard Dependency Corpus for {E}nglish}, + booktitle = {Proceedings of the Ninth International Conference on Language + Resources and Evaluation (LREC-2014)} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.USPSDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.USPSDataset.rst index 1b89bab1e1f..62d5aa31c3a 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.USPSDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.USPSDataset.rst @@ -1,69 +1,69 @@ -mindspore.dataset.USPSDataset -============================= - -.. py:class:: mindspore.dataset.USPSDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - USPS(U.S. Postal Service)数据集。 - - 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。`label` 列的数据类型为uint32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取7,291个样本,取值为 ``'test'`` 时将会读取2,007个测试样本,取值为 ``'all'`` 时将会读取全部9,298个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有数据集。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `usage` 参数无效。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于USPS数据集:** - - USPS是美国邮政服务公司从信封中自动扫描的数字数据集,包含总共9,298个16×16像素灰度样本。 - 数据集中的图片内容已被预处理为居中和归一化,并集中了多种样式的字体。 - - 以下是原始的USPS数据集结构。可以将数据集文件下载并解压缩到此目录结构中,并通过MindSpore的API读取。 - - .. code-block:: - - . - └── usps_dataset_dir - ├── usps - ├── usps.t - - **引用:** - - .. code-block:: - - @article{hull1994database, - title={A database for handwritten text recognition research}, - author={Hull, Jonathan J.}, - journal={IEEE Transactions on pattern analysis and machine intelligence}, - volume={16}, - number={5}, - pages={550--554}, - year={1994}, - publisher={IEEE} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.USPSDataset +============================= + +.. py:class:: mindspore.dataset.USPSDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + USPS(U.S. Postal Service)数据集。 + + 生成的数据集有两列:`[image, label]`。`image` 列的数据类型为uint8。`label` 列的数据类型为uint32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取7,291个样本,取值为 ``'test'`` 时将会读取2,007个测试样本,取值为 ``'all'`` 时将会读取全部9,298个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取所有数据集。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含数据文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `usage` 参数无效。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于USPS数据集:** + + USPS是美国邮政服务公司从信封中自动扫描的数字数据集,包含总共9,298个16×16像素灰度样本。 + 数据集中的图片内容已被预处理为居中和归一化,并集中了多种样式的字体。 + + 以下是原始的USPS数据集结构。可以将数据集文件下载并解压缩到此目录结构中,并通过MindSpore的API读取。 + + .. code-block:: + + . + └── usps_dataset_dir + ├── usps + ├── usps.t + + **引用:** + + .. code-block:: + + @article{hull1994database, + title={A database for handwritten text recognition research}, + author={Hull, Jonathan J.}, + journal={IEEE Transactions on pattern analysis and machine intelligence}, + volume={16}, + number={5}, + pages={550--554}, + year={1994}, + publisher={IEEE} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.VOCDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.VOCDataset.rst index 3c2373bbea7..78aaef5e400 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.VOCDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.VOCDataset.rst @@ -1,110 +1,110 @@ -mindspore.dataset.VOCDataset -============================= - -.. py:class:: mindspore.dataset.VOCDataset(dataset_dir, task='Segmentation', usage='train', class_indexing=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None, extra_metadata=False, decrypt=None) - - VOC(Visual Object Classes)数据集。 - - 根据给定的 `task` 配置,生成数据集具有不同的输出列: - - - `task` 为 ``'Detection'`` ,输出列: `[image, dtype=uint8]` , `[bbox, dtype=float32]` , `[label, dtype=uint32]` , `[difficult, dtype=uint32]` , `[truncate, dtype=uint32]` 。 - - `task` 为 ``'Segmentation'`` ,输出列: `[image, dtype=uint8]` , `[target, dtype=uint8]` 。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 - - **task** (str, 可选) - 指定读取VOC数据的任务类型,现在只支持 ``'Segmentation'`` 和 ``'Detection'`` 。默认值: ``'Segmentation'`` 。 - - **usage** (str, 可选) - 指定数据集的子集。默认值: ``'train'`` 。 - - - 如果 `task` 的值为 ``'Segmentation'`` ,则读取 'ImageSets/Segmentation/' 目录下定义的图片和label信息; - - 如果 `task` 的值为 ``'Detection'`` ,则读取 'ImageSets/Main/' 目录下定义的图片和label信息。 - - - **class_indexing** (dict, 可选) - 指定一个从label名称到label索引的映射,要求映射规则为string到int。索引值从0开始,并且要求每个label名称对应的索引值唯一。 - 仅在 'Detection' 任务中有效。默认值: ``None`` ,不指定。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - **extra_metadata** (bool, 可选) - 用于指定是否额外输出一个数据列用于表示图片元信息。如果为 ``True`` ,则将额外输出一个名为 `[_meta-filename, dtype=string]` 的数据列。默认值: ``False`` 。 - - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **RuntimeError** - 读取的xml文件格式异常或无效。 - - **RuntimeError** - 读取的xml文件缺失 `object` 属性。 - - **RuntimeError** - 读取的xml文件缺失 `bndbox` 属性。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - 指定的任务不为 ``'Segmentation'`` 或 ``'Detection'`` 。 - - **ValueError** - 指定任务为 ``'Segmentation'`` 时, `class_indexing` 参数不为 ``None`` 。 - - **ValueError** - 与 `usage` 参数相关的txt文件不存在。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: - - 当参数 `extra_metadata` 为True时,还需使用 `rename` 操作删除额外数据列 '_meta-filename'的前缀 '_meta-', - 否则迭代得到的数据行中不会出现此额外数据列。 - - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于VOC数据集:** - - PASCAL Visual Object Classes(VOC)是视觉目标识别和检测的挑战赛,它为视觉和机器学习社区提供了图像和标注的标准数据集,称为VOC数据集。 - - 您可以解压缩原始VOC-2012数据集文件到如下目录结构,并通过MindSpore的API进行读取。 - - .. code-block:: - - . - └── voc2012_dataset_dir - ├── Annotations - │ ├── 2007_000027.xml - │ ├── 2007_000032.xml - │ ├── ... - ├── ImageSets - │ ├── Action - │ ├── Layout - │ ├── Main - │ └── Segmentation - ├── JPEGImages - │ ├── 2007_000027.jpg - │ ├── 2007_000032.jpg - │ ├── ... - ├── SegmentationClass - │ ├── 2007_000032.png - │ ├── 2007_000033.png - │ ├── ... - └── SegmentationObject - ├── 2007_000032.png - ├── 2007_000033.png - ├── ... - - **引用:** - - .. code-block:: - - @article{Everingham10, - author = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.}, - title = {The Pascal Visual Object Classes (VOC) Challenge}, - journal = {International Journal of Computer Vision}, - volume = {88}, - year = {2012}, - number = {2}, - month = {jun}, - pages = {303--338}, - biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex}, - howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html} - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.VOCDataset +============================= + +.. py:class:: mindspore.dataset.VOCDataset(dataset_dir, task='Segmentation', usage='train', class_indexing=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None, extra_metadata=False, decrypt=None) + + VOC(Visual Object Classes)数据集。 + + 根据给定的 `task` 配置,生成数据集具有不同的输出列: + + - `task` 为 ``'Detection'`` ,输出列: `[image, dtype=uint8]` , `[bbox, dtype=float32]` , `[label, dtype=uint32]` , `[difficult, dtype=uint32]` , `[truncate, dtype=uint32]` 。 + - `task` 为 ``'Segmentation'`` ,输出列: `[image, dtype=uint8]` , `[target, dtype=uint8]` 。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录的路径。 + - **task** (str, 可选) - 指定读取VOC数据的任务类型,现在只支持 ``'Segmentation'`` 和 ``'Detection'`` 。默认值: ``'Segmentation'`` 。 + - **usage** (str, 可选) - 指定数据集的子集。默认值: ``'train'`` 。 + + - 如果 `task` 的值为 ``'Segmentation'`` ,则读取 'ImageSets/Segmentation/' 目录下定义的图片和label信息; + - 如果 `task` 的值为 ``'Detection'`` ,则读取 'ImageSets/Main/' 目录下定义的图片和label信息。 + + - **class_indexing** (dict, 可选) - 指定一个从label名称到label索引的映射,要求映射规则为string到int。索引值从0开始,并且要求每个label名称对应的索引值唯一。 + 仅在 'Detection' 任务中有效。默认值: ``None`` ,不指定。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,所有图像样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + - **extra_metadata** (bool, 可选) - 用于指定是否额外输出一个数据列用于表示图片元信息。如果为 ``True`` ,则将额外输出一个名为 `[_meta-filename, dtype=string]` 的数据列。默认值: ``False`` 。 + - **decrypt** (callable, 可选) - 图像解密函数,接受加密的图片路径并返回bytes类型的解密数据。默认值: ``None`` ,不进行解密。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **RuntimeError** - 读取的xml文件格式异常或无效。 + - **RuntimeError** - 读取的xml文件缺失 `object` 属性。 + - **RuntimeError** - 读取的xml文件缺失 `bndbox` 属性。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - 指定的任务不为 ``'Segmentation'`` 或 ``'Detection'`` 。 + - **ValueError** - 指定任务为 ``'Segmentation'`` 时, `class_indexing` 参数不为 ``None`` 。 + - **ValueError** - 与 `usage` 参数相关的txt文件不存在。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: + - 当参数 `extra_metadata` 为True时,还需使用 `rename` 操作删除额外数据列 '_meta-filename'的前缀 '_meta-', + 否则迭代得到的数据行中不会出现此额外数据列。 + - 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于VOC数据集:** + + PASCAL Visual Object Classes(VOC)是视觉目标识别和检测的挑战赛,它为视觉和机器学习社区提供了图像和标注的标准数据集,称为VOC数据集。 + + 您可以解压缩原始VOC-2012数据集文件到如下目录结构,并通过MindSpore的API进行读取。 + + .. code-block:: + + . + └── voc2012_dataset_dir + ├── Annotations + │ ├── 2007_000027.xml + │ ├── 2007_000032.xml + │ ├── ... + ├── ImageSets + │ ├── Action + │ ├── Layout + │ ├── Main + │ └── Segmentation + ├── JPEGImages + │ ├── 2007_000027.jpg + │ ├── 2007_000032.jpg + │ ├── ... + ├── SegmentationClass + │ ├── 2007_000032.png + │ ├── 2007_000033.png + │ ├── ... + └── SegmentationObject + ├── 2007_000032.png + ├── 2007_000033.png + ├── ... + + **引用:** + + .. code-block:: + + @article{Everingham10, + author = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.}, + title = {The Pascal Visual Object Classes (VOC) Challenge}, + journal = {International Journal of Computer Vision}, + volume = {88}, + year = {2012}, + number = {2}, + month = {jun}, + pages = {303--338}, + biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex}, + howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html} + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.WIDERFaceDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.WIDERFaceDataset.rst index 36f4bf41310..a615223a76b 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.WIDERFaceDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.WIDERFaceDataset.rst @@ -1,93 +1,93 @@ -mindspore.dataset.WIDERFaceDataset -================================== - -.. py:class:: mindspore.dataset.WIDERFaceDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) - - WIDERFace数据集。 - - 当 `usage` 为 "train"、"valid" 或 "all" 时,生成的数据集有八列 `["image", "bbox", "blur", "expression", "illumination", "occlusion", "pose", "invalid"]` 。其中 `image` 列的数据类型为uint8,其他列均为uint32。 - 当 `usage` 为 "test" 时,生成的数据集只有一列 `["image"]`,数据类型为uint8。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取12,880个样本,取值为 ``'test'`` 时将会读取16,097个样本,取值为 ``'valid'`` 时将会读取3,226个样本,取值为 ``'all'`` 时将会读取全部类别样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 不包含任何数据文件。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `usage` 不为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **ValueError** - `annotation_file` 不存在。 - - **ValueError** - `dataset_dir` 不存在。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于WIDERFace数据集:** - - WIDER FACE数据集具有12,880个训练样本,16,097个测试样本,以及3,226个验证样本。此数据集是WIDER数据集的子集。其中图片已经预先进行了尺寸归一化和人像中心化处理。 - - 以下是原始的WIDERFace数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── wider_face_dir - ├── WIDER_test - │ └── images - │ ├── 0--Parade - │ │ ├── 0_Parade_marchingband_1_9.jpg - │ │ ├── ... - │ ├──1--Handshaking - │ ├──... - ├── WIDER_train - │ └── images - │ ├── 0--Parade - │ │ ├── 0_Parade_marchingband_1_11.jpg - │ │ ├── ... - │ ├──1--Handshaking - │ ├──... - ├── WIDER_val - │ └── images - │ ├── 0--Parade - │ │ ├── 0_Parade_marchingband_1_102.jpg - │ │ ├── ... - │ ├──1--Handshaking - │ ├──... - └── wider_face_split - ├── wider_face_test_filelist.txt - ├── wider_face_train_bbx_gt.txt - └── wider_face_val_bbx_gt.txt - - **引用:** - - .. code-block:: - - @inproceedings{2016WIDER, - title={WIDERFACE: A Detection Benchmark}, - author={Yang, S. and Luo, P. and Loy, C. C. and Tang, X.}, - booktitle={IEEE}, - pages={5525-5533}, - year={2016}, - } - - -.. include:: mindspore.dataset.api_list_vision.rst +mindspore.dataset.WIDERFaceDataset +================================== + +.. py:class:: mindspore.dataset.WIDERFaceDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None) + + WIDERFace数据集。 + + 当 `usage` 为 "train"、"valid" 或 "all" 时,生成的数据集有八列 `["image", "bbox", "blur", "expression", "illumination", "occlusion", "pose", "invalid"]` 。其中 `image` 列的数据类型为uint8,其他列均为uint32。 + 当 `usage` 为 "test" 时,生成的数据集只有一列 `["image"]`,数据类型为uint8。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取12,880个样本,取值为 ``'test'`` 时将会读取16,097个样本,取值为 ``'valid'`` 时将会读取3,226个样本,取值为 ``'all'`` 时将会读取全部类别样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **decode** (bool, 可选) - 是否对读取的图片进行解码操作。默认值: ``False`` ,不解码。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 不包含任何数据文件。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `usage` 不为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `annotation_file` 不存在。 + - **ValueError** - `dataset_dir` 不存在。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于WIDERFace数据集:** + + WIDER FACE数据集具有12,880个训练样本,16,097个测试样本,以及3,226个验证样本。此数据集是WIDER数据集的子集。其中图片已经预先进行了尺寸归一化和人像中心化处理。 + + 以下是原始的WIDERFace数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── wider_face_dir + ├── WIDER_test + │ └── images + │ ├── 0--Parade + │ │ ├── 0_Parade_marchingband_1_9.jpg + │ │ ├── ... + │ ├──1--Handshaking + │ ├──... + ├── WIDER_train + │ └── images + │ ├── 0--Parade + │ │ ├── 0_Parade_marchingband_1_11.jpg + │ │ ├── ... + │ ├──1--Handshaking + │ ├──... + ├── WIDER_val + │ └── images + │ ├── 0--Parade + │ │ ├── 0_Parade_marchingband_1_102.jpg + │ │ ├── ... + │ ├──1--Handshaking + │ ├──... + └── wider_face_split + ├── wider_face_test_filelist.txt + ├── wider_face_train_bbx_gt.txt + └── wider_face_val_bbx_gt.txt + + **引用:** + + .. code-block:: + + @inproceedings{2016WIDER, + title={WIDERFACE: A Detection Benchmark}, + author={Yang, S. and Luo, P. and Loy, C. C. and Tang, X.}, + booktitle={IEEE}, + pages={5525-5533}, + year={2016}, + } + + +.. include:: mindspore.dataset.api_list_vision.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.WeightedRandomSampler.rst b/docs/api/api_python/dataset/mindspore.dataset.WeightedRandomSampler.rst index a3253d197f4..a9b0b836699 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.WeightedRandomSampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.WeightedRandomSampler.rst @@ -1,22 +1,22 @@ -mindspore.dataset.WeightedRandomSampler -======================================= - -.. py:class:: mindspore.dataset.WeightedRandomSampler(weights, num_samples=None, replacement=True) - - 给定样本的权重列表,根据权重决定样本的采样概率,随机采样[0,len(weights) - 1]中的样本。 - - 参数: - - **weights** (list[float, int]) - 权重序列,总和不一定为1。 - - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 - - **replacement** (bool) - 是否将样本ID放回下一次采样。默认值: ``True`` ,有放回采样。 - - 异常: - - **TypeError** - `weights` 元素的类型不是数值类型。 - - **TypeError** - `num_samples` 的类型不是int。 - - **TypeError** - `replacement` 的类型不是bool。 - - **RuntimeError** - `weights` 为空或全为零。 - - **ValueError** - `num_samples` 为负值。 - - .. include:: mindspore.dataset.BuiltinSampler.rst - +mindspore.dataset.WeightedRandomSampler +======================================= + +.. py:class:: mindspore.dataset.WeightedRandomSampler(weights, num_samples=None, replacement=True) + + 给定样本的权重列表,根据权重决定样本的采样概率,随机采样[0,len(weights) - 1]中的样本。 + + 参数: + - **weights** (list[float, int]) - 权重序列,总和不一定为1。 + - **num_samples** (int, 可选) - 获取的样本数,可用于部分获取采样得到的样本。默认值: ``None`` ,获取采样到的所有样本。 + - **replacement** (bool) - 是否将样本ID放回下一次采样。默认值: ``True`` ,有放回采样。 + + 异常: + - **TypeError** - `weights` 元素的类型不是数值类型。 + - **TypeError** - `num_samples` 的类型不是int。 + - **TypeError** - `replacement` 的类型不是bool。 + - **RuntimeError** - `weights` 为空或全为零。 + - **ValueError** - `num_samples` 为负值。 + + .. include:: mindspore.dataset.BuiltinSampler.rst + .. include:: mindspore.dataset.BuiltinSampler.b.rst \ No newline at end of file diff --git a/docs/api/api_python/dataset/mindspore.dataset.WikiTextDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.WikiTextDataset.rst index 6a8a16b1620..dc0a84fadb8 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.WikiTextDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.WikiTextDataset.rst @@ -1,67 +1,67 @@ -mindspore.dataset.WikiTextDataset -================================= - -.. py:class:: mindspore.dataset.WikiTextDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - WikiText2和WikiText103数据集。 - - 生成的数据集有一列 `[text]` ,数据类型为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_samples` 参数值错误,小于0。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于WikiText数据集:** - - WikiText数据集是一个包含1亿字的英语词典。 - 这些样本术语来自维基百科的高级和基础文章,包括Wikitext2和Wikitext103的版本。 - 对于WikiText2,分别在wiki.train.tokens中有36718个样本,在wiki.test.tokens中有4358个样本,在wiki.valid.tokens中有3760个样本。 - 对于WikiText103,分别在wiki.train.tokens中有1801350个样本,wiki.test.tokens中的4358个样本,Wiki.valid.tokens中的3760个样本。 - - 以下是原始的WikiText数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── WikiText2/WikiText103 - ├── wiki.train.tokens - ├── wiki.test.tokens - ├── wiki.valid.tokens - - **引用:** - - .. code-block:: - - @article{merity2016pointer, - title={Pointer sentinel mixture models}, - author={Merity, Stephen and Xiong, Caiming and Bradbury, James and Socher, Richard}, - journal={arXiv preprint arXiv:1609.07843}, - year={2016} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.WikiTextDataset +================================= + +.. py:class:: mindspore.dataset.WikiTextDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + WikiText2和WikiText103数据集。 + + 生成的数据集有一列 `[text]` ,数据类型为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 、 ``'valid'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_samples` 参数值错误,小于0。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于WikiText数据集:** + + WikiText数据集是一个包含1亿字的英语词典。 + 这些样本术语来自维基百科的高级和基础文章,包括Wikitext2和Wikitext103的版本。 + 对于WikiText2,分别在wiki.train.tokens中有36718个样本,在wiki.test.tokens中有4358个样本,在wiki.valid.tokens中有3760个样本。 + 对于WikiText103,分别在wiki.train.tokens中有1801350个样本,wiki.test.tokens中的4358个样本,Wiki.valid.tokens中的3760个样本。 + + 以下是原始的WikiText数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── WikiText2/WikiText103 + ├── wiki.train.tokens + ├── wiki.test.tokens + ├── wiki.valid.tokens + + **引用:** + + .. code-block:: + + @article{merity2016pointer, + title={Pointer sentinel mixture models}, + author={Merity, Stephen and Xiong, Caiming and Bradbury, James and Socher, Richard}, + journal={arXiv preprint arXiv:1609.07843}, + year={2016} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.YahooAnswersDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.YahooAnswersDataset.rst index 48a2966630e..3c82fa8f2c6 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.YahooAnswersDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.YahooAnswersDataset.rst @@ -1,67 +1,67 @@ -mindspore.dataset.YahooAnswersDataset -===================================== - -.. py:class:: mindspore.dataset.YahooAnswersDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) - - YahooAnswers数据集。 - - 生成的数据集有四列 `[class, title, content, answer]` ,数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 - 取值为 ``'train'`` 时将会读取1,400,000个训练样本,取值为 ``'test'`` 时将会读取60,000个测试样本,取值为 ``'all'`` 时将会读取全部1,460,000个样本。默认值: ``None`` ,读取全部样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于YahooAnswers数据集:** - - YahooAnswers数据集包含10个类的63万个文本样本。 - train.csv中有56万个样本,test.csv中有7万个样本。 - 这10个不同的类代表社会与文化、科学与数学、健康、教育与参考、计算机与互联网、体育、商业与金融、娱乐与音乐、家庭与关系、政治与政府。 - - 以下是原始的YahooAnswers数据集结构,可以将数据集文件解压缩到此目录结构中,并由Mindspore的API读取。 - - .. code-block:: - - . - └── yahoo_answers_dataset_dir - ├── train.csv - ├── test.csv - ├── classes.txt - └── readme.txt - - **引用:** - - .. code-block:: - - @article{YahooAnswers, - title = {Yahoo! Answers Topic Classification Dataset}, - author = {Xiang Zhang}, - year = {2015}, - howpublished = {} - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.YahooAnswersDataset +===================================== + +.. py:class:: mindspore.dataset.YahooAnswersDataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None) + + YahooAnswers数据集。 + + 生成的数据集有四列 `[class, title, content, answer]` ,数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。 + 取值为 ``'train'`` 时将会读取1,400,000个训练样本,取值为 ``'test'`` 时将会读取60,000个测试样本,取值为 ``'all'`` 时将会读取全部1,460,000个样本。默认值: ``None`` ,读取全部样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于YahooAnswers数据集:** + + YahooAnswers数据集包含10个类的63万个文本样本。 + train.csv中有56万个样本,test.csv中有7万个样本。 + 这10个不同的类代表社会与文化、科学与数学、健康、教育与参考、计算机与互联网、体育、商业与金融、娱乐与音乐、家庭与关系、政治与政府。 + + 以下是原始的YahooAnswers数据集结构,可以将数据集文件解压缩到此目录结构中,并由Mindspore的API读取。 + + .. code-block:: + + . + └── yahoo_answers_dataset_dir + ├── train.csv + ├── test.csv + ├── classes.txt + └── readme.txt + + **引用:** + + .. code-block:: + + @article{YahooAnswers, + title = {Yahoo! Answers Topic Classification Dataset}, + author = {Xiang Zhang}, + year = {2015}, + howpublished = {} + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.YelpReviewDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.YelpReviewDataset.rst index f544284eba4..8742ff792c0 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.YelpReviewDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.YelpReviewDataset.rst @@ -1,96 +1,96 @@ -mindspore.dataset.YelpReviewDataset -=================================== - -.. py:class:: mindspore.dataset.YelpReviewDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) - - Yelp Review Full和Yelp Review Polarity数据集。 - - 生成的数据集有两列 `[label, text]`,两列的数据类型均为string。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 - 对于Polarity数据集, ``'train'`` 将读取560,000个训练样本, ``'test'`` 将读取38,000个测试样本, ``'all'`` 将读取所有598,000个样本。 - 对于Full数据集, ``'train'`` 将读取650,000个训练样本, ``'test'`` 将读取50,000个测试样本, ``'all'`` 将读取所有700,000个样本。默认值: ``None`` ,读取所有样本。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 - 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 - 通过传入枚举变量设置数据混洗的模式: - - - ``Shuffle.GLOBAL`` :混洗文件和样本。 - - ``Shuffle.FILES`` :仅混洗文件。 - - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - **关于YelpReview数据集:** - - Yelp Review Full数据集包括来自Yelp的评论数据。这些数据时从2015年的Yelp数据集挑战赛数据中提取的,主要用于文本分类。 - - Yelp Review Polarity数据集在Full数据集的基础上,对产品评分进行了分级,评论分数1和2视为负面评论,4和5视为正面评论。 - - Yelp Reviews Polarity和Yelp Reviews Full datasets具有相同的目录结构。 - 可以将数据集文件解压缩到以下结构,并通过MindSpore的API读取: - - .. code-block:: - - . - └── yelp_review_dir - ├── train.csv - ├── test.csv - └── readme.txt - - **引用:** - - .. code-block:: - - @article{zhangCharacterlevelConvolutionalNetworks2015, - archivePrefix = {arXiv}, - eprinttype = {arxiv}, - eprint = {1509.01626}, - primaryClass = {cs}, - title = {Character-Level {{Convolutional Networks}} for {{Text Classification}}}, - abstract = {This article offers an empirical exploration on the use of character-level convolutional networks - (ConvNets) for text classification. We constructed several large-scale datasets to show that - character-level convolutional networks could achieve state-of-the-art or competitive results. - Comparisons are offered against traditional models such as bag of words, n-grams and their TFIDF - variants, and deep learning models such as word-based ConvNets and recurrent neural networks.}, - journal = {arXiv:1509.01626 [cs]}, - author = {Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, - month = sep, - year = {2015}, - } - - .. code-block:: - - @article{zhangCharacterlevelConvolutionalNetworks2015, - archivePrefix = {arXiv}, - eprinttype = {arxiv}, - eprint = {1509.01626}, - primaryClass = {cs}, - title = {Character-Level {{Convolutional Networks}} for {{Text Classification}}}, - abstract = {This article offers an empirical exploration on the use of character-level convolutional networks - (ConvNets) for text classification. We constructed several large-scale datasets to show that - character-level convolutional networks could achieve state-of-the-art or competitive results. - Comparisons are offered against traditional models such as bag of words, n-grams and their TFIDF - variants, and deep learning models such as word-based ConvNets and recurrent neural networks.}, - journal = {arXiv:1509.01626 [cs]}, - author = {Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, - month = sep, - year = {2015}, - } - - -.. include:: mindspore.dataset.api_list_nlp.rst +mindspore.dataset.YelpReviewDataset +=================================== + +.. py:class:: mindspore.dataset.YelpReviewDataset(dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, num_parallel_workers=None, cache=None) + + Yelp Review Full和Yelp Review Polarity数据集。 + + 生成的数据集有两列 `[label, text]`,两列的数据类型均为string。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **usage** (str, 可选) - 指定数据集的子集,可取值为 ``'train'`` 、 ``'test'`` 或 ``'all'`` 。默认值: ``None`` ,读取全部样本。 + 对于Polarity数据集, ``'train'`` 将读取560,000个训练样本, ``'test'`` 将读取38,000个测试样本, ``'all'`` 将读取所有598,000个样本。 + 对于Full数据集, ``'train'`` 将读取650,000个训练样本, ``'test'`` 将读取50,000个测试样本, ``'all'`` 将读取所有700,000个样本。默认值: ``None`` ,读取所有样本。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **shuffle** (Union[bool, :class:`~.dataset.Shuffle`], 可选) - 每个epoch中数据混洗的模式,支持传入bool类型与枚举类型进行指定。默认值: ``Shuffle.GLOBAL`` 。 + 如果 `shuffle` 为 ``False`` ,则不混洗,如果 `shuffle` 为 ``True`` ,等同于将 `shuffle` 设置为 ``mindspore.dataset.Shuffle.GLOBAL`` 。 + 通过传入枚举变量设置数据混洗的模式: + + - ``Shuffle.GLOBAL`` :混洗文件和样本。 + - ``Shuffle.FILES`` :仅混洗文件。 + + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + **关于YelpReview数据集:** + + Yelp Review Full数据集包括来自Yelp的评论数据。这些数据时从2015年的Yelp数据集挑战赛数据中提取的,主要用于文本分类。 + + Yelp Review Polarity数据集在Full数据集的基础上,对产品评分进行了分级,评论分数1和2视为负面评论,4和5视为正面评论。 + + Yelp Reviews Polarity和Yelp Reviews Full datasets具有相同的目录结构。 + 可以将数据集文件解压缩到以下结构,并通过MindSpore的API读取: + + .. code-block:: + + . + └── yelp_review_dir + ├── train.csv + ├── test.csv + └── readme.txt + + **引用:** + + .. code-block:: + + @article{zhangCharacterlevelConvolutionalNetworks2015, + archivePrefix = {arXiv}, + eprinttype = {arxiv}, + eprint = {1509.01626}, + primaryClass = {cs}, + title = {Character-Level {{Convolutional Networks}} for {{Text Classification}}}, + abstract = {This article offers an empirical exploration on the use of character-level convolutional networks + (ConvNets) for text classification. We constructed several large-scale datasets to show that + character-level convolutional networks could achieve state-of-the-art or competitive results. + Comparisons are offered against traditional models such as bag of words, n-grams and their TFIDF + variants, and deep learning models such as word-based ConvNets and recurrent neural networks.}, + journal = {arXiv:1509.01626 [cs]}, + author = {Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, + month = sep, + year = {2015}, + } + + .. code-block:: + + @article{zhangCharacterlevelConvolutionalNetworks2015, + archivePrefix = {arXiv}, + eprinttype = {arxiv}, + eprint = {1509.01626}, + primaryClass = {cs}, + title = {Character-Level {{Convolutional Networks}} for {{Text Classification}}}, + abstract = {This article offers an empirical exploration on the use of character-level convolutional networks + (ConvNets) for text classification. We constructed several large-scale datasets to show that + character-level convolutional networks could achieve state-of-the-art or competitive results. + Comparisons are offered against traditional models such as bag of words, n-grams and their TFIDF + variants, and deep learning models such as word-based ConvNets and recurrent neural networks.}, + journal = {arXiv:1509.01626 [cs]}, + author = {Zhang, Xiang and Zhao, Junbo and LeCun, Yann}, + month = sep, + year = {2015}, + } + + +.. include:: mindspore.dataset.api_list_nlp.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.YesNoDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.YesNoDataset.rst index eb69a6a6d52..24be41e5041 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.YesNoDataset.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.YesNoDataset.rst @@ -1,63 +1,63 @@ -mindspore.dataset.YesNoDataset -============================== - -.. py:class:: mindspore.dataset.YesNoDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) - - YesNo数据集。 - - 生成的数据集有三列 `[waveform, sample_rate, labels]` 。 - 列 `waveform` 的数据类型为float32。列 `sample_rate` 的数据类型为int32。列 `labels` 的数据类型为int32。 - - 参数: - - **dataset_dir** (str) - 包含数据集文件的根目录路径。 - - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 - - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 - - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 - - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 - - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 - - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 - - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 - - 异常: - - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 - - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 - - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 - - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 - - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 - - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 - - 教程样例: - - `使用数据Pipeline加载 & 处理数据集 - `_ - - .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 - - .. include:: mindspore.dataset.sampler.rst - - **关于YesNo数据集:** - - Yesno是一个音频数据集,由60个录音组成,由一个人用希伯来语说是或不是;每个录音都有8个字长。 - - 以下是原始的YesNo数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 - - .. code-block:: - - . - └── yes_no_dataset_dir - ├── 1_1_0_0_1_1_0_0.wav - ├── 1_0_0_0_1_1_0_0.wav - ├── 1_1_0_0_1_1_0_0.wav - └──.... - - **引用:** - - .. code-block:: - - @NetworkResource{Kaldi_audio_project, - author = {anonymous}, - url = "http://wwww.openslr.org/1/" - } - - -.. include:: mindspore.dataset.api_list_audio.rst +mindspore.dataset.YesNoDataset +============================== + +.. py:class:: mindspore.dataset.YesNoDataset(dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None) + + YesNo数据集。 + + 生成的数据集有三列 `[waveform, sample_rate, labels]` 。 + 列 `waveform` 的数据类型为float32。列 `sample_rate` 的数据类型为int32。列 `labels` 的数据类型为int32。 + + 参数: + - **dataset_dir** (str) - 包含数据集文件的根目录路径。 + - **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值: ``None`` ,读取全部样本。 + - **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值: ``None`` ,使用全局默认线程数(8),也可以通过 :func:`mindspore.dataset.config.set_num_parallel_workers` 配置全局线程数。 + - **shuffle** (bool, 可选) - 是否混洗数据集。默认值: ``None`` 。下表中会展示不同参数配置的预期行为。 + - **sampler** (Sampler, 可选) - 指定从数据集中选取样本的采样器。默认值: ``None`` 。下表中会展示不同配置的预期行为。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值: ``None`` 。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值: ``None`` 。只有当指定了 `num_shards` 时才能指定此参数。 + - **cache** (:class:`~.dataset.DatasetCache`, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 `_ 。默认值: ``None`` ,不使用缓存。 + + 异常: + - **RuntimeError** - `dataset_dir` 路径下不包含任何数据文件。 + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `shuffle` 参数。 + - **RuntimeError** - 同时指定了 `sampler` 和 `num_shards` 参数或同时指定了 `sampler` 和 `shard_id` 参数。 + - **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。 + - **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。 + - **ValueError** - 如果 `shard_id` 取值不在[0, `num_shards` )范围。 + + 教程样例: + - `使用数据Pipeline加载 & 处理数据集 + `_ + + .. note:: 入参 `num_samples` 、 `shuffle` 、 `num_shards` 、 `shard_id` 可用于控制数据集所使用的采样器,其与入参 `sampler` 搭配使用的效果如下。 + + .. include:: mindspore.dataset.sampler.rst + + **关于YesNo数据集:** + + Yesno是一个音频数据集,由60个录音组成,由一个人用希伯来语说是或不是;每个录音都有8个字长。 + + 以下是原始的YesNo数据集结构。可以将数据集文件解压缩到此目录结构中,并由MindSpore的API读取。 + + .. code-block:: + + . + └── yes_no_dataset_dir + ├── 1_1_0_0_1_1_0_0.wav + ├── 1_0_0_0_1_1_0_0.wav + ├── 1_1_0_0_1_1_0_0.wav + └──.... + + **引用:** + + .. code-block:: + + @NetworkResource{Kaldi_audio_project, + author = {anonymous}, + url = "http://wwww.openslr.org/1/" + } + + +.. include:: mindspore.dataset.api_list_audio.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.config.ErrorSamplesMode.rst b/docs/api/api_python/dataset/mindspore.dataset.config.ErrorSamplesMode.rst index 5f54cc4272c..a6c96ce6752 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.config.ErrorSamplesMode.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.config.ErrorSamplesMode.rst @@ -1,10 +1,10 @@ -mindspore.dataset.config.ErrorSamplesMode -=========================================== - -.. py:class:: mindspore.dataset.config.ErrorSamplesMode - - 指定数据管道中处理错误样本的策略。 - - - **ErrorSamplesMode.RETURN** - 表示处理过程中遇到错误样本时将报错并抛出异常。 - - **ErrorSamplesMode.REPLACE** - 表示处理过程中遇到错误样本时将使用正确的样本替换处理。 - - **ErrorSamplesMode.SKIP** - 表示处理过程中遇到错误样本时将直接跳过此样本。 +mindspore.dataset.config.ErrorSamplesMode +=========================================== + +.. py:class:: mindspore.dataset.config.ErrorSamplesMode + + 指定数据管道中处理错误样本的策略。 + + - **ErrorSamplesMode.RETURN** - 表示处理过程中遇到错误样本时将报错并抛出异常。 + - **ErrorSamplesMode.REPLACE** - 表示处理过程中遇到错误样本时将使用正确的样本替换处理。 + - **ErrorSamplesMode.SKIP** - 表示处理过程中遇到错误样本时将直接跳过此样本。 diff --git a/docs/api/api_python/dataset/mindspore.dataset.sampler.rst b/docs/api/api_python/dataset/mindspore.dataset.sampler.rst index e3480642d26..ae3c9f95d7a 100644 --- a/docs/api/api_python/dataset/mindspore.dataset.sampler.rst +++ b/docs/api/api_python/dataset/mindspore.dataset.sampler.rst @@ -1,49 +1,49 @@ -.. list-table:: 参数 `sampler` 和 `num_samples` , `shuffle` , `num_shards` , `shard_id` 的不同组合得到的采样器 - :widths: 150 150 50 50 350 - :header-rows: 1 - - * - 参数 `sampler` - - 参数 `num_shards` / `shard_id` - - 参数 `shuffle` - - 参数 `num_samples` - - **使用的采样器** - * - `mindspore.dataset.Sampler` 类型 - - *None* - - *None* - - *None* - - **sampler** - * - `numpy.ndarray,list,tuple,int` 类型 - - / - - / - - *num_samples* - - *SubsetSampler(indices =* **sampler** *, num_samples =* **num_samples** *)* - * - `iterable` 类型 - - / - - / - - *num_samples* - - *IterSampler(sampler =* **sampler** *, num_samples =* **num_samples** *)* - * - *None* - - *num_shards* / *shard_id* - - *None* / *True* - - *num_samples* - - *DistributedSampler(num_shards =* **num_shards** *, shard_id =* **shard_id** *, shuffle =* **True** *, num_samples =* **num_samples** *)* - * - *None* - - *num_shards* / *shard_id* - - *False* - - *num_samples* - - *DistributedSampler(num_shards =* **num_shards** *, shard_id =* **shard_id** *, shuffle =* **False** *, num_samples =* **num_samples** *)* - * - *None* - - *None* - - *None* / *True* - - *None* - - *RandomSampler(num_samples =* **num_samples** *)* - * - *None* - - *None* - - *None* / *True* - - *num_samples* - - *RandomSampler(replacement =* **True** *, num_samples =* **num_samples** *)* - * - *None* - - *None* - - *False* - - *num_samples* - - *SequentialSampler(num_samples =* **num_samples** *)* +.. list-table:: 参数 `sampler` 和 `num_samples` , `shuffle` , `num_shards` , `shard_id` 的不同组合得到的采样器 + :widths: 150 150 50 50 350 + :header-rows: 1 + + * - 参数 `sampler` + - 参数 `num_shards` / `shard_id` + - 参数 `shuffle` + - 参数 `num_samples` + - **使用的采样器** + * - `mindspore.dataset.Sampler` 类型 + - *None* + - *None* + - *None* + - **sampler** + * - `numpy.ndarray,list,tuple,int` 类型 + - / + - / + - *num_samples* + - *SubsetSampler(indices =* **sampler** *, num_samples =* **num_samples** *)* + * - `iterable` 类型 + - / + - / + - *num_samples* + - *IterSampler(sampler =* **sampler** *, num_samples =* **num_samples** *)* + * - *None* + - *num_shards* / *shard_id* + - *None* / *True* + - *num_samples* + - *DistributedSampler(num_shards =* **num_shards** *, shard_id =* **shard_id** *, shuffle =* **True** *, num_samples =* **num_samples** *)* + * - *None* + - *num_shards* / *shard_id* + - *False* + - *num_samples* + - *DistributedSampler(num_shards =* **num_shards** *, shard_id =* **shard_id** *, shuffle =* **False** *, num_samples =* **num_samples** *)* + * - *None* + - *None* + - *None* / *True* + - *None* + - *RandomSampler(num_samples =* **num_samples** *)* + * - *None* + - *None* + - *None* / *True* + - *num_samples* + - *RandomSampler(replacement =* **True** *, num_samples =* **num_samples** *)* + * - *None* + - *None* + - *False* + - *num_samples* + - *SequentialSampler(num_samples =* **num_samples** *)* diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DBToAmplitude.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DBToAmplitude.rst index a4927207162..581aba7318e 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DBToAmplitude.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DBToAmplitude.rst @@ -1,18 +1,18 @@ -mindspore.dataset.audio.DBToAmplitude -===================================== - -.. py:class:: mindspore.dataset.audio.DBToAmplitude(ref, power) - - 将音频波形从分贝转换为功率或振幅。 - - 参数: - - **ref** (float) - 输出波形的缩放系数。 - - **power** (float) - 如果 `power` 等于 ``1`` ,则将分贝值转为功率;如果为 ``0.5`` ,则将分贝值转为振幅。 - - 异常: - - **TypeError** - 如果 `ref` 不是float类型。 - - **TypeError** - 如果 `power` 不是float类型。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.DBToAmplitude +===================================== + +.. py:class:: mindspore.dataset.audio.DBToAmplitude(ref, power) + + 将音频波形从分贝转换为功率或振幅。 + + 参数: + - **ref** (float) - 输出波形的缩放系数。 + - **power** (float) - 如果 `power` 等于 ``1`` ,则将分贝值转为功率;如果为 ``0.5`` ,则将分贝值转为振幅。 + + 异常: + - **TypeError** - 如果 `ref` 不是float类型。 + - **TypeError** - 如果 `power` 不是float类型。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DCShift.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DCShift.rst index f7ae4dd9636..303fb1d2f7f 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DCShift.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DCShift.rst @@ -1,19 +1,19 @@ -mindspore.dataset.audio.DCShift -=============================== - -.. py:class:: mindspore.dataset.audio.DCShift(shift, limiter_gain=None) - - 对输入音频波形施加直流移位。可以从音频中删除直流偏移(DC Offset)。 - - 参数: - - **shift** (float) - 音频的移位量,值必须在[-2.0, 2.0]范围内。 - - **limiter_gain** (float, 可选) - 防止截断,仅在波峰生效。值应远小于1,如0.05或0.02。默认值: ``None`` ,将被设置为 `shift` 。 - - 异常: - - **TypeError** - 如果 `shift` 不是float类型。 - - **ValueError** - 如果 `shift` 不在[-2.0, 2.0]范围内。 - - **TypeError** - 如果 `limiter_gain` 不是float类型。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.DCShift +=============================== + +.. py:class:: mindspore.dataset.audio.DCShift(shift, limiter_gain=None) + + 对输入音频波形施加直流移位。可以从音频中删除直流偏移(DC Offset)。 + + 参数: + - **shift** (float) - 音频的移位量,值必须在[-2.0, 2.0]范围内。 + - **limiter_gain** (float, 可选) - 防止截断,仅在波峰生效。值应远小于1,如0.05或0.02。默认值: ``None`` ,将被设置为 `shift` 。 + + 异常: + - **TypeError** - 如果 `shift` 不是float类型。 + - **ValueError** - 如果 `shift` 不在[-2.0, 2.0]范围内。 + - **TypeError** - 如果 `limiter_gain` 不是float类型。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DetectPitchFrequency.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DetectPitchFrequency.rst index 78b46fb1e27..c33421d280b 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DetectPitchFrequency.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.DetectPitchFrequency.rst @@ -1,30 +1,30 @@ -mindspore.dataset.audio.DetectPitchFrequency -============================================ - -.. py:class:: mindspore.dataset.audio.DetectPitchFrequency(sample_rate, frame_time=0.01, win_length=30, freq_low=85, freq_high=3400) - - 检测音调频率。 - 基于归一化互相关函数和中位平滑来实现。 - - 参数: - - **sample_rate** (int) - 波形的采样频率,如44100 (单位:Hz),值不能为0。 - - **frame_time** (float, 可选) - 帧的持续时间,值必须大于零。默认值: ``0.01`` 。 - - **win_length** (int, 可选) - 中位平滑的窗口长度(以帧数为单位),该值必须大于零。默认值: ``30`` 。 - - **freq_low** (int, 可选) - 可检测的最低频率(Hz),该值必须大于零。默认值: ``85`` 。 - - **freq_high** (int, 可选) - 可检测的最高频率(Hz),该值必须大于零。默认值: ``3400`` 。 - - 异常: - - **TypeError** - 如果 `sample_rate` 不是int类型。 - - **ValueError** - 如果 `sample_rate` 为0。 - - **TypeError** - 如果 `frame_time` 不是float类型。 - - **ValueError** - 如果 `frame_time` 不为正数。 - - **TypeError** - 如果 `win_length` 不是int类型。 - - **ValueError** - 如果 `win_length` 不为正数。 - - **TypeError** - 如果 `freq_low` 不是int类型。 - - **ValueError** - 如果 `freq_low` 不为正数。 - - **TypeError** - 如果 `freq_high` 不是int类型。 - - **ValueError** - 如果 `freq_high` 不为正数。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.DetectPitchFrequency +============================================ + +.. py:class:: mindspore.dataset.audio.DetectPitchFrequency(sample_rate, frame_time=0.01, win_length=30, freq_low=85, freq_high=3400) + + 检测音调频率。 + 基于归一化互相关函数和中位平滑来实现。 + + 参数: + - **sample_rate** (int) - 波形的采样频率,如44100 (单位:Hz),值不能为0。 + - **frame_time** (float, 可选) - 帧的持续时间,值必须大于零。默认值: ``0.01`` 。 + - **win_length** (int, 可选) - 中位平滑的窗口长度(以帧数为单位),该值必须大于零。默认值: ``30`` 。 + - **freq_low** (int, 可选) - 可检测的最低频率(Hz),该值必须大于零。默认值: ``85`` 。 + - **freq_high** (int, 可选) - 可检测的最高频率(Hz),该值必须大于零。默认值: ``3400`` 。 + + 异常: + - **TypeError** - 如果 `sample_rate` 不是int类型。 + - **ValueError** - 如果 `sample_rate` 为0。 + - **TypeError** - 如果 `frame_time` 不是float类型。 + - **ValueError** - 如果 `frame_time` 不为正数。 + - **TypeError** - 如果 `win_length` 不是int类型。 + - **ValueError** - 如果 `win_length` 不为正数。 + - **TypeError** - 如果 `freq_low` 不是int类型。 + - **ValueError** - 如果 `freq_low` 不为正数。 + - **TypeError** - 如果 `freq_high` 不是int类型。 + - **ValueError** - 如果 `freq_high` 不为正数。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.EqualizerBiquad.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.EqualizerBiquad.rst index b51ad559863..93cd0e40791 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.EqualizerBiquad.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.EqualizerBiquad.rst @@ -1,26 +1,26 @@ -mindspore.dataset.audio.EqualizerBiquad -======================================= - -.. py:class:: mindspore.dataset.audio.EqualizerBiquad(sample_rate, center_freq, gain, Q=0.707) - - 给音频波形施加双二次均衡器滤波器。 - - 接口实现方式类似于 `SoX库 `_ 。 - - 参数: - - **sample_rate** (int) - 波形的采样频率,如 ``44100`` (单位:Hz),值不能为0。 - - **center_freq** (float) - 中心频率(单位:Hz)。 - - **gain** (float) - 期望提升(或衰减)的音频增益(单位:dB)。 - - **Q** (float, 可选) - `品质因子 `_ ,能够反映带宽与采样频率和中心频率的关系,取值范围为(0, 1]。默认值: ``0.707`` 。 - - 异常: - - **TypeError** - 当 `sample_rate` 的类型不为int。 - - **ValueError** - 当 `sample_rate` 的数值为0。 - - **TypeError** - 当 `center_freq` 的类型不为float。 - - **TypeError** - 当 `gain` 的类型不为float。 - - **TypeError** - 当 `Q` 的类型不为float。 - - **ValueError** - 当 `Q` 取值不在(0, 1]范围内。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.EqualizerBiquad +======================================= + +.. py:class:: mindspore.dataset.audio.EqualizerBiquad(sample_rate, center_freq, gain, Q=0.707) + + 给音频波形施加双二次均衡器滤波器。 + + 接口实现方式类似于 `SoX库 `_ 。 + + 参数: + - **sample_rate** (int) - 波形的采样频率,如 ``44100`` (单位:Hz),值不能为0。 + - **center_freq** (float) - 中心频率(单位:Hz)。 + - **gain** (float) - 期望提升(或衰减)的音频增益(单位:dB)。 + - **Q** (float, 可选) - `品质因子 `_ ,能够反映带宽与采样频率和中心频率的关系,取值范围为(0, 1]。默认值: ``0.707`` 。 + + 异常: + - **TypeError** - 当 `sample_rate` 的类型不为int。 + - **ValueError** - 当 `sample_rate` 的数值为0。 + - **TypeError** - 当 `center_freq` 的类型不为float。 + - **TypeError** - 当 `gain` 的类型不为float。 + - **TypeError** - 当 `Q` 的类型不为float。 + - **ValueError** - 当 `Q` 取值不在(0, 1]范围内。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Gain.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Gain.rst index 8d858b51fff..ff8522970c0 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Gain.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Gain.rst @@ -1,16 +1,16 @@ -mindspore.dataset.audio.Gain -============================ - -.. py:class:: mindspore.dataset.audio.Gain(gain_db=1.0) - - 放大或衰减整个音频波形。 - - 参数: - - **gain_db** (float) - 增益调整,单位为分贝(dB)。默认值: ``1.0`` 。 - - 异常: - - **TypeError** - 当 `gain_db` 的类型不为float。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.Gain +============================ + +.. py:class:: mindspore.dataset.audio.Gain(gain_db=1.0) + + 放大或衰减整个音频波形。 + + 参数: + - **gain_db** (float) - 增益调整,单位为分贝(dB)。默认值: ``1.0`` 。 + + 异常: + - **TypeError** - 当 `gain_db` 的类型不为float。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.GriffinLim.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.GriffinLim.rst index 1caa092baa9..e60de97a50e 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.GriffinLim.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.GriffinLim.rst @@ -1,46 +1,46 @@ -mindspore.dataset.audio.GriffinLim -================================== - -.. py:class:: mindspore.dataset.audio.GriffinLim(n_fft=400, n_iter=32, win_length=None, hop_length=None, window_type=WindowType.HANN, power=2.0, momentum=0.99, length=None, rand_init=True) - - 使用Griffin-Lim算法从线性幅度频谱图中计算信号波形。 - - 有关Griffin-Lim算法更多的描述,详见论文 `A fast Griffin-Lim algorithm `_ - 与 `Signal estimation from modified short-time Fourier transform `_ 。 - - 参数: - - **n_fft** (int, 可选) - FFT的长度。默认值: ``400`` 。 - - **n_iter** (int, 可选) - 相位恢复的迭代次数。默认值: ``32`` 。 - - **win_length** (int, 可选) - GriffinLim的窗口大小。默认值: ``None`` ,将设置为 `n_fft` 的值。 - - **hop_length** (int, 可选) - STFT窗口之间的跳数长度。默认值: ``None`` ,将设置为 `win_length//2` 。 - - **window_type** (:class:`~.audio.WindowType`, 可选) - GriffinLim的窗口类型,可以是 ``WindowType.BARTLETT`` , - ``WindowType.BLACKMAN`` , ``WindowType.HAMMING`` , ``WindowType.HANN`` 或 ``WindowType.KAISER`` 。 - 默认值: ``WindowType.HANN`` ,目前macOS上不支持kaiser窗口。 - - **power** (float, 可选) - 幅度谱图的指数。默认值: ``2.0`` 。 - - **momentum** (float, 可选) - 快速Griffin-Lim的动量。默认值: ``0.99`` 。 - - **length** (int, 可选) - 预期输出波形的长度。默认值: ``None`` ,将设置为stft矩阵的最后一个维度的值。 - - **rand_init** (bool, 可选) - 随机相位初始化或全零相位初始化标志。默认值: ``True`` 。 - - 异常: - - **TypeError** - 如果 `n_fft` 的类型不为int。 - - **ValueError** - 如果 `n_ftt` 不为正数。 - - **TypeError** - 如果 `n_iter` 的类型不为int。 - - **ValueError** - 如果 `n_iter` 不为正数。 - - **TypeError** - 如果 `win_length` 的类型不为int。 - - **ValueError** - 如果 `win_length` 为负数。 - - **TypeError** - 如果 `hop_length` 的类型不为int。 - - **ValueError** - 如果 `hop_length` 为负数。 - - **TypeError** - 如果 `window_type` 的类型不为 :class:`mindspore.dataset.audio.WindowType` 。 - - **TypeError** - 如果 `power` 的类型不为float。 - - **ValueError** - 如果 `power` 不为正数。 - - **TypeError** - 如果 `momentum` 的类型不为float。 - - **ValueError** - 如果 `momentum` 为负数。 - - **TypeError** - 如果 `length` 的类型不为int。 - - **ValueError** - 如果 `length` 为负数。 - - **TypeError** - 如果 `rand_init` 的类型不为bool。 - - **RuntimeError** - 当 `n_fft` 指定的FFT长度不小于 `length` 指定的输出波形长度。 - - **RuntimeError** - 当 `win_length` 指定的窗口长度不小于 `n_fft` 指定的FFT长度。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.GriffinLim +================================== + +.. py:class:: mindspore.dataset.audio.GriffinLim(n_fft=400, n_iter=32, win_length=None, hop_length=None, window_type=WindowType.HANN, power=2.0, momentum=0.99, length=None, rand_init=True) + + 使用Griffin-Lim算法从线性幅度频谱图中计算信号波形。 + + 有关Griffin-Lim算法更多的描述,详见论文 `A fast Griffin-Lim algorithm `_ + 与 `Signal estimation from modified short-time Fourier transform `_ 。 + + 参数: + - **n_fft** (int, 可选) - FFT的长度。默认值: ``400`` 。 + - **n_iter** (int, 可选) - 相位恢复的迭代次数。默认值: ``32`` 。 + - **win_length** (int, 可选) - GriffinLim的窗口大小。默认值: ``None`` ,将设置为 `n_fft` 的值。 + - **hop_length** (int, 可选) - STFT窗口之间的跳数长度。默认值: ``None`` ,将设置为 `win_length//2` 。 + - **window_type** (:class:`~.audio.WindowType`, 可选) - GriffinLim的窗口类型,可以是 ``WindowType.BARTLETT`` , + ``WindowType.BLACKMAN`` , ``WindowType.HAMMING`` , ``WindowType.HANN`` 或 ``WindowType.KAISER`` 。 + 默认值: ``WindowType.HANN`` ,目前macOS上不支持kaiser窗口。 + - **power** (float, 可选) - 幅度谱图的指数。默认值: ``2.0`` 。 + - **momentum** (float, 可选) - 快速Griffin-Lim的动量。默认值: ``0.99`` 。 + - **length** (int, 可选) - 预期输出波形的长度。默认值: ``None`` ,将设置为stft矩阵的最后一个维度的值。 + - **rand_init** (bool, 可选) - 随机相位初始化或全零相位初始化标志。默认值: ``True`` 。 + + 异常: + - **TypeError** - 如果 `n_fft` 的类型不为int。 + - **ValueError** - 如果 `n_ftt` 不为正数。 + - **TypeError** - 如果 `n_iter` 的类型不为int。 + - **ValueError** - 如果 `n_iter` 不为正数。 + - **TypeError** - 如果 `win_length` 的类型不为int。 + - **ValueError** - 如果 `win_length` 为负数。 + - **TypeError** - 如果 `hop_length` 的类型不为int。 + - **ValueError** - 如果 `hop_length` 为负数。 + - **TypeError** - 如果 `window_type` 的类型不为 :class:`mindspore.dataset.audio.WindowType` 。 + - **TypeError** - 如果 `power` 的类型不为float。 + - **ValueError** - 如果 `power` 不为正数。 + - **TypeError** - 如果 `momentum` 的类型不为float。 + - **ValueError** - 如果 `momentum` 为负数。 + - **TypeError** - 如果 `length` 的类型不为int。 + - **ValueError** - 如果 `length` 为负数。 + - **TypeError** - 如果 `rand_init` 的类型不为bool。 + - **RuntimeError** - 当 `n_fft` 指定的FFT长度不小于 `length` 指定的输出波形长度。 + - **RuntimeError** - 当 `win_length` 指定的窗口长度不小于 `n_fft` 指定的FFT长度。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.InverseMelScale.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.InverseMelScale.rst index 97ea853e0bd..09c58010e20 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.InverseMelScale.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.InverseMelScale.rst @@ -1,44 +1,44 @@ -mindspore.dataset.audio.InverseMelScale -======================================= - -.. py:class:: mindspore.dataset.audio.InverseMelScale(n_stft, n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, max_iter=100000, tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None, norm=NormType.NONE, mel_type=MelType.HTK) - - 使用转换矩阵从梅尔频率STFT求解普通频率的STFT。 - - 参数: - - **n_stft** (int) - STFT中的频段数。 - - **n_mels** (int, 可选) - mel滤波器的数量。默认值: ``128`` 。 - - **sample_rate** (int, 可选) - 音频信号采样频率。默认值: ``16000`` 。 - - **f_min** (float, 可选) - 最小频率。默认值: ``0.0`` 。 - - **f_max** (float, 可选) - 最大频率。默认值: ``None`` ,将设置为 `sample_rate//2` 。 - - **max_iter** (int, 可选) - 最大优化迭代次数。默认值: ``100000`` 。 - - **tolerance_loss** (float, 可选) - 当达到损失值时停止优化。默认值: ``1e-5`` 。 - - **tolerance_change** (float, 可选) - 指定损失差异,当达到损失差异时停止优化。默认值: ``1e-8`` 。 - - **sgdargs** (dict, 可选) - SGD优化器的参数。默认值: ``None`` ,将设置为{'sgd_lr': 0.1, 'sgd_momentum': 0.9}。 - - **norm** (:class:`~.audio.NormType`, 可选) - 标准化方法,可以是 ``NormType.SLANEY`` 或 ``NormType.NONE`` 。默认值: ``NormType.NONE`` ,不使用标准化。 - - **mel_type** (:class:`~.audio.MelType`, 可选) - 要使用的Mel比例,可以是 ``MelType.SLAN`` 或 ``MelType.HTK`` 。默认值: ``MelType.HTK`` 。 - - 异常: - - **TypeError** - 如果 `n_fft` 的类型不为int。 - - **ValueError** - 如果 `n_ftt` 不为正数。 - - **TypeError** - 如果 `n_mels` 的类型不为int。 - - **ValueError** - 如果 `n_mels` 不为正数。 - - **TypeError** - 如果 `sample_rate` 的类型不为int。 - - **ValueError** - 如果 `sample_rate` 不为正数。 - - **TypeError** - 如果 `f_min` 的类型不为float。 - - **ValueError** - 如果 `f_min` 大于等于 `f_max` 。 - - **TypeError** - 如果 `f_max` 的类型不为float。 - - **ValueError** - 如果 `f_max` 为负数。 - - **TypeError** - 如果 `max_iter` 的类型不为int。 - - **ValueError** - 如果 `max_iter` 为负数。 - - **TypeError** - 如果 `tolerance_loss` 的类型不为float。 - - **ValueError** - 如果 `tolerance_loss` 为负数。 - - **TypeError** - 如果 `tolerance_change` 的类型不为float。 - - **ValueError** - 如果 `tolerance_change` 为负数。 - - **TypeError** - 如果 `sgdargs` 的类型不为dict。 - - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 - - **TypeError** - 如果 `mel_type` 的类型不为 :class:`mindspore.dataset.audio.MelType` 。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.InverseMelScale +======================================= + +.. py:class:: mindspore.dataset.audio.InverseMelScale(n_stft, n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, max_iter=100000, tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None, norm=NormType.NONE, mel_type=MelType.HTK) + + 使用转换矩阵从梅尔频率STFT求解普通频率的STFT。 + + 参数: + - **n_stft** (int) - STFT中的频段数。 + - **n_mels** (int, 可选) - mel滤波器的数量。默认值: ``128`` 。 + - **sample_rate** (int, 可选) - 音频信号采样频率。默认值: ``16000`` 。 + - **f_min** (float, 可选) - 最小频率。默认值: ``0.0`` 。 + - **f_max** (float, 可选) - 最大频率。默认值: ``None`` ,将设置为 `sample_rate//2` 。 + - **max_iter** (int, 可选) - 最大优化迭代次数。默认值: ``100000`` 。 + - **tolerance_loss** (float, 可选) - 当达到损失值时停止优化。默认值: ``1e-5`` 。 + - **tolerance_change** (float, 可选) - 指定损失差异,当达到损失差异时停止优化。默认值: ``1e-8`` 。 + - **sgdargs** (dict, 可选) - SGD优化器的参数。默认值: ``None`` ,将设置为{'sgd_lr': 0.1, 'sgd_momentum': 0.9}。 + - **norm** (:class:`~.audio.NormType`, 可选) - 标准化方法,可以是 ``NormType.SLANEY`` 或 ``NormType.NONE`` 。默认值: ``NormType.NONE`` ,不使用标准化。 + - **mel_type** (:class:`~.audio.MelType`, 可选) - 要使用的Mel比例,可以是 ``MelType.SLAN`` 或 ``MelType.HTK`` 。默认值: ``MelType.HTK`` 。 + + 异常: + - **TypeError** - 如果 `n_fft` 的类型不为int。 + - **ValueError** - 如果 `n_ftt` 不为正数。 + - **TypeError** - 如果 `n_mels` 的类型不为int。 + - **ValueError** - 如果 `n_mels` 不为正数。 + - **TypeError** - 如果 `sample_rate` 的类型不为int。 + - **ValueError** - 如果 `sample_rate` 不为正数。 + - **TypeError** - 如果 `f_min` 的类型不为float。 + - **ValueError** - 如果 `f_min` 大于等于 `f_max` 。 + - **TypeError** - 如果 `f_max` 的类型不为float。 + - **ValueError** - 如果 `f_max` 为负数。 + - **TypeError** - 如果 `max_iter` 的类型不为int。 + - **ValueError** - 如果 `max_iter` 为负数。 + - **TypeError** - 如果 `tolerance_loss` 的类型不为float。 + - **ValueError** - 如果 `tolerance_loss` 为负数。 + - **TypeError** - 如果 `tolerance_change` 的类型不为float。 + - **ValueError** - 如果 `tolerance_change` 为负数。 + - **TypeError** - 如果 `sgdargs` 的类型不为dict。 + - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 + - **TypeError** - 如果 `mel_type` 的类型不为 :class:`mindspore.dataset.audio.MelType` 。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Magphase.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Magphase.rst index a9a8af59292..aec1bd9ada1 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Magphase.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.Magphase.rst @@ -1,16 +1,16 @@ -mindspore.dataset.audio.Magphase -================================ - -.. py:class:: mindspore.dataset.audio.Magphase(power=1.0) - - 将shape为 :math:`(..., 2)` 的复值光谱图分离,输出幅度和相位。 - - 参数: - - **power** (float) - 范数的幂,必须是非负的。默认值: ``1.0`` 。 - - 异常: - - **RuntimeError** - 当输入音频的shape不为(..., 2)。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.Magphase +================================ + +.. py:class:: mindspore.dataset.audio.Magphase(power=1.0) + + 将shape为 :math:`(..., 2)` 的复值光谱图分离,输出幅度和相位。 + + 参数: + - **power** (float) - 范数的幂,必须是非负的。默认值: ``1.0`` 。 + + 异常: + - **RuntimeError** - 当输入音频的shape不为(..., 2)。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxis.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxis.rst index b072d07e416..904b7b98015 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxis.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxis.rst @@ -1,21 +1,21 @@ -mindspore.dataset.audio.MaskAlongAxis -===================================== - -.. py:class:: mindspore.dataset.audio.MaskAlongAxis(mask_start, mask_width, mask_value, axis) - - 对音频波形应用掩码。掩码的起始和长度由 `[mask_start, mask_start + mask_width)` 决定。 - - 参数: - - **mask_start** (int) - 掩码的起始位置,必须是非负的。 - - **mask_width** (int) - 掩码的宽度,必须是大于0。 - - **mask_value** (float) - 填充到掩码区间的值。 - - **axis** (int) - 要应用掩码的轴( ``1`` 表示频率, ``2`` 表示时间)。 - - 异常: - - **ValueError** - `mask_start` 参数值错误(小于0)。 - - **ValueError** - `mask_width` 参数值错误(小于1)。 - - **ValueError** - `axis` 参数类型错误或者值错误,不属于 [1, 2]。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.MaskAlongAxis +===================================== + +.. py:class:: mindspore.dataset.audio.MaskAlongAxis(mask_start, mask_width, mask_value, axis) + + 对音频波形应用掩码。掩码的起始和长度由 `[mask_start, mask_start + mask_width)` 决定。 + + 参数: + - **mask_start** (int) - 掩码的起始位置,必须是非负的。 + - **mask_width** (int) - 掩码的宽度,必须是大于0。 + - **mask_value** (float) - 填充到掩码区间的值。 + - **axis** (int) - 要应用掩码的轴( ``1`` 表示频率, ``2`` 表示时间)。 + + 异常: + - **ValueError** - `mask_start` 参数值错误(小于0)。 + - **ValueError** - `mask_width` 参数值错误(小于1)。 + - **ValueError** - `axis` 参数类型错误或者值错误,不属于 [1, 2]。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxisIID.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxisIID.rst index fd2afbe2eef..cb3228d3c90 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxisIID.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MaskAlongAxisIID.rst @@ -1,24 +1,24 @@ -mindspore.dataset.audio.MaskAlongAxisIID -======================================== - -.. py:class:: mindspore.dataset.audio.MaskAlongAxisIID(mask_param, mask_value, axis) - - 对音频波形沿 `axis` 轴应用掩码。掩码的起始和长度由 `[mask_start, mask_start + mask_width)` 决定,其中 `mask_width` 从 `uniform[0, mask_param]` 中采样, `mask_start` 从 `uniform[0, max_length - mask_width]` 中采样, - `max_length` 是光谱图中特定轴的列数。 - - 参数: - - **mask_param** (int) - 要屏蔽的列数,将从[0, mask_param]统一采样,必须是非负数。 - - **mask_value** (float) - 填充到掩码区间的值。 - - **axis** (int) - 要应用掩码的轴( ``1`` 表示频率, ``2`` 表示时间)。 - - 异常: - - **TypeError** - 当 `mask_param` 的类型不为int。 - - **ValueError** - 当 `mask_param` 为负数。 - - **TypeError** - 当 `mask_value` 的类型不为float。 - - **TypeError** - 当 `axis` 的类型不为int。 - - **ValueError** - 当 `axis` 取值不在[1, 2]范围内。 - - **RuntimeError** - 当输入音频的shape不为<..., freq, time>。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.MaskAlongAxisIID +======================================== + +.. py:class:: mindspore.dataset.audio.MaskAlongAxisIID(mask_param, mask_value, axis) + + 对音频波形沿 `axis` 轴应用掩码。掩码的起始和长度由 `[mask_start, mask_start + mask_width)` 决定,其中 `mask_width` 从 `uniform[0, mask_param]` 中采样, `mask_start` 从 `uniform[0, max_length - mask_width]` 中采样, + `max_length` 是光谱图中特定轴的列数。 + + 参数: + - **mask_param** (int) - 要屏蔽的列数,将从[0, mask_param]统一采样,必须是非负数。 + - **mask_value** (float) - 填充到掩码区间的值。 + - **axis** (int) - 要应用掩码的轴( ``1`` 表示频率, ``2`` 表示时间)。 + + 异常: + - **TypeError** - 当 `mask_param` 的类型不为int。 + - **ValueError** - 当 `mask_param` 为负数。 + - **TypeError** - 当 `mask_value` 的类型不为float。 + - **TypeError** - 当 `axis` 的类型不为int。 + - **ValueError** - 当 `axis` 取值不在[1, 2]范围内。 + - **RuntimeError** - 当输入音频的shape不为<..., freq, time>。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MelScale.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MelScale.rst index 21af17f2b42..0cce18ba575 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MelScale.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MelScale.rst @@ -1,34 +1,34 @@ -mindspore.dataset.audio.MelScale -================================ - -.. py:class:: mindspore.dataset.audio.MelScale(n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, n_stft=201, norm=NormType.NONE, mel_type=MelType.HTK) - - 将普通STFT转换为梅尔尺度的STFT。 - - 参数: - - **n_mels** (int, 可选) - 梅尔滤波器的数量。默认值: ``128`` 。 - - **sample_rate** (int, 可选) - 音频信号采样速率。默认值: ``16000`` (单位:Hz)。 - - **f_min** (float, 可选) - 最小频率。默认值: ``0.0`` 。 - - **f_max** (float, 可选) - 最大频率。默认值: ``None`` ,将设置为 `sample_rate//2` 。 - - **n_stft** (int, 可选) - STFT中的频段数。默认值: ``201`` 。 - - **norm** (:class:`~.audio.NormType`, 可选) - 标准化方法,可以是 ``NormType.SLANEY`` 或 ``NormType.NONE`` 。默认值: ``NormType.NONE`` ,不使用标准化。 - 若采用 ``NormType.SLANEY`` ,则三角梅尔权重将被除以梅尔频带的宽度。 - - **mel_type** (:class:`~.audio.MelType`, 可选) - 要使用的Mel比例,可以是 ``MelType.SLAN`` 或 ``MelType.HTK`` 。默认值: ``MelType.HTK`` 。 - - 异常: - - **TypeError** - 如果 `n_mels` 的类型不为int。 - - **ValueError** - 如果 `n_mels` 不为正数。 - - **TypeError** - 如果 `sample_rate` 的类型不为int。 - - **ValueError** - 如果 `sample_rate` 不为正数。 - - **TypeError** - 如果 `f_min` 的类型不为float。 - - **ValueError** - 如果 `f_min` 大于等于 `f_max` 。 - - **TypeError** - 如果 `f_max` 的类型不为float。 - - **ValueError** - 如果 `f_max` 为负数。 - - **TypeError** - 如果 `n_stft` 的类型不为int。 - - **ValueError** - 如果 `n_stft` 不为正数。 - - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 - - **TypeError** - 如果 `mel_type` 的类型不为 :class:`mindspore.dataset.audio.MelType` 。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.MelScale +================================ + +.. py:class:: mindspore.dataset.audio.MelScale(n_mels=128, sample_rate=16000, f_min=0.0, f_max=None, n_stft=201, norm=NormType.NONE, mel_type=MelType.HTK) + + 将普通STFT转换为梅尔尺度的STFT。 + + 参数: + - **n_mels** (int, 可选) - 梅尔滤波器的数量。默认值: ``128`` 。 + - **sample_rate** (int, 可选) - 音频信号采样速率。默认值: ``16000`` (单位:Hz)。 + - **f_min** (float, 可选) - 最小频率。默认值: ``0.0`` 。 + - **f_max** (float, 可选) - 最大频率。默认值: ``None`` ,将设置为 `sample_rate//2` 。 + - **n_stft** (int, 可选) - STFT中的频段数。默认值: ``201`` 。 + - **norm** (:class:`~.audio.NormType`, 可选) - 标准化方法,可以是 ``NormType.SLANEY`` 或 ``NormType.NONE`` 。默认值: ``NormType.NONE`` ,不使用标准化。 + 若采用 ``NormType.SLANEY`` ,则三角梅尔权重将被除以梅尔频带的宽度。 + - **mel_type** (:class:`~.audio.MelType`, 可选) - 要使用的Mel比例,可以是 ``MelType.SLAN`` 或 ``MelType.HTK`` 。默认值: ``MelType.HTK`` 。 + + 异常: + - **TypeError** - 如果 `n_mels` 的类型不为int。 + - **ValueError** - 如果 `n_mels` 不为正数。 + - **TypeError** - 如果 `sample_rate` 的类型不为int。 + - **ValueError** - 如果 `sample_rate` 不为正数。 + - **TypeError** - 如果 `f_min` 的类型不为float。 + - **ValueError** - 如果 `f_min` 大于等于 `f_max` 。 + - **TypeError** - 如果 `f_max` 的类型不为float。 + - **ValueError** - 如果 `f_max` 为负数。 + - **TypeError** - 如果 `n_stft` 的类型不为int。 + - **ValueError** - 如果 `n_stft` 不为正数。 + - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 + - **TypeError** - 如果 `mel_type` 的类型不为 :class:`mindspore.dataset.audio.MelType` 。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MuLawEncoding.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MuLawEncoding.rst index 51f29442d44..de633a65c9f 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MuLawEncoding.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.MuLawEncoding.rst @@ -1,17 +1,17 @@ -mindspore.dataset.audio.MuLawEncoding -===================================== - -.. py:class:: mindspore.dataset.audio.MuLawEncoding(quantization_channels=256) - - 基于mu-law压缩的信号编码。 - - 参数: - - **quantization_channels** (int, 可选) - 通道数,必须为正数。默认值: ``256`` 。 - - 异常: - - **TypeError** - 当 `quantization_channels` 的类型不为int。 - - **ValueError** - 当 `quantization_channels` 不为正数。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.MuLawEncoding +===================================== + +.. py:class:: mindspore.dataset.audio.MuLawEncoding(quantization_channels=256) + + 基于mu-law压缩的信号编码。 + + 参数: + - **quantization_channels** (int, 可选) - 通道数,必须为正数。默认值: ``256`` 。 + + 异常: + - **TypeError** - 当 `quantization_channels` 的类型不为int。 + - **ValueError** - 当 `quantization_channels` 不为正数。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.RiaaBiquad.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.RiaaBiquad.rst index c1c69f0f4bc..bb15e662f9a 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.RiaaBiquad.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.RiaaBiquad.rst @@ -1,19 +1,19 @@ -mindspore.dataset.audio.RiaaBiquad -================================== - -.. py:class:: mindspore.dataset.audio.RiaaBiquad(sample_rate) - - 对输入音频波形施加RIAA均衡。 - - 接口实现方式类似于 `SoX库 `_ 。 - - 参数: - - **sample_rate** (int) - 波形的采样率,例如 ``44100`` (Hz),只能是 ``44100`` 、 ``48000`` 、 ``88200`` 、 ``96000`` 中的一个。 - - 异常: - - **TypeError** - 当 `sample_rate` 的类型不为int。 - - **ValueError** - 当 `sample_rate` 不为 ``44100`` 、 ``48000`` 、 ``88200`` 、 ``96000`` 中的任何一个。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.RiaaBiquad +================================== + +.. py:class:: mindspore.dataset.audio.RiaaBiquad(sample_rate) + + 对输入音频波形施加RIAA均衡。 + + 接口实现方式类似于 `SoX库 `_ 。 + + 参数: + - **sample_rate** (int) - 波形的采样率,例如 ``44100`` (Hz),只能是 ``44100`` 、 ``48000`` 、 ``88200`` 、 ``96000`` 中的一个。 + + 异常: + - **TypeError** - 当 `sample_rate` 的类型不为int。 + - **ValueError** - 当 `sample_rate` 不为 ``44100`` 、 ``48000`` 、 ``88200`` 、 ``96000`` 中的任何一个。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.SlidingWindowCmn.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.SlidingWindowCmn.rst index b0ddc93e964..102c785a7b6 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.SlidingWindowCmn.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.SlidingWindowCmn.rst @@ -1,25 +1,25 @@ -mindspore.dataset.audio.SlidingWindowCmn -======================================== - -.. py:class:: mindspore.dataset.audio.SlidingWindowCmn(cmn_window=600, min_cmn_window=100, center=False, norm_vars=False) - - 对每个话语应用滑动窗口倒谱均值(和可选方差)归一化。 - - 参数: - - **cmn_window** (int, 可选) - 用于运行平均CMN计算的帧中窗口。默认值: ``600`` 。 - - **min_cmn_window** (int, 可选) - 解码开始时使用的最小CMN窗口(仅在开始时增加延迟)。 - 仅在 `center` 为 ``False`` 时适用,在 `center` 为 ``True`` 时忽略。默认值: ``100`` 。 - - **center** (bool, 可选) - 如果为 ``True`` ,则使用以当前帧为中心的窗口。如果为 ``False`` ,则窗口在左侧。默认值: ``False`` 。 - - **norm_vars** (bool, 可选) - 如果为 ``True`` ,则将方差规范化为1。默认值: ``False`` 。 - - 异常: - - **TypeError** - 当 `cmn_window` 的类型不为int。 - - **ValueError** - 当 `cmn_window` 为负数。 - - **TypeError** - 当 `min_cmn_window` 的类型不为int。 - - **ValueError** - 当 `min_cmn_window` 为负数。 - - **TypeError** - 当 `center` 的类型不为bool。 - - **TypeError** - 当 `norm_vars` 的类型不为bool。 - - 教程样例: - - `音频变换样例库 - `_ +mindspore.dataset.audio.SlidingWindowCmn +======================================== + +.. py:class:: mindspore.dataset.audio.SlidingWindowCmn(cmn_window=600, min_cmn_window=100, center=False, norm_vars=False) + + 对每个话语应用滑动窗口倒谱均值(和可选方差)归一化。 + + 参数: + - **cmn_window** (int, 可选) - 用于运行平均CMN计算的帧中窗口。默认值: ``600`` 。 + - **min_cmn_window** (int, 可选) - 解码开始时使用的最小CMN窗口(仅在开始时增加延迟)。 + 仅在 `center` 为 ``False`` 时适用,在 `center` 为 ``True`` 时忽略。默认值: ``100`` 。 + - **center** (bool, 可选) - 如果为 ``True`` ,则使用以当前帧为中心的窗口。如果为 ``False`` ,则窗口在左侧。默认值: ``False`` 。 + - **norm_vars** (bool, 可选) - 如果为 ``True`` ,则将方差规范化为1。默认值: ``False`` 。 + + 异常: + - **TypeError** - 当 `cmn_window` 的类型不为int。 + - **ValueError** - 当 `cmn_window` 为负数。 + - **TypeError** - 当 `min_cmn_window` 的类型不为int。 + - **ValueError** - 当 `min_cmn_window` 为负数。 + - **TypeError** - 当 `center` 的类型不为bool。 + - **TypeError** - 当 `norm_vars` 的类型不为bool。 + + 教程样例: + - `音频变换样例库 + `_ diff --git a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.create_dct.rst b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.create_dct.rst index c3879827057..1237d1b4d7c 100644 --- a/docs/api/api_python/dataset_audio/mindspore.dataset.audio.create_dct.rst +++ b/docs/api/api_python/dataset_audio/mindspore.dataset.audio.create_dct.rst @@ -1,21 +1,21 @@ -mindspore.dataset.audio.create_dct -================================== - -.. py:function:: mindspore.dataset.audio.create_dct(n_mfcc, n_mels, norm=NormMode.NONE) - - 创建一个shape为( `n_mels` , `n_mfcc` )的DCT变换矩阵,并根据范数进行标准化。 - - 参数: - - **n_mfcc** (int) - 要保留mfc系数的数量,该值必须大于0。 - - **n_mels** (int) - mel滤波器的数量,该值必须大于0。 - - **norm** (:class:`~.audio.NormMode`, 可选) - 标准化模式,可以是 ``NormMode.NONE`` 或 ``NormMode.ORTHO`` 。默认值: ``NormMode.NONE`` 。 - - 返回: - numpy.ndarray,shape为 ( `n_mels` , `n_mfcc` ) 的DCT转换矩阵。 - - 异常: - - **TypeError** - 如果 `n_mfcc` 的类型不为int。 - - **ValueError** - 如果 `n_mfcc` 不为正数。 - - **TypeError** - 如果 `n_mels` 的类型不为int。 - - **ValueError** - 如果 `n_mels` 不为正数。 - - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 +mindspore.dataset.audio.create_dct +================================== + +.. py:function:: mindspore.dataset.audio.create_dct(n_mfcc, n_mels, norm=NormMode.NONE) + + 创建一个shape为( `n_mels` , `n_mfcc` )的DCT变换矩阵,并根据范数进行标准化。 + + 参数: + - **n_mfcc** (int) - 要保留mfc系数的数量,该值必须大于0。 + - **n_mels** (int) - mel滤波器的数量,该值必须大于0。 + - **norm** (:class:`~.audio.NormMode`, 可选) - 标准化模式,可以是 ``NormMode.NONE`` 或 ``NormMode.ORTHO`` 。默认值: ``NormMode.NONE`` 。 + + 返回: + numpy.ndarray,shape为 ( `n_mels` , `n_mfcc` ) 的DCT转换矩阵。 + + 异常: + - **TypeError** - 如果 `n_mfcc` 的类型不为int。 + - **ValueError** - 如果 `n_mfcc` 不为正数。 + - **TypeError** - 如果 `n_mels` 的类型不为int。 + - **ValueError** - 如果 `n_mels` 不为正数。 + - **TypeError** - 如果 `norm` 的类型不为 :class:`mindspore.dataset.audio.NormType` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AdjustGamma.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AdjustGamma.rst index 64870c51667..94a3b3e70aa 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AdjustGamma.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AdjustGamma.rst @@ -1,25 +1,25 @@ -mindspore.dataset.vision.AdjustGamma -==================================== - -.. py:class:: mindspore.dataset.vision.AdjustGamma(gamma, gain=1) - - 对输入图像应用伽马校正。输入图片shape应该为 <..., H, W, C>或。 - - .. math:: - I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} - - 更多详细信息,请参见 `Gamma矫正 `_ 。 - - 参数: - - **gamma** (float) - 非负实数。输出图像像素值与输入图像像素值呈指数相关。 `gamma` 大于 ``1`` 使阴影更暗,而 `gamma` 小于 ``1`` 使黑暗区域更亮。 - - **gain** (float, 可选) - 常数乘数。默认值: ``1.0`` 。 - - 异常: - - **TypeError** - 如果 `gain` 不是浮点类型。 - - **TypeError** - 如果 `gamma` 不是浮点类型。 - - **ValueError** - 如果 `gamma` 小于0。 - - **RuntimeError** - 如果给定的张量形状不是或<..., H, W, C>。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.AdjustGamma +==================================== + +.. py:class:: mindspore.dataset.vision.AdjustGamma(gamma, gain=1) + + 对输入图像应用伽马校正。输入图片shape应该为 <..., H, W, C>或。 + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + 更多详细信息,请参见 `Gamma矫正 `_ 。 + + 参数: + - **gamma** (float) - 非负实数。输出图像像素值与输入图像像素值呈指数相关。 `gamma` 大于 ``1`` 使阴影更暗,而 `gamma` 小于 ``1`` 使黑暗区域更亮。 + - **gain** (float, 可选) - 常数乘数。默认值: ``1.0`` 。 + + 异常: + - **TypeError** - 如果 `gain` 不是浮点类型。 + - **TypeError** - 如果 `gamma` 不是浮点类型。 + - **ValueError** - 如果 `gamma` 小于0。 + - **RuntimeError** - 如果给定的张量形状不是或<..., H, W, C>。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugment.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugment.rst index 90dd24befd9..3e9858e572c 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugment.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugment.rst @@ -1,31 +1,31 @@ -mindspore.dataset.vision.AutoAugment -==================================== - -.. py:class:: mindspore.dataset.vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, interpolation=Inter.NEAREST, fill_value=0) - - 应用AutoAugment数据增强方法,基于论文 `AutoAugment: Learning Augmentation Strategies from Data `_ 。 - 此操作仅适用于3通道RGB图像。 - - 参数: - - **policy** (:class:`~.vision.AutoAugmentPolicy`, 可选) - 在不同数据集上学习的AutoAugment策略。默认值: ``AutoAugmentPolicy.IMAGENET`` 。 - 可以是 ``AutoAugmentPolicy.IMAGENET`` 、 ``AutoAugmentPolicy.CIFAR10`` 、 ``AutoAugmentPolicy.SVHN`` 。 - - - **AutoAugmentPolicy.IMAGENET**:表示应用在ImageNet数据集上学习的AutoAugment。 - - **AutoAugmentPolicy.CIFAR10**:表示应用在Cifar10数据集上学习的AutoAugment。 - - **AutoAugmentPolicy.SVHN**:表示应用在SVHN数据集上学习的AutoAugment。 - - - **interpolation** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 - 默认值: ``Inter.NEAREST``。 - - **fill_value** (Union[int, tuple[int]], 可选) - 填充的像素值。 - 如果是3元素元组,则分别用于填充R、G、B通道。 - 如果是整数,则用于所有 RGB 通道。 `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 - - 异常: - - **TypeError** - 如果 `policy` 不是 :class:`mindspore.dataset.vision.AutoAugmentPolicy` 类型。 - - **TypeError** - 如果 `interpolation` 不是 :class:`mindspore.dataset.vision.Inter` 类型。 - - **TypeError** - 如果 `fill_value` 不是整数或长度为3的元组。 - - **RuntimeError** - 如果给定的张量shape不是。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.AutoAugment +==================================== + +.. py:class:: mindspore.dataset.vision.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, interpolation=Inter.NEAREST, fill_value=0) + + 应用AutoAugment数据增强方法,基于论文 `AutoAugment: Learning Augmentation Strategies from Data `_ 。 + 此操作仅适用于3通道RGB图像。 + + 参数: + - **policy** (:class:`~.vision.AutoAugmentPolicy`, 可选) - 在不同数据集上学习的AutoAugment策略。默认值: ``AutoAugmentPolicy.IMAGENET`` 。 + 可以是 ``AutoAugmentPolicy.IMAGENET`` 、 ``AutoAugmentPolicy.CIFAR10`` 、 ``AutoAugmentPolicy.SVHN`` 。 + + - **AutoAugmentPolicy.IMAGENET**:表示应用在ImageNet数据集上学习的AutoAugment。 + - **AutoAugmentPolicy.CIFAR10**:表示应用在Cifar10数据集上学习的AutoAugment。 + - **AutoAugmentPolicy.SVHN**:表示应用在SVHN数据集上学习的AutoAugment。 + + - **interpolation** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 + 默认值: ``Inter.NEAREST``。 + - **fill_value** (Union[int, tuple[int]], 可选) - 填充的像素值。 + 如果是3元素元组,则分别用于填充R、G、B通道。 + 如果是整数,则用于所有 RGB 通道。 `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 + + 异常: + - **TypeError** - 如果 `policy` 不是 :class:`mindspore.dataset.vision.AutoAugmentPolicy` 类型。 + - **TypeError** - 如果 `interpolation` 不是 :class:`mindspore.dataset.vision.Inter` 类型。 + - **TypeError** - 如果 `fill_value` 不是整数或长度为3的元组。 + - **RuntimeError** - 如果给定的张量shape不是。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugmentPolicy.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugmentPolicy.rst index ab551a50777..06b24714afa 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugmentPolicy.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoAugmentPolicy.rst @@ -1,67 +1,67 @@ -mindspore.dataset.vision.AutoAugmentPolicy -========================================== - -.. py:class:: mindspore.dataset.vision.AutoAugmentPolicy - - 不同数据集的自动增强策略。 - 可能的枚举值包括: ``AutoAugmentPolicy.IMAGENET`` 、 ``AutoAugmentPolicy.CIFAR10`` 、 ``AutoAugmentPolicy.SVHN`` 。 - 每个策略包含25对增强操作。使用AutoAugment时,每个图像都会使用这些操作对中的一个随机转换。每对有2个不同的操作。下面显示了所有这些增强操作,包括操作名称及其概率和随机参数。 - - - ``AutoAugmentPolicy.IMAGENET``:ImageNet的数据集自动增强策略。 - - .. code-block:: - - Augmentation operations pair: - [(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None))] - - - ``AutoAugmentPolicy.CIFAR10``:Cifar10的数据集自动增强策略。 - - .. code-block:: - - Augmentation operations pair: - [(("Invert", 0.1, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None))] - - - ``AutoAugmentPolicy.SVHN``:SVHN的数据集自动增强策略。 - - .. code-block:: - - Augmentation operations pair: - [(("ShearX", 0.9, 4), ("Invert", 0.2, None)), (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None))] +mindspore.dataset.vision.AutoAugmentPolicy +========================================== + +.. py:class:: mindspore.dataset.vision.AutoAugmentPolicy + + 不同数据集的自动增强策略。 + 可能的枚举值包括: ``AutoAugmentPolicy.IMAGENET`` 、 ``AutoAugmentPolicy.CIFAR10`` 、 ``AutoAugmentPolicy.SVHN`` 。 + 每个策略包含25对增强操作。使用AutoAugment时,每个图像都会使用这些操作对中的一个随机转换。每对有2个不同的操作。下面显示了所有这些增强操作,包括操作名称及其概率和随机参数。 + + - ``AutoAugmentPolicy.IMAGENET``:ImageNet的数据集自动增强策略。 + + .. code-block:: + + Augmentation operations pair: + [(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None))] + + - ``AutoAugmentPolicy.CIFAR10``:Cifar10的数据集自动增强策略。 + + .. code-block:: + + Augmentation operations pair: + [(("Invert", 0.1, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None))] + + - ``AutoAugmentPolicy.SVHN``:SVHN的数据集自动增强策略。 + + .. code-block:: + + Augmentation operations pair: + [(("ShearX", 0.9, 4), ("Invert", 0.2, None)), (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None))] diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoContrast.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoContrast.rst index f743517b760..e3b32d7ba4b 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoContrast.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.AutoContrast.rst @@ -1,36 +1,36 @@ -mindspore.dataset.vision.AutoContrast -===================================== - -.. py:class:: mindspore.dataset.vision.AutoContrast(cutoff=0.0, ignore=None) - - 在输入图像上应用自动对比度。首先计算图像的直方图,将直方图中最亮像素的值映射为255,将直方图中最暗像素的值映射为0。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **cutoff** (float, 可选) - 输入图像直方图中需要剔除的最亮和最暗像素的百分比。该值必须在 [0.0, 50.0) 范围内。默认值: ``0.0`` 。 - - **ignore** (Union[int, sequence], 可选) - 要忽略的背景像素值,忽略值必须在 [0, 255] 范围内。默认值: ``None`` 。 - - 异常: - - **TypeError** - 如果 `cutoff` 不是float类型。 - - **TypeError** - 如果 `ignore` 不是int或sequence类型。 - - **ValueError** - 如果 `cutoff` 不在[0, 50.0) 范围内。 - - **ValueError** - 如果 `ignore` 不在[0, 255] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,输入数据的通道仅支持1和3。如果数据类型是float32,期望输入的值的范围为[0,1]。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.AutoContrast +===================================== + +.. py:class:: mindspore.dataset.vision.AutoContrast(cutoff=0.0, ignore=None) + + 在输入图像上应用自动对比度。首先计算图像的直方图,将直方图中最亮像素的值映射为255,将直方图中最暗像素的值映射为0。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **cutoff** (float, 可选) - 输入图像直方图中需要剔除的最亮和最暗像素的百分比。该值必须在 [0.0, 50.0) 范围内。默认值: ``0.0`` 。 + - **ignore** (Union[int, sequence], 可选) - 要忽略的背景像素值,忽略值必须在 [0, 255] 范围内。默认值: ``None`` 。 + + 异常: + - **TypeError** - 如果 `cutoff` 不是float类型。 + - **TypeError** - 如果 `ignore` 不是int或sequence类型。 + - **ValueError** - 如果 `cutoff` 不在[0, 50.0) 范围内。 + - **ValueError** - 如果 `ignore` 不在[0, 255] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,输入数据的通道仅支持1和3。如果数据类型是float32,期望输入的值的范围为[0,1]。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Border.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Border.rst index ff4ce3fbe37..d5d32703258 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Border.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Border.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.Border -=============================== - -.. py:class:: mindspore.dataset.vision.Border - - 边界填充方式枚举类。 - - 可选枚举值为: ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。 - - - **Border.CONSTANT** - 使用常量值进行填充。 - - **Border.EDGE** - 使用各边的边界像素值进行填充。 - - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 - 例如,对 [1,2,3,4] 的两侧分别填充2个元素,结果为 [3,2,1,2,3,4,3,2]。 - - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 - 例如,对 [1,2,3,4] 的两侧分别填充2个元素,结果为 [2,1,1,2,3,4,4,3]。 - - .. note:: - 该类派生自 `str` 以支持 JSON 可序列化。 +mindspore.dataset.vision.Border +=============================== + +.. py:class:: mindspore.dataset.vision.Border + + 边界填充方式枚举类。 + + 可选枚举值为: ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。 + + - **Border.CONSTANT** - 使用常量值进行填充。 + - **Border.EDGE** - 使用各边的边界像素值进行填充。 + - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 + 例如,对 [1,2,3,4] 的两侧分别填充2个元素,结果为 [3,2,1,2,3,4,3,2]。 + - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 + 例如,对 [1,2,3,4] 的两侧分别填充2个元素,结果为 [2,1,1,2,3,4,4,3]。 + + .. note:: + 该类派生自 `str` 以支持 JSON 可序列化。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.BoundingBoxAugment.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.BoundingBoxAugment.rst index 0263f467691..00b8446cbbb 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.BoundingBoxAugment.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.BoundingBoxAugment.rst @@ -1,20 +1,20 @@ -mindspore.dataset.vision.BoundingBoxAugment -=========================================== - -.. py:class:: mindspore.dataset.vision.BoundingBoxAugment(transform, ratio=0.3) - - 对图像的随机标注边界框区域,应用给定的图像变换处理。 - - 参数: - - **transform** (TensorOperation) - 对图像的随机标注边界框区域应用的变换处理。 - - **ratio** (float, 可选) - 要应用变换的边界框的比例。范围:[0.0, 1.0]。默认值: ``0.3`` 。 - - 异常: - - **TypeError** - 如果 `transform` 不是 `mindspore.dataset.vision` 模块中的图像变换处理。 - - **TypeError** - 如果 `ratio` 不是float类型。 - - **ValueError** - 如果 `ratio` 不在 [0.0, 1.0] 范围内。 - - **RuntimeError** - 如果给定的边界框无效。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.BoundingBoxAugment +=========================================== + +.. py:class:: mindspore.dataset.vision.BoundingBoxAugment(transform, ratio=0.3) + + 对图像的随机标注边界框区域,应用给定的图像变换处理。 + + 参数: + - **transform** (TensorOperation) - 对图像的随机标注边界框区域应用的变换处理。 + - **ratio** (float, 可选) - 要应用变换的边界框的比例。范围:[0.0, 1.0]。默认值: ``0.3`` 。 + + 异常: + - **TypeError** - 如果 `transform` 不是 `mindspore.dataset.vision` 模块中的图像变换处理。 + - **TypeError** - 如果 `ratio` 不是float类型。 + - **ValueError** - 如果 `ratio` 不在 [0.0, 1.0] 范围内。 + - **RuntimeError** - 如果给定的边界框无效。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CenterCrop.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CenterCrop.rst index b83107b5491..3fa06e94486 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CenterCrop.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CenterCrop.rst @@ -1,21 +1,21 @@ -mindspore.dataset.vision.CenterCrop -=================================== - -.. py:class:: mindspore.dataset.vision.CenterCrop(size) - - 对输入图像应用中心区域裁剪。如果输入图像尺寸小于输出尺寸,则在裁剪前对输入图像边界填充0像素。 - - 参数: - - **size** (Union[int, sequence]) - 裁剪区域尺寸大小。 - 如果 `size` 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 - 如果 `size` 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - 值必须大于 0。 - - 异常: - - **TypeError** - 如果 `size` 不是int或sequence类型。 - - **ValueError** - 如果 `size` 小于或等于 0。 - - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.CenterCrop +=================================== + +.. py:class:: mindspore.dataset.vision.CenterCrop(size) + + 对输入图像应用中心区域裁剪。如果输入图像尺寸小于输出尺寸,则在裁剪前对输入图像边界填充0像素。 + + 参数: + - **size** (Union[int, sequence]) - 裁剪区域尺寸大小。 + 如果 `size` 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 + 如果 `size` 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + 值必须大于 0。 + + 异常: + - **TypeError** - 如果 `size` 不是int或sequence类型。 + - **ValueError** - 如果 `size` 小于或等于 0。 + - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ConvertColor.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ConvertColor.rst index 747b0cf406e..bd7f8d0f62d 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ConvertColor.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ConvertColor.rst @@ -1,53 +1,53 @@ -mindspore.dataset.vision.ConvertColor -===================================== - -.. py:class:: mindspore.dataset.vision.ConvertColor(convert_mode) - - 更改图像的色彩空间。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **convert_mode** (:class:`~.vision.ConvertMode`) - 图像色彩空间转换的模式。 - - - **ConvertMode.COLOR_BGR2BGRA**: 将 BGR 图像转换为 BGRA 图像。 - - **ConvertMode.COLOR_RGB2RGBA**: 将 RGB 图像转换为 RGBA 图像。 - - **ConvertMode.COLOR_BGRA2BGR**: 将 BGRA 图像转换为 BGR 图像。 - - **ConvertMode.COLOR_RGBA2RGB**: 将 RGBA 图像转换为 RGB 图像。 - - **ConvertMode.COLOR_BGR2RGBA**: 将 BGR 图像转换为 RGBA 图像。 - - **ConvertMode.COLOR_RGB2BGRA**: 将 RGB 图像转换为 BGRA 图像。 - - **ConvertMode.COLOR_RGBA2BGR**: 将 RGBA 图像转换为 BGR 图像。 - - **ConvertMode.COLOR_BGRA2RGB**: 将 BGRA 图像转换为 RGB 图像。 - - **ConvertMode.COLOR_BGR2RGB**: 将 BGR 图像转换为 RGB 图像。 - - **ConvertMode.COLOR_RGB2BGR**: 将 RGB 图像转换为 BGR 图像。 - - **ConvertMode.COLOR_BGRA2RGBA**: 将 BGRA 图像转换为 RGBA 图像。 - - **ConvertMode.COLOR_RGBA2BGRA**: 将 RGBA 图像转换为 BGRA 图像。 - - **ConvertMode.COLOR_BGR2GRAY**: 将 BGR 图像转换为 GRAY 图像。 - - **ConvertMode.COLOR_RGB2GRAY**: 将 RGB 图像转换为 GRAY 图像。 - - **ConvertMode.COLOR_GRAY2BGR**: 将 GRAY 图像转换为 BGR 图像。 - - **ConvertMode.COLOR_GRAY2RGB**: 将 GRAY 图像转换为 RGB 图像。 - - **ConvertMode.COLOR_GRAY2BGRA**: 将 GRAY 图像转换为 BGRA 图像。 - - **ConvertMode.COLOR_GRAY2RGBA**: 将 GRAY 图像转换为 RGBA 图像。 - - **ConvertMode.COLOR_BGRA2GRAY**: 将 BGRA 图像转换为 GRAY 图像。 - - **ConvertMode.COLOR_RGBA2GRAY**: 将 RGBA 图像转换为 GRAY 图像。 - - 异常: - - **TypeError** - 如果 `convert_mode` 不是类 :class:`mindspore.dataset.vision.ConvertMode` 的类型。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,数据格式支持NHWC,Channels: [1, 3, 4], N只支持1。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.ConvertColor +===================================== + +.. py:class:: mindspore.dataset.vision.ConvertColor(convert_mode) + + 更改图像的色彩空间。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **convert_mode** (:class:`~.vision.ConvertMode`) - 图像色彩空间转换的模式。 + + - **ConvertMode.COLOR_BGR2BGRA**: 将 BGR 图像转换为 BGRA 图像。 + - **ConvertMode.COLOR_RGB2RGBA**: 将 RGB 图像转换为 RGBA 图像。 + - **ConvertMode.COLOR_BGRA2BGR**: 将 BGRA 图像转换为 BGR 图像。 + - **ConvertMode.COLOR_RGBA2RGB**: 将 RGBA 图像转换为 RGB 图像。 + - **ConvertMode.COLOR_BGR2RGBA**: 将 BGR 图像转换为 RGBA 图像。 + - **ConvertMode.COLOR_RGB2BGRA**: 将 RGB 图像转换为 BGRA 图像。 + - **ConvertMode.COLOR_RGBA2BGR**: 将 RGBA 图像转换为 BGR 图像。 + - **ConvertMode.COLOR_BGRA2RGB**: 将 BGRA 图像转换为 RGB 图像。 + - **ConvertMode.COLOR_BGR2RGB**: 将 BGR 图像转换为 RGB 图像。 + - **ConvertMode.COLOR_RGB2BGR**: 将 RGB 图像转换为 BGR 图像。 + - **ConvertMode.COLOR_BGRA2RGBA**: 将 BGRA 图像转换为 RGBA 图像。 + - **ConvertMode.COLOR_RGBA2BGRA**: 将 RGBA 图像转换为 BGRA 图像。 + - **ConvertMode.COLOR_BGR2GRAY**: 将 BGR 图像转换为 GRAY 图像。 + - **ConvertMode.COLOR_RGB2GRAY**: 将 RGB 图像转换为 GRAY 图像。 + - **ConvertMode.COLOR_GRAY2BGR**: 将 GRAY 图像转换为 BGR 图像。 + - **ConvertMode.COLOR_GRAY2RGB**: 将 GRAY 图像转换为 RGB 图像。 + - **ConvertMode.COLOR_GRAY2BGRA**: 将 GRAY 图像转换为 BGRA 图像。 + - **ConvertMode.COLOR_GRAY2RGBA**: 将 GRAY 图像转换为 RGBA 图像。 + - **ConvertMode.COLOR_BGRA2GRAY**: 将 BGRA 图像转换为 GRAY 图像。 + - **ConvertMode.COLOR_RGBA2GRAY**: 将 RGBA 图像转换为 GRAY 图像。 + + 异常: + - **TypeError** - 如果 `convert_mode` 不是类 :class:`mindspore.dataset.vision.ConvertMode` 的类型。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,数据格式支持NHWC,Channels: [1, 3, 4], N只支持1。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Crop.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Crop.rst index 7103e2a7b19..9de741408ba 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Crop.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Crop.rst @@ -1,39 +1,39 @@ -mindspore.dataset.vision.Crop -============================= - -.. py:class:: mindspore.dataset.vision.Crop(coordinates, size) - - 在输入图像上裁剪出指定区域。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **coordinates** (sequence) - 裁剪区域的起始左上角坐标。必须是两个值的序列,形式为(上,左)。 - - **size** (Union[int, sequence]) - 裁剪区域的尺寸大小。 - 如果 `size` 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 - 如果 `size` 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - 值必须大于 0。 - - 异常: - - **TypeError** - 如果 `coordinates` 不是sequence类型。 - - **TypeError** - 如果 `size` 不是int或sequence类型。 - - **ValueError** - 如果 `coordinates` 小于 0。 - - **ValueError** - 如果 `size` 小于或等于 0。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入/输出数据的维度限制为[4, 6]和[32768, 32768]之间。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Crop +============================= + +.. py:class:: mindspore.dataset.vision.Crop(coordinates, size) + + 在输入图像上裁剪出指定区域。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **coordinates** (sequence) - 裁剪区域的起始左上角坐标。必须是两个值的序列,形式为(上,左)。 + - **size** (Union[int, sequence]) - 裁剪区域的尺寸大小。 + 如果 `size` 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 + 如果 `size` 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + 值必须大于 0。 + + 异常: + - **TypeError** - 如果 `coordinates` 不是sequence类型。 + - **TypeError** - 如果 `size` 不是int或sequence类型。 + - **ValueError** - 如果 `coordinates` 小于 0。 + - **ValueError** - 如果 `size` 小于或等于 0。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入/输出数据的维度限制为[4, 6]和[32768, 32768]之间。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutMixBatch.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutMixBatch.rst index ce8ab817b9b..0229dcdba46 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutMixBatch.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutMixBatch.rst @@ -1,24 +1,24 @@ -mindspore.dataset.vision.CutMixBatch -================================================= - -.. py:class:: mindspore.dataset.vision.CutMixBatch(image_batch_format, alpha=1.0, prob=1.0) - - 对输入批次的图像和标注应用剪切混合转换。 - 请注意,在调用此操作符之前,您需要将标注制作为 one-hot 格式并进行批处理。 - - 参数: - - **image_batch_format** (:class:`~.vision.ImageBatchFormat`) - 图像批处理输出格式。可以是 ``ImageBatchFormat.NHWC`` 或 ``ImageBatchFormat.NCHW`` 。 - - **alpha** (float, 可选) - β分布的超参数,必须大于0。默认值: ``1.0`` 。 - - **prob** (float, 可选) - 对每个图像应用剪切混合处理的概率,取值范围:[0.0, 1.0]。默认值: ``1.0`` 。 - - 异常: - - **TypeError** - 如果 `image_batch_format` 不是 :class:`mindspore.dataset.vision.ImageBatchFormat` 的类型。 - - **TypeError** - 如果 `alpha` 不是float类型。 - - **TypeError** - 如果 `prob` 不是 float 类型。 - - **ValueError** - 如果 `alpha` 小于或等于 0。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.CutMixBatch +================================================= + +.. py:class:: mindspore.dataset.vision.CutMixBatch(image_batch_format, alpha=1.0, prob=1.0) + + 对输入批次的图像和标注应用剪切混合转换。 + 请注意,在调用此操作符之前,您需要将标注制作为 one-hot 格式并进行批处理。 + + 参数: + - **image_batch_format** (:class:`~.vision.ImageBatchFormat`) - 图像批处理输出格式。可以是 ``ImageBatchFormat.NHWC`` 或 ``ImageBatchFormat.NCHW`` 。 + - **alpha** (float, 可选) - β分布的超参数,必须大于0。默认值: ``1.0`` 。 + - **prob** (float, 可选) - 对每个图像应用剪切混合处理的概率,取值范围:[0.0, 1.0]。默认值: ``1.0`` 。 + + 异常: + - **TypeError** - 如果 `image_batch_format` 不是 :class:`mindspore.dataset.vision.ImageBatchFormat` 的类型。 + - **TypeError** - 如果 `alpha` 不是float类型。 + - **TypeError** - 如果 `prob` 不是 float 类型。 + - **ValueError** - 如果 `alpha` 小于或等于 0。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutOut.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutOut.rst index fb4b28ec2ab..26e1e4255ac 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutOut.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.CutOut.rst @@ -1,23 +1,23 @@ -mindspore.dataset.vision.CutOut -============================================ - -.. py:class:: mindspore.dataset.vision.CutOut(length, num_patches=1, is_hwc=True) - - 从输入图像数组中随机裁剪出给定数量的正方形区域。 - - 参数: - - **length** (int) - 每个正方形区域的边长,必须大于 0。 - - **num_patches** (int, 可选) - 要从图像中切出的正方形区域数,必须大于0。默认值: ``1`` 。 - - **is_hwc** (bool, 可选) - 表示输入图像是否为HWC格式, ``True`` 为HWC格式, ``False`` 为CHW格式。默认值: ``True`` 。 - - 异常: - - **TypeError** - 如果 `length` 不是int类型。 - - **TypeError** - 如果 `num_patches` 不是int类型。 - - **TypeError** - 如果 `is_hwc` 不是bool类型。 - - **ValueError** - 如果 `length` 小于或等于 0。 - - **ValueError** - 如果 `num_patches` 小于或等于 0。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.CutOut +============================================ + +.. py:class:: mindspore.dataset.vision.CutOut(length, num_patches=1, is_hwc=True) + + 从输入图像数组中随机裁剪出给定数量的正方形区域。 + + 参数: + - **length** (int) - 每个正方形区域的边长,必须大于 0。 + - **num_patches** (int, 可选) - 要从图像中切出的正方形区域数,必须大于0。默认值: ``1`` 。 + - **is_hwc** (bool, 可选) - 表示输入图像是否为HWC格式, ``True`` 为HWC格式, ``False`` 为CHW格式。默认值: ``True`` 。 + + 异常: + - **TypeError** - 如果 `length` 不是int类型。 + - **TypeError** - 如果 `num_patches` 不是int类型。 + - **TypeError** - 如果 `is_hwc` 不是bool类型。 + - **ValueError** - 如果 `length` 小于或等于 0。 + - **ValueError** - 如果 `num_patches` 小于或等于 0。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Decode.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Decode.rst index 5bec886beba..3fce1fa87c1 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Decode.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Decode.rst @@ -1,33 +1,33 @@ -mindspore.dataset.vision.Decode -=============================== - -.. py:class:: mindspore.dataset.vision.Decode(to_pil=False) - - 将输入的压缩图像解码为RGB格式。当前支持的图片类型:JPEG、BMP、PNG、TIFF、GIF(需要指定 `to_pil=True`)、WEBP(需要指定 `to_pil=True`)。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **to_pil** (bool,可选) - 是否将图像解码为PIL数据类型。若为 ``True`` ,图像将被解码为PIL数据类型,否则解码为NumPy数据类型。默认值: ``False`` 。 - - 异常: - - **RuntimeError** - 如果输入图像不是一维序列。 - - **RuntimeError** - 如果输入数据不是合法的图像字节数据。 - - **RuntimeError** - 如果输入数据已经是解码的图像数据。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Decode +=============================== + +.. py:class:: mindspore.dataset.vision.Decode(to_pil=False) + + 将输入的压缩图像解码为RGB格式。当前支持的图片类型:JPEG、BMP、PNG、TIFF、GIF(需要指定 `to_pil=True`)、WEBP(需要指定 `to_pil=True`)。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **to_pil** (bool,可选) - 是否将图像解码为PIL数据类型。若为 ``True`` ,图像将被解码为PIL数据类型,否则解码为NumPy数据类型。默认值: ``False`` 。 + + 异常: + - **RuntimeError** - 如果输入图像不是一维序列。 + - **RuntimeError** - 如果输入数据不是合法的图像字节数据。 + - **RuntimeError** - 如果输入数据已经是解码的图像数据。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Equalize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Equalize.rst index 89917d6dc26..edd34282b6f 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Equalize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Equalize.rst @@ -1,28 +1,28 @@ -mindspore.dataset.vision.Equalize -================================= - -.. py:class:: mindspore.dataset.vision.Equalize() - - 对输入图像进行直方图均衡化。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 异常: - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型,输入数据的通道仅支持1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Equalize +================================= + +.. py:class:: mindspore.dataset.vision.Equalize() + + 对输入图像进行直方图均衡化。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 异常: + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型,输入数据的通道仅支持1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.GaussianBlur.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.GaussianBlur.rst index 2fb3f23a0cf..0a144f2b43f 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.GaussianBlur.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.GaussianBlur.rst @@ -1,41 +1,41 @@ -mindspore.dataset.vision.GaussianBlur -===================================== - -.. py:class:: mindspore.dataset.vision.GaussianBlur(kernel_size, sigma=None) - - 使用指定的高斯核对输入图像进行模糊处理。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **kernel_size** (Union[int, Sequence[int, int]]) - 高斯核的大小。需为正奇数。 - 若输入类型为int,将同时使用该值作为高斯核的宽、高。 - 若输入类型为Sequence[int, int],将分别使用这两个元素作为高斯核的宽、高。 - - **sigma** (Union[float, Sequence[float, float]], 可选) - 高斯核的标准差。需为正数。 - 若输入类型为float,将同时使用该值作为高斯核宽、高的标准差。 - 若输入类型为Sequence[float, float],将分别使用这两个元素作为高斯核宽、高的标准差。 - 默认值: ``None`` ,将通过公式 :math:`((kernel\_size - 1) * 0.5 - 1) * 0.3 + 0.8` 计算得到高斯核的标准差。 - - 异常: - - **TypeError** - 如果 `kernel_size` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `sigma` 不是float或Sequence[float]类型。 - - **ValueError** - 如果 `kernel_size` 不是正数和奇数。 - - **ValueError** - 如果 `sigma` 不是正数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,参数 `kernel_size` 仅支持取值1、3、5。输入数据的维度限制为[4, 6]和[8192, 4096]之间。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.GaussianBlur +===================================== + +.. py:class:: mindspore.dataset.vision.GaussianBlur(kernel_size, sigma=None) + + 使用指定的高斯核对输入图像进行模糊处理。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **kernel_size** (Union[int, Sequence[int, int]]) - 高斯核的大小。需为正奇数。 + 若输入类型为int,将同时使用该值作为高斯核的宽、高。 + 若输入类型为Sequence[int, int],将分别使用这两个元素作为高斯核的宽、高。 + - **sigma** (Union[float, Sequence[float, float]], 可选) - 高斯核的标准差。需为正数。 + 若输入类型为float,将同时使用该值作为高斯核宽、高的标准差。 + 若输入类型为Sequence[float, float],将分别使用这两个元素作为高斯核宽、高的标准差。 + 默认值: ``None`` ,将通过公式 :math:`((kernel\_size - 1) * 0.5 - 1) * 0.3 + 0.8` 计算得到高斯核的标准差。 + + 异常: + - **TypeError** - 如果 `kernel_size` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `sigma` 不是float或Sequence[float]类型。 + - **ValueError** - 如果 `kernel_size` 不是正数和奇数。 + - **ValueError** - 如果 `sigma` 不是正数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,参数 `kernel_size` 仅支持取值1、3、5。输入数据的维度限制为[4, 6]和[8192, 4096]之间。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HWC2CHW.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HWC2CHW.rst index 5f14b2f62ff..04f5becde25 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HWC2CHW.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HWC2CHW.rst @@ -1,16 +1,16 @@ -mindspore.dataset.vision.HWC2CHW -================================ - -.. py:class:: mindspore.dataset.vision.HWC2CHW() - - 将输入图像的shape从 转换为 。 - 如果输入图像的shape为 ,图像将保持不变。 - - .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 - - 异常: - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.HWC2CHW +================================ + +.. py:class:: mindspore.dataset.vision.HWC2CHW() + + 将输入图像的shape从 转换为 。 + 如果输入图像的shape为 ,图像将保持不变。 + + .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 + + 异常: + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst index ee839aed2d4..db596cf6cc1 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst @@ -1,28 +1,28 @@ -mindspore.dataset.vision.HorizontalFlip -======================================= - -.. py:class:: mindspore.dataset.vision.HorizontalFlip() - - 水平翻转输入图像。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 异常: - - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据支持 `uint8` 和 `float32` 类型,输入数据的通道仅支持 1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.HorizontalFlip +======================================= + +.. py:class:: mindspore.dataset.vision.HorizontalFlip() + + 水平翻转输入图像。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 异常: + - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据支持 `uint8` 和 `float32` 类型,输入数据的通道仅支持 1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageBatchFormat.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageBatchFormat.rst index 8cea10dee66..4d19285571a 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageBatchFormat.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageBatchFormat.rst @@ -1,11 +1,11 @@ -mindspore.dataset.vision.ImageBatchFormat -========================================= - -.. py:class:: mindspore.dataset.vision.ImageBatchFormat - - 图像批处理输出格式枚举类。 - - 可选枚举值为: ``ImageBatchFormat.NHWC`` 、 ``ImageBatchFormat.NCHW`` 。 - - - **ImageBatchFormat.NHWC** - 按批次N、高度H、宽度W、通道C的顺序存储数据。 - - **ImageBatchFormat.NCHW** - 按批次N、通道C、高度H、宽度W的顺序存储数据。 +mindspore.dataset.vision.ImageBatchFormat +========================================= + +.. py:class:: mindspore.dataset.vision.ImageBatchFormat + + 图像批处理输出格式枚举类。 + + 可选枚举值为: ``ImageBatchFormat.NHWC`` 、 ``ImageBatchFormat.NCHW`` 。 + + - **ImageBatchFormat.NHWC** - 按批次N、高度H、宽度W、通道C的顺序存储数据。 + - **ImageBatchFormat.NCHW** - 按批次N、通道C、高度H、宽度W的顺序存储数据。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageReadMode.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageReadMode.rst index 68b57793bb4..e56e3f03224 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageReadMode.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.ImageReadMode.rst @@ -1,12 +1,12 @@ -mindspore.dataset.vision.ImageReadMode -====================================== - -.. py:class:: mindspore.dataset.vision.ImageReadMode - - 图像文件读取方式枚举类。 - - 可选枚举值为: ``ImageReadMode.UNCHANGED`` 、 ``ImageReadMode.GRAYSCALE`` 、 ``ImageReadMode.COLOR`` 。 - - - **ImageReadMode.UNCHANGED** - 按照图像原始格式读取。 - - **ImageReadMode.GRAYSCALE** - 读取并转为单通道灰度数据。 - - **ImageReadMode.COLOR** - 读取并换为3通道RGB彩色数据。 +mindspore.dataset.vision.ImageReadMode +====================================== + +.. py:class:: mindspore.dataset.vision.ImageReadMode + + 图像文件读取方式枚举类。 + + 可选枚举值为: ``ImageReadMode.UNCHANGED`` 、 ``ImageReadMode.GRAYSCALE`` 、 ``ImageReadMode.COLOR`` 。 + + - **ImageReadMode.UNCHANGED** - 按照图像原始格式读取。 + - **ImageReadMode.GRAYSCALE** - 读取并转为单通道灰度数据。 + - **ImageReadMode.COLOR** - 读取并换为3通道RGB彩色数据。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Inter.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Inter.rst index c8b9bef33fa..06bcf8e2dce 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Inter.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Inter.rst @@ -1,17 +1,17 @@ -mindspore.dataset.vision.Inter -============================== - -.. py:class:: mindspore.dataset.vision.Inter - - 图像插值方法枚举类。 - - 可选值如下: - - - **Inter.NEAREST** - 最近邻插值。 - - **Inter.ANTIALIAS** - 抗锯齿插值。仅当输入为 `PIL.Image.Image` 时支持。 - - **Inter.LINEAR** - 线性插值,实现同 ``Inter.BILINEAR`` 。 - - **Inter.BILINEAR** - 双线性插值。 - - **Inter.CUBIC** - 三次插值,实现同 ``Inter.BICUBIC`` 。 - - **Inter.BICUBIC** - 双三次插值。 - - **Inter.AREA** - 像素区域插值。仅当输入为 `numpy.ndarray` 时支持。 - - **Inter.PILCUBIC** - 类Pillow实现的双三次插值。仅当输入为 `numpy.ndarray` 时支持。 +mindspore.dataset.vision.Inter +============================== + +.. py:class:: mindspore.dataset.vision.Inter + + 图像插值方法枚举类。 + + 可选值如下: + + - **Inter.NEAREST** - 最近邻插值。 + - **Inter.ANTIALIAS** - 抗锯齿插值。仅当输入为 `PIL.Image.Image` 时支持。 + - **Inter.LINEAR** - 线性插值,实现同 ``Inter.BILINEAR`` 。 + - **Inter.BILINEAR** - 双线性插值。 + - **Inter.CUBIC** - 三次插值,实现同 ``Inter.BICUBIC`` 。 + - **Inter.BICUBIC** - 双三次插值。 + - **Inter.AREA** - 像素区域插值。仅当输入为 `numpy.ndarray` 时支持。 + - **Inter.PILCUBIC** - 类Pillow实现的双三次插值。仅当输入为 `numpy.ndarray` 时支持。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Invert.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Invert.rst index f3e67792a01..11b53d369c0 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Invert.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Invert.rst @@ -1,30 +1,30 @@ -mindspore.dataset.vision.Invert -=============================== - -.. py:class:: mindspore.dataset.vision.Invert() - - 对输入的RGB图像进行色彩反转。 - - 对于图像中的每个像素,若原像素值为 `pixel` ,则反转后的像素值为 `255 - pixel` 。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 异常: - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型,输入数据的通道仅支持1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Invert +=============================== + +.. py:class:: mindspore.dataset.vision.Invert() + + 对输入的RGB图像进行色彩反转。 + + 对于图像中的每个像素,若原像素值为 `pixel` ,则反转后的像素值为 `255 - pixel` 。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 异常: + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入数据仅支持 `uint8` 类型,输入数据的通道仅支持1和3。输入数据的高度限制范围为[4, 8192]、宽度限制范围为[6, 4096]。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.MixUpBatch.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.MixUpBatch.rst index 1effd9e68aa..bb8110fc191 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.MixUpBatch.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.MixUpBatch.rst @@ -1,22 +1,22 @@ -mindspore.dataset.vision.MixUpBatch -=================================== - -.. py:class:: mindspore.dataset.vision.MixUpBatch(alpha=1.0) - - 对输入批次的图像和标注应用混合转换。从批处理中随机抽取两个图像,其中一个图像乘以随机权重 (lambda),另一个图像乘以 (1 - lambda),并相加。该处理将会同时应用于one-hot标注。 - - 上述的 lambda 是根据指定的参数 `alpha` 生成的。计算方式为在 [alpha, 1] 范围内随机生成两个系数 x1,x2 ,然后 lambda = (x1 / (x1 + x2))。 - - 请注意,在调用此处理之前,您需要将标注制作成 one-hot 格式并进行batch操作。 - - 参数: - - **alpha** (float, 可选) - β分布的超参数,该值必须为正。默认值: ``1.0`` 。 - - 异常: - - **TypeError** - 如果 `alpha` 不是float类型。 - - **ValueError** - 如果 `alpha` 不是正数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.MixUpBatch +=================================== + +.. py:class:: mindspore.dataset.vision.MixUpBatch(alpha=1.0) + + 对输入批次的图像和标注应用混合转换。从批处理中随机抽取两个图像,其中一个图像乘以随机权重 (lambda),另一个图像乘以 (1 - lambda),并相加。该处理将会同时应用于one-hot标注。 + + 上述的 lambda 是根据指定的参数 `alpha` 生成的。计算方式为在 [alpha, 1] 范围内随机生成两个系数 x1,x2 ,然后 lambda = (x1 / (x1 + x2))。 + + 请注意,在调用此处理之前,您需要将标注制作成 one-hot 格式并进行batch操作。 + + 参数: + - **alpha** (float, 可选) - β分布的超参数,该值必须为正。默认值: ``1.0`` 。 + + 异常: + - **TypeError** - 如果 `alpha` 不是float类型。 + - **ValueError** - 如果 `alpha` 不是正数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Normalize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Normalize.rst index ae117ec68fe..86513e446c6 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Normalize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Normalize.rst @@ -1,43 +1,43 @@ -mindspore.dataset.vision.Normalize -================================== - -.. py:class:: mindspore.dataset.vision.Normalize(mean, std, is_hwc=True) - - 根据均值和标准差对输入图像进行归一化。 - - 此处理将使用以下公式对输入图像进行归一化:output[channel] = (input[channel] - mean[channel]) / std[channel],其中 channel 代表通道索引,channel >= 1。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 - - 参数: - - **mean** (sequence) - 图像每个通道的均值组成的列表或元组。平均值必须在 [0.0, 255.0] 范围内。 - - **std** (sequence) - 图像每个通道的标准差组成的列表或元组。标准差值必须在 (0.0, 255.0] 范围内。 - - **is_hwc** (bool, 可选) - 表示输入图像是否为HWC格式, ``True`` 为HWC格式, ``False`` 为CHW格式。默认值: ``True`` 。 - - 异常: - - **TypeError** - 如果 `mean` 不是sequence类型。 - - **TypeError** - 如果 `std` 不是sequence类型。 - - **TypeError** - 如果 `is_hwc` 不是bool类型。 - - **ValueError** - 如果 `mean` 不在 [0.0, 255.0] 范围内。 - - **ValueError** - 如果 `std` 不在 (0.0, 255.0] 范围内。 - - **RuntimeError** - 如果给定的tensor format不是或<...,H, W, C>。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 CPU 时,输入数据支持 `uint8` 、 `float32` 或者 `float64` 类型,输入数据的通道支持 1/2/3 。 - - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,输入数据的通道仅支持 1/3。输入数据的维度限制为[4, 6]和[8192, 4096]之间。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Normalize +================================== + +.. py:class:: mindspore.dataset.vision.Normalize(mean, std, is_hwc=True) + + 根据均值和标准差对输入图像进行归一化。 + + 此处理将使用以下公式对输入图像进行归一化:output[channel] = (input[channel] - mean[channel]) / std[channel],其中 channel 代表通道索引,channel >= 1。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 + + 参数: + - **mean** (sequence) - 图像每个通道的均值组成的列表或元组。平均值必须在 [0.0, 255.0] 范围内。 + - **std** (sequence) - 图像每个通道的标准差组成的列表或元组。标准差值必须在 (0.0, 255.0] 范围内。 + - **is_hwc** (bool, 可选) - 表示输入图像是否为HWC格式, ``True`` 为HWC格式, ``False`` 为CHW格式。默认值: ``True`` 。 + + 异常: + - **TypeError** - 如果 `mean` 不是sequence类型。 + - **TypeError** - 如果 `std` 不是sequence类型。 + - **TypeError** - 如果 `is_hwc` 不是bool类型。 + - **ValueError** - 如果 `mean` 不在 [0.0, 255.0] 范围内。 + - **ValueError** - 如果 `std` 不在 (0.0, 255.0] 范围内。 + - **RuntimeError** - 如果给定的tensor format不是或<...,H, W, C>。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 CPU 时,输入数据支持 `uint8` 、 `float32` 或者 `float64` 类型,输入数据的通道支持 1/2/3 。 + - 当执行设备是 Ascend 时,输入数据支持 `uint8` 或者 `float32` 类型,输入数据的通道仅支持 1/3。输入数据的维度限制为[4, 6]和[8192, 4096]之间。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.NormalizePad.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.NormalizePad.rst index 52449e3d53a..2401c595e18 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.NormalizePad.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.NormalizePad.rst @@ -1,25 +1,25 @@ -mindspore.dataset.vision.NormalizePad -===================================== - -.. py:class:: mindspore.dataset.vision.NormalizePad(mean, std, dtype="float32", is_hwc=True) - - 根据均值和标准差对输入图像进行归一化,然后填充一个全零的额外通道。 - - 参数: - - **mean** (sequence) - 图像每个通道的均值组成的列表或元组。平均值必须在 (0.0, 255.0] 范围内。 - - **std** (sequence) - 图像每个通道的标准差组成的列表或元组。标准差值必须在 (0.0, 255.0] 范围内。 - - **dtype** (str, 可选) - 输出图像的数据类型。默认值: ``"float32"`` 。 - - **is_hwc** (bool, 可选) - 指定输入图像的格式,若为 ``True`` ,表示输入为 HW(C) 格式,否则为 CHW 格式。默认值: ``True`` 。 - - 异常: - - **TypeError** - 如果 `mean` 不是sequence类型。 - - **TypeError** - 如果 `std` 不是sequence类型。 - - **TypeError** - 如果 `dtype` 不是str类型。 - - **TypeError** - 如果 `is_hwc` 不是bool类型。 - - **ValueError** - 如果 `mean` 不在 [0.0, 255.0] 范围内。 - - **ValueError** - 如果 `std` 不在范围内 (0.0, 255.0]。 - - **RuntimeError** - 如果输入图像的shape不是 , 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.NormalizePad +===================================== + +.. py:class:: mindspore.dataset.vision.NormalizePad(mean, std, dtype="float32", is_hwc=True) + + 根据均值和标准差对输入图像进行归一化,然后填充一个全零的额外通道。 + + 参数: + - **mean** (sequence) - 图像每个通道的均值组成的列表或元组。平均值必须在 (0.0, 255.0] 范围内。 + - **std** (sequence) - 图像每个通道的标准差组成的列表或元组。标准差值必须在 (0.0, 255.0] 范围内。 + - **dtype** (str, 可选) - 输出图像的数据类型。默认值: ``"float32"`` 。 + - **is_hwc** (bool, 可选) - 指定输入图像的格式,若为 ``True`` ,表示输入为 HW(C) 格式,否则为 CHW 格式。默认值: ``True`` 。 + + 异常: + - **TypeError** - 如果 `mean` 不是sequence类型。 + - **TypeError** - 如果 `std` 不是sequence类型。 + - **TypeError** - 如果 `dtype` 不是str类型。 + - **TypeError** - 如果 `is_hwc` 不是bool类型。 + - **ValueError** - 如果 `mean` 不在 [0.0, 255.0] 范围内。 + - **ValueError** - 如果 `std` 不在范围内 (0.0, 255.0]。 + - **RuntimeError** - 如果输入图像的shape不是 , 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Pad.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Pad.rst index 6acee6f071a..1e22b958da2 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Pad.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.Pad.rst @@ -1,50 +1,50 @@ -mindspore.dataset.vision.Pad -============================ - -.. py:class:: mindspore.dataset.vision.Pad(padding, fill_value=0, padding_mode=Border.CONSTANT) - - 填充图像。 - - 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 - - 参数: - - **padding** (Union[int, Sequence[int, int], Sequence[int, int, int, int]]) - 图像各边填充的像素数。 - 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 - 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 - 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 - 填充值必须为非负值。 - - **fill_value** (Union[int, tuple[int]], 可选) - 填充的像素值,仅在 `padding_mode` 取值为 ``Border.CONSTANT`` 时有效。 - 如果是3元素元组,则分别用于填充R、G、B通道。 - 如果是整数,则用于所有 RGB 通道。 - `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 - - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 - - - **Border.CONSTANT** - 使用常量值进行填充。 - - **Border.EDGE** - 使用各边的边界像素值进行填充。 - - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 - - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 - - 异常: - - **TypeError** - 如果 `padding` 不是int或Sequence[int, int], Sequence[int, int, int, int]类型。 - - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 - - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 - - **ValueError** - 如果 `padding` 为负数。 - - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ - - .. py:method:: device(device_target="CPU") - - 指定该变换执行的设备。 - - - 当执行设备是 Ascend 时,输入/输出数据的维度限制为[4, 6]和[32768, 32768]之间。 - - 参数: - - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 - - 异常: - - **TypeError** - 当 `device_target` 的类型不为str。 - - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 +mindspore.dataset.vision.Pad +============================ + +.. py:class:: mindspore.dataset.vision.Pad(padding, fill_value=0, padding_mode=Border.CONSTANT) + + 填充图像。 + + 支持 Ascend 硬件加速,需要通过 `.device("Ascend")` 方式开启。 + + 参数: + - **padding** (Union[int, Sequence[int, int], Sequence[int, int, int, int]]) - 图像各边填充的像素数。 + 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 + 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 + 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 + 填充值必须为非负值。 + - **fill_value** (Union[int, tuple[int]], 可选) - 填充的像素值,仅在 `padding_mode` 取值为 ``Border.CONSTANT`` 时有效。 + 如果是3元素元组,则分别用于填充R、G、B通道。 + 如果是整数,则用于所有 RGB 通道。 + `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 + - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 + + - **Border.CONSTANT** - 使用常量值进行填充。 + - **Border.EDGE** - 使用各边的边界像素值进行填充。 + - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 + - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 + + 异常: + - **TypeError** - 如果 `padding` 不是int或Sequence[int, int], Sequence[int, int, int, int]类型。 + - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 + - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 + - **ValueError** - 如果 `padding` 为负数。 + - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ + + .. py:method:: device(device_target="CPU") + + 指定该变换执行的设备。 + + - 当执行设备是 Ascend 时,输入/输出数据的维度限制为[4, 6]和[32768, 32768]之间。 + + 参数: + - **device_target** (str, 可选) - 算子将在指定的设备上运行。当前支持 ``CPU`` 和 ``Ascend`` 。默认值: ``CPU`` 。 + + 异常: + - **TypeError** - 当 `device_target` 的类型不为str。 + - **ValueError** - 当 `device_target` 的取值不为 ``CPU`` / ``Ascend`` 。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.PadToSize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.PadToSize.rst index 9700e8146da..6b2401ac14e 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.PadToSize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.PadToSize.rst @@ -1,38 +1,38 @@ -mindspore.dataset.vision.PadToSize -================================== - -.. py:class:: mindspore.dataset.vision.PadToSize(size, offset=None, fill_value=0, padding_mode=Border.CONSTANT) - - 将图像填充到固定大小。 - - 参数: - - **size** (Union[int, Sequence[int, int]]) - 要填充的目标大小。 - 若输入整型,则将图像填充为(size, size)大小;如果提供了序列[int, int],则将图像填充为(高度, 宽度)大小。 - - **offset** (Union[int, Sequence[int, int]], 可选) - 顶部和左侧要填充的长度。 - 如果输入整型,使用此值填充图像上侧和左侧。 - 如果提供了序列[int, int],则应按[top, left]的顺序排列,填充图像上侧和左侧。 - 默认值: ``None`` ,表示对称填充,保持原始图像处于中心位置。 - - **fill_value** (Union[int, tuple[int, int, int]], 可选) - 填充的像素值,仅在 `padding_mode` 取值为 ``Border.CONSTANT`` 时有效。 - 如果是3元素元组,则分别用于填充R、G、B通道。 - 如果是整数,则用于所有 RGB 通道。 - `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 - - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 - - - **Border.CONSTANT** - 使用常量值进行填充。 - - **Border.EDGE** - 使用各边的边界像素值进行填充。 - - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 - - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 - - 异常: - - **TypeError** - 如果 `size` 不是int或tuple[int, int]类型。 - - **TypeError** - 如果 `offset` 不是int或tupl[int, int]类型。 - - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 - - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 - - **ValueError** - 如果 `size` 不是正数。 - - **ValueError** - 如果 `offset` 为负数。 - - **ValueError** - 如果 `fill_value` 不在[0, 255]的范围内。 - - **RuntimeError** - 如果输入图像的形状不是。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.PadToSize +================================== + +.. py:class:: mindspore.dataset.vision.PadToSize(size, offset=None, fill_value=0, padding_mode=Border.CONSTANT) + + 将图像填充到固定大小。 + + 参数: + - **size** (Union[int, Sequence[int, int]]) - 要填充的目标大小。 + 若输入整型,则将图像填充为(size, size)大小;如果提供了序列[int, int],则将图像填充为(高度, 宽度)大小。 + - **offset** (Union[int, Sequence[int, int]], 可选) - 顶部和左侧要填充的长度。 + 如果输入整型,使用此值填充图像上侧和左侧。 + 如果提供了序列[int, int],则应按[top, left]的顺序排列,填充图像上侧和左侧。 + 默认值: ``None`` ,表示对称填充,保持原始图像处于中心位置。 + - **fill_value** (Union[int, tuple[int, int, int]], 可选) - 填充的像素值,仅在 `padding_mode` 取值为 ``Border.CONSTANT`` 时有效。 + 如果是3元素元组,则分别用于填充R、G、B通道。 + 如果是整数,则用于所有 RGB 通道。 + `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 + - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 + + - **Border.CONSTANT** - 使用常量值进行填充。 + - **Border.EDGE** - 使用各边的边界像素值进行填充。 + - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 + - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 + + 异常: + - **TypeError** - 如果 `size` 不是int或tuple[int, int]类型。 + - **TypeError** - 如果 `offset` 不是int或tupl[int, int]类型。 + - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 + - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 + - **ValueError** - 如果 `size` 不是正数。 + - **ValueError** - 如果 `offset` 为负数。 + - **ValueError** - 如果 `fill_value` 不在[0, 255]的范围内。 + - **RuntimeError** - 如果输入图像的形状不是。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAffine.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAffine.rst index dd493e7566d..95d182246de 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAffine.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAffine.rst @@ -1,43 +1,43 @@ -mindspore.dataset.vision.RandomAffine -===================================== - -.. py:class:: mindspore.dataset.vision.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0) - - 对输入图像应用随机仿射变换。 - - 参数: - - **degrees** (Union[int, float, sequence]) - 旋转度数的范围。 - 如果 `degrees` 是一个数字,它代表旋转范围是(-degrees, degrees)。 - 如果 `degrees` 是一个序列,它代表旋转是 (min, max)。 - - **translate** (sequence, 可选) - 一个序列(tx_min, tx_max, ty_min, ty_max)用于表示水平(tx)方向和垂直(ty)方向的最小/最大平移范围,取值范围 [-1.0, 1.0]。默认值: ``None`` 。 - 水平和垂直偏移分别从以下范围中随机选择:(tx_min*width, tx_max*width) 和 (ty_min*height, ty_max*height)。 - 如果 `translate` 是一个包含2个值的元组或列表,则 (translate[0], translate[1]) 表示水平(X)方向的随机平移范围。 - 如果 `translate` 是一个包含4个值的元组或列表,则 (translate[0], translate[1]) 表示水平(X)方向的随机平移范围,(translate[2], translate[3])表示垂直(Y)方向的随机平移范围。 - 如果为None,则不对图像进行任何平移。 - - **scale** (sequence, 可选) - 图像的比例因子的随机范围,必须为非负数,使用原始比例。默认值: ``None`` 。 - - **shear** (Union[float, Sequence[float, float], Sequence[float, float, float, float]], 可选) - 图像的剪切因子的随机范围,必须为正数。默认值: ``None`` 。 - 如果是数字,则应用在 (-shear, +shear) 范围内平行于 X 轴的剪切。 - 如果 `shear` 是一个包含2个值的元组或列表,则在 (shear[0],shear[1]) 范围内进行水平(X)方向的剪切变换。 - 如果 `shear` 是一个包含4个值的元组或列表,则在 (shear[0],shear[1]) 范围内进行水平(X)方向的剪切变换,并在(shear[2], shear[3])范围内进行垂直(Y)方向的剪切变换。 - 如果为None,则不应用任何剪切。 - - **resample** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 - 默认值: ``Inter.NEAREST``。 - - - **fill_value** (Union[int, tuple[int]], 可选) - 用于填充输出图像中变换之外的区域。元组中必须有三个值,取值范围是[0, 255]。默认值: ``0`` 。 - - 异常: - - **TypeError** - 如果 `degrees` 不是int、float或sequence类型。 - - **TypeError** - 如果 `translate` 不是sequence类型。 - - **TypeError** - 如果 `scale` 不是sequence类型。 - - **TypeError** - 如果 `shear` 不是int、float或sequence类型。 - - **TypeError** - 如果 `resample` 不是 :class:`mindspore.dataset.vision.Inter` 的类型。 - - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 - - **ValueError** - 如果 `degrees` 为负数。 - - **ValueError** - 如果 `translate` 不在范围 [-1.0, 1.0] 内。 - - **ValueError** - 如果 `scale` 为负数。 - - **ValueError** - 如果 `shear` 不是正数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomAffine +===================================== + +.. py:class:: mindspore.dataset.vision.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0) + + 对输入图像应用随机仿射变换。 + + 参数: + - **degrees** (Union[int, float, sequence]) - 旋转度数的范围。 + 如果 `degrees` 是一个数字,它代表旋转范围是(-degrees, degrees)。 + 如果 `degrees` 是一个序列,它代表旋转是 (min, max)。 + - **translate** (sequence, 可选) - 一个序列(tx_min, tx_max, ty_min, ty_max)用于表示水平(tx)方向和垂直(ty)方向的最小/最大平移范围,取值范围 [-1.0, 1.0]。默认值: ``None`` 。 + 水平和垂直偏移分别从以下范围中随机选择:(tx_min*width, tx_max*width) 和 (ty_min*height, ty_max*height)。 + 如果 `translate` 是一个包含2个值的元组或列表,则 (translate[0], translate[1]) 表示水平(X)方向的随机平移范围。 + 如果 `translate` 是一个包含4个值的元组或列表,则 (translate[0], translate[1]) 表示水平(X)方向的随机平移范围,(translate[2], translate[3])表示垂直(Y)方向的随机平移范围。 + 如果为None,则不对图像进行任何平移。 + - **scale** (sequence, 可选) - 图像的比例因子的随机范围,必须为非负数,使用原始比例。默认值: ``None`` 。 + - **shear** (Union[float, Sequence[float, float], Sequence[float, float, float, float]], 可选) - 图像的剪切因子的随机范围,必须为正数。默认值: ``None`` 。 + 如果是数字,则应用在 (-shear, +shear) 范围内平行于 X 轴的剪切。 + 如果 `shear` 是一个包含2个值的元组或列表,则在 (shear[0],shear[1]) 范围内进行水平(X)方向的剪切变换。 + 如果 `shear` 是一个包含4个值的元组或列表,则在 (shear[0],shear[1]) 范围内进行水平(X)方向的剪切变换,并在(shear[2], shear[3])范围内进行垂直(Y)方向的剪切变换。 + 如果为None,则不应用任何剪切。 + - **resample** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 + 默认值: ``Inter.NEAREST``。 + + - **fill_value** (Union[int, tuple[int]], 可选) - 用于填充输出图像中变换之外的区域。元组中必须有三个值,取值范围是[0, 255]。默认值: ``0`` 。 + + 异常: + - **TypeError** - 如果 `degrees` 不是int、float或sequence类型。 + - **TypeError** - 如果 `translate` 不是sequence类型。 + - **TypeError** - 如果 `scale` 不是sequence类型。 + - **TypeError** - 如果 `shear` 不是int、float或sequence类型。 + - **TypeError** - 如果 `resample` 不是 :class:`mindspore.dataset.vision.Inter` 的类型。 + - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 + - **ValueError** - 如果 `degrees` 为负数。 + - **ValueError** - 如果 `translate` 不在范围 [-1.0, 1.0] 内。 + - **ValueError** - 如果 `scale` 为负数。 + - **ValueError** - 如果 `shear` 不是正数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAutoContrast.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAutoContrast.rst index f3093192e7f..aeafd3ad2d5 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAutoContrast.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomAutoContrast.rst @@ -1,24 +1,24 @@ -mindspore.dataset.vision.RandomAutoContrast -=========================================== - -.. py:class:: mindspore.dataset.vision.RandomAutoContrast(cutoff=0.0, ignore=None, prob=0.5) - - 以给定的概率自动调整图像的对比度。 - - 参数: - - **cutoff** (float, 可选) - 输入图像直方图中需要剔除的最亮和最暗像素的百分比。该值必须在 [0.0, 50.0) 范围内。默认值: ``0.0`` 。 - - **ignore** (Union[int, sequence], 可选) - 要忽略的背景像素值,该值必须在 [0, 255] 范围内。默认值: ``None`` 。 - - **prob** (float, 可选) - 图像被调整对比度的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 - - 异常: - - **TypeError** - 如果 `cutoff` 不是float类型。 - - **TypeError** - 如果 `ignore` 不是int或sequence类型。 - - **TypeError** - 如果 `prob` 的类型不为float。 - - **ValueError** - 如果 `cutoff` 不在[0, 50.0) 范围内。 - - **ValueError** - 如果 `ignore` 不在[0, 255] 范围内。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomAutoContrast +=========================================== + +.. py:class:: mindspore.dataset.vision.RandomAutoContrast(cutoff=0.0, ignore=None, prob=0.5) + + 以给定的概率自动调整图像的对比度。 + + 参数: + - **cutoff** (float, 可选) - 输入图像直方图中需要剔除的最亮和最暗像素的百分比。该值必须在 [0.0, 50.0) 范围内。默认值: ``0.0`` 。 + - **ignore** (Union[int, sequence], 可选) - 要忽略的背景像素值,该值必须在 [0, 255] 范围内。默认值: ``None`` 。 + - **prob** (float, 可选) - 图像被调整对比度的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 + + 异常: + - **TypeError** - 如果 `cutoff` 不是float类型。 + - **TypeError** - 如果 `ignore` 不是int或sequence类型。 + - **TypeError** - 如果 `prob` 的类型不为float。 + - **ValueError** - 如果 `cutoff` 不在[0, 50.0) 范围内。 + - **ValueError** - 如果 `ignore` 不在[0, 255] 范围内。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColor.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColor.rst index cb5afe50dac..0f7447a7d4a 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColor.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColor.rst @@ -1,19 +1,19 @@ -mindspore.dataset.vision.RandomColor -==================================== - -.. py:class:: mindspore.dataset.vision.RandomColor(degrees=(0.1, 1.9)) - - 随机调整输入图像的颜色。此操作仅适用于 3 通道RGB图像。 - - 参数: - - **degrees** (Sequence[float], 可选) - 色彩调节系数的范围,必须为非负数。它应该是(min, max)格式。 - 如果min与max相等,则代表色彩变化步长固定。默认值: ``(0.1, 1.9)`` 。 - - 异常: - - **TypeError** - 如果 `degrees` 不是Sequence[float]类型。 - - **ValueError** - 如果 `degrees` 为负数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomColor +==================================== + +.. py:class:: mindspore.dataset.vision.RandomColor(degrees=(0.1, 1.9)) + + 随机调整输入图像的颜色。此操作仅适用于 3 通道RGB图像。 + + 参数: + - **degrees** (Sequence[float], 可选) - 色彩调节系数的范围,必须为非负数。它应该是(min, max)格式。 + 如果min与max相等,则代表色彩变化步长固定。默认值: ``(0.1, 1.9)`` 。 + + 异常: + - **TypeError** - 如果 `degrees` 不是Sequence[float]类型。 + - **ValueError** - 如果 `degrees` 为负数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColorAdjust.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColorAdjust.rst index 0352e8dab79..257885604f1 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColorAdjust.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomColorAdjust.rst @@ -1,37 +1,37 @@ -mindspore.dataset.vision.RandomColorAdjust -========================================== - -.. py:class:: mindspore.dataset.vision.RandomColorAdjust(brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)) - - 随机调整输入图像的亮度、对比度、饱和度和色调。 - - .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 - - 参数: - - **brightness** (Union[float, Sequence[float]], 可选) - 亮度调整因子。不能为负。默认值: ``(1, 1)`` 。 - 如果是浮点数,则从 [max(0, 1-brightness), 1+brightness] 范围内统一选择因子。 - 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 - - **contrast** (Union[float, Sequence[float]], 可选) - 对比度调整因子。不能为负。默认值: ``(1, 1)`` 。 - 如果是浮点数,则从 [max(0, 1-contrast), 1+contrast] 范围内统一选择因子。 - 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 - - **saturation** (Union[float, Sequence[float]], 可选) - 饱和度调整因子。不能为负。默认值: ``(1, 1)`` 。 - 如果是浮点数,则从 [max(0, 1-saturation), 1+saturation] 范围内统一选择因子。 - 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 - - **hue** (Union[float, Sequence[float]], 可选) - 色调调整因子。默认值: ``(0, 0)`` 。 - 如果是浮点数,则代表是范围 [-hue, hue],从此范围中选择调整因子。注意 `hue` 取值应为[0, 0.5]。 - 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。注意取值范围min和max是 [-0.5, 0.5] 范围内的浮点数,并且min小于等于max。 - - 异常: - - **TypeError** - 如果 `brightness` 不是float或Sequence[float]类型。 - - **TypeError** - 如果 `contrast` 不是float或Sequence[float]类型。 - - **TypeError** - 如果 `saturation` 不是float或Sequence[float]类型。 - - **TypeError** - 如果 `hue` 不是float或Sequence[float]类型。 - - **ValueError** - 如果 `brightness` 为负数。 - - **ValueError** - 如果 `contrast` 为负数。 - - **ValueError** - 如果 `saturation` 为负数。 - - **ValueError** - 如果 `hue` 不在 [-0.5, 0.5] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomColorAdjust +========================================== + +.. py:class:: mindspore.dataset.vision.RandomColorAdjust(brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)) + + 随机调整输入图像的亮度、对比度、饱和度和色调。 + + .. note:: 此操作默认通过 CPU 执行,也支持异构加速到 GPU 或 Ascend 上执行。 + + 参数: + - **brightness** (Union[float, Sequence[float]], 可选) - 亮度调整因子。不能为负。默认值: ``(1, 1)`` 。 + 如果是浮点数,则从 [max(0, 1-brightness), 1+brightness] 范围内统一选择因子。 + 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 + - **contrast** (Union[float, Sequence[float]], 可选) - 对比度调整因子。不能为负。默认值: ``(1, 1)`` 。 + 如果是浮点数,则从 [max(0, 1-contrast), 1+contrast] 范围内统一选择因子。 + 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 + - **saturation** (Union[float, Sequence[float]], 可选) - 饱和度调整因子。不能为负。默认值: ``(1, 1)`` 。 + 如果是浮点数,则从 [max(0, 1-saturation), 1+saturation] 范围内统一选择因子。 + 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。 + - **hue** (Union[float, Sequence[float]], 可选) - 色调调整因子。默认值: ``(0, 0)`` 。 + 如果是浮点数,则代表是范围 [-hue, hue],从此范围中选择调整因子。注意 `hue` 取值应为[0, 0.5]。 + 如果它是一个序列,则代表是范围 [min, max],从此范围中选择调整因子。注意取值范围min和max是 [-0.5, 0.5] 范围内的浮点数,并且min小于等于max。 + + 异常: + - **TypeError** - 如果 `brightness` 不是float或Sequence[float]类型。 + - **TypeError** - 如果 `contrast` 不是float或Sequence[float]类型。 + - **TypeError** - 如果 `saturation` 不是float或Sequence[float]类型。 + - **TypeError** - 如果 `hue` 不是float或Sequence[float]类型。 + - **ValueError** - 如果 `brightness` 为负数。 + - **ValueError** - 如果 `contrast` 为负数。 + - **ValueError** - 如果 `saturation` 为负数。 + - **ValueError** - 如果 `hue` 不在 [-0.5, 0.5] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCrop.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCrop.rst index cd8feb8ad59..8c4081dc256 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCrop.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCrop.rst @@ -1,44 +1,44 @@ -mindspore.dataset.vision.RandomCrop -=================================== - -.. py:class:: mindspore.dataset.vision.RandomCrop(size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT) - - 对输入图像进行随机区域的裁剪。如果输入图像尺寸小于输出尺寸,输入图像将在裁剪前被填充。 - - .. note:: 如果在多个数据列上应用此处理,则需要确保每个数据列图像的shape相同。 - - 参数: - - **size** (Union[int, Sequence[int]]) - 裁剪图像的输出尺寸大小。值必须为正。 - 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 - 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - - **padding** (Union[int, Sequence[int]], 可选) - 图像各边填充的像素数。填充值必须为非负值。默认值: ``None`` 。 - 如果 `padding` 不为 None,则首先使用 `padding` 填充图像。 - 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 - 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 - 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 - - **pad_if_needed** (bool, 可选) - 如果输入图像高度或者宽度小于 `size` 指定的输出图像尺寸大小,是否进行填充。默认值: ``False`` 。 - - **fill_value** (Union[int, tuple[int]], 可选) - 边框的像素强度,仅当 `padding_mode` 为 ``Border.CONSTANT`` 时有效。 - 如果是3元素元组,则分别用于填充R、G、B通道。 - 如果是整数,则用于所有RGB通道。 - `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 - - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。它可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 - - - **Border.CONSTANT** - 使用常量值进行填充。 - - **Border.EDGE** - 使用各边的边界像素值进行填充。 - - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 - - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 - - 异常: - - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `padding` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `pad_if_needed` 不是bool类型。 - - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 - - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 - - **ValueError** - 如果 `size` 不是正数。 - - **ValueError** - 如果 `padding` 为负数。 - - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomCrop +=================================== + +.. py:class:: mindspore.dataset.vision.RandomCrop(size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT) + + 对输入图像进行随机区域的裁剪。如果输入图像尺寸小于输出尺寸,输入图像将在裁剪前被填充。 + + .. note:: 如果在多个数据列上应用此处理,则需要确保每个数据列图像的shape相同。 + + 参数: + - **size** (Union[int, Sequence[int]]) - 裁剪图像的输出尺寸大小。值必须为正。 + 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 + 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + - **padding** (Union[int, Sequence[int]], 可选) - 图像各边填充的像素数。填充值必须为非负值。默认值: ``None`` 。 + 如果 `padding` 不为 None,则首先使用 `padding` 填充图像。 + 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 + 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 + 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 + - **pad_if_needed** (bool, 可选) - 如果输入图像高度或者宽度小于 `size` 指定的输出图像尺寸大小,是否进行填充。默认值: ``False`` 。 + - **fill_value** (Union[int, tuple[int]], 可选) - 边框的像素强度,仅当 `padding_mode` 为 ``Border.CONSTANT`` 时有效。 + 如果是3元素元组,则分别用于填充R、G、B通道。 + 如果是整数,则用于所有RGB通道。 + `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 + - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。它可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 + + - **Border.CONSTANT** - 使用常量值进行填充。 + - **Border.EDGE** - 使用各边的边界像素值进行填充。 + - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 + - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 + + 异常: + - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `padding` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `pad_if_needed` 不是bool类型。 + - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 + - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 + - **ValueError** - 如果 `size` 不是正数。 + - **ValueError** - 如果 `padding` 为负数。 + - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropDecodeResize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropDecodeResize.rst index 0f35c910707..9c56941de3a 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropDecodeResize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropDecodeResize.rst @@ -1,32 +1,32 @@ -mindspore.dataset.vision.RandomCropDecodeResize -=============================================== - -.. py:class:: mindspore.dataset.vision.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10) - - "裁剪"、"解码"和"调整尺寸大小"的组合处理。该操作将在随机位置裁剪输入图像,以 RGB 模式对裁剪后的图像进行解码,并调整解码图像的尺寸大小。针对 JPEG 图像进行了优化, 可以获得更好的性能。 - - 参数: - - **size** (Union[int, Sequence[int]]) - 调整后图像的输出尺寸大小。大小值必须为正。 - 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 - 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - - **scale** (Union[list, tuple], 可选) - 要裁剪的原始尺寸大小的各个尺寸的范围[min, max),必须为非负数。默认值: ``(0.08, 1.0)`` 。 - - **ratio** (Union[list, tuple], 可选) - 宽高比的范围 [min, max) 裁剪,必须为非负数。默认值: ``(3. / 4., 4. / 3.)``。 - - **interpolation** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 - 默认值: ``Inter.BILINEAR``。 - - **max_attempts** (int, 可选) - 生成随机裁剪位置的最大尝试次数,超过该次数时将使用中心裁剪, `max_attempts` 值必须为正数。默认值: ``10`` 。 - - 异常: - - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `scale` 不是tuple或list类型。 - - **TypeError** - 如果 `ratio` 不是tuple或list类型。 - - **TypeError** - 如果 `interpolation` 不是 :class:`mindspore.dataset.vision.Inter` 的类型。 - - **TypeError** - 如果 `max_attempts` 不是int类型。 - - **ValueError** - 如果 `size` 不是正数。 - - **ValueError** - 如果 `scale` 为负数。 - - **ValueError** - 如果 `ratio` 为负数。 - - **ValueError** - 如果 `max_attempts` 不是正数。 - - **RuntimeError** - 如果输入图像不是一维序列。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomCropDecodeResize +=============================================== + +.. py:class:: mindspore.dataset.vision.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10) + + "裁剪"、"解码"和"调整尺寸大小"的组合处理。该操作将在随机位置裁剪输入图像,以 RGB 模式对裁剪后的图像进行解码,并调整解码图像的尺寸大小。针对 JPEG 图像进行了优化, 可以获得更好的性能。 + + 参数: + - **size** (Union[int, Sequence[int]]) - 调整后图像的输出尺寸大小。大小值必须为正。 + 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 + 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + - **scale** (Union[list, tuple], 可选) - 要裁剪的原始尺寸大小的各个尺寸的范围[min, max),必须为非负数。默认值: ``(0.08, 1.0)`` 。 + - **ratio** (Union[list, tuple], 可选) - 宽高比的范围 [min, max) 裁剪,必须为非负数。默认值: ``(3. / 4., 4. / 3.)``。 + - **interpolation** (:class:`~.vision.Inter`, 可选) - 图像插值方法。可选值详见 :class:`mindspore.dataset.vision.Inter` 。 + 默认值: ``Inter.BILINEAR``。 + - **max_attempts** (int, 可选) - 生成随机裁剪位置的最大尝试次数,超过该次数时将使用中心裁剪, `max_attempts` 值必须为正数。默认值: ``10`` 。 + + 异常: + - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `scale` 不是tuple或list类型。 + - **TypeError** - 如果 `ratio` 不是tuple或list类型。 + - **TypeError** - 如果 `interpolation` 不是 :class:`mindspore.dataset.vision.Inter` 的类型。 + - **TypeError** - 如果 `max_attempts` 不是int类型。 + - **ValueError** - 如果 `size` 不是正数。 + - **ValueError** - 如果 `scale` 为负数。 + - **ValueError** - 如果 `ratio` 为负数。 + - **ValueError** - 如果 `max_attempts` 不是正数。 + - **RuntimeError** - 如果输入图像不是一维序列。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropWithBBox.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropWithBBox.rst index 5db89f37f68..c6b63fec4bc 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropWithBBox.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomCropWithBBox.rst @@ -1,42 +1,42 @@ -mindspore.dataset.vision.RandomCropWithBBox -=========================================== - -.. py:class:: mindspore.dataset.vision.RandomCropWithBBox(size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT) - - 在输入图像的随机位置进行裁剪并相应地调整边界框。 - - 参数: - - **size** (Union[int, Sequence[int]]) - 裁剪图像的输出尺寸大小。大小值必须为正。 - 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 - 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - - **padding** (Union[int, Sequence[int]], 可选) - 填充图像的像素数。填充值必须非负值。默认值: ``None`` 。 - 如果 `padding` 不为 None,则首先使用 `padding` 填充图像。 - 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 - 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 - 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 - - **pad_if_needed** (bool, 可选) - 如果输入图像高度或者宽度小于 `size` 指定的输出图像尺寸大小,是否进行填充。默认值: ``False`` 。 - - **fill_value** (Union[int, tuple[int]], 可选) - 边框的像素强度,仅当 `padding_mode` 为 ``Border.CONSTANT`` 时有效。 - 如果是3元素元组,则分别用于填充R、G、B通道。 - 如果是整数,则用于所有 RGB 通道。 - `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 - - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。它可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 - - - **Border.CONSTANT** - 使用常量值进行填充。 - - **Border.EDGE** - 使用各边的边界像素值进行填充。 - - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 - - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 - - 异常: - - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `padding` 不是int或Sequence[int]类型。 - - **TypeError** - 如果 `pad_if_needed` 不是bool类型。 - - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 - - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 - - **ValueError** - 如果 `size` 不是正数。 - - **ValueError** - 如果 `padding` 为负数。 - - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomCropWithBBox +=========================================== + +.. py:class:: mindspore.dataset.vision.RandomCropWithBBox(size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT) + + 在输入图像的随机位置进行裁剪并相应地调整边界框。 + + 参数: + - **size** (Union[int, Sequence[int]]) - 裁剪图像的输出尺寸大小。大小值必须为正。 + 如果 size 是整数,则返回一个裁剪尺寸大小为 (size, size) 的正方形。 + 如果 size 是一个长度为 2 的序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + - **padding** (Union[int, Sequence[int]], 可选) - 填充图像的像素数。填充值必须非负值。默认值: ``None`` 。 + 如果 `padding` 不为 None,则首先使用 `padding` 填充图像。 + 如果 `padding` 是一个整数,代表为图像的所有方向填充该值大小的像素。 + 如果 `padding` 是一个包含2个值的元组或列表,第一个值会用于填充图像的左侧和右侧,第二个值会用于填充图像的上侧和下侧。 + 如果 `padding` 是一个包含4个值的元组或列表,则分别填充图像的左侧、上侧、右侧和下侧。 + - **pad_if_needed** (bool, 可选) - 如果输入图像高度或者宽度小于 `size` 指定的输出图像尺寸大小,是否进行填充。默认值: ``False`` 。 + - **fill_value** (Union[int, tuple[int]], 可选) - 边框的像素强度,仅当 `padding_mode` 为 ``Border.CONSTANT`` 时有效。 + 如果是3元素元组,则分别用于填充R、G、B通道。 + 如果是整数,则用于所有 RGB 通道。 + `fill_value` 值必须在 [0, 255] 范围内。默认值: ``0`` 。 + - **padding_mode** (:class:`~.vision.Border`, 可选) - 边界填充方式。它可以是 ``Border.CONSTANT`` 、 ``Border.EDGE`` 、 ``Border.REFLECT`` 、 ``Border.SYMMETRIC`` 。默认值: ``Border.CONSTANT`` 。 + + - **Border.CONSTANT** - 使用常量值进行填充。 + - **Border.EDGE** - 使用各边的边界像素值进行填充。 + - **Border.REFLECT** - 以各边的边界为轴进行镜像填充,忽略边界像素值。 + - **Border.SYMMETRIC** - 以各边的边界为轴进行对称填充,包括边界像素值。 + + 异常: + - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `padding` 不是int或Sequence[int]类型。 + - **TypeError** - 如果 `pad_if_needed` 不是bool类型。 + - **TypeError** - 如果 `fill_value` 不是int或tuple[int]类型。 + - **TypeError** - 如果 `padding_mode` 不是 :class:`mindspore.dataset.vision.Border` 的类型。 + - **ValueError** - 如果 `size` 不是正数。 + - **ValueError** - 如果 `padding` 为负数。 + - **ValueError** - 如果 `fill_value` 不在 [0, 255] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomEqualize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomEqualize.rst index d3fd80f780a..9828128ed50 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomEqualize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomEqualize.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomEqualize -======================================= - -.. py:class:: mindspore.dataset.vision.RandomEqualize(prob=0.5) - - 以给定的概率随机对输入图像进行直方图均衡化。 - - 参数: - - **prob** (float, 可选) - 图像被均衡化的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 - - 异常: - - **TypeError** - 如果 `prob` 的类型不为float。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomEqualize +======================================= + +.. py:class:: mindspore.dataset.vision.RandomEqualize(prob=0.5) + + 以给定的概率随机对输入图像进行直方图均衡化。 + + 参数: + - **prob** (float, 可选) - 图像被均衡化的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 + + 异常: + - **TypeError** - 如果 `prob` 的类型不为float。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlip.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlip.rst index 9a96791679a..bc01203445b 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlip.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlip.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomHorizontalFlip -============================================= - -.. py:class:: mindspore.dataset.vision.RandomHorizontalFlip(prob=0.5) - - 对输入图像按给定的概率进行水平随机翻转。 - - 参数: - - **prob** (float, 可选) - 图像被翻转的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 - - 异常: - - **TypeError** - 如果 `prob` 不是float类型。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomHorizontalFlip +============================================= + +.. py:class:: mindspore.dataset.vision.RandomHorizontalFlip(prob=0.5) + + 对输入图像按给定的概率进行水平随机翻转。 + + 参数: + - **prob** (float, 可选) - 图像被翻转的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 + + 异常: + - **TypeError** - 如果 `prob` 不是float类型。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlipWithBBox.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlipWithBBox.rst index 341693696a3..079dd299d5e 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlipWithBBox.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomHorizontalFlipWithBBox.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomHorizontalFlipWithBBox -===================================================== - -.. py:class:: mindspore.dataset.vision.RandomHorizontalFlipWithBBox(prob=0.5) - - 按给定的概率,对输入图像及其边界框进行随机水平翻转。 - - 参数: - - **prob** (float, 可选) - 图像被翻转的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 - - 异常: - - **TypeError** - 如果 `prob` 不是float类型。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomHorizontalFlipWithBBox +===================================================== + +.. py:class:: mindspore.dataset.vision.RandomHorizontalFlipWithBBox(prob=0.5) + + 按给定的概率,对输入图像及其边界框进行随机水平翻转。 + + 参数: + - **prob** (float, 可选) - 图像被翻转的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 + + 异常: + - **TypeError** - 如果 `prob` 不是float类型。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomInvert.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomInvert.rst index 0c80e76054e..97c34a6f240 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomInvert.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomInvert.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomInvert -===================================== - -.. py:class:: mindspore.dataset.vision.RandomInvert(prob=0.5) - - 以给定的概率随机反转图像的颜色。 - - 参数: - - **prob** (float, 可选) - 图像被反转颜色的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 - - 异常: - - **TypeError** - 如果 `prob` 的类型不为float。 - - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomInvert +===================================== + +.. py:class:: mindspore.dataset.vision.RandomInvert(prob=0.5) + + 以给定的概率随机反转图像的颜色。 + + 参数: + - **prob** (float, 可选) - 图像被反转颜色的概率,取值范围:[0.0, 1.0]。默认值: ``0.5`` 。 + + 异常: + - **TypeError** - 如果 `prob` 的类型不为float。 + - **ValueError** - 如果 `prob` 不在 [0.0, 1.0] 范围。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomLighting.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomLighting.rst index d78f5a6a242..dccb223c5ce 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomLighting.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomLighting.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomLighting -======================================== - -.. py:class:: mindspore.dataset.vision.RandomLighting(alpha=0.05) - - 将AlexNet PCA的噪声添加到图像中。Alexnet PCA噪声的特征值和特征向量是由ImageNet数据集计算得出。 - - 参数: - - **alpha** (float, 可选) - 图像的强度,必须是非负的。默认值: ``0.05`` 。 - - 异常: - - **TypeError** - 如果 `alpha` 的类型不为float。 - - **ValueError** - 如果 `alpha` 为负数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomLighting +======================================== + +.. py:class:: mindspore.dataset.vision.RandomLighting(alpha=0.05) + + 将AlexNet PCA的噪声添加到图像中。Alexnet PCA噪声的特征值和特征向量是由ImageNet数据集计算得出。 + + 参数: + - **alpha** (float, 可选) - 图像的强度,必须是非负的。默认值: ``0.05`` 。 + + 异常: + - **TypeError** - 如果 `alpha` 的类型不为float。 + - **ValueError** - 如果 `alpha` 为负数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomPosterize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomPosterize.rst index 16c112957fd..ca136b49728 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomPosterize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomPosterize.rst @@ -1,19 +1,19 @@ -mindspore.dataset.vision.RandomPosterize -======================================== - -.. py:class:: mindspore.dataset.vision.RandomPosterize(bits=(8, 8)) - - 随机减少图像的颜色通道的比特位数,使图像变得高对比度和颜色鲜艳。 - - 参数: - - **bits** (Union[int, Sequence[int]], 可选) - 随机位数压缩的范围。位值必须在 [1,8] 范围内,并且在给定范围内至少包含一个整数值。它必须是 (min, max) 或整数格式。 - 如果min与max相等,那么它是一个单一的位数压缩操作。默认值: ``(8, 8)`` 。 - - 异常: - - **TypeError** - 如果 `bits` 不是int或Sequence[int]类型。 - - **ValueError** - 如果 `bits` 不在 [1, 8] 范围内。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomPosterize +======================================== + +.. py:class:: mindspore.dataset.vision.RandomPosterize(bits=(8, 8)) + + 随机减少图像的颜色通道的比特位数,使图像变得高对比度和颜色鲜艳。 + + 参数: + - **bits** (Union[int, Sequence[int]], 可选) - 随机位数压缩的范围。位值必须在 [1,8] 范围内,并且在给定范围内至少包含一个整数值。它必须是 (min, max) 或整数格式。 + 如果min与max相等,那么它是一个单一的位数压缩操作。默认值: ``(8, 8)`` 。 + + 异常: + - **TypeError** - 如果 `bits` 不是int或Sequence[int]类型。 + - **ValueError** - 如果 `bits` 不在 [1, 8] 范围内。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomResize.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomResize.rst index 807b0634f23..e4b75126383 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomResize.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.RandomResize.rst @@ -1,18 +1,18 @@ -mindspore.dataset.vision.RandomResize -===================================== - -.. py:class:: mindspore.dataset.vision.RandomResize(size) - - 对输入图像使用随机选择的 :class:`mindspore.dataset.vision.Inter` 插值方式去调整它的尺寸大小。 - - 参数: - - **size** (Union[int, Sequence[int]]) - 调整后图像的输出尺寸大小。值必须为正。若输入整型,则放缩至(size, size)大小;若输入2元素序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 - - 异常: - - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 - - **ValueError** - 如果 `size` 不是正数。 - - **RuntimeError** - 如果输入图像的shape不是 。 - - 教程样例: - - `视觉变换样例库 - `_ +mindspore.dataset.vision.RandomResize +===================================== + +.. py:class:: mindspore.dataset.vision.RandomResize(size) + + 对输入图像使用随机选择的 :class:`mindspore.dataset.vision.Inter` 插值方式去调整它的尺寸大小。 + + 参数: + - **size** (Union[int, Sequence[int]]) - 调整后图像的输出尺寸大小。值必须为正。若输入整型,则放缩至(size, size)大小;若输入2元素序列,则以2个元素分别为高和宽放缩至(高度, 宽度)大小。 + + 异常: + - **TypeError** - 如果 `size` 不是int或Sequence[int]类型。 + - **ValueError** - 如果 `size` 不是正数。 + - **RuntimeError** - 如果输入图像的shape不是 。 + + 教程样例: + - `视觉变换样例库 + `_ diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.SliceMode.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.SliceMode.rst index 2060f27a1df..32a686faa48 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.SliceMode.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.SliceMode.rst @@ -1,11 +1,11 @@ -mindspore.dataset.vision.SliceMode -================================== - -.. py:class:: mindspore.dataset.vision.SliceMode - - Tensor切片方式枚举类。 - - 可选枚举值为: ``SliceMode.PAD`` 、 ``SliceMode.DROP``。 - - - **SliceMode.PAD** - 当图像无法进行整数块切分时,填充最后一个图像块至指定切分尺寸大小。 - - **SliceMode.DROP** - 当图像无法进行整数块切分时,丢弃最后一个图像块。 +mindspore.dataset.vision.SliceMode +================================== + +.. py:class:: mindspore.dataset.vision.SliceMode + + Tensor切片方式枚举类。 + + 可选枚举值为: ``SliceMode.PAD`` 、 ``SliceMode.DROP``。 + + - **SliceMode.PAD** - 当图像无法进行整数块切分时,填充最后一个图像块至指定切分尺寸大小。 + - **SliceMode.DROP** - 当图像无法进行整数块切分时,丢弃最后一个图像块。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst old mode 100755 new mode 100644 index 16dac22b581..1605023f36b --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_jpeg.rst @@ -1,20 +1,20 @@ -mindspore.dataset.vision.encode_jpeg -==================================== - -.. py:function:: mindspore.dataset.vision.encode_jpeg(image, quality=75) - - 将输入的图像编码为JPEG数据。 - - 参数: - - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。 - - **quality** (int, 可选) - 生成的JPEG数据的质量,取值范围为[1, 100]。默认值: ``75`` 。 - - 返回: - - numpy.ndarray, 一维uint8类型数据。 - - 异常: - - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 - - **TypeError** - 如果 `quality` 不是int类型。 - - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 - - **RuntimeError** - 如果 `image` 的shape不是 。 - - **RuntimeError** - 如果 `quality` 小于1或大于100。 +mindspore.dataset.vision.encode_jpeg +==================================== + +.. py:function:: mindspore.dataset.vision.encode_jpeg(image, quality=75) + + 将输入的图像编码为JPEG数据。 + + 参数: + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。 + - **quality** (int, 可选) - 生成的JPEG数据的质量,取值范围为[1, 100]。默认值: ``75`` 。 + + 返回: + - numpy.ndarray, 一维uint8类型数据。 + + 异常: + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `quality` 不是int类型。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `quality` 小于1或大于100。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_png.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_png.rst old mode 100755 new mode 100644 index 585fface047..6e3ee7d234a --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_png.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.encode_png.rst @@ -1,20 +1,20 @@ -mindspore.dataset.vision.encode_png -=================================== - -.. py:function:: mindspore.dataset.vision.encode_png(image, compression_level=6) - - 将输入的图像编码为PNG数据。 - - 参数: - - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。 - - **compression_level** (int, 可选) - 编码压缩因子,取值范围为[0, 9]。默认值: ``6`` 。 - - 返回: - - numpy.ndarray, 一维uint8类型数据。 - - 异常: - - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 - - **TypeError** - 如果 `compression_level` 不是int类型。 - - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 - - **RuntimeError** - 如果 `image` 的shape不是 。 - - **RuntimeError** - 如果 `compression_level` 小于0或大于9。 +mindspore.dataset.vision.encode_png +=================================== + +.. py:function:: mindspore.dataset.vision.encode_png(image, compression_level=6) + + 将输入的图像编码为PNG数据。 + + 参数: + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 编码的图像。 + - **compression_level** (int, 可选) - 编码压缩因子,取值范围为[0, 9]。默认值: ``6`` 。 + + 返回: + - numpy.ndarray, 一维uint8类型数据。 + + 异常: + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `compression_level` 不是int类型。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `compression_level` 小于0或大于9。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_file.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_file.rst old mode 100755 new mode 100644 index 16d1a20d4ca..264b06be85f --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_file.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_file.rst @@ -1,16 +1,16 @@ -mindspore.dataset.vision.read_file -================================== - -.. py:function:: mindspore.dataset.vision.read_file(filename) - - 以二进制模式读取文件。 - - 参数: - - **filename** (str) - 待读取文件路径。 - - 返回: - - numpy.ndarray, 一维uint8类型数据。 - - 异常: - - **TypeError** - 如果 `filename` 不是str类型。 - - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 +mindspore.dataset.vision.read_file +================================== + +.. py:function:: mindspore.dataset.vision.read_file(filename) + + 以二进制模式读取文件。 + + 参数: + - **filename** (str) - 待读取文件路径。 + + 返回: + - numpy.ndarray, 一维uint8类型数据。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_image.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_image.rst old mode 100755 new mode 100644 index d339a3e39e4..644616e3a63 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_image.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.read_image.rst @@ -1,24 +1,24 @@ -mindspore.dataset.vision.read_image -=================================== - -.. py:function:: mindspore.dataset.vision.read_image(filename, mode=ImageReadMode.UNCHANGED) - - 读取图像文件并解码为3通道RGB彩色数据或灰度数据。 - 支持的文件类型有JPEG、PNG、BMP和TIFF。 - - 参数: - - **filename** (str) - 待读取图像文件路径。 - - **mode** (:class:`~.vision.ImageReadMode`, 可选) - 图像读取模式。它可以是 ``ImageReadMode.UNCHANGED`` 、 ``ImageReadMode.GRAYSCALE`` 、 ``ImageReadMode.COLOR`` 。 - 默认值: ``ImageReadMode.UNCHANGED`` 。 - - - **ImageReadMode.UNCHANGED** - 按照图像原始格式读取。 - - **ImageReadMode.GRAYSCALE** - 读取并转为单通道灰度数据。 - - **ImageReadMode.COLOR** - 读取并换为3通道RGB彩色数据。 - - 返回: - - numpy.ndarray, 三维uint8类型数据,shape为(H, W, C)。 - - 异常: - - **TypeError** - 如果 `filename` 不是str类型。 - - **TypeError** - 如果 `mode` 不是 :class:`mindspore.dataset.vision.ImageReadMode` 类型。 - - **RuntimeError** - 如果 `filename` 不存在或不是普通文件或由于格式等原因无法正常读取。 +mindspore.dataset.vision.read_image +=================================== + +.. py:function:: mindspore.dataset.vision.read_image(filename, mode=ImageReadMode.UNCHANGED) + + 读取图像文件并解码为3通道RGB彩色数据或灰度数据。 + 支持的文件类型有JPEG、PNG、BMP和TIFF。 + + 参数: + - **filename** (str) - 待读取图像文件路径。 + - **mode** (:class:`~.vision.ImageReadMode`, 可选) - 图像读取模式。它可以是 ``ImageReadMode.UNCHANGED`` 、 ``ImageReadMode.GRAYSCALE`` 、 ``ImageReadMode.COLOR`` 。 + 默认值: ``ImageReadMode.UNCHANGED`` 。 + + - **ImageReadMode.UNCHANGED** - 按照图像原始格式读取。 + - **ImageReadMode.GRAYSCALE** - 读取并转为单通道灰度数据。 + - **ImageReadMode.COLOR** - 读取并换为3通道RGB彩色数据。 + + 返回: + - numpy.ndarray, 三维uint8类型数据,shape为(H, W, C)。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **TypeError** - 如果 `mode` 不是 :class:`mindspore.dataset.vision.ImageReadMode` 类型。 + - **RuntimeError** - 如果 `filename` 不存在或不是普通文件或由于格式等原因无法正常读取。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_file.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_file.rst old mode 100755 new mode 100644 index 7983418e291..b895b870462 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_file.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_file.rst @@ -1,17 +1,17 @@ -mindspore.dataset.vision.write_file -=================================== - -.. py:function:: mindspore.dataset.vision.write_file(filename, data) - - 使用二进制模式将一维uint8类型数据数组写到文件。 - - 参数: - - **filename** (str) - 要写入的文件的路径。 - - **data** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的一维uint8数据。 - - 异常: - - **TypeError** - 如果 `filename` 不是str类型。 - - **TypeError** - 如果 `data` 不是numpy.ndarray或mindspore.Tensor类型。 - - **RuntimeError** - 如果 `filename` 不是普通文件。 - - **RuntimeError** - 如果 `data` 的数据类型不是uint8类型。 - - **RuntimeError** - 如果 `data` 的shape不是一维数组。 +mindspore.dataset.vision.write_file +=================================== + +.. py:function:: mindspore.dataset.vision.write_file(filename, data) + + 使用二进制模式将一维uint8类型数据数组写到文件。 + + 参数: + - **filename** (str) - 要写入的文件的路径。 + - **data** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的一维uint8数据。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **TypeError** - 如果 `data` 不是numpy.ndarray或mindspore.Tensor类型。 + - **RuntimeError** - 如果 `filename` 不是普通文件。 + - **RuntimeError** - 如果 `data` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `data` 的shape不是一维数组。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst old mode 100755 new mode 100644 index e50ba2ef39c..fe42d17dffe --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_jpeg.rst @@ -1,20 +1,20 @@ -mindspore.dataset.vision.write_jpeg -=================================== - -.. py:function:: mindspore.dataset.vision.write_jpeg(filename, image, quality=75) - - 将图像数据保存为JPEG文件。 - - 参数: - - **filename** (str) - 要写入的文件的路径。 - - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。 - - **quality** (int, 可选) - 生成的JPEG文件的质量,取值范围为[1, 100]。默认值: ``75`` 。 - - 异常: - - **TypeError** - 如果 `filename` 不是str类型。 - - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 - - **TypeError** - 如果 `quality` 不是int类型。 - - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 - - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 - - **RuntimeError** - 如果 `image` 的shape不是 。 - - **RuntimeError** - 如果 `quality` 小于1或大于100。 +mindspore.dataset.vision.write_jpeg +=================================== + +.. py:function:: mindspore.dataset.vision.write_jpeg(filename, image, quality=75) + + 将图像数据保存为JPEG文件。 + + 参数: + - **filename** (str) - 要写入的文件的路径。 + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。 + - **quality** (int, 可选) - 生成的JPEG文件的质量,取值范围为[1, 100]。默认值: ``75`` 。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `quality` 不是int类型。 + - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `quality` 小于1或大于100。 diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_png.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_png.rst old mode 100755 new mode 100644 index f9da0970623..006cf757419 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_png.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.write_png.rst @@ -1,20 +1,20 @@ -mindspore.dataset.vision.write_png -================================== - -.. py:function:: mindspore.dataset.vision.write_png(filename, image, compression_level=6) - - 将图像数据保存为PNG文件。 - - 参数: - - **filename** (str) - 要写入的文件的路径。 - - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。 - - **compression_level** (int, 可选) - 生成PNG文件的压缩级别,取值范围为[0, 9]。默认值: ``6``。 - - 异常: - - **TypeError** - 如果 `filename` 不是str类型。 - - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 - - **TypeError** - 如果 `compression_level` 不是int类型。 - - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 - - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 - - **RuntimeError** - 如果 `image` 的shape不是 。 - - **RuntimeError** - 如果 `compression_level` 小于0或大于9。 +mindspore.dataset.vision.write_png +================================== + +.. py:function:: mindspore.dataset.vision.write_png(filename, image, compression_level=6) + + 将图像数据保存为PNG文件。 + + 参数: + - **filename** (str) - 要写入的文件的路径。 + - **image** (Union[numpy.ndarray, mindspore.Tensor]) - 要写入的图像数据。 + - **compression_level** (int, 可选) - 生成PNG文件的压缩级别,取值范围为[0, 9]。默认值: ``6``。 + + 异常: + - **TypeError** - 如果 `filename` 不是str类型。 + - **TypeError** - 如果 `image` 不是numpy.ndarray或mindspore.Tensor类型。 + - **TypeError** - 如果 `compression_level` 不是int类型。 + - **RuntimeError** - 如果 `filename` 不存在或不是普通文件。 + - **RuntimeError** - 如果 `image` 的数据类型不是uint8类型。 + - **RuntimeError** - 如果 `image` 的shape不是 。 + - **RuntimeError** - 如果 `compression_level` 小于0或大于9。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Adam.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Adam.rst index 64795c27cfc..5ac063bd115 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Adam.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Adam.rst @@ -1,62 +1,62 @@ -mindspore.experimental.optim.Adam -=================================== - -.. py:class:: mindspore.experimental.optim.Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, amsgrad=False, *, maximize=False) - - Adaptive Moment Estimation (Adam)算法的实现。 - - 更新公式如下: - - .. math:: - \begin{aligned} - &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 - \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ - &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, - \:\textit{maximize} \\ - &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, - v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] - &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ - &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ - &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ - &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ - &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ - &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ - &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ - &\hspace{5mm}\textbf{if} \: amsgrad \\ - &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, - \widehat{v_t}) \\ - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ - &\bf{return} \: \theta_t \\[-1.ex] - \end{aligned} - - .. warning:: - 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 - - 参数: - - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 - - **lr** (Union[int, float, Tensor], 可选) - 学习率。默认值:``1e-3``。 - - **betas** (Tuple[float, float], 可选) - 动量矩阵的指数衰减率。参数范围(0.0, 1.0)。默认值:``(0.9, 0.999)``。 - - **eps** (float, 可选) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:``1e-8``。 - - **weight_decay** (float, 可选) - 权重衰减(L2 penalty)。默认值:``0.0``。 - - **amsgrad** (bool, 可选) - 是否使用AMSGrad算法。默认值:``False``。 - - 关键字参数: - - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 - - 输入: - - **gradients** (tuple[Tensor]) - 网络权重的梯度。 - - 异常: - - **ValueError** - 学习率不是int、float或Tensor。 - - **ValueError** - 学习率小于0。 - - **ValueError** - `eps` 小于0。 - - **ValueError** - `betas` 范围不在[0, 1)之间。 +mindspore.experimental.optim.Adam +=================================== + +.. py:class:: mindspore.experimental.optim.Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, amsgrad=False, *, maximize=False) + + Adaptive Moment Estimation (Adam)算法的实现。 + + 更新公式如下: + + .. math:: + \begin{aligned} + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, + \:\textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\bf{return} \: \theta_t \\[-1.ex] + \end{aligned} + + .. warning:: + 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 + + 参数: + - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 + - **lr** (Union[int, float, Tensor], 可选) - 学习率。默认值:``1e-3``。 + - **betas** (Tuple[float, float], 可选) - 动量矩阵的指数衰减率。参数范围(0.0, 1.0)。默认值:``(0.9, 0.999)``。 + - **eps** (float, 可选) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:``1e-8``。 + - **weight_decay** (float, 可选) - 权重衰减(L2 penalty)。默认值:``0.0``。 + - **amsgrad** (bool, 可选) - 是否使用AMSGrad算法。默认值:``False``。 + + 关键字参数: + - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 + + 输入: + - **gradients** (tuple[Tensor]) - 网络权重的梯度。 + + 异常: + - **ValueError** - 学习率不是int、float或Tensor。 + - **ValueError** - 学习率小于0。 + - **ValueError** - `eps` 小于0。 + - **ValueError** - `betas` 范围不在[0, 1)之间。 - **ValueError** - `weight_decay` 小于0。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.AdamW.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.AdamW.rst index 5ac65ca0b15..f38dd7e90e9 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.AdamW.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.AdamW.rst @@ -1,62 +1,62 @@ -mindspore.experimental.optim.AdamW -=================================== - -.. py:class:: mindspore.experimental.optim.AdamW(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize=False) - - Adaptive Moment Estimation Weight Decay(AdamW)算法的实现。 - - 更新公式如下: - - .. math:: - \begin{aligned} - &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 - \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, - \: \epsilon \text{ (epsilon)} \\ - &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, - \: \textit{maximize} \\ - &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 - \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] - &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ - &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ - &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ - &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ - &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ - &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ - &\hspace{5mm}\textbf{if} \: amsgrad \\ - &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, - \widehat{v_t}) \\ - &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ - &\bf{return} \: \theta_t \\[-1.ex] - \end{aligned} - - .. warning:: - 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 - - 参数: - - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 - - **lr** (Union[int, float, Tensor], 可选) - 学习率。默认值:``1e-3``。 - - **betas** (Tuple[float, float], 可选) - 动量矩阵的指数衰减率。参数范围(0.0, 1.0)。默认值:``(0.9, 0.999)``。 - - **eps** (float, 可选) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:``1e-8``。 - - **weight_decay** (float, 可选) - 权重衰减(L2 penalty)。默认值:``1e-2``。 - - **amsgrad** (bool, 可选) - 是否使用AMSGrad算法。默认值:``False``。 - - 关键字参数: - - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 - - 输入: - - **gradients** (tuple[Tensor], 可选) - 网络权重的梯度。 - - 异常: - - **ValueError** - 学习率不是int、float或Tensor。 - - **ValueError** - 学习率小于0。 - - **ValueError** - `eps` 小于0。 - - **ValueError** - `betas` 范围不在[0, 1)之间。 +mindspore.experimental.optim.AdamW +=================================== + +.. py:class:: mindspore.experimental.optim.AdamW(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize=False) + + Adaptive Moment Estimation Weight Decay(AdamW)算法的实现。 + + 更新公式如下: + + .. math:: + \begin{aligned} + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\bf{return} \: \theta_t \\[-1.ex] + \end{aligned} + + .. warning:: + 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 + + 参数: + - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 + - **lr** (Union[int, float, Tensor], 可选) - 学习率。默认值:``1e-3``。 + - **betas** (Tuple[float, float], 可选) - 动量矩阵的指数衰减率。参数范围(0.0, 1.0)。默认值:``(0.9, 0.999)``。 + - **eps** (float, 可选) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:``1e-8``。 + - **weight_decay** (float, 可选) - 权重衰减(L2 penalty)。默认值:``1e-2``。 + - **amsgrad** (bool, 可选) - 是否使用AMSGrad算法。默认值:``False``。 + + 关键字参数: + - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 + + 输入: + - **gradients** (tuple[Tensor], 可选) - 网络权重的梯度。 + + 异常: + - **ValueError** - 学习率不是int、float或Tensor。 + - **ValueError** - 学习率小于0。 + - **ValueError** - `eps` 小于0。 + - **ValueError** - `betas` 范围不在[0, 1)之间。 - **ValueError** - `weight_decay` 小于0。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Optimizer.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Optimizer.rst index 02768fee9bc..dc848042c76 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Optimizer.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Optimizer.rst @@ -1,20 +1,20 @@ -mindspore.experimental.optim.Optimizer -======================================= - -.. py:class:: mindspore.experimental.optim.Optimizer(params, defaults) - - 用于参数更新的优化器基类。 - - .. warning:: - 这是一个实验性的优化器模块,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 - - 参数: - - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 - - **defaults** (dict) - 一个包含了优化器参数默认值的字典(当参数组未指定参数值时使用此默认值)。 - - .. py:method:: add_param_group(param_group) - - 为 `Optimizer.param_groups` 属性添加一个参数组。 - - 参数: +mindspore.experimental.optim.Optimizer +======================================= + +.. py:class:: mindspore.experimental.optim.Optimizer(params, defaults) + + 用于参数更新的优化器基类。 + + .. warning:: + 这是一个实验性的优化器模块,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 + + 参数: + - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 + - **defaults** (dict) - 一个包含了优化器参数默认值的字典(当参数组未指定参数值时使用此默认值)。 + + .. py:method:: add_param_group(param_group) + + 为 `Optimizer.param_groups` 属性添加一个参数组。 + + 参数: - **param_group** (dict) - 指定了当前网络参数组的特定的优化器配置。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.SGD.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.SGD.rst index 38e805d40f7..2287abf5ac6 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.SGD.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.SGD.rst @@ -1,46 +1,46 @@ -mindspore.experimental.optim.SGD -================================= - -.. py:class:: mindspore.experimental.optim.SGD(params, lr, momentum=0, dampening=0, weight_decay=0.0, nesterov=False, *, maximize=False) - - 随机梯度下降算法。 - - .. math:: - v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) - - 如果nesterov为True: - - .. math:: - p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) - - 如果nesterov为False: - - .. math:: - p_{t+1} = p_{t} - lr \ast v_{t+1} - - 需要注意的是,对于训练的第一步 :math:`v_{t+1} = gradient`。其中,p、v和u分别表示 `parameters`、`accum` 和 `momentum`。 - - .. warning:: - 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 - - 参数: - - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 - - **lr** (Union[int, float, Tensor]) - 学习率。 - - **momentum** (Union[int, float], 可选) - 动量值。默认值:``0``。 - - **weight_decay** (float, 可选) - 权重衰减(L2 penalty),必须大于等于0。默认值:``0.``。 - - **dampening** (Union[int, float], 可选) - 动量的阻尼值。默认值:``0``。 - - **nesterov** (bool, 可选) - 启用Nesterov动量。如果使用Nesterov,动量必须为正,阻尼必须等于0.0。默认值:``False``。 - - 关键字参数: - - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 - - 输入: - - **gradients** (tuple[Tensor]) - 网络权重的梯度。 - - 异常: - - **ValueError** - 学习率不是int、float或Tensor。 - - **ValueError** - 学习率小于0。 - - **ValueError** - ``momentum`` 和 ``weight_decay`` 值小于0.0。 - - **ValueError** - ``momentum``, ``dampening`` 和 ``weight_decay`` 不是int或float。 - - **ValueError** - ``nesterov`` 和 ``maximize`` 不是布尔类型。 +mindspore.experimental.optim.SGD +================================= + +.. py:class:: mindspore.experimental.optim.SGD(params, lr, momentum=0, dampening=0, weight_decay=0.0, nesterov=False, *, maximize=False) + + 随机梯度下降算法。 + + .. math:: + v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) + + 如果nesterov为True: + + .. math:: + p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) + + 如果nesterov为False: + + .. math:: + p_{t+1} = p_{t} - lr \ast v_{t+1} + + 需要注意的是,对于训练的第一步 :math:`v_{t+1} = gradient`。其中,p、v和u分别表示 `parameters`、`accum` 和 `momentum`。 + + .. warning:: + 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 + + 参数: + - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 + - **lr** (Union[int, float, Tensor]) - 学习率。 + - **momentum** (Union[int, float], 可选) - 动量值。默认值:``0``。 + - **weight_decay** (float, 可选) - 权重衰减(L2 penalty),必须大于等于0。默认值:``0.``。 + - **dampening** (Union[int, float], 可选) - 动量的阻尼值。默认值:``0``。 + - **nesterov** (bool, 可选) - 启用Nesterov动量。如果使用Nesterov,动量必须为正,阻尼必须等于0.0。默认值:``False``。 + + 关键字参数: + - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 + + 输入: + - **gradients** (tuple[Tensor]) - 网络权重的梯度。 + + 异常: + - **ValueError** - 学习率不是int、float或Tensor。 + - **ValueError** - 学习率小于0。 + - **ValueError** - ``momentum`` 和 ``weight_decay`` 值小于0.0。 + - **ValueError** - ``momentum``, ``dampening`` 和 ``weight_decay`` 不是int或float。 + - **ValueError** - ``nesterov`` 和 ``maximize`` 不是布尔类型。 - **ValueError** - ``nesterov`` 为True时, ``momentum`` 不为正或 ``dampening`` 不为0。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ConstantLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ConstantLR.rst index a4a502379f1..96776830797 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ConstantLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ConstantLR.rst @@ -1,15 +1,15 @@ -mindspore.experimental.optim.lr_scheduler.ConstantLR -======================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1) - - 将每个参数组的学习率按照衰减因子 `factor` 进行衰减,直到 `last_epoch` 达到 `total_iters`。注意,这种衰减可能与外部对于学习率的改变同时发生。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **factor** (float,可选) - 学习率的衰减因子。 默认值:``1.0 / 3``。 - - **total_iters** (int,可选) - 学习率进行衰减的执行次数,当 `last_epoch` 数达到 `total_iters`,恢复学习率。默认值:``5``. - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 +mindspore.experimental.optim.lr_scheduler.ConstantLR +======================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1) + + 将每个参数组的学习率按照衰减因子 `factor` 进行衰减,直到 `last_epoch` 达到 `total_iters`。注意,这种衰减可能与外部对于学习率的改变同时发生。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **factor** (float,可选) - 学习率的衰减因子。 默认值:``1.0 / 3``。 + - **total_iters** (int,可选) - 学习率进行衰减的执行次数,当 `last_epoch` 数达到 `total_iters`,恢复学习率。默认值:``5``. + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ExponentialLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ExponentialLR.rst index 4807bc7d322..fab2e66f7fa 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ExponentialLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.ExponentialLR.rst @@ -1,14 +1,14 @@ -mindspore.experimental.optim.lr_scheduler.ExponentialLR -========================================================== - -.. py:class:: mindspore.experimental.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1) - - 每个epoch呈指数衰减的学习率,即乘以 `gamma` 。注意,这种衰减可能与外部对于学习率的改变同时发生。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **gamma** (float) - 学习率衰减的乘法因子。 - - **last_epoch** (int,可选) - 最后一个epoch的索引。默认值: ``-1``。 +mindspore.experimental.optim.lr_scheduler.ExponentialLR +========================================================== + +.. py:class:: mindspore.experimental.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1) + + 每个epoch呈指数衰减的学习率,即乘以 `gamma` 。注意,这种衰减可能与外部对于学习率的改变同时发生。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **gamma** (float) - 学习率衰减的乘法因子。 + - **last_epoch** (int,可选) - 最后一个epoch的索引。默认值: ``-1``。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LRScheduler.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LRScheduler.rst index 23b6c4cd149..aee5496713c 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LRScheduler.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LRScheduler.rst @@ -1,32 +1,32 @@ -mindspore.experimental.optim.lr_scheduler.LRScheduler -======================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.LRScheduler(optimizer, last_epoch=-1) - - 动态学习率的基类。 - - .. warning:: - 这是一个实验性的动态学习率模块,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值: ``-1``。 - - 异常: - - **TypeError** - `optimizer` 不是优化器。 - - **KeyError** - `last_epoch` 不是 -1 且 ``'initial_lr'`` 不在参数组内。 - - **ValueError** - `last_epoch` 不是int类型。 - - **ValueError** - `last_epoch` 小于-1。 - - .. py:method:: get_last_lr() - - 返回当前使用的学习率。 - - .. py:method:: step(epoch=None) - - 按照定义的计算逻辑计算并修改学习率。 - - 参数: - - **epoch** (int,可选) - epoch数。默认值: ``None``。 - - +mindspore.experimental.optim.lr_scheduler.LRScheduler +======================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.LRScheduler(optimizer, last_epoch=-1) + + 动态学习率的基类。 + + .. warning:: + 这是一个实验性的动态学习率模块,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值: ``-1``。 + + 异常: + - **TypeError** - `optimizer` 不是优化器。 + - **KeyError** - `last_epoch` 不是 -1 且 ``'initial_lr'`` 不在参数组内。 + - **ValueError** - `last_epoch` 不是int类型。 + - **ValueError** - `last_epoch` 小于-1。 + + .. py:method:: get_last_lr() + + 返回当前使用的学习率。 + + .. py:method:: step(epoch=None) + + 按照定义的计算逻辑计算并修改学习率。 + + 参数: + - **epoch** (int,可选) - epoch数。默认值: ``None``。 + + diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LambdaLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LambdaLR.rst index 079b4f5e6f2..cac0065396e 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LambdaLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LambdaLR.rst @@ -1,17 +1,17 @@ -mindspore.experimental.optim.lr_scheduler.LambdaLR -===================================================== - -.. py:class:: mindspore.experimental.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) - - 将每个参数组的学习率设定为初始学习率乘以指定的 `lr_lambda` 函数。当 `last_epoch = -1` 时,将学习率设置成初始学习率。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **lr_lambda** (Union(function, list)) - 一个关于 `last_epoch` 的匿名函数,或类似函数的列表,列表中每个函数对应 `optimizer.param_groups` 中的每个参数组。 - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 - - 异常: - - **ValueError** - `lr_lambda` 的长度不等于参数组数目。 +mindspore.experimental.optim.lr_scheduler.LambdaLR +===================================================== + +.. py:class:: mindspore.experimental.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) + + 将每个参数组的学习率设定为初始学习率乘以指定的 `lr_lambda` 函数。当 `last_epoch = -1` 时,将学习率设置成初始学习率。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **lr_lambda** (Union(function, list)) - 一个关于 `last_epoch` 的匿名函数,或类似函数的列表,列表中每个函数对应 `optimizer.param_groups` 中的每个参数组。 + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 + + 异常: + - **ValueError** - `lr_lambda` 的长度不等于参数组数目。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LinearLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LinearLR.rst index ab6051b286f..a536508e9e5 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LinearLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.LinearLR.rst @@ -1,20 +1,20 @@ -mindspore.experimental.optim.lr_scheduler.LinearLR -======================================================== - -.. py:class:: mindspore.experimental.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1) - - 线性减小学习率乘法因子 ,并将每个参数组的学习率按照此乘法因子进行衰减,直到 `last_epoch` 数达到 `total_iters`。注意,这种衰减可能与外部对于学习率的改变同时发生。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **start_factor** (float,可选) - 初始的乘法因子值,后续向 `end_factor` 进行线性变化。默认值: ``1.0 /3``。 - - **end_factor** (float,可选) - 线性变化过程结束时的乘法因子值。默认值: ``1.0``。 - - **total_iters** (int,可选) - 迭代的次数。默认值: ``5``。 - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值: ``-1``。 - - 异常: - - **ValueError** - `start_factor` 不在(0, 1]范围内。 +mindspore.experimental.optim.lr_scheduler.LinearLR +======================================================== + +.. py:class:: mindspore.experimental.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1) + + 线性减小学习率乘法因子 ,并将每个参数组的学习率按照此乘法因子进行衰减,直到 `last_epoch` 数达到 `total_iters`。注意,这种衰减可能与外部对于学习率的改变同时发生。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **start_factor** (float,可选) - 初始的乘法因子值,后续向 `end_factor` 进行线性变化。默认值: ``1.0 /3``。 + - **end_factor** (float,可选) - 线性变化过程结束时的乘法因子值。默认值: ``1.0``。 + - **total_iters** (int,可选) - 迭代的次数。默认值: ``5``。 + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值: ``-1``。 + + 异常: + - **ValueError** - `start_factor` 不在(0, 1]范围内。 - **ValueError** - `end_factor` 不在[0, 1]范围内。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiStepLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiStepLR.rst index 721b493bbe6..b33d66cc539 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiStepLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiStepLR.rst @@ -1,20 +1,20 @@ -mindspore.experimental.optim.lr_scheduler.MultiStepLR -======================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1) - - 当epoch/step达到 `milestones` 时,将每个参数组的学习率按照乘法因子 `gamma` 进行变化。注意,这种衰减可能与外部对于学习率的改变同时发生。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **milestones** (list) - 阈值列表,当 `last_epoch` 数达到阈值时将学习率乘以 `gamma`。 - - **gamma** (float,可选) - 学习率的乘法因子。默认值: ``0.1``。 - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 - - 异常: - - **TypeError** - `milestones` 不是列表。 - - **TypeError** - `milestones` 的元素不是int类型。 +mindspore.experimental.optim.lr_scheduler.MultiStepLR +======================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1) + + 当epoch/step达到 `milestones` 时,将每个参数组的学习率按照乘法因子 `gamma` 进行变化。注意,这种衰减可能与外部对于学习率的改变同时发生。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **milestones** (list) - 阈值列表,当 `last_epoch` 数达到阈值时将学习率乘以 `gamma`。 + - **gamma** (float,可选) - 学习率的乘法因子。默认值: ``0.1``。 + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 + + 异常: + - **TypeError** - `milestones` 不是列表。 + - **TypeError** - `milestones` 的元素不是int类型。 - **TypeError** - `gamma` 不是float类型。 \ No newline at end of file diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiplicativeLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiplicativeLR.rst index 922fabaebb0..a28ccadc9c5 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiplicativeLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.MultiplicativeLR.rst @@ -1,14 +1,14 @@ -mindspore.experimental.optim.lr_scheduler.MultiplicativeLR -============================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1) - - 将每个参数组当前的学习率按照传入的 `lr_lambda` 函数乘以指定的乘法因子。当 `last_epoch = -1` 时,将学习率设置成初始学习率。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **lr_lambda** (Union(function, list)) - 一个关于epoch/step的乘法函数,或类似函数的列表,列表中每个函数对应 `optimizer.param_groups` 中的每个参数组。 - - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 +mindspore.experimental.optim.lr_scheduler.MultiplicativeLR +============================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1) + + 将每个参数组当前的学习率按照传入的 `lr_lambda` 函数乘以指定的乘法因子。当 `last_epoch = -1` 时,将学习率设置成初始学习率。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **lr_lambda** (Union(function, list)) - 一个关于epoch/step的乘法函数,或类似函数的列表,列表中每个函数对应 `optimizer.param_groups` 中的每个参数组。 + - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值:``-1``。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.PolynomialLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.PolynomialLR.rst index ecdac37d7fc..b91867ca283 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.PolynomialLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.PolynomialLR.rst @@ -1,24 +1,24 @@ -mindspore.experimental.optim.lr_scheduler.PolynomialLR -======================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=1.0, last_epoch=-1) - - 每个epoch,学习率通过多项式拟合来调整。当epoch大于等于 `total_iters` 时,学习率设置为 ``0`` 。注意,这种衰减可能与外部对于学习率的改变同时发生。 - - 学习率计算的多项式公式如下: - - .. math:: - \begin{split} - &factor = (\frac{1.0 - \frac{last\_epoch}{total\_iters}}{1.0 - \frac{last\_epoch - 1.0}{total\_iters}}) - ^{power}\\ - &lr = lr \times factor - \end{split} - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **total_iters** (int,可选) - 通过多项式拟合调整学习率的迭代次数。默认值: ``5``。 - - **power** (float,可选) - 多项式的幂。默认值: ``1.0``。 - - **last_epoch** (int,可选) - 最后一个epoch的索引。默认值: ``-1``。 +mindspore.experimental.optim.lr_scheduler.PolynomialLR +======================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=1.0, last_epoch=-1) + + 每个epoch,学习率通过多项式拟合来调整。当epoch大于等于 `total_iters` 时,学习率设置为 ``0`` 。注意,这种衰减可能与外部对于学习率的改变同时发生。 + + 学习率计算的多项式公式如下: + + .. math:: + \begin{split} + &factor = (\frac{1.0 - \frac{last\_epoch}{total\_iters}}{1.0 - \frac{last\_epoch - 1.0}{total\_iters}}) + ^{power}\\ + &lr = lr \times factor + \end{split} + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **total_iters** (int,可选) - 通过多项式拟合调整学习率的迭代次数。默认值: ``5``。 + - **power** (float,可选) - 多项式的幂。默认值: ``1.0``。 + - **last_epoch** (int,可选) - 最后一个epoch的索引。默认值: ``-1``。 diff --git a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.StepLR.rst b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.StepLR.rst index 450090fd0dd..06c06c893b9 100644 --- a/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.StepLR.rst +++ b/docs/api/api_python/experimental/optim/mindspore.experimental.optim.lr_scheduler.StepLR.rst @@ -1,15 +1,15 @@ -mindspore.experimental.optim.lr_scheduler.StepLR -================================================= - -.. py:class:: mindspore.experimental.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1) - - 每 `step_size` 个epoch按 `gamma` 衰减每个参数组的学习率。`StepLR` 对于学习率的衰减可能与外部对于学习率的改变同时发生。 - - .. warning:: - 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 - - 参数: - - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 - - **step_size** (int) - 学习率衰减的周期。 - - **gamma** (float,可选) - 学习率衰减的乘法因子。默认值: ``0.1``。 +mindspore.experimental.optim.lr_scheduler.StepLR +================================================= + +.. py:class:: mindspore.experimental.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1) + + 每 `step_size` 个epoch按 `gamma` 衰减每个参数组的学习率。`StepLR` 对于学习率的衰减可能与外部对于学习率的改变同时发生。 + + .. warning:: + 这是一个实验性的动态学习率接口,需要和 `mindspore.experimental.optim `_ 下的接口配合使用。 + + 参数: + - **optimizer** (:class:`mindspore.experimental.optim.Optimizer`) - 优化器实例。 + - **step_size** (int) - 学习率衰减的周期。 + - **gamma** (float,可选) - 学习率衰减的乘法因子。默认值: ``0.1``。 - **last_epoch** (int,可选) - 当前scheduler的 `step()` 方法的执行次数。默认值: ``-1``。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore.experimental.rst b/docs/api/api_python/mindspore.experimental.rst index a227f85a210..1b29c5012d5 100644 --- a/docs/api/api_python/mindspore.experimental.rst +++ b/docs/api/api_python/mindspore.experimental.rst @@ -1,79 +1,79 @@ -mindspore.experimental -======================= - -实验性模块。 - -实验性优化器 ------------- - -.. mscnplatformautosummary:: - :toctree: experimental/optim - :nosignatures: - :template: classtemplate.rst - - mindspore.experimental.optim.Optimizer - mindspore.experimental.optim.Adadelta - mindspore.experimental.optim.Adagrad - mindspore.experimental.optim.Adam - mindspore.experimental.optim.Adamax - mindspore.experimental.optim.AdamW - mindspore.experimental.optim.ASGD - mindspore.experimental.optim.NAdam - mindspore.experimental.optim.RAdam - mindspore.experimental.optim.RMSprop - mindspore.experimental.optim.Rprop - mindspore.experimental.optim.SGD - -LRScheduler类 -^^^^^^^^^^^^^^^^ - -本模块中的动态学习率都是LRScheduler的子类,此模块仅与mindspore.experimental.optim下的优化器配合使用,使用时将优化器实例传递给LRScheduler类。在训练过程中,LRScheduler子类通过调用 `step` 方法进行学习率的动态改变。 - -.. code-block:: - - import mindspore - from mindspore import nn - from mindspore.experimental import optim - # Define the network structure of LeNet5. Refer to - # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py - - net = LeNet5() - loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) - optimizer = optim.Adam(net.trainable_params(), lr=0.05) - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) - def forward_fn(data, label): - logits = net(data) - loss = loss_fn(logits, label) - return loss, logits - grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - def train_step(data, label): - (loss, _), grads = grad_fn(data, label) - optimizer(grads) - return loss - for epoch in range(6): - # Create the dataset taking MNIST as an example. Refer to - # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py - - for data, label in create_dataset(need_download=False): - train_step(data, label) - scheduler.step() - -.. mscnplatformautosummary:: - :toctree: experimental/optim - :nosignatures: - :template: classtemplate.rst - - mindspore.experimental.optim.lr_scheduler.LRScheduler - mindspore.experimental.optim.lr_scheduler.ConstantLR - mindspore.experimental.optim.lr_scheduler.CosineAnnealingLR - mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts - mindspore.experimental.optim.lr_scheduler.CyclicLR - mindspore.experimental.optim.lr_scheduler.ExponentialLR - mindspore.experimental.optim.lr_scheduler.LambdaLR - mindspore.experimental.optim.lr_scheduler.LinearLR - mindspore.experimental.optim.lr_scheduler.MultiplicativeLR - mindspore.experimental.optim.lr_scheduler.MultiStepLR - mindspore.experimental.optim.lr_scheduler.PolynomialLR - mindspore.experimental.optim.lr_scheduler.ReduceLROnPlateau - mindspore.experimental.optim.lr_scheduler.SequentialLR +mindspore.experimental +======================= + +实验性模块。 + +实验性优化器 +------------ + +.. mscnplatformautosummary:: + :toctree: experimental/optim + :nosignatures: + :template: classtemplate.rst + + mindspore.experimental.optim.Optimizer + mindspore.experimental.optim.Adadelta + mindspore.experimental.optim.Adagrad + mindspore.experimental.optim.Adam + mindspore.experimental.optim.Adamax + mindspore.experimental.optim.AdamW + mindspore.experimental.optim.ASGD + mindspore.experimental.optim.NAdam + mindspore.experimental.optim.RAdam + mindspore.experimental.optim.RMSprop + mindspore.experimental.optim.Rprop + mindspore.experimental.optim.SGD + +LRScheduler类 +^^^^^^^^^^^^^^^^ + +本模块中的动态学习率都是LRScheduler的子类,此模块仅与mindspore.experimental.optim下的优化器配合使用,使用时将优化器实例传递给LRScheduler类。在训练过程中,LRScheduler子类通过调用 `step` 方法进行学习率的动态改变。 + +.. code-block:: + + import mindspore + from mindspore import nn + from mindspore.experimental import optim + # Define the network structure of LeNet5. Refer to + # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py + + net = LeNet5() + loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + optimizer = optim.Adam(net.trainable_params(), lr=0.05) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) + def forward_fn(data, label): + logits = net(data) + loss = loss_fn(logits, label) + return loss, logits + grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + optimizer(grads) + return loss + for epoch in range(6): + # Create the dataset taking MNIST as an example. Refer to + # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py + + for data, label in create_dataset(need_download=False): + train_step(data, label) + scheduler.step() + +.. mscnplatformautosummary:: + :toctree: experimental/optim + :nosignatures: + :template: classtemplate.rst + + mindspore.experimental.optim.lr_scheduler.LRScheduler + mindspore.experimental.optim.lr_scheduler.ConstantLR + mindspore.experimental.optim.lr_scheduler.CosineAnnealingLR + mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts + mindspore.experimental.optim.lr_scheduler.CyclicLR + mindspore.experimental.optim.lr_scheduler.ExponentialLR + mindspore.experimental.optim.lr_scheduler.LambdaLR + mindspore.experimental.optim.lr_scheduler.LinearLR + mindspore.experimental.optim.lr_scheduler.MultiplicativeLR + mindspore.experimental.optim.lr_scheduler.MultiStepLR + mindspore.experimental.optim.lr_scheduler.PolynomialLR + mindspore.experimental.optim.lr_scheduler.ReduceLROnPlateau + mindspore.experimental.optim.lr_scheduler.SequentialLR mindspore.experimental.optim.lr_scheduler.StepLR \ No newline at end of file diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.floor.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.floor.rst index c670cbd540f..999ed0f51e9 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.floor.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.floor.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.floor -====================== - -.. py:method:: mindspore.Tensor.floor() - - 详情请参考 :func:`mindspore.ops.floor`。 +mindspore.Tensor.floor +====================== + +.. py:method:: mindspore.Tensor.floor() + + 详情请参考 :func:`mindspore.ops.floor`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.fold.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.fold.rst index 8aa1b90c2df..072d8f3a249 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.fold.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.fold.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.fold -======================== - -.. py:method:: mindspore.Tensor.fold(output_size, kernel_size, dilation=1, padding=0, stride=1) - - 详情请参考 :func:`mindspore.ops.fold`。 +mindspore.Tensor.fold +======================== + +.. py:method:: mindspore.Tensor.fold(output_size, kernel_size, dilation=1, padding=0, stride=1) + + 详情请参考 :func:`mindspore.ops.fold`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.ge.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.ge.rst index 40d14a59a05..454b677467d 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.ge.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.ge.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.ge -=================== - -.. py:method:: mindspore.Tensor.ge(x) - - 详情请参考 :func:`mindspore.ops.ge`。 +mindspore.Tensor.ge +=================== + +.. py:method:: mindspore.Tensor.ge(x) + + 详情请参考 :func:`mindspore.ops.ge`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater.rst index e98f654fc19..5a8d59bfc0f 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.greater -======================== - -.. py:method:: mindspore.Tensor.greater(other) - - 详情请参考 :func:`mindspore.ops.greater`。 +mindspore.Tensor.greater +======================== + +.. py:method:: mindspore.Tensor.greater(other) + + 详情请参考 :func:`mindspore.ops.greater`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater_equal.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater_equal.rst index 75aeda80046..47112d2060d 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater_equal.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.greater_equal.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.greater_equal -============================== - -.. py:method:: mindspore.Tensor.greater_equal(other) - - 详情请参考 :func:`mindspore.ops.greater_equal`。 +mindspore.Tensor.greater_equal +============================== + +.. py:method:: mindspore.Tensor.greater_equal(other) + + 详情请参考 :func:`mindspore.ops.greater_equal`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.gt.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.gt.rst index 12065edab16..292545bcf09 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.gt.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.gt.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.gt -==================== - -.. py:method:: mindspore.Tensor.gt(x) - - 详情请参考 :func:`mindspore.ops.gt`。 +mindspore.Tensor.gt +==================== + +.. py:method:: mindspore.Tensor.gt(x) + + 详情请参考 :func:`mindspore.ops.gt`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igamma.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igamma.rst index 9f44fa24cb5..8a269ef6e1b 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igamma.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igamma.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.igamma -======================= - -.. py:method:: mindspore.Tensor.igamma(other) - - 详情请参考 :func:`mindspore.ops.igamma`。 +mindspore.Tensor.igamma +======================= + +.. py:method:: mindspore.Tensor.igamma(other) + + 详情请参考 :func:`mindspore.ops.igamma`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igammac.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igammac.rst index e33377cd2e5..e99a7a50bf9 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igammac.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.igammac.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.igammac -======================== - -.. py:method:: mindspore.Tensor.igammac(other) - - 详情请参考 :func:`mindspore.ops.igammac`。 +mindspore.Tensor.igammac +======================== + +.. py:method:: mindspore.Tensor.igammac(other) + + 详情请参考 :func:`mindspore.ops.igammac`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.index_add.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.index_add.rst index eea95b150fb..4734e1cc98b 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.index_add.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.index_add.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.index_add -========================== - -.. py:method:: mindspore.Tensor.index_add(dim, index, source, *, alpha=1) - - 详情请参考 :func:`mindspore.ops.index_add`。 +mindspore.Tensor.index_add +========================== + +.. py:method:: mindspore.Tensor.index_add(dim, index, source, *, alpha=1) + + 详情请参考 :func:`mindspore.ops.index_add`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isinf.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isinf.rst index a78544aaffc..9f35900330d 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isinf.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isinf.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.isinf -====================== - -.. py:method:: mindspore.Tensor.isinf() - - 详情请参考 :func:`mindspore.ops.isinf`。 +mindspore.Tensor.isinf +====================== + +.. py:method:: mindspore.Tensor.isinf() + + 详情请参考 :func:`mindspore.ops.isinf`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isnan.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isnan.rst index 85553c56d95..2e81deecda7 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isnan.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.isnan.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.isnan -====================== - -.. py:method:: mindspore.Tensor.isnan() - - 详情请参考 :func:`mindspore.ops.isnan`。 +mindspore.Tensor.isnan +====================== + +.. py:method:: mindspore.Tensor.isnan() + + 详情请参考 :func:`mindspore.ops.isnan`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.le.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.le.rst index d52cbdac759..6999508eab0 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.le.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.le.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.le -=================== - -.. py:method:: mindspore.Tensor.le(other) - - 详情请参考 :func:`mindspore.ops.le`。 +mindspore.Tensor.le +=================== + +.. py:method:: mindspore.Tensor.le(other) + + 详情请参考 :func:`mindspore.ops.le`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.less.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.less.rst index 8cd5b477f8f..7e3b7550100 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.less.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.less.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.less -===================== - -.. py:method:: mindspore.Tensor.less(other) - - 详情请参考 :func:`mindspore.ops.less`。 +mindspore.Tensor.less +===================== + +.. py:method:: mindspore.Tensor.less(other) + + 详情请参考 :func:`mindspore.ops.less`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.log_normal.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.log_normal.rst index cf4a3920c71..6bc12164eb8 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.log_normal.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.log_normal.rst @@ -1,21 +1,21 @@ -mindspore.Tensor.log_normal -============================ - -.. py:method:: mindspore.Tensor.log_normal(mean=1.0, std=2.0) - - 使用给定均值 `mean` 和标准差 `std` 的对数正态分布的数值填充当前Tensor。 - - .. math:: - \text{f}(x;1.0,2.0)=\frac{1}{x\delta \sqrt[]{2\pi} }e^{-\frac{(\ln x-\mu )^2}{2\delta ^2} } - - 其中 :math:`\mu`、:math:`\delta` 分别是对数正态分布的均值和标准差。 - - .. warning:: - 这是一个实验性API,后续可能修改或删除。 - - 参数: - - **mean** (float, 可选) - 对数正态分布的均值。默认值:1.0。 - - **std** (float, 可选) - 对数正态分布的标准差。默认值:2.0。 - - 返回: +mindspore.Tensor.log_normal +============================ + +.. py:method:: mindspore.Tensor.log_normal(mean=1.0, std=2.0) + + 使用给定均值 `mean` 和标准差 `std` 的对数正态分布的数值填充当前Tensor。 + + .. math:: + \text{f}(x;1.0,2.0)=\frac{1}{x\delta \sqrt[]{2\pi} }e^{-\frac{(\ln x-\mu )^2}{2\delta ^2} } + + 其中 :math:`\mu`、:math:`\delta` 分别是对数正态分布的均值和标准差。 + + .. warning:: + 这是一个实验性API,后续可能修改或删除。 + + 参数: + - **mean** (float, 可选) - 对数正态分布的均值。默认值:1.0。 + - **std** (float, 可选) - 对数正态分布的标准差。默认值:2.0。 + + 返回: Tensor,具有与当前Tensor相同的shape和dtype。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_and.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_and.rst index 742fc9dc3a7..c718bf07571 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_and.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_and.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.logical_and -============================ - -.. py:method:: mindspore.Tensor.logical_and(other) - - 详情请参考 :func:`mindspore.ops.logical_and`。 +mindspore.Tensor.logical_and +============================ + +.. py:method:: mindspore.Tensor.logical_and(other) + + 详情请参考 :func:`mindspore.ops.logical_and`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_not.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_not.rst index feb0e697534..6fb74b3ce9a 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_not.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_not.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.logical_not -============================ - -.. py:method:: mindspore.Tensor.logical_not() - - 详情请参考 :func:`mindspore.ops.logical_not`。 +mindspore.Tensor.logical_not +============================ + +.. py:method:: mindspore.Tensor.logical_not() + + 详情请参考 :func:`mindspore.ops.logical_not`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_or.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_or.rst index 86a2c787715..585b563d06a 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_or.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_or.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.logical_or -=========================== - -.. py:method:: mindspore.Tensor.logical_or(other) - - 详情请参考 :func:`mindspore.ops.logical_or`。 +mindspore.Tensor.logical_or +=========================== + +.. py:method:: mindspore.Tensor.logical_or(other) + + 详情请参考 :func:`mindspore.ops.logical_or`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_xor.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_xor.rst index e6a971d5889..d9d07b3a91c 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_xor.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.logical_xor.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.logical_xor -============================ - -.. py:method:: mindspore.Tensor.logical_xor(other) - - 详情请参考 :func:`mindspore.ops.logical_xor`。 +mindspore.Tensor.logical_xor +============================ + +.. py:method:: mindspore.Tensor.logical_xor(other) + + 详情请参考 :func:`mindspore.ops.logical_xor`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.lt.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.lt.rst index 317afbf3a38..6427058d4e0 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.lt.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.lt.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.lt -=================== - -.. py:method:: mindspore.Tensor.lt(other) - - :func:`mindspore.Tensor.less` 的别名。 +mindspore.Tensor.lt +=================== + +.. py:method:: mindspore.Tensor.lt(other) + + :func:`mindspore.Tensor.less` 的别名。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.roll.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.roll.rst index a09a40a5fc8..b29eb500b80 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.roll.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.roll.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.roll -====================== - -.. py:method:: mindspore.Tensor.roll(shifts, dims) - +mindspore.Tensor.roll +====================== + +.. py:method:: mindspore.Tensor.roll(shifts, dims) + 详情请参考 :func:`mindspore.ops.roll`。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.rot90.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.rot90.rst index a2148f5e3c1..c1db434e74d 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.rot90.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.rot90.rst @@ -1,6 +1,6 @@ -mindspore.Tensor.rot90 -======================= - -.. py:method:: mindspore.Tensor.rot90(k, dims) - +mindspore.Tensor.rot90 +======================= + +.. py:method:: mindspore.Tensor.rot90(k, dims) + 详情请参考 :func:`mindspore.ops.rot90`。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.unfold.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.unfold.rst index e6a619bd421..c581e0467a8 100644 --- a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.unfold.rst +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.unfold.rst @@ -1,9 +1,9 @@ -mindspore.Tensor.unfold -======================= - -.. py:method:: mindspore.Tensor.unfold(kernel_size, dilation=1, padding=0, stride=1) - - 详情请参考 :func:`mindspore.ops.unfold`。 - - .. warning:: - 这是一个实验性API,后续可能修改或删除。 +mindspore.Tensor.unfold +======================= + +.. py:method:: mindspore.Tensor.unfold(kernel_size, dilation=1, padding=0, stride=1) + + 详情请参考 :func:`mindspore.ops.unfold`。 + + .. warning:: + 这是一个实验性API,后续可能修改或删除。 diff --git a/docs/api/api_python/mindspore/mindspore.QuantDtype.rst b/docs/api/api_python/mindspore/mindspore.QuantDtype.rst index 370ef02b7fd..71bc3b60afb 100644 --- a/docs/api/api_python/mindspore/mindspore.QuantDtype.rst +++ b/docs/api/api_python/mindspore/mindspore.QuantDtype.rst @@ -1,23 +1,23 @@ -mindspore.QuantDtype -==================== - -.. py:class:: mindspore.QuantDtype - - MindSpore量化数据类型枚举类,包含 `INT1` ~ `INT16`,`UINT1` ~ `UINT16` 。 - - `QuantDtype` 定义在 `dtype.py `_ 文件下 。运行以下命令导入环境: - - .. code-block:: - - from mindspore import QuantDtype - - 教程样例: - - `昇思金箍棒量化感知训练时配置算法 - `_ - - .. py:method:: value() - - 获取当前 `QuantDtype` 的值。该接口当前主要用于序列化或反序列化 `QuantDtype` 。 - - 返回: - int,表示当前 `QuantDtype` 的值。 +mindspore.QuantDtype +==================== + +.. py:class:: mindspore.QuantDtype + + MindSpore量化数据类型枚举类,包含 `INT1` ~ `INT16`,`UINT1` ~ `UINT16` 。 + + `QuantDtype` 定义在 `dtype.py `_ 文件下 。运行以下命令导入环境: + + .. code-block:: + + from mindspore import QuantDtype + + 教程样例: + - `昇思金箍棒量化感知训练时配置算法 + `_ + + .. py:method:: value() + + 获取当前 `QuantDtype` 的值。该接口当前主要用于序列化或反序列化 `QuantDtype` 。 + + 返回: + int,表示当前 `QuantDtype` 的值。 diff --git a/docs/api/api_python/mindspore/mindspore.common.np_dtype.rst b/docs/api/api_python/mindspore/mindspore.common.np_dtype.rst index 6534e0c199b..35cc1915bc5 100644 --- a/docs/api/api_python/mindspore/mindspore.common.np_dtype.rst +++ b/docs/api/api_python/mindspore/mindspore.common.np_dtype.rst @@ -1,20 +1,20 @@ -mindspore.common.np_dtype -========================= - -.. py:class:: mindspore.common.np_dtype - - `np_dtype` 扩展了Numpy的数据类型。 - - `np_dtype` 的实际路径为 `/mindspore/common/np_dtype.py` 。运行以下命令导入环境: - - .. code-block:: - - from mindspore.common import np_dtype - - - **数值型** - - ============================================== ============================= - 定义 描述 - ============================================== ============================= - ``bfloat16`` NumPy 下的 ``bfloat16`` 数据类型。该类型仅用于构造 ``bfloat16`` 类型的Tensor,不保证Numpy下的完整运算能力。仅当运行时的Numpy版本不小于编译时的Numpy版本时生效。 - ============================================== ============================= +mindspore.common.np_dtype +========================= + +.. py:class:: mindspore.common.np_dtype + + `np_dtype` 扩展了Numpy的数据类型。 + + `np_dtype` 的实际路径为 `/mindspore/common/np_dtype.py` 。运行以下命令导入环境: + + .. code-block:: + + from mindspore.common import np_dtype + + - **数值型** + + ============================================== ============================= + 定义 描述 + ============================================== ============================= + ``bfloat16`` NumPy 下的 ``bfloat16`` 数据类型。该类型仅用于构造 ``bfloat16`` 类型的Tensor,不保证Numpy下的完整运算能力。仅当运行时的Numpy版本不小于编译时的Numpy版本时生效。 + ============================================== ============================= diff --git a/docs/api/api_python/mindspore/mindspore.data_sink.rst b/docs/api/api_python/mindspore/mindspore.data_sink.rst index 2c09831d369..cef289d8354 100644 --- a/docs/api/api_python/mindspore/mindspore.data_sink.rst +++ b/docs/api/api_python/mindspore/mindspore.data_sink.rst @@ -1,22 +1,22 @@ -mindspore.data_sink -=================== - -.. py:function:: mindspore.data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None) - - 对输入的函数封装生成一个新的函数。 - - .. note:: - 使用数据下沉时,数据集将被自动循环发送至设备,设备侧最多缓存100个batch的数据且所占内存不大于2G,此时仅需考虑每次下沉的步数 `sink_size` , `sink_size` 默认为 ``1`` ,代表每个epoch仅从缓存中取一个batch的数据进行训练并输出loss,若 `sink_size` 大于1,则每个epoch从缓存中取出 `sink_size` 个batch的数据进行训练然后输出loss。 - - 参数: - - **fn** (Function) - 将与数据集一起运行的函数。 - - **dataset** (Dataset) - 训练数据集迭代器。数据集可以由数据集生成器API在 `mindspore.dataset` 中生成,例如 :class:`mindspore.dataset.ImageFolderDataset` 。 - - **sink_size** (int) - 控制每次数据下沉的step数量。 `sink_size` 必须为正整数。默认值: ``1`` 。 - - **jit_config** (JitConfig) - 编译时所使用的JitConfig配置项,详细可参考 :class:`mindspore.JitConfig` 。默认值: ``None`` ,表示以PyNative模式运行。 - - **input_signature** (Union[Tensor, List or Tuple of Tensors]) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是Tensor,并且 `fn` 的输入参数将不会接受 `**kwargs` 参数,实际输入的shape和dtype应与 `input_signature` 相同,否则会出现TypeError。默认值: ``None`` 。 - - 返回: - 函数,该生成的函数会以数据下沉模式执行。 - - 异常: +mindspore.data_sink +=================== + +.. py:function:: mindspore.data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None) + + 对输入的函数封装生成一个新的函数。 + + .. note:: + 使用数据下沉时,数据集将被自动循环发送至设备,设备侧最多缓存100个batch的数据且所占内存不大于2G,此时仅需考虑每次下沉的步数 `sink_size` , `sink_size` 默认为 ``1`` ,代表每个epoch仅从缓存中取一个batch的数据进行训练并输出loss,若 `sink_size` 大于1,则每个epoch从缓存中取出 `sink_size` 个batch的数据进行训练然后输出loss。 + + 参数: + - **fn** (Function) - 将与数据集一起运行的函数。 + - **dataset** (Dataset) - 训练数据集迭代器。数据集可以由数据集生成器API在 `mindspore.dataset` 中生成,例如 :class:`mindspore.dataset.ImageFolderDataset` 。 + - **sink_size** (int) - 控制每次数据下沉的step数量。 `sink_size` 必须为正整数。默认值: ``1`` 。 + - **jit_config** (JitConfig) - 编译时所使用的JitConfig配置项,详细可参考 :class:`mindspore.JitConfig` 。默认值: ``None`` ,表示以PyNative模式运行。 + - **input_signature** (Union[Tensor, List or Tuple of Tensors]) - 用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是Tensor,并且 `fn` 的输入参数将不会接受 `**kwargs` 参数,实际输入的shape和dtype应与 `input_signature` 相同,否则会出现TypeError。默认值: ``None`` 。 + + 返回: + 函数,该生成的函数会以数据下沉模式执行。 + + 异常: - **ValueError** - 如果 `sink_size` 不是正整数。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/mindspore.dtype.rst b/docs/api/api_python/mindspore/mindspore.dtype.rst index 298b5a8eda2..3f6d064f6e7 100644 --- a/docs/api/api_python/mindspore/mindspore.dtype.rst +++ b/docs/api/api_python/mindspore/mindspore.dtype.rst @@ -1,58 +1,58 @@ -mindspore.dtype -=============== - -.. py:class:: mindspore.dtype - - 创建一个MindSpore数据类型的对象。 - - `dtype` 的实际路径为 `/mindspore/common/dtype.py` 。运行以下命令导入环境: - - .. code-block:: - - from mindspore import dtype as mstype - - - **数值型** - - 目前,MindSpore支持 ``int``,``uint`` 和 ``float`` 数据类型。详情请参照以下表格。 - - ============================================== ============================= - 定义 描述 - ============================================== ============================= - ``mindspore.int8`` , ``mindspore.byte`` 8位整型数 - ``mindspore.int16`` , ``mindspore.short`` 16位整型数 - ``mindspore.int32`` , ``mindspore.intc`` 32位整型数 - ``mindspore.int64`` , ``mindspore.intp`` 64位整型数 - ``mindspore.uint8`` , ``mindspore.ubyte`` 无符号8位整型数 - ``mindspore.uint16`` , ``mindspore.ushort`` 无符号16位整型数 - ``mindspore.uint32`` , ``mindspore.uintc`` 无符号32位整型数 - ``mindspore.uint64`` , ``mindspore.uintp`` 无符号64位整型数 - ``mindspore.float16`` , ``mindspore.half`` 16位浮点数 - ``mindspore.float32`` , ``mindspore.single`` 32位浮点数 - ``mindspore.float64`` , ``mindspore.double`` 64位浮点数 - ``mindspore.bfloat16`` 16位脑浮点数 - ``mindspore.complex64`` 64位复数 - ``mindspore.complex128`` 128位复数 - ============================================== ============================= - - - **其他类型** - - 除数值型以外的其他数据类型,请参照以下表格。 - - ============================ ================= - 类型 描述 - ============================ ================= - ``Tensor`` MindSpore中的张量类型。数据格式采用NCHW。详情请参考 `tensor `_ 。 - ``bool_`` 布尔型,值为 ``True`` 或者 ``False`` 。 - ``int_`` 整数标量。 - ``uint`` 无符号整数标量。 - ``float_`` 浮点标量。 - ``complex`` 复数标量。 - ``number`` 数值型,包括 ``int_``、``uint``、``float_``、``complex`` 和 ``bool_``。 - ``list_`` 由 ``tensor`` 构造的列表,例如 ``List[T0,T1,...,Tn]`` ,其中元素 ``Ti`` 可以是不同的类型。 - ``tuple_`` 由 ``tensor`` 构造的元组,例如 ``Tuple[T0,T1,...,Tn]`` ,其中元素 ``Ti`` 可以是不同的类型。 - ``function`` 函数类型。两种返回方式,当function不是None时,直接返回function,另一种当function为None时返回function(参数: List[T0,T1,...,Tn],返回值: T)。 - ``type_type`` 类型的类型定义。 - ``type_none`` 没有匹配的返回类型,对应 Python 中的 ``type(None)``。 - ``symbolic_key`` 在 ``env_type`` 中用作变量的键的变量的值。 - ``env_type`` 用于存储函数的自由变量的梯度,其中键是自由变量节点的 `symbolic_key` ,值是梯度。 - ============================ ================= +mindspore.dtype +=============== + +.. py:class:: mindspore.dtype + + 创建一个MindSpore数据类型的对象。 + + `dtype` 的实际路径为 `/mindspore/common/dtype.py` 。运行以下命令导入环境: + + .. code-block:: + + from mindspore import dtype as mstype + + - **数值型** + + 目前,MindSpore支持 ``int``,``uint`` 和 ``float`` 数据类型。详情请参照以下表格。 + + ============================================== ============================= + 定义 描述 + ============================================== ============================= + ``mindspore.int8`` , ``mindspore.byte`` 8位整型数 + ``mindspore.int16`` , ``mindspore.short`` 16位整型数 + ``mindspore.int32`` , ``mindspore.intc`` 32位整型数 + ``mindspore.int64`` , ``mindspore.intp`` 64位整型数 + ``mindspore.uint8`` , ``mindspore.ubyte`` 无符号8位整型数 + ``mindspore.uint16`` , ``mindspore.ushort`` 无符号16位整型数 + ``mindspore.uint32`` , ``mindspore.uintc`` 无符号32位整型数 + ``mindspore.uint64`` , ``mindspore.uintp`` 无符号64位整型数 + ``mindspore.float16`` , ``mindspore.half`` 16位浮点数 + ``mindspore.float32`` , ``mindspore.single`` 32位浮点数 + ``mindspore.float64`` , ``mindspore.double`` 64位浮点数 + ``mindspore.bfloat16`` 16位脑浮点数 + ``mindspore.complex64`` 64位复数 + ``mindspore.complex128`` 128位复数 + ============================================== ============================= + + - **其他类型** + + 除数值型以外的其他数据类型,请参照以下表格。 + + ============================ ================= + 类型 描述 + ============================ ================= + ``Tensor`` MindSpore中的张量类型。数据格式采用NCHW。详情请参考 `tensor `_ 。 + ``bool_`` 布尔型,值为 ``True`` 或者 ``False`` 。 + ``int_`` 整数标量。 + ``uint`` 无符号整数标量。 + ``float_`` 浮点标量。 + ``complex`` 复数标量。 + ``number`` 数值型,包括 ``int_``、``uint``、``float_``、``complex`` 和 ``bool_``。 + ``list_`` 由 ``tensor`` 构造的列表,例如 ``List[T0,T1,...,Tn]`` ,其中元素 ``Ti`` 可以是不同的类型。 + ``tuple_`` 由 ``tensor`` 构造的元组,例如 ``Tuple[T0,T1,...,Tn]`` ,其中元素 ``Ti`` 可以是不同的类型。 + ``function`` 函数类型。两种返回方式,当function不是None时,直接返回function,另一种当function为None时返回function(参数: List[T0,T1,...,Tn],返回值: T)。 + ``type_type`` 类型的类型定义。 + ``type_none`` 没有匹配的返回类型,对应 Python 中的 ``type(None)``。 + ``symbolic_key`` 在 ``env_type`` 中用作变量的键的变量的值。 + ``env_type`` 用于存储函数的自由变量的梯度,其中键是自由变量节点的 `symbolic_key` ,值是梯度。 + ============================ ================= diff --git a/docs/api/api_python/mindspore/mindspore.dtype_to_nptype.rst b/docs/api/api_python/mindspore/mindspore.dtype_to_nptype.rst index ac6f1e78fa4..82c20ecf0b1 100644 --- a/docs/api/api_python/mindspore/mindspore.dtype_to_nptype.rst +++ b/docs/api/api_python/mindspore/mindspore.dtype_to_nptype.rst @@ -1,12 +1,12 @@ -mindspore.dtype_to_nptype -========================== - -.. py:function:: mindspore.dtype_to_nptype(type_) - - 将MindSpore数据类型转换成NumPy数据类型。 - - 参数: - - **type_** (mindspore.dtype) - MindSpore中的dtype。 - - 返回: +mindspore.dtype_to_nptype +========================== + +.. py:function:: mindspore.dtype_to_nptype(type_) + + 将MindSpore数据类型转换成NumPy数据类型。 + + 参数: + - **type_** (mindspore.dtype) - MindSpore中的dtype。 + + 返回: NumPy的数据类型。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/mindspore.dtype_to_pytype.rst b/docs/api/api_python/mindspore/mindspore.dtype_to_pytype.rst index 48bcfa5b7c7..60bd97707bd 100644 --- a/docs/api/api_python/mindspore/mindspore.dtype_to_pytype.rst +++ b/docs/api/api_python/mindspore/mindspore.dtype_to_pytype.rst @@ -1,12 +1,12 @@ -mindspore.dtype_to_pytype -========================= - -.. py:function:: mindspore.dtype_to_pytype(type_) - - 将MindSpore数据类型转换为Python数据类型。 - - 参数: - - **type_** (mindspore.dtype) - MindSpore中的dtype。 - - 返回: +mindspore.dtype_to_pytype +========================= + +.. py:function:: mindspore.dtype_to_pytype(type_) + + 将MindSpore数据类型转换为Python数据类型。 + + 参数: + - **type_** (mindspore.dtype) - MindSpore中的dtype。 + + 返回: Python的数据类型。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/mindspore.get_py_obj_dtype.rst b/docs/api/api_python/mindspore/mindspore.get_py_obj_dtype.rst index 2456027a743..ed815756313 100644 --- a/docs/api/api_python/mindspore/mindspore.get_py_obj_dtype.rst +++ b/docs/api/api_python/mindspore/mindspore.get_py_obj_dtype.rst @@ -1,12 +1,12 @@ -mindspore.get_py_obj_dtype -=========================== - -.. py:function:: mindspore.get_py_obj_dtype(obj) - - 获取与Python数据类型对应的MindSpore数据类型。 - - 参数: - - **obj** (type) - Python数据对象,或在Python环境中定义的变量。 - - 返回: +mindspore.get_py_obj_dtype +=========================== + +.. py:function:: mindspore.get_py_obj_dtype(obj) + + 获取与Python数据类型对应的MindSpore数据类型。 + + 参数: + - **obj** (type) - Python数据对象,或在Python环境中定义的变量。 + + 返回: MindSpore的数据类型。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/mindspore.get_seed.rst b/docs/api/api_python/mindspore/mindspore.get_seed.rst index 38a65d0dfff..b24570b0192 100644 --- a/docs/api/api_python/mindspore/mindspore.get_seed.rst +++ b/docs/api/api_python/mindspore/mindspore.get_seed.rst @@ -1,9 +1,9 @@ -mindspore.get_seed -=================== - -.. py:function:: mindspore.get_seed() - - 获取随机种子。 - - 返回: - int,随机种子。 +mindspore.get_seed +=================== + +.. py:function:: mindspore.get_seed() + + 获取随机种子。 + + 返回: + int,随机种子。 diff --git a/docs/api/api_python/mindspore/mindspore.pytype_to_dtype.rst b/docs/api/api_python/mindspore/mindspore.pytype_to_dtype.rst index 1b1baec24de..11b43456f78 100644 --- a/docs/api/api_python/mindspore/mindspore.pytype_to_dtype.rst +++ b/docs/api/api_python/mindspore/mindspore.pytype_to_dtype.rst @@ -1,15 +1,15 @@ -mindspore.pytype_to_dtype -========================= - -.. py:function:: mindspore.pytype_to_dtype(obj) - - 将Python数据类型转换为MindSpore数据类型。 - - 参数: - - **obj** (type) - Python数据对象。 - - 返回: - MindSpore的数据类型。 - - 异常: +mindspore.pytype_to_dtype +========================= + +.. py:function:: mindspore.pytype_to_dtype(obj) + + 将Python数据类型转换为MindSpore数据类型。 + + 参数: + - **obj** (type) - Python数据对象。 + + 返回: + MindSpore的数据类型。 + + 异常: - **NotImplementedError** - Python类型无法转换为MindSpore类型。 \ No newline at end of file diff --git a/docs/api/api_python/mindspore/mindspore.run_check.rst b/docs/api/api_python/mindspore/mindspore.run_check.rst index c69e5c29c3d..fbabade28ee 100644 --- a/docs/api/api_python/mindspore/mindspore.run_check.rst +++ b/docs/api/api_python/mindspore/mindspore.run_check.rst @@ -1,6 +1,6 @@ -mindspore.run_check -=================== - -.. py:function:: mindspore.run_check() - +mindspore.run_check +=================== + +.. py:function:: mindspore.run_check() + 提供了便捷的API用以查询MindSpore的安装是否成功。如果检查返回结果中的版本不是你所期望的,请在run_check()之前使用 :func:`mindspore.set_context` 设置device_target。 diff --git a/docs/api/api_python/mindspore/mindspore.set_seed.rst b/docs/api/api_python/mindspore/mindspore.set_seed.rst index 5b9ffbe16f2..8ef91d52d36 100644 --- a/docs/api/api_python/mindspore/mindspore.set_seed.rst +++ b/docs/api/api_python/mindspore/mindspore.set_seed.rst @@ -1,19 +1,19 @@ -mindspore.set_seed -=================== - -.. py:function:: mindspore.set_seed(seed) - - 设置全局种子。 - - .. note:: - - 全局种子可用于numpy.random, mindspore.common.Initializer以及mindspore.nn.probability.distribution。 - - 如果没有设置全局种子,这些包将会各自使用自己的种子,numpy.random和mindspore.common.Initializer将会随机选择种子值,mindspore.nn.probability.distribution将会使用零作为种子值。 - - numpy.random.seed()设置的种子仅能被numpy.random使用,而这个API设置的种子也可被numpy.random使用,因此推荐使用这个API设置所有的种子。 - - 在semi_auto_parallel/auto_parallel模式下,使用set_seed时,同一节点具有相同形状和相同切分策略的权重将被初始化为相同的结果,否则,将被初始化为不同的结果。 - - 参数: - - **seed** (int) - 设置的全局种子。 - - 异常: - - **ValueError** - 种子值非法 (小于0)。 - - **TypeError** - 种子值非整型数。 +mindspore.set_seed +=================== + +.. py:function:: mindspore.set_seed(seed) + + 设置全局种子。 + + .. note:: + - 全局种子可用于numpy.random, mindspore.common.Initializer以及mindspore.nn.probability.distribution。 + - 如果没有设置全局种子,这些包将会各自使用自己的种子,numpy.random和mindspore.common.Initializer将会随机选择种子值,mindspore.nn.probability.distribution将会使用零作为种子值。 + - numpy.random.seed()设置的种子仅能被numpy.random使用,而这个API设置的种子也可被numpy.random使用,因此推荐使用这个API设置所有的种子。 + - 在semi_auto_parallel/auto_parallel模式下,使用set_seed时,同一节点具有相同形状和相同切分策略的权重将被初始化为相同的结果,否则,将被初始化为不同的结果。 + + 参数: + - **seed** (int) - 设置的全局种子。 + + 异常: + - **ValueError** - 种子值非法 (小于0)。 + - **TypeError** - 种子值非整型数。 diff --git a/docs/api/api_python/mint/mindspore.mint.nn.functional.grid_sample.rst b/docs/api/api_python/mint/mindspore.mint.nn.functional.grid_sample.rst old mode 100755 new mode 100644 index 6a04b90d730..85095aff974 --- a/docs/api/api_python/mint/mindspore.mint.nn.functional.grid_sample.rst +++ b/docs/api/api_python/mint/mindspore.mint.nn.functional.grid_sample.rst @@ -1,42 +1,42 @@ -mindspore.mint.nn.functional.grid_sample -======================================== - -.. py:function:: mindspore.mint.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False) - - 给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。`input` 只支持4-D(GridSampler2D)和5-D(GridSampler3D)。 - - 在4-D场景下,`input` 的shape为 :math:`(N, C, H_{in}, W_{in})`,`grid` 的shape为 :math:`(N, H_{out}, W_{out}, 2)`,`output` 的shape为 :math:`(N, C, H_{out}, W_{out})`。 - 对于每个输出位置 `output[n, :, h, w]`,`grid[n, h, w]` 指定 `input` 像素位置 `x` 和 `y`,用于计算 `output[n, :, h, w]` 的插值。以5D为例,`grid[n, d, h, w]` 指定 `x`, - `y`,`z` 像素位置的插值位置为[n, :, d, h, w]。`mode` 参数指定 `nearest` 或 `bilinear` (bicubic暂不支持)插值法对输入像素进行采样。 - - `grid` 指定由 `input` 归一化的采样像素位置。因此,它应该在 :math:`[-1, 1]` 范围内的值最多。 - - 如果 `grid` 的值在 :math:`[-1, 1]` 范围之外,则相应的输出将按照定义的 `padding_mode` 方式处理。如果 `padding_mode` 设置为 ``0`` ,则使用 :math:`0` 来表示出界的网格位置。 - 如果 `padding_mode` 设置为 ``border``,对于出界网格位置,则使用border值。如果 `padding_mode` 设置为 ``reflection`` ,请使用边界所反映的位置的值用于指定出界网格位置。对于 - 远离边界的位置,它会一直被反射,直到在边界内。 - - 参数: - - **input** (Tensor) - 4-D场景下,shape为 :math:`(N, C, H_{in}, W_{in})`,5-D场景下,shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`。数据类型为float32或float64。 - - **grid** (Tensor) - 4-D场景下,shape为 :math:`(N, H_{out}, W_{out}, 2)`,5-D场景下,shape为 :math:`(N, D_{out}, H_{out}, W_{out}, 3)`。数据类型与 `input` 保持一致。 - - **mode** (str) - 插值方法。可选方法为 ``'bilinear'``, ``'nearest'``。默认值: ``'bilinear'`` 。注: ``'bilinear'`` 还不支持。当 `mode` 为 ``'bilinear'``,且输入为5-D,则 `mode` 为 ``'trilinear'``。但是,当输入为4-D,则 `mode` 为 ``'bilinear'``。默认值: ``'bilinear'`` 。 - - - ``'nearest'``:最近邻插值。每个输出像素的值为最近的输入像素的值。这种方法简单快速,但可能导致块状或像素化的输出。 - - ``'bilinear'``:双线性插值。每个输出像素是最接近的四个输入像素的加权平均值,使用双线性插值计算。与最近邻插值相比,此方法产生更平滑的结果。 - - ``'trilinear'``:三线性插值。这是双线性插值在三维数据上的扩展。它在两个空间维度上执行双线性插值,并沿第三个维度进行线性插值。通常用于体积或三维图像插值。 - - - **padding_mode** (str) - 填充方法。可选方法为 ``'zeros'``,``'border'`` 和 ``'reflection'``。默认值: ``'zeros'`` 。 - - **align_corners** (bool) - 如果设置成 `True`,-1和1被视为引用输入角像素的中心点。如果设置为 `False`,将被视为引用到输入角像素的角点,使采样更不受分辨率影响。默认值为 `False`。 - - 返回: - Tensor,数据类型与 `input` 相同,4-D场景下,shape为 :math:`(N, C, H_{out}, W_{out})`,5-D场景下,shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`。 - - 异常: - - **TypeError** - 如果 `input` 或 `grid` 不是Tensor类型。 - - **TypeError** - 如果 `input` 和 `grid` 的数据类型不一致。 - - **TypeError** - 如果 `input` 或 `grid` 的数据类型无效。 - - **TypeError** - 如果 `align_corners` 不是一个布尔值。 - - **ValueError** - 如果 `input` 或 `grid` 的维度不是四维或五维。 - - **ValueError** - 如果 `input` 的第一个维度不等于 `grid` 的第一个维度。 - - **ValueError** - 如果 `grid` 最后一个维度不等于2(4-D场景)或者3(5-D场景)。 - - **ValueError** - 如果 `mode` 不是 `bilinear`,`nearest`,数据类型不为String。 - - **ValueError** - 如果 `padding_mode` 不是 `zeros`,`border`,`reflection`,数据类型不为String。 +mindspore.mint.nn.functional.grid_sample +======================================== + +.. py:function:: mindspore.mint.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False) + + 给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。`input` 只支持4-D(GridSampler2D)和5-D(GridSampler3D)。 + + 在4-D场景下,`input` 的shape为 :math:`(N, C, H_{in}, W_{in})`,`grid` 的shape为 :math:`(N, H_{out}, W_{out}, 2)`,`output` 的shape为 :math:`(N, C, H_{out}, W_{out})`。 + 对于每个输出位置 `output[n, :, h, w]`,`grid[n, h, w]` 指定 `input` 像素位置 `x` 和 `y`,用于计算 `output[n, :, h, w]` 的插值。以5D为例,`grid[n, d, h, w]` 指定 `x`, + `y`,`z` 像素位置的插值位置为[n, :, d, h, w]。`mode` 参数指定 `nearest` 或 `bilinear` (bicubic暂不支持)插值法对输入像素进行采样。 + + `grid` 指定由 `input` 归一化的采样像素位置。因此,它应该在 :math:`[-1, 1]` 范围内的值最多。 + + 如果 `grid` 的值在 :math:`[-1, 1]` 范围之外,则相应的输出将按照定义的 `padding_mode` 方式处理。如果 `padding_mode` 设置为 ``0`` ,则使用 :math:`0` 来表示出界的网格位置。 + 如果 `padding_mode` 设置为 ``border``,对于出界网格位置,则使用border值。如果 `padding_mode` 设置为 ``reflection`` ,请使用边界所反映的位置的值用于指定出界网格位置。对于 + 远离边界的位置,它会一直被反射,直到在边界内。 + + 参数: + - **input** (Tensor) - 4-D场景下,shape为 :math:`(N, C, H_{in}, W_{in})`,5-D场景下,shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`。数据类型为float32或float64。 + - **grid** (Tensor) - 4-D场景下,shape为 :math:`(N, H_{out}, W_{out}, 2)`,5-D场景下,shape为 :math:`(N, D_{out}, H_{out}, W_{out}, 3)`。数据类型与 `input` 保持一致。 + - **mode** (str) - 插值方法。可选方法为 ``'bilinear'``, ``'nearest'``。默认值: ``'bilinear'`` 。注: ``'bilinear'`` 还不支持。当 `mode` 为 ``'bilinear'``,且输入为5-D,则 `mode` 为 ``'trilinear'``。但是,当输入为4-D,则 `mode` 为 ``'bilinear'``。默认值: ``'bilinear'`` 。 + + - ``'nearest'``:最近邻插值。每个输出像素的值为最近的输入像素的值。这种方法简单快速,但可能导致块状或像素化的输出。 + - ``'bilinear'``:双线性插值。每个输出像素是最接近的四个输入像素的加权平均值,使用双线性插值计算。与最近邻插值相比,此方法产生更平滑的结果。 + - ``'trilinear'``:三线性插值。这是双线性插值在三维数据上的扩展。它在两个空间维度上执行双线性插值,并沿第三个维度进行线性插值。通常用于体积或三维图像插值。 + + - **padding_mode** (str) - 填充方法。可选方法为 ``'zeros'``,``'border'`` 和 ``'reflection'``。默认值: ``'zeros'`` 。 + - **align_corners** (bool) - 如果设置成 `True`,-1和1被视为引用输入角像素的中心点。如果设置为 `False`,将被视为引用到输入角像素的角点,使采样更不受分辨率影响。默认值为 `False`。 + + 返回: + Tensor,数据类型与 `input` 相同,4-D场景下,shape为 :math:`(N, C, H_{out}, W_{out})`,5-D场景下,shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`。 + + 异常: + - **TypeError** - 如果 `input` 或 `grid` 不是Tensor类型。 + - **TypeError** - 如果 `input` 和 `grid` 的数据类型不一致。 + - **TypeError** - 如果 `input` 或 `grid` 的数据类型无效。 + - **TypeError** - 如果 `align_corners` 不是一个布尔值。 + - **ValueError** - 如果 `input` 或 `grid` 的维度不是四维或五维。 + - **ValueError** - 如果 `input` 的第一个维度不等于 `grid` 的第一个维度。 + - **ValueError** - 如果 `grid` 最后一个维度不等于2(4-D场景)或者3(5-D场景)。 + - **ValueError** - 如果 `mode` 不是 `bilinear`,`nearest`,数据类型不为String。 + - **ValueError** - 如果 `padding_mode` 不是 `zeros`,`border`,`reflection`,数据类型不为String。 diff --git a/docs/api/api_python/nn/mindspore.nn.Cell.rst b/docs/api/api_python/nn/mindspore.nn.Cell.rst index 8417995b77a..2066b76eab2 100644 --- a/docs/api/api_python/nn/mindspore.nn.Cell.rst +++ b/docs/api/api_python/nn/mindspore.nn.Cell.rst @@ -1,618 +1,618 @@ -mindspore.nn.Cell -================== - -.. py:class:: mindspore.nn.Cell(auto_prefix=True, flags=None) - - MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。 - - `mindspore.nn` 中神经网络层也是Cell的子类,如 :class:`mindspore.nn.Conv2d` 、 :class:`mindspore.nn.ReLU` 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。 - - .. note:: - Cell默认情况下是推理模式。对于继承Cell的类,如果训练和推理具有不同结构,子类会默认执行推理分支。设置训练模式,请参考 `mindspore.nn.Cell.set_train` 。 - - .. warning:: - 在Cell的子类中不能定义名为'cast'的方法,不能定义名为'phase'和'cells'的属性, 否则会报错。 - - 参数: - - **auto_prefix** (bool,可选) - 是否自动为Cell及其子Cell生成NameSpace。该参数同时会影响 `Cell` 中权重参数的名称。如果设置为 ``True`` ,则自动给权重参数的名称添加前缀,否则不添加前缀。通常情况下,骨干网络应设置为 ``True`` ,否则会产生重名问题。用于训练骨干网络的优化器、 :class:`mindspore.nn.TrainOneStepCell` 等,应设置为 ``False`` ,否则骨干网络的权重参数名会被误改。默认值: ``True`` 。 - - **flags** (dict,可选) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值: ``None`` 。 - - .. py:method:: add_flags(**flags) - - 为Cell添加自定义属性。 - - 在实例化Cell类时,如果入参flags不为空,会调用此方法。 - - 参数: - - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。 - - .. py:method:: add_flags_recursive(**flags) - - 如果Cell含有多个子Cell,此方法会递归得给所有子Cell添加自定义属性。 - - 参数: - - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。 - - .. py:method:: apply(fn) - - 递归地将 `fn` 应用于每个子Cell(由 `.cells()` 返回)以及自身。通常用于初始化模型的参数。 - - 参数: - - **fn** (function) - 被执行于每个Cell的function。 - - 返回: - Cell类型,Cell本身。 - - .. py:method:: auto_cast_inputs(inputs) - - 在混合精度下,自动对输入进行类型转换。 - - 参数: - - **inputs** (tuple) - construct方法的输入。 - - 返回: - Tuple类型,经过类型转换后的输入。 - - .. py:method:: bprop_debug - :property: - - 在图模式下使用,用于标识是否使用自定义的反向传播函数。 - - 教程样例: - - `Cell与参数 - 自定义Cell反向 - `_ - - .. py:method:: cast_inputs(inputs, dst_type) - - 将输入转换为指定类型。 - - 参数: - - **inputs** (tuple[Tensor]) - 输入。 - - **dst_type** (mindspore.dtype) - 指定的数据类型。 - - 返回: - tuple[Tensor]类型,转换类型后的结果。 - - .. py:method:: cast_param(param) - - 在PyNative模式下,根据自动混合精度的精度设置转换Cell中参数的类型。 - - 该接口目前在自动混合精度场景下使用。 - - 参数: - - **param** (Parameter) - 需要被转换类型的输入参数。 - - 返回: - Parameter类型,转换类型后的参数。 - - .. py:method:: cells() - - 返回当前Cell的子Cell的迭代器。 - - 返回: - Iteration类型,Cell的子Cell。 - - .. py:method:: cells_and_names(cells=None, name_prefix='') - - 递归地获取当前Cell及输入 `cells` 的所有子Cell的迭代器,包括Cell的名称及其本身。 - - 参数: - - **cells** (str) - 需要进行迭代的Cell。默认值: ``None`` 。 - - **name_prefix** (str) - 作用域。默认值: ``''`` 。 - - 返回: - Iteration类型,当前Cell及输入 `cells` 的所有子Cell和相对应的名称。 - - .. py:method:: check_names() - - 检查Cell中的网络参数名称是否重复。 - - .. py:method:: compile(*args, **kwargs) - - 编译Cell为计算图,输入需与construct中定义的输入一致。 - - 参数: - - **args** (tuple) - Cell的输入。 - - **kwargs** (dict) - Cell的输入。 - - .. py:method:: compile_and_run(*args, **kwargs) - - 编译并运行Cell,输入需与construct中定义的输入一致。 - - .. note:: - 不推荐使用该函数,建议直接调用Cell实例。 - - 参数: - - **args** (tuple) - Cell的输入。 - - **kwargs** (dict) - Cell的输入。 - - 返回: - Object类型,执行的结果。 - - .. py:method:: construct(*args, **kwargs) - - 定义要执行的计算逻辑。所有子类都必须重写此方法。 - - .. note:: - 当前不支持inputs同时输入tuple类型和非tuple类型。 - - 参数: - - **args** (tuple) - 可变参数列表,默认值: ``()`` 。 - - **kwargs** (dict) - 可变的关键字参数的字典,默认值: ``{}`` 。 - - 返回: - Tensor类型,返回计算结果。 - - .. py:method:: extend_repr() - - 在原有描述基础上扩展Cell的描述。 - - 若需要在print时输出个性化的扩展信息,请在您的网络中重新实现此方法。 - - .. py:method:: flatten_weights(fusion_size=0) - - 重置权重参数(即可训练参数)使用的数据内存,让这些参数按数据类型分组使用连续内存块。 - - .. note:: - 默认情况下,具有相同数据类型的参数会使用同一个连续内存块。但对于某些具有大量参数的模型, - 将一个大的连续内存块分为多个小一点的内存块有可能提升性能,对于这种情况, - 可以通过 `fusion_size` 参数来限制最大连续内存块的的大小。 - - 参数: - - **fusion_size** (int) - 最大连续内存块的大小(以字节为单位), ``0`` 表示不限制大小。默认值: ``0`` 。 - - .. py:method:: generate_scope() - - 为网络中的每个Cell对象生成NameSpace。 - - .. py:method:: get_flags() - - 获取该Cell的自定义属性,自定义属性通过 `add_flags` 方法添加。 - - .. py:method:: get_func_graph_proto() - - 返回图的二进制原型。 - - .. py:method:: get_inputs() - - 返回编译计算图所设置的输入。 - - 返回: - Tuple类型,编译计算图所设置的输入。 - - .. warning:: - 这是一个实验性API,后续可能修改或删除。 - - .. py:method:: get_parameters(expand=True) - - 返回Cell中parameter的迭代器。 - - 获取Cell的参数。如果 `expand` 为 ``true`` ,获取此cell和所有subcells的参数。关于subcell,请看下面的示例。 - - 参数: - - **expand** (bool) - 如果为 ``True`` ,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的subcell的parameter。默认值: ``True`` 。 - - 返回: - Iteration类型,Cell的parameter。 - - .. py:method:: get_scope() - - 返回Cell的作用域。 - - 返回: - String类型,网络的作用域。 - - .. py:method:: infer_param_pipeline_stage() - - 推导Cell中当前 `pipeline_stage` 的参数。 - - .. note:: - - 这个接口在2.3版本废弃,并且会在未来版本移除。 - - 返回: - 属于当前 `pipeline_stage` 的参数。 - - 异常: - - **RuntimeError** - 如果参数不属于任何stage。 - - .. py:method:: init_parameters_data(auto_parallel_mode=False) - - 初始化并替换Cell中所有的parameter的值。 - - .. note:: - 在调用 `init_parameters_data` 后,`trainable_params()` 或其他相似的接口可能返回不同的参数对象,不要保存这些结果。 - - 参数: - - **auto_parallel_mode** (bool) - 是否在自动并行模式下执行。默认值: ``False`` 。 - - 返回: - Dict[Parameter, Parameter],返回一个原始参数和替换参数的字典。 - - .. py:method:: insert_child_to_cell(child_name, child_cell) - - 将一个给定名称的子Cell添加到当前Cell。 - - 参数: - - **child_name** (str) - 子Cell名称。 - - **child_cell** (Cell) - 要插入的子Cell。 - - 异常: - - **KeyError** - 如果子Cell的名称不正确或与其他子Cell名称重复。 - - **TypeError** - 如果 `child_name` 的类型不为str类型。 - - **TypeError** - 如果子Cell的类型不正确。 - - .. py:method:: insert_param_to_cell(param_name, param, check_name_contain_dot=True) - - 向当前Cell添加参数。 - - 将指定名称的参数添加到Cell中。目前在 `mindspore.nn.Cell.__setattr__` 中使用。 - - 参数: - - **param_name** (str) - 参数名称。 - - **param** (Parameter) - 要插入到Cell的参数。 - - **check_name_contain_dot** (bool) - 是否对 `param_name` 中的"."进行检查。默认值: ``True`` 。 - - 异常: - - **KeyError** - 如果参数名称为空或包含"."。 - - **TypeError** - 如果参数的类型不是Parameter。 - - .. py:method:: name_cells() - - 递归地获取一个Cell中所有子Cell的迭代器。 - - 包括Cell名称和Cell本身。 - - 返回: - Dict[String, Cell],Cell中的所有子Cell及其名称。 - - .. py:method:: param_prefix - :property: - - 当前Cell的子Cell的参数名前缀。 - - .. py:method:: parameter_layout_dict - :property: - - `parameter_layout_dict` 表示一个参数的张量layout,这种张量layout是由分片策略和分布式算子信息推断出来的。 - - .. py:method:: parameters_and_names(name_prefix='', expand=True) - - 返回Cell中parameter的迭代器。 - - 包含参数名称和参数本身。 - - 参数: - - **name_prefix** (str) - 作用域。默认值: ``''`` 。 - - **expand** (bool) - 如果为True,则递归地获取当前Cell和所有子Cell的参数及名称;如果为 ``False`` ,只生成当前Cell的子Cell的参数及名称。默认值: ``True`` 。 - - 返回: - 迭代器,Cell的名称和Cell本身。 - - 教程样例: - - `网络构建 - 模型参数 `_ - - .. py:method:: parameters_broadcast_dict(recurse=True) - - 获取这个Cell的参数广播字典。 - - 参数: - - **recurse** (bool) - 是否包含子Cell的参数。默认值: ``True`` 。 - - 返回: - OrderedDict,返回参数广播字典。 - - .. py:method:: parameters_dict(recurse=True) - - 获取此Cell的parameter字典。 - - 参数: - - **recurse** (bool) - 是否递归得包含所有子Cell的parameter。默认值: ``True`` 。 - - 返回: - OrderedDict类型,返回参数字典。 - - .. py:method:: pipeline_stage - :property: - - `pipeline_stage` 表示当前Cell所在的stage。 - - .. py:method:: place(role, rank_id) - - 为该Cell中所有算子设置标签。此标签告诉MindSpore编译器此Cell在哪个进程上启动。 - 每个标签都由进程角色 `role` 和 `rank_id` 组成,因此,通过对不同Cell设置不同标签,这些Cell将在不同进程启动,使用户可以进行分布式训练/推理等任务。 - - .. note:: - - 此接口只在成功调用 `mindspore.communication.init()` 完成动态组网后才能生效。 - - 参数: - - **role** (str) - 算子执行所在进程的角色。只支持'MS_WORKER'。 - - **rank_id** (int) - 算子执行所在进程的id。在相同进程角色间, `rank_id` 是唯一的。 - - .. py:method:: recompute(**kwargs) - - 设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。 - - .. note:: - - 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。 - - 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。 - - 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。 - - Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的,那么请使用算子的重计算API。 - - 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。 - - 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。 - - 参数: - - **mp_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值: ``True`` 。 - - **parallel_optimizer_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值: ``False`` 。 - - .. py:method:: register_backward_hook(hook_fn) - - 设置Cell对象的反向hook函数。 - - .. note:: - - `register_backward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 - - hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `grad_input` 是反向传递给Cell对象的梯度。 `grad_output` 是Cell对象的反向输出梯度。用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。 - - hook_fn返回新的输出梯度或者None:hook_fn(cell_id, grad_input, grad_output) -> New grad_output or None。 - - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` 。 - - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 - - 参数: - - **hook_fn** (function) - 捕获Cell对象信息和反向输入,输出梯度的 `hook_fn` 函数。 - - 返回: - 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 - - 异常: - - **TypeError** - 如果 `hook_fn` 不是Python函数。 - - .. py:method:: register_forward_hook(hook_fn) - - 设置Cell对象的正向hook函数。 - - .. note:: - - `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 - - hook_fn必须有如下代码定义。 `cell` 是已注册Cell对象。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。 - - hook_fn返回新的输出数据或者None:hook_fn(cell, inputs, outputs) -> New outputs or None。 - - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。 - - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 - - 参数: - - **hook_fn** (function) - 捕获Cell对象信息和正向输入,输出数据的 `hook_fn` 函数。 - - 返回: - 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 - - 异常: - - **TypeError** - 如果 `hook_fn` 不是Python函数。 - - .. py:method:: register_forward_pre_hook(hook_fn) - - 设置Cell对象的正向pre_hook函数。 - - .. note:: - - `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 - - hook_fn必须有如下代码定义。 `cell` 是已注册Cell对象。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。 - - hook_fn返回新的输入数据或者None:hook_fn(cell, inputs) -> New inputs or None。 - - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。 - - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 - - 参数: - - **hook_fn** (function) - 捕获Cell对象信息和正向输入数据的hook_fn函数。 - - 返回: - 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 - - 异常: - - **TypeError** - 如果 `hook_fn` 不是Python函数。 - - .. py:method:: remove_redundant_parameters() - - 删除冗余参数。 - - 这个接口通常不需要显式调用。 - - .. py:method:: run_construct(cast_inputs, kwargs) - - 运行construct方法。 - - .. note:: - 该函数已经弃用,将会在未来版本中删除。不推荐使用此函数。 - - 参数: - - **cast_inputs** (tuple) - Cell的输入。 - - **kwargs** (dict) - 关键字参数。 - - 返回: - Cell的输出。 - - .. py:method:: set_boost(boost_type) - - 为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。 - - 请确保 `boost_type` 所选择的算法在 - `algorithm library `_ 算法库中。 - - .. note:: 部分加速算法可能影响网络精度,请谨慎选择。 - - 参数: - - **boost_type** (str) - 加速算法。 - - 返回: - Cell类型,Cell本身。 - - 异常: - - **ValueError** - 如果 `boost_type` 不在boost算法库内。 - - .. py:method:: set_broadcast_flag(mode=True) - - 设置该Cell的参数广播模式。 - - 参数: - - **mode** (bool) - 指定当前模式是否进行参数广播。默认值: ``True`` 。 - - .. py:method:: set_comm_fusion(fusion_type, recurse=True) - - 为Cell中的参数设置融合类型。请参考 :class:`mindspore.Parameter.comm_fusion` 的描述。 - - .. note:: 当函数被多次调用时,此属性值将被重写。 - - 参数: - - **fusion_type** (int) - Parameter的 `comm_fusion` 属性的设置值。 - - **recurse** (bool) - 是否递归地设置子Cell的可训练参数。默认值: ``True`` 。 - - .. py:method:: set_data_parallel() - - 在非自动策略搜索的情况下,如果此Cell的所有算子(包括此Cell内含嵌套的cell)未指定并行策略,则将为这些基本算子设置为数据并行策略。 - - .. note:: 仅在图模式,使用auto_parallel_context = ParallelMode.AUTO_PARALLEL生效。 - - .. py:method:: set_grad(requires_grad=True) - - Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为 ``True`` ,则在执行正向网络时,将生成需要计算梯度的反向网络。 - - 参数: - - **requires_grad** (bool) - 指定网络是否需要梯度,如果为 ``True`` ,PyNative模式下Cell将构建反向网络。默认值: ``True`` 。 - - 返回: - Cell类型,Cell本身。 - - .. py:method:: set_inputs(*inputs, **kwargs) - - 设置编译计算图所需的输入。输入数量需与数据集数量一致。若使用Model接口,请确保所有传入Model的网络和损失函数都配置了set_inputs。 - 输入Tensor的shape可以为动态或静态。 - - .. note:: - 有两种配置模式: - - - 全量配置模式:输入将被用作图编译时的完整编译参数。 - - 增量配置模式:输入被配置到Cell的部分输入上,这些输入将替换图编译对应位置上的参数。 - - 只能传入inputs和kwargs的其中一个。inputs用于全量配置模式,kwargs用于增量配置模式。 - - 参数: - - **inputs** (tuple) - 全量配置模式的参数。 - - **kwargs** (dict) - 增量配置模式的参数。可设置的key值为 `self.construct` 中定义的参数名。 - - .. warning:: - 这是一个实验性API,后续可能修改或删除。 - - .. py:method:: set_jit_config(jit_config) - - 为Cell设置编译时所使用的JitConfig配置项。 - - 参数: - - **jit_config** (JitConfig) - Cell的Jit配置信息。详情请参考 :class:`mindspore.JitConfig` 。 - - .. py:method:: set_param_ps(recurse=True, init_in_server=False) - - 设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。 - - .. note:: - 只在运行的任务处于参数服务器模式时有效。 - 只支持在图模式下调用。 - - 参数: - - **recurse** (bool) - 是否设置子网络的可训练参数。默认值: ``True`` 。 - - **init_in_server** (bool) - 是否在服务器上初始化由参数服务器更新的可训练参数。默认值: ``False`` 。 - - .. py:method:: set_train(mode=True) - - 将Cell设置为训练模式。 - - 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。 - - .. note:: - 当执行 :func:`mindspore.train.Model.train` 的时候,框架会默认调用Cell.set_train(True)。 - 当执行 :func:`mindspore.train.Model.eval` 的时候,框架会默认调用Cell.set_train(False)。 - - 参数: - - **mode** (bool) - 指定模型是否为训练模式。默认值: ``True`` 。 - - 返回: - Cell类型,Cell本身。 - - 教程样例: - - `模型训练 - 训练与评估实现 `_ - - .. py:method:: shard(in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0) - - 指定输入/输出Tensor的分布策略,通过其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 在图模式下, - 可以利用此方法设置某个模块的分布式切分策略,未设置的会自动通过策略传播方式配置。 in_strategy/out_strategy需要为元组类型, - 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: :func:`mindspore.ops.Primitive.shard` 的描述。也可以设置为None,会默认以数据并行执行。 - 其余算子的并行策略由输入输出指定的策略推导得到。 - - .. note:: 调用该方法后,并行模式(parallel_mode)会自动设置为"auto_parallel"且搜索模式(search_mode)自动设置为"sharding_propagation"。 - 如果输入含有Parameter,其对应的策略应该在 `in_strategy` 里设置。 - - 参数: - - **in_strategy** (tuple) - 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。 - - **out_strategy** (Union[None, tuple]) - 指定各输出的切分策略,用法同in_strategy,目前未使能。默认值: ``None`` 。 - - **parameter_plan** (Union[dict, None]) - 指定各参数的切分策略,传入字典时,键是str类型的参数名,值是一维整数tuple表示相应的切分策略, - 如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值: ``None`` 。 - - **device** (string) - 指定执行设备,可以为[ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]中任意一个,目前未使能。默认值: ``"Ascend"`` 。 - - **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[ ``0`` , ``1`` , ``2`` ]中任意一个,默认值: ``0`` 。目前仅支持最大化计算通信比,其余模式未使能。 - - 返回: - Function,返回一个在自动并行流程下执行的函数。 - - .. py:method:: to_float(dst_type) - - 在Cell和所有子Cell的输入上添加类型转换,以使用特定的浮点类型运行。 - - 如果 `dst_type` 是 `mindspore.dtype.float16` ,Cell的所有输入(包括作为常量的input, Parameter, Tensor)都会被转换为float16。请参考 :func:`mindspore.amp.build_train_network` 的源代码中的用法。 - - .. note:: 多次调用将产生覆盖。 - - 参数: - - **dst_type** (mindspore.dtype) - Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 、 `mindspore.dtype.float32` 或者 `mindspore.dtype.bfloat16` 。 - - 返回: - Cell类型,Cell本身。 - - 异常: - - **ValueError** - 如果 `dst_type` 不是 `mindspore.dtype.float32` ,不是 `mindspore.dtype.float16` , 也不是 `mindspore.dtype.bfloat16` 。 - - .. py:method:: trainable_params(recurse=True) - - 返回Cell的一个可训练参数的列表。 - - 参数: - - **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值: ``True`` 。 - - 返回: - List类型,可训练参数列表。 - - 教程样例: - - `模型训练 - 优化器 `_ - - .. py:method:: untrainable_params(recurse=True) - - 返回Cell的一个不可训练参数的列表。 - - 参数: - - **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值: ``True`` 。 - - 返回: - List类型,不可训练参数列表。 - - .. py:method:: update_cell_prefix() - - 递归地更新所有子Cell的 `param_prefix` 。 - - 在调用此方法后,可以通过Cell的 `param_prefix` 属性获取该Cell的所有子Cell的名称前缀。 - - .. py:method:: update_cell_type(cell_type) - - 量化感知训练网络场景下,更新当前Cell的类型。 - - 此方法将Cell类型设置为 `cell_type` 。 - - 参数: - - **cell_type** (str) - 被更新的类型,`cell_type` 可以是"quant"或"second-order"。 - - .. py:method:: update_parameters_name(prefix='', recurse=True) - - 给网络参数名称添加 `prefix` 前缀字符串。 - - 参数: - - **prefix** (str) - 前缀字符串。默认值: ``''`` 。 - - **recurse** (bool) - 是否递归地包含所有子Cell的参数。默认值: ``True`` 。 +mindspore.nn.Cell +================== + +.. py:class:: mindspore.nn.Cell(auto_prefix=True, flags=None) + + MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。 + + `mindspore.nn` 中神经网络层也是Cell的子类,如 :class:`mindspore.nn.Conv2d` 、 :class:`mindspore.nn.ReLU` 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。 + + .. note:: + Cell默认情况下是推理模式。对于继承Cell的类,如果训练和推理具有不同结构,子类会默认执行推理分支。设置训练模式,请参考 `mindspore.nn.Cell.set_train` 。 + + .. warning:: + 在Cell的子类中不能定义名为'cast'的方法,不能定义名为'phase'和'cells'的属性, 否则会报错。 + + 参数: + - **auto_prefix** (bool,可选) - 是否自动为Cell及其子Cell生成NameSpace。该参数同时会影响 `Cell` 中权重参数的名称。如果设置为 ``True`` ,则自动给权重参数的名称添加前缀,否则不添加前缀。通常情况下,骨干网络应设置为 ``True`` ,否则会产生重名问题。用于训练骨干网络的优化器、 :class:`mindspore.nn.TrainOneStepCell` 等,应设置为 ``False`` ,否则骨干网络的权重参数名会被误改。默认值: ``True`` 。 + - **flags** (dict,可选) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值: ``None`` 。 + + .. py:method:: add_flags(**flags) + + 为Cell添加自定义属性。 + + 在实例化Cell类时,如果入参flags不为空,会调用此方法。 + + 参数: + - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。 + + .. py:method:: add_flags_recursive(**flags) + + 如果Cell含有多个子Cell,此方法会递归得给所有子Cell添加自定义属性。 + + 参数: + - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。 + + .. py:method:: apply(fn) + + 递归地将 `fn` 应用于每个子Cell(由 `.cells()` 返回)以及自身。通常用于初始化模型的参数。 + + 参数: + - **fn** (function) - 被执行于每个Cell的function。 + + 返回: + Cell类型,Cell本身。 + + .. py:method:: auto_cast_inputs(inputs) + + 在混合精度下,自动对输入进行类型转换。 + + 参数: + - **inputs** (tuple) - construct方法的输入。 + + 返回: + Tuple类型,经过类型转换后的输入。 + + .. py:method:: bprop_debug + :property: + + 在图模式下使用,用于标识是否使用自定义的反向传播函数。 + + 教程样例: + - `Cell与参数 - 自定义Cell反向 + `_ + + .. py:method:: cast_inputs(inputs, dst_type) + + 将输入转换为指定类型。 + + 参数: + - **inputs** (tuple[Tensor]) - 输入。 + - **dst_type** (mindspore.dtype) - 指定的数据类型。 + + 返回: + tuple[Tensor]类型,转换类型后的结果。 + + .. py:method:: cast_param(param) + + 在PyNative模式下,根据自动混合精度的精度设置转换Cell中参数的类型。 + + 该接口目前在自动混合精度场景下使用。 + + 参数: + - **param** (Parameter) - 需要被转换类型的输入参数。 + + 返回: + Parameter类型,转换类型后的参数。 + + .. py:method:: cells() + + 返回当前Cell的子Cell的迭代器。 + + 返回: + Iteration类型,Cell的子Cell。 + + .. py:method:: cells_and_names(cells=None, name_prefix='') + + 递归地获取当前Cell及输入 `cells` 的所有子Cell的迭代器,包括Cell的名称及其本身。 + + 参数: + - **cells** (str) - 需要进行迭代的Cell。默认值: ``None`` 。 + - **name_prefix** (str) - 作用域。默认值: ``''`` 。 + + 返回: + Iteration类型,当前Cell及输入 `cells` 的所有子Cell和相对应的名称。 + + .. py:method:: check_names() + + 检查Cell中的网络参数名称是否重复。 + + .. py:method:: compile(*args, **kwargs) + + 编译Cell为计算图,输入需与construct中定义的输入一致。 + + 参数: + - **args** (tuple) - Cell的输入。 + - **kwargs** (dict) - Cell的输入。 + + .. py:method:: compile_and_run(*args, **kwargs) + + 编译并运行Cell,输入需与construct中定义的输入一致。 + + .. note:: + 不推荐使用该函数,建议直接调用Cell实例。 + + 参数: + - **args** (tuple) - Cell的输入。 + - **kwargs** (dict) - Cell的输入。 + + 返回: + Object类型,执行的结果。 + + .. py:method:: construct(*args, **kwargs) + + 定义要执行的计算逻辑。所有子类都必须重写此方法。 + + .. note:: + 当前不支持inputs同时输入tuple类型和非tuple类型。 + + 参数: + - **args** (tuple) - 可变参数列表,默认值: ``()`` 。 + - **kwargs** (dict) - 可变的关键字参数的字典,默认值: ``{}`` 。 + + 返回: + Tensor类型,返回计算结果。 + + .. py:method:: extend_repr() + + 在原有描述基础上扩展Cell的描述。 + + 若需要在print时输出个性化的扩展信息,请在您的网络中重新实现此方法。 + + .. py:method:: flatten_weights(fusion_size=0) + + 重置权重参数(即可训练参数)使用的数据内存,让这些参数按数据类型分组使用连续内存块。 + + .. note:: + 默认情况下,具有相同数据类型的参数会使用同一个连续内存块。但对于某些具有大量参数的模型, + 将一个大的连续内存块分为多个小一点的内存块有可能提升性能,对于这种情况, + 可以通过 `fusion_size` 参数来限制最大连续内存块的的大小。 + + 参数: + - **fusion_size** (int) - 最大连续内存块的大小(以字节为单位), ``0`` 表示不限制大小。默认值: ``0`` 。 + + .. py:method:: generate_scope() + + 为网络中的每个Cell对象生成NameSpace。 + + .. py:method:: get_flags() + + 获取该Cell的自定义属性,自定义属性通过 `add_flags` 方法添加。 + + .. py:method:: get_func_graph_proto() + + 返回图的二进制原型。 + + .. py:method:: get_inputs() + + 返回编译计算图所设置的输入。 + + 返回: + Tuple类型,编译计算图所设置的输入。 + + .. warning:: + 这是一个实验性API,后续可能修改或删除。 + + .. py:method:: get_parameters(expand=True) + + 返回Cell中parameter的迭代器。 + + 获取Cell的参数。如果 `expand` 为 ``true`` ,获取此cell和所有subcells的参数。关于subcell,请看下面的示例。 + + 参数: + - **expand** (bool) - 如果为 ``True`` ,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的subcell的parameter。默认值: ``True`` 。 + + 返回: + Iteration类型,Cell的parameter。 + + .. py:method:: get_scope() + + 返回Cell的作用域。 + + 返回: + String类型,网络的作用域。 + + .. py:method:: infer_param_pipeline_stage() + + 推导Cell中当前 `pipeline_stage` 的参数。 + + .. note:: + - 这个接口在2.3版本废弃,并且会在未来版本移除。 + + 返回: + 属于当前 `pipeline_stage` 的参数。 + + 异常: + - **RuntimeError** - 如果参数不属于任何stage。 + + .. py:method:: init_parameters_data(auto_parallel_mode=False) + + 初始化并替换Cell中所有的parameter的值。 + + .. note:: + 在调用 `init_parameters_data` 后,`trainable_params()` 或其他相似的接口可能返回不同的参数对象,不要保存这些结果。 + + 参数: + - **auto_parallel_mode** (bool) - 是否在自动并行模式下执行。默认值: ``False`` 。 + + 返回: + Dict[Parameter, Parameter],返回一个原始参数和替换参数的字典。 + + .. py:method:: insert_child_to_cell(child_name, child_cell) + + 将一个给定名称的子Cell添加到当前Cell。 + + 参数: + - **child_name** (str) - 子Cell名称。 + - **child_cell** (Cell) - 要插入的子Cell。 + + 异常: + - **KeyError** - 如果子Cell的名称不正确或与其他子Cell名称重复。 + - **TypeError** - 如果 `child_name` 的类型不为str类型。 + - **TypeError** - 如果子Cell的类型不正确。 + + .. py:method:: insert_param_to_cell(param_name, param, check_name_contain_dot=True) + + 向当前Cell添加参数。 + + 将指定名称的参数添加到Cell中。目前在 `mindspore.nn.Cell.__setattr__` 中使用。 + + 参数: + - **param_name** (str) - 参数名称。 + - **param** (Parameter) - 要插入到Cell的参数。 + - **check_name_contain_dot** (bool) - 是否对 `param_name` 中的"."进行检查。默认值: ``True`` 。 + + 异常: + - **KeyError** - 如果参数名称为空或包含"."。 + - **TypeError** - 如果参数的类型不是Parameter。 + + .. py:method:: name_cells() + + 递归地获取一个Cell中所有子Cell的迭代器。 + + 包括Cell名称和Cell本身。 + + 返回: + Dict[String, Cell],Cell中的所有子Cell及其名称。 + + .. py:method:: param_prefix + :property: + + 当前Cell的子Cell的参数名前缀。 + + .. py:method:: parameter_layout_dict + :property: + + `parameter_layout_dict` 表示一个参数的张量layout,这种张量layout是由分片策略和分布式算子信息推断出来的。 + + .. py:method:: parameters_and_names(name_prefix='', expand=True) + + 返回Cell中parameter的迭代器。 + + 包含参数名称和参数本身。 + + 参数: + - **name_prefix** (str) - 作用域。默认值: ``''`` 。 + - **expand** (bool) - 如果为True,则递归地获取当前Cell和所有子Cell的参数及名称;如果为 ``False`` ,只生成当前Cell的子Cell的参数及名称。默认值: ``True`` 。 + + 返回: + 迭代器,Cell的名称和Cell本身。 + + 教程样例: + - `网络构建 - 模型参数 `_ + + .. py:method:: parameters_broadcast_dict(recurse=True) + + 获取这个Cell的参数广播字典。 + + 参数: + - **recurse** (bool) - 是否包含子Cell的参数。默认值: ``True`` 。 + + 返回: + OrderedDict,返回参数广播字典。 + + .. py:method:: parameters_dict(recurse=True) + + 获取此Cell的parameter字典。 + + 参数: + - **recurse** (bool) - 是否递归得包含所有子Cell的parameter。默认值: ``True`` 。 + + 返回: + OrderedDict类型,返回参数字典。 + + .. py:method:: pipeline_stage + :property: + + `pipeline_stage` 表示当前Cell所在的stage。 + + .. py:method:: place(role, rank_id) + + 为该Cell中所有算子设置标签。此标签告诉MindSpore编译器此Cell在哪个进程上启动。 + 每个标签都由进程角色 `role` 和 `rank_id` 组成,因此,通过对不同Cell设置不同标签,这些Cell将在不同进程启动,使用户可以进行分布式训练/推理等任务。 + + .. note:: + - 此接口只在成功调用 `mindspore.communication.init()` 完成动态组网后才能生效。 + + 参数: + - **role** (str) - 算子执行所在进程的角色。只支持'MS_WORKER'。 + - **rank_id** (int) - 算子执行所在进程的id。在相同进程角色间, `rank_id` 是唯一的。 + + .. py:method:: recompute(**kwargs) + + 设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。 + + .. note:: + - 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。 + - 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。 + - 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。 + - Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的,那么请使用算子的重计算API。 + - 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。 + - 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。 + + 参数: + - **mp_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值: ``True`` 。 + - **parallel_optimizer_comm_recompute** (bool) - 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值: ``False`` 。 + + .. py:method:: register_backward_hook(hook_fn) + + 设置Cell对象的反向hook函数。 + + .. note:: + - `register_backward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 + - hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `grad_input` 是反向传递给Cell对象的梯度。 `grad_output` 是Cell对象的反向输出梯度。用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。 + - hook_fn返回新的输出梯度或者None:hook_fn(cell_id, grad_input, grad_output) -> New grad_output or None。 + - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` 。 + - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 + + 参数: + - **hook_fn** (function) - 捕获Cell对象信息和反向输入,输出梯度的 `hook_fn` 函数。 + + 返回: + 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 + + 异常: + - **TypeError** - 如果 `hook_fn` 不是Python函数。 + + .. py:method:: register_forward_hook(hook_fn) + + 设置Cell对象的正向hook函数。 + + .. note:: + - `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 + - hook_fn必须有如下代码定义。 `cell` 是已注册Cell对象。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。 + - hook_fn返回新的输出数据或者None:hook_fn(cell, inputs, outputs) -> New outputs or None。 + - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。 + - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 + + 参数: + - **hook_fn** (function) - 捕获Cell对象信息和正向输入,输出数据的 `hook_fn` 函数。 + + 返回: + 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 + + 异常: + - **TypeError** - 如果 `hook_fn` 不是Python函数。 + + .. py:method:: register_forward_pre_hook(hook_fn) + + 设置Cell对象的正向pre_hook函数。 + + .. note:: + - `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `jit` 装饰器功能时不起作用。 + - hook_fn必须有如下代码定义。 `cell` 是已注册Cell对象。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。 + - hook_fn返回新的输入数据或者None:hook_fn(cell, inputs) -> New inputs or None。 + - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。 + - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。 + + 参数: + - **hook_fn** (function) - 捕获Cell对象信息和正向输入数据的hook_fn函数。 + + 返回: + 返回与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。 + + 异常: + - **TypeError** - 如果 `hook_fn` 不是Python函数。 + + .. py:method:: remove_redundant_parameters() + + 删除冗余参数。 + + 这个接口通常不需要显式调用。 + + .. py:method:: run_construct(cast_inputs, kwargs) + + 运行construct方法。 + + .. note:: + 该函数已经弃用,将会在未来版本中删除。不推荐使用此函数。 + + 参数: + - **cast_inputs** (tuple) - Cell的输入。 + - **kwargs** (dict) - 关键字参数。 + + 返回: + Cell的输出。 + + .. py:method:: set_boost(boost_type) + + 为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。 + + 请确保 `boost_type` 所选择的算法在 + `algorithm library `_ 算法库中。 + + .. note:: 部分加速算法可能影响网络精度,请谨慎选择。 + + 参数: + - **boost_type** (str) - 加速算法。 + + 返回: + Cell类型,Cell本身。 + + 异常: + - **ValueError** - 如果 `boost_type` 不在boost算法库内。 + + .. py:method:: set_broadcast_flag(mode=True) + + 设置该Cell的参数广播模式。 + + 参数: + - **mode** (bool) - 指定当前模式是否进行参数广播。默认值: ``True`` 。 + + .. py:method:: set_comm_fusion(fusion_type, recurse=True) + + 为Cell中的参数设置融合类型。请参考 :class:`mindspore.Parameter.comm_fusion` 的描述。 + + .. note:: 当函数被多次调用时,此属性值将被重写。 + + 参数: + - **fusion_type** (int) - Parameter的 `comm_fusion` 属性的设置值。 + - **recurse** (bool) - 是否递归地设置子Cell的可训练参数。默认值: ``True`` 。 + + .. py:method:: set_data_parallel() + + 在非自动策略搜索的情况下,如果此Cell的所有算子(包括此Cell内含嵌套的cell)未指定并行策略,则将为这些基本算子设置为数据并行策略。 + + .. note:: 仅在图模式,使用auto_parallel_context = ParallelMode.AUTO_PARALLEL生效。 + + .. py:method:: set_grad(requires_grad=True) + + Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为 ``True`` ,则在执行正向网络时,将生成需要计算梯度的反向网络。 + + 参数: + - **requires_grad** (bool) - 指定网络是否需要梯度,如果为 ``True`` ,PyNative模式下Cell将构建反向网络。默认值: ``True`` 。 + + 返回: + Cell类型,Cell本身。 + + .. py:method:: set_inputs(*inputs, **kwargs) + + 设置编译计算图所需的输入。输入数量需与数据集数量一致。若使用Model接口,请确保所有传入Model的网络和损失函数都配置了set_inputs。 + 输入Tensor的shape可以为动态或静态。 + + .. note:: + 有两种配置模式: + + - 全量配置模式:输入将被用作图编译时的完整编译参数。 + - 增量配置模式:输入被配置到Cell的部分输入上,这些输入将替换图编译对应位置上的参数。 + + 只能传入inputs和kwargs的其中一个。inputs用于全量配置模式,kwargs用于增量配置模式。 + + 参数: + - **inputs** (tuple) - 全量配置模式的参数。 + - **kwargs** (dict) - 增量配置模式的参数。可设置的key值为 `self.construct` 中定义的参数名。 + + .. warning:: + 这是一个实验性API,后续可能修改或删除。 + + .. py:method:: set_jit_config(jit_config) + + 为Cell设置编译时所使用的JitConfig配置项。 + + 参数: + - **jit_config** (JitConfig) - Cell的Jit配置信息。详情请参考 :class:`mindspore.JitConfig` 。 + + .. py:method:: set_param_ps(recurse=True, init_in_server=False) + + 设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。 + + .. note:: + 只在运行的任务处于参数服务器模式时有效。 + 只支持在图模式下调用。 + + 参数: + - **recurse** (bool) - 是否设置子网络的可训练参数。默认值: ``True`` 。 + - **init_in_server** (bool) - 是否在服务器上初始化由参数服务器更新的可训练参数。默认值: ``False`` 。 + + .. py:method:: set_train(mode=True) + + 将Cell设置为训练模式。 + + 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。 + + .. note:: + 当执行 :func:`mindspore.train.Model.train` 的时候,框架会默认调用Cell.set_train(True)。 + 当执行 :func:`mindspore.train.Model.eval` 的时候,框架会默认调用Cell.set_train(False)。 + + 参数: + - **mode** (bool) - 指定模型是否为训练模式。默认值: ``True`` 。 + + 返回: + Cell类型,Cell本身。 + + 教程样例: + - `模型训练 - 训练与评估实现 `_ + + .. py:method:: shard(in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0) + + 指定输入/输出Tensor的分布策略,通过其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 在图模式下, + 可以利用此方法设置某个模块的分布式切分策略,未设置的会自动通过策略传播方式配置。 in_strategy/out_strategy需要为元组类型, + 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: :func:`mindspore.ops.Primitive.shard` 的描述。也可以设置为None,会默认以数据并行执行。 + 其余算子的并行策略由输入输出指定的策略推导得到。 + + .. note:: 调用该方法后,并行模式(parallel_mode)会自动设置为"auto_parallel"且搜索模式(search_mode)自动设置为"sharding_propagation"。 + 如果输入含有Parameter,其对应的策略应该在 `in_strategy` 里设置。 + + 参数: + - **in_strategy** (tuple) - 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。 + - **out_strategy** (Union[None, tuple]) - 指定各输出的切分策略,用法同in_strategy,目前未使能。默认值: ``None`` 。 + - **parameter_plan** (Union[dict, None]) - 指定各参数的切分策略,传入字典时,键是str类型的参数名,值是一维整数tuple表示相应的切分策略, + 如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值: ``None`` 。 + - **device** (string) - 指定执行设备,可以为[ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]中任意一个,目前未使能。默认值: ``"Ascend"`` 。 + - **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[ ``0`` , ``1`` , ``2`` ]中任意一个,默认值: ``0`` 。目前仅支持最大化计算通信比,其余模式未使能。 + + 返回: + Function,返回一个在自动并行流程下执行的函数。 + + .. py:method:: to_float(dst_type) + + 在Cell和所有子Cell的输入上添加类型转换,以使用特定的浮点类型运行。 + + 如果 `dst_type` 是 `mindspore.dtype.float16` ,Cell的所有输入(包括作为常量的input, Parameter, Tensor)都会被转换为float16。请参考 :func:`mindspore.amp.build_train_network` 的源代码中的用法。 + + .. note:: 多次调用将产生覆盖。 + + 参数: + - **dst_type** (mindspore.dtype) - Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 、 `mindspore.dtype.float32` 或者 `mindspore.dtype.bfloat16` 。 + + 返回: + Cell类型,Cell本身。 + + 异常: + - **ValueError** - 如果 `dst_type` 不是 `mindspore.dtype.float32` ,不是 `mindspore.dtype.float16` , 也不是 `mindspore.dtype.bfloat16` 。 + + .. py:method:: trainable_params(recurse=True) + + 返回Cell的一个可训练参数的列表。 + + 参数: + - **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值: ``True`` 。 + + 返回: + List类型,可训练参数列表。 + + 教程样例: + - `模型训练 - 优化器 `_ + + .. py:method:: untrainable_params(recurse=True) + + 返回Cell的一个不可训练参数的列表。 + + 参数: + - **recurse** (bool) - 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值: ``True`` 。 + + 返回: + List类型,不可训练参数列表。 + + .. py:method:: update_cell_prefix() + + 递归地更新所有子Cell的 `param_prefix` 。 + + 在调用此方法后,可以通过Cell的 `param_prefix` 属性获取该Cell的所有子Cell的名称前缀。 + + .. py:method:: update_cell_type(cell_type) + + 量化感知训练网络场景下,更新当前Cell的类型。 + + 此方法将Cell类型设置为 `cell_type` 。 + + 参数: + - **cell_type** (str) - 被更新的类型,`cell_type` 可以是"quant"或"second-order"。 + + .. py:method:: update_parameters_name(prefix='', recurse=True) + + 给网络参数名称添加 `prefix` 前缀字符串。 + + 参数: + - **prefix** (str) - 前缀字符串。默认值: ``''`` 。 + - **recurse** (bool) - 是否递归地包含所有子Cell的参数。默认值: ``True`` 。 diff --git a/docs/api/api_python/nn/mindspore.nn.GetNextSingleOp.rst b/docs/api/api_python/nn/mindspore.nn.GetNextSingleOp.rst index 37769f6a9bf..2d0110ecbb7 100644 --- a/docs/api/api_python/nn/mindspore.nn.GetNextSingleOp.rst +++ b/docs/api/api_python/nn/mindspore.nn.GetNextSingleOp.rst @@ -1,14 +1,14 @@ -mindspore.nn.GetNextSingleOp -============================= - -.. py:class:: mindspore.nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name) - - 用于获取下一条数据的Cell。更详细的信息请参考 :class:`mindspore.ops.GetNext` 。 - - 参数: - - **dataset_types** (list[:class:`mindspore.dtype`]) - 数据集类型。 - - **dataset_shapes** (list[tuple[int]]) - 数据集的shape。 - - **queue_name** (str) - 待获取数据的队列名称。 - - 输出: - tuple[Tensor],从数据集中获取的数据。 +mindspore.nn.GetNextSingleOp +============================= + +.. py:class:: mindspore.nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name) + + 用于获取下一条数据的Cell。更详细的信息请参考 :class:`mindspore.ops.GetNext` 。 + + 参数: + - **dataset_types** (list[:class:`mindspore.dtype`]) - 数据集类型。 + - **dataset_shapes** (list[tuple[int]]) - 数据集的shape。 + - **queue_name** (str) - 待获取数据的队列名称。 + + 输出: + tuple[Tensor],从数据集中获取的数据。 diff --git a/docs/api/api_python/nn/mindspore.nn.GraphCell.rst b/docs/api/api_python/nn/mindspore.nn.GraphCell.rst index 2c6abb92256..0cd4089a8a1 100644 --- a/docs/api/api_python/nn/mindspore.nn.GraphCell.rst +++ b/docs/api/api_python/nn/mindspore.nn.GraphCell.rst @@ -1,19 +1,19 @@ -mindspore.nn.GraphCell -====================== - -.. py:class:: mindspore.nn.GraphCell(graph, params_init=None, obf_random_seed=None) - - 运行从MindIR加载的计算图。 - - 此功能仍在开发中。目前 `GraphCell` 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。 - - 参数: - - **graph** (FuncGraph) - 从MindIR加载的编译图。 - - **params_init** (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: ``None`` 。 - - **obf_random_seed** (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 :func:`mindspore.obfuscate_model` 。如果导入的 `graph` 是一个经过混淆的模型,那么须提供 `obf_random_seed` 。 `obf_random_seed` 的取值范围是(0, 9223372036854775807]。默认值: ``None`` 。 - - 异常: - - **TypeError** - 如果图不是FuncGraph类型。 - - **TypeError** - 如果 `params_init` 不是字典。 - - **TypeError** - 如果 `params_init` 的key不是字符串。 - - **TypeError** - 如果 `params_init` 的value既不是 Tensor也不是Parameter。 +mindspore.nn.GraphCell +====================== + +.. py:class:: mindspore.nn.GraphCell(graph, params_init=None, obf_random_seed=None) + + 运行从MindIR加载的计算图。 + + 此功能仍在开发中。目前 `GraphCell` 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。 + + 参数: + - **graph** (FuncGraph) - 从MindIR加载的编译图。 + - **params_init** (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: ``None`` 。 + - **obf_random_seed** (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 :func:`mindspore.obfuscate_model` 。如果导入的 `graph` 是一个经过混淆的模型,那么须提供 `obf_random_seed` 。 `obf_random_seed` 的取值范围是(0, 9223372036854775807]。默认值: ``None`` 。 + + 异常: + - **TypeError** - 如果图不是FuncGraph类型。 + - **TypeError** - 如果 `params_init` 不是字典。 + - **TypeError** - 如果 `params_init` 的key不是字符串。 + - **TypeError** - 如果 `params_init` 的value既不是 Tensor也不是Parameter。 diff --git a/docs/api/api_python/nn/mindspore.nn.LeakyReLU.rst b/docs/api/api_python/nn/mindspore.nn.LeakyReLU.rst index f9a0056d2ad..33f1e1c2098 100644 --- a/docs/api/api_python/nn/mindspore.nn.LeakyReLU.rst +++ b/docs/api/api_python/nn/mindspore.nn.LeakyReLU.rst @@ -1,33 +1,33 @@ -mindspore.nn.LeakyReLU -======================= - -.. py:class:: mindspore.nn.LeakyReLU(alpha=0.2) - - 逐元素计算Leaky ReLU激活函数。 - - 该激活函数定义如下: - - .. math:: - \text{leaky_relu}(x) = \begin{cases}x, &\text{if } x \geq 0; \cr - {\alpha} * x, &\text{otherwise.}\end{cases} - - 其中,:math:`\alpha` 表示 `alpha` 参数。 - - 更多细节详见 `Rectifier Nonlinearities Improve Neural Network Acoustic Models `_ 。 - - LeakyReLU函数图: - - .. image:: ../images/LeakyReLU.png - :align: center - - 参数: - - **alpha** (`Union[int, float]`) - `x` 小于0时激活函数的斜率,默认值: ``0.2`` 。 - - 输入: - - **x** (Tensor) - 计算LeakyReLU的任意维度的Tensor。 - - 输出: - Tensor,数据类型和shape与 `x` 相同。 - - 异常: - - **TypeError** - `alpha` 不是浮点数或整数。 +mindspore.nn.LeakyReLU +======================= + +.. py:class:: mindspore.nn.LeakyReLU(alpha=0.2) + + 逐元素计算Leaky ReLU激活函数。 + + 该激活函数定义如下: + + .. math:: + \text{leaky_relu}(x) = \begin{cases}x, &\text{if } x \geq 0; \cr + {\alpha} * x, &\text{otherwise.}\end{cases} + + 其中,:math:`\alpha` 表示 `alpha` 参数。 + + 更多细节详见 `Rectifier Nonlinearities Improve Neural Network Acoustic Models `_ 。 + + LeakyReLU函数图: + + .. image:: ../images/LeakyReLU.png + :align: center + + 参数: + - **alpha** (`Union[int, float]`) - `x` 小于0时激活函数的斜率,默认值: ``0.2`` 。 + + 输入: + - **x** (Tensor) - 计算LeakyReLU的任意维度的Tensor。 + + 输出: + Tensor,数据类型和shape与 `x` 相同。 + + 异常: + - **TypeError** - `alpha` 不是浮点数或整数。 diff --git a/docs/api/api_python/nn/mindspore.nn.ParameterUpdate.rst b/docs/api/api_python/nn/mindspore.nn.ParameterUpdate.rst index 18e4bd80aff..d806f7fb9cb 100644 --- a/docs/api/api_python/nn/mindspore.nn.ParameterUpdate.rst +++ b/docs/api/api_python/nn/mindspore.nn.ParameterUpdate.rst @@ -1,20 +1,20 @@ -mindspore.nn.ParameterUpdate -========================================= - -.. py:class:: mindspore.nn.ParameterUpdate(param) - - 更新参数的Cell。 - - 使用输入的 `Tensor` 值更新 `param` 的值。 - - 参数: - - **param** (Parameter) - 输入的参数。 - - 输入: - - **x** (Tensor) - shape和type与 `param` 相同的Tensor。 - - 输出: - Tensor,更新后的值。 - - 异常: - - **KeyError** - 指定名称的参数不存在。 +mindspore.nn.ParameterUpdate +========================================= + +.. py:class:: mindspore.nn.ParameterUpdate(param) + + 更新参数的Cell。 + + 使用输入的 `Tensor` 值更新 `param` 的值。 + + 参数: + - **param** (Parameter) - 输入的参数。 + + 输入: + - **x** (Tensor) - shape和type与 `param` 相同的Tensor。 + + 输出: + Tensor,更新后的值。 + + 异常: + - **KeyError** - 指定名称的参数不存在。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_i0.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_i0.rst old mode 100755 new mode 100644 index 6d69d2f75c0..03c7dba2851 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_i0.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_i0.rst @@ -1,26 +1,26 @@ -mindspore.ops.bessel_i0 -======================= - -.. py:function:: mindspore.ops.bessel_i0(x) - - 逐元素计算第一类零阶修正Bessel函数值。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - I_{0}(x)=J_{0}(\mathrm{i} x)=\sum_{m=0}^{\infty} - \frac{x^{2 m}}{2^{2 m} (m !)^{2}} - \end{array} - - 其中 :math:`J_{0}` 是第一类零阶Bessel函数。 - - 参数: - - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_i0 +======================= + +.. py:function:: mindspore.ops.bessel_i0(x) + + 逐元素计算第一类零阶修正Bessel函数值。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + I_{0}(x)=J_{0}(\mathrm{i} x)=\sum_{m=0}^{\infty} + \frac{x^{2 m}}{2^{2 m} (m !)^{2}} + \end{array} + + 其中 :math:`J_{0}` 是第一类零阶Bessel函数。 + + 参数: + - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_i0e.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_i0e.rst old mode 100755 new mode 100644 index 8882a37bfbf..50d84a3f586 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_i0e.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_i0e.rst @@ -1,26 +1,26 @@ -mindspore.ops.bessel_i0e -======================== - -.. py:function:: mindspore.ops.bessel_i0e(x) - - 逐元素计算指数缩放第一类零阶修正贝塞尔函数。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - \text I_{0}e(x)=e^{(-|x|)} * I_{0}(x)=e^{(-|x|)} * \sum_{m=0}^ - {\infty} \frac{x^{2 m}}{2^{2 m} (m !)^{2}} - \end{array} - - 其中 :math:`I_{0}` 是第一类零阶修正Bessel函数。 - - 参数: - - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_i0e +======================== + +.. py:function:: mindspore.ops.bessel_i0e(x) + + 逐元素计算指数缩放第一类零阶修正贝塞尔函数。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + \text I_{0}e(x)=e^{(-|x|)} * I_{0}(x)=e^{(-|x|)} * \sum_{m=0}^ + {\infty} \frac{x^{2 m}}{2^{2 m} (m !)^{2}} + \end{array} + + 其中 :math:`I_{0}` 是第一类零阶修正Bessel函数。 + + 参数: + - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_j0.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_j0.rst old mode 100755 new mode 100644 index f0f1748b3ff..1a6d4e25742 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_j0.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_j0.rst @@ -1,24 +1,24 @@ -mindspore.ops.bessel_j0 -======================= - -.. py:function:: mindspore.ops.bessel_j0(x) - - 逐元素计算输入数据的第一类零阶的Bessel函数。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - J_{0}(x) = \frac{1}{\pi} \int_{0}^{\pi} \cos (x \sin \theta) d \theta - =\sum_{m=0}^{\infty} \frac{(-1)^{m} x^{2 m}}{2^{2 m} (m !)^2} - \end{array} - - 参数: - - **x** (Tensor) - 输入Tensor。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_j0 +======================= + +.. py:function:: mindspore.ops.bessel_j0(x) + + 逐元素计算输入数据的第一类零阶的Bessel函数。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + J_{0}(x) = \frac{1}{\pi} \int_{0}^{\pi} \cos (x \sin \theta) d \theta + =\sum_{m=0}^{\infty} \frac{(-1)^{m} x^{2 m}}{2^{2 m} (m !)^2} + \end{array} + + 参数: + - **x** (Tensor) - 输入Tensor。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_j1.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_j1.rst old mode 100755 new mode 100644 index 9b6c5287e46..fb3b9475a78 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_j1.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_j1.rst @@ -1,24 +1,24 @@ -mindspore.ops.bessel_j1 -======================= - -.. py:function:: mindspore.ops.bessel_j1(x) - - 逐元素计算输入数据的第一类一阶的Bessel函数。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - J_{1}(x) = \frac{1}{\pi} \int_{0}^{\pi} \cos (x \sin \theta- \theta) d \theta - =\sum_{m=0}^{\infty} \frac{(-1)^{m} x^{2 m+1}}{2^{2 m+1} m !(m+1) !} - \end{array} - - 参数: - - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_j1 +======================= + +.. py:function:: mindspore.ops.bessel_j1(x) + + 逐元素计算输入数据的第一类一阶的Bessel函数。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + J_{1}(x) = \frac{1}{\pi} \int_{0}^{\pi} \cos (x \sin \theta- \theta) d \theta + =\sum_{m=0}^{\infty} \frac{(-1)^{m} x^{2 m+1}}{2^{2 m+1} m !(m+1) !} + \end{array} + + 参数: + - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_k0.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_k0.rst old mode 100755 new mode 100644 index df78e49471b..8a54a091d01 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_k0.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_k0.rst @@ -1,26 +1,26 @@ -mindspore.ops.bessel_k0 -======================= - -.. py:function:: mindspore.ops.bessel_k0(x) - - 逐元素计算第二类零阶修正Bessel函数值。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - K_{0}(x)= \lim_{\nu \to 0} \left(\frac{\pi}{2}\right) \frac - {I_{-\nu}(x)-I_{\nu}(x)}{\sin (\nu \pi)} = \int_{0}^{\infty} e^{-x \cosh t} d t - \end{array} - - 其中 :math:`I_{0}` 是第一类零阶修正Bessel函数。 - - 参数: - - **x** (Tensor) - 输入Tensor。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_k0 +======================= + +.. py:function:: mindspore.ops.bessel_k0(x) + + 逐元素计算第二类零阶修正Bessel函数值。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + K_{0}(x)= \lim_{\nu \to 0} \left(\frac{\pi}{2}\right) \frac + {I_{-\nu}(x)-I_{\nu}(x)}{\sin (\nu \pi)} = \int_{0}^{\infty} e^{-x \cosh t} d t + \end{array} + + 其中 :math:`I_{0}` 是第一类零阶修正Bessel函数。 + + 参数: + - **x** (Tensor) - 输入Tensor。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bessel_k0e.rst b/docs/api/api_python/ops/mindspore.ops.func_bessel_k0e.rst old mode 100755 new mode 100644 index ac1ac8a8ebc..74e8d01ff08 --- a/docs/api/api_python/ops/mindspore.ops.func_bessel_k0e.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bessel_k0e.rst @@ -1,26 +1,26 @@ -mindspore.ops.bessel_k0e -======================== - -.. py:function:: mindspore.ops.bessel_k0e(x) - - 逐元素计算指数缩放第二类零阶修正Bessel函数值。 - - 计算公式定义如下: - - .. math:: - \begin{array}{ll} \\ - K_{0}e(x)= e^{(-|x|)} * K_{0}(x) = e^{(-|x|)} * \int_{0}^ - {\infty} e^{-x \cosh t} d t - \end{array} - - 其中 :math:`K_{0}` 是第二类零阶修正Bessel函数。 - - 参数: - - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 - - 返回: - Tensor,shape和数据类型与 `x` 相同。 - - 异常: - - **TypeError** - `x` 不是Tensor。 - - **TypeError** - `x` 的数据类型不是float16,float32或float64。 +mindspore.ops.bessel_k0e +======================== + +.. py:function:: mindspore.ops.bessel_k0e(x) + + 逐元素计算指数缩放第二类零阶修正Bessel函数值。 + + 计算公式定义如下: + + .. math:: + \begin{array}{ll} \\ + K_{0}e(x)= e^{(-|x|)} * K_{0}(x) = e^{(-|x|)} * \int_{0}^ + {\infty} e^{-x \cosh t} d t + \end{array} + + 其中 :math:`K_{0}` 是第二类零阶修正Bessel函数。 + + 参数: + - **x** (Tensor) - Tensor的输入。数据类型应为float16,float32或float64。 + + 返回: + Tensor,shape和数据类型与 `x` 相同。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `x` 的数据类型不是float16,float32或float64。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bincount.rst b/docs/api/api_python/ops/mindspore.ops.func_bincount.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_bucketize.rst b/docs/api/api_python/ops/mindspore.ops.func_bucketize.rst index 347c876eaf2..f01f6517a1b 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_bucketize.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_bucketize.rst @@ -1,27 +1,27 @@ -mindspore.ops.bucketize -========================== - -.. py:function:: mindspore.ops.bucketize(input, boundaries, *, right=False) - - 根据 `boundaries` 对 `input` 进行分桶。如果 `right` 为 ``False``,则左边界关闭,对于 `input` 中的每个元素 x,返回的索引满足以下规则: - - .. math:: - - \begin{cases} - boundaries[i-1] < x <= boundaries[i], & \text{if right} = False\\ - boundaries[i-1] <= x < boundaries[i], & \text{if right} = True - \end{cases} - - 参数: - - **input** (Tensor) - 输入的Tensor。 - - **boundaries** (list) - 表示桶的边界值的有序列表。 - - 关键字参数: - - **right** (bool, 可选) - 如果为 ``False``,则从边界获取输入中每个值的下限索引;如果为 ``True``,则改为获取上限索引。默认值:``False``。 - - 返回: - Tensor,返回的索引值,shape与输入Tensor的shape相同,数据类型为int32。 - - 异常: - - **TypeError** - `boundaries` 不是list。 - - **TypeError** - `input` 不是Tensor。 +mindspore.ops.bucketize +========================== + +.. py:function:: mindspore.ops.bucketize(input, boundaries, *, right=False) + + 根据 `boundaries` 对 `input` 进行分桶。如果 `right` 为 ``False``,则左边界关闭,对于 `input` 中的每个元素 x,返回的索引满足以下规则: + + .. math:: + + \begin{cases} + boundaries[i-1] < x <= boundaries[i], & \text{if right} = False\\ + boundaries[i-1] <= x < boundaries[i], & \text{if right} = True + \end{cases} + + 参数: + - **input** (Tensor) - 输入的Tensor。 + - **boundaries** (list) - 表示桶的边界值的有序列表。 + + 关键字参数: + - **right** (bool, 可选) - 如果为 ``False``,则从边界获取输入中每个值的下限索引;如果为 ``True``,则改为获取上限索引。默认值:``False``。 + + 返回: + Tensor,返回的索引值,shape与输入Tensor的shape相同,数据类型为int32。 + + 异常: + - **TypeError** - `boundaries` 不是list。 + - **TypeError** - `input` 不是Tensor。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_col2im.rst b/docs/api/api_python/ops/mindspore.ops.func_col2im.rst index f3566f1be8e..2e2e488644b 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_col2im.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_col2im.rst @@ -1,24 +1,24 @@ -mindspore.ops.col2im -==================== - -.. py:function:: mindspore.ops.col2im(input_x, output_size, kernel_size, dilation, padding_value, stride) - - 将一组滑动局部块组合成一个大的Tensor。 - - 参数: - - **input_x** (Tensor) - 四维Tensor,输入的批量的滑动局部块,数据类型支持float16和float32。 - - **output_size** (Tensor) - 包含两个int元素的一维Tensor,输出张量的后两维的shape。 - - **kernel_size** (Union[int, tuple[int], list[int]]) - 滑动窗口的大小。tuple的两个元素分别对应kernel的高度与宽度。如果为一个int则kernel的高度与宽度均为该值。 - - **dilation** (Union[int, tuple[int], list[int]]) - 滑动窗口扩张的大小。 - - **padding_value** (Union[int, tuple[int], list[int]]) - 填充的大小。 - - **stride** (Union[int, tuple[int], list[int]]) - 步长的大小。 - - 返回: - Tensor,输出的张量,维度和类型和输入一致。 - - 异常: - - **TypeError** - 如果 `kernel_size`,`dilation`,`padding_value`,`stride` 不属于 Union[int, tuple[int], list[int]]。 - - **ValueError** - 如果 `kernel_size`,`dilation`,`stride` 值小于等于0或者个数大于2。 - - **ValueError** - 如果 `padding_value` 值小于0或者个数大于2。 - - **ValueError** - 如果 `input_x.dims(2)` 不等于 `kernel_size[0] * kernel_size[1]` 。 - - **ValueError** - 如果 `input_x.dims(3)` 与计算出的滑动块数量不匹配。 +mindspore.ops.col2im +==================== + +.. py:function:: mindspore.ops.col2im(input_x, output_size, kernel_size, dilation, padding_value, stride) + + 将一组滑动局部块组合成一个大的Tensor。 + + 参数: + - **input_x** (Tensor) - 四维Tensor,输入的批量的滑动局部块,数据类型支持float16和float32。 + - **output_size** (Tensor) - 包含两个int元素的一维Tensor,输出张量的后两维的shape。 + - **kernel_size** (Union[int, tuple[int], list[int]]) - 滑动窗口的大小。tuple的两个元素分别对应kernel的高度与宽度。如果为一个int则kernel的高度与宽度均为该值。 + - **dilation** (Union[int, tuple[int], list[int]]) - 滑动窗口扩张的大小。 + - **padding_value** (Union[int, tuple[int], list[int]]) - 填充的大小。 + - **stride** (Union[int, tuple[int], list[int]]) - 步长的大小。 + + 返回: + Tensor,输出的张量,维度和类型和输入一致。 + + 异常: + - **TypeError** - 如果 `kernel_size`,`dilation`,`padding_value`,`stride` 不属于 Union[int, tuple[int], list[int]]。 + - **ValueError** - 如果 `kernel_size`,`dilation`,`stride` 值小于等于0或者个数大于2。 + - **ValueError** - 如果 `padding_value` 值小于0或者个数大于2。 + - **ValueError** - 如果 `input_x.dims(2)` 不等于 `kernel_size[0] * kernel_size[1]` 。 + - **ValueError** - 如果 `input_x.dims(3)` 与计算出的滑动块数量不匹配。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_abs.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_abs.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_acos.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_acos.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_acosh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_acosh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_asin.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_asin.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_asinh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_asinh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_atan.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_atan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_atanh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_atanh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_ceil.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_ceil.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_cos.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_cos.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_cosh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_cosh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_exp.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_exp.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_expm1.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_expm1.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_floor.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_floor.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_inv.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_inv.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_isfinite.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_isfinite.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_isnan.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_isnan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_log.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_log.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_log1p.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_log1p.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_neg.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_neg.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_relu.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_relu.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_relu6.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_relu6.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_round.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_round.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_sigmoid.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_sigmoid.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_sin.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_sin.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_sinh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_sinh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_softsign.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_softsign.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_sqrt.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_sqrt.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_square.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_square.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_tan.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_tan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_coo_tanh.rst b/docs/api/api_python/ops/mindspore.ops.func_coo_tanh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_abs.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_abs.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_acos.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_acos.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_acosh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_acosh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_asin.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_asin.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_asinh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_asinh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_atan.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_atan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_atanh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_atanh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_ceil.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_ceil.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_cos.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_cos.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_cosh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_cosh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_exp.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_exp.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_expm1.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_expm1.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_floor.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_floor.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_inv.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_inv.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_isfinite.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_isfinite.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_isnan.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_isnan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_log.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_log.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_log1p.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_log1p.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_neg.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_neg.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_relu.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_relu.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_relu6.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_relu6.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_round.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_round.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_sigmoid.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_sigmoid.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_sin.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_sin.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_sinh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_sinh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_softsign.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_softsign.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_sqrt.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_sqrt.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_square.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_square.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_tan.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_tan.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_tanh.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_tanh.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_ctc_greedy_decoder.rst b/docs/api/api_python/ops/mindspore.ops.func_ctc_greedy_decoder.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_grid_sample.rst b/docs/api/api_python/ops/mindspore.ops.func_grid_sample.rst old mode 100755 new mode 100644 index 42908e79666..c9e1a9de954 --- a/docs/api/api_python/ops/mindspore.ops.func_grid_sample.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_grid_sample.rst @@ -1,42 +1,42 @@ -mindspore.ops.grid_sample -========================= - -.. py:function:: mindspore.ops.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False) - - 给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。`input` 只支持4-D(GridSampler2D)和5-D(GridSampler3D)。 - - 在4-D场景下,`input` 的shape为 :math:`(N, C, H_{in}, W_{in})`,`grid` 的shape为 :math:`(N, H_{out}, W_{out}, 2)`,`output` 的shape为 :math:`(N, C, H_{out}, W_{out})`。 - 对于每个输出位置 `output[n, :, h, w]`,`grid[n, h, w]` 指定 `input` 像素位置 `x` 和 `y`,用于计算 `output[n, :, h, w]` 的插值。以5D为例,`grid[n, d, h, w]` 指定 `x`, - `y`,`z` 像素位置的插值位置为[n, :, d, h, w]。`mode` 参数指定 `nearest` 或 `bilinear` (bicubic暂不支持)插值法对输入像素进行采样。 - - `grid` 指定由 `input` 归一化的采样像素位置。因此,它应该在 :math:`[-1, 1]` 范围内的值最多。 - - 如果 `grid` 的值在 :math:`[-1, 1]` 范围之外,则相应的输出将按照定义的 `padding_mode` 方式处理。如果 `padding_mode` 设置为 ``0`` ,则使用 :math:`0` 来表示出界的网格位置。 - 如果 `padding_mode` 设置为 ``border``,对于出界网格位置,则使用border值。如果 `padding_mode` 设置为 ``reflection`` ,请使用边界所反映的位置的值用于指定出界网格位置。对于 - 远离边界的位置,它会一直被反射,直到在边界内。 - - 参数: - - **input** (Tensor) - 4-D场景下,shape为 :math:`(N, C, H_{in}, W_{in})`,5-D场景下,shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`。数据类型为float32或float64。 - - **grid** (Tensor) - 4-D场景下,shape为 :math:`(N, H_{out}, W_{out}, 2)`,5-D场景下,shape为 :math:`(N, D_{out}, H_{out}, W_{out}, 3)`。数据类型与 `input` 保持一致。 - - **mode** (str) - 插值方法。可选方法为 ``'bilinear'``, ``'nearest'``。默认值: ``'bilinear'`` 。注: ``'bilinear'`` 还不支持。当 `mode` 为 ``'bilinear'``,且输入为5-D,则 `mode` 为 ``'trilinear'``。但是,当输入为4-D,则 `mode` 为 ``'bilinear'``。默认值: ``'bilinear'`` 。 - - - ``'nearest'``:最近邻插值。每个输出像素的值为最近的输入像素的值。这种方法简单快速,但可能导致块状或像素化的输出。 - - ``'bilinear'``:双线性插值。每个输出像素是最接近的四个输入像素的加权平均值,使用双线性插值计算。与最近邻插值相比,此方法产生更平滑的结果。 - - ``'trilinear'``:三线性插值。这是双线性插值在三维数据上的扩展。它在两个空间维度上执行双线性插值,并沿第三个维度进行线性插值。通常用于体积或三维图像插值。 - - - **padding_mode** (str) - 填充方法。可选方法为 ``'zeros'``,``'border'`` 和 ``'reflection'``。默认值: ``'zeros'`` 。 - - **align_corners** (bool) - 如果设置成 `True`,-1和1被视为引用输入角像素的中心点。如果设置为 `False`,将被视为引用到输入角像素的角点,使采样更不受分辨率影响。默认值为 `False`。 - - 返回: - Tensor,数据类型与 `input` 相同,4-D场景下,shape为 :math:`(N, C, H_{out}, W_{out})`,5-D场景下,shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`。 - - 异常: - - **TypeError** - 如果 `input` 或 `grid` 不是Tensor类型。 - - **TypeError** - 如果 `input` 和 `grid` 的数据类型不一致。 - - **TypeError** - 如果 `input` 或 `grid` 的数据类型无效。 - - **TypeError** - 如果 `align_corners` 不是一个布尔值。 - - **ValueError** - 如果 `input` 或 `grid` 的维度不是四维或五维。 - - **ValueError** - 如果 `input` 的第一个维度不等于 `grid` 的第一个维度。 - - **ValueError** - 如果 `grid` 最后一个维度不等于2(4-D场景)或者3(5-D场景)。 - - **ValueError** - 如果 `mode` 不是 `bilinear`,`nearest`,数据类型不为String。 - - **ValueError** - 如果 `padding_mode` 不是 `zeros`,`border`,`reflection`,数据类型不为String。 +mindspore.ops.grid_sample +========================= + +.. py:function:: mindspore.ops.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False) + + 给定一个输入和一个网格,使用网格中的输入值和像素位置计算输出。`input` 只支持4-D(GridSampler2D)和5-D(GridSampler3D)。 + + 在4-D场景下,`input` 的shape为 :math:`(N, C, H_{in}, W_{in})`,`grid` 的shape为 :math:`(N, H_{out}, W_{out}, 2)`,`output` 的shape为 :math:`(N, C, H_{out}, W_{out})`。 + 对于每个输出位置 `output[n, :, h, w]`,`grid[n, h, w]` 指定 `input` 像素位置 `x` 和 `y`,用于计算 `output[n, :, h, w]` 的插值。以5D为例,`grid[n, d, h, w]` 指定 `x`, + `y`,`z` 像素位置的插值位置为[n, :, d, h, w]。`mode` 参数指定 `nearest` 或 `bilinear` (bicubic暂不支持)插值法对输入像素进行采样。 + + `grid` 指定由 `input` 归一化的采样像素位置。因此,它应该在 :math:`[-1, 1]` 范围内的值最多。 + + 如果 `grid` 的值在 :math:`[-1, 1]` 范围之外,则相应的输出将按照定义的 `padding_mode` 方式处理。如果 `padding_mode` 设置为 ``0`` ,则使用 :math:`0` 来表示出界的网格位置。 + 如果 `padding_mode` 设置为 ``border``,对于出界网格位置,则使用border值。如果 `padding_mode` 设置为 ``reflection`` ,请使用边界所反映的位置的值用于指定出界网格位置。对于 + 远离边界的位置,它会一直被反射,直到在边界内。 + + 参数: + - **input** (Tensor) - 4-D场景下,shape为 :math:`(N, C, H_{in}, W_{in})`,5-D场景下,shape为 :math:`(N, C, D_{in}, H_{in}, W_{in})`。数据类型为float32或float64。 + - **grid** (Tensor) - 4-D场景下,shape为 :math:`(N, H_{out}, W_{out}, 2)`,5-D场景下,shape为 :math:`(N, D_{out}, H_{out}, W_{out}, 3)`。数据类型与 `input` 保持一致。 + - **mode** (str) - 插值方法。可选方法为 ``'bilinear'``, ``'nearest'``。默认值: ``'bilinear'`` 。注: ``'bilinear'`` 还不支持。当 `mode` 为 ``'bilinear'``,且输入为5-D,则 `mode` 为 ``'trilinear'``。但是,当输入为4-D,则 `mode` 为 ``'bilinear'``。默认值: ``'bilinear'`` 。 + + - ``'nearest'``:最近邻插值。每个输出像素的值为最近的输入像素的值。这种方法简单快速,但可能导致块状或像素化的输出。 + - ``'bilinear'``:双线性插值。每个输出像素是最接近的四个输入像素的加权平均值,使用双线性插值计算。与最近邻插值相比,此方法产生更平滑的结果。 + - ``'trilinear'``:三线性插值。这是双线性插值在三维数据上的扩展。它在两个空间维度上执行双线性插值,并沿第三个维度进行线性插值。通常用于体积或三维图像插值。 + + - **padding_mode** (str) - 填充方法。可选方法为 ``'zeros'``,``'border'`` 和 ``'reflection'``。默认值: ``'zeros'`` 。 + - **align_corners** (bool) - 如果设置成 `True`,-1和1被视为引用输入角像素的中心点。如果设置为 `False`,将被视为引用到输入角像素的角点,使采样更不受分辨率影响。默认值为 `False`。 + + 返回: + Tensor,数据类型与 `input` 相同,4-D场景下,shape为 :math:`(N, C, H_{out}, W_{out})`,5-D场景下,shape为 :math:`(N, C, D_{out}, H_{out}, W_{out})`。 + + 异常: + - **TypeError** - 如果 `input` 或 `grid` 不是Tensor类型。 + - **TypeError** - 如果 `input` 和 `grid` 的数据类型不一致。 + - **TypeError** - 如果 `input` 或 `grid` 的数据类型无效。 + - **TypeError** - 如果 `align_corners` 不是一个布尔值。 + - **ValueError** - 如果 `input` 或 `grid` 的维度不是四维或五维。 + - **ValueError** - 如果 `input` 的第一个维度不等于 `grid` 的第一个维度。 + - **ValueError** - 如果 `grid` 最后一个维度不等于2(4-D场景)或者3(5-D场景)。 + - **ValueError** - 如果 `mode` 不是 `bilinear`,`nearest`,数据类型不为String。 + - **ValueError** - 如果 `padding_mode` 不是 `zeros`,`border`,`reflection`,数据类型不为String。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_hamming_window.rst b/docs/api/api_python/ops/mindspore.ops.func_hamming_window.rst old mode 100755 new mode 100644 diff --git a/docs/api/api_python/ops/mindspore.ops.func_i0.rst b/docs/api/api_python/ops/mindspore.ops.func_i0.rst old mode 100755 new mode 100644 index 171a602117a..a1e2b181999 --- a/docs/api/api_python/ops/mindspore.ops.func_i0.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_i0.rst @@ -1,6 +1,6 @@ -mindspore.ops.i0 -================= - -.. py:function:: mindspore.ops.i0(input) - - :func:`mindspore.ops.bessel_i0` 的别名。 +mindspore.ops.i0 +================= + +.. py:function:: mindspore.ops.i0(input) + + :func:`mindspore.ops.bessel_i0` 的别名。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_population_count.rst b/docs/api/api_python/ops/mindspore.ops.func_population_count.rst old mode 100755 new mode 100644 index 887a651c632..fe5c7b4f426 --- a/docs/api/api_python/ops/mindspore.ops.func_population_count.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_population_count.rst @@ -1,18 +1,18 @@ -mindspore.ops.population_count -============================== - -.. py:function:: mindspore.ops.population_count(input_x) - - 逐元素计算population count(又称bitsum, bitcount)。 - 对于 `input_x` 中的每个entry,计算该entry的二进制表示中的1比特的数量。 - - 参数: - - **input_x** (Tensor) - 任意维度的Tensor。Ascend平台支持的数据类型为int16、uint16,CPU和GPU平台支持的数据类型为int8、int16、int32、int64、uint8、uint16、uint32、uint64。 - - 返回: - Tensor,shape与 `input_x` 相同,数据类型为uint8。 - - 异常: - - **TypeError** - `input_x` 不是Tensor。 - - **TypeError** - `input_x` 的数据类型不是int16或uint16(Ascend平台)。 - - **TypeError** - `input_x` 的数据类型不是int8、int16、int32、int64、uint8、uint16、uint32、uint64(CPU和GPU平台)。 +mindspore.ops.population_count +============================== + +.. py:function:: mindspore.ops.population_count(input_x) + + 逐元素计算population count(又称bitsum, bitcount)。 + 对于 `input_x` 中的每个entry,计算该entry的二进制表示中的1比特的数量。 + + 参数: + - **input_x** (Tensor) - 任意维度的Tensor。Ascend平台支持的数据类型为int16、uint16,CPU和GPU平台支持的数据类型为int8、int16、int32、int64、uint8、uint16、uint32、uint64。 + + 返回: + Tensor,shape与 `input_x` 相同,数据类型为uint8。 + + 异常: + - **TypeError** - `input_x` 不是Tensor。 + - **TypeError** - `input_x` 的数据类型不是int16或uint16(Ascend平台)。 + - **TypeError** - `input_x` 的数据类型不是int8、int16、int32、int64、uint8、uint16、uint32、uint64(CPU和GPU平台)。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_trunc.rst b/docs/api/api_python/ops/mindspore.ops.func_trunc.rst old mode 100755 new mode 100644 index a1551e55876..a40fc3facc2 --- a/docs/api/api_python/ops/mindspore.ops.func_trunc.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_trunc.rst @@ -1,15 +1,15 @@ -mindspore.ops.trunc -=================== - -.. py:function:: mindspore.ops.trunc(input) - - 返回一个新的Tensor,该Tensor具有输入元素的截断整数值。 - - 参数: - - **input** (Tensor) - 任意维度的Tensor。 - - 返回: - Tensor,shape和数据类型与 `input` 相同。 - - 异常: - - **TypeError** - `input` 不是Tensor。 +mindspore.ops.trunc +=================== + +.. py:function:: mindspore.ops.trunc(input) + + 返回一个新的Tensor,该Tensor具有输入元素的截断整数值。 + + 参数: + - **input** (Tensor) - 任意维度的Tensor。 + + 返回: + Tensor,shape和数据类型与 `input` 相同。 + + 异常: + - **TypeError** - `input` 不是Tensor。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.block_diag.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.block_diag.rst index d21a5e75dca..fdca4ccb351 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.block_diag.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.block_diag.rst @@ -1,27 +1,27 @@ -mindspore.scipy.linalg.block_diag -================================= - -.. py:function:: mindspore.scipy.linalg.block_diag(*arrs) - - 根据输入的数组创建块对角矩阵。 - - 输入为:`A`、`B` 和 `C` 的Tensor列表。输出为:在对角线上排列这些Tensor的块对角矩阵。 - - .. code-block:: - - [[A, 0, 0], - [0, B, 0], - [0, 0, C]] - - .. note:: - Windows平台上还不支持 `block_diag`。 - - 参数: - - **arrs** (list) - 最大支持2D的Tensor输入。 - 一个或多个Tensor,维度支持0D,1D、2D。 - - 返回: - 对角线上含有 `A`、`B`、`C`,...的Tensor,数据类型与 `A` 相同。 - - 异常: - - **ValueError** - 输入参数中存在维度大于2的Tensor。 +mindspore.scipy.linalg.block_diag +================================= + +.. py:function:: mindspore.scipy.linalg.block_diag(*arrs) + + 根据输入的数组创建块对角矩阵。 + + 输入为:`A`、`B` 和 `C` 的Tensor列表。输出为:在对角线上排列这些Tensor的块对角矩阵。 + + .. code-block:: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] + + .. note:: + Windows平台上还不支持 `block_diag`。 + + 参数: + - **arrs** (list) - 最大支持2D的Tensor输入。 + 一个或多个Tensor,维度支持0D,1D、2D。 + + 返回: + 对角线上含有 `A`、`B`、`C`,...的Tensor,数据类型与 `A` 相同。 + + 异常: + - **ValueError** - 输入参数中存在维度大于2的Tensor。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_factor.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_factor.rst index d10eda39358..acd7b0bbea1 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_factor.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_factor.rst @@ -1,44 +1,44 @@ -mindspore.scipy.linalg.cho_factor -================================= - -.. py:function:: mindspore.scipy.linalg.cho_factor(a, lower=False, overwrite_a=False, check_finite=True) - - 计算矩阵的cholesky分解,用于 :func:`mindspore.scipy.linalg.cho_solve`。 - - 返回包含cholesky分解的矩阵,对于一个Hermitian正定矩阵 `A`,根据 `lower` 取值,进行如下形式的分解: - - - `lower` 为True: :math:`A = L L^*` - - `lower` 为False: :math:`A = U^* U` - - 其中, :math:`L^*` 为 :math:`L` 的共轭转置矩阵。 - 其中, :math:`U^*` 为 :math:`U` 的共轭转置矩阵。 - - 返回值可以直接作为 :func:`mindspore.scipy.linalg.cho_solve` 的第一个参数使用。 - - .. note:: - - Windows平台上还不支持 `cho_factor`。 - - 仅支持float32、float64、int32、int64类型的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - .. warning:: - 返回的矩阵中还包含cholesky分解不使用的条目中的随机数据。如果需要将这些条目清零,请改用 :func:`mindspore.scipy.linalg.cholesky` 函数。 - - 参数: - - **a** (Tensor) - 要分解的 :math:`(M,M)` 方阵。 - - **lower** (bool, 可选) - 是计算上三角还是下三角的cholesky分解。 - 默认值:``False``。 - - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - 在MindSpore中,这个参数当前不起作用。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - 在MindSpore中,当前这个参数不起作用。 - - 返回: - - **c** (Tensor) - 在上三角或下三角中包含 `a` 的cholesky因子的矩阵。 - 矩阵的其他部分包含随机数据。 - - **lower** (bool) - 表示cholesky因子是在下三角形还是上三角形。 - - 异常: - - **ValueError** - 如果输入的Tensor不是2D方阵。 +mindspore.scipy.linalg.cho_factor +================================= + +.. py:function:: mindspore.scipy.linalg.cho_factor(a, lower=False, overwrite_a=False, check_finite=True) + + 计算矩阵的cholesky分解,用于 :func:`mindspore.scipy.linalg.cho_solve`。 + + 返回包含cholesky分解的矩阵,对于一个Hermitian正定矩阵 `A`,根据 `lower` 取值,进行如下形式的分解: + + - `lower` 为True: :math:`A = L L^*` + - `lower` 为False: :math:`A = U^* U` + + 其中, :math:`L^*` 为 :math:`L` 的共轭转置矩阵。 + 其中, :math:`U^*` 为 :math:`U` 的共轭转置矩阵。 + + 返回值可以直接作为 :func:`mindspore.scipy.linalg.cho_solve` 的第一个参数使用。 + + .. note:: + - Windows平台上还不支持 `cho_factor`。 + - 仅支持float32、float64、int32、int64类型的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + .. warning:: + 返回的矩阵中还包含cholesky分解不使用的条目中的随机数据。如果需要将这些条目清零,请改用 :func:`mindspore.scipy.linalg.cholesky` 函数。 + + 参数: + - **a** (Tensor) - 要分解的 :math:`(M,M)` 方阵。 + - **lower** (bool, 可选) - 是计算上三角还是下三角的cholesky分解。 + 默认值:``False``。 + - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + 在MindSpore中,这个参数当前不起作用。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + 在MindSpore中,当前这个参数不起作用。 + + 返回: + - **c** (Tensor) - 在上三角或下三角中包含 `a` 的cholesky因子的矩阵。 + 矩阵的其他部分包含随机数据。 + - **lower** (bool) - 表示cholesky因子是在下三角形还是上三角形。 + + 异常: + - **ValueError** - 如果输入的Tensor不是2D方阵。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_solve.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_solve.rst index 62709382881..2d8aa06a357 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_solve.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.cho_solve.rst @@ -1,25 +1,25 @@ -mindspore.scipy.linalg.cho_solve -================================ - -.. py:function:: mindspore.scipy.linalg.cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True) - - 给定 :math:`A` 的cholesky分解,求解线性方程组。 - - .. math:: - A x = b - - .. note:: - - 仅支持float32、float64、int32、int64类型的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **c_and_lower** ((Tensor, bool)) - :math:`a` 的cholesky分解,由::func:`mindspore.scipy.linalg.cho_factor` 计算得出。 - - **b** (Tensor) - 方程右侧的值。 - - **overwrite_b** (bool, 可选) - 是否覆盖::math:`b` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - - 返回: - Tensor,线性方程 :math:`A x = b` 的解。 +mindspore.scipy.linalg.cho_solve +================================ + +.. py:function:: mindspore.scipy.linalg.cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True) + + 给定 :math:`A` 的cholesky分解,求解线性方程组。 + + .. math:: + A x = b + + .. note:: + - 仅支持float32、float64、int32、int64类型的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **c_and_lower** ((Tensor, bool)) - :math:`a` 的cholesky分解,由::func:`mindspore.scipy.linalg.cho_factor` 计算得出。 + - **b** (Tensor) - 方程右侧的值。 + - **overwrite_b** (bool, 可选) - 是否覆盖::math:`b` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + + 返回: + Tensor,线性方程 :math:`A x = b` 的解。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.cholesky.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.cholesky.rst index ef8ef400f38..2a069f544a2 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.cholesky.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.cholesky.rst @@ -1,37 +1,37 @@ -mindspore.scipy.linalg.cholesky -=============================== - -.. py:function:: mindspore.scipy.linalg.cholesky(a, lower=False, overwrite_a=False, check_finite=True) - - 计算矩阵的cholesky分解。 - - 返回包含cholesky分解的矩阵,对于一个Hermitian正定矩阵 `A`,根据 `lower` 取值,进行如下形式的分解: - - - `lower` 为True: :math:`A = L L^*` - - `lower` 为False: :math:`A = U^* U` - - 其中, :math:`L^*` 为 :math:`L` 的共轭转置矩阵。 - 其中, :math:`U^*` 为 :math:`U` 的共轭转置矩阵。 - - .. note:: - - Windows平台上还不支持 `cholesky`。 - - 仅支持float32、float64、int32、int64类型的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **a** (Tensor) - 要分解的 :math:`(M, M)` 方阵。 - - **lower** (bool, 可选) - 是计算上三角还是下三角的cholesky分解。 - 默认值:``False``。 - - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - 在MindSpore中,这个参数当前不起作用。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - 在MindSpore中,当前这个参数不起作用。 - - 返回: - Tensor,`a` 的上三角或下三角cholesky因子。 - - 异常: - - **ValueError** - 如果输入的Tensor不是2D方阵。 +mindspore.scipy.linalg.cholesky +=============================== + +.. py:function:: mindspore.scipy.linalg.cholesky(a, lower=False, overwrite_a=False, check_finite=True) + + 计算矩阵的cholesky分解。 + + 返回包含cholesky分解的矩阵,对于一个Hermitian正定矩阵 `A`,根据 `lower` 取值,进行如下形式的分解: + + - `lower` 为True: :math:`A = L L^*` + - `lower` 为False: :math:`A = U^* U` + + 其中, :math:`L^*` 为 :math:`L` 的共轭转置矩阵。 + 其中, :math:`U^*` 为 :math:`U` 的共轭转置矩阵。 + + .. note:: + - Windows平台上还不支持 `cholesky`。 + - 仅支持float32、float64、int32、int64类型的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **a** (Tensor) - 要分解的 :math:`(M, M)` 方阵。 + - **lower** (bool, 可选) - 是计算上三角还是下三角的cholesky分解。 + 默认值:``False``。 + - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + 在MindSpore中,这个参数当前不起作用。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + 在MindSpore中,当前这个参数不起作用。 + + 返回: + Tensor,`a` 的上三角或下三角cholesky因子。 + + 异常: + - **ValueError** - 如果输入的Tensor不是2D方阵。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.eigh.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.eigh.rst index 58a47b150d3..e9c1dd77ca6 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.eigh.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.eigh.rst @@ -1,71 +1,71 @@ -mindspore.scipy.linalg.eigh -=========================== - -.. py:function:: mindspore.scipy.linalg.eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True) - - 求解复Hermitian矩阵或实对称矩阵的标准或广义特征值问题。 - - 求出 `a` 的特征值Tensor `w` 和可选的特征值Tensor `v`,其中 `b` 是正定的,使得对于每个特征值 `λ` ( `w` 的第i个条目)及其特征向量 `vi` ( `v` 的第i列)满足: - - .. code-block:: - - a @ vi = λ * b @ vi - vi.conj().T @ a @ vi = λ - vi.conj().T @ b @ vi = 1 - - 在标准问题中,假设 `b` 是单位矩阵。 - - .. note:: - - Windows平台上还不支持 `eigh`。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **a** (Tensor) - 一个shape 为 :math:`(M,M)` 的复Hermitian矩阵或实对称矩阵,用于计算其特征值和特征向量。 - - **b** (Tensor, 可选) - 一个shape为 :math:`(M,M)` 的复Hermitian矩阵或实对称正矩阵。 - 如果缺省,则假定为传入单位矩阵。 - 默认值:``None``。 - - **lower** (bool, 可选) - 控制相关的Tensor数据是取自 `a` 和 `b` 的下三角还是上三角。 - 默认值:``True``。 - - **eigvals_only** (bool, 可选) - 是否只计算特征值,不计算特征向量。 - 默认值:``False``。 - - **overwrite_a** (bool, 可选) - 是否覆盖 `a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **overwrite_b** (bool, 可选) - 是否覆盖 `b` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **turbo** (bool, 可选) - 使用分而治之算法(速度更快,但占用大量内存,仅适用于需要计算全量特征值的广义特征值问题)。 - 如果不需要计算特征向量,则没有显著影响。 - 默认值:``True``。 - - **eigvals** (tuple, 可选) - 要返回的最小和最大(按升序排列)特征值和对应的特征向量的索引: :math:`0 <= lo <= hi <= M-1`。 - 如果缺省,则返回所有特征值和特征向量。 - 默认值:``None``。 - - **type** (int, 可选) - 对于广义问题,此参数指定 `w` 和 `v` 要解决的问题类型(仅取1、2、3作为可能的输入): - - .. code-block:: - - 1 => a @ v = w @ b @ v - 2 => a @ b @ v = w @ v - 3 => b @ a @ v = w @ v - - 对于标准问题,会忽略此关键字。默认值:``1``。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - - 返回: - - **w** (Tensor) - 返回shape为 :math:`(N,)` 的Tensor,其中特征值 :math:`N (1<=N<=M)`,按升序排列,根据其多样性重复。 - - - **v** (Tensor) - 如果 `eigvals_only==False`,返回shape为 :math:`(M, N)` 的Tensor。 - - 异常: - - **RuntimeError** - 如果特征值计算不收敛或 `b` 不是正定矩阵,则会触发报错。 - 如果输入矩阵不是对称矩阵或Hermitian矩阵,则不会报告错误,但结果将是错误的。 - - **TypeError** - 如果 `a` 不是Tensor。 - - **TypeError** - 如果 `low` 不是bool类型。 - - **TypeError** - 如果 `eigvals_only` 不是bool类型。 - - **TypeError** - 如果 `overwrite_a` 不是bool类型。 - - **TypeError** - 如果 `overwrite_b` 不是bool类型。 - - **TypeError** - 如果 `turbo` 不是bool类型。 - - **TypeError** - 如果 `check_finite` 不是bool类型。 - - **ValueError** - 如果 `a` 不是2D方阵。 - - **ValueError** - 如果 `b` 不为None。 - - **ValueError** - 如果 `eigvals` 不是None。 +mindspore.scipy.linalg.eigh +=========================== + +.. py:function:: mindspore.scipy.linalg.eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True) + + 求解复Hermitian矩阵或实对称矩阵的标准或广义特征值问题。 + + 求出 `a` 的特征值Tensor `w` 和可选的特征值Tensor `v`,其中 `b` 是正定的,使得对于每个特征值 `λ` ( `w` 的第i个条目)及其特征向量 `vi` ( `v` 的第i列)满足: + + .. code-block:: + + a @ vi = λ * b @ vi + vi.conj().T @ a @ vi = λ + vi.conj().T @ b @ vi = 1 + + 在标准问题中,假设 `b` 是单位矩阵。 + + .. note:: + - Windows平台上还不支持 `eigh`。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **a** (Tensor) - 一个shape 为 :math:`(M,M)` 的复Hermitian矩阵或实对称矩阵,用于计算其特征值和特征向量。 + - **b** (Tensor, 可选) - 一个shape为 :math:`(M,M)` 的复Hermitian矩阵或实对称正矩阵。 + 如果缺省,则假定为传入单位矩阵。 + 默认值:``None``。 + - **lower** (bool, 可选) - 控制相关的Tensor数据是取自 `a` 和 `b` 的下三角还是上三角。 + 默认值:``True``。 + - **eigvals_only** (bool, 可选) - 是否只计算特征值,不计算特征向量。 + 默认值:``False``。 + - **overwrite_a** (bool, 可选) - 是否覆盖 `a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **overwrite_b** (bool, 可选) - 是否覆盖 `b` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **turbo** (bool, 可选) - 使用分而治之算法(速度更快,但占用大量内存,仅适用于需要计算全量特征值的广义特征值问题)。 + 如果不需要计算特征向量,则没有显著影响。 + 默认值:``True``。 + - **eigvals** (tuple, 可选) - 要返回的最小和最大(按升序排列)特征值和对应的特征向量的索引: :math:`0 <= lo <= hi <= M-1`。 + 如果缺省,则返回所有特征值和特征向量。 + 默认值:``None``。 + - **type** (int, 可选) - 对于广义问题,此参数指定 `w` 和 `v` 要解决的问题类型(仅取1、2、3作为可能的输入): + + .. code-block:: + + 1 => a @ v = w @ b @ v + 2 => a @ b @ v = w @ v + 3 => b @ a @ v = w @ v + + 对于标准问题,会忽略此关键字。默认值:``1``。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + + 返回: + - **w** (Tensor) - 返回shape为 :math:`(N,)` 的Tensor,其中特征值 :math:`N (1<=N<=M)`,按升序排列,根据其多样性重复。 + + - **v** (Tensor) - 如果 `eigvals_only==False`,返回shape为 :math:`(M, N)` 的Tensor。 + + 异常: + - **RuntimeError** - 如果特征值计算不收敛或 `b` 不是正定矩阵,则会触发报错。 + 如果输入矩阵不是对称矩阵或Hermitian矩阵,则不会报告错误,但结果将是错误的。 + - **TypeError** - 如果 `a` 不是Tensor。 + - **TypeError** - 如果 `low` 不是bool类型。 + - **TypeError** - 如果 `eigvals_only` 不是bool类型。 + - **TypeError** - 如果 `overwrite_a` 不是bool类型。 + - **TypeError** - 如果 `overwrite_b` 不是bool类型。 + - **TypeError** - 如果 `turbo` 不是bool类型。 + - **TypeError** - 如果 `check_finite` 不是bool类型。 + - **ValueError** - 如果 `a` 不是2D方阵。 + - **ValueError** - 如果 `b` 不为None。 + - **ValueError** - 如果 `eigvals` 不是None。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.inv.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.inv.rst index be5ee96dcb3..86c29524912 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.inv.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.inv.rst @@ -1,26 +1,26 @@ -mindspore.scipy.linalg.inv -========================== - -.. py:function:: mindspore.scipy.linalg.inv(a, overwrite_a=False, check_finite=True) - - 计算矩阵的逆。 - - .. note:: - - Windows平台上还不支持 `inv`。 - - 仅支持float32、float64、int32、int64类型的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **a** (Tensor) - 要求逆的方阵。 - - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - - 返回: - Tensor,矩阵 `a` 的逆。 - - 异常: - - **LinAlgError** - 如果 :math:`a` 是单数。 - - **ValueError** - 如果 :math:`a` 不是2D方阵。 +mindspore.scipy.linalg.inv +========================== + +.. py:function:: mindspore.scipy.linalg.inv(a, overwrite_a=False, check_finite=True) + + 计算矩阵的逆。 + + .. note:: + - Windows平台上还不支持 `inv`。 + - 仅支持float32、float64、int32、int64类型的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **a** (Tensor) - 要求逆的方阵。 + - **overwrite_a** (bool, 可选) - 是否覆盖参数 `a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + + 返回: + Tensor,矩阵 `a` 的逆。 + + 异常: + - **LinAlgError** - 如果 :math:`a` 是单数。 + - **ValueError** - 如果 :math:`a` 不是2D方阵。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.lu.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.lu.rst index 0d8e7c011b1..c8ef3c4cc74 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.lu.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.lu.rst @@ -1,41 +1,41 @@ -mindspore.scipy.linalg.lu -========================= - -.. py:function:: mindspore.scipy.linalg.lu(a, permute_l=False, overwrite_a=False, check_finite=True) - - 计算通用矩阵的LU分解。 - - 分解为: - - .. math:: - A = P L U - - 其中, :math:`P` 是一个置换矩阵, :math:`L` 是对角线元素全为1的下三角矩阵, :math:`U` 是上三角矩阵。 - - .. note:: - - Windows平台上还不支持 `LU`。 - - 仅支持float32、float64、int32、int64类型的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **a** (Tensor) - 要分解的 :math:`(M, N)` 方阵。 - 如果输入Tensor不是float类型,那么它将被强制转换为:mstype.float32。 - - **permute_l** (bool, 可选) - 执行乘法运算 :math:`P L`(默认:不进行置换)。 - 默认值:``False``。 - - **overwrite_a** (bool, 可选) - 是否覆盖 :math:`a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - - 返回: - **如果 permute_l == False** - - - **p** (Tensor) - :math:`(M, M)` 置换矩阵。 - - **l** (Tensor) - :math:`(M, K)` 对角线元素全为1的下三角矩阵或梯形矩阵。 :math:`K = min(M, N)`。 - - **u** (Tensor) - :math:`(K, N)` 上三角矩阵或梯形矩阵。 - - **如果 permute_l == True** - - - **pl** (Tensor) - :math:`(M, K)` 置换L矩阵。 :math:`K = min(M,N)`。 - - **u** (Tensor) - :math:`(K, N)` 上三角矩阵或梯形矩阵。 +mindspore.scipy.linalg.lu +========================= + +.. py:function:: mindspore.scipy.linalg.lu(a, permute_l=False, overwrite_a=False, check_finite=True) + + 计算通用矩阵的LU分解。 + + 分解为: + + .. math:: + A = P L U + + 其中, :math:`P` 是一个置换矩阵, :math:`L` 是对角线元素全为1的下三角矩阵, :math:`U` 是上三角矩阵。 + + .. note:: + - Windows平台上还不支持 `LU`。 + - 仅支持float32、float64、int32、int64类型的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **a** (Tensor) - 要分解的 :math:`(M, N)` 方阵。 + 如果输入Tensor不是float类型,那么它将被强制转换为:mstype.float32。 + - **permute_l** (bool, 可选) - 执行乘法运算 :math:`P L`(默认:不进行置换)。 + 默认值:``False``。 + - **overwrite_a** (bool, 可选) - 是否覆盖 :math:`a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含INF或NaN,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + + 返回: + **如果 permute_l == False** + + - **p** (Tensor) - :math:`(M, M)` 置换矩阵。 + - **l** (Tensor) - :math:`(M, K)` 对角线元素全为1的下三角矩阵或梯形矩阵。 :math:`K = min(M, N)`。 + - **u** (Tensor) - :math:`(K, N)` 上三角矩阵或梯形矩阵。 + + **如果 permute_l == True** + + - **pl** (Tensor) - :math:`(M, K)` 置换L矩阵。 :math:`K = min(M,N)`。 + - **u** (Tensor) - :math:`(K, N)` 上三角矩阵或梯形矩阵。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.linalg.lu_factor.rst b/docs/api/api_python/scipy/mindspore.scipy.linalg.lu_factor.rst index 738315bb332..8f4aae86e7e 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.linalg.lu_factor.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.linalg.lu_factor.rst @@ -1,35 +1,35 @@ -mindspore.scipy.linalg.lu_factor -================================ - -.. py:function:: mindspore.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True) - - 计算方阵的LU分解,其输出可以直接作为 `lu_solve` 的输入。 - - 分解为: - - .. math:: - a = P L U - - 其中, :math:`P` 是一个置换矩阵, :math:`L` 是对角线元素全为1的下三角矩阵, :math:`U` 是上三角矩阵。 - - .. note:: - - Windows平台上还不支持 `lu_factor`。 - - 仅支持float32、float64、int32、int64的Tensor类型。 - - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 - - 参数: - - **a** (Tensor) - 要分解的 :math:`(M, M)` 方阵。 - 如果输入Tensor不是float类型,那么它将被强制转换为:mstype.float32。 - - **overwrite_a** (bool, 可选) - 是否覆盖 :math:`a` 中的数据(可能会提高性能)。 - 默认值:``False``。 - - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 - 禁用可能会带来性能增益,但如果输入确实包含 `INF` 或 `NaN`,则可能会导致问题(崩溃、程序不终止)。 - 默认值:``True``。 - - 返回: - - **lu** (Tensor) - 一个 :math:`(M, M)` 的方阵,在它的上三角中包含 `u`,它的下三角形中包含 `l`。 - 不含 `l` 中对角线全为1的元素。 - - **piv** (Tensor) - shape为 :math:`(M,)` 的Tensor,表示置换矩阵 `p` 的索引:索引中的第 `i` 个元素值 `j` 表示矩阵的第 `i` 行与第 `j` 行互换。 - - 异常: - - **ValueError** - 如果 :math:`a` 不是2D方阵。 +mindspore.scipy.linalg.lu_factor +================================ + +.. py:function:: mindspore.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True) + + 计算方阵的LU分解,其输出可以直接作为 `lu_solve` 的输入。 + + 分解为: + + .. math:: + a = P L U + + 其中, :math:`P` 是一个置换矩阵, :math:`L` 是对角线元素全为1的下三角矩阵, :math:`U` 是上三角矩阵。 + + .. note:: + - Windows平台上还不支持 `lu_factor`。 + - 仅支持float32、float64、int32、int64的Tensor类型。 + - 如果Tensor是int32、int64类型,它将被强制转换为:mstype.float64类型。 + + 参数: + - **a** (Tensor) - 要分解的 :math:`(M, M)` 方阵。 + 如果输入Tensor不是float类型,那么它将被强制转换为:mstype.float32。 + - **overwrite_a** (bool, 可选) - 是否覆盖 :math:`a` 中的数据(可能会提高性能)。 + 默认值:``False``。 + - **check_finite** (bool, 可选) - 是否检查输入矩阵是否只包含有限数。 + 禁用可能会带来性能增益,但如果输入确实包含 `INF` 或 `NaN`,则可能会导致问题(崩溃、程序不终止)。 + 默认值:``True``。 + + 返回: + - **lu** (Tensor) - 一个 :math:`(M, M)` 的方阵,在它的上三角中包含 `u`,它的下三角形中包含 `l`。 + 不含 `l` 中对角线全为1的元素。 + - **piv** (Tensor) - shape为 :math:`(M,)` 的Tensor,表示置换矩阵 `p` 的索引:索引中的第 `i` 个元素值 `j` 表示矩阵的第 `i` 行与第 `j` 行互换。 + + 异常: + - **ValueError** - 如果 :math:`a` 不是2D方阵。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.optimize.line_search.rst b/docs/api/api_python/scipy/mindspore.scipy.optimize.line_search.rst index 67cf41fad51..780f6aa2e62 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.optimize.line_search.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.optimize.line_search.rst @@ -1,28 +1,28 @@ -mindspore.scipy.optimize.line_search -==================================== - -.. py:function:: mindspore.scipy.optimize.line_search(f, xk, pk, jac=None, gfk=None, old_fval=None, old_old_fval=None, c1=0.0001, c2=0.9, maxiter=20) - - 满足强Wolfe条件的非精确线搜索。 - - 来自Wright和Nocedal,'Numerical Optimization',1999,第59-61页,算法3.5章节。 - - .. note:: - Windows平台上还不支持 `line_search`。 - - 参数: - - **f** (function) - 形式为f(x)的函数,其中x是一个扁平Tensor,并返回一个实数标量。 - 该函数应该由 `vjp` 定义的算子组成。 - - **xk** (Tensor) - 初始猜测。 - - **pk** (Tensor) - 要搜索的方向。假定方向是下降方向。 - - **jac** (function) - 求x处的梯度的函数,其中x是一个扁平的Tensor,函数返回一个Tensor。 - 如果要使用自动微分,则可以传 ``None``。 - - **gfk** (Tensor) - `value_and_gradient` 作为位置的初始值。默认值:``None``。 - - **old_fval** (Tensor) - 同 `gfk`。默认值:``None``。 - - **old_old_fval** (Tensor) - 未使用的参数,仅用于scipy API合规性。默认值:``None``。 - - **c1** (float) - Wolfe准则常量,参见ref。默认值:``1e-4``。 - - **c2** (float) - 与 `c1` 相同。默认值:``0.9``。 - - **maxiter** (int) - 搜索的最大迭代次数。默认值:``20``。 - - 返回: - 线搜索的结果。 +mindspore.scipy.optimize.line_search +==================================== + +.. py:function:: mindspore.scipy.optimize.line_search(f, xk, pk, jac=None, gfk=None, old_fval=None, old_old_fval=None, c1=0.0001, c2=0.9, maxiter=20) + + 满足强Wolfe条件的非精确线搜索。 + + 来自Wright和Nocedal,'Numerical Optimization',1999,第59-61页,算法3.5章节。 + + .. note:: + Windows平台上还不支持 `line_search`。 + + 参数: + - **f** (function) - 形式为f(x)的函数,其中x是一个扁平Tensor,并返回一个实数标量。 + 该函数应该由 `vjp` 定义的算子组成。 + - **xk** (Tensor) - 初始猜测。 + - **pk** (Tensor) - 要搜索的方向。假定方向是下降方向。 + - **jac** (function) - 求x处的梯度的函数,其中x是一个扁平的Tensor,函数返回一个Tensor。 + 如果要使用自动微分,则可以传 ``None``。 + - **gfk** (Tensor) - `value_and_gradient` 作为位置的初始值。默认值:``None``。 + - **old_fval** (Tensor) - 同 `gfk`。默认值:``None``。 + - **old_old_fval** (Tensor) - 未使用的参数,仅用于scipy API合规性。默认值:``None``。 + - **c1** (float) - Wolfe准则常量,参见ref。默认值:``1e-4``。 + - **c2** (float) - 与 `c1` 相同。默认值:``0.9``。 + - **maxiter** (int) - 搜索的最大迭代次数。默认值:``20``。 + + 返回: + 线搜索的结果。 diff --git a/docs/api/api_python/scipy/mindspore.scipy.optimize.minimize.rst b/docs/api/api_python/scipy/mindspore.scipy.optimize.minimize.rst index 05df2f88f5e..2332d7221a0 100644 --- a/docs/api/api_python/scipy/mindspore.scipy.optimize.minimize.rst +++ b/docs/api/api_python/scipy/mindspore.scipy.optimize.minimize.rst @@ -1,54 +1,54 @@ -mindspore.scipy.optimize.minimize -================================= - -.. py:function:: mindspore.scipy.optimize.minimize(func, x0, args=(), method=None, jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None, callback=None, options=None) - - 最小化一个或多个变量的标量函数。 - - 此函数的API与SciPy匹配,但有一些细微的差异: - - 当 `jac` 为None时,会使用MindSpore的自动微分功能计算 ``func`` 的反向梯度。 - ``method`` 参数是必需的。如果不指定求解器,将触发异常。 - 尚未实现SciPy接口中的如下可选参数:`"hess"`、`"hessp"`、`"bounds"`、`"constraints"`、`"tol"`、`"callback"`。 - 由于线搜索实现的差异,优化结果可能与SciPy不同。 - - .. note:: - - `minimize` 接口当前还不支持多维Tensor输入或求微分,但有支持的计划。 - - Windows平台上还不支持 `minimize`。 - - `LAGRANGE` 方法仅在 `GPU` 上支持。 - - 参数: - - **func** (Callable) - 要最小化的目标函数 :math:`fun(x,*args) -> float`,其中 `x` 是一个一维数组,其shape为 :math:`(n,)`。 - `args` 是一个Tuple,用于指定 `func` 的执行所需的所有参数。 - 当 `jac` 为None时,`func` 必须能支持微分。 - - **x0** (Tensor) - 初始猜测。shape为 :math:`(n,)` 的实数数组,其中 `n` 是自变量的个数。 - - **args** (Tuple) - 传递给目标函数的额外参数。默认值:``()`` 。 - - **method** (str) - 求解器类型。应为 `“BFGS”` 和 `“LBFGS”`、`“LAGRANGE”` 中的一种。 - - **jac** (Callable, 可选) - 计算梯度向量的函数。 - 只支持 `"BFGS"` 和 `"LBFGS"`。如果为None,则将使用 ``func`` 的反向梯度函数进行梯度计算。 - 如果 `jac` 是可执行的,则应该是能返回梯度向量的函数::math:`jac(x, *args) -> array\_like, shape (n,)`,其中x是一个数组,其shape为 :math:`(n,)`,`args` 是一个具有固定参数的元组。 - - **hess** (Callable, 可选) - 计算Hessian矩阵的方法。当前尚未实现。 - - **hessp** (Callable, 可选) - 目标函数的Hessian乘以任意向量 `p` 。当前尚未实现。 - - **bounds** (Sequence, 可选) - `x` 中的每个元素的 `(min, max)` 对的序列。当前尚未实现。 - - **constraints** (Callable, 可选) - 表示不等式的约束,约束中的每个函数都将 `function < 0` 表示为不等式约束。 - - **tol** (float, 可选) - 异常终止的容差范围。如需更具体的操控,请使用求解器里专门的选项。默认值:``None``。 - - **callback** (Callable, 可选) - 每次迭代后调用的可执行函数。当前尚未实现。 - - **options** (Mapping[str, Any], 可选) - 用于保存求解器可选项的字典。所有求解器方法都能支持下述通用选项。默认值:``None``。 - - - ``"history_size"`` (int) - 用于更新Hession矩阵的逆的缓冲区大小,仅支持与 `method="LBFGS"` 一起使用。默认值:``20``。 - - ``"maxiter"`` (int) - 要执行的最大迭代次数。根据方法的不同,每个迭代可能会使用多个函数进行求值。 - - 以下选项是拉格朗日方法的专有选项: - - - ``"save_tol"`` (list) - 保存 `tol` 的列表,长度与 `constrains` 相同。 - - ``"obj_weight"`` (float) - 目标函数的权重,通常在1.0 - 100000.0之间。 - - ``"lower"`` (Tensor) - 变量的下限约束,必须具有与 `x0` 相同的shape。 - - ``"upper"`` (Tensor) - 变量的上限约束,必须具有与 `x0` 相同的shape。 - - ``"learning_rate"`` (float) - 每个Adam步骤的学习率。 - - ``"coincide_func"`` (Callable) - 子函数,表示目标函数和约束之间的公共部分,用于避免冗余计算。 - - ``"rounds"`` (int) - 更新拉格朗日乘数的次数。 - - ``"steps"`` (int) - 每执行 `steps` 次就执行一次Adam去更新拉格朗日乘数。 - - ``"log_sw"`` (bool) - 是否打印每一步的 `loss` 值。 - - 返回: - 优化的结果。 +mindspore.scipy.optimize.minimize +================================= + +.. py:function:: mindspore.scipy.optimize.minimize(func, x0, args=(), method=None, jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None, callback=None, options=None) + + 最小化一个或多个变量的标量函数。 + + 此函数的API与SciPy匹配,但有一些细微的差异: + + 当 `jac` 为None时,会使用MindSpore的自动微分功能计算 ``func`` 的反向梯度。 + ``method`` 参数是必需的。如果不指定求解器,将触发异常。 + 尚未实现SciPy接口中的如下可选参数:`"hess"`、`"hessp"`、`"bounds"`、`"constraints"`、`"tol"`、`"callback"`。 + 由于线搜索实现的差异,优化结果可能与SciPy不同。 + + .. note:: + - `minimize` 接口当前还不支持多维Tensor输入或求微分,但有支持的计划。 + - Windows平台上还不支持 `minimize`。 + - `LAGRANGE` 方法仅在 `GPU` 上支持。 + + 参数: + - **func** (Callable) - 要最小化的目标函数 :math:`fun(x,*args) -> float`,其中 `x` 是一个一维数组,其shape为 :math:`(n,)`。 + `args` 是一个Tuple,用于指定 `func` 的执行所需的所有参数。 + 当 `jac` 为None时,`func` 必须能支持微分。 + - **x0** (Tensor) - 初始猜测。shape为 :math:`(n,)` 的实数数组,其中 `n` 是自变量的个数。 + - **args** (Tuple) - 传递给目标函数的额外参数。默认值:``()`` 。 + - **method** (str) - 求解器类型。应为 `“BFGS”` 和 `“LBFGS”`、`“LAGRANGE”` 中的一种。 + - **jac** (Callable, 可选) - 计算梯度向量的函数。 + 只支持 `"BFGS"` 和 `"LBFGS"`。如果为None,则将使用 ``func`` 的反向梯度函数进行梯度计算。 + 如果 `jac` 是可执行的,则应该是能返回梯度向量的函数::math:`jac(x, *args) -> array\_like, shape (n,)`,其中x是一个数组,其shape为 :math:`(n,)`,`args` 是一个具有固定参数的元组。 + - **hess** (Callable, 可选) - 计算Hessian矩阵的方法。当前尚未实现。 + - **hessp** (Callable, 可选) - 目标函数的Hessian乘以任意向量 `p` 。当前尚未实现。 + - **bounds** (Sequence, 可选) - `x` 中的每个元素的 `(min, max)` 对的序列。当前尚未实现。 + - **constraints** (Callable, 可选) - 表示不等式的约束,约束中的每个函数都将 `function < 0` 表示为不等式约束。 + - **tol** (float, 可选) - 异常终止的容差范围。如需更具体的操控,请使用求解器里专门的选项。默认值:``None``。 + - **callback** (Callable, 可选) - 每次迭代后调用的可执行函数。当前尚未实现。 + - **options** (Mapping[str, Any], 可选) - 用于保存求解器可选项的字典。所有求解器方法都能支持下述通用选项。默认值:``None``。 + + - ``"history_size"`` (int) - 用于更新Hession矩阵的逆的缓冲区大小,仅支持与 `method="LBFGS"` 一起使用。默认值:``20``。 + - ``"maxiter"`` (int) - 要执行的最大迭代次数。根据方法的不同,每个迭代可能会使用多个函数进行求值。 + + 以下选项是拉格朗日方法的专有选项: + + - ``"save_tol"`` (list) - 保存 `tol` 的列表,长度与 `constrains` 相同。 + - ``"obj_weight"`` (float) - 目标函数的权重,通常在1.0 - 100000.0之间。 + - ``"lower"`` (Tensor) - 变量的下限约束,必须具有与 `x0` 相同的shape。 + - ``"upper"`` (Tensor) - 变量的上限约束,必须具有与 `x0` 相同的shape。 + - ``"learning_rate"`` (float) - 每个Adam步骤的学习率。 + - ``"coincide_func"`` (Callable) - 子函数,表示目标函数和约束之间的公共部分,用于避免冗余计算。 + - ``"rounds"`` (int) - 更新拉格朗日乘数的次数。 + - ``"steps"`` (int) - 每执行 `steps` 次就执行一次Adam去更新拉格朗日乘数。 + - ``"log_sw"`` (bool) - 是否打印每一步的 `loss` 值。 + + 返回: + 优化的结果。 diff --git a/docs/api/api_python/train/mindspore.train.BleuScore.rst b/docs/api/api_python/train/mindspore.train.BleuScore.rst index ad715c21cae..d60d0d38079 100644 --- a/docs/api/api_python/train/mindspore.train.BleuScore.rst +++ b/docs/api/api_python/train/mindspore.train.BleuScore.rst @@ -1,38 +1,38 @@ -mindspore.train.BleuScore -========================== - -.. py:class:: mindspore.train.BleuScore(n_gram=4, smooth=False) - - 计算BLEU分数。BLEU指的是具有一个或多个引用的机器翻译文本的metric。 - - 参数: - - **n_gram** (int) - 取值范围为1~4。默认值: ``4`` 。 - - **smooth** (bool) - 是否采用平滑计算的方式。默认值: ``False`` 。 - - 异常: - - **ValueError** - `n_gram` 的取值范围不在1~4之间。 - - .. py:method:: clear() - - 重置评估结果。 - - .. py:method:: eval() - - 计算BLEU分数。 - - 返回: - numpy.ndarray,numpy类型的BLEU分数。 - - 异常: - - **RuntimeError** - 调用该方法前没有先调用update方法。 - - .. py:method:: update(*inputs) - - 使用输入的内容更新内部评估结果。 - - 参数: - - **inputs** (iterator) - 输入的元组,第一个输入是机器翻译语料库列表(candidate_corpus),第二个输入是引用语料库列表(reference_corpus)。 - - 异常: - - **ValueError** - 输入参数的数量不等于2。 +mindspore.train.BleuScore +========================== + +.. py:class:: mindspore.train.BleuScore(n_gram=4, smooth=False) + + 计算BLEU分数。BLEU指的是具有一个或多个引用的机器翻译文本的metric。 + + 参数: + - **n_gram** (int) - 取值范围为1~4。默认值: ``4`` 。 + - **smooth** (bool) - 是否采用平滑计算的方式。默认值: ``False`` 。 + + 异常: + - **ValueError** - `n_gram` 的取值范围不在1~4之间。 + + .. py:method:: clear() + + 重置评估结果。 + + .. py:method:: eval() + + 计算BLEU分数。 + + 返回: + numpy.ndarray,numpy类型的BLEU分数。 + + 异常: + - **RuntimeError** - 调用该方法前没有先调用update方法。 + + .. py:method:: update(*inputs) + + 使用输入的内容更新内部评估结果。 + + 参数: + - **inputs** (iterator) - 输入的元组,第一个输入是机器翻译语料库列表(candidate_corpus),第二个输入是引用语料库列表(reference_corpus)。 + + 异常: + - **ValueError** - 输入参数的数量不等于2。 - **ValueError** - `candidate_corpus` 的长度与 `reference_corpus` 不同。 \ No newline at end of file diff --git a/docs/api/api_python/train/mindspore.train.ConfusionMatrix.rst b/docs/api/api_python/train/mindspore.train.ConfusionMatrix.rst index 7c7d63b8365..2bd2d19a0b4 100644 --- a/docs/api/api_python/train/mindspore.train.ConfusionMatrix.rst +++ b/docs/api/api_python/train/mindspore.train.ConfusionMatrix.rst @@ -1,42 +1,42 @@ -mindspore.train.ConfusionMatrix -================================ - -.. py:class:: mindspore.train.ConfusionMatrix(num_classes, normalize="no_norm", threshold=0.5) - - 计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。 - - 如果只想使用混淆矩阵,请使用该类。如果想计算"PPV"、"TPR"、"TNR"等,请使用 :class:`mindspore.train.ConfusionMatrixMetric` 类。 - - 参数: - - **num_classes** (int) - 数据集中的类别数量。 - - **normalize** (str) - 计算ConfusionMatrix的参数支持四种归一化模式,默认值: ``"no_norm"`` 。 - - - ``"no_norm"`` :不使用标准化。 - - ``"target"`` :基于目标值的标准化。 - - ``"prediction"`` :基于预测值的标准化。 - - ``"all"`` :整个矩阵的标准化。 - - - **threshold** (float) - 阈值,用于与输入Tensor进行比较。默认值: ``0.5`` 。 - - .. py:method:: clear() - - 重置评估结果。 - - .. py:method:: eval() - - 计算混淆矩阵。 - - 返回: - numpy.ndarray,计算的结果。 - - .. py:method:: update(*inputs) - - 使用y_pred和y更新内部评估结果。 - - 参数: - - ***inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。 - `y_pred` 是预测值, `y` 是真实值, `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` , `y` 的shape是 :math:`(N, ...)` 。 - - 异常: - - **ValueError** - 输入参数的数量不等于2。 - - **ValueError** - 如果预测值和标签的维度不一致。 +mindspore.train.ConfusionMatrix +================================ + +.. py:class:: mindspore.train.ConfusionMatrix(num_classes, normalize="no_norm", threshold=0.5) + + 计算混淆矩阵(confusion matrix),通常用于评估分类模型的性能,包括二分类和多分类场景。 + + 如果只想使用混淆矩阵,请使用该类。如果想计算"PPV"、"TPR"、"TNR"等,请使用 :class:`mindspore.train.ConfusionMatrixMetric` 类。 + + 参数: + - **num_classes** (int) - 数据集中的类别数量。 + - **normalize** (str) - 计算ConfusionMatrix的参数支持四种归一化模式,默认值: ``"no_norm"`` 。 + + - ``"no_norm"`` :不使用标准化。 + - ``"target"`` :基于目标值的标准化。 + - ``"prediction"`` :基于预测值的标准化。 + - ``"all"`` :整个矩阵的标准化。 + + - **threshold** (float) - 阈值,用于与输入Tensor进行比较。默认值: ``0.5`` 。 + + .. py:method:: clear() + + 重置评估结果。 + + .. py:method:: eval() + + 计算混淆矩阵。 + + 返回: + numpy.ndarray,计算的结果。 + + .. py:method:: update(*inputs) + + 使用y_pred和y更新内部评估结果。 + + 参数: + - ***inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。 + `y_pred` 是预测值, `y` 是真实值, `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` , `y` 的shape是 :math:`(N, ...)` 。 + + 异常: + - **ValueError** - 输入参数的数量不等于2。 + - **ValueError** - 如果预测值和标签的维度不一致。 diff --git a/docs/api/api_python/train/mindspore.train.ConfusionMatrixMetric.rst b/docs/api/api_python/train/mindspore.train.ConfusionMatrixMetric.rst index d810b72d594..98b5796e902 100644 --- a/docs/api/api_python/train/mindspore.train.ConfusionMatrixMetric.rst +++ b/docs/api/api_python/train/mindspore.train.ConfusionMatrixMetric.rst @@ -1,48 +1,48 @@ -mindspore.train.ConfusionMatrixMetric -====================================== - -.. py:class:: mindspore.train.ConfusionMatrixMetric(skip_channel=True, metric_name="sensitivity", calculation_method=False, decrease='mean') - - 计算与混淆矩阵相关的度量。 - - 该计算基于全尺度张量,并收集批处理平均值,类通道数和迭代数。 - 此函数支持计算参数metric_name中描述中列出的所有度量名称。 - - 如果要使用混淆矩阵计算,如"PPV"、"TPR"、"TNR",请使用此类。 - 如果只想计算混淆矩阵,请使用 :class:`mindspore.train.ConfusionMatrix` 。 - - 参数: - - **skip_channel** (bool) - 是否跳过预测输出的第一个通道的度量计算。默认值: ``True`` 。 - - **metric_name** (str) - 建议采用如下指标。当然,也可以为这些指标设置通用别名。 - 取值范围:["sensitivity", "specificity", "precision", "negative predictive value", "miss rate", "fall out", "false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy", "balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"]。 - 默认值: ``"sensitivity"`` 。 - - **calculation_method** (bool) - 如果为True,则计算每个样本的度量值。如果为False,则累积所有样本的混淆矩阵。 - 对于分类任务, `calculation_method` 应为False。默认值: ``False`` 。 - - **decrease** (str) - 定义减少一批数据计算结果的模式。仅当 `calculation_method` 为True时,才生效。 - 取值范围:["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"]。默认值: ``"mean"`` 。 - - .. py:method:: clear() - - 重置评估结果。 - - .. py:method:: eval() - - 计算混淆矩阵度量。 - - 返回: - numpy.ndarray,计算的结果。 - - .. py:method:: update(*inputs) - - 使用预测值和目标值更新状态。 - - 参数: - - **inputs** (tuple) - `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。 - - - **y_pred** (ndarray) - 待计算的输入数据。格式必须为one-hot,且第一个维度是batch。 - `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` 。 - 至于分类任务, `y_pred` 的shape应为 :math:`(B, N)` ,其中N大于1。对于分割任务,shape应为 :math:`(B, N, H, W)` 或 :math:`(B, N, H, W, D)` 。 - - **y** (ndarray) - 计算度量值的真实值。格式必须为one-hot,且第一个维度是batch。`y` 的shape是 :math:`(N, C, ...)` 。 - - 异常: - - **ValueError** - 输入参数的数量不等于2。 +mindspore.train.ConfusionMatrixMetric +====================================== + +.. py:class:: mindspore.train.ConfusionMatrixMetric(skip_channel=True, metric_name="sensitivity", calculation_method=False, decrease='mean') + + 计算与混淆矩阵相关的度量。 + + 该计算基于全尺度张量,并收集批处理平均值,类通道数和迭代数。 + 此函数支持计算参数metric_name中描述中列出的所有度量名称。 + + 如果要使用混淆矩阵计算,如"PPV"、"TPR"、"TNR",请使用此类。 + 如果只想计算混淆矩阵,请使用 :class:`mindspore.train.ConfusionMatrix` 。 + + 参数: + - **skip_channel** (bool) - 是否跳过预测输出的第一个通道的度量计算。默认值: ``True`` 。 + - **metric_name** (str) - 建议采用如下指标。当然,也可以为这些指标设置通用别名。 + 取值范围:["sensitivity", "specificity", "precision", "negative predictive value", "miss rate", "fall out", "false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy", "balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"]。 + 默认值: ``"sensitivity"`` 。 + - **calculation_method** (bool) - 如果为True,则计算每个样本的度量值。如果为False,则累积所有样本的混淆矩阵。 + 对于分类任务, `calculation_method` 应为False。默认值: ``False`` 。 + - **decrease** (str) - 定义减少一批数据计算结果的模式。仅当 `calculation_method` 为True时,才生效。 + 取值范围:["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"]。默认值: ``"mean"`` 。 + + .. py:method:: clear() + + 重置评估结果。 + + .. py:method:: eval() + + 计算混淆矩阵度量。 + + 返回: + numpy.ndarray,计算的结果。 + + .. py:method:: update(*inputs) + + 使用预测值和目标值更新状态。 + + 参数: + - **inputs** (tuple) - `y_pred` 和 `y` 。 `y_pred` 和 `y` 是 `Tensor` 、列表或数组。 + + - **y_pred** (ndarray) - 待计算的输入数据。格式必须为one-hot,且第一个维度是batch。 + `y_pred` 的shape是 :math:`(N, C, ...)` 或 :math:`(N, ...)` 。 + 至于分类任务, `y_pred` 的shape应为 :math:`(B, N)` ,其中N大于1。对于分割任务,shape应为 :math:`(B, N, H, W)` 或 :math:`(B, N, H, W, D)` 。 + - **y** (ndarray) - 计算度量值的真实值。格式必须为one-hot,且第一个维度是batch。`y` 的shape是 :math:`(N, C, ...)` 。 + + 异常: + - **ValueError** - 输入参数的数量不等于2。 diff --git a/docs/api/api_python/train/mindspore.train.Dice.rst b/docs/api/api_python/train/mindspore.train.Dice.rst index 5dedc084f94..6f6bfe05e4f 100644 --- a/docs/api/api_python/train/mindspore.train.Dice.rst +++ b/docs/api/api_python/train/mindspore.train.Dice.rst @@ -1,39 +1,39 @@ -mindspore.train.Dice -===================== - -.. py:class:: mindspore.train.Dice(smooth=1e-5) - - 集合相似性度量。 - - 用于计算两个样本之间的相似性。当分割结果最好时,Dice系数的值为1,当分割结果最差时,Dice系数的值为0。Dice系数表示预测值与真实值交集同预测值和真实值并集之间的比值。 - - .. math:: - dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} - - 参数: - - **smooth** (float) - 在计算过程中添加到分母里,用于提高数值稳定性,取值需大于0。默认值: ``1e-5`` 。 - - .. py:method:: clear() - - 重置评估结果。 - - .. py:method:: eval() - - 计算Dice系数。 - - 返回: - float,计算的结果。 - - 异常: - - **RuntimeError** - 样本数为0。 - - .. py:method:: update(*inputs) - - 更新内部评估结果 `y_pred` 和 `y` 。 - - 参数: - - **inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是tensor、列表或numpy.ndarray。 `y_pred` 是预测值, `y` 是真实值。 `y_pred` 和 `y` 的shape都是 :math:`(N, ...)`。 - - 异常: - - **ValueError** - 输入参数的数量不等于2。 - - **ValueError** - 如果预测值和标签shape不一致。 +mindspore.train.Dice +===================== + +.. py:class:: mindspore.train.Dice(smooth=1e-5) + + 集合相似性度量。 + + 用于计算两个样本之间的相似性。当分割结果最好时,Dice系数的值为1,当分割结果最差时,Dice系数的值为0。Dice系数表示预测值与真实值交集同预测值和真实值并集之间的比值。 + + .. math:: + dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} + + 参数: + - **smooth** (float) - 在计算过程中添加到分母里,用于提高数值稳定性,取值需大于0。默认值: ``1e-5`` 。 + + .. py:method:: clear() + + 重置评估结果。 + + .. py:method:: eval() + + 计算Dice系数。 + + 返回: + float,计算的结果。 + + 异常: + - **RuntimeError** - 样本数为0。 + + .. py:method:: update(*inputs) + + 更新内部评估结果 `y_pred` 和 `y` 。 + + 参数: + - **inputs** (tuple) - 输入 `y_pred` 和 `y` 。 `y_pred` 和 `y` 是tensor、列表或numpy.ndarray。 `y_pred` 是预测值, `y` 是真实值。 `y_pred` 和 `y` 的shape都是 :math:`(N, ...)`。 + + 异常: + - **ValueError** - 输入参数的数量不等于2。 + - **ValueError** - 如果预测值和标签shape不一致。 diff --git a/docs/api/api_python/train/mindspore.train.HausdorffDistance.rst b/docs/api/api_python/train/mindspore.train.HausdorffDistance.rst index 7f377771b9c..4531a51a668 100644 --- a/docs/api/api_python/train/mindspore.train.HausdorffDistance.rst +++ b/docs/api/api_python/train/mindspore.train.HausdorffDistance.rst @@ -1,50 +1,50 @@ -mindspore.train.HausdorffDistance -============================================ - -.. py:class:: mindspore.train.HausdorffDistance(distance_metric="euclidean", percentile=None, directed=False, crop=True) - - 计算Hausdorff距离。Hausdorff距离是两个点集之间两点的最小距离的最大值,度量了两个点集间的最大不匹配程度。 - - 给定两个集合A和B,A和B之间的Hausdorff距离定义如下: - - .. math:: - \begin{array}{ll} \\ - H(A, B) = \text{max}[h(A, B), h(B, A)]\\ - h(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}\\ - h(B, A) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \} - \end{array} - - 其中 :math:`h(A, B)` 表示,对A中的每个点a找到B集合里的最近点,这些最短距离的最大值为从A到B的单向Hausdorff距离,同理,:math:`h(B, A)` 为集合B到集合A中最近点的最大距离。Hausdoff距离是有方向性的,通常情况下 :math:`h(A, B)` 不等于 :math:`h(B, A)`。:math:`H(A, B)` 为双向Hausdorff距离。 - - 参数: - - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 - - **percentile** (float) - 0到100之间的浮点数。指定最终返回的Hausdorff距离的百分位数。默认值: ``None`` 。 - - **directed** (bool) - 如果为True,为单向Hausdorff距离,只计算h(y_pred, y)距离;如果为False,为双向Hausdorff距离,计算max(h(y_pred, y), h(y, y_pred))。默认值: ``False`` 。 - - **crop** (bool) - 是否裁剪输入图像,仅保留目标区域。为了保证y_pred和y的shape匹配,使用(y_pred | y),即两图像的并集来确定bounding box。默认值: ``True`` 。 - - .. py:method:: clear() - - 内部评估结果清零。 - - .. py:method:: eval() - - 计算定向或非定向Hausdorff距离。 - - 返回: - numpy.float64,计算得到的Hausdorff距离。 - - 异常: - - **RuntimeError** - 如果没有先调用update方法。 - - .. py:method:: update(*inputs) - - 使用 `y_pred`、 `y` 和 `label_idx` 更新内部评估结果。 - - 参数: - - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor, list或numpy.ndarray, `y_pred` 是预测的二值图像, `y` 是实际的二值图像。 `label_idx` 的数据类型为int或float,表示像素点的类别值。 - - 异常: - - **ValueError** - 输入的数量不等于3。 - - **TypeError** - label_idx 的数据类型不是int或float。 - - **ValueError** - label_idx 的值不在y_pred或y中。 - - **ValueError** - y_pred 和 y 的shape不同。 +mindspore.train.HausdorffDistance +============================================ + +.. py:class:: mindspore.train.HausdorffDistance(distance_metric="euclidean", percentile=None, directed=False, crop=True) + + 计算Hausdorff距离。Hausdorff距离是两个点集之间两点的最小距离的最大值,度量了两个点集间的最大不匹配程度。 + + 给定两个集合A和B,A和B之间的Hausdorff距离定义如下: + + .. math:: + \begin{array}{ll} \\ + H(A, B) = \text{max}[h(A, B), h(B, A)]\\ + h(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}\\ + h(B, A) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \} + \end{array} + + 其中 :math:`h(A, B)` 表示,对A中的每个点a找到B集合里的最近点,这些最短距离的最大值为从A到B的单向Hausdorff距离,同理,:math:`h(B, A)` 为集合B到集合A中最近点的最大距离。Hausdoff距离是有方向性的,通常情况下 :math:`h(A, B)` 不等于 :math:`h(B, A)`。:math:`H(A, B)` 为双向Hausdorff距离。 + + 参数: + - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 + - **percentile** (float) - 0到100之间的浮点数。指定最终返回的Hausdorff距离的百分位数。默认值: ``None`` 。 + - **directed** (bool) - 如果为True,为单向Hausdorff距离,只计算h(y_pred, y)距离;如果为False,为双向Hausdorff距离,计算max(h(y_pred, y), h(y, y_pred))。默认值: ``False`` 。 + - **crop** (bool) - 是否裁剪输入图像,仅保留目标区域。为了保证y_pred和y的shape匹配,使用(y_pred | y),即两图像的并集来确定bounding box。默认值: ``True`` 。 + + .. py:method:: clear() + + 内部评估结果清零。 + + .. py:method:: eval() + + 计算定向或非定向Hausdorff距离。 + + 返回: + numpy.float64,计算得到的Hausdorff距离。 + + 异常: + - **RuntimeError** - 如果没有先调用update方法。 + + .. py:method:: update(*inputs) + + 使用 `y_pred`、 `y` 和 `label_idx` 更新内部评估结果。 + + 参数: + - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor, list或numpy.ndarray, `y_pred` 是预测的二值图像, `y` 是实际的二值图像。 `label_idx` 的数据类型为int或float,表示像素点的类别值。 + + 异常: + - **ValueError** - 输入的数量不等于3。 + - **TypeError** - label_idx 的数据类型不是int或float。 + - **ValueError** - label_idx 的值不在y_pred或y中。 + - **ValueError** - y_pred 和 y 的shape不同。 diff --git a/docs/api/api_python/train/mindspore.train.MeanSurfaceDistance.rst b/docs/api/api_python/train/mindspore.train.MeanSurfaceDistance.rst index 8c7976e43ca..fa57108c081 100644 --- a/docs/api/api_python/train/mindspore.train.MeanSurfaceDistance.rst +++ b/docs/api/api_python/train/mindspore.train.MeanSurfaceDistance.rst @@ -1,57 +1,57 @@ -mindspore.train.MeanSurfaceDistance -=============================================== - -.. py:class:: mindspore.train.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean") - - 计算从 `y_pred` 到 `y` 的平均表面距离。通常情况下,用来衡量分割任务中,预测情况和真实情况之间的差异度。 - - 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为: - - .. math:: - {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert - - 从集合B到集合A的平均表面距离(Average Surface Distance)为: - - .. math:: - AvgSurDis(B \rightarrow A) = \frac{\sum_{s_{B} \in S(B)}^{} {\text{dis} \left - ( s_{B}, S(A) \right )} } {\left | S(B) \right |} - - 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。 - - 从集合B到集合A以及从集合A到集合B的表面距离平均值为: - - .. math:: - MeanSurDis(A \leftrightarrow B) = \frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A}, S(B) \right )} - + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right )} }{\left | S(A) \right | + - \left | S(B) \right |} - - 参数: - - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 - - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为False,计算方式为 :math:`AvgSurDis(y\_pred \rightarrow y)` ,如果为True,计算方式为 :math:`MeanSurDis(y\_pred \leftrightarrow y)` 。默认值: ``False`` 。 - - .. py:method:: clear() - - 内部评估结果清零。 - - .. py:method:: eval() - - 计算平均表面距离。 - - 返回: - numpy.float64,计算得到的平均表面距离值。 - - 异常: - - **RuntimeError** - 如果没有先调用update方法。 - - .. py:method:: update(*inputs) - - 使用 `y_pred`、 `y` 和 `label_idx` 更新内部评估结果。 - - 参数: - - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor,list或numpy.ndarray, `y_pred` 是预测的二值图像。 `y` 是实际的二值图像。 `label_idx` 数据类型为int或float,表示像素点的类别值。 - - 异常: - - **ValueError** - 输入的数量不等于3。 - - **TypeError** - `label_idx` 的数据类型不是int或float。 - - **ValueError** - `label_idx` 的值不在y_pred或y中。 - - **ValueError** - `y_pred` 和 `y` 的shape不同。 +mindspore.train.MeanSurfaceDistance +=============================================== + +.. py:class:: mindspore.train.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean") + + 计算从 `y_pred` 到 `y` 的平均表面距离。通常情况下,用来衡量分割任务中,预测情况和真实情况之间的差异度。 + + 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为: + + .. math:: + {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert + + 从集合B到集合A的平均表面距离(Average Surface Distance)为: + + .. math:: + AvgSurDis(B \rightarrow A) = \frac{\sum_{s_{B} \in S(B)}^{} {\text{dis} \left + ( s_{B}, S(A) \right )} } {\left | S(B) \right |} + + 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。 + + 从集合B到集合A以及从集合A到集合B的表面距离平均值为: + + .. math:: + MeanSurDis(A \leftrightarrow B) = \frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A}, S(B) \right )} + + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right )} }{\left | S(A) \right | + + \left | S(B) \right |} + + 参数: + - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 + - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为False,计算方式为 :math:`AvgSurDis(y\_pred \rightarrow y)` ,如果为True,计算方式为 :math:`MeanSurDis(y\_pred \leftrightarrow y)` 。默认值: ``False`` 。 + + .. py:method:: clear() + + 内部评估结果清零。 + + .. py:method:: eval() + + 计算平均表面距离。 + + 返回: + numpy.float64,计算得到的平均表面距离值。 + + 异常: + - **RuntimeError** - 如果没有先调用update方法。 + + .. py:method:: update(*inputs) + + 使用 `y_pred`、 `y` 和 `label_idx` 更新内部评估结果。 + + 参数: + - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor,list或numpy.ndarray, `y_pred` 是预测的二值图像。 `y` 是实际的二值图像。 `label_idx` 数据类型为int或float,表示像素点的类别值。 + + 异常: + - **ValueError** - 输入的数量不等于3。 + - **TypeError** - `label_idx` 的数据类型不是int或float。 + - **ValueError** - `label_idx` 的值不在y_pred或y中。 + - **ValueError** - `y_pred` 和 `y` 的shape不同。 diff --git a/docs/api/api_python/train/mindspore.train.OcclusionSensitivity.rst b/docs/api/api_python/train/mindspore.train.OcclusionSensitivity.rst index 6217cb45a7f..39a7bc5773b 100644 --- a/docs/api/api_python/train/mindspore.train.OcclusionSensitivity.rst +++ b/docs/api/api_python/train/mindspore.train.OcclusionSensitivity.rst @@ -1,40 +1,40 @@ -mindspore.train.OcclusionSensitivity -============================================= - -.. py:class:: mindspore.train.OcclusionSensitivity(pad_val=0.0, margin=2, n_batch=128, b_box=None) - - 用于计算神经网络对给定图像的遮挡灵敏度(Occlusion Sensitivity),表示了图像的哪些部分对神经网络的分类决策最重要。 - - 遮挡敏感度是指神经网络对图像的类别预测概率如何随着图像被遮挡部分的变化而变化。遮挡敏感度值越高,意味着模型对类别预测的概率值下降越大,说明遮挡区域在神经网络的分类决策过程中越重要。 - - 参数: - - **pad_val** (float) - 图像中被遮挡部分的填充值。默认值: ``0.0`` 。 - - **margin** (Union[int, Sequence]) - 在要遮挡的像素点周围设置的长方体/立方体。默认值: ``2`` 。 - - **n_batch** (int) - 一个batch中样本的数量。默认值: ``128`` 。 - - **b_box** (Sequence) - 执行分析的目标区域的边界框(Bounding box),其大小与输出图像的大小相匹配。如果没有设置此入参,Bounding box将与输入图像的大小相同;如果设置了此入参,输入图像将被裁剪为此大小,此设置值应形如:``[min1, max1, min2, max2,...]``,分别对应除batch size外各维度的最大最小值。默认值: ``None`` 。 - - .. py:method:: clear() - - 内部评估结果清零。 - - .. py:method:: eval() - - 计算遮挡敏感度。 - - 返回: - numpy ndarray。计算得到的遮挡敏感度值。 - - 异常: - - **RuntimeError** - 如果没有先调用update方法,则会报错。 - - .. py:method:: update(*inputs) - - 更新inputs,包括 `model` 、 `y_pred` 和 `label` 。 - - 参数: - - **inputs** - `y_pred` 和 `label` 为Tensor,list或numpy.ndarray,`y_pred` 是要测试的图像,一般为2D或3D,`label` 是用于检测神经网络预测值变化的类别标签,通常情况下为真实标签。`model` 为神经网络模型。 - - 异常: - - **ValueError** - 输入数量不是3。 - - **RuntimeError** - `y_pred.shape[0]` 不是1。 - - **RuntimeError** - 标签数量与batch数量不同。 +mindspore.train.OcclusionSensitivity +============================================= + +.. py:class:: mindspore.train.OcclusionSensitivity(pad_val=0.0, margin=2, n_batch=128, b_box=None) + + 用于计算神经网络对给定图像的遮挡灵敏度(Occlusion Sensitivity),表示了图像的哪些部分对神经网络的分类决策最重要。 + + 遮挡敏感度是指神经网络对图像的类别预测概率如何随着图像被遮挡部分的变化而变化。遮挡敏感度值越高,意味着模型对类别预测的概率值下降越大,说明遮挡区域在神经网络的分类决策过程中越重要。 + + 参数: + - **pad_val** (float) - 图像中被遮挡部分的填充值。默认值: ``0.0`` 。 + - **margin** (Union[int, Sequence]) - 在要遮挡的像素点周围设置的长方体/立方体。默认值: ``2`` 。 + - **n_batch** (int) - 一个batch中样本的数量。默认值: ``128`` 。 + - **b_box** (Sequence) - 执行分析的目标区域的边界框(Bounding box),其大小与输出图像的大小相匹配。如果没有设置此入参,Bounding box将与输入图像的大小相同;如果设置了此入参,输入图像将被裁剪为此大小,此设置值应形如:``[min1, max1, min2, max2,...]``,分别对应除batch size外各维度的最大最小值。默认值: ``None`` 。 + + .. py:method:: clear() + + 内部评估结果清零。 + + .. py:method:: eval() + + 计算遮挡敏感度。 + + 返回: + numpy ndarray。计算得到的遮挡敏感度值。 + + 异常: + - **RuntimeError** - 如果没有先调用update方法,则会报错。 + + .. py:method:: update(*inputs) + + 更新inputs,包括 `model` 、 `y_pred` 和 `label` 。 + + 参数: + - **inputs** - `y_pred` 和 `label` 为Tensor,list或numpy.ndarray,`y_pred` 是要测试的图像,一般为2D或3D,`label` 是用于检测神经网络预测值变化的类别标签,通常情况下为真实标签。`model` 为神经网络模型。 + + 异常: + - **ValueError** - 输入数量不是3。 + - **RuntimeError** - `y_pred.shape[0]` 不是1。 + - **RuntimeError** - 标签数量与batch数量不同。 diff --git a/docs/api/api_python/train/mindspore.train.Perplexity.rst b/docs/api/api_python/train/mindspore.train.Perplexity.rst index e9a3818ef50..725471e4408 100644 --- a/docs/api/api_python/train/mindspore.train.Perplexity.rst +++ b/docs/api/api_python/train/mindspore.train.Perplexity.rst @@ -1,40 +1,40 @@ -mindspore.train.Perplexity -=========================== - -.. py:class:: mindspore.train.Perplexity(ignore_label=None) - - 计算困惑度(perplexity)。困惑度是衡量一个概率分布或语言模型好坏的标准。低困惑度表明语言模型可以很好地预测样本。计算方式如下: - - .. math:: - PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}} - - 其中 :math:`w` 代表语料库中的单词。根号内是句子概率的倒数,句子越好(概率大),困惑度越小。 - - 参数: - - **ignore_label** (Union[int, None]) - 计数时要忽略的无效标签的索引。如果设置为None,它将包括所有条目。默认值: ``None`` 。 - - .. py:method:: clear() - - 内部评估结果清零。 - - .. py:method:: eval() - - 返回当前评估结果。 - - 返回: - numpy.float64,计算得到的困惑度结果。 - - 异常: - - **RuntimeError** - 样本量为0。 - - .. py:method:: update(*inputs) - - 使用 `preds` 和 `labels` 更新内部评估结果。 - - 参数: - - **inputs** - 输入 `preds` 和 `labels` 。 `preds` 和 `labels` 是Tensor、list或numpy.ndarray。 `preds` 是预测值, `labels` 是数据的标签。 `preds` 和 `labels` 的shape都是 :math:`(N, C)` 。 - - 异常: - - **ValueError** - 输入数量不是2。 - - **RuntimeError** - 预测值和标签的长度不同。 - - **RuntimeError** - 预测值和标签的shape不同。 +mindspore.train.Perplexity +=========================== + +.. py:class:: mindspore.train.Perplexity(ignore_label=None) + + 计算困惑度(perplexity)。困惑度是衡量一个概率分布或语言模型好坏的标准。低困惑度表明语言模型可以很好地预测样本。计算方式如下: + + .. math:: + PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}} + + 其中 :math:`w` 代表语料库中的单词。根号内是句子概率的倒数,句子越好(概率大),困惑度越小。 + + 参数: + - **ignore_label** (Union[int, None]) - 计数时要忽略的无效标签的索引。如果设置为None,它将包括所有条目。默认值: ``None`` 。 + + .. py:method:: clear() + + 内部评估结果清零。 + + .. py:method:: eval() + + 返回当前评估结果。 + + 返回: + numpy.float64,计算得到的困惑度结果。 + + 异常: + - **RuntimeError** - 样本量为0。 + + .. py:method:: update(*inputs) + + 使用 `preds` 和 `labels` 更新内部评估结果。 + + 参数: + - **inputs** - 输入 `preds` 和 `labels` 。 `preds` 和 `labels` 是Tensor、list或numpy.ndarray。 `preds` 是预测值, `labels` 是数据的标签。 `preds` 和 `labels` 的shape都是 :math:`(N, C)` 。 + + 异常: + - **ValueError** - 输入数量不是2。 + - **RuntimeError** - 预测值和标签的长度不同。 + - **RuntimeError** - 预测值和标签的shape不同。 diff --git a/docs/api/api_python/train/mindspore.train.RootMeanSquareDistance.rst b/docs/api/api_python/train/mindspore.train.RootMeanSquareDistance.rst index b93be595110..e9a64c4cfeb 100644 --- a/docs/api/api_python/train/mindspore.train.RootMeanSquareDistance.rst +++ b/docs/api/api_python/train/mindspore.train.RootMeanSquareDistance.rst @@ -1,57 +1,57 @@ -mindspore.train.RootMeanSquareDistance -======================================= - -.. py:class:: mindspore.train.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean") - - 计算从 `y_pred` 到 `y` 的均方根表面距离。 - - 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为: - - .. math:: - {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert - - 从集合B到集合A的均方根表面距离(Root Mean Square Surface Distance)为: - - .. math:: - RmsSurDis(B \rightarrow A) = \sqrt{\frac{\sum_{s_{B} \in S(B)}^{} {\text{dis}^2 \left ( s_{B}, S(A) - \right )} }{\left | S(B) \right |}} - - 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。 - - 从集合B到集合A以及从集合A到集合B的表面距离平均值为: - - .. math:: - RmsSurDis(A \leftrightarrow B) = \sqrt{\frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A}, - S(B) \right ) ^{2}} + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right ) ^{2}}}{\left | S(A) - \right | + \left | S(B) \right |}} - - 参数: - - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 - - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为 ``False`` ,计算方式为 :math:`RmsSurDis(y\_pred, y)` ,如果为 ``True`` ,计算方式为 :math:`RmsSurDis(y\_pred \leftrightarrow y)` 。默认值: ``False`` 。 - - .. py:method:: clear() - - 内部评估结果清零。 - - .. py:method:: eval() - - 计算均方根表面距离。 - - 返回: - numpy.float64,计算得到的均方根表面距离值。 - - 异常: - - **RuntimeError** - 如果没有先调用update方法,则会报错。 - - .. py:method:: update(*inputs) - - 使用 `y_pred`、`y` 和 `label_idx` 更新内部评估结果。 - - 参数: - - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor,list或numpy.ndarray, `y_pred` 是预测的二值图像。 `y` 是实际的二值图像。 `label_idx` 数据类型为int或float,表示像素点的类别值。 - - 异常: - - **ValueError** - 输入的数量不等于3。 - - **TypeError** - `label_idx` 的数据类型不是int或float。 - - **ValueError** - `label_idx` 的值不在y_pred或y中。 - - **ValueError** - `y_pred` 和 `y` 的shape不同。 +mindspore.train.RootMeanSquareDistance +======================================= + +.. py:class:: mindspore.train.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean") + + 计算从 `y_pred` 到 `y` 的均方根表面距离。 + + 给定两个集合A和B,S(A)表示A的表面像素,任意v到S(A)的最短距离定义为: + + .. math:: + {\text{dis}}\left (v, S(A)\right ) = \underset{s_{A} \in S(A)}{\text{min }}\rVert v - s_{A} \rVert + + 从集合B到集合A的均方根表面距离(Root Mean Square Surface Distance)为: + + .. math:: + RmsSurDis(B \rightarrow A) = \sqrt{\frac{\sum_{s_{B} \in S(B)}^{} {\text{dis}^2 \left ( s_{B}, S(A) + \right )} }{\left | S(B) \right |}} + + 其中 \|\|\*\|\| 表示距离度量。 \|\*\| 表示元素的数量。 + + 从集合B到集合A以及从集合A到集合B的表面距离平均值为: + + .. math:: + RmsSurDis(A \leftrightarrow B) = \sqrt{\frac{\sum_{s_{A} \in S(A)}^{} {\text{dis} \left ( s_{A}, + S(B) \right ) ^{2}} + \sum_{s_{B} \in S(B)}^{} {\text{dis} \left ( s_{B}, S(A) \right ) ^{2}}}{\left | S(A) + \right | + \left | S(B) \right |}} + + 参数: + - **distance_metric** (string) - 支持如下三种距离计算方法: ``"euclidean"`` (欧式距离)、 ``"chessboard"`` (棋盘距离、切比雪夫距离) 或 ``"taxicab"`` (出租车距离、曼哈顿距离)。默认值: ``"euclidean"`` 。 + - **symmetric** (bool) - 是否计算 `y_pred` 和 `y` 之间的对称平均平面距离。如果为 ``False`` ,计算方式为 :math:`RmsSurDis(y\_pred, y)` ,如果为 ``True`` ,计算方式为 :math:`RmsSurDis(y\_pred \leftrightarrow y)` 。默认值: ``False`` 。 + + .. py:method:: clear() + + 内部评估结果清零。 + + .. py:method:: eval() + + 计算均方根表面距离。 + + 返回: + numpy.float64,计算得到的均方根表面距离值。 + + 异常: + - **RuntimeError** - 如果没有先调用update方法,则会报错。 + + .. py:method:: update(*inputs) + + 使用 `y_pred`、`y` 和 `label_idx` 更新内部评估结果。 + + 参数: + - **inputs** - `y_pred`、 `y` 和 `label_idx`。 `y_pred` 和 `y` 为Tensor,list或numpy.ndarray, `y_pred` 是预测的二值图像。 `y` 是实际的二值图像。 `label_idx` 数据类型为int或float,表示像素点的类别值。 + + 异常: + - **ValueError** - 输入的数量不等于3。 + - **TypeError** - `label_idx` 的数据类型不是int或float。 + - **ValueError** - `label_idx` 的值不在y_pred或y中。 + - **ValueError** - `y_pred` 和 `y` 的shape不同。 diff --git a/docs/api/api_python/train/mindspore.train.auc.rst b/docs/api/api_python/train/mindspore.train.auc.rst index 31a1edcae67..0b55a43839f 100644 --- a/docs/api/api_python/train/mindspore.train.auc.rst +++ b/docs/api/api_python/train/mindspore.train.auc.rst @@ -1,15 +1,15 @@ -mindspore.train.auc -==================== - -.. py:function:: mindspore.train.auc(x, y, reorder=False) - - 使用梯形法则计算曲线下面积AUC(Area Under the Curve,AUC)。这是一个一般函数,给定曲线上的点, - 用于计算ROC (Receiver Operating Curve, ROC) 曲线下的面积。 - - 参数: - - **x** (Union[np.array, list]) - 从ROC曲线(False Positive Rate, FPR)来看,np.array具有假阳性率。如果是多类,则为np.array列表。Shape为 :math:`(N)` 。 - - **y** (Union[np.array, list]) - 从ROC曲线(True Positive Rate, TPR)来看,np.array具有假阳性率。如果是多类,则为np.array列表。Shape为 :math:`(N)` 。 - - **reorder** (bool) - 如果为False,那么 `x` 必须是单调上升或下降的,如果为True,那么 `x` 将会按照升序排序。默认值: ``False`` 。 - - 返回: - float,曲线下面积的值AUC。 +mindspore.train.auc +==================== + +.. py:function:: mindspore.train.auc(x, y, reorder=False) + + 使用梯形法则计算曲线下面积AUC(Area Under the Curve,AUC)。这是一个一般函数,给定曲线上的点, + 用于计算ROC (Receiver Operating Curve, ROC) 曲线下的面积。 + + 参数: + - **x** (Union[np.array, list]) - 从ROC曲线(False Positive Rate, FPR)来看,np.array具有假阳性率。如果是多类,则为np.array列表。Shape为 :math:`(N)` 。 + - **y** (Union[np.array, list]) - 从ROC曲线(True Positive Rate, TPR)来看,np.array具有假阳性率。如果是多类,则为np.array列表。Shape为 :math:`(N)` 。 + - **reorder** (bool) - 如果为False,那么 `x` 必须是单调上升或下降的,如果为True,那么 `x` 将会按照升序排序。默认值: ``False`` 。 + + 返回: + float,曲线下面积的值AUC。 diff --git a/docs/api/api_python_en/mindspore.experimental.rst b/docs/api/api_python_en/mindspore.experimental.rst index dd8534d4f53..f9f242f0a66 100644 --- a/docs/api/api_python_en/mindspore.experimental.rst +++ b/docs/api/api_python_en/mindspore.experimental.rst @@ -1,83 +1,83 @@ -mindspore.experimental -======================= - -The experimental modules. - -Experimental Optimizer ------------------------ - -.. msplatformautosummary:: - :toctree: experimental/optim - :nosignatures: - :template: classtemplate.rst - - mindspore.experimental.optim.Optimizer - mindspore.experimental.optim.Adadelta - mindspore.experimental.optim.Adagrad - mindspore.experimental.optim.Adam - mindspore.experimental.optim.Adamax - mindspore.experimental.optim.AdamW - mindspore.experimental.optim.ASGD - mindspore.experimental.optim.NAdam - mindspore.experimental.optim.RAdam - mindspore.experimental.optim.RMSprop - mindspore.experimental.optim.Rprop - mindspore.experimental.optim.SGD - - -LRScheduler Class -^^^^^^^^^^^^^^^^^^ - -The dynamic learning rates in this module are all subclasses of LRScheduler, this module should be used with optimizers -in mindspore.experimental.optim, pass the optimizer instance to a LRScheduler when used. During the training process, the -LRScheduler subclass dynamically changes the learning rate by calling the `step` method. - -.. code-block:: - - import mindspore - from mindspore import nn - from mindspore.experimental import optim - - # Define the network structure of LeNet5. Refer to - # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py - - net = LeNet5() - loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) - optimizer = optim.Adam(net.trainable_params(), lr=0.05) - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) - def forward_fn(data, label): - logits = net(data) - loss = loss_fn(logits, label) - return loss, logits - grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - def train_step(data, label): - (loss, _), grads = grad_fn(data, label) - optimizer(grads) - return loss - for epoch in range(6): - # Create the dataset taking MNIST as an example. Refer to - # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py - - for data, label in create_dataset(need_download=False): - train_step(data, label) - scheduler.step() - -.. msplatformautosummary:: - :toctree: experimental/optim - :nosignatures: - :template: classtemplate.rst - - mindspore.experimental.optim.lr_scheduler.LRScheduler - mindspore.experimental.optim.lr_scheduler.ConstantLR - mindspore.experimental.optim.lr_scheduler.CosineAnnealingLR - mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts - mindspore.experimental.optim.lr_scheduler.CyclicLR - mindspore.experimental.optim.lr_scheduler.ExponentialLR - mindspore.experimental.optim.lr_scheduler.LambdaLR - mindspore.experimental.optim.lr_scheduler.LinearLR - mindspore.experimental.optim.lr_scheduler.MultiplicativeLR - mindspore.experimental.optim.lr_scheduler.MultiStepLR - mindspore.experimental.optim.lr_scheduler.PolynomialLR - mindspore.experimental.optim.lr_scheduler.ReduceLROnPlateau - mindspore.experimental.optim.lr_scheduler.SequentialLR +mindspore.experimental +======================= + +The experimental modules. + +Experimental Optimizer +----------------------- + +.. msplatformautosummary:: + :toctree: experimental/optim + :nosignatures: + :template: classtemplate.rst + + mindspore.experimental.optim.Optimizer + mindspore.experimental.optim.Adadelta + mindspore.experimental.optim.Adagrad + mindspore.experimental.optim.Adam + mindspore.experimental.optim.Adamax + mindspore.experimental.optim.AdamW + mindspore.experimental.optim.ASGD + mindspore.experimental.optim.NAdam + mindspore.experimental.optim.RAdam + mindspore.experimental.optim.RMSprop + mindspore.experimental.optim.Rprop + mindspore.experimental.optim.SGD + + +LRScheduler Class +^^^^^^^^^^^^^^^^^^ + +The dynamic learning rates in this module are all subclasses of LRScheduler, this module should be used with optimizers +in mindspore.experimental.optim, pass the optimizer instance to a LRScheduler when used. During the training process, the +LRScheduler subclass dynamically changes the learning rate by calling the `step` method. + +.. code-block:: + + import mindspore + from mindspore import nn + from mindspore.experimental import optim + + # Define the network structure of LeNet5. Refer to + # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py + + net = LeNet5() + loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + optimizer = optim.Adam(net.trainable_params(), lr=0.05) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) + def forward_fn(data, label): + logits = net(data) + loss = loss_fn(logits, label) + return loss, logits + grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + optimizer(grads) + return loss + for epoch in range(6): + # Create the dataset taking MNIST as an example. Refer to + # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py + + for data, label in create_dataset(need_download=False): + train_step(data, label) + scheduler.step() + +.. msplatformautosummary:: + :toctree: experimental/optim + :nosignatures: + :template: classtemplate.rst + + mindspore.experimental.optim.lr_scheduler.LRScheduler + mindspore.experimental.optim.lr_scheduler.ConstantLR + mindspore.experimental.optim.lr_scheduler.CosineAnnealingLR + mindspore.experimental.optim.lr_scheduler.CosineAnnealingWarmRestarts + mindspore.experimental.optim.lr_scheduler.CyclicLR + mindspore.experimental.optim.lr_scheduler.ExponentialLR + mindspore.experimental.optim.lr_scheduler.LambdaLR + mindspore.experimental.optim.lr_scheduler.LinearLR + mindspore.experimental.optim.lr_scheduler.MultiplicativeLR + mindspore.experimental.optim.lr_scheduler.MultiStepLR + mindspore.experimental.optim.lr_scheduler.PolynomialLR + mindspore.experimental.optim.lr_scheduler.ReduceLROnPlateau + mindspore.experimental.optim.lr_scheduler.SequentialLR mindspore.experimental.optim.lr_scheduler.StepLR \ No newline at end of file diff --git a/graphengine b/graphengine deleted file mode 160000 index f4a38c95483..00000000000 --- a/graphengine +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f4a38c95483cce7c3cd3991a173108e0b2b5ed69 diff --git a/include/api/model_group.h b/include/api/model_group.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/backend/common/graph_kernel/cast_matmul_fusion.h b/mindspore/ccsrc/backend/common/graph_kernel/cast_matmul_fusion.h index 8ccffcf7847..4490c577066 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/cast_matmul_fusion.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/cast_matmul_fusion.h @@ -1,34 +1,34 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ - -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" - -namespace mindspore::graphkernel { -class CastMatmulFusion : public opt::Pass { - public: - CastMatmulFusion() : Pass("cast_matmul_fusion") {} - ~CastMatmulFusion() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -using OptimizeMatmulPtr = std::shared_ptr; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ + +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" + +namespace mindspore::graphkernel { +class CastMatmulFusion : public opt::Pass { + public: + CastMatmulFusion() : Pass("cast_matmul_fusion") {} + ~CastMatmulFusion() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +using OptimizeMatmulPtr = std::shared_ptr; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ diff --git a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.cc b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.cc index f2305d2ed2e..b289847cb46 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.cc +++ b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.cc @@ -1,67 +1,67 @@ -/** - * Copyright 2019-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/common/graph_kernel/core/graph_kernel_pass_manager.h" - -#include - -#include "utils/log_adapter.h" - -namespace mindspore::graphkernel { -void GraphKernelPassManager::Add(const opt::PassPtr &pass, unsigned int pass_level, bool supported_device) { - MS_EXCEPTION_IF_NULL(pass); - auto pass_id = passes_.size(); - auto pass_name = pass->name(); - auto pass_in_list = [this, pass_id, &pass_name](const std::vector &pass_list) { - // the config format can be "stage_id.pass_id" or "stage_name.pass_name" - return std::find(pass_list.begin(), pass_list.end(), - std::to_string(this->stage_) + "." + std::to_string(pass_id)) != pass_list.end() || - std::find(pass_list.begin(), pass_list.end(), this->name_ + "." + pass_name) != pass_list.end(); - }; - bool enable = supported_device && flags_.opt_level >= pass_level; - if (enable) { - // if it meets the condition to enable, check whether it's in the disabled list. - enable = !pass_in_list(flags_.disable_pass); - } else { - // if it doesn't meet the condition to enable, check whether it's in the enabled list. - enable = pass_in_list(flags_.enable_pass); - } - passes_.push_back(pass); - enabled_.push_back(enable); -} - -std::string GraphKernelPassManager::GetPassFullname(size_t pass_id, const opt::PassPtr &pass) const { - return "stage" + std::to_string(stage_) + "_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name(); -} - -bool GraphKernelPassManager::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - for (size_t i = 0; i < passes_.size(); i++) { - if (enabled_[i]) { - changed = RunPass(func_graph, i, passes_[i]) || changed; - // dump ir to a graph_kernel subdir, and set a global id in front of the filename - std::ostringstream oss; - static int g_id = 0; - constexpr int id_length = 4; - oss << "graph_kernel/" << std::setfill('0') << std::setw(id_length) << g_id++ << "_" - << GetPassFullname(i, passes_[i]); - DumpPassIR(func_graph, oss.str()); - } else { - MS_LOG(INFO) << "pass " << GetPassFullname(i, passes_[i]) << " is disabled."; - } - } - return changed; -} -} // namespace mindspore::graphkernel +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/common/graph_kernel/core/graph_kernel_pass_manager.h" + +#include + +#include "utils/log_adapter.h" + +namespace mindspore::graphkernel { +void GraphKernelPassManager::Add(const opt::PassPtr &pass, unsigned int pass_level, bool supported_device) { + MS_EXCEPTION_IF_NULL(pass); + auto pass_id = passes_.size(); + auto pass_name = pass->name(); + auto pass_in_list = [this, pass_id, &pass_name](const std::vector &pass_list) { + // the config format can be "stage_id.pass_id" or "stage_name.pass_name" + return std::find(pass_list.begin(), pass_list.end(), + std::to_string(this->stage_) + "." + std::to_string(pass_id)) != pass_list.end() || + std::find(pass_list.begin(), pass_list.end(), this->name_ + "." + pass_name) != pass_list.end(); + }; + bool enable = supported_device && flags_.opt_level >= pass_level; + if (enable) { + // if it meets the condition to enable, check whether it's in the disabled list. + enable = !pass_in_list(flags_.disable_pass); + } else { + // if it doesn't meet the condition to enable, check whether it's in the enabled list. + enable = pass_in_list(flags_.enable_pass); + } + passes_.push_back(pass); + enabled_.push_back(enable); +} + +std::string GraphKernelPassManager::GetPassFullname(size_t pass_id, const opt::PassPtr &pass) const { + return "stage" + std::to_string(stage_) + "_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name(); +} + +bool GraphKernelPassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + for (size_t i = 0; i < passes_.size(); i++) { + if (enabled_[i]) { + changed = RunPass(func_graph, i, passes_[i]) || changed; + // dump ir to a graph_kernel subdir, and set a global id in front of the filename + std::ostringstream oss; + static int g_id = 0; + constexpr int id_length = 4; + oss << "graph_kernel/" << std::setfill('0') << std::setw(id_length) << g_id++ << "_" + << GetPassFullname(i, passes_[i]); + DumpPassIR(func_graph, oss.str()); + } else { + MS_LOG(INFO) << "pass " << GetPassFullname(i, passes_[i]) << " is disabled."; + } + } + return changed; +} +} // namespace mindspore::graphkernel diff --git a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.h b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.h index d6e052499a3..6077b880c2f 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_pass_manager.h @@ -1,49 +1,49 @@ -/** - * Copyright 2019-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "include/backend/optimizer/pass_manager.h" - -namespace mindspore::graphkernel { -using opt::PassManager; -class GraphKernelPassManager : public PassManager { - public: - GraphKernelPassManager(size_t stage, const std::string &name) - : PassManager(name, true), stage_(stage), flags_(GraphKernelFlags::GetInstance()) {} - ~GraphKernelPassManager() = default; - - // Add graph pass, the pass object will be freed when pass manager freed. - void Add(const opt::PassPtr &pass, unsigned int pass_level, bool supported_device = true); - - // Run passes on the func_graph - bool Run(const FuncGraphPtr &func_graph) const override; - - protected: - std::string GetPassFullname(size_t pass_id, const opt::PassPtr &pass) const override; - - size_t stage_; - std::vector enabled_; - const GraphKernelFlags &flags_; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "backend/common/graph_kernel/graph_kernel_flags.h" +#include "include/backend/optimizer/pass_manager.h" + +namespace mindspore::graphkernel { +using opt::PassManager; +class GraphKernelPassManager : public PassManager { + public: + GraphKernelPassManager(size_t stage, const std::string &name) + : PassManager(name, true), stage_(stage), flags_(GraphKernelFlags::GetInstance()) {} + ~GraphKernelPassManager() = default; + + // Add graph pass, the pass object will be freed when pass manager freed. + void Add(const opt::PassPtr &pass, unsigned int pass_level, bool supported_device = true); + + // Run passes on the func_graph + bool Run(const FuncGraphPtr &func_graph) const override; + + protected: + std::string GetPassFullname(size_t pass_id, const opt::PassPtr &pass) const override; + + size_t stage_; + std::vector enabled_; + const GraphKernelFlags &flags_; +}; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/common/graph_kernel/insert_pad.h b/mindspore/ccsrc/backend/common/graph_kernel/insert_pad.h index d0ea92b7d43..656fa179012 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/insert_pad.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/insert_pad.h @@ -1,35 +1,35 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ - -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "backend/common/graph_kernel/graph_kernel_helper.h" - -namespace mindspore::graphkernel { -class InsertPadOps : public opt::Pass { - public: - InsertPadOps() : Pass("insert_pad_ops") {} - ~InsertPadOps() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -using InsertPadOpsPtr = std::shared_ptr; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ + +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" +#include "backend/common/graph_kernel/graph_kernel_helper.h" + +namespace mindspore::graphkernel { +class InsertPadOps : public opt::Pass { + public: + InsertPadOps() : Pass("insert_pad_ops") {} + ~InsertPadOps() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +using InsertPadOpsPtr = std::shared_ptr; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INSERT_PAD_OPS_H_ diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.cc b/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.cc index 53ad41cfc69..12b55482e39 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.cc +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.cc @@ -1,152 +1,152 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/common/graph_kernel/model/lite_graph.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "utils/hash_map.h" -#include "backend/common/graph_kernel/model/node.h" -#include "backend/common/graph_kernel/model/op_node.h" -#include "backend/common/graph_kernel/model/op_register.h" - -namespace mindspore::graphkernel::inner { -std::string LiteGraph::ToString(bool reset_node_name) const { - if (reset_node_name) { - param_id_ = node_id_ = 0; - for (auto &inp : inputs_) { - inp->SetDebugName(ParamName()); - } - for (auto &node : ops_) { - node->SetDebugName(NodeName()); - } - } - std::ostringstream os; - os << name_ << "("; - for (size_t i = 0; i < inputs_.size(); i++) { - os << inputs_[i]->debug_name(); - if (i != inputs_.size() - 1) { - os << ", "; - } - } - os << ") -> "; - auto &outputs = GetOutputs(); - for (size_t i = 0; i < outputs.size(); i++) { - os << outputs[i]->debug_name(); - if (i != outputs.size() - 1) { - os << ", "; - } - } - os << " {\n"; - for (const NodePtr &op : ops_) { - os << " " << op->ToString() << "\n"; - } - os << "}"; - return os.str(); -} - -const NodePtrList &LiteGraph::GetOrderedNodes() { - mindspore::HashMap outdegrees; - std::function dfs; - std::set visited; - // record the out degree of each nodes by Dfs. - dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) { - (void)visited.insert(node); - for (auto &input : node->inputs()) { - if (input->NodeType() == NType::Primitive) { - ++outdegrees[input]; - if (visited.count(input) == 0) { - dfs(input); - } - } - } - }; - dfs(output_); - NodePtrList res; - NodePtrList stack; - - // toposort algorithm with out degree - stack.push_back(output_); - while (!stack.empty()) { - auto cur = stack.back(); - stack.pop_back(); - res.push_back(cur); - for (auto &input : cur->inputs()) { - if (input->NodeType() != NType::Primitive) { - continue; - } - --outdegrees[input]; - if (outdegrees[input] == 0) { - stack.push_back(input); - (void)outdegrees.erase(input); - } - } - } - if (!outdegrees.empty()) { - MS_LOG(ERROR) << "Circle was found:"; - for (auto &node : outdegrees) { - MS_LOG(ERROR) << " " << node.first->debug_name(); - } - MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size(); - } - std::reverse(res.begin(), res.end()); - // remove the "OutputNode" - res.pop_back(); - ops_ = std::move(res); - return ops_; -} - -PrimOpPtr CreateOp(const std::string &op, const std::string &debug_name) { - auto node = OpRegistry::Instance().NewOp(op); - node->SetDebugName(debug_name); - return node; -} - -NodePtr LiteGraph::GraphBuilderBase::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs) const { - PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); - auto baseinfo = op_ptr->Infer(inputs, attrs); - op_ptr->SetInputs(inputs); - op_ptr->SetAttrs(attrs); - op_ptr->SetBaseInfo(baseinfo); - (void)graph_->ops_.emplace_back(op_ptr); - return op_ptr; -} - -NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBaseList &baseinfolist, - const NodePtrList &inputs, const DAttrs &attrs) const { - PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); - op_ptr->SetInputs(inputs); - op_ptr->SetAttrs(attrs); - op_ptr->SetBaseInfo(baseinfolist); - (void)graph_->ops_.emplace_back(op_ptr); - return op_ptr; -} - -NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, - const DAttrs &attrs) const { - PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); - op_ptr->SetInputs(inputs); - op_ptr->SetAttrs(attrs); - op_ptr->SetBaseInfo({baseinfo}); - (void)graph_->ops_.emplace_back(op_ptr); - return op_ptr; -} -} // namespace mindspore::graphkernel::inner +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/common/graph_kernel/model/lite_graph.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "utils/hash_map.h" +#include "backend/common/graph_kernel/model/node.h" +#include "backend/common/graph_kernel/model/op_node.h" +#include "backend/common/graph_kernel/model/op_register.h" + +namespace mindspore::graphkernel::inner { +std::string LiteGraph::ToString(bool reset_node_name) const { + if (reset_node_name) { + param_id_ = node_id_ = 0; + for (auto &inp : inputs_) { + inp->SetDebugName(ParamName()); + } + for (auto &node : ops_) { + node->SetDebugName(NodeName()); + } + } + std::ostringstream os; + os << name_ << "("; + for (size_t i = 0; i < inputs_.size(); i++) { + os << inputs_[i]->debug_name(); + if (i != inputs_.size() - 1) { + os << ", "; + } + } + os << ") -> "; + auto &outputs = GetOutputs(); + for (size_t i = 0; i < outputs.size(); i++) { + os << outputs[i]->debug_name(); + if (i != outputs.size() - 1) { + os << ", "; + } + } + os << " {\n"; + for (const NodePtr &op : ops_) { + os << " " << op->ToString() << "\n"; + } + os << "}"; + return os.str(); +} + +const NodePtrList &LiteGraph::GetOrderedNodes() { + mindspore::HashMap outdegrees; + std::function dfs; + std::set visited; + // record the out degree of each nodes by Dfs. + dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) { + (void)visited.insert(node); + for (auto &input : node->inputs()) { + if (input->NodeType() == NType::Primitive) { + ++outdegrees[input]; + if (visited.count(input) == 0) { + dfs(input); + } + } + } + }; + dfs(output_); + NodePtrList res; + NodePtrList stack; + + // toposort algorithm with out degree + stack.push_back(output_); + while (!stack.empty()) { + auto cur = stack.back(); + stack.pop_back(); + res.push_back(cur); + for (auto &input : cur->inputs()) { + if (input->NodeType() != NType::Primitive) { + continue; + } + --outdegrees[input]; + if (outdegrees[input] == 0) { + stack.push_back(input); + (void)outdegrees.erase(input); + } + } + } + if (!outdegrees.empty()) { + MS_LOG(ERROR) << "Circle was found:"; + for (auto &node : outdegrees) { + MS_LOG(ERROR) << " " << node.first->debug_name(); + } + MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size(); + } + std::reverse(res.begin(), res.end()); + // remove the "OutputNode" + res.pop_back(); + ops_ = std::move(res); + return ops_; +} + +PrimOpPtr CreateOp(const std::string &op, const std::string &debug_name) { + auto node = OpRegistry::Instance().NewOp(op); + node->SetDebugName(debug_name); + return node; +} + +NodePtr LiteGraph::GraphBuilderBase::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs) const { + PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); + auto baseinfo = op_ptr->Infer(inputs, attrs); + op_ptr->SetInputs(inputs); + op_ptr->SetAttrs(attrs); + op_ptr->SetBaseInfo(baseinfo); + (void)graph_->ops_.emplace_back(op_ptr); + return op_ptr; +} + +NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBaseList &baseinfolist, + const NodePtrList &inputs, const DAttrs &attrs) const { + PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); + op_ptr->SetInputs(inputs); + op_ptr->SetAttrs(attrs); + op_ptr->SetBaseInfo(baseinfolist); + (void)graph_->ops_.emplace_back(op_ptr); + return op_ptr; +} + +NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, + const DAttrs &attrs) const { + PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName()); + op_ptr->SetInputs(inputs); + op_ptr->SetAttrs(attrs); + op_ptr->SetBaseInfo({baseinfo}); + (void)graph_->ops_.emplace_back(op_ptr); + return op_ptr; +} +} // namespace mindspore::graphkernel::inner diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.h b/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.h index eceab8369f1..3494139ad36 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/lite_graph.h @@ -1,85 +1,85 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ - -#include -#include -#include "backend/common/graph_kernel/model/node.h" - -namespace mindspore::graphkernel::inner { -class LiteGraph { - public: - class GraphBuilderBase; - explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {} - - const NodePtrList &GetOrderedNodes(); - std::string ToString(bool reset_node_name = false) const; - const std::string &name() const { return name_; } - const NodePtrList &ops() const { return ops_; } - const NodePtrList &inputs() const { return inputs_; } - const NodePtr &output(size_t i) const { return output_->input(i); } - const NodePtrList &GetOutputs() const { return output_->inputs(); } - - void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); } - void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); } - - protected: - std::string name_; - NodePtrList ops_; // save all operators in topo order - NodePtrList inputs_; - NodePtr output_; - - private: - std::string ParamName() const { return "input_" + std::to_string(param_id_++); } - std::string NodeName() const { return "output_" + std::to_string(node_id_++); } - mutable int param_id_{0}; - mutable int node_id_{0}; -}; -using LiteGraphPtr = std::shared_ptr; -class LiteGraph::GraphBuilderBase { - public: - explicit GraphBuilderBase(const std::string &name = "") { graph_ = std::make_shared(name); } - ~GraphBuilderBase() = default; - - // Create a parameter of graph - NodePtr Parameter(const NodeBase &baseinfo) const { - auto para = std::make_shared(baseinfo); - para->SetDebugName(graph_->ParamName()); - graph_->inputs_.push_back(para); - return para; - } - - // Create a const value node - NodePtr Value(const tensor::TensorPtr &data) const { return std::make_shared(data); } - - void SetOutputs(const NodePtrList &nodes) const { graph_->output_->SetInputs(nodes); } - - // Emit op, auto inferring the baseinfo of Node. - NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const; - - // Create op node with given baseinfo. - NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs, - const DAttrs &attrs = {}) const; - NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, - const DAttrs &attrs = {}) const; - LiteGraphPtr Get() const { return graph_; } - - private: - LiteGraphPtr graph_; -}; -} // namespace mindspore::graphkernel::inner -#endif +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ + +#include +#include +#include "backend/common/graph_kernel/model/node.h" + +namespace mindspore::graphkernel::inner { +class LiteGraph { + public: + class GraphBuilderBase; + explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {} + + const NodePtrList &GetOrderedNodes(); + std::string ToString(bool reset_node_name = false) const; + const std::string &name() const { return name_; } + const NodePtrList &ops() const { return ops_; } + const NodePtrList &inputs() const { return inputs_; } + const NodePtr &output(size_t i) const { return output_->input(i); } + const NodePtrList &GetOutputs() const { return output_->inputs(); } + + void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); } + void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); } + + protected: + std::string name_; + NodePtrList ops_; // save all operators in topo order + NodePtrList inputs_; + NodePtr output_; + + private: + std::string ParamName() const { return "input_" + std::to_string(param_id_++); } + std::string NodeName() const { return "output_" + std::to_string(node_id_++); } + mutable int param_id_{0}; + mutable int node_id_{0}; +}; +using LiteGraphPtr = std::shared_ptr; +class LiteGraph::GraphBuilderBase { + public: + explicit GraphBuilderBase(const std::string &name = "") { graph_ = std::make_shared(name); } + ~GraphBuilderBase() = default; + + // Create a parameter of graph + NodePtr Parameter(const NodeBase &baseinfo) const { + auto para = std::make_shared(baseinfo); + para->SetDebugName(graph_->ParamName()); + graph_->inputs_.push_back(para); + return para; + } + + // Create a const value node + NodePtr Value(const tensor::TensorPtr &data) const { return std::make_shared(data); } + + void SetOutputs(const NodePtrList &nodes) const { graph_->output_->SetInputs(nodes); } + + // Emit op, auto inferring the baseinfo of Node. + NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const; + + // Create op node with given baseinfo. + NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs, + const DAttrs &attrs = {}) const; + NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, + const DAttrs &attrs = {}) const; + LiteGraphPtr Get() const { return graph_; } + + private: + LiteGraphPtr graph_; +}; +} // namespace mindspore::graphkernel::inner +#endif diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/node.cc b/mindspore/ccsrc/backend/common/graph_kernel/model/node.cc index 08f6a2048fd..508eb4f4642 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/node.cc +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/node.cc @@ -1,135 +1,135 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "backend/common/graph_kernel/model/node.h" -#include "abstract/utils.h" - -namespace mindspore::graphkernel::inner { -ConstScalarNode::ConstScalarNode(const ValuePtr &data) - : Node({DShape({}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) { - auto type_ptr = data->ToAbstract()->BuildType(); - MS_EXCEPTION_IF_NULL(type_ptr); - type = type_ptr->type_id(); -} - -ConstTupleNode::ConstTupleNode(const ValuePtr &data, const size_t len) - : Node({DShape({SizeToLong(len)}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) { - auto type_ptr = data->ToAbstract()->BuildType(); - MS_EXCEPTION_IF_NULL(type_ptr); - type = type_ptr->type_id(); -} - -void Node::SetBaseInfo(const NodeBaseList &baseinfo) { - this->shape = baseinfo[0].shape; - this->type = baseinfo[0].type; - this->format = baseinfo[0].format; - this->symbolic_shape = baseinfo[0].symbolic_shape; - if (baseinfo.size() > 1) { - outputs_ = baseinfo; - } -} - -std::string Node::ToString() const { - std::ostringstream oss; - oss << debug_name() << "["; - for (size_t i = 0; i < shape.size(); i++) { - oss << shape[i]; - if (i + 1 < shape.size()) { - oss << ","; - } - } - auto type_str = (type == TypeId::kNumberTypeBegin) ? "NOTYPE" : TypeIdToString(type); - oss << "]{" << type_str << "x" << format << "}"; - return oss.str(); -} - -abstract::AbstractBasePtr Node::ToAbstract() const { - if (outputs_.empty()) { - return std::make_shared(TypeIdToType(this->type), this->shape); - } - AbstractBasePtrList abs_list(outputs_.size()); - (void)std::transform(outputs_.cbegin(), outputs_.cend(), abs_list.begin(), [](const NodeBase &node) { - return std::make_shared(TypeIdToType(node.type), node.shape); - }); - return std::make_shared(std::move(abs_list)); -} - -void Node::AddInput(const NodePtr &new_input) { - MS_EXCEPTION_IF_NULL(new_input); - new_input->AddUser(this, inputs_.size()); - (void)inputs_.emplace_back(new_input); -} - -void Node::SetInput(size_t i, const NodePtr &new_input) { - MS_EXCEPTION_IF_NULL(new_input); - if (i >= inputs_.size()) { - MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range [0, " << inputs_.size() << ")"; - } - auto &old_input = inputs_[i]; - old_input->RemoveUser(this, i); - new_input->AddUser(this, i); - inputs_[i] = new_input; -} - -void Node::SetInputs(const NodePtrList &inputs) { - ClearInputs(); - inputs_.reserve(inputs.size()); - for (const auto &inp : inputs) { - AddInput(inp); - } -} - -void Node::ClearInputs() noexcept { - if (!inputs_.empty()) { - // remove the original inputs - for (size_t i = 0; i < inputs_.size(); i++) { - inputs_[i]->RemoveUser(this, i); - } - inputs_.clear(); - } -} - -void Node::ReplaceWith(const NodePtr &other_node) { - if (this->users_.empty()) { - return; - } - // the users_ will be changed, so we copy the users before traversal - auto users = this->users_; - for (auto &user : users) { - for (const auto &idx : user.second) { - user.first->SetInput(idx, other_node); - } - } -} - -void Node::RemoveUser(Node *const user, size_t index) { - if (auto iter = users_.find(user); iter != users_.end()) { - (void)iter->second.erase(index); - if (iter->second.empty()) { - (void)users_.erase(iter); - } - } -} - -size_t Node::tensor_size(bool in_bytes) const { - if (IsDynamic(this->shape)) { - return 0; - } - size_t size = LongToSize(abstract::ShapeSize(this->shape)); - return in_bytes ? abstract::TypeIdSize(this->type) * size : size; -} -} // namespace mindspore::graphkernel::inner +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "backend/common/graph_kernel/model/node.h" +#include "abstract/utils.h" + +namespace mindspore::graphkernel::inner { +ConstScalarNode::ConstScalarNode(const ValuePtr &data) + : Node({DShape({}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) { + auto type_ptr = data->ToAbstract()->BuildType(); + MS_EXCEPTION_IF_NULL(type_ptr); + type = type_ptr->type_id(); +} + +ConstTupleNode::ConstTupleNode(const ValuePtr &data, const size_t len) + : Node({DShape({SizeToLong(len)}), kNumberTypeEnd, kOpFormat_DEFAULT}), data_(data) { + auto type_ptr = data->ToAbstract()->BuildType(); + MS_EXCEPTION_IF_NULL(type_ptr); + type = type_ptr->type_id(); +} + +void Node::SetBaseInfo(const NodeBaseList &baseinfo) { + this->shape = baseinfo[0].shape; + this->type = baseinfo[0].type; + this->format = baseinfo[0].format; + this->symbolic_shape = baseinfo[0].symbolic_shape; + if (baseinfo.size() > 1) { + outputs_ = baseinfo; + } +} + +std::string Node::ToString() const { + std::ostringstream oss; + oss << debug_name() << "["; + for (size_t i = 0; i < shape.size(); i++) { + oss << shape[i]; + if (i + 1 < shape.size()) { + oss << ","; + } + } + auto type_str = (type == TypeId::kNumberTypeBegin) ? "NOTYPE" : TypeIdToString(type); + oss << "]{" << type_str << "x" << format << "}"; + return oss.str(); +} + +abstract::AbstractBasePtr Node::ToAbstract() const { + if (outputs_.empty()) { + return std::make_shared(TypeIdToType(this->type), this->shape); + } + AbstractBasePtrList abs_list(outputs_.size()); + (void)std::transform(outputs_.cbegin(), outputs_.cend(), abs_list.begin(), [](const NodeBase &node) { + return std::make_shared(TypeIdToType(node.type), node.shape); + }); + return std::make_shared(std::move(abs_list)); +} + +void Node::AddInput(const NodePtr &new_input) { + MS_EXCEPTION_IF_NULL(new_input); + new_input->AddUser(this, inputs_.size()); + (void)inputs_.emplace_back(new_input); +} + +void Node::SetInput(size_t i, const NodePtr &new_input) { + MS_EXCEPTION_IF_NULL(new_input); + if (i >= inputs_.size()) { + MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range [0, " << inputs_.size() << ")"; + } + auto &old_input = inputs_[i]; + old_input->RemoveUser(this, i); + new_input->AddUser(this, i); + inputs_[i] = new_input; +} + +void Node::SetInputs(const NodePtrList &inputs) { + ClearInputs(); + inputs_.reserve(inputs.size()); + for (const auto &inp : inputs) { + AddInput(inp); + } +} + +void Node::ClearInputs() noexcept { + if (!inputs_.empty()) { + // remove the original inputs + for (size_t i = 0; i < inputs_.size(); i++) { + inputs_[i]->RemoveUser(this, i); + } + inputs_.clear(); + } +} + +void Node::ReplaceWith(const NodePtr &other_node) { + if (this->users_.empty()) { + return; + } + // the users_ will be changed, so we copy the users before traversal + auto users = this->users_; + for (auto &user : users) { + for (const auto &idx : user.second) { + user.first->SetInput(idx, other_node); + } + } +} + +void Node::RemoveUser(Node *const user, size_t index) { + if (auto iter = users_.find(user); iter != users_.end()) { + (void)iter->second.erase(index); + if (iter->second.empty()) { + (void)users_.erase(iter); + } + } +} + +size_t Node::tensor_size(bool in_bytes) const { + if (IsDynamic(this->shape)) { + return 0; + } + size_t size = LongToSize(abstract::ShapeSize(this->shape)); + return in_bytes ? abstract::TypeIdSize(this->type) * size : size; +} +} // namespace mindspore::graphkernel::inner diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/node.h b/mindspore/ccsrc/backend/common/graph_kernel/model/node.h index 8d570bf3acf..332a3b13684 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/node.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/node.h @@ -1,167 +1,167 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ - -#include -#include -#include -#include -#include "ir/dtype/type_id.h" -#include "ir/anf.h" -#include "ir/value.h" -#include "ir/tensor.h" -#include "utils/hash_map.h" -#include "utils/shape_utils.h" -#include "include/common/utils/utils.h" -#include "include/backend/visible.h" -#include "mindspore/core/symbolic_shape/symbol.h" - -namespace mindspore::graphkernel::inner { -enum class NType { - Base, - Primitive, - Parameter, - Tensor, - Scalar, - Tuple, - Output, -}; - -using DFormat = std::string; -using DShape = ShapeVector; -using DAttrs = mindspore::HashMap; - -struct BACKEND_EXPORT NodeBase { - DShape shape; - TypeId type; - DFormat format; - ListSymbolPtr symbolic_shape{nullptr}; -}; -using NodeBaseList = std::vector; - -class BACKEND_EXPORT Node; -using NodePtr = std::shared_ptr; -using NodePtrList = std::vector; -class BACKEND_EXPORT Node : public NodeBase, public std::enable_shared_from_this { - public: - explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {} - virtual ~Node() { ClearInputs(); } // remove this node from the previous nodes' user. - - virtual NType NodeType() { return NType::Base; } - virtual std::string ToString() const; - virtual abstract::AbstractBasePtr ToAbstract() const; - - virtual void SetBaseInfo(const NodeBaseList &baseinfo); - void AddInput(const NodePtr &new_input); - void SetInput(size_t i, const NodePtr &new_input); - void SetInputs(const NodePtrList &inputs); - void ClearInputs() noexcept; - void ReplaceWith(const NodePtr &other_node); - void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } - void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } - void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; } - - template - std::shared_ptr As() { - return std::static_pointer_cast(shared_from_this()); - } - - const std::string &debug_name() const { return debug_name_; } - const DAttrs &attrs() const { return attrs_; } - const NodePtr &input(size_t i) const { return inputs_[i]; } - const NodePtrList &inputs() const { return inputs_; } - const mindspore::HashMap> &users() const { return users_; } - size_t tensor_size(bool in_bytes = false) const; - const NodeBaseList &outputs() const { return outputs_; } - - protected: - // only used in Dump function - mutable std::string debug_name_; - DAttrs attrs_; - NodePtrList inputs_; - // {user_node: {input edge index set}} - mindspore::HashMap> users_; - // save output tensor info when the node is a multi-output operator. - // it should keep empty when the node is single-output. - NodeBaseList outputs_; - - private: - // the nodes' users are only maintained by AddInput/SetInput. - void AddUser(Node *const user, size_t index) { (void)users_[user].insert(index); } - void RemoveUser(Node *const user, size_t index); -}; - -class BACKEND_EXPORT ConstTensorNode : public Node { - public: - explicit ConstTensorNode(const tensor::TensorPtr &data) - : Node({data->DataSize() == 1 ? DShape({1}) : data->shape(), data->data_type(), kOpFormat_DEFAULT}), - data_(data) {} - ~ConstTensorNode() = default; - - NType NodeType() override { return NType::Tensor; } - std::string ToString() const override { return data_->data().ToString(data_->data_type(), data_->shape(), false); } - const tensor::TensorPtr data() const { return data_; } - abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } - - protected: - tensor::TensorPtr data_; -}; - -class ConstScalarNode : public Node { - public: - explicit ConstScalarNode(const ValuePtr &data); - ~ConstScalarNode() = default; - - NType NodeType() override { return NType::Scalar; } - const ValuePtr data() const { return data_; } - abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } - - protected: - ValuePtr data_; -}; - -class ConstTupleNode : public Node { - public: - explicit ConstTupleNode(const ValuePtr &data, const size_t len); - ~ConstTupleNode() = default; - - NType NodeType() override { return NType::Tuple; } - const ValuePtr data() const { return data_; } - abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } - - protected: - ValuePtr data_; -}; - -class ParamNode : public Node { - public: - explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {} - ~ParamNode() = default; - - NType NodeType() override { return NType::Parameter; } -}; - -// the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph. -class OutputNode : public Node { - public: - OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) { debug_name_ = "Output"; } - ~OutputNode() = default; - - NType NodeType() override { return NType::Output; } -}; -} // namespace mindspore::graphkernel::inner -#endif +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ + +#include +#include +#include +#include +#include "ir/dtype/type_id.h" +#include "ir/anf.h" +#include "ir/value.h" +#include "ir/tensor.h" +#include "utils/hash_map.h" +#include "utils/shape_utils.h" +#include "include/common/utils/utils.h" +#include "include/backend/visible.h" +#include "mindspore/core/symbolic_shape/symbol.h" + +namespace mindspore::graphkernel::inner { +enum class NType { + Base, + Primitive, + Parameter, + Tensor, + Scalar, + Tuple, + Output, +}; + +using DFormat = std::string; +using DShape = ShapeVector; +using DAttrs = mindspore::HashMap; + +struct BACKEND_EXPORT NodeBase { + DShape shape; + TypeId type; + DFormat format; + ListSymbolPtr symbolic_shape{nullptr}; +}; +using NodeBaseList = std::vector; + +class BACKEND_EXPORT Node; +using NodePtr = std::shared_ptr; +using NodePtrList = std::vector; +class BACKEND_EXPORT Node : public NodeBase, public std::enable_shared_from_this { + public: + explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {} + virtual ~Node() { ClearInputs(); } // remove this node from the previous nodes' user. + + virtual NType NodeType() { return NType::Base; } + virtual std::string ToString() const; + virtual abstract::AbstractBasePtr ToAbstract() const; + + virtual void SetBaseInfo(const NodeBaseList &baseinfo); + void AddInput(const NodePtr &new_input); + void SetInput(size_t i, const NodePtr &new_input); + void SetInputs(const NodePtrList &inputs); + void ClearInputs() noexcept; + void ReplaceWith(const NodePtr &other_node); + void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } + void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } + void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; } + + template + std::shared_ptr As() { + return std::static_pointer_cast(shared_from_this()); + } + + const std::string &debug_name() const { return debug_name_; } + const DAttrs &attrs() const { return attrs_; } + const NodePtr &input(size_t i) const { return inputs_[i]; } + const NodePtrList &inputs() const { return inputs_; } + const mindspore::HashMap> &users() const { return users_; } + size_t tensor_size(bool in_bytes = false) const; + const NodeBaseList &outputs() const { return outputs_; } + + protected: + // only used in Dump function + mutable std::string debug_name_; + DAttrs attrs_; + NodePtrList inputs_; + // {user_node: {input edge index set}} + mindspore::HashMap> users_; + // save output tensor info when the node is a multi-output operator. + // it should keep empty when the node is single-output. + NodeBaseList outputs_; + + private: + // the nodes' users are only maintained by AddInput/SetInput. + void AddUser(Node *const user, size_t index) { (void)users_[user].insert(index); } + void RemoveUser(Node *const user, size_t index); +}; + +class BACKEND_EXPORT ConstTensorNode : public Node { + public: + explicit ConstTensorNode(const tensor::TensorPtr &data) + : Node({data->DataSize() == 1 ? DShape({1}) : data->shape(), data->data_type(), kOpFormat_DEFAULT}), + data_(data) {} + ~ConstTensorNode() = default; + + NType NodeType() override { return NType::Tensor; } + std::string ToString() const override { return data_->data().ToString(data_->data_type(), data_->shape(), false); } + const tensor::TensorPtr data() const { return data_; } + abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } + + protected: + tensor::TensorPtr data_; +}; + +class ConstScalarNode : public Node { + public: + explicit ConstScalarNode(const ValuePtr &data); + ~ConstScalarNode() = default; + + NType NodeType() override { return NType::Scalar; } + const ValuePtr data() const { return data_; } + abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } + + protected: + ValuePtr data_; +}; + +class ConstTupleNode : public Node { + public: + explicit ConstTupleNode(const ValuePtr &data, const size_t len); + ~ConstTupleNode() = default; + + NType NodeType() override { return NType::Tuple; } + const ValuePtr data() const { return data_; } + abstract::AbstractBasePtr ToAbstract() const override { return data_->ToAbstract(); } + + protected: + ValuePtr data_; +}; + +class ParamNode : public Node { + public: + explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {} + ~ParamNode() = default; + + NType NodeType() override { return NType::Parameter; } +}; + +// the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph. +class OutputNode : public Node { + public: + OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) { debug_name_ = "Output"; } + ~OutputNode() = default; + + NType NodeType() override { return NType::Output; } +}; +} // namespace mindspore::graphkernel::inner +#endif diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.cc b/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.cc index b05fb8eedee..ae8f6b40f36 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.cc +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.cc @@ -1,1288 +1,1288 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/common/graph_kernel/model/op_node.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/ops/primitive_infer_map.h" -#include "utils/anf_utils.h" -#include "utils/hash_map.h" -#include "utils/check_convert_utils.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/model/node.h" -#include "backend/operator/ops_backend_infer_function.h" -#include "utils/log_adapter.h" -#include "ops/auto_generate/gen_ops_primitive.h" - -namespace mindspore::graphkernel::inner { -std::vector GetListInt(const ValuePtr &attr_value) { - std::vector list_int; - const auto &vals = attr_value->cast()->value(); - (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), - [](const ValuePtr &v) { return AnfUtils::GetIntValue(v); }); - return list_int; -} - -BaseShapePtr InferShapeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { - auto shape_optional = abstract::InferShapeByFuncImpl(prim, abs_list, true); - if (shape_optional.has_value()) { - return shape_optional.value(); - } - - auto found = abstract::GetBackendPrimitiveInferImpl(prim); - if (found.has_value()) { - auto infer = found.value(); - if (infer.IsImplInferShapeAndType()) { - return infer.InferShape(prim, abs_list); - } - } - MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined."; - return nullptr; -} - -TypePtr InferTypeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { - auto type_optional = abstract::InferTypeByFuncImpl(prim, abs_list, true); - if (type_optional.has_value()) { - return type_optional.value(); - } - - auto found = abstract::GetBackendPrimitiveInferImpl(prim); - if (found.has_value()) { - auto infer = found.value(); - if (infer.IsImplInferShapeAndType()) { - return infer.InferType(prim, abs_list); - } - } - MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined."; - return nullptr; -} - -tensor::TensorPtr InferValueWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { - auto value_optional = abstract::InferValueByFuncImpl(prim, abs_list); - if (value_optional.has_value()) { - return std::static_pointer_cast(value_optional.value()); - } - - auto found = abstract::GetBackendPrimitiveInferImpl(prim); - if (found.has_value()) { - auto infer = found.value(); - if (infer.IsImplInferValue()) { - return std::static_pointer_cast(infer.InferValue(prim, abs_list)); - } - } - return nullptr; -} - -std::pair PrimOp::GenPrimAndAbstract(const NodePtrList &inputs, - const DAttrs &attrs) const { - auto prim = std::make_shared(op_); - MS_EXCEPTION_IF_NULL(prim); - (void)prim->SetAttrs(attrs); - AbstractBasePtrList abs_list(inputs.size()); - (void)std::transform(inputs.cbegin(), inputs.cend(), abs_list.begin(), - [](const NodePtr &node) { return node->ToAbstract(); }); - return std::make_pair(prim, abs_list); -} - -std::vector PrimOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs); - RectifyAbstract(prim, &abs_list); - auto baseshape = InferShapeWithAbstract(prim, abs_list); - MS_EXCEPTION_IF_NULL(baseshape); - if (baseshape->isa()) { - auto tuple_shape = baseshape->cast(); - MS_EXCEPTION_IF_NULL(tuple_shape); - const auto &shape_elements = tuple_shape->shape(); - std::vector result(shape_elements.size()); - (void)std::transform(shape_elements.cbegin(), shape_elements.cend(), result.begin(), - [](const BaseShapePtr &s) { return s->cast()->shape(); }); - return result; - } - auto shape = baseshape->cast(); - if (shape != nullptr) { - return {shape->shape()}; - } - return {DShape()}; -} - -std::vector PrimOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { - auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs); - RectifyAbstract(prim, &abs_list); - auto type = InferTypeWithAbstract(prim, abs_list); - MS_EXCEPTION_IF_NULL(type); - auto get_type_id = [](const TypePtr &t) { - return t->isa() ? t->cast()->element()->type_id() : t->type_id(); - }; - if (type->isa()) { - auto elements = type->cast()->elements(); - std::vector result(elements.size()); - (void)std::transform(elements.cbegin(), elements.cend(), result.begin(), get_type_id); - return result; - } - return {get_type_id(type)}; -} - -NodeBaseList PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { - Check(inputs, attrs); - NodeBaseList result; - auto format = InferFormat(inputs, attrs); - auto shapes = InferShape(inputs, attrs); - auto types = InferType(inputs, attrs); - if (shapes.size() != types.size()) { - MS_LOG(EXCEPTION) << "The num of shapes and types should be equal. (" << shapes.size() << " vs " << types.size() - << ")"; - } - for (size_t i = 0; i < shapes.size(); i++) { - (void)result.emplace_back(NodeBase{shapes[i], types[i], format}); - } - return result; -} - -std::string PrimOp::ToString() const { - std::ostringstream oss; - oss << Node::ToString(); - oss << " = " << this->op_ << "("; - for (size_t i = 0; i < inputs_.size(); i++) { - if (inputs_[i]->NodeType() == NType::Primitive) { - oss << inputs_[i]->Node::ToString(); - } else { - oss << inputs_[i]->ToString(); - } - if (i != inputs_.size() - 1) { - oss << ", "; - } - } - oss << ")"; - std::ostringstream attr_oss; - bool has_attr = false; - std::set black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"}; - for (auto attr : attrs_) { - if (attr.second != nullptr && black_list.count(attr.first) == 0) { - if (has_attr) { - attr_oss << ", "; - } else { - has_attr = true; - } - attr_oss << attr.first << ": " << attr.second->ToString(); - } - } - if (has_attr) { - oss << " // attr {" << attr_oss.str() << "}"; - } - return oss.str(); -} - -template -std::vector ChangeDataToVec(const NodePtr &n) { - std::vector res; - TD *data = static_cast(std::static_pointer_cast(n)->data()->data_c()); - for (size_t elem = 0; elem < n->tensor_size(); elem++) { - res.push_back(static_cast(*(data + elem))); - } - return res; -} - -template -tensor::TensorPtr PrimOp::CalcByOperator(const NodePtrList &inputs, const DAttrs &) const { - const size_t unary_input_num = 1; - const size_t binary_input_num = 2; - if (inputs.size() > 0) { - bool all_shape_equal = - std::all_of(inputs.begin(), inputs.end(), [&inputs](const NodePtr &t) { return t->shape == inputs[0]->shape; }); - if (!all_shape_equal) { - return nullptr; - } - } - std::vector> inputs_tm; - const auto &op = this->op(); - const auto tid = this->type; - for (const auto &t : inputs) { - (void)inputs_tm.emplace_back(ChangeDataToVec(t)); - } - if (inputs.size() == unary_input_num) { - mindspore::HashMap> func_map = { - {"Abs", [](const TM &a) { return a <= TM(0) ? -a : a; }}, - {"Exp", [](const TM &a) { return exp(a); }}, - {"Log", [](const TM &a) { return log(a); }}, - {"Neg", [](const TM &a) { return -a; }}, - {"Reciprocal", - [](const TM &a) { - if (a == TM(0)) { - MS_LOG(EXCEPTION) << "During graph kernel constant fold for reciprocal, divisor is zero."; - } - return TM(1) / a; - }}, - {"Rsqrt", - [](const TM &a) { - if (a == TM(0)) { - MS_LOG(EXCEPTION) << "During graph kernel constant fold for rsqrt, divisor is zero."; - } - return TM(1) / sqrt(a); - }}, - {"Sqrt", [](const TM &a) { return sqrt(a); }}, - }; - if (func_map.find(op) == func_map.end()) { - return nullptr; - } - const auto &input_a = inputs_tm[0]; - std::vector res; - (void)std::transform(input_a.begin(), input_a.end(), std::back_inserter(res), - [&func_map, &op](const TM &i) { return func_map[op](i); }); - return std::make_shared(tid, this->shape, &res[0], tid); - } else if (inputs.size() == binary_input_num) { - mindspore::HashMap> func_map = { - {"Add", [](const TM &a, const TM &b) { return a + b; }}, - {"Sub", [](const TM &a, const TM &b) { return a - b; }}, - {"Mul", [](const TM &a, const TM &b) { return a * b; }}, - {"RealDiv", - [](const TM &a, const TM &b) { - if (b == TM(0)) { - MS_LOG(EXCEPTION) << "During graph kernel constant fold for realdiv, divisor is zero."; - } - return a / b; - }}, - }; - if (func_map.find(op) == func_map.end()) { - return nullptr; - } - const auto &input_a = inputs_tm[0]; - const auto &input_b = inputs_tm[1]; - std::vector res; - for (size_t i = 0; i < input_a.size(); i++) { - (void)res.emplace_back(func_map[op](input_a[i], input_b[i])); - } - return std::make_shared(tid, this->shape, &res[0], tid); - } - return nullptr; -} - -NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { - for (auto i : inputs) { - if (i->NodeType() != NType::Tensor) { - return nullptr; - } - } - TypeId output_type = this->type; - tensor::TensorPtr res = nullptr; - switch (static_cast(output_type)) { - case TypeId::kNumberTypeUInt8: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt8: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt16: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt32: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt64: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt16: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt32: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt64: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat16: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat32: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat64: { - res = CalcByOperator(inputs, attrs); - break; - } - case TypeId::kNumberTypeBFloat16: { - res = CalcByOperator(inputs, attrs); - break; - } - default: - return nullptr; - } - if (res == nullptr) { - auto [prim, inputs_abstract] = GenPrimAndAbstract(inputs, attrs); - RectifyAbstract(prim, &inputs_abstract); - res = InferValueWithAbstract(prim, inputs_abstract); - } - return res == nullptr ? nullptr : std::make_shared(res); -} - -NodePtr ReshapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) { - if (inputs[0]->NodeType() != NType::Tensor) { - return nullptr; - } - void *tensor_data = inputs[0]->As()->data()->data_c(); - tensor::TensorPtr result_tensor = std::make_shared(this->type, this->shape, tensor_data, this->type); - return std::make_shared(result_tensor); -} - -// default format shape to fractal_Nz format shape -DShape ToNz(const DShape &default_shape) { - constexpr size_t nz_size = 2; - constexpr auto align16 = 16; - auto len = default_shape.size(); - DShape leading_shape; - DShape tail_shape; - if (default_shape.size() == 1 && default_shape[0] == 1) { - // # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape - return default_shape; - } - if (default_shape.size() > nz_size) { - (void)leading_shape.insert(leading_shape.cend(), default_shape.cbegin(), - default_shape.cend() - SizeToLong(nz_size)); - } - if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) { - // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16) - if (default_shape.back() % align16 != 0) { - MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back(); - } - tail_shape = {default_shape.back() / align16, 1, 1, align16}; - } else if (default_shape.size() >= nz_size || default_shape[1] == 1) { - // (N, 32, 1) -> (N, 1, 2, 16, 1) - if (default_shape[len - nz_size] % align16 != 0) { - MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size]; - } - tail_shape = {1, default_shape[0] / align16, align16, 1}; - } else { - // (N, 32, 48) -> (N, 3, 2, 16, 16) - if (default_shape.back() % align16 != 0 || default_shape[len - nz_size] % align16 != 0) { - MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got " - << default_shape.back() << " " << default_shape[len - nz_size]; - } - tail_shape = {default_shape[1] / align16, default_shape[0] / align16, align16, align16}; - } - (void)leading_shape.insert(leading_shape.cend(), tail_shape.cbegin(), tail_shape.cend()); - return leading_shape; -} - -DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { - std::vector> shapes; - for (auto &input : inputs) { - if (to_nz && input->format != kOpFormat_FRAC_NZ) { - (void)shapes.emplace_back(ToNz(input->shape)); - } else { - (void)shapes.emplace_back(input->shape); - } - } - auto max_dim_input = - std::max_element(shapes.begin(), shapes.end(), - [](const std::vector &a, const std::vector &b) { return a.size() < b.size(); }); - auto max_dim = max_dim_input->size(); - std::vector> align_shapes; - for (auto &s : shapes) { - std::vector cur(max_dim - s.size(), 1); - (void)cur.insert(cur.cend(), s.cbegin(), s.cend()); - (void)align_shapes.emplace_back(cur); - } - std::vector output_shape(max_dim, 1); - for (size_t i = 0; i < max_dim; i++) { - for (auto &align_shape : align_shapes) { - if (align_shape[i] > 1) { - if (output_shape[i] == 1) { - output_shape[i] = align_shape[i]; - } - if (output_shape[i] != align_shape[i]) { - MS_LOG(EXCEPTION) << "Shape broadcast failed: " << output_shape[i] << " vs " << align_shape[i]; - } - } - } - } - return output_shape; -} - -std::vector ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - if (std::any_of(inputs.begin(), inputs.end(), - [](const NodePtr &input) { return input->format == kOpFormat_FRAC_NZ; })) { - return {BroadcastShape(inputs, true)}; - } - return PrimOp::InferShape(inputs, attrs); -} - -DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &) { - if (inputs.empty()) { - return kOpFormat_DEFAULT; - } - auto first_format = inputs[0]->format; - for (const auto &inp : inputs) { - auto cur_format = inp->format; - if (cur_format.find("FRACTAL") != std::string::npos) { - // special format - return cur_format; - } - if (cur_format != kOpFormat_DEFAULT && inp->tensor_size() != 1) { - return cur_format; - } - } - return first_format; -} - -std::vector ArgReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - CHECK_ATTR(attrs, "axis"); - auto axis = GetListInt(attrs.find("axis")->second); - const auto &input_shape = inputs[0]->shape; - int64_t size = SizeToLong(input_shape.size()); - std::vector real_axis; - (void)std::transform(axis.begin(), axis.end(), std::back_inserter(real_axis), - [&size](const int64_t &x) { return x < 0 ? (x + size) : x; }); - - DShape new_shape; - for (size_t i = 0; i < input_shape.size(); i++) { - if (std::find(real_axis.begin(), real_axis.end(), SizeToLong(i)) == real_axis.end()) { - (void)new_shape.emplace_back(input_shape[i]); - } - } - if (new_shape.empty()) { - (void)new_shape.emplace_back(1); - } - return {new_shape}; -} - -std::vector ArgReduceOp::InferType(const NodePtrList &, const DAttrs &attrs) { - CHECK_ATTR(attrs, "output_type"); - return {attrs.find("output_type")->second->cast()->type_id()}; -} - -DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { - if (attrs.count(kAttrDstFormat) != 0) { - return GetValue(attrs.find(kAttrDstFormat)->second); - } - // only support NCHW/NHWC now - constexpr size_t kRank4 = 4; - if (inputs[0]->shape.size() != kRank4) { - return kOpFormat_DEFAULT; - } - auto perm_node = inputs[1]; - auto perm_tensor = perm_node->As()->data(); - auto perm = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm_tensor, "Transpose"); - const auto &ori_format = inputs[0]->format; - if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) { - std::vector nchw2nhwc = {0, 2, 3, 1}; - if (perm == nchw2nhwc) { - return kOpFormat_NHWC; - } - } else if (ori_format == kOpFormat_NHWC) { - std::vector nhwc2nchw = {0, 3, 1, 2}; - if (perm == nhwc2nchw) { - return kOpFormat_NCHW; - } - } - return kOpFormat_DEFAULT; -} - -NodePtr ConstantOfShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { - for (auto i : inputs) { - if (i->NodeType() != NType::Tensor) { - return nullptr; - } - } - const auto &value = GetValue>(attrs.find("value")->second); - std::vector res; - size_t elem_num = LongToSize(std::accumulate(this->shape.begin(), this->shape.end(), 1, std::multiplies())); - if (value.size() == 1) { - res = std::vector(elem_num, value[0]); - } else if (value.size() == elem_num) { - res = value; - } else { - return nullptr; - } - auto tensor = std::make_shared(this->type, this->shape, &res[0], kNumberTypeFloat32); - return std::make_shared(tensor); -} - -std::vector ConstantOfShapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - const auto &value = attrs.find("shape")->second; - std::vector res; - if (value->isa()) { - res = GetValue>(value); - return {res}; - } else if (value->isa()) { - auto tvalue = value->cast(); - if (tvalue->data_type_c() == static_cast(TypeId::kNumberTypeInt32)) { - int *data = static_cast(tvalue->data_c()); - for (size_t elem = 0; elem < tvalue->DataSize(); elem++) { - res.push_back(IntToLong(*(data + elem))); - } - return {res}; - } else if (tvalue->data_type_c() == static_cast(TypeId::kNumberTypeInt64)) { - int64_t *data = static_cast(tvalue->data_c()); - res = std::vector(data, data + tvalue->DataSize()); - return {res}; - } - } - return PrimOp::InferShape(inputs, attrs); -} - -NodePtr ShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) { - auto tensor = std::make_shared(this->type, this->shape, inputs[0]->shape.data(), kNumberTypeInt64); - return std::make_shared(tensor); -} - -std::vector PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - std::vector shape0 = inputs[0]->shape; - size_t n = shape0.size(); - CHECK_ATTR(attrs, "head"); - CHECK_ATTR(attrs, "tail"); - std::vector pad_before = GetListInt(attrs.find("head")->second); - std::vector pad_after = GetListInt(attrs.find("tail")->second); - if (pad_before.size() != n || pad_after.size() != n) { - MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs " - << pad_after.size(); - } - std::vector output; - for (size_t i = 0; i < n; i++) { - (void)output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]); - } - return {output}; -} - -std::vector UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - std::vector shape0 = inputs[0]->shape; - size_t n = shape0.size(); - CHECK_ATTR(attrs, "tail"); - std::vector unpad_after = GetListInt(attrs.find("tail")->second); - if (unpad_after.size() != n) { - MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size(); - } - std::vector output; - for (size_t i = 0; i < n; i++) { - (void)output.emplace_back(shape0[i] - unpad_after[i]); - } - return {output}; -} - -bool Conv2dOp::HadPad(const ShapeVector &pad_list, const std::string &pad_mode) { - constexpr size_t kTop = 0; - constexpr size_t kBottom = 1; - constexpr size_t kLeft = 2; - constexpr size_t kRight = 3; - - if (pad_list[kTop] != pad_list[kBottom] || pad_list[kLeft] != pad_list[kRight]) { - return true; - } - if (pad_mode != "VALID" && pad_mode != "valid") { - return std::any_of(pad_list.begin(), pad_list.end(), [](auto a) { return a != 0; }); - } - return false; -} - -std::vector Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - // get the output shape when format is NHWC/NCHW - if (inputs[0]->shape.size() == kDim4) { - CHECK_ATTR(attrs, "format"); - if (inputs[0]->format == kOpFormat_NHWC || inputs[1]->format == kOpFormat_NHWC || - GetValue(attrs.find("format")->second) == kOpFormat_NHWC) { - CHECK_ATTR(attrs, "pad_mode"); - CHECK_ATTR(attrs, "pad_list"); - CHECK_ATTR(attrs, "kernel_size"); - CHECK_ATTR(attrs, "stride"); - CHECK_ATTR(attrs, "dilation"); - - auto x_shape = inputs[0]->shape; - auto w_shape = inputs[1]->shape; - auto pad_mode = GetValue(attrs.find("pad_mode")->second); - auto pad_list = GetListInt(attrs.find("pad_list")->second); - auto kernel_size = GetListInt(attrs.find("kernel_size")->second); - auto stride = GetListInt(attrs.find("stride")->second); - auto dilation = GetListInt(attrs.find("dilation")->second); - constexpr size_t kPadSize = 4; - constexpr size_t kKernelSize = 2; - constexpr size_t kStrideSize = 4; - constexpr size_t kDilationSize = 4; - if (x_shape.size() != kDim4 || w_shape.size() != kDim4 || pad_list.size() != kPadSize || - kernel_size.size() != kKernelSize || stride.size() != kStrideSize || dilation.size() != kDilationSize) { - MS_LOG(EXCEPTION) << "For 'Conv2D', got sizes of x_shape, w_shape, pad_list, kernel_size, stride and dilation: " - << x_shape.size() << ", " << w_shape.size() << ", " << pad_list.size() << ", " - << kernel_size.size() << ", " << stride.size() << ", " << dilation.size() - << ". But expect: 4, 4, 4, 2, 4, 4"; - } - auto has_pad = HadPad(pad_list, pad_mode); - if (!has_pad) { - pad_list = {0, 0, 0, 0}; - } - - auto k_h = (kernel_size[0] - 1) * dilation[2] + 1; - auto k_w = (kernel_size[1] - 1) * dilation[3] + 1; - auto out_h = (x_shape[1] + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1; - auto out_w = (x_shape[2] + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1; - return {{x_shape[0], out_h, out_w, w_shape[3]}}; - } else { - return OpaqueOp::InferShape(inputs, attrs); - } - } - - // get the output shape when format is NCHWc - std::vector data_shape = inputs[0]->shape; - std::vector weight_shape = inputs[1]->shape; - auto n = data_shape[0]; - auto i_h = data_shape[2]; - auto i_w = data_shape[3]; - auto c_o_o = weight_shape[0]; - auto k_h = weight_shape[2]; - auto k_w = weight_shape[3]; - auto c_o_i = weight_shape[5]; - - CHECK_ATTR(attrs, "stride"); - CHECK_ATTR(attrs, "dilation"); - - std::vector strides = GetListInt(attrs.find("stride")->second); - std::vector dilations = GetListInt(attrs.find("dilation")->second); - - auto d_h = dilations[0]; - auto d_w = dilations[1]; - auto s_h = strides[0]; - auto s_w = strides[1]; - auto k_h_d = (k_h - 1) * d_h + 1; - auto k_w_d = (k_w - 1) * d_w + 1; - auto o_h = (i_h - k_h_d) / s_h + 1; - auto o_w = (i_w - k_w_d) / s_w + 1; - - std::vector output_shape{n, c_o_o, o_h, o_w, c_o_i}; - return {output_shape}; -} - -std::vector Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { - if (inputs[0]->shape.size() == kDim4) { - return PrimOp::InferType(inputs, attrs); - } - return {inputs[0]->type}; -} - -DFormat Conv2dOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { - if (inputs[0]->shape.size() == kDim4) { - return PrimOp::InferFormat(inputs, attrs); - } - CHECK_ATTR(attrs, "conv_out_format"); - return GetValue(attrs.find("conv_out_format")->second); -} - -void ConcatOp::RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) { - AbstractBasePtrList rectifyed_abs_list; - (void)rectifyed_abs_list.emplace_back(std::make_shared(*input_abstract_ptr)); - input_abstract_ptr->swap(rectifyed_abs_list); -} - -void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "keep_dims"); - (void)abs_list->emplace_back(prim->GetAttr("keep_dims")->ToAbstract()); - if (prim->name() == prim::kPrimReduceSum->name()) { - CHECK_ATTR(prim->attrs(), "skip_mode"); - (void)abs_list->emplace_back(prim->GetAttr("skip_mode")->ToAbstract()); - } -} - -void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "axis"); - (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract()); -} - -void CumSumOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "exclusive"); - (void)abs_list->emplace_back(prim->GetAttr("exclusive")->ToAbstract()); - CHECK_ATTR(prim->attrs(), "reverse"); - (void)abs_list->emplace_back(prim->GetAttr("reverse")->ToAbstract()); -} - -void GatherOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "batch_dims"); - (void)abs_list->emplace_back(prim->GetAttr("batch_dims")->ToAbstract()); -} - -void ArgReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "axis"); - (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract()); - CHECK_ATTR(prim->attrs(), "output_type"); - (void)abs_list->emplace_back(prim->GetAttr("output_type")->ToAbstract()); -} - -void PagedAttentionOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - constexpr size_t PA_INPUT_NUM = 5; - constexpr size_t PA_MASK_INPUT_NUM = 6; - if (abs_list->size() == PA_INPUT_NUM || abs_list->size() == PA_MASK_INPUT_NUM) { - CHECK_ATTR(prim->attrs(), "head_num"); - (void)abs_list->emplace_back(prim->GetAttr("head_num")->ToAbstract()); - CHECK_ATTR(prim->attrs(), "scale_value"); - (void)abs_list->emplace_back(prim->GetAttr("scale_value")->ToAbstract()); - CHECK_ATTR(prim->attrs(), "kv_head_num"); - (void)abs_list->emplace_back(prim->GetAttr("kv_head_num")->ToAbstract()); - } -} - -std::vector CompactShape(const ShapeVector &origin, int64_t axis) { - std::vector new_shape; - size_t accu = 1; - for (size_t i = 0; i < origin.size(); i++) { - if (LongToSize(axis) == i) { - new_shape.push_back(accu); - new_shape.push_back(LongToSize(origin[i])); - accu = 1; - } else { - accu *= LongToSize(origin[i]); - } - } - new_shape.push_back(accu); - return new_shape; -} - -template -tensor::TensorPtr GatherOp::CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const { - constexpr size_t param_index = 0; - constexpr size_t indice_index = 1; - constexpr size_t axis_index = 2; - constexpr size_t input_num = 3; - constexpr size_t first_dim = 0; - constexpr size_t second_dim = 1; - constexpr size_t third_dim = 2; - int64_t axis = 0; - if (attrs.count("axis") > 0) { - axis = GetValue(attrs.find("axis")->second); - } else if (inputs.size() == input_num) { - int *data_axis = - static_cast(std::static_pointer_cast(inputs[axis_index])->data()->data_c()); - axis = IntToLong(*data_axis); - } else { - return nullptr; - } - ShapeVector param_shp = inputs[param_index]->shape; - axis = axis < 0 ? SizeToLong(param_shp.size()) + axis : axis; - std::vector indices; - switch (static_cast(inputs[indice_index]->type)) { - case TypeId::kNumberTypeInt8: { - indices = ChangeDataToVec(inputs[indice_index]); - break; - } - case TypeId::kNumberTypeInt16: { - indices = ChangeDataToVec(inputs[indice_index]); - break; - } - case TypeId::kNumberTypeInt32: { - indices = ChangeDataToVec(inputs[indice_index]); - break; - } - case TypeId::kNumberTypeInt64: { - indices = ChangeDataToVec(inputs[indice_index]); - break; - } - default: - return nullptr; - } - - TM *input_x = - static_cast(std::static_pointer_cast(inputs[param_index])->data()->data_c()); - std::vector compact_shp = CompactShape(param_shp, axis); - std::vector res; - if (compact_shp.size() == input_num) { - for (size_t i = 0; i < compact_shp[first_dim]; i++) { - for (auto j : indices) { - for (size_t k = 0; k < compact_shp[third_dim]; k++) { - (void)res.emplace_back( - input_x[i * compact_shp[second_dim] * compact_shp[third_dim] + j * compact_shp[third_dim] + k]); - } - } - } - return std::make_shared(this->type, this->shape, &res[0], this->type); - } - return nullptr; -} - -NodePtr GatherOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { - for (auto i : inputs) { - if (i->NodeType() != NType::Tensor) { - return nullptr; - } - } - TypeId output_type = this->type; - tensor::TensorPtr res = nullptr; - switch (static_cast(output_type)) { - case TypeId::kNumberTypeUInt8: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt8: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt16: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt32: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt64: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt16: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt32: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt64: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat16: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat32: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat64: { - res = CalcGather(inputs, attrs); - break; - } - case TypeId::kNumberTypeBFloat16: { - res = CalcGather(inputs, attrs); - break; - } - default: - return nullptr; - } - return res == nullptr ? nullptr : std::make_shared(res); -} - -template -tensor::TensorPtr ConcatOp::CalcConcat(const NodePtrList &inputs, const DAttrs &attrs) { - constexpr size_t first_dim = 0; - constexpr size_t second_dim = 1; - constexpr size_t third_dim = 2; - int64_t axis = 0; - auto axis_node = inputs.back(); - if (axis_node->NodeType() == NType::Scalar) { - auto scalar_node = axis_node->As(); - axis = GetValue(scalar_node->data()); - } else { - return nullptr; - } - axis = axis < 0 ? SizeToLong(this->shape.size()) + axis : axis; - std::vector> inputs_tm; - for (const auto &t : inputs) { - (void)inputs_tm.emplace_back(ChangeDataToVec(t)); - } - std::vector> all_shps; - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(all_shps), - [&axis](const NodePtr &t) { return CompactShape(t->shape, axis); }); - std::vector res; - if (all_shps.size() > 0) { - const size_t third_dim_size = all_shps[0][third_dim]; - const size_t first_dim_size = all_shps[0][first_dim]; - for (size_t i = 0; i < first_dim_size; i++) { - for (size_t t = 0; t < inputs_tm.size(); t++) { - for (size_t j = 0; j < all_shps[t][second_dim]; j++) { - for (size_t k = 0; k < third_dim_size; k++) { - (void)res.emplace_back(inputs_tm[t][i * all_shps[t][second_dim] * third_dim_size + j * third_dim_size + k]); - } - } - } - } - return std::make_shared(this->type, this->shape, &res[0], this->type); - } - return nullptr; -} - -NodePtr ConcatOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { - for (auto i : inputs) { - if (i->NodeType() != NType::Tensor) { - return nullptr; - } - } - TypeId output_type = this->type; - tensor::TensorPtr res = nullptr; - switch (static_cast(output_type)) { - case TypeId::kNumberTypeUInt8: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt8: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt16: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt32: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt64: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt16: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt32: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt64: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat16: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat32: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat64: { - res = CalcConcat(inputs, attrs); - break; - } - case TypeId::kNumberTypeBFloat16: { - res = CalcConcat(inputs, attrs); - break; - } - default: - return nullptr; - } - return res == nullptr ? nullptr : std::make_shared(res); -} - -std::vector LayoutTransformOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - CHECK_ATTR(attrs, kAttrSrcFormat); - CHECK_ATTR(attrs, kAttrDstFormat); - auto src_format = GetValue(attrs.find(kAttrSrcFormat)->second); - auto dst_format = GetValue(attrs.find(kAttrDstFormat)->second); - std::vector data_shape = inputs[0]->shape; - if (src_format == kOpFormat_NHWC) { - auto n = data_shape[0]; - auto h = data_shape[1]; - auto w = data_shape[2]; - auto c = data_shape[3]; - auto c_o_i = GkUtils::GetChannelInConvFormat(dst_format); - if (c_o_i == 0) { - c_o_i = 1; - } - auto c_o_o = c / c_o_i; - std::vector output_shape{n, c_o_o, h, w, c_o_i}; - return {output_shape}; - } - if (dst_format == kOpFormat_NHWC) { - auto n = data_shape[0]; - auto c_o_o = data_shape[1]; - auto h = data_shape[2]; - auto w = data_shape[3]; - auto c_o_i = data_shape[4]; - auto c = c_o_o * c_o_i; - std::vector output_shape{n, h, w, c}; - return {output_shape}; - } - // LayoutTransform between nchwnc - auto n = data_shape[0]; - auto c_o_o = data_shape[1]; - auto h = data_shape[2]; - auto w = data_shape[3]; - auto c_o_i = data_shape[4]; - auto c_o_i_new = GkUtils::GetChannelInConvFormat(dst_format); - if (c_o_i_new == 0) { - c_o_i_new = 1; - } - auto c_o_o_new = c_o_o * c_o_i / c_o_i_new; - std::vector output_shape{n, c_o_o_new, h, w, c_o_i_new}; - return {output_shape}; -} - -std::vector Pool2DOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - CHECK_ATTR(attrs, "global"); - std::vector input_shape = inputs[0]->shape; - bool is_nhwc = input_shape.size() == 4; - int64_t n = input_shape[0]; - int64_t c; - int64_t h; - int64_t w; - if (is_nhwc) { - constexpr size_t h_idx = 1; - constexpr size_t w_idx = 2; - constexpr size_t c_idx = 3; - h = input_shape[h_idx]; - w = input_shape[w_idx]; - c = input_shape[c_idx]; - } else { - constexpr size_t c_idx = 1; - constexpr size_t h_idx = 2; - constexpr size_t w_idx = 3; - c = input_shape[c_idx]; - h = input_shape[h_idx]; - w = input_shape[w_idx]; - } - - if (GetValue(attrs.find("global")->second)) { - h = 1; - w = 1; - } else { - CHECK_ATTR(attrs, "strides"); - CHECK_ATTR(attrs, "kernel_size"); - CHECK_ATTR(attrs, "round_mode"); - std::vector strides = GetListInt(attrs.find("strides")->second); - std::vector kernels = GetListInt(attrs.find("kernel_size")->second); - if (AnfUtils::GetIntValue(attrs.find("round_mode")->second) == 0) { - // ceil mode - h = ((h - kernels[0] + strides[0] - 1) / strides[0]) + 1; - w = ((w - kernels[1] + strides[1] - 1) / strides[1]) + 1; - } else { - // round mode - h = ((h - kernels[0]) / strides[0]) + 1; - w = ((w - kernels[1]) / strides[1]) + 1; - } - } - if (is_nhwc) { - return {{n, h, w, c}}; - } else { - auto ci = input_shape[4]; - return {{n, c, h, w, ci}}; - } -} - -void ComplexOp::Check(const NodePtrList &inputs, const DAttrs &) { - if (inputs[0]->type != TypeId::kNumberTypeFloat32) { - MS_LOG(EXCEPTION) << "Complex's input[0] should be float32, but got " << TypeIdToString(inputs[0]->type, true); - } - if (inputs[0]->type != inputs[1]->type) { - MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch: " << TypeIdToString(inputs[0]->type, true) - << " vs " << TypeIdToString(inputs[1]->type, true); - } -} - -std::vector StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) { - CHECK_ATTR(attrs, "shape"); - return {GetListInt(attrs.find("shape")->second)}; -} - -template -tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const { - constexpr size_t input_index = 0; - constexpr size_t begin_index = 1; - constexpr size_t end_index = 2; - constexpr size_t axes_index = 3; - constexpr size_t stride_index = 4; - - ShapeVector input_shape = inputs[input_index]->shape; - std::vector begin = ChangeDataToVec(inputs[begin_index]); - std::vector end = ChangeDataToVec(inputs[end_index]); - std::vector axes = ChangeDataToVec(inputs[axes_index]); - std::vector stride = ChangeDataToVec(inputs[stride_index]); - - std::unordered_map> info; - for (size_t i = 0; i < axes.size(); i++) { - int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i]; - if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) { - MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative."; - return nullptr; - } - std::unordered_set pos; - int index = begin[i]; - while (index < end[i]) { - (void)pos.insert(IntToSize(index)); - index += stride[i]; - } - (void)info.emplace(axis, pos); - } - - TM *input_x = - static_cast(std::static_pointer_cast(inputs[input_index])->data()->data_c()); - - std::vector res; - - std::function func; - func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) { - if ((dim + 1) == input_shape.size()) { - for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { - if (info.count(SizeToInt(dim)) > 0) { - if (info[SizeToInt(dim)].count(i) > 0) { - (void)res.emplace_back(input_x[offset + i]); - } - } else { - (void)res.emplace_back(input_x[offset + i]); - } - } - } else if ((dim + 1) < input_shape.size()) { - size_t accu = 1; - for (size_t j = dim + 1; j < input_shape.size(); j++) { - accu *= LongToSize(input_shape[j]); - } - for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { - if (info.count(SizeToInt(dim)) > 0) { - if (info[SizeToInt(dim)].count(i) > 0) { - func(dim + 1, offset + i * accu); - } - } else { - func(dim + 1, offset + i * accu); - } - } - } - return; - }; - func(0, 0); - return std::make_shared(this->type, this->shape, &res[0], this->type); -} - -NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { - for (auto i : inputs) { - if (i->NodeType() != NType::Tensor) { - return nullptr; - } - } - TypeId output_type = this->type; - tensor::TensorPtr res = nullptr; - switch (static_cast(output_type)) { - case TypeId::kNumberTypeUInt8: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt8: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt16: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt32: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeInt64: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt16: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt32: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeUInt64: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat16: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat32: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeFloat64: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - case TypeId::kNumberTypeBFloat16: { - res = CalcStridedSliceOnnx(inputs, attrs); - break; - } - default: - return nullptr; - } - return res == nullptr ? nullptr : std::make_shared(res); -} - -void MatMulOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { - CHECK_ATTR(prim->attrs(), "transpose_a"); - (void)abs_list->emplace_back(prim->GetAttr("transpose_a")->ToAbstract()); - CHECK_ATTR(prim->attrs(), "transpose_b"); - (void)abs_list->emplace_back(prim->GetAttr("transpose_b")->ToAbstract()); -} - -std::vector MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { - // the prim's infer shape does not supports batch dims - constexpr size_t kMatMulRank = 2; - if (inputs[0]->shape.size() > kMatMulRank || inputs[1]->shape.size() > kMatMulRank) { - NodePtrList new_inputs = inputs; - std::vector batches(inputs.size()); - auto cut_batches = [&new_inputs, &batches, kMatMulRank](size_t i) -> void { - const auto &shape_i = new_inputs[i]->shape; - if (shape_i.size() > kMatMulRank) { - DShape real_shape(shape_i.cend() - kMatMulRank, shape_i.cend()); - new_inputs[i] = std::make_shared(NodeBase{real_shape, new_inputs[i]->type, new_inputs[i]->format}); - batches[i].assign(shape_i.cbegin(), shape_i.cend() - kMatMulRank); - } - }; - - cut_batches(0); - cut_batches(1); - if (batches[0].size() != batches[1].size()) { - MS_LOG(EXCEPTION) << "The Matmul's batch rank should be equal, but got " << batches[0].size() << " vs " - << batches[1].size(); - } - DShape batch; - for (size_t i = 0; i < batches[0].size(); i++) { - if (batches[0][i] != batches[1][i]) { - if (batches[0][i] != 1 && batches[1][i] != 1) { - MS_LOG(EXCEPTION) << "The Matmul's batch dim is unmatched. got " << inputs[0]->shape << " and " - << inputs[1]->shape; - } - } - batch.push_back(std::max(batches[0][i], batches[1][i])); - } - - auto out_shape = PrimOp::InferShape(new_inputs, attrs)[0]; - // just reuse the `batch` vector - (void)batch.insert(batch.end(), out_shape.begin(), out_shape.end()); - return {batch}; - } - return PrimOp::InferShape(inputs, attrs); -} - -std::vector MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { - if (attrs.count("dst_type") != 0) { - return {attrs.find("dst_type")->second->cast()->type_id()}; - } - if (inputs[0]->type == TypeId::kNumberTypeInt8) { - return {TypeId::kNumberTypeInt32}; - } - return {inputs[0]->type}; -} -} // namespace mindspore::graphkernel::inner +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/common/graph_kernel/model/op_node.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/ops/primitive_infer_map.h" +#include "utils/anf_utils.h" +#include "utils/hash_map.h" +#include "utils/check_convert_utils.h" +#include "backend/common/graph_kernel/core/graph_kernel_utils.h" +#include "backend/common/graph_kernel/model/node.h" +#include "backend/operator/ops_backend_infer_function.h" +#include "utils/log_adapter.h" +#include "ops/auto_generate/gen_ops_primitive.h" + +namespace mindspore::graphkernel::inner { +std::vector GetListInt(const ValuePtr &attr_value) { + std::vector list_int; + const auto &vals = attr_value->cast()->value(); + (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), + [](const ValuePtr &v) { return AnfUtils::GetIntValue(v); }); + return list_int; +} + +BaseShapePtr InferShapeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { + auto shape_optional = abstract::InferShapeByFuncImpl(prim, abs_list, true); + if (shape_optional.has_value()) { + return shape_optional.value(); + } + + auto found = abstract::GetBackendPrimitiveInferImpl(prim); + if (found.has_value()) { + auto infer = found.value(); + if (infer.IsImplInferShapeAndType()) { + return infer.InferShape(prim, abs_list); + } + } + MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined."; + return nullptr; +} + +TypePtr InferTypeWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { + auto type_optional = abstract::InferTypeByFuncImpl(prim, abs_list, true); + if (type_optional.has_value()) { + return type_optional.value(); + } + + auto found = abstract::GetBackendPrimitiveInferImpl(prim); + if (found.has_value()) { + auto infer = found.value(); + if (infer.IsImplInferShapeAndType()) { + return infer.InferType(prim, abs_list); + } + } + MS_LOG(EXCEPTION) << "The infer function of [" << prim->name() << "] is not defined."; + return nullptr; +} + +tensor::TensorPtr InferValueWithAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &abs_list) { + auto value_optional = abstract::InferValueByFuncImpl(prim, abs_list); + if (value_optional.has_value()) { + return std::static_pointer_cast(value_optional.value()); + } + + auto found = abstract::GetBackendPrimitiveInferImpl(prim); + if (found.has_value()) { + auto infer = found.value(); + if (infer.IsImplInferValue()) { + return std::static_pointer_cast(infer.InferValue(prim, abs_list)); + } + } + return nullptr; +} + +std::pair PrimOp::GenPrimAndAbstract(const NodePtrList &inputs, + const DAttrs &attrs) const { + auto prim = std::make_shared(op_); + MS_EXCEPTION_IF_NULL(prim); + (void)prim->SetAttrs(attrs); + AbstractBasePtrList abs_list(inputs.size()); + (void)std::transform(inputs.cbegin(), inputs.cend(), abs_list.begin(), + [](const NodePtr &node) { return node->ToAbstract(); }); + return std::make_pair(prim, abs_list); +} + +std::vector PrimOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs); + RectifyAbstract(prim, &abs_list); + auto baseshape = InferShapeWithAbstract(prim, abs_list); + MS_EXCEPTION_IF_NULL(baseshape); + if (baseshape->isa()) { + auto tuple_shape = baseshape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + const auto &shape_elements = tuple_shape->shape(); + std::vector result(shape_elements.size()); + (void)std::transform(shape_elements.cbegin(), shape_elements.cend(), result.begin(), + [](const BaseShapePtr &s) { return s->cast()->shape(); }); + return result; + } + auto shape = baseshape->cast(); + if (shape != nullptr) { + return {shape->shape()}; + } + return {DShape()}; +} + +std::vector PrimOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { + auto [prim, abs_list] = GenPrimAndAbstract(inputs, attrs); + RectifyAbstract(prim, &abs_list); + auto type = InferTypeWithAbstract(prim, abs_list); + MS_EXCEPTION_IF_NULL(type); + auto get_type_id = [](const TypePtr &t) { + return t->isa() ? t->cast()->element()->type_id() : t->type_id(); + }; + if (type->isa()) { + auto elements = type->cast()->elements(); + std::vector result(elements.size()); + (void)std::transform(elements.cbegin(), elements.cend(), result.begin(), get_type_id); + return result; + } + return {get_type_id(type)}; +} + +NodeBaseList PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { + Check(inputs, attrs); + NodeBaseList result; + auto format = InferFormat(inputs, attrs); + auto shapes = InferShape(inputs, attrs); + auto types = InferType(inputs, attrs); + if (shapes.size() != types.size()) { + MS_LOG(EXCEPTION) << "The num of shapes and types should be equal. (" << shapes.size() << " vs " << types.size() + << ")"; + } + for (size_t i = 0; i < shapes.size(); i++) { + (void)result.emplace_back(NodeBase{shapes[i], types[i], format}); + } + return result; +} + +std::string PrimOp::ToString() const { + std::ostringstream oss; + oss << Node::ToString(); + oss << " = " << this->op_ << "("; + for (size_t i = 0; i < inputs_.size(); i++) { + if (inputs_[i]->NodeType() == NType::Primitive) { + oss << inputs_[i]->Node::ToString(); + } else { + oss << inputs_[i]->ToString(); + } + if (i != inputs_.size() - 1) { + oss << ", "; + } + } + oss << ")"; + std::ostringstream attr_oss; + bool has_attr = false; + std::set black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"}; + for (auto attr : attrs_) { + if (attr.second != nullptr && black_list.count(attr.first) == 0) { + if (has_attr) { + attr_oss << ", "; + } else { + has_attr = true; + } + attr_oss << attr.first << ": " << attr.second->ToString(); + } + } + if (has_attr) { + oss << " // attr {" << attr_oss.str() << "}"; + } + return oss.str(); +} + +template +std::vector ChangeDataToVec(const NodePtr &n) { + std::vector res; + TD *data = static_cast(std::static_pointer_cast(n)->data()->data_c()); + for (size_t elem = 0; elem < n->tensor_size(); elem++) { + res.push_back(static_cast(*(data + elem))); + } + return res; +} + +template +tensor::TensorPtr PrimOp::CalcByOperator(const NodePtrList &inputs, const DAttrs &) const { + const size_t unary_input_num = 1; + const size_t binary_input_num = 2; + if (inputs.size() > 0) { + bool all_shape_equal = + std::all_of(inputs.begin(), inputs.end(), [&inputs](const NodePtr &t) { return t->shape == inputs[0]->shape; }); + if (!all_shape_equal) { + return nullptr; + } + } + std::vector> inputs_tm; + const auto &op = this->op(); + const auto tid = this->type; + for (const auto &t : inputs) { + (void)inputs_tm.emplace_back(ChangeDataToVec(t)); + } + if (inputs.size() == unary_input_num) { + mindspore::HashMap> func_map = { + {"Abs", [](const TM &a) { return a <= TM(0) ? -a : a; }}, + {"Exp", [](const TM &a) { return exp(a); }}, + {"Log", [](const TM &a) { return log(a); }}, + {"Neg", [](const TM &a) { return -a; }}, + {"Reciprocal", + [](const TM &a) { + if (a == TM(0)) { + MS_LOG(EXCEPTION) << "During graph kernel constant fold for reciprocal, divisor is zero."; + } + return TM(1) / a; + }}, + {"Rsqrt", + [](const TM &a) { + if (a == TM(0)) { + MS_LOG(EXCEPTION) << "During graph kernel constant fold for rsqrt, divisor is zero."; + } + return TM(1) / sqrt(a); + }}, + {"Sqrt", [](const TM &a) { return sqrt(a); }}, + }; + if (func_map.find(op) == func_map.end()) { + return nullptr; + } + const auto &input_a = inputs_tm[0]; + std::vector res; + (void)std::transform(input_a.begin(), input_a.end(), std::back_inserter(res), + [&func_map, &op](const TM &i) { return func_map[op](i); }); + return std::make_shared(tid, this->shape, &res[0], tid); + } else if (inputs.size() == binary_input_num) { + mindspore::HashMap> func_map = { + {"Add", [](const TM &a, const TM &b) { return a + b; }}, + {"Sub", [](const TM &a, const TM &b) { return a - b; }}, + {"Mul", [](const TM &a, const TM &b) { return a * b; }}, + {"RealDiv", + [](const TM &a, const TM &b) { + if (b == TM(0)) { + MS_LOG(EXCEPTION) << "During graph kernel constant fold for realdiv, divisor is zero."; + } + return a / b; + }}, + }; + if (func_map.find(op) == func_map.end()) { + return nullptr; + } + const auto &input_a = inputs_tm[0]; + const auto &input_b = inputs_tm[1]; + std::vector res; + for (size_t i = 0; i < input_a.size(); i++) { + (void)res.emplace_back(func_map[op](input_a[i], input_b[i])); + } + return std::make_shared(tid, this->shape, &res[0], tid); + } + return nullptr; +} + +NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Tensor) { + return nullptr; + } + } + TypeId output_type = this->type; + tensor::TensorPtr res = nullptr; + switch (static_cast(output_type)) { + case TypeId::kNumberTypeUInt8: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt8: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt16: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt32: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt64: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt16: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt32: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt64: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat16: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat32: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat64: { + res = CalcByOperator(inputs, attrs); + break; + } + case TypeId::kNumberTypeBFloat16: { + res = CalcByOperator(inputs, attrs); + break; + } + default: + return nullptr; + } + if (res == nullptr) { + auto [prim, inputs_abstract] = GenPrimAndAbstract(inputs, attrs); + RectifyAbstract(prim, &inputs_abstract); + res = InferValueWithAbstract(prim, inputs_abstract); + } + return res == nullptr ? nullptr : std::make_shared(res); +} + +NodePtr ReshapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) { + if (inputs[0]->NodeType() != NType::Tensor) { + return nullptr; + } + void *tensor_data = inputs[0]->As()->data()->data_c(); + tensor::TensorPtr result_tensor = std::make_shared(this->type, this->shape, tensor_data, this->type); + return std::make_shared(result_tensor); +} + +// default format shape to fractal_Nz format shape +DShape ToNz(const DShape &default_shape) { + constexpr size_t nz_size = 2; + constexpr auto align16 = 16; + auto len = default_shape.size(); + DShape leading_shape; + DShape tail_shape; + if (default_shape.size() == 1 && default_shape[0] == 1) { + // # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape + return default_shape; + } + if (default_shape.size() > nz_size) { + (void)leading_shape.insert(leading_shape.cend(), default_shape.cbegin(), + default_shape.cend() - SizeToLong(nz_size)); + } + if (default_shape.size() == 1 || (default_shape.size() >= nz_size && default_shape[len - nz_size] == 1)) { + // (32) or (N, 1, 32) -> (N, 2, 1, 1, 16) + if (default_shape.back() % align16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-1] should be multiplies of 16, but got " << default_shape.back(); + } + tail_shape = {default_shape.back() / align16, 1, 1, align16}; + } else if (default_shape.size() >= nz_size || default_shape[1] == 1) { + // (N, 32, 1) -> (N, 1, 2, 16, 1) + if (default_shape[len - nz_size] % align16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-2] should be multiplies of 16, but got " << default_shape[len - nz_size]; + } + tail_shape = {1, default_shape[0] / align16, align16, 1}; + } else { + // (N, 32, 48) -> (N, 3, 2, 16, 16) + if (default_shape.back() % align16 != 0 || default_shape[len - nz_size] % align16 != 0) { + MS_LOG(EXCEPTION) << "default_shape[-1] and default_shape[-2]should be multiplies of 16, but got " + << default_shape.back() << " " << default_shape[len - nz_size]; + } + tail_shape = {default_shape[1] / align16, default_shape[0] / align16, align16, align16}; + } + (void)leading_shape.insert(leading_shape.cend(), tail_shape.cbegin(), tail_shape.cend()); + return leading_shape; +} + +DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { + std::vector> shapes; + for (auto &input : inputs) { + if (to_nz && input->format != kOpFormat_FRAC_NZ) { + (void)shapes.emplace_back(ToNz(input->shape)); + } else { + (void)shapes.emplace_back(input->shape); + } + } + auto max_dim_input = + std::max_element(shapes.begin(), shapes.end(), + [](const std::vector &a, const std::vector &b) { return a.size() < b.size(); }); + auto max_dim = max_dim_input->size(); + std::vector> align_shapes; + for (auto &s : shapes) { + std::vector cur(max_dim - s.size(), 1); + (void)cur.insert(cur.cend(), s.cbegin(), s.cend()); + (void)align_shapes.emplace_back(cur); + } + std::vector output_shape(max_dim, 1); + for (size_t i = 0; i < max_dim; i++) { + for (auto &align_shape : align_shapes) { + if (align_shape[i] > 1) { + if (output_shape[i] == 1) { + output_shape[i] = align_shape[i]; + } + if (output_shape[i] != align_shape[i]) { + MS_LOG(EXCEPTION) << "Shape broadcast failed: " << output_shape[i] << " vs " << align_shape[i]; + } + } + } + } + return output_shape; +} + +std::vector ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + if (std::any_of(inputs.begin(), inputs.end(), + [](const NodePtr &input) { return input->format == kOpFormat_FRAC_NZ; })) { + return {BroadcastShape(inputs, true)}; + } + return PrimOp::InferShape(inputs, attrs); +} + +DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &) { + if (inputs.empty()) { + return kOpFormat_DEFAULT; + } + auto first_format = inputs[0]->format; + for (const auto &inp : inputs) { + auto cur_format = inp->format; + if (cur_format.find("FRACTAL") != std::string::npos) { + // special format + return cur_format; + } + if (cur_format != kOpFormat_DEFAULT && inp->tensor_size() != 1) { + return cur_format; + } + } + return first_format; +} + +std::vector ArgReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + CHECK_ATTR(attrs, "axis"); + auto axis = GetListInt(attrs.find("axis")->second); + const auto &input_shape = inputs[0]->shape; + int64_t size = SizeToLong(input_shape.size()); + std::vector real_axis; + (void)std::transform(axis.begin(), axis.end(), std::back_inserter(real_axis), + [&size](const int64_t &x) { return x < 0 ? (x + size) : x; }); + + DShape new_shape; + for (size_t i = 0; i < input_shape.size(); i++) { + if (std::find(real_axis.begin(), real_axis.end(), SizeToLong(i)) == real_axis.end()) { + (void)new_shape.emplace_back(input_shape[i]); + } + } + if (new_shape.empty()) { + (void)new_shape.emplace_back(1); + } + return {new_shape}; +} + +std::vector ArgReduceOp::InferType(const NodePtrList &, const DAttrs &attrs) { + CHECK_ATTR(attrs, "output_type"); + return {attrs.find("output_type")->second->cast()->type_id()}; +} + +DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { + if (attrs.count(kAttrDstFormat) != 0) { + return GetValue(attrs.find(kAttrDstFormat)->second); + } + // only support NCHW/NHWC now + constexpr size_t kRank4 = 4; + if (inputs[0]->shape.size() != kRank4) { + return kOpFormat_DEFAULT; + } + auto perm_node = inputs[1]; + auto perm_tensor = perm_node->As()->data(); + auto perm = CheckAndConvertUtils::CheckTensorIntValue("permutation", perm_tensor, "Transpose"); + const auto &ori_format = inputs[0]->format; + if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) { + std::vector nchw2nhwc = {0, 2, 3, 1}; + if (perm == nchw2nhwc) { + return kOpFormat_NHWC; + } + } else if (ori_format == kOpFormat_NHWC) { + std::vector nhwc2nchw = {0, 3, 1, 2}; + if (perm == nhwc2nchw) { + return kOpFormat_NCHW; + } + } + return kOpFormat_DEFAULT; +} + +NodePtr ConstantOfShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Tensor) { + return nullptr; + } + } + const auto &value = GetValue>(attrs.find("value")->second); + std::vector res; + size_t elem_num = LongToSize(std::accumulate(this->shape.begin(), this->shape.end(), 1, std::multiplies())); + if (value.size() == 1) { + res = std::vector(elem_num, value[0]); + } else if (value.size() == elem_num) { + res = value; + } else { + return nullptr; + } + auto tensor = std::make_shared(this->type, this->shape, &res[0], kNumberTypeFloat32); + return std::make_shared(tensor); +} + +std::vector ConstantOfShapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + const auto &value = attrs.find("shape")->second; + std::vector res; + if (value->isa()) { + res = GetValue>(value); + return {res}; + } else if (value->isa()) { + auto tvalue = value->cast(); + if (tvalue->data_type_c() == static_cast(TypeId::kNumberTypeInt32)) { + int *data = static_cast(tvalue->data_c()); + for (size_t elem = 0; elem < tvalue->DataSize(); elem++) { + res.push_back(IntToLong(*(data + elem))); + } + return {res}; + } else if (tvalue->data_type_c() == static_cast(TypeId::kNumberTypeInt64)) { + int64_t *data = static_cast(tvalue->data_c()); + res = std::vector(data, data + tvalue->DataSize()); + return {res}; + } + } + return PrimOp::InferShape(inputs, attrs); +} + +NodePtr ShapeOp::InferValue(const NodePtrList &inputs, const DAttrs &) { + auto tensor = std::make_shared(this->type, this->shape, inputs[0]->shape.data(), kNumberTypeInt64); + return std::make_shared(tensor); +} + +std::vector PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + std::vector shape0 = inputs[0]->shape; + size_t n = shape0.size(); + CHECK_ATTR(attrs, "head"); + CHECK_ATTR(attrs, "tail"); + std::vector pad_before = GetListInt(attrs.find("head")->second); + std::vector pad_after = GetListInt(attrs.find("tail")->second); + if (pad_before.size() != n || pad_after.size() != n) { + MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs " + << pad_after.size(); + } + std::vector output; + for (size_t i = 0; i < n; i++) { + (void)output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]); + } + return {output}; +} + +std::vector UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + std::vector shape0 = inputs[0]->shape; + size_t n = shape0.size(); + CHECK_ATTR(attrs, "tail"); + std::vector unpad_after = GetListInt(attrs.find("tail")->second); + if (unpad_after.size() != n) { + MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size(); + } + std::vector output; + for (size_t i = 0; i < n; i++) { + (void)output.emplace_back(shape0[i] - unpad_after[i]); + } + return {output}; +} + +bool Conv2dOp::HadPad(const ShapeVector &pad_list, const std::string &pad_mode) { + constexpr size_t kTop = 0; + constexpr size_t kBottom = 1; + constexpr size_t kLeft = 2; + constexpr size_t kRight = 3; + + if (pad_list[kTop] != pad_list[kBottom] || pad_list[kLeft] != pad_list[kRight]) { + return true; + } + if (pad_mode != "VALID" && pad_mode != "valid") { + return std::any_of(pad_list.begin(), pad_list.end(), [](auto a) { return a != 0; }); + } + return false; +} + +std::vector Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + // get the output shape when format is NHWC/NCHW + if (inputs[0]->shape.size() == kDim4) { + CHECK_ATTR(attrs, "format"); + if (inputs[0]->format == kOpFormat_NHWC || inputs[1]->format == kOpFormat_NHWC || + GetValue(attrs.find("format")->second) == kOpFormat_NHWC) { + CHECK_ATTR(attrs, "pad_mode"); + CHECK_ATTR(attrs, "pad_list"); + CHECK_ATTR(attrs, "kernel_size"); + CHECK_ATTR(attrs, "stride"); + CHECK_ATTR(attrs, "dilation"); + + auto x_shape = inputs[0]->shape; + auto w_shape = inputs[1]->shape; + auto pad_mode = GetValue(attrs.find("pad_mode")->second); + auto pad_list = GetListInt(attrs.find("pad_list")->second); + auto kernel_size = GetListInt(attrs.find("kernel_size")->second); + auto stride = GetListInt(attrs.find("stride")->second); + auto dilation = GetListInt(attrs.find("dilation")->second); + constexpr size_t kPadSize = 4; + constexpr size_t kKernelSize = 2; + constexpr size_t kStrideSize = 4; + constexpr size_t kDilationSize = 4; + if (x_shape.size() != kDim4 || w_shape.size() != kDim4 || pad_list.size() != kPadSize || + kernel_size.size() != kKernelSize || stride.size() != kStrideSize || dilation.size() != kDilationSize) { + MS_LOG(EXCEPTION) << "For 'Conv2D', got sizes of x_shape, w_shape, pad_list, kernel_size, stride and dilation: " + << x_shape.size() << ", " << w_shape.size() << ", " << pad_list.size() << ", " + << kernel_size.size() << ", " << stride.size() << ", " << dilation.size() + << ". But expect: 4, 4, 4, 2, 4, 4"; + } + auto has_pad = HadPad(pad_list, pad_mode); + if (!has_pad) { + pad_list = {0, 0, 0, 0}; + } + + auto k_h = (kernel_size[0] - 1) * dilation[2] + 1; + auto k_w = (kernel_size[1] - 1) * dilation[3] + 1; + auto out_h = (x_shape[1] + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1; + auto out_w = (x_shape[2] + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1; + return {{x_shape[0], out_h, out_w, w_shape[3]}}; + } else { + return OpaqueOp::InferShape(inputs, attrs); + } + } + + // get the output shape when format is NCHWc + std::vector data_shape = inputs[0]->shape; + std::vector weight_shape = inputs[1]->shape; + auto n = data_shape[0]; + auto i_h = data_shape[2]; + auto i_w = data_shape[3]; + auto c_o_o = weight_shape[0]; + auto k_h = weight_shape[2]; + auto k_w = weight_shape[3]; + auto c_o_i = weight_shape[5]; + + CHECK_ATTR(attrs, "stride"); + CHECK_ATTR(attrs, "dilation"); + + std::vector strides = GetListInt(attrs.find("stride")->second); + std::vector dilations = GetListInt(attrs.find("dilation")->second); + + auto d_h = dilations[0]; + auto d_w = dilations[1]; + auto s_h = strides[0]; + auto s_w = strides[1]; + auto k_h_d = (k_h - 1) * d_h + 1; + auto k_w_d = (k_w - 1) * d_w + 1; + auto o_h = (i_h - k_h_d) / s_h + 1; + auto o_w = (i_w - k_w_d) / s_w + 1; + + std::vector output_shape{n, c_o_o, o_h, o_w, c_o_i}; + return {output_shape}; +} + +std::vector Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { + if (inputs[0]->shape.size() == kDim4) { + return PrimOp::InferType(inputs, attrs); + } + return {inputs[0]->type}; +} + +DFormat Conv2dOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { + if (inputs[0]->shape.size() == kDim4) { + return PrimOp::InferFormat(inputs, attrs); + } + CHECK_ATTR(attrs, "conv_out_format"); + return GetValue(attrs.find("conv_out_format")->second); +} + +void ConcatOp::RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) { + AbstractBasePtrList rectifyed_abs_list; + (void)rectifyed_abs_list.emplace_back(std::make_shared(*input_abstract_ptr)); + input_abstract_ptr->swap(rectifyed_abs_list); +} + +void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "keep_dims"); + (void)abs_list->emplace_back(prim->GetAttr("keep_dims")->ToAbstract()); + if (prim->name() == prim::kPrimReduceSum->name()) { + CHECK_ATTR(prim->attrs(), "skip_mode"); + (void)abs_list->emplace_back(prim->GetAttr("skip_mode")->ToAbstract()); + } +} + +void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "axis"); + (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract()); +} + +void CumSumOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "exclusive"); + (void)abs_list->emplace_back(prim->GetAttr("exclusive")->ToAbstract()); + CHECK_ATTR(prim->attrs(), "reverse"); + (void)abs_list->emplace_back(prim->GetAttr("reverse")->ToAbstract()); +} + +void GatherOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "batch_dims"); + (void)abs_list->emplace_back(prim->GetAttr("batch_dims")->ToAbstract()); +} + +void ArgReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "axis"); + (void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract()); + CHECK_ATTR(prim->attrs(), "output_type"); + (void)abs_list->emplace_back(prim->GetAttr("output_type")->ToAbstract()); +} + +void PagedAttentionOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + constexpr size_t PA_INPUT_NUM = 5; + constexpr size_t PA_MASK_INPUT_NUM = 6; + if (abs_list->size() == PA_INPUT_NUM || abs_list->size() == PA_MASK_INPUT_NUM) { + CHECK_ATTR(prim->attrs(), "head_num"); + (void)abs_list->emplace_back(prim->GetAttr("head_num")->ToAbstract()); + CHECK_ATTR(prim->attrs(), "scale_value"); + (void)abs_list->emplace_back(prim->GetAttr("scale_value")->ToAbstract()); + CHECK_ATTR(prim->attrs(), "kv_head_num"); + (void)abs_list->emplace_back(prim->GetAttr("kv_head_num")->ToAbstract()); + } +} + +std::vector CompactShape(const ShapeVector &origin, int64_t axis) { + std::vector new_shape; + size_t accu = 1; + for (size_t i = 0; i < origin.size(); i++) { + if (LongToSize(axis) == i) { + new_shape.push_back(accu); + new_shape.push_back(LongToSize(origin[i])); + accu = 1; + } else { + accu *= LongToSize(origin[i]); + } + } + new_shape.push_back(accu); + return new_shape; +} + +template +tensor::TensorPtr GatherOp::CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const { + constexpr size_t param_index = 0; + constexpr size_t indice_index = 1; + constexpr size_t axis_index = 2; + constexpr size_t input_num = 3; + constexpr size_t first_dim = 0; + constexpr size_t second_dim = 1; + constexpr size_t third_dim = 2; + int64_t axis = 0; + if (attrs.count("axis") > 0) { + axis = GetValue(attrs.find("axis")->second); + } else if (inputs.size() == input_num) { + int *data_axis = + static_cast(std::static_pointer_cast(inputs[axis_index])->data()->data_c()); + axis = IntToLong(*data_axis); + } else { + return nullptr; + } + ShapeVector param_shp = inputs[param_index]->shape; + axis = axis < 0 ? SizeToLong(param_shp.size()) + axis : axis; + std::vector indices; + switch (static_cast(inputs[indice_index]->type)) { + case TypeId::kNumberTypeInt8: { + indices = ChangeDataToVec(inputs[indice_index]); + break; + } + case TypeId::kNumberTypeInt16: { + indices = ChangeDataToVec(inputs[indice_index]); + break; + } + case TypeId::kNumberTypeInt32: { + indices = ChangeDataToVec(inputs[indice_index]); + break; + } + case TypeId::kNumberTypeInt64: { + indices = ChangeDataToVec(inputs[indice_index]); + break; + } + default: + return nullptr; + } + + TM *input_x = + static_cast(std::static_pointer_cast(inputs[param_index])->data()->data_c()); + std::vector compact_shp = CompactShape(param_shp, axis); + std::vector res; + if (compact_shp.size() == input_num) { + for (size_t i = 0; i < compact_shp[first_dim]; i++) { + for (auto j : indices) { + for (size_t k = 0; k < compact_shp[third_dim]; k++) { + (void)res.emplace_back( + input_x[i * compact_shp[second_dim] * compact_shp[third_dim] + j * compact_shp[third_dim] + k]); + } + } + } + return std::make_shared(this->type, this->shape, &res[0], this->type); + } + return nullptr; +} + +NodePtr GatherOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Tensor) { + return nullptr; + } + } + TypeId output_type = this->type; + tensor::TensorPtr res = nullptr; + switch (static_cast(output_type)) { + case TypeId::kNumberTypeUInt8: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt8: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt16: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt32: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt64: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt16: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt32: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt64: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat16: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat32: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat64: { + res = CalcGather(inputs, attrs); + break; + } + case TypeId::kNumberTypeBFloat16: { + res = CalcGather(inputs, attrs); + break; + } + default: + return nullptr; + } + return res == nullptr ? nullptr : std::make_shared(res); +} + +template +tensor::TensorPtr ConcatOp::CalcConcat(const NodePtrList &inputs, const DAttrs &attrs) { + constexpr size_t first_dim = 0; + constexpr size_t second_dim = 1; + constexpr size_t third_dim = 2; + int64_t axis = 0; + auto axis_node = inputs.back(); + if (axis_node->NodeType() == NType::Scalar) { + auto scalar_node = axis_node->As(); + axis = GetValue(scalar_node->data()); + } else { + return nullptr; + } + axis = axis < 0 ? SizeToLong(this->shape.size()) + axis : axis; + std::vector> inputs_tm; + for (const auto &t : inputs) { + (void)inputs_tm.emplace_back(ChangeDataToVec(t)); + } + std::vector> all_shps; + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(all_shps), + [&axis](const NodePtr &t) { return CompactShape(t->shape, axis); }); + std::vector res; + if (all_shps.size() > 0) { + const size_t third_dim_size = all_shps[0][third_dim]; + const size_t first_dim_size = all_shps[0][first_dim]; + for (size_t i = 0; i < first_dim_size; i++) { + for (size_t t = 0; t < inputs_tm.size(); t++) { + for (size_t j = 0; j < all_shps[t][second_dim]; j++) { + for (size_t k = 0; k < third_dim_size; k++) { + (void)res.emplace_back(inputs_tm[t][i * all_shps[t][second_dim] * third_dim_size + j * third_dim_size + k]); + } + } + } + } + return std::make_shared(this->type, this->shape, &res[0], this->type); + } + return nullptr; +} + +NodePtr ConcatOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Tensor) { + return nullptr; + } + } + TypeId output_type = this->type; + tensor::TensorPtr res = nullptr; + switch (static_cast(output_type)) { + case TypeId::kNumberTypeUInt8: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt8: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt16: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt32: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt64: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt16: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt32: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt64: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat16: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat32: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat64: { + res = CalcConcat(inputs, attrs); + break; + } + case TypeId::kNumberTypeBFloat16: { + res = CalcConcat(inputs, attrs); + break; + } + default: + return nullptr; + } + return res == nullptr ? nullptr : std::make_shared(res); +} + +std::vector LayoutTransformOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + CHECK_ATTR(attrs, kAttrSrcFormat); + CHECK_ATTR(attrs, kAttrDstFormat); + auto src_format = GetValue(attrs.find(kAttrSrcFormat)->second); + auto dst_format = GetValue(attrs.find(kAttrDstFormat)->second); + std::vector data_shape = inputs[0]->shape; + if (src_format == kOpFormat_NHWC) { + auto n = data_shape[0]; + auto h = data_shape[1]; + auto w = data_shape[2]; + auto c = data_shape[3]; + auto c_o_i = GkUtils::GetChannelInConvFormat(dst_format); + if (c_o_i == 0) { + c_o_i = 1; + } + auto c_o_o = c / c_o_i; + std::vector output_shape{n, c_o_o, h, w, c_o_i}; + return {output_shape}; + } + if (dst_format == kOpFormat_NHWC) { + auto n = data_shape[0]; + auto c_o_o = data_shape[1]; + auto h = data_shape[2]; + auto w = data_shape[3]; + auto c_o_i = data_shape[4]; + auto c = c_o_o * c_o_i; + std::vector output_shape{n, h, w, c}; + return {output_shape}; + } + // LayoutTransform between nchwnc + auto n = data_shape[0]; + auto c_o_o = data_shape[1]; + auto h = data_shape[2]; + auto w = data_shape[3]; + auto c_o_i = data_shape[4]; + auto c_o_i_new = GkUtils::GetChannelInConvFormat(dst_format); + if (c_o_i_new == 0) { + c_o_i_new = 1; + } + auto c_o_o_new = c_o_o * c_o_i / c_o_i_new; + std::vector output_shape{n, c_o_o_new, h, w, c_o_i_new}; + return {output_shape}; +} + +std::vector Pool2DOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + CHECK_ATTR(attrs, "global"); + std::vector input_shape = inputs[0]->shape; + bool is_nhwc = input_shape.size() == 4; + int64_t n = input_shape[0]; + int64_t c; + int64_t h; + int64_t w; + if (is_nhwc) { + constexpr size_t h_idx = 1; + constexpr size_t w_idx = 2; + constexpr size_t c_idx = 3; + h = input_shape[h_idx]; + w = input_shape[w_idx]; + c = input_shape[c_idx]; + } else { + constexpr size_t c_idx = 1; + constexpr size_t h_idx = 2; + constexpr size_t w_idx = 3; + c = input_shape[c_idx]; + h = input_shape[h_idx]; + w = input_shape[w_idx]; + } + + if (GetValue(attrs.find("global")->second)) { + h = 1; + w = 1; + } else { + CHECK_ATTR(attrs, "strides"); + CHECK_ATTR(attrs, "kernel_size"); + CHECK_ATTR(attrs, "round_mode"); + std::vector strides = GetListInt(attrs.find("strides")->second); + std::vector kernels = GetListInt(attrs.find("kernel_size")->second); + if (AnfUtils::GetIntValue(attrs.find("round_mode")->second) == 0) { + // ceil mode + h = ((h - kernels[0] + strides[0] - 1) / strides[0]) + 1; + w = ((w - kernels[1] + strides[1] - 1) / strides[1]) + 1; + } else { + // round mode + h = ((h - kernels[0]) / strides[0]) + 1; + w = ((w - kernels[1]) / strides[1]) + 1; + } + } + if (is_nhwc) { + return {{n, h, w, c}}; + } else { + auto ci = input_shape[4]; + return {{n, c, h, w, ci}}; + } +} + +void ComplexOp::Check(const NodePtrList &inputs, const DAttrs &) { + if (inputs[0]->type != TypeId::kNumberTypeFloat32) { + MS_LOG(EXCEPTION) << "Complex's input[0] should be float32, but got " << TypeIdToString(inputs[0]->type, true); + } + if (inputs[0]->type != inputs[1]->type) { + MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch: " << TypeIdToString(inputs[0]->type, true) + << " vs " << TypeIdToString(inputs[1]->type, true); + } +} + +std::vector StandardNormalOp::InferShape(const NodePtrList &, const DAttrs &attrs) { + CHECK_ATTR(attrs, "shape"); + return {GetListInt(attrs.find("shape")->second)}; +} + +template +tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const { + constexpr size_t input_index = 0; + constexpr size_t begin_index = 1; + constexpr size_t end_index = 2; + constexpr size_t axes_index = 3; + constexpr size_t stride_index = 4; + + ShapeVector input_shape = inputs[input_index]->shape; + std::vector begin = ChangeDataToVec(inputs[begin_index]); + std::vector end = ChangeDataToVec(inputs[end_index]); + std::vector axes = ChangeDataToVec(inputs[axes_index]); + std::vector stride = ChangeDataToVec(inputs[stride_index]); + + std::unordered_map> info; + for (size_t i = 0; i < axes.size(); i++) { + int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i]; + if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) { + MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative."; + return nullptr; + } + std::unordered_set pos; + int index = begin[i]; + while (index < end[i]) { + (void)pos.insert(IntToSize(index)); + index += stride[i]; + } + (void)info.emplace(axis, pos); + } + + TM *input_x = + static_cast(std::static_pointer_cast(inputs[input_index])->data()->data_c()); + + std::vector res; + + std::function func; + func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) { + if ((dim + 1) == input_shape.size()) { + for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { + if (info.count(SizeToInt(dim)) > 0) { + if (info[SizeToInt(dim)].count(i) > 0) { + (void)res.emplace_back(input_x[offset + i]); + } + } else { + (void)res.emplace_back(input_x[offset + i]); + } + } + } else if ((dim + 1) < input_shape.size()) { + size_t accu = 1; + for (size_t j = dim + 1; j < input_shape.size(); j++) { + accu *= LongToSize(input_shape[j]); + } + for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { + if (info.count(SizeToInt(dim)) > 0) { + if (info[SizeToInt(dim)].count(i) > 0) { + func(dim + 1, offset + i * accu); + } + } else { + func(dim + 1, offset + i * accu); + } + } + } + return; + }; + func(0, 0); + return std::make_shared(this->type, this->shape, &res[0], this->type); +} + +NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Tensor) { + return nullptr; + } + } + TypeId output_type = this->type; + tensor::TensorPtr res = nullptr; + switch (static_cast(output_type)) { + case TypeId::kNumberTypeUInt8: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt8: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeBFloat16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + default: + return nullptr; + } + return res == nullptr ? nullptr : std::make_shared(res); +} + +void MatMulOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) { + CHECK_ATTR(prim->attrs(), "transpose_a"); + (void)abs_list->emplace_back(prim->GetAttr("transpose_a")->ToAbstract()); + CHECK_ATTR(prim->attrs(), "transpose_b"); + (void)abs_list->emplace_back(prim->GetAttr("transpose_b")->ToAbstract()); +} + +std::vector MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + // the prim's infer shape does not supports batch dims + constexpr size_t kMatMulRank = 2; + if (inputs[0]->shape.size() > kMatMulRank || inputs[1]->shape.size() > kMatMulRank) { + NodePtrList new_inputs = inputs; + std::vector batches(inputs.size()); + auto cut_batches = [&new_inputs, &batches, kMatMulRank](size_t i) -> void { + const auto &shape_i = new_inputs[i]->shape; + if (shape_i.size() > kMatMulRank) { + DShape real_shape(shape_i.cend() - kMatMulRank, shape_i.cend()); + new_inputs[i] = std::make_shared(NodeBase{real_shape, new_inputs[i]->type, new_inputs[i]->format}); + batches[i].assign(shape_i.cbegin(), shape_i.cend() - kMatMulRank); + } + }; + + cut_batches(0); + cut_batches(1); + if (batches[0].size() != batches[1].size()) { + MS_LOG(EXCEPTION) << "The Matmul's batch rank should be equal, but got " << batches[0].size() << " vs " + << batches[1].size(); + } + DShape batch; + for (size_t i = 0; i < batches[0].size(); i++) { + if (batches[0][i] != batches[1][i]) { + if (batches[0][i] != 1 && batches[1][i] != 1) { + MS_LOG(EXCEPTION) << "The Matmul's batch dim is unmatched. got " << inputs[0]->shape << " and " + << inputs[1]->shape; + } + } + batch.push_back(std::max(batches[0][i], batches[1][i])); + } + + auto out_shape = PrimOp::InferShape(new_inputs, attrs)[0]; + // just reuse the `batch` vector + (void)batch.insert(batch.end(), out_shape.begin(), out_shape.end()); + return {batch}; + } + return PrimOp::InferShape(inputs, attrs); +} + +std::vector MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { + if (attrs.count("dst_type") != 0) { + return {attrs.find("dst_type")->second->cast()->type_id()}; + } + if (inputs[0]->type == TypeId::kNumberTypeInt8) { + return {TypeId::kNumberTypeInt32}; + } + return {inputs[0]->type}; +} +} // namespace mindspore::graphkernel::inner diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.h b/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.h index a7dbacb88ec..57307145889 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/op_node.h @@ -1,407 +1,407 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ - -#include -#include -#include -#include -#include "ops/primitive_c.h" -#include "backend/common/graph_kernel/model/node.h" -#include "ir/dtype/type.h" -#include "include/backend/visible.h" - -namespace mindspore::graphkernel::inner { -#define CHECK_ATTR(attrs, attr_name) \ - do { \ - if (attrs.count(attr_name) == 0) { \ - MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \ - } \ - } while (0) - -class BACKEND_EXPORT PrimOp : public Node { - public: - enum class ComputeType : int { - VIRTUAL = 0, - RESHAPE = 1, - ELEMWISE = 2, - BROADCAST = 3, - REDUCE = 4, - OPAQUE = 5, - }; - - PrimOp(const std::string &op, ComputeType compute) - : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {} - ~PrimOp() = default; - - NodeBaseList Infer(const NodePtrList &inputs, const DAttrs &attrs); - - std::string ToString() const override; - NType NodeType() override { return NType::Primitive; } - - const std::string &op() const { return op_; } - ComputeType compute_type() const { return compute_type_; } - // infer output value when all inputs are constant - virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs); - - protected: - // Check node info before inference the shape/type/format. - virtual void Check(const NodePtrList &, const DAttrs &) {} - - // Infer format. assume all outputs have the same format. - virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) { return inputs[0]->format; } - - // Infer shape. returning an empty vector means using PrimitiveC's infer_shape function. - virtual std::vector InferShape(const NodePtrList &, const DAttrs &); - - // Infer type. returning an empty vector means using PrimitiveC's infer_type function. - virtual std::vector InferType(const NodePtrList &, const DAttrs &); - - // calculate const inputs, used for InferValue - template - tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const DAttrs &) const; - - // Gen PrimitiveC and abstract list to call PrimitiveC's inference function. - std::pair GenPrimAndAbstract(const NodePtrList &inputs, const DAttrs &attrs) const; - - // rectify abstract before calling PrimitiveC's inference function. - virtual void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *) {} - - std::string op_; - ComputeType compute_type_; -}; -using PrimOpPtr = std::shared_ptr; - -class ReshapeOp : public PrimOp { - public: - explicit ReshapeOp(const std::string &op) : PrimOp(op, ComputeType::RESHAPE) {} - ~ReshapeOp() = default; - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; - - protected: - DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { - return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT - : GetValue(attrs.find("format")->second); - } -}; - -class ElemwiseOp : public PrimOp { - public: - explicit ElemwiseOp(const std::string &op) : PrimOp(op, ComputeType::ELEMWISE) {} - ~ElemwiseOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) override; -}; - -class BroadcastOp : public PrimOp { - public: - explicit BroadcastOp(const std::string &op) : PrimOp(op, ComputeType::BROADCAST) {} - ~BroadcastOp() = default; -}; - -class TileOp : public BroadcastOp { - public: - explicit TileOp(const std::string &op) : BroadcastOp(op) {} - ~TileOp() = default; -}; - -class ReduceOp : public PrimOp { - public: - explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {} - ~ReduceOp() = default; - - protected: - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class ArgReduceOp : public ReduceOp { - public: - explicit ArgReduceOp(const std::string &op) : ReduceOp(op) {} - ~ArgReduceOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &, const DAttrs &attrs) override; - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class OpaqueOp : public PrimOp { - public: - explicit OpaqueOp(const std::string &op) : PrimOp(op, ComputeType::OPAQUE) {} - ~OpaqueOp() = default; - - protected: - // for pclint warning: 1790 public base symbol of symbol has no non-destructor virtual functions - virtual void DoNothing() {} -}; - -class VirtualOp : public PrimOp { - public: - explicit VirtualOp(const std::string &op) : PrimOp(op, ComputeType::VIRTUAL) {} - ~VirtualOp() = default; -}; - -class TransposeOp : public OpaqueOp { - public: - explicit TransposeOp(const std::string &op) : OpaqueOp(op) {} - ~TransposeOp() = default; - - protected: - DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; -}; - -class OneHotOp : public OpaqueOp { - public: - explicit OneHotOp(const std::string &op) : OpaqueOp(op) {} - ~OneHotOp() = default; - - protected: - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class CumSumOp : public OpaqueOp { - public: - explicit CumSumOp(const std::string &op) : OpaqueOp(op) {} - ~CumSumOp() = default; - - protected: - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class LayoutTransformOp : public OpaqueOp { - public: - explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {} - ~LayoutTransformOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } - DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { - return GetValue(attrs.find("dst_format")->second); - } -}; - -class ElemAnyOp : public OpaqueOp { - public: - explicit ElemAnyOp(const std::string &op) : OpaqueOp(op) {} - ~ElemAnyOp() = default; - - protected: - std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override { - auto iter = attrs.find("empty_shape"); - if (iter != attrs.end() && GetValue(iter->second) == true) { - return {{}}; - } - return {{1}}; - } - std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } -}; - -class ShapeOp : public OpaqueOp { - public: - explicit ShapeOp(const std::string &op) : OpaqueOp(op) {} - ~ShapeOp() = default; - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { - return {{SizeToLong(inputs[0]->shape.size())}}; - } - std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeInt32}; } - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; -}; - -class ConstantOfShapeOp : public OpaqueOp { - public: - explicit ConstantOfShapeOp(const std::string &op) : OpaqueOp(op) {} - ~ConstantOfShapeOp() = default; - - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &, const DAttrs &attrs) override { - return {static_cast(GetValue(attrs.find("data_type")->second))}; - } - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } -}; - -class PadAkgOp : public OpaqueOp { - public: - explicit PadAkgOp(const std::string &op) : OpaqueOp(op) {} - ~PadAkgOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } -}; - -class UnPadAkgOp : public OpaqueOp { - public: - explicit UnPadAkgOp(const std::string &op) : OpaqueOp(op) {} - ~UnPadAkgOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } -}; - -class Conv2dOp : public OpaqueOp { - public: - explicit Conv2dOp(const std::string &op) : OpaqueOp(op) {} - ~Conv2dOp() = default; - static bool HadPad(const ShapeVector &pad_list, const std::string &pad_mode); - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &attrs) override; - DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; -}; - -class GatherOp : public OpaqueOp { - public: - explicit GatherOp(const std::string &op) : OpaqueOp(op) {} - ~GatherOp() = default; - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; - - protected: - template - tensor::TensorPtr CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const; - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class ConcatOp : public OpaqueOp { - public: - explicit ConcatOp(const std::string &op) : OpaqueOp(op) {} - ~ConcatOp() = default; - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; - - protected: - template - tensor::TensorPtr CalcConcat(const NodePtrList &inputs, const DAttrs &attrs); - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; - -class CImagRealOp : public ElemwiseOp { - public: - explicit CImagRealOp(const std::string &op) : ElemwiseOp(op) {} - ~CImagRealOp() = default; - - protected: - void Check(const NodePtrList &inputs, const DAttrs &) override { - if (inputs[0]->type != TypeId::kNumberTypeComplex64) { - MS_LOG(EXCEPTION) << op_ << "'s input[0] should be complex64, but got " << TypeIdToString(inputs[0]->type, true); - } - }; - - std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } - std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } -}; - -class Pool2DOp : public OpaqueOp { - public: - explicit Pool2DOp(const std::string &op) : OpaqueOp(op) {} - ~Pool2DOp() = default; - - protected: - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } -}; - -class ComplexOp : public ElemwiseOp { - public: - explicit ComplexOp(const std::string &op) : ElemwiseOp(op) {} - ~ComplexOp() = default; - - protected: - void Check(const NodePtrList &inputs, const DAttrs &) override; - std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } - std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeComplex64}; } -}; - -class StandardNormalOp : public OpaqueOp { - public: - explicit StandardNormalOp(const std::string &op) : OpaqueOp(op) {} - ~StandardNormalOp() = default; - - protected: - std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } -}; - -class StridedSliceOp : public OpaqueOp { - public: - explicit StridedSliceOp(const std::string &op) : OpaqueOp(op) {} - ~StridedSliceOp() = default; - void RectifyAbstract(const PrimitivePtr &p, AbstractBasePtrList *input_abstract_ptr) override { - input_abstract_ptr->push_back(p->GetAttr("begin_mask")->ToAbstract()); - input_abstract_ptr->push_back(p->GetAttr("end_mask")->ToAbstract()); - input_abstract_ptr->push_back(p->GetAttr("ellipsis_mask")->ToAbstract()); - input_abstract_ptr->push_back(p->GetAttr("new_axis_mask")->ToAbstract()); - input_abstract_ptr->push_back(p->GetAttr("shrink_axis_mask")->ToAbstract()); - } -}; - -class StridedSliceOnnxOp : public OpaqueOp { - public: - explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {} - ~StridedSliceOnnxOp() = default; - NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; - - protected: - template - tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const; - std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override { - return GetValue>(attrs.find("output_shape")->second); - } - std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } - DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } -}; - -class MatMulOp : public OpaqueOp { - public: - explicit MatMulOp(const std::string &op) : OpaqueOp(op) {} - ~MatMulOp() = default; - - protected: - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; - std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; - std::vector InferType(const NodePtrList &inputs, const DAttrs &attrs) override; -}; - -class TupleGetItemOp : public VirtualOp { - public: - using VirtualOp::VirtualOp; - ~TupleGetItemOp() = default; -}; - -class PagedAttentionOp : public OpaqueOp { - public: - explicit PagedAttentionOp(const std::string &op) : OpaqueOp(op) {} - ~PagedAttentionOp() = default; - - protected: - void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; -}; -} // namespace mindspore::graphkernel::inner -#endif +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ + +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "backend/common/graph_kernel/model/node.h" +#include "ir/dtype/type.h" +#include "include/backend/visible.h" + +namespace mindspore::graphkernel::inner { +#define CHECK_ATTR(attrs, attr_name) \ + do { \ + if (attrs.count(attr_name) == 0) { \ + MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \ + } \ + } while (0) + +class BACKEND_EXPORT PrimOp : public Node { + public: + enum class ComputeType : int { + VIRTUAL = 0, + RESHAPE = 1, + ELEMWISE = 2, + BROADCAST = 3, + REDUCE = 4, + OPAQUE = 5, + }; + + PrimOp(const std::string &op, ComputeType compute) + : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {} + ~PrimOp() = default; + + NodeBaseList Infer(const NodePtrList &inputs, const DAttrs &attrs); + + std::string ToString() const override; + NType NodeType() override { return NType::Primitive; } + + const std::string &op() const { return op_; } + ComputeType compute_type() const { return compute_type_; } + // infer output value when all inputs are constant + virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs); + + protected: + // Check node info before inference the shape/type/format. + virtual void Check(const NodePtrList &, const DAttrs &) {} + + // Infer format. assume all outputs have the same format. + virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) { return inputs[0]->format; } + + // Infer shape. returning an empty vector means using PrimitiveC's infer_shape function. + virtual std::vector InferShape(const NodePtrList &, const DAttrs &); + + // Infer type. returning an empty vector means using PrimitiveC's infer_type function. + virtual std::vector InferType(const NodePtrList &, const DAttrs &); + + // calculate const inputs, used for InferValue + template + tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const DAttrs &) const; + + // Gen PrimitiveC and abstract list to call PrimitiveC's inference function. + std::pair GenPrimAndAbstract(const NodePtrList &inputs, const DAttrs &attrs) const; + + // rectify abstract before calling PrimitiveC's inference function. + virtual void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *) {} + + std::string op_; + ComputeType compute_type_; +}; +using PrimOpPtr = std::shared_ptr; + +class ReshapeOp : public PrimOp { + public: + explicit ReshapeOp(const std::string &op) : PrimOp(op, ComputeType::RESHAPE) {} + ~ReshapeOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; + + protected: + DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { + return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT + : GetValue(attrs.find("format")->second); + } +}; + +class ElemwiseOp : public PrimOp { + public: + explicit ElemwiseOp(const std::string &op) : PrimOp(op, ComputeType::ELEMWISE) {} + ~ElemwiseOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + DFormat InferFormat(const NodePtrList &inputs, const DAttrs &) override; +}; + +class BroadcastOp : public PrimOp { + public: + explicit BroadcastOp(const std::string &op) : PrimOp(op, ComputeType::BROADCAST) {} + ~BroadcastOp() = default; +}; + +class TileOp : public BroadcastOp { + public: + explicit TileOp(const std::string &op) : BroadcastOp(op) {} + ~TileOp() = default; +}; + +class ReduceOp : public PrimOp { + public: + explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {} + ~ReduceOp() = default; + + protected: + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class ArgReduceOp : public ReduceOp { + public: + explicit ArgReduceOp(const std::string &op) : ReduceOp(op) {} + ~ArgReduceOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &, const DAttrs &attrs) override; + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class OpaqueOp : public PrimOp { + public: + explicit OpaqueOp(const std::string &op) : PrimOp(op, ComputeType::OPAQUE) {} + ~OpaqueOp() = default; + + protected: + // for pclint warning: 1790 public base symbol of symbol has no non-destructor virtual functions + virtual void DoNothing() {} +}; + +class VirtualOp : public PrimOp { + public: + explicit VirtualOp(const std::string &op) : PrimOp(op, ComputeType::VIRTUAL) {} + ~VirtualOp() = default; +}; + +class TransposeOp : public OpaqueOp { + public: + explicit TransposeOp(const std::string &op) : OpaqueOp(op) {} + ~TransposeOp() = default; + + protected: + DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; +}; + +class OneHotOp : public OpaqueOp { + public: + explicit OneHotOp(const std::string &op) : OpaqueOp(op) {} + ~OneHotOp() = default; + + protected: + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class CumSumOp : public OpaqueOp { + public: + explicit CumSumOp(const std::string &op) : OpaqueOp(op) {} + ~CumSumOp() = default; + + protected: + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class LayoutTransformOp : public OpaqueOp { + public: + explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {} + ~LayoutTransformOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } + DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override { + return GetValue(attrs.find("dst_format")->second); + } +}; + +class ElemAnyOp : public OpaqueOp { + public: + explicit ElemAnyOp(const std::string &op) : OpaqueOp(op) {} + ~ElemAnyOp() = default; + + protected: + std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override { + auto iter = attrs.find("empty_shape"); + if (iter != attrs.end() && GetValue(iter->second) == true) { + return {{}}; + } + return {{1}}; + } + std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } +}; + +class ShapeOp : public OpaqueOp { + public: + explicit ShapeOp(const std::string &op) : OpaqueOp(op) {} + ~ShapeOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &) override; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { + return {{SizeToLong(inputs[0]->shape.size())}}; + } + std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeInt32}; } + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; +}; + +class ConstantOfShapeOp : public OpaqueOp { + public: + explicit ConstantOfShapeOp(const std::string &op) : OpaqueOp(op) {} + ~ConstantOfShapeOp() = default; + + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &, const DAttrs &attrs) override { + return {static_cast(GetValue(attrs.find("data_type")->second))}; + } + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } +}; + +class PadAkgOp : public OpaqueOp { + public: + explicit PadAkgOp(const std::string &op) : OpaqueOp(op) {} + ~PadAkgOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } +}; + +class UnPadAkgOp : public OpaqueOp { + public: + explicit UnPadAkgOp(const std::string &op) : OpaqueOp(op) {} + ~UnPadAkgOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } +}; + +class Conv2dOp : public OpaqueOp { + public: + explicit Conv2dOp(const std::string &op) : OpaqueOp(op) {} + ~Conv2dOp() = default; + static bool HadPad(const ShapeVector &pad_list, const std::string &pad_mode); + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &attrs) override; + DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; +}; + +class GatherOp : public OpaqueOp { + public: + explicit GatherOp(const std::string &op) : OpaqueOp(op) {} + ~GatherOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; + + protected: + template + tensor::TensorPtr CalcGather(const NodePtrList &inputs, const DAttrs &attrs) const; + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class ConcatOp : public OpaqueOp { + public: + explicit ConcatOp(const std::string &op) : OpaqueOp(op) {} + ~ConcatOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; + + protected: + template + tensor::TensorPtr CalcConcat(const NodePtrList &inputs, const DAttrs &attrs); + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }; + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; + +class CImagRealOp : public ElemwiseOp { + public: + explicit CImagRealOp(const std::string &op) : ElemwiseOp(op) {} + ~CImagRealOp() = default; + + protected: + void Check(const NodePtrList &inputs, const DAttrs &) override { + if (inputs[0]->type != TypeId::kNumberTypeComplex64) { + MS_LOG(EXCEPTION) << op_ << "'s input[0] should be complex64, but got " << TypeIdToString(inputs[0]->type, true); + } + }; + + std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } + std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } +}; + +class Pool2DOp : public OpaqueOp { + public: + explicit Pool2DOp(const std::string &op) : OpaqueOp(op) {} + ~Pool2DOp() = default; + + protected: + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } +}; + +class ComplexOp : public ElemwiseOp { + public: + explicit ComplexOp(const std::string &op) : ElemwiseOp(op) {} + ~ComplexOp() = default; + + protected: + void Check(const NodePtrList &inputs, const DAttrs &) override; + std::vector InferShape(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->shape}; } + std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeComplex64}; } +}; + +class StandardNormalOp : public OpaqueOp { + public: + explicit StandardNormalOp(const std::string &op) : OpaqueOp(op) {} + ~StandardNormalOp() = default; + + protected: + std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &, const DAttrs &) override { return {TypeId::kNumberTypeFloat32}; } + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } +}; + +class StridedSliceOp : public OpaqueOp { + public: + explicit StridedSliceOp(const std::string &op) : OpaqueOp(op) {} + ~StridedSliceOp() = default; + void RectifyAbstract(const PrimitivePtr &p, AbstractBasePtrList *input_abstract_ptr) override { + input_abstract_ptr->push_back(p->GetAttr("begin_mask")->ToAbstract()); + input_abstract_ptr->push_back(p->GetAttr("end_mask")->ToAbstract()); + input_abstract_ptr->push_back(p->GetAttr("ellipsis_mask")->ToAbstract()); + input_abstract_ptr->push_back(p->GetAttr("new_axis_mask")->ToAbstract()); + input_abstract_ptr->push_back(p->GetAttr("shrink_axis_mask")->ToAbstract()); + } +}; + +class StridedSliceOnnxOp : public OpaqueOp { + public: + explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {} + ~StridedSliceOnnxOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; + + protected: + template + tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &) const; + std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override { + return GetValue>(attrs.find("output_shape")->second); + } + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } +}; + +class MatMulOp : public OpaqueOp { + public: + explicit MatMulOp(const std::string &op) : OpaqueOp(op) {} + ~MatMulOp() = default; + + protected: + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; + std::vector InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; + std::vector InferType(const NodePtrList &inputs, const DAttrs &attrs) override; +}; + +class TupleGetItemOp : public VirtualOp { + public: + using VirtualOp::VirtualOp; + ~TupleGetItemOp() = default; +}; + +class PagedAttentionOp : public OpaqueOp { + public: + explicit PagedAttentionOp(const std::string &op) : OpaqueOp(op) {} + ~PagedAttentionOp() = default; + + protected: + void RectifyAbstract(const PrimitivePtr &, AbstractBasePtrList *input_abstract_ptr) override; +}; +} // namespace mindspore::graphkernel::inner +#endif diff --git a/mindspore/ccsrc/backend/common/graph_kernel/model/op_register.h b/mindspore/ccsrc/backend/common/graph_kernel/model/op_register.h index 2ebd6263bf9..82e435474d1 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/model/op_register.h +++ b/mindspore/ccsrc/backend/common/graph_kernel/model/op_register.h @@ -1,53 +1,53 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ - -#include -#include - -#include "utils/hash_map.h" -#include "backend/common/graph_kernel/model/op_node.h" -#include "include/backend/visible.h" - -namespace mindspore::graphkernel::inner { -using CreatorFunc = std::function; -class BACKEND_EXPORT OpRegistry { - public: - static OpRegistry &Instance() { - static OpRegistry instance{}; - return instance; - } - void Register(const std::string &op_name, const CreatorFunc &func) { (void)creators.emplace(op_name, func); } - - PrimOpPtr NewOp(const std::string &op) { - // "OpaqueOp" is registered by default. - return creators.find(op) == creators.end() ? creators["_opaque"](op) : creators[op](op); - } - - private: - OpRegistry() = default; - ~OpRegistry() = default; - - OpRegistry(const OpRegistry &) = delete; - OpRegistry(const OpRegistry &&) = delete; - OpRegistry &operator=(const OpRegistry &) = delete; - OpRegistry &operator=(const OpRegistry &&) = delete; - - mindspore::HashMap creators; -}; -} // namespace mindspore::graphkernel::inner -#endif +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ + +#include +#include + +#include "utils/hash_map.h" +#include "backend/common/graph_kernel/model/op_node.h" +#include "include/backend/visible.h" + +namespace mindspore::graphkernel::inner { +using CreatorFunc = std::function; +class BACKEND_EXPORT OpRegistry { + public: + static OpRegistry &Instance() { + static OpRegistry instance{}; + return instance; + } + void Register(const std::string &op_name, const CreatorFunc &func) { (void)creators.emplace(op_name, func); } + + PrimOpPtr NewOp(const std::string &op) { + // "OpaqueOp" is registered by default. + return creators.find(op) == creators.end() ? creators["_opaque"](op) : creators[op](op); + } + + private: + OpRegistry() = default; + ~OpRegistry() = default; + + OpRegistry(const OpRegistry &) = delete; + OpRegistry(const OpRegistry &&) = delete; + OpRegistry &operator=(const OpRegistry &) = delete; + OpRegistry &operator=(const OpRegistry &&) = delete; + + mindspore::HashMap creators; +}; +} // namespace mindspore::graphkernel::inner +#endif diff --git a/mindspore/ccsrc/backend/common/optimizer/pass_manager.cc b/mindspore/ccsrc/backend/common/optimizer/pass_manager.cc index 3fdaa30520d..1cc87a05b0b 100644 --- a/mindspore/ccsrc/backend/common/optimizer/pass_manager.cc +++ b/mindspore/ccsrc/backend/common/optimizer/pass_manager.cc @@ -1,96 +1,96 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/backend/optimizer/pass_manager.h" -#include -#include -#include "utils/ms_context.h" -#include "include/common/debug/anf_ir_dump.h" -#include "backend/common/optimizer/cache_manager.h" - -namespace mindspore { -namespace opt { -PassManager::PassManager(const std::string &name, bool run_only_once) - : name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared()) {} - -void PassManager::AddPass(const PassPtr &pass) { - if (pass != nullptr) { - passes_.push_back(pass); - } -} - -bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const { - auto start_time = std::chrono::steady_clock::now(); - bool changed = pass->Run(func_graph); - constexpr auto kMicroSendUnit = 1000000; - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Run pass " + GetPassFullname(pass_id, pass) + " in " << cost.count() << " us"; - return changed; -} - -std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const { - return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name(); -} - -void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const { -#ifdef ENABLE_DUMP_IR - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - static const auto enable_dump = !GetDumpConfig().disable_backend_dump; - if (context_ptr->CanDump(kAdvanced) && enable_dump) { - std::ostringstream oss; - oss << "verbose_ir_files" - << "/"; - oss << (pass_fullname + ".ir"); - DumpIR(oss.str(), func_graph, true); - } -#endif -} - -bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { - if (func_graph == nullptr) { - return false; - } - bool changed = false; - size_t num = 0; - for (const auto &pass : passes) { - if (pass != nullptr) { - pass->SetCacheManager(cache_manager_); - changed = RunPass(func_graph, num, pass) || changed; -#ifdef ENABLE_DUMP_IR - DumpPassIR(func_graph, GetPassFullname(num, pass)); -#endif - num++; - } - } - return changed; -} - -bool PassManager::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - // run all passes - bool change = true; - while (change) { - change = Run(func_graph, passes_); - changed = change || changed; - if (run_only_once_) { - break; - } - } - return changed; -} -} // namespace opt -} // namespace mindspore +/** + * Copyright 2019-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/backend/optimizer/pass_manager.h" +#include +#include +#include "utils/ms_context.h" +#include "include/common/debug/anf_ir_dump.h" +#include "backend/common/optimizer/cache_manager.h" + +namespace mindspore { +namespace opt { +PassManager::PassManager(const std::string &name, bool run_only_once) + : name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared()) {} + +void PassManager::AddPass(const PassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const { + auto start_time = std::chrono::steady_clock::now(); + bool changed = pass->Run(func_graph); + constexpr auto kMicroSendUnit = 1000000; + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Run pass " + GetPassFullname(pass_id, pass) + " in " << cost.count() << " us"; + return changed; +} + +std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const { + return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name(); +} + +void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const { +#ifdef ENABLE_DUMP_IR + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + static const auto enable_dump = !GetDumpConfig().disable_backend_dump; + if (context_ptr->CanDump(kAdvanced) && enable_dump) { + std::ostringstream oss; + oss << "verbose_ir_files" + << "/"; + oss << (pass_fullname + ".ir"); + DumpIR(oss.str(), func_graph, true); + } +#endif +} + +bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + size_t num = 0; + for (const auto &pass : passes) { + if (pass != nullptr) { + pass->SetCacheManager(cache_manager_); + changed = RunPass(func_graph, num, pass) || changed; +#ifdef ENABLE_DUMP_IR + DumpPassIR(func_graph, GetPassFullname(num, pass)); +#endif + num++; + } + } + return changed; +} + +bool PassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/add_training_attr.h b/mindspore/ccsrc/backend/common/pass/add_training_attr.h index bf5922d22ef..61291920ffb 100644 --- a/mindspore/ccsrc/backend/common/pass/add_training_attr.h +++ b/mindspore/ccsrc/backend/common/pass/add_training_attr.h @@ -1,35 +1,35 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H -#include -#include "ir/anf.h" -#include "include/common/utils/convert_utils.h" -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class AddTrainingAttr : public PatternProcessPass { - public: - explicit AddTrainingAttr(bool multigraph = true) : PatternProcessPass("add_training_attr", multigraph) {} - ~AddTrainingAttr() override = default; - const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H +#include +#include "ir/anf.h" +#include "include/common/utils/convert_utils.h" +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class AddTrainingAttr : public PatternProcessPass { + public: + explicit AddTrainingAttr(bool multigraph = true) : PatternProcessPass("add_training_attr", multigraph) {} + ~AddTrainingAttr() override = default; + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H diff --git a/mindspore/ccsrc/backend/common/pass/communication_op_fusion.h b/mindspore/ccsrc/backend/common/pass/communication_op_fusion.h index 8879410bbde..5e51cacd9bd 100644 --- a/mindspore/ccsrc/backend/common/pass/communication_op_fusion.h +++ b/mindspore/ccsrc/backend/common/pass/communication_op_fusion.h @@ -1,99 +1,99 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ -#include -#include -#include -#include "include/backend/visible.h" -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "include/common/utils/utils.h" -#include "ops/array_op_name.h" -#include "ops/ascend_op_name.h" -#include "ops/framework_op_name.h" -#include "ops/other_op_name.h" - -namespace mindspore { -namespace opt { -struct CommunicationOpInfo { - std::vector communication_op_nodes; - std::vector input_grad_size; - std::vector input_grad_time; - std::string group_name; -}; - -class BACKEND_EXPORT CommunicationOpFusion : public Pass { - public: - explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) - : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} - ~CommunicationOpFusion() override = default; - bool Run(const FuncGraphPtr &func_graph) override; - - private: - bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, - const std::vector &segment_index) const; - void GetAllReduceSplitSegment(const std::vector &nodes, int64_t threshold, - std::vector *segment_index) const; - AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, - const CommunicationOpInfo &communication_op_info, size_t start_index, - size_t end_index) const; - bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, std::vector *segment_index, - const std::string &group) const; - std::string op_name_; - size_t groups_ = 1; -}; - -class SendFusion : public CommunicationOpFusion { - public: - explicit SendFusion(size_t groups = 1) : CommunicationOpFusion("send_fusion", kSendOpName, groups) {} - ~SendFusion() override = default; -}; - -class RecvFusion : public CommunicationOpFusion { - public: - explicit RecvFusion(size_t groups = 1) : CommunicationOpFusion("recv_fusion", kReceiveOpName, groups) {} - ~RecvFusion() override = default; -}; - -class AllReduceFusion : public CommunicationOpFusion { - public: - explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} - ~AllReduceFusion() override = default; -}; - -class AllGatherFusion : public CommunicationOpFusion { - public: - explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} - ~AllGatherFusion() override = default; -}; - -class BroadcastFusion : public CommunicationOpFusion { - public: - explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} - ~BroadcastFusion() override = default; -}; - -class ReduceScatterFusion : public CommunicationOpFusion { - public: - explicit ReduceScatterFusion(size_t groups = 1) - : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} - ~ReduceScatterFusion() override = default; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ +#include +#include +#include +#include "include/backend/visible.h" +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/common/utils/utils.h" +#include "ops/array_op_name.h" +#include "ops/ascend_op_name.h" +#include "ops/framework_op_name.h" +#include "ops/other_op_name.h" + +namespace mindspore { +namespace opt { +struct CommunicationOpInfo { + std::vector communication_op_nodes; + std::vector input_grad_size; + std::vector input_grad_time; + std::string group_name; +}; + +class BACKEND_EXPORT CommunicationOpFusion : public Pass { + public: + explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) + : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} + ~CommunicationOpFusion() override = default; + bool Run(const FuncGraphPtr &func_graph) override; + + private: + bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, + const std::vector &segment_index) const; + void GetAllReduceSplitSegment(const std::vector &nodes, int64_t threshold, + std::vector *segment_index) const; + AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) const; + bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, std::vector *segment_index, + const std::string &group) const; + std::string op_name_; + size_t groups_ = 1; +}; + +class SendFusion : public CommunicationOpFusion { + public: + explicit SendFusion(size_t groups = 1) : CommunicationOpFusion("send_fusion", kSendOpName, groups) {} + ~SendFusion() override = default; +}; + +class RecvFusion : public CommunicationOpFusion { + public: + explicit RecvFusion(size_t groups = 1) : CommunicationOpFusion("recv_fusion", kReceiveOpName, groups) {} + ~RecvFusion() override = default; +}; + +class AllReduceFusion : public CommunicationOpFusion { + public: + explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} + ~AllReduceFusion() override = default; +}; + +class AllGatherFusion : public CommunicationOpFusion { + public: + explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} + ~AllGatherFusion() override = default; +}; + +class BroadcastFusion : public CommunicationOpFusion { + public: + explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} + ~BroadcastFusion() override = default; +}; + +class ReduceScatterFusion : public CommunicationOpFusion { + public: + explicit ReduceScatterFusion(size_t groups = 1) + : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} + ~ReduceScatterFusion() override = default; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/common/session/CMakeLists.txt b/mindspore/ccsrc/backend/common/session/CMakeLists.txt index f3b958b38b4..f56ceb8cafc 100644 --- a/mindspore/ccsrc/backend/common/session/CMakeLists.txt +++ b/mindspore/ccsrc/backend/common/session/CMakeLists.txt @@ -1,26 +1,26 @@ -file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_build_client.cc" - "kernel_graph.cc" - "kernel_graph_mgr.cc" - "exec_order_builder.cc" - "session_basic.cc" - "session_factory.cc" - "executor.cc" - "executor_manager.cc" - "anf_runtime_algorithm.cc" - "py_execute_utils.cc" - "debug_register.cc" - "single_kernel_graph.cc" -) - -if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) - string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -endif() - -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-overloaded-virtual -Wno-delete-abstract-non-virtual-dtor") -endif() - -set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) -add_library(_mindspore_backend_common_session_obj OBJECT ${_SESSION_SRC_LIST}) +file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_build_client.cc" + "kernel_graph.cc" + "kernel_graph_mgr.cc" + "exec_order_builder.cc" + "session_basic.cc" + "session_factory.cc" + "executor.cc" + "executor_manager.cc" + "anf_runtime_algorithm.cc" + "py_execute_utils.cc" + "debug_register.cc" + "single_kernel_graph.cc" +) + +if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) + string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-overloaded-virtual -Wno-delete-abstract-non-virtual-dtor") +endif() + +set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) +add_library(_mindspore_backend_common_session_obj OBJECT ${_SESSION_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/common/session/executor_manager.h b/mindspore/ccsrc/backend/common/session/executor_manager.h index 9883b26d6a9..fe51416368a 100644 --- a/mindspore/ccsrc/backend/common/session/executor_manager.h +++ b/mindspore/ccsrc/backend/common/session/executor_manager.h @@ -1,50 +1,50 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ -#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ -#include -#include -#include -#include -#include "backend/common/session/executor.h" -#include "include/backend/visible.h" - -namespace mindspore::session { -class Executor; -class BACKEND_EXPORT ExecutorManager { - public: - static ExecutorManager &Instance(); - std::shared_ptr GetExecutor(const std::string &device_name, uint32_t device_id); - void OnEvent(const ExecutorEvent &event); - void Clear(); - void ClearDoneTasks() { - for (const auto &item : executors_) { - auto &executor = item.second; - if (executor != nullptr) { - executor->ClearDoneTasks(); - } - } - } - - private: - ExecutorManager() = default; - ~ExecutorManager() = default; - DISABLE_COPY_AND_ASSIGN(ExecutorManager) - - std::map> executors_; -}; -} // namespace mindspore::session -#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ +#include +#include +#include +#include +#include "backend/common/session/executor.h" +#include "include/backend/visible.h" + +namespace mindspore::session { +class Executor; +class BACKEND_EXPORT ExecutorManager { + public: + static ExecutorManager &Instance(); + std::shared_ptr GetExecutor(const std::string &device_name, uint32_t device_id); + void OnEvent(const ExecutorEvent &event); + void Clear(); + void ClearDoneTasks() { + for (const auto &item : executors_) { + auto &executor = item.second; + if (executor != nullptr) { + executor->ClearDoneTasks(); + } + } + } + + private: + ExecutorManager() = default; + ~ExecutorManager() = default; + DISABLE_COPY_AND_ASSIGN(ExecutorManager) + + std::map> executors_; +}; +} // namespace mindspore::session +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/graph_compiler/graph_partition.h b/mindspore/ccsrc/backend/graph_compiler/graph_partition.h index 1478fe7ad68..12ce0d66c59 100644 --- a/mindspore/ccsrc/backend/graph_compiler/graph_partition.h +++ b/mindspore/ccsrc/backend/graph_compiler/graph_partition.h @@ -1,49 +1,49 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ -#define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ -#include -#include -#include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/graph_utils.h" -#include "base/base_ref.h" -#include "include/backend/visible.h" - -namespace mindspore { -constexpr char kMsConvert[] = "ms"; -constexpr char kMsVm[] = "vm"; -constexpr char kGeVm[] = "ge"; - -namespace compile { -class GraphPartition { - public: - explicit GraphPartition(const std::vector &cut_list, const std::string &backend_name); - ~GraphPartition() = default; - std::vector Partition(const FuncGraphPtr &graph, bool *multi_target = nullptr); - - private: - bool IsCut(const AnfNodePtr &node); - std::vector cut_list_; - std::string backend_name_; -}; - -using GraphPartitionPtr = std::shared_ptr; -} // namespace compile -} // namespace mindspore -#endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ +#define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ +#include +#include +#include +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/graph_utils.h" +#include "base/base_ref.h" +#include "include/backend/visible.h" + +namespace mindspore { +constexpr char kMsConvert[] = "ms"; +constexpr char kMsVm[] = "vm"; +constexpr char kGeVm[] = "ge"; + +namespace compile { +class GraphPartition { + public: + explicit GraphPartition(const std::vector &cut_list, const std::string &backend_name); + ~GraphPartition() = default; + std::vector Partition(const FuncGraphPtr &graph, bool *multi_target = nullptr); + + private: + bool IsCut(const AnfNodePtr &node); + std::vector cut_list_; + std::string backend_name_; +}; + +using GraphPartitionPtr = std::shared_ptr; +} // namespace compile +} // namespace mindspore +#endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ diff --git a/mindspore/ccsrc/backend/operator/CMakeLists.txt b/mindspore/ccsrc/backend/operator/CMakeLists.txt old mode 100755 new mode 100644 index 0bf4c0fc625..39892044e0e --- a/mindspore/ccsrc/backend/operator/CMakeLists.txt +++ b/mindspore/ccsrc/backend/operator/CMakeLists.txt @@ -1,8 +1,8 @@ -file(GLOB_RECURSE _OPERATOR_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPERATOR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_VM) -add_library(_mindspore_backend_operator_obj OBJECT ${_OPERATOR_SRC_LIST}) - -if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) - string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +file(GLOB_RECURSE _OPERATOR_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPERATOR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_VM) +add_library(_mindspore_backend_operator_obj OBJECT ${_OPERATOR_SRC_LIST}) + +if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) + string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") endif() \ No newline at end of file diff --git a/mindspore/ccsrc/backend/operator/ops_backend_infer_function.cc b/mindspore/ccsrc/backend/operator/ops_backend_infer_function.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/backend/operator/ops_backend_infer_function.h b/mindspore/ccsrc/backend/operator/ops_backend_infer_function.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.cc b/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/frontend/expander/CMakeLists.txt b/mindspore/ccsrc/frontend/expander/CMakeLists.txt index 9efbd38ade5..588cca35fc4 100644 --- a/mindspore/ccsrc/frontend/expander/CMakeLists.txt +++ b/mindspore/ccsrc/frontend/expander/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB_RECURSE _EXPANDER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_EXPANDER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS - SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) -add_library(_mindspore_frontend_expander_obj OBJECT ${_EXPANDER_SRC_FILES}) +file(GLOB_RECURSE _EXPANDER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_EXPANDER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) +add_library(_mindspore_frontend_expander_obj OBJECT ${_EXPANDER_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/operator/CMakeLists.txt b/mindspore/ccsrc/frontend/operator/CMakeLists.txt index 4ac71811b7a..270410ee777 100644 --- a/mindspore/ccsrc/frontend/operator/CMakeLists.txt +++ b/mindspore/ccsrc/frontend/operator/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS - SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) -add_library(_mindspore_frontend_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) +file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) +add_library(_mindspore_frontend_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/frontend/operator/ops_extends.cc b/mindspore/ccsrc/frontend/operator/ops_extends.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt index f4d0eb29be9..5f71c97b822 100644 --- a/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt +++ b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt @@ -1,9 +1,9 @@ -file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") - -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-non-abstract-non-virtual-dtor") -endif() - -set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS - SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) -add_library(_mindspore_frontend_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) +file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-delete-non-abstract-non-virtual-dtor") +endif() + +set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) +add_library(_mindspore_frontend_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc index cc3e31c388a..53b1741a51f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -1,490 +1,490 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" -#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" -#include "frontend/parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) { - TensorParam new_tensor; - new_tensor.tensor_type = kFloat32; - new_tensor.tensor_shape.shape_n = n; - new_tensor.tensor_shape.shape_c = c; - new_tensor.tensor_shape.shape_h = h; - new_tensor.tensor_shape.shape_w = w; - const TensorParam &tensor = new_tensor; - return tensor; -} - -Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { - Graph::NodeType NewOp; - NewOp.name = ops[iter_ops]->name(); - NewOp.info = InfoType::kApplication; - - auto pos = ops[iter_ops]->name().find("Info"); - auto name = ops[iter_ops]->name().substr(0, pos); - auto op_type = ops[iter_ops]->type(); - auto idx = DictOpType.find(op_type); - if (idx != DictOpType.end()) { - NewOp.apply.op_type = DictOpType.at(op_type); - } else if (name == STAND_ALONE) { - MS_LOG(INFO) << ops[iter_ops]->type() << ": standalone operator."; - NewOp.apply.op_type = OperatorType::kRecStandAlone; - } else if (name == BATCH_PARALLEL) { - MS_LOG(INFO) << ops[iter_ops]->type() << ": batch parallel operator."; - NewOp.apply.op_type = OperatorType::kRecBatchParallel; - } else { - NewOp.apply.op_type = OperatorType::kRecUnknownType; - MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type; - } - - if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " outputs shape is empty."; - } - - if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) { - NewOp.tensor_parm = MakeTensor(ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1], - ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO], - ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_THREE]); - } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) { - NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1], - ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO]); - } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) { - NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1]); - } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_shape()[0][0]); - } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ZERO) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(WARNING) << ops[iter_ops]->name() << ": output tensor shape is unexpected."; - } - - CompleteOperatorInputs(ops, iter_ops, &NewOp); - MS_LOG(INFO) << "Node " << NewOp.name << "created successfully" - << " its input is " << ops[iter_ops]->inputs_shape() << " and its output is " - << ops[iter_ops]->outputs_shape() << "."; - return NewOp; -} - -void CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType *NewTensor) { - size_t input_tensor_size = ops[iter_ops]->inputs_shape().size(); - if (ops[iter_ops]->type() == STACK) { - input_tensor_size = 1; - } - if (input_tensor_size > MAX_INPUT_NUM) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor " << input_tensor_size << " num exceeds limit(" - << MAX_INPUT_NUM << ")."; - } - - for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) { - if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_FOUR) { - Complete4DInputs(ops, iter_ops, iter_input_tensors, NewTensor); - } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_THREE) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); - } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_TWO) { - Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); - } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_ONE) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO]); - } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == 0) { - NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(WARNING) << ops[iter_ops]->name() << ": input tensor shape is unexpected."; - } - } -} - -void Complete2DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensors, Graph::NodeType *NewTensor) { - if (NewTensor->apply.op_type == OperatorType::kRecMatMul) { - auto input_value = ops[iter_ops]->input_value(); - bool transpose_a = input_value[2]->cast()->value(); - bool transpose_b = input_value[3]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1], - ops[iter_ops]->inputs_shape()[iter_input_tensors][0]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1], - ops[iter_ops]->inputs_shape()[iter_input_tensors][0]); - } else { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], - ops[iter_ops]->inputs_shape()[iter_input_tensors][1]); - } - } else { - NewTensor->apply.arguments[iter_input_tensors] = MakeTensor( - 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], ops[iter_ops]->inputs_shape()[iter_input_tensors][1]); - } -} - -void Complete4DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensors, Graph::NodeType *NewTensor) { - if (NewTensor->apply.op_type == OperatorType::kRecBatchMatMul) { - auto input_value = ops[iter_ops]->input_value(); - bool transpose_a = input_value[2]->cast()->value(); - bool transpose_b = input_value[3]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); - } else { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]); - } - } else { - NewTensor->apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO], - ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]); - } -} - -std::shared_ptr ParseGraph(const std::vector> &ops, - const std::vector> &input_tensor_names) { - std::shared_ptr graph = std::make_shared(); - constexpr size_t MAX_OP_NUM = SIZE_MAX / 2; - if (ops.size() > MAX_OP_NUM) { - MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << MAX_OP_NUM; - } - - for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); - NewOp.param_name = ops[iter_ops]->get_involved_param_name(); - graph->nodes.push_back(NewOp); - } - MakeEdge(input_tensor_names, graph); - - return graph; -} - -void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { - for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { - for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { - size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); - if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { - graph->nodes[iter_i].node_in.push_back(head_node_index); - graph->nodes[head_node_index].node_out.push_back(iter_i); - } - } - } -} - -size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, - const std::string &input_name) { - for (size_t index = 0; index < input_tensor_name.size(); index++) { - if (input_tensor_name[index][0] == input_name) { - return index; - } - } - MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead"; - return SIZE_MAX; -} - -void Eliminate_Aux(size_t node_index, const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(eli_list); - std::vector eli; - eli.push_back(node_index); - for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { - auto outgoing_node_idx = graph->nodes[node_index].node_out[i]; - eli.push_back(outgoing_node_idx); - if (!graph->nodes[node_index].param_name.empty() && - graph->nodes[node_index].apply.op_type == OperatorType::kRecCast && - (graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecMatMul || - graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecBatchMatMul)) { - graph->nodes[outgoing_node_idx].param_name = graph->nodes[node_index].param_name; - } - } - eli_list->push_back(eli); - - // Iterate over all input operators of the current node - for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { - auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; - auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); - if (it != incoming_outputs->end()) { - it = incoming_outputs->erase(it); - for (auto outgoing_index : graph->nodes[node_index].node_out) { - it = find(incoming_outputs->begin(), incoming_outputs->end(), outgoing_index); - if (it == incoming_outputs->end()) { - incoming_outputs->push_back(outgoing_index); - } - } - } - } - - // Iterate over all aux_input operators of the current node - for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { - auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; - auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); - if (it != aux_incoming_outputs->end()) { - it = aux_incoming_outputs->erase(it); - for (auto outgoing_index : graph->nodes[node_index].node_out) { - it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), outgoing_index); - if (it == aux_incoming_outputs->end()) { - aux_incoming_outputs->push_back(outgoing_index); - } - } - } - } - - // Iterate over all output operators of the current node - Eliminate_Aux_Outgoing(node_index, graph); -} - -void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr &graph, size_t i) { - MS_EXCEPTION_IF_NULL(graph); - auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; - MS_EXCEPTION_IF_NULL(outgoing_inputs); - // Check if the current node is the input operator of the current node's output operator - auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); - if (it != outgoing_inputs->end()) { - if (graph->nodes[node_index].node_in.size() > 0) { - // If the current node has input operator, then add input[0] of the current node to the input of the current - // node's output operator (if input[0] is also in the aux_input of the current node's output operator, then remove - // it from the aux_input and keep it only in the input) - auto exist_in_outgoing_auxinputs = - find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), graph->nodes[node_index].node_in[0]); - if (exist_in_outgoing_auxinputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { - size_t index_remove_node = LongToSize(std::distance( - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), exist_in_outgoing_auxinputs)); - if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.size() > index_remove_node) { - (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.erase( - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.begin() + index_remove_node); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node << ", out of range!"; - } - if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.size() > index_remove_node) { - (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.erase(exist_in_outgoing_auxinputs); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node - << ", which is out of range!"; - } - } - size_t idx = LongToSize(std::distance(outgoing_inputs->begin(), it)); - if (outgoing_inputs->size() > idx) { - outgoing_inputs->at(idx) = graph->nodes[node_index].node_in[0]; - } else { - MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!"; - } - // Then add the other input operators of the current node to the aux_input of the current node's output operator - for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { - exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), - graph->nodes[node_index].node_in[j]); - if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { - size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it)); - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux); - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); - } - } - // Then add all the operators in the aux_input of the current node to the aux_input of the output operator of the - // current node - for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) { - exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), - graph->nodes[node_index].node_in_aux[j]); - if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { - size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it)); - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux); - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( - graph->nodes[node_index].node_in_aux[j]); - } - } - } else { - auto idx = LongToSize(std::distance(outgoing_inputs->begin(), it)); - if (outgoing_inputs->size() > idx) { - (void)outgoing_inputs->erase(it); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << idx << ", out of range!"; - } - } - } -} - -void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr &graph, size_t i) { - MS_EXCEPTION_IF_NULL(graph); - auto *outgoing_auxinputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux; - MS_EXCEPTION_IF_NULL(outgoing_auxinputs); - auto *outgoing_auxinputs_index = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx; - // Check if the current node is the aux_input operator of the current node's output operator - auto it = find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), node_index); - size_t index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); - if (it != outgoing_auxinputs->end()) { - if (graph->nodes[node_index].node_in.size() > 0) { - // If the current node has input operator, and if the input[0] of the current node is in - // the input of the output operator of the current node, then delete it - // from the aux_input of the output of the current node, otherwise add the input[0] - // to the auxinput of the output of the current node - auto exist_in_outgoing_inputs = - find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[0]); - if (exist_in_outgoing_inputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { - index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); - if (outgoing_auxinputs_index->size() > index_entree) { - (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; - } - if (outgoing_auxinputs->size() > index_entree) { - (void)outgoing_auxinputs->erase(it); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; - } - } else { - size_t idx = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); - if (outgoing_auxinputs->size() > idx) { - outgoing_auxinputs->at(idx) = graph->nodes[node_index].node_in[0]; - } else { - MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!"; - } - index_entree = LongToSize(std::distance( - outgoing_auxinputs->begin(), - find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), graph->nodes[node_index].node_in[0]))); - } - // Determine whether the other input operator of the current node is in the input of the output operator, - // and if not, add it to the aux_input of the output operator - for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { - exist_in_outgoing_inputs = - find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[j]); - if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { - outgoing_auxinputs->push_back(graph->nodes[node_index].node_in[j]); - if (outgoing_auxinputs_index->size() > index_entree) { - outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree)); - } else { - MS_LOG(DEBUG) << "Trying to index vector element at index " << index_entree << ", out of range!"; - } - } - } - // Determine if the aux_input operator of the current node is in the input of the output operator, - // and if not, add it to the aux_input of the output operator - for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) { - exist_in_outgoing_inputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), - graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), - graph->nodes[node_index].node_in_aux[j]); - if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { - outgoing_auxinputs->push_back(graph->nodes[node_index].node_in_aux[j]); - outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree)); - } - } - } else { - if (outgoing_auxinputs_index->size() > index_entree) { - (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; - } - if (outgoing_auxinputs->size() > index_entree) { - (void)outgoing_auxinputs->erase(it); - } else { - MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", which is out of range."; - } - } - } -} - -void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr &graph) { - for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { - // Handle the output operator connected to the current node via main edge - EliminateAuxOutgoingInput(node_index, graph, i); - // Handle the output operator connected to the current node via auxiliary edge - EliminateAuxOutgoingAuxInput(node_index, graph, i); - } -} - -static void EraseEliminatedNode(std::vector *nodes, const std::shared_ptr> &index_list) { - for (size_t j = nodes->size(); j > 0; j--) { - bool IsEliminated = (index_list->at(nodes->at(j - 1)) == SIZE_MAX); - if (IsEliminated) { - (void)nodes->erase(nodes->begin() + SizeToLong(j) - 1); - } else { - nodes->at(j - 1) = index_list->at(nodes->at(j - 1)); - } - } -} - -std::shared_ptr EliminateGraph(const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list, - const std::shared_ptr> &index_list, - const bool dyn_shape_tmp_fix) { - MS_EXCEPTION_IF_NULL(graph); - for (size_t node_index = 0; node_index < graph->nodes.size(); node_index++) { - auto type = graph->nodes[node_index].apply.op_type; - if (dyn_shape_tmp_fix && type == OperatorType::kRecBatchMatMul) { - continue; - } else if (EliminateOpType.find(type) != EliminateOpType.end()) { - Eliminate_Aux(node_index, graph, eli_list); - } - } - index_list->reserve(graph->nodes.size()); - for (size_t i = 0; i < graph->nodes.size(); i++) { - index_list->push_back(i); - } - for (size_t i = 0; i < eli_list->size(); i++) { - if (eli_list->at(i)[0] >= index_list->size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - index_list->at(eli_list->at(i)[0]) = SIZE_MAX; - for (size_t j = eli_list->at(i)[0] + 1; j < index_list->size(); j++) { - index_list->at(j)--; - } - } - std::shared_ptr new_graph = std::make_shared(); - for (size_t i = 0; i < graph->nodes.size(); i++) { - if (index_list->at(i) > SIZE_MAX / 2) { - continue; - } - new_graph->nodes.push_back(graph->nodes[i]); - auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; - EraseEliminatedNode(node_in, index_list); - auto *node_in_aux = &new_graph->nodes[index_list->at(i)].node_in_aux; - EraseEliminatedNode(node_in_aux, index_list); - auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; - EraseEliminatedNode(node_out, index_list); - } - return new_graph; -} -} // namespace parallel -} // namespace mindspore +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w) { + TensorParam new_tensor; + new_tensor.tensor_type = kFloat32; + new_tensor.tensor_shape.shape_n = n; + new_tensor.tensor_shape.shape_c = c; + new_tensor.tensor_shape.shape_h = h; + new_tensor.tensor_shape.shape_w = w; + const TensorParam &tensor = new_tensor; + return tensor; +} + +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { + Graph::NodeType NewOp; + NewOp.name = ops[iter_ops]->name(); + NewOp.info = InfoType::kApplication; + + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + auto op_type = ops[iter_ops]->type(); + auto idx = DictOpType.find(op_type); + if (idx != DictOpType.end()) { + NewOp.apply.op_type = DictOpType.at(op_type); + } else if (name == STAND_ALONE) { + MS_LOG(INFO) << ops[iter_ops]->type() << ": standalone operator."; + NewOp.apply.op_type = OperatorType::kRecStandAlone; + } else if (name == BATCH_PARALLEL) { + MS_LOG(INFO) << ops[iter_ops]->type() << ": batch parallel operator."; + NewOp.apply.op_type = OperatorType::kRecBatchParallel; + } else { + NewOp.apply.op_type = OperatorType::kRecUnknownType; + MS_LOG(INFO) << ops[iter_ops]->name() << ": Unknown operator type " << op_type; + } + + if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " outputs shape is empty."; + } + + if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) { + NewOp.tensor_parm = MakeTensor(ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1], + ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO], + ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_THREE]); + } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) { + NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1], + ops[iter_ops]->outputs_shape()[INDEX_ZERO][INDEX_TWO]); + } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) { + NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_shape()[0][0], ops[iter_ops]->outputs_shape()[0][1]); + } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_shape()[0][0]); + } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ZERO) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(WARNING) << ops[iter_ops]->name() << ": output tensor shape is unexpected."; + } + + CompleteOperatorInputs(ops, iter_ops, &NewOp); + MS_LOG(INFO) << "Node " << NewOp.name << "created successfully" + << " its input is " << ops[iter_ops]->inputs_shape() << " and its output is " + << ops[iter_ops]->outputs_shape() << "."; + return NewOp; +} + +void CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType *NewTensor) { + size_t input_tensor_size = ops[iter_ops]->inputs_shape().size(); + if (ops[iter_ops]->type() == STACK) { + input_tensor_size = 1; + } + if (input_tensor_size > MAX_INPUT_NUM) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " input tensor " << input_tensor_size << " num exceeds limit(" + << MAX_INPUT_NUM << ")."; + } + + for (size_t iter_input_tensors = 0; iter_input_tensors < input_tensor_size; iter_input_tensors++) { + if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_FOUR) { + Complete4DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_THREE) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); + } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_TWO) { + Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == SIZE_ONE) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO]); + } else if (ops[iter_ops]->inputs_shape()[iter_input_tensors].size() == 0) { + NewTensor->apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(WARNING) << ops[iter_ops]->name() << ": input tensor shape is unexpected."; + } + } +} + +void Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType *NewTensor) { + if (NewTensor->apply.op_type == OperatorType::kRecMatMul) { + auto input_value = ops[iter_ops]->input_value(); + bool transpose_a = input_value[2]->cast()->value(); + bool transpose_b = input_value[3]->cast()->value(); + if (transpose_a && (iter_input_tensors == 0)) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1], + ops[iter_ops]->inputs_shape()[iter_input_tensors][0]); + } else if (transpose_b && (iter_input_tensors == 1)) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][1], + ops[iter_ops]->inputs_shape()[iter_input_tensors][0]); + } else { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], + ops[iter_ops]->inputs_shape()[iter_input_tensors][1]); + } + } else { + NewTensor->apply.arguments[iter_input_tensors] = MakeTensor( + 1, 1, ops[iter_ops]->inputs_shape()[iter_input_tensors][0], ops[iter_ops]->inputs_shape()[iter_input_tensors][1]); + } +} + +void Complete4DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType *NewTensor) { + if (NewTensor->apply.op_type == OperatorType::kRecBatchMatMul) { + auto input_value = ops[iter_ops]->input_value(); + bool transpose_a = input_value[2]->cast()->value(); + bool transpose_b = input_value[3]->cast()->value(); + if (transpose_a && (iter_input_tensors == 0)) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); + } else if (transpose_b && (iter_input_tensors == 1)) { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO]); + } else { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]); + } + } else { + NewTensor->apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ZERO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_ONE], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_TWO], + ops[iter_ops]->inputs_shape()[iter_input_tensors][INDEX_THREE]); + } +} + +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { + std::shared_ptr graph = std::make_shared(); + constexpr size_t MAX_OP_NUM = SIZE_MAX / 2; + if (ops.size() > MAX_OP_NUM) { + MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << MAX_OP_NUM; + } + + for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { + Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); + NewOp.param_name = ops[iter_ops]->get_involved_param_name(); + graph->nodes.push_back(NewOp); + } + MakeEdge(input_tensor_names, graph); + + return graph; +} + +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { + for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { + for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { + size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); + if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { + graph->nodes[iter_i].node_in.push_back(head_node_index); + graph->nodes[head_node_index].node_out.push_back(iter_i); + } + } + } +} + +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { + for (size_t index = 0; index < input_tensor_name.size(); index++) { + if (input_tensor_name[index][0] == input_name) { + return index; + } + } + MS_LOG(INFO) << "Get index failed, using SIZE_MAX instead"; + return SIZE_MAX; +} + +void Eliminate_Aux(size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(eli_list); + std::vector eli; + eli.push_back(node_index); + for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { + auto outgoing_node_idx = graph->nodes[node_index].node_out[i]; + eli.push_back(outgoing_node_idx); + if (!graph->nodes[node_index].param_name.empty() && + graph->nodes[node_index].apply.op_type == OperatorType::kRecCast && + (graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecMatMul || + graph->nodes[outgoing_node_idx].apply.op_type == OperatorType::kRecBatchMatMul)) { + graph->nodes[outgoing_node_idx].param_name = graph->nodes[node_index].param_name; + } + } + eli_list->push_back(eli); + + // Iterate over all input operators of the current node + for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { + auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; + auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); + if (it != incoming_outputs->end()) { + it = incoming_outputs->erase(it); + for (auto outgoing_index : graph->nodes[node_index].node_out) { + it = find(incoming_outputs->begin(), incoming_outputs->end(), outgoing_index); + if (it == incoming_outputs->end()) { + incoming_outputs->push_back(outgoing_index); + } + } + } + } + + // Iterate over all aux_input operators of the current node + for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { + auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; + auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); + if (it != aux_incoming_outputs->end()) { + it = aux_incoming_outputs->erase(it); + for (auto outgoing_index : graph->nodes[node_index].node_out) { + it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), outgoing_index); + if (it == aux_incoming_outputs->end()) { + aux_incoming_outputs->push_back(outgoing_index); + } + } + } + } + + // Iterate over all output operators of the current node + Eliminate_Aux_Outgoing(node_index, graph); +} + +void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr &graph, size_t i) { + MS_EXCEPTION_IF_NULL(graph); + auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; + MS_EXCEPTION_IF_NULL(outgoing_inputs); + // Check if the current node is the input operator of the current node's output operator + auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); + if (it != outgoing_inputs->end()) { + if (graph->nodes[node_index].node_in.size() > 0) { + // If the current node has input operator, then add input[0] of the current node to the input of the current + // node's output operator (if input[0] is also in the aux_input of the current node's output operator, then remove + // it from the aux_input and keep it only in the input) + auto exist_in_outgoing_auxinputs = + find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), graph->nodes[node_index].node_in[0]); + if (exist_in_outgoing_auxinputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { + size_t index_remove_node = LongToSize(std::distance( + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), exist_in_outgoing_auxinputs)); + if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.size() > index_remove_node) { + (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.erase( + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.begin() + index_remove_node); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node << ", out of range!"; + } + if (graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.size() > index_remove_node) { + (void)graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.erase(exist_in_outgoing_auxinputs); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_remove_node + << ", which is out of range!"; + } + } + size_t idx = LongToSize(std::distance(outgoing_inputs->begin(), it)); + if (outgoing_inputs->size() > idx) { + outgoing_inputs->at(idx) = graph->nodes[node_index].node_in[0]; + } else { + MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!"; + } + // Then add the other input operators of the current node to the aux_input of the current node's output operator + for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { + exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), + graph->nodes[node_index].node_in[j]); + if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { + size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it)); + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux); + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); + } + } + // Then add all the operators in the aux_input of the current node to the aux_input of the output operator of the + // current node + for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) { + exist_in_outgoing_auxinputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end(), + graph->nodes[node_index].node_in_aux[j]); + if (exist_in_outgoing_auxinputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.end()) { + size_t index_aux = LongToSize(std::distance(outgoing_inputs->begin(), it)); + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx.push_back(index_aux); + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( + graph->nodes[node_index].node_in_aux[j]); + } + } + } else { + auto idx = LongToSize(std::distance(outgoing_inputs->begin(), it)); + if (outgoing_inputs->size() > idx) { + (void)outgoing_inputs->erase(it); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << idx << ", out of range!"; + } + } + } +} + +void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr &graph, size_t i) { + MS_EXCEPTION_IF_NULL(graph); + auto *outgoing_auxinputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux; + MS_EXCEPTION_IF_NULL(outgoing_auxinputs); + auto *outgoing_auxinputs_index = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux_idx; + // Check if the current node is the aux_input operator of the current node's output operator + auto it = find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), node_index); + size_t index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); + if (it != outgoing_auxinputs->end()) { + if (graph->nodes[node_index].node_in.size() > 0) { + // If the current node has input operator, and if the input[0] of the current node is in + // the input of the output operator of the current node, then delete it + // from the aux_input of the output of the current node, otherwise add the input[0] + // to the auxinput of the output of the current node + auto exist_in_outgoing_inputs = + find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[0]); + if (exist_in_outgoing_inputs != graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { + index_entree = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); + if (outgoing_auxinputs_index->size() > index_entree) { + (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; + } + if (outgoing_auxinputs->size() > index_entree) { + (void)outgoing_auxinputs->erase(it); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; + } + } else { + size_t idx = LongToSize(std::distance(outgoing_auxinputs->begin(), it)); + if (outgoing_auxinputs->size() > idx) { + outgoing_auxinputs->at(idx) = graph->nodes[node_index].node_in[0]; + } else { + MS_LOG(DEBUG) << "Trying to index vector element at index " << idx << ", out of range!"; + } + index_entree = LongToSize(std::distance( + outgoing_auxinputs->begin(), + find(outgoing_auxinputs->begin(), outgoing_auxinputs->end(), graph->nodes[node_index].node_in[0]))); + } + // Determine whether the other input operator of the current node is in the input of the output operator, + // and if not, add it to the aux_input of the output operator + for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { + exist_in_outgoing_inputs = + find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), graph->nodes[node_index].node_in[j]); + if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { + outgoing_auxinputs->push_back(graph->nodes[node_index].node_in[j]); + if (outgoing_auxinputs_index->size() > index_entree) { + outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree)); + } else { + MS_LOG(DEBUG) << "Trying to index vector element at index " << index_entree << ", out of range!"; + } + } + } + // Determine if the aux_input operator of the current node is in the input of the output operator, + // and if not, add it to the aux_input of the output operator + for (size_t j = 0; j < graph->nodes[node_index].node_in_aux.size(); j++) { + exist_in_outgoing_inputs = find(graph->nodes[graph->nodes[node_index].node_out[i]].node_in.begin(), + graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end(), + graph->nodes[node_index].node_in_aux[j]); + if (exist_in_outgoing_inputs == graph->nodes[graph->nodes[node_index].node_out[i]].node_in.end()) { + outgoing_auxinputs->push_back(graph->nodes[node_index].node_in_aux[j]); + outgoing_auxinputs_index->push_back(outgoing_auxinputs_index->at(index_entree)); + } + } + } else { + if (outgoing_auxinputs_index->size() > index_entree) { + (void)outgoing_auxinputs_index->erase(outgoing_auxinputs_index->begin() + index_entree); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", out of range!"; + } + if (outgoing_auxinputs->size() > index_entree) { + (void)outgoing_auxinputs->erase(it); + } else { + MS_LOG(DEBUG) << "Trying to erase vector element at index " << index_entree << ", which is out of range."; + } + } + } +} + +void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr &graph) { + for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { + // Handle the output operator connected to the current node via main edge + EliminateAuxOutgoingInput(node_index, graph, i); + // Handle the output operator connected to the current node via auxiliary edge + EliminateAuxOutgoingAuxInput(node_index, graph, i); + } +} + +static void EraseEliminatedNode(std::vector *nodes, const std::shared_ptr> &index_list) { + for (size_t j = nodes->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(nodes->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + (void)nodes->erase(nodes->begin() + SizeToLong(j) - 1); + } else { + nodes->at(j - 1) = index_list->at(nodes->at(j - 1)); + } + } +} + +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list, + const bool dyn_shape_tmp_fix) { + MS_EXCEPTION_IF_NULL(graph); + for (size_t node_index = 0; node_index < graph->nodes.size(); node_index++) { + auto type = graph->nodes[node_index].apply.op_type; + if (dyn_shape_tmp_fix && type == OperatorType::kRecBatchMatMul) { + continue; + } else if (EliminateOpType.find(type) != EliminateOpType.end()) { + Eliminate_Aux(node_index, graph, eli_list); + } + } + index_list->reserve(graph->nodes.size()); + for (size_t i = 0; i < graph->nodes.size(); i++) { + index_list->push_back(i); + } + for (size_t i = 0; i < eli_list->size(); i++) { + if (eli_list->at(i)[0] >= index_list->size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + index_list->at(eli_list->at(i)[0]) = SIZE_MAX; + for (size_t j = eli_list->at(i)[0] + 1; j < index_list->size(); j++) { + index_list->at(j)--; + } + } + std::shared_ptr new_graph = std::make_shared(); + for (size_t i = 0; i < graph->nodes.size(); i++) { + if (index_list->at(i) > SIZE_MAX / 2) { + continue; + } + new_graph->nodes.push_back(graph->nodes[i]); + auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; + EraseEliminatedNode(node_in, index_list); + auto *node_in_aux = &new_graph->nodes[index_list->at(i)].node_in_aux; + EraseEliminatedNode(node_in_aux, index_list); + auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; + EraseEliminatedNode(node_out, index_list); + } + return new_graph; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/came_parallel_handler.cc b/mindspore/ccsrc/frontend/parallel/came_parallel_handler.cc index e7983913bab..c3fc6347a22 100644 --- a/mindspore/ccsrc/frontend/parallel/came_parallel_handler.cc +++ b/mindspore/ccsrc/frontend/parallel/came_parallel_handler.cc @@ -1,527 +1,527 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "frontend/parallel/came_parallel_handler.h" - -#include -#include - -#include "frontend/parallel/parameter_manager.h" -#include "mindspore/core/ops/sequence_ops.h" -#include "mindspore/core/ops/other_ops.h" -#include "mindspore/core/ops/array_ops.h" -#include "mindspore/core/ops/framework_ops.h" -#include "mindspore/core/utils/convert_utils_base.h" -#include "utils/hash_map.h" -#include "frontend/operator/ops.h" -#include "frontend/optimizer/optimizer.h" -#include "include/common/utils/parallel_context.h" -#include "frontend/parallel/device_manager.h" -#include "frontend/parallel/graph_util/generate_graph.h" -#include "frontend/parallel/graph_util/graph_info.h" -#include "frontend/parallel/graph_util/node_info.h" -#include "frontend/parallel/graph_util/get_parallel_info.h" -#include "frontend/parallel/graph_util/pipeline_split_utils.h" -#include "frontend/parallel/node_check.h" -#include "ir/param_info.h" -#include "ir/tensor.h" -#include "utils/trace_base.h" -#include "include/common/utils/comm_manager.h" -#include "utils/ms_context.h" -#include "utils/symbolic.h" -#include "pipeline/jit/ps/pipeline.h" -#include "mindspore/core/utils/parallel_node_check.h" -#include "frontend/parallel/step_parallel_utils.h" -#include "mindspore/core/ops/nn_ops.h" - -namespace mindspore { -namespace parallel { -const std::string GetCNodeOpName(const CNodePtr &cnode) { - // get the prim name of cnode - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - return node_prim->name(); -} - -std::pair BackwardSearchCNode(const CNodePtr &bottom_node, - const std::vector> &bwd_calls, - const std::string &target_name) { - CNodePtr target_node = bottom_node; - for (const auto &call_param : bwd_calls) { - const auto node_name = call_param.first; - const auto idx = call_param.second; - auto cnode_name = GetCNodeOpName(target_node); - if (cnode_name != node_name) { - MS_LOG(DEBUG) << "[CAME] backward search failed, expect node name: " << node_name << " but got " << cnode_name; - return {false, bottom_node}; - } - const auto ¶m_node = target_node->input(idx + 1); - if (!param_node) { - MS_LOG(DEBUG) << "[CAME] backward search failed, expect param at index: " << (idx + 1) << " but got null"; - return {false, bottom_node}; - } - if (!param_node->isa()) { - MS_LOG(DEBUG) << "[CAME] param node is not a cnode!"; - return {false, bottom_node}; - } - auto param_cnode = param_node->cast(); - MS_EXCEPTION_IF_NULL(param_cnode); - target_node = param_cnode; - } - auto cnode_name = GetCNodeOpName(target_node); - if (cnode_name != target_name) { - MS_LOG(DEBUG) << "[CAME] backward search failed, expect target node name: " << target_name << " but got " - << cnode_name; - return {false, bottom_node}; - } - return {true, target_node}; -} - -std::pair> ForwardSearchCNode(const CNodePtr &start_node, - const std::vector &fwd_calls, - const NodeUsersMap &node_user_map) { - if (!start_node) { - MS_LOG(DEBUG) << "[CAME] forward search start is null!"; - return {false, {}}; - } - if (fwd_calls.empty()) { - MS_LOG(DEBUG) << "[CAME] gives empty forward calls!"; - return {false, {}}; - } - std::vector candidates; - std::deque visited; - CNodePtr cur_node = nullptr; - uint32_t depth = 0; - - visited.push_back(start_node); - CNodePtr last_node = visited.back(); - while (!visited.empty()) { - if (depth == fwd_calls.size() - 1) { - std::copy(visited.begin(), visited.end(), std::back_inserter(candidates)); - break; - } - cur_node = visited.front(); - MS_LOG(INFO) << "[CAME] fwd current node: " << cur_node->DebugString(); - visited.pop_front(); - auto node_set = node_user_map.at(cur_node->cast()); - for (auto item : node_set) { - auto user_node = item.first; - if (!user_node->isa()) { - continue; - } - auto user_cnode = user_node->cast(); - if (GetCNodeOpName(user_cnode) == fwd_calls[depth + 1]) { - visited.push_back(user_cnode); - } - } - if (last_node == cur_node) { - last_node = visited.back(); - depth++; - } - } - - if (candidates.empty()) { - return {false, {}}; - } else { - return {true, candidates}; - } -} - -CameCommHandler::CameCommHandler(ParameterPtr origin, const std::vector &all_parameters, - const NodeUsersMap &node_user_map) - : origin(origin), all_parameters(all_parameters), node_user_map(node_user_map) { - CheckGlobalDeviceManager(); - cur_rank = g_device_manager->global_rank(); - full_rank_list = g_device_manager->GetDeviceListInThisStage(); - - tensor_layout = origin->user_data(); - MS_EXCEPTION_IF_NULL(tensor_layout); - - auto opt_shard_group_name = tensor_layout->opt_shard_group(); - if (!opt_shard_group_name.empty()) { - is_opt_shard = true; - } - MS_LOG(DEBUG) << "CAME processing parameter"; - MS_LOG(DEBUG) << "tensor shape:" << tensor_layout->tensor_shape().ToString(); - MS_LOG(DEBUG) << "slice shape:" << tensor_layout->slice_shape().ToString(); - - MS_LOG(DEBUG) << "opt shard slice shape:"; - for (const auto &item : tensor_layout->opt_shard_slice_shape()) { - MS_LOG(DEBUG) << item; - } - MS_LOG(DEBUG) << "opt shard group:" << tensor_layout->opt_shard_group(); - MS_LOG(DEBUG) << "opt shard step:" << tensor_layout->opt_weight_shard_step(); - - MS_LOG(DEBUG) << "device arrangement:" << tensor_layout->device_arrangement().ToString(); - MS_LOG(DEBUG) << "original device arrangement:" << tensor_layout->device_arrangement_origin().ToString(); - - MS_LOG(DEBUG) << "tensor map:" << tensor_layout->tensor_map().ToString(); - MS_LOG(DEBUG) << "original tensor map:" << tensor_layout->origin_tensor_map().ToString(); - - FindCameParams(); -} - -void CameCommHandler::FindCameParams() { - const std::string origin_name = origin->name(); - const std::string exp_row_name = EXP_AVG_SQ_ROW + origin_name; - const std::string exp_col_name = EXP_AVG_SQ_COL + origin_name; - const std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + origin_name; - const std::string exp_insta_col_name = EXP_AVG_INSTA_COL + origin_name; - const std::string exp_avg_name = std::string(EXP_AVG) + "." + origin_name; - const size_t param_to_find_size = 5; - size_t cur_found_param_count = 0; - for (const auto ¶m_node : all_parameters) { - auto param = param_node->cast(); - MS_EXCEPTION_IF_NULL(param); - const std::string param_name = param->name(); - if (param_name == exp_row_name) { - MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_row: " << param_name; - exp_avg_sq_row = param; - cur_found_param_count++; - } else if (param_name == exp_col_name) { - MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_col: " << param_name; - exp_avg_sq_col = param; - cur_found_param_count++; - } else if (param_name == exp_insta_row_name) { - MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_row: " << param_name; - exp_avg_insta_row = param; - cur_found_param_count++; - } else if (param_name == exp_insta_col_name) { - MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_col: " << param_name; - exp_avg_insta_col = param; - cur_found_param_count++; - } else if (param_name == exp_avg_name) { - MS_LOG(DEBUG) << "[CAME] found exp_avg: " << param_name; - exp_avg = param; - cur_found_param_count++; - } - - if (cur_found_param_count == param_to_find_size) { - break; - } - } - MS_LOG(INFO) << "[CAME] found params corresponding to origin param size: " << cur_found_param_count; -} - -std::pair CameCommHandler::GetOptShardRankList(const int64_t rank) { - DeviceMatrix temp_dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array()); - RankList group_devices; - Shape orig_tensor_map = tensor_layout->tensor_map().array(); - if (temp_dev_matrix.GetDevicesByTensorMap(orig_tensor_map, &group_devices) != SUCCESS) { - return {FAILED, {}}; - } - if (group_devices.size() < 2) { - MS_LOG(ERROR) << "get opt shard rank list with less than two group devices!"; - return {FAILED, {}}; - } - - int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); - MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size); - if ((optimizer_weight_shard_size == -1) || (optimizer_weight_shard_size > SizeToLong(group_devices.size()))) { - MS_LOG(INFO) << "[CAME] detect optimizer_weight_shard_size = -1 or exceed max shard size, use group devices size: " - << group_devices.size(); - optimizer_weight_shard_size = SizeToLong(group_devices.size()); - } - - int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin(); - - // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24] - auto rank_list = - RankList(group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size, - group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size); - return std::make_pair(SUCCESS, rank_list); -} - -std::pair CameCommHandler::GetDimRankList(const int64_t rank, const int64_t dim) { - DeviceMatrix dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array()); - int64_t device_reverse_dim = tensor_layout->tensor_map().GetDimByIdx(dim); - if (device_reverse_dim == -1) { - return {SUCCESS, {rank}}; - } - int64_t device_dim = SizeToLong(tensor_layout->device_arrangement().array().size()) - 1 - device_reverse_dim; - RankList rank_list; - if (dev_matrix.GetDevicesAlongDim(LongToUlong(device_dim), &rank_list) != SUCCESS) { - MS_LOG(ERROR) << "Get devices along dim failed"; - return {FAILED, rank_list}; - } - return {SUCCESS, rank_list}; -} - -RankList CameCommHandler::ExpandRankListWithOptShard(const RankList &rank_list) { - if (!is_opt_shard) { - return rank_list; - } - MS_LOG(INFO) << "opt shard yes, group name:" << tensor_layout->opt_shard_group(); - - RankList opt_rank_list_find = g_device_manager->FindRankListByHashName(tensor_layout->opt_shard_group()); - for (const auto &opt_find_rank : opt_rank_list_find) { - MS_LOG(INFO) << "group device member:" << opt_find_rank; - } - - RankList expanded_list; - for (const auto &rank : rank_list) { - Status ret_state; - RankList opt_shard_rank_list; - std::tie(ret_state, opt_shard_rank_list) = GetOptShardRankList(rank); - if (ret_state != SUCCESS) { - MS_LOG(EXCEPTION) << "find opt shard rank list in adafactor failed"; - } - MS_LOG(INFO) << "found opt shard rank list for rank " << rank; - - for (const auto &opt_rank : opt_shard_rank_list) { - MS_LOG(INFO) << opt_rank; - } - expanded_list.insert(expanded_list.end(), opt_shard_rank_list.begin(), opt_shard_rank_list.end()); - } - std::sort(expanded_list.begin(), expanded_list.end()); - MS_LOG(INFO) << "expand rank list with opt shard, before:"; - for (const auto &item : rank_list) { - MS_LOG(INFO) << item; - } - MS_LOG(INFO) << "after:"; - for (const auto &item : expanded_list) { - MS_LOG(INFO) << item; - } - return expanded_list; -} - -RankList CameCommHandler::ExpandRankListWithDim(const RankList &rank_list, const int64_t dim) { - RankList expanded_list; - for (const auto &rank : rank_list) { - Status ret_status; - RankList dim_rank_list; - std::tie(ret_status, dim_rank_list) = GetDimRankList(rank, dim); - if (ret_status != SUCCESS) { - MS_LOG(EXCEPTION) << "find dim rank list in adafactor failed"; - } - expanded_list.insert(expanded_list.end(), dim_rank_list.begin(), dim_rank_list.end()); - } - std::sort(expanded_list.begin(), expanded_list.end()); - return expanded_list; -} - -CNodePtr CameCommHandler::FindReduceMean(size_t number) { - if (reduce_mean_numbers.find(number) == reduce_mean_numbers.end()) { - MS_LOG(INFO) << "[CAME] invalid reduce mean number: " << number; - } - - if (number == kFirstCameReduceMean) { - return FindReduceMean1256(exp_avg_sq_row); - } else if (number == kSecondCameReduceMean) { - return FindReduceMean1256(exp_avg_sq_col); - } else if (number == kThirdCameReduceMean) { - return FindReduceMean37(exp_avg_sq_row); - } else if (number == kForthCameReduceMean) { - return FindReduceMean4(); - } else if (number == kFifthCameReduceMean) { - return FindReduceMean1256(exp_avg_insta_row); - } else if (number == kSixthCameReduceMean) { - return FindReduceMean1256(exp_avg_insta_col); - } else if (number == kSeventhCameReduceMean) { - return FindReduceMean37(exp_avg_insta_row); - } else { - return nullptr; - } -} - -CNodePtr CameCommHandler::FindReduceMean1256(const ParameterPtr ¶m) { - if (!param) { - return nullptr; - } - MS_LOG(INFO) << "[CAME] try find reduce_mean according to " << param->name() << " Assign:"; - auto param_user_set = node_user_map.at(param->cast()); - for (auto ¶m_pair : param_user_set) { - auto user_cnode = param_pair.first->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - if (IsSomePrimitive(user_cnode, ASSIGN)) { - MS_LOG(INFO) << "[CAME] found assign node"; - // assign 1 -> add 1 -> mul 0 -> reduce_mean - auto res = BackwardSearchCNode(user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}}, REDUCE_MEAN); - if (res.first) { - MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString(); - return res.second; - } - } - } - return nullptr; -} - -CNodePtr CameCommHandler::FindReduceMean37(const ParameterPtr ¶m) { - if (!param) { - return nullptr; - } - auto param_user_set = node_user_map.at(param->cast()); - MS_LOG(INFO) << "[CAME] user map size: " << param_user_set.size(); - size_t load_count = 0; - for (auto ¶m_pair : param_user_set) { - auto user_cnode = param_pair.first->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - if (IsSomePrimitive(user_cnode, LOAD)) { - MS_LOG(INFO) << "[CAME] found load node"; - load_count++; - // load -> reduce mean - auto res = ForwardSearchCNode(user_cnode, {LOAD, REDUCE_MEAN}, node_user_map); - if (res.first) { - MS_LOG(INFO) << "[CAME] found reduce mean node size: " << res.second.size(); - return res.second[0]; // get the first one - } - } - } - MS_LOG(INFO) << "[CAME] found load count: " << load_count; - return nullptr; -} - -CNodePtr CameCommHandler::FindReduceMean4() { - MS_LOG(INFO) << "[CAME] try find reduce_mean no.4 according to exp_avg Assign:"; - if (!exp_avg) { - return nullptr; - } - auto exp_avg_user_set = node_user_map.at(exp_avg->cast()); - for (auto ¶m_pair : exp_avg_user_set) { - auto user_cnode = param_pair.first->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - if (IsSomePrimitive(user_cnode, ASSIGN)) { - MS_LOG(INFO) << "[CAME] found exp_avg's assign node"; - auto res = BackwardSearchCNode( - user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}, {REAL_DIV, 1}, {MAXIMUM, 0}, {REAL_DIV, 0}, {SQRT, 0}}, - REDUCE_MEAN); - if (res.first) { - MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString(); - return res.second; - } - } - } - return nullptr; -} - -void CameCommHandler::InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list) { - // construct all reduce cnode and insert to the first input - if (!reduce_mean) { - return; - } - FuncGraphPtr func_graph = reduce_mean->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - CheckGlobalDeviceManager(); - - MS_LOG(INFO) << "Insert All Reduce and RealDiv to node" << reduce_mean->DebugString(); - // insert all reduce - OperatorName allreduce_op_name = ALL_REDUCE; - OperatorAttrs all_reduce_op_attrs; - ValuePtr allreduce_pyop_instance = CreateOpInstance(all_reduce_op_attrs, allreduce_op_name, "came_norm_allreduce"); - std::vector all_reduce_input = {NewValueNode(allreduce_pyop_instance), reduce_mean}; - auto all_reduce_node = func_graph->NewCNode(all_reduce_input); - auto all_reduce_prim = GetCNodePrimitive(all_reduce_node); - auto all_reduce_attrs = all_reduce_prim->attrs(); - all_reduce_attrs["op"] = MakeValue(REDUCE_OP_SUM); - - std::string group_name = CreateCommGroupFromRankList(comm_rank_list); - MS_LOG(INFO) << "[CAME] came allreduce opt shard group: " << group_name; - all_reduce_attrs["group"] = MakeValue(group_name); - int64_t fusion_id = 0; - all_reduce_attrs["fusion"] = MakeValue(fusion_id); - all_reduce_prim->SetAttrs(all_reduce_attrs); - // insert real div - OperatorName operator_name = REAL_DIV; - OperatorAttrs operator_attrs; - - ValuePtr pyop_instance = CreateOpInstance(operator_attrs, operator_name, "came_norm_realdiv"); - MS_EXCEPTION_IF_NULL(pyop_instance); - - size_t group_rank_size = comm_rank_list.size(); - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( - static_cast(group_rank_size), - reduce_mean->abstract()->cast()->element()->GetType()); - ValuePtr scale_value = MakeValue(tensor_ptr); - - std::vector real_div_input = {NewValueNode(pyop_instance), all_reduce_node->cast(), - NewValueNode(scale_value)}; - auto real_div_node = func_graph->NewCNode(real_div_input); - manager->Replace(reduce_mean, real_div_node); -} - -void CameCommHandler::Process() { - auto reduce_mean_1 = FindReduceMean(1); - auto reduce_mean_2 = FindReduceMean(2); - auto reduce_mean_3 = FindReduceMean(3); - auto reduce_mean_4 = FindReduceMean(4); - auto reduce_mean_5 = FindReduceMean(5); - auto reduce_mean_6 = FindReduceMean(6); - auto reduce_mean_7 = FindReduceMean(7); - MS_LOG(INFO) << "found all reduce mean for came/adafactor"; - - auto shape_size = tensor_layout->slice_shape().array().size(); - if (shape_size == 1) { - // for shape [A], mp and opt shard may overlay on dim A. - Status ret_status; - RankList comm_rank_list; - std::tie(ret_status, comm_rank_list) = GetDimRankList(cur_rank, 0); - if (ret_status != SUCCESS) { - MS_LOG(ERROR) << "[CAME] shape size = 1, getting rank list along 0 failed"; - } - comm_rank_list = ExpandRankListWithOptShard(comm_rank_list); - if (comm_rank_list.size() > 1) { - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list); - } - } else { - Status ret_status; - RankList comm_rank_list_along_neg_1; - RankList comm_rank_list_along_neg_2; - RankList comm_rank_list_along_neg_12; - int64_t actual_dim_of_neg_1 = SizeToLong(shape_size) - 1; - int64_t actual_dim_of_neg_2 = SizeToLong(shape_size) - 2; - std::tie(ret_status, comm_rank_list_along_neg_1) = GetDimRankList(cur_rank, actual_dim_of_neg_1); - if (ret_status != SUCCESS) { - MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -1 failed"; - } - std::tie(ret_status, comm_rank_list_along_neg_2) = GetDimRankList(cur_rank, actual_dim_of_neg_2); - if (ret_status != SUCCESS) { - MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -2 failed"; - } - if (shape_size == kParameterDimTwo) { - comm_rank_list_along_neg_2 = ExpandRankListWithOptShard(comm_rank_list_along_neg_2); - } - comm_rank_list_along_neg_12 = ExpandRankListWithDim(comm_rank_list_along_neg_2, actual_dim_of_neg_1); - if (comm_rank_list_along_neg_1.size() > 1) { - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_1, comm_rank_list_along_neg_1); - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_5, comm_rank_list_along_neg_1); - } - if (comm_rank_list_along_neg_2.size() > 1) { - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_2, comm_rank_list_along_neg_2); - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_3, comm_rank_list_along_neg_2); - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_6, comm_rank_list_along_neg_2); - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_7, comm_rank_list_along_neg_2); - } - if (comm_rank_list_along_neg_12.size() > 1) { - InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list_along_neg_12); - } - } -} - -std::string CameCommHandler::CreateCommGroupFromRankList(const RankList &rank_list) { - Group comm_group; - if (g_device_manager->CreateGroup(rank_list, &comm_group) != SUCCESS) { - MS_LOG(EXCEPTION) << "Create comm group failed in came"; - } - std::string group_name = comm_group.name(); - return group_name; -} - -} // namespace parallel -} // namespace mindspore +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/parallel/came_parallel_handler.h" + +#include +#include + +#include "frontend/parallel/parameter_manager.h" +#include "mindspore/core/ops/sequence_ops.h" +#include "mindspore/core/ops/other_ops.h" +#include "mindspore/core/ops/array_ops.h" +#include "mindspore/core/ops/framework_ops.h" +#include "mindspore/core/utils/convert_utils_base.h" +#include "utils/hash_map.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/optimizer.h" +#include "include/common/utils/parallel_context.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/graph_util/get_parallel_info.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" +#include "frontend/parallel/node_check.h" +#include "ir/param_info.h" +#include "ir/tensor.h" +#include "utils/trace_base.h" +#include "include/common/utils/comm_manager.h" +#include "utils/ms_context.h" +#include "utils/symbolic.h" +#include "pipeline/jit/ps/pipeline.h" +#include "mindspore/core/utils/parallel_node_check.h" +#include "frontend/parallel/step_parallel_utils.h" +#include "mindspore/core/ops/nn_ops.h" + +namespace mindspore { +namespace parallel { +const std::string GetCNodeOpName(const CNodePtr &cnode) { + // get the prim name of cnode + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + return node_prim->name(); +} + +std::pair BackwardSearchCNode(const CNodePtr &bottom_node, + const std::vector> &bwd_calls, + const std::string &target_name) { + CNodePtr target_node = bottom_node; + for (const auto &call_param : bwd_calls) { + const auto node_name = call_param.first; + const auto idx = call_param.second; + auto cnode_name = GetCNodeOpName(target_node); + if (cnode_name != node_name) { + MS_LOG(DEBUG) << "[CAME] backward search failed, expect node name: " << node_name << " but got " << cnode_name; + return {false, bottom_node}; + } + const auto ¶m_node = target_node->input(idx + 1); + if (!param_node) { + MS_LOG(DEBUG) << "[CAME] backward search failed, expect param at index: " << (idx + 1) << " but got null"; + return {false, bottom_node}; + } + if (!param_node->isa()) { + MS_LOG(DEBUG) << "[CAME] param node is not a cnode!"; + return {false, bottom_node}; + } + auto param_cnode = param_node->cast(); + MS_EXCEPTION_IF_NULL(param_cnode); + target_node = param_cnode; + } + auto cnode_name = GetCNodeOpName(target_node); + if (cnode_name != target_name) { + MS_LOG(DEBUG) << "[CAME] backward search failed, expect target node name: " << target_name << " but got " + << cnode_name; + return {false, bottom_node}; + } + return {true, target_node}; +} + +std::pair> ForwardSearchCNode(const CNodePtr &start_node, + const std::vector &fwd_calls, + const NodeUsersMap &node_user_map) { + if (!start_node) { + MS_LOG(DEBUG) << "[CAME] forward search start is null!"; + return {false, {}}; + } + if (fwd_calls.empty()) { + MS_LOG(DEBUG) << "[CAME] gives empty forward calls!"; + return {false, {}}; + } + std::vector candidates; + std::deque visited; + CNodePtr cur_node = nullptr; + uint32_t depth = 0; + + visited.push_back(start_node); + CNodePtr last_node = visited.back(); + while (!visited.empty()) { + if (depth == fwd_calls.size() - 1) { + std::copy(visited.begin(), visited.end(), std::back_inserter(candidates)); + break; + } + cur_node = visited.front(); + MS_LOG(INFO) << "[CAME] fwd current node: " << cur_node->DebugString(); + visited.pop_front(); + auto node_set = node_user_map.at(cur_node->cast()); + for (auto item : node_set) { + auto user_node = item.first; + if (!user_node->isa()) { + continue; + } + auto user_cnode = user_node->cast(); + if (GetCNodeOpName(user_cnode) == fwd_calls[depth + 1]) { + visited.push_back(user_cnode); + } + } + if (last_node == cur_node) { + last_node = visited.back(); + depth++; + } + } + + if (candidates.empty()) { + return {false, {}}; + } else { + return {true, candidates}; + } +} + +CameCommHandler::CameCommHandler(ParameterPtr origin, const std::vector &all_parameters, + const NodeUsersMap &node_user_map) + : origin(origin), all_parameters(all_parameters), node_user_map(node_user_map) { + CheckGlobalDeviceManager(); + cur_rank = g_device_manager->global_rank(); + full_rank_list = g_device_manager->GetDeviceListInThisStage(); + + tensor_layout = origin->user_data(); + MS_EXCEPTION_IF_NULL(tensor_layout); + + auto opt_shard_group_name = tensor_layout->opt_shard_group(); + if (!opt_shard_group_name.empty()) { + is_opt_shard = true; + } + MS_LOG(DEBUG) << "CAME processing parameter"; + MS_LOG(DEBUG) << "tensor shape:" << tensor_layout->tensor_shape().ToString(); + MS_LOG(DEBUG) << "slice shape:" << tensor_layout->slice_shape().ToString(); + + MS_LOG(DEBUG) << "opt shard slice shape:"; + for (const auto &item : tensor_layout->opt_shard_slice_shape()) { + MS_LOG(DEBUG) << item; + } + MS_LOG(DEBUG) << "opt shard group:" << tensor_layout->opt_shard_group(); + MS_LOG(DEBUG) << "opt shard step:" << tensor_layout->opt_weight_shard_step(); + + MS_LOG(DEBUG) << "device arrangement:" << tensor_layout->device_arrangement().ToString(); + MS_LOG(DEBUG) << "original device arrangement:" << tensor_layout->device_arrangement_origin().ToString(); + + MS_LOG(DEBUG) << "tensor map:" << tensor_layout->tensor_map().ToString(); + MS_LOG(DEBUG) << "original tensor map:" << tensor_layout->origin_tensor_map().ToString(); + + FindCameParams(); +} + +void CameCommHandler::FindCameParams() { + const std::string origin_name = origin->name(); + const std::string exp_row_name = EXP_AVG_SQ_ROW + origin_name; + const std::string exp_col_name = EXP_AVG_SQ_COL + origin_name; + const std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + origin_name; + const std::string exp_insta_col_name = EXP_AVG_INSTA_COL + origin_name; + const std::string exp_avg_name = std::string(EXP_AVG) + "." + origin_name; + const size_t param_to_find_size = 5; + size_t cur_found_param_count = 0; + for (const auto ¶m_node : all_parameters) { + auto param = param_node->cast(); + MS_EXCEPTION_IF_NULL(param); + const std::string param_name = param->name(); + if (param_name == exp_row_name) { + MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_row: " << param_name; + exp_avg_sq_row = param; + cur_found_param_count++; + } else if (param_name == exp_col_name) { + MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_col: " << param_name; + exp_avg_sq_col = param; + cur_found_param_count++; + } else if (param_name == exp_insta_row_name) { + MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_row: " << param_name; + exp_avg_insta_row = param; + cur_found_param_count++; + } else if (param_name == exp_insta_col_name) { + MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_col: " << param_name; + exp_avg_insta_col = param; + cur_found_param_count++; + } else if (param_name == exp_avg_name) { + MS_LOG(DEBUG) << "[CAME] found exp_avg: " << param_name; + exp_avg = param; + cur_found_param_count++; + } + + if (cur_found_param_count == param_to_find_size) { + break; + } + } + MS_LOG(INFO) << "[CAME] found params corresponding to origin param size: " << cur_found_param_count; +} + +std::pair CameCommHandler::GetOptShardRankList(const int64_t rank) { + DeviceMatrix temp_dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array()); + RankList group_devices; + Shape orig_tensor_map = tensor_layout->tensor_map().array(); + if (temp_dev_matrix.GetDevicesByTensorMap(orig_tensor_map, &group_devices) != SUCCESS) { + return {FAILED, {}}; + } + if (group_devices.size() < 2) { + MS_LOG(ERROR) << "get opt shard rank list with less than two group devices!"; + return {FAILED, {}}; + } + + int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); + MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size); + if ((optimizer_weight_shard_size == -1) || (optimizer_weight_shard_size > SizeToLong(group_devices.size()))) { + MS_LOG(INFO) << "[CAME] detect optimizer_weight_shard_size = -1 or exceed max shard size, use group devices size: " + << group_devices.size(); + optimizer_weight_shard_size = SizeToLong(group_devices.size()); + } + + int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin(); + + // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24] + auto rank_list = + RankList(group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size, + group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size); + return std::make_pair(SUCCESS, rank_list); +} + +std::pair CameCommHandler::GetDimRankList(const int64_t rank, const int64_t dim) { + DeviceMatrix dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array()); + int64_t device_reverse_dim = tensor_layout->tensor_map().GetDimByIdx(dim); + if (device_reverse_dim == -1) { + return {SUCCESS, {rank}}; + } + int64_t device_dim = SizeToLong(tensor_layout->device_arrangement().array().size()) - 1 - device_reverse_dim; + RankList rank_list; + if (dev_matrix.GetDevicesAlongDim(LongToUlong(device_dim), &rank_list) != SUCCESS) { + MS_LOG(ERROR) << "Get devices along dim failed"; + return {FAILED, rank_list}; + } + return {SUCCESS, rank_list}; +} + +RankList CameCommHandler::ExpandRankListWithOptShard(const RankList &rank_list) { + if (!is_opt_shard) { + return rank_list; + } + MS_LOG(INFO) << "opt shard yes, group name:" << tensor_layout->opt_shard_group(); + + RankList opt_rank_list_find = g_device_manager->FindRankListByHashName(tensor_layout->opt_shard_group()); + for (const auto &opt_find_rank : opt_rank_list_find) { + MS_LOG(INFO) << "group device member:" << opt_find_rank; + } + + RankList expanded_list; + for (const auto &rank : rank_list) { + Status ret_state; + RankList opt_shard_rank_list; + std::tie(ret_state, opt_shard_rank_list) = GetOptShardRankList(rank); + if (ret_state != SUCCESS) { + MS_LOG(EXCEPTION) << "find opt shard rank list in adafactor failed"; + } + MS_LOG(INFO) << "found opt shard rank list for rank " << rank; + + for (const auto &opt_rank : opt_shard_rank_list) { + MS_LOG(INFO) << opt_rank; + } + expanded_list.insert(expanded_list.end(), opt_shard_rank_list.begin(), opt_shard_rank_list.end()); + } + std::sort(expanded_list.begin(), expanded_list.end()); + MS_LOG(INFO) << "expand rank list with opt shard, before:"; + for (const auto &item : rank_list) { + MS_LOG(INFO) << item; + } + MS_LOG(INFO) << "after:"; + for (const auto &item : expanded_list) { + MS_LOG(INFO) << item; + } + return expanded_list; +} + +RankList CameCommHandler::ExpandRankListWithDim(const RankList &rank_list, const int64_t dim) { + RankList expanded_list; + for (const auto &rank : rank_list) { + Status ret_status; + RankList dim_rank_list; + std::tie(ret_status, dim_rank_list) = GetDimRankList(rank, dim); + if (ret_status != SUCCESS) { + MS_LOG(EXCEPTION) << "find dim rank list in adafactor failed"; + } + expanded_list.insert(expanded_list.end(), dim_rank_list.begin(), dim_rank_list.end()); + } + std::sort(expanded_list.begin(), expanded_list.end()); + return expanded_list; +} + +CNodePtr CameCommHandler::FindReduceMean(size_t number) { + if (reduce_mean_numbers.find(number) == reduce_mean_numbers.end()) { + MS_LOG(INFO) << "[CAME] invalid reduce mean number: " << number; + } + + if (number == kFirstCameReduceMean) { + return FindReduceMean1256(exp_avg_sq_row); + } else if (number == kSecondCameReduceMean) { + return FindReduceMean1256(exp_avg_sq_col); + } else if (number == kThirdCameReduceMean) { + return FindReduceMean37(exp_avg_sq_row); + } else if (number == kForthCameReduceMean) { + return FindReduceMean4(); + } else if (number == kFifthCameReduceMean) { + return FindReduceMean1256(exp_avg_insta_row); + } else if (number == kSixthCameReduceMean) { + return FindReduceMean1256(exp_avg_insta_col); + } else if (number == kSeventhCameReduceMean) { + return FindReduceMean37(exp_avg_insta_row); + } else { + return nullptr; + } +} + +CNodePtr CameCommHandler::FindReduceMean1256(const ParameterPtr ¶m) { + if (!param) { + return nullptr; + } + MS_LOG(INFO) << "[CAME] try find reduce_mean according to " << param->name() << " Assign:"; + auto param_user_set = node_user_map.at(param->cast()); + for (auto ¶m_pair : param_user_set) { + auto user_cnode = param_pair.first->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + if (IsSomePrimitive(user_cnode, ASSIGN)) { + MS_LOG(INFO) << "[CAME] found assign node"; + // assign 1 -> add 1 -> mul 0 -> reduce_mean + auto res = BackwardSearchCNode(user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}}, REDUCE_MEAN); + if (res.first) { + MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString(); + return res.second; + } + } + } + return nullptr; +} + +CNodePtr CameCommHandler::FindReduceMean37(const ParameterPtr ¶m) { + if (!param) { + return nullptr; + } + auto param_user_set = node_user_map.at(param->cast()); + MS_LOG(INFO) << "[CAME] user map size: " << param_user_set.size(); + size_t load_count = 0; + for (auto ¶m_pair : param_user_set) { + auto user_cnode = param_pair.first->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + if (IsSomePrimitive(user_cnode, LOAD)) { + MS_LOG(INFO) << "[CAME] found load node"; + load_count++; + // load -> reduce mean + auto res = ForwardSearchCNode(user_cnode, {LOAD, REDUCE_MEAN}, node_user_map); + if (res.first) { + MS_LOG(INFO) << "[CAME] found reduce mean node size: " << res.second.size(); + return res.second[0]; // get the first one + } + } + } + MS_LOG(INFO) << "[CAME] found load count: " << load_count; + return nullptr; +} + +CNodePtr CameCommHandler::FindReduceMean4() { + MS_LOG(INFO) << "[CAME] try find reduce_mean no.4 according to exp_avg Assign:"; + if (!exp_avg) { + return nullptr; + } + auto exp_avg_user_set = node_user_map.at(exp_avg->cast()); + for (auto ¶m_pair : exp_avg_user_set) { + auto user_cnode = param_pair.first->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + if (IsSomePrimitive(user_cnode, ASSIGN)) { + MS_LOG(INFO) << "[CAME] found exp_avg's assign node"; + auto res = BackwardSearchCNode( + user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}, {REAL_DIV, 1}, {MAXIMUM, 0}, {REAL_DIV, 0}, {SQRT, 0}}, + REDUCE_MEAN); + if (res.first) { + MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString(); + return res.second; + } + } + } + return nullptr; +} + +void CameCommHandler::InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list) { + // construct all reduce cnode and insert to the first input + if (!reduce_mean) { + return; + } + FuncGraphPtr func_graph = reduce_mean->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + CheckGlobalDeviceManager(); + + MS_LOG(INFO) << "Insert All Reduce and RealDiv to node" << reduce_mean->DebugString(); + // insert all reduce + OperatorName allreduce_op_name = ALL_REDUCE; + OperatorAttrs all_reduce_op_attrs; + ValuePtr allreduce_pyop_instance = CreateOpInstance(all_reduce_op_attrs, allreduce_op_name, "came_norm_allreduce"); + std::vector all_reduce_input = {NewValueNode(allreduce_pyop_instance), reduce_mean}; + auto all_reduce_node = func_graph->NewCNode(all_reduce_input); + auto all_reduce_prim = GetCNodePrimitive(all_reduce_node); + auto all_reduce_attrs = all_reduce_prim->attrs(); + all_reduce_attrs["op"] = MakeValue(REDUCE_OP_SUM); + + std::string group_name = CreateCommGroupFromRankList(comm_rank_list); + MS_LOG(INFO) << "[CAME] came allreduce opt shard group: " << group_name; + all_reduce_attrs["group"] = MakeValue(group_name); + int64_t fusion_id = 0; + all_reduce_attrs["fusion"] = MakeValue(fusion_id); + all_reduce_prim->SetAttrs(all_reduce_attrs); + // insert real div + OperatorName operator_name = REAL_DIV; + OperatorAttrs operator_attrs; + + ValuePtr pyop_instance = CreateOpInstance(operator_attrs, operator_name, "came_norm_realdiv"); + MS_EXCEPTION_IF_NULL(pyop_instance); + + size_t group_rank_size = comm_rank_list.size(); + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( + static_cast(group_rank_size), + reduce_mean->abstract()->cast()->element()->GetType()); + ValuePtr scale_value = MakeValue(tensor_ptr); + + std::vector real_div_input = {NewValueNode(pyop_instance), all_reduce_node->cast(), + NewValueNode(scale_value)}; + auto real_div_node = func_graph->NewCNode(real_div_input); + manager->Replace(reduce_mean, real_div_node); +} + +void CameCommHandler::Process() { + auto reduce_mean_1 = FindReduceMean(1); + auto reduce_mean_2 = FindReduceMean(2); + auto reduce_mean_3 = FindReduceMean(3); + auto reduce_mean_4 = FindReduceMean(4); + auto reduce_mean_5 = FindReduceMean(5); + auto reduce_mean_6 = FindReduceMean(6); + auto reduce_mean_7 = FindReduceMean(7); + MS_LOG(INFO) << "found all reduce mean for came/adafactor"; + + auto shape_size = tensor_layout->slice_shape().array().size(); + if (shape_size == 1) { + // for shape [A], mp and opt shard may overlay on dim A. + Status ret_status; + RankList comm_rank_list; + std::tie(ret_status, comm_rank_list) = GetDimRankList(cur_rank, 0); + if (ret_status != SUCCESS) { + MS_LOG(ERROR) << "[CAME] shape size = 1, getting rank list along 0 failed"; + } + comm_rank_list = ExpandRankListWithOptShard(comm_rank_list); + if (comm_rank_list.size() > 1) { + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list); + } + } else { + Status ret_status; + RankList comm_rank_list_along_neg_1; + RankList comm_rank_list_along_neg_2; + RankList comm_rank_list_along_neg_12; + int64_t actual_dim_of_neg_1 = SizeToLong(shape_size) - 1; + int64_t actual_dim_of_neg_2 = SizeToLong(shape_size) - 2; + std::tie(ret_status, comm_rank_list_along_neg_1) = GetDimRankList(cur_rank, actual_dim_of_neg_1); + if (ret_status != SUCCESS) { + MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -1 failed"; + } + std::tie(ret_status, comm_rank_list_along_neg_2) = GetDimRankList(cur_rank, actual_dim_of_neg_2); + if (ret_status != SUCCESS) { + MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -2 failed"; + } + if (shape_size == kParameterDimTwo) { + comm_rank_list_along_neg_2 = ExpandRankListWithOptShard(comm_rank_list_along_neg_2); + } + comm_rank_list_along_neg_12 = ExpandRankListWithDim(comm_rank_list_along_neg_2, actual_dim_of_neg_1); + if (comm_rank_list_along_neg_1.size() > 1) { + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_1, comm_rank_list_along_neg_1); + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_5, comm_rank_list_along_neg_1); + } + if (comm_rank_list_along_neg_2.size() > 1) { + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_2, comm_rank_list_along_neg_2); + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_3, comm_rank_list_along_neg_2); + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_6, comm_rank_list_along_neg_2); + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_7, comm_rank_list_along_neg_2); + } + if (comm_rank_list_along_neg_12.size() > 1) { + InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list_along_neg_12); + } + } +} + +std::string CameCommHandler::CreateCommGroupFromRankList(const RankList &rank_list) { + Group comm_group; + if (g_device_manager->CreateGroup(rank_list, &comm_group) != SUCCESS) { + MS_LOG(EXCEPTION) << "Create comm group failed in came"; + } + std::string group_name = comm_group.name(); + return group_name; +} + +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/came_parallel_handler.h b/mindspore/ccsrc/frontend/parallel/came_parallel_handler.h index 19292067079..e91c645eed5 100644 --- a/mindspore/ccsrc/frontend/parallel/came_parallel_handler.h +++ b/mindspore/ccsrc/frontend/parallel/came_parallel_handler.h @@ -1,97 +1,97 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ -#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ - -#include -#include - -#include -#include -#include -#include -#include "base/base.h" -#include "frontend/parallel/device_manager.h" -#include "frontend/parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -using TensorLayoutPtr = std::shared_ptr; - -constexpr size_t kFirstCameReduceMean = 1; -constexpr size_t kSecondCameReduceMean = 2; -constexpr size_t kThirdCameReduceMean = 3; -constexpr size_t kForthCameReduceMean = 4; -constexpr size_t kFifthCameReduceMean = 5; -constexpr size_t kSixthCameReduceMean = 6; -constexpr size_t kSeventhCameReduceMean = 7; -constexpr size_t kParameterDimTwo = 2; - -constexpr char EXP_AVG[] = "exp_avg"; -constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_"; -constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_"; -constexpr char EXP_AVG_INSTA_ROW[] = "exp_avg_insta_row_"; -constexpr char EXP_AVG_INSTA_COL[] = "exp_avg_insta_col_"; -constexpr char EXP_AVG_SQ[] = "exp_avg_sq_"; - -class CameCommHandler { - public: - CameCommHandler(ParameterPtr origin, const std::vector &all_parameters, - const NodeUsersMap &node_user_map); - void Process(); - - private: - ParameterPtr origin; - const std::vector &all_parameters; - TensorLayoutPtr tensor_layout; - const NodeUsersMap &node_user_map; - - int64_t cur_rank = -1; - DeviceMatrix dev_matrix; - RankList full_rank_list; - - bool is_opt_shard = false; - - ParameterPtr exp_avg_sq_row = nullptr; - ParameterPtr exp_avg_sq_col = nullptr; - ParameterPtr exp_avg = nullptr; - ParameterPtr exp_avg_insta_row = nullptr; - ParameterPtr exp_avg_insta_col = nullptr; - - std::set reduce_mean_numbers = {kFirstCameReduceMean, kSecondCameReduceMean, kThirdCameReduceMean, - kForthCameReduceMean, kFifthCameReduceMean, kSixthCameReduceMean, - kSeventhCameReduceMean}; - - void FindCameParams(); - - CNodePtr FindReduceMean(size_t number); - CNodePtr FindReduceMean1256(const ParameterPtr ¶m); - CNodePtr FindReduceMean37(const ParameterPtr ¶m); - CNodePtr FindReduceMean4(); - - std::pair GetOptShardRankList(const int64_t rank); - std::pair GetDimRankList(const int64_t rank, const int64_t dim); - - RankList ExpandRankListWithOptShard(const RankList &rank_list); - RankList ExpandRankListWithDim(const RankList &base, const int64_t dim); - - std::string CreateCommGroupFromRankList(const RankList &rank_list); - void InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ + +#include +#include + +#include +#include +#include +#include +#include "base/base.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +using TensorLayoutPtr = std::shared_ptr; + +constexpr size_t kFirstCameReduceMean = 1; +constexpr size_t kSecondCameReduceMean = 2; +constexpr size_t kThirdCameReduceMean = 3; +constexpr size_t kForthCameReduceMean = 4; +constexpr size_t kFifthCameReduceMean = 5; +constexpr size_t kSixthCameReduceMean = 6; +constexpr size_t kSeventhCameReduceMean = 7; +constexpr size_t kParameterDimTwo = 2; + +constexpr char EXP_AVG[] = "exp_avg"; +constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_"; +constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_"; +constexpr char EXP_AVG_INSTA_ROW[] = "exp_avg_insta_row_"; +constexpr char EXP_AVG_INSTA_COL[] = "exp_avg_insta_col_"; +constexpr char EXP_AVG_SQ[] = "exp_avg_sq_"; + +class CameCommHandler { + public: + CameCommHandler(ParameterPtr origin, const std::vector &all_parameters, + const NodeUsersMap &node_user_map); + void Process(); + + private: + ParameterPtr origin; + const std::vector &all_parameters; + TensorLayoutPtr tensor_layout; + const NodeUsersMap &node_user_map; + + int64_t cur_rank = -1; + DeviceMatrix dev_matrix; + RankList full_rank_list; + + bool is_opt_shard = false; + + ParameterPtr exp_avg_sq_row = nullptr; + ParameterPtr exp_avg_sq_col = nullptr; + ParameterPtr exp_avg = nullptr; + ParameterPtr exp_avg_insta_row = nullptr; + ParameterPtr exp_avg_insta_col = nullptr; + + std::set reduce_mean_numbers = {kFirstCameReduceMean, kSecondCameReduceMean, kThirdCameReduceMean, + kForthCameReduceMean, kFifthCameReduceMean, kSixthCameReduceMean, + kSeventhCameReduceMean}; + + void FindCameParams(); + + CNodePtr FindReduceMean(size_t number); + CNodePtr FindReduceMean1256(const ParameterPtr ¶m); + CNodePtr FindReduceMean37(const ParameterPtr ¶m); + CNodePtr FindReduceMean4(); + + std::pair GetOptShardRankList(const int64_t rank); + std::pair GetDimRankList(const int64_t rank, const int64_t dim); + + RankList ExpandRankListWithOptShard(const RankList &rank_list); + RankList ExpandRankListWithDim(const RankList &base, const int64_t dim); + + std::string CreateCommGroupFromRankList(const RankList &rank_list); + void InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.cc index 93a5ecc1ce6..1747beef350 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.cc @@ -1,569 +1,569 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/parallel/graph_util/fold_pipeline_split_utils.h" -#include -#include -#include -#include -#include - -#include "frontend/parallel/graph_util/generate_graph.h" -#include "frontend/parallel/graph_util/pipeline_split_utils.h" -#include "ops/other_ops.h" -#include "ops/math_ops.h" -#include "ops/framework_ops.h" -#include "ops/array_ops.h" -#include "ops/nn_ops.h" -#include "ir/value.h" -#include "frontend/parallel/ops_info/ops_utils.h" -#include "frontend/parallel/device_manager.h" -#include "include/common/utils/parallel_context.h" -#include "frontend/parallel/step_parallel.h" -#include "frontend/parallel/step_parallel_utils.h" -#include "frontend/parallel/graph_util/node_info.h" -#include "utils/parallel_node_check.h" - -namespace mindspore { -namespace parallel { - -namespace { -constexpr int kBackwardEnd = 1; -constexpr int kForwardStart = 2; -constexpr int kForwardEnd = 3; -} // namespace - -const std::set END_NODE_BLACK_LIST = { - prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits, - prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape}; - -int64_t GetSegmentMax(const FuncGraphPtr &root, const std::vector &forward_end) { - int64_t seg_max = 0; - if (forward_end.empty()) { - MS_LOG(EXCEPTION) << "Can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it."; - } else { - auto forward_end_cnode = forward_end.back()->cast(); - auto seg_size = forward_end_cnode->GetPrimalAttr(SEGMENT); - MS_EXCEPTION_IF_NULL(seg_size); - seg_max = GetValue(seg_size); - } - return seg_max; -} - -std::vector GetSubStepPairs(const PipelinePair &fp_or_bp_pair, int64_t sub_step_num, int64_t seg_num, - int64_t sub_micro_num, int64_t micro_num) { - std::vector fp_or_bp_sub_pairs; - for (int64_t s = 0; s < sub_step_num; s++) { - std::vector temp_first; - std::vector temp_second; - for (int64_t sid = 0; sid < seg_num; sid++) { - temp_first.insert(temp_first.end(), fp_or_bp_pair.first.begin() + s * sub_micro_num + sid * micro_num, - fp_or_bp_pair.first.begin() + (s + 1) * sub_micro_num + sid * micro_num); - temp_second.insert(temp_second.end(), fp_or_bp_pair.second.begin() + s * sub_micro_num + sid * micro_num, - fp_or_bp_pair.second.begin() + (s + 1) * sub_micro_num + sid * micro_num); - } - fp_or_bp_sub_pairs.emplace_back(temp_first, temp_second); - } - return fp_or_bp_sub_pairs; -} - -bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2) { - auto parallel_context = parallel::ParallelContext::GetInstance(); - if (parallel_context->enable_fold_pipeline()) { - auto get_value_func = [](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto seg = cnode->GetPrimalAttr(SEGMENT); - MS_EXCEPTION_IF_NULL(seg); - return GetValue(seg); - }; - - if (get_value_func(node1) != get_value_func(node2)) { - return get_value_func(node1) < get_value_func(node2); - } - } - return CompFunc(node1, node2); -} - -bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2) { - auto parallel_context = parallel::ParallelContext::GetInstance(); - if (parallel_context->enable_fold_pipeline()) { - auto get_value_func = [](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto seg = cnode->GetPrimalAttr(SEGMENT); - MS_EXCEPTION_IF_NULL(seg); - return GetValue(seg); - }; - - if (get_value_func(node1) != get_value_func(node2)) { - return get_value_func(node1) > get_value_func(node2); - } - } - return CompFunc(node1, node2); -} - -void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager) { - auto end_node = GetPreNode(temp_node); - MS_EXCEPTION_IF_NULL(end_node); - auto end_cnode = end_node->cast(); - MS_EXCEPTION_IF_NULL(end_cnode); - auto end_prim = GetCNodePrimitive(end_node); - OperatorAttrs attrs_; - auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node"); - auto value_node = NewValueNode(op); - auto new_prim = GetValueNode(value_node)->cast(); - (void)new_prim->SetAttrs(end_prim->attrs()); - manager->SetEdge(end_node, 0, value_node); - end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO)); - auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num(); - end_cnode->AddPrimalAttr(SEGMENT, MakeValue(seg - 1)); -} - -AnfNodePtr FindNodeFirstUser(const FuncGraphPtr &root, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(root); - auto node_users_map = root->manager()->node_users(); - auto users = node_users_map[node]; - for (auto &temp_user : users) { - MS_LOG(INFO) << "Receive user: " << (temp_user.first)->ToString(); - return temp_user.first; - } - return nullptr; -} - -static bool IsInEndNodeBlackListOrParallelBlackList(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - return true; - } - auto prim = GetValueNode(cnode->input(0)); - if (IsInParallelBlackList(prim)) { - return true; - } - for (auto &prim_node : END_NODE_BLACK_LIST) { - if (IsPrimitiveCNode(cnode, prim_node)) { - return true; - } - } - return false; -} - -AnfNodePtr GetPreNode(const AnfNodePtr &node) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector node_queue = {node}; - while (!node_queue.empty()) { - auto cur_node = (*node_queue.begin())->cast(); - (void)node_queue.erase(node_queue.begin()); - if (!cur_node) { - continue; - } - if (!IsInEndNodeBlackListOrParallelBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) { - MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString(); - return cur_node; - } - (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end()); - } - MS_LOG(EXCEPTION) << "Get Pipeline End node failed."; -} - -static bool ComputeLastSegForwardEndIdx(const PipelinePair &forward_start, size_t curr_idx, int64_t micro_max, - int64_t stage_num, int64_t stage_id) { - auto last_seg_idx = static_cast(1 + micro_max + 1 - 2 * (stage_num - stage_id - 1) - 1); - return curr_idx > forward_start.first.size() - last_seg_idx; -} - -void ReorderForFoldPipelineForward(const std::vector &pair_vector, int64_t seg_max, int64_t micro_max, - const FuncGraphPtr &root, AnfNodePtr *start_of_forward, AnfNodePtr *end_of_forward, - bool enable_1f1b) { - MS_EXCEPTION_IF_NULL(g_device_manager); - MS_EXCEPTION_IF_NULL(root); - auto manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); - - auto stage_num = g_device_manager->stage_num(); - auto stage_id = g_device_manager->stage_id(); - *start_of_forward = pair_vector[kForwardStart].first[0]; - for (size_t i = 1; i < pair_vector[kForwardStart].first.size(); ++i) { - auto prior_node_begin = pair_vector[kForwardEnd].first[i - 1]; - auto prior_node_end = pair_vector[kForwardEnd].second[i - 1]; - auto post_node_begin = pair_vector[kForwardStart].first[i]; - auto post_node_end = pair_vector[kForwardStart].second[i]; - if (IsFirstStage() && (i > IntToSize(micro_max))) { - auto receive_node = post_node_begin; - post_node_begin = FindNodeFirstUser(root, post_node_begin); - - MS_EXCEPTION_IF_NULL(post_node_begin); - auto insert_idx = i - LongToSize(micro_max + 1) + LongToSize(stage_num - 1); - auto send_node_begin = pair_vector[3].first[insert_idx]; - auto send_node_end = pair_vector[3].second[insert_idx]; - InsertDepend(post_node_end, send_node_begin, manager, root); - - auto send_cnode = send_node_begin->cast(); - auto before_send_node = GetActualOp(send_cnode->input(1)); - - InsertDepend(before_send_node, receive_node, manager, root); - } - if (enable_1f1b && ComputeLastSegForwardEndIdx(pair_vector[kForwardStart], i, micro_max, stage_num, stage_id)) { - continue; - } - - InsertDepend(prior_node_end, post_node_begin, manager, root); - *end_of_forward = pair_vector[kForwardEnd].second[i]; - } - (*end_of_forward)->cast()->AddPrimalAttr(FORWARD_END, MakeValue(true)); - (*end_of_forward)->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); -} - -void ReorderForBackwardLastSeg(const std::vector &pair_vector, const FuncGraphPtr &root, - AnfNodePtr *start_of_backward, AnfNodePtr *end_of_backward, int64_t micro_max) { - MS_EXCEPTION_IF_NULL(g_device_manager); - MS_EXCEPTION_IF_NULL(root); - auto manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto stage_num = g_device_manager->stage_num(); - auto stage_id = g_device_manager->stage_id(); - int64_t seg_max = GetSegmentMax(root, pair_vector[3].second); - MS_LOG(INFO) << "Micro max:" << micro_max << "seg_max" << seg_max; - int64_t last_seg_index = SizeToLong(pair_vector[2].first.size()) - 1 - micro_max; - int64_t cur_stage_fwd_max_idx = 2 * (stage_num - stage_id - 1) + 1; - if (!IsFirstStage() && (micro_max + 1 > cur_stage_fwd_max_idx)) { - for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) { - auto forward_node_begin = pair_vector[2].first[LongToSize(last_seg_index) + i]; - auto forward_node_end = pair_vector[2].second[LongToSize(last_seg_index) + i]; - size_t insert_idx; - if (i == LongToSize(cur_stage_fwd_max_idx)) { - if (IsLastStage()) { - continue; - } - insert_idx = LongToSize(last_seg_index) + i - 1; - auto post_node = pair_vector[3].first[insert_idx]; - InsertDepend(forward_node_end, post_node, manager, root); - - auto prior_node = pair_vector[4].second[insert_idx]; - InsertDepend(prior_node, forward_node_begin, manager, root); - } else { - if (IsLastStage() && i == LongToSize(cur_stage_fwd_max_idx + 1)) { - auto post_node0 = pair_vector[1].first[0]; - InsertDepend(forward_node_end, post_node0, manager, root); - auto pre_prior_node = pair_vector[2].second[LongToSize(last_seg_index) + i - 1]; - InsertDepend(pre_prior_node, forward_node_begin, manager, root); - auto pre_post_node = pair_vector[2].first[LongToSize(last_seg_index) + i - 1]; - auto prior_node0 = GetActualOp(pair_vector[1].first[0]->cast()->input(1)); - InsertDepend(prior_node0, pre_post_node, manager, root); - continue; - } - insert_idx = i - LongToSize(cur_stage_fwd_max_idx) - 1; - auto post_node1 = pair_vector[1].first[insert_idx]; - InsertDepend(forward_node_end, post_node1, manager, root); - - auto prior_cnode1 = post_node1->cast(); - auto before_prior_cnode = GetActualOp(prior_cnode1->input(1)); - InsertDepend(before_prior_cnode, forward_node_begin, manager, root); - } - } - } - - if (micro_max + 1 > cur_stage_fwd_max_idx) { - for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) { - if (!IsLastStage()) { - auto prior_node1 = pair_vector[3].second[last_seg_index + i]; - auto post_node1 = pair_vector[0].first[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx + 1)]; - InsertDepend(prior_node1, post_node1, manager, root); - } - std::shared_ptr post_node2; - post_node2 = FindNodeFirstUser(root, pair_vector[kForwardStart].first[last_seg_index + i]); - auto prior_node2 = pair_vector[1].second[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx)]; - InsertDepend(prior_node2, post_node2, manager, root); - } - - for (size_t j = LongToSize(micro_max + 1 - 2 * (stage_num - stage_id - 1)); j < LongToSize(micro_max + 1); ++j) { - auto prior_node3 = pair_vector[1].second[j - 1]; - auto post_node3 = pair_vector[0].first[j]; - InsertDepend(prior_node3, post_node3, manager, root); - } - } else { - for (size_t j = 1; j < LongToSize(micro_max + 1); ++j) { - auto prior_node4 = pair_vector[1].second[j - 1]; - auto post_node4 = pair_vector[0].first[j]; - InsertDepend(prior_node4, post_node4, manager, root); - } - } - - if (!IsLastStage()) { - std::shared_ptr prior_node5; - if ((micro_max + 1 > cur_stage_fwd_max_idx)) { - prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + cur_stage_fwd_max_idx - 1)]; - } else { - prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + micro_max)]; - } - auto post_node5 = pair_vector[0].first[0]; - InsertDepend(prior_node5, post_node5, manager, root); - } - - for (size_t i = 0; i < pair_vector[0].first.size(); ++i) { - pair_vector[0].first[i]->cast()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true)); - pair_vector[0].first[i]->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); - } - *start_of_backward = pair_vector[0].first[0]; - *end_of_backward = pair_vector[1].second.back(); - ReorderForBackwardOtherSeg(pair_vector[0], pair_vector[1], micro_max, stage_num, root); -} - -void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, - int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - auto manager = root->manager(); - for (size_t i = LongToSize(micro_max) + 1; i < backward_start_pair.first.size(); ++i) { - auto prior_node_begin = backward_end_pair.first[i - 1]; - auto prior_node_end = backward_end_pair.second[i - 1]; - auto post_node_begin = backward_start_pair.first[i]; - auto post_node_end = backward_start_pair.second[i]; - - if (IsLastStage() && (i > IntToSize(micro_max))) { - auto receive_node = post_node_begin; - post_node_begin = FindNodeFirstUser(root, post_node_begin); - int64_t insert_idx = SizeToLong(i) - (micro_max + 1) + (stage_num - 1); - auto send_node_begin = backward_end_pair.first[insert_idx]; - auto send_node_end = backward_end_pair.second[insert_idx]; - InsertDepend(post_node_end, send_node_begin, manager, root); - - auto send_cnode = send_node_begin->cast(); - auto before_send_node = GetActualOp(send_cnode->input(1)); - before_send_node = GetActualOp((before_send_node->cast())->input(1)); - - InsertDepend(before_send_node, receive_node, manager, root); - } - - InsertDepend(prior_node_end, post_node_begin, manager, root); - } -} - -PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max, - int64_t seg_max, bool is_train) { - std::vector out_vec_begin; - std::vector out_vec_end; - for (int64_t h = 0; h <= seg_max; ++h) { - CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); - } - if (out_vec_begin.empty()) { - return std::make_pair(node_vector, node_vector); - } - return std::make_pair(out_vec_begin, out_vec_end); -} - -PipelinePair DeduplicateBySegAscending(const std::vector &node_vector, const FuncGraphPtr &root, - int64_t micro_max, bool is_train, int64_t seg_max = 0) { - std::vector out_vec_begin; - std::vector out_vec_end; - for (int64_t h = 0; h <= seg_max; ++h) { - CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); - } - if (out_vec_begin.empty()) { - return std::make_pair(node_vector, node_vector); - } - return std::make_pair(out_vec_begin, out_vec_end); -} - -PipelinePair DeduplicateBySegDescending(const std::vector &node_vector, const FuncGraphPtr &root, - int64_t micro_max, bool is_train, int64_t seg_max = 0) { - std::vector out_vec_begin; - std::vector out_vec_end; - for (int64_t h = seg_max; h >= 0; --h) { - CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); - } - if (out_vec_begin.empty()) { - return std::make_pair(node_vector, node_vector); - } - return std::make_pair(out_vec_begin, out_vec_end); -} - -void ReorderForFoldPipelineBackward(const std::vector &pair_vector, int64_t seg_max, int64_t micro_max, - const FuncGraphPtr &root, AnfNodePtr *start_of_backward, - AnfNodePtr *end_of_backward) { - MS_EXCEPTION_IF_NULL(g_device_manager); - MS_EXCEPTION_IF_NULL(root); - auto manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto stage_num = g_device_manager->stage_num(); - - bool first = true; - for (size_t i = 0; i < pair_vector[0].first.size(); ++i) { - pair_vector[0].first[i]->cast()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true)); - pair_vector[0].first[i]->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); - } - for (size_t i = 1; i < pair_vector[0].first.size(); ++i) { - auto prior_node_begin = pair_vector[1].first[i - 1]; - auto prior_node_end = pair_vector[1].second[i - 1]; - auto post_node_begin = pair_vector[0].first[i]; - auto post_node_end = pair_vector[0].second[i]; - - if (IsLastStage() && (i > IntToSize(micro_max))) { - auto receive_node = post_node_begin; - post_node_begin = FindNodeFirstUser(root, post_node_begin); - auto insert_idx = i - (IntToSize(micro_max) + 1) + (IntToSize(stage_num) - 1); - auto send_node_begin = pair_vector[1].first[insert_idx]; - auto send_node_end = pair_vector[1].second[insert_idx]; - - InsertDepend(post_node_end, send_node_begin, manager, root); - - auto send_cnode = send_node_begin->cast(); - auto before_send_node = GetActualOp(send_cnode->input(1)); - before_send_node = GetActualOp((before_send_node->cast())->input(1)); - - InsertDepend(before_send_node, receive_node, manager, root); - } - - InsertDepend(prior_node_end, post_node_begin, manager, root); - if (first) { - *start_of_backward = pair_vector[0].first[i - 1]; - first = false; - } - } - *end_of_backward = pair_vector[1].second.back(); -} - -PipelinePairVector UpdateSubPairs(int64_t sub_step_num, int64_t micro_num, std::vector pair_vector, - int64_t sub_micro_num, int64_t seg_num) { - PipelinePairVector sub_pair_vector; - PipelinePairVector tmp_pair_vector; - if (micro_num % sub_step_num != 0) { - MS_LOG(EXCEPTION) << "Micro_num(" << micro_num << ")cannot be divisible by sub_step_num(" << sub_step_num << ")."; - } - - if (sub_micro_num < g_device_manager->stage_num()) { - MS_LOG(EXCEPTION) << "Sub_micro_num(" << sub_micro_num << ") is less than stage_num(" - << g_device_manager->stage_num() << ")."; - } - MS_LOG(INFO) << "Micro_num=" << micro_num << ",sub_micro_num=" << sub_micro_num << ",seg_num = " << seg_num; - - std::transform(pair_vector.begin(), pair_vector.end(), std::back_inserter(tmp_pair_vector), - [&sub_step_num, &seg_num, &sub_micro_num, µ_num](const auto &pipeline_pair) { - return GetSubStepPairs(pipeline_pair, sub_step_num, seg_num, sub_micro_num, micro_num); - }); - - for (size_t i = 0; i < tmp_pair_vector.size(); i++) { - std::vector sub_step1; - std::vector sub_step2; - if (!sub_pair_vector.empty()) { - sub_pair_vector[0].push_back(sub_pair_vector[i][0]); - sub_pair_vector[1].push_back(sub_pair_vector[i][1]); - } else { - sub_step1.push_back(sub_pair_vector[i][0]); - sub_pair_vector.push_back(sub_step1); - sub_step2.push_back(sub_pair_vector[i][1]); - sub_pair_vector.push_back(sub_step2); - } - } - return sub_pair_vector; -} - -void FoldPipelineReorder(const FuncGraphPtr &root) { - std::vector forward_start; - std::vector forward_end; - std::vector forward_params; - std::vector backward_start; - std::vector backward_end; - std::vector backward_params; - std::vector allreduce_params; - - SetParameterStartForCellShare(root); - GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params, - &allreduce_params, root); - int64_t micro_max = GetMicroMax(root, forward_end); - int64_t seg_max = GetSegmentMax(root, forward_end); - std::vector seg_micro_max{micro_max, seg_max}; - - auto backward_start_pair = DeduplicateBySegDescending(backward_start, root, micro_max, true, seg_max); - auto backward_end_pair = DeduplicateBySegDescending(backward_end, root, micro_max, true, seg_max); - auto forward_start_pair = DeduplicateBySegAscending(forward_start, root, micro_max, true, seg_max); - auto forward_end_pair = DeduplicateBySegAscending(forward_end, root, micro_max, true, seg_max); - auto forward_params_pair = Deduplicate(forward_params, root, micro_max, true, seg_max); - auto backward_params_pair = Deduplicate(backward_params, root, micro_max, true, seg_max); - CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max); - auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair); - std::vector pair_vector{backward_start_pair, backward_end_pair, forward_start_pair, forward_end_pair, - forward_end_before_pair}; - AnfNodePtr start_of_forward; - AnfNodePtr end_of_forward; - AnfNodePtr start_of_backward; - AnfNodePtr end_of_backward; - AnfNodePtr pre_end_of_backward; - - bool enable_1f1b = false; - if (common::GetEnv("FOLD_LAST_SEG_1F1B") != "") { - enable_1f1b = true; - } - int64_t sub_step_num = 0; - int64_t sub_micro_num = 0; - if (common::GetEnv("FOLD_ACCUMULATION") != "") sub_step_num = std::stoi(common::GetEnv("FOLD_ACCUMULATION")); - MS_LOG(INFO) << "Sub_step_num=" << sub_step_num; - PipelinePairVector sub_pair_vector; - if (sub_step_num > 0) { - int64_t micro_num = micro_max + 1; - int64_t seg_num = seg_max + 1; - sub_micro_num = micro_num / sub_step_num; - sub_pair_vector = UpdateSubPairs(sub_step_num, micro_num, pair_vector, sub_micro_num, seg_num); - } - - if (enable_1f1b) { - if (sub_step_num > 0) { - for (int64_t s = 0; s < sub_step_num; s++) { - ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward, - &end_of_forward, enable_1f1b); - ReorderForBackwardLastSeg(sub_pair_vector[s], root, &start_of_backward, &end_of_backward, sub_micro_num - 1); - if (s > 0) { - InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root); - } - pre_end_of_backward = end_of_backward; - ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[kBackwardEnd][s], - sub_pair_vector[kForwardStart][s], root); - } - } else { - ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward, - enable_1f1b); - ReorderForBackwardLastSeg(pair_vector, root, &start_of_backward, &end_of_backward, micro_max); - ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root); - } - } else { - if (sub_step_num > 0) { - for (int64_t s = 0; s < sub_step_num; s++) { - ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward, - &end_of_forward, enable_1f1b); - - ReorderForFoldPipelineBackward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_backward, - &end_of_backward); - InsertDepend(end_of_forward, start_of_backward, root->manager(), root); - if (s > 0) { - InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root); - } - pre_end_of_backward = end_of_backward; - ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[1][s], sub_pair_vector[2][s], root); - } - } else { - ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward, - enable_1f1b); - ReorderForFoldPipelineBackward(pair_vector, seg_max, micro_max, root, &start_of_backward, &end_of_backward); - InsertDepend(end_of_forward, start_of_backward, root->manager(), root); - ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root); - } - } -} - -} // namespace parallel -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/fold_pipeline_split_utils.h" +#include +#include +#include +#include +#include + +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" +#include "ops/other_ops.h" +#include "ops/math_ops.h" +#include "ops/framework_ops.h" +#include "ops/array_ops.h" +#include "ops/nn_ops.h" +#include "ir/value.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/device_manager.h" +#include "include/common/utils/parallel_context.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_parallel_utils.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "utils/parallel_node_check.h" + +namespace mindspore { +namespace parallel { + +namespace { +constexpr int kBackwardEnd = 1; +constexpr int kForwardStart = 2; +constexpr int kForwardEnd = 3; +} // namespace + +const std::set END_NODE_BLACK_LIST = { + prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits, + prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape}; + +int64_t GetSegmentMax(const FuncGraphPtr &root, const std::vector &forward_end) { + int64_t seg_max = 0; + if (forward_end.empty()) { + MS_LOG(EXCEPTION) << "Can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it."; + } else { + auto forward_end_cnode = forward_end.back()->cast(); + auto seg_size = forward_end_cnode->GetPrimalAttr(SEGMENT); + MS_EXCEPTION_IF_NULL(seg_size); + seg_max = GetValue(seg_size); + } + return seg_max; +} + +std::vector GetSubStepPairs(const PipelinePair &fp_or_bp_pair, int64_t sub_step_num, int64_t seg_num, + int64_t sub_micro_num, int64_t micro_num) { + std::vector fp_or_bp_sub_pairs; + for (int64_t s = 0; s < sub_step_num; s++) { + std::vector temp_first; + std::vector temp_second; + for (int64_t sid = 0; sid < seg_num; sid++) { + temp_first.insert(temp_first.end(), fp_or_bp_pair.first.begin() + s * sub_micro_num + sid * micro_num, + fp_or_bp_pair.first.begin() + (s + 1) * sub_micro_num + sid * micro_num); + temp_second.insert(temp_second.end(), fp_or_bp_pair.second.begin() + s * sub_micro_num + sid * micro_num, + fp_or_bp_pair.second.begin() + (s + 1) * sub_micro_num + sid * micro_num); + } + fp_or_bp_sub_pairs.emplace_back(temp_first, temp_second); + } + return fp_or_bp_sub_pairs; +} + +bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2) { + auto parallel_context = parallel::ParallelContext::GetInstance(); + if (parallel_context->enable_fold_pipeline()) { + auto get_value_func = [](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto seg = cnode->GetPrimalAttr(SEGMENT); + MS_EXCEPTION_IF_NULL(seg); + return GetValue(seg); + }; + + if (get_value_func(node1) != get_value_func(node2)) { + return get_value_func(node1) < get_value_func(node2); + } + } + return CompFunc(node1, node2); +} + +bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2) { + auto parallel_context = parallel::ParallelContext::GetInstance(); + if (parallel_context->enable_fold_pipeline()) { + auto get_value_func = [](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto seg = cnode->GetPrimalAttr(SEGMENT); + MS_EXCEPTION_IF_NULL(seg); + return GetValue(seg); + }; + + if (get_value_func(node1) != get_value_func(node2)) { + return get_value_func(node1) > get_value_func(node2); + } + } + return CompFunc(node1, node2); +} + +void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager) { + auto end_node = GetPreNode(temp_node); + MS_EXCEPTION_IF_NULL(end_node); + auto end_cnode = end_node->cast(); + MS_EXCEPTION_IF_NULL(end_cnode); + auto end_prim = GetCNodePrimitive(end_node); + OperatorAttrs attrs_; + auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node"); + auto value_node = NewValueNode(op); + auto new_prim = GetValueNode(value_node)->cast(); + (void)new_prim->SetAttrs(end_prim->attrs()); + manager->SetEdge(end_node, 0, value_node); + end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO)); + auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num(); + end_cnode->AddPrimalAttr(SEGMENT, MakeValue(seg - 1)); +} + +AnfNodePtr FindNodeFirstUser(const FuncGraphPtr &root, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(root); + auto node_users_map = root->manager()->node_users(); + auto users = node_users_map[node]; + for (auto &temp_user : users) { + MS_LOG(INFO) << "Receive user: " << (temp_user.first)->ToString(); + return temp_user.first; + } + return nullptr; +} + +static bool IsInEndNodeBlackListOrParallelBlackList(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + return true; + } + auto prim = GetValueNode(cnode->input(0)); + if (IsInParallelBlackList(prim)) { + return true; + } + for (auto &prim_node : END_NODE_BLACK_LIST) { + if (IsPrimitiveCNode(cnode, prim_node)) { + return true; + } + } + return false; +} + +AnfNodePtr GetPreNode(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector node_queue = {node}; + while (!node_queue.empty()) { + auto cur_node = (*node_queue.begin())->cast(); + (void)node_queue.erase(node_queue.begin()); + if (!cur_node) { + continue; + } + if (!IsInEndNodeBlackListOrParallelBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) { + MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString(); + return cur_node; + } + (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end()); + } + MS_LOG(EXCEPTION) << "Get Pipeline End node failed."; +} + +static bool ComputeLastSegForwardEndIdx(const PipelinePair &forward_start, size_t curr_idx, int64_t micro_max, + int64_t stage_num, int64_t stage_id) { + auto last_seg_idx = static_cast(1 + micro_max + 1 - 2 * (stage_num - stage_id - 1) - 1); + return curr_idx > forward_start.first.size() - last_seg_idx; +} + +void ReorderForFoldPipelineForward(const std::vector &pair_vector, int64_t seg_max, int64_t micro_max, + const FuncGraphPtr &root, AnfNodePtr *start_of_forward, AnfNodePtr *end_of_forward, + bool enable_1f1b) { + MS_EXCEPTION_IF_NULL(g_device_manager); + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + *start_of_forward = pair_vector[kForwardStart].first[0]; + for (size_t i = 1; i < pair_vector[kForwardStart].first.size(); ++i) { + auto prior_node_begin = pair_vector[kForwardEnd].first[i - 1]; + auto prior_node_end = pair_vector[kForwardEnd].second[i - 1]; + auto post_node_begin = pair_vector[kForwardStart].first[i]; + auto post_node_end = pair_vector[kForwardStart].second[i]; + if (IsFirstStage() && (i > IntToSize(micro_max))) { + auto receive_node = post_node_begin; + post_node_begin = FindNodeFirstUser(root, post_node_begin); + + MS_EXCEPTION_IF_NULL(post_node_begin); + auto insert_idx = i - LongToSize(micro_max + 1) + LongToSize(stage_num - 1); + auto send_node_begin = pair_vector[3].first[insert_idx]; + auto send_node_end = pair_vector[3].second[insert_idx]; + InsertDepend(post_node_end, send_node_begin, manager, root); + + auto send_cnode = send_node_begin->cast(); + auto before_send_node = GetActualOp(send_cnode->input(1)); + + InsertDepend(before_send_node, receive_node, manager, root); + } + if (enable_1f1b && ComputeLastSegForwardEndIdx(pair_vector[kForwardStart], i, micro_max, stage_num, stage_id)) { + continue; + } + + InsertDepend(prior_node_end, post_node_begin, manager, root); + *end_of_forward = pair_vector[kForwardEnd].second[i]; + } + (*end_of_forward)->cast()->AddPrimalAttr(FORWARD_END, MakeValue(true)); + (*end_of_forward)->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); +} + +void ReorderForBackwardLastSeg(const std::vector &pair_vector, const FuncGraphPtr &root, + AnfNodePtr *start_of_backward, AnfNodePtr *end_of_backward, int64_t micro_max) { + MS_EXCEPTION_IF_NULL(g_device_manager); + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + int64_t seg_max = GetSegmentMax(root, pair_vector[3].second); + MS_LOG(INFO) << "Micro max:" << micro_max << "seg_max" << seg_max; + int64_t last_seg_index = SizeToLong(pair_vector[2].first.size()) - 1 - micro_max; + int64_t cur_stage_fwd_max_idx = 2 * (stage_num - stage_id - 1) + 1; + if (!IsFirstStage() && (micro_max + 1 > cur_stage_fwd_max_idx)) { + for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) { + auto forward_node_begin = pair_vector[2].first[LongToSize(last_seg_index) + i]; + auto forward_node_end = pair_vector[2].second[LongToSize(last_seg_index) + i]; + size_t insert_idx; + if (i == LongToSize(cur_stage_fwd_max_idx)) { + if (IsLastStage()) { + continue; + } + insert_idx = LongToSize(last_seg_index) + i - 1; + auto post_node = pair_vector[3].first[insert_idx]; + InsertDepend(forward_node_end, post_node, manager, root); + + auto prior_node = pair_vector[4].second[insert_idx]; + InsertDepend(prior_node, forward_node_begin, manager, root); + } else { + if (IsLastStage() && i == LongToSize(cur_stage_fwd_max_idx + 1)) { + auto post_node0 = pair_vector[1].first[0]; + InsertDepend(forward_node_end, post_node0, manager, root); + auto pre_prior_node = pair_vector[2].second[LongToSize(last_seg_index) + i - 1]; + InsertDepend(pre_prior_node, forward_node_begin, manager, root); + auto pre_post_node = pair_vector[2].first[LongToSize(last_seg_index) + i - 1]; + auto prior_node0 = GetActualOp(pair_vector[1].first[0]->cast()->input(1)); + InsertDepend(prior_node0, pre_post_node, manager, root); + continue; + } + insert_idx = i - LongToSize(cur_stage_fwd_max_idx) - 1; + auto post_node1 = pair_vector[1].first[insert_idx]; + InsertDepend(forward_node_end, post_node1, manager, root); + + auto prior_cnode1 = post_node1->cast(); + auto before_prior_cnode = GetActualOp(prior_cnode1->input(1)); + InsertDepend(before_prior_cnode, forward_node_begin, manager, root); + } + } + } + + if (micro_max + 1 > cur_stage_fwd_max_idx) { + for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) { + if (!IsLastStage()) { + auto prior_node1 = pair_vector[3].second[last_seg_index + i]; + auto post_node1 = pair_vector[0].first[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx + 1)]; + InsertDepend(prior_node1, post_node1, manager, root); + } + std::shared_ptr post_node2; + post_node2 = FindNodeFirstUser(root, pair_vector[kForwardStart].first[last_seg_index + i]); + auto prior_node2 = pair_vector[1].second[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx)]; + InsertDepend(prior_node2, post_node2, manager, root); + } + + for (size_t j = LongToSize(micro_max + 1 - 2 * (stage_num - stage_id - 1)); j < LongToSize(micro_max + 1); ++j) { + auto prior_node3 = pair_vector[1].second[j - 1]; + auto post_node3 = pair_vector[0].first[j]; + InsertDepend(prior_node3, post_node3, manager, root); + } + } else { + for (size_t j = 1; j < LongToSize(micro_max + 1); ++j) { + auto prior_node4 = pair_vector[1].second[j - 1]; + auto post_node4 = pair_vector[0].first[j]; + InsertDepend(prior_node4, post_node4, manager, root); + } + } + + if (!IsLastStage()) { + std::shared_ptr prior_node5; + if ((micro_max + 1 > cur_stage_fwd_max_idx)) { + prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + cur_stage_fwd_max_idx - 1)]; + } else { + prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + micro_max)]; + } + auto post_node5 = pair_vector[0].first[0]; + InsertDepend(prior_node5, post_node5, manager, root); + } + + for (size_t i = 0; i < pair_vector[0].first.size(); ++i) { + pair_vector[0].first[i]->cast()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true)); + pair_vector[0].first[i]->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); + } + *start_of_backward = pair_vector[0].first[0]; + *end_of_backward = pair_vector[1].second.back(); + ReorderForBackwardOtherSeg(pair_vector[0], pair_vector[1], micro_max, stage_num, root); +} + +void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, + int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + for (size_t i = LongToSize(micro_max) + 1; i < backward_start_pair.first.size(); ++i) { + auto prior_node_begin = backward_end_pair.first[i - 1]; + auto prior_node_end = backward_end_pair.second[i - 1]; + auto post_node_begin = backward_start_pair.first[i]; + auto post_node_end = backward_start_pair.second[i]; + + if (IsLastStage() && (i > IntToSize(micro_max))) { + auto receive_node = post_node_begin; + post_node_begin = FindNodeFirstUser(root, post_node_begin); + int64_t insert_idx = SizeToLong(i) - (micro_max + 1) + (stage_num - 1); + auto send_node_begin = backward_end_pair.first[insert_idx]; + auto send_node_end = backward_end_pair.second[insert_idx]; + InsertDepend(post_node_end, send_node_begin, manager, root); + + auto send_cnode = send_node_begin->cast(); + auto before_send_node = GetActualOp(send_cnode->input(1)); + before_send_node = GetActualOp((before_send_node->cast())->input(1)); + + InsertDepend(before_send_node, receive_node, manager, root); + } + + InsertDepend(prior_node_end, post_node_begin, manager, root); + } +} + +PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max, + int64_t seg_max, bool is_train) { + std::vector out_vec_begin; + std::vector out_vec_end; + for (int64_t h = 0; h <= seg_max; ++h) { + CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); + } + if (out_vec_begin.empty()) { + return std::make_pair(node_vector, node_vector); + } + return std::make_pair(out_vec_begin, out_vec_end); +} + +PipelinePair DeduplicateBySegAscending(const std::vector &node_vector, const FuncGraphPtr &root, + int64_t micro_max, bool is_train, int64_t seg_max = 0) { + std::vector out_vec_begin; + std::vector out_vec_end; + for (int64_t h = 0; h <= seg_max; ++h) { + CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); + } + if (out_vec_begin.empty()) { + return std::make_pair(node_vector, node_vector); + } + return std::make_pair(out_vec_begin, out_vec_end); +} + +PipelinePair DeduplicateBySegDescending(const std::vector &node_vector, const FuncGraphPtr &root, + int64_t micro_max, bool is_train, int64_t seg_max = 0) { + std::vector out_vec_begin; + std::vector out_vec_end; + for (int64_t h = seg_max; h >= 0; --h) { + CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train); + } + if (out_vec_begin.empty()) { + return std::make_pair(node_vector, node_vector); + } + return std::make_pair(out_vec_begin, out_vec_end); +} + +void ReorderForFoldPipelineBackward(const std::vector &pair_vector, int64_t seg_max, int64_t micro_max, + const FuncGraphPtr &root, AnfNodePtr *start_of_backward, + AnfNodePtr *end_of_backward) { + MS_EXCEPTION_IF_NULL(g_device_manager); + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto stage_num = g_device_manager->stage_num(); + + bool first = true; + for (size_t i = 0; i < pair_vector[0].first.size(); ++i) { + pair_vector[0].first[i]->cast()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true)); + pair_vector[0].first[i]->cast()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max)); + } + for (size_t i = 1; i < pair_vector[0].first.size(); ++i) { + auto prior_node_begin = pair_vector[1].first[i - 1]; + auto prior_node_end = pair_vector[1].second[i - 1]; + auto post_node_begin = pair_vector[0].first[i]; + auto post_node_end = pair_vector[0].second[i]; + + if (IsLastStage() && (i > IntToSize(micro_max))) { + auto receive_node = post_node_begin; + post_node_begin = FindNodeFirstUser(root, post_node_begin); + auto insert_idx = i - (IntToSize(micro_max) + 1) + (IntToSize(stage_num) - 1); + auto send_node_begin = pair_vector[1].first[insert_idx]; + auto send_node_end = pair_vector[1].second[insert_idx]; + + InsertDepend(post_node_end, send_node_begin, manager, root); + + auto send_cnode = send_node_begin->cast(); + auto before_send_node = GetActualOp(send_cnode->input(1)); + before_send_node = GetActualOp((before_send_node->cast())->input(1)); + + InsertDepend(before_send_node, receive_node, manager, root); + } + + InsertDepend(prior_node_end, post_node_begin, manager, root); + if (first) { + *start_of_backward = pair_vector[0].first[i - 1]; + first = false; + } + } + *end_of_backward = pair_vector[1].second.back(); +} + +PipelinePairVector UpdateSubPairs(int64_t sub_step_num, int64_t micro_num, std::vector pair_vector, + int64_t sub_micro_num, int64_t seg_num) { + PipelinePairVector sub_pair_vector; + PipelinePairVector tmp_pair_vector; + if (micro_num % sub_step_num != 0) { + MS_LOG(EXCEPTION) << "Micro_num(" << micro_num << ")cannot be divisible by sub_step_num(" << sub_step_num << ")."; + } + + if (sub_micro_num < g_device_manager->stage_num()) { + MS_LOG(EXCEPTION) << "Sub_micro_num(" << sub_micro_num << ") is less than stage_num(" + << g_device_manager->stage_num() << ")."; + } + MS_LOG(INFO) << "Micro_num=" << micro_num << ",sub_micro_num=" << sub_micro_num << ",seg_num = " << seg_num; + + std::transform(pair_vector.begin(), pair_vector.end(), std::back_inserter(tmp_pair_vector), + [&sub_step_num, &seg_num, &sub_micro_num, µ_num](const auto &pipeline_pair) { + return GetSubStepPairs(pipeline_pair, sub_step_num, seg_num, sub_micro_num, micro_num); + }); + + for (size_t i = 0; i < tmp_pair_vector.size(); i++) { + std::vector sub_step1; + std::vector sub_step2; + if (!sub_pair_vector.empty()) { + sub_pair_vector[0].push_back(sub_pair_vector[i][0]); + sub_pair_vector[1].push_back(sub_pair_vector[i][1]); + } else { + sub_step1.push_back(sub_pair_vector[i][0]); + sub_pair_vector.push_back(sub_step1); + sub_step2.push_back(sub_pair_vector[i][1]); + sub_pair_vector.push_back(sub_step2); + } + } + return sub_pair_vector; +} + +void FoldPipelineReorder(const FuncGraphPtr &root) { + std::vector forward_start; + std::vector forward_end; + std::vector forward_params; + std::vector backward_start; + std::vector backward_end; + std::vector backward_params; + std::vector allreduce_params; + + SetParameterStartForCellShare(root); + GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params, + &allreduce_params, root); + int64_t micro_max = GetMicroMax(root, forward_end); + int64_t seg_max = GetSegmentMax(root, forward_end); + std::vector seg_micro_max{micro_max, seg_max}; + + auto backward_start_pair = DeduplicateBySegDescending(backward_start, root, micro_max, true, seg_max); + auto backward_end_pair = DeduplicateBySegDescending(backward_end, root, micro_max, true, seg_max); + auto forward_start_pair = DeduplicateBySegAscending(forward_start, root, micro_max, true, seg_max); + auto forward_end_pair = DeduplicateBySegAscending(forward_end, root, micro_max, true, seg_max); + auto forward_params_pair = Deduplicate(forward_params, root, micro_max, true, seg_max); + auto backward_params_pair = Deduplicate(backward_params, root, micro_max, true, seg_max); + CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max); + auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair); + std::vector pair_vector{backward_start_pair, backward_end_pair, forward_start_pair, forward_end_pair, + forward_end_before_pair}; + AnfNodePtr start_of_forward; + AnfNodePtr end_of_forward; + AnfNodePtr start_of_backward; + AnfNodePtr end_of_backward; + AnfNodePtr pre_end_of_backward; + + bool enable_1f1b = false; + if (common::GetEnv("FOLD_LAST_SEG_1F1B") != "") { + enable_1f1b = true; + } + int64_t sub_step_num = 0; + int64_t sub_micro_num = 0; + if (common::GetEnv("FOLD_ACCUMULATION") != "") sub_step_num = std::stoi(common::GetEnv("FOLD_ACCUMULATION")); + MS_LOG(INFO) << "Sub_step_num=" << sub_step_num; + PipelinePairVector sub_pair_vector; + if (sub_step_num > 0) { + int64_t micro_num = micro_max + 1; + int64_t seg_num = seg_max + 1; + sub_micro_num = micro_num / sub_step_num; + sub_pair_vector = UpdateSubPairs(sub_step_num, micro_num, pair_vector, sub_micro_num, seg_num); + } + + if (enable_1f1b) { + if (sub_step_num > 0) { + for (int64_t s = 0; s < sub_step_num; s++) { + ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward, + &end_of_forward, enable_1f1b); + ReorderForBackwardLastSeg(sub_pair_vector[s], root, &start_of_backward, &end_of_backward, sub_micro_num - 1); + if (s > 0) { + InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root); + } + pre_end_of_backward = end_of_backward; + ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[kBackwardEnd][s], + sub_pair_vector[kForwardStart][s], root); + } + } else { + ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward, + enable_1f1b); + ReorderForBackwardLastSeg(pair_vector, root, &start_of_backward, &end_of_backward, micro_max); + ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root); + } + } else { + if (sub_step_num > 0) { + for (int64_t s = 0; s < sub_step_num; s++) { + ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward, + &end_of_forward, enable_1f1b); + + ReorderForFoldPipelineBackward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_backward, + &end_of_backward); + InsertDepend(end_of_forward, start_of_backward, root->manager(), root); + if (s > 0) { + InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root); + } + pre_end_of_backward = end_of_backward; + ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[1][s], sub_pair_vector[2][s], root); + } + } else { + ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward, + enable_1f1b); + ReorderForFoldPipelineBackward(pair_vector, seg_max, micro_max, root, &start_of_backward, &end_of_backward); + InsertDepend(end_of_forward, start_of_backward, root->manager(), root); + ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root); + } + } +} + +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.h b/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.h index c90431a7a68..89e8e7b00e5 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/fold_pipeline_split_utils.h @@ -1,41 +1,41 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ -#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/manager.h" -#include "frontend/parallel/graph_util/pipeline_split_utils.h" - -namespace mindspore { -namespace parallel { -void FoldPipelineReorder(const FuncGraphPtr &root); -void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, - int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root); -void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager); -bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2); -bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2); -AnfNodePtr GetPreNode(const AnfNodePtr &node); -PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max, - int64_t seg_max, bool is_train); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" + +namespace mindspore { +namespace parallel { +void FoldPipelineReorder(const FuncGraphPtr &root); +void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, + int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root); +void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager); +bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2); +bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2); +AnfNodePtr GetPreNode(const AnfNodePtr &node); +PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max, + int64_t seg_max, bool is_train); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_FOLD_PIPELINE_SPLIT_UTILS_H_ diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.cc index bf2e15855a6..1426c3c5c4d 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.cc @@ -1,709 +1,709 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h" -#include -#include -#include -#include -#include -#include -#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" -#include "frontend/parallel/auto_parallel/graph_costmodel.h" -#include "frontend/parallel/graph_util/graph_splitter.h" -#include "frontend/parallel/ops_info/ops_utils.h" -#include "frontend/parallel/group_manager.h" -#include "frontend/parallel/parameter_manager.h" -#include "include/common/utils/parallel_context.h" -#include "frontend/parallel/step_parallel.h" -#include "frontend/parallel/node_check.h" -#include "frontend/parallel/graph_util/node_info.h" -#include "frontend/parallel/graph_util/pipeline_split_utils.h" -#include "frontend/parallel/step_parallel_utils.h" -#include "ir/anf.h" -#include "ir/graph_utils.h" -#include "ops/other_ops.h" -#include "ops/array_ops.h" -#include "ops/framework_ops.h" -#include "include/common/utils/comm_manager.h" -#include "utils/ms_context.h" -#include "utils/parallel_node_check.h" - -namespace mindspore { -namespace parallel { -mindspore::HashMap fold_send_tag_map; -mindspore::HashMap fold_recv_tag_map; - -void FoldPipelineTransformer::CreateForwardGroup2() { - auto rank_id = g_device_manager->global_rank(); - auto stage_id = g_device_manager->stage_id(); - auto stage_num = g_device_manager->stage_num(); - - std::vector forward_rank_list; - forward_rank_list.push_back(rank_id); - if (stage_id < stage_num - 1) { - forward_rank_list.push_back(rank_id + per_stage_rank_num_); - } else { - forward_rank_list.push_back(rank_id + per_stage_rank_num_ * (0 - stage_id)); - } - - Group g; - - if (g_device_manager->CreateGroup(forward_rank_list, &g) != SUCCESS) { - MS_LOG(EXCEPTION) << "Create forward communication group between all pipeline stages failed, the rank_list is: " - << forward_rank_list; - } - - std::vector backward_rank_list; - if (stage_id == 0) { - backward_rank_list.push_back(rank_id + per_stage_rank_num_ * (stage_num - 1)); - } else { - backward_rank_list.push_back(rank_id - per_stage_rank_num_); - } - backward_rank_list.push_back(rank_id); - - Group g_back; - if (g_device_manager->CreateGroup(backward_rank_list, &g_back) != SUCCESS) { - MS_LOG(EXCEPTION) << "Create backward communication group between all pipeline stages failed, the rank_list is: " - << backward_rank_list; - } - - group_.push_back(g.name()); - group_.push_back(g_back.name()); -} -void HandleSegment(const ValuePtr &value, const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = graph->nodes(); - for (auto node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - MS_LOG(INFO) << "Handle Segment cnode: " << cnode->fullname_with_scope(); - cnode->AddPrimalAttr(SEGMENT, value); - } - } -} -void FoldPipelineTransformer::Coloring() { - auto need_coloring = true; - std::set stage_set; - std::set segment_set; - if (!IsTraining(manager_)) { - is_train_ = false; - } - while (need_coloring) { - need_coloring = false; - for (auto &fg : manager_->func_graphs()) { - if (fg == root_ && is_train_) { - continue; - } - auto value_nodes = fg->value_nodes(); - for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) { - auto node = (*value_pair).first; - if (!IsValueNode(node)) { - continue; - } - auto graph = GetValueNode(node); - if (graph->stage() == -1) { - continue; - } - (void)stage_set.insert(graph->stage()); - (void)segment_set.insert(graph->segment()); - auto node_users = manager_->node_users()[node]; - HandleSegment(MakeValue(graph->segment()), graph); - for (auto &user_pair : node_users) { - auto user_node = user_pair.first->cast(); - user_node->set_user_data(std::make_shared(graph->stage())); - user_node->set_user_data(std::make_shared(graph->segment())); - auto user_node_graph = user_node->func_graph(); - if (graph->stage() == stage_ && user_node_graph->stage() == -1) { - user_node_graph->set_stage(graph->stage()); - MS_LOG(INFO) << "Set_segment in Coloring" << graph->segment(); - user_node_graph->set_segment(graph->segment()); - need_coloring = true; - } - } - } - } - } - MS_EXCEPTION_IF_NULL(g_device_manager); - auto stage_num = g_device_manager->stage_num(); - auto segment_num = ParallelContext::GetInstance()->pipeline_segment_split_num(); - if (SizeToLong(stage_set.size()) != stage_num) { - MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); - } - if (SizeToLong(segment_set.size()) != segment_num) { - MS_LOG(EXCEPTION) << "Segment num is " << segment_num << " is not equal to segment used: " << segment_set.size(); - } -} - -void FoldPipelineTransformer::ColorForNodes() { - for (auto &fg : manager_->func_graphs()) { - auto stage = fg->stage(); - auto segment = fg->segment(); - if (stage < 0) { - continue; - } - if (segment < 0) { - continue; - } - if (fg == root_ || fg == main_graph_ || fg == shared_cell_) { - continue; - } - auto all_nodes = fg->nodes(); - for (auto node : all_nodes) { - if (node->user_data() != nullptr) { - continue; - } - node->set_user_data(std::make_shared(stage)); - if (node->user_data() != nullptr) { - continue; - } - node->set_user_data(std::make_shared(segment)); - } - } -} - -void FoldPipelineTransformer::BroadCastColoring() { - auto need_coloring = true; - while (need_coloring) { - need_coloring = false; - auto all_nodes = main_graph_->nodes(); - auto node_users = manager_->node_users(); - for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) { - auto stage_info = (*node)->user_data(); - auto segment_info = (*node)->user_data(); - if (!(*node)->isa() || stage_info == nullptr || stage_info->stage() == -1 || - IsPrimitiveCNode(*node, prim::kPrimUpdateState)) { - continue; - } - auto stage = stage_info->stage(); - auto segment = segment_info->segment(); - for (auto &user_pair : node_users[*node]) { - auto user_node = user_pair.first->cast(); - auto user_stage_info = user_node->user_data(); - auto user_segment_info = user_node->user_data(); - if (user_stage_info == nullptr) { - user_node->set_user_data(std::make_shared(stage)); - user_node->set_user_data(std::make_shared(segment)); - need_coloring = true; - continue; - } - auto user_node_stage = user_stage_info->stage(); - auto user_node_segment = user_segment_info->segment(); - if (stage > user_node_stage && segment == user_node_segment) { - if (IsValueNode(user_node->input(0))) { - MS_LOG(WARNING) << "The stage setting is incorrect. PreNode's stage: " << stage - << " is larger than NextNode's stage:" << user_node_stage; - } - user_node->set_user_data(std::make_shared(stage)); - need_coloring = true; - } - if (segment > user_node_segment) { - user_node->set_user_data(std::make_shared(segment)); - need_coloring = true; - } - } - } - } - ColorForNodes(); -} - -SendAttr FoldPipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, - const ValuePtr &value, int64_t segment) { - auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; - int64_t send_tag; - auto stage_num = g_device_manager->stage_num(); - if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { - if (fold_recv_tag_map.find(dest_rank) != fold_recv_tag_map.end()) { - send_tag = fold_recv_tag_map[dest_rank] + 1; - fold_recv_tag_map[dest_rank] += 1; - } else { - send_tag = 0; - fold_recv_tag_map[dest_rank] = 0; - } - } else { - if (fold_send_tag_map.find(dest_rank) != fold_send_tag_map.end()) { - send_tag = fold_send_tag_map[dest_rank] + 1; - fold_send_tag_map[dest_rank] += 1; - } else { - send_tag = 0; - fold_send_tag_map[dest_rank] = 0; - } - } - Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag)); - Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage)); - Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); - Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); - if (stage_num > 2) { - auto next = (user_node_stage == 0) ? 0 : 1; - attr_rank = std::make_pair(DEST_RANK, MakeValue(next)); - attr_group = std::make_pair(GROUP, MakeValue(group_[0])); - attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0])); - } - - if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { - attr_group = std::make_pair(GROUP, MakeValue(group_[1])); - attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); - attr_rank = std::make_pair(DEST_RANK, MakeValue(1)); - } - auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; - std::vector send_input = {parameter}; - OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back}; - CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, send_input, attrs); - auto prim = GetCNodePrimitive(send); - AnfNodePtr care_node; - bool is_param = true; - auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param); - auto tensor_info = GetTensorInfo(op_info_pair, is_param); - - auto index = op_info_pair.second; - auto op_info = op_info_pair.first; - auto slice_shape = tensor_info.slice_shape(); - auto shape_type_pair = GetShapeType(parameter, slice_shape, 0); - prim->set_attr(SHAPE, shape_type_pair.first); - prim->set_attr(DTYPE, shape_type_pair.second); - if (!is_param) { - send->AddPrimalAttr(PIPELINE_END, value); - } else { - send->AddPrimalAttr(PIPELINE_PARAM, value); - send->set_user_data(op_info); - send->AddPrimalAttr(PARAM_INDEX, MakeValue(index)); - auto param = care_node ? care_node : parameter; - send->set_user_data(INPUT_PARAM, param); - } - send->AddPrimalAttr(MICRO, value); - send->AddPrimalAttr(SEGMENT, MakeValue(segment)); - MS_LOG(INFO) << "Insert Send op, segment is " << segment; - send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage)); - OperatorAttrs depend_attrs; - CNodePtr depend = CreateCNodeByInputsAndAttr(graph, DEPEND, DEPEND, AnfNodePtrList{parameter, send}, depend_attrs); - auto abstract = parameter->abstract(); - if (care_node) { - abstract = care_node->abstract(); - } - depend->set_abstract(abstract); - send->set_abstract(abstract); - SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; - - send->set_user_data(DEST_RANK, std::make_shared(dest_rank)); - send->set_user_data(USER_NODE_STAGE, std::make_shared(user_node_stage)); - return send_out; -} - -int64_t FoldPipelineTransformer::ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, - int64_t src_rank) { - int64_t recv_tag; - if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { - if (fold_send_tag_map.find(src_rank) != fold_send_tag_map.end()) { - recv_tag = fold_send_tag_map[src_rank] + 1; - fold_send_tag_map[src_rank] += 1; - } else { - recv_tag = 0; - fold_send_tag_map[src_rank] = 0; - } - } else { - if (fold_recv_tag_map.find(src_rank) != fold_recv_tag_map.end()) { - recv_tag = fold_recv_tag_map[src_rank] + 1; - fold_recv_tag_map[src_rank] += 1; - } else { - recv_tag = 0; - fold_recv_tag_map[src_rank] = 0; - } - } - return recv_tag; -} - -AnfNodePtr FoldPipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, - const AnfNodePtr &use_node, int index, int64_t user_node_stage, - int64_t node_stage, const ValuePtr &value, - const AnfNodePtr &graph_param, int64_t segment) { - auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; - auto stage_num = g_device_manager->stage_num(); - auto recv_tag = ComputeRecvTag(node_stage, user_node_stage, stage_num, src_rank); - Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag)); - Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage)); - Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); - Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); - - if (stage_num > 2) { - auto next = (user_node_stage == 0) ? 1 : 0; - attr_rank = std::make_pair(SRC_RANK, MakeValue(next)); - attr_group = std::make_pair(GROUP, MakeValue(group_[1])); - attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); - } - bool is_param = true; - AnfNodePtr care_node; - auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param); - auto tensor_info = GetTensorInfo(op_info_pair, is_param); - auto tensor_layout = tensor_info.tensor_layout(); - Shape slice_shape = tensor_info.slice_shape(); - auto shape_type_pair = GetShapeType(node, slice_shape, 0); - Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first); - Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second); - if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { - attr_group = std::make_pair(GROUP, MakeValue(group_[0])); - attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0])); - attr_rank = std::make_pair(SRC_RANK, MakeValue(0)); - } - OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back}; - std::vector recv_input; - if (node->isa()) { - recv_input = {node}; - } else { - recv_input = {virtual_param_}; - } - auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs); - if (is_param) { - recv->set_user_data(PIPELINE_PARAM, node); - recv->AddPrimalAttr(PIPELINE_PARAM, value); - auto param = care_node ? care_node : node; - recv->set_user_data(INPUT_PARAM, param); - } else { - recv->AddPrimalAttr(PIPELINE_BEGIN, value); - } - recv->AddPrimalAttr(MICRO, value); - recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage)); - recv->AddPrimalAttr(SEGMENT, MakeValue(segment)); - MS_LOG(INFO) << "Insertreceive segment" << segment; - auto node_abstract = node->abstract(); - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (IsValueNode(cnode->input(0))) { - auto output = GetValueNode(cnode->input(0))->output(); - MS_EXCEPTION_IF_NULL(output); - node_abstract = output->abstract(); - } - } - MS_EXCEPTION_IF_NULL(node_abstract); - recv->set_abstract(node_abstract); - if (node->isa()) { - BaseShapePtr parallel_shape = std::make_shared(slice_shape); - auto abstract_clone = node->abstract()->Clone(); - MS_EXCEPTION_IF_NULL(abstract_clone); - abstract_clone->set_shape(parallel_shape); - node->set_abstract(abstract_clone); - node->set_user_data(std::make_shared(tensor_layout)); - auto actual_param = RefParameterToActualParameter(node); - if (actual_param) { - actual_param->set_user_data(std::make_shared(tensor_layout)); - auto actual_param_abstract = actual_param->abstract()->Clone(); - actual_param_abstract->set_shape(parallel_shape); - actual_param->set_abstract(actual_param_abstract); - } - } - recv->set_user_data(std::make_shared(tensor_layout)); - recv->set_user_data(op_info_pair.first); - - recv->set_user_data(SRC_RANK, std::make_shared(src_rank)); - recv->set_user_data(NODE_STAGE, std::make_shared(node_stage)); - recv->set_user_data(SLICE_DTYPE, shape_type_pair.second); - recv->set_user_data(SLICE_SHAPE, std::make_shared(slice_shape)); - - manager_->SetEdge(use_node, index, recv); - return recv; -} - -AnfNodePtr FoldPipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment, - const std::vector &out_input, - const std::vector &out_input_segment, const std::string &tag) { - std::vector> zipped; - std::transform(out_input.begin(), out_input.end(), out_input_segment.begin(), std::back_inserter(zipped), - [](const auto &send, const auto &send_segment) { return std::make_pair(send, send_segment); }); - - for (auto &zipp : zipped) { - auto input = zipp.first; - auto send_segment = zipp.second; - auto cnode = input->cast(); - if (!cnode) { - continue; - } - if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) { - cnode = cnode->input(DEPEND_NODE_SOURCE_INDEX)->cast(); - } - if (cnode->input(1) == node) { - auto prim = GetValueNode(cnode->input(0)); - auto dest_rank_send = GetValue(prim->GetAttr(tag)); - if (dest_rank_send == stage && node_segment == send_segment) { - return input; - } - } - } - return nullptr; -} - -AnfNodePtr FoldPipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, - int64_t stage, int64_t user_stage, const ValuePtr µ, - size_t pos, const std::vector &ops) { - CNodePtr call_node = nullptr; - auto argument = GetRealKernelNode(node, -1, &call_node).first; - - auto use_cnode = use_node->cast(); - MS_EXCEPTION_IF_NULL(use_cnode); - if (!IsValueNode(use_cnode->input(0))) { - MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString(); - } - auto use_graph = GetValueNode(use_cnode->input(0)); - auto use_parameter_list = use_graph->parameters(); - auto parameter = use_parameter_list.at(pos - 1); - - // insert receive - if (stage_ == user_stage) { - auto recv = PipelineTransformer::Reuse(argument, stage, ops, SRC_RANK); - if (recv) { - manager_->SetEdge(use_node, SizeToInt(pos), recv); - return nullptr; - } - auto root_param = argument; - if (argument->isa() && argument->func_graph() != root_) { - root_param = GetArgumentsByParameter(argument); - } - (void)parameter_color_map_[root_param].insert(user_stage); - auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; - return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter, 0); - } - // insert send - if (PipelineTransformer::Reuse(argument, user_stage, ops, DEST_RANK)) { - return nullptr; - } - auto send_out = InsertSend(argument, user_stage, stage_, micro, 0); - send_out.depend->set_user_data(DTYPE, send_out.type); - send_out.depend->set_user_data(SHAPE, send_out.shape); - return send_out.depend; -} - -bool IsStageConflict(int64_t node_stage, int64_t user_node_stage, int64_t node_segment, int64_t user_node_segment, - int64_t stage_num, bool isEmbed) { - if (isEmbed || (node_stage < user_node_stage && node_segment == user_node_segment) || - (node_stage == stage_num - 1 && user_node_stage == 0 && node_segment < user_node_segment)) { - return true; - } - return false; -} - -void FoldPipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, - std::vector *send_ops, - std::vector *send_ops_segment, - std::vector *receive_ops) { - auto stage_info = node->user_data(); - auto segment_info = node->user_data(); - auto node_users = manager_->node_users()[node]; - AnfNodePtr receive = nullptr; - for (auto &user_pair : node_users) { - auto user_node = user_pair.first; - auto node_stage = stage_info->stage(); - auto node_segment = segment_info->segment(); - auto user_stage_info = user_node->user_data(); - if (user_stage_info == nullptr) { - continue; - } - auto user_segment_info = user_node->user_data(); - if (user_segment_info == nullptr) { - continue; - } - auto user_node_stage = user_stage_info->stage(); - if (node_stage != stage_ && user_node_stage != stage_) { - continue; - } - auto micro = user_node->cast()->GetPrimalAttr(MICRO); - auto user_node_segment = user_segment_info->segment(); - if (!micro) { - MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)"; - micro = MakeValue(int64_t(0)); - } - auto stage_num = g_device_manager->stage_num(); - - bool isEmbed = node_stage < user_node_stage && node_segment != user_node_segment; - if (IsStageConflict(node_stage, user_node_stage, node_segment, user_node_segment, stage_num, isEmbed)) { - if (node_stage == stage_) { - if (IsParameterGraph(node) && isEmbed) { - auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, - IntToSize(user_pair.second), *send_ops); - if (!send_depend) { - continue; - } - (void)send_ops->insert(send_ops->cbegin(), send_depend); - (void)send_ops_segment->insert(send_ops_segment->begin(), node_segment); - continue; - } - if (Reuse(node, user_node_stage, user_node_segment, *send_ops, *send_ops_segment, DEST_RANK)) { - continue; - } - auto send_out = InsertSend(node, user_node_stage, node_stage, micro, node_segment); - MS_EXCEPTION_IF_NULL(send_out.depend); - send_ops->push_back(send_out.depend); - send_ops_segment->push_back(node_segment); - send_out.depend->set_user_data(DTYPE, send_out.type); - send_out.depend->set_user_data(SHAPE, send_out.shape); - } else { - if (!receive) { - if (IsParameterGraph(node)) { - receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, - IntToSize(user_pair.second), *receive_ops); - if (!receive) { - continue; - } - receive_ops->push_back(receive); - } else { - receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node, - user_node_segment); - receive_ops->push_back(receive); - } - } else { - manager_->SetEdge(user_node, user_pair.second, receive); - } - } - continue; - } - if (node_stage > user_node_stage && node_segment == user_node_segment) { - MS_LOG(EXCEPTION) << "Within a segment, node_stage: " << node_stage - << " must be smaller than user_node_stage: " << user_node_stage; - } - } -} - -std::pair, std::vector> FoldPipelineTransformer::CutBorder( - const FuncGraphPtr &graph) { - std::vector send_ops; - std::vector send_ops_segment; - std::vector receive_ops; - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - std::reverse(all_nodes.begin(), all_nodes.end()); - auto stage_num = g_device_manager->stage_num(); - if (is_train_ && (stage_num > micro_size_)) { - MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; - } - for (auto &node : all_nodes) { - auto stage_info = node->user_data(); - if (!node->isa() || stage_info == nullptr || stage_info->stage() == -1 || - IsPrimitiveCNode(node, prim::kPrimUpdateState)) { - continue; - } - CutBorderForNode(graph, node, &send_ops, &send_ops_segment, &receive_ops); - } - RemoveMonadNode(); - return std::make_pair(send_ops, receive_ops); -} - -std::pair, std::vector> FoldPipelineTransformer::HandleSharedParameter() { - auto parameters = root_->parameters(); - std::vector sends = {}; - std::vector recvs = {}; - for (auto ¶meter : parameters) { - auto parameter_stage = parameter_color_map_[parameter]; - if (parameter_stage.size() <= 1) { - continue; - } - const auto &node_users_map = manager_->node_users(); - auto users = GetParameterLoadUsers(parameter, node_users_map); - for (auto &user : users) { - auto node = user.first; - auto cnode = node->cast(); - auto graph = node->func_graph(); - if (IsValueNode(cnode->input(0))) { - graph = GetValueNode(cnode->input(0)); - } - if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) { - continue; - } - auto micro = cnode->GetPrimalAttr(MICRO); - if (!micro) { - MS_LOG(INFO) << "Parameter: " << parameter->ToString() << " doesn't have micro batch"; - micro = MakeValue(int64_t(0)); - } - if (stage_ == *parameter_stage.begin()) { - auto user_stage = graph->stage(); - auto stage_info = node->user_data(); - if (stage_info) { - user_stage = stage_info->stage(); - } - if (graph->stage() == stage_ || user_stage == -1) { - continue; - } - if (PipelineTransformer::Reuse(parameter, user_stage, sends, DEST_RANK)) { - continue; - } - auto send_out = InsertSend(parameter, user_stage, stage_, micro, 0); - sends.push_back(send_out.depend); - } else { - auto receive = PipelineTransformer::Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK); - if (receive) { - manager_->SetEdge(node, user.second, receive); - } else { - AnfNodePtr recv; - auto fg = enable_share_cell_ ? shared_cell_ : main_graph_; - recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter, 0); - (void)(recvs.push_back(recv)); - } - } - } - } - return std::make_pair(sends, recvs); -} - -void FoldPipelineTransformer::CutGraph() { - CreateForwardGroup2(); - MS_EXCEPTION_IF_NULL(main_graph_); - auto send_recv_shared_param = HandleSharedParameter(); - auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; - MS_EXCEPTION_IF_NULL(graph); - auto send_recv_cut_border = CutBorder(graph); - std::vector send_ops; - (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end())); - (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end())); - if (IsLastStage() && !enable_share_cell_) { - auto out_node = main_graph_->output(); - - auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops); - - std::vector tuple_out_depend = {NewValueNode(prim::kPrimDepend)}; - tuple_out_depend.push_back(out_node); - tuple_out_depend.push_back(make_tuple); - - auto tuple_out_depend_node = main_graph_->NewCNode(tuple_out_depend); - tuple_out_depend_node->set_abstract(out_node->abstract()); - (void)manager_->Replace(main_graph_->output(), tuple_out_depend_node); - return; - } - if (send_ops.empty() && !is_train_) { - return; - } - if (!send_ops.empty()) { - type_ptr_ = send_ops.back()->user_data(DTYPE); - shape_ = send_ops.back()->user_data(SHAPE); - } - if (!enable_share_cell_) { - auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops); - auto zero_outputs = GetZeroOutputs(main_graph_); - std::vector out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple}; - auto out_node = main_graph_->NewCNode(out); - (void)manager_->Replace(main_graph_->output(), out_node); - return; - } - fold_send_tag_map.clear(); - fold_recv_tag_map.clear(); - if (!IsLastStage()) { - HandleGraphOutputs(send_ops); - } - std::vector recv_ops; - (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end())); - (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end())); - HandleGraphInputs(recv_ops); -} - -} // namespace parallel -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h" +#include +#include +#include +#include +#include +#include +#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/graph_util/graph_splitter.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/group_manager.h" +#include "frontend/parallel/parameter_manager.h" +#include "include/common/utils/parallel_context.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/node_check.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" +#include "frontend/parallel/step_parallel_utils.h" +#include "ir/anf.h" +#include "ir/graph_utils.h" +#include "ops/other_ops.h" +#include "ops/array_ops.h" +#include "ops/framework_ops.h" +#include "include/common/utils/comm_manager.h" +#include "utils/ms_context.h" +#include "utils/parallel_node_check.h" + +namespace mindspore { +namespace parallel { +mindspore::HashMap fold_send_tag_map; +mindspore::HashMap fold_recv_tag_map; + +void FoldPipelineTransformer::CreateForwardGroup2() { + auto rank_id = g_device_manager->global_rank(); + auto stage_id = g_device_manager->stage_id(); + auto stage_num = g_device_manager->stage_num(); + + std::vector forward_rank_list; + forward_rank_list.push_back(rank_id); + if (stage_id < stage_num - 1) { + forward_rank_list.push_back(rank_id + per_stage_rank_num_); + } else { + forward_rank_list.push_back(rank_id + per_stage_rank_num_ * (0 - stage_id)); + } + + Group g; + + if (g_device_manager->CreateGroup(forward_rank_list, &g) != SUCCESS) { + MS_LOG(EXCEPTION) << "Create forward communication group between all pipeline stages failed, the rank_list is: " + << forward_rank_list; + } + + std::vector backward_rank_list; + if (stage_id == 0) { + backward_rank_list.push_back(rank_id + per_stage_rank_num_ * (stage_num - 1)); + } else { + backward_rank_list.push_back(rank_id - per_stage_rank_num_); + } + backward_rank_list.push_back(rank_id); + + Group g_back; + if (g_device_manager->CreateGroup(backward_rank_list, &g_back) != SUCCESS) { + MS_LOG(EXCEPTION) << "Create backward communication group between all pipeline stages failed, the rank_list is: " + << backward_rank_list; + } + + group_.push_back(g.name()); + group_.push_back(g_back.name()); +} +void HandleSegment(const ValuePtr &value, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto nodes = graph->nodes(); + for (auto node : nodes) { + if (node->isa()) { + auto cnode = node->cast(); + MS_LOG(INFO) << "Handle Segment cnode: " << cnode->fullname_with_scope(); + cnode->AddPrimalAttr(SEGMENT, value); + } + } +} +void FoldPipelineTransformer::Coloring() { + auto need_coloring = true; + std::set stage_set; + std::set segment_set; + if (!IsTraining(manager_)) { + is_train_ = false; + } + while (need_coloring) { + need_coloring = false; + for (auto &fg : manager_->func_graphs()) { + if (fg == root_ && is_train_) { + continue; + } + auto value_nodes = fg->value_nodes(); + for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) { + auto node = (*value_pair).first; + if (!IsValueNode(node)) { + continue; + } + auto graph = GetValueNode(node); + if (graph->stage() == -1) { + continue; + } + (void)stage_set.insert(graph->stage()); + (void)segment_set.insert(graph->segment()); + auto node_users = manager_->node_users()[node]; + HandleSegment(MakeValue(graph->segment()), graph); + for (auto &user_pair : node_users) { + auto user_node = user_pair.first->cast(); + user_node->set_user_data(std::make_shared(graph->stage())); + user_node->set_user_data(std::make_shared(graph->segment())); + auto user_node_graph = user_node->func_graph(); + if (graph->stage() == stage_ && user_node_graph->stage() == -1) { + user_node_graph->set_stage(graph->stage()); + MS_LOG(INFO) << "Set_segment in Coloring" << graph->segment(); + user_node_graph->set_segment(graph->segment()); + need_coloring = true; + } + } + } + } + } + MS_EXCEPTION_IF_NULL(g_device_manager); + auto stage_num = g_device_manager->stage_num(); + auto segment_num = ParallelContext::GetInstance()->pipeline_segment_split_num(); + if (SizeToLong(stage_set.size()) != stage_num) { + MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); + } + if (SizeToLong(segment_set.size()) != segment_num) { + MS_LOG(EXCEPTION) << "Segment num is " << segment_num << " is not equal to segment used: " << segment_set.size(); + } +} + +void FoldPipelineTransformer::ColorForNodes() { + for (auto &fg : manager_->func_graphs()) { + auto stage = fg->stage(); + auto segment = fg->segment(); + if (stage < 0) { + continue; + } + if (segment < 0) { + continue; + } + if (fg == root_ || fg == main_graph_ || fg == shared_cell_) { + continue; + } + auto all_nodes = fg->nodes(); + for (auto node : all_nodes) { + if (node->user_data() != nullptr) { + continue; + } + node->set_user_data(std::make_shared(stage)); + if (node->user_data() != nullptr) { + continue; + } + node->set_user_data(std::make_shared(segment)); + } + } +} + +void FoldPipelineTransformer::BroadCastColoring() { + auto need_coloring = true; + while (need_coloring) { + need_coloring = false; + auto all_nodes = main_graph_->nodes(); + auto node_users = manager_->node_users(); + for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) { + auto stage_info = (*node)->user_data(); + auto segment_info = (*node)->user_data(); + if (!(*node)->isa() || stage_info == nullptr || stage_info->stage() == -1 || + IsPrimitiveCNode(*node, prim::kPrimUpdateState)) { + continue; + } + auto stage = stage_info->stage(); + auto segment = segment_info->segment(); + for (auto &user_pair : node_users[*node]) { + auto user_node = user_pair.first->cast(); + auto user_stage_info = user_node->user_data(); + auto user_segment_info = user_node->user_data(); + if (user_stage_info == nullptr) { + user_node->set_user_data(std::make_shared(stage)); + user_node->set_user_data(std::make_shared(segment)); + need_coloring = true; + continue; + } + auto user_node_stage = user_stage_info->stage(); + auto user_node_segment = user_segment_info->segment(); + if (stage > user_node_stage && segment == user_node_segment) { + if (IsValueNode(user_node->input(0))) { + MS_LOG(WARNING) << "The stage setting is incorrect. PreNode's stage: " << stage + << " is larger than NextNode's stage:" << user_node_stage; + } + user_node->set_user_data(std::make_shared(stage)); + need_coloring = true; + } + if (segment > user_node_segment) { + user_node->set_user_data(std::make_shared(segment)); + need_coloring = true; + } + } + } + } + ColorForNodes(); +} + +SendAttr FoldPipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, + const ValuePtr &value, int64_t segment) { + auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; + int64_t send_tag; + auto stage_num = g_device_manager->stage_num(); + if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { + if (fold_recv_tag_map.find(dest_rank) != fold_recv_tag_map.end()) { + send_tag = fold_recv_tag_map[dest_rank] + 1; + fold_recv_tag_map[dest_rank] += 1; + } else { + send_tag = 0; + fold_recv_tag_map[dest_rank] = 0; + } + } else { + if (fold_send_tag_map.find(dest_rank) != fold_send_tag_map.end()) { + send_tag = fold_send_tag_map[dest_rank] + 1; + fold_send_tag_map[dest_rank] += 1; + } else { + send_tag = 0; + fold_send_tag_map[dest_rank] = 0; + } + } + Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag)); + Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + if (stage_num > 2) { + auto next = (user_node_stage == 0) ? 0 : 1; + attr_rank = std::make_pair(DEST_RANK, MakeValue(next)); + attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0])); + } + + if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { + attr_group = std::make_pair(GROUP, MakeValue(group_[1])); + attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + attr_rank = std::make_pair(DEST_RANK, MakeValue(1)); + } + auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; + std::vector send_input = {parameter}; + OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back}; + CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, send_input, attrs); + auto prim = GetCNodePrimitive(send); + AnfNodePtr care_node; + bool is_param = true; + auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param); + auto tensor_info = GetTensorInfo(op_info_pair, is_param); + + auto index = op_info_pair.second; + auto op_info = op_info_pair.first; + auto slice_shape = tensor_info.slice_shape(); + auto shape_type_pair = GetShapeType(parameter, slice_shape, 0); + prim->set_attr(SHAPE, shape_type_pair.first); + prim->set_attr(DTYPE, shape_type_pair.second); + if (!is_param) { + send->AddPrimalAttr(PIPELINE_END, value); + } else { + send->AddPrimalAttr(PIPELINE_PARAM, value); + send->set_user_data(op_info); + send->AddPrimalAttr(PARAM_INDEX, MakeValue(index)); + auto param = care_node ? care_node : parameter; + send->set_user_data(INPUT_PARAM, param); + } + send->AddPrimalAttr(MICRO, value); + send->AddPrimalAttr(SEGMENT, MakeValue(segment)); + MS_LOG(INFO) << "Insert Send op, segment is " << segment; + send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage)); + OperatorAttrs depend_attrs; + CNodePtr depend = CreateCNodeByInputsAndAttr(graph, DEPEND, DEPEND, AnfNodePtrList{parameter, send}, depend_attrs); + auto abstract = parameter->abstract(); + if (care_node) { + abstract = care_node->abstract(); + } + depend->set_abstract(abstract); + send->set_abstract(abstract); + SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; + + send->set_user_data(DEST_RANK, std::make_shared(dest_rank)); + send->set_user_data(USER_NODE_STAGE, std::make_shared(user_node_stage)); + return send_out; +} + +int64_t FoldPipelineTransformer::ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, + int64_t src_rank) { + int64_t recv_tag; + if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { + if (fold_send_tag_map.find(src_rank) != fold_send_tag_map.end()) { + recv_tag = fold_send_tag_map[src_rank] + 1; + fold_send_tag_map[src_rank] += 1; + } else { + recv_tag = 0; + fold_send_tag_map[src_rank] = 0; + } + } else { + if (fold_recv_tag_map.find(src_rank) != fold_recv_tag_map.end()) { + recv_tag = fold_recv_tag_map[src_rank] + 1; + fold_recv_tag_map[src_rank] += 1; + } else { + recv_tag = 0; + fold_recv_tag_map[src_rank] = 0; + } + } + return recv_tag; +} + +AnfNodePtr FoldPipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, + const AnfNodePtr &use_node, int index, int64_t user_node_stage, + int64_t node_stage, const ValuePtr &value, + const AnfNodePtr &graph_param, int64_t segment) { + auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; + auto stage_num = g_device_manager->stage_num(); + auto recv_tag = ComputeRecvTag(node_stage, user_node_stage, stage_num, src_rank); + Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag)); + Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + + if (stage_num > 2) { + auto next = (user_node_stage == 0) ? 1 : 0; + attr_rank = std::make_pair(SRC_RANK, MakeValue(next)); + attr_group = std::make_pair(GROUP, MakeValue(group_[1])); + attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + } + bool is_param = true; + AnfNodePtr care_node; + auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param); + auto tensor_info = GetTensorInfo(op_info_pair, is_param); + auto tensor_layout = tensor_info.tensor_layout(); + Shape slice_shape = tensor_info.slice_shape(); + auto shape_type_pair = GetShapeType(node, slice_shape, 0); + Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first); + Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second); + if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) { + attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0])); + attr_rank = std::make_pair(SRC_RANK, MakeValue(0)); + } + OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back}; + std::vector recv_input; + if (node->isa()) { + recv_input = {node}; + } else { + recv_input = {virtual_param_}; + } + auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs); + if (is_param) { + recv->set_user_data(PIPELINE_PARAM, node); + recv->AddPrimalAttr(PIPELINE_PARAM, value); + auto param = care_node ? care_node : node; + recv->set_user_data(INPUT_PARAM, param); + } else { + recv->AddPrimalAttr(PIPELINE_BEGIN, value); + } + recv->AddPrimalAttr(MICRO, value); + recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage)); + recv->AddPrimalAttr(SEGMENT, MakeValue(segment)); + MS_LOG(INFO) << "Insertreceive segment" << segment; + auto node_abstract = node->abstract(); + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsValueNode(cnode->input(0))) { + auto output = GetValueNode(cnode->input(0))->output(); + MS_EXCEPTION_IF_NULL(output); + node_abstract = output->abstract(); + } + } + MS_EXCEPTION_IF_NULL(node_abstract); + recv->set_abstract(node_abstract); + if (node->isa()) { + BaseShapePtr parallel_shape = std::make_shared(slice_shape); + auto abstract_clone = node->abstract()->Clone(); + MS_EXCEPTION_IF_NULL(abstract_clone); + abstract_clone->set_shape(parallel_shape); + node->set_abstract(abstract_clone); + node->set_user_data(std::make_shared(tensor_layout)); + auto actual_param = RefParameterToActualParameter(node); + if (actual_param) { + actual_param->set_user_data(std::make_shared(tensor_layout)); + auto actual_param_abstract = actual_param->abstract()->Clone(); + actual_param_abstract->set_shape(parallel_shape); + actual_param->set_abstract(actual_param_abstract); + } + } + recv->set_user_data(std::make_shared(tensor_layout)); + recv->set_user_data(op_info_pair.first); + + recv->set_user_data(SRC_RANK, std::make_shared(src_rank)); + recv->set_user_data(NODE_STAGE, std::make_shared(node_stage)); + recv->set_user_data(SLICE_DTYPE, shape_type_pair.second); + recv->set_user_data(SLICE_SHAPE, std::make_shared(slice_shape)); + + manager_->SetEdge(use_node, index, recv); + return recv; +} + +AnfNodePtr FoldPipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment, + const std::vector &out_input, + const std::vector &out_input_segment, const std::string &tag) { + std::vector> zipped; + std::transform(out_input.begin(), out_input.end(), out_input_segment.begin(), std::back_inserter(zipped), + [](const auto &send, const auto &send_segment) { return std::make_pair(send, send_segment); }); + + for (auto &zipp : zipped) { + auto input = zipp.first; + auto send_segment = zipp.second; + auto cnode = input->cast(); + if (!cnode) { + continue; + } + if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) { + cnode = cnode->input(DEPEND_NODE_SOURCE_INDEX)->cast(); + } + if (cnode->input(1) == node) { + auto prim = GetValueNode(cnode->input(0)); + auto dest_rank_send = GetValue(prim->GetAttr(tag)); + if (dest_rank_send == stage && node_segment == send_segment) { + return input; + } + } + } + return nullptr; +} + +AnfNodePtr FoldPipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, + int64_t stage, int64_t user_stage, const ValuePtr µ, + size_t pos, const std::vector &ops) { + CNodePtr call_node = nullptr; + auto argument = GetRealKernelNode(node, -1, &call_node).first; + + auto use_cnode = use_node->cast(); + MS_EXCEPTION_IF_NULL(use_cnode); + if (!IsValueNode(use_cnode->input(0))) { + MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString(); + } + auto use_graph = GetValueNode(use_cnode->input(0)); + auto use_parameter_list = use_graph->parameters(); + auto parameter = use_parameter_list.at(pos - 1); + + // insert receive + if (stage_ == user_stage) { + auto recv = PipelineTransformer::Reuse(argument, stage, ops, SRC_RANK); + if (recv) { + manager_->SetEdge(use_node, SizeToInt(pos), recv); + return nullptr; + } + auto root_param = argument; + if (argument->isa() && argument->func_graph() != root_) { + root_param = GetArgumentsByParameter(argument); + } + (void)parameter_color_map_[root_param].insert(user_stage); + auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; + return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter, 0); + } + // insert send + if (PipelineTransformer::Reuse(argument, user_stage, ops, DEST_RANK)) { + return nullptr; + } + auto send_out = InsertSend(argument, user_stage, stage_, micro, 0); + send_out.depend->set_user_data(DTYPE, send_out.type); + send_out.depend->set_user_data(SHAPE, send_out.shape); + return send_out.depend; +} + +bool IsStageConflict(int64_t node_stage, int64_t user_node_stage, int64_t node_segment, int64_t user_node_segment, + int64_t stage_num, bool isEmbed) { + if (isEmbed || (node_stage < user_node_stage && node_segment == user_node_segment) || + (node_stage == stage_num - 1 && user_node_stage == 0 && node_segment < user_node_segment)) { + return true; + } + return false; +} + +void FoldPipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, + std::vector *send_ops, + std::vector *send_ops_segment, + std::vector *receive_ops) { + auto stage_info = node->user_data(); + auto segment_info = node->user_data(); + auto node_users = manager_->node_users()[node]; + AnfNodePtr receive = nullptr; + for (auto &user_pair : node_users) { + auto user_node = user_pair.first; + auto node_stage = stage_info->stage(); + auto node_segment = segment_info->segment(); + auto user_stage_info = user_node->user_data(); + if (user_stage_info == nullptr) { + continue; + } + auto user_segment_info = user_node->user_data(); + if (user_segment_info == nullptr) { + continue; + } + auto user_node_stage = user_stage_info->stage(); + if (node_stage != stage_ && user_node_stage != stage_) { + continue; + } + auto micro = user_node->cast()->GetPrimalAttr(MICRO); + auto user_node_segment = user_segment_info->segment(); + if (!micro) { + MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)"; + micro = MakeValue(int64_t(0)); + } + auto stage_num = g_device_manager->stage_num(); + + bool isEmbed = node_stage < user_node_stage && node_segment != user_node_segment; + if (IsStageConflict(node_stage, user_node_stage, node_segment, user_node_segment, stage_num, isEmbed)) { + if (node_stage == stage_) { + if (IsParameterGraph(node) && isEmbed) { + auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, + IntToSize(user_pair.second), *send_ops); + if (!send_depend) { + continue; + } + (void)send_ops->insert(send_ops->cbegin(), send_depend); + (void)send_ops_segment->insert(send_ops_segment->begin(), node_segment); + continue; + } + if (Reuse(node, user_node_stage, user_node_segment, *send_ops, *send_ops_segment, DEST_RANK)) { + continue; + } + auto send_out = InsertSend(node, user_node_stage, node_stage, micro, node_segment); + MS_EXCEPTION_IF_NULL(send_out.depend); + send_ops->push_back(send_out.depend); + send_ops_segment->push_back(node_segment); + send_out.depend->set_user_data(DTYPE, send_out.type); + send_out.depend->set_user_data(SHAPE, send_out.shape); + } else { + if (!receive) { + if (IsParameterGraph(node)) { + receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, + IntToSize(user_pair.second), *receive_ops); + if (!receive) { + continue; + } + receive_ops->push_back(receive); + } else { + receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node, + user_node_segment); + receive_ops->push_back(receive); + } + } else { + manager_->SetEdge(user_node, user_pair.second, receive); + } + } + continue; + } + if (node_stage > user_node_stage && node_segment == user_node_segment) { + MS_LOG(EXCEPTION) << "Within a segment, node_stage: " << node_stage + << " must be smaller than user_node_stage: " << user_node_stage; + } + } +} + +std::pair, std::vector> FoldPipelineTransformer::CutBorder( + const FuncGraphPtr &graph) { + std::vector send_ops; + std::vector send_ops_segment; + std::vector receive_ops; + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::reverse(all_nodes.begin(), all_nodes.end()); + auto stage_num = g_device_manager->stage_num(); + if (is_train_ && (stage_num > micro_size_)) { + MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; + } + for (auto &node : all_nodes) { + auto stage_info = node->user_data(); + if (!node->isa() || stage_info == nullptr || stage_info->stage() == -1 || + IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + continue; + } + CutBorderForNode(graph, node, &send_ops, &send_ops_segment, &receive_ops); + } + RemoveMonadNode(); + return std::make_pair(send_ops, receive_ops); +} + +std::pair, std::vector> FoldPipelineTransformer::HandleSharedParameter() { + auto parameters = root_->parameters(); + std::vector sends = {}; + std::vector recvs = {}; + for (auto ¶meter : parameters) { + auto parameter_stage = parameter_color_map_[parameter]; + if (parameter_stage.size() <= 1) { + continue; + } + const auto &node_users_map = manager_->node_users(); + auto users = GetParameterLoadUsers(parameter, node_users_map); + for (auto &user : users) { + auto node = user.first; + auto cnode = node->cast(); + auto graph = node->func_graph(); + if (IsValueNode(cnode->input(0))) { + graph = GetValueNode(cnode->input(0)); + } + if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) { + continue; + } + auto micro = cnode->GetPrimalAttr(MICRO); + if (!micro) { + MS_LOG(INFO) << "Parameter: " << parameter->ToString() << " doesn't have micro batch"; + micro = MakeValue(int64_t(0)); + } + if (stage_ == *parameter_stage.begin()) { + auto user_stage = graph->stage(); + auto stage_info = node->user_data(); + if (stage_info) { + user_stage = stage_info->stage(); + } + if (graph->stage() == stage_ || user_stage == -1) { + continue; + } + if (PipelineTransformer::Reuse(parameter, user_stage, sends, DEST_RANK)) { + continue; + } + auto send_out = InsertSend(parameter, user_stage, stage_, micro, 0); + sends.push_back(send_out.depend); + } else { + auto receive = PipelineTransformer::Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK); + if (receive) { + manager_->SetEdge(node, user.second, receive); + } else { + AnfNodePtr recv; + auto fg = enable_share_cell_ ? shared_cell_ : main_graph_; + recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter, 0); + (void)(recvs.push_back(recv)); + } + } + } + } + return std::make_pair(sends, recvs); +} + +void FoldPipelineTransformer::CutGraph() { + CreateForwardGroup2(); + MS_EXCEPTION_IF_NULL(main_graph_); + auto send_recv_shared_param = HandleSharedParameter(); + auto graph = enable_share_cell_ ? shared_cell_ : main_graph_; + MS_EXCEPTION_IF_NULL(graph); + auto send_recv_cut_border = CutBorder(graph); + std::vector send_ops; + (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end())); + (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end())); + if (IsLastStage() && !enable_share_cell_) { + auto out_node = main_graph_->output(); + + auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops); + + std::vector tuple_out_depend = {NewValueNode(prim::kPrimDepend)}; + tuple_out_depend.push_back(out_node); + tuple_out_depend.push_back(make_tuple); + + auto tuple_out_depend_node = main_graph_->NewCNode(tuple_out_depend); + tuple_out_depend_node->set_abstract(out_node->abstract()); + (void)manager_->Replace(main_graph_->output(), tuple_out_depend_node); + return; + } + if (send_ops.empty() && !is_train_) { + return; + } + if (!send_ops.empty()) { + type_ptr_ = send_ops.back()->user_data(DTYPE); + shape_ = send_ops.back()->user_data(SHAPE); + } + if (!enable_share_cell_) { + auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops); + auto zero_outputs = GetZeroOutputs(main_graph_); + std::vector out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple}; + auto out_node = main_graph_->NewCNode(out); + (void)manager_->Replace(main_graph_->output(), out_node); + return; + } + fold_send_tag_map.clear(); + fold_recv_tag_map.clear(); + if (!IsLastStage()) { + HandleGraphOutputs(send_ops); + } + std::vector recv_ops; + (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end())); + (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end())); + HandleGraphInputs(recv_ops); +} + +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h index 5cb5f2b6e4f..465cba64ab3 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h @@ -1,87 +1,87 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ -#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ - -#include -#include -#include -#include -#include -#include "ir/value.h" -#include "ir/graph_utils.h" -#include "base/base.h" -#include "utils/hash_map.h" -#include "frontend/parallel/step_parallel.h" -#include "frontend/parallel/graph_util/generate_graph.h" -#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" - -namespace mindspore { -namespace parallel { -const int32_t DEPEND_NODE_SOURCE_INDEX = 2; - -class FoldPipelineTransformer : public PipelineTransformer { - public: - FoldPipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, - int64_t per_stage_rank_num) - : PipelineTransformer(manager, stage, root, global_rank, per_stage_rank_num) {} - ~FoldPipelineTransformer() = default; - void Coloring() override; - void BroadCastColoring() override; - void CutGraph() override; - - SendAttr InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, - int64_t segment); - AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, - int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, - const AnfNodePtr &graph_param, int64_t segment); - - void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector *send_ops, - std::vector *send_ops_segment, std::vector *receive_ops); - AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment, - const std::vector &out_input, const std::vector &out_input_segment, - const std::string &tag); - std::pair, std::vector> CutBorder(const FuncGraphPtr &graph) override; - AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage, - const ValuePtr µ, size_t pos, const std::vector &ops) override; - std::pair, std::vector> HandleSharedParameter() override; - - private: - void CreateForwardGroup2(); - int64_t ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, int64_t src_rank); - void ColorForNodes(); - std::vector group_ = {}; -}; - -class NodeSegmentInfo { - public: - explicit NodeSegmentInfo(int64_t segment) : segment_(segment) {} - ~NodeSegmentInfo() = default; - - int64_t segment() const { return segment_; } - - // Key for user data. - constexpr static char key[] = "NodeSegmentInfo"; - - private: - int64_t segment_; -}; - -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ + +#include +#include +#include +#include +#include +#include "ir/value.h" +#include "ir/graph_utils.h" +#include "base/base.h" +#include "utils/hash_map.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" + +namespace mindspore { +namespace parallel { +const int32_t DEPEND_NODE_SOURCE_INDEX = 2; + +class FoldPipelineTransformer : public PipelineTransformer { + public: + FoldPipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, + int64_t per_stage_rank_num) + : PipelineTransformer(manager, stage, root, global_rank, per_stage_rank_num) {} + ~FoldPipelineTransformer() = default; + void Coloring() override; + void BroadCastColoring() override; + void CutGraph() override; + + SendAttr InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, + int64_t segment); + AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, + int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, + const AnfNodePtr &graph_param, int64_t segment); + + void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector *send_ops, + std::vector *send_ops_segment, std::vector *receive_ops); + AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment, + const std::vector &out_input, const std::vector &out_input_segment, + const std::string &tag); + std::pair, std::vector> CutBorder(const FuncGraphPtr &graph) override; + AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage, + const ValuePtr µ, size_t pos, const std::vector &ops) override; + std::pair, std::vector> HandleSharedParameter() override; + + private: + void CreateForwardGroup2(); + int64_t ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, int64_t src_rank); + void ColorForNodes(); + std::vector group_ = {}; +}; + +class NodeSegmentInfo { + public: + explicit NodeSegmentInfo(int64_t segment) : segment_(segment) {} + ~NodeSegmentInfo() = default; + + int64_t segment() const { return segment_; } + + // Key for user data. + constexpr static char key[] = "NodeSegmentInfo"; + + private: + int64_t segment_; +}; + +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ diff --git a/mindspore/ccsrc/include/backend/distributed/ps/scheduler.h b/mindspore/ccsrc/include/backend/distributed/ps/scheduler.h old mode 100755 new mode 100644 index 32ee9d23420..9db01993313 --- a/mindspore/ccsrc/include/backend/distributed/ps/scheduler.h +++ b/mindspore/ccsrc/include/backend/distributed/ps/scheduler.h @@ -1,45 +1,45 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_ -#define MINDSPORE_CCSRC_PS_SCHEDULER_H_ - -#include -#include "include/backend/distributed/ps/util.h" -#include "include/backend/distributed/ps/ps_context.h" -#include "include/backend/visible.h" - -namespace mindspore { -namespace ps { -namespace core { -class SchedulerNode; -} // namespace core -class BACKEND_EXPORT Scheduler { - public: - static Scheduler &GetInstance(); - - void Run(); - - private: - Scheduler(); - ~Scheduler(); - Scheduler(const Scheduler &) = delete; - Scheduler &operator=(const Scheduler &) = delete; - std::unique_ptr scheduler_node_; -}; -} // namespace ps -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_SCHEDULER_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_ +#define MINDSPORE_CCSRC_PS_SCHEDULER_H_ + +#include +#include "include/backend/distributed/ps/util.h" +#include "include/backend/distributed/ps/ps_context.h" +#include "include/backend/visible.h" + +namespace mindspore { +namespace ps { +namespace core { +class SchedulerNode; +} // namespace core +class BACKEND_EXPORT Scheduler { + public: + static Scheduler &GetInstance(); + + void Run(); + + private: + Scheduler(); + ~Scheduler(); + Scheduler(const Scheduler &) = delete; + Scheduler &operator=(const Scheduler &) = delete; + std::unique_ptr scheduler_node_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SCHEDULER_H_ diff --git a/mindspore/ccsrc/include/backend/optimizer/node_pass.h b/mindspore/ccsrc/include/backend/optimizer/node_pass.h index 62747eb45d6..77d8ab05704 100644 --- a/mindspore/ccsrc/include/backend/optimizer/node_pass.h +++ b/mindspore/ccsrc/include/backend/optimizer/node_pass.h @@ -1,56 +1,56 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ -#include -#include -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "include/backend/visible.h" - -namespace mindspore { -namespace opt { -// @brief ANF Node level optimization base pass -class BACKEND_EXPORT NodePass : public Pass { - public: - explicit NodePass(const std::string &name) : Pass(name) {} - ~NodePass() override = default; - bool Run(const FuncGraphPtr &func_graph) override; - virtual bool IsFastPass() { return false; } - virtual void AfterProcess(const AnfNodePtr &, const AnfNodePtr &, const FuncGraphPtr &, const FuncGraphIndexPtr &) {} - virtual std::string GetPatternRootPrimitiveName() { return ""; } - virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; - virtual std::vector MustExistPrimitiveName() const { return {}; } - - protected: - bool is_add_ = true; - - private: - bool ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph, - const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager); - bool ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); - bool ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager); -}; -void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); -void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg, - mindspore::HashMap> *out_caller_map, - bool is_add = true); -std::string GetCNodeKey(const AnfNodePtr &node); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ +#include +#include +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "include/backend/visible.h" + +namespace mindspore { +namespace opt { +// @brief ANF Node level optimization base pass +class BACKEND_EXPORT NodePass : public Pass { + public: + explicit NodePass(const std::string &name) : Pass(name) {} + ~NodePass() override = default; + bool Run(const FuncGraphPtr &func_graph) override; + virtual bool IsFastPass() { return false; } + virtual void AfterProcess(const AnfNodePtr &, const AnfNodePtr &, const FuncGraphPtr &, const FuncGraphIndexPtr &) {} + virtual std::string GetPatternRootPrimitiveName() { return ""; } + virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; + virtual std::vector MustExistPrimitiveName() const { return {}; } + + protected: + bool is_add_ = true; + + private: + bool ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph, + const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager); + bool ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); + bool ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager); +}; +void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); +void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg, + mindspore::HashMap> *out_caller_map, + bool is_add = true); +std::string GetCNodeKey(const AnfNodePtr &node); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/include/backend/optimizer/pass_manager.h b/mindspore/ccsrc/include/backend/optimizer/pass_manager.h index 0701e79db74..4c0ddc8a394 100644 --- a/mindspore/ccsrc/include/backend/optimizer/pass_manager.h +++ b/mindspore/ccsrc/include/backend/optimizer/pass_manager.h @@ -1,67 +1,67 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ - -#include -#include -#include -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "include/backend/optimizer/node_pass.h" -#include "include/backend/visible.h" - -namespace mindspore { -namespace opt { -// @brief For optimization passes management -class BACKEND_EXPORT PassManager { - public: - explicit PassManager(const std::string &name = "pm", bool run_only_once = true); - virtual ~PassManager() = default; - // Get all the passes added by AddPass - const std::vector &Passes() const { return passes_; } - // Add graph pass, the pass object will be freed when pass manager freed. - virtual void AddPass(const PassPtr &pass); - // Run passes added in pass manager on the input graph - // @param [in out] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - virtual bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [in out] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - virtual bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - protected: - virtual bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const; - virtual std::string GetPassFullname(size_t pass_id, const PassPtr &pass) const; - virtual void DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const; - - const std::string name_; - std::vector passes_; - bool run_only_once_; - CacheManagerPtr cache_manager_; -}; -using PassManagerPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "include/backend/optimizer/node_pass.h" +#include "include/backend/visible.h" + +namespace mindspore { +namespace opt { +// @brief For optimization passes management +class BACKEND_EXPORT PassManager { + public: + explicit PassManager(const std::string &name = "pm", bool run_only_once = true); + virtual ~PassManager() = default; + // Get all the passes added by AddPass + const std::vector &Passes() const { return passes_; } + // Add graph pass, the pass object will be freed when pass manager freed. + virtual void AddPass(const PassPtr &pass); + // Run passes added in pass manager on the input graph + // @param [in out] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + virtual bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [in out] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + virtual bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + protected: + virtual bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const; + virtual std::string GetPassFullname(size_t pass_id, const PassPtr &pass) const; + virtual void DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const; + + const std::string name_; + std::vector passes_; + bool run_only_once_; + CacheManagerPtr cache_manager_; +}; +using PassManagerPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/include/common/utils/recompute_helper.h b/mindspore/ccsrc/include/common/utils/recompute_helper.h index cbbe3eacde8..3d0aa3e155d 100644 --- a/mindspore/ccsrc/include/common/utils/recompute_helper.h +++ b/mindspore/ccsrc/include/common/utils/recompute_helper.h @@ -1,83 +1,83 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H -#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H - -#include -#include "ir/anf.h" -#include "utils/hash_map.h" -#include "utils/hash_set.h" -#include "ir/func_graph.h" -#include "include/common/visible.h" - -namespace mindspore { -bool CanNotRecomputed(const CNodePtr &node); - -bool IsBpropNode(const AnfNodePtr &node); - -ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node); - -bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node); - -bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node); - -bool IsCandidateRecomputedNode(const CNodePtr &node); - -bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap *has_grad_inputs_map); - -bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node); - -void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, - std::vector *tuple_getitem_output_nodes); - -bool SetRecomputedScope(const CNodePtr &node); - -CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, - const std::vector &new_inputs); - -CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, - const std::vector &first_target_inputs, - const mindspore::HashSet &recomputed_origin_nodes, - mindspore::HashMap *origin_to_recomputed_nodes); - -COMMON_EXPORT void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &origin_nodes_topological); - -COMMON_EXPORT bool WithRecomputedScope(const AnfNodePtr &node); - -COMMON_EXPORT std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, - const std::vector &cnodes); - -COMMON_EXPORT void GetMaxSubGraph(const FuncGraphManagerPtr &mng, mindspore::HashSet *recomputed_nodes, - bool get_inputs, bool get_outputs); - -COMMON_EXPORT void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, - const mindspore::HashSet &max_recomputed_sub_graph, - mindspore::HashSet *recompute_nodes, - mindspore::HashSet *target_nodes); - -COMMON_EXPORT std::vector GetFirstTargetInputs(const std::vector &origin_nodes_topological, - const mindspore::HashSet &max_recomputed_sub_graph, - const mindspore::HashSet &recomputed_origin_nodes, - const mindspore::HashSet &target_nodes); - -COMMON_EXPORT void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSet &target_nodes, - const mindspore::HashSet &origin_recomputed_nodes, - const std::vector &first_target_inputs, - mindspore::HashMap *origin_to_new_target_nodes, - mindspore::HashMap *origin_to_recomputed_nodes); -} // namespace mindspore -#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H +#define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H + +#include +#include "ir/anf.h" +#include "utils/hash_map.h" +#include "utils/hash_set.h" +#include "ir/func_graph.h" +#include "include/common/visible.h" + +namespace mindspore { +bool CanNotRecomputed(const CNodePtr &node); + +bool IsBpropNode(const AnfNodePtr &node); + +ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node); + +bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node); + +bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node); + +bool IsCandidateRecomputedNode(const CNodePtr &node); + +bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap *has_grad_inputs_map); + +bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node); + +void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, + std::vector *tuple_getitem_output_nodes); + +bool SetRecomputedScope(const CNodePtr &node); + +CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, + const std::vector &new_inputs); + +CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, + const std::vector &first_target_inputs, + const mindspore::HashSet &recomputed_origin_nodes, + mindspore::HashMap *origin_to_recomputed_nodes); + +COMMON_EXPORT void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &origin_nodes_topological); + +COMMON_EXPORT bool WithRecomputedScope(const AnfNodePtr &node); + +COMMON_EXPORT std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, + const std::vector &cnodes); + +COMMON_EXPORT void GetMaxSubGraph(const FuncGraphManagerPtr &mng, mindspore::HashSet *recomputed_nodes, + bool get_inputs, bool get_outputs); + +COMMON_EXPORT void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, + const mindspore::HashSet &max_recomputed_sub_graph, + mindspore::HashSet *recompute_nodes, + mindspore::HashSet *target_nodes); + +COMMON_EXPORT std::vector GetFirstTargetInputs(const std::vector &origin_nodes_topological, + const mindspore::HashSet &max_recomputed_sub_graph, + const mindspore::HashSet &recomputed_origin_nodes, + const mindspore::HashSet &target_nodes); + +COMMON_EXPORT void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSet &target_nodes, + const mindspore::HashSet &origin_recomputed_nodes, + const std::vector &first_target_inputs, + mindspore::HashMap *origin_to_new_target_nodes, + mindspore::HashMap *origin_to_recomputed_nodes); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_RECOMPUTE_HELPER_H diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index 8422c5eea2f..1b55597706f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -1,1254 +1,1254 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/include/dataset/audio.h" - -#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" -#include "minddata/dataset/audio/ir/kernels/angle_ir.h" -#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" -#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h" -#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" -#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" -#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" -#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" -#include "minddata/dataset/audio/ir/kernels/dither_ir.h" -#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/fade_ir.h" -#include "minddata/dataset/audio/ir/kernels/filtfilt_ir.h" -#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" -#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" -#include "minddata/dataset/audio/ir/kernels/gain_ir.h" -#include "minddata/dataset/audio/ir/kernels/griffin_lim_ir.h" -#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/inverse_mel_scale_ir.h" -#include "minddata/dataset/audio/ir/kernels/inverse_spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/lfcc_ir.h" -#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" -#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" -#include "minddata/dataset/audio/ir/kernels/mel_scale_ir.h" -#include "minddata/dataset/audio/ir/kernels/mel_spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/mfcc_ir.h" -#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" -#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h" -#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" -#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" -#include "minddata/dataset/audio/ir/kernels/phaser_ir.h" -#include "minddata/dataset/audio/ir/kernels/pitch_shift_ir.h" -#include "minddata/dataset/audio/ir/kernels/resample_ir.h" -#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h" -#include "minddata/dataset/audio/ir/kernels/spectral_centroid_ir.h" -#include "minddata/dataset/audio/ir/kernels/spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" -#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" -#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/vad_ir.h" -#include "minddata/dataset/audio/ir/kernels/vol_ir.h" -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/audio_utils.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// AllpassBiquad Transform Operation. -struct AllpassBiquad::Data { - Data(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - int32_t sample_rate_; - float central_freq_; - float Q_; -}; - -AllpassBiquad::AllpassBiquad(int32_t sample_rate, float central_freq, float Q) - : data_(std::make_shared(sample_rate, central_freq, Q)) {} - -std::shared_ptr AllpassBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_); -} - -// AmplitudeToDB Transform Operation. -struct AmplitudeToDB::Data { - Data(ScaleType stype, float ref_value, float amin, float top_db) - : stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {} - ScaleType stype_; - float ref_value_; - float amin_; - float top_db_; -}; - -AmplitudeToDB::AmplitudeToDB(ScaleType stype, float ref_value, float amin, float top_db) - : data_(std::make_shared(stype, ref_value, amin, top_db)) {} - -std::shared_ptr AmplitudeToDB::Parse() { - return std::make_shared(data_->stype_, data_->ref_value_, data_->amin_, data_->top_db_); -} - -// Angle Transform Operation. -Angle::Angle() = default; - -std::shared_ptr Angle::Parse() { return std::make_shared(); } - -// BandBiquad Transform Operation. -struct BandBiquad::Data { - Data(int32_t sample_rate, float central_freq, float Q, bool noise) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} - int32_t sample_rate_; - float central_freq_; - float Q_; - bool noise_; -}; - -BandBiquad::BandBiquad(int32_t sample_rate, float central_freq, float Q, bool noise) - : data_(std::make_shared(sample_rate, central_freq, Q, noise)) {} - -std::shared_ptr BandBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_, data_->noise_); -} - -// BandpassBiquad Transform Operation. -struct BandpassBiquad::Data { - Data(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} - int32_t sample_rate_; - float central_freq_; - float Q_; - bool const_skirt_gain_; -}; - -BandpassBiquad::BandpassBiquad(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) - : data_(std::make_shared(sample_rate, central_freq, Q, const_skirt_gain)) {} - -std::shared_ptr BandpassBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_, - data_->const_skirt_gain_); -} - -// BandrejectBiquad Transform Operation. -struct BandrejectBiquad::Data { - Data(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - int32_t sample_rate_; - float central_freq_; - float Q_; -}; - -BandrejectBiquad::BandrejectBiquad(int32_t sample_rate, float central_freq, float Q) - : data_(std::make_shared(sample_rate, central_freq, Q)) {} - -std::shared_ptr BandrejectBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_); -} - -// BassBiquad Transform Operation. -struct BassBiquad::Data { - Data(int32_t sample_rate, float gain, float central_freq, float Q) - : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} - int32_t sample_rate_; - float gain_; - float central_freq_; - float Q_; -}; - -BassBiquad::BassBiquad(int32_t sample_rate, float gain, float central_freq, float Q) - : data_(std::make_shared(sample_rate, gain, central_freq, Q)) {} - -std::shared_ptr BassBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_); -} - -// Biquad Transform Operation. -struct Biquad::Data { - Data(float b0, float b1, float b2, float a0, float a1, float a2) - : b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {} - float b0_; - float b1_; - float b2_; - float a0_; - float a1_; - float a2_; -}; - -Biquad::Biquad(float b0, float b1, float b2, float a0, float a1, float a2) - : data_(std::make_shared(b0, b1, b2, a0, a1, a2)) {} - -std::shared_ptr Biquad::Parse() { - return std::make_shared(data_->b0_, data_->b1_, data_->b2_, data_->a0_, data_->a1_, data_->a1_); -} - -// ComplexNorm Transform Operation. -struct ComplexNorm::Data { - explicit Data(float power) : power_(power) {} - float power_; -}; - -ComplexNorm::ComplexNorm(float power) : data_(std::make_shared(power)) {} - -std::shared_ptr ComplexNorm::Parse() { return std::make_shared(data_->power_); } - -// ComputeDeltas Transform Operation. -struct ComputeDeltas::Data { - Data(int32_t win_length, BorderType pad_mode) : win_length_(win_length), pad_mode_(pad_mode) {} - int32_t win_length_; - BorderType pad_mode_; -}; - -ComputeDeltas::ComputeDeltas(int32_t win_length, BorderType pad_mode) - : data_(std::make_shared(win_length, pad_mode)) {} - -std::shared_ptr ComputeDeltas::Parse() { - return std::make_shared(data_->win_length_, data_->pad_mode_); -} - -// Contrast Transform Operation. -struct Contrast::Data { - explicit Data(float enhancement_amount) : enhancement_amount_(enhancement_amount) {} - float enhancement_amount_; -}; - -Contrast::Contrast(float enhancement_amount) : data_(std::make_shared(enhancement_amount)) {} - -std::shared_ptr Contrast::Parse() { - return std::make_shared(data_->enhancement_amount_); -} - -// DBToAmplitude Transform Operation. -struct DBToAmplitude::Data { - explicit Data(float ref, float power) : ref_(ref), power_(power) {} - float ref_; - float power_; -}; - -DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared(ref, power)) {} - -std::shared_ptr DBToAmplitude::Parse() { - return std::make_shared(data_->ref_, data_->power_); -} - -// DCShift Transform Operation. -struct DCShift::Data { - Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} - float shift_; - float limiter_gain_; -}; - -DCShift::DCShift(float shift) : data_(std::make_shared(shift, shift)) {} - -DCShift::DCShift(float shift, float limiter_gain) : data_(std::make_shared(shift, limiter_gain)) {} - -std::shared_ptr DCShift::Parse() { - return std::make_shared(data_->shift_, data_->limiter_gain_); -} - -Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm) { - RETURN_UNEXPECTED_IF_NULL(output); - RETURN_IF_NOT_OK(ValidateIntScalarPositive("CreateDct", "n_mfcc", n_mfcc)); - RETURN_IF_NOT_OK(ValidateIntScalarPositive("CreateDct", "n_mels", n_mels)); - - std::shared_ptr dct; - RETURN_IF_NOT_OK(Dct(&dct, n_mfcc, n_mels, norm)); - CHECK_FAIL_RETURN_UNEXPECTED(dct->HasData(), "CreateDct: get an empty tensor with shape " + dct->shape().ToString()); - *output = mindspore::MSTensor(std::make_shared(dct)); - return Status::OK(); -} - -// DeemphBiquad Transform Operation. -struct DeemphBiquad::Data { - explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} - int32_t sample_rate_; -}; - -DeemphBiquad::DeemphBiquad(int32_t sample_rate) : data_(std::make_shared(sample_rate)) {} - -std::shared_ptr DeemphBiquad::Parse() { - return std::make_shared(data_->sample_rate_); -} - -// DetectPitchFrequency Transform Operation. -struct DetectPitchFrequency::Data { - Data(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high) - : sample_rate_(sample_rate), - frame_time_(frame_time), - win_length_(win_length), - freq_low_(freq_low), - freq_high_(freq_high) {} - int32_t sample_rate_; - float frame_time_; - int32_t win_length_; - int32_t freq_low_; - int32_t freq_high_; -}; - -DetectPitchFrequency::DetectPitchFrequency(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, - int32_t freq_high) - : data_(std::make_shared(sample_rate, frame_time, win_length, freq_low, freq_high)) {} - -std::shared_ptr DetectPitchFrequency::Parse() { - return std::make_shared(data_->sample_rate_, data_->frame_time_, data_->win_length_, - data_->freq_low_, data_->freq_high_); -} - -// Dither Transform Operation. -struct Dither::Data { - Data(DensityFunction density_function, bool noise_shaping) - : density_function_(density_function), noise_shaping_(noise_shaping) {} - DensityFunction density_function_; - bool noise_shaping_; -}; - -Dither::Dither(DensityFunction density_function, bool noise_shaping) - : data_(std::make_shared(density_function, noise_shaping)) {} - -std::shared_ptr Dither::Parse() { - return std::make_shared(data_->density_function_, data_->noise_shaping_); -} - -// EqualizerBiquad Transform Operation. -struct EqualizerBiquad::Data { - Data(int32_t sample_rate, float center_freq, float gain, float Q) - : sample_rate_(sample_rate), center_freq_(center_freq), gain_(gain), Q_(Q) {} - int32_t sample_rate_; - float center_freq_; - float gain_; - float Q_; -}; - -EqualizerBiquad::EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q) - : data_(std::make_shared(sample_rate, center_freq, gain, Q)) {} - -std::shared_ptr EqualizerBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->center_freq_, data_->gain_, data_->Q_); -} - -// Fade Transform Operation. -struct Fade::Data { - Data(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape) - : fade_in_len_(fade_in_len), fade_out_len_(fade_out_len), fade_shape_(fade_shape) {} - int32_t fade_in_len_; - int32_t fade_out_len_; - FadeShape fade_shape_; -}; - -Fade::Fade(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape) - : data_(std::make_shared(fade_in_len, fade_out_len, fade_shape)) {} - -std::shared_ptr Fade::Parse() { - return std::make_shared(data_->fade_in_len_, data_->fade_out_len_, data_->fade_shape_); -} - -// Filtfilt Transform Operation. -struct Filtfilt::Data { - Data(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} - std::vector a_coeffs_; - std::vector b_coeffs_; - bool clamp_; -}; - -Filtfilt::Filtfilt(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : data_(std::make_shared(a_coeffs, b_coeffs, clamp)) {} - -std::shared_ptr Filtfilt::Parse() { - return std::make_shared(data_->a_coeffs_, data_->b_coeffs_, data_->clamp_); -} - -// Flanger Transform Operation. -struct Flanger::Data { - Data(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, - Modulation modulation, Interpolation interpolation) - : sample_rate_(sample_rate), - delay_(delay), - depth_(depth), - regen_(regen), - width_(width), - speed_(speed), - phase_(phase), - modulation_(modulation), - interpolation_(interpolation) {} - int32_t sample_rate_; - float delay_; - float depth_; - float regen_; - float width_; - float speed_; - float phase_; - Modulation modulation_; - Interpolation interpolation_; -}; - -Flanger::Flanger(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, - Modulation modulation, Interpolation interpolation) - : data_(std::make_shared(sample_rate, delay, depth, regen, width, speed, phase, modulation, interpolation)) {} - -std::shared_ptr Flanger::Parse() { - return std::make_shared(data_->sample_rate_, data_->delay_, data_->depth_, data_->regen_, - data_->width_, data_->speed_, data_->phase_, data_->modulation_, - data_->interpolation_); -} - -// FrequencyMasking Transform Operation. -struct FrequencyMasking::Data { - Data(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) - : iid_masks_(iid_masks), - frequency_mask_param_(frequency_mask_param), - mask_start_(mask_start), - mask_value_(mask_value) {} - bool iid_masks_; - int32_t frequency_mask_param_; - int32_t mask_start_; - float mask_value_; -}; - -FrequencyMasking::FrequencyMasking(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) - : data_(std::make_shared(iid_masks, frequency_mask_param, mask_start, mask_value)) {} - -std::shared_ptr FrequencyMasking::Parse() { - return std::make_shared(data_->iid_masks_, data_->frequency_mask_param_, - data_->mask_start_, data_->mask_value_); -} - -// Gain Transform Operation. -struct Gain::Data { - explicit Data(float gain_db) : gain_db_(gain_db) {} - float gain_db_; -}; - -Gain::Gain(float gain_db) : data_(std::make_shared(gain_db)) {} - -std::shared_ptr Gain::Parse() { return std::make_shared(data_->gain_db_); } - -// GriffinLim Transform Operation. -struct GriffinLim::Data { - Data(int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, float power, - float momentum, int32_t length, bool rand_init) - : n_fft_(n_fft), - n_iter_(n_iter), - win_length_(win_length), - hop_length_(hop_length), - window_type_(window_type), - power_(power), - momentum_(momentum), - length_(length), - rand_init_(rand_init) {} - int32_t n_fft_; - int32_t n_iter_; - int32_t win_length_; - int32_t hop_length_; - WindowType window_type_; - float power_; - float momentum_; - int32_t length_; - bool rand_init_; -}; - -GriffinLim::GriffinLim(int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, - float power, float momentum, int32_t length, bool rand_init) - : data_(std::make_shared(n_fft, n_iter, win_length, hop_length, window_type, power, momentum, length, - rand_init)) {} - -std::shared_ptr GriffinLim::Parse() { - return std::make_shared(data_->n_fft_, data_->n_iter_, data_->win_length_, data_->hop_length_, - data_->window_type_, data_->power_, data_->momentum_, data_->length_, - data_->rand_init_); -} - -// HighpassBiquad Transform Operation. -struct HighpassBiquad::Data { - Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} - int32_t sample_rate_; - float cutoff_freq_; - float Q_; -}; - -HighpassBiquad::HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q) - : data_(std::make_shared(sample_rate, cutoff_freq, Q)) {} - -std::shared_ptr HighpassBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->cutoff_freq_, data_->Q_); -} - -// InverseMelScale Transform Operation. -struct InverseMelScale::Data { - Data(int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t max_iter, - float tolerance_loss, float tolerance_change, const std::map &sgdargs, NormType norm, - MelType mel_type) - : n_stft_(n_stft), - n_mels_(n_mels), - sample_rate_(sample_rate), - f_min_(f_min), - f_max_(f_max), - max_iter_(max_iter), - tolerance_loss_(tolerance_loss), - tolerance_change_(tolerance_change), - sgdargs_(sgdargs), - norm_(norm), - mel_type_(mel_type) {} - int32_t n_stft_; - int32_t n_mels_; - int32_t sample_rate_; - float f_min_; - float f_max_; - int32_t max_iter_; - float tolerance_loss_; - float tolerance_change_; - std::map sgdargs_; - NormType norm_; - MelType mel_type_; -}; - -InverseMelScale::InverseMelScale(int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, - int32_t max_iter, float tolerance_loss, float tolerance_change, - const std::map &sgdargs, NormType norm, MelType mel_type) - : data_(std::make_shared(n_stft, n_mels, sample_rate, f_min, f_max, max_iter, tolerance_loss, - tolerance_change, sgdargs, norm, mel_type)) {} - -std::shared_ptr InverseMelScale::Parse() { - return std::make_shared( - data_->n_stft_, data_->n_mels_, data_->sample_rate_, data_->f_min_, data_->f_max_, data_->max_iter_, - data_->tolerance_loss_, data_->tolerance_change_, data_->sgdargs_, data_->norm_, data_->mel_type_); -} - -// InverseSpectrogram Transform Operation. -struct InverseSpectrogram::Data { - Data(int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, - bool normalized, bool center, BorderType pad_mode, bool onesided) - : length_(length), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - pad_(pad), - window_(window), - normalized_(normalized), - center_(center), - pad_mode_(pad_mode), - onesided_(onesided) {} - int32_t length_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; -}; - -InverseSpectrogram::InverseSpectrogram(int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, - int32_t pad, WindowType window, bool normalized, bool center, - BorderType pad_mode, bool onesided) - : data_(std::make_shared(length, n_fft, win_length, hop_length, pad, window, normalized, center, pad_mode, - onesided)) {} - -std::shared_ptr InverseSpectrogram::Parse() { - return std::make_shared( - data_->length_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, data_->window_, - data_->normalized_, data_->center_, data_->pad_mode_, data_->onesided_); -} - -// LFCC Transform Operation. -struct LFCC::Data { - Data(int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, int32_t dct_type, NormMode norm, - bool log_lf, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, float power, - bool normalized, bool center, BorderType pad_mode, bool onesided) - : sample_rate_(sample_rate), - n_filter_(n_filter), - n_lfcc_(n_lfcc), - f_min_(f_min), - f_max_(f_max), - dct_type_(dct_type), - norm_(norm), - log_lf_(log_lf), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - pad_(pad), - window_(window), - power_(power), - normalized_(normalized), - center_(center), - pad_mode_(pad_mode), - onesided_(onesided) {} - int32_t sample_rate_; - int32_t n_filter_; - int32_t n_lfcc_; - float f_min_; - float f_max_; - int32_t dct_type_; - NormMode norm_; - bool log_lf_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; - float power_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; -}; - -LFCC::LFCC(int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, int32_t dct_type, - NormMode norm, bool log_lf, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, - WindowType window, float power, bool normalized, bool center, BorderType pad_mode, bool onesided) - : data_(std::make_shared(sample_rate, n_filter, n_lfcc, f_min, f_max, dct_type, norm, log_lf, n_fft, - win_length, hop_length, pad, window, power, normalized, center, pad_mode, - onesided)) {} - -std::shared_ptr LFCC::Parse() { - return std::make_shared( - data_->sample_rate_, data_->n_filter_, data_->n_lfcc_, data_->f_min_, data_->f_max_, data_->dct_type_, data_->norm_, - data_->log_lf_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, data_->window_, data_->power_, - data_->normalized_, data_->center_, data_->pad_mode_, data_->onesided_); -} - -// LFilter Transform Operation. -struct LFilter::Data { - Data(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} - std::vector a_coeffs_; - std::vector b_coeffs_; - bool clamp_; -}; - -LFilter::LFilter(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : data_(std::make_shared(a_coeffs, b_coeffs, clamp)) {} - -std::shared_ptr LFilter::Parse() { - return std::make_shared(data_->a_coeffs_, data_->b_coeffs_, data_->clamp_); -} - -// LowpassBiquad Transform Operation. -struct LowpassBiquad::Data { - Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} - int32_t sample_rate_; - float cutoff_freq_; - float Q_; -}; - -LowpassBiquad::LowpassBiquad(int32_t sample_rate, float cutoff_freq, float Q) - : data_(std::make_shared(sample_rate, cutoff_freq, Q)) {} - -std::shared_ptr LowpassBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->cutoff_freq_, data_->Q_); -} - -// Magphase Transform Operation. -struct Magphase::Data { - explicit Data(float power) : power_(power) {} - float power_; -}; - -Magphase::Magphase(float power) : data_(std::make_shared(power)) {} - -std::shared_ptr Magphase::Parse() { return std::make_shared(data_->power_); } - -// MaskAlongAxis Transform Operation. -struct MaskAlongAxis::Data { - Data(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) - : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} - int32_t mask_start_; - int32_t mask_width_; - float mask_value_; - int32_t axis_; -}; - -MaskAlongAxis::MaskAlongAxis(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) - : data_(std::make_shared(mask_start, mask_width, mask_value, axis)) {} - -std::shared_ptr MaskAlongAxis::Parse() { - return std::make_shared(data_->mask_start_, data_->mask_width_, data_->mask_value_, - data_->axis_); -} - -// MaskAlongAxisIID Transform Operation. -struct MaskAlongAxisIID::Data { - Data(int32_t mask_param, float mask_value, int32_t axis) - : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) {} - int32_t mask_param_; - float mask_value_; - int32_t axis_; -}; - -MaskAlongAxisIID::MaskAlongAxisIID(int32_t mask_param, float mask_value, int32_t axis) - : data_(std::make_shared(mask_param, mask_value, axis)) {} - -std::shared_ptr MaskAlongAxisIID::Parse() { - return std::make_shared(data_->mask_param_, data_->mask_value_, data_->axis_); -} - -// MelScale Transform Operation. -struct MelScale::Data { - Data(int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, NormType norm, MelType mel_type) - : n_mels_(n_mels), - sample_rate_(sample_rate), - f_min_(f_min), - f_max_(f_max), - n_stft_(n_stft), - norm_(norm), - mel_type_(mel_type) {} - int32_t n_mels_; - int32_t sample_rate_; - float f_min_; - float f_max_; - int32_t n_stft_; - NormType norm_; - MelType mel_type_; -}; - -MelScale::MelScale(int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, NormType norm, - MelType mel_type) - : data_(std::make_shared(n_mels, sample_rate, f_min, f_max, n_stft, norm, mel_type)) {} - -std::shared_ptr MelScale::Parse() { - return std::make_shared(data_->n_mels_, data_->sample_rate_, data_->f_min_, data_->f_max_, - data_->n_stft_, data_->norm_, data_->mel_type_); -} - -// MelscaleFbanks Function. -Status MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels, int32_t sample_rate, - NormType norm, MelType mel_type) { - RETURN_UNEXPECTED_IF_NULL(output); - CHECK_FAIL_RETURN_UNEXPECTED(n_freqs > 0, - "MelscaleFbanks: n_freqs must be greater than 0, got: " + std::to_string(n_freqs)); - - CHECK_FAIL_RETURN_UNEXPECTED(f_min >= 0, "MelscaleFbanks: f_min must be non negative, got: " + std::to_string(f_min)); - CHECK_FAIL_RETURN_UNEXPECTED(f_max > 0, - "MelscaleFbanks: f_max must be greater than 0, got: " + std::to_string(f_max)); - CHECK_FAIL_RETURN_UNEXPECTED(n_mels > 0, - "MelscaleFbanks: n_mels must be greater than 0, got: " + std::to_string(n_mels)); - CHECK_FAIL_RETURN_UNEXPECTED( - sample_rate > 0, "MelscaleFbanks: sample_rate must be greater than 0, got: " + std::to_string(sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED(f_max > f_min, "MelscaleFbanks: f_max must be greater than f_min, got: f_min = " + - std::to_string(f_min) + ", while f_max = " + std::to_string(f_max)); - std::shared_ptr fb; - RETURN_IF_NOT_OK(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type)); - CHECK_FAIL_RETURN_UNEXPECTED(fb->HasData(), - "MelscaleFbanks: get an empty tensor with shape " + fb->shape().ToString()); - *output = mindspore::MSTensor(std::make_shared(fb)); - return Status::OK(); -} - -// MelSpectrogram Transform Operation. -struct MelSpectrogram::Data { - Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, float f_max, - int32_t pad, int32_t n_mels, WindowType window, float power, bool normalized, bool center, BorderType pad_mode, - bool onesided, NormType norm, MelType mel_scale) - : sample_rate_(sample_rate), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - f_min_(f_min), - f_max_(f_max), - pad_(pad), - n_mels_(n_mels), - window_(window), - power_(power), - normalized_(normalized), - center_(center), - pad_mode_(pad_mode), - onesided_(onesided), - norm_(norm), - mel_scale_(mel_scale) {} - int32_t sample_rate_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - float f_min_; - float f_max_; - int32_t pad_; - int32_t n_mels_; - WindowType window_; - float power_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; - NormType norm_; - MelType mel_scale_; -}; - -MelSpectrogram::MelSpectrogram(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, - float f_max, int32_t pad, int32_t n_mels, WindowType window, float power, - bool normalized, bool center, BorderType pad_mode, bool onesided, NormType norm, - MelType mel_scale) - : data_(std::make_shared(sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, window, power, - normalized, center, pad_mode, onesided, norm, mel_scale)) {} - -std::shared_ptr MelSpectrogram::Parse() { - return std::make_shared( - data_->sample_rate_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->f_min_, data_->f_max_, - data_->pad_, data_->n_mels_, data_->window_, data_->power_, data_->normalized_, data_->center_, data_->pad_mode_, - data_->onesided_, data_->norm_, data_->mel_scale_); -} - -// MFCC Transform Operation. -struct MFCC::Data { - Data(int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, bool log_mels, int32_t n_fft, - int32_t win_length, int32_t hop_length, float f_min, float f_max, int32_t pad, int32_t n_mels, WindowType window, - float power, bool normalized, bool center, BorderType pad_mode, bool onesided, NormType norm_mel, - MelType mel_scale) - : sample_rate_(sample_rate), - n_mfcc_(n_mfcc), - dct_type_(dct_type), - norm_(norm), - log_mels_(log_mels), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - f_min_(f_min), - f_max_(f_max), - pad_(pad), - n_mels_(n_mels), - window_(window), - power_(power), - normalized_(normalized), - center_(center), - pad_mode_(pad_mode), - onesided_(onesided), - norm_mel_(norm_mel), - mel_scale_(mel_scale) {} - int32_t sample_rate_; - int32_t n_mfcc_; - int32_t dct_type_; - NormMode norm_; - bool log_mels_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - float f_min_; - float f_max_; - int32_t pad_; - int32_t n_mels_; - WindowType window_; - float power_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; - NormType norm_mel_; - MelType mel_scale_; - std::map melkwargs_; -}; - -MFCC::MFCC(int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, bool log_mels, int32_t n_fft, - int32_t win_length, int32_t hop_length, float f_min, float f_max, int32_t pad, int32_t n_mels, - WindowType window, float power, bool normalized, bool center, BorderType pad_mode, bool onesided, - NormType norm_mel, MelType mel_scale) - : data_(std::make_shared(sample_rate, n_mfcc, dct_type, norm, log_mels, n_fft, win_length, hop_length, f_min, - f_max, pad, n_mels, window, power, normalized, center, pad_mode, onesided, norm_mel, - mel_scale)) {} - -std::shared_ptr MFCC::Parse() { - return std::make_shared(data_->sample_rate_, data_->n_mfcc_, data_->dct_type_, data_->norm_, - data_->log_mels_, data_->n_fft_, data_->win_length_, data_->hop_length_, - data_->f_min_, data_->f_max_, data_->pad_, data_->n_mels_, data_->window_, - data_->power_, data_->normalized_, data_->center_, data_->pad_mode_, - data_->onesided_, data_->norm_mel_, data_->mel_scale_); -} - -// MuLawDecoding Transform Operation. -struct MuLawDecoding::Data { - explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {} - int32_t quantization_channels_; -}; - -MuLawDecoding::MuLawDecoding(int32_t quantization_channels) : data_(std::make_shared(quantization_channels)) {} - -std::shared_ptr MuLawDecoding::Parse() { - return std::make_shared(data_->quantization_channels_); -} - -// MuLawEncoding Transform Operation. -struct MuLawEncoding::Data { - explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {} - int32_t quantization_channels_; -}; - -MuLawEncoding::MuLawEncoding(int32_t quantization_channels) : data_(std::make_shared(quantization_channels)) {} - -std::shared_ptr MuLawEncoding::Parse() { - return std::make_shared(data_->quantization_channels_); -} - -// LinearFbanks Function. -Status LinearFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_filter, - int32_t sample_rate) { - RETURN_UNEXPECTED_IF_NULL(output); - CHECK_FAIL_RETURN_UNEXPECTED(n_freqs > 0, - "LinearFbanks: n_freqs must be greater than 0, got: " + std::to_string(n_freqs)); - - CHECK_FAIL_RETURN_UNEXPECTED(f_min >= 0, "LinearFbanks: f_min must be non negative, got: " + std::to_string(f_min)); - CHECK_FAIL_RETURN_UNEXPECTED(f_max > 0, "LinearFbanks: f_max must be greater than 0, got: " + std::to_string(f_max)); - CHECK_FAIL_RETURN_UNEXPECTED(n_filter > 0, - "LinearFbanks: n_filter must be greater than 0, got: " + std::to_string(n_filter)); - CHECK_FAIL_RETURN_UNEXPECTED(sample_rate > 0, - "LinearFbanks: sample_rate must be greater than 0, got: " + std::to_string(sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED(f_max > f_min, "LinearFbanks: f_max must be greater than f_min, got: f_min = " + - std::to_string(f_min) + ", while f_max = " + std::to_string(f_max)); - std::shared_ptr fb; - RETURN_IF_NOT_OK(CreateLinearFbanks(&fb, n_freqs, f_min, f_max, n_filter, sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED(fb->HasData(), "LinearFbanks: get an empty tensor with shape " + fb->shape().ToString()); - *output = mindspore::MSTensor(std::make_shared(fb)); - return Status::OK(); -} - -// Overdrive Transform Operation. -struct Overdrive::Data { - Data(float gain, float color) : gain_(gain), color_(color) {} - float gain_; - float color_; -}; - -Overdrive::Overdrive(float gain, float color) : data_(std::make_shared(gain, color)) {} - -std::shared_ptr Overdrive::Parse() { - return std::make_shared(data_->gain_, data_->color_); -} - -// Phaser Transform Operation. -struct Phaser::Data { - Data(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed, - bool sinusoidal) - : sample_rate_(sample_rate), - gain_in_(gain_in), - gain_out_(gain_out), - delay_ms_(delay_ms), - decay_(decay), - mod_speed_(mod_speed), - sinusoidal_(sinusoidal) {} - int32_t sample_rate_; - float gain_in_; - float gain_out_; - float delay_ms_; - float decay_; - float mod_speed_; - bool sinusoidal_; -}; - -Phaser::Phaser(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed, - bool sinusoidal) - : data_(std::make_shared(sample_rate, gain_in, gain_out, delay_ms, decay, mod_speed, sinusoidal)) {} - -std::shared_ptr Phaser::Parse() { - return std::make_shared(data_->sample_rate_, data_->gain_in_, data_->gain_out_, data_->delay_ms_, - data_->decay_, data_->mod_speed_, data_->sinusoidal_); -} - -// PhaseVocoder Transofrm Operation. -struct PhaseVocoder::Data { - Data(float rate, const MSTensor &phase_advance) : rate_(rate), phase_advance_(phase_advance) {} - float rate_; - MSTensor phase_advance_; -}; - -PhaseVocoder::PhaseVocoder(float rate, const MSTensor &phase_advance) - : data_(std::make_shared(rate, phase_advance)) {} - -std::shared_ptr PhaseVocoder::Parse() { - std::shared_ptr phase_advance; - Status rc = Tensor::CreateFromMSTensor(data_->phase_advance_, &phase_advance); - if (rc.IsError()) { - MS_LOG(ERROR) << "Error creating phase_vocoder constant tensor." << rc; - return nullptr; - } - return std::make_shared(data_->rate_, phase_advance); -} - -// pitchshift -struct PitchShift::Data { - Data(int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, - int32_t hop_length, WindowType window) - : sample_rate_(sample_rate), - n_steps_(n_steps), - bins_per_octave_(bins_per_octave), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - window_(window) {} - int32_t sample_rate_; - int32_t n_steps_; - int32_t bins_per_octave_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - WindowType window_; -}; - -PitchShift::PitchShift(int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, - int32_t hop_length, WindowType window) - : data_(std::make_shared(sample_rate, n_steps, bins_per_octave, n_fft, win_length, hop_length, window)) {} - -std::shared_ptr PitchShift::Parse() { - return std::make_shared(data_->sample_rate_, data_->n_steps_, data_->bins_per_octave_, - data_->n_fft_, data_->win_length_, data_->hop_length_, data_->window_); -} - -// Resample Transform Operation. -struct Resample::Data { - Data(float orig_freq, float new_freq, ResampleMethod resample_method, int32_t lowpass_filter_width, float rolloff, - float beta) - : orig_freq_(orig_freq), - new_freq_(new_freq), - resample_method_(resample_method), - lowpass_filter_width_(lowpass_filter_width), - rolloff_(rolloff), - beta_(beta) {} - float orig_freq_; - float new_freq_; - ResampleMethod resample_method_; - int32_t lowpass_filter_width_; - float rolloff_; - float beta_; -}; - -Resample::Resample(float orig_freq, float new_freq, ResampleMethod resample_method, int32_t lowpass_filter_width, - float rolloff, float beta) - : data_(std::make_shared(orig_freq, new_freq, resample_method, lowpass_filter_width, rolloff, beta)) {} - -std::shared_ptr Resample::Parse() { - return std::make_shared(data_->orig_freq_, data_->new_freq_, data_->resample_method_, - data_->lowpass_filter_width_, data_->rolloff_, data_->beta_); -} - -// RiaaBiquad Transform Operation. -struct RiaaBiquad::Data { - explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} - int32_t sample_rate_; -}; - -RiaaBiquad::RiaaBiquad(int32_t sample_rate) : data_(std::make_shared(sample_rate)) {} - -std::shared_ptr RiaaBiquad::Parse() { - return std::make_shared(data_->sample_rate_); -} - -// SlidingWindowCmn Transform Operation. -struct SlidingWindowCmn::Data { - Data(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) - : cmn_window_(cmn_window), min_cmn_window_(min_cmn_window), center_(center), norm_vars_(norm_vars) {} - int32_t cmn_window_; - int32_t min_cmn_window_; - bool center_; - bool norm_vars_; -}; - -SlidingWindowCmn::SlidingWindowCmn(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) - : data_(std::make_shared(cmn_window, min_cmn_window, center, norm_vars)) {} - -std::shared_ptr SlidingWindowCmn::Parse() { - return std::make_shared(data_->cmn_window_, data_->min_cmn_window_, data_->center_, - data_->norm_vars_); -} - -// Spectrogram Transform Operation. -struct Spectrogram::Data { - Data(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, float power, - bool normalized, bool center, BorderType pad_mode, bool onesided) - : n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - pad_(pad), - window_(window), - power_(power), - normalized_(normalized), - center_(center), - pad_mode_(pad_mode), - onesided_(onesided) {} - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; - float power_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; -}; - -Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, - float power, bool normalized, bool center, BorderType pad_mode, bool onesided) - : data_(std::make_shared(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode, - onesided)) {} - -std::shared_ptr Spectrogram::Parse() { - return std::make_shared(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, - data_->window_, data_->power_, data_->normalized_, data_->center_, - data_->pad_mode_, data_->onesided_); -} - -// SpectralCentroid Transform Operation. -struct SpectralCentroid::Data { - Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window) - : sample_rate_(sample_rate), - n_fft_(n_fft), - win_length_(win_length), - hop_length_(hop_length), - pad_(pad), - window_(window) {} - int32_t sample_rate_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; -}; - -SpectralCentroid::SpectralCentroid(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, - int32_t pad, WindowType window) - : data_(std::make_shared(sample_rate, n_fft, win_length, hop_length, pad, window)) {} - -std::shared_ptr SpectralCentroid::Parse() { - return std::make_shared(data_->sample_rate_, data_->n_fft_, data_->win_length_, - data_->hop_length_, data_->pad_, data_->window_); -} - -// TimeMasking Transform Operation. -struct TimeMasking::Data { - Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) - : iid_masks_(iid_masks), time_mask_param_(time_mask_param), mask_start_(mask_start), mask_value_(mask_value) {} - bool iid_masks_; - int32_t time_mask_param_; - int32_t mask_start_; - float mask_value_; -}; - -TimeMasking::TimeMasking(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) - : data_(std::make_shared(iid_masks, time_mask_param, mask_start, mask_value)) {} - -std::shared_ptr TimeMasking::Parse() { - return std::make_shared(data_->iid_masks_, data_->time_mask_param_, data_->mask_start_, - data_->mask_value_); -} - -// TimeStretch Transform Operation. -struct TimeStretch::Data { - explicit Data(float hop_length, int32_t n_freq, float fixed_rate) - : hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {} - float hop_length_; - int32_t n_freq_; - float fixed_rate_; -}; - -TimeStretch::TimeStretch(float hop_length, int32_t n_freq, float fixed_rate) - : data_(std::make_shared(hop_length, n_freq, fixed_rate)) {} - -std::shared_ptr TimeStretch::Parse() { - return std::make_shared(data_->hop_length_, data_->n_freq_, data_->fixed_rate_); -} - -// TrebleBiquad Transform Operation. -struct TrebleBiquad::Data { - Data(int32_t sample_rate, float gain, float central_freq, float Q) - : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} - int32_t sample_rate_; - float gain_; - float central_freq_; - float Q_; -}; - -TrebleBiquad::TrebleBiquad(int32_t sample_rate, float gain, float central_freq, float Q) - : data_(std::make_shared(sample_rate, gain, central_freq, Q)) {} - -std::shared_ptr TrebleBiquad::Parse() { - return std::make_shared(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_); -} - -// Vad Transform Operation. -struct Vad::Data { - Data(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap, - float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time, - float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time, - float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) - : sample_rate_(sample_rate), - trigger_level_(trigger_level), - trigger_time_(trigger_time), - search_time_(search_time), - allowed_gap_(allowed_gap), - pre_trigger_time_(pre_trigger_time), - boot_time_(boot_time), - noise_up_time_(noise_up_time), - noise_down_time_(noise_down_time), - noise_reduction_amount_(noise_reduction_amount), - measure_freq_(measure_freq), - measure_duration_(measure_duration), - measure_smooth_time_(measure_smooth_time), - hp_filter_freq_(hp_filter_freq), - lp_filter_freq_(lp_filter_freq), - hp_lifter_freq_(hp_lifter_freq), - lp_lifter_freq_(lp_lifter_freq) {} - int32_t sample_rate_; - float trigger_level_; - float trigger_time_; - float search_time_; - float allowed_gap_; - float pre_trigger_time_; - float boot_time_; - float noise_up_time_; - float noise_down_time_; - float noise_reduction_amount_; - float measure_freq_; - float measure_duration_; - float measure_smooth_time_; - float hp_filter_freq_; - float lp_filter_freq_; - float hp_lifter_freq_; - float lp_lifter_freq_; -}; - -Vad::Vad(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap, - float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time, - float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time, - float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) - : data_(std::make_shared(sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time, - boot_time, noise_up_time, noise_down_time, noise_reduction_amount, measure_freq, - measure_duration, measure_smooth_time, hp_filter_freq, lp_filter_freq, - hp_lifter_freq, lp_lifter_freq)) {} - -std::shared_ptr Vad::Parse() { - return std::make_shared( - data_->sample_rate_, data_->trigger_level_, data_->trigger_time_, data_->search_time_, data_->allowed_gap_, - data_->pre_trigger_time_, data_->boot_time_, data_->noise_up_time_, data_->noise_down_time_, - data_->noise_reduction_amount_, data_->measure_freq_, data_->measure_duration_, data_->measure_smooth_time_, - data_->hp_filter_freq_, data_->lp_filter_freq_, data_->hp_lifter_freq_, data_->lp_lifter_freq_); -} - -// Vol Transform Operation. -struct Vol::Data { - Data(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {} - float gain_; - GainType gain_type_; -}; - -Vol::Vol(float gain, GainType gain_type) : data_(std::make_shared(gain, gain_type)) {} - -std::shared_ptr Vol::Parse() { - return std::make_shared(data_->gain_, data_->gain_type_); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/dataset/audio.h" + +#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" +#include "minddata/dataset/audio/ir/kernels/angle_ir.h" +#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" +#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h" +#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" +#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" +#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" +#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" +#include "minddata/dataset/audio/ir/kernels/dither_ir.h" +#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/fade_ir.h" +#include "minddata/dataset/audio/ir/kernels/filtfilt_ir.h" +#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" +#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" +#include "minddata/dataset/audio/ir/kernels/gain_ir.h" +#include "minddata/dataset/audio/ir/kernels/griffin_lim_ir.h" +#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/inverse_mel_scale_ir.h" +#include "minddata/dataset/audio/ir/kernels/inverse_spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/lfcc_ir.h" +#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" +#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" +#include "minddata/dataset/audio/ir/kernels/mel_scale_ir.h" +#include "minddata/dataset/audio/ir/kernels/mel_spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/mfcc_ir.h" +#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" +#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" +#include "minddata/dataset/audio/ir/kernels/phaser_ir.h" +#include "minddata/dataset/audio/ir/kernels/pitch_shift_ir.h" +#include "minddata/dataset/audio/ir/kernels/resample_ir.h" +#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h" +#include "minddata/dataset/audio/ir/kernels/spectral_centroid_ir.h" +#include "minddata/dataset/audio/ir/kernels/spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" +#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" +#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/vad_ir.h" +#include "minddata/dataset/audio/ir/kernels/vol_ir.h" +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/audio_utils.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// AllpassBiquad Transform Operation. +struct AllpassBiquad::Data { + Data(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + int32_t sample_rate_; + float central_freq_; + float Q_; +}; + +AllpassBiquad::AllpassBiquad(int32_t sample_rate, float central_freq, float Q) + : data_(std::make_shared(sample_rate, central_freq, Q)) {} + +std::shared_ptr AllpassBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_); +} + +// AmplitudeToDB Transform Operation. +struct AmplitudeToDB::Data { + Data(ScaleType stype, float ref_value, float amin, float top_db) + : stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {} + ScaleType stype_; + float ref_value_; + float amin_; + float top_db_; +}; + +AmplitudeToDB::AmplitudeToDB(ScaleType stype, float ref_value, float amin, float top_db) + : data_(std::make_shared(stype, ref_value, amin, top_db)) {} + +std::shared_ptr AmplitudeToDB::Parse() { + return std::make_shared(data_->stype_, data_->ref_value_, data_->amin_, data_->top_db_); +} + +// Angle Transform Operation. +Angle::Angle() = default; + +std::shared_ptr Angle::Parse() { return std::make_shared(); } + +// BandBiquad Transform Operation. +struct BandBiquad::Data { + Data(int32_t sample_rate, float central_freq, float Q, bool noise) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} + int32_t sample_rate_; + float central_freq_; + float Q_; + bool noise_; +}; + +BandBiquad::BandBiquad(int32_t sample_rate, float central_freq, float Q, bool noise) + : data_(std::make_shared(sample_rate, central_freq, Q, noise)) {} + +std::shared_ptr BandBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_, data_->noise_); +} + +// BandpassBiquad Transform Operation. +struct BandpassBiquad::Data { + Data(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} + int32_t sample_rate_; + float central_freq_; + float Q_; + bool const_skirt_gain_; +}; + +BandpassBiquad::BandpassBiquad(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) + : data_(std::make_shared(sample_rate, central_freq, Q, const_skirt_gain)) {} + +std::shared_ptr BandpassBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_, + data_->const_skirt_gain_); +} + +// BandrejectBiquad Transform Operation. +struct BandrejectBiquad::Data { + Data(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + int32_t sample_rate_; + float central_freq_; + float Q_; +}; + +BandrejectBiquad::BandrejectBiquad(int32_t sample_rate, float central_freq, float Q) + : data_(std::make_shared(sample_rate, central_freq, Q)) {} + +std::shared_ptr BandrejectBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_); +} + +// BassBiquad Transform Operation. +struct BassBiquad::Data { + Data(int32_t sample_rate, float gain, float central_freq, float Q) + : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} + int32_t sample_rate_; + float gain_; + float central_freq_; + float Q_; +}; + +BassBiquad::BassBiquad(int32_t sample_rate, float gain, float central_freq, float Q) + : data_(std::make_shared(sample_rate, gain, central_freq, Q)) {} + +std::shared_ptr BassBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_); +} + +// Biquad Transform Operation. +struct Biquad::Data { + Data(float b0, float b1, float b2, float a0, float a1, float a2) + : b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {} + float b0_; + float b1_; + float b2_; + float a0_; + float a1_; + float a2_; +}; + +Biquad::Biquad(float b0, float b1, float b2, float a0, float a1, float a2) + : data_(std::make_shared(b0, b1, b2, a0, a1, a2)) {} + +std::shared_ptr Biquad::Parse() { + return std::make_shared(data_->b0_, data_->b1_, data_->b2_, data_->a0_, data_->a1_, data_->a1_); +} + +// ComplexNorm Transform Operation. +struct ComplexNorm::Data { + explicit Data(float power) : power_(power) {} + float power_; +}; + +ComplexNorm::ComplexNorm(float power) : data_(std::make_shared(power)) {} + +std::shared_ptr ComplexNorm::Parse() { return std::make_shared(data_->power_); } + +// ComputeDeltas Transform Operation. +struct ComputeDeltas::Data { + Data(int32_t win_length, BorderType pad_mode) : win_length_(win_length), pad_mode_(pad_mode) {} + int32_t win_length_; + BorderType pad_mode_; +}; + +ComputeDeltas::ComputeDeltas(int32_t win_length, BorderType pad_mode) + : data_(std::make_shared(win_length, pad_mode)) {} + +std::shared_ptr ComputeDeltas::Parse() { + return std::make_shared(data_->win_length_, data_->pad_mode_); +} + +// Contrast Transform Operation. +struct Contrast::Data { + explicit Data(float enhancement_amount) : enhancement_amount_(enhancement_amount) {} + float enhancement_amount_; +}; + +Contrast::Contrast(float enhancement_amount) : data_(std::make_shared(enhancement_amount)) {} + +std::shared_ptr Contrast::Parse() { + return std::make_shared(data_->enhancement_amount_); +} + +// DBToAmplitude Transform Operation. +struct DBToAmplitude::Data { + explicit Data(float ref, float power) : ref_(ref), power_(power) {} + float ref_; + float power_; +}; + +DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared(ref, power)) {} + +std::shared_ptr DBToAmplitude::Parse() { + return std::make_shared(data_->ref_, data_->power_); +} + +// DCShift Transform Operation. +struct DCShift::Data { + Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} + float shift_; + float limiter_gain_; +}; + +DCShift::DCShift(float shift) : data_(std::make_shared(shift, shift)) {} + +DCShift::DCShift(float shift, float limiter_gain) : data_(std::make_shared(shift, limiter_gain)) {} + +std::shared_ptr DCShift::Parse() { + return std::make_shared(data_->shift_, data_->limiter_gain_); +} + +Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm) { + RETURN_UNEXPECTED_IF_NULL(output); + RETURN_IF_NOT_OK(ValidateIntScalarPositive("CreateDct", "n_mfcc", n_mfcc)); + RETURN_IF_NOT_OK(ValidateIntScalarPositive("CreateDct", "n_mels", n_mels)); + + std::shared_ptr dct; + RETURN_IF_NOT_OK(Dct(&dct, n_mfcc, n_mels, norm)); + CHECK_FAIL_RETURN_UNEXPECTED(dct->HasData(), "CreateDct: get an empty tensor with shape " + dct->shape().ToString()); + *output = mindspore::MSTensor(std::make_shared(dct)); + return Status::OK(); +} + +// DeemphBiquad Transform Operation. +struct DeemphBiquad::Data { + explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} + int32_t sample_rate_; +}; + +DeemphBiquad::DeemphBiquad(int32_t sample_rate) : data_(std::make_shared(sample_rate)) {} + +std::shared_ptr DeemphBiquad::Parse() { + return std::make_shared(data_->sample_rate_); +} + +// DetectPitchFrequency Transform Operation. +struct DetectPitchFrequency::Data { + Data(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high) + : sample_rate_(sample_rate), + frame_time_(frame_time), + win_length_(win_length), + freq_low_(freq_low), + freq_high_(freq_high) {} + int32_t sample_rate_; + float frame_time_; + int32_t win_length_; + int32_t freq_low_; + int32_t freq_high_; +}; + +DetectPitchFrequency::DetectPitchFrequency(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, + int32_t freq_high) + : data_(std::make_shared(sample_rate, frame_time, win_length, freq_low, freq_high)) {} + +std::shared_ptr DetectPitchFrequency::Parse() { + return std::make_shared(data_->sample_rate_, data_->frame_time_, data_->win_length_, + data_->freq_low_, data_->freq_high_); +} + +// Dither Transform Operation. +struct Dither::Data { + Data(DensityFunction density_function, bool noise_shaping) + : density_function_(density_function), noise_shaping_(noise_shaping) {} + DensityFunction density_function_; + bool noise_shaping_; +}; + +Dither::Dither(DensityFunction density_function, bool noise_shaping) + : data_(std::make_shared(density_function, noise_shaping)) {} + +std::shared_ptr Dither::Parse() { + return std::make_shared(data_->density_function_, data_->noise_shaping_); +} + +// EqualizerBiquad Transform Operation. +struct EqualizerBiquad::Data { + Data(int32_t sample_rate, float center_freq, float gain, float Q) + : sample_rate_(sample_rate), center_freq_(center_freq), gain_(gain), Q_(Q) {} + int32_t sample_rate_; + float center_freq_; + float gain_; + float Q_; +}; + +EqualizerBiquad::EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q) + : data_(std::make_shared(sample_rate, center_freq, gain, Q)) {} + +std::shared_ptr EqualizerBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->center_freq_, data_->gain_, data_->Q_); +} + +// Fade Transform Operation. +struct Fade::Data { + Data(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape) + : fade_in_len_(fade_in_len), fade_out_len_(fade_out_len), fade_shape_(fade_shape) {} + int32_t fade_in_len_; + int32_t fade_out_len_; + FadeShape fade_shape_; +}; + +Fade::Fade(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape) + : data_(std::make_shared(fade_in_len, fade_out_len, fade_shape)) {} + +std::shared_ptr Fade::Parse() { + return std::make_shared(data_->fade_in_len_, data_->fade_out_len_, data_->fade_shape_); +} + +// Filtfilt Transform Operation. +struct Filtfilt::Data { + Data(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} + std::vector a_coeffs_; + std::vector b_coeffs_; + bool clamp_; +}; + +Filtfilt::Filtfilt(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : data_(std::make_shared(a_coeffs, b_coeffs, clamp)) {} + +std::shared_ptr Filtfilt::Parse() { + return std::make_shared(data_->a_coeffs_, data_->b_coeffs_, data_->clamp_); +} + +// Flanger Transform Operation. +struct Flanger::Data { + Data(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, + Modulation modulation, Interpolation interpolation) + : sample_rate_(sample_rate), + delay_(delay), + depth_(depth), + regen_(regen), + width_(width), + speed_(speed), + phase_(phase), + modulation_(modulation), + interpolation_(interpolation) {} + int32_t sample_rate_; + float delay_; + float depth_; + float regen_; + float width_; + float speed_; + float phase_; + Modulation modulation_; + Interpolation interpolation_; +}; + +Flanger::Flanger(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, + Modulation modulation, Interpolation interpolation) + : data_(std::make_shared(sample_rate, delay, depth, regen, width, speed, phase, modulation, interpolation)) {} + +std::shared_ptr Flanger::Parse() { + return std::make_shared(data_->sample_rate_, data_->delay_, data_->depth_, data_->regen_, + data_->width_, data_->speed_, data_->phase_, data_->modulation_, + data_->interpolation_); +} + +// FrequencyMasking Transform Operation. +struct FrequencyMasking::Data { + Data(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) + : iid_masks_(iid_masks), + frequency_mask_param_(frequency_mask_param), + mask_start_(mask_start), + mask_value_(mask_value) {} + bool iid_masks_; + int32_t frequency_mask_param_; + int32_t mask_start_; + float mask_value_; +}; + +FrequencyMasking::FrequencyMasking(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) + : data_(std::make_shared(iid_masks, frequency_mask_param, mask_start, mask_value)) {} + +std::shared_ptr FrequencyMasking::Parse() { + return std::make_shared(data_->iid_masks_, data_->frequency_mask_param_, + data_->mask_start_, data_->mask_value_); +} + +// Gain Transform Operation. +struct Gain::Data { + explicit Data(float gain_db) : gain_db_(gain_db) {} + float gain_db_; +}; + +Gain::Gain(float gain_db) : data_(std::make_shared(gain_db)) {} + +std::shared_ptr Gain::Parse() { return std::make_shared(data_->gain_db_); } + +// GriffinLim Transform Operation. +struct GriffinLim::Data { + Data(int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, float power, + float momentum, int32_t length, bool rand_init) + : n_fft_(n_fft), + n_iter_(n_iter), + win_length_(win_length), + hop_length_(hop_length), + window_type_(window_type), + power_(power), + momentum_(momentum), + length_(length), + rand_init_(rand_init) {} + int32_t n_fft_; + int32_t n_iter_; + int32_t win_length_; + int32_t hop_length_; + WindowType window_type_; + float power_; + float momentum_; + int32_t length_; + bool rand_init_; +}; + +GriffinLim::GriffinLim(int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, + float power, float momentum, int32_t length, bool rand_init) + : data_(std::make_shared(n_fft, n_iter, win_length, hop_length, window_type, power, momentum, length, + rand_init)) {} + +std::shared_ptr GriffinLim::Parse() { + return std::make_shared(data_->n_fft_, data_->n_iter_, data_->win_length_, data_->hop_length_, + data_->window_type_, data_->power_, data_->momentum_, data_->length_, + data_->rand_init_); +} + +// HighpassBiquad Transform Operation. +struct HighpassBiquad::Data { + Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} + int32_t sample_rate_; + float cutoff_freq_; + float Q_; +}; + +HighpassBiquad::HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q) + : data_(std::make_shared(sample_rate, cutoff_freq, Q)) {} + +std::shared_ptr HighpassBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->cutoff_freq_, data_->Q_); +} + +// InverseMelScale Transform Operation. +struct InverseMelScale::Data { + Data(int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t max_iter, + float tolerance_loss, float tolerance_change, const std::map &sgdargs, NormType norm, + MelType mel_type) + : n_stft_(n_stft), + n_mels_(n_mels), + sample_rate_(sample_rate), + f_min_(f_min), + f_max_(f_max), + max_iter_(max_iter), + tolerance_loss_(tolerance_loss), + tolerance_change_(tolerance_change), + sgdargs_(sgdargs), + norm_(norm), + mel_type_(mel_type) {} + int32_t n_stft_; + int32_t n_mels_; + int32_t sample_rate_; + float f_min_; + float f_max_; + int32_t max_iter_; + float tolerance_loss_; + float tolerance_change_; + std::map sgdargs_; + NormType norm_; + MelType mel_type_; +}; + +InverseMelScale::InverseMelScale(int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, + int32_t max_iter, float tolerance_loss, float tolerance_change, + const std::map &sgdargs, NormType norm, MelType mel_type) + : data_(std::make_shared(n_stft, n_mels, sample_rate, f_min, f_max, max_iter, tolerance_loss, + tolerance_change, sgdargs, norm, mel_type)) {} + +std::shared_ptr InverseMelScale::Parse() { + return std::make_shared( + data_->n_stft_, data_->n_mels_, data_->sample_rate_, data_->f_min_, data_->f_max_, data_->max_iter_, + data_->tolerance_loss_, data_->tolerance_change_, data_->sgdargs_, data_->norm_, data_->mel_type_); +} + +// InverseSpectrogram Transform Operation. +struct InverseSpectrogram::Data { + Data(int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, + bool normalized, bool center, BorderType pad_mode, bool onesided) + : length_(length), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + pad_(pad), + window_(window), + normalized_(normalized), + center_(center), + pad_mode_(pad_mode), + onesided_(onesided) {} + int32_t length_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; +}; + +InverseSpectrogram::InverseSpectrogram(int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, + int32_t pad, WindowType window, bool normalized, bool center, + BorderType pad_mode, bool onesided) + : data_(std::make_shared(length, n_fft, win_length, hop_length, pad, window, normalized, center, pad_mode, + onesided)) {} + +std::shared_ptr InverseSpectrogram::Parse() { + return std::make_shared( + data_->length_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, data_->window_, + data_->normalized_, data_->center_, data_->pad_mode_, data_->onesided_); +} + +// LFCC Transform Operation. +struct LFCC::Data { + Data(int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, int32_t dct_type, NormMode norm, + bool log_lf, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, float power, + bool normalized, bool center, BorderType pad_mode, bool onesided) + : sample_rate_(sample_rate), + n_filter_(n_filter), + n_lfcc_(n_lfcc), + f_min_(f_min), + f_max_(f_max), + dct_type_(dct_type), + norm_(norm), + log_lf_(log_lf), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + pad_(pad), + window_(window), + power_(power), + normalized_(normalized), + center_(center), + pad_mode_(pad_mode), + onesided_(onesided) {} + int32_t sample_rate_; + int32_t n_filter_; + int32_t n_lfcc_; + float f_min_; + float f_max_; + int32_t dct_type_; + NormMode norm_; + bool log_lf_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; + float power_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; +}; + +LFCC::LFCC(int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, int32_t dct_type, + NormMode norm, bool log_lf, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, + WindowType window, float power, bool normalized, bool center, BorderType pad_mode, bool onesided) + : data_(std::make_shared(sample_rate, n_filter, n_lfcc, f_min, f_max, dct_type, norm, log_lf, n_fft, + win_length, hop_length, pad, window, power, normalized, center, pad_mode, + onesided)) {} + +std::shared_ptr LFCC::Parse() { + return std::make_shared( + data_->sample_rate_, data_->n_filter_, data_->n_lfcc_, data_->f_min_, data_->f_max_, data_->dct_type_, data_->norm_, + data_->log_lf_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, data_->window_, data_->power_, + data_->normalized_, data_->center_, data_->pad_mode_, data_->onesided_); +} + +// LFilter Transform Operation. +struct LFilter::Data { + Data(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} + std::vector a_coeffs_; + std::vector b_coeffs_; + bool clamp_; +}; + +LFilter::LFilter(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : data_(std::make_shared(a_coeffs, b_coeffs, clamp)) {} + +std::shared_ptr LFilter::Parse() { + return std::make_shared(data_->a_coeffs_, data_->b_coeffs_, data_->clamp_); +} + +// LowpassBiquad Transform Operation. +struct LowpassBiquad::Data { + Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} + int32_t sample_rate_; + float cutoff_freq_; + float Q_; +}; + +LowpassBiquad::LowpassBiquad(int32_t sample_rate, float cutoff_freq, float Q) + : data_(std::make_shared(sample_rate, cutoff_freq, Q)) {} + +std::shared_ptr LowpassBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->cutoff_freq_, data_->Q_); +} + +// Magphase Transform Operation. +struct Magphase::Data { + explicit Data(float power) : power_(power) {} + float power_; +}; + +Magphase::Magphase(float power) : data_(std::make_shared(power)) {} + +std::shared_ptr Magphase::Parse() { return std::make_shared(data_->power_); } + +// MaskAlongAxis Transform Operation. +struct MaskAlongAxis::Data { + Data(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) + : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} + int32_t mask_start_; + int32_t mask_width_; + float mask_value_; + int32_t axis_; +}; + +MaskAlongAxis::MaskAlongAxis(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) + : data_(std::make_shared(mask_start, mask_width, mask_value, axis)) {} + +std::shared_ptr MaskAlongAxis::Parse() { + return std::make_shared(data_->mask_start_, data_->mask_width_, data_->mask_value_, + data_->axis_); +} + +// MaskAlongAxisIID Transform Operation. +struct MaskAlongAxisIID::Data { + Data(int32_t mask_param, float mask_value, int32_t axis) + : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) {} + int32_t mask_param_; + float mask_value_; + int32_t axis_; +}; + +MaskAlongAxisIID::MaskAlongAxisIID(int32_t mask_param, float mask_value, int32_t axis) + : data_(std::make_shared(mask_param, mask_value, axis)) {} + +std::shared_ptr MaskAlongAxisIID::Parse() { + return std::make_shared(data_->mask_param_, data_->mask_value_, data_->axis_); +} + +// MelScale Transform Operation. +struct MelScale::Data { + Data(int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, NormType norm, MelType mel_type) + : n_mels_(n_mels), + sample_rate_(sample_rate), + f_min_(f_min), + f_max_(f_max), + n_stft_(n_stft), + norm_(norm), + mel_type_(mel_type) {} + int32_t n_mels_; + int32_t sample_rate_; + float f_min_; + float f_max_; + int32_t n_stft_; + NormType norm_; + MelType mel_type_; +}; + +MelScale::MelScale(int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, NormType norm, + MelType mel_type) + : data_(std::make_shared(n_mels, sample_rate, f_min, f_max, n_stft, norm, mel_type)) {} + +std::shared_ptr MelScale::Parse() { + return std::make_shared(data_->n_mels_, data_->sample_rate_, data_->f_min_, data_->f_max_, + data_->n_stft_, data_->norm_, data_->mel_type_); +} + +// MelscaleFbanks Function. +Status MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels, int32_t sample_rate, + NormType norm, MelType mel_type) { + RETURN_UNEXPECTED_IF_NULL(output); + CHECK_FAIL_RETURN_UNEXPECTED(n_freqs > 0, + "MelscaleFbanks: n_freqs must be greater than 0, got: " + std::to_string(n_freqs)); + + CHECK_FAIL_RETURN_UNEXPECTED(f_min >= 0, "MelscaleFbanks: f_min must be non negative, got: " + std::to_string(f_min)); + CHECK_FAIL_RETURN_UNEXPECTED(f_max > 0, + "MelscaleFbanks: f_max must be greater than 0, got: " + std::to_string(f_max)); + CHECK_FAIL_RETURN_UNEXPECTED(n_mels > 0, + "MelscaleFbanks: n_mels must be greater than 0, got: " + std::to_string(n_mels)); + CHECK_FAIL_RETURN_UNEXPECTED( + sample_rate > 0, "MelscaleFbanks: sample_rate must be greater than 0, got: " + std::to_string(sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED(f_max > f_min, "MelscaleFbanks: f_max must be greater than f_min, got: f_min = " + + std::to_string(f_min) + ", while f_max = " + std::to_string(f_max)); + std::shared_ptr fb; + RETURN_IF_NOT_OK(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type)); + CHECK_FAIL_RETURN_UNEXPECTED(fb->HasData(), + "MelscaleFbanks: get an empty tensor with shape " + fb->shape().ToString()); + *output = mindspore::MSTensor(std::make_shared(fb)); + return Status::OK(); +} + +// MelSpectrogram Transform Operation. +struct MelSpectrogram::Data { + Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, float f_max, + int32_t pad, int32_t n_mels, WindowType window, float power, bool normalized, bool center, BorderType pad_mode, + bool onesided, NormType norm, MelType mel_scale) + : sample_rate_(sample_rate), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + f_min_(f_min), + f_max_(f_max), + pad_(pad), + n_mels_(n_mels), + window_(window), + power_(power), + normalized_(normalized), + center_(center), + pad_mode_(pad_mode), + onesided_(onesided), + norm_(norm), + mel_scale_(mel_scale) {} + int32_t sample_rate_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + float f_min_; + float f_max_; + int32_t pad_; + int32_t n_mels_; + WindowType window_; + float power_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; + NormType norm_; + MelType mel_scale_; +}; + +MelSpectrogram::MelSpectrogram(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, + float f_max, int32_t pad, int32_t n_mels, WindowType window, float power, + bool normalized, bool center, BorderType pad_mode, bool onesided, NormType norm, + MelType mel_scale) + : data_(std::make_shared(sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, window, power, + normalized, center, pad_mode, onesided, norm, mel_scale)) {} + +std::shared_ptr MelSpectrogram::Parse() { + return std::make_shared( + data_->sample_rate_, data_->n_fft_, data_->win_length_, data_->hop_length_, data_->f_min_, data_->f_max_, + data_->pad_, data_->n_mels_, data_->window_, data_->power_, data_->normalized_, data_->center_, data_->pad_mode_, + data_->onesided_, data_->norm_, data_->mel_scale_); +} + +// MFCC Transform Operation. +struct MFCC::Data { + Data(int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, bool log_mels, int32_t n_fft, + int32_t win_length, int32_t hop_length, float f_min, float f_max, int32_t pad, int32_t n_mels, WindowType window, + float power, bool normalized, bool center, BorderType pad_mode, bool onesided, NormType norm_mel, + MelType mel_scale) + : sample_rate_(sample_rate), + n_mfcc_(n_mfcc), + dct_type_(dct_type), + norm_(norm), + log_mels_(log_mels), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + f_min_(f_min), + f_max_(f_max), + pad_(pad), + n_mels_(n_mels), + window_(window), + power_(power), + normalized_(normalized), + center_(center), + pad_mode_(pad_mode), + onesided_(onesided), + norm_mel_(norm_mel), + mel_scale_(mel_scale) {} + int32_t sample_rate_; + int32_t n_mfcc_; + int32_t dct_type_; + NormMode norm_; + bool log_mels_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + float f_min_; + float f_max_; + int32_t pad_; + int32_t n_mels_; + WindowType window_; + float power_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; + NormType norm_mel_; + MelType mel_scale_; + std::map melkwargs_; +}; + +MFCC::MFCC(int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, bool log_mels, int32_t n_fft, + int32_t win_length, int32_t hop_length, float f_min, float f_max, int32_t pad, int32_t n_mels, + WindowType window, float power, bool normalized, bool center, BorderType pad_mode, bool onesided, + NormType norm_mel, MelType mel_scale) + : data_(std::make_shared(sample_rate, n_mfcc, dct_type, norm, log_mels, n_fft, win_length, hop_length, f_min, + f_max, pad, n_mels, window, power, normalized, center, pad_mode, onesided, norm_mel, + mel_scale)) {} + +std::shared_ptr MFCC::Parse() { + return std::make_shared(data_->sample_rate_, data_->n_mfcc_, data_->dct_type_, data_->norm_, + data_->log_mels_, data_->n_fft_, data_->win_length_, data_->hop_length_, + data_->f_min_, data_->f_max_, data_->pad_, data_->n_mels_, data_->window_, + data_->power_, data_->normalized_, data_->center_, data_->pad_mode_, + data_->onesided_, data_->norm_mel_, data_->mel_scale_); +} + +// MuLawDecoding Transform Operation. +struct MuLawDecoding::Data { + explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {} + int32_t quantization_channels_; +}; + +MuLawDecoding::MuLawDecoding(int32_t quantization_channels) : data_(std::make_shared(quantization_channels)) {} + +std::shared_ptr MuLawDecoding::Parse() { + return std::make_shared(data_->quantization_channels_); +} + +// MuLawEncoding Transform Operation. +struct MuLawEncoding::Data { + explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {} + int32_t quantization_channels_; +}; + +MuLawEncoding::MuLawEncoding(int32_t quantization_channels) : data_(std::make_shared(quantization_channels)) {} + +std::shared_ptr MuLawEncoding::Parse() { + return std::make_shared(data_->quantization_channels_); +} + +// LinearFbanks Function. +Status LinearFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_filter, + int32_t sample_rate) { + RETURN_UNEXPECTED_IF_NULL(output); + CHECK_FAIL_RETURN_UNEXPECTED(n_freqs > 0, + "LinearFbanks: n_freqs must be greater than 0, got: " + std::to_string(n_freqs)); + + CHECK_FAIL_RETURN_UNEXPECTED(f_min >= 0, "LinearFbanks: f_min must be non negative, got: " + std::to_string(f_min)); + CHECK_FAIL_RETURN_UNEXPECTED(f_max > 0, "LinearFbanks: f_max must be greater than 0, got: " + std::to_string(f_max)); + CHECK_FAIL_RETURN_UNEXPECTED(n_filter > 0, + "LinearFbanks: n_filter must be greater than 0, got: " + std::to_string(n_filter)); + CHECK_FAIL_RETURN_UNEXPECTED(sample_rate > 0, + "LinearFbanks: sample_rate must be greater than 0, got: " + std::to_string(sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED(f_max > f_min, "LinearFbanks: f_max must be greater than f_min, got: f_min = " + + std::to_string(f_min) + ", while f_max = " + std::to_string(f_max)); + std::shared_ptr fb; + RETURN_IF_NOT_OK(CreateLinearFbanks(&fb, n_freqs, f_min, f_max, n_filter, sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED(fb->HasData(), "LinearFbanks: get an empty tensor with shape " + fb->shape().ToString()); + *output = mindspore::MSTensor(std::make_shared(fb)); + return Status::OK(); +} + +// Overdrive Transform Operation. +struct Overdrive::Data { + Data(float gain, float color) : gain_(gain), color_(color) {} + float gain_; + float color_; +}; + +Overdrive::Overdrive(float gain, float color) : data_(std::make_shared(gain, color)) {} + +std::shared_ptr Overdrive::Parse() { + return std::make_shared(data_->gain_, data_->color_); +} + +// Phaser Transform Operation. +struct Phaser::Data { + Data(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed, + bool sinusoidal) + : sample_rate_(sample_rate), + gain_in_(gain_in), + gain_out_(gain_out), + delay_ms_(delay_ms), + decay_(decay), + mod_speed_(mod_speed), + sinusoidal_(sinusoidal) {} + int32_t sample_rate_; + float gain_in_; + float gain_out_; + float delay_ms_; + float decay_; + float mod_speed_; + bool sinusoidal_; +}; + +Phaser::Phaser(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed, + bool sinusoidal) + : data_(std::make_shared(sample_rate, gain_in, gain_out, delay_ms, decay, mod_speed, sinusoidal)) {} + +std::shared_ptr Phaser::Parse() { + return std::make_shared(data_->sample_rate_, data_->gain_in_, data_->gain_out_, data_->delay_ms_, + data_->decay_, data_->mod_speed_, data_->sinusoidal_); +} + +// PhaseVocoder Transofrm Operation. +struct PhaseVocoder::Data { + Data(float rate, const MSTensor &phase_advance) : rate_(rate), phase_advance_(phase_advance) {} + float rate_; + MSTensor phase_advance_; +}; + +PhaseVocoder::PhaseVocoder(float rate, const MSTensor &phase_advance) + : data_(std::make_shared(rate, phase_advance)) {} + +std::shared_ptr PhaseVocoder::Parse() { + std::shared_ptr phase_advance; + Status rc = Tensor::CreateFromMSTensor(data_->phase_advance_, &phase_advance); + if (rc.IsError()) { + MS_LOG(ERROR) << "Error creating phase_vocoder constant tensor." << rc; + return nullptr; + } + return std::make_shared(data_->rate_, phase_advance); +} + +// pitchshift +struct PitchShift::Data { + Data(int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, + int32_t hop_length, WindowType window) + : sample_rate_(sample_rate), + n_steps_(n_steps), + bins_per_octave_(bins_per_octave), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + window_(window) {} + int32_t sample_rate_; + int32_t n_steps_; + int32_t bins_per_octave_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + WindowType window_; +}; + +PitchShift::PitchShift(int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, + int32_t hop_length, WindowType window) + : data_(std::make_shared(sample_rate, n_steps, bins_per_octave, n_fft, win_length, hop_length, window)) {} + +std::shared_ptr PitchShift::Parse() { + return std::make_shared(data_->sample_rate_, data_->n_steps_, data_->bins_per_octave_, + data_->n_fft_, data_->win_length_, data_->hop_length_, data_->window_); +} + +// Resample Transform Operation. +struct Resample::Data { + Data(float orig_freq, float new_freq, ResampleMethod resample_method, int32_t lowpass_filter_width, float rolloff, + float beta) + : orig_freq_(orig_freq), + new_freq_(new_freq), + resample_method_(resample_method), + lowpass_filter_width_(lowpass_filter_width), + rolloff_(rolloff), + beta_(beta) {} + float orig_freq_; + float new_freq_; + ResampleMethod resample_method_; + int32_t lowpass_filter_width_; + float rolloff_; + float beta_; +}; + +Resample::Resample(float orig_freq, float new_freq, ResampleMethod resample_method, int32_t lowpass_filter_width, + float rolloff, float beta) + : data_(std::make_shared(orig_freq, new_freq, resample_method, lowpass_filter_width, rolloff, beta)) {} + +std::shared_ptr Resample::Parse() { + return std::make_shared(data_->orig_freq_, data_->new_freq_, data_->resample_method_, + data_->lowpass_filter_width_, data_->rolloff_, data_->beta_); +} + +// RiaaBiquad Transform Operation. +struct RiaaBiquad::Data { + explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} + int32_t sample_rate_; +}; + +RiaaBiquad::RiaaBiquad(int32_t sample_rate) : data_(std::make_shared(sample_rate)) {} + +std::shared_ptr RiaaBiquad::Parse() { + return std::make_shared(data_->sample_rate_); +} + +// SlidingWindowCmn Transform Operation. +struct SlidingWindowCmn::Data { + Data(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) + : cmn_window_(cmn_window), min_cmn_window_(min_cmn_window), center_(center), norm_vars_(norm_vars) {} + int32_t cmn_window_; + int32_t min_cmn_window_; + bool center_; + bool norm_vars_; +}; + +SlidingWindowCmn::SlidingWindowCmn(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) + : data_(std::make_shared(cmn_window, min_cmn_window, center, norm_vars)) {} + +std::shared_ptr SlidingWindowCmn::Parse() { + return std::make_shared(data_->cmn_window_, data_->min_cmn_window_, data_->center_, + data_->norm_vars_); +} + +// Spectrogram Transform Operation. +struct Spectrogram::Data { + Data(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, float power, + bool normalized, bool center, BorderType pad_mode, bool onesided) + : n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + pad_(pad), + window_(window), + power_(power), + normalized_(normalized), + center_(center), + pad_mode_(pad_mode), + onesided_(onesided) {} + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; + float power_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; +}; + +Spectrogram::Spectrogram(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, + float power, bool normalized, bool center, BorderType pad_mode, bool onesided) + : data_(std::make_shared(n_fft, win_length, hop_length, pad, window, power, normalized, center, pad_mode, + onesided)) {} + +std::shared_ptr Spectrogram::Parse() { + return std::make_shared(data_->n_fft_, data_->win_length_, data_->hop_length_, data_->pad_, + data_->window_, data_->power_, data_->normalized_, data_->center_, + data_->pad_mode_, data_->onesided_); +} + +// SpectralCentroid Transform Operation. +struct SpectralCentroid::Data { + Data(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window) + : sample_rate_(sample_rate), + n_fft_(n_fft), + win_length_(win_length), + hop_length_(hop_length), + pad_(pad), + window_(window) {} + int32_t sample_rate_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; +}; + +SpectralCentroid::SpectralCentroid(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, + int32_t pad, WindowType window) + : data_(std::make_shared(sample_rate, n_fft, win_length, hop_length, pad, window)) {} + +std::shared_ptr SpectralCentroid::Parse() { + return std::make_shared(data_->sample_rate_, data_->n_fft_, data_->win_length_, + data_->hop_length_, data_->pad_, data_->window_); +} + +// TimeMasking Transform Operation. +struct TimeMasking::Data { + Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) + : iid_masks_(iid_masks), time_mask_param_(time_mask_param), mask_start_(mask_start), mask_value_(mask_value) {} + bool iid_masks_; + int32_t time_mask_param_; + int32_t mask_start_; + float mask_value_; +}; + +TimeMasking::TimeMasking(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) + : data_(std::make_shared(iid_masks, time_mask_param, mask_start, mask_value)) {} + +std::shared_ptr TimeMasking::Parse() { + return std::make_shared(data_->iid_masks_, data_->time_mask_param_, data_->mask_start_, + data_->mask_value_); +} + +// TimeStretch Transform Operation. +struct TimeStretch::Data { + explicit Data(float hop_length, int32_t n_freq, float fixed_rate) + : hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {} + float hop_length_; + int32_t n_freq_; + float fixed_rate_; +}; + +TimeStretch::TimeStretch(float hop_length, int32_t n_freq, float fixed_rate) + : data_(std::make_shared(hop_length, n_freq, fixed_rate)) {} + +std::shared_ptr TimeStretch::Parse() { + return std::make_shared(data_->hop_length_, data_->n_freq_, data_->fixed_rate_); +} + +// TrebleBiquad Transform Operation. +struct TrebleBiquad::Data { + Data(int32_t sample_rate, float gain, float central_freq, float Q) + : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} + int32_t sample_rate_; + float gain_; + float central_freq_; + float Q_; +}; + +TrebleBiquad::TrebleBiquad(int32_t sample_rate, float gain, float central_freq, float Q) + : data_(std::make_shared(sample_rate, gain, central_freq, Q)) {} + +std::shared_ptr TrebleBiquad::Parse() { + return std::make_shared(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_); +} + +// Vad Transform Operation. +struct Vad::Data { + Data(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap, + float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time, + float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time, + float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) + : sample_rate_(sample_rate), + trigger_level_(trigger_level), + trigger_time_(trigger_time), + search_time_(search_time), + allowed_gap_(allowed_gap), + pre_trigger_time_(pre_trigger_time), + boot_time_(boot_time), + noise_up_time_(noise_up_time), + noise_down_time_(noise_down_time), + noise_reduction_amount_(noise_reduction_amount), + measure_freq_(measure_freq), + measure_duration_(measure_duration), + measure_smooth_time_(measure_smooth_time), + hp_filter_freq_(hp_filter_freq), + lp_filter_freq_(lp_filter_freq), + hp_lifter_freq_(hp_lifter_freq), + lp_lifter_freq_(lp_lifter_freq) {} + int32_t sample_rate_; + float trigger_level_; + float trigger_time_; + float search_time_; + float allowed_gap_; + float pre_trigger_time_; + float boot_time_; + float noise_up_time_; + float noise_down_time_; + float noise_reduction_amount_; + float measure_freq_; + float measure_duration_; + float measure_smooth_time_; + float hp_filter_freq_; + float lp_filter_freq_; + float hp_lifter_freq_; + float lp_lifter_freq_; +}; + +Vad::Vad(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap, + float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time, + float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time, + float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) + : data_(std::make_shared(sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time, + boot_time, noise_up_time, noise_down_time, noise_reduction_amount, measure_freq, + measure_duration, measure_smooth_time, hp_filter_freq, lp_filter_freq, + hp_lifter_freq, lp_lifter_freq)) {} + +std::shared_ptr Vad::Parse() { + return std::make_shared( + data_->sample_rate_, data_->trigger_level_, data_->trigger_time_, data_->search_time_, data_->allowed_gap_, + data_->pre_trigger_time_, data_->boot_time_, data_->noise_up_time_, data_->noise_down_time_, + data_->noise_reduction_amount_, data_->measure_freq_, data_->measure_duration_, data_->measure_smooth_time_, + data_->hp_filter_freq_, data_->lp_filter_freq_, data_->hp_lifter_freq_, data_->lp_lifter_freq_); +} + +// Vol Transform Operation. +struct Vol::Data { + Data(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {} + float gain_; + GainType gain_type_; +}; + +Vol::Vol(float gain, GainType gain_type) : data_(std::make_shared(gain, gain_type)) {} + +std::shared_ptr Vol::Parse() { + return std::make_shared(data_->gain_, data_->gain_type_); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/bindings.cc old mode 100755 new mode 100644 index 245012b53c5..871fc7a02c5 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/bindings.cc @@ -1,73 +1,73 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pybind11/pybind11.h" - -#include "minddata/dataset/api/python/pybind_conversion.h" -#include "minddata/dataset/api/python/pybind_register.h" -#include "minddata/dataset/audio/kernels/audio_utils.h" - -namespace mindspore { -namespace dataset { -PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) { - (void)m->def("create_dct", ([](int32_t n_mfcc, int32_t n_mels, NormMode norm) { - std::shared_ptr out; - THROW_IF_ERROR(Dct(&out, n_mfcc, n_mels, norm)); - return out; - })); - })); - -PYBIND_REGISTER(MelscaleFbanks, 1, ([](py::module *m) { - (void)m->def("melscale_fbanks", ([](int32_t n_freqs, float f_min, float f_max, int32_t n_mels, - int32_t sample_rate, NormType norm, MelType mel_type) { - std::shared_ptr fb; - THROW_IF_ERROR(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, - norm, mel_type)); - return fb; - })); - })); - -PYBIND_REGISTER(MelType, 0, ([](const py::module *m) { - (void)py::enum_(*m, "MelType", py::arithmetic()) - .value("DE_MEL_TYPE_HTK", MelType::kHtk) - .value("DE_MEL_TYPE_SLANEY", MelType::kSlaney) - .export_values(); - })); - -PYBIND_REGISTER(NormType, 0, ([](const py::module *m) { - (void)py::enum_(*m, "NormType", py::arithmetic()) - .value("DE_NORM_TYPE_NONE", NormType::kNone) - .value("DE_NORM_TYPE_SLANEY", NormType::kSlaney) - .export_values(); - })); - -PYBIND_REGISTER(LinearFbanks, 1, ([](py::module *m) { - (void)m->def("linear_fbanks", - ([](int32_t n_freqs, float f_min, float f_max, int32_t n_filter, int32_t sample_rate) { - std::shared_ptr fb; - THROW_IF_ERROR(CreateLinearFbanks(&fb, n_freqs, f_min, f_max, n_filter, sample_rate)); - return fb; - })); - })); - -PYBIND_REGISTER(NormMode, 0, ([](const py::module *m) { - (void)py::enum_(*m, "NormMode", py::arithmetic()) - .value("DE_NORM_MODE_NONE", NormMode::kNone) - .value("DE_NORM_MODE_ORTHO", NormMode::kOrtho) - .export_values(); - })); -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind11/pybind11.h" + +#include "minddata/dataset/api/python/pybind_conversion.h" +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/audio/kernels/audio_utils.h" + +namespace mindspore { +namespace dataset { +PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) { + (void)m->def("create_dct", ([](int32_t n_mfcc, int32_t n_mels, NormMode norm) { + std::shared_ptr out; + THROW_IF_ERROR(Dct(&out, n_mfcc, n_mels, norm)); + return out; + })); + })); + +PYBIND_REGISTER(MelscaleFbanks, 1, ([](py::module *m) { + (void)m->def("melscale_fbanks", ([](int32_t n_freqs, float f_min, float f_max, int32_t n_mels, + int32_t sample_rate, NormType norm, MelType mel_type) { + std::shared_ptr fb; + THROW_IF_ERROR(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, + norm, mel_type)); + return fb; + })); + })); + +PYBIND_REGISTER(MelType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "MelType", py::arithmetic()) + .value("DE_MEL_TYPE_HTK", MelType::kHtk) + .value("DE_MEL_TYPE_SLANEY", MelType::kSlaney) + .export_values(); + })); + +PYBIND_REGISTER(NormType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "NormType", py::arithmetic()) + .value("DE_NORM_TYPE_NONE", NormType::kNone) + .value("DE_NORM_TYPE_SLANEY", NormType::kSlaney) + .export_values(); + })); + +PYBIND_REGISTER(LinearFbanks, 1, ([](py::module *m) { + (void)m->def("linear_fbanks", + ([](int32_t n_freqs, float f_min, float f_max, int32_t n_filter, int32_t sample_rate) { + std::shared_ptr fb; + THROW_IF_ERROR(CreateLinearFbanks(&fb, n_freqs, f_min, f_max, n_filter, sample_rate)); + return fb; + })); + })); + +PYBIND_REGISTER(NormMode, 0, ([](const py::module *m) { + (void)py::enum_(*m, "NormMode", py::arithmetic()) + .value("DE_NORM_MODE_NONE", NormMode::kNone) + .value("DE_NORM_MODE_ORTHO", NormMode::kOrtho) + .export_values(); + })); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index c3712e7497d..272f1bcca39 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -1,756 +1,756 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pybind11/pybind11.h" - -#include "minddata/dataset/api/python/pybind_conversion.h" -#include "minddata/dataset/api/python/pybind_register.h" -#include "minddata/dataset/include/dataset/transforms.h" - -#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" -#include "minddata/dataset/audio/ir/kernels/angle_ir.h" -#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" -#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h" -#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" -#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" -#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" -#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" -#include "minddata/dataset/audio/ir/kernels/dither_ir.h" -#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/fade_ir.h" -#include "minddata/dataset/audio/ir/kernels/filtfilt_ir.h" -#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" -#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" -#include "minddata/dataset/audio/ir/kernels/gain_ir.h" -#include "minddata/dataset/audio/ir/kernels/griffin_lim_ir.h" -#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/inverse_mel_scale_ir.h" -#include "minddata/dataset/audio/ir/kernels/inverse_spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/lfcc_ir.h" -#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" -#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" -#include "minddata/dataset/audio/ir/kernels/mel_scale_ir.h" -#include "minddata/dataset/audio/ir/kernels/mel_spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/mfcc_ir.h" -#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" -#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h" -#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" -#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" -#include "minddata/dataset/audio/ir/kernels/phaser_ir.h" -#include "minddata/dataset/audio/ir/kernels/pitch_shift_ir.h" -#include "minddata/dataset/audio/ir/kernels/resample_ir.h" -#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h" -#include "minddata/dataset/audio/ir/kernels/spectral_centroid_ir.h" -#include "minddata/dataset/audio/ir/kernels/spectrogram_ir.h" -#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" -#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" -#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h" -#include "minddata/dataset/audio/ir/kernels/vad_ir.h" -#include "minddata/dataset/audio/ir/kernels/vol_ir.h" - -namespace mindspore { -namespace dataset { -PYBIND_REGISTER( - AllpassBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "AllpassBiquadOperation") - .def(py::init([](int32_t sample_rate, float central_freq, float Q) { - auto allpass_biquad = std::make_shared(sample_rate, central_freq, Q); - THROW_IF_ERROR(allpass_biquad->ValidateParams()); - return allpass_biquad; - })); - })); - -PYBIND_REGISTER( - AmplitudeToDBOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "AmplitudeToDBOperation") - .def(py::init([](ScaleType stype, float ref_value, float amin, float top_db) { - auto amplitude_to_db = std::make_shared(stype, ref_value, amin, top_db); - THROW_IF_ERROR(amplitude_to_db->ValidateParams()); - return amplitude_to_db; - })); - })); - -PYBIND_REGISTER(ScaleType, 0, ([](const py::module *m) { - (void)py::enum_(*m, "ScaleType", py::arithmetic()) - .value("DE_SCALE_TYPE_MAGNITUDE", ScaleType::kMagnitude) - .value("DE_SCALE_TYPE_POWER", ScaleType::kPower) - .export_values(); - })); - -PYBIND_REGISTER(AngleOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "AngleOperation") - .def(py::init([]() { - auto angle = std::make_shared(); - THROW_IF_ERROR(angle->ValidateParams()); - return angle; - })); - })); - -PYBIND_REGISTER( - BandBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "BandBiquadOperation") - .def(py::init([](int32_t sample_rate, float central_freq, float Q, bool noise) { - auto band_biquad = std::make_shared(sample_rate, central_freq, Q, noise); - THROW_IF_ERROR(band_biquad->ValidateParams()); - return band_biquad; - })); - })); - -PYBIND_REGISTER( - BandpassBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "BandpassBiquadOperation") - .def(py::init([](int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) { - auto bandpass_biquad = - std::make_shared(sample_rate, central_freq, Q, const_skirt_gain); - THROW_IF_ERROR(bandpass_biquad->ValidateParams()); - return bandpass_biquad; - })); - })); - -PYBIND_REGISTER(BandrejectBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, "BandrejectBiquadOperation") - .def(py::init([](int32_t sample_rate, float central_freq, float Q) { - auto bandreject_biquad = - std::make_shared(sample_rate, central_freq, Q); - THROW_IF_ERROR(bandreject_biquad->ValidateParams()); - return bandreject_biquad; - })); - })); - -PYBIND_REGISTER( - BassBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "BassBiquadOperation") - .def(py::init([](int32_t sample_rate, float gain, float central_freq, float Q) { - auto bass_biquad = std::make_shared(sample_rate, gain, central_freq, Q); - THROW_IF_ERROR(bass_biquad->ValidateParams()); - return bass_biquad; - })); - })); - -PYBIND_REGISTER(BiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "BiquadOperation") - .def(py::init([](float b0, float b1, float b2, float a0, float a1, float a2) { - auto biquad = std::make_shared(b0, b1, b2, a0, a1, a2); - THROW_IF_ERROR(biquad->ValidateParams()); - return biquad; - })); - })); - -PYBIND_REGISTER( - ComplexNormOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "ComplexNormOperation") - .def(py::init([](float power) { - auto complex_norm = std::make_shared(power); - THROW_IF_ERROR(complex_norm->ValidateParams()); - return complex_norm; - })); - })); - -PYBIND_REGISTER( - ComputeDeltasOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "ComputeDeltasOperation") - .def(py::init([](int32_t win_length, BorderType pad_mode) { - auto compute_deltas = std::make_shared(win_length, pad_mode); - THROW_IF_ERROR(compute_deltas->ValidateParams()); - return compute_deltas; - })); - })); - -PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "ContrastOperation") - .def(py::init([](float enhancement_amount) { - auto contrast = std::make_shared(enhancement_amount); - THROW_IF_ERROR(contrast->ValidateParams()); - return contrast; - })); - })); - -PYBIND_REGISTER( - DBToAmplitudeOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DBToAmplitudeOperation") - .def(py::init([](float ref, float power) { - auto db_to_amplitude = std::make_shared(ref, power); - THROW_IF_ERROR(db_to_amplitude->ValidateParams()); - return db_to_amplitude; - })); - })); - -PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DCShiftOperation") - .def(py::init([](float shift, float limiter_gain) { - auto dc_shift = std::make_shared(shift, limiter_gain); - THROW_IF_ERROR(dc_shift->ValidateParams()); - return dc_shift; - })); - })); - -PYBIND_REGISTER( - DeemphBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DeemphBiquadOperation") - .def(py::init([](int32_t sample_rate) { - auto deemph_biquad = std::make_shared(sample_rate); - THROW_IF_ERROR(deemph_biquad->ValidateParams()); - return deemph_biquad; - })); - })); - -PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DetectPitchFrequencyOperation") - .def(py::init([](int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, - int32_t freq_high) { - auto detect_pitch_frequency = std::make_shared( - sample_rate, frame_time, win_length, freq_low, freq_high); - THROW_IF_ERROR(detect_pitch_frequency->ValidateParams()); - return detect_pitch_frequency; - })); - })); - -PYBIND_REGISTER(DensityFunction, 0, ([](const py::module *m) { - (void)py::enum_(*m, "DensityFunction", py::arithmetic()) - .value("DE_DENSITY_FUNCTION_TPDF", DensityFunction::kTPDF) - .value("DE_DENSITY_FUNCTION_RPDF", DensityFunction::kRPDF) - .value("DE_DENSITY_FUNCTION_GPDF", DensityFunction::kGPDF) - .export_values(); - })); - -PYBIND_REGISTER(DitherOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "DitherOperation") - .def(py::init([](DensityFunction density_function, bool noise_shaping) { - auto dither = std::make_shared(density_function, noise_shaping); - THROW_IF_ERROR(dither->ValidateParams()); - return dither; - })); - })); - -PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, "EqualizerBiquadOperation") - .def(py::init([](int sample_rate, float center_freq, float gain, float Q) { - auto equalizer_biquad = - std::make_shared(sample_rate, center_freq, gain, Q); - THROW_IF_ERROR(equalizer_biquad->ValidateParams()); - return equalizer_biquad; - })); - })); - -PYBIND_REGISTER(FadeShape, 0, ([](const py::module *m) { - (void)py::enum_(*m, "FadeShape", py::arithmetic()) - .value("DE_FADE_SHAPE_LINEAR", FadeShape::kLinear) - .value("DE_FADE_SHAPE_EXPONENTIAL", FadeShape::kExponential) - .value("DE_FADE_SHAPE_LOGARITHMIC", FadeShape::kLogarithmic) - .value("DE_FADE_SHAPE_QUARTER_SINE", FadeShape::kQuarterSine) - .value("DE_FADE_SHAPE_HALF_SINE", FadeShape::kHalfSine) - .export_values(); - })); - -PYBIND_REGISTER(FadeOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "FadeOperation") - .def(py::init([](int fade_in_len, int fade_out_len, FadeShape fade_shape) { - auto fade = std::make_shared(fade_in_len, fade_out_len, fade_shape); - THROW_IF_ERROR(fade->ValidateParams()); - return fade; - })); - })); - -PYBIND_REGISTER(FiltfiltOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "FiltfiltOperation") - .def(py::init([](const std::vector &a_coeffs, std::vector &b_coeffs, bool clamp) { - auto filtfilt = std::make_shared(a_coeffs, b_coeffs, clamp); - THROW_IF_ERROR(filtfilt->ValidateParams()); - return filtfilt; - })); - })); - -PYBIND_REGISTER(Modulation, 0, ([](const py::module *m) { - (void)py::enum_(*m, "Modulation", py::arithmetic()) - .value("DE_MODULATION_SINUSOIDAL", Modulation::kSinusoidal) - .value("DE_MODULATION_TRIANGULAR", Modulation::kTriangular) - .export_values(); - })); - -PYBIND_REGISTER(Interpolation, 0, ([](const py::module *m) { - (void)py::enum_(*m, "Interpolation", py::arithmetic()) - .value("DE_INTERPOLATION_LINEAR", Interpolation::kLinear) - .value("DE_INTERPOLATION_QUADRATIC", Interpolation::kQuadratic) - .export_values(); - })); - -PYBIND_REGISTER(FlangerOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "FlangerOperation") - .def(py::init([](int32_t sample_rate, float delay, float depth, float regen, float width, - float speed, float phase, Modulation modulation, Interpolation interpolation) { - auto flanger = std::make_shared(sample_rate, delay, depth, regen, width, - speed, phase, modulation, interpolation); - THROW_IF_ERROR(flanger->ValidateParams()); - return flanger; - })); - })); - -PYBIND_REGISTER( - FrequencyMaskingOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "FrequencyMaskingOperation") - .def(py::init([](bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) { - auto frequency_masking = - std::make_shared(iid_masks, frequency_mask_param, mask_start, mask_value); - THROW_IF_ERROR(frequency_masking->ValidateParams()); - return frequency_masking; - })); - })); - -PYBIND_REGISTER(GainOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "GainOperation") - .def(py::init([](float gain_db) { - auto gain = std::make_shared(gain_db); - THROW_IF_ERROR(gain->ValidateParams()); - return gain; - })); - })); - -PYBIND_REGISTER( - GriffinLimOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "GriffinLimOperation") - .def(py::init([](int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, - float power, float momentum, int32_t length, bool rand_init) { - auto griffin_lim = std::make_shared( - n_fft, n_iter, win_length, hop_length, window_type, power, momentum, length, rand_init); - THROW_IF_ERROR(griffin_lim->ValidateParams()); - return griffin_lim; - })); - })); - -PYBIND_REGISTER( - HighpassBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "HighpassBiquadOperation") - .def(py::init([](float sample_rate, float cutoff_freq, float Q) { - auto highpass_biquad = std::make_shared(sample_rate, cutoff_freq, Q); - THROW_IF_ERROR(highpass_biquad->ValidateParams()); - return highpass_biquad; - })); - })); - -PYBIND_REGISTER(InverseMelScaleOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, "InverseMelScaleOperation") - .def(py::init([](int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, - int32_t max_iter, float tolerance_loss, float tolerance_change, - const py::dict &sgdargs, NormType norm, MelType mel_type) { - auto inverse_mel_scale = std::make_shared( - n_stft, n_mels, sample_rate, f_min, f_max, max_iter, tolerance_loss, tolerance_change, - toStringFloatMap(sgdargs), norm, mel_type); - THROW_IF_ERROR(inverse_mel_scale->ValidateParams()); - return inverse_mel_scale; - })); - })); - -PYBIND_REGISTER(InverseSpectrogramOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, - "InverseSpectrogramOperation") - .def( - py::init([](int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, - WindowType window, bool normalized, bool center, BorderType pad_mode, bool onesided) { - auto inverse_spectrogram = std::make_shared( - length, n_fft, win_length, hop_length, pad, window, normalized, center, pad_mode, onesided); - THROW_IF_ERROR(inverse_spectrogram->ValidateParams()); - return inverse_spectrogram; - })); - })); - -PYBIND_REGISTER(LFCCOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "LFCCOperation") - .def(py::init([](int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, - int32_t dct_type, NormMode norm, bool log_lf, const py::dict &speckwargs, - WindowType window, BorderType pad_mode) { - int32_t n_fft = py::cast(speckwargs["n_fft"]); - int32_t win_length = py::cast(speckwargs["win_length"]); - int32_t hop_length = py::cast(speckwargs["hop_length"]); - int32_t pad = py::cast(speckwargs["pad"]); - float power = py::cast(speckwargs["power"]); - bool normalized = py::cast(speckwargs["normalized"]); - bool center = py::cast(speckwargs["center"]); - bool onesided = py::cast(speckwargs["onesided"]); - auto lfcc = std::make_shared( - sample_rate, n_filter, n_lfcc, f_min, f_max, dct_type, norm, log_lf, n_fft, win_length, - hop_length, pad, window, power, normalized, center, pad_mode, onesided); - THROW_IF_ERROR(lfcc->ValidateParams()); - return lfcc; - })); - })); - -PYBIND_REGISTER(LFilterOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "LFilterOperation") - .def(py::init([](std::vector a_coeffs, std::vector b_coeffs, bool clamp) { - auto lfilter = std::make_shared(a_coeffs, b_coeffs, clamp); - THROW_IF_ERROR(lfilter->ValidateParams()); - return lfilter; - })); - })); - -PYBIND_REGISTER( - LowpassBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "LowpassBiquadOperation") - .def(py::init([](int sample_rate, float cutoff_freq, float Q) { - auto lowpass_biquad = std::make_shared(sample_rate, cutoff_freq, Q); - THROW_IF_ERROR(lowpass_biquad->ValidateParams()); - return lowpass_biquad; - })); - })); - -PYBIND_REGISTER(MagphaseOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "MagphaseOperation") - .def(py::init([](float power) { - auto magphase = std::make_shared(power); - THROW_IF_ERROR(magphase->ValidateParams()); - return magphase; - })); - })); - -PYBIND_REGISTER(MaskAlongAxisIIDOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, "MaskAlongAxisIIDOperation") - .def(py::init([](int32_t mask_param, float mask_value, int32_t axis) { - auto mask_along_axis_iid = - std::make_shared(mask_param, mask_value, axis); - THROW_IF_ERROR(mask_along_axis_iid->ValidateParams()); - return mask_along_axis_iid; - })); - })); - -PYBIND_REGISTER( - MaskAlongAxisOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "MaskAlongAxisOperation") - .def(py::init([](int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) { - auto mask_along_axis = - std::make_shared(mask_start, mask_width, mask_value, axis); - THROW_IF_ERROR(mask_along_axis->ValidateParams()); - return mask_along_axis; - })); - })); - -PYBIND_REGISTER(MelScaleOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "MelScaleOperation") - .def(py::init([](int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, - NormType norm, MelType mel_type) { - auto mel_scale = std::make_shared(n_mels, sample_rate, f_min, f_max, - n_stft, norm, mel_type); - THROW_IF_ERROR(mel_scale->ValidateParams()); - return mel_scale; - })); - })); - -PYBIND_REGISTER( - MelSpectrogramOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "MelSpectrogramOperation") - .def(py::init([](int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, - float f_max, int32_t pad, int32_t n_mels, WindowType window, float power, bool normalized, - bool center, BorderType pad_mode, bool onesided, NormType norm, MelType mel_scale) { - auto mel_spectrogram = std::make_shared( - sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, window, power, normalized, center, - pad_mode, onesided, norm, mel_scale); - THROW_IF_ERROR(mel_spectrogram->ValidateParams()); - return mel_spectrogram; - })); - })); - -PYBIND_REGISTER(MFCCOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "MFCCOperation") - .def(py::init([](int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, - bool log_mels, const py::dict &melkwargs, WindowType window, BorderType pad_mode, - NormType norm_mel, MelType mel_scale) { - int32_t n_fft = py::cast(melkwargs["n_fft"]); - int32_t win_length = py::cast(melkwargs["win_length"]); - int32_t hop_length = py::cast(melkwargs["hop_length"]); - float f_min = py::cast(melkwargs["f_min"]); - float f_max = py::cast(melkwargs["f_max"]); - int32_t pad = py::cast(melkwargs["pad"]); - int32_t n_mels = py::cast(melkwargs["n_mels"]); - float power = py::cast(melkwargs["power"]); - bool normalized = py::cast(melkwargs["normalized"]); - bool center = py::cast(melkwargs["center"]); - bool onesided = py::cast(melkwargs["onesided"]); - auto mfcc = std::make_shared( - sample_rate, n_mfcc, dct_type, norm, log_mels, n_fft, win_length, hop_length, f_min, f_max, pad, - n_mels, window, power, normalized, center, pad_mode, onesided, norm_mel, mel_scale); - THROW_IF_ERROR(mfcc->ValidateParams()); - return mfcc; - })); - })); - -PYBIND_REGISTER( - MuLawDecodingOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "MuLawDecodingOperation") - .def(py::init([](int32_t quantization_channels) { - auto mu_law_decoding = std::make_shared(quantization_channels); - THROW_IF_ERROR(mu_law_decoding->ValidateParams()); - return mu_law_decoding; - })); - })); - -PYBIND_REGISTER( - MuLawEncodingOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "MuLawEncodingOperation") - .def(py::init([](int32_t quantization_channels) { - auto mu_law_encoding = std::make_shared(quantization_channels); - THROW_IF_ERROR(mu_law_encoding->ValidateParams()); - return mu_law_encoding; - })); - })); - -PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "OverdriveOperation") - .def(py::init([](float gain, float color) { - auto overdrive = std::make_shared(gain, color); - THROW_IF_ERROR(overdrive->ValidateParams()); - return overdrive; - })); - })); - -PYBIND_REGISTER(PhaserOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "PhaserOperation") - .def(py::init([](int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, - float mod_speed, bool sinusoidal) { - auto phaser = std::make_shared(sample_rate, gain_in, gain_out, delay_ms, - decay, mod_speed, sinusoidal); - THROW_IF_ERROR(phaser->ValidateParams()); - return phaser; - })); - })); - -PYBIND_REGISTER( - PhaseVocoderOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "PhaseVocoderOperation") - .def(py::init([](float rate, const std::shared_ptr &phase_advance) { - auto phase_vocoder = std::make_shared(rate, phase_advance); - THROW_IF_ERROR(phase_vocoder->ValidateParams()); - return phase_vocoder; - })); - })); - -PYBIND_REGISTER( - PitchShiftOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "PitchShiftOperation") - .def(py::init([](int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, - int32_t hop_length, WindowType window) { - auto pitch_shift = std::make_shared(sample_rate, n_steps, bins_per_octave, n_fft, - win_length, hop_length, window); - THROW_IF_ERROR(pitch_shift->ValidateParams()); - return pitch_shift; - })); - })); - -PYBIND_REGISTER(ResampleMethod, 0, ([](const py::module *m) { - (void)py::enum_(*m, "ResampleMethod", py::arithmetic()) - .value("DE_RESAMPLE_SINC_INTERPOLATION", ResampleMethod::kSincInterpolation) - .value("DE_RESAMPLE_KAISER_WINDOW", ResampleMethod::kKaiserWindow) - .export_values(); - })); - -PYBIND_REGISTER(ResampleOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "ResampleOperation") - .def(py::init([](float orig_freq, float new_freq, ResampleMethod resample_method, - int32_t lowpass_filter_width, float rolloff, float beta) { - auto resample = std::make_shared(orig_freq, new_freq, resample_method, - lowpass_filter_width, rolloff, beta); - THROW_IF_ERROR(resample->ValidateParams()); - return resample; - })); - })); - -PYBIND_REGISTER( - RiaaBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "RiaaBiquadOperation") - .def(py::init([](int32_t sample_rate) { - auto riaa_biquad = std::make_shared(sample_rate); - THROW_IF_ERROR(riaa_biquad->ValidateParams()); - return riaa_biquad; - })); - })); - -PYBIND_REGISTER(SlidingWindowCmnOperation, 1, ([](const py::module *m) { - (void)py::class_>(*m, "SlidingWindowCmnOperation") - .def(py::init([](int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) { - auto sliding_window_cmn = std::make_shared( - cmn_window, min_cmn_window, center, norm_vars); - THROW_IF_ERROR(sliding_window_cmn->ValidateParams()); - return sliding_window_cmn; - })); - })); - -PYBIND_REGISTER(WindowType, 0, ([](const py::module *m) { - (void)py::enum_(*m, "WindowType", py::arithmetic()) - .value("DE_WINDOW_TYPE_BARTLETT", WindowType::kBartlett) - .value("DE_WINDOW_TYPE_BLACKMAN", WindowType::kBlackman) - .value("DE_WINDOW_TYPE_HAMMING", WindowType::kHamming) - .value("DE_WINDOW_TYPE_HANN", WindowType::kHann) - .value("DE_WINDOW_TYPE_KAISER", WindowType::kKaiser) - .export_values(); - })); - -PYBIND_REGISTER( - SpectralCentroidOperation, 1, ([](const py::module *m) { - (void) - py::class_>( - *m, "SpectralCentroidOperation") - .def(py::init([](int sample_rate, int n_fft, int win_length, int hop_length, int pad, WindowType window) { - auto spectral_centroid = - std::make_shared(sample_rate, n_fft, win_length, hop_length, pad, window); - THROW_IF_ERROR(spectral_centroid->ValidateParams()); - return spectral_centroid; - })); - })); - -PYBIND_REGISTER( - SpectrogramOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "SpectrogramOperation") - .def(py::init([](int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, - float power, bool normalized, bool center, BorderType pad_mode, bool onesided) { - auto spectrogram = std::make_shared(n_fft, win_length, hop_length, pad, window, - power, normalized, center, pad_mode, onesided); - THROW_IF_ERROR(spectrogram->ValidateParams()); - return spectrogram; - })); - })); - -PYBIND_REGISTER( - TimeMaskingOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "TimeMaskingOperation") - .def(py::init([](bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) { - auto time_masking = - std::make_shared(iid_masks, time_mask_param, mask_start, mask_value); - THROW_IF_ERROR(time_masking->ValidateParams()); - return time_masking; - })); - })); - -PYBIND_REGISTER( - TimeStretchOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "TimeStretchOperation") - .def(py::init([](float hop_length, int n_freq, float fixed_rate) { - auto timestretch = std::make_shared(hop_length, n_freq, fixed_rate); - THROW_IF_ERROR(timestretch->ValidateParams()); - return timestretch; - })); - })); - -PYBIND_REGISTER( - TrebleBiquadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "TrebleBiquadOperation") - .def(py::init([](int32_t sample_rate, float gain, float central_freq, float Q) { - auto treble_biquad = std::make_shared(sample_rate, gain, central_freq, Q); - THROW_IF_ERROR(treble_biquad->ValidateParams()); - return treble_biquad; - })); - })); - -PYBIND_REGISTER(VadOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "VadOperation") - .def(py::init([](int32_t sample_rate, float trigger_level, float trigger_time, float search_time, - float allowed_gap, float pre_trigger_time, float boot_time, float noise_up_time, - float noise_down_time, float noise_reduction_amount, float measure_freq, - float measure_duration, float measure_smooth_time, float hp_filter_freq, - float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) { - auto vad = std::make_shared( - sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time, boot_time, - noise_up_time, noise_down_time, noise_reduction_amount, measure_freq, measure_duration, - measure_smooth_time, hp_filter_freq, lp_filter_freq, hp_lifter_freq, lp_lifter_freq); - THROW_IF_ERROR(vad->ValidateParams()); - return vad; - })); - })); - -PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) { - (void)py::class_>( - *m, "VolOperation") - .def(py::init([](float gain, GainType gain_type) { - auto vol = std::make_shared(gain, gain_type); - THROW_IF_ERROR(vol->ValidateParams()); - return vol; - })); - })); - -PYBIND_REGISTER(GainType, 0, ([](const py::module *m) { - (void)py::enum_(*m, "GainType", py::arithmetic()) - .value("DE_GAIN_TYPE_AMPLITUDE", GainType::kAmplitude) - .value("DE_GAIN_TYPE_POWER", GainType::kPower) - .value("DE_GAIN_TYPE_DB", GainType::kDb) - .export_values(); - })); -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pybind11/pybind11.h" + +#include "minddata/dataset/api/python/pybind_conversion.h" +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/include/dataset/transforms.h" + +#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" +#include "minddata/dataset/audio/ir/kernels/angle_ir.h" +#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" +#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h" +#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" +#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" +#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" +#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" +#include "minddata/dataset/audio/ir/kernels/dither_ir.h" +#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/fade_ir.h" +#include "minddata/dataset/audio/ir/kernels/filtfilt_ir.h" +#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" +#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" +#include "minddata/dataset/audio/ir/kernels/gain_ir.h" +#include "minddata/dataset/audio/ir/kernels/griffin_lim_ir.h" +#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/inverse_mel_scale_ir.h" +#include "minddata/dataset/audio/ir/kernels/inverse_spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/lfcc_ir.h" +#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" +#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" +#include "minddata/dataset/audio/ir/kernels/mel_scale_ir.h" +#include "minddata/dataset/audio/ir/kernels/mel_spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/mfcc_ir.h" +#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" +#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" +#include "minddata/dataset/audio/ir/kernels/phaser_ir.h" +#include "minddata/dataset/audio/ir/kernels/pitch_shift_ir.h" +#include "minddata/dataset/audio/ir/kernels/resample_ir.h" +#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h" +#include "minddata/dataset/audio/ir/kernels/spectral_centroid_ir.h" +#include "minddata/dataset/audio/ir/kernels/spectrogram_ir.h" +#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" +#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" +#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/vad_ir.h" +#include "minddata/dataset/audio/ir/kernels/vol_ir.h" + +namespace mindspore { +namespace dataset { +PYBIND_REGISTER( + AllpassBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "AllpassBiquadOperation") + .def(py::init([](int32_t sample_rate, float central_freq, float Q) { + auto allpass_biquad = std::make_shared(sample_rate, central_freq, Q); + THROW_IF_ERROR(allpass_biquad->ValidateParams()); + return allpass_biquad; + })); + })); + +PYBIND_REGISTER( + AmplitudeToDBOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "AmplitudeToDBOperation") + .def(py::init([](ScaleType stype, float ref_value, float amin, float top_db) { + auto amplitude_to_db = std::make_shared(stype, ref_value, amin, top_db); + THROW_IF_ERROR(amplitude_to_db->ValidateParams()); + return amplitude_to_db; + })); + })); + +PYBIND_REGISTER(ScaleType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "ScaleType", py::arithmetic()) + .value("DE_SCALE_TYPE_MAGNITUDE", ScaleType::kMagnitude) + .value("DE_SCALE_TYPE_POWER", ScaleType::kPower) + .export_values(); + })); + +PYBIND_REGISTER(AngleOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "AngleOperation") + .def(py::init([]() { + auto angle = std::make_shared(); + THROW_IF_ERROR(angle->ValidateParams()); + return angle; + })); + })); + +PYBIND_REGISTER( + BandBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BandBiquadOperation") + .def(py::init([](int32_t sample_rate, float central_freq, float Q, bool noise) { + auto band_biquad = std::make_shared(sample_rate, central_freq, Q, noise); + THROW_IF_ERROR(band_biquad->ValidateParams()); + return band_biquad; + })); + })); + +PYBIND_REGISTER( + BandpassBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BandpassBiquadOperation") + .def(py::init([](int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) { + auto bandpass_biquad = + std::make_shared(sample_rate, central_freq, Q, const_skirt_gain); + THROW_IF_ERROR(bandpass_biquad->ValidateParams()); + return bandpass_biquad; + })); + })); + +PYBIND_REGISTER(BandrejectBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, "BandrejectBiquadOperation") + .def(py::init([](int32_t sample_rate, float central_freq, float Q) { + auto bandreject_biquad = + std::make_shared(sample_rate, central_freq, Q); + THROW_IF_ERROR(bandreject_biquad->ValidateParams()); + return bandreject_biquad; + })); + })); + +PYBIND_REGISTER( + BassBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BassBiquadOperation") + .def(py::init([](int32_t sample_rate, float gain, float central_freq, float Q) { + auto bass_biquad = std::make_shared(sample_rate, gain, central_freq, Q); + THROW_IF_ERROR(bass_biquad->ValidateParams()); + return bass_biquad; + })); + })); + +PYBIND_REGISTER(BiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "BiquadOperation") + .def(py::init([](float b0, float b1, float b2, float a0, float a1, float a2) { + auto biquad = std::make_shared(b0, b1, b2, a0, a1, a2); + THROW_IF_ERROR(biquad->ValidateParams()); + return biquad; + })); + })); + +PYBIND_REGISTER( + ComplexNormOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ComplexNormOperation") + .def(py::init([](float power) { + auto complex_norm = std::make_shared(power); + THROW_IF_ERROR(complex_norm->ValidateParams()); + return complex_norm; + })); + })); + +PYBIND_REGISTER( + ComputeDeltasOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "ComputeDeltasOperation") + .def(py::init([](int32_t win_length, BorderType pad_mode) { + auto compute_deltas = std::make_shared(win_length, pad_mode); + THROW_IF_ERROR(compute_deltas->ValidateParams()); + return compute_deltas; + })); + })); + +PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "ContrastOperation") + .def(py::init([](float enhancement_amount) { + auto contrast = std::make_shared(enhancement_amount); + THROW_IF_ERROR(contrast->ValidateParams()); + return contrast; + })); + })); + +PYBIND_REGISTER( + DBToAmplitudeOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DBToAmplitudeOperation") + .def(py::init([](float ref, float power) { + auto db_to_amplitude = std::make_shared(ref, power); + THROW_IF_ERROR(db_to_amplitude->ValidateParams()); + return db_to_amplitude; + })); + })); + +PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DCShiftOperation") + .def(py::init([](float shift, float limiter_gain) { + auto dc_shift = std::make_shared(shift, limiter_gain); + THROW_IF_ERROR(dc_shift->ValidateParams()); + return dc_shift; + })); + })); + +PYBIND_REGISTER( + DeemphBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DeemphBiquadOperation") + .def(py::init([](int32_t sample_rate) { + auto deemph_biquad = std::make_shared(sample_rate); + THROW_IF_ERROR(deemph_biquad->ValidateParams()); + return deemph_biquad; + })); + })); + +PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DetectPitchFrequencyOperation") + .def(py::init([](int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, + int32_t freq_high) { + auto detect_pitch_frequency = std::make_shared( + sample_rate, frame_time, win_length, freq_low, freq_high); + THROW_IF_ERROR(detect_pitch_frequency->ValidateParams()); + return detect_pitch_frequency; + })); + })); + +PYBIND_REGISTER(DensityFunction, 0, ([](const py::module *m) { + (void)py::enum_(*m, "DensityFunction", py::arithmetic()) + .value("DE_DENSITY_FUNCTION_TPDF", DensityFunction::kTPDF) + .value("DE_DENSITY_FUNCTION_RPDF", DensityFunction::kRPDF) + .value("DE_DENSITY_FUNCTION_GPDF", DensityFunction::kGPDF) + .export_values(); + })); + +PYBIND_REGISTER(DitherOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DitherOperation") + .def(py::init([](DensityFunction density_function, bool noise_shaping) { + auto dither = std::make_shared(density_function, noise_shaping); + THROW_IF_ERROR(dither->ValidateParams()); + return dither; + })); + })); + +PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, "EqualizerBiquadOperation") + .def(py::init([](int sample_rate, float center_freq, float gain, float Q) { + auto equalizer_biquad = + std::make_shared(sample_rate, center_freq, gain, Q); + THROW_IF_ERROR(equalizer_biquad->ValidateParams()); + return equalizer_biquad; + })); + })); + +PYBIND_REGISTER(FadeShape, 0, ([](const py::module *m) { + (void)py::enum_(*m, "FadeShape", py::arithmetic()) + .value("DE_FADE_SHAPE_LINEAR", FadeShape::kLinear) + .value("DE_FADE_SHAPE_EXPONENTIAL", FadeShape::kExponential) + .value("DE_FADE_SHAPE_LOGARITHMIC", FadeShape::kLogarithmic) + .value("DE_FADE_SHAPE_QUARTER_SINE", FadeShape::kQuarterSine) + .value("DE_FADE_SHAPE_HALF_SINE", FadeShape::kHalfSine) + .export_values(); + })); + +PYBIND_REGISTER(FadeOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "FadeOperation") + .def(py::init([](int fade_in_len, int fade_out_len, FadeShape fade_shape) { + auto fade = std::make_shared(fade_in_len, fade_out_len, fade_shape); + THROW_IF_ERROR(fade->ValidateParams()); + return fade; + })); + })); + +PYBIND_REGISTER(FiltfiltOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "FiltfiltOperation") + .def(py::init([](const std::vector &a_coeffs, std::vector &b_coeffs, bool clamp) { + auto filtfilt = std::make_shared(a_coeffs, b_coeffs, clamp); + THROW_IF_ERROR(filtfilt->ValidateParams()); + return filtfilt; + })); + })); + +PYBIND_REGISTER(Modulation, 0, ([](const py::module *m) { + (void)py::enum_(*m, "Modulation", py::arithmetic()) + .value("DE_MODULATION_SINUSOIDAL", Modulation::kSinusoidal) + .value("DE_MODULATION_TRIANGULAR", Modulation::kTriangular) + .export_values(); + })); + +PYBIND_REGISTER(Interpolation, 0, ([](const py::module *m) { + (void)py::enum_(*m, "Interpolation", py::arithmetic()) + .value("DE_INTERPOLATION_LINEAR", Interpolation::kLinear) + .value("DE_INTERPOLATION_QUADRATIC", Interpolation::kQuadratic) + .export_values(); + })); + +PYBIND_REGISTER(FlangerOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "FlangerOperation") + .def(py::init([](int32_t sample_rate, float delay, float depth, float regen, float width, + float speed, float phase, Modulation modulation, Interpolation interpolation) { + auto flanger = std::make_shared(sample_rate, delay, depth, regen, width, + speed, phase, modulation, interpolation); + THROW_IF_ERROR(flanger->ValidateParams()); + return flanger; + })); + })); + +PYBIND_REGISTER( + FrequencyMaskingOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "FrequencyMaskingOperation") + .def(py::init([](bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value) { + auto frequency_masking = + std::make_shared(iid_masks, frequency_mask_param, mask_start, mask_value); + THROW_IF_ERROR(frequency_masking->ValidateParams()); + return frequency_masking; + })); + })); + +PYBIND_REGISTER(GainOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "GainOperation") + .def(py::init([](float gain_db) { + auto gain = std::make_shared(gain_db); + THROW_IF_ERROR(gain->ValidateParams()); + return gain; + })); + })); + +PYBIND_REGISTER( + GriffinLimOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "GriffinLimOperation") + .def(py::init([](int32_t n_fft, int32_t n_iter, int32_t win_length, int32_t hop_length, WindowType window_type, + float power, float momentum, int32_t length, bool rand_init) { + auto griffin_lim = std::make_shared( + n_fft, n_iter, win_length, hop_length, window_type, power, momentum, length, rand_init); + THROW_IF_ERROR(griffin_lim->ValidateParams()); + return griffin_lim; + })); + })); + +PYBIND_REGISTER( + HighpassBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "HighpassBiquadOperation") + .def(py::init([](float sample_rate, float cutoff_freq, float Q) { + auto highpass_biquad = std::make_shared(sample_rate, cutoff_freq, Q); + THROW_IF_ERROR(highpass_biquad->ValidateParams()); + return highpass_biquad; + })); + })); + +PYBIND_REGISTER(InverseMelScaleOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, "InverseMelScaleOperation") + .def(py::init([](int32_t n_stft, int32_t n_mels, int32_t sample_rate, float f_min, float f_max, + int32_t max_iter, float tolerance_loss, float tolerance_change, + const py::dict &sgdargs, NormType norm, MelType mel_type) { + auto inverse_mel_scale = std::make_shared( + n_stft, n_mels, sample_rate, f_min, f_max, max_iter, tolerance_loss, tolerance_change, + toStringFloatMap(sgdargs), norm, mel_type); + THROW_IF_ERROR(inverse_mel_scale->ValidateParams()); + return inverse_mel_scale; + })); + })); + +PYBIND_REGISTER(InverseSpectrogramOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, + "InverseSpectrogramOperation") + .def( + py::init([](int32_t length, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, + WindowType window, bool normalized, bool center, BorderType pad_mode, bool onesided) { + auto inverse_spectrogram = std::make_shared( + length, n_fft, win_length, hop_length, pad, window, normalized, center, pad_mode, onesided); + THROW_IF_ERROR(inverse_spectrogram->ValidateParams()); + return inverse_spectrogram; + })); + })); + +PYBIND_REGISTER(LFCCOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "LFCCOperation") + .def(py::init([](int32_t sample_rate, int32_t n_filter, int32_t n_lfcc, float f_min, float f_max, + int32_t dct_type, NormMode norm, bool log_lf, const py::dict &speckwargs, + WindowType window, BorderType pad_mode) { + int32_t n_fft = py::cast(speckwargs["n_fft"]); + int32_t win_length = py::cast(speckwargs["win_length"]); + int32_t hop_length = py::cast(speckwargs["hop_length"]); + int32_t pad = py::cast(speckwargs["pad"]); + float power = py::cast(speckwargs["power"]); + bool normalized = py::cast(speckwargs["normalized"]); + bool center = py::cast(speckwargs["center"]); + bool onesided = py::cast(speckwargs["onesided"]); + auto lfcc = std::make_shared( + sample_rate, n_filter, n_lfcc, f_min, f_max, dct_type, norm, log_lf, n_fft, win_length, + hop_length, pad, window, power, normalized, center, pad_mode, onesided); + THROW_IF_ERROR(lfcc->ValidateParams()); + return lfcc; + })); + })); + +PYBIND_REGISTER(LFilterOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "LFilterOperation") + .def(py::init([](std::vector a_coeffs, std::vector b_coeffs, bool clamp) { + auto lfilter = std::make_shared(a_coeffs, b_coeffs, clamp); + THROW_IF_ERROR(lfilter->ValidateParams()); + return lfilter; + })); + })); + +PYBIND_REGISTER( + LowpassBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "LowpassBiquadOperation") + .def(py::init([](int sample_rate, float cutoff_freq, float Q) { + auto lowpass_biquad = std::make_shared(sample_rate, cutoff_freq, Q); + THROW_IF_ERROR(lowpass_biquad->ValidateParams()); + return lowpass_biquad; + })); + })); + +PYBIND_REGISTER(MagphaseOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "MagphaseOperation") + .def(py::init([](float power) { + auto magphase = std::make_shared(power); + THROW_IF_ERROR(magphase->ValidateParams()); + return magphase; + })); + })); + +PYBIND_REGISTER(MaskAlongAxisIIDOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, "MaskAlongAxisIIDOperation") + .def(py::init([](int32_t mask_param, float mask_value, int32_t axis) { + auto mask_along_axis_iid = + std::make_shared(mask_param, mask_value, axis); + THROW_IF_ERROR(mask_along_axis_iid->ValidateParams()); + return mask_along_axis_iid; + })); + })); + +PYBIND_REGISTER( + MaskAlongAxisOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MaskAlongAxisOperation") + .def(py::init([](int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) { + auto mask_along_axis = + std::make_shared(mask_start, mask_width, mask_value, axis); + THROW_IF_ERROR(mask_along_axis->ValidateParams()); + return mask_along_axis; + })); + })); + +PYBIND_REGISTER(MelScaleOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "MelScaleOperation") + .def(py::init([](int32_t n_mels, int32_t sample_rate, float f_min, float f_max, int32_t n_stft, + NormType norm, MelType mel_type) { + auto mel_scale = std::make_shared(n_mels, sample_rate, f_min, f_max, + n_stft, norm, mel_type); + THROW_IF_ERROR(mel_scale->ValidateParams()); + return mel_scale; + })); + })); + +PYBIND_REGISTER( + MelSpectrogramOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MelSpectrogramOperation") + .def(py::init([](int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, float f_min, + float f_max, int32_t pad, int32_t n_mels, WindowType window, float power, bool normalized, + bool center, BorderType pad_mode, bool onesided, NormType norm, MelType mel_scale) { + auto mel_spectrogram = std::make_shared( + sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, window, power, normalized, center, + pad_mode, onesided, norm, mel_scale); + THROW_IF_ERROR(mel_spectrogram->ValidateParams()); + return mel_spectrogram; + })); + })); + +PYBIND_REGISTER(MFCCOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MFCCOperation") + .def(py::init([](int32_t sample_rate, int32_t n_mfcc, int32_t dct_type, NormMode norm, + bool log_mels, const py::dict &melkwargs, WindowType window, BorderType pad_mode, + NormType norm_mel, MelType mel_scale) { + int32_t n_fft = py::cast(melkwargs["n_fft"]); + int32_t win_length = py::cast(melkwargs["win_length"]); + int32_t hop_length = py::cast(melkwargs["hop_length"]); + float f_min = py::cast(melkwargs["f_min"]); + float f_max = py::cast(melkwargs["f_max"]); + int32_t pad = py::cast(melkwargs["pad"]); + int32_t n_mels = py::cast(melkwargs["n_mels"]); + float power = py::cast(melkwargs["power"]); + bool normalized = py::cast(melkwargs["normalized"]); + bool center = py::cast(melkwargs["center"]); + bool onesided = py::cast(melkwargs["onesided"]); + auto mfcc = std::make_shared( + sample_rate, n_mfcc, dct_type, norm, log_mels, n_fft, win_length, hop_length, f_min, f_max, pad, + n_mels, window, power, normalized, center, pad_mode, onesided, norm_mel, mel_scale); + THROW_IF_ERROR(mfcc->ValidateParams()); + return mfcc; + })); + })); + +PYBIND_REGISTER( + MuLawDecodingOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MuLawDecodingOperation") + .def(py::init([](int32_t quantization_channels) { + auto mu_law_decoding = std::make_shared(quantization_channels); + THROW_IF_ERROR(mu_law_decoding->ValidateParams()); + return mu_law_decoding; + })); + })); + +PYBIND_REGISTER( + MuLawEncodingOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MuLawEncodingOperation") + .def(py::init([](int32_t quantization_channels) { + auto mu_law_encoding = std::make_shared(quantization_channels); + THROW_IF_ERROR(mu_law_encoding->ValidateParams()); + return mu_law_encoding; + })); + })); + +PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "OverdriveOperation") + .def(py::init([](float gain, float color) { + auto overdrive = std::make_shared(gain, color); + THROW_IF_ERROR(overdrive->ValidateParams()); + return overdrive; + })); + })); + +PYBIND_REGISTER(PhaserOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "PhaserOperation") + .def(py::init([](int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, + float mod_speed, bool sinusoidal) { + auto phaser = std::make_shared(sample_rate, gain_in, gain_out, delay_ms, + decay, mod_speed, sinusoidal); + THROW_IF_ERROR(phaser->ValidateParams()); + return phaser; + })); + })); + +PYBIND_REGISTER( + PhaseVocoderOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "PhaseVocoderOperation") + .def(py::init([](float rate, const std::shared_ptr &phase_advance) { + auto phase_vocoder = std::make_shared(rate, phase_advance); + THROW_IF_ERROR(phase_vocoder->ValidateParams()); + return phase_vocoder; + })); + })); + +PYBIND_REGISTER( + PitchShiftOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "PitchShiftOperation") + .def(py::init([](int32_t sample_rate, int32_t n_steps, int32_t bins_per_octave, int32_t n_fft, int32_t win_length, + int32_t hop_length, WindowType window) { + auto pitch_shift = std::make_shared(sample_rate, n_steps, bins_per_octave, n_fft, + win_length, hop_length, window); + THROW_IF_ERROR(pitch_shift->ValidateParams()); + return pitch_shift; + })); + })); + +PYBIND_REGISTER(ResampleMethod, 0, ([](const py::module *m) { + (void)py::enum_(*m, "ResampleMethod", py::arithmetic()) + .value("DE_RESAMPLE_SINC_INTERPOLATION", ResampleMethod::kSincInterpolation) + .value("DE_RESAMPLE_KAISER_WINDOW", ResampleMethod::kKaiserWindow) + .export_values(); + })); + +PYBIND_REGISTER(ResampleOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "ResampleOperation") + .def(py::init([](float orig_freq, float new_freq, ResampleMethod resample_method, + int32_t lowpass_filter_width, float rolloff, float beta) { + auto resample = std::make_shared(orig_freq, new_freq, resample_method, + lowpass_filter_width, rolloff, beta); + THROW_IF_ERROR(resample->ValidateParams()); + return resample; + })); + })); + +PYBIND_REGISTER( + RiaaBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "RiaaBiquadOperation") + .def(py::init([](int32_t sample_rate) { + auto riaa_biquad = std::make_shared(sample_rate); + THROW_IF_ERROR(riaa_biquad->ValidateParams()); + return riaa_biquad; + })); + })); + +PYBIND_REGISTER(SlidingWindowCmnOperation, 1, ([](const py::module *m) { + (void)py::class_>(*m, "SlidingWindowCmnOperation") + .def(py::init([](int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) { + auto sliding_window_cmn = std::make_shared( + cmn_window, min_cmn_window, center, norm_vars); + THROW_IF_ERROR(sliding_window_cmn->ValidateParams()); + return sliding_window_cmn; + })); + })); + +PYBIND_REGISTER(WindowType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "WindowType", py::arithmetic()) + .value("DE_WINDOW_TYPE_BARTLETT", WindowType::kBartlett) + .value("DE_WINDOW_TYPE_BLACKMAN", WindowType::kBlackman) + .value("DE_WINDOW_TYPE_HAMMING", WindowType::kHamming) + .value("DE_WINDOW_TYPE_HANN", WindowType::kHann) + .value("DE_WINDOW_TYPE_KAISER", WindowType::kKaiser) + .export_values(); + })); + +PYBIND_REGISTER( + SpectralCentroidOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "SpectralCentroidOperation") + .def(py::init([](int sample_rate, int n_fft, int win_length, int hop_length, int pad, WindowType window) { + auto spectral_centroid = + std::make_shared(sample_rate, n_fft, win_length, hop_length, pad, window); + THROW_IF_ERROR(spectral_centroid->ValidateParams()); + return spectral_centroid; + })); + })); + +PYBIND_REGISTER( + SpectrogramOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "SpectrogramOperation") + .def(py::init([](int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window, + float power, bool normalized, bool center, BorderType pad_mode, bool onesided) { + auto spectrogram = std::make_shared(n_fft, win_length, hop_length, pad, window, + power, normalized, center, pad_mode, onesided); + THROW_IF_ERROR(spectrogram->ValidateParams()); + return spectrogram; + })); + })); + +PYBIND_REGISTER( + TimeMaskingOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "TimeMaskingOperation") + .def(py::init([](bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value) { + auto time_masking = + std::make_shared(iid_masks, time_mask_param, mask_start, mask_value); + THROW_IF_ERROR(time_masking->ValidateParams()); + return time_masking; + })); + })); + +PYBIND_REGISTER( + TimeStretchOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "TimeStretchOperation") + .def(py::init([](float hop_length, int n_freq, float fixed_rate) { + auto timestretch = std::make_shared(hop_length, n_freq, fixed_rate); + THROW_IF_ERROR(timestretch->ValidateParams()); + return timestretch; + })); + })); + +PYBIND_REGISTER( + TrebleBiquadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "TrebleBiquadOperation") + .def(py::init([](int32_t sample_rate, float gain, float central_freq, float Q) { + auto treble_biquad = std::make_shared(sample_rate, gain, central_freq, Q); + THROW_IF_ERROR(treble_biquad->ValidateParams()); + return treble_biquad; + })); + })); + +PYBIND_REGISTER(VadOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "VadOperation") + .def(py::init([](int32_t sample_rate, float trigger_level, float trigger_time, float search_time, + float allowed_gap, float pre_trigger_time, float boot_time, float noise_up_time, + float noise_down_time, float noise_reduction_amount, float measure_freq, + float measure_duration, float measure_smooth_time, float hp_filter_freq, + float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) { + auto vad = std::make_shared( + sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time, boot_time, + noise_up_time, noise_down_time, noise_reduction_amount, measure_freq, measure_duration, + measure_smooth_time, hp_filter_freq, lp_filter_freq, hp_lifter_freq, lp_lifter_freq); + THROW_IF_ERROR(vad->ValidateParams()); + return vad; + })); + })); + +PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "VolOperation") + .def(py::init([](float gain, GainType gain_type) { + auto vol = std::make_shared(gain, gain_type); + THROW_IF_ERROR(vol->ValidateParams()); + return vol; + })); + })); + +PYBIND_REGISTER(GainType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "GainType", py::arithmetic()) + .value("DE_GAIN_TYPE_AMPLITUDE", GainType::kAmplitude) + .value("DE_GAIN_TYPE_POWER", GainType::kPower) + .value("DE_GAIN_TYPE_DB", GainType::kDb) + .export_values(); + })); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.cc index 18f55235981..e307c908513 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.cc @@ -1,52 +1,52 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/allpass_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// AllpassBiquadOperation -AllpassBiquadOperation::AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - -Status AllpassBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalarNotZero("AllpassBiquad", "sample_rate", sample_rate_)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("AllpassBiquad", "central_freq", central_freq_)); - RETURN_IF_NOT_OK(ValidateScalar("AllpassBiquad", "Q", Q_, {0, 1.0}, true, false)); - return Status::OK(); -} - -std::shared_ptr AllpassBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_); - return tensor_op; -} - -Status AllpassBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["central_freq"] = central_freq_; - args["Q"] = Q_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/allpass_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// AllpassBiquadOperation +AllpassBiquadOperation::AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + +Status AllpassBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalarNotZero("AllpassBiquad", "sample_rate", sample_rate_)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("AllpassBiquad", "central_freq", central_freq_)); + RETURN_IF_NOT_OK(ValidateScalar("AllpassBiquad", "Q", Q_, {0, 1.0}, true, false)); + return Status::OK(); +} + +std::shared_ptr AllpassBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_); + return tensor_op; +} + +Status AllpassBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["central_freq"] = central_freq_; + args["Q"] = Q_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h index 58164064fc1..01032067f51 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { - -constexpr char kAllpassBiquadOperation[] = "AllpassBiquad"; - -class AllpassBiquadOperation : public TensorOperation { - public: - AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q); - - ~AllpassBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kAllpassBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +constexpr char kAllpassBiquadOperation[] = "AllpassBiquad"; + +class AllpassBiquadOperation : public TensorOperation { + public: + AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q); + + ~AllpassBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kAllpassBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_ALLPASS_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/angle_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/angle_ir.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/angle_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/angle_ir.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.cc index 2359ec70853..10b5a4b5c4e 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.cc @@ -1,52 +1,52 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/band_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// BandBiquadOperation -BandBiquadOperation::BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} - -Status BandBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBiquad", "sample_rate", sample_rate_)); - return Status::OK(); -} - -std::shared_ptr BandBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_, noise_); - return tensor_op; -} - -Status BandBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["central_freq"] = central_freq_; - args["Q"] = Q_; - args["noise"] = noise_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/band_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// BandBiquadOperation +BandBiquadOperation::BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} + +Status BandBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalar("BandBiquad", "Q", Q_, {0, 1.0}, true, false)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("BandBiquad", "sample_rate", sample_rate_)); + return Status::OK(); +} + +std::shared_ptr BandBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_, noise_); + return tensor_op; +} + +Status BandBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["central_freq"] = central_freq_; + args["Q"] = Q_; + args["noise"] = noise_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.h index eecbaa9c504..d11f75d1558 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/band_biquad_ir.h @@ -1,58 +1,58 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ - -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kBandBiquadOperation[] = "BandBiquad"; - -class BandBiquadOperation : public TensorOperation { - public: - BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise); - - ~BandBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kBandBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; - bool noise_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ + +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kBandBiquadOperation[] = "BandBiquad"; + +class BandBiquadOperation : public TensorOperation { + public: + BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise); + + ~BandBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kBandBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; + bool noise_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BAND_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.cc old mode 100755 new mode 100644 index 84bfe304bf1..e6ca561cd15 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.cc @@ -1,54 +1,54 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/bandpass_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// BandpassBiquadOperation -BandpassBiquadOperation::BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, - bool const_skirt_gain) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} - -Status BandpassBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalar("BandpassBiquad", "Q", Q_, {0, 1.0}, true, false)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("BandpassBiquad", "sample_rate", sample_rate_)); - return Status::OK(); -} - -std::shared_ptr BandpassBiquadOperation::Build() { - std::shared_ptr tensor_op = - std::make_shared(sample_rate_, central_freq_, Q_, const_skirt_gain_); - return tensor_op; -} - -Status BandpassBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["central_freq"] = central_freq_; - args["Q"] = Q_; - args["const_skirt_gain"] = const_skirt_gain_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/bandpass_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// BandpassBiquadOperation +BandpassBiquadOperation::BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, + bool const_skirt_gain) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} + +Status BandpassBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalar("BandpassBiquad", "Q", Q_, {0, 1.0}, true, false)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("BandpassBiquad", "sample_rate", sample_rate_)); + return Status::OK(); +} + +std::shared_ptr BandpassBiquadOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(sample_rate_, central_freq_, Q_, const_skirt_gain_); + return tensor_op; +} + +Status BandpassBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["central_freq"] = central_freq_; + args["Q"] = Q_; + args["const_skirt_gain"] = const_skirt_gain_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h old mode 100755 new mode 100644 index 1ab9bd6f10e..92f9bc9a63b --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h @@ -1,58 +1,58 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ - -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kBandpassBiquadOperation[] = "BandpassBiquad"; - -class BandpassBiquadOperation : public TensorOperation { - public: - BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain); - - ~BandpassBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kBandpassBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; - bool const_skirt_gain_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ + +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kBandpassBiquadOperation[] = "BandpassBiquad"; + +class BandpassBiquadOperation : public TensorOperation { + public: + BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain); + + ~BandpassBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kBandpassBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; + bool const_skirt_gain_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDPASS_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.cc index d8db2478b5b..250f81cefba 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.cc @@ -1,51 +1,51 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/bandreject_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// BandrejectBiquadOperation -BandrejectBiquadOperation::BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - -Status BandrejectBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalar("BandrejectBiquad", "Q", Q_, {0, 1.0}, true, false)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("BandrejectBiquad", "sample_rate", sample_rate_)); - return Status::OK(); -} - -std::shared_ptr BandrejectBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_); - return tensor_op; -} - -Status BandrejectBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["central_freq"] = central_freq_; - args["Q"] = Q_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/bandreject_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// BandrejectBiquadOperation +BandrejectBiquadOperation::BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + +Status BandrejectBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalar("BandrejectBiquad", "Q", Q_, {0, 1.0}, true, false)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("BandrejectBiquad", "sample_rate", sample_rate_)); + return Status::OK(); +} + +std::shared_ptr BandrejectBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, central_freq_, Q_); + return tensor_op; +} + +Status BandrejectBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["central_freq"] = central_freq_; + args["Q"] = Q_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h index dbc6a9a91a2..da42ee5704b 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ - -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad"; - -class BandrejectBiquadOperation : public TensorOperation { - public: - BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q); - - ~BandrejectBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kBandrejectBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ + +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kBandrejectBiquadOperation[] = "BandrejectBiquad"; + +class BandrejectBiquadOperation : public TensorOperation { + public: + BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q); + + ~BandrejectBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kBandrejectBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BANDREJECT_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.cc index a95613aaf4c..bad1ad10357 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.cc @@ -1,52 +1,52 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/bass_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// BassBiquadOperation -BassBiquadOperation::BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q) - : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} - -Status BassBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalar("BassBiquad", "Q", Q_, {0, 1.0}, true, false)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("BassBiquad", "sample_rate", sample_rate_)); - return Status::OK(); -} - -std::shared_ptr BassBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, gain_, central_freq_, Q_); - return tensor_op; -} - -Status BassBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["gain"] = gain_; - args["central_freq"] = central_freq_; - args["Q"] = Q_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/bass_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// BassBiquadOperation +BassBiquadOperation::BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q) + : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} + +Status BassBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalar("BassBiquad", "Q", Q_, {0, 1.0}, true, false)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("BassBiquad", "sample_rate", sample_rate_)); + return Status::OK(); +} + +std::shared_ptr BassBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, gain_, central_freq_, Q_); + return tensor_op; +} + +Status BassBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["gain"] = gain_; + args["central_freq"] = central_freq_; + args["Q"] = Q_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.h index 610bf6d36e6..f190082f3dd 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/bass_biquad_ir.h @@ -1,58 +1,58 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ - -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kBassBiquadOperation[] = "BassBiquad"; - -class BassBiquadOperation : public TensorOperation { - public: - BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q); - - ~BassBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kBassBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float gain_; - float central_freq_; - float Q_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ + +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kBassBiquadOperation[] = "BassBiquad"; + +class BassBiquadOperation : public TensorOperation { + public: + BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q); + + ~BassBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kBassBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float gain_; + float central_freq_; + float Q_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BASS_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/dc_shift_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/dc_shift_ir.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.cc index 5d432458636..f61f9f41a03 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.cc @@ -1,51 +1,51 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" - -#include "minddata/dataset/audio/kernels/deemph_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// DeemphBiquadOperation -DeemphBiquadOperation::DeemphBiquadOperation(int32_t sample_rate) : sample_rate_(sample_rate) {} - -Status DeemphBiquadOperation::ValidateParams() { - if ((sample_rate_ != 44100 && sample_rate_ != 48000)) { - std::string err_msg = - "DeemphBiquad: sample_rate can only be 44100 or 48000, but got: " + std::to_string(sample_rate_); - MS_LOG(ERROR) << err_msg; - RETURN_SYNTAX_ERROR(err_msg); - } - return Status::OK(); -} - -std::shared_ptr DeemphBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_); - return tensor_op; -} - -Status DeemphBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" + +#include "minddata/dataset/audio/kernels/deemph_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// DeemphBiquadOperation +DeemphBiquadOperation::DeemphBiquadOperation(int32_t sample_rate) : sample_rate_(sample_rate) {} + +Status DeemphBiquadOperation::ValidateParams() { + if ((sample_rate_ != 44100 && sample_rate_ != 48000)) { + std::string err_msg = + "DeemphBiquad: sample_rate can only be 44100 or 48000, but got: " + std::to_string(sample_rate_); + MS_LOG(ERROR) << err_msg; + RETURN_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +std::shared_ptr DeemphBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_); + return tensor_op; +} + +Status DeemphBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h index f5fc8e4e37e..fc3a15b87ee 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h @@ -1,54 +1,54 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kDeemphBiquadOperation[] = "DeemphBiquad"; - -class DeemphBiquadOperation : public TensorOperation { - public: - explicit DeemphBiquadOperation(int32_t sample_rate); - - ~DeemphBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kDeemphBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kDeemphBiquadOperation[] = "DeemphBiquad"; + +class DeemphBiquadOperation : public TensorOperation { + public: + explicit DeemphBiquadOperation(int32_t sample_rate); + + ~DeemphBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kDeemphBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DEEMPH_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.cc index 63250f57816..b3d3329d4be 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.cc @@ -1,72 +1,72 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/flanger_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// FlangerOperation -FlangerOperation::FlangerOperation(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, - float phase, Modulation modulation, Interpolation interpolation) - : sample_rate_(sample_rate), - delay_(delay), - depth_(depth), - regen_(regen), - width_(width), - speed_(speed), - phase_(phase), - modulation_(modulation), - interpolation_(interpolation) {} - -Status FlangerOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalarNotZero("Flanger", "sample_rate", sample_rate_)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "delay", delay_, {0, 30}, false, false)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "depth", depth_, {0, 10}, false, false)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "regen", regen_, {-95, 95}, false, false)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "width", width_, {0, 100}, false, false)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "speed", speed_, {0.1, 10}, false, false)); - RETURN_IF_NOT_OK(ValidateScalar("Flanger", "phase", phase_, {0, 100}, false, false)); - return Status::OK(); -} - -std::shared_ptr FlangerOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, delay_, depth_, regen_, width_, - speed_, phase_, modulation_, interpolation_); - return tensor_op; -} - -Status FlangerOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["delay"] = delay_; - args["depth"] = depth_; - args["regen"] = regen_; - args["width"] = width_; - args["speed"] = speed_; - args["phase"] = phase_; - args["modulation"] = modulation_; - args["interpolation"] = interpolation_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/flanger_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/flanger_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// FlangerOperation +FlangerOperation::FlangerOperation(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, + float phase, Modulation modulation, Interpolation interpolation) + : sample_rate_(sample_rate), + delay_(delay), + depth_(depth), + regen_(regen), + width_(width), + speed_(speed), + phase_(phase), + modulation_(modulation), + interpolation_(interpolation) {} + +Status FlangerOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalarNotZero("Flanger", "sample_rate", sample_rate_)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "delay", delay_, {0, 30}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "depth", depth_, {0, 10}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "regen", regen_, {-95, 95}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "width", width_, {0, 100}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "speed", speed_, {0.1, 10}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Flanger", "phase", phase_, {0, 100}, false, false)); + return Status::OK(); +} + +std::shared_ptr FlangerOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, delay_, depth_, regen_, width_, + speed_, phase_, modulation_, interpolation_); + return tensor_op; +} + +Status FlangerOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["delay"] = delay_; + args["depth"] = depth_; + args["regen"] = regen_; + args["width"] = width_; + args["speed"] = speed_; + args["phase"] = phase_; + args["modulation"] = modulation_; + args["interpolation"] = interpolation_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.h index c94fe94d9df..06e759561cc 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/flanger_ir.h @@ -1,64 +1,64 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { - -constexpr char kFlangerOperation[] = "Flanger"; - -class FlangerOperation : public TensorOperation { - public: - explicit FlangerOperation(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, - float phase, Modulation modulation, Interpolation interpolation); - - ~FlangerOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kFlangerOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float delay_; - float depth_; - float regen_; - float width_; - float speed_; - float phase_; - Modulation modulation_; - Interpolation interpolation_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +constexpr char kFlangerOperation[] = "Flanger"; + +class FlangerOperation : public TensorOperation { + public: + explicit FlangerOperation(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, + float phase, Modulation modulation, Interpolation interpolation); + + ~FlangerOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kFlangerOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float delay_; + float depth_; + float regen_; + float width_; + float speed_; + float phase_; + Modulation modulation_; + Interpolation interpolation_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FLANGER_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.cc old mode 100755 new mode 100644 index c7ab27aa86c..7a629a2f049 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.cc @@ -1,51 +1,51 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/highpass_biquad_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// HighpassBiquadOperation -HighpassBiquadOperation::HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q) - : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} - -Status HighpassBiquadOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateScalarNotZero("HighpassBiquad", "sample_rate", sample_rate_)); - RETURN_IF_NOT_OK(ValidateScalar("HighpassBiquad", "Q", Q_, {0, 1.0}, true, false)); - return Status::OK(); -} - -std::shared_ptr HighpassBiquadOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(sample_rate_, cutoff_freq_, Q_); - return tensor_op; -} - -Status HighpassBiquadOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["sample_rate"] = sample_rate_; - args["cutoff_freq"] = cutoff_freq_; - args["Q"] = Q_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/highpass_biquad_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// HighpassBiquadOperation +HighpassBiquadOperation::HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q) + : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} + +Status HighpassBiquadOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalarNotZero("HighpassBiquad", "sample_rate", sample_rate_)); + RETURN_IF_NOT_OK(ValidateScalar("HighpassBiquad", "Q", Q_, {0, 1.0}, true, false)); + return Status::OK(); +} + +std::shared_ptr HighpassBiquadOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(sample_rate_, cutoff_freq_, Q_); + return tensor_op; +} + +Status HighpassBiquadOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["sample_rate"] = sample_rate_; + args["cutoff_freq"] = cutoff_freq_; + args["Q"] = Q_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h old mode 100755 new mode 100644 index e52a70cc18e..a13a730fa8d --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ - -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kHighpassBiquadOperation[] = "HighpassBiquad"; - -class HighpassBiquadOperation : public TensorOperation { - public: - HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q); - - ~HighpassBiquadOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kHighpassBiquadOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t sample_rate_; - float cutoff_freq_; - float Q_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ + +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kHighpassBiquadOperation[] = "HighpassBiquad"; + +class HighpassBiquadOperation : public TensorOperation { + public: + HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q); + + ~HighpassBiquadOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kHighpassBiquadOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t sample_rate_; + float cutoff_freq_; + float Q_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.cc index 4c0fbfd85cb..13edc3f5646 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.cc @@ -1,53 +1,53 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/lfilter_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// LFilterOperation -LFilterOperation::LFilterOperation(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} - -Status LFilterOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "a_coeffs", a_coeffs_)); - RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "b_coeffs", b_coeffs_)); - RETURN_IF_NOT_OK(ValidateVectorSameSize("lfilter", "a_coeffs", a_coeffs_, "b_coeffs", b_coeffs_)); - RETURN_IF_NOT_OK(ValidateScalarNotZero("lfilter", "a_coeffs[0]", a_coeffs_[0])); - return Status::OK(); -} - -std::shared_ptr LFilterOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(a_coeffs_, b_coeffs_, clamp_); - return tensor_op; -} - -Status LFilterOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["a_coeffs"] = a_coeffs_; - args["b_coeffs"] = b_coeffs_; - args["clamp"] = clamp_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/lfilter_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// LFilterOperation +LFilterOperation::LFilterOperation(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} + +Status LFilterOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "a_coeffs", a_coeffs_)); + RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "b_coeffs", b_coeffs_)); + RETURN_IF_NOT_OK(ValidateVectorSameSize("lfilter", "a_coeffs", a_coeffs_, "b_coeffs", b_coeffs_)); + RETURN_IF_NOT_OK(ValidateScalarNotZero("lfilter", "a_coeffs[0]", a_coeffs_[0])); + return Status::OK(); +} + +std::shared_ptr LFilterOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(a_coeffs_, b_coeffs_, clamp_); + return tensor_op; +} + +Status LFilterOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["a_coeffs"] = a_coeffs_; + args["b_coeffs"] = b_coeffs_; + args["clamp"] = clamp_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.h index 303d280972f..5c2bee39ab5 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/lfilter_ir.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -// Char arrays storing name of corresponding classes (in alphabetical order) -constexpr char kLFilterOperation[] = "LFilter"; - -class LFilterOperation : public TensorOperation { - public: - LFilterOperation(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp); - - ~LFilterOperation() override = default; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kLFilterOperation; } - - Status to_json(nlohmann::json *out_json) override; - - private: - std::vector a_coeffs_; - std::vector b_coeffs_; - bool clamp_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// Char arrays storing name of corresponding classes (in alphabetical order) +constexpr char kLFilterOperation[] = "LFilter"; + +class LFilterOperation : public TensorOperation { + public: + LFilterOperation(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp); + + ~LFilterOperation() override = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kLFilterOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + std::vector a_coeffs_; + std::vector b_coeffs_; + bool clamp_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc index 5de5faae410..bfd4c311807 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc @@ -1,51 +1,51 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/magphase_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -MagphaseOperation::MagphaseOperation(float power) : power_(power) {} - -Status MagphaseOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Magphase", "power", power_)); - return Status::OK(); -} - -std::shared_ptr MagphaseOperation::Build() { return std::make_shared(power_); } - -Status MagphaseOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["power"] = power_; - *out_json = args; - return Status::OK(); -} - -Status MagphaseOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { - RETURN_UNEXPECTED_IF_NULL(operation); - RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "power", kMagphaseOperation)); - float power = op_params["power"]; - *operation = std::make_shared(power); - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/magphase_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +MagphaseOperation::MagphaseOperation(float power) : power_(power) {} + +Status MagphaseOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Magphase", "power", power_)); + return Status::OK(); +} + +std::shared_ptr MagphaseOperation::Build() { return std::make_shared(power_); } + +Status MagphaseOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["power"] = power_; + *out_json = args; + return Status::OK(); +} + +Status MagphaseOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { + RETURN_UNEXPECTED_IF_NULL(operation); + RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "power", kMagphaseOperation)); + float power = op_params["power"]; + *operation = std::make_shared(power); + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h index 00fc23a12f3..f2d3ad5017c 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h @@ -1,55 +1,55 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ - -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kMagphaseOperation[] = "Magphase"; - -class MagphaseOperation : public TensorOperation { - public: - explicit MagphaseOperation(float power); - - ~MagphaseOperation() override = default; - - std::shared_ptr Build() override; - - std::string Name() const override { return kMagphaseOperation; } - - Status ValidateParams() override; - - Status to_json(nlohmann::json *out_json) override; - - static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); - - private: - float power_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ + +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kMagphaseOperation[] = "Magphase"; + +class MagphaseOperation : public TensorOperation { + public: + explicit MagphaseOperation(float power); + + ~MagphaseOperation() override = default; + + std::shared_ptr Build() override; + + std::string Name() const override { return kMagphaseOperation; } + + Status ValidateParams() override; + + Status to_json(nlohmann::json *out_json) override; + + static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); + + private: + float power_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.cc index cd7f0488285..407481eca90 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.cc @@ -1,56 +1,56 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/mask_along_axis_iid_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -MaskAlongAxisIIDOperation::MaskAlongAxisIIDOperation(int32_t mask_param, float mask_value, int32_t axis) - : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) { - random_op_ = true; -} - -MaskAlongAxisIIDOperation::~MaskAlongAxisIIDOperation() = default; - -Status MaskAlongAxisIIDOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("MaskAlongAxisIID", "mask_param", mask_param_)); - RETURN_IF_NOT_OK(ValidateScalarValue("MaskAlongAxisIID", "axis", axis_, {1, 2})); - return Status::OK(); -} - -std::string MaskAlongAxisIIDOperation::Name() const { return kMaskAlongAxisIIDOperation; } - -std::shared_ptr MaskAlongAxisIIDOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(mask_param_, mask_value_, axis_); - return tensor_op; -} - -Status MaskAlongAxisIIDOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["mask_param"] = mask_param_; - args["mask_value"] = mask_value_; - args["axis"] = axis_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/mask_along_axis_iid_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +MaskAlongAxisIIDOperation::MaskAlongAxisIIDOperation(int32_t mask_param, float mask_value, int32_t axis) + : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) { + random_op_ = true; +} + +MaskAlongAxisIIDOperation::~MaskAlongAxisIIDOperation() = default; + +Status MaskAlongAxisIIDOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("MaskAlongAxisIID", "mask_param", mask_param_)); + RETURN_IF_NOT_OK(ValidateScalarValue("MaskAlongAxisIID", "axis", axis_, {1, 2})); + return Status::OK(); +} + +std::string MaskAlongAxisIIDOperation::Name() const { return kMaskAlongAxisIIDOperation; } + +std::shared_ptr MaskAlongAxisIIDOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(mask_param_, mask_value_, axis_); + return tensor_op; +} + +Status MaskAlongAxisIIDOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["mask_param"] = mask_param_; + args["mask_value"] = mask_value_; + args["axis"] = axis_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h index fcfdbaa7932..43d7731b3d7 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_iid_ir.h @@ -1,58 +1,58 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ - -#include -#include -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" -#include "minddata/dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kMaskAlongAxisIIDOperation[] = "MaskAlongAxisIID"; - -class MaskAlongAxisIIDOperation : public TensorOperation { - public: - MaskAlongAxisIIDOperation(int32_t mask_param, float mask_value, int32_t axis); - - ~MaskAlongAxisIIDOperation() override; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override; - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t mask_param_; - float mask_value_; - int32_t axis_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ + +#include +#include +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kMaskAlongAxisIIDOperation[] = "MaskAlongAxisIID"; + +class MaskAlongAxisIIDOperation : public TensorOperation { + public: + MaskAlongAxisIIDOperation(int32_t mask_param, float mask_value, int32_t axis); + + ~MaskAlongAxisIIDOperation() override; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t mask_param_; + float mask_value_; + int32_t axis_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IID_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.cc index dd7fd01d1e5..b47fa0b34f3 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.cc @@ -1,57 +1,57 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/mask_along_axis_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -MaskAlongAxisOperation::MaskAlongAxisOperation(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) - : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} - -MaskAlongAxisOperation::~MaskAlongAxisOperation() = default; - -Status MaskAlongAxisOperation::ValidateParams() { - RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("MaskAlongAxis", "mask_start", mask_start_)); - RETURN_IF_NOT_OK(ValidateIntScalarPositive("MaskAlongAxis", "mask_width", mask_width_)); - RETURN_IF_NOT_OK(ValidateScalarValue("MaskAlongAxis", "axis", axis_, {1, 2})); - return Status::OK(); -} - -std::string MaskAlongAxisOperation::Name() const { return kMaskAlongAxisOperation; } - -std::shared_ptr MaskAlongAxisOperation::Build() { - std::shared_ptr tensor_op = - std::make_shared(mask_start_, mask_width_, mask_value_, axis_); - return tensor_op; -} - -Status MaskAlongAxisOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["mask_start"] = mask_start_; - args["mask_width"] = mask_width_; - args["mask_value"] = mask_value_; - args["axis"] = axis_; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/mask_along_axis_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +MaskAlongAxisOperation::MaskAlongAxisOperation(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) + : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} + +MaskAlongAxisOperation::~MaskAlongAxisOperation() = default; + +Status MaskAlongAxisOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("MaskAlongAxis", "mask_start", mask_start_)); + RETURN_IF_NOT_OK(ValidateIntScalarPositive("MaskAlongAxis", "mask_width", mask_width_)); + RETURN_IF_NOT_OK(ValidateScalarValue("MaskAlongAxis", "axis", axis_, {1, 2})); + return Status::OK(); +} + +std::string MaskAlongAxisOperation::Name() const { return kMaskAlongAxisOperation; } + +std::shared_ptr MaskAlongAxisOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(mask_start_, mask_width_, mask_value_, axis_); + return tensor_op; +} + +Status MaskAlongAxisOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["mask_start"] = mask_start_; + args["mask_width"] = mask_width_; + args["mask_value"] = mask_value_; + args["axis"] = axis_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h index 07ba150c3c4..36b40a20ae3 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mask_along_axis_ir.h @@ -1,58 +1,58 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ - -#include -#include -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kMaskAlongAxisOperation[] = "MaskAlongAxis"; - -class MaskAlongAxisOperation : public TensorOperation { - public: - MaskAlongAxisOperation(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis); - - ~MaskAlongAxisOperation() override; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override; - - Status to_json(nlohmann::json *out_json) override; - - private: - int32_t mask_start_; - int32_t mask_width_; - float mask_value_; - int32_t axis_; -}; // class MaskAlongAxisOperation -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ + +#include +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kMaskAlongAxisOperation[] = "MaskAlongAxis"; + +class MaskAlongAxisOperation : public TensorOperation { + public: + MaskAlongAxisOperation(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis); + + ~MaskAlongAxisOperation() override; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + private: + int32_t mask_start_; + int32_t mask_width_; + float mask_value_; + int32_t axis_; +}; // class MaskAlongAxisOperation +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MASK_ALONG_AXIS_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.cc index a5178cb661b..368cf175ed4 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.cc @@ -1,59 +1,59 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" - -#include "minddata/dataset/audio/ir/validators.h" -#include "minddata/dataset/audio/kernels/phase_vocoder_op.h" - -namespace mindspore { -namespace dataset { -namespace audio { -PhaseVocoderOperation::PhaseVocoderOperation(float rate, const std::shared_ptr &phase_advance) - : rate_(rate), phase_advance_(phase_advance) {} - -PhaseVocoderOperation::~PhaseVocoderOperation() = default; - -Status PhaseVocoderOperation::ValidateParams() { - const int kPhaseAdvanceRank = 2; - const int kLastDim = -1; - const int kLastDimSize = 1; - RETURN_IF_NOT_OK(ValidateFloatScalarPositive("PhaseVocoder", "rate", rate_)); - CHECK_FAIL_RETURN_SYNTAX_ERROR( - phase_advance_->Rank() == kPhaseAdvanceRank && phase_advance_->shape()[kLastDim] == kLastDimSize, - "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of ."); - return Status::OK(); -} - -std::string PhaseVocoderOperation::Name() const { return kPhaseVocoderOperation; } - -std::shared_ptr PhaseVocoderOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(rate_, phase_advance_); - return tensor_op; -} - -Status PhaseVocoderOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - nlohmann::json args; - args["rate"] = rate_; - nlohmann::json phase_advance; - RETURN_IF_NOT_OK(phase_advance_->to_json(&phase_advance)); - args["phase_advance"] = phase_advance; - *out_json = args; - return Status::OK(); -} -} // namespace audio -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/phase_vocoder_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +PhaseVocoderOperation::PhaseVocoderOperation(float rate, const std::shared_ptr &phase_advance) + : rate_(rate), phase_advance_(phase_advance) {} + +PhaseVocoderOperation::~PhaseVocoderOperation() = default; + +Status PhaseVocoderOperation::ValidateParams() { + const int kPhaseAdvanceRank = 2; + const int kLastDim = -1; + const int kLastDimSize = 1; + RETURN_IF_NOT_OK(ValidateFloatScalarPositive("PhaseVocoder", "rate", rate_)); + CHECK_FAIL_RETURN_SYNTAX_ERROR( + phase_advance_->Rank() == kPhaseAdvanceRank && phase_advance_->shape()[kLastDim] == kLastDimSize, + "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of ."); + return Status::OK(); +} + +std::string PhaseVocoderOperation::Name() const { return kPhaseVocoderOperation; } + +std::shared_ptr PhaseVocoderOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(rate_, phase_advance_); + return tensor_op; +} + +Status PhaseVocoderOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + args["rate"] = rate_; + nlohmann::json phase_advance; + RETURN_IF_NOT_OK(phase_advance_->to_json(&phase_advance)); + args["phase_advance"] = phase_advance; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h index 4ce9c8f2992..4f074a402f5 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/phase_vocoder_ir.h @@ -1,52 +1,52 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace audio { -constexpr char kPhaseVocoderOperation[] = "PhaseVocoder"; - -class PhaseVocoderOperation : public TensorOperation { - public: - PhaseVocoderOperation(float rate, const std::shared_ptr &phase_advance); - - ~PhaseVocoderOperation(); - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override; - - Status to_json(nlohmann::json *out_json) override; - - private: - float rate_; - std::shared_ptr phase_advance_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kPhaseVocoderOperation[] = "PhaseVocoder"; + +class PhaseVocoderOperation : public TensorOperation { + public: + PhaseVocoderOperation(float rate, const std::shared_ptr &phase_advance); + + ~PhaseVocoderOperation(); + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + private: + float rate_; + std::shared_ptr phase_advance_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASE_VOCODER_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.cc index f31efe61945..e270446ed83 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.cc @@ -1,47 +1,47 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/allpass_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status AllpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_IF_NOT_OK(ValidateLowRank("AllpassBiquad", input, kMinAudioDim, "<..., time>")); - RETURN_IF_NOT_OK(ValidateTensorFloat("AllpassBiquad", input)); - double w0 = 2 * PI * central_freq_ / sample_rate_; - double alpha = sin(w0) / 2 / Q_; - double b0 = 1 - alpha; - double b1 = -2 * cos(w0); - double b2 = 1 + alpha; - double a0 = b2; - double a1 = -2 * cos(w0); - double a2 = 1 - alpha; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/allpass_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status AllpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateLowRank("AllpassBiquad", input, kMinAudioDim, "<..., time>")); + RETURN_IF_NOT_OK(ValidateTensorFloat("AllpassBiquad", input)); + double w0 = 2 * PI * central_freq_ / sample_rate_; + double alpha = sin(w0) / 2 / Q_; + double b0 = 1 - alpha; + double b1 = -2 * cos(w0); + double b2 = 1 + alpha; + double a0 = b2; + double a1 = -2 * cos(w0); + double a2 = 1 - alpha; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.h index 26c7b729f0a..3dd9b237deb 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/allpass_biquad_op.h @@ -1,53 +1,53 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_ALLPASS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_ALLPASS_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class AllpassBiquadOp : public TensorOp { - public: - AllpassBiquadOp(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - - ~AllpassBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ - << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kAllpassBiquadOp; } - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_ALLPASS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_ALLPASS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_ALLPASS_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class AllpassBiquadOp : public TensorOp { + public: + AllpassBiquadOp(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + + ~AllpassBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ + << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kAllpassBiquadOp; } + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_ALLPASS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.cc index 192af94f0f7..bda261cd3a4 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.cc @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/band_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status BandBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input tensor dimension, it should be greater than 0. - RETURN_IF_NOT_OK(ValidateLowRank("BandBiquad", input, kMinAudioDim, "<..., time>")); - // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("BandBiquad", input)); - double w0 = 2 * PI * central_freq_ / sample_rate_; - double bw_Hz = central_freq_ / Q_; - double a0 = 1.; - double a2 = exp(-2 * PI * bw_Hz / sample_rate_); - double a1 = -4 * a2 / (1 + a2) * cos(w0); - CHECK_FAIL_RETURN_UNEXPECTED( - a2 != 0, "BandBiquad: zero division error, 'central_freq / Q / sample_rate' got a big negative value."); - double b0 = sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2); - if (noise_) { - CHECK_FAIL_RETURN_UNEXPECTED(b0 != 0, "BandBiquad: zero division error, 'b0' can not be zero."); - double mutl = sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0; - b0 *= mutl; - } - double b1 = 0.; - double b2 = 0.; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/band_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status BandBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input tensor dimension, it should be greater than 0. + RETURN_IF_NOT_OK(ValidateLowRank("BandBiquad", input, kMinAudioDim, "<..., time>")); + // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("BandBiquad", input)); + double w0 = 2 * PI * central_freq_ / sample_rate_; + double bw_Hz = central_freq_ / Q_; + double a0 = 1.; + double a2 = exp(-2 * PI * bw_Hz / sample_rate_); + double a1 = -4 * a2 / (1 + a2) * cos(w0); + CHECK_FAIL_RETURN_UNEXPECTED( + a2 != 0, "BandBiquad: zero division error, 'central_freq / Q / sample_rate' got a big negative value."); + double b0 = sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2); + if (noise_) { + CHECK_FAIL_RETURN_UNEXPECTED(b0 != 0, "BandBiquad: zero division error, 'b0' can not be zero."); + double mutl = sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0; + b0 *= mutl; + } + double b1 = 0.; + double b2 = 0.; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.h index c92bda5fdd6..01f37d4de39 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/band_biquad_op.h @@ -1,54 +1,54 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BandBiquadOp : public TensorOp { - public: - BandBiquadOp(int32_t sample_rate, float central_freq, float Q, bool noise) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} - - ~BandBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ - << ", noise: " << noise_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kBandBiquadOp; } - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; - bool noise_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BandBiquadOp : public TensorOp { + public: + BandBiquadOp(int32_t sample_rate, float central_freq, float Q, bool noise) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), noise_(noise) {} + + ~BandBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ + << ", noise: " << noise_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kBandBiquadOp; } + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; + bool noise_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BAND_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.cc old mode 100755 new mode 100644 index 9076016fb00..9fe3bdeaefb --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.cc @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/bandpass_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status BandpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_IF_NOT_OK(ValidateLowRank("BandpassBiquad", input, kMinAudioDim, "<..., time>")); - // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("BandpassBiquad", input)); - float w0 = 2 * PI * central_freq_ / sample_rate_; - float alpha = sin(w0) / 2 / Q_; - float temp; - if (const_skirt_gain_) { - temp = sin(w0) / 2; - } else { - temp = alpha; - } - - float b0 = temp; - float b1 = 0.0; - float b2 = -temp; - float a0 = 1 + alpha; - float a1 = (-2) * cos(w0); - float a2 = 1 - alpha; - - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/bandpass_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status BandpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateLowRank("BandpassBiquad", input, kMinAudioDim, "<..., time>")); + // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("BandpassBiquad", input)); + float w0 = 2 * PI * central_freq_ / sample_rate_; + float alpha = sin(w0) / 2 / Q_; + float temp; + if (const_skirt_gain_) { + temp = sin(w0) / 2; + } else { + temp = alpha; + } + + float b0 = temp; + float b1 = 0.0; + float b2 = -temp; + float a0 = 1 + alpha; + float a1 = (-2) * cos(w0); + float a2 = 1 - alpha; + + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.h old mode 100755 new mode 100644 index dead035fbc4..8a3f966203a --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandpass_biquad_op.h @@ -1,54 +1,54 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BandpassBiquadOp : public TensorOp { - public: - BandpassBiquadOp(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} - - ~BandpassBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ - << ", const_skirt_gain: " << const_skirt_gain_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kBandpassBiquadOp; } - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; - bool const_skirt_gain_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BandpassBiquadOp : public TensorOp { + public: + BandpassBiquadOp(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q), const_skirt_gain_(const_skirt_gain) {} + + ~BandpassBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ + << ", const_skirt_gain: " << const_skirt_gain_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kBandpassBiquadOp; } + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; + bool const_skirt_gain_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDPASS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.cc index a165e5fbc8d..4d2e5f4996f 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.cc @@ -1,48 +1,48 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/bandreject_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status BandrejectBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input type and input shape - RETURN_IF_NOT_OK(ValidateLowRank("BandrejectBiquad", input, kMinAudioDim, "<..., time>")); - RETURN_IF_NOT_OK(ValidateTensorFloat("BandrejectBiquad", input)); - double w0 = 2 * PI * central_freq_ / sample_rate_; - double alpha = sin(w0) / 2 / Q_; - double b0 = 1; - double b1 = -2 * cos(w0); - double b2 = 1; - double a0 = 1 + alpha; - double a1 = b1; - double a2 = 1 - alpha; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/bandreject_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status BandrejectBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input type and input shape + RETURN_IF_NOT_OK(ValidateLowRank("BandrejectBiquad", input, kMinAudioDim, "<..., time>")); + RETURN_IF_NOT_OK(ValidateTensorFloat("BandrejectBiquad", input)); + double w0 = 2 * PI * central_freq_ / sample_rate_; + double alpha = sin(w0) / 2 / Q_; + double b0 = 1; + double b1 = -2 * cos(w0); + double b2 = 1; + double a0 = 1 + alpha; + double a1 = b1; + double a2 = 1 - alpha; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.h index e59d0cf3220..e2d5e50df72 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bandreject_biquad_op.h @@ -1,53 +1,53 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDREJECT_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDREJECT_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BandrejectBiquadOp : public TensorOp { - public: - BandrejectBiquadOp(int32_t sample_rate, float central_freq, float Q) - : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} - - ~BandrejectBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ - << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kBandrejectBiquadOp; } - - private: - int32_t sample_rate_; - float central_freq_; - float Q_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_BANDREJECT_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDREJECT_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BANDREJECT_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BandrejectBiquadOp : public TensorOp { + public: + BandrejectBiquadOp(int32_t sample_rate, float central_freq, float Q) + : sample_rate_(sample_rate), central_freq_(central_freq), Q_(Q) {} + + ~BandrejectBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", central_freq: " << central_freq_ << ", Q: " << Q_ + << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kBandrejectBiquadOp; } + + private: + int32_t sample_rate_; + float central_freq_; + float Q_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_BANDREJECT_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.cc index 22116a7d8bb..6915fb967f8 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.cc @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/bass_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status BassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_IF_NOT_OK(ValidateLowRank("BassBiquad", input, kMinAudioDim, "<..., time>")); - // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("BassBiquad", input)); - double w0 = 2 * PI * central_freq_ / sample_rate_; - double alpha = sin(w0) / 2 / Q_; - double A = exp(gain_ / 40 * log(10)); - - double temp1 = 2 * sqrt(A) * alpha; - double temp2 = (A - 1) * cos(w0); - double temp3 = (A + 1) * cos(w0); - - double b0 = A * ((A + 1) - temp2 + temp1); - double b1 = 2 * A * ((A - 1) - temp3); - double b2 = A * ((A + 1) - temp2 - temp1); - double a0 = (A + 1) + temp2 + temp1; - double a1 = -2 * ((A - 1) + temp3); - double a2 = (A + 1) + temp2 - temp1; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), static_cast(b2 / a0), - static_cast(1.0), static_cast(a1 / a0), static_cast(a2 / a0)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), - static_cast(b2 / a0), static_cast(1.0), static_cast(a1 / a0), - static_cast(a2 / a0)); - } else { - return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), - static_cast(b2 / a0), static_cast(1.0), static_cast(a1 / a0), - static_cast(a2 / a0)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/bass_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status BassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateLowRank("BassBiquad", input, kMinAudioDim, "<..., time>")); + // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("BassBiquad", input)); + double w0 = 2 * PI * central_freq_ / sample_rate_; + double alpha = sin(w0) / 2 / Q_; + double A = exp(gain_ / 40 * log(10)); + + double temp1 = 2 * sqrt(A) * alpha; + double temp2 = (A - 1) * cos(w0); + double temp3 = (A + 1) * cos(w0); + + double b0 = A * ((A + 1) - temp2 + temp1); + double b1 = 2 * A * ((A - 1) - temp3); + double b2 = A * ((A + 1) - temp2 - temp1); + double a0 = (A + 1) + temp2 + temp1; + double a1 = -2 * ((A - 1) + temp3); + double a2 = (A + 1) + temp2 - temp1; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), static_cast(b2 / a0), + static_cast(1.0), static_cast(a1 / a0), static_cast(a2 / a0)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), + static_cast(b2 / a0), static_cast(1.0), static_cast(a1 / a0), + static_cast(a2 / a0)); + } else { + return Biquad(input, output, static_cast(b0 / a0), static_cast(b1 / a0), + static_cast(b2 / a0), static_cast(1.0), static_cast(a1 / a0), + static_cast(a2 / a0)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.h index 68552c1bb80..87985d176cd 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/bass_biquad_op.h @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BassBiquadOp : public TensorOp { - public: - BassBiquadOp(int32_t sample_rate, float gain, float central_freq, float Q) - : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} - - ~BassBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", gain: " << gain_ << ", central_freq: " << central_freq_ - << ", Q: " << Q_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kBassBiquadOp; } - - private: - int32_t sample_rate_; - float gain_; - float central_freq_; - float Q_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BassBiquadOp : public TensorOp { + public: + BassBiquadOp(int32_t sample_rate, float gain, float central_freq, float Q) + : sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {} + + ~BassBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", gain: " << gain_ << ", central_freq: " << central_freq_ + << ", Q: " << Q_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kBassBiquadOp; } + + private: + int32_t sample_rate_; + float gain_; + float central_freq_; + float Q_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BASS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.cc index aa3cb0b45cb..4cb37b9cff7 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.cc @@ -1,41 +1,41 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status BiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input tensor dimension, it should be greater than 0. - RETURN_IF_NOT_OK(ValidateLowRank("Biquad", input, kMinAudioDim, "<..., time>")); - // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("Biquad", input)); - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), - static_cast(a0_), static_cast(a1_), static_cast(a2_)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), - static_cast(a0_), static_cast(a1_), static_cast(a2_)); - } else { - return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), - static_cast(a0_), static_cast(a1_), static_cast(a2_)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status BiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input tensor dimension, it should be greater than 0. + RETURN_IF_NOT_OK(ValidateLowRank("Biquad", input, kMinAudioDim, "<..., time>")); + // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("Biquad", input)); + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), + static_cast(a0_), static_cast(a1_), static_cast(a2_)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), + static_cast(a0_), static_cast(a1_), static_cast(a2_)); + } else { + return Biquad(input, output, static_cast(b0_), static_cast(b1_), static_cast(b2_), + static_cast(a0_), static_cast(a1_), static_cast(a2_)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.h index e6fe3553c70..cbca0956c20 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/biquad_op.h @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BiquadOp : public TensorOp { - public: - BiquadOp(float b0, float b1, float b2, float a0, float a1, float a2) - : b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {} - - ~BiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": b0: " << b0_ << ", b1: " << b1_ << ", b2: " << b2_ << ", a0: " << a0_ << ", a1: " << a1_ - << ", a2: " << a2_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kBiquadOp; } - - private: - float b0_; - float b1_; - float b2_; - float a0_; - float a1_; - float a2_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BiquadOp : public TensorOp { + public: + BiquadOp(float b0, float b1, float b2, float a0, float a1, float a2) + : b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {} + + ~BiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": b0: " << b0_ << ", b1: " << b1_ << ", b2: " << b2_ << ", a0: " << a0_ << ", a1: " << a1_ + << ", a2: " << a2_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kBiquadOp; } + + private: + float b0_; + float b1_; + float b2_; + float a0_; + float a1_; + float a2_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.cc index 6d4721b213f..e7fefa2da9f 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.cc @@ -1,72 +1,72 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/deemph_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status DeemphBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_IF_NOT_OK(ValidateLowRank("DeemphBiquad", input, kMinAudioDim, "<..., time>")); - RETURN_IF_NOT_OK(ValidateTensorFloat("DeemphBiquad", input)); - const int32_t kSampleRate44100 = 44100; - const int32_t kSampleRate48000 = 48000; - int32_t central_freq = 0; - double width_slope = 1; - double gain = 0.0; - if (sample_rate_ == kSampleRate44100) { - central_freq = 5283; // central_freq value from SoX - width_slope = 0.4845; // width_slope value from SoX - gain = -9.477; // gain value from SoX - } else if (sample_rate_ == kSampleRate48000) { - central_freq = 5356; // central_freq value from SoX - width_slope = 0.479; // width_slope value from SoX - gain = -9.62; // gain value from SoX - } else { - RETURN_STATUS_UNEXPECTED( - "The sample_rate parameter only supports 44100 or 48000, but got: " + std::to_string(sample_rate_) + "."); - } - - double w0 = 2 * PI * central_freq / sample_rate_; - double A = exp(gain / 40 * log(10)); - double alpha = sin(w0) / 2 * sqrt((A + 1 / A) * (1 / width_slope - 1) + 2); - - // temp1, temp2, temp3 are the intermediate variable used to solve for a and b. - double temp1 = 2 * sqrt(A) * alpha; - double temp2 = (A - 1) * cos(w0); - double temp3 = (A + 1) * cos(w0); - - double b0 = A * ((A + 1) + temp2 + temp1); - double b1 = -2 * A * ((A - 1) + temp3); - double b2 = A * ((A + 1) + temp2 - temp1); - double a0 = (A + 1) - temp2 + temp1; - double a1 = 2 * ((A - 1) - temp3); - double a2 = (A + 1) - temp2 - temp1; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/deemph_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status DeemphBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateLowRank("DeemphBiquad", input, kMinAudioDim, "<..., time>")); + RETURN_IF_NOT_OK(ValidateTensorFloat("DeemphBiquad", input)); + const int32_t kSampleRate44100 = 44100; + const int32_t kSampleRate48000 = 48000; + int32_t central_freq = 0; + double width_slope = 1; + double gain = 0.0; + if (sample_rate_ == kSampleRate44100) { + central_freq = 5283; // central_freq value from SoX + width_slope = 0.4845; // width_slope value from SoX + gain = -9.477; // gain value from SoX + } else if (sample_rate_ == kSampleRate48000) { + central_freq = 5356; // central_freq value from SoX + width_slope = 0.479; // width_slope value from SoX + gain = -9.62; // gain value from SoX + } else { + RETURN_STATUS_UNEXPECTED( + "The sample_rate parameter only supports 44100 or 48000, but got: " + std::to_string(sample_rate_) + "."); + } + + double w0 = 2 * PI * central_freq / sample_rate_; + double A = exp(gain / 40 * log(10)); + double alpha = sin(w0) / 2 * sqrt((A + 1 / A) * (1 / width_slope - 1) + 2); + + // temp1, temp2, temp3 are the intermediate variable used to solve for a and b. + double temp1 = 2 * sqrt(A) * alpha; + double temp2 = (A - 1) * cos(w0); + double temp3 = (A + 1) * cos(w0); + + double b0 = A * ((A + 1) + temp2 + temp1); + double b1 = -2 * A * ((A - 1) + temp3); + double b2 = A * ((A + 1) + temp2 - temp1); + double a0 = (A + 1) - temp2 + temp1; + double a1 = 2 * ((A - 1) - temp3); + double a2 = (A + 1) - temp2 - temp1; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.h index 1e830c02f27..0324a5ce104 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/deemph_biquad_op.h @@ -1,46 +1,46 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DEEMPH_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DEEMPH_BIQUAD_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DeemphBiquadOp : public TensorOp { - public: - explicit DeemphBiquadOp(int32_t sample_rate) : sample_rate_(sample_rate) {} - - ~DeemphBiquadOp() override = default; - - void Print(std::ostream &out) const override { out << Name() << ": sample_rate: " << sample_rate_ << std::endl; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kDeemphBiquadOp; } - - private: - int32_t sample_rate_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DEEMPH_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DEEMPH_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DEEMPH_BIQUAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DeemphBiquadOp : public TensorOp { + public: + explicit DeemphBiquadOp(int32_t sample_rate) : sample_rate_(sample_rate) {} + + ~DeemphBiquadOp() override = default; + + void Print(std::ostream &out) const override { out << Name() << ": sample_rate: " << sample_rate_ << std::endl; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kDeemphBiquadOp; } + + private: + int32_t sample_rate_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DEEMPH_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.cc index 0a12ead6368..9f56b2be1ac 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/flanger_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status FlangerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input dimensions, it should be 2 dimensions or more - RETURN_IF_NOT_OK(ValidateLowRank("Flanger", input, kDefaultAudioDim, "<..., channel, time>")); - - // check input channel, it should be less than or equal to 4 - const int32_t kChannelIndex = -2; - const int32_t kChannelLimit = 4; - CHECK_FAIL_RETURN_SYNTAX_ERROR(input->shape()[kChannelIndex] <= kChannelLimit, - "Flanger: the channel of input tensor dose not match the requirement of operator. " - "Expecting tensor with channel less than or equal to 4. But got channel: " + - std::to_string(input->shape()[kChannelIndex])); - - // check input type, it should be [int, float, double] - RETURN_IF_NOT_OK(ValidateTensorNumeric("Flanger", input)); - - if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Flanger(input, output, sample_rate_, delay_, depth_, regen_, width_, speed_, phase_, Modulation_, - Interpolation_); - } else { - return Flanger(input, output, sample_rate_, delay_, depth_, regen_, width_, speed_, phase_, Modulation_, - Interpolation_); - } -} - -Status FlangerOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - RETURN_IF_NOT_OK(ValidateTensorType("Flanger", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); - outputs[0] = inputs[0]; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/flanger_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status FlangerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input dimensions, it should be 2 dimensions or more + RETURN_IF_NOT_OK(ValidateLowRank("Flanger", input, kDefaultAudioDim, "<..., channel, time>")); + + // check input channel, it should be less than or equal to 4 + const int32_t kChannelIndex = -2; + const int32_t kChannelLimit = 4; + CHECK_FAIL_RETURN_SYNTAX_ERROR(input->shape()[kChannelIndex] <= kChannelLimit, + "Flanger: the channel of input tensor dose not match the requirement of operator. " + "Expecting tensor with channel less than or equal to 4. But got channel: " + + std::to_string(input->shape()[kChannelIndex])); + + // check input type, it should be [int, float, double] + RETURN_IF_NOT_OK(ValidateTensorNumeric("Flanger", input)); + + if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Flanger(input, output, sample_rate_, delay_, depth_, regen_, width_, speed_, phase_, Modulation_, + Interpolation_); + } else { + return Flanger(input, output, sample_rate_, delay_, depth_, regen_, width_, speed_, phase_, Modulation_, + Interpolation_); + } +} + +Status FlangerOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + RETURN_IF_NOT_OK(ValidateTensorType("Flanger", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); + outputs[0] = inputs[0]; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.h index cc03625e2f4..4d5dc51cf9a 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/flanger_op.h @@ -1,73 +1,73 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class FlangerOp : public TensorOp { - public: - explicit FlangerOp(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, - Modulation modulation, Interpolation interpolation) - : sample_rate_(sample_rate), - delay_(delay), - depth_(depth), - regen_(regen), - width_(width), - speed_(speed), - phase_(phase), - Modulation_(modulation), - Interpolation_(interpolation) {} - - ~FlangerOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", delay:" << delay_ << ", depth: " << depth_ - << ", regen: " << regen_ << ", width: " << width_ << ", speed: " << speed_ << ", phase: " << phase_ - << ", Modulation: " << static_cast(Modulation_) << ", Interpolation: " << static_cast(Interpolation_) - << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kFlangerOp; } - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - int32_t sample_rate_; - float delay_; - float depth_; - float regen_; - float width_; - float speed_; - float phase_; - Modulation Modulation_; - Interpolation Interpolation_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class FlangerOp : public TensorOp { + public: + explicit FlangerOp(int32_t sample_rate, float delay, float depth, float regen, float width, float speed, float phase, + Modulation modulation, Interpolation interpolation) + : sample_rate_(sample_rate), + delay_(delay), + depth_(depth), + regen_(regen), + width_(width), + speed_(speed), + phase_(phase), + Modulation_(modulation), + Interpolation_(interpolation) {} + + ~FlangerOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", delay:" << delay_ << ", depth: " << depth_ + << ", regen: " << regen_ << ", width: " << width_ << ", speed: " << speed_ << ", phase: " << phase_ + << ", Modulation: " << static_cast(Modulation_) << ", Interpolation: " << static_cast(Interpolation_) + << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kFlangerOp; } + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + int32_t sample_rate_; + float delay_; + float depth_; + float regen_; + float width_; + float speed_; + float phase_; + Modulation Modulation_; + Interpolation Interpolation_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FLANGER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.cc index 26f5ca3ca35..675f3e8b5d7 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.cc @@ -1,50 +1,50 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/kernels/highpass_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" - -namespace mindspore { -namespace dataset { -Status HighpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input tensor dimension, it should be greater than 0. - RETURN_IF_NOT_OK(ValidateLowRank("HighpassBiquad", input, kMinAudioDim, "<..., time>")); - // check input tensor type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("HighpassBiquad", input)); - double w0 = 2 * PI * cutoff_freq_ / sample_rate_; - double alpha = sin(w0) / 2 / Q_; - - double b0 = (1 + cos(w0)) / 2; - double b1 = -1 - cos(w0); - double b2 = b0; - double a0 = 1 + alpha; - double a1 = -2 * cos(w0); - double a2 = 1 - alpha; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/highpass_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" + +namespace mindspore { +namespace dataset { +Status HighpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input tensor dimension, it should be greater than 0. + RETURN_IF_NOT_OK(ValidateLowRank("HighpassBiquad", input, kMinAudioDim, "<..., time>")); + // check input tensor type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("HighpassBiquad", input)); + double w0 = 2 * PI * cutoff_freq_ / sample_rate_; + double alpha = sin(w0) / 2 / Q_; + + double b0 = (1 + cos(w0)) / 2; + double b1 = -1 - cos(w0); + double b2 = b0; + double a0 = 1 + alpha; + double a1 = -2 * cos(w0); + double a2 = 1 - alpha; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.h old mode 100755 new mode 100644 index 9c279245b15..cdeac961a75 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/highpass_biquad_op.h @@ -1,48 +1,48 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ - -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class HighpassBiquadOp : public TensorOp { - public: - HighpassBiquadOp(int32_t sample_rate, float cutoff_freq, float Q) - : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} - - ~HighpassBiquadOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kHighpassBiquadOp; }; - - protected: - int32_t sample_rate_; - float cutoff_freq_; - float Q_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class HighpassBiquadOp : public TensorOp { + public: + HighpassBiquadOp(int32_t sample_rate, float cutoff_freq, float Q) + : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} + + ~HighpassBiquadOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kHighpassBiquadOp; }; + + protected: + int32_t sample_rate_; + float cutoff_freq_; + float Q_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.cc index ada059f8892..b0162651698 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.cc @@ -1,52 +1,52 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/lfilter_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status LFilterOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_IF_NOT_OK(ValidateLowRank("LFilter", input, kMinAudioDim, "<..., time>")); - RETURN_IF_NOT_OK(ValidateTensorFloat("LFilter", input)); - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return LFilter(input, output, a_coeffs_, b_coeffs_, clamp_); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - std::vector a_coeffs_double; - std::vector b_coeffs_double; - for (auto i = 0; i < a_coeffs_.size(); i++) { - a_coeffs_double.push_back(static_cast(a_coeffs_[i])); - } - for (auto i = 0; i < b_coeffs_.size(); i++) { - b_coeffs_double.push_back(static_cast(b_coeffs_[i])); - } - return LFilter(input, output, a_coeffs_double, b_coeffs_double, clamp_); - } else { - std::vector a_coeffs_float16; - std::vector b_coeffs_float16; - for (auto i = 0; i < a_coeffs_.size(); i++) { - a_coeffs_float16.push_back(static_cast(a_coeffs_[i])); - } - for (auto i = 0; i < b_coeffs_.size(); i++) { - b_coeffs_float16.push_back(static_cast(b_coeffs_[i])); - } - return LFilter(input, output, a_coeffs_float16, b_coeffs_float16, clamp_); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/lfilter_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status LFilterOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateLowRank("LFilter", input, kMinAudioDim, "<..., time>")); + RETURN_IF_NOT_OK(ValidateTensorFloat("LFilter", input)); + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return LFilter(input, output, a_coeffs_, b_coeffs_, clamp_); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + std::vector a_coeffs_double; + std::vector b_coeffs_double; + for (auto i = 0; i < a_coeffs_.size(); i++) { + a_coeffs_double.push_back(static_cast(a_coeffs_[i])); + } + for (auto i = 0; i < b_coeffs_.size(); i++) { + b_coeffs_double.push_back(static_cast(b_coeffs_[i])); + } + return LFilter(input, output, a_coeffs_double, b_coeffs_double, clamp_); + } else { + std::vector a_coeffs_float16; + std::vector b_coeffs_float16; + for (auto i = 0; i < a_coeffs_.size(); i++) { + a_coeffs_float16.push_back(static_cast(a_coeffs_[i])); + } + for (auto i = 0; i < b_coeffs_.size(); i++) { + b_coeffs_float16.push_back(static_cast(b_coeffs_[i])); + } + return LFilter(input, output, a_coeffs_float16, b_coeffs_float16, clamp_); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.h index 16c21dc3f1c..884b24943b4 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/lfilter_op.h @@ -1,60 +1,60 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class LFilterOp : public TensorOp { - public: - LFilterOp(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) - : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} - - ~LFilterOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": a_coeffs: "; - for (auto i = 0; i < a_coeffs_.size(); i++) { - out << a_coeffs_[i] << " "; - } - out << "b_coeffs: "; - for (auto i = 0; i < b_coeffs_.size(); i++) { - out << b_coeffs_[i] << " "; - } - out << "clamp: " << clamp_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kLFilterOp; } - - private: - std::vector a_coeffs_; - std::vector b_coeffs_; - bool clamp_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_LFILTER_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class LFilterOp : public TensorOp { + public: + LFilterOp(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp) + : a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {} + + ~LFilterOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": a_coeffs: "; + for (auto i = 0; i < a_coeffs_.size(); i++) { + out << a_coeffs_[i] << " "; + } + out << "b_coeffs: "; + for (auto i = 0; i < b_coeffs_.size(); i++) { + out << b_coeffs_[i] << " "; + } + out << "clamp: " << clamp_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kLFilterOp; } + + private: + std::vector a_coeffs_; + std::vector b_coeffs_; + bool clamp_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_LFILTER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.cc index 27e7824b5c0..bbe868ac04d 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.cc @@ -1,54 +1,54 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/kernels/lowpass_biquad_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float LowpassBiquadOp::kQ = 0.707; -// constructor - -Status LowpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // check input tensor dimension, it should be greater than 0. - RETURN_IF_NOT_OK(ValidateLowRank("LowpassBiquad", input, kMinAudioDim, "<..., time>")); - // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 - RETURN_IF_NOT_OK(ValidateTensorFloat("LowpassBiquad", input)); - double w0 = 2 * PI * cutoff_freq_ / sample_rate_; - double alpha = sin(w0) / 2 / Q_; - - double b0 = (1 - cos(w0)) / 2; - double b1 = 1 - cos(w0); - double b2 = b0; - double a0 = 1 + alpha; - double a1 = -2 * cos(w0); - double a2 = 1 - alpha; - if (input->type() == DataType(DataType::DE_FLOAT32)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else if (input->type() == DataType(DataType::DE_FLOAT64)) { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } else { - return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), - static_cast(a0), static_cast(a1), static_cast(a2)); - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/lowpass_biquad_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float LowpassBiquadOp::kQ = 0.707; +// constructor + +Status LowpassBiquadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // check input tensor dimension, it should be greater than 0. + RETURN_IF_NOT_OK(ValidateLowRank("LowpassBiquad", input, kMinAudioDim, "<..., time>")); + // check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64 + RETURN_IF_NOT_OK(ValidateTensorFloat("LowpassBiquad", input)); + double w0 = 2 * PI * cutoff_freq_ / sample_rate_; + double alpha = sin(w0) / 2 / Q_; + + double b0 = (1 - cos(w0)) / 2; + double b1 = 1 - cos(w0); + double b2 = b0; + double a0 = 1 + alpha; + double a1 = -2 * cos(w0); + double a2 = 1 - alpha; + if (input->type() == DataType(DataType::DE_FLOAT32)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else if (input->type() == DataType(DataType::DE_FLOAT64)) { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } else { + return Biquad(input, output, static_cast(b0), static_cast(b1), static_cast(b2), + static_cast(a0), static_cast(a1), static_cast(a2)); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.h index 951d5f9e70b..3b3cad04d12 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/lowpass_biquad_op.h @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ - -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class LowpassBiquadOp : public TensorOp { - public: - /// default values; - static const float kQ; - - LowpassBiquadOp(int32_t sample_rate, float cutoff_freq, float Q) - : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} - - ~LowpassBiquadOp() override = default; - - void Print(std::ostream &out) const override { - out << Name() << ": sample_rate: " << sample_rate_ << ", cutoff_freq: " << cutoff_freq_ << ", Q: " << Q_ - << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kLowpassBiquadOp; } - - private: - int32_t sample_rate_; - float cutoff_freq_; - float Q_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class LowpassBiquadOp : public TensorOp { + public: + /// default values; + static const float kQ; + + LowpassBiquadOp(int32_t sample_rate, float cutoff_freq, float Q) + : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {} + + ~LowpassBiquadOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << ": sample_rate: " << sample_rate_ << ", cutoff_freq: " << cutoff_freq_ << ", Q: " << Q_ + << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kLowpassBiquadOp; } + + private: + int32_t sample_rate_; + float cutoff_freq_; + float Q_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LOWPASS_BIQUAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc index 7cc8a45fa0c..fa4147da9bd 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc @@ -1,57 +1,57 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/magphase_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -constexpr float MagphaseOp::kPower = 1.0; - -Status MagphaseOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - RETURN_IF_NOT_OK(ValidateTensorShape("Magphase", input[0]->IsComplex(), "<..., complex=2>")); - RETURN_IF_NOT_OK(ValidateTensorNumeric("Magphase", input[0])); - RETURN_IF_NOT_OK(Magphase(input, output, power_)); - return Status::OK(); -} - -Status MagphaseOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - auto vec = inputs[0].AsVector(); - vec.pop_back(); - auto out = TensorShape(vec); - outputs = {out, out}; - if (!outputs.empty()) { - return Status::OK(); - } - return Status(StatusCode::kMDUnexpectedError, "Magphase: invalid shape of input tensor."); -} - -Status MagphaseOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - RETURN_IF_NOT_OK(ValidateTensorType("Magphase", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); - if (inputs[0] == DataType(DataType::DE_FLOAT64)) { - outputs[0] = DataType(DataType::DE_FLOAT64); - } else { - outputs[0] = DataType(DataType::DE_FLOAT32); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/magphase_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +constexpr float MagphaseOp::kPower = 1.0; + +Status MagphaseOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + RETURN_IF_NOT_OK(ValidateTensorShape("Magphase", input[0]->IsComplex(), "<..., complex=2>")); + RETURN_IF_NOT_OK(ValidateTensorNumeric("Magphase", input[0])); + RETURN_IF_NOT_OK(Magphase(input, output, power_)); + return Status::OK(); +} + +Status MagphaseOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + auto vec = inputs[0].AsVector(); + vec.pop_back(); + auto out = TensorShape(vec); + outputs = {out, out}; + if (!outputs.empty()) { + return Status::OK(); + } + return Status(StatusCode::kMDUnexpectedError, "Magphase: invalid shape of input tensor."); +} + +Status MagphaseOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + RETURN_IF_NOT_OK(ValidateTensorType("Magphase", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); + if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h index bf4e1a86c1c..4e8c179a647 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h @@ -1,52 +1,52 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -class MagphaseOp : public TensorOp { - public: - static const float kPower; - - explicit MagphaseOp(float power = kPower) : power_(power) {} - - ~MagphaseOp() override = default; - - Status Compute(const TensorRow &input, TensorRow *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kMagphaseOp; } - - private: - float power_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class MagphaseOp : public TensorOp { + public: + static const float kPower; + + explicit MagphaseOp(float power = kPower) : power_(power) {} + + ~MagphaseOp() override = default; + + Status Compute(const TensorRow &input, TensorRow *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMagphaseOp; } + + private: + float power_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.cc index 6eeccb44aeb..758cd260e6a 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.cc @@ -1,71 +1,71 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/mask_along_axis_iid_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/kernels/data/data_utils.h" -#include "minddata/dataset/util/random.h" - -namespace mindspore { -namespace dataset { -const int32_t kFrequencyAxis = 1; -const int32_t kTimeAxis = 2; -const int32_t kTensorFreqiencyPos = -2; -const int32_t kTensorTimePos = -1; - -Status MaskAlongAxisIIDOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - RETURN_IF_NOT_OK(ValidateLowRank("MaskAlongAxisIID", input, kDefaultAudioDim, "<..., freq, time>")); - RETURN_IF_NOT_OK(ValidateTensorType("MaskAlongAxisIID", input->type().IsNumeric(), "[int, float, double]", - input->type().ToString())); - TensorShape input_shape = input->shape(); - - if (axis_ == kFrequencyAxis) { - CHECK_FAIL_RETURN_UNEXPECTED( - input_shape[kTensorFreqiencyPos] >= mask_param_, - "MaskAlongAxisIID: mask_param should be less than or equal to the length of frequency dimension."); - } else if (axis_ == kTimeAxis) { - CHECK_FAIL_RETURN_UNEXPECTED( - input_shape[kTensorTimePos] >= mask_param_, - "MaskAlongAxisIID: mask_param should be less than or equal to the length of time dimension."); - } else { - RETURN_STATUS_UNEXPECTED("MaskAlongAxisIID: only support Frequency and Time masking, axis should be 1 or 2."); - } - - std::shared_ptr input_tensor; - if (input->type() != DataType::DE_FLOAT64) { - RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); - } else { - input_tensor = input; - } - return RandomMaskAlongAxis(input_tensor, output, mask_param_, mask_value_, axis_, &random_generator_); -} - -Status MaskAlongAxisIIDOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - RETURN_IF_NOT_OK( - ValidateTensorType("MaskAlongAxisIID", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); - - if (inputs[0] == DataType(DataType::DE_FLOAT64)) { - outputs[0] = DataType(DataType::DE_FLOAT64); - } else { - outputs[0] = DataType(DataType::DE_FLOAT32); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/mask_along_axis_iid_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +const int32_t kFrequencyAxis = 1; +const int32_t kTimeAxis = 2; +const int32_t kTensorFreqiencyPos = -2; +const int32_t kTensorTimePos = -1; + +Status MaskAlongAxisIIDOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + RETURN_IF_NOT_OK(ValidateLowRank("MaskAlongAxisIID", input, kDefaultAudioDim, "<..., freq, time>")); + RETURN_IF_NOT_OK(ValidateTensorType("MaskAlongAxisIID", input->type().IsNumeric(), "[int, float, double]", + input->type().ToString())); + TensorShape input_shape = input->shape(); + + if (axis_ == kFrequencyAxis) { + CHECK_FAIL_RETURN_UNEXPECTED( + input_shape[kTensorFreqiencyPos] >= mask_param_, + "MaskAlongAxisIID: mask_param should be less than or equal to the length of frequency dimension."); + } else if (axis_ == kTimeAxis) { + CHECK_FAIL_RETURN_UNEXPECTED( + input_shape[kTensorTimePos] >= mask_param_, + "MaskAlongAxisIID: mask_param should be less than or equal to the length of time dimension."); + } else { + RETURN_STATUS_UNEXPECTED("MaskAlongAxisIID: only support Frequency and Time masking, axis should be 1 or 2."); + } + + std::shared_ptr input_tensor; + if (input->type() != DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); + } else { + input_tensor = input; + } + return RandomMaskAlongAxis(input_tensor, output, mask_param_, mask_value_, axis_, &random_generator_); +} + +Status MaskAlongAxisIIDOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + RETURN_IF_NOT_OK( + ValidateTensorType("MaskAlongAxisIID", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); + + if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.h index 3f1dcbdea37..082ca2bbf66 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_iid_op.h @@ -1,56 +1,56 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ - -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/random.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class MaskAlongAxisIIDOp : public RandomTensorOp { - public: - /// \brief Constructor. - /// \param[in] mask_param Number of columns to be masked, will be uniformly sampled from [0, mask_param], - /// must be non negative. - /// \param[in] mask_value Value to assign to the masked columns. - /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). - MaskAlongAxisIIDOp(int32_t mask_param, float mask_value, int32_t axis) - : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) {} - - ~MaskAlongAxisIIDOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kMaskAlongAxisIIDOp; } - - private: - int32_t mask_param_; - float mask_value_; - int32_t axis_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class MaskAlongAxisIIDOp : public RandomTensorOp { + public: + /// \brief Constructor. + /// \param[in] mask_param Number of columns to be masked, will be uniformly sampled from [0, mask_param], + /// must be non negative. + /// \param[in] mask_value Value to assign to the masked columns. + /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). + MaskAlongAxisIIDOp(int32_t mask_param, float mask_value, int32_t axis) + : mask_param_(mask_param), mask_value_(mask_value), axis_(axis) {} + + ~MaskAlongAxisIIDOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMaskAlongAxisIIDOp; } + + private: + int32_t mask_param_; + float mask_value_; + int32_t axis_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_IID_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.cc index 1bbb9561d03..ddb9a15add4 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.cc @@ -1,51 +1,51 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/audio/kernels/mask_along_axis_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { -Status MaskAlongAxisOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // input <..., freq, time> - RETURN_IF_NOT_OK(ValidateLowRank("MaskAlongAxis", input, kDefaultAudioDim, "<..., freq, time>")); - RETURN_IF_NOT_OK( - ValidateTensorType("MaskAlongAxis", input->type().IsNumeric(), "[int, float, double]", input->type().ToString())); - std::shared_ptr input_tensor; - if (input->type() != DataType::DE_FLOAT64) { - RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); - } else { - input_tensor = input; - } - return MaskAlongAxis(input_tensor, output, mask_width_, mask_start_, mask_value_, axis_); -} - -Status MaskAlongAxisOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - RETURN_IF_NOT_OK( - ValidateTensorType("MaskAlongAxis", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); - if (inputs[0] == DataType(DataType::DE_FLOAT64)) { - outputs[0] = DataType(DataType::DE_FLOAT64); - } else { - outputs[0] = DataType(DataType::DE_FLOAT32); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/mask_along_axis_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { +Status MaskAlongAxisOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // input <..., freq, time> + RETURN_IF_NOT_OK(ValidateLowRank("MaskAlongAxis", input, kDefaultAudioDim, "<..., freq, time>")); + RETURN_IF_NOT_OK( + ValidateTensorType("MaskAlongAxis", input->type().IsNumeric(), "[int, float, double]", input->type().ToString())); + std::shared_ptr input_tensor; + if (input->type() != DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); + } else { + input_tensor = input; + } + return MaskAlongAxis(input_tensor, output, mask_width_, mask_start_, mask_value_, axis_); +} + +Status MaskAlongAxisOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + RETURN_IF_NOT_OK( + ValidateTensorType("MaskAlongAxis", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); + if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.h index f8910a0a343..9a9e5020039 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/mask_along_axis_op.h @@ -1,55 +1,55 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class MaskAlongAxisOp : public TensorOp { - public: - /// \brief Constructor. - /// \param[in] mask_start Starting position of the mask. - /// \param[in] mask_width The width of the mask. - /// \param[in] mask_value Value to assign to the masked columns. - /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). - MaskAlongAxisOp(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) - : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} - - ~MaskAlongAxisOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kMaskAlongAxisOp; } - - private: - int32_t mask_start_; - int32_t mask_width_; - float mask_value_; - int32_t axis_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class MaskAlongAxisOp : public TensorOp { + public: + /// \brief Constructor. + /// \param[in] mask_start Starting position of the mask. + /// \param[in] mask_width The width of the mask. + /// \param[in] mask_value Value to assign to the masked columns. + /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). + MaskAlongAxisOp(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis) + : mask_start_(mask_start), mask_width_(mask_width), mask_value_(mask_value), axis_(axis) {} + + ~MaskAlongAxisOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMaskAlongAxisOp; } + + private: + int32_t mask_start_; + int32_t mask_width_; + float mask_value_; + int32_t axis_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MASK_ALONG_AXIS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.cc index baee1b3f493..5c5acaa786f 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.cc @@ -1,57 +1,57 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/audio/kernels/phase_vocoder_op.h" - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { -Status PhaseVocoderOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - return PhaseVocoder(input, output, rate_, phase_advance_); -} - -Status PhaseVocoderOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - const int32_t kTimePos = -2; - const int32_t kComplexDimSize = 2; - for (auto s : inputs) { - std::vector s_vec = s.AsVector(); - s_vec.pop_back(); - s_vec.pop_back(); - s_vec.push_back(std::ceil(s[kTimePos] / rate_)); - s_vec.push_back(kComplexDimSize); - outputs.emplace_back(TensorShape(s_vec)); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "PhaseVocoder: invalid shape of input tensor."); - return Status::OK(); -} - -Status PhaseVocoderOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - RETURN_IF_NOT_OK( - ValidateTensorType("PhaseVocoder", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); - if (inputs[0] == DataType(DataType::DE_FLOAT64)) { - outputs[0] = DataType(DataType::DE_FLOAT64); - } else { - outputs[0] = DataType(DataType::DE_FLOAT32); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/phase_vocoder_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { +Status PhaseVocoderOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return PhaseVocoder(input, output, rate_, phase_advance_); +} + +Status PhaseVocoderOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + const int32_t kTimePos = -2; + const int32_t kComplexDimSize = 2; + for (auto s : inputs) { + std::vector s_vec = s.AsVector(); + s_vec.pop_back(); + s_vec.pop_back(); + s_vec.push_back(std::ceil(s[kTimePos] / rate_)); + s_vec.push_back(kComplexDimSize); + outputs.emplace_back(TensorShape(s_vec)); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "PhaseVocoder: invalid shape of input tensor."); + return Status::OK(); +} + +Status PhaseVocoderOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + RETURN_IF_NOT_OK( + ValidateTensorType("PhaseVocoder", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString())); + if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.h index bb427e16ba4..bb745a1c2b4 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/phase_vocoder_op.h @@ -1,53 +1,53 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class PhaseVocoderOp : public TensorOp { - public: - /// \brief Constructor. - /// \param[in] rate Speed-up factor. - /// \param[in] phase_advance Expected phase advance in each bin in shape of (freq, 1). - PhaseVocoderOp(float rate, const std::shared_ptr &phase_advance) - : rate_(rate), phase_advance_(phase_advance) {} - - ~PhaseVocoderOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kPhaseVocoderOp; } - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - float rate_; - std::shared_ptr phase_advance_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class PhaseVocoderOp : public TensorOp { + public: + /// \brief Constructor. + /// \param[in] rate Speed-up factor. + /// \param[in] phase_advance Expected phase advance in each bin in shape of (freq, 1). + PhaseVocoderOp(float rate, const std::shared_ptr &phase_advance) + : rate_(rate), phase_advance_(phase_advance) {} + + ~PhaseVocoderOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kPhaseVocoderOp; } + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + float rate_; + std::shared_ptr phase_advance_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASE_VOCODER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.cc index 4e22d919c85..ca2f2340f72 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/datasetops/source/ag_news_op.h" - -#include - -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/engine/datasetops/source/csv_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/engine/jagged_connector.h" - -namespace mindspore { -namespace dataset { -AGNewsOp::AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, - const std::vector> &column_default, - const std::vector &column_name, const std::vector &ag_news_list) - : CsvOp(ag_news_list, field_delim, column_default, column_name, num_workers, num_samples, worker_connector_size, - op_connector_size, shuffle_files, num_devices, device_id) {} - -// A print method typically used for debugging. -void AGNewsOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - // Call the super class for displaying any common 1-liner info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op. - out << "\n"; - } else { - // Call the super class for displaying any common detailed info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAGNews files list:\n"; - for (int i = 0; i < csv_files_list_.size(); ++i) { - out << " " << csv_files_list_[i]; - } - out << "\n\n"; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/ag_news_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +AGNewsOp::AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, + const std::vector> &column_default, + const std::vector &column_name, const std::vector &ag_news_list) + : CsvOp(ag_news_list, field_delim, column_default, column_name, num_workers, num_samples, worker_connector_size, + op_connector_size, shuffle_files, num_devices, device_id) {} + +// A print method typically used for debugging. +void AGNewsOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAGNews files list:\n"; + for (int i = 0; i < csv_files_list_.size(); ++i) { + out << " " << csv_files_list_[i]; + } + out << "\n\n"; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.h index fc7e1532d0d..b2fd3c36a21 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/ag_news_op.h @@ -1,77 +1,77 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/csv_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/ir/cache/dataset_cache.h" -#include "minddata/dataset/engine/jagged_connector.h" -#include "minddata/dataset/util/auto_index.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector; - -class AGNewsOp : public CsvOp { - public: - /// \brief Constructor. - /// \param[in] num_workers Number of workers reading images in parallel - /// \param[in] num_samples The number of samples to be included in the dataset. - /// (Default = 0 means all samples). - /// \param[in] worker_connector_size Size of each internal queue. - /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. - /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. - /// \param[in] num_devices Number of devices that the dataset should be divided into. (Default = 1) - /// \param[in] device_id The device ID within num_devices. This argument should be - /// specified only when num_devices is also specified (Default = 0). - /// \param[in] field_delim A char that indicates the delimiter to separate fields (default=','). - /// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is - /// either a valid type (float, int, or string). If this is not provided, treats all columns as string type. - /// \param[in] column_name List of column names of the dataset (default={}). If this is not provided, infers the - /// column_names from the first row of CSV file. - /// \param[in] ag_news_list List of files to be read to search for a pattern of files. The list - /// will be sorted in a lexicographical order. - AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, - const std::vector> &column_default, const std::vector &column_name, - const std::vector &ag_news_list); - - /// \brief Default destructor. - ~AGNewsOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[in] out he output stream to write output to. - /// \param[in] show_all A bool to control if you want to show all info or just a - /// summary. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "AGNewsOp"; } - - // DatasetName name getter - // \return DatasetName of the current Op - std::string DatasetName(bool upper = false) const { return upper ? "AGNews" : "ag news"; } -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/ir/cache/dataset_cache.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/util/auto_index.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector; + +class AGNewsOp : public CsvOp { + public: + /// \brief Constructor. + /// \param[in] num_workers Number of workers reading images in parallel + /// \param[in] num_samples The number of samples to be included in the dataset. + /// (Default = 0 means all samples). + /// \param[in] worker_connector_size Size of each internal queue. + /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. + /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. + /// \param[in] num_devices Number of devices that the dataset should be divided into. (Default = 1) + /// \param[in] device_id The device ID within num_devices. This argument should be + /// specified only when num_devices is also specified (Default = 0). + /// \param[in] field_delim A char that indicates the delimiter to separate fields (default=','). + /// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is + /// either a valid type (float, int, or string). If this is not provided, treats all columns as string type. + /// \param[in] column_name List of column names of the dataset (default={}). If this is not provided, infers the + /// column_names from the first row of CSV file. + /// \param[in] ag_news_list List of files to be read to search for a pattern of files. The list + /// will be sorted in a lexicographical order. + AGNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, + const std::vector> &column_default, const std::vector &column_name, + const std::vector &ag_news_list); + + /// \brief Default destructor. + ~AGNewsOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[in] out he output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a + /// summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "AGNewsOp"; } + + // DatasetName name getter + // \return DatasetName of the current Op + std::string DatasetName(bool upper = false) const { return upper ? "AGNews" : "ag news"; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AG_NEWS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.cc old mode 100755 new mode 100644 index a194b437084..852d4beda35 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.cc @@ -1,50 +1,50 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/amazon_review_op.h" - -#include - -namespace mindspore { -namespace dataset { -AmazonReviewOp::AmazonReviewOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, - int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, - char field_delim, const std::vector> &column_default, - const std::vector &column_name, - const std::vector &amazon_review_files_list) - : CsvOp(amazon_review_files_list, field_delim, column_default, column_name, num_workers, num_samples, - worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id) {} - -void AmazonReviewOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - // Call the super class for displaying any common 1-liner info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op. - out << "\n"; - } else { - // Call the super class for displaying any common detailed info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff. - out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAmazonReview files list:\n"; - for (int i = 0; i < csv_files_list_.size(); ++i) { - out << " " << csv_files_list_[i]; - } - out << "\n\n"; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/amazon_review_op.h" + +#include + +namespace mindspore { +namespace dataset { +AmazonReviewOp::AmazonReviewOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, + char field_delim, const std::vector> &column_default, + const std::vector &column_name, + const std::vector &amazon_review_files_list) + : CsvOp(amazon_review_files_list, field_delim, column_default, column_name, num_workers, num_samples, + worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id) {} + +void AmazonReviewOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff. + out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nAmazonReview files list:\n"; + for (int i = 0; i < csv_files_list_.size(); ++i) { + out << " " << csv_files_list_[i]; + } + out << "\n\n"; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.h old mode 100755 new mode 100644 index 2b58c85ae2b..7d006d7e45a --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/amazon_review_op.h @@ -1,71 +1,71 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/csv_op.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector; - -/// \class AmazonReviewOp -/// \brief A Op derived class to represent AmazonReview Op. -class AmazonReviewOp : public CsvOp { - public: - /// \brief Constructor of AmazonReviewOp. - /// \param[in] num_workers Number of worker threads reading data from amazon_review files. - /// \param[in] num_samples The number of samples to be included in the dataset. - /// \param[in] worker_connector_size Size of each internal queue. - /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. - /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. - /// \param[in] num_devices Number of devices that the dataset should be divided into. - /// \param[in] device_id The device ID within num_devices. - /// \param[in] field_delim A char that indicates the delimiter to separate fields. - /// \param[in] column_default List of default values for the CSV field. Each item in the list is - /// either a valid type (float, int, or string). - /// \param[in] column_name List of column names of the dataset. - /// \param[in] amazon_review_files_list List of file paths for the dataset files. - AmazonReviewOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, - const std::vector> &column_default, - const std::vector &column_name, const std::vector &amazon_review_files_list); - - /// \brief Destructor. - ~AmazonReviewOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[out] out The output stream to write output to. - /// \param[in] show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief DatasetName name getter. - /// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name. - /// \return DatasetName of the current Op. - std::string DatasetName(bool upper = false) const { return upper ? "AmazonReview" : "amazon review"; } - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "AmazonReviewOp"; } -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/csv_op.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector; + +/// \class AmazonReviewOp +/// \brief A Op derived class to represent AmazonReview Op. +class AmazonReviewOp : public CsvOp { + public: + /// \brief Constructor of AmazonReviewOp. + /// \param[in] num_workers Number of worker threads reading data from amazon_review files. + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] worker_connector_size Size of each internal queue. + /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. + /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. + /// \param[in] num_devices Number of devices that the dataset should be divided into. + /// \param[in] device_id The device ID within num_devices. + /// \param[in] field_delim A char that indicates the delimiter to separate fields. + /// \param[in] column_default List of default values for the CSV field. Each item in the list is + /// either a valid type (float, int, or string). + /// \param[in] column_name List of column names of the dataset. + /// \param[in] amazon_review_files_list List of file paths for the dataset files. + AmazonReviewOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, + const std::vector> &column_default, + const std::vector &column_name, const std::vector &amazon_review_files_list); + + /// \brief Destructor. + ~AmazonReviewOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[out] out The output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief DatasetName name getter. + /// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name. + /// \return DatasetName of the current Op. + std::string DatasetName(bool upper = false) const { return upper ? "AmazonReview" : "amazon review"; } + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "AmazonReviewOp"; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_AMAZON_REVIEW_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.cc old mode 100755 new mode 100644 index 0e975705a28..1721e7a52f4 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.cc @@ -1,32 +1,32 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "minddata/dataset/engine/datasetops/source/caltech_op.h" - -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { -const std::set kExts = {".jpg", ".JPEG"}; -const std::map kClassIndex = {}; -CaltechOp::CaltechOp(int32_t num_workers, const std::string &file_dir, int32_t queue_size, bool do_decode, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ImageFolderOp(num_workers, file_dir, queue_size, false, do_decode, kExts, kClassIndex, std::move(data_schema), - std::move(sampler)) {} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/datasetops/source/caltech_op.h" + +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { +const std::set kExts = {".jpg", ".JPEG"}; +const std::map kClassIndex = {}; +CaltechOp::CaltechOp(int32_t num_workers, const std::string &file_dir, int32_t queue_size, bool do_decode, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ImageFolderOp(num_workers, file_dir, queue_size, false, do_decode, kExts, kClassIndex, std::move(data_schema), + std::move(sampler)) {} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/caltech_op.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.cc index a0efc6aee54..52e32547c65 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.cc @@ -1,171 +1,171 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/cmu_arctic_op.h" - -#include - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/core/tensor_shape.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -const char kDataDirectory[] = "wav"; -const char kLabelDirectory[] = "etc"; -const char kLabelFileName[] = "txt.done.data"; -const char kDataFilePrefix[] = "cmu_us_"; -const char kDataFileSuffix[] = "_arctic"; - -CMUArcticOp::CMUArcticOp(const std::string &dataset_dir, const std::string &name, int32_t num_workers, - int32_t queue_size, std::unique_ptr data_schema, - std::shared_ptr sampler) - : MappableLeafOp(num_workers, queue_size, std::move(sampler)), - folder_path_(dataset_dir), - name_(name), - data_schema_(std::move(data_schema)) {} - -Status CMUArcticOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { - RETURN_UNEXPECTED_IF_NULL(trow); - const uint32_t sample_rate = 16000; - const std::string wav_suffix = ".wav"; - size_t pos = label_pairs_[row_id].first.find_last_of('_'); - CHECK_FAIL_RETURN_UNEXPECTED( - pos != std::string::npos && pos + 1 < label_pairs_[row_id].first.size(), - "Invalid utterance id, please check if it is in valid format: " + label_pairs_[row_id].first); - std::string utterance_id_t = label_pairs_[row_id].first.substr(pos + 1); - std::string full_name_path = kDataFilePrefix + name_ + kDataFileSuffix; - std::string file_name = label_pairs_[row_id].first + wav_suffix; - Path root_folder(real_path_); - Path wav_file_path = root_folder / full_name_path / kDataDirectory / file_name; - std::shared_ptr waveform, rate, transcript, utterance_id; - RETURN_IF_NOT_OK(ReadAudio(wav_file_path.ToString(), &waveform)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(label_pairs_[row_id].second, &transcript)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(utterance_id_t, &utterance_id)); - (*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(transcript), std::move(utterance_id)}); - Path label_dir = root_folder / full_name_path / kLabelDirectory / kLabelFileName; - trow->setPath({wav_file_path.ToString(), wav_file_path.ToString(), label_dir.ToString(), label_dir.ToString()}); - return Status::OK(); -} - -void CMUArcticOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - ParallelOp::Print(out, show_all); - out << "\n"; - } else { - ParallelOp::Print(out, show_all); - out << "\nNumber of rows: " << num_rows_ << "\nCMUArctic directory: " << folder_path_ << "\n\n"; - } -} - -Status CMUArcticOp::CountTotalRows(const std::string &dir, const std::string &name, int64_t *count) { - RETURN_UNEXPECTED_IF_NULL(count); - *count = 0; - const int64_t num_samples = 0; - const int64_t start_index = 0; - auto sampler = std::make_shared(start_index, num_samples); - auto schema = std::make_unique(); - - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_utterance = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance))); - TensorShape scalar_utterance_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); - std::shared_ptr cfg = GlobalContext::config_manager(); - - int32_t num_workers = cfg->num_parallel_workers(); - int32_t op_connect_size = cfg->op_connector_size(); - auto op = - std::make_shared(dir, name, num_workers, op_connect_size, std::move(schema), std::move(sampler)); - RETURN_IF_NOT_OK(op->PrepareData()); - *count = op->label_pairs_.size(); - return Status::OK(); -} - -Status CMUArcticOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->Column(i).Name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -Status CMUArcticOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { - RETURN_UNEXPECTED_IF_NULL(waveform); - const int32_t kWavFileSampleRate = 16000; - int32_t sample_rate = 0; - std::vector waveform_vec; - RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED( - sample_rate == kWavFileSampleRate, - "Invalid file, sampling rate of CMUArctic wav file must be 16000, file path: " + audio_dir); - RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); - RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); - return Status::OK(); -} - -Status CMUArcticOp::PrepareData() { - auto realpath = FileUtils::GetRealPath(folder_path_.c_str()); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Invalid file path, CMUArctic Dataset dir: " << folder_path_ << " does not exist."; - RETURN_STATUS_UNEXPECTED("Invalid file path, CMUArctic Dataset dir: " + folder_path_ + " does not exist."); - } - real_path_ = realpath.value(); - Path dir(real_path_); - std::string full_name_path = kDataFilePrefix + name_ + kDataFileSuffix; - Path label_dir = dir / full_name_path / kLabelDirectory / kLabelFileName; - CHECK_FAIL_RETURN_UNEXPECTED(label_dir.Exists() && !label_dir.IsDirectory(), - "Invalid file, failed to find label file: " + label_dir.ToString()); - std::ifstream label_reader(label_dir.ToString(), std::ifstream::in); - CHECK_FAIL_RETURN_UNEXPECTED(label_reader.is_open(), - "Invalid file, failed to open label file: " + label_dir.ToString() + - ", make sure file not damaged or permission denied."); - std::string line = ""; - while (getline(label_reader, line)) { - size_t quot_inx[2] = {0}; - size_t quot_num = 0; - size_t quot_exact = 2; - for (size_t i = 0; quot_num < quot_exact && i < line.size(); i++) { - if (line[i] == '"') { - quot_inx[quot_num++] = i; - } - } - if (quot_num != quot_exact) { - label_reader.close(); - RETURN_STATUS_UNEXPECTED("Invalid file, the file may not be a CMUArctic dataset file: " + label_dir.ToString()); - } - label_pairs_.push_back( - {line.substr(2, quot_inx[0] - 3), line.substr(quot_inx[0] + 1, quot_inx[1] - quot_inx[0] - 1)}); - } - label_reader.close(); - num_rows_ = label_pairs_.size(); - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path: " + folder_path_); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/cmu_arctic_op.h" + +#include + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +const char kDataDirectory[] = "wav"; +const char kLabelDirectory[] = "etc"; +const char kLabelFileName[] = "txt.done.data"; +const char kDataFilePrefix[] = "cmu_us_"; +const char kDataFileSuffix[] = "_arctic"; + +CMUArcticOp::CMUArcticOp(const std::string &dataset_dir, const std::string &name, int32_t num_workers, + int32_t queue_size, std::unique_ptr data_schema, + std::shared_ptr sampler) + : MappableLeafOp(num_workers, queue_size, std::move(sampler)), + folder_path_(dataset_dir), + name_(name), + data_schema_(std::move(data_schema)) {} + +Status CMUArcticOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { + RETURN_UNEXPECTED_IF_NULL(trow); + const uint32_t sample_rate = 16000; + const std::string wav_suffix = ".wav"; + size_t pos = label_pairs_[row_id].first.find_last_of('_'); + CHECK_FAIL_RETURN_UNEXPECTED( + pos != std::string::npos && pos + 1 < label_pairs_[row_id].first.size(), + "Invalid utterance id, please check if it is in valid format: " + label_pairs_[row_id].first); + std::string utterance_id_t = label_pairs_[row_id].first.substr(pos + 1); + std::string full_name_path = kDataFilePrefix + name_ + kDataFileSuffix; + std::string file_name = label_pairs_[row_id].first + wav_suffix; + Path root_folder(real_path_); + Path wav_file_path = root_folder / full_name_path / kDataDirectory / file_name; + std::shared_ptr waveform, rate, transcript, utterance_id; + RETURN_IF_NOT_OK(ReadAudio(wav_file_path.ToString(), &waveform)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(label_pairs_[row_id].second, &transcript)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(utterance_id_t, &utterance_id)); + (*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(transcript), std::move(utterance_id)}); + Path label_dir = root_folder / full_name_path / kLabelDirectory / kLabelFileName; + trow->setPath({wav_file_path.ToString(), wav_file_path.ToString(), label_dir.ToString(), label_dir.ToString()}); + return Status::OK(); +} + +void CMUArcticOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + ParallelOp::Print(out, show_all); + out << "\n"; + } else { + ParallelOp::Print(out, show_all); + out << "\nNumber of rows: " << num_rows_ << "\nCMUArctic directory: " << folder_path_ << "\n\n"; + } +} + +Status CMUArcticOp::CountTotalRows(const std::string &dir, const std::string &name, int64_t *count) { + RETURN_UNEXPECTED_IF_NULL(count); + *count = 0; + const int64_t num_samples = 0; + const int64_t start_index = 0; + auto sampler = std::make_shared(start_index, num_samples); + auto schema = std::make_unique(); + + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_utterance = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance))); + TensorShape scalar_utterance_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); + std::shared_ptr cfg = GlobalContext::config_manager(); + + int32_t num_workers = cfg->num_parallel_workers(); + int32_t op_connect_size = cfg->op_connector_size(); + auto op = + std::make_shared(dir, name, num_workers, op_connect_size, std::move(schema), std::move(sampler)); + RETURN_IF_NOT_OK(op->PrepareData()); + *count = op->label_pairs_.size(); + return Status::OK(); +} + +Status CMUArcticOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->Column(i).Name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +Status CMUArcticOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { + RETURN_UNEXPECTED_IF_NULL(waveform); + const int32_t kWavFileSampleRate = 16000; + int32_t sample_rate = 0; + std::vector waveform_vec; + RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED( + sample_rate == kWavFileSampleRate, + "Invalid file, sampling rate of CMUArctic wav file must be 16000, file path: " + audio_dir); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); + RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); + return Status::OK(); +} + +Status CMUArcticOp::PrepareData() { + auto realpath = FileUtils::GetRealPath(folder_path_.c_str()); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Invalid file path, CMUArctic Dataset dir: " << folder_path_ << " does not exist."; + RETURN_STATUS_UNEXPECTED("Invalid file path, CMUArctic Dataset dir: " + folder_path_ + " does not exist."); + } + real_path_ = realpath.value(); + Path dir(real_path_); + std::string full_name_path = kDataFilePrefix + name_ + kDataFileSuffix; + Path label_dir = dir / full_name_path / kLabelDirectory / kLabelFileName; + CHECK_FAIL_RETURN_UNEXPECTED(label_dir.Exists() && !label_dir.IsDirectory(), + "Invalid file, failed to find label file: " + label_dir.ToString()); + std::ifstream label_reader(label_dir.ToString(), std::ifstream::in); + CHECK_FAIL_RETURN_UNEXPECTED(label_reader.is_open(), + "Invalid file, failed to open label file: " + label_dir.ToString() + + ", make sure file not damaged or permission denied."); + std::string line = ""; + while (getline(label_reader, line)) { + size_t quot_inx[2] = {0}; + size_t quot_num = 0; + size_t quot_exact = 2; + for (size_t i = 0; quot_num < quot_exact && i < line.size(); i++) { + if (line[i] == '"') { + quot_inx[quot_num++] = i; + } + } + if (quot_num != quot_exact) { + label_reader.close(); + RETURN_STATUS_UNEXPECTED("Invalid file, the file may not be a CMUArctic dataset file: " + label_dir.ToString()); + } + label_pairs_.push_back( + {line.substr(2, quot_inx[0] - 3), line.substr(quot_inx[0] + 1, quot_inx[1] - quot_inx[0] - 1)}); + } + label_reader.close(); + num_rows_ = label_pairs_.size(); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path: " + folder_path_); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.h index 033bff4c17d..6a20873f9ac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cmu_arctic_op.h @@ -1,99 +1,99 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class CMUArcticOp : public MappableLeafOp { - public: - /// \brief Constructor. - /// \param[in] dataset_dir Directory of CMUArctic. - /// \param[in] name Part of this dataset, can be "aew", "ahw", "aup", "awb", "axb", "bdl", - /// "clb", "eey", "fem", "gka", "jmk", "ksp", "ljm", "lnh", "rms", "rxr", "slp" or "slt" - /// \param[in] num_workers Number of workers reading audios in parallel. - /// \param[in] queue_size Connector queue size. - /// \param[in] data_schema The schema of the CMUArctic dataset. - /// \param[in] sampler Sampler tells CMUArcticOp what to read. - CMUArcticOp(const std::string &dataset_dir, const std::string &name, int32_t num_workers, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - - /// \brief Destructor. - ~CMUArcticOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[out] out The output stream to write output to. - /// \param[in] show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief Function to count the number of samples in the CMUArctic dataset. - /// \param[in] dir Path to the CMUArctic directory. - /// \param[in] name Choose the subset of CMUArctic dataset. - /// \param[out] count Output arg that will hold the minimum of the actual dataset size and numSamples. - /// \return Status The status code returned. - static Status CountTotalRows(const std::string &dir, const std::string &name, int64_t *count); - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "CMUArcticOp"; } - - private: - /// \brief Load a tensor row according to a pair. - /// \param[in] row_id Id for this tensor row. - /// \param[out] row Audio & label read into this tensor row. - /// \return Status The status code returned. - Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; - - /// \brief Parse a single wav file. - /// \param[in] audio_dir Audio file path. - /// \param[out] waveform The output waveform tensor. - /// \return Status The status code returned. - Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); - - /// \brief Prepare all data in the directory. - /// \return Status The status code returned. - Status PrepareData(); - - /// \brief Private function for computing the assignment of the column name map. - /// \return Status. - Status ComputeColMap() override; - - const std::string name_; - std::string folder_path_; - std::string real_path_; - std::unique_ptr data_schema_; - std::vector> label_pairs_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class CMUArcticOp : public MappableLeafOp { + public: + /// \brief Constructor. + /// \param[in] dataset_dir Directory of CMUArctic. + /// \param[in] name Part of this dataset, can be "aew", "ahw", "aup", "awb", "axb", "bdl", + /// "clb", "eey", "fem", "gka", "jmk", "ksp", "ljm", "lnh", "rms", "rxr", "slp" or "slt" + /// \param[in] num_workers Number of workers reading audios in parallel. + /// \param[in] queue_size Connector queue size. + /// \param[in] data_schema The schema of the CMUArctic dataset. + /// \param[in] sampler Sampler tells CMUArcticOp what to read. + CMUArcticOp(const std::string &dataset_dir, const std::string &name, int32_t num_workers, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + + /// \brief Destructor. + ~CMUArcticOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[out] out The output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Function to count the number of samples in the CMUArctic dataset. + /// \param[in] dir Path to the CMUArctic directory. + /// \param[in] name Choose the subset of CMUArctic dataset. + /// \param[out] count Output arg that will hold the minimum of the actual dataset size and numSamples. + /// \return Status The status code returned. + static Status CountTotalRows(const std::string &dir, const std::string &name, int64_t *count); + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "CMUArcticOp"; } + + private: + /// \brief Load a tensor row according to a pair. + /// \param[in] row_id Id for this tensor row. + /// \param[out] row Audio & label read into this tensor row. + /// \return Status The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; + + /// \brief Parse a single wav file. + /// \param[in] audio_dir Audio file path. + /// \param[out] waveform The output waveform tensor. + /// \return Status The status code returned. + Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); + + /// \brief Prepare all data in the directory. + /// \return Status The status code returned. + Status PrepareData(); + + /// \brief Private function for computing the assignment of the column name map. + /// \return Status. + Status ComputeColMap() override; + + const std::string name_; + std::string folder_path_; + std::string real_path_; + std::unique_ptr data_schema_; + std::vector> label_pairs_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CMU_ARCTIC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc index 4195300600e..a3cbe9122f5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc @@ -1,335 +1,335 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/gtzan_op.h" - -#include -#include - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/core/tensor_shape.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -const std::vector genres = { - "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", -}; - -const std::vector filtered_test = { - "blues.00012", "blues.00013", "blues.00014", "blues.00015", "blues.00016", "blues.00017", - "blues.00018", "blues.00019", "blues.00020", "blues.00021", "blues.00022", "blues.00023", - "blues.00024", "blues.00025", "blues.00026", "blues.00027", "blues.00028", "blues.00061", - "blues.00062", "blues.00063", "blues.00064", "blues.00065", "blues.00066", "blues.00067", - "blues.00068", "blues.00069", "blues.00070", "blues.00071", "blues.00072", "blues.00098", - "blues.00099", "classical.00011", "classical.00012", "classical.00013", "classical.00014", "classical.00015", - "classical.00016", "classical.00017", "classical.00018", "classical.00019", "classical.00020", "classical.00021", - "classical.00022", "classical.00023", "classical.00024", "classical.00025", "classical.00026", "classical.00027", - "classical.00028", "classical.00029", "classical.00034", "classical.00035", "classical.00036", "classical.00037", - "classical.00038", "classical.00039", "classical.00040", "classical.00041", "classical.00049", "classical.00077", - "classical.00078", "classical.00079", "country.00030", "country.00031", "country.00032", "country.00033", - "country.00034", "country.00035", "country.00036", "country.00037", "country.00038", "country.00039", - "country.00040", "country.00043", "country.00044", "country.00046", "country.00047", "country.00048", - "country.00050", "country.00051", "country.00053", "country.00054", "country.00055", "country.00056", - "country.00057", "country.00058", "country.00059", "country.00060", "country.00061", "country.00062", - "country.00063", "country.00064", "disco.00001", "disco.00021", "disco.00058", "disco.00062", - "disco.00063", "disco.00064", "disco.00065", "disco.00066", "disco.00069", "disco.00076", - "disco.00077", "disco.00078", "disco.00079", "disco.00080", "disco.00081", "disco.00082", - "disco.00083", "disco.00084", "disco.00085", "disco.00086", "disco.00087", "disco.00088", - "disco.00091", "disco.00092", "disco.00093", "disco.00094", "disco.00096", "disco.00097", - "disco.00099", "hiphop.00000", "hiphop.00026", "hiphop.00027", "hiphop.00030", "hiphop.00040", - "hiphop.00043", "hiphop.00044", "hiphop.00045", "hiphop.00051", "hiphop.00052", "hiphop.00053", - "hiphop.00054", "hiphop.00062", "hiphop.00063", "hiphop.00064", "hiphop.00065", "hiphop.00066", - "hiphop.00067", "hiphop.00068", "hiphop.00069", "hiphop.00070", "hiphop.00071", "hiphop.00072", - "hiphop.00073", "hiphop.00074", "hiphop.00075", "hiphop.00099", "jazz.00073", "jazz.00074", - "jazz.00075", "jazz.00076", "jazz.00077", "jazz.00078", "jazz.00079", "jazz.00080", - "jazz.00081", "jazz.00082", "jazz.00083", "jazz.00084", "jazz.00085", "jazz.00086", - "jazz.00087", "jazz.00088", "jazz.00089", "jazz.00090", "jazz.00091", "jazz.00092", - "jazz.00093", "jazz.00094", "jazz.00095", "jazz.00096", "jazz.00097", "jazz.00098", - "jazz.00099", "metal.00012", "metal.00013", "metal.00014", "metal.00015", "metal.00022", - "metal.00023", "metal.00025", "metal.00026", "metal.00027", "metal.00028", "metal.00029", - "metal.00030", "metal.00031", "metal.00032", "metal.00033", "metal.00038", "metal.00039", - "metal.00067", "metal.00070", "metal.00073", "metal.00074", "metal.00075", "metal.00078", - "metal.00083", "metal.00085", "metal.00087", "metal.00088", "pop.00000", "pop.00001", - "pop.00013", "pop.00014", "pop.00043", "pop.00063", "pop.00064", "pop.00065", - "pop.00066", "pop.00069", "pop.00070", "pop.00071", "pop.00072", "pop.00073", - "pop.00074", "pop.00075", "pop.00076", "pop.00077", "pop.00078", "pop.00079", - "pop.00082", "pop.00088", "pop.00089", "pop.00090", "pop.00091", "pop.00092", - "pop.00093", "pop.00094", "pop.00095", "pop.00096", "reggae.00034", "reggae.00035", - "reggae.00036", "reggae.00037", "reggae.00038", "reggae.00039", "reggae.00040", "reggae.00046", - "reggae.00047", "reggae.00048", "reggae.00052", "reggae.00053", "reggae.00064", "reggae.00065", - "reggae.00066", "reggae.00067", "reggae.00068", "reggae.00071", "reggae.00079", "reggae.00082", - "reggae.00083", "reggae.00084", "reggae.00087", "reggae.00088", "reggae.00089", "reggae.00090", - "rock.00010", "rock.00011", "rock.00012", "rock.00013", "rock.00014", "rock.00015", - "rock.00027", "rock.00028", "rock.00029", "rock.00030", "rock.00031", "rock.00032", - "rock.00033", "rock.00034", "rock.00035", "rock.00036", "rock.00037", "rock.00039", - "rock.00040", "rock.00041", "rock.00042", "rock.00043", "rock.00044", "rock.00045", - "rock.00046", "rock.00047", "rock.00048", "rock.00086", "rock.00087", "rock.00088", - "rock.00089", "rock.00090", -}; - -const std::vector filtered_train = { - "blues.00029", "blues.00030", "blues.00031", "blues.00032", "blues.00033", "blues.00034", - "blues.00035", "blues.00036", "blues.00037", "blues.00038", "blues.00039", "blues.00040", - "blues.00041", "blues.00042", "blues.00043", "blues.00044", "blues.00045", "blues.00046", - "blues.00047", "blues.00048", "blues.00049", "blues.00073", "blues.00074", "blues.00075", - "blues.00076", "blues.00077", "blues.00078", "blues.00079", "blues.00080", "blues.00081", - "blues.00082", "blues.00083", "blues.00084", "blues.00085", "blues.00086", "blues.00087", - "blues.00088", "blues.00089", "blues.00090", "blues.00091", "blues.00092", "blues.00093", - "blues.00094", "blues.00095", "blues.00096", "blues.00097", "classical.00030", "classical.00031", - "classical.00032", "classical.00033", "classical.00043", "classical.00044", "classical.00045", "classical.00046", - "classical.00047", "classical.00048", "classical.00050", "classical.00051", "classical.00052", "classical.00053", - "classical.00054", "classical.00055", "classical.00056", "classical.00057", "classical.00058", "classical.00059", - "classical.00060", "classical.00061", "classical.00062", "classical.00063", "classical.00064", "classical.00065", - "classical.00066", "classical.00067", "classical.00080", "classical.00081", "classical.00082", "classical.00083", - "classical.00084", "classical.00085", "classical.00086", "classical.00087", "classical.00088", "classical.00089", - "classical.00090", "classical.00091", "classical.00092", "classical.00093", "classical.00094", "classical.00095", - "classical.00096", "classical.00097", "classical.00098", "classical.00099", "country.00019", "country.00020", - "country.00021", "country.00022", "country.00023", "country.00024", "country.00025", "country.00026", - "country.00028", "country.00029", "country.00065", "country.00066", "country.00067", "country.00068", - "country.00069", "country.00070", "country.00071", "country.00072", "country.00073", "country.00074", - "country.00075", "country.00076", "country.00077", "country.00078", "country.00079", "country.00080", - "country.00081", "country.00082", "country.00083", "country.00084", "country.00085", "country.00086", - "country.00087", "country.00088", "country.00089", "country.00090", "country.00091", "country.00092", - "country.00093", "country.00094", "country.00095", "country.00096", "country.00097", "country.00098", - "country.00099", "disco.00005", "disco.00015", "disco.00016", "disco.00017", "disco.00018", - "disco.00019", "disco.00020", "disco.00022", "disco.00023", "disco.00024", "disco.00025", - "disco.00026", "disco.00027", "disco.00028", "disco.00029", "disco.00030", "disco.00031", - "disco.00032", "disco.00033", "disco.00034", "disco.00035", "disco.00036", "disco.00037", - "disco.00039", "disco.00040", "disco.00041", "disco.00042", "disco.00043", "disco.00044", - "disco.00045", "disco.00047", "disco.00049", "disco.00053", "disco.00054", "disco.00056", - "disco.00057", "disco.00059", "disco.00061", "disco.00070", "disco.00073", "disco.00074", - "disco.00089", "hiphop.00002", "hiphop.00003", "hiphop.00004", "hiphop.00005", "hiphop.00006", - "hiphop.00007", "hiphop.00008", "hiphop.00009", "hiphop.00010", "hiphop.00011", "hiphop.00012", - "hiphop.00013", "hiphop.00014", "hiphop.00015", "hiphop.00016", "hiphop.00017", "hiphop.00018", - "hiphop.00019", "hiphop.00020", "hiphop.00021", "hiphop.00022", "hiphop.00023", "hiphop.00024", - "hiphop.00025", "hiphop.00028", "hiphop.00029", "hiphop.00031", "hiphop.00032", "hiphop.00033", - "hiphop.00034", "hiphop.00035", "hiphop.00036", "hiphop.00037", "hiphop.00038", "hiphop.00041", - "hiphop.00042", "hiphop.00055", "hiphop.00056", "hiphop.00057", "hiphop.00058", "hiphop.00059", - "hiphop.00060", "hiphop.00061", "hiphop.00077", "hiphop.00078", "hiphop.00079", "hiphop.00080", - "jazz.00000", "jazz.00001", "jazz.00011", "jazz.00012", "jazz.00013", "jazz.00014", - "jazz.00015", "jazz.00016", "jazz.00017", "jazz.00018", "jazz.00019", "jazz.00020", - "jazz.00021", "jazz.00022", "jazz.00023", "jazz.00024", "jazz.00041", "jazz.00047", - "jazz.00048", "jazz.00049", "jazz.00050", "jazz.00051", "jazz.00052", "jazz.00053", - "jazz.00054", "jazz.00055", "jazz.00056", "jazz.00057", "jazz.00058", "jazz.00059", - "jazz.00060", "jazz.00061", "jazz.00062", "jazz.00063", "jazz.00064", "jazz.00065", - "jazz.00066", "jazz.00067", "jazz.00068", "jazz.00069", "jazz.00070", "jazz.00071", - "jazz.00072", "metal.00002", "metal.00003", "metal.00005", "metal.00021", "metal.00024", - "metal.00035", "metal.00046", "metal.00047", "metal.00048", "metal.00049", "metal.00050", - "metal.00051", "metal.00052", "metal.00053", "metal.00054", "metal.00055", "metal.00056", - "metal.00057", "metal.00059", "metal.00060", "metal.00061", "metal.00062", "metal.00063", - "metal.00064", "metal.00065", "metal.00066", "metal.00069", "metal.00071", "metal.00072", - "metal.00079", "metal.00080", "metal.00084", "metal.00086", "metal.00089", "metal.00090", - "metal.00091", "metal.00092", "metal.00093", "metal.00094", "metal.00095", "metal.00096", - "metal.00097", "metal.00098", "metal.00099", "pop.00002", "pop.00003", "pop.00004", - "pop.00005", "pop.00006", "pop.00007", "pop.00008", "pop.00009", "pop.00011", - "pop.00012", "pop.00016", "pop.00017", "pop.00018", "pop.00019", "pop.00020", - "pop.00023", "pop.00024", "pop.00025", "pop.00026", "pop.00027", "pop.00028", - "pop.00029", "pop.00031", "pop.00032", "pop.00033", "pop.00034", "pop.00035", - "pop.00036", "pop.00038", "pop.00039", "pop.00040", "pop.00041", "pop.00042", - "pop.00044", "pop.00046", "pop.00049", "pop.00050", "pop.00080", "pop.00097", - "pop.00098", "pop.00099", "reggae.00000", "reggae.00001", "reggae.00002", "reggae.00004", - "reggae.00006", "reggae.00009", "reggae.00011", "reggae.00012", "reggae.00014", "reggae.00015", - "reggae.00016", "reggae.00017", "reggae.00018", "reggae.00019", "reggae.00020", "reggae.00021", - "reggae.00022", "reggae.00023", "reggae.00024", "reggae.00025", "reggae.00026", "reggae.00027", - "reggae.00028", "reggae.00029", "reggae.00030", "reggae.00031", "reggae.00032", "reggae.00042", - "reggae.00043", "reggae.00044", "reggae.00045", "reggae.00049", "reggae.00050", "reggae.00051", - "reggae.00054", "reggae.00055", "reggae.00056", "reggae.00057", "reggae.00058", "reggae.00059", - "reggae.00060", "reggae.00063", "reggae.00069", "rock.00000", "rock.00001", "rock.00002", - "rock.00003", "rock.00004", "rock.00005", "rock.00006", "rock.00007", "rock.00008", - "rock.00009", "rock.00016", "rock.00017", "rock.00018", "rock.00019", "rock.00020", - "rock.00021", "rock.00022", "rock.00023", "rock.00024", "rock.00025", "rock.00026", - "rock.00057", "rock.00058", "rock.00059", "rock.00060", "rock.00061", "rock.00062", - "rock.00063", "rock.00064", "rock.00065", "rock.00066", "rock.00067", "rock.00068", - "rock.00069", "rock.00070", "rock.00091", "rock.00092", "rock.00093", "rock.00094", - "rock.00095", "rock.00096", "rock.00097", "rock.00098", "rock.00099", -}; - -const std::vector filtered_valid = { - "blues.00000", "blues.00001", "blues.00002", "blues.00003", "blues.00004", "blues.00005", - "blues.00006", "blues.00007", "blues.00008", "blues.00009", "blues.00010", "blues.00011", - "blues.00050", "blues.00051", "blues.00052", "blues.00053", "blues.00054", "blues.00055", - "blues.00056", "blues.00057", "blues.00058", "blues.00059", "blues.00060", "classical.00000", - "classical.00001", "classical.00002", "classical.00003", "classical.00004", "classical.00005", "classical.00006", - "classical.00007", "classical.00008", "classical.00009", "classical.00010", "classical.00068", "classical.00069", - "classical.00070", "classical.00071", "classical.00072", "classical.00073", "classical.00074", "classical.00075", - "classical.00076", "country.00000", "country.00001", "country.00002", "country.00003", "country.00004", - "country.00005", "country.00006", "country.00007", "country.00009", "country.00010", "country.00011", - "country.00012", "country.00013", "country.00014", "country.00015", "country.00016", "country.00017", - "country.00018", "country.00027", "country.00041", "country.00042", "country.00045", "country.00049", - "disco.00000", "disco.00002", "disco.00003", "disco.00004", "disco.00006", "disco.00007", - "disco.00008", "disco.00009", "disco.00010", "disco.00011", "disco.00012", "disco.00013", - "disco.00014", "disco.00046", "disco.00048", "disco.00052", "disco.00067", "disco.00068", - "disco.00072", "disco.00075", "disco.00090", "disco.00095", "hiphop.00081", "hiphop.00082", - "hiphop.00083", "hiphop.00084", "hiphop.00085", "hiphop.00086", "hiphop.00087", "hiphop.00088", - "hiphop.00089", "hiphop.00090", "hiphop.00091", "hiphop.00092", "hiphop.00093", "hiphop.00094", - "hiphop.00095", "hiphop.00096", "hiphop.00097", "hiphop.00098", "jazz.00002", "jazz.00003", - "jazz.00004", "jazz.00005", "jazz.00006", "jazz.00007", "jazz.00008", "jazz.00009", - "jazz.00010", "jazz.00025", "jazz.00026", "jazz.00027", "jazz.00028", "jazz.00029", - "jazz.00030", "jazz.00031", "jazz.00032", "metal.00000", "metal.00001", "metal.00006", - "metal.00007", "metal.00008", "metal.00009", "metal.00010", "metal.00011", "metal.00016", - "metal.00017", "metal.00018", "metal.00019", "metal.00020", "metal.00036", "metal.00037", - "metal.00068", "metal.00076", "metal.00077", "metal.00081", "metal.00082", "pop.00010", - "pop.00053", "pop.00055", "pop.00058", "pop.00059", "pop.00060", "pop.00061", - "pop.00062", "pop.00081", "pop.00083", "pop.00084", "pop.00085", "pop.00086", - "reggae.00061", "reggae.00062", "reggae.00070", "reggae.00072", "reggae.00074", "reggae.00076", - "reggae.00077", "reggae.00078", "reggae.00085", "reggae.00092", "reggae.00093", "reggae.00094", - "reggae.00095", "reggae.00096", "reggae.00097", "reggae.00098", "reggae.00099", "rock.00038", - "rock.00049", "rock.00050", "rock.00051", "rock.00052", "rock.00053", "rock.00054", - "rock.00055", "rock.00056", "rock.00071", "rock.00072", "rock.00073", "rock.00074", - "rock.00075", "rock.00076", "rock.00077", "rock.00078", "rock.00079", "rock.00080", - "rock.00081", "rock.00082", "rock.00083", "rock.00084", "rock.00085", -}; - -GTZANOp::GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler) - : MappableLeafOp(num_workers, queue_size, std::move(sampler)), - usage_(usage), - folder_path_(folder_path), - data_schema_(std::move(data_schema)) {} - -Status GTZANOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { - RETURN_UNEXPECTED_IF_NULL(trow); - const uint32_t sample_rate = 22050; - std::shared_ptr waveform, rate, label; - RETURN_IF_NOT_OK(ReadAudio(audio_names_[row_id].first, &waveform)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_names_[row_id].second, &label)); - (*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(label)}); - trow->setPath({audio_names_[row_id].first, audio_names_[row_id].first, audio_names_[row_id].first}); - return Status::OK(); -} - -void GTZANOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - ParallelOp::Print(out, show_all); - out << "\n"; - return; - } - ParallelOp::Print(out, show_all); - out << "\nNumber of rows: " << num_rows_ << "\nGTZAN directory: " << folder_path_ << "\n\n"; -} - -Status GTZANOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { - RETURN_UNEXPECTED_IF_NULL(count); - *count = 0; - const int64_t num_samples = 0; - const int64_t start_index = 0; - auto sampler = std::make_shared(start_index, num_samples); - auto schema = std::make_unique(); - - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_label = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); - - std::shared_ptr cfg = GlobalContext::config_manager(); - int32_t num_workers = cfg->num_parallel_workers(); - int32_t op_connect_size = cfg->op_connector_size(); - auto op = std::make_shared(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler)); - RETURN_IF_NOT_OK(op->PrepareData()); - *count = op->audio_names_.size(); - return Status::OK(); -} - -Status GTZANOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->Column(i).Name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -Status GTZANOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { - RETURN_UNEXPECTED_IF_NULL(waveform); - const int32_t kWavFileSampleRate = 22050; - int32_t sample_rate = 0; - std::vector waveform_vec; - RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED(sample_rate == kWavFileSampleRate, - "Invalid file, sampling rate of GTZAN wav file must be 22050, file path: " + audio_dir); - RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); - RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); - return Status::OK(); -} - -Status GTZANOp::PrepareData() { - auto realpath = FileUtils::GetRealPath(folder_path_.c_str()); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Invalid file path, GTZAN Dataset dir: " << folder_path_ << " does not exist."; - RETURN_STATUS_UNEXPECTED("Invalid file path, GTZAN Dataset dir: " + folder_path_ + " does not exist."); - } - Path dir(folder_path_); - - if (usage_ == "all") { - for (std::string sub_directory : genres) { - Path full_dir = dir / sub_directory; - if (!full_dir.Exists() || !full_dir.IsDirectory()) { - continue; - } - auto dir_it = Path::DirIterator::OpenDirectory(&full_dir); - if (dir_it != nullptr) { - while (dir_it->HasNext()) { - Path file = dir_it->Next(); - std::string file_name = file.ToString(); - auto pos = file_name.find_last_of('.'); - std::string name = file_name.substr(0, pos), temp_ext = file_name.substr(pos); - if (temp_ext == ".wav" && name.find('.') != std::string::npos) { - audio_names_.push_back({file.ToString(), sub_directory}); - } else { - MS_LOG(WARNING) << "Invalid file, invalid file name or file type: " << file.ToString() << "."; - } - } - } else { - MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << full_dir.ToString() << "."; - } - } - } else { - const std::vector *files_point = nullptr; - if (usage_ == "test") { - files_point = &filtered_test; - } else if (usage_ == "train") { - files_point = &filtered_train; - } else { - files_point = &filtered_valid; - } - std::string ext = ".wav"; - for (auto sub_file_name : *files_point) { - auto pos = sub_file_name.find_first_of('.'); - std::string cls = sub_file_name.substr(0, pos); - Path full_dir = dir / cls / (sub_file_name + ext); - if (full_dir.Exists()) { - audio_names_.push_back({full_dir.ToString(), cls}); - } else { - MS_LOG(WARNING) << "The audio file is lost, file name= " << (sub_file_name + ext); - } - } - } - num_rows_ = audio_names_.size(); - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path:" + folder_path_); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/gtzan_op.h" + +#include +#include + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +const std::vector genres = { + "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", +}; + +const std::vector filtered_test = { + "blues.00012", "blues.00013", "blues.00014", "blues.00015", "blues.00016", "blues.00017", + "blues.00018", "blues.00019", "blues.00020", "blues.00021", "blues.00022", "blues.00023", + "blues.00024", "blues.00025", "blues.00026", "blues.00027", "blues.00028", "blues.00061", + "blues.00062", "blues.00063", "blues.00064", "blues.00065", "blues.00066", "blues.00067", + "blues.00068", "blues.00069", "blues.00070", "blues.00071", "blues.00072", "blues.00098", + "blues.00099", "classical.00011", "classical.00012", "classical.00013", "classical.00014", "classical.00015", + "classical.00016", "classical.00017", "classical.00018", "classical.00019", "classical.00020", "classical.00021", + "classical.00022", "classical.00023", "classical.00024", "classical.00025", "classical.00026", "classical.00027", + "classical.00028", "classical.00029", "classical.00034", "classical.00035", "classical.00036", "classical.00037", + "classical.00038", "classical.00039", "classical.00040", "classical.00041", "classical.00049", "classical.00077", + "classical.00078", "classical.00079", "country.00030", "country.00031", "country.00032", "country.00033", + "country.00034", "country.00035", "country.00036", "country.00037", "country.00038", "country.00039", + "country.00040", "country.00043", "country.00044", "country.00046", "country.00047", "country.00048", + "country.00050", "country.00051", "country.00053", "country.00054", "country.00055", "country.00056", + "country.00057", "country.00058", "country.00059", "country.00060", "country.00061", "country.00062", + "country.00063", "country.00064", "disco.00001", "disco.00021", "disco.00058", "disco.00062", + "disco.00063", "disco.00064", "disco.00065", "disco.00066", "disco.00069", "disco.00076", + "disco.00077", "disco.00078", "disco.00079", "disco.00080", "disco.00081", "disco.00082", + "disco.00083", "disco.00084", "disco.00085", "disco.00086", "disco.00087", "disco.00088", + "disco.00091", "disco.00092", "disco.00093", "disco.00094", "disco.00096", "disco.00097", + "disco.00099", "hiphop.00000", "hiphop.00026", "hiphop.00027", "hiphop.00030", "hiphop.00040", + "hiphop.00043", "hiphop.00044", "hiphop.00045", "hiphop.00051", "hiphop.00052", "hiphop.00053", + "hiphop.00054", "hiphop.00062", "hiphop.00063", "hiphop.00064", "hiphop.00065", "hiphop.00066", + "hiphop.00067", "hiphop.00068", "hiphop.00069", "hiphop.00070", "hiphop.00071", "hiphop.00072", + "hiphop.00073", "hiphop.00074", "hiphop.00075", "hiphop.00099", "jazz.00073", "jazz.00074", + "jazz.00075", "jazz.00076", "jazz.00077", "jazz.00078", "jazz.00079", "jazz.00080", + "jazz.00081", "jazz.00082", "jazz.00083", "jazz.00084", "jazz.00085", "jazz.00086", + "jazz.00087", "jazz.00088", "jazz.00089", "jazz.00090", "jazz.00091", "jazz.00092", + "jazz.00093", "jazz.00094", "jazz.00095", "jazz.00096", "jazz.00097", "jazz.00098", + "jazz.00099", "metal.00012", "metal.00013", "metal.00014", "metal.00015", "metal.00022", + "metal.00023", "metal.00025", "metal.00026", "metal.00027", "metal.00028", "metal.00029", + "metal.00030", "metal.00031", "metal.00032", "metal.00033", "metal.00038", "metal.00039", + "metal.00067", "metal.00070", "metal.00073", "metal.00074", "metal.00075", "metal.00078", + "metal.00083", "metal.00085", "metal.00087", "metal.00088", "pop.00000", "pop.00001", + "pop.00013", "pop.00014", "pop.00043", "pop.00063", "pop.00064", "pop.00065", + "pop.00066", "pop.00069", "pop.00070", "pop.00071", "pop.00072", "pop.00073", + "pop.00074", "pop.00075", "pop.00076", "pop.00077", "pop.00078", "pop.00079", + "pop.00082", "pop.00088", "pop.00089", "pop.00090", "pop.00091", "pop.00092", + "pop.00093", "pop.00094", "pop.00095", "pop.00096", "reggae.00034", "reggae.00035", + "reggae.00036", "reggae.00037", "reggae.00038", "reggae.00039", "reggae.00040", "reggae.00046", + "reggae.00047", "reggae.00048", "reggae.00052", "reggae.00053", "reggae.00064", "reggae.00065", + "reggae.00066", "reggae.00067", "reggae.00068", "reggae.00071", "reggae.00079", "reggae.00082", + "reggae.00083", "reggae.00084", "reggae.00087", "reggae.00088", "reggae.00089", "reggae.00090", + "rock.00010", "rock.00011", "rock.00012", "rock.00013", "rock.00014", "rock.00015", + "rock.00027", "rock.00028", "rock.00029", "rock.00030", "rock.00031", "rock.00032", + "rock.00033", "rock.00034", "rock.00035", "rock.00036", "rock.00037", "rock.00039", + "rock.00040", "rock.00041", "rock.00042", "rock.00043", "rock.00044", "rock.00045", + "rock.00046", "rock.00047", "rock.00048", "rock.00086", "rock.00087", "rock.00088", + "rock.00089", "rock.00090", +}; + +const std::vector filtered_train = { + "blues.00029", "blues.00030", "blues.00031", "blues.00032", "blues.00033", "blues.00034", + "blues.00035", "blues.00036", "blues.00037", "blues.00038", "blues.00039", "blues.00040", + "blues.00041", "blues.00042", "blues.00043", "blues.00044", "blues.00045", "blues.00046", + "blues.00047", "blues.00048", "blues.00049", "blues.00073", "blues.00074", "blues.00075", + "blues.00076", "blues.00077", "blues.00078", "blues.00079", "blues.00080", "blues.00081", + "blues.00082", "blues.00083", "blues.00084", "blues.00085", "blues.00086", "blues.00087", + "blues.00088", "blues.00089", "blues.00090", "blues.00091", "blues.00092", "blues.00093", + "blues.00094", "blues.00095", "blues.00096", "blues.00097", "classical.00030", "classical.00031", + "classical.00032", "classical.00033", "classical.00043", "classical.00044", "classical.00045", "classical.00046", + "classical.00047", "classical.00048", "classical.00050", "classical.00051", "classical.00052", "classical.00053", + "classical.00054", "classical.00055", "classical.00056", "classical.00057", "classical.00058", "classical.00059", + "classical.00060", "classical.00061", "classical.00062", "classical.00063", "classical.00064", "classical.00065", + "classical.00066", "classical.00067", "classical.00080", "classical.00081", "classical.00082", "classical.00083", + "classical.00084", "classical.00085", "classical.00086", "classical.00087", "classical.00088", "classical.00089", + "classical.00090", "classical.00091", "classical.00092", "classical.00093", "classical.00094", "classical.00095", + "classical.00096", "classical.00097", "classical.00098", "classical.00099", "country.00019", "country.00020", + "country.00021", "country.00022", "country.00023", "country.00024", "country.00025", "country.00026", + "country.00028", "country.00029", "country.00065", "country.00066", "country.00067", "country.00068", + "country.00069", "country.00070", "country.00071", "country.00072", "country.00073", "country.00074", + "country.00075", "country.00076", "country.00077", "country.00078", "country.00079", "country.00080", + "country.00081", "country.00082", "country.00083", "country.00084", "country.00085", "country.00086", + "country.00087", "country.00088", "country.00089", "country.00090", "country.00091", "country.00092", + "country.00093", "country.00094", "country.00095", "country.00096", "country.00097", "country.00098", + "country.00099", "disco.00005", "disco.00015", "disco.00016", "disco.00017", "disco.00018", + "disco.00019", "disco.00020", "disco.00022", "disco.00023", "disco.00024", "disco.00025", + "disco.00026", "disco.00027", "disco.00028", "disco.00029", "disco.00030", "disco.00031", + "disco.00032", "disco.00033", "disco.00034", "disco.00035", "disco.00036", "disco.00037", + "disco.00039", "disco.00040", "disco.00041", "disco.00042", "disco.00043", "disco.00044", + "disco.00045", "disco.00047", "disco.00049", "disco.00053", "disco.00054", "disco.00056", + "disco.00057", "disco.00059", "disco.00061", "disco.00070", "disco.00073", "disco.00074", + "disco.00089", "hiphop.00002", "hiphop.00003", "hiphop.00004", "hiphop.00005", "hiphop.00006", + "hiphop.00007", "hiphop.00008", "hiphop.00009", "hiphop.00010", "hiphop.00011", "hiphop.00012", + "hiphop.00013", "hiphop.00014", "hiphop.00015", "hiphop.00016", "hiphop.00017", "hiphop.00018", + "hiphop.00019", "hiphop.00020", "hiphop.00021", "hiphop.00022", "hiphop.00023", "hiphop.00024", + "hiphop.00025", "hiphop.00028", "hiphop.00029", "hiphop.00031", "hiphop.00032", "hiphop.00033", + "hiphop.00034", "hiphop.00035", "hiphop.00036", "hiphop.00037", "hiphop.00038", "hiphop.00041", + "hiphop.00042", "hiphop.00055", "hiphop.00056", "hiphop.00057", "hiphop.00058", "hiphop.00059", + "hiphop.00060", "hiphop.00061", "hiphop.00077", "hiphop.00078", "hiphop.00079", "hiphop.00080", + "jazz.00000", "jazz.00001", "jazz.00011", "jazz.00012", "jazz.00013", "jazz.00014", + "jazz.00015", "jazz.00016", "jazz.00017", "jazz.00018", "jazz.00019", "jazz.00020", + "jazz.00021", "jazz.00022", "jazz.00023", "jazz.00024", "jazz.00041", "jazz.00047", + "jazz.00048", "jazz.00049", "jazz.00050", "jazz.00051", "jazz.00052", "jazz.00053", + "jazz.00054", "jazz.00055", "jazz.00056", "jazz.00057", "jazz.00058", "jazz.00059", + "jazz.00060", "jazz.00061", "jazz.00062", "jazz.00063", "jazz.00064", "jazz.00065", + "jazz.00066", "jazz.00067", "jazz.00068", "jazz.00069", "jazz.00070", "jazz.00071", + "jazz.00072", "metal.00002", "metal.00003", "metal.00005", "metal.00021", "metal.00024", + "metal.00035", "metal.00046", "metal.00047", "metal.00048", "metal.00049", "metal.00050", + "metal.00051", "metal.00052", "metal.00053", "metal.00054", "metal.00055", "metal.00056", + "metal.00057", "metal.00059", "metal.00060", "metal.00061", "metal.00062", "metal.00063", + "metal.00064", "metal.00065", "metal.00066", "metal.00069", "metal.00071", "metal.00072", + "metal.00079", "metal.00080", "metal.00084", "metal.00086", "metal.00089", "metal.00090", + "metal.00091", "metal.00092", "metal.00093", "metal.00094", "metal.00095", "metal.00096", + "metal.00097", "metal.00098", "metal.00099", "pop.00002", "pop.00003", "pop.00004", + "pop.00005", "pop.00006", "pop.00007", "pop.00008", "pop.00009", "pop.00011", + "pop.00012", "pop.00016", "pop.00017", "pop.00018", "pop.00019", "pop.00020", + "pop.00023", "pop.00024", "pop.00025", "pop.00026", "pop.00027", "pop.00028", + "pop.00029", "pop.00031", "pop.00032", "pop.00033", "pop.00034", "pop.00035", + "pop.00036", "pop.00038", "pop.00039", "pop.00040", "pop.00041", "pop.00042", + "pop.00044", "pop.00046", "pop.00049", "pop.00050", "pop.00080", "pop.00097", + "pop.00098", "pop.00099", "reggae.00000", "reggae.00001", "reggae.00002", "reggae.00004", + "reggae.00006", "reggae.00009", "reggae.00011", "reggae.00012", "reggae.00014", "reggae.00015", + "reggae.00016", "reggae.00017", "reggae.00018", "reggae.00019", "reggae.00020", "reggae.00021", + "reggae.00022", "reggae.00023", "reggae.00024", "reggae.00025", "reggae.00026", "reggae.00027", + "reggae.00028", "reggae.00029", "reggae.00030", "reggae.00031", "reggae.00032", "reggae.00042", + "reggae.00043", "reggae.00044", "reggae.00045", "reggae.00049", "reggae.00050", "reggae.00051", + "reggae.00054", "reggae.00055", "reggae.00056", "reggae.00057", "reggae.00058", "reggae.00059", + "reggae.00060", "reggae.00063", "reggae.00069", "rock.00000", "rock.00001", "rock.00002", + "rock.00003", "rock.00004", "rock.00005", "rock.00006", "rock.00007", "rock.00008", + "rock.00009", "rock.00016", "rock.00017", "rock.00018", "rock.00019", "rock.00020", + "rock.00021", "rock.00022", "rock.00023", "rock.00024", "rock.00025", "rock.00026", + "rock.00057", "rock.00058", "rock.00059", "rock.00060", "rock.00061", "rock.00062", + "rock.00063", "rock.00064", "rock.00065", "rock.00066", "rock.00067", "rock.00068", + "rock.00069", "rock.00070", "rock.00091", "rock.00092", "rock.00093", "rock.00094", + "rock.00095", "rock.00096", "rock.00097", "rock.00098", "rock.00099", +}; + +const std::vector filtered_valid = { + "blues.00000", "blues.00001", "blues.00002", "blues.00003", "blues.00004", "blues.00005", + "blues.00006", "blues.00007", "blues.00008", "blues.00009", "blues.00010", "blues.00011", + "blues.00050", "blues.00051", "blues.00052", "blues.00053", "blues.00054", "blues.00055", + "blues.00056", "blues.00057", "blues.00058", "blues.00059", "blues.00060", "classical.00000", + "classical.00001", "classical.00002", "classical.00003", "classical.00004", "classical.00005", "classical.00006", + "classical.00007", "classical.00008", "classical.00009", "classical.00010", "classical.00068", "classical.00069", + "classical.00070", "classical.00071", "classical.00072", "classical.00073", "classical.00074", "classical.00075", + "classical.00076", "country.00000", "country.00001", "country.00002", "country.00003", "country.00004", + "country.00005", "country.00006", "country.00007", "country.00009", "country.00010", "country.00011", + "country.00012", "country.00013", "country.00014", "country.00015", "country.00016", "country.00017", + "country.00018", "country.00027", "country.00041", "country.00042", "country.00045", "country.00049", + "disco.00000", "disco.00002", "disco.00003", "disco.00004", "disco.00006", "disco.00007", + "disco.00008", "disco.00009", "disco.00010", "disco.00011", "disco.00012", "disco.00013", + "disco.00014", "disco.00046", "disco.00048", "disco.00052", "disco.00067", "disco.00068", + "disco.00072", "disco.00075", "disco.00090", "disco.00095", "hiphop.00081", "hiphop.00082", + "hiphop.00083", "hiphop.00084", "hiphop.00085", "hiphop.00086", "hiphop.00087", "hiphop.00088", + "hiphop.00089", "hiphop.00090", "hiphop.00091", "hiphop.00092", "hiphop.00093", "hiphop.00094", + "hiphop.00095", "hiphop.00096", "hiphop.00097", "hiphop.00098", "jazz.00002", "jazz.00003", + "jazz.00004", "jazz.00005", "jazz.00006", "jazz.00007", "jazz.00008", "jazz.00009", + "jazz.00010", "jazz.00025", "jazz.00026", "jazz.00027", "jazz.00028", "jazz.00029", + "jazz.00030", "jazz.00031", "jazz.00032", "metal.00000", "metal.00001", "metal.00006", + "metal.00007", "metal.00008", "metal.00009", "metal.00010", "metal.00011", "metal.00016", + "metal.00017", "metal.00018", "metal.00019", "metal.00020", "metal.00036", "metal.00037", + "metal.00068", "metal.00076", "metal.00077", "metal.00081", "metal.00082", "pop.00010", + "pop.00053", "pop.00055", "pop.00058", "pop.00059", "pop.00060", "pop.00061", + "pop.00062", "pop.00081", "pop.00083", "pop.00084", "pop.00085", "pop.00086", + "reggae.00061", "reggae.00062", "reggae.00070", "reggae.00072", "reggae.00074", "reggae.00076", + "reggae.00077", "reggae.00078", "reggae.00085", "reggae.00092", "reggae.00093", "reggae.00094", + "reggae.00095", "reggae.00096", "reggae.00097", "reggae.00098", "reggae.00099", "rock.00038", + "rock.00049", "rock.00050", "rock.00051", "rock.00052", "rock.00053", "rock.00054", + "rock.00055", "rock.00056", "rock.00071", "rock.00072", "rock.00073", "rock.00074", + "rock.00075", "rock.00076", "rock.00077", "rock.00078", "rock.00079", "rock.00080", + "rock.00081", "rock.00082", "rock.00083", "rock.00084", "rock.00085", +}; + +GTZANOp::GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler) + : MappableLeafOp(num_workers, queue_size, std::move(sampler)), + usage_(usage), + folder_path_(folder_path), + data_schema_(std::move(data_schema)) {} + +Status GTZANOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { + RETURN_UNEXPECTED_IF_NULL(trow); + const uint32_t sample_rate = 22050; + std::shared_ptr waveform, rate, label; + RETURN_IF_NOT_OK(ReadAudio(audio_names_[row_id].first, &waveform)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_names_[row_id].second, &label)); + (*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(label)}); + trow->setPath({audio_names_[row_id].first, audio_names_[row_id].first, audio_names_[row_id].first}); + return Status::OK(); +} + +void GTZANOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + ParallelOp::Print(out, show_all); + out << "\n"; + return; + } + ParallelOp::Print(out, show_all); + out << "\nNumber of rows: " << num_rows_ << "\nGTZAN directory: " << folder_path_ << "\n\n"; +} + +Status GTZANOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { + RETURN_UNEXPECTED_IF_NULL(count); + *count = 0; + const int64_t num_samples = 0; + const int64_t start_index = 0; + auto sampler = std::make_shared(start_index, num_samples); + auto schema = std::make_unique(); + + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_label = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); + + std::shared_ptr cfg = GlobalContext::config_manager(); + int32_t num_workers = cfg->num_parallel_workers(); + int32_t op_connect_size = cfg->op_connector_size(); + auto op = std::make_shared(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler)); + RETURN_IF_NOT_OK(op->PrepareData()); + *count = op->audio_names_.size(); + return Status::OK(); +} + +Status GTZANOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->Column(i).Name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +Status GTZANOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { + RETURN_UNEXPECTED_IF_NULL(waveform); + const int32_t kWavFileSampleRate = 22050; + int32_t sample_rate = 0; + std::vector waveform_vec; + RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED(sample_rate == kWavFileSampleRate, + "Invalid file, sampling rate of GTZAN wav file must be 22050, file path: " + audio_dir); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); + RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); + return Status::OK(); +} + +Status GTZANOp::PrepareData() { + auto realpath = FileUtils::GetRealPath(folder_path_.c_str()); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Invalid file path, GTZAN Dataset dir: " << folder_path_ << " does not exist."; + RETURN_STATUS_UNEXPECTED("Invalid file path, GTZAN Dataset dir: " + folder_path_ + " does not exist."); + } + Path dir(folder_path_); + + if (usage_ == "all") { + for (std::string sub_directory : genres) { + Path full_dir = dir / sub_directory; + if (!full_dir.Exists() || !full_dir.IsDirectory()) { + continue; + } + auto dir_it = Path::DirIterator::OpenDirectory(&full_dir); + if (dir_it != nullptr) { + while (dir_it->HasNext()) { + Path file = dir_it->Next(); + std::string file_name = file.ToString(); + auto pos = file_name.find_last_of('.'); + std::string name = file_name.substr(0, pos), temp_ext = file_name.substr(pos); + if (temp_ext == ".wav" && name.find('.') != std::string::npos) { + audio_names_.push_back({file.ToString(), sub_directory}); + } else { + MS_LOG(WARNING) << "Invalid file, invalid file name or file type: " << file.ToString() << "."; + } + } + } else { + MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << full_dir.ToString() << "."; + } + } + } else { + const std::vector *files_point = nullptr; + if (usage_ == "test") { + files_point = &filtered_test; + } else if (usage_ == "train") { + files_point = &filtered_train; + } else { + files_point = &filtered_valid; + } + std::string ext = ".wav"; + for (auto sub_file_name : *files_point) { + auto pos = sub_file_name.find_first_of('.'); + std::string cls = sub_file_name.substr(0, pos); + Path full_dir = dir / cls / (sub_file_name + ext); + if (full_dir.Exists()) { + audio_names_.push_back({full_dir.ToString(), cls}); + } else { + MS_LOG(WARNING) << "The audio file is lost, file name= " << (sub_file_name + ext); + } + } + } + num_rows_ = audio_names_.size(); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path:" + folder_path_); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h index 16df898a05a..76546cb293f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h @@ -1,97 +1,97 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class GTZANOp : public MappableLeafOp { - public: - /// \brief Constructor - /// \param[in] usage Usage of this dataset, can be 'train', 'valid', 'test', or 'all'. - /// \param[in] num_workers Number of workers reading audios in parallel. - /// \param[in] folder_path Dir directory of GTZAN. - /// \param[in] queue_size Connector queue size. - /// \param[in] data_schema The schema of the GTZAN dataset. - /// \param[in] sampler Sampler tells GTZANOp what to read. - GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - - /// \Destructor. - ~GTZANOp() = default; - - /// \A print method typically used for debugging. - /// \param[out] out Output stream. - /// \param[in] show_all Whether to show all information. - void Print(std::ostream &out, bool show_all) const override; - - /// \Function to count the number of samples in the GTZAN dataset. - /// \param[in] dir Path to the GTZAN directory. - /// \param[in] usage Choose the subset of GTZAN dataset. - /// \param[out] count Output arg that will hold the actual dataset size. - /// \return Status The status code returned. - static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); - - /// \Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "GTZANOp"; } - - private: - /// \Load a tensor row according to a pair. - /// \param[in] row_id Id for this tensor row. - /// \param[out] row Audio & label read into this tensor row. - /// \return Status The status code returned. - Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; - - /// \Parse a audio file. - /// \param[in] audio_dir Audio file path. - /// \param[out] waveform The output waveform tensor. - /// \return Status The status code returned. - Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); - - /// \Prepare data. - /// \return Status The status code returned. - Status PrepareData(); - - /// \Private function for computing the assignment of the column name map. - /// \return Status The status code returned. - Status ComputeColMap() override; - - const std::string usage_; - std::string folder_path_; - std::unique_ptr data_schema_; - std::vector> audio_names_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class GTZANOp : public MappableLeafOp { + public: + /// \brief Constructor + /// \param[in] usage Usage of this dataset, can be 'train', 'valid', 'test', or 'all'. + /// \param[in] num_workers Number of workers reading audios in parallel. + /// \param[in] folder_path Dir directory of GTZAN. + /// \param[in] queue_size Connector queue size. + /// \param[in] data_schema The schema of the GTZAN dataset. + /// \param[in] sampler Sampler tells GTZANOp what to read. + GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + + /// \Destructor. + ~GTZANOp() = default; + + /// \A print method typically used for debugging. + /// \param[out] out Output stream. + /// \param[in] show_all Whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + /// \Function to count the number of samples in the GTZAN dataset. + /// \param[in] dir Path to the GTZAN directory. + /// \param[in] usage Choose the subset of GTZAN dataset. + /// \param[out] count Output arg that will hold the actual dataset size. + /// \return Status The status code returned. + static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); + + /// \Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "GTZANOp"; } + + private: + /// \Load a tensor row according to a pair. + /// \param[in] row_id Id for this tensor row. + /// \param[out] row Audio & label read into this tensor row. + /// \return Status The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; + + /// \Parse a audio file. + /// \param[in] audio_dir Audio file path. + /// \param[out] waveform The output waveform tensor. + /// \return Status The status code returned. + Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); + + /// \Prepare data. + /// \return Status The status code returned. + Status PrepareData(); + + /// \Private function for computing the assignment of the column name map. + /// \return Status The status code returned. + Status ComputeColMap() override; + + const std::string usage_; + std::string folder_path_; + std::unique_ptr data_schema_; + std::vector> audio_names_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.cc index 5ade49012a4..f489bddc304 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.cc @@ -1,234 +1,234 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h" - -#include -#include -#include - -#include "minddata/dataset/audio/kernels/audio_utils.h" -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/core/tensor_shape.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -const int32_t label_file_suffix_len = 10; -const char label_file_suffix[] = ".trans.tsv"; -const char audio_file_suffix[] = ".wav"; -const std::vector usage_list = {"dev-clean", "dev-other", "test-clean", "test-other", - "train-clean-100", "train-clean-360", "train-other-500"}; - -LibriTTSOp::LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, - int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) - : MappableLeafOp(num_workers, queue_size, std::move(sampler)), - dataset_dir_(dataset_dir), - usage_(usage), - data_schema_(std::move(data_schema)) {} - -Status LibriTTSOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { - RETURN_UNEXPECTED_IF_NULL(trow); - LibriTTSLabelTuple audio_tuple = audio_label_tuples_[row_id]; - const uint32_t rate = 24000; - std::shared_ptr waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id; - Path dir(real_path_); - std::string file_name = audio_tuple.utterance_id + audio_file_suffix; - Path full_dir = dir / audio_tuple.usage / std::to_string(audio_tuple.speaker_id) / - std::to_string(audio_tuple.chapter_id) / file_name; - RETURN_IF_NOT_OK(ReadAudio(full_dir.ToString(), &waveform)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(rate, &sample_rate)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.original_text, &original_text)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.normalized_text, &normalized_text)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.speaker_id, &speaker_id)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.chapter_id, &chapter_id)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.utterance_id, &utterance_id)); - (*trow) = TensorRow( - row_id, {std::move(waveform), std::move(sample_rate), std::move(original_text), std::move(normalized_text), - std::move(speaker_id), std::move(chapter_id), std::move(utterance_id)}); - std::string label_path = audio_tuple.label_path; - trow->setPath({full_dir.ToString(), full_dir.ToString(), label_path, label_path, label_path, label_path, label_path}); - return Status::OK(); -} - -void LibriTTSOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - ParallelOp::Print(out, show_all); - out << "\n"; - } else { - ParallelOp::Print(out, show_all); - out << "\nNumber of rows: " << num_rows_ << "\nLibriTTS directory: " << dataset_dir_ << "\n\n"; - } -} - -Status LibriTTSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { - RETURN_UNEXPECTED_IF_NULL(count); - *count = 0; - const int64_t num_samples = 0; - const int64_t start_index = 0; - auto sampler = std::make_shared(start_index, num_samples); - auto schema = std::make_unique(); - - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_original_text = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text))); - TensorShape scalar_normalized_text = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING), - TensorImpl::kFlexible, 0, &scalar_normalized_text))); - TensorShape scalar_speaker_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id))); - TensorShape scalar_chapter_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id))); - TensorShape scalar_utterance_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); - std::shared_ptr cfg = GlobalContext::config_manager(); - int32_t num_workers = cfg->num_parallel_workers(); - int32_t op_connect_size = cfg->op_connector_size(); - auto op = - std::make_shared(dir, usage, num_workers, op_connect_size, std::move(schema), std::move(sampler)); - RETURN_IF_NOT_OK(op->PrepareData()); - *count = op->audio_label_tuples_.size(); - return Status::OK(); -} - -Status LibriTTSOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->Column(i).Name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -Status LibriTTSOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { - RETURN_UNEXPECTED_IF_NULL(waveform); - const int32_t kWavFileSampleRate = 24000; - int32_t sample_rate = 0; - std::vector waveform_vec; - RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); - CHECK_FAIL_RETURN_UNEXPECTED( - sample_rate == kWavFileSampleRate, - "Invalid file, sampling rate of LibriTTS wav file must be 24000, file path: " + audio_dir); - RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); - RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); - return Status::OK(); -} - -Status LibriTTSOp::PrepareData() { - auto realpath = FileUtils::GetRealPath(dataset_dir_.c_str()); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Invalid file path, LibriTTS dataset dir: " << dataset_dir_ << " does not exist."; - RETURN_STATUS_UNEXPECTED("Invalid file path, LibriTTS dataset dir: " + dataset_dir_ + " does not exist."); - } - real_path_ = realpath.value(); - Path dir(real_path_); - if (usage_ != "all") { - Path full_dir = dir / usage_; - cur_usage_ = usage_; - RETURN_IF_NOT_OK(GetPaths(&full_dir)); - RETURN_IF_NOT_OK(GetLabels()); - } else { - for (std::string usage_iter : usage_list) { - cur_usage_ = usage_iter; - Path full_dir = dir / cur_usage_; - RETURN_IF_NOT_OK(GetPaths(&full_dir)); - RETURN_IF_NOT_OK(GetLabels()); - } - } - num_rows_ = audio_label_tuples_.size(); - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, - "Invalid data, no valid data matching the dataset API LibriTTSDataset. " - "Please check dataset API or file path: " + - dataset_dir_ + "."); - return Status::OK(); -} - -Status LibriTTSOp::GetPaths(Path *dir) { - RETURN_UNEXPECTED_IF_NULL(dir); - auto iter = Path::DirIterator::OpenDirectory(dir); - if (iter == nullptr) { - MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << dir->ToString() << "."; - } else { - while (iter->HasNext()) { - Path sub_dir = iter->Next(); - if (sub_dir.IsDirectory()) { - RETURN_IF_NOT_OK(GetPaths(&sub_dir)); - } else { - Path file_path = sub_dir; - std::string file_name = file_path.Basename(); - int32_t length = file_name.size(); - if (length > label_file_suffix_len && file_name.substr(length - label_file_suffix_len) == label_file_suffix) { - label_files_.push_back(sub_dir.ToString()); - return Status::OK(); - } - } - } - } - return Status::OK(); -} - -Status LibriTTSOp::GetLabels() { - std::string utterance_id_body = ""; - std::string original_text_body = ""; - std::string normalized_text_body = ""; - const uint32_t base = 10; - const uint32_t ascii_zero = 48; - const size_t underline_exact = 3; - for (std::string label_file : label_files_) { - std::ifstream label_reader(label_file, std::ios::in); - while (getline(label_reader, utterance_id_body, '\t')) { - getline(label_reader, original_text_body, '\t'); - getline(label_reader, normalized_text_body, '\n'); - uint32_t speaker_id = 0; - uint32_t chapter_id = 0; - size_t underline_num = 0; - size_t underline_inx[4] = {0}; - for (size_t i = 0; i < utterance_id_body.size() && underline_num <= underline_exact; i++) { - if (utterance_id_body[i] == '_') { - underline_inx[underline_num++] = i; - } - } - if (underline_num != underline_exact) { - label_reader.close(); - RETURN_STATUS_UNEXPECTED("Invalid file, the file may not be a LibriTTS dataset file: " + label_file); - } - for (size_t i = 0; i < underline_inx[0]; i++) { - speaker_id = speaker_id * base + utterance_id_body[i] - ascii_zero; - } - for (size_t i = underline_inx[0] + 1; i < underline_inx[1]; i++) { - chapter_id = chapter_id * base + utterance_id_body[i] - ascii_zero; - } - audio_label_tuples_.push_back( - {cur_usage_, utterance_id_body, original_text_body, normalized_text_body, speaker_id, chapter_id, label_file}); - } - label_reader.close(); - } - label_files_.clear(); - return Status::OK(); -} -} // namespace dataset. -} // namespace mindspore. +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h" + +#include +#include +#include + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +const int32_t label_file_suffix_len = 10; +const char label_file_suffix[] = ".trans.tsv"; +const char audio_file_suffix[] = ".wav"; +const std::vector usage_list = {"dev-clean", "dev-other", "test-clean", "test-other", + "train-clean-100", "train-clean-360", "train-other-500"}; + +LibriTTSOp::LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) + : MappableLeafOp(num_workers, queue_size, std::move(sampler)), + dataset_dir_(dataset_dir), + usage_(usage), + data_schema_(std::move(data_schema)) {} + +Status LibriTTSOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { + RETURN_UNEXPECTED_IF_NULL(trow); + LibriTTSLabelTuple audio_tuple = audio_label_tuples_[row_id]; + const uint32_t rate = 24000; + std::shared_ptr waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id; + Path dir(real_path_); + std::string file_name = audio_tuple.utterance_id + audio_file_suffix; + Path full_dir = dir / audio_tuple.usage / std::to_string(audio_tuple.speaker_id) / + std::to_string(audio_tuple.chapter_id) / file_name; + RETURN_IF_NOT_OK(ReadAudio(full_dir.ToString(), &waveform)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(rate, &sample_rate)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.original_text, &original_text)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.normalized_text, &normalized_text)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.speaker_id, &speaker_id)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.chapter_id, &chapter_id)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_tuple.utterance_id, &utterance_id)); + (*trow) = TensorRow( + row_id, {std::move(waveform), std::move(sample_rate), std::move(original_text), std::move(normalized_text), + std::move(speaker_id), std::move(chapter_id), std::move(utterance_id)}); + std::string label_path = audio_tuple.label_path; + trow->setPath({full_dir.ToString(), full_dir.ToString(), label_path, label_path, label_path, label_path, label_path}); + return Status::OK(); +} + +void LibriTTSOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + ParallelOp::Print(out, show_all); + out << "\n"; + } else { + ParallelOp::Print(out, show_all); + out << "\nNumber of rows: " << num_rows_ << "\nLibriTTS directory: " << dataset_dir_ << "\n\n"; + } +} + +Status LibriTTSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { + RETURN_UNEXPECTED_IF_NULL(count); + *count = 0; + const int64_t num_samples = 0; + const int64_t start_index = 0; + auto sampler = std::make_shared(start_index, num_samples); + auto schema = std::make_unique(); + + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_original_text = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text))); + TensorShape scalar_normalized_text = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING), + TensorImpl::kFlexible, 0, &scalar_normalized_text))); + TensorShape scalar_speaker_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id))); + TensorShape scalar_chapter_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id))); + TensorShape scalar_utterance_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); + std::shared_ptr cfg = GlobalContext::config_manager(); + int32_t num_workers = cfg->num_parallel_workers(); + int32_t op_connect_size = cfg->op_connector_size(); + auto op = + std::make_shared(dir, usage, num_workers, op_connect_size, std::move(schema), std::move(sampler)); + RETURN_IF_NOT_OK(op->PrepareData()); + *count = op->audio_label_tuples_.size(); + return Status::OK(); +} + +Status LibriTTSOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->Column(i).Name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +Status LibriTTSOp::ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform) { + RETURN_UNEXPECTED_IF_NULL(waveform); + const int32_t kWavFileSampleRate = 24000; + int32_t sample_rate = 0; + std::vector waveform_vec; + RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); + CHECK_FAIL_RETURN_UNEXPECTED( + sample_rate == kWavFileSampleRate, + "Invalid file, sampling rate of LibriTTS wav file must be 24000, file path: " + audio_dir); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); + RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); + return Status::OK(); +} + +Status LibriTTSOp::PrepareData() { + auto realpath = FileUtils::GetRealPath(dataset_dir_.c_str()); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Invalid file path, LibriTTS dataset dir: " << dataset_dir_ << " does not exist."; + RETURN_STATUS_UNEXPECTED("Invalid file path, LibriTTS dataset dir: " + dataset_dir_ + " does not exist."); + } + real_path_ = realpath.value(); + Path dir(real_path_); + if (usage_ != "all") { + Path full_dir = dir / usage_; + cur_usage_ = usage_; + RETURN_IF_NOT_OK(GetPaths(&full_dir)); + RETURN_IF_NOT_OK(GetLabels()); + } else { + for (std::string usage_iter : usage_list) { + cur_usage_ = usage_iter; + Path full_dir = dir / cur_usage_; + RETURN_IF_NOT_OK(GetPaths(&full_dir)); + RETURN_IF_NOT_OK(GetLabels()); + } + } + num_rows_ = audio_label_tuples_.size(); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, + "Invalid data, no valid data matching the dataset API LibriTTSDataset. " + "Please check dataset API or file path: " + + dataset_dir_ + "."); + return Status::OK(); +} + +Status LibriTTSOp::GetPaths(Path *dir) { + RETURN_UNEXPECTED_IF_NULL(dir); + auto iter = Path::DirIterator::OpenDirectory(dir); + if (iter == nullptr) { + MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << dir->ToString() << "."; + } else { + while (iter->HasNext()) { + Path sub_dir = iter->Next(); + if (sub_dir.IsDirectory()) { + RETURN_IF_NOT_OK(GetPaths(&sub_dir)); + } else { + Path file_path = sub_dir; + std::string file_name = file_path.Basename(); + int32_t length = file_name.size(); + if (length > label_file_suffix_len && file_name.substr(length - label_file_suffix_len) == label_file_suffix) { + label_files_.push_back(sub_dir.ToString()); + return Status::OK(); + } + } + } + } + return Status::OK(); +} + +Status LibriTTSOp::GetLabels() { + std::string utterance_id_body = ""; + std::string original_text_body = ""; + std::string normalized_text_body = ""; + const uint32_t base = 10; + const uint32_t ascii_zero = 48; + const size_t underline_exact = 3; + for (std::string label_file : label_files_) { + std::ifstream label_reader(label_file, std::ios::in); + while (getline(label_reader, utterance_id_body, '\t')) { + getline(label_reader, original_text_body, '\t'); + getline(label_reader, normalized_text_body, '\n'); + uint32_t speaker_id = 0; + uint32_t chapter_id = 0; + size_t underline_num = 0; + size_t underline_inx[4] = {0}; + for (size_t i = 0; i < utterance_id_body.size() && underline_num <= underline_exact; i++) { + if (utterance_id_body[i] == '_') { + underline_inx[underline_num++] = i; + } + } + if (underline_num != underline_exact) { + label_reader.close(); + RETURN_STATUS_UNEXPECTED("Invalid file, the file may not be a LibriTTS dataset file: " + label_file); + } + for (size_t i = 0; i < underline_inx[0]; i++) { + speaker_id = speaker_id * base + utterance_id_body[i] - ascii_zero; + } + for (size_t i = underline_inx[0] + 1; i < underline_inx[1]; i++) { + chapter_id = chapter_id * base + utterance_id_body[i] - ascii_zero; + } + audio_label_tuples_.push_back( + {cur_usage_, utterance_id_body, original_text_body, normalized_text_body, speaker_id, chapter_id, label_file}); + } + label_reader.close(); + } + label_files_.clear(); + return Status::OK(); +} +} // namespace dataset. +} // namespace mindspore. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.h index 143827b3429..90dc05e4420 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/libri_tts_op.h @@ -1,120 +1,120 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -struct LibriTTSLabelTuple { - std::string usage; - std::string utterance_id; - std::string original_text; - std::string normalized_text; - uint32_t speaker_id; - uint32_t chapter_id; - std::string label_path; -}; - -class LibriTTSOp : public MappableLeafOp { - public: - /// \brief Constructor. - /// \param[in] dataset_dir Dir directory of LibriTTS. - /// \param[in] usage usage of this dataset, can be "dev-clean", "dev-other", "test-clean", "test-other", - /// "train-clean-100", "train-clean-360", "train-other-500", or "all". - /// \param[in] num_workers Number of workers reading audios in parallel. - /// \param[in] queue_size Connector queue size. - /// \param[in] data_schema The schema of the LibriTTS dataset. - /// \param[in] sampler Sampler tells LibriSpeechOp what to read. - LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - - /// \brief Destructor. - ~LibriTTSOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[out] out Output stream. - /// \param[in] show_all Whether to show all information. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief Function to count the number of samples in the LibriTTS dataset. - /// \param[in] dir Path to the LibriTTS directory. - /// \param[in] usage Select the data set section. - /// \param[out] count Output arg that will hold the minimum of the actual dataset size and numSamples. - /// \return Status The status code returned. - static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "LibriTTSOp"; } - - private: - /// \brief Load a tensor row according to a pair. - /// \param[in] row_id Id for this tensor row. - /// \param[out] row Audio & label read into this tensor row. - /// \return Status The status code returned. - Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; - - /// \brief Read all paths in the directory. - /// \param[in] dir File path to be traversed. - /// \return Status The status code returned. - Status GetPaths(Path *dir); - - /// \brief Read all label files. - /// \return Status The status code returned. - Status GetLabels(); - - /// \brief Parse a single wav file. - /// \param[in] audio_dir Audio file path. - /// \param[out] waveform The output waveform tensor. - /// \return Status The status code returned. - Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); - - /// \brief Prepare all data in the directory. - /// \return Status The status code returned. - Status PrepareData(); - - /// \brief Private function for computing the assignment of the column name map. - /// \return Status The status code returned. - Status ComputeColMap() override; - - const std::string usage_; - std::string cur_usage_; - std::string real_path_; - std::string dataset_dir_; - std::unique_ptr data_schema_; - std::vector audio_label_tuples_; - std::vector label_files_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +struct LibriTTSLabelTuple { + std::string usage; + std::string utterance_id; + std::string original_text; + std::string normalized_text; + uint32_t speaker_id; + uint32_t chapter_id; + std::string label_path; +}; + +class LibriTTSOp : public MappableLeafOp { + public: + /// \brief Constructor. + /// \param[in] dataset_dir Dir directory of LibriTTS. + /// \param[in] usage usage of this dataset, can be "dev-clean", "dev-other", "test-clean", "test-other", + /// "train-clean-100", "train-clean-360", "train-other-500", or "all". + /// \param[in] num_workers Number of workers reading audios in parallel. + /// \param[in] queue_size Connector queue size. + /// \param[in] data_schema The schema of the LibriTTS dataset. + /// \param[in] sampler Sampler tells LibriSpeechOp what to read. + LibriTTSOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + + /// \brief Destructor. + ~LibriTTSOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[out] out Output stream. + /// \param[in] show_all Whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Function to count the number of samples in the LibriTTS dataset. + /// \param[in] dir Path to the LibriTTS directory. + /// \param[in] usage Select the data set section. + /// \param[out] count Output arg that will hold the minimum of the actual dataset size and numSamples. + /// \return Status The status code returned. + static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "LibriTTSOp"; } + + private: + /// \brief Load a tensor row according to a pair. + /// \param[in] row_id Id for this tensor row. + /// \param[out] row Audio & label read into this tensor row. + /// \return Status The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; + + /// \brief Read all paths in the directory. + /// \param[in] dir File path to be traversed. + /// \return Status The status code returned. + Status GetPaths(Path *dir); + + /// \brief Read all label files. + /// \return Status The status code returned. + Status GetLabels(); + + /// \brief Parse a single wav file. + /// \param[in] audio_dir Audio file path. + /// \param[out] waveform The output waveform tensor. + /// \return Status The status code returned. + Status ReadAudio(const std::string &audio_dir, std::shared_ptr *waveform); + + /// \brief Prepare all data in the directory. + /// \return Status The status code returned. + Status PrepareData(); + + /// \brief Private function for computing the assignment of the column name map. + /// \return Status The status code returned. + Status ComputeColMap() override; + + const std::string usage_; + std::string cur_usage_; + std::string real_path_; + std::string dataset_dir_; + std::unique_ptr data_schema_; + std::vector audio_label_tuples_; + std::vector label_files_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_LIBRI_TTS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.cc index 24a7f22f3b3..5b756e68100 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h" - -#include "include/common/debug/common.h" -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/util/random.h" -#include "minddata/dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -PennTreebankOp::PennTreebankOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, - std::unique_ptr schema, const std::vector &file_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id) - : TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), file_list, op_connector_size, - shuffle_files, num_devices, device_id) {} - -// A print method typically used for debugging. -void PennTreebankOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - // Call the super class for displaying any common 1-liner info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op. - out << "\n"; - } else { - // Call the super class for displaying any common detailed info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff. - out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nPennTreebank files list:\n"; - for (size_t i = 0; i < text_files_list_.size(); ++i) { - out << " " << text_files_list_[i]; - } - out << "\nData Schema:\n"; - out << *data_schema_ << "\n\n"; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h" + +#include "include/common/debug/common.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +PennTreebankOp::PennTreebankOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr schema, const std::vector &file_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id) + : TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), file_list, op_connector_size, + shuffle_files, num_devices, device_id) {} + +// A print method typically used for debugging. +void PennTreebankOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff. + out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nPennTreebank files list:\n"; + for (size_t i = 0; i < text_files_list_.size(); ++i) { + out << " " << text_files_list_[i]; + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.h index 3fc701e4a9a..399899b0d5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/penn_treebank_op.h @@ -1,69 +1,69 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/text_file_op.h" -#include "minddata/dataset/util/queue.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector; - -class PennTreebankOp : public TextFileOp { - public: - /// \brief Constructor. - /// \param[in] num_workers Number of workers reading images in parallel - /// \param[in] num_samples The number of samples to be included in the dataset. - /// \param[in] worker_connector_size Size of each internal queue. - /// \param[in] data_schema Path to dataset schema file. - /// \param[in] file_list List of files to be read to search for a pattern of files. The list - /// will be sorted in a lexicographical order. - /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. - /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. - /// \param[in] num_devices Number of devices that the dataset should be divided into. - /// \param[in] device_id The device ID within num_devices. This argument should be - /// specified only when num_devices is also specified. - PennTreebankOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, std::unique_ptr, - const std::vector &file_list, int32_t op_connector_size, bool shuffle_files, - int32_t num_devices, int32_t device_id); - - /// \brief Default destructor. - ~PennTreebankOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[in] out he output stream to write output to. - /// \param[in] show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "PennTreebankOp"; } - - /// \brief DatasetName name getter. - /// \return DatasetName of the current Op. - std::string DatasetName(bool upper = false) const { return upper ? "PennTreebank" : "penn treebank"; } -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector; + +class PennTreebankOp : public TextFileOp { + public: + /// \brief Constructor. + /// \param[in] num_workers Number of workers reading images in parallel + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] worker_connector_size Size of each internal queue. + /// \param[in] data_schema Path to dataset schema file. + /// \param[in] file_list List of files to be read to search for a pattern of files. The list + /// will be sorted in a lexicographical order. + /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. + /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. + /// \param[in] num_devices Number of devices that the dataset should be divided into. + /// \param[in] device_id The device ID within num_devices. This argument should be + /// specified only when num_devices is also specified. + PennTreebankOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, std::unique_ptr, + const std::vector &file_list, int32_t op_connector_size, bool shuffle_files, + int32_t num_devices, int32_t device_id); + + /// \brief Default destructor. + ~PennTreebankOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[in] out he output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "PennTreebankOp"; } + + /// \brief DatasetName name getter. + /// \return DatasetName of the current Op. + std::string DatasetName(bool upper = false) const { return upper ? "PennTreebank" : "penn treebank"; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PENN_TREEBANK_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h index abf26a8d98a..708d68d3d13 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.h @@ -1,113 +1,113 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { - -using QMnistImageInfoPair = std::pair, std::shared_ptr>; - -class QMnistOp : public MnistOp { - public: - // Constructor. - // @param const std::string &folder_path - dir directory of QMNIST data file. - // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or - // 'all'. - // @param bool compat - Compatibility with Mnist. - // @param std::unique_ptr data_schema - the schema of the QMNIST dataset. - // @param td::unique_ptr sampler - sampler tells QMnistOp what to read. - // @param int32_t num_workers - number of workers reading images in parallel. - // @param int32_t queue_size - connector queue size. - QMnistOp(const std::string &folder_path, const std::string &usage, bool compat, - std::unique_ptr data_schema, std::shared_ptr sampler, int32_t num_workers, - int32_t queue_size); - - // Destructor. - ~QMnistOp() = default; - - // Op name getter. - // @return std::string - Name of the current Op. - std::string Name() const override { return "QMnistOp"; } - - // DatasetName name getter - // \return std::string - DatasetName of the current Op - std::string DatasetName(bool upper = false) const { return upper ? "QMnist" : "qmnist"; } - - // A print method typically used for debugging. - // @param std::ostream &out - out stream. - // @param bool show_all - whether to show all information. - void Print(std::ostream &out, bool show_all) const override; - - // Function to count the number of samples in the QMNIST dataset. - // @param const std::string &dir - path to the QMNIST directory. - // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or - // 'all'. - // @param int64_t *count - output arg that will hold the actual dataset size. - // @return Status -The status code returned. - static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); - - private: - // Load a tensor row according to a pair. - // @param row_id_type row_id - id for this tensor row. - // @param TensorRow row - image & label read into this tensor row. - // @return Status - The status code returned. - Status LoadTensorRow(row_id_type row_id, TensorRow *trow) override; - - // Get needed files in the folder_path_. - // @return Status - The status code returned. - Status WalkAllFiles() override; - - // Read images and labels from the file stream. - // @param std::ifstream *image_reader - image file stream. - // @param std::ifstream *label_reader - label file stream. - // @param size_t index - the index of file that is reading. - // @return Status The status code returned. - Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) override; - - // Check label stream. - // @param const std::string &file_name - label file name. - // @param std::ifstream *label_reader - label file stream. - // @param uint32_t num_labels - returns the number of labels. - // @return Status The status code returned. - Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) override; - - const bool compat_; // compatible with mnist - - std::vector image_info_pairs_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { + +using QMnistImageInfoPair = std::pair, std::shared_ptr>; + +class QMnistOp : public MnistOp { + public: + // Constructor. + // @param const std::string &folder_path - dir directory of QMNIST data file. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or + // 'all'. + // @param bool compat - Compatibility with Mnist. + // @param std::unique_ptr data_schema - the schema of the QMNIST dataset. + // @param td::unique_ptr sampler - sampler tells QMnistOp what to read. + // @param int32_t num_workers - number of workers reading images in parallel. + // @param int32_t queue_size - connector queue size. + QMnistOp(const std::string &folder_path, const std::string &usage, bool compat, + std::unique_ptr data_schema, std::shared_ptr sampler, int32_t num_workers, + int32_t queue_size); + + // Destructor. + ~QMnistOp() = default; + + // Op name getter. + // @return std::string - Name of the current Op. + std::string Name() const override { return "QMnistOp"; } + + // DatasetName name getter + // \return std::string - DatasetName of the current Op + std::string DatasetName(bool upper = false) const { return upper ? "QMnist" : "qmnist"; } + + // A print method typically used for debugging. + // @param std::ostream &out - out stream. + // @param bool show_all - whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the QMNIST dataset. + // @param const std::string &dir - path to the QMNIST directory. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test', 'test10k', 'test50k', 'nist' or + // 'all'. + // @param int64_t *count - output arg that will hold the actual dataset size. + // @return Status -The status code returned. + static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); + + private: + // Load a tensor row according to a pair. + // @param row_id_type row_id - id for this tensor row. + // @param TensorRow row - image & label read into this tensor row. + // @return Status - The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *trow) override; + + // Get needed files in the folder_path_. + // @return Status - The status code returned. + Status WalkAllFiles() override; + + // Read images and labels from the file stream. + // @param std::ifstream *image_reader - image file stream. + // @param std::ifstream *label_reader - label file stream. + // @param size_t index - the index of file that is reading. + // @return Status The status code returned. + Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) override; + + // Check label stream. + // @param const std::string &file_name - label file name. + // @param std::ifstream *label_reader - label file stream. + // @param uint32_t num_labels - returns the number of labels. + // @return Status The status code returned. + Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) override; + + const bool compat_; // compatible with mnist + + std::vector image_info_pairs_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_QMNIST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.cc index 4c1b2f9db16..072f22a387c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.cc @@ -1,52 +1,52 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/datasetops/source/sogou_news_op.h" - -#include - -#include "include/common/debug/common.h" - -namespace mindspore { -namespace dataset { -SogouNewsOp::SogouNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, - int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, - char field_delim, const std::vector> &column_default, - const std::vector &column_name, - const std::vector &sogou_news_files_list) - : CsvOp(sogou_news_files_list, field_delim, column_default, column_name, num_workers, num_samples, - worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id) {} - -void SogouNewsOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - // Call the super class for displaying any common 1-liner info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op. - out << "\n"; - } else { - // Call the super class for displaying any common detailed info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff. - out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nSogouNews files list:\n"; - for (int i = 0; i < csv_files_list_.size(); ++i) { - out << " " << csv_files_list_[i]; - } - out << "\n\n"; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/sogou_news_op.h" + +#include + +#include "include/common/debug/common.h" + +namespace mindspore { +namespace dataset { +SogouNewsOp::SogouNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, + char field_delim, const std::vector> &column_default, + const std::vector &column_name, + const std::vector &sogou_news_files_list) + : CsvOp(sogou_news_files_list, field_delim, column_default, column_name, num_workers, num_samples, + worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id) {} + +void SogouNewsOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff. + out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nSogouNews files list:\n"; + for (int i = 0; i < csv_files_list_.size(); ++i) { + out << " " << csv_files_list_[i]; + } + out << "\n\n"; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.h index 315bf4f1e35..f43b3805fd5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sogou_news_op.h @@ -1,71 +1,71 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/csv_op.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector; - -/// \class SogouNewsOp -/// \brief A Op derived class to represent SogouNews Op. -class SogouNewsOp : public CsvOp { - public: - /// \brief Constructor of SogouNewsOp. - /// \param[in] num_workers Number of worker threads reading data from sogou_news files. - /// \param[in] num_samples The number of samples to be included in the dataset. - /// \param[in] worker_connector_size Size of each internal queue. - /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. - /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. - /// \param[in] num_devices Number of devices that the dataset should be divided into. - /// \param[in] device_id The device ID within num_devices. - /// \param[in] field_delim A char that indicates the delimiter to separate fields. - /// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is - /// either a valid type (float, int, or string). - /// \param[in] column_name List of column names of the dataset. - /// \param[in] sogounews_files_list List of file paths for the dataset files. - SogouNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, - const std::vector> &column_default, - const std::vector &column_name, const std::vector &sogou_news_files_list); - - /// \brief Destructor. - ~SogouNewsOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[out] out The output stream to write output to. - /// \param[in] show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief DatasetName name getter. - /// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name. - /// \return DatasetName of the current Op. - std::string DatasetName(bool upper = false) const { return upper ? "SogouNews" : "sogou news"; } - - /// \brief Op name getter. - /// \return Name of the current Op. - std::string Name() const override { return "SogouNewsOp"; } -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/csv_op.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector; + +/// \class SogouNewsOp +/// \brief A Op derived class to represent SogouNews Op. +class SogouNewsOp : public CsvOp { + public: + /// \brief Constructor of SogouNewsOp. + /// \param[in] num_workers Number of worker threads reading data from sogou_news files. + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] worker_connector_size Size of each internal queue. + /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. + /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. + /// \param[in] num_devices Number of devices that the dataset should be divided into. + /// \param[in] device_id The device ID within num_devices. + /// \param[in] field_delim A char that indicates the delimiter to separate fields. + /// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is + /// either a valid type (float, int, or string). + /// \param[in] column_name List of column names of the dataset. + /// \param[in] sogounews_files_list List of file paths for the dataset files. + SogouNewsOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, char field_delim, + const std::vector> &column_default, + const std::vector &column_name, const std::vector &sogou_news_files_list); + + /// \brief Destructor. + ~SogouNewsOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[out] out The output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief DatasetName name getter. + /// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name. + /// \return DatasetName of the current Op. + std::string DatasetName(bool upper = false) const { return upper ? "SogouNews" : "sogou news"; } + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "SogouNewsOp"; } +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SOGOU_NEWS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.cc index fdb3f0714c8..db8ad85e301 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.cc @@ -1,311 +1,311 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/datasetops/source/tedlium_op.h" - -#include - -#include "minddata/dataset/core/config_manager.h" -#include "minddata/dataset/core/tensor_shape.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -TedliumOp::TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage, - const std::string &extensions, int32_t num_parallel_workers, - std::unique_ptr data_schema, std::shared_ptr sampler, int32_t queue_size) - : MappableLeafOp(num_parallel_workers, queue_size, std::move(sampler)), - dataset_dir_(dataset_dir), - release_(release), - usage_(usage), - extensions_(extensions), - data_schema_(std::move(data_schema)), - audio_files_({}), - usage_list_({}) {} - -void TedliumOp::Print(std::ostream &out, bool show_all) const { - if (!show_all) { - // Call the super class for displaying any common 1-liner info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op. - out << "\n"; - } else { - // Call the super class for displaying any common detailed info. - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff. - out << "\nNumber of rows: " << num_rows_ << "\nTedliumOp directory: " << dataset_dir_; - } -} - -Status TedliumOp::PrepareData() { - auto real_path = FileUtils::GetRealPath(dataset_dir_.c_str()); - if (!real_path.has_value()) { - RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + dataset_dir_); - } - Path root_folder(real_path.value()); - - if (release_ == "release1" || release_ == "release2") { - if (usage_ == "train" || usage_ == "test" || usage_ == "dev") { - usage_list_.push_back(usage_); - } else if (usage_ == "all") { - usage_list_ = {"train", "test", "dev"}; - } else { - RETURN_STATUS_UNEXPECTED( - "Invalid parameter, usage should be \"train\", \"test\", \"dev\" or \"all\" when " - "specify \"release1\" or \"release2\" , got " + - usage_); - } - for (int32_t i = 0; i < usage_list_.size(); ++i) { - Path stm_folder = root_folder / usage_list_[i] / "stm"; - RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, usage_list_[i])); - } - } else if (release_ == "release3") { - if (usage_ == "all") { - Path stm_folder = root_folder / "data" / "stm"; - RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, "data")); - } else { - RETURN_STATUS_UNEXPECTED("Invalid parameter, usage should be \"all\" when specify \"release3\" , got " + usage_); - } - } - std::sort(audio_files_.begin(), audio_files_.end()); - num_rows_ = audio_files_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "Invalid data, no valid data matching the dataset API TedliumDataset. Please check file path or dataset API."); - } - return Status::OK(); -} - -Status TedliumOp::ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage) { - Path dir(stm_folder); - std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&dir); - if (!dir.Exists() || dirItr == nullptr) { - RETURN_STATUS_UNEXPECTED("Invalid file, failed to open folder: " + dir.ToString()); - } - MS_LOG(DEBUG) << "Tedlium " + release_ + " stm folder Path found: " << dir << "."; - while (dirItr->HasNext()) { - Path file = dirItr->Next(); - if (file.Extension() == ".stm") { - std::ifstream handle(file.ToString(), std::ios::in); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file.ToString()); - } - std::string line; - int32_t numline = 0; - while (getline(handle, line)) { - std::string filename = line.substr(0, line.find(" ")); - std::stringstream ss; - ss << numline; - audio_files_.push_back({ss.str(), filename, release_usage}); - ++numline; - } - handle.close(); - } - } - return Status::OK(); -} - -Status TedliumOp::ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id, - std::string *start_time, std::string *end_time, std::string *identifier, - std::string *transcript) { - std::ifstream handle(file_stm_path.ToString().c_str(), std::ios::in); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_stm_path.ToString()); - } - std::string line; - int32_t i = 0; - while (i <= row_line && getline(handle, line)) { - ++i; - } - handle.close(); - std::vector temp; - i = 0; - const int32_t data_stm_number = 7; - // There are seven pieces of data in each row, which need to be read out and stored - // with a space as a separator. - // Talk_id, _, speaker_id, start_time, end_time, identifier, transcript. - // "_" is the data we don't need. - while (i < data_stm_number - 1) { - std::string s = line.substr(0, line.find(" ")); - temp.push_back(s); - line.erase(0, line.find(" ") + 1); // to delete space, so use s.find(" ") + 1. - ++i; - } - temp.push_back(line); - if (temp.size() != data_stm_number) { - RETURN_STATUS_UNEXPECTED("Invalid data, stm data was broken."); - } - - const int32_t talk_id_num = 0, speaker_id_num = 2, start_time_num = 3, end_time_num = 4, identifier_num = 5, - transcript_num = 6; - *talk_id = temp[talk_id_num]; - // temp[1] is "_", which is the data we don't need. - *speaker_id = temp[speaker_id_num]; - *start_time = temp[start_time_num]; - *end_time = temp[end_time_num]; - *identifier = temp[identifier_num]; - *transcript = temp[transcript_num]; - - return Status::OK(); -} - -Status TedliumOp::ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate, - std::vector *result) { - std::ifstream handle(file_sph_path.ToString().c_str(), std::ios::in | std::ios::binary); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_sph_path.ToString()); - } - - char head[1024]; - handle.read(head, sizeof(head)); - CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(), - "Invalid data, failed to read head part from sph file: " + file_sph_path.ToString() + - ", re-download dataset(make sure the data is true)."); - std::vector vec; - for (int32_t i = 0, j = 0; i < strlen(head); ++i) { - if (head[i] == '\n' || head[i] == ' ') { - while (head[i + 1] == ' ') { - i++; - } - std::string strTemp(head + j, i - j); - vec.push_back(strTemp); - j = i + 1; - } - } - const int32_t dataToBytes = 2; - for (int32_t i = 0; i < vec.size(); ++i) { - if (vec[i] == "sample_rate") { - *sample_rate = atoi(vec[i + dataToBytes].c_str()); - } - } - - int32_t start = static_cast(start_time * (*sample_rate)); - int32_t end = static_cast(end_time * (*sample_rate)); - const int32_t size = (end - start); - std::vector temp(size * dataToBytes); - handle.seekg(start, std::ios::beg); - int32_t j = 0; - char c; - while (j < size * dataToBytes) { - handle.read(&c, 1); - CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(), - "Invalid data, failed to read data part from sph file: " + file_sph_path.ToString() + - ", re-download dataset(make sure the data is true)."); - temp.push_back(c); - ++j; - } - - const float kMaxVal = 32767.0; - for (int32_t i = 0; i < size; ++i) { - char bh = temp[2 * i]; - char bl = temp[2 * i + 1]; - // SPH audio files is big-endian, so we should convert the two bytes of data into int16_t based - // on the high 8 bits and the low 8 bits. - int16_t s = static_cast(((bh & 0x00FF) << 8) | (bl & 0x00FF)); - // Data normalization: Convert the data from the interval [-32768,32767] to the interval [-1,1]. - double t = s / kMaxVal; - (*result).push_back(t); - } - handle.close(); - - return Status::OK(); -} - -Status TedliumOp::LoadTensorRow(row_id_type row_id, TensorRow *row) { - int32_t row_line = atoi(audio_files_[row_id][0].c_str()); - std::string file_name = audio_files_[row_id][1]; - std::string file_usage_or3_none_ = audio_files_[row_id][2]; - Path dir_path(dataset_dir_); - Path file_stm_path = dir_path / file_usage_or3_none_ / "stm" / (file_name + ".stm"); - Path file_sph_path = dir_path / file_usage_or3_none_ / "sph" / (file_name + extensions_); - std::string talk_id, speaker_id, start_time, end_time, identifier, transcript; - std::vector result; - int32_t sample_rate; - RETURN_IF_NOT_OK( - ReadStm(file_stm_path, row_line, &talk_id, &speaker_id, &start_time, &end_time, &identifier, &transcript)); - RETURN_IF_NOT_OK(ReadSph(file_sph_path, atof(start_time.c_str()), atof(end_time.c_str()), &sample_rate, &result)); - - std::shared_ptr sample_rate_tensor, talk_id_tensor, speaker_id_tensor, identifier_tensor, transcript_tensor; - RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_tensor)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(talk_id, &talk_id_tensor)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(speaker_id, &speaker_id_tensor)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(identifier, &identifier_tensor)); - RETURN_IF_NOT_OK(Tensor::CreateScalar(transcript, &transcript_tensor)); - - std::shared_ptr audio_tensor; - RETURN_IF_NOT_OK(Tensor::CreateFromVector(result, &audio_tensor)); - RETURN_IF_NOT_OK(audio_tensor->ExpandDim(0)); - (*row) = TensorRow(row_id, {audio_tensor, sample_rate_tensor, transcript_tensor, talk_id_tensor, speaker_id_tensor, - identifier_tensor}); - row->setPath({file_sph_path.ToString(), file_sph_path.ToString(), file_stm_path.ToString(), file_stm_path.ToString(), - file_stm_path.ToString(), file_stm_path.ToString()}); - - return Status::OK(); -} - -Status TedliumOp::CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage, - const std::string &extensions, int64_t *count) { - // the logic of counting the number of samples is copied from PrepareData() - RETURN_UNEXPECTED_IF_NULL(count); - *count = 0; - const int64_t num_samples = 0; - const int64_t start_index = 0; - auto new_sampler = std::make_shared(start_index, num_samples); - - // build a new unique schema object - auto new_schema = std::make_unique(); - RETURN_IF_NOT_OK( - new_schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - TensorShape sample_rate_scalar = TensorShape::CreateScalar(); - TensorShape trans_scalar = TensorShape::CreateScalar(); - TensorShape talk_id_scalar = TensorShape::CreateScalar(); - TensorShape speaker_id_scalar = TensorShape::CreateScalar(); - TensorShape identi_scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(new_schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar))); - RETURN_IF_NOT_OK(new_schema->AddColumn( - ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar))); - RETURN_IF_NOT_OK(new_schema->AddColumn( - ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar))); - RETURN_IF_NOT_OK(new_schema->AddColumn( - ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar))); - RETURN_IF_NOT_OK(new_schema->AddColumn( - ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar))); - - std::shared_ptr cfg = GlobalContext::config_manager(); - int32_t num_workers = cfg->num_parallel_workers(); - int32_t op_connect_size = cfg->op_connector_size(); - std::shared_ptr op = - std::make_shared(dataset_dir, release, usage, extensions, num_workers, std::move(new_schema), - std::move(new_sampler), op_connect_size); - RETURN_IF_NOT_OK(op->PrepareData()); - *count = static_cast(op->audio_files_.size()); - return Status::OK(); -} - -Status TedliumOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->Column(i).Name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/tedlium_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +TedliumOp::TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage, + const std::string &extensions, int32_t num_parallel_workers, + std::unique_ptr data_schema, std::shared_ptr sampler, int32_t queue_size) + : MappableLeafOp(num_parallel_workers, queue_size, std::move(sampler)), + dataset_dir_(dataset_dir), + release_(release), + usage_(usage), + extensions_(extensions), + data_schema_(std::move(data_schema)), + audio_files_({}), + usage_list_({}) {} + +void TedliumOp::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff. + out << "\nNumber of rows: " << num_rows_ << "\nTedliumOp directory: " << dataset_dir_; + } +} + +Status TedliumOp::PrepareData() { + auto real_path = FileUtils::GetRealPath(dataset_dir_.c_str()); + if (!real_path.has_value()) { + RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + dataset_dir_); + } + Path root_folder(real_path.value()); + + if (release_ == "release1" || release_ == "release2") { + if (usage_ == "train" || usage_ == "test" || usage_ == "dev") { + usage_list_.push_back(usage_); + } else if (usage_ == "all") { + usage_list_ = {"train", "test", "dev"}; + } else { + RETURN_STATUS_UNEXPECTED( + "Invalid parameter, usage should be \"train\", \"test\", \"dev\" or \"all\" when " + "specify \"release1\" or \"release2\" , got " + + usage_); + } + for (int32_t i = 0; i < usage_list_.size(); ++i) { + Path stm_folder = root_folder / usage_list_[i] / "stm"; + RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, usage_list_[i])); + } + } else if (release_ == "release3") { + if (usage_ == "all") { + Path stm_folder = root_folder / "data" / "stm"; + RETURN_IF_NOT_OK(ReadStmFolderRows(stm_folder, "data")); + } else { + RETURN_STATUS_UNEXPECTED("Invalid parameter, usage should be \"all\" when specify \"release3\" , got " + usage_); + } + } + std::sort(audio_files_.begin(), audio_files_.end()); + num_rows_ = audio_files_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, no valid data matching the dataset API TedliumDataset. Please check file path or dataset API."); + } + return Status::OK(); +} + +Status TedliumOp::ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage) { + Path dir(stm_folder); + std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&dir); + if (!dir.Exists() || dirItr == nullptr) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open folder: " + dir.ToString()); + } + MS_LOG(DEBUG) << "Tedlium " + release_ + " stm folder Path found: " << dir << "."; + while (dirItr->HasNext()) { + Path file = dirItr->Next(); + if (file.Extension() == ".stm") { + std::ifstream handle(file.ToString(), std::ios::in); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file.ToString()); + } + std::string line; + int32_t numline = 0; + while (getline(handle, line)) { + std::string filename = line.substr(0, line.find(" ")); + std::stringstream ss; + ss << numline; + audio_files_.push_back({ss.str(), filename, release_usage}); + ++numline; + } + handle.close(); + } + } + return Status::OK(); +} + +Status TedliumOp::ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id, + std::string *start_time, std::string *end_time, std::string *identifier, + std::string *transcript) { + std::ifstream handle(file_stm_path.ToString().c_str(), std::ios::in); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_stm_path.ToString()); + } + std::string line; + int32_t i = 0; + while (i <= row_line && getline(handle, line)) { + ++i; + } + handle.close(); + std::vector temp; + i = 0; + const int32_t data_stm_number = 7; + // There are seven pieces of data in each row, which need to be read out and stored + // with a space as a separator. + // Talk_id, _, speaker_id, start_time, end_time, identifier, transcript. + // "_" is the data we don't need. + while (i < data_stm_number - 1) { + std::string s = line.substr(0, line.find(" ")); + temp.push_back(s); + line.erase(0, line.find(" ") + 1); // to delete space, so use s.find(" ") + 1. + ++i; + } + temp.push_back(line); + if (temp.size() != data_stm_number) { + RETURN_STATUS_UNEXPECTED("Invalid data, stm data was broken."); + } + + const int32_t talk_id_num = 0, speaker_id_num = 2, start_time_num = 3, end_time_num = 4, identifier_num = 5, + transcript_num = 6; + *talk_id = temp[talk_id_num]; + // temp[1] is "_", which is the data we don't need. + *speaker_id = temp[speaker_id_num]; + *start_time = temp[start_time_num]; + *end_time = temp[end_time_num]; + *identifier = temp[identifier_num]; + *transcript = temp[transcript_num]; + + return Status::OK(); +} + +Status TedliumOp::ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate, + std::vector *result) { + std::ifstream handle(file_sph_path.ToString().c_str(), std::ios::in | std::ios::binary); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_sph_path.ToString()); + } + + char head[1024]; + handle.read(head, sizeof(head)); + CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(), + "Invalid data, failed to read head part from sph file: " + file_sph_path.ToString() + + ", re-download dataset(make sure the data is true)."); + std::vector vec; + for (int32_t i = 0, j = 0; i < strlen(head); ++i) { + if (head[i] == '\n' || head[i] == ' ') { + while (head[i + 1] == ' ') { + i++; + } + std::string strTemp(head + j, i - j); + vec.push_back(strTemp); + j = i + 1; + } + } + const int32_t dataToBytes = 2; + for (int32_t i = 0; i < vec.size(); ++i) { + if (vec[i] == "sample_rate") { + *sample_rate = atoi(vec[i + dataToBytes].c_str()); + } + } + + int32_t start = static_cast(start_time * (*sample_rate)); + int32_t end = static_cast(end_time * (*sample_rate)); + const int32_t size = (end - start); + std::vector temp(size * dataToBytes); + handle.seekg(start, std::ios::beg); + int32_t j = 0; + char c; + while (j < size * dataToBytes) { + handle.read(&c, 1); + CHECK_FAIL_RETURN_UNEXPECTED(!handle.fail(), + "Invalid data, failed to read data part from sph file: " + file_sph_path.ToString() + + ", re-download dataset(make sure the data is true)."); + temp.push_back(c); + ++j; + } + + const float kMaxVal = 32767.0; + for (int32_t i = 0; i < size; ++i) { + char bh = temp[2 * i]; + char bl = temp[2 * i + 1]; + // SPH audio files is big-endian, so we should convert the two bytes of data into int16_t based + // on the high 8 bits and the low 8 bits. + int16_t s = static_cast(((bh & 0x00FF) << 8) | (bl & 0x00FF)); + // Data normalization: Convert the data from the interval [-32768,32767] to the interval [-1,1]. + double t = s / kMaxVal; + (*result).push_back(t); + } + handle.close(); + + return Status::OK(); +} + +Status TedliumOp::LoadTensorRow(row_id_type row_id, TensorRow *row) { + int32_t row_line = atoi(audio_files_[row_id][0].c_str()); + std::string file_name = audio_files_[row_id][1]; + std::string file_usage_or3_none_ = audio_files_[row_id][2]; + Path dir_path(dataset_dir_); + Path file_stm_path = dir_path / file_usage_or3_none_ / "stm" / (file_name + ".stm"); + Path file_sph_path = dir_path / file_usage_or3_none_ / "sph" / (file_name + extensions_); + std::string talk_id, speaker_id, start_time, end_time, identifier, transcript; + std::vector result; + int32_t sample_rate; + RETURN_IF_NOT_OK( + ReadStm(file_stm_path, row_line, &talk_id, &speaker_id, &start_time, &end_time, &identifier, &transcript)); + RETURN_IF_NOT_OK(ReadSph(file_sph_path, atof(start_time.c_str()), atof(end_time.c_str()), &sample_rate, &result)); + + std::shared_ptr sample_rate_tensor, talk_id_tensor, speaker_id_tensor, identifier_tensor, transcript_tensor; + RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(talk_id, &talk_id_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(speaker_id, &speaker_id_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(identifier, &identifier_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(transcript, &transcript_tensor)); + + std::shared_ptr audio_tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromVector(result, &audio_tensor)); + RETURN_IF_NOT_OK(audio_tensor->ExpandDim(0)); + (*row) = TensorRow(row_id, {audio_tensor, sample_rate_tensor, transcript_tensor, talk_id_tensor, speaker_id_tensor, + identifier_tensor}); + row->setPath({file_sph_path.ToString(), file_sph_path.ToString(), file_stm_path.ToString(), file_stm_path.ToString(), + file_stm_path.ToString(), file_stm_path.ToString()}); + + return Status::OK(); +} + +Status TedliumOp::CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage, + const std::string &extensions, int64_t *count) { + // the logic of counting the number of samples is copied from PrepareData() + RETURN_UNEXPECTED_IF_NULL(count); + *count = 0; + const int64_t num_samples = 0; + const int64_t start_index = 0; + auto new_sampler = std::make_shared(start_index, num_samples); + + // build a new unique schema object + auto new_schema = std::make_unique(); + RETURN_IF_NOT_OK( + new_schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + TensorShape sample_rate_scalar = TensorShape::CreateScalar(); + TensorShape trans_scalar = TensorShape::CreateScalar(); + TensorShape talk_id_scalar = TensorShape::CreateScalar(); + TensorShape speaker_id_scalar = TensorShape::CreateScalar(); + TensorShape identi_scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(new_schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar))); + RETURN_IF_NOT_OK(new_schema->AddColumn( + ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &trans_scalar))); + RETURN_IF_NOT_OK(new_schema->AddColumn( + ColDescriptor("talk_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &talk_id_scalar))); + RETURN_IF_NOT_OK(new_schema->AddColumn( + ColDescriptor("speaker_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &speaker_id_scalar))); + RETURN_IF_NOT_OK(new_schema->AddColumn( + ColDescriptor("identifier", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &identi_scalar))); + + std::shared_ptr cfg = GlobalContext::config_manager(); + int32_t num_workers = cfg->num_parallel_workers(); + int32_t op_connect_size = cfg->op_connector_size(); + std::shared_ptr op = + std::make_shared(dataset_dir, release, usage, extensions, num_workers, std::move(new_schema), + std::move(new_sampler), op_connect_size); + RETURN_IF_NOT_OK(op->PrepareData()); + *count = static_cast(op->audio_files_.size()); + return Status::OK(); +} + +Status TedliumOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->Column(i).Name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.h index 296823c7074..fd5e25034db 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tedlium_op.h @@ -1,126 +1,126 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ - -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/ir/cache/dataset_cache.h" - -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/path.h" - -namespace mindspore { -namespace dataset { -class TedliumOp : public MappableLeafOp { - public: - /// \brief Constructor. - /// \param[in] dataset_dir Directory of tedlium dataset. - /// \param[in] release Release of tedlium dataset, can be 'release1', 'release2' or 'release3'. - /// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'. - /// \param[in] extensions Extensions of the sph file, only '.sph' is valid. - /// \param[in] num_parallel_workers Number of workers in parallel. - /// \param[in] data_schema Schema of dataset. - /// \param[in] sampler Sampler tells TedliumOp what to read. - /// \param[in] queue_size Connector queue size. - TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage, - const std::string &extensions, int32_t num_parallel_workers, std::unique_ptr data_schema, - std::shared_ptr sampler, int32_t queue_size); - - /// \brief Destructor. - ~TedliumOp() = default; - - /// \brief A print method typically used for debugging. - /// \param[in] out Out stream. - /// \param[in] show_all Whether to show all information. - void Print(std::ostream &out, bool show_all) const override; - - /// \brief Op name getter. - std::string Name() const override { return "TedliumOp"; } - - /// \brief Initialize TedliumOp related var, calls the function to walk all files. - /// \return Status The status code returned. - Status PrepareData() override; - - /// \brief Function to count the number of samples in the TEDLIUM dataset. - /// \param[in] dataset_dir Directory of tedlium dataset. - /// \param[in] release Release of tedlium dataset. - /// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'. - /// \param[in] extensions Extensions of the sph file, only '.sph' is valid. - /// \param[in] count Output arg that will hold the actual dataset size. - /// \return Status The status code returned. - static Status CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage, - const std::string &extensions, int64_t *count); - - private: - /// \brief Read stm file. - /// \param[in] file_stm_path The path of stm file. - /// \param[in] row_line Which line of the file we need to read. - /// \param[out] talk_id Talk identifier of the row_line in the file. - /// \param[out] speaker_id Speaker identifier of the row_line in the file. - /// \param[out] start_time Start time of the row_line in the file. - /// \param[out] end_time End time of the row_line in the file. - /// \param[out] identifier Identifier of the row_line in the file. - /// \param[out] transcript Transcript of the row_line in the file. - /// \return Status The status code returned. - Status ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id, - std::string *start_time, std::string *end_time, std::string *identifier, std::string *transcript); - - /// \brief Read sph file. - /// \param[in] file_sph_path The path of sph file. - /// \param[in] start_time The start_time of row we need to use. - /// \param[in] end_time The end_time of row we need to use. - /// \param[out] sample_rate Sample rate of the row. - /// \param[out] result Waveform result vector of the row. - /// \return Status The status code returned. - Status ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate, - std::vector *result); - - /// \brief Read stm files according current release`s usage. - /// \param[in] stm_folder The folder of stm files. - /// \param[in] release_usage For release1 or release2, use usage_, for release3, "data". - /// \return Status The status code returned. - Status ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage); - - /// \brief Load a tensor row according to a pair. - /// \param[in] row_id Id of row need to load. - /// \param[in] row Audio & label read into this tensor row. - /// \return Status The status code returned. - Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; - - /// \brief Private function for computing the assignment of the column name map. - /// \return Status The status code returned. - Status ComputeColMap() override; - - const std::string release_; - const std::string dataset_dir_; - const std::string usage_; - const std::string extensions_; - std::unique_ptr data_schema_; - - std::vector > audio_files_; - std::vector usage_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/ir/cache/dataset_cache.h" + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { +class TedliumOp : public MappableLeafOp { + public: + /// \brief Constructor. + /// \param[in] dataset_dir Directory of tedlium dataset. + /// \param[in] release Release of tedlium dataset, can be 'release1', 'release2' or 'release3'. + /// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'. + /// \param[in] extensions Extensions of the sph file, only '.sph' is valid. + /// \param[in] num_parallel_workers Number of workers in parallel. + /// \param[in] data_schema Schema of dataset. + /// \param[in] sampler Sampler tells TedliumOp what to read. + /// \param[in] queue_size Connector queue size. + TedliumOp(const std::string &dataset_dir, const std::string &release, const std::string &usage, + const std::string &extensions, int32_t num_parallel_workers, std::unique_ptr data_schema, + std::shared_ptr sampler, int32_t queue_size); + + /// \brief Destructor. + ~TedliumOp() = default; + + /// \brief A print method typically used for debugging. + /// \param[in] out Out stream. + /// \param[in] show_all Whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Op name getter. + std::string Name() const override { return "TedliumOp"; } + + /// \brief Initialize TedliumOp related var, calls the function to walk all files. + /// \return Status The status code returned. + Status PrepareData() override; + + /// \brief Function to count the number of samples in the TEDLIUM dataset. + /// \param[in] dataset_dir Directory of tedlium dataset. + /// \param[in] release Release of tedlium dataset. + /// \param[in] usage Usage of this dataset, if release is release3, can be '', else 'train', 'dev', 'test' or 'all'. + /// \param[in] extensions Extensions of the sph file, only '.sph' is valid. + /// \param[in] count Output arg that will hold the actual dataset size. + /// \return Status The status code returned. + static Status CountTotalRows(const std::string &dataset_dir, const std::string &release, const std::string &usage, + const std::string &extensions, int64_t *count); + + private: + /// \brief Read stm file. + /// \param[in] file_stm_path The path of stm file. + /// \param[in] row_line Which line of the file we need to read. + /// \param[out] talk_id Talk identifier of the row_line in the file. + /// \param[out] speaker_id Speaker identifier of the row_line in the file. + /// \param[out] start_time Start time of the row_line in the file. + /// \param[out] end_time End time of the row_line in the file. + /// \param[out] identifier Identifier of the row_line in the file. + /// \param[out] transcript Transcript of the row_line in the file. + /// \return Status The status code returned. + Status ReadStm(const Path &file_stm_path, int32_t row_line, std::string *talk_id, std::string *speaker_id, + std::string *start_time, std::string *end_time, std::string *identifier, std::string *transcript); + + /// \brief Read sph file. + /// \param[in] file_sph_path The path of sph file. + /// \param[in] start_time The start_time of row we need to use. + /// \param[in] end_time The end_time of row we need to use. + /// \param[out] sample_rate Sample rate of the row. + /// \param[out] result Waveform result vector of the row. + /// \return Status The status code returned. + Status ReadSph(const Path &file_sph_path, double start_time, double end_time, int32_t *sample_rate, + std::vector *result); + + /// \brief Read stm files according current release`s usage. + /// \param[in] stm_folder The folder of stm files. + /// \param[in] release_usage For release1 or release2, use usage_, for release3, "data". + /// \return Status The status code returned. + Status ReadStmFolderRows(const Path &stm_folder, const std::string &release_usage); + + /// \brief Load a tensor row according to a pair. + /// \param[in] row_id Id of row need to load. + /// \param[in] row Audio & label read into this tensor row. + /// \return Status The status code returned. + Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; + + /// \brief Private function for computing the assignment of the column name map. + /// \return Status The status code returned. + Status ComputeColMap() override; + + const std::string release_; + const std::string dataset_dir_; + const std::string usage_; + const std::string extensions_; + std::unique_ptr data_schema_; + + std::vector > audio_files_; + std::vector usage_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_TEDLIUM_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/usps_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/usps_op.h index 38463a8dae0..0d49634acb5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/usps_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/usps_op.h @@ -1,137 +1,137 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/data_schema.h" -#include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/queue.h" -#include "minddata/dataset/util/status.h" -#include "minddata/dataset/util/wait_post.h" -#include "minddata/dataset/engine/jagged_connector.h" - -namespace mindspore { -namespace dataset { -class USPSOp : public NonMappableLeafOp { - public: - // Constructor. - // @param const std::string &dataset_dir - dir directory of USPS data file. - // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'. - // @param std::unique_ptr data_schema - the schema of the USPS dataset. - // @param num_workers - number of worker threads reading data from tf_file files. - // @param worker_connector_size - size of each internal queue. - // @param num_samples - number of samples to read. - // @param op_connector_size - size of each queue in the connector that the child operator pulls from. - // @param shuffle_files - whether to shuffle the files before reading data. - // @param num_devices - number of devices. - // @param device_id - device id. - USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr data_schema, - int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); - - // Destructor. - ~USPSOp() = default; - - // Op name getter. - // @return std::string - Name of the current Op. - std::string Name() const override { return "USPSOp"; } - - // A print method typically used for debugging. - // @param std::ostream &out - out stream. - // @param bool show_all - whether to show all information. - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors - // @return Status - the error code returned. - Status Init() override; - - // Function to count the number of samples in the USPS dataset. - // @param const std::string &dir - path to the USPS directory. - // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'. - // @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples. - // @return Status - the error coed returned. - static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); - - // File names getter. - // @return Vector of the input file names. - std::vector FileNames() { return data_files_list_; } - - private: - // Function to count the number of samples in one data file. - // @param const std::string &data_file - path to the data file. - // @return int64_t - the count result. - int64_t CountRows(const std::string &data_file) const; - - // Reads a data file and loads the data into multiple TensorRows. - // @param data_file - the data file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; - - // Parses a single row and puts the data into a tensor table. - // @param line - the content of the row. - // @param trow - image & label read into this tensor row. - // @return Status - the error code returned. - Status LoadTensor(std::string *line, TensorRow *trow); - - // Calculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard() override; - - // Fill the IOBlockQueue. - // @param i_keys - keys of file to fill to the IOBlockQueue. - // @return Status - the error code returned. - Status FillIOBlockQueue(const std::vector &i_keys) override; - - // Get all files in the dataset_dir_. - // @return Status - The status code returned. - Status GetFiles(); - - // Parse a line to image and label. - // @param line - the content of the row. - // @param images_buffer - image destination. - // @param labels_buffer - label destination. - // @return Status - the status code returned. - Status ParseLine(std::string *line, const std::unique_ptr &images_buffer, - const std::unique_ptr &labels_buffer) const; - - // Private function for computing the assignment of the column name map. - // @return Status - the error code returned. - Status ComputeColMap() override; - - const std::string usage_; // can be "all", "train" or "test". - std::string dataset_dir_; // directory of data files. - std::unique_ptr data_schema_; - - std::vector data_files_list_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +class USPSOp : public NonMappableLeafOp { + public: + // Constructor. + // @param const std::string &dataset_dir - dir directory of USPS data file. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'. + // @param std::unique_ptr data_schema - the schema of the USPS dataset. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param worker_connector_size - size of each internal queue. + // @param num_samples - number of samples to read. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param shuffle_files - whether to shuffle the files before reading data. + // @param num_devices - number of devices. + // @param device_id - device id. + USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr data_schema, + int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Destructor. + ~USPSOp() = default; + + // Op name getter. + // @return std::string - Name of the current Op. + std::string Name() const override { return "USPSOp"; } + + // A print method typically used for debugging. + // @param std::ostream &out - out stream. + // @param bool show_all - whether to show all information. + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned. + Status Init() override; + + // Function to count the number of samples in the USPS dataset. + // @param const std::string &dir - path to the USPS directory. + // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'. + // @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples. + // @return Status - the error coed returned. + static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); + + // File names getter. + // @return Vector of the input file names. + std::vector FileNames() { return data_files_list_; } + + private: + // Function to count the number of samples in one data file. + // @param const std::string &data_file - path to the data file. + // @return int64_t - the count result. + int64_t CountRows(const std::string &data_file) const; + + // Reads a data file and loads the data into multiple TensorRows. + // @param data_file - the data file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param trow - image & label read into this tensor row. + // @return Status - the error code returned. + Status LoadTensor(std::string *line, TensorRow *trow); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard() override; + + // Fill the IOBlockQueue. + // @param i_keys - keys of file to fill to the IOBlockQueue. + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys) override; + + // Get all files in the dataset_dir_. + // @return Status - The status code returned. + Status GetFiles(); + + // Parse a line to image and label. + // @param line - the content of the row. + // @param images_buffer - image destination. + // @param labels_buffer - label destination. + // @return Status - the status code returned. + Status ParseLine(std::string *line, const std::unique_ptr &images_buffer, + const std::unique_ptr &labels_buffer) const; + + // Private function for computing the assignment of the column name map. + // @return Status - the error code returned. + Status ComputeColMap() override; + + const std::string usage_; // can be "all", "train" or "test". + std::string dataset_dir_; // directory of data files. + std::unique_ptr data_schema_; + + std::vector data_files_list_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_USPS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.cc index f4499715f58..d61cd3a9584 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.cc @@ -1,197 +1,197 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h" - -#include "minddata/dataset/engine/datasetops/source/ag_news_op.h" -#include "minddata/dataset/engine/datasetops/source/csv_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Constructor for AGNewsNode. -AGNewsNode::AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, - const std::string &usage, int32_t num_shards, int32_t shard_id, - const std::shared_ptr &cache) - : NonMappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - num_samples_(num_samples), - shuffle_(shuffle), - num_shards_(num_shards), - shard_id_(shard_id), - usage_(usage), - ag_news_files_list_(WalkAllFiles(usage, dataset_dir)) { - GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); -} - -std::shared_ptr AGNewsNode::Copy() { - auto node = - std::make_shared(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void AGNewsNode::Print(std::ostream &out) const { - out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + - ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); -} - -Status AGNewsNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateStringValue("AGNewsDataset", usage_, {"train", "test", "all"})); - RETURN_IF_NOT_OK(ValidateScalar("AGNewsDataset", "num_samples", num_samples_, {0}, false)); - RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsDataset", num_shards_, shard_id_)); - RETURN_IF_NOT_OK(ValidateEnum("AGNewsDataset", "ShuffleMode", shuffle_, - {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); - - if (!column_names_.empty()) { - RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsDataset", "column_names", column_names_)); - } - return Status::OK(); -} - -// Function to build AGNewsNode. -Status AGNewsNode::Build(std::vector> *const node_ops) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // Sort the dataset files in a lexicographical order. - std::vector sorted_dataset_files = ag_news_files_list_; - std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); - // Because AGNews does not have external column_defaults nor column_names parameters, - // they need to be set before AGNewsOp is initialized. - // AGNews data set is formatted as three columns of data, so three columns are added. - std::vector> column_default; - column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); - column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); - column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); - std::vector column_name = {"index", "title", "description"}; - // AGNews data values are always delimited by a comma. - char field_delim_ = ','; - std::shared_ptr ag_news_op = - std::make_shared(num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, - num_shards_, shard_id_, field_delim_, column_default, column_name, sorted_dataset_files); - RETURN_IF_NOT_OK(ag_news_op->Init()); - if (shuffle_ == ShuffleMode::kGlobal) { - // Inject ShuffleOp. - std::shared_ptr shuffle_op = nullptr; - int64_t num_rows = 0; - // First, get the number of rows in the dataset. - RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows)); - // Add the shuffle op after this op. - RETURN_IF_NOT_OK( - AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); - shuffle_op->SetTotalRepeats(GetTotalRepeats()); - shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - shuffle_op->Skip(skip_steps_); - node_ops->push_back(shuffle_op); - } - ag_news_op->SetTotalRepeats(GetTotalRepeats()); - ag_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(ag_news_op); - return Status::OK(); -} - -// Get the shard id of node. -Status AGNewsNode::GetShardId(int32_t *shard_id) { - *shard_id = shard_id_; - return Status::OK(); -} - -// Get Dataset size. -Status AGNewsNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows)); - sample_size = num_samples_; - num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status AGNewsNode::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["num_samples"] = num_samples_; - args["shuffle"] = shuffle_; - args["num_shards"] = num_shards_; - args["shard_id"] = shard_id_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -// Note: The following two functions are common among NonMappableSourceNode and -// should be promoted to its parent class. AGNews (for which internally is based off CSV) -// by itself is a non-mappable dataset that does not support sampling. -// However, if a cache operator is injected at some other place higher in the tree, -// that cache can inherit this sampler from the leaf, providing sampling support from -// the caching layer. -// Should be promoted to its parent class. -// That is why we setup the sampler for a leaf node that does not use sampling. -Status AGNewsNode::SetupSamplerForCache(std::shared_ptr *sampler) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - return Status::OK(); -} - -// If a cache has been added into the ascendant tree over this AGNews node, then -// the cache will be executing a sampler for fetching the data. As such, any -// options in the AGNews node need to be reset to its defaults so that this -// AGNews node will produce the full set of data into the cache. -Status AGNewsNode::MakeSimpleProducer() { - shard_id_ = 0; - num_shards_ = 1; - shuffle_ = ShuffleMode::kFalse; - num_samples_ = 0; - return Status::OK(); -} - -std::vector AGNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { - std::vector ag_news_files_list; - Path train_prefix("train.csv"); - Path test_prefix("test.csv"); - Path dir(dataset_dir); - - if (usage == "train") { - Path temp_path = dir / train_prefix; - ag_news_files_list.push_back(temp_path.ToString()); - } else if (usage == "test") { - Path temp_path = dir / test_prefix; - ag_news_files_list.push_back(temp_path.ToString()); - } else { - Path temp_path = dir / train_prefix; - ag_news_files_list.push_back(temp_path.ToString()); - Path temp_path1 = dir / test_prefix; - ag_news_files_list.push_back(temp_path1.ToString()); - } - return ag_news_files_list; -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h" + +#include "minddata/dataset/engine/datasetops/source/ag_news_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Constructor for AGNewsNode. +AGNewsNode::AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, + const std::string &usage, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + usage_(usage), + ag_news_files_list_(WalkAllFiles(usage, dataset_dir)) { + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr AGNewsNode::Copy() { + auto node = + std::make_shared(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void AGNewsNode::Print(std::ostream &out) const { + out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + + ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); +} + +Status AGNewsNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("AGNewsDataset", usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateScalar("AGNewsDataset", "num_samples", num_samples_, {0}, false)); + RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsDataset", num_shards_, shard_id_)); + RETURN_IF_NOT_OK(ValidateEnum("AGNewsDataset", "ShuffleMode", shuffle_, + {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); + + if (!column_names_.empty()) { + RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsDataset", "column_names", column_names_)); + } + return Status::OK(); +} + +// Function to build AGNewsNode. +Status AGNewsNode::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + // Sort the dataset files in a lexicographical order. + std::vector sorted_dataset_files = ag_news_files_list_; + std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); + // Because AGNews does not have external column_defaults nor column_names parameters, + // they need to be set before AGNewsOp is initialized. + // AGNews data set is formatted as three columns of data, so three columns are added. + std::vector> column_default; + column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); + column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); + column_default.push_back(std::make_shared>(AGNewsOp::STRING, "")); + std::vector column_name = {"index", "title", "description"}; + // AGNews data values are always delimited by a comma. + char field_delim_ = ','; + std::shared_ptr ag_news_op = + std::make_shared(num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, + num_shards_, shard_id_, field_delim_, column_default, column_name, sorted_dataset_files); + RETURN_IF_NOT_OK(ag_news_op->Init()); + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp. + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + // First, get the number of rows in the dataset. + RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows)); + // Add the shuffle op after this op. + RETURN_IF_NOT_OK( + AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + shuffle_op->Skip(skip_steps_); + node_ops->push_back(shuffle_op); + } + ag_news_op->SetTotalRepeats(GetTotalRepeats()); + ag_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(ag_news_op); + return Status::OK(); +} + +// Get the shard id of node. +Status AGNewsNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + return Status::OK(); +} + +// Get Dataset size. +Status AGNewsNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(AGNewsOp::CountAllFileRows(ag_news_files_list_, false, &num_rows)); + sample_size = num_samples_; + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status AGNewsNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +// Note: The following two functions are common among NonMappableSourceNode and +// should be promoted to its parent class. AGNews (for which internally is based off CSV) +// by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, +// that cache can inherit this sampler from the leaf, providing sampling support from +// the caching layer. +// Should be promoted to its parent class. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status AGNewsNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this AGNews node, then +// the cache will be executing a sampler for fetching the data. As such, any +// options in the AGNews node need to be reset to its defaults so that this +// AGNews node will produce the full set of data into the cache. +Status AGNewsNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} + +std::vector AGNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { + std::vector ag_news_files_list; + Path train_prefix("train.csv"); + Path test_prefix("test.csv"); + Path dir(dataset_dir); + + if (usage == "train") { + Path temp_path = dir / train_prefix; + ag_news_files_list.push_back(temp_path.ToString()); + } else if (usage == "test") { + Path temp_path = dir / test_prefix; + ag_news_files_list.push_back(temp_path.ToString()); + } else { + Path temp_path = dir / train_prefix; + ag_news_files_list.push_back(temp_path.ToString()); + Path temp_path1 = dir / test_prefix; + ag_news_files_list.push_back(temp_path1.ToString()); + } + return ag_news_files_list; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.h index 4c2cb87b430..afd28a7703a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/ag_news_node.h @@ -1,126 +1,126 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -/// \brief class AGNewsNode. -/// \brief Dataset derived class to represent AGNews dataset. -class AGNewsNode : public NonMappableSourceNode { - public: - /// \brief Constructor. - AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, const std::string &usage, - int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); - - /// \brief Destructor. - ~AGNewsNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kAGNewsNode; } - - /// \brief Print the description. - /// \param[in] out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief A base class override function to create the required runtime dataset op objects for this class. - /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &Usage() const { return usage_; } - int64_t NumSamples() const { return num_samples_; } - ShuffleMode Shuffle() const { return shuffle_; } - int32_t NumShards() const { return num_shards_; } - int32_t ShardId() const { return shard_id_; } - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Get the arguments of node - /// \param[out] out_json JSON string of all attributes - /// \return Status of the function - Status to_json(nlohmann::json *out_json) override; - - /// \brief AGNews by itself is a non-mappable dataset that does not support sampling. - /// However, if a cache operator is injected at some other place higher in - /// the tree, that cache can inherit this sampler from the leaf, providing - /// sampling support from the caching layer. That is why we setup the - /// sampler for a leaf node that does not use sampling. Note: This - /// function is common among NonMappableSourceNode and should be promoted - /// to its parent class. - /// \param[in] sampler The sampler to setup. - /// \return Status of the function. - Status SetupSamplerForCache(std::shared_ptr *sampler) override; - - /// \brief If a cache has been added into the ascendant tree over this ag_news node, - /// then the cache will be executing a sampler for fetching the data. - /// As such, any options in the AGNews node need to be reset to its defaults - /// so that this AGNews node will produce the full set of data into the cache. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its - /// parent class. - /// \return Status of the function. - Status MakeSimpleProducer() override; - - /// \brief Generate a list of read file names according to usage. - /// \param[in] usage Part of dataset of AGNews. - /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \return std::vector A list of read file names. - std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); - - private: - std::string dataset_dir_; - std::string usage_; - std::vector> column_defaults_; - std::vector column_names_; - int64_t num_samples_; - ShuffleMode shuffle_; - int32_t num_shards_; - int32_t shard_id_; - std::vector ag_news_files_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +/// \brief class AGNewsNode. +/// \brief Dataset derived class to represent AGNews dataset. +class AGNewsNode : public NonMappableSourceNode { + public: + /// \brief Constructor. + AGNewsNode(const std::string &dataset_dir, int64_t num_samples, ShuffleMode shuffle, const std::string &usage, + int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); + + /// \brief Destructor. + ~AGNewsNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kAGNewsNode; } + + /// \brief Print the description. + /// \param[in] out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief A base class override function to create the required runtime dataset op objects for this class. + /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + int64_t NumSamples() const { return num_samples_; } + ShuffleMode Shuffle() const { return shuffle_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + + /// \brief AGNews by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in + /// the tree, that cache can inherit this sampler from the leaf, providing + /// sampling support from the caching layer. That is why we setup the + /// sampler for a leaf node that does not use sampling. Note: This + /// function is common among NonMappableSourceNode and should be promoted + /// to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this ag_news node, + /// then the cache will be executing a sampler for fetching the data. + /// As such, any options in the AGNews node need to be reset to its defaults + /// so that this AGNews node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its + /// parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + /// \brief Generate a list of read file names according to usage. + /// \param[in] usage Part of dataset of AGNews. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \return std::vector A list of read file names. + std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); + + private: + std::string dataset_dir_; + std::string usage_; + std::vector> column_defaults_; + std::vector column_names_; + int64_t num_samples_; + ShuffleMode shuffle_; + int32_t num_shards_; + int32_t shard_id_; + std::vector ag_news_files_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AG_NEWS_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.cc old mode 100755 new mode 100644 index a8e35a4fc2b..122252a1e0c --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.cc @@ -1,193 +1,193 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" - -namespace mindspore { -namespace dataset { -// Constructor for AmazonReviewNode -AmazonReviewNode::AmazonReviewNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, - const std::shared_ptr &cache) - : NonMappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - num_samples_(num_samples), - shuffle_(shuffle), - num_shards_(num_shards), - shard_id_(shard_id), - usage_(usage), - amazon_review_files_list_(WalkAllFiles(usage, dataset_dir)) { - // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. - // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work - // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to - // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. - GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); -} - -std::shared_ptr AmazonReviewNode::Copy() { - auto node = - std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void AmazonReviewNode::Print(std::ostream &out) const { - out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + - ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); -} - -Status AmazonReviewNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("AmazonReviewDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateStringValue("AmazonReviewDataset", usage_, {"train", "test", "all"})); - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AmazonReviewDataset", amazon_review_files_list_)); - RETURN_IF_NOT_OK(ValidateScalar("AmazonReviewDataset", "num_samples", num_samples_, {0}, false)); - RETURN_IF_NOT_OK(ValidateEnum("AmazonReviewDataset", "ShuffleMode", shuffle_, - {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); - - RETURN_IF_NOT_OK(ValidateDatasetShardParams("AmazonReviewDataset", num_shards_, shard_id_)); - return Status::OK(); -} - -Status AmazonReviewNode::Build(std::vector> *const node_ops) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - - // Sort the dataset files in a lexicographical order. - std::vector sorted_dataset_files = amazon_review_files_list_; - std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); - - std::vector> column_default; - column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); - column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); - column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); - - std::vector column_name = {"label", "title", "content"}; - char field_delim = ','; - std::shared_ptr amazon_review_op = std::make_shared( - num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, - field_delim, column_default, column_name, sorted_dataset_files); - RETURN_IF_NOT_OK(amazon_review_op->Init()); - - // If a global shuffle is used for AmazonReview, it will inject a shuffle op over the AmazonReview. - // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be - // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset AmazonReview's - // shuffle option to false. - if (shuffle_ == ShuffleMode::kGlobal) { - // Inject ShuffleOp. - std::shared_ptr shuffle_op = nullptr; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset. - RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(sorted_dataset_files, false, &num_rows)); - // Add the shuffle op after this op. - RETURN_IF_NOT_OK( - AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); - shuffle_op->SetTotalRepeats(GetTotalRepeats()); - shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - shuffle_op->Skip(skip_steps_); - node_ops->push_back(shuffle_op); - } - amazon_review_op->SetTotalRepeats(GetTotalRepeats()); - amazon_review_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(amazon_review_op); - return Status::OK(); -} - -Status AmazonReviewNode::GetShardId(int32_t *shard_id) { - *shard_id = shard_id_; - return Status::OK(); -} - -// Get Dataset size -Status AmazonReviewNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(amazon_review_files_list_, false, &num_rows)); - sample_size = num_samples_; - num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status AmazonReviewNode::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["num_samples"] = num_samples_; - args["shuffle"] = shuffle_; - args["num_shards"] = num_shards_; - args["shard_id"] = shard_id_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent -// class. AmazonReview by itself is a non-mappable dataset that does not support sampling. However, if a cache -// operator is injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, -// providing sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not -// use sampling. -Status AmazonReviewNode::SetupSamplerForCache(std::shared_ptr *sampler) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - return Status::OK(); -} - -// If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing -// a sampler for fetching the data. As such, any options in the AmazonReview node need to be reset to its defaults so -// If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing -Status AmazonReviewNode::MakeSimpleProducer() { - shard_id_ = 0; - num_shards_ = 1; - shuffle_ = ShuffleMode::kFalse; - num_samples_ = 0; - return Status::OK(); -} - -std::vector AmazonReviewNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { - std::vector amazon_review_files_list; - Path train_prefix("train.csv"); - Path test_prefix("test.csv"); - Path dir(dataset_dir); - - if (usage == "train") { - Path temp_path = dir / train_prefix; - amazon_review_files_list.push_back(temp_path.ToString()); - } else if (usage == "test") { - Path temp_path = dir / test_prefix; - amazon_review_files_list.push_back(temp_path.ToString()); - } else { - Path temp_path = dir / train_prefix; - amazon_review_files_list.push_back(temp_path.ToString()); - Path temp_path1 = dir / test_prefix; - amazon_review_files_list.push_back(temp_path1.ToString()); - } - return amazon_review_files_list; -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" + +namespace mindspore { +namespace dataset { +// Constructor for AmazonReviewNode +AmazonReviewNode::AmazonReviewNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + usage_(usage), + amazon_review_files_list_(WalkAllFiles(usage, dataset_dir)) { + // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. + // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work + // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to + // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr AmazonReviewNode::Copy() { + auto node = + std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void AmazonReviewNode::Print(std::ostream &out) const { + out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + + ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); +} + +Status AmazonReviewNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("AmazonReviewDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("AmazonReviewDataset", usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AmazonReviewDataset", amazon_review_files_list_)); + RETURN_IF_NOT_OK(ValidateScalar("AmazonReviewDataset", "num_samples", num_samples_, {0}, false)); + RETURN_IF_NOT_OK(ValidateEnum("AmazonReviewDataset", "ShuffleMode", shuffle_, + {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); + + RETURN_IF_NOT_OK(ValidateDatasetShardParams("AmazonReviewDataset", num_shards_, shard_id_)); + return Status::OK(); +} + +Status AmazonReviewNode::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + + // Sort the dataset files in a lexicographical order. + std::vector sorted_dataset_files = amazon_review_files_list_; + std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); + + std::vector> column_default; + column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); + column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); + column_default.push_back(std::make_shared>(AmazonReviewOp::STRING, "")); + + std::vector column_name = {"label", "title", "content"}; + char field_delim = ','; + std::shared_ptr amazon_review_op = std::make_shared( + num_workers_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, + field_delim, column_default, column_name, sorted_dataset_files); + RETURN_IF_NOT_OK(amazon_review_op->Init()); + + // If a global shuffle is used for AmazonReview, it will inject a shuffle op over the AmazonReview. + // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be + // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset AmazonReview's + // shuffle option to false. + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp. + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset. + RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(sorted_dataset_files, false, &num_rows)); + // Add the shuffle op after this op. + RETURN_IF_NOT_OK( + AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + shuffle_op->Skip(skip_steps_); + node_ops->push_back(shuffle_op); + } + amazon_review_op->SetTotalRepeats(GetTotalRepeats()); + amazon_review_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(amazon_review_op); + return Status::OK(); +} + +Status AmazonReviewNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + return Status::OK(); +} + +// Get Dataset size +Status AmazonReviewNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(AmazonReviewOp::CountAllFileRows(amazon_review_files_list_, false, &num_rows)); + sample_size = num_samples_; + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status AmazonReviewNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent +// class. AmazonReview by itself is a non-mappable dataset that does not support sampling. However, if a cache +// operator is injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, +// providing sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not +// use sampling. +Status AmazonReviewNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the AmazonReview node need to be reset to its defaults so +// If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be executing +Status AmazonReviewNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} + +std::vector AmazonReviewNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { + std::vector amazon_review_files_list; + Path train_prefix("train.csv"); + Path test_prefix("test.csv"); + Path dir(dataset_dir); + + if (usage == "train") { + Path temp_path = dir / train_prefix; + amazon_review_files_list.push_back(temp_path.ToString()); + } else if (usage == "test") { + Path temp_path = dir / test_prefix; + amazon_review_files_list.push_back(temp_path.ToString()); + } else { + Path temp_path = dir / train_prefix; + amazon_review_files_list.push_back(temp_path.ToString()); + Path temp_path1 = dir / test_prefix; + amazon_review_files_list.push_back(temp_path1.ToString()); + } + return amazon_review_files_list; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h old mode 100755 new mode 100644 index 6696edcd700..208e6e64f59 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h @@ -1,120 +1,120 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/amazon_review_op.h" -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class AmazonReviewNode : public NonMappableSourceNode { - public: - /// \brief Constructor. - AmazonReviewNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); - - /// \brief Destructor. - ~AmazonReviewNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kAmazonReviewNode; } - - /// \brief Print the description. - /// \param[out] out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief A base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size The size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &Usage() const { return usage_; } - int64_t NumSamples() const { return num_samples_; } - ShuffleMode Shuffle() const { return shuffle_; } - int32_t NumShards() const { return num_shards_; } - int32_t ShardId() const { return shard_id_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief AmazonReview by itself is a non-mappable dataset that does not support sampling. - /// However, if a cache operator is injected at some other place higher in the tree, that cache can - /// inherit this sampler from the leaf, providing sampling support from the caching layer. - /// That is why we setup the sampler for a leaf node that does not use sampling. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \param[in] sampler The sampler to setup. - /// \return Status of the function. - Status SetupSamplerForCache(std::shared_ptr *sampler) override; - - /// \brief If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be - /// executing a sampler for fetching the data. As such, any options in the AmazonReview node need to be reset - /// to its defaults so that this AmazonReview node will produce the full set of data into the cache. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \return Status of the function. - Status MakeSimpleProducer() override; - - /// \brief Generate a list of read file names according to usage. - /// \param[in] usage Part of dataset of AmazonReview. - /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \return std::vector A list of read file names. - std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); - - private: - std::string dataset_dir_; - std::string usage_; - std::vector> column_defaults_; - std::vector column_names_; - int64_t num_samples_; - ShuffleMode shuffle_; - int32_t num_shards_; - int32_t shard_id_; - std::vector amazon_review_files_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/amazon_review_op.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class AmazonReviewNode : public NonMappableSourceNode { + public: + /// \brief Constructor. + AmazonReviewNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); + + /// \brief Destructor. + ~AmazonReviewNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kAmazonReviewNode; } + + /// \brief Print the description. + /// \param[out] out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief A base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size The size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + int64_t NumSamples() const { return num_samples_; } + ShuffleMode Shuffle() const { return shuffle_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief AmazonReview by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this AmazonReview node, then the cache will be + /// executing a sampler for fetching the data. As such, any options in the AmazonReview node need to be reset + /// to its defaults so that this AmazonReview node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + /// \brief Generate a list of read file names according to usage. + /// \param[in] usage Part of dataset of AmazonReview. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \return std::vector A list of read file names. + std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); + + private: + std::string dataset_dir_; + std::string usage_; + std::vector> column_defaults_; + std::vector column_names_; + int64_t num_samples_; + ShuffleMode shuffle_; + int32_t num_shards_; + int32_t shard_id_; + std::vector amazon_review_files_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_AMAZON_REVIEW_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/caltech256_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/caltech256_node.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/caltech256_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/caltech256_node.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.cc index 51497933a85..0c650ccd192 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.cc @@ -1,114 +1,114 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h" - -#include "minddata/dataset/engine/datasetops/source/cmu_arctic_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -CMUArcticNode::CMUArcticNode(const std::string &dataset_dir, const std::string &name, - std::shared_ptr sampler, std::shared_ptr cache) - : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), name_(name), sampler_(sampler) {} - -void CMUArcticNode::Print(std::ostream &out) const { out << Name(); } - -std::shared_ptr CMUArcticNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); - auto node = std::make_shared(dataset_dir_, name_, sampler, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -Status CMUArcticNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("CMUArcticDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("CMUArcticDataset", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue("CMUArcticDataset", name_, - {"aew", "ahw", "aup", "awb", "axb", "bdl", "clb", "eey", "fem", "gka", "jmk", - "ksp", "ljm", "lnh", "rms", "rxr", "slp", "slt"})); - return Status::OK(); -} - -Status CMUArcticNode::Build(std::vector> *const node_ops) { - auto schema = std::make_unique(); - - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_utterance = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance))); - TensorShape scalar_utterance_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); - - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - - auto op = std::make_shared(dataset_dir_, name_, num_workers_, connector_que_size_, std::move(schema), - std::move(sampler_rt)); - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - - return Status::OK(); -} - -Status CMUArcticNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - return Status::OK(); -} - -Status CMUArcticNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(CMUArcticOp::CountTotalRows(dataset_dir_, name_, &num_rows)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - sample_size = sampler_rt->CalculateNumSamples(num_rows); - if (sample_size == -1) { - RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); - } - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status CMUArcticNode::to_json(nlohmann::json *out_json) { - nlohmann::json args, sampler_args; - RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); - args["sampler"] = sampler_args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["name"] = name_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h" + +#include "minddata/dataset/engine/datasetops/source/cmu_arctic_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +CMUArcticNode::CMUArcticNode(const std::string &dataset_dir, const std::string &name, + std::shared_ptr sampler, std::shared_ptr cache) + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), name_(name), sampler_(sampler) {} + +void CMUArcticNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr CMUArcticNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, name_, sampler, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +Status CMUArcticNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CMUArcticDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("CMUArcticDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateStringValue("CMUArcticDataset", name_, + {"aew", "ahw", "aup", "awb", "axb", "bdl", "clb", "eey", "fem", "gka", "jmk", + "ksp", "ljm", "lnh", "rms", "rxr", "slp", "slt"})); + return Status::OK(); +} + +Status CMUArcticNode::Build(std::vector> *const node_ops) { + auto schema = std::make_unique(); + + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_utterance = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("transcript", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance))); + TensorShape scalar_utterance_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); + + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + + auto op = std::make_shared(dataset_dir_, name_, num_workers_, connector_que_size_, std::move(schema), + std::move(sampler_rt)); + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + + return Status::OK(); +} + +Status CMUArcticNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + return Status::OK(); +} + +Status CMUArcticNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(CMUArcticOp::CountTotalRows(dataset_dir_, name_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status CMUArcticNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["name"] = name_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h index 8305599adf9..8e38c78ad5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cmu_arctic_node.h @@ -1,95 +1,95 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class CMUArcticNode : public MappableSourceNode { - public: - /// \brief Constructor. - CMUArcticNode(const std::string &dataset_dir, const std::string &name, std::shared_ptr sampler, - std::shared_ptr cache); - - /// \brief Destructor. - ~CMUArcticNode() = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kCMUArcticNode; } - - /// \brief Print the description. - /// \param out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard ID within num_shards. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &GetName() const { return name_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief Sampler getter. - /// \return SamplerObj of the current node. - std::shared_ptr Sampler() override { return sampler_; } - - /// \brief Sampler setter. - /// \param[in] sampler Tells CMUArcticOp what to read. - void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } - - private: - std::string dataset_dir_; - std::string name_; - std::shared_ptr sampler_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class CMUArcticNode : public MappableSourceNode { + public: + /// \brief Constructor. + CMUArcticNode(const std::string &dataset_dir, const std::string &name, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor. + ~CMUArcticNode() = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kCMUArcticNode; } + + /// \brief Print the description. + /// \param out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard ID within num_shards. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &GetName() const { return name_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + /// \param[in] sampler Tells CMUArcticOp what to read. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + std::string name_; + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CMU_ARCTIC_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc index 15329458560..b2e00a5c53b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc @@ -1,108 +1,108 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h" - -#include "minddata/dataset/engine/datasetops/source/gtzan_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -GTZANNode::GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, - std::shared_ptr cache) - : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} - -void GTZANNode::Print(std::ostream &out) const { out << Name(); } - -std::shared_ptr GTZANNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); - auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -Status GTZANNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("GTZANDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("GTZANDataset", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue("GTZANDataset", usage_, {"train", "valid", "test", "all"})); - return Status::OK(); -} - -Status GTZANNode::Build(std::vector> *const node_ops) { - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT64), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_label = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - auto op = std::make_shared(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_rt)); - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - return Status::OK(); -} - -// Get the shard id of node. -Status GTZANNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - return Status::OK(); -} - -// Get Dataset size. -Status GTZANNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(GTZANOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - sample_size = sampler_rt->CalculateNumSamples(num_rows); - if (sample_size == -1) { - RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); - } - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status GTZANNode::to_json(nlohmann::json *out_json) { - nlohmann::json args, sampler_args; - RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); - args["sampler"] = sampler_args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h" + +#include "minddata/dataset/engine/datasetops/source/gtzan_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +GTZANNode::GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache) + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + +void GTZANNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr GTZANNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +Status GTZANNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("GTZANDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("GTZANDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateStringValue("GTZANDataset", usage_, {"train", "valid", "test", "all"})); + return Status::OK(); +} + +Status GTZANNode::Build(std::vector> *const node_ops) { + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT64), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_label = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + auto op = std::make_shared(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_rt)); + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + return Status::OK(); +} + +// Get the shard id of node. +Status GTZANNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + return Status::OK(); +} + +// Get Dataset size. +Status GTZANNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(GTZANOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status GTZANNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h index 4a38d98f45b..1d11ea56d46 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h @@ -1,95 +1,95 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class GTZANNode : public MappableSourceNode { - public: - /// \brief Constructor - GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, - std::shared_ptr cache); - - /// \brief Destructor - ~GTZANNode() = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return "kGTZANNode"; } - - /// \brief Print the description. - /// \param out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard ID within num_shards. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &Usage() const { return usage_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief Sampler getter. - /// \return SamplerObj of the current node. - std::shared_ptr Sampler() override { return sampler_; } - - /// \brief Sampler setter. - /// \param[in] sampler Tells GTZANOp what to read. - void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } - - private: - std::string dataset_dir_; - std::string usage_; - std::shared_ptr sampler_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class GTZANNode : public MappableSourceNode { + public: + /// \brief Constructor + GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor + ~GTZANNode() = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return "kGTZANNode"; } + + /// \brief Print the description. + /// \param out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard ID within num_shards. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + /// \param[in] sampler Tells GTZANOp what to read. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + std::string usage_; + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.cc index 6a6460b9240..294139f3cc3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.cc @@ -1,119 +1,119 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h" - -#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -LibriTTSNode::LibriTTSNode(const std::string &dataset_dir, const std::string &usage, - std::shared_ptr sampler, std::shared_ptr cache) - : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} - -void LibriTTSNode::Print(std::ostream &out) const { out << Name(); } - -std::shared_ptr LibriTTSNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); - auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -Status LibriTTSNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("LibriTTSDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("LibriTTSDataset", sampler_)); - RETURN_IF_NOT_OK(ValidateStringValue("LibriTTSDataset", usage_, - {"dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", - "train-clean-360", "train-other-500", "all"})); - return Status::OK(); -} - -Status LibriTTSNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - return Status::OK(); -} - -Status LibriTTSNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(LibriTTSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - sample_size = sampler_rt->CalculateNumSamples(num_rows); - if (sample_size == -1) { - RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); - } - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status LibriTTSNode::Build(std::vector> *const node_ops) { - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); - TensorShape scalar_rate = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); - TensorShape scalar_original_text = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text))); - TensorShape scalar_normalized_text = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING), - TensorImpl::kFlexible, 0, &scalar_normalized_text))); - TensorShape scalar_speaker_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id))); - TensorShape scalar_chapter_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id))); - TensorShape scalar_utterance_id = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema->AddColumn( - ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - auto op = std::make_shared(dataset_dir_, usage_, num_workers_, connector_que_size_, std::move(schema), - std::move(sampler_rt)); - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - return Status::OK(); -} - -Status LibriTTSNode::to_json(nlohmann::json *out_json) { - nlohmann::json args, sampler_args; - RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); - args["sampler"] = sampler_args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h" + +#include "minddata/dataset/engine/datasetops/source/libri_tts_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +LibriTTSNode::LibriTTSNode(const std::string &dataset_dir, const std::string &usage, + std::shared_ptr sampler, std::shared_ptr cache) + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + +void LibriTTSNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr LibriTTSNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +Status LibriTTSNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("LibriTTSDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("LibriTTSDataset", sampler_)); + RETURN_IF_NOT_OK(ValidateStringValue("LibriTTSDataset", usage_, + {"dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", + "train-clean-360", "train-other-500", "all"})); + return Status::OK(); +} + +Status LibriTTSNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + return Status::OK(); +} + +Status LibriTTSNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(LibriTTSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status LibriTTSNode::Build(std::vector> *const node_ops) { + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); + TensorShape scalar_rate = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); + TensorShape scalar_original_text = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("original_text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_original_text))); + TensorShape scalar_normalized_text = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("normalized_text", DataType(DataType::DE_STRING), + TensorImpl::kFlexible, 0, &scalar_normalized_text))); + TensorShape scalar_speaker_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("speaker_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_speaker_id))); + TensorShape scalar_chapter_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("chapter_id", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_chapter_id))); + TensorShape scalar_utterance_id = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema->AddColumn( + ColDescriptor("utterance_id", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_utterance_id))); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + auto op = std::make_shared(dataset_dir_, usage_, num_workers_, connector_que_size_, std::move(schema), + std::move(sampler_rt)); + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + return Status::OK(); +} + +Status LibriTTSNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h index 812a5b8e29a..df399540fc0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h @@ -1,95 +1,95 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class LibriTTSNode : public MappableSourceNode { - public: - /// \brief Constructor. - LibriTTSNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, - std::shared_ptr cache); - - /// \brief Destructor. - ~LibriTTSNode() = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kLibriTTSNode; } - - /// \brief Print the description. - /// \param out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard ID within num_shards. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &usage() const { return usage_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief Sampler getter. - /// \return SamplerObj of the current node. - std::shared_ptr Sampler() override { return sampler_; } - - /// \brief Sampler setter. - /// \param[in] sampler Tells LibriTTSOp what to read. - void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } - - private: - std::string dataset_dir_; - std::string usage_; - std::shared_ptr sampler_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class LibriTTSNode : public MappableSourceNode { + public: + /// \brief Constructor. + LibriTTSNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor. + ~LibriTTSNode() = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kLibriTTSNode; } + + /// \brief Print the description. + /// \param out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard ID within num_shards. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &usage() const { return usage_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + /// \param[in] sampler Tells LibriTTSOp what to read. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + std::string usage_; + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_LIBRI_TTS_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.cc index e281023a818..81f4d3c9e94 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.cc @@ -1,198 +1,198 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h" - -#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Constructor for PennTreebankNode. -PennTreebankNode::PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, - const std::shared_ptr &cache) - : NonMappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - usage_(usage), - num_samples_(num_samples), - shuffle_(shuffle), - num_shards_(num_shards), - shard_id_(shard_id), - penn_treebank_files_list_(WalkAllFiles(usage, dataset_dir)) { - // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion - // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't - // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once - // PreBuildSampler is phased out, this can be cleaned up. - GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); -} - -std::shared_ptr PennTreebankNode::Copy() { - auto node = - std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void PennTreebankNode::Print(std::ostream &out) const { - out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + - ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); -} - -Status PennTreebankNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("PennTreebankNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateStringValue("PennTreebankNode", usage_, {"train", "test", "valid", "all"})); - RETURN_IF_NOT_OK(ValidateEnum("PennTreebankNode", "ShuffleMode", shuffle_, - {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); - if (num_samples_ < 0) { - std::string err_msg = "PennTreebankNode: Invalid number of samples: " + std::to_string(num_samples_); - LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - RETURN_IF_NOT_OK(ValidateDatasetShardParams("PennTreebankNode", num_shards_, shard_id_)); - return Status::OK(); -} - -// Function to build PennTreebankNode. -Status PennTreebankNode::Build(std::vector> *const node_ops) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // Sort the dataset files in a lexicographical order. - std::vector sorted_dataset_files = penn_treebank_files_list_; - std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - // Create and initialize PennTreebankNode. - std::shared_ptr penn_treebank_op = - std::make_shared(num_workers_, num_samples_, worker_connector_size_, std::move(schema), - sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); - RETURN_IF_NOT_OK(penn_treebank_op->Init()); - // If a global shuffle is used for PennTreebank, it will inject a shuffle op over the PennTreebank. - // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. - // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset PennTreebank's shuffle - // option to false. - if (shuffle_ == ShuffleMode::kGlobal) { - // Inject ShuffleOp. - std::shared_ptr shuffle_op = nullptr; - int64_t num_rows = 0; - // First, get the number of rows in the dataset. - RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows)); - // Add the shuffle op after this op. - RETURN_IF_NOT_OK( - AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); - shuffle_op->SetTotalRepeats(GetTotalRepeats()); - shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - shuffle_op->Skip(skip_steps_); - node_ops->push_back(shuffle_op); - } - penn_treebank_op->SetTotalRepeats(GetTotalRepeats()); - penn_treebank_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - // Add PennTreebankNode. - node_ops->push_back(penn_treebank_op); - return Status::OK(); -} - -// Get the shard id of node. -Status PennTreebankNode::GetShardId(int32_t *shard_id) { - *shard_id = shard_id_; - return Status::OK(); -} - -// Get Dataset size. -Status PennTreebankNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size = num_samples_; - RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows)); - num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status PennTreebankNode::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["num_samples"] = num_samples_; - args["shuffle"] = shuffle_; - args["num_shards"] = num_shards_; - args["shard_id"] = shard_id_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. -// PennTreebank by itself is a non-mappable dataset that does not support sampling. -// However, if a cache operator is injected at some other place higher in the tree, that cache can -// inherit this sampler from the leaf, providing sampling support from the caching layer. -// That is why we setup the sampler for a leaf node that does not use sampling. -Status PennTreebankNode::SetupSamplerForCache(std::shared_ptr *sampler) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - return Status::OK(); -} - -// If a cache has been added into the ascendant tree over this PennTreebank node, then the cache will be executing -// a sampler for fetching the data. As such, any options in the PennTreebank node need to be reset to its defaults so -// that this PennTreebank node will produce the full set of data into the cache. -Status PennTreebankNode::MakeSimpleProducer() { - shard_id_ = 0; - num_shards_ = 1; - shuffle_ = ShuffleMode::kFalse; - num_samples_ = 0; - return Status::OK(); -} - -std::vector PennTreebankNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { - std::vector penn_treebank_files_list; - Path train_prefix("ptb.train.txt"); - Path test_prefix("ptb.test.txt"); - Path valid_prefix("ptb.valid.txt"); - Path dir(dataset_dir); - - if (usage == "train") { - Path temp_path = dir / train_prefix; - penn_treebank_files_list.push_back(temp_path.ToString()); - } else if (usage == "test") { - Path temp_path = dir / test_prefix; - penn_treebank_files_list.push_back(temp_path.ToString()); - } else if (usage == "valid") { - Path temp_path = dir / valid_prefix; - penn_treebank_files_list.push_back(temp_path.ToString()); - } else { - Path temp_path = dir / train_prefix; - penn_treebank_files_list.push_back(temp_path.ToString()); - Path temp_path1 = dir / test_prefix; - penn_treebank_files_list.push_back(temp_path1.ToString()); - Path temp_path2 = dir / valid_prefix; - penn_treebank_files_list.push_back(temp_path2.ToString()); - } - return penn_treebank_files_list; -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h" + +#include "minddata/dataset/engine/datasetops/source/penn_treebank_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Constructor for PennTreebankNode. +PennTreebankNode::PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + usage_(usage), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + penn_treebank_files_list_(WalkAllFiles(usage, dataset_dir)) { + // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion + // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't + // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once + // PreBuildSampler is phased out, this can be cleaned up. + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr PennTreebankNode::Copy() { + auto node = + std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void PennTreebankNode::Print(std::ostream &out) const { + out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + + ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); +} + +Status PennTreebankNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("PennTreebankNode", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("PennTreebankNode", usage_, {"train", "test", "valid", "all"})); + RETURN_IF_NOT_OK(ValidateEnum("PennTreebankNode", "ShuffleMode", shuffle_, + {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); + if (num_samples_ < 0) { + std::string err_msg = "PennTreebankNode: Invalid number of samples: " + std::to_string(num_samples_); + LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + RETURN_IF_NOT_OK(ValidateDatasetShardParams("PennTreebankNode", num_shards_, shard_id_)); + return Status::OK(); +} + +// Function to build PennTreebankNode. +Status PennTreebankNode::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + // Sort the dataset files in a lexicographical order. + std::vector sorted_dataset_files = penn_treebank_files_list_; + std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + // Create and initialize PennTreebankNode. + std::shared_ptr penn_treebank_op = + std::make_shared(num_workers_, num_samples_, worker_connector_size_, std::move(schema), + sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); + RETURN_IF_NOT_OK(penn_treebank_op->Init()); + // If a global shuffle is used for PennTreebank, it will inject a shuffle op over the PennTreebank. + // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. + // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset PennTreebank's shuffle + // option to false. + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp. + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + // First, get the number of rows in the dataset. + RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows)); + // Add the shuffle op after this op. + RETURN_IF_NOT_OK( + AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + shuffle_op->Skip(skip_steps_); + node_ops->push_back(shuffle_op); + } + penn_treebank_op->SetTotalRepeats(GetTotalRepeats()); + penn_treebank_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + // Add PennTreebankNode. + node_ops->push_back(penn_treebank_op); + return Status::OK(); +} + +// Get the shard id of node. +Status PennTreebankNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + return Status::OK(); +} + +// Get Dataset size. +Status PennTreebankNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size = num_samples_; + RETURN_IF_NOT_OK(PennTreebankOp::CountAllFileRows(penn_treebank_files_list_, &num_rows)); + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status PennTreebankNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// PennTreebank by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status PennTreebankNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this PennTreebank node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the PennTreebank node need to be reset to its defaults so +// that this PennTreebank node will produce the full set of data into the cache. +Status PennTreebankNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} + +std::vector PennTreebankNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { + std::vector penn_treebank_files_list; + Path train_prefix("ptb.train.txt"); + Path test_prefix("ptb.test.txt"); + Path valid_prefix("ptb.valid.txt"); + Path dir(dataset_dir); + + if (usage == "train") { + Path temp_path = dir / train_prefix; + penn_treebank_files_list.push_back(temp_path.ToString()); + } else if (usage == "test") { + Path temp_path = dir / test_prefix; + penn_treebank_files_list.push_back(temp_path.ToString()); + } else if (usage == "valid") { + Path temp_path = dir / valid_prefix; + penn_treebank_files_list.push_back(temp_path.ToString()); + } else { + Path temp_path = dir / train_prefix; + penn_treebank_files_list.push_back(temp_path.ToString()); + Path temp_path1 = dir / test_prefix; + penn_treebank_files_list.push_back(temp_path1.ToString()); + Path temp_path2 = dir / valid_prefix; + penn_treebank_files_list.push_back(temp_path2.ToString()); + } + return penn_treebank_files_list; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h index 796052e8e78..56f8880779d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h @@ -1,124 +1,124 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -/// \brief class PennTreebankNode. -/// \brief Dataset derived class to represent PennTreebank dataset. -class PennTreebankNode : public NonMappableSourceNode { - public: - /// \brief Constructor. - PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); - - /// \brief Destructor. - ~PennTreebankNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kPennTreebankNode; } - - /// \brief Print the description. - /// \param[in] out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief A base class override function to create the required runtime dataset op objects for this class. - /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - int32_t NumSamples() const { return num_samples_; } - int32_t NumShards() const { return num_shards_; } - int32_t ShardId() const { return shard_id_; } - ShuffleMode Shuffle() const { return shuffle_; } - const std::string &Usage() const { return usage_; } - - /// \brief Get the arguments of node - /// \param[out] out_json JSON string of all attributes - /// \return Status of the function - Status to_json(nlohmann::json *out_json) override; - - /// \brief PennTreebank by itself is a non-mappable dataset that does not support sampling. - /// However, if a cache operator is injected at some other place higher in - /// the tree, that cache can inherit this sampler from the leaf, providing - /// sampling support from the caching layer. That is why we setup the - /// sampler for a leaf node that does not use sampling. Note: This - /// function is common among NonMappableSourceNode and should be promoted - /// to its parent class. - /// \param[in] sampler The sampler to setup. - /// \return Status of the function. - Status SetupSamplerForCache(std::shared_ptr *sampler) override; - - /// \brief If a cache has been added into the ascendant tree over this PennTreebank node, - /// then the cache will be executing a sampler for fetching the data. - /// As such, any options in the PennTreebank node need to be reset to its defaults - /// so that this PennTreebank node will produce the full set of data into the cache. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its - /// parent class. - /// \return Status of the function. - Status MakeSimpleProducer() override; - - /// \brief Generate a list of read file names according to usage. - /// \param[in] usage Part of dataset of PennTreebank. - /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \return std::vector A list of read file names. - std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); - - private: - std::string dataset_dir_; - std::string usage_; - int64_t num_samples_; - int32_t num_shards_; - int32_t shard_id_; - ShuffleMode shuffle_; - std::vector penn_treebank_files_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +/// \brief class PennTreebankNode. +/// \brief Dataset derived class to represent PennTreebank dataset. +class PennTreebankNode : public NonMappableSourceNode { + public: + /// \brief Constructor. + PennTreebankNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); + + /// \brief Destructor. + ~PennTreebankNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kPennTreebankNode; } + + /// \brief Print the description. + /// \param[in] out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief A base class override function to create the required runtime dataset op objects for this class. + /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + int32_t NumSamples() const { return num_samples_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + ShuffleMode Shuffle() const { return shuffle_; } + const std::string &Usage() const { return usage_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + + /// \brief PennTreebank by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in + /// the tree, that cache can inherit this sampler from the leaf, providing + /// sampling support from the caching layer. That is why we setup the + /// sampler for a leaf node that does not use sampling. Note: This + /// function is common among NonMappableSourceNode and should be promoted + /// to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this PennTreebank node, + /// then the cache will be executing a sampler for fetching the data. + /// As such, any options in the PennTreebank node need to be reset to its defaults + /// so that this PennTreebank node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its + /// parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + /// \brief Generate a list of read file names according to usage. + /// \param[in] usage Part of dataset of PennTreebank. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \return std::vector A list of read file names. + std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); + + private: + std::string dataset_dir_; + std::string usage_; + int64_t num_samples_; + int32_t num_shards_; + int32_t shard_id_; + ShuffleMode shuffle_; + std::vector penn_treebank_files_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PENN_TREEBANK_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc index f4f06fcb08b..fc98d1227e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.cc @@ -1,154 +1,154 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" - -#include -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/qmnist_op.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/serdes.h" -#endif -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, - std::shared_ptr sampler, std::shared_ptr cache) - : MappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - usage_(usage), - compat_(compat), - sampler_(sampler) {} - -std::shared_ptr QMnistNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); - auto node = std::make_shared(dataset_dir_, usage_, compat_, sampler, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void QMnistNode::Print(std::ostream &out) const { - out << (Name() + "(dataset dir: " + dataset_dir_ + ", usage: " + usage_ + - ", compat: " + (compat_ ? "true" : "false") + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); -} - -Status QMnistNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistDataset", sampler_)); - RETURN_IF_NOT_OK( - ValidateStringValue("QMnistDataset", usage_, {"train", "test", "test10k", "test50k", "nist", "all"})); - return Status::OK(); -} - -Status QMnistNode::Build(std::vector> *const node_ops) { - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - if (compat_) { - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - } else { - RETURN_IF_NOT_OK( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - } - - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - - auto op = std::make_shared(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt), - num_workers_, connector_que_size_); - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - - return Status::OK(); -} - -// Get the shard id of node -Status QMnistNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - - return Status::OK(); -} - -// Get Dataset size -Status QMnistNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(QMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - sample_size = sampler_rt->CalculateNumSamples(num_rows); - if (sample_size == -1) { - RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); - } - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status QMnistNode::to_json(nlohmann::json *out_json) { - nlohmann::json args, sampler_args; - RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); - args["sampler"] = sampler_args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["compat"] = compat_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -#ifndef ENABLE_ANDROID -Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr *ds) { - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kQMnistNode)); - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kQMnistNode)); - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kQMnistNode)); - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kQMnistNode)); - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "compat", kQMnistNode)); - RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kQMnistNode)); - std::string dataset_dir = json_obj["dataset_dir"]; - std::string usage = json_obj["usage"]; - bool compat = json_obj["compat"]; - std::shared_ptr sampler; - RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); - std::shared_ptr cache = nullptr; - RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); - *ds = std::make_shared(dataset_dir, usage, compat, sampler, cache); - (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); - (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]); - return Status::OK(); -} -#endif -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/qmnist_op.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/serdes.h" +#endif +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, + std::shared_ptr sampler, std::shared_ptr cache) + : MappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + usage_(usage), + compat_(compat), + sampler_(sampler) {} + +std::shared_ptr QMnistNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, usage_, compat_, sampler, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void QMnistNode::Print(std::ostream &out) const { + out << (Name() + "(dataset dir: " + dataset_dir_ + ", usage: " + usage_ + + ", compat: " + (compat_ ? "true" : "false") + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); +} + +Status QMnistNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistDataset", sampler_)); + RETURN_IF_NOT_OK( + ValidateStringValue("QMnistDataset", usage_, {"train", "test", "test10k", "test50k", "nist", "all"})); + return Status::OK(); +} + +Status QMnistNode::Build(std::vector> *const node_ops) { + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + if (compat_) { + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + } else { + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + } + + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + + auto op = std::make_shared(dataset_dir_, usage_, compat_, std::move(schema), std::move(sampler_rt), + num_workers_, connector_que_size_); + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + + return Status::OK(); +} + +// Get the shard id of node +Status QMnistNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + +// Get Dataset size +Status QMnistNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(QMnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status QMnistNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["compat"] = compat_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr *ds) { + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kQMnistNode)); + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kQMnistNode)); + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kQMnistNode)); + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kQMnistNode)); + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "compat", kQMnistNode)); + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kQMnistNode)); + std::string dataset_dir = json_obj["dataset_dir"]; + std::string usage = json_obj["usage"]; + bool compat = json_obj["compat"]; + std::shared_ptr sampler; + RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler)); + std::shared_ptr cache = nullptr; + RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); + *ds = std::make_shared(dataset_dir, usage, compat, sampler, cache); + (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); + (void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]); + return Status::OK(); +} +#endif +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h index 4f64e13c461..bdcd1ab276e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/qmnist_node.h @@ -1,109 +1,109 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class QMnistNode : public MappableSourceNode { - public: - /// \brief Constructor. - QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, std::shared_ptr sampler, - std::shared_ptr cache); - - /// \brief Destructor. - ~QMnistNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kQMnistNode; } - - /// \brief Print the description. - /// \param out - The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - - /// \brief Getter functions. - const std::string &Usage() const { return usage_; } - - /// \brief Getter functions. - const bool Compat() const { return compat_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - -#ifndef ENABLE_ANDROID - /// \brief Function to read dataset in json - /// \param[in] json_obj The JSON object to be deserialized - /// \param[out] ds Deserialized dataset - /// \return Status The status code returned - static Status from_json(nlohmann::json json_obj, std::shared_ptr *ds); -#endif - - /// \brief Sampler getter. - /// \return SamplerObj of the current node. - std::shared_ptr Sampler() override { return sampler_; } - - /// \brief Sampler setter. - void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } - - private: - std::string dataset_dir_; - std::string usage_; - bool compat_; - std::shared_ptr sampler_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class QMnistNode : public MappableSourceNode { + public: + /// \brief Constructor. + QMnistNode(const std::string &dataset_dir, const std::string &usage, bool compat, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor. + ~QMnistNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kQMnistNode; } + + /// \brief Print the description. + /// \param out - The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + + /// \brief Getter functions. + const std::string &Usage() const { return usage_; } + + /// \brief Getter functions. + const bool Compat() const { return compat_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + +#ifndef ENABLE_ANDROID + /// \brief Function to read dataset in json + /// \param[in] json_obj The JSON object to be deserialized + /// \param[out] ds Deserialized dataset + /// \return Status The status code returned + static Status from_json(nlohmann::json json_obj, std::shared_ptr *ds); +#endif + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + std::string usage_; + bool compat_; + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_QMNIST_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.cc index c909e8369df..9e282aa32ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.cc @@ -1,125 +1,125 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h" - -#include -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/sbu_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -SBUNode::SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr &sampler, - const std::shared_ptr &cache) - : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {} - -std::shared_ptr SBUNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); - auto node = std::make_shared(dataset_dir_, decode_, sampler, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void SBUNode::Print(std::ostream &out) const { - out << (Name() + "(dataset dir: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + - ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); -} - -Status SBUNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUDataset", sampler_)); - - Path root_dir(dataset_dir_); - - Path url_path = root_dir / Path("SBU_captioned_photo_dataset_urls.txt"); - Path caption_path = root_dir / Path("SBU_captioned_photo_dataset_captions.txt"); - Path image_path = root_dir / Path("sbu_images"); - - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {url_path.ToString()}, "url file")); - RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {caption_path.ToString()}, "caption file")); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", {image_path.ToString()})); - - return Status::OK(); -} - -Status SBUNode::Build(std::vector> *const node_ops) { - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - - auto op = std::make_shared(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_, - connector_que_size_); - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - - return Status::OK(); -} - -// Get the shard id of node -Status SBUNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - - return Status::OK(); -} - -// Get Dataset size -Status SBUNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(SBUOp::CountTotalRows(dataset_dir_, &num_rows)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - sample_size = sampler_rt->CalculateNumSamples(num_rows); - if (sample_size == -1) { - RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); - } - *dataset_size = sample_size; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status SBUNode::to_json(nlohmann::json *out_json) { - nlohmann::json args, sampler_args; - RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); - args["sampler"] = sampler_args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["decode"] = decode_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/sbu_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sbu_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +SBUNode::SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr &sampler, + const std::shared_ptr &cache) + : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler) {} + +std::shared_ptr SBUNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(dataset_dir_, decode_, sampler, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void SBUNode::Print(std::ostream &out) const { + out << (Name() + "(dataset dir: " + dataset_dir_ + ", decode: " + (decode_ ? "true" : "false") + + ", cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); +} + +Status SBUNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUDataset", sampler_)); + + Path root_dir(dataset_dir_); + + Path url_path = root_dir / Path("SBU_captioned_photo_dataset_urls.txt"); + Path caption_path = root_dir / Path("SBU_captioned_photo_dataset_captions.txt"); + Path image_path = root_dir / Path("sbu_images"); + + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {url_path.ToString()}, "url file")); + RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {caption_path.ToString()}, "caption file")); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", {image_path.ToString()})); + + return Status::OK(); +} + +Status SBUNode::Build(std::vector> *const node_ops) { + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + + auto op = std::make_shared(dataset_dir_, decode_, std::move(schema), std::move(sampler_rt), num_workers_, + connector_que_size_); + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + + return Status::OK(); +} + +// Get the shard id of node +Status SBUNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + +// Get Dataset size +Status SBUNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(SBUOp::CountTotalRows(dataset_dir_, &num_rows)); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); + sample_size = sampler_rt->CalculateNumSamples(num_rows); + if (sample_size == -1) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); + } + *dataset_size = sample_size; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status SBUNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["decode"] = decode_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.h index 70ad3110b18..45b67a0d637 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sbu_node.h @@ -1,95 +1,95 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class SBUNode : public MappableSourceNode { - public: - /// \brief Constructor. - SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr &sampler, - const std::shared_ptr &cache); - - /// \brief Destructor. - ~SBUNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kSBUNode; } - - /// \brief Print the description. - /// \param out - The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - bool Decode() const { return decode_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief Sampler getter. - /// \return SamplerObj of the current node. - std::shared_ptr Sampler() override { return sampler_; } - - /// \brief Sampler setter. - void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } - - private: - std::string dataset_dir_; - bool decode_; - std::shared_ptr sampler_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class SBUNode : public MappableSourceNode { + public: + /// \brief Constructor. + SBUNode(const std::string &dataset_dir, bool decode, const std::shared_ptr &sampler, + const std::shared_ptr &cache); + + /// \brief Destructor. + ~SBUNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kSBUNode; } + + /// \brief Print the description. + /// \param out - The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + bool Decode() const { return decode_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief Sampler getter. + /// \return SamplerObj of the current node. + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter. + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + + private: + std::string dataset_dir_; + bool decode_; + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SBU_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.cc index 688446afea8..85bf8bf365b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.cc @@ -1,194 +1,194 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h" - -#include "minddata/dataset/util/path.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -SogouNewsNode::SogouNewsNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, - const std::shared_ptr &cache) - : NonMappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - num_samples_(num_samples), - shuffle_(shuffle), - num_shards_(num_shards), - shard_id_(shard_id), - usage_(usage), - sogou_news_files_list_(WalkAllFiles(usage, dataset_dir)) { - // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. - // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work - // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to - // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. - GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); -} - -std::shared_ptr SogouNewsNode::Copy() { - auto node = - std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void SogouNewsNode::Print(std::ostream &out) const { - out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + - ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); -} - -Status SogouNewsNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("SogouNewsNode", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateStringValue("SogouNewsNode", usage_, {"train", "test", "all"})); - RETURN_IF_NOT_OK(ValidateEnum("SogouNewsNode", "ShuffleMode", shuffle_, - {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); - if (num_samples_ < 0) { - std::string err_msg = "SogouNewsNode: Invalid number of samples: " + std::to_string(num_samples_); - MS_LOG(ERROR) << err_msg; - LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - RETURN_IF_NOT_OK(ValidateDatasetShardParams("SogouNewsNode", num_shards_, shard_id_)); - return Status::OK(); -} - -Status SogouNewsNode::Build(std::vector> *const node_ops) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - - // Sort the dataset files in a lexicographical order. - std::vector sorted_dataset_files = sogou_news_files_list_; - std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); - - std::vector> column_default; - column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); - column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); - column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); - - std::vector column_name = {"index", "title", "content"}; - char field_delim = ','; - auto sogou_news_op = std::make_shared(num_workers_, num_samples_, worker_connector_size_, - connector_que_size_, shuffle_files, num_shards_, shard_id_, - field_delim, column_default, column_name, sogou_news_files_list_); - - RETURN_IF_NOT_OK(sogou_news_op->Init()); - - // If a global shuffle is used for SogouNews, it will inject a shuffle op over the SogouNews. - // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be - // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset SogouNews - // shuffle option to false. - if (shuffle_ == ShuffleMode::kGlobal) { - // Inject ShuffleOp. - std::shared_ptr shuffle_op = nullptr; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset. - RETURN_IF_NOT_OK(SogouNewsOp::CountAllFileRows(sogou_news_files_list_, false, &num_rows)); - // Add the shuffle op after this op. - RETURN_IF_NOT_OK( - AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); - shuffle_op->SetTotalRepeats(GetTotalRepeats()); - shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - shuffle_op->Skip(skip_steps_); - node_ops->push_back(shuffle_op); - } - sogou_news_op->SetTotalRepeats(GetTotalRepeats()); - sogou_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(sogou_news_op); - return Status::OK(); -} - -Status SogouNewsNode::GetShardId(int32_t *shard_id) { - *shard_id = shard_id_; - return Status::OK(); -} - -Status SogouNewsNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - - int64_t num_rows, sample_size; - RETURN_IF_NOT_OK(SogouNewsOp::CountAllFileRows(sogou_news_files_list_, false, &num_rows)); - sample_size = num_samples_; - num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status SogouNewsNode::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["num_samples"] = num_samples_; - args["shuffle"] = shuffle_; - args["num_shards"] = num_shards_; - args["shard_id"] = shard_id_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -Status SogouNewsNode::SetupSamplerForCache(std::shared_ptr *sampler) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - return Status::OK(); -} - -Status SogouNewsNode::MakeSimpleProducer() { - shard_id_ = 0; - num_shards_ = 1; - shuffle_ = ShuffleMode::kFalse; - num_samples_ = 0; - return Status::OK(); -} - -std::vector SogouNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { - std::vector sogou_news_files_list; - Path train_prefix("train.csv"); - Path test_prefix("test.csv"); - Path dir(dataset_dir); - - if (usage == "train") { - Path temp_path = dir / train_prefix; - sogou_news_files_list.push_back(temp_path.ToString()); - } else if (usage == "test") { - Path temp_path = dir / test_prefix; - sogou_news_files_list.push_back(temp_path.ToString()); - } else { - Path temp_path = dir / train_prefix; - if (temp_path.Exists()) { - sogou_news_files_list.push_back(temp_path.ToString()); - } - Path temp_path1 = dir / test_prefix; - if (temp_path1.Exists()) { - sogou_news_files_list.push_back(temp_path1.ToString()); - } - } - return sogou_news_files_list; -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h" + +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +SogouNewsNode::SogouNewsNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + usage_(usage), + sogou_news_files_list_(WalkAllFiles(usage, dataset_dir)) { + // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. + // User discretion is advised. Auto_num_worker_pass is currently an experimental feature which can still work + // if the num_shards_ isn't 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to + // return num_shards. Once PreBuildSampler is phased out, this can be cleaned up. + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr SogouNewsNode::Copy() { + auto node = + std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void SogouNewsNode::Print(std::ostream &out) const { + out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + + ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); +} + +Status SogouNewsNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("SogouNewsNode", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("SogouNewsNode", usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateEnum("SogouNewsNode", "ShuffleMode", shuffle_, + {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); + if (num_samples_ < 0) { + std::string err_msg = "SogouNewsNode: Invalid number of samples: " + std::to_string(num_samples_); + MS_LOG(ERROR) << err_msg; + LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + RETURN_IF_NOT_OK(ValidateDatasetShardParams("SogouNewsNode", num_shards_, shard_id_)); + return Status::OK(); +} + +Status SogouNewsNode::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + + // Sort the dataset files in a lexicographical order. + std::vector sorted_dataset_files = sogou_news_files_list_; + std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); + + std::vector> column_default; + column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); + column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); + column_default.push_back(std::make_shared>(SogouNewsOp::STRING, "")); + + std::vector column_name = {"index", "title", "content"}; + char field_delim = ','; + auto sogou_news_op = std::make_shared(num_workers_, num_samples_, worker_connector_size_, + connector_que_size_, shuffle_files, num_shards_, shard_id_, + field_delim, column_default, column_name, sogou_news_files_list_); + + RETURN_IF_NOT_OK(sogou_news_op->Init()); + + // If a global shuffle is used for SogouNews, it will inject a shuffle op over the SogouNews. + // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be + // built.This is achieved in the cache transform pass where we call MakeSimpleProducer to reset SogouNews + // shuffle option to false. + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp. + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset. + RETURN_IF_NOT_OK(SogouNewsOp::CountAllFileRows(sogou_news_files_list_, false, &num_rows)); + // Add the shuffle op after this op. + RETURN_IF_NOT_OK( + AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + shuffle_op->Skip(skip_steps_); + node_ops->push_back(shuffle_op); + } + sogou_news_op->SetTotalRepeats(GetTotalRepeats()); + sogou_news_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(sogou_news_op); + return Status::OK(); +} + +Status SogouNewsNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + return Status::OK(); +} + +Status SogouNewsNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + + int64_t num_rows, sample_size; + RETURN_IF_NOT_OK(SogouNewsOp::CountAllFileRows(sogou_news_files_list_, false, &num_rows)); + sample_size = num_samples_; + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status SogouNewsNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +Status SogouNewsNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +Status SogouNewsNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} + +std::vector SogouNewsNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { + std::vector sogou_news_files_list; + Path train_prefix("train.csv"); + Path test_prefix("test.csv"); + Path dir(dataset_dir); + + if (usage == "train") { + Path temp_path = dir / train_prefix; + sogou_news_files_list.push_back(temp_path.ToString()); + } else if (usage == "test") { + Path temp_path = dir / test_prefix; + sogou_news_files_list.push_back(temp_path.ToString()); + } else { + Path temp_path = dir / train_prefix; + if (temp_path.Exists()) { + sogou_news_files_list.push_back(temp_path.ToString()); + } + Path temp_path1 = dir / test_prefix; + if (temp_path1.Exists()) { + sogou_news_files_list.push_back(temp_path1.ToString()); + } + } + return sogou_news_files_list; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h index 400564096f1..4fec2a07f04 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h @@ -1,135 +1,135 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/sogou_news_op.h" -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -/// \class SogouNewsNode -/// \brief A Node derived class to represent SogouNews Node. -class SogouNewsNode : public NonMappableSourceNode { - public: - /// \brief Constructor of SogouNewsNode. - /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \param[in] usage Part of dataset of SogouNews, can be "train", "test" or "all" data. - /// \param[in] num_samples The number of samples to be included in the dataset. - /// \param[in] shuffle The mode for shuffling data every epoch. - /// Can be any of: - /// ShuffleMode::kFalse - No shuffling is performed. - /// ShuffleMode::kFiles - Shuffle files only. - /// ShuffleMode::kGlobal - Shuffle both the files and samples. - /// \param[in] num_shards Number of shards that the dataset should be divided into. - /// \param[in] shard_id The shard ID within num_shards. This argument should be - /// specified only when num_shards is also specified. - /// \param[in] cache Tensor cache to use. - SogouNewsNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); - - /// \brief Destructor. - ~SogouNewsNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kSogouNewsNode; } - - /// \brief Print the description. - /// \param[out] out The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief A base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \param[in] shard_id The shard id. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting. - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size The size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - const std::string &Usage() const { return usage_; } - int64_t NumSamples() const { return num_samples_; } - ShuffleMode Shuffle() const { return shuffle_; } - int32_t NumShards() const { return num_shards_; } - int32_t ShardId() const { return shard_id_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief SogouNews by itself is a non-mappable dataset that does not support sampling. - /// However, if a cache operator is injected at some other place higher in the tree, that cache can - /// inherit this sampler from the leaf, providing sampling support from the caching layer. - /// That is why we setup the sampler for a leaf node that does not use sampling. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \param[in] sampler The sampler to setup. - /// \return Status of the function. - Status SetupSamplerForCache(std::shared_ptr *sampler) override; - - /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so. - /// that this clue node will produce the full set of data into the cache. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \return Status of the function. - Status MakeSimpleProducer() override; - - /// \brief Generate a list of read file names according to usage. - /// \param[in] usage Part of dataset of SogouNews. - /// \param[in] dataset_dir Path to the root directory that contains the dataset. - /// \return std::vector A list of read file names. - std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); - - private: - std::string dataset_dir_; - std::string usage_; - std::vector> column_defaults_; - std::vector column_names_; - int64_t num_samples_; - ShuffleMode shuffle_; - int32_t num_shards_; - int32_t shard_id_; - std::vector sogou_news_files_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sogou_news_op.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +/// \class SogouNewsNode +/// \brief A Node derived class to represent SogouNews Node. +class SogouNewsNode : public NonMappableSourceNode { + public: + /// \brief Constructor of SogouNewsNode. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \param[in] usage Part of dataset of SogouNews, can be "train", "test" or "all" data. + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] shuffle The mode for shuffling data every epoch. + /// Can be any of: + /// ShuffleMode::kFalse - No shuffling is performed. + /// ShuffleMode::kFiles - Shuffle files only. + /// ShuffleMode::kGlobal - Shuffle both the files and samples. + /// \param[in] num_shards Number of shards that the dataset should be divided into. + /// \param[in] shard_id The shard ID within num_shards. This argument should be + /// specified only when num_shards is also specified. + /// \param[in] cache Tensor cache to use. + SogouNewsNode(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache); + + /// \brief Destructor. + ~SogouNewsNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kSogouNewsNode; } + + /// \brief Print the description. + /// \param[out] out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief A base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting. + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size The size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + int64_t NumSamples() const { return num_samples_; } + ShuffleMode Shuffle() const { return shuffle_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief SogouNews by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so. + /// that this clue node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + /// \brief Generate a list of read file names according to usage. + /// \param[in] usage Part of dataset of SogouNews. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \return std::vector A list of read file names. + std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); + + private: + std::string dataset_dir_; + std::string usage_; + std::vector> column_defaults_; + std::vector column_names_; + int64_t num_samples_; + ShuffleMode shuffle_; + int32_t num_shards_; + int32_t shard_id_; + std::vector sogou_news_files_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SOGOU_NEWS_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.cc index 257fb0f829a..4fea97209f0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.cc @@ -1,171 +1,171 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" - -#include -#include -#include -#include -#include - -#include "minddata/dataset/engine/datasetops/source/usps_op.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -USPSNode::USPSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, std::shared_ptr cache) - : NonMappableSourceNode(std::move(cache)), - dataset_dir_(dataset_dir), - usage_(usage), - num_samples_(num_samples), - shuffle_(shuffle), - num_shards_(num_shards), - shard_id_(shard_id) { - // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion - // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't - // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once - // PreBuildSampler is phased out, this can be cleaned up. - GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); -} - -std::shared_ptr USPSNode::Copy() { - auto node = std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); - (void)node->SetNumWorkers(num_workers_); - (void)node->SetConnectorQueueSize(connector_que_size_); - return node; -} - -void USPSNode::Print(std::ostream &out) const { - out << (Name() + "(dataset dir:" + dataset_dir_ + ", usage:" + usage_ + - ", num_shards:" + std::to_string(num_shards_) + ", shard_id:" + std::to_string(shard_id_) + - ", num_samples:" + std::to_string(num_samples_) + ")"); -} - -Status USPSNode::ValidateParams() { - RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); - RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSDataset", dataset_dir_)); - RETURN_IF_NOT_OK(ValidateStringValue("USPSDataset", usage_, {"train", "test", "all"})); - RETURN_IF_NOT_OK(ValidateScalar("USPSDataset", "num_samples", num_samples_, {0}, false)); - RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSDataset", num_shards_, shard_id_)); - RETURN_IF_NOT_OK(ValidateEnum("USPSDataset", "ShuffleMode", shuffle_, - {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); - return Status::OK(); -} - -Status USPSNode::Build(std::vector> *const node_ops) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - - auto op = std::make_shared(dataset_dir_, usage_, std::move(schema), num_workers_, worker_connector_size_, - num_samples_, connector_que_size_, shuffle_files, num_shards_, shard_id_); - RETURN_IF_NOT_OK(op->Init()); - - // If a global shuffle is used for USPS, it will inject a shuffle op over the USPS. - // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. - // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset USPS's shuffle - // option to false. - if (shuffle_ == ShuffleMode::kGlobal) { - // Inject ShuffleOp - std::shared_ptr shuffle_op = nullptr; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset - RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - - // Add the shuffle op after this op - RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); - shuffle_op->SetTotalRepeats(GetTotalRepeats()); - shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - shuffle_op->Skip(skip_steps_); - node_ops->push_back(shuffle_op); - } - op->SetTotalRepeats(GetTotalRepeats()); - op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); - node_ops->push_back(op); - return Status::OK(); -} - -// Get the shard id of node -Status USPSNode::GetShardId(int32_t *shard_id) { - *shard_id = shard_id_; - return Status::OK(); -} - -// Get Dataset size -Status USPSNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) { - if (dataset_size_ > 0) { - *dataset_size = dataset_size_; - return Status::OK(); - } - int64_t num_rows, sample_size = num_samples_; - RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); - *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; - dataset_size_ = *dataset_size; - return Status::OK(); -} - -Status USPSNode::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["num_parallel_workers"] = num_workers_; - args["connector_queue_size"] = connector_que_size_; - args["dataset_dir"] = dataset_dir_; - args["usage"] = usage_; - args["num_samples"] = num_samples_; - args["shuffle"] = shuffle_; - args["num_shards"] = num_shards_; - args["shard_id"] = shard_id_; - if (cache_ != nullptr) { - nlohmann::json cache_args; - RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); - args["cache"] = cache_args; - } - *out_json = args; - return Status::OK(); -} - -// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. -// USPS by itself is a non-mappable dataset that does not support sampling. -// However, if a cache operator is injected at some other place higher in the tree, that cache can -// inherit this sampler from the leaf, providing sampling support from the caching layer. -// That is why we setup the sampler for a leaf node that does not use sampling. -Status USPSNode::SetupSamplerForCache(std::shared_ptr *sampler) { - bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - return Status::OK(); -} - -// If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing -// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults so -// that this USPS node will produce the full set of data into the cache. -Status USPSNode::MakeSimpleProducer() { - shard_id_ = 0; - num_shards_ = 1; - shuffle_ = ShuffleMode::kFalse; - num_samples_ = 0; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/usps_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +USPSNode::USPSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, std::shared_ptr cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + usage_(usage), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id) { + // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion + // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't + // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once + // PreBuildSampler is phased out, this can be cleaned up. + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr USPSNode::Copy() { + auto node = std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + (void)node->SetNumWorkers(num_workers_); + (void)node->SetConnectorQueueSize(connector_que_size_); + return node; +} + +void USPSNode::Print(std::ostream &out) const { + out << (Name() + "(dataset dir:" + dataset_dir_ + ", usage:" + usage_ + + ", num_shards:" + std::to_string(num_shards_) + ", shard_id:" + std::to_string(shard_id_) + + ", num_samples:" + std::to_string(num_samples_) + ")"); +} + +Status USPSNode::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSDataset", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("USPSDataset", usage_, {"train", "test", "all"})); + RETURN_IF_NOT_OK(ValidateScalar("USPSDataset", "num_samples", num_samples_, {0}, false)); + RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSDataset", num_shards_, shard_id_)); + RETURN_IF_NOT_OK(ValidateEnum("USPSDataset", "ShuffleMode", shuffle_, + {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); + return Status::OK(); +} + +Status USPSNode::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + auto op = std::make_shared(dataset_dir_, usage_, std::move(schema), num_workers_, worker_connector_size_, + num_samples_, connector_que_size_, shuffle_files, num_shards_, shard_id_); + RETURN_IF_NOT_OK(op->Init()); + + // If a global shuffle is used for USPS, it will inject a shuffle op over the USPS. + // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. + // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset USPS's shuffle + // option to false. + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset + RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + + // Add the shuffle op after this op + RETURN_IF_NOT_OK(AddShuffleOp(op->FileNames().size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + shuffle_op->Skip(skip_steps_); + node_ops->push_back(shuffle_op); + } + op->SetTotalRepeats(GetTotalRepeats()); + op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(op); + return Status::OK(); +} + +// Get the shard id of node +Status USPSNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + return Status::OK(); +} + +// Get Dataset size +Status USPSNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size = num_samples_; + RETURN_IF_NOT_OK(USPSOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status USPSNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["connector_queue_size"] = connector_que_size_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// USPS by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status USPSNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults so +// that this USPS node will produce the full set of data into the cache. +Status USPSNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.h index ccb9f20cdc1..8ce700b84d5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/usps_node.h @@ -1,120 +1,120 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ - -#include -#include -#include - -#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" - -namespace mindspore { -namespace dataset { -class USPSNode : public NonMappableSourceNode { - public: - /// \brief Constructor. - USPSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, std::shared_ptr cache); - - /// \brief Destructor. - ~USPSNode() override = default; - - /// \brief Node name getter. - /// \return Name of the current node. - std::string Name() const override { return kUSPSNode; } - - /// \brief Print the description. - /// \param out - The output stream to write output to. - void Print(std::ostream &out) const override; - - /// \brief Copy the node to a new object. - /// \return A shared pointer to the new copy. - std::shared_ptr Copy() override; - - /// \brief a base class override function to create the required runtime dataset op objects for this class. - /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. - /// \return Status Status::OK() if build successfully. - Status Build(std::vector> *const node_ops) override; - - /// \brief Parameters validation. - /// \return Status Status::OK() if all the parameters are valid. - Status ValidateParams() override; - - /// \brief Get the shard id of node. - /// \return Status Status::OK() if get shard id successfully. - Status GetShardId(int32_t *shard_id) override; - - /// \brief Base-class override for GetDatasetSize. - /// \param[in] size_getter Shared pointer to DatasetSizeGetter. - /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting - /// dataset size at the expense of accuracy. - /// \param[out] dataset_size the size of the dataset. - /// \return Status of the function. - Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, - int64_t *dataset_size) override; - - /// \brief Getter functions. - const std::string &DatasetDir() const { return dataset_dir_; } - - /// \brief Getter functions. - const std::string &Usage() const { return usage_; } - - /// \brief Getter functions. - int32_t NumSamples() const { return num_samples_; } - - /// \brief Getter functions. - int32_t NumShards() const { return num_shards_; } - - /// \brief Getter functions. - int32_t ShardId() const { return shard_id_; } - - /// \brief Getter functions. - ShuffleMode Shuffle() const { return shuffle_; } - - /// \brief Get the arguments of node. - /// \param[out] out_json JSON string of all attributes. - /// \return Status of the function. - Status to_json(nlohmann::json *out_json) override; - - /// \brief USPS by itself is a non-mappable dataset that does not support sampling. - /// However, if a cache operator is injected at some other place higher in the tree, that cache can - /// inherit this sampler from the leaf, providing sampling support from the caching layer. - /// That is why we setup the sampler for a leaf node that does not use sampling. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \param[in] sampler The sampler to setup. - /// \return Status of the function. - Status SetupSamplerForCache(std::shared_ptr *sampler) override; - - /// \brief If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults - /// so that this USPS node will produce the full set of data into the cache. - /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. - /// \return Status of the function. - Status MakeSimpleProducer() override; - - private: - std::string dataset_dir_; - std::string usage_; - int32_t num_samples_; - ShuffleMode shuffle_; - int32_t num_shards_; - int32_t shard_id_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +class USPSNode : public NonMappableSourceNode { + public: + /// \brief Constructor. + USPSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, std::shared_ptr cache); + + /// \brief Destructor. + ~USPSNode() override = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return kUSPSNode; } + + /// \brief Print the description. + /// \param out - The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class. + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + + /// \brief Getter functions. + const std::string &Usage() const { return usage_; } + + /// \brief Getter functions. + int32_t NumSamples() const { return num_samples_; } + + /// \brief Getter functions. + int32_t NumShards() const { return num_shards_; } + + /// \brief Getter functions. + int32_t ShardId() const { return shard_id_; } + + /// \brief Getter functions. + ShuffleMode Shuffle() const { return shuffle_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief USPS by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this USPS node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the USPS node need to be reset to its defaults + /// so that this USPS node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + private: + std::string dataset_dir_; + std::string usage_; + int32_t num_samples_; + ShuffleMode shuffle_; + int32_t num_shards_; + int32_t shard_id_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_USPS_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index a231513cd07..9caa4ac65ee 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -1,1437 +1,1437 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ - -#include -#include -#include -#include -#include -#include - -#include "include/api/dual_abi_helper.h" -#include "include/api/status.h" -#include "include/api/types.h" -#include "include/dataset/constants.h" -#include "include/dataset/transforms.h" - -namespace mindspore { -namespace dataset { -class TensorOperation; - -// Transform operations for performing computer audio. -namespace audio { -/// \brief Compute the angle of complex tensor input. -class DATASET_API Angle final : public TensorTransform { - public: - /// \brief Constructor. - Angle(); - - /// \brief Destructor. - ~Angle() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; -}; - -/// \brief Design two-pole allpass filter. Similar to SoX implementation. -class DATASET_API AllpassBiquad final : public TensorTransform { - public: - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] central_freq Central frequency (in Hz). - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - explicit AllpassBiquad(int32_t sample_rate, float central_freq, float Q = 0.707); - - /// \brief Destructor. - ~AllpassBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief AmplitudeToDB TensorTransform. -/// \notes Turn a tensor from the power/amplitude scale to the decibel scale. -class DATASET_API AmplitudeToDB final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] stype Scale of input tensor, must be one of [ScaleType::kPower, ScaleType::kMagnitude]. - /// Default: ScaleType::kPower. - /// \param[in] ref_value Calculate db_multiplier. Default: 1.0. - /// \param[in] amin Minimum threshold for input tensor and ref_value. It must be greater than zero. Default: 1e-10. - /// \param[in] top_db Decibels cut-off value. It must be greater than or equal to zero. Default: 80.0. - explicit AmplitudeToDB(ScaleType stype = ScaleType::kPower, float ref_value = 1.0, float amin = 1e-10, - float top_db = 80.0); - - /// \brief Destructor. - ~AmplitudeToDB() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design two-pole band filter. -class DATASET_API BandBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] central_freq Central frequency (in Hz). - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - /// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio. Default: False. - explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false); - - /// \brief Destructor. - ~BandBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design two-pole band-pass filter. -class DATASET_API BandpassBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] central_freq Central frequency (in Hz). - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - /// \param[in] const_skirt_gain, If True, uses a constant skirt gain (peak gain = Q). If False, uses a - /// constant 0dB peak gain. Default: False. - explicit BandpassBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool const_skirt_gain = false); - - /// \brief Destructor. - ~BandpassBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design two-pole band-reject filter. Similar to SoX implementation. -class DATASET_API BandrejectBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] central_freq Central frequency (in Hz). - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - explicit BandrejectBiquad(int32_t sample_rate, float central_freq, float Q = 0.707); - - /// \brief Destructor. - ~BandrejectBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design a bass tone-control effect. -class DATASET_API BassBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] gain Desired gain at the boost (or attenuation) in dB. - /// \param[in] central_freq Central frequency (in Hz). Default: 100.0. - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - explicit BassBiquad(int32_t sample_rate, float gain, float central_freq = 100.0, float Q = 0.707); - - /// \brief Destructor. - ~BassBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Perform a biquad filter of input tensor. -class DATASET_API Biquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] b0 Numerator coefficient of current input, x[n]. - /// \param[in] b1 Numerator coefficient of input one time step ago x[n-1]. - /// \param[in] b2 Numerator coefficient of input two time steps ago x[n-2]. - /// \param[in] a0 Denominator coefficient of current output y[n], the value can't be zero, typically 1. - /// \param[in] a1 Denominator coefficient of current output y[n-1]. - /// \param[in] a2 Denominator coefficient of current output y[n-2]. - explicit Biquad(float b0, float b1, float b2, float a0, float a1, float a2); - - /// \brief Destructor. - ~Biquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief ComplexNorm TensorTransform. -/// \notes Compute the norm of complex tensor input. -class DATASET_API ComplexNorm final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] power Power of the norm, which must be non-negative. Default: 1.0. - explicit ComplexNorm(float power = 1.0); - - /// \brief Destructor. - ~ComplexNorm() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief ComputeDeltas Transform. -/// \note Compute delta coefficients of a spectrogram. -class DATASET_API ComputeDeltas final : public TensorTransform { - public: - /// \brief Construct a new Compute Deltas object. - /// \f[ - /// d_{t}=\frac{{\textstyle\sum_{n=1}^{N}}n(c_{t+n}-c_{t-n})}{2{\textstyle\sum_{n=1}^{N}}n^{2}} - /// \f] - /// \param[in] win_length The window length used for computing delta, must be no less than 3. Default: 5. - /// \param[in] pad_mode Padding mode. Can be one of BorderType::kConstant, BorderType::kEdge, - /// BorderType::kReflect or BorderType::kSymmetric. Default: BorderType::kEdge. - explicit ComputeDeltas(int32_t win_length = 5, BorderType pad_mode = BorderType::kEdge); - - /// \brief Destructor. - ~ComputeDeltas() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply contrast effect. -class DATASET_API Contrast final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] enhancement_amount Controls the amount of the enhancement. Default: 75.0. - explicit Contrast(float enhancement_amount = 75.0); - - /// \brief Destructor. - ~Contrast() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Turn a waveform from the decibel scale to the power/amplitude scale. -class DATASET_API DBToAmplitude final : public TensorTransform { - public: - /// \brief Constructor - /// \param[in] ref Reference which the output will be scaled by. - /// \param[in] power If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. - explicit DBToAmplitude(float ref, float power); - - /// \brief Destructor. - ~DBToAmplitude() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply a DC shift to the audio. -class DATASET_API DCShift : public TensorTransform { - public: - /// \brief Constructor - /// \param[in] shift Indicates the amount to shift the audio, the value must be in the range [-2.0, 2.0]. - /// \param[in] limiter_gain Used only on peaks to prevent clipping. - DCShift(float shift, float limiter_gain); - - /// \brief Constructor - /// \param[in] shift Indicates the amount to shift the audio. - /// \note This constructor will use `shift` as `limiter_gain`. - explicit DCShift(float shift); - - /// \brief Destructor. - ~DCShift() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \param[in] n_mfcc Number of mfc coefficients to retain, the value must be greater than 0. -/// \param[in] n_mels Number of mel filterbanks, the value must be greater than 0. -/// \param[in] norm Norm to use, can be NormMode::kNone or NormMode::kOrtho. -/// \return Status error code, returns OK if no error encountered. -Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm = NormMode::kNone); - -/// \brief Design two-pole deemph filter. Similar to SoX implementation. -class DATASET_API DeemphBiquad final : public TensorTransform { - public: - /// \param[in] sample_rate Sampling rate of the waveform, the value can only be 44100 (Hz) or 48000(hz). - explicit DeemphBiquad(int32_t sample_rate); - - /// \brief Destructor. - ~DeemphBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Detect pitch frequency. -class DATASET_API DetectPitchFrequency final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] frame_time Duration of a frame, the value must be greater than zero. Default: 0.02. - /// \param[in] win_length The window length for median smoothing (in number of frames), the value must - /// be greater than zero. Default: 30. - /// \param[in] freq_low Lowest frequency that can be detected (Hz), the value must be greater than zero. Default: 85. - /// \param[in] freq_high Highest frequency that can be detected (Hz), the value must be greater than - /// zero. Default: 3400. - explicit DetectPitchFrequency(int32_t sample_rate, float frame_time = 0.01, int32_t win_length = 30, - int32_t freq_low = 85, int32_t freq_high = 3400); - - /// \brief Destructor. - ~DetectPitchFrequency() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Dither increases the perceived dynamic range of audio stored at a -/// particular bit-depth by eliminating nonlinear truncation distortion. -class DATASET_API Dither final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] density_function The density function of a continuous random variable. - /// Can be one of DensityFunction::kTPDF (Triangular Probability Density Function), - /// DensityFunction::kRPDF (Rectangular Probability Density Function) or - /// DensityFunction::kGPDF (Gaussian Probability Density Function). Default: DensityFunction::kTPDF. - /// \param[in] noise_shaping A filtering process that shapes the spectral energy of - /// quantisation error. Default: false. - explicit Dither(DensityFunction density_function = DensityFunction::kTPDF, bool noise_shaping = false); - - /// \brief Destructor. - ~Dither() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio. -class DATASET_API EqualizerBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] center_freq Filter's central frequency (in Hz). - /// \param[in] gain Desired gain at the boost (or attenuation) in dB. - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q = 0.707); - - /// \brief Destructor. - ~EqualizerBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Add fade in or/and fade out on the input audio. -class DATASET_API Fade final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] fade_in_len Length of fade-in (time frames), which must be non-negative - /// and no more than the length of waveform. Default: 0. - /// \param[in] fade_out_len Length of fade-out (time frames), which must be non-negative - /// and no more than the length of waveform. Default: 0. - /// \param[in] fade_shape An enum for the fade shape. Default: FadeShape::kLinear. - explicit Fade(int32_t fade_in_len = 0, int32_t fade_out_len = 0, FadeShape fade_shape = FadeShape::kLinear); - - /// \brief Destructor. - ~Fade() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design IIR forward and backward filter. -class DATASET_API Filtfilt final : public TensorTransform { - public: - /// \param[in] a_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). - /// Lower delays coefficients are first, e.g. [a0, a1, a2, ...]. - /// Must be same size as b_coeffs (pad with 0's as necessary). - /// \param[in] b_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). - /// Lower delays coefficients are first, e.g. [b0, b1, b2, ...]. - /// Must be same size as a_coeffs (pad with 0's as necessary). - /// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1]. Default: True. - Filtfilt(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp = true); - - /// \brief Destructor. - ~Filtfilt() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply a flanger effect to the audio. -class DATASET_API Flanger final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). - /// \param[in] delay Desired delay in milliseconds (ms), range: [0, 30]. Default: 0.0. - /// \param[in] depth Desired delay depth in milliseconds (ms), range: [0, 10]. Default: 2.0. - /// \param[in] regen Desired regen (feedback gain) in dB., range: [-95, 95]. Default: 0.0. - /// \param[in] width Desired width (delay gain) in dB, range: [0, 100]. Default: 71.0. - /// \param[in] speed Modulation speed in Hz, range: [0.1, 10]. Default: 0.5. - /// \param[in] phase Percentage phase-shift for multi-channel, range: [0, 100]. Default: 25.0. - /// \param[in] modulation Modulation of input tensor, must be one of [Modulation::kSinusoidal, - /// Modulation::kTriangular]. Default:Modulation::kSinusoidal. - /// \param[in] interpolation Interpolation of input tensor, must be one of [Interpolation::kLinear, - /// Interpolation::kQuadratic]. Default:Interpolation::kLinear. - explicit Flanger(int32_t sample_rate, float delay = 0.0, float depth = 2.0, float regen = 0.0, float width = 71.0, - float speed = 0.5, float phase = 25.0, Modulation modulation = Modulation::kSinusoidal, - Interpolation interpolation = Interpolation::kLinear); - - /// \brief Destructor. - ~Flanger() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief FrequencyMasking TensorTransform. -/// \notes Apply masking to a spectrogram in the frequency domain. -class DATASET_API FrequencyMasking final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] iid_masks Whether to apply different masks to each example. - /// \param[in] frequency_mask_param Maximum possible length of the mask, range: [0, freq_length]. Default: 0. - /// Indices uniformly sampled from [0, frequency_mask_param]. - /// Mask width when iid_masks=true. - /// \param[in] mask_start Mask start when iid_masks=true, range: [0, freq_length-frequency_mask_param]. Default: 0. - /// \param[in] mask_value Mask value. - explicit FrequencyMasking(bool iid_masks = false, int32_t frequency_mask_param = 0, int32_t mask_start = 0, - float mask_value = 0.0); - - /// \brief Destructor. - ~FrequencyMasking() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply amplification or attenuation to the whole waveform. -class DATASET_API Gain final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] gain_db Gain adjustment in decibels (dB). Default: 1.0. - explicit Gain(float gain_db = 1.0); - - /// \brief Destructor. - ~Gain() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Waveform calculation from linear scalar amplitude spectrogram using GriffinLim transform. -class DATASET_API GriffinLim final : public TensorTransform { - public: - /// \brief Constructor. - /// \notes Calculated by formula: - /// x(n)=\frac{\sum_{m=-\infty}^{\infty} w(m S-n) y_{w}(m S, n)}{\sum_{m=-\infty}^{\infty} w^{2}(m S-n)} - /// where w represents the window function, y represents the reconstructed signal of each frame and x represents - /// the whole signal. - /// \param[in] n_fft Size of FFT. Default: 400. - /// \param[in] n_iter Number of iteration for phase recovery. Default: 32. - /// \param[in] win_length Window size for GriffinLim. Default: 0, will be set to n_fft. - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to win_length / 2. - /// \param[in] window_type Window type for GriffinLim. Default: WindowType::kHann. - /// \param[in] power Exponent for the magnitude spectrogram. Default: 2.0. - /// \param[in] momentum The momentum for fast Griffin-Lim. Default: 0.99. - /// \param[in] length Length of the expected output waveform. Default: 0.0, will be set to the value of last - /// dimension of the stft matrix. - /// \param[in] rand_init Flag for random phase initialization or all-zero phase initialization. Default: true. - explicit GriffinLim(int32_t n_fft = 400, int32_t n_iter = 32, int32_t win_length = 0, int32_t hop_length = 0, - WindowType window_type = WindowType::kHann, float power = 2.0, float momentum = 0.99, - int32_t length = 0, bool rand_init = true); - - /// \brief Destructor. - ~GriffinLim() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief HighpassBiquad TensorTransform. Apply highpass biquad filter on audio. -class DATASET_API HighpassBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] cutoff_freq Filter cutoff frequency (in Hz). - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q = 0.707); - - /// \brief Destructor. - ~HighpassBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief InverseMelScale TensorTransform -/// \notes Solve for a normal STFT from a mel frequency STFT, using a conversion matrix. -class DATASET_API InverseMelScale final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] n_stft Number of bins in STFT, must be positive. - /// \param[in] n_mels Number of mel filter, must be positive. Default: 128. - /// \param[in] sample_rate Sample rate of the signal, the value can't be zero. Default: 16000. - /// \param[in] f_min Minimum frequency, must be non-negative. Default: 0.0. - /// \param[in] f_max Maximum frequency, must be non-negative. Default: 0.0, will be set to sample_rate / 2. - /// \param[in] max_iter Maximum number of optimization iterations, must be positive. Default: 100000. - /// \param[in] tolerance_loss Value of loss to stop optimization at, must be non-negative. Default: 1e-5. - /// \param[in] tolerance_change Difference in losses to stop optimization at, must be non-negative. Default: 1e-8. - /// \param[in] sgdargs Parameters of SGD optimizer, including lr, momentum. - /// Default: {{"sgd_lr", 0.1}, {"sgd_momentum", 0.0}}. - /// \param[in] norm Type of norm, value should be NormType::kSlaney or NormType::kNone. If norm is NormType::kSlaney, - /// divide the triangle mel weight by the width of the mel band. Default: NormType::kNone. - /// \param[in] mel_type Type of mel, value should be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtk. - explicit InverseMelScale(int32_t n_stft, int32_t n_mels = 128, int32_t sample_rate = 16000, float f_min = 0.0, - float f_max = 0.0, int32_t max_iter = 100000, float tolerance_loss = 1e-5, - float tolerance_change = 1e-8, - const std::map &sgdargs = {{"sgd_lr", 0.1}, {"sgd_momentum", 0.0}}, - NormType norm = NormType::kNone, MelType mel_type = MelType::kHtk); - - /// \brief Destructor. - ~InverseMelScale() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create an inverse spectrogram to recover an audio signal from a spectrogram. -class DATASET_API InverseSpectrogram final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] length The output length of the waveform. Default: 0, means to output the whole waveform. - /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. - /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 2` . - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. - /// Default: WindowType::kHann. - /// \param[in] normalized Whether the spectrogram was normalized by magnitude after stft. Default:false. - /// \param[in] center Whether the signal in spectrogram was padded on both sides. Default: true. - /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. - /// \param[in] onesided Controls whether spectrogram was used to return half of results to avoid - /// redundancy. Default: true. - explicit InverseSpectrogram(int32_t length = 0, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, - int32_t pad = 0, WindowType window = WindowType::kHann, bool normalized = false, - bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true); - - /// \brief Destructor. - ~InverseSpectrogram() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create LFCC for a raw audio signal. -class DATASET_API LFCC final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. - /// \param[in] n_filter Number of linear filters to apply. Default: 128. - /// \param[in] n_lfcc Number of lfc coefficients to retain. Default: 40. - /// \param[in] f_min Minimum frequency. Default: 0.0. - /// \param[in] f_max Maximum frequency. Default: 0.0, will be set to sample_rate // 2. - /// \param[in] dct_type Type of DCT (discrete cosine transform) to use. Default: 2. - /// \param[in] norm Norm to use. Default: NormMode::kOrtho. - /// \param[in] log_lf Whether to use log-lf spectrograms instead of db-scaled. Default: false. - /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. - /// \param[in] win_length Window size. Default: 0, will be set to n_fft. - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to win_length // 2. - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] window A function to create a window tensor that is applied/multiplied to - /// each frame/window. Default: WindowType::kHann. - /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 - /// for power, etc. Default: 2.0. - /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false - /// \param[in] center Whether to pad waveform on both sides so that the tt-th frame is centered at - /// time t t*hop_length. Default: true. - /// \param[in] pad_mode Controls the padding method used when center is True. Default: - /// BorderType::kReflect. - /// \param[in] onesided Controls whether to return half of results to avoid - /// redundancy. Default: true. - explicit LFCC(int32_t sample_rate = 16000, int32_t n_filter = 128, int32_t n_lfcc = 40, float f_min = 0.0, - float f_max = 0.0, int32_t dct_type = 2, NormMode norm = NormMode::kOrtho, bool log_lf = false, - int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0, - WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true, - BorderType pad_mode = BorderType::kReflect, bool onesided = true); - - /// \brief Destructor. - ~LFCC() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design filter. Similar to SoX implementation. -class DATASET_API LFilter final : public TensorTransform { - public: - /// \param[in] a_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). - /// Lower delays coefficients are first, e.g. [a0, a1, a2, ...]. - /// Must be same size as b_coeffs (pad with 0's as necessary). - /// \param[in] b_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). - /// Lower delays coefficients are first, e.g. [b0, b1, b2, ...]. - /// Must be same size as a_coeffs (pad with 0's as necessary). - /// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1]. Default: True. - explicit LFilter(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp = true); - - /// \brief Destructor. - ~LFilter() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Creates a linear triangular filterbank. -/// \param output Tensor of a linear triangular filterbank. -/// \param n_freqs: Number of frequency. -/// \param f_min: Minimum of frequency in Hz. -/// \param f_max: Maximum of frequency in Hz. -/// \param n_filter: Number of (linear) triangular filter. -/// \param sample_rate: Sample rate. -/// \return Status code. -Status DATASET_API LinearFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_filter, - int32_t sample_rate); - -/// \brief Design biquad lowpass filter and perform filtering. Similar to SoX implementation. -class DATASET_API LowpassBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] cutoff_freq Filter cutoff frequency. - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - LowpassBiquad(int32_t sample_rate, float cutoff_freq, float Q = 0.707); - - /// \brief Destructor. - ~LowpassBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase. -class DATASET_API Magphase final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] power Power of the norm, which must be non-negative. Default: 1.0. - explicit Magphase(float power); - - /// \brief Destructor. - ~Magphase() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief MaskAlongAxis TensorTransform. -/// \note Tensor operation to mask the input tensor along axis. -class MaskAlongAxis final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] mask_start Starting position of the mask, which must be non negative. - /// \param[in] mask_width The width of the mask, which must be positive. - /// \param[in] mask_value Value to assign to the masked columns. - /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). - MaskAlongAxis(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis); - - /// \brief Destructor. - ~MaskAlongAxis() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief MaskAlongAxisIID TensorTransform. -/// \note Apply a mask along axis. -class MaskAlongAxisIID final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] mask_param Number of columns to be masked, will be uniformly sampled from [0, mask_param], - /// must be non negative. - /// \param[in] mask_value Value to assign to the masked columns. - /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). - MaskAlongAxisIID(int32_t mask_param, float mask_value, int32_t axis); - - /// \brief Destructor. - ~MaskAlongAxisIID() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief MelScale TensorTransform. -/// \notes Convert normal STFT to STFT at the Mel scale. -class DATASET_API MelScale final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] n_mels Number of mel filter, which must be positive. Default: 128. - /// \param[in] sample_rate Sample rate of the signal, the value can't be zero. Default: 16000. - /// \param[in] f_min Minimum frequency, which must be non negative. Default: 0.0. - /// \param[in] f_max Maximum frequency, which must be positive. Default: 0.0, will be set to sample_rate / 2. - /// \param[in] n_stft Number of bins in STFT, which must be positive. Default: 201. - /// \param[in] norm Type of norm, value should be NormType::kSlaney or NormType::kNone. If norm is NormType::kSlaney, - /// divide the triangle mel weight by the width of the mel band. Default: NormType::kNone. - /// \param[in] mel_type Type of mel, value should be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtk. - explicit MelScale(int32_t n_mels = 128, int32_t sample_rate = 16000, float f_min = 0.0, float f_max = 0.0, - int32_t n_stft = 201, NormType norm = NormType::kNone, MelType mel_type = MelType::kHtk); - - /// \brief Destructor. - ~MelScale() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create a frequency transformation matrix with shape (n_freqs, n_mels). -/// \param[in] output Tensor of the frequency transformation matrix. -/// \param[in] n_freqs Number of frequencies to highlight/apply. -/// \param[in] f_min Minimum frequency (Hz). -/// \param[in] f_max Maximum frequency (Hz). -/// \param[in] n_mels Number of mel filterbanks. -/// \param[in] sample_rate Sample rate of the audio waveform. -/// \param[in] norm Norm to use, can be NormType::kNone or NormType::kSlaney. Default: NormType::kNone. -/// \param[in] mel_type Scale to use, can be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtz. -/// \return Status code. -Status DATASET_API MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels, - int32_t sample_rate, NormType norm = NormType::kNone, - MelType mel_type = MelType::kHtk); - -/// \brief Create MelSpectrogram for a raw audio signal. -class DATASET_API MelSpectrogram final : public TensorTransform { - public: - /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. - /// \param[in] n_fft Size of FFT, creates `n_fft // 2 + 1` bins. Default: 400. - /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 2` . - /// \param[in] f_min Minimum frequency. Default: 0.0. - /// \param[in] f_max Maximum frequency. Default: 0.0. - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] n_mels Number of mel filterbanks. Default: 128. - /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. - /// Default: WindowType::kHann. - /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - /// Default: 2.0. - /// \param[in] normalized Whether to normalize by magnitude after stft Default: false. - /// \param[in] center Whether to pad waveform on both sides. Default: true. - /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. - /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. - /// \param[in] norm If 'slaney', divide the triangular mel weights by the width of the mel band (area normalization). - /// Default: NormType::kNone. - /// \param[in] mel_scale Scale to use: htk or slaney. Default: MelType::kHtk. - explicit MelSpectrogram(int32_t sample_rate = 16000, int32_t n_fft = 400, int32_t win_length = 0, - int32_t hop_length = 0, float f_min = 0.0, float f_max = 0.0, int32_t pad = 0, - int32_t n_mels = 128, WindowType window = WindowType::kHann, float power = 2.0, - bool normalized = false, bool center = true, BorderType pad_mode = BorderType::kReflect, - bool onesided = true, NormType norm = NormType::kNone, MelType mel_scale = MelType::kHtk); - - /// \brief Destructor. - ~MelSpectrogram() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create MFCC for a raw audio signal. -class DATASET_API MFCC final : public TensorTransform { - public: - /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. - /// \param[in] n_mfcc Number of mfc coefficients to retain. Default: 40. - /// \param[in] dct_type Type of DCT (discrete cosine transform) to use. Default: 2. - /// \param[in] norm If 'slaney', divide the triangular mel weights by the width of the mel band (area normalization). - /// Default: NormMode::kOrtho. - /// \param[in] log_mels Whether to use log-mel spectrograms instead of db-scaled. Default: false. - /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. - /// \param[in] win_length Window size. Default: 0. - /// \param[in] hop_length Length of hop between STFT windows. Default: 0. - /// \param[in] f_min Minimum frequency. Default: 0.0. - /// \param[in] f_max Maximum frequency. Default: 0.0. - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] n_mels Number of mel filterbanks. Default: 128. - /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. - /// Default: WindowType::kHann. - /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - /// Default: 2.0. - /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false. - /// \param[in] center Whether to pad waveform on both sides. Default: true. - /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. - /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. - /// \param[in] norm_mel Norm to use. Default: NormType::kNone. - /// \param[in] mel_scale Scale to use: htk or slaney. Default: MelType::kHtk. - explicit MFCC(int32_t sample_rate = 16000, int32_t n_mfcc = 40, int32_t dct_type = 2, - NormMode norm = NormMode::kOrtho, bool log_mels = false, int32_t n_fft = 400, int32_t win_length = 0, - int32_t hop_length = 0, float f_min = 0.0, float f_max = 0.0, int32_t pad = 0, int32_t n_mels = 128, - WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true, - BorderType pad_mode = BorderType::kReflect, bool onesided = true, NormType norm_mel = NormType::kNone, - MelType mel_scale = MelType::kHtk); - - /// \brief Destructor. - ~MFCC() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief MuLawDecoding TensorTransform. -/// \note Decode mu-law encoded signal. -class DATASET_API MuLawDecoding final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] quantization_channels Number of channels, which must be positive. Default: 256. - explicit MuLawDecoding(int32_t quantization_channels = 256); - - /// \brief Destructor. - ~MuLawDecoding() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief MuLawEncoding TensorTransform. -/// \note Encode signal based on mu-law companding. -class DATASET_API MuLawEncoding final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] quantization_channels Number of channels, which must be positive. Default: 256. - explicit MuLawEncoding(int32_t quantization_channels = 256); - - /// \brief Destructor. - ~MuLawEncoding() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Overdrive TensorTransform. -class DATASET_API Overdrive final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] gain Coefficient of overload in dB, in range of [0, 100]. Default: 20.0. - /// \param[in] color Coefficient of translation, in range of [0, 100]. Default: 20.0. - explicit Overdrive(float gain = 20.0, float color = 20.0); - - /// \brief Destructor. - ~Overdrive() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Phaser TensorTransform. -class DATASET_API Phaser final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). - /// \param[in] gain_in Desired input gain at the boost (or attenuation) in dB. - /// Allowed range of values is [0, 1]. Default: 0.4. - /// \param[in] gain_out Desired output gain at the boost (or attenuation) in dB. - /// Allowed range of values is [0, 1e9]. Default: 0.74. - /// \param[in] delay_ms Desired delay in milli seconds. Allowed range of values is [0, 5]. Default: 3.0. - /// \param[in] decay Desired decay relative to gain-in. Allowed range of values is [0, 0.99]. Default: 0.4. - /// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2]. Default: 0.5. - /// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments). - /// If false, use triangular modulation (gives single instruments a sharper phasing effect). Default: true. - explicit Phaser(int32_t sample_rate, float gain_in = 0.4, float gain_out = 0.74, float delay_ms = 3.0, - float decay = 0.4, float mod_speed = 0.5, bool sinusoidal = true); - - /// \brief Destructor. - ~Phaser() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief PhaseVocoder TensorTransform -/// \notes Given a STFT tensor, speed up in time without modifying pitch by factor of rate. -class DATASET_API PhaseVocoder final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] rate Speed-up factor. - /// \param[in] phase_advance Expected phase advance in each bin in shape of (freq, 1). - PhaseVocoder(float rate, const MSTensor &phase_advance); - - /// \brief Destructor. - ~PhaseVocoder() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -// \brief Shift the pitch of a waveform by 'n_steps' steps. -class DATASET_API PitchShift final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of "waveform". Default: 0. - /// \param[in] n_steps The (fractional) steps to shift "waveform". Default: 0. - /// \param[in] bins_per_octave The number of steps per octave. Default: 12. - /// \param[in] n_fft Size of FFT, creates "n_fft // 2 + 1" bins. Default: 512. - /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 4` . - /// \param[in] window Window tensor that is applied/multiplied to each frame/window. Default: WindowType::kHann. - explicit PitchShift(int32_t sample_rate = 0, int32_t n_steps = 0, int32_t bins_per_octave = 12, int32_t n_fft = 512, - int32_t win_length = 0, int32_t hop_length = 0, WindowType window = WindowType::kHann); - - /// \brief Destructor. - ~PitchShift() override = default; - - protected: - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Resample TensorTransform. -/// \notes Resample a signal from one frequency to another. A sampling method can be given. -class DATASET_API Resample : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] orig_freq The original frequency of the signal, which must be positive. Default: 16000.0. - /// \param[in] new_freq The desired frequency, which must be positive. Default: 16000.0. - /// \param[in] resample_method The resampling method, which can be ResampleMethod::kSincInterpolation - /// and ResampleMethod::kKaiserWindow. Default: ResampleMethod::kSincInterpolation. - /// \param[in] lowpass_filter_width Controls the sharpness of the filter, more means sharper but less efficient, - /// which must be positive. Default: 6. - /// \param[in] rolloff The roll-off frequency of the filter, as a fraction of the Nyquist. Lower values - /// reduce anti-aliasing, but also reduce some of the highest frequencies, range: (0, 1]. Default: 0.99. - /// \param[in] beta The shape parameter used for kaiser window. Default: 14.769656459379492. - explicit Resample(float orig_freq = 16000.0, float new_freq = 16000.0, - ResampleMethod resample_method = ResampleMethod::kSincInterpolation, - int32_t lowpass_filter_width = 6, float rolloff = 0.99, float beta = 14.769656459379492); - - /// \brief Destructor. - ~Resample() override = default; - - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply RIAA vinyl playback equalization. -class DATASET_API RiaaBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), - /// can only be one of 44100, 48000, 88200, 96000. - explicit RiaaBiquad(int32_t sample_rate); - - /// \brief Destructor. - ~RiaaBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. -class DATASET_API SlidingWindowCmn final : public TensorTransform { - public: - /// \brief Constructor of SlidingWindowCmnOp. - /// \param[in] cmn_window The window in frames for running average CMN computation. Default: 600. - /// \param[in] min_cmn_window The minimum CMN window. Only applicable if center is false, ignored if center - /// is true. Default: 100. - /// \param[in] center If true, use a window centered on the current frame. If false, window is to the left. - /// Default: false. - /// \param[in] norm_vars If true, normalize variance to one. Default: false. - explicit SlidingWindowCmn(int32_t cmn_window = 600, int32_t min_cmn_window = 100, bool center = false, - bool norm_vars = false); - - /// \brief Destructor. - ~SlidingWindowCmn() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create a spectral centroid from an audio signal. -class DATASET_API SpectralCentroid : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). - /// \param[in] n_fft Size of FFT, creates n_fft / 2 + 1 bins. Default: 400. - /// \param[in] win_length Window size. Default: 0, will use n_fft. - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will use win_length / 2. - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] window Window function that is applied/multiplied to each frame/window, - /// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming, - /// WindowType::kHann or WindowType::kKaiser. Default: WindowType::kHann. - explicit SpectralCentroid(int32_t sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, - int32_t pad = 0, WindowType window = WindowType::kHann); - - ~SpectralCentroid() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - int32_t sample_rate_; - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; - struct Data; - std::shared_ptr data_; -}; - -/// \brief Create a spectrogram from an audio signal. -class DATASET_API Spectrogram : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] n_fft Size of FFT, creates n_fft / 2 + 1 bins. Default: 400. - /// \param[in] win_length Window size. Default: 0, will use n_fft. - /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will use win_length / 2. - /// \param[in] pad Two sided padding of signal. Default: 0. - /// \param[in] window Window function that is applied/multiplied to each frame/window, - /// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming, - /// WindowType::kHann or WindowType::kKaiser. Default: WindowType::kHann. - /// \param[in] power Exponent for the magnitude spectrogram, which must be greater than or equal to 0. Default: 2.0. - /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false. - /// \param[in] center Whether to pad waveform on both sides. Default: true. - /// \param[in] pad_mode Controls the padding method used when center is true, - /// which can be BorderType::kReflect, BorderType::kConstant, BorderType::kEdge, - /// BorderType::kSymmetric. Default: BorderType::kReflect. - /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. - explicit Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0, - WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, - bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true); - - /// \brief Destructor. - ~Spectrogram() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - int32_t n_fft_; - int32_t win_length_; - int32_t hop_length_; - int32_t pad_; - WindowType window_; - float power_; - bool normalized_; - bool center_; - BorderType pad_mode_; - bool onesided_; - struct Data; - std::shared_ptr data_; -}; - -/// \brief TimeMasking TensorTransform. -/// \notes Apply masking to a spectrogram in the time domain. -class DATASET_API TimeMasking final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] iid_masks Whether to apply different masks to each example. - /// \param[in] time_mask_param Maximum possible length of the mask, range: [0, time_length]. Default: 0. - /// Indices uniformly sampled from [0, time_mask_param]. - /// Mask width when iid_masks=true. - /// \param[in] mask_start Mask start when iid_masks=true, range: [0, time_length-time_mask_param]. Default: 0. - /// \param[in] mask_value Mask value. - explicit TimeMasking(bool iid_masks = false, int32_t time_mask_param = 0, int32_t mask_start = 0, - float mask_value = 0.0); - - /// \brief Destructor. - ~TimeMasking() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief TimeStretch TensorTransform -/// \notes Stretch STFT in time at a given rate, without changing the pitch. -class DATASET_API TimeStretch final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] hop_length Length of hop between STFT windows. Default: None, will use ((n_freq - 1) * 2) // 2. - /// \param[in] n_freq Number of filter banks form STFT. Default: 201. - /// \param[in] fixed_rate Rate to speed up or slow down the input in time. - /// Default: std::numeric_limits::quiet_NaN(), will keep the original rate. - explicit TimeStretch(float hop_length = std::numeric_limits::quiet_NaN(), int n_freq = 201, - float fixed_rate = std::numeric_limits::quiet_NaN()); - - /// \brief Destructor. - ~TimeStretch() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Design a treble tone-control effect. -class DATASET_API TrebleBiquad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. - /// \param[in] gain Desired gain at the boost (or attenuation) in dB. - /// \param[in] central_freq Central frequency (in Hz). Default: 3000.0. - /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. - TrebleBiquad(int32_t sample_rate, float gain, float central_freq = 3000.0, float Q = 0.707); - - /// \brief Destructor. - ~TrebleBiquad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Vad TensorTransform. -/// \notes Attempt to trim silent background sounds from the end of the voice recording. -class DATASET_API Vad final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] sample_rate Sample rate of audio signal. - /// \param[in] trigger_level The measurement level used to trigger activity detection. Default: 7.0. - /// \param[in] trigger_time The time constant (in seconds) used to help ignore short sounds. Default: 0.25. - /// \param[in] search_time The amount of audio (in seconds) to search for quieter/shorter sounds to include prior to - /// the detected trigger point. Default: 1.0. - /// \param[in] allowed_gap The allowed gap (in seconds) between quiteter/shorter sounds to include prior to the - /// detected trigger point. Default: 0.25. - /// \param[in] pre_trigger_time The amount of audio (in seconds) to preserve before the trigger point and any found - /// quieter/shorter bursts. Default: 0.0. - /// \param[in] boot_time The time for the initial noise estimate. Default: 0.35. - /// \param[in] noise_up_time Time constant used by the adaptive noise estimator, when the noise level is increasing. - /// Default: 0.1. - /// \param[in] noise_down_time Time constant used by the adaptive noise estimator, when the noise level is decreasing. - /// Default: 0.01. - /// \param[in] noise_reduction_amount The amount of noise reduction used in the detection algorithm. Default: 1.35. - /// \param[in] measure_freq The frequency of the algorithm’s processing. Default: 20.0. - /// \param[in] measure_duration The duration of measurement. Default: 0, use twice the measurement period. - /// \param[in] measure_smooth_time The time constant used to smooth spectral measurements. Default: 0.4. - /// \param[in] hp_filter_freq The "Brick-wall" frequency of high-pass filter applied at the input to the detector - /// algorithm. Default: 50.0. - /// \param[in] lp_filter_freq The "Brick-wall" frequency of low-pass filter applied at the input to the detector - /// algorithm. Default: 6000.0. - /// \param[in] hp_lifter_freq The "Brick-wall" frequency of high-pass lifter applied at the input to the detector - /// algorithm. Default: 150.0. - /// \param[in] lp_lifter_freq The "Brick-wall" frequency of low-pass lifter applied at the input to the detector - /// algorithm. Default: 2000.0. - explicit Vad(int32_t sample_rate, float trigger_level = 7.0, float trigger_time = 0.25, float search_time = 1.0, - float allowed_gap = 0.25, float pre_trigger_time = 0.0, float boot_time = 0.35, - float noise_up_time = 0.1, float noise_down_time = 0.01, float noise_reduction_amount = 1.35, - float measure_freq = 20.0, float measure_duration = 0.0, float measure_smooth_time = 0.4, - float hp_filter_freq = 50.0, float lp_filter_freq = 6000.0, float hp_lifter_freq = 150.0, - float lp_lifter_freq = 2000.0); - - /// \brief Destructor. - ~Vad() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; - -/// \brief Vol TensorTransform. -/// \notes Add a volume to an waveform. -class DATASET_API Vol final : public TensorTransform { - public: - /// \brief Constructor. - /// \param[in] gain Gain value, varies according to the value of gain_type. If gain_type is GainType::kAmplitude, - /// gain must be greater than or equal to zero. If gain_type is GainType::kPower, gain must be greater than zero. - /// If gain_type is GainType::kDb, there is no limit for gain. - /// \param[in] gain_type Type of gain, should be one of [GainType::kAmplitude, GainType::kDb, GainType::kPower]. - explicit Vol(float gain, GainType gain_type = GainType::kAmplitude); - - /// \brief Destructor. - ~Vol() override = default; - - protected: - /// \brief Function to convert TensorTransform object into a TensorOperation object. - /// \return Shared pointer to TensorOperation object. - std::shared_ptr Parse() override; - - private: - struct Data; - std::shared_ptr data_; -}; -} // namespace audio -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ + +#include +#include +#include +#include +#include +#include + +#include "include/api/dual_abi_helper.h" +#include "include/api/status.h" +#include "include/api/types.h" +#include "include/dataset/constants.h" +#include "include/dataset/transforms.h" + +namespace mindspore { +namespace dataset { +class TensorOperation; + +// Transform operations for performing computer audio. +namespace audio { +/// \brief Compute the angle of complex tensor input. +class DATASET_API Angle final : public TensorTransform { + public: + /// \brief Constructor. + Angle(); + + /// \brief Destructor. + ~Angle() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; +}; + +/// \brief Design two-pole allpass filter. Similar to SoX implementation. +class DATASET_API AllpassBiquad final : public TensorTransform { + public: + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] central_freq Central frequency (in Hz). + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + explicit AllpassBiquad(int32_t sample_rate, float central_freq, float Q = 0.707); + + /// \brief Destructor. + ~AllpassBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief AmplitudeToDB TensorTransform. +/// \notes Turn a tensor from the power/amplitude scale to the decibel scale. +class DATASET_API AmplitudeToDB final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] stype Scale of input tensor, must be one of [ScaleType::kPower, ScaleType::kMagnitude]. + /// Default: ScaleType::kPower. + /// \param[in] ref_value Calculate db_multiplier. Default: 1.0. + /// \param[in] amin Minimum threshold for input tensor and ref_value. It must be greater than zero. Default: 1e-10. + /// \param[in] top_db Decibels cut-off value. It must be greater than or equal to zero. Default: 80.0. + explicit AmplitudeToDB(ScaleType stype = ScaleType::kPower, float ref_value = 1.0, float amin = 1e-10, + float top_db = 80.0); + + /// \brief Destructor. + ~AmplitudeToDB() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design two-pole band filter. +class DATASET_API BandBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] central_freq Central frequency (in Hz). + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + /// \param[in] noise Choose alternate mode for un-pitched audio or mode oriented to pitched audio. Default: False. + explicit BandBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool noise = false); + + /// \brief Destructor. + ~BandBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design two-pole band-pass filter. +class DATASET_API BandpassBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] central_freq Central frequency (in Hz). + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + /// \param[in] const_skirt_gain, If True, uses a constant skirt gain (peak gain = Q). If False, uses a + /// constant 0dB peak gain. Default: False. + explicit BandpassBiquad(int32_t sample_rate, float central_freq, float Q = 0.707, bool const_skirt_gain = false); + + /// \brief Destructor. + ~BandpassBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design two-pole band-reject filter. Similar to SoX implementation. +class DATASET_API BandrejectBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] central_freq Central frequency (in Hz). + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + explicit BandrejectBiquad(int32_t sample_rate, float central_freq, float Q = 0.707); + + /// \brief Destructor. + ~BandrejectBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design a bass tone-control effect. +class DATASET_API BassBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] gain Desired gain at the boost (or attenuation) in dB. + /// \param[in] central_freq Central frequency (in Hz). Default: 100.0. + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + explicit BassBiquad(int32_t sample_rate, float gain, float central_freq = 100.0, float Q = 0.707); + + /// \brief Destructor. + ~BassBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Perform a biquad filter of input tensor. +class DATASET_API Biquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] b0 Numerator coefficient of current input, x[n]. + /// \param[in] b1 Numerator coefficient of input one time step ago x[n-1]. + /// \param[in] b2 Numerator coefficient of input two time steps ago x[n-2]. + /// \param[in] a0 Denominator coefficient of current output y[n], the value can't be zero, typically 1. + /// \param[in] a1 Denominator coefficient of current output y[n-1]. + /// \param[in] a2 Denominator coefficient of current output y[n-2]. + explicit Biquad(float b0, float b1, float b2, float a0, float a1, float a2); + + /// \brief Destructor. + ~Biquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief ComplexNorm TensorTransform. +/// \notes Compute the norm of complex tensor input. +class DATASET_API ComplexNorm final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] power Power of the norm, which must be non-negative. Default: 1.0. + explicit ComplexNorm(float power = 1.0); + + /// \brief Destructor. + ~ComplexNorm() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief ComputeDeltas Transform. +/// \note Compute delta coefficients of a spectrogram. +class DATASET_API ComputeDeltas final : public TensorTransform { + public: + /// \brief Construct a new Compute Deltas object. + /// \f[ + /// d_{t}=\frac{{\textstyle\sum_{n=1}^{N}}n(c_{t+n}-c_{t-n})}{2{\textstyle\sum_{n=1}^{N}}n^{2}} + /// \f] + /// \param[in] win_length The window length used for computing delta, must be no less than 3. Default: 5. + /// \param[in] pad_mode Padding mode. Can be one of BorderType::kConstant, BorderType::kEdge, + /// BorderType::kReflect or BorderType::kSymmetric. Default: BorderType::kEdge. + explicit ComputeDeltas(int32_t win_length = 5, BorderType pad_mode = BorderType::kEdge); + + /// \brief Destructor. + ~ComputeDeltas() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply contrast effect. +class DATASET_API Contrast final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] enhancement_amount Controls the amount of the enhancement. Default: 75.0. + explicit Contrast(float enhancement_amount = 75.0); + + /// \brief Destructor. + ~Contrast() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Turn a waveform from the decibel scale to the power/amplitude scale. +class DATASET_API DBToAmplitude final : public TensorTransform { + public: + /// \brief Constructor + /// \param[in] ref Reference which the output will be scaled by. + /// \param[in] power If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. + explicit DBToAmplitude(float ref, float power); + + /// \brief Destructor. + ~DBToAmplitude() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply a DC shift to the audio. +class DATASET_API DCShift : public TensorTransform { + public: + /// \brief Constructor + /// \param[in] shift Indicates the amount to shift the audio, the value must be in the range [-2.0, 2.0]. + /// \param[in] limiter_gain Used only on peaks to prevent clipping. + DCShift(float shift, float limiter_gain); + + /// \brief Constructor + /// \param[in] shift Indicates the amount to shift the audio. + /// \note This constructor will use `shift` as `limiter_gain`. + explicit DCShift(float shift); + + /// \brief Destructor. + ~DCShift() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \param[in] n_mfcc Number of mfc coefficients to retain, the value must be greater than 0. +/// \param[in] n_mels Number of mel filterbanks, the value must be greater than 0. +/// \param[in] norm Norm to use, can be NormMode::kNone or NormMode::kOrtho. +/// \return Status error code, returns OK if no error encountered. +Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm = NormMode::kNone); + +/// \brief Design two-pole deemph filter. Similar to SoX implementation. +class DATASET_API DeemphBiquad final : public TensorTransform { + public: + /// \param[in] sample_rate Sampling rate of the waveform, the value can only be 44100 (Hz) or 48000(hz). + explicit DeemphBiquad(int32_t sample_rate); + + /// \brief Destructor. + ~DeemphBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Detect pitch frequency. +class DATASET_API DetectPitchFrequency final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] frame_time Duration of a frame, the value must be greater than zero. Default: 0.02. + /// \param[in] win_length The window length for median smoothing (in number of frames), the value must + /// be greater than zero. Default: 30. + /// \param[in] freq_low Lowest frequency that can be detected (Hz), the value must be greater than zero. Default: 85. + /// \param[in] freq_high Highest frequency that can be detected (Hz), the value must be greater than + /// zero. Default: 3400. + explicit DetectPitchFrequency(int32_t sample_rate, float frame_time = 0.01, int32_t win_length = 30, + int32_t freq_low = 85, int32_t freq_high = 3400); + + /// \brief Destructor. + ~DetectPitchFrequency() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Dither increases the perceived dynamic range of audio stored at a +/// particular bit-depth by eliminating nonlinear truncation distortion. +class DATASET_API Dither final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] density_function The density function of a continuous random variable. + /// Can be one of DensityFunction::kTPDF (Triangular Probability Density Function), + /// DensityFunction::kRPDF (Rectangular Probability Density Function) or + /// DensityFunction::kGPDF (Gaussian Probability Density Function). Default: DensityFunction::kTPDF. + /// \param[in] noise_shaping A filtering process that shapes the spectral energy of + /// quantisation error. Default: false. + explicit Dither(DensityFunction density_function = DensityFunction::kTPDF, bool noise_shaping = false); + + /// \brief Destructor. + ~Dither() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio. +class DATASET_API EqualizerBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] center_freq Filter's central frequency (in Hz). + /// \param[in] gain Desired gain at the boost (or attenuation) in dB. + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q = 0.707); + + /// \brief Destructor. + ~EqualizerBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Add fade in or/and fade out on the input audio. +class DATASET_API Fade final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] fade_in_len Length of fade-in (time frames), which must be non-negative + /// and no more than the length of waveform. Default: 0. + /// \param[in] fade_out_len Length of fade-out (time frames), which must be non-negative + /// and no more than the length of waveform. Default: 0. + /// \param[in] fade_shape An enum for the fade shape. Default: FadeShape::kLinear. + explicit Fade(int32_t fade_in_len = 0, int32_t fade_out_len = 0, FadeShape fade_shape = FadeShape::kLinear); + + /// \brief Destructor. + ~Fade() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design IIR forward and backward filter. +class DATASET_API Filtfilt final : public TensorTransform { + public: + /// \param[in] a_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). + /// Lower delays coefficients are first, e.g. [a0, a1, a2, ...]. + /// Must be same size as b_coeffs (pad with 0's as necessary). + /// \param[in] b_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). + /// Lower delays coefficients are first, e.g. [b0, b1, b2, ...]. + /// Must be same size as a_coeffs (pad with 0's as necessary). + /// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1]. Default: True. + Filtfilt(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp = true); + + /// \brief Destructor. + ~Filtfilt() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply a flanger effect to the audio. +class DATASET_API Flanger final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). + /// \param[in] delay Desired delay in milliseconds (ms), range: [0, 30]. Default: 0.0. + /// \param[in] depth Desired delay depth in milliseconds (ms), range: [0, 10]. Default: 2.0. + /// \param[in] regen Desired regen (feedback gain) in dB., range: [-95, 95]. Default: 0.0. + /// \param[in] width Desired width (delay gain) in dB, range: [0, 100]. Default: 71.0. + /// \param[in] speed Modulation speed in Hz, range: [0.1, 10]. Default: 0.5. + /// \param[in] phase Percentage phase-shift for multi-channel, range: [0, 100]. Default: 25.0. + /// \param[in] modulation Modulation of input tensor, must be one of [Modulation::kSinusoidal, + /// Modulation::kTriangular]. Default:Modulation::kSinusoidal. + /// \param[in] interpolation Interpolation of input tensor, must be one of [Interpolation::kLinear, + /// Interpolation::kQuadratic]. Default:Interpolation::kLinear. + explicit Flanger(int32_t sample_rate, float delay = 0.0, float depth = 2.0, float regen = 0.0, float width = 71.0, + float speed = 0.5, float phase = 25.0, Modulation modulation = Modulation::kSinusoidal, + Interpolation interpolation = Interpolation::kLinear); + + /// \brief Destructor. + ~Flanger() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief FrequencyMasking TensorTransform. +/// \notes Apply masking to a spectrogram in the frequency domain. +class DATASET_API FrequencyMasking final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] iid_masks Whether to apply different masks to each example. + /// \param[in] frequency_mask_param Maximum possible length of the mask, range: [0, freq_length]. Default: 0. + /// Indices uniformly sampled from [0, frequency_mask_param]. + /// Mask width when iid_masks=true. + /// \param[in] mask_start Mask start when iid_masks=true, range: [0, freq_length-frequency_mask_param]. Default: 0. + /// \param[in] mask_value Mask value. + explicit FrequencyMasking(bool iid_masks = false, int32_t frequency_mask_param = 0, int32_t mask_start = 0, + float mask_value = 0.0); + + /// \brief Destructor. + ~FrequencyMasking() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply amplification or attenuation to the whole waveform. +class DATASET_API Gain final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] gain_db Gain adjustment in decibels (dB). Default: 1.0. + explicit Gain(float gain_db = 1.0); + + /// \brief Destructor. + ~Gain() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Waveform calculation from linear scalar amplitude spectrogram using GriffinLim transform. +class DATASET_API GriffinLim final : public TensorTransform { + public: + /// \brief Constructor. + /// \notes Calculated by formula: + /// x(n)=\frac{\sum_{m=-\infty}^{\infty} w(m S-n) y_{w}(m S, n)}{\sum_{m=-\infty}^{\infty} w^{2}(m S-n)} + /// where w represents the window function, y represents the reconstructed signal of each frame and x represents + /// the whole signal. + /// \param[in] n_fft Size of FFT. Default: 400. + /// \param[in] n_iter Number of iteration for phase recovery. Default: 32. + /// \param[in] win_length Window size for GriffinLim. Default: 0, will be set to n_fft. + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to win_length / 2. + /// \param[in] window_type Window type for GriffinLim. Default: WindowType::kHann. + /// \param[in] power Exponent for the magnitude spectrogram. Default: 2.0. + /// \param[in] momentum The momentum for fast Griffin-Lim. Default: 0.99. + /// \param[in] length Length of the expected output waveform. Default: 0.0, will be set to the value of last + /// dimension of the stft matrix. + /// \param[in] rand_init Flag for random phase initialization or all-zero phase initialization. Default: true. + explicit GriffinLim(int32_t n_fft = 400, int32_t n_iter = 32, int32_t win_length = 0, int32_t hop_length = 0, + WindowType window_type = WindowType::kHann, float power = 2.0, float momentum = 0.99, + int32_t length = 0, bool rand_init = true); + + /// \brief Destructor. + ~GriffinLim() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief HighpassBiquad TensorTransform. Apply highpass biquad filter on audio. +class DATASET_API HighpassBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] cutoff_freq Filter cutoff frequency (in Hz). + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q = 0.707); + + /// \brief Destructor. + ~HighpassBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief InverseMelScale TensorTransform +/// \notes Solve for a normal STFT from a mel frequency STFT, using a conversion matrix. +class DATASET_API InverseMelScale final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] n_stft Number of bins in STFT, must be positive. + /// \param[in] n_mels Number of mel filter, must be positive. Default: 128. + /// \param[in] sample_rate Sample rate of the signal, the value can't be zero. Default: 16000. + /// \param[in] f_min Minimum frequency, must be non-negative. Default: 0.0. + /// \param[in] f_max Maximum frequency, must be non-negative. Default: 0.0, will be set to sample_rate / 2. + /// \param[in] max_iter Maximum number of optimization iterations, must be positive. Default: 100000. + /// \param[in] tolerance_loss Value of loss to stop optimization at, must be non-negative. Default: 1e-5. + /// \param[in] tolerance_change Difference in losses to stop optimization at, must be non-negative. Default: 1e-8. + /// \param[in] sgdargs Parameters of SGD optimizer, including lr, momentum. + /// Default: {{"sgd_lr", 0.1}, {"sgd_momentum", 0.0}}. + /// \param[in] norm Type of norm, value should be NormType::kSlaney or NormType::kNone. If norm is NormType::kSlaney, + /// divide the triangle mel weight by the width of the mel band. Default: NormType::kNone. + /// \param[in] mel_type Type of mel, value should be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtk. + explicit InverseMelScale(int32_t n_stft, int32_t n_mels = 128, int32_t sample_rate = 16000, float f_min = 0.0, + float f_max = 0.0, int32_t max_iter = 100000, float tolerance_loss = 1e-5, + float tolerance_change = 1e-8, + const std::map &sgdargs = {{"sgd_lr", 0.1}, {"sgd_momentum", 0.0}}, + NormType norm = NormType::kNone, MelType mel_type = MelType::kHtk); + + /// \brief Destructor. + ~InverseMelScale() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create an inverse spectrogram to recover an audio signal from a spectrogram. +class DATASET_API InverseSpectrogram final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] length The output length of the waveform. Default: 0, means to output the whole waveform. + /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. + /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 2` . + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. + /// Default: WindowType::kHann. + /// \param[in] normalized Whether the spectrogram was normalized by magnitude after stft. Default:false. + /// \param[in] center Whether the signal in spectrogram was padded on both sides. Default: true. + /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. + /// \param[in] onesided Controls whether spectrogram was used to return half of results to avoid + /// redundancy. Default: true. + explicit InverseSpectrogram(int32_t length = 0, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, + int32_t pad = 0, WindowType window = WindowType::kHann, bool normalized = false, + bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true); + + /// \brief Destructor. + ~InverseSpectrogram() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create LFCC for a raw audio signal. +class DATASET_API LFCC final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. + /// \param[in] n_filter Number of linear filters to apply. Default: 128. + /// \param[in] n_lfcc Number of lfc coefficients to retain. Default: 40. + /// \param[in] f_min Minimum frequency. Default: 0.0. + /// \param[in] f_max Maximum frequency. Default: 0.0, will be set to sample_rate // 2. + /// \param[in] dct_type Type of DCT (discrete cosine transform) to use. Default: 2. + /// \param[in] norm Norm to use. Default: NormMode::kOrtho. + /// \param[in] log_lf Whether to use log-lf spectrograms instead of db-scaled. Default: false. + /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. + /// \param[in] win_length Window size. Default: 0, will be set to n_fft. + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to win_length // 2. + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] window A function to create a window tensor that is applied/multiplied to + /// each frame/window. Default: WindowType::kHann. + /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 + /// for power, etc. Default: 2.0. + /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false + /// \param[in] center Whether to pad waveform on both sides so that the tt-th frame is centered at + /// time t t*hop_length. Default: true. + /// \param[in] pad_mode Controls the padding method used when center is True. Default: + /// BorderType::kReflect. + /// \param[in] onesided Controls whether to return half of results to avoid + /// redundancy. Default: true. + explicit LFCC(int32_t sample_rate = 16000, int32_t n_filter = 128, int32_t n_lfcc = 40, float f_min = 0.0, + float f_max = 0.0, int32_t dct_type = 2, NormMode norm = NormMode::kOrtho, bool log_lf = false, + int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0, + WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true, + BorderType pad_mode = BorderType::kReflect, bool onesided = true); + + /// \brief Destructor. + ~LFCC() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design filter. Similar to SoX implementation. +class DATASET_API LFilter final : public TensorTransform { + public: + /// \param[in] a_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). + /// Lower delays coefficients are first, e.g. [a0, a1, a2, ...]. + /// Must be same size as b_coeffs (pad with 0's as necessary). + /// \param[in] b_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1). + /// Lower delays coefficients are first, e.g. [b0, b1, b2, ...]. + /// Must be same size as a_coeffs (pad with 0's as necessary). + /// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1]. Default: True. + explicit LFilter(const std::vector &a_coeffs, const std::vector &b_coeffs, bool clamp = true); + + /// \brief Destructor. + ~LFilter() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Creates a linear triangular filterbank. +/// \param output Tensor of a linear triangular filterbank. +/// \param n_freqs: Number of frequency. +/// \param f_min: Minimum of frequency in Hz. +/// \param f_max: Maximum of frequency in Hz. +/// \param n_filter: Number of (linear) triangular filter. +/// \param sample_rate: Sample rate. +/// \return Status code. +Status DATASET_API LinearFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_filter, + int32_t sample_rate); + +/// \brief Design biquad lowpass filter and perform filtering. Similar to SoX implementation. +class DATASET_API LowpassBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] cutoff_freq Filter cutoff frequency. + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + LowpassBiquad(int32_t sample_rate, float cutoff_freq, float Q = 0.707); + + /// \brief Destructor. + ~LowpassBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase. +class DATASET_API Magphase final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] power Power of the norm, which must be non-negative. Default: 1.0. + explicit Magphase(float power); + + /// \brief Destructor. + ~Magphase() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief MaskAlongAxis TensorTransform. +/// \note Tensor operation to mask the input tensor along axis. +class MaskAlongAxis final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] mask_start Starting position of the mask, which must be non negative. + /// \param[in] mask_width The width of the mask, which must be positive. + /// \param[in] mask_value Value to assign to the masked columns. + /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). + MaskAlongAxis(int32_t mask_start, int32_t mask_width, float mask_value, int32_t axis); + + /// \brief Destructor. + ~MaskAlongAxis() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief MaskAlongAxisIID TensorTransform. +/// \note Apply a mask along axis. +class MaskAlongAxisIID final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] mask_param Number of columns to be masked, will be uniformly sampled from [0, mask_param], + /// must be non negative. + /// \param[in] mask_value Value to assign to the masked columns. + /// \param[in] axis Axis to apply masking on (1 for frequency and 2 for time). + MaskAlongAxisIID(int32_t mask_param, float mask_value, int32_t axis); + + /// \brief Destructor. + ~MaskAlongAxisIID() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief MelScale TensorTransform. +/// \notes Convert normal STFT to STFT at the Mel scale. +class DATASET_API MelScale final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] n_mels Number of mel filter, which must be positive. Default: 128. + /// \param[in] sample_rate Sample rate of the signal, the value can't be zero. Default: 16000. + /// \param[in] f_min Minimum frequency, which must be non negative. Default: 0.0. + /// \param[in] f_max Maximum frequency, which must be positive. Default: 0.0, will be set to sample_rate / 2. + /// \param[in] n_stft Number of bins in STFT, which must be positive. Default: 201. + /// \param[in] norm Type of norm, value should be NormType::kSlaney or NormType::kNone. If norm is NormType::kSlaney, + /// divide the triangle mel weight by the width of the mel band. Default: NormType::kNone. + /// \param[in] mel_type Type of mel, value should be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtk. + explicit MelScale(int32_t n_mels = 128, int32_t sample_rate = 16000, float f_min = 0.0, float f_max = 0.0, + int32_t n_stft = 201, NormType norm = NormType::kNone, MelType mel_type = MelType::kHtk); + + /// \brief Destructor. + ~MelScale() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create a frequency transformation matrix with shape (n_freqs, n_mels). +/// \param[in] output Tensor of the frequency transformation matrix. +/// \param[in] n_freqs Number of frequencies to highlight/apply. +/// \param[in] f_min Minimum frequency (Hz). +/// \param[in] f_max Maximum frequency (Hz). +/// \param[in] n_mels Number of mel filterbanks. +/// \param[in] sample_rate Sample rate of the audio waveform. +/// \param[in] norm Norm to use, can be NormType::kNone or NormType::kSlaney. Default: NormType::kNone. +/// \param[in] mel_type Scale to use, can be MelType::kHtk or MelType::kSlaney. Default: MelType::kHtz. +/// \return Status code. +Status DATASET_API MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels, + int32_t sample_rate, NormType norm = NormType::kNone, + MelType mel_type = MelType::kHtk); + +/// \brief Create MelSpectrogram for a raw audio signal. +class DATASET_API MelSpectrogram final : public TensorTransform { + public: + /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. + /// \param[in] n_fft Size of FFT, creates `n_fft // 2 + 1` bins. Default: 400. + /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 2` . + /// \param[in] f_min Minimum frequency. Default: 0.0. + /// \param[in] f_max Maximum frequency. Default: 0.0. + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] n_mels Number of mel filterbanks. Default: 128. + /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. + /// Default: WindowType::kHann. + /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. + /// Default: 2.0. + /// \param[in] normalized Whether to normalize by magnitude after stft Default: false. + /// \param[in] center Whether to pad waveform on both sides. Default: true. + /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. + /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. + /// \param[in] norm If 'slaney', divide the triangular mel weights by the width of the mel band (area normalization). + /// Default: NormType::kNone. + /// \param[in] mel_scale Scale to use: htk or slaney. Default: MelType::kHtk. + explicit MelSpectrogram(int32_t sample_rate = 16000, int32_t n_fft = 400, int32_t win_length = 0, + int32_t hop_length = 0, float f_min = 0.0, float f_max = 0.0, int32_t pad = 0, + int32_t n_mels = 128, WindowType window = WindowType::kHann, float power = 2.0, + bool normalized = false, bool center = true, BorderType pad_mode = BorderType::kReflect, + bool onesided = true, NormType norm = NormType::kNone, MelType mel_scale = MelType::kHtk); + + /// \brief Destructor. + ~MelSpectrogram() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create MFCC for a raw audio signal. +class DATASET_API MFCC final : public TensorTransform { + public: + /// \param[in] sample_rate Sample rate of audio signal. Default: 16000. + /// \param[in] n_mfcc Number of mfc coefficients to retain. Default: 40. + /// \param[in] dct_type Type of DCT (discrete cosine transform) to use. Default: 2. + /// \param[in] norm If 'slaney', divide the triangular mel weights by the width of the mel band (area normalization). + /// Default: NormMode::kOrtho. + /// \param[in] log_mels Whether to use log-mel spectrograms instead of db-scaled. Default: false. + /// \param[in] n_fft Size of FFT, creates n_fft // 2 + 1 bins. Default: 400. + /// \param[in] win_length Window size. Default: 0. + /// \param[in] hop_length Length of hop between STFT windows. Default: 0. + /// \param[in] f_min Minimum frequency. Default: 0.0. + /// \param[in] f_max Maximum frequency. Default: 0.0. + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] n_mels Number of mel filterbanks. Default: 128. + /// \param[in] window A function to create a window tensor that is applied/multiplied to each frame/window. + /// Default: WindowType::kHann. + /// \param[in] power Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. + /// Default: 2.0. + /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false. + /// \param[in] center Whether to pad waveform on both sides. Default: true. + /// \param[in] pad_mode Controls the padding method used when center is True. Default: BorderType::kReflect. + /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. + /// \param[in] norm_mel Norm to use. Default: NormType::kNone. + /// \param[in] mel_scale Scale to use: htk or slaney. Default: MelType::kHtk. + explicit MFCC(int32_t sample_rate = 16000, int32_t n_mfcc = 40, int32_t dct_type = 2, + NormMode norm = NormMode::kOrtho, bool log_mels = false, int32_t n_fft = 400, int32_t win_length = 0, + int32_t hop_length = 0, float f_min = 0.0, float f_max = 0.0, int32_t pad = 0, int32_t n_mels = 128, + WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, bool center = true, + BorderType pad_mode = BorderType::kReflect, bool onesided = true, NormType norm_mel = NormType::kNone, + MelType mel_scale = MelType::kHtk); + + /// \brief Destructor. + ~MFCC() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief MuLawDecoding TensorTransform. +/// \note Decode mu-law encoded signal. +class DATASET_API MuLawDecoding final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] quantization_channels Number of channels, which must be positive. Default: 256. + explicit MuLawDecoding(int32_t quantization_channels = 256); + + /// \brief Destructor. + ~MuLawDecoding() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief MuLawEncoding TensorTransform. +/// \note Encode signal based on mu-law companding. +class DATASET_API MuLawEncoding final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] quantization_channels Number of channels, which must be positive. Default: 256. + explicit MuLawEncoding(int32_t quantization_channels = 256); + + /// \brief Destructor. + ~MuLawEncoding() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Overdrive TensorTransform. +class DATASET_API Overdrive final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] gain Coefficient of overload in dB, in range of [0, 100]. Default: 20.0. + /// \param[in] color Coefficient of translation, in range of [0, 100]. Default: 20.0. + explicit Overdrive(float gain = 20.0, float color = 20.0); + + /// \brief Destructor. + ~Overdrive() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Phaser TensorTransform. +class DATASET_API Phaser final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). + /// \param[in] gain_in Desired input gain at the boost (or attenuation) in dB. + /// Allowed range of values is [0, 1]. Default: 0.4. + /// \param[in] gain_out Desired output gain at the boost (or attenuation) in dB. + /// Allowed range of values is [0, 1e9]. Default: 0.74. + /// \param[in] delay_ms Desired delay in milli seconds. Allowed range of values is [0, 5]. Default: 3.0. + /// \param[in] decay Desired decay relative to gain-in. Allowed range of values is [0, 0.99]. Default: 0.4. + /// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2]. Default: 0.5. + /// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments). + /// If false, use triangular modulation (gives single instruments a sharper phasing effect). Default: true. + explicit Phaser(int32_t sample_rate, float gain_in = 0.4, float gain_out = 0.74, float delay_ms = 3.0, + float decay = 0.4, float mod_speed = 0.5, bool sinusoidal = true); + + /// \brief Destructor. + ~Phaser() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief PhaseVocoder TensorTransform +/// \notes Given a STFT tensor, speed up in time without modifying pitch by factor of rate. +class DATASET_API PhaseVocoder final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] rate Speed-up factor. + /// \param[in] phase_advance Expected phase advance in each bin in shape of (freq, 1). + PhaseVocoder(float rate, const MSTensor &phase_advance); + + /// \brief Destructor. + ~PhaseVocoder() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +// \brief Shift the pitch of a waveform by 'n_steps' steps. +class DATASET_API PitchShift final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of "waveform". Default: 0. + /// \param[in] n_steps The (fractional) steps to shift "waveform". Default: 0. + /// \param[in] bins_per_octave The number of steps per octave. Default: 12. + /// \param[in] n_fft Size of FFT, creates "n_fft // 2 + 1" bins. Default: 512. + /// \param[in] win_length Window size. Default: 0, will be set to `n_fft` . + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will be set to `win_length // 4` . + /// \param[in] window Window tensor that is applied/multiplied to each frame/window. Default: WindowType::kHann. + explicit PitchShift(int32_t sample_rate = 0, int32_t n_steps = 0, int32_t bins_per_octave = 12, int32_t n_fft = 512, + int32_t win_length = 0, int32_t hop_length = 0, WindowType window = WindowType::kHann); + + /// \brief Destructor. + ~PitchShift() override = default; + + protected: + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Resample TensorTransform. +/// \notes Resample a signal from one frequency to another. A sampling method can be given. +class DATASET_API Resample : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] orig_freq The original frequency of the signal, which must be positive. Default: 16000.0. + /// \param[in] new_freq The desired frequency, which must be positive. Default: 16000.0. + /// \param[in] resample_method The resampling method, which can be ResampleMethod::kSincInterpolation + /// and ResampleMethod::kKaiserWindow. Default: ResampleMethod::kSincInterpolation. + /// \param[in] lowpass_filter_width Controls the sharpness of the filter, more means sharper but less efficient, + /// which must be positive. Default: 6. + /// \param[in] rolloff The roll-off frequency of the filter, as a fraction of the Nyquist. Lower values + /// reduce anti-aliasing, but also reduce some of the highest frequencies, range: (0, 1]. Default: 0.99. + /// \param[in] beta The shape parameter used for kaiser window. Default: 14.769656459379492. + explicit Resample(float orig_freq = 16000.0, float new_freq = 16000.0, + ResampleMethod resample_method = ResampleMethod::kSincInterpolation, + int32_t lowpass_filter_width = 6, float rolloff = 0.99, float beta = 14.769656459379492); + + /// \brief Destructor. + ~Resample() override = default; + + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply RIAA vinyl playback equalization. +class DATASET_API RiaaBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), + /// can only be one of 44100, 48000, 88200, 96000. + explicit RiaaBiquad(int32_t sample_rate); + + /// \brief Destructor. + ~RiaaBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. +class DATASET_API SlidingWindowCmn final : public TensorTransform { + public: + /// \brief Constructor of SlidingWindowCmnOp. + /// \param[in] cmn_window The window in frames for running average CMN computation. Default: 600. + /// \param[in] min_cmn_window The minimum CMN window. Only applicable if center is false, ignored if center + /// is true. Default: 100. + /// \param[in] center If true, use a window centered on the current frame. If false, window is to the left. + /// Default: false. + /// \param[in] norm_vars If true, normalize variance to one. Default: false. + explicit SlidingWindowCmn(int32_t cmn_window = 600, int32_t min_cmn_window = 100, bool center = false, + bool norm_vars = false); + + /// \brief Destructor. + ~SlidingWindowCmn() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create a spectral centroid from an audio signal. +class DATASET_API SpectralCentroid : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz). + /// \param[in] n_fft Size of FFT, creates n_fft / 2 + 1 bins. Default: 400. + /// \param[in] win_length Window size. Default: 0, will use n_fft. + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will use win_length / 2. + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] window Window function that is applied/multiplied to each frame/window, + /// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming, + /// WindowType::kHann or WindowType::kKaiser. Default: WindowType::kHann. + explicit SpectralCentroid(int32_t sample_rate, int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, + int32_t pad = 0, WindowType window = WindowType::kHann); + + ~SpectralCentroid() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + int32_t sample_rate_; + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; + struct Data; + std::shared_ptr data_; +}; + +/// \brief Create a spectrogram from an audio signal. +class DATASET_API Spectrogram : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] n_fft Size of FFT, creates n_fft / 2 + 1 bins. Default: 400. + /// \param[in] win_length Window size. Default: 0, will use n_fft. + /// \param[in] hop_length Length of hop between STFT windows. Default: 0, will use win_length / 2. + /// \param[in] pad Two sided padding of signal. Default: 0. + /// \param[in] window Window function that is applied/multiplied to each frame/window, + /// which can be WindowType::kBartlett, WindowType::kBlackman, WindowType::kHamming, + /// WindowType::kHann or WindowType::kKaiser. Default: WindowType::kHann. + /// \param[in] power Exponent for the magnitude spectrogram, which must be greater than or equal to 0. Default: 2.0. + /// \param[in] normalized Whether to normalize by magnitude after stft. Default: false. + /// \param[in] center Whether to pad waveform on both sides. Default: true. + /// \param[in] pad_mode Controls the padding method used when center is true, + /// which can be BorderType::kReflect, BorderType::kConstant, BorderType::kEdge, + /// BorderType::kSymmetric. Default: BorderType::kReflect. + /// \param[in] onesided Controls whether to return half of results to avoid redundancy. Default: true. + explicit Spectrogram(int32_t n_fft = 400, int32_t win_length = 0, int32_t hop_length = 0, int32_t pad = 0, + WindowType window = WindowType::kHann, float power = 2.0, bool normalized = false, + bool center = true, BorderType pad_mode = BorderType::kReflect, bool onesided = true); + + /// \brief Destructor. + ~Spectrogram() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + int32_t n_fft_; + int32_t win_length_; + int32_t hop_length_; + int32_t pad_; + WindowType window_; + float power_; + bool normalized_; + bool center_; + BorderType pad_mode_; + bool onesided_; + struct Data; + std::shared_ptr data_; +}; + +/// \brief TimeMasking TensorTransform. +/// \notes Apply masking to a spectrogram in the time domain. +class DATASET_API TimeMasking final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] iid_masks Whether to apply different masks to each example. + /// \param[in] time_mask_param Maximum possible length of the mask, range: [0, time_length]. Default: 0. + /// Indices uniformly sampled from [0, time_mask_param]. + /// Mask width when iid_masks=true. + /// \param[in] mask_start Mask start when iid_masks=true, range: [0, time_length-time_mask_param]. Default: 0. + /// \param[in] mask_value Mask value. + explicit TimeMasking(bool iid_masks = false, int32_t time_mask_param = 0, int32_t mask_start = 0, + float mask_value = 0.0); + + /// \brief Destructor. + ~TimeMasking() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief TimeStretch TensorTransform +/// \notes Stretch STFT in time at a given rate, without changing the pitch. +class DATASET_API TimeStretch final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] hop_length Length of hop between STFT windows. Default: None, will use ((n_freq - 1) * 2) // 2. + /// \param[in] n_freq Number of filter banks form STFT. Default: 201. + /// \param[in] fixed_rate Rate to speed up or slow down the input in time. + /// Default: std::numeric_limits::quiet_NaN(), will keep the original rate. + explicit TimeStretch(float hop_length = std::numeric_limits::quiet_NaN(), int n_freq = 201, + float fixed_rate = std::numeric_limits::quiet_NaN()); + + /// \brief Destructor. + ~TimeStretch() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Design a treble tone-control effect. +class DATASET_API TrebleBiquad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero. + /// \param[in] gain Desired gain at the boost (or attenuation) in dB. + /// \param[in] central_freq Central frequency (in Hz). Default: 3000.0. + /// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1]. Default: 0.707. + TrebleBiquad(int32_t sample_rate, float gain, float central_freq = 3000.0, float Q = 0.707); + + /// \brief Destructor. + ~TrebleBiquad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Vad TensorTransform. +/// \notes Attempt to trim silent background sounds from the end of the voice recording. +class DATASET_API Vad final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] sample_rate Sample rate of audio signal. + /// \param[in] trigger_level The measurement level used to trigger activity detection. Default: 7.0. + /// \param[in] trigger_time The time constant (in seconds) used to help ignore short sounds. Default: 0.25. + /// \param[in] search_time The amount of audio (in seconds) to search for quieter/shorter sounds to include prior to + /// the detected trigger point. Default: 1.0. + /// \param[in] allowed_gap The allowed gap (in seconds) between quiteter/shorter sounds to include prior to the + /// detected trigger point. Default: 0.25. + /// \param[in] pre_trigger_time The amount of audio (in seconds) to preserve before the trigger point and any found + /// quieter/shorter bursts. Default: 0.0. + /// \param[in] boot_time The time for the initial noise estimate. Default: 0.35. + /// \param[in] noise_up_time Time constant used by the adaptive noise estimator, when the noise level is increasing. + /// Default: 0.1. + /// \param[in] noise_down_time Time constant used by the adaptive noise estimator, when the noise level is decreasing. + /// Default: 0.01. + /// \param[in] noise_reduction_amount The amount of noise reduction used in the detection algorithm. Default: 1.35. + /// \param[in] measure_freq The frequency of the algorithm’s processing. Default: 20.0. + /// \param[in] measure_duration The duration of measurement. Default: 0, use twice the measurement period. + /// \param[in] measure_smooth_time The time constant used to smooth spectral measurements. Default: 0.4. + /// \param[in] hp_filter_freq The "Brick-wall" frequency of high-pass filter applied at the input to the detector + /// algorithm. Default: 50.0. + /// \param[in] lp_filter_freq The "Brick-wall" frequency of low-pass filter applied at the input to the detector + /// algorithm. Default: 6000.0. + /// \param[in] hp_lifter_freq The "Brick-wall" frequency of high-pass lifter applied at the input to the detector + /// algorithm. Default: 150.0. + /// \param[in] lp_lifter_freq The "Brick-wall" frequency of low-pass lifter applied at the input to the detector + /// algorithm. Default: 2000.0. + explicit Vad(int32_t sample_rate, float trigger_level = 7.0, float trigger_time = 0.25, float search_time = 1.0, + float allowed_gap = 0.25, float pre_trigger_time = 0.0, float boot_time = 0.35, + float noise_up_time = 0.1, float noise_down_time = 0.01, float noise_reduction_amount = 1.35, + float measure_freq = 20.0, float measure_duration = 0.0, float measure_smooth_time = 0.4, + float hp_filter_freq = 50.0, float lp_filter_freq = 6000.0, float hp_lifter_freq = 150.0, + float lp_lifter_freq = 2000.0); + + /// \brief Destructor. + ~Vad() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + +/// \brief Vol TensorTransform. +/// \notes Add a volume to an waveform. +class DATASET_API Vol final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] gain Gain value, varies according to the value of gain_type. If gain_type is GainType::kAmplitude, + /// gain must be greater than or equal to zero. If gain_type is GainType::kPower, gain must be greater than zero. + /// If gain_type is GainType::kDb, there is no limit for gain. + /// \param[in] gain_type Type of gain, should be one of [GainType::kAmplitude, GainType::kDb, GainType::kPower]. + explicit Vol(float gain, GainType gain_type = GainType::kAmplitude); + + /// \brief Destructor. + ~Vol() override = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt index 32e59cbfcb4..e5d7c5f79ff 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt @@ -1,52 +1,52 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_definitions(-DENABLE_DVPP_INTERFACE) -set(DVPP_IMAGE_SOURCE - # Ascend310 - ascend310/dvpp_crop_jpeg_op.cc - ascend310/dvpp_decode_resize_crop_jpeg_op.cc - ascend310/dvpp_decode_resize_jpeg_op.cc - ascend310/dvpp_decode_jpeg_op.cc - ascend310/dvpp_decode_png_op.cc - ascend310/dvpp_decode_video_op.cc - ascend310/dvpp_normalize_op.cc - ascend310/dvpp_resize_jpeg_op.cc - # adaptor - acl_adapter.cc - ) - -if(NOT BUILD_LITE AND ENABLE_D) -set(DVPP_IMAGE_SOURCE - ${DVPP_IMAGE_SOURCE} - # Ascend910B - ascend910b/dvpp_adjust_brightness_op.cc - ascend910b/dvpp_adjust_contrast_op.cc - ascend910b/dvpp_adjust_hue_op.cc - ascend910b/dvpp_adjust_saturation_op.cc - ascend910b/dvpp_adjust_sharpness_op.cc - ascend910b/dvpp_affine_op.cc - ascend910b/dvpp_auto_contrast_op.cc - ascend910b/dvpp_crop_op.cc - ascend910b/dvpp_convert_color_op.cc - ascend910b/dvpp_decode_op.cc - ascend910b/dvpp_equalize_op.cc - ascend910b/dvpp_erase_op.cc - ascend910b/dvpp_gaussian_blur_op.cc - ascend910b/dvpp_horizontal_flip_op.cc - ascend910b/dvpp_invert_op.cc - ascend910b/dvpp_normalize_v2_op.cc - ascend910b/dvpp_pad_op.cc - ascend910b/dvpp_perspective_op.cc - ascend910b/dvpp_posterize_op.cc - ascend910b/dvpp_resize_op.cc - ascend910b/dvpp_resized_crop_op.cc - ascend910b/dvpp_rotate_op.cc - ascend910b/dvpp_solarize_op.cc - ascend910b/dvpp_vertical_flip_op.cc - ) -endif() - -add_library(kernels-dvpp-image OBJECT ${DVPP_IMAGE_SOURCE}) -if(ENABLE_ACL OR MSLITE_ENABLE_ACL) - add_subdirectory(utils) -endif() +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_definitions(-DENABLE_DVPP_INTERFACE) +set(DVPP_IMAGE_SOURCE + # Ascend310 + ascend310/dvpp_crop_jpeg_op.cc + ascend310/dvpp_decode_resize_crop_jpeg_op.cc + ascend310/dvpp_decode_resize_jpeg_op.cc + ascend310/dvpp_decode_jpeg_op.cc + ascend310/dvpp_decode_png_op.cc + ascend310/dvpp_decode_video_op.cc + ascend310/dvpp_normalize_op.cc + ascend310/dvpp_resize_jpeg_op.cc + # adaptor + acl_adapter.cc + ) + +if(NOT BUILD_LITE AND ENABLE_D) +set(DVPP_IMAGE_SOURCE + ${DVPP_IMAGE_SOURCE} + # Ascend910B + ascend910b/dvpp_adjust_brightness_op.cc + ascend910b/dvpp_adjust_contrast_op.cc + ascend910b/dvpp_adjust_hue_op.cc + ascend910b/dvpp_adjust_saturation_op.cc + ascend910b/dvpp_adjust_sharpness_op.cc + ascend910b/dvpp_affine_op.cc + ascend910b/dvpp_auto_contrast_op.cc + ascend910b/dvpp_crop_op.cc + ascend910b/dvpp_convert_color_op.cc + ascend910b/dvpp_decode_op.cc + ascend910b/dvpp_equalize_op.cc + ascend910b/dvpp_erase_op.cc + ascend910b/dvpp_gaussian_blur_op.cc + ascend910b/dvpp_horizontal_flip_op.cc + ascend910b/dvpp_invert_op.cc + ascend910b/dvpp_normalize_v2_op.cc + ascend910b/dvpp_pad_op.cc + ascend910b/dvpp_perspective_op.cc + ascend910b/dvpp_posterize_op.cc + ascend910b/dvpp_resize_op.cc + ascend910b/dvpp_resized_crop_op.cc + ascend910b/dvpp_rotate_op.cc + ascend910b/dvpp_solarize_op.cc + ascend910b/dvpp_vertical_flip_op.cc + ) +endif() + +add_library(kernels-dvpp-image OBJECT ${DVPP_IMAGE_SOURCE}) +if(ENABLE_ACL OR MSLITE_ENABLE_ACL) + add_subdirectory(utils) +endif() diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_crop_jpeg_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_crop_jpeg_op.h index 10f1c3b4e1b..a659f5ba0ac 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_crop_jpeg_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_crop_jpeg_op.h @@ -1,60 +1,60 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/device_tensor.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/log_adapter.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DvppCropJpegOp : public TensorOp { - public: - DvppCropJpegOp(int32_t crop_height, int32_t crop_width) : crop_height_(crop_height), crop_width_(crop_width) {} - - /// \brief Destructor - ~DvppCropJpegOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppCropJpegOp; } - - Status SetAscendResource(const std::shared_ptr &resource) override; - - private: - uint32_t crop_height_; - uint32_t crop_width_; - std::shared_ptr processor_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/device_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DvppCropJpegOp : public TensorOp { + public: + DvppCropJpegOp(int32_t crop_height, int32_t crop_width) : crop_height_(crop_height), crop_width_(crop_width) {} + + /// \brief Destructor + ~DvppCropJpegOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppCropJpegOp; } + + Status SetAscendResource(const std::shared_ptr &resource) override; + + private: + uint32_t crop_height_; + uint32_t crop_width_; + std::shared_ptr processor_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_CROP_JPEG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.cc index c98f5d34ecd..1c818030fe5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.cc @@ -1,146 +1,146 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h" - -#include -#include - -#include "minddata/dataset/core/cv_tensor.h" -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_tensor.h" -#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" -#include "minddata/dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -// Compute() will be called when context=="Ascend310" -Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer on device is empty."); - APP_ERROR ret = AclAdapter::GetInstance().JPEG_D(processor_.get()); - if (ret != APP_ERR_OK) { - ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get()); - const TensorShape dvpp_shape({1, 1, 1}); - const DataType dvpp_data_type(DataType::DE_UINT8); - RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); - RETURN_IF_NOT_OK((*output)->SetAttributes(DecodeOut->data, DecodeOut->dataSize, DecodeOut->width, - DecodeOut->widthStride, DecodeOut->height, DecodeOut->heightStride)); - if (!((*output)->HasDeviceData())) { - std::string error = "[ERROR] Fail to get the Output result from memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppDecodeJpegOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -// Compute() will be called when context=="CPU" -Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (!IsNonEmptyJPEG(input)) { - RETURN_STATUS_UNEXPECTED("DvppDecodeJpegOp only support process JPEG image."); - } - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); - auto *buffer = const_cast(input->GetBuffer()); - RawData imageInfo{}; - uint32_t filesize = input->SizeInBytes(); - imageInfo.lenOfByte = filesize; - imageInfo.data = static_cast(buffer); - ResourceInfo resource; - resource.deviceIds.insert(0); - APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init D-chip: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - int deviceId = *(resource.deviceIds.begin()); - void *context = AclAdapter::GetInstance().GetContext(deviceId); - // Second part end where we initialize the resource of D-chip and set up all configures - std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcess(context, false, nullptr, nullptr), - [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); - ret = AclAdapter::GetInstance().InitAclProcess(process.get()); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init resource: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - ret = AclAdapter::GetInstance().JPEG_D_WITH_DATA(process.get(), imageInfo); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - // Third part end where we execute the core function of dvpp - auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); - DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(process.get()); - dsize_t dvpp_length = DecodeOut->dataSize; - uint32_t decoded_height = DecodeOut->height; - uint32_t decoded_heightStride = DecodeOut->heightStride; - uint32_t decoded_width = DecodeOut->width; - uint32_t decoded_widthStride = DecodeOut->widthStride; - - const TensorShape dvpp_shape({dvpp_length, 1, 1}); - const DataType dvpp_data_type(DataType::DE_UINT8); - RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output)); - RETURN_IF_NOT_OK((*output)->SetYuvShape(decoded_width, decoded_widthStride, decoded_height, decoded_heightStride)); - if (!((*output)->HasData())) { - std::string error = "[ERROR] Fail to get the Output result from device memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); - ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); - // Last part end where we transform the processed data into a tensor which can be applied in later units. - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppDecodeJpegOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppDecodeJpegOp::SetAscendResource(const std::shared_ptr &resource) { - processor_ = resource->GetInstance(); - if (!processor_) { - RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); - } - return Status::OK(); -} - -Status DvppDecodeJpegOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 3 channels - CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodeJpeg: inputs cannot be empty."); - if (inputs[0].Rank() == 1) { - (void)outputs.emplace_back(out); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodeJpeg: Invalid input shape."); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h" + +#include +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_tensor.h" +#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +// Compute() will be called when context=="Ascend310" +Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer on device is empty."); + APP_ERROR ret = AclAdapter::GetInstance().JPEG_D(processor_.get()); + if (ret != APP_ERR_OK) { + ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get()); + const TensorShape dvpp_shape({1, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); + RETURN_IF_NOT_OK((*output)->SetAttributes(DecodeOut->data, DecodeOut->dataSize, DecodeOut->width, + DecodeOut->widthStride, DecodeOut->height, DecodeOut->heightStride)); + if (!((*output)->HasDeviceData())) { + std::string error = "[ERROR] Fail to get the Output result from memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppDecodeJpegOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +// Compute() will be called when context=="CPU" +Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!IsNonEmptyJPEG(input)) { + RETURN_STATUS_UNEXPECTED("DvppDecodeJpegOp only support process JPEG image."); + } + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); + auto *buffer = const_cast(input->GetBuffer()); + RawData imageInfo{}; + uint32_t filesize = input->SizeInBytes(); + imageInfo.lenOfByte = filesize; + imageInfo.data = static_cast(buffer); + ResourceInfo resource; + resource.deviceIds.insert(0); + APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init D-chip: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + int deviceId = *(resource.deviceIds.begin()); + void *context = AclAdapter::GetInstance().GetContext(deviceId); + // Second part end where we initialize the resource of D-chip and set up all configures + std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcess(context, false, nullptr, nullptr), + [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); + ret = AclAdapter::GetInstance().InitAclProcess(process.get()); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init resource: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + ret = AclAdapter::GetInstance().JPEG_D_WITH_DATA(process.get(), imageInfo); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + // Third part end where we execute the core function of dvpp + auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); + DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(process.get()); + dsize_t dvpp_length = DecodeOut->dataSize; + uint32_t decoded_height = DecodeOut->height; + uint32_t decoded_heightStride = DecodeOut->heightStride; + uint32_t decoded_width = DecodeOut->width; + uint32_t decoded_widthStride = DecodeOut->widthStride; + + const TensorShape dvpp_shape({dvpp_length, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output)); + RETURN_IF_NOT_OK((*output)->SetYuvShape(decoded_width, decoded_widthStride, decoded_height, decoded_heightStride)); + if (!((*output)->HasData())) { + std::string error = "[ERROR] Fail to get the Output result from device memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); + ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); + // Last part end where we transform the processed data into a tensor which can be applied in later units. + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppDecodeJpegOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppDecodeJpegOp::SetAscendResource(const std::shared_ptr &resource) { + processor_ = resource->GetInstance(); + if (!processor_) { + RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); + } + return Status::OK(); +} + +Status DvppDecodeJpegOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 3 channels + CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodeJpeg: inputs cannot be empty."); + if (inputs[0].Rank() == 1) { + (void)outputs.emplace_back(out); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodeJpeg: Invalid input shape."); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h index 6600047b729..4e5c087a1f6 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_jpeg_op.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/log_adapter.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DvppDecodeJpegOp : public TensorOp { - public: - DvppDecodeJpegOp() : processor_(nullptr) {} - - /// \brief Destructor - ~DvppDecodeJpegOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppDecodeJpegOp; } - - Status SetAscendResource(const std::shared_ptr &resource) override; - - private: - std::shared_ptr processor_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DvppDecodeJpegOp : public TensorOp { + public: + DvppDecodeJpegOp() : processor_(nullptr) {} + + /// \brief Destructor + ~DvppDecodeJpegOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodeJpegOp; } + + Status SetAscendResource(const std::shared_ptr &resource) override; + + private: + std::shared_ptr processor_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_JPEG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.cc index 50eb3c3f4fb..2ec5fac92b1 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.cc @@ -1,141 +1,141 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h" - -#include -#include - -#include "minddata/dataset/core/cv_tensor.h" -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_crop_jpeg_op.h" -#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" -#include "minddata/dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer on device is empty."); - APP_ERROR ret = AclAdapter::GetInstance().PNG_D(processor_.get()); - if (ret != APP_ERR_OK) { - ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get()); - const TensorShape dvpp_shape({1, 1, 1}); - const DataType dvpp_data_type(DataType::DE_UINT8); - RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); - RETURN_IF_NOT_OK((*output)->SetAttributes(DecodeOut->data, DecodeOut->dataSize, DecodeOut->width, - DecodeOut->widthStride, DecodeOut->height, DecodeOut->heightStride)); - if (!((*output)->HasDeviceData())) { - std::string error = "[ERROR] Fail to get the Output result from memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppDecodePngOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (!IsNonEmptyPNG(input)) { - RETURN_STATUS_UNEXPECTED("DvppDecodePngOp only support process PNG image."); - } - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); - auto *buffer = const_cast(input->GetBuffer()); - RawData imageInfo{}; - uint32_t filesize = input->SizeInBytes(); - imageInfo.lenOfByte = filesize; - imageInfo.data = static_cast(buffer); - ResourceInfo resource; - resource.deviceIds.insert(0); - APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init D-chip: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - int deviceId = *(resource.deviceIds.begin()); - void *context = AclAdapter::GetInstance().GetContext(deviceId); - // Second part end where we initialize the resource of D-chip and set up all configures - std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcess(context, false, nullptr, nullptr), - [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); - ret = AclAdapter::GetInstance().InitAclProcess(process.get()); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init resource: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - - ret = AclAdapter::GetInstance().PNG_D_WITH_DATA(process.get(), imageInfo); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - - // Third part end where we execute the core function of dvpp - auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); - DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(process.get()); - dsize_t dvpp_length = DecodeOut->dataSize; - - const TensorShape dvpp_shape({dvpp_length, 1, 1}); - const DataType dvpp_data_type(DataType::DE_UINT8); - mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); - if (!((*output)->HasData())) { - std::string error = "[ERROR] Fail to get the Output result from memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); - ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); - // Last part end where we transform the processed data into a tensor which can be applied in later units. - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppDecodePngOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppDecodePngOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 3 channels - CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodePng: inputs cannot be empty."); - if (inputs[0].Rank() == 1) { - outputs.emplace_back(out); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodePng: Invalid input shape."); - return Status::OK(); -} - -Status DvppDecodePngOp::SetAscendResource(const std::shared_ptr &resource) { - processor_ = resource->GetInstance(); - if (!processor_) { - RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h" + +#include +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_crop_jpeg_op.h" +#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer on device is empty."); + APP_ERROR ret = AclAdapter::GetInstance().PNG_D(processor_.get()); + if (ret != APP_ERR_OK) { + ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get()); + const TensorShape dvpp_shape({1, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); + RETURN_IF_NOT_OK((*output)->SetAttributes(DecodeOut->data, DecodeOut->dataSize, DecodeOut->width, + DecodeOut->widthStride, DecodeOut->height, DecodeOut->heightStride)); + if (!((*output)->HasDeviceData())) { + std::string error = "[ERROR] Fail to get the Output result from memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppDecodePngOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!IsNonEmptyPNG(input)) { + RETURN_STATUS_UNEXPECTED("DvppDecodePngOp only support process PNG image."); + } + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); + auto *buffer = const_cast(input->GetBuffer()); + RawData imageInfo{}; + uint32_t filesize = input->SizeInBytes(); + imageInfo.lenOfByte = filesize; + imageInfo.data = static_cast(buffer); + ResourceInfo resource; + resource.deviceIds.insert(0); + APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init D-chip: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + int deviceId = *(resource.deviceIds.begin()); + void *context = AclAdapter::GetInstance().GetContext(deviceId); + // Second part end where we initialize the resource of D-chip and set up all configures + std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcess(context, false, nullptr, nullptr), + [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); + ret = AclAdapter::GetInstance().InitAclProcess(process.get()); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init resource: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + + ret = AclAdapter::GetInstance().PNG_D_WITH_DATA(process.get(), imageInfo); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + + // Third part end where we execute the core function of dvpp + auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); + DvppDataInfo *DecodeOut = AclAdapter::GetInstance().GetDecodeDeviceData(process.get()); + dsize_t dvpp_length = DecodeOut->dataSize; + + const TensorShape dvpp_shape({dvpp_length, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); + if (!((*output)->HasData())) { + std::string error = "[ERROR] Fail to get the Output result from memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); + ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); + // Last part end where we transform the processed data into a tensor which can be applied in later units. + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppDecodePngOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppDecodePngOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 3 channels + CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodePng: inputs cannot be empty."); + if (inputs[0].Rank() == 1) { + outputs.emplace_back(out); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodePng: Invalid input shape."); + return Status::OK(); +} + +Status DvppDecodePngOp::SetAscendResource(const std::shared_ptr &resource) { + processor_ = resource->GetInstance(); + if (!processor_) { + RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h index ffcb0fefeff..534719c3b2d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_png_op.h @@ -1,57 +1,57 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/log_adapter.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DvppDecodePngOp : public TensorOp { - public: - DvppDecodePngOp() {} - - /// \brief Destructor - ~DvppDecodePngOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppDecodePngOp; } - - Status SetAscendResource(const std::shared_ptr &resource) override; - - private: - std::shared_ptr processor_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DvppDecodePngOp : public TensorOp { + public: + DvppDecodePngOp() {} + + /// \brief Destructor + ~DvppDecodePngOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodePngOp; } + + Status SetAscendResource(const std::shared_ptr &resource) override; + + private: + std::shared_ptr processor_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_PNG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_jpeg_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_jpeg_op.h index 658e921bbc7..09eec54055a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_jpeg_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_resize_jpeg_op.h @@ -1,60 +1,60 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/log_adapter.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DvppDecodeResizeJpegOp : public TensorOp { - public: - DvppDecodeResizeJpegOp(int32_t resized_height, int32_t resized_width) - : resized_height_(resized_height), resized_width_(resized_width) {} - - /// \brief Destructor - ~DvppDecodeResizeJpegOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppDecodeResizeJpegOp; } - - Status SetAscendResource(const std::shared_ptr &resource) override; - - private: - int32_t resized_height_; - int32_t resized_width_; - std::shared_ptr processor_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DvppDecodeResizeJpegOp : public TensorOp { + public: + DvppDecodeResizeJpegOp(int32_t resized_height, int32_t resized_width) + : resized_height_(resized_height), resized_width_(resized_width) {} + + /// \brief Destructor + ~DvppDecodeResizeJpegOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodeResizeJpegOp; } + + Status SetAscendResource(const std::shared_ptr &resource) override; + + private: + int32_t resized_height_; + int32_t resized_width_; + std::shared_ptr processor_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_RESIZE_JPEG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.cc index d30a52a5489..cda5f56fdf0 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.cc @@ -1,90 +1,90 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h" - -#include "include/api/context.h" -#include "minddata/dataset/core/cv_tensor.h" -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/util/path.h" - -namespace mindspore { -namespace dataset { -const VdecOutputFormat DvppDecodeVideoOp::kDefVdecOutputFormat = VdecOutputFormat::kYuvSemiplanar420; -const char DvppDecodeVideoOp::kDefOutput[] = "./output"; - -Status DvppDecodeVideoOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input video buffer is empty."); - auto *buffer = const_cast(input->GetBuffer()); - auto data_size = input->SizeInBytes(); - // assuem that output equals to input - RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromTensor(input, output)); - - ResourceInfo resource; - resource.deviceIds.insert(0); - APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "DvppDecodeVideo: Error in Init D-chip: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - int deviceId = *(resource.deviceIds.begin()); - void *context = AclAdapter::GetInstance().GetContext(deviceId); - // initialize the resource of D-chip and set up all configures - - auto dvpp_video = AclAdapter::GetInstance().CreateDvppVideo(context, buffer, data_size, width_, height_, - static_cast(en_type_), - static_cast(format_), output_); - AclLiteError res = AclAdapter::GetInstance().InitDvppVideo(dvpp_video); - if (res != ACLLITE_OK) { - (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); - AclAdapter::GetInstance().Release(); - std::string error = "DvppDecodeVideo: Failed to initialize DvppVideo:" + std::to_string(res); - RETURN_STATUS_UNEXPECTED(error); - } - - res = AclAdapter::GetInstance().DvppVideoDumpFrame(dvpp_video); - if (res != ACLLITE_OK) { - (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); - AclAdapter::GetInstance().Release(); - std::string error = "DvppDecodeVideo: Error in DumpFrame:" + std::to_string(res); - RETURN_STATUS_UNEXPECTED(error); - } - (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); - } catch (const std::exception &e) { - std::string error = "[ERROR] Error in DvppDecodeVideoOp:" + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppDecodeVideoOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodeVideo: inputs cannot be empty."); - if (inputs[0].Rank() == 1) { - outputs = inputs; - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodeVideo: Invalid input shape."); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h" + +#include "include/api/context.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { +const VdecOutputFormat DvppDecodeVideoOp::kDefVdecOutputFormat = VdecOutputFormat::kYuvSemiplanar420; +const char DvppDecodeVideoOp::kDefOutput[] = "./output"; + +Status DvppDecodeVideoOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input video buffer is empty."); + auto *buffer = const_cast(input->GetBuffer()); + auto data_size = input->SizeInBytes(); + // assuem that output equals to input + RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromTensor(input, output)); + + ResourceInfo resource; + resource.deviceIds.insert(0); + APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "DvppDecodeVideo: Error in Init D-chip: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + int deviceId = *(resource.deviceIds.begin()); + void *context = AclAdapter::GetInstance().GetContext(deviceId); + // initialize the resource of D-chip and set up all configures + + auto dvpp_video = AclAdapter::GetInstance().CreateDvppVideo(context, buffer, data_size, width_, height_, + static_cast(en_type_), + static_cast(format_), output_); + AclLiteError res = AclAdapter::GetInstance().InitDvppVideo(dvpp_video); + if (res != ACLLITE_OK) { + (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); + AclAdapter::GetInstance().Release(); + std::string error = "DvppDecodeVideo: Failed to initialize DvppVideo:" + std::to_string(res); + RETURN_STATUS_UNEXPECTED(error); + } + + res = AclAdapter::GetInstance().DvppVideoDumpFrame(dvpp_video); + if (res != ACLLITE_OK) { + (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); + AclAdapter::GetInstance().Release(); + std::string error = "DvppDecodeVideo: Error in DumpFrame:" + std::to_string(res); + RETURN_STATUS_UNEXPECTED(error); + } + (void)AclAdapter::GetInstance().CloseDvppVideo(dvpp_video); + } catch (const std::exception &e) { + std::string error = "[ERROR] Error in DvppDecodeVideoOp:" + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppDecodeVideoOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppDecodeVideo: inputs cannot be empty."); + if (inputs[0].Rank() == 1) { + outputs = inputs; + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppDecodeVideo: Invalid input shape."); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h index 7d9397362a0..a71d94c0aba 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_decode_video_op.h @@ -1,72 +1,72 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/status.h" -#include "mindspore/core/utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -class DvppDecodeVideoOp : public TensorOp { - public: - // Default values - static const VdecOutputFormat kDefVdecOutputFormat; - static const char kDefOutput[]; - - DvppDecodeVideoOp(uint32_t width, uint32_t height, VdecStreamFormat type, - VdecOutputFormat out_format = kDefVdecOutputFormat, const std::string &output = kDefOutput) - : width_(width), height_(height), format_(out_format), en_type_(type), output_(output) {} - - /// \brief Destructor - ~DvppDecodeVideoOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppDecodeVideoOp; } - - private: - uint32_t width_; - uint32_t height_; - - /* 1:YUV420 semi-planner(nv12) - 2:YVU420 semi-planner(nv21) - */ - VdecOutputFormat format_; - - /* 0:H265 main level - * 1:H264 baseline level - * 2:H264 main level - * 3:H264 high level - */ - VdecStreamFormat en_type_; - std::string output_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +class DvppDecodeVideoOp : public TensorOp { + public: + // Default values + static const VdecOutputFormat kDefVdecOutputFormat; + static const char kDefOutput[]; + + DvppDecodeVideoOp(uint32_t width, uint32_t height, VdecStreamFormat type, + VdecOutputFormat out_format = kDefVdecOutputFormat, const std::string &output = kDefOutput) + : width_(width), height_(height), format_(out_format), en_type_(type), output_(output) {} + + /// \brief Destructor + ~DvppDecodeVideoOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodeVideoOp; } + + private: + uint32_t width_; + uint32_t height_; + + /* 1:YUV420 semi-planner(nv12) + 2:YVU420 semi-planner(nv21) + */ + VdecOutputFormat format_; + + /* 0:H265 main level + * 1:H264 baseline level + * 2:H264 main level + * 3:H264 high level + */ + VdecStreamFormat en_type_; + std::string output_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_DECODE_VIDEO_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.cc index 519dee809a2..9e777c8a136 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.cc @@ -1,159 +1,159 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h" - -#include -#include - -#include "minddata/dataset/core/cv_tensor.h" -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_tensor.h" -#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" -#include "minddata/dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -Status DvppResizeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer is empty."); - std::string last_step = "Decode"; - DvppDataInfo *imageinfo(AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get())); - if (!imageinfo->data) { - last_step = "Crop"; - } - APP_ERROR ret = AclAdapter::GetInstance().JPEG_R(processor_.get(), last_step); - if (ret != APP_ERR_OK) { - ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - DvppDataInfo *ResizeOut(AclAdapter::GetInstance().GetResizedDeviceData(processor_.get())); - const TensorShape dvpp_shape({1, 1, 1}); - const DataType dvpp_data_type(DataType::DE_UINT8); - RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); - RETURN_IF_NOT_OK((*output)->SetAttributes(ResizeOut->data, ResizeOut->dataSize, ResizeOut->width, - ResizeOut->widthStride, ResizeOut->height, ResizeOut->heightStride)); - if (!((*output)->HasDeviceData())) { - std::string error = "[ERROR] Fail to get the Output result from device memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppResizeJpegOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppResizeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (!IsNonEmptyJPEG(input)) { - RETURN_STATUS_UNEXPECTED("DvppReiszeJpegOp only support process jpeg image."); - } - try { - CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); - auto *buffer = const_cast(input->GetBuffer()); - DvppDataInfo imageinfo; - imageinfo.dataSize = input->SizeInBytes(); - imageinfo.data = static_cast(buffer); - std::vector yuv_shape_ = input->GetYuvShape(); - imageinfo.width = yuv_shape_[0]; - imageinfo.widthStride = yuv_shape_[1]; - imageinfo.height = yuv_shape_[2]; - imageinfo.heightStride = yuv_shape_[3]; - imageinfo.format = 1; // 1 means PIXEL_FORMAT_YUV_SEMIPLANAR_420 - ResourceInfo resource; - resource.deviceIds.insert(0); - APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init D-chip: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - int deviceId = *(resource.deviceIds.begin()); - void *context = AclAdapter::GetInstance().GetContext(deviceId); - // Second part end where we initialize the resource of D-chip and set up all configures - std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcessWithPara(resized_width_, resized_height_, - context, false, nullptr, nullptr), - [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); - - ret = AclAdapter::GetInstance().InitAclProcess(process.get()); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in Init resource: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - - ret = AclAdapter::GetInstance().JPEG_R_WITH_DATA(process.get(), imageinfo); - if (ret != APP_ERR_OK) { - AclAdapter::GetInstance().Release(); - std::string error = "Error in dvpp processing: " + std::to_string(ret); - RETURN_STATUS_UNEXPECTED(error); - } - - // Third part end where we execute the core function of dvpp - auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); - DvppDataInfo *ResizeOut = AclAdapter::GetInstance().GetResizedDeviceData(process.get()); - dsize_t dvpp_length = ResizeOut->dataSize; - const TensorShape dvpp_shape({dvpp_length, 1, 1}); - uint32_t resized_height = ResizeOut->height; - uint32_t resized_heightStride = ResizeOut->heightStride; - uint32_t resized_width = ResizeOut->width; - uint32_t resized_widthStride = ResizeOut->widthStride; - const DataType dvpp_data_type(DataType::DE_UINT8); - RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output)); - RETURN_IF_NOT_OK((*output)->SetYuvShape(resized_width, resized_widthStride, resized_height, resized_heightStride)); - if (!((*output)->HasData())) { - std::string error = "[ERROR] Fail to get the Output result from memory!"; - RETURN_STATUS_UNEXPECTED(error); - } - ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); - ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); - // Last part end where we transform the processed data into a tensor which can be applied in later units. - } catch (const std::exception &e) { - std::string error = "[ERROR] Fail in DvppResizeJpegOp: " + std::string(e.what()); - RETURN_STATUS_UNEXPECTED(error); - } - return Status::OK(); -} - -Status DvppResizeJpegOp::SetAscendResource(const std::shared_ptr &resource) { - processor_ = resource->GetInstance(); - if (!processor_) { - RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); - } - APP_ERROR ret = AclAdapter::GetInstance().SetResizeParas(processor_.get(), resized_width_, resized_height_); - CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "SetResizeParas failed."); - return Status::OK(); -} - -Status DvppResizeJpegOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 1 channels - CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppResizeJpeg: inputs cannot be empty."); - if (inputs[0].Rank() == 1) { - (void)outputs.emplace_back(out); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppResizeJpeg: Invalid input shape."); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h" + +#include +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_tensor.h" +#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +Status DvppResizeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetDeviceBuffer() != nullptr, "The input image buffer is empty."); + std::string last_step = "Decode"; + DvppDataInfo *imageinfo(AclAdapter::GetInstance().GetDecodeDeviceData(processor_.get())); + if (!imageinfo->data) { + last_step = "Crop"; + } + APP_ERROR ret = AclAdapter::GetInstance().JPEG_R(processor_.get(), last_step); + if (ret != APP_ERR_OK) { + ret = AclAdapter::GetInstance().ReleaseAclProcess(processor_.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release memory failed."); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + DvppDataInfo *ResizeOut(AclAdapter::GetInstance().GetResizedDeviceData(processor_.get())); + const TensorShape dvpp_shape({1, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + RETURN_IF_NOT_OK(mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output)); + RETURN_IF_NOT_OK((*output)->SetAttributes(ResizeOut->data, ResizeOut->dataSize, ResizeOut->width, + ResizeOut->widthStride, ResizeOut->height, ResizeOut->heightStride)); + if (!((*output)->HasDeviceData())) { + std::string error = "[ERROR] Fail to get the Output result from device memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppResizeJpegOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppResizeJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!IsNonEmptyJPEG(input)) { + RETURN_STATUS_UNEXPECTED("DvppReiszeJpegOp only support process jpeg image."); + } + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); + auto *buffer = const_cast(input->GetBuffer()); + DvppDataInfo imageinfo; + imageinfo.dataSize = input->SizeInBytes(); + imageinfo.data = static_cast(buffer); + std::vector yuv_shape_ = input->GetYuvShape(); + imageinfo.width = yuv_shape_[0]; + imageinfo.widthStride = yuv_shape_[1]; + imageinfo.height = yuv_shape_[2]; + imageinfo.heightStride = yuv_shape_[3]; + imageinfo.format = 1; // 1 means PIXEL_FORMAT_YUV_SEMIPLANAR_420 + ResourceInfo resource; + resource.deviceIds.insert(0); + APP_ERROR ret = AclAdapter::GetInstance().InitResource(&resource); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init D-chip: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + int deviceId = *(resource.deviceIds.begin()); + void *context = AclAdapter::GetInstance().GetContext(deviceId); + // Second part end where we initialize the resource of D-chip and set up all configures + std::shared_ptr process(AclAdapter::GetInstance().CreateAclProcessWithPara(resized_width_, resized_height_, + context, false, nullptr, nullptr), + [](void *ptr) { AclAdapter::GetInstance().DestroyAclProcess(ptr); }); + + ret = AclAdapter::GetInstance().InitAclProcess(process.get()); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in Init resource: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + + ret = AclAdapter::GetInstance().JPEG_R_WITH_DATA(process.get(), imageinfo); + if (ret != APP_ERR_OK) { + AclAdapter::GetInstance().Release(); + std::string error = "Error in dvpp processing: " + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + + // Third part end where we execute the core function of dvpp + auto *ret_ptr = static_cast(AclAdapter::GetInstance().GetMemoryData(process.get())); + DvppDataInfo *ResizeOut = AclAdapter::GetInstance().GetResizedDeviceData(process.get()); + dsize_t dvpp_length = ResizeOut->dataSize; + const TensorShape dvpp_shape({dvpp_length, 1, 1}); + uint32_t resized_height = ResizeOut->height; + uint32_t resized_heightStride = ResizeOut->heightStride; + uint32_t resized_width = ResizeOut->width; + uint32_t resized_widthStride = ResizeOut->widthStride; + const DataType dvpp_data_type(DataType::DE_UINT8); + RETURN_IF_NOT_OK(mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output)); + RETURN_IF_NOT_OK((*output)->SetYuvShape(resized_width, resized_widthStride, resized_height, resized_heightStride)); + if (!((*output)->HasData())) { + std::string error = "[ERROR] Fail to get the Output result from memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + ret = AclAdapter::GetInstance().DeviceMemoryRelease(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release device memory failed."); + ret = AclAdapter::GetInstance().ReleaseAclProcess(process.get()); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "Release host memory failed."); + // Last part end where we transform the processed data into a tensor which can be applied in later units. + } catch (const std::exception &e) { + std::string error = "[ERROR] Fail in DvppResizeJpegOp: " + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppResizeJpegOp::SetAscendResource(const std::shared_ptr &resource) { + processor_ = resource->GetInstance(); + if (!processor_) { + RETURN_STATUS_UNEXPECTED("Resource initialize fail, please check your env."); + } + APP_ERROR ret = AclAdapter::GetInstance().SetResizeParas(processor_.get(), resized_width_, resized_height_); + CHECK_FAIL_RETURN_UNEXPECTED(ret == APP_ERR_OK, "SetResizeParas failed."); + return Status::OK(); +} + +Status DvppResizeJpegOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 1 channels + CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty(), "DvppResizeJpeg: inputs cannot be empty."); + if (inputs[0].Rank() == 1) { + (void)outputs.emplace_back(out); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "DvppResizeJpeg: Invalid input shape."); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h index 3328058de6a..3830e5c2ed8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/ascend310/dvpp_resize_jpeg_op.h @@ -1,61 +1,61 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ - -#include -#include -#include - -#include "minddata/dataset/core/data_type.h" -#include "minddata/dataset/core/device_resource.h" -#include "minddata/dataset/core/device_tensor.h" -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/util/log_adapter.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DvppResizeJpegOp : public TensorOp { - public: - DvppResizeJpegOp(int32_t resized_height, int32_t resized_width) - : resized_height_(resized_height), resized_width_(resized_width) {} - - /// \brief Destructor - ~DvppResizeJpegOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDvppDecodeResizeJpegOp; } - - Status SetAscendResource(const std::shared_ptr &resource) override; - - private: - int32_t resized_height_; - int32_t resized_width_; - std::shared_ptr processor_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/device_resource.h" +#include "minddata/dataset/core/device_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/dvpp/acl_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/log_adapter.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DvppResizeJpegOp : public TensorOp { + public: + DvppResizeJpegOp(int32_t resized_height, int32_t resized_width) + : resized_height_(resized_height), resized_width_(resized_width) {} + + /// \brief Destructor + ~DvppResizeJpegOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodeResizeJpegOp; } + + Status SetAscendResource(const std::shared_ptr &resource) override; + + private: + int32_t resized_height_; + int32_t resized_width_; + std::shared_ptr processor_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_DVPP_RESIZE_JPEG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h index 40c8d46cd0a..f38e39ecc8a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h @@ -1,191 +1,191 @@ -/** -* Copyright 2022-2023 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ - -using AclLiteError = int; - -constexpr int ACLLITE_OK = 0; -constexpr int ACLLITE_ERROR = 1; -constexpr int ACLLITE_ERROR_INVALID_ARGS = 2; -constexpr int ACLLITE_ERROR_SET_ACL_CONTEXT = 3; -constexpr int ACLLITE_ERROR_GET_ACL_CONTEXT = 4; -constexpr int ACLLITE_ERROR_CREATE_ACL_CONTEXT = 5; -constexpr int ACLLITE_ERROR_CREATE_THREAD = 6; -constexpr int ACLLITE_ERROR_CREATE_STREAM = 7; -constexpr int ACLLITE_ERROR_GET_RUM_MODE = 8; -constexpr int ACLLITE_ERROR_APP_INIT = 9; -constexpr int ACLLITE_ERROR_DEST_INVALID = 10; -constexpr int ACLLITE_ERROR_INITED_ALREADY = 11; -constexpr int ACLLITE_ERROR_ENQUEUE = 12; -constexpr int ACLLITE_ERROR_WRITE_FILE = 13; -constexpr int ACLLITE_ERROR_THREAD_ABNORMAL = 14; -constexpr int ACLLITE_ERROR_START_THREAD = 15; -constexpr int ACLLITE_ERROR_ADD_THREAD = 16; - -// malloc or new memory failed -constexpr int ACLLITE_ERROR_MALLOC = 101; -// aclrtMalloc failed -constexpr int ACLLITE_ERROR_MALLOC_DEVICE = 102; - -constexpr int ACLLITE_ERROR_MALLOC_DVPP = 103; -// access file failed -constexpr int ACLLITE_ERROR_ACCESS_FILE = 201; -// the file is invalid -constexpr int ACLLITE_ERROR_INVALID_FILE = 202; -// open file failed -constexpr int ACLLITE_ERROR_OPEN_FILE = 203; - -// load model repeated -constexpr int ACLLITE_ERROR_LOAD_MODEL_REPEATED = 301; - -constexpr int ACLLITE_ERROR_NO_MODEL_DESC = 302; -// load mode by acl failed -constexpr int ACLLITE_ERROR_LOAD_MODEL = 303; - -constexpr int ACLLITE_ERROR_CREATE_MODEL_DESC = 304; - -constexpr int ACLLITE_ERROR_GET_MODEL_DESC = 305; - -constexpr int ACLLITE_ERROR_CREATE_DATASET = 306; - -constexpr int ACLLITE_ERROR_CREATE_DATA_BUFFER = 307; - -constexpr int ACLLITE_ERROR_ADD_DATASET_BUFFER = 308; - -constexpr int ACLLITE_ERROR_EXECUTE_MODEL = 309; - -constexpr int ACLLITE_ERROR_GET_DATASET_BUFFER = 310; - -constexpr int ACLLITE_ERROR_GET_DATA_BUFFER_ADDR = 311; - -constexpr int ACLLITE_ERROR_GET_DATA_BUFFER_SIZE = 312; - -constexpr int ACLLITE_ERROR_COPY_DATA = 313; - -constexpr int ACLLITE_ERROR_SET_CAMERA = 400; - -constexpr int ACLLITE_ERROR_CAMERA_NO_ACCESSABLE = 401; - -constexpr int ACLLITE_ERROR_OPEN_CAMERA = 402; - -constexpr int ACLLITE_ERROR_READ_CAMERA_FRAME = 403; - -constexpr int ACLLITE_ERROR_UNSURPPORT_PROPERTY = 404; - -constexpr int ACLLITE_ERROR_INVALID_PROPERTY_VALUE = 405; - -constexpr int ACLLITE_ERROR_UNSURPPORT_VIDEO_CAPTURE = 406; - -constexpr int ACLLITE_ERROR_CREATE_DVPP_CHANNEL_DESC = 501; - -constexpr int ACLLITE_ERRROR_CREATE_DVPP_CHANNEL = 502; - -constexpr int ACLLITE_ERROR_CREATE_PIC_DESC = 503; - -constexpr int ACLLITE_ERROR_CREATE_RESIZE_CONFIG = 504; - -constexpr int ACLLITE_ERROR_RESIZE_ASYNC = 505; - -constexpr int ACLLITE_ERROR_SYNC_STREAM = 506; - -constexpr int ACLLITE_ERROR_JPEGE_ASYNC = 507; - -constexpr int ACLLITE_ERROR_JPEGD_ASYNC = 508; - -constexpr int ACLLITE_ERROR_FFMPEG_DECODER_INIT = 601; - -constexpr int ACLLITE_ERROR_OPEN_VIDEO_UNREADY = 602; - -constexpr int ACLLITE_ERROR_TOO_MANY_VIDEO_DECODERS = 603; - -constexpr int ACLLITE_ERROR_SET_VDEC_CHANNEL_ID = 604; - -constexpr int ACLLITE_ERROR_SET_STREAM_DESC_DATA = 605; - -constexpr int ACLLITE_ERROR_SET_VDEC_CHANNEL_THREAD_ID = 606; - -constexpr int ACLLITE_ERROR_SET_VDEC_CALLBACK = 607; - -constexpr int ACLLITE_ERROR_SET_VDEC_ENTYPE = 608; - -constexpr int ACLLITE_ERROR_SET_VDEC_PIC_FORMAT = 609; - -constexpr int ACLLITE_ERROR_CREATE_VDEC_CHANNEL = 610; - -constexpr int ACLLITE_ERROR_CREATE_STREAM_DESC = 611; - -constexpr int ACLLITE_ERROR_SET_STREAM_DESC_EOS = 612; - -constexpr int ACLLITE_ERROR_SET_STREAM_DESC_SIZE = 613; - -constexpr int ACLLITE_ERROR_SET_PIC_DESC_DATA = 614; - -constexpr int ACLLITE_ERROR_SET_PIC_DESC_SIZE = 615; - -constexpr int ACLLITE_ERROR_SET_PIC_DESC_FORMAT = 616; - -constexpr int ACLLITE_ERROR_VDEC_IS_EXITTING = 617; - -constexpr int ACLLITE_ERROR_VDEC_SET_WIDTH = 618; - -constexpr int ACLLITE_ERROR_VDEC_WIDTH_INVALID = 619; - -constexpr int ACLLITE_ERROR_VDEC_HEIGHT_INVALID = 620; - -constexpr int ACLLITE_ERROR_VDEC_SET_HEIGHT = 621; - -constexpr int ACLLITE_ERROR_VDEC_ENTYPE_INVALID = 622; - -constexpr int ACLLITE_ERROR_VDEC_FORMAT_INVALID = 623; - -constexpr int ACLLITE_ERROR_VDEC_INVALID_PARAM = 624; - -constexpr int ACLLITE_ERROR_VDEC_SEND_FRAME = 625; - -constexpr int ACLLITE_ERROR_VDEC_QUEUE_FULL = 626; - -constexpr int ACLLITE_ERROR_SET_RTSP_TRANS = 627; - -constexpr int ACLLITE_ERROR_READ_EMPTY = 628; - -constexpr int ACLLITE_ERROR_VIDEO_DECODER_STATUS = 629; - -constexpr int ACLLITE_ERROR_DECODE_FINISH = 630; - -constexpr int ACLLITE_ERROR_H26X_FRAME = 631; - -constexpr int ACLLITE_ERROR_VENC_STATUS = 701; - -constexpr int ACLLITE_ERROR_VENC_QUEUE_FULL = 702; - -constexpr int ACLLITE_ERROR_CREATE_VENC_CHAN_DESC = 703; - -constexpr int ACLLITE_ERROR_SET_VENC_CHAN_TID = 704; - -constexpr int ACLLITE_ERROR_VENC_SET_EOS = 705; - -constexpr int ACLLITE_ERROR_VENC_SET_IF_FRAME = 706; - -constexpr int ACLLITE_ERROR_CREATE_VENC_CHAN = 707; - -constexpr int ACLLITE_ERROR_VENC_CREATE_FRAME_CONFIG = 708; - -constexpr int ACLLITE_ERROR_VENC_SEND_FRAME = 709; - -constexpr int ACLLITE_ERROR_SUBSCRIBE_REPORT = 710; - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ +/** +* Copyright 2022-2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ + +using AclLiteError = int; + +constexpr int ACLLITE_OK = 0; +constexpr int ACLLITE_ERROR = 1; +constexpr int ACLLITE_ERROR_INVALID_ARGS = 2; +constexpr int ACLLITE_ERROR_SET_ACL_CONTEXT = 3; +constexpr int ACLLITE_ERROR_GET_ACL_CONTEXT = 4; +constexpr int ACLLITE_ERROR_CREATE_ACL_CONTEXT = 5; +constexpr int ACLLITE_ERROR_CREATE_THREAD = 6; +constexpr int ACLLITE_ERROR_CREATE_STREAM = 7; +constexpr int ACLLITE_ERROR_GET_RUM_MODE = 8; +constexpr int ACLLITE_ERROR_APP_INIT = 9; +constexpr int ACLLITE_ERROR_DEST_INVALID = 10; +constexpr int ACLLITE_ERROR_INITED_ALREADY = 11; +constexpr int ACLLITE_ERROR_ENQUEUE = 12; +constexpr int ACLLITE_ERROR_WRITE_FILE = 13; +constexpr int ACLLITE_ERROR_THREAD_ABNORMAL = 14; +constexpr int ACLLITE_ERROR_START_THREAD = 15; +constexpr int ACLLITE_ERROR_ADD_THREAD = 16; + +// malloc or new memory failed +constexpr int ACLLITE_ERROR_MALLOC = 101; +// aclrtMalloc failed +constexpr int ACLLITE_ERROR_MALLOC_DEVICE = 102; + +constexpr int ACLLITE_ERROR_MALLOC_DVPP = 103; +// access file failed +constexpr int ACLLITE_ERROR_ACCESS_FILE = 201; +// the file is invalid +constexpr int ACLLITE_ERROR_INVALID_FILE = 202; +// open file failed +constexpr int ACLLITE_ERROR_OPEN_FILE = 203; + +// load model repeated +constexpr int ACLLITE_ERROR_LOAD_MODEL_REPEATED = 301; + +constexpr int ACLLITE_ERROR_NO_MODEL_DESC = 302; +// load mode by acl failed +constexpr int ACLLITE_ERROR_LOAD_MODEL = 303; + +constexpr int ACLLITE_ERROR_CREATE_MODEL_DESC = 304; + +constexpr int ACLLITE_ERROR_GET_MODEL_DESC = 305; + +constexpr int ACLLITE_ERROR_CREATE_DATASET = 306; + +constexpr int ACLLITE_ERROR_CREATE_DATA_BUFFER = 307; + +constexpr int ACLLITE_ERROR_ADD_DATASET_BUFFER = 308; + +constexpr int ACLLITE_ERROR_EXECUTE_MODEL = 309; + +constexpr int ACLLITE_ERROR_GET_DATASET_BUFFER = 310; + +constexpr int ACLLITE_ERROR_GET_DATA_BUFFER_ADDR = 311; + +constexpr int ACLLITE_ERROR_GET_DATA_BUFFER_SIZE = 312; + +constexpr int ACLLITE_ERROR_COPY_DATA = 313; + +constexpr int ACLLITE_ERROR_SET_CAMERA = 400; + +constexpr int ACLLITE_ERROR_CAMERA_NO_ACCESSABLE = 401; + +constexpr int ACLLITE_ERROR_OPEN_CAMERA = 402; + +constexpr int ACLLITE_ERROR_READ_CAMERA_FRAME = 403; + +constexpr int ACLLITE_ERROR_UNSURPPORT_PROPERTY = 404; + +constexpr int ACLLITE_ERROR_INVALID_PROPERTY_VALUE = 405; + +constexpr int ACLLITE_ERROR_UNSURPPORT_VIDEO_CAPTURE = 406; + +constexpr int ACLLITE_ERROR_CREATE_DVPP_CHANNEL_DESC = 501; + +constexpr int ACLLITE_ERRROR_CREATE_DVPP_CHANNEL = 502; + +constexpr int ACLLITE_ERROR_CREATE_PIC_DESC = 503; + +constexpr int ACLLITE_ERROR_CREATE_RESIZE_CONFIG = 504; + +constexpr int ACLLITE_ERROR_RESIZE_ASYNC = 505; + +constexpr int ACLLITE_ERROR_SYNC_STREAM = 506; + +constexpr int ACLLITE_ERROR_JPEGE_ASYNC = 507; + +constexpr int ACLLITE_ERROR_JPEGD_ASYNC = 508; + +constexpr int ACLLITE_ERROR_FFMPEG_DECODER_INIT = 601; + +constexpr int ACLLITE_ERROR_OPEN_VIDEO_UNREADY = 602; + +constexpr int ACLLITE_ERROR_TOO_MANY_VIDEO_DECODERS = 603; + +constexpr int ACLLITE_ERROR_SET_VDEC_CHANNEL_ID = 604; + +constexpr int ACLLITE_ERROR_SET_STREAM_DESC_DATA = 605; + +constexpr int ACLLITE_ERROR_SET_VDEC_CHANNEL_THREAD_ID = 606; + +constexpr int ACLLITE_ERROR_SET_VDEC_CALLBACK = 607; + +constexpr int ACLLITE_ERROR_SET_VDEC_ENTYPE = 608; + +constexpr int ACLLITE_ERROR_SET_VDEC_PIC_FORMAT = 609; + +constexpr int ACLLITE_ERROR_CREATE_VDEC_CHANNEL = 610; + +constexpr int ACLLITE_ERROR_CREATE_STREAM_DESC = 611; + +constexpr int ACLLITE_ERROR_SET_STREAM_DESC_EOS = 612; + +constexpr int ACLLITE_ERROR_SET_STREAM_DESC_SIZE = 613; + +constexpr int ACLLITE_ERROR_SET_PIC_DESC_DATA = 614; + +constexpr int ACLLITE_ERROR_SET_PIC_DESC_SIZE = 615; + +constexpr int ACLLITE_ERROR_SET_PIC_DESC_FORMAT = 616; + +constexpr int ACLLITE_ERROR_VDEC_IS_EXITTING = 617; + +constexpr int ACLLITE_ERROR_VDEC_SET_WIDTH = 618; + +constexpr int ACLLITE_ERROR_VDEC_WIDTH_INVALID = 619; + +constexpr int ACLLITE_ERROR_VDEC_HEIGHT_INVALID = 620; + +constexpr int ACLLITE_ERROR_VDEC_SET_HEIGHT = 621; + +constexpr int ACLLITE_ERROR_VDEC_ENTYPE_INVALID = 622; + +constexpr int ACLLITE_ERROR_VDEC_FORMAT_INVALID = 623; + +constexpr int ACLLITE_ERROR_VDEC_INVALID_PARAM = 624; + +constexpr int ACLLITE_ERROR_VDEC_SEND_FRAME = 625; + +constexpr int ACLLITE_ERROR_VDEC_QUEUE_FULL = 626; + +constexpr int ACLLITE_ERROR_SET_RTSP_TRANS = 627; + +constexpr int ACLLITE_ERROR_READ_EMPTY = 628; + +constexpr int ACLLITE_ERROR_VIDEO_DECODER_STATUS = 629; + +constexpr int ACLLITE_ERROR_DECODE_FINISH = 630; + +constexpr int ACLLITE_ERROR_H26X_FRAME = 631; + +constexpr int ACLLITE_ERROR_VENC_STATUS = 701; + +constexpr int ACLLITE_ERROR_VENC_QUEUE_FULL = 702; + +constexpr int ACLLITE_ERROR_CREATE_VENC_CHAN_DESC = 703; + +constexpr int ACLLITE_ERROR_SET_VENC_CHAN_TID = 704; + +constexpr int ACLLITE_ERROR_VENC_SET_EOS = 705; + +constexpr int ACLLITE_ERROR_VENC_SET_IF_FRAME = 706; + +constexpr int ACLLITE_ERROR_CREATE_VENC_CHAN = 707; + +constexpr int ACLLITE_ERROR_VENC_CREATE_FRAME_CONFIG = 708; + +constexpr int ACLLITE_ERROR_VENC_SEND_FRAME = 709; + +constexpr int ACLLITE_ERROR_SUBSCRIBE_REPORT = 710; + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_ERROR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h index 7489096a1b4..f33d8cfe52c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h @@ -1,100 +1,100 @@ -/** -* Copyright 2022-2023 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ - -#include - -#include -#include - -#include "acl/acl.h" -#include "acl/ops/acl_dvpp.h" - -enum class MemoryType { MEMORY_NORMAL = 0, MEMORY_HOST, MEMORY_DEVICE, MEMORY_DVPP, MEMORY_INVALID_TYPE }; - -enum class CopyDirection { TO_DEVICE = 0, TO_HOST, INVALID_COPY_DIRECT }; - -enum class CameraId { - CAMERA_ID_0 = 0, - CAMERA_ID_1, - CAMERA_ID_INVALID, -}; - -enum VencStatus { STATUS_VENC_INIT = 0, STATUS_VENC_WORK, STATUS_VENC_FINISH, STATUS_VENC_EXIT, STATUS_VENC_ERROR }; - -struct VencConfig { - uint32_t maxWidth = 0; - uint32_t maxHeight = 0; - std::string outFile; - acldvppPixelFormat format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; - acldvppStreamFormat enType = H264_MAIN_LEVEL; - aclrtContext context = nullptr; - aclrtRunMode runMode = ACL_HOST; -}; - -struct ImageData { - acldvppPixelFormat format; - uint32_t width = 0; - uint32_t height = 0; - uint32_t alignWidth = 0; - uint32_t alignHeight = 0; - uint32_t size = 0; - std::shared_ptr data = nullptr; -}; - -struct FrameData { - bool isFinished = false; - uint32_t frameId = 0; - uint32_t size = 0; - void *data = nullptr; -}; - -struct Resolution { - uint32_t width = 0; - uint32_t height = 0; -}; - -struct Rect { - uint32_t ltX = 0; - uint32_t ltY = 0; - uint32_t rbX = 0; - uint32_t rbY = 0; -}; - -struct BBox { - Rect rect; - uint32_t score = 0; - std::string text; -}; - -struct AclLiteMessage { - int dest; - int msgId; - std::shared_ptr data = nullptr; -}; - -struct DataInfo { - void *data; - uint32_t size; -}; - -struct InferenceOutput { - std::shared_ptr data = nullptr; - uint32_t size; -}; - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ +/** +* Copyright 2022-2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ + +#include + +#include +#include + +#include "acl/acl.h" +#include "acl/ops/acl_dvpp.h" + +enum class MemoryType { MEMORY_NORMAL = 0, MEMORY_HOST, MEMORY_DEVICE, MEMORY_DVPP, MEMORY_INVALID_TYPE }; + +enum class CopyDirection { TO_DEVICE = 0, TO_HOST, INVALID_COPY_DIRECT }; + +enum class CameraId { + CAMERA_ID_0 = 0, + CAMERA_ID_1, + CAMERA_ID_INVALID, +}; + +enum VencStatus { STATUS_VENC_INIT = 0, STATUS_VENC_WORK, STATUS_VENC_FINISH, STATUS_VENC_EXIT, STATUS_VENC_ERROR }; + +struct VencConfig { + uint32_t maxWidth = 0; + uint32_t maxHeight = 0; + std::string outFile; + acldvppPixelFormat format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; + acldvppStreamFormat enType = H264_MAIN_LEVEL; + aclrtContext context = nullptr; + aclrtRunMode runMode = ACL_HOST; +}; + +struct ImageData { + acldvppPixelFormat format; + uint32_t width = 0; + uint32_t height = 0; + uint32_t alignWidth = 0; + uint32_t alignHeight = 0; + uint32_t size = 0; + std::shared_ptr data = nullptr; +}; + +struct FrameData { + bool isFinished = false; + uint32_t frameId = 0; + uint32_t size = 0; + void *data = nullptr; +}; + +struct Resolution { + uint32_t width = 0; + uint32_t height = 0; +}; + +struct Rect { + uint32_t ltX = 0; + uint32_t ltY = 0; + uint32_t rbX = 0; + uint32_t rbY = 0; +}; + +struct BBox { + Rect rect; + uint32_t score = 0; + std::string text; +}; + +struct AclLiteMessage { + int dest; + int msgId; + std::shared_ptr data = nullptr; +}; + +struct DataInfo { + void *data; + uint32_t size; +}; + +struct InferenceOutput { + std::shared_ptr data = nullptr; + uint32_t size; +}; + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_TYPE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc index a05fc5d00ce..3a57215a0f7 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc @@ -1,522 +1,522 @@ -/** -* Copyright 2022-2023 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "acl/ops/acl_dvpp.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/symbol_utils.h" - -namespace { -const char COMMENT_CHAR = '#'; -const char EQUALS_CHAR = '='; -const char BLANK_SPACE_CHAR = ' '; -const char TABLE_CHAR = '\t'; - -const std::string kImagePathSeparator = ","; -const int kStatSuccess = 0; -const std::string kFileSperator = "/"; -const std::string kPathSeparator = "/"; -// output image prefix -const std::string kOutputFilePrefix = "out_"; - -const std::string kRegexIpAddr = - "^(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|[0-9])\\." - "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\." - "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\." - "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)" - ":([1-9]|[1-9]\\d|[1-9]\\d{2}|[1-9]\\d{3}|[1-5]\\d{4}|" - "6[0-4]\\d{3}|65[0-4]\\d{2}|655[0-2]\\d|6553[0-5])$"; - -// regex for verify video file name -const std::string kRegexVideoFile = "^.+\\.(mp4|h264|h265)$"; - -// regex for verify RTSP rtsp://ip:port/channelname -const std::string kRegexRtsp = "^rtsp://.*"; -} // namespace - -bool IsDigitStr(const std::string &str) { return std::all_of(str.begin(), str.end(), isdigit); } - -bool IsPathExist(const std::string &path) { - std::ifstream file(path, std::ios::in); - if (!file) { - return false; - } - file.close(); - return true; -} - -bool IsVideoFile(const std::string &path) { - std::regex regexVideoFile(kRegexVideoFile.c_str()); - return regex_match(path, regexVideoFile); -} - -bool IsRtspAddr(const std::string &str) { - std::regex regexRtspAddress(kRegexRtsp.c_str()); - - return regex_match(str, regexRtspAddress); -} - -bool IsIpAddrWithPort(const std::string &addrStr) { - std::regex regexIpAddr(kRegexIpAddr.c_str()); - - return regex_match(addrStr, regexIpAddr); -} - -void ParseIpAddr(std::string &ip, std::string &port, const std::string &addr) { - std::string::size_type pos = addr.find(':'); - - (void)ip.assign(addr.substr(0, pos)); - (void)port.assign(addr.substr(pos + 1)); -} - -bool IsDirectory(const std::string &path) { - // get path stat - struct stat buf {}; - if (stat(path.c_str(), &buf) != kStatSuccess) { - return false; - } - - // check - return S_ISDIR(buf.st_mode); -} - -void SplitPath(const std::string &path, std::vector &pathVec) { - char *imageFile = strtok(const_cast(path.c_str()), kImagePathSeparator.c_str()); - while (imageFile) { - (void)pathVec.emplace_back(imageFile); - imageFile = strtok(nullptr, kImagePathSeparator.c_str()); - } -} - -void GetPathFiles(const std::string &path, std::vector &fileVec) { - if (IsDirectory(path)) { - DIR *dir = opendir(path.c_str()); - struct dirent *direntPtr; - while ((direntPtr = readdir(dir)) != nullptr) { - // skip . and .. - if (direntPtr->d_name[0] == '.') { - continue; - } - - // file path - std::string fullPath = path + kPathSeparator + direntPtr->d_name; - // directory need recursion - if (IsDirectory(fullPath)) { - GetPathFiles(fullPath, fileVec); - } else { - // put file - (void)fileVec.emplace_back(fullPath); - } - } - closedir(dir); - } else { - (void)fileVec.emplace_back(path); - } -} - -void GetAllFiles(const std::string &pathList, std::vector &fileVec) { - // split file path - std::vector pathVec; - SplitPath(pathList, pathVec); - - for (const std::string &everyPath : pathVec) { - // check path exist or not - if (!IsPathExist(pathList)) { - ACLLITE_LOG_ERROR("Failed to deal path=%s. Reason: not exist or can not access.", everyPath.c_str()); - continue; - } - // get files in path and sub-path - GetPathFiles(everyPath, fileVec); - } -} - -void *MallocMemory(uint32_t dataSize, MemoryType memType) { - void *buffer = nullptr; - aclError aclRet = ACL_SUCCESS; - - switch (memType) { - case MemoryType::MEMORY_NORMAL: - buffer = new uint8_t[dataSize]; - break; - case MemoryType::MEMORY_HOST: - aclRet = CALL_ASCEND_API(aclrtMallocHost, &buffer, dataSize); - break; - case MemoryType::MEMORY_DEVICE: - aclRet = CALL_ASCEND_API(aclrtMalloc, &buffer, dataSize, ACL_MEM_MALLOC_HUGE_FIRST); - break; - case MemoryType::MEMORY_DVPP: - aclRet = acldvppMalloc(&buffer, dataSize); - break; - default: - ACLLITE_LOG_ERROR("Invalid memory type %d", memType); - aclRet = ACL_ERROR_INVALID_PARAM; - break; - } - - if ((aclRet != ACL_SUCCESS) || (buffer == nullptr)) { - ACLLITE_LOG_ERROR("Malloc memory failed, type: %d, errorno:%d", memType, aclRet); - return nullptr; - } - - return buffer; -} - -void FreeMemory(void *mem, MemoryType memType) { - if (mem == nullptr) { - ACLLITE_LOG_ERROR("Invalid mem"); - return; - } - aclError ret = ACL_SUCCESS; - switch (memType) { - case MemoryType::MEMORY_NORMAL: - delete[](reinterpret_cast(mem)); - break; - case MemoryType::MEMORY_HOST: - ret = CALL_ASCEND_API(aclrtFreeHost, mem); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("aclrtFreeHost failed, errorno: %d", ret); - } - break; - case MemoryType::MEMORY_DEVICE: - ret = CALL_ASCEND_API(aclrtFree, mem); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("aclrtFree failed, errorno: %d", ret); - } - break; - case MemoryType::MEMORY_DVPP: - ret = acldvppFree(mem); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("acldvppFree failed, errorno: %d", ret); - } - break; - default: - ACLLITE_LOG_ERROR("Invalid memory type %d", memType); - break; - } -} - -aclrtMemcpyKind GetCopyPolicy(aclrtRunMode srcDev, CopyDirection direct, MemoryType memType) { - aclrtMemcpyKind policy = ACL_MEMCPY_HOST_TO_HOST; - - if (direct == CopyDirection::TO_DEVICE) { - if (srcDev == ACL_HOST) { - policy = ACL_MEMCPY_HOST_TO_DEVICE; - } else { - policy = ACL_MEMCPY_DEVICE_TO_DEVICE; - } - } else { // TO_HOST - if (srcDev == ACL_DEVICE) { - policy = ACL_MEMCPY_DEVICE_TO_HOST; - } - } - - return policy; -} - -void *CopyDataToDevice(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType) { - if ((data == nullptr) || (size == 0) || ((curRunMode != ACL_HOST) && (curRunMode != ACL_DEVICE)) || - (memType >= MemoryType::MEMORY_INVALID_TYPE) || (memType == MemoryType::MEMORY_HOST)) { - ACLLITE_LOG_ERROR( - "Copy data args invalid, data %p, " - "size %d, src dev %d, memory type %d", - data, size, curRunMode, memType); - return nullptr; - } - - aclrtMemcpyKind policy = GetCopyPolicy(curRunMode, CopyDirection::TO_DEVICE, memType); - - return CopyData(data, size, policy, memType); -} - -AclLiteError CopyDataToDeviceEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, - aclrtRunMode runMode) { - aclrtMemcpyKind policy = ACL_MEMCPY_HOST_TO_DEVICE; - if (runMode == ACL_DEVICE) { - policy = ACL_MEMCPY_DEVICE_TO_DEVICE; - } - - aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, dest, destSize, src, srcSize, policy); - if (aclRet != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); - return ACLLITE_ERROR; - } - - return ACLLITE_OK; -} - -void *CopyDataToHost(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType) { - if ((data == nullptr) || (size == 0) || ((curRunMode != ACL_HOST) && (curRunMode != ACL_DEVICE)) || - ((memType != MemoryType::MEMORY_HOST) && (memType != MemoryType::MEMORY_NORMAL))) { - ACLLITE_LOG_ERROR( - "Copy data args invalid, data %p, " - "size %d, src dev %d, memory type %d", - data, size, curRunMode, memType); - return nullptr; - } - - aclrtMemcpyKind policy = GetCopyPolicy(curRunMode, CopyDirection::TO_HOST, memType); - - return CopyData(data, size, policy, memType); -} - -AclLiteError CopyDataToHostEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode) { - aclrtMemcpyKind policy = ACL_MEMCPY_DEVICE_TO_HOST; - if (runMode == ACL_DEVICE) { - policy = ACL_MEMCPY_DEVICE_TO_DEVICE; - } - - aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, dest, destSize, src, srcSize, policy); - if (aclRet != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); - return ACLLITE_ERROR; - } - - return ACLLITE_OK; -} - -void *CopyData(const void *data, uint32_t size, aclrtMemcpyKind policy, MemoryType memType) { - void *buffer = MallocMemory(size, memType); - if (buffer == nullptr) { - return nullptr; - } - - aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, buffer, size, data, size, policy); - if (aclRet != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); - FreeMemory(buffer, memType); - return nullptr; - } - - return buffer; -} - -AclLiteError CopyImageToLocal(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode) { - void *data = CopyDataToHost(srcImage.data.get(), srcImage.size, curRunMode, MemoryType::MEMORY_NORMAL); - if (data == nullptr) { - return ACLLITE_ERROR_COPY_DATA; - } - - destImage.format = srcImage.format; - destImage.width = srcImage.width; - destImage.height = srcImage.height; - destImage.size = srcImage.size; - destImage.alignWidth = srcImage.alignWidth; - destImage.alignHeight = srcImage.alignHeight; - destImage.data = SHARED_PTR_U8_BUF(data); - - return ACLLITE_OK; -} - -AclLiteError CopyImageToDevice(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode, MemoryType memType) { - void *data = CopyDataToDevice(srcImage.data.get(), srcImage.size, curRunMode, memType); - if (data == nullptr) { - return ACLLITE_ERROR_COPY_DATA; - } - - destImage.format = srcImage.format; - destImage.width = srcImage.width; - destImage.height = srcImage.height; - destImage.size = srcImage.size; - destImage.alignWidth = srcImage.alignWidth; - destImage.alignHeight = srcImage.alignHeight; - - if (memType == MemoryType::MEMORY_DEVICE) { - destImage.data = SHARED_PTR_DEV_BUF(data); - } else { - destImage.data = SHARED_PTR_DVPP_BUF(data); - } - - return ACLLITE_OK; -} - -AclLiteError ReadBinFile(const std::string &fileName, void *&data, uint32_t &size) { - struct stat sBuf {}; - int fileStatus = stat(fileName.data(), &sBuf); - if (fileStatus == -1) { - ACLLITE_LOG_ERROR("failed to get file"); - return ACLLITE_ERROR_ACCESS_FILE; - } - if (S_ISREG(sBuf.st_mode) == 0) { - ACLLITE_LOG_ERROR("%s is not a file, please enter a file", fileName.c_str()); - return ACLLITE_ERROR_INVALID_FILE; - } - std::ifstream binFile(fileName, std::ifstream::in | std::ifstream::binary); - if (!binFile.is_open()) { - ACLLITE_LOG_ERROR("open file %s failed", fileName.c_str()); - return ACLLITE_ERROR_OPEN_FILE; - } - - (void)binFile.seekg(0, std::ifstream::end); - uint32_t binFileBufferLen = binFile.tellg(); - if (binFileBufferLen == 0) { - ACLLITE_LOG_ERROR("binfile is empty, filename is %s", fileName.c_str()); - binFile.close(); - return ACLLITE_ERROR_INVALID_FILE; - } - - (void)binFile.seekg(0, std::ifstream::beg); - - auto *binFileBufferData = new (std::nothrow) uint8_t[binFileBufferLen]; - if (binFileBufferData == nullptr) { - ACLLITE_LOG_ERROR("malloc binFileBufferData failed"); - binFile.close(); - return ACLLITE_ERROR_MALLOC; - } - (void)binFile.read(reinterpret_cast(binFileBufferData), binFileBufferLen); - binFile.close(); - - data = binFileBufferData; - size = binFileBufferLen; - - return ACLLITE_OK; -} - -AclLiteError ReadJpeg(ImageData &image, const std::string &fileName) { - uint32_t size = 0; - void *buf = nullptr; - - auto lite_ret = ReadBinFile(fileName, buf, size); - if (lite_ret != ACLLITE_OK) { - delete[](reinterpret_cast(buf)); - return lite_ret; - } - - int32_t ch = 0; - auto ret = acldvppJpegGetImageInfo(buf, size, &(image.width), &(image.height), &ch); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("acldvppJpegGetImageInfo failed, errorno: %d", ret); - delete[](reinterpret_cast(buf)); - return ACLLITE_ERROR; - } - if (image.width == 0 || image.height == 0) { - ACLLITE_LOG_ERROR("unsupported format, only Baseline JPEG"); - delete[](reinterpret_cast(buf)); - return ACLLITE_ERROR; - } - image.data.reset(reinterpret_cast(buf), [](const uint8_t *p) { delete[](p); }); - image.size = size; - - return ACLLITE_OK; -} - -void SaveBinFile(const std::string &filename, const void *data, uint32_t size) { - FILE *outFileFp = fopen(filename.c_str(), "wb+"); - if (outFileFp == nullptr) { - ACLLITE_LOG_ERROR("Save file %s failed for open error", filename.c_str()); - return; - } - (void)fwrite(data, 1, size, outFileFp); - - (void)fflush(outFileFp); - (void)fclose(outFileFp); -} - -bool IsSpace(char c) { return (c == BLANK_SPACE_CHAR || c == TABLE_CHAR); } - -void Trim(std::string &str) { - if (str.empty()) { - return; - } - int32_t i; - int32_t start_pos; - int32_t end_pos; - for (i = 0; i < str.size(); ++i) { - if (!IsSpace(str[i])) { - break; - } - } - if (i == str.size()) { // is all blank space - str = ""; - return; - } - - start_pos = i; - - for (i = str.size() - 1; i >= 0; --i) { - if (!IsSpace(str[i])) { - break; - } - } - end_pos = i; - - str = str.substr(start_pos, end_pos - start_pos + 1); -} - -bool AnalyseLine(const std::string &line, std::string &key, std::string &value) { - if (line.empty()) { - return false; - } - - int start_pos = 0; - auto end_pos = line.size() - 1; - std::string::size_type pos = line.find(COMMENT_CHAR); - if (pos != std::string::npos) { - if (pos == 0) { // the first charactor is # - return false; - } - end_pos = pos - 1; - } - std::string new_line = line.substr(start_pos, start_pos + 1 - end_pos); // delete comment - pos = new_line.find(EQUALS_CHAR); - if (pos == std::string::npos) { // has no = - return false; - } - - key = new_line.substr(0, pos); - value = new_line.substr(pos + 1, end_pos + 1 - (pos + 1)); - - Trim(key); - if (key.empty()) { - return false; - } - Trim(value); - return true; -} - -bool ReadConfig(std::map &config, const char *configFile) { - config.clear(); - std::ifstream infile(configFile, std::ifstream::in); - if (!infile) { - return false; - } - std::string line; - std::string key; - std::string value; - while (getline(infile, line)) { - if (AnalyseLine(line, key, value)) { - config[key] = value; - } - } - - infile.close(); - return true; -} - -void PrintConfig(const std::map &config) { - auto mIter = config.begin(); - for (; mIter != config.end(); ++mIter) { - std::cout << mIter->first << "=" << mIter->second << std::endl; - } -} +/** +* Copyright 2022-2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "acl/ops/acl_dvpp.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/symbol_utils.h" + +namespace { +const char COMMENT_CHAR = '#'; +const char EQUALS_CHAR = '='; +const char BLANK_SPACE_CHAR = ' '; +const char TABLE_CHAR = '\t'; + +const std::string kImagePathSeparator = ","; +const int kStatSuccess = 0; +const std::string kFileSperator = "/"; +const std::string kPathSeparator = "/"; +// output image prefix +const std::string kOutputFilePrefix = "out_"; + +const std::string kRegexIpAddr = + "^(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|[0-9])\\." + "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\." + "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\." + "(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)" + ":([1-9]|[1-9]\\d|[1-9]\\d{2}|[1-9]\\d{3}|[1-5]\\d{4}|" + "6[0-4]\\d{3}|65[0-4]\\d{2}|655[0-2]\\d|6553[0-5])$"; + +// regex for verify video file name +const std::string kRegexVideoFile = "^.+\\.(mp4|h264|h265)$"; + +// regex for verify RTSP rtsp://ip:port/channelname +const std::string kRegexRtsp = "^rtsp://.*"; +} // namespace + +bool IsDigitStr(const std::string &str) { return std::all_of(str.begin(), str.end(), isdigit); } + +bool IsPathExist(const std::string &path) { + std::ifstream file(path, std::ios::in); + if (!file) { + return false; + } + file.close(); + return true; +} + +bool IsVideoFile(const std::string &path) { + std::regex regexVideoFile(kRegexVideoFile.c_str()); + return regex_match(path, regexVideoFile); +} + +bool IsRtspAddr(const std::string &str) { + std::regex regexRtspAddress(kRegexRtsp.c_str()); + + return regex_match(str, regexRtspAddress); +} + +bool IsIpAddrWithPort(const std::string &addrStr) { + std::regex regexIpAddr(kRegexIpAddr.c_str()); + + return regex_match(addrStr, regexIpAddr); +} + +void ParseIpAddr(std::string &ip, std::string &port, const std::string &addr) { + std::string::size_type pos = addr.find(':'); + + (void)ip.assign(addr.substr(0, pos)); + (void)port.assign(addr.substr(pos + 1)); +} + +bool IsDirectory(const std::string &path) { + // get path stat + struct stat buf {}; + if (stat(path.c_str(), &buf) != kStatSuccess) { + return false; + } + + // check + return S_ISDIR(buf.st_mode); +} + +void SplitPath(const std::string &path, std::vector &pathVec) { + char *imageFile = strtok(const_cast(path.c_str()), kImagePathSeparator.c_str()); + while (imageFile) { + (void)pathVec.emplace_back(imageFile); + imageFile = strtok(nullptr, kImagePathSeparator.c_str()); + } +} + +void GetPathFiles(const std::string &path, std::vector &fileVec) { + if (IsDirectory(path)) { + DIR *dir = opendir(path.c_str()); + struct dirent *direntPtr; + while ((direntPtr = readdir(dir)) != nullptr) { + // skip . and .. + if (direntPtr->d_name[0] == '.') { + continue; + } + + // file path + std::string fullPath = path + kPathSeparator + direntPtr->d_name; + // directory need recursion + if (IsDirectory(fullPath)) { + GetPathFiles(fullPath, fileVec); + } else { + // put file + (void)fileVec.emplace_back(fullPath); + } + } + closedir(dir); + } else { + (void)fileVec.emplace_back(path); + } +} + +void GetAllFiles(const std::string &pathList, std::vector &fileVec) { + // split file path + std::vector pathVec; + SplitPath(pathList, pathVec); + + for (const std::string &everyPath : pathVec) { + // check path exist or not + if (!IsPathExist(pathList)) { + ACLLITE_LOG_ERROR("Failed to deal path=%s. Reason: not exist or can not access.", everyPath.c_str()); + continue; + } + // get files in path and sub-path + GetPathFiles(everyPath, fileVec); + } +} + +void *MallocMemory(uint32_t dataSize, MemoryType memType) { + void *buffer = nullptr; + aclError aclRet = ACL_SUCCESS; + + switch (memType) { + case MemoryType::MEMORY_NORMAL: + buffer = new uint8_t[dataSize]; + break; + case MemoryType::MEMORY_HOST: + aclRet = CALL_ASCEND_API(aclrtMallocHost, &buffer, dataSize); + break; + case MemoryType::MEMORY_DEVICE: + aclRet = CALL_ASCEND_API(aclrtMalloc, &buffer, dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + break; + case MemoryType::MEMORY_DVPP: + aclRet = acldvppMalloc(&buffer, dataSize); + break; + default: + ACLLITE_LOG_ERROR("Invalid memory type %d", memType); + aclRet = ACL_ERROR_INVALID_PARAM; + break; + } + + if ((aclRet != ACL_SUCCESS) || (buffer == nullptr)) { + ACLLITE_LOG_ERROR("Malloc memory failed, type: %d, errorno:%d", memType, aclRet); + return nullptr; + } + + return buffer; +} + +void FreeMemory(void *mem, MemoryType memType) { + if (mem == nullptr) { + ACLLITE_LOG_ERROR("Invalid mem"); + return; + } + aclError ret = ACL_SUCCESS; + switch (memType) { + case MemoryType::MEMORY_NORMAL: + delete[](reinterpret_cast(mem)); + break; + case MemoryType::MEMORY_HOST: + ret = CALL_ASCEND_API(aclrtFreeHost, mem); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("aclrtFreeHost failed, errorno: %d", ret); + } + break; + case MemoryType::MEMORY_DEVICE: + ret = CALL_ASCEND_API(aclrtFree, mem); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("aclrtFree failed, errorno: %d", ret); + } + break; + case MemoryType::MEMORY_DVPP: + ret = acldvppFree(mem); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("acldvppFree failed, errorno: %d", ret); + } + break; + default: + ACLLITE_LOG_ERROR("Invalid memory type %d", memType); + break; + } +} + +aclrtMemcpyKind GetCopyPolicy(aclrtRunMode srcDev, CopyDirection direct, MemoryType memType) { + aclrtMemcpyKind policy = ACL_MEMCPY_HOST_TO_HOST; + + if (direct == CopyDirection::TO_DEVICE) { + if (srcDev == ACL_HOST) { + policy = ACL_MEMCPY_HOST_TO_DEVICE; + } else { + policy = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + } else { // TO_HOST + if (srcDev == ACL_DEVICE) { + policy = ACL_MEMCPY_DEVICE_TO_HOST; + } + } + + return policy; +} + +void *CopyDataToDevice(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType) { + if ((data == nullptr) || (size == 0) || ((curRunMode != ACL_HOST) && (curRunMode != ACL_DEVICE)) || + (memType >= MemoryType::MEMORY_INVALID_TYPE) || (memType == MemoryType::MEMORY_HOST)) { + ACLLITE_LOG_ERROR( + "Copy data args invalid, data %p, " + "size %d, src dev %d, memory type %d", + data, size, curRunMode, memType); + return nullptr; + } + + aclrtMemcpyKind policy = GetCopyPolicy(curRunMode, CopyDirection::TO_DEVICE, memType); + + return CopyData(data, size, policy, memType); +} + +AclLiteError CopyDataToDeviceEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, + aclrtRunMode runMode) { + aclrtMemcpyKind policy = ACL_MEMCPY_HOST_TO_DEVICE; + if (runMode == ACL_DEVICE) { + policy = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + + aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, dest, destSize, src, srcSize, policy); + if (aclRet != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); + return ACLLITE_ERROR; + } + + return ACLLITE_OK; +} + +void *CopyDataToHost(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType) { + if ((data == nullptr) || (size == 0) || ((curRunMode != ACL_HOST) && (curRunMode != ACL_DEVICE)) || + ((memType != MemoryType::MEMORY_HOST) && (memType != MemoryType::MEMORY_NORMAL))) { + ACLLITE_LOG_ERROR( + "Copy data args invalid, data %p, " + "size %d, src dev %d, memory type %d", + data, size, curRunMode, memType); + return nullptr; + } + + aclrtMemcpyKind policy = GetCopyPolicy(curRunMode, CopyDirection::TO_HOST, memType); + + return CopyData(data, size, policy, memType); +} + +AclLiteError CopyDataToHostEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode) { + aclrtMemcpyKind policy = ACL_MEMCPY_DEVICE_TO_HOST; + if (runMode == ACL_DEVICE) { + policy = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + + aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, dest, destSize, src, srcSize, policy); + if (aclRet != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); + return ACLLITE_ERROR; + } + + return ACLLITE_OK; +} + +void *CopyData(const void *data, uint32_t size, aclrtMemcpyKind policy, MemoryType memType) { + void *buffer = MallocMemory(size, memType); + if (buffer == nullptr) { + return nullptr; + } + + aclError aclRet = CALL_ASCEND_API(aclrtMemcpy, buffer, size, data, size, policy); + if (aclRet != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Copy data to device failed, aclRet is %d", aclRet); + FreeMemory(buffer, memType); + return nullptr; + } + + return buffer; +} + +AclLiteError CopyImageToLocal(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode) { + void *data = CopyDataToHost(srcImage.data.get(), srcImage.size, curRunMode, MemoryType::MEMORY_NORMAL); + if (data == nullptr) { + return ACLLITE_ERROR_COPY_DATA; + } + + destImage.format = srcImage.format; + destImage.width = srcImage.width; + destImage.height = srcImage.height; + destImage.size = srcImage.size; + destImage.alignWidth = srcImage.alignWidth; + destImage.alignHeight = srcImage.alignHeight; + destImage.data = SHARED_PTR_U8_BUF(data); + + return ACLLITE_OK; +} + +AclLiteError CopyImageToDevice(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode, MemoryType memType) { + void *data = CopyDataToDevice(srcImage.data.get(), srcImage.size, curRunMode, memType); + if (data == nullptr) { + return ACLLITE_ERROR_COPY_DATA; + } + + destImage.format = srcImage.format; + destImage.width = srcImage.width; + destImage.height = srcImage.height; + destImage.size = srcImage.size; + destImage.alignWidth = srcImage.alignWidth; + destImage.alignHeight = srcImage.alignHeight; + + if (memType == MemoryType::MEMORY_DEVICE) { + destImage.data = SHARED_PTR_DEV_BUF(data); + } else { + destImage.data = SHARED_PTR_DVPP_BUF(data); + } + + return ACLLITE_OK; +} + +AclLiteError ReadBinFile(const std::string &fileName, void *&data, uint32_t &size) { + struct stat sBuf {}; + int fileStatus = stat(fileName.data(), &sBuf); + if (fileStatus == -1) { + ACLLITE_LOG_ERROR("failed to get file"); + return ACLLITE_ERROR_ACCESS_FILE; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ACLLITE_LOG_ERROR("%s is not a file, please enter a file", fileName.c_str()); + return ACLLITE_ERROR_INVALID_FILE; + } + std::ifstream binFile(fileName, std::ifstream::in | std::ifstream::binary); + if (!binFile.is_open()) { + ACLLITE_LOG_ERROR("open file %s failed", fileName.c_str()); + return ACLLITE_ERROR_OPEN_FILE; + } + + (void)binFile.seekg(0, std::ifstream::end); + uint32_t binFileBufferLen = binFile.tellg(); + if (binFileBufferLen == 0) { + ACLLITE_LOG_ERROR("binfile is empty, filename is %s", fileName.c_str()); + binFile.close(); + return ACLLITE_ERROR_INVALID_FILE; + } + + (void)binFile.seekg(0, std::ifstream::beg); + + auto *binFileBufferData = new (std::nothrow) uint8_t[binFileBufferLen]; + if (binFileBufferData == nullptr) { + ACLLITE_LOG_ERROR("malloc binFileBufferData failed"); + binFile.close(); + return ACLLITE_ERROR_MALLOC; + } + (void)binFile.read(reinterpret_cast(binFileBufferData), binFileBufferLen); + binFile.close(); + + data = binFileBufferData; + size = binFileBufferLen; + + return ACLLITE_OK; +} + +AclLiteError ReadJpeg(ImageData &image, const std::string &fileName) { + uint32_t size = 0; + void *buf = nullptr; + + auto lite_ret = ReadBinFile(fileName, buf, size); + if (lite_ret != ACLLITE_OK) { + delete[](reinterpret_cast(buf)); + return lite_ret; + } + + int32_t ch = 0; + auto ret = acldvppJpegGetImageInfo(buf, size, &(image.width), &(image.height), &ch); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("acldvppJpegGetImageInfo failed, errorno: %d", ret); + delete[](reinterpret_cast(buf)); + return ACLLITE_ERROR; + } + if (image.width == 0 || image.height == 0) { + ACLLITE_LOG_ERROR("unsupported format, only Baseline JPEG"); + delete[](reinterpret_cast(buf)); + return ACLLITE_ERROR; + } + image.data.reset(reinterpret_cast(buf), [](const uint8_t *p) { delete[](p); }); + image.size = size; + + return ACLLITE_OK; +} + +void SaveBinFile(const std::string &filename, const void *data, uint32_t size) { + FILE *outFileFp = fopen(filename.c_str(), "wb+"); + if (outFileFp == nullptr) { + ACLLITE_LOG_ERROR("Save file %s failed for open error", filename.c_str()); + return; + } + (void)fwrite(data, 1, size, outFileFp); + + (void)fflush(outFileFp); + (void)fclose(outFileFp); +} + +bool IsSpace(char c) { return (c == BLANK_SPACE_CHAR || c == TABLE_CHAR); } + +void Trim(std::string &str) { + if (str.empty()) { + return; + } + int32_t i; + int32_t start_pos; + int32_t end_pos; + for (i = 0; i < str.size(); ++i) { + if (!IsSpace(str[i])) { + break; + } + } + if (i == str.size()) { // is all blank space + str = ""; + return; + } + + start_pos = i; + + for (i = str.size() - 1; i >= 0; --i) { + if (!IsSpace(str[i])) { + break; + } + } + end_pos = i; + + str = str.substr(start_pos, end_pos - start_pos + 1); +} + +bool AnalyseLine(const std::string &line, std::string &key, std::string &value) { + if (line.empty()) { + return false; + } + + int start_pos = 0; + auto end_pos = line.size() - 1; + std::string::size_type pos = line.find(COMMENT_CHAR); + if (pos != std::string::npos) { + if (pos == 0) { // the first charactor is # + return false; + } + end_pos = pos - 1; + } + std::string new_line = line.substr(start_pos, start_pos + 1 - end_pos); // delete comment + pos = new_line.find(EQUALS_CHAR); + if (pos == std::string::npos) { // has no = + return false; + } + + key = new_line.substr(0, pos); + value = new_line.substr(pos + 1, end_pos + 1 - (pos + 1)); + + Trim(key); + if (key.empty()) { + return false; + } + Trim(value); + return true; +} + +bool ReadConfig(std::map &config, const char *configFile) { + config.clear(); + std::ifstream infile(configFile, std::ifstream::in); + if (!infile) { + return false; + } + std::string line; + std::string key; + std::string value; + while (getline(infile, line)) { + if (AnalyseLine(line, key, value)) { + config[key] = value; + } + } + + infile.close(); + return true; +} + +void PrintConfig(const std::map &config) { + auto mIter = config.begin(); + for (; mIter != config.end(); ++mIter) { + std::cout << mIter->first << "=" << mIter->second << std::endl; + } +} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h index df728c1118b..0cf5f33ed79 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h @@ -1,471 +1,471 @@ -/** -* Copyright 2022-2023 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "acl/ops/acl_dvpp.h" - -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h" -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/symbol_utils.h" - -/** - * @brief calculate RGB 24bits image size - * @param [in]: width: image width - * @param [in]: height: image height - * @return bytes size of image - */ -#define RGBU8_IMAGE_SIZE(width, height) ((width) * (height)*3) - -/** - * @brief calculate RGB C3F32 image size - * @param [in]: width: image width - * @param [in]: height: image height - * @return bytes size of image - */ -#define RGBF32_IMAGE_SIZE(width, height) ((width) * (height)*3 * sizeof(float)) - -/** - * @brief calculate YUVSP420 image size - * @param [in]: width: image width - * @param [in]: height: image height - * @return bytes size of image - */ -#define YUV420SP_SIZE(width, height) ((width) * (height)*3 / 2) - -/** - * @brief calculate YUVSP420 nv12 load to opencv mat height paramter - * @param [in]: height: yuv image height - * @return bytes size of image - */ -#define YUV420SP_CV_MAT_HEIGHT(height) ((height)*3 / 2) - -/** - * @brief generate shared pointer of dvpp memory - * @param [in]: buf: memory pointer, malloc by acldvppMalloc - * @return shared pointer of input buffer - */ -#define SHARED_PTR_DVPP_BUF(buf) \ - (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { acldvppFree(p); })) - -/** - * @brief generate shared pointer of device memory - * @param [in]: buf: memory pointer, malloc by acldvppMalloc - * @return shared pointer of input buffer - */ -#define SHARED_PTR_DEV_BUF(buf) \ - (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { CALL_ASCEND_API(aclrtFree, p); })) - -/** - * @brief generate shared pointer of memory - * @param [in]: buf memory pointer, malloc by new - * @return shared pointer of input buffer - */ -#define SHARED_PTR_U8_BUF(buf) \ - (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { delete[](p); })) - -/** - * @brief calculate aligned number - * @param [in]: num: the original number that to aligned - * @param [in]: align: the align factor - * @return the number after aligned - */ -#define ALIGN_UP(num, align) (((num) + (align)-1) & ~((align)-1)) - -/** - * @brief calculate number align with 2 - * @param [in]: num: the original number that to aligned - * @return the number after aligned - */ -#define ALIGN_UP2(num) ALIGN_UP(num, 2) - -/** - * @brief calculate number align with 16 - * @param [in]: num: the original number that to aligned - * @return the number after aligned - */ -#define ALIGN_UP16(num) ALIGN_UP(num, 16) - -/** - * @brief calculate number align with 128 - * @param [in]: num: the original number that to aligned - * @return the number after aligned - */ -#define ALIGN_UP128(num) ALIGN_UP(num, 128) - -/** - * @brief calculate elements num of array - * @param [in]: array: the array variable - * @return elements num of array - */ -#define SIZEOF_ARRAY(array) (sizeof(array) / sizeof(array[0])) - -/** - * @brief Write acl error level log to host log - * @param [in]: fmt: the input format string - * @return none - */ -#define ACLLITE_LOG_ERROR(fmt, ...) \ - do { \ - aclAppLog(ACL_ERROR, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ - fprintf(stdout, "[ERROR] " fmt "\n", ##__VA_ARGS__); \ - } while (0) - -/** - * @brief Write acl info level log to host log - * @param [in]: fmt: the input format string - * @return none - */ -#define ACLLITE_LOG_INFO(fmt, ...) \ - do { \ - aclAppLog(ACL_INFO, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ - fprintf(stdout, "[INFO] " fmt "\n", ##__VA_ARGS__); \ - } while (0) - -/** - * @brief Write acl warining level log to host log - * @param [in]: fmt: the input format string - * @return none - */ -#define ACLLITE_LOG_WARNING(fmt, ...) \ - do { \ - aclAppLog(ACL_WARNING, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ - fprintf(stdout, "[WARNING] " fmt "\n", ##__VA_ARGS__); \ - } while (0) - -/** - * @brief Write acl debug level log to host log - * @param [in]: fmt: the input format string - * @return none - */ -#define ACLLITE_LOG_DEBUG(fmt, ...) \ - do { \ - aclAppLog(ACL_DEBUG, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ - fprintf(stdout, "[INFO] " fmt "\n", ##__VA_ARGS__); \ - } while (0) - -/** - * @brief define variable record time && - set start time - * @param [X]: function name - * @return X_START X_END - */ -#define TIME_START(X) auto X##_START = std::chrono::steady_clock::now(), X##_END = X##_START - -/** - * @brief set end time - * @param [X]: function name - * @return none - */ -#define TIME_END(X) X##_END = std::chrono::steady_clock::now() - -/** - * @brief calculate time by nanosecond - * @param [X]: function name - * @return none - */ -#define TIME_NSEC(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by nanosecond - * @param [X]: function name - * @return none - */ -#define TIME_NSEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_NSEC(X) << " ns " << endl - -/** - * @brief calculate time and show by microsecond - * @param [X]: variable name - * @return none - */ -#define TIME_USEC(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by microsecond - * @param [X]: function name - * @return none - */ -#define TIME_USEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_USEC(X) << " us " << endl - -/** - * @brief calculate time and show by millisecond - * @param [X]: variable name - * @return none - */ -#define TIME_MSEC(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by millisecond - * @param [X]: function name - * @return none - */ -#define TIME_MSEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_MSEC(X) << " ms " << endl - -/** - * @brief calculate time and show by second - * @param [X]: variable name - * @return none - */ -#define TIME_SEC(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by second - * @param [X]: function name - * @return none - */ -#define TIME_SEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_SEC(X) << " s " << endl - -/** - * @brief calculate time and show by minute - * @param [X]: variable name - * @return none - */ -#define TIME_MINUTE(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by minute - * @param [X]: function name - * @return none - */ -#define TIME_MINUTE_SHOW(X) cout << "Func " << #X << " cost : " << TIME_MINUTE(X) << " min " << endl - -/** - * @brief calculate time and show by hour - * @param [X]: variable name - * @return none - */ -#define TIME_HOUR(X) std::chrono::duration_cast(X##_END - X##_START).count() - -/** - * @brief show time by hour - * @param [X]: function name - * @return none - */ -#define TIME_HOUR_SHOW(X) cout << "Func " << #X << " cost : " << TIME_HOUR(X) << " h " << endl - -/** - * @brief Recognize the string is a accessable directory or not - * @param [in]: path: the input string - * @return bool true: is directory; false: not directory - */ -bool IsDirectory(const std::string &path); - -/** - * @brief Copy data to device - * @param [in]: data: The data to copy - * @param [in]: size: The data bytes size - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @param [in]: memType: The dest memory type:MEMORY_NORMAL(in Atlas200DK), - * MEMORY_DEVICE, MEMORY_DVPP - * @return void* The dest memory pointer - */ -void *CopyDataToDevice(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType); - -/** - * @brief Copy data to device buffer - * @param [in]: dest: The device buffer - * @param [in]: destSize: The device buffer size - * @param [in]: src: The data to copy - * @param [in]: srcSize: The data bytes size - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @return AclLiteError ACLLITE_OK: copy success - * others: copy failed - */ -AclLiteError CopyDataToDeviceEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode); - -/** - * @brief Copy data to host - * @param [in]: data: The data to be copy - * @param [in]: size: The data bytes size - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @param [in]: memType: The dest memory type:MEMORY_NORMAL, MEMORY_HOST - * @return void* The dest memory pointer - */ -void *CopyDataToHost(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType); - -/** - * @brief Copy data to host buffer - * @param [in]: dest: The host buffer - * @param [in]: destSize: The host buffer size - * @param [in]: src: The data to copy - * @param [in]: srcSize: The data bytes size - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @return AclLiteError ACLLITE_OK: copy success - * others: copy failed - */ -AclLiteError CopyDataToHostEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode); - -/** - * @brief Copy data to memory - * @param [in]: data: The data to be copy - * @param [in]: size: The data bytes size - * @param [in]: policy: the kind of sync, - * typedef enum aclrtMemcpyKind { - * ACL_MEMCPY_HOST_TO_HOST, // Memory copy from Host to Host - * ACL_MEMCPY_HOST_TO_DEVICE, // Memory copy from Host to Device - * ACL_MEMCPY_DEVICE_TO_HOST, // Memory copy from Device to Host - * ACL_MEMCPY_DEVICE_TO_DEVICE, // Memory copy from Device to Device - * } aclrtMemcpyKind; - * @param [in]: memType: The dest memory type - * @return void* The dest memory pointer - */ -void *CopyData(const void *data, uint32_t size, aclrtMemcpyKind policy, MemoryType memType); - -/** - * @brief Read jpeg image file. Only support baseline, not support progressive - * @param [out]: image: image data read from file. - * @param [in]: fileName: The data bytes size - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -AclLiteError ReadJpeg(ImageData &image, const std::string &fileName); - -/** - * @brief Get all files from file list string - * @param [in]: pathList: files list string, seperate by ',', - * the element could be file path or directory - * @param [in]: fileVec: The data bytes size - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -void GetAllFiles(const std::string &pathList, std::vector &fileVec); - -/** - * @brief Save data to binary file - * @param [in]: filename: binary file name with path - * @param [in]: data: binary data - * @param [in]: size: bytes size of data - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -void SaveBinFile(const std::string &filename, const void *data, uint32_t size); - -/** - * @brief Read binary file to buffer - * @param [in]: fileName: binary file name with path - * @param [in]: data: buffer - * @param [in]: size: buffer size - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -AclLiteError ReadBinFile(const std::string &fileName, void *&data, uint32_t &size); - -/** - * @brief Copy image to memory that malloc by new - * @param [out]: destImage: The image after copy - * @param [in]: srcImage: The image to copy - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -AclLiteError CopyImageToLocal(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode); - -/** - * @brief Copy image to acl device - * @param [out]: destImage: The image after copy - * @param [in]: srcImage: The image to copy - * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, - * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST - * @param [in]: memType: memory type, dvpp is MEMORY_DVPP, - * device is MEMPRY_DEVICE - * @return AclLiteError ACLLITE_OK: read success - * others: read failed - */ -AclLiteError CopyImageToDevice(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode, MemoryType memType); - -/** - * @brief Match ip address string as <1-255>.<0-255>.<0-255>.<0-255>: - * @param [in]: addrStr: Ip address string - * @return bool true: The input string match success - * false: is not match - */ -bool IsIpAddrWithPort(const std::string &addrStr); - -/** - * @brief Split ip address string <1-255>.<0-255>.<0-255>.<0-255>: to - * ip and port - * @param [out]: ip: Ip address <1-255>.<0-255>.<0-255>.<0-255> - * @param [out]: port: port string - * @param [in]: addr: Ip address string - * @return None - */ -void ParseIpAddr(std::string &ip, std::string &port, const std::string &addr); - -/** - * @brief Judge input string is mp4 file path - * @param [in]: path: file path - * @return bool true: input string is mp4 file path - * false: is not mp4 file path - */ -bool IsVideoFile(const std::string &path); - -/** - * @brief Judge input string is rtsp addr link rtsp:// - * @param [in]: str: input string - * @return bool true: input string is rtsp address - * false: is not rtsp address - */ -bool IsRtspAddr(const std::string &str); - -/** - * @brief Judge input string is digit string - * @param [in]: str: input string - * @return bool true: input string is digit string - * false: is not rtsp address - */ -bool IsDigitStr(const std::string &str); - -/** - * @brief Test file path is exist or not - * @param [in]: path: file path - * @return bool true: file path is exist - * false: is not exist - */ -bool IsPathExist(const std::string &path); - -/** - * @brief read file and save information to config - * @param [out]: config: map, save option information - * @param [in]: configFile: string, file - * @return bool true: read config success - * false: read config fail - */ -bool ReadConfig(std::map &config, const char *configFile); - -/** - * @brief print option information - * @param [in]: m: map, save option information - * @return None - */ -void PrintConfig(const std::map &m); - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ +/** +* Copyright 2022-2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "acl/ops/acl_dvpp.h" + +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h" +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/symbol_utils.h" + +/** + * @brief calculate RGB 24bits image size + * @param [in]: width: image width + * @param [in]: height: image height + * @return bytes size of image + */ +#define RGBU8_IMAGE_SIZE(width, height) ((width) * (height)*3) + +/** + * @brief calculate RGB C3F32 image size + * @param [in]: width: image width + * @param [in]: height: image height + * @return bytes size of image + */ +#define RGBF32_IMAGE_SIZE(width, height) ((width) * (height)*3 * sizeof(float)) + +/** + * @brief calculate YUVSP420 image size + * @param [in]: width: image width + * @param [in]: height: image height + * @return bytes size of image + */ +#define YUV420SP_SIZE(width, height) ((width) * (height)*3 / 2) + +/** + * @brief calculate YUVSP420 nv12 load to opencv mat height paramter + * @param [in]: height: yuv image height + * @return bytes size of image + */ +#define YUV420SP_CV_MAT_HEIGHT(height) ((height)*3 / 2) + +/** + * @brief generate shared pointer of dvpp memory + * @param [in]: buf: memory pointer, malloc by acldvppMalloc + * @return shared pointer of input buffer + */ +#define SHARED_PTR_DVPP_BUF(buf) \ + (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { acldvppFree(p); })) + +/** + * @brief generate shared pointer of device memory + * @param [in]: buf: memory pointer, malloc by acldvppMalloc + * @return shared pointer of input buffer + */ +#define SHARED_PTR_DEV_BUF(buf) \ + (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { CALL_ASCEND_API(aclrtFree, p); })) + +/** + * @brief generate shared pointer of memory + * @param [in]: buf memory pointer, malloc by new + * @return shared pointer of input buffer + */ +#define SHARED_PTR_U8_BUF(buf) \ + (std::shared_ptr(reinterpret_cast(buf), [](uint8_t *p) { delete[](p); })) + +/** + * @brief calculate aligned number + * @param [in]: num: the original number that to aligned + * @param [in]: align: the align factor + * @return the number after aligned + */ +#define ALIGN_UP(num, align) (((num) + (align)-1) & ~((align)-1)) + +/** + * @brief calculate number align with 2 + * @param [in]: num: the original number that to aligned + * @return the number after aligned + */ +#define ALIGN_UP2(num) ALIGN_UP(num, 2) + +/** + * @brief calculate number align with 16 + * @param [in]: num: the original number that to aligned + * @return the number after aligned + */ +#define ALIGN_UP16(num) ALIGN_UP(num, 16) + +/** + * @brief calculate number align with 128 + * @param [in]: num: the original number that to aligned + * @return the number after aligned + */ +#define ALIGN_UP128(num) ALIGN_UP(num, 128) + +/** + * @brief calculate elements num of array + * @param [in]: array: the array variable + * @return elements num of array + */ +#define SIZEOF_ARRAY(array) (sizeof(array) / sizeof(array[0])) + +/** + * @brief Write acl error level log to host log + * @param [in]: fmt: the input format string + * @return none + */ +#define ACLLITE_LOG_ERROR(fmt, ...) \ + do { \ + aclAppLog(ACL_ERROR, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + fprintf(stdout, "[ERROR] " fmt "\n", ##__VA_ARGS__); \ + } while (0) + +/** + * @brief Write acl info level log to host log + * @param [in]: fmt: the input format string + * @return none + */ +#define ACLLITE_LOG_INFO(fmt, ...) \ + do { \ + aclAppLog(ACL_INFO, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + fprintf(stdout, "[INFO] " fmt "\n", ##__VA_ARGS__); \ + } while (0) + +/** + * @brief Write acl warining level log to host log + * @param [in]: fmt: the input format string + * @return none + */ +#define ACLLITE_LOG_WARNING(fmt, ...) \ + do { \ + aclAppLog(ACL_WARNING, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + fprintf(stdout, "[WARNING] " fmt "\n", ##__VA_ARGS__); \ + } while (0) + +/** + * @brief Write acl debug level log to host log + * @param [in]: fmt: the input format string + * @return none + */ +#define ACLLITE_LOG_DEBUG(fmt, ...) \ + do { \ + aclAppLog(ACL_DEBUG, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + fprintf(stdout, "[INFO] " fmt "\n", ##__VA_ARGS__); \ + } while (0) + +/** + * @brief define variable record time && + set start time + * @param [X]: function name + * @return X_START X_END + */ +#define TIME_START(X) auto X##_START = std::chrono::steady_clock::now(), X##_END = X##_START + +/** + * @brief set end time + * @param [X]: function name + * @return none + */ +#define TIME_END(X) X##_END = std::chrono::steady_clock::now() + +/** + * @brief calculate time by nanosecond + * @param [X]: function name + * @return none + */ +#define TIME_NSEC(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by nanosecond + * @param [X]: function name + * @return none + */ +#define TIME_NSEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_NSEC(X) << " ns " << endl + +/** + * @brief calculate time and show by microsecond + * @param [X]: variable name + * @return none + */ +#define TIME_USEC(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by microsecond + * @param [X]: function name + * @return none + */ +#define TIME_USEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_USEC(X) << " us " << endl + +/** + * @brief calculate time and show by millisecond + * @param [X]: variable name + * @return none + */ +#define TIME_MSEC(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by millisecond + * @param [X]: function name + * @return none + */ +#define TIME_MSEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_MSEC(X) << " ms " << endl + +/** + * @brief calculate time and show by second + * @param [X]: variable name + * @return none + */ +#define TIME_SEC(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by second + * @param [X]: function name + * @return none + */ +#define TIME_SEC_SHOW(X) cout << "Func " << #X << " cost : " << TIME_SEC(X) << " s " << endl + +/** + * @brief calculate time and show by minute + * @param [X]: variable name + * @return none + */ +#define TIME_MINUTE(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by minute + * @param [X]: function name + * @return none + */ +#define TIME_MINUTE_SHOW(X) cout << "Func " << #X << " cost : " << TIME_MINUTE(X) << " min " << endl + +/** + * @brief calculate time and show by hour + * @param [X]: variable name + * @return none + */ +#define TIME_HOUR(X) std::chrono::duration_cast(X##_END - X##_START).count() + +/** + * @brief show time by hour + * @param [X]: function name + * @return none + */ +#define TIME_HOUR_SHOW(X) cout << "Func " << #X << " cost : " << TIME_HOUR(X) << " h " << endl + +/** + * @brief Recognize the string is a accessable directory or not + * @param [in]: path: the input string + * @return bool true: is directory; false: not directory + */ +bool IsDirectory(const std::string &path); + +/** + * @brief Copy data to device + * @param [in]: data: The data to copy + * @param [in]: size: The data bytes size + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @param [in]: memType: The dest memory type:MEMORY_NORMAL(in Atlas200DK), + * MEMORY_DEVICE, MEMORY_DVPP + * @return void* The dest memory pointer + */ +void *CopyDataToDevice(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType); + +/** + * @brief Copy data to device buffer + * @param [in]: dest: The device buffer + * @param [in]: destSize: The device buffer size + * @param [in]: src: The data to copy + * @param [in]: srcSize: The data bytes size + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @return AclLiteError ACLLITE_OK: copy success + * others: copy failed + */ +AclLiteError CopyDataToDeviceEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode); + +/** + * @brief Copy data to host + * @param [in]: data: The data to be copy + * @param [in]: size: The data bytes size + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @param [in]: memType: The dest memory type:MEMORY_NORMAL, MEMORY_HOST + * @return void* The dest memory pointer + */ +void *CopyDataToHost(const void *data, uint32_t size, aclrtRunMode curRunMode, MemoryType memType); + +/** + * @brief Copy data to host buffer + * @param [in]: dest: The host buffer + * @param [in]: destSize: The host buffer size + * @param [in]: src: The data to copy + * @param [in]: srcSize: The data bytes size + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @return AclLiteError ACLLITE_OK: copy success + * others: copy failed + */ +AclLiteError CopyDataToHostEx(void *dest, uint32_t destSize, const void *src, uint32_t srcSize, aclrtRunMode runMode); + +/** + * @brief Copy data to memory + * @param [in]: data: The data to be copy + * @param [in]: size: The data bytes size + * @param [in]: policy: the kind of sync, + * typedef enum aclrtMemcpyKind { + * ACL_MEMCPY_HOST_TO_HOST, // Memory copy from Host to Host + * ACL_MEMCPY_HOST_TO_DEVICE, // Memory copy from Host to Device + * ACL_MEMCPY_DEVICE_TO_HOST, // Memory copy from Device to Host + * ACL_MEMCPY_DEVICE_TO_DEVICE, // Memory copy from Device to Device + * } aclrtMemcpyKind; + * @param [in]: memType: The dest memory type + * @return void* The dest memory pointer + */ +void *CopyData(const void *data, uint32_t size, aclrtMemcpyKind policy, MemoryType memType); + +/** + * @brief Read jpeg image file. Only support baseline, not support progressive + * @param [out]: image: image data read from file. + * @param [in]: fileName: The data bytes size + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +AclLiteError ReadJpeg(ImageData &image, const std::string &fileName); + +/** + * @brief Get all files from file list string + * @param [in]: pathList: files list string, seperate by ',', + * the element could be file path or directory + * @param [in]: fileVec: The data bytes size + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +void GetAllFiles(const std::string &pathList, std::vector &fileVec); + +/** + * @brief Save data to binary file + * @param [in]: filename: binary file name with path + * @param [in]: data: binary data + * @param [in]: size: bytes size of data + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +void SaveBinFile(const std::string &filename, const void *data, uint32_t size); + +/** + * @brief Read binary file to buffer + * @param [in]: fileName: binary file name with path + * @param [in]: data: buffer + * @param [in]: size: buffer size + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +AclLiteError ReadBinFile(const std::string &fileName, void *&data, uint32_t &size); + +/** + * @brief Copy image to memory that malloc by new + * @param [out]: destImage: The image after copy + * @param [in]: srcImage: The image to copy + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +AclLiteError CopyImageToLocal(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode); + +/** + * @brief Copy image to acl device + * @param [out]: destImage: The image after copy + * @param [in]: srcImage: The image to copy + * @param [in]: curRunMode: The run mode, get by aclrtGetRunMode, + * Atlas200DK is ACL_DEVICE, Atlas300 is ACL_HOST + * @param [in]: memType: memory type, dvpp is MEMORY_DVPP, + * device is MEMPRY_DEVICE + * @return AclLiteError ACLLITE_OK: read success + * others: read failed + */ +AclLiteError CopyImageToDevice(ImageData &destImage, ImageData &srcImage, aclrtRunMode curRunMode, MemoryType memType); + +/** + * @brief Match ip address string as <1-255>.<0-255>.<0-255>.<0-255>: + * @param [in]: addrStr: Ip address string + * @return bool true: The input string match success + * false: is not match + */ +bool IsIpAddrWithPort(const std::string &addrStr); + +/** + * @brief Split ip address string <1-255>.<0-255>.<0-255>.<0-255>: to + * ip and port + * @param [out]: ip: Ip address <1-255>.<0-255>.<0-255>.<0-255> + * @param [out]: port: port string + * @param [in]: addr: Ip address string + * @return None + */ +void ParseIpAddr(std::string &ip, std::string &port, const std::string &addr); + +/** + * @brief Judge input string is mp4 file path + * @param [in]: path: file path + * @return bool true: input string is mp4 file path + * false: is not mp4 file path + */ +bool IsVideoFile(const std::string &path); + +/** + * @brief Judge input string is rtsp addr link rtsp:// + * @param [in]: str: input string + * @return bool true: input string is rtsp address + * false: is not rtsp address + */ +bool IsRtspAddr(const std::string &str); + +/** + * @brief Judge input string is digit string + * @param [in]: str: input string + * @return bool true: input string is digit string + * false: is not rtsp address + */ +bool IsDigitStr(const std::string &str); + +/** + * @brief Test file path is exist or not + * @param [in]: path: file path + * @return bool true: file path is exist + * false: is not exist + */ +bool IsPathExist(const std::string &path); + +/** + * @brief read file and save information to config + * @param [out]: config: map, save option information + * @param [in]: configFile: string, file + * @return bool true: read config success + * false: read config fail + */ +bool ReadConfig(std::map &config, const char *configFile); + +/** + * @brief print option information + * @param [in]: m: map, save option information + * @return None + */ +void PrintConfig(const std::map &m); + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ACL_LITE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt index d2b4797ea3c..291f186335e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt @@ -1,55 +1,55 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_definitions(-DENABLE_DVPP_INTERFACE) - -set(DVPP_UTILS_SRC - # Ascend310 - MDAclProcess.cc - DvppCommon.cc - ErrorCode.cpp - ResourceManager.cc - AclLiteUtils.cc - VdecHelper.cc - dvpp_video.cc - # plugin - acl_plugin.cc - ) -if(NOT MSLITE_ENABLE_ACL) - set(DVPP_UTILS_SRC - ${DVPP_UTILS_SRC} - ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc - ) -endif() - -if(NOT BUILD_LITE AND ENABLE_D) -set(DVPP_UTILS_SRC - ${DVPP_UTILS_SRC} - # Ascend910B - dvpp_image_utils.cc - ) -endif() - -add_library(dvpp_utils SHARED ${DVPP_UTILS_SRC}) -enable_target_when_only_build_plugins(dvpp_utils) - -if(MSLITE_ENABLE_ACL) - find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - # find acl_env_guard in ascend_kernel_plugin - target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) -else() - find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - target_link_libraries(dvpp_utils PRIVATE _c_dataengine ${acl} ${acl_dvpp} mindspore_core mindspore_shared_lib ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) -endif() -target_link_libraries(dvpp_utils PRIVATE $) - -if(MSLITE_ENABLE_CLOUD_MIND_DATA) - add_dependencies(dvpp_utils fbs_src) -endif() +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_definitions(-DENABLE_DVPP_INTERFACE) + +set(DVPP_UTILS_SRC + # Ascend310 + MDAclProcess.cc + DvppCommon.cc + ErrorCode.cpp + ResourceManager.cc + AclLiteUtils.cc + VdecHelper.cc + dvpp_video.cc + # plugin + acl_plugin.cc + ) +if(NOT MSLITE_ENABLE_ACL) + set(DVPP_UTILS_SRC + ${DVPP_UTILS_SRC} + ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc + ) +endif() + +if(NOT BUILD_LITE AND ENABLE_D) +set(DVPP_UTILS_SRC + ${DVPP_UTILS_SRC} + # Ascend910B + dvpp_image_utils.cc + ) +endif() + +add_library(dvpp_utils SHARED ${DVPP_UTILS_SRC}) +enable_target_when_only_build_plugins(dvpp_utils) + +if(MSLITE_ENABLE_ACL) + find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + # find acl_env_guard in ascend_kernel_plugin + target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) +else() + find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + target_link_libraries(dvpp_utils PRIVATE _c_dataengine ${acl} ${acl_dvpp} mindspore_core mindspore_shared_lib ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) +endif() +target_link_libraries(dvpp_utils PRIVATE $) + +if(MSLITE_ENABLE_CLOUD_MIND_DATA) + add_dependencies(dvpp_utils fbs_src) +endif() diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp index fb540959b22..ba8caf922ec 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp @@ -1,54 +1,54 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" - -#include "minddata/dataset/util/log_adapter.h" - -std::string GetAppErrCodeInfo(const APP_ERROR err) { - if ((err < APP_ERR_ACL_END) && (err >= APP_ERR_ACL_FAILURE)) { - return APP_ERR_ACL_LOG_STRING[((err < 0) ? (err + APP_ERR_ACL_END + 1) : err)]; - } else if ((err < APP_ERR_COMM_END) && (err > APP_ERR_COMM_BASE)) { - return (err - APP_ERR_COMM_BASE) < static_cast(sizeof(APP_ERR_COMMON_LOG_STRING)) / - static_cast(sizeof(APP_ERR_COMMON_LOG_STRING[0])) - ? APP_ERR_COMMON_LOG_STRING[err - APP_ERR_COMM_BASE] - : "Undefine the error code information"; - } else if ((err < APP_ERR_DVPP_END) && (err > APP_ERR_DVPP_BASE)) { - return (err - APP_ERR_DVPP_BASE) < - static_cast(sizeof(APP_ERR_DVPP_LOG_STRING)) / static_cast(sizeof(APP_ERR_DVPP_LOG_STRING[0])) - ? APP_ERR_DVPP_LOG_STRING[err - APP_ERR_DVPP_BASE] - : "Undefine the error code information"; - } else if ((err < APP_ERR_QUEUE_END) && (err > APP_ERR_QUEUE_BASE)) { - return (err - APP_ERR_QUEUE_BASE) < static_cast(sizeof(APP_ERR_QUEUE_LOG_STRING)) / - static_cast(sizeof(APP_ERR_QUEUE_LOG_STRING[0])) - ? APP_ERR_QUEUE_LOG_STRING[err - APP_ERR_QUEUE_BASE] - : "Undefine the error code information"; - } else { - return "Error code unknown"; - } -} - -void AssertErrorCode(int code, const std::string &file, const std::string &function, int line) { - if (code != APP_ERR_OK) { - MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; - } -} - -void CheckErrorCode(int code, const std::string &file, const std::string &function, int line) { - if (code != APP_ERR_OK) { - MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; - } -} +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" + +#include "minddata/dataset/util/log_adapter.h" + +std::string GetAppErrCodeInfo(const APP_ERROR err) { + if ((err < APP_ERR_ACL_END) && (err >= APP_ERR_ACL_FAILURE)) { + return APP_ERR_ACL_LOG_STRING[((err < 0) ? (err + APP_ERR_ACL_END + 1) : err)]; + } else if ((err < APP_ERR_COMM_END) && (err > APP_ERR_COMM_BASE)) { + return (err - APP_ERR_COMM_BASE) < static_cast(sizeof(APP_ERR_COMMON_LOG_STRING)) / + static_cast(sizeof(APP_ERR_COMMON_LOG_STRING[0])) + ? APP_ERR_COMMON_LOG_STRING[err - APP_ERR_COMM_BASE] + : "Undefine the error code information"; + } else if ((err < APP_ERR_DVPP_END) && (err > APP_ERR_DVPP_BASE)) { + return (err - APP_ERR_DVPP_BASE) < + static_cast(sizeof(APP_ERR_DVPP_LOG_STRING)) / static_cast(sizeof(APP_ERR_DVPP_LOG_STRING[0])) + ? APP_ERR_DVPP_LOG_STRING[err - APP_ERR_DVPP_BASE] + : "Undefine the error code information"; + } else if ((err < APP_ERR_QUEUE_END) && (err > APP_ERR_QUEUE_BASE)) { + return (err - APP_ERR_QUEUE_BASE) < static_cast(sizeof(APP_ERR_QUEUE_LOG_STRING)) / + static_cast(sizeof(APP_ERR_QUEUE_LOG_STRING[0])) + ? APP_ERR_QUEUE_LOG_STRING[err - APP_ERR_QUEUE_BASE] + : "Undefine the error code information"; + } else { + return "Error code unknown"; + } +} + +void AssertErrorCode(int code, const std::string &file, const std::string &function, int line) { + if (code != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; + } +} + +void CheckErrorCode(int code, const std::string &file, const std::string &function, int line) { + if (code != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; + } +} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h index f2f4d4d439b..05f17c2c5de 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h @@ -1,291 +1,291 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ - -#include - -using APP_ERROR = int; -// define the data tpye of error code -enum { - APP_ERR_OK = 0, - - // define the error code of ACL model, this is same with the aclError which is - // error code of ACL API Error codes 1~999 are reserved for the ACL. Do not - // add other error codes. Add it after APP_ERR_COMMON_ERR_BASE. - APP_ERR_ACL_FAILURE = -1, // ACL: general error - APP_ERR_ACL_INVALID_PARAM = 1, // ACL: invalid parameter - APP_ERR_ACL_BAD_ALLOC = 2, // ACL: memory allocation fail - APP_ERR_ACL_RT_FAILURE = 3, // ACL: runtime failure - APP_ERR_ACL_GE_FAILURE = 4, // ACL: Graph Engine failure - APP_ERR_ACL_OP_NOT_FOUND = 5, // ACL: operator not found - APP_ERR_ACL_OP_LOAD_FAILED = 6, // ACL: fail to load operator - APP_ERR_ACL_READ_MODEL_FAILURE = 7, // ACL: fail to read model - APP_ERR_ACL_PARSE_MODEL = 8, // ACL: parse model failure - APP_ERR_ACL_MODEL_MISSING_ATTR = 9, // ACL: model missing attribute - APP_ERR_ACL_DESERIALIZE_MODEL = 10, // ACL: deserialize model failure - APP_ERR_ACL_EVENT_NOT_READY = 12, // ACL: event not ready - APP_ERR_ACL_EVENT_COMPLETE = 13, // ACL: event complete - APP_ERR_ACL_UNSUPPORTED_DATA_TYPE = 14, // ACL: unsupported data type - APP_ERR_ACL_REPEAT_INITIALIZE = 15, // ACL: repeat initialize - APP_ERR_ACL_COMPILER_NOT_REGISTERED = 16, // ACL: compiler not registered - APP_ERR_ACL_IO = 17, // ACL: IO failed - APP_ERR_ACL_INVALID_FILE = 18, // ACL: invalid file - APP_ERR_ACL_INVALID_DUMP_CONFIG = 19, // ACL: invalid dump comfig - APP_ERR_ACL_INVALID_PROFILING_CONFIG = 20, // ACL: invalid profiling config - APP_ERR_ACL_OP_TYPE_NOT_MATCH = 21, // ACL: operator type not match - APP_ERR_ACL_OP_INPUT_NOT_MATCH = 22, // ACL: operator input not match - APP_ERR_ACL_OP_OUTPUT_NOT_MATCH = 23, // ACL: operator output not match - APP_ERR_ACL_OP_ATTR_NOT_MATCH = 24, // ACL: operator attribute not match - APP_ERR_ACL_API_NOT_SUPPORT = 25, // ACL: API not support - APP_ERR_ACL_CREATE_DATA_BUF_FAILED = 26, // ACL: create data buffer fail - APP_ERR_ACL_END, // Not an error code, define the range of ACL error code - - // define the common error code, range: 1001~1999 - APP_ERR_COMM_BASE = 1000, - APP_ERR_COMM_FAILURE = APP_ERR_COMM_BASE + 1, // General Failed - APP_ERR_COMM_INNER = APP_ERR_COMM_BASE + 2, // Internal error - APP_ERR_COMM_INVALID_POINTER = APP_ERR_COMM_BASE + 3, // Invalid Pointer - APP_ERR_COMM_INVALID_PARAM = APP_ERR_COMM_BASE + 4, // Invalid parameter - APP_ERR_COMM_UNREALIZED = APP_ERR_COMM_BASE + 5, // Not implemented - APP_ERR_COMM_OUT_OF_MEM = APP_ERR_COMM_BASE + 6, // Out of memory - APP_ERR_COMM_ALLOC_MEM = APP_ERR_COMM_BASE + 7, // memory allocation error - APP_ERR_COMM_FREE_MEM = APP_ERR_COMM_BASE + 8, // free memory error - APP_ERR_COMM_OUT_OF_RANGE = APP_ERR_COMM_BASE + 9, // out of range - APP_ERR_COMM_NO_PERMISSION = APP_ERR_COMM_BASE + 10, // NO Permission - APP_ERR_COMM_TIMEOUT = APP_ERR_COMM_BASE + 11, // Timed out - APP_ERR_COMM_NOT_INIT = APP_ERR_COMM_BASE + 12, // Not initialized - APP_ERR_COMM_INIT_FAIL = APP_ERR_COMM_BASE + 13, // initialize failed - APP_ERR_COMM_INPROGRESS = APP_ERR_COMM_BASE + 14, // Operation now in progress - APP_ERR_COMM_EXIST = APP_ERR_COMM_BASE + 15, // Object, file or other resource already exist - APP_ERR_COMM_NO_EXIST = APP_ERR_COMM_BASE + 16, // Object, file or other resource doesn't exist - APP_ERR_COMM_BUSY = APP_ERR_COMM_BASE + 17, // Object, file or other resource is in use - APP_ERR_COMM_FULL = APP_ERR_COMM_BASE + 18, // No available Device or resource - APP_ERR_COMM_OPEN_FAIL = APP_ERR_COMM_BASE + 19, // Device, file or resource open failed - APP_ERR_COMM_READ_FAIL = APP_ERR_COMM_BASE + 20, // Device, file or resource read failed - APP_ERR_COMM_WRITE_FAIL = APP_ERR_COMM_BASE + 21, // Device, file or resource write failed - APP_ERR_COMM_DESTORY_FAIL = APP_ERR_COMM_BASE + 22, // Device, file or resource destroy failed - APP_ERR_COMM_EXIT = APP_ERR_COMM_BASE + 23, // End of data stream, stop the application - APP_ERR_COMM_CONNECTION_CLOSE = APP_ERR_COMM_BASE + 24, // Out of connection, Communication shutdown - APP_ERR_COMM_CONNECTION_FAILURE = APP_ERR_COMM_BASE + 25, // connection fail - APP_ERR_COMM_STREAM_INVALID = APP_ERR_COMM_BASE + 26, // ACL stream is null pointer - APP_ERR_COMM_END, // Not an error code, define the range of common error code - - // define the error code of DVPP - APP_ERR_DVPP_BASE = 2000, - APP_ERR_DVPP_CROP_FAIL = APP_ERR_DVPP_BASE + 1, // DVPP: crop fail - APP_ERR_DVPP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 2, // DVPP: resize fail - APP_ERR_DVPP_CROP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 3, // DVPP: corp and resize fail - APP_ERR_DVPP_CONVERT_FROMAT_FAIL = APP_ERR_DVPP_BASE + 4, // DVPP: convert image format fail - APP_ERR_DVPP_VPC_FAIL = APP_ERR_DVPP_BASE + 5, // DVPP: VPC(crop, resize, convert format) fail - APP_ERR_DVPP_JPEG_DECODE_FAIL = APP_ERR_DVPP_BASE + 6, // DVPP: decode jpeg or jpg fail - APP_ERR_DVPP_JPEG_ENCODE_FAIL = APP_ERR_DVPP_BASE + 7, // DVPP: encode jpeg or jpg fail - APP_ERR_DVPP_PNG_DECODE_FAIL = APP_ERR_DVPP_BASE + 8, // DVPP: encode png fail - APP_ERR_DVPP_H26X_DECODE_FAIL = APP_ERR_DVPP_BASE + 9, // DVPP: decode H264 or H265 fail - APP_ERR_DVPP_H26X_ENCODE_FAIL = APP_ERR_DVPP_BASE + 10, // DVPP: encode H264 or H265 fail - APP_ERR_DVPP_HANDLE_NULL = APP_ERR_DVPP_BASE + 11, // DVPP: acldvppChannelDesc is nullptr - APP_ERR_DVPP_PICDESC_FAIL = APP_ERR_DVPP_BASE + 12, // DVPP: fail to create acldvppCreatePicDesc or - // fail to set acldvppCreatePicDesc - APP_ERR_DVPP_CONFIG_FAIL = APP_ERR_DVPP_BASE + 13, // DVPP: fail to set dvpp configuration,such as - // resize configuration,crop configuration - APP_ERR_DVPP_OBJ_FUNC_MISMATCH = APP_ERR_DVPP_BASE + 14, // DVPP: DvppCommon object mismatch the function - APP_ERR_DVPP_NORMALIZE_FAIL = APP_ERR_DVPP_BASE + 15, // DVPP: normalize fail - APP_ERR_DVPP_ADJUST_BRIGHTNESS_FAIL = APP_ERR_DVPP_BASE + 16, // DVPP: adjust brightness fail - APP_ERR_DVPP_ADJUST_CONTRAST_FAIL = APP_ERR_DVPP_BASE + 17, // DVPP: adjust contrast fail - APP_ERR_DVPP_ADJUST_HUE_FAIL = APP_ERR_DVPP_BASE + 18, // DVPP: adjust hue fail - APP_ERR_DVPP_ADJUST_SATURATION_FAIL = APP_ERR_DVPP_BASE + 19, // DVPP: adjust saturation fail - APP_ERR_DVPP_HORIZONTAL_FLIP_FAIL = APP_ERR_DVPP_BASE + 20, // DVPP: Horizontal Flip - APP_ERR_DVPP_VERTICAL_FLIP_FAIL = APP_ERR_DVPP_BASE + 21, // DVPP: vertical Flip - APP_ERR_DVPP_PERSPECTIVE_FAIL = APP_ERR_DVPP_BASE + 22, // DVPP: perspective fail - APP_ERR_DVPP_RESIZED_CROP_FAIL = APP_ERR_DVPP_BASE + 23, // DVPP: crop and resize fail - APP_ERR_DVPP_PAD_FAIL = APP_ERR_DVPP_BASE + 24, // DVPP: pad fail - APP_ERR_DVPP_AFFINE_FAIL = APP_ERR_DVPP_BASE + 25, // DVPP: affine fail - APP_ERR_DVPP_GAUSSIAN_BLUR_FAIL = APP_ERR_DVPP_BASE + 26, // DVPP: gaussian blur fail - APP_ERR_DVPP_EQUALIZE_FAIL = APP_ERR_DVPP_BASE + 27, // DVPP: equalize blur fail - APP_ERR_DVPP_ROTATE_FAIL = APP_ERR_DVPP_BASE + 28, // DVPP: rotate fail - APP_ERR_DVPP_AUTO_CONTRAST_FAIL = APP_ERR_DVPP_BASE + 29, // DVPP: auto contrast fail - APP_ERR_DVPP_POSTERIZE_FAIL = APP_ERR_DVPP_BASE + 30, // DVPP: posterize fail - APP_ERR_DVPP_ADJUST_SHARPNESS_FAIL = APP_ERR_DVPP_BASE + 31, // DVPP: adjust sharpness fail - APP_ERR_DVPP_INVERT_FAIL = APP_ERR_DVPP_BASE + 32, // DVPP: invert fail - APP_ERR_DVPP_SOLARIZE_FAIL = APP_ERR_DVPP_BASE + 33, // DVPP: solarize fail - APP_ERR_DVPP_CONVERT_COLOR_FAIL = APP_ERR_DVPP_BASE + 34, // DVPP: convert color fail - APP_ERR_DVPP_ERASE_FAIL = APP_ERR_DVPP_BASE + 35, // DVPP: erase fail - APP_ERR_DVPP_END, // Not an error code, define the range of common error code - - // define the error code of inference - APP_ERR_INFER_BASE = 3000, - APP_ERR_INFER_SET_INPUT_FAIL = APP_ERR_INFER_BASE + 1, // Infer: set input fail - APP_ERR_INFER_SET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 2, // Infer: set output fail - APP_ERR_INFER_CREATE_OUTPUT_FAIL = APP_ERR_INFER_BASE + 3, // Infer: create output fail - APP_ERR_INFER_OP_SET_ATTR_FAIL = APP_ERR_INFER_BASE + 4, // Infer: set op attribute fail - APP_ERR_INFER_GET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 5, // Infer: get model output fail - APP_ERR_INFER_FIND_MODEL_ID_FAIL = APP_ERR_INFER_BASE + 6, // Infer: find model id fail - APP_ERR_INFER_FIND_MODEL_DESC_FAIL = APP_ERR_INFER_BASE + 7, // Infer: find model description fail - APP_ERR_INFER_FIND_MODEL_MEM_FAIL = APP_ERR_INFER_BASE + 8, // Infer: find model memory fail - APP_ERR_INFER_FIND_MODEL_WEIGHT_FAIL = APP_ERR_INFER_BASE + 9, // Infer: find model weight fail - - APP_ERR_INFER_END, // Not an error code, define the range of inference error - // code - - // define the error code of transmission - APP_ERR_TRANS_BASE = 4000, - - APP_ERR_TRANS_END, // Not an error code, define the range of transmission - // error code - - // define the error code of blocking queue - APP_ERR_QUEUE_BASE = 5000, - APP_ERR_QUEUE_EMPTY = APP_ERR_QUEUE_BASE + 1, // Queue: empty queue - APP_ERR_QUEUE_STOPED = APP_ERR_QUEUE_BASE + 2, // Queue: queue stopped - APP_ERROR_QUEUE_FULL = APP_ERR_QUEUE_BASE + 3, // Queue: full queue - - // define the error code of destory - APP_ERR_DESTORY_BASE = 6000, - APP_ERR_DESTORY_TENSOR = APP_ERR_DESTORY_BASE + 1, - APP_ERR_DESTORY_SCALAR = APP_ERR_DESTORY_BASE + 2, - APP_ERR_DESTORY_INT_ARRAY = APP_ERR_DESTORY_BASE + 3, - APP_ERR_DESTORY_FLOAT_ARRAY = APP_ERR_DESTORY_BASE + 4, - APP_ERR_DESTORY_BOOL_ARRAY = APP_ERR_DESTORY_BASE + 5, - APP_ERR_DESTORY_TENSOR_LIST = APP_ERR_DESTORY_BASE + 6, - APP_ERR_DESTORY_SCALAR_LIST = APP_ERR_DESTORY_BASE + 7, - - // define the idrecognition web error code - APP_ERROR_FACE_WEB_USE_BASE = 10000, - APP_ERROR_FACE_WEB_USE_SYSTEM_ERROR = APP_ERROR_FACE_WEB_USE_BASE + 1, // Web: system error - APP_ERROR_FACE_WEB_USE_MUL_FACE = APP_ERROR_FACE_WEB_USE_BASE + 2, // Web: multiple cheeks - APP_ERROR_FACE_WEB_USE_REPEAT_REG = APP_ERROR_FACE_WEB_USE_BASE + 3, // Web: repeat registration - APP_ERROR_FACE_WEB_USE_PART_SUCCESS = APP_ERROR_FACE_WEB_USE_BASE + 4, // Web: partial search succeeded - APP_ERROR_FACE_WEB_USE_NO_FACE = APP_ERROR_FACE_WEB_USE_BASE + 5, // Web: no cheek detected - APP_ERR_QUEUE_END, // Not an error code, define the range of blocking queue - // error code -}; - -const std::string APP_ERR_ACL_LOG_STRING[] = { - "Success", // APP_ERR_OK - "ACL: invalid parameter", // APP_ERR_ACL_INVALID_PARAM - "ACL: memory allocation fail", // APP_ERR_ACL_BAD_ALLOC - "ACL: runtime failure", // APP_ERR_ACL_RT_FAILURE - "ACL: Graph Engine failure", // APP_ERR_ACL_GE_FAILURE - "ACL: operator not found", // APP_ERR_ACL_OP_NOT_FOUND - "ACL: fail to load operator", // APP_ERR_ACL_OP_LOAD_FAILED - "ACL: fail to read model", // APP_ERR_ACL_READ_MODEL_FAILURE - "ACL: parse model failure", // APP_ERR_ACL_PARSE_MODEL - "ACL: model missing attribute", // APP_ERR_ACL_MODEL_MISSING_ATTR - "ACL: deserialize model failure", // APP_ERR_ACL_DESERIALIZE_MODEL - "Placeholder", // 11 - "ACL: event not ready", // APP_ERR_ACL_EVENT_NOT_READY - "ACL: event complete", // APP_ERR_ACL_EVENT_COMPLETE - "ACL: unsupported data type", // APP_ERR_ACL_UNSUPPORTED_DATA_TYPE - "ACL: repeat initialize", // APP_ERR_ACL_REPEAT_INITIALIZE - "ACL: compiler not registered", // APP_ERR_ACL_COMPILER_NOT_REGISTERED - "ACL: IO failed", // APP_ERR_ACL_IO - "ACL: invalid file", // APP_ERR_ACL_INVALID_FILE - "ACL: invalid dump comfig", // APP_ERR_ACL_INVALID_DUMP_CONFIG - "ACL: invalid profiling config", // APP_ERR_ACL_INVALID_PROFILING_CONFIG - "ACL: operator type not match", // APP_ERR_ACL_OP_TYPE_NOT_MATCH - "ACL: operator input not match", // APP_ERR_ACL_OP_INPUT_NOT_MATCH - "ACL: operator output not match", // APP_ERR_ACL_OP_OUTPUT_NOT_MATCH - "ACL: operator attribute not match", // APP_ERR_ACL_OP_ATTR_NOT_MATCH - "ACL: API not supported", // APP_ERR_ACL_API_NOT_SUPPORT - "ACL: create data buffer fail", // APP_ERR_ACL_CREATE_DATA_BUF_FAILED -}; - -const std::string APP_ERR_COMMON_LOG_STRING[] = { - "Placeholder", // 0 - "General Failed", - "Internal error", - "Invalid Pointer", // 3 - "Invalid parameter", - "Not implemented", // 5 - "Out of memory", - "memory allocation error", - "free memory error", // 8 - "out of range", - "NO Permission ", // 10 - "Timed out", - "Not initialized", - "initialize failed", // 13 - "Operation now in progress ", - "Object, file or other resource already exist", // 15 - "Object, file or other resource already doesn't exist", - "Object, file or other resource is in use", - "No available Device or resource", // 18 - "Device, file or resource open failed", - "Device, file or resource read failed", // 20 - "Device, file or resource write failed", - "Device, file or resource destory failed", // 22 - " ", - "Out of connection, Communication shutdown", // 24 - "connection fail", - "ACL stream is null pointer", // 26 -}; - -const std::string APP_ERR_DVPP_LOG_STRING[] = { - "Placeholder", // 0 - "DVPP: crop fail", - "DVPP: resize fail", - "DVPP: corp and resize fail", - "DVPP: convert image format fail", - "DVPP: VPC(crop, resize, convert format) fail", // 5 - "DVPP: decode jpeg or jpg fail", - "DVPP: encode jpeg or jpg fail", - "DVPP: encode png fail", - "DVPP: decode H264 or H265 fail", - "DVPP: encode H264 or H265 fail", // 10 - "DVPP: acldvppChannelDesc is nullptr", - "DVPP: fail to create or set acldvppCreatePicDesc", - "DVPP: fail to set dvpp configuration", - "DVPP: DvppCommon object mismatch the function", // 14 -}; - -const std::string APP_ERR_INFER_LOG_STRING[] = { - "Placeholder", // 0 - "Infer: set input fail", - "Infer: set output fail", - "Infer: create output fail", - "Infer: set op attribute fail", - "Infer: get model output fail", // 5 - "Infer: find model id fail", - "Infer: find model description fail", - "Infer: find model memory fail", - "Infer: find model weight fail", // 9 -}; - -const std::string APP_ERR_QUEUE_LOG_STRING[] = { - "Placeholder", - "empty queue", - "queue stopped", - "full queue", -}; - -const std::string APP_ERR_FACE_LOG_STRING[] = { - "Placeholder", // 0 - "system error", // 1 - "multiple faces", // 2 - "repeat registration", // 3 - "partial search succeeded", // 4 - "no face detected", // 5 -}; - -std::string GetAppErrCodeInfo(APP_ERROR err); -void AssertErrorCode(int code, const std::string &file, const std::string &function, int line); -void CheckErrorCode(int code, const std::string &file, const std::string &function, int line); - -#define RtAssert(code) AssertErrorCode(code, DATASET_SRC_FILE_NAME, __FUNCTION__, __LINE__); -#define RtCheckError(code) CheckErrorCode(code, DATASET_SRC_FILE_NAME, __FUNCTION__, __LINE__); - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ + +#include + +using APP_ERROR = int; +// define the data tpye of error code +enum { + APP_ERR_OK = 0, + + // define the error code of ACL model, this is same with the aclError which is + // error code of ACL API Error codes 1~999 are reserved for the ACL. Do not + // add other error codes. Add it after APP_ERR_COMMON_ERR_BASE. + APP_ERR_ACL_FAILURE = -1, // ACL: general error + APP_ERR_ACL_INVALID_PARAM = 1, // ACL: invalid parameter + APP_ERR_ACL_BAD_ALLOC = 2, // ACL: memory allocation fail + APP_ERR_ACL_RT_FAILURE = 3, // ACL: runtime failure + APP_ERR_ACL_GE_FAILURE = 4, // ACL: Graph Engine failure + APP_ERR_ACL_OP_NOT_FOUND = 5, // ACL: operator not found + APP_ERR_ACL_OP_LOAD_FAILED = 6, // ACL: fail to load operator + APP_ERR_ACL_READ_MODEL_FAILURE = 7, // ACL: fail to read model + APP_ERR_ACL_PARSE_MODEL = 8, // ACL: parse model failure + APP_ERR_ACL_MODEL_MISSING_ATTR = 9, // ACL: model missing attribute + APP_ERR_ACL_DESERIALIZE_MODEL = 10, // ACL: deserialize model failure + APP_ERR_ACL_EVENT_NOT_READY = 12, // ACL: event not ready + APP_ERR_ACL_EVENT_COMPLETE = 13, // ACL: event complete + APP_ERR_ACL_UNSUPPORTED_DATA_TYPE = 14, // ACL: unsupported data type + APP_ERR_ACL_REPEAT_INITIALIZE = 15, // ACL: repeat initialize + APP_ERR_ACL_COMPILER_NOT_REGISTERED = 16, // ACL: compiler not registered + APP_ERR_ACL_IO = 17, // ACL: IO failed + APP_ERR_ACL_INVALID_FILE = 18, // ACL: invalid file + APP_ERR_ACL_INVALID_DUMP_CONFIG = 19, // ACL: invalid dump comfig + APP_ERR_ACL_INVALID_PROFILING_CONFIG = 20, // ACL: invalid profiling config + APP_ERR_ACL_OP_TYPE_NOT_MATCH = 21, // ACL: operator type not match + APP_ERR_ACL_OP_INPUT_NOT_MATCH = 22, // ACL: operator input not match + APP_ERR_ACL_OP_OUTPUT_NOT_MATCH = 23, // ACL: operator output not match + APP_ERR_ACL_OP_ATTR_NOT_MATCH = 24, // ACL: operator attribute not match + APP_ERR_ACL_API_NOT_SUPPORT = 25, // ACL: API not support + APP_ERR_ACL_CREATE_DATA_BUF_FAILED = 26, // ACL: create data buffer fail + APP_ERR_ACL_END, // Not an error code, define the range of ACL error code + + // define the common error code, range: 1001~1999 + APP_ERR_COMM_BASE = 1000, + APP_ERR_COMM_FAILURE = APP_ERR_COMM_BASE + 1, // General Failed + APP_ERR_COMM_INNER = APP_ERR_COMM_BASE + 2, // Internal error + APP_ERR_COMM_INVALID_POINTER = APP_ERR_COMM_BASE + 3, // Invalid Pointer + APP_ERR_COMM_INVALID_PARAM = APP_ERR_COMM_BASE + 4, // Invalid parameter + APP_ERR_COMM_UNREALIZED = APP_ERR_COMM_BASE + 5, // Not implemented + APP_ERR_COMM_OUT_OF_MEM = APP_ERR_COMM_BASE + 6, // Out of memory + APP_ERR_COMM_ALLOC_MEM = APP_ERR_COMM_BASE + 7, // memory allocation error + APP_ERR_COMM_FREE_MEM = APP_ERR_COMM_BASE + 8, // free memory error + APP_ERR_COMM_OUT_OF_RANGE = APP_ERR_COMM_BASE + 9, // out of range + APP_ERR_COMM_NO_PERMISSION = APP_ERR_COMM_BASE + 10, // NO Permission + APP_ERR_COMM_TIMEOUT = APP_ERR_COMM_BASE + 11, // Timed out + APP_ERR_COMM_NOT_INIT = APP_ERR_COMM_BASE + 12, // Not initialized + APP_ERR_COMM_INIT_FAIL = APP_ERR_COMM_BASE + 13, // initialize failed + APP_ERR_COMM_INPROGRESS = APP_ERR_COMM_BASE + 14, // Operation now in progress + APP_ERR_COMM_EXIST = APP_ERR_COMM_BASE + 15, // Object, file or other resource already exist + APP_ERR_COMM_NO_EXIST = APP_ERR_COMM_BASE + 16, // Object, file or other resource doesn't exist + APP_ERR_COMM_BUSY = APP_ERR_COMM_BASE + 17, // Object, file or other resource is in use + APP_ERR_COMM_FULL = APP_ERR_COMM_BASE + 18, // No available Device or resource + APP_ERR_COMM_OPEN_FAIL = APP_ERR_COMM_BASE + 19, // Device, file or resource open failed + APP_ERR_COMM_READ_FAIL = APP_ERR_COMM_BASE + 20, // Device, file or resource read failed + APP_ERR_COMM_WRITE_FAIL = APP_ERR_COMM_BASE + 21, // Device, file or resource write failed + APP_ERR_COMM_DESTORY_FAIL = APP_ERR_COMM_BASE + 22, // Device, file or resource destroy failed + APP_ERR_COMM_EXIT = APP_ERR_COMM_BASE + 23, // End of data stream, stop the application + APP_ERR_COMM_CONNECTION_CLOSE = APP_ERR_COMM_BASE + 24, // Out of connection, Communication shutdown + APP_ERR_COMM_CONNECTION_FAILURE = APP_ERR_COMM_BASE + 25, // connection fail + APP_ERR_COMM_STREAM_INVALID = APP_ERR_COMM_BASE + 26, // ACL stream is null pointer + APP_ERR_COMM_END, // Not an error code, define the range of common error code + + // define the error code of DVPP + APP_ERR_DVPP_BASE = 2000, + APP_ERR_DVPP_CROP_FAIL = APP_ERR_DVPP_BASE + 1, // DVPP: crop fail + APP_ERR_DVPP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 2, // DVPP: resize fail + APP_ERR_DVPP_CROP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 3, // DVPP: corp and resize fail + APP_ERR_DVPP_CONVERT_FROMAT_FAIL = APP_ERR_DVPP_BASE + 4, // DVPP: convert image format fail + APP_ERR_DVPP_VPC_FAIL = APP_ERR_DVPP_BASE + 5, // DVPP: VPC(crop, resize, convert format) fail + APP_ERR_DVPP_JPEG_DECODE_FAIL = APP_ERR_DVPP_BASE + 6, // DVPP: decode jpeg or jpg fail + APP_ERR_DVPP_JPEG_ENCODE_FAIL = APP_ERR_DVPP_BASE + 7, // DVPP: encode jpeg or jpg fail + APP_ERR_DVPP_PNG_DECODE_FAIL = APP_ERR_DVPP_BASE + 8, // DVPP: encode png fail + APP_ERR_DVPP_H26X_DECODE_FAIL = APP_ERR_DVPP_BASE + 9, // DVPP: decode H264 or H265 fail + APP_ERR_DVPP_H26X_ENCODE_FAIL = APP_ERR_DVPP_BASE + 10, // DVPP: encode H264 or H265 fail + APP_ERR_DVPP_HANDLE_NULL = APP_ERR_DVPP_BASE + 11, // DVPP: acldvppChannelDesc is nullptr + APP_ERR_DVPP_PICDESC_FAIL = APP_ERR_DVPP_BASE + 12, // DVPP: fail to create acldvppCreatePicDesc or + // fail to set acldvppCreatePicDesc + APP_ERR_DVPP_CONFIG_FAIL = APP_ERR_DVPP_BASE + 13, // DVPP: fail to set dvpp configuration,such as + // resize configuration,crop configuration + APP_ERR_DVPP_OBJ_FUNC_MISMATCH = APP_ERR_DVPP_BASE + 14, // DVPP: DvppCommon object mismatch the function + APP_ERR_DVPP_NORMALIZE_FAIL = APP_ERR_DVPP_BASE + 15, // DVPP: normalize fail + APP_ERR_DVPP_ADJUST_BRIGHTNESS_FAIL = APP_ERR_DVPP_BASE + 16, // DVPP: adjust brightness fail + APP_ERR_DVPP_ADJUST_CONTRAST_FAIL = APP_ERR_DVPP_BASE + 17, // DVPP: adjust contrast fail + APP_ERR_DVPP_ADJUST_HUE_FAIL = APP_ERR_DVPP_BASE + 18, // DVPP: adjust hue fail + APP_ERR_DVPP_ADJUST_SATURATION_FAIL = APP_ERR_DVPP_BASE + 19, // DVPP: adjust saturation fail + APP_ERR_DVPP_HORIZONTAL_FLIP_FAIL = APP_ERR_DVPP_BASE + 20, // DVPP: Horizontal Flip + APP_ERR_DVPP_VERTICAL_FLIP_FAIL = APP_ERR_DVPP_BASE + 21, // DVPP: vertical Flip + APP_ERR_DVPP_PERSPECTIVE_FAIL = APP_ERR_DVPP_BASE + 22, // DVPP: perspective fail + APP_ERR_DVPP_RESIZED_CROP_FAIL = APP_ERR_DVPP_BASE + 23, // DVPP: crop and resize fail + APP_ERR_DVPP_PAD_FAIL = APP_ERR_DVPP_BASE + 24, // DVPP: pad fail + APP_ERR_DVPP_AFFINE_FAIL = APP_ERR_DVPP_BASE + 25, // DVPP: affine fail + APP_ERR_DVPP_GAUSSIAN_BLUR_FAIL = APP_ERR_DVPP_BASE + 26, // DVPP: gaussian blur fail + APP_ERR_DVPP_EQUALIZE_FAIL = APP_ERR_DVPP_BASE + 27, // DVPP: equalize blur fail + APP_ERR_DVPP_ROTATE_FAIL = APP_ERR_DVPP_BASE + 28, // DVPP: rotate fail + APP_ERR_DVPP_AUTO_CONTRAST_FAIL = APP_ERR_DVPP_BASE + 29, // DVPP: auto contrast fail + APP_ERR_DVPP_POSTERIZE_FAIL = APP_ERR_DVPP_BASE + 30, // DVPP: posterize fail + APP_ERR_DVPP_ADJUST_SHARPNESS_FAIL = APP_ERR_DVPP_BASE + 31, // DVPP: adjust sharpness fail + APP_ERR_DVPP_INVERT_FAIL = APP_ERR_DVPP_BASE + 32, // DVPP: invert fail + APP_ERR_DVPP_SOLARIZE_FAIL = APP_ERR_DVPP_BASE + 33, // DVPP: solarize fail + APP_ERR_DVPP_CONVERT_COLOR_FAIL = APP_ERR_DVPP_BASE + 34, // DVPP: convert color fail + APP_ERR_DVPP_ERASE_FAIL = APP_ERR_DVPP_BASE + 35, // DVPP: erase fail + APP_ERR_DVPP_END, // Not an error code, define the range of common error code + + // define the error code of inference + APP_ERR_INFER_BASE = 3000, + APP_ERR_INFER_SET_INPUT_FAIL = APP_ERR_INFER_BASE + 1, // Infer: set input fail + APP_ERR_INFER_SET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 2, // Infer: set output fail + APP_ERR_INFER_CREATE_OUTPUT_FAIL = APP_ERR_INFER_BASE + 3, // Infer: create output fail + APP_ERR_INFER_OP_SET_ATTR_FAIL = APP_ERR_INFER_BASE + 4, // Infer: set op attribute fail + APP_ERR_INFER_GET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 5, // Infer: get model output fail + APP_ERR_INFER_FIND_MODEL_ID_FAIL = APP_ERR_INFER_BASE + 6, // Infer: find model id fail + APP_ERR_INFER_FIND_MODEL_DESC_FAIL = APP_ERR_INFER_BASE + 7, // Infer: find model description fail + APP_ERR_INFER_FIND_MODEL_MEM_FAIL = APP_ERR_INFER_BASE + 8, // Infer: find model memory fail + APP_ERR_INFER_FIND_MODEL_WEIGHT_FAIL = APP_ERR_INFER_BASE + 9, // Infer: find model weight fail + + APP_ERR_INFER_END, // Not an error code, define the range of inference error + // code + + // define the error code of transmission + APP_ERR_TRANS_BASE = 4000, + + APP_ERR_TRANS_END, // Not an error code, define the range of transmission + // error code + + // define the error code of blocking queue + APP_ERR_QUEUE_BASE = 5000, + APP_ERR_QUEUE_EMPTY = APP_ERR_QUEUE_BASE + 1, // Queue: empty queue + APP_ERR_QUEUE_STOPED = APP_ERR_QUEUE_BASE + 2, // Queue: queue stopped + APP_ERROR_QUEUE_FULL = APP_ERR_QUEUE_BASE + 3, // Queue: full queue + + // define the error code of destory + APP_ERR_DESTORY_BASE = 6000, + APP_ERR_DESTORY_TENSOR = APP_ERR_DESTORY_BASE + 1, + APP_ERR_DESTORY_SCALAR = APP_ERR_DESTORY_BASE + 2, + APP_ERR_DESTORY_INT_ARRAY = APP_ERR_DESTORY_BASE + 3, + APP_ERR_DESTORY_FLOAT_ARRAY = APP_ERR_DESTORY_BASE + 4, + APP_ERR_DESTORY_BOOL_ARRAY = APP_ERR_DESTORY_BASE + 5, + APP_ERR_DESTORY_TENSOR_LIST = APP_ERR_DESTORY_BASE + 6, + APP_ERR_DESTORY_SCALAR_LIST = APP_ERR_DESTORY_BASE + 7, + + // define the idrecognition web error code + APP_ERROR_FACE_WEB_USE_BASE = 10000, + APP_ERROR_FACE_WEB_USE_SYSTEM_ERROR = APP_ERROR_FACE_WEB_USE_BASE + 1, // Web: system error + APP_ERROR_FACE_WEB_USE_MUL_FACE = APP_ERROR_FACE_WEB_USE_BASE + 2, // Web: multiple cheeks + APP_ERROR_FACE_WEB_USE_REPEAT_REG = APP_ERROR_FACE_WEB_USE_BASE + 3, // Web: repeat registration + APP_ERROR_FACE_WEB_USE_PART_SUCCESS = APP_ERROR_FACE_WEB_USE_BASE + 4, // Web: partial search succeeded + APP_ERROR_FACE_WEB_USE_NO_FACE = APP_ERROR_FACE_WEB_USE_BASE + 5, // Web: no cheek detected + APP_ERR_QUEUE_END, // Not an error code, define the range of blocking queue + // error code +}; + +const std::string APP_ERR_ACL_LOG_STRING[] = { + "Success", // APP_ERR_OK + "ACL: invalid parameter", // APP_ERR_ACL_INVALID_PARAM + "ACL: memory allocation fail", // APP_ERR_ACL_BAD_ALLOC + "ACL: runtime failure", // APP_ERR_ACL_RT_FAILURE + "ACL: Graph Engine failure", // APP_ERR_ACL_GE_FAILURE + "ACL: operator not found", // APP_ERR_ACL_OP_NOT_FOUND + "ACL: fail to load operator", // APP_ERR_ACL_OP_LOAD_FAILED + "ACL: fail to read model", // APP_ERR_ACL_READ_MODEL_FAILURE + "ACL: parse model failure", // APP_ERR_ACL_PARSE_MODEL + "ACL: model missing attribute", // APP_ERR_ACL_MODEL_MISSING_ATTR + "ACL: deserialize model failure", // APP_ERR_ACL_DESERIALIZE_MODEL + "Placeholder", // 11 + "ACL: event not ready", // APP_ERR_ACL_EVENT_NOT_READY + "ACL: event complete", // APP_ERR_ACL_EVENT_COMPLETE + "ACL: unsupported data type", // APP_ERR_ACL_UNSUPPORTED_DATA_TYPE + "ACL: repeat initialize", // APP_ERR_ACL_REPEAT_INITIALIZE + "ACL: compiler not registered", // APP_ERR_ACL_COMPILER_NOT_REGISTERED + "ACL: IO failed", // APP_ERR_ACL_IO + "ACL: invalid file", // APP_ERR_ACL_INVALID_FILE + "ACL: invalid dump comfig", // APP_ERR_ACL_INVALID_DUMP_CONFIG + "ACL: invalid profiling config", // APP_ERR_ACL_INVALID_PROFILING_CONFIG + "ACL: operator type not match", // APP_ERR_ACL_OP_TYPE_NOT_MATCH + "ACL: operator input not match", // APP_ERR_ACL_OP_INPUT_NOT_MATCH + "ACL: operator output not match", // APP_ERR_ACL_OP_OUTPUT_NOT_MATCH + "ACL: operator attribute not match", // APP_ERR_ACL_OP_ATTR_NOT_MATCH + "ACL: API not supported", // APP_ERR_ACL_API_NOT_SUPPORT + "ACL: create data buffer fail", // APP_ERR_ACL_CREATE_DATA_BUF_FAILED +}; + +const std::string APP_ERR_COMMON_LOG_STRING[] = { + "Placeholder", // 0 + "General Failed", + "Internal error", + "Invalid Pointer", // 3 + "Invalid parameter", + "Not implemented", // 5 + "Out of memory", + "memory allocation error", + "free memory error", // 8 + "out of range", + "NO Permission ", // 10 + "Timed out", + "Not initialized", + "initialize failed", // 13 + "Operation now in progress ", + "Object, file or other resource already exist", // 15 + "Object, file or other resource already doesn't exist", + "Object, file or other resource is in use", + "No available Device or resource", // 18 + "Device, file or resource open failed", + "Device, file or resource read failed", // 20 + "Device, file or resource write failed", + "Device, file or resource destory failed", // 22 + " ", + "Out of connection, Communication shutdown", // 24 + "connection fail", + "ACL stream is null pointer", // 26 +}; + +const std::string APP_ERR_DVPP_LOG_STRING[] = { + "Placeholder", // 0 + "DVPP: crop fail", + "DVPP: resize fail", + "DVPP: corp and resize fail", + "DVPP: convert image format fail", + "DVPP: VPC(crop, resize, convert format) fail", // 5 + "DVPP: decode jpeg or jpg fail", + "DVPP: encode jpeg or jpg fail", + "DVPP: encode png fail", + "DVPP: decode H264 or H265 fail", + "DVPP: encode H264 or H265 fail", // 10 + "DVPP: acldvppChannelDesc is nullptr", + "DVPP: fail to create or set acldvppCreatePicDesc", + "DVPP: fail to set dvpp configuration", + "DVPP: DvppCommon object mismatch the function", // 14 +}; + +const std::string APP_ERR_INFER_LOG_STRING[] = { + "Placeholder", // 0 + "Infer: set input fail", + "Infer: set output fail", + "Infer: create output fail", + "Infer: set op attribute fail", + "Infer: get model output fail", // 5 + "Infer: find model id fail", + "Infer: find model description fail", + "Infer: find model memory fail", + "Infer: find model weight fail", // 9 +}; + +const std::string APP_ERR_QUEUE_LOG_STRING[] = { + "Placeholder", + "empty queue", + "queue stopped", + "full queue", +}; + +const std::string APP_ERR_FACE_LOG_STRING[] = { + "Placeholder", // 0 + "system error", // 1 + "multiple faces", // 2 + "repeat registration", // 3 + "partial search succeeded", // 4 + "no face detected", // 5 +}; + +std::string GetAppErrCodeInfo(APP_ERROR err); +void AssertErrorCode(int code, const std::string &file, const std::string &function, int line); +void CheckErrorCode(int code, const std::string &file, const std::string &function, int line); + +#define RtAssert(code) AssertErrorCode(code, DATASET_SRC_FILE_NAME, __FUNCTION__, __LINE__); +#define RtCheckError(code) CheckErrorCode(code, DATASET_SRC_FILE_NAME, __FUNCTION__, __LINE__); + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_ERROR_CODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h index 1f550e3d253..7e7ae019091 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h @@ -1,129 +1,129 @@ -/** - * ============================================================================ - * - * Copyright (C) 2018, Hisilicon Technologies Co., Ltd. All Rights Reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1 Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2 Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3 Neither the names of the copyright holders nor the names of the - * contributors may be used to endorse or promote products derived from this - * software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - * ============================================================================ - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ - -#include -#include - -template -class ThreadSafeQueue { - public: - /** - * @brief ThreadSafeQueue constructor - * @param [in] capacity: the queue capacity - */ - explicit ThreadSafeQueue(uint32_t capacity) { - // check the input value: capacity is valid - if (capacity >= kMinQueueCapacity && capacity <= kMaxQueueCapacity) { - queueCapacity = capacity; - } else { // the input value: capacity is invalid, set the default value - queueCapacity = kDefaultQueueCapacity; - } - } - - /** - * @brief ThreadSafeQueue constructor - */ - ThreadSafeQueue() { queueCapacity = kDefaultQueueCapacity; } - - /** - * @brief ThreadSafeQueue destructor - */ - ~ThreadSafeQueue() = default; - - /** - * @brief push data to queue - * @param [in] input_value: the value will push to the queue - * @return true: success to push data; false: fail to push data - */ - bool Push(T input_value) { - std::lock_guard lock(mutex_); - - // check current size is less than capacity - if (queue_.size() < queueCapacity) { - queue_.push(input_value); - return true; - } - - return false; - } - - /** - * @brief pop data from queue - * @return true: success to pop data; false: fail to pop data - */ - T Pop() { - std::lock_guard lock(mutex_); - if (queue_.empty()) { // check the queue is empty - return nullptr; - } - - T tmp_ptr = queue_.front(); - queue_.pop(); - return tmp_ptr; - } - - /** - * @brief check the queue is empty - * @return true: the queue is empty; false: the queue is not empty - */ - bool Empty() { - std::lock_guard lock(mutex_); - return queue_.empty(); - } - - /** - * @brief get the queue size - * @return the queue size - */ - uint32_t Size() { - std::lock_guard lock(mutex_); - return queue_.size(); - } - - void ExtendCapacity(uint32_t newSize) { - queueCapacity = newSize; - kMaxQueueCapacity = newSize > kMaxQueueCapacity ? newSize : kMaxQueueCapacity; - } - - private: - std::queue queue_; // the queue - uint32_t queueCapacity; // queue capacity - mutable std::mutex mutex_; // the mutex value - const uint32_t kMinQueueCapacity = 1; // the minimum queue capacity - const uint32_t kMaxQueueCapacity = 10000; // the maximum queue capacity - const uint32_t kDefaultQueueCapacity = 10; // default queue capacity -}; -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ +/** + * ============================================================================ + * + * Copyright (C) 2018, Hisilicon Technologies Co., Ltd. All Rights Reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1 Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2 Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3 Neither the names of the copyright holders nor the names of the + * contributors may be used to endorse or promote products derived from this + * software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * ============================================================================ + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ + +#include +#include + +template +class ThreadSafeQueue { + public: + /** + * @brief ThreadSafeQueue constructor + * @param [in] capacity: the queue capacity + */ + explicit ThreadSafeQueue(uint32_t capacity) { + // check the input value: capacity is valid + if (capacity >= kMinQueueCapacity && capacity <= kMaxQueueCapacity) { + queueCapacity = capacity; + } else { // the input value: capacity is invalid, set the default value + queueCapacity = kDefaultQueueCapacity; + } + } + + /** + * @brief ThreadSafeQueue constructor + */ + ThreadSafeQueue() { queueCapacity = kDefaultQueueCapacity; } + + /** + * @brief ThreadSafeQueue destructor + */ + ~ThreadSafeQueue() = default; + + /** + * @brief push data to queue + * @param [in] input_value: the value will push to the queue + * @return true: success to push data; false: fail to push data + */ + bool Push(T input_value) { + std::lock_guard lock(mutex_); + + // check current size is less than capacity + if (queue_.size() < queueCapacity) { + queue_.push(input_value); + return true; + } + + return false; + } + + /** + * @brief pop data from queue + * @return true: success to pop data; false: fail to pop data + */ + T Pop() { + std::lock_guard lock(mutex_); + if (queue_.empty()) { // check the queue is empty + return nullptr; + } + + T tmp_ptr = queue_.front(); + queue_.pop(); + return tmp_ptr; + } + + /** + * @brief check the queue is empty + * @return true: the queue is empty; false: the queue is not empty + */ + bool Empty() { + std::lock_guard lock(mutex_); + return queue_.empty(); + } + + /** + * @brief get the queue size + * @return the queue size + */ + uint32_t Size() { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + void ExtendCapacity(uint32_t newSize) { + queueCapacity = newSize; + kMaxQueueCapacity = newSize > kMaxQueueCapacity ? newSize : kMaxQueueCapacity; + } + + private: + std::queue queue_; // the queue + uint32_t queueCapacity; // queue capacity + mutable std::mutex mutex_; // the mutex value + const uint32_t kMinQueueCapacity = 1; // the minimum queue capacity + const uint32_t kMaxQueueCapacity = 10000; // the maximum queue capacity + const uint32_t kDefaultQueueCapacity = 10; // default queue capacity +}; +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_THREAD_SAFE_QUEUE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.cc index 31326c2fa60..edbfc9a2ba9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.cc @@ -1,364 +1,364 @@ -/** -* Copyright 2022-2023 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -#include "minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h" - -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/symbol_utils.h" - -using namespace std; - -namespace { -const uint32_t kFrameWidthMax = 4096; -const uint32_t kFrameHeightMax = 4096; -} // namespace - -VdecHelper::VdecHelper(int channelId, uint32_t width, uint32_t height, int type, aclvdecCallback callback, - uint32_t outFormat) - : channelId_(channelId), - format_(outFormat), - enType_(type), - frameWidth_(width), - frameHeight_(height), - callback_(callback), - isExit_(false), - isReleased_(false), - isChannelExit_(false) { - alignWidth_ = ALIGN_UP16(frameWidth_); - alignHeight_ = ALIGN_UP2(frameHeight_); - outputPicSize_ = YUV420SP_SIZE(alignWidth_, alignHeight_); - - vdecChannelDesc_ = nullptr; - inputStreamDesc_ = nullptr; - outputPicDesc_ = nullptr; - outputPicBuf_ = nullptr; - - aclError aclRet; - ACLLITE_LOG_INFO("get current context"); - aclRet = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); - if ((aclRet != ACL_SUCCESS) || (context_ == nullptr)) { - ACLLITE_LOG_ERROR("VdecHelper : Get current acl context error:%d", aclRet); - } - - ACLLITE_LOG_INFO("VDEC width %d, height %d", frameWidth_, frameHeight_); -} - -VdecHelper::~VdecHelper() { DestroyResource(); } - -void VdecHelper::DestroyChannel() { - if (isReleased_) { - return; - } - aclError ret; - if (vdecChannelDesc_ != nullptr) { - ret = aclvdecDestroyChannel(vdecChannelDesc_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Vdec destroy channel failed, errorno: %d", ret); - } - ACLLITE_LOG_INFO("Vdec destory Channel ok"); - ret = aclvdecDestroyChannelDesc(vdecChannelDesc_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Vdec destory ChannelDesc failed, errorno: %d", ret); - } - ACLLITE_LOG_INFO("Vdec destory ChannelDesc ok"); - vdecChannelDesc_ = nullptr; - isChannelExit_ = true; - } -} - -void VdecHelper::DestroyResource() { - if (isReleased_) { - return; - } - constexpr auto kMicrosecond = 1000; - while (!isChannelExit_) { - (void)usleep(kMicrosecond); - } - UnsubscribReportThread(); - - // destory stream - aclError ret; - if (stream_ != nullptr) { - ret = CALL_ASCEND_API(aclrtDestroyStream, stream_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Vdec destroy stream failed"); - } - stream_ = nullptr; - } - isReleased_ = true; -} - -void *VdecHelper::SubscribeReportThreadFunc(void *arg) { - ACLLITE_LOG_INFO("Start vdec subscribe thread..."); - - // Notice: create context for this thread - auto *vdec = reinterpret_cast(arg); - aclrtContext context = vdec->GetContext(); - aclError ret = CALL_ASCEND_API(aclrtSetCurrentContext, context); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Video decoder set context failed, errorno: %d", ret); - } - - while (!vdec->IsExit()) { - // Notice: timeout 1000ms - ret = CALL_ASCEND_API(aclrtProcessReport, 1000); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Video decoder process report failed, errorno: %d", ret); - } - } - - ACLLITE_LOG_INFO("Vdec subscribe thread exit!"); - - return reinterpret_cast(ACLLITE_OK); -} - -void VdecHelper::UnsubscribReportThread() { - if ((subscribeThreadId_ == 0) || (stream_ == nullptr)) { - return; - } - - (void)aclrtUnSubscribeReport(static_cast(subscribeThreadId_), stream_); - // destory thread - isExit_ = true; - - void *res = nullptr; - int joinThreadErr = pthread_join(subscribeThreadId_, &res); - if (joinThreadErr) { - ACLLITE_LOG_ERROR("Join thread failed, threadId = %lu, err = %d", subscribeThreadId_, joinThreadErr); - } else { - if (reinterpret_cast(res) != 0) { - ACLLITE_LOG_ERROR("thread run failed. ret is %lu.", reinterpret_cast(res)); - } - } - ACLLITE_LOG_INFO("Destory report thread success."); -} - -AclLiteError VdecHelper::Init() { - ACLLITE_LOG_INFO("Vdec process init start..."); - aclError aclRet = aclrtCreateStream(&stream_); - if (aclRet != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Vdec create stream failed, errorno: %d", aclRet); - return ACLLITE_ERROR_CREATE_STREAM; - } - ACLLITE_LOG_INFO("Vdec create stream ok"); - - int ret = pthread_create(&subscribeThreadId_, nullptr, SubscribeReportThreadFunc, reinterpret_cast(this)); - if (ret) { - ACLLITE_LOG_ERROR("Start vdec subscribe thread failed, return: %d", ret); - return ACLLITE_ERROR_CREATE_THREAD; - } - (void)CALL_ASCEND_API(aclrtSubscribeReport, static_cast(subscribeThreadId_), stream_); - - ret = CreateVdecChannelDesc(); - if (ret != ACLLITE_OK) { - ACLLITE_LOG_ERROR("Create vdec channel failed"); - return ret; - } - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::CreateVdecChannelDesc() { - vdecChannelDesc_ = aclvdecCreateChannelDesc(); - if (vdecChannelDesc_ == nullptr) { - ACLLITE_LOG_ERROR("Create vdec channel desc failed"); - return ACLLITE_ERROR_CREATE_DVPP_CHANNEL_DESC; - } - - // channelId: 0-15 - aclError ret = aclvdecSetChannelDescChannelId(vdecChannelDesc_, channelId_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec channel id to %d failed, errorno:%d", channelId_, ret); - return ACLLITE_ERROR_SET_VDEC_CHANNEL_ID; - } - - ret = aclvdecSetChannelDescThreadId(vdecChannelDesc_, subscribeThreadId_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec channel thread id failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_VDEC_CHANNEL_THREAD_ID; - } - - // callback func - ret = aclvdecSetChannelDescCallback(vdecChannelDesc_, callback_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec channel callback failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_VDEC_CALLBACK; - } - - ret = aclvdecSetChannelDescEnType(vdecChannelDesc_, static_cast(enType_)); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec channel entype failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_VDEC_ENTYPE; - } - - ret = aclvdecSetChannelDescOutPicFormat(vdecChannelDesc_, static_cast(format_)); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec channel pic format failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_VDEC_PIC_FORMAT; - } - - // create vdec channel - ACLLITE_LOG_INFO("Start create vdec channel by desc..."); - ret = aclvdecCreateChannel(vdecChannelDesc_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("fail to create vdec channel"); - return ACLLITE_ERROR_CREATE_VDEC_CHANNEL; - } - ACLLITE_LOG_INFO("Create vdec channel ok"); - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::CreateInputStreamDesc(const std::shared_ptr &frameData) { - inputStreamDesc_ = acldvppCreateStreamDesc(); - if (inputStreamDesc_ == nullptr) { - ACLLITE_LOG_ERROR("Create input stream desc failed"); - return ACLLITE_ERROR_CREATE_STREAM_DESC; - } - - aclError ret; - // to the last data,send an endding signal to dvpp vdec - if (frameData->isFinished) { - ret = acldvppSetStreamDescEos(inputStreamDesc_, 1); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set EOS to input stream desc failed, errorno: %d", ret); - return ACLLITE_ERROR_SET_STREAM_DESC_EOS; - } - return ACLLITE_OK; - } - - ret = acldvppSetStreamDescData(inputStreamDesc_, frameData->data); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set input stream data failed, errorno: %d", ret); - return ACLLITE_ERROR_SET_STREAM_DESC_DATA; - } - - // set size for dvpp stream desc - ret = acldvppSetStreamDescSize(inputStreamDesc_, frameData->size); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set input stream size failed, errorno: %d", ret); - return ACLLITE_ERROR_SET_STREAM_DESC_SIZE; - } - - ret = acldvppSetStreamDescTimestamp(inputStreamDesc_, frameData->frameId); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set input stream timestamp failed, errorno: %d", ret); - return ACLLITE_ERROR; - } - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::CreateOutputPicDesc(size_t size) { - // Malloc output device memory - aclError ret = acldvppMalloc(&outputPicBuf_, size); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR( - "Malloc vdec output buffer failed when create " - "vdec output desc, errorno:%d", - ret); - return ACLLITE_ERROR_MALLOC_DVPP; - } - - outputPicDesc_ = acldvppCreatePicDesc(); - if (outputPicDesc_ == nullptr) { - ACLLITE_LOG_ERROR("Create vdec output pic desc failed"); - return ACLLITE_ERROR_CREATE_PIC_DESC; - } - - ret = acldvppSetPicDescData(outputPicDesc_, outputPicBuf_); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec output pic desc data failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_PIC_DESC_DATA; - } - - ret = acldvppSetPicDescSize(outputPicDesc_, size); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec output pic size failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_PIC_DESC_SIZE; - } - - ret = acldvppSetPicDescFormat(outputPicDesc_, static_cast(format_)); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Set vdec output pic format failed, errorno:%d", ret); - return ACLLITE_ERROR_SET_PIC_DESC_FORMAT; - } - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::Process(const std::shared_ptr &frameData, void *userData) { - // create input desc - AclLiteError atlRet = CreateInputStreamDesc(frameData); - if (atlRet != ACLLITE_OK) { - ACLLITE_LOG_ERROR("Create stream desc failed"); - return atlRet; - } - // create out desc - atlRet = CreateOutputPicDesc(outputPicSize_); - if (atlRet != ACLLITE_OK) { - ACLLITE_LOG_ERROR("Create pic desc failed"); - return atlRet; - } - // send data to dvpp vdec to decode - aclError ret = aclvdecSendFrame(vdecChannelDesc_, inputStreamDesc_, outputPicDesc_, nullptr, userData); - if (ret != ACL_SUCCESS) { - ACLLITE_LOG_ERROR("Send frame to vdec failed, errorno:%d", ret); - return ACLLITE_ERROR_VDEC_SEND_FRAME; - } - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::SetFormat(uint32_t format) { - if ((format != PIXEL_FORMAT_YUV_SEMIPLANAR_420) && (format != PIXEL_FORMAT_YVU_SEMIPLANAR_420)) { - ACLLITE_LOG_ERROR( - "Set video decode output image format to %d failed, " - "only support %d(YUV420SP NV12) and %d(YUV420SP NV21)", - format, (int)PIXEL_FORMAT_YUV_SEMIPLANAR_420, (int)PIXEL_FORMAT_YVU_SEMIPLANAR_420); - return ACLLITE_ERROR_VDEC_FORMAT_INVALID; - } - - format_ = format; - ACLLITE_LOG_INFO("Set video decode output image format to %d ok", format); - - return ACLLITE_OK; -} - -AclLiteError VdecHelper::VideoParamCheck() const { - if ((frameWidth_ == 0) || (frameWidth_ > kFrameWidthMax)) { - ACLLITE_LOG_ERROR("video frame width %d is invalid, the legal range is [0, %d]", frameWidth_, kFrameWidthMax); - return ACLLITE_ERROR_VDEC_INVALID_PARAM; - } - if ((frameHeight_ == 0) || (frameHeight_ > kFrameHeightMax)) { - ACLLITE_LOG_ERROR("video frame height %d is invalid, the legal range is [0, %d]", frameHeight_, kFrameHeightMax); - return ACLLITE_ERROR_VDEC_INVALID_PARAM; - } - if ((format_ != PIXEL_FORMAT_YUV_SEMIPLANAR_420) && (format_ != PIXEL_FORMAT_YVU_SEMIPLANAR_420)) { - ACLLITE_LOG_ERROR( - "video decode image format %d invalid, " - "only support %d(YUV420SP NV12) and %d(YUV420SP NV21)", - format_, (int)PIXEL_FORMAT_YUV_SEMIPLANAR_420, (int)PIXEL_FORMAT_YVU_SEMIPLANAR_420); - return ACLLITE_ERROR_VDEC_INVALID_PARAM; - } - if (enType_ > static_cast(H264_HIGH_LEVEL)) { - ACLLITE_LOG_ERROR("Input video stream format %d invalid", enType_); - return ACLLITE_ERROR_VDEC_INVALID_PARAM; - } - - return ACLLITE_OK; -} +/** +* Copyright 2022-2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h" + +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/symbol_utils.h" + +using namespace std; + +namespace { +const uint32_t kFrameWidthMax = 4096; +const uint32_t kFrameHeightMax = 4096; +} // namespace + +VdecHelper::VdecHelper(int channelId, uint32_t width, uint32_t height, int type, aclvdecCallback callback, + uint32_t outFormat) + : channelId_(channelId), + format_(outFormat), + enType_(type), + frameWidth_(width), + frameHeight_(height), + callback_(callback), + isExit_(false), + isReleased_(false), + isChannelExit_(false) { + alignWidth_ = ALIGN_UP16(frameWidth_); + alignHeight_ = ALIGN_UP2(frameHeight_); + outputPicSize_ = YUV420SP_SIZE(alignWidth_, alignHeight_); + + vdecChannelDesc_ = nullptr; + inputStreamDesc_ = nullptr; + outputPicDesc_ = nullptr; + outputPicBuf_ = nullptr; + + aclError aclRet; + ACLLITE_LOG_INFO("get current context"); + aclRet = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); + if ((aclRet != ACL_SUCCESS) || (context_ == nullptr)) { + ACLLITE_LOG_ERROR("VdecHelper : Get current acl context error:%d", aclRet); + } + + ACLLITE_LOG_INFO("VDEC width %d, height %d", frameWidth_, frameHeight_); +} + +VdecHelper::~VdecHelper() { DestroyResource(); } + +void VdecHelper::DestroyChannel() { + if (isReleased_) { + return; + } + aclError ret; + if (vdecChannelDesc_ != nullptr) { + ret = aclvdecDestroyChannel(vdecChannelDesc_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Vdec destroy channel failed, errorno: %d", ret); + } + ACLLITE_LOG_INFO("Vdec destory Channel ok"); + ret = aclvdecDestroyChannelDesc(vdecChannelDesc_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Vdec destory ChannelDesc failed, errorno: %d", ret); + } + ACLLITE_LOG_INFO("Vdec destory ChannelDesc ok"); + vdecChannelDesc_ = nullptr; + isChannelExit_ = true; + } +} + +void VdecHelper::DestroyResource() { + if (isReleased_) { + return; + } + constexpr auto kMicrosecond = 1000; + while (!isChannelExit_) { + (void)usleep(kMicrosecond); + } + UnsubscribReportThread(); + + // destory stream + aclError ret; + if (stream_ != nullptr) { + ret = CALL_ASCEND_API(aclrtDestroyStream, stream_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Vdec destroy stream failed"); + } + stream_ = nullptr; + } + isReleased_ = true; +} + +void *VdecHelper::SubscribeReportThreadFunc(void *arg) { + ACLLITE_LOG_INFO("Start vdec subscribe thread..."); + + // Notice: create context for this thread + auto *vdec = reinterpret_cast(arg); + aclrtContext context = vdec->GetContext(); + aclError ret = CALL_ASCEND_API(aclrtSetCurrentContext, context); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Video decoder set context failed, errorno: %d", ret); + } + + while (!vdec->IsExit()) { + // Notice: timeout 1000ms + ret = CALL_ASCEND_API(aclrtProcessReport, 1000); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Video decoder process report failed, errorno: %d", ret); + } + } + + ACLLITE_LOG_INFO("Vdec subscribe thread exit!"); + + return reinterpret_cast(ACLLITE_OK); +} + +void VdecHelper::UnsubscribReportThread() { + if ((subscribeThreadId_ == 0) || (stream_ == nullptr)) { + return; + } + + (void)aclrtUnSubscribeReport(static_cast(subscribeThreadId_), stream_); + // destory thread + isExit_ = true; + + void *res = nullptr; + int joinThreadErr = pthread_join(subscribeThreadId_, &res); + if (joinThreadErr) { + ACLLITE_LOG_ERROR("Join thread failed, threadId = %lu, err = %d", subscribeThreadId_, joinThreadErr); + } else { + if (reinterpret_cast(res) != 0) { + ACLLITE_LOG_ERROR("thread run failed. ret is %lu.", reinterpret_cast(res)); + } + } + ACLLITE_LOG_INFO("Destory report thread success."); +} + +AclLiteError VdecHelper::Init() { + ACLLITE_LOG_INFO("Vdec process init start..."); + aclError aclRet = aclrtCreateStream(&stream_); + if (aclRet != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Vdec create stream failed, errorno: %d", aclRet); + return ACLLITE_ERROR_CREATE_STREAM; + } + ACLLITE_LOG_INFO("Vdec create stream ok"); + + int ret = pthread_create(&subscribeThreadId_, nullptr, SubscribeReportThreadFunc, reinterpret_cast(this)); + if (ret) { + ACLLITE_LOG_ERROR("Start vdec subscribe thread failed, return: %d", ret); + return ACLLITE_ERROR_CREATE_THREAD; + } + (void)CALL_ASCEND_API(aclrtSubscribeReport, static_cast(subscribeThreadId_), stream_); + + ret = CreateVdecChannelDesc(); + if (ret != ACLLITE_OK) { + ACLLITE_LOG_ERROR("Create vdec channel failed"); + return ret; + } + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::CreateVdecChannelDesc() { + vdecChannelDesc_ = aclvdecCreateChannelDesc(); + if (vdecChannelDesc_ == nullptr) { + ACLLITE_LOG_ERROR("Create vdec channel desc failed"); + return ACLLITE_ERROR_CREATE_DVPP_CHANNEL_DESC; + } + + // channelId: 0-15 + aclError ret = aclvdecSetChannelDescChannelId(vdecChannelDesc_, channelId_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec channel id to %d failed, errorno:%d", channelId_, ret); + return ACLLITE_ERROR_SET_VDEC_CHANNEL_ID; + } + + ret = aclvdecSetChannelDescThreadId(vdecChannelDesc_, subscribeThreadId_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec channel thread id failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_VDEC_CHANNEL_THREAD_ID; + } + + // callback func + ret = aclvdecSetChannelDescCallback(vdecChannelDesc_, callback_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec channel callback failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_VDEC_CALLBACK; + } + + ret = aclvdecSetChannelDescEnType(vdecChannelDesc_, static_cast(enType_)); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec channel entype failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_VDEC_ENTYPE; + } + + ret = aclvdecSetChannelDescOutPicFormat(vdecChannelDesc_, static_cast(format_)); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec channel pic format failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_VDEC_PIC_FORMAT; + } + + // create vdec channel + ACLLITE_LOG_INFO("Start create vdec channel by desc..."); + ret = aclvdecCreateChannel(vdecChannelDesc_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("fail to create vdec channel"); + return ACLLITE_ERROR_CREATE_VDEC_CHANNEL; + } + ACLLITE_LOG_INFO("Create vdec channel ok"); + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::CreateInputStreamDesc(const std::shared_ptr &frameData) { + inputStreamDesc_ = acldvppCreateStreamDesc(); + if (inputStreamDesc_ == nullptr) { + ACLLITE_LOG_ERROR("Create input stream desc failed"); + return ACLLITE_ERROR_CREATE_STREAM_DESC; + } + + aclError ret; + // to the last data,send an endding signal to dvpp vdec + if (frameData->isFinished) { + ret = acldvppSetStreamDescEos(inputStreamDesc_, 1); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set EOS to input stream desc failed, errorno: %d", ret); + return ACLLITE_ERROR_SET_STREAM_DESC_EOS; + } + return ACLLITE_OK; + } + + ret = acldvppSetStreamDescData(inputStreamDesc_, frameData->data); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set input stream data failed, errorno: %d", ret); + return ACLLITE_ERROR_SET_STREAM_DESC_DATA; + } + + // set size for dvpp stream desc + ret = acldvppSetStreamDescSize(inputStreamDesc_, frameData->size); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set input stream size failed, errorno: %d", ret); + return ACLLITE_ERROR_SET_STREAM_DESC_SIZE; + } + + ret = acldvppSetStreamDescTimestamp(inputStreamDesc_, frameData->frameId); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set input stream timestamp failed, errorno: %d", ret); + return ACLLITE_ERROR; + } + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::CreateOutputPicDesc(size_t size) { + // Malloc output device memory + aclError ret = acldvppMalloc(&outputPicBuf_, size); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR( + "Malloc vdec output buffer failed when create " + "vdec output desc, errorno:%d", + ret); + return ACLLITE_ERROR_MALLOC_DVPP; + } + + outputPicDesc_ = acldvppCreatePicDesc(); + if (outputPicDesc_ == nullptr) { + ACLLITE_LOG_ERROR("Create vdec output pic desc failed"); + return ACLLITE_ERROR_CREATE_PIC_DESC; + } + + ret = acldvppSetPicDescData(outputPicDesc_, outputPicBuf_); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec output pic desc data failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_PIC_DESC_DATA; + } + + ret = acldvppSetPicDescSize(outputPicDesc_, size); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec output pic size failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_PIC_DESC_SIZE; + } + + ret = acldvppSetPicDescFormat(outputPicDesc_, static_cast(format_)); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Set vdec output pic format failed, errorno:%d", ret); + return ACLLITE_ERROR_SET_PIC_DESC_FORMAT; + } + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::Process(const std::shared_ptr &frameData, void *userData) { + // create input desc + AclLiteError atlRet = CreateInputStreamDesc(frameData); + if (atlRet != ACLLITE_OK) { + ACLLITE_LOG_ERROR("Create stream desc failed"); + return atlRet; + } + // create out desc + atlRet = CreateOutputPicDesc(outputPicSize_); + if (atlRet != ACLLITE_OK) { + ACLLITE_LOG_ERROR("Create pic desc failed"); + return atlRet; + } + // send data to dvpp vdec to decode + aclError ret = aclvdecSendFrame(vdecChannelDesc_, inputStreamDesc_, outputPicDesc_, nullptr, userData); + if (ret != ACL_SUCCESS) { + ACLLITE_LOG_ERROR("Send frame to vdec failed, errorno:%d", ret); + return ACLLITE_ERROR_VDEC_SEND_FRAME; + } + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::SetFormat(uint32_t format) { + if ((format != PIXEL_FORMAT_YUV_SEMIPLANAR_420) && (format != PIXEL_FORMAT_YVU_SEMIPLANAR_420)) { + ACLLITE_LOG_ERROR( + "Set video decode output image format to %d failed, " + "only support %d(YUV420SP NV12) and %d(YUV420SP NV21)", + format, (int)PIXEL_FORMAT_YUV_SEMIPLANAR_420, (int)PIXEL_FORMAT_YVU_SEMIPLANAR_420); + return ACLLITE_ERROR_VDEC_FORMAT_INVALID; + } + + format_ = format; + ACLLITE_LOG_INFO("Set video decode output image format to %d ok", format); + + return ACLLITE_OK; +} + +AclLiteError VdecHelper::VideoParamCheck() const { + if ((frameWidth_ == 0) || (frameWidth_ > kFrameWidthMax)) { + ACLLITE_LOG_ERROR("video frame width %d is invalid, the legal range is [0, %d]", frameWidth_, kFrameWidthMax); + return ACLLITE_ERROR_VDEC_INVALID_PARAM; + } + if ((frameHeight_ == 0) || (frameHeight_ > kFrameHeightMax)) { + ACLLITE_LOG_ERROR("video frame height %d is invalid, the legal range is [0, %d]", frameHeight_, kFrameHeightMax); + return ACLLITE_ERROR_VDEC_INVALID_PARAM; + } + if ((format_ != PIXEL_FORMAT_YUV_SEMIPLANAR_420) && (format_ != PIXEL_FORMAT_YVU_SEMIPLANAR_420)) { + ACLLITE_LOG_ERROR( + "video decode image format %d invalid, " + "only support %d(YUV420SP NV12) and %d(YUV420SP NV21)", + format_, (int)PIXEL_FORMAT_YUV_SEMIPLANAR_420, (int)PIXEL_FORMAT_YVU_SEMIPLANAR_420); + return ACLLITE_ERROR_VDEC_INVALID_PARAM; + } + if (enType_ > static_cast(H264_HIGH_LEVEL)) { + ACLLITE_LOG_ERROR("Input video stream format %d invalid", enType_); + return ACLLITE_ERROR_VDEC_INVALID_PARAM; + } + + return ACLLITE_OK; +} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h index d68f0d92f74..5c536b977d1 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h @@ -1,88 +1,88 @@ -/** -* Copyright 2022 Huawei Technologies Co., Ltd -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at - -* http://www.apache.org/licenses/LICENSE-2.0 - -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ - -#include -#include -#include - -#include "acl/acl.h" -#include "acl/ops/acl_dvpp.h" - -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h" -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h" - -class VdecHelper { - public: - VdecHelper(int channel, uint32_t width, uint32_t height, int type, aclvdecCallback callback, - uint32_t outFormat = PIXEL_FORMAT_YUV_SEMIPLANAR_420); - ~VdecHelper(); - - static void *SubscribeReportThreadFunc(void *arg); - - AclLiteError Init(); - void DestroyResource(); - void DestroyChannel(); - - AclLiteError Process(const std::shared_ptr &frameData, void *userData); - AclLiteError SetFormat(uint32_t format); - AclLiteError VideoParamCheck() const; - bool IsExit() const { return isExit_; } - aclrtContext GetContext() { return context_; } - - private: - AclLiteError CreateVdecChannelDesc(); - AclLiteError CreateInputStreamDesc(const std::shared_ptr &frameData); - AclLiteError CreateOutputPicDesc(size_t size); - void UnsubscribReportThread(); - - private: - int channelId_; - - /* 1:YUV420 semi-planner(nv12) - 2:YVU420 semi-planner(nv21) - */ - uint32_t format_; - - /* 0:H265 main level - * 1:H264 baseline level - * 2:H264 main level - * 3:H264 high level - */ - uint32_t enType_; - - uint32_t frameWidth_; - uint32_t frameHeight_; - uint32_t alignWidth_; - uint32_t alignHeight_; - uint32_t outputPicSize_; - void *outputPicBuf_; - aclvdecCallback callback_; - aclrtContext context_{}; - aclrtStream stream_{}; - - aclvdecChannelDesc *vdecChannelDesc_; - acldvppStreamDesc *inputStreamDesc_; - acldvppPicDesc *outputPicDesc_; - - pthread_t subscribeThreadId_{}; - bool isExit_; - bool isReleased_; - bool isChannelExit_; -}; - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ +/** +* Copyright 2022 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "acl/ops/acl_dvpp.h" + +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteError.h" +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteType.h" + +class VdecHelper { + public: + VdecHelper(int channel, uint32_t width, uint32_t height, int type, aclvdecCallback callback, + uint32_t outFormat = PIXEL_FORMAT_YUV_SEMIPLANAR_420); + ~VdecHelper(); + + static void *SubscribeReportThreadFunc(void *arg); + + AclLiteError Init(); + void DestroyResource(); + void DestroyChannel(); + + AclLiteError Process(const std::shared_ptr &frameData, void *userData); + AclLiteError SetFormat(uint32_t format); + AclLiteError VideoParamCheck() const; + bool IsExit() const { return isExit_; } + aclrtContext GetContext() { return context_; } + + private: + AclLiteError CreateVdecChannelDesc(); + AclLiteError CreateInputStreamDesc(const std::shared_ptr &frameData); + AclLiteError CreateOutputPicDesc(size_t size); + void UnsubscribReportThread(); + + private: + int channelId_; + + /* 1:YUV420 semi-planner(nv12) + 2:YVU420 semi-planner(nv21) + */ + uint32_t format_; + + /* 0:H265 main level + * 1:H264 baseline level + * 2:H264 main level + * 3:H264 high level + */ + uint32_t enType_; + + uint32_t frameWidth_; + uint32_t frameHeight_; + uint32_t alignWidth_; + uint32_t alignHeight_; + uint32_t outputPicSize_; + void *outputPicBuf_; + aclvdecCallback callback_; + aclrtContext context_{}; + aclrtStream stream_{}; + + aclvdecChannelDesc *vdecChannelDesc_; + acldvppStreamDesc *inputStreamDesc_; + acldvppPicDesc *outputPicDesc_; + + pthread_t subscribeThreadId_{}; + bool isExit_; + bool isReleased_; + bool isChannelExit_; +}; + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_VDEC_HELPER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.cc index 7fedd2b21ab..08d47e518d2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.cc @@ -1,682 +1,682 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h" - -#include -#include - -#include -#include -#include -#include - -#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" -#include "mindspore/core/utils/log_adapter.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/symbol_utils.h" - -namespace { -const int64_t kUsec = 1000000; -const uint32_t kDecodeFrameQueueSize = 256; -const int kDecodeQueueOpWait = 10000; // decode wait 10ms/frame -const int kFrameEnQueueRetryTimes = 1000; // max wait time for the frame to enter in queue -const int kQueueOpRetryTimes = 1000; -const int kOutputJamWait = 10000; -const int kInvalidTpye = -1; -const int kWaitDecodeFinishInterval = 1000; - -const uint32_t DVPP_VIDEO_H264 = 0; -const uint32_t DVPP_VIDEO_H265 = 1; - -ChannelIdGenerator channelIdGenerator; -} // namespace - -FrameExtarct::FrameExtarct(uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type) - : data_(data), size_(size), frameWidth_(width), frameHeight_(height) { - isFinished_ = false; - isStop_ = false; - videoType_ = (type == 0) ? DVPP_VIDEO_H265 : DVPP_VIDEO_H264; -} - -void FrameExtarct::ExtractFrameH264(const uint8_t *buf_ptr, int *size_ptr) { - if (buf_ptr == nullptr || size_ptr == nullptr || *size_ptr <= 0) { - return; - } - bool isFindStart = false; - bool isFindEnd = false; - int size = *size_ptr; - int i = 0; - for (; i < size - 8; i++) { - if (FindStartH264(buf_ptr, i)) { - isFindStart = true; - i += 8; - break; - } - } - - for (; i < size - 8; i++) { - if (FindEndH264(buf_ptr, i)) { - isFindEnd = true; - break; - } - } - - if (i > 0) { - *size_ptr = i; - } - - if (!isFindStart) { - MS_LOG(ERROR) << "Channel can not find H265 start code, please check input video coding protocol is H264."; - return; - } - if (!isFindEnd) { - *size_ptr = i + 8; - } -} - -void FrameExtarct::ExtractFrameH265(const uint8_t *buf_ptr, int *size_ptr) { - if (buf_ptr == nullptr || size_ptr == nullptr || *size_ptr <= 0) { - return; - } - bool isFindStart = false; - bool isFindEnd = false; - int i = 0; - for (; i < *size_ptr - 6; i++) { - if (FindStartH265(buf_ptr, i)) { - isFindStart = true; - i += 6; - break; - } - } - - for (; i < *size_ptr - 6; i++) { - if (FindEndH265(buf_ptr, i)) { - isFindEnd = true; - break; - } - } - if (i > 0) { - *size_ptr = i; - } - - if (!isFindStart) { - MS_LOG(ERROR) << "Channel can not find H265 start code, please check input video coding protocol is H265."; - return; - } - if (!isFindEnd) { - *size_ptr = i + 6; - } -} - -void FrameExtarct::Decode(FrameProcessCallBack callback, void *callbackParam) { - MS_LOG(INFO) << "Start extarct frame from video..."; - - int32_t usedBytes = 0; - uint32_t count = 0; - bool processOk = true; - - while (!isStop_ && processOk) { - uint8_t *bufPointer; - bufPointer = data_ + usedBytes; - int32_t readlen = size_ - usedBytes; - if (readlen <= 0) { - break; - } - - if (videoType_ == DVPP_VIDEO_H264) { // H264 - ExtractFrameH264(bufPointer, &readlen); - } else if (videoType_ == DVPP_VIDEO_H265) { // H265 - ExtractFrameH265(bufPointer, &readlen); - } - int ret = callback(callbackParam, bufPointer, readlen); - if (ret != 0) { - processOk = false; - } - count++; - usedBytes = usedBytes + readlen; - } - // Frame count - - isFinished_ = true; - MS_LOG(INFO) << "FrameExtarct decoder finished, frame count: " << count << "."; -} - -DvppVideo::DvppVideo(aclrtContext context, uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type, - uint32_t out_format, const std::string &output) - : data_(data), - size_(size), - frameWidth_(width), - frameHeight_(height), - format_(out_format), - output_(output), - isStop_(false), - isReleased_(false), - isJam_(false), - status_(DecodeStatus::DECODE_UNINIT), - context_(context), - channelId_(INVALID_CHANNEL_ID), - streamFormat_(type), - frameId_(0), - finFrameCnt_(0), - lastDecodeTime_(0), - frameExtarct_(nullptr), - dvppVdec_(nullptr), - frameImageQueue_(kDecodeFrameQueueSize) {} - -DvppVideo::~DvppVideo() { DestroyResource(); } - -void DvppVideo::DestroyResource() { - if (isReleased_) { - return; - } - // 1. stop ffmpeg - isStop_ = true; - frameExtarct_->StopDecode(); - while ((status_ >= DecodeStatus::DECODE_START) && (status_ < DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED)) { - (void)usleep(kWaitDecodeFinishInterval); - } - // 2. delete ffmpeg decoder - delete frameExtarct_; - frameExtarct_ = nullptr; - - // 3. release dvpp vdec - dvppVdec_->DestroyResource(); - - // 4. release image memory in decode output queue - do { - std::shared_ptr frame = FrameImageOutQueue(true); - if (frame == nullptr) { - break; - } - - if (frame->data != nullptr) { - auto ret = acldvppFree(frame->data.get()); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "acldvppFree failed, errorno: " << ret; - } - frame->data = nullptr; - } - } while (true); - // 5. release channel id - channelIdGenerator.ReleaseChannelId(channelId_); - - isReleased_ = true; -} - -AclLiteError DvppVideo::InitResource() { - aclError aclRet; - // use current thread context default - if (context_ == nullptr) { - aclRet = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); - if ((aclRet != ACL_SUCCESS) || (context_ == nullptr)) { - MS_LOG(ERROR) << "Get current acl context error: " << aclRet; - return ACLLITE_ERROR_GET_ACL_CONTEXT; - } - } - // Get current run mode - aclRet = CALL_ASCEND_API(aclrtGetRunMode, &runMode_); - if (aclRet != ACL_SUCCESS) { - MS_LOG(ERROR) << "acl get run mode failed"; - return ACLLITE_ERROR_GET_RUM_MODE; - } - - return ACLLITE_OK; -} - -AclLiteError DvppVideo::InitVdecDecoder() { - // Generate a unique channel id for video decoder - channelId_ = channelIdGenerator.GenerateChannelId(); - if (channelId_ == INVALID_CHANNEL_ID) { - MS_LOG(ERROR) << "Decoder number excessive " << VIDEO_CHANNEL_MAX; - return ACLLITE_ERROR_TOO_MANY_VIDEO_DECODERS; - } - - // Create dvpp vdec to decode h26x data - dvppVdec_ = new VdecHelper(channelId_, frameWidth_, frameHeight_, streamFormat_, DvppVideo::DvppVdecCallback); - - AclLiteError ret = dvppVdec_->SetFormat(format_); - if (ret != ACLLITE_OK) { - MS_LOG(ERROR) << "Dvpp vdec set out format failed"; - } - ret = dvppVdec_->Init(); - if (ret != ACLLITE_OK) { - MS_LOG(ERROR) << "Dvpp vdec init failed"; - } - - ret = this->dvppVdec_->VideoParamCheck(); - if (ret != ACLLITE_OK) { - this->SetStatus(DecodeStatus::DECODE_ERROR); - MS_LOG(ERROR) << "Dvpp vdec check param failed " << ret; - return ret; - } - - return ret; -} - -AclLiteError DvppVideo::InitFrameExtractor() { - // Create ffmpeg decoder to parse video stream to h26x frame data - frameExtarct_ = new FrameExtarct(data_, size_, frameWidth_, frameHeight_, streamFormat_); - return ACLLITE_OK; -} - -AclLiteError DvppVideo::Init() { - // Open video stream, if open failed before, return error directly - if (status_ == DecodeStatus::DECODE_ERROR) { - return ACLLITE_ERROR_OPEN_VIDEO_UNREADY; - } - // If open ok already - if (status_ != DecodeStatus::DECODE_UNINIT) { - return ACLLITE_OK; - } - // Init acl resource - AclLiteError ret = InitResource(); - if (ret != ACLLITE_OK) { - this->SetStatus(DecodeStatus::DECODE_ERROR); - MS_LOG(ERROR) << "Dvpp video init resource failed " << ret; - return ret; - } - // Init ffmpeg decoder - ret = InitFrameExtractor(); - if (ret != ACLLITE_OK) { - this->SetStatus(DecodeStatus::DECODE_ERROR); - MS_LOG(ERROR) << "Dvpp video init FrameExtractor failed " << ret; - return ret; - } - // Init dvpp vdec decoder - ret = InitVdecDecoder(); - if (ret != ACLLITE_OK) { - this->SetStatus(DecodeStatus::DECODE_ERROR); - MS_LOG(ERROR) << "Dvpp video init Vdec failed " << ret; - return ret; - } - // Set init ok - this->SetStatus(DecodeStatus::DECODE_READY); - MS_LOG(INFO) << "Dvpp video init ok"; - - return ACLLITE_OK; -} - -// dvpp vdec callback -void DvppVideo::DvppVdecCallback(acldvppStreamDesc *input, acldvppPicDesc *output, void *userData) { - auto *decoder = reinterpret_cast(userData); - // Get decoded image parameters - std::shared_ptr image = std::make_shared(); - image->format = acldvppGetPicDescFormat(output); - image->width = acldvppGetPicDescWidth(output); - image->height = acldvppGetPicDescHeight(output); - image->alignWidth = acldvppGetPicDescWidthStride(output); - image->alignHeight = acldvppGetPicDescHeightStride(output); - image->size = acldvppGetPicDescSize(output); - - void *vdecOutBufferDev = acldvppGetPicDescData(output); - image->data = SHARED_PTR_DVPP_BUF(vdecOutBufferDev); - - // Put the decoded image to queue for read - decoder->ProcessDecodedImage(image); - // Release resource - aclError ret = acldvppDestroyPicDesc(output); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Dvpp vdec destroy pic desc failed " << ret; - } - - if (input != nullptr) { - void *inputBuf = acldvppGetStreamDescData(input); - if (inputBuf != nullptr) { - ret = acldvppFree(inputBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "acldvppFree failed, errorno: " << ret; - } - } - - ret = acldvppDestroyStreamDesc(input); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Dvpp vdec destroy input stream failed " << ret; - } - } -} - -void DvppVideo::ProcessDecodedImage(std::shared_ptr frameData) { - finFrameCnt_++; - if (YUV420SP_SIZE(frameData->width, frameData->height) != frameData->size) { - MS_LOG(ERROR) << "Invalid decoded frame parameter, width " << frameData->width << ", height " << frameData->height - << ", size " << frameData->size << ", buffer " << static_cast(frameData->data.get()); - return; - } - - auto ret = FrameImageEnQueue(frameData); - if (ret != ACLLITE_OK) { - MS_LOG(ERROR) << "FrameImageEnQueue faile, errorno: " << ret; - } - - if ((status_ == DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED) && (finFrameCnt_ >= frameId_)) { - MS_LOG(INFO) << "Last frame decoded by dvpp, change status to " << DecodeStatus::DECODE_DVPP_FINISHED; - this->SetStatus(DecodeStatus::DECODE_DVPP_FINISHED); - } -} - -AclLiteError DvppVideo::FrameImageEnQueue(const std::shared_ptr &frameData) { - for (int count = 0; count < kFrameEnQueueRetryTimes; count++) { - if (frameImageQueue_.Push(frameData)) { - return ACLLITE_OK; - } - (void)usleep(kDecodeQueueOpWait); - } - MS_LOG(ERROR) << "Video lost decoded image for queue full"; - - return ACLLITE_ERROR_VDEC_QUEUE_FULL; -} - -// start decoder -void DvppVideo::StartFrameDecoder() { - if (status_ == DecodeStatus::DECODE_READY) { - decodeThread_ = std::thread(FrameDecodeThreadFunction, reinterpret_cast(this)); - decodeThread_.detach(); - - status_ = DecodeStatus::DECODE_START; - } -} - -// ffmpeg decoder entry -void DvppVideo::FrameDecodeThreadFunction(void *decoderSelf) { - if (decoderSelf == nullptr) { - return; - } - auto *thisPtr = reinterpret_cast(decoderSelf); - - aclError aclRet = thisPtr->SetAclContext(); - if (aclRet != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set frame decoder context failed, errorno: " << aclRet; - return; - } - // start decode until complete - thisPtr->DecodeH26xFrame(); - if (thisPtr->IsStop()) { - thisPtr->SetStatus(DecodeStatus::DECODE_FINISHED); - return; - } - thisPtr->SetStatus(DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED); - // when ffmpeg decode finish, send eos to vdec - std::shared_ptr videoFrame = std::make_shared(); - videoFrame->isFinished = true; - videoFrame->data = nullptr; - videoFrame->size = 0; - auto ret = thisPtr->dvppVdec_->Process(videoFrame, decoderSelf); - if (ret != ACLLITE_ERROR) { - MS_LOG(ERROR) << "DvppVdec procesing failed, errorno: " << ret; - } - - thisPtr->dvppVdec_->DestroyChannel(); - while ((thisPtr->GetStatus() != DecodeStatus::DECODE_DVPP_FINISHED) && !thisPtr->IsStop()) { - (void)usleep(kWaitDecodeFinishInterval); - } -} - -// callback of ffmpeg decode frame -AclLiteError DvppVideo::FrameDecodeCallback(void *decoder, void *frameData, int frameSize) { - if ((frameData == nullptr) || (frameSize == 0)) { - MS_LOG(ERROR) << "Frame data is null"; - return ACLLITE_ERROR_H26X_FRAME; - } - - // copy data to dvpp memory - if (decoder == nullptr) { - MS_LOG(ERROR) << "Decoder is nullptr"; - return ACLLITE_ERROR_H26X_FRAME; - } - auto *videoDecoder = reinterpret_cast(decoder); - - void *buffer = CopyDataToDevice(frameData, frameSize, videoDecoder->runMode_, MemoryType::MEMORY_DVPP); - if (buffer == nullptr) { - MS_LOG(ERROR) << "Copy frame h26x data to dvpp failed"; - return ACLLITE_ERROR_COPY_DATA; - } - - std::shared_ptr videoFrame = std::make_shared(); - videoDecoder->frameId_++; - videoFrame->frameId = videoDecoder->frameId_; - videoFrame->data = buffer; - videoFrame->size = frameSize; - // decode data by dvpp vdec - AclLiteError ret = videoDecoder->dvppVdec_->Process(videoFrame, decoder); - if (ret != ACLLITE_OK) { - MS_LOG(ERROR) << "Dvpp vdec process " << videoDecoder->frameId_ << "th frame failed, error: " << ret; - return ret; - } - return ACLLITE_OK; -} - -// read decoded frame -AclLiteError DvppVideo::Read(std::shared_ptr *image_ptr) { - if (image_ptr == nullptr) { - MS_LOG(ERROR) << "image_ptr is nullptr"; - return ACLLITE_ERROR; - } - // return nullptr,if decode fail/finish - if (status_ == DecodeStatus::DECODE_ERROR) { - MS_LOG(ERROR) << "Read failed for decode failed"; - return ACLLITE_ERROR_VIDEO_DECODER_STATUS; - } - - if (status_ == DecodeStatus::DECODE_FINISHED) { - MS_LOG(INFO) << "No frame to read for decode finished"; - return ACLLITE_ERROR_DECODE_FINISH; - } - // start decode if status is ok - if (status_ == DecodeStatus::DECODE_READY) { - StartFrameDecoder(); - (void)usleep(kDecodeQueueOpWait); - } - // read frame from decode queue - bool noWait = (status_ == DecodeStatus::DECODE_DVPP_FINISHED); - std::shared_ptr frame = FrameImageOutQueue(noWait); - if (noWait && (frame == nullptr)) { - SetStatus(DecodeStatus::DECODE_FINISHED); - MS_LOG(INFO) << "No frame to read anymore"; - return ACLLITE_ERROR_DECODE_FINISH; - } - - if (frame == nullptr) { - MS_LOG(ERROR) << "Empty frame image to read"; - return ACLLITE_ERROR_READ_EMPTY; - } - - (*image_ptr)->format = frame->format; - (*image_ptr)->width = frame->width; - (*image_ptr)->height = frame->height; - (*image_ptr)->alignWidth = frame->alignWidth; - (*image_ptr)->alignHeight = frame->alignHeight; - (*image_ptr)->size = frame->size; - (*image_ptr)->data = frame->data; - - return ACLLITE_OK; -} - -std::shared_ptr DvppVideo::FrameImageOutQueue(bool noWait) { - std::shared_ptr image = frameImageQueue_.Pop(); - - if (noWait || (image != nullptr)) { - return image; - } - - for (int count = 0; count < kQueueOpRetryTimes - 1; count++) { - (void)usleep(kDecodeQueueOpWait); - - image = frameImageQueue_.Pop(); - if (image != nullptr) { - return image; - } - } - - return nullptr; -} - -// YUV data write to a file -void DvppVideo::SaveYuvFile(FILE *const fd, const ImageData &frame) { - auto *addr = reinterpret_cast(frame.data.get()); - uint32_t imageSize = frame.width * frame.height * 3 / 2; // Size = width * height * 3 / 2 - uint8_t *outImageBuf = nullptr; - uint32_t outWidthStride = frame.alignWidth; - uint32_t outHeightStride = frame.alignHeight; - - if (runMode_ == ACL_HOST) { - // malloc host memory - AclLiteError ret = CALL_ASCEND_API(aclrtMallocHost, reinterpret_cast(&outImageBuf), imageSize); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Chn " << channelId_ << " malloc host memory " << imageSize << " failed, error code " << ret; - return; - } - } - - if ((frame.width == outWidthStride) && (frame.height == outHeightStride)) { - if (runMode_ == ACL_HOST) { - // copy device data to host - AclLiteError ret = - CALL_ASCEND_API(aclrtMemcpy, outImageBuf, imageSize, addr, imageSize, ACL_MEMCPY_DEVICE_TO_HOST); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize - << " from device to host failed, error code " << ret; - ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; - } - return; - } - - (void)fwrite(outImageBuf, 1, imageSize, fd); - ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; - return; - } - } else { - (void)fwrite(addr, imageSize, 1, fd); - } - } else { - if (runMode_ == ACL_HOST) { - if (outImageBuf == nullptr) { - return; - } - // Copy valid Y data - for (uint32_t i = 0; i < frame.height; i++) { - AclLiteError ret = CALL_ASCEND_API(aclrtMemcpy, outImageBuf + i * frame.width, frame.width, - addr + i * outWidthStride, frame.width, ACL_MEMCPY_DEVICE_TO_HOST); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize - << " from device to host failed, error code " << ret; - ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; - } - return; - } - } - // Copy valid UV data - for (uint32_t i = 0; i < frame.height / 2; i++) { - AclLiteError ret = CALL_ASCEND_API(aclrtMemcpy, outImageBuf + i * frame.width + frame.width * frame.height, - frame.width, addr + i * outWidthStride + outWidthStride * outHeightStride, - frame.width, ACL_MEMCPY_DEVICE_TO_HOST); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize - << " from device to host failed, error code " << ret; - ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; - } - return; - } - } - - (void)fwrite(outImageBuf, 1, imageSize, fd); - aclError ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; - } - } else { - // Crop the invalid data, then write the valid data to a file - outImageBuf = reinterpret_cast(malloc(imageSize)); - if (outImageBuf == nullptr) { - MS_LOG(ERROR) << "Chn " << channelId_ << " Malloc failed"; - return; - } - // Copy valid Y data - for (uint32_t i = 0; i < frame.height; i++) { - int status = memcpy_s(outImageBuf + i * frame.width, frame.width, addr + i * outWidthStride, frame.width); - if (status != EOK) { - MS_LOG(ERROR) << "[Internal ERROR] memcpy failed."; - free(outImageBuf); - return; - } - } - // Copy valid UV data - for (uint32_t i = 0; i < frame.height / 2; i++) { - int status = memcpy_s(outImageBuf + i * frame.width + frame.width * frame.height, frame.width, - addr + i * outWidthStride + outWidthStride * outHeightStride, frame.width); - if (status != EOK) { - MS_LOG(ERROR) << "[Internal ERROR] memcpy failed."; - free(outImageBuf); - return; - } - } - - (void)fwrite(outImageBuf, 1, imageSize, fd); - free(outImageBuf); - } - } -} - -AclLiteError DvppVideo::DumpFrame() { - auto frame = std::make_shared(); - int frameCnt = 0; - while (true) { - AclLiteError ret = Read(&frame); - if (ret != ACLLITE_OK) { - if (ret == ACLLITE_ERROR_DECODE_FINISH) { - MS_LOG(INFO) << "Dump all " << frameCnt << " frames to " << output_; - return ACLLITE_OK; - } else { - MS_LOG(ERROR) << "Dump " << frameCnt << "td frame failed"; - return ret; - } - } - frameCnt++; - std::string full_path = output_ + "/" + "frame_" + std::to_string(frameCnt) + ".yuv"; - MS_LOG(INFO) << "Dump the " << frameCnt << "th frame to " << full_path; - FILE *outFileFp = fopen(full_path.c_str(), "wb+"); - SaveYuvFile(outFileFp, *frame); - (void)fflush(outFileFp); - (void)fclose(outFileFp); - } -} - -AclLiteError DvppVideo::SetAclContext() { - if (context_ == nullptr) { - MS_LOG(ERROR) << "Video decoder context is null"; - return ACLLITE_ERROR_SET_ACL_CONTEXT; - } - - aclError ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Video decoder set context failed, error: " << ret; - return ACLLITE_ERROR_SET_ACL_CONTEXT; - } - - return ACLLITE_OK; -} - -AclLiteError DvppVideo::Close() { - DestroyResource(); - return ACLLITE_OK; -} +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h" + +#include +#include + +#include +#include +#include +#include + +#include "minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" +#include "mindspore/core/utils/log_adapter.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/symbol_utils.h" + +namespace { +const int64_t kUsec = 1000000; +const uint32_t kDecodeFrameQueueSize = 256; +const int kDecodeQueueOpWait = 10000; // decode wait 10ms/frame +const int kFrameEnQueueRetryTimes = 1000; // max wait time for the frame to enter in queue +const int kQueueOpRetryTimes = 1000; +const int kOutputJamWait = 10000; +const int kInvalidTpye = -1; +const int kWaitDecodeFinishInterval = 1000; + +const uint32_t DVPP_VIDEO_H264 = 0; +const uint32_t DVPP_VIDEO_H265 = 1; + +ChannelIdGenerator channelIdGenerator; +} // namespace + +FrameExtarct::FrameExtarct(uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type) + : data_(data), size_(size), frameWidth_(width), frameHeight_(height) { + isFinished_ = false; + isStop_ = false; + videoType_ = (type == 0) ? DVPP_VIDEO_H265 : DVPP_VIDEO_H264; +} + +void FrameExtarct::ExtractFrameH264(const uint8_t *buf_ptr, int *size_ptr) { + if (buf_ptr == nullptr || size_ptr == nullptr || *size_ptr <= 0) { + return; + } + bool isFindStart = false; + bool isFindEnd = false; + int size = *size_ptr; + int i = 0; + for (; i < size - 8; i++) { + if (FindStartH264(buf_ptr, i)) { + isFindStart = true; + i += 8; + break; + } + } + + for (; i < size - 8; i++) { + if (FindEndH264(buf_ptr, i)) { + isFindEnd = true; + break; + } + } + + if (i > 0) { + *size_ptr = i; + } + + if (!isFindStart) { + MS_LOG(ERROR) << "Channel can not find H265 start code, please check input video coding protocol is H264."; + return; + } + if (!isFindEnd) { + *size_ptr = i + 8; + } +} + +void FrameExtarct::ExtractFrameH265(const uint8_t *buf_ptr, int *size_ptr) { + if (buf_ptr == nullptr || size_ptr == nullptr || *size_ptr <= 0) { + return; + } + bool isFindStart = false; + bool isFindEnd = false; + int i = 0; + for (; i < *size_ptr - 6; i++) { + if (FindStartH265(buf_ptr, i)) { + isFindStart = true; + i += 6; + break; + } + } + + for (; i < *size_ptr - 6; i++) { + if (FindEndH265(buf_ptr, i)) { + isFindEnd = true; + break; + } + } + if (i > 0) { + *size_ptr = i; + } + + if (!isFindStart) { + MS_LOG(ERROR) << "Channel can not find H265 start code, please check input video coding protocol is H265."; + return; + } + if (!isFindEnd) { + *size_ptr = i + 6; + } +} + +void FrameExtarct::Decode(FrameProcessCallBack callback, void *callbackParam) { + MS_LOG(INFO) << "Start extarct frame from video..."; + + int32_t usedBytes = 0; + uint32_t count = 0; + bool processOk = true; + + while (!isStop_ && processOk) { + uint8_t *bufPointer; + bufPointer = data_ + usedBytes; + int32_t readlen = size_ - usedBytes; + if (readlen <= 0) { + break; + } + + if (videoType_ == DVPP_VIDEO_H264) { // H264 + ExtractFrameH264(bufPointer, &readlen); + } else if (videoType_ == DVPP_VIDEO_H265) { // H265 + ExtractFrameH265(bufPointer, &readlen); + } + int ret = callback(callbackParam, bufPointer, readlen); + if (ret != 0) { + processOk = false; + } + count++; + usedBytes = usedBytes + readlen; + } + // Frame count + + isFinished_ = true; + MS_LOG(INFO) << "FrameExtarct decoder finished, frame count: " << count << "."; +} + +DvppVideo::DvppVideo(aclrtContext context, uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type, + uint32_t out_format, const std::string &output) + : data_(data), + size_(size), + frameWidth_(width), + frameHeight_(height), + format_(out_format), + output_(output), + isStop_(false), + isReleased_(false), + isJam_(false), + status_(DecodeStatus::DECODE_UNINIT), + context_(context), + channelId_(INVALID_CHANNEL_ID), + streamFormat_(type), + frameId_(0), + finFrameCnt_(0), + lastDecodeTime_(0), + frameExtarct_(nullptr), + dvppVdec_(nullptr), + frameImageQueue_(kDecodeFrameQueueSize) {} + +DvppVideo::~DvppVideo() { DestroyResource(); } + +void DvppVideo::DestroyResource() { + if (isReleased_) { + return; + } + // 1. stop ffmpeg + isStop_ = true; + frameExtarct_->StopDecode(); + while ((status_ >= DecodeStatus::DECODE_START) && (status_ < DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED)) { + (void)usleep(kWaitDecodeFinishInterval); + } + // 2. delete ffmpeg decoder + delete frameExtarct_; + frameExtarct_ = nullptr; + + // 3. release dvpp vdec + dvppVdec_->DestroyResource(); + + // 4. release image memory in decode output queue + do { + std::shared_ptr frame = FrameImageOutQueue(true); + if (frame == nullptr) { + break; + } + + if (frame->data != nullptr) { + auto ret = acldvppFree(frame->data.get()); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "acldvppFree failed, errorno: " << ret; + } + frame->data = nullptr; + } + } while (true); + // 5. release channel id + channelIdGenerator.ReleaseChannelId(channelId_); + + isReleased_ = true; +} + +AclLiteError DvppVideo::InitResource() { + aclError aclRet; + // use current thread context default + if (context_ == nullptr) { + aclRet = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); + if ((aclRet != ACL_SUCCESS) || (context_ == nullptr)) { + MS_LOG(ERROR) << "Get current acl context error: " << aclRet; + return ACLLITE_ERROR_GET_ACL_CONTEXT; + } + } + // Get current run mode + aclRet = CALL_ASCEND_API(aclrtGetRunMode, &runMode_); + if (aclRet != ACL_SUCCESS) { + MS_LOG(ERROR) << "acl get run mode failed"; + return ACLLITE_ERROR_GET_RUM_MODE; + } + + return ACLLITE_OK; +} + +AclLiteError DvppVideo::InitVdecDecoder() { + // Generate a unique channel id for video decoder + channelId_ = channelIdGenerator.GenerateChannelId(); + if (channelId_ == INVALID_CHANNEL_ID) { + MS_LOG(ERROR) << "Decoder number excessive " << VIDEO_CHANNEL_MAX; + return ACLLITE_ERROR_TOO_MANY_VIDEO_DECODERS; + } + + // Create dvpp vdec to decode h26x data + dvppVdec_ = new VdecHelper(channelId_, frameWidth_, frameHeight_, streamFormat_, DvppVideo::DvppVdecCallback); + + AclLiteError ret = dvppVdec_->SetFormat(format_); + if (ret != ACLLITE_OK) { + MS_LOG(ERROR) << "Dvpp vdec set out format failed"; + } + ret = dvppVdec_->Init(); + if (ret != ACLLITE_OK) { + MS_LOG(ERROR) << "Dvpp vdec init failed"; + } + + ret = this->dvppVdec_->VideoParamCheck(); + if (ret != ACLLITE_OK) { + this->SetStatus(DecodeStatus::DECODE_ERROR); + MS_LOG(ERROR) << "Dvpp vdec check param failed " << ret; + return ret; + } + + return ret; +} + +AclLiteError DvppVideo::InitFrameExtractor() { + // Create ffmpeg decoder to parse video stream to h26x frame data + frameExtarct_ = new FrameExtarct(data_, size_, frameWidth_, frameHeight_, streamFormat_); + return ACLLITE_OK; +} + +AclLiteError DvppVideo::Init() { + // Open video stream, if open failed before, return error directly + if (status_ == DecodeStatus::DECODE_ERROR) { + return ACLLITE_ERROR_OPEN_VIDEO_UNREADY; + } + // If open ok already + if (status_ != DecodeStatus::DECODE_UNINIT) { + return ACLLITE_OK; + } + // Init acl resource + AclLiteError ret = InitResource(); + if (ret != ACLLITE_OK) { + this->SetStatus(DecodeStatus::DECODE_ERROR); + MS_LOG(ERROR) << "Dvpp video init resource failed " << ret; + return ret; + } + // Init ffmpeg decoder + ret = InitFrameExtractor(); + if (ret != ACLLITE_OK) { + this->SetStatus(DecodeStatus::DECODE_ERROR); + MS_LOG(ERROR) << "Dvpp video init FrameExtractor failed " << ret; + return ret; + } + // Init dvpp vdec decoder + ret = InitVdecDecoder(); + if (ret != ACLLITE_OK) { + this->SetStatus(DecodeStatus::DECODE_ERROR); + MS_LOG(ERROR) << "Dvpp video init Vdec failed " << ret; + return ret; + } + // Set init ok + this->SetStatus(DecodeStatus::DECODE_READY); + MS_LOG(INFO) << "Dvpp video init ok"; + + return ACLLITE_OK; +} + +// dvpp vdec callback +void DvppVideo::DvppVdecCallback(acldvppStreamDesc *input, acldvppPicDesc *output, void *userData) { + auto *decoder = reinterpret_cast(userData); + // Get decoded image parameters + std::shared_ptr image = std::make_shared(); + image->format = acldvppGetPicDescFormat(output); + image->width = acldvppGetPicDescWidth(output); + image->height = acldvppGetPicDescHeight(output); + image->alignWidth = acldvppGetPicDescWidthStride(output); + image->alignHeight = acldvppGetPicDescHeightStride(output); + image->size = acldvppGetPicDescSize(output); + + void *vdecOutBufferDev = acldvppGetPicDescData(output); + image->data = SHARED_PTR_DVPP_BUF(vdecOutBufferDev); + + // Put the decoded image to queue for read + decoder->ProcessDecodedImage(image); + // Release resource + aclError ret = acldvppDestroyPicDesc(output); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Dvpp vdec destroy pic desc failed " << ret; + } + + if (input != nullptr) { + void *inputBuf = acldvppGetStreamDescData(input); + if (inputBuf != nullptr) { + ret = acldvppFree(inputBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "acldvppFree failed, errorno: " << ret; + } + } + + ret = acldvppDestroyStreamDesc(input); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Dvpp vdec destroy input stream failed " << ret; + } + } +} + +void DvppVideo::ProcessDecodedImage(std::shared_ptr frameData) { + finFrameCnt_++; + if (YUV420SP_SIZE(frameData->width, frameData->height) != frameData->size) { + MS_LOG(ERROR) << "Invalid decoded frame parameter, width " << frameData->width << ", height " << frameData->height + << ", size " << frameData->size << ", buffer " << static_cast(frameData->data.get()); + return; + } + + auto ret = FrameImageEnQueue(frameData); + if (ret != ACLLITE_OK) { + MS_LOG(ERROR) << "FrameImageEnQueue faile, errorno: " << ret; + } + + if ((status_ == DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED) && (finFrameCnt_ >= frameId_)) { + MS_LOG(INFO) << "Last frame decoded by dvpp, change status to " << DecodeStatus::DECODE_DVPP_FINISHED; + this->SetStatus(DecodeStatus::DECODE_DVPP_FINISHED); + } +} + +AclLiteError DvppVideo::FrameImageEnQueue(const std::shared_ptr &frameData) { + for (int count = 0; count < kFrameEnQueueRetryTimes; count++) { + if (frameImageQueue_.Push(frameData)) { + return ACLLITE_OK; + } + (void)usleep(kDecodeQueueOpWait); + } + MS_LOG(ERROR) << "Video lost decoded image for queue full"; + + return ACLLITE_ERROR_VDEC_QUEUE_FULL; +} + +// start decoder +void DvppVideo::StartFrameDecoder() { + if (status_ == DecodeStatus::DECODE_READY) { + decodeThread_ = std::thread(FrameDecodeThreadFunction, reinterpret_cast(this)); + decodeThread_.detach(); + + status_ = DecodeStatus::DECODE_START; + } +} + +// ffmpeg decoder entry +void DvppVideo::FrameDecodeThreadFunction(void *decoderSelf) { + if (decoderSelf == nullptr) { + return; + } + auto *thisPtr = reinterpret_cast(decoderSelf); + + aclError aclRet = thisPtr->SetAclContext(); + if (aclRet != ACL_SUCCESS) { + MS_LOG(ERROR) << "Set frame decoder context failed, errorno: " << aclRet; + return; + } + // start decode until complete + thisPtr->DecodeH26xFrame(); + if (thisPtr->IsStop()) { + thisPtr->SetStatus(DecodeStatus::DECODE_FINISHED); + return; + } + thisPtr->SetStatus(DecodeStatus::DECODE_FRAME_EXTRACT_FINISHED); + // when ffmpeg decode finish, send eos to vdec + std::shared_ptr videoFrame = std::make_shared(); + videoFrame->isFinished = true; + videoFrame->data = nullptr; + videoFrame->size = 0; + auto ret = thisPtr->dvppVdec_->Process(videoFrame, decoderSelf); + if (ret != ACLLITE_ERROR) { + MS_LOG(ERROR) << "DvppVdec procesing failed, errorno: " << ret; + } + + thisPtr->dvppVdec_->DestroyChannel(); + while ((thisPtr->GetStatus() != DecodeStatus::DECODE_DVPP_FINISHED) && !thisPtr->IsStop()) { + (void)usleep(kWaitDecodeFinishInterval); + } +} + +// callback of ffmpeg decode frame +AclLiteError DvppVideo::FrameDecodeCallback(void *decoder, void *frameData, int frameSize) { + if ((frameData == nullptr) || (frameSize == 0)) { + MS_LOG(ERROR) << "Frame data is null"; + return ACLLITE_ERROR_H26X_FRAME; + } + + // copy data to dvpp memory + if (decoder == nullptr) { + MS_LOG(ERROR) << "Decoder is nullptr"; + return ACLLITE_ERROR_H26X_FRAME; + } + auto *videoDecoder = reinterpret_cast(decoder); + + void *buffer = CopyDataToDevice(frameData, frameSize, videoDecoder->runMode_, MemoryType::MEMORY_DVPP); + if (buffer == nullptr) { + MS_LOG(ERROR) << "Copy frame h26x data to dvpp failed"; + return ACLLITE_ERROR_COPY_DATA; + } + + std::shared_ptr videoFrame = std::make_shared(); + videoDecoder->frameId_++; + videoFrame->frameId = videoDecoder->frameId_; + videoFrame->data = buffer; + videoFrame->size = frameSize; + // decode data by dvpp vdec + AclLiteError ret = videoDecoder->dvppVdec_->Process(videoFrame, decoder); + if (ret != ACLLITE_OK) { + MS_LOG(ERROR) << "Dvpp vdec process " << videoDecoder->frameId_ << "th frame failed, error: " << ret; + return ret; + } + return ACLLITE_OK; +} + +// read decoded frame +AclLiteError DvppVideo::Read(std::shared_ptr *image_ptr) { + if (image_ptr == nullptr) { + MS_LOG(ERROR) << "image_ptr is nullptr"; + return ACLLITE_ERROR; + } + // return nullptr,if decode fail/finish + if (status_ == DecodeStatus::DECODE_ERROR) { + MS_LOG(ERROR) << "Read failed for decode failed"; + return ACLLITE_ERROR_VIDEO_DECODER_STATUS; + } + + if (status_ == DecodeStatus::DECODE_FINISHED) { + MS_LOG(INFO) << "No frame to read for decode finished"; + return ACLLITE_ERROR_DECODE_FINISH; + } + // start decode if status is ok + if (status_ == DecodeStatus::DECODE_READY) { + StartFrameDecoder(); + (void)usleep(kDecodeQueueOpWait); + } + // read frame from decode queue + bool noWait = (status_ == DecodeStatus::DECODE_DVPP_FINISHED); + std::shared_ptr frame = FrameImageOutQueue(noWait); + if (noWait && (frame == nullptr)) { + SetStatus(DecodeStatus::DECODE_FINISHED); + MS_LOG(INFO) << "No frame to read anymore"; + return ACLLITE_ERROR_DECODE_FINISH; + } + + if (frame == nullptr) { + MS_LOG(ERROR) << "Empty frame image to read"; + return ACLLITE_ERROR_READ_EMPTY; + } + + (*image_ptr)->format = frame->format; + (*image_ptr)->width = frame->width; + (*image_ptr)->height = frame->height; + (*image_ptr)->alignWidth = frame->alignWidth; + (*image_ptr)->alignHeight = frame->alignHeight; + (*image_ptr)->size = frame->size; + (*image_ptr)->data = frame->data; + + return ACLLITE_OK; +} + +std::shared_ptr DvppVideo::FrameImageOutQueue(bool noWait) { + std::shared_ptr image = frameImageQueue_.Pop(); + + if (noWait || (image != nullptr)) { + return image; + } + + for (int count = 0; count < kQueueOpRetryTimes - 1; count++) { + (void)usleep(kDecodeQueueOpWait); + + image = frameImageQueue_.Pop(); + if (image != nullptr) { + return image; + } + } + + return nullptr; +} + +// YUV data write to a file +void DvppVideo::SaveYuvFile(FILE *const fd, const ImageData &frame) { + auto *addr = reinterpret_cast(frame.data.get()); + uint32_t imageSize = frame.width * frame.height * 3 / 2; // Size = width * height * 3 / 2 + uint8_t *outImageBuf = nullptr; + uint32_t outWidthStride = frame.alignWidth; + uint32_t outHeightStride = frame.alignHeight; + + if (runMode_ == ACL_HOST) { + // malloc host memory + AclLiteError ret = CALL_ASCEND_API(aclrtMallocHost, reinterpret_cast(&outImageBuf), imageSize); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Chn " << channelId_ << " malloc host memory " << imageSize << " failed, error code " << ret; + return; + } + } + + if ((frame.width == outWidthStride) && (frame.height == outHeightStride)) { + if (runMode_ == ACL_HOST) { + // copy device data to host + AclLiteError ret = + CALL_ASCEND_API(aclrtMemcpy, outImageBuf, imageSize, addr, imageSize, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize + << " from device to host failed, error code " << ret; + ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; + } + return; + } + + (void)fwrite(outImageBuf, 1, imageSize, fd); + ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; + return; + } + } else { + (void)fwrite(addr, imageSize, 1, fd); + } + } else { + if (runMode_ == ACL_HOST) { + if (outImageBuf == nullptr) { + return; + } + // Copy valid Y data + for (uint32_t i = 0; i < frame.height; i++) { + AclLiteError ret = CALL_ASCEND_API(aclrtMemcpy, outImageBuf + i * frame.width, frame.width, + addr + i * outWidthStride, frame.width, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize + << " from device to host failed, error code " << ret; + ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; + } + return; + } + } + // Copy valid UV data + for (uint32_t i = 0; i < frame.height / 2; i++) { + AclLiteError ret = CALL_ASCEND_API(aclrtMemcpy, outImageBuf + i * frame.width + frame.width * frame.height, + frame.width, addr + i * outWidthStride + outWidthStride * outHeightStride, + frame.width, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Chn " << channelId_ << " Copy aclrtMemcpy " << imageSize + << " from device to host failed, error code " << ret; + ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; + } + return; + } + } + + (void)fwrite(outImageBuf, 1, imageSize, fd); + aclError ret = CALL_ASCEND_API(aclrtFreeHost, outImageBuf); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "aclrtFreeHost failed, errorno: " << ret; + } + } else { + // Crop the invalid data, then write the valid data to a file + outImageBuf = reinterpret_cast(malloc(imageSize)); + if (outImageBuf == nullptr) { + MS_LOG(ERROR) << "Chn " << channelId_ << " Malloc failed"; + return; + } + // Copy valid Y data + for (uint32_t i = 0; i < frame.height; i++) { + int status = memcpy_s(outImageBuf + i * frame.width, frame.width, addr + i * outWidthStride, frame.width); + if (status != EOK) { + MS_LOG(ERROR) << "[Internal ERROR] memcpy failed."; + free(outImageBuf); + return; + } + } + // Copy valid UV data + for (uint32_t i = 0; i < frame.height / 2; i++) { + int status = memcpy_s(outImageBuf + i * frame.width + frame.width * frame.height, frame.width, + addr + i * outWidthStride + outWidthStride * outHeightStride, frame.width); + if (status != EOK) { + MS_LOG(ERROR) << "[Internal ERROR] memcpy failed."; + free(outImageBuf); + return; + } + } + + (void)fwrite(outImageBuf, 1, imageSize, fd); + free(outImageBuf); + } + } +} + +AclLiteError DvppVideo::DumpFrame() { + auto frame = std::make_shared(); + int frameCnt = 0; + while (true) { + AclLiteError ret = Read(&frame); + if (ret != ACLLITE_OK) { + if (ret == ACLLITE_ERROR_DECODE_FINISH) { + MS_LOG(INFO) << "Dump all " << frameCnt << " frames to " << output_; + return ACLLITE_OK; + } else { + MS_LOG(ERROR) << "Dump " << frameCnt << "td frame failed"; + return ret; + } + } + frameCnt++; + std::string full_path = output_ + "/" + "frame_" + std::to_string(frameCnt) + ".yuv"; + MS_LOG(INFO) << "Dump the " << frameCnt << "th frame to " << full_path; + FILE *outFileFp = fopen(full_path.c_str(), "wb+"); + SaveYuvFile(outFileFp, *frame); + (void)fflush(outFileFp); + (void)fclose(outFileFp); + } +} + +AclLiteError DvppVideo::SetAclContext() { + if (context_ == nullptr) { + MS_LOG(ERROR) << "Video decoder context is null"; + return ACLLITE_ERROR_SET_ACL_CONTEXT; + } + + aclError ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Video decoder set context failed, error: " << ret; + return ACLLITE_ERROR_SET_ACL_CONTEXT; + } + + return ACLLITE_OK; +} + +AclLiteError DvppVideo::Close() { + DestroyResource(); + return ACLLITE_OK; +} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h index 7cb899ac35a..8f883a92766 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.h @@ -1,215 +1,215 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h" -#include "minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h" - -constexpr int INVALID_CHANNEL_ID = -1; -constexpr int INVALID_STREAM_FORMAT = -1; -constexpr int VIDEO_CHANNEL_MAX = 23; -constexpr int THIRD_ELEMENT_INDEX = 2; -constexpr int FOURTH_ELEMENT_INDEX = 3; -constexpr int FIFTH_ELEMENT_INDEX = 4; -constexpr int SIXTH_ELEMENT_INDEX = 5; -constexpr int EIGHTH_ELEMENT_INDEX = 7; - -using FrameProcessCallBack = int (*)(void *callback_param, void *frame_data, int frame_size); - -enum class DecodeStatus { - DECODE_ERROR = -1, - DECODE_UNINIT = 0, - DECODE_READY = 1, - DECODE_START = 2, - DECODE_FRAME_EXTRACT_FINISHED = 3, - DECODE_DVPP_FINISHED = 4, - DECODE_FINISHED = 5 -}; - -class ChannelIdGenerator { - public: - ChannelIdGenerator() noexcept { - for (int i = 0; i < VIDEO_CHANNEL_MAX; i++) { - channelId_[i] = INVALID_CHANNEL_ID; - } - } - ~ChannelIdGenerator() = default; - - int GenerateChannelId() { - std::lock_guard lock(mutex_lock_); - for (int i = 0; i < VIDEO_CHANNEL_MAX; i++) { - if (channelId_[i] == INVALID_CHANNEL_ID) { - channelId_[i] = i; - return i; - } - } - - return INVALID_CHANNEL_ID; - } - - void ReleaseChannelId(int channelId) { - std::lock_guard lock(mutex_lock_); - if ((channelId >= 0) && (channelId < VIDEO_CHANNEL_MAX)) { - channelId_[channelId] = INVALID_CHANNEL_ID; - } - } - - private: - int channelId_[VIDEO_CHANNEL_MAX]{}; - mutable std::mutex mutex_lock_; -}; - -class FrameExtarct { - public: - FrameExtarct(uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type); - ~FrameExtarct() = default; - void Decode(FrameProcessCallBack callback, void *callbackParam); - void ExtractFrameH264(const uint8_t *buf_ptr, int *size_ptr); - void ExtractFrameH265(const uint8_t *buf_ptr, int *size_ptr); - int IsFinished() const { return isFinished_; } - void StopDecode() { isStop_ = true; } - - private: - inline bool FindStartH264(const uint8_t *buf, int idx) { - int32_t tmp = buf[idx + FOURTH_ELEMENT_INDEX] & 0x1F; - // Find 00 00 01 - return (buf[idx] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && - (((tmp == 0x5 || tmp == 0x1) && ((buf[idx + FIFTH_ELEMENT_INDEX] & 0x80) == 0x80)) || - (tmp == 0x14 && (buf[idx + EIGHTH_ELEMENT_INDEX] & 0x80) == 0x80)); - } - - inline bool FindEndH264(const uint8_t *buf, int idx) { - // Find 00 00 01 - int32_t tmp = buf[idx + FOURTH_ELEMENT_INDEX] & 0x1F; - return (buf[idx] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && - ((tmp == 0xF) || (tmp == 0x7) || (tmp == 0x8) || (tmp == 0x6) || - ((tmp == 0x5 || tmp == 1) && ((buf[idx + FIFTH_ELEMENT_INDEX] & 0x80) == 0x80)) || - (tmp == 0x14 && (buf[idx + EIGHTH_ELEMENT_INDEX] & 0x80) == 0x80)); - } - - inline bool FindStartH265(const uint8_t *buf, int idx) { - uint32_t tmp = (buf[idx + FOURTH_ELEMENT_INDEX] & 0x7EU) >> 1; - // Find 00 00 01 - return (buf[idx + 0] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && (tmp <= 0x15U) && - ((buf[idx + SIXTH_ELEMENT_INDEX] & 0x80) == 0x80); - } - - inline bool FindEndH265(const uint8_t *buf, int idx) { - uint32_t tmp = (buf[idx + FOURTH_ELEMENT_INDEX] & 0x7EU) >> 1; - // Find 00 00 01 - return ((buf[idx + 0] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && - ((tmp == 0x20U) || (tmp == 0x21U) || (tmp == 0x22U) || (tmp == 0x27U) || (tmp == 0x28U) || - ((tmp <= 0x15U) && (buf[idx + SIXTH_ELEMENT_INDEX] & 0x80) == 0x80))); - } - - private: - uint8_t *data_; - uint32_t size_; - - uint32_t frameWidth_; - uint32_t frameHeight_; - int videoType_; - - bool isFinished_; - bool isStop_; -}; - -class DvppVideo { - public: - /** - * @brief DvppVideo constructor - */ - DvppVideo(aclrtContext context, uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type, - uint32_t out_format, const std::string &output); - - /** - * @brief DvppVideo destructor - */ - ~DvppVideo(); - - static void FrameDecodeThreadFunction(void *decoderSelf); - static AclLiteError FrameDecodeCallback(void *context, void *frameData, int frameSize); - static void DvppVdecCallback(acldvppStreamDesc *input, acldvppPicDesc *output, void *userdata); - - void ProcessDecodedImage(std::shared_ptr frameData); - void DecodeH26xFrame() { frameExtarct_->Decode(&DvppVideo::FrameDecodeCallback, reinterpret_cast(this)); } - - AclLiteError Init(); - void SetStatus(DecodeStatus status) { status_ = status; } - DecodeStatus GetStatus() { return status_; } - - AclLiteError Read(std::shared_ptr *image_ptr); - - AclLiteError DumpFrame(); - - AclLiteError SetAclContext(); - AclLiteError Close(); - - void DestroyResource(); - bool IsStop() const { return isStop_; } - bool IsJam() const { return isJam_; } - - private: - AclLiteError InitResource(); - AclLiteError InitVdecDecoder(); - AclLiteError InitFrameExtractor(); - void StartFrameDecoder(); - AclLiteError FrameImageEnQueue(const std::shared_ptr &frameData); - std::shared_ptr FrameImageOutQueue(bool noWait = false); - - void SaveYuvFile(FILE *fd, const ImageData &frame); - - private: - uint8_t *data_; - uint32_t size_; - - uint32_t frameWidth_; - uint32_t frameHeight_; - - /* 1:YUV420 semi-planner(nv12) - 2:YVU420 semi-planner(nv21) - */ - uint32_t format_; - std::string output_; - - bool isStop_; - bool isReleased_; - bool isJam_; - DecodeStatus status_; - aclrtContext context_; - aclrtRunMode runMode_; - int channelId_; - int streamFormat_; - uint32_t frameId_; - uint32_t finFrameCnt_; - int64_t lastDecodeTime_; - std::thread decodeThread_; - FrameExtarct *frameExtarct_; - VdecHelper *dvppVdec_; - ThreadSafeQueue> frameImageQueue_; -}; -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/kernels/image/dvpp/utils/ThreadSafeQueue.h" +#include "minddata/dataset/kernels/image/dvpp/utils/VdecHelper.h" + +constexpr int INVALID_CHANNEL_ID = -1; +constexpr int INVALID_STREAM_FORMAT = -1; +constexpr int VIDEO_CHANNEL_MAX = 23; +constexpr int THIRD_ELEMENT_INDEX = 2; +constexpr int FOURTH_ELEMENT_INDEX = 3; +constexpr int FIFTH_ELEMENT_INDEX = 4; +constexpr int SIXTH_ELEMENT_INDEX = 5; +constexpr int EIGHTH_ELEMENT_INDEX = 7; + +using FrameProcessCallBack = int (*)(void *callback_param, void *frame_data, int frame_size); + +enum class DecodeStatus { + DECODE_ERROR = -1, + DECODE_UNINIT = 0, + DECODE_READY = 1, + DECODE_START = 2, + DECODE_FRAME_EXTRACT_FINISHED = 3, + DECODE_DVPP_FINISHED = 4, + DECODE_FINISHED = 5 +}; + +class ChannelIdGenerator { + public: + ChannelIdGenerator() noexcept { + for (int i = 0; i < VIDEO_CHANNEL_MAX; i++) { + channelId_[i] = INVALID_CHANNEL_ID; + } + } + ~ChannelIdGenerator() = default; + + int GenerateChannelId() { + std::lock_guard lock(mutex_lock_); + for (int i = 0; i < VIDEO_CHANNEL_MAX; i++) { + if (channelId_[i] == INVALID_CHANNEL_ID) { + channelId_[i] = i; + return i; + } + } + + return INVALID_CHANNEL_ID; + } + + void ReleaseChannelId(int channelId) { + std::lock_guard lock(mutex_lock_); + if ((channelId >= 0) && (channelId < VIDEO_CHANNEL_MAX)) { + channelId_[channelId] = INVALID_CHANNEL_ID; + } + } + + private: + int channelId_[VIDEO_CHANNEL_MAX]{}; + mutable std::mutex mutex_lock_; +}; + +class FrameExtarct { + public: + FrameExtarct(uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type); + ~FrameExtarct() = default; + void Decode(FrameProcessCallBack callback, void *callbackParam); + void ExtractFrameH264(const uint8_t *buf_ptr, int *size_ptr); + void ExtractFrameH265(const uint8_t *buf_ptr, int *size_ptr); + int IsFinished() const { return isFinished_; } + void StopDecode() { isStop_ = true; } + + private: + inline bool FindStartH264(const uint8_t *buf, int idx) { + int32_t tmp = buf[idx + FOURTH_ELEMENT_INDEX] & 0x1F; + // Find 00 00 01 + return (buf[idx] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && + (((tmp == 0x5 || tmp == 0x1) && ((buf[idx + FIFTH_ELEMENT_INDEX] & 0x80) == 0x80)) || + (tmp == 0x14 && (buf[idx + EIGHTH_ELEMENT_INDEX] & 0x80) == 0x80)); + } + + inline bool FindEndH264(const uint8_t *buf, int idx) { + // Find 00 00 01 + int32_t tmp = buf[idx + FOURTH_ELEMENT_INDEX] & 0x1F; + return (buf[idx] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && + ((tmp == 0xF) || (tmp == 0x7) || (tmp == 0x8) || (tmp == 0x6) || + ((tmp == 0x5 || tmp == 1) && ((buf[idx + FIFTH_ELEMENT_INDEX] & 0x80) == 0x80)) || + (tmp == 0x14 && (buf[idx + EIGHTH_ELEMENT_INDEX] & 0x80) == 0x80)); + } + + inline bool FindStartH265(const uint8_t *buf, int idx) { + uint32_t tmp = (buf[idx + FOURTH_ELEMENT_INDEX] & 0x7EU) >> 1; + // Find 00 00 01 + return (buf[idx + 0] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && (tmp <= 0x15U) && + ((buf[idx + SIXTH_ELEMENT_INDEX] & 0x80) == 0x80); + } + + inline bool FindEndH265(const uint8_t *buf, int idx) { + uint32_t tmp = (buf[idx + FOURTH_ELEMENT_INDEX] & 0x7EU) >> 1; + // Find 00 00 01 + return ((buf[idx + 0] == 0) && (buf[idx + 1] == 0) && (buf[idx + THIRD_ELEMENT_INDEX] == 1) && + ((tmp == 0x20U) || (tmp == 0x21U) || (tmp == 0x22U) || (tmp == 0x27U) || (tmp == 0x28U) || + ((tmp <= 0x15U) && (buf[idx + SIXTH_ELEMENT_INDEX] & 0x80) == 0x80))); + } + + private: + uint8_t *data_; + uint32_t size_; + + uint32_t frameWidth_; + uint32_t frameHeight_; + int videoType_; + + bool isFinished_; + bool isStop_; +}; + +class DvppVideo { + public: + /** + * @brief DvppVideo constructor + */ + DvppVideo(aclrtContext context, uint8_t *data, uint32_t size, uint32_t width, uint32_t height, uint32_t type, + uint32_t out_format, const std::string &output); + + /** + * @brief DvppVideo destructor + */ + ~DvppVideo(); + + static void FrameDecodeThreadFunction(void *decoderSelf); + static AclLiteError FrameDecodeCallback(void *context, void *frameData, int frameSize); + static void DvppVdecCallback(acldvppStreamDesc *input, acldvppPicDesc *output, void *userdata); + + void ProcessDecodedImage(std::shared_ptr frameData); + void DecodeH26xFrame() { frameExtarct_->Decode(&DvppVideo::FrameDecodeCallback, reinterpret_cast(this)); } + + AclLiteError Init(); + void SetStatus(DecodeStatus status) { status_ = status; } + DecodeStatus GetStatus() { return status_; } + + AclLiteError Read(std::shared_ptr *image_ptr); + + AclLiteError DumpFrame(); + + AclLiteError SetAclContext(); + AclLiteError Close(); + + void DestroyResource(); + bool IsStop() const { return isStop_; } + bool IsJam() const { return isJam_; } + + private: + AclLiteError InitResource(); + AclLiteError InitVdecDecoder(); + AclLiteError InitFrameExtractor(); + void StartFrameDecoder(); + AclLiteError FrameImageEnQueue(const std::shared_ptr &frameData); + std::shared_ptr FrameImageOutQueue(bool noWait = false); + + void SaveYuvFile(FILE *fd, const ImageData &frame); + + private: + uint8_t *data_; + uint32_t size_; + + uint32_t frameWidth_; + uint32_t frameHeight_; + + /* 1:YUV420 semi-planner(nv12) + 2:YVU420 semi-planner(nv21) + */ + uint32_t format_; + std::string output_; + + bool isStop_; + bool isReleased_; + bool isJam_; + DecodeStatus status_; + aclrtContext context_; + aclrtRunMode runMode_; + int channelId_; + int streamFormat_; + uint32_t frameId_; + uint32_t finFrameCnt_; + int64_t lastDecodeTime_; + std::thread decodeThread_; + FrameExtarct *frameExtarct_; + VdecHelper *dvppVdec_; + ThreadSafeQueue> frameImageQueue_; +}; +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_DVPP_UTILS_DVPP_VIDEO_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.cc index e0fb15f11b3..b0fbfe62831 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.cc @@ -1,99 +1,99 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/kernels/ir/vision/posterize_ir.h" - -#ifndef ENABLE_ANDROID -#include "minddata/dataset/kernels/image/posterize_op.h" -#endif -#if !defined(BUILD_LITE) && defined(ENABLE_D) -#include "minddata/dataset/kernels/image/dvpp/ascend910b/dvpp_posterize_op.h" -#endif -#include "minddata/dataset/util/validators.h" - -namespace mindspore { -namespace dataset { -namespace vision { -#ifndef ENABLE_ANDROID -// PosterizeOperation -PosterizeOperation::PosterizeOperation(uint8_t bits, const std::string &device_target) - : bits_(bits), device_target_(device_target) {} - -PosterizeOperation::~PosterizeOperation() = default; - -Status PosterizeOperation::ValidateParams() { - constexpr uint8_t kMinimumBitValue = 0; - constexpr uint8_t kMaximumBitValue = 8; - - if (bits_ < kMinimumBitValue || bits_ > kMaximumBitValue) { - std::string err_msg = "Posterize: bits is out of range [0, 8], got: " + std::to_string(bits_); - LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - // device target - if (device_target_ != "CPU" && device_target_ != "Ascend") { - std::string err_msg = "Posterize: Invalid device target. It's not CPU or Ascend."; - LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - return Status::OK(); -} - -std::shared_ptr PosterizeOperation::Build() { - if (device_target_ == "CPU") { - std::shared_ptr tensor_op = std::make_shared(bits_); - return tensor_op; -#if !defined(BUILD_LITE) && defined(ENABLE_D) - } else if (device_target_ == "Ascend") { - std::shared_ptr dvpp_tensor_op = std::make_shared(bits_); - return dvpp_tensor_op; -#endif - } else { - MS_LOG(ERROR) << "Posterize: Invalid device target. It's not CPU or Ascend."; - return nullptr; - } -} - -Status PosterizeOperation::to_json(nlohmann::json *out_json) { - RETURN_UNEXPECTED_IF_NULL(out_json); - (*out_json)["bits"] = bits_; - (*out_json)["device_target"] = device_target_; - return Status::OK(); -} - -Status PosterizeOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { - RETURN_UNEXPECTED_IF_NULL(operation); - RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "bits", kPosterizeOperation)); - RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "device_target", kPosterizeOperation)); - uint8_t bits_ = op_params["bits"]; - std::string device_target = op_params["device_target"]; - *operation = std::make_shared(bits_, device_target); - return Status::OK(); -} - -MapTargetDevice PosterizeOperation::Type() { - if (device_target_ == "CPU") { - return MapTargetDevice::kCpu; - } else if (device_target_ == "Ascend") { - return MapTargetDevice::kAscend910B; - } else { - MS_LOG(ERROR) << "Posterize: Invalid device target. It's not CPU or Ascend."; - } - return MapTargetDevice::kInvalid; -} -#endif -} // namespace vision -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/ir/vision/posterize_ir.h" + +#ifndef ENABLE_ANDROID +#include "minddata/dataset/kernels/image/posterize_op.h" +#endif +#if !defined(BUILD_LITE) && defined(ENABLE_D) +#include "minddata/dataset/kernels/image/dvpp/ascend910b/dvpp_posterize_op.h" +#endif +#include "minddata/dataset/util/validators.h" + +namespace mindspore { +namespace dataset { +namespace vision { +#ifndef ENABLE_ANDROID +// PosterizeOperation +PosterizeOperation::PosterizeOperation(uint8_t bits, const std::string &device_target) + : bits_(bits), device_target_(device_target) {} + +PosterizeOperation::~PosterizeOperation() = default; + +Status PosterizeOperation::ValidateParams() { + constexpr uint8_t kMinimumBitValue = 0; + constexpr uint8_t kMaximumBitValue = 8; + + if (bits_ < kMinimumBitValue || bits_ > kMaximumBitValue) { + std::string err_msg = "Posterize: bits is out of range [0, 8], got: " + std::to_string(bits_); + LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + // device target + if (device_target_ != "CPU" && device_target_ != "Ascend") { + std::string err_msg = "Posterize: Invalid device target. It's not CPU or Ascend."; + LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + return Status::OK(); +} + +std::shared_ptr PosterizeOperation::Build() { + if (device_target_ == "CPU") { + std::shared_ptr tensor_op = std::make_shared(bits_); + return tensor_op; +#if !defined(BUILD_LITE) && defined(ENABLE_D) + } else if (device_target_ == "Ascend") { + std::shared_ptr dvpp_tensor_op = std::make_shared(bits_); + return dvpp_tensor_op; +#endif + } else { + MS_LOG(ERROR) << "Posterize: Invalid device target. It's not CPU or Ascend."; + return nullptr; + } +} + +Status PosterizeOperation::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + (*out_json)["bits"] = bits_; + (*out_json)["device_target"] = device_target_; + return Status::OK(); +} + +Status PosterizeOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { + RETURN_UNEXPECTED_IF_NULL(operation); + RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "bits", kPosterizeOperation)); + RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "device_target", kPosterizeOperation)); + uint8_t bits_ = op_params["bits"]; + std::string device_target = op_params["device_target"]; + *operation = std::make_shared(bits_, device_target); + return Status::OK(); +} + +MapTargetDevice PosterizeOperation::Type() { + if (device_target_ == "CPU") { + return MapTargetDevice::kCpu; + } else if (device_target_ == "Ascend") { + return MapTargetDevice::kAscend910B; + } else { + MS_LOG(ERROR) << "Posterize: Invalid device target. It's not CPU or Ascend."; + } + return MapTargetDevice::kInvalid; +} +#endif +} // namespace vision +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.h b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.h index 5b6f1ea58a3..1d120ea6b7e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/posterize_ir.h @@ -1,59 +1,59 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ - -#include -#include -#include - -#include "include/api/status.h" -#include "minddata/dataset/include/dataset/constants.h" -#include "minddata/dataset/include/dataset/transforms.h" -#include "minddata/dataset/kernels/ir/tensor_operation.h" - -namespace mindspore { -namespace dataset { -namespace vision { -constexpr char kPosterizeOperation[] = "Posterize"; - -class PosterizeOperation : public TensorOperation { - public: - explicit PosterizeOperation(uint8_t bits, const std::string &device_target = "CPU"); - - ~PosterizeOperation() override; - - std::shared_ptr Build() override; - - Status ValidateParams() override; - - std::string Name() const override { return kPosterizeOperation; }; - - Status to_json(nlohmann::json *out_json) override; - - static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); - - MapTargetDevice Type() override; - - private: - uint8_t bits_; - std::string device_target_; -}; -} // namespace vision -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace vision { +constexpr char kPosterizeOperation[] = "Posterize"; + +class PosterizeOperation : public TensorOperation { + public: + explicit PosterizeOperation(uint8_t bits, const std::string &device_target = "CPU"); + + ~PosterizeOperation() override; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kPosterizeOperation; }; + + Status to_json(nlohmann::json *out_json) override; + + static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); + + MapTargetDevice Type() override; + + private: + uint8_t bits_; + std::string device_target_; +}; +} // namespace vision +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_POSTERIZE_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/char_n_gram.cc b/mindspore/ccsrc/minddata/dataset/text/char_n_gram.cc index c15e2c96f66..70b8e4f2c6b 100644 --- a/mindspore/ccsrc/minddata/dataset/text/char_n_gram.cc +++ b/mindspore/ccsrc/minddata/dataset/text/char_n_gram.cc @@ -1,98 +1,98 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/char_n_gram.h" - -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -CharNGram::CharNGram(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} - -Status CharNGram::BuildFromFile(std::shared_ptr *char_n_gram, const std::string &path, int32_t max_vectors) { - RETURN_UNEXPECTED_IF_NULL(char_n_gram); - std::unordered_map> map; - int vector_dim = -1; - RETURN_IF_NOT_OK(CharNGram::Load(path, max_vectors, &map, &vector_dim)); - *char_n_gram = std::make_shared(std::move(map), vector_dim); - return Status::OK(); -} - -std::vector CharNGram::Lookup(const std::string &token, const std::vector &unk_init, - bool lower_case_backup) { - std::vector init_vec(dim_, 0); - if (!unk_init.empty()) { - if (unk_init.size() != dim_) { - MS_LOG(WARNING) << "CharNGram: size of unk_init is not the same as vectors, will initialize with zero vectors."; - } else { - init_vec = unk_init; - } - } - std::string lower_token = token; - if (lower_case_backup) { - std::transform(lower_token.begin(), lower_token.end(), lower_token.begin(), ::tolower); - } - - std::vector chars; - chars.push_back("#BEGIN#"); - for (int i = 0; i < lower_token.length(); i++) { - std::string s; - s.push_back(lower_token[i]); // Convert a char type letter to a string type. - chars.push_back(s); - } - chars.push_back("#END#"); - - int len = chars.size(); - int num_vectors = 0; - std::vector vector_value_sum(dim_, 0); - std::vector vector_value_temp; - // The length of meaningful characters in the pre-training file is 2, 3, 4. - const int slice_len[3] = {2, 3, 4}; - const int slice_len_size = sizeof(slice_len) / sizeof(slice_len[0]); - for (int i = 0; i < slice_len_size; i++) { - int end = len - slice_len[i] + 1; - for (int pos = 0; pos < end; pos++) { - std::vector gram_vec; - std::vector::const_iterator first = chars.begin() + pos; - std::vector::const_iterator second = first + slice_len[i]; - gram_vec.assign(first, second); - std::string c = ""; - std::string gram = std::accumulate(gram_vec.begin(), gram_vec.end(), c); - std::string gram_key = std::to_string(slice_len[i]) + "gram-" + gram; - auto str_index = map_.find(gram_key); - if (str_index == map_.end()) { - vector_value_temp = init_vec; - } else { - vector_value_temp = str_index->second; - } - if (vector_value_temp != init_vec) { - std::transform(vector_value_temp.begin(), vector_value_temp.end(), vector_value_sum.begin(), - vector_value_sum.begin(), std::plus()); - num_vectors++; - } - } - } - std::vector vector_value(dim_, 0); - if (num_vectors > 0) { - std::transform(vector_value_sum.begin(), vector_value_sum.end(), vector_value.begin(), - [&num_vectors](float value) -> float { return value / num_vectors; }); - return vector_value; - } else { - return init_vec; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/char_n_gram.h" + +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +CharNGram::CharNGram(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} + +Status CharNGram::BuildFromFile(std::shared_ptr *char_n_gram, const std::string &path, int32_t max_vectors) { + RETURN_UNEXPECTED_IF_NULL(char_n_gram); + std::unordered_map> map; + int vector_dim = -1; + RETURN_IF_NOT_OK(CharNGram::Load(path, max_vectors, &map, &vector_dim)); + *char_n_gram = std::make_shared(std::move(map), vector_dim); + return Status::OK(); +} + +std::vector CharNGram::Lookup(const std::string &token, const std::vector &unk_init, + bool lower_case_backup) { + std::vector init_vec(dim_, 0); + if (!unk_init.empty()) { + if (unk_init.size() != dim_) { + MS_LOG(WARNING) << "CharNGram: size of unk_init is not the same as vectors, will initialize with zero vectors."; + } else { + init_vec = unk_init; + } + } + std::string lower_token = token; + if (lower_case_backup) { + std::transform(lower_token.begin(), lower_token.end(), lower_token.begin(), ::tolower); + } + + std::vector chars; + chars.push_back("#BEGIN#"); + for (int i = 0; i < lower_token.length(); i++) { + std::string s; + s.push_back(lower_token[i]); // Convert a char type letter to a string type. + chars.push_back(s); + } + chars.push_back("#END#"); + + int len = chars.size(); + int num_vectors = 0; + std::vector vector_value_sum(dim_, 0); + std::vector vector_value_temp; + // The length of meaningful characters in the pre-training file is 2, 3, 4. + const int slice_len[3] = {2, 3, 4}; + const int slice_len_size = sizeof(slice_len) / sizeof(slice_len[0]); + for (int i = 0; i < slice_len_size; i++) { + int end = len - slice_len[i] + 1; + for (int pos = 0; pos < end; pos++) { + std::vector gram_vec; + std::vector::const_iterator first = chars.begin() + pos; + std::vector::const_iterator second = first + slice_len[i]; + gram_vec.assign(first, second); + std::string c = ""; + std::string gram = std::accumulate(gram_vec.begin(), gram_vec.end(), c); + std::string gram_key = std::to_string(slice_len[i]) + "gram-" + gram; + auto str_index = map_.find(gram_key); + if (str_index == map_.end()) { + vector_value_temp = init_vec; + } else { + vector_value_temp = str_index->second; + } + if (vector_value_temp != init_vec) { + std::transform(vector_value_temp.begin(), vector_value_temp.end(), vector_value_sum.begin(), + vector_value_sum.begin(), std::plus()); + num_vectors++; + } + } + } + std::vector vector_value(dim_, 0); + if (num_vectors > 0) { + std::transform(vector_value_sum.begin(), vector_value_sum.end(), vector_value.begin(), + [&num_vectors](float value) -> float { return value / num_vectors; }); + return vector_value; + } else { + return init_vec; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/char_n_gram.h b/mindspore/ccsrc/minddata/dataset/text/char_n_gram.h index 360cd757cfb..92606e838f6 100644 --- a/mindspore/ccsrc/minddata/dataset/text/char_n_gram.h +++ b/mindspore/ccsrc/minddata/dataset/text/char_n_gram.h @@ -1,64 +1,64 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/text/vectors.h" - -namespace mindspore { -namespace dataset { -/// \brief Build CharNGram vectors from reading a Pre-train word vectors. -class CharNGram : public Vectors { - public: - // Constructor. - CharNGram() = default; - - /// Constructor. - /// \param[in] map A map between string and vector. - /// \param[in] dim Dimension of the vectors. - CharNGram(const std::unordered_map> &map, int32_t dim); - - // Destructor. - ~CharNGram() = default; - - /// \brief Build CharNGram from reading a CharNGram pre-train vector file. - /// \param[out] char_n_gram CharNGram object which contains the pre-train vectors. - /// \param[in] path Path to the CharNGram pre-trained word vector file. - /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). - static Status BuildFromFile(std::shared_ptr *char_n_gram, const std::string &path, - int32_t max_vectors = 0); - - /// \brief Look up embedding vectors of token. - /// \param[in] token A token to be looked up. - /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`. - /// (default={}, means to initialize with zero vectors). - /// \param[in] lower_case_backup Whether to look up the token in the lower case (Default = false). - /// \return The vector of the input token. - std::vector Lookup(const std::string &token, const std::vector &unk_init = {}, - bool lower_case_backup = false); -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/text/vectors.h" + +namespace mindspore { +namespace dataset { +/// \brief Build CharNGram vectors from reading a Pre-train word vectors. +class CharNGram : public Vectors { + public: + // Constructor. + CharNGram() = default; + + /// Constructor. + /// \param[in] map A map between string and vector. + /// \param[in] dim Dimension of the vectors. + CharNGram(const std::unordered_map> &map, int32_t dim); + + // Destructor. + ~CharNGram() = default; + + /// \brief Build CharNGram from reading a CharNGram pre-train vector file. + /// \param[out] char_n_gram CharNGram object which contains the pre-train vectors. + /// \param[in] path Path to the CharNGram pre-trained word vector file. + /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). + static Status BuildFromFile(std::shared_ptr *char_n_gram, const std::string &path, + int32_t max_vectors = 0); + + /// \brief Look up embedding vectors of token. + /// \param[in] token A token to be looked up. + /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`. + /// (default={}, means to initialize with zero vectors). + /// \param[in] lower_case_backup Whether to look up the token in the lower case (Default = false). + /// \return The vector of the input token. + std::vector Lookup(const std::string &token, const std::vector &unk_init = {}, + bool lower_case_backup = false); +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_CHAR_N_GRAM_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/fast_text.cc b/mindspore/ccsrc/minddata/dataset/text/fast_text.cc index e4037a7db88..b4c5e77745e 100644 --- a/mindspore/ccsrc/minddata/dataset/text/fast_text.cc +++ b/mindspore/ccsrc/minddata/dataset/text/fast_text.cc @@ -1,50 +1,50 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/fast_text.h" - -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -FastText::FastText(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} - -Status CheckFastText(const std::string &file_path) { - Path path = Path(file_path); - if (path.Exists() && !path.IsDirectory()) { - std::string basename = path.Basename(); - size_t dot = basename.rfind('.'); - std::string suffix = basename.substr(dot + 1); - if (suffix != "vec") { - RETURN_STATUS_UNEXPECTED("FastText: invalid file, can not find file '*.vec', but got: " + file_path); - } - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("FastText: invalid file, failed to open FastText file."); - } -} - -Status FastText::BuildFromFile(std::shared_ptr *fast_text, const std::string &path, int32_t max_vectors) { - RETURN_UNEXPECTED_IF_NULL(fast_text); - RETURN_IF_NOT_OK(CheckFastText(path)); - std::unordered_map> map; - int vector_dim = -1; - RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); - *fast_text = std::make_shared(std::move(map), vector_dim); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/fast_text.h" + +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +FastText::FastText(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} + +Status CheckFastText(const std::string &file_path) { + Path path = Path(file_path); + if (path.Exists() && !path.IsDirectory()) { + std::string basename = path.Basename(); + size_t dot = basename.rfind('.'); + std::string suffix = basename.substr(dot + 1); + if (suffix != "vec") { + RETURN_STATUS_UNEXPECTED("FastText: invalid file, can not find file '*.vec', but got: " + file_path); + } + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED("FastText: invalid file, failed to open FastText file."); + } +} + +Status FastText::BuildFromFile(std::shared_ptr *fast_text, const std::string &path, int32_t max_vectors) { + RETURN_UNEXPECTED_IF_NULL(fast_text); + RETURN_IF_NOT_OK(CheckFastText(path)); + std::unordered_map> map; + int vector_dim = -1; + RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); + *fast_text = std::make_shared(std::move(map), vector_dim); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/fast_text.h b/mindspore/ccsrc/minddata/dataset/text/fast_text.h index 0421c7ed49c..6bde296e89e 100644 --- a/mindspore/ccsrc/minddata/dataset/text/fast_text.h +++ b/mindspore/ccsrc/minddata/dataset/text/fast_text.h @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ - -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/include/dataset/iterator.h" -#include "minddata/dataset/text/vectors.h" -#include "minddata/dataset/util/path.h" - -namespace mindspore { -namespace dataset { -/// \brief Pre-train word vectors. -class FastText : public Vectors { - public: - /// Constructor. - FastText() = default; - - /// Constructor. - /// \param[in] map A map between string and vector. - /// \param[in] dim Dimension of the vectors. - FastText(const std::unordered_map> &map, int32_t dim); - - /// Destructor. - ~FastText() = default; - - /// \brief Build Vectors from reading a pre-train vector file. - /// \param[out] fast_text FastText object which contains the pre-train vectors. - /// \param[in] path Path to the pre-trained word vector file. The suffix of set must be `*.vec`. - /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). - static Status BuildFromFile(std::shared_ptr *fast_text, const std::string &path, int32_t max_vectors = 0); -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/iterator.h" +#include "minddata/dataset/text/vectors.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { +/// \brief Pre-train word vectors. +class FastText : public Vectors { + public: + /// Constructor. + FastText() = default; + + /// Constructor. + /// \param[in] map A map between string and vector. + /// \param[in] dim Dimension of the vectors. + FastText(const std::unordered_map> &map, int32_t dim); + + /// Destructor. + ~FastText() = default; + + /// \brief Build Vectors from reading a pre-train vector file. + /// \param[out] fast_text FastText object which contains the pre-train vectors. + /// \param[in] path Path to the pre-trained word vector file. The suffix of set must be `*.vec`. + /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). + static Status BuildFromFile(std::shared_ptr *fast_text, const std::string &path, int32_t max_vectors = 0); +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_FAST_TEXT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/glove.cc b/mindspore/ccsrc/minddata/dataset/text/glove.cc index d338b46e401..6460e7754bf 100644 --- a/mindspore/ccsrc/minddata/dataset/text/glove.cc +++ b/mindspore/ccsrc/minddata/dataset/text/glove.cc @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/glove.h" - -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -GloVe::GloVe(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} - -Status CheckGloVe(const std::string &file_path) { - Path path = Path(file_path); - if (path.Exists() && !path.IsDirectory()) { - std::string basename = path.Basename(); - size_t dot = basename.rfind('.'); - std::string suffix = basename.substr(dot + 1); - std::string sub_name = basename.substr(0, dot); - dot = sub_name.rfind('.'); - std::string glove_name = sub_name.substr(0, dot); - dot = glove_name.rfind('.'); - std::string infix = glove_name.substr(dot + 1); - std::string prefix = glove_name.substr(0, dot); - if (suffix != "txt" || infix != "6B" || prefix != "glove") { - RETURN_STATUS_UNEXPECTED("GloVe: invalid file, can not find file 'glove.6B.*.txt', but got: " + file_path); - } - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("GloVe: invalid file, failed to open GloVe file."); - } -} - -Status GloVe::BuildFromFile(std::shared_ptr *glove, const std::string &path, int32_t max_vectors) { - RETURN_UNEXPECTED_IF_NULL(glove); - RETURN_IF_NOT_OK(CheckGloVe(path)); - std::unordered_map> map; - int vector_dim = -1; - RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); - *glove = std::make_shared(std::move(map), vector_dim); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/glove.h" + +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +GloVe::GloVe(const std::unordered_map> &map, int32_t dim) : Vectors(map, dim) {} + +Status CheckGloVe(const std::string &file_path) { + Path path = Path(file_path); + if (path.Exists() && !path.IsDirectory()) { + std::string basename = path.Basename(); + size_t dot = basename.rfind('.'); + std::string suffix = basename.substr(dot + 1); + std::string sub_name = basename.substr(0, dot); + dot = sub_name.rfind('.'); + std::string glove_name = sub_name.substr(0, dot); + dot = glove_name.rfind('.'); + std::string infix = glove_name.substr(dot + 1); + std::string prefix = glove_name.substr(0, dot); + if (suffix != "txt" || infix != "6B" || prefix != "glove") { + RETURN_STATUS_UNEXPECTED("GloVe: invalid file, can not find file 'glove.6B.*.txt', but got: " + file_path); + } + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED("GloVe: invalid file, failed to open GloVe file."); + } +} + +Status GloVe::BuildFromFile(std::shared_ptr *glove, const std::string &path, int32_t max_vectors) { + RETURN_UNEXPECTED_IF_NULL(glove); + RETURN_IF_NOT_OK(CheckGloVe(path)); + std::unordered_map> map; + int vector_dim = -1; + RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); + *glove = std::make_shared(std::move(map), vector_dim); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/glove.h b/mindspore/ccsrc/minddata/dataset/text/glove.h index c5f47dfe2a6..f4a91b98822 100644 --- a/mindspore/ccsrc/minddata/dataset/text/glove.h +++ b/mindspore/ccsrc/minddata/dataset/text/glove.h @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ - -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/include/dataset/iterator.h" -#include "minddata/dataset/text/vectors.h" -#include "minddata/dataset/util/path.h" - -namespace mindspore { -namespace dataset { -/// \brief Pre-train word vectors. -class GloVe : public Vectors { - public: - /// Constructor. - GloVe() = default; - - /// Constructor. - /// \param[in] map A map between string and vector. - /// \param[in] dim Dimension of the vectors. - GloVe(const std::unordered_map> &map, int32_t dim); - - /// Destructor. - ~GloVe() = default; - - /// \brief Build Vectors from reading a pre-train vector file. - /// \param[out] glove GloVe object which contains the pre-train vectors. - /// \param[in] path Path to the pre-trained word vector file. - /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). - static Status BuildFromFile(std::shared_ptr *glove, const std::string &path, int32_t max_vectors = 0); -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/iterator.h" +#include "minddata/dataset/text/vectors.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { +/// \brief Pre-train word vectors. +class GloVe : public Vectors { + public: + /// Constructor. + GloVe() = default; + + /// Constructor. + /// \param[in] map A map between string and vector. + /// \param[in] dim Dimension of the vectors. + GloVe(const std::unordered_map> &map, int32_t dim); + + /// Destructor. + ~GloVe() = default; + + /// \brief Build Vectors from reading a pre-train vector file. + /// \param[out] glove GloVe object which contains the pre-train vectors. + /// \param[in] path Path to the pre-trained word vector file. + /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). + static Status BuildFromFile(std::shared_ptr *glove, const std::string &path, int32_t max_vectors = 0); +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_GLOVE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.cc index 9033d5410fa..984bb28da61 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.cc @@ -1,58 +1,58 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "minddata/dataset/text/kernels/to_vectors_op.h" - -namespace mindspore { -namespace dataset { -ToVectorsOp::ToVectorsOp(const std::shared_ptr &vectors, const std::vector &unk_init, - bool lower_case_backup) - : vectors_(vectors), unk_init_(unk_init), lower_case_backup_(lower_case_backup) {} - -Status ToVectorsOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "ToVectors: input tensor type should be string."); - CHECK_FAIL_RETURN_UNEXPECTED(unk_init_.size() == 0 || unk_init_.size() == vectors_->Dim(), - "ToVectors: unk_init must be the same length as vectors, but got unk_init: " + - std::to_string(unk_init_.size()) + " and vectors: " + std::to_string(vectors_->Dim())); - - std::vector vectors_vec; - int len = 0; - for (auto itr = input->begin(); itr != input->end(); ++itr) { - std::vector vectors_value = vectors_->Lookup(std::string(*itr), unk_init_, lower_case_backup_); - CHECK_FAIL_RETURN_UNEXPECTED(!vectors_value.empty(), "ToVectors: invalid data, token: \"" + std::string(*itr) + - "\" doesn't exist in vectors and no unk_init is specified."); - vectors_vec.insert(vectors_vec.end(), vectors_value.begin(), vectors_value.end()); - len++; - } - - int dim = static_cast(vectors_vec.size() / len); - if (vectors_vec.size() == dim) { - RETURN_IF_NOT_OK(Tensor::CreateFromVector(vectors_vec, output)); - } else { - RETURN_IF_NOT_OK(Tensor::CreateFromVector(vectors_vec, TensorShape({len, dim}), output)); - } - return Status::OK(); -} - -Status ToVectorsOp::OutputType(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), - "ToVectors: input and output size don't match."); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "ToVectors: input tensor type should be string."); - outputs[0] = DataType(DataType::DE_FLOAT32); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/to_vectors_op.h" + +namespace mindspore { +namespace dataset { +ToVectorsOp::ToVectorsOp(const std::shared_ptr &vectors, const std::vector &unk_init, + bool lower_case_backup) + : vectors_(vectors), unk_init_(unk_init), lower_case_backup_(lower_case_backup) {} + +Status ToVectorsOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "ToVectors: input tensor type should be string."); + CHECK_FAIL_RETURN_UNEXPECTED(unk_init_.size() == 0 || unk_init_.size() == vectors_->Dim(), + "ToVectors: unk_init must be the same length as vectors, but got unk_init: " + + std::to_string(unk_init_.size()) + " and vectors: " + std::to_string(vectors_->Dim())); + + std::vector vectors_vec; + int len = 0; + for (auto itr = input->begin(); itr != input->end(); ++itr) { + std::vector vectors_value = vectors_->Lookup(std::string(*itr), unk_init_, lower_case_backup_); + CHECK_FAIL_RETURN_UNEXPECTED(!vectors_value.empty(), "ToVectors: invalid data, token: \"" + std::string(*itr) + + "\" doesn't exist in vectors and no unk_init is specified."); + vectors_vec.insert(vectors_vec.end(), vectors_value.begin(), vectors_value.end()); + len++; + } + + int dim = static_cast(vectors_vec.size() / len); + if (vectors_vec.size() == dim) { + RETURN_IF_NOT_OK(Tensor::CreateFromVector(vectors_vec, output)); + } else { + RETURN_IF_NOT_OK(Tensor::CreateFromVector(vectors_vec, TensorShape({len, dim}), output)); + } + return Status::OK(); +} + +Status ToVectorsOp::OutputType(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), + "ToVectors: input and output size don't match."); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "ToVectors: input tensor type should be string."); + outputs[0] = DataType(DataType::DE_FLOAT32); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.h index 913b3a91bff..b910b395e3b 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_vectors_op.h @@ -1,64 +1,64 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ - -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/text/vectors.h" -#include "minddata/dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class ToVectorsOp : public TensorOp { - public: - /// \brief Constructor. - /// \param[in] vectors Vectors used to lookup tokens. - /// \param[in] unk_init Vector used to initialize OOV token. - /// \param[in] lower_case_backup Whether to look up the token in the lower case. - ToVectorsOp(const std::shared_ptr &vectors, const std::vector &unk_init, bool lower_case_backup); - - /// \brief Destructor. - ~ToVectorsOp() = default; - - /// \brief Perform actual ToVectors on each tensor. - /// \param[in] input Input tensor. - /// \param[in] output Output tensor. - /// \return[out] Status code. - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - /// \param[in] inputs DataType of input tensor. - /// \param[in] outputs DataType of output tensor. - /// \return[out] Status code. - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - /// \brief Get Op name. - std::string Name() const override { return kToVectorsOp; } - - private: - std::shared_ptr vectors_; - std::vector unk_init_; - bool lower_case_backup_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/vectors.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class ToVectorsOp : public TensorOp { + public: + /// \brief Constructor. + /// \param[in] vectors Vectors used to lookup tokens. + /// \param[in] unk_init Vector used to initialize OOV token. + /// \param[in] lower_case_backup Whether to look up the token in the lower case. + ToVectorsOp(const std::shared_ptr &vectors, const std::vector &unk_init, bool lower_case_backup); + + /// \brief Destructor. + ~ToVectorsOp() = default; + + /// \brief Perform actual ToVectors on each tensor. + /// \param[in] input Input tensor. + /// \param[in] output Output tensor. + /// \return[out] Status code. + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + /// \param[in] inputs DataType of input tensor. + /// \param[in] outputs DataType of output tensor. + /// \return[out] Status code. + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + /// \brief Get Op name. + std::string Name() const override { return kToVectorsOp; } + + private: + std::shared_ptr vectors_; + std::vector unk_init_; + bool lower_case_backup_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_TO_VECTORS_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h index 71d85ef2043..84ede593307 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h @@ -1,69 +1,69 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/include/dataset/text.h" -#include "minddata/dataset/kernels/tensor_op.h" -#include "minddata/dataset/text/kernels/tokenizer_op.h" -#include "minddata/dataset/util/status.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; -namespace mindspore { -namespace dataset { - -class WordpieceTokenizerOp : public TokenizerOp { - public: - static const char kDefSuffixIndicator[]; - static const int kDefMaxBytesPerToken; - static const char kDefUnknownToken[]; - WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, - const int &max_bytes_per_token = kDefMaxBytesPerToken, - const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets); - - ~WordpieceTokenizerOp() override = default; - - Status Compute(const TensorRow &input, TensorRow *output) override; - - protected: - Status AddSubword(const std::string &input_token, const int &start, const int &end, - std::vector *out_tokens) const; - Status FoundNoToken(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, - std::vector *offsets_start, std::vector *offsets_limit) const; - Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, - int *out_end) const; - Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, - std::vector *offsets_start, std::vector *offsets_limit) const; - - std::string Name() const override { return kWordpieceTokenizerOp; } - - private: - const std::shared_ptr vocab_; - const std::string suffix_indicator_; - const int max_bytes_per_token_; - const std::string unknown_token_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/text.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/tokenizer_op.h" +#include "minddata/dataset/util/status.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; +namespace mindspore { +namespace dataset { + +class WordpieceTokenizerOp : public TokenizerOp { + public: + static const char kDefSuffixIndicator[]; + static const int kDefMaxBytesPerToken; + static const char kDefUnknownToken[]; + WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, + const int &max_bytes_per_token = kDefMaxBytesPerToken, + const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets); + + ~WordpieceTokenizerOp() override = default; + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_tokens) const; + Status FoundNoToken(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, + int *out_end) const; + Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + + std::string Name() const override { return kWordpieceTokenizerOp; } + + private: + const std::shared_ptr vocab_; + const std::string suffix_indicator_; + const int max_bytes_per_token_; + const std::string unknown_token_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/vectors.cc b/mindspore/ccsrc/minddata/dataset/text/vectors.cc index 7841e7f3e59..f38aafadd42 100644 --- a/mindspore/ccsrc/minddata/dataset/text/vectors.cc +++ b/mindspore/ccsrc/minddata/dataset/text/vectors.cc @@ -1,153 +1,153 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/vectors.h" - -#include "utils/file_utils.h" - -namespace mindspore { -namespace dataset { -Status Vectors::InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines, - int32_t *vector_dim) { - RETURN_UNEXPECTED_IF_NULL(num_lines); - RETURN_UNEXPECTED_IF_NULL(header_num_lines); - RETURN_UNEXPECTED_IF_NULL(vector_dim); - - std::ifstream file_reader; - file_reader.open(path, std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(), "Vectors: invalid file, failed to open vector file: " + path); - - *num_lines = 0, *header_num_lines = 0, *vector_dim = -1; - std::string line, row; - while (std::getline(file_reader, line)) { - if (*vector_dim == -1) { - std::vector vec; - std::istringstream line_reader(line); - while (std::getline(line_reader, row, ' ')) { - vec.push_back(row); - } - // The number of rows and dimensions can be obtained directly from the information header. - const int kInfoHeaderSize = 2; - if (vec.size() == kInfoHeaderSize) { - (*header_num_lines)++; - } else { - *vector_dim = vec.size() - 1; - (*num_lines)++; - } - } else { - (*num_lines)++; - } - } - file_reader.close(); - CHECK_FAIL_RETURN_UNEXPECTED(*num_lines > 0, "Vectors: invalid file, file is empty."); - - if (max_vectors > 0) { - *num_lines = std::min(max_vectors, *num_lines); // Determine the true rows. - } - return Status::OK(); -} - -Status Vectors::Load(const std::string &path, int32_t max_vectors, - std::unordered_map> *map, int32_t *vector_dim) { - RETURN_UNEXPECTED_IF_NULL(map); - RETURN_UNEXPECTED_IF_NULL(vector_dim); - auto realpath = FileUtils::GetRealPath(common::SafeCStr(path)); - CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Vectors: get real path failed, path: " + path); - auto file_path = realpath.value(); - - CHECK_FAIL_RETURN_UNEXPECTED(max_vectors >= 0, - "Vectors: max_vectors must be non negative, but got: " + std::to_string(max_vectors)); - - int num_lines = 0, header_num_lines = 0; - RETURN_IF_NOT_OK(InferShape(file_path, max_vectors, &num_lines, &header_num_lines, vector_dim)); - - std::fstream file_reader; - file_reader.open(file_path, std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(), - "Vectors: invalid file, failed to open vector file: " + file_path); - - while (header_num_lines > 0) { - file_reader.ignore(std::numeric_limits::max(), '\n'); - header_num_lines--; - } - - std::string line, token, vector_value; - for (auto i = 0; i < num_lines; ++i) { - std::getline(file_reader, line); - std::istringstream line_reader(line); - std::getline(line_reader, token, ' '); - std::vector vector_values; - int dim = 0; - while (line_reader >> vector_value) { - dim++; - vector_values.push_back(atof(vector_value.c_str())); - } - if (dim <= 1) { - file_reader.close(); - RETURN_STATUS_UNEXPECTED("Vectors: token with 1-dimensional vector."); - } - if (dim != *vector_dim) { - file_reader.close(); - RETURN_STATUS_UNEXPECTED("Vectors: all vectors must have the same number of dimensions, but got dim " + - std::to_string(dim) + " while expecting " + std::to_string(*vector_dim)); - } - - auto token_index = map->find(token); - if (token_index == map->end()) { - (*map)[token] = vector_values; - } - } - file_reader.close(); - return Status::OK(); -} - -Vectors::Vectors(const std::unordered_map> &map, int32_t dim) { - map_ = map; - dim_ = dim; -} - -Status Vectors::BuildFromFile(std::shared_ptr *vectors, const std::string &path, int32_t max_vectors) { - RETURN_UNEXPECTED_IF_NULL(vectors); - std::unordered_map> map; - int vector_dim = -1; - RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); - *vectors = std::make_shared(std::move(map), vector_dim); - return Status::OK(); -} - -std::vector Vectors::Lookup(const std::string &token, const std::vector &unk_init, - bool lower_case_backup) { - std::vector init_vec(dim_, 0); - if (!unk_init.empty()) { - if (unk_init.size() != dim_) { - MS_LOG(WARNING) << "Vectors: size of unk_init is not the same as vectors, will initialize with zero vectors."; - } else { - init_vec = unk_init; - } - } - std::string lower_token = token; - if (lower_case_backup) { - transform(lower_token.begin(), lower_token.end(), lower_token.begin(), ::tolower); - } - auto str_index = map_.find(lower_token); - if (str_index == map_.end()) { - return init_vec; - } else { - return str_index->second; - } -} -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/vectors.h" + +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +Status Vectors::InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines, + int32_t *vector_dim) { + RETURN_UNEXPECTED_IF_NULL(num_lines); + RETURN_UNEXPECTED_IF_NULL(header_num_lines); + RETURN_UNEXPECTED_IF_NULL(vector_dim); + + std::ifstream file_reader; + file_reader.open(path, std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(), "Vectors: invalid file, failed to open vector file: " + path); + + *num_lines = 0, *header_num_lines = 0, *vector_dim = -1; + std::string line, row; + while (std::getline(file_reader, line)) { + if (*vector_dim == -1) { + std::vector vec; + std::istringstream line_reader(line); + while (std::getline(line_reader, row, ' ')) { + vec.push_back(row); + } + // The number of rows and dimensions can be obtained directly from the information header. + const int kInfoHeaderSize = 2; + if (vec.size() == kInfoHeaderSize) { + (*header_num_lines)++; + } else { + *vector_dim = vec.size() - 1; + (*num_lines)++; + } + } else { + (*num_lines)++; + } + } + file_reader.close(); + CHECK_FAIL_RETURN_UNEXPECTED(*num_lines > 0, "Vectors: invalid file, file is empty."); + + if (max_vectors > 0) { + *num_lines = std::min(max_vectors, *num_lines); // Determine the true rows. + } + return Status::OK(); +} + +Status Vectors::Load(const std::string &path, int32_t max_vectors, + std::unordered_map> *map, int32_t *vector_dim) { + RETURN_UNEXPECTED_IF_NULL(map); + RETURN_UNEXPECTED_IF_NULL(vector_dim); + auto realpath = FileUtils::GetRealPath(common::SafeCStr(path)); + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Vectors: get real path failed, path: " + path); + auto file_path = realpath.value(); + + CHECK_FAIL_RETURN_UNEXPECTED(max_vectors >= 0, + "Vectors: max_vectors must be non negative, but got: " + std::to_string(max_vectors)); + + int num_lines = 0, header_num_lines = 0; + RETURN_IF_NOT_OK(InferShape(file_path, max_vectors, &num_lines, &header_num_lines, vector_dim)); + + std::fstream file_reader; + file_reader.open(file_path, std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(), + "Vectors: invalid file, failed to open vector file: " + file_path); + + while (header_num_lines > 0) { + file_reader.ignore(std::numeric_limits::max(), '\n'); + header_num_lines--; + } + + std::string line, token, vector_value; + for (auto i = 0; i < num_lines; ++i) { + std::getline(file_reader, line); + std::istringstream line_reader(line); + std::getline(line_reader, token, ' '); + std::vector vector_values; + int dim = 0; + while (line_reader >> vector_value) { + dim++; + vector_values.push_back(atof(vector_value.c_str())); + } + if (dim <= 1) { + file_reader.close(); + RETURN_STATUS_UNEXPECTED("Vectors: token with 1-dimensional vector."); + } + if (dim != *vector_dim) { + file_reader.close(); + RETURN_STATUS_UNEXPECTED("Vectors: all vectors must have the same number of dimensions, but got dim " + + std::to_string(dim) + " while expecting " + std::to_string(*vector_dim)); + } + + auto token_index = map->find(token); + if (token_index == map->end()) { + (*map)[token] = vector_values; + } + } + file_reader.close(); + return Status::OK(); +} + +Vectors::Vectors(const std::unordered_map> &map, int32_t dim) { + map_ = map; + dim_ = dim; +} + +Status Vectors::BuildFromFile(std::shared_ptr *vectors, const std::string &path, int32_t max_vectors) { + RETURN_UNEXPECTED_IF_NULL(vectors); + std::unordered_map> map; + int vector_dim = -1; + RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim)); + *vectors = std::make_shared(std::move(map), vector_dim); + return Status::OK(); +} + +std::vector Vectors::Lookup(const std::string &token, const std::vector &unk_init, + bool lower_case_backup) { + std::vector init_vec(dim_, 0); + if (!unk_init.empty()) { + if (unk_init.size() != dim_) { + MS_LOG(WARNING) << "Vectors: size of unk_init is not the same as vectors, will initialize with zero vectors."; + } else { + init_vec = unk_init; + } + } + std::string lower_token = token; + if (lower_case_backup) { + transform(lower_token.begin(), lower_token.end(), lower_token.begin(), ::tolower); + } + auto str_index = map_.find(lower_token); + if (str_index == map_.end()) { + return init_vec; + } else { + return str_index->second; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/vectors.h b/mindspore/ccsrc/minddata/dataset/text/vectors.h index 3516b8c5872..dd82bca916f 100644 --- a/mindspore/ccsrc/minddata/dataset/text/vectors.h +++ b/mindspore/ccsrc/minddata/dataset/text/vectors.h @@ -1,88 +1,88 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/include/dataset/iterator.h" - -namespace mindspore { -namespace dataset { -/// \brief Pre-train word vectors. -class Vectors { - public: - /// Constructor. - Vectors() = default; - - /// Constructor. - /// \param[in] map A map between string and vector. - /// \param[in] dim Dimension of the vectors. - Vectors(const std::unordered_map> &map, int32_t dim); - - /// Destructor. - virtual ~Vectors() = default; - - /// \brief Build Vectors from reading a pre-train vector file. - /// \param[out] vectors Vectors object which contains the pre-train vectors. - /// \param[in] path Path to the pre-trained word vector file. - /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). - static Status BuildFromFile(std::shared_ptr *vectors, const std::string &path, int32_t max_vectors = 0); - - /// \brief Look up embedding vectors of token. - /// \param[in] token A token to be looked up. - /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`. - /// (default={}, means to initialize with zero vectors). - /// \param[in] lower_case_backup Whether to look up the token in the lower case (Default = false). - /// \return The vector of the input token. - virtual std::vector Lookup(const std::string &token, const std::vector &unk_init = {}, - bool lower_case_backup = false); - - /// \brief Getter of dimension. - const int32_t &Dim() const { return dim_; } - - protected: - /// \brief Infer the shape of the pre-trained word vector file. - /// \param[in] path Path to the pre-trained word vector file. - /// \param[in] max_vectors Maximum number of pre-trained word vectors to be read. - /// \param[out] num_lines The number of lines of the file. - /// \param[out] header_num_lines The number of lines of file header. - /// \param[out] vector_dim The dimension of the vectors in the file. - static Status InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines, - int32_t *vector_dim); - - /// \brief Load map from reading a pre-train vector file. - /// \param[in] path Path to the pre-trained word vector file. - /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded, must be non negative. - /// \param[out] map The map between words and vectors. - /// \param[out] vector_dim The dimension of the vectors in the file. - static Status Load(const std::string &path, int32_t max_vectors, - std::unordered_map> *map, int32_t *vector_dim); - - int32_t dim_; - std::unordered_map> map_; -}; -} // namespace dataset -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/iterator.h" + +namespace mindspore { +namespace dataset { +/// \brief Pre-train word vectors. +class Vectors { + public: + /// Constructor. + Vectors() = default; + + /// Constructor. + /// \param[in] map A map between string and vector. + /// \param[in] dim Dimension of the vectors. + Vectors(const std::unordered_map> &map, int32_t dim); + + /// Destructor. + virtual ~Vectors() = default; + + /// \brief Build Vectors from reading a pre-train vector file. + /// \param[out] vectors Vectors object which contains the pre-train vectors. + /// \param[in] path Path to the pre-trained word vector file. + /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit). + static Status BuildFromFile(std::shared_ptr *vectors, const std::string &path, int32_t max_vectors = 0); + + /// \brief Look up embedding vectors of token. + /// \param[in] token A token to be looked up. + /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`. + /// (default={}, means to initialize with zero vectors). + /// \param[in] lower_case_backup Whether to look up the token in the lower case (Default = false). + /// \return The vector of the input token. + virtual std::vector Lookup(const std::string &token, const std::vector &unk_init = {}, + bool lower_case_backup = false); + + /// \brief Getter of dimension. + const int32_t &Dim() const { return dim_; } + + protected: + /// \brief Infer the shape of the pre-trained word vector file. + /// \param[in] path Path to the pre-trained word vector file. + /// \param[in] max_vectors Maximum number of pre-trained word vectors to be read. + /// \param[out] num_lines The number of lines of the file. + /// \param[out] header_num_lines The number of lines of file header. + /// \param[out] vector_dim The dimension of the vectors in the file. + static Status InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines, + int32_t *vector_dim); + + /// \brief Load map from reading a pre-train vector file. + /// \param[in] path Path to the pre-trained word vector file. + /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded, must be non negative. + /// \param[out] map The map between words and vectors. + /// \param[out] vector_dim The dimension of the vectors in the file. + static Status Load(const std::string &path, int32_t max_vectors, + std::unordered_map> *map, int32_t *vector_dim); + + int32_t dim_; + std::unordered_map> map_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.cc b/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.cc index 7cdb170444b..312abb21392 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.cc +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.cc @@ -1,400 +1,400 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pipeline/jit/pi/graph_guard/info.h" -#include -#include -#include -#include -#include -#include "pipeline/jit/pi/utils/utils.h" - -namespace mindspore { -namespace pijit { - -static constexpr char kSepFlag = '\\'; -static constexpr char kBeginFlag = '{'; -static constexpr char kEndFlag = '}'; -static constexpr char kArrayBeginFlag = '['; -static constexpr char kArrayEndFlag = ']'; -static constexpr size_t kInitLimit = 1024; - -template -size_t StoreScalar(uint8_t *buf, size_t ptr, T val) { - uint8_t *pVal = reinterpret_cast(&val); - constexpr int kScalarSize = sizeof(T); - for (int idx = 0; idx < kScalarSize; ++idx) { - buf[ptr++] = pVal[idx]; - } - return ptr; -} - -template -size_t AppendScalar(uint8_t *buf, size_t ptr, T v) { - ptr = StoreScalar(buf, ptr, v); - ptr = StoreScalar(buf, ptr, kSepFlag); - return ptr; -} - -template -size_t StoreVector(uint8_t *buf, size_t ptr, const std::vector &val) { - T *pVal = const_cast(val.data()); - size_t szVal = val.size() * sizeof(T); - ptr = StoreScalar(buf, ptr, szVal); - memcpy_s(buf + ptr, szVal, reinterpret_cast(pVal), szVal); - return ptr + szVal; -} - -template -size_t AppendVector(uint8_t *buf, size_t ptr, const std::vector &v) { - ptr = StoreScalar(buf, ptr, kArrayBeginFlag); - ptr = StoreVector(buf, ptr, v); - ptr = StoreScalar(buf, ptr, kArrayEndFlag); - ptr = StoreScalar(buf, ptr, kSepFlag); - return ptr; -} - -InfoPack::InfoPack() : id_(kInvalidId), buf_(std::make_unique(kInitLimit)), ptr_(0), limit_(kInitLimit) {} - -InfoPack::InfoPack(const InfoPack &dup) - : id_(dup.id_), buf_(std::make_unique(dup.ptr_)), ptr_(dup.ptr_), limit_(dup.ptr_) { - memcpy_s(buf_.get(), dup.ptr_, dup.buf_.get(), dup.ptr_); -} - -InfoPack::~InfoPack() { buf_.reset(nullptr); } - -size_t InfoPack::Id() const { return id_; } - -uint8_t *InfoPack::Buf(size_t *sz) const { - if (sz != nullptr) { - *sz = ptr_; - return buf_.get(); - } - return nullptr; -} - -void InfoPack::Update() { id_ = CalcBuffer(buf_.get(), ptr_); } - -InfoPack &InfoPack::Begin() { - AllocIfNeed(sizeof(kBeginFlag)); - *(buf_.get() + ptr_++) = (uint8_t)kBeginFlag; - return *this; -} - -InfoPack &InfoPack::End() { - if (buf_.get()[ptr_ - 1] == kSepFlag) { - buf_.get()[ptr_ - 1] = kEndFlag; - } else { - AllocIfNeed(sizeof(kEndFlag)); - *(buf_.get() + ptr_++) = (uint8_t)kEndFlag; - } - return *this; -} - -InfoPack &InfoPack::operator<<(int8_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(uint8_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(int16_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(uint16_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(int32_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(uint32_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(int64_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(uint64_t v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(float v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(double v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(bool vv) { - uint8_t v = vv ? 1 : 0; - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(void *v) { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(PyObject *vv) { - uint8_t v = vv != nullptr ? 1 : 0; - if (vv != nullptr) { - size_t w = CalcString(std::string(py::str(vv))); - AllocIfNeed(sizeof(v) + sizeof(w) + sizeof(kSepFlag)); - ptr_ = StoreScalar(buf_.get(), ptr_, v); - ptr_ = AppendScalar(buf_.get(), ptr_, w); - } else { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - } - return *this; -} - -InfoPack &InfoPack::operator<<(mindspore::BasePtr vv) { - uint8_t v = vv != nullptr ? 1 : 0; - if (vv != nullptr) { - size_t w = CalcString(vv->ToString()); - AllocIfNeed(sizeof(v) + sizeof(w) + sizeof(kSepFlag)); - ptr_ = StoreScalar(buf_.get(), ptr_, v); - ptr_ = AppendScalar(buf_.get(), ptr_, w); - } else { - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - } - return *this; -} - -InfoPack &InfoPack::operator<<(const std::string &vv) { - size_t v = CalcString(vv); - AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &vv) { - std::vector v; - std::transform(vv.begin(), vv.end(), std::back_inserter(v), [](const auto &item) { return item ? 1 : 0; }); - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &vv) { - std::vector v; - std::transform(vv.begin(), vv.end(), std::back_inserter(v), [this](const auto &item) { return CalcString(item); }); - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &v) { - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const std::vector &vv) { - std::vector v; - std::transform(vv.begin(), vv.end(), std::back_inserter(v), - [this](const auto &item) { return CalcString(std::string(py::str(item))); }); - AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + - sizeof(size_t)); - ptr_ = AppendVector(buf_.get(), ptr_, v); - return *this; -} - -InfoPack &InfoPack::operator<<(const InfoPack &v) { - size_t id = v.Id(); - AllocIfNeed(sizeof(id) + sizeof(kSepFlag)); - ptr_ = AppendScalar(buf_.get(), ptr_, id); - return *this; -} - -class String2Id { - public: - String2Id() = default; - ~String2Id() = default; - size_t Insert(std::string key) { - if (map_.find(key) == map_.end()) { - map_[key] = map_.size(); - return map_.size() - 1; - } else { - return map_[key]; - } - } - - protected: - std::map map_; -}; - -static String2Id g_StrMap; - -size_t InfoPack::CalcString(std::string v) { return g_StrMap.Insert(v); } - -struct BufferHash { - bool operator()(const std::vector &lhs, const std::vector &rhs) const { - if (lhs.size() == rhs.size()) { - return memcmp(lhs.data(), rhs.data(), lhs.size()) == 0; - } - return false; - } - size_t operator()(const std::vector &k) const { - size_t ret = 0; - ret = std::accumulate(k.begin(), k.end(), ret, [](size_t key, uint8_t v) { - static constexpr int kShiftKey = 3; - return (key << kShiftKey) + v; - }); - return ret; - } -}; - -class Buffer2Id { - public: - Buffer2Id() = default; - ~Buffer2Id() = default; - size_t Insert(uint8_t *buf, size_t sz) { - std::vector vec(sz); - memcpy_s(vec.data(), sz, buf, sz); - auto it = map_.find(vec); - if (it == map_.end()) { - size_t ret = map_.size(); - map_[vec] = ret; - return ret; - } else { - return it->second; - } - } - - protected: - std::unordered_map, size_t, BufferHash, BufferHash> map_; -}; - -static Buffer2Id g_BufMap; - -size_t InfoPack::CalcBuffer(uint8_t *buf, size_t sz) { return g_BufMap.Insert(buf, sz); } - -void InfoPack::AllocIfNeed(size_t need) { - if (limit_ < need + ptr_) { - do { - limit_ += kInitLimit; - } while (limit_ < need + ptr_); - auto buf = std::make_unique(limit_); - memcpy_s(buf.get(), limit_, buf_.get(), ptr_ * sizeof(uint8_t)); - buf_.reset(buf.release()); - } -} -} // namespace pijit -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pipeline/jit/pi/graph_guard/info.h" +#include +#include +#include +#include +#include +#include "pipeline/jit/pi/utils/utils.h" + +namespace mindspore { +namespace pijit { + +static constexpr char kSepFlag = '\\'; +static constexpr char kBeginFlag = '{'; +static constexpr char kEndFlag = '}'; +static constexpr char kArrayBeginFlag = '['; +static constexpr char kArrayEndFlag = ']'; +static constexpr size_t kInitLimit = 1024; + +template +size_t StoreScalar(uint8_t *buf, size_t ptr, T val) { + uint8_t *pVal = reinterpret_cast(&val); + constexpr int kScalarSize = sizeof(T); + for (int idx = 0; idx < kScalarSize; ++idx) { + buf[ptr++] = pVal[idx]; + } + return ptr; +} + +template +size_t AppendScalar(uint8_t *buf, size_t ptr, T v) { + ptr = StoreScalar(buf, ptr, v); + ptr = StoreScalar(buf, ptr, kSepFlag); + return ptr; +} + +template +size_t StoreVector(uint8_t *buf, size_t ptr, const std::vector &val) { + T *pVal = const_cast(val.data()); + size_t szVal = val.size() * sizeof(T); + ptr = StoreScalar(buf, ptr, szVal); + memcpy_s(buf + ptr, szVal, reinterpret_cast(pVal), szVal); + return ptr + szVal; +} + +template +size_t AppendVector(uint8_t *buf, size_t ptr, const std::vector &v) { + ptr = StoreScalar(buf, ptr, kArrayBeginFlag); + ptr = StoreVector(buf, ptr, v); + ptr = StoreScalar(buf, ptr, kArrayEndFlag); + ptr = StoreScalar(buf, ptr, kSepFlag); + return ptr; +} + +InfoPack::InfoPack() : id_(kInvalidId), buf_(std::make_unique(kInitLimit)), ptr_(0), limit_(kInitLimit) {} + +InfoPack::InfoPack(const InfoPack &dup) + : id_(dup.id_), buf_(std::make_unique(dup.ptr_)), ptr_(dup.ptr_), limit_(dup.ptr_) { + memcpy_s(buf_.get(), dup.ptr_, dup.buf_.get(), dup.ptr_); +} + +InfoPack::~InfoPack() { buf_.reset(nullptr); } + +size_t InfoPack::Id() const { return id_; } + +uint8_t *InfoPack::Buf(size_t *sz) const { + if (sz != nullptr) { + *sz = ptr_; + return buf_.get(); + } + return nullptr; +} + +void InfoPack::Update() { id_ = CalcBuffer(buf_.get(), ptr_); } + +InfoPack &InfoPack::Begin() { + AllocIfNeed(sizeof(kBeginFlag)); + *(buf_.get() + ptr_++) = (uint8_t)kBeginFlag; + return *this; +} + +InfoPack &InfoPack::End() { + if (buf_.get()[ptr_ - 1] == kSepFlag) { + buf_.get()[ptr_ - 1] = kEndFlag; + } else { + AllocIfNeed(sizeof(kEndFlag)); + *(buf_.get() + ptr_++) = (uint8_t)kEndFlag; + } + return *this; +} + +InfoPack &InfoPack::operator<<(int8_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(uint8_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(int16_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(uint16_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(int32_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(uint32_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(int64_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(uint64_t v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(float v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(double v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(bool vv) { + uint8_t v = vv ? 1 : 0; + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(void *v) { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(PyObject *vv) { + uint8_t v = vv != nullptr ? 1 : 0; + if (vv != nullptr) { + size_t w = CalcString(std::string(py::str(vv))); + AllocIfNeed(sizeof(v) + sizeof(w) + sizeof(kSepFlag)); + ptr_ = StoreScalar(buf_.get(), ptr_, v); + ptr_ = AppendScalar(buf_.get(), ptr_, w); + } else { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + } + return *this; +} + +InfoPack &InfoPack::operator<<(mindspore::BasePtr vv) { + uint8_t v = vv != nullptr ? 1 : 0; + if (vv != nullptr) { + size_t w = CalcString(vv->ToString()); + AllocIfNeed(sizeof(v) + sizeof(w) + sizeof(kSepFlag)); + ptr_ = StoreScalar(buf_.get(), ptr_, v); + ptr_ = AppendScalar(buf_.get(), ptr_, w); + } else { + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + } + return *this; +} + +InfoPack &InfoPack::operator<<(const std::string &vv) { + size_t v = CalcString(vv); + AllocIfNeed(sizeof(v) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &vv) { + std::vector v; + std::transform(vv.begin(), vv.end(), std::back_inserter(v), [](const auto &item) { return item ? 1 : 0; }); + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &vv) { + std::vector v; + std::transform(vv.begin(), vv.end(), std::back_inserter(v), [this](const auto &item) { return CalcString(item); }); + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &v) { + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const std::vector &vv) { + std::vector v; + std::transform(vv.begin(), vv.end(), std::back_inserter(v), + [this](const auto &item) { return CalcString(std::string(py::str(item))); }); + AllocIfNeed(sizeof(kArrayBeginFlag) + sizeof(kArrayEndFlag) + sizeof(kSepFlag) + sizeof(v[0]) * v.size() + + sizeof(size_t)); + ptr_ = AppendVector(buf_.get(), ptr_, v); + return *this; +} + +InfoPack &InfoPack::operator<<(const InfoPack &v) { + size_t id = v.Id(); + AllocIfNeed(sizeof(id) + sizeof(kSepFlag)); + ptr_ = AppendScalar(buf_.get(), ptr_, id); + return *this; +} + +class String2Id { + public: + String2Id() = default; + ~String2Id() = default; + size_t Insert(std::string key) { + if (map_.find(key) == map_.end()) { + map_[key] = map_.size(); + return map_.size() - 1; + } else { + return map_[key]; + } + } + + protected: + std::map map_; +}; + +static String2Id g_StrMap; + +size_t InfoPack::CalcString(std::string v) { return g_StrMap.Insert(v); } + +struct BufferHash { + bool operator()(const std::vector &lhs, const std::vector &rhs) const { + if (lhs.size() == rhs.size()) { + return memcmp(lhs.data(), rhs.data(), lhs.size()) == 0; + } + return false; + } + size_t operator()(const std::vector &k) const { + size_t ret = 0; + ret = std::accumulate(k.begin(), k.end(), ret, [](size_t key, uint8_t v) { + static constexpr int kShiftKey = 3; + return (key << kShiftKey) + v; + }); + return ret; + } +}; + +class Buffer2Id { + public: + Buffer2Id() = default; + ~Buffer2Id() = default; + size_t Insert(uint8_t *buf, size_t sz) { + std::vector vec(sz); + memcpy_s(vec.data(), sz, buf, sz); + auto it = map_.find(vec); + if (it == map_.end()) { + size_t ret = map_.size(); + map_[vec] = ret; + return ret; + } else { + return it->second; + } + } + + protected: + std::unordered_map, size_t, BufferHash, BufferHash> map_; +}; + +static Buffer2Id g_BufMap; + +size_t InfoPack::CalcBuffer(uint8_t *buf, size_t sz) { return g_BufMap.Insert(buf, sz); } + +void InfoPack::AllocIfNeed(size_t need) { + if (limit_ < need + ptr_) { + do { + limit_ += kInitLimit; + } while (limit_ < need + ptr_); + auto buf = std::make_unique(limit_); + memcpy_s(buf.get(), limit_, buf_.get(), ptr_ * sizeof(uint8_t)); + buf_.reset(buf.release()); + } +} +} // namespace pijit +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.h b/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.h index 05e9a1be816..ca48b4da7a8 100644 --- a/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.h +++ b/mindspore/ccsrc/pipeline/jit/pi/graph_guard/info.h @@ -1,87 +1,87 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_PI_JIT_INFO_H -#define MINDSPORE_PI_JIT_INFO_H - -#include -#include -#include -#include "pybind11/pybind11.h" -#include "mindspore/core/base/base.h" - -namespace py = pybind11; - -namespace mindspore { -namespace pijit { - -constexpr size_t kInvalidId = size_t(-1); - -class InfoPack { - public: - InfoPack(); - InfoPack(const InfoPack &); - virtual ~InfoPack(); - size_t Id() const; - uint8_t *Buf(size_t *sz) const; - void Update(); - InfoPack &Begin(); - InfoPack &End(); - InfoPack &operator<<(int8_t v); - InfoPack &operator<<(uint8_t v); - InfoPack &operator<<(int16_t v); - InfoPack &operator<<(uint16_t v); - InfoPack &operator<<(int32_t v); - InfoPack &operator<<(uint32_t v); - InfoPack &operator<<(int64_t v); - InfoPack &operator<<(uint64_t v); - InfoPack &operator<<(float v); - InfoPack &operator<<(double v); - InfoPack &operator<<(bool v); - InfoPack &operator<<(void *v); - InfoPack &operator<<(PyObject *v); - InfoPack &operator<<(mindspore::BasePtr v); - InfoPack &operator<<(const std::string &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const std::vector &v); - InfoPack &operator<<(const InfoPack &v); - void AllocIfNeed(size_t need); - - protected: - size_t CalcBuffer(uint8_t *buf, size_t sz); - size_t CalcString(std::string v); - size_t id_; - std::unique_ptr buf_; - size_t ptr_; - size_t limit_; -}; -using InfoPackPtr = std::shared_ptr; - -} // namespace pijit -} // namespace mindspore - -#endif // MINDSPORE_PI_JIT_INFO_H +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_PI_JIT_INFO_H +#define MINDSPORE_PI_JIT_INFO_H + +#include +#include +#include +#include "pybind11/pybind11.h" +#include "mindspore/core/base/base.h" + +namespace py = pybind11; + +namespace mindspore { +namespace pijit { + +constexpr size_t kInvalidId = size_t(-1); + +class InfoPack { + public: + InfoPack(); + InfoPack(const InfoPack &); + virtual ~InfoPack(); + size_t Id() const; + uint8_t *Buf(size_t *sz) const; + void Update(); + InfoPack &Begin(); + InfoPack &End(); + InfoPack &operator<<(int8_t v); + InfoPack &operator<<(uint8_t v); + InfoPack &operator<<(int16_t v); + InfoPack &operator<<(uint16_t v); + InfoPack &operator<<(int32_t v); + InfoPack &operator<<(uint32_t v); + InfoPack &operator<<(int64_t v); + InfoPack &operator<<(uint64_t v); + InfoPack &operator<<(float v); + InfoPack &operator<<(double v); + InfoPack &operator<<(bool v); + InfoPack &operator<<(void *v); + InfoPack &operator<<(PyObject *v); + InfoPack &operator<<(mindspore::BasePtr v); + InfoPack &operator<<(const std::string &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const std::vector &v); + InfoPack &operator<<(const InfoPack &v); + void AllocIfNeed(size_t need); + + protected: + size_t CalcBuffer(uint8_t *buf, size_t sz); + size_t CalcString(std::string v); + size_t id_; + std::unique_ptr buf_; + size_t ptr_; + size_t limit_; +}; +using InfoPackPtr = std::shared_ptr; + +} // namespace pijit +} // namespace mindspore + +#endif // MINDSPORE_PI_JIT_INFO_H diff --git a/mindspore/ccsrc/pipeline/jit/ps/CMakeLists.txt b/mindspore/ccsrc/pipeline/jit/ps/CMakeLists.txt index 9ba76f6d2ea..1b4cfa3c26a 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/CMakeLists.txt +++ b/mindspore/ccsrc/pipeline/jit/ps/CMakeLists.txt @@ -1,35 +1,35 @@ -file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "pipeline.cc" - "resource.cc" - "pass.cc" - "action.cc" - "validator.cc" - "remove_value_node_dup.cc" - "pipeline_split.cc" - "compile_cache_manager.cc" - "event_message_print.cc" - "fallback.cc" - "parse/*.cc" - "static_analysis/*.cc" - "debug/*.cc" - "load_mindir.cc" -) - -file(GLOB PIPELINE_SRC_FILES "*.cc") -set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) - -file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") -set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) - -file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") -set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) - -file(GLOB_RECURSE DEBUG_SRC_FILES "debug/*.cc") -set_property(SOURCE ${DEBUG_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEBUG) - -if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) - string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -endif() - -add_library(_mindspore_pipeline_jit_ps_obj OBJECT ${_PIPELINE_SRC_FILES}) +file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "pipeline.cc" + "resource.cc" + "pass.cc" + "action.cc" + "validator.cc" + "remove_value_node_dup.cc" + "pipeline_split.cc" + "compile_cache_manager.cc" + "event_message_print.cc" + "fallback.cc" + "parse/*.cc" + "static_analysis/*.cc" + "debug/*.cc" + "load_mindir.cc" +) + +file(GLOB PIPELINE_SRC_FILES "*.cc") +set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) + +file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") +set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) + +file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") +set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) + +file(GLOB_RECURSE DEBUG_SRC_FILES "debug/*.cc") +set_property(SOURCE ${DEBUG_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEBUG) + +if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) + string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +endif() + +add_library(_mindspore_pipeline_jit_ps_obj OBJECT ${_PIPELINE_SRC_FILES}) diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.cc index 27c6591093e..028b320d4bc 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.cc @@ -1,97 +1,97 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" -#include "utils/log_adapter.h" - -static constexpr const char *kAscendCollectiveFileName = "libascend_collective.so"; -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -HcclCollectiveGroup &HcclCollectiveGroup::instance() { - static HcclCollectiveGroup instance = {}; - return instance; -} - -void HcclCollectiveGroup::FinalizeCollective() { - MS_LOG(INFO) << "Finalize Collective"; - if (collective_handle_ != nullptr) { - MS_EXCEPTION_IF_NULL(finalize_mpi_); - finalize_mpi_(); - if (dlclose(collective_handle_) != 0) { - MS_LOG(EXCEPTION) << "Closing libascend_collective.so handle failed."; - } - collective_handle_ = nullptr; - } -} - -bool HcclCollectiveGroup::InitCollective() { - MS_LOG(INFO) << "InitCollective"; - if (inited_) { - return true; - } - collective_handle_ = dlopen(kAscendCollectiveFileName, RTLD_NOW); - if (collective_handle_ == nullptr) { - MS_LOG(DEBUG) << "Load lib" << kAscendCollectiveFileName << " failed, error message: " << dlerror(); - MS_LOG(EXCEPTION) - << "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not " - "installed.\n2.hccl is not " - "installed or found.\n3.mpi is not installed or found, please check if lib files of OpenMPI is added to " - "LD_LIBRARY_PATH or have the version specified in MindSpore document installed."; - } - init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_); - finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_); - get_group_comm_ = DlsymFuncObj(GetGroupComm, collective_handle_); - get_group_size_ = DlsymFuncObj(GetGroupSize, collective_handle_); - get_rank_id_by_group_ = DlsymFuncObj(GetRankIdByGroup, collective_handle_); - get_device_id_ = DlsymFuncObj(GetDeviceId, collective_handle_); - create_comm_for_group_ = DlsymFuncObj(CreateCommForGroup, collective_handle_); - destroy_hccl_comm_ = DlsymFuncObj(DestroyHcclComm, collective_handle_); - MS_EXCEPTION_IF_NULL(init_mpi_); - init_mpi_(); - inited_ = true; - MS_LOG(INFO) << "InitCollective success"; - return true; -} -HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) { - MS_EXCEPTION_IF_NULL(get_group_comm_); - return get_group_comm_(name); -} -int HcclCollectiveGroup::GetRankSize(const std::string &name) const { - MS_EXCEPTION_IF_NULL(get_group_size_); - return get_group_size_(name); -} -int HcclCollectiveGroup::GetRankId(const std::string &name) const { - MS_EXCEPTION_IF_NULL(get_rank_id_by_group_); - return get_rank_id_by_group_(name); -} -int HcclCollectiveGroup::GetDeviceId() const { - MS_EXCEPTION_IF_NULL(get_device_id_); - return get_device_id_(); -} -void HcclCollectiveGroup::CreateCommGroup(const std::string &name, const std::vector &ranks) { - MS_EXCEPTION_IF_NULL(create_comm_for_group_); - (void)create_comm_for_group_(name, ranks); -} -void HcclCollectiveGroup::DestroyCommGroup() { - MS_EXCEPTION_IF_NULL(destroy_hccl_comm_); - destroy_hccl_comm_(); -} -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" +#include "utils/log_adapter.h" + +static constexpr const char *kAscendCollectiveFileName = "libascend_collective.so"; +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +HcclCollectiveGroup &HcclCollectiveGroup::instance() { + static HcclCollectiveGroup instance = {}; + return instance; +} + +void HcclCollectiveGroup::FinalizeCollective() { + MS_LOG(INFO) << "Finalize Collective"; + if (collective_handle_ != nullptr) { + MS_EXCEPTION_IF_NULL(finalize_mpi_); + finalize_mpi_(); + if (dlclose(collective_handle_) != 0) { + MS_LOG(EXCEPTION) << "Closing libascend_collective.so handle failed."; + } + collective_handle_ = nullptr; + } +} + +bool HcclCollectiveGroup::InitCollective() { + MS_LOG(INFO) << "InitCollective"; + if (inited_) { + return true; + } + collective_handle_ = dlopen(kAscendCollectiveFileName, RTLD_NOW); + if (collective_handle_ == nullptr) { + MS_LOG(DEBUG) << "Load lib" << kAscendCollectiveFileName << " failed, error message: " << dlerror(); + MS_LOG(EXCEPTION) + << "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not " + "installed.\n2.hccl is not " + "installed or found.\n3.mpi is not installed or found, please check if lib files of OpenMPI is added to " + "LD_LIBRARY_PATH or have the version specified in MindSpore document installed."; + } + init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_); + finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_); + get_group_comm_ = DlsymFuncObj(GetGroupComm, collective_handle_); + get_group_size_ = DlsymFuncObj(GetGroupSize, collective_handle_); + get_rank_id_by_group_ = DlsymFuncObj(GetRankIdByGroup, collective_handle_); + get_device_id_ = DlsymFuncObj(GetDeviceId, collective_handle_); + create_comm_for_group_ = DlsymFuncObj(CreateCommForGroup, collective_handle_); + destroy_hccl_comm_ = DlsymFuncObj(DestroyHcclComm, collective_handle_); + MS_EXCEPTION_IF_NULL(init_mpi_); + init_mpi_(); + inited_ = true; + MS_LOG(INFO) << "InitCollective success"; + return true; +} +HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) { + MS_EXCEPTION_IF_NULL(get_group_comm_); + return get_group_comm_(name); +} +int HcclCollectiveGroup::GetRankSize(const std::string &name) const { + MS_EXCEPTION_IF_NULL(get_group_size_); + return get_group_size_(name); +} +int HcclCollectiveGroup::GetRankId(const std::string &name) const { + MS_EXCEPTION_IF_NULL(get_rank_id_by_group_); + return get_rank_id_by_group_(name); +} +int HcclCollectiveGroup::GetDeviceId() const { + MS_EXCEPTION_IF_NULL(get_device_id_); + return get_device_id_(); +} +void HcclCollectiveGroup::CreateCommGroup(const std::string &name, const std::vector &ranks) { + MS_EXCEPTION_IF_NULL(create_comm_for_group_); + (void)create_comm_for_group_(name, ranks); +} +void HcclCollectiveGroup::DestroyCommGroup() { + MS_EXCEPTION_IF_NULL(destroy_hccl_comm_); + destroy_hccl_comm_(); +} +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.h index a14fb00b98a..7bd87e66bd0 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/ascend_collective.h @@ -1,75 +1,75 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H -#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H - -#include -#include -#include -#include "hccl/hccl_types.h" -#include "include/common/utils/utils.h" -#include "utils/dlopen_macro.h" -#include "ops/ascend_op_name.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { - -ORIGIN_METHOD(InitMPI, void); -ORIGIN_METHOD(FinalizeMPI, void); -ORIGIN_METHOD(GetGroupComm, HcclComm, const std::string &); -ORIGIN_METHOD(GetGroupSize, int, const std::string &); -ORIGIN_METHOD(GetRankIdByGroup, int, const std::string &); -ORIGIN_METHOD(GetDeviceId, int); -ORIGIN_METHOD(CreateCommForGroup, bool, const std::string &, const std::vector &); -ORIGIN_METHOD(DestroyHcclComm, void); - -class HcclCollectiveGroup { - public: - HcclCollectiveGroup(HcclCollectiveGroup const &) = delete; - HcclCollectiveGroup &operator=(const HcclCollectiveGroup &) = delete; - static HcclCollectiveGroup &instance(); - bool InitCollective(); - void FinalizeCollective(); - HcclComm GetGroupComm(const std::string &name); - int GetDeviceId() const; - int GetRankId(const std::string &name = kHcclWorldGroup) const; - int GetRankSize(const std::string &name = kHcclWorldGroup) const; - void CreateCommGroup(const std::string &name, const std::vector &ranks); - void DestroyCommGroup(); - const void *collective_handle() const { return collective_handle_; } - - private: - HcclCollectiveGroup() = default; - ~HcclCollectiveGroup() = default; - bool inited_ = false; - void *collective_handle_ = nullptr; - InitMPIFunObj init_mpi_ = nullptr; - FinalizeMPIFunObj finalize_mpi_ = nullptr; - GetGroupCommFunObj get_group_comm_ = nullptr; - GetGroupSizeFunObj get_group_size_ = nullptr; - GetRankIdByGroupFunObj get_rank_id_by_group_ = nullptr; - GetDeviceIdFunObj get_device_id_ = nullptr; - CreateCommForGroupFunObj create_comm_for_group_ = nullptr; - DestroyHcclCommFunObj destroy_hccl_comm_ = nullptr; -}; -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H + +#include +#include +#include +#include "hccl/hccl_types.h" +#include "include/common/utils/utils.h" +#include "utils/dlopen_macro.h" +#include "ops/ascend_op_name.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { + +ORIGIN_METHOD(InitMPI, void); +ORIGIN_METHOD(FinalizeMPI, void); +ORIGIN_METHOD(GetGroupComm, HcclComm, const std::string &); +ORIGIN_METHOD(GetGroupSize, int, const std::string &); +ORIGIN_METHOD(GetRankIdByGroup, int, const std::string &); +ORIGIN_METHOD(GetDeviceId, int); +ORIGIN_METHOD(CreateCommForGroup, bool, const std::string &, const std::vector &); +ORIGIN_METHOD(DestroyHcclComm, void); + +class HcclCollectiveGroup { + public: + HcclCollectiveGroup(HcclCollectiveGroup const &) = delete; + HcclCollectiveGroup &operator=(const HcclCollectiveGroup &) = delete; + static HcclCollectiveGroup &instance(); + bool InitCollective(); + void FinalizeCollective(); + HcclComm GetGroupComm(const std::string &name); + int GetDeviceId() const; + int GetRankId(const std::string &name = kHcclWorldGroup) const; + int GetRankSize(const std::string &name = kHcclWorldGroup) const; + void CreateCommGroup(const std::string &name, const std::vector &ranks); + void DestroyCommGroup(); + const void *collective_handle() const { return collective_handle_; } + + private: + HcclCollectiveGroup() = default; + ~HcclCollectiveGroup() = default; + bool inited_ = false; + void *collective_handle_ = nullptr; + InitMPIFunObj init_mpi_ = nullptr; + FinalizeMPIFunObj finalize_mpi_ = nullptr; + GetGroupCommFunObj get_group_comm_ = nullptr; + GetGroupSizeFunObj get_group_size_ = nullptr; + GetRankIdByGroupFunObj get_rank_id_by_group_ = nullptr; + GetDeviceIdFunObj get_device_id_ = nullptr; + CreateCommForGroupFunObj create_comm_for_group_ = nullptr; + DestroyHcclCommFunObj destroy_hccl_comm_ = nullptr; +}; +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_ASCEND_COLLECTIVE_H diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.cc index a0476f45088..611825db5a5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.cc @@ -1,37 +1,37 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h" - -extern "C" { -void InitMPI() { (void)MPICollective::instance().Init(); } -void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); } -int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); } -int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); } -int GetGroupLocalRankSize(const std::string &name) { return MPICollective::instance().GetGroupLocalRankSize(name); } -int GetWorldRankIdFromGroup(const std::string &name, const int rank_id) { - return MPICollective::instance().GetWorldRankIdFromGroup(name, rank_id); -} -int GetGroupRankIdFromWorld(const std::string &name, const int rank_id) { - return MPICollective::instance().GetGroupRankIdFromWorld(name, rank_id); -} -int GetDeviceId() { return MPICollective::instance().GetDeviceId(); } -HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); } -bool CreateCommForGroup(const std::string &name, const std::vector &ranks) { - return MPICollective::instance().CreateCommGroup(name, ranks); -} -void DestroyHcclComm() { MPICollective::instance().DestroyHcclComm(); } -} +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h" + +extern "C" { +void InitMPI() { (void)MPICollective::instance().Init(); } +void FinalizeMPI() { MPICollective::instance().FinalizeMPI(); } +int GetRankIdByGroup(const std::string &name) { return MPICollective::instance().GetRankIdByGroup(name); } +int GetGroupSize(const std::string &name) { return MPICollective::instance().GetGroupSize(name); } +int GetGroupLocalRankSize(const std::string &name) { return MPICollective::instance().GetGroupLocalRankSize(name); } +int GetWorldRankIdFromGroup(const std::string &name, const int rank_id) { + return MPICollective::instance().GetWorldRankIdFromGroup(name, rank_id); +} +int GetGroupRankIdFromWorld(const std::string &name, const int rank_id) { + return MPICollective::instance().GetGroupRankIdFromWorld(name, rank_id); +} +int GetDeviceId() { return MPICollective::instance().GetDeviceId(); } +HcclComm GetGroupComm(const std::string &name) { return MPICollective::instance().GetGroupComm(name); } +bool CreateCommForGroup(const std::string &name, const std::vector &ranks) { + return MPICollective::instance().CreateCommGroup(name, ranks); +} +void DestroyHcclComm() { MPICollective::instance().DestroyHcclComm(); } +} diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h index f3da2a0e39f..dcec584deb6 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h @@ -1,39 +1,39 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H -#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H - -#include -#include -#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h" -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif -using MPICollective = mindspore::device::ascend::collective::MPICollective; - -extern "C" EXPORT_WRAPPER void InitMPI(); -extern "C" EXPORT_WRAPPER void FinalizeMPI(); -extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name); -extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name); -extern "C" EXPORT_WRAPPER int GetGroupLocalRankSize(const std::string &name); -extern "C" EXPORT_WRAPPER int GetWorldRankIdFromGroup(const std::string &name, const int rank_id); -extern "C" EXPORT_WRAPPER int GetGroupRankIdFromWorld(const std::string &name, const int rank_id); -extern "C" EXPORT_WRAPPER int GetDeviceId(); -extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name); -extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector &ranks); -extern "C" EXPORT_WRAPPER void DestroyHcclComm(); -#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H + +#include +#include +#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h" +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif +using MPICollective = mindspore::device::ascend::collective::MPICollective; + +extern "C" EXPORT_WRAPPER void InitMPI(); +extern "C" EXPORT_WRAPPER void FinalizeMPI(); +extern "C" EXPORT_WRAPPER int GetRankIdByGroup(const std::string &name); +extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &name); +extern "C" EXPORT_WRAPPER int GetGroupLocalRankSize(const std::string &name); +extern "C" EXPORT_WRAPPER int GetWorldRankIdFromGroup(const std::string &name, const int rank_id); +extern "C" EXPORT_WRAPPER int GetGroupRankIdFromWorld(const std::string &name, const int rank_id); +extern "C" EXPORT_WRAPPER int GetDeviceId(); +extern "C" EXPORT_WRAPPER HcclComm GetGroupComm(const std::string &name); +extern "C" EXPORT_WRAPPER bool CreateCommForGroup(const std::string &name, const std::vector &ranks); +extern "C" EXPORT_WRAPPER void DestroyHcclComm(); +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_GROUP_WRAPPER_H diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.cc index 4428c323a2d..552d5e36533 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.cc @@ -1,203 +1,203 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "hccl/hccl.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/symbol_utils.h" -#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -MPICollective::MPICollective() - : mpi_inited_(false), rank_id_(0), local_rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) {} -void MPICollective::FinalizeMPI() { - group_info_.clear(); - group_comm_.clear(); - int finalized; - (void)MPI_Finalized(&finalized); - if (finalized == 0) { - (void)MPI_Finalize(); - } -} - -MPICollective::~MPICollective() { - int finalized; - (void)MPI_Finalized(&finalized); - if (finalized == 0) { - (void)MPI_Finalize(); - } -} - -void MPICollective::DestroyHcclComm() { - for (auto iter = group_comm_.cbegin(); iter != group_comm_.cend(); ++iter) { - CHECK_RET(static_cast(HcclCommDestroy(iter->second)), static_cast(::HcclResult::HCCL_SUCCESS), - "HcclCommDestroy failed"); - } - group_comm_.clear(); -} - -MPICollective &MPICollective::instance() { - static MPICollective instance = {}; - return instance; -} - -int MPICollective::GetRankIdByGroup(const std::string &name) { - CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group rank by group name " + name)); - return std::get<0>(group_info_[name]); -} - -int MPICollective::GetGroupSize(const std::string &name) { - CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group size by group name " + name)); - return std::get<1>(group_info_[name]); -} - -int MPICollective::GetGroupLocalRankSize(const std::string &name) { - CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group local size by group name " + name)); - return std::get(group_info_[name]); -} - -int MPICollective::GetWorldRankIdFromGroup(const std::string &name, const int rank_id) { - CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI world rank from group by group name " + name)); - CHECK_RET(static_cast(world_map_[name].size()) > rank_id && rank_id >= 0, 1, - ("The rank_id " + std::to_string(rank_id) + "is not in the range of group " + name)); - CHECK_RET(rank_id >= 0, true, "The rank_id[" + std::to_string(rank_id) + "] must be greater equal than zero."); - return world_map_[name][static_cast(rank_id)]; -} - -int MPICollective::GetGroupRankIdFromWorld(const std::string &name, const int rank_id) { - CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI group rank from world by group name " + name)); - CHECK_RET(std::min(rank_size_ - 1, rank_id), rank_id, - ("The rank_id " + std::to_string(rank_id) + "is great than world rank size")); - CHECK_RET(std::count(world_map_[name].begin(), world_map_[name].end(), rank_id), 1, - ("The rank_id " + std::to_string(rank_id) + " is not in group " + name)); - return std::find(world_map_[name].begin(), world_map_[name].end(), rank_id) - world_map_[name].begin(); -} - -HcclComm MPICollective::GetGroupComm(const std::string &name) { - CHECK_RET(group_comm_.count(name), 1, ("Failed to get MPI group comm by group name " + name)); - return group_comm_[name]; -} - -int MPICollective::GetDeviceId() const { return local_rank_id_; } - -bool MPICollective::Init() { - int init_flag = 0; - CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!"); - if (init_flag == 0) { - CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to init mpi!"); - } - - CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_), MPI_SUCCESS, "comm_group_world_ init fail!"); - - CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id!"); - - CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!"); - AssignLocalRankID(); - group_info_["hccl_world_group"] = std::make_tuple(rank_id_, rank_size_, 0); - mpi_inited_ = true; - return true; -} - -bool MPICollective::CreateCommGroup(const std::string &name, const std::vector &ranks) { - CHECK_RET(mpi_inited_, true, "HcclCollectiveGroup has not been inited."); - CHECK_RET(ranks.empty(), false, "Ranks is empty."); - std::vector group_ranks(ranks.begin(), ranks.end()); - if (group_comm_.count(name) != 0) { - return true; - } - CHECK_RET(CALL_ASCEND_API(aclrtSetDevice, local_rank_id_), ACL_ERROR_NONE, "Call aclrtSetDevice error."); - HcclRootInfo rootInfo; - if (static_cast(rank_id_) == ranks[0]) { - CHECK_RET(static_cast(HcclGetRootInfo(&rootInfo)), static_cast(::HcclResult::HCCL_SUCCESS), - "HcclGetRootInfo failed."); - } - MPI_Group mpi_group = MPI_GROUP_NULL; - CHECK_RET(MPI_Group_incl(comm_group_world_, group_ranks.size(), group_ranks.data(), &mpi_group), MPI_SUCCESS, - "Create mpi group failed!"); - MPI_Comm mpi_group_comm; - - CHECK_RET(MPI_Comm_create_group(MPI_COMM_WORLD, mpi_group, 0, &mpi_group_comm), MPI_SUCCESS, "Create mpi comm fail!"); - - CHECK_RET(MPI_Bcast(&rootInfo, sizeof(rootInfo), MPI_BYTE, 0, mpi_group_comm), MPI_SUCCESS, - "Mpi reduce_scatter failed!"); - - HcclComm group_hcomm = nullptr; - int group_rank[1]; - int global_rank[1] = {rank_id_}; - CHECK_RET(MPI_Group_translate_ranks(comm_group_world_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS, - "Failed to translate global rank to group rank."); - if (group_rank[0] == MPI_UNDEFINED) { - return false; - } - - CHECK_RET(static_cast(HcclCommInitRootInfo(static_cast(ranks.size()), &rootInfo, - static_cast(group_rank[0]), &group_hcomm)), - static_cast(::HcclResult::HCCL_SUCCESS), "HcclCommInitRootInfo failed."); - group_comm_[name] = group_hcomm; - group_info_[name] = std::make_tuple(group_rank[0], static_cast(ranks.size()), 0); - AssignLocalRankSize(name, group_ranks, mpi_group_comm); - return true; -} - -void MPICollective::AssignLocalRankSize(const std::string &name, const std::vector &group_ranks, - MPI_Comm mpi_group_comm) { - char host_name[max_hostname_len] = {0}; - CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!"); - size_t host_hash = std::hash()(host_name); - - auto rank_size = group_ranks.size(); - std::vector all_host_hashs(rank_size); - for (size_t i = 0; i < rank_size; ++i) { - if (group_ranks[i] == rank_id_) { - all_host_hashs[i] = host_hash; - } - } - CHECK_RET( - MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs.data(), sizeof(size_t), MPI_BYTE, mpi_group_comm), - MPI_SUCCESS, "MPI_Allgather host hash failed."); - int local_rank_size = static_cast(std::count(all_host_hashs.begin(), all_host_hashs.end(), host_hash)); - std::get(group_info_[name]) = local_rank_size; - std::vector group_world_ranks(group_ranks.begin(), group_ranks.end()); - world_map_[name] = group_world_ranks; -} - -void MPICollective::AssignLocalRankID() { - char host_name[max_hostname_len] = {0}; - CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!"); - size_t host_hash = std::hash()(host_name); - - const int kRankSize = rank_size_; - size_t all_host_hashs[kRankSize]; - all_host_hashs[rank_id_] = host_hash; - CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), - MPI_SUCCESS, "MPI_Allgather host hash failed."); - for (int global_rank = 0; global_rank < kRankSize; global_rank++) { - if (global_rank == rank_id_) { - break; - } - if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { - local_rank_id_++; - } - } - return; -} -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "hccl/hccl.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/symbol_utils.h" +#include "plugin/device/ascend/hal/device/distribute/mpi_collective_group.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +MPICollective::MPICollective() + : mpi_inited_(false), rank_id_(0), local_rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) {} +void MPICollective::FinalizeMPI() { + group_info_.clear(); + group_comm_.clear(); + int finalized; + (void)MPI_Finalized(&finalized); + if (finalized == 0) { + (void)MPI_Finalize(); + } +} + +MPICollective::~MPICollective() { + int finalized; + (void)MPI_Finalized(&finalized); + if (finalized == 0) { + (void)MPI_Finalize(); + } +} + +void MPICollective::DestroyHcclComm() { + for (auto iter = group_comm_.cbegin(); iter != group_comm_.cend(); ++iter) { + CHECK_RET(static_cast(HcclCommDestroy(iter->second)), static_cast(::HcclResult::HCCL_SUCCESS), + "HcclCommDestroy failed"); + } + group_comm_.clear(); +} + +MPICollective &MPICollective::instance() { + static MPICollective instance = {}; + return instance; +} + +int MPICollective::GetRankIdByGroup(const std::string &name) { + CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group rank by group name " + name)); + return std::get<0>(group_info_[name]); +} + +int MPICollective::GetGroupSize(const std::string &name) { + CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group size by group name " + name)); + return std::get<1>(group_info_[name]); +} + +int MPICollective::GetGroupLocalRankSize(const std::string &name) { + CHECK_RET(group_info_.count(name), 1, ("Failed to get MPI group local size by group name " + name)); + return std::get(group_info_[name]); +} + +int MPICollective::GetWorldRankIdFromGroup(const std::string &name, const int rank_id) { + CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI world rank from group by group name " + name)); + CHECK_RET(static_cast(world_map_[name].size()) > rank_id && rank_id >= 0, 1, + ("The rank_id " + std::to_string(rank_id) + "is not in the range of group " + name)); + CHECK_RET(rank_id >= 0, true, "The rank_id[" + std::to_string(rank_id) + "] must be greater equal than zero."); + return world_map_[name][static_cast(rank_id)]; +} + +int MPICollective::GetGroupRankIdFromWorld(const std::string &name, const int rank_id) { + CHECK_RET(world_map_.count(name), 1, ("Failed to get MPI group rank from world by group name " + name)); + CHECK_RET(std::min(rank_size_ - 1, rank_id), rank_id, + ("The rank_id " + std::to_string(rank_id) + "is great than world rank size")); + CHECK_RET(std::count(world_map_[name].begin(), world_map_[name].end(), rank_id), 1, + ("The rank_id " + std::to_string(rank_id) + " is not in group " + name)); + return std::find(world_map_[name].begin(), world_map_[name].end(), rank_id) - world_map_[name].begin(); +} + +HcclComm MPICollective::GetGroupComm(const std::string &name) { + CHECK_RET(group_comm_.count(name), 1, ("Failed to get MPI group comm by group name " + name)); + return group_comm_[name]; +} + +int MPICollective::GetDeviceId() const { return local_rank_id_; } + +bool MPICollective::Init() { + int init_flag = 0; + CHECK_RET(MPI_Initialized(&init_flag), MPI_SUCCESS, "Check mpi initialized fail!"); + if (init_flag == 0) { + CHECK_RET(MPI_Init(nullptr, nullptr), MPI_SUCCESS, "Failed to init mpi!"); + } + + CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_), MPI_SUCCESS, "comm_group_world_ init fail!"); + + CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id!"); + + CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size!"); + AssignLocalRankID(); + group_info_["hccl_world_group"] = std::make_tuple(rank_id_, rank_size_, 0); + mpi_inited_ = true; + return true; +} + +bool MPICollective::CreateCommGroup(const std::string &name, const std::vector &ranks) { + CHECK_RET(mpi_inited_, true, "HcclCollectiveGroup has not been inited."); + CHECK_RET(ranks.empty(), false, "Ranks is empty."); + std::vector group_ranks(ranks.begin(), ranks.end()); + if (group_comm_.count(name) != 0) { + return true; + } + CHECK_RET(CALL_ASCEND_API(aclrtSetDevice, local_rank_id_), ACL_ERROR_NONE, "Call aclrtSetDevice error."); + HcclRootInfo rootInfo; + if (static_cast(rank_id_) == ranks[0]) { + CHECK_RET(static_cast(HcclGetRootInfo(&rootInfo)), static_cast(::HcclResult::HCCL_SUCCESS), + "HcclGetRootInfo failed."); + } + MPI_Group mpi_group = MPI_GROUP_NULL; + CHECK_RET(MPI_Group_incl(comm_group_world_, group_ranks.size(), group_ranks.data(), &mpi_group), MPI_SUCCESS, + "Create mpi group failed!"); + MPI_Comm mpi_group_comm; + + CHECK_RET(MPI_Comm_create_group(MPI_COMM_WORLD, mpi_group, 0, &mpi_group_comm), MPI_SUCCESS, "Create mpi comm fail!"); + + CHECK_RET(MPI_Bcast(&rootInfo, sizeof(rootInfo), MPI_BYTE, 0, mpi_group_comm), MPI_SUCCESS, + "Mpi reduce_scatter failed!"); + + HcclComm group_hcomm = nullptr; + int group_rank[1]; + int global_rank[1] = {rank_id_}; + CHECK_RET(MPI_Group_translate_ranks(comm_group_world_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS, + "Failed to translate global rank to group rank."); + if (group_rank[0] == MPI_UNDEFINED) { + return false; + } + + CHECK_RET(static_cast(HcclCommInitRootInfo(static_cast(ranks.size()), &rootInfo, + static_cast(group_rank[0]), &group_hcomm)), + static_cast(::HcclResult::HCCL_SUCCESS), "HcclCommInitRootInfo failed."); + group_comm_[name] = group_hcomm; + group_info_[name] = std::make_tuple(group_rank[0], static_cast(ranks.size()), 0); + AssignLocalRankSize(name, group_ranks, mpi_group_comm); + return true; +} + +void MPICollective::AssignLocalRankSize(const std::string &name, const std::vector &group_ranks, + MPI_Comm mpi_group_comm) { + char host_name[max_hostname_len] = {0}; + CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!"); + size_t host_hash = std::hash()(host_name); + + auto rank_size = group_ranks.size(); + std::vector all_host_hashs(rank_size); + for (size_t i = 0; i < rank_size; ++i) { + if (group_ranks[i] == rank_id_) { + all_host_hashs[i] = host_hash; + } + } + CHECK_RET( + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs.data(), sizeof(size_t), MPI_BYTE, mpi_group_comm), + MPI_SUCCESS, "MPI_Allgather host hash failed."); + int local_rank_size = static_cast(std::count(all_host_hashs.begin(), all_host_hashs.end(), host_hash)); + std::get(group_info_[name]) = local_rank_size; + std::vector group_world_ranks(group_ranks.begin(), group_ranks.end()); + world_map_[name] = group_world_ranks; +} + +void MPICollective::AssignLocalRankID() { + char host_name[max_hostname_len] = {0}; + CHECK_RET(gethostname(host_name, max_hostname_len), MPI_SUCCESS, "Getting host name failed!"); + size_t host_hash = std::hash()(host_name); + + const int kRankSize = rank_size_; + size_t all_host_hashs[kRankSize]; + all_host_hashs[rank_id_] = host_hash; + CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), + MPI_SUCCESS, "MPI_Allgather host hash failed."); + for (int global_rank = 0; global_rank < kRankSize; global_rank++) { + if (global_rank == rank_id_) { + break; + } + if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { + local_rank_id_++; + } + } + return; +} +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.h index a94b31cd425..68ea132dac8 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_collective_group.h @@ -1,81 +1,81 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H -#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H - -#include -#include -#include -#include -#include -#include -#include -#include "hccl/hccl_types.h" -#include "pybind11/pybind11.h" -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -constexpr int max_hostname_len = 1024; -constexpr int local_rank_size_index = 2; -class MPICollective { - public: - MPICollective(MPICollective const &) = delete; - MPICollective &operator=(const MPICollective &) = delete; - static MPICollective &instance(); - void AssignLocalRankID(); - void AssignLocalRankSize(); - bool Init(); - void FinalizeMPI(); - int GetRankIdByGroup(const std::string &name); - int GetGroupSize(const std::string &name); - int GetGroupLocalRankSize(const std::string &name); - int GetWorldRankIdFromGroup(const std::string &name, const int rank_id); - int GetGroupRankIdFromWorld(const std::string &name, const int rank_id); - void AssignLocalRankSize(const std::string &name, const std::vector &group_ranks, MPI_Comm mpi_group_comm); - HcclComm GetGroupComm(const std::string &name); - int GetDeviceId() const; - bool CreateCommGroup(const std::string &name, const std::vector &ranks); - void DestroyHcclComm(); - std::map group_comm_; - - private: - MPICollective(); - ~MPICollective(); - bool mpi_inited_; - int rank_id_; - int local_rank_id_; - int rank_size_; - MPI_Group comm_group_world_; - std::map> group_info_; - std::map> world_map_; -}; -#define CHECK_RET(expression, result, message) \ - { \ - auto ret = (expression); \ - if (ret != (result)) { \ - std::ostringstream oss; \ - oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ \ - << " | Ascend collective Error: " << (message) << " | Error Number " << ret; \ - pybind11::pybind11_fail(oss.str()); \ - } \ - } -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_INIT_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_COLLECTIVE_INIT_H + +#include +#include +#include +#include +#include +#include +#include +#include "hccl/hccl_types.h" +#include "pybind11/pybind11.h" +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +constexpr int max_hostname_len = 1024; +constexpr int local_rank_size_index = 2; +class MPICollective { + public: + MPICollective(MPICollective const &) = delete; + MPICollective &operator=(const MPICollective &) = delete; + static MPICollective &instance(); + void AssignLocalRankID(); + void AssignLocalRankSize(); + bool Init(); + void FinalizeMPI(); + int GetRankIdByGroup(const std::string &name); + int GetGroupSize(const std::string &name); + int GetGroupLocalRankSize(const std::string &name); + int GetWorldRankIdFromGroup(const std::string &name, const int rank_id); + int GetGroupRankIdFromWorld(const std::string &name, const int rank_id); + void AssignLocalRankSize(const std::string &name, const std::vector &group_ranks, MPI_Comm mpi_group_comm); + HcclComm GetGroupComm(const std::string &name); + int GetDeviceId() const; + bool CreateCommGroup(const std::string &name, const std::vector &ranks); + void DestroyHcclComm(); + std::map group_comm_; + + private: + MPICollective(); + ~MPICollective(); + bool mpi_inited_; + int rank_id_; + int local_rank_id_; + int rank_size_; + MPI_Group comm_group_world_; + std::map> group_info_; + std::map> world_map_; +}; +#define CHECK_RET(expression, result, message) \ + { \ + auto ret = (expression); \ + if (ret != (result)) { \ + std::ostringstream oss; \ + oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ \ + << " | Ascend collective Error: " << (message) << " | Error Number " << ret; \ + pybind11::pybind11_fail(oss.str()); \ + } \ + } +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_COLLECTIVE_INIT_H diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.cc index 33db5febcfa..69ff90441a5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.cc @@ -1,60 +1,60 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/device/distribute/mpi_pycc.h" -#include -#include -#include - -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -MpiPycc &MpiPycc::instance() { - static MpiPycc instance = {}; - return instance; -} - -int MpiPycc::GetDeviceID() { return GetDeviceId(); } -int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); } -int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); } -int MpiPycc::GetLocalRankSize(const std::string &group) { return GetGroupLocalRankSize(group); } -int MpiPycc::GetGroupRankFromWorld(const int rank_id, const std::string &group) { - return GetGroupRankIdFromWorld(group, rank_id); -} -int MpiPycc::GetWorldRankFromGroup(const std::string &group, const int rank_id) { - return GetWorldRankIdFromGroup(group, rank_id); -} -void MpiPycc::CreateGroup(const std::string &group, const std::vector &ranks) { - (void)CreateCommForGroup(group, ranks); -} - -// cppcheck-suppress syntaxError -PYBIND11_MODULE(_ascend_mpi, mpi_initializer) { - (void)mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id"); - (void)mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id"); - (void)mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size"); - (void)mpi_initializer.def("get_local_rank_size", &MpiPycc::GetLocalRankSize, "get local rank size"); - (void)mpi_initializer.def("get_group_rank_from_world_rank", &MpiPycc::GetGroupRankFromWorld, - "get group rank from world rank"); - (void)mpi_initializer.def("get_world_rank_from_group_rank", &MpiPycc::GetWorldRankFromGroup, - "get world rank from group rank"); - (void)mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group"); -} -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/device/distribute/mpi_pycc.h" +#include +#include +#include + +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +MpiPycc &MpiPycc::instance() { + static MpiPycc instance = {}; + return instance; +} + +int MpiPycc::GetDeviceID() { return GetDeviceId(); } +int MpiPycc::GetRankId(const std::string &group) { return GetRankIdByGroup(group); } +int MpiPycc::GetRankSize(const std::string &group) { return GetGroupSize(group); } +int MpiPycc::GetLocalRankSize(const std::string &group) { return GetGroupLocalRankSize(group); } +int MpiPycc::GetGroupRankFromWorld(const int rank_id, const std::string &group) { + return GetGroupRankIdFromWorld(group, rank_id); +} +int MpiPycc::GetWorldRankFromGroup(const std::string &group, const int rank_id) { + return GetWorldRankIdFromGroup(group, rank_id); +} +void MpiPycc::CreateGroup(const std::string &group, const std::vector &ranks) { + (void)CreateCommForGroup(group, ranks); +} + +// cppcheck-suppress syntaxError +PYBIND11_MODULE(_ascend_mpi, mpi_initializer) { + (void)mpi_initializer.def("get_device_id", &MpiPycc::GetDeviceID, "get device id"); + (void)mpi_initializer.def("get_rank_id", &MpiPycc::GetRankId, "get rank id"); + (void)mpi_initializer.def("get_rank_size", &MpiPycc::GetRankSize, "get rank size"); + (void)mpi_initializer.def("get_local_rank_size", &MpiPycc::GetLocalRankSize, "get local rank size"); + (void)mpi_initializer.def("get_group_rank_from_world_rank", &MpiPycc::GetGroupRankFromWorld, + "get group rank from world rank"); + (void)mpi_initializer.def("get_world_rank_from_group_rank", &MpiPycc::GetWorldRankFromGroup, + "get world rank from group rank"); + (void)mpi_initializer.def("create_group", &MpiPycc::CreateGroup, "create group"); +} +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.h index dc485311a96..c8a0c5a7204 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/distribute/mpi_pycc.h @@ -1,49 +1,49 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H -#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H - -#include -#include -#include "plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -class MpiPycc { - public: - MpiPycc(MpiPycc const &) = delete; - MpiPycc &operator=(const MpiPycc &) = delete; - static MpiPycc &instance(); - static int GetDeviceID(); - static int GetRankId(const std::string &group); - static int GetRankSize(const std::string &group); - static int GetLocalRankSize(const std::string &group); - static int GetGroupRankFromWorld(const int rank_id, const std::string &group); - static int GetWorldRankFromGroup(const std::string &group, const int rank_id); - static void CreateGroup(const std::string &group, const std::vector &ranks); - - private: - MpiPycc() = default; - ~MpiPycc() = default; -}; -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H + +#include +#include +#include "plugin/device/ascend/hal/device/distribute/collective_group_wrapper.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +class MpiPycc { + public: + MpiPycc(MpiPycc const &) = delete; + MpiPycc &operator=(const MpiPycc &) = delete; + static MpiPycc &instance(); + static int GetDeviceID(); + static int GetRankId(const std::string &group); + static int GetRankSize(const std::string &group); + static int GetLocalRankSize(const std::string &group); + static int GetGroupRankFromWorld(const int rank_id, const std::string &group); + static int GetWorldRankFromGroup(const std::string &group, const int rank_id); + static void CreateGroup(const std::string &group, const std::vector &ranks); + + private: + MpiPycc() = default; + ~MpiPycc() = default; +}; +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DISTRIBUTE_MPI_PYCC_H diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.cc index 5e05a4c30bc..d42f97ed48c 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.cc @@ -1,297 +1,297 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h" -#include "ops/ascend_op_name.h" -#include "plugin/device/ascend/hal/common/ascend_utils.h" -#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h" -#include "runtime/hardware/device_context_manager.h" -#include "utils/convert_utils_base.h" -#include "utils/ms_context.h" - -constexpr size_t kPathMax = 4096; -namespace mindspore { -namespace device { -namespace ascend { -#define HCCL_RUN_CHECK(op_name, group, op) \ - do { \ - auto hccl_result = static_cast(op); \ - if (hccl_result != 0) { \ - MS_LOG(ERROR) << (op_name) << " failed: #" << (group) << "#"; \ - return false; \ - } \ - } while (0) - -#define HCCL_GROUP_CHECK_EMPTY(group) \ - do { \ - if ((group).length() == 0) { \ - MS_LOG(ERROR) << "The length of group name should not be 0"; \ - return false; \ - } \ - } while (0) - -#define HCCL_GROUP_CHECK_IS_WORLD(group) \ - do { \ - if ((group) == kHcclWorldGroup) { \ - MS_LOG(ERROR) << "The group name should not be " << kHcclWorldGroup; \ - return false; \ - } \ - } while (0) -AscendCollectiveCommLib::AscendCollectiveCommLib() { global_group_name_ = kHCCLGlobalGroupName; } - -bool AscendCollectiveCommLib::InitializeHccl() { - if (initialized_) { - return false; - } - auto ms_context = MsContext::GetInstance(); - ms_context->set_param(MS_CTX_ENABLE_HCCL, true); - MS_LOG(INFO) << "Create hccl_world_group with rank table."; - auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); - if (config_path_str == nullptr) { - config_path_str = std::getenv("RANK_TABLE_FILE"); - if (config_path_str == nullptr) { - MS_LOG(ERROR) << "The environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE' is not set, so get" - << " hccl json config failed, please set env 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE'"; - return false; - } - } - if (strlen(config_path_str) >= kPathMax) { - MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE', the path length" - << " should be smaller than " << kPathMax << ", but got " << config_path_str; - return false; - } - auto full_path = realpath(config_path_str, nullptr); - if (full_path == nullptr) { - MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE', the path is: " - << config_path_str << ". Please check (1) whether the path exists, " - << "(2) whether the path has the access permission, (3) whether the path is too long. "; - return false; - } - auto rank_id_str = common::GetEnv("RANK_ID"); - if (rank_id_str.empty()) { - MS_LOG(EXCEPTION) << "Invalid environment variable 'RANK_ID', it should not be empty."; - } - auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; - - auto mode = ms_context->get_param(MS_CTX_EXECUTION_MODE); - hccl::HcclMode hccl_mode = hccl::HcclMode::kGraph; - if (mode == kPynativeMode) { - hccl_mode = hccl::HcclMode::kPynative; - } else if (ms_context->IsKByKExecutorMode()) { - hccl_mode = hccl::HcclMode::kKernelByKernel; - } - - bool ret = hccl::HcclAdapter::GetInstance().InitHccl(device_id, rank_id_str, full_path, hccl_mode); - free(full_path); - if (!ret) { - MS_LOG(ERROR) << "Hcom init failed."; - return false; - } - initialized_ = true; - finalized_ = false; - return true; -} - -bool AscendCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) { - if (initialized_) { - return false; - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); - MS_EXCEPTION_IF_NULL(device_context); - MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); - (void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context); - try { - if (!common::GetEnv(kSimulationLevel).empty()) { - std::string rank_id_str = std::to_string(0); - (void)hccl::HcclAdapter::GetInstance().InitHccl(local_rank_id, rank_id_str); - } else if (!common::UseHostCollective()) { - // Use rank table to launch distribtued job. - MS_LOG(WARNING) - << "Launch Ascend distributed job in RankTable manner. This manner will be deprecated in later version of " - "MindSpore. \n Please switch to 'msrun' or 'mpirun'. You can refer to this link about how to use these " - "commands: https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/startup_method.html."; - return InitializeHccl(); - } else { - if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { - // Use dynamic cluster and hccl's CM envs to launch distributed job. This method is similar to rank table. It - // only supports to run in graph sink mode. - MS_LOG(INFO) << "Launch Ascend distributed job using hccl CM envs."; - } - std::string rank_id_str = std::to_string(global_rank); - (void)hccl::HcclAdapter::GetInstance().InitHccl(local_rank_id, rank_id_str); - } - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Ascend collective communication initialization failed.#dmsg#Framework Error Message:#dmsg#" - << e.what(); - } - ms_context->set_param(MS_CTX_ENABLE_HCCL, true); - global_rank_id_ = global_rank; - global_rank_size_ = global_rank_size; - local_rank_id_ = local_rank_id; - initialized_ = true; - finalized_ = false; - return true; -} - -bool AscendCollectiveCommLib::DestroyHcclComm() { - for (auto &group : groups_) { - CHECK_IF_NULL(group.second); - if (!group.second->Finalize()) { - return false; - } - } - groups_.clear(); - bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl(); - if (!res) { - MS_LOG(ERROR) << "Hccl finalize failed"; - return false; - } - return true; -} - -bool AscendCollectiveCommLib::DestroyDeviceCommunicationGroup(const std::string &group_name) { - HCCL_GROUP_CHECK_EMPTY(group_name); - HCCL_RUN_CHECK(std::string("destroy communicate group"), group_name, - hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group_name)); - return true; -} - -bool AscendCollectiveCommLib::DestroyCommunicationGroup(const std::string &group_name) { - // If using hccl CM, we reuse rank table launching interfaces. - if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { - return DestroyDeviceCommunicationGroup(group_name); - } - - HCCL_GROUP_CHECK_EMPTY(group_name); - CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not exist."); - - if (!groups_[group_name]->Finalize()) { - return false; - } - return true; -} - -bool AscendCollectiveCommLib::CreateDeviceCommunicationGroup(const std::string &group_name, - const std::vector &group_ranks) { - HCCL_GROUP_CHECK_EMPTY(group_name); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { - MS_LOG(ERROR) << "Creating custom communication group is not allowed in PyNative mode."; - return false; - } - auto rank_size = group_ranks.size(); - HCCL_RUN_CHECK(std::string("create communicate group"), group_name, - hccl::HcclAdapter::GetInstance().HcclCreateGroup(group_name, UlongToUint(rank_size), - std::vector(group_ranks).data())); - return true; -} - -bool AscendCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, - const std::vector &group_ranks, - uint32_t local_group_rank, uint32_t local_group_size) { - HCCL_GROUP_CHECK_EMPTY(group_name); - CHECK_RET((groups_.count(group_name) == 0), true, "The HCCL group " + group_name + " has already existed."); - - AscendCommunicationGroupPtr group = std::make_shared( - group_name, group_ranks, global_rank_id_, local_group_rank, local_group_size); - CHECK_IF_NULL(group); - groups_[group_name] = group; - - // If using hccl CM, we reuse rank table launching interfaces. - // It does not support to create hccl_world_group. - if (hccl::HcclAdapter::GetInstance().UseHcclCM() && group_name != kHCCLGlobalGroupName) { - return CreateDeviceCommunicationGroup(group_name, group_ranks); - } - return true; -} - -HcclComm AscendCollectiveCommLib::HcclCommunicator(const std::string &group_name) { - if (!common::UseHostCollective() || hccl::HcclAdapter::GetInstance().UseHcclCM()) { - return hccl::HcclAdapter::GetInstance().get_hccl_comm(); - } - CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not existed."); - auto group = std::dynamic_pointer_cast(groups_[group_name]); - CHECK_IF_NULL(group); - return group->hccl_communicator(); -} - -std::string AscendCollectiveCommLib::HcclInnerCommName(const std::string &group_name) { - if (!common::UseHostCollective() || hccl::HcclAdapter::GetInstance().UseHcclCM()) { - return ""; - } - CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not existed."); - auto group = std::dynamic_pointer_cast(groups_[group_name]); - CHECK_IF_NULL(group); - return group->inner_comm_name(); -} - -uint32_t AscendCollectiveCommLib::GetRankId(const std::string &group_name) { - uint32_t rank_id = 0; - HCCL_RUN_CHECK(std::string("get rank_id"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetRankId(group_name, &rank_id)); - return rank_id; -} - -uint32_t AscendCollectiveCommLib::GetGroupSize(const std::string &group_name) { - HCCL_GROUP_CHECK_EMPTY(group_name); - uint32_t rank_size = 0; - HCCL_RUN_CHECK(std::string("get rank size"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetRankSize(group_name, &rank_size)); - return rank_size; -} - -uint32_t AscendCollectiveCommLib::GetLocalRankId(const std::string &group_name) { - uint32_t rank_id = 0; - HCCL_RUN_CHECK(std::string("get rank_id"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetLocalRankId(group_name, &rank_id)); - return rank_id; -} - -uint32_t AscendCollectiveCommLib::GetLocalGroupSize(const std::string &group_name) { - HCCL_GROUP_CHECK_EMPTY(group_name); - uint32_t rank_size = 0; - HCCL_RUN_CHECK(std::string("get rank size"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetLocalRankSize(group_name, &rank_size)); - return rank_size; -} - -uint32_t AscendCollectiveCommLib::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) { - uint32_t world_rank_id = 0; - HCCL_RUN_CHECK( - std::string("get world rank id"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetWorldRankFromGroupRank(group_name, local_rank, &world_rank_id)); - return world_rank_id; -} - -uint32_t AscendCollectiveCommLib::GetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group_name) { - uint32_t local_rank_id = 0; - HCCL_RUN_CHECK( - std::string("get local rank id"), group_name, - hccl::HcclAdapter::GetInstance().HcclGetGroupRankFromWorldRank(world_rank, group_name, &local_rank_id)); - return local_rank_id; -} -} // namespace ascend - -using AscendCollectiveCommLib = mindspore::device::ascend::AscendCollectiveCommLib; - -CollectiveCommunicationLib *communication_lib_instance() { return &AscendCollectiveCommLib::GetInstance(); } -} // namespace device -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h" +#include "ops/ascend_op_name.h" +#include "plugin/device/ascend/hal/common/ascend_utils.h" +#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h" +#include "runtime/hardware/device_context_manager.h" +#include "utils/convert_utils_base.h" +#include "utils/ms_context.h" + +constexpr size_t kPathMax = 4096; +namespace mindspore { +namespace device { +namespace ascend { +#define HCCL_RUN_CHECK(op_name, group, op) \ + do { \ + auto hccl_result = static_cast(op); \ + if (hccl_result != 0) { \ + MS_LOG(ERROR) << (op_name) << " failed: #" << (group) << "#"; \ + return false; \ + } \ + } while (0) + +#define HCCL_GROUP_CHECK_EMPTY(group) \ + do { \ + if ((group).length() == 0) { \ + MS_LOG(ERROR) << "The length of group name should not be 0"; \ + return false; \ + } \ + } while (0) + +#define HCCL_GROUP_CHECK_IS_WORLD(group) \ + do { \ + if ((group) == kHcclWorldGroup) { \ + MS_LOG(ERROR) << "The group name should not be " << kHcclWorldGroup; \ + return false; \ + } \ + } while (0) +AscendCollectiveCommLib::AscendCollectiveCommLib() { global_group_name_ = kHCCLGlobalGroupName; } + +bool AscendCollectiveCommLib::InitializeHccl() { + if (initialized_) { + return false; + } + auto ms_context = MsContext::GetInstance(); + ms_context->set_param(MS_CTX_ENABLE_HCCL, true); + MS_LOG(INFO) << "Create hccl_world_group with rank table."; + auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); + if (config_path_str == nullptr) { + config_path_str = std::getenv("RANK_TABLE_FILE"); + if (config_path_str == nullptr) { + MS_LOG(ERROR) << "The environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE' is not set, so get" + << " hccl json config failed, please set env 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE'"; + return false; + } + } + if (strlen(config_path_str) >= kPathMax) { + MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE', the path length" + << " should be smaller than " << kPathMax << ", but got " << config_path_str; + return false; + } + auto full_path = realpath(config_path_str, nullptr); + if (full_path == nullptr) { + MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_HCCL_CONFIG_PATH' or 'RANK_TABLE_FILE', the path is: " + << config_path_str << ". Please check (1) whether the path exists, " + << "(2) whether the path has the access permission, (3) whether the path is too long. "; + return false; + } + auto rank_id_str = common::GetEnv("RANK_ID"); + if (rank_id_str.empty()) { + MS_LOG(EXCEPTION) << "Invalid environment variable 'RANK_ID', it should not be empty."; + } + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; + + auto mode = ms_context->get_param(MS_CTX_EXECUTION_MODE); + hccl::HcclMode hccl_mode = hccl::HcclMode::kGraph; + if (mode == kPynativeMode) { + hccl_mode = hccl::HcclMode::kPynative; + } else if (ms_context->IsKByKExecutorMode()) { + hccl_mode = hccl::HcclMode::kKernelByKernel; + } + + bool ret = hccl::HcclAdapter::GetInstance().InitHccl(device_id, rank_id_str, full_path, hccl_mode); + free(full_path); + if (!ret) { + MS_LOG(ERROR) << "Hcom init failed."; + return false; + } + initialized_ = true; + finalized_ = false; + return true; +} + +bool AscendCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) { + if (initialized_) { + return false; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); + (void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context); + try { + if (!common::GetEnv(kSimulationLevel).empty()) { + std::string rank_id_str = std::to_string(0); + (void)hccl::HcclAdapter::GetInstance().InitHccl(local_rank_id, rank_id_str); + } else if (!common::UseHostCollective()) { + // Use rank table to launch distribtued job. + MS_LOG(WARNING) + << "Launch Ascend distributed job in RankTable manner. This manner will be deprecated in later version of " + "MindSpore. \n Please switch to 'msrun' or 'mpirun'. You can refer to this link about how to use these " + "commands: https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/startup_method.html."; + return InitializeHccl(); + } else { + if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { + // Use dynamic cluster and hccl's CM envs to launch distributed job. This method is similar to rank table. It + // only supports to run in graph sink mode. + MS_LOG(INFO) << "Launch Ascend distributed job using hccl CM envs."; + } + std::string rank_id_str = std::to_string(global_rank); + (void)hccl::HcclAdapter::GetInstance().InitHccl(local_rank_id, rank_id_str); + } + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Ascend collective communication initialization failed.#dmsg#Framework Error Message:#dmsg#" + << e.what(); + } + ms_context->set_param(MS_CTX_ENABLE_HCCL, true); + global_rank_id_ = global_rank; + global_rank_size_ = global_rank_size; + local_rank_id_ = local_rank_id; + initialized_ = true; + finalized_ = false; + return true; +} + +bool AscendCollectiveCommLib::DestroyHcclComm() { + for (auto &group : groups_) { + CHECK_IF_NULL(group.second); + if (!group.second->Finalize()) { + return false; + } + } + groups_.clear(); + bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl(); + if (!res) { + MS_LOG(ERROR) << "Hccl finalize failed"; + return false; + } + return true; +} + +bool AscendCollectiveCommLib::DestroyDeviceCommunicationGroup(const std::string &group_name) { + HCCL_GROUP_CHECK_EMPTY(group_name); + HCCL_RUN_CHECK(std::string("destroy communicate group"), group_name, + hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group_name)); + return true; +} + +bool AscendCollectiveCommLib::DestroyCommunicationGroup(const std::string &group_name) { + // If using hccl CM, we reuse rank table launching interfaces. + if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { + return DestroyDeviceCommunicationGroup(group_name); + } + + HCCL_GROUP_CHECK_EMPTY(group_name); + CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not exist."); + + if (!groups_[group_name]->Finalize()) { + return false; + } + return true; +} + +bool AscendCollectiveCommLib::CreateDeviceCommunicationGroup(const std::string &group_name, + const std::vector &group_ranks) { + HCCL_GROUP_CHECK_EMPTY(group_name); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + MS_LOG(ERROR) << "Creating custom communication group is not allowed in PyNative mode."; + return false; + } + auto rank_size = group_ranks.size(); + HCCL_RUN_CHECK(std::string("create communicate group"), group_name, + hccl::HcclAdapter::GetInstance().HcclCreateGroup(group_name, UlongToUint(rank_size), + std::vector(group_ranks).data())); + return true; +} + +bool AscendCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, + const std::vector &group_ranks, + uint32_t local_group_rank, uint32_t local_group_size) { + HCCL_GROUP_CHECK_EMPTY(group_name); + CHECK_RET((groups_.count(group_name) == 0), true, "The HCCL group " + group_name + " has already existed."); + + AscendCommunicationGroupPtr group = std::make_shared( + group_name, group_ranks, global_rank_id_, local_group_rank, local_group_size); + CHECK_IF_NULL(group); + groups_[group_name] = group; + + // If using hccl CM, we reuse rank table launching interfaces. + // It does not support to create hccl_world_group. + if (hccl::HcclAdapter::GetInstance().UseHcclCM() && group_name != kHCCLGlobalGroupName) { + return CreateDeviceCommunicationGroup(group_name, group_ranks); + } + return true; +} + +HcclComm AscendCollectiveCommLib::HcclCommunicator(const std::string &group_name) { + if (!common::UseHostCollective() || hccl::HcclAdapter::GetInstance().UseHcclCM()) { + return hccl::HcclAdapter::GetInstance().get_hccl_comm(); + } + CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return group->hccl_communicator(); +} + +std::string AscendCollectiveCommLib::HcclInnerCommName(const std::string &group_name) { + if (!common::UseHostCollective() || hccl::HcclAdapter::GetInstance().UseHcclCM()) { + return ""; + } + CHECK_RET((groups_.count(group_name) != 0), true, "The HCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return group->inner_comm_name(); +} + +uint32_t AscendCollectiveCommLib::GetRankId(const std::string &group_name) { + uint32_t rank_id = 0; + HCCL_RUN_CHECK(std::string("get rank_id"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetRankId(group_name, &rank_id)); + return rank_id; +} + +uint32_t AscendCollectiveCommLib::GetGroupSize(const std::string &group_name) { + HCCL_GROUP_CHECK_EMPTY(group_name); + uint32_t rank_size = 0; + HCCL_RUN_CHECK(std::string("get rank size"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetRankSize(group_name, &rank_size)); + return rank_size; +} + +uint32_t AscendCollectiveCommLib::GetLocalRankId(const std::string &group_name) { + uint32_t rank_id = 0; + HCCL_RUN_CHECK(std::string("get rank_id"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetLocalRankId(group_name, &rank_id)); + return rank_id; +} + +uint32_t AscendCollectiveCommLib::GetLocalGroupSize(const std::string &group_name) { + HCCL_GROUP_CHECK_EMPTY(group_name); + uint32_t rank_size = 0; + HCCL_RUN_CHECK(std::string("get rank size"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetLocalRankSize(group_name, &rank_size)); + return rank_size; +} + +uint32_t AscendCollectiveCommLib::GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) { + uint32_t world_rank_id = 0; + HCCL_RUN_CHECK( + std::string("get world rank id"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetWorldRankFromGroupRank(group_name, local_rank, &world_rank_id)); + return world_rank_id; +} + +uint32_t AscendCollectiveCommLib::GetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group_name) { + uint32_t local_rank_id = 0; + HCCL_RUN_CHECK( + std::string("get local rank id"), group_name, + hccl::HcclAdapter::GetInstance().HcclGetGroupRankFromWorldRank(world_rank, group_name, &local_rank_id)); + return local_rank_id; +} +} // namespace ascend + +using AscendCollectiveCommLib = mindspore::device::ascend::AscendCollectiveCommLib; + +CollectiveCommunicationLib *communication_lib_instance() { return &AscendCollectiveCommLib::GetInstance(); } +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h index 686b3ea2415..b19d653f04d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h @@ -1,81 +1,81 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ -#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ - -#include -#include -#include -#include -#include "runtime/collective/collective_communication_lib.h" -#include "plugin/device/ascend/hal/hardware/ascend_communication_group.h" - -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif - -namespace mindspore { -namespace device { -namespace ascend { -constexpr char kHCCLGlobalGroupName[] = "hccl_world_group"; - -class EXPORT_WRAPPER AscendCollectiveCommLib : public CollectiveCommunicationLib { - public: - static AscendCollectiveCommLib &GetInstance() { - static AscendCollectiveCommLib instance; - return instance; - } - - bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override; - - bool InitializeHccl(); - - bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks, - uint32_t local_group_rank, uint32_t local_group_size) override; - - bool CreateDeviceCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) override; - - bool DestroyCommunicationGroup(const std::string &group_name) override; - - bool DestroyDeviceCommunicationGroup(const std::string &group_name) override; - - uint32_t GetRankId(const std::string &group_name) override; - - uint32_t GetGroupSize(const std::string &group_name) override; - - uint32_t GetLocalRankId(const std::string &group_name) override; - - uint32_t GetLocalGroupSize(const std::string &group_name) override; - - uint32_t GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) override; - - uint32_t GetGroupRankFromWorldRank(uint32_t group_rank, const std::string &group_name) override; - - HcclComm HcclCommunicator(const std::string &group_name); - - std::string HcclInnerCommName(const std::string &group_name); - - bool DestroyHcclComm(); - - private: - AscendCollectiveCommLib(); - ~AscendCollectiveCommLib() override = default; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ +#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ + +#include +#include +#include +#include +#include "runtime/collective/collective_communication_lib.h" +#include "plugin/device/ascend/hal/hardware/ascend_communication_group.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +namespace mindspore { +namespace device { +namespace ascend { +constexpr char kHCCLGlobalGroupName[] = "hccl_world_group"; + +class EXPORT_WRAPPER AscendCollectiveCommLib : public CollectiveCommunicationLib { + public: + static AscendCollectiveCommLib &GetInstance() { + static AscendCollectiveCommLib instance; + return instance; + } + + bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override; + + bool InitializeHccl(); + + bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks, + uint32_t local_group_rank, uint32_t local_group_size) override; + + bool CreateDeviceCommunicationGroup(const std::string &group_name, const std::vector &group_ranks) override; + + bool DestroyCommunicationGroup(const std::string &group_name) override; + + bool DestroyDeviceCommunicationGroup(const std::string &group_name) override; + + uint32_t GetRankId(const std::string &group_name) override; + + uint32_t GetGroupSize(const std::string &group_name) override; + + uint32_t GetLocalRankId(const std::string &group_name) override; + + uint32_t GetLocalGroupSize(const std::string &group_name) override; + + uint32_t GetWorldRankFromGroupRank(const std::string &group_name, uint32_t local_rank) override; + + uint32_t GetGroupRankFromWorldRank(uint32_t group_rank, const std::string &group_name) override; + + HcclComm HcclCommunicator(const std::string &group_name); + + std::string HcclInnerCommName(const std::string &group_name); + + bool DestroyHcclComm(); + + private: + AscendCollectiveCommLib(); + ~AscendCollectiveCommLib() override = default; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COLLECTIVE_COMM_LIB_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.cc index a8ee82cb99f..304a2ca0305 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.cc @@ -1,128 +1,128 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/hardware/ascend_communication_group.h" -#include "plugin/device/ascend/hal/common/ascend_utils.h" -#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h" -#include "mindspore/core/utils/ms_context.h" -#include "transform/symbol/acl_rt_symbol.h" -#include "transform/symbol/acl_symbol.h" -#include "transform/symbol/symbol_utils.h" - -namespace mindspore { -namespace device { -namespace ascend { -AscendCommunicationGroup::AscendCommunicationGroup(const std::string &name, const std::vector &group_ranks, - uint32_t global_rank, uint32_t local_group_rank, - uint32_t local_group_size) - : CommunicationGroup(name, group_ranks, global_rank, local_group_rank, local_group_size), - unique_id_({}), - comm_(nullptr) { - (void)memset_s(inner_comm_name_, INNER_COMM_NAME_MAX_LENGTH, 0x00, INNER_COMM_NAME_MAX_LENGTH); -} - -bool AscendCommunicationGroup::Initialize(void *root_info) { - if (initialized_) { - return false; - } - if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { - // If using hccl CM envs to launch distributed job, no need to call HcclCommInitRootInfo. The group will be - // initialized in rank table way. - initialized_ = true; - return true; - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - (void)CALL_ASCEND_API(aclrtSetDevice, device_id); - unique_id_ = *(static_cast(root_info)); - uint32_t group_rank; - auto group_size = size_; - if (!common::GetEnv(kSimulationLevel).empty()) { - group_size = 1; - group_rank = 0; - } else { - group_rank = GetGroupRank(global_rank_); - } - if (HcclCommInitRootInfo(static_cast(group_size), &unique_id_, static_cast(group_rank), &comm_) != - static_cast(HCCL_SUCCESS)) { - const string &error_message = ErrorManagerAdapter::GetErrorMessage(true); - MS_LOG(ERROR) << "HcclCommInitRootInfo failed. " + error_message; - return false; - } - // Get HCCL comm name which is used in graph sink mode for GE. - if (HcclGetCommName(comm_, inner_comm_name_) != static_cast(HCCL_SUCCESS)) { - const string &error_message = ErrorManagerAdapter::GetErrorMessage(true); - MS_LOG(ERROR) << "HcclGetCommName failed. " + error_message; - return false; - } - initialized_ = true; - (void)CALL_ASCEND_API(aclrtResetDevice, device_id); - return true; -} - -bool AscendCommunicationGroup::Finalize() { - if (!initialized_) { - return false; - } - if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { - // If using hccl CM envs to launch distributed job, comm_ is not initialized. So directly return. - initialized_ = false; - return true; - } - - // This function will be called at a lonesome thread that has no rtContext, so HcclCommDestroy will be failed. - // Delete these codes when these threads can be bind. - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - (void)CALL_ASCEND_API(aclrtSetDevice, device_id); - RETURN_IF_FALSE_WITH_LOG(HcclCommDestroy(comm_) == static_cast(HCCL_SUCCESS), - "Failed to destroy HCCL communicator."); - (void)CALL_ASCEND_API(aclrtResetDevice, device_id); - initialized_ = false; - comm_ = nullptr; - return true; -} - -void *AscendCommunicationGroup::GenerateRootInfo(size_t *root_info_size) { - *root_info_size = sizeof(unique_id_); - if (!common::GetEnv(kSimulationLevel).empty() && !hccl::HcclAdapter::GetInstance().UseHcclCM()) { - if (HcclGetRootInfo(&unique_id_) != static_cast(HCCL_SUCCESS)) { - return nullptr; - } - return &unique_id_; - } - uint32_t group_rank = GetGroupRank(global_rank_); - if (group_rank == 0) { - if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { - // If using hccl CM envs to launch distributed job, no need to call HcclGetRootInfo. - return &unique_id_; - } - if (HcclGetRootInfo(&unique_id_) != static_cast(HCCL_SUCCESS)) { - MS_LOG(ERROR) << "Failed to get HCCL unique id: " << CALL_ASCEND_API(aclGetRecentErrMsg); - return nullptr; - } - } - return &unique_id_; -} - -const HcclComm &AscendCommunicationGroup::hccl_communicator() const { return comm_; } - -std::string AscendCommunicationGroup::inner_comm_name() const { return inner_comm_name_; } -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/hardware/ascend_communication_group.h" +#include "plugin/device/ascend/hal/common/ascend_utils.h" +#include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h" +#include "mindspore/core/utils/ms_context.h" +#include "transform/symbol/acl_rt_symbol.h" +#include "transform/symbol/acl_symbol.h" +#include "transform/symbol/symbol_utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +AscendCommunicationGroup::AscendCommunicationGroup(const std::string &name, const std::vector &group_ranks, + uint32_t global_rank, uint32_t local_group_rank, + uint32_t local_group_size) + : CommunicationGroup(name, group_ranks, global_rank, local_group_rank, local_group_size), + unique_id_({}), + comm_(nullptr) { + (void)memset_s(inner_comm_name_, INNER_COMM_NAME_MAX_LENGTH, 0x00, INNER_COMM_NAME_MAX_LENGTH); +} + +bool AscendCommunicationGroup::Initialize(void *root_info) { + if (initialized_) { + return false; + } + if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { + // If using hccl CM envs to launch distributed job, no need to call HcclCommInitRootInfo. The group will be + // initialized in rank table way. + initialized_ = true; + return true; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + (void)CALL_ASCEND_API(aclrtSetDevice, device_id); + unique_id_ = *(static_cast(root_info)); + uint32_t group_rank; + auto group_size = size_; + if (!common::GetEnv(kSimulationLevel).empty()) { + group_size = 1; + group_rank = 0; + } else { + group_rank = GetGroupRank(global_rank_); + } + if (HcclCommInitRootInfo(static_cast(group_size), &unique_id_, static_cast(group_rank), &comm_) != + static_cast(HCCL_SUCCESS)) { + const string &error_message = ErrorManagerAdapter::GetErrorMessage(true); + MS_LOG(ERROR) << "HcclCommInitRootInfo failed. " + error_message; + return false; + } + // Get HCCL comm name which is used in graph sink mode for GE. + if (HcclGetCommName(comm_, inner_comm_name_) != static_cast(HCCL_SUCCESS)) { + const string &error_message = ErrorManagerAdapter::GetErrorMessage(true); + MS_LOG(ERROR) << "HcclGetCommName failed. " + error_message; + return false; + } + initialized_ = true; + (void)CALL_ASCEND_API(aclrtResetDevice, device_id); + return true; +} + +bool AscendCommunicationGroup::Finalize() { + if (!initialized_) { + return false; + } + if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { + // If using hccl CM envs to launch distributed job, comm_ is not initialized. So directly return. + initialized_ = false; + return true; + } + + // This function will be called at a lonesome thread that has no rtContext, so HcclCommDestroy will be failed. + // Delete these codes when these threads can be bind. + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + (void)CALL_ASCEND_API(aclrtSetDevice, device_id); + RETURN_IF_FALSE_WITH_LOG(HcclCommDestroy(comm_) == static_cast(HCCL_SUCCESS), + "Failed to destroy HCCL communicator."); + (void)CALL_ASCEND_API(aclrtResetDevice, device_id); + initialized_ = false; + comm_ = nullptr; + return true; +} + +void *AscendCommunicationGroup::GenerateRootInfo(size_t *root_info_size) { + *root_info_size = sizeof(unique_id_); + if (!common::GetEnv(kSimulationLevel).empty() && !hccl::HcclAdapter::GetInstance().UseHcclCM()) { + if (HcclGetRootInfo(&unique_id_) != static_cast(HCCL_SUCCESS)) { + return nullptr; + } + return &unique_id_; + } + uint32_t group_rank = GetGroupRank(global_rank_); + if (group_rank == 0) { + if (hccl::HcclAdapter::GetInstance().UseHcclCM()) { + // If using hccl CM envs to launch distributed job, no need to call HcclGetRootInfo. + return &unique_id_; + } + if (HcclGetRootInfo(&unique_id_) != static_cast(HCCL_SUCCESS)) { + MS_LOG(ERROR) << "Failed to get HCCL unique id: " << CALL_ASCEND_API(aclGetRecentErrMsg); + return nullptr; + } + } + return &unique_id_; +} + +const HcclComm &AscendCommunicationGroup::hccl_communicator() const { return comm_; } + +std::string AscendCommunicationGroup::inner_comm_name() const { return inner_comm_name_; } +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.h index 0bb46d9dac4..10bacb9b8f5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_communication_group.h @@ -1,64 +1,64 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ -#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ - -#include -#include -#include -#include "hccl/hccl.h" -#include "runtime/collective/communication_group.h" -#include "utils/dlopen_macro.h" - -namespace mindspore { -namespace device { -namespace ascend { -// Confirmed by HCCL max length of hccl comm name is 128. -constexpr int INNER_COMM_NAME_MAX_LENGTH = 128; - -class AscendCommunicationGroup : public CommunicationGroup { - public: - explicit AscendCommunicationGroup(const std::string &name, const std::vector &group_ranks, - uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size); - - ~AscendCommunicationGroup() override = default; - - bool Initialize(void *root_info) override; - bool Finalize() override; - - void *GenerateRootInfo(size_t *root_info_size) override; - - // Return HCCL communicator because collective operations need it as a input. - const HcclComm &hccl_communicator() const; - - // Return communicator name maintained by HCCL. This is different from the group set by user. - std::string inner_comm_name() const; - - private: - // The HCCL unique id for this group. Used to initialize this group's communicator. - HcclRootInfo unique_id_; - - // HCCL communicator of this group. - HcclComm comm_; - - char inner_comm_name_[INNER_COMM_NAME_MAX_LENGTH]; -}; -using AscendCommunicationGroupPtr = std::shared_ptr; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ +#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ + +#include +#include +#include +#include "hccl/hccl.h" +#include "runtime/collective/communication_group.h" +#include "utils/dlopen_macro.h" + +namespace mindspore { +namespace device { +namespace ascend { +// Confirmed by HCCL max length of hccl comm name is 128. +constexpr int INNER_COMM_NAME_MAX_LENGTH = 128; + +class AscendCommunicationGroup : public CommunicationGroup { + public: + explicit AscendCommunicationGroup(const std::string &name, const std::vector &group_ranks, + uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size); + + ~AscendCommunicationGroup() override = default; + + bool Initialize(void *root_info) override; + bool Finalize() override; + + void *GenerateRootInfo(size_t *root_info_size) override; + + // Return HCCL communicator because collective operations need it as a input. + const HcclComm &hccl_communicator() const; + + // Return communicator name maintained by HCCL. This is different from the group set by user. + std::string inner_comm_name() const; + + private: + // The HCCL unique id for this group. Used to initialize this group's communicator. + HcclRootInfo unique_id_; + + // HCCL communicator of this group. + HcclComm comm_; + + char inner_comm_name_[INNER_COMM_NAME_MAX_LENGTH]; +}; +using AscendCommunicationGroupPtr = std::shared_ptr; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_COMMUNICATION_GROUP_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.cc index a828a499eca..af4d4c61a4a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.cc @@ -1,163 +1,163 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h" - -namespace mindspore { -namespace device { -namespace ascend { -LowlatencyCollectiveCommLib::LowlatencyCollectiveCommLib() { global_group_name_ = kLCCLGlobalGroupName; } - -bool LowlatencyCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) { - if (initialized_) { - return false; - } - - global_rank_id_ = global_rank; - global_rank_size_ = global_rank_size; - local_rank_id_ = local_rank_id; - initialized_ = true; - finalized_ = false; - return true; -} - -bool LowlatencyCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, - const std::vector &group_ranks, - uint32_t local_group_rank, uint32_t local_group_size) { - CHECK_RET((groups_.count(group_name) == 0), true, "The LCCL group " + group_name + " has already existed."); - - LowlatencyCommunicationGroupPtr group = std::make_shared( - group_name, group_ranks, global_rank_id_, local_group_rank, local_group_size); - CHECK_IF_NULL(group); - groups_[group_name] = group; - return true; -} - -int LowlatencyCollectiveCommLib::AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const HcclReduceOp reduce_op, - const aclrtStream stream) { - return lccl_ptr->AllReduce(send_buff, recv_buff, count, data_type, reduce_op, stream); -} - -int LowlatencyCollectiveCommLib::AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const aclrtStream stream) { - return lccl_ptr->AllGather(send_buff, recv_buff, count, data_type, stream); -} - -int LowlatencyCollectiveCommLib::ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const HcclReduceOp reduce_op, - const aclrtStream stream) { - return lccl_ptr->ReduceScatter(send_buff, recv_buff, count, data_type, reduce_op, stream); -} - -int LowlatencyCollectiveCommLib::All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const aclrtStream stream) { - return lccl_ptr->All2All(send_buff, recv_buff, count, data_type, stream); -} - -int LowlatencyCollectiveCommLib::Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, - int root, const aclrtStream stream) { - return lccl_ptr->Broadcast(buff, count, data_type, root, stream); -} - -int LowlatencyCollectiveCommLib::MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, - const CoCOutputPkg &output_pkg, void *workspace, - const aclrtStream stream) { - return lcoc_ptr->MatmulAllReduce(input_pkg, output_pkg, workspace, stream); -} - -LcclPtr LowlatencyCollectiveCommLib::LcclCommunicator(const std::string &group_name) { - CHECK_RET((groups_.count(group_name) != 0), true, "The LCCL group " + group_name + " does not existed."); - auto group = std::dynamic_pointer_cast(groups_[group_name]); - CHECK_IF_NULL(group); - return group->lccl_communicator(); -} - -LcocPtr LowlatencyCollectiveCommLib::CreateLcocForOp(const std::string &group_name) { - CHECK_RET((groups_.count(group_name) != 0), true, "The LCCL group " + group_name + " does not existed."); - auto group = std::dynamic_pointer_cast(groups_[group_name]); - CHECK_IF_NULL(group); - - LcalCommPtr lcal_comm = group->lcal_comm(); - CHECK_IF_NULL(lcal_comm); - LcocPtr lcoc_ptr = std::make_shared(*(lcal_comm.get())); - return lcoc_ptr; -} - -void LowlatencyCollectiveCommLib::SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, - const CoCParamDesc ¶m_desc) { - lcoc_ptr->SetParam(lcal_type, tiling, param_desc); -} - -int64_t LowlatencyCollectiveCommLib::GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr) { - return lcoc_ptr->GetWorkspaceSize(); -} -} // namespace ascend - -using LowlatencyCollectiveCommLib = mindspore::device::ascend::LowlatencyCollectiveCommLib; - -CollectiveCommunicationLib *communication_lib_instance() { return &LowlatencyCollectiveCommLib::GetInstance(); } - -int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, - const HcclReduceOp reduce_op, const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().AllReduce(lccl_ptr, send_buff, recv_buff, count, data_type, - reduce_op, stream); -} - -int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, - const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().AllGather(lccl_ptr, send_buff, recv_buff, count, data_type, stream); -} - -int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, - const HcclReduceOp reduce_op, const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().ReduceScatter(lccl_ptr, send_buff, recv_buff, count, data_type, - reduce_op, stream); -} - -int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, - const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().All2All(lccl_ptr, send_buff, recv_buff, count, data_type, stream); -} - -int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, int root, - const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().Broadcast(lccl_ptr, buff, count, data_type, root, stream); -} - -int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, const CoCOutputPkg &output_pkg, - void *workspace, const aclrtStream stream) { - return LowlatencyCollectiveCommLib::GetInstance().MatmulAllReduce(lcoc_ptr, input_pkg, output_pkg, workspace, stream); -} - -LcclPtr LcclCommunicator(const std::string &group_name) { - return LowlatencyCollectiveCommLib::GetInstance().LcclCommunicator(group_name); -} - -LcocPtr CreateLcocForOp(const std::string &group_name) { - return LowlatencyCollectiveCommLib::GetInstance().CreateLcocForOp(group_name); -} - -void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, - const CoCParamDesc ¶m_desc) { - LowlatencyCollectiveCommLib::GetInstance().SetParamForLcoc(lcoc_ptr, lcal_type, tiling, param_desc); -} - -int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr) { - return LowlatencyCollectiveCommLib::GetInstance().GetLcocWorkspaceSize(lcoc_ptr); -} -} // namespace device -} // namespace mindspore +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h" + +namespace mindspore { +namespace device { +namespace ascend { +LowlatencyCollectiveCommLib::LowlatencyCollectiveCommLib() { global_group_name_ = kLCCLGlobalGroupName; } + +bool LowlatencyCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) { + if (initialized_) { + return false; + } + + global_rank_id_ = global_rank; + global_rank_size_ = global_rank_size; + local_rank_id_ = local_rank_id; + initialized_ = true; + finalized_ = false; + return true; +} + +bool LowlatencyCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name, + const std::vector &group_ranks, + uint32_t local_group_rank, uint32_t local_group_size) { + CHECK_RET((groups_.count(group_name) == 0), true, "The LCCL group " + group_name + " has already existed."); + + LowlatencyCommunicationGroupPtr group = std::make_shared( + group_name, group_ranks, global_rank_id_, local_group_rank, local_group_size); + CHECK_IF_NULL(group); + groups_[group_name] = group; + return true; +} + +int LowlatencyCollectiveCommLib::AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const HcclReduceOp reduce_op, + const aclrtStream stream) { + return lccl_ptr->AllReduce(send_buff, recv_buff, count, data_type, reduce_op, stream); +} + +int LowlatencyCollectiveCommLib::AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const aclrtStream stream) { + return lccl_ptr->AllGather(send_buff, recv_buff, count, data_type, stream); +} + +int LowlatencyCollectiveCommLib::ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const HcclReduceOp reduce_op, + const aclrtStream stream) { + return lccl_ptr->ReduceScatter(send_buff, recv_buff, count, data_type, reduce_op, stream); +} + +int LowlatencyCollectiveCommLib::All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const aclrtStream stream) { + return lccl_ptr->All2All(send_buff, recv_buff, count, data_type, stream); +} + +int LowlatencyCollectiveCommLib::Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, + int root, const aclrtStream stream) { + return lccl_ptr->Broadcast(buff, count, data_type, root, stream); +} + +int LowlatencyCollectiveCommLib::MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, + const CoCOutputPkg &output_pkg, void *workspace, + const aclrtStream stream) { + return lcoc_ptr->MatmulAllReduce(input_pkg, output_pkg, workspace, stream); +} + +LcclPtr LowlatencyCollectiveCommLib::LcclCommunicator(const std::string &group_name) { + CHECK_RET((groups_.count(group_name) != 0), true, "The LCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + return group->lccl_communicator(); +} + +LcocPtr LowlatencyCollectiveCommLib::CreateLcocForOp(const std::string &group_name) { + CHECK_RET((groups_.count(group_name) != 0), true, "The LCCL group " + group_name + " does not existed."); + auto group = std::dynamic_pointer_cast(groups_[group_name]); + CHECK_IF_NULL(group); + + LcalCommPtr lcal_comm = group->lcal_comm(); + CHECK_IF_NULL(lcal_comm); + LcocPtr lcoc_ptr = std::make_shared(*(lcal_comm.get())); + return lcoc_ptr; +} + +void LowlatencyCollectiveCommLib::SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, + const CoCParamDesc ¶m_desc) { + lcoc_ptr->SetParam(lcal_type, tiling, param_desc); +} + +int64_t LowlatencyCollectiveCommLib::GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr) { + return lcoc_ptr->GetWorkspaceSize(); +} +} // namespace ascend + +using LowlatencyCollectiveCommLib = mindspore::device::ascend::LowlatencyCollectiveCommLib; + +CollectiveCommunicationLib *communication_lib_instance() { return &LowlatencyCollectiveCommLib::GetInstance(); } + +int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, + const HcclReduceOp reduce_op, const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().AllReduce(lccl_ptr, send_buff, recv_buff, count, data_type, + reduce_op, stream); +} + +int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, + const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().AllGather(lccl_ptr, send_buff, recv_buff, count, data_type, stream); +} + +int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, + const HcclReduceOp reduce_op, const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().ReduceScatter(lccl_ptr, send_buff, recv_buff, count, data_type, + reduce_op, stream); +} + +int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType data_type, + const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().All2All(lccl_ptr, send_buff, recv_buff, count, data_type, stream); +} + +int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, int root, + const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().Broadcast(lccl_ptr, buff, count, data_type, root, stream); +} + +int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, const CoCOutputPkg &output_pkg, + void *workspace, const aclrtStream stream) { + return LowlatencyCollectiveCommLib::GetInstance().MatmulAllReduce(lcoc_ptr, input_pkg, output_pkg, workspace, stream); +} + +LcclPtr LcclCommunicator(const std::string &group_name) { + return LowlatencyCollectiveCommLib::GetInstance().LcclCommunicator(group_name); +} + +LcocPtr CreateLcocForOp(const std::string &group_name) { + return LowlatencyCollectiveCommLib::GetInstance().CreateLcocForOp(group_name); +} + +void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, + const CoCParamDesc ¶m_desc) { + LowlatencyCollectiveCommLib::GetInstance().SetParamForLcoc(lcoc_ptr, lcal_type, tiling, param_desc); +} + +int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr) { + return LowlatencyCollectiveCommLib::GetInstance().GetLcocWorkspaceSize(lcoc_ptr); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h index 75a856adb03..81b98181f73 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h @@ -1,123 +1,123 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ -#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ - -#include -#include -#include -#include -#include "runtime/collective/collective_communication_lib.h" -#include "plugin/device/ascend/hal/hardware/lowlatency_communication_group.h" - -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif - -namespace mindspore { -namespace device { -namespace ascend { -constexpr char kLCCLGlobalGroupName[] = "hccl_world_group"; - -// Low-latency collective communication libaray is implemented on Ascend platform. So some HCCL data types could be -// reused. -class EXPORT_WRAPPER LowlatencyCollectiveCommLib : public CollectiveCommunicationLib { - public: - static LowlatencyCollectiveCommLib &GetInstance() { - static LowlatencyCollectiveCommLib instance; - return instance; - } - - bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override; - - bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks, - uint32_t local_group_rank, uint32_t local_group_size) override; - - int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, - const HcclReduceOp op, const aclrtStream stream); - - int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, - const aclrtStream stream); - - int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, - const HcclReduceOp op, const aclrtStream stream); - - int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, - const aclrtStream stream); - - int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType dataType, int root, - const aclrtStream stream); - - int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, const CoCOutputPkg &output_pkg, - void *workspace, const aclrtStream stream); - - // Return lccl communicator so that caller could pass this communicator to communication APIs. - LcclPtr LcclCommunicator(const std::string &group_name); - - // For lcoc operations, lcoc object should be created for each operator so performance could be optimal. - LcocPtr CreateLcocForOp(const std::string &group_name); - - // Must set coc parameters before calling lcoc operators. - void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, - const CoCParamDesc ¶m_desc); - - // Lcoc operators need workspace with size returned by lcoc object. - int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr); - - private: - LowlatencyCollectiveCommLib(); - ~LowlatencyCollectiveCommLib() override = default; -}; -} // namespace ascend - -extern "C" EXPORT_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); -extern "C" EXPORT_WRAPPER int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const HcclReduceOp reduce_op, const aclrtStream stream); -extern "C" EXPORT_WRAPPER int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const aclrtStream stream); -extern "C" EXPORT_WRAPPER int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const HcclReduceOp reduce_op, - const aclrtStream stream); -extern "C" EXPORT_WRAPPER int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, - HcclDataType data_type, const aclrtStream stream); -extern "C" EXPORT_WRAPPER int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, - int root, const aclrtStream stream); -extern "C" EXPORT_WRAPPER int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, - const CoCOutputPkg &output_pkg, void *workspace, - const aclrtStream stream); -extern "C" EXPORT_WRAPPER LcclPtr LcclCommunicator(const std::string &group_name); -extern "C" EXPORT_WRAPPER LcocPtr CreateLcocForOp(const std::string &group_name); -extern "C" EXPORT_WRAPPER void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, - const CoCParamDesc ¶m_desc); -extern "C" EXPORT_WRAPPER int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr); -} // namespace device -} // namespace mindspore - -ORIGIN_METHOD(AllReduce, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const HcclReduceOp, - const aclrtStream) -ORIGIN_METHOD(AllGather, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const aclrtStream) -ORIGIN_METHOD(ReduceScatter, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const HcclReduceOp, - const aclrtStream) -ORIGIN_METHOD(All2All, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const aclrtStream) -ORIGIN_METHOD(Broadcast, int, const LcclPtr &, void *, size_t, HcclDataType, int, const aclrtStream) -ORIGIN_METHOD(MatmulAllReduce, int, const LcocPtr &, const CoCInputPkg &, const CoCOutputPkg &, void *, - const aclrtStream) -ORIGIN_METHOD(LcclCommunicator, LcclPtr, const std::string &); -ORIGIN_METHOD(CreateLcocForOp, LcocPtr, const std::string &); -ORIGIN_METHOD(SetParamForLcoc, void, const LcocPtr &, LcalType, const CoCTiling &, const CoCParamDesc &); -ORIGIN_METHOD(GetLcocWorkspaceSize, int64_t, const LcocPtr &); -#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ +#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ + +#include +#include +#include +#include +#include "runtime/collective/collective_communication_lib.h" +#include "plugin/device/ascend/hal/hardware/lowlatency_communication_group.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +namespace mindspore { +namespace device { +namespace ascend { +constexpr char kLCCLGlobalGroupName[] = "hccl_world_group"; + +// Low-latency collective communication libaray is implemented on Ascend platform. So some HCCL data types could be +// reused. +class EXPORT_WRAPPER LowlatencyCollectiveCommLib : public CollectiveCommunicationLib { + public: + static LowlatencyCollectiveCommLib &GetInstance() { + static LowlatencyCollectiveCommLib instance; + return instance; + } + + bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override; + + bool CreateCommunicationGroup(const std::string &group_name, const std::vector &group_ranks, + uint32_t local_group_rank, uint32_t local_group_size) override; + + int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream); + + int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, + const aclrtStream stream); + + int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream); + + int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, HcclDataType dataType, + const aclrtStream stream); + + int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType dataType, int root, + const aclrtStream stream); + + int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, const CoCOutputPkg &output_pkg, + void *workspace, const aclrtStream stream); + + // Return lccl communicator so that caller could pass this communicator to communication APIs. + LcclPtr LcclCommunicator(const std::string &group_name); + + // For lcoc operations, lcoc object should be created for each operator so performance could be optimal. + LcocPtr CreateLcocForOp(const std::string &group_name); + + // Must set coc parameters before calling lcoc operators. + void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, + const CoCParamDesc ¶m_desc); + + // Lcoc operators need workspace with size returned by lcoc object. + int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr); + + private: + LowlatencyCollectiveCommLib(); + ~LowlatencyCollectiveCommLib() override = default; +}; +} // namespace ascend + +extern "C" EXPORT_WRAPPER CollectiveCommunicationLib *communication_lib_instance(); +extern "C" EXPORT_WRAPPER int AllReduce(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const HcclReduceOp reduce_op, const aclrtStream stream); +extern "C" EXPORT_WRAPPER int AllGather(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const aclrtStream stream); +extern "C" EXPORT_WRAPPER int ReduceScatter(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const HcclReduceOp reduce_op, + const aclrtStream stream); +extern "C" EXPORT_WRAPPER int All2All(const LcclPtr &lccl_ptr, void *send_buff, void *recv_buff, size_t count, + HcclDataType data_type, const aclrtStream stream); +extern "C" EXPORT_WRAPPER int Broadcast(const LcclPtr &lccl_ptr, void *buff, size_t count, HcclDataType data_type, + int root, const aclrtStream stream); +extern "C" EXPORT_WRAPPER int MatmulAllReduce(const LcocPtr &lcoc_ptr, const CoCInputPkg &input_pkg, + const CoCOutputPkg &output_pkg, void *workspace, + const aclrtStream stream); +extern "C" EXPORT_WRAPPER LcclPtr LcclCommunicator(const std::string &group_name); +extern "C" EXPORT_WRAPPER LcocPtr CreateLcocForOp(const std::string &group_name); +extern "C" EXPORT_WRAPPER void SetParamForLcoc(const LcocPtr &lcoc_ptr, LcalType lcal_type, const CoCTiling &tiling, + const CoCParamDesc ¶m_desc); +extern "C" EXPORT_WRAPPER int64_t GetLcocWorkspaceSize(const LcocPtr &lcoc_ptr); +} // namespace device +} // namespace mindspore + +ORIGIN_METHOD(AllReduce, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const HcclReduceOp, + const aclrtStream) +ORIGIN_METHOD(AllGather, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const aclrtStream) +ORIGIN_METHOD(ReduceScatter, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const HcclReduceOp, + const aclrtStream) +ORIGIN_METHOD(All2All, int, const LcclPtr &, void *, void *, size_t, HcclDataType, const aclrtStream) +ORIGIN_METHOD(Broadcast, int, const LcclPtr &, void *, size_t, HcclDataType, int, const aclrtStream) +ORIGIN_METHOD(MatmulAllReduce, int, const LcocPtr &, const CoCInputPkg &, const CoCOutputPkg &, void *, + const aclrtStream) +ORIGIN_METHOD(LcclCommunicator, LcclPtr, const std::string &); +ORIGIN_METHOD(CreateLcocForOp, LcocPtr, const std::string &); +ORIGIN_METHOD(SetParamForLcoc, void, const LcocPtr &, LcalType, const CoCTiling &, const CoCParamDesc &); +ORIGIN_METHOD(GetLcocWorkspaceSize, int64_t, const LcocPtr &); +#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COLLECTIVE_COMM_LIB_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.cc index f713a8cacf9..f3e87705165 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.cc @@ -1,66 +1,66 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/hardware/lowlatency_communication_group.h" - -namespace mindspore { -namespace device { -namespace ascend { -LowlatencyCommunicationGroup::LowlatencyCommunicationGroup(const std::string &name, - const std::vector &group_ranks, - uint32_t global_rank, uint32_t local_group_rank, - uint32_t local_group_size) - : CommunicationGroup(name, group_ranks, global_rank, local_group_rank, local_group_size), - lcal_comm_(nullptr), - lccl_comm_(nullptr) {} - -bool LowlatencyCommunicationGroup::Initialize(void *root_info) { - if (initialized_) { - return true; - } - auto ret = aclrtSetDevice(local_group_rank_); - if (ret != ACL_RT_SUCCESS) { - return false; - } - uint32_t group_rank = GetGroupRank(global_rank_); - lcal_comm_ = std::make_shared(group_rank, size_); - if (lcal_comm_->Init() != LCAL_SUCCESS) { - return false; - } - lccl_comm_ = std::make_shared(*(lcal_comm_.get())); - initialized_ = true; - return true; -} - -bool LowlatencyCommunicationGroup::Finalize() { - if (!initialized_) { - return true; - } - initialized_ = false; - return true; -} - -void *LowlatencyCommunicationGroup::GenerateRootInfo(size_t *root_info_size) { - *root_info_size = sizeof(size_t); - return root_info_size; -} - -const LcclPtr &LowlatencyCommunicationGroup::lccl_communicator() const { return lccl_comm_; } - -const LcalCommPtr &LowlatencyCommunicationGroup::lcal_comm() const { return lcal_comm_; } -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/hardware/lowlatency_communication_group.h" + +namespace mindspore { +namespace device { +namespace ascend { +LowlatencyCommunicationGroup::LowlatencyCommunicationGroup(const std::string &name, + const std::vector &group_ranks, + uint32_t global_rank, uint32_t local_group_rank, + uint32_t local_group_size) + : CommunicationGroup(name, group_ranks, global_rank, local_group_rank, local_group_size), + lcal_comm_(nullptr), + lccl_comm_(nullptr) {} + +bool LowlatencyCommunicationGroup::Initialize(void *root_info) { + if (initialized_) { + return true; + } + auto ret = aclrtSetDevice(local_group_rank_); + if (ret != ACL_RT_SUCCESS) { + return false; + } + uint32_t group_rank = GetGroupRank(global_rank_); + lcal_comm_ = std::make_shared(group_rank, size_); + if (lcal_comm_->Init() != LCAL_SUCCESS) { + return false; + } + lccl_comm_ = std::make_shared(*(lcal_comm_.get())); + initialized_ = true; + return true; +} + +bool LowlatencyCommunicationGroup::Finalize() { + if (!initialized_) { + return true; + } + initialized_ = false; + return true; +} + +void *LowlatencyCommunicationGroup::GenerateRootInfo(size_t *root_info_size) { + *root_info_size = sizeof(size_t); + return root_info_size; +} + +const LcclPtr &LowlatencyCommunicationGroup::lccl_communicator() const { return lccl_comm_; } + +const LcalCommPtr &LowlatencyCommunicationGroup::lcal_comm() const { return lcal_comm_; } +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.h b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.h index da1aa7ee042..7ea26be6c96 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/lowlatency_communication_group.h @@ -1,65 +1,65 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ -#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ - -#include -#include -#include -#include "lccl.h" -#include "lcoc.h" -#include "runtime/collective/communication_group.h" -#include "utils/dlopen_macro.h" - -using namespace Lcal; -using LcalCommPtr = std::shared_ptr; -using LcclPtr = std::shared_ptr; -using LcocPtr = std::shared_ptr; - -namespace mindspore { -namespace device { -namespace ascend { - -class LowlatencyCommunicationGroup : public CommunicationGroup { - public: - explicit LowlatencyCommunicationGroup(const std::string &name, const std::vector &group_ranks, - uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size); - - ~LowlatencyCommunicationGroup() override = default; - - bool Initialize(void *root_info) override; - bool Finalize() override; - - void *GenerateRootInfo(size_t *root_info_size) override; - - // Return communicator for collective communication ops. - const LcclPtr &lccl_communicator() const; - // Return communicator of lcal. - const LcalCommPtr &lcal_comm() const; - - private: - // Lcal communicator of this group, but this should be encapsulated by 'Lccl' class to use communication operations. - LcalCommPtr lcal_comm_; - - // 'Lccl' object returned to call communication operations. - LcclPtr lccl_comm_; -}; -using LowlatencyCommunicationGroupPtr = std::shared_ptr; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ +#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ + +#include +#include +#include +#include "lccl.h" +#include "lcoc.h" +#include "runtime/collective/communication_group.h" +#include "utils/dlopen_macro.h" + +using namespace Lcal; +using LcalCommPtr = std::shared_ptr; +using LcclPtr = std::shared_ptr; +using LcocPtr = std::shared_ptr; + +namespace mindspore { +namespace device { +namespace ascend { + +class LowlatencyCommunicationGroup : public CommunicationGroup { + public: + explicit LowlatencyCommunicationGroup(const std::string &name, const std::vector &group_ranks, + uint32_t global_rank, uint32_t local_group_rank, uint32_t local_group_size); + + ~LowlatencyCommunicationGroup() override = default; + + bool Initialize(void *root_info) override; + bool Finalize() override; + + void *GenerateRootInfo(size_t *root_info_size) override; + + // Return communicator for collective communication ops. + const LcclPtr &lccl_communicator() const; + // Return communicator of lcal. + const LcalCommPtr &lcal_comm() const; + + private: + // Lcal communicator of this group, but this should be encapsulated by 'Lccl' class to use communication operations. + LcalCommPtr lcal_comm_; + + // 'Lccl' object returned to call communication operations. + LcclPtr lccl_comm_; +}; +using LowlatencyCommunicationGroupPtr = std::shared_ptr; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_LOWLATENCY_COMMUNICATION_GROUP_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicore/cmake/util/gen_impl_and_merge_json.sh b/mindspore/ccsrc/plugin/device/ascend/kernel/aicore/cmake/util/gen_impl_and_merge_json.sh old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.cc index b8cb24a54fd..053eb1a89a4 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.cc @@ -1,851 +1,851 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd 2019-2022. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! - * \file common_shape_fns.cpp - * \brief - */ -#include "common_shape_fns.h" -#include -#include -#include "op_log.h" -#include "error_util.h" -#include "util.h" - -namespace ge { -const std::map dtype_maps{{"DT_FLOAT", DT_FLOAT}, - {"DT_FLOAT16", DT_FLOAT16}, - {"DT_INT8", DT_INT8}, - {"DT_INT16", DT_INT16}, - {"DT_UINT16", DT_UINT16}, - {"DT_UINT8", DT_UINT8}, - {"DT_INT32", DT_INT32}, - {"DT_INT64", DT_INT64}, - {"DT_UINT32", DT_UINT32}, - {"DT_UINT64", DT_UINT64}, - {"DT_BOOL", DT_BOOL}, - {"DT_DOUBLE", DT_DOUBLE}, - {"DT_STRING", DT_STRING}, - {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, - {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, - {"DT_COMPLEX64", DT_COMPLEX64}, - {"DT_COMPLEX128", DT_COMPLEX128}, - {"DT_QINT8", DT_QINT8}, - {"DT_QINT16", DT_QINT16}, - {"DT_QINT32", DT_QINT32}, - {"DT_QUINT8", DT_QUINT8}, - {"DT_QUINT16", DT_QUINT16}, - {"DT_RESOURCE", DT_RESOURCE}, - {"DT_STRING_REF", DT_STRING_REF}, - {"DT_DUAL", DT_DUAL}, - {"DT_BF16", DT_BF16}, - {"DT_UNDEFINED", DT_UNDEFINED}}; - -graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { - if (rank > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - Shape s = tensor.GetShape(); - std::vector dims = s.GetDims(); - // dim.size() convert to be type int64_t can't overflow - int64_t size = static_cast(dims.size()); - if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]")); - return GRAPH_FAILED; - } - out = s; - return GRAPH_SUCCESS; -} - -graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) { - if (rank > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - Shape s = tensor.GetShape(); - std::vector dims = s.GetDims(); - // dim.size() convert to be type int64_t can't overflow - int64_t size = static_cast(dims.size()); - if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), - ConcatString("rank[", size, "] must be at least [", rank, "]")); - return GRAPH_FAILED; - } - out = s; - return GRAPH_SUCCESS; -} - -graphStatus WithRankShape(Shape &shape, int64_t rank, const ge::Operator &op) { - if (rank > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - - int64_t existing = static_cast(shape.GetDimNum()); - - if (shape.GetDims() == UNKNOWN_RANK) { - std::vector out_shape(rank, UNKNOWN_DIM); - shape = Shape(out_shape); - return GRAPH_SUCCESS; - } - if (existing != rank) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]")); - return GRAPH_FAILED; - } - - std::vector dim_values = shape.GetDims(); - shape = Shape(dim_values); - return GRAPH_SUCCESS; -} - -graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { - if (rank > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - Shape s = tensor.GetShape(); - int64_t existing = static_cast(s.GetDimNum()); - - if (s.GetDims() == UNKNOWN_RANK) { - std::vector out_shape(rank, UNKNOWN_DIM); - out = Shape(out_shape); - return GRAPH_SUCCESS; - } - - if (existing != rank) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]")); - return GRAPH_FAILED; - } - out = s; - return GRAPH_SUCCESS; -} - -graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op) { - out = value; - if (dim == UNKNOWN_DIM) { - return GRAPH_SUCCESS; - } - - if (dim != value) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[", dim, "] should be ", value)); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op) { - // Same shape and unknown rank - if (!RankKnown(s) || !RankKnown(prefix)) { - s_out = s; - prefix_out = prefix; - return GRAPH_SUCCESS; - } - const size_t rank = prefix.GetDimNum(); - std::vector dims1 = s.GetDims(); - if ((dims1 != UNKNOWN_RANK) && (dims1.size() < rank)) { - std::string err_msg = ConcatString("first shape rank[", dims1.size(), "] must be at least rank[", rank, "]"); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - - const size_t rank_s = s.GetDimNum(); - std::vector dims; - dims.reserve(std::max(rank, rank_s)); - dims.resize(rank); - for (size_t i = 0; i < rank; ++i) { - if (Merge(s.GetDim(i), prefix.GetDim(i), dims[i]) != GRAPH_SUCCESS) { - std::string err_msg = ConcatString(i, "th dim of first shape", DebugString(s.GetDims()), - " is not same as that of prefix shape", DebugString(prefix.GetDims())); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - } - prefix_out = Shape(dims); - for (size_t i = rank; i < rank_s; ++i) { - dims.push_back(s.GetDim(i)); - } - s_out = Shape(dims); - return GRAPH_SUCCESS; -} - -graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out) { - if (dim1 == dim2) { - out = dim1; - return GRAPH_SUCCESS; - } else if (dim2 == UNKNOWN_DIM) { - out = dim1; - return GRAPH_SUCCESS; - } else if (dim1 == UNKNOWN_DIM) { - out = dim2; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op) { - // Same shape and unknown rank - if (s0.GetDims() == s1.GetDims()) { - out = s0; - return GRAPH_SUCCESS; - } else if (!RankKnown(s1)) { - out = s0; - return GRAPH_SUCCESS; - } else if (!RankKnown(s0)) { - out = s1; - return GRAPH_SUCCESS; - } - - const size_t rank = s0.GetDimNum(); - if (s1.GetDimNum() != rank) { - std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape", - DebugString(s1.GetDims())); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - - // Check if each dims equal - bool return_s0 = true; - bool return_s1 = true; - for (size_t i = 0; i < rank; i++) { - int64_t d0 = s0.GetDim(i); - int64_t d1 = s1.GetDim(i); - if (d0 == UNKNOWN_DIM) { - if (d1 != UNKNOWN_DIM) { - return_s0 = false; - } - } else if (d1 == UNKNOWN_DIM) { - return_s1 = false; - } else if (d0 != d1) { - std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()), - " and second shape", DebugString(s1.GetDims())); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - } - - if (return_s0 || return_s1) { - out = return_s0 ? s0 : s1; - return GRAPH_SUCCESS; - } - - // Merge dims - std::vector dims(rank, 0); - for (size_t i = 0; i < rank; ++i) { - // Invariant for merge was checked earlier, so CHECK is ok. - if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) { - std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()), - " and second shape", DebugString(s1.GetDims())); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - } - - out = Shape(dims); - return GRAPH_SUCCESS; -} - -void MergeShape(const Shape &shared_shape, const Shape &value_shape, std::vector &out, bool &shape_changed) { - for (size_t i = 0; i < out.size(); ++i) { - if (shared_shape.GetDim(i) == value_shape.GetDim(i) || shared_shape.GetDim(i) == -1) { - out[i] = shared_shape.GetDim(i); - } else { - out[i] = -1; - shape_changed = true; - } - } -} - -void MergeRange(const std::vector> &shared_shape_range, - const std::vector> &value_shape_range, - std::vector> &out, bool &shape_changed) { - for (size_t i = 0; i < out.size(); ++i) { - auto &shared_range = shared_shape_range[i]; - auto &value_range = value_shape_range[i]; - if (shared_range.first <= value_range.first) { - out[i].first = shared_range.first; - } else { - out[i].first = value_range.first; - shape_changed = true; - } - if (shared_range.second == -1 || (value_range.second != -1 && shared_range.second >= value_range.second)) { - out[i].second = shared_range.second; - } else { - out[i].second = value_range.second; - shape_changed = true; - } - } -} - -graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range, - ShapeAndRange &out, bool &shape_changed, const ge::Operator &op) { - if (!RankKnown(shared_shape_and_range.shape_)) { - out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_}; - return GRAPH_SUCCESS; - } - if (!RankKnown(value_shape_and_range.shape_) || - (shared_shape_and_range.shape_.GetDimNum() != value_shape_and_range.shape_.GetDimNum())) { - out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_}; - shape_changed = true; - return GRAPH_SUCCESS; - } - auto actual_shared_range = shared_shape_and_range.shape_range_; - auto actual_value_range = value_shape_and_range.shape_range_; - if (shared_shape_and_range.shape_.GetDimNum() != shared_shape_and_range.shape_range_.size()) { - actual_shared_range.clear(); - for (auto dim : shared_shape_and_range.shape_.GetDims()) { - if (dim == ge::UNKNOWN_DIM) { - actual_shared_range.push_back({1, -1}); - } else { - actual_shared_range.push_back({dim, dim}); - } - } - } - if (value_shape_and_range.shape_.GetDimNum() != value_shape_and_range.shape_range_.size()) { - actual_value_range.clear(); - for (auto dim : value_shape_and_range.shape_.GetDims()) { - if (dim == ge::UNKNOWN_DIM) { - actual_value_range.push_back({1, -1}); - } else { - actual_value_range.push_back({dim, dim}); - } - } - } - const size_t rank = value_shape_and_range.shape_.GetDimNum(); - std::vector dims(rank); - std::vector> shape_range(rank); - MergeShape(shared_shape_and_range.shape_, value_shape_and_range.shape_, dims, shape_changed); - MergeRange(actual_shared_range, actual_value_range, shape_range, shape_changed); - out = {Shape(dims), shape_range, value_shape_and_range.shape_type_}; - return GRAPH_SUCCESS; -} - -graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op) { - if (!RankKnown(s)) { - out = Shape(ge::UNKNOWN_SHAPE); - return GRAPH_SUCCESS; - } - int64_t dim_index = dim_index_in; - if (dim_index < 0) { - dim_index = static_cast(s.GetDimNum()) + dim_index; - } - if (!FastBoundsCheck(dim_index, s.GetDimNum())) { - out = Shape(); - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]")); - return GRAPH_FAILED; - } - std::vector dims = s.GetDims(); - dims[dim_index] = new_dim; - out = Shape(dims); - return GRAPH_SUCCESS; -} - -template -bool FastBoundsCheck(const Ta index, const Tb limit) { - static_assert(std::is_integral::value && std::is_integral::value, - "FastBoundsCheck can only be used on integer types."); - typedef typename std::make_unsigned::type UIndex; - return static_cast(index) < static_cast(limit); -} - -graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out) { - if (dim1 == 0) { - out = dim2; - } else if (dim2 == 0) { - out = dim1; - } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { - out = UNKNOWN_DIM; - } else { - const int64_t sum = dim1 + dim2; - if (sum < 0) { - return GRAPH_FAILED; - } - out = sum; - } - return GRAPH_SUCCESS; -} - -graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op) { - if (dim2 == 0) { - out = dim1; - } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { - out = UNKNOWN_DIM; - } else { - if (dim1 < dim2) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]")); - return GRAPH_FAILED; - } - out = dim1 - dim2; - } - return GRAPH_SUCCESS; -} - -graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op) { - if (s.GetDimNum() > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - const int64_t rank = static_cast(s.GetDimNum()); - TensorDesc tensor(s); - if (!RankKnown(s) || - (start == 0 && ((tensor.GetRealDimCnt() != -1 && end >= rank) || end == std::numeric_limits::max()))) { - out = s; - return GRAPH_SUCCESS; - } - - if (start > rank) { - start = rank; - } - if (end > rank) { - end = rank; - } - - if (stride < 0 && start == rank) { - --start; - } - - if (start < 0) { - start = rank + start; - if (start < 0) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]")); - return GRAPH_FAILED; - } - } - - if (end < 0) { - end = rank + end; - if (end < 0) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]")); - return GRAPH_FAILED; - } - } - - // stride > 0 and start > end - if (!((stride <= 0 || start <= end))) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]")); - return GRAPH_FAILED; - } - // stride < 0 and start < end - if (!(stride >= 0 || start >= end)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]")); - return GRAPH_FAILED; - } - std::vector dims; - for (int64_t i = start; stride > 0 ? i < end : i > end; i += stride) { - dims.push_back(s.GetDim(i)); - } - Shape tmp(dims); - out = tmp; - return GRAPH_SUCCESS; -} - -graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out) { - if (!RankKnown(s1) || !RankKnown(s2)) { - out = Shape(ge::UNKNOWN_RANK); - return GRAPH_SUCCESS; - } - size_t s1_rank = s1.GetDimNum(); - size_t s2_rank = s2.GetDimNum(); - size_t rank = s1_rank + s2_rank; - std::vector dims; - dims.reserve(rank); - for (size_t i = 0; i < s1_rank; ++i) { - dims.push_back(s1.GetDim(i)); - } - for (size_t i = 0; i < s2_rank; ++i) { - dims.push_back(s2.GetDim(i)); - } - Shape s(dims); - out = s; - return GRAPH_SUCCESS; -} - -graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out) { - std::vector dims; - dims.reserve(2); // The number of dims is 2. - dims.push_back(dim1); - dims.push_back(dim2); - Shape s(dims); - out = s; - return GRAPH_SUCCESS; -} - -graphStatus Vector(int64_t dim, Shape &out) { - std::vector dims; - dims.reserve(1); - dims.push_back(dim); - Shape s(dims); - out = s; - return GRAPH_SUCCESS; -} - -static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_name, int64_t rank, - std::vector &data) { - auto shape_data_desc = op.GetInputDesc(dst_name); - - std::vector input_infer_depends = {dst_name}; - PREPARE_DYNAMIC_SHAPE(input_infer_depends); - - Shape shape_data_shape(shape_data_desc.GetShape()); - std::vector dims = shape_data_shape.GetDims(); - DataType data_type = shape_data_desc.GetDataType(); - if (dims.size() != static_cast(rank)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]")); - return GRAPH_FAILED; - } - int64_t dim_value = ((rank > 0) && (dims[0] > 0)) ? dims[0] : 1; - data.clear(); - if (dims[0] < 0) { - OP_LOGI(op, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]); - data.push_back(UNKNOWN_DIM_NUM); - return GRAPH_SUCCESS; - } - data.reserve(dim_value); - Tensor shape_tensor; - if (data_type == DT_INT32) { - if (op.GetInputConstData(dst_name.c_str(), shape_tensor) == GRAPH_SUCCESS) { - const auto *shape_data = reinterpret_cast(shape_tensor.GetData()); - for (int64_t i = 0; i < dim_value; i++) { - data.push_back(static_cast(shape_data[i])); - } - } else { - OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str()); - for (int64_t i = 0; i < dim_value; i++) { - data.push_back(UNKNOWN_DIM); - } - } - } else if (data_type == DT_INT64) { - if (op.GetInputConstData(dst_name.c_str(), shape_tensor) == GRAPH_SUCCESS) { - const auto *shape_data = reinterpret_cast(shape_tensor.GetData()); - for (int64_t i = 0; i < dim_value; i++) { - data.push_back(static_cast(shape_data[i])); - } - } else { - OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str()); - for (int64_t i = 0; i < dim_value; i++) { - data.push_back(UNKNOWN_DIM); - } - } - } else { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64")); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank, std::vector &data, - const ge::Operator &op) { - TensorDesc shape_data_desc = tensor.GetTensorDesc(); - Shape shape_data_shape = shape_data_desc.GetShape(); - std::vector dims = shape_data_shape.GetDims(); - DataType data_type = shape_data_desc.GetDataType(); - - if (dims.size() != static_cast(rank)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]")); - return GRAPH_FAILED; - } - int64_t dim_value = rank > 0 ? dims[0] : 1; - OP_LOGI(op, "data_type = %d, dim_value = %ld", data_type, dim_value); - data.clear(); - data.reserve(dim_value); - if (data_type == DT_INT32) { - const int32_t *shape_data = reinterpret_cast(tensor.GetData()); - for (int64_t i = 0; i < dim_value; i++) { - OP_LOGI(op, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast(shape_data[i])); - data.push_back(static_cast(shape_data[i])); - } - } else if (data_type == DT_INT64) { - const int64_t *shape_data = reinterpret_cast(tensor.GetData()); - for (int64_t i = 0; i < dim_value; i++) { - OP_LOGI(op, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]); - data.push_back(shape_data[i]); - } - } else { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64")); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op) { - std::vector shape_data; - GetShapeDataFromConstData(tensor, 1, shape_data, op); - out = Shape(shape_data); - return GRAPH_SUCCESS; -} - -graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, Shape &out) { - std::vector shape_data; - if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - out = Shape(shape_data); - return GRAPH_SUCCESS; -} - -graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op) { - std::vector shape_data; - GetShapeDataFromConstData(tensor, 0, shape_data, op); - out = shape_data[0]; - return GRAPH_SUCCESS; -} - -graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { - if (rank > INT32_MAX) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); - return GRAPH_FAILED; - } - Shape s = tensor.GetShape(); - std::vector dims = s.GetDims(); - if (!((dims.size() <= static_cast(rank)) || (dims == ge::UNKNOWN_SHAPE))) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank)); - return GRAPH_FAILED; - } - out = s; - return GRAPH_SUCCESS; -} - -graphStatus Scalar(Shape &out) { - std::vector dims = {}; - Shape s(dims); - out = s; - return GRAPH_SUCCESS; -} - -graphStatus UnchangedShape(Operator &op, const string input_name, const string output_name) { - TensorDesc desc = op.GetOutputDescByName(output_name.c_str()); - desc.SetShape(op.GetInputDescByName(input_name.c_str()).GetShape()); - return op.UpdateOutputDesc(output_name.c_str(), desc); -} - -graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out, - const ge::Operator &op) { - if (divisor == 1) { - out = dividend; - } else if ((dividend == ge::UNKNOWN_DIM) || (divisor == ge::UNKNOWN_DIM)) { - out = ge::UNKNOWN_DIM; - } else { - if (divisor <= 0) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid divisor[", divisor, "], should be positive")); - return GRAPH_FAILED; - } - if (!((!evenlyDivisible) || (dividend % divisor) == 0)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT( - op, ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]")); - return GRAPH_FAILED; - } - out = dividend / divisor; - } - return GRAPH_SUCCESS; -} - -bool ShapeFullDefined(const Shape &shape) { - if (!RankKnown(shape)) { - return false; - } - std::vector dims = shape.GetDims(); - - for (const auto &dim : dims) { - if (dim == ge::UNKNOWN_DIM) { - return false; - } - } - return true; -} - -bool ShapeFullyDefined(const Shape &shape) { - if (!RankKnown(shape)) { - return false; - } - - std::vector dims = shape.GetDims(); - for (const int64_t &dim : dims) { - if (dim == ge::UNKNOWN_DIM) { - return false; - } - } - - return true; -} - -bool RankKnown(const Shape &shape) { - std::vector dims = shape.GetDims(); - if (dims == ge::UNKNOWN_RANK) { - return false; - } - return true; -} - -Shape UnknownShapeOfRank(int64_t rank) { - std::vector dims(rank); - for (int64_t i = 0; i < rank; ++i) { - dims[i] = ge::UNKNOWN_DIM; - } - return Shape(dims); -} - -bool ValueKnown(const Shape &shape, const size_t &dim_index) { - if (shape.GetDims() == ge::UNKNOWN_SHAPE) { - return false; - } - if (dim_index >= shape.GetDims().size()) { - return false; - } - if (shape.GetDim(dim_index) == ge::UNKNOWN_DIM) { - return false; - } - - return true; -} - -graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape, - const ge::Operator &op) { - // Validate ranks - Shape unused_shape; - if (WithRank(indices, 2, unused_shape, op) != GRAPH_SUCCESS) { // The rank is 2. - std::string err_msg = ConcatString("failed to call WithRank function, indices has wrong shape", - DebugString(indices.GetShape().GetDims()), ", it should be 2D"); - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - if (WithRank(values, 1, unused_shape, op) != GRAPH_SUCCESS) { - std::string err_msg = ConcatString("failed to call WithRank function, values has wrong shape", - DebugString(values.GetShape().GetDims()), ", it should be 1D"); - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - if (WithRank(shape, 1, unused_shape, op) != GRAPH_SUCCESS) { - std::string err_msg = ConcatString("failed to call WithRank function, shape has wrong shape", - DebugString(shape.GetShape().GetDims()), ", it should be 1D"); - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); - return GRAPH_FAILED; - } - - // Number of elements in indices and values must match - Shape indices_shape = indices.GetShape(); - Shape values_shape = values.GetShape(); - if (ValueKnown(indices_shape, 0)) { - if (ValueKnown(values_shape, 0)) { - if (indices_shape.GetDim(0) != values_shape.GetDim(0)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[0] of indices and dim[0] of value do not match, ", - indices_shape.GetDim(0), " and ", values_shape.GetDim(0))); - return GRAPH_FAILED; - } - } - } - - // Rank embedded in indices must match shape. - Shape sparse_shape = shape.GetShape(); - if (ValueKnown(indices_shape, 1)) { - if (ValueKnown(sparse_shape, 0)) { - if (indices_shape.GetDim(1) != sparse_shape.GetDim(0)) { - VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[1] of indices and dim[0] of sparse do not match, ", - indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0))); - return GRAPH_FAILED; - } - } - } - return GRAPH_SUCCESS; -} - -std::string DTypeStr(DataType dtype) { - auto iter = - std::find_if(dtype_maps.begin(), dtype_maps.end(), - [dtype](const std::map::value_type &kv) { return (kv.second == dtype); }); - if (iter != dtype_maps.end()) { - return iter->first; - } else { - return std::string("DT_UNDEFINED"); - } -} - -graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range) { - auto context = op.GetInferenceContext(); - std::vector marks; - context->GetMarks(marks); - - if (!marks.empty()) { - OP_LOGI(op, "Set marks[0] = %s", marks[0].GetString()); - bool shape_changed = false; - auto aicpu_resource_context = dynamic_cast(context->GetResourceContext(marks[0])); - if (aicpu_resource_context == nullptr) { - aicpu_resource_context = new (std::nothrow) AicpuResourceContext(); - if (aicpu_resource_context == nullptr) { - AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("new AicpuResourceContext failed.")); - return GRAPH_FAILED; - } - aicpu_resource_context->shape_and_range_.push_back(feed_shape_and_range); - if (context->SetResourceContext(marks[0], aicpu_resource_context) != GRAPH_SUCCESS) { - delete aicpu_resource_context; - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("set resource context failed.")); - return GRAPH_FAILED; - } - shape_changed = true; - } else { - auto &shape_and_range = aicpu_resource_context->shape_and_range_; - if (shape_and_range.empty()) { - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("get resource context shape and ranges failed.")); - return GRAPH_FAILED; - } - MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed, op); - } - if (shape_changed) { - if (context->AddChangedResourceKey(marks[0]) != GRAPH_SUCCESS) { - AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("add change resource key failed.")); - return GRAPH_FAILED; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context) { - std::vector marks; - infer_context->GetMarks(marks); - if (!marks.empty()) { - OP_LOGI(op, "Get marks[0] = %s", marks[0].GetString()); - if (infer_context->RegisterReliedOnResourceKey(marks[0]) != GRAPH_SUCCESS) { - AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("register relied on resource key failed.")); - return GRAPH_FAILED; - } - auto aicpu_resource_context = dynamic_cast(infer_context->GetResourceContext(marks[0])); - if (aicpu_resource_context != nullptr) { - auto &shape_and_range = aicpu_resource_context->shape_and_range_; - if (shape_and_range.empty()) { - AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("get resource context shape and ranges failed.")); - return GRAPH_FAILED; - } - out.shape_ = shape_and_range[0].shape_; - out.shape_range_ = shape_and_range[0].shape_range_; - out.shape_type_ = shape_and_range[0].shape_type_; - geted = true; - } - } - return GRAPH_SUCCESS; -} -} // namespace ge +/* + * Copyright (c) Huawei Technologies Co., Ltd 2019-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file common_shape_fns.cpp + * \brief + */ +#include "common_shape_fns.h" +#include +#include +#include "op_log.h" +#include "error_util.h" +#include "util.h" + +namespace ge { +const std::map dtype_maps{{"DT_FLOAT", DT_FLOAT}, + {"DT_FLOAT16", DT_FLOAT16}, + {"DT_INT8", DT_INT8}, + {"DT_INT16", DT_INT16}, + {"DT_UINT16", DT_UINT16}, + {"DT_UINT8", DT_UINT8}, + {"DT_INT32", DT_INT32}, + {"DT_INT64", DT_INT64}, + {"DT_UINT32", DT_UINT32}, + {"DT_UINT64", DT_UINT64}, + {"DT_BOOL", DT_BOOL}, + {"DT_DOUBLE", DT_DOUBLE}, + {"DT_STRING", DT_STRING}, + {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, + {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, + {"DT_COMPLEX64", DT_COMPLEX64}, + {"DT_COMPLEX128", DT_COMPLEX128}, + {"DT_QINT8", DT_QINT8}, + {"DT_QINT16", DT_QINT16}, + {"DT_QINT32", DT_QINT32}, + {"DT_QUINT8", DT_QUINT8}, + {"DT_QUINT16", DT_QUINT16}, + {"DT_RESOURCE", DT_RESOURCE}, + {"DT_STRING_REF", DT_STRING_REF}, + {"DT_DUAL", DT_DUAL}, + {"DT_BF16", DT_BF16}, + {"DT_UNDEFINED", DT_UNDEFINED}}; + +graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { + if (rank > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + std::vector dims = s.GetDims(); + // dim.size() convert to be type int64_t can't overflow + int64_t size = static_cast(dims.size()); + if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]")); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) { + if (rank > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + std::vector dims = s.GetDims(); + // dim.size() convert to be type int64_t can't overflow + int64_t size = static_cast(dims.size()); + if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), + ConcatString("rank[", size, "] must be at least [", rank, "]")); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRankShape(Shape &shape, int64_t rank, const ge::Operator &op) { + if (rank > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + + int64_t existing = static_cast(shape.GetDimNum()); + + if (shape.GetDims() == UNKNOWN_RANK) { + std::vector out_shape(rank, UNKNOWN_DIM); + shape = Shape(out_shape); + return GRAPH_SUCCESS; + } + if (existing != rank) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]")); + return GRAPH_FAILED; + } + + std::vector dim_values = shape.GetDims(); + shape = Shape(dim_values); + return GRAPH_SUCCESS; +} + +graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { + if (rank > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + int64_t existing = static_cast(s.GetDimNum()); + + if (s.GetDims() == UNKNOWN_RANK) { + std::vector out_shape(rank, UNKNOWN_DIM); + out = Shape(out_shape); + return GRAPH_SUCCESS; + } + + if (existing != rank) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]")); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op) { + out = value; + if (dim == UNKNOWN_DIM) { + return GRAPH_SUCCESS; + } + + if (dim != value) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[", dim, "] should be ", value)); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op) { + // Same shape and unknown rank + if (!RankKnown(s) || !RankKnown(prefix)) { + s_out = s; + prefix_out = prefix; + return GRAPH_SUCCESS; + } + const size_t rank = prefix.GetDimNum(); + std::vector dims1 = s.GetDims(); + if ((dims1 != UNKNOWN_RANK) && (dims1.size() < rank)) { + std::string err_msg = ConcatString("first shape rank[", dims1.size(), "] must be at least rank[", rank, "]"); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + + const size_t rank_s = s.GetDimNum(); + std::vector dims; + dims.reserve(std::max(rank, rank_s)); + dims.resize(rank); + for (size_t i = 0; i < rank; ++i) { + if (Merge(s.GetDim(i), prefix.GetDim(i), dims[i]) != GRAPH_SUCCESS) { + std::string err_msg = ConcatString(i, "th dim of first shape", DebugString(s.GetDims()), + " is not same as that of prefix shape", DebugString(prefix.GetDims())); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + } + prefix_out = Shape(dims); + for (size_t i = rank; i < rank_s; ++i) { + dims.push_back(s.GetDim(i)); + } + s_out = Shape(dims); + return GRAPH_SUCCESS; +} + +graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out) { + if (dim1 == dim2) { + out = dim1; + return GRAPH_SUCCESS; + } else if (dim2 == UNKNOWN_DIM) { + out = dim1; + return GRAPH_SUCCESS; + } else if (dim1 == UNKNOWN_DIM) { + out = dim2; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op) { + // Same shape and unknown rank + if (s0.GetDims() == s1.GetDims()) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s1)) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s0)) { + out = s1; + return GRAPH_SUCCESS; + } + + const size_t rank = s0.GetDimNum(); + if (s1.GetDimNum() != rank) { + std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape", + DebugString(s1.GetDims())); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + + // Check if each dims equal + bool return_s0 = true; + bool return_s1 = true; + for (size_t i = 0; i < rank; i++) { + int64_t d0 = s0.GetDim(i); + int64_t d1 = s1.GetDim(i); + if (d0 == UNKNOWN_DIM) { + if (d1 != UNKNOWN_DIM) { + return_s0 = false; + } + } else if (d1 == UNKNOWN_DIM) { + return_s1 = false; + } else if (d0 != d1) { + std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()), + " and second shape", DebugString(s1.GetDims())); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + } + + if (return_s0 || return_s1) { + out = return_s0 ? s0 : s1; + return GRAPH_SUCCESS; + } + + // Merge dims + std::vector dims(rank, 0); + for (size_t i = 0; i < rank; ++i) { + // Invariant for merge was checked earlier, so CHECK is ok. + if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) { + std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()), + " and second shape", DebugString(s1.GetDims())); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + } + + out = Shape(dims); + return GRAPH_SUCCESS; +} + +void MergeShape(const Shape &shared_shape, const Shape &value_shape, std::vector &out, bool &shape_changed) { + for (size_t i = 0; i < out.size(); ++i) { + if (shared_shape.GetDim(i) == value_shape.GetDim(i) || shared_shape.GetDim(i) == -1) { + out[i] = shared_shape.GetDim(i); + } else { + out[i] = -1; + shape_changed = true; + } + } +} + +void MergeRange(const std::vector> &shared_shape_range, + const std::vector> &value_shape_range, + std::vector> &out, bool &shape_changed) { + for (size_t i = 0; i < out.size(); ++i) { + auto &shared_range = shared_shape_range[i]; + auto &value_range = value_shape_range[i]; + if (shared_range.first <= value_range.first) { + out[i].first = shared_range.first; + } else { + out[i].first = value_range.first; + shape_changed = true; + } + if (shared_range.second == -1 || (value_range.second != -1 && shared_range.second >= value_range.second)) { + out[i].second = shared_range.second; + } else { + out[i].second = value_range.second; + shape_changed = true; + } + } +} + +graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range, + ShapeAndRange &out, bool &shape_changed, const ge::Operator &op) { + if (!RankKnown(shared_shape_and_range.shape_)) { + out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_}; + return GRAPH_SUCCESS; + } + if (!RankKnown(value_shape_and_range.shape_) || + (shared_shape_and_range.shape_.GetDimNum() != value_shape_and_range.shape_.GetDimNum())) { + out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_}; + shape_changed = true; + return GRAPH_SUCCESS; + } + auto actual_shared_range = shared_shape_and_range.shape_range_; + auto actual_value_range = value_shape_and_range.shape_range_; + if (shared_shape_and_range.shape_.GetDimNum() != shared_shape_and_range.shape_range_.size()) { + actual_shared_range.clear(); + for (auto dim : shared_shape_and_range.shape_.GetDims()) { + if (dim == ge::UNKNOWN_DIM) { + actual_shared_range.push_back({1, -1}); + } else { + actual_shared_range.push_back({dim, dim}); + } + } + } + if (value_shape_and_range.shape_.GetDimNum() != value_shape_and_range.shape_range_.size()) { + actual_value_range.clear(); + for (auto dim : value_shape_and_range.shape_.GetDims()) { + if (dim == ge::UNKNOWN_DIM) { + actual_value_range.push_back({1, -1}); + } else { + actual_value_range.push_back({dim, dim}); + } + } + } + const size_t rank = value_shape_and_range.shape_.GetDimNum(); + std::vector dims(rank); + std::vector> shape_range(rank); + MergeShape(shared_shape_and_range.shape_, value_shape_and_range.shape_, dims, shape_changed); + MergeRange(actual_shared_range, actual_value_range, shape_range, shape_changed); + out = {Shape(dims), shape_range, value_shape_and_range.shape_type_}; + return GRAPH_SUCCESS; +} + +graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op) { + if (!RankKnown(s)) { + out = Shape(ge::UNKNOWN_SHAPE); + return GRAPH_SUCCESS; + } + int64_t dim_index = dim_index_in; + if (dim_index < 0) { + dim_index = static_cast(s.GetDimNum()) + dim_index; + } + if (!FastBoundsCheck(dim_index, s.GetDimNum())) { + out = Shape(); + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]")); + return GRAPH_FAILED; + } + std::vector dims = s.GetDims(); + dims[dim_index] = new_dim; + out = Shape(dims); + return GRAPH_SUCCESS; +} + +template +bool FastBoundsCheck(const Ta index, const Tb limit) { + static_assert(std::is_integral::value && std::is_integral::value, + "FastBoundsCheck can only be used on integer types."); + typedef typename std::make_unsigned::type UIndex; + return static_cast(index) < static_cast(limit); +} + +graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out) { + if (dim1 == 0) { + out = dim2; + } else if (dim2 == 0) { + out = dim1; + } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { + out = UNKNOWN_DIM; + } else { + const int64_t sum = dim1 + dim2; + if (sum < 0) { + return GRAPH_FAILED; + } + out = sum; + } + return GRAPH_SUCCESS; +} + +graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op) { + if (dim2 == 0) { + out = dim1; + } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { + out = UNKNOWN_DIM; + } else { + if (dim1 < dim2) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]")); + return GRAPH_FAILED; + } + out = dim1 - dim2; + } + return GRAPH_SUCCESS; +} + +graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op) { + if (s.GetDimNum() > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + const int64_t rank = static_cast(s.GetDimNum()); + TensorDesc tensor(s); + if (!RankKnown(s) || + (start == 0 && ((tensor.GetRealDimCnt() != -1 && end >= rank) || end == std::numeric_limits::max()))) { + out = s; + return GRAPH_SUCCESS; + } + + if (start > rank) { + start = rank; + } + if (end > rank) { + end = rank; + } + + if (stride < 0 && start == rank) { + --start; + } + + if (start < 0) { + start = rank + start; + if (start < 0) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]")); + return GRAPH_FAILED; + } + } + + if (end < 0) { + end = rank + end; + if (end < 0) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]")); + return GRAPH_FAILED; + } + } + + // stride > 0 and start > end + if (!((stride <= 0 || start <= end))) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]")); + return GRAPH_FAILED; + } + // stride < 0 and start < end + if (!(stride >= 0 || start >= end)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]")); + return GRAPH_FAILED; + } + std::vector dims; + for (int64_t i = start; stride > 0 ? i < end : i > end; i += stride) { + dims.push_back(s.GetDim(i)); + } + Shape tmp(dims); + out = tmp; + return GRAPH_SUCCESS; +} + +graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out) { + if (!RankKnown(s1) || !RankKnown(s2)) { + out = Shape(ge::UNKNOWN_RANK); + return GRAPH_SUCCESS; + } + size_t s1_rank = s1.GetDimNum(); + size_t s2_rank = s2.GetDimNum(); + size_t rank = s1_rank + s2_rank; + std::vector dims; + dims.reserve(rank); + for (size_t i = 0; i < s1_rank; ++i) { + dims.push_back(s1.GetDim(i)); + } + for (size_t i = 0; i < s2_rank; ++i) { + dims.push_back(s2.GetDim(i)); + } + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out) { + std::vector dims; + dims.reserve(2); // The number of dims is 2. + dims.push_back(dim1); + dims.push_back(dim2); + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus Vector(int64_t dim, Shape &out) { + std::vector dims; + dims.reserve(1); + dims.push_back(dim); + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_name, int64_t rank, + std::vector &data) { + auto shape_data_desc = op.GetInputDesc(dst_name); + + std::vector input_infer_depends = {dst_name}; + PREPARE_DYNAMIC_SHAPE(input_infer_depends); + + Shape shape_data_shape(shape_data_desc.GetShape()); + std::vector dims = shape_data_shape.GetDims(); + DataType data_type = shape_data_desc.GetDataType(); + if (dims.size() != static_cast(rank)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]")); + return GRAPH_FAILED; + } + int64_t dim_value = ((rank > 0) && (dims[0] > 0)) ? dims[0] : 1; + data.clear(); + if (dims[0] < 0) { + OP_LOGI(op, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]); + data.push_back(UNKNOWN_DIM_NUM); + return GRAPH_SUCCESS; + } + data.reserve(dim_value); + Tensor shape_tensor; + if (data_type == DT_INT32) { + if (op.GetInputConstData(dst_name.c_str(), shape_tensor) == GRAPH_SUCCESS) { + const auto *shape_data = reinterpret_cast(shape_tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(static_cast(shape_data[i])); + } + } else { + OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(UNKNOWN_DIM); + } + } + } else if (data_type == DT_INT64) { + if (op.GetInputConstData(dst_name.c_str(), shape_tensor) == GRAPH_SUCCESS) { + const auto *shape_data = reinterpret_cast(shape_tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(static_cast(shape_data[i])); + } + } else { + OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(UNKNOWN_DIM); + } + } + } else { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64")); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank, std::vector &data, + const ge::Operator &op) { + TensorDesc shape_data_desc = tensor.GetTensorDesc(); + Shape shape_data_shape = shape_data_desc.GetShape(); + std::vector dims = shape_data_shape.GetDims(); + DataType data_type = shape_data_desc.GetDataType(); + + if (dims.size() != static_cast(rank)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]")); + return GRAPH_FAILED; + } + int64_t dim_value = rank > 0 ? dims[0] : 1; + OP_LOGI(op, "data_type = %d, dim_value = %ld", data_type, dim_value); + data.clear(); + data.reserve(dim_value); + if (data_type == DT_INT32) { + const int32_t *shape_data = reinterpret_cast(tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + OP_LOGI(op, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast(shape_data[i])); + data.push_back(static_cast(shape_data[i])); + } + } else if (data_type == DT_INT64) { + const int64_t *shape_data = reinterpret_cast(tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + OP_LOGI(op, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]); + data.push_back(shape_data[i]); + } + } else { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64")); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op) { + std::vector shape_data; + GetShapeDataFromConstData(tensor, 1, shape_data, op); + out = Shape(shape_data); + return GRAPH_SUCCESS; +} + +graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, Shape &out) { + std::vector shape_data; + if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + out = Shape(shape_data); + return GRAPH_SUCCESS; +} + +graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op) { + std::vector shape_data; + GetShapeDataFromConstData(tensor, 0, shape_data, op); + out = shape_data[0]; + return GRAPH_SUCCESS; +} + +graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) { + if (rank > INT32_MAX) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max")); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + std::vector dims = s.GetDims(); + if (!((dims.size() <= static_cast(rank)) || (dims == ge::UNKNOWN_SHAPE))) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank)); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus Scalar(Shape &out) { + std::vector dims = {}; + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus UnchangedShape(Operator &op, const string input_name, const string output_name) { + TensorDesc desc = op.GetOutputDescByName(output_name.c_str()); + desc.SetShape(op.GetInputDescByName(input_name.c_str()).GetShape()); + return op.UpdateOutputDesc(output_name.c_str(), desc); +} + +graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out, + const ge::Operator &op) { + if (divisor == 1) { + out = dividend; + } else if ((dividend == ge::UNKNOWN_DIM) || (divisor == ge::UNKNOWN_DIM)) { + out = ge::UNKNOWN_DIM; + } else { + if (divisor <= 0) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid divisor[", divisor, "], should be positive")); + return GRAPH_FAILED; + } + if (!((!evenlyDivisible) || (dividend % divisor) == 0)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT( + op, ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]")); + return GRAPH_FAILED; + } + out = dividend / divisor; + } + return GRAPH_SUCCESS; +} + +bool ShapeFullDefined(const Shape &shape) { + if (!RankKnown(shape)) { + return false; + } + std::vector dims = shape.GetDims(); + + for (const auto &dim : dims) { + if (dim == ge::UNKNOWN_DIM) { + return false; + } + } + return true; +} + +bool ShapeFullyDefined(const Shape &shape) { + if (!RankKnown(shape)) { + return false; + } + + std::vector dims = shape.GetDims(); + for (const int64_t &dim : dims) { + if (dim == ge::UNKNOWN_DIM) { + return false; + } + } + + return true; +} + +bool RankKnown(const Shape &shape) { + std::vector dims = shape.GetDims(); + if (dims == ge::UNKNOWN_RANK) { + return false; + } + return true; +} + +Shape UnknownShapeOfRank(int64_t rank) { + std::vector dims(rank); + for (int64_t i = 0; i < rank; ++i) { + dims[i] = ge::UNKNOWN_DIM; + } + return Shape(dims); +} + +bool ValueKnown(const Shape &shape, const size_t &dim_index) { + if (shape.GetDims() == ge::UNKNOWN_SHAPE) { + return false; + } + if (dim_index >= shape.GetDims().size()) { + return false; + } + if (shape.GetDim(dim_index) == ge::UNKNOWN_DIM) { + return false; + } + + return true; +} + +graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape, + const ge::Operator &op) { + // Validate ranks + Shape unused_shape; + if (WithRank(indices, 2, unused_shape, op) != GRAPH_SUCCESS) { // The rank is 2. + std::string err_msg = ConcatString("failed to call WithRank function, indices has wrong shape", + DebugString(indices.GetShape().GetDims()), ", it should be 2D"); + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + if (WithRank(values, 1, unused_shape, op) != GRAPH_SUCCESS) { + std::string err_msg = ConcatString("failed to call WithRank function, values has wrong shape", + DebugString(values.GetShape().GetDims()), ", it should be 1D"); + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + if (WithRank(shape, 1, unused_shape, op) != GRAPH_SUCCESS) { + std::string err_msg = ConcatString("failed to call WithRank function, shape has wrong shape", + DebugString(shape.GetShape().GetDims()), ", it should be 1D"); + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg); + return GRAPH_FAILED; + } + + // Number of elements in indices and values must match + Shape indices_shape = indices.GetShape(); + Shape values_shape = values.GetShape(); + if (ValueKnown(indices_shape, 0)) { + if (ValueKnown(values_shape, 0)) { + if (indices_shape.GetDim(0) != values_shape.GetDim(0)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[0] of indices and dim[0] of value do not match, ", + indices_shape.GetDim(0), " and ", values_shape.GetDim(0))); + return GRAPH_FAILED; + } + } + } + + // Rank embedded in indices must match shape. + Shape sparse_shape = shape.GetShape(); + if (ValueKnown(indices_shape, 1)) { + if (ValueKnown(sparse_shape, 0)) { + if (indices_shape.GetDim(1) != sparse_shape.GetDim(0)) { + VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[1] of indices and dim[0] of sparse do not match, ", + indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0))); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +std::string DTypeStr(DataType dtype) { + auto iter = + std::find_if(dtype_maps.begin(), dtype_maps.end(), + [dtype](const std::map::value_type &kv) { return (kv.second == dtype); }); + if (iter != dtype_maps.end()) { + return iter->first; + } else { + return std::string("DT_UNDEFINED"); + } +} + +graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range) { + auto context = op.GetInferenceContext(); + std::vector marks; + context->GetMarks(marks); + + if (!marks.empty()) { + OP_LOGI(op, "Set marks[0] = %s", marks[0].GetString()); + bool shape_changed = false; + auto aicpu_resource_context = dynamic_cast(context->GetResourceContext(marks[0])); + if (aicpu_resource_context == nullptr) { + aicpu_resource_context = new (std::nothrow) AicpuResourceContext(); + if (aicpu_resource_context == nullptr) { + AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("new AicpuResourceContext failed.")); + return GRAPH_FAILED; + } + aicpu_resource_context->shape_and_range_.push_back(feed_shape_and_range); + if (context->SetResourceContext(marks[0], aicpu_resource_context) != GRAPH_SUCCESS) { + delete aicpu_resource_context; + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("set resource context failed.")); + return GRAPH_FAILED; + } + shape_changed = true; + } else { + auto &shape_and_range = aicpu_resource_context->shape_and_range_; + if (shape_and_range.empty()) { + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("get resource context shape and ranges failed.")); + return GRAPH_FAILED; + } + MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed, op); + } + if (shape_changed) { + if (context->AddChangedResourceKey(marks[0]) != GRAPH_SUCCESS) { + AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("add change resource key failed.")); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context) { + std::vector marks; + infer_context->GetMarks(marks); + if (!marks.empty()) { + OP_LOGI(op, "Get marks[0] = %s", marks[0].GetString()); + if (infer_context->RegisterReliedOnResourceKey(marks[0]) != GRAPH_SUCCESS) { + AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("register relied on resource key failed.")); + return GRAPH_FAILED; + } + auto aicpu_resource_context = dynamic_cast(infer_context->GetResourceContext(marks[0])); + if (aicpu_resource_context != nullptr) { + auto &shape_and_range = aicpu_resource_context->shape_and_range_; + if (shape_and_range.empty()) { + AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("get resource context shape and ranges failed.")); + return GRAPH_FAILED; + } + out.shape_ = shape_and_range[0].shape_; + out.shape_range_ = shape_and_range[0].shape_range_; + out.shape_type_ = shape_and_range[0].shape_type_; + geted = true; + } + } + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.h index ac3cda25d0e..b3f3ed90bc1 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/utils/common_shape_fns.h @@ -1,406 +1,406 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! - * \file common_shape_fns.h - * \brief - */ -#ifndef CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ -#define CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ - -#include -#include -#include "graph/tensor.h" -#include "graph/operator.h" -#include "graph/resource_context.h" - -namespace ge { - -struct ShapeAndRange { - Shape shape_; - std::vector> shape_range_; - DataType shape_type_; -}; - -struct AicpuResourceContext : public ResourceContext { - std::vector shape_and_range_; -}; - -/** - * Check whether Shape's rank is at least rank - * @param tensor Input tensor - * @param rank expect val of Shape - * @param out Output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); - -/** - * Check whether Shape's rank is at least rank - * @param tensor Input tensor - * @param rank expect val of Shape - * @param out Output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name); - -/** - * Check whether Shape's rank is equal to rank - * @param shape Input tensor shape - * @param rank expect shape rank - * @param out Output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRankShape(Shape &shape, int64_t rank, const ge::Operator &op); - -/** - * Check whether Shape's rank is equal to rank - * @param tensor Input tensor - * @param rank expect val of Shape - * @param out Output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); - -/** - * Check whether dim is equal to value - * @param dim Input dim - * @param value expect val of dim - * @param out Output dim - * @return status whether Dim is equal to value - */ -graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op); - -/** - * Merge two shapes - * @param s0 first shape val - * @param prefix second shape val - * @param s_out merged shape val - * @param prefix_out prefix out shape val - * @return status whether this operation success - */ -graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op); - -/** - * Merge two dims of Shape - * @param dim0 first dim val - * @param dim1 second dim val - * @param out merged dim val - * @return status whether this operation success - */ -graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out); - -/** - * Merge two shapes - * @param s0 first shape val - * @param s1 second shape val - * @param out merged shape val - * @return status whether this operation success - */ -graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op); - -/** - * Merge two shapes - * @param s0 first Geshape val - * @param s1 second Geshape val - * @param out merged Geshape val - * @return status whether this operation success - */ -graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op); - -/** - * Merge two shapes - * @param shared_shape first Geshape val - * @param value_shape second Geshape val - * @param out merged shape val - * @param shape_changed status whether shape has changed - */ -void MergeShape(const Shape &shared_shape, const Shape &value_shape, std::vector &out, bool &shape_changed); - -/** - * Merge two shape ranges - * @param shared_shape_range first shape range val - * @param value_shape_range second shape range val - * @param out merged shape range val - * @param shape_changed status whether shape range has changed - */ -void MergeRange(const std::vector> &shared_shape_range, - const std::vector> &value_shape_range, - std::vector> &out, bool &shape_changed); - -/** - * Merge two shapes and ranges - * @param shared_shape_and_range first shape and range val - * @param value_shape_and_range second shape and range val - * @param out merged shape and range val - * @param shape_changed status whether shape and range has changed - * @return status whether this operation success - */ -graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range, - ShapeAndRange &out, bool &shape_changed, const ge::Operator &op); - -/** - * Replace one dim in a given shape - * @param s original shape - * @param dim_index_in dim index - * @param new_dim new dim value - * @param out new shape - * @return status whether this operation success - */ -graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op); - -/** - * Replace one dim in a given shape - * @param s original shape - * @param dim_index_in dim index - * @param new_dim new dim value - * @param out new shape - * @return status whether this operation success - */ -graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op); - -/** - * Check if it satisfies 0 <= index < limit - * @param index first input - * @param limit second input - * @return status whether this operation success - */ -template -bool FastBoundsCheck(const Ta index, const Tb limit); - -/** - * Add two dims - * @param dim0 first dim val - * @param dim1 second dim val - * @param out sum dim val - * @return status whether this operation success - */ -graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out); - -/** - * Subtract two dims - * @param dim0 first dim val - * @param dim1 second dim val - * @param out Subtract dim val - * @return status whether this operation success - */ -graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op); - -/** - * Get SubShape according to start end index and step size stride - * @param s input Shape - * @param start sub start index - * @param end sub end index - * @param stride sub step size - * @param out sub shape output - * @return status whether this operation success - */ -graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op); - -/** - * Get SubShape according to start end index and step size stride - * @param s input Shape - * @param start sub start index - * @param end sub end index - * @param stride sub step size - * @param out sub shape output - * @return status whether this operation success - */ -graphStatus SubShape(const Shape &s, size_t start, size_t end, size_t stride, Shape &out); - -/** - * Get SubShape according to start end index and step size stride - * @param s input Shape - * @param start sub start index - * @param end sub end index - * @param stride sub step size - * @param out sub shape output - * @return status whether this operation success - */ -graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op); - -/** - * Concatenate two shape - * @param s1 first shape - * @param s2 second shape - * @param out concatenated shape - * @return status whether this operation success - */ -graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out); - -/** - * Concatenate two shape - * @param s1 first shape - * @param s2 second shape - * @param out concatenated shape - * @return status whether this operation success - */ -graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out); - -/** - * Gen matrix shape according d1 and d2 - * @param dim1 first dim val - * @param dim2 first dim val - * @param out matrix shape - * @return status whether this operation success - */ -graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out); - -/** - * Gen vector shape according d - * @param dim dim val - * @param out vector shape - * @return status whether this operation success - */ -graphStatus Vector(int64_t dim, Shape &out); - -/** - * Make shape from shape tensor - * @param tensor shape tensor - * @param out shape - * @return status whether this operation success - */ -graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op); - -/** - * Make shape from shape tensor - * @param op Operator - * @param dst_name const string & - * @param out Shape - * @return status whether this operation success - */ -graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, Shape &out); - -/** - * Make dim from scalar tensor - * @param tensor shape tensor - * @param out shape - * @return status whether this operation success - */ -graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op); - -/** - * Check whether Shape's rank is at most rank - * @param tensor input tensor - * @param rank expect val of Shape - * @param out output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); - -/** - * Check whether Shape's rank is at most rank - * @param tensor input tensor - * @param rank expect val of Shape - * @param out output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus WithRankAtMost(const TensorDesc &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op); - -/** - * make a empty dim shape - * @param out output Shape - * @return status whether Shape's condition Satisfied - */ -graphStatus Scalar(Shape &out); - -/** - * set input_name shape to output_name shape - * @param op Operator which need to infershape - * @param input_name input name of Operator - * @param output_name output name of Operator - * @return status whether infershape success - */ -graphStatus UnchangedShape(Operator &op, const string input_name, const string output_name); - -/** - * Divide dim - * @param dividend - * @param divisor - * @param evenlyDivisible if to be divisible - * @param out dims - * @return status whether this operation success - */ -graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out, - const ge::Operator &op); - -/** - * check shape fully defined or not - * @param shape Shape is checked - * @return whether shape is fully defined - */ -bool ShapeFullDefined(const Shape &shape); - -/** - * check shape fully defined or not - * @param shape Shape is checked - * @return whether shape is fully defined - */ -bool ShapeFullyDefined(const Shape &shape); - -/** - * check shape known or not - * @param shape Shape is checked - * @return whether rank is known - */ -bool RankKnown(const Shape &shape); - -/** - * check ge_shape known or not - * @param shape Shape is checked - * @return whether rank is known - */ -bool RankKnown(const Shape &shape); - -/** - * make a unknown shape with rank - * @return unknown shape - */ -Shape UnknownShapeOfRank(int64_t rank); - -/** - * check dim value known or not - * @param shape which Shape need check dim value - * @param dimIndex the index of dim - * @return whether dim value is known - */ -bool ValueKnown(const Shape &shape, const size_t &dim_index); - -/** - * Validates the 3 component tensors of a sparse tensor - * have the proper shapes. - * @param sparse indices shape - * @param sparse values shape - * @param sparse shape - * @return status whether this operation success - */ -graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape, - const ge::Operator &op); - -/** - * @brief get string from data type - * @param dtype data type - * @return string of data type - */ -std::string DTypeStr(DataType dtype); - -graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range); - -graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context); - -} // namespace ge - -#endif // CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file common_shape_fns.h + * \brief + */ +#ifndef CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ +#define CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ + +#include +#include +#include "graph/tensor.h" +#include "graph/operator.h" +#include "graph/resource_context.h" + +namespace ge { + +struct ShapeAndRange { + Shape shape_; + std::vector> shape_range_; + DataType shape_type_; +}; + +struct AicpuResourceContext : public ResourceContext { + std::vector shape_and_range_; +}; + +/** + * Check whether Shape's rank is at least rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); + +/** + * Check whether Shape's rank is at least rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name); + +/** + * Check whether Shape's rank is equal to rank + * @param shape Input tensor shape + * @param rank expect shape rank + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankShape(Shape &shape, int64_t rank, const ge::Operator &op); + +/** + * Check whether Shape's rank is equal to rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); + +/** + * Check whether dim is equal to value + * @param dim Input dim + * @param value expect val of dim + * @param out Output dim + * @return status whether Dim is equal to value + */ +graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op); + +/** + * Merge two shapes + * @param s0 first shape val + * @param prefix second shape val + * @param s_out merged shape val + * @param prefix_out prefix out shape val + * @return status whether this operation success + */ +graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op); + +/** + * Merge two dims of Shape + * @param dim0 first dim val + * @param dim1 second dim val + * @param out merged dim val + * @return status whether this operation success + */ +graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out); + +/** + * Merge two shapes + * @param s0 first shape val + * @param s1 second shape val + * @param out merged shape val + * @return status whether this operation success + */ +graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op); + +/** + * Merge two shapes + * @param s0 first Geshape val + * @param s1 second Geshape val + * @param out merged Geshape val + * @return status whether this operation success + */ +graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op); + +/** + * Merge two shapes + * @param shared_shape first Geshape val + * @param value_shape second Geshape val + * @param out merged shape val + * @param shape_changed status whether shape has changed + */ +void MergeShape(const Shape &shared_shape, const Shape &value_shape, std::vector &out, bool &shape_changed); + +/** + * Merge two shape ranges + * @param shared_shape_range first shape range val + * @param value_shape_range second shape range val + * @param out merged shape range val + * @param shape_changed status whether shape range has changed + */ +void MergeRange(const std::vector> &shared_shape_range, + const std::vector> &value_shape_range, + std::vector> &out, bool &shape_changed); + +/** + * Merge two shapes and ranges + * @param shared_shape_and_range first shape and range val + * @param value_shape_and_range second shape and range val + * @param out merged shape and range val + * @param shape_changed status whether shape and range has changed + * @return status whether this operation success + */ +graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range, + ShapeAndRange &out, bool &shape_changed, const ge::Operator &op); + +/** + * Replace one dim in a given shape + * @param s original shape + * @param dim_index_in dim index + * @param new_dim new dim value + * @param out new shape + * @return status whether this operation success + */ +graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op); + +/** + * Replace one dim in a given shape + * @param s original shape + * @param dim_index_in dim index + * @param new_dim new dim value + * @param out new shape + * @return status whether this operation success + */ +graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op); + +/** + * Check if it satisfies 0 <= index < limit + * @param index first input + * @param limit second input + * @return status whether this operation success + */ +template +bool FastBoundsCheck(const Ta index, const Tb limit); + +/** + * Add two dims + * @param dim0 first dim val + * @param dim1 second dim val + * @param out sum dim val + * @return status whether this operation success + */ +graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out); + +/** + * Subtract two dims + * @param dim0 first dim val + * @param dim1 second dim val + * @param out Subtract dim val + * @return status whether this operation success + */ +graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const Shape &s, size_t start, size_t end, size_t stride, Shape &out); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op); + +/** + * Concatenate two shape + * @param s1 first shape + * @param s2 second shape + * @param out concatenated shape + * @return status whether this operation success + */ +graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out); + +/** + * Concatenate two shape + * @param s1 first shape + * @param s2 second shape + * @param out concatenated shape + * @return status whether this operation success + */ +graphStatus Concatenate(const Shape &s1, const Shape &s2, Shape &out); + +/** + * Gen matrix shape according d1 and d2 + * @param dim1 first dim val + * @param dim2 first dim val + * @param out matrix shape + * @return status whether this operation success + */ +graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out); + +/** + * Gen vector shape according d + * @param dim dim val + * @param out vector shape + * @return status whether this operation success + */ +graphStatus Vector(int64_t dim, Shape &out); + +/** + * Make shape from shape tensor + * @param tensor shape tensor + * @param out shape + * @return status whether this operation success + */ +graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op); + +/** + * Make shape from shape tensor + * @param op Operator + * @param dst_name const string & + * @param out Shape + * @return status whether this operation success + */ +graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, Shape &out); + +/** + * Make dim from scalar tensor + * @param tensor shape tensor + * @param out shape + * @return status whether this operation success + */ +graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op); + +/** + * Check whether Shape's rank is at most rank + * @param tensor input tensor + * @param rank expect val of Shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op); + +/** + * Check whether Shape's rank is at most rank + * @param tensor input tensor + * @param rank expect val of Shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtMost(const TensorDesc &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op); + +/** + * make a empty dim shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus Scalar(Shape &out); + +/** + * set input_name shape to output_name shape + * @param op Operator which need to infershape + * @param input_name input name of Operator + * @param output_name output name of Operator + * @return status whether infershape success + */ +graphStatus UnchangedShape(Operator &op, const string input_name, const string output_name); + +/** + * Divide dim + * @param dividend + * @param divisor + * @param evenlyDivisible if to be divisible + * @param out dims + * @return status whether this operation success + */ +graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out, + const ge::Operator &op); + +/** + * check shape fully defined or not + * @param shape Shape is checked + * @return whether shape is fully defined + */ +bool ShapeFullDefined(const Shape &shape); + +/** + * check shape fully defined or not + * @param shape Shape is checked + * @return whether shape is fully defined + */ +bool ShapeFullyDefined(const Shape &shape); + +/** + * check shape known or not + * @param shape Shape is checked + * @return whether rank is known + */ +bool RankKnown(const Shape &shape); + +/** + * check ge_shape known or not + * @param shape Shape is checked + * @return whether rank is known + */ +bool RankKnown(const Shape &shape); + +/** + * make a unknown shape with rank + * @return unknown shape + */ +Shape UnknownShapeOfRank(int64_t rank); + +/** + * check dim value known or not + * @param shape which Shape need check dim value + * @param dimIndex the index of dim + * @return whether dim value is known + */ +bool ValueKnown(const Shape &shape, const size_t &dim_index); + +/** + * Validates the 3 component tensors of a sparse tensor + * have the proper shapes. + * @param sparse indices shape + * @param sparse values shape + * @param sparse shape + * @return status whether this operation success + */ +graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape, + const ge::Operator &op); + +/** + * @brief get string from data type + * @param dtype data type + * @return string of data type + */ +std::string DTypeStr(DataType dtype); + +graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range); + +graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context); + +} // namespace ge + +#endif // CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/aarch64/libdvm.a b/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/aarch64/libdvm.a deleted file mode 100644 index 470e569e7b8..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/aarch64/libdvm.a +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:682b2ee69405891c45de102bde19fa09268d60a81967bf71eb7f8f60e9af828b -size 772502 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/x86_64/libdvm.a b/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/x86_64/libdvm.a deleted file mode 100644 index d69ada74ee8..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/x86_64/libdvm.a +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f738a1d8bdd2e362f7b126b58e1cbbad43f4d294123a46fcd58d14c991c32bf3 -size 769134 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/plugin/device/ascend/kernel/hccl/hccl_kernel_metadata.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/host/dynamic_shape_kernel.h b/mindspore/ccsrc/plugin/device/ascend/kernel/host/dynamic_shape_kernel.h index d3603c83c69..26a6f7b3cd2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/host/dynamic_shape_kernel.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/host/dynamic_shape_kernel.h @@ -1,41 +1,41 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ -#include -#include -#include -#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" -#include "kernel/kernel.h" -namespace mindspore { -namespace kernel { -class TensorShapeKernelMod : public HostKernelMod { - public: - TensorShapeKernelMod() = default; - ~TensorShapeKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: - void Execute(const std::vector &inputs, const std::vector &outputs, - void *stream_ptr) const; -}; -MS_HOST_REG_KERNEL(DynamicShape, TensorShapeKernelMod); -MS_HOST_REG_KERNEL(TensorShape, TensorShapeKernelMod); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ +#include +#include +#include +#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" +#include "kernel/kernel.h" +namespace mindspore { +namespace kernel { +class TensorShapeKernelMod : public HostKernelMod { + public: + TensorShapeKernelMod() = default; + ~TensorShapeKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: + void Execute(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) const; +}; +MS_HOST_REG_KERNEL(DynamicShape, TensorShapeKernelMod); +MS_HOST_REG_KERNEL(TensorShape, TensorShapeKernelMod); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.cc index c03e09512b3..f48b511b303 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/ascend/kernel/host/host_kernel_build.h" -#include -#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" -#include "include/common/utils/anfalgo.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "utils/log_adapter.h" -#include "kernel/framework_utils.h" -#include "utils/trace_base.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HostOpBuild(const std::shared_ptr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto prim = common::AnfAlgo::GetCNodePrimitive(anf_node); - MS_LOG(INFO) << "Build host op [" << prim->name() << "]"; - - auto kernel_mod_ptr = HostKernelFactory::Get(prim->name()); - if (kernel_mod_ptr == nullptr) { - MS_LOG(ERROR) << "Host can't find Kernel[" << prim->name() << "]"; - return nullptr; - } - - std::vector input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node); - std::vector output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node); - if (!std::static_pointer_cast(kernel_mod_ptr)->Init(prim, input_kernel_tensors, output_kernel_tensors)) { - MS_LOG_WITH_NODE(EXCEPTION, anf_node) - << "#dmsg#Kernel build failed:#dmsg#Initialize host kernel op[" << anf_node->fullname_with_scope() << "] failed." - << trace::DumpSourceLines(anf_node); - } - - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (kernel::CheckResizeCondition(cnode)) { - kernel_mod_ptr->Resize(input_kernel_tensors, output_kernel_tensors); - } - - return kernel_mod_ptr; -} -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/kernel/host/host_kernel_build.h" +#include +#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" +#include "include/common/utils/anfalgo.h" +#include "include/backend/anf_runtime_algorithm.h" +#include "utils/log_adapter.h" +#include "kernel/framework_utils.h" +#include "utils/trace_base.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HostOpBuild(const std::shared_ptr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto prim = common::AnfAlgo::GetCNodePrimitive(anf_node); + MS_LOG(INFO) << "Build host op [" << prim->name() << "]"; + + auto kernel_mod_ptr = HostKernelFactory::Get(prim->name()); + if (kernel_mod_ptr == nullptr) { + MS_LOG(ERROR) << "Host can't find Kernel[" << prim->name() << "]"; + return nullptr; + } + + std::vector input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node); + std::vector output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node); + if (!std::static_pointer_cast(kernel_mod_ptr)->Init(prim, input_kernel_tensors, output_kernel_tensors)) { + MS_LOG_WITH_NODE(EXCEPTION, anf_node) + << "#dmsg#Kernel build failed:#dmsg#Initialize host kernel op[" << anf_node->fullname_with_scope() << "] failed." + << trace::DumpSourceLines(anf_node); + } + + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (kernel::CheckResizeCondition(cnode)) { + kernel_mod_ptr->Resize(input_kernel_tensors, output_kernel_tensors); + } + + return kernel_mod_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.h b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.h index 5471d56b390..d61a966f320 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_build.h @@ -1,27 +1,27 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HostOpBuild(const std::shared_ptr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ +#include +#include "kernel/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HostOpBuild(const std::shared_ptr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_metadata.h b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_metadata.h index 4f9f13c0c71..9717007950d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_metadata.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/host/host_kernel_metadata.h @@ -1,29 +1,29 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ - -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void HostMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ + +#include +#include +#include "kernel/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void HostMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/host/reshape_kernel.h b/mindspore/ccsrc/plugin/device/ascend/kernel/host/reshape_kernel.h index ca936391931..8a3e26636a8 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/host/reshape_kernel.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/host/reshape_kernel.h @@ -1,35 +1,35 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ -#include -#include -#include -#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" -namespace mindspore { -namespace kernel { -class ReshapeKernelMod : public HostKernelMod { - public: - ReshapeKernelMod() = default; - ~ReshapeKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; -}; -MS_HOST_REG_KERNEL(Reshape, ReshapeKernelMod); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ +#include +#include +#include +#include "plugin/device/ascend/kernel/host/host_kernel_mod.h" +namespace mindspore { +namespace kernel { +class ReshapeKernelMod : public HostKernelMod { + public: + ReshapeKernelMod() = default; + ~ReshapeKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; +}; +MS_HOST_REG_KERNEL(Reshape, ReshapeKernelMod); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_RESHAPE_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz deleted file mode 100644 index 97fb224714a..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/aarch64/ms_kernels_internal.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d83a31705f685a9ea4f7d2c0ff4bc2263981813c2d02b3e8b0a1d8e4bc8f1be8 -size 93097904 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz b/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz deleted file mode 100644 index 541842a0499..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/x86_64/ms_kernels_internal.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b59c4333d90b05ab6f752169bb8b47e465098722567c2e49cd329d4a37fc8436 -size 93094496 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero.h b/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero_ext.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero_ext.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero_ext.h b/mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/customize/non_zero_ext.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.cc index 5ef99e800bc..d4379a0565b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.cc @@ -1,209 +1,209 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h" -#include "include/common/utils/utils.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "include/common/utils/anfalgo.h" -#include "include/common/utils/parallel_context.h" -#include "ops/structure_op_name.h" -#include "ops/framework_op_name.h" -#include "ops/framework_ops.h" - -namespace mindspore { -namespace opt { - -void InsertDependForAllGatherOutput::InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, - const FuncGraphPtr &root) const { - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(post_node); - auto post_cnode = post_node->cast(); - auto manager = root->manager(); - std::vector depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node}; - auto depend_node = root->NewCNode(depend_input); - manager->SetEdge(post_node, 1, depend_node); -} - -int64_t InsertDependForAllGatherOutput::DealSegment(const std::vector &node_list) { - int64_t seg_max = -1; - for (auto &node : node_list) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - // get forward segment first recv - if (!cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) && cnode->HasPrimalAttr(kAttrSegment) && - cnode->HasPrimalAttr(kAttrMicro) && GetValue(cnode->GetPrimalAttr(kAttrMicro)) == 0 && - cnode->HasPrimalAttr("pipeline_begin")) { - forward_each_seg_first_recv_[GetValue(cnode->GetPrimalAttr(kAttrSegment))].push_back(node); - MS_LOG(INFO) << "Forward pipeline begin op is: " << node->fullname_with_scope() - << ", segment info: " << GetValue(cnode->GetPrimalAttr(kAttrSegment)); - } - // get max segment - if (cnode->HasPrimalAttr(kAttrSegment)) { - auto segment_info = GetValue(cnode->GetPrimalAttr(kAttrSegment)); - seg_max = std::max(seg_max, segment_info); - } - } - return seg_max; -} - -bool InsertDependForAllGatherOutput::IsLastSegWithRecv(int64_t seg_max, std::shared_ptr cnode) { - return common::AnfAlgo::GetCNodeName(cnode) == kReceiveOpName && cnode->HasPrimalAttr(kAttrSegment) && - GetValue(cnode->GetPrimalAttr(kAttrSegment)) == seg_max && - !cnode->HasPrimalAttr(kPrimalAttrForwardNodeName); -} - -bool InsertDependForAllGatherOutput::IsGatherNode(std::shared_ptr cnode, bool is_recompute) { - return common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName && common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && - common::AnfAlgo::HasNodeAttr(kAttrSegment, cnode) && - common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion) > 0 && !is_recompute; -} - -bool InsertDependForAllGatherOutput::IsRedistriuteAllGatherNode(int64_t seg_max, std::shared_ptr cnode) { - return common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName && - cnode->fullname_with_scope().find("head-PanGuHead") != std::string::npos && - cnode->HasPrimalAttr(kAttrSegment) && GetValue(cnode->GetPrimalAttr(kAttrSegment)) == seg_max; -} - -void InsertDependForAllGatherOutput::GetEachSegSend(const FuncGraphPtr &graph, const std::vector &node_list, - int64_t seg_max) { - for (auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (!node->cast() || !AnfUtils::IsRealKernel(node)) { - continue; - } - auto cnode = node->cast(); - if (common::AnfAlgo::GetCNodeName(cnode) == kSendOpName && cnode->HasPrimalAttr("pipeline_param") && - !cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) { - pipeline_param_send_ = node; - MS_LOG(INFO) << "Pipeline_param_send_ is: " << pipeline_param_send_->fullname_with_scope(); - } - if (IsLastSegWithRecv(seg_max, cnode)) { - auto micro_info = GetValue(cnode->GetPrimalAttr(kAttrMicro)); - forward_last_seg_each_micro_recv_[micro_info] = node; - } - bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue(cnode->GetAttr(kAttrDuplicated)); - if (IsGatherNode(cnode, is_recompute)) { - all_gather_node_[common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion)] = node; - } - if (IsRedistriuteAllGatherNode(seg_max, cnode)) { - redistribution_all_gather_node_.push_back(cnode->input(1)); - } - if (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { - auto node_users = graph->manager()->node_users()[node]; - auto node_pair = node_users.begin(); - for (size_t j = 0; j < node_users.size(); ++j) { - auto current_node = node_pair->first; - auto current_node_users = graph->manager()->node_users()[current_node]; - auto current_node_pair = current_node_users.begin(); - for (size_t k = 0; k < current_node_users.size(); ++k) { - auto current_node_user_node = current_node_pair->first; - get_next_tuplegetitem_node_.push_back(current_node_user_node); - } - node_pair++; - } - } - } -} - -void InsertDependForAllGatherOutput::ReorderGetnext(const FuncGraphPtr &graph, bool *changed) { - if (pipeline_param_send_ != nullptr) { - for (size_t i = 0; i < get_next_tuplegetitem_node_.size(); ++i) { - auto current_node = get_next_tuplegetitem_node_[i]; - MS_LOG(INFO) << "Insert depend for getnext tuplegetitem before first allgather op " - << pipeline_param_send_->fullname_with_scope(); - InsertDepend(current_node, all_gather_node_.begin()->second, graph); - InsertDepend(current_node, pipeline_param_send_, graph); - *changed = true; - } - } - - auto iter = all_gather_node_.end(); - iter--; - if (forward_each_seg_first_recv_.find(0) != forward_each_seg_first_recv_.end()) { - for (auto &node : forward_each_seg_first_recv_[0]) { - MS_LOG(INFO) << "Insert depend last allgather node before recv op " << node->fullname_with_scope(); - InsertDepend(iter->second, node, graph); - } - } -} - -void InsertDependForAllGatherOutput::IsChanged(const FuncGraphPtr &graph, AnfNodePtr node, int64_t segment_info, - bool *changed) { - if (segment_info != 0) { - auto node_users = graph->manager()->node_users()[node]; - auto current_node_pair = node_users.begin(); - for (auto &forward_node : forward_each_seg_first_recv_[segment_info]) { - MS_LOG(INFO) << "Insert depend for tuplegetitem after recv op " << forward_node->fullname_with_scope(); - InsertDepend(forward_node, current_node_pair->first, graph); - *changed = true; - } - for (size_t j = 0; j < node_users.size() - 1; ++j) { - auto current_node = current_node_pair->first; - auto next_node = (++current_node_pair)->first; - MS_LOG(INFO) << "Current_node " << current_node->fullname_with_scope() << ", next_node " - << next_node->fullname_with_scope(); - InsertDepend(current_node, next_node, graph); - *changed = true; - } - } -} - -bool InsertDependForAllGatherOutput::Run(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - bool changed = false; - auto parallel_context = parallel::ParallelContext::GetInstance(); - if (!parallel_context->enable_fold_pipeline()) { - return changed; - } - std::vector node_list = TopoSort(graph->get_return()); - - // find each seg last send, seg_max - int64_t seg_max = DealSegment(node_list); - - GetEachSegSend(graph, node_list, seg_max); - - if (!forward_each_seg_first_recv_.empty()) { - for (auto &node_pair : all_gather_node_) { - auto node = node_pair.second; - auto segment_info = common::AnfAlgo::GetNodeAttr(node, kAttrSegment); - MS_LOG(INFO) << "Node " << node->fullname_with_scope() << ", segment info: " << segment_info; - IsChanged(graph, node, segment_info, &changed); - } - } - - for (size_t i = 0; i < redistribution_all_gather_node_.size(); ++i) { - auto current_node = redistribution_all_gather_node_[i]; - auto micro_info = GetValue(current_node->cast()->GetPrimalAttr(kAttrMicro)); - MS_LOG(INFO) << "Current_node " << current_node->fullname_with_scope() << ", micro_info " << micro_info; - InsertDepend(forward_last_seg_each_micro_recv_[micro_info], current_node, graph); - changed = true; - } - - if (all_gather_node_.empty()) { - return changed; - } - - // reorder getnext tensormove to before pipeline param send - ReorderGetnext(graph, &changed); - - return changed; -} - -} // namespace opt -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h" +#include "include/common/utils/utils.h" +#include "include/backend/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "include/common/utils/parallel_context.h" +#include "ops/structure_op_name.h" +#include "ops/framework_op_name.h" +#include "ops/framework_ops.h" + +namespace mindspore { +namespace opt { + +void InsertDependForAllGatherOutput::InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, + const FuncGraphPtr &root) const { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(post_node); + auto post_cnode = post_node->cast(); + auto manager = root->manager(); + std::vector depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node}; + auto depend_node = root->NewCNode(depend_input); + manager->SetEdge(post_node, 1, depend_node); +} + +int64_t InsertDependForAllGatherOutput::DealSegment(const std::vector &node_list) { + int64_t seg_max = -1; + for (auto &node : node_list) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + // get forward segment first recv + if (!cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) && cnode->HasPrimalAttr(kAttrSegment) && + cnode->HasPrimalAttr(kAttrMicro) && GetValue(cnode->GetPrimalAttr(kAttrMicro)) == 0 && + cnode->HasPrimalAttr("pipeline_begin")) { + forward_each_seg_first_recv_[GetValue(cnode->GetPrimalAttr(kAttrSegment))].push_back(node); + MS_LOG(INFO) << "Forward pipeline begin op is: " << node->fullname_with_scope() + << ", segment info: " << GetValue(cnode->GetPrimalAttr(kAttrSegment)); + } + // get max segment + if (cnode->HasPrimalAttr(kAttrSegment)) { + auto segment_info = GetValue(cnode->GetPrimalAttr(kAttrSegment)); + seg_max = std::max(seg_max, segment_info); + } + } + return seg_max; +} + +bool InsertDependForAllGatherOutput::IsLastSegWithRecv(int64_t seg_max, std::shared_ptr cnode) { + return common::AnfAlgo::GetCNodeName(cnode) == kReceiveOpName && cnode->HasPrimalAttr(kAttrSegment) && + GetValue(cnode->GetPrimalAttr(kAttrSegment)) == seg_max && + !cnode->HasPrimalAttr(kPrimalAttrForwardNodeName); +} + +bool InsertDependForAllGatherOutput::IsGatherNode(std::shared_ptr cnode, bool is_recompute) { + return common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName && common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && + common::AnfAlgo::HasNodeAttr(kAttrSegment, cnode) && + common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion) > 0 && !is_recompute; +} + +bool InsertDependForAllGatherOutput::IsRedistriuteAllGatherNode(int64_t seg_max, std::shared_ptr cnode) { + return common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName && + cnode->fullname_with_scope().find("head-PanGuHead") != std::string::npos && + cnode->HasPrimalAttr(kAttrSegment) && GetValue(cnode->GetPrimalAttr(kAttrSegment)) == seg_max; +} + +void InsertDependForAllGatherOutput::GetEachSegSend(const FuncGraphPtr &graph, const std::vector &node_list, + int64_t seg_max) { + for (auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (!node->cast() || !AnfUtils::IsRealKernel(node)) { + continue; + } + auto cnode = node->cast(); + if (common::AnfAlgo::GetCNodeName(cnode) == kSendOpName && cnode->HasPrimalAttr("pipeline_param") && + !cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) { + pipeline_param_send_ = node; + MS_LOG(INFO) << "Pipeline_param_send_ is: " << pipeline_param_send_->fullname_with_scope(); + } + if (IsLastSegWithRecv(seg_max, cnode)) { + auto micro_info = GetValue(cnode->GetPrimalAttr(kAttrMicro)); + forward_last_seg_each_micro_recv_[micro_info] = node; + } + bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue(cnode->GetAttr(kAttrDuplicated)); + if (IsGatherNode(cnode, is_recompute)) { + all_gather_node_[common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion)] = node; + } + if (IsRedistriuteAllGatherNode(seg_max, cnode)) { + redistribution_all_gather_node_.push_back(cnode->input(1)); + } + if (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { + auto node_users = graph->manager()->node_users()[node]; + auto node_pair = node_users.begin(); + for (size_t j = 0; j < node_users.size(); ++j) { + auto current_node = node_pair->first; + auto current_node_users = graph->manager()->node_users()[current_node]; + auto current_node_pair = current_node_users.begin(); + for (size_t k = 0; k < current_node_users.size(); ++k) { + auto current_node_user_node = current_node_pair->first; + get_next_tuplegetitem_node_.push_back(current_node_user_node); + } + node_pair++; + } + } + } +} + +void InsertDependForAllGatherOutput::ReorderGetnext(const FuncGraphPtr &graph, bool *changed) { + if (pipeline_param_send_ != nullptr) { + for (size_t i = 0; i < get_next_tuplegetitem_node_.size(); ++i) { + auto current_node = get_next_tuplegetitem_node_[i]; + MS_LOG(INFO) << "Insert depend for getnext tuplegetitem before first allgather op " + << pipeline_param_send_->fullname_with_scope(); + InsertDepend(current_node, all_gather_node_.begin()->second, graph); + InsertDepend(current_node, pipeline_param_send_, graph); + *changed = true; + } + } + + auto iter = all_gather_node_.end(); + iter--; + if (forward_each_seg_first_recv_.find(0) != forward_each_seg_first_recv_.end()) { + for (auto &node : forward_each_seg_first_recv_[0]) { + MS_LOG(INFO) << "Insert depend last allgather node before recv op " << node->fullname_with_scope(); + InsertDepend(iter->second, node, graph); + } + } +} + +void InsertDependForAllGatherOutput::IsChanged(const FuncGraphPtr &graph, AnfNodePtr node, int64_t segment_info, + bool *changed) { + if (segment_info != 0) { + auto node_users = graph->manager()->node_users()[node]; + auto current_node_pair = node_users.begin(); + for (auto &forward_node : forward_each_seg_first_recv_[segment_info]) { + MS_LOG(INFO) << "Insert depend for tuplegetitem after recv op " << forward_node->fullname_with_scope(); + InsertDepend(forward_node, current_node_pair->first, graph); + *changed = true; + } + for (size_t j = 0; j < node_users.size() - 1; ++j) { + auto current_node = current_node_pair->first; + auto next_node = (++current_node_pair)->first; + MS_LOG(INFO) << "Current_node " << current_node->fullname_with_scope() << ", next_node " + << next_node->fullname_with_scope(); + InsertDepend(current_node, next_node, graph); + *changed = true; + } + } +} + +bool InsertDependForAllGatherOutput::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + bool changed = false; + auto parallel_context = parallel::ParallelContext::GetInstance(); + if (!parallel_context->enable_fold_pipeline()) { + return changed; + } + std::vector node_list = TopoSort(graph->get_return()); + + // find each seg last send, seg_max + int64_t seg_max = DealSegment(node_list); + + GetEachSegSend(graph, node_list, seg_max); + + if (!forward_each_seg_first_recv_.empty()) { + for (auto &node_pair : all_gather_node_) { + auto node = node_pair.second; + auto segment_info = common::AnfAlgo::GetNodeAttr(node, kAttrSegment); + MS_LOG(INFO) << "Node " << node->fullname_with_scope() << ", segment info: " << segment_info; + IsChanged(graph, node, segment_info, &changed); + } + } + + for (size_t i = 0; i < redistribution_all_gather_node_.size(); ++i) { + auto current_node = redistribution_all_gather_node_[i]; + auto micro_info = GetValue(current_node->cast()->GetPrimalAttr(kAttrMicro)); + MS_LOG(INFO) << "Current_node " << current_node->fullname_with_scope() << ", micro_info " << micro_info; + InsertDepend(forward_last_seg_each_micro_recv_[micro_info], current_node, graph); + changed = true; + } + + if (all_gather_node_.empty()) { + return changed; + } + + // reorder getnext tensormove to before pipeline param send + ReorderGetnext(graph, &changed); + + return changed; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h index d164b437885..403128d71d5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_gather_output.h @@ -1,59 +1,59 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ -#include -#include -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "include/backend/optimizer/helper.h" -#include "include/backend/optimizer/optimizer.h" -#include "plugin/device/ascend/optimizer/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertDependForAllGatherOutput : public Pass { - public: - InsertDependForAllGatherOutput() - : Pass("insert_depend_for_all_gather_output"), kernel_select_(std::make_shared()) {} - ~InsertDependForAllGatherOutput() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - std::map> forward_each_seg_first_recv_; - std::vector redistribution_all_gather_node_; - std::vector get_next_tuplegetitem_node_; - std::map all_gather_node_; - std::map forward_last_seg_each_micro_recv_; - AnfNodePtr pipeline_param_send_ = nullptr; - KernelSelectPtr kernel_select_; - void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphPtr &root) const; - int64_t DealSegment(const std::vector &node_list); - void ReorderGetnext(const FuncGraphPtr &graph, bool *changed); - bool IsLastSegWithRecv(int64_t seg_max, std::shared_ptr cnode); - bool IsGatherNode(std::shared_ptr cnode, bool is_recompute); - bool IsRedistriuteAllGatherNode(int64_t seg_max, std::shared_ptr cnode); - void GetEachSegSend(const FuncGraphPtr &graph, const std::vector &node_list, int64_t seg_max); - void IsChanged(const FuncGraphPtr &graph, AnfNodePtr node, int64_t segment_info, bool *changed); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ +#include +#include +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/backend/optimizer/helper.h" +#include "include/backend/optimizer/optimizer.h" +#include "plugin/device/ascend/optimizer/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertDependForAllGatherOutput : public Pass { + public: + InsertDependForAllGatherOutput() + : Pass("insert_depend_for_all_gather_output"), kernel_select_(std::make_shared()) {} + ~InsertDependForAllGatherOutput() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + std::map> forward_each_seg_first_recv_; + std::vector redistribution_all_gather_node_; + std::vector get_next_tuplegetitem_node_; + std::map all_gather_node_; + std::map forward_last_seg_each_micro_recv_; + AnfNodePtr pipeline_param_send_ = nullptr; + KernelSelectPtr kernel_select_; + void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphPtr &root) const; + int64_t DealSegment(const std::vector &node_list); + void ReorderGetnext(const FuncGraphPtr &graph, bool *changed); + bool IsLastSegWithRecv(int64_t seg_max, std::shared_ptr cnode); + bool IsGatherNode(std::shared_ptr cnode, bool is_recompute); + bool IsRedistriuteAllGatherNode(int64_t seg_max, std::shared_ptr cnode); + void GetEachSegSend(const FuncGraphPtr &graph, const std::vector &node_list, int64_t seg_max); + void IsChanged(const FuncGraphPtr &graph, AnfNodePtr node, int64_t segment_info, bool *changed); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_OUTPUT_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.cc index 8832d0fedc5..45970ffa38c 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.cc @@ -1,126 +1,126 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h" -#include "include/common/utils/utils.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "include/common/utils/anfalgo.h" -#include "include/common/utils/parallel_context.h" -#include "ops/framework_ops.h" - -namespace mindspore { -namespace opt { - -void InsertDependForAllReduce::InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, - const FuncGraphPtr &graph) const { - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(post_node); - auto post_cnode = post_node->cast(); - auto manager = graph->manager(); - std::vector depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node}; - auto depend_node = graph->NewCNode(depend_input); - manager->SetEdge(post_node, 1, depend_node); -} - -void InsertDependForAllReduce::InsertAllReduceOpAfterSendOp(const FuncGraphPtr &graph) { - for (size_t i = 0; i < all_reduce_node_.size(); ++i) { - auto prim = GetCNodePrimitive(all_reduce_node_[i]); - auto segment_info = GetValue(prim->GetAttr(kAttrSegment)); - if (backward_each_seg_last_send_.find(segment_info) != backward_each_seg_last_send_.end() && segment_info != 0) { - MS_LOG(INFO) << "Backward micro max send is: " - << backward_each_seg_last_send_[segment_info]->fullname_with_scope(); - auto before_send_op = backward_each_seg_last_send_[segment_info]->cast()->input(1); - if (IsPrimitiveCNode(before_send_op, prim::kPrimDepend)) { - before_send_op = before_send_op->cast()->input(1); - } - MS_LOG(INFO) << "Before send op is:" << before_send_op->fullname_with_scope(); - InsertDepend(all_reduce_node_[i], before_send_op, graph); - } - } -} - -void InsertDependForAllReduce::HandleAllReduceUsersNode(const FuncGraphPtr &graph) { - for (size_t i = 0; i < allreduce_users_list_.size(); ++i) { - for (size_t j = 1; j < allreduce_users_list_[i].size(); ++j) { - InsertDepend(allreduce_users_list_[i][j - 1], allreduce_users_list_[i][j], graph); - } - InsertDepend(last_allreduce_, allreduce_users_list_[i][0], graph); - } -} - -void InsertDependForAllReduce::FindEachSegLastSend() { - for (auto &node : node_list_) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (IsPrimitiveCNode(cnode, prim::kPrimSend) && cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) && - cnode->HasPrimalAttr(kAttrMicro) && GetValue(cnode->GetPrimalAttr(kAttrMicro)) == micro_max_ && - cnode->HasPrimalAttr(kAttrSegment)) { - backward_each_seg_last_send_[GetValue(cnode->GetPrimalAttr(kAttrSegment))] = node; - } - } -} - -bool InsertDependForAllReduce::Run(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - bool changed = false; - auto parallel_context = parallel::ParallelContext::GetInstance(); - if (!parallel_context->enable_fold_pipeline()) { - return changed; - } - node_list_ = TopoSort(graph->get_return()); - for (auto &node : node_list_) { - MS_EXCEPTION_IF_NULL(node); - if (!node->cast() || !AnfUtils::IsRealKernel(node)) { - continue; - } - auto cnode = node->cast(); - if (cnode->HasPrimalAttr(kAttrMicro) && cnode->GetPrimalAttr(kAttrMicro)->isa()) { - int64_t micro = GetValue(cnode->GetPrimalAttr(kAttrMicro)); - micro_max_ = std::max(micro_max_, micro); - } - bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue(cnode->GetAttr(kAttrDuplicated)); - if (common::AnfAlgo::GetCNodeName(cnode) == kAllReduceOpName && common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && - common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion) > 0 && !is_recompute) { - all_reduce_node_.push_back(node); - auto segment_info = common::AnfAlgo::GetNodeAttr(node, kAttrSegment); - auto fusion_info = common::AnfAlgo::GetNodeAttr(node, kAttrFusion); - MS_LOG(INFO) << "Find all reduce cnode :" << cnode->fullname_with_scope() << ", segment_info" << segment_info - << ", fusion_info" << fusion_info; - if (fusion_info < min_fusion_) { - min_fusion_ = fusion_info; - last_allreduce_ = node; - } - if (segment_info == 0) { - continue; - } - auto node_users = graph->manager()->node_users()[node]; - std::vector node_users_list; - for (auto &node_user : node_users) { - MS_LOG(INFO) << "Node_user: " << node_user.first->fullname_with_scope(); - node_users_list.push_back(node_user.first->cast()); - } - allreduce_users_list_.push_back(node_users_list); - changed = true; - } - } - FindEachSegLastSend(); - InsertAllReduceOpAfterSendOp(graph); - HandleAllReduceUsersNode(graph); - return changed; -} -} // namespace opt -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h" +#include "include/common/utils/utils.h" +#include "include/backend/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "include/common/utils/parallel_context.h" +#include "ops/framework_ops.h" + +namespace mindspore { +namespace opt { + +void InsertDependForAllReduce::InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, + const FuncGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(post_node); + auto post_cnode = post_node->cast(); + auto manager = graph->manager(); + std::vector depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node}; + auto depend_node = graph->NewCNode(depend_input); + manager->SetEdge(post_node, 1, depend_node); +} + +void InsertDependForAllReduce::InsertAllReduceOpAfterSendOp(const FuncGraphPtr &graph) { + for (size_t i = 0; i < all_reduce_node_.size(); ++i) { + auto prim = GetCNodePrimitive(all_reduce_node_[i]); + auto segment_info = GetValue(prim->GetAttr(kAttrSegment)); + if (backward_each_seg_last_send_.find(segment_info) != backward_each_seg_last_send_.end() && segment_info != 0) { + MS_LOG(INFO) << "Backward micro max send is: " + << backward_each_seg_last_send_[segment_info]->fullname_with_scope(); + auto before_send_op = backward_each_seg_last_send_[segment_info]->cast()->input(1); + if (IsPrimitiveCNode(before_send_op, prim::kPrimDepend)) { + before_send_op = before_send_op->cast()->input(1); + } + MS_LOG(INFO) << "Before send op is:" << before_send_op->fullname_with_scope(); + InsertDepend(all_reduce_node_[i], before_send_op, graph); + } + } +} + +void InsertDependForAllReduce::HandleAllReduceUsersNode(const FuncGraphPtr &graph) { + for (size_t i = 0; i < allreduce_users_list_.size(); ++i) { + for (size_t j = 1; j < allreduce_users_list_[i].size(); ++j) { + InsertDepend(allreduce_users_list_[i][j - 1], allreduce_users_list_[i][j], graph); + } + InsertDepend(last_allreduce_, allreduce_users_list_[i][0], graph); + } +} + +void InsertDependForAllReduce::FindEachSegLastSend() { + for (auto &node : node_list_) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (IsPrimitiveCNode(cnode, prim::kPrimSend) && cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) && + cnode->HasPrimalAttr(kAttrMicro) && GetValue(cnode->GetPrimalAttr(kAttrMicro)) == micro_max_ && + cnode->HasPrimalAttr(kAttrSegment)) { + backward_each_seg_last_send_[GetValue(cnode->GetPrimalAttr(kAttrSegment))] = node; + } + } +} + +bool InsertDependForAllReduce::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + bool changed = false; + auto parallel_context = parallel::ParallelContext::GetInstance(); + if (!parallel_context->enable_fold_pipeline()) { + return changed; + } + node_list_ = TopoSort(graph->get_return()); + for (auto &node : node_list_) { + MS_EXCEPTION_IF_NULL(node); + if (!node->cast() || !AnfUtils::IsRealKernel(node)) { + continue; + } + auto cnode = node->cast(); + if (cnode->HasPrimalAttr(kAttrMicro) && cnode->GetPrimalAttr(kAttrMicro)->isa()) { + int64_t micro = GetValue(cnode->GetPrimalAttr(kAttrMicro)); + micro_max_ = std::max(micro_max_, micro); + } + bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue(cnode->GetAttr(kAttrDuplicated)); + if (common::AnfAlgo::GetCNodeName(cnode) == kAllReduceOpName && common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && + common::AnfAlgo::GetNodeAttr(cnode, kAttrFusion) > 0 && !is_recompute) { + all_reduce_node_.push_back(node); + auto segment_info = common::AnfAlgo::GetNodeAttr(node, kAttrSegment); + auto fusion_info = common::AnfAlgo::GetNodeAttr(node, kAttrFusion); + MS_LOG(INFO) << "Find all reduce cnode :" << cnode->fullname_with_scope() << ", segment_info" << segment_info + << ", fusion_info" << fusion_info; + if (fusion_info < min_fusion_) { + min_fusion_ = fusion_info; + last_allreduce_ = node; + } + if (segment_info == 0) { + continue; + } + auto node_users = graph->manager()->node_users()[node]; + std::vector node_users_list; + for (auto &node_user : node_users) { + MS_LOG(INFO) << "Node_user: " << node_user.first->fullname_with_scope(); + node_users_list.push_back(node_user.first->cast()); + } + allreduce_users_list_.push_back(node_users_list); + changed = true; + } + } + FindEachSegLastSend(); + InsertAllReduceOpAfterSendOp(graph); + HandleAllReduceUsersNode(graph); + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h index 20d884a0bb6..03aa2b83363 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/insert_depend_for_all_reduce.h @@ -1,54 +1,54 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ -#include -#include -#include -#include -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "include/backend/optimizer/helper.h" -#include "include/backend/optimizer/optimizer.h" -#include "plugin/device/ascend/optimizer/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertDependForAllReduce : public Pass { - public: - InsertDependForAllReduce() : Pass("insert_depend_for_all_reduce"), kernel_select_(std::make_shared()) {} - ~InsertDependForAllReduce() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphPtr &graph) const; - void InsertAllReduceOpAfterSendOp(const FuncGraphPtr &graph); - void HandleAllReduceUsersNode(const FuncGraphPtr &graph); - void FindEachSegLastSend(); - KernelSelectPtr kernel_select_; - std::vector all_reduce_node_; - int64_t min_fusion_ = INT64_MAX; - int64_t micro_max_ = 0; - std::vector node_list_; - AnfNodePtr last_allreduce_ = nullptr; - std::map backward_each_seg_last_send_; - std::vector> allreduce_users_list_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ +#include +#include +#include +#include +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/backend/optimizer/helper.h" +#include "include/backend/optimizer/optimizer.h" +#include "plugin/device/ascend/optimizer/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertDependForAllReduce : public Pass { + public: + InsertDependForAllReduce() : Pass("insert_depend_for_all_reduce"), kernel_select_(std::make_shared()) {} + ~InsertDependForAllReduce() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphPtr &graph) const; + void InsertAllReduceOpAfterSendOp(const FuncGraphPtr &graph); + void HandleAllReduceUsersNode(const FuncGraphPtr &graph); + void FindEachSegLastSend(); + KernelSelectPtr kernel_select_; + std::vector all_reduce_node_; + int64_t min_fusion_ = INT64_MAX; + int64_t micro_max_ = 0; + std::vector node_list_; + AnfNodePtr last_allreduce_ = nullptr; + std::map backward_each_seg_last_send_; + std::vector> allreduce_users_list_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_REDUCE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/lamb_fission.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/lamb_fission.h index af4fdec0645..86f9e33d7c8 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/lamb_fission.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/lamb_fission.h @@ -1,32 +1,32 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ - -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class LambFissionGe : public PatternProcessPass { - public: - explicit LambFissionGe(bool multi_graph = true) : PatternProcessPass("lamb_fission_ge", multi_graph) {} - ~LambFissionGe() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ + +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class LambFissionGe : public PatternProcessPass { + public: + explicit LambFissionGe(bool multi_graph = true) : PatternProcessPass("lamb_fission_ge", multi_graph) {} + ~LambFissionGe() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_GE_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.cc index 6f3ef88bee7..8fc2cc2ada9 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.cc @@ -1,149 +1,149 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h" -#include -#include -#include -#include "include/backend/anf_runtime_algorithm.h" -#include "include/backend/optimizer/helper.h" -#include "include/common/utils/anfalgo.h" -#include "ops/array_op_name.h" -#include "ops/math_op_name.h" -#include "ops/nn_optimizer_ops.h" -#include "ops/sequence_ops.h" - -namespace mindspore { -namespace opt { -namespace { -// AdamWeightDecay's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, gradient -constexpr size_t kIdxParam = 1; -constexpr size_t kIdxM = 2; -constexpr size_t kIdxV = 3; -constexpr size_t kIdxLr = 4; -constexpr size_t kIdxBeta1 = 5; -constexpr size_t kIdxBeta2 = 6; -constexpr size_t kIdxEps = 7; -constexpr size_t kIdxWeightDecay = 8; -constexpr size_t kIdxGradient = 9; -constexpr size_t kAamWeightDecayInputNum = 10; - -AnfNodePtr CreateNodeOfBinaryOp(const FuncGraphPtr &graph, const string &op_name, const AnfNodePtr &node1, - const AnfNodePtr &node2) { - std::vector new_node_inputs = {NewValueNode(std::make_shared(op_name)), node1, node2}; - return CreateNodeBase(graph, new_node_inputs, node2); -} - -AnfNodePtr CreateNodeOfUnaryOp(const FuncGraphPtr &graph, const string &op_name, const AnfNodePtr &node) { - std::vector new_node_inputs = {NewValueNode(std::make_shared(op_name)), node}; - return CreateNodeBase(graph, new_node_inputs, node); -} - -ValueNodePtr CreateValueNode(const FuncGraphPtr &graph, double value) { - auto tensor = std::make_shared(value); - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - ValueNodePtr value_node = kernel_graph->NewValueNode(tensor->ToAbstract(), tensor); - kernel_graph->AddValueNodeToGraph(value_node); - return value_node; -} - -AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input); - if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { - AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared(kCastOpName)), input}); - MS_EXCEPTION_IF_NULL(cast); - common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get()); - common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast); - cast->set_scope(input->scope()); - return cast; - } - return input; -} -} // namespace - -const BaseRef AdamWeightDecayFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimAdamWeightDecay, Xs}); -} - -const AnfNodePtr AdamWeightDecayFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto adam_weight_decay_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(adam_weight_decay_cnode); - CheckCNodeInputSize(adam_weight_decay_cnode, kAamWeightDecayInputNum); - if (common::AnfAlgo::IsDynamicShape(adam_weight_decay_cnode)) { - MS_LOG_WITH_NODE(EXCEPTION, adam_weight_decay_cnode) - << "AdamWeightDecay don't support dynamic shape, node: " << adam_weight_decay_cnode->fullname_with_scope(); - } - - const auto ori_inputs = adam_weight_decay_cnode->inputs(); - - // cast param to float32 - auto param_fp32 = CreateCastNode(graph, ori_inputs[kIdxParam], kNumberTypeFloat32); - auto m_fp32 = CreateCastNode(graph, ori_inputs[kIdxM], kNumberTypeFloat32); - auto v_fp32 = CreateCastNode(graph, ori_inputs[kIdxV], kNumberTypeFloat32); - auto grad_fp32 = CreateCastNode(graph, ori_inputs[kIdxGradient], kNumberTypeFloat32); - - // create beta1 * m - auto mul_1 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxBeta1], m_fp32); - // create 1-beta1 - auto num_one = CreateValueNode(graph, 1.0); - auto sub_1 = CreateNodeOfBinaryOp(graph, kSubOpName, num_one, ori_inputs[kIdxBeta1]); - // create (1-beta1) * gradient - auto mul_2 = CreateNodeOfBinaryOp(graph, kMulOpName, sub_1, grad_fp32); - // create next_m = beta1 * m + (1 - beat1) * gradient - auto add_1 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_1, mul_2); - - // create beta2 * v - auto mul_3 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxBeta2], v_fp32); - // create gradient^2 - auto square = CreateNodeOfUnaryOp(graph, kSquareOpName, grad_fp32); - // create 1-beta2 - auto sub_2 = CreateNodeOfBinaryOp(graph, kSubOpName, num_one, ori_inputs[kIdxBeta2]); - // create (1-beta2) * gradient^2 - auto mul_4 = CreateNodeOfBinaryOp(graph, kMulOpName, sub_2, square); - // create next_v = beta2 * v + (1 - beta2) * gradient^2 - auto add_2 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_3, mul_4); - - // create sqrt(next_v) - auto sqrt = CreateNodeOfUnaryOp(graph, kSqrtOpName, add_2); - // create eps + sqrt(next_v) - auto add_3 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, ori_inputs[kIdxEps], sqrt); - // create update = next_m / (eps + sqrt(next_v)) - auto real_div = CreateNodeOfBinaryOp(graph, kRealDivOpName, add_1, add_3); - // create weight_decay * param - auto mul_5 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxWeightDecay], param_fp32); - // create update <== weight_decay * param + update - auto add_4 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_5, real_div); - // create update_with_lr = lr * update - auto mul_6 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxLr], add_4); - // create param - update_with_lr - auto sub_3 = CreateNodeOfBinaryOp(graph, kSubOpName, param_fp32, mul_6); - - // create param = param - update_with_lr - auto assign_1 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), param_fp32, sub_3); - // create m = next_m - auto assign_2 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), m_fp32, add_1); - // create v = next_v - auto assign_3 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), v_fp32, add_2); - - return CreateMakeTupleNode(graph, std::vector{assign_1, assign_2, assign_3}); -} -} // namespace opt -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h" +#include +#include +#include +#include "include/backend/anf_runtime_algorithm.h" +#include "include/backend/optimizer/helper.h" +#include "include/common/utils/anfalgo.h" +#include "ops/array_op_name.h" +#include "ops/math_op_name.h" +#include "ops/nn_optimizer_ops.h" +#include "ops/sequence_ops.h" + +namespace mindspore { +namespace opt { +namespace { +// AdamWeightDecay's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, gradient +constexpr size_t kIdxParam = 1; +constexpr size_t kIdxM = 2; +constexpr size_t kIdxV = 3; +constexpr size_t kIdxLr = 4; +constexpr size_t kIdxBeta1 = 5; +constexpr size_t kIdxBeta2 = 6; +constexpr size_t kIdxEps = 7; +constexpr size_t kIdxWeightDecay = 8; +constexpr size_t kIdxGradient = 9; +constexpr size_t kAamWeightDecayInputNum = 10; + +AnfNodePtr CreateNodeOfBinaryOp(const FuncGraphPtr &graph, const string &op_name, const AnfNodePtr &node1, + const AnfNodePtr &node2) { + std::vector new_node_inputs = {NewValueNode(std::make_shared(op_name)), node1, node2}; + return CreateNodeBase(graph, new_node_inputs, node2); +} + +AnfNodePtr CreateNodeOfUnaryOp(const FuncGraphPtr &graph, const string &op_name, const AnfNodePtr &node) { + std::vector new_node_inputs = {NewValueNode(std::make_shared(op_name)), node}; + return CreateNodeBase(graph, new_node_inputs, node); +} + +ValueNodePtr CreateValueNode(const FuncGraphPtr &graph, double value) { + auto tensor = std::make_shared(value); + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + ValueNodePtr value_node = kernel_graph->NewValueNode(tensor->ToAbstract(), tensor); + kernel_graph->AddValueNodeToGraph(value_node); + return value_node; +} + +AnfNodePtr CreateCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const TypeId dst_type) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + if (common::AnfAlgo::GetOutputInferDataType(input, 0) != dst_type) { + AnfNodePtr cast = graph->NewCNode({NewValueNode(std::make_shared(kCastOpName)), input}); + MS_EXCEPTION_IF_NULL(cast); + common::AnfAlgo::SetOutputTypeAndDetailShape({dst_type}, {AnfAlgo::GetOutputDetailShape(input, 0)}, cast.get()); + common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(dst_type), cast); + cast->set_scope(input->scope()); + return cast; + } + return input; +} +} // namespace + +const BaseRef AdamWeightDecayFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimAdamWeightDecay, Xs}); +} + +const AnfNodePtr AdamWeightDecayFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto adam_weight_decay_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(adam_weight_decay_cnode); + CheckCNodeInputSize(adam_weight_decay_cnode, kAamWeightDecayInputNum); + if (common::AnfAlgo::IsDynamicShape(adam_weight_decay_cnode)) { + MS_LOG_WITH_NODE(EXCEPTION, adam_weight_decay_cnode) + << "AdamWeightDecay don't support dynamic shape, node: " << adam_weight_decay_cnode->fullname_with_scope(); + } + + const auto ori_inputs = adam_weight_decay_cnode->inputs(); + + // cast param to float32 + auto param_fp32 = CreateCastNode(graph, ori_inputs[kIdxParam], kNumberTypeFloat32); + auto m_fp32 = CreateCastNode(graph, ori_inputs[kIdxM], kNumberTypeFloat32); + auto v_fp32 = CreateCastNode(graph, ori_inputs[kIdxV], kNumberTypeFloat32); + auto grad_fp32 = CreateCastNode(graph, ori_inputs[kIdxGradient], kNumberTypeFloat32); + + // create beta1 * m + auto mul_1 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxBeta1], m_fp32); + // create 1-beta1 + auto num_one = CreateValueNode(graph, 1.0); + auto sub_1 = CreateNodeOfBinaryOp(graph, kSubOpName, num_one, ori_inputs[kIdxBeta1]); + // create (1-beta1) * gradient + auto mul_2 = CreateNodeOfBinaryOp(graph, kMulOpName, sub_1, grad_fp32); + // create next_m = beta1 * m + (1 - beat1) * gradient + auto add_1 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_1, mul_2); + + // create beta2 * v + auto mul_3 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxBeta2], v_fp32); + // create gradient^2 + auto square = CreateNodeOfUnaryOp(graph, kSquareOpName, grad_fp32); + // create 1-beta2 + auto sub_2 = CreateNodeOfBinaryOp(graph, kSubOpName, num_one, ori_inputs[kIdxBeta2]); + // create (1-beta2) * gradient^2 + auto mul_4 = CreateNodeOfBinaryOp(graph, kMulOpName, sub_2, square); + // create next_v = beta2 * v + (1 - beta2) * gradient^2 + auto add_2 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_3, mul_4); + + // create sqrt(next_v) + auto sqrt = CreateNodeOfUnaryOp(graph, kSqrtOpName, add_2); + // create eps + sqrt(next_v) + auto add_3 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, ori_inputs[kIdxEps], sqrt); + // create update = next_m / (eps + sqrt(next_v)) + auto real_div = CreateNodeOfBinaryOp(graph, kRealDivOpName, add_1, add_3); + // create weight_decay * param + auto mul_5 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxWeightDecay], param_fp32); + // create update <== weight_decay * param + update + auto add_4 = CreateNodeOfBinaryOp(graph, kTensorAddOpName, mul_5, real_div); + // create update_with_lr = lr * update + auto mul_6 = CreateNodeOfBinaryOp(graph, kMulOpName, ori_inputs[kIdxLr], add_4); + // create param - update_with_lr + auto sub_3 = CreateNodeOfBinaryOp(graph, kSubOpName, param_fp32, mul_6); + + // create param = param - update_with_lr + auto assign_1 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), param_fp32, sub_3); + // create m = next_m + auto assign_2 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), m_fp32, add_1); + // create v = next_v + auto assign_3 = CreateNodeOfBinaryOp(graph, prim::kPrimAssign->name(), v_fp32, add_2); + + return CreateMakeTupleNode(graph, std::vector{assign_1, assign_2, assign_3}); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h index 7119a2161db..0caf172e048 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/adam_weight_decay_fission.h @@ -1,33 +1,33 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ - -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class AdamWeightDecayFission : public PatternProcessPass { - public: - explicit AdamWeightDecayFission(bool multi_graph = true) - : PatternProcessPass("adam_weight_decay_fission", multi_graph) {} - ~AdamWeightDecayFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ + +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamWeightDecayFission : public PatternProcessPass { + public: + explicit AdamWeightDecayFission(bool multi_graph = true) + : PatternProcessPass("adam_weight_decay_fission", multi_graph) {} + ~AdamWeightDecayFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_ADAM_WEIGHT_DECAY_FISSION_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.h index 430fc418dfa..87c1891c9f9 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.h @@ -1,32 +1,32 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ - -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class LambFission : public PatternProcessPass { - public: - explicit LambFission(bool multi_graph = true) : PatternProcessPass("lamb_fission", multi_graph) {} - ~LambFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ + +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class LambFission : public PatternProcessPass { + public: + explicit LambFission(bool multi_graph = true) : PatternProcessPass("lamb_fission", multi_graph) {} + ~LambFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_LAMB_FISSION_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/scale_grad_fission.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/scale_grad_fission.h index a326687a044..118f7f1b63d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/scale_grad_fission.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/scale_grad_fission.h @@ -1,32 +1,32 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ - -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class ScaleGradFission : public PatternProcessPass { - public: - explicit ScaleGradFission(bool multi_graph = true) : PatternProcessPass("scale_grad_fission", multi_graph) {} - ~ScaleGradFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ + +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class ScaleGradFission : public PatternProcessPass { + public: + explicit ScaleGradFission(bool multi_graph = true) : PatternProcessPass("scale_grad_fission", multi_graph) {} + ~ScaleGradFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_IR_FISSION_SCALE_GRAD_FISSION_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/transdata_split.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/transdata_split.h index d7888d2228c..bf3d5950e21 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/transdata_split.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/transdata_split.h @@ -1,45 +1,45 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#include -#include - -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "include/backend/optimizer/helper.h" -#include "include/backend/optimizer/optimizer.h" -#include "plugin/device/ascend/optimizer/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TransDataSplit : public PatternProcessPass { - public: - explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split") - : PatternProcessPass(name, multigraph), kernel_select_(std::make_shared()) {} - ~TransDataSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; - - protected: - CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; - bool IsFormatInvaild(const AnfNodePtr &node) const; - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#include +#include + +#include "include/backend/optimizer/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/backend/optimizer/helper.h" +#include "include/backend/optimizer/optimizer.h" +#include "plugin/device/ascend/optimizer/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TransDataSplit : public PatternProcessPass { + public: + explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split") + : PatternProcessPass(name, multigraph), kernel_select_(std::make_shared()) {} + ~TransDataSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + + protected: + CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; + bool IsFormatInvaild(const AnfNodePtr &node) const; + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/prelu_fusion.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/prelu_fusion.h index 6fceac3e5ac..ed521a01810 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/prelu_fusion.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fusion/prelu_fusion.h @@ -1,40 +1,40 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class PReluFusion : public PatternProcessPass { - public: - explicit PReluFusion(bool multigraph = true) : PatternProcessPass("prelu_fusion", multigraph) { - x_ = std::make_shared(); - weight_ = std::make_shared(); - } - ~PReluFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const override; - - private: - VarPtr x_; - VarPtr weight_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class PReluFusion : public PatternProcessPass { + public: + explicit PReluFusion(bool multigraph = true) : PatternProcessPass("prelu_fusion", multigraph) { + x_ = std::make_shared(); + weight_ = std::make_shared(); + } + ~PReluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const override; + + private: + VarPtr x_; + VarPtr weight_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_export.h b/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_export.h index afc953a9473..c19d49fc7cf 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_export.h +++ b/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_export.h @@ -1,34 +1,34 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ -#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ -#include -#include -#ifndef FUNC_EXPORT -#define FUNC_EXPORT __attribute__((visibility("default"))) -#endif - -extern "C" { -FUNC_EXPORT int GetMPIRankId(); -FUNC_EXPORT int GetMPIRankSize(); -FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector &ranks_group, - size_t data_num, const std::string &op_type); -FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t in_data_num, - size_t output_size, const std::string &op_type, float *output); -FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num); -} -#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ +#include +#include +#ifndef FUNC_EXPORT +#define FUNC_EXPORT __attribute__((visibility("default"))) +#endif + +extern "C" { +FUNC_EXPORT int GetMPIRankId(); +FUNC_EXPORT int GetMPIRankSize(); +FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector &ranks_group, + size_t data_num, const std::string &op_type); +FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t in_data_num, + size_t output_size, const std::string &op_type, float *output); +FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num); +} +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_interface.h b/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_interface.h index 878f5456820..1464b8f8e5f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_interface.h +++ b/mindspore/ccsrc/plugin/device/cpu/hal/device/mpi/mpi_interface.h @@ -1,31 +1,31 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ -#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ -#include -#include -#ifdef ENABLE_MPI -constexpr auto kMPIOpTypeSum = "sum"; -int GetMPIRankId(); -int GetMPIRankSize(); -bool MPIReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, - const std::string &op_type = kMPIOpTypeSum); -bool MPIReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t in_data_num, - size_t output_size, const std::string &op_type = kMPIOpTypeSum, - float *output = nullptr); -bool MPIAllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num); -#endif // ENABLE_MPI -#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ +#include +#include +#ifdef ENABLE_MPI +constexpr auto kMPIOpTypeSum = "sum"; +int GetMPIRankId(); +int GetMPIRankSize(); +bool MPIReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, + const std::string &op_type = kMPIOpTypeSum); +bool MPIReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t in_data_num, + size_t output_size, const std::string &op_type = kMPIOpTypeSum, + float *output = nullptr); +bool MPIAllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num); +#endif // ENABLE_MPI +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.cc index 090f0f7a32f..ad41c4e5d37 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.cc @@ -1,155 +1,155 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -#define F64 kNumberTypeFloat64 -#define F32 kNumberTypeFloat32 -#define F16 kNumberTypeFloat16 -#define I32 kNumberTypeInt32 -#define I64 kNumberTypeInt64 - -constexpr size_t kHWSize = 2; -} // namespace - -bool AdaptiveMaxPool2DGradCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int AdaptiveMaxPool2DGradCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_y_grad_shape_ = inputs.at(kIndex0)->GetShapeVector(); - input_x_shape_ = inputs.at(kIndex1)->GetShapeVector(); - input_argmax_shape_ = inputs.at(kIndex2)->GetShapeVector(); - ShapeVector output_shape = outputs.at(kIndex0)->GetShapeVector(); - - outer_size_ = 1; - inner_size_ = 1; - output_stride_ = 1; - output_size_ = 1; - const size_t shape_size = input_argmax_shape_.size(); - for (size_t i = 0; i < shape_size; i++) { - if (i < shape_size - kHWSize) { - outer_size_ *= input_argmax_shape_[i]; - } else { - inner_size_ *= input_argmax_shape_[i]; - } - } - - const size_t output_shape_size = output_shape.size(); - for (size_t k = 0; k < output_shape_size; k++) { - output_size_ *= output_shape[k]; - if (k >= output_shape_size - kHWSize) { - output_stride_ *= output_shape[k]; - } - } - - return KRET_OK; -} - -template -bool AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input_grad = GetDeviceAddress(inputs, kIndex0); - auto input_argmax = GetDeviceAddress(inputs, kIndex2); - auto output = GetDeviceAddress(outputs, kIndex0); - - std::atomic_int memset_ret{EOK}; - auto output_int8 = GetDeviceAddress(outputs, kIndex0); - auto init_task = [&](size_t start, size_t end) { - size_t mem_size = end - start; - while (mem_size > 0) { - size_t real_mem_size = mem_size; - if (real_mem_size > static_cast(SECUREC_MEM_MAX_LEN)) { - real_mem_size = static_cast(SECUREC_MEM_MAX_LEN); - } - auto ret = memset_s(output_int8 + start, real_mem_size, 0, real_mem_size); - if (ret != EOK) { - memset_ret = ret; - return; - } - mem_size -= real_mem_size; - start += real_mem_size; - } - }; - ParallelLaunchAutoSearch(init_task, outputs[kIndex0]->size(), this, &search_info_); - if (memset_ret != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset_s failed, ret=" << memset_ret; - } - - auto adaptive_max_pool_2d_grad = [&](int64_t start, int64_t end) { - for (int64_t n = start; n < end; ++n) { - for (int64_t i = 0; i < inner_size_; ++i) { - int32_t maxp = input_argmax[i + n * inner_size_] + n * output_stride_; - output[maxp] += static_cast(input_grad[i + n * inner_size_]); - } - } - }; - ParallelLaunchAutoSearch(adaptive_max_pool_2d_grad, LongToSize(outer_size_), this, ¶llel_search_info_); - - return true; -} - -std::vector> - AdaptiveMaxPool2DGradCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(F16).AddInputAttr(F16).AddInputAttr(I32).AddOutputAttr(F16), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(F32).AddInputAttr(F32).AddInputAttr(I32).AddOutputAttr(F32), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(F64).AddInputAttr(F64).AddInputAttr(I32).AddOutputAttr(F64), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(F16).AddInputAttr(F16).AddInputAttr(I64).AddOutputAttr(F16), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(F32).AddInputAttr(F32).AddInputAttr(I64).AddOutputAttr(F32), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(F64).AddInputAttr(F64).AddInputAttr(I64).AddOutputAttr(F64), - &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}}; - -std::vector AdaptiveMaxPool2DGradCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { - return pair.first; - }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdaptiveMaxPool2DGrad, AdaptiveMaxPool2DGradCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +#define F64 kNumberTypeFloat64 +#define F32 kNumberTypeFloat32 +#define F16 kNumberTypeFloat16 +#define I32 kNumberTypeInt32 +#define I64 kNumberTypeInt64 + +constexpr size_t kHWSize = 2; +} // namespace + +bool AdaptiveMaxPool2DGradCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int AdaptiveMaxPool2DGradCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_y_grad_shape_ = inputs.at(kIndex0)->GetShapeVector(); + input_x_shape_ = inputs.at(kIndex1)->GetShapeVector(); + input_argmax_shape_ = inputs.at(kIndex2)->GetShapeVector(); + ShapeVector output_shape = outputs.at(kIndex0)->GetShapeVector(); + + outer_size_ = 1; + inner_size_ = 1; + output_stride_ = 1; + output_size_ = 1; + const size_t shape_size = input_argmax_shape_.size(); + for (size_t i = 0; i < shape_size; i++) { + if (i < shape_size - kHWSize) { + outer_size_ *= input_argmax_shape_[i]; + } else { + inner_size_ *= input_argmax_shape_[i]; + } + } + + const size_t output_shape_size = output_shape.size(); + for (size_t k = 0; k < output_shape_size; k++) { + output_size_ *= output_shape[k]; + if (k >= output_shape_size - kHWSize) { + output_stride_ *= output_shape[k]; + } + } + + return KRET_OK; +} + +template +bool AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_grad = GetDeviceAddress(inputs, kIndex0); + auto input_argmax = GetDeviceAddress(inputs, kIndex2); + auto output = GetDeviceAddress(outputs, kIndex0); + + std::atomic_int memset_ret{EOK}; + auto output_int8 = GetDeviceAddress(outputs, kIndex0); + auto init_task = [&](size_t start, size_t end) { + size_t mem_size = end - start; + while (mem_size > 0) { + size_t real_mem_size = mem_size; + if (real_mem_size > static_cast(SECUREC_MEM_MAX_LEN)) { + real_mem_size = static_cast(SECUREC_MEM_MAX_LEN); + } + auto ret = memset_s(output_int8 + start, real_mem_size, 0, real_mem_size); + if (ret != EOK) { + memset_ret = ret; + return; + } + mem_size -= real_mem_size; + start += real_mem_size; + } + }; + ParallelLaunchAutoSearch(init_task, outputs[kIndex0]->size(), this, &search_info_); + if (memset_ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset_s failed, ret=" << memset_ret; + } + + auto adaptive_max_pool_2d_grad = [&](int64_t start, int64_t end) { + for (int64_t n = start; n < end; ++n) { + for (int64_t i = 0; i < inner_size_; ++i) { + int32_t maxp = input_argmax[i + n * inner_size_] + n * output_stride_; + output[maxp] += static_cast(input_grad[i + n * inner_size_]); + } + } + }; + ParallelLaunchAutoSearch(adaptive_max_pool_2d_grad, LongToSize(outer_size_), this, ¶llel_search_info_); + + return true; +} + +std::vector> + AdaptiveMaxPool2DGradCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(F16).AddInputAttr(F16).AddInputAttr(I32).AddOutputAttr(F16), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(F32).AddInputAttr(F32).AddInputAttr(I32).AddOutputAttr(F32), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(F64).AddInputAttr(F64).AddInputAttr(I32).AddOutputAttr(F64), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(F16).AddInputAttr(F16).AddInputAttr(I64).AddOutputAttr(F16), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(F32).AddInputAttr(F32).AddInputAttr(I64).AddOutputAttr(F32), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(F64).AddInputAttr(F64).AddInputAttr(I64).AddOutputAttr(F64), + &AdaptiveMaxPool2DGradCpuKernelMod::LaunchKernel}}; + +std::vector AdaptiveMaxPool2DGradCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { + return pair.first; + }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdaptiveMaxPool2DGrad, AdaptiveMaxPool2DGradCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h index ba2b89a23cc..e506d81baa4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adaptive_max_pool_2d_grad_cpu_kernel.h @@ -1,65 +1,65 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class AdaptiveMaxPool2DGradCpuKernelMod : public NativeCpuKernelMod { - public: - AdaptiveMaxPool2DGradCpuKernelMod() = default; - ~AdaptiveMaxPool2DGradCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - using AdaptiveMaxPool2DGradLaunchFunc = std::function &, const std::vector &)>; - static std::vector> func_list_; - AdaptiveMaxPool2DGradLaunchFunc kernel_func_; - ParallelSearchInfo search_info_; - - ShapeVector input_y_grad_shape_; - ShapeVector input_x_shape_; - ShapeVector input_argmax_shape_; - int64_t outer_size_{1}; - int64_t inner_size_{1}; - int64_t output_stride_{1}; - int64_t output_size_{1}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class AdaptiveMaxPool2DGradCpuKernelMod : public NativeCpuKernelMod { + public: + AdaptiveMaxPool2DGradCpuKernelMod() = default; + ~AdaptiveMaxPool2DGradCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + using AdaptiveMaxPool2DGradLaunchFunc = std::function &, const std::vector &)>; + static std::vector> func_list_; + AdaptiveMaxPool2DGradLaunchFunc kernel_func_; + ParallelSearchInfo search_info_; + + ShapeVector input_y_grad_shape_; + ShapeVector input_x_shape_; + ShapeVector input_argmax_shape_; + int64_t outer_size_{1}; + int64_t inner_size_{1}; + int64_t output_stride_{1}; + int64_t output_size_{1}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAPTIVE_MAX_POOL_2D_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.cc index 197c4b6423d..bf48c8e52bd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.cc @@ -1,187 +1,187 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "utils/ms_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -const std::int64_t kAdjustSaturationParallelNum = 64 * 1024; -const std::int64_t kAdjustSaturationZero = 0; -const std::int64_t kAdjustSaturationOne = 1; -const std::int64_t kAdjustSaturationTwo = 2; -const std::int64_t kAdjustSaturationThree = 3; -const std::int64_t kAdjustSaturationFour = 4; -const std::int64_t kAdjustSaturationFive = 5; -const std::float_t kAdjustSaturationSix = 6; -} // namespace - -namespace detail { -static void rgb_to_hsv(float r, float g, float b, float *h, float *s, float *v) { - float vv = std::max(r, std::max(g, b)); - float range = vv - std::min(r, std::min(g, b)); - const float eps = 1e-6; - if (vv > 0) { - *s = range / vv; - } else { - *s = 0; - } - float norm = kAdjustSaturationOne / (kAdjustSaturationSix * range); - float hh; - if (std::fabs(r - vv) <= eps) { - hh = norm * (g - b); - } else if (std::fabs(g - vv) <= eps) { - hh = norm * (b - r) + kAdjustSaturationTwo / kAdjustSaturationSix; - } else { - hh = norm * (r - g) + kAdjustSaturationFour / kAdjustSaturationSix; - } - if (range <= 0.0) { - hh = 0; - } - if (hh < 0.0) { - hh = hh + kAdjustSaturationOne; - } - *v = vv; - *h = hh; -} - -template -static void hsv_to_rgb(float h, float s, float v, T *r, T *g, T *b) { - float c = s * v; - float m = v - c; - float dh = h * kAdjustSaturationSix; - float rr, gg, bb; - int h_category = static_cast(dh); - float fmodu = dh; - while (fmodu <= 0) { - fmodu += kAdjustSaturationTwo; - } - while (fmodu >= kAdjustSaturationTwo) { - fmodu -= kAdjustSaturationTwo; - } - float x = c * (1 - std::abs(fmodu - 1)); - switch (h_category) { - case kAdjustSaturationZero: - rr = c; - gg = x; - bb = 0; - break; - case kAdjustSaturationOne: - rr = x; - gg = c; - bb = 0; - break; - case kAdjustSaturationTwo: - rr = 0; - gg = c; - bb = x; - break; - case kAdjustSaturationThree: - rr = 0; - gg = x; - bb = c; - break; - case kAdjustSaturationFour: - rr = x; - gg = 0; - bb = c; - break; - case kAdjustSaturationFive: - rr = c; - gg = 0; - bb = x; - break; - default: - rr = 0; - gg = 0; - bb = 0; - } - *r = static_cast(rr + m); - *g = static_cast(gg + m); - *b = static_cast(bb + m); -} - -template -bool LaunchAdjustSaturationKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input{static_cast(inputs[0]->device_ptr())}; - auto scale{static_cast(inputs[1]->device_ptr())}; - auto output{static_cast(outputs[0]->device_ptr())}; - constexpr int64_t kChannelSize = 3; - std::int64_t num_elements = static_cast(inputs[0]->size() / sizeof(T)); - auto sharder_adjustsaturation = [input, scale, output, kChannelSize](int64_t start, int64_t end) { - for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) { - float h, s, v; - // Convert the RGB color to Hue/V-range. - rgb_to_hsv(static_cast(*(input + i)), static_cast(*(input + i + 1)), - static_cast(*(input + i + 2)), &h, &s, &v); - s = std::min(1.0f, std::max(0.0f, s * scale[0])); - // Convert the hue and v-range back into RGB. - hsv_to_rgb(h, s, v, &output[i], &output[i + 1], &output[i + 2]); - } - }; - std::int64_t total = num_elements / kChannelSize; - if (total > kAdjustSaturationParallelNum) { - std::int64_t per_unit_size = - total / std::min(kAdjustSaturationParallelNum - SizeToLong(kAdjustSaturationTwo), total); - CPUKernelUtils::ParallelFor(sharder_adjustsaturation, static_cast(total), - static_cast(per_unit_size)); - } else { - sharder_adjustsaturation(0, total); - } - return true; -} -} // namespace detail - -bool AdjustSaturationCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdjustSaturationTwo, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdjustSaturationOne, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match.first) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - input_type_ = inputs[kIndex0]->dtype_id(); - return true; -} - -bool AdjustSaturationCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (input_type_ == kNumberTypeFloat32) { - return detail::LaunchAdjustSaturationKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeFloat16) { - return detail::LaunchAdjustSaturationKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported input data type " << TypeIdLabel(input_type_); - } -} -std::vector AdjustSaturationCpuKernelMod::GetOpSupport() { - static const std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}; - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustSaturation, AdjustSaturationCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +const std::int64_t kAdjustSaturationParallelNum = 64 * 1024; +const std::int64_t kAdjustSaturationZero = 0; +const std::int64_t kAdjustSaturationOne = 1; +const std::int64_t kAdjustSaturationTwo = 2; +const std::int64_t kAdjustSaturationThree = 3; +const std::int64_t kAdjustSaturationFour = 4; +const std::int64_t kAdjustSaturationFive = 5; +const std::float_t kAdjustSaturationSix = 6; +} // namespace + +namespace detail { +static void rgb_to_hsv(float r, float g, float b, float *h, float *s, float *v) { + float vv = std::max(r, std::max(g, b)); + float range = vv - std::min(r, std::min(g, b)); + const float eps = 1e-6; + if (vv > 0) { + *s = range / vv; + } else { + *s = 0; + } + float norm = kAdjustSaturationOne / (kAdjustSaturationSix * range); + float hh; + if (std::fabs(r - vv) <= eps) { + hh = norm * (g - b); + } else if (std::fabs(g - vv) <= eps) { + hh = norm * (b - r) + kAdjustSaturationTwo / kAdjustSaturationSix; + } else { + hh = norm * (r - g) + kAdjustSaturationFour / kAdjustSaturationSix; + } + if (range <= 0.0) { + hh = 0; + } + if (hh < 0.0) { + hh = hh + kAdjustSaturationOne; + } + *v = vv; + *h = hh; +} + +template +static void hsv_to_rgb(float h, float s, float v, T *r, T *g, T *b) { + float c = s * v; + float m = v - c; + float dh = h * kAdjustSaturationSix; + float rr, gg, bb; + int h_category = static_cast(dh); + float fmodu = dh; + while (fmodu <= 0) { + fmodu += kAdjustSaturationTwo; + } + while (fmodu >= kAdjustSaturationTwo) { + fmodu -= kAdjustSaturationTwo; + } + float x = c * (1 - std::abs(fmodu - 1)); + switch (h_category) { + case kAdjustSaturationZero: + rr = c; + gg = x; + bb = 0; + break; + case kAdjustSaturationOne: + rr = x; + gg = c; + bb = 0; + break; + case kAdjustSaturationTwo: + rr = 0; + gg = c; + bb = x; + break; + case kAdjustSaturationThree: + rr = 0; + gg = x; + bb = c; + break; + case kAdjustSaturationFour: + rr = x; + gg = 0; + bb = c; + break; + case kAdjustSaturationFive: + rr = c; + gg = 0; + bb = x; + break; + default: + rr = 0; + gg = 0; + bb = 0; + } + *r = static_cast(rr + m); + *g = static_cast(gg + m); + *b = static_cast(bb + m); +} + +template +bool LaunchAdjustSaturationKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input{static_cast(inputs[0]->device_ptr())}; + auto scale{static_cast(inputs[1]->device_ptr())}; + auto output{static_cast(outputs[0]->device_ptr())}; + constexpr int64_t kChannelSize = 3; + std::int64_t num_elements = static_cast(inputs[0]->size() / sizeof(T)); + auto sharder_adjustsaturation = [input, scale, output, kChannelSize](int64_t start, int64_t end) { + for (int64_t i = start * kChannelSize; i < end * kChannelSize; i = i + kChannelSize) { + float h, s, v; + // Convert the RGB color to Hue/V-range. + rgb_to_hsv(static_cast(*(input + i)), static_cast(*(input + i + 1)), + static_cast(*(input + i + 2)), &h, &s, &v); + s = std::min(1.0f, std::max(0.0f, s * scale[0])); + // Convert the hue and v-range back into RGB. + hsv_to_rgb(h, s, v, &output[i], &output[i + 1], &output[i + 2]); + } + }; + std::int64_t total = num_elements / kChannelSize; + if (total > kAdjustSaturationParallelNum) { + std::int64_t per_unit_size = + total / std::min(kAdjustSaturationParallelNum - SizeToLong(kAdjustSaturationTwo), total); + CPUKernelUtils::ParallelFor(sharder_adjustsaturation, static_cast(total), + static_cast(per_unit_size)); + } else { + sharder_adjustsaturation(0, total); + } + return true; +} +} // namespace detail + +bool AdjustSaturationCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdjustSaturationTwo, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdjustSaturationOne, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + input_type_ = inputs[kIndex0]->dtype_id(); + return true; +} + +bool AdjustSaturationCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (input_type_ == kNumberTypeFloat32) { + return detail::LaunchAdjustSaturationKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeFloat16) { + return detail::LaunchAdjustSaturationKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', unsupported input data type " << TypeIdLabel(input_type_); + } +} +std::vector AdjustSaturationCpuKernelMod::GetOpSupport() { + static const std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}; + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdjustSaturation, AdjustSaturationCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h index 23d549984c8..34814130ba4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/adjust_saturation_cpu_kernel.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class AdjustSaturationCpuKernelMod : public NativeCpuKernelMod { - public: - AdjustSaturationCpuKernelMod() = default; - ~AdjustSaturationCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - std::vector GetOpSupport() override; - - private: - TypeId input_type_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class AdjustSaturationCpuKernelMod : public NativeCpuKernelMod { + public: + AdjustSaturationCpuKernelMod() = default; + ~AdjustSaturationCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + std::vector GetOpSupport() override; + + private: + TypeId input_type_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DECIVE_CPU_ADJUST_SATURATION_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.cc index 0338c83a3b5..83c0a0b3231 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.cc @@ -1,196 +1,196 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h" -#include -#include -#include -#include "kernel/common_utils.h" -#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "ops/op_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kPowerSignInputsNum = 7; -constexpr size_t kPowerSignOutputsNum = 2; -constexpr size_t kIndexVar = 0; -constexpr size_t kIndexM = 1; -constexpr size_t kIndexLr = 2; -constexpr size_t kIndexLogBase = 3; -constexpr size_t kIndexSignDecay = 4; -constexpr size_t kIndexBeta = 5; -constexpr size_t kIndexGrad = 6; - -template -int Sgn(const T &x) { - if (x > T(0)) { - return 1; - } - if (x < T(0)) { - return -1; - } - return 0; -} -} // namespace - -template -void ApplyPowerSignCpuKernelMod::LaunchPowerSign(const std::vector &inputs, - const std::vector &) { - T *var = reinterpret_cast(inputs[kIndexVar]->device_ptr()); - T *m = reinterpret_cast(inputs[kIndexM]->device_ptr()); - T *lr = reinterpret_cast(inputs[kIndexLr]->device_ptr()); - T *logbase = reinterpret_cast(inputs[kIndexLogBase]->device_ptr()); - T *sign_decay = reinterpret_cast(inputs[kIndexSignDecay]->device_ptr()); - T *beta = reinterpret_cast(inputs[kIndexBeta]->device_ptr()); - T *gradient = reinterpret_cast(inputs[kIndexGrad]->device_ptr()); - - for (int64_t b = 0; b < batch_size_; b++) { - // multithreading - auto task = [this, &var, &m, &gradient, &lr, &beta, &logbase, &sign_decay](size_t start, size_t end) { - T one = static_cast(1.0); - for (size_t i = start; i < end; i++) { - m[i] = gradient[i] * (one - beta[0]) + m[i] * beta[0]; - T sign_value = static_cast(Sgn(gradient[i]) * Sgn(m[i])); - T update = exp(logbase[0] * sign_decay[0] * sign_value) * gradient[i]; - var[i] = var[i] - lr[0] * update; - } - }; - ParallelLaunchAutoSearch(task, LongToSize(input_elements_), this, ¶llel_search_info_); - var = var + input_elements_; - m = m + input_elements_; - gradient = gradient + input_elements_; - lr++; - beta++; - logbase++; - sign_decay++; - } -} - -bool ApplyPowerSignCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - dtype_ = inputs[0]->dtype_id(); - batch_rank_ = ops::get_batch_rank(primitive_); - return true; -} - -bool ApplyPowerSignCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kPowerSignInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kPowerSignOutputsNum, kernel_name_); - - if (dtype_ == kNumberTypeFloat32) { - LaunchPowerSign(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat16) { - LaunchPowerSign(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be Float16 or Float32, but got " - << TypeIdToType(dtype_)->ToString(); - } - return true; -} - -int ApplyPowerSignCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - - std::vector var_shape = inputs[kIndexVar]->GetShapeVector(); - std::vector m_shape = inputs[kIndexM]->GetShapeVector(); - std::vector lr_shape = inputs[kIndexLr]->GetShapeVector(); - std::vector grad_shape = inputs[kIndexGrad]->GetShapeVector(); - - if (var_shape.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the dimension of 'var' must be at least 1-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the shape of 'accum' must be the same as the shape of 'var', " - "but got the shape of 'accum': " - << m_shape << " and the shape of 'var': " << var_shape; - return KRET_RESIZE_FAILED; - } - - if (!IsSameShape(var_shape, grad_shape)) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the shape of 'grad' must be the same as the shape of 'var', " - "but got the shape of 'grad': " - << grad_shape << " and the shape of 'var': " << var_shape; - return KRET_RESIZE_FAILED; - } - - if ((batch_rank_ != 0) && (lr_shape.size() != static_cast(batch_rank_))) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the shape size of 'lr' must be equal to 'batch_rank', " - "but got the shape of 'lr': " - << lr_shape << " and 'batch_rank': " << batch_rank_; - return KRET_RESIZE_FAILED; - } - - if (!lr_shape.empty()) { - batch_size_ = std::accumulate(lr_shape.begin(), lr_shape.end(), 1, std::multiplies()); - } - - if (batch_size_ <= 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', batch_size_ must be greater than 0, but got batch_size: " << batch_size_; - return KRET_RESIZE_FAILED; - } - - input_elements_ = std::accumulate(var_shape.begin(), var_shape.end(), 1, std::multiplies()); - input_elements_ = input_elements_ / batch_size_; - - return ret; -} - -std::vector ApplyPowerSignCpuKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1), - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1)}; - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyPowerSign, ApplyPowerSignCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h" +#include +#include +#include +#include "kernel/common_utils.h" +#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kPowerSignInputsNum = 7; +constexpr size_t kPowerSignOutputsNum = 2; +constexpr size_t kIndexVar = 0; +constexpr size_t kIndexM = 1; +constexpr size_t kIndexLr = 2; +constexpr size_t kIndexLogBase = 3; +constexpr size_t kIndexSignDecay = 4; +constexpr size_t kIndexBeta = 5; +constexpr size_t kIndexGrad = 6; + +template +int Sgn(const T &x) { + if (x > T(0)) { + return 1; + } + if (x < T(0)) { + return -1; + } + return 0; +} +} // namespace + +template +void ApplyPowerSignCpuKernelMod::LaunchPowerSign(const std::vector &inputs, + const std::vector &) { + T *var = reinterpret_cast(inputs[kIndexVar]->device_ptr()); + T *m = reinterpret_cast(inputs[kIndexM]->device_ptr()); + T *lr = reinterpret_cast(inputs[kIndexLr]->device_ptr()); + T *logbase = reinterpret_cast(inputs[kIndexLogBase]->device_ptr()); + T *sign_decay = reinterpret_cast(inputs[kIndexSignDecay]->device_ptr()); + T *beta = reinterpret_cast(inputs[kIndexBeta]->device_ptr()); + T *gradient = reinterpret_cast(inputs[kIndexGrad]->device_ptr()); + + for (int64_t b = 0; b < batch_size_; b++) { + // multithreading + auto task = [this, &var, &m, &gradient, &lr, &beta, &logbase, &sign_decay](size_t start, size_t end) { + T one = static_cast(1.0); + for (size_t i = start; i < end; i++) { + m[i] = gradient[i] * (one - beta[0]) + m[i] * beta[0]; + T sign_value = static_cast(Sgn(gradient[i]) * Sgn(m[i])); + T update = exp(logbase[0] * sign_decay[0] * sign_value) * gradient[i]; + var[i] = var[i] - lr[0] * update; + } + }; + ParallelLaunchAutoSearch(task, LongToSize(input_elements_), this, ¶llel_search_info_); + var = var + input_elements_; + m = m + input_elements_; + gradient = gradient + input_elements_; + lr++; + beta++; + logbase++; + sign_decay++; + } +} + +bool ApplyPowerSignCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + dtype_ = inputs[0]->dtype_id(); + batch_rank_ = ops::get_batch_rank(primitive_); + return true; +} + +bool ApplyPowerSignCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kPowerSignInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kPowerSignOutputsNum, kernel_name_); + + if (dtype_ == kNumberTypeFloat32) { + LaunchPowerSign(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + LaunchPowerSign(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be Float16 or Float32, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +int ApplyPowerSignCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + + std::vector var_shape = inputs[kIndexVar]->GetShapeVector(); + std::vector m_shape = inputs[kIndexM]->GetShapeVector(); + std::vector lr_shape = inputs[kIndexLr]->GetShapeVector(); + std::vector grad_shape = inputs[kIndexGrad]->GetShapeVector(); + + if (var_shape.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the dimension of 'var' must be at least 1-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the shape of 'accum' must be the same as the shape of 'var', " + "but got the shape of 'accum': " + << m_shape << " and the shape of 'var': " << var_shape; + return KRET_RESIZE_FAILED; + } + + if (!IsSameShape(var_shape, grad_shape)) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the shape of 'grad' must be the same as the shape of 'var', " + "but got the shape of 'grad': " + << grad_shape << " and the shape of 'var': " << var_shape; + return KRET_RESIZE_FAILED; + } + + if ((batch_rank_ != 0) && (lr_shape.size() != static_cast(batch_rank_))) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the shape size of 'lr' must be equal to 'batch_rank', " + "but got the shape of 'lr': " + << lr_shape << " and 'batch_rank': " << batch_rank_; + return KRET_RESIZE_FAILED; + } + + if (!lr_shape.empty()) { + batch_size_ = std::accumulate(lr_shape.begin(), lr_shape.end(), 1, std::multiplies()); + } + + if (batch_size_ <= 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', batch_size_ must be greater than 0, but got batch_size: " << batch_size_; + return KRET_RESIZE_FAILED; + } + + input_elements_ = std::accumulate(var_shape.begin(), var_shape.end(), 1, std::multiplies()); + input_elements_ = input_elements_ / batch_size_; + + return ret; +} + +std::vector ApplyPowerSignCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutInRef(0, 0) + .AddOutInRef(1, 1), + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutInRef(0, 0) + .AddOutInRef(1, 1)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyPowerSign, ApplyPowerSignCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h index 23516c97d66..b02b92d4dbc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_power_sign_cpu_kernel.h @@ -1,55 +1,55 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ - -#include -#include -#include - -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class ApplyPowerSignCpuKernelMod : public NativeCpuKernelMod { - public: - ApplyPowerSignCpuKernelMod() = default; - ~ApplyPowerSignCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - template - void LaunchPowerSign(const std::vector &inputs, const std::vector &outputs); - int64_t batch_size_{1}; - int64_t batch_rank_{0}; - bool use_locking{false}; - int64_t input_elements_{0}; - TypeId dtype_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ + +#include +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class ApplyPowerSignCpuKernelMod : public NativeCpuKernelMod { + public: + ApplyPowerSignCpuKernelMod() = default; + ~ApplyPowerSignCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + template + void LaunchPowerSign(const std::vector &inputs, const std::vector &outputs); + int64_t batch_size_{1}; + int64_t batch_rank_{0}; + bool use_locking{false}; + int64_t input_elements_{0}; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_APPLY_POWER_SIGN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.cc index d28a9018de0..5c74d11bfd0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.cc @@ -1,104 +1,104 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/betainc_cpu_kernel.h" -#include -#include -#include "unsupported/Eigen/CXX11/Tensor" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kBetaincInputsNum = 3; -constexpr size_t kBetaincOutputsNum = 1; -} // namespace - -bool BetaincCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - return MatchKernelFunc(kernel_name_, inputs, outputs); -} - -int BetaincCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = 0; - if ((ret = NativeCpuKernelMod::Resize(inputs, outputs)) != 0) { - return ret; - } - input0_shape_ = inputs[kIndex0]->GetShapeVector(); - input1_shape_ = inputs[kIndex1]->GetShapeVector(); - input2_shape_ = inputs[kIndex2]->GetShapeVector(); - output_shape_ = outputs[kIndex0]->GetShapeVector(); - if (!IsSameShape(input0_shape_, input1_shape_)) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'b' should be same with the shape of 'a', " - << "but got the shape of 'b': " << input1_shape_ << " and 'a': " << input0_shape_; - return KRET_RESIZE_FAILED; - } - if (!IsSameShape(input0_shape_, input2_shape_)) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'x' should be same with the shape of 'a', " - << "but got the shape of 'x': " << input2_shape_ << " and 'a': " << input0_shape_; - return KRET_RESIZE_FAILED; - } - if (!IsSameShape(input0_shape_, output_shape_)) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of output should be same with the shape of the 'a', " - << "but got the shape of the output: " << output_shape_ << " and 'a': " << input0_shape_; - return KRET_RESIZE_FAILED; - } - return 0; -} - -template -inline T ScalarBetainc(T a, T b, T x) { - return Eigen::numext::betainc(a, b, x); -} - -template -bool BetaincCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBetaincInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBetaincOutputsNum, kernel_name_); - T *input0 = reinterpret_cast(inputs[0]->device_ptr()); - T *input1 = reinterpret_cast(inputs[1]->device_ptr()); - T *input2 = reinterpret_cast(inputs[2]->device_ptr()); - T *output = reinterpret_cast(outputs[0]->device_ptr()); - auto total = inputs[0]->size() / sizeof(T); - auto task = [&input0, &input1, &input2, &output](std::int64_t begin, std::int64_t end) { - for (std::int64_t i = begin; i < end; i++) { - output[i] = ScalarBetainc(input0[i], input1[i], input2[i]); - } - }; - ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); - return true; -} - -const std::vector> &BetaincCpuKernelMod::GetFuncList() const { - static const std::vector> func_list = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &BetaincCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - &BetaincCpuKernelMod::LaunchKernel}}; - return func_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Betainc, BetaincCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/betainc_cpu_kernel.h" +#include +#include +#include "unsupported/Eigen/CXX11/Tensor" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kBetaincInputsNum = 3; +constexpr size_t kBetaincOutputsNum = 1; +} // namespace + +bool BetaincCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + return MatchKernelFunc(kernel_name_, inputs, outputs); +} + +int BetaincCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = 0; + if ((ret = NativeCpuKernelMod::Resize(inputs, outputs)) != 0) { + return ret; + } + input0_shape_ = inputs[kIndex0]->GetShapeVector(); + input1_shape_ = inputs[kIndex1]->GetShapeVector(); + input2_shape_ = inputs[kIndex2]->GetShapeVector(); + output_shape_ = outputs[kIndex0]->GetShapeVector(); + if (!IsSameShape(input0_shape_, input1_shape_)) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'b' should be same with the shape of 'a', " + << "but got the shape of 'b': " << input1_shape_ << " and 'a': " << input0_shape_; + return KRET_RESIZE_FAILED; + } + if (!IsSameShape(input0_shape_, input2_shape_)) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'x' should be same with the shape of 'a', " + << "but got the shape of 'x': " << input2_shape_ << " and 'a': " << input0_shape_; + return KRET_RESIZE_FAILED; + } + if (!IsSameShape(input0_shape_, output_shape_)) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of output should be same with the shape of the 'a', " + << "but got the shape of the output: " << output_shape_ << " and 'a': " << input0_shape_; + return KRET_RESIZE_FAILED; + } + return 0; +} + +template +inline T ScalarBetainc(T a, T b, T x) { + return Eigen::numext::betainc(a, b, x); +} + +template +bool BetaincCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBetaincInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBetaincOutputsNum, kernel_name_); + T *input0 = reinterpret_cast(inputs[0]->device_ptr()); + T *input1 = reinterpret_cast(inputs[1]->device_ptr()); + T *input2 = reinterpret_cast(inputs[2]->device_ptr()); + T *output = reinterpret_cast(outputs[0]->device_ptr()); + auto total = inputs[0]->size() / sizeof(T); + auto task = [&input0, &input1, &input2, &output](std::int64_t begin, std::int64_t end) { + for (std::int64_t i = begin; i < end; i++) { + output[i] = ScalarBetainc(input0[i], input1[i], input2[i]); + } + }; + ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); + return true; +} + +const std::vector> &BetaincCpuKernelMod::GetFuncList() const { + static const std::vector> func_list = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &BetaincCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &BetaincCpuKernelMod::LaunchKernel}}; + return func_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Betainc, BetaincCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.h index 428dd5a96eb..668600b4190 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/betainc_cpu_kernel.h @@ -1,62 +1,62 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ - -#include -#include -#include - -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class BetaincCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { - public: - BetaincCpuKernelMod() = default; - - ~BetaincCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - std::vector input0_shape_; - std::vector input1_shape_; - std::vector input2_shape_; - std::vector output_shape_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ + +#include +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class BetaincCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + BetaincCpuKernelMod() = default; + + ~BetaincCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + std::vector input0_shape_; + std::vector input1_shape_; + std::vector input2_shape_; + std::vector output_shape_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BETAINC_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.cc index 3a490b73513..1540f51606c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.cc @@ -1,102 +1,102 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "plugin/device/cpu/kernel/blackman_window_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "mindspore/core/ops/blackman_window.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kBlackmanWindowInputsNum = 1; -constexpr size_t kBlackmanWindowOutputsNum = 1; -} // namespace - -bool BlackmanWindowCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBlackmanWindowInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBlackmanWindowOutputsNum, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); - return true; -} - -template -bool BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) const { - auto input = static_cast(inputs[0]->device_ptr()); - auto output = static_cast(outputs[0]->device_ptr()); - - if (*input < 0) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input window_length should be >= 0, but got " << *input; - } - - auto window_length = static_cast(*input); - double pre_window_length = static_cast(window_length); - const size_t OUTPUTISONE = 1; - - if (*input == 1) { - *output = static_cast(OUTPUTISONE); - } else { - if (periodic_) { - window_length += 1; - } - const double PI = 3.14159265358979323846; - const double x = static_cast(window_length); - for (size_t i = 0; i < pre_window_length; i++) { - auto temp = static_cast(0.08 * cos((4 * PI * i) / (x - 1)) - 0.5 * cos((2 * PI * i) / (x - 1)) + 0.42); - *(output + i) = temp; - } - } - return true; -} - -std::vector> - BlackmanWindowCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}}; - -std::vector BlackmanWindowCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BlackmanWindow, BlackmanWindowCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "plugin/device/cpu/kernel/blackman_window_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "mindspore/core/ops/blackman_window.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kBlackmanWindowInputsNum = 1; +constexpr size_t kBlackmanWindowOutputsNum = 1; +} // namespace + +bool BlackmanWindowCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBlackmanWindowInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBlackmanWindowOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); + return true; +} + +template +bool BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) const { + auto input = static_cast(inputs[0]->device_ptr()); + auto output = static_cast(outputs[0]->device_ptr()); + + if (*input < 0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input window_length should be >= 0, but got " << *input; + } + + auto window_length = static_cast(*input); + double pre_window_length = static_cast(window_length); + const size_t OUTPUTISONE = 1; + + if (*input == 1) { + *output = static_cast(OUTPUTISONE); + } else { + if (periodic_) { + window_length += 1; + } + const double PI = 3.14159265358979323846; + const double x = static_cast(window_length); + for (size_t i = 0; i < pre_window_length; i++) { + auto temp = static_cast(0.08 * cos((4 * PI * i) / (x - 1)) - 0.5 * cos((2 * PI * i) / (x - 1)) + 0.42); + *(output + i) = temp; + } + } + return true; +} + +std::vector> + BlackmanWindowCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &BlackmanWindowCpuKernelMod::BlackmanWindowKernelFunc}}; + +std::vector BlackmanWindowCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BlackmanWindow, BlackmanWindowCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.h index 9deb625c0dd..918a5589fc2 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/blackman_window_cpu_kernel.h @@ -1,56 +1,56 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ - -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class BlackmanWindowCpuKernelMod : public NativeCpuKernelMod { - public: - BlackmanWindowCpuKernelMod() = default; - ~BlackmanWindowCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - std::vector GetOpSupport() override; - - private: - template - bool BlackmanWindowKernelFunc(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) const; - bool periodic_{true}; - TypeId input_dtype{kTypeUnknown}; - using BlackmanWindowFunc = - std::function &, - const std::vector &, const std::vector &)>; - static std::vector> func_list_; - BlackmanWindowFunc kernel_func_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class BlackmanWindowCpuKernelMod : public NativeCpuKernelMod { + public: + BlackmanWindowCpuKernelMod() = default; + ~BlackmanWindowCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + std::vector GetOpSupport() override; + + private: + template + bool BlackmanWindowKernelFunc(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) const; + bool periodic_{true}; + TypeId input_dtype{kTypeUnknown}; + using BlackmanWindowFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + BlackmanWindowFunc kernel_func_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CPU_BLACKMAN_WINDOW_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc index a38159efbaa..abdcfae1f6a 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc @@ -1,115 +1,115 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "ops/bucketize.h" -#include "plugin/device/cpu/kernel/bucketize_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kOutputNum = 1; -const size_t kInputNum = 1; -const size_t kParallelDataNumSameShape = 64 * 1024; -const size_t kParallelDataNumSameShapeMid = 35 * 1024; -} // namespace - -bool BucketizeCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - dtype_ = inputs.at(kIndex0)->dtype_id(); - boundaries_ = GetValue>(primitive_->GetAttr(ops::kBoundaries)); - return true; -} - -int BucketizeCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_shape_ = inputs.at(kIndex0)->GetShapeVector(); - output_shape_ = outputs.at(kIndex0)->GetShapeVector(); - return KRET_OK; -} - -bool BucketizeCpuKernelMod::Launch(const std::vector &inputs, - const std::vector & /* workspace */, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); - if (dtype_ != kNumberTypeInt32 && dtype_ != kNumberTypeInt64 && dtype_ != kNumberTypeFloat32 && - dtype_ != kNumberTypeFloat64) { - MS_LOG(EXCEPTION) << "Input data type must int32 or int64 or float32 or float64, but got data type." << dtype_; - } - size_t input_sizes = input_shape_.size(); - size_t output_sizes = output_shape_.size(); - if (input_sizes != output_sizes) { - MS_LOG(EXCEPTION) << "The tensor shape of input need be same with output."; - } - switch (dtype_) { - case kNumberTypeInt32: - return BucketizeCompute(inputs, outputs); - case kNumberTypeInt64: - return BucketizeCompute(inputs, outputs); - case kNumberTypeFloat32: - return BucketizeCompute(inputs, outputs); - case kNumberTypeFloat64: - return BucketizeCompute(inputs, outputs); - default: - MS_LOG(ERROR) << "Unsupported data type."; - } - return true; -} - -template -bool BucketizeCpuKernelMod::BucketizeCompute(const std::vector &inputs, - const std::vector &outputs) { - auto input_data = reinterpret_cast(inputs[0]->device_ptr()); - auto output_data = reinterpret_cast(outputs[0]->device_ptr()); - size_t data_num_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies()); - std::vector boundaries_data = boundaries_; - std::sort(boundaries_data.begin(), boundaries_data.end()); - if (data_num_ >= kParallelDataNumSameShape) { - auto sharder_bucketize = [&](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); - output_data[i] = LongToInt(first_bigger_it - boundaries_data.begin()); - } - }; - ParallelLaunchAutoSearch(sharder_bucketize, data_num_, this, ¶llel_search_info_); - } else { - for (size_t i = 0; i < data_num_; i++) { - auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); - output_data[i] = LongToInt(first_bigger_it - boundaries_data.begin()); - } - } - return true; -} - -std::vector BucketizeCpuKernelMod::GetOpSupport() { - static const std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32)}; - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Bucketize, BucketizeCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "ops/bucketize.h" +#include "plugin/device/cpu/kernel/bucketize_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kOutputNum = 1; +const size_t kInputNum = 1; +const size_t kParallelDataNumSameShape = 64 * 1024; +const size_t kParallelDataNumSameShapeMid = 35 * 1024; +} // namespace + +bool BucketizeCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + dtype_ = inputs.at(kIndex0)->dtype_id(); + boundaries_ = GetValue>(primitive_->GetAttr(ops::kBoundaries)); + return true; +} + +int BucketizeCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_shape_ = inputs.at(kIndex0)->GetShapeVector(); + output_shape_ = outputs.at(kIndex0)->GetShapeVector(); + return KRET_OK; +} + +bool BucketizeCpuKernelMod::Launch(const std::vector &inputs, + const std::vector & /* workspace */, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + if (dtype_ != kNumberTypeInt32 && dtype_ != kNumberTypeInt64 && dtype_ != kNumberTypeFloat32 && + dtype_ != kNumberTypeFloat64) { + MS_LOG(EXCEPTION) << "Input data type must int32 or int64 or float32 or float64, but got data type." << dtype_; + } + size_t input_sizes = input_shape_.size(); + size_t output_sizes = output_shape_.size(); + if (input_sizes != output_sizes) { + MS_LOG(EXCEPTION) << "The tensor shape of input need be same with output."; + } + switch (dtype_) { + case kNumberTypeInt32: + return BucketizeCompute(inputs, outputs); + case kNumberTypeInt64: + return BucketizeCompute(inputs, outputs); + case kNumberTypeFloat32: + return BucketizeCompute(inputs, outputs); + case kNumberTypeFloat64: + return BucketizeCompute(inputs, outputs); + default: + MS_LOG(ERROR) << "Unsupported data type."; + } + return true; +} + +template +bool BucketizeCpuKernelMod::BucketizeCompute(const std::vector &inputs, + const std::vector &outputs) { + auto input_data = reinterpret_cast(inputs[0]->device_ptr()); + auto output_data = reinterpret_cast(outputs[0]->device_ptr()); + size_t data_num_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies()); + std::vector boundaries_data = boundaries_; + std::sort(boundaries_data.begin(), boundaries_data.end()); + if (data_num_ >= kParallelDataNumSameShape) { + auto sharder_bucketize = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = LongToInt(first_bigger_it - boundaries_data.begin()); + } + }; + ParallelLaunchAutoSearch(sharder_bucketize, data_num_, this, ¶llel_search_info_); + } else { + for (size_t i = 0; i < data_num_; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = LongToInt(first_bigger_it - boundaries_data.begin()); + } + } + return true; +} + +std::vector BucketizeCpuKernelMod::GetOpSupport() { + static const std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32)}; + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Bucketize, BucketizeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h index 7169c558fa9..ae510e4325f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h @@ -1,52 +1,52 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ - -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class BucketizeCpuKernelMod : public NativeCpuKernelMod { - public: - BucketizeCpuKernelMod() = default; - ~BucketizeCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - template - bool BucketizeCompute(const std::vector &inputs, const std::vector &outputs); - - std::vector GetOpSupport() override; - - private: - ShapeVector input_shape_; - ShapeVector output_shape_; - std::vector boundaries_; - TypeId dtype_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class BucketizeCpuKernelMod : public NativeCpuKernelMod { + public: + BucketizeCpuKernelMod() = default; + ~BucketizeCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + template + bool BucketizeCompute(const std::vector &inputs, const std::vector &outputs); + + std::vector GetOpSupport() override; + + private: + ShapeVector input_shape_; + ShapeVector output_shape_; + std::vector boundaries_; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.cc index b61371d6a80..e8ab5fda083 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.cc @@ -1,141 +1,141 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/cdist_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -#include "plugin/device/cpu/kernel/nnacl/fp32/cdist_fp32.h" -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kCdistInputDimsMin = 2; - -const std::vector kernel_attr = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}}; -} // namespace - -void CdistCpuKernelMod::InitFunc(float p) { - if (p == 0.0) { - dist_func_ = CdistZeroNormalOpt; - } else if (p == 1.0) { - dist_func_ = CdistOneNormalOpt; - } else if (p == 2.0) { - dist_func_ = CdistTwoNormalOpt; - } else if (std::isinf(p)) { - dist_func_ = CdistInfNormalOpt; - } else { - dist_func_ = CdistPNormalOpt; - } -} - -bool CdistCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - p_ = GetValue(primitive_->GetAttr(ops::kP)); - auto input_type_id = inputs[0]->dtype_id(); - switch (input_type_id) { - case kNumberTypeFloat32: - InitFunc(p_); - break; - default: - MS_LOG(ERROR) << "cdist kernel does not support " << TypeIdToString(input_type_id); - return false; - } - return true; -} - -int CdistCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = 0; - if ((ret = KernelMod::Resize(inputs, outputs)) != 0) { - return ret; - } - std::vector in_shape0 = inputs[0]->GetShapeVector(); - std::vector in_shape1 = inputs[1]->GetShapeVector(); - auto in_shape_size = in_shape0.size(); - if (in_shape1.size() != in_shape_size || in_shape_size < kCdistInputDimsMin) { - MS_LOG(ERROR) << "invalid input shape, input0 shape size " << in_shape_size << ", input1 shape size " - << in_shape1.size() << ", kernel_name_ " << kernel_name_; - return KRET_RESIZE_FAILED; - } - batch_ = 1; - for (size_t i = 0; i < in_shape_size - kCdistInputDimsMin; i++) { - batch_ *= in_shape0[i]; - } - - r0_ = in_shape0[in_shape_size - 2]; - m_ = in_shape0[in_shape_size - 1]; - r1_ = in_shape1[in_shape_size - 2]; - - thread_num_ = std::min(static_cast(batch_), pool_->GetKernelThreadNum()); - - return 0; -} - -bool CdistCpuKernelMod::LaunchKernel(int64_t start, int64_t end) { - const auto *in_data0 = reinterpret_cast(in_data0_) + start * r0_ * m_; - const auto *in_data1 = reinterpret_cast(in_data1_) + start * r1_ * m_; - auto *out_data = reinterpret_cast(out_data_) + start * r0_ * r1_; - - for (int64_t b_i = 0; b_i < end - start; b_i++) { - for (int64_t p_i = 0; p_i < r0_; p_i++) { - auto in_data_tmp1 = in_data1; - for (int64_t r_i = 0; r_i < r1_; r_i++) { - dist_func_(in_data0, in_data_tmp1, &(out_data[r_i]), m_, p_); - in_data_tmp1 = in_data_tmp1 + m_; - } - in_data0 = in_data0 + m_; - out_data = out_data + r1_; - } - in_data1 = in_data1 + r1_ * m_; - } - - return true; -} - -std::vector CdistCpuKernelMod::GetOpSupport() { return kernel_attr; } - -bool CdistCpuKernelMod::DoLaunch(int task_id) { - auto batch_per_thread = UP_DIV(batch_, thread_num_); - int64_t start = batch_per_thread * task_id; - int64_t end = start + batch_per_thread; - end = std::min(end, batch_); - return LaunchKernel(start, end); -} - -int CdistRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { - auto cdist_kernel = reinterpret_cast(cdata); - if (!cdist_kernel->DoLaunch(task_id)) { - MS_LOG(ERROR) << "cdist_kernel DoLaunch failed, task_id:" << task_id; - return -1; - } - return 0; -} - -bool CdistCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - in_data0_ = inputs[0]->device_ptr(); - in_data1_ = inputs[1]->device_ptr(); - out_data_ = outputs[0]->device_ptr(); - int ret = pool_->ParallelLaunch(CdistRun, this, thread_num_); - if (ret != 0) { - MS_LOG(ERROR) << "CdistCpuKernelMod ParallelLaunch failed, error_code[" << ret << "]"; - return false; - } - return true; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cdist, CdistCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/cdist_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/cpu/kernel/nnacl/fp32/cdist_fp32.h" +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kCdistInputDimsMin = 2; + +const std::vector kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}}; +} // namespace + +void CdistCpuKernelMod::InitFunc(float p) { + if (p == 0.0) { + dist_func_ = CdistZeroNormalOpt; + } else if (p == 1.0) { + dist_func_ = CdistOneNormalOpt; + } else if (p == 2.0) { + dist_func_ = CdistTwoNormalOpt; + } else if (std::isinf(p)) { + dist_func_ = CdistInfNormalOpt; + } else { + dist_func_ = CdistPNormalOpt; + } +} + +bool CdistCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + p_ = GetValue(primitive_->GetAttr(ops::kP)); + auto input_type_id = inputs[0]->dtype_id(); + switch (input_type_id) { + case kNumberTypeFloat32: + InitFunc(p_); + break; + default: + MS_LOG(ERROR) << "cdist kernel does not support " << TypeIdToString(input_type_id); + return false; + } + return true; +} + +int CdistCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = 0; + if ((ret = KernelMod::Resize(inputs, outputs)) != 0) { + return ret; + } + std::vector in_shape0 = inputs[0]->GetShapeVector(); + std::vector in_shape1 = inputs[1]->GetShapeVector(); + auto in_shape_size = in_shape0.size(); + if (in_shape1.size() != in_shape_size || in_shape_size < kCdistInputDimsMin) { + MS_LOG(ERROR) << "invalid input shape, input0 shape size " << in_shape_size << ", input1 shape size " + << in_shape1.size() << ", kernel_name_ " << kernel_name_; + return KRET_RESIZE_FAILED; + } + batch_ = 1; + for (size_t i = 0; i < in_shape_size - kCdistInputDimsMin; i++) { + batch_ *= in_shape0[i]; + } + + r0_ = in_shape0[in_shape_size - 2]; + m_ = in_shape0[in_shape_size - 1]; + r1_ = in_shape1[in_shape_size - 2]; + + thread_num_ = std::min(static_cast(batch_), pool_->GetKernelThreadNum()); + + return 0; +} + +bool CdistCpuKernelMod::LaunchKernel(int64_t start, int64_t end) { + const auto *in_data0 = reinterpret_cast(in_data0_) + start * r0_ * m_; + const auto *in_data1 = reinterpret_cast(in_data1_) + start * r1_ * m_; + auto *out_data = reinterpret_cast(out_data_) + start * r0_ * r1_; + + for (int64_t b_i = 0; b_i < end - start; b_i++) { + for (int64_t p_i = 0; p_i < r0_; p_i++) { + auto in_data_tmp1 = in_data1; + for (int64_t r_i = 0; r_i < r1_; r_i++) { + dist_func_(in_data0, in_data_tmp1, &(out_data[r_i]), m_, p_); + in_data_tmp1 = in_data_tmp1 + m_; + } + in_data0 = in_data0 + m_; + out_data = out_data + r1_; + } + in_data1 = in_data1 + r1_ * m_; + } + + return true; +} + +std::vector CdistCpuKernelMod::GetOpSupport() { return kernel_attr; } + +bool CdistCpuKernelMod::DoLaunch(int task_id) { + auto batch_per_thread = UP_DIV(batch_, thread_num_); + int64_t start = batch_per_thread * task_id; + int64_t end = start + batch_per_thread; + end = std::min(end, batch_); + return LaunchKernel(start, end); +} + +int CdistRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { + auto cdist_kernel = reinterpret_cast(cdata); + if (!cdist_kernel->DoLaunch(task_id)) { + MS_LOG(ERROR) << "cdist_kernel DoLaunch failed, task_id:" << task_id; + return -1; + } + return 0; +} + +bool CdistCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + in_data0_ = inputs[0]->device_ptr(); + in_data1_ = inputs[1]->device_ptr(); + out_data_ = outputs[0]->device_ptr(); + int ret = pool_->ParallelLaunch(CdistRun, this, thread_num_); + if (ret != 0) { + MS_LOG(ERROR) << "CdistCpuKernelMod ParallelLaunch failed, error_code[" << ret << "]"; + return false; + } + return true; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cdist, CdistCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.h index 8d2840720df..cb183d159f4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_cpu_kernel.h @@ -1,65 +1,65 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "mindspore/core/ops/cdist.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" -namespace mindspore { -namespace kernel { -class CdistCpuKernelMod : public NativeCpuKernelMod { - public: - CdistCpuKernelMod() {} - ~CdistCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - bool DoLaunch(int task_id); - - private: - bool LaunchKernel(int64_t start, int64_t end); - - void InitFunc(float p); - - using DistFunc = std::function; - DistFunc dist_func_; - - int64_t batch_; - int64_t r0_; - int64_t m_; - int64_t r1_; - float p_ = 2; - size_t thread_num_; - void *in_data0_; - void *in_data1_; - void *out_data_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "mindspore/core/ops/cdist.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +namespace mindspore { +namespace kernel { +class CdistCpuKernelMod : public NativeCpuKernelMod { + public: + CdistCpuKernelMod() {} + ~CdistCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + bool DoLaunch(int task_id); + + private: + bool LaunchKernel(int64_t start, int64_t end); + + void InitFunc(float p); + + using DistFunc = std::function; + DistFunc dist_func_; + + int64_t batch_; + int64_t r0_; + int64_t m_; + int64_t r1_; + float p_ = 2; + size_t thread_num_; + void *in_data0_; + void *in_data1_; + void *out_data_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.cc index b0ba6317976..671c5762d9b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.cc @@ -1,184 +1,184 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kCdistInputDimsMin = 2; - -const std::vector kernel_attr = { - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}}; -} // namespace - -inline float DistSign(float val) { - return std::min(std::max(0.f, std::ceil(val)), (1.f)) + std::min(std::max((-1.f), std::floor(val)), (0.f)); -} - -float CdistOneNormalcompute(float diff, float grad, float dist, float p) { return grad * DistSign(diff); } - -float CdistLessTwoNormalcompute(float diff, float grad, float dist, float p) { - if (diff == 0.0 || p < 1.0) { - return 0.f; - } - return (DistSign(diff) * std::pow(std::abs(diff), (p - 1)) * grad / std::pow(dist, (p - 1))); -} - -float CdistTwoNormalcompute(float diff, float grad, float dist, float p) { - return dist == 0.0 ? 0.f : grad * diff / dist; -} - -float CdistInfNormalcompute(float diff, float grad, float dist, float p) { - return grad * DistSign(diff) * (1 - std::min(1.f, std::ceil(std::abs(std::abs(diff) - dist)))); -} - -float CdistPNormalcompute(float diff, float grad, float dist, float p) { - float result; - - if (dist == 0.0) { - result = 0.f; - } else { - result = diff * std::pow(std::abs(diff), (p - 2)) * grad / std::pow(dist, (p - 1)); - } - return result; -} - -void CdistGradCpuKernelMod::InitFunc(float p) { - if (p == 0.0) { - dist_func_ = nullptr; - } else if (p == 1.0) { - dist_func_ = CdistOneNormalcompute; - } else if (p < 2.0) { - dist_func_ = CdistLessTwoNormalcompute; - } else if (p == 2.0) { - dist_func_ = CdistTwoNormalcompute; - } else if (std::isinf(p)) { - dist_func_ = CdistInfNormalcompute; - } else { - dist_func_ = CdistPNormalcompute; - } -} - -bool CdistGradCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - p_ = GetValue(primitive_->GetAttr(ops::kP)); - auto input_type_id = inputs[0]->dtype_id(); - switch (input_type_id) { - case kNumberTypeFloat32: - InitFunc(p_); - break; - default: - MS_LOG(ERROR) << "cdist grad kernel does not support " << TypeIdToString(input_type_id); - return false; - } - return true; -} - -int CdistGradCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - auto ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - MS_LOG(WARNING) << "For " << kernel_name_ << " Resize failed. ret " << ret; - return ret; - } - std::vector in_shape0 = inputs[1]->GetShapeVector(); - std::vector in_shape1 = inputs[2]->GetShapeVector(); - auto in_shape_size = in_shape0.size(); - if (in_shape1.size() != in_shape_size || in_shape_size < kCdistInputDimsMin) { - MS_LOG(ERROR) << "For " << kernel_name_ << ",invalid input shape, input0 shape size " << in_shape_size - << ", input1 shape size " << in_shape1.size(); - return KRET_RESIZE_FAILED; - } - batch_ = 0; - for (size_t i = 0; i < in_shape_size - kCdistInputDimsMin; i++) { - batch_ += in_shape0[i]; - } - batch_ = (batch_ <= 0) ? 1 : batch_; - - r0_ = in_shape0[in_shape_size - 2]; - m_ = in_shape0[in_shape_size - 1]; - r1_ = in_shape1[in_shape_size - 2]; - - l1_size = r0_ * m_; - l2_size = r1_ * m_; - - return 0; -} - -std::vector CdistGradCpuKernelMod::GetOpSupport() { return kernel_attr; } - -bool CdistGradCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - float *grad_start = reinterpret_cast(inputs[0]->device_ptr()); - float *dist_start = reinterpret_cast(inputs[3]->device_ptr()); - float *t1_start = reinterpret_cast(inputs[1]->device_ptr()); - float *t2_start = reinterpret_cast(inputs[2]->device_ptr()); - float *res_start = reinterpret_cast(outputs[0]->device_ptr()); - auto ret = memset_s(res_start, outputs[0]->size(), 0, outputs[0]->size()); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset_s failed, ret=" << ret; - } - if (p_ == 0.0) { - return true; - } - - auto task = [this, grad_start, dist_start, t1_start, t2_start, res_start](size_t b_start, size_t b_end) { - const float *i = t1_start + b_start; - const float *j = t2_start + b_start; - float *res_l = res_start + b_start; - float *res_end = res_start + b_end; - for (; res_l != res_end; i += 1, j += 1, res_l += 1) { - const float *t1 = i; - const float *t2 = j; - float *res = res_l; - const float *t1_end = t1 + l1_size; - const float *t2_end = t2 + l2_size; - auto grad_k = grad_start; - auto dist_k = dist_start; - - for (int64_t l = 0; l < batch_; l++) { - for (; t1 != t1_end; t1 += m_, res += m_) { - float t1_tmp = *t1; - float res_tmp = *res; - - for (const float *t2_curr = t2; t2_curr != t2_end; t2_curr += m_, grad_k += 1, dist_k += 1) { - auto diff = t1_tmp - *t2_curr; - float res_curr = dist_func_(diff, (*grad_k), (*dist_k), p_); - res_tmp = res_tmp + res_curr; - } - - *res = res_tmp; - } - t1_end += l1_size; - t2_end += l2_size; - t2 += l2_size; - } - } - }; - ParallelLaunchAutoSearch(task, m_, this, ¶llel_search_info_, pool_); - - return true; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CdistGrad, CdistGradCpuKernelMod); - -}; // namespace kernel - -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kCdistInputDimsMin = 2; + +const std::vector kernel_attr = { + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}}; +} // namespace + +inline float DistSign(float val) { + return std::min(std::max(0.f, std::ceil(val)), (1.f)) + std::min(std::max((-1.f), std::floor(val)), (0.f)); +} + +float CdistOneNormalcompute(float diff, float grad, float dist, float p) { return grad * DistSign(diff); } + +float CdistLessTwoNormalcompute(float diff, float grad, float dist, float p) { + if (diff == 0.0 || p < 1.0) { + return 0.f; + } + return (DistSign(diff) * std::pow(std::abs(diff), (p - 1)) * grad / std::pow(dist, (p - 1))); +} + +float CdistTwoNormalcompute(float diff, float grad, float dist, float p) { + return dist == 0.0 ? 0.f : grad * diff / dist; +} + +float CdistInfNormalcompute(float diff, float grad, float dist, float p) { + return grad * DistSign(diff) * (1 - std::min(1.f, std::ceil(std::abs(std::abs(diff) - dist)))); +} + +float CdistPNormalcompute(float diff, float grad, float dist, float p) { + float result; + + if (dist == 0.0) { + result = 0.f; + } else { + result = diff * std::pow(std::abs(diff), (p - 2)) * grad / std::pow(dist, (p - 1)); + } + return result; +} + +void CdistGradCpuKernelMod::InitFunc(float p) { + if (p == 0.0) { + dist_func_ = nullptr; + } else if (p == 1.0) { + dist_func_ = CdistOneNormalcompute; + } else if (p < 2.0) { + dist_func_ = CdistLessTwoNormalcompute; + } else if (p == 2.0) { + dist_func_ = CdistTwoNormalcompute; + } else if (std::isinf(p)) { + dist_func_ = CdistInfNormalcompute; + } else { + dist_func_ = CdistPNormalcompute; + } +} + +bool CdistGradCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + p_ = GetValue(primitive_->GetAttr(ops::kP)); + auto input_type_id = inputs[0]->dtype_id(); + switch (input_type_id) { + case kNumberTypeFloat32: + InitFunc(p_); + break; + default: + MS_LOG(ERROR) << "cdist grad kernel does not support " << TypeIdToString(input_type_id); + return false; + } + return true; +} + +int CdistGradCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + auto ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + MS_LOG(WARNING) << "For " << kernel_name_ << " Resize failed. ret " << ret; + return ret; + } + std::vector in_shape0 = inputs[1]->GetShapeVector(); + std::vector in_shape1 = inputs[2]->GetShapeVector(); + auto in_shape_size = in_shape0.size(); + if (in_shape1.size() != in_shape_size || in_shape_size < kCdistInputDimsMin) { + MS_LOG(ERROR) << "For " << kernel_name_ << ",invalid input shape, input0 shape size " << in_shape_size + << ", input1 shape size " << in_shape1.size(); + return KRET_RESIZE_FAILED; + } + batch_ = 0; + for (size_t i = 0; i < in_shape_size - kCdistInputDimsMin; i++) { + batch_ += in_shape0[i]; + } + batch_ = (batch_ <= 0) ? 1 : batch_; + + r0_ = in_shape0[in_shape_size - 2]; + m_ = in_shape0[in_shape_size - 1]; + r1_ = in_shape1[in_shape_size - 2]; + + l1_size = r0_ * m_; + l2_size = r1_ * m_; + + return 0; +} + +std::vector CdistGradCpuKernelMod::GetOpSupport() { return kernel_attr; } + +bool CdistGradCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + float *grad_start = reinterpret_cast(inputs[0]->device_ptr()); + float *dist_start = reinterpret_cast(inputs[3]->device_ptr()); + float *t1_start = reinterpret_cast(inputs[1]->device_ptr()); + float *t2_start = reinterpret_cast(inputs[2]->device_ptr()); + float *res_start = reinterpret_cast(outputs[0]->device_ptr()); + auto ret = memset_s(res_start, outputs[0]->size(), 0, outputs[0]->size()); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset_s failed, ret=" << ret; + } + if (p_ == 0.0) { + return true; + } + + auto task = [this, grad_start, dist_start, t1_start, t2_start, res_start](size_t b_start, size_t b_end) { + const float *i = t1_start + b_start; + const float *j = t2_start + b_start; + float *res_l = res_start + b_start; + float *res_end = res_start + b_end; + for (; res_l != res_end; i += 1, j += 1, res_l += 1) { + const float *t1 = i; + const float *t2 = j; + float *res = res_l; + const float *t1_end = t1 + l1_size; + const float *t2_end = t2 + l2_size; + auto grad_k = grad_start; + auto dist_k = dist_start; + + for (int64_t l = 0; l < batch_; l++) { + for (; t1 != t1_end; t1 += m_, res += m_) { + float t1_tmp = *t1; + float res_tmp = *res; + + for (const float *t2_curr = t2; t2_curr != t2_end; t2_curr += m_, grad_k += 1, dist_k += 1) { + auto diff = t1_tmp - *t2_curr; + float res_curr = dist_func_(diff, (*grad_k), (*dist_k), p_); + res_tmp = res_tmp + res_curr; + } + + *res = res_tmp; + } + t1_end += l1_size; + t2_end += l2_size; + t2 += l2_size; + } + } + }; + ParallelLaunchAutoSearch(task, m_, this, ¶llel_search_info_, pool_); + + return true; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CdistGrad, CdistGradCpuKernelMod); + +}; // namespace kernel + +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h index c4a6b10b26e..c913116bd01 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cdist_grad_cpu_kernel.h @@ -1,61 +1,61 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "mindspore/core/ops/grad/cdist_grad.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" -namespace mindspore { -namespace kernel { -class CdistGradCpuKernelMod : public NativeCpuKernelMod { - public: - CdistGradCpuKernelMod() {} - ~CdistGradCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - void InitFunc(float p); - - using DistFunc = std::function; - DistFunc dist_func_; - - int64_t batch_; - int64_t r0_; - int64_t m_; - int64_t r1_; - int64_t l1_size; - int64_t l2_size; - - float p_ = 2; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "mindspore/core/ops/grad/cdist_grad.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +namespace mindspore { +namespace kernel { +class CdistGradCpuKernelMod : public NativeCpuKernelMod { + public: + CdistGradCpuKernelMod() {} + ~CdistGradCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + void InitFunc(float p); + + using DistFunc = std::function; + DistFunc dist_func_; + + int64_t batch_; + int64_t r0_; + int64_t m_; + int64_t r1_; + int64_t l1_size; + int64_t l2_size; + + float p_ = 2; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.cc index 5d90b53023f..cd75ffc3e78 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.cc @@ -1,75 +1,75 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/celu_cpu_kernel.h" -#include "mindspore/core/ops/ops_func_impl/celu.h" -#include -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" - -namespace mindspore { -namespace kernel { -namespace { - -const std::vector kernel_attr = {{KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32)}}; -} // namespace - -bool CeluCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - auto input_type_id = inputs[0]->dtype_id(); - if (input_type_id != kNumberTypeFloat32) { - MS_LOG(ERROR) << "celu kernel does not support " << TypeIdToString(input_type_id); - return false; - } - unit_size_ = sizeof(float); - return true; -} - -int CeluCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - - input_elements_ = output_size_list_[0] / unit_size_; - alpha_ = static_cast(inputs[kIndex1]->GetValueWithCheck()); - return KRET_OK; -} - -std::vector CeluCpuKernelMod::GetOpSupport() { return kernel_attr; } - -bool CeluCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - auto in_data = static_cast(inputs[0]->device_ptr()); - auto out_data = static_cast(outputs[0]->device_ptr()); - auto task = [this, in_data, out_data](size_t start, size_t end) { - auto src = in_data + start; - auto dst = out_data + start; - auto length = end - start; - for (size_t i = 0; i < length; ++i) { - dst[i] = src[i] > 0 ? src[i] : (expm1(src[i] / alpha_) * alpha_); - } - }; - ParallelLaunchAutoSearch(task, input_elements_, this, ¶llel_search_info_, pool_); - - return true; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CeLU, CeluCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/celu_cpu_kernel.h" +#include "mindspore/core/ops/ops_func_impl/celu.h" +#include +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" + +namespace mindspore { +namespace kernel { +namespace { + +const std::vector kernel_attr = {{KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32)}}; +} // namespace + +bool CeluCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + auto input_type_id = inputs[0]->dtype_id(); + if (input_type_id != kNumberTypeFloat32) { + MS_LOG(ERROR) << "celu kernel does not support " << TypeIdToString(input_type_id); + return false; + } + unit_size_ = sizeof(float); + return true; +} + +int CeluCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + + input_elements_ = output_size_list_[0] / unit_size_; + alpha_ = static_cast(inputs[kIndex1]->GetValueWithCheck()); + return KRET_OK; +} + +std::vector CeluCpuKernelMod::GetOpSupport() { return kernel_attr; } + +bool CeluCpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + auto in_data = static_cast(inputs[0]->device_ptr()); + auto out_data = static_cast(outputs[0]->device_ptr()); + auto task = [this, in_data, out_data](size_t start, size_t end) { + auto src = in_data + start; + auto dst = out_data + start; + auto length = end - start; + for (size_t i = 0; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i] / alpha_) * alpha_); + } + }; + ParallelLaunchAutoSearch(task, input_elements_, this, ¶llel_search_info_, pool_); + + return true; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CeLU, CeluCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.h index c8a1fbd8be9..0a39512c91c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/celu_cpu_kernel.h @@ -1,51 +1,51 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CELU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CELU_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class CeluCpuKernelMod : public NativeCpuKernelMod { - public: - CeluCpuKernelMod() {} - ~CeluCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - size_t unit_size_; - size_t input_elements_; - double alpha_{1.0}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CELU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CELU_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class CeluCpuKernelMod : public NativeCpuKernelMod { + public: + CeluCpuKernelMod() {} + ~CeluCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + size_t unit_size_; + size_t input_elements_; + double alpha_{1.0}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CDIST_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.cc index afaf0c3d48d..1b4c44a300b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.cc @@ -1,183 +1,183 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h" -#include "mindspore/core/ops/cumulative_logsumexp.h" -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kCumulativeLogsumexpInputsNum = 2; -constexpr size_t kCumulativeLogsumexpOutputsNum = 1; -constexpr size_t kAxisDimension = 1; -constexpr size_t kAxisShapeSize = 1; -constexpr size_t kInputIndex0 = 0; -const float float16_exclusive_data = -65504e+0; -const float float_exclusive_data = -3.4028235e+38; -const double double_exclusive_data = -1.7976931348623157e+308; -} // namespace - -bool CumulativeLogsumexpCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - shape_ = inputs[kInputIndex0]->GetShapeVector(); - dtype_ = inputs[kInputIndex0]->dtype_id(); - exclusive_ = GetValue(primitive_->GetAttr(ops::kExclusive)); - reverse_ = GetValue(primitive_->GetAttr(ops::kReverse)); - return true; -} - -int CumulativeLogsumexpCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - auto ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - - shape_ = inputs[kInputIndex0]->GetShapeVector(); - dtype_ = inputs[kInputIndex0]->dtype_id(); - return KRET_OK; -} - -bool CumulativeLogsumexpCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumulativeLogsumexpInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumulativeLogsumexpOutputsNum, kernel_name_); - if (dtype_ == kNumberTypeFloat64) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', kernel data type " << TypeIdLabel(dtype_) << "not support."; - } - return true; -} - -template -void CumulativeLogsumexpCpuKernelMod::CumulativeProcess(const t *input_data, t *output_data, const uint32_t outer, - const uint32_t inner, const uint32_t depth) const { - for (size_t outer_index = 0; outer_index < outer; ++outer_index) { - size_t outer_index_adj; - if (reverse_) { - outer_index_adj = (outer - 1) - outer_index; - } else { - outer_index_adj = outer_index; - } - for (size_t inner_index = 0; inner_index < inner; ++inner_index) { - double one = 1; - double temp = 0; - size_t inner_index_adj; - if (reverse_) { - inner_index_adj = (inner - 1) - inner_index; - } else { - inner_index_adj = inner_index; - } - for (size_t depth_index = 0; depth_index < depth; ++depth_index) { - size_t depth_index_adj; - if (reverse_) { - depth_index_adj = (depth - 1) - depth_index; - } else { - depth_index_adj = depth_index; - } - size_t index = outer_index_adj; - index += inner_index_adj * depth * outer; - index += depth_index_adj * outer; - if (exclusive_) { - if (depth_index == 0) { - if (dtype_ == kNumberTypeFloat16) { - output_data[index] = static_cast(float16_exclusive_data); - } else if (dtype_ == kNumberTypeFloat32) { - output_data[index] = static_cast(float_exclusive_data); - } else { - output_data[index] = static_cast(double_exclusive_data); - } - temp = static_cast(input_data[index]); - } else { - output_data[index] = static_cast(temp); - double a = temp; - double b, min, max; - b = static_cast(input_data[index]); - min = (a < b) ? a : b; - max = (a >= b) ? a : b; - temp = log(one + exp(min - max)) + max; - } - } else { - if (depth_index == 0) { - output_data[index] = input_data[index]; - temp = static_cast(input_data[index]); - } else { - double a = temp; - double b, min, max; - b = static_cast(input_data[index]); - min = (a < b) ? a : b; - max = (a >= b) ? a : b; - output_data[index] = static_cast(log(one + exp(min - max)) + max); - temp = log(one + exp(min - max)) + max; - } - } - } - } - } -} - -template -void CumulativeLogsumexpCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto *input_data = static_cast(inputs[kIndex0]->device_ptr()); - auto axis_ = static_cast(inputs[kIndex1]->device_ptr()); - auto *output_data = static_cast(outputs[kIndex0]->device_ptr()); - size_t lens = inputs[kIndex0]->size() > 0 ? static_cast(inputs[kIndex0]->size() / sizeof(T)) : 1; - auto task = [this, input_data, axis_, output_data](const size_t start, const size_t end) { - int32_t x_rank = SizeToInt(shape_.size()); - if (axis_[0] >= x_rank || axis_[0] < -x_rank) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' must be in range [" << -x_rank << ", " << x_rank - << "), but got: " << axis_[0]; - } - if (axis_[0] < 0) { - axis_[0] += x_rank; - } - uint32_t inner = 1; - uint32_t depth = static_cast(shape_[IntToSize(axis_[0])]); - uint32_t outer = 1; - for (size_t i = 0; i < IntToSize(axis_[0]); i++) { - inner *= static_cast(shape_[i]); - } - for (size_t i = IntToSize(axis_[0]) + 1; i < shape_.size(); i++) { - outer *= static_cast(shape_[i]); - } - CumulativeProcess(input_data, output_data, outer, inner, depth); - }; - ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); -} - -std::vector CumulativeLogsumexpCpuKernelMod::GetOpSupport() { - std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64)}; - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CumulativeLogsumexp, CumulativeLogsumexpCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h" +#include "mindspore/core/ops/cumulative_logsumexp.h" +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kCumulativeLogsumexpInputsNum = 2; +constexpr size_t kCumulativeLogsumexpOutputsNum = 1; +constexpr size_t kAxisDimension = 1; +constexpr size_t kAxisShapeSize = 1; +constexpr size_t kInputIndex0 = 0; +const float float16_exclusive_data = -65504e+0; +const float float_exclusive_data = -3.4028235e+38; +const double double_exclusive_data = -1.7976931348623157e+308; +} // namespace + +bool CumulativeLogsumexpCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + shape_ = inputs[kInputIndex0]->GetShapeVector(); + dtype_ = inputs[kInputIndex0]->dtype_id(); + exclusive_ = GetValue(primitive_->GetAttr(ops::kExclusive)); + reverse_ = GetValue(primitive_->GetAttr(ops::kReverse)); + return true; +} + +int CumulativeLogsumexpCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + auto ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + + shape_ = inputs[kInputIndex0]->GetShapeVector(); + dtype_ = inputs[kInputIndex0]->dtype_id(); + return KRET_OK; +} + +bool CumulativeLogsumexpCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumulativeLogsumexpInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumulativeLogsumexpOutputsNum, kernel_name_); + if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', kernel data type " << TypeIdLabel(dtype_) << "not support."; + } + return true; +} + +template +void CumulativeLogsumexpCpuKernelMod::CumulativeProcess(const t *input_data, t *output_data, const uint32_t outer, + const uint32_t inner, const uint32_t depth) const { + for (size_t outer_index = 0; outer_index < outer; ++outer_index) { + size_t outer_index_adj; + if (reverse_) { + outer_index_adj = (outer - 1) - outer_index; + } else { + outer_index_adj = outer_index; + } + for (size_t inner_index = 0; inner_index < inner; ++inner_index) { + double one = 1; + double temp = 0; + size_t inner_index_adj; + if (reverse_) { + inner_index_adj = (inner - 1) - inner_index; + } else { + inner_index_adj = inner_index; + } + for (size_t depth_index = 0; depth_index < depth; ++depth_index) { + size_t depth_index_adj; + if (reverse_) { + depth_index_adj = (depth - 1) - depth_index; + } else { + depth_index_adj = depth_index; + } + size_t index = outer_index_adj; + index += inner_index_adj * depth * outer; + index += depth_index_adj * outer; + if (exclusive_) { + if (depth_index == 0) { + if (dtype_ == kNumberTypeFloat16) { + output_data[index] = static_cast(float16_exclusive_data); + } else if (dtype_ == kNumberTypeFloat32) { + output_data[index] = static_cast(float_exclusive_data); + } else { + output_data[index] = static_cast(double_exclusive_data); + } + temp = static_cast(input_data[index]); + } else { + output_data[index] = static_cast(temp); + double a = temp; + double b, min, max; + b = static_cast(input_data[index]); + min = (a < b) ? a : b; + max = (a >= b) ? a : b; + temp = log(one + exp(min - max)) + max; + } + } else { + if (depth_index == 0) { + output_data[index] = input_data[index]; + temp = static_cast(input_data[index]); + } else { + double a = temp; + double b, min, max; + b = static_cast(input_data[index]); + min = (a < b) ? a : b; + max = (a >= b) ? a : b; + output_data[index] = static_cast(log(one + exp(min - max)) + max); + temp = log(one + exp(min - max)) + max; + } + } + } + } + } +} + +template +void CumulativeLogsumexpCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto *input_data = static_cast(inputs[kIndex0]->device_ptr()); + auto axis_ = static_cast(inputs[kIndex1]->device_ptr()); + auto *output_data = static_cast(outputs[kIndex0]->device_ptr()); + size_t lens = inputs[kIndex0]->size() > 0 ? static_cast(inputs[kIndex0]->size() / sizeof(T)) : 1; + auto task = [this, input_data, axis_, output_data](const size_t start, const size_t end) { + int32_t x_rank = SizeToInt(shape_.size()); + if (axis_[0] >= x_rank || axis_[0] < -x_rank) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' must be in range [" << -x_rank << ", " << x_rank + << "), but got: " << axis_[0]; + } + if (axis_[0] < 0) { + axis_[0] += x_rank; + } + uint32_t inner = 1; + uint32_t depth = static_cast(shape_[IntToSize(axis_[0])]); + uint32_t outer = 1; + for (size_t i = 0; i < IntToSize(axis_[0]); i++) { + inner *= static_cast(shape_[i]); + } + for (size_t i = IntToSize(axis_[0]) + 1; i < shape_.size(); i++) { + outer *= static_cast(shape_[i]); + } + CumulativeProcess(input_data, output_data, outer, inner, depth); + }; + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); +} + +std::vector CumulativeLogsumexpCpuKernelMod::GetOpSupport() { + std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CumulativeLogsumexp, CumulativeLogsumexpCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h index 45da6da5309..5bdba9c850b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cumulative_logsumexp_cpu_kernel.h @@ -1,58 +1,58 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class CumulativeLogsumexpCpuKernelMod : public NativeCpuKernelMod { - public: - CumulativeLogsumexpCpuKernelMod() = default; - ~CumulativeLogsumexpCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - template - void CumulativeProcess(const t *input_data, t *output_data, const uint32_t outer, const uint32_t inner, - const uint32_t depth) const; - - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - protected: - std::vector GetOpSupport() override; - - private: - ShapeVector shape_; - bool exclusive_{false}; - bool reverse_{false}; - TypeId dtype_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore -#endif +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMULATIVE_LOGSUMEXP_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class CumulativeLogsumexpCpuKernelMod : public NativeCpuKernelMod { + public: + CumulativeLogsumexpCpuKernelMod() = default; + ~CumulativeLogsumexpCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void CumulativeProcess(const t *input_data, t *output_data, const uint32_t outer, const uint32_t inner, + const uint32_t depth) const; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + protected: + std::vector GetOpSupport() override; + + private: + ShapeVector shape_; + bool exclusive_{false}; + bool reverse_{false}; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.cc index c1d14cc4bdf..e11bae74463 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.cc @@ -1,86 +1,86 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/diag_cpu_kernel.h" -#include -#include -#include - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kDiagInputsNum = 1; -constexpr size_t kDiagOutputsNum = 1; -} // namespace - -bool DiagCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDiagOutputsNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "Diag does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -template -bool DiagCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, - const std::vector &outputs) { - auto aptr = static_cast(inputs[0]->device_ptr()); - auto xptr = static_cast(outputs[0]->device_ptr()); - - int64_t data_num = static_cast(inputs[0]->size() / sizeof(T)); - - auto task = [&xptr, &aptr, &data_num](int64_t start, int64_t end) { - std::fill(xptr + data_num * start, xptr + data_num * end, T()); - for (int64_t index = start; index < end; index++) { - *(xptr + (1 + data_num) * index) = *(aptr + index); - } - }; - ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_); - - return true; -} - -std::vector> DiagCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &DiagCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &DiagCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &DiagCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &DiagCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &DiagCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - &DiagCpuKernelMod::LaunchKernel>}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - &DiagCpuKernelMod::LaunchKernel>}}; - -std::vector DiagCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Diag, DiagCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/diag_cpu_kernel.h" +#include +#include +#include + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kDiagInputsNum = 1; +constexpr size_t kDiagOutputsNum = 1; +} // namespace + +bool DiagCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDiagOutputsNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "Diag does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +template +bool DiagCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + auto aptr = static_cast(inputs[0]->device_ptr()); + auto xptr = static_cast(outputs[0]->device_ptr()); + + int64_t data_num = static_cast(inputs[0]->size() / sizeof(T)); + + auto task = [&xptr, &aptr, &data_num](int64_t start, int64_t end) { + std::fill(xptr + data_num * start, xptr + data_num * end, T()); + for (int64_t index = start; index < end; index++) { + *(xptr + (1 + data_num) * index) = *(aptr + index); + } + }; + ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_); + + return true; +} + +std::vector> DiagCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &DiagCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &DiagCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &DiagCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &DiagCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &DiagCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &DiagCpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &DiagCpuKernelMod::LaunchKernel>}}; + +std::vector DiagCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Diag, DiagCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.h index 2976bcb56c4..f6e627a7ed0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/diag_cpu_kernel.h @@ -1,57 +1,57 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ - -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class DiagCpuKernelMod : public NativeCpuKernelMod { - public: - DiagCpuKernelMod() = default; - ~DiagCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - - using DiagFunc = - std::function &, - const std::vector &, const std::vector &)>; - static std::vector> func_list_; - DiagFunc kernel_func_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class DiagCpuKernelMod : public NativeCpuKernelMod { + public: + DiagCpuKernelMod() = default; + ~DiagCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + + using DiagFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + DiagFunc kernel_func_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DIAG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.cc index 6a026d5a44c..105fb55a275 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.cc @@ -1,190 +1,190 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h" -#include -#include "unsupported/Eigen/CXX11/Tensor" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kExpandInputsNum = 2; -const size_t kExpandOutputsNum = 1; -const size_t kNoBroadcastValue = 1; -const size_t kRank0 = 0; -const size_t kRank1 = 1; -const size_t kRank2 = 2; -const size_t kRank3 = 3; -const size_t kRank4 = 4; -const size_t kRank5 = 5; -const size_t kRank6 = 6; -const size_t kRank7 = 7; -const size_t kRank8 = 8; -} // namespace - -bool ExpandCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kExpandInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kExpandOutputsNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match.first) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - return true; -} - -int ExpandCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_x_shape_ = LongVecToSizeVec(inputs[kIndex0]->GetDeviceShapeVector()); - input_x_dtype_ = inputs[kIndex0]->dtype_id(); - input_shape_ = LongVecToSizeVec(outputs[kIndex0]->GetDeviceShapeVector()); - output_y_shape_ = LongVecToSizeVec(outputs[kIndex0]->GetDeviceShapeVector()); - return KRET_OK; -} - -bool ExpandCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kExpandInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kExpandOutputsNum, kernel_name_); - switch (input_x_dtype_) { - case kNumberTypeFloat16: - return ExpandCompute(inputs, outputs); - case kNumberTypeFloat32: - return ExpandCompute(inputs, outputs); - case kNumberTypeInt8: - return ExpandCompute(inputs, outputs); - case kNumberTypeInt32: - return ExpandCompute(inputs, outputs); - case kNumberTypeUInt8: - return ExpandCompute(inputs, outputs); - default: - MS_LOG(EXCEPTION) << "For " << kernel_name_ - << ", the dtype of input `x` must in [float16, float32, int8, int32, uint8] " - << "but got " << TypeIdToType(input_x_dtype_)->ToString() << "."; - } -} - -size_t ExpandCpuKernelMod::get_element_num(const std::vector &shape) const { - size_t size = 1; - for (size_t i = 0; i < shape.size(); i++) { - size *= shape[i]; - } - return size; -} - -template -bool ExpandCpuKernelMod::ExpandCompute(const std::vector &inputs, - const std::vector &outputs) { - size_t rank = static_cast(output_y_shape_.size()); - switch (rank) { - case kRank0: { - T v0 = *(reinterpret_cast(inputs[0]->device_ptr())); - T *value_out = reinterpret_cast(outputs[0]->device_ptr()); - *(value_out) = v0; - return true; - } - case kRank1: - return ExpandCalculate(inputs, outputs); - case kRank2: - return ExpandCalculate(inputs, outputs); - case kRank3: - return ExpandCalculate(inputs, outputs); - case kRank4: - return ExpandCalculate(inputs, outputs); - case kRank5: - return ExpandCalculate(inputs, outputs); - case kRank6: - return ExpandCalculate(inputs, outputs); - case kRank7: - return ExpandCalculate(inputs, outputs); - case kRank8: - return ExpandCalculate(inputs, outputs); - default: - MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the rank of output should not expand than 8 but got " - << std::to_string(rank) << "."; - return false; - } -} - -template -bool ExpandCpuKernelMod::ExpandCalculate(const std::vector &inputs, - const std::vector &outputs) { - size_t input_x_element_num = get_element_num(input_x_shape_); - size_t output_y_element_num = get_element_num(output_y_shape_); - - (void)input_x_shape_.insert(input_x_shape_.begin(), RANK - input_x_shape_.size(), 1); - input_x_bcast_.clear(); - input_x_bcast_.resize(RANK, kNoBroadcastValue); - for (size_t i = 0; i < RANK; i++) { - if (input_x_shape_[i] == input_shape_[i]) { - continue; - } - if (input_x_shape_[i] == kNoBroadcastValue) { - input_x_bcast_[i] = input_shape_[i]; - } else { - MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", broadcast not support, dim_x[" << std::to_string(i) - << "]=" << std::to_string(input_x_shape_[i]) << ", dim_y[" << std::to_string(i) - << "]=" << std::to_string(input_shape_[i]) << "."; - return false; - } - } - - Eigen::TensorMap, Eigen::Aligned> input_x(static_cast(inputs[0]->device_ptr()), - input_x_element_num); - Eigen::TensorMap, Eigen::Aligned> output_y(static_cast(outputs[0]->device_ptr()), - output_y_element_num); - - Eigen::DSizes input_reshape; - Eigen::DSizes output_shape; - Eigen::array bcast; - - for (size_t i = 0; i < RANK; i++) { - input_reshape[RANK - i - 1] = static_cast(input_x_shape_[i]); - output_shape[RANK - i - 1] = static_cast(output_y_shape_[i]); - bcast[RANK - i - 1] = static_cast(input_x_bcast_[i]); - } - - output_y.reshape(output_shape) = input_x.reshape(input_reshape).broadcast(bcast); - return true; -} - -std::vector ExpandCpuKernelMod::GetOpSupport() { - static const std::vector kernel_attr_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32)}; - - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Expand, ExpandCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h" +#include +#include "unsupported/Eigen/CXX11/Tensor" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kExpandInputsNum = 2; +const size_t kExpandOutputsNum = 1; +const size_t kNoBroadcastValue = 1; +const size_t kRank0 = 0; +const size_t kRank1 = 1; +const size_t kRank2 = 2; +const size_t kRank3 = 3; +const size_t kRank4 = 4; +const size_t kRank5 = 5; +const size_t kRank6 = 6; +const size_t kRank7 = 7; +const size_t kRank8 = 8; +} // namespace + +bool ExpandCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kExpandInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kExpandOutputsNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + return true; +} + +int ExpandCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_x_shape_ = LongVecToSizeVec(inputs[kIndex0]->GetDeviceShapeVector()); + input_x_dtype_ = inputs[kIndex0]->dtype_id(); + input_shape_ = LongVecToSizeVec(outputs[kIndex0]->GetDeviceShapeVector()); + output_y_shape_ = LongVecToSizeVec(outputs[kIndex0]->GetDeviceShapeVector()); + return KRET_OK; +} + +bool ExpandCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kExpandInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kExpandOutputsNum, kernel_name_); + switch (input_x_dtype_) { + case kNumberTypeFloat16: + return ExpandCompute(inputs, outputs); + case kNumberTypeFloat32: + return ExpandCompute(inputs, outputs); + case kNumberTypeInt8: + return ExpandCompute(inputs, outputs); + case kNumberTypeInt32: + return ExpandCompute(inputs, outputs); + case kNumberTypeUInt8: + return ExpandCompute(inputs, outputs); + default: + MS_LOG(EXCEPTION) << "For " << kernel_name_ + << ", the dtype of input `x` must in [float16, float32, int8, int32, uint8] " + << "but got " << TypeIdToType(input_x_dtype_)->ToString() << "."; + } +} + +size_t ExpandCpuKernelMod::get_element_num(const std::vector &shape) const { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} + +template +bool ExpandCpuKernelMod::ExpandCompute(const std::vector &inputs, + const std::vector &outputs) { + size_t rank = static_cast(output_y_shape_.size()); + switch (rank) { + case kRank0: { + T v0 = *(reinterpret_cast(inputs[0]->device_ptr())); + T *value_out = reinterpret_cast(outputs[0]->device_ptr()); + *(value_out) = v0; + return true; + } + case kRank1: + return ExpandCalculate(inputs, outputs); + case kRank2: + return ExpandCalculate(inputs, outputs); + case kRank3: + return ExpandCalculate(inputs, outputs); + case kRank4: + return ExpandCalculate(inputs, outputs); + case kRank5: + return ExpandCalculate(inputs, outputs); + case kRank6: + return ExpandCalculate(inputs, outputs); + case kRank7: + return ExpandCalculate(inputs, outputs); + case kRank8: + return ExpandCalculate(inputs, outputs); + default: + MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", the rank of output should not expand than 8 but got " + << std::to_string(rank) << "."; + return false; + } +} + +template +bool ExpandCpuKernelMod::ExpandCalculate(const std::vector &inputs, + const std::vector &outputs) { + size_t input_x_element_num = get_element_num(input_x_shape_); + size_t output_y_element_num = get_element_num(output_y_shape_); + + (void)input_x_shape_.insert(input_x_shape_.begin(), RANK - input_x_shape_.size(), 1); + input_x_bcast_.clear(); + input_x_bcast_.resize(RANK, kNoBroadcastValue); + for (size_t i = 0; i < RANK; i++) { + if (input_x_shape_[i] == input_shape_[i]) { + continue; + } + if (input_x_shape_[i] == kNoBroadcastValue) { + input_x_bcast_[i] = input_shape_[i]; + } else { + MS_LOG(EXCEPTION) << "For " << kernel_name_ << ", broadcast not support, dim_x[" << std::to_string(i) + << "]=" << std::to_string(input_x_shape_[i]) << ", dim_y[" << std::to_string(i) + << "]=" << std::to_string(input_shape_[i]) << "."; + return false; + } + } + + Eigen::TensorMap, Eigen::Aligned> input_x(static_cast(inputs[0]->device_ptr()), + input_x_element_num); + Eigen::TensorMap, Eigen::Aligned> output_y(static_cast(outputs[0]->device_ptr()), + output_y_element_num); + + Eigen::DSizes input_reshape; + Eigen::DSizes output_shape; + Eigen::array bcast; + + for (size_t i = 0; i < RANK; i++) { + input_reshape[RANK - i - 1] = static_cast(input_x_shape_[i]); + output_shape[RANK - i - 1] = static_cast(output_y_shape_[i]); + bcast[RANK - i - 1] = static_cast(input_x_bcast_[i]); + } + + output_y.reshape(output_shape) = input_x.reshape(input_reshape).broadcast(bcast); + return true; +} + +std::vector ExpandCpuKernelMod::GetOpSupport() { + static const std::vector kernel_attr_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32)}; + + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Expand, ExpandCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h index a76cc80c84e..105198e7574 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eigen/expand_cpu_kernel.h @@ -1,62 +1,62 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class ExpandCpuKernelMod : public NativeCpuKernelMod { - public: - ExpandCpuKernelMod() = default; - ~ExpandCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - size_t get_element_num(const std::vector &shape) const; - - template - bool ExpandCompute(const std::vector &inputs, const std::vector &outputs); - - template - bool ExpandCalculate(const std::vector &inputs, const std::vector &outputs); - - std::vector GetOpSupport() override; - - private: - std::vector input_x_shape_; - TypeId input_x_dtype_{kNumberTypeFloat32}; - std::vector input_shape_; - std::vector output_y_shape_; - std::vector input_x_bcast_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class ExpandCpuKernelMod : public NativeCpuKernelMod { + public: + ExpandCpuKernelMod() = default; + ~ExpandCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + size_t get_element_num(const std::vector &shape) const; + + template + bool ExpandCompute(const std::vector &inputs, const std::vector &outputs); + + template + bool ExpandCalculate(const std::vector &inputs, const std::vector &outputs); + + std::vector GetOpSupport() override; + + private: + std::vector input_x_shape_; + TypeId input_x_dtype_{kNumberTypeFloat32}; + std::vector input_shape_; + std::vector output_y_shape_; + std::vector input_x_bcast_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EXPAND_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.cc index 75074e408db..8720f1dd853 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.cc @@ -1,313 +1,313 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h" -#include -#include "ops/op_utils.h" -#include "kernel/kernel.h" - -#define FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, real, inverse) \ - if (signal_ndim_ == 1) { \ - FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ - } else if (signal_ndim_ == 2) { \ - FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ - } else { \ - FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ - } -using std::vector; -namespace mindspore { -namespace kernel { -namespace { -constexpr int kDimNum_FFT = 1; -constexpr int kDimNum_IFFT = 2; -constexpr int kDimNum_RFFT = 3; -constexpr int kDimNum_IRFFT = 4; -constexpr int kRealFFTSideNum = 2; - -int64_t FFTWithSize_choose(bool real, bool inverse) { - if (!real) { - if (!inverse) { - return kDimNum_FFT; // fftz - } else { - return kDimNum_IFFT; // ifft - } - } else { - if (!inverse) { - return kDimNum_RFFT; // rfft - } else { - return kDimNum_IRFFT; // irfft - } - } -} - -int64_t get_element_num(const std::vector &shape, size_t rank) { - size_t back_itr = shape.size(); - int64_t size = 1; - for (size_t i = 1; i <= rank; i++) { - auto dim = shape[back_itr - i]; - MS_EXCEPTION_IF_CHECK_FAIL(dim > 0, "The element in shape must be positive."); - size *= dim; - } - return size; -} - -template -void change_axes(Eigen::array *axes) { - for (unsigned i = from; i <= (unsigned)to; i++) { - axes->operator[](i - 1) = i; - } - return; -} -} // namespace - -bool FFTWithSizeCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << kernel_name_ << " valid cpu kernel does not support this kernel data type: " << kernel_attr; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int FFTWithSizeCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - - x_shape_ = inputs[kIndex0]->GetShapeVector(); - signal_ndim_ = inputs[kIndex1]->GetValueWithCheck(); - inverse_ = inputs[kIndex2]->GetValueWithCheck(); - real_ = inputs[kIndex3]->GetValueWithCheck(); - normalized_ = inputs[kIndex4]->GetValueWithCheck(); - onesided_ = inputs[kIndex5]->GetValueWithCheck(); - raw_checked_signal_size_ = inputs[kIndex6]->GetValueWithCheck>(); - - return KRET_OK; -} - -double Getnormalized(int64_t element_num, const std::string &normalized, bool is_reverse) { - double result = 1.0; - if (!is_reverse) { - if (normalized == "forward") result = 1.0 / element_num; - if (normalized == "backward") result = 1.0; - if (normalized == "ortho") result = 1.0 / sqrt(static_cast(element_num)); - } - if (is_reverse) { - if (normalized == "forward") result = 1.0 * element_num; - if (normalized == "backward") result = 1.0; - if (normalized == "ortho") result = 1.0 * sqrt(static_cast(element_num)); - } - return result; -} - -template -inline Eigen::DSizes GetFlatShape(const std::vector &x_shape, - size_t x_dims) { - Eigen::DSizes tensor_shape; - if (x_dims == signal_ndim) { - tensor_shape[0] = 1; - for (size_t i = 0; i < x_dims; i++) { - tensor_shape[i + 1] = x_shape[i]; - } - } else if (x_dims == signal_ndim + 1) { - for (size_t i = 0; i < x_dims; i++) { - tensor_shape[i] = x_shape[i]; - } - } else if (x_dims > signal_ndim + 1) { - tensor_shape[0] = 1; - for (size_t i = 0; i < x_dims - signal_ndim; i++) { - tensor_shape[0] *= x_shape[i]; - } - for (size_t j = x_dims - static_cast(signal_ndim), i = 1; j < x_dims; j++, i++) { - tensor_shape[i] = x_shape[j]; - } - } else { - MS_LOG(EXCEPTION) << "x_dims must not be less than signal_ndim."; - } - return tensor_shape; -} - -template -bool FFTWithSizeCompute(T1 *input_x, T2 *output_y, bool onesided, std::string normalized, - const vector &checked_signal_size, const vector &x_shape) { - Eigen::DSizes tensor_shape = GetFlatShape(x_shape, x_shape.size()); - Eigen::TensorMap, Eigen::RowMajor> in(&input_x[0], tensor_shape); - Eigen::array axes; - change_axes(&axes); - Eigen::Tensor out; - vector norm_shape(x_shape); - if constexpr (is_real) { // for rfft and irfft COMPILE TIME EXPANSION - if constexpr (is_inverse) { // irfft - Eigen::Tensor complex_out; - if (onesided) { - // compute the full fft tensor shape: full_fft_shape[-1] / 2 + 1 - Eigen::DSizes temp_tensor_shape(tensor_shape); - if (checked_signal_size.empty()) { - if (temp_tensor_shape[signal_ndim] == 1) { - MS_EXCEPTION(ValueError) << "For 'FFTWithSize', the last dimension of the input cannot be 1, but got: " - << temp_tensor_shape[signal_ndim]; - } - temp_tensor_shape[signal_ndim] = (temp_tensor_shape[signal_ndim] - 1) * kRealFFTSideNum; - } else { - if (checked_signal_size.back() / kRealFFTSideNum + 1 == temp_tensor_shape[signal_ndim]) { - temp_tensor_shape[static_cast(signal_ndim)] = checked_signal_size.back(); - } - } - if (temp_tensor_shape.back() == tensor_shape.back()) { - // fake there is no need to reconstruct signal tensor - complex_out = in.template fft(axes); - } else { - // Reconstruct the full fft tensor: temp_tensor - Eigen::Tensor temp_tensor(temp_tensor_shape); - temp_tensor.setZero(); - Eigen::DSizes zero_offsets; - Eigen::DSizes input_slice_sizes(in.dimensions()); - temp_tensor.slice(zero_offsets, input_slice_sizes) = in; - // do ifft at outer axes - if (signal_ndim > 1) { - Eigen::array outer_axes; - change_axes(&outer_axes); - temp_tensor = temp_tensor.template fft(outer_axes); - } - // rebuild the last axis with symmetrical data - Eigen::array reverse_last_axis; - for (auto i = 0; i <= signal_ndim; i++) { - reverse_last_axis[i] = i == signal_ndim; - } - auto reverse_size = input_slice_sizes; - reverse_size[signal_ndim] = temp_tensor_shape[signal_ndim] - input_slice_sizes[signal_ndim]; - Eigen::DSizes reverse_start_indices; - reverse_start_indices[signal_ndim] = 1; - Eigen::DSizes reverse_target_indices; - reverse_target_indices[signal_ndim] = input_slice_sizes[signal_ndim]; - temp_tensor.slice(reverse_target_indices, reverse_size) = - temp_tensor.slice(reverse_start_indices, reverse_size).reverse(reverse_last_axis).conjugate(); - // do irfft at the last axis: - auto inner_axis = Eigen::array{signal_ndim}; - complex_out = temp_tensor.template fft(inner_axis); - } - norm_shape.back() = static_cast(temp_tensor_shape.back()); - } else { - complex_out = in.template fft(axes); - } - out.resize(complex_out.dimensions()); - T1 *complex_out_ptr = complex_out.data(); - for (int i = 0; i < complex_out.size(); i++) { - *(out.data() + i) = (complex_out_ptr + i)->real(); - } - } else { // rfft - Eigen::Tensor complex_in(in.dimensions()); - T2 *in_data_ptr = complex_in.data(); - for (int i = 0; i < in.size(); i++) { - (in_data_ptr + i)->real(*(input_x + i)); - (in_data_ptr + i)->imag(0); - } - Eigen::Tensor full_fft = - complex_in.template fft(axes); - if (onesided) { - auto dims = in.dimensions(); - Eigen::DSizes offsets; - Eigen::DSizes input_slice_sizes; - for (auto i = 0; i <= signal_ndim; i++) { - input_slice_sizes[i] = (i == signal_ndim) ? (dims[i] / kRealFFTSideNum + 1) : dims[i]; - } - out = full_fft.slice(offsets, input_slice_sizes); - } else { - out = full_fft; - } - } - } else { // fft and ifft - if (is_inverse) { - out = in.template fft(axes); - } else { - out = in.template fft(axes); - } - } - - int64_t element_num = get_element_num(norm_shape, static_cast(signal_ndim)); - double norm = Getnormalized(element_num, normalized, is_inverse); - T2 *out_ptr = out.data(); - for (int i = 0; i < out.size(); i++) { - T2 temp_value = *(out_ptr + i); - temp_value *= norm; - *(output_y + i) = temp_value; - } - return true; -} - -template -bool FFTWithSizeCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - std::vector checked_signal_size(raw_checked_signal_size_.begin(), raw_checked_signal_size_.end()); - const int64_t choose = FFTWithSize_choose(real_, inverse_); - auto p_x = reinterpret_cast(inputs[kIndex0]->device_ptr()); - auto p_y = reinterpret_cast(outputs[kIndex0]->device_ptr()); - if constexpr (std::is_same::value) { // fft and ifft - if (choose == kDimNum_FFT) { - FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, false, false); - } else { - FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, false, true); - } - } else { // rfft and irfft - if constexpr (std::is_same>::value || - std::is_same>::value) { // irfft - FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, true, true); - } else { // rfft - FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, true, false); - } - } - return true; -} - -#define FFT_CPU_REG(MS_I, MS_O, I, O) \ - KernelAttr() \ - .AddInputAttr(MS_I) /* x */ \ - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) /* signal_ndim */ \ - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* inverse */ \ - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* real */ \ - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) /* norm */ \ - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* onesided */ \ - .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) /* signal_sizes */ \ - .AddOutputAttr(MS_O), \ - &FFTWithSizeCpuKernelMod::LaunchKernel - -std::vector> FFTWithSizeCpuKernelMod::func_list_ = { - {FFT_CPU_REG(kNumberTypeComplex64, kNumberTypeComplex64, std::complex, std::complex)}, - {FFT_CPU_REG(kNumberTypeComplex128, kNumberTypeComplex128, std::complex, std::complex)}, - {FFT_CPU_REG(kNumberTypeFloat32, kNumberTypeComplex64, float, std::complex)}, - {FFT_CPU_REG(kNumberTypeComplex64, kNumberTypeFloat32, std::complex, float)}, - {FFT_CPU_REG(kNumberTypeFloat64, kNumberTypeComplex128, double, std::complex)}, - {FFT_CPU_REG(kNumberTypeComplex128, kNumberTypeFloat64, std::complex, double)}, - {FFT_CPU_REG(kNumberTypeUInt8, kNumberTypeComplex64, uint8_t, std::complex)}, - {FFT_CPU_REG(kNumberTypeInt8, kNumberTypeComplex64, int8_t, std::complex)}, - {FFT_CPU_REG(kNumberTypeInt16, kNumberTypeComplex64, int16_t, std::complex)}, - {FFT_CPU_REG(kNumberTypeInt32, kNumberTypeComplex64, int32_t, std::complex)}, - {FFT_CPU_REG(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex)}, - {FFT_CPU_REG(kNumberTypeBool, kNumberTypeComplex64, bool, std::complex)}}; - -std::vector FFTWithSizeCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FFTWithSize, FFTWithSizeCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h" +#include +#include "ops/op_utils.h" +#include "kernel/kernel.h" + +#define FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, real, inverse) \ + if (signal_ndim_ == 1) { \ + FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ + } else if (signal_ndim_ == 2) { \ + FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ + } else { \ + FFTWithSizeCompute(p_x, p_y, onesided_, normalized_, checked_signal_size, x_shape_); \ + } +using std::vector; +namespace mindspore { +namespace kernel { +namespace { +constexpr int kDimNum_FFT = 1; +constexpr int kDimNum_IFFT = 2; +constexpr int kDimNum_RFFT = 3; +constexpr int kDimNum_IRFFT = 4; +constexpr int kRealFFTSideNum = 2; + +int64_t FFTWithSize_choose(bool real, bool inverse) { + if (!real) { + if (!inverse) { + return kDimNum_FFT; // fftz + } else { + return kDimNum_IFFT; // ifft + } + } else { + if (!inverse) { + return kDimNum_RFFT; // rfft + } else { + return kDimNum_IRFFT; // irfft + } + } +} + +int64_t get_element_num(const std::vector &shape, size_t rank) { + size_t back_itr = shape.size(); + int64_t size = 1; + for (size_t i = 1; i <= rank; i++) { + auto dim = shape[back_itr - i]; + MS_EXCEPTION_IF_CHECK_FAIL(dim > 0, "The element in shape must be positive."); + size *= dim; + } + return size; +} + +template +void change_axes(Eigen::array *axes) { + for (unsigned i = from; i <= (unsigned)to; i++) { + axes->operator[](i - 1) = i; + } + return; +} +} // namespace + +bool FFTWithSizeCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << kernel_name_ << " valid cpu kernel does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int FFTWithSizeCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + + x_shape_ = inputs[kIndex0]->GetShapeVector(); + signal_ndim_ = inputs[kIndex1]->GetValueWithCheck(); + inverse_ = inputs[kIndex2]->GetValueWithCheck(); + real_ = inputs[kIndex3]->GetValueWithCheck(); + normalized_ = inputs[kIndex4]->GetValueWithCheck(); + onesided_ = inputs[kIndex5]->GetValueWithCheck(); + raw_checked_signal_size_ = inputs[kIndex6]->GetValueWithCheck>(); + + return KRET_OK; +} + +double Getnormalized(int64_t element_num, const std::string &normalized, bool is_reverse) { + double result = 1.0; + if (!is_reverse) { + if (normalized == "forward") result = 1.0 / element_num; + if (normalized == "backward") result = 1.0; + if (normalized == "ortho") result = 1.0 / sqrt(static_cast(element_num)); + } + if (is_reverse) { + if (normalized == "forward") result = 1.0 * element_num; + if (normalized == "backward") result = 1.0; + if (normalized == "ortho") result = 1.0 * sqrt(static_cast(element_num)); + } + return result; +} + +template +inline Eigen::DSizes GetFlatShape(const std::vector &x_shape, + size_t x_dims) { + Eigen::DSizes tensor_shape; + if (x_dims == signal_ndim) { + tensor_shape[0] = 1; + for (size_t i = 0; i < x_dims; i++) { + tensor_shape[i + 1] = x_shape[i]; + } + } else if (x_dims == signal_ndim + 1) { + for (size_t i = 0; i < x_dims; i++) { + tensor_shape[i] = x_shape[i]; + } + } else if (x_dims > signal_ndim + 1) { + tensor_shape[0] = 1; + for (size_t i = 0; i < x_dims - signal_ndim; i++) { + tensor_shape[0] *= x_shape[i]; + } + for (size_t j = x_dims - static_cast(signal_ndim), i = 1; j < x_dims; j++, i++) { + tensor_shape[i] = x_shape[j]; + } + } else { + MS_LOG(EXCEPTION) << "x_dims must not be less than signal_ndim."; + } + return tensor_shape; +} + +template +bool FFTWithSizeCompute(T1 *input_x, T2 *output_y, bool onesided, std::string normalized, + const vector &checked_signal_size, const vector &x_shape) { + Eigen::DSizes tensor_shape = GetFlatShape(x_shape, x_shape.size()); + Eigen::TensorMap, Eigen::RowMajor> in(&input_x[0], tensor_shape); + Eigen::array axes; + change_axes(&axes); + Eigen::Tensor out; + vector norm_shape(x_shape); + if constexpr (is_real) { // for rfft and irfft COMPILE TIME EXPANSION + if constexpr (is_inverse) { // irfft + Eigen::Tensor complex_out; + if (onesided) { + // compute the full fft tensor shape: full_fft_shape[-1] / 2 + 1 + Eigen::DSizes temp_tensor_shape(tensor_shape); + if (checked_signal_size.empty()) { + if (temp_tensor_shape[signal_ndim] == 1) { + MS_EXCEPTION(ValueError) << "For 'FFTWithSize', the last dimension of the input cannot be 1, but got: " + << temp_tensor_shape[signal_ndim]; + } + temp_tensor_shape[signal_ndim] = (temp_tensor_shape[signal_ndim] - 1) * kRealFFTSideNum; + } else { + if (checked_signal_size.back() / kRealFFTSideNum + 1 == temp_tensor_shape[signal_ndim]) { + temp_tensor_shape[static_cast(signal_ndim)] = checked_signal_size.back(); + } + } + if (temp_tensor_shape.back() == tensor_shape.back()) { + // fake there is no need to reconstruct signal tensor + complex_out = in.template fft(axes); + } else { + // Reconstruct the full fft tensor: temp_tensor + Eigen::Tensor temp_tensor(temp_tensor_shape); + temp_tensor.setZero(); + Eigen::DSizes zero_offsets; + Eigen::DSizes input_slice_sizes(in.dimensions()); + temp_tensor.slice(zero_offsets, input_slice_sizes) = in; + // do ifft at outer axes + if (signal_ndim > 1) { + Eigen::array outer_axes; + change_axes(&outer_axes); + temp_tensor = temp_tensor.template fft(outer_axes); + } + // rebuild the last axis with symmetrical data + Eigen::array reverse_last_axis; + for (auto i = 0; i <= signal_ndim; i++) { + reverse_last_axis[i] = i == signal_ndim; + } + auto reverse_size = input_slice_sizes; + reverse_size[signal_ndim] = temp_tensor_shape[signal_ndim] - input_slice_sizes[signal_ndim]; + Eigen::DSizes reverse_start_indices; + reverse_start_indices[signal_ndim] = 1; + Eigen::DSizes reverse_target_indices; + reverse_target_indices[signal_ndim] = input_slice_sizes[signal_ndim]; + temp_tensor.slice(reverse_target_indices, reverse_size) = + temp_tensor.slice(reverse_start_indices, reverse_size).reverse(reverse_last_axis).conjugate(); + // do irfft at the last axis: + auto inner_axis = Eigen::array{signal_ndim}; + complex_out = temp_tensor.template fft(inner_axis); + } + norm_shape.back() = static_cast(temp_tensor_shape.back()); + } else { + complex_out = in.template fft(axes); + } + out.resize(complex_out.dimensions()); + T1 *complex_out_ptr = complex_out.data(); + for (int i = 0; i < complex_out.size(); i++) { + *(out.data() + i) = (complex_out_ptr + i)->real(); + } + } else { // rfft + Eigen::Tensor complex_in(in.dimensions()); + T2 *in_data_ptr = complex_in.data(); + for (int i = 0; i < in.size(); i++) { + (in_data_ptr + i)->real(*(input_x + i)); + (in_data_ptr + i)->imag(0); + } + Eigen::Tensor full_fft = + complex_in.template fft(axes); + if (onesided) { + auto dims = in.dimensions(); + Eigen::DSizes offsets; + Eigen::DSizes input_slice_sizes; + for (auto i = 0; i <= signal_ndim; i++) { + input_slice_sizes[i] = (i == signal_ndim) ? (dims[i] / kRealFFTSideNum + 1) : dims[i]; + } + out = full_fft.slice(offsets, input_slice_sizes); + } else { + out = full_fft; + } + } + } else { // fft and ifft + if (is_inverse) { + out = in.template fft(axes); + } else { + out = in.template fft(axes); + } + } + + int64_t element_num = get_element_num(norm_shape, static_cast(signal_ndim)); + double norm = Getnormalized(element_num, normalized, is_inverse); + T2 *out_ptr = out.data(); + for (int i = 0; i < out.size(); i++) { + T2 temp_value = *(out_ptr + i); + temp_value *= norm; + *(output_y + i) = temp_value; + } + return true; +} + +template +bool FFTWithSizeCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + std::vector checked_signal_size(raw_checked_signal_size_.begin(), raw_checked_signal_size_.end()); + const int64_t choose = FFTWithSize_choose(real_, inverse_); + auto p_x = reinterpret_cast(inputs[kIndex0]->device_ptr()); + auto p_y = reinterpret_cast(outputs[kIndex0]->device_ptr()); + if constexpr (std::is_same::value) { // fft and ifft + if (choose == kDimNum_FFT) { + FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, false, false); + } else { + FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, false, true); + } + } else { // rfft and irfft + if constexpr (std::is_same>::value || + std::is_same>::value) { // irfft + FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, true, true); + } else { // rfft + FFTWITHSIZE_SWITCH_DIM_CALCULATE(T1, T2, true, false); + } + } + return true; +} + +#define FFT_CPU_REG(MS_I, MS_O, I, O) \ + KernelAttr() \ + .AddInputAttr(MS_I) /* x */ \ + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) /* signal_ndim */ \ + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* inverse */ \ + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* real */ \ + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) /* norm */ \ + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) /* onesided */ \ + .AddInputAttr(kObjectTypeTuple, kNumberTypeInt64) /* signal_sizes */ \ + .AddOutputAttr(MS_O), \ + &FFTWithSizeCpuKernelMod::LaunchKernel + +std::vector> FFTWithSizeCpuKernelMod::func_list_ = { + {FFT_CPU_REG(kNumberTypeComplex64, kNumberTypeComplex64, std::complex, std::complex)}, + {FFT_CPU_REG(kNumberTypeComplex128, kNumberTypeComplex128, std::complex, std::complex)}, + {FFT_CPU_REG(kNumberTypeFloat32, kNumberTypeComplex64, float, std::complex)}, + {FFT_CPU_REG(kNumberTypeComplex64, kNumberTypeFloat32, std::complex, float)}, + {FFT_CPU_REG(kNumberTypeFloat64, kNumberTypeComplex128, double, std::complex)}, + {FFT_CPU_REG(kNumberTypeComplex128, kNumberTypeFloat64, std::complex, double)}, + {FFT_CPU_REG(kNumberTypeUInt8, kNumberTypeComplex64, uint8_t, std::complex)}, + {FFT_CPU_REG(kNumberTypeInt8, kNumberTypeComplex64, int8_t, std::complex)}, + {FFT_CPU_REG(kNumberTypeInt16, kNumberTypeComplex64, int16_t, std::complex)}, + {FFT_CPU_REG(kNumberTypeInt32, kNumberTypeComplex64, int32_t, std::complex)}, + {FFT_CPU_REG(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex)}, + {FFT_CPU_REG(kNumberTypeBool, kNumberTypeComplex64, bool, std::complex)}}; + +std::vector FFTWithSizeCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FFTWithSize, FFTWithSizeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h index 3a6e241fef2..d2c2d2b84b0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fft_with_size_cpu_kernel.h @@ -1,67 +1,67 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" -#include "unsupported/Eigen/CXX11/Tensor" - -namespace mindspore { -constexpr size_t kInputNum = 1; -constexpr size_t kOutputNum = 1; -namespace kernel { -class FFTWithSizeCpuKernelMod : public NativeCpuKernelMod { - public: - FFTWithSizeCpuKernelMod() = default; - ~FFTWithSizeCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - return kernel_func_(this, inputs, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &outputs); - using FFTWithSizeFunc = std::function &, - const std::vector &)>; - static std::vector> func_list_; - FFTWithSizeFunc kernel_func_; - bool real_; - bool inverse_; - bool onesided_; - int64_t signal_ndim_; - std::string normalized_; - std::vector raw_checked_signal_size_; - std::vector x_shape_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace mindspore { +constexpr size_t kInputNum = 1; +constexpr size_t kOutputNum = 1; +namespace kernel { +class FFTWithSizeCpuKernelMod : public NativeCpuKernelMod { + public: + FFTWithSizeCpuKernelMod() = default; + ~FFTWithSizeCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &outputs); + using FFTWithSizeFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + FFTWithSizeFunc kernel_func_; + bool real_; + bool inverse_; + bool onesided_; + int64_t signal_ndim_; + std::string normalized_; + std::vector raw_checked_signal_size_; + std::vector x_shape_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FFTWITHSIZE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.cc index 4a3a0d40430..56cb1140830 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.cc @@ -1,155 +1,155 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "mindspore/core/ops/fill_diagonal.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kFillDiagonalInputNum = 1; -const size_t kFillDiagonalOutputNum = 1; -const size_t kInputDimIndex0 = 0; -const size_t kInputDimIndex1 = 1; -const size_t kInputMinDim = 2; -constexpr int64_t kParallelDataNums = 512 * 1024; -} // namespace - -bool FillDiagonalCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillDiagonalInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillDiagonalOutputNum, kernel_name_); - - input_type_ = inputs[0]->dtype_id(); - fill_value_ = GetValue(primitive_->GetAttr(ops::kFillValue)); - wrap_ = GetValue(primitive_->GetAttr(ops::kWrap)); - - if (IsOneOfUnsignedType(input_type_) && fill_value_ < 0) { - MS_LOG(ERROR) << "For " << kernel_name_ << ", [file_value] should be non_negative for input of unsigned type."; - return false; - } - return true; -} - -int FillDiagonalCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - auto ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - input_shape_ = inputs[0]->GetDeviceShapeVector(); - return KRET_OK; -} - -bool FillDiagonalCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (input_type_ == kNumberTypeFloat16) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeFloat32) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeFloat64) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeUInt8) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeUInt16) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeUInt32) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeUInt64) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeInt8) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeInt16) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeInt32) { - return LaunchKernel(inputs, outputs); - } else if (input_type_ == kNumberTypeInt64) { - return LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "the datatype of the input not support, support datatype: float, int32, int64."; - } -} - -template -bool FillDiagonalCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *input_ptr = reinterpret_cast(inputs[0]->device_ptr()); - MS_EXCEPTION_IF_NULL(input_ptr); - T *output_ptr = reinterpret_cast(outputs[0]->device_ptr()); - MS_EXCEPTION_IF_NULL(output_ptr); - - size_t data_nums = outputs[0]->size() / sizeof(T); - if (SizeToLong(data_nums) <= kParallelDataNums) { - auto ret_code = memcpy_s(output_ptr, data_nums * sizeof(T), input_ptr, data_nums * sizeof(T)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code; - } - } else { - auto task = [this, input_ptr, output_ptr](size_t start, size_t end) { - auto ret_code = - memcpy_s(output_ptr + start, (end - start) * sizeof(T), input_ptr + start, (end - start) * sizeof(T)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code; - } - }; - CPUKernelUtils::ParallelFor(task, data_nums); - } - - int64_t height = input_shape_[kInputDimIndex0]; - int64_t width = input_shape_[kInputDimIndex1]; - int64_t size = std::min(height, width); - - int64_t stride = 0; - for (int64_t i = (SizeToLong(input_shape_.size()) - 1); i >= 0; i--) { - stride += static_cast(pow(width, i)); - } - for (int64_t i = 0; i < size; ++i) { - output_ptr[stride * i] = static_cast(fill_value_); - } - - if (wrap_ && input_shape_.size() == kInputMinDim && height > width + 1) { - int64_t location = size * (size + 1); - while (location < SizeToLong(data_nums)) { - output_ptr[location] = static_cast(fill_value_); - location += stride; - } - } - - return true; -} - -std::vector FillDiagonalCpuKernelMod::GetOpSupport() { - static const std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)}; - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillDiagonal, FillDiagonalCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "mindspore/core/ops/fill_diagonal.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kFillDiagonalInputNum = 1; +const size_t kFillDiagonalOutputNum = 1; +const size_t kInputDimIndex0 = 0; +const size_t kInputDimIndex1 = 1; +const size_t kInputMinDim = 2; +constexpr int64_t kParallelDataNums = 512 * 1024; +} // namespace + +bool FillDiagonalCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillDiagonalInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillDiagonalOutputNum, kernel_name_); + + input_type_ = inputs[0]->dtype_id(); + fill_value_ = GetValue(primitive_->GetAttr(ops::kFillValue)); + wrap_ = GetValue(primitive_->GetAttr(ops::kWrap)); + + if (IsOneOfUnsignedType(input_type_) && fill_value_ < 0) { + MS_LOG(ERROR) << "For " << kernel_name_ << ", [file_value] should be non_negative for input of unsigned type."; + return false; + } + return true; +} + +int FillDiagonalCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + auto ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + input_shape_ = inputs[0]->GetDeviceShapeVector(); + return KRET_OK; +} + +bool FillDiagonalCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (input_type_ == kNumberTypeFloat16) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeFloat32) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeFloat64) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeUInt8) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeUInt16) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeUInt32) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeUInt64) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt8) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt16) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt32) { + return LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt64) { + return LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "the datatype of the input not support, support datatype: float, int32, int64."; + } +} + +template +bool FillDiagonalCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *input_ptr = reinterpret_cast(inputs[0]->device_ptr()); + MS_EXCEPTION_IF_NULL(input_ptr); + T *output_ptr = reinterpret_cast(outputs[0]->device_ptr()); + MS_EXCEPTION_IF_NULL(output_ptr); + + size_t data_nums = outputs[0]->size() / sizeof(T); + if (SizeToLong(data_nums) <= kParallelDataNums) { + auto ret_code = memcpy_s(output_ptr, data_nums * sizeof(T), input_ptr, data_nums * sizeof(T)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code; + } + } else { + auto task = [this, input_ptr, output_ptr](size_t start, size_t end) { + auto ret_code = + memcpy_s(output_ptr + start, (end - start) * sizeof(T), input_ptr + start, (end - start) * sizeof(T)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code; + } + }; + CPUKernelUtils::ParallelFor(task, data_nums); + } + + int64_t height = input_shape_[kInputDimIndex0]; + int64_t width = input_shape_[kInputDimIndex1]; + int64_t size = std::min(height, width); + + int64_t stride = 0; + for (int64_t i = (SizeToLong(input_shape_.size()) - 1); i >= 0; i--) { + stride += static_cast(pow(width, i)); + } + for (int64_t i = 0; i < size; ++i) { + output_ptr[stride * i] = static_cast(fill_value_); + } + + if (wrap_ && input_shape_.size() == kInputMinDim && height > width + 1) { + int64_t location = size * (size + 1); + while (location < SizeToLong(data_nums)) { + output_ptr[location] = static_cast(fill_value_); + location += stride; + } + } + + return true; +} + +std::vector FillDiagonalCpuKernelMod::GetOpSupport() { + static const std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64)}; + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillDiagonal, FillDiagonalCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h index 6bfab06232a..3e2870e0adc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fill_diagonal_cpu_kernel.h @@ -1,51 +1,51 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ - -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class FillDiagonalCpuKernelMod : public NativeCpuKernelMod { - public: - FillDiagonalCpuKernelMod() = default; - ~FillDiagonalCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - TypeId input_type_{kTypeUnknown}; - std::vector input_shape_; - float fill_value_; - bool wrap_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class FillDiagonalCpuKernelMod : public NativeCpuKernelMod { + public: + FillDiagonalCpuKernelMod() = default; + ~FillDiagonalCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + TypeId input_type_{kTypeUnknown}; + std::vector input_shape_; + float fill_value_; + bool wrap_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FILL_DIAGONAL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.cc index 5698fdccc80..9e274d5bd27 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.cc @@ -1,198 +1,198 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h" -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kInputsNum = 3; -const size_t kOutputsNum = 1; -const size_t kInputIndex0 = 0; -const size_t kInputIndex1 = 1; -const size_t kInputIndex2 = 2; -const size_t kInputsDimSize = 4; -const size_t kInputsDimIndexN = 0; -const size_t kInputsDimIndexC = 1; -const size_t kInputsDimIndexH = 2; -const size_t kInputsDimIndexW = 3; - -#define ADD_KERNEL(t1, t2, t3, t4) \ - KernelAttr() \ - .AddInputAttr(kNumberType##t1) \ - .AddInputAttr(kNumberType##t2) \ - .AddInputAttr(kNumberType##t3) \ - .AddOutputAttr(kNumberType##t4) -} // namespace - -bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - constexpr size_t input_num = kInputsNum; - constexpr size_t output_num = kOutputsNum; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); - out_backprop_type_ = inputs[kInputIndex1]->dtype_id(); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!match.first) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - data_format_ = GetValue(primitive_->GetAttr(ops::kFormat)); - if (data_format_ != "NCHW") { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr data_format must be NCHW."; - } - return true; -} - -int FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - input_shape_ = inputs[kInputIndex0]->GetDeviceShapeVector(); - out_backprop_shape_ = inputs[kInputIndex1]->GetDeviceShapeVector(); - argmax_shape_ = inputs[kInputIndex2]->GetDeviceShapeVector(); - if (input_shape_.size() != kInputsDimSize) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input origin_input must be 4, but got " - << input_shape_.size() << "."; - } - input_n_ = input_shape_[kInputsDimIndexN]; - input_c_ = input_shape_[kInputsDimIndexC]; - input_h_ = input_shape_[kInputsDimIndexH]; - input_w_ = input_shape_[kInputsDimIndexW]; - if (out_backprop_shape_.size() != kInputsDimSize) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input out_backprop must be 4, but got " - << out_backprop_shape_.size() << "."; - } - out_backprop_h_ = out_backprop_shape_[kInputsDimIndexH]; - out_backprop_w_ = out_backprop_shape_[kInputsDimIndexW]; - if (argmax_shape_.size() != kInputsDimSize) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input argmax must be 4, but got " - << argmax_shape_.size() << "."; - } - for (size_t i = 0; i < kInputsDimSize; i++) { - if (out_backprop_shape_[i] != argmax_shape_[i]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', The shape of input out_backprop and input argmax must be equal."; - } - } - - if (input_n_ != out_backprop_shape_[kInputsDimIndexN]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The first dimension of three inputs must be equal."; - } - if (input_c_ != out_backprop_shape_[kInputsDimIndexC]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The second dimension of three inputs must be equal."; - } - return ret; -} - -bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); - switch (out_backprop_type_) { - case kNumberTypeFloat16: - return GradComputeTemplate(inputs, outputs); - case kNumberTypeFloat32: - return GradComputeTemplate(inputs, outputs); - case kNumberTypeFloat64: - return GradComputeTemplate(inputs, outputs); - case kNumberTypeInt32: - return GradComputeTemplate(inputs, outputs); - case kNumberTypeInt64: - return GradComputeTemplate(inputs, outputs); - default: - MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', out_backprop_type" << out_backprop_type_ - << "not support, must be in [{DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}]."; - } - return true; -} - -template -bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::GradComputeTemplate( - const std::vector &inputs, const std::vector &outputs) const { - backprop_t *out_backprop_ptr = static_cast(inputs[1]->device_ptr()); - int64_t *argmax_ptr = static_cast(inputs[2]->device_ptr()); - backprop_t *output_ptr = static_cast(outputs[0]->device_ptr()); - - auto shard_fractional_max_pool_grad_with_fixed_ksize = [&](size_t start, size_t end) { - for (size_t n = start; n < end; n++) { - backprop_t *out_backpropForPlane = out_backprop_ptr + n * input_c_ * out_backprop_h_ * out_backprop_w_; - int64_t *argmaxForPlane = argmax_ptr + n * input_c_ * out_backprop_h_ * out_backprop_w_; - backprop_t *outputForPlane = output_ptr + n * input_c_ * input_h_ * input_w_; - - FractionalMaxPoolGradWithFixedKsizeCompute(out_backpropForPlane, argmaxForPlane, outputForPlane); - } - }; - CPUKernelUtils::ParallelFor(shard_fractional_max_pool_grad_with_fixed_ksize, LongToSize(input_n_)); - return true; -} - -template -void FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::FractionalMaxPoolGradWithFixedKsizeCompute( - backprop_t *out_backpropForPlane, int64_t *argmaxForPlane, backprop_t *outputForPlane) const { - for (int64_t plane = 0; plane < input_c_; plane++) { - backprop_t *out_backpropPlane = out_backpropForPlane + plane * out_backprop_h_ * out_backprop_w_; - int64_t *argmaxPlane = argmaxForPlane + plane * out_backprop_h_ * out_backprop_w_; - backprop_t *outputPlane = outputForPlane + plane * input_h_ * input_w_; - - for (int64_t i = 0; i < input_h_; i++) { - for (int64_t j = 0; j < input_w_; j++) { - outputPlane[i * input_w_ + j] = static_cast(0); - } - } - - for (int64_t h = 0; h < out_backprop_h_; h++) { - for (int64_t w = 0; w < out_backprop_w_; w++) { - int input_index = h * out_backprop_w_ + w; - if (input_index < 0 || input_index >= (out_backprop_h_ * out_backprop_w_)) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the index value of argmax is illegal."; - } - int output_index = argmaxPlane[input_index]; - if (output_index < 0 || output_index >= (input_h_ * input_w_)) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the index value of output is illegal."; - } - outputPlane[output_index] += out_backpropPlane[input_index]; - } - } - } -} - -std::vector FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = { - ADD_KERNEL(Int32, Float16, Int64, Float16), ADD_KERNEL(Int32, Float32, Int64, Float32), - ADD_KERNEL(Int32, Float64, Int64, Float64), ADD_KERNEL(Int32, Int32, Int64, Int32), - ADD_KERNEL(Int32, Int64, Int64, Int64), ADD_KERNEL(Int64, Float16, Int64, Float16), - ADD_KERNEL(Int64, Float32, Int64, Float32), ADD_KERNEL(Int64, Float64, Int64, Float64), - ADD_KERNEL(Int64, Int32, Int64, Int32), ADD_KERNEL(Int64, Int64, Int64, Int64)}; - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FractionalMaxPoolGradWithFixedKsize, - FractionalMaxPoolGradWithFixedKsizeCPUKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kInputsNum = 3; +const size_t kOutputsNum = 1; +const size_t kInputIndex0 = 0; +const size_t kInputIndex1 = 1; +const size_t kInputIndex2 = 2; +const size_t kInputsDimSize = 4; +const size_t kInputsDimIndexN = 0; +const size_t kInputsDimIndexC = 1; +const size_t kInputsDimIndexH = 2; +const size_t kInputsDimIndexW = 3; + +#define ADD_KERNEL(t1, t2, t3, t4) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddOutputAttr(kNumberType##t4) +} // namespace + +bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = kInputsNum; + constexpr size_t output_num = kOutputsNum; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); + out_backprop_type_ = inputs[kInputIndex1]->dtype_id(); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!match.first) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + data_format_ = GetValue(primitive_->GetAttr(ops::kFormat)); + if (data_format_ != "NCHW") { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr data_format must be NCHW."; + } + return true; +} + +int FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + input_shape_ = inputs[kInputIndex0]->GetDeviceShapeVector(); + out_backprop_shape_ = inputs[kInputIndex1]->GetDeviceShapeVector(); + argmax_shape_ = inputs[kInputIndex2]->GetDeviceShapeVector(); + if (input_shape_.size() != kInputsDimSize) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input origin_input must be 4, but got " + << input_shape_.size() << "."; + } + input_n_ = input_shape_[kInputsDimIndexN]; + input_c_ = input_shape_[kInputsDimIndexC]; + input_h_ = input_shape_[kInputsDimIndexH]; + input_w_ = input_shape_[kInputsDimIndexW]; + if (out_backprop_shape_.size() != kInputsDimSize) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input out_backprop must be 4, but got " + << out_backprop_shape_.size() << "."; + } + out_backprop_h_ = out_backprop_shape_[kInputsDimIndexH]; + out_backprop_w_ = out_backprop_shape_[kInputsDimIndexW]; + if (argmax_shape_.size() != kInputsDimSize) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The dim of input argmax must be 4, but got " + << argmax_shape_.size() << "."; + } + for (size_t i = 0; i < kInputsDimSize; i++) { + if (out_backprop_shape_[i] != argmax_shape_[i]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', The shape of input out_backprop and input argmax must be equal."; + } + } + + if (input_n_ != out_backprop_shape_[kInputsDimIndexN]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The first dimension of three inputs must be equal."; + } + if (input_c_ != out_backprop_shape_[kInputsDimIndexC]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', The second dimension of three inputs must be equal."; + } + return ret; +} + +bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + switch (out_backprop_type_) { + case kNumberTypeFloat16: + return GradComputeTemplate(inputs, outputs); + case kNumberTypeFloat32: + return GradComputeTemplate(inputs, outputs); + case kNumberTypeFloat64: + return GradComputeTemplate(inputs, outputs); + case kNumberTypeInt32: + return GradComputeTemplate(inputs, outputs); + case kNumberTypeInt64: + return GradComputeTemplate(inputs, outputs); + default: + MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', out_backprop_type" << out_backprop_type_ + << "not support, must be in [{DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}]."; + } + return true; +} + +template +bool FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::GradComputeTemplate( + const std::vector &inputs, const std::vector &outputs) const { + backprop_t *out_backprop_ptr = static_cast(inputs[1]->device_ptr()); + int64_t *argmax_ptr = static_cast(inputs[2]->device_ptr()); + backprop_t *output_ptr = static_cast(outputs[0]->device_ptr()); + + auto shard_fractional_max_pool_grad_with_fixed_ksize = [&](size_t start, size_t end) { + for (size_t n = start; n < end; n++) { + backprop_t *out_backpropForPlane = out_backprop_ptr + n * input_c_ * out_backprop_h_ * out_backprop_w_; + int64_t *argmaxForPlane = argmax_ptr + n * input_c_ * out_backprop_h_ * out_backprop_w_; + backprop_t *outputForPlane = output_ptr + n * input_c_ * input_h_ * input_w_; + + FractionalMaxPoolGradWithFixedKsizeCompute(out_backpropForPlane, argmaxForPlane, outputForPlane); + } + }; + CPUKernelUtils::ParallelFor(shard_fractional_max_pool_grad_with_fixed_ksize, LongToSize(input_n_)); + return true; +} + +template +void FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::FractionalMaxPoolGradWithFixedKsizeCompute( + backprop_t *out_backpropForPlane, int64_t *argmaxForPlane, backprop_t *outputForPlane) const { + for (int64_t plane = 0; plane < input_c_; plane++) { + backprop_t *out_backpropPlane = out_backpropForPlane + plane * out_backprop_h_ * out_backprop_w_; + int64_t *argmaxPlane = argmaxForPlane + plane * out_backprop_h_ * out_backprop_w_; + backprop_t *outputPlane = outputForPlane + plane * input_h_ * input_w_; + + for (int64_t i = 0; i < input_h_; i++) { + for (int64_t j = 0; j < input_w_; j++) { + outputPlane[i * input_w_ + j] = static_cast(0); + } + } + + for (int64_t h = 0; h < out_backprop_h_; h++) { + for (int64_t w = 0; w < out_backprop_w_; w++) { + int input_index = h * out_backprop_w_ + w; + if (input_index < 0 || input_index >= (out_backprop_h_ * out_backprop_w_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the index value of argmax is illegal."; + } + int output_index = argmaxPlane[input_index]; + if (output_index < 0 || output_index >= (input_h_ * input_w_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the index value of output is illegal."; + } + outputPlane[output_index] += out_backpropPlane[input_index]; + } + } + } +} + +std::vector FractionalMaxPoolGradWithFixedKsizeCPUKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + ADD_KERNEL(Int32, Float16, Int64, Float16), ADD_KERNEL(Int32, Float32, Int64, Float32), + ADD_KERNEL(Int32, Float64, Int64, Float64), ADD_KERNEL(Int32, Int32, Int64, Int32), + ADD_KERNEL(Int32, Int64, Int64, Int64), ADD_KERNEL(Int64, Float16, Int64, Float16), + ADD_KERNEL(Int64, Float32, Int64, Float32), ADD_KERNEL(Int64, Float64, Int64, Float64), + ADD_KERNEL(Int64, Int32, Int64, Int32), ADD_KERNEL(Int64, Int64, Int64, Int64)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FractionalMaxPoolGradWithFixedKsize, + FractionalMaxPoolGradWithFixedKsizeCPUKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h index 9c8b44872bb..78159eb30b1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_grad_with_fixed_ksize_cpu_kernel.h @@ -1,61 +1,61 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" -#include "mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h" - -namespace mindspore { -namespace kernel { -class FractionalMaxPoolGradWithFixedKsizeCPUKernelMod : public NativeCpuKernelMod { - public: - FractionalMaxPoolGradWithFixedKsizeCPUKernelMod() = default; - ~FractionalMaxPoolGradWithFixedKsizeCPUKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool GradComputeTemplate(const std::vector &inputs, const std::vector &outputs) const; - template - void FractionalMaxPoolGradWithFixedKsizeCompute(backprop_t *out_backpropForPlane, int64_t *argmaxForPlane, - backprop_t *outputForPlane) const; - std::vector input_shape_; - std::vector out_backprop_shape_; - std::vector argmax_shape_; - std::string data_format_{"NCHW"}; - TypeId out_backprop_type_; - int64_t input_n_; - int64_t input_c_; - int64_t input_h_; - int64_t input_w_; - int64_t out_backprop_h_; - int64_t out_backprop_w_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h" + +namespace mindspore { +namespace kernel { +class FractionalMaxPoolGradWithFixedKsizeCPUKernelMod : public NativeCpuKernelMod { + public: + FractionalMaxPoolGradWithFixedKsizeCPUKernelMod() = default; + ~FractionalMaxPoolGradWithFixedKsizeCPUKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool GradComputeTemplate(const std::vector &inputs, const std::vector &outputs) const; + template + void FractionalMaxPoolGradWithFixedKsizeCompute(backprop_t *out_backpropForPlane, int64_t *argmaxForPlane, + backprop_t *outputForPlane) const; + std::vector input_shape_; + std::vector out_backprop_shape_; + std::vector argmax_shape_; + std::string data_format_{"NCHW"}; + TypeId out_backprop_type_; + int64_t input_n_; + int64_t input_c_; + int64_t input_h_; + int64_t input_w_; + int64_t out_backprop_h_; + int64_t out_backprop_w_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.cc index 7eed0d614ee..eefb3f2da7b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.cc @@ -1,306 +1,306 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h" -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kInputsNum = 2; -const size_t kOutputsNum = 2; -const size_t kInputIndex0 = 0; -const size_t kInputIndex1 = 1; -const size_t kOutputIndex1 = 1; -const size_t kInputDimIndexN = 0; -const size_t kInputDimIndexC = 1; -const size_t kInputDimIndexH = 2; -const size_t kInputDimIndexW = 3; -const size_t kDimSize1 = 1; -const size_t kDimSize2 = 2; -const size_t kDimSize3 = 3; -const size_t kDimSize4 = 4; -const size_t kKszieIndexH = 0; -const size_t kKszieIndexW = 1; -const size_t kOutputShapeIndexH = 0; -const size_t kOutputShapeIndexW = 1; -const size_t kRandomSimplesLastDimIndex = 2; -const int64_t kRandomSimplesThirdDimSize = 2; -const size_t kKsizeLength1 = 1; -const size_t kKsizeLength2 = 2; -const size_t kOutputShapeLength1 = 1; -const size_t kOutputShapeLength2 = 2; - -#define ADD_KERNEL(t1, t2, t3, t4) \ - KernelAttr() \ - .AddInputAttr(kNumberType##t1) \ - .AddInputAttr(kNumberType##t2) \ - .AddOutputAttr(kNumberType##t3) \ - .AddOutputAttr(kNumberType##t4) -} // namespace - -bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - constexpr size_t input_num = kInputsNum; - constexpr size_t output_num = kOutputsNum; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); - input_type_ = inputs[kInputIndex0]->dtype_id(); - random_samples_type_ = inputs[kInputIndex1]->dtype_id(); - argmax_type_ = outputs[kOutputIndex1]->dtype_id(); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!match.first) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', does not support this kernel data type: " << kernel_attr; - return false; - } - output_shape_ = GetValue>(primitive_->GetAttr("output_shape")); - ksize_ = GetValue>(primitive_->GetAttr("ksize")); - data_format_ = GetValue(primitive_->GetAttr(ops::kFormat)); - if (data_format_ != "NCHW") { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr data_format must be NCHW."; - } - if (std::any_of(output_shape_.begin(), output_shape_.end(), [](int64_t output_shape) { return output_shape <= 0; })) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', the output_shape should all be positive numbers, but there are negative numbers."; - } - if (output_shape_.size() == kOutputShapeLength1) { - output_h_ = output_shape_[kOutputShapeIndexH]; - output_w_ = output_shape_[kOutputShapeIndexH]; - } else if (output_shape_.size() == kOutputShapeLength2) { - output_h_ = output_shape_[kOutputShapeIndexH]; - output_w_ = output_shape_[kOutputShapeIndexW]; - } else { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', the size of attr output_shape must be equal to 1 or 2, but got " - << output_shape_.size() << "."; - } - if (ksize_.size() == kKsizeLength1) { - ksize_h_ = ksize_[kKszieIndexH]; - ksize_w_ = ksize_[kKszieIndexH]; - } else if (ksize_.size() == kKsizeLength2) { - ksize_h_ = ksize_[kKszieIndexH]; - ksize_w_ = ksize_[kKszieIndexW]; - } else { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the size of attr kszie must be equal to 1 or 2, but got " - << ksize_.size() << "."; - } - return true; -} - -int FractionalMaxPoolWithFixedKsizeCPUKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - input_shape_ = inputs[kInputIndex0]->GetDeviceShapeVector(); - random_samples_shape_ = inputs[kInputIndex1]->GetDeviceShapeVector(); - - input_n_ = input_shape_[kInputDimIndexN]; - input_c_ = input_shape_[kInputDimIndexC]; - input_h_ = input_shape_[kInputDimIndexH]; - input_w_ = input_shape_[kInputDimIndexW]; - - if (output_h_ + ksize_h_ - 1 > input_h_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', ksize height [" << ksize_h_ << "] + output_shape_h [" - << output_h_ << "] too large relative to input height [" << input_h_ - << "], conflict with the rule: ksize_h + output_shape_h - 1 <= input_h"; - } - if (output_w_ + ksize_w_ - 1 > input_w_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', ksize width [" << ksize_w_ << "] + output_shape_w [" - << output_w_ << "] too large relative to input width [" << input_w_ - << "], conflict with the rule: ksize_w + output_shape_w - 1 <= input_w"; - } - if (random_samples_shape_[kInputDimIndexN] != input_n_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', The first dim of input[x] and input[random_samples] must be equal, but " - << "got x=[" << input_n_ << "] and random_samples=[" - << random_samples_shape_[kInputDimIndexN] << "]."; - } - if (random_samples_shape_[kInputDimIndexC] != input_c_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', The second dim of input[x] and input[random_samples] must be equal, but " - << "got x=[" << input_c_ << "] and random_samples=[" - << random_samples_shape_[kInputDimIndexC] << "]."; - } - if (random_samples_shape_[kRandomSimplesLastDimIndex] != kRandomSimplesThirdDimSize) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', The third dim of input[random_samples] must be 2, but got " - << random_samples_shape_[kRandomSimplesLastDimIndex] << "."; - } - return ret; -} - -bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); - switch (input_type_) { - case kNumberTypeFloat16: - return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); - case kNumberTypeFloat32: - return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); - case kNumberTypeFloat64: - return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); - case kNumberTypeInt32: - return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); - case kNumberTypeInt64: - return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); - default: - MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', the data type of input not support."; - } - return true; -} - -template -bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::DoComputeWithRandomSamplesType( - const std::vector &inputs, const std::vector &outputs, - TypeId random_samples_type_) const { - switch (random_samples_type_) { - case kNumberTypeFloat16: - return ComputeTemplate(inputs, outputs); - case kNumberTypeFloat32: - return ComputeTemplate(inputs, outputs); - case kNumberTypeFloat64: - return ComputeTemplate(inputs, outputs); - default: - MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', random_samples_type" << random_samples_type_ - << "not support, must be in [{DT_FLOAT16, DT_FLOAT, DT_DOUBLE}]."; - } -} - -template -bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::ComputeTemplate(const std::vector &inputs, - const std::vector &outputs) const { - scalar_t *input_ptr = static_cast(inputs[0]->device_ptr()); - random_sample_t *random_samples_ptr = static_cast(inputs[1]->device_ptr()); - scalar_t *output_ptr = static_cast(outputs[0]->device_ptr()); - int64_t *argmax_ptr = static_cast(outputs[1]->device_ptr()); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(random_samples_ptr); - MS_EXCEPTION_IF_NULL(output_ptr); - MS_EXCEPTION_IF_NULL(argmax_ptr); - - auto shard_fractional_max_pool_with_fixed_ksize = [&](size_t start, size_t end) { - for (size_t n = start; n < end; n++) { - scalar_t *inputForPlane = input_ptr + n * input_c_ * input_h_ * input_w_; - random_sample_t *random_samplesForPlane = random_samples_ptr + n * input_c_ * kRandomSimplesThirdDimSize; - scalar_t *outputForPlane = output_ptr + n * input_c_ * output_h_ * output_w_; - int64_t *argmaxForPlane = argmax_ptr + n * input_c_ * output_h_ * output_w_; - - FractionalMaxPoolWithFixedKsizeCompute(inputForPlane, random_samplesForPlane, - outputForPlane, argmaxForPlane); - } - }; - CPUKernelUtils::ParallelFor(shard_fractional_max_pool_with_fixed_ksize, LongToSize(input_n_)); - - return true; -} - -template -void FractionalMaxPoolWithFixedKsizeCPUKernelMod::FractionalMaxPoolWithFixedKsizeCompute( - scalar_t *inputForPlane, random_sample_t *random_samplesForPlane, scalar_t *outputForPlane, - int64_t *argmaxForPlane) const { - for (int64_t plane = 0; plane < input_c_; plane++) { - random_sample_t *random_samplesPlane = random_samplesForPlane + plane * 2; - std::vector sequenceW = GenerateIntervals( - random_samplesPlane[0], static_cast(input_w_), static_cast(output_w_), static_cast(ksize_w_)); - std::vector sequenceH = GenerateIntervals( - random_samplesPlane[1], static_cast(input_h_), static_cast(output_h_), static_cast(ksize_h_)); - - scalar_t *inputPlane = inputForPlane + plane * input_h_ * input_w_; - scalar_t *outputPlane = outputForPlane + plane * output_h_ * output_w_; - int64_t *argmaxPlane = argmaxForPlane + plane * output_h_ * output_w_; - - int h; - int w; - for (h = 0; h < output_h_; h++) { - int inputHStart = sequenceH[h]; - for (w = 0; w < output_w_; w++) { - int inputWStart = sequenceW[w]; - int h2 = inputHStart; - int w2 = inputWStart; - scalar_t maxValue = -std::numeric_limits::infinity(); - int64_t maxIndex = h2 * input_w_ + w2; - - for (h2 = inputHStart; h2 < inputHStart + ksize_h_; h2++) { - for (w2 = inputWStart; w2 < inputWStart + ksize_w_; w2++) { - if (h2 < 0 || h2 >= input_h_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', index H value is illegal."; - } - if (w2 < 0 || w2 >= input_w_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', index W value is illegal."; - } - - int index = h2 * input_w_ + w2; - scalar_t value = inputPlane[index]; - if (value > maxValue) { - maxValue = value; - maxIndex = index; - } - } - } - - outputPlane[h * output_w_ + w] = maxValue; - argmaxPlane[h * output_w_ + w] = maxIndex; - } - } - } -} - -template -std::vector FractionalMaxPoolWithFixedKsizeCPUKernelMod::GenerateIntervals(random_sample_t sample, int input_size, - int output_size, - int kernel_size) const { - std::vector sequence(output_size); - if (output_size > 1) { - random_sample_t alpha = - static_cast(input_size - kernel_size) / static_cast(output_size - 1); - - for (int i = 0; i < output_size - 1; i++) { - sequence[i] = - static_cast((static_cast(i) + sample) * alpha) - static_cast(sample * alpha); - } - } - sequence[output_size - 1] = input_size - kernel_size; - - return sequence; -} - -std::vector FractionalMaxPoolWithFixedKsizeCPUKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = { - ADD_KERNEL(Int32, Float32, Int32, Int64), ADD_KERNEL(Int64, Float32, Int64, Int64), - ADD_KERNEL(Float16, Float32, Float16, Int64), ADD_KERNEL(Float32, Float32, Float32, Int64), - ADD_KERNEL(Float64, Float32, Float64, Int64), ADD_KERNEL(Int32, Float16, Int32, Int64), - ADD_KERNEL(Int64, Float16, Int64, Int64), ADD_KERNEL(Float16, Float16, Float16, Int64), - ADD_KERNEL(Float32, Float16, Float32, Int64), ADD_KERNEL(Float64, Float16, Float64, Int64), - ADD_KERNEL(Int32, Float64, Int32, Int64), ADD_KERNEL(Int64, Float64, Int64, Int64), - ADD_KERNEL(Float16, Float64, Float16, Int64), ADD_KERNEL(Float32, Float64, Float32, Int64), - ADD_KERNEL(Float64, Float64, Float64, Int64)}; - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FractionalMaxPoolWithFixedKsize, FractionalMaxPoolWithFixedKsizeCPUKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kInputsNum = 2; +const size_t kOutputsNum = 2; +const size_t kInputIndex0 = 0; +const size_t kInputIndex1 = 1; +const size_t kOutputIndex1 = 1; +const size_t kInputDimIndexN = 0; +const size_t kInputDimIndexC = 1; +const size_t kInputDimIndexH = 2; +const size_t kInputDimIndexW = 3; +const size_t kDimSize1 = 1; +const size_t kDimSize2 = 2; +const size_t kDimSize3 = 3; +const size_t kDimSize4 = 4; +const size_t kKszieIndexH = 0; +const size_t kKszieIndexW = 1; +const size_t kOutputShapeIndexH = 0; +const size_t kOutputShapeIndexW = 1; +const size_t kRandomSimplesLastDimIndex = 2; +const int64_t kRandomSimplesThirdDimSize = 2; +const size_t kKsizeLength1 = 1; +const size_t kKsizeLength2 = 2; +const size_t kOutputShapeLength1 = 1; +const size_t kOutputShapeLength2 = 2; + +#define ADD_KERNEL(t1, t2, t3, t4) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddOutputAttr(kNumberType##t3) \ + .AddOutputAttr(kNumberType##t4) +} // namespace + +bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = kInputsNum; + constexpr size_t output_num = kOutputsNum; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); + input_type_ = inputs[kInputIndex0]->dtype_id(); + random_samples_type_ = inputs[kInputIndex1]->dtype_id(); + argmax_type_ = outputs[kOutputIndex1]->dtype_id(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!match.first) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', does not support this kernel data type: " << kernel_attr; + return false; + } + output_shape_ = GetValue>(primitive_->GetAttr("output_shape")); + ksize_ = GetValue>(primitive_->GetAttr("ksize")); + data_format_ = GetValue(primitive_->GetAttr(ops::kFormat)); + if (data_format_ != "NCHW") { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the attr data_format must be NCHW."; + } + if (std::any_of(output_shape_.begin(), output_shape_.end(), [](int64_t output_shape) { return output_shape <= 0; })) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', the output_shape should all be positive numbers, but there are negative numbers."; + } + if (output_shape_.size() == kOutputShapeLength1) { + output_h_ = output_shape_[kOutputShapeIndexH]; + output_w_ = output_shape_[kOutputShapeIndexH]; + } else if (output_shape_.size() == kOutputShapeLength2) { + output_h_ = output_shape_[kOutputShapeIndexH]; + output_w_ = output_shape_[kOutputShapeIndexW]; + } else { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', the size of attr output_shape must be equal to 1 or 2, but got " + << output_shape_.size() << "."; + } + if (ksize_.size() == kKsizeLength1) { + ksize_h_ = ksize_[kKszieIndexH]; + ksize_w_ = ksize_[kKszieIndexH]; + } else if (ksize_.size() == kKsizeLength2) { + ksize_h_ = ksize_[kKszieIndexH]; + ksize_w_ = ksize_[kKszieIndexW]; + } else { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the size of attr kszie must be equal to 1 or 2, but got " + << ksize_.size() << "."; + } + return true; +} + +int FractionalMaxPoolWithFixedKsizeCPUKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + input_shape_ = inputs[kInputIndex0]->GetDeviceShapeVector(); + random_samples_shape_ = inputs[kInputIndex1]->GetDeviceShapeVector(); + + input_n_ = input_shape_[kInputDimIndexN]; + input_c_ = input_shape_[kInputDimIndexC]; + input_h_ = input_shape_[kInputDimIndexH]; + input_w_ = input_shape_[kInputDimIndexW]; + + if (output_h_ + ksize_h_ - 1 > input_h_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', ksize height [" << ksize_h_ << "] + output_shape_h [" + << output_h_ << "] too large relative to input height [" << input_h_ + << "], conflict with the rule: ksize_h + output_shape_h - 1 <= input_h"; + } + if (output_w_ + ksize_w_ - 1 > input_w_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', ksize width [" << ksize_w_ << "] + output_shape_w [" + << output_w_ << "] too large relative to input width [" << input_w_ + << "], conflict with the rule: ksize_w + output_shape_w - 1 <= input_w"; + } + if (random_samples_shape_[kInputDimIndexN] != input_n_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', The first dim of input[x] and input[random_samples] must be equal, but " + << "got x=[" << input_n_ << "] and random_samples=[" + << random_samples_shape_[kInputDimIndexN] << "]."; + } + if (random_samples_shape_[kInputDimIndexC] != input_c_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', The second dim of input[x] and input[random_samples] must be equal, but " + << "got x=[" << input_c_ << "] and random_samples=[" + << random_samples_shape_[kInputDimIndexC] << "]."; + } + if (random_samples_shape_[kRandomSimplesLastDimIndex] != kRandomSimplesThirdDimSize) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', The third dim of input[random_samples] must be 2, but got " + << random_samples_shape_[kRandomSimplesLastDimIndex] << "."; + } + return ret; +} + +bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + switch (input_type_) { + case kNumberTypeFloat16: + return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); + case kNumberTypeFloat32: + return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); + case kNumberTypeFloat64: + return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); + case kNumberTypeInt32: + return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); + case kNumberTypeInt64: + return DoComputeWithRandomSamplesType(inputs, outputs, random_samples_type_); + default: + MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', the data type of input not support."; + } + return true; +} + +template +bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::DoComputeWithRandomSamplesType( + const std::vector &inputs, const std::vector &outputs, + TypeId random_samples_type_) const { + switch (random_samples_type_) { + case kNumberTypeFloat16: + return ComputeTemplate(inputs, outputs); + case kNumberTypeFloat32: + return ComputeTemplate(inputs, outputs); + case kNumberTypeFloat64: + return ComputeTemplate(inputs, outputs); + default: + MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', random_samples_type" << random_samples_type_ + << "not support, must be in [{DT_FLOAT16, DT_FLOAT, DT_DOUBLE}]."; + } +} + +template +bool FractionalMaxPoolWithFixedKsizeCPUKernelMod::ComputeTemplate(const std::vector &inputs, + const std::vector &outputs) const { + scalar_t *input_ptr = static_cast(inputs[0]->device_ptr()); + random_sample_t *random_samples_ptr = static_cast(inputs[1]->device_ptr()); + scalar_t *output_ptr = static_cast(outputs[0]->device_ptr()); + int64_t *argmax_ptr = static_cast(outputs[1]->device_ptr()); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(random_samples_ptr); + MS_EXCEPTION_IF_NULL(output_ptr); + MS_EXCEPTION_IF_NULL(argmax_ptr); + + auto shard_fractional_max_pool_with_fixed_ksize = [&](size_t start, size_t end) { + for (size_t n = start; n < end; n++) { + scalar_t *inputForPlane = input_ptr + n * input_c_ * input_h_ * input_w_; + random_sample_t *random_samplesForPlane = random_samples_ptr + n * input_c_ * kRandomSimplesThirdDimSize; + scalar_t *outputForPlane = output_ptr + n * input_c_ * output_h_ * output_w_; + int64_t *argmaxForPlane = argmax_ptr + n * input_c_ * output_h_ * output_w_; + + FractionalMaxPoolWithFixedKsizeCompute(inputForPlane, random_samplesForPlane, + outputForPlane, argmaxForPlane); + } + }; + CPUKernelUtils::ParallelFor(shard_fractional_max_pool_with_fixed_ksize, LongToSize(input_n_)); + + return true; +} + +template +void FractionalMaxPoolWithFixedKsizeCPUKernelMod::FractionalMaxPoolWithFixedKsizeCompute( + scalar_t *inputForPlane, random_sample_t *random_samplesForPlane, scalar_t *outputForPlane, + int64_t *argmaxForPlane) const { + for (int64_t plane = 0; plane < input_c_; plane++) { + random_sample_t *random_samplesPlane = random_samplesForPlane + plane * 2; + std::vector sequenceW = GenerateIntervals( + random_samplesPlane[0], static_cast(input_w_), static_cast(output_w_), static_cast(ksize_w_)); + std::vector sequenceH = GenerateIntervals( + random_samplesPlane[1], static_cast(input_h_), static_cast(output_h_), static_cast(ksize_h_)); + + scalar_t *inputPlane = inputForPlane + plane * input_h_ * input_w_; + scalar_t *outputPlane = outputForPlane + plane * output_h_ * output_w_; + int64_t *argmaxPlane = argmaxForPlane + plane * output_h_ * output_w_; + + int h; + int w; + for (h = 0; h < output_h_; h++) { + int inputHStart = sequenceH[h]; + for (w = 0; w < output_w_; w++) { + int inputWStart = sequenceW[w]; + int h2 = inputHStart; + int w2 = inputWStart; + scalar_t maxValue = -std::numeric_limits::infinity(); + int64_t maxIndex = h2 * input_w_ + w2; + + for (h2 = inputHStart; h2 < inputHStart + ksize_h_; h2++) { + for (w2 = inputWStart; w2 < inputWStart + ksize_w_; w2++) { + if (h2 < 0 || h2 >= input_h_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', index H value is illegal."; + } + if (w2 < 0 || w2 >= input_w_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', index W value is illegal."; + } + + int index = h2 * input_w_ + w2; + scalar_t value = inputPlane[index]; + if (value > maxValue) { + maxValue = value; + maxIndex = index; + } + } + } + + outputPlane[h * output_w_ + w] = maxValue; + argmaxPlane[h * output_w_ + w] = maxIndex; + } + } + } +} + +template +std::vector FractionalMaxPoolWithFixedKsizeCPUKernelMod::GenerateIntervals(random_sample_t sample, int input_size, + int output_size, + int kernel_size) const { + std::vector sequence(output_size); + if (output_size > 1) { + random_sample_t alpha = + static_cast(input_size - kernel_size) / static_cast(output_size - 1); + + for (int i = 0; i < output_size - 1; i++) { + sequence[i] = + static_cast((static_cast(i) + sample) * alpha) - static_cast(sample * alpha); + } + } + sequence[output_size - 1] = input_size - kernel_size; + + return sequence; +} + +std::vector FractionalMaxPoolWithFixedKsizeCPUKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + ADD_KERNEL(Int32, Float32, Int32, Int64), ADD_KERNEL(Int64, Float32, Int64, Int64), + ADD_KERNEL(Float16, Float32, Float16, Int64), ADD_KERNEL(Float32, Float32, Float32, Int64), + ADD_KERNEL(Float64, Float32, Float64, Int64), ADD_KERNEL(Int32, Float16, Int32, Int64), + ADD_KERNEL(Int64, Float16, Int64, Int64), ADD_KERNEL(Float16, Float16, Float16, Int64), + ADD_KERNEL(Float32, Float16, Float32, Int64), ADD_KERNEL(Float64, Float16, Float64, Int64), + ADD_KERNEL(Int32, Float64, Int32, Int64), ADD_KERNEL(Int64, Float64, Int64, Int64), + ADD_KERNEL(Float16, Float64, Float16, Int64), ADD_KERNEL(Float32, Float64, Float32, Int64), + ADD_KERNEL(Float64, Float64, Float64, Int64)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FractionalMaxPoolWithFixedKsize, FractionalMaxPoolWithFixedKsizeCPUKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h index fa263485c0c..103efc465fe 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/fractional_max_pool_with_fixed_ksize_cpu_kernel.h @@ -1,71 +1,71 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" -#include "mindspore/core/ops/fractional_max_pool_with_fixed_ksize.h" - -namespace mindspore { -namespace kernel { -class FractionalMaxPoolWithFixedKsizeCPUKernelMod : public NativeCpuKernelMod { - public: - FractionalMaxPoolWithFixedKsizeCPUKernelMod() = default; - ~FractionalMaxPoolWithFixedKsizeCPUKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool DoComputeWithRandomSamplesType(const std::vector &inputs, - const std::vector &outputs, TypeId random_samples_type) const; - template - bool ComputeTemplate(const std::vector &inputs, const std::vector &outputs) const; - template - void FractionalMaxPoolWithFixedKsizeCompute(scalar_t *inputForPlane, random_sample_t *random_samplesForPlane, - scalar_t *outputForPlane, int64_t *argmaxForPlane) const; - template - std::vector GenerateIntervals(random_sample_t sample, int input_size, int output_size, int kernel_size) const; - std::vector input_shape_; - std::vector random_samples_shape_; - std::vector output_shape_; - std::vector ksize_; - std::string data_format_{"NCHW"}; - TypeId input_type_; - TypeId random_samples_type_; - TypeId argmax_type_; - int64_t input_n_; - int64_t input_c_; - int64_t input_h_; - int64_t input_w_; - int64_t ksize_h_; - int64_t ksize_w_; - int64_t output_h_; - int64_t output_w_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "mindspore/core/ops/fractional_max_pool_with_fixed_ksize.h" + +namespace mindspore { +namespace kernel { +class FractionalMaxPoolWithFixedKsizeCPUKernelMod : public NativeCpuKernelMod { + public: + FractionalMaxPoolWithFixedKsizeCPUKernelMod() = default; + ~FractionalMaxPoolWithFixedKsizeCPUKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool DoComputeWithRandomSamplesType(const std::vector &inputs, + const std::vector &outputs, TypeId random_samples_type) const; + template + bool ComputeTemplate(const std::vector &inputs, const std::vector &outputs) const; + template + void FractionalMaxPoolWithFixedKsizeCompute(scalar_t *inputForPlane, random_sample_t *random_samplesForPlane, + scalar_t *outputForPlane, int64_t *argmaxForPlane) const; + template + std::vector GenerateIntervals(random_sample_t sample, int input_size, int output_size, int kernel_size) const; + std::vector input_shape_; + std::vector random_samples_shape_; + std::vector output_shape_; + std::vector ksize_; + std::string data_format_{"NCHW"}; + TypeId input_type_; + TypeId random_samples_type_; + TypeId argmax_type_; + int64_t input_n_; + int64_t input_c_; + int64_t input_h_; + int64_t input_w_; + int64_t ksize_h_; + int64_t ksize_w_; + int64_t output_h_; + int64_t output_w_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_FRACTIONAL_MAX_POOL_WITH_FIXED_KSIZE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.cc index f4c1daaa85e..d5c6ee2993d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.cc @@ -1,167 +1,167 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/geqrf_cpu_kernel.h" - -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputsNum = 1; -constexpr size_t kOutputsNum = 2; -constexpr size_t kInputIndex0 = 0; -constexpr size_t kOutputIndex0 = 0; -constexpr size_t kOutputIndex1 = 1; -constexpr int64_t kLastSecond = -2; -} // namespace - -bool GeqrfCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For" << kernel_name_ << " does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int GeqrfCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - std::vector input0_tensor_shape = inputs[0]->GetShapeVector(); - MS_EXCEPTION_IF_CHECK_FAIL(!input0_tensor_shape.empty(), "For Geqrf, input0_tensor_shape should not be empty."); - elem_num = static_cast( - std::accumulate(input0_tensor_shape.begin(), input0_tensor_shape.end(), 1, std::multiplies())); - num_m = static_cast(input0_tensor_shape.end()[kLastSecond]); - num_n = static_cast(input0_tensor_shape.back()); - batch_num = elem_num / (num_m * num_n); - return KRET_OK; -} - -template -void GeqrfCpuKernelMod::Larfg(size_t n, size_t vm, size_t vn, T *x, T *tau) { - T zero = static_cast(0); - if (n <= 1) { - *tau = zero; - return; - } - T xnorm = zero; - for (size_t i = vm + 1; i < vm + n; i++) { - xnorm = xnorm + (*(x + i * num_n + vn) * *(x + i * num_n + vn)); - } - xnorm = static_cast(sqrt(xnorm)); - if (xnorm == zero) { - *tau = zero; - return; - } else { - T beta = sqrt((*(x + vm * num_n + vn) * *(x + vm * num_n + vn)) + xnorm * xnorm); - if (*(x + vm * num_n + vn) > zero) { - beta = -beta; - } - if (beta == zero) { - return; - } - *tau = (beta - *(x + vm * num_n + vn)) / beta; - auto scal = *(x + vm * num_n + vn) - beta; - for (size_t i = vm + 1; i < vm + n; i++) { - *(x + i * num_n + vn) /= scal; - } - *(x + vm * num_n + vn) = beta; - } -} - -template -std::unique_ptr GeqrfCpuKernelMod::Larf(size_t m, size_t n, T *x, T *tau, std::unique_ptr workspace, - size_t cm, size_t cn) { - if (m <= 0 || n <= 0) { - return std::move(workspace); - } - for (size_t i = 0; i < n; i++) { - workspace[i] = static_cast(0); - } - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < n; j++) { - workspace[j] += *(x + ((cm + i) * num_n) + (cn - 1)) * *(x + ((cm + i) * num_n) + (cn + j)); - } - } - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < n; j++) { - *(x + ((cm + i) * num_n) + (cn + j)) -= (*tau) * *(x + ((cm + i) * num_n) + (cn - 1)) * workspace[j]; - } - } - return std::move(workspace); -} - -template -void GeqrfCpuKernelMod::Geqrf(size_t num_m_, size_t num_n_, T *x, T *tau) { - size_t k = std::min(num_m_, num_n_); - T one = static_cast(1); - auto x_origin = x; - auto tau_origin = tau; - auto geqrf_shard = [&](size_t start, size_t end) { - std::unique_ptr workspace = std::make_unique(num_n_); - for (size_t batch = start; batch < end; ++batch) { - x = x_origin + batch * num_m_ * num_n_; - tau = tau_origin + batch * k; - for (size_t i = 0; i < k; i++) { - Larfg(num_m_ - i, i, i, x, tau + i); - T aii = *(x + i * num_n_ + i); - *(x + i * num_n_ + i) = one; - workspace = Larf(num_m_ - i, num_n_ - i - 1, x, tau + i, std::move(workspace), i, i + 1); - *(x + i * num_n_ + i) = aii; - } - } - }; - ParallelLaunchAutoSearch(geqrf_shard, batch_num, this, ¶llel_search_info_); -} - -template -bool GeqrfCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - MS_EXCEPTION_IF_NULL(inputs[kInputIndex0]); - MS_EXCEPTION_IF_NULL(outputs[kOutputIndex0]); - MS_EXCEPTION_IF_NULL(outputs[kOutputIndex1]); - T *x = static_cast(inputs[kInputIndex0]->device_ptr()); - T *y = static_cast(outputs[kOutputIndex0]->device_ptr()); - T *tau = static_cast(outputs[kOutputIndex1]->device_ptr()); - std::copy(x, x + elem_num, y); - Geqrf(num_m, num_n, y, tau); - return true; -} - -std::vector> GeqrfCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &GeqrfCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &GeqrfCpuKernelMod::LaunchKernel}}; - -std::vector GeqrfCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Geqrf, GeqrfCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/geqrf_cpu_kernel.h" + +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputsNum = 1; +constexpr size_t kOutputsNum = 2; +constexpr size_t kInputIndex0 = 0; +constexpr size_t kOutputIndex0 = 0; +constexpr size_t kOutputIndex1 = 1; +constexpr int64_t kLastSecond = -2; +} // namespace + +bool GeqrfCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For" << kernel_name_ << " does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int GeqrfCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + std::vector input0_tensor_shape = inputs[0]->GetShapeVector(); + MS_EXCEPTION_IF_CHECK_FAIL(!input0_tensor_shape.empty(), "For Geqrf, input0_tensor_shape should not be empty."); + elem_num = static_cast( + std::accumulate(input0_tensor_shape.begin(), input0_tensor_shape.end(), 1, std::multiplies())); + num_m = static_cast(input0_tensor_shape.end()[kLastSecond]); + num_n = static_cast(input0_tensor_shape.back()); + batch_num = elem_num / (num_m * num_n); + return KRET_OK; +} + +template +void GeqrfCpuKernelMod::Larfg(size_t n, size_t vm, size_t vn, T *x, T *tau) { + T zero = static_cast(0); + if (n <= 1) { + *tau = zero; + return; + } + T xnorm = zero; + for (size_t i = vm + 1; i < vm + n; i++) { + xnorm = xnorm + (*(x + i * num_n + vn) * *(x + i * num_n + vn)); + } + xnorm = static_cast(sqrt(xnorm)); + if (xnorm == zero) { + *tau = zero; + return; + } else { + T beta = sqrt((*(x + vm * num_n + vn) * *(x + vm * num_n + vn)) + xnorm * xnorm); + if (*(x + vm * num_n + vn) > zero) { + beta = -beta; + } + if (beta == zero) { + return; + } + *tau = (beta - *(x + vm * num_n + vn)) / beta; + auto scal = *(x + vm * num_n + vn) - beta; + for (size_t i = vm + 1; i < vm + n; i++) { + *(x + i * num_n + vn) /= scal; + } + *(x + vm * num_n + vn) = beta; + } +} + +template +std::unique_ptr GeqrfCpuKernelMod::Larf(size_t m, size_t n, T *x, T *tau, std::unique_ptr workspace, + size_t cm, size_t cn) { + if (m <= 0 || n <= 0) { + return std::move(workspace); + } + for (size_t i = 0; i < n; i++) { + workspace[i] = static_cast(0); + } + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < n; j++) { + workspace[j] += *(x + ((cm + i) * num_n) + (cn - 1)) * *(x + ((cm + i) * num_n) + (cn + j)); + } + } + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < n; j++) { + *(x + ((cm + i) * num_n) + (cn + j)) -= (*tau) * *(x + ((cm + i) * num_n) + (cn - 1)) * workspace[j]; + } + } + return std::move(workspace); +} + +template +void GeqrfCpuKernelMod::Geqrf(size_t num_m_, size_t num_n_, T *x, T *tau) { + size_t k = std::min(num_m_, num_n_); + T one = static_cast(1); + auto x_origin = x; + auto tau_origin = tau; + auto geqrf_shard = [&](size_t start, size_t end) { + std::unique_ptr workspace = std::make_unique(num_n_); + for (size_t batch = start; batch < end; ++batch) { + x = x_origin + batch * num_m_ * num_n_; + tau = tau_origin + batch * k; + for (size_t i = 0; i < k; i++) { + Larfg(num_m_ - i, i, i, x, tau + i); + T aii = *(x + i * num_n_ + i); + *(x + i * num_n_ + i) = one; + workspace = Larf(num_m_ - i, num_n_ - i - 1, x, tau + i, std::move(workspace), i, i + 1); + *(x + i * num_n_ + i) = aii; + } + } + }; + ParallelLaunchAutoSearch(geqrf_shard, batch_num, this, ¶llel_search_info_); +} + +template +bool GeqrfCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(inputs[kInputIndex0]); + MS_EXCEPTION_IF_NULL(outputs[kOutputIndex0]); + MS_EXCEPTION_IF_NULL(outputs[kOutputIndex1]); + T *x = static_cast(inputs[kInputIndex0]->device_ptr()); + T *y = static_cast(outputs[kOutputIndex0]->device_ptr()); + T *tau = static_cast(outputs[kOutputIndex1]->device_ptr()); + std::copy(x, x + elem_num, y); + Geqrf(num_m, num_n, y, tau); + return true; +} + +std::vector> GeqrfCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &GeqrfCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &GeqrfCpuKernelMod::LaunchKernel}}; + +std::vector GeqrfCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Geqrf, GeqrfCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.h index e90c8aa4d99..c3648787611 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/geqrf_cpu_kernel.h @@ -1,73 +1,73 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class GeqrfCpuKernelMod : public NativeCpuKernelMod { - public: - GeqrfCpuKernelMod() = default; - ~GeqrfCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - }; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - using GeqrfLaunchFunc = std::function &, - const std::vector &)>; - - static std::vector> func_list_; - GeqrfLaunchFunc kernel_func_; - - template - void Larfg(size_t n, size_t vm, size_t vn, T *x, T *tau); - - template - std::unique_ptr Larf(size_t m, size_t n, T *x, T *tau, std::unique_ptr workspace, size_t cm, size_t cn); - - template - void Geqrf(size_t m, size_t n, T *x, T *tau); - size_t num_m = 0; - size_t num_n = 0; - size_t elem_num; - size_t batch_num; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class GeqrfCpuKernelMod : public NativeCpuKernelMod { + public: + GeqrfCpuKernelMod() = default; + ~GeqrfCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + }; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + using GeqrfLaunchFunc = std::function &, + const std::vector &)>; + + static std::vector> func_list_; + GeqrfLaunchFunc kernel_func_; + + template + void Larfg(size_t n, size_t vm, size_t vn, T *x, T *tau); + + template + std::unique_ptr Larf(size_t m, size_t n, T *x, T *tau, std::unique_ptr workspace, size_t cm, size_t cn); + + template + void Geqrf(size_t m, size_t n, T *x, T *tau); + size_t num_m = 0; + size_t num_n = 0; + size_t elem_num; + size_t batch_num; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GEQRF_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.cc index 73c16a19292..c50bcdfa6bd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.cc @@ -1,136 +1,136 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/glu_grad_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -const int64_t kEvenNum = 2; -} // namespace - -int GluGradCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - dtype_ = inputs[kIndex0]->dtype_id(); - auto axis_value = GetValue(primitive_->GetAttr("axis")); - grad_shape_ = inputs[kIndex0]->GetShapeVector(); - x_shape_ = inputs[kIndex1]->GetShapeVector(); - - int64_t rank = SizeToLong(x_shape_.size()); - if (axis_value < -rank || axis_value >= rank) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in range [" << -rank << ", " << rank - << "), but got " << axis_value << "."; - } - if (axis_value < 0) { - axis_ = axis_value + rank; - } else { - axis_ = axis_value; - } - - if (x_shape_[axis_] % kEvenNum != 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', x.shape[" << axis_value << "] must be even, but got " - << x_shape_[axis_] << "."; - } - - auto expected_grad_shape = x_shape_; - expected_grad_shape[axis_] /= kEvenNum; - if (grad_shape_ != expected_grad_shape) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', x.shape must be euqal to grad.shape except for grad.shape[axis]=x.shape[axis]" - "/2, but got axis=" - << axis_value << ", x.shape=" << x_shape_ << " and grad.shape=" << grad_shape_ << "."; - } - return KRET_OK; -} - -bool GluGradCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat64) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of x must be float16, float32 or float64, but got " - << TypeIdLabel(dtype_) << "."; - } - return true; -} - -template -void GluGradCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - const auto *input0 = static_cast(inputs[0]->device_ptr()); - const auto *input1 = static_cast(inputs[1]->device_ptr()); - auto *output = static_cast(outputs[0]->device_ptr()); - std::vector shape = x_shape_; - int64_t dim = axis_; - size_t lens = outputs[0]->size() > 0 ? outputs[0]->size() / sizeof(T) : 1; - auto task = [&input0, &input1, &output, &shape, &dim](const size_t start, const size_t end) { - int64_t input_num = std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies()); - int num = input_num; - for (int64_t m = 0; m <= dim; m++) { - if (m < dim) { - num = num / shape[m]; - } else if (m == dim) { - num = num / 2; - } - } - int64_t n_m = 1; - int64_t size_m = 0; - int64_t grad_offset_b = 0; - int64_t grad_offset_a = 0; - for (int i = 0; i < input_num; i++) { - if (n_m % 2 != 0) { - *(output + i) = (T(1.0) / (T(1.0) + exp(-(*(input1 + (i + num)))))) * (*(input0 + grad_offset_b)); - grad_offset_b += 1; - size_m = size_m + 1; - if (size_m == num) { - n_m += 1; - size_m = 0; - } - } else { - *(output + i) = *(input1 + (i - num)) * (T(1.0) / (T(1.0) + exp(-(*(input1 + i))))) * - (T(1.0) - (T(1.0) / (T(1.0) + exp(-(*(input1 + i)))))) * (*(input0 + grad_offset_a)); - grad_offset_a += 1; - size_m = size_m + 1; - if (size_m == num) { - n_m += 1; - size_m = 0; - } - } - } - }; - ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); -} - -std::vector GluGradCpuKernelMod::GetOpSupport() { - std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GluGrad, GluGradCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/glu_grad_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +const int64_t kEvenNum = 2; +} // namespace + +int GluGradCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + dtype_ = inputs[kIndex0]->dtype_id(); + auto axis_value = GetValue(primitive_->GetAttr("axis")); + grad_shape_ = inputs[kIndex0]->GetShapeVector(); + x_shape_ = inputs[kIndex1]->GetShapeVector(); + + int64_t rank = SizeToLong(x_shape_.size()); + if (axis_value < -rank || axis_value >= rank) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in range [" << -rank << ", " << rank + << "), but got " << axis_value << "."; + } + if (axis_value < 0) { + axis_ = axis_value + rank; + } else { + axis_ = axis_value; + } + + if (x_shape_[axis_] % kEvenNum != 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', x.shape[" << axis_value << "] must be even, but got " + << x_shape_[axis_] << "."; + } + + auto expected_grad_shape = x_shape_; + expected_grad_shape[axis_] /= kEvenNum; + if (grad_shape_ != expected_grad_shape) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', x.shape must be euqal to grad.shape except for grad.shape[axis]=x.shape[axis]" + "/2, but got axis=" + << axis_value << ", x.shape=" << x_shape_ << " and grad.shape=" << grad_shape_ << "."; + } + return KRET_OK; +} + +bool GluGradCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of x must be float16, float32 or float64, but got " + << TypeIdLabel(dtype_) << "."; + } + return true; +} + +template +void GluGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + const auto *input0 = static_cast(inputs[0]->device_ptr()); + const auto *input1 = static_cast(inputs[1]->device_ptr()); + auto *output = static_cast(outputs[0]->device_ptr()); + std::vector shape = x_shape_; + int64_t dim = axis_; + size_t lens = outputs[0]->size() > 0 ? outputs[0]->size() / sizeof(T) : 1; + auto task = [&input0, &input1, &output, &shape, &dim](const size_t start, const size_t end) { + int64_t input_num = std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies()); + int num = input_num; + for (int64_t m = 0; m <= dim; m++) { + if (m < dim) { + num = num / shape[m]; + } else if (m == dim) { + num = num / 2; + } + } + int64_t n_m = 1; + int64_t size_m = 0; + int64_t grad_offset_b = 0; + int64_t grad_offset_a = 0; + for (int i = 0; i < input_num; i++) { + if (n_m % 2 != 0) { + *(output + i) = (T(1.0) / (T(1.0) + exp(-(*(input1 + (i + num)))))) * (*(input0 + grad_offset_b)); + grad_offset_b += 1; + size_m = size_m + 1; + if (size_m == num) { + n_m += 1; + size_m = 0; + } + } else { + *(output + i) = *(input1 + (i - num)) * (T(1.0) / (T(1.0) + exp(-(*(input1 + i))))) * + (T(1.0) - (T(1.0) / (T(1.0) + exp(-(*(input1 + i)))))) * (*(input0 + grad_offset_a)); + grad_offset_a += 1; + size_m = size_m + 1; + if (size_m == num) { + n_m += 1; + size_m = 0; + } + } + } + }; + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); +} + +std::vector GluGradCpuKernelMod::GetOpSupport() { + std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GluGrad, GluGradCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.h index 05ee67cf990..c4311efd27e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/glu_grad_cpu_kernel.h @@ -1,54 +1,54 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class GluGradCpuKernelMod : public NativeCpuKernelMod { - public: - GluGradCpuKernelMod() = default; - ~GluGradCpuKernelMod() override = default; - bool Init(const std::vector &inputs, const std::vector &outputs) override { - return true; - } - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - int64_t axis_{1}; - std::vector grad_shape_; - std::vector x_shape_; - TypeId dtype_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class GluGradCpuKernelMod : public NativeCpuKernelMod { + public: + GluGradCpuKernelMod() = default; + ~GluGradCpuKernelMod() override = default; + bool Init(const std::vector &inputs, const std::vector &outputs) override { + return true; + } + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + int64_t axis_{1}; + std::vector grad_shape_; + std::vector x_shape_; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GLU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.cc index 2addb0c1c2b..9368da41e4e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.cc @@ -1,139 +1,139 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/hamming_window_cpu_kernel.h" -#include -#include -#include -#include -#include "mindspore/core/ops/hamming_window.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -const size_t kHammingWindowOutputNum = 1; -const size_t kHammingWindowInputNum = 1; -} // namespace - -bool HammingWindowCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kHammingWindowInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kHammingWindowOutputNum, kernel_name_); - periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); - - alpha_ = GetValue(primitive_->GetAttr(ops::kAlpha)); - beta_ = GetValue(primitive_->GetAttr(ops::kBeta)); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "HammingWindow does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -template -bool HammingWindowCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector & /* workspace */, - const std::vector &outputs) const { - auto *length_addr = static_cast(inputs[0]->device_ptr()); - auto *output = static_cast(outputs[0]->device_ptr()); - int64_t window_length_ = static_cast(*length_addr); - if (window_length_ < 0) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of input 'length' cannot be negative, but got " - << window_length_; - } else if (window_length_ == 0) { - return true; - } else if (window_length_ == 1) { - *output = S{1}; - return true; - } - int64_t length = periodic_ ? window_length_ : (window_length_ - 1); - constexpr double t_pi = 6.283185307179586476925286766559; - auto func = [length, alpha = alpha_, beta = beta_, t_pi, &output](int64_t start, int64_t end) { - for (int64_t i = start; i < end; i++) { - double result = alpha - beta * std::cos(i * t_pi / length); - output[i] = static_cast(result); - } - }; - ParallelLaunch(func, LongToSize(window_length_)); - return true; -} - -std::vector> HammingWindowCpuKernelMod::func_list_ = - {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowCpuKernelMod::LaunchKernel}}; - -std::vector HammingWindowCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, HammingWindow, HammingWindowCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/hamming_window_cpu_kernel.h" +#include +#include +#include +#include +#include "mindspore/core/ops/hamming_window.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kHammingWindowOutputNum = 1; +const size_t kHammingWindowInputNum = 1; +} // namespace + +bool HammingWindowCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kHammingWindowInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kHammingWindowOutputNum, kernel_name_); + periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); + + alpha_ = GetValue(primitive_->GetAttr(ops::kAlpha)); + beta_ = GetValue(primitive_->GetAttr(ops::kBeta)); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "HammingWindow does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +template +bool HammingWindowCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector & /* workspace */, + const std::vector &outputs) const { + auto *length_addr = static_cast(inputs[0]->device_ptr()); + auto *output = static_cast(outputs[0]->device_ptr()); + int64_t window_length_ = static_cast(*length_addr); + if (window_length_ < 0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of input 'length' cannot be negative, but got " + << window_length_; + } else if (window_length_ == 0) { + return true; + } else if (window_length_ == 1) { + *output = S{1}; + return true; + } + int64_t length = periodic_ ? window_length_ : (window_length_ - 1); + constexpr double t_pi = 6.283185307179586476925286766559; + auto func = [length, alpha = alpha_, beta = beta_, t_pi, &output](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + double result = alpha - beta * std::cos(i * t_pi / length); + output[i] = static_cast(result); + } + }; + ParallelLaunch(func, LongToSize(window_length_)); + return true; +} + +std::vector> HammingWindowCpuKernelMod::func_list_ = + {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowCpuKernelMod::LaunchKernel}}; + +std::vector HammingWindowCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, HammingWindow, HammingWindowCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.h index 3807c09107f..f6bc0044de0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/hamming_window_cpu_kernel.h @@ -1,58 +1,58 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class HammingWindowCpuKernelMod : public NativeCpuKernelMod { - public: - HammingWindowCpuKernelMod() = default; - ~HammingWindowCpuKernelMod() override = default; - bool Init(const std::vector &inputs, const std::vector &outputs); - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) const; - using HammingWindowFunc = - std::function &, - const std::vector &, const std::vector &)>; - static std::vector> func_list_; - HammingWindowFunc kernel_func_; - - bool periodic_; - float alpha_, beta_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class HammingWindowCpuKernelMod : public NativeCpuKernelMod { + public: + HammingWindowCpuKernelMod() = default; + ~HammingWindowCpuKernelMod() override = default; + bool Init(const std::vector &inputs, const std::vector &outputs); + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) const; + using HammingWindowFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + HammingWindowFunc kernel_func_; + + bool periodic_; + float alpha_, beta_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_HAMMING_WINDOW_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.cc index 5794f22d0f9..0efe81b31ad 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.cc @@ -1,132 +1,132 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/lstsq_cpu_kernel.h" -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kLstsqInputsNum = 2; -constexpr size_t kLstsqOutputsNum = 1; -constexpr size_t kXDimNum = 2; -constexpr size_t kADimNum_1 = 1; -constexpr size_t kADimNum_2 = 2; -} // namespace - -bool LstsqCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstsqInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstsqOutputsNum, kernel_name_); - - dtype_0_ = inputs.at(kIndex0)->dtype_id(); - dtype_1_ = inputs.at(kIndex1)->dtype_id(); - if (dtype_0_ != dtype_1_) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', input's dtypes are not the same."; - return false; - } - - return true; -} - -int LstsqCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - - input_0_shape_ = inputs[kIndex0]->GetDeviceShapeVector(); - input_1_shape_ = inputs[kIndex1]->GetDeviceShapeVector(); - if (input_0_shape_.size() != kXDimNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the input x tensor's rank must be 2 for 'Lstsq' Op, but x tensor's rank is " - << input_0_shape_.size(); - return KRET_RESIZE_FAILED; - } - if (input_1_shape_.size() != kADimNum_2 && input_1_shape_.size() != kADimNum_1) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the input a tensor's rank must be 2 or 1 for 'Lstsq' Op, but a tensor's rank is " - << input_1_shape_.size(); - return KRET_RESIZE_FAILED; - } - if (input_0_shape_[0] != input_1_shape_[0]) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the length of x_dim[0]: " << input_0_shape_[0] - << " is not equal to the length of a_dims[0]: " << input_1_shape_[0] << "."; - return KRET_RESIZE_FAILED; - } - return KRET_OK; -} - -bool LstsqCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - if (dtype_0_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else if (dtype_0_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } else if (dtype_0_ == kNumberTypeFloat64) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "Unsupported input data type."; - } - return true; -} - -template -void LstsqCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input_0_addr = reinterpret_cast(inputs[0]->device_ptr()); - auto input_1_addr = reinterpret_cast(inputs[1]->device_ptr()); - auto output_addr = reinterpret_cast(outputs[0]->device_ptr()); - size_t m = static_cast(input_0_shape_[0]); - size_t n = static_cast(input_0_shape_[1]); - size_t k = 0; - if (input_1_shape_.size() == kADimNum_1) { - k = 1; - } else { - k = static_cast(input_1_shape_[1]); - } - - typedef Eigen::Matrix MartixXd; // NOLINT - MartixXd A(m, n); - MartixXd B(m, k); - for (size_t i = 0; i < m * n; i++) { - A.data()[i] = static_cast(input_0_addr[i]); - } - for (size_t i = 0; i < m * k; i++) { - B.data()[i] = static_cast(input_1_addr[i]); - } - MartixXd result; - if (m >= n) { - result = A.colPivHouseholderQr().solve(B); - } else { - MartixXd A_Transpose = A.transpose(); - MartixXd temp = A * A_Transpose; - MartixXd tempI = temp.inverse(); - MartixXd x = A_Transpose * tempI; - MartixXd output = x * B; - result = output; - } - for (size_t i = 0; i < n; i++) { - for (size_t j = 0; j < k; j++) { - *(output_addr + i * k + j) = static_cast(result(i, j)); // NOLINT - } - } -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lstsq, LstsqCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/lstsq_cpu_kernel.h" +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kLstsqInputsNum = 2; +constexpr size_t kLstsqOutputsNum = 1; +constexpr size_t kXDimNum = 2; +constexpr size_t kADimNum_1 = 1; +constexpr size_t kADimNum_2 = 2; +} // namespace + +bool LstsqCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstsqInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstsqOutputsNum, kernel_name_); + + dtype_0_ = inputs.at(kIndex0)->dtype_id(); + dtype_1_ = inputs.at(kIndex1)->dtype_id(); + if (dtype_0_ != dtype_1_) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', input's dtypes are not the same."; + return false; + } + + return true; +} + +int LstsqCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + + input_0_shape_ = inputs[kIndex0]->GetDeviceShapeVector(); + input_1_shape_ = inputs[kIndex1]->GetDeviceShapeVector(); + if (input_0_shape_.size() != kXDimNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the input x tensor's rank must be 2 for 'Lstsq' Op, but x tensor's rank is " + << input_0_shape_.size(); + return KRET_RESIZE_FAILED; + } + if (input_1_shape_.size() != kADimNum_2 && input_1_shape_.size() != kADimNum_1) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the input a tensor's rank must be 2 or 1 for 'Lstsq' Op, but a tensor's rank is " + << input_1_shape_.size(); + return KRET_RESIZE_FAILED; + } + if (input_0_shape_[0] != input_1_shape_[0]) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the length of x_dim[0]: " << input_0_shape_[0] + << " is not equal to the length of a_dims[0]: " << input_1_shape_[0] << "."; + return KRET_RESIZE_FAILED; + } + return KRET_OK; +} + +bool LstsqCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_0_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_0_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_0_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Unsupported input data type."; + } + return true; +} + +template +void LstsqCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_0_addr = reinterpret_cast(inputs[0]->device_ptr()); + auto input_1_addr = reinterpret_cast(inputs[1]->device_ptr()); + auto output_addr = reinterpret_cast(outputs[0]->device_ptr()); + size_t m = static_cast(input_0_shape_[0]); + size_t n = static_cast(input_0_shape_[1]); + size_t k = 0; + if (input_1_shape_.size() == kADimNum_1) { + k = 1; + } else { + k = static_cast(input_1_shape_[1]); + } + + typedef Eigen::Matrix MartixXd; // NOLINT + MartixXd A(m, n); + MartixXd B(m, k); + for (size_t i = 0; i < m * n; i++) { + A.data()[i] = static_cast(input_0_addr[i]); + } + for (size_t i = 0; i < m * k; i++) { + B.data()[i] = static_cast(input_1_addr[i]); + } + MartixXd result; + if (m >= n) { + result = A.colPivHouseholderQr().solve(B); + } else { + MartixXd A_Transpose = A.transpose(); + MartixXd temp = A * A_Transpose; + MartixXd tempI = temp.inverse(); + MartixXd x = A_Transpose * tempI; + MartixXd output = x * B; + result = output; + } + for (size_t i = 0; i < n; i++) { + for (size_t j = 0; j < k; j++) { + *(output_addr + i * k + j) = static_cast(result(i, j)); // NOLINT + } + } +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lstsq, LstsqCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.h index b653e6f7205..9d26530f9ab 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lstsq_cpu_kernel.h @@ -1,59 +1,59 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ - -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class LstsqCpuKernelMod : public NativeCpuKernelMod { - public: - LstsqCpuKernelMod() = default; - ~LstsqCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - std::vector GetOpSupport() override { - static std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; - return support_list; - } - - private: - std::vector input_0_shape_; - std::vector input_1_shape_; - TypeId dtype_0_{kTypeUnknown}; - TypeId dtype_1_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class LstsqCpuKernelMod : public NativeCpuKernelMod { + public: + LstsqCpuKernelMod() = default; + ~LstsqCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector input_0_shape_; + std::vector input_1_shape_; + TypeId dtype_0_{kTypeUnknown}; + TypeId dtype_1_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LSTSQ_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.cc index a5825856324..a5fa2ba0e18 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.cc @@ -1,157 +1,157 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/lu_solve_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kDimNum = 2; -} - -int64_t get_element_num(const std::vector &shape) { return SizeToLong(SizeOf(shape)); } - -bool LuSolveCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - size_t input_num = inputs.size(); - size_t output_num = outputs.size(); - CHECK_KERNEL_INPUTS_NUM(input_num, kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputNum, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); - if (!is_match) { - MS_LOG(ERROR) << "LuSolve does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int LuSolveCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - - input_0_shape_ = inputs[kIndex0]->GetDeviceShapeVector(); - input_1_shape_ = inputs[kIndex1]->GetDeviceShapeVector(); - output_shape_ = outputs[kIndex0]->GetDeviceShapeVector(); - return KRET_OK; -} - -template -void LuSolveCpuKernelMod::LuSolve(const std::vector &inputs, - const std::vector &outputs, T1 *b_working_ptr, - T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a) { - auto output_y = reinterpret_cast(outputs[0]->device_ptr()); - size_t lu_dims = input_1_shape_.size(); - size_t lu_maxtrix_sizes = LongToSize(input_1_shape_[lu_dims - 2]); - size_t b_dim = input_0_shape_.size(); - size_t b_m = LongToSize(input_0_shape_[b_dim - 1]); - typedef Eigen::Matrix MatrixXd; - MatrixXd matrix_b = Eigen::Map(b_working_ptr, lu_maxtrix_sizes, b_m); - MatrixXd matrix_A = Eigen::Map(lu_working_ptr, lu_maxtrix_sizes, lu_maxtrix_sizes); - for (size_t i = 0; i < LongToSize(input_0_shape_[b_dim - kDimNum]); i++) { - size_t pivots_i = *(pivots_working_ptr + i) - 1; - if ((pivots_i > LongToSize(input_0_shape_[b_dim - kDimNum])) || (pivots_i >= IntToSize(matrix_b.rows())) || - (i >= IntToSize(matrix_b.rows()))) { - MS_EXCEPTION(ValueError) << "lu_pivots values out of index of lu_data. "; - } - matrix_b.row(i).swap(matrix_b.row(pivots_i)); - } - MatrixXd result = matrix_A.template triangularView().solve(matrix_b); - result.noalias() = matrix_A.template triangularView().solve(result); - for (size_t m = 0; m < b_stride; m++) { - *(output_y + a * b_stride + m) = (T2) * (result.data() + m); - } -} - -template -bool LuSolveCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input_x0 = reinterpret_cast(inputs[0]->device_ptr()); - auto input_x1 = reinterpret_cast(inputs[1]->device_ptr()); - auto input_x2 = reinterpret_cast(inputs[2]->device_ptr()); - auto input0_element_num = SizeOf(input_0_shape_); - auto input1_element_num = SizeOf(input_1_shape_); - auto output_element_num = SizeOf(output_shape_); - std::vector input_0(input_x0, input_x0 + input0_element_num); - std::vector input_1(input_x1, input_x1 + input1_element_num); - size_t b_dims = input_0_shape_.size(); - std::vector b_dims_vector = input_0_shape_; - size_t lu_dims = input_1_shape_.size(); - std::vector lu_dims_vector = input_1_shape_; - size_t b_stride = static_cast(input_0_shape_[b_dims - 1] * input_0_shape_[b_dims - 2]); - size_t lu_stride = static_cast(input_1_shape_[lu_dims - 1] * input_1_shape_[lu_dims - 2]); - size_t pivots_stride = static_cast(input_1_shape_[lu_dims - 1]); - MS_EXCEPTION_IF_ZERO("b_stride", b_stride); - size_t batch_num = output_element_num / b_stride; - if (b_dims == lu_dims) { - for (size_t i = 0; i < batch_num; i++) { - T1 *b_working_ptr = input_0.data() + i * b_stride; - T1 *lu_working_ptr = input_1.data() + i * lu_stride; - int32_t *pivots_working_ptr = &input_x2[i * pivots_stride]; - LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); - } - } else { - std::vector b_shape = b_dims_vector; - std::vector lu_shape = lu_dims_vector; - for (size_t i = 0; i < kDimNum; i++) { - b_shape.pop_back(); - lu_shape.pop_back(); - } - auto output_shape = CPUKernelUtils::GetBroadcastShape(b_shape, lu_shape); - BroadcastIterator iter(b_shape, lu_shape, output_shape); - iter.SetPos(0); - for (size_t i = 0; i < batch_num; i++) { - T1 *b_working_ptr = input_0.data() + iter.GetInputPosA() * b_stride; - T1 *lu_working_ptr = input_1.data() + iter.GetInputPosB() * lu_stride; - int32_t *pivots_working_ptr = &input_x2[iter.GetInputPosB() * pivots_stride]; - LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); - iter.GenNextPos(); - } - } - return true; -} - -std::vector> LuSolveCpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &LuSolveCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &LuSolveCpuKernelMod::LaunchKernel}}; - -std::vector LuSolveCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LuSolve, LuSolveCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/lu_solve_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kDimNum = 2; +} + +int64_t get_element_num(const std::vector &shape) { return SizeToLong(SizeOf(shape)); } + +bool LuSolveCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + size_t input_num = inputs.size(); + size_t output_num = outputs.size(); + CHECK_KERNEL_INPUTS_NUM(input_num, kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); + if (!is_match) { + MS_LOG(ERROR) << "LuSolve does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int LuSolveCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + + input_0_shape_ = inputs[kIndex0]->GetDeviceShapeVector(); + input_1_shape_ = inputs[kIndex1]->GetDeviceShapeVector(); + output_shape_ = outputs[kIndex0]->GetDeviceShapeVector(); + return KRET_OK; +} + +template +void LuSolveCpuKernelMod::LuSolve(const std::vector &inputs, + const std::vector &outputs, T1 *b_working_ptr, + T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a) { + auto output_y = reinterpret_cast(outputs[0]->device_ptr()); + size_t lu_dims = input_1_shape_.size(); + size_t lu_maxtrix_sizes = LongToSize(input_1_shape_[lu_dims - 2]); + size_t b_dim = input_0_shape_.size(); + size_t b_m = LongToSize(input_0_shape_[b_dim - 1]); + typedef Eigen::Matrix MatrixXd; + MatrixXd matrix_b = Eigen::Map(b_working_ptr, lu_maxtrix_sizes, b_m); + MatrixXd matrix_A = Eigen::Map(lu_working_ptr, lu_maxtrix_sizes, lu_maxtrix_sizes); + for (size_t i = 0; i < LongToSize(input_0_shape_[b_dim - kDimNum]); i++) { + size_t pivots_i = *(pivots_working_ptr + i) - 1; + if ((pivots_i > LongToSize(input_0_shape_[b_dim - kDimNum])) || (pivots_i >= IntToSize(matrix_b.rows())) || + (i >= IntToSize(matrix_b.rows()))) { + MS_EXCEPTION(ValueError) << "lu_pivots values out of index of lu_data. "; + } + matrix_b.row(i).swap(matrix_b.row(pivots_i)); + } + MatrixXd result = matrix_A.template triangularView().solve(matrix_b); + result.noalias() = matrix_A.template triangularView().solve(result); + for (size_t m = 0; m < b_stride; m++) { + *(output_y + a * b_stride + m) = (T2) * (result.data() + m); + } +} + +template +bool LuSolveCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_x0 = reinterpret_cast(inputs[0]->device_ptr()); + auto input_x1 = reinterpret_cast(inputs[1]->device_ptr()); + auto input_x2 = reinterpret_cast(inputs[2]->device_ptr()); + auto input0_element_num = SizeOf(input_0_shape_); + auto input1_element_num = SizeOf(input_1_shape_); + auto output_element_num = SizeOf(output_shape_); + std::vector input_0(input_x0, input_x0 + input0_element_num); + std::vector input_1(input_x1, input_x1 + input1_element_num); + size_t b_dims = input_0_shape_.size(); + std::vector b_dims_vector = input_0_shape_; + size_t lu_dims = input_1_shape_.size(); + std::vector lu_dims_vector = input_1_shape_; + size_t b_stride = static_cast(input_0_shape_[b_dims - 1] * input_0_shape_[b_dims - 2]); + size_t lu_stride = static_cast(input_1_shape_[lu_dims - 1] * input_1_shape_[lu_dims - 2]); + size_t pivots_stride = static_cast(input_1_shape_[lu_dims - 1]); + MS_EXCEPTION_IF_ZERO("b_stride", b_stride); + size_t batch_num = output_element_num / b_stride; + if (b_dims == lu_dims) { + for (size_t i = 0; i < batch_num; i++) { + T1 *b_working_ptr = input_0.data() + i * b_stride; + T1 *lu_working_ptr = input_1.data() + i * lu_stride; + int32_t *pivots_working_ptr = &input_x2[i * pivots_stride]; + LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); + } + } else { + std::vector b_shape = b_dims_vector; + std::vector lu_shape = lu_dims_vector; + for (size_t i = 0; i < kDimNum; i++) { + b_shape.pop_back(); + lu_shape.pop_back(); + } + auto output_shape = CPUKernelUtils::GetBroadcastShape(b_shape, lu_shape); + BroadcastIterator iter(b_shape, lu_shape, output_shape); + iter.SetPos(0); + for (size_t i = 0; i < batch_num; i++) { + T1 *b_working_ptr = input_0.data() + iter.GetInputPosA() * b_stride; + T1 *lu_working_ptr = input_1.data() + iter.GetInputPosB() * lu_stride; + int32_t *pivots_working_ptr = &input_x2[iter.GetInputPosB() * pivots_stride]; + LuSolve(inputs, outputs, b_working_ptr, lu_working_ptr, pivots_working_ptr, b_stride, i); + iter.GenNextPos(); + } + } + return true; +} + +std::vector> LuSolveCpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &LuSolveCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &LuSolveCpuKernelMod::LaunchKernel}}; + +std::vector LuSolveCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LuSolve, LuSolveCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.h index 5e43a9b431a..e56dc2afac7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lu_solve_cpu_kernel.h @@ -1,64 +1,64 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -constexpr size_t kInputNum = 3; -constexpr size_t kOutputNum = 1; -namespace kernel { -class LuSolveCpuKernelMod : public NativeCpuKernelMod { - public: - LuSolveCpuKernelMod() = default; - ~LuSolveCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - std::vector GetOpSupport() override; - - private: - template - void LuSolve(const std::vector &inputs, const std::vector &outputs, - T1 *b_working_ptr, T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a); - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &outputs); - using LuSolveFunc = std::function &, - const std::vector &)>; - static std::vector> func_list_; - LuSolveFunc kernel_func_; - CNodePtr node_wpt_; - std::vector input_0_shape_; - std::vector input_1_shape_; - std::vector output_shape_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +constexpr size_t kInputNum = 3; +constexpr size_t kOutputNum = 1; +namespace kernel { +class LuSolveCpuKernelMod : public NativeCpuKernelMod { + public: + LuSolveCpuKernelMod() = default; + ~LuSolveCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + std::vector GetOpSupport() override; + + private: + template + void LuSolve(const std::vector &inputs, const std::vector &outputs, + T1 *b_working_ptr, T1 *lu_working_ptr, int32_t *pivots_working_ptr, size_t b_stride, size_t a); + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &outputs); + using LuSolveFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + LuSolveFunc kernel_func_; + CNodePtr node_wpt_; + std::vector input_0_shape_; + std::vector input_1_shape_; + std::vector output_shape_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_LUSOLVE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.cc index f030af8e427..c093f66c57f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.cc @@ -1,165 +1,165 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h" -#include -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -using Eigen::ColMajor; -using Eigen::Dynamic; -using Eigen::Lower; -using Eigen::Map; -using Eigen::MatrixBase; -using Eigen::RowMajor; -using Eigen::UnitLower; -using Eigen::UnitUpper; -using Eigen::Upper; -template -using Matrix = Eigen::Matrix; -constexpr auto kSolveTriangularInputsNum = 2; -constexpr auto kSolveTriangularOutputsNum = 1; -constexpr auto kAVectorxDimNum = 1; -constexpr auto kAMatrixDimNum = 2; -constexpr size_t kRowIndex = 2; -constexpr size_t kColIndex = 1; - -bool MatrixTriangularSolveCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_); - - trans_ = GetValue(primitive_->GetAttr(ADJOINT)); - lower_ = GetValue(primitive_->GetAttr(LOWER)); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int MatrixTriangularSolveCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto a_shape = inputs[0]->GetShapeVector(); - auto b_shape = inputs[1]->GetShapeVector(); - // Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid. - size_t a_dims = a_shape.size(); - size_t aRowIndex = a_dims - kRowIndex; - m_ = static_cast(a_shape[aRowIndex]); - size_t b_sims = b_shape.size(); - bool vector_b = b_sims == a_dims - 1; - if (vector_b) { - n_ = 1; - } else { - n_ = static_cast(b_shape[b_sims - 1]); - } - batch_ = 1; - for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) { - batch_ *= static_cast(a_shape[batch]); - } - return KRET_OK; -} - -template -inline void solve(const MatrixBase &a, const MatrixBase &b, T *output_addr, int m, int n, - bool lower, bool unit_diagonal) { - Map> output(output_addr, m, n); - if (unit_diagonal) { - if (lower) { - output.noalias() = a.template triangularView().solve(b); - } else { - output.noalias() = a.template triangularView().solve(b); - } - } else { - if (lower) { - output.noalias() = a.template triangularView().solve(b); - } else { - output.noalias() = a.template triangularView().solve(b); - } - } -} - -template -bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - auto a_addr = reinterpret_cast(inputs[0]->device_ptr()); - auto b_addr = reinterpret_cast(inputs[1]->device_ptr()); - auto output_addr = reinterpret_cast(outputs[0]->device_ptr()); - MS_EXCEPTION_IF_NULL(a_addr); - MS_EXCEPTION_IF_NULL(b_addr); - MS_EXCEPTION_IF_NULL(output_addr); - - size_t a_batch_size = m_ * m_; - size_t b_batch_size = m_ * n_; - size_t output_batch_size = m_ * n_; - - for (size_t i = 0; i < batch_; ++i) { - T *a_batch_addr = a_addr + i * a_batch_size; - T *b_batch_addr = b_addr + i * b_batch_size; - T *output_batch_addr = output_addr + i * output_batch_size; - - Map> b(b_batch_addr, m_, n_); - if (trans_) { - Map> a(a_batch_addr, m_, m_); - auto a_conj = a.conjugate(); - solve(a_conj, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_); - } else { - Map> a(a_batch_addr, m_, m_); - solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_); - } - } - - return true; -} - -std::vector> - MatrixTriangularSolveCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &MatrixTriangularSolveCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &MatrixTriangularSolveCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex64) - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64), - &MatrixTriangularSolveCpuKernelMod::LaunchKernel>}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex128) - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128), - &MatrixTriangularSolveCpuKernelMod::LaunchKernel>}}; - -std::vector MatrixTriangularSolveCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h" +#include +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +using Eigen::ColMajor; +using Eigen::Dynamic; +using Eigen::Lower; +using Eigen::Map; +using Eigen::MatrixBase; +using Eigen::RowMajor; +using Eigen::UnitLower; +using Eigen::UnitUpper; +using Eigen::Upper; +template +using Matrix = Eigen::Matrix; +constexpr auto kSolveTriangularInputsNum = 2; +constexpr auto kSolveTriangularOutputsNum = 1; +constexpr auto kAVectorxDimNum = 1; +constexpr auto kAMatrixDimNum = 2; +constexpr size_t kRowIndex = 2; +constexpr size_t kColIndex = 1; + +bool MatrixTriangularSolveCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_); + + trans_ = GetValue(primitive_->GetAttr(ADJOINT)); + lower_ = GetValue(primitive_->GetAttr(LOWER)); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int MatrixTriangularSolveCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto a_shape = inputs[0]->GetShapeVector(); + auto b_shape = inputs[1]->GetShapeVector(); + // Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid. + size_t a_dims = a_shape.size(); + size_t aRowIndex = a_dims - kRowIndex; + m_ = static_cast(a_shape[aRowIndex]); + size_t b_sims = b_shape.size(); + bool vector_b = b_sims == a_dims - 1; + if (vector_b) { + n_ = 1; + } else { + n_ = static_cast(b_shape[b_sims - 1]); + } + batch_ = 1; + for (size_t batch = 0; batch < a_dims - kRowIndex; ++batch) { + batch_ *= static_cast(a_shape[batch]); + } + return KRET_OK; +} + +template +inline void solve(const MatrixBase &a, const MatrixBase &b, T *output_addr, int m, int n, + bool lower, bool unit_diagonal) { + Map> output(output_addr, m, n); + if (unit_diagonal) { + if (lower) { + output.noalias() = a.template triangularView().solve(b); + } else { + output.noalias() = a.template triangularView().solve(b); + } + } else { + if (lower) { + output.noalias() = a.template triangularView().solve(b); + } else { + output.noalias() = a.template triangularView().solve(b); + } + } +} + +template +bool MatrixTriangularSolveCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + auto a_addr = reinterpret_cast(inputs[0]->device_ptr()); + auto b_addr = reinterpret_cast(inputs[1]->device_ptr()); + auto output_addr = reinterpret_cast(outputs[0]->device_ptr()); + MS_EXCEPTION_IF_NULL(a_addr); + MS_EXCEPTION_IF_NULL(b_addr); + MS_EXCEPTION_IF_NULL(output_addr); + + size_t a_batch_size = m_ * m_; + size_t b_batch_size = m_ * n_; + size_t output_batch_size = m_ * n_; + + for (size_t i = 0; i < batch_; ++i) { + T *a_batch_addr = a_addr + i * a_batch_size; + T *b_batch_addr = b_addr + i * b_batch_size; + T *output_batch_addr = output_addr + i * output_batch_size; + + Map> b(b_batch_addr, m_, n_); + if (trans_) { + Map> a(a_batch_addr, m_, m_); + auto a_conj = a.conjugate(); + solve(a_conj, b, output_batch_addr, m_, n_, !lower_, unit_diagonal_); + } else { + Map> a(a_batch_addr, m_, m_); + solve(a, b, output_batch_addr, m_, n_, lower_, unit_diagonal_); + } + } + + return true; +} + +std::vector> + MatrixTriangularSolveCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &MatrixTriangularSolveCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &MatrixTriangularSolveCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &MatrixTriangularSolveCpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &MatrixTriangularSolveCpuKernelMod::LaunchKernel>}}; + +std::vector MatrixTriangularSolveCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixTriangularSolve, MatrixTriangularSolveCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h index aed61f939a2..8c80b992b0e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_triangular_solve_cpu_kernel.h @@ -1,64 +1,64 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class MatrixTriangularSolveCpuKernelMod : public NativeCpuKernelMod { - public: - MatrixTriangularSolveCpuKernelMod() = default; - ~MatrixTriangularSolveCpuKernelMod() override = default; - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - using MatrixTriangularSolveFunc = - std::function &, - const std::vector &, const std::vector &)>; - static std::vector> func_list_; - MatrixTriangularSolveFunc kernel_func_; - - size_t m_{0}; - size_t n_{0}; - size_t batch_{1}; - bool lower_{false}; - bool trans_{false}; - bool unit_diagonal_{false}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MatrixTriangularSolveCpuKernelMod : public NativeCpuKernelMod { + public: + MatrixTriangularSolveCpuKernelMod() = default; + ~MatrixTriangularSolveCpuKernelMod() override = default; + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + using MatrixTriangularSolveFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + MatrixTriangularSolveFunc kernel_func_; + + size_t m_{0}; + size_t n_{0}; + size_t batch_{1}; + bool lower_{false}; + bool trans_{false}; + bool unit_diagonal_{false}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h index 39f186331d8..969f306c461 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h @@ -1,65 +1,65 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class MvlgammaGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { - public: - MvlgammaGradCpuKernelMod() = default; - ~MvlgammaGradCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - MS_EXCEPTION_IF_NULL(kernel_func_); - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - template - T Digamma(const T &input) const; - - template - T MvlgammaGradSingle(const T &y_grad, const T &x, const int64_t &p) const; - - ShapeVector input_shape_; - int64_t attr_p_; - int64_t input_tensor_size_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MvlgammaGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + MvlgammaGradCpuKernelMod() = default; + ~MvlgammaGradCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + template + T Digamma(const T &input) const; + + template + T MvlgammaGradSingle(const T &y_grad, const T &x, const int64_t &p) const; + + ShapeVector input_shape_; + int64_t attr_p_; + int64_t input_tensor_size_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MVLGAMMA_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.cc index 56dbce7d381..628689408a6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.cc @@ -1,167 +1,167 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/non_zero_cpu_kernel.h" -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputNum = 1; -constexpr size_t kOutputNum = 1; -constexpr size_t kInputMinDim = 1; -constexpr size_t kOutputDim = 2; -} // namespace - -bool NonZeroCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - data_size_ = abstract::TypeIdSize(inputs[kIndex0]->dtype_id()); - index_size_ = abstract::TypeIdSize(outputs[kIndex0]->dtype_id()); - return true; -} - -int NonZeroCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - auto ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_UNKNOWN_OUT_SHAPE && ret != KRET_OK) { - return ret; - } - ResetResource(); - auto input_shape = inputs[kIndex0]->GetDeviceShapeVector(); - (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), - [](int64_t x) { return x < 0 ? 0 : LongToSize(x); }); - input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies{}); - if (input_size_ == 0) { - return KRET_UNKNOWN_SHAPE; - } - input_rank_ = input_shape_.size(); - - output_size_list_.push_back(input_size_ * input_shape_.size() * index_size_); - return KRET_OK; -} - -void NonZeroCpuKernelMod::ResetResource() noexcept { - real_output_size_ = 0; - input_shape_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); -} - -template -bool NonZeroCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - if (input_size_ == 0) { - return true; - } - - auto input_addr = static_cast(inputs[0]->device_ptr()); - auto output_addr = static_cast(outputs[0]->device_ptr()); - real_output_size_ = NonZeroCompute(input_addr, output_addr, input_size_); - return true; -} - -void NonZeroCpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) { - std::vector new_output_shape = {SizeToLong(real_output_size_), SizeToLong(input_shape_.size())}; - outputs[kIndex0]->SetShapeVector(new_output_shape); - outputs[kIndex0]->set_size(real_output_size_ * input_shape_.size() * index_size_); -} - -template -size_t NonZeroCpuKernelMod::NonZeroCompute(const T *input, int64_t *output, size_t input_num) { - size_t non_zero_count = 0; - std::vector dim_strides(input_rank_, 1); - - for (size_t i = input_rank_ - 1; i >= 1; --i) { - dim_strides[i - 1] = dim_strides[i] * input_shape_[i]; - } - - for (size_t elem_i = 0; elem_i < input_num; ++elem_i) { - auto zero = static_cast(0); - if constexpr (std::is_same_v) { - if (common::IsDoubleEqual(input[elem_i], zero)) { - continue; - } - } else { - if constexpr (std::is_same_v) { - if (common::IsFloatEqual(input[elem_i], zero)) { - continue; - } - } else { - if (input[elem_i] == zero) { - continue; - } - } - } - size_t index = elem_i; - for (size_t pos_j = 0; pos_j < input_rank_; ++pos_j) { - output[non_zero_count * input_rank_ + pos_j] = static_cast(index / dim_strides[pos_j]); - index %= dim_strides[pos_j]; - } - non_zero_count++; - } - return non_zero_count; -} - -std::vector> NonZeroCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), - &NonZeroCpuKernelMod::LaunchKernel}}; - -std::vector NonZeroCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NonZero, NonZeroCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/non_zero_cpu_kernel.h" +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputNum = 1; +constexpr size_t kOutputNum = 1; +constexpr size_t kInputMinDim = 1; +constexpr size_t kOutputDim = 2; +} // namespace + +bool NonZeroCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + data_size_ = abstract::TypeIdSize(inputs[kIndex0]->dtype_id()); + index_size_ = abstract::TypeIdSize(outputs[kIndex0]->dtype_id()); + return true; +} + +int NonZeroCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + auto ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_UNKNOWN_OUT_SHAPE && ret != KRET_OK) { + return ret; + } + ResetResource(); + auto input_shape = inputs[kIndex0]->GetDeviceShapeVector(); + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), + [](int64_t x) { return x < 0 ? 0 : LongToSize(x); }); + input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies{}); + if (input_size_ == 0) { + return KRET_UNKNOWN_SHAPE; + } + input_rank_ = input_shape_.size(); + + output_size_list_.push_back(input_size_ * input_shape_.size() * index_size_); + return KRET_OK; +} + +void NonZeroCpuKernelMod::ResetResource() noexcept { + real_output_size_ = 0; + input_shape_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +template +bool NonZeroCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + if (input_size_ == 0) { + return true; + } + + auto input_addr = static_cast(inputs[0]->device_ptr()); + auto output_addr = static_cast(outputs[0]->device_ptr()); + real_output_size_ = NonZeroCompute(input_addr, output_addr, input_size_); + return true; +} + +void NonZeroCpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) { + std::vector new_output_shape = {SizeToLong(real_output_size_), SizeToLong(input_shape_.size())}; + outputs[kIndex0]->SetShapeVector(new_output_shape); + outputs[kIndex0]->set_size(real_output_size_ * input_shape_.size() * index_size_); +} + +template +size_t NonZeroCpuKernelMod::NonZeroCompute(const T *input, int64_t *output, size_t input_num) { + size_t non_zero_count = 0; + std::vector dim_strides(input_rank_, 1); + + for (size_t i = input_rank_ - 1; i >= 1; --i) { + dim_strides[i - 1] = dim_strides[i] * input_shape_[i]; + } + + for (size_t elem_i = 0; elem_i < input_num; ++elem_i) { + auto zero = static_cast(0); + if constexpr (std::is_same_v) { + if (common::IsDoubleEqual(input[elem_i], zero)) { + continue; + } + } else { + if constexpr (std::is_same_v) { + if (common::IsFloatEqual(input[elem_i], zero)) { + continue; + } + } else { + if (input[elem_i] == zero) { + continue; + } + } + } + size_t index = elem_i; + for (size_t pos_j = 0; pos_j < input_rank_; ++pos_j) { + output[non_zero_count * input_rank_ + pos_j] = static_cast(index / dim_strides[pos_j]); + index %= dim_strides[pos_j]; + } + non_zero_count++; + } + return non_zero_count; +} + +std::vector> NonZeroCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), + &NonZeroCpuKernelMod::LaunchKernel}}; + +std::vector NonZeroCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NonZero, NonZeroCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.h index 9e5315f3e67..be273040a2e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/non_zero_cpu_kernel.h @@ -1,77 +1,77 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NON_ZERO_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NON_ZERO_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -using complex64 = std::complex; -using complex128 = std::complex; - -class NonZeroCpuKernelMod : public NativeCpuKernelMod { - public: - NonZeroCpuKernelMod() = default; - ~NonZeroCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - protected: - void UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) override; - bool IsNeedUpdateOutputShapeAndSize() override { return true; } - std::vector GetOpSupport() override; - - private: - void ResetResource() noexcept; - - template - size_t NonZeroCompute(const T *input, int64_t *output, size_t input_num); - - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &outputs); - using NonZeroFunc = std::function &, - const std::vector &)>; - static std::vector> func_list_; - - NonZeroFunc kernel_func_; - std::vector input_shape_; - size_t input_rank_{0}; - size_t input_size_{0}; - size_t data_size_{0}; // That is, sizeof(DataType). - size_t index_size_{0}; // That is, sizeof(IndexType) - size_t real_output_size_{0}; // Dynamic shape related. -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NON_ZERO_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NON_ZERO_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NON_ZERO_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; + +class NonZeroCpuKernelMod : public NativeCpuKernelMod { + public: + NonZeroCpuKernelMod() = default; + ~NonZeroCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + void UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) override; + bool IsNeedUpdateOutputShapeAndSize() override { return true; } + std::vector GetOpSupport() override; + + private: + void ResetResource() noexcept; + + template + size_t NonZeroCompute(const T *input, int64_t *output, size_t input_num); + + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &outputs); + using NonZeroFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + + NonZeroFunc kernel_func_; + std::vector input_shape_; + size_t input_rank_{0}; + size_t input_size_{0}; + size_t data_size_{0}; // That is, sizeof(DataType). + size_t index_size_{0}; // That is, sizeof(IndexType) + size_t real_output_size_{0}; // Dynamic shape related. +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NON_ZERO_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.cc index d38ad220a96..ae05f262d15 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.cc @@ -1,157 +1,157 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr int axis = 0; -constexpr size_t kParallelConcatOutputsNum = 1; -} // namespace - -bool ParallelConcatCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the data type of input must be float or double, but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int ParallelConcatCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - ResetResource(); - std::vector output_shape = outputs[0]->GetShapeVector(); - int64_t output_elements = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements == 0) { - is_null_input_ = true; - } - input_num_ = inputs.size(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input tensor is empty, which is not expected."; - } - auto x_shape = inputs[0]->GetShapeVector(); - for (size_t i = 0; i < input_num_; i++) { - if (x_shape != inputs[i]->GetShapeVector()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of all tensors must be the same, but got tensor0.shape " - << x_shape << " and tensor" << i << ".shape " << inputs[i]->GetShapeVector(); - } - } - input_flat_shape_list_.reserve(input_num_); - for (size_t i = 0; i < input_num_; i++) { - auto input_shape_i = inputs[i]->GetShapeVector(); - auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis); - (void)input_flat_shape_list_.emplace_back(flat_shape); - } - return KRET_OK; -} - -template -bool ParallelConcatCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - std::vector input_addr_list; - for (size_t j = 0; j < input_num_; ++j) { - auto *tmp_addr = static_cast(inputs[j]->device_ptr()); - MS_EXCEPTION_IF_NULL(tmp_addr); - (void)input_addr_list.emplace_back(tmp_addr); - } - auto *output_addr = static_cast(outputs[0]->device_ptr()); - MS_EXCEPTION_IF_NULL(output_addr); - - size_t output_dim_1 = 0; - for (size_t j = 0; j < input_num_; ++j) { - output_dim_1 += LongToSize(input_flat_shape_list_[j][1]); - } - - // each input's row of shape after flat are same - auto before_axis = LongToSize(input_flat_shape_list_[0][0]); - auto task = [&](size_t start, size_t end) { - for (size_t i = start; i < end; ++i) { - auto output_ptr = output_addr + i * output_dim_1; - for (size_t j = 0; j < input_num_; ++j) { - if (input_flat_shape_list_[j][1] == 0) { - continue; - } - auto copy_num = LongToSize(input_flat_shape_list_[j][1]); - auto copy_size = copy_num * sizeof(T); - auto offset = copy_num * i; - auto ret = memcpy_s(output_ptr, copy_size, input_addr_list[j] + offset, copy_size); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret; - } - output_ptr += copy_num; - } - } - }; - ParallelLaunchAutoSearch(task, before_axis, this, ¶llel_search_info_); - return true; -} - -std::vector> ParallelConcatCpuKernelMod::func_list_ = { - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - &ParallelConcatCpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), - &ParallelConcatCpuKernelMod::LaunchKernel}}; - -std::vector ParallelConcatCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ParallelConcat, ParallelConcatCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int axis = 0; +constexpr size_t kParallelConcatOutputsNum = 1; +} // namespace + +bool ParallelConcatCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the data type of input must be float or double, but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int ParallelConcatCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + ResetResource(); + std::vector output_shape = outputs[0]->GetShapeVector(); + int64_t output_elements = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements == 0) { + is_null_input_ = true; + } + input_num_ = inputs.size(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input tensor is empty, which is not expected."; + } + auto x_shape = inputs[0]->GetShapeVector(); + for (size_t i = 0; i < input_num_; i++) { + if (x_shape != inputs[i]->GetShapeVector()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of all tensors must be the same, but got tensor0.shape " + << x_shape << " and tensor" << i << ".shape " << inputs[i]->GetShapeVector(); + } + } + input_flat_shape_list_.reserve(input_num_); + for (size_t i = 0; i < input_num_; i++) { + auto input_shape_i = inputs[i]->GetShapeVector(); + auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis); + (void)input_flat_shape_list_.emplace_back(flat_shape); + } + return KRET_OK; +} + +template +bool ParallelConcatCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + std::vector input_addr_list; + for (size_t j = 0; j < input_num_; ++j) { + auto *tmp_addr = static_cast(inputs[j]->device_ptr()); + MS_EXCEPTION_IF_NULL(tmp_addr); + (void)input_addr_list.emplace_back(tmp_addr); + } + auto *output_addr = static_cast(outputs[0]->device_ptr()); + MS_EXCEPTION_IF_NULL(output_addr); + + size_t output_dim_1 = 0; + for (size_t j = 0; j < input_num_; ++j) { + output_dim_1 += LongToSize(input_flat_shape_list_[j][1]); + } + + // each input's row of shape after flat are same + auto before_axis = LongToSize(input_flat_shape_list_[0][0]); + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + auto output_ptr = output_addr + i * output_dim_1; + for (size_t j = 0; j < input_num_; ++j) { + if (input_flat_shape_list_[j][1] == 0) { + continue; + } + auto copy_num = LongToSize(input_flat_shape_list_[j][1]); + auto copy_size = copy_num * sizeof(T); + auto offset = copy_num * i; + auto ret = memcpy_s(output_ptr, copy_size, input_addr_list[j] + offset, copy_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret; + } + output_ptr += copy_num; + } + } + }; + ParallelLaunchAutoSearch(task, before_axis, this, ¶llel_search_info_); + return true; +} + +std::vector> ParallelConcatCpuKernelMod::func_list_ = { + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &ParallelConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + &ParallelConcatCpuKernelMod::LaunchKernel}}; + +std::vector ParallelConcatCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ParallelConcat, ParallelConcatCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h index 59e477b1183..c72ba9f2e82 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h @@ -1,71 +1,71 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -using complex64 = std::complex; -using complex128 = std::complex; - -class ParallelConcatCpuKernelMod : public NativeCpuKernelMod { - public: - ParallelConcatCpuKernelMod() = default; - ~ParallelConcatCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - input_num_ = 0; - input_flat_shape_list_.clear(); - } - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &outputs); - using PCFunc = std::function &, - const std::vector &)>; - - private: - size_t input_num_; - PCFunc kernel_func_; - bool is_null_input_{false}; - std::vector input_flat_shape_list_; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; + +class ParallelConcatCpuKernelMod : public NativeCpuKernelMod { + public: + ParallelConcatCpuKernelMod() = default; + ~ParallelConcatCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + input_num_ = 0; + input_flat_shape_list_.clear(); + } + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &outputs); + using PCFunc = std::function &, + const std::vector &)>; + + private: + size_t input_num_; + PCFunc kernel_func_; + bool is_null_input_{false}; + std::vector input_flat_shape_list_; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/customize/non_zero.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/customize/non_zero.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/customize/non_zero.h b/mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/customize/non_zero.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.cc index b9d0683451b..61c6a1d0148 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.cc @@ -1,98 +1,98 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/scale_grad_cpu_kernel.h" -#include -#include -#include "mindspore/core/ops/fusion/scale_grad_fusion.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore::kernel { -template -void ScaleGradCpuKernelMod::LaunchScaleGradPerGrad(const std::vector &inputs, - const std::vector &outputs, - const float16 *scale_addr_half, const float *scale_addr_float, - size_t index) { - T *input_addr = GetDeviceAddress(inputs, index); - T *output_addr = GetDeviceAddress(outputs, index); - T x1; - if (scale_addr_half != nullptr) { - x1 = static_cast(*scale_addr_half); - } else { - MS_EXCEPTION_IF_NULL(scale_addr_float); - x1 = static_cast(*scale_addr_float); - } - - size_t lens = outputs[index]->size() > 0 ? static_cast(outputs[index]->size() / sizeof(T)) : 1; - auto task = [input_addr, x1, output_addr](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - output_addr[i] = input_addr[i] * x1; - } - }; - ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); -} - -bool ScaleGradCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - float16 *scale_addr_half = nullptr; - float *scale_addr_float = nullptr; - if (input_info_.back() == kNumberTypeFloat16) { - scale_addr_half = GetDeviceAddress(inputs, inputs.size() - 1); - } else { - scale_addr_float = GetDeviceAddress(inputs, inputs.size() - 1); - } - - for (size_t i = 0; i < inputs.size() - 1; i++) { - switch (input_info_[i]) { - case kNumberTypeFloat16: { - LaunchScaleGradPerGrad(inputs, outputs, scale_addr_half, scale_addr_float, i); - break; - } - case kNumberTypeFloat32: { - LaunchScaleGradPerGrad(inputs, outputs, scale_addr_half, scale_addr_float, i); - break; - } - default: - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the typeid cannot be " << input_info_[i]; - } - } - return true; -} - -std::vector ScaleGradCpuKernelMod::GetOpSupport() { - std::vector support_list; - support_list.push_back(KernelAttr().AddSkipCheckAttr(true)); - return support_list; -} - -bool ScaleGradCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto input_size = inputs.size(); - for (size_t index = 0; index < input_size; index++) { - auto type_id = inputs[index]->dtype_id(); - input_info_.push_back(type_id); - - if (index < input_size - 1) { - output_size_list_.push_back(inputs[index]->size()); - } - } - - return true; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ScaleGrad, ScaleGradCpuKernelMod); -} // namespace mindspore::kernel +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/scale_grad_cpu_kernel.h" +#include +#include +#include "mindspore/core/ops/fusion/scale_grad_fusion.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore::kernel { +template +void ScaleGradCpuKernelMod::LaunchScaleGradPerGrad(const std::vector &inputs, + const std::vector &outputs, + const float16 *scale_addr_half, const float *scale_addr_float, + size_t index) { + T *input_addr = GetDeviceAddress(inputs, index); + T *output_addr = GetDeviceAddress(outputs, index); + T x1; + if (scale_addr_half != nullptr) { + x1 = static_cast(*scale_addr_half); + } else { + MS_EXCEPTION_IF_NULL(scale_addr_float); + x1 = static_cast(*scale_addr_float); + } + + size_t lens = outputs[index]->size() > 0 ? static_cast(outputs[index]->size() / sizeof(T)) : 1; + auto task = [input_addr, x1, output_addr](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + output_addr[i] = input_addr[i] * x1; + } + }; + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); +} + +bool ScaleGradCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + float16 *scale_addr_half = nullptr; + float *scale_addr_float = nullptr; + if (input_info_.back() == kNumberTypeFloat16) { + scale_addr_half = GetDeviceAddress(inputs, inputs.size() - 1); + } else { + scale_addr_float = GetDeviceAddress(inputs, inputs.size() - 1); + } + + for (size_t i = 0; i < inputs.size() - 1; i++) { + switch (input_info_[i]) { + case kNumberTypeFloat16: { + LaunchScaleGradPerGrad(inputs, outputs, scale_addr_half, scale_addr_float, i); + break; + } + case kNumberTypeFloat32: { + LaunchScaleGradPerGrad(inputs, outputs, scale_addr_half, scale_addr_float, i); + break; + } + default: + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the typeid cannot be " << input_info_[i]; + } + } + return true; +} + +std::vector ScaleGradCpuKernelMod::GetOpSupport() { + std::vector support_list; + support_list.push_back(KernelAttr().AddSkipCheckAttr(true)); + return support_list; +} + +bool ScaleGradCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto input_size = inputs.size(); + for (size_t index = 0; index < input_size; index++) { + auto type_id = inputs[index]->dtype_id(); + input_info_.push_back(type_id); + + if (index < input_size - 1) { + output_size_list_.push_back(inputs[index]->size()); + } + } + + return true; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ScaleGrad, ScaleGradCpuKernelMod); +} // namespace mindspore::kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.h index 719412b1085..fdb9e7b2366 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scale_grad_cpu_kernel.h @@ -1,47 +1,47 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore::kernel { -class ScaleGradCpuKernelMod : public NativeCpuKernelMod { - public: - ScaleGradCpuKernelMod() = default; - ~ScaleGradCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - std::vector GetOpSupport() override; - - private: - template - void LaunchScaleGradPerGrad(const std::vector &inputs, const std::vector &outputs, - const float16 *scale_addr_half, const float *scale_addr_float, size_t index); - std::vector input_info_; -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore::kernel { +class ScaleGradCpuKernelMod : public NativeCpuKernelMod { + public: + ScaleGradCpuKernelMod() = default; + ~ScaleGradCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + std::vector GetOpSupport() override; + + private: + template + void LaunchScaleGradPerGrad(const std::vector &inputs, const std::vector &outputs, + const float16 *scale_addr_half, const float *scale_addr_float, size_t index); + std::vector input_info_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCALE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.cc index c8d2244b587..3280d2245f1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.cc @@ -1,171 +1,171 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "Eigen/Dense" - -namespace mindspore::kernel { -constexpr auto kSelfAdjopintEig = "SelfAdjopintEig"; -constexpr const size_t kInputsNum = 1; -constexpr const size_t kOutputsNum = 2; - -bool SelfAdjointEigCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match.first) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - } - - dtype_ = inputs[kIndex0]->dtype_id(); - compute_v_ = GetValue(primitive_->GetAttr("compute_v")); - - return true; -} - -int SelfAdjointEigCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_shape_ = inputs[kIndex0]->GetShapeVector(); - - return KRET_OK; -} - -bool SelfAdjointEigCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - if (dtype_ == kNumberTypeFloat32) { - (void)LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat64) { - (void)LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeComplex64) { - (void)LaunchKernel>(inputs, outputs); - } else if (dtype_ == kNumberTypeComplex128) { - (void)LaunchKernel>(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of x must be float32 or float64, but got " - << TypeIdLabel(dtype_) << "."; - } - return true; -} - -template -bool SelfAdjointEigCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto *input = reinterpret_cast(inputs[kIndex0]->device_ptr()); - auto *output0 = reinterpret_cast(outputs[kIndex0]->device_ptr()); - auto *output1 = reinterpret_cast(outputs[kIndex1]->device_ptr()); - bool attr0_ = compute_v_; - // The size of each dimension - std::vector shape = input_shape_; - // rank - auto input_dims = input_shape_.size(); - // Total number of elements - size_t input_numelements = static_cast(inputs[0]->size() / sizeof(T)); - // The length of the line - const int32_t m = shape[input_dims - 1]; - // The length of the column - const int32_t n = shape[input_dims - 2]; - auto num_array = (SizeToLong(input_numelements)) / (m * n); - using MatrixMap = Eigen::Map>; - const size_t input_dim_min = 2; - if (input_dims <= input_dim_min) { - MatrixMap Input0(input, m, n); - MatrixMap Output0(output0, m, 1); - MatrixMap Output1(output1, m, n); - Eigen::SelfAdjointEigenSolver> es( - Input0, attr0_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); - Output0 = es.eigenvalues().template cast(); - for (int64_t i = 0; i < m; i++) { - *(output0 + i) = Output0(i, 0); - } - if (attr0_) { - Output1 = es.eigenvectors(); - for (int64_t i = 0; i < m; i++) { - for (int64_t j = 0; j < m; j++) { - *(output1 + i * m + j) = Output1(i, j); - } - } - } - } else { - for (int64_t batch = 0; batch < num_array; ++batch) { - T *A = static_cast(new T[m * n]); - T *B = static_cast(new T[m]); - T *C = static_cast(new T[m * n]); - // Get the address of the input and output matrix for each batch - for (int64_t i = 0; i < m * n; ++i) { - A[i] = input[batch * m * n + i]; - C[i] = output1[batch * m * n + i]; - } - for (int64_t i = 0; i < m; ++i) { - B[i] = output0[batch * m + i]; - } - MatrixMap Input0(A, m, n); - MatrixMap Output0(B, m, 1); - MatrixMap Output1(C, m, n); - Eigen::SelfAdjointEigenSolver> es( - Input0, attr0_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); - Output0 = es.eigenvalues().template cast(); - for (int64_t i = 0; i < m; i++) { - *(output0 + batch * n + i) = Output0(i, 0); - } - if (attr0_) { - Output1 = es.eigenvectors(); - for (int64_t i = 0; i < m; i++) { - for (int64_t j = 0; j < m; j++) { - *(output1 + batch * m * n + i * m + j) = Output1(i, j); - } - } - } - } - } - return true; -} -std::vector> - SelfAdjointEigCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &SelfAdjointEigCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &SelfAdjointEigCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeComplex64), - &SelfAdjointEigCpuKernelMod::LaunchKernel>}, - {KernelAttr() - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeComplex128), - &SelfAdjointEigCpuKernelMod::LaunchKernel>}}; - -std::vector SelfAdjointEigCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SelfAdjointEig, SelfAdjointEigCpuKernelMod); -} // namespace mindspore::kernel +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "Eigen/Dense" + +namespace mindspore::kernel { +constexpr auto kSelfAdjopintEig = "SelfAdjopintEig"; +constexpr const size_t kInputsNum = 1; +constexpr const size_t kOutputsNum = 2; + +bool SelfAdjointEigCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match.first) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + } + + dtype_ = inputs[kIndex0]->dtype_id(); + compute_v_ = GetValue(primitive_->GetAttr("compute_v")); + + return true; +} + +int SelfAdjointEigCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_shape_ = inputs[kIndex0]->GetShapeVector(); + + return KRET_OK; +} + +bool SelfAdjointEigCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + (void)LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + (void)LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeComplex64) { + (void)LaunchKernel>(inputs, outputs); + } else if (dtype_ == kNumberTypeComplex128) { + (void)LaunchKernel>(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of x must be float32 or float64, but got " + << TypeIdLabel(dtype_) << "."; + } + return true; +} + +template +bool SelfAdjointEigCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto *input = reinterpret_cast(inputs[kIndex0]->device_ptr()); + auto *output0 = reinterpret_cast(outputs[kIndex0]->device_ptr()); + auto *output1 = reinterpret_cast(outputs[kIndex1]->device_ptr()); + bool attr0_ = compute_v_; + // The size of each dimension + std::vector shape = input_shape_; + // rank + auto input_dims = input_shape_.size(); + // Total number of elements + size_t input_numelements = static_cast(inputs[0]->size() / sizeof(T)); + // The length of the line + const int32_t m = shape[input_dims - 1]; + // The length of the column + const int32_t n = shape[input_dims - 2]; + auto num_array = (SizeToLong(input_numelements)) / (m * n); + using MatrixMap = Eigen::Map>; + const size_t input_dim_min = 2; + if (input_dims <= input_dim_min) { + MatrixMap Input0(input, m, n); + MatrixMap Output0(output0, m, 1); + MatrixMap Output1(output1, m, n); + Eigen::SelfAdjointEigenSolver> es( + Input0, attr0_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); + Output0 = es.eigenvalues().template cast(); + for (int64_t i = 0; i < m; i++) { + *(output0 + i) = Output0(i, 0); + } + if (attr0_) { + Output1 = es.eigenvectors(); + for (int64_t i = 0; i < m; i++) { + for (int64_t j = 0; j < m; j++) { + *(output1 + i * m + j) = Output1(i, j); + } + } + } + } else { + for (int64_t batch = 0; batch < num_array; ++batch) { + T *A = static_cast(new T[m * n]); + T *B = static_cast(new T[m]); + T *C = static_cast(new T[m * n]); + // Get the address of the input and output matrix for each batch + for (int64_t i = 0; i < m * n; ++i) { + A[i] = input[batch * m * n + i]; + C[i] = output1[batch * m * n + i]; + } + for (int64_t i = 0; i < m; ++i) { + B[i] = output0[batch * m + i]; + } + MatrixMap Input0(A, m, n); + MatrixMap Output0(B, m, 1); + MatrixMap Output1(C, m, n); + Eigen::SelfAdjointEigenSolver> es( + Input0, attr0_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); + Output0 = es.eigenvalues().template cast(); + for (int64_t i = 0; i < m; i++) { + *(output0 + batch * n + i) = Output0(i, 0); + } + if (attr0_) { + Output1 = es.eigenvectors(); + for (int64_t i = 0; i < m; i++) { + for (int64_t j = 0; j < m; j++) { + *(output1 + batch * m * n + i * m + j) = Output1(i, j); + } + } + } + } + } + return true; +} +std::vector> + SelfAdjointEigCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &SelfAdjointEigCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &SelfAdjointEigCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &SelfAdjointEigCpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &SelfAdjointEigCpuKernelMod::LaunchKernel>}}; + +std::vector SelfAdjointEigCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SelfAdjointEig, SelfAdjointEigCpuKernelMod); +} // namespace mindspore::kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h index 27180f4cace..efc7ceae389 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/self_adjoint_eig_cpu_kernel.h @@ -1,55 +1,55 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SelfAdjointEigCpuKernelMod : public NativeCpuKernelMod { - public: - SelfAdjointEigCpuKernelMod() = default; - ~SelfAdjointEigCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - using SelfAdjointEigLaunchFunc = - std::function &, - const std::vector &)>; - static std::vector> func_list_; - TypeId dtype_{kTypeUnknown}; - std::vector input_shape_; - bool compute_v_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SelfAdjointEigCpuKernelMod : public NativeCpuKernelMod { + public: + SelfAdjointEigCpuKernelMod() = default; + ~SelfAdjointEigCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using SelfAdjointEigLaunchFunc = + std::function &, + const std::vector &)>; + static std::vector> func_list_; + TypeId dtype_{kTypeUnknown}; + std::vector input_shape_; + bool compute_v_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SELF_ADJOINT_EIG_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.cc index 9b01e97e374..b9abd3fec24 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.cc @@ -1,127 +1,127 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h" -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h" -#include "plugin/device/cpu/kernel/nnacl/errorcode.h" -#include "utils/ms_utils.h" -#include "include/common/thread_pool.h" -#include "mindspore/core/ops/sequence_stack.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSequenceStackOutputsNum = 1; - -using complex64 = std::complex; -using complex128 = std::complex; -} // namespace - -bool SequenceStackFwdCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSequenceStackOutputsNum, kernel_name_); - return MatchKernelFunc(kernel_name_, inputs, outputs); -} - -int SequenceStackFwdCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - tuple_shape_ = inputs[0]->GetShapeVector(); - if (tuple_shape_.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0"; - } - std::vector shape_vec_item; - std::copy(tuple_shape_.begin() + 1, tuple_shape_.end(), std::back_inserter(shape_vec_item)); - - input_num_ = tuple_shape_[0]; - axis_ = GetValue(primitive_->GetAttr(ops::kAxis)); - if (axis_ < 0) { - axis_ += (SizeToInt(shape_vec_item.size()) + 1); - } - dims_behind_axis_ = 1; - // calculate elements while dim >= axis - for (size_t i = IntToSize(axis_); i < shape_vec_item.size(); i++) { - dims_behind_axis_ *= static_cast(shape_vec_item[i]); - } - auto output_shape = outputs.at(kIndex0)->GetShapeVector(); - output_size_ = 1; - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= static_cast(output_shape[i]); - } - return KRET_OK; -} - -template -bool SequenceStackFwdCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - const auto input_addr = GetDeviceAddress(inputs, 0); - auto output_addr = GetDeviceAddress(outputs, 0); - - size_t element_index_size = - static_cast(std::accumulate(tuple_shape_.begin() + 1, tuple_shape_.end(), 1, std::multiplies())); - - // multi-threading - size_t input_size = output_size_; - size_t dims_behind_axis = dims_behind_axis_; - size_t copy_time = input_size / dims_behind_axis; - size_t single_copy_size = dims_behind_axis * sizeof(T); - auto task = [&](size_t start, size_t end) { - for (size_t pos = start; pos < end; ++pos) { - size_t cur_input_index = pos % this->input_num_; - size_t local_idx = pos / this->input_num_; - auto ret = - memcpy_s(output_addr + dims_behind_axis * pos, single_copy_size, - input_addr + cur_input_index * element_index_size + dims_behind_axis * local_idx, single_copy_size); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy_s failed: " << ret; - } - } - }; - ParallelLaunchAutoSearch(task, copy_time, this, ¶llel_search_info_); - return true; -} - -#define SEQUENCE_STACK_REG(ms_type, builtin_type) \ - { \ - KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \ - &SequenceStackFwdCpuKernelMod::LaunchKernel \ - } - -const SequenceStackFwdCpuKernelMod::FuncList &SequenceStackFwdCpuKernelMod::GetFuncList() const { - static const FuncList func_list = { - SEQUENCE_STACK_REG(kNumberTypeInt8, int8_t), SEQUENCE_STACK_REG(kNumberTypeInt16, int16_t), - SEQUENCE_STACK_REG(kNumberTypeInt32, int32_t), SEQUENCE_STACK_REG(kNumberTypeInt64, int64_t), - SEQUENCE_STACK_REG(kNumberTypeUInt8, uint8_t), SEQUENCE_STACK_REG(kNumberTypeUInt16, uint16_t), - SEQUENCE_STACK_REG(kNumberTypeUInt32, uint32_t), SEQUENCE_STACK_REG(kNumberTypeUInt64, uint64_t), - SEQUENCE_STACK_REG(kNumberTypeFloat16, float16), SEQUENCE_STACK_REG(kNumberTypeFloat32, float), - SEQUENCE_STACK_REG(kNumberTypeFloat64, double), SEQUENCE_STACK_REG(kNumberTypeComplex64, complex64), - SEQUENCE_STACK_REG(kNumberTypeComplex128, complex128), SEQUENCE_STACK_REG(kNumberTypeBool, bool)}; - return func_list; -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SequenceStack, SequenceStackFwdCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h" +#include "plugin/device/cpu/kernel/nnacl/errorcode.h" +#include "utils/ms_utils.h" +#include "include/common/thread_pool.h" +#include "mindspore/core/ops/sequence_stack.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSequenceStackOutputsNum = 1; + +using complex64 = std::complex; +using complex128 = std::complex; +} // namespace + +bool SequenceStackFwdCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSequenceStackOutputsNum, kernel_name_); + return MatchKernelFunc(kernel_name_, inputs, outputs); +} + +int SequenceStackFwdCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + tuple_shape_ = inputs[0]->GetShapeVector(); + if (tuple_shape_.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0"; + } + std::vector shape_vec_item; + std::copy(tuple_shape_.begin() + 1, tuple_shape_.end(), std::back_inserter(shape_vec_item)); + + input_num_ = tuple_shape_[0]; + axis_ = GetValue(primitive_->GetAttr(ops::kAxis)); + if (axis_ < 0) { + axis_ += (SizeToInt(shape_vec_item.size()) + 1); + } + dims_behind_axis_ = 1; + // calculate elements while dim >= axis + for (size_t i = IntToSize(axis_); i < shape_vec_item.size(); i++) { + dims_behind_axis_ *= static_cast(shape_vec_item[i]); + } + auto output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_size_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= static_cast(output_shape[i]); + } + return KRET_OK; +} + +template +bool SequenceStackFwdCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + const auto input_addr = GetDeviceAddress(inputs, 0); + auto output_addr = GetDeviceAddress(outputs, 0); + + size_t element_index_size = + static_cast(std::accumulate(tuple_shape_.begin() + 1, tuple_shape_.end(), 1, std::multiplies())); + + // multi-threading + size_t input_size = output_size_; + size_t dims_behind_axis = dims_behind_axis_; + size_t copy_time = input_size / dims_behind_axis; + size_t single_copy_size = dims_behind_axis * sizeof(T); + auto task = [&](size_t start, size_t end) { + for (size_t pos = start; pos < end; ++pos) { + size_t cur_input_index = pos % this->input_num_; + size_t local_idx = pos / this->input_num_; + auto ret = + memcpy_s(output_addr + dims_behind_axis * pos, single_copy_size, + input_addr + cur_input_index * element_index_size + dims_behind_axis * local_idx, single_copy_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s failed: " << ret; + } + } + }; + ParallelLaunchAutoSearch(task, copy_time, this, ¶llel_search_info_); + return true; +} + +#define SEQUENCE_STACK_REG(ms_type, builtin_type) \ + { \ + KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \ + &SequenceStackFwdCpuKernelMod::LaunchKernel \ + } + +const SequenceStackFwdCpuKernelMod::FuncList &SequenceStackFwdCpuKernelMod::GetFuncList() const { + static const FuncList func_list = { + SEQUENCE_STACK_REG(kNumberTypeInt8, int8_t), SEQUENCE_STACK_REG(kNumberTypeInt16, int16_t), + SEQUENCE_STACK_REG(kNumberTypeInt32, int32_t), SEQUENCE_STACK_REG(kNumberTypeInt64, int64_t), + SEQUENCE_STACK_REG(kNumberTypeUInt8, uint8_t), SEQUENCE_STACK_REG(kNumberTypeUInt16, uint16_t), + SEQUENCE_STACK_REG(kNumberTypeUInt32, uint32_t), SEQUENCE_STACK_REG(kNumberTypeUInt64, uint64_t), + SEQUENCE_STACK_REG(kNumberTypeFloat16, float16), SEQUENCE_STACK_REG(kNumberTypeFloat32, float), + SEQUENCE_STACK_REG(kNumberTypeFloat64, double), SEQUENCE_STACK_REG(kNumberTypeComplex64, complex64), + SEQUENCE_STACK_REG(kNumberTypeComplex128, complex128), SEQUENCE_STACK_REG(kNumberTypeBool, bool)}; + return func_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SequenceStack, SequenceStackFwdCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h index a9c6ea9b2be..712d53cee26 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_stack_cpu_kernel.h @@ -1,63 +1,63 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SequenceStackFwdCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { - public: - SequenceStackFwdCpuKernelMod() = default; - ~SequenceStackFwdCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - MS_EXCEPTION_IF_NULL(kernel_func_); - return kernel_func_(this, inputs, workspace, outputs); - } - - using FuncList = std::vector>; - const FuncList &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - std::vector tuple_shape_; - int axis_{0}; - size_t input_num_{1}; - size_t output_size_{0}; - size_t dims_behind_axis_{1}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SequenceStackFwdCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + SequenceStackFwdCpuKernelMod() = default; + ~SequenceStackFwdCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs); + } + + using FuncList = std::vector>; + const FuncList &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + std::vector tuple_shape_; + int axis_{0}; + size_t input_num_{1}; + size_t output_size_{0}; + size_t dims_behind_axis_{1}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEQUENCE_STACK_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.cc index 1bf1eaa2063..155af4c4dcb 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.cc @@ -1,140 +1,140 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/cpu/kernel/sinc_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSincInputsNum = 1; -constexpr size_t kSincOutputsNum = 1; -} // namespace - -bool SincCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - return MatchKernelFunc(kernel_name_, inputs, outputs); -} - -template -bool SincCpuKernelMod::LaunchSameKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); - auto input = static_cast(inputs[0]->device_ptr()); - auto output = static_cast(outputs[0]->device_ptr()); - size_t total = inputs[0]->size() / sizeof(T); - auto task = [&input, &output](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - if (static_cast(input[i]) == static_cast(0.0f)) { - output[i] = static_cast(1.0f); - } else { - T pi = static_cast(3.14159265358979323846L); - T product = pi * input[i]; - output[i] = sin(product) / product; - } - } - }; - ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); - return true; -} - -template -bool SincCpuKernelMod::LaunchDiffKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); - auto input = static_cast(inputs[0]->device_ptr()); - auto output = static_cast(outputs[0]->device_ptr()); - size_t total = inputs[0]->size() / sizeof(T); - auto task = [&input, &output](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - if (input[i] == static_cast(0)) { - output[i] = static_cast(1.0f); - } else { - float pi = static_cast(3.14159265358979323846); - float product = pi * input[i]; - output[i] = sin(product) / product; - } - } - }; - ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); - return true; -} - -template -bool SincCpuKernelMod::LaunchBoolKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); - auto input = static_cast(inputs[0]->device_ptr()); - auto output = static_cast(outputs[0]->device_ptr()); - size_t total = inputs[0]->size() / sizeof(T); - auto task = [&input, &output](size_t start, size_t end) { - for (size_t i = start; i < end; i++) { - float tmp; - if (input[i] == true) { - tmp = 1.0f; - } else { - tmp = 0.0f; - } - float pi = 3.14159265358979323846; - float product = pi * tmp; - output[i] = sin(product) / product; - } - }; - ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); - return true; -} - -const std::vector> &SincCpuKernelMod::GetFuncList() const { - static const std::vector> func_list = { - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchDiffKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &SincCpuKernelMod::LaunchSameKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchSameKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &SincCpuKernelMod::LaunchSameKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - &SincCpuKernelMod::LaunchSameKernel>}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - &SincCpuKernelMod::LaunchSameKernel>}, - {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), - &SincCpuKernelMod::LaunchBoolKernel}}; - return func_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Sinc, SincCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/cpu/kernel/sinc_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSincInputsNum = 1; +constexpr size_t kSincOutputsNum = 1; +} // namespace + +bool SincCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + return MatchKernelFunc(kernel_name_, inputs, outputs); +} + +template +bool SincCpuKernelMod::LaunchSameKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); + auto input = static_cast(inputs[0]->device_ptr()); + auto output = static_cast(outputs[0]->device_ptr()); + size_t total = inputs[0]->size() / sizeof(T); + auto task = [&input, &output](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + if (static_cast(input[i]) == static_cast(0.0f)) { + output[i] = static_cast(1.0f); + } else { + T pi = static_cast(3.14159265358979323846L); + T product = pi * input[i]; + output[i] = sin(product) / product; + } + } + }; + ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); + return true; +} + +template +bool SincCpuKernelMod::LaunchDiffKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); + auto input = static_cast(inputs[0]->device_ptr()); + auto output = static_cast(outputs[0]->device_ptr()); + size_t total = inputs[0]->size() / sizeof(T); + auto task = [&input, &output](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + if (input[i] == static_cast(0)) { + output[i] = static_cast(1.0f); + } else { + float pi = static_cast(3.14159265358979323846); + float product = pi * input[i]; + output[i] = sin(product) / product; + } + } + }; + ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); + return true; +} + +template +bool SincCpuKernelMod::LaunchBoolKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSincInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSincOutputsNum, kernel_name_); + auto input = static_cast(inputs[0]->device_ptr()); + auto output = static_cast(outputs[0]->device_ptr()); + size_t total = inputs[0]->size() / sizeof(T); + auto task = [&input, &output](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + float tmp; + if (input[i] == true) { + tmp = 1.0f; + } else { + tmp = 0.0f; + } + float pi = 3.14159265358979323846; + float product = pi * tmp; + output[i] = sin(product) / product; + } + }; + ParallelLaunchAutoSearch(task, total, this, ¶llel_search_info_); + return true; +} + +const std::vector> &SincCpuKernelMod::GetFuncList() const { + static const std::vector> func_list = { + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchDiffKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &SincCpuKernelMod::LaunchSameKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchSameKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &SincCpuKernelMod::LaunchSameKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &SincCpuKernelMod::LaunchSameKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &SincCpuKernelMod::LaunchSameKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), + &SincCpuKernelMod::LaunchBoolKernel}}; + return func_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Sinc, SincCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.h index f58a8921a89..b41ac5d248e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sinc_cpu_kernel.h @@ -1,62 +1,62 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ - -#include -#include - -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SincCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { - public: - SincCpuKernelMod() = default; - ~SincCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - - private: - template - bool LaunchSameKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - template - bool LaunchDiffKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - template - bool LaunchBoolKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ + +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SincCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + SincCpuKernelMod() = default; + ~SincCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + + private: + template + bool LaunchSameKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + template + bool LaunchDiffKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + template + bool LaunchBoolKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SINC_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.cc index 3e7a835bbe3..95fb834bd40 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.cc @@ -1,379 +1,379 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h" -#include -#include -#include - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseAddmmInputsNum = 7; -constexpr size_t kSparseAddmmOutputsNum = 1; -constexpr size_t kSparseAddmmOutputShapeSize = 2; -constexpr size_t kSparseAddmmDenseShapeSize = 2; -constexpr size_t kIndicesSizeNum = 2; -constexpr size_t kIndices2rdDimNum = 2; -constexpr size_t kShapeValue = 0; -constexpr size_t kIndex0 = 0; -constexpr size_t kIndex1 = 1; -constexpr size_t kIndex2 = 2; -constexpr size_t kIndex3 = 3; -constexpr size_t kIndex4 = 4; -constexpr size_t kIndex5 = 5; -constexpr size_t kIndex6 = 6; -} // namespace - -bool SparseAddmmCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << "SparseAddmm does not support this kernel data type: " << kernel_attr; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int SparseAddmmCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto indices_shape = inputs.at(kIndex0)->GetShapeVector(); - if (indices_shape.size() != kIndicesSizeNum && LongToSize(indices_shape[1]) != kIndices2rdDimNum) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', it requires 'indices' should be a 2-D Tensor and the second dimension length " - "should be 2, but got 'indices' shape: " - << indices_shape; - } - auto values_shape = inputs.at(kIndex1)->GetShapeVector(); - if (values_shape.size() != 1 || values_shape[0] != indices_shape[0]) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', it requires 'values' should be a 1-D Tensor and the first dimension length " - " should be equal to the first dimension length of 'indices', but got 'values' shape: " - << values_shape << " and 'indices' shape: " << indices_shape; - } - output_shape_ = Convert2SizeT(outputs[0]->GetShapeVector()); - values_size_ = LongToSize(values_shape[0]); - b_shape_ = Convert2SizeT(inputs.at(kIndex3)->GetShapeVector()); - if (b_shape_.size() != kSparseAddmmDenseShapeSize) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'dense' should be " - << kSparseAddmmDenseShapeSize << "-D, but got " << b_shape_.size() << "-D"; - } - if (output_shape_.size() != kSparseAddmmOutputShapeSize) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be " - << kSparseAddmmOutputShapeSize << "-D, but got " << output_shape_.size() << "-D"; - } - return KRET_OK; -} - -template -bool SparseAddmmCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseAddmmInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseAddmmOutputsNum, kernel_name_); - auto ret = memset_s(outputs[0]->device_ptr(), outputs[0]->size(), 0, outputs[0]->size()); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret; - } - - auto *a_indices = static_cast(inputs[kIndex0]->device_ptr()); - auto *a_values = static_cast(inputs[kIndex1]->device_ptr()); - auto *x1_shape = static_cast(inputs[kIndex2]->device_ptr()); - auto *b = static_cast(inputs[kIndex3]->device_ptr()); - auto *c = static_cast(inputs[kIndex4]->device_ptr()); - auto *alpha = static_cast(inputs[kIndex5]->device_ptr()); - auto *beta = static_cast(inputs[kIndex6]->device_ptr()); - auto *out = static_cast(outputs[kIndex0]->device_ptr()); - - const size_t indices_length = inputs[kIndex0]->size() / sizeof(I); - const size_t values_length = inputs[kIndex1]->size() / sizeof(T); - const size_t b_length = inputs[kIndex3]->size() / sizeof(T); - - const size_t dim_num = 2; - const size_t out_dim_0 = output_shape_[0]; - const size_t out_dim_1 = output_shape_[1]; - const size_t b_dim_0 = b_shape_[0]; - const size_t b_dim_1 = b_shape_[1]; - const size_t same_dim = b_dim_0; - - const I x1_shape_0 = x1_shape[0]; - const I x1_shape_1 = x1_shape[1]; - - const size_t x1_shape_0_s = IntToSize(x1_shape_0); - const size_t x1_shape_1_s = IntToSize(x1_shape_1); - if (x1_shape_0_s <= kShapeValue || x1_shape_1_s <= kShapeValue) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'x1_shape' should be greater than 0."; - } - if (x1_shape_1_s != b_dim_0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the col of 'x1_shape' should be equal to the row of 'x2_dense'," - " but got col: " - << x1_shape_1_s << ", row: " << b_dim_0; - } - - for (size_t i = 0; i < values_size_; ++i) { - if (i * dim_num + 1 >= indices_length) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'indices' out of bounds."; - } - if (i >= values_length) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'values' out of bounds."; - } - - const int row = a_indices[i * dim_num]; - const int col = a_indices[i * dim_num + 1]; - if (row >= SizeToInt(out_dim_0) || row < 0 || col >= SizeToInt(same_dim) || col < 0) { - MS_EXCEPTION(ValueError) << "The indices including out of bounds index, row range: [0, " << out_dim_0 - << "), col range: [0, " << same_dim << "), but got row: " << row << ", col: " << col; - } - - const size_t row_s = IntToSize(row); - const size_t col_s = IntToSize(col); - const T alpha_value = *(alpha); - for (size_t n = 0; n < out_dim_1; ++n) { - if (col_s * b_dim_1 + n >= b_length) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'b' out of bounds."; - } - const T b_value = b[col_s * b_dim_1 + n]; - out[row_s * out_dim_1 + n] += alpha_value * a_values[i] * b_value; - } - } - - const T beta_value = *(beta); - for (size_t i = 0; i < out_dim_0; ++i) { - for (size_t j = 0; j < out_dim_1; ++j) { - const T c_value = c[i * out_dim_1 + j]; - out[i * out_dim_1 + j] += beta_value * c_value; - } - } - - return true; -} - -std::vector> SparseAddmmCpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt16), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt16), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeUInt16), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeUInt16), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddOutputAttr(kNumberTypeUInt32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddOutputAttr(kNumberTypeUInt32), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddOutputAttr(kNumberTypeUInt64), - &SparseAddmmCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddOutputAttr(kNumberTypeUInt64), - &SparseAddmmCpuKernelMod::LaunchKernel}}; - -std::vector SparseAddmmCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseAddmm, SparseAddmmCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h" +#include +#include +#include + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseAddmmInputsNum = 7; +constexpr size_t kSparseAddmmOutputsNum = 1; +constexpr size_t kSparseAddmmOutputShapeSize = 2; +constexpr size_t kSparseAddmmDenseShapeSize = 2; +constexpr size_t kIndicesSizeNum = 2; +constexpr size_t kIndices2rdDimNum = 2; +constexpr size_t kShapeValue = 0; +constexpr size_t kIndex0 = 0; +constexpr size_t kIndex1 = 1; +constexpr size_t kIndex2 = 2; +constexpr size_t kIndex3 = 3; +constexpr size_t kIndex4 = 4; +constexpr size_t kIndex5 = 5; +constexpr size_t kIndex6 = 6; +} // namespace + +bool SparseAddmmCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "SparseAddmm does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int SparseAddmmCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto indices_shape = inputs.at(kIndex0)->GetShapeVector(); + if (indices_shape.size() != kIndicesSizeNum && LongToSize(indices_shape[1]) != kIndices2rdDimNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', it requires 'indices' should be a 2-D Tensor and the second dimension length " + "should be 2, but got 'indices' shape: " + << indices_shape; + } + auto values_shape = inputs.at(kIndex1)->GetShapeVector(); + if (values_shape.size() != 1 || values_shape[0] != indices_shape[0]) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', it requires 'values' should be a 1-D Tensor and the first dimension length " + " should be equal to the first dimension length of 'indices', but got 'values' shape: " + << values_shape << " and 'indices' shape: " << indices_shape; + } + output_shape_ = Convert2SizeT(outputs[0]->GetShapeVector()); + values_size_ = LongToSize(values_shape[0]); + b_shape_ = Convert2SizeT(inputs.at(kIndex3)->GetShapeVector()); + if (b_shape_.size() != kSparseAddmmDenseShapeSize) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'dense' should be " + << kSparseAddmmDenseShapeSize << "-D, but got " << b_shape_.size() << "-D"; + } + if (output_shape_.size() != kSparseAddmmOutputShapeSize) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be " + << kSparseAddmmOutputShapeSize << "-D, but got " << output_shape_.size() << "-D"; + } + return KRET_OK; +} + +template +bool SparseAddmmCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseAddmmInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseAddmmOutputsNum, kernel_name_); + auto ret = memset_s(outputs[0]->device_ptr(), outputs[0]->size(), 0, outputs[0]->size()); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret; + } + + auto *a_indices = static_cast(inputs[kIndex0]->device_ptr()); + auto *a_values = static_cast(inputs[kIndex1]->device_ptr()); + auto *x1_shape = static_cast(inputs[kIndex2]->device_ptr()); + auto *b = static_cast(inputs[kIndex3]->device_ptr()); + auto *c = static_cast(inputs[kIndex4]->device_ptr()); + auto *alpha = static_cast(inputs[kIndex5]->device_ptr()); + auto *beta = static_cast(inputs[kIndex6]->device_ptr()); + auto *out = static_cast(outputs[kIndex0]->device_ptr()); + + const size_t indices_length = inputs[kIndex0]->size() / sizeof(I); + const size_t values_length = inputs[kIndex1]->size() / sizeof(T); + const size_t b_length = inputs[kIndex3]->size() / sizeof(T); + + const size_t dim_num = 2; + const size_t out_dim_0 = output_shape_[0]; + const size_t out_dim_1 = output_shape_[1]; + const size_t b_dim_0 = b_shape_[0]; + const size_t b_dim_1 = b_shape_[1]; + const size_t same_dim = b_dim_0; + + const I x1_shape_0 = x1_shape[0]; + const I x1_shape_1 = x1_shape[1]; + + const size_t x1_shape_0_s = IntToSize(x1_shape_0); + const size_t x1_shape_1_s = IntToSize(x1_shape_1); + if (x1_shape_0_s <= kShapeValue || x1_shape_1_s <= kShapeValue) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'x1_shape' should be greater than 0."; + } + if (x1_shape_1_s != b_dim_0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the col of 'x1_shape' should be equal to the row of 'x2_dense'," + " but got col: " + << x1_shape_1_s << ", row: " << b_dim_0; + } + + for (size_t i = 0; i < values_size_; ++i) { + if (i * dim_num + 1 >= indices_length) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'indices' out of bounds."; + } + if (i >= values_length) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'values' out of bounds."; + } + + const int row = a_indices[i * dim_num]; + const int col = a_indices[i * dim_num + 1]; + if (row >= SizeToInt(out_dim_0) || row < 0 || col >= SizeToInt(same_dim) || col < 0) { + MS_EXCEPTION(ValueError) << "The indices including out of bounds index, row range: [0, " << out_dim_0 + << "), col range: [0, " << same_dim << "), but got row: " << row << ", col: " << col; + } + + const size_t row_s = IntToSize(row); + const size_t col_s = IntToSize(col); + const T alpha_value = *(alpha); + for (size_t n = 0; n < out_dim_1; ++n) { + if (col_s * b_dim_1 + n >= b_length) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'b' out of bounds."; + } + const T b_value = b[col_s * b_dim_1 + n]; + out[row_s * out_dim_1 + n] += alpha_value * a_values[i] * b_value; + } + } + + const T beta_value = *(beta); + for (size_t i = 0; i < out_dim_0; ++i) { + for (size_t j = 0; j < out_dim_1; ++j) { + const T c_value = c[i * out_dim_1 + j]; + out[i * out_dim_1 + j] += beta_value * c_value; + } + } + + return true; +} + +std::vector> SparseAddmmCpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + &SparseAddmmCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + &SparseAddmmCpuKernelMod::LaunchKernel}}; + +std::vector SparseAddmmCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseAddmm, SparseAddmmCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h index 6b3292a68bd..553e90a865d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h @@ -1,61 +1,61 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SparseAddmmCpuKernelMod : public NativeCpuKernelMod { - public: - SparseAddmmCpuKernelMod() = default; - ~SparseAddmmCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &outputs); - using SparseAddmmFunc = std::function &, - const std::vector &)>; - static std::vector> func_list_; - SparseAddmmFunc kernel_func_; - - std::vector output_shape_; - std::vector b_shape_; - size_t output_size_{0}; - size_t values_size_{0}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseAddmmCpuKernelMod : public NativeCpuKernelMod { + public: + SparseAddmmCpuKernelMod() = default; + ~SparseAddmmCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &outputs); + using SparseAddmmFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + SparseAddmmFunc kernel_func_; + + std::vector output_shape_; + std::vector b_shape_; + size_t output_size_{0}; + size_t values_size_{0}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.cc index a92eef0a19b..5ad6ca4e3b4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.cc @@ -1,221 +1,221 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h" - -#include -#include -#include -#include - -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyMomentumInputsNum = 6; -constexpr size_t kSparseApplyMomentumOutputsNum = 1; - -using KernelRunFunc = SparseApplyMomentumCpuKernelMod::KernelRunFunc; - -#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7) \ - KernelAttr() \ - .AddInputAttr(kNumberType##t1) \ - .AddInputAttr(kNumberType##t2) \ - .AddInputAttr(kNumberType##t3) \ - .AddInputAttr(kNumberType##t4) \ - .AddInputAttr(kNumberType##t5) \ - .AddInputAttr(kNumberType##t6) \ - .AddOutputAttr(kNumberType##t7) -} // namespace - -bool SparseApplyMomentumCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; - return false; - } - if (inputs.size() != kSparseApplyMomentumInputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', input size must be " << kSparseApplyMomentumInputsNum - << ", but got " << inputs.size() << "."; - return false; - } - if (outputs.size() != kSparseApplyMomentumOutputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size must be " << kSparseApplyMomentumOutputsNum - << ", but got " << outputs.size() << "."; - return false; - } - use_nesterov_ = GetValue(primitive_->GetAttr(ops::kUseNesterov)); - if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { - return false; - } - return true; -} - -void SparseApplyMomentumCpuKernelMod::ResetResource() noexcept { - output_size_list_.clear(); - workspace_size_list_.clear(); - indices_data_type_ = kNumberTypeInt32; - indices_size_ = 0; - var_first_dim_size_ = 0; - var_outer_dim_size_ = 1; - use_nesterov_ = false; -} - -int SparseApplyMomentumCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - ResetResource(); - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - enum input_index : size_t { Var_no, Accum_no, Lr_no, Grad_no, Indices_no, Momentum_no }; - auto var_shape = inputs[static_cast(Var_no)]->GetShapeVector(); - auto accum_shape = inputs[static_cast(Accum_no)]->GetShapeVector(); - auto lr_shape = inputs[static_cast(Lr_no)]->GetShapeVector(); - auto grad_shape = inputs[static_cast(Grad_no)]->GetShapeVector(); - auto indices_shape = inputs[static_cast(Indices_no)]->GetShapeVector(); - auto momentum_shape = inputs[static_cast(Momentum_no)]->GetShapeVector(); - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, var must be at least 1D."; - } else { - var_first_dim_size_ = LongToSize(var_shape[0]); - } - if (var_shape.size() != grad_shape.size()) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, rank(grad) should be same as rank(var), but got rank(grad): " - << grad_shape.size() << ", rank(var): " << var_shape.size() << "."; - } - if (!IsSameShape(var_shape, accum_shape)) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, var and accum should have the same shape."; - } - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, the shape of var and grad must equal in dimension " << i << "."; - } - var_outer_dim_size_ *= LongToSize(var_shape[i]); - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, indices must be 1D, but got " << indices_shape.size() << "D."; - } - indices_size_ = LongToSize(indices_shape[0]); - if (grad_shape[0] != SizeToLong(indices_size_)) { - MS_LOG(EXCEPTION) - << "For SparseApplyMomentum, grad.shape[0] must be equal to indices.shape[0], but got grad.shape[0]: " - << grad_shape[0] << ", indices.shape[0]: " << indices_size_ << "."; - } - if (!lr_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, lr is not a scalar, got shape: " << lr_shape << "."; - } - if (!momentum_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, momentum is not a scalar, got shape: " << momentum_shape << "."; - } - return static_cast(KRET_OK); -} - -template -bool SparseApplyMomentumCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) const { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyMomentumInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyMomentumOutputsNum, kernel_name_); - - auto var = static_cast(inputs[0]->device_ptr()); - auto accum = static_cast(inputs[1]->device_ptr()); - auto grad = static_cast(inputs[3]->device_ptr()); - auto indices = static_cast(inputs[4]->device_ptr()); - auto lr_scalar = static_cast(inputs[2]->device_ptr())[0]; - auto momentum_scalar = static_cast(inputs[5]->device_ptr())[0]; - auto output = static_cast(outputs[0]->device_ptr()); - - for (size_t i = 0; i < indices_size_; ++i) { - I index = indices[i]; - if (index < 0 || LongToSize(index) >= var_first_dim_size_) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, values in indices should be [0, var.shape[0]), but got " << index - << "."; - } - size_t start_index = var_outer_dim_size_ * static_cast(index); - size_t end_index = start_index + var_outer_dim_size_; - for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { - accum[j] = accum[j] * momentum_scalar + grad[k]; - if (use_nesterov_) { - var[j] -= lr_scalar * grad[k] + lr_scalar * momentum_scalar * accum[j]; - } else { - var[j] -= lr_scalar * accum[j]; - } - } - } - - size_t copy_size = var_first_dim_size_ * var_outer_dim_size_ * sizeof(T); - auto ret = memcpy_s(output, copy_size, var, copy_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "For SparseApplyMomentum, memcpy_s error, errorno: " << ret << "."; - } - - return true; -} - -const std::vector> &SparseApplyMomentumCpuKernelMod::GetFuncList() const { - static const std::vector> func_list_ = { - {ADD_KERNEL(Int8, Int8, Int8, Int8, Int32, Int8, Int8), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int16, Int16, Int16, Int16, Int32, Int16, Int16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int64, Int64, Int64, Int64, Int32, Int64, Int64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, Int32, UInt8, UInt8), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, Int32, UInt16, UInt16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, Int32, UInt32, UInt32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, Int32, UInt64, UInt64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float16, Float16, Float16, Float16, Int32, Float16, Float16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float32, Float32, Float32, Float32, Int32, Float32, Float32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float64, Float64, Float64, Float64, Int32, Float64, Float64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int8, Int8, Int8, Int8, Int64, Int8, Int8), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int16, Int16, Int16, Int16, Int64, Int16, Int16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int32, Int32, Int32, Int32, Int64, Int32, Int32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, Int64, UInt8, UInt8), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, Int64, UInt16, UInt16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, Int64, UInt32, UInt32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, Int64, UInt64, UInt64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float16, Float16, Float16, Float16, Int64, Float16, Float16), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float32, Float32, Float32, Float32, Int64, Float32, Float32), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float64, Float64, Float64, Float64, Int64, Float64, Float64), - &SparseApplyMomentumCpuKernelMod::LaunchKernel}}; - return func_list_; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyMomentum, SparseApplyMomentumCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h" + +#include +#include +#include +#include + +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyMomentumInputsNum = 6; +constexpr size_t kSparseApplyMomentumOutputsNum = 1; + +using KernelRunFunc = SparseApplyMomentumCpuKernelMod::KernelRunFunc; + +#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddInputAttr(kNumberType##t4) \ + .AddInputAttr(kNumberType##t5) \ + .AddInputAttr(kNumberType##t6) \ + .AddOutputAttr(kNumberType##t7) +} // namespace + +bool SparseApplyMomentumCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; + return false; + } + if (inputs.size() != kSparseApplyMomentumInputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', input size must be " << kSparseApplyMomentumInputsNum + << ", but got " << inputs.size() << "."; + return false; + } + if (outputs.size() != kSparseApplyMomentumOutputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size must be " << kSparseApplyMomentumOutputsNum + << ", but got " << outputs.size() << "."; + return false; + } + use_nesterov_ = GetValue(primitive_->GetAttr(ops::kUseNesterov)); + if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { + return false; + } + return true; +} + +void SparseApplyMomentumCpuKernelMod::ResetResource() noexcept { + output_size_list_.clear(); + workspace_size_list_.clear(); + indices_data_type_ = kNumberTypeInt32; + indices_size_ = 0; + var_first_dim_size_ = 0; + var_outer_dim_size_ = 1; + use_nesterov_ = false; +} + +int SparseApplyMomentumCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + ResetResource(); + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + enum input_index : size_t { Var_no, Accum_no, Lr_no, Grad_no, Indices_no, Momentum_no }; + auto var_shape = inputs[static_cast(Var_no)]->GetShapeVector(); + auto accum_shape = inputs[static_cast(Accum_no)]->GetShapeVector(); + auto lr_shape = inputs[static_cast(Lr_no)]->GetShapeVector(); + auto grad_shape = inputs[static_cast(Grad_no)]->GetShapeVector(); + auto indices_shape = inputs[static_cast(Indices_no)]->GetShapeVector(); + auto momentum_shape = inputs[static_cast(Momentum_no)]->GetShapeVector(); + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, var must be at least 1D."; + } else { + var_first_dim_size_ = LongToSize(var_shape[0]); + } + if (var_shape.size() != grad_shape.size()) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, rank(grad) should be same as rank(var), but got rank(grad): " + << grad_shape.size() << ", rank(var): " << var_shape.size() << "."; + } + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, var and accum should have the same shape."; + } + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, the shape of var and grad must equal in dimension " << i << "."; + } + var_outer_dim_size_ *= LongToSize(var_shape[i]); + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, indices must be 1D, but got " << indices_shape.size() << "D."; + } + indices_size_ = LongToSize(indices_shape[0]); + if (grad_shape[0] != SizeToLong(indices_size_)) { + MS_LOG(EXCEPTION) + << "For SparseApplyMomentum, grad.shape[0] must be equal to indices.shape[0], but got grad.shape[0]: " + << grad_shape[0] << ", indices.shape[0]: " << indices_size_ << "."; + } + if (!lr_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, lr is not a scalar, got shape: " << lr_shape << "."; + } + if (!momentum_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, momentum is not a scalar, got shape: " << momentum_shape << "."; + } + return static_cast(KRET_OK); +} + +template +bool SparseApplyMomentumCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) const { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyMomentumInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyMomentumOutputsNum, kernel_name_); + + auto var = static_cast(inputs[0]->device_ptr()); + auto accum = static_cast(inputs[1]->device_ptr()); + auto grad = static_cast(inputs[3]->device_ptr()); + auto indices = static_cast(inputs[4]->device_ptr()); + auto lr_scalar = static_cast(inputs[2]->device_ptr())[0]; + auto momentum_scalar = static_cast(inputs[5]->device_ptr())[0]; + auto output = static_cast(outputs[0]->device_ptr()); + + for (size_t i = 0; i < indices_size_; ++i) { + I index = indices[i]; + if (index < 0 || LongToSize(index) >= var_first_dim_size_) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, values in indices should be [0, var.shape[0]), but got " << index + << "."; + } + size_t start_index = var_outer_dim_size_ * static_cast(index); + size_t end_index = start_index + var_outer_dim_size_; + for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { + accum[j] = accum[j] * momentum_scalar + grad[k]; + if (use_nesterov_) { + var[j] -= lr_scalar * grad[k] + lr_scalar * momentum_scalar * accum[j]; + } else { + var[j] -= lr_scalar * accum[j]; + } + } + } + + size_t copy_size = var_first_dim_size_ * var_outer_dim_size_ * sizeof(T); + auto ret = memcpy_s(output, copy_size, var, copy_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For SparseApplyMomentum, memcpy_s error, errorno: " << ret << "."; + } + + return true; +} + +const std::vector> &SparseApplyMomentumCpuKernelMod::GetFuncList() const { + static const std::vector> func_list_ = { + {ADD_KERNEL(Int8, Int8, Int8, Int8, Int32, Int8, Int8), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int16, Int16, Int16, Int16, Int32, Int16, Int16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int64, Int64, Int64, Int64, Int32, Int64, Int64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, Int32, UInt8, UInt8), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, Int32, UInt16, UInt16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, Int32, UInt32, UInt32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, Int32, UInt64, UInt64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float16, Float16, Float16, Float16, Int32, Float16, Float16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float32, Float32, Float32, Float32, Int32, Float32, Float32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float64, Float64, Float64, Float64, Int32, Float64, Float64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int8, Int8, Int8, Int8, Int64, Int8, Int8), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int16, Int16, Int16, Int16, Int64, Int16, Int16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int32, Int32, Int32, Int32, Int64, Int32, Int32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, Int64, UInt8, UInt8), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, Int64, UInt16, UInt16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, Int64, UInt32, UInt32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, Int64, UInt64, UInt64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float16, Float16, Float16, Float16, Int64, Float16, Float16), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float32, Float32, Float32, Float32, Int64, Float32, Float32), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float64, Float64, Float64, Float64, Int64, Float64, Float64), + &SparseApplyMomentumCpuKernelMod::LaunchKernel}}; + return func_list_; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyMomentum, SparseApplyMomentumCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h index 6b8baab83ac..827610ba711 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_momentum_cpu_kernel.h @@ -1,61 +1,61 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ - -#include -#include -#include - -#include "mindspore/core/ops/sparse_apply_momentum.h" -#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyMomentumCpuKernelMod : public SparseOptimizerCpuKernelMod, - public MatchKernelHelper { - public: - SparseApplyMomentumCpuKernelMod() = default; - ~SparseApplyMomentumCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) const; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - void ResetResource() noexcept; - - private: - bool use_nesterov_{false}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ + +#include +#include +#include + +#include "mindspore/core/ops/sparse_apply_momentum.h" +#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyMomentumCpuKernelMod : public SparseOptimizerCpuKernelMod, + public MatchKernelHelper { + public: + SparseApplyMomentumCpuKernelMod() = default; + ~SparseApplyMomentumCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) const; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + void ResetResource() noexcept; + + private: + bool use_nesterov_{false}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.cc index e142681ef35..44348394067 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.cc @@ -1,232 +1,232 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h" - -#include -#include -#include -#include - -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyProximalGradientDescentInputsNum = 6; -constexpr size_t kSparseApplyProximalGradientDescentOutputsNum = 1; - -using KernelRunFunc = SparseApplyProximalGradientDescentCpuKernelMod::KernelRunFunc; - -#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7) \ - KernelAttr() \ - .AddInputAttr(kNumberType##t1) \ - .AddInputAttr(kNumberType##t2) \ - .AddInputAttr(kNumberType##t3) \ - .AddInputAttr(kNumberType##t4) \ - .AddInputAttr(kNumberType##t5) \ - .AddInputAttr(kNumberType##t6) \ - .AddOutputAttr(kNumberType##t7) -} // namespace - -bool SparseApplyProximalGradientDescentCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; - return false; - } - if (inputs.size() != kSparseApplyProximalGradientDescentInputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', input size must be " << kSparseApplyProximalGradientDescentInputsNum - << ", but got " << inputs.size() << "."; - return false; - } - if (outputs.size() != kSparseApplyProximalGradientDescentOutputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size must be " - << kSparseApplyProximalGradientDescentOutputsNum << ", but got " << outputs.size() << "."; - return false; - } - if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { - return false; - } - return true; -} - -void SparseApplyProximalGradientDescentCpuKernelMod::ResetResouce() noexcept { - output_size_list_.clear(); - workspace_size_list_.clear(); - indices_data_type_ = kNumberTypeInt32; - indices_size_ = 0; - var_first_dim_size_ = 0; - var_outer_dim_size_ = 1; -} - -int SparseApplyProximalGradientDescentCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - ResetResouce(); - int ret = KernelMod::Resize(inputs, outputs); - if (ret != static_cast(KRET_OK)) { - return ret; - } - enum input_index : size_t { Var_no, Alpha_no, L1_no, L2_no, Grad_no, Indices_no }; - auto var_shape = inputs[static_cast(Var_no)]->GetShapeVector(); - auto alpha_shape = inputs[static_cast(Alpha_no)]->GetShapeVector(); - auto l1_shape = inputs[static_cast(L1_no)]->GetShapeVector(); - auto l2_shape = inputs[static_cast(L2_no)]->GetShapeVector(); - auto grad_shape = inputs[static_cast(Grad_no)]->GetShapeVector(); - auto indices_shape = inputs[static_cast(Indices_no)]->GetShapeVector(); - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, var must be at least 1D."; - } else { - var_first_dim_size_ = LongToSize(var_shape[0]); - } - if (var_shape.size() != grad_shape.size()) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, rank(grad) should be same as rank(var), but " - "got rank(grad): " - << grad_shape.size() << ", rank(var): " << var_shape.size() << "."; - } - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, the shape of var and grad must equal in dimension " - << i << "."; - } - var_outer_dim_size_ *= LongToSize(var_shape[i]); - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, indices must be 1D, but got " << indices_shape.size() - << "D."; - } - indices_size_ = LongToSize(indices_shape[0]); - if (grad_shape[0] != SizeToLong(indices_size_)) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, grad.shape[0] must be equal to indices.shape[0], but " - "got grad.shape[0]: " - << grad_shape[0] << " indices.shape[0]: " << indices_size_ << "."; - } - if (!alpha_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, alpha is not a scalar, got shape: " << alpha_shape - << "."; - } - if (!l1_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l1 is not a scalar, got shape: " << l1_shape << "."; - } - if (!l2_shape.empty()) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l2 is not a scalar, got shape: " << l2_shape << "."; - } - return static_cast(KRET_OK); -} - -template -bool SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel( - const std::vector &inputs, const std::vector &, - const std::vector &outputs) const { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyProximalGradientDescentInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyProximalGradientDescentOutputsNum, kernel_name_); - - auto var = static_cast(inputs[0]->device_ptr()); - auto grad = static_cast(inputs[4]->device_ptr()); - auto indices = static_cast(inputs[5]->device_ptr()); - auto alpha_scalar = static_cast(inputs[1]->device_ptr())[0]; - auto l1_scalar = static_cast(inputs[2]->device_ptr())[0]; - auto l2_scalar = static_cast(inputs[3]->device_ptr())[0]; - auto output = static_cast(outputs[0]->device_ptr()); - - for (size_t i = 0; i < indices_size_; i++) { - I index = indices[i]; - if (index < 0 || LongToSize(index) >= var_first_dim_size_) { - MS_LOG(EXCEPTION) - << "For SparseApplyProximalGradientDescent, values in indices should be [0, var.shape[0]), but got " << index - << "."; - } - size_t start_index = var_outer_dim_size_ * static_cast(index); - size_t end_index = start_index + var_outer_dim_size_; - for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { - auto learning_rate = alpha_scalar; - auto prox_v = var[j]; - prox_v -= grad[k] * learning_rate; - if (l1_scalar > static_cast(0.0)) { - var[j] = static_cast(Sign(static_cast(prox_v))) * - static_cast(std::fmax(std::fabs(static_cast(prox_v)) - - static_cast(learning_rate) * static_cast(l1_scalar), - static_cast(0.0))) / - (static_cast(1.0) + l2_scalar * learning_rate); - } else { - var[j] = static_cast(prox_v) / (static_cast(1.0) + l2_scalar * learning_rate); - } - } - } - - auto copy_size = var_first_dim_size_ * var_outer_dim_size_ * sizeof(T); - auto ret = memcpy_s(output, copy_size, var, copy_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, memcpy_s error, errorno: " << ret << "."; - } - - return true; -} - -const std::vector> &SparseApplyProximalGradientDescentCpuKernelMod::GetFuncList() - const { - static const std::vector> func_list_ = { - {ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int32, Int8), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int32, Int16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int32, Int64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int32, UInt8), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int32, UInt16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int32, UInt32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int32, UInt64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int32, Float16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int32, Float32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int32, Float64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int64, Int8), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int64, Int16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int64, Int32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int64, UInt8), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int64, UInt16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int64, UInt32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int64, UInt64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int64, Float16), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int64, Float32), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, - {ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int64, Float64), - &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}}; - return func_list_; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyProximalGradientDescent, - SparseApplyProximalGradientDescentCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h" + +#include +#include +#include +#include + +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyProximalGradientDescentInputsNum = 6; +constexpr size_t kSparseApplyProximalGradientDescentOutputsNum = 1; + +using KernelRunFunc = SparseApplyProximalGradientDescentCpuKernelMod::KernelRunFunc; + +#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddInputAttr(kNumberType##t4) \ + .AddInputAttr(kNumberType##t5) \ + .AddInputAttr(kNumberType##t6) \ + .AddOutputAttr(kNumberType##t7) +} // namespace + +bool SparseApplyProximalGradientDescentCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; + return false; + } + if (inputs.size() != kSparseApplyProximalGradientDescentInputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', input size must be " << kSparseApplyProximalGradientDescentInputsNum + << ", but got " << inputs.size() << "."; + return false; + } + if (outputs.size() != kSparseApplyProximalGradientDescentOutputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size must be " + << kSparseApplyProximalGradientDescentOutputsNum << ", but got " << outputs.size() << "."; + return false; + } + if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { + return false; + } + return true; +} + +void SparseApplyProximalGradientDescentCpuKernelMod::ResetResouce() noexcept { + output_size_list_.clear(); + workspace_size_list_.clear(); + indices_data_type_ = kNumberTypeInt32; + indices_size_ = 0; + var_first_dim_size_ = 0; + var_outer_dim_size_ = 1; +} + +int SparseApplyProximalGradientDescentCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + ResetResouce(); + int ret = KernelMod::Resize(inputs, outputs); + if (ret != static_cast(KRET_OK)) { + return ret; + } + enum input_index : size_t { Var_no, Alpha_no, L1_no, L2_no, Grad_no, Indices_no }; + auto var_shape = inputs[static_cast(Var_no)]->GetShapeVector(); + auto alpha_shape = inputs[static_cast(Alpha_no)]->GetShapeVector(); + auto l1_shape = inputs[static_cast(L1_no)]->GetShapeVector(); + auto l2_shape = inputs[static_cast(L2_no)]->GetShapeVector(); + auto grad_shape = inputs[static_cast(Grad_no)]->GetShapeVector(); + auto indices_shape = inputs[static_cast(Indices_no)]->GetShapeVector(); + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, var must be at least 1D."; + } else { + var_first_dim_size_ = LongToSize(var_shape[0]); + } + if (var_shape.size() != grad_shape.size()) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, rank(grad) should be same as rank(var), but " + "got rank(grad): " + << grad_shape.size() << ", rank(var): " << var_shape.size() << "."; + } + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, the shape of var and grad must equal in dimension " + << i << "."; + } + var_outer_dim_size_ *= LongToSize(var_shape[i]); + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, indices must be 1D, but got " << indices_shape.size() + << "D."; + } + indices_size_ = LongToSize(indices_shape[0]); + if (grad_shape[0] != SizeToLong(indices_size_)) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, grad.shape[0] must be equal to indices.shape[0], but " + "got grad.shape[0]: " + << grad_shape[0] << " indices.shape[0]: " << indices_size_ << "."; + } + if (!alpha_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, alpha is not a scalar, got shape: " << alpha_shape + << "."; + } + if (!l1_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l1 is not a scalar, got shape: " << l1_shape << "."; + } + if (!l2_shape.empty()) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l2 is not a scalar, got shape: " << l2_shape << "."; + } + return static_cast(KRET_OK); +} + +template +bool SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel( + const std::vector &inputs, const std::vector &, + const std::vector &outputs) const { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyProximalGradientDescentInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyProximalGradientDescentOutputsNum, kernel_name_); + + auto var = static_cast(inputs[0]->device_ptr()); + auto grad = static_cast(inputs[4]->device_ptr()); + auto indices = static_cast(inputs[5]->device_ptr()); + auto alpha_scalar = static_cast(inputs[1]->device_ptr())[0]; + auto l1_scalar = static_cast(inputs[2]->device_ptr())[0]; + auto l2_scalar = static_cast(inputs[3]->device_ptr())[0]; + auto output = static_cast(outputs[0]->device_ptr()); + + for (size_t i = 0; i < indices_size_; i++) { + I index = indices[i]; + if (index < 0 || LongToSize(index) >= var_first_dim_size_) { + MS_LOG(EXCEPTION) + << "For SparseApplyProximalGradientDescent, values in indices should be [0, var.shape[0]), but got " << index + << "."; + } + size_t start_index = var_outer_dim_size_ * static_cast(index); + size_t end_index = start_index + var_outer_dim_size_; + for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) { + auto learning_rate = alpha_scalar; + auto prox_v = var[j]; + prox_v -= grad[k] * learning_rate; + if (l1_scalar > static_cast(0.0)) { + var[j] = static_cast(Sign(static_cast(prox_v))) * + static_cast(std::fmax(std::fabs(static_cast(prox_v)) - + static_cast(learning_rate) * static_cast(l1_scalar), + static_cast(0.0))) / + (static_cast(1.0) + l2_scalar * learning_rate); + } else { + var[j] = static_cast(prox_v) / (static_cast(1.0) + l2_scalar * learning_rate); + } + } + } + + auto copy_size = var_first_dim_size_ * var_outer_dim_size_ * sizeof(T); + auto ret = memcpy_s(output, copy_size, var, copy_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, memcpy_s error, errorno: " << ret << "."; + } + + return true; +} + +const std::vector> &SparseApplyProximalGradientDescentCpuKernelMod::GetFuncList() + const { + static const std::vector> func_list_ = { + {ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int32, Int8), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int32, Int16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int32, Int64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int32, UInt8), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int32, UInt16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int32, UInt32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int32, UInt64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int32, Float16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int32, Float32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int32, Float64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int64, Int8), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int64, Int16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int64, Int32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int64, UInt8), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int64, UInt16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int64, UInt32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int64, UInt64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int64, Float16), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int64, Float32), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}, + {ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int64, Float64), + &SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel}}; + return func_list_; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyProximalGradientDescent, + SparseApplyProximalGradientDescentCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h index 45f5115b9d8..c8b6e0424fe 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h @@ -1,59 +1,59 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ - -#include -#include -#include - -#include "mindspore/core/ops/sparse_apply_proximal_gradient_descent.h" -#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyProximalGradientDescentCpuKernelMod - : public SparseOptimizerCpuKernelMod, - public MatchKernelHelper { - public: - SparseApplyProximalGradientDescentCpuKernelMod() = default; - ~SparseApplyProximalGradientDescentCpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) const; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); } - void ResetResouce() noexcept; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ + +#include +#include +#include + +#include "mindspore/core/ops/sparse_apply_proximal_gradient_descent.h" +#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyProximalGradientDescentCpuKernelMod + : public SparseOptimizerCpuKernelMod, + public MatchKernelHelper { + public: + SparseApplyProximalGradientDescentCpuKernelMod() = default; + ~SparseApplyProximalGradientDescentCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) const; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + void ResetResouce() noexcept; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.cc index e6fe153bc9d..fc7da699ae3 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.cc @@ -1,259 +1,259 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h" -#include -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kIndicesShapeSize = 2; -constexpr size_t kValuesShapeSize = 1; -constexpr size_t kShapeShapeSize = 1; -constexpr size_t kSparseSoftmaxInputsNum = 3; -constexpr size_t kSparseSoftmaxOutputsNum = 1; -constexpr size_t kShapeMinSize = 2; -constexpr size_t kinput_indices = 0; -constexpr size_t kinput_values = 1; -constexpr size_t kinput_shape = 2; -constexpr size_t kIndex0 = 0; -constexpr size_t kIndex1 = 1; -constexpr size_t kIndex2 = 2; - -template -inline bool CompareIndices(const I *a, const I *b, const size_t &len) { - size_t i = 0; - while (i < len) { - if (a[i] != b[i]) { - return a[i] > b[i]; - } - ++i; - } - return true; -} - -template -inline void CopyIndicesAndValue(I *dst_indices_addr, T *dst_values_addr, const I *src_indices_addr, - const T *src_values_addr, const size_t &indices_size) { - auto ret = memcpy_s(dst_indices_addr, indices_size, src_indices_addr, indices_size); - if (ret != EOK) { - MS_LOG(ERROR) << "Execute memcpy_s failed."; - } - *dst_values_addr = *src_values_addr; -} - -template -inline int64_t Partition(I *__restrict indices_addr, T *__restrict values_addr, I *__restrict tmp_indices, - const size_t &indices_len, const int64_t &left, const int64_t &right) { - int64_t i = left; - int64_t j = right; - T tmp_values = 0; - const size_t indices_size = indices_len * sizeof(I); -#define INDICES_OFFSET_ADDR(addr, index, len) (addr) + (index) * (len) - - CopyIndicesAndValue(tmp_indices, &tmp_values, INDICES_OFFSET_ADDR(indices_addr, left, indices_len), - values_addr + left, indices_size); - while (i < j) { - while (i < j && CompareIndices(INDICES_OFFSET_ADDR(indices_addr, j, indices_len), tmp_indices, indices_len)) { - --j; - } - CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, - INDICES_OFFSET_ADDR(indices_addr, j, indices_len), values_addr + j, indices_size); - while (i < j && !CompareIndices(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), tmp_indices, indices_len)) { - ++i; - } - CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, j, indices_len), values_addr + j, - INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, indices_size); - } - CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, tmp_indices, &tmp_values, - indices_size); - return i; -} -} // namespace - -template -void QuickSortIndicesAndValues(I *__restrict indices_addr, T *__restrict values_addr, const size_t &indices_len, - const int64_t &left, const int64_t &right) { - std::stack index_stk; - (void)index_stk.emplace(right); - (void)index_stk.emplace(left); - I *indices_buff = new I[indices_len]; - - while (!index_stk.empty()) { - int64_t i = index_stk.top(); - index_stk.pop(); - int64_t j = index_stk.top(); - index_stk.pop(); - if (i < j) { - int64_t k = Partition(indices_addr, values_addr, indices_buff, indices_len, i, j); - if (k > i) { - (void)index_stk.emplace(k - 1); - (void)index_stk.emplace(i); - } - if (j > k) { - (void)index_stk.emplace(j); - (void)index_stk.emplace(k + 1); - } - } - } - delete[] indices_buff; -} - -bool SparseSoftmaxCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { - return false; - } - return true; -} - -int SparseSoftmaxCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto indices_shape = inputs.at(kIndex0)->GetShapeVector(); - auto values_shape = inputs.at(kIndex1)->GetShapeVector(); - values_size_ = LongToSize(values_shape[0]); - auto shape_shape = inputs.at(kIndex2)->GetShapeVector(); - shape_size_ = LongToSize(shape_shape[0]); - auto output_shape = outputs.at(kIndex0)->GetShapeVector(); - output_shape_ = Convert2SizeT(output_shape); - if (indices_shape.size() != kIndicesShapeSize) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', it requires 'indices' should be a " << kIndicesShapeSize - << "-D Tensor, but got " << indices_shape.size() << "-D"; - } - if (values_shape.size() != kValuesShapeSize || values_shape[0] != indices_shape[0]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', it requires 'values' should be a 1-D Tensor and " - "the first dimension length should be equal to the first dimension length of " - "'indices', but got 'values' shape: " - << values_shape << " and 'indices' shape: " << indices_shape; - } - if (shape_shape.size() != kShapeShapeSize || LongToSize(shape_shape[0]) < kShapeMinSize || - shape_shape[0] != indices_shape[1]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', it requires 'shape' should be 1-D and more than 2 element, the element " - "should be equal to the second dimension length of 'indices', but " - "got 'shape' shape: " - << shape_shape << " and 'indices' shape: " << indices_shape; - } - return KRET_OK; -} - -template -bool SparseSoftmaxCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSoftmaxInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSoftmaxOutputsNum, kernel_name_); - if (outputs[0]->size() == 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0."; - } - auto ret = memset_s(outputs[0]->device_ptr(), outputs[0]->size(), 0, outputs[0]->size()); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret; - } - auto *indices_addr = static_cast(inputs[kIndex0]->device_ptr()); - auto *values_addr = static_cast(inputs[kIndex1]->device_ptr()); - auto *output_addr = static_cast(outputs[kIndex0]->device_ptr()); - const size_t indices_length = inputs[kIndex0]->size() / sizeof(I); - const size_t values_length = inputs[kIndex1]->size() / sizeof(T); - std::vector exp_values; - std::vector index_values; - std::vector visited; - - QuickSortIndicesAndValues(indices_addr, values_addr, shape_size_, 0, SizeToLong(values_size_) - 1); - - for (size_t i = 0; i < values_size_; ++i) { - visited.push_back(0); - } - T exp_sum = static_cast(0); - int equal_judge = 0; - for (size_t i = 0; i < values_size_; ++i) { - if (visited[i] == 1) { - continue; - } - if (i >= values_length) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the size of 'values' " - "should be the same size as values memory length '" - << values_length << "'but got '" << i << "'."; - } - for (size_t j = i; j < values_size_; j++) { - for (size_t k = 0; k < shape_size_ - 1; k++) { - if (i * shape_size_ + k >= indices_length || j * shape_size_ + k >= indices_length) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the size of 'indices' " - "should be the same size as indices memory length '" - << indices_length << "'but got '" << i << "'."; - } - size_t index_i = i * shape_size_ + k; - size_t index_j = j * shape_size_ + k; - if (indices_addr[index_i] == indices_addr[index_j]) { - equal_judge = 1; - } else { - equal_judge = 0; - break; - } - } - if (equal_judge == 1) { - visited[j] = 1; - exp_values.push_back(exp(values_addr[j])); - index_values.push_back(j); - } - equal_judge = 0; - } - for (size_t p = 0; p < exp_values.size(); ++p) { - exp_sum += exp_values[p]; - } - for (size_t q = 0; q < exp_values.size(); ++q) { - output_addr[index_values[q]] = exp_values[q] / exp_sum; - } - exp_sum = 0; - std::vector().swap(exp_values); - std::vector().swap(index_values); - } - return true; -} -const std::vector> - &SparseSoftmaxCpuKernelMod::GetFuncList() const { - static const std::vector> func_list = { - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSoftmaxCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSoftmaxCpuKernelMod::LaunchKernel}, - }; - return func_list; -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSoftmax, SparseSoftmaxCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h" +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kIndicesShapeSize = 2; +constexpr size_t kValuesShapeSize = 1; +constexpr size_t kShapeShapeSize = 1; +constexpr size_t kSparseSoftmaxInputsNum = 3; +constexpr size_t kSparseSoftmaxOutputsNum = 1; +constexpr size_t kShapeMinSize = 2; +constexpr size_t kinput_indices = 0; +constexpr size_t kinput_values = 1; +constexpr size_t kinput_shape = 2; +constexpr size_t kIndex0 = 0; +constexpr size_t kIndex1 = 1; +constexpr size_t kIndex2 = 2; + +template +inline bool CompareIndices(const I *a, const I *b, const size_t &len) { + size_t i = 0; + while (i < len) { + if (a[i] != b[i]) { + return a[i] > b[i]; + } + ++i; + } + return true; +} + +template +inline void CopyIndicesAndValue(I *dst_indices_addr, T *dst_values_addr, const I *src_indices_addr, + const T *src_values_addr, const size_t &indices_size) { + auto ret = memcpy_s(dst_indices_addr, indices_size, src_indices_addr, indices_size); + if (ret != EOK) { + MS_LOG(ERROR) << "Execute memcpy_s failed."; + } + *dst_values_addr = *src_values_addr; +} + +template +inline int64_t Partition(I *__restrict indices_addr, T *__restrict values_addr, I *__restrict tmp_indices, + const size_t &indices_len, const int64_t &left, const int64_t &right) { + int64_t i = left; + int64_t j = right; + T tmp_values = 0; + const size_t indices_size = indices_len * sizeof(I); +#define INDICES_OFFSET_ADDR(addr, index, len) (addr) + (index) * (len) + + CopyIndicesAndValue(tmp_indices, &tmp_values, INDICES_OFFSET_ADDR(indices_addr, left, indices_len), + values_addr + left, indices_size); + while (i < j) { + while (i < j && CompareIndices(INDICES_OFFSET_ADDR(indices_addr, j, indices_len), tmp_indices, indices_len)) { + --j; + } + CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, + INDICES_OFFSET_ADDR(indices_addr, j, indices_len), values_addr + j, indices_size); + while (i < j && !CompareIndices(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), tmp_indices, indices_len)) { + ++i; + } + CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, j, indices_len), values_addr + j, + INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, indices_size); + } + CopyIndicesAndValue(INDICES_OFFSET_ADDR(indices_addr, i, indices_len), values_addr + i, tmp_indices, &tmp_values, + indices_size); + return i; +} +} // namespace + +template +void QuickSortIndicesAndValues(I *__restrict indices_addr, T *__restrict values_addr, const size_t &indices_len, + const int64_t &left, const int64_t &right) { + std::stack index_stk; + (void)index_stk.emplace(right); + (void)index_stk.emplace(left); + I *indices_buff = new I[indices_len]; + + while (!index_stk.empty()) { + int64_t i = index_stk.top(); + index_stk.pop(); + int64_t j = index_stk.top(); + index_stk.pop(); + if (i < j) { + int64_t k = Partition(indices_addr, values_addr, indices_buff, indices_len, i, j); + if (k > i) { + (void)index_stk.emplace(k - 1); + (void)index_stk.emplace(i); + } + if (j > k) { + (void)index_stk.emplace(j); + (void)index_stk.emplace(k + 1); + } + } + } + delete[] indices_buff; +} + +bool SparseSoftmaxCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { + return false; + } + return true; +} + +int SparseSoftmaxCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto indices_shape = inputs.at(kIndex0)->GetShapeVector(); + auto values_shape = inputs.at(kIndex1)->GetShapeVector(); + values_size_ = LongToSize(values_shape[0]); + auto shape_shape = inputs.at(kIndex2)->GetShapeVector(); + shape_size_ = LongToSize(shape_shape[0]); + auto output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_shape_ = Convert2SizeT(output_shape); + if (indices_shape.size() != kIndicesShapeSize) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', it requires 'indices' should be a " << kIndicesShapeSize + << "-D Tensor, but got " << indices_shape.size() << "-D"; + } + if (values_shape.size() != kValuesShapeSize || values_shape[0] != indices_shape[0]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', it requires 'values' should be a 1-D Tensor and " + "the first dimension length should be equal to the first dimension length of " + "'indices', but got 'values' shape: " + << values_shape << " and 'indices' shape: " << indices_shape; + } + if (shape_shape.size() != kShapeShapeSize || LongToSize(shape_shape[0]) < kShapeMinSize || + shape_shape[0] != indices_shape[1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', it requires 'shape' should be 1-D and more than 2 element, the element " + "should be equal to the second dimension length of 'indices', but " + "got 'shape' shape: " + << shape_shape << " and 'indices' shape: " << indices_shape; + } + return KRET_OK; +} + +template +bool SparseSoftmaxCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSoftmaxInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSoftmaxOutputsNum, kernel_name_); + if (outputs[0]->size() == 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0."; + } + auto ret = memset_s(outputs[0]->device_ptr(), outputs[0]->size(), 0, outputs[0]->size()); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret; + } + auto *indices_addr = static_cast(inputs[kIndex0]->device_ptr()); + auto *values_addr = static_cast(inputs[kIndex1]->device_ptr()); + auto *output_addr = static_cast(outputs[kIndex0]->device_ptr()); + const size_t indices_length = inputs[kIndex0]->size() / sizeof(I); + const size_t values_length = inputs[kIndex1]->size() / sizeof(T); + std::vector exp_values; + std::vector index_values; + std::vector visited; + + QuickSortIndicesAndValues(indices_addr, values_addr, shape_size_, 0, SizeToLong(values_size_) - 1); + + for (size_t i = 0; i < values_size_; ++i) { + visited.push_back(0); + } + T exp_sum = static_cast(0); + int equal_judge = 0; + for (size_t i = 0; i < values_size_; ++i) { + if (visited[i] == 1) { + continue; + } + if (i >= values_length) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the size of 'values' " + "should be the same size as values memory length '" + << values_length << "'but got '" << i << "'."; + } + for (size_t j = i; j < values_size_; j++) { + for (size_t k = 0; k < shape_size_ - 1; k++) { + if (i * shape_size_ + k >= indices_length || j * shape_size_ + k >= indices_length) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the size of 'indices' " + "should be the same size as indices memory length '" + << indices_length << "'but got '" << i << "'."; + } + size_t index_i = i * shape_size_ + k; + size_t index_j = j * shape_size_ + k; + if (indices_addr[index_i] == indices_addr[index_j]) { + equal_judge = 1; + } else { + equal_judge = 0; + break; + } + } + if (equal_judge == 1) { + visited[j] = 1; + exp_values.push_back(exp(values_addr[j])); + index_values.push_back(j); + } + equal_judge = 0; + } + for (size_t p = 0; p < exp_values.size(); ++p) { + exp_sum += exp_values[p]; + } + for (size_t q = 0; q < exp_values.size(); ++q) { + output_addr[index_values[q]] = exp_values[q] / exp_sum; + } + exp_sum = 0; + std::vector().swap(exp_values); + std::vector().swap(index_values); + } + return true; +} +const std::vector> + &SparseSoftmaxCpuKernelMod::GetFuncList() const { + static const std::vector> func_list = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSoftmaxCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSoftmaxCpuKernelMod::LaunchKernel}, + }; + return func_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseSoftmax, SparseSoftmaxCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h index 8945b0f1a60..1aee7f3f8a6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_softmax_cpu_kernel.h @@ -1,60 +1,60 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/common_utils.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SparseSoftmaxCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { - public: - SparseSoftmaxCpuKernelMod() = default; - ~SparseSoftmaxCpuKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return kernel_func_(this, inputs, workspace, outputs); - } - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - const std::vector> &GetFuncList() const override; - - protected: - std::vector GetOpSupport() override { return OpSupport(); }; - - private: - template - bool LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs); - - std::vector output_shape_; - size_t values_size_{0}; - size_t shape_size_{0}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "kernel/common_utils.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSoftmaxCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + SparseSoftmaxCpuKernelMod() = default; + ~SparseSoftmaxCpuKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); }; + + private: + template + bool LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + + std::vector output_shape_; + size_t values_size_{0}; + size_t shape_size_{0}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.cc index 52e876a0651..a901cffb7f7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.cc @@ -1,195 +1,195 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr int64_t kRankWithoutBatch = 2; -constexpr int64_t kRankWithBatch = 3; -constexpr int64_t kZero = 0; -constexpr int64_t kOne = 1; -constexpr int64_t kTwo = 2; -constexpr int64_t kSparseTensorToCSRSparseMatrixInputsNum = 3; -constexpr int64_t kSparseTensorToCSRSparseMatrixOutputsNum = 5; -constexpr size_t kInputIndex0 = 0; -constexpr size_t kInputIndex1 = 1; -constexpr size_t kInputIndex2 = 2; -constexpr size_t kOutputIndex0 = 0; -constexpr size_t kOutputIndex1 = 1; -constexpr size_t kOutputIndex2 = 2; -constexpr size_t kOutputIndex3 = 3; -constexpr size_t kOutputIndex4 = 4; -constexpr int64_t kInitPrevBatch = -1; -constexpr char kKernelName[] = "SparseTensorToCSRSparseMatrix"; - -#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7, t8) \ - KernelAttr() \ - .AddInputAttr(kNumberType##t1) \ - .AddInputAttr(kNumberType##t2) \ - .AddInputAttr(kNumberType##t3) \ - .AddOutputAttr(kNumberType##t4) \ - .AddOutputAttr(kNumberType##t5) \ - .AddOutputAttr(kNumberType##t6) \ - .AddOutputAttr(kNumberType##t7) \ - .AddOutputAttr(kNumberType##t8) -} // namespace - -bool SparseTensorToCSRSparseMatrixCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - indice_type_ = inputs.at(kIndex0)->dtype_id(); - value_type_ = inputs.at(kIndex1)->dtype_id(); - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseTensorToCSRSparseMatrixInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseTensorToCSRSparseMatrixOutputsNum, kernel_name_); - return true; -} - -int SparseTensorToCSRSparseMatrixCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto x_indices_shape = inputs.at(kIndex0)->GetShapeVector(); - total_nnz_ = x_indices_shape[0]; - auto input_shape = inputs.at(kIndex2)->GetShapeVector(); - rank_ = input_shape[0]; - return KRET_OK; -} - -bool SparseTensorToCSRSparseMatrixCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - switch (indice_type_) { - case kNumberTypeInt32: - switch (value_type_) { - case kNumberTypeFloat32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeComplex64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeComplex128: - LaunchKernel(inputs, outputs); - break; - default: - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be " - << "float32, float64, complex64 or complex128, but got " - << TypeIdToType(value_type_)->ToString() << "."; - } - break; - case kNumberTypeInt64: - switch (value_type_) { - case kNumberTypeFloat32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeComplex64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeComplex128: - LaunchKernel(inputs, outputs); - break; - default: - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be " - << "float32, float64, complex64 or complex128, but got " - << TypeIdToType(value_type_)->ToString() << "."; - } - break; - default: - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of indices should be int32 or int64, " - << "but got " << TypeIdToType(indice_type_)->ToString() << "."; - } - return true; -} - -template -void SparseTensorToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - const int64_t shift = (rank_ == kRankWithoutBatch) ? kZero : kOne; - num_rows_ = *(static_cast(inputs[kInputIndex2]->device_ptr()) + shift); - indiceT *x_indices = static_cast(inputs[kInputIndex0]->device_ptr()); - valueT *x_values = static_cast(inputs[kInputIndex1]->device_ptr()); - indiceT *x_dense_shape = static_cast(inputs[kInputIndex2]->device_ptr()); - batch_size_ = (rank_ == kRankWithoutBatch) ? kOne : x_dense_shape[kZero]; - indiceT *y_dense_shape_addr = static_cast(outputs[kOutputIndex0]->device_ptr()); - indiceT *y_batch_pointers_addr = static_cast(outputs[kOutputIndex1]->device_ptr()); - indiceT *y_row_pointers_addr = static_cast(outputs[kOutputIndex2]->device_ptr()); - indiceT *y_col_indices_addr = static_cast(outputs[kOutputIndex3]->device_ptr()); - valueT *y_values_addr = static_cast(outputs[kOutputIndex4]->device_ptr()); - - for (int64_t i = kZero; i < rank_; i++) { - y_dense_shape_addr[i] = x_dense_shape[i]; - } - - for (int64_t i = kZero; i < total_nnz_; i++) { - y_values_addr[i] = x_values[i]; - } - - for (int64_t i = kZero; i < batch_size_ * (num_rows_ + 1); i++) { - y_row_pointers_addr[i] = indiceT(kZero); - } - - int64_t prev_batch = kInitPrevBatch; - if (rank_ == kRankWithoutBatch) { - y_batch_pointers_addr[kZero] = indiceT(kZero); - ++prev_batch; - for (int64_t i = kZero; i < total_nnz_; ++i) { - y_row_pointers_addr[x_indices[i * rank_] + kOne] += indiceT(kOne); - y_col_indices_addr[i] = x_indices[i * rank_ + kOne]; - } - } else { - for (int64_t i = kZero; i < total_nnz_; ++i) { - int64_t cur_batch = static_cast(x_indices[i * rank_]); - y_row_pointers_addr[cur_batch * (num_rows_ + kOne) + x_indices[i * rank_ + kOne] + kOne] += kOne; - y_col_indices_addr[i] = x_indices[i * rank_ + kTwo]; - while (prev_batch < cur_batch) { - y_batch_pointers_addr[prev_batch + kOne] = indiceT(i); - ++prev_batch; - } - } - } - while (prev_batch < batch_size_) { - y_batch_pointers_addr[prev_batch + kOne] = total_nnz_; - ++prev_batch; - } - for (int64_t batch_idx = 0; batch_idx < batch_size_; ++batch_idx) { - auto *row_ptr_batch = y_row_pointers_addr + batch_idx * (num_rows_ + kOne); - (void)std::partial_sum(row_ptr_batch, row_ptr_batch + num_rows_ + kOne, row_ptr_batch); - } -} -std::vector SparseTensorToCSRSparseMatrixCpuKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = { - ADD_KERNEL(Int32, Float32, Int32, Int32, Int32, Int32, Int32, Float32), - ADD_KERNEL(Int32, Float64, Int32, Int32, Int32, Int32, Int32, Float64), - ADD_KERNEL(Int32, Complex64, Int32, Int32, Int32, Int32, Int32, Complex64), - ADD_KERNEL(Int32, Complex128, Int32, Int32, Int32, Int32, Int32, Complex128), - ADD_KERNEL(Int64, Float32, Int64, Int64, Int64, Int64, Int64, Float32), - ADD_KERNEL(Int64, Float64, Int64, Int64, Int64, Int64, Int64, Float64), - ADD_KERNEL(Int64, Complex64, Int64, Int64, Int64, Int64, Int64, Complex64), - ADD_KERNEL(Int64, Complex128, Int64, Int64, Int64, Int64, Int64, Complex128)}; - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseTensorToCSRSparseMatrix, SparseTensorToCSRSparseMatrixCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int64_t kRankWithoutBatch = 2; +constexpr int64_t kRankWithBatch = 3; +constexpr int64_t kZero = 0; +constexpr int64_t kOne = 1; +constexpr int64_t kTwo = 2; +constexpr int64_t kSparseTensorToCSRSparseMatrixInputsNum = 3; +constexpr int64_t kSparseTensorToCSRSparseMatrixOutputsNum = 5; +constexpr size_t kInputIndex0 = 0; +constexpr size_t kInputIndex1 = 1; +constexpr size_t kInputIndex2 = 2; +constexpr size_t kOutputIndex0 = 0; +constexpr size_t kOutputIndex1 = 1; +constexpr size_t kOutputIndex2 = 2; +constexpr size_t kOutputIndex3 = 3; +constexpr size_t kOutputIndex4 = 4; +constexpr int64_t kInitPrevBatch = -1; +constexpr char kKernelName[] = "SparseTensorToCSRSparseMatrix"; + +#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7, t8) \ + KernelAttr() \ + .AddInputAttr(kNumberType##t1) \ + .AddInputAttr(kNumberType##t2) \ + .AddInputAttr(kNumberType##t3) \ + .AddOutputAttr(kNumberType##t4) \ + .AddOutputAttr(kNumberType##t5) \ + .AddOutputAttr(kNumberType##t6) \ + .AddOutputAttr(kNumberType##t7) \ + .AddOutputAttr(kNumberType##t8) +} // namespace + +bool SparseTensorToCSRSparseMatrixCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + indice_type_ = inputs.at(kIndex0)->dtype_id(); + value_type_ = inputs.at(kIndex1)->dtype_id(); + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseTensorToCSRSparseMatrixInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseTensorToCSRSparseMatrixOutputsNum, kernel_name_); + return true; +} + +int SparseTensorToCSRSparseMatrixCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto x_indices_shape = inputs.at(kIndex0)->GetShapeVector(); + total_nnz_ = x_indices_shape[0]; + auto input_shape = inputs.at(kIndex2)->GetShapeVector(); + rank_ = input_shape[0]; + return KRET_OK; +} + +bool SparseTensorToCSRSparseMatrixCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + switch (indice_type_) { + case kNumberTypeInt32: + switch (value_type_) { + case kNumberTypeFloat32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeComplex64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeComplex128: + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be " + << "float32, float64, complex64 or complex128, but got " + << TypeIdToType(value_type_)->ToString() << "."; + } + break; + case kNumberTypeInt64: + switch (value_type_) { + case kNumberTypeFloat32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeComplex64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeComplex128: + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of values should be " + << "float32, float64, complex64 or complex128, but got " + << TypeIdToType(value_type_)->ToString() << "."; + } + break; + default: + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', dtype of indices should be int32 or int64, " + << "but got " << TypeIdToType(indice_type_)->ToString() << "."; + } + return true; +} + +template +void SparseTensorToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + const int64_t shift = (rank_ == kRankWithoutBatch) ? kZero : kOne; + num_rows_ = *(static_cast(inputs[kInputIndex2]->device_ptr()) + shift); + indiceT *x_indices = static_cast(inputs[kInputIndex0]->device_ptr()); + valueT *x_values = static_cast(inputs[kInputIndex1]->device_ptr()); + indiceT *x_dense_shape = static_cast(inputs[kInputIndex2]->device_ptr()); + batch_size_ = (rank_ == kRankWithoutBatch) ? kOne : x_dense_shape[kZero]; + indiceT *y_dense_shape_addr = static_cast(outputs[kOutputIndex0]->device_ptr()); + indiceT *y_batch_pointers_addr = static_cast(outputs[kOutputIndex1]->device_ptr()); + indiceT *y_row_pointers_addr = static_cast(outputs[kOutputIndex2]->device_ptr()); + indiceT *y_col_indices_addr = static_cast(outputs[kOutputIndex3]->device_ptr()); + valueT *y_values_addr = static_cast(outputs[kOutputIndex4]->device_ptr()); + + for (int64_t i = kZero; i < rank_; i++) { + y_dense_shape_addr[i] = x_dense_shape[i]; + } + + for (int64_t i = kZero; i < total_nnz_; i++) { + y_values_addr[i] = x_values[i]; + } + + for (int64_t i = kZero; i < batch_size_ * (num_rows_ + 1); i++) { + y_row_pointers_addr[i] = indiceT(kZero); + } + + int64_t prev_batch = kInitPrevBatch; + if (rank_ == kRankWithoutBatch) { + y_batch_pointers_addr[kZero] = indiceT(kZero); + ++prev_batch; + for (int64_t i = kZero; i < total_nnz_; ++i) { + y_row_pointers_addr[x_indices[i * rank_] + kOne] += indiceT(kOne); + y_col_indices_addr[i] = x_indices[i * rank_ + kOne]; + } + } else { + for (int64_t i = kZero; i < total_nnz_; ++i) { + int64_t cur_batch = static_cast(x_indices[i * rank_]); + y_row_pointers_addr[cur_batch * (num_rows_ + kOne) + x_indices[i * rank_ + kOne] + kOne] += kOne; + y_col_indices_addr[i] = x_indices[i * rank_ + kTwo]; + while (prev_batch < cur_batch) { + y_batch_pointers_addr[prev_batch + kOne] = indiceT(i); + ++prev_batch; + } + } + } + while (prev_batch < batch_size_) { + y_batch_pointers_addr[prev_batch + kOne] = total_nnz_; + ++prev_batch; + } + for (int64_t batch_idx = 0; batch_idx < batch_size_; ++batch_idx) { + auto *row_ptr_batch = y_row_pointers_addr + batch_idx * (num_rows_ + kOne); + (void)std::partial_sum(row_ptr_batch, row_ptr_batch + num_rows_ + kOne, row_ptr_batch); + } +} +std::vector SparseTensorToCSRSparseMatrixCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + ADD_KERNEL(Int32, Float32, Int32, Int32, Int32, Int32, Int32, Float32), + ADD_KERNEL(Int32, Float64, Int32, Int32, Int32, Int32, Int32, Float64), + ADD_KERNEL(Int32, Complex64, Int32, Int32, Int32, Int32, Int32, Complex64), + ADD_KERNEL(Int32, Complex128, Int32, Int32, Int32, Int32, Int32, Complex128), + ADD_KERNEL(Int64, Float32, Int64, Int64, Int64, Int64, Int64, Float32), + ADD_KERNEL(Int64, Float64, Int64, Int64, Int64, Int64, Int64, Float64), + ADD_KERNEL(Int64, Complex64, Int64, Int64, Int64, Int64, Int64, Complex64), + ADD_KERNEL(Int64, Complex128, Int64, Int64, Int64, Int64, Int64, Complex128)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseTensorToCSRSparseMatrix, SparseTensorToCSRSparseMatrixCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h index 457c378540a..344239a7564 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_tensor_to_csr_sparse_matrix_cpu_kernel.h @@ -1,64 +1,64 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -using complex64 = std::complex; -using complex128 = std::complex; -class SparseTensorToCSRSparseMatrixCpuKernelMod : public NativeCpuKernelMod { - public: - SparseTensorToCSRSparseMatrixCpuKernelMod() = default; - ~SparseTensorToCSRSparseMatrixCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - int64_t rank_{1}; - int64_t num_rows_{1}; - int64_t total_nnz_{1}; - int64_t batch_size_{1}; - CNodeWeakPtr node_wpt_; - TypeId value_type_{kTypeUnknown}; - TypeId indice_type_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; +class SparseTensorToCSRSparseMatrixCpuKernelMod : public NativeCpuKernelMod { + public: + SparseTensorToCSRSparseMatrixCpuKernelMod() = default; + ~SparseTensorToCSRSparseMatrixCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + int64_t rank_{1}; + int64_t num_rows_{1}; + int64_t total_nnz_{1}; + int64_t batch_size_{1}; + CNodeWeakPtr node_wpt_; + TypeId value_type_{kTypeUnknown}; + TypeId indice_type_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.cc index 8c9d7f81433..beee9ddcc8c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.cc @@ -1,106 +1,106 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/trace_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputNum = 1; -constexpr size_t kInputDim = 2; -constexpr size_t kOutputNum = 1; -} // namespace - -bool TraceCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - values_type_ = inputs.at(kIndex0)->dtype_id(); - return true; -} - -int TraceCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_shape_ = Convert2SizeT(inputs.at(kIndex0)->GetDeviceShapeVector()); - if (input_shape_.size() != kInputDim) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor's dimension should be " << kInputDim - << ", but got " << input_shape_.size(); - } - return KRET_OK; -} - -bool TraceCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); - switch (values_type_) { - case kNumberTypeInt8: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt8: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat64: - LaunchKernel(inputs, outputs); - break; - default: - MS_LOG(EXCEPTION) << "Unsupported input data type."; - } - return true; -} - -template -void TraceCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *input_addr = GetDeviceAddress(inputs, kIndex0); - T *output_addr = GetDeviceAddress(outputs, kIndex0); - size_t min_size = std::min(input_shape_[0], input_shape_[1]); - if (memset_s(output_addr, outputs[0]->size(), 0, outputs[0]->size()) != EOK) { - MS_LOG(EXCEPTION) << "Failed to init output memory."; - } - for (size_t i = 0; i < min_size; ++i) { - *output_addr += *(input_addr + i * input_shape_[1] + i); - } -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Trace, TraceCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/trace_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputNum = 1; +constexpr size_t kInputDim = 2; +constexpr size_t kOutputNum = 1; +} // namespace + +bool TraceCpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + values_type_ = inputs.at(kIndex0)->dtype_id(); + return true; +} + +int TraceCpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_shape_ = Convert2SizeT(inputs.at(kIndex0)->GetDeviceShapeVector()); + if (input_shape_.size() != kInputDim) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor's dimension should be " << kInputDim + << ", but got " << input_shape_.size(); + } + return KRET_OK; +} + +bool TraceCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + switch (values_type_) { + case kNumberTypeInt8: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt8: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat64: + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "Unsupported input data type."; + } + return true; +} + +template +void TraceCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *input_addr = GetDeviceAddress(inputs, kIndex0); + T *output_addr = GetDeviceAddress(outputs, kIndex0); + size_t min_size = std::min(input_shape_[0], input_shape_[1]); + if (memset_s(output_addr, outputs[0]->size(), 0, outputs[0]->size()) != EOK) { + MS_LOG(EXCEPTION) << "Failed to init output memory."; + } + for (size_t i = 0; i < min_size; ++i) { + *output_addr += *(input_addr + i * input_shape_[1] + i); + } +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Trace, TraceCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.h index ed96a7f6116..2ccdbbb18b0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_cpu_kernel.h @@ -1,66 +1,66 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class TraceCpuKernelMod : public NativeCpuKernelMod { - public: - TraceCpuKernelMod() = default; - ~TraceCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - std::vector GetOpSupport() override { - static const std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; - return support_list; - } - - private: - std::vector input_shape_; - TypeId values_type_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class TraceCpuKernelMod : public NativeCpuKernelMod { + public: + TraceCpuKernelMod() = default; + ~TraceCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + std::vector GetOpSupport() override { + static const std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector input_shape_; + TypeId values_type_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.cc index 10ac74817d6..6b162ac9443 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.cc @@ -1,112 +1,112 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/cpu/kernel/trace_grad_cpu_kernel.h" -#include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputNum = 2; -constexpr size_t kOutputNum = 1; -} // namespace - -bool TraceGradCpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - values_type_ = inputs.at(kIndex0)->dtype_id(); - return true; -} - -int TraceGradCpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - input_shape_ = inputs.at(kIndex1)->GetDeviceShapeVector(); - const std::vector x_shape_ = {2}; - if (input_shape_ != x_shape_) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input[x_shape] should be " << x_shape_ - << ", but got " << input_shape_ << "."; - } - return KRET_OK; -} - -bool TraceGradCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); - switch (values_type_) { - case kNumberTypeInt8: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt8: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat16: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat32: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeInt64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeUInt64: - LaunchKernel(inputs, outputs); - break; - case kNumberTypeFloat64: - LaunchKernel(inputs, outputs); - break; - default: - MS_LOG(EXCEPTION) << "Trace Grad Unsupported input data type."; - } - return true; -} - -template -void TraceGradCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *grad = GetDeviceAddress(inputs, kIndex0); - MS_EXCEPTION_IF_NULL(grad); - auto shape = GetDeviceAddress(inputs, kIndex1); - MS_EXCEPTION_IF_NULL(shape); - T *output_addr = GetDeviceAddress(outputs, kIndex0); - MS_EXCEPTION_IF_NULL(output_addr); - - if (memset_s(output_addr, outputs[0]->size(), 0, outputs[0]->size()) != EOK) { - MS_LOG(EXCEPTION) << "Failed to init output memory."; - } - int64_t min_size = std::min(shape[0], shape[1]); - for (int64_t i = 0; i < min_size; ++i) { - *(output_addr + i * shape[1] + i) = *grad; - } -} -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TraceGrad, TraceGradCpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/trace_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputNum = 2; +constexpr size_t kOutputNum = 1; +} // namespace + +bool TraceGradCpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + values_type_ = inputs.at(kIndex0)->dtype_id(); + return true; +} + +int TraceGradCpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + input_shape_ = inputs.at(kIndex1)->GetDeviceShapeVector(); + const std::vector x_shape_ = {2}; + if (input_shape_ != x_shape_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of input[x_shape] should be " << x_shape_ + << ", but got " << input_shape_ << "."; + } + return KRET_OK; +} + +bool TraceGradCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + switch (values_type_) { + case kNumberTypeInt8: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt8: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeUInt64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat64: + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "Trace Grad Unsupported input data type."; + } + return true; +} + +template +void TraceGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *grad = GetDeviceAddress(inputs, kIndex0); + MS_EXCEPTION_IF_NULL(grad); + auto shape = GetDeviceAddress(inputs, kIndex1); + MS_EXCEPTION_IF_NULL(shape); + T *output_addr = GetDeviceAddress(outputs, kIndex0); + MS_EXCEPTION_IF_NULL(output_addr); + + if (memset_s(output_addr, outputs[0]->size(), 0, outputs[0]->size()) != EOK) { + MS_LOG(EXCEPTION) << "Failed to init output memory."; + } + int64_t min_size = std::min(shape[0], shape[1]); + for (int64_t i = 0; i < min_size; ++i) { + *(output_addr + i * shape[1] + i) = *grad; + } +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TraceGrad, TraceGradCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.h index c6df5f4ffc6..13d8939b2d8 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/trace_grad_cpu_kernel.h @@ -1,66 +1,66 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class TraceGradCpuKernelMod : public NativeCpuKernelMod { - public: - TraceGradCpuKernelMod() = default; - ~TraceGradCpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - - std::vector GetOpSupport() override { - static const std::vector support_list = { - KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), - KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64)}; - return support_list; - } - - private: - std::vector input_shape_; - TypeId values_type_{kTypeUnknown}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class TraceGradCpuKernelMod : public NativeCpuKernelMod { + public: + TraceGradCpuKernelMod() = default; + ~TraceGradCpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + std::vector GetOpSupport() override { + static const std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector input_shape_; + TypeId values_type_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRACE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.h b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.h index 51aef9993a0..d5fc3df2d52 100644 --- a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.h +++ b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_cast_cpu.h @@ -1,35 +1,35 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H - -#include -#include "include/backend/optimizer/optimizer.h" -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -class InsertCastCPU : public Pass { - public: - explicit InsertCastCPU(const std::string & /* name */) : Pass("insert_cast_cpu") {} - ~InsertCastCPU() override = default; - bool Run(const FuncGraphPtr &graph) override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H + +#include +#include "include/backend/optimizer/optimizer.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertCastCPU : public Pass { + public: + explicit InsertCastCPU(const std::string & /* name */) : Pass("insert_cast_cpu") {} + ~InsertCastCPU() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H diff --git a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_format_transform_op.h b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_format_transform_op.h index 74463653ca5..cd6611a8f47 100644 --- a/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_format_transform_op.h +++ b/mindspore/ccsrc/plugin/device/cpu/optimizer/insert_format_transform_op.h @@ -1,35 +1,35 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H - -#include -#include "include/backend/optimizer/optimizer.h" -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -class InsertFormatTransformOpCPU : public Pass { - public: - explicit InsertFormatTransformOpCPU(const std::string &) : Pass("insert_format_transform_op_cpu") {} - ~InsertFormatTransformOpCPU() override = default; - bool Run(const FuncGraphPtr &graph) override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H + +#include +#include "include/backend/optimizer/optimizer.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertFormatTransformOpCPU : public Pass { + public: + explicit InsertFormatTransformOpCPU(const std::string &) : Pass("insert_format_transform_op_cpu") {} + ~InsertFormatTransformOpCPU() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/cuda_driver.h b/mindspore/ccsrc/plugin/device/gpu/hal/device/cuda_driver.h index 4c35db905f2..a963beabdb0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/cuda_driver.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/cuda_driver.h @@ -1,90 +1,90 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ -#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ - -#include - -namespace mindspore { -namespace device { -namespace gpu { -typedef void *CudaDeviceStream; -typedef void *CudaDeviceEvent; -typedef void *HostMemPtr; -typedef void *DeviceMemPtr; - -class CudaDriver { - public: - // Encapsulate the cuda APIs associated with memory operations - // such as malloc/free and memory copy from host to device and reverse. - static size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr); - static bool FreeDeviceMem(const DeviceMemPtr &addr); - static size_t AllocHostPinnedMem(size_t size, void **addr); - static void FreeHostPinnedMem(void *addr); - - static void CudaHostRegister(void *addr, size_t alloc_size); - - static void CudaHostUnregister(void *addr); - - static bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size); - static bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size); - static bool CopyHostMemToHost(const DeviceMemPtr &dst, const void *src, size_t size); - - static bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream = 0); - static bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream = 0); - static bool CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream = 0); - - static size_t total_mem_size(); - static size_t free_mem_size(); - - // Encapsulate the cuda APIs associated with device resource - // such as Stream and Event. - static bool CreateStream(CudaDeviceStream *stream); - static bool CreateStreamWithPriority(CudaDeviceStream *stream, int priority); - static bool DestroyStream(const CudaDeviceStream &stream); - static bool SyncStream(const CudaDeviceStream &stream); - static bool QueryStream(const CudaDeviceStream &stream); - - static bool ConstructEvent(CudaDeviceEvent *event, unsigned int flag = cudaEventDefault); - static bool DestroyEvent(const CudaDeviceEvent &event); - static bool RecordEvent(CudaDeviceEvent event, CudaDeviceStream stream = 0); - static bool SyncEvent(const CudaDeviceEvent &event); - static bool QueryEvent(const CudaDeviceEvent &event); - static bool ElapsedTime(float *cost_time, const CudaDeviceEvent &start, const CudaDeviceEvent &end); - - // Encapsulate the cuda APIs associated with device management. - static int device_count(); - static bool SetDevice(int index); - - private: - CudaDriver() = delete; - ~CudaDriver() = delete; - CudaDriver(const CudaDriver &) = delete; - CudaDriver &operator=(const CudaDriver &) = delete; - - static constexpr float mem_malloc_retry_rate_{0.99}; - static constexpr size_t mem_malloc_retry_conut_max_{10}; - static constexpr size_t mem_malloc_align_size_{4}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ + +#include + +namespace mindspore { +namespace device { +namespace gpu { +typedef void *CudaDeviceStream; +typedef void *CudaDeviceEvent; +typedef void *HostMemPtr; +typedef void *DeviceMemPtr; + +class CudaDriver { + public: + // Encapsulate the cuda APIs associated with memory operations + // such as malloc/free and memory copy from host to device and reverse. + static size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr); + static bool FreeDeviceMem(const DeviceMemPtr &addr); + static size_t AllocHostPinnedMem(size_t size, void **addr); + static void FreeHostPinnedMem(void *addr); + + static void CudaHostRegister(void *addr, size_t alloc_size); + + static void CudaHostUnregister(void *addr); + + static bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size); + static bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size); + static bool CopyHostMemToHost(const DeviceMemPtr &dst, const void *src, size_t size); + + static bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream = 0); + static bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream = 0); + static bool CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream = 0); + + static size_t total_mem_size(); + static size_t free_mem_size(); + + // Encapsulate the cuda APIs associated with device resource + // such as Stream and Event. + static bool CreateStream(CudaDeviceStream *stream); + static bool CreateStreamWithPriority(CudaDeviceStream *stream, int priority); + static bool DestroyStream(const CudaDeviceStream &stream); + static bool SyncStream(const CudaDeviceStream &stream); + static bool QueryStream(const CudaDeviceStream &stream); + + static bool ConstructEvent(CudaDeviceEvent *event, unsigned int flag = cudaEventDefault); + static bool DestroyEvent(const CudaDeviceEvent &event); + static bool RecordEvent(CudaDeviceEvent event, CudaDeviceStream stream = 0); + static bool SyncEvent(const CudaDeviceEvent &event); + static bool QueryEvent(const CudaDeviceEvent &event); + static bool ElapsedTime(float *cost_time, const CudaDeviceEvent &start, const CudaDeviceEvent &end); + + // Encapsulate the cuda APIs associated with device management. + static int device_count(); + static bool SetDevice(int index); + + private: + CudaDriver() = delete; + ~CudaDriver() = delete; + CudaDriver(const CudaDriver &) = delete; + CudaDriver &operator=(const CudaDriver &) = delete; + + static constexpr float mem_malloc_retry_rate_{0.99}; + static constexpr size_t mem_malloc_retry_conut_max_{10}; + static constexpr size_t mem_malloc_align_size_{4}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_CUDA_DRIVER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.cc b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.cc index df1b78a33d6..ff6b3ec1eb5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.cc @@ -1,288 +1,288 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/hal/device/gpu_device_manager.h" - -#include - -#include "plugin/device/gpu/hal/device/gpu_common.h" -#include "utils/log_adapter.h" -#include "include/common/utils/convert_utils.h" - -namespace mindspore { -namespace device { -namespace gpu { -GPUDeviceManager &GPUDeviceManager::GetInstance() { - static GPUDeviceManager instance; - return instance; -} - -void GPUDeviceManager::InitDevice() { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(SizeToInt(cur_dev_id_)), "Failed to set current device id"); - if (dev_alive_) { - return; - } - CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); - default_stream_id_ = gpu_streams_.size() - 1; - CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); - CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuDNN handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( - cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuBLAS handle."); - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnCreate(&cusolver_dn_handle_), - "Failed to create cusolver dn handle."); - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cusolver dn handle"); - // Create cusparse handle. - CHECK_CUSPARSE_RET_WITH_EXCEPT(cusparseCreate(&cusparse_handle_), "Failed to create sparse handle."); - CHECK_CUSPARSE_RET_WITH_EXCEPT(cusparseSetStream(cusparse_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cusparse handle"); - - CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") - dev_alive_ = true; -} - -void GPUDeviceManager::ReleaseDevice() { - // Avoid repeated release device resource. - if (!dev_alive_) { - return; - } - { - std::lock_guard lock_gpu_streams(stream_mutex_); - for (CudaDeviceStream stream : gpu_streams_) { - if (stream != nullptr) { - CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); - } - } - gpu_streams_.clear(); - } - - if (cudnn_handle_ != nullptr) { - CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); - } - if (cublas_handle_ != nullptr) { - CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); - } - if (cusolver_dn_handle_ != nullptr) { - CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle."); - } - if (cusparse_handle_ != nullptr) { - CHECK_CUSPARSE_RET_WITH_ERROR(cusparseDestroy(cusparse_handle_), "Failed to destroy cusparse handle."); - } - - dev_alive_ = false; -} - -bool GPUDeviceManager::CreateStream(CudaDeviceStream *stream) { - std::lock_guard lock_gpu_streams(stream_mutex_); - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); - (void)gpu_streams_.emplace_back(*stream); - return true; -} - -bool GPUDeviceManager::CreateStream(size_t *stream_id) { - MS_EXCEPTION_IF_NULL(stream_id); - - std::lock_guard lock_gpu_streams(stream_mutex_); - CudaDeviceStream stream; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(&stream), "Failed to create CUDA stream"); - *stream_id = gpu_streams_.size(); - (void)gpu_streams_.emplace_back(stream); - return true; -} - -bool GPUDeviceManager::CreateStreamWithPriority(size_t *stream_id, int32_t priority) { - MS_EXCEPTION_IF_NULL(stream_id); - - std::lock_guard lock_gpu_streams(stream_mutex_); - CudaDeviceStream stream; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStreamWithPriority(&stream, priority), - "Failed to create CUDA stream with priority"); - *stream_id = gpu_streams_.size(); - (void)gpu_streams_.emplace_back(stream); - - return true; -} - -bool GPUDeviceManager::DestroyStream(size_t stream_id) { - std::lock_guard lock_gpu_streams(stream_mutex_); - if (stream_id >= gpu_streams_.size()) { - MS_LOG(ERROR) << "CUDA stream not found for stream id " << stream_id; - return false; - } - if (gpu_streams_.at(stream_id) == nullptr) { - MS_LOG(WARNING) << "CUDA stream hsa been destroyed for stream id " << stream_id; - return true; - } - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyStream(gpu_streams_.at(stream_id)), "Failed to destroy CUDA stream"); - gpu_streams_[stream_id] = nullptr; - return true; -} - -CudaDeviceStream GPUDeviceManager::GetStream(size_t stream_id) const { - if (stream_id >= gpu_streams_.size()) { - MS_LOG(DEBUG) << "Stream for stream id[" << stream_id << "] not found, return nullptr."; - return nullptr; - } - return gpu_streams_[stream_id]; -} - -size_t GPUDeviceManager::QueryStreamSize() const { - return std::count_if(gpu_streams_.begin(), gpu_streams_.end(), - [](CudaDeviceStream stream) { return stream != nullptr; }); -} - -std::vector GPUDeviceManager::GetStreamIds() const { - std::vector stream_ids; - for (size_t i = 0; i < gpu_streams_.size(); i++) { - if (gpu_streams_[i] != nullptr) { - (void)stream_ids.emplace_back(static_cast(i)); - } - } - return stream_ids; -} - -void GPUDeviceManager::set_current_stream(size_t stream_id) { current_stream_id_ = stream_id; } - -size_t GPUDeviceManager::current_stream() const { return current_stream_id_; } - -bool GPUDeviceManager::QueryStream(size_t stream_id) { - if (stream_id >= gpu_streams_.size()) { - MS_LOG(ERROR) << "CUDA stream not found for stream id " << stream_id; - return false; - } - if (gpu_streams_.at(stream_id) == nullptr) { - MS_LOG(WARNING) << "CUDA stream has been destroyed for stream id " << stream_id; - return true; - } - MS_LOG(DEBUG) << "Query completion status of stream id: " << stream_id; - return CudaDriver::QueryStream(gpu_streams_.at(stream_id)); -} - -const CudaDeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } - -size_t GPUDeviceManager::default_stream_id() const { return default_stream_id_; } - -int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } - -bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { - if (!dev_id_init_) { - dev_id_init_ = true; - cur_dev_id_ = device_id; - return true; - } else { - MS_LOG(ERROR) << "Device already been set."; - return false; - } -} - -uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } - -bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } - -const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } - -const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } - -const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; } - -const cusparseHandle_t &GPUDeviceManager::GetCuSparseHandle() const { return cusparse_handle_; } - -bool GPUDeviceManager::SyncStream(size_t stream_id) const { - if (!dev_alive_) { - return false; - } - auto stream = GetStream(stream_id); - if (stream == nullptr) { - MS_LOG(EXCEPTION) << "Get CUDA stream for stream id failed."; - } - return SyncStream(stream); -} - -bool GPUDeviceManager::SyncStream(const CudaDeviceStream &stream) const { - return dev_alive_ && CudaDriver::SyncStream(stream); -} - -bool GPUDeviceManager::SyncAllStreams() const { - if (!dev_alive_) { - return false; - } - for (const auto &stream : gpu_streams_) { - if (stream != nullptr && !SyncStream(stream)) { - return false; - } - } - return true; -} - -bool GPUDeviceManager::SyncNotDefaultStreams() const { - bool res = true; - for (size_t i = 0; i < gpu_streams_.size(); i++) { - if (i != default_stream_id_ && !SyncStream(i)) { - MS_LOG(ERROR) << "Failed to sync for gpu stream id: " << i; - res = false; - } - } - return res; -} - -bool GPUDeviceManager::SyncExceptStreamsInList(const std::set &except_streams) const { - bool res = true; - for (size_t i = 0; i < gpu_streams_.size(); i++) { - if (except_streams.count(gpu_streams_[i]) > 0) { - MS_LOG(DEBUG) << "Stream id:" << i << " is been synchronized."; - continue; - } - if (!SyncStream(i)) { - MS_LOG(ERROR) << "Failed to sync for gpu stream id: " << i; - res = false; - } - } - return res; -} - -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { - return CudaDriver::CopyDeviceMemToHost(dst, src, size); -} - -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { - return CudaDriver::CopyHostMemToDevice(dst, src, size); -} - -bool GPUDeviceManager::CopyHostMemToHost(const HostMemPtr &dst, const void *src, size_t size) const { - return CudaDriver::CopyHostMemToHost(dst, src, size); -} - -bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream) const { - return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); -} - -bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream) const { - return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); -} - -bool GPUDeviceManager::CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - CudaDeviceStream stream) const { - return CudaDriver::CopyDeviceMemToDeviceAsync(dst, src, size, stream); -} -} // namespace gpu -} // namespace device -} // namespace mindspore +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/hal/device/gpu_device_manager.h" + +#include + +#include "plugin/device/gpu/hal/device/gpu_common.h" +#include "utils/log_adapter.h" +#include "include/common/utils/convert_utils.h" + +namespace mindspore { +namespace device { +namespace gpu { +GPUDeviceManager &GPUDeviceManager::GetInstance() { + static GPUDeviceManager instance; + return instance; +} + +void GPUDeviceManager::InitDevice() { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(SizeToInt(cur_dev_id_)), "Failed to set current device id"); + if (dev_alive_) { + return; + } + CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); + default_stream_id_ = gpu_streams_.size() - 1; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuDNN handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( + cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuBLAS handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnCreate(&cusolver_dn_handle_), + "Failed to create cusolver dn handle."); + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( + cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cusolver dn handle"); + // Create cusparse handle. + CHECK_CUSPARSE_RET_WITH_EXCEPT(cusparseCreate(&cusparse_handle_), "Failed to create sparse handle."); + CHECK_CUSPARSE_RET_WITH_EXCEPT(cusparseSetStream(cusparse_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cusparse handle"); + + CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") + dev_alive_ = true; +} + +void GPUDeviceManager::ReleaseDevice() { + // Avoid repeated release device resource. + if (!dev_alive_) { + return; + } + { + std::lock_guard lock_gpu_streams(stream_mutex_); + for (CudaDeviceStream stream : gpu_streams_) { + if (stream != nullptr) { + CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); + } + } + gpu_streams_.clear(); + } + + if (cudnn_handle_ != nullptr) { + CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); + } + if (cublas_handle_ != nullptr) { + CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); + } + if (cusolver_dn_handle_ != nullptr) { + CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle."); + } + if (cusparse_handle_ != nullptr) { + CHECK_CUSPARSE_RET_WITH_ERROR(cusparseDestroy(cusparse_handle_), "Failed to destroy cusparse handle."); + } + + dev_alive_ = false; +} + +bool GPUDeviceManager::CreateStream(CudaDeviceStream *stream) { + std::lock_guard lock_gpu_streams(stream_mutex_); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); + (void)gpu_streams_.emplace_back(*stream); + return true; +} + +bool GPUDeviceManager::CreateStream(size_t *stream_id) { + MS_EXCEPTION_IF_NULL(stream_id); + + std::lock_guard lock_gpu_streams(stream_mutex_); + CudaDeviceStream stream; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(&stream), "Failed to create CUDA stream"); + *stream_id = gpu_streams_.size(); + (void)gpu_streams_.emplace_back(stream); + return true; +} + +bool GPUDeviceManager::CreateStreamWithPriority(size_t *stream_id, int32_t priority) { + MS_EXCEPTION_IF_NULL(stream_id); + + std::lock_guard lock_gpu_streams(stream_mutex_); + CudaDeviceStream stream; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStreamWithPriority(&stream, priority), + "Failed to create CUDA stream with priority"); + *stream_id = gpu_streams_.size(); + (void)gpu_streams_.emplace_back(stream); + + return true; +} + +bool GPUDeviceManager::DestroyStream(size_t stream_id) { + std::lock_guard lock_gpu_streams(stream_mutex_); + if (stream_id >= gpu_streams_.size()) { + MS_LOG(ERROR) << "CUDA stream not found for stream id " << stream_id; + return false; + } + if (gpu_streams_.at(stream_id) == nullptr) { + MS_LOG(WARNING) << "CUDA stream hsa been destroyed for stream id " << stream_id; + return true; + } + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyStream(gpu_streams_.at(stream_id)), "Failed to destroy CUDA stream"); + gpu_streams_[stream_id] = nullptr; + return true; +} + +CudaDeviceStream GPUDeviceManager::GetStream(size_t stream_id) const { + if (stream_id >= gpu_streams_.size()) { + MS_LOG(DEBUG) << "Stream for stream id[" << stream_id << "] not found, return nullptr."; + return nullptr; + } + return gpu_streams_[stream_id]; +} + +size_t GPUDeviceManager::QueryStreamSize() const { + return std::count_if(gpu_streams_.begin(), gpu_streams_.end(), + [](CudaDeviceStream stream) { return stream != nullptr; }); +} + +std::vector GPUDeviceManager::GetStreamIds() const { + std::vector stream_ids; + for (size_t i = 0; i < gpu_streams_.size(); i++) { + if (gpu_streams_[i] != nullptr) { + (void)stream_ids.emplace_back(static_cast(i)); + } + } + return stream_ids; +} + +void GPUDeviceManager::set_current_stream(size_t stream_id) { current_stream_id_ = stream_id; } + +size_t GPUDeviceManager::current_stream() const { return current_stream_id_; } + +bool GPUDeviceManager::QueryStream(size_t stream_id) { + if (stream_id >= gpu_streams_.size()) { + MS_LOG(ERROR) << "CUDA stream not found for stream id " << stream_id; + return false; + } + if (gpu_streams_.at(stream_id) == nullptr) { + MS_LOG(WARNING) << "CUDA stream has been destroyed for stream id " << stream_id; + return true; + } + MS_LOG(DEBUG) << "Query completion status of stream id: " << stream_id; + return CudaDriver::QueryStream(gpu_streams_.at(stream_id)); +} + +const CudaDeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } + +size_t GPUDeviceManager::default_stream_id() const { return default_stream_id_; } + +int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } + +bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { + if (!dev_id_init_) { + dev_id_init_ = true; + cur_dev_id_ = device_id; + return true; + } else { + MS_LOG(ERROR) << "Device already been set."; + return false; + } +} + +uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } + +bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } + +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } + +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } + +const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; } + +const cusparseHandle_t &GPUDeviceManager::GetCuSparseHandle() const { return cusparse_handle_; } + +bool GPUDeviceManager::SyncStream(size_t stream_id) const { + if (!dev_alive_) { + return false; + } + auto stream = GetStream(stream_id); + if (stream == nullptr) { + MS_LOG(EXCEPTION) << "Get CUDA stream for stream id failed."; + } + return SyncStream(stream); +} + +bool GPUDeviceManager::SyncStream(const CudaDeviceStream &stream) const { + return dev_alive_ && CudaDriver::SyncStream(stream); +} + +bool GPUDeviceManager::SyncAllStreams() const { + if (!dev_alive_) { + return false; + } + for (const auto &stream : gpu_streams_) { + if (stream != nullptr && !SyncStream(stream)) { + return false; + } + } + return true; +} + +bool GPUDeviceManager::SyncNotDefaultStreams() const { + bool res = true; + for (size_t i = 0; i < gpu_streams_.size(); i++) { + if (i != default_stream_id_ && !SyncStream(i)) { + MS_LOG(ERROR) << "Failed to sync for gpu stream id: " << i; + res = false; + } + } + return res; +} + +bool GPUDeviceManager::SyncExceptStreamsInList(const std::set &except_streams) const { + bool res = true; + for (size_t i = 0; i < gpu_streams_.size(); i++) { + if (except_streams.count(gpu_streams_[i]) > 0) { + MS_LOG(DEBUG) << "Stream id:" << i << " is been synchronized."; + continue; + } + if (!SyncStream(i)) { + MS_LOG(ERROR) << "Failed to sync for gpu stream id: " << i; + res = false; + } + } + return res; +} + +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { + return CudaDriver::CopyDeviceMemToHost(dst, src, size); +} + +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { + return CudaDriver::CopyHostMemToDevice(dst, src, size); +} + +bool GPUDeviceManager::CopyHostMemToHost(const HostMemPtr &dst, const void *src, size_t size) const { + return CudaDriver::CopyHostMemToHost(dst, src, size); +} + +bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream) const { + return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); +} + +bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream) const { + return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); +} + +bool GPUDeviceManager::CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + CudaDeviceStream stream) const { + return CudaDriver::CopyDeviceMemToDeviceAsync(dst, src, size, stream); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.h b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.h index 4ebb2ff3296..93627246d4c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_device_manager.h @@ -1,120 +1,120 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/hal/device/cuda_driver.h" -#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUDeviceManager { - public: - void InitDevice(); - void ReleaseDevice(); - - int device_count() const; - bool set_cur_device_id(uint32_t device_id); - uint32_t cur_device_id() const; - bool is_device_id_init() const; - - bool CreateStream(CudaDeviceStream *stream); - bool CreateStream(size_t *stream_id); - bool CreateStreamWithPriority(size_t *stream_id, int32_t priority); - bool DestroyStream(size_t stream_id); - CudaDeviceStream GetStream(size_t stream_id) const; - size_t QueryStreamSize() const; - std::vector GetStreamIds() const; - void set_current_stream(size_t stream_id); - size_t current_stream() const; - bool QueryStream(size_t stream_id); - bool SyncStream(size_t stream_id) const; - bool SyncStream(const CudaDeviceStream &stream) const; - bool SyncAllStreams() const; - bool SyncNotDefaultStreams() const; - // Sync all streams except the streams in except_streams. - bool SyncExceptStreamsInList(const std::set &except_streams) const; - const CudaDeviceStream &default_stream() const; - size_t default_stream_id() const; - - const cudnnHandle_t &GetCudnnHandle() const; - const cublasHandle_t &GetCublasHandle() const; - const cusolverDnHandle_t &GetCusolverDnHandle() const; - const cusparseHandle_t &GetCuSparseHandle() const; - - bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - - bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; - bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; - bool CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; - bool CopyHostMemToHost(const HostMemPtr &dst, const void *src, size_t size) const; - - static GPUDeviceManager &GetInstance(); - bool single_op_multi_stream_enable() const { return single_op_multi_stream_enable_; } - void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) { - single_op_multi_stream_enable_ = single_op_multi_stream_enable; - } - - private: - GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0), dev_alive_(false) {} - ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager &) = delete; - GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; - - // Ensure the thread safety for creating and destroying stream. - std::mutex stream_mutex_; - - // default CUDA stream used for all the kernels. - CudaDeviceStream default_stream_{nullptr}; - - size_t default_stream_id_{0}; - - size_t current_stream_id_{0}; - - // all gpu CUDA streams including default_stream_. - std::vector gpu_streams_; - - // handle used for cuDNN kernels. - cudnnHandle_t cudnn_handle_{nullptr}; - - // handle used for cuBLAS kernels. - cublasHandle_t cublas_handle_{nullptr}; - - // handle used for cusolver dn kernels; - cusolverDnHandle_t cusolver_dn_handle_{nullptr}; - - // handle used for cusparse kernels; - cusparseHandle_t cusparse_handle_{nullptr}; - bool dev_id_init_; - uint32_t cur_dev_id_; - bool dev_alive_; - bool single_op_multi_stream_enable_{false}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/hal/device/cuda_driver.h" +#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUDeviceManager { + public: + void InitDevice(); + void ReleaseDevice(); + + int device_count() const; + bool set_cur_device_id(uint32_t device_id); + uint32_t cur_device_id() const; + bool is_device_id_init() const; + + bool CreateStream(CudaDeviceStream *stream); + bool CreateStream(size_t *stream_id); + bool CreateStreamWithPriority(size_t *stream_id, int32_t priority); + bool DestroyStream(size_t stream_id); + CudaDeviceStream GetStream(size_t stream_id) const; + size_t QueryStreamSize() const; + std::vector GetStreamIds() const; + void set_current_stream(size_t stream_id); + size_t current_stream() const; + bool QueryStream(size_t stream_id); + bool SyncStream(size_t stream_id) const; + bool SyncStream(const CudaDeviceStream &stream) const; + bool SyncAllStreams() const; + bool SyncNotDefaultStreams() const; + // Sync all streams except the streams in except_streams. + bool SyncExceptStreamsInList(const std::set &except_streams) const; + const CudaDeviceStream &default_stream() const; + size_t default_stream_id() const; + + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; + const cusolverDnHandle_t &GetCusolverDnHandle() const; + const cusparseHandle_t &GetCuSparseHandle() const; + + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; + + bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; + bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; + bool CopyDeviceMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, CudaDeviceStream stream) const; + bool CopyHostMemToHost(const HostMemPtr &dst, const void *src, size_t size) const; + + static GPUDeviceManager &GetInstance(); + bool single_op_multi_stream_enable() const { return single_op_multi_stream_enable_; } + void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) { + single_op_multi_stream_enable_ = single_op_multi_stream_enable; + } + + private: + GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0), dev_alive_(false) {} + ~GPUDeviceManager() = default; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; + + // Ensure the thread safety for creating and destroying stream. + std::mutex stream_mutex_; + + // default CUDA stream used for all the kernels. + CudaDeviceStream default_stream_{nullptr}; + + size_t default_stream_id_{0}; + + size_t current_stream_id_{0}; + + // all gpu CUDA streams including default_stream_. + std::vector gpu_streams_; + + // handle used for cuDNN kernels. + cudnnHandle_t cudnn_handle_{nullptr}; + + // handle used for cuBLAS kernels. + cublasHandle_t cublas_handle_{nullptr}; + + // handle used for cusolver dn kernels; + cusolverDnHandle_t cusolver_dn_handle_{nullptr}; + + // handle used for cusparse kernels; + cusparseHandle_t cusparse_handle_{nullptr}; + bool dev_id_init_; + uint32_t cur_dev_id_; + bool dev_alive_; + bool single_op_multi_stream_enable_{false}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.cc b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.cc index a59ca15ef88..5ea74b24075 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.cc @@ -1,111 +1,111 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" -#include "plugin/device/gpu/hal/device/cuda_driver.h" -#include "utils/log_adapter.h" -#include "utils/ms_context.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -namespace device { -namespace gpu { -const size_t kGBToByte = 1024 << 20; -constexpr float kReservedMemoryRatio = 0.0625; // 1/16 -static const size_t MEM_ALIGN_SIZE = 512; - -bool GPUMemoryAllocator::Init() { - size_t total_size = CudaDriver::total_mem_size(); - size_t free_size = CudaDriver::free_mem_size(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - limited_device_memory_ = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); - available_device_memory_ = FloatToSize(limited_device_memory_ * kGBToByte); - if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { - MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size - << ", set max available memory size " << available_device_memory_ << "."; - } else { - MS_LOG(EXCEPTION) << "#umsg#GPU memory error:#umsg#The total size or free size or max_device_memory size of GPU " - "memory can't be zero, total memory size " - << total_size << ", current free memory size " << free_size << ", set max available memory size " - << available_device_memory_ << "."; - } - // In gpu mode, recommend 1/16 reserved for other cuda functions - if (available_device_memory_ > total_size) { - size_t recommend_mem_size_for_others = FloatToSize(total_size * kReservedMemoryRatio); - SetMemPoolBlockSize(std::min(available_device_memory_, total_size - recommend_mem_size_for_others)); - } else { - SetMemPoolBlockSize(std::min(available_device_memory_, total_size)); - } - return true; -} - -void GPUMemoryAllocator::CheckMaxDeviceMemory() const { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - auto max_device_memory = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); - // Currently not support modifying the max device memory. - if (!common::IsFloatEqual(limited_device_memory_, max_device_memory)) { - MS_LOG(EXCEPTION) << "#umsg#Can't change or set context param max_device_memory during running:#umsg#Currently " - "effective max_device_memory(" - << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory - << "GB) failed."; - } -} - -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { - auto alloc_size = AllocDeviceMem(size, addr); - buffer_q_addr_ = *addr; - // Buffer queue needs to ensure that the alloc_size and size is equal. - return alloc_size == size; -} - -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "#umsg#GPU memory error:#umsg#The memory alloc size is 0."; - } - auto free_size = free_mem_size(); - if (size > free_size) { - MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Current free memory size[" << free_size - << "] is smaller than required size[" << size << "]."; - } - - auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); - if (alloc_size == 0) { - MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Alloc device memory[" << size << "] failed."; - } - total_used_device_memory_ += alloc_size; - available_device_memory_ -= alloc_size; - MS_LOG(INFO) << "Cuda current free memory size[" << free_size << "], alloc size[" << alloc_size - << "], left free memory size[" << free_size - alloc_size << "]" - << ".Total used size[" << total_used_device_memory_ << "]."; - return alloc_size; -} - -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } - -size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } - -size_t GPUMemoryAllocator::AlignMemorySize(size_t size) const { - if (size == 0) { - return MEM_ALIGN_SIZE; - } - return ((size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE) * MEM_ALIGN_SIZE; -} -} // namespace gpu -} // namespace device -} // namespace mindspore +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" +#include "plugin/device/gpu/hal/device/cuda_driver.h" +#include "utils/log_adapter.h" +#include "utils/ms_context.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace device { +namespace gpu { +const size_t kGBToByte = 1024 << 20; +constexpr float kReservedMemoryRatio = 0.0625; // 1/16 +static const size_t MEM_ALIGN_SIZE = 512; + +bool GPUMemoryAllocator::Init() { + size_t total_size = CudaDriver::total_mem_size(); + size_t free_size = CudaDriver::free_mem_size(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + limited_device_memory_ = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); + available_device_memory_ = FloatToSize(limited_device_memory_ * kGBToByte); + if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { + MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size + << ", set max available memory size " << available_device_memory_ << "."; + } else { + MS_LOG(EXCEPTION) << "#umsg#GPU memory error:#umsg#The total size or free size or max_device_memory size of GPU " + "memory can't be zero, total memory size " + << total_size << ", current free memory size " << free_size << ", set max available memory size " + << available_device_memory_ << "."; + } + // In gpu mode, recommend 1/16 reserved for other cuda functions + if (available_device_memory_ > total_size) { + size_t recommend_mem_size_for_others = FloatToSize(total_size * kReservedMemoryRatio); + SetMemPoolBlockSize(std::min(available_device_memory_, total_size - recommend_mem_size_for_others)); + } else { + SetMemPoolBlockSize(std::min(available_device_memory_, total_size)); + } + return true; +} + +void GPUMemoryAllocator::CheckMaxDeviceMemory() const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto max_device_memory = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); + // Currently not support modifying the max device memory. + if (!common::IsFloatEqual(limited_device_memory_, max_device_memory)) { + MS_LOG(EXCEPTION) << "#umsg#Can't change or set context param max_device_memory during running:#umsg#Currently " + "effective max_device_memory(" + << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory + << "GB) failed."; + } +} + +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { + auto alloc_size = AllocDeviceMem(size, addr); + buffer_q_addr_ = *addr; + // Buffer queue needs to ensure that the alloc_size and size is equal. + return alloc_size == size; +} + +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "#umsg#GPU memory error:#umsg#The memory alloc size is 0."; + } + auto free_size = free_mem_size(); + if (size > free_size) { + MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Current free memory size[" << free_size + << "] is smaller than required size[" << size << "]."; + } + + auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); + if (alloc_size == 0) { + MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Alloc device memory[" << size << "] failed."; + } + total_used_device_memory_ += alloc_size; + available_device_memory_ -= alloc_size; + MS_LOG(INFO) << "Cuda current free memory size[" << free_size << "], alloc size[" << alloc_size + << "], left free memory size[" << free_size - alloc_size << "]" + << ".Total used size[" << total_used_device_memory_ << "]."; + return alloc_size; +} + +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } + +size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } + +size_t GPUMemoryAllocator::AlignMemorySize(size_t size) const { + if (size == 0) { + return MEM_ALIGN_SIZE; + } + return ((size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE) * MEM_ALIGN_SIZE; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.h b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.h index c15755fd20e..510498c69aa 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_memory_allocator.h @@ -1,61 +1,61 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ - -#include -#include -#include "plugin/device/gpu/hal/device/cuda_driver.h" -#include "include/backend/mem_reuse/mem_dynamic_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUMemoryAllocator : public DynamicMemPoolBestFit { - public: - ~GPUMemoryAllocator() override = default; - bool Init(); - void CheckMaxDeviceMemory() const; - bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - - size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; - bool FreeDeviceMem(const DeviceMemPtr &addr) override; - size_t free_mem_size() override; - size_t AlignMemorySize(size_t size) const override; - std::string GetMemoryPoolType() const override { return "GPU"; } - static GPUMemoryAllocator &GetInstance() { - static GPUMemoryAllocator instance; - return instance; - } - - private: - GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; - GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; - - // Used to track address of data buffer queue. - DeviceMemPtr buffer_q_addr_{nullptr}; - - float limited_device_memory_{0.0}; - size_t total_used_device_memory_{0}; - size_t available_device_memory_{0}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ + +#include +#include +#include "plugin/device/gpu/hal/device/cuda_driver.h" +#include "include/backend/mem_reuse/mem_dynamic_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUMemoryAllocator : public DynamicMemPoolBestFit { + public: + ~GPUMemoryAllocator() override = default; + bool Init(); + void CheckMaxDeviceMemory() const; + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); + + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + size_t free_mem_size() override; + size_t AlignMemorySize(size_t size) const override; + std::string GetMemoryPoolType() const override { return "GPU"; } + static GPUMemoryAllocator &GetInstance() { + static GPUMemoryAllocator instance; + return instance; + } + + private: + GPUMemoryAllocator() = default; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; + + // Used to track address of data buffer queue. + DeviceMemPtr buffer_q_addr_{nullptr}; + + float limited_device_memory_{0.0}; + size_t total_used_device_memory_{0}; + size_t available_device_memory_{0}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.cc index 68679ea32e6..1640e111ac5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.cc @@ -1,229 +1,229 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h" -#include -#include "ops/op_utils.h" - -namespace mindspore { -namespace kernel { -template -bool ArgmaxGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - S bound = static_cast(bound_); - T *input_ptr = GetDeviceAddress(inputs, 0); - S *output_ptr = GetDeviceAddress(outputs, 0); - // call cuda kernel - auto status = CalArgmax(input_ptr, bound, outer_size_, inner_size_, output_ptr, device_id_, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> ArgmaxGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &ArgmaxGpuKernelMod::LaunchKernel}, - - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &ArgmaxGpuKernelMod::LaunchKernel}}; - -bool ArgmaxGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); - kernel_func_ = func_list_[index].second; - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << tensor_attr; - return false; - } - return true; -} - -int ArgmaxGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - constexpr auto kSizeIndex = 1; - axis_ = inputs[kSizeIndex]->GetValueWithCheck(); - auto input_shape = inputs[kIndex0]->GetShapeVector(); - auto output_shape = outputs[kIndex0]->GetShapeVector(); - is_null_input_ = - CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shape, kernel_name_, "output"); - int64_t dims = static_cast(input_shape.size()); - if (axis_ < 0) { - axis_ += dims; - } - bound_ = input_shape[axis_]; - if (input_shape[axis_] != bound_) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of input_shape[axis] should be " - << static_cast(bound_) << ", but got " << input_shape[axis_]; - return -1; - } - if (is_null_input_) { - return true; - } - outer_size_ = 1; - for (int64_t i = axis_ - 1; i >= 0; i--) { - outer_size_ *= input_shape[i]; - } - inner_size_ = 1; - for (int64_t i = axis_ + 1; i < static_cast(input_shape.size()); i++) { - inner_size_ *= input_shape[i]; - } - return KRET_OK; -} - -std::vector ArgmaxGpuKernelMod::GetOpSupport() { - static std::vector support_list; - if (support_list.empty()) { - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - } - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Argmax, ArgmaxGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h" +#include +#include "ops/op_utils.h" + +namespace mindspore { +namespace kernel { +template +bool ArgmaxGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + S bound = static_cast(bound_); + T *input_ptr = GetDeviceAddress(inputs, 0); + S *output_ptr = GetDeviceAddress(outputs, 0); + // call cuda kernel + auto status = CalArgmax(input_ptr, bound, outer_size_, inner_size_, output_ptr, device_id_, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> ArgmaxGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &ArgmaxGpuKernelMod::LaunchKernel}, + + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &ArgmaxGpuKernelMod::LaunchKernel}}; + +bool ArgmaxGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + kernel_func_ = func_list_[index].second; + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << tensor_attr; + return false; + } + return true; +} + +int ArgmaxGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + constexpr auto kSizeIndex = 1; + axis_ = inputs[kSizeIndex]->GetValueWithCheck(); + auto input_shape = inputs[kIndex0]->GetShapeVector(); + auto output_shape = outputs[kIndex0]->GetShapeVector(); + is_null_input_ = + CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shape, kernel_name_, "output"); + int64_t dims = static_cast(input_shape.size()); + if (axis_ < 0) { + axis_ += dims; + } + bound_ = input_shape[axis_]; + if (input_shape[axis_] != bound_) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of input_shape[axis] should be " + << static_cast(bound_) << ", but got " << input_shape[axis_]; + return -1; + } + if (is_null_input_) { + return true; + } + outer_size_ = 1; + for (int64_t i = axis_ - 1; i >= 0; i--) { + outer_size_ *= input_shape[i]; + } + inner_size_ = 1; + for (int64_t i = axis_ + 1; i < static_cast(input_shape.size()); i++) { + inner_size_ *= input_shape[i]; + } + return KRET_OK; +} + +std::vector ArgmaxGpuKernelMod::GetOpSupport() { + static std::vector support_list; + if (support_list.empty()) { + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + } + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Argmax, ArgmaxGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h index 6e8a50a50bb..89915d350a1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmax_gpu_kernel.h @@ -1,64 +1,64 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh" -namespace mindspore { -namespace kernel { -class ArgmaxGpuKernelMod : public NativeGpuKernelMod { - public: - ArgmaxGpuKernelMod() = default; - ~ArgmaxGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - return kernel_func_(this, inputs, workspace, outputs, stream_ptr); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr); - using ArgmaxFunc = - std::function &, const std::vector &, - const std::vector &, void *)>; - static std::vector> func_list_; - ArgmaxFunc kernel_func_; - bool is_null_input_; - int64_t axis_; - int64_t bound_; - size_t inner_size_; - size_t outer_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh" +namespace mindspore { +namespace kernel { +class ArgmaxGpuKernelMod : public NativeGpuKernelMod { + public: + ArgmaxGpuKernelMod() = default; + ~ArgmaxGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + using ArgmaxFunc = + std::function &, const std::vector &, + const std::vector &, void *)>; + static std::vector> func_list_; + ArgmaxFunc kernel_func_; + bool is_null_input_; + int64_t axis_; + int64_t bound_; + size_t inner_size_; + size_t outer_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.cc index 91e57418d52..de825b3f139 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.cc @@ -1,111 +1,111 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kNumber2 = 2; -bool CheckNumericsGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [float16, float32, float64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - return true; -} - -int CheckNumericsGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - size_t output_size = output_elements_ * unit_output_size_; - output_size_list_.push_back(output_size); - workspace_size_list_.push_back(kNumber2 * sizeof(int32_t)); - return KRET_OK; -} - -template -bool CheckNumericsGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - int32_t *flag_device = GetDeviceAddress(workspace, 0); - cudaStream_t stream = reinterpret_cast(cuda_stream_); - int32_t flag_host[2]; - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(flag_device, 0, kNumber2 * sizeof(int32_t), stream), - "flag_check cudaMemsetAsync failed."); - auto status = CalCheckNumerics(output_elements_, input, flag_device, device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(flag_host, flag_device, kNumber2 * sizeof(int32_t), cudaMemcpyDeviceToHost, stream), - "For 'checkNumerics', flag_host cudaMemcpyAsync failed."); - if (cudaStreamQuery(stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cuda Stream Sync Failed."); - } - if (flag_host[0] == 1 && flag_host[1] == 1) { - MS_EXCEPTION(ValueError) << ": Tensor had Inf and NaN values [Op" << kernel_name_ << "]."; - } else if (flag_host[0] == 1 && flag_host[1] == 0) { - MS_EXCEPTION(ValueError) << ": Tensor had NaN values [Op" << kernel_name_ << "]."; - } else if (flag_host[0] == 0 && flag_host[1] == 1) { - MS_EXCEPTION(ValueError) << ": Tensor had Inf values [Op" << kernel_name_ << "]."; - } else { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(output, input, output_elements_ * sizeof(T), cudaMemcpyDeviceToDevice, stream), - "cudaMemcpyAsync value variable failed."); - } - return true; -} - -std::vector> CheckNumericsGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &CheckNumericsGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &CheckNumericsGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &CheckNumericsGpuKernelMod::LaunchKernel}}; - -std::vector CheckNumericsGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CheckNumerics, CheckNumericsGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kNumber2 = 2; +bool CheckNumericsGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [float16, float32, float64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + return true; +} + +int CheckNumericsGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + size_t output_size = output_elements_ * unit_output_size_; + output_size_list_.push_back(output_size); + workspace_size_list_.push_back(kNumber2 * sizeof(int32_t)); + return KRET_OK; +} + +template +bool CheckNumericsGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int32_t *flag_device = GetDeviceAddress(workspace, 0); + cudaStream_t stream = reinterpret_cast(cuda_stream_); + int32_t flag_host[2]; + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(flag_device, 0, kNumber2 * sizeof(int32_t), stream), + "flag_check cudaMemsetAsync failed."); + auto status = CalCheckNumerics(output_elements_, input, flag_device, device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(flag_host, flag_device, kNumber2 * sizeof(int32_t), cudaMemcpyDeviceToHost, stream), + "For 'checkNumerics', flag_host cudaMemcpyAsync failed."); + if (cudaStreamQuery(stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cuda Stream Sync Failed."); + } + if (flag_host[0] == 1 && flag_host[1] == 1) { + MS_EXCEPTION(ValueError) << ": Tensor had Inf and NaN values [Op" << kernel_name_ << "]."; + } else if (flag_host[0] == 1 && flag_host[1] == 0) { + MS_EXCEPTION(ValueError) << ": Tensor had NaN values [Op" << kernel_name_ << "]."; + } else if (flag_host[0] == 0 && flag_host[1] == 1) { + MS_EXCEPTION(ValueError) << ": Tensor had Inf values [Op" << kernel_name_ << "]."; + } else { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output, input, output_elements_ * sizeof(T), cudaMemcpyDeviceToDevice, stream), + "cudaMemcpyAsync value variable failed."); + } + return true; +} + +std::vector> CheckNumericsGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &CheckNumericsGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &CheckNumericsGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &CheckNumericsGpuKernelMod::LaunchKernel}}; + +std::vector CheckNumericsGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CheckNumerics, CheckNumericsGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h index 081cf01a410..8e442114a9e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/check_numerics_gpu_kernel.h @@ -1,82 +1,82 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/check_numerics.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh" - -namespace mindspore { -namespace kernel { -class CheckNumericsGpuKernelMod : public NativeGpuKernelMod { - public: - CheckNumericsGpuKernelMod() { ResetResource(); } - ~CheckNumericsGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using CNFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t unit_output_size_{1}; - size_t output_elements_; - CNFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/check_numerics.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh" + +namespace mindspore { +namespace kernel { +class CheckNumericsGpuKernelMod : public NativeGpuKernelMod { + public: + CheckNumericsGpuKernelMod() { ResetResource(); } + ~CheckNumericsGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using CNFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t unit_output_size_{1}; + size_t output_elements_; + CNFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CHECK_NUMERICS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.cc index 33c8a95311f..f9a413e5b79 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.cc @@ -1,108 +1,108 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "mindspore/core/abstract/utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" -#include "plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -bool DiagPartGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type is not available, but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - return true; -} - -int DiagPartGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just - // return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = inputs[kIndex0]->GetShapeVector(); - int64_t input_dims = input_shape.size(); - int kNumberTwo = 2; - output_dims = input_dims / kNumberTwo; - output_elements_ = 1; - for (int i = 0; i < output_dims; i++) { - output_elements_ *= input_shape[i]; - } - input_elements_ = output_elements_ * output_elements_; - if (input_elements_ == 0) { - is_null_input_ = true; - } - size_t output_size = output_elements_ * unit_size_; - output_size_list_.push_back(output_size); - return KRET_OK; -} - -template -bool DiagPartGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - - auto status = CalDiagPart(output_elements_, input, output, device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> DiagPartGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &DiagPartGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &DiagPartGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &DiagPartGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &DiagPartGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &DiagPartGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - &DiagPartGpuKernelMod::LaunchKernel>}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - &DiagPartGpuKernelMod::LaunchKernel>}}; - -std::vector DiagPartGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DiagPart, DiagPartGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "mindspore/core/abstract/utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" +#include "plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +bool DiagPartGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type is not available, but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + return true; +} + +int DiagPartGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just + // return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = inputs[kIndex0]->GetShapeVector(); + int64_t input_dims = input_shape.size(); + int kNumberTwo = 2; + output_dims = input_dims / kNumberTwo; + output_elements_ = 1; + for (int i = 0; i < output_dims; i++) { + output_elements_ *= input_shape[i]; + } + input_elements_ = output_elements_ * output_elements_; + if (input_elements_ == 0) { + is_null_input_ = true; + } + size_t output_size = output_elements_ * unit_size_; + output_size_list_.push_back(output_size); + return KRET_OK; +} + +template +bool DiagPartGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + auto status = CalDiagPart(output_elements_, input, output, device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> DiagPartGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &DiagPartGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &DiagPartGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &DiagPartGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &DiagPartGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &DiagPartGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &DiagPartGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &DiagPartGpuKernelMod::LaunchKernel>}}; + +std::vector DiagPartGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DiagPart, DiagPartGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h index 3bdeccce49e..ff95c0bb7bf 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/diag_part_gpu_kernel.h @@ -1,87 +1,87 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/diag_part.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class DiagPartGpuKernelMod : public NativeGpuKernelMod { - public: - DiagPartGpuKernelMod() { ResetResource(); } - ~DiagPartGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using DiagPartFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - int p_{0}; - int64_t output_dims{0}; - size_t unit_size_{1}; - size_t input_elements_{0}; - size_t output_elements_{1}; - DiagPartFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/diag_part.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class DiagPartGpuKernelMod : public NativeGpuKernelMod { + public: + DiagPartGpuKernelMod() { ResetResource(); } + ~DiagPartGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using DiagPartFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + int p_{0}; + int64_t output_dims{0}; + size_t unit_size_{1}; + size_t input_elements_{0}; + size_t output_elements_{1}; + DiagPartFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_DIAG_PART_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.cc index 25d7a1b4f8a..4def18e4d7b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.cc @@ -1,144 +1,144 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kDiagonalInputsNum = 1; -constexpr size_t kDiagonalOutputsNum = 1; -constexpr size_t kInputDimIndex0 = 0; -constexpr size_t kInputNull = 0; -constexpr size_t kInputDimIndex1 = 1; -constexpr int64_t kInputMinDim = 2; -} // namespace - -bool FillDiagonalGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagonalInputsNum, kernel_name_); - - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the kernel type should be in [float32, int32, int64], but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - fill_value_ = GetValue(primitive_->GetAttr("fill_value")); - wrap_ = GetValue(primitive_->GetAttr("wrap")); - - if (IsOneOfUnsignedType(inputs.at(0)->dtype_id()) && fill_value_ < 0) { - MS_LOG(ERROR) << "For " << kernel_name_ << ", [file_value] should be non_negative for input of unsigned type."; - return false; - } - - return true; -} - -int FillDiagonalGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagonalInputsNum, kernel_name_); - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just - // return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), - inputs.at(kIndex0)->GetDeviceShapeVector().end()); - matrix_row_ = input_shape[kInputDimIndex0]; - matrix_col_ = input_shape[kInputDimIndex1]; - int64_t min_size = std::min(matrix_row_, matrix_col_); - input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - if (input_elements_ == kInputNull) { - is_null_input_ = true; - } - input_dims_ = input_shape.size(); - if (input_dims_ == kInputMinDim) { - for (int64_t i = (input_dims_ - 1); i >= 0; i--) { - step_ += pow(matrix_col_, i); - } - } else { - std::vector cumprod(input_dims_); - auto dims = input_shape; - std::partial_sum(dims.begin(), dims.end() - 1, cumprod.begin(), std::multiplies()); - step_ = 1 + std::accumulate(cumprod.begin(), cumprod.end(), static_cast(0)); - } - if (wrap_ || input_dims_ > kInputMinDim || matrix_row_ < matrix_col_) { - num_diagonal_elements_ = ceil(static_cast(input_elements_) / step_); - } else { - num_diagonal_elements_ = ceil(static_cast(min_size * min_size) / step_); - } - size_t input_size = input_elements_ * unit_size_; - output_size_list_.push_back(input_size); - workspace_size_list_.push_back(sizeof(bool)); - return KRET_OK; -} - -template -bool FillDiagonalGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(output, input, input_elements_ * unit_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(cuda_stream_)), - "cudaMemcpyAsync output 'output' from 'input' failed."); - auto status = CalFillDiagonal(num_diagonal_elements_, fill_value_, step_, output, device_id_, - reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> FillDiagonalGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &FillDiagonalGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &FillDiagonalGpuKernelMod::LaunchKernel}}; - -std::vector FillDiagonalGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FillDiagonal, FillDiagonalGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kDiagonalInputsNum = 1; +constexpr size_t kDiagonalOutputsNum = 1; +constexpr size_t kInputDimIndex0 = 0; +constexpr size_t kInputNull = 0; +constexpr size_t kInputDimIndex1 = 1; +constexpr int64_t kInputMinDim = 2; +} // namespace + +bool FillDiagonalGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagonalInputsNum, kernel_name_); + + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the kernel type should be in [float32, int32, int64], but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + fill_value_ = GetValue(primitive_->GetAttr("fill_value")); + wrap_ = GetValue(primitive_->GetAttr("wrap")); + + if (IsOneOfUnsignedType(inputs.at(0)->dtype_id()) && fill_value_ < 0) { + MS_LOG(ERROR) << "For " << kernel_name_ << ", [file_value] should be non_negative for input of unsigned type."; + return false; + } + + return true; +} + +int FillDiagonalGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDiagonalInputsNum, kernel_name_); + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just + // return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), + inputs.at(kIndex0)->GetDeviceShapeVector().end()); + matrix_row_ = input_shape[kInputDimIndex0]; + matrix_col_ = input_shape[kInputDimIndex1]; + int64_t min_size = std::min(matrix_row_, matrix_col_); + input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + if (input_elements_ == kInputNull) { + is_null_input_ = true; + } + input_dims_ = input_shape.size(); + if (input_dims_ == kInputMinDim) { + for (int64_t i = (input_dims_ - 1); i >= 0; i--) { + step_ += pow(matrix_col_, i); + } + } else { + std::vector cumprod(input_dims_); + auto dims = input_shape; + std::partial_sum(dims.begin(), dims.end() - 1, cumprod.begin(), std::multiplies()); + step_ = 1 + std::accumulate(cumprod.begin(), cumprod.end(), static_cast(0)); + } + if (wrap_ || input_dims_ > kInputMinDim || matrix_row_ < matrix_col_) { + num_diagonal_elements_ = ceil(static_cast(input_elements_) / step_); + } else { + num_diagonal_elements_ = ceil(static_cast(min_size * min_size) / step_); + } + size_t input_size = input_elements_ * unit_size_; + output_size_list_.push_back(input_size); + workspace_size_list_.push_back(sizeof(bool)); + return KRET_OK; +} + +template +bool FillDiagonalGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output, input, input_elements_ * unit_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync output 'output' from 'input' failed."); + auto status = CalFillDiagonal(num_diagonal_elements_, fill_value_, step_, output, device_id_, + reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> FillDiagonalGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &FillDiagonalGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &FillDiagonalGpuKernelMod::LaunchKernel}}; + +std::vector FillDiagonalGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FillDiagonal, FillDiagonalGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h index 702080422eb..c0d6cd23abd 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/fill_diagonal_gpu_kernel.h @@ -1,93 +1,93 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/fill_diagonal.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class FillDiagonalGpuKernelMod : public NativeGpuKernelMod { - public: - FillDiagonalGpuKernelMod() { ResetResource(); } - ~FillDiagonalGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - input_elements_ = 0; - step_ = 0; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using FillDiagonalFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - float fill_value_{0.0}; - bool wrap_{false}; - size_t num_diagonal_elements_{0}; - int64_t step_{0}; - int64_t input_dims_{0}; - int64_t matrix_row_{0}; - int64_t matrix_col_{0}; - size_t unit_size_{1}; - size_t input_elements_{0}; - FillDiagonalFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/fill_diagonal.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class FillDiagonalGpuKernelMod : public NativeGpuKernelMod { + public: + FillDiagonalGpuKernelMod() { ResetResource(); } + ~FillDiagonalGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + input_elements_ = 0; + step_ = 0; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using FillDiagonalFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + float fill_value_{0.0}; + bool wrap_{false}; + size_t num_diagonal_elements_{0}; + int64_t step_{0}; + int64_t input_dims_{0}; + int64_t matrix_row_{0}; + int64_t matrix_col_{0}; + size_t unit_size_{1}; + size_t input_elements_{0}; + FillDiagonalFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_FILL_DIAGONAL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.cc index b3669f96f55..f462b36a633 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.cc @@ -1,203 +1,203 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h" -#include -#include -#include -#include -#include "runtime/device/ms_device_shape_transfer.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -template -std::unique_ptr CreateListDiffKernelPtr(const std::string &kernel_name, - const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); -} -using ListDiffPtrCreatorFunc = - std::function(const std::string &, const uint32_t &)>; - -const std::vector> kernel_attr = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - CreateListDiffKernelPtr}}; -} // namespace - -bool ListDiffGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - auto [is_match, index] = MatchKernelAttr(GetKernelAttrFromTensors(inputs, outputs), GetOpSupport()); - if (!is_match) { - return false; - } - helper_ptr_ = kernel_attr[index].second(kernel_name_, device_id_); - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For 'ListDiff' got empty inputs or outputs, which is invalid."; - return false; - } - return true; -} - -int ListDiffGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - std::vector> input_shapes; - std::vector> output_shapes; - input_shapes.emplace_back(inputs[kIndex0]->GetDeviceShapeVector()); - input_shapes.emplace_back(inputs[kIndex1]->GetDeviceShapeVector()); - helper_ptr_->CalMemSize(input_shapes, output_shapes); - ResetResource(); - InitSizeLists(); - return KRET_OK; -} - -void ListDiffGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), - "cudaStreamSynchronized failed"); - size_t output_num = outputs.size(); - auto dyn_out = helper_ptr_->GetOutputTensorInfo(); - auto num_out = dyn_out.shapes[kIndex0][kIndex0]; - std::vector shape = outputs[kIndex0]->GetShapeVector(); - shape[kIndex0] = num_out; - for (size_t i = 0; i < output_num; ++i) { - outputs[i]->SetShapeVector(std::vector(shape.begin(), shape.end())); - outputs[i]->set_size(LongToSize(std::accumulate(shape.begin(), shape.end(), UnitSizeInBytes(outputs[i]->dtype_id()), - std::multiplies()))); - } -} - -std::vector ListDiffGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ListDiff, ListDiffGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h" +#include +#include +#include +#include +#include "runtime/device/ms_device_shape_transfer.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateListDiffKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using ListDiffPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + CreateListDiffKernelPtr}}; +} // namespace + +bool ListDiffGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + auto [is_match, index] = MatchKernelAttr(GetKernelAttrFromTensors(inputs, outputs), GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = kernel_attr[index].second(kernel_name_, device_id_); + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For 'ListDiff' got empty inputs or outputs, which is invalid."; + return false; + } + return true; +} + +int ListDiffGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + std::vector> input_shapes; + std::vector> output_shapes; + input_shapes.emplace_back(inputs[kIndex0]->GetDeviceShapeVector()); + input_shapes.emplace_back(inputs[kIndex1]->GetDeviceShapeVector()); + helper_ptr_->CalMemSize(input_shapes, output_shapes); + ResetResource(); + InitSizeLists(); + return KRET_OK; +} + +void ListDiffGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); + size_t output_num = outputs.size(); + auto dyn_out = helper_ptr_->GetOutputTensorInfo(); + auto num_out = dyn_out.shapes[kIndex0][kIndex0]; + std::vector shape = outputs[kIndex0]->GetShapeVector(); + shape[kIndex0] = num_out; + for (size_t i = 0; i < output_num; ++i) { + outputs[i]->SetShapeVector(std::vector(shape.begin(), shape.end())); + outputs[i]->set_size(LongToSize(std::accumulate(shape.begin(), shape.end(), UnitSizeInBytes(outputs[i]->dtype_id()), + std::multiplies()))); + } +} + +std::vector ListDiffGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ListDiff, ListDiffGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h index 9cce09dfb30..886b3a1aa98 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/list_diff_gpu_kernel.h @@ -1,78 +1,78 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h" - -namespace mindspore { -namespace kernel { -class ListDiffGpuKernelMod : public NativeGpuKernelMod { - public: - ListDiffGpuKernelMod() { - KernelMod::kernel_name_ = "ListDiff"; - ResetResource(); - } - ~ListDiffGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - stream_ptr_ = stream_ptr; - std::vector input_ptrs = ConvertPtrs(inputs); - std::vector work_ptrs = ConvertPtrs(workspace); - std::vector output_ptrs = ConvertPtrs(outputs); - if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { - return false; - } - return true; - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - bool IsNeedUpdateOutputShapeAndSize() override { return true; } - void UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) override; - - void ResetResource() noexcept { - stream_ptr_ = nullptr; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - void InitSizeLists() { - output_size_list_ = helper_ptr_->GetOutputSizeList(); - workspace_size_list_ = helper_ptr_->GetWorkSizeList(); - } - - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_ = nullptr; - void *stream_ptr_; - std::optional is_input_dynamic_shape_ = {}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h" + +namespace mindspore { +namespace kernel { +class ListDiffGpuKernelMod : public NativeGpuKernelMod { + public: + ListDiffGpuKernelMod() { + KernelMod::kernel_name_ = "ListDiff"; + ResetResource(); + } + ~ListDiffGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + stream_ptr_ = stream_ptr; + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + bool IsNeedUpdateOutputShapeAndSize() override { return true; } + void UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) override; + + void ResetResource() noexcept { + stream_ptr_ = nullptr; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + void InitSizeLists() { + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + } + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_ = nullptr; + void *stream_ptr_; + std::optional is_input_dynamic_shape_ = {}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LIST_DIFF_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.cc index ac2b36ee39f..998fc494d6b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.cc @@ -1,104 +1,104 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool MvlgammaGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the kernel type should be in [float32, float64], but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - p_ = GetValue(primitive_->GetAttr(ops::kP)); - if (p_ < 1) { - MS_LOG(ERROR) << "For " << kernel_name_ << ", the attr 'p' has to be greater than or equal to 1, " - << "but got " << p_ << "."; - return false; - } - return true; -} - -int MvlgammaGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just - // return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), - inputs.at(kIndex0)->GetDeviceShapeVector().end()); - input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - if (input_elements_ == 0) { - is_null_input_ = true; - } - int64_t input_dims = input_shape.size(); - if (input_dims < 1) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " << input_dims - << "-D."; - return KRET_RESIZE_FAILED; - } - size_t input_size = input_elements_ * unit_size_; - output_size_list_.push_back(input_size); - workspace_size_list_.push_back(sizeof(int)); - return KRET_OK; -} - -template -bool MvlgammaGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - int *valid_d = GetDeviceAddress(workspace, 0); - int host_valid = -1; - auto status = CalMvlgamma(valid_d, input_elements_, input, p_, output, device_id_, - reinterpret_cast(cuda_stream_), &host_valid); - CHECK_CUDA_STATUS(status, kernel_name_); - if (host_valid >= 0) { - MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", all elements of 'x' must be greater than (p-1)/2"; - } - return true; -} - -std::vector> MvlgammaGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &MvlgammaGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &MvlgammaGpuKernelMod::LaunchKernel}}; - -std::vector MvlgammaGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mvlgamma, MvlgammaGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool MvlgammaGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the kernel type should be in [float32, float64], but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + p_ = GetValue(primitive_->GetAttr(ops::kP)); + if (p_ < 1) { + MS_LOG(ERROR) << "For " << kernel_name_ << ", the attr 'p' has to be greater than or equal to 1, " + << "but got " << p_ << "."; + return false; + } + return true; +} + +int MvlgammaGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just + // return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), + inputs.at(kIndex0)->GetDeviceShapeVector().end()); + input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + if (input_elements_ == 0) { + is_null_input_ = true; + } + int64_t input_dims = input_shape.size(); + if (input_dims < 1) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " << input_dims + << "-D."; + return KRET_RESIZE_FAILED; + } + size_t input_size = input_elements_ * unit_size_; + output_size_list_.push_back(input_size); + workspace_size_list_.push_back(sizeof(int)); + return KRET_OK; +} + +template +bool MvlgammaGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int *valid_d = GetDeviceAddress(workspace, 0); + int host_valid = -1; + auto status = CalMvlgamma(valid_d, input_elements_, input, p_, output, device_id_, + reinterpret_cast(cuda_stream_), &host_valid); + CHECK_CUDA_STATUS(status, kernel_name_); + if (host_valid >= 0) { + MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", all elements of 'x' must be greater than (p-1)/2"; + } + return true; +} + +std::vector> MvlgammaGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &MvlgammaGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &MvlgammaGpuKernelMod::LaunchKernel}}; + +std::vector MvlgammaGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mvlgamma, MvlgammaGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h index ef81319c263..b407570fee0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h @@ -1,86 +1,86 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/mvlgamma.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class MvlgammaGpuKernelMod : public NativeGpuKernelMod { - public: - MvlgammaGpuKernelMod() { ResetResource(); } - ~MvlgammaGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - input_elements_ = 0; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using MvlgammaFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - int p_{0}; - size_t unit_size_{1}; - size_t input_elements_; - MvlgammaFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_mvlgamma_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/mvlgamma.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MvlgammaGpuKernelMod : public NativeGpuKernelMod { + public: + MvlgammaGpuKernelMod() { ResetResource(); } + ~MvlgammaGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + input_elements_ = 0; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using MvlgammaFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + int p_{0}; + size_t unit_size_{1}; + size_t input_elements_; + MvlgammaFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_mvlgamma_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.cc index 5ae3767aa06..e02cfb28834 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.cc @@ -1,95 +1,95 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool MvlgammaGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - p_ = GetValue(primitive_->GetAttr(ops::kP)); - return true; -} - -int MvlgammaGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just - // return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = std::vector(inputs.at(kIndex1)->GetDeviceShapeVector().begin(), - inputs.at(kIndex1)->GetDeviceShapeVector().end()); - input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - if (input_elements_ == 0) { - is_null_input_ = true; - } - int64_t input_dims = input_shape.size(); - if (input_dims < 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " - << input_dims << "-D."; - return KRET_RESIZE_FAILED; - } - size_t input_size = input_elements_ * unit_size_; - - output_size_list_.push_back(input_size); - return KRET_OK; -} - -template -bool MvlgammaGradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *y_grad = GetDeviceAddress(inputs, 0); - T *x = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - auto status = - CalMvlgammaGrad(input_elements_, y_grad, x, p_, output, device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> MvlgammaGradGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &MvlgammaGradGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &MvlgammaGradGpuKernelMod::LaunchKernel}}; - -std::vector MvlgammaGradGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MvlgammaGrad, MvlgammaGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool MvlgammaGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + p_ = GetValue(primitive_->GetAttr(ops::kP)); + return true; +} + +int MvlgammaGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just + // return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = std::vector(inputs.at(kIndex1)->GetDeviceShapeVector().begin(), + inputs.at(kIndex1)->GetDeviceShapeVector().end()); + input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + if (input_elements_ == 0) { + is_null_input_ = true; + } + int64_t input_dims = input_shape.size(); + if (input_dims < 1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " + << input_dims << "-D."; + return KRET_RESIZE_FAILED; + } + size_t input_size = input_elements_ * unit_size_; + + output_size_list_.push_back(input_size); + return KRET_OK; +} + +template +bool MvlgammaGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *y_grad = GetDeviceAddress(inputs, 0); + T *x = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + auto status = + CalMvlgammaGrad(input_elements_, y_grad, x, p_, output, device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> MvlgammaGradGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &MvlgammaGradGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &MvlgammaGradGpuKernelMod::LaunchKernel}}; + +std::vector MvlgammaGradGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MvlgammaGrad, MvlgammaGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h index d51aa151413..117180c45ba 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h @@ -1,83 +1,83 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include "abstract/utils.h" -#include "mindspore/core/ops/grad/mvlgamma_grad.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class MvlgammaGradGpuKernelMod : public NativeGpuKernelMod { - public: - MvlgammaGradGpuKernelMod() { ResetResource(); } - ~MvlgammaGradGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - input_elements_ = 0; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using MvlgammaGradFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - int p_{0}; - size_t input_elements_; - size_t unit_size_{1}; - MvlgammaGradFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "abstract/utils.h" +#include "mindspore/core/ops/grad/mvlgamma_grad.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MvlgammaGradGpuKernelMod : public NativeGpuKernelMod { + public: + MvlgammaGradGpuKernelMod() { ResetResource(); } + ~MvlgammaGradGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + input_elements_ = 0; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using MvlgammaGradFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + int p_{0}; + size_t input_elements_; + size_t unit_size_{1}; + MvlgammaGradFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.cc index 5c1e5fb2a65..d48cc8ac718 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.cc @@ -1,84 +1,84 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -template -void ScaleGradGpuKernelMod::LaunchScaleGradPerGrad(const std::vector &inputs, - const std::vector &outputs, void *stream_ptr, - const half *scale_addr_half, const float *scale_addr_float, - size_t index) { - T *input_addr = GetDeviceAddress(inputs, index); - T *output_addr = GetDeviceAddress(outputs, index); - cudaError_t status = cudaErrorNotReady; - if (scale_addr_half != nullptr) { - status = ScaleGradKernel(outputs[index]->size() / sizeof(T), input_addr, *scale_addr_half, output_addr, - reinterpret_cast(stream_ptr)); - } else { - MS_EXCEPTION_IF_NULL(scale_addr_float); - status = ScaleGradKernel(outputs[index]->size() / sizeof(T), input_addr, *scale_addr_float, output_addr, - reinterpret_cast(stream_ptr)); - } - CHECK_CUDA_STATUS(status, kernel_name_); -} - -bool ScaleGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - half *scale_addr_half = nullptr; - float *scale_addr_float = nullptr; - if (input_info_.back() == kNumberTypeFloat16) { - scale_addr_half = GetDeviceAddress(inputs, inputs.size() - 1); - } else { - scale_addr_float = GetDeviceAddress(inputs, inputs.size() - 1); - } - - for (size_t i = 0; i < inputs.size() - 1; i++) { - switch (input_info_[i]) { - case kNumberTypeFloat16: { - LaunchScaleGradPerGrad(inputs, outputs, stream_ptr, scale_addr_half, scale_addr_float, i); - break; - } - case kNumberTypeFloat32: { - LaunchScaleGradPerGrad(inputs, outputs, stream_ptr, scale_addr_half, scale_addr_float, i); - break; - } - default: - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the typeid cannot be " << input_info_[i]; - } - } - return true; -} - -bool ScaleGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto input_size = inputs.size(); - for (size_t index = 0; index < input_size; index++) { - auto type_id = inputs[index]->dtype_id(); - input_info_.push_back(type_id); - if (index < input_size - 1) { - output_size_list_.push_back(inputs[index]->size()); - } - } - - return true; -} - -MS_REG_GPU_KERNEL(ScaleGrad, ScaleGradGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +template +void ScaleGradGpuKernelMod::LaunchScaleGradPerGrad(const std::vector &inputs, + const std::vector &outputs, void *stream_ptr, + const half *scale_addr_half, const float *scale_addr_float, + size_t index) { + T *input_addr = GetDeviceAddress(inputs, index); + T *output_addr = GetDeviceAddress(outputs, index); + cudaError_t status = cudaErrorNotReady; + if (scale_addr_half != nullptr) { + status = ScaleGradKernel(outputs[index]->size() / sizeof(T), input_addr, *scale_addr_half, output_addr, + reinterpret_cast(stream_ptr)); + } else { + MS_EXCEPTION_IF_NULL(scale_addr_float); + status = ScaleGradKernel(outputs[index]->size() / sizeof(T), input_addr, *scale_addr_float, output_addr, + reinterpret_cast(stream_ptr)); + } + CHECK_CUDA_STATUS(status, kernel_name_); +} + +bool ScaleGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + half *scale_addr_half = nullptr; + float *scale_addr_float = nullptr; + if (input_info_.back() == kNumberTypeFloat16) { + scale_addr_half = GetDeviceAddress(inputs, inputs.size() - 1); + } else { + scale_addr_float = GetDeviceAddress(inputs, inputs.size() - 1); + } + + for (size_t i = 0; i < inputs.size() - 1; i++) { + switch (input_info_[i]) { + case kNumberTypeFloat16: { + LaunchScaleGradPerGrad(inputs, outputs, stream_ptr, scale_addr_half, scale_addr_float, i); + break; + } + case kNumberTypeFloat32: { + LaunchScaleGradPerGrad(inputs, outputs, stream_ptr, scale_addr_half, scale_addr_float, i); + break; + } + default: + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the typeid cannot be " << input_info_[i]; + } + } + return true; +} + +bool ScaleGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto input_size = inputs.size(); + for (size_t index = 0; index < input_size; index++) { + auto type_id = inputs[index]->dtype_id(); + input_info_.push_back(type_id); + if (index < input_size - 1) { + output_size_list_.push_back(inputs[index]->size()); + } + } + + return true; +} + +MS_REG_GPU_KERNEL(ScaleGrad, ScaleGradGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h index 89469049585..5ad01d9ffac 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scalegrad_gpu_kernel.h @@ -1,49 +1,49 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SCALEGRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SCALEGRAD_GPU_KERNEL_H - -#include -#include -#include -#include "mindspore/core/ops/fusion/scale_grad_fusion.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -class ScaleGradGpuKernelMod : public NativeGpuKernelMod { - public: - ScaleGradGpuKernelMod() { kernel_name_ = "ScaleGrad"; } - ~ScaleGradGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - private: - template - void LaunchScaleGradPerGrad(const std::vector &inputs, const std::vector &outputs, - void *stream_ptr, const half *scale_addr_half, const float *scale_addr_float, - size_t index); - std::vector input_info_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPLIT_GPU_KERNEL_H +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SCALEGRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SCALEGRAD_GPU_KERNEL_H + +#include +#include +#include +#include "mindspore/core/ops/fusion/scale_grad_fusion.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +class ScaleGradGpuKernelMod : public NativeGpuKernelMod { + public: + ScaleGradGpuKernelMod() { kernel_name_ = "ScaleGrad"; } + ~ScaleGradGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + private: + template + void LaunchScaleGradPerGrad(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr, const half *scale_addr_half, const float *scale_addr_float, + size_t index); + std::vector input_info_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPLIT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.cc old mode 100755 new mode 100644 index aca89accf8c..6f8b1f4beef --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.cc @@ -1,140 +1,140 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -using Complex = mindspore::utils::Complex; -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint8_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint16_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint32_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint64_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int8_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int16_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int32_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int64_t, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, half, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, float, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, double, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, bool, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint8_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint16_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint32_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint64_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int8_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int16_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int32_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int64_t, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, half, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, float, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, double, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, bool, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, Complex, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, Complex, int64_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, Complex, int32_t) -MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, Complex, int64_t) - -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint8_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint16_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint32_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, uint64_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int8_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int16_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int32_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, int64_t, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, half, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, float, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, double, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, bool, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint8_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint16_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint32_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, uint64_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int8_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int16_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int32_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, int64_t, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, half, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, float, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, double, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, bool, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, Complex, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, Complex, int64_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32), - TensorShapeGpuKernelMod, Complex, int32_t) -MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), - TensorShapeGpuKernelMod, Complex, int64_t) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +using Complex = mindspore::utils::Complex; +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint8_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint16_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint32_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint64_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int8_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int16_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int32_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int64_t, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, half, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, float, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, double, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, bool, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint8_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint16_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint32_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint64_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int8_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int16_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int32_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int64_t, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, half, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, float, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, double, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, bool, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, Complex, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, Complex, int64_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, Complex, int32_t) +MS_REG_GPU_KERNEL_TWO(TensorShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, Complex, int64_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint8_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint16_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint32_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, uint64_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int8_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int16_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int32_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, int64_t, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, half, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, float, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, double, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, bool, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint8_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint16_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint32_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, uint64_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int8_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int16_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int32_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, int64_t, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, half, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, float, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, double, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, bool, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, Complex, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, Complex, int64_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32), + TensorShapeGpuKernelMod, Complex, int32_t) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64), + TensorShapeGpuKernelMod, Complex, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_shape_gpu_kernel.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc index d280eaea556..c4497407855 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc @@ -1,144 +1,144 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "utils/check_convert_utils.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -template -using Complex = mindspore::utils::Complex; - -constexpr size_t kPermInputNum = 2; - -#define OP_REGISTER(INPUTX, OUTPUT, T) \ - { \ - KernelAttr().AddInputAttr(INPUTX).AddInputAttr(kObjectTypeTuple, kNumberTypeInt64).AddOutputAttr(OUTPUT), \ - &TransposeGpuKernelMod::LaunchKernel \ - } - -const std::vector> &TransposeGpuKernelMod::GetFuncList() - const { - static const std::vector> func_list = { - OP_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex), - OP_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex), - OP_REGISTER(kNumberTypeBool, kNumberTypeBool, bool), - OP_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double), - OP_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float), - OP_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half), - OP_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t), - OP_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t), - OP_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t), - OP_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t), - OP_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t), - OP_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t), - OP_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t), - OP_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t), - }; - return func_list; -} - -template -bool TransposeGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - - size_t size = SizeOf(input_shape_); - if (is_copy_) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output, input, size * sizeof(T), cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr_)), - "For '" << kernel_name_ << "', cudaMemcpyAsync input to output failed."); - return true; - } - auto status = CalTranspose(size, input, info_, output, reinterpret_cast(stream_ptr_)); - - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -void TransposeGpuKernelMod::GetPermValue(const std::vector &perm, std::vector *input_perm) { - input_perm->clear(); - for (size_t j = 0; j < perm.size(); j++) { - auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]); - if (p < 0) { - MS_LOG(EXCEPTION) << "the perm value must be in [-" << perm.size() << ", " << (perm.size() - 1) << "], but got " - << perm; - } - input_perm->push_back(p); - } -} - -bool TransposeGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { - return false; - } - size_t input_num = inputs.size(); - size_t output_num = outputs.size(); - - if (input_num != kPermInputNum) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kPermInputNum << ", but got " - << input_num; - } - if (output_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num; - } - return true; -} - -bool TransposeGpuKernelMod::IsCopy(const std::vector &perm) { - int32_t index = 0; - return !(std::any_of(perm.begin(), perm.end(), [&](int32_t x) { return x != index++; })); -} - -int TransposeGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto perm = inputs[kIndex1]->GetValueWithCheck>(); - std::vector input_perm; - GetPermValue(perm, &input_perm); - auto input_shape = inputs[kIndex0]->GetDeviceShapeVector(); - if (input_shape.empty()) { - is_copy_ = true; - return KRET_OK; - } - - if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t s) { return s == 0; })) { - is_empty_tensor_ = true; - return KRET_OK; - } - - shape_size_ = input_shape.size(); - if (shape_size_ > transpose_max_dimension) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " - << transpose_max_dimension << ", but got " << shape_size_; - } - SimplifyTranspose(input_shape, input_perm, &input_shape_, &input_perm_); - info_.input_shape = input_shape_; - info_.perm = input_perm_; - is_copy_ = IsCopy(input_perm_); - return KRET_OK; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Transpose, TransposeGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "utils/check_convert_utils.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +template +using Complex = mindspore::utils::Complex; + +constexpr size_t kPermInputNum = 2; + +#define OP_REGISTER(INPUTX, OUTPUT, T) \ + { \ + KernelAttr().AddInputAttr(INPUTX).AddInputAttr(kObjectTypeTuple, kNumberTypeInt64).AddOutputAttr(OUTPUT), \ + &TransposeGpuKernelMod::LaunchKernel \ + } + +const std::vector> &TransposeGpuKernelMod::GetFuncList() + const { + static const std::vector> func_list = { + OP_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex), + OP_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex), + OP_REGISTER(kNumberTypeBool, kNumberTypeBool, bool), + OP_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double), + OP_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float), + OP_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half), + OP_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t), + OP_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t), + OP_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t), + OP_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t), + OP_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t), + OP_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t), + OP_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t), + OP_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t), + }; + return func_list; +} + +template +bool TransposeGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = SizeOf(input_shape_); + if (is_copy_) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output, input, size * sizeof(T), cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr_)), + "For '" << kernel_name_ << "', cudaMemcpyAsync input to output failed."); + return true; + } + auto status = CalTranspose(size, input, info_, output, reinterpret_cast(stream_ptr_)); + + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +void TransposeGpuKernelMod::GetPermValue(const std::vector &perm, std::vector *input_perm) { + input_perm->clear(); + for (size_t j = 0; j < perm.size(); j++) { + auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]); + if (p < 0) { + MS_LOG(EXCEPTION) << "the perm value must be in [-" << perm.size() << ", " << (perm.size() - 1) << "], but got " + << perm; + } + input_perm->push_back(p); + } +} + +bool TransposeGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (!MatchKernelFunc(kernel_name_, inputs, outputs)) { + return false; + } + size_t input_num = inputs.size(); + size_t output_num = outputs.size(); + + if (input_num != kPermInputNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kPermInputNum << ", but got " + << input_num; + } + if (output_num != 1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num; + } + return true; +} + +bool TransposeGpuKernelMod::IsCopy(const std::vector &perm) { + int32_t index = 0; + return !(std::any_of(perm.begin(), perm.end(), [&](int32_t x) { return x != index++; })); +} + +int TransposeGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto perm = inputs[kIndex1]->GetValueWithCheck>(); + std::vector input_perm; + GetPermValue(perm, &input_perm); + auto input_shape = inputs[kIndex0]->GetDeviceShapeVector(); + if (input_shape.empty()) { + is_copy_ = true; + return KRET_OK; + } + + if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t s) { return s == 0; })) { + is_empty_tensor_ = true; + return KRET_OK; + } + + shape_size_ = input_shape.size(); + if (shape_size_ > transpose_max_dimension) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " + << transpose_max_dimension << ", but got " << shape_size_; + } + SimplifyTranspose(input_shape, input_perm, &input_shape_, &input_perm_); + info_.input_shape = input_shape_; + info_.perm = input_perm_; + is_copy_ = IsCopy(input_perm_); + return KRET_OK; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Transpose, TransposeGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/unravel_index_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/unravel_index_gpu_kernel.h index 1503955b0b6..1fe51f649e6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/unravel_index_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/unravel_index_gpu_kernel.h @@ -1,52 +1,52 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/unravel_index.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h" -namespace mindspore { -namespace kernel { -class UnravelIndexGpuKernelMod : public NativeGpuKernelMod { - public: - UnravelIndexGpuKernelMod() {} - ~UnravelIndexGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/unravel_index.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h" +namespace mindspore { +namespace kernel { +class UnravelIndexGpuKernelMod : public NativeGpuKernelMod { + public: + UnravelIndexGpuKernelMod() {} + ~UnravelIndexGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_UNRAVEL_INDEX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h index 71e0161a74a..32564357418 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h @@ -1,139 +1,139 @@ -/** - * Copyright 2019-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh" - -namespace mindspore { -namespace cukernel { -class AdaptiveAvgPool3DAttr : public GpuKernelAttrBase { - public: - AdaptiveAvgPool3DAttr() = default; - ~AdaptiveAvgPool3DAttr() override = default; - std::vector output_size; -}; - -template -class AdaptiveAvgPool3DHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit AdaptiveAvgPool3DHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - is_null_input_ = false; - } - - virtual ~AdaptiveAvgPool3DHelperGpuKernel() = default; - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - constexpr size_t OUTPUT_NUM = 1; - ResetResource(); - - int out_flag = - CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); - if (out_flag == -1) { - return out_flag; - } - is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); - - constexpr size_t kInputShapeSize = 3; - constexpr size_t kOutputShapeSize = 3; - auto input_rank = input_shapes[0].size(); - auto output_rank = output_shapes[0].size(); - if (input_rank < kInputShapeSize) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be less than 3, but got " - << input_rank; - return -1; - } - if (output_rank < kOutputShapeSize) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of output cannot be less than 3, but got " - << output_rank; - return -1; - } - - constexpr int DEPTH = 1; - constexpr int WIDTH = 2; - constexpr int HEIGHT = 3; - constexpr int CHANNEL = 4; - constexpr int DIMENSION = 4; - input_channel_ = input_shapes[0][input_rank - CHANNEL]; - input_height_ = input_shapes[0][input_rank - HEIGHT]; - input_width_ = input_shapes[0][input_rank - WIDTH]; - input_depth_ = input_shapes[0][input_rank - DEPTH]; - output_channel_ = output_shapes[0][output_rank - CHANNEL]; - output_height_ = output_shapes[0][output_rank - HEIGHT]; - output_width_ = output_shapes[0][output_rank - WIDTH]; - output_depth_ = output_shapes[0][output_rank - DEPTH]; - out_size_ = output_rank == DIMENSION - ? output_shapes[0][0] * output_height_ * output_width_ * output_depth_ - : output_shapes[0][0] * output_shapes[0][1] * output_height_ * output_width_ * output_depth_; - - return 0; - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - if (is_null_input_) { - return 0; - } - - T *input_ptr = nullptr; - T *output_ptr = nullptr; - int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &input_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); - if (flag != 0) { - return flag; - } - - // call cuda kernel - auto status = - ApplyAdaptiveAvgPool3D((uint)out_size_, (uint)input_channel_, (uint)input_height_, (uint)input_width_, - (uint)input_depth_, (uint)output_channel_, (uint)output_height_, (uint)output_width_, - (uint)output_depth_, input_ptr, output_ptr, reinterpret_cast(cuda_stream)); - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - - void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { - attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); - } - - private: - std::shared_ptr attr_ptr_; - std::vector input_shape_; - int64_t len_{0}; - int64_t input_channel_{0}; - int64_t input_height_{0}; - int64_t input_width_{0}; - int64_t input_depth_{0}; - int64_t output_channel_{0}; - int64_t output_height_{0}; - int64_t output_width_{0}; - int64_t output_depth_{0}; - int64_t in_size_{0}; - int64_t out_size_{0}; - bool is_null_input_{false}; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh" + +namespace mindspore { +namespace cukernel { +class AdaptiveAvgPool3DAttr : public GpuKernelAttrBase { + public: + AdaptiveAvgPool3DAttr() = default; + ~AdaptiveAvgPool3DAttr() override = default; + std::vector output_size; +}; + +template +class AdaptiveAvgPool3DHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit AdaptiveAvgPool3DHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + } + + virtual ~AdaptiveAvgPool3DHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t OUTPUT_NUM = 1; + ResetResource(); + + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); + + constexpr size_t kInputShapeSize = 3; + constexpr size_t kOutputShapeSize = 3; + auto input_rank = input_shapes[0].size(); + auto output_rank = output_shapes[0].size(); + if (input_rank < kInputShapeSize) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input cannot be less than 3, but got " + << input_rank; + return -1; + } + if (output_rank < kOutputShapeSize) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of output cannot be less than 3, but got " + << output_rank; + return -1; + } + + constexpr int DEPTH = 1; + constexpr int WIDTH = 2; + constexpr int HEIGHT = 3; + constexpr int CHANNEL = 4; + constexpr int DIMENSION = 4; + input_channel_ = input_shapes[0][input_rank - CHANNEL]; + input_height_ = input_shapes[0][input_rank - HEIGHT]; + input_width_ = input_shapes[0][input_rank - WIDTH]; + input_depth_ = input_shapes[0][input_rank - DEPTH]; + output_channel_ = output_shapes[0][output_rank - CHANNEL]; + output_height_ = output_shapes[0][output_rank - HEIGHT]; + output_width_ = output_shapes[0][output_rank - WIDTH]; + output_depth_ = output_shapes[0][output_rank - DEPTH]; + out_size_ = output_rank == DIMENSION + ? output_shapes[0][0] * output_height_ * output_width_ * output_depth_ + : output_shapes[0][0] * output_shapes[0][1] * output_height_ * output_width_ * output_depth_; + + return 0; + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + + T *input_ptr = nullptr; + T *output_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &input_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + + // call cuda kernel + auto status = + ApplyAdaptiveAvgPool3D((uint)out_size_, (uint)input_channel_, (uint)input_height_, (uint)input_width_, + (uint)input_depth_, (uint)output_channel_, (uint)output_height_, (uint)output_width_, + (uint)output_depth_, input_ptr, output_ptr, reinterpret_cast(cuda_stream)); + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + } + + private: + std::shared_ptr attr_ptr_; + std::vector input_shape_; + int64_t len_{0}; + int64_t input_channel_{0}; + int64_t input_height_{0}; + int64_t input_width_{0}; + int64_t input_depth_{0}; + int64_t output_channel_{0}; + int64_t output_height_{0}; + int64_t output_width_{0}; + int64_t output_depth_{0}; + int64_t in_size_{0}; + int64_t out_size_{0}; + bool is_null_input_{false}; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_AVG_POOL3D_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h index 9fd3e5c4601..29d8cfe4774 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h @@ -1,135 +1,135 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh" - -namespace mindspore { -namespace cukernel { -constexpr int64_t maxIndexIdx = 2; -constexpr int64_t dyDimSmall = 3; -constexpr int64_t hIdx = 2; - -class AdaptiveMaxPoolGradAttr : public GpuKernelAttrBase { - public: - AdaptiveMaxPoolGradAttr() = default; - ~AdaptiveMaxPoolGradAttr() override = default; -}; - -template -class AdaptiveMaxPoolGradHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit AdaptiveMaxPoolGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - is_null_input_ = false; - } - - virtual ~AdaptiveMaxPoolGradHelperGpuKernel() = default; - - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - ResetResource(); - is_null_input_ = CHECK_SHAPE_NULL(output_shapes[0], kernel_name_, "out_shape"); - if (is_null_input_) { - return -1; - } - input_shape_.emplace_back(input_shapes[0]); // dy - input_shape_.emplace_back(input_shapes[1]); // x - input_shape_.emplace_back(input_shapes[maxIndexIdx]); // index - output_shape_ = output_shapes[0]; // dx - - return 0; - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - // get device ptr input index output - T *dy_ptr = nullptr; - S *index_ptr = nullptr; - T *dx_ptr = nullptr; - int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &dy_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(input_ptrs, maxIndexIdx, kernel_name_, &index_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &dx_ptr); - if (flag != 0) { - return flag; - } - - cudaError_t status = cudaErrorNotReady; - if (kernel_name_ == kAdaptiveMaxPool3DGradOpName) { - const int64_t output_stride = output_shape_.cend()[-1] * output_shape_.cend()[-2] * output_shape_.cend()[-3]; - auto input_argmax_shape = input_shape_[maxIndexIdx]; - const int64_t argmax_stride = - input_argmax_shape.cend()[-1] * input_argmax_shape.cend()[-2] * input_argmax_shape.cend()[-3]; - const int64_t batch = std::accumulate(input_argmax_shape.begin(), input_argmax_shape.end() - 3, - static_cast(1), [=](int64_t a, int64_t b) { return a * b; }); - status = CalAdaptiveMaxPool3DGrad(dy_ptr, index_ptr, output_stride, argmax_stride, batch, dx_ptr, device_id_, - reinterpret_cast(cuda_stream)); - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - // call cuda kernel - const int shape_dim = output_shape_.size(); // dx grad dim 3 or 4 - auto input_shape = input_shape_[0]; // dy - const int kMinDims = 3; - if (shape_dim < kMinDims || SizeToInt(input_shape.size()) < kMinDims) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the rank of input and output can not less than " << kMinDims - << ", but got output shape: " << output_shape_ << ", input shape: " << input_shape_; - } - const int n = (shape_dim == dyDimSmall ? 1 : output_shape_[0]); - const int c = (shape_dim == dyDimSmall ? output_shape_[0] : output_shape_[1]); - const int in_h = input_shape[input_shape.size() - hIdx]; - const int in_w = input_shape[input_shape.size() - 1]; - const int out_h = output_shape_[output_shape_.size() - hIdx]; - const int out_w = output_shape_[output_shape_.size() - 1]; - - status = CalAdaptiveMaxPool2DGrad(dy_ptr, index_ptr, n, c, in_h, in_w, out_h, out_w, dx_ptr, device_id_, - reinterpret_cast(cuda_stream)); - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - - void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { - attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); - } - - void ResetResource() override { - input_shape_.clear(); - output_shape_.clear(); - } - - private: - std::shared_ptr attr_ptr_; - std::vector> input_shape_; // 0:input_shape(y_grad) 2:index_shape(argmax) - std::vector output_shape_; - bool is_null_input_; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh" + +namespace mindspore { +namespace cukernel { +constexpr int64_t maxIndexIdx = 2; +constexpr int64_t dyDimSmall = 3; +constexpr int64_t hIdx = 2; + +class AdaptiveMaxPoolGradAttr : public GpuKernelAttrBase { + public: + AdaptiveMaxPoolGradAttr() = default; + ~AdaptiveMaxPoolGradAttr() override = default; +}; + +template +class AdaptiveMaxPoolGradHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit AdaptiveMaxPoolGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + } + + virtual ~AdaptiveMaxPoolGradHelperGpuKernel() = default; + + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + ResetResource(); + is_null_input_ = CHECK_SHAPE_NULL(output_shapes[0], kernel_name_, "out_shape"); + if (is_null_input_) { + return -1; + } + input_shape_.emplace_back(input_shapes[0]); // dy + input_shape_.emplace_back(input_shapes[1]); // x + input_shape_.emplace_back(input_shapes[maxIndexIdx]); // index + output_shape_ = output_shapes[0]; // dx + + return 0; + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + // get device ptr input index output + T *dy_ptr = nullptr; + S *index_ptr = nullptr; + T *dx_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &dy_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(input_ptrs, maxIndexIdx, kernel_name_, &index_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &dx_ptr); + if (flag != 0) { + return flag; + } + + cudaError_t status = cudaErrorNotReady; + if (kernel_name_ == kAdaptiveMaxPool3DGradOpName) { + const int64_t output_stride = output_shape_.cend()[-1] * output_shape_.cend()[-2] * output_shape_.cend()[-3]; + auto input_argmax_shape = input_shape_[maxIndexIdx]; + const int64_t argmax_stride = + input_argmax_shape.cend()[-1] * input_argmax_shape.cend()[-2] * input_argmax_shape.cend()[-3]; + const int64_t batch = std::accumulate(input_argmax_shape.begin(), input_argmax_shape.end() - 3, + static_cast(1), [=](int64_t a, int64_t b) { return a * b; }); + status = CalAdaptiveMaxPool3DGrad(dy_ptr, index_ptr, output_stride, argmax_stride, batch, dx_ptr, device_id_, + reinterpret_cast(cuda_stream)); + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + // call cuda kernel + const int shape_dim = output_shape_.size(); // dx grad dim 3 or 4 + auto input_shape = input_shape_[0]; // dy + const int kMinDims = 3; + if (shape_dim < kMinDims || SizeToInt(input_shape.size()) < kMinDims) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the rank of input and output can not less than " << kMinDims + << ", but got output shape: " << output_shape_ << ", input shape: " << input_shape_; + } + const int n = (shape_dim == dyDimSmall ? 1 : output_shape_[0]); + const int c = (shape_dim == dyDimSmall ? output_shape_[0] : output_shape_[1]); + const int in_h = input_shape[input_shape.size() - hIdx]; + const int in_w = input_shape[input_shape.size() - 1]; + const int out_h = output_shape_[output_shape_.size() - hIdx]; + const int out_w = output_shape_[output_shape_.size() - 1]; + + status = CalAdaptiveMaxPool2DGrad(dy_ptr, index_ptr, n, c, in_h, in_w, out_h, out_w, dx_ptr, device_id_, + reinterpret_cast(cuda_stream)); + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + } + + void ResetResource() override { + input_shape_.clear(); + output_shape_.clear(); + } + + private: + std::shared_ptr attr_ptr_; + std::vector> input_shape_; // 0:input_shape(y_grad) 2:index_shape(argmax) + std::vector output_shape_; + bool is_null_input_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h index 57d5dc23f65..b63c6cbec4d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h @@ -1,135 +1,135 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh" - -namespace mindspore { -namespace cukernel { -constexpr int MAX_DIMS = 7; -template -class HeavisideHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit HeavisideHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - is_null_input_ = false; - need_broadcast_ = false; - } - - virtual ~HeavisideHelperGpuKernel() = default; - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - constexpr size_t OUTPUT_NUM = 1; - ResetResource(); - int out_flag = - CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); - if (out_flag == -1) { - return out_flag; - } - is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); - - auto inputx_shape = input_shapes[0]; - auto inputy_shape = input_shapes[1]; - auto output_shape = output_shapes[0]; - - for (size_t i = 0; i < inputx_shape.size(); i++) { - if (inputx_shape[i] != inputy_shape[i]) { - need_broadcast_ = true; - } - } - - lhs_shape_.resize(MAX_DIMS, 1); - rhs_shape_.resize(MAX_DIMS, 1); - output_shape_.resize(MAX_DIMS, 1); - output_num_ = 1; - for (size_t i = 0; i < output_shape.size(); i++) { - if (need_broadcast_) { - output_shape_[i] = output_shape[i]; - } - output_num_ *= output_shape[i]; - } - int lhs_offset = output_shape.size() - inputx_shape.size(); - for (size_t j = 0; j < inputx_shape.size(); j++) { - if (need_broadcast_) { - if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) { - lhs_shape_[j + lhs_offset] = inputx_shape[j]; - } - } - } - int rhs_offset = output_shape.size() - inputy_shape.size(); - for (size_t k = 0; k < inputy_shape.size(); k++) { - if (need_broadcast_) { - if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) { - rhs_shape_[k + rhs_offset] = inputy_shape[k]; - } - } - } - - return CheckKernelParam(); - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - if (is_null_input_) { - return 0; - } - - T *inputx_ptr = nullptr; - T *inputy_ptr = nullptr; - T *output_ptr = nullptr; - int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &inputx_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(input_ptrs, 1, kernel_name_, &inputy_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); - if (flag != 0) { - return flag; - } - - // call cuda kernel - if (need_broadcast_) { - BroadcastHeaviside(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_, - reinterpret_cast(cuda_stream)); - } else { - CalHeaviside(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_, - reinterpret_cast(cuda_stream)); - } - - return 0; - } - - private: - std::vector lhs_shape_; - std::vector rhs_shape_; - std::vector output_shape_; - bool need_broadcast_; - bool is_null_input_; - size_t output_num_; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh" + +namespace mindspore { +namespace cukernel { +constexpr int MAX_DIMS = 7; +template +class HeavisideHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit HeavisideHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + need_broadcast_ = false; + } + + virtual ~HeavisideHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t OUTPUT_NUM = 1; + ResetResource(); + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); + + auto inputx_shape = input_shapes[0]; + auto inputy_shape = input_shapes[1]; + auto output_shape = output_shapes[0]; + + for (size_t i = 0; i < inputx_shape.size(); i++) { + if (inputx_shape[i] != inputy_shape[i]) { + need_broadcast_ = true; + } + } + + lhs_shape_.resize(MAX_DIMS, 1); + rhs_shape_.resize(MAX_DIMS, 1); + output_shape_.resize(MAX_DIMS, 1); + output_num_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + if (need_broadcast_) { + output_shape_[i] = output_shape[i]; + } + output_num_ *= output_shape[i]; + } + int lhs_offset = output_shape.size() - inputx_shape.size(); + for (size_t j = 0; j < inputx_shape.size(); j++) { + if (need_broadcast_) { + if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) { + lhs_shape_[j + lhs_offset] = inputx_shape[j]; + } + } + } + int rhs_offset = output_shape.size() - inputy_shape.size(); + for (size_t k = 0; k < inputy_shape.size(); k++) { + if (need_broadcast_) { + if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) { + rhs_shape_[k + rhs_offset] = inputy_shape[k]; + } + } + } + + return CheckKernelParam(); + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + + T *inputx_ptr = nullptr; + T *inputy_ptr = nullptr; + T *output_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &inputx_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(input_ptrs, 1, kernel_name_, &inputy_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + + // call cuda kernel + if (need_broadcast_) { + BroadcastHeaviside(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } else { + CalHeaviside(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } + + return 0; + } + + private: + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; + bool need_broadcast_; + bool is_null_input_; + size_t output_num_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HEAVISIDE_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h index 7f555bd5ada..c6063bb6c9f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h @@ -1,151 +1,151 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" - -namespace mindspore { -namespace cukernel { -constexpr int MAX_DIMS = 7; -template -class HypotHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit HypotHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - is_null_input_ = false; - need_broadcast_ = false; - } - - virtual ~HypotHelperGpuKernel() = default; - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - constexpr size_t OUTPUT_NUM = 1; - ResetResource(); - int out_flag = - CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); - if (out_flag == -1) { - return out_flag; - } - is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); - - auto inputx_shape = input_shapes[0]; - auto inputy_shape = input_shapes[1]; - auto output_shape = output_shapes[0]; - - ProcessScalar(&inputx_shape, &inputy_shape, &output_shape); - - for (size_t i = 0; i < inputx_shape.size(); i++) { - if (inputx_shape[i] != inputy_shape[i]) { - need_broadcast_ = true; - } - } - - lhs_shape_.resize(MAX_DIMS, 1); - rhs_shape_.resize(MAX_DIMS, 1); - output_shape_.resize(MAX_DIMS, 1); - output_num_ = 1; - for (size_t i = 0; i < output_shape.size(); i++) { - if (need_broadcast_) { - output_shape_[i] = output_shape[i]; - } - output_num_ *= output_shape[i]; - } - int lhs_offset = output_shape.size() - inputx_shape.size(); - for (size_t j = 0; j < inputx_shape.size(); j++) { - if (need_broadcast_) { - if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) { - lhs_shape_[j + lhs_offset] = inputx_shape[j]; - } - } - } - int rhs_offset = output_shape.size() - inputy_shape.size(); - for (size_t k = 0; k < inputy_shape.size(); k++) { - if (need_broadcast_) { - if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) { - rhs_shape_[k + rhs_offset] = inputy_shape[k]; - } - } - } - - return CheckKernelParam(); - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - if (is_null_input_) { - return 0; - } - - T *inputx_ptr = nullptr; - T *inputy_ptr = nullptr; - T *output_ptr = nullptr; - int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &inputx_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(input_ptrs, 1, kernel_name_, &inputy_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); - if (flag != 0) { - return flag; - } - - cudaError_t status = cudaErrorNotReady; - // call cuda kernel - if (need_broadcast_) { - status = BroadcastHypot(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_, - reinterpret_cast(cuda_stream)); - } else { - status = CalHypot(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_, - reinterpret_cast(cuda_stream)); - } - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - - void ProcessScalar(std::vector *x1_shape, std::vector *x2_shape, std::vector *y_shape) { - // If there is a scalar in the inputs, its shape will be [], so it will be treated as [1]. - if (x1_shape->size() == 0) { - x1_shape->insert(x1_shape->begin(), 1); - } - if (x2_shape->size() == 0) { - x2_shape->insert(x2_shape->begin(), 1); - } - if (y_shape->size() == 0) { - y_shape->insert(y_shape->begin(), 1); - } - } - - private: - std::vector lhs_shape_; - std::vector rhs_shape_; - std::vector output_shape_; - bool need_broadcast_; - bool is_null_input_; - size_t output_num_; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" + +namespace mindspore { +namespace cukernel { +constexpr int MAX_DIMS = 7; +template +class HypotHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit HypotHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + need_broadcast_ = false; + } + + virtual ~HypotHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t OUTPUT_NUM = 1; + ResetResource(); + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); + + auto inputx_shape = input_shapes[0]; + auto inputy_shape = input_shapes[1]; + auto output_shape = output_shapes[0]; + + ProcessScalar(&inputx_shape, &inputy_shape, &output_shape); + + for (size_t i = 0; i < inputx_shape.size(); i++) { + if (inputx_shape[i] != inputy_shape[i]) { + need_broadcast_ = true; + } + } + + lhs_shape_.resize(MAX_DIMS, 1); + rhs_shape_.resize(MAX_DIMS, 1); + output_shape_.resize(MAX_DIMS, 1); + output_num_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + if (need_broadcast_) { + output_shape_[i] = output_shape[i]; + } + output_num_ *= output_shape[i]; + } + int lhs_offset = output_shape.size() - inputx_shape.size(); + for (size_t j = 0; j < inputx_shape.size(); j++) { + if (need_broadcast_) { + if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) { + lhs_shape_[j + lhs_offset] = inputx_shape[j]; + } + } + } + int rhs_offset = output_shape.size() - inputy_shape.size(); + for (size_t k = 0; k < inputy_shape.size(); k++) { + if (need_broadcast_) { + if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) { + rhs_shape_[k + rhs_offset] = inputy_shape[k]; + } + } + } + + return CheckKernelParam(); + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + + T *inputx_ptr = nullptr; + T *inputy_ptr = nullptr; + T *output_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &inputx_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(input_ptrs, 1, kernel_name_, &inputy_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + + cudaError_t status = cudaErrorNotReady; + // call cuda kernel + if (need_broadcast_) { + status = BroadcastHypot(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } else { + status = CalHypot(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + + void ProcessScalar(std::vector *x1_shape, std::vector *x2_shape, std::vector *y_shape) { + // If there is a scalar in the inputs, its shape will be [], so it will be treated as [1]. + if (x1_shape->size() == 0) { + x1_shape->insert(x1_shape->begin(), 1); + } + if (x2_shape->size() == 0) { + x2_shape->insert(x2_shape->begin(), 1); + } + if (y_shape->size() == 0) { + y_shape->insert(y_shape->begin(), 1); + } + } + + private: + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; + bool need_broadcast_; + bool is_null_input_; + size_t output_num_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h index 8bc77266dff..ddf966324b9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/list_diff_helper.h @@ -1,113 +1,113 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh" - -namespace mindspore { -namespace cukernel { -template -class ListDiffHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit ListDiffHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - num_elements_x_ = 0; - num_elements_y_ = 0; - post_output_size_ = 0; - } - virtual ~ListDiffHelperGpuKernel() = default; - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - ResetResource(); - num_elements_x_ = input_shapes[kIndex0][kIndex0]; - num_elements_y_ = input_shapes[kIndex1][kIndex0]; - output_size_list_.emplace_back(num_elements_x_ * sizeof(T)); - output_size_list_.emplace_back(num_elements_x_ * sizeof(S)); - work_size_list_.emplace_back(num_elements_y_ * sizeof(T)); - work_size_list_.emplace_back(num_elements_x_ * sizeof(S)); - work_size_list_.emplace_back(num_elements_x_ * sizeof(bool)); - return 0; - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - T *x_ptr = nullptr; - T *y_ptr = nullptr; - T *out_ptr = nullptr; - S *idx_ptr = nullptr; - T *workspace_y_ptr = nullptr; - S *worksapce_xidx_ptr = nullptr; - bool *worksapce_flag_ptr = nullptr; - int flag = GetDeviceAddress(input_ptrs, kIndex0, kernel_name_, &x_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(input_ptrs, kIndex1, kernel_name_, &y_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, kIndex0, kernel_name_, &out_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(output_ptrs, kIndex1, kernel_name_, &idx_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(work_ptrs, kIndex0, kernel_name_, &workspace_y_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(work_ptrs, kIndex1, kernel_name_, &worksapce_xidx_ptr); - if (flag != 0) { - return flag; - } - - flag = GetDeviceAddress(work_ptrs, kIndex2, kernel_name_, &worksapce_flag_ptr); - if (flag != 0) { - return flag; - } - - auto status = - CalListDiff(num_elements_x_, num_elements_y_, x_ptr, y_ptr, out_ptr, idx_ptr, workspace_y_ptr, worksapce_xidx_ptr, - worksapce_flag_ptr, device_id_, reinterpret_cast(cuda_stream), &post_output_size_); - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - - TensorInfo GetOutputTensorInfo() override { - TensorInfo dyn_out; - dyn_out.shapes.push_back({{post_output_size_}}); - return dyn_out; - } - - private: - size_t num_elements_x_; - size_t num_elements_y_; - int post_output_size_ = 0; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh" + +namespace mindspore { +namespace cukernel { +template +class ListDiffHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit ListDiffHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + num_elements_x_ = 0; + num_elements_y_ = 0; + post_output_size_ = 0; + } + virtual ~ListDiffHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + ResetResource(); + num_elements_x_ = input_shapes[kIndex0][kIndex0]; + num_elements_y_ = input_shapes[kIndex1][kIndex0]; + output_size_list_.emplace_back(num_elements_x_ * sizeof(T)); + output_size_list_.emplace_back(num_elements_x_ * sizeof(S)); + work_size_list_.emplace_back(num_elements_y_ * sizeof(T)); + work_size_list_.emplace_back(num_elements_x_ * sizeof(S)); + work_size_list_.emplace_back(num_elements_x_ * sizeof(bool)); + return 0; + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + T *x_ptr = nullptr; + T *y_ptr = nullptr; + T *out_ptr = nullptr; + S *idx_ptr = nullptr; + T *workspace_y_ptr = nullptr; + S *worksapce_xidx_ptr = nullptr; + bool *worksapce_flag_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, kIndex0, kernel_name_, &x_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(input_ptrs, kIndex1, kernel_name_, &y_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, kIndex0, kernel_name_, &out_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, kIndex1, kernel_name_, &idx_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(work_ptrs, kIndex0, kernel_name_, &workspace_y_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(work_ptrs, kIndex1, kernel_name_, &worksapce_xidx_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(work_ptrs, kIndex2, kernel_name_, &worksapce_flag_ptr); + if (flag != 0) { + return flag; + } + + auto status = + CalListDiff(num_elements_x_, num_elements_y_, x_ptr, y_ptr, out_ptr, idx_ptr, workspace_y_ptr, worksapce_xidx_ptr, + worksapce_flag_ptr, device_id_, reinterpret_cast(cuda_stream), &post_output_size_); + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + + TensorInfo GetOutputTensorInfo() override { + TensorInfo dyn_out; + dyn_out.shapes.push_back({{post_output_size_}}); + return dyn_out; + } + + private: + size_t num_elements_x_; + size_t num_elements_y_; + int post_output_size_ = 0; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LIST_DIFF_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h index 97a29c677ad..9a7e4aed705 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unravel_index_helper.h @@ -1,91 +1,91 @@ -/** - * Copyright 2019-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh" - -namespace mindspore { -namespace cukernel { -template -class UnravelIndexHelperGpuKernel : public GpuKernelHelperBase { - public: - explicit UnravelIndexHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) - : GpuKernelHelperBase(kernel_name, device_id) { - is_null_input_ = false; - } - - virtual ~UnravelIndexHelperGpuKernel() = default; - int CalMemSize(const std::vector> &input_shapes, - const std::vector> &output_shapes) override { - ResetResource(); - - constexpr size_t OUTPUT_NUM = 1; - // get input shape vector - input_indices_shape_ = input_shapes[0]; - input_dims_shape_ = input_shapes[1]; - - int out_flag = - CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); - if (out_flag == -1) { - return out_flag; - } - is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); - - // emplace_back workspace_size - size_t check_dims_ptr_workspace_size = sizeof(T); - work_size_list_.emplace_back(check_dims_ptr_workspace_size); - - return 0; - } - - int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, - const std::vector &work_ptrs, void *cuda_stream) override { - if (is_null_input_) { - return 0; - } - size_t indices_size = input_indices_shape_.size() == 0 ? 1 : input_indices_shape_[0]; - size_t dims_size = input_dims_shape_[0]; - - T *input_indices_ptr = nullptr; - T *input_dims_ptr = nullptr; - T *output_ptr = nullptr; - T *check_dims_ptr = nullptr; - - (void)GetDeviceAddress(input_ptrs, 0, kernel_name_, &input_indices_ptr); - (void)GetDeviceAddress(input_ptrs, 1, kernel_name_, &input_dims_ptr); - (void)GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); - (void)GetDeviceAddress(work_ptrs, 0, kernel_name_, &check_dims_ptr); - - // call cuda kernel - auto status = CalUnravelIndex(input_indices_ptr, input_dims_ptr, indices_size, dims_size, output_ptr, device_id_, - reinterpret_cast(cuda_stream)); - CHECK_CUDA_STATUS(status, kernel_name_); - return 0; - } - - private: - std::vector input_indices_shape_; - std::vector input_dims_shape_; - bool is_null_input_; -}; -} // namespace cukernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh" + +namespace mindspore { +namespace cukernel { +template +class UnravelIndexHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit UnravelIndexHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + } + + virtual ~UnravelIndexHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + ResetResource(); + + constexpr size_t OUTPUT_NUM = 1; + // get input shape vector + input_indices_shape_ = input_shapes[0]; + input_dims_shape_ = input_shapes[1]; + + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (HasZeroInShapes(input_shapes) || out_flag == 1); + + // emplace_back workspace_size + size_t check_dims_ptr_workspace_size = sizeof(T); + work_size_list_.emplace_back(check_dims_ptr_workspace_size); + + return 0; + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + size_t indices_size = input_indices_shape_.size() == 0 ? 1 : input_indices_shape_[0]; + size_t dims_size = input_dims_shape_[0]; + + T *input_indices_ptr = nullptr; + T *input_dims_ptr = nullptr; + T *output_ptr = nullptr; + T *check_dims_ptr = nullptr; + + (void)GetDeviceAddress(input_ptrs, 0, kernel_name_, &input_indices_ptr); + (void)GetDeviceAddress(input_ptrs, 1, kernel_name_, &input_dims_ptr); + (void)GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); + (void)GetDeviceAddress(work_ptrs, 0, kernel_name_, &check_dims_ptr); + + // call cuda kernel + auto status = CalUnravelIndex(input_indices_ptr, input_dims_ptr, indices_size, dims_size, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + CHECK_CUDA_STATUS(status, kernel_name_); + return 0; + } + + private: + std::vector input_indices_shape_; + std::vector input_dims_shape_; + bool is_null_input_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_UNRAVEL_INDEX_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cu index 24e6ab09c9d..51ca6bb28c2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cu @@ -1,70 +1,70 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "include/cuda_fp16.h" -#include "accumulate_n_v2_impl.cuh" - -template -__global__ void AccumulateNV2(const size_t size, const size_t n, T **inputs, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - T temp = 0; - for (size_t num = 0; num < n; num++) { - temp += inputs[num][pos]; - } - output[pos] = temp; - } - return; -} - -template <> -__global__ void AccumulateNV2(const size_t size, const size_t n, half **inputs, half *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - float temp = 0; - for (size_t num = 0; num < n; num++) { - temp += __half2float(inputs[num][pos]); - } - output[pos] = __float2half(temp); - } - return; -} - -template -cudaError_t CalAccumulateNV2(const size_t size, const size_t n, T **inputs, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - AccumulateNV2<<>>(size, n, inputs, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, uint8_t **inputs, - uint8_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, int8_t **inputs, - int8_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, int32_t **inputs, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, half **inputs, - half *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, float **inputs, - float *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, double **inputs, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "include/cuda_fp16.h" +#include "accumulate_n_v2_impl.cuh" + +template +__global__ void AccumulateNV2(const size_t size, const size_t n, T **inputs, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T temp = 0; + for (size_t num = 0; num < n; num++) { + temp += inputs[num][pos]; + } + output[pos] = temp; + } + return; +} + +template <> +__global__ void AccumulateNV2(const size_t size, const size_t n, half **inputs, half *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + float temp = 0; + for (size_t num = 0; num < n; num++) { + temp += __half2float(inputs[num][pos]); + } + output[pos] = __float2half(temp); + } + return; +} + +template +cudaError_t CalAccumulateNV2(const size_t size, const size_t n, T **inputs, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + AccumulateNV2<<>>(size, n, inputs, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, uint8_t **inputs, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, int8_t **inputs, + int8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, int32_t **inputs, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, half **inputs, + half *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, float **inputs, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, double **inputs, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh index e4765651932..6b234d1e1aa 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, T **inputs, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalAccumulateNV2(const size_t size, const size_t n, T **inputs, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ACCUMULATE_N_V2_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cu index c3bd5c07d2d..9f474354dfc 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cu @@ -1,91 +1,91 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh" -#include "include/cuda_fp16.h" - -template -__global__ void AdaptiveAvgPool3DKernel(const uint out_size, const uint input_channel, const uint input_height, - const uint input_width, const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, const uint output_depth, - T *input_data, T *output_data) { - for (uint pos = blockIdx.x * blockDim.x + threadIdx.x; pos < out_size; pos += gridDim.x * blockDim.x) { - const uint on = pos / (output_channel * output_height * output_width * output_depth); - const uint oc = pos / (output_height * output_width * output_depth) % output_channel; - const uint oh = pos / (output_width * output_depth) % output_height; - const uint ow = pos / output_depth % output_width; - const uint od = pos % output_depth; - const uint in = on; - const uint ic = oc; - - uint ih0 = floorf(__uint2float_rn(oh * input_height) / __uint2float_rn(output_height)); - uint ih1 = ceilf(__uint2float_rn((oh + 1) * input_height) / __uint2float_rn(output_height)); - uint kh = ih1 - ih0; - - uint iw0 = floorf(__uint2float_rn(ow * input_width) / __uint2float_rn(output_width)); - uint iw1 = ceilf(__uint2float_rn((ow + 1) * input_width) / __uint2float_rn(output_width)); - uint kw = iw1 - iw0; - - uint id0 = floorf(__uint2float_rn(od * input_depth) / __uint2float_rn(output_depth)); - uint id1 = ceilf(__uint2float_rn((od + 1) * input_depth) / __uint2float_rn(output_depth)); - uint kd = id1 - id0; - - T sum = 0; - uint in_index = 0; - for (uint ih = ih0; ih < ih1; ih++) { - for (uint iw = iw0; iw < iw1; iw++) { - for (uint id = id0; id < id1; id++) { - in_index = (((in * input_channel + ic) * input_height + ih) * input_width + iw) * input_depth + id; - sum += input_data[in_index]; - } - } - } - uint out_index = (((on * output_channel + oc) * output_height + oh) * output_width + ow) * output_depth + od; - output_data[out_index] = sum / static_cast(kh * kw * kd); - } -} - -template -cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, const uint input_height, - const uint input_width, const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, const uint output_depth, - T *input_data, T *output_data, cudaStream_t cuda_stream) { - AdaptiveAvgPool3DKernel<<>>( - out_size, input_channel, input_height, input_width, input_depth, output_channel, output_height, output_width, - output_depth, input_data, output_data); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, - const uint input_height, const uint input_width, - const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, - const uint output_depth, float *input_data, - float *output_data, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, - const uint input_height, const uint input_width, - const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, - const uint output_depth, half *input_data, - half *output_data, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, - const uint input_height, const uint input_width, - const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, - const uint output_depth, double *input_data, - double *output_data, cudaStream_t cuda_stream); +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh" +#include "include/cuda_fp16.h" + +template +__global__ void AdaptiveAvgPool3DKernel(const uint out_size, const uint input_channel, const uint input_height, + const uint input_width, const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, const uint output_depth, + T *input_data, T *output_data) { + for (uint pos = blockIdx.x * blockDim.x + threadIdx.x; pos < out_size; pos += gridDim.x * blockDim.x) { + const uint on = pos / (output_channel * output_height * output_width * output_depth); + const uint oc = pos / (output_height * output_width * output_depth) % output_channel; + const uint oh = pos / (output_width * output_depth) % output_height; + const uint ow = pos / output_depth % output_width; + const uint od = pos % output_depth; + const uint in = on; + const uint ic = oc; + + uint ih0 = floorf(__uint2float_rn(oh * input_height) / __uint2float_rn(output_height)); + uint ih1 = ceilf(__uint2float_rn((oh + 1) * input_height) / __uint2float_rn(output_height)); + uint kh = ih1 - ih0; + + uint iw0 = floorf(__uint2float_rn(ow * input_width) / __uint2float_rn(output_width)); + uint iw1 = ceilf(__uint2float_rn((ow + 1) * input_width) / __uint2float_rn(output_width)); + uint kw = iw1 - iw0; + + uint id0 = floorf(__uint2float_rn(od * input_depth) / __uint2float_rn(output_depth)); + uint id1 = ceilf(__uint2float_rn((od + 1) * input_depth) / __uint2float_rn(output_depth)); + uint kd = id1 - id0; + + T sum = 0; + uint in_index = 0; + for (uint ih = ih0; ih < ih1; ih++) { + for (uint iw = iw0; iw < iw1; iw++) { + for (uint id = id0; id < id1; id++) { + in_index = (((in * input_channel + ic) * input_height + ih) * input_width + iw) * input_depth + id; + sum += input_data[in_index]; + } + } + } + uint out_index = (((on * output_channel + oc) * output_height + oh) * output_width + ow) * output_depth + od; + output_data[out_index] = sum / static_cast(kh * kw * kd); + } +} + +template +cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, const uint input_height, + const uint input_width, const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, const uint output_depth, + T *input_data, T *output_data, cudaStream_t cuda_stream) { + AdaptiveAvgPool3DKernel<<>>( + out_size, input_channel, input_height, input_width, input_depth, output_channel, output_height, output_width, + output_depth, input_data, output_data); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, + const uint input_height, const uint input_width, + const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, + const uint output_depth, float *input_data, + float *output_data, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, + const uint input_height, const uint input_width, + const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, + const uint output_depth, half *input_data, + half *output_data, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, + const uint input_height, const uint input_width, + const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, + const uint output_depth, double *input_data, + double *output_data, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh index 372e293a923..92a548d37f5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_avg_pool3d_impl.cuh @@ -1,28 +1,28 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -template -CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, - const uint input_height, const uint input_width, - const uint input_depth, const uint output_channel, - const uint output_height, const uint output_width, - const uint output_depth, T *input_data, T *output_data, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT cudaError_t ApplyAdaptiveAvgPool3D(const uint out_size, const uint input_channel, + const uint input_height, const uint input_width, + const uint input_depth, const uint output_channel, + const uint output_height, const uint output_width, + const uint output_depth, T *input_data, T *output_data, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_AVGPOOL3D_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh index e17a4fb0d83..d88244c5c14 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool2DGrad(const T *input_data, const S *max_index, const int n, const int c, - const uint input_height, const uint input_width, - const uint output_height, const uint output_width, T *output_data, - const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool2DGrad(const T *input_data, const S *max_index, const int n, const int c, + const uint input_height, const uint input_width, + const uint output_height, const uint output_width, T *output_data, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL2D_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cu index f2da25863f4..578a4bbb818 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cu @@ -1,69 +1,69 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh" -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__global__ void AdaptiveMaxPool3DGradKernel(const T *input_grad, const S *input_argmax, const int output_stride, - const int argmax_stride, const int batch, T *output_data) { - for (size_t n = blockIdx.x * blockDim.x + threadIdx.x; n < batch; n += blockDim.x * gridDim.x) { - for (int64_t i = 0; i < argmax_stride; ++i) { - int32_t maxp = input_argmax[i + n * argmax_stride] + n * output_stride; - MsAtomicAdd(output_data + static_cast(maxp), input_grad[i + n * argmax_stride]); - } - } - return; -} - -template -cudaError_t CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, const int output_stride, - const int argmax_stride, const int batch, T *output_data, - const uint32_t &device_id, cudaStream_t cuda_stream) { - AdaptiveMaxPool3DGradKernel<<>>( - input_grad, input_argmax, output_stride, argmax_stride, batch, output_data); - return GetCudaStatus(); -} - -#define REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(type1, type2) \ - template CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool3DGrad( \ - const type1 *input_grad, const type2 *input_argmax, const int output_stride, const int argmax_stride, \ - const int batch, type1 *output_data, const uint32_t &device_id, cudaStream_t cuda_stream); - -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(half, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(float, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(double, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int8_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int16_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int32_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int64_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint8_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint16_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint32_t, int32_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint64_t, int32_t); - -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(half, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(float, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(double, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int8_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int16_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int32_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int64_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint8_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint16_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint32_t, int64_t); -REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint64_t, int64_t); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh" +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__global__ void AdaptiveMaxPool3DGradKernel(const T *input_grad, const S *input_argmax, const int output_stride, + const int argmax_stride, const int batch, T *output_data) { + for (size_t n = blockIdx.x * blockDim.x + threadIdx.x; n < batch; n += blockDim.x * gridDim.x) { + for (int64_t i = 0; i < argmax_stride; ++i) { + int32_t maxp = input_argmax[i + n * argmax_stride] + n * output_stride; + MsAtomicAdd(output_data + static_cast(maxp), input_grad[i + n * argmax_stride]); + } + } + return; +} + +template +cudaError_t CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, const int output_stride, + const int argmax_stride, const int batch, T *output_data, + const uint32_t &device_id, cudaStream_t cuda_stream) { + AdaptiveMaxPool3DGradKernel<<>>( + input_grad, input_argmax, output_stride, argmax_stride, batch, output_data); + return GetCudaStatus(); +} + +#define REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(type1, type2) \ + template CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool3DGrad( \ + const type1 *input_grad, const type2 *input_argmax, const int output_stride, const int argmax_stride, \ + const int batch, type1 *output_data, const uint32_t &device_id, cudaStream_t cuda_stream); + +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(half, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(float, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(double, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int8_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int16_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int32_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int64_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint8_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint16_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint32_t, int32_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint64_t, int32_t); + +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(half, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(float, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(double, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int8_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int16_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int32_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(int64_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint8_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint16_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint32_t, int64_t); +REG_ADAPTIVE_MAX_POOL3D_GRAD_CUDA(uint64_t, int64_t); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh index 24a9567024f..1790ad7ee6f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh @@ -1,28 +1,28 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, - const int output_stride, const int argmax_stride, const int batch, - T *output_data, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, + const int output_stride, const int argmax_stride, const int batch, + T *output_data, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adagrad_d_a_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adagrad_d_a_impl.cu old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adagrad_d_a_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adagrad_d_a_impl.cuh old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cu old mode 100755 new mode 100644 index 8bd98347372..0e5872f0289 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cu @@ -1,134 +1,134 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmax_impl.cuh" -template -__global__ void Argmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size * inner_size; - pos += gridDim.x * blockDim.x) { - size_t x = pos / inner_size % outer_size; - size_t y = pos % inner_size; - S idx = 0; - size_t input_offset = x * bound * inner_size + 0 * inner_size + y; - T max_data = input[input_offset]; - for (S i = 1; i < bound; i++) { - input_offset = x * bound * inner_size + i * inner_size + y; - auto input_data = input[input_offset]; - idx = input_data > max_data ? i : idx; - max_data = input_data > max_data ? input_data : max_data; - } - output[pos] = idx; - } - return; -} - -template -cudaError_t CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - Argmax<<>>(input, bound, outer_size, - inner_size, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const half *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const float *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const double *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int8_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int16_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int32_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int64_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint8_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint16_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint32_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint64_t *input, const int32_t bound, - const size_t outer_size, const size_t inner_size, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const half *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const float *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const double *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int8_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int16_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int32_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int64_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint8_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint16_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint32_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint64_t *input, const int64_t bound, - const size_t outer_size, const size_t inner_size, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmax_impl.cuh" +template +__global__ void Argmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size * inner_size; + pos += gridDim.x * blockDim.x) { + size_t x = pos / inner_size % outer_size; + size_t y = pos % inner_size; + S idx = 0; + size_t input_offset = x * bound * inner_size + 0 * inner_size + y; + T max_data = input[input_offset]; + for (S i = 1; i < bound; i++) { + input_offset = x * bound * inner_size + i * inner_size + y; + auto input_data = input[input_offset]; + idx = input_data > max_data ? i : idx; + max_data = input_data > max_data ? input_data : max_data; + } + output[pos] = idx; + } + return; +} + +template +cudaError_t CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + Argmax<<>>(input, bound, outer_size, + inner_size, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const half *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const float *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const double *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int8_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int16_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int32_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int64_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint8_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint16_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint32_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint64_t *input, const int32_t bound, + const size_t outer_size, const size_t inner_size, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const half *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const float *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const double *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int8_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int16_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int32_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const int64_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint8_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint16_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint32_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalArgmax(const uint64_t *input, const int64_t bound, + const size_t outer_size, const size_t inner_size, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh old mode 100755 new mode 100644 index 100d3efbb82..94fd06da4c5 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/argmax_impl.cuh @@ -1,40 +1,40 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -#ifdef __cplusplus -extern "C" { -#endif -CUDA_LIB_EXPORT cudaError_t CalArgmaxFp32(const float *input, const int bound, const size_t outer_size, - const size_t inner_size, int *output, const uint32_t &device_id, - cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalArgmaxFp16(const half *input, const int bound, const size_t outer_size, - const size_t inner_size, int *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -#ifdef __cplusplus -} -#endif - -template -CUDA_LIB_EXPORT cudaError_t CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, - S *output, const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +#ifdef __cplusplus +extern "C" { +#endif +CUDA_LIB_EXPORT cudaError_t CalArgmaxFp32(const float *input, const int bound, const size_t outer_size, + const size_t inner_size, int *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalArgmaxFp16(const half *input, const int bound, const size_t outer_size, + const size_t inner_size, int *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +#ifdef __cplusplus +} +#endif + +template +CUDA_LIB_EXPORT cudaError_t CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, + S *output, const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/assert_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/assert_impl.cu old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/assert_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/assert_impl.cuh old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cu index 048aaf5f5e9..b1f14d011d5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cu @@ -1,77 +1,77 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "bartlett_window_impl.cuh" - -template -__global__ void BartlettWindowOne(const size_t size, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = static_cast(1); - } - return; -} - -template -__global__ void BartlettWindow(const size_t size, const double N, const double M, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - double out = 0; - if (pos <= M) { - out = (2 * pos) / (N - 1); - } else { - out = 2 - (2 * pos) / (N - 1); - } - output[pos] = static_cast(out); - } - return; -} - -template -cudaError_t CalBartlettWindow(const size_t size, const T *input, const bool periodic, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - T N = 0; - cudaMemcpy(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost); - if (N == 1) { - BartlettWindowOne<<>>(size, output); - } else { - N = periodic ? static_cast(N + 1) : static_cast(N); - double M = (N - 1) / 2; - BartlettWindow<<>>(size, N, M, output); - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, - const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, - const bool periodic, half *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, - const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, - const bool periodic, float *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, - const bool periodic, double *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, - const bool periodic, double *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "bartlett_window_impl.cuh" + +template +__global__ void BartlettWindowOne(const size_t size, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = static_cast(1); + } + return; +} + +template +__global__ void BartlettWindow(const size_t size, const double N, const double M, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + double out = 0; + if (pos <= M) { + out = (2 * pos) / (N - 1); + } else { + out = 2 - (2 * pos) / (N - 1); + } + output[pos] = static_cast(out); + } + return; +} + +template +cudaError_t CalBartlettWindow(const size_t size, const T *input, const bool periodic, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + T N = 0; + cudaMemcpy(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost); + if (N == 1) { + BartlettWindowOne<<>>(size, output); + } else { + N = periodic ? static_cast(N + 1) : static_cast(N); + double M = (N - 1) / 2; + BartlettWindow<<>>(size, N, M, output); + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, + const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, + const bool periodic, half *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, + const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, + const bool periodic, float *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int *input, + const bool periodic, double *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const int64_t *input, + const bool periodic, double *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh index 0d27da32905..af94ce384f1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const T *input, const bool periodic, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalBartlettWindow(const size_t size, const T *input, const bool periodic, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BARTLETT_WINDOW_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cu index d483947b21c..b677b2e72d8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cu @@ -1,50 +1,50 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh" - -template -__global__ void CalBiasAdd(const size_t num_value, const size_t num_bias, const T *src, const T *bias, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < num_value; pos += blockDim.x * gridDim.x) { - size_t j = pos % num_bias; - output[pos] = src[pos] + bias[j]; - } - return; -} - -template -cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, const T *src, const T *bias, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - size_t thread_num = num_value > 256 ? 256 : num_value; - CalBiasAdd<<>>(num_value, num_bias, - src, bias, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, - const half *src, const half *bias, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, - const float *src, const float *bias, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, - const int8_t *src, const int8_t *bias, int8_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh" + +template +__global__ void CalBiasAdd(const size_t num_value, const size_t num_bias, const T *src, const T *bias, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < num_value; pos += blockDim.x * gridDim.x) { + size_t j = pos % num_bias; + output[pos] = src[pos] + bias[j]; + } + return; +} + +template +cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, const T *src, const T *bias, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + size_t thread_num = num_value > 256 ? 256 : num_value; + CalBiasAdd<<>>(num_value, num_bias, + src, bias, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, + const half *src, const half *bias, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, + const float *src, const float *bias, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, + const int8_t *src, const int8_t *bias, int8_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh index 5f94c1b402f..6bf49587cb0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_add_nhwc.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, const T *src, const T *bias, - T *output, const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalBiasAddNHWC(const size_t num_value, const size_t num_bias, const T *src, const T *bias, + T *output, const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BIAS_ADD_NHWC_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cu index 181e2774c29..57607264b79 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cu @@ -1,72 +1,72 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "blackman_window_impl.cuh" - -template -__global__ void BlackmanWindowOne(const size_t size, const double N, const double PI, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = static_cast(1); - } - return; -} - -template -__global__ void BlackmanWindow(const size_t size, const double N, const double PI, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - double out = 0.42 - 0.5 * cos((2 * PI * pos) / (N - 1)) + 0.08 * cos((4 * PI * pos) / (N - 1)); - output[pos] = static_cast(out); - } - return; -} - -template -cudaError_t CalBlackmanWindow(const size_t size, const T *input, const bool periodic, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - const double PI = acos(-1); - T N = 0; - cudaMemcpy(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost); - if (N == 1) { - BlackmanWindowOne<<>>(size, N, PI, output); - } else { - N = periodic ? static_cast(N + 1) : static_cast(N); - BlackmanWindow<<>>(size, N, PI, output); - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, - const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, - const bool periodic, half *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, - const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, - const bool periodic, float *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, - const bool periodic, double *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, - const bool periodic, double *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "blackman_window_impl.cuh" + +template +__global__ void BlackmanWindowOne(const size_t size, const double N, const double PI, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = static_cast(1); + } + return; +} + +template +__global__ void BlackmanWindow(const size_t size, const double N, const double PI, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + double out = 0.42 - 0.5 * cos((2 * PI * pos) / (N - 1)) + 0.08 * cos((4 * PI * pos) / (N - 1)); + output[pos] = static_cast(out); + } + return; +} + +template +cudaError_t CalBlackmanWindow(const size_t size, const T *input, const bool periodic, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + const double PI = acos(-1); + T N = 0; + cudaMemcpy(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost); + if (N == 1) { + BlackmanWindowOne<<>>(size, N, PI, output); + } else { + N = periodic ? static_cast(N + 1) : static_cast(N); + BlackmanWindow<<>>(size, N, PI, output); + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, + const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, + const bool periodic, half *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, + const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, + const bool periodic, float *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int *input, + const bool periodic, double *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const int64_t *input, + const bool periodic, double *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh index 97ea77310a9..4f2b1b76b61 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const T *input, const bool periodic, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalBlackmanWindow(const size_t size, const T *input, const bool periodic, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BLACKMAN_WINDOW_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cu index 1de38335cbe..f1087f4da2c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cu @@ -1,57 +1,57 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "check_numerics_impl.cuh" - -template -__global__ void CheckNumerics(const size_t size, const T *input, int32_t *flag_device) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - if (isnan(input[pos])) { - flag_device[0] = 1; - } else if (isinf(input[pos])) { - flag_device[1] = 1; - } - } - return; -} - -template <> -__global__ void CheckNumerics(const size_t size, const half *input, int32_t *flag_device) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - if (isnan(__half2float(input[pos]))) { - flag_device[0] = 1; - } else if (isinf(__half2float(input[pos]))) { - flag_device[1] = 1; - } - } - return; -} - -template -cudaError_t CalCheckNumerics(const size_t size, const T *input, int32_t *flag_device, const uint32_t &device_id, - cudaStream_t cuda_stream) { - CheckNumerics<<>>(size, input, flag_device); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const half *input, int32_t *flag_device, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const float *input, - int32_t *flag_device, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const double *input, - int32_t *flag_device, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "check_numerics_impl.cuh" + +template +__global__ void CheckNumerics(const size_t size, const T *input, int32_t *flag_device) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + flag_device[0] = 1; + } else if (isinf(input[pos])) { + flag_device[1] = 1; + } + } + return; +} + +template <> +__global__ void CheckNumerics(const size_t size, const half *input, int32_t *flag_device) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + if (isnan(__half2float(input[pos]))) { + flag_device[0] = 1; + } else if (isinf(__half2float(input[pos]))) { + flag_device[1] = 1; + } + } + return; +} + +template +cudaError_t CalCheckNumerics(const size_t size, const T *input, int32_t *flag_device, const uint32_t &device_id, + cudaStream_t cuda_stream) { + CheckNumerics<<>>(size, input, flag_device); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const half *input, int32_t *flag_device, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const float *input, + int32_t *flag_device, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const double *input, + int32_t *flag_device, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh index 1099abbd897..6ecf1f4b67a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/check_numerics_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const T *input, int32_t *flag_device, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalCheckNumerics(const size_t size, const T *input, int32_t *flag_device, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CHECK_NUMERICS_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cu index 6794527b943..436845689d6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cu @@ -1,118 +1,118 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh" -#include -#include -#include -#include -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" - -template -using Complex = mindspore::utils::Complex; - -template -__global__ void Col2ImKernel(const T *input, T *output, const uint32_t num_kernels, const uint32_t per_batch_size, - const uint32_t per_channel_size, const uint32_t per_col_batch_size, - const uint32_t out_height, const uint32_t out_width, const uint32_t in_height, - const uint32_t in_width, const uint32_t kernel_height, const uint32_t kernel_width, - const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, - const uint32_t dilation_width) { - for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_kernels; i += blockDim.x * gridDim.x) { - S val = static_cast(0); - uint32_t w_id = i % out_width + pad_width; - uint32_t h_id = i % per_batch_size / out_width % out_height + pad_height; - uint32_t c_id = i % per_batch_size / per_channel_size; - uint32_t n_col_offset = i / per_batch_size * per_col_batch_size; - uint32_t kernel_expand_h = (kernel_height - 1) * dilation_height + 1; - uint32_t kernel_expand_w = (kernel_width - 1) * dilation_width + 1; - // range coordinates - uint32_t out_height_start = h_id < kernel_expand_h ? 0 : (h_id - kernel_expand_h) / stride_height + 1; - uint32_t out_width_start = w_id < kernel_expand_w ? 0 : (w_id - kernel_expand_w) / stride_width + 1; - uint32_t out_height_end = min(h_id / stride_height + 1, in_height); - uint32_t out_width_end = min(w_id / stride_width + 1, in_width); - - for (uint32_t height = out_height_start; height < out_height_end; ++height) { - for (uint32_t width = out_width_start; width < out_width_end; ++width) { - uint32_t kernel_h = (h_id - height * stride_height); - uint32_t kernel_w = (w_id - width * stride_width); - if (kernel_h % dilation_height == 0 && kernel_w % dilation_width == 0) { - kernel_h /= dilation_height; - kernel_w /= dilation_width; - uint32_t data_index = - n_col_offset + - (((c_id * kernel_height + kernel_h) * kernel_width + kernel_w) * in_height + height) * in_width + width; - val += (S)input[data_index]; - } - } - } - output[i] = static_cast(val); - } -} - -template -cudaError_t Col2Im(const T *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, - const uint32_t kernel_height, const uint32_t kernel_width, const uint32_t pad_height, - const uint32_t pad_width, const uint32_t stride_height, const uint32_t stride_width, - const uint32_t dilation_height, const uint32_t dilation_width, T *output, cudaStream_t cuda_stream) { - uint32_t per_channel_size = out_height * out_width; - uint32_t per_batch_size = channels * per_channel_size; - uint32_t num_kernels = batch_size * per_batch_size; - uint32_t per_col_batch_size = channels * in_height * in_width * kernel_width * kernel_height; - Col2ImKernel<<>>( - input, output, num_kernels, per_batch_size, per_channel_size, per_col_batch_size, out_height, out_width, in_height, - in_width, kernel_height, kernel_width, pad_height, pad_width, stride_height, stride_width, dilation_height, - dilation_width); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t Col2Im( - const float *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, - const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, float *output, - cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Col2Im( - const half *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, - const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, half *output, - cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Col2Im( - const double *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, - const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, double *output, - cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Col2Im, Complex>( - const Complex *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, - const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, Complex *output, - cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Col2Im, Complex>( - const Complex *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, - const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, - const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, Complex *output, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh" +#include +#include +#include +#include +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +template +using Complex = mindspore::utils::Complex; + +template +__global__ void Col2ImKernel(const T *input, T *output, const uint32_t num_kernels, const uint32_t per_batch_size, + const uint32_t per_channel_size, const uint32_t per_col_batch_size, + const uint32_t out_height, const uint32_t out_width, const uint32_t in_height, + const uint32_t in_width, const uint32_t kernel_height, const uint32_t kernel_width, + const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, + const uint32_t dilation_width) { + for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_kernels; i += blockDim.x * gridDim.x) { + S val = static_cast(0); + uint32_t w_id = i % out_width + pad_width; + uint32_t h_id = i % per_batch_size / out_width % out_height + pad_height; + uint32_t c_id = i % per_batch_size / per_channel_size; + uint32_t n_col_offset = i / per_batch_size * per_col_batch_size; + uint32_t kernel_expand_h = (kernel_height - 1) * dilation_height + 1; + uint32_t kernel_expand_w = (kernel_width - 1) * dilation_width + 1; + // range coordinates + uint32_t out_height_start = h_id < kernel_expand_h ? 0 : (h_id - kernel_expand_h) / stride_height + 1; + uint32_t out_width_start = w_id < kernel_expand_w ? 0 : (w_id - kernel_expand_w) / stride_width + 1; + uint32_t out_height_end = min(h_id / stride_height + 1, in_height); + uint32_t out_width_end = min(w_id / stride_width + 1, in_width); + + for (uint32_t height = out_height_start; height < out_height_end; ++height) { + for (uint32_t width = out_width_start; width < out_width_end; ++width) { + uint32_t kernel_h = (h_id - height * stride_height); + uint32_t kernel_w = (w_id - width * stride_width); + if (kernel_h % dilation_height == 0 && kernel_w % dilation_width == 0) { + kernel_h /= dilation_height; + kernel_w /= dilation_width; + uint32_t data_index = + n_col_offset + + (((c_id * kernel_height + kernel_h) * kernel_width + kernel_w) * in_height + height) * in_width + width; + val += (S)input[data_index]; + } + } + } + output[i] = static_cast(val); + } +} + +template +cudaError_t Col2Im(const T *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, + const uint32_t kernel_height, const uint32_t kernel_width, const uint32_t pad_height, + const uint32_t pad_width, const uint32_t stride_height, const uint32_t stride_width, + const uint32_t dilation_height, const uint32_t dilation_width, T *output, cudaStream_t cuda_stream) { + uint32_t per_channel_size = out_height * out_width; + uint32_t per_batch_size = channels * per_channel_size; + uint32_t num_kernels = batch_size * per_batch_size; + uint32_t per_col_batch_size = channels * in_height * in_width * kernel_width * kernel_height; + Col2ImKernel<<>>( + input, output, num_kernels, per_batch_size, per_channel_size, per_col_batch_size, out_height, out_width, in_height, + in_width, kernel_height, kernel_width, pad_height, pad_width, stride_height, stride_width, dilation_height, + dilation_width); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t Col2Im( + const float *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, + const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, float *output, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Col2Im( + const half *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, + const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, half *output, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Col2Im( + const double *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, + const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, double *output, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Col2Im, Complex>( + const Complex *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, + const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, Complex *output, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Col2Im, Complex>( + const Complex *input, const uint32_t batch_size, const uint32_t channels, const uint32_t out_height, + const uint32_t out_width, const uint32_t in_height, const uint32_t in_width, const uint32_t kernel_height, + const uint32_t kernel_width, const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, const uint32_t dilation_width, Complex *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh index fe219c741dd..81f5c4a8572 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh @@ -1,28 +1,28 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -template -CUDA_LIB_EXPORT cudaError_t Col2Im(const T *input, const uint32_t batch_size, const uint32_t channels, - const uint32_t out_height, const uint32_t out_width, const uint32_t in_height, - const uint32_t in_width, const uint32_t kernel_height, const uint32_t kernel_width, - const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, - const uint32_t stride_width, const uint32_t dilation_height, - const uint32_t dilation_width, T *output, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT cudaError_t Col2Im(const T *input, const uint32_t batch_size, const uint32_t channels, + const uint32_t out_height, const uint32_t out_width, const uint32_t in_height, + const uint32_t in_width, const uint32_t kernel_height, const uint32_t kernel_width, + const uint32_t pad_height, const uint32_t pad_width, const uint32_t stride_height, + const uint32_t stride_width, const uint32_t dilation_height, + const uint32_t dilation_width, T *output, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cu index e3463517e52..e8e090647e8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cu @@ -1,162 +1,162 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ctcgreedydecoder_impl.cuh" -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" -template -__global__ void CTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, const size_t batch_size, - int64_t *decoded_values_temp, T *log_probability) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += gridDim.x * blockDim.x) { - int idx = 0; - size_t input_offset = pos * bound; - T max_data = input[input_offset]; - for (int i = 1; i < bound; i++) { - input_offset = pos * bound + i; - auto input_data = input[input_offset]; - if (input_data > max_data) { - idx = i; - max_data = input_data; - } - } - decoded_values_temp[pos] = idx; - log_probability[pos] = -max_data; - } - return; -} - -template -__global__ void values_merge(int64_t *decoded_values_temp, const int32_t *sequence_length, const size_t batch_size, - const int bound, const bool merge_ok, T *log_probability, int64_t *nums_count) { - const int blank_idx = bound - 1; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < batch_size; pos += gridDim.x * blockDim.x) { - if (sequence_length[pos] <= 0) { - nums_count[pos] = 0; - log_probability[pos] = 0; - nums_count[pos] = 0; - return; - } - size_t cnt = 0; - for (size_t i = 0, idx = pos; i < sequence_length[pos]; i++, idx += batch_size) { - if (idx != pos) { - log_probability[pos] += log_probability[idx]; - } - if (decoded_values_temp[idx] == blank_idx || - merge_ok && idx != pos && decoded_values_temp[idx] == decoded_values_temp[idx - batch_size]) { - continue; - } - decoded_values_temp[cnt * batch_size + pos] = decoded_values_temp[idx]; - cnt++; - } - nums_count[pos] = cnt; - } - return; -} - -__global__ void indicesCompute(const int64_t *decoded_values_temp, const int64_t *nums_count, const size_t batch_size, - int64_t *decoded_indices, int64_t *decoded_values, int64_t *decoded_shape, - int64_t *nums_count_pre_sum) { - for (size_t batch_pos = blockIdx.y * blockDim.y + threadIdx.y; batch_pos < batch_size; - batch_pos += gridDim.y * blockDim.y) { - for (size_t nums_count_pos = threadIdx.x; nums_count_pos < nums_count[batch_pos]; nums_count_pos += blockDim.x) { - decoded_indices[(nums_count_pre_sum[batch_pos] + nums_count_pos) * 2] = batch_pos; - decoded_indices[(nums_count_pre_sum[batch_pos] + nums_count_pos) * 2 + 1] = nums_count_pos; - decoded_values[nums_count_pre_sum[batch_pos] + nums_count_pos] = - decoded_values_temp[nums_count_pos * batch_size + batch_pos]; - } - if (threadIdx.x == 0) { - MsAtomicMax(decoded_shape + 1, nums_count[batch_pos]); - } - } - decoded_shape[0] = batch_size; -} - -template -cudaError_t CalCTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, const size_t batch_size, - int64_t *decoded_values_temp, T *log_probability, const uint32_t &device_id, - cudaStream_t cuda_stream) { - CTCGreedyDecoder<<>>( - input, bound, outer_size, batch_size, decoded_values_temp, log_probability); - return GetCudaStatus(); -} - -template -cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, const size_t batch_size, - const int bound, const bool merge_ok, T *log_probability, int64_t *nums_count, - const uint32_t &device_id, cudaStream_t cuda_stream) { - values_merge<<>>( - decoded_values_temp, sequence_length, batch_size, bound, merge_ok, log_probability, nums_count); - return GetCudaStatus(); -} - -cudaError_t Calindices(const int64_t *decoded_values_temp, const int64_t *nums_count, const size_t batch_size, - int64_t *decoded_indices, int64_t *decoded_values, int64_t *decoded_shape, const uint32_t &device_id, - cudaStream_t cuda_stream, int64_t *count) { - size_t temp_storage_bytes = 0; - int64_t *nums_count_pre_sum = nullptr; - cudaMalloc(&nums_count_pre_sum, sizeof(int64_t) * (batch_size + 1)); - cudaMemset(nums_count_pre_sum, 0, sizeof(int64_t) * (batch_size + 1)); - cudaMemset(decoded_shape, 0, sizeof(int64_t) * 2); - (void)cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, nums_count, nums_count_pre_sum + 1, - static_cast(batch_size), cuda_stream); - void *d_temp_storage = nullptr; - cudaStreamSynchronize(cuda_stream); - (void)cudaMalloc(&d_temp_storage, temp_storage_bytes); - (void)cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, nums_count, nums_count_pre_sum + 1, - static_cast(batch_size), cuda_stream); - cudaStreamSynchronize(cuda_stream); - (void)cudaFree(d_temp_storage); - - int64_t sum_num_count = 0; - cudaMemcpy(&sum_num_count, nums_count_pre_sum + batch_size, sizeof(int64_t), cudaMemcpyDeviceToHost); - int64_t avg_num_count = sum_num_count / batch_size == 0 ? 1 : sum_num_count / batch_size; - size_t thread_x_num = avg_num_count > 32 ? 32 : avg_num_count; - size_t thread_y_num = 512 / thread_x_num; - - dim3 thread_num(thread_x_num, thread_y_num); - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, device_id); - int max_blocks = prop.multiProcessorCount; - int block_num = - min(static_cast(((avg_num_count * batch_size - 1) / (thread_x_num * thread_y_num)) + 1), max_blocks); - - indicesCompute<<>>( - decoded_values_temp, nums_count, batch_size, decoded_indices, decoded_values, decoded_shape, nums_count_pre_sum); - cudaFree(nums_count_pre_sum); - *count = sum_num_count; - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const float *input, const int bound, - const size_t outer_size, const size_t batch_size, - int64_t *decoded_values_temp, float *log_probability, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const double *input, const int bound, - const size_t outer_size, const size_t batch_size, - int64_t *decoded_values_temp, double *log_probability, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, - const size_t batch_size, const int bound, const bool merge_ok, - float *log_probability, int64_t *nums_count, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, - const size_t batch_size, const int bound, const bool merge_ok, - double *log_probability, int64_t *nums_count, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ctcgreedydecoder_impl.cuh" +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +template +__global__ void CTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, const size_t batch_size, + int64_t *decoded_values_temp, T *log_probability) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += gridDim.x * blockDim.x) { + int idx = 0; + size_t input_offset = pos * bound; + T max_data = input[input_offset]; + for (int i = 1; i < bound; i++) { + input_offset = pos * bound + i; + auto input_data = input[input_offset]; + if (input_data > max_data) { + idx = i; + max_data = input_data; + } + } + decoded_values_temp[pos] = idx; + log_probability[pos] = -max_data; + } + return; +} + +template +__global__ void values_merge(int64_t *decoded_values_temp, const int32_t *sequence_length, const size_t batch_size, + const int bound, const bool merge_ok, T *log_probability, int64_t *nums_count) { + const int blank_idx = bound - 1; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < batch_size; pos += gridDim.x * blockDim.x) { + if (sequence_length[pos] <= 0) { + nums_count[pos] = 0; + log_probability[pos] = 0; + nums_count[pos] = 0; + return; + } + size_t cnt = 0; + for (size_t i = 0, idx = pos; i < sequence_length[pos]; i++, idx += batch_size) { + if (idx != pos) { + log_probability[pos] += log_probability[idx]; + } + if (decoded_values_temp[idx] == blank_idx || + merge_ok && idx != pos && decoded_values_temp[idx] == decoded_values_temp[idx - batch_size]) { + continue; + } + decoded_values_temp[cnt * batch_size + pos] = decoded_values_temp[idx]; + cnt++; + } + nums_count[pos] = cnt; + } + return; +} + +__global__ void indicesCompute(const int64_t *decoded_values_temp, const int64_t *nums_count, const size_t batch_size, + int64_t *decoded_indices, int64_t *decoded_values, int64_t *decoded_shape, + int64_t *nums_count_pre_sum) { + for (size_t batch_pos = blockIdx.y * blockDim.y + threadIdx.y; batch_pos < batch_size; + batch_pos += gridDim.y * blockDim.y) { + for (size_t nums_count_pos = threadIdx.x; nums_count_pos < nums_count[batch_pos]; nums_count_pos += blockDim.x) { + decoded_indices[(nums_count_pre_sum[batch_pos] + nums_count_pos) * 2] = batch_pos; + decoded_indices[(nums_count_pre_sum[batch_pos] + nums_count_pos) * 2 + 1] = nums_count_pos; + decoded_values[nums_count_pre_sum[batch_pos] + nums_count_pos] = + decoded_values_temp[nums_count_pos * batch_size + batch_pos]; + } + if (threadIdx.x == 0) { + MsAtomicMax(decoded_shape + 1, nums_count[batch_pos]); + } + } + decoded_shape[0] = batch_size; +} + +template +cudaError_t CalCTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, const size_t batch_size, + int64_t *decoded_values_temp, T *log_probability, const uint32_t &device_id, + cudaStream_t cuda_stream) { + CTCGreedyDecoder<<>>( + input, bound, outer_size, batch_size, decoded_values_temp, log_probability); + return GetCudaStatus(); +} + +template +cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, const size_t batch_size, + const int bound, const bool merge_ok, T *log_probability, int64_t *nums_count, + const uint32_t &device_id, cudaStream_t cuda_stream) { + values_merge<<>>( + decoded_values_temp, sequence_length, batch_size, bound, merge_ok, log_probability, nums_count); + return GetCudaStatus(); +} + +cudaError_t Calindices(const int64_t *decoded_values_temp, const int64_t *nums_count, const size_t batch_size, + int64_t *decoded_indices, int64_t *decoded_values, int64_t *decoded_shape, const uint32_t &device_id, + cudaStream_t cuda_stream, int64_t *count) { + size_t temp_storage_bytes = 0; + int64_t *nums_count_pre_sum = nullptr; + cudaMalloc(&nums_count_pre_sum, sizeof(int64_t) * (batch_size + 1)); + cudaMemset(nums_count_pre_sum, 0, sizeof(int64_t) * (batch_size + 1)); + cudaMemset(decoded_shape, 0, sizeof(int64_t) * 2); + (void)cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, nums_count, nums_count_pre_sum + 1, + static_cast(batch_size), cuda_stream); + void *d_temp_storage = nullptr; + cudaStreamSynchronize(cuda_stream); + (void)cudaMalloc(&d_temp_storage, temp_storage_bytes); + (void)cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, nums_count, nums_count_pre_sum + 1, + static_cast(batch_size), cuda_stream); + cudaStreamSynchronize(cuda_stream); + (void)cudaFree(d_temp_storage); + + int64_t sum_num_count = 0; + cudaMemcpy(&sum_num_count, nums_count_pre_sum + batch_size, sizeof(int64_t), cudaMemcpyDeviceToHost); + int64_t avg_num_count = sum_num_count / batch_size == 0 ? 1 : sum_num_count / batch_size; + size_t thread_x_num = avg_num_count > 32 ? 32 : avg_num_count; + size_t thread_y_num = 512 / thread_x_num; + + dim3 thread_num(thread_x_num, thread_y_num); + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, device_id); + int max_blocks = prop.multiProcessorCount; + int block_num = + min(static_cast(((avg_num_count * batch_size - 1) / (thread_x_num * thread_y_num)) + 1), max_blocks); + + indicesCompute<<>>( + decoded_values_temp, nums_count, batch_size, decoded_indices, decoded_values, decoded_shape, nums_count_pre_sum); + cudaFree(nums_count_pre_sum); + *count = sum_num_count; + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const float *input, const int bound, + const size_t outer_size, const size_t batch_size, + int64_t *decoded_values_temp, float *log_probability, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const double *input, const int bound, + const size_t outer_size, const size_t batch_size, + int64_t *decoded_values_temp, double *log_probability, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, + const size_t batch_size, const int bound, const bool merge_ok, + float *log_probability, int64_t *nums_count, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, + const size_t batch_size, const int bound, const bool merge_ok, + double *log_probability, int64_t *nums_count, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh index 358844d8e8c..41ca35f5ce0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh @@ -1,37 +1,37 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, - const size_t batch_size, int64_t *decoded_values_temp, - T *log_probability, const uint32_t &device_id, - cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, - const size_t batch_size, const int bound, const bool merge_ok, T *log_probability, - int64_t *nums_count, const uint32_t &device_id, cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t Calindices(const int64_t *decoded_values_temp, const int64_t *nums_count, - const size_t batch_size, int64_t *decoded_indices, int64_t *decoded_values, - int64_t *decoded_shape, const uint32_t &device_id, cudaStream_t cuda_stream, - int64_t *count); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalCTCGreedyDecoder(const T *input, const int bound, const size_t outer_size, + const size_t batch_size, int64_t *decoded_values_temp, + T *log_probability, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t Calmerge(int64_t *decoded_values_temp, const int32_t *sequence_length, + const size_t batch_size, const int bound, const bool merge_ok, T *log_probability, + int64_t *nums_count, const uint32_t &device_id, cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t Calindices(const int64_t *decoded_values_temp, const int64_t *nums_count, + const size_t batch_size, int64_t *decoded_indices, int64_t *decoded_values, + int64_t *decoded_shape, const uint32_t &device_id, cudaStream_t cuda_stream, + int64_t *count); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTC_GREEDY_DECODER_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cu index c2ab47aa73f..cc1c423de79 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cu @@ -1,460 +1,460 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "ctcloss_impl.cuh" -template -__device__ T LogSumExp(const T logprob1, const T logprob2) { - if (logprob1 == logprob2 && logprob1 == -std::numeric_limits::infinity()) { - return logprob1; - } else { - return (logprob1 > logprob2) ? logprob1 + log1pf(expf(logprob2 - logprob1)) - : logprob2 + log1pf(expf(logprob1 - logprob2)); - } -} - -template -__global__ void CalculateFwdVarKernel(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, - const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, - int maxtime, int blank, int *label_squence_length, int *cum_labels_length, - bool ignore_longer_outputs_than_inputs) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - if (sequence_length[i] == 0 || - (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { - } else { - T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime]; - int *label_value_with_blank_cur = &label_value_with_blank[0]; - if (i > 0) { - label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; - } - int numclass = blank + 1; - int U = 2 * label_squence_length[i] + 1; - int Ti = sequence_length[i]; - int low = 0; - int high = 0; - log_alpha_b_cur[0] = log(softmax_probs[i * numclass + blank]); - int label0 = blank; - if (U > 1) { - label0 = label_value_with_blank_cur[1]; - log_alpha_b_cur[maxtime] = log(softmax_probs[i * numclass + label0]); - } - for (int t = 1; t < Ti; ++t) { - low = 0; - high = U; - int low_limit = U - (2 * (Ti - t)); - int high_limit = 2 * (t + 1); - if (low_limit > low) { - low = low_limit; - } - if (high_limit < U) { - high = high_limit; - } - for (int u = low; u < high; ++u) { - T sum_log_alpha = -std::numeric_limits::infinity(); - if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) { - sum_log_alpha = log_alpha_b_cur[u * maxtime + t - 1]; - } - if (u > 0) { - sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 1) * maxtime + t - 1]); - } - if (u > 1) { - const bool matching_labels_merge = - ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u - 2]); - if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) { - sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 2) * maxtime + t - 1]); - } - } - log_alpha_b_cur[u * maxtime + t] = - log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + t * numclass * batch]) + sum_log_alpha; - } - } - } - } -} - -template -__global__ void CalculateBwdVarKernel(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, - const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, - int maxtime, int blank, int *label_squence_length, int *cum_labels_length, - bool ignore_longer_outputs_than_inputs) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - if (sequence_length[i] == 0 || - (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { - } else { - T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime]; - int *label_value_with_blank_cur = &label_value_with_blank[0]; - if (i > 0) { - label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; - } - int numclass = blank + 1; - int U = 2 * label_squence_length[i] + 1; - int Ti = sequence_length[i]; - int low = 0; - int high = 0; - if (U > 1) { - for (int u = U - 2; u < U; ++u) { - log_beta_b_cur[u * maxtime + Ti - 1] = 0; - } - } else { - log_beta_b_cur[Ti - 1] = 0; - log_beta_b_cur[Ti - 2] = 0; - } - for (int t = Ti - 2; t >= 0; --t) { - low = 0; - high = U; - int low_limit = U - (2 * (Ti - t)); - int high_limit = 2 * (t + 1); - if (low_limit > low) { - low = low_limit; - } - if (high_limit < U) { - high = high_limit; - } - for (int u = low; u < high; ++u) { - if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) { - log_beta_b_cur[u * maxtime + t] = LogSumExp( - log_beta_b_cur[u * maxtime + t], - log_beta_b_cur[u * maxtime + t + 1] + - log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + (t + 1) * numclass * batch])); - } - if (u + 1 < U) { - log_beta_b_cur[u * maxtime + t] = LogSumExp( - log_beta_b_cur[u * maxtime + t], - log_beta_b_cur[(u + 1) * maxtime + t + 1] + - log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 1] + (t + 1) * numclass * batch])); - } - if (u + 2 < U) { - const bool matching_labels_merge = - ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u + 2]); - if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) { - log_beta_b_cur[u * maxtime + t] = LogSumExp( - log_beta_b_cur[u * maxtime + t], - log_beta_b_cur[(u + 2) * maxtime + t + 1] + - log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 2] + (t + 1) * numclass * batch])); - } - } - } - } - } - } -} - -template -__global__ void ProbInitKernel(T *prob_num, int size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - prob_num[i] = -std::numeric_limits::infinity(); - } -} -template -__global__ void LogBInitKernel(T *log_b, int log_prob_size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < log_prob_size; i += blockDim.x * gridDim.x) { - log_b[i] = -std::numeric_limits::infinity(); - } -} - -template -__global__ void CTCLossKernel(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, - int SOffSet, int maxtime, int numclass, const int *sequence_length, - int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num, - bool ignore_longer_outputs_than_inputs) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - if (sequence_length[i] == 0 || - (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { - } else { - T *grad_cur = &grads[i * numclass]; - const T *softmax_probs_cur = &softmax_probs[i * numclass]; - T *prob_num_cur = &prob_num[i * numclass]; - int U = 2 * label_squence_length[i] + 1; - T log_pzx = -std::numeric_limits::infinity(); - const T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime]; - const T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime]; - int *label_value_with_blank_cur = &label_value_with_blank[0]; - if (i > 0) { - label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; - } - for (int u = 0; u < U; ++u) { - log_pzx = LogSumExp(log_pzx, log_alpha_b_cur[u * maxtime] + log_beta_b_cur[u * maxtime]); - } - cost[i] = -log_pzx; - // grad - int L = numclass; - int Ti = sequence_length[i]; - if (log_pzx == -std::numeric_limits::infinity()) { - for (int t = 0; t < Ti; ++t) { - for (int l = 0; l < L; ++l) { - grad_cur[t * numclass * batch + l] = softmax_probs_cur[t * numclass * batch + l]; - } - } - } else { - for (int t = 0; t < Ti; ++t) { - for (int u = 0; u < U; ++u) { - int l = label_value_with_blank_cur[u]; - prob_num_cur[t * batch * numclass + l] = - LogSumExp(prob_num_cur[t * batch * numclass + l], - log_alpha_b_cur[u * maxtime + t] + log_beta_b_cur[u * maxtime + t]); - } - for (int l = 0; l < L; ++l) { - grad_cur[t * numclass * batch + l] = - softmax_probs_cur[t * numclass * batch + l] - expf(prob_num_cur[t * batch * numclass + l] - log_pzx); - } - } - } - } - } -} - -template -__global__ void InnerSoftMaxKernel(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, - int batch, int numclass) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch * max_time; i += blockDim.x * gridDim.x) { - int k = i / batch; - int m = i % batch; - if (k < sequence_length[m]) { - T maxCoeff = 0.; - T sumCoeff = 0.; - for (int j = i * numclass; j < (i + 1) * numclass; ++j) { - if (probs[j] > maxCoeff) { - maxCoeff = probs[j]; - } - } - for (int j = i * numclass; j < (i + 1) * numclass; ++j) { - sumCoeff += exp(probs[j] - maxCoeff); - softmax_probs[j] = exp(probs[j] - maxCoeff); - } - for (int j = i * numclass; j < (i + 1) * numclass; ++j) { - softmax_probs[j] /= sumCoeff; - } - } - } -} - -__global__ void GenLabelValuePCRKernel(int *label_value_sp, int *label_value_pcr, int *label_squence_length, - int *cum_labels_length, int batch) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - int L = label_squence_length[i]; - label_squence_length[i] = 0; - int offset = 0; - if (i > 0) { - offset = cum_labels_length[i - 1]; - } - for (int l = offset; l < L; ++l) { - if (l == offset || label_value_sp[l] != label_value_sp[l - 1]) { - label_value_pcr[offset + label_squence_length[i]++] = label_value_sp[l]; - } - } - } -} - -__global__ void UpdateLengthKernel(int *label_squence_length, int *cum_labels_length, int *max_labels_length, - int batch) { - max_labels_length[0] = 0; - for (int i = 0; i < batch; ++i) { - if (label_squence_length[i] > max_labels_length[0]) { - max_labels_length[0] = label_squence_length[i]; - } - if (i == 0) { - cum_labels_length[i] = label_squence_length[i]; - } else { - cum_labels_length[i] = label_squence_length[i] + cum_labels_length[i - 1]; - } - } -} - -template -cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length, - bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, - int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs, - cudaStream_t stream) { - int log_prob_size = SOffSet * batch * maxtime; - LogBInitKernel<<>>(log_beta_b, log_prob_size); - CalculateBwdVarKernel<<>>( - log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime, - blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs); - return GetCudaStatus(); -} - -template -cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length, - bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, - int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs, - cudaStream_t stream) { - int log_prob_size = SOffSet * batch * maxtime; - LogBInitKernel<<>>(log_alpha_b, log_prob_size); - CalculateFwdVarKernel<<>>( - log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime, - blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs); - return GetCudaStatus(); -} - -template -cudaError_t InnerSoftMax(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, int batch, - int numclass, cudaStream_t stream) { - InnerSoftMaxKernel<<>>(probs, softmax_probs, sequence_length, - max_time, batch, numclass); - return GetCudaStatus(); -} - -__global__ void GenLabelWithBlankKernel(int *label_value, int *label_value_with_blank, int *label_squence_length, - int *precum_labels_length, int *cum_labels_length, int batch, int blank) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - int offset = 0; - int offset1 = 0; - if (i > 0) { - offset = 2 * cum_labels_length[i - 1] + i; - offset1 = precum_labels_length[i - 1]; - } - for (int j = 0; j < label_squence_length[i]; ++j) { - label_value_with_blank[offset + 2 * j] = blank; - label_value_with_blank[offset + 2 * j + 1] = label_value[offset1 + j]; - } - label_value_with_blank[offset + 2 * label_squence_length[i]] = blank; - } -} - -cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length, - int *precum_labels_length, int *cum_labels_length, int batch, int blank, - cudaStream_t stream) { - GenLabelWithBlankKernel<<>>( - label_value, label_value_with_blank, label_squence_length, precum_labels_length, cum_labels_length, batch, blank); - return GetCudaStatus(); -} - -cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, - int *cum_labels_length, int *max_labels_length, int batch, cudaStream_t stream) { - GenLabelValuePCRKernel<<>>(label_value_sp, label_value_pcr, - label_squence_length, cum_labels_length, batch); - UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch); - return GetCudaStatus(); -} - -__global__ void GenLabelValueKernel(int *label_value_sp, const int64_t *label_indices, const int *label_values, - int *label_squence_length, int *cum_labels_length, int size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - int64_t b = label_indices[i * 2]; - int offset = 0; - if (b > 0) { - offset = cum_labels_length[b - 1]; - } - int64_t index = offset + label_indices[i * 2 + 1]; - label_value_sp[index] = label_values[i]; - } -} -__global__ void LabelValueInitKernel(int *label_value_sp, int size, int blank) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - label_value_sp[i] = blank; - } -} -__global__ void RecalculateLengthKernel(int *label_value_sp, int *label_squence_length, int *cum_labels_length, - int batch, int blank) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { - int offset = 0; - if (i > 0) { - offset = cum_labels_length[i - 1]; - } - int L = label_squence_length[i]; - label_squence_length[i] = 0; - for (int j = offset; j < offset + L; ++j) { - if (label_value_sp[j] >= blank) { - break; - } else { - label_squence_length[i]++; - } - } - } -} -cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values, - int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, - int blank, int batch, cudaStream_t stream) { - LabelValueInitKernel<<>>(label_value_sp, size, blank); - GenLabelValueKernel<<>>(label_value_sp, label_indices, label_values, - label_squence_length, cum_labels_length, size); - RecalculateLengthKernel<<>>(label_value_sp, label_squence_length, - cum_labels_length, batch, blank); - UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch); - return GetCudaStatus(); -} - -__global__ void CalculatePreLengthKernel(int *label_squence_length, int *precum_labels_length, int *cum_labels_length, - int *max_labels_length, const int64_t *label_indices, int batch, int size) { - max_labels_length[0] = 0; - for (int i = 0; i < size; ++i) { - label_squence_length[label_indices[i * 2]]++; - if (max_labels_length[0] < label_indices[i * 2]) { - max_labels_length[0] = label_indices[i * 2]; - } - } - precum_labels_length[0] = label_squence_length[0]; - cum_labels_length[0] = label_squence_length[0]; - for (int i = 1; i < batch; ++i) { - cum_labels_length[i] = cum_labels_length[i - 1] + label_squence_length[i]; - precum_labels_length[i] = precum_labels_length[i - 1] + label_squence_length[i]; - } -} - -__global__ void CalculateMaxSequenceKernel(const int *sequence_length, int *max_labels_length, int batch) { - max_labels_length[0] = 0; - for (int i = 0; i < batch; ++i) { - if (sequence_length[i] > max_labels_length[0]) { - max_labels_length[0] = sequence_length[i]; - } - } -} - -cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream) { - CalculateMaxSequenceKernel<<<1, 1, 0, stream>>>(sequence_length, max_labels_length, batch); - return GetCudaStatus(); -} - -cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length, - int *max_labels_length, const int64_t *label_indices, int batch, int size, - cudaStream_t stream) { - CalculatePreLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, precum_labels_length, cum_labels_length, - max_labels_length, label_indices, batch, size); - return GetCudaStatus(); -} - -template -cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, - int SOffSet, int maxtime, int numclass, const int *sequence_length, int *label_squence_length, - int *cum_labels_length, T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, - cudaStream_t stream) { - ProbInitKernel<<>>(prob_num, - maxtime * batch * numclass); - CTCLossKernel<<>>( - log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, maxtime, numclass, sequence_length, - label_squence_length, cum_labels_length, cost, grads, prob_num, ignore_longer_outputs_than_inputs); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalculateFwdVar( - float *log_alpha_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length, - bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length, - int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t CalculateBwdVar( - float *log_beta_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length, - bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length, - int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t InnerSoftMax(const float *probs, float *softmax_probs, - const int *sequence_length, int max_time, int batch, - int numclass, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t CTCLoss(float *log_alpha_b, float *log_beta_b, float *softmax_probs, - int *label_value_with_blank, int batch, int SOffSet, int maxtime, - int numclass, const int *sequence_length, int *label_squence_length, - int *cum_labels_length, float *cost, float *grads, float *prob_num, - bool ignore_longer_outputs_than_inputs, cudaStream_t stream); +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ctcloss_impl.cuh" +template +__device__ T LogSumExp(const T logprob1, const T logprob2) { + if (logprob1 == logprob2 && logprob1 == -std::numeric_limits::infinity()) { + return logprob1; + } else { + return (logprob1 > logprob2) ? logprob1 + log1pf(expf(logprob2 - logprob1)) + : logprob2 + log1pf(expf(logprob1 - logprob2)); + } +} + +template +__global__ void CalculateFwdVarKernel(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, + const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, + int maxtime, int blank, int *label_squence_length, int *cum_labels_length, + bool ignore_longer_outputs_than_inputs) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + if (sequence_length[i] == 0 || + (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { + } else { + T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime]; + int *label_value_with_blank_cur = &label_value_with_blank[0]; + if (i > 0) { + label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; + } + int numclass = blank + 1; + int U = 2 * label_squence_length[i] + 1; + int Ti = sequence_length[i]; + int low = 0; + int high = 0; + log_alpha_b_cur[0] = log(softmax_probs[i * numclass + blank]); + int label0 = blank; + if (U > 1) { + label0 = label_value_with_blank_cur[1]; + log_alpha_b_cur[maxtime] = log(softmax_probs[i * numclass + label0]); + } + for (int t = 1; t < Ti; ++t) { + low = 0; + high = U; + int low_limit = U - (2 * (Ti - t)); + int high_limit = 2 * (t + 1); + if (low_limit > low) { + low = low_limit; + } + if (high_limit < U) { + high = high_limit; + } + for (int u = low; u < high; ++u) { + T sum_log_alpha = -std::numeric_limits::infinity(); + if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) { + sum_log_alpha = log_alpha_b_cur[u * maxtime + t - 1]; + } + if (u > 0) { + sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 1) * maxtime + t - 1]); + } + if (u > 1) { + const bool matching_labels_merge = + ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u - 2]); + if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) { + sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 2) * maxtime + t - 1]); + } + } + log_alpha_b_cur[u * maxtime + t] = + log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + t * numclass * batch]) + sum_log_alpha; + } + } + } + } +} + +template +__global__ void CalculateBwdVarKernel(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, + const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, + int maxtime, int blank, int *label_squence_length, int *cum_labels_length, + bool ignore_longer_outputs_than_inputs) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + if (sequence_length[i] == 0 || + (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { + } else { + T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime]; + int *label_value_with_blank_cur = &label_value_with_blank[0]; + if (i > 0) { + label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; + } + int numclass = blank + 1; + int U = 2 * label_squence_length[i] + 1; + int Ti = sequence_length[i]; + int low = 0; + int high = 0; + if (U > 1) { + for (int u = U - 2; u < U; ++u) { + log_beta_b_cur[u * maxtime + Ti - 1] = 0; + } + } else { + log_beta_b_cur[Ti - 1] = 0; + log_beta_b_cur[Ti - 2] = 0; + } + for (int t = Ti - 2; t >= 0; --t) { + low = 0; + high = U; + int low_limit = U - (2 * (Ti - t)); + int high_limit = 2 * (t + 1); + if (low_limit > low) { + low = low_limit; + } + if (high_limit < U) { + high = high_limit; + } + for (int u = low; u < high; ++u) { + if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) { + log_beta_b_cur[u * maxtime + t] = LogSumExp( + log_beta_b_cur[u * maxtime + t], + log_beta_b_cur[u * maxtime + t + 1] + + log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + (t + 1) * numclass * batch])); + } + if (u + 1 < U) { + log_beta_b_cur[u * maxtime + t] = LogSumExp( + log_beta_b_cur[u * maxtime + t], + log_beta_b_cur[(u + 1) * maxtime + t + 1] + + log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 1] + (t + 1) * numclass * batch])); + } + if (u + 2 < U) { + const bool matching_labels_merge = + ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u + 2]); + if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) { + log_beta_b_cur[u * maxtime + t] = LogSumExp( + log_beta_b_cur[u * maxtime + t], + log_beta_b_cur[(u + 2) * maxtime + t + 1] + + log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 2] + (t + 1) * numclass * batch])); + } + } + } + } + } + } +} + +template +__global__ void ProbInitKernel(T *prob_num, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + prob_num[i] = -std::numeric_limits::infinity(); + } +} +template +__global__ void LogBInitKernel(T *log_b, int log_prob_size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < log_prob_size; i += blockDim.x * gridDim.x) { + log_b[i] = -std::numeric_limits::infinity(); + } +} + +template +__global__ void CTCLossKernel(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, + int SOffSet, int maxtime, int numclass, const int *sequence_length, + int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num, + bool ignore_longer_outputs_than_inputs) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + if (sequence_length[i] == 0 || + (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) { + } else { + T *grad_cur = &grads[i * numclass]; + const T *softmax_probs_cur = &softmax_probs[i * numclass]; + T *prob_num_cur = &prob_num[i * numclass]; + int U = 2 * label_squence_length[i] + 1; + T log_pzx = -std::numeric_limits::infinity(); + const T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime]; + const T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime]; + int *label_value_with_blank_cur = &label_value_with_blank[0]; + if (i > 0) { + label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i]; + } + for (int u = 0; u < U; ++u) { + log_pzx = LogSumExp(log_pzx, log_alpha_b_cur[u * maxtime] + log_beta_b_cur[u * maxtime]); + } + cost[i] = -log_pzx; + // grad + int L = numclass; + int Ti = sequence_length[i]; + if (log_pzx == -std::numeric_limits::infinity()) { + for (int t = 0; t < Ti; ++t) { + for (int l = 0; l < L; ++l) { + grad_cur[t * numclass * batch + l] = softmax_probs_cur[t * numclass * batch + l]; + } + } + } else { + for (int t = 0; t < Ti; ++t) { + for (int u = 0; u < U; ++u) { + int l = label_value_with_blank_cur[u]; + prob_num_cur[t * batch * numclass + l] = + LogSumExp(prob_num_cur[t * batch * numclass + l], + log_alpha_b_cur[u * maxtime + t] + log_beta_b_cur[u * maxtime + t]); + } + for (int l = 0; l < L; ++l) { + grad_cur[t * numclass * batch + l] = + softmax_probs_cur[t * numclass * batch + l] - expf(prob_num_cur[t * batch * numclass + l] - log_pzx); + } + } + } + } + } +} + +template +__global__ void InnerSoftMaxKernel(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, + int batch, int numclass) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch * max_time; i += blockDim.x * gridDim.x) { + int k = i / batch; + int m = i % batch; + if (k < sequence_length[m]) { + T maxCoeff = 0.; + T sumCoeff = 0.; + for (int j = i * numclass; j < (i + 1) * numclass; ++j) { + if (probs[j] > maxCoeff) { + maxCoeff = probs[j]; + } + } + for (int j = i * numclass; j < (i + 1) * numclass; ++j) { + sumCoeff += exp(probs[j] - maxCoeff); + softmax_probs[j] = exp(probs[j] - maxCoeff); + } + for (int j = i * numclass; j < (i + 1) * numclass; ++j) { + softmax_probs[j] /= sumCoeff; + } + } + } +} + +__global__ void GenLabelValuePCRKernel(int *label_value_sp, int *label_value_pcr, int *label_squence_length, + int *cum_labels_length, int batch) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + int L = label_squence_length[i]; + label_squence_length[i] = 0; + int offset = 0; + if (i > 0) { + offset = cum_labels_length[i - 1]; + } + for (int l = offset; l < L; ++l) { + if (l == offset || label_value_sp[l] != label_value_sp[l - 1]) { + label_value_pcr[offset + label_squence_length[i]++] = label_value_sp[l]; + } + } + } +} + +__global__ void UpdateLengthKernel(int *label_squence_length, int *cum_labels_length, int *max_labels_length, + int batch) { + max_labels_length[0] = 0; + for (int i = 0; i < batch; ++i) { + if (label_squence_length[i] > max_labels_length[0]) { + max_labels_length[0] = label_squence_length[i]; + } + if (i == 0) { + cum_labels_length[i] = label_squence_length[i]; + } else { + cum_labels_length[i] = label_squence_length[i] + cum_labels_length[i - 1]; + } + } +} + +template +cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length, + bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, + int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs, + cudaStream_t stream) { + int log_prob_size = SOffSet * batch * maxtime; + LogBInitKernel<<>>(log_beta_b, log_prob_size); + CalculateBwdVarKernel<<>>( + log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime, + blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs); + return GetCudaStatus(); +} + +template +cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length, + bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, + int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs, + cudaStream_t stream) { + int log_prob_size = SOffSet * batch * maxtime; + LogBInitKernel<<>>(log_alpha_b, log_prob_size); + CalculateFwdVarKernel<<>>( + log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime, + blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs); + return GetCudaStatus(); +} + +template +cudaError_t InnerSoftMax(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, int batch, + int numclass, cudaStream_t stream) { + InnerSoftMaxKernel<<>>(probs, softmax_probs, sequence_length, + max_time, batch, numclass); + return GetCudaStatus(); +} + +__global__ void GenLabelWithBlankKernel(int *label_value, int *label_value_with_blank, int *label_squence_length, + int *precum_labels_length, int *cum_labels_length, int batch, int blank) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + int offset = 0; + int offset1 = 0; + if (i > 0) { + offset = 2 * cum_labels_length[i - 1] + i; + offset1 = precum_labels_length[i - 1]; + } + for (int j = 0; j < label_squence_length[i]; ++j) { + label_value_with_blank[offset + 2 * j] = blank; + label_value_with_blank[offset + 2 * j + 1] = label_value[offset1 + j]; + } + label_value_with_blank[offset + 2 * label_squence_length[i]] = blank; + } +} + +cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length, + int *precum_labels_length, int *cum_labels_length, int batch, int blank, + cudaStream_t stream) { + GenLabelWithBlankKernel<<>>( + label_value, label_value_with_blank, label_squence_length, precum_labels_length, cum_labels_length, batch, blank); + return GetCudaStatus(); +} + +cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, + int *cum_labels_length, int *max_labels_length, int batch, cudaStream_t stream) { + GenLabelValuePCRKernel<<>>(label_value_sp, label_value_pcr, + label_squence_length, cum_labels_length, batch); + UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch); + return GetCudaStatus(); +} + +__global__ void GenLabelValueKernel(int *label_value_sp, const int64_t *label_indices, const int *label_values, + int *label_squence_length, int *cum_labels_length, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + int64_t b = label_indices[i * 2]; + int offset = 0; + if (b > 0) { + offset = cum_labels_length[b - 1]; + } + int64_t index = offset + label_indices[i * 2 + 1]; + label_value_sp[index] = label_values[i]; + } +} +__global__ void LabelValueInitKernel(int *label_value_sp, int size, int blank) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + label_value_sp[i] = blank; + } +} +__global__ void RecalculateLengthKernel(int *label_value_sp, int *label_squence_length, int *cum_labels_length, + int batch, int blank) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) { + int offset = 0; + if (i > 0) { + offset = cum_labels_length[i - 1]; + } + int L = label_squence_length[i]; + label_squence_length[i] = 0; + for (int j = offset; j < offset + L; ++j) { + if (label_value_sp[j] >= blank) { + break; + } else { + label_squence_length[i]++; + } + } + } +} +cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values, + int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, + int blank, int batch, cudaStream_t stream) { + LabelValueInitKernel<<>>(label_value_sp, size, blank); + GenLabelValueKernel<<>>(label_value_sp, label_indices, label_values, + label_squence_length, cum_labels_length, size); + RecalculateLengthKernel<<>>(label_value_sp, label_squence_length, + cum_labels_length, batch, blank); + UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch); + return GetCudaStatus(); +} + +__global__ void CalculatePreLengthKernel(int *label_squence_length, int *precum_labels_length, int *cum_labels_length, + int *max_labels_length, const int64_t *label_indices, int batch, int size) { + max_labels_length[0] = 0; + for (int i = 0; i < size; ++i) { + label_squence_length[label_indices[i * 2]]++; + if (max_labels_length[0] < label_indices[i * 2]) { + max_labels_length[0] = label_indices[i * 2]; + } + } + precum_labels_length[0] = label_squence_length[0]; + cum_labels_length[0] = label_squence_length[0]; + for (int i = 1; i < batch; ++i) { + cum_labels_length[i] = cum_labels_length[i - 1] + label_squence_length[i]; + precum_labels_length[i] = precum_labels_length[i - 1] + label_squence_length[i]; + } +} + +__global__ void CalculateMaxSequenceKernel(const int *sequence_length, int *max_labels_length, int batch) { + max_labels_length[0] = 0; + for (int i = 0; i < batch; ++i) { + if (sequence_length[i] > max_labels_length[0]) { + max_labels_length[0] = sequence_length[i]; + } + } +} + +cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream) { + CalculateMaxSequenceKernel<<<1, 1, 0, stream>>>(sequence_length, max_labels_length, batch); + return GetCudaStatus(); +} + +cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length, + int *max_labels_length, const int64_t *label_indices, int batch, int size, + cudaStream_t stream) { + CalculatePreLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, precum_labels_length, cum_labels_length, + max_labels_length, label_indices, batch, size); + return GetCudaStatus(); +} + +template +cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, + int SOffSet, int maxtime, int numclass, const int *sequence_length, int *label_squence_length, + int *cum_labels_length, T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, + cudaStream_t stream) { + ProbInitKernel<<>>(prob_num, + maxtime * batch * numclass); + CTCLossKernel<<>>( + log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, maxtime, numclass, sequence_length, + label_squence_length, cum_labels_length, cost, grads, prob_num, ignore_longer_outputs_than_inputs); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalculateFwdVar( + float *log_alpha_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length, + bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length, + int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t CalculateBwdVar( + float *log_beta_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length, + bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length, + int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t InnerSoftMax(const float *probs, float *softmax_probs, + const int *sequence_length, int max_time, int batch, + int numclass, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t CTCLoss(float *log_alpha_b, float *log_beta_b, float *softmax_probs, + int *label_value_with_blank, int batch, int SOffSet, int maxtime, + int numclass, const int *sequence_length, int *label_squence_length, + int *cum_labels_length, float *cost, float *grads, float *prob_num, + bool ignore_longer_outputs_than_inputs, cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh index c79eb4fb371..fa777051a68 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh @@ -1,59 +1,59 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, - const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, - int maxtime, int blank, int *label_squence_length, int *cum_labels_length, - bool ignore_longer_outputs_than_inputs, cudaStream_t stream); - -template -CUDA_LIB_EXPORT cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, - const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, - int maxtime, int blank, int *label_squence_length, int *cum_labels_length, - bool ignore_longer_outputs_than_inputs, cudaStream_t stream); - -template -CUDA_LIB_EXPORT cudaError_t InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time, - int batch, int numclass, cudaStream_t stream); - -CUDA_LIB_EXPORT cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, - int *cum_labels_length, int *max_labels_length, int batch, - cudaStream_t stream); - -CUDA_LIB_EXPORT cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length, - int *precum_labels_length, int *cum_labels_length, int batch, int blank, - cudaStream_t stream); - -CUDA_LIB_EXPORT cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values, - int *label_squence_length, int *cum_labels_length, int *max_labels_length, - int size, int blank, int batch, cudaStream_t stream); - -CUDA_LIB_EXPORT cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, - int *cum_labels_length, int *max_labels_length, - const int64_t *label_indices, int batch, int size, cudaStream_t stream); -CUDA_LIB_EXPORT cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, - cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, - int batch, int SOffSet, int maxtime, int numclass, const int *sequence_length, - int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num, - bool ignore_longer_outputs_than_inputs, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, + const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, + int maxtime, int blank, int *label_squence_length, int *cum_labels_length, + bool ignore_longer_outputs_than_inputs, cudaStream_t stream); + +template +CUDA_LIB_EXPORT cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, + const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet, + int maxtime, int blank, int *label_squence_length, int *cum_labels_length, + bool ignore_longer_outputs_than_inputs, cudaStream_t stream); + +template +CUDA_LIB_EXPORT cudaError_t InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time, + int batch, int numclass, cudaStream_t stream); + +CUDA_LIB_EXPORT cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, + int *cum_labels_length, int *max_labels_length, int batch, + cudaStream_t stream); + +CUDA_LIB_EXPORT cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length, + int *precum_labels_length, int *cum_labels_length, int batch, int blank, + cudaStream_t stream); + +CUDA_LIB_EXPORT cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values, + int *label_squence_length, int *cum_labels_length, int *max_labels_length, + int size, int blank, int batch, cudaStream_t stream); + +CUDA_LIB_EXPORT cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, + int *cum_labels_length, int *max_labels_length, + const int64_t *label_indices, int batch, int size, cudaStream_t stream); +CUDA_LIB_EXPORT cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, + cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, + int batch, int SOffSet, int maxtime, int numclass, const int *sequence_length, + int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num, + bool ignore_longer_outputs_than_inputs, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cu index 1f729050f79..896fc73efd1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cu @@ -1,190 +1,190 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cumprod_impl.cuh" -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" - -template -__global__ void Copy(T *input, T *output, size_t size) { - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) { - input[write_index] = output[write_index]; - } -} - -template -__global__ void LeftMoveProd(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (size_t j = 0; j < dim1; ++j) { - size_t read_index = j * stride2 + offset; - if (j == 0) { - output[read_index] = 1; - } else { - size_t read_index2 = (j - 1) * stride2 + offset; - output[read_index] = input[read_index2]; - } - } - } -} - -template -__global__ void RightMoveProd(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (int j = dim1 - 1; j >= 0; --j) { - size_t read_index = j * stride2 + offset; - if (j == dim1 - 1) { - output[read_index] = 1; - } else { - size_t read_index2 = (j + 1) * stride2 + offset; - output[read_index] = input[read_index2]; - } - } - } -} -template -__global__ void CumProdKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (int j = dim1 - 1; j >= 0; --j) { - size_t read_index = j * stride2 + offset; - if (j == dim1 - 1) { - output[read_index] = input[read_index]; - } else { - size_t read_index2 = (j + 1) * stride2 + offset; - output[read_index] = output[read_index2] * input[read_index]; - } - } - } -} - -template -__global__ void CumProdKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (size_t j = 0; j < dim1; ++j) { - size_t read_index = j * stride2 + offset; - if (j == 0) { - output[read_index] = input[read_index]; - } else { - size_t read_index2 = (j - 1) * stride2 + offset; - output[read_index] = output[read_index2] * input[read_index]; - } - } - } -} -template -cudaError_t CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) { - int size = dim0 * dim2; - if (exclusive_) { - if (reverse_) { - RightMoveProd<<>>(input, output, dim0, dim1, dim2, stride, stride2); - Copy<<>>(workspace, output, size * dim1); - CumProdKernelReverse<<>>(workspace, output, dim0, dim1, dim2, stride, - stride2); - } else { - LeftMoveProd<<>>(input, output, dim0, dim1, dim2, stride, stride2); - Copy<<>>(workspace, output, size * dim1); - CumProdKernel<<>>(workspace, output, dim0, dim1, dim2, stride, stride2); - } - } else { - if (reverse_) { - CumProdKernelReverse<<>>(input, output, dim0, dim1, dim2, stride, - stride2); - } else { - CumProdKernel<<>>(input, output, dim0, dim1, dim2, stride, stride2); - } - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CumProd(const uint8_t *input, uint8_t *output, uint8_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const uint16_t *input, uint16_t *output, uint16_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const uint32_t *input, uint32_t *output, uint32_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const uint64_t *input, uint64_t *output, uint64_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const int8_t *input, int8_t *output, int8_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const int16_t *input, int16_t *output, int16_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const int32_t *input, int32_t *output, int32_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const int64_t *input, int64_t *output, int64_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const double *input, double *output, double *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const float *input, float *output, float *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd(const half *input, half *output, half *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd>(const Complex *input, Complex *output, - Complex *workspace, size_t dim0, size_t dim1, - size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumProd>(const Complex *input, Complex *output, - Complex *workspace, size_t dim0, size_t dim1, - size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, cudaStream_t stream); +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cumprod_impl.cuh" +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +template +__global__ void Copy(T *input, T *output, size_t size) { + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) { + input[write_index] = output[write_index]; + } +} + +template +__global__ void LeftMoveProd(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = 1; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = input[read_index2]; + } + } + } +} + +template +__global__ void RightMoveProd(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (int j = dim1 - 1; j >= 0; --j) { + size_t read_index = j * stride2 + offset; + if (j == dim1 - 1) { + output[read_index] = 1; + } else { + size_t read_index2 = (j + 1) * stride2 + offset; + output[read_index] = input[read_index2]; + } + } + } +} +template +__global__ void CumProdKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (int j = dim1 - 1; j >= 0; --j) { + size_t read_index = j * stride2 + offset; + if (j == dim1 - 1) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j + 1) * stride2 + offset; + output[read_index] = output[read_index2] * input[read_index]; + } + } + } +} + +template +__global__ void CumProdKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = output[read_index2] * input[read_index]; + } + } + } +} +template +cudaError_t CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) { + int size = dim0 * dim2; + if (exclusive_) { + if (reverse_) { + RightMoveProd<<>>(input, output, dim0, dim1, dim2, stride, stride2); + Copy<<>>(workspace, output, size * dim1); + CumProdKernelReverse<<>>(workspace, output, dim0, dim1, dim2, stride, + stride2); + } else { + LeftMoveProd<<>>(input, output, dim0, dim1, dim2, stride, stride2); + Copy<<>>(workspace, output, size * dim1); + CumProdKernel<<>>(workspace, output, dim0, dim1, dim2, stride, stride2); + } + } else { + if (reverse_) { + CumProdKernelReverse<<>>(input, output, dim0, dim1, dim2, stride, + stride2); + } else { + CumProdKernel<<>>(input, output, dim0, dim1, dim2, stride, stride2); + } + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CumProd(const uint8_t *input, uint8_t *output, uint8_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const uint16_t *input, uint16_t *output, uint16_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const uint32_t *input, uint32_t *output, uint32_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const uint64_t *input, uint64_t *output, uint64_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const int8_t *input, int8_t *output, int8_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const int16_t *input, int16_t *output, int16_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const int32_t *input, int32_t *output, int32_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const int64_t *input, int64_t *output, int64_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const double *input, double *output, double *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const float *input, float *output, float *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd(const half *input, half *output, half *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd>(const Complex *input, Complex *output, + Complex *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumProd>(const Complex *input, Complex *output, + Complex *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cuh index 283e8af388a..f38d97d237c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumprod_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -template -CUDA_LIB_EXPORT cudaError_t CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, - size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +template +CUDA_LIB_EXPORT cudaError_t CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, + size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMPROD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cu index f4971b145c4..8dd2d8ceaea 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cu @@ -1,200 +1,200 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "cumsum_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" - -template -__global__ void Copy(T *input, T *output, size_t size) { - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) { - input[write_index] = output[write_index]; - } -} - -template -__global__ void LeftMoveSum(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (size_t j = 0; j < dim1; ++j) { - size_t read_index = j * stride2 + offset; - if (j == 0) { - output[read_index] = 0; - } else { - size_t read_index2 = (j - 1) * stride2 + offset; - output[read_index] = input[read_index2]; - } - } - } -} - -template -__global__ void RightMoveSum(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (int j = dim1 - 1; j >= 0; --j) { - size_t read_index = j * stride2 + offset; - if (j == dim1 - 1) { - output[read_index] = 0; - } else { - size_t read_index2 = (j + 1) * stride2 + offset; - output[read_index] = input[read_index2]; - } - } - } -} -template -__global__ void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (int j = dim1 - 1; j >= 0; --j) { - size_t read_index = j * stride2 + offset; - if (j == dim1 - 1) { - output[read_index] = input[read_index]; - } else { - size_t read_index2 = (j + 1) * stride2 + offset; - output[read_index] = output[read_index2] + input[read_index]; - } - } - } -} - -template -__global__ void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2) { - size_t num = dim0 * dim2; - size_t i, k, offset; - size_t step = blockDim.x * gridDim.x; - for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { - i = write_index / dim2 % dim0; - k = write_index % dim2; - offset = i * stride + k; - for (size_t j = 0; j < dim1; ++j) { - size_t read_index = j * stride2 + offset; - if (j == 0) { - output[read_index] = input[read_index]; - } else { - size_t read_index2 = (j - 1) * stride2 + offset; - output[read_index] = output[read_index2] + input[read_index]; - } - } - } -} -template -cudaError_t CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, const uint32_t &device_id, cudaStream_t stream) { - int size = dim0 * dim2; - int block_num = size > 256 ? 256 : size; - if (exclusive_) { - if (reverse_) { - RightMoveSum<<>>(input, output, dim0, dim1, - dim2, stride, stride2); - Copy<<>>(workspace, output, size * dim1); - CumSumKernelReverse<<>>( - workspace, output, dim0, dim1, dim2, stride, stride2); - } else { - LeftMoveSum<<>>(input, output, dim0, dim1, - dim2, stride, stride2); - Copy<<>>(workspace, output, size * dim1); - CumSumKernel<<>>(workspace, output, dim0, dim1, - dim2, stride, stride2); - } - } else { - if (reverse_) { - CumSumKernelReverse<<>>( - input, output, dim0, dim1, dim2, stride, stride2); - } else { - CumSumKernel<<>>(input, output, dim0, dim1, - dim2, stride, stride2); - } - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CumSum(const int8_t *input, int8_t *output, int8_t *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, const uint32_t &device_id, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const int16_t *input, int16_t *output, int16_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const int32_t *input, int32_t *output, int32_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const int64_t *input, int64_t *output, int64_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const uint8_t *input, uint8_t *output, uint8_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const uint16_t *input, uint16_t *output, uint16_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const uint32_t *input, uint32_t *output, uint32_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const uint64_t *input, uint64_t *output, uint64_t *workspace, - size_t dim0, size_t dim1, size_t dim2, size_t stride, - size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const double *input, double *output, double *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, const uint32_t &device_id, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const float *input, float *output, float *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, const uint32_t &device_id, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum(const half *input, half *output, half *workspace, size_t dim0, - size_t dim1, size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, const uint32_t &device_id, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum>(const Complex *input, Complex *output, - Complex *workspace, size_t dim0, size_t dim1, - size_t dim2, size_t stride, size_t stride2, bool exclusive_, - bool reverse_, const uint32_t &device_id, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t CumSum>(const Complex *input, Complex *output, - Complex *workspace, size_t dim0, size_t dim1, - size_t dim2, size_t stride, size_t stride2, - bool exclusive_, bool reverse_, const uint32_t &device_id, - cudaStream_t stream); +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "cumsum_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +template +__global__ void Copy(T *input, T *output, size_t size) { + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) { + input[write_index] = output[write_index]; + } +} + +template +__global__ void LeftMoveSum(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = 0; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = input[read_index2]; + } + } + } +} + +template +__global__ void RightMoveSum(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (int j = dim1 - 1; j >= 0; --j) { + size_t read_index = j * stride2 + offset; + if (j == dim1 - 1) { + output[read_index] = 0; + } else { + size_t read_index2 = (j + 1) * stride2 + offset; + output[read_index] = input[read_index2]; + } + } + } +} +template +__global__ void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (int j = dim1 - 1; j >= 0; --j) { + size_t read_index = j * stride2 + offset; + if (j == dim1 - 1) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j + 1) * stride2 + offset; + output[read_index] = output[read_index2] + input[read_index]; + } + } + } +} + +template +__global__ void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + size_t step = blockDim.x * gridDim.x; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = output[read_index2] + input[read_index]; + } + } + } +} +template +cudaError_t CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, const uint32_t &device_id, cudaStream_t stream) { + int size = dim0 * dim2; + int block_num = size > 256 ? 256 : size; + if (exclusive_) { + if (reverse_) { + RightMoveSum<<>>(input, output, dim0, dim1, + dim2, stride, stride2); + Copy<<>>(workspace, output, size * dim1); + CumSumKernelReverse<<>>( + workspace, output, dim0, dim1, dim2, stride, stride2); + } else { + LeftMoveSum<<>>(input, output, dim0, dim1, + dim2, stride, stride2); + Copy<<>>(workspace, output, size * dim1); + CumSumKernel<<>>(workspace, output, dim0, dim1, + dim2, stride, stride2); + } + } else { + if (reverse_) { + CumSumKernelReverse<<>>( + input, output, dim0, dim1, dim2, stride, stride2); + } else { + CumSumKernel<<>>(input, output, dim0, dim1, + dim2, stride, stride2); + } + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CumSum(const int8_t *input, int8_t *output, int8_t *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, const uint32_t &device_id, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const int16_t *input, int16_t *output, int16_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const int32_t *input, int32_t *output, int32_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const int64_t *input, int64_t *output, int64_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const uint8_t *input, uint8_t *output, uint8_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const uint16_t *input, uint16_t *output, uint16_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const uint32_t *input, uint32_t *output, uint32_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const uint64_t *input, uint64_t *output, uint64_t *workspace, + size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const double *input, double *output, double *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, const uint32_t &device_id, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const float *input, float *output, float *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, const uint32_t &device_id, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum(const half *input, half *output, half *workspace, size_t dim0, + size_t dim1, size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, const uint32_t &device_id, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum>(const Complex *input, Complex *output, + Complex *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, bool exclusive_, + bool reverse_, const uint32_t &device_id, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t CumSum>(const Complex *input, Complex *output, + Complex *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, + bool exclusive_, bool reverse_, const uint32_t &device_id, + cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cuh index 71a4a19b9c6..ed681d4bf36 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -template -CUDA_LIB_EXPORT cudaError_t CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, - size_t stride, size_t stride2, bool exclusive_, bool reverse_, - const uint32_t &device_id, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +template +CUDA_LIB_EXPORT cudaError_t CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, + size_t stride, size_t stride2, bool exclusive_, bool reverse_, + const uint32_t &device_id, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUMSUM_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cu index e508de1f294..6cedc55b337 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cu @@ -1,68 +1,68 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh" -#include "include/cuda_runtime.h" - -template -__global__ void DataFormatVecPermuteKernel1D(const size_t size, const T *input, T *output, int32_t *index) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = input[index[pos]]; - } - return; -} - -template -__global__ void DataFormatVecPermuteKernel2D(const size_t size, const T *input, T *output, int32_t *index) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - int32_t dim = static_cast(2); - int32_t i = static_cast(pos) / dim; - output[dim * i] = input[dim * index[i]]; - output[dim * i + 1] = input[dim * index[i] + 1]; - } - return; -} - -template -cudaError_t CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index, - const uint32_t &device_id, cudaStream_t cuda_stream) { - DataFormatVecPermuteKernel1D<<>>( - size, input, output, index); - return GetCudaStatus(); -} - -template -cudaError_t CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index, - const uint32_t &device_id, cudaStream_t cuda_stream) { - DataFormatVecPermuteKernel2D<<>>( - size, input, output, index); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const int *input, int *output, - int32_t *index, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const int64_t *input, - int64_t *output, int32_t *index, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const int *input, int *output, - int32_t *index, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const int64_t *input, - int64_t *output, int32_t *index, - const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh" +#include "include/cuda_runtime.h" + +template +__global__ void DataFormatVecPermuteKernel1D(const size_t size, const T *input, T *output, int32_t *index) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = input[index[pos]]; + } + return; +} + +template +__global__ void DataFormatVecPermuteKernel2D(const size_t size, const T *input, T *output, int32_t *index) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int32_t dim = static_cast(2); + int32_t i = static_cast(pos) / dim; + output[dim * i] = input[dim * index[i]]; + output[dim * i + 1] = input[dim * index[i] + 1]; + } + return; +} + +template +cudaError_t CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index, + const uint32_t &device_id, cudaStream_t cuda_stream) { + DataFormatVecPermuteKernel1D<<>>( + size, input, output, index); + return GetCudaStatus(); +} + +template +cudaError_t CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index, + const uint32_t &device_id, cudaStream_t cuda_stream) { + DataFormatVecPermuteKernel2D<<>>( + size, input, output, index); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const int *input, int *output, + int32_t *index, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const int64_t *input, + int64_t *output, int32_t *index, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const int *input, int *output, + int32_t *index, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const int64_t *input, + int64_t *output, int32_t *index, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh index d447f8b644c..b1b978d9e6b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cu index 79137d639c5..eb4b494bb11 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cu @@ -1,53 +1,53 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" - -template -__global__ void DiagPart(const size_t size, const T *input, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = input[(1 + size) * pos]; - } -} - -template -cudaError_t CalDiagPart(const size_t size, const T *input, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - DiagPart<<>>(size, input, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const int32_t *input, int32_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const int64_t *input, int64_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const half *input, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const double *input, double *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const float *input, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart>(const size_t size, - const std::complex *input, - std::complex *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalDiagPart>(const size_t size, - const std::complex *input, - std::complex *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh" + +template +__global__ void DiagPart(const size_t size, const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = input[(1 + size) * pos]; + } +} + +template +cudaError_t CalDiagPart(const size_t size, const T *input, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + DiagPart<<>>(size, input, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const int32_t *input, int32_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const int64_t *input, int64_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const half *input, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const double *input, double *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart(const size_t size, const float *input, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart>(const size_t size, + const std::complex *input, + std::complex *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalDiagPart>(const size_t size, + const std::complex *input, + std::complex *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh index cd470a36492..5abd113ee26 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/diag_part_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -cudaError_t CalDiagPart(const size_t size, const T *input, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +cudaError_t CalDiagPart(const size_t size, const T *input, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DIAG_PART_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cu index 0299b78320d..ba93fab2cc7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cu @@ -1,43 +1,43 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "equalcount_impl.cuh" -#include "include/cuda_fp16.h" -template -__global__ void EqualCount(const int size, const T *input1, const T *input2, T *output) { - T equal_count = 0; - - for (int i = 0; i < size; i++) { - if (input1[i] == input2[i]) { - equal_count++; - } - } - - output[0] = equal_count; - return; -} -template -cudaError_t CalEqualCount(const int size, const T *input1, const T *input2, T *output, cudaStream_t cuda_stream) { - EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const int *input1, const int *input2, - int *output, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const float *input1, const float *input2, - float *output, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const half *input1, const half *input2, - half *output, cudaStream_t cuda_stream); +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "equalcount_impl.cuh" +#include "include/cuda_fp16.h" +template +__global__ void EqualCount(const int size, const T *input1, const T *input2, T *output) { + T equal_count = 0; + + for (int i = 0; i < size; i++) { + if (input1[i] == input2[i]) { + equal_count++; + } + } + + output[0] = equal_count; + return; +} +template +cudaError_t CalEqualCount(const int size, const T *input1, const T *input2, T *output, cudaStream_t cuda_stream) { + EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const int *input1, const int *input2, + int *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const float *input1, const float *input2, + float *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const half *input1, const half *input2, + half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cuh index ba7c8b85508..df2a217016d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/equalcount_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -template -CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const T *input1, const T *input2, T *output, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT cudaError_t CalEqualCount(const int size, const T *input1, const T *input2, T *output, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EQUALCOUNT_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cu index 8ebf051025c..a3b4b47fcc1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cu @@ -1,107 +1,107 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "fake_learned_scale_quant_perchannel_impl.cuh" -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -__global__ void FakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, - const int channel_num) { - int channel_idx = 0; - int per_channel_num = size / channel_num; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); - // dequantize - output[i] = input_quant[i] * input_alpha[channel_idx]; - } - return; -} - -__global__ void FakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, - const int size, const float *input_div_alpha, - const float *input_quant, const bool neg_trunc, - const int channel_num) { - int channel_idx = 0; - int per_channel_num = size / channel_num; - float lower_bound = -1.0 * !neg_trunc; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - float grad_alpha_temp = 0.f; - channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); - if (input_div_alpha[i] > 1.0) { - grad_alpha_temp = gradient[i]; - grad_input[i] = 0; - } else if (input_div_alpha[i] < lower_bound) { - grad_alpha_temp = -gradient[i]; - grad_input[i] = 0; - } else { - grad_input[i] = gradient[i]; - grad_alpha_temp = (gradient[i] * (input_quant[i] - input_div_alpha[i])); - } - MsAtomicAdd(grad_alpha + channel_idx, grad_alpha_temp); - } - return; -} - -__global__ void LSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, - float *input_div_alpha, float *input_quant, const bool neg_trunc, - const int channel_num) { - float input_x; - int channel_idx = 0; - int per_channel_num = size / channel_num; - float lower_bound = -1.0 * !neg_trunc; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); - input_x = input[i] / input_alpha[channel_idx]; - input_div_alpha[i] = input_x; - input_x = max(input_x, lower_bound); - input_x = min(input_x, 1.0); - - // quantize - input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; - } - return; -} - -cudaError_t CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, - const int channel_num, cudaStream_t cuda_stream) { - FakeLearnedScaleQuantPerChannel<<>>(output, size, input_alpha, - input_quant, channel_num); - return GetCudaStatus(); -} - -cudaError_t CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, - const int size, const float *input_div_alpha, - const float *input_quant, const bool neg_trunc, - const int channel_num, cudaStream_t cuda_stream) { - FakeLearnedScaleQuantPerChannelGrad<<>>( - grad_input, grad_alpha, gradient, size, input_div_alpha, input_quant, neg_trunc, channel_num); - return GetCudaStatus(); -} - -cudaError_t CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, - float *input_div_alpha, float *input_quant, const bool neg_trunc, - const int channel_num, cudaStream_t cuda_stream) { - LSQNudgePerChannel<<>>( - input, size, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc, channel_num); - return GetCudaStatus(); -} +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fake_learned_scale_quant_perchannel_impl.cuh" +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +__global__ void FakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, + const int channel_num) { + int channel_idx = 0; + int per_channel_num = size / channel_num; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + // dequantize + output[i] = input_quant[i] * input_alpha[channel_idx]; + } + return; +} + +__global__ void FakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, + const float *input_quant, const bool neg_trunc, + const int channel_num) { + int channel_idx = 0; + int per_channel_num = size / channel_num; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + float grad_alpha_temp = 0.f; + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + if (input_div_alpha[i] > 1.0) { + grad_alpha_temp = gradient[i]; + grad_input[i] = 0; + } else if (input_div_alpha[i] < lower_bound) { + grad_alpha_temp = -gradient[i]; + grad_input[i] = 0; + } else { + grad_input[i] = gradient[i]; + grad_alpha_temp = (gradient[i] * (input_quant[i] - input_div_alpha[i])); + } + MsAtomicAdd(grad_alpha + channel_idx, grad_alpha_temp); + } + return; +} + +__global__ void LSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, + const int channel_num) { + float input_x; + int channel_idx = 0; + int per_channel_num = size / channel_num; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + input_x = input[i] / input_alpha[channel_idx]; + input_div_alpha[i] = input_x; + input_x = max(input_x, lower_bound); + input_x = min(input_x, 1.0); + + // quantize + input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; + } + return; +} + +cudaError_t CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, + const int channel_num, cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerChannel<<>>(output, size, input_alpha, + input_quant, channel_num); + return GetCudaStatus(); +} + +cudaError_t CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, + const float *input_quant, const bool neg_trunc, + const int channel_num, cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerChannelGrad<<>>( + grad_input, grad_alpha, gradient, size, input_div_alpha, input_quant, neg_trunc, channel_num); + return GetCudaStatus(); +} + +cudaError_t CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, + const int channel_num, cudaStream_t cuda_stream) { + LSQNudgePerChannel<<>>( + input, size, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc, channel_num); + return GetCudaStatus(); +} diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh index 9db9a83c36f..fde8d3ebc46 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh @@ -1,36 +1,36 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -CUDA_LIB_EXPORT cudaError_t CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, - float *input_quant_max, float *input_div_alpha, float *input_quant, - const bool neg_trunc, const int channel_num, - cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, - float *input_quant, const int channel_num, - cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, - const float *gradient, const int size, - const float *input_div_alpha, - const float *input_quant, const bool neg_trunc, - const int channel_num, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +CUDA_LIB_EXPORT cudaError_t CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, + float *input_quant_max, float *input_div_alpha, float *input_quant, + const bool neg_trunc, const int channel_num, + cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, + float *input_quant, const int channel_num, + cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, + const float *gradient, const int size, + const float *input_div_alpha, + const float *input_quant, const bool neg_trunc, + const int channel_num, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cu index 5b3b99a31d9..af23fa04e95 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cu @@ -1,92 +1,92 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "fake_learned_scale_quant_perlayer_impl.cuh" -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -__global__ void FakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - // dequantize - output[i] = input_quant[i] * input_alpha[0]; - } - return; -} - -__global__ void FakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, - const int size, const float *input_div_alpha, - const float *input_quant, const bool neg_trunc) { - float grad_alpha_temp = 0.f; - float lower_bound = -1.0 * !neg_trunc; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - if (input_div_alpha[i] > 1.0) { - grad_alpha_temp += gradient[i]; - grad_input[i] = 0; - } else if (input_div_alpha[i] < lower_bound) { - grad_alpha_temp -= gradient[i]; - grad_input[i] = 0; - } else { - grad_input[i] = gradient[i]; - grad_alpha_temp += (gradient[i] * (input_quant[i] - input_div_alpha[i])); - } - } - MsAtomicAdd(grad_alpha, grad_alpha_temp); - return; -} - -__global__ void LSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, - float *input_div_alpha, float *input_quant, const bool neg_trunc) { - float input_x; - float lower_bound = -1.0 * !neg_trunc; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - input_x = input[i] / input_alpha[0]; - input_div_alpha[i] = input_x; - input_x = max(input_x, lower_bound); - input_x = min(input_x, 1.0); - - // quantize - input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; - } - return; -} - -cudaError_t CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant, - cudaStream_t cuda_stream) { - FakeLearnedScaleQuantPerLayer<<>>(output, size, input_alpha, - input_quant); - return GetCudaStatus(); -} - -cudaError_t CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, - const int size, const float *input_div_alpha, const float *input_quant, - const bool neg_trunc, cudaStream_t cuda_stream) { - FakeLearnedScaleQuantPerLayerGrad<<>>( - grad_input, grad_alpha, gradient, size, input_div_alpha, input_quant, neg_trunc); - return GetCudaStatus(); -} - -cudaError_t CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, - float *input_div_alpha, float *input_quant, const bool neg_trunc, - cudaStream_t cuda_stream) { - LSQNudgePerLayer<<>>(input, size, input_alpha, input_quant_max, - input_div_alpha, input_quant, neg_trunc); - return GetCudaStatus(); -} +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fake_learned_scale_quant_perlayer_impl.cuh" +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +__global__ void FakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + // dequantize + output[i] = input_quant[i] * input_alpha[0]; + } + return; +} + +__global__ void FakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, + const float *input_quant, const bool neg_trunc) { + float grad_alpha_temp = 0.f; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + if (input_div_alpha[i] > 1.0) { + grad_alpha_temp += gradient[i]; + grad_input[i] = 0; + } else if (input_div_alpha[i] < lower_bound) { + grad_alpha_temp -= gradient[i]; + grad_input[i] = 0; + } else { + grad_input[i] = gradient[i]; + grad_alpha_temp += (gradient[i] * (input_quant[i] - input_div_alpha[i])); + } + } + MsAtomicAdd(grad_alpha, grad_alpha_temp); + return; +} + +__global__ void LSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc) { + float input_x; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + input_x = input[i] / input_alpha[0]; + input_div_alpha[i] = input_x; + input_x = max(input_x, lower_bound); + input_x = min(input_x, 1.0); + + // quantize + input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; + } + return; +} + +cudaError_t CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant, + cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerLayer<<>>(output, size, input_alpha, + input_quant); + return GetCudaStatus(); +} + +cudaError_t CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, const float *input_quant, + const bool neg_trunc, cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerLayerGrad<<>>( + grad_input, grad_alpha, gradient, size, input_div_alpha, input_quant, neg_trunc); + return GetCudaStatus(); +} + +cudaError_t CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, + cudaStream_t cuda_stream) { + LSQNudgePerLayer<<>>(input, size, input_alpha, input_quant_max, + input_div_alpha, input_quant, neg_trunc); + return GetCudaStatus(); +} diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh index 1de9d7924a2..ab860209cd9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh @@ -1,33 +1,33 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -CUDA_LIB_EXPORT cudaError_t CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, - float *input_quant_max, float *input_div_alpha, float *input_quant, - const bool neg_trunc, cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, - float *input_quant, cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, - const float *gradient, const int size, - const float *input_div_alpha, const float *input_quant, - const bool neg_trunc, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +CUDA_LIB_EXPORT cudaError_t CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, + float *input_quant_max, float *input_div_alpha, float *input_quant, + const bool neg_trunc, cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, + float *input_quant, cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, + const float *gradient, const int size, + const float *input_div_alpha, const float *input_quant, + const bool neg_trunc, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_LEARNED_SCALE_QUANT_PERLAYER_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh index 57ebfc0f57d..40cb970b862 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh @@ -1,35 +1,35 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -CUDA_LIB_EXPORT cudaError_t CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, - const float quant_max, float *nudge_min, float *nudge_max, float *scale, - const int channel_num, const bool symmetric, cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerChannel(const float *input, float *output, const int total_num, - const int channel_num, const float *nudge_min, - const float *nudge_max, const float *scale, - cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, - const int total_num, const int channel_num, - const float *nudge_min, const float *nudge_max, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +CUDA_LIB_EXPORT cudaError_t CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, + const float quant_max, float *nudge_min, float *nudge_max, float *scale, + const int channel_num, const bool symmetric, cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerChannel(const float *input, float *output, const int total_num, + const int channel_num, const float *nudge_min, + const float *nudge_max, const float *scale, + cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, + const int total_num, const int channel_num, + const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERCHANNEL_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cu index 5d436cd17fa..d9b77678be7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cu @@ -1,113 +1,113 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "fake_quant_perlayer_impl.cuh" - -__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale) { - float input_x = 0.f; - int nudge_input = 0; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - input_x = input[i]; - // clamp input x - if (input_x < nudge_min[0]) { - input_x = nudge_min[0]; - } - if (input_x > nudge_max[0]) { - input_x = nudge_max[0]; - } - // clamp shift - nudge_input = round((input_x - nudge_min[0]) / scale[0]); - - // quantize - output[i] = nudge_input * scale[0] + nudge_min[0]; - } - return; -} - -__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { - output[i] = 0; - } else { - output[i] = gradient[i]; - } - } - return; -} - -__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const bool symmetric) { - float zp_from_min = 0.f; - scale[0] = 0.f; - nudge_max[0] = 0.f; - nudge_min[0] = 0.f; - - float max_data = input_max[0]; - float min_data = input_min[0]; - if (symmetric) { - max_data = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0]; - min_data = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0]; - } - - if ((quant_max - quant_min) == 0 || (max_data - min_data) == 0) { - scale[0] = 0.f; - zp_from_min = 0.f; - } else { - scale[0] = (max_data - min_data) / (quant_max - quant_min); - zp_from_min = quant_min - min_data / scale[0]; - } - - float nudge_zp = 0.f; - if (zp_from_min <= quant_min) { - nudge_zp = quant_min; - } else if (zp_from_min >= quant_max) { - nudge_zp = quant_max; - } else { - nudge_zp = round(zp_from_min); - } - - nudge_min[0] = (quant_min - nudge_zp) * (scale[0]); - nudge_max[0] = (quant_max - nudge_zp) * (scale[0]); - return; -} - -cudaError_t CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale, cudaStream_t cuda_stream) { - FakeQuantPerLayer<<>>(input, output, size, nudge_min, nudge_max, - scale); - return GetCudaStatus(); -} - -cudaError_t CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { - FakeQuantPerLayerGrad<<>>(input, gradient, output, size, nudge_min, - nudge_max); - return GetCudaStatus(); -} - -cudaError_t CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const bool symmetric, - cudaStream_t cuda_stream) { - NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, - symmetric); - return GetCudaStatus(); -} +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "fake_quant_perlayer_impl.cuh" + +__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale) { + float input_x = 0.f; + int nudge_input = 0; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + input_x = input[i]; + // clamp input x + if (input_x < nudge_min[0]) { + input_x = nudge_min[0]; + } + if (input_x > nudge_max[0]) { + input_x = nudge_max[0]; + } + // clamp shift + nudge_input = round((input_x - nudge_min[0]) / scale[0]); + + // quantize + output[i] = nudge_input * scale[0] + nudge_min[0]; + } + return; +} + +__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { + output[i] = 0; + } else { + output[i] = gradient[i]; + } + } + return; +} + +__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric) { + float zp_from_min = 0.f; + scale[0] = 0.f; + nudge_max[0] = 0.f; + nudge_min[0] = 0.f; + + float max_data = input_max[0]; + float min_data = input_min[0]; + if (symmetric) { + max_data = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0]; + min_data = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0]; + } + + if ((quant_max - quant_min) == 0 || (max_data - min_data) == 0) { + scale[0] = 0.f; + zp_from_min = 0.f; + } else { + scale[0] = (max_data - min_data) / (quant_max - quant_min); + zp_from_min = quant_min - min_data / scale[0]; + } + + float nudge_zp = 0.f; + if (zp_from_min <= quant_min) { + nudge_zp = quant_min; + } else if (zp_from_min >= quant_max) { + nudge_zp = quant_max; + } else { + nudge_zp = round(zp_from_min); + } + + nudge_min[0] = (quant_min - nudge_zp) * (scale[0]); + nudge_max[0] = (quant_max - nudge_zp) * (scale[0]); + return; +} + +cudaError_t CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, cudaStream_t cuda_stream) { + FakeQuantPerLayer<<>>(input, output, size, nudge_min, nudge_max, + scale); + return GetCudaStatus(); +} + +cudaError_t CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { + FakeQuantPerLayerGrad<<>>(input, gradient, output, size, nudge_min, + nudge_max); + return GetCudaStatus(); +} + +cudaError_t CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric, + cudaStream_t cuda_stream) { + NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, + symmetric); + return GetCudaStatus(); +} diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh index 9796a2d1b39..a6ce4ef871e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh @@ -1,33 +1,33 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -CUDA_LIB_EXPORT cudaError_t CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, - const float quant_max, float *nudge_min, float *nudge_max, float *scale, - const bool symmetric, cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerLayer(const float *input, float *output, const int size, - const float *nudge_min, const float *nudge_max, const float *scale, - cudaStream_t cuda_stream); - -CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, - const int size, const float *nudge_min, const float *nudge_max, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +CUDA_LIB_EXPORT cudaError_t CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, + const float quant_max, float *nudge_min, float *nudge_max, float *scale, + const bool symmetric, cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerLayer(const float *input, float *output, const int size, + const float *nudge_min, const float *nudge_max, const float *scale, + cudaStream_t cuda_stream); + +CUDA_LIB_EXPORT cudaError_t CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, + const int size, const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FAKE_QUANT_PERLAYER_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cu index 83876b4815d..dd3b8f776a4 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cu @@ -1,60 +1,60 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "fill_diagonal_impl.cuh" - -#include "include/cuda_fp16.h" - -template -__global__ void FillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[step_size * pos] = static_cast(fill_value); - } - return; -} - -template -cudaError_t CalFillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - FillDiagonal<<>>(size, fill_value, step_size, - output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, double *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, uint8_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, int8_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, int16_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, int32_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, - const int64_t step_size, int64_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fill_diagonal_impl.cuh" + +#include "include/cuda_fp16.h" + +template +__global__ void FillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[step_size * pos] = static_cast(fill_value); + } + return; +} + +template +cudaError_t CalFillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + FillDiagonal<<>>(size, fill_value, step_size, + output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, double *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, uint8_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, int8_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, int16_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, int32_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalFillDiagonal(const size_t size, const float fill_value, + const int64_t step_size, int64_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh index 312ac791e04..8e3e594d5bb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_diagonal_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -cudaError_t CalFillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +cudaError_t CalFillDiagonal(const size_t size, const float fill_value, const int64_t step_size, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_DIAGONAL_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cu index 3303ef8f52a..8a96836590b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cu @@ -1,345 +1,345 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__inline__ __device__ T GetInput(const T *input, size_t index) { - return input[index]; -} -__inline__ __device__ float GetInput(const half *input, size_t index) { return __half2float(input[index]); } - -template -__global__ void GridSampler2DKernel(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, - const size_t C, const size_t inp_H, const size_t inp_W, const size_t out_H, - const size_t out_W, const size_t inp_sN, const size_t inp_sC, const size_t inp_sH, - const size_t inp_sW, const size_t grid_sN, const size_t grid_sH, - const size_t grid_sW, const size_t grid_sCoor, const size_t out_sN, - const size_t out_sC, const size_t out_sH, const size_t out_sW, - GridSamplerInterpolationMode interpolation_mode, - GridSamplerPaddingMode padding_mode, bool align_corners) { - for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < size; index += blockDim.x * gridDim.x) { - const size_t w = index % out_W; - const size_t h = (index / out_W) % out_H; - const size_t n = index / (out_H * out_W); - const size_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y coordinates from grid - auto x = GetInput(grid_addr, grid_offset); - auto y = GetInput(grid_addr, grid_offset + grid_sCoor); - - // ItmType is the intermediate type for computing. - // If input type T is fp16, ItmType represents the upcasting type fp32 of T. Otherwise, im_type is the same as T. - using ItmType = decltype(x); - - ItmType ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); - ItmType iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolationMode::BILINEAR) { - // get NE, NW, SE, SW pixel values from (x, y) - int64_t ix_nw = static_cast(::floor(ix)); - int64_t iy_nw = static_cast(::floor(iy)); - int64_t ix_ne = ix_nw + 1; - int64_t iy_ne = iy_nw; - int64_t ix_sw = ix_nw; - int64_t iy_sw = iy_nw + 1; - int64_t ix_se = ix_nw + 1; - int64_t iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - ItmType nw = (ix_se - ix) * (iy_se - iy); - ItmType ne = (ix - ix_sw) * (iy_sw - iy); - ItmType sw = (ix_ne - ix) * (iy - iy_ne); - ItmType se = (ix - ix_nw) * (iy - iy_nw); - - // calculate bilinear weighted pixel value and set output pixel - auto inp_ptr_NC = input_addr + n * inp_sN; - auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; - for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - ItmType intermediate_value = 0; - if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iy_nw * inp_sH + ix_nw * inp_sW) * nw; - } - if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iy_ne * inp_sH + ix_ne * inp_sW) * ne; - } - if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iy_sw * inp_sH + ix_sw * inp_sW) * sw; - } - if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iy_se * inp_sH + ix_se * inp_sW) * se; - } - *out_ptr_NCHW = static_cast(intermediate_value); - } - } else if (interpolation_mode == GridSamplerInterpolationMode::NEAREST) { - int64_t ix_nearest = static_cast(::round(ix)); - int64_t iy_nearest = static_cast(::round(iy)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input_addr + n * inp_sN; - auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; - for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { - *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); - } - } - } else if (interpolation_mode == GridSamplerInterpolationMode::BICUBIC) { - ix = grid_sampler_unnormalize(x, inp_W, align_corners); - iy = grid_sampler_unnormalize(y, inp_H, align_corners); - - ItmType ix_nw = ::floor(ix); - ItmType iy_nw = ::floor(iy); - - const ItmType tx = ix - ix_nw; - const ItmType ty = iy - iy_nw; - - auto inp_ptr_NC = input_addr + n * inp_sN; - auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; - for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - T coefficients[4]; - - for (size_t i = 0; i < 4; ++i) { - coefficients[i] = cubic_interp1d(get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - tx); - } - - *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); - } - } - } -} - -template -cudaError_t GridSampler2D(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, - const GridSamplerPaddingMode padding_mode, const bool align_corners, - cudaStream_t cuda_stream) { - size_t thread_per_block = 256; - size_t block_per_grid = (size + thread_per_block - 1) / thread_per_block; - GridSampler2DKernel<<>>( - size, input_addr, grid_addr, output_addr, input_shape[1], input_shape[2], input_shape[3], grid_shape[1], - grid_shape[2], input_stride[0], input_stride[1], input_stride[2], input_stride[3], grid_stride[0], grid_stride[1], - grid_stride[2], grid_stride[3], output_stride[0], output_stride[1], output_stride[2], output_stride[3], - interpolation_mode, padding_mode, align_corners); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t -GridSampler2D(const size_t size, const half *input_addr, const half *grid_addr, half *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t -GridSampler2D(const size_t size, const float *input_addr, const float *grid_addr, float *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t -GridSampler2D(const size_t size, const double *input_addr, const double *grid_addr, double *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); - -template -__global__ void GridSampler3DKernel(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, - const size_t C, const size_t inp_D, const size_t inp_H, const size_t inp_W, - const size_t out_D, const size_t out_H, const size_t out_W, const size_t inp_sN, - const size_t inp_sC, const size_t inp_sD, const size_t inp_sH, const size_t inp_sW, - const size_t grid_sN, const size_t grid_sD, const size_t grid_sH, - const size_t grid_sW, const size_t grid_sCoor, const size_t out_sN, - const size_t out_sC, const size_t out_sD, const size_t out_sH, const size_t out_sW, - GridSamplerInterpolationMode interpolation_mode, - GridSamplerPaddingMode padding_mode, bool align_corners) { - for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < size; index += blockDim.x * gridDim.x) { - const size_t w = index % out_W; - const size_t h = (index / out_W) % out_H; - const size_t d = (index / (out_H * out_W)) % out_D; - const size_t n = index / (out_D * out_H * out_W); - const size_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y, z coordinates from grid - auto x = GetInput(grid_addr, grid_offset); - auto y = GetInput(grid_addr, grid_offset + grid_sCoor); - auto z = GetInput(grid_addr, grid_offset + 2 * grid_sCoor); - - // ItmType is the intermediate type for computing. - // If input type T is fp16, ItmType represents the upcasting type fp32 of T. Otherwise, im_type is the same as T. - using ItmType = decltype(x); - - ItmType ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); - ItmType iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); - ItmType iz = grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolationMode::BILINEAR) { - // get corner pixel values from (x, y, z) - // for 4d, we used north-east-south-west - // for 5d, we add top-bottom - int64_t ix_tnw = static_cast(::floor(ix)); - int64_t iy_tnw = static_cast(::floor(iy)); - int64_t iz_tnw = static_cast(::floor(iz)); - - int64_t ix_tne = ix_tnw + 1; - int64_t iy_tne = iy_tnw; - int64_t iz_tne = iz_tnw; - - int64_t ix_tsw = ix_tnw; - int64_t iy_tsw = iy_tnw + 1; - int64_t iz_tsw = iz_tnw; - - int64_t ix_tse = ix_tnw + 1; - int64_t iy_tse = iy_tnw + 1; - int64_t iz_tse = iz_tnw; - - int64_t ix_bnw = ix_tnw; - int64_t iy_bnw = iy_tnw; - int64_t iz_bnw = iz_tnw + 1; - - int64_t ix_bne = ix_tnw + 1; - int64_t iy_bne = iy_tnw; - int64_t iz_bne = iz_tnw + 1; - - int64_t ix_bsw = ix_tnw; - int64_t iy_bsw = iy_tnw + 1; - int64_t iz_bsw = iz_tnw + 1; - - int64_t ix_bse = ix_tnw + 1; - int64_t iy_bse = iy_tnw + 1; - int64_t iz_bse = iz_tnw + 1; - - // get surfaces to each neighbor: - ItmType tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); - ItmType tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); - ItmType tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); - ItmType tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); - ItmType bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); - ItmType bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); - ItmType bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); - ItmType bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); - - auto inp_ptr_NC = input_addr + n * inp_sN; - auto out_ptr_NCDHW = output_addr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne - // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse - // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne - // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse - ItmType intermediate_value = 0; - if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW) * tnw; - } - if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW) * tne; - } - if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW) * tsw; - } - if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW) * tse; - } - if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW) * bnw; - } - if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW) * bne; - } - if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW) * bsw; - } - if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { - intermediate_value += GetInput(inp_ptr_NC, iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW) * bse; - } - *out_ptr_NCDHW = static_cast(intermediate_value); - } - } else if (interpolation_mode == GridSamplerInterpolationMode::NEAREST) { - int64_t ix_nearest = static_cast(::round(ix)); - int64_t iy_nearest = static_cast(::round(iy)); - int64_t iz_nearest = static_cast(::round(iz)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input_addr + n * inp_sN; - auto out_ptr_NCDHW = output_addr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCDHW = static_cast(0); - } - } - } - } -} - -template -cudaError_t GridSampler3D(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, - const GridSamplerPaddingMode padding_mode, const bool align_corners, - cudaStream_t cuda_stream) { - size_t thread_per_block = 256; - size_t block_per_grid = (size + thread_per_block - 1) / thread_per_block; - GridSampler3DKernel<<>>( - size, input_addr, grid_addr, output_addr, input_shape[1], input_shape[2], input_shape[3], input_shape[4], - grid_shape[1], grid_shape[2], grid_shape[3], input_stride[0], input_stride[1], input_stride[2], input_stride[3], - input_stride[4], grid_stride[0], grid_stride[1], grid_stride[2], grid_stride[3], grid_stride[4], output_stride[0], - output_stride[1], output_stride[2], output_stride[3], output_stride[4], interpolation_mode, padding_mode, - align_corners); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t -GridSampler3D(const size_t size, const half *input_addr, const half *grid_addr, half *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t -GridSampler3D(const size_t size, const float *input_addr, const float *grid_addr, float *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t -GridSampler3D(const size_t size, const double *input_addr, const double *grid_addr, double *output_addr, - const std::vector &input_shape, const std::vector &grid_shape, - const std::vector &output_shape, const std::vector &input_stride, - const std::vector &grid_stride, const std::vector &output_stride, - const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, - const bool align_corners, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__inline__ __device__ T GetInput(const T *input, size_t index) { + return input[index]; +} +__inline__ __device__ float GetInput(const half *input, size_t index) { return __half2float(input[index]); } + +template +__global__ void GridSampler2DKernel(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, + const size_t C, const size_t inp_H, const size_t inp_W, const size_t out_H, + const size_t out_W, const size_t inp_sN, const size_t inp_sC, const size_t inp_sH, + const size_t inp_sW, const size_t grid_sN, const size_t grid_sH, + const size_t grid_sW, const size_t grid_sCoor, const size_t out_sN, + const size_t out_sC, const size_t out_sH, const size_t out_sW, + GridSamplerInterpolationMode interpolation_mode, + GridSamplerPaddingMode padding_mode, bool align_corners) { + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < size; index += blockDim.x * gridDim.x) { + const size_t w = index % out_W; + const size_t h = (index / out_W) % out_H; + const size_t n = index / (out_H * out_W); + const size_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y coordinates from grid + auto x = GetInput(grid_addr, grid_offset); + auto y = GetInput(grid_addr, grid_offset + grid_sCoor); + + // ItmType is the intermediate type for computing. + // If input type T is fp16, ItmType represents the upcasting type fp32 of T. Otherwise, im_type is the same as T. + using ItmType = decltype(x); + + ItmType ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + ItmType iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolationMode::BILINEAR) { + // get NE, NW, SE, SW pixel values from (x, y) + int64_t ix_nw = static_cast(::floor(ix)); + int64_t iy_nw = static_cast(::floor(iy)); + int64_t ix_ne = ix_nw + 1; + int64_t iy_ne = iy_nw; + int64_t ix_sw = ix_nw; + int64_t iy_sw = iy_nw + 1; + int64_t ix_se = ix_nw + 1; + int64_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + ItmType nw = (ix_se - ix) * (iy_se - iy); + ItmType ne = (ix - ix_sw) * (iy_sw - iy); + ItmType sw = (ix_ne - ix) * (iy - iy_ne); + ItmType se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input_addr + n * inp_sN; + auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; + for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + ItmType intermediate_value = 0; + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iy_nw * inp_sH + ix_nw * inp_sW) * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iy_ne * inp_sH + ix_ne * inp_sW) * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iy_sw * inp_sH + ix_sw * inp_sW) * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iy_se * inp_sH + ix_se * inp_sW) * se; + } + *out_ptr_NCHW = static_cast(intermediate_value); + } + } else if (interpolation_mode == GridSamplerInterpolationMode::NEAREST) { + int64_t ix_nearest = static_cast(::round(ix)); + int64_t iy_nearest = static_cast(::round(iy)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input_addr + n * inp_sN; + auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; + for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } else if (interpolation_mode == GridSamplerInterpolationMode::BICUBIC) { + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + ItmType ix_nw = ::floor(ix); + ItmType iy_nw = ::floor(iy); + + const ItmType tx = ix - ix_nw; + const ItmType ty = iy - iy_nw; + + auto inp_ptr_NC = input_addr + n * inp_sN; + auto out_ptr_NCHW = output_addr + n * out_sN + h * out_sH + w * out_sW; + for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + T coefficients[4]; + + for (size_t i = 0; i < 4; ++i) { + coefficients[i] = cubic_interp1d(get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, + inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, + inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, + inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, + inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); + } + } + } +} + +template +cudaError_t GridSampler2D(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, + const GridSamplerPaddingMode padding_mode, const bool align_corners, + cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (size + thread_per_block - 1) / thread_per_block; + GridSampler2DKernel<<>>( + size, input_addr, grid_addr, output_addr, input_shape[1], input_shape[2], input_shape[3], grid_shape[1], + grid_shape[2], input_stride[0], input_stride[1], input_stride[2], input_stride[3], grid_stride[0], grid_stride[1], + grid_stride[2], grid_stride[3], output_stride[0], output_stride[1], output_stride[2], output_stride[3], + interpolation_mode, padding_mode, align_corners); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t +GridSampler2D(const size_t size, const half *input_addr, const half *grid_addr, half *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t +GridSampler2D(const size_t size, const float *input_addr, const float *grid_addr, float *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t +GridSampler2D(const size_t size, const double *input_addr, const double *grid_addr, double *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); + +template +__global__ void GridSampler3DKernel(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, + const size_t C, const size_t inp_D, const size_t inp_H, const size_t inp_W, + const size_t out_D, const size_t out_H, const size_t out_W, const size_t inp_sN, + const size_t inp_sC, const size_t inp_sD, const size_t inp_sH, const size_t inp_sW, + const size_t grid_sN, const size_t grid_sD, const size_t grid_sH, + const size_t grid_sW, const size_t grid_sCoor, const size_t out_sN, + const size_t out_sC, const size_t out_sD, const size_t out_sH, const size_t out_sW, + GridSamplerInterpolationMode interpolation_mode, + GridSamplerPaddingMode padding_mode, bool align_corners) { + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < size; index += blockDim.x * gridDim.x) { + const size_t w = index % out_W; + const size_t h = (index / out_W) % out_H; + const size_t d = (index / (out_H * out_W)) % out_D; + const size_t n = index / (out_D * out_H * out_W); + const size_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z coordinates from grid + auto x = GetInput(grid_addr, grid_offset); + auto y = GetInput(grid_addr, grid_offset + grid_sCoor); + auto z = GetInput(grid_addr, grid_offset + 2 * grid_sCoor); + + // ItmType is the intermediate type for computing. + // If input type T is fp16, ItmType represents the upcasting type fp32 of T. Otherwise, im_type is the same as T. + using ItmType = decltype(x); + + ItmType ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + ItmType iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + ItmType iz = grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolationMode::BILINEAR) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int64_t ix_tnw = static_cast(::floor(ix)); + int64_t iy_tnw = static_cast(::floor(iy)); + int64_t iz_tnw = static_cast(::floor(iz)); + + int64_t ix_tne = ix_tnw + 1; + int64_t iy_tne = iy_tnw; + int64_t iz_tne = iz_tnw; + + int64_t ix_tsw = ix_tnw; + int64_t iy_tsw = iy_tnw + 1; + int64_t iz_tsw = iz_tnw; + + int64_t ix_tse = ix_tnw + 1; + int64_t iy_tse = iy_tnw + 1; + int64_t iz_tse = iz_tnw; + + int64_t ix_bnw = ix_tnw; + int64_t iy_bnw = iy_tnw; + int64_t iz_bnw = iz_tnw + 1; + + int64_t ix_bne = ix_tnw + 1; + int64_t iy_bne = iy_tnw; + int64_t iz_bne = iz_tnw + 1; + + int64_t ix_bsw = ix_tnw; + int64_t iy_bsw = iy_tnw + 1; + int64_t iz_bsw = iz_tnw + 1; + + int64_t ix_bse = ix_tnw + 1; + int64_t iy_bse = iy_tnw + 1; + int64_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + ItmType tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + ItmType tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + ItmType tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + ItmType tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + ItmType bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + ItmType bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + ItmType bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + ItmType bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input_addr + n * inp_sN; + auto out_ptr_NCDHW = output_addr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + ItmType intermediate_value = 0; + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW) * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW) * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW) * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW) * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW) * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW) * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW) * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + intermediate_value += GetInput(inp_ptr_NC, iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW) * bse; + } + *out_ptr_NCDHW = static_cast(intermediate_value); + } + } else if (interpolation_mode == GridSamplerInterpolationMode::NEAREST) { + int64_t ix_nearest = static_cast(::round(ix)); + int64_t iy_nearest = static_cast(::round(iy)); + int64_t iz_nearest = static_cast(::round(iz)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input_addr + n * inp_sN; + auto out_ptr_NCDHW = output_addr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (size_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +template +cudaError_t GridSampler3D(const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, + const GridSamplerPaddingMode padding_mode, const bool align_corners, + cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (size + thread_per_block - 1) / thread_per_block; + GridSampler3DKernel<<>>( + size, input_addr, grid_addr, output_addr, input_shape[1], input_shape[2], input_shape[3], input_shape[4], + grid_shape[1], grid_shape[2], grid_shape[3], input_stride[0], input_stride[1], input_stride[2], input_stride[3], + input_stride[4], grid_stride[0], grid_stride[1], grid_stride[2], grid_stride[3], grid_stride[4], output_stride[0], + output_stride[1], output_stride[2], output_stride[3], output_stride[4], interpolation_mode, padding_mode, + align_corners); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t +GridSampler3D(const size_t size, const half *input_addr, const half *grid_addr, half *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t +GridSampler3D(const size_t size, const float *input_addr, const float *grid_addr, float *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t +GridSampler3D(const size_t size, const double *input_addr, const double *grid_addr, double *output_addr, + const std::vector &input_shape, const std::vector &grid_shape, + const std::vector &output_shape, const std::vector &input_stride, + const std::vector &grid_stride, const std::vector &output_stride, + const GridSamplerInterpolationMode interpolation_mode, const GridSamplerPaddingMode padding_mode, + const bool align_corners, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh index f2ddd45ecdb..861ce207f1c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh @@ -1,207 +1,207 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -enum GridSamplerInterpolationMode { BILINEAR = 0, NEAREST, BICUBIC }; - -enum GridSamplerPaddingMode { ZEROS = 0, BORDER, REFLECTION }; - -static std::map kGridSamplerInterpolationMap{ - {"bilinear", GridSamplerInterpolationMode::BILINEAR}, - {"nearest", GridSamplerInterpolationMode::NEAREST}, - {"bicubic", GridSamplerInterpolationMode::BICUBIC}}; - -static std::map kGridSamplerPaddingMap{ - {"zeros", GridSamplerPaddingMode::ZEROS}, - {"border", GridSamplerPaddingMode::BORDER}, - {"reflection", GridSamplerPaddingMode::REFLECTION}}; - -template -CUDA_LIB_EXPORT cudaError_t GridSampler2D( - const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, const std::vector &input_shape, - const std::vector &grid_shape, const std::vector &output_shape, - const std::vector &input_stride, const std::vector &grid_stride, - const std::vector &output_stride, const GridSamplerInterpolationMode interpolation_mode, - const GridSamplerPaddingMode padding_mode, const bool align_corners, cudaStream_t stream); - -template -CUDA_LIB_EXPORT cudaError_t GridSampler3D( - const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, const std::vector &input_shape, - const std::vector &grid_shape, const std::vector &output_shape, - const std::vector &input_stride, const std::vector &grid_stride, - const std::vector &output_stride, const GridSamplerInterpolationMode interpolation_mode, - const GridSamplerPaddingMode padding_mode, const bool align_corners, cudaStream_t stream); - -template -static __forceinline__ __device__ T clip_coordinates(T in, int clip_limit) { - in = in > static_cast(0) ? in : static_cast(0); - return static_cast(clip_limit - 1) < in ? static_cast(clip_limit - 1) : in; -} - -template -static __forceinline__ __device__ T reflect_coordinates(T in, int twice_low, int twice_high) { - if (twice_low != twice_high) { - T min = static_cast(twice_low) / 2; - T span = static_cast(twice_high - twice_low) / 2; - in = ::fabs(in - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - T extra = ::fmod(in, span); - int flips = static_cast(::floor(in / span)); - if (flips % 2 != 0) { - return span - extra + min; - } else { - return extra + min; - } - } else { - return static_cast(0); - } -} - -template -static __forceinline__ __device__ half reflect_coordinates(half in, int twice_low, int twice_high) { - if (twice_low != twice_high) { - float min = static_cast(twice_low) / 2; - float span = static_cast(twice_high - twice_low) / 2; - float new_in = ::fabs(__half2float(in) - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - float extra = ::fmod(in, span); - int flips = static_cast(::floor(new_in / span)); - if (flips % 2 != 0) { - return __float2half(span - extra + min); - } else { - return __float2half(extra + min); - } - } else { - return static_cast(0.0); - } -} - -template -static __forceinline__ __device__ T safe_downgrade_to_int_range(T x) { - if (x > static_cast(INT_MAX - 1) || x < static_cast(INT_MIN) || !::isfinite(static_cast(x))) { - return static_cast(-100.0); - } else { - return x; - } -} - -template -__device__ __forceinline__ static T cubic_convolution_one(T x, const T A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} - -template -__device__ __forceinline__ static T cubic_convolution_two(T x, const T A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { - return w >= 0 && w < W && h >= 0 && h < H; -} - -static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { - return w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D; -} - -template -static __forceinline__ __device__ T grid_sampler_unnormalize(T coord, const int size, bool align_corners) { - if (!align_corners) { - // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] - return ((coord + 1.f) * size - 1) / 2; - } else { - // unnormalize coord from [-1, 1] to [0, size - 1] - return ((coord + 1.f) / 2) * (size - 1); - } -} - -template -static __forceinline__ __device__ T compute_coordinates(T coord, const size_t size, GridSamplerPaddingMode padding_mode, - bool align_corners) { - if (padding_mode == GridSamplerPaddingMode::REFLECTION) { - if (!align_corners) { - coord = reflect_coordinates(coord, -1, 2 * size - 1); - } else { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); - } - coord = clip_coordinates(coord, size); - } else if (padding_mode == GridSamplerPaddingMode::BORDER) { - coord = clip_coordinates(coord, size); - } - - coord = safe_downgrade_to_int_range(coord); - return coord; -} - -template -static __forceinline__ __device__ T grid_sampler_compute_source_index(T coord, size_t size, - GridSamplerPaddingMode padding_mode, - bool align_corners) { - coord = compute_coordinates(grid_sampler_unnormalize(coord, size, align_corners), size, padding_mode, align_corners); - return coord; -} - -template -__device__ __forceinline__ static void get_cubic_upsampling_coefficients(T coeffs[4], T t) { - const T A = -0.75; - - // opposite coefficients - T op_x = 1.0 - t; - coeffs[2] = cubic_convolution_one(op_x, A); - coeffs[3] = cubic_convolution_two(op_x + 1.0, A); - - T x = t; - coeffs[0] = cubic_convolution_two(x + 1.0, A); - coeffs[1] = cubic_convolution_one(x, A); -} - -template -__device__ __forceinline__ static T cubic_interp1d(T x0, T x1, T x2, T x3, S t) { - S coeffs[4]; - get_cubic_upsampling_coefficients(coeffs, t); - - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; -} - -template -__device__ __forceinline__ static half cubic_interp1d(half x0, half x1, half x2, half x3, S t) { - S coeffs[4]; - get_cubic_upsampling_coefficients(coeffs, t); - - return __float2half(__half2float(x0) * coeffs[0] + __half2float(x1) * coeffs[1] + __half2float(x2) * coeffs[2] + - __half2float(x3) * coeffs[3]); -} - -template -static __forceinline__ __device__ T get_value_bounded(const T *data, T x, T y, const size_t W, const size_t H, - const size_t sW, const size_t sH, - GridSamplerPaddingMode padding_mode, bool align_corners) { - int ix = static_cast(compute_coordinates(x, W, padding_mode, align_corners)); - int iy = static_cast(compute_coordinates(y, H, padding_mode, align_corners)); - - if (within_bounds_2d(iy, ix, H, W)) { - return data[iy * sH + ix * sW]; - } - return static_cast(0); -} - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +enum GridSamplerInterpolationMode { BILINEAR = 0, NEAREST, BICUBIC }; + +enum GridSamplerPaddingMode { ZEROS = 0, BORDER, REFLECTION }; + +static std::map kGridSamplerInterpolationMap{ + {"bilinear", GridSamplerInterpolationMode::BILINEAR}, + {"nearest", GridSamplerInterpolationMode::NEAREST}, + {"bicubic", GridSamplerInterpolationMode::BICUBIC}}; + +static std::map kGridSamplerPaddingMap{ + {"zeros", GridSamplerPaddingMode::ZEROS}, + {"border", GridSamplerPaddingMode::BORDER}, + {"reflection", GridSamplerPaddingMode::REFLECTION}}; + +template +CUDA_LIB_EXPORT cudaError_t GridSampler2D( + const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, const std::vector &input_shape, + const std::vector &grid_shape, const std::vector &output_shape, + const std::vector &input_stride, const std::vector &grid_stride, + const std::vector &output_stride, const GridSamplerInterpolationMode interpolation_mode, + const GridSamplerPaddingMode padding_mode, const bool align_corners, cudaStream_t stream); + +template +CUDA_LIB_EXPORT cudaError_t GridSampler3D( + const size_t size, const T *input_addr, const T *grid_addr, T *output_addr, const std::vector &input_shape, + const std::vector &grid_shape, const std::vector &output_shape, + const std::vector &input_stride, const std::vector &grid_stride, + const std::vector &output_stride, const GridSamplerInterpolationMode interpolation_mode, + const GridSamplerPaddingMode padding_mode, const bool align_corners, cudaStream_t stream); + +template +static __forceinline__ __device__ T clip_coordinates(T in, int clip_limit) { + in = in > static_cast(0) ? in : static_cast(0); + return static_cast(clip_limit - 1) < in ? static_cast(clip_limit - 1) : in; +} + +template +static __forceinline__ __device__ T reflect_coordinates(T in, int twice_low, int twice_high) { + if (twice_low != twice_high) { + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + T extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 != 0) { + return span - extra + min; + } else { + return extra + min; + } + } else { + return static_cast(0); + } +} + +template +static __forceinline__ __device__ half reflect_coordinates(half in, int twice_low, int twice_high) { + if (twice_low != twice_high) { + float min = static_cast(twice_low) / 2; + float span = static_cast(twice_high - twice_low) / 2; + float new_in = ::fabs(__half2float(in) - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + float extra = ::fmod(in, span); + int flips = static_cast(::floor(new_in / span)); + if (flips % 2 != 0) { + return __float2half(span - extra + min); + } else { + return __float2half(extra + min); + } + } else { + return static_cast(0.0); + } +} + +template +static __forceinline__ __device__ T safe_downgrade_to_int_range(T x) { + if (x > static_cast(INT_MAX - 1) || x < static_cast(INT_MIN) || !::isfinite(static_cast(x))) { + return static_cast(-100.0); + } else { + return x; + } +} + +template +__device__ __forceinline__ static T cubic_convolution_one(T x, const T A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +__device__ __forceinline__ static T cubic_convolution_two(T x, const T A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { + return w >= 0 && w < W && h >= 0 && h < H; +} + +static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D; +} + +template +static __forceinline__ __device__ T grid_sampler_unnormalize(T coord, const int size, bool align_corners) { + if (!align_corners) { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } else { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } +} + +template +static __forceinline__ __device__ T compute_coordinates(T coord, const size_t size, GridSamplerPaddingMode padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPaddingMode::REFLECTION) { + if (!align_corners) { + coord = reflect_coordinates(coord, -1, 2 * size - 1); + } else { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + } + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPaddingMode::BORDER) { + coord = clip_coordinates(coord, size); + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +template +static __forceinline__ __device__ T grid_sampler_compute_source_index(T coord, size_t size, + GridSamplerPaddingMode padding_mode, + bool align_corners) { + coord = compute_coordinates(grid_sampler_unnormalize(coord, size, align_corners), size, padding_mode, align_corners); + return coord; +} + +template +__device__ __forceinline__ static void get_cubic_upsampling_coefficients(T coeffs[4], T t) { + const T A = -0.75; + + // opposite coefficients + T op_x = 1.0 - t; + coeffs[2] = cubic_convolution_one(op_x, A); + coeffs[3] = cubic_convolution_two(op_x + 1.0, A); + + T x = t; + coeffs[0] = cubic_convolution_two(x + 1.0, A); + coeffs[1] = cubic_convolution_one(x, A); +} + +template +__device__ __forceinline__ static T cubic_interp1d(T x0, T x1, T x2, T x3, S t) { + S coeffs[4]; + get_cubic_upsampling_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +template +__device__ __forceinline__ static half cubic_interp1d(half x0, half x1, half x2, half x3, S t) { + S coeffs[4]; + get_cubic_upsampling_coefficients(coeffs, t); + + return __float2half(__half2float(x0) * coeffs[0] + __half2float(x1) * coeffs[1] + __half2float(x2) * coeffs[2] + + __half2float(x3) * coeffs[3]); +} + +template +static __forceinline__ __device__ T get_value_bounded(const T *data, T x, T y, const size_t W, const size_t H, + const size_t sW, const size_t sH, + GridSamplerPaddingMode padding_mode, bool align_corners) { + int ix = static_cast(compute_coordinates(x, W, padding_mode, align_corners)); + int iy = static_cast(compute_coordinates(y, H, padding_mode, align_corners)); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GRID_SAMPLER_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cu index 770d1eb9b93..d0936358e7f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cu @@ -1,135 +1,135 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "hamming_window_impl.cuh" - -template -__global__ void HammingWindowOne(const size_t size, const double N, const double PI, const float alpha, - const float beta, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = static_cast(1); - } - return; -} - -template -__global__ void HammingWindow(const size_t size, const double N, const double PI, const float alpha, const float beta, - S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - double out = alpha - beta * cos((2 * pos * PI) / (N - 1)); - output[pos] = static_cast(out); - } - return; -} - -template -cudaError_t HammingWindow(const size_t size, T N, const float alpha, const float beta, const bool periodic, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - const double PI = acos(-1); - if (N == 1) { - HammingWindowOne<<>>(size, N, PI, alpha, - beta, output); - } else { - N = periodic ? static_cast(N + 1) : static_cast(N); - HammingWindow<<>>(size, N, PI, alpha, beta, - output); - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, - const float beta, const bool periodic, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, - const float beta, const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, - const float beta, const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, - const float beta, const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, - const float beta, const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, - const float beta, const bool periodic, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, - const float beta, const bool periodic, - float *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, - const float beta, const bool periodic, - float *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, - const float beta, const bool periodic, - float *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, - const float beta, const bool periodic, - double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "hamming_window_impl.cuh" + +template +__global__ void HammingWindowOne(const size_t size, const double N, const double PI, const float alpha, + const float beta, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = static_cast(1); + } + return; +} + +template +__global__ void HammingWindow(const size_t size, const double N, const double PI, const float alpha, const float beta, + S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + double out = alpha - beta * cos((2 * pos * PI) / (N - 1)); + output[pos] = static_cast(out); + } + return; +} + +template +cudaError_t HammingWindow(const size_t size, T N, const float alpha, const float beta, const bool periodic, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + const double PI = acos(-1); + if (N == 1) { + HammingWindowOne<<>>(size, N, PI, alpha, + beta, output); + } else { + N = periodic ? static_cast(N + 1) : static_cast(N); + HammingWindow<<>>(size, N, PI, alpha, beta, + output); + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, + const float beta, const bool periodic, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, + const float beta, const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, + const float beta, const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, + const float beta, const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, + const float beta, const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, + const float beta, const bool periodic, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, + const float beta, const bool periodic, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, + const float beta, const bool periodic, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, + const float beta, const bool periodic, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int8_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int16_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int32_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, int64_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint8_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint16_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint32_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, uint64_t N, const float alpha, + const float beta, const bool periodic, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh index 00eb7784124..00f4c68b1ec 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, T N, const float alpha, const float beta, - const bool periodic, S *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t HammingWindow(const size_t size, T N, const float alpha, const float beta, + const bool periodic, S *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cu index 7db4c659fc8..36e90cd155d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cu @@ -1,194 +1,194 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh" - -__constant__ size_t start_cal[5]; -__constant__ size_t end_cal[5]; -__constant__ size_t output_cal[5]; - -template -struct HeavisideFunc { - __device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) { - if (x1 > T(0)) { - return T(1); - } else if (x1 == T(0)) { - return x2; - } else { - return T(0); - } - } -}; - -template -__global__ void CalHeavisideKernel(size_t size, const T *x1, const T *x2, T *y) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - y[pos] = Func()(x1[pos], x2[pos]); - } -} - -__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } - -template -__global__ void BroadcastHeavisideKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, - const size_t l4, const size_t l5, const size_t l6, const size_t r0, - const size_t r1, const size_t r2, const size_t r3, const size_t r4, - const size_t r5, const size_t r6, const size_t d0, const size_t d1, - const size_t d2, const size_t d3, const size_t d4, const size_t d5, - const size_t d6, const T *x1, const T *x2, T *y) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; - pos += blockDim.x * gridDim.x) { - size_t i = pos / output_cal[0] % d0; - size_t j = pos / output_cal[1] % d1; - size_t k = pos / output_cal[2] % d2; - size_t l = pos / output_cal[3] % d3; - size_t m = pos / output_cal[4] % d4; - size_t n = pos / d6 % d5; - size_t o = pos % d6; - - size_t l_index = Index(i, l0) * start_cal[0]; - l_index += Index(j, l1) * start_cal[1]; - l_index += Index(k, l2) * start_cal[2]; - l_index += Index(l, l3) * start_cal[3]; - l_index += Index(m, l4) * start_cal[4]; - l_index += Index(n, l5) * l6; - l_index += Index(o, l6); - size_t r_index = Index(i, r0) * end_cal[0]; - r_index += Index(j, r1) * end_cal[1]; - r_index += Index(k, r2) * end_cal[2]; - r_index += Index(l, r3) * end_cal[3]; - r_index += Index(m, r4) * end_cal[4]; - r_index += Index(n, r5) * r6; - r_index += Index(o, r6); - y[pos] = Func()(x1[l_index], x2[r_index]); - } -} - -template -cudaError_t CalHeaviside(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, - cudaStream_t cuda_stream) { - CalHeavisideKernel> - <<>>(size, x1, x2, y); - return GetCudaStatus(); -} - -cudaError_t CalData(const std::vector &start_shape, size_t *output) { - output[4] = start_shape[5] * start_shape[6]; - output[3] = output[4] * start_shape[4]; - output[2] = output[3] * start_shape[3]; - output[1] = output[2] * start_shape[2]; - output[0] = output[1] * start_shape[1]; - return GetCudaStatus(); -} - -template -cudaError_t BroadcastHeaviside(const std::vector &x1_shape, const std::vector &x2_shape, - const std::vector &y_shape, const T *x1, const T *x2, T *y, - const uint32_t &device_id, cudaStream_t cuda_stream) { - size_t size = 1; - for (auto d : y_shape) { - size *= d; - } - size_t start_dim[5]; - size_t end_dim[5]; - size_t output_dim[5]; - CalData(x1_shape, start_dim); - CalData(x2_shape, end_dim); - CalData(y_shape, output_dim); - cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5); - cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5); - cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5); - BroadcastHeavisideKernel> - <<>>( - x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3], x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0], - x2_shape[1], x2_shape[2], x2_shape[3], x2_shape[4], x2_shape[5], x2_shape[6], y_shape[0], y_shape[1], y_shape[2], - y_shape[3], y_shape[4], y_shape[5], y_shape[6], x1, x2, y); - return GetCudaStatus(); -} -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint8_t *, const uint8_t *, uint8_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint16_t *, const uint16_t *, uint16_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint32_t *, const uint32_t *, uint32_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint64_t *, const uint64_t *, uint64_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int8_t *, const int8_t *, int8_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int16_t *, const int16_t *, int16_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int32_t *, const int32_t *, int32_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int64_t *, const int64_t *, int64_t *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const half *, const half *, half *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const float *, const float *, float *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const double *, const double *, double *, - const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const uint8_t *, - const uint8_t *, uint8_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const uint16_t *, - const uint16_t *, uint16_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const uint32_t *, - const uint32_t *, uint32_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const uint64_t *, - const uint64_t *, uint64_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const int8_t *, - const int8_t *, int8_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const int16_t *, - const int16_t *, int16_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const int32_t *, - const int32_t *, int32_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const int64_t *, - const int64_t *, int64_t *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, const std::vector &, - const std::vector &, const half *, const half *, - half *, const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, const std::vector &, - const std::vector &, const float *, - const float *, float *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, - const std::vector &, - const std::vector &, const double *, - const double *, double *, const uint32_t &, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh" + +__constant__ size_t start_cal[5]; +__constant__ size_t end_cal[5]; +__constant__ size_t output_cal[5]; + +template +struct HeavisideFunc { + __device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) { + if (x1 > T(0)) { + return T(1); + } else if (x1 == T(0)) { + return x2; + } else { + return T(0); + } + } +}; + +template +__global__ void CalHeavisideKernel(size_t size, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x1[pos], x2[pos]); + } +} + +__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } + +template +__global__ void BroadcastHeavisideKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; + pos += blockDim.x * gridDim.x) { + size_t i = pos / output_cal[0] % d0; + size_t j = pos / output_cal[1] % d1; + size_t k = pos / output_cal[2] % d2; + size_t l = pos / output_cal[3] % d3; + size_t m = pos / output_cal[4] % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * start_cal[0]; + l_index += Index(j, l1) * start_cal[1]; + l_index += Index(k, l2) * start_cal[2]; + l_index += Index(l, l3) * start_cal[3]; + l_index += Index(m, l4) * start_cal[4]; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + size_t r_index = Index(i, r0) * end_cal[0]; + r_index += Index(j, r1) * end_cal[1]; + r_index += Index(k, r2) * end_cal[2]; + r_index += Index(l, r3) * end_cal[3]; + r_index += Index(m, r4) * end_cal[4]; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); + y[pos] = Func()(x1[l_index], x2[r_index]); + } +} + +template +cudaError_t CalHeaviside(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, + cudaStream_t cuda_stream) { + CalHeavisideKernel> + <<>>(size, x1, x2, y); + return GetCudaStatus(); +} + +cudaError_t CalData(const std::vector &start_shape, size_t *output) { + output[4] = start_shape[5] * start_shape[6]; + output[3] = output[4] * start_shape[4]; + output[2] = output[3] * start_shape[3]; + output[1] = output[2] * start_shape[2]; + output[0] = output[1] * start_shape[1]; + return GetCudaStatus(); +} + +template +cudaError_t BroadcastHeaviside(const std::vector &x1_shape, const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream) { + size_t size = 1; + for (auto d : y_shape) { + size *= d; + } + size_t start_dim[5]; + size_t end_dim[5]; + size_t output_dim[5]; + CalData(x1_shape, start_dim); + CalData(x2_shape, end_dim); + CalData(y_shape, output_dim); + cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5); + BroadcastHeavisideKernel> + <<>>( + x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3], x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0], + x2_shape[1], x2_shape[2], x2_shape[3], x2_shape[4], x2_shape[5], x2_shape[6], y_shape[0], y_shape[1], y_shape[2], + y_shape[3], y_shape[4], y_shape[5], y_shape[6], x1, x2, y); + return GetCudaStatus(); +} +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint8_t *, const uint8_t *, uint8_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint16_t *, const uint16_t *, uint16_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint32_t *, const uint32_t *, uint32_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const uint64_t *, const uint64_t *, uint64_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int8_t *, const int8_t *, int8_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int16_t *, const int16_t *, int16_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int32_t *, const int32_t *, int32_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const int64_t *, const int64_t *, int64_t *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const half *, const half *, half *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const float *, const float *, float *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t, const double *, const double *, double *, + const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const uint8_t *, + const uint8_t *, uint8_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const uint16_t *, + const uint16_t *, uint16_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const uint32_t *, + const uint32_t *, uint32_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const uint64_t *, + const uint64_t *, uint64_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const int8_t *, + const int8_t *, int8_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const int16_t *, + const int16_t *, int16_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const int32_t *, + const int32_t *, int32_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const int64_t *, + const int64_t *, int64_t *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, const std::vector &, + const std::vector &, const half *, const half *, + half *, const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, const std::vector &, + const std::vector &, const float *, + const float *, float *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &, + const std::vector &, + const std::vector &, const double *, + const double *, double *, const uint32_t &, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh index 60aaa307464..4edcea07246 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/heaviside_impl.cuh @@ -1,31 +1,31 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, - cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &x1_shape, const std::vector &x2_shape, - const std::vector &y_shape, const T *x1, const T *x2, T *y, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalHeaviside(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t BroadcastHeaviside(const std::vector &x1_shape, const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HEAVISIDE_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu index d0625a56242..61b703d2634 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu @@ -1,126 +1,126 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" - -__constant__ size_t start_cal[5]; -__constant__ size_t end_cal[5]; -__constant__ size_t output_cal[5]; - -template -struct HypotFunc { - __device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) { return hypotf(x1, x2); } -}; - -template <> -struct HypotFunc { - __device__ __host__ __forceinline__ double operator()(const double &x1, const double &x2) { return hypot(x1, x2); } -}; - -template -__global__ void CalHypotKernel(size_t size, const T *x1, const T *x2, T *y) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - y[pos] = Func()(x1[pos], x2[pos]); - } -} - -__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } - -template -__global__ void BroadcastHypotKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, - const size_t l4, const size_t l5, const size_t l6, const size_t r0, - const size_t r1, const size_t r2, const size_t r3, const size_t r4, - const size_t r5, const size_t r6, const size_t d0, const size_t d1, - const size_t d2, const size_t d3, const size_t d4, const size_t d5, - const size_t d6, const T *x1, const T *x2, T *y) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; - pos += blockDim.x * gridDim.x) { - size_t i = pos / output_cal[0] % d0; - size_t j = pos / output_cal[1] % d1; - size_t k = pos / output_cal[2] % d2; - size_t l = pos / output_cal[3] % d3; - size_t m = pos / output_cal[4] % d4; - size_t n = pos / d6 % d5; - size_t o = pos % d6; - - size_t l_index = Index(i, l0) * start_cal[0]; - l_index += Index(j, l1) * start_cal[1]; - l_index += Index(k, l2) * start_cal[2]; - l_index += Index(l, l3) * start_cal[3]; - l_index += Index(m, l4) * start_cal[4]; - l_index += Index(n, l5) * l6; - l_index += Index(o, l6); - size_t r_index = Index(i, r0) * end_cal[0]; - r_index += Index(j, r1) * end_cal[1]; - r_index += Index(k, r2) * end_cal[2]; - r_index += Index(l, r3) * end_cal[3]; - r_index += Index(m, r4) * end_cal[4]; - r_index += Index(n, r5) * r6; - r_index += Index(o, r6); - y[pos] = Func()(x1[l_index], x2[r_index]); - } -} - -template -cudaError_t CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, cudaStream_t cuda_stream) { - CalHypotKernel> - <<>>(size, x1, x2, y); - return GetCudaStatus(); -} - -void CalShapeData(const std::vector &start_shape, size_t *output) { - output[4] = start_shape[5] * start_shape[6]; - output[3] = output[4] * start_shape[4]; - output[2] = output[3] * start_shape[3]; - output[1] = output[2] * start_shape[2]; - output[0] = output[1] * start_shape[1]; -} - -template -cudaError_t BroadcastHypot(const std::vector &x1_shape, const std::vector &x2_shape, - const std::vector &y_shape, const T *x1, const T *x2, T *y, - const uint32_t &device_id, cudaStream_t cuda_stream) { - size_t size = 1; - for (auto d : y_shape) { - size *= d; - } - size_t start_dim[5]; - size_t end_dim[5]; - size_t output_dim[5]; - CalShapeData(x1_shape, start_dim); - CalShapeData(x2_shape, end_dim); - CalShapeData(y_shape, output_dim); - cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5); - cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5); - cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5); - BroadcastHypotKernel><<>>( - x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3], x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0], x2_shape[1], - x2_shape[2], x2_shape[3], x2_shape[4], x2_shape[5], x2_shape[6], y_shape[0], y_shape[1], y_shape[2], y_shape[3], - y_shape[4], y_shape[5], y_shape[6], x1, x2, y); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalHypot(size_t, const float *, const float *, float *, const uint32_t &, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalHypot(size_t, const double *, const double *, double *, - const uint32_t &, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &, const std::vector &, - const std::vector &, const float *, const float *, - float *, const uint32_t &, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &, const std::vector &, - const std::vector &, const double *, const double *, - double *, const uint32_t &, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" + +__constant__ size_t start_cal[5]; +__constant__ size_t end_cal[5]; +__constant__ size_t output_cal[5]; + +template +struct HypotFunc { + __device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) { return hypotf(x1, x2); } +}; + +template <> +struct HypotFunc { + __device__ __host__ __forceinline__ double operator()(const double &x1, const double &x2) { return hypot(x1, x2); } +}; + +template +__global__ void CalHypotKernel(size_t size, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x1[pos], x2[pos]); + } +} + +__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } + +template +__global__ void BroadcastHypotKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; + pos += blockDim.x * gridDim.x) { + size_t i = pos / output_cal[0] % d0; + size_t j = pos / output_cal[1] % d1; + size_t k = pos / output_cal[2] % d2; + size_t l = pos / output_cal[3] % d3; + size_t m = pos / output_cal[4] % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * start_cal[0]; + l_index += Index(j, l1) * start_cal[1]; + l_index += Index(k, l2) * start_cal[2]; + l_index += Index(l, l3) * start_cal[3]; + l_index += Index(m, l4) * start_cal[4]; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + size_t r_index = Index(i, r0) * end_cal[0]; + r_index += Index(j, r1) * end_cal[1]; + r_index += Index(k, r2) * end_cal[2]; + r_index += Index(l, r3) * end_cal[3]; + r_index += Index(m, r4) * end_cal[4]; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); + y[pos] = Func()(x1[l_index], x2[r_index]); + } +} + +template +cudaError_t CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, cudaStream_t cuda_stream) { + CalHypotKernel> + <<>>(size, x1, x2, y); + return GetCudaStatus(); +} + +void CalShapeData(const std::vector &start_shape, size_t *output) { + output[4] = start_shape[5] * start_shape[6]; + output[3] = output[4] * start_shape[4]; + output[2] = output[3] * start_shape[3]; + output[1] = output[2] * start_shape[2]; + output[0] = output[1] * start_shape[1]; +} + +template +cudaError_t BroadcastHypot(const std::vector &x1_shape, const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream) { + size_t size = 1; + for (auto d : y_shape) { + size *= d; + } + size_t start_dim[5]; + size_t end_dim[5]; + size_t output_dim[5]; + CalShapeData(x1_shape, start_dim); + CalShapeData(x2_shape, end_dim); + CalShapeData(y_shape, output_dim); + cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5); + BroadcastHypotKernel><<>>( + x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3], x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0], x2_shape[1], + x2_shape[2], x2_shape[3], x2_shape[4], x2_shape[5], x2_shape[6], y_shape[0], y_shape[1], y_shape[2], y_shape[3], + y_shape[4], y_shape[5], y_shape[6], x1, x2, y); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalHypot(size_t, const float *, const float *, float *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalHypot(size_t, const double *, const double *, double *, + const uint32_t &, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &, const std::vector &, + const std::vector &, const float *, const float *, + float *, const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &, const std::vector &, + const std::vector &, const double *, const double *, + double *, const uint32_t &, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh index 41cf7c94ce4..0f2a57d9d89 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh @@ -1,31 +1,31 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, - cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &x1_shape, const std::vector &x2_shape, - const std::vector &y_shape, const T *x1, const T *x2, T *y, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t BroadcastHypot(const std::vector &x1_shape, const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cu index 803d89b546a..8b405e75c03 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cu @@ -1,117 +1,117 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "include/cuda_fp16.h" -#include "include/cuda_runtime.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -struct is_selected { - __host__ __device__ bool operator()(const bool x) { return x == false; } -}; -template -cudaError_t CalListDiff(size_t x_size, size_t y_size, const T *x, const T *y, T *out, S *idx, T *workspace_y, - S *workspace_xidx, bool *workspace_flag, const uint32_t &device_id, cudaStream_t cuda_stream, - int *count) { - int count_out = 0; - auto policy = thrust::cuda::par.on(cuda_stream); - cudaMemcpy(workspace_y, y, y_size * sizeof(T), cudaMemcpyDeviceToDevice); - thrust::sequence(policy, thrust::device_pointer_cast(workspace_xidx), - thrust::device_pointer_cast(workspace_xidx) + x_size); - thrust::stable_sort(policy, thrust::device_pointer_cast(workspace_y), - thrust::device_pointer_cast(workspace_y) + y_size); - thrust::binary_search(thrust::device_pointer_cast(workspace_y), thrust::device_pointer_cast(workspace_y) + y_size, - thrust::device_pointer_cast(x), thrust::device_pointer_cast(x) + x_size, - thrust::device_pointer_cast(workspace_flag)); - count_out = thrust::count(policy, thrust::device_pointer_cast(workspace_flag), - thrust::device_pointer_cast(workspace_flag) + x_size, false); - thrust::copy_if( - policy, - thrust::make_zip_iterator( - thrust::make_tuple(thrust::device_pointer_cast(workspace_xidx), thrust::device_pointer_cast(x))), - thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(workspace_xidx) + x_size, - thrust::device_pointer_cast(x) + x_size)), - thrust::device_pointer_cast(workspace_flag), - thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(idx), thrust::device_pointer_cast(out))), - is_selected()); - *count = count_out; - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const half *, const half *, half *, int64_t *, half *, - int64_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, - int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const float *, const float *, float *, int64_t *, - float *, int64_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, - int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const double *, const double *, double *, int64_t *, - double *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, int64_t *, - uint8_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint16_t *, const uint16_t *, uint16_t *, - int64_t *, uint16_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int8_t *, const int8_t *, int8_t *, int64_t *, - int8_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int16_t *, const int16_t *, int16_t *, int64_t *, - int16_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int32_t *, const int32_t *, int32_t *, int64_t *, - int32_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int64_t *, const int64_t *, int64_t *, int64_t *, - int64_t *, int64_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); - -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const half *, const half *, half *, int32_t *, half *, - int32_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, - int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const float *, const float *, float *, int32_t *, - float *, int32_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, - int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const double *, const double *, double *, int32_t *, - double *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, int32_t *, - uint8_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint16_t *, const uint16_t *, uint16_t *, - int32_t *, uint16_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int8_t *, const int8_t *, int8_t *, int32_t *, - int8_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int16_t *, const int16_t *, int16_t *, int32_t *, - int16_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int32_t *, const int32_t *, int32_t *, int32_t *, - int32_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); -template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int64_t *, const int64_t *, int64_t *, int32_t *, - int64_t *, int32_t *, bool *, const uint32_t &, - cudaStream_t cuda_stream, int *count); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/cuda_fp16.h" +#include "include/cuda_runtime.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +struct is_selected { + __host__ __device__ bool operator()(const bool x) { return x == false; } +}; +template +cudaError_t CalListDiff(size_t x_size, size_t y_size, const T *x, const T *y, T *out, S *idx, T *workspace_y, + S *workspace_xidx, bool *workspace_flag, const uint32_t &device_id, cudaStream_t cuda_stream, + int *count) { + int count_out = 0; + auto policy = thrust::cuda::par.on(cuda_stream); + cudaMemcpy(workspace_y, y, y_size * sizeof(T), cudaMemcpyDeviceToDevice); + thrust::sequence(policy, thrust::device_pointer_cast(workspace_xidx), + thrust::device_pointer_cast(workspace_xidx) + x_size); + thrust::stable_sort(policy, thrust::device_pointer_cast(workspace_y), + thrust::device_pointer_cast(workspace_y) + y_size); + thrust::binary_search(thrust::device_pointer_cast(workspace_y), thrust::device_pointer_cast(workspace_y) + y_size, + thrust::device_pointer_cast(x), thrust::device_pointer_cast(x) + x_size, + thrust::device_pointer_cast(workspace_flag)); + count_out = thrust::count(policy, thrust::device_pointer_cast(workspace_flag), + thrust::device_pointer_cast(workspace_flag) + x_size, false); + thrust::copy_if( + policy, + thrust::make_zip_iterator( + thrust::make_tuple(thrust::device_pointer_cast(workspace_xidx), thrust::device_pointer_cast(x))), + thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(workspace_xidx) + x_size, + thrust::device_pointer_cast(x) + x_size)), + thrust::device_pointer_cast(workspace_flag), + thrust::make_zip_iterator(thrust::make_tuple(thrust::device_pointer_cast(idx), thrust::device_pointer_cast(out))), + is_selected()); + *count = count_out; + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const half *, const half *, half *, int64_t *, half *, + int64_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, + int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const float *, const float *, float *, int64_t *, + float *, int64_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, + int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const double *, const double *, double *, int64_t *, + double *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, int64_t *, + uint8_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint16_t *, const uint16_t *, uint16_t *, + int64_t *, uint16_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int8_t *, const int8_t *, int8_t *, int64_t *, + int8_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int16_t *, const int16_t *, int16_t *, int64_t *, + int16_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int32_t *, const int32_t *, int32_t *, int64_t *, + int32_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int64_t *, const int64_t *, int64_t *, int64_t *, + int64_t *, int64_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); + +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const half *, const half *, half *, int32_t *, half *, + int32_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, + int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const float *, const float *, float *, int32_t *, + float *, int32_t *, bool *, const uint32_t &, cudaStream_t cuda_stream, + int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const double *, const double *, double *, int32_t *, + double *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, int32_t *, + uint8_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const uint16_t *, const uint16_t *, uint16_t *, + int32_t *, uint16_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int8_t *, const int8_t *, int8_t *, int32_t *, + int8_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int16_t *, const int16_t *, int16_t *, int32_t *, + int16_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int32_t *, const int32_t *, int32_t *, int32_t *, + int32_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); +template CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t, size_t, const int64_t *, const int64_t *, int64_t *, int32_t *, + int64_t *, int32_t *, bool *, const uint32_t &, + cudaStream_t cuda_stream, int *count); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh index 31fb17c18cd..512fd76d630 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/list_diff_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ - -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t x_size, size_t y_size, const T *x, const T *y, T *out, S *idx, - T *workspace_y, S *workspace_xidx, bool *workspace_flag, - const uint32_t &device_id, cudaStream_t cuda_stream, int *count); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ + +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalListDiff(size_t x_size, size_t y_size, const T *x, const T *y, T *out, S *idx, + T *workspace_y, S *workspace_xidx, bool *workspace_flag, + const uint32_t &device_id, cudaStream_t cuda_stream, int *count); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LIST_DIFF_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu index df474b96304..61159cd6e3d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cu @@ -1,657 +1,657 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "loss_with_reduction_impl.cuh" -#include "util.cuh" - -inline __device__ float logT(float x) { return logf(x); } -inline __device__ half logT(half x) { return hlog(x); } -inline __device__ float castT(float ref, int x) { return __int2float_rd(x); } -inline __device__ half castT(half ref, int x) { return __int2half_rd(x); } -inline __device__ float maxT(float a, float b) { return fmaxf(a, b); } -inline __device__ half maxT(half a, half b) { return a > b ? a : b; } - -template -__global__ void Copy(T *loss, T *tmp_loss, ReductionMode reduction, int input_size) { - loss[0] += tmp_loss[0]; - if (reduction == ReductionMode::kMean) { - loss[0] /= castT(loss[0], input_size); - } -} - -template -__global__ void AddTile(T *tmp_loss, int index) { - tmp_loss[0] += tmp_loss[index]; -} -template -__global__ void PartialSum(T *tmp_loss, int stride) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) { - tmp_loss[i] += tmp_loss[i + stride]; - } -} - -template -__device__ void MultiplyDevice(const S a, const T b, T *out) { - *out = a * b; -} - -template <> -__device__ void MultiplyDevice(const half a, const float b, float *out) { - // cast a to float for calculation - float a_float = __half2float(a); - *out = a_float * b; -} - -template <> -__device__ void MultiplyDevice(const float a, const half b, half *out) { - // cast b to float for calculation - float b_float = __half2float(b); - float out_float = a * b_float; - *out = __float2half(out_float); -} - -template -__device__ void Divide(const T *numerator, const S *denominator, T *result) { - result[0] = numerator[0] / denominator[0]; -} - -template <> -__device__ void Divide(const float *numerator, const half *denominator, float *result) { - float denom_float = __half2float(denominator[0]); - - result[0] = numerator[0] / denom_float; -} - -template <> -__device__ void Divide(const half *numerator, const float *denominator, half *result) { - float numer_float = __half2float(numerator[0]); - - float result_float = numer_float / denominator[0]; - - result[0] = __float2half(result_float); -} - -template -__device__ __forceinline__ void WarpReduce(T *shared_data, const int tid) { - T local_data = shared_data[tid]; - if (BlockDimX >= 32) { - local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 16); - } - if (BlockDimX >= 16) { - local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 8); - } - if (BlockDimX >= 8) { - local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 4); - } - if (BlockDimX >= 4) { - local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 2); - } - if (BlockDimX >= 2) { - local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 1); - } - if (tid == 0) { - shared_data[tid] = local_data; - } -} - -template -__device__ __forceinline__ void BinaryWarpReduce(T *shared_data0, S *shared_data1, const int tid) { - T local_data0 = shared_data0[tid]; - S local_data1 = shared_data1[tid]; - if (BlockDimX >= 32) { - local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 16); - local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 16); - } - if (BlockDimX >= 16) { - local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 8); - local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 8); - } - if (BlockDimX >= 8) { - local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 4); - local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 4); - } - if (BlockDimX >= 4) { - local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 2); - local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 2); - } - if (BlockDimX >= 2) { - local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 1); - local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 1); - } - if (tid == 0) { - shared_data0[tid] = local_data0; - shared_data1[tid] = local_data1; - } -} - -template -__device__ __forceinline__ void BlockReduce(T *shared_data, const unsigned int tid) { - if (BlockDimX >= 1024) { - if (tid < 512) { - shared_data[tid] = shared_data[tid] + shared_data[tid + 512]; - } - __syncthreads(); - } - if (BlockDimX >= 512) { - if (tid < 256) { - shared_data[tid] = shared_data[tid] + shared_data[tid + 256]; - } - __syncthreads(); - } - if (BlockDimX >= 256) { - if (tid < 128) { - shared_data[tid] = shared_data[tid] + shared_data[tid + 128]; - } - __syncthreads(); - } - if (BlockDimX >= 128) { - if (tid < 64) { - shared_data[tid] = shared_data[tid] + shared_data[tid + 64]; - } - __syncthreads(); - } - if (BlockDimX >= 64) { - if (tid < 32) { - shared_data[tid] = shared_data[tid] + shared_data[tid + 32]; - } - } - __syncthreads(); - - if (tid < 32) WarpReduce(shared_data, tid); - - __syncthreads(); -} - -template -__device__ __forceinline__ void BinaryBlockReduce(T *shared_data0, S *shared_data1, const unsigned int tid) { - if (BlockDimX >= 1024) { - if (tid < 512) { - shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 512]; - shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 512]; - } - __syncthreads(); - } - if (BlockDimX >= 512) { - if (tid < 256) { - shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 256]; - shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 256]; - } - __syncthreads(); - } - if (BlockDimX >= 256) { - if (tid < 128) { - shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 128]; - shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 128]; - } - __syncthreads(); - } - if (BlockDimX >= 128) { - if (tid < 64) { - shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 64]; - shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 64]; - } - __syncthreads(); - } - if (BlockDimX >= 64) { - if (tid < 32) { - shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 32]; - shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 32]; - } - } - __syncthreads(); - - if (tid < 32) BinaryWarpReduce(shared_data0, shared_data1, tid); - - __syncthreads(); -} - -template -__inline__ __device__ void Reduce(T *output, T *shared_data, const unsigned int tid) { - BlockReduce(shared_data, tid); - - if (tid == 0) { - MsAtomicAdd(output, shared_data[0]); - } -} - -template -__inline__ __device__ void BinaryReduce(T *output0, S *output1, T *shared_data0, S *shared_data1, - const unsigned int tid) { - BinaryBlockReduce(shared_data0, shared_data1, tid); - - if (tid == 0) { - MsAtomicAdd(output0, shared_data0[0]); - MsAtomicAdd(output1, shared_data1[0]); - } -} - -template -__global__ void NLLLossNativeKernel(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, - unsigned int label_size, unsigned int num_classes, int32_t ignore_index) { - unsigned int tid = threadIdx.x; - const S zero = static_cast(0); - const S one = static_cast(1); - __shared__ S shared_total_weight[sharedSize]; - shared_total_weight[tid] = zero; - if (tid == 0 && blockIdx.x == 0) { - total_weight[0] = zero; - } - - for (unsigned int gid = blockIdx.x * BlockDimX + tid, gridSize = BlockDimX * gridDim.x; gid < label_size; - gid += gridSize) { - int32_t label = labels[gid]; - if (label != ignore_index) { - CUDA_KERNEL_ASSERT(label >= 0 && label < num_classes); - S weight = weights ? weights[label] : one; - T logit; - MultiplyDevice(weight, -(logits[gid * num_classes + label]), &logit); - loss[gid] = logit; - shared_total_weight[tid] = shared_total_weight[tid] + weight; - } - } - __syncthreads(); - Reduce(total_weight, shared_total_weight, tid); -} - -template -__global__ void NLLLossReduceKernel(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, - unsigned int label_size, unsigned int num_classes, int32_t ignore_index, - bool mean) { - unsigned int tid = threadIdx.x; - const S one = static_cast(1); - __shared__ T shared_loss[sharedSize0]; - __shared__ S shared_total_weight[sharedSize1]; - shared_loss[tid] = static_cast(0); - shared_total_weight[tid] = static_cast(0); - if (tid == 0 && blockIdx.x == 0) { - loss[0] = static_cast(0); - total_weight[0] = static_cast(0); - } - - for (unsigned int gid = blockIdx.x * BlockDimX + tid, gridSize = BlockDimX * gridDim.x; gid < label_size; - gid += gridSize) { - int32_t label = labels[gid]; - if (label != ignore_index) { - CUDA_KERNEL_ASSERT(label >= 0 && label < num_classes); - S weight = weights ? weights[label] : one; - T logit; - MultiplyDevice(weight, -(logits[gid * num_classes + label]), &logit); - shared_loss[tid] = shared_loss[tid] + logit; - shared_total_weight[tid] = shared_total_weight[tid] + weight; - } - } - __syncthreads(); - BinaryReduce(loss, total_weight, shared_loss, shared_total_weight, tid); - if (mean && tid == 0) { - __syncthreads(); - Divide(loss, total_weight, loss); - } -} - -template -__global__ void LossInitKernel(T *loss) { - loss[0] = static_cast(0.); -} - -template -__global__ void InitZero(T *array, int size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { - array[i] = static_cast(0.); - } -} - -template -__global__ void KLDivLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y, - T *loss, T *tmp_loss) { - T epsilon = 1e-6; - if (reduction == ReductionMode::kNone) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_y[i], epsilon); - T value = input_y[i] * (logT(denominator) - input_x[i]); - loss[i] = value; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_y[i], epsilon); - T value = input_y[i] * (logT(denominator) - input_x[i]); - tmp_loss[i] = value; - } - } -} - -template -cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, - T *loss, T *tmp_loss, cudaStream_t stream) { - LossInitKernel<<<1, 1, 0, stream>>>(loss); - KLDivLossKernel<<>>(input_size, reduction, input_x, input_y, loss, - tmp_loss); - if (reduction != ReductionMode::kNone) { - if (input_size % 2 == 1) { - AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); - } - for (int stride = input_size / 2; stride > 0; stride >>= 1) { - PartialSum<<>>(tmp_loss, stride); - if (stride > 2 && stride % 2 == 1) { - AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); - } - } - Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); - } - return GetCudaStatus(); -} - -template -__global__ void KLDivLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, - const T *input_y, const T *dloss, T *dx) { - T epsilon = 1e-6; - if (reduction == ReductionMode::kNone) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_y[i], epsilon); - dx[i] = -input_y[i] * dloss[i]; - } - } else { - T dloss1 = dloss[0]; - if (reduction == ReductionMode::kMean) { - dloss1 = dloss[0] / castT(dloss[0], input_size); - } - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_y[i], epsilon); - dx[i] = -input_y[i] * dloss1; - } - } -} - -template -cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, - const T *dloss, T *dx, cudaStream_t stream) { - KLDivLossGradKernel<<>>(input_size, reduction, input_x, input_y, - dloss, dx); - return GetCudaStatus(); -} - -template -__global__ void BinaryCrossEntropyLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, - const T *input_y, const T *weight, T *loss, T *tmp_loss) { - T epsilon = 1e-12; - T zero = static_cast(0); - T one = static_cast(1); - if (reduction == ReductionMode::kNone && weight != nullptr) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); - T value = - -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); - loss[i] = value; - } - } else if (reduction == ReductionMode::kNone && weight == nullptr) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); - T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); - loss[i] = value; - } - } else if (reduction != ReductionMode::kNone && weight != nullptr) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); - T value = - -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); - tmp_loss[i] = value; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); - T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); - tmp_loss[i] = value; - } - } -} - -template -cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, - const T *input_y, const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) { - LossInitKernel<<<1, 1, 0, stream>>>(loss); - BinaryCrossEntropyLossKernel<<>>(input_size, reduction, input_x, - input_y, weight, loss, tmp_loss); - if (reduction != ReductionMode::kNone) { - if (input_size % 2 == 1) { - AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); - } - for (int stride = input_size / 2; stride > 0; stride >>= 1) { - PartialSum<<>>(tmp_loss, stride); - if (stride > 2 && stride % 2 == 1) { - AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); - } - } - Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); - } - return GetCudaStatus(); -} - -template -__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, - const T *input_y, const T *weight, const T *dloss, T *dx) { - T epsilon = 1e-12; - T one = static_cast(1); - if (reduction == ReductionMode::kNone) { - if (weight != nullptr) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); - T value = weight[i] * (input_x[i] - input_y[i]) / denominator; - dx[i] = value * dloss[i]; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); - T value = (input_x[i] - input_y[i]) / denominator; - dx[i] = value * dloss[i]; - } - } - } else { - T dloss1 = dloss[0]; - if (reduction == ReductionMode::kMean) { - dloss1 = dloss[0] / castT(dloss[0], input_size); - } - if (weight != nullptr) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); - T value = weight[i] * (input_x[i] - input_y[i]) / denominator; - dx[i] = value * dloss1; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); - T value = (input_x[i] - input_y[i]) / denominator; - dx[i] = value * dloss1; - } - } - } -} - -template -cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, - const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream) { - BinaryCrossEntropyLossGradKernel<<>>(input_size, reduction, input_x, - input_y, weight, dloss, dx); - return GetCudaStatus(); -} - -template -cudaError_t NLLLoss(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, - unsigned int label_size, unsigned int num_classes, const ReductionMode reduction, - int32_t ignore_index, cudaStream_t stream) { - const unsigned int Threads = 512; - if (reduction == ReductionMode::kNone) { - const unsigned int sharedSize = Threads * sizeof(S) + 1; - NLLLossNativeKernel<<>>( - logits, labels, weights, loss, total_weight, label_size, num_classes, ignore_index); - } else { - bool mean = (reduction == ReductionMode::kMean); - const unsigned int sharedSize0 = Threads * sizeof(T) + 1; - const unsigned int sharedSize1 = Threads * sizeof(S) + 1; - NLLLossReduceKernel<<>>( - logits, labels, weights, loss, total_weight, label_size, num_classes, ignore_index, mean); - } - cudaStreamSynchronize(stream); - return GetCudaStatus(); -} - -template -__global__ void NLLLossGradKernel(const int n, const int c, const ReductionMode reduction, const T *input, - const int32_t *target, const S *weight, const S *total_weight, int32_t ignore_index, - const T *dloss, T *dinput) { - int input_idx; - int target_class; - S tmp_quot; - if (reduction == ReductionMode::kNone) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - target_class = static_cast(target[i]); - if (target_class == ignore_index) { - continue; - } - - input_idx = (i * c) + target_class; - - MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx); - } - } else if (reduction == ReductionMode::kMean) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - target_class = static_cast(target[i]); - if (target_class == ignore_index) { - continue; - } - - input_idx = (i * c) + target_class; - - tmp_quot = (-weight[target_class]) / *total_weight; - MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx); - } - } else if (reduction == ReductionMode::kSum) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - target_class = static_cast(target[i]); - if (target_class == ignore_index) { - continue; - } - - input_idx = (i * c) + target_class; - - MultiplyDevice(-weight[target_class], dloss[0], dinput + input_idx); - } - } -} - -template -cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, - const S *weight, const S *total_weight, const T *dloss, T *dinput, int32_t ignore_index, - cudaStream_t stream) { - int input_size = n * c; - InitZero<<>>(dinput, input_size); - - NLLLossGradKernel<<>>(n, c, reduction, input, target, weight, total_weight, - ignore_index, dloss, dinput); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t NLLLoss(const half *logits, const int32_t *labels, const half *weights, - half *loss, half *total_weight, const unsigned int label_size, - const unsigned int num_classes, const ReductionMode reduction, - int32_t ignore_index, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLoss(const half *logits, const int32_t *labels, - const float *weights, half *loss, float *total_weight, - unsigned int label_size, unsigned int num_classes, - const ReductionMode reduction, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLoss(const float *logits, const int32_t *labels, - const half *weights, float *loss, half *total_weight, - unsigned int label_size, unsigned int num_classes, - const ReductionMode reduction, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLoss(const float *logits, const int32_t *labels, - const float *weights, float *loss, float *total_weight, - unsigned int label_size, unsigned int num_classes, - const ReductionMode reduction, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, - const float *input_x, const float *input_y, float *loss, - float *tmp_loss, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, - const float *input_x, const float *input_y, - const float *dloss, float *dx, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, - const double *input_x, const double *input_y, double *loss, - double *tmp_loss, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, - const double *input_x, const double *input_y, - const double *dloss, double *dx, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, - const ReductionMode &reduction, const float *input_x, - const float *input_y, const float *weight, - float *loss, float *tmp_loss, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, - const ReductionMode &reduction, - const float *input_x, const float *input_y, - const float *weight, const float *dloss, - float *dx, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, - const float *input, const int32_t *target, - const float *weight, const float *total_weight, - const float *dloss, float *dinput, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, - const float *input, const int32_t *target, - const half *weight, const half *total_weight, - const float *dloss, float *dinput, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, - const half *input_x, const half *input_y, half *loss, - half *tmp_loss, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, - const half *input_x, const half *input_y, const half *dloss, - half *dx, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, - const half *input_x, const half *input_y, - const half *weight, half *loss, half *tmp_loss, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, - const ReductionMode &reduction, - const half *input_x, const half *input_y, - const half *weight, const half *dloss, half *dx, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, - const half *input, const int32_t *target, - const half *weight, const half *total_weight, - const half *dloss, half *dinput, int32_t ignore_index, - cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, - const half *input, const int32_t *target, - const float *weight, const float *total_weight, - const half *dloss, half *dinput, int32_t ignore_index, - cudaStream_t stream); +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "loss_with_reduction_impl.cuh" +#include "util.cuh" + +inline __device__ float logT(float x) { return logf(x); } +inline __device__ half logT(half x) { return hlog(x); } +inline __device__ float castT(float ref, int x) { return __int2float_rd(x); } +inline __device__ half castT(half ref, int x) { return __int2half_rd(x); } +inline __device__ float maxT(float a, float b) { return fmaxf(a, b); } +inline __device__ half maxT(half a, half b) { return a > b ? a : b; } + +template +__global__ void Copy(T *loss, T *tmp_loss, ReductionMode reduction, int input_size) { + loss[0] += tmp_loss[0]; + if (reduction == ReductionMode::kMean) { + loss[0] /= castT(loss[0], input_size); + } +} + +template +__global__ void AddTile(T *tmp_loss, int index) { + tmp_loss[0] += tmp_loss[index]; +} +template +__global__ void PartialSum(T *tmp_loss, int stride) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) { + tmp_loss[i] += tmp_loss[i + stride]; + } +} + +template +__device__ void MultiplyDevice(const S a, const T b, T *out) { + *out = a * b; +} + +template <> +__device__ void MultiplyDevice(const half a, const float b, float *out) { + // cast a to float for calculation + float a_float = __half2float(a); + *out = a_float * b; +} + +template <> +__device__ void MultiplyDevice(const float a, const half b, half *out) { + // cast b to float for calculation + float b_float = __half2float(b); + float out_float = a * b_float; + *out = __float2half(out_float); +} + +template +__device__ void Divide(const T *numerator, const S *denominator, T *result) { + result[0] = numerator[0] / denominator[0]; +} + +template <> +__device__ void Divide(const float *numerator, const half *denominator, float *result) { + float denom_float = __half2float(denominator[0]); + + result[0] = numerator[0] / denom_float; +} + +template <> +__device__ void Divide(const half *numerator, const float *denominator, half *result) { + float numer_float = __half2float(numerator[0]); + + float result_float = numer_float / denominator[0]; + + result[0] = __float2half(result_float); +} + +template +__device__ __forceinline__ void WarpReduce(T *shared_data, const int tid) { + T local_data = shared_data[tid]; + if (BlockDimX >= 32) { + local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 16); + } + if (BlockDimX >= 16) { + local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 8); + } + if (BlockDimX >= 8) { + local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 4); + } + if (BlockDimX >= 4) { + local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 2); + } + if (BlockDimX >= 2) { + local_data = local_data + __shfl_down_sync(0xFFFFFFFF, local_data, 1); + } + if (tid == 0) { + shared_data[tid] = local_data; + } +} + +template +__device__ __forceinline__ void BinaryWarpReduce(T *shared_data0, S *shared_data1, const int tid) { + T local_data0 = shared_data0[tid]; + S local_data1 = shared_data1[tid]; + if (BlockDimX >= 32) { + local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 16); + local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 16); + } + if (BlockDimX >= 16) { + local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 8); + local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 8); + } + if (BlockDimX >= 8) { + local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 4); + local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 4); + } + if (BlockDimX >= 4) { + local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 2); + local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 2); + } + if (BlockDimX >= 2) { + local_data0 = local_data0 + __shfl_down_sync(0xFFFFFFFF, local_data0, 1); + local_data1 = local_data1 + __shfl_down_sync(0xFFFFFFFF, local_data1, 1); + } + if (tid == 0) { + shared_data0[tid] = local_data0; + shared_data1[tid] = local_data1; + } +} + +template +__device__ __forceinline__ void BlockReduce(T *shared_data, const unsigned int tid) { + if (BlockDimX >= 1024) { + if (tid < 512) { + shared_data[tid] = shared_data[tid] + shared_data[tid + 512]; + } + __syncthreads(); + } + if (BlockDimX >= 512) { + if (tid < 256) { + shared_data[tid] = shared_data[tid] + shared_data[tid + 256]; + } + __syncthreads(); + } + if (BlockDimX >= 256) { + if (tid < 128) { + shared_data[tid] = shared_data[tid] + shared_data[tid + 128]; + } + __syncthreads(); + } + if (BlockDimX >= 128) { + if (tid < 64) { + shared_data[tid] = shared_data[tid] + shared_data[tid + 64]; + } + __syncthreads(); + } + if (BlockDimX >= 64) { + if (tid < 32) { + shared_data[tid] = shared_data[tid] + shared_data[tid + 32]; + } + } + __syncthreads(); + + if (tid < 32) WarpReduce(shared_data, tid); + + __syncthreads(); +} + +template +__device__ __forceinline__ void BinaryBlockReduce(T *shared_data0, S *shared_data1, const unsigned int tid) { + if (BlockDimX >= 1024) { + if (tid < 512) { + shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 512]; + shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 512]; + } + __syncthreads(); + } + if (BlockDimX >= 512) { + if (tid < 256) { + shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 256]; + shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 256]; + } + __syncthreads(); + } + if (BlockDimX >= 256) { + if (tid < 128) { + shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 128]; + shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 128]; + } + __syncthreads(); + } + if (BlockDimX >= 128) { + if (tid < 64) { + shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 64]; + shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 64]; + } + __syncthreads(); + } + if (BlockDimX >= 64) { + if (tid < 32) { + shared_data0[tid] = shared_data0[tid] + shared_data0[tid + 32]; + shared_data1[tid] = shared_data1[tid] + shared_data1[tid + 32]; + } + } + __syncthreads(); + + if (tid < 32) BinaryWarpReduce(shared_data0, shared_data1, tid); + + __syncthreads(); +} + +template +__inline__ __device__ void Reduce(T *output, T *shared_data, const unsigned int tid) { + BlockReduce(shared_data, tid); + + if (tid == 0) { + MsAtomicAdd(output, shared_data[0]); + } +} + +template +__inline__ __device__ void BinaryReduce(T *output0, S *output1, T *shared_data0, S *shared_data1, + const unsigned int tid) { + BinaryBlockReduce(shared_data0, shared_data1, tid); + + if (tid == 0) { + MsAtomicAdd(output0, shared_data0[0]); + MsAtomicAdd(output1, shared_data1[0]); + } +} + +template +__global__ void NLLLossNativeKernel(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, + unsigned int label_size, unsigned int num_classes, int32_t ignore_index) { + unsigned int tid = threadIdx.x; + const S zero = static_cast(0); + const S one = static_cast(1); + __shared__ S shared_total_weight[sharedSize]; + shared_total_weight[tid] = zero; + if (tid == 0 && blockIdx.x == 0) { + total_weight[0] = zero; + } + + for (unsigned int gid = blockIdx.x * BlockDimX + tid, gridSize = BlockDimX * gridDim.x; gid < label_size; + gid += gridSize) { + int32_t label = labels[gid]; + if (label != ignore_index) { + CUDA_KERNEL_ASSERT(label >= 0 && label < num_classes); + S weight = weights ? weights[label] : one; + T logit; + MultiplyDevice(weight, -(logits[gid * num_classes + label]), &logit); + loss[gid] = logit; + shared_total_weight[tid] = shared_total_weight[tid] + weight; + } + } + __syncthreads(); + Reduce(total_weight, shared_total_weight, tid); +} + +template +__global__ void NLLLossReduceKernel(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, + unsigned int label_size, unsigned int num_classes, int32_t ignore_index, + bool mean) { + unsigned int tid = threadIdx.x; + const S one = static_cast(1); + __shared__ T shared_loss[sharedSize0]; + __shared__ S shared_total_weight[sharedSize1]; + shared_loss[tid] = static_cast(0); + shared_total_weight[tid] = static_cast(0); + if (tid == 0 && blockIdx.x == 0) { + loss[0] = static_cast(0); + total_weight[0] = static_cast(0); + } + + for (unsigned int gid = blockIdx.x * BlockDimX + tid, gridSize = BlockDimX * gridDim.x; gid < label_size; + gid += gridSize) { + int32_t label = labels[gid]; + if (label != ignore_index) { + CUDA_KERNEL_ASSERT(label >= 0 && label < num_classes); + S weight = weights ? weights[label] : one; + T logit; + MultiplyDevice(weight, -(logits[gid * num_classes + label]), &logit); + shared_loss[tid] = shared_loss[tid] + logit; + shared_total_weight[tid] = shared_total_weight[tid] + weight; + } + } + __syncthreads(); + BinaryReduce(loss, total_weight, shared_loss, shared_total_weight, tid); + if (mean && tid == 0) { + __syncthreads(); + Divide(loss, total_weight, loss); + } +} + +template +__global__ void LossInitKernel(T *loss) { + loss[0] = static_cast(0.); +} + +template +__global__ void InitZero(T *array, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + array[i] = static_cast(0.); + } +} + +template +__global__ void KLDivLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y, + T *loss, T *tmp_loss) { + T epsilon = 1e-6; + if (reduction == ReductionMode::kNone) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_y[i], epsilon); + T value = input_y[i] * (logT(denominator) - input_x[i]); + loss[i] = value; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_y[i], epsilon); + T value = input_y[i] * (logT(denominator) - input_x[i]); + tmp_loss[i] = value; + } + } +} + +template +cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, + T *loss, T *tmp_loss, cudaStream_t stream) { + LossInitKernel<<<1, 1, 0, stream>>>(loss); + KLDivLossKernel<<>>(input_size, reduction, input_x, input_y, loss, + tmp_loss); + if (reduction != ReductionMode::kNone) { + if (input_size % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); + } + for (int stride = input_size / 2; stride > 0; stride >>= 1) { + PartialSum<<>>(tmp_loss, stride); + if (stride > 2 && stride % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); + } + } + Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); + } + return GetCudaStatus(); +} + +template +__global__ void KLDivLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, + const T *input_y, const T *dloss, T *dx) { + T epsilon = 1e-6; + if (reduction == ReductionMode::kNone) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_y[i], epsilon); + dx[i] = -input_y[i] * dloss[i]; + } + } else { + T dloss1 = dloss[0]; + if (reduction == ReductionMode::kMean) { + dloss1 = dloss[0] / castT(dloss[0], input_size); + } + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_y[i], epsilon); + dx[i] = -input_y[i] * dloss1; + } + } +} + +template +cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, + const T *dloss, T *dx, cudaStream_t stream) { + KLDivLossGradKernel<<>>(input_size, reduction, input_x, input_y, + dloss, dx); + return GetCudaStatus(); +} + +template +__global__ void BinaryCrossEntropyLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, + const T *input_y, const T *weight, T *loss, T *tmp_loss) { + T epsilon = 1e-12; + T zero = static_cast(0); + T one = static_cast(1); + if (reduction == ReductionMode::kNone && weight != nullptr) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); + T value = + -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); + loss[i] = value; + } + } else if (reduction == ReductionMode::kNone && weight == nullptr) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); + T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); + loss[i] = value; + } + } else if (reduction != ReductionMode::kNone && weight != nullptr) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); + T value = + -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + CUDA_KERNEL_ASSERT(input_x[i] >= zero && input_x[i] <= one); + T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } +} + +template +cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) { + LossInitKernel<<<1, 1, 0, stream>>>(loss); + BinaryCrossEntropyLossKernel<<>>(input_size, reduction, input_x, + input_y, weight, loss, tmp_loss); + if (reduction != ReductionMode::kNone) { + if (input_size % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); + } + for (int stride = input_size / 2; stride > 0; stride >>= 1) { + PartialSum<<>>(tmp_loss, stride); + if (stride > 2 && stride % 2 == 1) { + AddTile<<<1, 1, 0, stream>>>(tmp_loss, stride - 1); + } + } + Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); + } + return GetCudaStatus(); +} + +template +__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, + const T *input_y, const T *weight, const T *dloss, T *dx) { + T epsilon = 1e-12; + T one = static_cast(1); + if (reduction == ReductionMode::kNone) { + if (weight != nullptr) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); + T value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } + } else { + T dloss1 = dloss[0]; + if (reduction == ReductionMode::kMean) { + dloss1 = dloss[0] / castT(dloss[0], input_size); + } + if (weight != nullptr) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); + T value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); + T value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } + } +} + +template +cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream) { + BinaryCrossEntropyLossGradKernel<<>>(input_size, reduction, input_x, + input_y, weight, dloss, dx); + return GetCudaStatus(); +} + +template +cudaError_t NLLLoss(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, + unsigned int label_size, unsigned int num_classes, const ReductionMode reduction, + int32_t ignore_index, cudaStream_t stream) { + const unsigned int Threads = 512; + if (reduction == ReductionMode::kNone) { + const unsigned int sharedSize = Threads * sizeof(S) + 1; + NLLLossNativeKernel<<>>( + logits, labels, weights, loss, total_weight, label_size, num_classes, ignore_index); + } else { + bool mean = (reduction == ReductionMode::kMean); + const unsigned int sharedSize0 = Threads * sizeof(T) + 1; + const unsigned int sharedSize1 = Threads * sizeof(S) + 1; + NLLLossReduceKernel<<>>( + logits, labels, weights, loss, total_weight, label_size, num_classes, ignore_index, mean); + } + cudaStreamSynchronize(stream); + return GetCudaStatus(); +} + +template +__global__ void NLLLossGradKernel(const int n, const int c, const ReductionMode reduction, const T *input, + const int32_t *target, const S *weight, const S *total_weight, int32_t ignore_index, + const T *dloss, T *dinput) { + int input_idx; + int target_class; + S tmp_quot; + if (reduction == ReductionMode::kNone) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + if (target_class == ignore_index) { + continue; + } + + input_idx = (i * c) + target_class; + + MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx); + } + } else if (reduction == ReductionMode::kMean) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + if (target_class == ignore_index) { + continue; + } + + input_idx = (i * c) + target_class; + + tmp_quot = (-weight[target_class]) / *total_weight; + MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx); + } + } else if (reduction == ReductionMode::kSum) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + if (target_class == ignore_index) { + continue; + } + + input_idx = (i * c) + target_class; + + MultiplyDevice(-weight[target_class], dloss[0], dinput + input_idx); + } + } +} + +template +cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, + const S *weight, const S *total_weight, const T *dloss, T *dinput, int32_t ignore_index, + cudaStream_t stream) { + int input_size = n * c; + InitZero<<>>(dinput, input_size); + + NLLLossGradKernel<<>>(n, c, reduction, input, target, weight, total_weight, + ignore_index, dloss, dinput); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t NLLLoss(const half *logits, const int32_t *labels, const half *weights, + half *loss, half *total_weight, const unsigned int label_size, + const unsigned int num_classes, const ReductionMode reduction, + int32_t ignore_index, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLoss(const half *logits, const int32_t *labels, + const float *weights, half *loss, float *total_weight, + unsigned int label_size, unsigned int num_classes, + const ReductionMode reduction, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLoss(const float *logits, const int32_t *labels, + const half *weights, float *loss, half *total_weight, + unsigned int label_size, unsigned int num_classes, + const ReductionMode reduction, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLoss(const float *logits, const int32_t *labels, + const float *weights, float *loss, float *total_weight, + unsigned int label_size, unsigned int num_classes, + const ReductionMode reduction, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, + const float *input_x, const float *input_y, float *loss, + float *tmp_loss, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, + const float *input_x, const float *input_y, + const float *dloss, float *dx, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, + const double *input_x, const double *input_y, double *loss, + double *tmp_loss, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, + const double *input_x, const double *input_y, + const double *dloss, double *dx, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, + const ReductionMode &reduction, const float *input_x, + const float *input_y, const float *weight, + float *loss, float *tmp_loss, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, + const ReductionMode &reduction, + const float *input_x, const float *input_y, + const float *weight, const float *dloss, + float *dx, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, + const float *input, const int32_t *target, + const float *weight, const float *total_weight, + const float *dloss, float *dinput, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, + const float *input, const int32_t *target, + const half *weight, const half *total_weight, + const float *dloss, float *dinput, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, + const half *input_x, const half *input_y, half *loss, + half *tmp_loss, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, + const half *input_x, const half *input_y, const half *dloss, + half *dx, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, + const half *input_x, const half *input_y, + const half *weight, half *loss, half *tmp_loss, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, + const ReductionMode &reduction, + const half *input_x, const half *input_y, + const half *weight, const half *dloss, half *dx, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, + const half *input, const int32_t *target, + const half *weight, const half *total_weight, + const half *dloss, half *dinput, int32_t ignore_index, + cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, + const half *input, const int32_t *target, + const float *weight, const float *total_weight, + const half *dloss, half *dinput, int32_t ignore_index, + cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh index 3a699b7fc09..99cf7da5d02 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh @@ -1,57 +1,57 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "mindapi/base/types.h" - -enum class ReductionMode { kNone, kMean, kSum }; - -static std::map kReductionModeMap{ - {"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}}; - -static std::map kEnumReductionModeMap{ - {static_cast(mindspore::Reduction::NONE), ReductionMode::kNone}, - {static_cast(mindspore::Reduction::MEAN), ReductionMode::kMean}, - {static_cast(mindspore::Reduction::REDUCTION_SUM), ReductionMode::kSum}}; - -template -CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, - const T *input_x, const T *input_y, const T *weight, T *loss, - T *tmp_loss, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, - const T *input_x, const T *input_y, const T *weight, - const T *dloss, T *dx, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, - const T *input_y, T *loss, T *tmp_loss, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, - const T *input_y, const T *dloss, T *dx, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t NLLLoss(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, - unsigned int label_size, unsigned int num_classes, const ReductionMode reduction, - int32_t ignore_index, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, - const int32_t *target, const S *weight, const S *total_weight, const T *dloss, - T *dinput, int32_t ignore_index, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "mindapi/base/types.h" + +enum class ReductionMode { kNone, kMean, kSum }; + +static std::map kReductionModeMap{ + {"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}}; + +static std::map kEnumReductionModeMap{ + {static_cast(mindspore::Reduction::NONE), ReductionMode::kNone}, + {static_cast(mindspore::Reduction::MEAN), ReductionMode::kMean}, + {static_cast(mindspore::Reduction::REDUCTION_SUM), ReductionMode::kSum}}; + +template +CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, + const T *input_x, const T *input_y, const T *weight, T *loss, + T *tmp_loss, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, + const T *input_x, const T *input_y, const T *weight, + const T *dloss, T *dx, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, T *loss, T *tmp_loss, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, const T *dloss, T *dx, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t NLLLoss(const T *logits, const int32_t *labels, const S *weights, T *loss, S *total_weight, + unsigned int label_size, unsigned int num_classes, const ReductionMode reduction, + int32_t ignore_index, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, + const int32_t *target, const S *weight, const S *total_weight, const T *dloss, + T *dinput, int32_t ignore_index, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOSS_WITH_REDUCTION_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cu index 09d4e1e1438..df6b0c35037 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cu @@ -1,217 +1,217 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "multinomial_impl.cuh" -#include - -template -inline T Floor(const T &num, const S &unit) { - return static_cast(num / unit); -} - -template -inline T Ceil(const T &num, const S &unit) { - return static_cast((num + unit - 1) / unit); -} - -__global__ void InitRandStateKernel(uint64_t seed, uint64_t seed_offset, int num, curandState *state) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { - curand_init(seed, i, seed_offset, &state[i]); - } -} - -cudaError_t InitRandState(uint64_t seed, uint64_t seed_offset, int num, curandState *state, cudaStream_t stream) { - InitRandStateKernel<<<(num + 127) / 128, 128, 0, stream>>>(seed, seed_offset, num, state); - return GetCudaStatus(); -} - -template -__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (distributions); pos += blockDim.x * gridDim.x) { - if (input[(1 + pos) * categories - 1] <= 0) { - out[0] = 1; - } - } - return; -} - -template -cudaError_t CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, - cudaStream_t cuda_stream) { - CheckZeroKernel<<>>(distributions, categories, input, output); - return GetCudaStatus(); -} - -template -__global__ void CheckNonNegKernel(const size_t size, const T *input, T *out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (input[pos] < 0) { - out[0] = 1; - } - } - return; -} - -template -cudaError_t CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { - CheckNonNegKernel<<>>(size, input, output); - return GetCudaStatus(); -} - -template -__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) { - int start = 0; - int end = size; - while (end - start > 0) { - int mid = start + (end - start) / 2; - T mid_val = start_addr[mid]; - if (mid_val < rand) { - start = mid + 1; - } else { - end = mid; - } - } - if (start == size) { - start = size - 1; - } - return start; -} - -template -__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output) { - // Load the probs to shared memory. - extern __shared__ float accum_probs[]; - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int probs_base_index = gid * col; - if (probs_base_index >= row * col) { - return; - } - - int shm_base_index = threadIdx.x * col; - accum_probs[shm_base_index] = probs[probs_base_index]; - for (int i = 1; i < col; i++) { - probs_base_index++; - float prob = static_cast(probs[probs_base_index]); - CUDA_KERNEL_ASSERT(prob >= 0); - CUDA_KERNEL_ASSERT(!isnan(prob)); - CUDA_KERNEL_ASSERT(!isinf(prob)); - accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + prob; - } - __syncthreads(); - - // Probs normalization. - float max_probs = accum_probs[shm_base_index + col - 1]; - for (int i = 0; i < col; i++) { - accum_probs[shm_base_index + i] /= max_probs; - } - __syncthreads(); - - // Sample. - int output_base_index = gid * num_sample[0]; - auto local_state = state[gid]; - for (int i = 0; i < num_sample[0]; i++) { - float rand = curand_uniform(&local_state); - output[output_base_index + i] = static_cast(BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand)); - } - state[gid] = local_state; -} - -template -cudaError_t Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output, - cudaStream_t stream) { - // Every block process several rows. It depends on shared memory usage. - constexpr int max_shm_used_per_block = 256; - int block_dim = std::max(Floor(std::min(row, max_shm_used_per_block), col), 1); - int grid_dim = Ceil(row, block_dim); - int shm_size = block_dim * col * sizeof(float); - - MultinomialKernel<<>>(row, col, probs, state, num_sample, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, float *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, double *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, half *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int8_t *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int16_t *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int32_t *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int64_t *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint8_t *probs, curandState *state, - int64_t *num_sample, int64_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint16_t *probs, - curandState *state, int64_t *num_sample, - int64_t *output, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint32_t *probs, - curandState *state, int64_t *num_sample, - int64_t *output, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint64_t *probs, - curandState *state, int64_t *num_sample, - int64_t *output, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, float *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, double *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, half *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int8_t *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int16_t *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int32_t *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int64_t *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint8_t *probs, curandState *state, - int64_t *num_sample, int32_t *output, - cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint16_t *probs, - curandState *state, int64_t *num_sample, - int32_t *output, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint32_t *probs, - curandState *state, int64_t *num_sample, - int32_t *output, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint64_t *probs, - curandState *state, int64_t *num_sample, - int32_t *output, cudaStream_t stream); - -template CUDA_LIB_EXPORT cudaError_t CheckNonNeg(const size_t size, const float *input, float *output, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CheckZero(const size_t distributions, const size_t categories, - const float *input, float *output, cudaStream_t cuda_stream); +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "multinomial_impl.cuh" +#include + +template +inline T Floor(const T &num, const S &unit) { + return static_cast(num / unit); +} + +template +inline T Ceil(const T &num, const S &unit) { + return static_cast((num + unit - 1) / unit); +} + +__global__ void InitRandStateKernel(uint64_t seed, uint64_t seed_offset, int num, curandState *state) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { + curand_init(seed, i, seed_offset, &state[i]); + } +} + +cudaError_t InitRandState(uint64_t seed, uint64_t seed_offset, int num, curandState *state, cudaStream_t stream) { + InitRandStateKernel<<<(num + 127) / 128, 128, 0, stream>>>(seed, seed_offset, num, state); + return GetCudaStatus(); +} + +template +__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (distributions); pos += blockDim.x * gridDim.x) { + if (input[(1 + pos) * categories - 1] <= 0) { + out[0] = 1; + } + } + return; +} + +template +cudaError_t CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, + cudaStream_t cuda_stream) { + CheckZeroKernel<<>>(distributions, categories, input, output); + return GetCudaStatus(); +} + +template +__global__ void CheckNonNegKernel(const size_t size, const T *input, T *out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (input[pos] < 0) { + out[0] = 1; + } + } + return; +} + +template +cudaError_t CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { + CheckNonNegKernel<<>>(size, input, output); + return GetCudaStatus(); +} + +template +__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) { + int start = 0; + int end = size; + while (end - start > 0) { + int mid = start + (end - start) / 2; + T mid_val = start_addr[mid]; + if (mid_val < rand) { + start = mid + 1; + } else { + end = mid; + } + } + if (start == size) { + start = size - 1; + } + return start; +} + +template +__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output) { + // Load the probs to shared memory. + extern __shared__ float accum_probs[]; + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int probs_base_index = gid * col; + if (probs_base_index >= row * col) { + return; + } + + int shm_base_index = threadIdx.x * col; + accum_probs[shm_base_index] = probs[probs_base_index]; + for (int i = 1; i < col; i++) { + probs_base_index++; + float prob = static_cast(probs[probs_base_index]); + CUDA_KERNEL_ASSERT(prob >= 0); + CUDA_KERNEL_ASSERT(!isnan(prob)); + CUDA_KERNEL_ASSERT(!isinf(prob)); + accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + prob; + } + __syncthreads(); + + // Probs normalization. + float max_probs = accum_probs[shm_base_index + col - 1]; + for (int i = 0; i < col; i++) { + accum_probs[shm_base_index + i] /= max_probs; + } + __syncthreads(); + + // Sample. + int output_base_index = gid * num_sample[0]; + auto local_state = state[gid]; + for (int i = 0; i < num_sample[0]; i++) { + float rand = curand_uniform(&local_state); + output[output_base_index + i] = static_cast(BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand)); + } + state[gid] = local_state; +} + +template +cudaError_t Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output, + cudaStream_t stream) { + // Every block process several rows. It depends on shared memory usage. + constexpr int max_shm_used_per_block = 256; + int block_dim = std::max(Floor(std::min(row, max_shm_used_per_block), col), 1); + int grid_dim = Ceil(row, block_dim); + int shm_size = block_dim * col * sizeof(float); + + MultinomialKernel<<>>(row, col, probs, state, num_sample, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, float *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, double *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, half *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int8_t *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int16_t *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int32_t *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int64_t *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint8_t *probs, curandState *state, + int64_t *num_sample, int64_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint16_t *probs, + curandState *state, int64_t *num_sample, + int64_t *output, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint32_t *probs, + curandState *state, int64_t *num_sample, + int64_t *output, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint64_t *probs, + curandState *state, int64_t *num_sample, + int64_t *output, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, float *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, double *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, half *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int8_t *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int16_t *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int32_t *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, int64_t *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint8_t *probs, curandState *state, + int64_t *num_sample, int32_t *output, + cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint16_t *probs, + curandState *state, int64_t *num_sample, + int32_t *output, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint32_t *probs, + curandState *state, int64_t *num_sample, + int32_t *output, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, uint64_t *probs, + curandState *state, int64_t *num_sample, + int32_t *output, cudaStream_t stream); + +template CUDA_LIB_EXPORT cudaError_t CheckNonNeg(const size_t size, const float *input, float *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CheckZero(const size_t distributions, const size_t categories, + const float *input, float *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cuh index 958c1734967..cfcff39006e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cuh @@ -1,34 +1,34 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -CUDA_LIB_EXPORT cudaError_t InitRandState(uint64_t seed, uint64_t seed_offset, int num, curandState *state, - cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, T *probs, curandState *rand_state, int64_t *num_sample, - S *output, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); -template -CUDA_LIB_EXPORT cudaError_t CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, - cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +CUDA_LIB_EXPORT cudaError_t InitRandState(uint64_t seed, uint64_t seed_offset, int num, curandState *state, + cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t Multinomial(int row, int col, T *probs, curandState *rand_state, int64_t *num_sample, + S *output, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); +template +CUDA_LIB_EXPORT cudaError_t CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MULTINOMIAL_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cu index a32b21db4f1..a3fef530e2d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cu @@ -1,83 +1,83 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "mvlgamma_grad_impl.cuh" -#define PI 3.141592653589793 - -__constant__ double kLanczosCoefficientsd[8] = {676.520368121885098567009190444019, -1259.13921672240287047156078755283, - 771.3234287776530788486528258894, -176.61502916214059906584551354, - 12.507343278686904814458936853, -0.13857109526572011689554707, - 9.984369578019570859563e-6, 1.50563273514931155834e-7}; -template -__device__ __forceinline__ T CalNumDivDenom(T x) { - T num = 0; - T denom = 0.99999999999980993227684700473478; - for (int j = 0; j < 8; ++j) { - num -= kLanczosCoefficientsd[j] / ((x + j + 1) * (x + j + 1)); - denom += kLanczosCoefficientsd[j] / (x + j + 1); - } - return num / denom; -} -template -__global__ void MvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - T kLanczosGamma = 7; - T log_lanczos_gamma_plus_one_half = log(7.5); - T temp = 0; - T cur_input = 0; - T num_div_denom = 0; - for (int i = 0; i < p; i++) { - cur_input = x[pos] - 0.5 * i; - if (cur_input < 0 && cur_input == floor(cur_input)) { - temp += std::numeric_limits::quiet_NaN(); - break; - } - if (cur_input < 0.5) { - num_div_denom = CalNumDivDenom(-cur_input); - temp += (log_lanczos_gamma_plus_one_half + log1pf((-cur_input) / (kLanczosGamma + 0.5))) + num_div_denom - - kLanczosGamma / (kLanczosGamma + 0.5 - cur_input); - temp -= PI / tan(PI * (cur_input + abs(floor(cur_input + 0.5)))); - } else { - num_div_denom = CalNumDivDenom(cur_input - 1); - temp += (log_lanczos_gamma_plus_one_half + log1pf((cur_input - 1) / (kLanczosGamma + 0.5))) + num_div_denom - - kLanczosGamma / (kLanczosGamma + 0.5 + cur_input - 1); - } - } - output[pos] = temp * y_grad[pos]; - } -} - -template -cudaError_t CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream) { - int thread_num = 256 < size ? 256 : size; - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, device_id); - int max_blocks = prop.multiProcessorCount; - int block_num = std::min(static_cast(((size - 1) / thread_num) + 1), max_blocks); - MvlgammaGrad<<>>(size, y_grad, x, p, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalMvlgammaGrad(const size_t size, const float *y_grad, const float *x, - const int p, float *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalMvlgammaGrad(const size_t size, const double *y_grad, const double *x, - const int p, double *output, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "mvlgamma_grad_impl.cuh" +#define PI 3.141592653589793 + +__constant__ double kLanczosCoefficientsd[8] = {676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +template +__device__ __forceinline__ T CalNumDivDenom(T x) { + T num = 0; + T denom = 0.99999999999980993227684700473478; + for (int j = 0; j < 8; ++j) { + num -= kLanczosCoefficientsd[j] / ((x + j + 1) * (x + j + 1)); + denom += kLanczosCoefficientsd[j] / (x + j + 1); + } + return num / denom; +} +template +__global__ void MvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T kLanczosGamma = 7; + T log_lanczos_gamma_plus_one_half = log(7.5); + T temp = 0; + T cur_input = 0; + T num_div_denom = 0; + for (int i = 0; i < p; i++) { + cur_input = x[pos] - 0.5 * i; + if (cur_input < 0 && cur_input == floor(cur_input)) { + temp += std::numeric_limits::quiet_NaN(); + break; + } + if (cur_input < 0.5) { + num_div_denom = CalNumDivDenom(-cur_input); + temp += (log_lanczos_gamma_plus_one_half + log1pf((-cur_input) / (kLanczosGamma + 0.5))) + num_div_denom - + kLanczosGamma / (kLanczosGamma + 0.5 - cur_input); + temp -= PI / tan(PI * (cur_input + abs(floor(cur_input + 0.5)))); + } else { + num_div_denom = CalNumDivDenom(cur_input - 1); + temp += (log_lanczos_gamma_plus_one_half + log1pf((cur_input - 1) / (kLanczosGamma + 0.5))) + num_div_denom - + kLanczosGamma / (kLanczosGamma + 0.5 + cur_input - 1); + } + } + output[pos] = temp * y_grad[pos]; + } +} + +template +cudaError_t CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream) { + int thread_num = 256 < size ? 256 : size; + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, device_id); + int max_blocks = prop.multiProcessorCount; + int block_num = std::min(static_cast(((size - 1) / thread_num) + 1), max_blocks); + MvlgammaGrad<<>>(size, y_grad, x, p, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalMvlgammaGrad(const size_t size, const float *y_grad, const float *x, + const int p, float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalMvlgammaGrad(const size_t size, const double *y_grad, const double *x, + const int p, double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh index 0f1b35a67c3..a2890325930 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -cudaError_t CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +cudaError_t CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cu index 2422c53242a..c5936736f31 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cu @@ -1,58 +1,58 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mvlgamma_impl.cuh" -#ifdef _WIN32 -// for M_PI -#define _USE_MATH_DEFINES -#include -#endif - -template -__global__ void Mvlgamma(const size_t size, const T *input, const int p, T *output, int *valid) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - T input_val = input[pos]; - if (isnan(input_val) || input_val <= (0.5 * (p - 1))) { - *valid = static_cast(pos); - return; - } - T temp = 0; - for (int i = 1; i <= p; i++) { - temp += lgamma(input_val - static_cast((i - 1) * 0.5)); - } - output[pos] = temp + static_cast(p * (p - 1) * 0.25 * log(M_PI)); - } - return; -} - -template -cudaError_t CalMvlgamma(int *valid, const size_t size, const T *input, const int p, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream, int *host_valid) { - *host_valid = -1; - int thread_num = size > 256 ? 256 : size; - cudaMemsetAsync(valid, -1, sizeof(int), cuda_stream); - Mvlgamma<<>>(size, input, p, output, valid); - cudaMemcpyAsync(host_valid, valid, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); - cudaStreamSynchronize(cuda_stream); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalMvlgamma(int *valid, const size_t size, const float *input, const int p, - float *output, const uint32_t &device_id, - cudaStream_t cuda_stream, int *host_valid); -template CUDA_LIB_EXPORT cudaError_t CalMvlgamma(int *valid, const size_t size, const double *input, - const int p, double *output, const uint32_t &device_id, - cudaStream_t cuda_streamy, int *host_valid); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mvlgamma_impl.cuh" +#ifdef _WIN32 +// for M_PI +#define _USE_MATH_DEFINES +#include +#endif + +template +__global__ void Mvlgamma(const size_t size, const T *input, const int p, T *output, int *valid) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T input_val = input[pos]; + if (isnan(input_val) || input_val <= (0.5 * (p - 1))) { + *valid = static_cast(pos); + return; + } + T temp = 0; + for (int i = 1; i <= p; i++) { + temp += lgamma(input_val - static_cast((i - 1) * 0.5)); + } + output[pos] = temp + static_cast(p * (p - 1) * 0.25 * log(M_PI)); + } + return; +} + +template +cudaError_t CalMvlgamma(int *valid, const size_t size, const T *input, const int p, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream, int *host_valid) { + *host_valid = -1; + int thread_num = size > 256 ? 256 : size; + cudaMemsetAsync(valid, -1, sizeof(int), cuda_stream); + Mvlgamma<<>>(size, input, p, output, valid); + cudaMemcpyAsync(host_valid, valid, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); + cudaStreamSynchronize(cuda_stream); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalMvlgamma(int *valid, const size_t size, const float *input, const int p, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream, int *host_valid); +template CUDA_LIB_EXPORT cudaError_t CalMvlgamma(int *valid, const size_t size, const double *input, + const int p, double *output, const uint32_t &device_id, + cudaStream_t cuda_streamy, int *host_valid); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh index 31b53b31c3e..85fee5470d1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -cudaError_t CalMvlgamma(int *valid, const size_t size, const T *input, const int p, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream, int *host_valid); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +cudaError_t CalMvlgamma(int *valid, const size_t size, const T *input, const int p, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream, int *host_valid); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh index 888221d5313..bc54478db5d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh @@ -1,27 +1,27 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalPDistGrad(const size_t x_size, const size_t y_size, const size_t grad_size, - const T *y_grad, const T *x, const T *y, const int64_t n, const int64_t m, - const float p, T *x_grad, T *buffer, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalPDistGrad(const size_t x_size, const size_t y_size, const size_t grad_size, + const T *y_grad, const T *x, const T *y, const int64_t n, const int64_t m, + const float p, T *x_grad, T *buffer, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cu index 4cea2739bd0..312448d8e3d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cu @@ -1,232 +1,232 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pdist_impl.cuh" -#include - -static const int threads = 256; - -template -__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if !defined(USE_ROCM) - return __shfl_down_sync(mask, value, delta, width); -#else - return __shfl_down(value, delta, width); -#endif -} - -template -__global__ void PDist_Zero(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, - const float n2) { - const int64_t pos = blockIdx.x; - const int s = blockDim.x; - - int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); - int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; - - const T *const begin = x + i * m; - const T *const end = begin + m; - const T *x_i = begin + threadIdx.x; - const T *x_j = x + j * m + threadIdx.x; - T res = 0.0; - for (; x_i < end; x_i += s, x_j += s) { - res += (*x_i == *x_j) ? 0 : 1; - } - - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - - __shared__ T shared[threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; - if (lane == 0) { - shared[warp_id] = res; - } - - __syncthreads(); - res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; - if (warp_id == 0) { - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - } - - if (threadIdx.x == 0) { - y[pos] = res; - } -} - -template -__global__ void PDist_One(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, - const float n2) { - const int64_t pos = blockIdx.x; - const int s = blockDim.x; - - int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); - int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; - - const T *const begin = x + i * m; - const T *const end = begin + m; - const T *x_i = begin + threadIdx.x; - const T *x_j = x + j * m + threadIdx.x; - T res = 0.0; - for (; x_i < end; x_i += s, x_j += s) { - res += abs(*x_i - *x_j); - } - - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - - __shared__ T shared[threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; - if (lane == 0) { - shared[warp_id] = res; - } - - __syncthreads(); - res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; - if (warp_id == 0) { - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - } - - if (threadIdx.x == 0) { - y[pos] = res; - } -} - -template -__global__ void PDist_Inf(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, - const float n2) { - const int64_t pos = blockIdx.x; - const int s = blockDim.x; - - // The -1 accounts for floating point truncation issues - int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); - int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; - - const T *const begin = x + i * m; - const T *const end = begin + m; - const T *x_i = begin + threadIdx.x; - const T *x_j = x + j * m + threadIdx.x; - T res = 0.0; - for (; x_i < end; x_i += s, x_j += s) { - res = abs(*x_i - *x_j) > res ? abs(*x_i - *x_j) : res; - } - - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - const T other = WARP_SHFL_DOWN(res, offset); - if (other > res) { - res = other; - } - } - - __shared__ T shared[threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; - if (lane == 0) { - shared[warp_id] = res; - } - - __syncthreads(); - res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; - if (warp_id == 0) { - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { - const T other = WARP_SHFL_DOWN(res, offset); - if (other > res) { - res = other; - } - } - } - - if (threadIdx.x == 0) { - y[pos] = res; - } -} - -template -__global__ void PDist_Other(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, - const float n2) { - const int64_t pos = blockIdx.x; - const int s = blockDim.x; - - // The -1 accounts for floating point truncation issues - int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); - int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; - - const T *const begin = x + i * m; - const T *const end = begin + m; - const T *x_i = begin + threadIdx.x; - const T *x_j = x + j * m + threadIdx.x; - T res = 0.0; - for (; x_i < end; x_i += s, x_j += s) { - res += pow(abs(*x_i - *x_j), static_cast(p)); - } - - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - - __shared__ T shared[threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; - if (lane == 0) { - shared[warp_id] = res; - } - - __syncthreads(); - res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; - if (warp_id == 0) { - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { - res += WARP_SHFL_DOWN(res, offset); - } - } - - if (threadIdx.x == 0) { - y[pos] = pow(res, static_cast(1.0 / p)); - } -} - -template -cudaError_t CalPDist(const size_t x_size, const size_t y_size, const T *x, T *y, const float p, const int64_t n, - const int64_t m, const uint32_t &device_id, cudaStream_t cuda_stream) { - const dim3 grid(y_size); - const dim3 block(threads); - const float n1 = n - .5; - const float n2 = n1 * n1 - 1; - if (p == 0.0) { - PDist_Zero<<>>(x, y, p, n, m, n1, n2); - } else if (p == 1.0) { - PDist_One<<>>(x, y, p, n, m, n1, n2); - } else if (std::isinf(p)) { - PDist_Inf<<>>(x, y, p, n, m, n1, n2); - } else { - PDist_Other<<>>(x, y, p, n, m, n1, n2); - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const float *x, float *y, - const float p, const int64_t n, const int64_t m, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const double *x, - double *y, const float p, const int64_t n, const int64_t m, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pdist_impl.cuh" +#include + +static const int threads = 256; + +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if !defined(USE_ROCM) + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +template +__global__ void PDist_Zero(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, + const float n2) { + const int64_t pos = blockIdx.x; + const int s = blockDim.x; + + int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); + int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; + + const T *const begin = x + i * m; + const T *const end = begin + m; + const T *x_i = begin + threadIdx.x; + const T *x_j = x + j * m + threadIdx.x; + T res = 0.0; + for (; x_i < end; x_i += s, x_j += s) { + res += (*x_i == *x_j) ? 0 : 1; + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + + __shared__ T shared[threads]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + if (lane == 0) { + shared[warp_id] = res; + } + + __syncthreads(); + res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + if (warp_id == 0) { + for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + } + + if (threadIdx.x == 0) { + y[pos] = res; + } +} + +template +__global__ void PDist_One(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, + const float n2) { + const int64_t pos = blockIdx.x; + const int s = blockDim.x; + + int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); + int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; + + const T *const begin = x + i * m; + const T *const end = begin + m; + const T *x_i = begin + threadIdx.x; + const T *x_j = x + j * m + threadIdx.x; + T res = 0.0; + for (; x_i < end; x_i += s, x_j += s) { + res += abs(*x_i - *x_j); + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + + __shared__ T shared[threads]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + if (lane == 0) { + shared[warp_id] = res; + } + + __syncthreads(); + res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + if (warp_id == 0) { + for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + } + + if (threadIdx.x == 0) { + y[pos] = res; + } +} + +template +__global__ void PDist_Inf(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, + const float n2) { + const int64_t pos = blockIdx.x; + const int s = blockDim.x; + + // The -1 accounts for floating point truncation issues + int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); + int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; + + const T *const begin = x + i * m; + const T *const end = begin + m; + const T *x_i = begin + threadIdx.x; + const T *x_j = x + j * m + threadIdx.x; + T res = 0.0; + for (; x_i < end; x_i += s, x_j += s) { + res = abs(*x_i - *x_j) > res ? abs(*x_i - *x_j) : res; + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + const T other = WARP_SHFL_DOWN(res, offset); + if (other > res) { + res = other; + } + } + + __shared__ T shared[threads]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + if (lane == 0) { + shared[warp_id] = res; + } + + __syncthreads(); + res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + if (warp_id == 0) { + for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + const T other = WARP_SHFL_DOWN(res, offset); + if (other > res) { + res = other; + } + } + } + + if (threadIdx.x == 0) { + y[pos] = res; + } +} + +template +__global__ void PDist_Other(const T *x, T *y, const float p, const int64_t n, const int64_t m, const float n1, + const float n2) { + const int64_t pos = blockIdx.x; + const int s = blockDim.x; + + // The -1 accounts for floating point truncation issues + int64_t i = static_cast((n1 - sqrt(n2 - 2 * pos))); + int64_t j = pos - n * i + i * (i + 1) / 2 + i + 1; + + const T *const begin = x + i * m; + const T *const end = begin + m; + const T *x_i = begin + threadIdx.x; + const T *x_j = x + j * m + threadIdx.x; + T res = 0.0; + for (; x_i < end; x_i += s, x_j += s) { + res += pow(abs(*x_i - *x_j), static_cast(p)); + } + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + + __shared__ T shared[threads]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + if (lane == 0) { + shared[warp_id] = res; + } + + __syncthreads(); + res = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + if (warp_id == 0) { + for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + res += WARP_SHFL_DOWN(res, offset); + } + } + + if (threadIdx.x == 0) { + y[pos] = pow(res, static_cast(1.0 / p)); + } +} + +template +cudaError_t CalPDist(const size_t x_size, const size_t y_size, const T *x, T *y, const float p, const int64_t n, + const int64_t m, const uint32_t &device_id, cudaStream_t cuda_stream) { + const dim3 grid(y_size); + const dim3 block(threads); + const float n1 = n - .5; + const float n2 = n1 * n1 - 1; + if (p == 0.0) { + PDist_Zero<<>>(x, y, p, n, m, n1, n2); + } else if (p == 1.0) { + PDist_One<<>>(x, y, p, n, m, n1, n2); + } else if (std::isinf(p)) { + PDist_Inf<<>>(x, y, p, n, m, n1, n2); + } else { + PDist_Other<<>>(x, y, p, n, m, n1, n2); + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const float *x, float *y, + const float p, const int64_t n, const int64_t m, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const double *x, + double *y, const float p, const int64_t n, const int64_t m, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh index 7ee4f198862..d82392b5dac 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh @@ -1,26 +1,26 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const T *x, T *y, const float p, - const int64_t n, const int64_t m, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalPDist(const size_t x_size, const size_t y_size, const T *x, T *y, const float p, + const int64_t n, const int64_t m, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PDIST_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cu index 583b5a3ea2c..43420477a14 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cu @@ -1,114 +1,114 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh" -#include -#include -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementwise/elementswise_pub_impl.cuh" -constexpr uint kThreadsPerBlock = cuda::elementwise::kThreadsPerBlock; - -template -using Complex = mindspore::utils::Complex; - -template -struct PolarFunctor { - __device__ __forceinline__ S operator()(const T abs, const T angle) const { - S output = 0; - output.real(abs * std::cos(angle)); - output.imag(abs * std::sin(angle)); - return output; - } -}; - -template -__device__ __forceinline__ void NormalCall(Func func, const T *abs_addr, const T *angle_addr, S *output, uint offset, - uint remaining) { - uint loop = UP_DIV(remaining, vec_size); - for (uint i = threadIdx.x; i < loop; i += blockDim.x) { -#pragma unroll - for (uint j = 0; j < vec_size; j++) { - uint index = i * vec_size + j; - if (index >= remaining) { - return; - } - index += offset; - output[index] = func(abs_addr[index], angle_addr[index]); - } - } -} - -template -__device__ __forceinline__ void VectorizedCall(Func func, const T *abs_addr, const T *angle_addr, S *output, - uint offset) { - uint tid = threadIdx.x; - - using VecT = cuda::elementwise::AlignVec; - using VecS = cuda::elementwise::AlignVec; - - auto vec_abs = reinterpret_cast(abs_addr + offset); - auto vec_angle = reinterpret_cast(angle_addr + offset); - auto vec_output = reinterpret_cast(output + offset); - VecT abs = vec_abs[tid]; - VecT angle = vec_angle[tid]; - VecS out{0}; - -#pragma unroll - for (uint j = 0; j < vec_size; j++) { - out.elements_[j] = func(abs.elements_[j], angle.elements_[j]); - } - vec_output[tid] = out; -} - -template -__global__ void PolarVectorized(Func func, const T *abs_addr, const T *angle_addr, S *output, uint num_of_elements) { - uint elements_per_block = kThreadsPerBlock * vec_size; - for (uint offset = elements_per_block * blockIdx.x; offset < num_of_elements; - offset += elements_per_block * gridDim.x) { - uint remaining = num_of_elements - offset; - if (remaining < elements_per_block) { - NormalCall(func, abs_addr, angle_addr, output, offset, remaining); - } else { - VectorizedCall(func, abs_addr, angle_addr, output, offset); - } - } -} - -template -cudaError_t CalPolar(const size_t size, const T *abs, const T *angle, S *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - constexpr uint vec_size = cuda::elementwise::VecSize(); - const auto block_x = uint(kThreadsPerBlock); - const uint elements_per_block = kThreadsPerBlock * vec_size; - const auto grid_x = uint(UP_DIV(size, elements_per_block)); - dim3 block{block_x}; - dim3 grid{grid_x}; - PolarFunctor functor{}; - PolarVectorized, vec_size, T, S> - <<>>(functor, abs, angle, output, size); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalPolar>(const size_t size, const float *abs, - const float *angle, Complex *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalPolar>(const size_t size, const double *abs, - const double *angle, Complex *output, - const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh" +#include +#include +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementwise/elementswise_pub_impl.cuh" +constexpr uint kThreadsPerBlock = cuda::elementwise::kThreadsPerBlock; + +template +using Complex = mindspore::utils::Complex; + +template +struct PolarFunctor { + __device__ __forceinline__ S operator()(const T abs, const T angle) const { + S output = 0; + output.real(abs * std::cos(angle)); + output.imag(abs * std::sin(angle)); + return output; + } +}; + +template +__device__ __forceinline__ void NormalCall(Func func, const T *abs_addr, const T *angle_addr, S *output, uint offset, + uint remaining) { + uint loop = UP_DIV(remaining, vec_size); + for (uint i = threadIdx.x; i < loop; i += blockDim.x) { +#pragma unroll + for (uint j = 0; j < vec_size; j++) { + uint index = i * vec_size + j; + if (index >= remaining) { + return; + } + index += offset; + output[index] = func(abs_addr[index], angle_addr[index]); + } + } +} + +template +__device__ __forceinline__ void VectorizedCall(Func func, const T *abs_addr, const T *angle_addr, S *output, + uint offset) { + uint tid = threadIdx.x; + + using VecT = cuda::elementwise::AlignVec; + using VecS = cuda::elementwise::AlignVec; + + auto vec_abs = reinterpret_cast(abs_addr + offset); + auto vec_angle = reinterpret_cast(angle_addr + offset); + auto vec_output = reinterpret_cast(output + offset); + VecT abs = vec_abs[tid]; + VecT angle = vec_angle[tid]; + VecS out{0}; + +#pragma unroll + for (uint j = 0; j < vec_size; j++) { + out.elements_[j] = func(abs.elements_[j], angle.elements_[j]); + } + vec_output[tid] = out; +} + +template +__global__ void PolarVectorized(Func func, const T *abs_addr, const T *angle_addr, S *output, uint num_of_elements) { + uint elements_per_block = kThreadsPerBlock * vec_size; + for (uint offset = elements_per_block * blockIdx.x; offset < num_of_elements; + offset += elements_per_block * gridDim.x) { + uint remaining = num_of_elements - offset; + if (remaining < elements_per_block) { + NormalCall(func, abs_addr, angle_addr, output, offset, remaining); + } else { + VectorizedCall(func, abs_addr, angle_addr, output, offset); + } + } +} + +template +cudaError_t CalPolar(const size_t size, const T *abs, const T *angle, S *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + constexpr uint vec_size = cuda::elementwise::VecSize(); + const auto block_x = uint(kThreadsPerBlock); + const uint elements_per_block = kThreadsPerBlock * vec_size; + const auto grid_x = uint(UP_DIV(size, elements_per_block)); + dim3 block{block_x}; + dim3 grid{grid_x}; + PolarFunctor functor{}; + PolarVectorized, vec_size, T, S> + <<>>(functor, abs, angle, output, size); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalPolar>(const size_t size, const float *abs, + const float *angle, Complex *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalPolar>(const size_t size, const double *abs, + const double *angle, Complex *output, + const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh index 6e3160ab3d5..80c79d0e8e6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalPolar(const size_t size, const T *abs, const T *angle, S *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalPolar(const size_t size, const T *abs, const T *angle, S *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_POLAR_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cu index 760224bd439..6128a62d5d6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cu @@ -1,42 +1,42 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "real_to_complex_impl.cuh" - -template -__global__ void ToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { - // set the complex real to original real, imag to 0j - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[2 * pos] = input[pos]; - } -} - -template -cudaError_t RealToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { - cudaMemsetAsync(output, 0, 2 * size * sizeof(T), cuda_stream); - ToComplex<<>>(size, input, output, cuda_stream); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const double *input, double *output, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const float *input, float *output, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const int *input, int *output, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const int64_t *input, int64_t *output, - cudaStream_t cuda_stream); +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "real_to_complex_impl.cuh" + +template +__global__ void ToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { + // set the complex real to original real, imag to 0j + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[2 * pos] = input[pos]; + } +} + +template +cudaError_t RealToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { + cudaMemsetAsync(output, 0, 2 * size * sizeof(T), cuda_stream); + ToComplex<<>>(size, input, output, cuda_stream); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const double *input, double *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const float *input, float *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const int *input, int *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const int64_t *input, int64_t *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh old mode 100755 new mode 100644 index ea0b6147b60..535f8be66a1 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh @@ -1,23 +1,23 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -template -CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT cudaError_t RealToComplex(const size_t size, const T *input, T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_REAL_TO_COMPLEX_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cu index 311b3607eb0..e39a359a094 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cu @@ -1,43 +1,43 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh" - -template -__global__ void ScaleGrad(const int nums, const T *x0, const S &x1, T *y) { - T x1_t = static_cast(x1); - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - y[pos] = x0[pos] * x1_t; - } -} - -template -cudaError_t ScaleGradKernel(const int &nums, const T *x0, const S &x1, T *y, cudaStream_t stream) { - ScaleGrad<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const float *x0, const float &x1, - float *y, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const float *x0, const half &x1, - float *y, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const half *x0, const float &x1, - half *y, cudaStream_t stream); -template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const half *x0, const half &x1, - half *y, cudaStream_t stream); +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh" + +template +__global__ void ScaleGrad(const int nums, const T *x0, const S &x1, T *y) { + T x1_t = static_cast(x1); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + y[pos] = x0[pos] * x1_t; + } +} + +template +cudaError_t ScaleGradKernel(const int &nums, const T *x0, const S &x1, T *y, cudaStream_t stream) { + ScaleGrad<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const float *x0, const float &x1, + float *y, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const float *x0, const half &x1, + float *y, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const half *x0, const float &x1, + half *y, cudaStream_t stream); +template CUDA_LIB_EXPORT cudaError_t ScaleGradKernel(const int &nums, const half *x0, const half &x1, + half *y, cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh index f41b0bcd577..a6bb7965f68 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scale_grad_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_SCALE_GRAD_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_SCALE_GRAD_IMPL_H_ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" -template -cudaError_t ScaleGradKernel(const int &nums, const T *x0, const S &x1, T *y, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_SCALE_GRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_SCALE_GRAD_IMPL_H_ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +template +cudaError_t ScaleGradKernel(const int &nums, const T *x0, const S &x1, T *y, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_nd.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_nd.cuh index ed82ad9568f..48c9d53396c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_nd.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_nd.cuh @@ -1,30 +1,30 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -struct ScatterNdInfo { - S indices_stride[8] = {0}; - S shape[8] = {0}; -}; -template -CUDA_LIB_EXPORT cudaError_t ScatterNd(S *indices, T *update, T *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, const ScatterNdInfo &info, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +struct ScatterNdInfo { + S indices_stride[8] = {0}; + S shape[8] = {0}; +}; +template +CUDA_LIB_EXPORT cudaError_t ScatterNd(S *indices, T *update, T *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, const ScatterNdInfo &info, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_ND_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cu index 85403585a4c..0bfe7970fb0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cu @@ -1,58 +1,58 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh" - -template -__global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, - const T *momentum, const T *lr, T *param, T *accum, T *stat) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - T grad_new = grad[i]; - if (weight_decay > static_cast(0)) { - grad_new += param[i] * weight_decay; - } - - if (momentum[0] > static_cast(0)) { - if (stat[i] > static_cast(0)) { - accum[i] = grad_new; - stat[i] = 0; - } else { - accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; - } - - if (nesterov) { - grad_new += accum[i] * momentum[0]; - } else { - grad_new = accum[i]; - } - } - - param[i] -= lr[0] * grad_new; - } -} - -template -cudaError_t SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, - const T *momentum, const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { - SGDKernel<<>>(size, dampening, weight_decay, nesterov, grad, momentum, - lr, param, accum, stat); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t SGD(const int size, const float dampening, const float weight_decay, - const bool nesterov, const float *lr, const float *momentum, const float *grad, - float *param, float *accum, float *stat, cudaStream_t cuda_stream); +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh" + +template +__global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, + const T *momentum, const T *lr, T *param, T *accum, T *stat) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + T grad_new = grad[i]; + if (weight_decay > static_cast(0)) { + grad_new += param[i] * weight_decay; + } + + if (momentum[0] > static_cast(0)) { + if (stat[i] > static_cast(0)) { + accum[i] = grad_new; + stat[i] = 0; + } else { + accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; + } + + if (nesterov) { + grad_new += accum[i] * momentum[0]; + } else { + grad_new = accum[i]; + } + } + + param[i] -= lr[0] * grad_new; + } +} + +template +cudaError_t SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, + const T *momentum, const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { + SGDKernel<<>>(size, dampening, weight_decay, nesterov, grad, momentum, + lr, param, accum, stat); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t SGD(const int size, const float dampening, const float weight_decay, + const bool nesterov, const float *lr, const float *momentum, const float *grad, + float *param, float *accum, float *stat, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh index 4a04022a12d..bb24b71816d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sgd_impl.cuh @@ -1,25 +1,25 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, - const T *lr, const T *momentum, const T *grad, T *param, T *accum, T *stat, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, + const T *lr, const T *momentum, const T *grad, T *param, T *accum, T *stat, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SGD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cu index 19066904643..63fecd06d40 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cu @@ -1,132 +1,132 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "sparse_add_grad_impl.cuh" -#include -template -__global__ void SparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size, - const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim, - S init_val) { - size_t stride = gridDim.x * blockDim.x; - size_t threadId = blockIdx.x * blockDim.x + threadIdx.x; - size_t x1_idx = threadId; - while (x1_idx < x1_size) { - size_t idx = x1_idx * dim; - for (size_t i = 0; i < dim; i++) { - temp_save_ptr[i] = x1_indices[idx + i]; - } - for (size_t i = 0; i < out_size; i++) { - auto oi = i * dim; - bool same_flag = true; - for (size_t j = 0; j < dim; j++) { - if (temp_save_ptr[j] != out_indices[oi + j]) { - same_flag = false; - break; - } - } - if (same_flag) { - dx1[x1_idx] = dout[i]; - break; - } - } - x1_idx += stride; - } - - size_t x2_idx = threadId; - while (x2_idx < x2_size) { - size_t idx = x2_idx * dim; - for (size_t i = 0; i < dim; i++) { - temp_save_ptr[i] = x2_indices[idx + i]; - } - for (size_t i = 0; i < out_size; i++) { - auto oi = i * dim; - bool same_flag = true; - for (size_t j = 0; j < dim; j++) { - if (temp_save_ptr[j] != out_indices[oi + j]) { - same_flag = false; - break; - } - } - if (same_flag) { - dx2[x2_idx] = dout[i]; - break; - } - } - x2_idx += stride; - } - return; -} - -template -cudaError_t CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size, - const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim, - const uint32_t &device_id, cudaStream_t cuda_stream) { - dim3 blockSize(1); - dim3 gridSize(1); - cudaMemset(dx1, 0, x1_size * sizeof(T)); - cudaMemset(dx2, 0, x2_size * sizeof(T)); - SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, - out_size, temp_save_ptr, dx1, dx2, dim, S(0)); - return GetCudaStatus(); -} - -template -cudaError_t CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, - size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, cuComplex *dx1, - cuComplex *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) { - dim3 blockSize(1); - dim3 gridSize(1); - cudaMemset(dx1, 0, x1_size * sizeof(T)); - cudaMemset(dx2, 0, x2_size * sizeof(T)); - SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, - out_size, temp_save_ptr, dx1, dx2, dim, {0, 0}); - return GetCudaStatus(); -} - -template -cudaError_t CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, - size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, - cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id, - cudaStream_t cuda_stream) { - dim3 blockSize(1); - dim3 gridSize(1); - cudaMemset(dx1, 0, x1_size * sizeof(T)); - cudaMemset(dx2, 0, x2_size * sizeof(T)); - SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, - out_size, temp_save_ptr, dx1, dx2, dim, {0, 0}); - return GetCudaStatus(); -} - -#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \ - template CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad( \ - const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \ - const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \ - size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream); - -#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \ - template CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad( \ - const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \ - const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \ - size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream); - -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int8_t) -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int16_t) -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int32_t) -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int64_t) -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, float) -GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, double) -GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(int64_t, cuComplex) -GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(int64_t, cuDoubleComplex) +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_add_grad_impl.cuh" +#include +template +__global__ void SparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size, + const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim, + S init_val) { + size_t stride = gridDim.x * blockDim.x; + size_t threadId = blockIdx.x * blockDim.x + threadIdx.x; + size_t x1_idx = threadId; + while (x1_idx < x1_size) { + size_t idx = x1_idx * dim; + for (size_t i = 0; i < dim; i++) { + temp_save_ptr[i] = x1_indices[idx + i]; + } + for (size_t i = 0; i < out_size; i++) { + auto oi = i * dim; + bool same_flag = true; + for (size_t j = 0; j < dim; j++) { + if (temp_save_ptr[j] != out_indices[oi + j]) { + same_flag = false; + break; + } + } + if (same_flag) { + dx1[x1_idx] = dout[i]; + break; + } + } + x1_idx += stride; + } + + size_t x2_idx = threadId; + while (x2_idx < x2_size) { + size_t idx = x2_idx * dim; + for (size_t i = 0; i < dim; i++) { + temp_save_ptr[i] = x2_indices[idx + i]; + } + for (size_t i = 0; i < out_size; i++) { + auto oi = i * dim; + bool same_flag = true; + for (size_t j = 0; j < dim; j++) { + if (temp_save_ptr[j] != out_indices[oi + j]) { + same_flag = false; + break; + } + } + if (same_flag) { + dx2[x2_idx] = dout[i]; + break; + } + } + x2_idx += stride; + } + return; +} + +template +cudaError_t CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size, + const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim, + const uint32_t &device_id, cudaStream_t cuda_stream) { + dim3 blockSize(1); + dim3 gridSize(1); + cudaMemset(dx1, 0, x1_size * sizeof(T)); + cudaMemset(dx2, 0, x2_size * sizeof(T)); + SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, + out_size, temp_save_ptr, dx1, dx2, dim, S(0)); + return GetCudaStatus(); +} + +template +cudaError_t CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, + size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, cuComplex *dx1, + cuComplex *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) { + dim3 blockSize(1); + dim3 gridSize(1); + cudaMemset(dx1, 0, x1_size * sizeof(T)); + cudaMemset(dx2, 0, x2_size * sizeof(T)); + SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, + out_size, temp_save_ptr, dx1, dx2, dim, {0, 0}); + return GetCudaStatus(); +} + +template +cudaError_t CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, + size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, + cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id, + cudaStream_t cuda_stream) { + dim3 blockSize(1); + dim3 gridSize(1); + cudaMemset(dx1, 0, x1_size * sizeof(T)); + cudaMemset(dx2, 0, x2_size * sizeof(T)); + SparseAddGrad<<>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, + out_size, temp_save_ptr, dx1, dx2, dim, {0, 0}); + return GetCudaStatus(); +} + +#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \ + template CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad( \ + const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \ + const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \ + size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream); + +#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \ + template CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad( \ + const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \ + const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \ + size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream); + +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int8_t) +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int16_t) +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int32_t) +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int64_t) +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, float) +GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, double) +GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(int64_t, cuComplex) +GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(int64_t, cuDoubleComplex) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cuh index a0b19762686..f5ae76f488a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_add_grad_impl.cuh @@ -1,38 +1,38 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, - size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, - S *dx1, S *dx2, size_t dim, const uint32_t &device_id, - cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, - const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size, - T *temp_save_ptr, cuComplex *dx1, cuComplex *dx2, size_t dim, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template -CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, - const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size, - T *temp_save_ptr, cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim, - const uint32_t &device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, + size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, + S *dx1, S *dx2, size_t dim, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, + const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size, + T *temp_save_ptr, cuComplex *dx1, cuComplex *dx2, size_t dim, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT cudaError_t CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, + const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size, + T *temp_save_ptr, cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cu index 115d2abd296..e4a4f1d8317 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cu @@ -1,208 +1,208 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh" -#include "include/cuda_fp16.h" - -template -__device__ __forceinline__ T RsqrtFunc(T x) { - return __frsqrt_rn(x); -} - -template <> -__device__ __forceinline__ half RsqrtFunc(half x) { - return hrsqrt(x); -} - -template <> -__device__ __forceinline__ double RsqrtFunc(double x) { - return rsqrt(x); -} - -template -__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, - T *learning_rate, T *decay_rate, T *epsilon, T *momentum, - const T *gradient, const S *indices, T *variable, T *mean_grad, - T *mean_square, T *mom, T *variable_out) { - const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); - const T con1 = static_cast(1); - for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); - pos += gridDim.x * blockDim.x) { - const int64_t index = pos / inner_size; - const int64_t inner_pos = pos % inner_size; - const int64_t grad_pos = pos; - const int64_t cur_pos = indices[index] * inner_size + inner_pos; - - mean_square[cur_pos] = - (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) * gradient[grad_pos] * gradient[grad_pos]; - mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate)); - const T denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos]; - mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum); - variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos]; - } -} -template -__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, - double *learning_rate, double *decay_rate, double *epsilon, - double *momentum, const double *gradient, const S *indices, - double *variable, double *mean_grad, double *mean_square, double *mom, - double *variable_out) { - const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); - const double con1 = static_cast(1); - for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); - pos += gridDim.x * blockDim.x) { - const int64_t index = pos / inner_size; - const int64_t inner_pos = pos % inner_size; - const int64_t grad_pos = pos; - const int64_t cur_pos = indices[index] * inner_size + inner_pos; - - mean_square[cur_pos] = - (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) * gradient[grad_pos] * gradient[grad_pos]; - mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate)); - const double denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos]; - mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum); - variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos]; - } -} -template -__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, - half *learning_rate, half *decay_rate, half *epsilon, half *momentum, - const half *gradient, const S *indices, half *variable, - half *mean_grad, half *mean_square, half *mom, half *variable_out) { - // const int64_t inner_size = static_cast(size / indices_size); - const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); - const float con1 = static_cast(1); - for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); - pos += gridDim.x * blockDim.x) { - const int64_t index = pos / inner_size; - const int64_t inner_pos = pos % inner_size; - const int64_t grad_pos = pos; - const int64_t cur_pos = indices[index] * inner_size + inner_pos; - - mean_square[cur_pos] = static_cast(*decay_rate) * static_cast(mean_square[cur_pos]) + - static_cast(con1 - static_cast(*decay_rate)) * - static_cast(gradient[grad_pos]) * static_cast(gradient[grad_pos]); - mean_grad[cur_pos] = static_cast(mean_grad[cur_pos]) * static_cast(*decay_rate) + - static_cast(gradient[grad_pos]) * (con1 - static_cast(*decay_rate)); - const float denom = static_cast(mean_square[cur_pos]) + static_cast(*epsilon) - - static_cast(mean_grad[cur_pos]) * static_cast(mean_grad[cur_pos]); - mom[cur_pos] = static_cast(*learning_rate) * static_cast(gradient[grad_pos]) * - static_cast(RsqrtFunc(denom)) + - static_cast(mom[cur_pos]) * static_cast(*momentum); - variable_out[cur_pos] = - static_cast(static_cast(variable[cur_pos]) - static_cast(mom[cur_pos])); - } -} - -template -cudaError_t CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size, const bool use_locking, - T *learning_rate, T *decay_rate, T *epsilon, T *momentum, const T *gradient, - const S *indices, T *variable, T *mean_grad, T *mean_square, T *mom, - T *variable_out, cudaStream_t cuda_stream) { - SparseApplyCenteredRMSPropUpdate<<>>( - size, indices_size, use_locking, learning_rate, decay_rate, epsilon, momentum, gradient, indices, variable, - mean_grad, mean_square, mom, variable_out); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, half *learning_rate, half *decay_rate, - half *epsilon, half *momentum, const half *gradient, const int32_t *indices, half *variable, half *mean_grad, - half *mean_square, half *mom, half *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, float *learning_rate, float *decay_rate, - float *epsilon, float *momentum, const float *gradient, const int32_t *indices, float *variable, float *mean_grad, - float *mean_square, float *mom, float *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, double *learning_rate, double *decay_rate, - double *epsilon, double *momentum, const double *gradient, const int32_t *indices, double *variable, - double *mean_grad, double *mean_square, double *mom, double *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int8_t *learning_rate, int8_t *decay_rate, - int8_t *epsilon, int8_t *momentum, const int8_t *gradient, const int32_t *indices, int8_t *variable, - int8_t *mean_grad, int8_t *mean_square, int8_t *mom, int8_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int16_t *learning_rate, int16_t *decay_rate, - int16_t *epsilon, int16_t *momentum, const int16_t *gradient, const int32_t *indices, int16_t *variable, - int16_t *mean_grad, int16_t *mean_square, int16_t *mom, int16_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int32_t *learning_rate, int32_t *decay_rate, - int32_t *epsilon, int32_t *momentum, const int32_t *gradient, const int32_t *indices, int32_t *variable, - int32_t *mean_grad, int32_t *mean_square, int32_t *mom, int32_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int64_t *learning_rate, int64_t *decay_rate, - int64_t *epsilon, int64_t *momentum, const int64_t *gradient, const int32_t *indices, int64_t *variable, - int64_t *mean_grad, int64_t *mean_square, int64_t *mom, int64_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint8_t *learning_rate, uint8_t *decay_rate, - uint8_t *epsilon, uint8_t *momentum, const uint8_t *gradient, const int32_t *indices, uint8_t *variable, - uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom, uint8_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint16_t *learning_rate, uint16_t *decay_rate, - uint16_t *epsilon, uint16_t *momentum, const uint16_t *gradient, const int32_t *indices, uint16_t *variable, - uint16_t *mean_grad, uint16_t *mean_square, uint16_t *mom, uint16_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint32_t *learning_rate, uint32_t *decay_rate, - uint32_t *epsilon, uint32_t *momentum, const uint32_t *gradient, const int32_t *indices, uint32_t *variable, - uint32_t *mean_grad, uint32_t *mean_square, uint32_t *mom, uint32_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint64_t *learning_rate, uint64_t *decay_rate, - uint64_t *epsilon, uint64_t *momentum, const uint64_t *gradient, const int32_t *indices, uint64_t *variable, - uint64_t *mean_grad, uint64_t *mean_square, uint64_t *mom, uint64_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, half *learning_rate, half *decay_rate, - half *epsilon, half *momentum, const half *gradient, const int64_t *indices, half *variable, half *mean_grad, - half *mean_square, half *mom, half *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, float *learning_rate, float *decay_rate, - float *epsilon, float *momentum, const float *gradient, const int64_t *indices, float *variable, float *mean_grad, - float *mean_square, float *mom, float *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, double *learning_rate, double *decay_rate, - double *epsilon, double *momentum, const double *gradient, const int64_t *indices, double *variable, - double *mean_grad, double *mean_square, double *mom, double *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int8_t *learning_rate, int8_t *decay_rate, - int8_t *epsilon, int8_t *momentum, const int8_t *gradient, const int64_t *indices, int8_t *variable, - int8_t *mean_grad, int8_t *mean_square, int8_t *mom, int8_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int16_t *learning_rate, int16_t *decay_rate, - int16_t *epsilon, int16_t *momentum, const int16_t *gradient, const int64_t *indices, int16_t *variable, - int16_t *mean_grad, int16_t *mean_square, int16_t *mom, int16_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int32_t *learning_rate, int32_t *decay_rate, - int32_t *epsilon, int32_t *momentum, const int32_t *gradient, const int64_t *indices, int32_t *variable, - int32_t *mean_grad, int32_t *mean_square, int32_t *mom, int32_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, int64_t *learning_rate, int64_t *decay_rate, - int64_t *epsilon, int64_t *momentum, const int64_t *gradient, const int64_t *indices, int64_t *variable, - int64_t *mean_grad, int64_t *mean_square, int64_t *mom, int64_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint8_t *learning_rate, uint8_t *decay_rate, - uint8_t *epsilon, uint8_t *momentum, const uint8_t *gradient, const int64_t *indices, uint8_t *variable, - uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom, uint8_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint16_t *learning_rate, uint16_t *decay_rate, - uint16_t *epsilon, uint16_t *momentum, const uint16_t *gradient, const int64_t *indices, uint16_t *variable, - uint16_t *mean_grad, uint16_t *mean_square, uint16_t *mom, uint16_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint32_t *learning_rate, uint32_t *decay_rate, - uint32_t *epsilon, uint32_t *momentum, const uint32_t *gradient, const int64_t *indices, uint32_t *variable, - uint32_t *mean_grad, uint32_t *mean_square, uint32_t *mom, uint32_t *variable_out, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( - const size_t size, const size_t indices_size, const bool use_locking, uint64_t *learning_rate, uint64_t *decay_rate, - uint64_t *epsilon, uint64_t *momentum, const uint64_t *gradient, const int64_t *indices, uint64_t *variable, - uint64_t *mean_grad, uint64_t *mean_square, uint64_t *mom, uint64_t *variable_out, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh" +#include "include/cuda_fp16.h" + +template +__device__ __forceinline__ T RsqrtFunc(T x) { + return __frsqrt_rn(x); +} + +template <> +__device__ __forceinline__ half RsqrtFunc(half x) { + return hrsqrt(x); +} + +template <> +__device__ __forceinline__ double RsqrtFunc(double x) { + return rsqrt(x); +} + +template +__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, + T *learning_rate, T *decay_rate, T *epsilon, T *momentum, + const T *gradient, const S *indices, T *variable, T *mean_grad, + T *mean_square, T *mom, T *variable_out) { + const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); + const T con1 = static_cast(1); + for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); + pos += gridDim.x * blockDim.x) { + const int64_t index = pos / inner_size; + const int64_t inner_pos = pos % inner_size; + const int64_t grad_pos = pos; + const int64_t cur_pos = indices[index] * inner_size + inner_pos; + + mean_square[cur_pos] = + (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) * gradient[grad_pos] * gradient[grad_pos]; + mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate)); + const T denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos]; + mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum); + variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos]; + } +} +template +__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, + double *learning_rate, double *decay_rate, double *epsilon, + double *momentum, const double *gradient, const S *indices, + double *variable, double *mean_grad, double *mean_square, double *mom, + double *variable_out) { + const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); + const double con1 = static_cast(1); + for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); + pos += gridDim.x * blockDim.x) { + const int64_t index = pos / inner_size; + const int64_t inner_pos = pos % inner_size; + const int64_t grad_pos = pos; + const int64_t cur_pos = indices[index] * inner_size + inner_pos; + + mean_square[cur_pos] = + (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) * gradient[grad_pos] * gradient[grad_pos]; + mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate)); + const double denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos]; + mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum); + variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos]; + } +} +template +__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking, + half *learning_rate, half *decay_rate, half *epsilon, half *momentum, + const half *gradient, const S *indices, half *variable, + half *mean_grad, half *mean_square, half *mom, half *variable_out) { + // const int64_t inner_size = static_cast(size / indices_size); + const int64_t inner_size = static_cast(size * sizeof(int64_t) / sizeof(S)); + const float con1 = static_cast(1); + for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast(size); + pos += gridDim.x * blockDim.x) { + const int64_t index = pos / inner_size; + const int64_t inner_pos = pos % inner_size; + const int64_t grad_pos = pos; + const int64_t cur_pos = indices[index] * inner_size + inner_pos; + + mean_square[cur_pos] = static_cast(*decay_rate) * static_cast(mean_square[cur_pos]) + + static_cast(con1 - static_cast(*decay_rate)) * + static_cast(gradient[grad_pos]) * static_cast(gradient[grad_pos]); + mean_grad[cur_pos] = static_cast(mean_grad[cur_pos]) * static_cast(*decay_rate) + + static_cast(gradient[grad_pos]) * (con1 - static_cast(*decay_rate)); + const float denom = static_cast(mean_square[cur_pos]) + static_cast(*epsilon) - + static_cast(mean_grad[cur_pos]) * static_cast(mean_grad[cur_pos]); + mom[cur_pos] = static_cast(*learning_rate) * static_cast(gradient[grad_pos]) * + static_cast(RsqrtFunc(denom)) + + static_cast(mom[cur_pos]) * static_cast(*momentum); + variable_out[cur_pos] = + static_cast(static_cast(variable[cur_pos]) - static_cast(mom[cur_pos])); + } +} + +template +cudaError_t CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size, const bool use_locking, + T *learning_rate, T *decay_rate, T *epsilon, T *momentum, const T *gradient, + const S *indices, T *variable, T *mean_grad, T *mean_square, T *mom, + T *variable_out, cudaStream_t cuda_stream) { + SparseApplyCenteredRMSPropUpdate<<>>( + size, indices_size, use_locking, learning_rate, decay_rate, epsilon, momentum, gradient, indices, variable, + mean_grad, mean_square, mom, variable_out); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, half *learning_rate, half *decay_rate, + half *epsilon, half *momentum, const half *gradient, const int32_t *indices, half *variable, half *mean_grad, + half *mean_square, half *mom, half *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, float *learning_rate, float *decay_rate, + float *epsilon, float *momentum, const float *gradient, const int32_t *indices, float *variable, float *mean_grad, + float *mean_square, float *mom, float *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, double *learning_rate, double *decay_rate, + double *epsilon, double *momentum, const double *gradient, const int32_t *indices, double *variable, + double *mean_grad, double *mean_square, double *mom, double *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int8_t *learning_rate, int8_t *decay_rate, + int8_t *epsilon, int8_t *momentum, const int8_t *gradient, const int32_t *indices, int8_t *variable, + int8_t *mean_grad, int8_t *mean_square, int8_t *mom, int8_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int16_t *learning_rate, int16_t *decay_rate, + int16_t *epsilon, int16_t *momentum, const int16_t *gradient, const int32_t *indices, int16_t *variable, + int16_t *mean_grad, int16_t *mean_square, int16_t *mom, int16_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int32_t *learning_rate, int32_t *decay_rate, + int32_t *epsilon, int32_t *momentum, const int32_t *gradient, const int32_t *indices, int32_t *variable, + int32_t *mean_grad, int32_t *mean_square, int32_t *mom, int32_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int64_t *learning_rate, int64_t *decay_rate, + int64_t *epsilon, int64_t *momentum, const int64_t *gradient, const int32_t *indices, int64_t *variable, + int64_t *mean_grad, int64_t *mean_square, int64_t *mom, int64_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint8_t *learning_rate, uint8_t *decay_rate, + uint8_t *epsilon, uint8_t *momentum, const uint8_t *gradient, const int32_t *indices, uint8_t *variable, + uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom, uint8_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint16_t *learning_rate, uint16_t *decay_rate, + uint16_t *epsilon, uint16_t *momentum, const uint16_t *gradient, const int32_t *indices, uint16_t *variable, + uint16_t *mean_grad, uint16_t *mean_square, uint16_t *mom, uint16_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint32_t *learning_rate, uint32_t *decay_rate, + uint32_t *epsilon, uint32_t *momentum, const uint32_t *gradient, const int32_t *indices, uint32_t *variable, + uint32_t *mean_grad, uint32_t *mean_square, uint32_t *mom, uint32_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint64_t *learning_rate, uint64_t *decay_rate, + uint64_t *epsilon, uint64_t *momentum, const uint64_t *gradient, const int32_t *indices, uint64_t *variable, + uint64_t *mean_grad, uint64_t *mean_square, uint64_t *mom, uint64_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, half *learning_rate, half *decay_rate, + half *epsilon, half *momentum, const half *gradient, const int64_t *indices, half *variable, half *mean_grad, + half *mean_square, half *mom, half *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, float *learning_rate, float *decay_rate, + float *epsilon, float *momentum, const float *gradient, const int64_t *indices, float *variable, float *mean_grad, + float *mean_square, float *mom, float *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, double *learning_rate, double *decay_rate, + double *epsilon, double *momentum, const double *gradient, const int64_t *indices, double *variable, + double *mean_grad, double *mean_square, double *mom, double *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int8_t *learning_rate, int8_t *decay_rate, + int8_t *epsilon, int8_t *momentum, const int8_t *gradient, const int64_t *indices, int8_t *variable, + int8_t *mean_grad, int8_t *mean_square, int8_t *mom, int8_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int16_t *learning_rate, int16_t *decay_rate, + int16_t *epsilon, int16_t *momentum, const int16_t *gradient, const int64_t *indices, int16_t *variable, + int16_t *mean_grad, int16_t *mean_square, int16_t *mom, int16_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int32_t *learning_rate, int32_t *decay_rate, + int32_t *epsilon, int32_t *momentum, const int32_t *gradient, const int64_t *indices, int32_t *variable, + int32_t *mean_grad, int32_t *mean_square, int32_t *mom, int32_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, int64_t *learning_rate, int64_t *decay_rate, + int64_t *epsilon, int64_t *momentum, const int64_t *gradient, const int64_t *indices, int64_t *variable, + int64_t *mean_grad, int64_t *mean_square, int64_t *mom, int64_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint8_t *learning_rate, uint8_t *decay_rate, + uint8_t *epsilon, uint8_t *momentum, const uint8_t *gradient, const int64_t *indices, uint8_t *variable, + uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom, uint8_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint16_t *learning_rate, uint16_t *decay_rate, + uint16_t *epsilon, uint16_t *momentum, const uint16_t *gradient, const int64_t *indices, uint16_t *variable, + uint16_t *mean_grad, uint16_t *mean_square, uint16_t *mom, uint16_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint32_t *learning_rate, uint32_t *decay_rate, + uint32_t *epsilon, uint32_t *momentum, const uint32_t *gradient, const int64_t *indices, uint32_t *variable, + uint32_t *mean_grad, uint32_t *mean_square, uint32_t *mom, uint32_t *variable_out, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp( + const size_t size, const size_t indices_size, const bool use_locking, uint64_t *learning_rate, uint64_t *decay_rate, + uint64_t *epsilon, uint64_t *momentum, const uint64_t *gradient, const int64_t *indices, uint64_t *variable, + uint64_t *mean_grad, uint64_t *mean_square, uint64_t *mom, uint64_t *variable_out, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh index dcf885bb4dc..6bed098d371 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh @@ -1,28 +1,28 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -template -CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size, - const bool use_locking, T *learning_rate, T *decay_rate, - T *epsilon, T *momentum, const T *gradient, const S *indices, - T *variable, T *mean_grad, T *mean_square, T *mom, - T *variable_out, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT cudaError_t CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size, + const bool use_locking, T *learning_rate, T *decay_rate, + T *epsilon, T *momentum, const T *gradient, const S *indices, + T *variable, T *mean_grad, T *mean_square, T *mom, + T *variable_out, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cu index 430ed5fd857..35b55fc0309 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cu @@ -1,39 +1,39 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh" - -template -__global__ void SparseMatrixNNZ(const size_t size, const T *input, int32_t *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = input[pos + 1] - (pos == 0 ? 0 : input[pos]); - } - return; -} - -template -cudaError_t CalSparseMatrixNNZ(const size_t size, const T *input, int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - SparseMatrixNNZ<<>>(size, input, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalSparseMatrixNNZ(const size_t size, const int32_t *input, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalSparseMatrixNNZ(const size_t size, const int64_t *input, - int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh" + +template +__global__ void SparseMatrixNNZ(const size_t size, const T *input, int32_t *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = input[pos + 1] - (pos == 0 ? 0 : input[pos]); + } + return; +} + +template +cudaError_t CalSparseMatrixNNZ(const size_t size, const T *input, int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + SparseMatrixNNZ<<>>(size, input, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalSparseMatrixNNZ(const size_t size, const int32_t *input, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalSparseMatrixNNZ(const size_t size, const int64_t *input, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh index dbdfcde585b..6f2e4e3c1ce 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh @@ -1,24 +1,24 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -cudaError_t CalSparseMatrixNNZ(const size_t size, const T *input, int32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +cudaError_t CalSparseMatrixNNZ(const size_t size, const T *input, int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_MATRIX_NNZ_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu index f89c73f1809..d51036761ee 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu @@ -1,217 +1,217 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" - -template -__global__ void SparseSegmentPosKernel(const S *indices_ptr, size_t *indices_pos_ptr, size_t idx_seg_size, - size_t outer_size) { - for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { - const S max_size = static_cast(outer_size); - const S min_size = S(0); - S beg_idx = (id == 0) ? min_size : indices_ptr[id - 1] + 1; - S end_idx = (id >= idx_seg_size) ? max_size : indices_ptr[id]; - beg_idx = max(min_size, min(max_size, beg_idx)); - end_idx = max(min_size, min(max_size, end_idx)); - for (S i = beg_idx; i <= end_idx; i++) { - indices_pos_ptr[i] = id; - } - } -} - -template -__global__ void SparseSegmentSumGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, - const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, - size_t output_dim0, R *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { - size_t beg_pos = indices_pos_ptr[inid]; - size_t end_pos = indices_pos_ptr[inid + 1]; - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; - MsAtomicAdd(out_pos, static_cast(reduce_result)); - } - } - } - } -} - -template -__global__ void SparseSegmentSumGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, - const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, - size_t output_dim0, half *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { - size_t beg_pos = indices_pos_ptr[inid]; - size_t end_pos = indices_pos_ptr[inid + 1]; - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); - } - if (threadIdx.y == 0 && inner_valid) { - half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; - MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result))); - } - } - } - } -} - -template -__global__ void SparseSegmentSqrtNGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, - const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, - size_t output_dim0, R *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { - size_t beg_pos = indices_pos_ptr[inid]; - size_t end_pos = indices_pos_ptr[inid + 1]; - double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; - MsAtomicAdd(out_pos, static_cast(reduce_result / sqrt_segment_len)); - } - } - } - } -} - -template -__global__ void SparseSegmentSqrtNGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, - const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, - size_t output_dim0, half *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { - size_t beg_pos = indices_pos_ptr[inid]; - size_t end_pos = indices_pos_ptr[inid + 1]; - double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); - } - if (threadIdx.y == 0 && inner_valid) { - half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; - MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result / sqrt_segment_len))); - } - } - } - } -} - -inline int Log2Floor_M(uint32_t n) { - if (n == 0) return -1; - int log = 0; - for (int i = 4; i >= 0; --i) { - int shift = (1 << i); - uint32_t x = n >> shift; - if (x) { - n = x; - log += shift; - } - } - return log; -} - -inline int Log2Floor64_M(uint64_t n) { - // Scan n first high 32 then low 32 bits. - const uint32_t high_32_bit = static_cast(n >> 32); - if (high_32_bit == 0) { - return Log2Floor_M(static_cast(n)); - } else { - return 32 + Log2Floor_M(high_32_bit); - } -} - -inline int Log2Ceil64_M(uint64_t n) { - int floor = Log2Floor64_M(n); - if (n == (n & ~(n - 1))) - return floor; - else - return floor + 1; -} - -template -cudaError_t CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, const S *indices_ptr, - const S *segment_ids_ptr, size_t *indices_pos_ptr, size_t outer_size, - size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, - uint32_t device_id, cudaStream_t cuda_stream) { - // Get start position of each segment and set to indices_pos_ptr. - // The last element of indices_pos_ptr must equal to idx_seg_size. - SparseSegmentPosKernel<<>>( - indices_ptr, indices_pos_ptr, idx_seg_size, outer_size); - const unsigned int max_grid_x = (1u << 31) - 1; - const unsigned int max_grid_y = (1u << 16) - 1; - unsigned int block_x = 32; - unsigned int block_y = 1; - unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); - unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); - dim3 block(block_x, block_y); - dim3 grid(grid_x, grid_y); - unsigned int shared_memory_size = block_x * block_y * sizeof(R); - if (kernel_type == "SparseSegmentSumGrad") { - SparseSegmentSumGradKernel<<>>( - grad_ptr, indices_ptr, segment_ids_ptr, indices_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); - } else if (kernel_type == "SparseSegmentSqrtNGrad") { - SparseSegmentSqrtNGradKernel<<>>( - grad_ptr, indices_ptr, segment_ids_ptr, indices_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); - } - return GetCudaStatus(); -} - -#define ADD_SPARSE_SEGMENT_GRAD(R, S) \ - template CUDA_LIB_EXPORT cudaError_t CalSparseSegmentGradCombination( \ - const std::string kernel_type, const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, \ - size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, \ - uint32_t device_id, cudaStream_t cuda_stream); - -ADD_SPARSE_SEGMENT_GRAD(half, int32_t) -ADD_SPARSE_SEGMENT_GRAD(half, int64_t) - -ADD_SPARSE_SEGMENT_GRAD(float, int32_t) -ADD_SPARSE_SEGMENT_GRAD(float, int64_t) - -ADD_SPARSE_SEGMENT_GRAD(double, int32_t) -ADD_SPARSE_SEGMENT_GRAD(double, int64_t) +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" + +template +__global__ void SparseSegmentPosKernel(const S *indices_ptr, size_t *indices_pos_ptr, size_t idx_seg_size, + size_t outer_size) { + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { + const S max_size = static_cast(outer_size); + const S min_size = S(0); + S beg_idx = (id == 0) ? min_size : indices_ptr[id - 1] + 1; + S end_idx = (id >= idx_seg_size) ? max_size : indices_ptr[id]; + beg_idx = max(min_size, min(max_size, beg_idx)); + end_idx = max(min_size, min(max_size, end_idx)); + for (S i = beg_idx; i <= end_idx; i++) { + indices_pos_ptr[i] = id; + } + } +} + +template +__global__ void SparseSegmentSumGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, static_cast(reduce_result)); + } + } + } + } +} + +template +__global__ void SparseSegmentSumGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); + } + if (threadIdx.y == 0 && inner_valid) { + half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result))); + } + } + } + } +} + +template +__global__ void SparseSegmentSqrtNGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, static_cast(reduce_result / sqrt_segment_len)); + } + } + } + } +} + +template +__global__ void SparseSegmentSqrtNGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); + } + if (threadIdx.y == 0 && inner_valid) { + half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result / sqrt_segment_len))); + } + } + } + } +} + +inline int Log2Floor_M(uint32_t n) { + if (n == 0) return -1; + int log = 0; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t x = n >> shift; + if (x) { + n = x; + log += shift; + } + } + return log; +} + +inline int Log2Floor64_M(uint64_t n) { + // Scan n first high 32 then low 32 bits. + const uint32_t high_32_bit = static_cast(n >> 32); + if (high_32_bit == 0) { + return Log2Floor_M(static_cast(n)); + } else { + return 32 + Log2Floor_M(high_32_bit); + } +} + +inline int Log2Ceil64_M(uint64_t n) { + int floor = Log2Floor64_M(n); + if (n == (n & ~(n - 1))) + return floor; + else + return floor + 1; +} + +template +cudaError_t CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, const S *indices_ptr, + const S *segment_ids_ptr, size_t *indices_pos_ptr, size_t outer_size, + size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream) { + // Get start position of each segment and set to indices_pos_ptr. + // The last element of indices_pos_ptr must equal to idx_seg_size. + SparseSegmentPosKernel<<>>( + indices_ptr, indices_pos_ptr, idx_seg_size, outer_size); + const unsigned int max_grid_x = (1u << 31) - 1; + const unsigned int max_grid_y = (1u << 16) - 1; + unsigned int block_x = 32; + unsigned int block_y = 1; + unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); + unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); + dim3 block(block_x, block_y); + dim3 grid(grid_x, grid_y); + unsigned int shared_memory_size = block_x * block_y * sizeof(R); + if (kernel_type == "SparseSegmentSumGrad") { + SparseSegmentSumGradKernel<<>>( + grad_ptr, indices_ptr, segment_ids_ptr, indices_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); + } else if (kernel_type == "SparseSegmentSqrtNGrad") { + SparseSegmentSqrtNGradKernel<<>>( + grad_ptr, indices_ptr, segment_ids_ptr, indices_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); + } + return GetCudaStatus(); +} + +#define ADD_SPARSE_SEGMENT_GRAD(R, S) \ + template CUDA_LIB_EXPORT cudaError_t CalSparseSegmentGradCombination( \ + const std::string kernel_type, const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, \ + size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, \ + uint32_t device_id, cudaStream_t cuda_stream); + +ADD_SPARSE_SEGMENT_GRAD(half, int32_t) +ADD_SPARSE_SEGMENT_GRAD(half, int64_t) + +ADD_SPARSE_SEGMENT_GRAD(float, int32_t) +ADD_SPARSE_SEGMENT_GRAD(float, int64_t) + +ADD_SPARSE_SEGMENT_GRAD(double, int32_t) +ADD_SPARSE_SEGMENT_GRAD(double, int64_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh index 2eae7872ad9..78c993eb69d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh @@ -1,30 +1,30 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, - const S *indices_ptr, const S *segment_ids_ptr, - size_t *indices_pos_ptr, size_t outer_size, - size_t inner_size, size_t idx_seg_size, size_t output_dim0, - R *y_ptr, uint32_t device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, + const S *indices_ptr, const S *segment_ids_ptr, + size_t *indices_pos_ptr, size_t outer_size, + size_t inner_size, size_t idx_seg_size, size_t output_dim0, + R *y_ptr, uint32_t device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu index ff7259ed281..95b1c5b3ef8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu @@ -1,303 +1,303 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" - -template -__global__ void SparseSegmentPosKernel(const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t idx_seg_size, - size_t output_dim0) { - for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { - const S max_size = static_cast(output_dim0); - const S min_size = S(0); - S beg_idx = (id == 0) ? min_size : segment_ids_ptr[id - 1] + 1; - S end_idx = (id >= idx_seg_size) ? max_size : segment_ids_ptr[id]; - beg_idx = max(min_size, min(max_size, beg_idx)); - end_idx = max(min_size, min(max_size, end_idx)); - for (S i = beg_idx; i <= end_idx; i++) { - segment_pos_ptr[i] = id; - } - } -} - -template -__global__ void SparseSegmentSumKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - R segment_sum = 0; - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - R reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = x_ptr[index * inner_size + inner_idx]; - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum; - } - } - } -} - -template -__global__ void SparseSegmentSumKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - double segment_sum = 0; - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = - beg_pos == end_pos ? static_cast(0) : static_cast(segment_sum); - } - } - } -} - -template -__global__ void SparseSegmentSumKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - float segment_sum = 0; - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - float reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : __float2half(segment_sum); - } - } - } -} - -template -__global__ void SparseSegmentSqrtNKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - R segment_sum = 0; - R sqrt_segment_len = R(sqrt(static_cast(end_pos - beg_pos))); - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - R reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = x_ptr[index * inner_size + inner_idx]; - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum / sqrt_segment_len; - } - } - } -} - -template -__global__ void SparseSegmentSqrtNKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - double segment_sum = 0; - double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - double reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = - beg_pos == end_pos ? static_cast(0) : static_cast(segment_sum / sqrt_segment_len); - } - } - } -} - -template -__global__ void SparseSegmentSqrtNKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, - size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { - size_t num_blocks = (inner_size - 1) / blockDim.x + 1; - for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { - size_t inner_idx = threadIdx.x + bid * blockDim.x; - bool inner_valid = inner_idx < inner_size; - for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { - size_t beg_pos = segment_pos_ptr[sid]; - size_t end_pos = segment_pos_ptr[sid + 1]; - float segment_sum = 0; - float sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); - for (size_t pos = beg_pos; pos < end_pos; pos += 1) { - float reduce_result = 0; - S index = inner_valid ? indices_ptr[pos] : outer_size; - if (index >= 0 && index < outer_size) { - reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); - } - if (threadIdx.y == 0 && inner_valid) { - segment_sum += reduce_result; - } - } - if (threadIdx.y == 0 && inner_valid) { - y_ptr[sid * inner_size + inner_idx] = - beg_pos == end_pos ? half(0) : __float2half(segment_sum / sqrt_segment_len); - } - } - } -} - -inline int Log2Floor(uint32_t n) { - if (n == 0) return -1; - int log = 0; - for (int i = 4; i >= 0; --i) { - int shift = (1 << i); - uint32_t x = n >> shift; - if (x) { - n = x; - log += shift; - } - } - return log; -} - -inline int Log2Floor64(uint64_t n) { - // Scan n first high 32 then low 32 bits. - const uint32_t high_32_bit = static_cast(n >> 32); - if (high_32_bit == 0) { - return Log2Floor(static_cast(n)); - } else { - return 32 + Log2Floor(high_32_bit); - } -} - -inline int Log2Ceil64(uint64_t n) { - int floor = Log2Floor64(n); - if (n == (n & ~(n - 1))) - return floor; - else - return floor + 1; -} - -template -cudaError_t CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr, - const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size, - size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, - uint32_t device_id, cudaStream_t cuda_stream) { - // Get start position of each segment and set to segment_pos_ptr. - // The last element of segment_pos_ptr must equal to idx_seg_size. - SparseSegmentPosKernel<<>>( - segment_ids_ptr, segment_pos_ptr, idx_seg_size, output_dim0); - - const unsigned int max_grid_x = (1u << 31) - 1; - const unsigned int max_grid_y = (1u << 16) - 1; - unsigned int block_x = 32; - unsigned int block_y = 1; - unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); - unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); - dim3 block(block_x, block_y); - dim3 grid(grid_x, grid_y); - unsigned int shared_memory_size = block_x * block_y * sizeof(R); - if (kernel_type == "SparseSegmentSum" || kernel_type == "SparseSegmentSumWithNumSegments") { - SparseSegmentSumKernel<<>>( - x_ptr, indices_ptr, segment_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); - } else if (kernel_type == "SparseSegmentSqrtN" || kernel_type == "SparseSegmentSqrtNWithNumSegments") { - SparseSegmentSqrtNKernel<<>>( - x_ptr, indices_ptr, segment_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); - } - return GetCudaStatus(); -} - -#define ADD_SPARSE_SEGMENT(R, S) \ - template CUDA_LIB_EXPORT cudaError_t CalSparseSegmentCombination( \ - const std::string kernel_type, const R *x_ptr, const S *indices_ptr, const S *segment_ids_ptr, \ - size_t *segment_pos_ptr, size_t outer_size, size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, \ - uint32_t device_id, cudaStream_t cuda_stream); - -ADD_SPARSE_SEGMENT(uint8_t, int32_t) -ADD_SPARSE_SEGMENT(uint8_t, int64_t) - -ADD_SPARSE_SEGMENT(uint16_t, int32_t) -ADD_SPARSE_SEGMENT(uint16_t, int64_t) - -ADD_SPARSE_SEGMENT(int8_t, int32_t) -ADD_SPARSE_SEGMENT(int8_t, int64_t) - -ADD_SPARSE_SEGMENT(int16_t, int32_t) -ADD_SPARSE_SEGMENT(int16_t, int64_t) - -ADD_SPARSE_SEGMENT(int32_t, int32_t) -ADD_SPARSE_SEGMENT(int32_t, int64_t) - -ADD_SPARSE_SEGMENT(int64_t, int32_t) -ADD_SPARSE_SEGMENT(int64_t, int64_t) - -ADD_SPARSE_SEGMENT(half, int32_t) -ADD_SPARSE_SEGMENT(half, int64_t) - -ADD_SPARSE_SEGMENT(float, int32_t) -ADD_SPARSE_SEGMENT(float, int64_t) - -ADD_SPARSE_SEGMENT(double, int32_t) -ADD_SPARSE_SEGMENT(double, int64_t) +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" + +template +__global__ void SparseSegmentPosKernel(const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t idx_seg_size, + size_t output_dim0) { + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { + const S max_size = static_cast(output_dim0); + const S min_size = S(0); + S beg_idx = (id == 0) ? min_size : segment_ids_ptr[id - 1] + 1; + S end_idx = (id >= idx_seg_size) ? max_size : segment_ids_ptr[id]; + beg_idx = max(min_size, min(max_size, beg_idx)); + end_idx = max(min_size, min(max_size, end_idx)); + for (S i = beg_idx; i <= end_idx; i++) { + segment_pos_ptr[i] = id; + } + } +} + +template +__global__ void SparseSegmentSumKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + R segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + R reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = x_ptr[index * inner_size + inner_idx]; + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum; + } + } + } +} + +template +__global__ void SparseSegmentSumKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + double segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = + beg_pos == end_pos ? static_cast(0) : static_cast(segment_sum); + } + } + } +} + +template +__global__ void SparseSegmentSumKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + float segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + float reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : __float2half(segment_sum); + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + R segment_sum = 0; + R sqrt_segment_len = R(sqrt(static_cast(end_pos - beg_pos))); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + R reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = x_ptr[index * inner_size + inner_idx]; + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum / sqrt_segment_len; + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + double segment_sum = 0; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = + beg_pos == end_pos ? static_cast(0) : static_cast(segment_sum / sqrt_segment_len); + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + float segment_sum = 0; + float sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + float reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = + beg_pos == end_pos ? half(0) : __float2half(segment_sum / sqrt_segment_len); + } + } + } +} + +inline int Log2Floor(uint32_t n) { + if (n == 0) return -1; + int log = 0; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t x = n >> shift; + if (x) { + n = x; + log += shift; + } + } + return log; +} + +inline int Log2Floor64(uint64_t n) { + // Scan n first high 32 then low 32 bits. + const uint32_t high_32_bit = static_cast(n >> 32); + if (high_32_bit == 0) { + return Log2Floor(static_cast(n)); + } else { + return 32 + Log2Floor(high_32_bit); + } +} + +inline int Log2Ceil64(uint64_t n) { + int floor = Log2Floor64(n); + if (n == (n & ~(n - 1))) + return floor; + else + return floor + 1; +} + +template +cudaError_t CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr, + const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size, + size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream) { + // Get start position of each segment and set to segment_pos_ptr. + // The last element of segment_pos_ptr must equal to idx_seg_size. + SparseSegmentPosKernel<<>>( + segment_ids_ptr, segment_pos_ptr, idx_seg_size, output_dim0); + + const unsigned int max_grid_x = (1u << 31) - 1; + const unsigned int max_grid_y = (1u << 16) - 1; + unsigned int block_x = 32; + unsigned int block_y = 1; + unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); + unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); + dim3 block(block_x, block_y); + dim3 grid(grid_x, grid_y); + unsigned int shared_memory_size = block_x * block_y * sizeof(R); + if (kernel_type == "SparseSegmentSum" || kernel_type == "SparseSegmentSumWithNumSegments") { + SparseSegmentSumKernel<<>>( + x_ptr, indices_ptr, segment_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); + } else if (kernel_type == "SparseSegmentSqrtN" || kernel_type == "SparseSegmentSqrtNWithNumSegments") { + SparseSegmentSqrtNKernel<<>>( + x_ptr, indices_ptr, segment_pos_ptr, outer_size, inner_size, output_dim0, y_ptr); + } + return GetCudaStatus(); +} + +#define ADD_SPARSE_SEGMENT(R, S) \ + template CUDA_LIB_EXPORT cudaError_t CalSparseSegmentCombination( \ + const std::string kernel_type, const R *x_ptr, const S *indices_ptr, const S *segment_ids_ptr, \ + size_t *segment_pos_ptr, size_t outer_size, size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, \ + uint32_t device_id, cudaStream_t cuda_stream); + +ADD_SPARSE_SEGMENT(uint8_t, int32_t) +ADD_SPARSE_SEGMENT(uint8_t, int64_t) + +ADD_SPARSE_SEGMENT(uint16_t, int32_t) +ADD_SPARSE_SEGMENT(uint16_t, int64_t) + +ADD_SPARSE_SEGMENT(int8_t, int32_t) +ADD_SPARSE_SEGMENT(int8_t, int64_t) + +ADD_SPARSE_SEGMENT(int16_t, int32_t) +ADD_SPARSE_SEGMENT(int16_t, int64_t) + +ADD_SPARSE_SEGMENT(int32_t, int32_t) +ADD_SPARSE_SEGMENT(int32_t, int64_t) + +ADD_SPARSE_SEGMENT(int64_t, int32_t) +ADD_SPARSE_SEGMENT(int64_t, int64_t) + +ADD_SPARSE_SEGMENT(half, int32_t) +ADD_SPARSE_SEGMENT(half, int64_t) + +ADD_SPARSE_SEGMENT(float, int32_t) +ADD_SPARSE_SEGMENT(float, int64_t) + +ADD_SPARSE_SEGMENT(double, int32_t) +ADD_SPARSE_SEGMENT(double, int64_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh index c7b1bfbe2a5..d462daaa836 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh @@ -1,30 +1,30 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, - const S *indices_ptr, const S *segment_ids_ptr, - size_t *segment_pos_ptr, size_t outer_size, size_t inner_size, - size_t indices_size, size_t output_dim0, R *y_ptr, - uint32_t device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, + const S *indices_ptr, const S *segment_ids_ptr, + size_t *segment_pos_ptr, size_t outer_size, size_t inner_size, + size_t indices_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cu index 6b7b751d5c8..6bf92fab77c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cu @@ -1,227 +1,227 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sparse_sparse_maximum_impl.cuh" -#include "include/cuda_fp16.h" -#include "include/cuda_runtime.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__global__ void SparseSparseMaximum1(const T *a_indices, const T *b_indices, int64_t *ab_status, const int64_t rank_1, - const int64_t a_indices_num, const int64_t b_indices_num) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x, y = blockIdx.y * blockDim.y + threadIdx.y; - x < a_indices_num && y < b_indices_num; x += blockDim.x * gridDim.x, y += blockDim.y * gridDim.y) { - for (int64_t i = 0; i < rank_1; i++) { - if (a_indices[x * rank_1 + i] > b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = 1; - return; - } - if (a_indices[x * rank_1 + i] == b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = 0; - continue; - } - if (a_indices[x * rank_1 + i] < b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = -1; - return; - } - } - } -} - -template -__global__ void SparseSparseMaximum2(const T *a_indices, const T *b_indices, int64_t *ab_status, - const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, - int64_t *sum_ptr, int64_t *ab_stauts1, int64_t *ab_stauts2) { - int64_t count = 0; - int64_t i = 0; - int64_t j = 0; - while (i < a_indices_num && j < b_indices_num) { - if (ab_status[j * a_indices_num + i] == -1) { - ab_stauts1[count] = 1; - ab_stauts2[count] = i; - count++; - i++; - continue; - } - if (ab_status[j * a_indices_num + i] == 0) { - ab_stauts1[count] = -i; - ab_stauts2[count] = j; - count++; - i++; - j++; - continue; - } - if (ab_status[j * a_indices_num + i] == 1) { - ab_stauts1[count] = 2; - ab_stauts2[count] = j; - count++; - j++; - continue; - } - } - for (int64_t y1 = i; y1 < a_indices_num; y1++) { - ab_stauts1[count] = 1; - ab_stauts2[count] = y1; - count++; - } - - for (int64_t y1 = j; y1 < b_indices_num; y1++) { - ab_stauts1[count] = 2; - ab_stauts2[count] = y1; - count++; - } - *sum_ptr = count; -} - -template -__global__ void SparseSparseMaximum3(const T *a_indices, const S *a_values, const T *b_indices, const S *b_values, - int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, - S *output_values, int64_t limit1) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { - int64_t mid1 = ab_stauts2[x]; - int64_t mid2 = -ab_stauts1[x]; - if (ab_stauts1[x] == 3) { - return; - } else if (ab_stauts1[x] == 1) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid1] > 0 ? a_values[mid1] : 0; - } else if (ab_stauts1[x] == 2) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = b_values[mid1] > 0 ? b_values[mid1] : 0; - } else if (ab_stauts1[x] <= 0) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid2] > b_values[mid1] ? a_values[mid2] : b_values[mid1]; - } - } -} - -template -__global__ void SparseSparseMaximum3(const T *a_indices, const half *a_values, const T *b_indices, const half *b_values, - int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, - half *output_values, int64_t limit1) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { - int64_t mid1 = ab_stauts2[x]; - int64_t mid2 = -ab_stauts1[x]; - if (ab_stauts1[x] == 3) return; - if (ab_stauts1[x] == 1) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; - } - output_values[x] = __half2float(a_values[mid1]) > 0 ? a_values[mid1] : __float2half(0.0); - } else if (ab_stauts1[x] == 2) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = __half2float(b_values[mid1]) > 0 ? b_values[mid1] : __float2half(0.0); - } else if (ab_stauts1[x] <= 0) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid2] > b_values[mid1] ? a_values[mid2] : b_values[mid1]; - } - } -} - -template -__global__ void Max_test1(const int64_t a_len, const T *a_indices, const S *a_values, T *output_indices, - S *output_values, const int64_t rank_1, int64_t *sum_ptr) { - *sum_ptr = a_len; - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { - for (int64_t j = 0; j < rank_1; j++) { - output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; - } - output_values[x] = a_values[x] > 0 ? a_values[x] : 0; - } -} - -template -__global__ void Max_test1(const int64_t a_len, const T *a_indices, const half *a_values, T *output_indices, - half *output_values, const int64_t rank_1, int64_t *sum_ptr) { - *sum_ptr = a_len; - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { - for (int64_t j = 0; j < rank_1; j++) { - output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; - } - output_values[x] = __half2float(a_values[x]) > 0 ? a_values[x] : __float2half(0.0); - } -} - -__global__ void Max_test2(int64_t *sum_ptr) { *sum_ptr = 0; } - -template -CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum(const T *a_indices, const S *a_values, const T *b_indices, - const S *b_values, T *sum_indices, S *sum_values, - int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, - const int64_t b_indices_num, const int64_t rank_1, - cudaStream_t cuda_stream1, const uint32_t &device_id, - int64_t *ab_status_ptr1, int64_t *ab_status_ptr2) { - if (a_indices_num != 0 && b_indices_num == 0) { - Max_test1<<>>( - a_indices_num, a_indices, a_values, sum_indices, sum_values, rank_1, sum_ptr); - cudaDeviceSynchronize(); - return GetCudaStatus(); - } - if (a_indices_num == 0 && b_indices_num != 0) { - Max_test1<<>>( - b_indices_num, b_indices, b_values, sum_indices, sum_values, rank_1, sum_ptr); - cudaDeviceSynchronize(); - return GetCudaStatus(); - } - if (a_indices_num == 0 && b_indices_num == 0) { - Max_test2<<<1, 1, 0, cuda_stream1>>>(sum_ptr); - return GetCudaStatus(); - } - const int block1 = 32; - const int block2 = 32; - const int grid1 = (a_indices_num + block1 - 1) / block1; - const int grid2 = (b_indices_num + block2 - 1) / block2; - const int grid3 = (a_indices_num + b_indices_num + block1 - 1) / block1; - dim3 block12(block1, block2); - dim3 grid12(grid1, grid2); - SparseSparseMaximum1<<>>(a_indices, b_indices, ab_status_ptr, rank_1, a_indices_num, - b_indices_num); - cudaDeviceSynchronize(); - SparseSparseMaximum2<<<1, 1, 0, cuda_stream1>>>(a_indices, b_indices, ab_status_ptr, a_indices_num, b_indices_num, - rank_1, sum_ptr, ab_status_ptr1, ab_status_ptr2); - cudaDeviceSynchronize(); - SparseSparseMaximum3<<>>(a_indices, a_values, b_indices, b_values, ab_status_ptr1, - ab_status_ptr2, rank_1, sum_indices, sum_values, - a_indices_num + b_indices_num); - cudaDeviceSynchronize(); - return GetCudaStatus(); -} - -#define GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(index_type1, val_type1) \ - template CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum( \ - const index_type1 *a_indices, const val_type1 *a_values, const index_type1 *b_indices, const val_type1 *b_values, \ - index_type1 *sum_indices, val_type1 *sum_values, int64_t *ab_status_ptr, int64_t *sum_ptr, \ - const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, cudaStream_t cuda_stream1, \ - const uint32_t &device_id, int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); - -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int8_t) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int16_t) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int32_t) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int64_t) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, float) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, half) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, double) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, uint8_t) -GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, uint16_t) +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sparse_sparse_maximum_impl.cuh" +#include "include/cuda_fp16.h" +#include "include/cuda_runtime.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__global__ void SparseSparseMaximum1(const T *a_indices, const T *b_indices, int64_t *ab_status, const int64_t rank_1, + const int64_t a_indices_num, const int64_t b_indices_num) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x, y = blockIdx.y * blockDim.y + threadIdx.y; + x < a_indices_num && y < b_indices_num; x += blockDim.x * gridDim.x, y += blockDim.y * gridDim.y) { + for (int64_t i = 0; i < rank_1; i++) { + if (a_indices[x * rank_1 + i] > b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = 1; + return; + } + if (a_indices[x * rank_1 + i] == b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = 0; + continue; + } + if (a_indices[x * rank_1 + i] < b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = -1; + return; + } + } + } +} + +template +__global__ void SparseSparseMaximum2(const T *a_indices, const T *b_indices, int64_t *ab_status, + const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, + int64_t *sum_ptr, int64_t *ab_stauts1, int64_t *ab_stauts2) { + int64_t count = 0; + int64_t i = 0; + int64_t j = 0; + while (i < a_indices_num && j < b_indices_num) { + if (ab_status[j * a_indices_num + i] == -1) { + ab_stauts1[count] = 1; + ab_stauts2[count] = i; + count++; + i++; + continue; + } + if (ab_status[j * a_indices_num + i] == 0) { + ab_stauts1[count] = -i; + ab_stauts2[count] = j; + count++; + i++; + j++; + continue; + } + if (ab_status[j * a_indices_num + i] == 1) { + ab_stauts1[count] = 2; + ab_stauts2[count] = j; + count++; + j++; + continue; + } + } + for (int64_t y1 = i; y1 < a_indices_num; y1++) { + ab_stauts1[count] = 1; + ab_stauts2[count] = y1; + count++; + } + + for (int64_t y1 = j; y1 < b_indices_num; y1++) { + ab_stauts1[count] = 2; + ab_stauts2[count] = y1; + count++; + } + *sum_ptr = count; +} + +template +__global__ void SparseSparseMaximum3(const T *a_indices, const S *a_values, const T *b_indices, const S *b_values, + int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, + S *output_values, int64_t limit1) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { + int64_t mid1 = ab_stauts2[x]; + int64_t mid2 = -ab_stauts1[x]; + if (ab_stauts1[x] == 3) { + return; + } else if (ab_stauts1[x] == 1) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid1] > 0 ? a_values[mid1] : 0; + } else if (ab_stauts1[x] == 2) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = b_values[mid1] > 0 ? b_values[mid1] : 0; + } else if (ab_stauts1[x] <= 0) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid2] > b_values[mid1] ? a_values[mid2] : b_values[mid1]; + } + } +} + +template +__global__ void SparseSparseMaximum3(const T *a_indices, const half *a_values, const T *b_indices, const half *b_values, + int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, + half *output_values, int64_t limit1) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { + int64_t mid1 = ab_stauts2[x]; + int64_t mid2 = -ab_stauts1[x]; + if (ab_stauts1[x] == 3) return; + if (ab_stauts1[x] == 1) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; + } + output_values[x] = __half2float(a_values[mid1]) > 0 ? a_values[mid1] : __float2half(0.0); + } else if (ab_stauts1[x] == 2) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = __half2float(b_values[mid1]) > 0 ? b_values[mid1] : __float2half(0.0); + } else if (ab_stauts1[x] <= 0) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid2] > b_values[mid1] ? a_values[mid2] : b_values[mid1]; + } + } +} + +template +__global__ void Max_test1(const int64_t a_len, const T *a_indices, const S *a_values, T *output_indices, + S *output_values, const int64_t rank_1, int64_t *sum_ptr) { + *sum_ptr = a_len; + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { + for (int64_t j = 0; j < rank_1; j++) { + output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; + } + output_values[x] = a_values[x] > 0 ? a_values[x] : 0; + } +} + +template +__global__ void Max_test1(const int64_t a_len, const T *a_indices, const half *a_values, T *output_indices, + half *output_values, const int64_t rank_1, int64_t *sum_ptr) { + *sum_ptr = a_len; + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { + for (int64_t j = 0; j < rank_1; j++) { + output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; + } + output_values[x] = __half2float(a_values[x]) > 0 ? a_values[x] : __float2half(0.0); + } +} + +__global__ void Max_test2(int64_t *sum_ptr) { *sum_ptr = 0; } + +template +CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum(const T *a_indices, const S *a_values, const T *b_indices, + const S *b_values, T *sum_indices, S *sum_values, + int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, + const int64_t b_indices_num, const int64_t rank_1, + cudaStream_t cuda_stream1, const uint32_t &device_id, + int64_t *ab_status_ptr1, int64_t *ab_status_ptr2) { + if (a_indices_num != 0 && b_indices_num == 0) { + Max_test1<<>>( + a_indices_num, a_indices, a_values, sum_indices, sum_values, rank_1, sum_ptr); + cudaDeviceSynchronize(); + return GetCudaStatus(); + } + if (a_indices_num == 0 && b_indices_num != 0) { + Max_test1<<>>( + b_indices_num, b_indices, b_values, sum_indices, sum_values, rank_1, sum_ptr); + cudaDeviceSynchronize(); + return GetCudaStatus(); + } + if (a_indices_num == 0 && b_indices_num == 0) { + Max_test2<<<1, 1, 0, cuda_stream1>>>(sum_ptr); + return GetCudaStatus(); + } + const int block1 = 32; + const int block2 = 32; + const int grid1 = (a_indices_num + block1 - 1) / block1; + const int grid2 = (b_indices_num + block2 - 1) / block2; + const int grid3 = (a_indices_num + b_indices_num + block1 - 1) / block1; + dim3 block12(block1, block2); + dim3 grid12(grid1, grid2); + SparseSparseMaximum1<<>>(a_indices, b_indices, ab_status_ptr, rank_1, a_indices_num, + b_indices_num); + cudaDeviceSynchronize(); + SparseSparseMaximum2<<<1, 1, 0, cuda_stream1>>>(a_indices, b_indices, ab_status_ptr, a_indices_num, b_indices_num, + rank_1, sum_ptr, ab_status_ptr1, ab_status_ptr2); + cudaDeviceSynchronize(); + SparseSparseMaximum3<<>>(a_indices, a_values, b_indices, b_values, ab_status_ptr1, + ab_status_ptr2, rank_1, sum_indices, sum_values, + a_indices_num + b_indices_num); + cudaDeviceSynchronize(); + return GetCudaStatus(); +} + +#define GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(index_type1, val_type1) \ + template CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum( \ + const index_type1 *a_indices, const val_type1 *a_values, const index_type1 *b_indices, const val_type1 *b_values, \ + index_type1 *sum_indices, val_type1 *sum_values, int64_t *ab_status_ptr, int64_t *sum_ptr, \ + const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, cudaStream_t cuda_stream1, \ + const uint32_t &device_id, int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); + +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int8_t) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int16_t) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int32_t) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, int64_t) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, float) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, half) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, double) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, uint8_t) +GPU_SPARSE_SPARSE_MAXIMUM_GRAD_EXPORT_REGISTER(int64_t, uint16_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh index 06524fd5092..21f8a247529 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MAXIMUM_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MAXIMUM_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum(const T *a_indices, const S *a_values, const T *b_indices, - const S *b_values, T *sum_indices, S *sum_values, - int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, - const int64_t b_indices_num, const int64_t rank_1, - cudaStream_t cuda_stream, const uint32_t &device_id, - int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); -#endif +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MAXIMUM_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MAXIMUM_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t SparseSparseMaximum(const T *a_indices, const S *a_values, const T *b_indices, + const S *b_values, T *sum_indices, S *sum_values, + int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, + const int64_t b_indices_num, const int64_t rank_1, + cudaStream_t cuda_stream, const uint32_t &device_id, + int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); +#endif diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cu index e4c08a4effc..3179c8dc66a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cu @@ -1,228 +1,228 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sparse_sparse_minimum_impl.cuh" -#include "include/cuda_fp16.h" -#include "include/cuda_runtime.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__global__ void SparseSparseMinimum1(const T *a_indices, const T *b_indices, int64_t *ab_status, const int64_t rank_1, - const int64_t a_indices_num, const int64_t b_indices_num) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x, y = blockIdx.y * blockDim.y + threadIdx.y; - x < a_indices_num && y < b_indices_num; x += blockDim.x * gridDim.x, y += blockDim.y * gridDim.y) { - if (x < a_indices_num && y < b_indices_num) { - for (int64_t i = 0; i < rank_1; i++) { - if (a_indices[x * rank_1 + i] > b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = 1; - return; - } - if (a_indices[x * rank_1 + i] == b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = 0; - continue; - } - if (a_indices[x * rank_1 + i] < b_indices[y * rank_1 + i]) { - ab_status[y * a_indices_num + x] = -1; - return; - } - } - } - } -} - -template -__global__ void SparseSparseMinimum2(const T *a_indices, const T *b_indices, int64_t *ab_status, - const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, - int64_t *sum_ptr, int64_t *ab_stauts1, int64_t *ab_stauts2) { - int64_t count = 0; - int64_t i = 0; - int64_t j = 0; - while (i < a_indices_num && j < b_indices_num) { - if (ab_status[j * a_indices_num + i] == -1) { - ab_stauts1[count] = 1; - ab_stauts2[count] = i; - count++; - i++; - continue; - } - if (ab_status[j * a_indices_num + i] == 0) { - ab_stauts1[count] = -i; - ab_stauts2[count] = j; - count++; - i++; - j++; - continue; - } - if (ab_status[j * a_indices_num + i] == 1) { - ab_stauts1[count] = 2; - ab_stauts2[count] = j; - count++; - j++; - continue; - } - } - for (int64_t y1 = i; y1 < a_indices_num; y1++) { - ab_stauts1[count] = 1; - ab_stauts2[count] = y1; - count++; - } - - for (int64_t y1 = j; y1 < b_indices_num; y1++) { - ab_stauts1[count] = 2; - ab_stauts2[count] = y1; - count++; - } - *sum_ptr = count; -} - -template -__global__ void SparseSparseMinimum3(const T *a_indices, const S *a_values, const T *b_indices, const S *b_values, - int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, - S *output_values, int64_t limit1) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { - int64_t mid1 = ab_stauts2[x]; - int64_t mid2 = -ab_stauts1[x]; - if (x >= limit1) return; - if (ab_stauts1[x] == 3) { - return; - } else if (ab_stauts1[x] == 1) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid1] < static_cast(0) ? a_values[mid1] : static_cast(0); - } else if (ab_stauts1[x] == 2) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = b_values[mid1] < static_cast(0) ? b_values[mid1] : static_cast(0); - } else if (ab_stauts1[x] <= 0) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid2] < b_values[mid1] ? a_values[mid2] : b_values[mid1]; - } - } -} - -template -__global__ void SparseSparseMinimum3(const T *a_indices, const half *a_values, const T *b_indices, const half *b_values, - int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, - half *output_values, int64_t limit1) { - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { - int64_t mid1 = ab_stauts2[x]; - int64_t mid2 = -ab_stauts1[x]; - if (ab_stauts1[x] == 3) return; - if (ab_stauts1[x] == 1) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; - } - output_values[x] = __half2float(a_values[mid1]) < 0 ? a_values[mid1] : __float2half(0.0); - } else if (ab_stauts1[x] == 2) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = __half2float(b_values[mid1]) < 0 ? b_values[mid1] : __float2half(0.0); - } else if (ab_stauts1[x] <= 0) { - for (int64_t m = 0; m < rank_1; m++) { - output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; - } - output_values[x] = a_values[mid2] < b_values[mid1] ? a_values[mid2] : b_values[mid1]; - } - } -} - -template -__global__ void Min_test1(const int64_t a_len, const T *a_indices, const S *a_values, T *output_indices, - S *output_values, const int64_t rank_1, int64_t *sum_ptr) { - *sum_ptr = a_len; - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { - for (int64_t j = 0; j < rank_1; j++) { - output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; - } - output_values[x] = a_values[x] < static_cast(0) ? a_values[x] : static_cast(0); - } -} - -template -__global__ void Min_test1(const int64_t a_len, const T *a_indices, const half *a_values, T *output_indices, - half *output_values, const int64_t rank_1, int64_t *sum_ptr) { - *sum_ptr = a_len; - for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { - for (int64_t j = 0; j < rank_1; j++) { - output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; - } - output_values[x] = __half2float(a_values[x]) < 0 ? a_values[x] : __float2half(0.0); - } -} - -__global__ void Min_test2(int64_t *sum_ptr) { *sum_ptr = 0; } - -template -CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum(const T *a_indices, const S *a_values, const T *b_indices, - const S *b_values, T *sum_indices, S *sum_values, - int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, - const int64_t b_indices_num, const int64_t rank_1, - cudaStream_t cuda_stream1, const uint32_t &device_id, - int64_t *ab_status_ptr1, int64_t *ab_status_ptr2) { - if (a_indices_num != 0 && b_indices_num == 0) { - Min_test1<<>>( - a_indices_num, a_indices, a_values, sum_indices, sum_values, rank_1, sum_ptr); - cudaDeviceSynchronize(); - return GetCudaStatus(); - } - if (a_indices_num == 0 && b_indices_num != 0) { - Min_test1<<>>( - b_indices_num, b_indices, b_values, sum_indices, sum_values, rank_1, sum_ptr); - cudaDeviceSynchronize(); - return GetCudaStatus(); - } - if (a_indices_num == 0 && b_indices_num == 0) { - Min_test2<<<1, 1>>>(sum_ptr); - return GetCudaStatus(); - } - const int block1 = 32; - const int block2 = 32; - const int grid1 = (a_indices_num + block1 - 1) / block1; - const int grid2 = (b_indices_num + block2 - 1) / block2; - const int grid3 = (a_indices_num + b_indices_num + block1 - 1) / block1; - dim3 block12(block1, block2); - dim3 grid12(grid1, grid2); - SparseSparseMinimum1<<>>(a_indices, b_indices, ab_status_ptr, rank_1, a_indices_num, b_indices_num); - cudaDeviceSynchronize(); - SparseSparseMinimum2<<<1, 1>>>(a_indices, b_indices, ab_status_ptr, a_indices_num, b_indices_num, rank_1, sum_ptr, - ab_status_ptr1, ab_status_ptr2); - cudaDeviceSynchronize(); - SparseSparseMinimum3<<>>(a_indices, a_values, b_indices, b_values, ab_status_ptr1, ab_status_ptr2, - rank_1, sum_indices, sum_values, a_indices_num + b_indices_num); - cudaDeviceSynchronize(); - return GetCudaStatus(); -} - -#define GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(index_type1, val_type1) \ - template CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum( \ - const index_type1 *a_indices, const val_type1 *a_values, const index_type1 *b_indices, const val_type1 *b_values, \ - index_type1 *sum_indices, val_type1 *sum_values, int64_t *ab_status_ptr, int64_t *sum_ptr, \ - const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, cudaStream_t cuda_stream1, \ - const uint32_t &device_id, int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); - -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int8_t) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int16_t) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int32_t) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int64_t) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, float) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, half) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, double) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, uint8_t) -GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, uint16_t) +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sparse_sparse_minimum_impl.cuh" +#include "include/cuda_fp16.h" +#include "include/cuda_runtime.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__global__ void SparseSparseMinimum1(const T *a_indices, const T *b_indices, int64_t *ab_status, const int64_t rank_1, + const int64_t a_indices_num, const int64_t b_indices_num) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x, y = blockIdx.y * blockDim.y + threadIdx.y; + x < a_indices_num && y < b_indices_num; x += blockDim.x * gridDim.x, y += blockDim.y * gridDim.y) { + if (x < a_indices_num && y < b_indices_num) { + for (int64_t i = 0; i < rank_1; i++) { + if (a_indices[x * rank_1 + i] > b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = 1; + return; + } + if (a_indices[x * rank_1 + i] == b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = 0; + continue; + } + if (a_indices[x * rank_1 + i] < b_indices[y * rank_1 + i]) { + ab_status[y * a_indices_num + x] = -1; + return; + } + } + } + } +} + +template +__global__ void SparseSparseMinimum2(const T *a_indices, const T *b_indices, int64_t *ab_status, + const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, + int64_t *sum_ptr, int64_t *ab_stauts1, int64_t *ab_stauts2) { + int64_t count = 0; + int64_t i = 0; + int64_t j = 0; + while (i < a_indices_num && j < b_indices_num) { + if (ab_status[j * a_indices_num + i] == -1) { + ab_stauts1[count] = 1; + ab_stauts2[count] = i; + count++; + i++; + continue; + } + if (ab_status[j * a_indices_num + i] == 0) { + ab_stauts1[count] = -i; + ab_stauts2[count] = j; + count++; + i++; + j++; + continue; + } + if (ab_status[j * a_indices_num + i] == 1) { + ab_stauts1[count] = 2; + ab_stauts2[count] = j; + count++; + j++; + continue; + } + } + for (int64_t y1 = i; y1 < a_indices_num; y1++) { + ab_stauts1[count] = 1; + ab_stauts2[count] = y1; + count++; + } + + for (int64_t y1 = j; y1 < b_indices_num; y1++) { + ab_stauts1[count] = 2; + ab_stauts2[count] = y1; + count++; + } + *sum_ptr = count; +} + +template +__global__ void SparseSparseMinimum3(const T *a_indices, const S *a_values, const T *b_indices, const S *b_values, + int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, + S *output_values, int64_t limit1) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { + int64_t mid1 = ab_stauts2[x]; + int64_t mid2 = -ab_stauts1[x]; + if (x >= limit1) return; + if (ab_stauts1[x] == 3) { + return; + } else if (ab_stauts1[x] == 1) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid1] < static_cast(0) ? a_values[mid1] : static_cast(0); + } else if (ab_stauts1[x] == 2) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = b_values[mid1] < static_cast(0) ? b_values[mid1] : static_cast(0); + } else if (ab_stauts1[x] <= 0) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid2] < b_values[mid1] ? a_values[mid2] : b_values[mid1]; + } + } +} + +template +__global__ void SparseSparseMinimum3(const T *a_indices, const half *a_values, const T *b_indices, const half *b_values, + int64_t *ab_stauts1, int64_t *ab_stauts2, const int64_t rank_1, T *output_indices, + half *output_values, int64_t limit1) { + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < limit1; x += blockDim.x * gridDim.x) { + int64_t mid1 = ab_stauts2[x]; + int64_t mid2 = -ab_stauts1[x]; + if (ab_stauts1[x] == 3) return; + if (ab_stauts1[x] == 1) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = a_indices[mid1 * rank_1 + m]; + } + output_values[x] = __half2float(a_values[mid1]) < 0 ? a_values[mid1] : __float2half(0.0); + } else if (ab_stauts1[x] == 2) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = __half2float(b_values[mid1]) < 0 ? b_values[mid1] : __float2half(0.0); + } else if (ab_stauts1[x] <= 0) { + for (int64_t m = 0; m < rank_1; m++) { + output_indices[x * rank_1 + m] = b_indices[mid1 * rank_1 + m]; + } + output_values[x] = a_values[mid2] < b_values[mid1] ? a_values[mid2] : b_values[mid1]; + } + } +} + +template +__global__ void Min_test1(const int64_t a_len, const T *a_indices, const S *a_values, T *output_indices, + S *output_values, const int64_t rank_1, int64_t *sum_ptr) { + *sum_ptr = a_len; + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { + for (int64_t j = 0; j < rank_1; j++) { + output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; + } + output_values[x] = a_values[x] < static_cast(0) ? a_values[x] : static_cast(0); + } +} + +template +__global__ void Min_test1(const int64_t a_len, const T *a_indices, const half *a_values, T *output_indices, + half *output_values, const int64_t rank_1, int64_t *sum_ptr) { + *sum_ptr = a_len; + for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < a_len; x += blockDim.x * gridDim.x) { + for (int64_t j = 0; j < rank_1; j++) { + output_indices[x * rank_1 + j] = a_indices[x * rank_1 + j]; + } + output_values[x] = __half2float(a_values[x]) < 0 ? a_values[x] : __float2half(0.0); + } +} + +__global__ void Min_test2(int64_t *sum_ptr) { *sum_ptr = 0; } + +template +CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum(const T *a_indices, const S *a_values, const T *b_indices, + const S *b_values, T *sum_indices, S *sum_values, + int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, + const int64_t b_indices_num, const int64_t rank_1, + cudaStream_t cuda_stream1, const uint32_t &device_id, + int64_t *ab_status_ptr1, int64_t *ab_status_ptr2) { + if (a_indices_num != 0 && b_indices_num == 0) { + Min_test1<<>>( + a_indices_num, a_indices, a_values, sum_indices, sum_values, rank_1, sum_ptr); + cudaDeviceSynchronize(); + return GetCudaStatus(); + } + if (a_indices_num == 0 && b_indices_num != 0) { + Min_test1<<>>( + b_indices_num, b_indices, b_values, sum_indices, sum_values, rank_1, sum_ptr); + cudaDeviceSynchronize(); + return GetCudaStatus(); + } + if (a_indices_num == 0 && b_indices_num == 0) { + Min_test2<<<1, 1>>>(sum_ptr); + return GetCudaStatus(); + } + const int block1 = 32; + const int block2 = 32; + const int grid1 = (a_indices_num + block1 - 1) / block1; + const int grid2 = (b_indices_num + block2 - 1) / block2; + const int grid3 = (a_indices_num + b_indices_num + block1 - 1) / block1; + dim3 block12(block1, block2); + dim3 grid12(grid1, grid2); + SparseSparseMinimum1<<>>(a_indices, b_indices, ab_status_ptr, rank_1, a_indices_num, b_indices_num); + cudaDeviceSynchronize(); + SparseSparseMinimum2<<<1, 1>>>(a_indices, b_indices, ab_status_ptr, a_indices_num, b_indices_num, rank_1, sum_ptr, + ab_status_ptr1, ab_status_ptr2); + cudaDeviceSynchronize(); + SparseSparseMinimum3<<>>(a_indices, a_values, b_indices, b_values, ab_status_ptr1, ab_status_ptr2, + rank_1, sum_indices, sum_values, a_indices_num + b_indices_num); + cudaDeviceSynchronize(); + return GetCudaStatus(); +} + +#define GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(index_type1, val_type1) \ + template CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum( \ + const index_type1 *a_indices, const val_type1 *a_values, const index_type1 *b_indices, const val_type1 *b_values, \ + index_type1 *sum_indices, val_type1 *sum_values, int64_t *ab_status_ptr, int64_t *sum_ptr, \ + const int64_t a_indices_num, const int64_t b_indices_num, const int64_t rank_1, cudaStream_t cuda_stream1, \ + const uint32_t &device_id, int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); + +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int8_t) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int16_t) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int32_t) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, int64_t) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, float) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, half) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, double) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, uint8_t) +GPU_SPARSE_SPARSE_MINIMUM_GRAD_EXPORT_REGISTER(int64_t, uint16_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh index 98a7aa1622c..9ea169bdb41 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -template -CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum(const T *a_indices, const S *a_values, const T *b_indices, - const S *b_values, T *sum_indices, S *sum_values, - int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, - const int64_t b_indices_num, const int64_t rank_1, - cudaStream_t cuda_stream, const uint32_t &device_id, - int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT cudaError_t SparseSparseMinimum(const T *a_indices, const S *a_values, const T *b_indices, + const S *b_values, T *sum_indices, S *sum_values, + int64_t *ab_status_ptr, int64_t *sum_ptr, const int64_t a_indices_num, + const int64_t b_indices_num, const int64_t rank_1, + cudaStream_t cuda_stream, const uint32_t &device_id, + int64_t *ab_status_ptr1, int64_t *ab_status_ptr2); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SPARSE_MINIMUM_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cu index 8bc01799158..d291dd5bb24 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cu @@ -1,52 +1,52 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "sparse_tensor_to_csr_sparse_matrix_impl.cuh" -#include -#include -#include "plugin/device/cpu/kernel/nnacl/op_base.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__global__ void SparseTensorToCSRSparseMatrixKernel(const IndiceType *x_indices_ptr, IndiceType *out_row_indices_ptr, - IndiceType *out_col_indices_ptr, IndiceType *out_batch_pointers_ptr, - int total_num, int rank) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_num; i += blockDim.x * gridDim.x) { - out_row_indices_ptr[i] = x_indices_ptr[i * rank + rank - 2]; - out_col_indices_ptr[i] = x_indices_ptr[i * rank + rank - 1]; - if (rank == 3) { - IndiceType batch = x_indices_ptr[i * rank]; - MsAtomicMax(out_batch_pointers_ptr + batch + 1, i + 1); - } else { - MsAtomicMax(out_batch_pointers_ptr + 1, i + 1); - } - } -} - -template -CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix(const IndiceType *x_indices_ptr, - IndiceType *out_row_indices_ptr, - IndiceType *out_col_indices_ptr, - IndiceType *out_batch_pointers_ptr, int total_num, int rank, - cudaStream_t cuda_stream, const uint32_t device_id) { - SparseTensorToCSRSparseMatrixKernel<<>>( - x_indices_ptr, out_row_indices_ptr, out_col_indices_ptr, out_batch_pointers_ptr, total_num, rank); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix( - const int32_t *x_indices_ptr, int32_t *out_row_indices_ptr, int32_t *out_col_indices_ptr, - int32_t *out_batch_pointers_ptr, int total_num, int rank, cudaStream_t cuda_stream, const uint32_t device_id); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_tensor_to_csr_sparse_matrix_impl.cuh" +#include +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__global__ void SparseTensorToCSRSparseMatrixKernel(const IndiceType *x_indices_ptr, IndiceType *out_row_indices_ptr, + IndiceType *out_col_indices_ptr, IndiceType *out_batch_pointers_ptr, + int total_num, int rank) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_num; i += blockDim.x * gridDim.x) { + out_row_indices_ptr[i] = x_indices_ptr[i * rank + rank - 2]; + out_col_indices_ptr[i] = x_indices_ptr[i * rank + rank - 1]; + if (rank == 3) { + IndiceType batch = x_indices_ptr[i * rank]; + MsAtomicMax(out_batch_pointers_ptr + batch + 1, i + 1); + } else { + MsAtomicMax(out_batch_pointers_ptr + 1, i + 1); + } + } +} + +template +CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix(const IndiceType *x_indices_ptr, + IndiceType *out_row_indices_ptr, + IndiceType *out_col_indices_ptr, + IndiceType *out_batch_pointers_ptr, int total_num, int rank, + cudaStream_t cuda_stream, const uint32_t device_id) { + SparseTensorToCSRSparseMatrixKernel<<>>( + x_indices_ptr, out_row_indices_ptr, out_col_indices_ptr, out_batch_pointers_ptr, total_num, rank); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix( + const int32_t *x_indices_ptr, int32_t *out_row_indices_ptr, int32_t *out_col_indices_ptr, + int32_t *out_batch_pointers_ptr, int total_num, int rank, cudaStream_t cuda_stream, const uint32_t device_id); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh index afdf0d24cd5..3bc3748964e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh @@ -1,28 +1,28 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix(const IndiceType *x_indices_ptr, - IndiceType *out_row_indices_ptr, - IndiceType *out_col_indices_ptr, - IndiceType *out_batch_pointers_ptr, int total_num, int rank, - cudaStream_t cuda_stream, const uint32_t device_id); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t SparseTensorToCSRSparseMatrix(const IndiceType *x_indices_ptr, + IndiceType *out_row_indices_ptr, + IndiceType *out_col_indices_ptr, + IndiceType *out_batch_pointers_ptr, int total_num, int rank, + cudaStream_t cuda_stream, const uint32_t device_id); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu old mode 100755 new mode 100644 index bf443c1dcb8..5cf2e0e11e8 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu @@ -1,398 +1,398 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" - -template -using Complex = mindspore::utils::Complex; - -template -__global__ void TransposeKernel(const T *__restrict__ input, const size_t size, const TransposeInfoDevice info, - const int ndims, T *__restrict__ output) { - const int32_t *in_strides = info.transpose_info_device; - const int32_t *out_strides = info.transpose_info_device + stride_ndims; - const int32_t *perm = info.transpose_info_device + stride_ndims * 2; - for (int output_pos = blockDim.x * blockIdx.x + threadIdx.x; output_pos < size; - output_pos += blockDim.x * gridDim.x) { - int32_t input_pos = 0; - int32_t temp = output_pos; - for (int i = 0; i < ndims; ++i) { - const int32_t ratio = temp / out_strides[i]; - temp -= ratio * out_strides[i]; - input_pos += ratio * in_strides[perm[i]]; - } - output[output_pos] = input[input_pos]; - } -} - -template -bool TransposeUsingTile(const T *input, const std::vector &shape, const std::vector &perm, T *output, - cudaStream_t cuda_stream) { - int dims = shape.size(); - if (dims < 2 || dims > 3) { - return false; - } - switch (dims) { - case 2: - if (perm[0] == 1 && perm[1] == 0) { - Swap3DTensorLast2Dim(input, (int64_t)1, shape[0], shape[1], output, cuda_stream); - return true; - } - break; - case 3: - if (perm == std::vector{0, 2, 1}) { - Swap3DTensorLast2Dim(input, shape[0], shape[1], shape[2], output, cuda_stream); - return true; - } else if (perm == std::vector{2, 1, 0}) { - Swap3DTensorDim0and2(input, shape[0], shape[1], shape[2], output, cuda_stream); - return true; - } else { - // Do not support other 3D Transpose. - return false; - } - break; - default: - return false; - } - return false; -} - -// Optimize nchw2nhwc && nhwc2nchw with tiling and shared memory. -// Firstly, combined 2 dims hw together, treat input and output as 3D tensor. -// Secondly, determine whether a matrix is a large matrix or a narrow matrix, -// which determines the chosen TileSize. -// Reason: tiling and shared memory can avoid uncoalesced global memory access. -// There are two stages of this kernel, load-to-shm and write-to-output. -// load-to-shm: Threads in a thread block work together to load input data tile to shared mem. -// write-to-output: Threads in a thread block work together to write shared mem to output tile. -// because of the shared mem usage, The access to both input and output memory can be coalesced. - -// SimpleTransposeKernel for small matrix - -__forceinline__ __device__ int TensorIdxToOneDimIdx(int ndims, const int *idx, const int *dims) { - int flat_idx = idx[0]; - for (int i = 1; i < ndims; i++) { - flat_idx = flat_idx * dims[i] + idx[i]; - } - return flat_idx; -} - -__forceinline__ __device__ void OneDimIdxToTensorIdx(int ndims, int idx, const int *dims, int *out_tensor_idx) { - for (int i = ndims - 1; i >= 0; i--) { - int new_idx = idx / dims[i]; - out_tensor_idx[i] = idx - dims[i] * new_idx; - idx = new_idx; - } -} - -template -__global__ void Transpose3DTensorSimple(const T *__restrict__ input, const size_t size, const int64_t dim0, - const int64_t dim1, const int64_t dim2, T *__restrict__ output) { - int output_shape[3]{0, 0, 0}; - output_shape[perm0] = dim0; - output_shape[perm1] = dim1; - output_shape[perm2] = dim2; - for (int output_pos = blockIdx.x * blockDim.x + threadIdx.x; output_pos < size; - output_pos += gridDim.x * blockDim.x) { - int output_tensor_index[3]{0, 0, 0}; - OneDimIdxToTensorIdx(3, output_pos, output_shape, output_tensor_index); - int input_tensor_index[3]{0, 0, 0}; - int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; - input_tensor_index[0] = output_tensor_index[perm0]; - input_tensor_index[1] = output_tensor_index[perm1]; - input_tensor_index[2] = output_tensor_index[perm2]; - int input_pos = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); - output[output_pos] = input[input_pos]; - } -} - -template -__global__ void Swap3DTensorLast2DimKernel(const T *input, int NumThreads, int TileHeight, int TileWidth, - int input_dims_0, int input_dims_1, int input_dims_2, T *output) { - extern __shared__ unsigned char sdata_uchar[]; - // shm_tile[TileHeight][TileWidth + 1]: to avoid bank conflict in write-to-output period - T *shm_tile = reinterpret_cast(sdata_uchar); - int NumRowsPerLoadLoop = NumThreads / TileWidth; // the number of shm rows that all threads can load into shm once - int NumColsPerWriteLoop = - NumThreads / TileHeight; // the number of shm cols that all threads can write into output once - int load_thread_num_align = NumRowsPerLoadLoop * TileWidth; // use align num threads in load-to-shm period - int write_thread_num_align = NumColsPerWriteLoop * TileHeight; // use align num threads in write-to-output period - int tid = threadIdx.x; - int input_dims[3] = {input_dims_0, input_dims_1, input_dims_2}; - int output_dims[3] = {input_dims[0], input_dims[2], input_dims[1]}; - int input_dims_in_tiles[3] = {input_dims[0], (input_dims[1] + TileHeight - 1) / TileHeight, - (input_dims[2] + TileWidth - 1) / TileWidth}; - int input_tile_idx[3]; - OneDimIdxToTensorIdx(3, blockIdx.x, input_dims_in_tiles, input_tile_idx); - int input_tile_origin[3] = {input_tile_idx[0], input_tile_idx[1] * TileHeight, input_tile_idx[2] * TileWidth}; - int input_block_start_idx = TensorIdxToOneDimIdx(3, input_tile_origin, input_dims); // input idx of this thread block - bool full_tile = true; - int tile_width = TileWidth; - // Only the last row or column may not have the full size - // boundary process - if (input_tile_idx[2] == input_dims_in_tiles[2] - 1) { - tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileWidth; - full_tile &= false; - } - int tile_height = TileHeight; - if (input_tile_idx[1] == input_dims_in_tiles[1] - 1) { - tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileHeight; - full_tile &= false; - } - // load-to-shm: each block load input data into shared mem(loop) - if (tid < load_thread_num_align) { - // Map task blocks to thread blocks. - // organize threads to n*TileWidth - int shm_row_id = tid / TileWidth; // shem_row_id, also the block row_id of input - int shm_col_id = tid % TileWidth; // shem_col_id, also the block col_id of input - int input_idx = input_block_start_idx + shm_row_id * input_dims[2] + shm_col_id; // the input idx of this thread - int input_step = NumRowsPerLoadLoop * input_dims[2]; - if (full_tile) { // thread blocks responses for inner tiles -#pragma unroll - for (int row_id = shm_row_id; row_id < (TileHeight); - row_id += NumRowsPerLoadLoop) { // move to the next pass, loop - // shm_tile[row_id][shm_col_id] - shm_tile[row_id * (TileWidth + 1) + shm_col_id] = - input[input_idx]; // each thread load one input data into shared mem - input_idx += input_step; // calculate the next input idx this thread should load - } - } else { // boundary process: thread blocks responses for edge tiles - if (shm_col_id < tile_width) { - for (int row_id = shm_row_id; row_id < (tile_height); row_id += NumRowsPerLoadLoop) { - // shm_tile[row_id][shm_col_id] - shm_tile[row_id * (TileWidth + 1) + shm_col_id] = input[input_idx]; - input_idx += input_step; - } - } - } - } - __syncthreads(); - // load-to-shm: end - - // write-to-output: each block write shared mem into output(loop) - int output_tile_idx[3] = {input_tile_idx[0], input_tile_idx[2], input_tile_idx[1]}; - int output_tile_origin[3] = {output_tile_idx[0], output_tile_idx[1] * TileWidth, output_tile_idx[2] * TileHeight}; - int output_block_start_idx = TensorIdxToOneDimIdx(3, output_tile_origin, output_dims); - if (tid < write_thread_num_align) { - // organize threads to TileHeight*n1 - int shm_col_id = tid / TileHeight; // shm_col_id, also the block row_id of output - int shm_row_id = tid % TileHeight; // shm_row_id, also the block col_id of output - int output_idx = output_block_start_idx + shm_col_id * output_dims[2] + shm_row_id; - int output_step = NumColsPerWriteLoop * output_dims[2]; - if (full_tile) { -#pragma unroll - for (int col_id = shm_col_id; col_id < (TileWidth); - col_id += NumColsPerWriteLoop) { // move to the next pass, loop - // shm_tile[shm_row_id][col_id] - output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; // avoid bank conflict - output_idx += output_step; - } - } else { - if (shm_row_id < tile_height) { - for (int col_id = shm_col_id; col_id < (tile_width); col_id += NumColsPerWriteLoop) { - // shm_tile[shm_row_id][col_id]; - output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; - output_idx += output_step; - } - } - } - } -} - -template -__global__ void Transpose3DTensorSimpleVector(const T *__restrict__ input, size_t size, const int64_t dim0, - const int64_t dim1, const int64_t dim2, T *__restrict__ output) { - int output_shape[3]{0, 0, 0}; - output_shape[perm0] = dim0; - output_shape[perm1] = dim1; - output_shape[perm2] = dim2; - - const int stride = blockDim.x * gridDim.x * kUnroll; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - T vec[kUnroll]; - int output_pos; - for (output_pos = tid * kUnroll; output_pos + kUnroll - 1 < size; output_pos += stride) { -#pragma unroll - for (int i = 0; i < kUnroll; ++i) { - int outpos_pos_i = output_pos + i; - int output_tensor_index[3]{0, 0, 0}; - OneDimIdxToTensorIdx(3, outpos_pos_i, output_shape, output_tensor_index); - int input_tensor_index[3]{0, 0, 0}; - int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; - input_tensor_index[0] = output_tensor_index[perm0]; - input_tensor_index[1] = output_tensor_index[perm1]; - input_tensor_index[2] = output_tensor_index[perm2]; - int input_pos_i = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); - vec[i] = input[input_pos_i]; - } - float2 *out = reinterpret_cast(output + output_pos); - *out = *reinterpret_cast(vec); - } - - for (; output_pos < size; ++output_pos) { - int output_tensor_index[3]{0, 0, 0}; - OneDimIdxToTensorIdx(3, output_pos, output_shape, output_tensor_index); - int input_tensor_index[3]{0, 0, 0}; - int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; - input_tensor_index[0] = output_tensor_index[perm0]; - input_tensor_index[1] = output_tensor_index[perm1]; - input_tensor_index[2] = output_tensor_index[perm2]; - int input_pos = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); - output[output_pos] = input[input_pos]; - } -} - -template -void Swap3DTensorLast2Dim(const T *input, const int64_t dim0, const int64_t dim1, const int64_t dim2, T *output, - cudaStream_t cuda_stream) { - static const int kMinDimensionToUseTiles = 16; - static const int kMinDimensionToUseRectTiles = 96; - auto short_side = std::min(dim1, dim2); - auto long_side = std::max(dim1, dim2); - // large matrix - // Both dims are greater than 16 && cuda blocks have enough shared mem. - constexpr int kTileSizeLargeMat = 32; - constexpr int kNumThreadsLargeMat = 256; - auto ShmemReqLargeMat = kTileSizeLargeMat * (kTileSizeLargeMat + 1) * sizeof(T); - bool is_large_matrix = short_side >= kMinDimensionToUseTiles && ShmemReqLargeMat <= SHARED_MEM_PER_BLOCK; - // narrow matrix - // one dim less than 16 && one dim greater than 96(narrow) - constexpr int kTileSizeNarrowMatLongSide = 128; - const int kTileSizeNarrowMatShortSide = short_side; - constexpr int kNumThreadsNarrowMat = kTileSizeNarrowMatLongSide; - auto ShmemReqNarrowMat = kTileSizeNarrowMatLongSide * (kTileSizeNarrowMatShortSide + 1) * sizeof(T); - bool is_narrow_matrix = short_side < kMinDimensionToUseTiles && long_side >= kMinDimensionToUseRectTiles && - ShmemReqNarrowMat <= SHARED_MEM_PER_BLOCK; - if (is_large_matrix) { - int64_t input_dims_in_tiles[3]{dim0, (dim1 + kTileSizeLargeMat - 1) / kTileSizeLargeMat, - (dim2 + kTileSizeLargeMat - 1) / kTileSizeLargeMat}; - int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; - Swap3DTensorLast2DimKernel<<>>( - input, kNumThreadsLargeMat, kTileSizeLargeMat, kTileSizeLargeMat, dim0, dim1, dim2, output); - } else if (is_narrow_matrix) { - int64_t input_dims_in_tiles[3]{dim0, 1, (long_side + kTileSizeNarrowMatLongSide - 1) / kTileSizeNarrowMatLongSide}; - int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; - int TileHeight, TileWidth; - if (long_side == dim1) { - TileHeight = kTileSizeNarrowMatLongSide; - TileWidth = short_side; - } else { - TileHeight = short_side; - TileWidth = kTileSizeNarrowMatLongSide; - } - Swap3DTensorLast2DimKernel<<>>( - input, kNumThreadsNarrowMat, TileHeight, TileWidth, dim0, dim1, dim2, output); - } else { - size_t size = static_cast(dim0 * dim1 * dim2); - Transpose3DTensorSimple - <<>>(input, size, dim0, dim1, dim2, output); - } - return; -} - -template -void Swap3DTensorDim0and2(const T *input, const int64_t dim0, const int64_t dim1, const int64_t dim2, T *output, - cudaStream_t cuda_stream) { - size_t size = dim0 * dim1 * dim2; - auto out_ptr = reinterpret_cast(output); - bool aligned = (out_ptr % 16 == 0); // Is aligned with 16 bits(2 bytes)? - bool use_vector{false}, is_custom{false}; - if ((dim0 <= 128 && dim2 <= 128) || dim0 * dim1 <= 128 || dim1 * dim2 <= 8) { - use_vector = is_custom = true; - } else if (dim1 * dim2 <= 16384) { - use_vector = true; - } - if (sizeof(T) == 2 && aligned && use_vector) { - int grid_size; - if (is_custom) { - grid_size = (size + GET_THREADS - 1) / GET_THREADS; - } else { - grid_size = GET_BLOCKS(size); - } - Transpose3DTensorSimpleVector - <<>>(input, size, dim0, dim1, dim2, output); - } else { - Transpose3DTensorSimple - <<>>(input, size, dim0, dim1, dim2, output); - } - - return; -} - -template -cudaError_t CalTranspose(const size_t size, const T *input, const TransposeInfo &info, T *output, - cudaStream_t cuda_stream) { - std::vector new_shape{0}; - std::vector new_perm{0}; - - if (need_simplify) { - SimplifyTranspose(info.input_shape, info.perm, &new_shape, &new_perm); - } else { - new_shape = info.input_shape; - new_perm = info.perm; - } - - if (TransposeUsingTile(input, new_shape, new_perm, output, cuda_stream)) { - return GetCudaStatus(); - } - - TransposeInfoDevice transpose_info_device; - int32_t input_stride[kDimSize]; - int32_t output_stride[kDimSize]; - ComputeInputStride(new_shape, input_stride); - ComputeOutputStride(new_shape, new_perm, output_stride); - - for (size_t i = 0; i < new_shape.size(); ++i) { - transpose_info_device.transpose_info_device[i] = input_stride[i]; - transpose_info_device.transpose_info_device[i + stride_ndims] = output_stride[i]; - transpose_info_device.transpose_info_device[i + stride_ndims * 2] = new_perm[i]; - } - TransposeKernel<<>>(input, size, transpose_info_device, - new_shape.size(), output); - return GetCudaStatus(); -} - -#define REGISTER_CALTRANSPOSE(T, NEED_SIMPLIFY) \ - template CUDA_LIB_EXPORT cudaError_t CalTranspose( \ - const size_t size, const T *input, const TransposeInfo &info, T *output, cudaStream_t cuda_stream) - -#define REGISTER_BOTH_CALTRANSPOSE(T) \ - REGISTER_CALTRANSPOSE(T, true); \ - REGISTER_CALTRANSPOSE(T, false) - -REGISTER_BOTH_CALTRANSPOSE(bool); -REGISTER_BOTH_CALTRANSPOSE(int8_t); -REGISTER_BOTH_CALTRANSPOSE(int16_t); -REGISTER_BOTH_CALTRANSPOSE(int32_t); -REGISTER_BOTH_CALTRANSPOSE(int64_t); -REGISTER_BOTH_CALTRANSPOSE(uint8_t); -REGISTER_BOTH_CALTRANSPOSE(uint16_t); -REGISTER_BOTH_CALTRANSPOSE(uint32_t); -REGISTER_BOTH_CALTRANSPOSE(uint64_t); -REGISTER_BOTH_CALTRANSPOSE(half); -REGISTER_BOTH_CALTRANSPOSE(float); -REGISTER_BOTH_CALTRANSPOSE(double); -REGISTER_BOTH_CALTRANSPOSE(Complex); -REGISTER_BOTH_CALTRANSPOSE(Complex); +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" + +template +using Complex = mindspore::utils::Complex; + +template +__global__ void TransposeKernel(const T *__restrict__ input, const size_t size, const TransposeInfoDevice info, + const int ndims, T *__restrict__ output) { + const int32_t *in_strides = info.transpose_info_device; + const int32_t *out_strides = info.transpose_info_device + stride_ndims; + const int32_t *perm = info.transpose_info_device + stride_ndims * 2; + for (int output_pos = blockDim.x * blockIdx.x + threadIdx.x; output_pos < size; + output_pos += blockDim.x * gridDim.x) { + int32_t input_pos = 0; + int32_t temp = output_pos; + for (int i = 0; i < ndims; ++i) { + const int32_t ratio = temp / out_strides[i]; + temp -= ratio * out_strides[i]; + input_pos += ratio * in_strides[perm[i]]; + } + output[output_pos] = input[input_pos]; + } +} + +template +bool TransposeUsingTile(const T *input, const std::vector &shape, const std::vector &perm, T *output, + cudaStream_t cuda_stream) { + int dims = shape.size(); + if (dims < 2 || dims > 3) { + return false; + } + switch (dims) { + case 2: + if (perm[0] == 1 && perm[1] == 0) { + Swap3DTensorLast2Dim(input, (int64_t)1, shape[0], shape[1], output, cuda_stream); + return true; + } + break; + case 3: + if (perm == std::vector{0, 2, 1}) { + Swap3DTensorLast2Dim(input, shape[0], shape[1], shape[2], output, cuda_stream); + return true; + } else if (perm == std::vector{2, 1, 0}) { + Swap3DTensorDim0and2(input, shape[0], shape[1], shape[2], output, cuda_stream); + return true; + } else { + // Do not support other 3D Transpose. + return false; + } + break; + default: + return false; + } + return false; +} + +// Optimize nchw2nhwc && nhwc2nchw with tiling and shared memory. +// Firstly, combined 2 dims hw together, treat input and output as 3D tensor. +// Secondly, determine whether a matrix is a large matrix or a narrow matrix, +// which determines the chosen TileSize. +// Reason: tiling and shared memory can avoid uncoalesced global memory access. +// There are two stages of this kernel, load-to-shm and write-to-output. +// load-to-shm: Threads in a thread block work together to load input data tile to shared mem. +// write-to-output: Threads in a thread block work together to write shared mem to output tile. +// because of the shared mem usage, The access to both input and output memory can be coalesced. + +// SimpleTransposeKernel for small matrix + +__forceinline__ __device__ int TensorIdxToOneDimIdx(int ndims, const int *idx, const int *dims) { + int flat_idx = idx[0]; + for (int i = 1; i < ndims; i++) { + flat_idx = flat_idx * dims[i] + idx[i]; + } + return flat_idx; +} + +__forceinline__ __device__ void OneDimIdxToTensorIdx(int ndims, int idx, const int *dims, int *out_tensor_idx) { + for (int i = ndims - 1; i >= 0; i--) { + int new_idx = idx / dims[i]; + out_tensor_idx[i] = idx - dims[i] * new_idx; + idx = new_idx; + } +} + +template +__global__ void Transpose3DTensorSimple(const T *__restrict__ input, const size_t size, const int64_t dim0, + const int64_t dim1, const int64_t dim2, T *__restrict__ output) { + int output_shape[3]{0, 0, 0}; + output_shape[perm0] = dim0; + output_shape[perm1] = dim1; + output_shape[perm2] = dim2; + for (int output_pos = blockIdx.x * blockDim.x + threadIdx.x; output_pos < size; + output_pos += gridDim.x * blockDim.x) { + int output_tensor_index[3]{0, 0, 0}; + OneDimIdxToTensorIdx(3, output_pos, output_shape, output_tensor_index); + int input_tensor_index[3]{0, 0, 0}; + int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; + input_tensor_index[0] = output_tensor_index[perm0]; + input_tensor_index[1] = output_tensor_index[perm1]; + input_tensor_index[2] = output_tensor_index[perm2]; + int input_pos = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); + output[output_pos] = input[input_pos]; + } +} + +template +__global__ void Swap3DTensorLast2DimKernel(const T *input, int NumThreads, int TileHeight, int TileWidth, + int input_dims_0, int input_dims_1, int input_dims_2, T *output) { + extern __shared__ unsigned char sdata_uchar[]; + // shm_tile[TileHeight][TileWidth + 1]: to avoid bank conflict in write-to-output period + T *shm_tile = reinterpret_cast(sdata_uchar); + int NumRowsPerLoadLoop = NumThreads / TileWidth; // the number of shm rows that all threads can load into shm once + int NumColsPerWriteLoop = + NumThreads / TileHeight; // the number of shm cols that all threads can write into output once + int load_thread_num_align = NumRowsPerLoadLoop * TileWidth; // use align num threads in load-to-shm period + int write_thread_num_align = NumColsPerWriteLoop * TileHeight; // use align num threads in write-to-output period + int tid = threadIdx.x; + int input_dims[3] = {input_dims_0, input_dims_1, input_dims_2}; + int output_dims[3] = {input_dims[0], input_dims[2], input_dims[1]}; + int input_dims_in_tiles[3] = {input_dims[0], (input_dims[1] + TileHeight - 1) / TileHeight, + (input_dims[2] + TileWidth - 1) / TileWidth}; + int input_tile_idx[3]; + OneDimIdxToTensorIdx(3, blockIdx.x, input_dims_in_tiles, input_tile_idx); + int input_tile_origin[3] = {input_tile_idx[0], input_tile_idx[1] * TileHeight, input_tile_idx[2] * TileWidth}; + int input_block_start_idx = TensorIdxToOneDimIdx(3, input_tile_origin, input_dims); // input idx of this thread block + bool full_tile = true; + int tile_width = TileWidth; + // Only the last row or column may not have the full size + // boundary process + if (input_tile_idx[2] == input_dims_in_tiles[2] - 1) { + tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileWidth; + full_tile &= false; + } + int tile_height = TileHeight; + if (input_tile_idx[1] == input_dims_in_tiles[1] - 1) { + tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileHeight; + full_tile &= false; + } + // load-to-shm: each block load input data into shared mem(loop) + if (tid < load_thread_num_align) { + // Map task blocks to thread blocks. + // organize threads to n*TileWidth + int shm_row_id = tid / TileWidth; // shem_row_id, also the block row_id of input + int shm_col_id = tid % TileWidth; // shem_col_id, also the block col_id of input + int input_idx = input_block_start_idx + shm_row_id * input_dims[2] + shm_col_id; // the input idx of this thread + int input_step = NumRowsPerLoadLoop * input_dims[2]; + if (full_tile) { // thread blocks responses for inner tiles +#pragma unroll + for (int row_id = shm_row_id; row_id < (TileHeight); + row_id += NumRowsPerLoadLoop) { // move to the next pass, loop + // shm_tile[row_id][shm_col_id] + shm_tile[row_id * (TileWidth + 1) + shm_col_id] = + input[input_idx]; // each thread load one input data into shared mem + input_idx += input_step; // calculate the next input idx this thread should load + } + } else { // boundary process: thread blocks responses for edge tiles + if (shm_col_id < tile_width) { + for (int row_id = shm_row_id; row_id < (tile_height); row_id += NumRowsPerLoadLoop) { + // shm_tile[row_id][shm_col_id] + shm_tile[row_id * (TileWidth + 1) + shm_col_id] = input[input_idx]; + input_idx += input_step; + } + } + } + } + __syncthreads(); + // load-to-shm: end + + // write-to-output: each block write shared mem into output(loop) + int output_tile_idx[3] = {input_tile_idx[0], input_tile_idx[2], input_tile_idx[1]}; + int output_tile_origin[3] = {output_tile_idx[0], output_tile_idx[1] * TileWidth, output_tile_idx[2] * TileHeight}; + int output_block_start_idx = TensorIdxToOneDimIdx(3, output_tile_origin, output_dims); + if (tid < write_thread_num_align) { + // organize threads to TileHeight*n1 + int shm_col_id = tid / TileHeight; // shm_col_id, also the block row_id of output + int shm_row_id = tid % TileHeight; // shm_row_id, also the block col_id of output + int output_idx = output_block_start_idx + shm_col_id * output_dims[2] + shm_row_id; + int output_step = NumColsPerWriteLoop * output_dims[2]; + if (full_tile) { +#pragma unroll + for (int col_id = shm_col_id; col_id < (TileWidth); + col_id += NumColsPerWriteLoop) { // move to the next pass, loop + // shm_tile[shm_row_id][col_id] + output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; // avoid bank conflict + output_idx += output_step; + } + } else { + if (shm_row_id < tile_height) { + for (int col_id = shm_col_id; col_id < (tile_width); col_id += NumColsPerWriteLoop) { + // shm_tile[shm_row_id][col_id]; + output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; + output_idx += output_step; + } + } + } + } +} + +template +__global__ void Transpose3DTensorSimpleVector(const T *__restrict__ input, size_t size, const int64_t dim0, + const int64_t dim1, const int64_t dim2, T *__restrict__ output) { + int output_shape[3]{0, 0, 0}; + output_shape[perm0] = dim0; + output_shape[perm1] = dim1; + output_shape[perm2] = dim2; + + const int stride = blockDim.x * gridDim.x * kUnroll; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + T vec[kUnroll]; + int output_pos; + for (output_pos = tid * kUnroll; output_pos + kUnroll - 1 < size; output_pos += stride) { +#pragma unroll + for (int i = 0; i < kUnroll; ++i) { + int outpos_pos_i = output_pos + i; + int output_tensor_index[3]{0, 0, 0}; + OneDimIdxToTensorIdx(3, outpos_pos_i, output_shape, output_tensor_index); + int input_tensor_index[3]{0, 0, 0}; + int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; + input_tensor_index[0] = output_tensor_index[perm0]; + input_tensor_index[1] = output_tensor_index[perm1]; + input_tensor_index[2] = output_tensor_index[perm2]; + int input_pos_i = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); + vec[i] = input[input_pos_i]; + } + float2 *out = reinterpret_cast(output + output_pos); + *out = *reinterpret_cast(vec); + } + + for (; output_pos < size; ++output_pos) { + int output_tensor_index[3]{0, 0, 0}; + OneDimIdxToTensorIdx(3, output_pos, output_shape, output_tensor_index); + int input_tensor_index[3]{0, 0, 0}; + int input_shape[3]{static_cast(dim0), static_cast(dim1), static_cast(dim2)}; + input_tensor_index[0] = output_tensor_index[perm0]; + input_tensor_index[1] = output_tensor_index[perm1]; + input_tensor_index[2] = output_tensor_index[perm2]; + int input_pos = TensorIdxToOneDimIdx(3, input_tensor_index, input_shape); + output[output_pos] = input[input_pos]; + } +} + +template +void Swap3DTensorLast2Dim(const T *input, const int64_t dim0, const int64_t dim1, const int64_t dim2, T *output, + cudaStream_t cuda_stream) { + static const int kMinDimensionToUseTiles = 16; + static const int kMinDimensionToUseRectTiles = 96; + auto short_side = std::min(dim1, dim2); + auto long_side = std::max(dim1, dim2); + // large matrix + // Both dims are greater than 16 && cuda blocks have enough shared mem. + constexpr int kTileSizeLargeMat = 32; + constexpr int kNumThreadsLargeMat = 256; + auto ShmemReqLargeMat = kTileSizeLargeMat * (kTileSizeLargeMat + 1) * sizeof(T); + bool is_large_matrix = short_side >= kMinDimensionToUseTiles && ShmemReqLargeMat <= SHARED_MEM_PER_BLOCK; + // narrow matrix + // one dim less than 16 && one dim greater than 96(narrow) + constexpr int kTileSizeNarrowMatLongSide = 128; + const int kTileSizeNarrowMatShortSide = short_side; + constexpr int kNumThreadsNarrowMat = kTileSizeNarrowMatLongSide; + auto ShmemReqNarrowMat = kTileSizeNarrowMatLongSide * (kTileSizeNarrowMatShortSide + 1) * sizeof(T); + bool is_narrow_matrix = short_side < kMinDimensionToUseTiles && long_side >= kMinDimensionToUseRectTiles && + ShmemReqNarrowMat <= SHARED_MEM_PER_BLOCK; + if (is_large_matrix) { + int64_t input_dims_in_tiles[3]{dim0, (dim1 + kTileSizeLargeMat - 1) / kTileSizeLargeMat, + (dim2 + kTileSizeLargeMat - 1) / kTileSizeLargeMat}; + int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + Swap3DTensorLast2DimKernel<<>>( + input, kNumThreadsLargeMat, kTileSizeLargeMat, kTileSizeLargeMat, dim0, dim1, dim2, output); + } else if (is_narrow_matrix) { + int64_t input_dims_in_tiles[3]{dim0, 1, (long_side + kTileSizeNarrowMatLongSide - 1) / kTileSizeNarrowMatLongSide}; + int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + int TileHeight, TileWidth; + if (long_side == dim1) { + TileHeight = kTileSizeNarrowMatLongSide; + TileWidth = short_side; + } else { + TileHeight = short_side; + TileWidth = kTileSizeNarrowMatLongSide; + } + Swap3DTensorLast2DimKernel<<>>( + input, kNumThreadsNarrowMat, TileHeight, TileWidth, dim0, dim1, dim2, output); + } else { + size_t size = static_cast(dim0 * dim1 * dim2); + Transpose3DTensorSimple + <<>>(input, size, dim0, dim1, dim2, output); + } + return; +} + +template +void Swap3DTensorDim0and2(const T *input, const int64_t dim0, const int64_t dim1, const int64_t dim2, T *output, + cudaStream_t cuda_stream) { + size_t size = dim0 * dim1 * dim2; + auto out_ptr = reinterpret_cast(output); + bool aligned = (out_ptr % 16 == 0); // Is aligned with 16 bits(2 bytes)? + bool use_vector{false}, is_custom{false}; + if ((dim0 <= 128 && dim2 <= 128) || dim0 * dim1 <= 128 || dim1 * dim2 <= 8) { + use_vector = is_custom = true; + } else if (dim1 * dim2 <= 16384) { + use_vector = true; + } + if (sizeof(T) == 2 && aligned && use_vector) { + int grid_size; + if (is_custom) { + grid_size = (size + GET_THREADS - 1) / GET_THREADS; + } else { + grid_size = GET_BLOCKS(size); + } + Transpose3DTensorSimpleVector + <<>>(input, size, dim0, dim1, dim2, output); + } else { + Transpose3DTensorSimple + <<>>(input, size, dim0, dim1, dim2, output); + } + + return; +} + +template +cudaError_t CalTranspose(const size_t size, const T *input, const TransposeInfo &info, T *output, + cudaStream_t cuda_stream) { + std::vector new_shape{0}; + std::vector new_perm{0}; + + if (need_simplify) { + SimplifyTranspose(info.input_shape, info.perm, &new_shape, &new_perm); + } else { + new_shape = info.input_shape; + new_perm = info.perm; + } + + if (TransposeUsingTile(input, new_shape, new_perm, output, cuda_stream)) { + return GetCudaStatus(); + } + + TransposeInfoDevice transpose_info_device; + int32_t input_stride[kDimSize]; + int32_t output_stride[kDimSize]; + ComputeInputStride(new_shape, input_stride); + ComputeOutputStride(new_shape, new_perm, output_stride); + + for (size_t i = 0; i < new_shape.size(); ++i) { + transpose_info_device.transpose_info_device[i] = input_stride[i]; + transpose_info_device.transpose_info_device[i + stride_ndims] = output_stride[i]; + transpose_info_device.transpose_info_device[i + stride_ndims * 2] = new_perm[i]; + } + TransposeKernel<<>>(input, size, transpose_info_device, + new_shape.size(), output); + return GetCudaStatus(); +} + +#define REGISTER_CALTRANSPOSE(T, NEED_SIMPLIFY) \ + template CUDA_LIB_EXPORT cudaError_t CalTranspose( \ + const size_t size, const T *input, const TransposeInfo &info, T *output, cudaStream_t cuda_stream) + +#define REGISTER_BOTH_CALTRANSPOSE(T) \ + REGISTER_CALTRANSPOSE(T, true); \ + REGISTER_CALTRANSPOSE(T, false) + +REGISTER_BOTH_CALTRANSPOSE(bool); +REGISTER_BOTH_CALTRANSPOSE(int8_t); +REGISTER_BOTH_CALTRANSPOSE(int16_t); +REGISTER_BOTH_CALTRANSPOSE(int32_t); +REGISTER_BOTH_CALTRANSPOSE(int64_t); +REGISTER_BOTH_CALTRANSPOSE(uint8_t); +REGISTER_BOTH_CALTRANSPOSE(uint16_t); +REGISTER_BOTH_CALTRANSPOSE(uint32_t); +REGISTER_BOTH_CALTRANSPOSE(uint64_t); +REGISTER_BOTH_CALTRANSPOSE(half); +REGISTER_BOTH_CALTRANSPOSE(float); +REGISTER_BOTH_CALTRANSPOSE(double); +REGISTER_BOTH_CALTRANSPOSE(Complex); +REGISTER_BOTH_CALTRANSPOSE(Complex); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh old mode 100755 new mode 100644 index 3564dfec54a..63868830748 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh @@ -1,102 +1,102 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ - -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -constexpr int kUnroll = 4; // Size of vector. -constexpr int kInfoDims = 78; // Max TransposeInfoDevice length -constexpr int stride_ndims = 26; // Max length of input_shape -constexpr int transpose_max_dimension = 26; // Max dimension of input -constexpr int kDimSize = 26; - -struct TransposeInfo { - std::vector input_shape; - std::vector perm; -}; - -struct TransposeInfoDevice { - int32_t transpose_info_device[kInfoDims]; -}; - -inline void ComputeInputStride(const std::vector &shape, int32_t *strides) { - const int ndims = shape.size(); - int32_t stride = 1; - for (int i = ndims - 1; i >= 0; --i) { - strides[i] = stride; - stride *= static_cast(shape[i]); - } -} - -inline void ComputeOutputStride(const std::vector &shape, const std::vector &perm, int32_t *strides) { - const int ndims = shape.size(); - int32_t stride = 1; - for (int i = ndims - 1; i >= 0; --i) { - strides[i] = stride; - stride *= static_cast(shape[perm[i]]); - } -} - -inline void SimplifyTranspose(const std::vector &input_shape, const std::vector &input_perm, - std::vector *new_shape, std::vector *new_perm) { - auto input_shape_size = input_shape.size(); - std::vector combined_shape(input_shape_size, 0); - std::vector new_perm_position(input_shape_size, -1); - int32_t cur_dim = input_perm[0]; - new_perm_position[cur_dim] = 0; - combined_shape[0] = input_shape[cur_dim]; - int dim_index = 0; - for (size_t perm_index = 1; perm_index < input_shape_size; ++perm_index) { - if (input_perm[perm_index] == cur_dim + 1) { - cur_dim = input_perm[perm_index]; - combined_shape[dim_index] *= input_shape[cur_dim]; - } else { - cur_dim = input_perm[perm_index]; - dim_index++; - new_perm_position[cur_dim] = dim_index; - combined_shape[dim_index] = input_shape[cur_dim]; - } - } - new_shape->resize(dim_index + 1); - std::vector new_perm_temp(dim_index + 1, 0); - new_perm->resize(dim_index + 1); - dim_index = 0; - - for (size_t i = 0; i < new_perm_position.size(); ++i) { - if (new_perm_position[i] >= 0) { - int new_perm_index = new_perm_position[i]; - (*new_shape)[dim_index] = combined_shape[new_perm_index]; - new_perm_temp[dim_index] = new_perm_index; - dim_index++; - } - } - for (int i = 0; i < dim_index + 1; ++i) { - auto ret = std::find_if(new_perm_temp.begin(), new_perm_temp.end(), [&](int x) { return x == i; }); - if (ret != new_perm_temp.end()) { - (*new_perm)[i] = ret - new_perm_temp.begin(); - } - } -} - -template -CUDA_LIB_EXPORT cudaError_t CalTranspose(const size_t size, const T *input, const TransposeInfo &info, T *output, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +constexpr int kUnroll = 4; // Size of vector. +constexpr int kInfoDims = 78; // Max TransposeInfoDevice length +constexpr int stride_ndims = 26; // Max length of input_shape +constexpr int transpose_max_dimension = 26; // Max dimension of input +constexpr int kDimSize = 26; + +struct TransposeInfo { + std::vector input_shape; + std::vector perm; +}; + +struct TransposeInfoDevice { + int32_t transpose_info_device[kInfoDims]; +}; + +inline void ComputeInputStride(const std::vector &shape, int32_t *strides) { + const int ndims = shape.size(); + int32_t stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast(shape[i]); + } +} + +inline void ComputeOutputStride(const std::vector &shape, const std::vector &perm, int32_t *strides) { + const int ndims = shape.size(); + int32_t stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast(shape[perm[i]]); + } +} + +inline void SimplifyTranspose(const std::vector &input_shape, const std::vector &input_perm, + std::vector *new_shape, std::vector *new_perm) { + auto input_shape_size = input_shape.size(); + std::vector combined_shape(input_shape_size, 0); + std::vector new_perm_position(input_shape_size, -1); + int32_t cur_dim = input_perm[0]; + new_perm_position[cur_dim] = 0; + combined_shape[0] = input_shape[cur_dim]; + int dim_index = 0; + for (size_t perm_index = 1; perm_index < input_shape_size; ++perm_index) { + if (input_perm[perm_index] == cur_dim + 1) { + cur_dim = input_perm[perm_index]; + combined_shape[dim_index] *= input_shape[cur_dim]; + } else { + cur_dim = input_perm[perm_index]; + dim_index++; + new_perm_position[cur_dim] = dim_index; + combined_shape[dim_index] = input_shape[cur_dim]; + } + } + new_shape->resize(dim_index + 1); + std::vector new_perm_temp(dim_index + 1, 0); + new_perm->resize(dim_index + 1); + dim_index = 0; + + for (size_t i = 0; i < new_perm_position.size(); ++i) { + if (new_perm_position[i] >= 0) { + int new_perm_index = new_perm_position[i]; + (*new_shape)[dim_index] = combined_shape[new_perm_index]; + new_perm_temp[dim_index] = new_perm_index; + dim_index++; + } + } + for (int i = 0; i < dim_index + 1; ++i) { + auto ret = std::find_if(new_perm_temp.begin(), new_perm_temp.end(), [&](int x) { return x == i; }); + if (ret != new_perm_temp.end()) { + (*new_perm)[i] = ret - new_perm_temp.begin(); + } + } +} + +template +CUDA_LIB_EXPORT cudaError_t CalTranspose(const size_t size, const T *input, const TransposeInfo &info, T *output, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRANSPOSE_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cu index 67fe087db44..2ea4125e4cc 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cu @@ -1,59 +1,59 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tril_indices_impl.cuh" - -template -__global__ void TrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, - const int64_t trapezoid_size, const size_t tril_size, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < tril_size; pos += blockDim.x * gridDim.x) { - int64_t row_idx, col_idx; - if (pos < trapezoid_size) { - int64_t t_first_row = m_first_row << 1; - auto t_bottom_row = t_first_row - 1; - double t_sqrt = sqrt(static_cast(t_bottom_row * t_bottom_row + (pos << 3))); - row_idx = __double2ll_rd((-t_bottom_row + t_sqrt) / 2); - col_idx = pos - ((t_first_row + row_idx - 1) * row_idx >> 1); - } else { - auto surplus = pos - trapezoid_size; - row_idx = surplus / col + col - m_first_row + 1; - col_idx = surplus % col; - } - row_idx += row_offset; - - output[pos] = static_cast(row_idx); - output[pos + tril_size] = static_cast(col_idx); - } -} - -template -cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, - const int64_t trapezoid_size, const size_t tril_size, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - TrilIndices<<>>( - row_offset, m_first_row, col, trapezoid_size, tril_size, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, - const int64_t col, const int64_t trapezoid_size, - const size_t tril_size, int32_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, - const int64_t col, const int64_t trapezoid_size, - const size_t tril_size, int64_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tril_indices_impl.cuh" + +template +__global__ void TrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, + const int64_t trapezoid_size, const size_t tril_size, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < tril_size; pos += blockDim.x * gridDim.x) { + int64_t row_idx, col_idx; + if (pos < trapezoid_size) { + int64_t t_first_row = m_first_row << 1; + auto t_bottom_row = t_first_row - 1; + double t_sqrt = sqrt(static_cast(t_bottom_row * t_bottom_row + (pos << 3))); + row_idx = __double2ll_rd((-t_bottom_row + t_sqrt) / 2); + col_idx = pos - ((t_first_row + row_idx - 1) * row_idx >> 1); + } else { + auto surplus = pos - trapezoid_size; + row_idx = surplus / col + col - m_first_row + 1; + col_idx = surplus % col; + } + row_idx += row_offset; + + output[pos] = static_cast(row_idx); + output[pos + tril_size] = static_cast(col_idx); + } +} + +template +cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, + const int64_t trapezoid_size, const size_t tril_size, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + TrilIndices<<>>( + row_offset, m_first_row, col, trapezoid_size, tril_size, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, + const int64_t col, const int64_t trapezoid_size, + const size_t tril_size, int32_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, + const int64_t col, const int64_t trapezoid_size, + const size_t tril_size, int64_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh index 1a84af485b6..6d5202009d0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ - -#include - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, - const int64_t trapezoid_size, const size_t tril_size, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ + +#include + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalTrilIndices(const int64_t row_offset, const int64_t m_first_row, const int64_t col, + const int64_t trapezoid_size, const size_t tril_size, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_INDICES_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cu index 8e5347076fb..b4af0951027 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cu @@ -1,185 +1,185 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "tril_triu_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" - -template -using Complex = mindspore::utils::Complex; - -template -__global__ void Tril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - int matrix_size = matrix_row * matrix_col; - int row = pos % matrix_size / matrix_col; - int col = pos % matrix_size % matrix_col; - output[pos] = row + diagonal >= col ? input[pos] : static_cast(0.0); - } - return; -} - -template -__global__ void Triu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - int matrix_size = matrix_row * matrix_col; - int row = pos % matrix_size / matrix_col; - int col = pos % matrix_size % matrix_col; - output[pos] = row + diagonal <= col ? input[pos] : static_cast(0.0); - } - return; -} - -template <> -__global__ void Triu(const size_t size, const Complex *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, Complex *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - size_t matrix_size = matrix_row * matrix_col; - int row = pos % matrix_size / matrix_col; - int col = pos % matrix_size % matrix_col; - float rs_real = row + diagonal <= col ? input[pos].real() : static_cast(0.0); - float rs_imag = row + diagonal <= col ? input[pos].imag() : static_cast(0.0); - output[pos].real(rs_real); - output[pos].imag(rs_imag); - } - return; -} - -template <> -__global__ void Triu(const size_t size, const Complex *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, Complex *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - int matrix_size = matrix_row * matrix_col; - int row = pos % matrix_size / matrix_col; - int col = pos % matrix_size % matrix_col; - double rs_real = row + diagonal <= col ? input[pos].real() : static_cast(0.0); - double rs_imag = row + diagonal <= col ? input[pos].imag() : static_cast(0.0); - output[pos].real(rs_real); - output[pos].imag(rs_imag); - } - return; -} - -template -cudaError_t CalTril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) { - Tril<<>>(size, input, diagonal, matrix_row, - matrix_col, output); - return GetCudaStatus(); -} - -template -cudaError_t CalTriu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) { - Triu<<>>(size, input, diagonal, matrix_row, - matrix_col, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint8_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint8_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint16_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint16_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint32_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint64_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int8_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, int8_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int16_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - int16_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, int *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int64_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const half *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const float *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const double *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, double *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const bool *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, bool *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint8_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint8_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint16_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint16_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint32_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint32_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint64_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - uint64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int8_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, int8_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int16_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - int16_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, int *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int64_t *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, - int64_t *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const half *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, half *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const float *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, float *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const double *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, double *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu>(const size_t size, const Complex *input, - const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, Complex *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu>(const size_t size, const Complex *input, - const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, Complex *output, - const uint32_t &device_id, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const bool *input, const int diagonal, - const int64_t matrix_row, const int64_t matrix_col, bool *output, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "tril_triu_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +template +using Complex = mindspore::utils::Complex; + +template +__global__ void Tril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int matrix_size = matrix_row * matrix_col; + int row = pos % matrix_size / matrix_col; + int col = pos % matrix_size % matrix_col; + output[pos] = row + diagonal >= col ? input[pos] : static_cast(0.0); + } + return; +} + +template +__global__ void Triu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int matrix_size = matrix_row * matrix_col; + int row = pos % matrix_size / matrix_col; + int col = pos % matrix_size % matrix_col; + output[pos] = row + diagonal <= col ? input[pos] : static_cast(0.0); + } + return; +} + +template <> +__global__ void Triu(const size_t size, const Complex *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, Complex *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t matrix_size = matrix_row * matrix_col; + int row = pos % matrix_size / matrix_col; + int col = pos % matrix_size % matrix_col; + float rs_real = row + diagonal <= col ? input[pos].real() : static_cast(0.0); + float rs_imag = row + diagonal <= col ? input[pos].imag() : static_cast(0.0); + output[pos].real(rs_real); + output[pos].imag(rs_imag); + } + return; +} + +template <> +__global__ void Triu(const size_t size, const Complex *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, Complex *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int matrix_size = matrix_row * matrix_col; + int row = pos % matrix_size / matrix_col; + int col = pos % matrix_size % matrix_col; + double rs_real = row + diagonal <= col ? input[pos].real() : static_cast(0.0); + double rs_imag = row + diagonal <= col ? input[pos].imag() : static_cast(0.0); + output[pos].real(rs_real); + output[pos].imag(rs_imag); + } + return; +} + +template +cudaError_t CalTril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) { + Tril<<>>(size, input, diagonal, matrix_row, + matrix_col, output); + return GetCudaStatus(); +} + +template +cudaError_t CalTriu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) { + Triu<<>>(size, input, diagonal, matrix_row, + matrix_col, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint8_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint16_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint32_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const uint64_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int8_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, int8_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int16_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + int16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, int *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const int64_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const half *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const float *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const double *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, double *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const bool *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, bool *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint8_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint16_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint32_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const uint64_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + uint64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int8_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, int8_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int16_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + int16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, int *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const int64_t *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const half *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const float *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, float *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const double *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, double *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu>(const size_t size, const Complex *input, + const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, Complex *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu>(const size_t size, const Complex *input, + const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, Complex *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const bool *input, const int diagonal, + const int64_t matrix_row, const int64_t matrix_col, bool *output, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cuh index 5050e76a7d1..a07aa98d722 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_triu_impl.cuh @@ -1,30 +1,30 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -template -CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, - const int64_t matrix_col, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalTril(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template +CUDA_LIB_EXPORT cudaError_t CalTriu(const size_t size, const T *input, const int diagonal, const int64_t matrix_row, + const int64_t matrix_col, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIL_TRIU_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cu index e1e4b7651dc..baa73807b6d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cu @@ -1,60 +1,60 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "triu_indices_impl.cuh" - -template -__global__ void TriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, - const int64_t rectangle_size, const size_t triu_size, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < triu_size; pos += blockDim.x * gridDim.x) { - int64_t row_idx, col_idx; - if (pos < rectangle_size) { - row_idx = pos / col; - col_idx = pos % col; - } else { - int64_t t_first_row = m_first_row << 1; - auto t_bottom_row = -1 - t_first_row; - int64_t idx = pos - rectangle_size; - double t_sqrt = sqrt(static_cast(t_bottom_row * t_bottom_row - (idx << 3))); - row_idx = __double2ll_rd((-t_bottom_row - t_sqrt) / 2); - col_idx = idx - ((t_first_row - row_idx + 1) * row_idx >> 1) + row_idx; - row_idx += rectangle_size / col; - } - col_idx += col_offset; - - output[pos] = static_cast(row_idx); - output[pos + triu_size] = static_cast(col_idx); - } -} - -template -cudaError_t CalTriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, - const int64_t rectangle_size, const size_t triu_size, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream) { - TriuIndices<<>>( - col_offset, m_first_row, col, rectangle_size, triu_size, output); - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t row_offset, const int64_t m_first_row, - const int64_t col, const int64_t trapezoid_size, - const size_t triu_size, int32_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t row_offset, const int64_t m_first_row, - const int64_t col, const int64_t trapezoid_size, - const size_t triu_size, int64_t *output, - const uint32_t &device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "triu_indices_impl.cuh" + +template +__global__ void TriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, + const int64_t rectangle_size, const size_t triu_size, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < triu_size; pos += blockDim.x * gridDim.x) { + int64_t row_idx, col_idx; + if (pos < rectangle_size) { + row_idx = pos / col; + col_idx = pos % col; + } else { + int64_t t_first_row = m_first_row << 1; + auto t_bottom_row = -1 - t_first_row; + int64_t idx = pos - rectangle_size; + double t_sqrt = sqrt(static_cast(t_bottom_row * t_bottom_row - (idx << 3))); + row_idx = __double2ll_rd((-t_bottom_row - t_sqrt) / 2); + col_idx = idx - ((t_first_row - row_idx + 1) * row_idx >> 1) + row_idx; + row_idx += rectangle_size / col; + } + col_idx += col_offset; + + output[pos] = static_cast(row_idx); + output[pos + triu_size] = static_cast(col_idx); + } +} + +template +cudaError_t CalTriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, + const int64_t rectangle_size, const size_t triu_size, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + TriuIndices<<>>( + col_offset, m_first_row, col, rectangle_size, triu_size, output); + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t row_offset, const int64_t m_first_row, + const int64_t col, const int64_t trapezoid_size, + const size_t triu_size, int32_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t row_offset, const int64_t m_first_row, + const int64_t col, const int64_t trapezoid_size, + const size_t triu_size, int64_t *output, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh index 2780d208df4..8771777eaec 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh @@ -1,29 +1,29 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ - -#include - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, - const int64_t rectangle_size, const size_t triu_size, T *output, - const uint32_t &device_id, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ + +#include + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalTriuIndices(const int64_t col_offset, const int64_t m_first_row, const int64_t col, + const int64_t rectangle_size, const size_t triu_size, T *output, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIU_INDICES_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh index d10137a6a72..cff7542fc09 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unravel_index_impl.cuh @@ -1,27 +1,27 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" - -template -CUDA_LIB_EXPORT cudaError_t CalUnravelIndex(T *input_indices, T *input_dims, const size_t indices_size, - const size_t dims_size, T *output, const uint32_t &device_id, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT cudaError_t CalUnravelIndex(T *input_indices, T *input_dims, const size_t indices_size, + const size_t dims_size, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNRAVEL_INDEX_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cu index 3d4e7b8eba3..36008c8ea06 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cu @@ -1,136 +1,136 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh" -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample.cuh" - -__device__ __forceinline__ int idx_dhw(const int height, const int width, const int z, const int y, const int x) { - return (z * height + y) * width + x; -} - -template -__device__ __forceinline__ S special_cast(T value) { - return static_cast(value); -} - -template <> -__device__ __forceinline__ float special_cast(half value) { - return __half2float(value); -} - -template <> -__device__ __forceinline__ half special_cast(float value) { - return __float2half(value); -} - -template -__global__ void UpsampleTrilinear3DGradKernel(const size_t elem_num, const T *grad, const int batchsize, - const int channels, const int grad_d, const int grad_h, const int grad_w, - const int grad_dhw, const int dinput_d, const int dinput_h, - const int dinput_w, const int dinput_dhw, const S d_scale, - const S h_scale, const S w_scale, const bool align_corner, T *dinput) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < grad_dhw; pos += blockDim.x * gridDim.x) { - const int t2 = pos / (grad_h * grad_w); - const int h2 = pos / grad_w % grad_h; - const int w2 = pos % grad_w; - - const S t1r = area_pixel_compute_source_index(d_scale, t2, align_corner, false); - const int t1 = floorf(t1r); - const int t1p = (t1 < (dinput_d - 1)) ? 1 : 0; - const S t1lambda = t1r - t1; - const S t0lambda = static_cast(1) - t1lambda; - - const S h1r = area_pixel_compute_source_index(h_scale, h2, align_corner, false); - const int h1 = floorf(h1r); - const int h1p = (h1 < (dinput_h - 1)) ? 1 : 0; - const S h1lambda = h1r - h1; - const S h0lambda = static_cast(1) - h1lambda; - - const S w1r = area_pixel_compute_source_index(w_scale, w2, align_corner, false); - const int w1 = floorf(w1r); - const int w1p = (w1 < (dinput_w - 1)) ? 1 : 0; - const S w1lambda = w1r - w1; - const S w0lambda = static_cast(1) - w1lambda; - - size_t dinput_offset = 0; - size_t dout_offset = 0; - for (int n = 0; n < batchsize; ++n) { - for (int c = 0; c < channels; ++c) { - const S d2val = special_cast(grad[dout_offset + (t2 * grad_h + h2) * grad_w + w2]); - - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1, w1), elem_num, - special_cast(t0lambda * h0lambda * w0lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1, w1 + w1p), elem_num, - special_cast(t0lambda * h0lambda * w1lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1 + h1p, w1), elem_num, - special_cast(t0lambda * h1lambda * w0lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1 + h1p, w1 + w1p), elem_num, - special_cast(t0lambda * h1lambda * w1lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1, w1), elem_num, - special_cast(t1lambda * h0lambda * w0lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1, w1 + w1p), elem_num, - special_cast(t1lambda * h0lambda * w1lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1 + h1p, w1), elem_num, - special_cast(t1lambda * h1lambda * w0lambda * d2val)); - FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1 + h1p, w1 + w1p), elem_num, - special_cast(t1lambda * h1lambda * w1lambda * d2val)); - - dout_offset += grad_dhw; - dinput_offset += dinput_dhw; - } - } - } - return; -} - -template -cudaError_t CalUpsampleTrilinear3DGrad(const T *grad, const int n, const int c, const int grad_d, const int grad_h, - const int grad_w, const int dinput_d, const int dinput_h, const int dinput_w, - const S d_scale, const S h_scale, const S w_scale, const bool align_corner, - T *dinput, const uint32_t device_id, cudaStream_t cuda_stream) { - const int dinput_dhw = dinput_d * dinput_h * dinput_w; - const int grad_dhw = grad_d * grad_h * grad_w; - const int dinput_size = dinput_dhw * n * c; - if (dinput_d == grad_d && dinput_h == grad_h && dinput_w == grad_w) { - CudaMemcpyDeviceToDevice - <<>>(dinput_size, grad, dinput); - } else { - (void)cudaMemset(dinput, 0, sizeof(T) * dinput_size); - const size_t blockSize = std::min(CUDA_THREADS(device_id), static_cast(256)); - const size_t gridSize = (grad_dhw + blockSize - 1) / blockSize; - UpsampleTrilinear3DGradKernel<<>>( - dinput_size, grad, n, c, grad_d, grad_h, grad_w, grad_dhw, dinput_d, dinput_h, dinput_w, dinput_dhw, d_scale, - h_scale, w_scale, align_corner, dinput); - } - return GetCudaStatus(); -} - -template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( - const half *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, const int dinput_d, - const int dinput_h, const int dinput_w, const float d_scale, const float h_scale, const float w_scale, - const bool align_corner, half *dinput, const uint32_t device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( - const float *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, const int dinput_d, - const int dinput_h, const int dinput_w, const float d_scale, const float h_scale, const float w_scale, - const bool align_corner, float *dinput, const uint32_t device_id, cudaStream_t cuda_stream); - -template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( - const double *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, - const int dinput_d, const int dinput_h, const int dinput_w, const double d_scale, const double h_scale, - const double w_scale, const bool align_corner, double *dinput, const uint32_t device_id, cudaStream_t cuda_stream); +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh" +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample.cuh" + +__device__ __forceinline__ int idx_dhw(const int height, const int width, const int z, const int y, const int x) { + return (z * height + y) * width + x; +} + +template +__device__ __forceinline__ S special_cast(T value) { + return static_cast(value); +} + +template <> +__device__ __forceinline__ float special_cast(half value) { + return __half2float(value); +} + +template <> +__device__ __forceinline__ half special_cast(float value) { + return __float2half(value); +} + +template +__global__ void UpsampleTrilinear3DGradKernel(const size_t elem_num, const T *grad, const int batchsize, + const int channels, const int grad_d, const int grad_h, const int grad_w, + const int grad_dhw, const int dinput_d, const int dinput_h, + const int dinput_w, const int dinput_dhw, const S d_scale, + const S h_scale, const S w_scale, const bool align_corner, T *dinput) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < grad_dhw; pos += blockDim.x * gridDim.x) { + const int t2 = pos / (grad_h * grad_w); + const int h2 = pos / grad_w % grad_h; + const int w2 = pos % grad_w; + + const S t1r = area_pixel_compute_source_index(d_scale, t2, align_corner, false); + const int t1 = floorf(t1r); + const int t1p = (t1 < (dinput_d - 1)) ? 1 : 0; + const S t1lambda = t1r - t1; + const S t0lambda = static_cast(1) - t1lambda; + + const S h1r = area_pixel_compute_source_index(h_scale, h2, align_corner, false); + const int h1 = floorf(h1r); + const int h1p = (h1 < (dinput_h - 1)) ? 1 : 0; + const S h1lambda = h1r - h1; + const S h0lambda = static_cast(1) - h1lambda; + + const S w1r = area_pixel_compute_source_index(w_scale, w2, align_corner, false); + const int w1 = floorf(w1r); + const int w1p = (w1 < (dinput_w - 1)) ? 1 : 0; + const S w1lambda = w1r - w1; + const S w0lambda = static_cast(1) - w1lambda; + + size_t dinput_offset = 0; + size_t dout_offset = 0; + for (int n = 0; n < batchsize; ++n) { + for (int c = 0; c < channels; ++c) { + const S d2val = special_cast(grad[dout_offset + (t2 * grad_h + h2) * grad_w + w2]); + + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1, w1), elem_num, + special_cast(t0lambda * h0lambda * w0lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1, w1 + w1p), elem_num, + special_cast(t0lambda * h0lambda * w1lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1 + h1p, w1), elem_num, + special_cast(t0lambda * h1lambda * w0lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1, h1 + h1p, w1 + w1p), elem_num, + special_cast(t0lambda * h1lambda * w1lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1, w1), elem_num, + special_cast(t1lambda * h0lambda * w0lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1, w1 + w1p), elem_num, + special_cast(t1lambda * h0lambda * w1lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1 + h1p, w1), elem_num, + special_cast(t1lambda * h1lambda * w0lambda * d2val)); + FastAtomicAdd(dinput, dinput_offset + idx_dhw(dinput_h, dinput_w, t1 + t1p, h1 + h1p, w1 + w1p), elem_num, + special_cast(t1lambda * h1lambda * w1lambda * d2val)); + + dout_offset += grad_dhw; + dinput_offset += dinput_dhw; + } + } + } + return; +} + +template +cudaError_t CalUpsampleTrilinear3DGrad(const T *grad, const int n, const int c, const int grad_d, const int grad_h, + const int grad_w, const int dinput_d, const int dinput_h, const int dinput_w, + const S d_scale, const S h_scale, const S w_scale, const bool align_corner, + T *dinput, const uint32_t device_id, cudaStream_t cuda_stream) { + const int dinput_dhw = dinput_d * dinput_h * dinput_w; + const int grad_dhw = grad_d * grad_h * grad_w; + const int dinput_size = dinput_dhw * n * c; + if (dinput_d == grad_d && dinput_h == grad_h && dinput_w == grad_w) { + CudaMemcpyDeviceToDevice + <<>>(dinput_size, grad, dinput); + } else { + (void)cudaMemset(dinput, 0, sizeof(T) * dinput_size); + const size_t blockSize = std::min(CUDA_THREADS(device_id), static_cast(256)); + const size_t gridSize = (grad_dhw + blockSize - 1) / blockSize; + UpsampleTrilinear3DGradKernel<<>>( + dinput_size, grad, n, c, grad_d, grad_h, grad_w, grad_dhw, dinput_d, dinput_h, dinput_w, dinput_dhw, d_scale, + h_scale, w_scale, align_corner, dinput); + } + return GetCudaStatus(); +} + +template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( + const half *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, const int dinput_d, + const int dinput_h, const int dinput_w, const float d_scale, const float h_scale, const float w_scale, + const bool align_corner, half *dinput, const uint32_t device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( + const float *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, const int dinput_d, + const int dinput_h, const int dinput_w, const float d_scale, const float h_scale, const float w_scale, + const bool align_corner, float *dinput, const uint32_t device_id, cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad( + const double *grad, const int n, const int c, const int grad_d, const int grad_h, const int grad_w, + const int dinput_d, const int dinput_h, const int dinput_w, const double d_scale, const double h_scale, + const double w_scale, const bool align_corner, double *dinput, const uint32_t device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh index c69803a83b6..3bef58c0b1d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/upsample_trilinear_3d_grad_impl.cuh @@ -1,26 +1,26 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" -template -CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad(const T *grad, const int n, const int c, const int grad_d, - const int grad_h, const int grad_w, const int dinput_d, - const int dinput_h, const int dinput_w, const S d_scale, - const S h_scale, const S w_scale, const bool align_corner, - T *dinput, const uint32_t device_id, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +template +CUDA_LIB_EXPORT cudaError_t CalUpsampleTrilinear3DGrad(const T *grad, const int n, const int c, const int grad_d, + const int grad_h, const int grad_w, const int dinput_d, + const int dinput_h, const int dinput_w, const S d_scale, + const S h_scale, const S w_scale, const bool align_corner, + T *dinput, const uint32_t device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UPSAMPLE_TRILINEAR_3D_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h index b203f6508af..8e455e67f2b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h @@ -1,250 +1,250 @@ -/** - * Copyright 2019-2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_mod.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" -#include "plugin/device/gpu/hal/device/gpu_device_manager.h" -#include "plugin/device/gpu/hal/device/gpu_device_address.h" -#include "plugin/device/gpu/hal/device/gpu_common.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "include/common/utils/anfalgo.h" -#include "kernel/kernel_build_info.h" -#include "kernel/common_utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" - -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; - -// The max_limit of tensor shape size: 2 Giga-elements(2^31, the largest number in 32 bits). -#define SHAPE_SIZE_LIMIT 2147483648 - -namespace mindspore { -namespace kernel { -constexpr size_t kShapeIndex1st = 1; -constexpr size_t kShapeIndex2nd = 2; -constexpr size_t kShapeIndex3rd = 3; -constexpr size_t kShapeIndex4th = 4; -constexpr size_t kShapeIndex5nd = 5; -constexpr size_t kShapeIndex6rd = 6; -constexpr size_t kShapeIndex7th = 7; - -constexpr size_t kDim2DShapeSize = 4; -constexpr size_t kDim3DShapeSize = 5; -constexpr size_t kPoolingNbDims = kDim3DShapeSize; - -constexpr size_t kHelperDimsNum = 5; - -static std::map kNCHWToNHWCAxisMap = { - {0, 0}, - {1, 3}, - {2, 1}, - {3, 2}, -}; -static std::map kNHWCToNCHWAxisMap = { - {0, 0}, - {1, 2}, - {2, 3}, - {3, 1}, -}; - -static auto Anyone = [](auto &&k, auto &&... args) { return ((args == k) || ...); }; - -inline int CeilDivide(int m, int n) { return (m + n - 1) / n; } - -inline int GetPad(int input, int kernel, int stride) { - return std::max(0, (CeilDivide(input, stride) - 1) * stride + kernel - input); -} - -// Choose the suitable datatype for cudnn -inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { - auto type = kCudnnDtypeMap.find(Type); - if (type == kCudnnDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; -} - -// Choose the suitable datatype for cublas -inline cudaDataType_t GetCudaDataType(const std::string &Type) { - auto type = kCudaDtypeMap.find(Type); - if (type == kCudaDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; -} - -class NativeGpuKernelMod : public GpuKernelMod { - public: - using ReduceDetail = std::tuple; - using ReducePrecisonRes = std::tuple, std::vector>; - - virtual void DestroyResource() noexcept {} - bool CheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr); - std::vector GetAllSupportedList(const std::string &kernel_name); - ReducePrecisonRes ReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr); - static std::vector GetGpuSupportedList(const std::string &kernel_name) { - if (!Factory::Instance().IsRegistered(kernel_name)) { - return {}; - } - return Factory::Instance().Create(kernel_name)->GetAllSupportedList(kernel_name); - } - std::vector GetOpSupport() { return {}; } - static bool GpuCheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr); - - static ReducePrecisonRes GpuReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr) { - return Factory::Instance().Create(kernel_name)->ReducePrecisionCheck(kernel_name, kernel_attr); - } - enum KernelModType GetKernelModType() const override { return KernelModType::NativeGpuKernelMod; } - - protected: - virtual void InitResource() {} - static mindspore::HashMap> support_map_; -}; - -std::vector ConvertPtrs(const std::vector &input_ptrs); - -// expand Nd Shape to 4d (N in [0,4]) -bool ShapeNdTo4d(const ShapeVector &src, ShapeVector *dst); - -template -inline T *GetPossiblyNullDeviceAddress(const std::vector &addr_list, size_t index) { - if (index >= addr_list.size()) { - MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; - return nullptr; - } - // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. - if ((addr_list[index] == nullptr) || (addr_list[index]->size() == 0)) { - return nullptr; - } - if (addr_list[index]->device_ptr() == nullptr) { - MS_LOG(ERROR) << "The device address is empty, address index:" << index; - return nullptr; - } - return reinterpret_cast(addr_list[index]->device_ptr()); -} -template -inline T *GetPossiblyNullDeviceAddress(const std::vector &addr_list, size_t index) { - if (index >= addr_list.size()) { - MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; - return nullptr; - } - // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. - if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { - return nullptr; - } - if (addr_list[index]->addr == nullptr) { - MS_LOG(ERROR) << "The device address is empty, address index:" << index; - return nullptr; - } - return reinterpret_cast(addr_list[index]->addr); -} - -int AxisTransform(const std::string &origin_data_format, const std::string &cal_format, int axis); - -// transpose shape: NCHW To NHWC -void ShapeNCHW2NHWC(ShapeVector *shape); - -// transpose shape: NCDHW To NDHWC -void ShapeNCDHW2NDHWC(ShapeVector *shape); - -//////////////// old: format string ///////////// -void SetDimA(const ShapeVector &shape, int *dimA, size_t len, const std::string &format); - -void SetStrideA(const ShapeVector &shape, int *strideA, size_t len, const std::string &format); - -void SetNCHW(const ShapeVector &shape, int *n, int *c, int *h, int *w, const std::string &format); - -void SetNCDHW(const ShapeVector &shape, int *n, int *c, int *d, int *h, int *w, const std::string &format); -//////////////////////////////////////////////// -//////////////// new: format enum/////////////// -void SetDimA(const ShapeVector &shape, int *dimA, size_t len, const mindspore::Format &format); - -void SetStrideA(const ShapeVector &shape, int *strideA, size_t len, const mindspore::Format &format); - -void SetNCHW(const ShapeVector &shape, int *n, int *c, int *h, int *w, const mindspore::Format &format); - -void SetNCDHW(const ShapeVector &shape, int *n, int *c, int *d, int *h, int *w, const mindspore::Format &format); -//////////////////////////////////////////////// - -bool CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, const std::vector &Out); - -// The tensor size is limited to 2G by cudnn. -bool CheckTensorSize(const std::initializer_list &shapes); - -// set the tensor descriptor for cudnn/cublas -bool CudnnSetTensorNdDescriptor(const ShapeVector &shape, cudnnTensorDescriptor_t descriptor, cudnnDataType_t data_type, - const std::string &node_name); - -// choose the suitable datatype for cudnn/cublas -bool GetCudnnDataType(const std::string &Type, cudnnDataType_t *out_type); - -bool GetCudaDataType(const std::string &Type, cudaDataType_t *out_type); - -bool ShapeEqual(const ShapeVector &s1, const ShapeVector &s2); - -template -T GetDimValue(const std::vector &inputs, const int index, const string kernel_name, - const TypeId &dim_type) { - size_t size = abstract::TypeIdSize(dim_type); - auto dim_gpu_addr = - std::make_shared(inputs[index]->device_ptr(), size, kOpFormat_DEFAULT, dim_type); - int res = 0; - if (dim_type == kNumberTypeInt32) { - int32_t host_dim = 0; - dim_gpu_addr->SyncDeviceToHost(size, &host_dim); - res = static_cast(host_dim); - } else if (dim_type == kNumberTypeInt64) { - int64_t host_dim = 0; - dim_gpu_addr->SyncDeviceToHost(size, &host_dim); - res = static_cast(host_dim); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', got unsupported data type of dim: " << dim_type; - } - return res; -} -// This is necessary for gpu kernels to support uint8 data type. In cuda, an unsigned, -// 8 bit integral type is represented by an unsigned char, but the MS_REG_GPU_KERNEL -// macros defined below will create compilation errors when datatype T contains a space, -// because the variable created by the macro will also contain a space. So, we solve this -// problem by writing uchar when calling these macros, and expanding uchar after the -// variable has been created. -using uchar = unsigned char; - -inline size_t GetTensorSize(std::vector shape) { - return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); -} -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ +/** + * Copyright 2019-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "kernel/kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_mod.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/kernel_constants.h" +#include "plugin/device/gpu/hal/device/gpu_device_manager.h" +#include "plugin/device/gpu/hal/device/gpu_device_address.h" +#include "plugin/device/gpu/hal/device/gpu_common.h" +#include "include/backend/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "kernel/kernel_build_info.h" +#include "kernel/common_utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; + +// The max_limit of tensor shape size: 2 Giga-elements(2^31, the largest number in 32 bits). +#define SHAPE_SIZE_LIMIT 2147483648 + +namespace mindspore { +namespace kernel { +constexpr size_t kShapeIndex1st = 1; +constexpr size_t kShapeIndex2nd = 2; +constexpr size_t kShapeIndex3rd = 3; +constexpr size_t kShapeIndex4th = 4; +constexpr size_t kShapeIndex5nd = 5; +constexpr size_t kShapeIndex6rd = 6; +constexpr size_t kShapeIndex7th = 7; + +constexpr size_t kDim2DShapeSize = 4; +constexpr size_t kDim3DShapeSize = 5; +constexpr size_t kPoolingNbDims = kDim3DShapeSize; + +constexpr size_t kHelperDimsNum = 5; + +static std::map kNCHWToNHWCAxisMap = { + {0, 0}, + {1, 3}, + {2, 1}, + {3, 2}, +}; +static std::map kNHWCToNCHWAxisMap = { + {0, 0}, + {1, 2}, + {2, 3}, + {3, 1}, +}; + +static auto Anyone = [](auto &&k, auto &&... args) { return ((args == k) || ...); }; + +inline int CeilDivide(int m, int n) { return (m + n - 1) / n; } + +inline int GetPad(int input, int kernel, int stride) { + return std::max(0, (CeilDivide(input, stride) - 1) * stride + kernel - input); +} + +// Choose the suitable datatype for cudnn +inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { + auto type = kCudnnDtypeMap.find(Type); + if (type == kCudnnDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; +} + +// Choose the suitable datatype for cublas +inline cudaDataType_t GetCudaDataType(const std::string &Type) { + auto type = kCudaDtypeMap.find(Type); + if (type == kCudaDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; +} + +class NativeGpuKernelMod : public GpuKernelMod { + public: + using ReduceDetail = std::tuple; + using ReducePrecisonRes = std::tuple, std::vector>; + + virtual void DestroyResource() noexcept {} + bool CheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr); + std::vector GetAllSupportedList(const std::string &kernel_name); + ReducePrecisonRes ReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr); + static std::vector GetGpuSupportedList(const std::string &kernel_name) { + if (!Factory::Instance().IsRegistered(kernel_name)) { + return {}; + } + return Factory::Instance().Create(kernel_name)->GetAllSupportedList(kernel_name); + } + std::vector GetOpSupport() { return {}; } + static bool GpuCheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr); + + static ReducePrecisonRes GpuReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr) { + return Factory::Instance().Create(kernel_name)->ReducePrecisionCheck(kernel_name, kernel_attr); + } + enum KernelModType GetKernelModType() const override { return KernelModType::NativeGpuKernelMod; } + + protected: + virtual void InitResource() {} + static mindspore::HashMap> support_map_; +}; + +std::vector ConvertPtrs(const std::vector &input_ptrs); + +// expand Nd Shape to 4d (N in [0,4]) +bool ShapeNdTo4d(const ShapeVector &src, ShapeVector *dst); + +template +inline T *GetPossiblyNullDeviceAddress(const std::vector &addr_list, size_t index) { + if (index >= addr_list.size()) { + MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; + return nullptr; + } + // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. + if ((addr_list[index] == nullptr) || (addr_list[index]->size() == 0)) { + return nullptr; + } + if (addr_list[index]->device_ptr() == nullptr) { + MS_LOG(ERROR) << "The device address is empty, address index:" << index; + return nullptr; + } + return reinterpret_cast(addr_list[index]->device_ptr()); +} +template +inline T *GetPossiblyNullDeviceAddress(const std::vector &addr_list, size_t index) { + if (index >= addr_list.size()) { + MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; + return nullptr; + } + // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. + if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { + return nullptr; + } + if (addr_list[index]->addr == nullptr) { + MS_LOG(ERROR) << "The device address is empty, address index:" << index; + return nullptr; + } + return reinterpret_cast(addr_list[index]->addr); +} + +int AxisTransform(const std::string &origin_data_format, const std::string &cal_format, int axis); + +// transpose shape: NCHW To NHWC +void ShapeNCHW2NHWC(ShapeVector *shape); + +// transpose shape: NCDHW To NDHWC +void ShapeNCDHW2NDHWC(ShapeVector *shape); + +//////////////// old: format string ///////////// +void SetDimA(const ShapeVector &shape, int *dimA, size_t len, const std::string &format); + +void SetStrideA(const ShapeVector &shape, int *strideA, size_t len, const std::string &format); + +void SetNCHW(const ShapeVector &shape, int *n, int *c, int *h, int *w, const std::string &format); + +void SetNCDHW(const ShapeVector &shape, int *n, int *c, int *d, int *h, int *w, const std::string &format); +//////////////////////////////////////////////// +//////////////// new: format enum/////////////// +void SetDimA(const ShapeVector &shape, int *dimA, size_t len, const mindspore::Format &format); + +void SetStrideA(const ShapeVector &shape, int *strideA, size_t len, const mindspore::Format &format); + +void SetNCHW(const ShapeVector &shape, int *n, int *c, int *h, int *w, const mindspore::Format &format); + +void SetNCDHW(const ShapeVector &shape, int *n, int *c, int *d, int *h, int *w, const mindspore::Format &format); +//////////////////////////////////////////////// + +bool CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, const std::vector &Out); + +// The tensor size is limited to 2G by cudnn. +bool CheckTensorSize(const std::initializer_list &shapes); + +// set the tensor descriptor for cudnn/cublas +bool CudnnSetTensorNdDescriptor(const ShapeVector &shape, cudnnTensorDescriptor_t descriptor, cudnnDataType_t data_type, + const std::string &node_name); + +// choose the suitable datatype for cudnn/cublas +bool GetCudnnDataType(const std::string &Type, cudnnDataType_t *out_type); + +bool GetCudaDataType(const std::string &Type, cudaDataType_t *out_type); + +bool ShapeEqual(const ShapeVector &s1, const ShapeVector &s2); + +template +T GetDimValue(const std::vector &inputs, const int index, const string kernel_name, + const TypeId &dim_type) { + size_t size = abstract::TypeIdSize(dim_type); + auto dim_gpu_addr = + std::make_shared(inputs[index]->device_ptr(), size, kOpFormat_DEFAULT, dim_type); + int res = 0; + if (dim_type == kNumberTypeInt32) { + int32_t host_dim = 0; + dim_gpu_addr->SyncDeviceToHost(size, &host_dim); + res = static_cast(host_dim); + } else if (dim_type == kNumberTypeInt64) { + int64_t host_dim = 0; + dim_gpu_addr->SyncDeviceToHost(size, &host_dim); + res = static_cast(host_dim); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name << "', got unsupported data type of dim: " << dim_type; + } + return res; +} +// This is necessary for gpu kernels to support uint8 data type. In cuda, an unsigned, +// 8 bit integral type is represented by an unsigned char, but the MS_REG_GPU_KERNEL +// macros defined below will create compilation errors when datatype T contains a space, +// because the variable created by the macro will also contain a space. So, we solve this +// problem by writing uchar when calling these macros, and expanding uchar after the +// variable has been created. +using uchar = unsigned char; + +inline size_t GetTensorSize(std::vector shape) { + return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); +} +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.cc index dbb354c7c06..8d049510ede 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.cc @@ -1,111 +1,111 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h" -#include -namespace mindspore { -namespace kernel { -bool AccumulateNV2GpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the kernel type should be in [half, float, double, int32, int8, uint8], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - return true; -} - -int AccumulateNV2GpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - for (const auto &output : outputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto output_shape = output->GetShapeVector(); - if (!IsValidShape(output_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - n_ = inputs.size(); - size_t output_size = output_elements_ * unit_output_size_; - output_size_list_.push_back(output_size); - workspace_size_list_.push_back(n_ * sizeof(void *)); - return KRET_OK; -} - -template -bool AccumulateNV2GpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *output = GetDeviceAddress(outputs, 0); - T **inputs_array = GetDeviceAddress(workspace, 0); - std::unique_ptr inputs_host = std::make_unique(n_); - for (size_t i = 0; i < inputs.size(); i++) { - inputs_host[i] = GetDeviceAddress(inputs, i); - } - cudaStream_t stream = reinterpret_cast(cuda_stream_); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(inputs_array, inputs_host.get(), n_ * sizeof(T *), cudaMemcpyHostToDevice, stream), - "cudaMemcpy failed."); - auto status = CalAccumulateNV2(output_elements_, n_, inputs_array, output, device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> AccumulateNV2GpuKernelMod::func_list_ = { - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - &AccumulateNV2GpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - &AccumulateNV2GpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &AccumulateNV2GpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &AccumulateNV2GpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &AccumulateNV2GpuKernelMod::LaunchKernel}, - {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &AccumulateNV2GpuKernelMod::LaunchKernel}}; - -std::vector AccumulateNV2GpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AccumulateNV2, AccumulateNV2GpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h" +#include +namespace mindspore { +namespace kernel { +bool AccumulateNV2GpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the kernel type should be in [half, float, double, int32, int8, uint8], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + return true; +} + +int AccumulateNV2GpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + n_ = inputs.size(); + size_t output_size = output_elements_ * unit_output_size_; + output_size_list_.push_back(output_size); + workspace_size_list_.push_back(n_ * sizeof(void *)); + return KRET_OK; +} + +template +bool AccumulateNV2GpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *output = GetDeviceAddress(outputs, 0); + T **inputs_array = GetDeviceAddress(workspace, 0); + std::unique_ptr inputs_host = std::make_unique(n_); + for (size_t i = 0; i < inputs.size(); i++) { + inputs_host[i] = GetDeviceAddress(inputs, i); + } + cudaStream_t stream = reinterpret_cast(cuda_stream_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(inputs_array, inputs_host.get(), n_ * sizeof(T *), cudaMemcpyHostToDevice, stream), + "cudaMemcpy failed."); + auto status = CalAccumulateNV2(output_elements_, n_, inputs_array, output, device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> AccumulateNV2GpuKernelMod::func_list_ = { + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + &AccumulateNV2GpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + &AccumulateNV2GpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &AccumulateNV2GpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &AccumulateNV2GpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &AccumulateNV2GpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &AccumulateNV2GpuKernelMod::LaunchKernel}}; + +std::vector AccumulateNV2GpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AccumulateNV2, AccumulateNV2GpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h index 7af7cf93062..01daef58dea 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/accumulate_n_v2_gpu_kernel.h @@ -1,84 +1,84 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/accumulate_n_v2.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh" - -namespace mindspore { -namespace kernel { -class AccumulateNV2GpuKernelMod : public NativeGpuKernelMod { - public: - AccumulateNV2GpuKernelMod() { ResetResource(); } - ~AccumulateNV2GpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - n_ = 0; - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using ANVFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t n_{1}; - size_t unit_output_size_{1}; - size_t output_elements_; - ANVFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/accumulate_n_v2.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/accumulate_n_v2_impl.cuh" + +namespace mindspore { +namespace kernel { +class AccumulateNV2GpuKernelMod : public NativeGpuKernelMod { + public: + AccumulateNV2GpuKernelMod() { ResetResource(); } + ~AccumulateNV2GpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + n_ = 0; + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using ANVFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t n_{1}; + size_t unit_output_size_{1}; + size_t output_elements_; + ANVFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ACCUMULATE_N_V2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.cc index 9fa8141126a..ecd412d71d3 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.cc @@ -1,120 +1,120 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -template -std::unique_ptr CreateHeavisideKernelPtr(const std::string &kernel_name, - const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); -} -using HeavisidePtrCreatorFunc = - std::function(const std::string &, const uint32_t &)>; - -const std::vector> kernel_attr = { - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - CreateHeavisideKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - CreateHeavisideKernelPtr}}; -} // namespace - -bool HeavisideGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - std::vector input_ptrs = ConvertPtrs(inputs); - std::vector work_ptrs = ConvertPtrs(workspace); - std::vector output_ptrs = ConvertPtrs(outputs); - if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { - return false; - } - return true; -} - -bool HeavisideGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); - if (!is_match) { - return false; - } - helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); - MS_ERROR_IF_NULL(helper_ptr_); - return true; -} - -int HeavisideGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - - auto output = outputs.at(kIndex0); - MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); - std::vector> input_shapes; - std::vector> output_shapes; - std::vector inpx_shape = - inputs.at(kIndex0)->GetShapeVector().empty() ? std::vector({1}) : inputs.at(kIndex0)->GetShapeVector(); - std::vector inpy_shape = - inputs.at(kIndex1)->GetShapeVector().empty() ? std::vector({1}) : inputs.at(kIndex1)->GetShapeVector(); - std::vector out_shape = - output->GetShapeVector().empty() ? std::vector({1}) : output->GetShapeVector(); - input_shapes.emplace_back(inpx_shape); - input_shapes.emplace_back(inpy_shape); - output_shapes.emplace_back(out_shape); - if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { - return KRET_RESIZE_FAILED; - } - output_size_list_ = helper_ptr_->GetOutputSizeList(); - workspace_size_list_ = helper_ptr_->GetWorkSizeList(); - return KRET_OK; -} - -std::vector HeavisideGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Heaviside, HeavisideGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateHeavisideKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using HeavisidePtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CreateHeavisideKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CreateHeavisideKernelPtr}}; +} // namespace + +bool HeavisideGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool HeavisideGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + MS_ERROR_IF_NULL(helper_ptr_); + return true; +} + +int HeavisideGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + + auto output = outputs.at(kIndex0); + MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); + std::vector> input_shapes; + std::vector> output_shapes; + std::vector inpx_shape = + inputs.at(kIndex0)->GetShapeVector().empty() ? std::vector({1}) : inputs.at(kIndex0)->GetShapeVector(); + std::vector inpy_shape = + inputs.at(kIndex1)->GetShapeVector().empty() ? std::vector({1}) : inputs.at(kIndex1)->GetShapeVector(); + std::vector out_shape = + output->GetShapeVector().empty() ? std::vector({1}) : output->GetShapeVector(); + input_shapes.emplace_back(inpx_shape); + input_shapes.emplace_back(inpy_shape); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector HeavisideGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Heaviside, HeavisideGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h index 211311b6859..73e8b70eee5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/heaviside_gpu_kernel.h @@ -1,57 +1,57 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/heaviside.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h" - -namespace mindspore { -namespace kernel { -class HeavisideGpuKernelMod : public NativeGpuKernelMod { - public: - HeavisideGpuKernelMod() {} - ~HeavisideGpuKernelMod() = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/heaviside.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/heaviside_helper.h" + +namespace mindspore { +namespace kernel { +class HeavisideGpuKernelMod : public NativeGpuKernelMod { + public: + HeavisideGpuKernelMod() {} + ~HeavisideGpuKernelMod() = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HEAVISIDE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc index cf575b00b67..f650a93a967 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc @@ -1,95 +1,95 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/hypot_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -template -std::unique_ptr CreateHypotKernelPtr(const std::string &kernel_name, - const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); -} -using HypotPtrCreatorFunc = - std::function(const std::string &, const uint32_t &)>; - -const std::vector> kernel_attr = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - CreateHypotKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - CreateHypotKernelPtr}}; -} // namespace - -bool HypotGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - std::vector input_ptrs = ConvertPtrs(inputs); - std::vector work_ptrs = ConvertPtrs(workspace); - std::vector output_ptrs = ConvertPtrs(outputs); - if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { - return false; - } - return true; -} - -bool HypotGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); - if (!is_match) { - return false; - } - helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); - MS_ERROR_IF_NULL(helper_ptr_); - return true; -} - -int HypotGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - for (const auto &input : inputs) { - MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - auto output = outputs.at(kIndex0); - MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); - std::vector> input_shapes; - std::vector> output_shapes; - std::vector inpx_shape = inputs.at(kIndex0)->GetShapeVector(); - std::vector inpy_shape = inputs.at(kIndex1)->GetShapeVector(); - std::vector out_shape = output->GetShapeVector(); - input_shapes.emplace_back(inpx_shape); - input_shapes.emplace_back(inpy_shape); - output_shapes.emplace_back(out_shape); - if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { - return KRET_RESIZE_FAILED; - } - output_size_list_ = helper_ptr_->GetOutputSizeList(); - workspace_size_list_ = helper_ptr_->GetWorkSizeList(); - return KRET_OK; -} - -std::vector HypotGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Hypot, HypotGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/hypot_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateHypotKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using HypotPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CreateHypotKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CreateHypotKernelPtr}}; +} // namespace + +bool HypotGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool HypotGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + MS_ERROR_IF_NULL(helper_ptr_); + return true; +} + +int HypotGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + for (const auto &input : inputs) { + MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + auto output = outputs.at(kIndex0); + MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); + std::vector> input_shapes; + std::vector> output_shapes; + std::vector inpx_shape = inputs.at(kIndex0)->GetShapeVector(); + std::vector inpy_shape = inputs.at(kIndex1)->GetShapeVector(); + std::vector out_shape = output->GetShapeVector(); + input_shapes.emplace_back(inpx_shape); + input_shapes.emplace_back(inpy_shape); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector HypotGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Hypot, HypotGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h index 224b82f46e0..31e7d55c5bb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h @@ -1,57 +1,57 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/hypot.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h" - -namespace mindspore { -namespace kernel { -class HypotGpuKernelMod : public NativeGpuKernelMod { - public: - HypotGpuKernelMod() {} - ~HypotGpuKernelMod() = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/hypot.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h" + +namespace mindspore { +namespace kernel { +class HypotGpuKernelMod : public NativeGpuKernelMod { + public: + HypotGpuKernelMod() {} + ~HypotGpuKernelMod() = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.cc index 3a86f329180..82352af1847 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.cc @@ -1,36 +1,36 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h" -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LU, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - LUGpuKernelMod, float) - -MS_REG_GPU_KERNEL_ONE(LU, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - LUGpuKernelMod, double) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LU, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + LUGpuKernelMod, float) + +MS_REG_GPU_KERNEL_ONE(LU, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + LUGpuKernelMod, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h index eef63960166..b49c9cfd026 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h @@ -1,256 +1,256 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "include/common/utils/convert_utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class LUGpuKernelMod : public NativeGpuKernelMod { - public: - LUGpuKernelMod() : is_null_input_(false) {} - ~LUGpuKernelMod() = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast(stream_ptr)), - "cusolverDnSetStream failed"); - T *batch_input_addr = GetDeviceAddress(inputs, kDim0); - T *batch_output_addr = GetDeviceAddress(outputs, kDim0); - int *batch_piv_output_addr = nullptr; - if (pivot_on_) { - batch_piv_output_addr = GetDeviceAddress(outputs, kDim1); - } - // workspace - int *batch_permutation_addr = GetDeviceAddress(outputs, kDim2); - int *info_output_addr = GetDeviceAddress(workspace, kDim0); - T *dev_transpose_work = GetDeviceAddress(workspace, kDim1); - - TransposeInfo info; - TransposeInfo work_info; - info.input_shape = std::vector{m_, n_}; - info.perm = std::vector{1, 0}; - work_info.input_shape = std::vector{n_, m_}; - work_info.perm = std::vector{1, 0}; - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed in LUGpuKernelMod::Launch."); - - // 4. query working space of getrf - if constexpr (std::is_same_v) { - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), - "cusolver query lu work size fail"); - } else if constexpr (std::is_same_v) { - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), - "cusolver query lu work size fail"); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; - } - // 5. malloc device working space of getrf - d_work_ = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_)); - for (size_t batch = 0; batch < batch_size_; ++batch) { - T *output_addr = batch_output_addr + batch * m_ * n_; - int *permutation_addr = batch_permutation_addr + batch * k_ * k_; - int *piv_output_addr = batch_piv_output_addr + batch * k_; - auto s1 = CalTranspose(m_ * n_, output_addr, info, dev_transpose_work, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(s1, "Transpose called by " + kernel_name_); - - // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. - if constexpr (std::is_same_v) { - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - - cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), - "cusolver lu fail"); - } else if constexpr (std::is_same_v) { - // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. - CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( - cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), - "cusolver lu fail"); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; - } - auto s2 = CalTranspose(m_ * n_, dev_transpose_work, work_info, output_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(s2, "Transpose called by " + kernel_name_); - std::vector host_permuted(k_, 0); - std::vector host_pivots(k_, 0); - std::vector host_permutation(k_ * k_, 0); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "For 'LuScipy', cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host."); - if (cudaStreamQuery(reinterpret_cast(stream_ptr)) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), - "For 'LuScipy', cuda Stream Sync Failed."); - } - - // cal pivots && permutation major by row. - for (size_t i = 0; i < k_; ++i) { - host_pivots[i] -= 1; - host_permuted[i] = i; - } - for (size_t i = 0; i < k_; ++i) { - int tmp_value = host_permuted[i]; - host_permuted[i] = host_permuted[host_pivots[i]]; - host_permuted[host_pivots[i]] = tmp_value; - } - // gpu default is P.A = LU, so here is col swap. - for (size_t i = 0; i < k_; ++i) { - host_permutation[host_permuted[i] * k_ + i] = 1; - } - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array."); - } - device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_); - return true; - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); - return true; - } - - int Resize(const std::vector &inputs, const std::vector &outputs) override { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - batch_size_ = 1; - auto shape_signed = inputs[kIndex0]->GetShapeVector(); - auto in_shape = Convert2SizeT(shape_signed); - // 2. check input shape not null - is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input"); - if (is_null_input_) { - InitSizeLists(); - return KRET_OK; - } - // 3. calculate input size - if (!InitInputSize(in_shape)) { - MS_LOG(ERROR) << "For 'PureCholeskyGpuKernel', input shape init failed."; - return KRET_RESIZE_FAILED; - } - return KRET_OK; - } - - std::vector GetOpSupport() override { - static std::vector support_list = { - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - }; - return support_list; - } - - private: - bool InitInputSize(const std::vector &in_shape) { - constexpr size_t lu_min_dim = 1; - if (in_shape.size() <= lu_min_dim) { - MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid."; - } - constexpr size_t lu_reverse_row_dim = 2; - lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim); - lu_col_ = in_shape.at(in_shape.size() - 1); - batch_size_ = lu_min_dim; - for (int batch = 0; batch < static_cast(in_shape.size() - lu_reverse_row_dim); ++batch) { - batch_size_ *= in_shape.at(batch); - } - // set matrix row or col to be lead dimension - m_ = SizeToInt(lu_row_); - n_ = SizeToInt(lu_col_); - k_ = std::min(lu_row_, lu_col_); - lda_ = m_; - ldb_ = n_; - InitSizeLists(); - return true; - } - - void InitSizeLists() { - size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; - - size_t output_piv_size = 0; - if (pivot_on_) { - output_piv_size = batch_size_ * k_ * sizeof(int); - } - size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int); - output_size_list_.resize(kDim3); - output_size_list_[kDim0] = output_size; - output_size_list_[kDim1] = output_piv_size; - output_size_list_[kDim2] = output_permutation_size; - - // a device addr to place lu factor return code - workspace_size_list_.push_back(sizeof(int)); - - // transpose 2d matrix scalar args workspace - constexpr size_t shape_2d = 2; - workspace_size_list_.push_back(shape_2d * sizeof(size_t)); - workspace_size_list_.push_back(shape_2d * sizeof(size_t)); - - // transpose workspace - workspace_size_list_.push_back(m_ * n_ * unit_size_); - } - - size_t unit_size_{sizeof(T)}; - size_t batch_size_{1}; - size_t lu_row_{0}; - size_t lu_col_{0}; - size_t k_{0}; - int64_t m_{0}; - int64_t n_{0}; - size_t lda_{0}; - size_t ldb_{0}; - int lwork_{0}; - bool pivot_on_{true}; - T *d_work_{nullptr}; - cusolverDnHandle_t handle_{nullptr}; - bool is_null_input_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "include/common/utils/convert_utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LUGpuKernelMod : public NativeGpuKernelMod { + public: + LUGpuKernelMod() : is_null_input_(false) {} + ~LUGpuKernelMod() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast(stream_ptr)), + "cusolverDnSetStream failed"); + T *batch_input_addr = GetDeviceAddress(inputs, kDim0); + T *batch_output_addr = GetDeviceAddress(outputs, kDim0); + int *batch_piv_output_addr = nullptr; + if (pivot_on_) { + batch_piv_output_addr = GetDeviceAddress(outputs, kDim1); + } + // workspace + int *batch_permutation_addr = GetDeviceAddress(outputs, kDim2); + int *info_output_addr = GetDeviceAddress(workspace, kDim0); + T *dev_transpose_work = GetDeviceAddress(workspace, kDim1); + + TransposeInfo info; + TransposeInfo work_info; + info.input_shape = std::vector{m_, n_}; + info.perm = std::vector{1, 0}; + work_info.input_shape = std::vector{n_, m_}; + work_info.perm = std::vector{1, 0}; + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernelMod::Launch."); + + // 4. query working space of getrf + if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( + cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), + "cusolver query lu work size fail"); + } else if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( + cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), + "cusolver query lu work size fail"); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; + } + // 5. malloc device working space of getrf + d_work_ = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_)); + for (size_t batch = 0; batch < batch_size_; ++batch) { + T *output_addr = batch_output_addr + batch * m_ * n_; + int *permutation_addr = batch_permutation_addr + batch * k_ * k_; + int *piv_output_addr = batch_piv_output_addr + batch * k_; + auto s1 = CalTranspose(m_ * n_, output_addr, info, dev_transpose_work, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(s1, "Transpose called by " + kernel_name_); + + // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. + if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( + + cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + } else if constexpr (std::is_same_v) { + // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. + CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE( + cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; + } + auto s2 = CalTranspose(m_ * n_, dev_transpose_work, work_info, output_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(s2, "Transpose called by " + kernel_name_); + std::vector host_permuted(k_, 0); + std::vector host_pivots(k_, 0); + std::vector host_permutation(k_ * k_, 0); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "For 'LuScipy', cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host."); + if (cudaStreamQuery(reinterpret_cast(stream_ptr)) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "For 'LuScipy', cuda Stream Sync Failed."); + } + + // cal pivots && permutation major by row. + for (size_t i = 0; i < k_; ++i) { + host_pivots[i] -= 1; + host_permuted[i] = i; + } + for (size_t i = 0; i < k_; ++i) { + int tmp_value = host_permuted[i]; + host_permuted[i] = host_permuted[host_pivots[i]]; + host_permuted[host_pivots[i]] = tmp_value; + } + // gpu default is P.A = LU, so here is col swap. + for (size_t i = 0; i < k_; ++i) { + host_permutation[host_permuted[i] * k_ + i] = 1; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array."); + } + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_); + return true; + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); + return true; + } + + int Resize(const std::vector &inputs, const std::vector &outputs) override { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + batch_size_ = 1; + auto shape_signed = inputs[kIndex0]->GetShapeVector(); + auto in_shape = Convert2SizeT(shape_signed); + // 2. check input shape not null + is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input"); + if (is_null_input_) { + InitSizeLists(); + return KRET_OK; + } + // 3. calculate input size + if (!InitInputSize(in_shape)) { + MS_LOG(ERROR) << "For 'PureCholeskyGpuKernel', input shape init failed."; + return KRET_RESIZE_FAILED; + } + return KRET_OK; + } + + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + }; + return support_list; + } + + private: + bool InitInputSize(const std::vector &in_shape) { + constexpr size_t lu_min_dim = 1; + if (in_shape.size() <= lu_min_dim) { + MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid."; + } + constexpr size_t lu_reverse_row_dim = 2; + lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim); + lu_col_ = in_shape.at(in_shape.size() - 1); + batch_size_ = lu_min_dim; + for (int batch = 0; batch < static_cast(in_shape.size() - lu_reverse_row_dim); ++batch) { + batch_size_ *= in_shape.at(batch); + } + // set matrix row or col to be lead dimension + m_ = SizeToInt(lu_row_); + n_ = SizeToInt(lu_col_); + k_ = std::min(lu_row_, lu_col_); + lda_ = m_; + ldb_ = n_; + InitSizeLists(); + return true; + } + + void InitSizeLists() { + size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; + + size_t output_piv_size = 0; + if (pivot_on_) { + output_piv_size = batch_size_ * k_ * sizeof(int); + } + size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int); + output_size_list_.resize(kDim3); + output_size_list_[kDim0] = output_size; + output_size_list_[kDim1] = output_piv_size; + output_size_list_[kDim2] = output_permutation_size; + + // a device addr to place lu factor return code + workspace_size_list_.push_back(sizeof(int)); + + // transpose 2d matrix scalar args workspace + constexpr size_t shape_2d = 2; + workspace_size_list_.push_back(shape_2d * sizeof(size_t)); + workspace_size_list_.push_back(shape_2d * sizeof(size_t)); + + // transpose workspace + workspace_size_list_.push_back(m_ * n_ * unit_size_); + } + + size_t unit_size_{sizeof(T)}; + size_t batch_size_{1}; + size_t lu_row_{0}; + size_t lu_col_{0}; + size_t k_{0}; + int64_t m_{0}; + int64_t n_{0}; + size_t lda_{0}; + size_t ldb_{0}; + int lwork_{0}; + bool pivot_on_{true}; + T *d_work_{nullptr}; + cusolverDnHandle_t handle_{nullptr}; + bool is_null_input_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/multinomial_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/multinomial_gpu_kernel.cc index 8c1d0560eed..02a1200b048 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/multinomial_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/multinomial_gpu_kernel.cc @@ -1,150 +1,150 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/multinomial_gpu_kernel.h" -#include "kernel/philox_random.h" - -namespace mindspore { -namespace kernel { -namespace { -static constexpr size_t input_num_ = 2; -static constexpr size_t output_num_ = 1; -} // namespace -bool MultinomialGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num_, kernel_name_); - - uint64_t seed = static_cast(GetValue(primitive_->GetAttr("seed"))); - uint64_t seed2 = static_cast(GetValue(primitive_->GetAttr("seed2"))); - seed_ = random::GetSeed(seed, seed2); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; - } - launch_func_ = func_list_[index].second; - return true; -} - -int MultinomialGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - workspace_size_list_.clear(); - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - auto input_shape_0 = Convert2SizeTClipNeg(inputs[0]->GetShapeVector()); - if (input_shape_0.size() <= 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input0.shape is empty."; - } - if (input_shape_0.size() == 1) { - distributions_ = 1; - categories_ = input_shape_0[0]; - } else { - distributions_ = input_shape_0[0]; - categories_ = input_shape_0[1]; - } - auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); - rand_state_ = static_cast(allocator.AllocTensorMem(sizeof(curandState) * distributions_)); - return ret; -} - -bool MultinomialGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, void *stream_ptr) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num_, kernel_name_); - - launch_func_(this, inputs, outputs, stream_ptr); - return true; -} - -template -void MultinomialGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs, void *stream_ptr) { - T *probs_addr = GetDeviceAddress(inputs, 0); - S *output_addr = GetDeviceAddress(outputs, 0); - int64_t *num_sample_addr = GetDeviceAddress(inputs, 1); - if (distributions_ == 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', divide by zero. the distributions_ is 0."; - } - cudaError_t status = cudaErrorNotReady; - auto stream = reinterpret_cast(stream_ptr); - status = InitRandState(seed_, seed_offset_, distributions_, rand_state_, stream); - CHECK_CUDA_STATUS(status, "InitRandState called by " + kernel_name_); - status = Multinomial(distributions_, categories_, probs_addr, rand_state_, num_sample_addr, output_addr, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - seed_offset_ += 1; -} - -std::vector> MultinomialGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &MultinomialGpuKernelMod::LaunchKernel}}; - -std::vector MultinomialGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Multinomial, MultinomialGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/multinomial_gpu_kernel.h" +#include "kernel/philox_random.h" + +namespace mindspore { +namespace kernel { +namespace { +static constexpr size_t input_num_ = 2; +static constexpr size_t output_num_ = 1; +} // namespace +bool MultinomialGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num_, kernel_name_); + + uint64_t seed = static_cast(GetValue(primitive_->GetAttr("seed"))); + uint64_t seed2 = static_cast(GetValue(primitive_->GetAttr("seed2"))); + seed_ = random::GetSeed(seed, seed2); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + } + launch_func_ = func_list_[index].second; + return true; +} + +int MultinomialGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + workspace_size_list_.clear(); + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + auto input_shape_0 = Convert2SizeTClipNeg(inputs[0]->GetShapeVector()); + if (input_shape_0.size() <= 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input0.shape is empty."; + } + if (input_shape_0.size() == 1) { + distributions_ = 1; + categories_ = input_shape_0[0]; + } else { + distributions_ = input_shape_0[0]; + categories_ = input_shape_0[1]; + } + auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); + rand_state_ = static_cast(allocator.AllocTensorMem(sizeof(curandState) * distributions_)); + return ret; +} + +bool MultinomialGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, void *stream_ptr) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num_, kernel_name_); + + launch_func_(this, inputs, outputs, stream_ptr); + return true; +} + +template +void MultinomialGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs, void *stream_ptr) { + T *probs_addr = GetDeviceAddress(inputs, 0); + S *output_addr = GetDeviceAddress(outputs, 0); + int64_t *num_sample_addr = GetDeviceAddress(inputs, 1); + if (distributions_ == 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', divide by zero. the distributions_ is 0."; + } + cudaError_t status = cudaErrorNotReady; + auto stream = reinterpret_cast(stream_ptr); + status = InitRandState(seed_, seed_offset_, distributions_, rand_state_, stream); + CHECK_CUDA_STATUS(status, "InitRandState called by " + kernel_name_); + status = Multinomial(distributions_, categories_, probs_addr, rand_state_, num_sample_addr, output_addr, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + seed_offset_ += 1; +} + +std::vector> MultinomialGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MultinomialGpuKernelMod::LaunchKernel}}; + +std::vector MultinomialGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Multinomial, MultinomialGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.cc index 405728c8588..24743445870 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.cc @@ -1,103 +1,103 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/polar_gpu_kernel.h" -#include - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kZero = 0; -constexpr size_t kOne = 1; -constexpr size_t kTwo = 2; -} // namespace - -bool PolarGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', valid gpu kernel does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - return true; -} - -int PolarGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - for (const auto &output : outputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto output_shape = output->GetShapeVector(); - if (!IsValidShape(output_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - output_size_list_.push_back(output_elements_ * unit_output_size_); - - return KRET_OK; -} - -template -bool PolarGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *abs = GetDeviceAddress(inputs, 0); - T *angle = GetDeviceAddress(inputs, 1); - S *output = GetDeviceAddress(outputs, 0); - auto status = - CalPolar(output_elements_, abs, angle, output, device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return True; -} - -template -using Complex = mindspore::utils::Complex; - -std::vector> PolarGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64), - &PolarGpuKernelMod::LaunchKernel>}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128), - &PolarGpuKernelMod::LaunchKernel>}}; - -std::vector PolarGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Polar, PolarGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/polar_gpu_kernel.h" +#include + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kZero = 0; +constexpr size_t kOne = 1; +constexpr size_t kTwo = 2; +} // namespace + +bool PolarGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', valid gpu kernel does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + return true; +} + +int PolarGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + output_size_list_.push_back(output_elements_ * unit_output_size_); + + return KRET_OK; +} + +template +bool PolarGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *abs = GetDeviceAddress(inputs, 0); + T *angle = GetDeviceAddress(inputs, 1); + S *output = GetDeviceAddress(outputs, 0); + auto status = + CalPolar(output_elements_, abs, angle, output, device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return True; +} + +template +using Complex = mindspore::utils::Complex; + +std::vector> PolarGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64), + &PolarGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128), + &PolarGpuKernelMod::LaunchKernel>}}; + +std::vector PolarGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Polar, PolarGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.h index f95faefc11a..31d3407d4b8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/polar_gpu_kernel.h @@ -1,85 +1,85 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/polar.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh" - -namespace mindspore { -namespace kernel { -class PolarGpuKernelMod : public NativeGpuKernelMod { - public: - PolarGpuKernelMod() { ResetResource(); } - ~PolarGpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - protected: - void ResetResource() noexcept { - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - using Polarfunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t output_elements_; - size_t unit_input_size_{1}; - size_t unit_output_size_{1}; - Polarfunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/polar.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/polar_impl.cuh" + +namespace mindspore { +namespace kernel { +class PolarGpuKernelMod : public NativeGpuKernelMod { + public: + PolarGpuKernelMod() { ResetResource(); } + ~PolarGpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + protected: + void ResetResource() noexcept { + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using Polarfunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t output_elements_; + size_t unit_input_size_{1}; + size_t unit_output_size_{1}; + Polarfunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_POLAR_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.cc index bb2df634369..5a31da7050e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.cc @@ -1,100 +1,100 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool TrilIndicesGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - row_ = GetValue(primitive_->GetAttr("row")); - col_ = GetValue(primitive_->GetAttr("col")); - offset_ = GetValue(primitive_->GetAttr("offset")); - return true; -} - -int TrilIndicesGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - MS_EXCEPTION_IF_NULL(outputs[kIndex0]); - ResetResource(); - auto ret = KRET_OK; - size_t tensor_size = 0; - size_t type_size = GetTypeByte(TypeIdToType(outputs.at(kIndex0)->dtype_id())); - auto shape = outputs.at(kIndex0)->GetShapeVector(); - if (!IsValidShape(shape)) { - ret = KRET_UNKNOWN_OUT_SHAPE; - } else { - tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - } - output_size_list_.emplace_back(tensor_size); - const size_t matrix_dim = 2; - tril_size_ = tensor_size / (type_size * matrix_dim); - return ret; -} - -void TrilIndicesGpuKernelMod::ResetResource() noexcept { - tril_size_ = 0; - workspace_size_list_.clear(); - output_size_list_.clear(); -} - -template -bool TrilIndicesGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *output = GetDeviceAddress(outputs, kIndex0); - MS_EXCEPTION_IF_NULL(output); - if (tril_size_ > 0) { - auto m_first_row = offset_ > 0 ? std::min(col_, 1 + offset_) : row_ + offset_ > 0; - auto trapezoid_row_offset = std::max(0, -offset_); - auto rectangle_row_offset = trapezoid_row_offset + col_ - m_first_row + 1; - int64_t rectangle_size = 0; - if (rectangle_row_offset < row_) { - rectangle_size = (row_ - rectangle_row_offset) * col_; - } - auto status = - CalTrilIndices(trapezoid_row_offset, m_first_row, col_, static_cast(tril_size_) - rectangle_size, - tril_size_, output, device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - } - return true; -} - -std::vector> TrilIndicesGpuKernelMod::func_list_ = { - {KernelAttr().AddOutputAttr(kNumberTypeInt32), &TrilIndicesGpuKernelMod::LaunchKernel}, - {KernelAttr().AddOutputAttr(kNumberTypeInt64), &TrilIndicesGpuKernelMod::LaunchKernel}}; - -std::vector TrilIndicesGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TrilIndices, TrilIndicesGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool TrilIndicesGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + row_ = GetValue(primitive_->GetAttr("row")); + col_ = GetValue(primitive_->GetAttr("col")); + offset_ = GetValue(primitive_->GetAttr("offset")); + return true; +} + +int TrilIndicesGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(outputs[kIndex0]); + ResetResource(); + auto ret = KRET_OK; + size_t tensor_size = 0; + size_t type_size = GetTypeByte(TypeIdToType(outputs.at(kIndex0)->dtype_id())); + auto shape = outputs.at(kIndex0)->GetShapeVector(); + if (!IsValidShape(shape)) { + ret = KRET_UNKNOWN_OUT_SHAPE; + } else { + tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + } + output_size_list_.emplace_back(tensor_size); + const size_t matrix_dim = 2; + tril_size_ = tensor_size / (type_size * matrix_dim); + return ret; +} + +void TrilIndicesGpuKernelMod::ResetResource() noexcept { + tril_size_ = 0; + workspace_size_list_.clear(); + output_size_list_.clear(); +} + +template +bool TrilIndicesGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *output = GetDeviceAddress(outputs, kIndex0); + MS_EXCEPTION_IF_NULL(output); + if (tril_size_ > 0) { + auto m_first_row = offset_ > 0 ? std::min(col_, 1 + offset_) : row_ + offset_ > 0; + auto trapezoid_row_offset = std::max(0, -offset_); + auto rectangle_row_offset = trapezoid_row_offset + col_ - m_first_row + 1; + int64_t rectangle_size = 0; + if (rectangle_row_offset < row_) { + rectangle_size = (row_ - rectangle_row_offset) * col_; + } + auto status = + CalTrilIndices(trapezoid_row_offset, m_first_row, col_, static_cast(tril_size_) - rectangle_size, + tril_size_, output, device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + } + return true; +} + +std::vector> TrilIndicesGpuKernelMod::func_list_ = { + {KernelAttr().AddOutputAttr(kNumberTypeInt32), &TrilIndicesGpuKernelMod::LaunchKernel}, + {KernelAttr().AddOutputAttr(kNumberTypeInt64), &TrilIndicesGpuKernelMod::LaunchKernel}}; + +std::vector TrilIndicesGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TrilIndices, TrilIndicesGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h index e21e27109c2..9e76c89b233 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/tril_indices_gpu_kernel.h @@ -1,76 +1,76 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/tril_indices.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh" - -namespace mindspore { -namespace kernel { -class TrilIndicesGpuKernelMod : public NativeGpuKernelMod { - public: - TrilIndicesGpuKernelMod() { ResetResource(); } - ~TrilIndicesGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept; - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using TrilIndicesFunc = - std::function &, - const std::vector &, const std::vector &)>; - TrilIndicesFunc kernel_func_{}; - static std::vector> func_list_; - - private: - int64_t row_{0}; - int64_t col_{0}; - int64_t offset_{0}; - size_t tril_size_{0}; - void *cuda_stream_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/tril_indices.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tril_indices_impl.cuh" + +namespace mindspore { +namespace kernel { +class TrilIndicesGpuKernelMod : public NativeGpuKernelMod { + public: + TrilIndicesGpuKernelMod() { ResetResource(); } + ~TrilIndicesGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept; + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using TrilIndicesFunc = + std::function &, + const std::vector &, const std::vector &)>; + TrilIndicesFunc kernel_func_{}; + static std::vector> func_list_; + + private: + int64_t row_{0}; + int64_t col_{0}; + int64_t offset_{0}; + size_t tril_size_{0}; + void *cuda_stream_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIL_INDICES_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.cc index 76e90c28d8e..aecd7c73e8f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.cc @@ -1,97 +1,97 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool TriuIndicesGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - row_ = GetValue(primitive_->GetAttr("row")); - col_ = GetValue(primitive_->GetAttr("col")); - offset_ = GetValue(primitive_->GetAttr("offset")); - return true; -} - -int TriuIndicesGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - MS_EXCEPTION_IF_NULL(outputs[kIndex0]); - ResetResource(); - auto ret = KRET_OK; - size_t tensor_size = 0; - size_t type_size = GetTypeByte(TypeIdToType(outputs.at(kIndex0)->dtype_id())); - auto shape = outputs.at(kIndex0)->GetShapeVector(); - if (!IsValidShape(shape)) { - ret = KRET_UNKNOWN_OUT_SHAPE; - } else { - tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - } - output_size_list_.emplace_back(tensor_size); - const size_t matrix_dim = 2; - triu_size_ = tensor_size / (type_size * matrix_dim); - return ret; -} - -void TriuIndicesGpuKernelMod::ResetResource() noexcept { - triu_size_ = 0; - workspace_size_list_.clear(); - output_size_list_.clear(); -} - -template -bool TriuIndicesGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *output = GetDeviceAddress(outputs, kIndex0); - MS_EXCEPTION_IF_NULL(output); - if (triu_size_ > 0) { - auto m_first_row = offset_ > 0 ? std::max(col_ - offset_, 0) : col_; - int64_t rectangle_size = 0; - if (offset_ < 0) { - rectangle_size = std::min(row_, -offset_) * col_; - } - auto status = CalTriuIndices(std::max(0, offset_), m_first_row, col_, rectangle_size, triu_size_, output, - device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - } - return true; -} - -std::vector> TriuIndicesGpuKernelMod::func_list_ = { - {KernelAttr().AddOutputAttr(kNumberTypeInt32), &TriuIndicesGpuKernelMod::LaunchKernel}, - {KernelAttr().AddOutputAttr(kNumberTypeInt64), &TriuIndicesGpuKernelMod::LaunchKernel}}; - -std::vector TriuIndicesGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TriuIndices, TriuIndicesGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool TriuIndicesGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + row_ = GetValue(primitive_->GetAttr("row")); + col_ = GetValue(primitive_->GetAttr("col")); + offset_ = GetValue(primitive_->GetAttr("offset")); + return true; +} + +int TriuIndicesGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(outputs[kIndex0]); + ResetResource(); + auto ret = KRET_OK; + size_t tensor_size = 0; + size_t type_size = GetTypeByte(TypeIdToType(outputs.at(kIndex0)->dtype_id())); + auto shape = outputs.at(kIndex0)->GetShapeVector(); + if (!IsValidShape(shape)) { + ret = KRET_UNKNOWN_OUT_SHAPE; + } else { + tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + } + output_size_list_.emplace_back(tensor_size); + const size_t matrix_dim = 2; + triu_size_ = tensor_size / (type_size * matrix_dim); + return ret; +} + +void TriuIndicesGpuKernelMod::ResetResource() noexcept { + triu_size_ = 0; + workspace_size_list_.clear(); + output_size_list_.clear(); +} + +template +bool TriuIndicesGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *output = GetDeviceAddress(outputs, kIndex0); + MS_EXCEPTION_IF_NULL(output); + if (triu_size_ > 0) { + auto m_first_row = offset_ > 0 ? std::max(col_ - offset_, 0) : col_; + int64_t rectangle_size = 0; + if (offset_ < 0) { + rectangle_size = std::min(row_, -offset_) * col_; + } + auto status = CalTriuIndices(std::max(0, offset_), m_first_row, col_, rectangle_size, triu_size_, output, + device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + } + return true; +} + +std::vector> TriuIndicesGpuKernelMod::func_list_ = { + {KernelAttr().AddOutputAttr(kNumberTypeInt32), &TriuIndicesGpuKernelMod::LaunchKernel}, + {KernelAttr().AddOutputAttr(kNumberTypeInt64), &TriuIndicesGpuKernelMod::LaunchKernel}}; + +std::vector TriuIndicesGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TriuIndices, TriuIndicesGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h index 09647a4acfa..27103176402 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/triu_indices_gpu_kernel.h @@ -1,76 +1,76 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/triu_indices.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh" - -namespace mindspore { -namespace kernel { -class TriuIndicesGpuKernelMod : public NativeGpuKernelMod { - public: - TriuIndicesGpuKernelMod() { ResetResource(); } - ~TriuIndicesGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept; - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using TriuIndicesFunc = - std::function &, - const std::vector &, const std::vector &)>; - TriuIndicesFunc kernel_func_{}; - static std::vector> func_list_; - - private: - int64_t row_{0}; - int64_t col_{0}; - int64_t offset_{0}; - size_t triu_size_{0}; - void *cuda_stream_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/triu_indices.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triu_indices_impl.cuh" + +namespace mindspore { +namespace kernel { +class TriuIndicesGpuKernelMod : public NativeGpuKernelMod { + public: + TriuIndicesGpuKernelMod() { ResetResource(); } + ~TriuIndicesGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept; + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using TriuIndicesFunc = + std::function &, + const std::vector &, const std::vector &)>; + TriuIndicesFunc kernel_func_{}; + static std::vector> func_list_; + + private: + int64_t row_{0}; + int64_t col_{0}; + int64_t offset_{0}; + size_t triu_size_{0}; + void *cuda_stream_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_TRIU_INDICES_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.cc index 88e47967b31..b9dc9d0fc0c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.cc @@ -1,97 +1,97 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h" -#include -namespace mindspore { -namespace kernel { -namespace { -template -std::unique_ptr CreateAdaptiveAvgPool3DKernelPtr(const std::string &kernel_name, - const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); -} -using AdaptiveAvgPool3DPtrCreatorFunc = - std::function(const std::string &, const uint32_t &)>; - -const std::vector> kernel_attr = { - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - CreateAdaptiveAvgPool3DKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - CreateAdaptiveAvgPool3DKernelPtr}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - CreateAdaptiveAvgPool3DKernelPtr}}; -} // namespace - -bool AdaptiveAvgPool3DGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - std::vector input_ptrs = ConvertPtrs(inputs); - std::vector work_ptrs = ConvertPtrs(workspace); - std::vector output_ptrs = ConvertPtrs(outputs); - if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { - return false; - } - return true; -} - -bool AdaptiveAvgPool3DGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); - if (!is_match) { - return false; - } - const auto &output_size_ptr = primitive_->GetAttr("output_size"); - attr_ptr_->output_size = GetValue>(output_size_ptr); - helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); - helper_ptr_->SetKernelParam(attr_ptr_); - return true; -} - -int AdaptiveAvgPool3DGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - std::vector> input_shapes; - std::vector> output_shapes; - std::vector inp_shape = inputs[0]->GetShapeVector(); - std::vector out_shape = outputs[0]->GetShapeVector(); - input_shapes.emplace_back(inp_shape); - output_shapes.emplace_back(out_shape); - if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { - return KRET_RESIZE_FAILED; - } - output_size_list_ = helper_ptr_->GetOutputSizeList(); - workspace_size_list_ = helper_ptr_->GetWorkSizeList(); - return KRET_OK; -} - -std::vector AdaptiveAvgPool3DGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveAvgPool3D, AdaptiveAvgPool3DGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h" +#include +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateAdaptiveAvgPool3DKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using AdaptiveAvgPool3DPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + CreateAdaptiveAvgPool3DKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CreateAdaptiveAvgPool3DKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CreateAdaptiveAvgPool3DKernelPtr}}; +} // namespace + +bool AdaptiveAvgPool3DGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool AdaptiveAvgPool3DGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + const auto &output_size_ptr = primitive_->GetAttr("output_size"); + attr_ptr_->output_size = GetValue>(output_size_ptr); + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + helper_ptr_->SetKernelParam(attr_ptr_); + return true; +} + +int AdaptiveAvgPool3DGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + std::vector> input_shapes; + std::vector> output_shapes; + std::vector inp_shape = inputs[0]->GetShapeVector(); + std::vector out_shape = outputs[0]->GetShapeVector(); + input_shapes.emplace_back(inp_shape); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector AdaptiveAvgPool3DGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveAvgPool3D, AdaptiveAvgPool3DGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h index 4cc6ddd79de..84cee815f2f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_avg_pool3d_gpu_kernel.h @@ -1,51 +1,51 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/adaptive_avg_pool_3d.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h" -namespace mindspore { -namespace kernel { -class AdaptiveAvgPool3DGpuKernelMod : public NativeGpuKernelMod { - public: - AdaptiveAvgPool3DGpuKernelMod() { attr_ptr_ = std::make_shared(); } - ~AdaptiveAvgPool3DGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_{nullptr}; - std::shared_ptr attr_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/adaptive_avg_pool_3d.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_avg_pool3d_helper.h" +namespace mindspore { +namespace kernel { +class AdaptiveAvgPool3DGpuKernelMod : public NativeGpuKernelMod { + public: + AdaptiveAvgPool3DGpuKernelMod() { attr_ptr_ = std::make_shared(); } + ~AdaptiveAvgPool3DGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; + std::shared_ptr attr_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_AVG_POOL3D_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.cc index 056b3afd7e5..e3a113dd687 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.cc @@ -1,130 +1,130 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h" -#include - -namespace mindspore { -namespace kernel { -constexpr int64_t maxIndexIdx = 2; - -namespace { -template -std::unique_ptr CreateAdaptiveMaxPoolGradKernelPtr(const std::string &kernel_name, - const uint32_t &device_id) { - return std::make_unique>(kernel_name, device_id); -} - -using AdaptiveMaxPoolGradPtrCreatorFunc = - std::function(const std::string &, const uint32_t &)>; - -#define REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(TypeId1, TypeId2, Type1, Type2) \ - { \ - KernelAttr().AddInputAttr(TypeId1).AddInputAttr(TypeId1).AddInputAttr(TypeId2).AddOutputAttr(TypeId1), \ - CreateAdaptiveMaxPoolGradKernelPtr \ - } - -const std::vector> kernel_attr = { - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat16, kNumberTypeInt32, half, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat32, kNumberTypeInt32, float, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat64, kNumberTypeInt32, double, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt8, kNumberTypeInt32, int8_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt16, kNumberTypeInt32, int16_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt32, kNumberTypeInt32, int32_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt8, kNumberTypeInt32, uint8_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt16, kNumberTypeInt32, uint16_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt32, kNumberTypeInt32, uint32_t, int32_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt64, kNumberTypeInt32, uint64_t, int32_t), - - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt8, kNumberTypeInt64, int8_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt16, kNumberTypeInt64, int16_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt32, kNumberTypeInt64, int32_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt8, kNumberTypeInt64, uint8_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt16, kNumberTypeInt64, uint16_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt32, kNumberTypeInt64, uint32_t, int64_t), - REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt64, kNumberTypeInt64, uint64_t, int64_t), -}; // namespace -} // namespace - -bool AdaptiveMaxPool3DGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - std::vector input_ptrs = ConvertPtrs(inputs); - std::vector work_ptrs = ConvertPtrs(workspace); - std::vector output_ptrs = ConvertPtrs(outputs); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size(), reinterpret_cast(stream_ptr)), - "failed to set cuda memory with zeros."); - - if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { - return false; - } - return true; -} - -bool AdaptiveMaxPool3DGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); - if (!is_match) { - return false; - } - helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); - helper_ptr_->SetKernelParam(attr_ptr_); - - return true; -} - -int AdaptiveMaxPool3DGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - - std::vector> input_shapes; - std::vector> output_shapes; - std::vector input_shape = inputs[0]->GetShapeVector(); - std::vector x_shape = inputs[1]->GetShapeVector(); - std::vector index_shape = inputs[maxIndexIdx]->GetShapeVector(); - std::vector out_shape = outputs[0]->GetShapeVector(); - - (void)input_shapes.emplace_back(input_shape); - (void)input_shapes.emplace_back(x_shape); - (void)input_shapes.emplace_back(index_shape); - (void)output_shapes.emplace_back(out_shape); - - if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { - return KRET_RESIZE_FAILED; - } - return KRET_OK; -} - -std::vector AdaptiveMaxPool3DGradGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveMaxPool3DGrad, AdaptiveMaxPool3DGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h" +#include + +namespace mindspore { +namespace kernel { +constexpr int64_t maxIndexIdx = 2; + +namespace { +template +std::unique_ptr CreateAdaptiveMaxPoolGradKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} + +using AdaptiveMaxPoolGradPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +#define REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(TypeId1, TypeId2, Type1, Type2) \ + { \ + KernelAttr().AddInputAttr(TypeId1).AddInputAttr(TypeId1).AddInputAttr(TypeId2).AddOutputAttr(TypeId1), \ + CreateAdaptiveMaxPoolGradKernelPtr \ + } + +const std::vector> kernel_attr = { + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat16, kNumberTypeInt32, half, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat32, kNumberTypeInt32, float, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat64, kNumberTypeInt32, double, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt8, kNumberTypeInt32, int8_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt16, kNumberTypeInt32, int16_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt32, kNumberTypeInt32, int32_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt8, kNumberTypeInt32, uint8_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt16, kNumberTypeInt32, uint16_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt32, kNumberTypeInt32, uint32_t, int32_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt64, kNumberTypeInt32, uint64_t, int32_t), + + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt8, kNumberTypeInt64, int8_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt16, kNumberTypeInt64, int16_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt32, kNumberTypeInt64, int32_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt8, kNumberTypeInt64, uint8_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt16, kNumberTypeInt64, uint16_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt32, kNumberTypeInt64, uint32_t, int64_t), + REG_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL(kNumberTypeUInt64, kNumberTypeInt64, uint64_t, int64_t), +}; // namespace +} // namespace + +bool AdaptiveMaxPool3DGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size(), reinterpret_cast(stream_ptr)), + "failed to set cuda memory with zeros."); + + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool AdaptiveMaxPool3DGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + helper_ptr_->SetKernelParam(attr_ptr_); + + return true; +} + +int AdaptiveMaxPool3DGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + + std::vector> input_shapes; + std::vector> output_shapes; + std::vector input_shape = inputs[0]->GetShapeVector(); + std::vector x_shape = inputs[1]->GetShapeVector(); + std::vector index_shape = inputs[maxIndexIdx]->GetShapeVector(); + std::vector out_shape = outputs[0]->GetShapeVector(); + + (void)input_shapes.emplace_back(input_shape); + (void)input_shapes.emplace_back(x_shape); + (void)input_shapes.emplace_back(index_shape); + (void)output_shapes.emplace_back(out_shape); + + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + return KRET_OK; +} + +std::vector AdaptiveMaxPool3DGradGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveMaxPool3DGrad, AdaptiveMaxPool3DGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h index a09b2be11c3..a68d65050f4 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h @@ -1,53 +1,53 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/grad/adaptive_max_pool_3d_grad.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h" -namespace mindspore { -namespace kernel { -class AdaptiveMaxPool3DGradGpuKernelMod : public NativeGpuKernelMod { - public: - AdaptiveMaxPool3DGradGpuKernelMod() { attr_ptr_ = std::make_shared(); } - ~AdaptiveMaxPool3DGradGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - std::unique_ptr helper_ptr_{nullptr}; - std::shared_ptr attr_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/grad/adaptive_max_pool_3d_grad.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h" +namespace mindspore { +namespace kernel { +class AdaptiveMaxPool3DGradGpuKernelMod : public NativeGpuKernelMod { + public: + AdaptiveMaxPool3DGradGpuKernelMod() { attr_ptr_ = std::make_shared(); } + ~AdaptiveMaxPool3DGradGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; + std::shared_ptr attr_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.cc index 56d980ed4cb..027ce222772 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.cc @@ -1,141 +1,141 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.h" - -#include -#include -#include -#include -#include -#include -#include "abstract/utils.h" -#include "ops/fusion/bias_dropout_add_fusion.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_dropout_add_impl.cuh" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputNum = 3; -constexpr size_t kInputXIndex = 0; -constexpr size_t kInputBiasIndex = 1; -constexpr size_t kInputResidualIndex = 2; - -constexpr size_t kOutputNum = 2; -constexpr size_t kOutputYIndex = 0; -constexpr size_t kOutputMaskIndex = 1; -} // namespace - -bool BiasDropoutAddGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.size() != kInputNum || outputs.size() != kOutputNum) { - MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputNum << " and " << kOutputNum - << ", but get " << inputs.size() << " and " << outputs.size(); - return false; - } - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - - keep_prob_ = GetValue(primitive_->GetAttr("keep_prob")); - int64_t seed = GetValue(primitive_->GetAttr("seed0")); - if (seed == 0) { - seed = GetValue(primitive_->GetAttr("seed1")); - if (seed == 0) { - seed = time(NULL); - } - } - seed_ = static_cast(seed); - return true; -} - -int BiasDropoutAddGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - if (inputs.size() != kInputNum || output_size_list_.size() != kOutputNum) { - MS_LOG(ERROR) << kernel_name_ << " resize : input and output size should be " << kInputNum << " and " << kOutputNum - << ", but get " << inputs.size() << " and " << output_size_list_.size(); - return KRET_RESIZE_FAILED; - } - auto x_shape = inputs[kInputXIndex]->GetShapeVector(); - num_count_ = 1; - n_strides_ = 1; - channel_strides_ = 1; - for (size_t i = 0; i < x_shape.size(); ++i) { - auto dim = x_shape[i]; - if (dim < 0) { - dim = 0; - } - auto dim_length = LongToSize(dim); - num_count_ *= dim_length; - if (i > 0) { - n_strides_ *= dim_length; - } - if (i > 1) { - channel_strides_ *= dim_length; - } - } - return KRET_OK; -} - -template -bool BiasDropoutAddGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *x = GetDeviceAddress(inputs, kInputXIndex); - T *bias = GetDeviceAddress(inputs, kInputBiasIndex); - T *residual = GetDeviceAddress(inputs, kInputResidualIndex); - T *y = GetDeviceAddress(outputs, kOutputYIndex); - T *mask = GetDeviceAddress(outputs, kOutputMaskIndex); - auto status = BiasDropoutAdd(x, bias, residual, y, mask, num_count_, n_strides_, channel_strides_, keep_prob_, seed_, - seed_offset_, cuda_stream_); - CHECK_CUDA_STATUS(status, kernel_name_); - seed_offset_ += num_count_; - return true; -} - -std::vector> BiasDropoutAddGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - &BiasDropoutAddGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &BiasDropoutAddGpuKernelMod::LaunchKernel}}; - -std::vector BiasDropoutAddGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BiasDropoutAdd, BiasDropoutAddGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/nn/bias_dropout_add_gpu_kernel.h" + +#include +#include +#include +#include +#include +#include +#include "abstract/utils.h" +#include "ops/fusion/bias_dropout_add_fusion.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bias_dropout_add_impl.cuh" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputNum = 3; +constexpr size_t kInputXIndex = 0; +constexpr size_t kInputBiasIndex = 1; +constexpr size_t kInputResidualIndex = 2; + +constexpr size_t kOutputNum = 2; +constexpr size_t kOutputYIndex = 0; +constexpr size_t kOutputMaskIndex = 1; +} // namespace + +bool BiasDropoutAddGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.size() != kInputNum || outputs.size() != kOutputNum) { + MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputNum << " and " << kOutputNum + << ", but get " << inputs.size() << " and " << outputs.size(); + return false; + } + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + + keep_prob_ = GetValue(primitive_->GetAttr("keep_prob")); + int64_t seed = GetValue(primitive_->GetAttr("seed0")); + if (seed == 0) { + seed = GetValue(primitive_->GetAttr("seed1")); + if (seed == 0) { + seed = time(NULL); + } + } + seed_ = static_cast(seed); + return true; +} + +int BiasDropoutAddGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + if (inputs.size() != kInputNum || output_size_list_.size() != kOutputNum) { + MS_LOG(ERROR) << kernel_name_ << " resize : input and output size should be " << kInputNum << " and " << kOutputNum + << ", but get " << inputs.size() << " and " << output_size_list_.size(); + return KRET_RESIZE_FAILED; + } + auto x_shape = inputs[kInputXIndex]->GetShapeVector(); + num_count_ = 1; + n_strides_ = 1; + channel_strides_ = 1; + for (size_t i = 0; i < x_shape.size(); ++i) { + auto dim = x_shape[i]; + if (dim < 0) { + dim = 0; + } + auto dim_length = LongToSize(dim); + num_count_ *= dim_length; + if (i > 0) { + n_strides_ *= dim_length; + } + if (i > 1) { + channel_strides_ *= dim_length; + } + } + return KRET_OK; +} + +template +bool BiasDropoutAddGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *x = GetDeviceAddress(inputs, kInputXIndex); + T *bias = GetDeviceAddress(inputs, kInputBiasIndex); + T *residual = GetDeviceAddress(inputs, kInputResidualIndex); + T *y = GetDeviceAddress(outputs, kOutputYIndex); + T *mask = GetDeviceAddress(outputs, kOutputMaskIndex); + auto status = BiasDropoutAdd(x, bias, residual, y, mask, num_count_, n_strides_, channel_strides_, keep_prob_, seed_, + seed_offset_, cuda_stream_); + CHECK_CUDA_STATUS(status, kernel_name_); + seed_offset_ += num_count_; + return true; +} + +std::vector> BiasDropoutAddGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &BiasDropoutAddGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &BiasDropoutAddGpuKernelMod::LaunchKernel}}; + +std::vector BiasDropoutAddGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BiasDropoutAdd, BiasDropoutAddGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.cc index d1fd2a43151..89918910d21 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.cc @@ -1,111 +1,111 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.h" -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" -#include "ops/op_name.h" - -namespace mindspore { -namespace kernel { -bool BinaryCrossEntropyGpuKernelMod::BinaryCrossEntropyGpuKernelMod::Launch( - const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, workspace, outputs, stream_ptr); - } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, workspace, outputs, stream_ptr); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be float16 or float32, but got " - << TypeIdToType(dtype_)->ToString(); - } - return true; -} - -template -void BinaryCrossEntropyGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - auto reduction = static_cast(inputs[kIndex3]->GetValueWithCheck()); - if (reduction == Reduction::NONE) { - reduction_ = ReductionMode::kNone; - } else if (reduction == Reduction::MEAN) { - reduction_ = ReductionMode::kMean; - } else { - reduction_ = ReductionMode::kSum; - } - T *input_x = GetDeviceAddress(inputs, kIndex0); - T *input_y = GetDeviceAddress(inputs, kIndex1); - T *weight = nullptr; - if (inputs[kIndex2]->type_id() != kMetaTypeNone) { - weight = GetDeviceAddress(inputs, kIndex2); - } - T *loss = GetDeviceAddress(outputs, kIndex0); - T *tmp_loss = GetDeviceAddress(workspace, kIndex0); - if (input_size_ > 0) { - auto status = BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } -} - -bool BinaryCrossEntropyGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!match.first) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; - } - - dtype_ = inputs[kIndex0]->dtype_id(); - return true; -} - -int BinaryCrossEntropyGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - auto input_shape = inputs[kIndex0]->GetShapeVector(); - input_size_ = SizeOf(input_shape); - workspace_size_ = sizeof(TypeIdToType(inputs[kIndex0]->dtype_id())); - if (reduction_ != ReductionMode::kNone) { - workspace_size_ *= input_size_; - } - workspace_size_list_.push_back(workspace_size_); - return KRET_OK; -} - -std::vector BinaryCrossEntropyGpuKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOptionalInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOptionalInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32)}; - - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BinaryCrossEntropy, BinaryCrossEntropyGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/nn/binary_cross_entropy_gpu_kernel.h" +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" +#include "ops/op_name.h" + +namespace mindspore { +namespace kernel { +bool BinaryCrossEntropyGpuKernelMod::BinaryCrossEntropyGpuKernelMod::Launch( + const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, workspace, outputs, stream_ptr); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, workspace, outputs, stream_ptr); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be float16 or float32, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +template +void BinaryCrossEntropyGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + auto reduction = static_cast(inputs[kIndex3]->GetValueWithCheck()); + if (reduction == Reduction::NONE) { + reduction_ = ReductionMode::kNone; + } else if (reduction == Reduction::MEAN) { + reduction_ = ReductionMode::kMean; + } else { + reduction_ = ReductionMode::kSum; + } + T *input_x = GetDeviceAddress(inputs, kIndex0); + T *input_y = GetDeviceAddress(inputs, kIndex1); + T *weight = nullptr; + if (inputs[kIndex2]->type_id() != kMetaTypeNone) { + weight = GetDeviceAddress(inputs, kIndex2); + } + T *loss = GetDeviceAddress(outputs, kIndex0); + T *tmp_loss = GetDeviceAddress(workspace, kIndex0); + if (input_size_ > 0) { + auto status = BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } +} + +bool BinaryCrossEntropyGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!match.first) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + } + + dtype_ = inputs[kIndex0]->dtype_id(); + return true; +} + +int BinaryCrossEntropyGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + auto input_shape = inputs[kIndex0]->GetShapeVector(); + input_size_ = SizeOf(input_shape); + workspace_size_ = sizeof(TypeIdToType(inputs[kIndex0]->dtype_id())); + if (reduction_ != ReductionMode::kNone) { + workspace_size_ *= input_size_; + } + workspace_size_list_.push_back(workspace_size_); + return KRET_OK; +} + +std::vector BinaryCrossEntropyGpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOptionalInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOptionalInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32)}; + + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BinaryCrossEntropy, BinaryCrossEntropyGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.cc index c6486fea8a9..85447fc7793 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.cc @@ -1,106 +1,106 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.h" -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" -#include "ops/op_name.h" - -namespace mindspore { -namespace kernel { -bool BinaryCrossEntropyGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs, stream_ptr); - } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs, stream_ptr); - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be float16 or float32, but got " - << TypeIdToType(dtype_)->ToString(); - } - return true; -} - -template -void BinaryCrossEntropyGradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs, void *stream_ptr) { - T *input_x = GetDeviceAddress(inputs, kIndex0); - T *input_y = GetDeviceAddress(inputs, kIndex1); - T *dloss = GetDeviceAddress(inputs, kIndex2); - T *weight = nullptr; - if (inputs[kIndex3]->type_id() != kMetaTypeNone) { - weight = GetDeviceAddress(inputs, kIndex3); - } - auto reduction = static_cast(inputs[kIndex4]->GetValueWithCheck()); - if (reduction == Reduction::NONE) { - reduction_ = ReductionMode::kNone; - } else if (reduction == Reduction::MEAN) { - reduction_ = ReductionMode::kMean; - } else { - reduction_ = ReductionMode::kSum; - } - T *dx = GetDeviceAddress(outputs, kIndex0); - if (input_size_ > 0) { - auto status = BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } -} - -bool BinaryCrossEntropyGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!match.first) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; - } - - dtype_ = inputs[kIndex0]->dtype_id(); - return true; -} - -int BinaryCrossEntropyGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - auto input_shape = inputs[kIndex0]->GetShapeVector(); - input_size_ = SizeOf(input_shape); - return KRET_OK; -} - -std::vector BinaryCrossEntropyGradGpuKernelMod::GetOpSupport() { - static std::vector kernel_attr_list = {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOptionalInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOptionalInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32)}; - return kernel_attr_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BinaryCrossEntropyGrad, BinaryCrossEntropyGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/nn/binary_cross_entropy_grad_kernel.h" +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" +#include "ops/op_name.h" + +namespace mindspore { +namespace kernel { +bool BinaryCrossEntropyGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs, stream_ptr); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs, stream_ptr); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of input must be float16 or float32, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +template +void BinaryCrossEntropyGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs, void *stream_ptr) { + T *input_x = GetDeviceAddress(inputs, kIndex0); + T *input_y = GetDeviceAddress(inputs, kIndex1); + T *dloss = GetDeviceAddress(inputs, kIndex2); + T *weight = nullptr; + if (inputs[kIndex3]->type_id() != kMetaTypeNone) { + weight = GetDeviceAddress(inputs, kIndex3); + } + auto reduction = static_cast(inputs[kIndex4]->GetValueWithCheck()); + if (reduction == Reduction::NONE) { + reduction_ = ReductionMode::kNone; + } else if (reduction == Reduction::MEAN) { + reduction_ = ReductionMode::kMean; + } else { + reduction_ = ReductionMode::kSum; + } + T *dx = GetDeviceAddress(outputs, kIndex0); + if (input_size_ > 0) { + auto status = BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } +} + +bool BinaryCrossEntropyGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!match.first) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + } + + dtype_ = inputs[kIndex0]->dtype_id(); + return true; +} + +int BinaryCrossEntropyGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + auto input_shape = inputs[kIndex0]->GetShapeVector(); + input_size_ = SizeOf(input_shape); + return KRET_OK; +} + +std::vector BinaryCrossEntropyGradGpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOptionalInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOptionalInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BinaryCrossEntropyGrad, BinaryCrossEntropyGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.cc index 58f698d03f1..5d59fdae78d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.cc @@ -1,148 +1,148 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h" -#include -#include -#include -#include -#include "mindspore/core/abstract/utils.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh" - -namespace mindspore { -namespace kernel { -namespace { -constexpr int kCol2ImInputsNum = 2; -constexpr int kPaddingDirection = 2; -} // namespace - -void Col2ImFwdGpuKernelMod::ResetResource() noexcept { - batch_size_ = 0; - channels_ = 0; - out_height_ = 0; - out_width_ = 0; - in_height_ = 0; - in_width_ = 0; - pad_height_ = 0; - pad_width_ = 0; - kernel_height_ = 0; - kernel_width_ = 0; - stride_height_ = 0; - stride_width_ = 0; - dilation_height_ = 0; - dilation_width_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); -} - -template -bool Col2ImFwdGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input_addr = GetDeviceAddress(inputs, kIndex0); - T *output_addr = GetDeviceAddress(outputs, kIndex0); - Col2Im(input_addr, batch_size_, channels_, out_height_, out_width_, in_height_, in_width_, kernel_height_, - kernel_width_, pad_height_, pad_width_, stride_height_, stride_width_, dilation_height_, dilation_width_, - output_addr, reinterpret_cast(cuda_stream_)); - return true; -} - -bool Col2ImFwdGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int Col2ImFwdGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - ResetResource(); - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - if (inputs.size() != kCol2ImInputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be equal 2."; - return KRET_RESIZE_FAILED; - } - auto input_shape = inputs[kIndex0]->GetShapeVector(); - auto output_shape = outputs[kIndex0]->GetShapeVector(); - batch_size_ = static_cast(input_shape[kIndex0]); - channels_ = static_cast(input_shape[kIndex1]); - out_height_ = static_cast(output_shape[kIndex2]); - out_width_ = static_cast(output_shape[kIndex3]); - auto kernel_size = GetValue>(primitive_->GetAttr("kernel_size")); - auto dilation = GetValue>(primitive_->GetAttr("dilation")); - auto padding = GetValue>(primitive_->GetAttr("padding")); - auto stride = GetValue>(primitive_->GetAttr("stride")); - std::unordered_map> to_check{ - {"kernel_size", kernel_size}, {"dilation", dilation}, {"padding", padding}, {"stride", stride}}; - for (const auto &[name, vec] : to_check) { - if (std::any_of(vec.begin(), vec.end(), [](int64_t x) { return x > std::numeric_limits::max(); })) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', " << name << " value overflow."; - return KRET_RESIZE_FAILED; - } - } - pad_height_ = static_cast(padding[kIndex0]); - pad_width_ = static_cast(padding[kIndex1]); - kernel_height_ = static_cast(kernel_size[kIndex0]); - kernel_width_ = static_cast(kernel_size[kIndex1]); - stride_height_ = static_cast(stride[kIndex0]); - stride_width_ = static_cast(stride[kIndex1]); - dilation_height_ = static_cast(dilation[kIndex0]); - dilation_width_ = static_cast(dilation[kIndex1]); - in_height_ = static_cast( - (out_height_ + kPaddingDirection * pad_height_ - (dilation_height_ * (kernel_height_ - 1) + 1)) / stride_height_ + - 1); - in_width_ = static_cast( - (out_width_ + kPaddingDirection * pad_width_ - (dilation_width_ * (kernel_width_ - 1) + 1)) / stride_width_ + 1); - return KRET_OK; -} - -std::vector> Col2ImFwdGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &Col2ImFwdGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &Col2ImFwdGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &Col2ImFwdGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64), - &Col2ImFwdGpuKernelMod::LaunchKernel, utils::Complex>}, - {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128), - &Col2ImFwdGpuKernelMod::LaunchKernel, utils::Complex>}, -}; - -std::vector Col2ImFwdGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Col2Im, Col2ImFwdGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h" +#include +#include +#include +#include +#include "mindspore/core/abstract/utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int kCol2ImInputsNum = 2; +constexpr int kPaddingDirection = 2; +} // namespace + +void Col2ImFwdGpuKernelMod::ResetResource() noexcept { + batch_size_ = 0; + channels_ = 0; + out_height_ = 0; + out_width_ = 0; + in_height_ = 0; + in_width_ = 0; + pad_height_ = 0; + pad_width_ = 0; + kernel_height_ = 0; + kernel_width_ = 0; + stride_height_ = 0; + stride_width_ = 0; + dilation_height_ = 0; + dilation_width_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +template +bool Col2ImFwdGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input_addr = GetDeviceAddress(inputs, kIndex0); + T *output_addr = GetDeviceAddress(outputs, kIndex0); + Col2Im(input_addr, batch_size_, channels_, out_height_, out_width_, in_height_, in_width_, kernel_height_, + kernel_width_, pad_height_, pad_width_, stride_height_, stride_width_, dilation_height_, dilation_width_, + output_addr, reinterpret_cast(cuda_stream_)); + return true; +} + +bool Col2ImFwdGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int Col2ImFwdGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + ResetResource(); + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + if (inputs.size() != kCol2ImInputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be equal 2."; + return KRET_RESIZE_FAILED; + } + auto input_shape = inputs[kIndex0]->GetShapeVector(); + auto output_shape = outputs[kIndex0]->GetShapeVector(); + batch_size_ = static_cast(input_shape[kIndex0]); + channels_ = static_cast(input_shape[kIndex1]); + out_height_ = static_cast(output_shape[kIndex2]); + out_width_ = static_cast(output_shape[kIndex3]); + auto kernel_size = GetValue>(primitive_->GetAttr("kernel_size")); + auto dilation = GetValue>(primitive_->GetAttr("dilation")); + auto padding = GetValue>(primitive_->GetAttr("padding")); + auto stride = GetValue>(primitive_->GetAttr("stride")); + std::unordered_map> to_check{ + {"kernel_size", kernel_size}, {"dilation", dilation}, {"padding", padding}, {"stride", stride}}; + for (const auto &[name, vec] : to_check) { + if (std::any_of(vec.begin(), vec.end(), [](int64_t x) { return x > std::numeric_limits::max(); })) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', " << name << " value overflow."; + return KRET_RESIZE_FAILED; + } + } + pad_height_ = static_cast(padding[kIndex0]); + pad_width_ = static_cast(padding[kIndex1]); + kernel_height_ = static_cast(kernel_size[kIndex0]); + kernel_width_ = static_cast(kernel_size[kIndex1]); + stride_height_ = static_cast(stride[kIndex0]); + stride_width_ = static_cast(stride[kIndex1]); + dilation_height_ = static_cast(dilation[kIndex0]); + dilation_width_ = static_cast(dilation[kIndex1]); + in_height_ = static_cast( + (out_height_ + kPaddingDirection * pad_height_ - (dilation_height_ * (kernel_height_ - 1) + 1)) / stride_height_ + + 1); + in_width_ = static_cast( + (out_width_ + kPaddingDirection * pad_width_ - (dilation_width_ * (kernel_width_ - 1) + 1)) / stride_width_ + 1); + return KRET_OK; +} + +std::vector> Col2ImFwdGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &Col2ImFwdGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &Col2ImFwdGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &Col2ImFwdGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64), + &Col2ImFwdGpuKernelMod::LaunchKernel, utils::Complex>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128), + &Col2ImFwdGpuKernelMod::LaunchKernel, utils::Complex>}, +}; + +std::vector Col2ImFwdGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Col2Im, Col2ImFwdGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h index 93915ad3aa1..d18e8e777f9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h @@ -1,84 +1,84 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh" -#include "plugin/device/gpu/kernel/kernel_constants.h" - -namespace mindspore { -namespace kernel { -class Col2ImFwdGpuKernelMod : public NativeGpuKernelMod { - public: - Col2ImFwdGpuKernelMod() { ResetResource(); } - ~Col2ImFwdGpuKernelMod() = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = stream_ptr; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - std::vector GetOpSupport() override; - - protected: - void ResetResource() noexcept; - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using Col2ImFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - uint32_t batch_size_{0}; - uint32_t channels_{0}; - uint32_t out_height_{0}; - uint32_t out_width_{0}; - uint32_t in_height_{0}; - uint32_t in_width_{0}; - uint32_t pad_width_{0}; - uint32_t pad_height_{0}; - uint32_t kernel_height_{0}; - uint32_t kernel_width_{0}; - uint32_t stride_height_{0}; - uint32_t stride_width_{0}; - uint32_t dilation_height_{0}; - uint32_t dilation_width_{0}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - Col2ImFunc kernel_func_{}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh" +#include "plugin/device/gpu/kernel/kernel_constants.h" + +namespace mindspore { +namespace kernel { +class Col2ImFwdGpuKernelMod : public NativeGpuKernelMod { + public: + Col2ImFwdGpuKernelMod() { ResetResource(); } + ~Col2ImFwdGpuKernelMod() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = stream_ptr; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + std::vector GetOpSupport() override; + + protected: + void ResetResource() noexcept; + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using Col2ImFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + uint32_t batch_size_{0}; + uint32_t channels_{0}; + uint32_t out_height_{0}; + uint32_t out_width_{0}; + uint32_t in_height_{0}; + uint32_t in_width_{0}; + uint32_t pad_width_{0}; + uint32_t pad_height_{0}; + uint32_t kernel_height_{0}; + uint32_t kernel_width_{0}; + uint32_t stride_height_{0}; + uint32_t stride_width_{0}; + uint32_t dilation_height_{0}; + uint32_t dilation_width_{0}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + Col2ImFunc kernel_func_{}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.cc index 8915c5b4c91..3deb89115a0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.cc @@ -1,231 +1,231 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "kernel/common_utils.h" -#include "plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h" -#include "mindspore/core/ops/ctc_greedy_decoder.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh" - -namespace mindspore { -namespace kernel { -constexpr size_t kInputNum = 2; -constexpr size_t kOutputNum = 4; -constexpr size_t kInputsRank = 3; -constexpr size_t kDecodedIndicesRank = 2; -constexpr size_t kSeqLenRank = 1; - -void CTCGreedyDecoderGpuKernelMod::ResetResource() { - stream_ptr_ = nullptr; - is_null_input_ = false; - workspace_size_list_.clear(); - output_size_list_.clear(); -} - -void CTCGreedyDecoderGpuKernelMod::InitSizeLists() { - max_time_ = inputs_x_shape_[kIndex0]; - batch_size_ = inputs_x_shape_[kIndex1]; - bound_ = inputs_x_shape_[kIndex2]; - - workspace_size_list_.push_back(sizeof(int64_t)); - workspace_size_list_.push_back(batch_size_ * sizeof(int64_t)); - workspace_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t)); - - output_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t) * kDecodedIndicesRank); - output_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t)); - output_size_list_.push_back(kDecodedIndicesRank * sizeof(int64_t)); - output_size_list_.push_back(max_time_ * batch_size_ * data_unit_size_); -} - -bool CTCGreedyDecoderGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - - merge_repeated_ = GetValue(primitive_->GetAttr("merge_repeated")); - - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; - return false; - } - - kernel_func_ = func_list_[index].second; - data_unit_size_ = abstract::TypeIdSize(inputs[kIndex0]->dtype_id()); - return true; -} - -int CTCGreedyDecoderGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - return KRET_OK; - } - for (const auto &input : inputs) { - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - - ResetResource(); - - inputs_x_shape_ = inputs[kIndex0]->GetShapeVector(); - sequence_shape_ = inputs[kIndex1]->GetShapeVector(); - - if (inputs_x_shape_.size() != kInputsRank) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', inputs's dim must be 3, but got: " << inputs_x_shape_.size() - << "."; - return KRET_RESIZE_FAILED; - } - - if (sequence_shape_.size() != kSeqLenRank) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', sequence_length's dims must be 1, but got: " << sequence_shape_.size() << "."; - return KRET_RESIZE_FAILED; - } - - if (inputs_x_shape_[1] != sequence_shape_[0]) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', inputs batch_size must be the same with sequence_length batch_size" - << "."; - return KRET_RESIZE_FAILED; - } - - InitSizeLists(); - - if (inputs.size() != kInputNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', Input size list should be " << kInputNum << ", but got " - << inputs.size() << "."; - return KRET_RESIZE_FAILED; - } - - return KRET_OK; -} - -template -bool CTCGreedyDecoderGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - stream_ptr_ = stream_ptr; - T *inputs_x = GetDeviceAddress(inputs, kIndex0); - int *sequence_length = GetDeviceAddress(inputs, kIndex1); - int64_t *nums_count = GetDeviceAddress(workspace, kIndex1); - int64_t *decoded_values_temp = GetDeviceAddress(workspace, kIndex2); - - int64_t *decoded_indices = GetDeviceAddress(outputs, kIndex0); - int64_t *decoded_values = GetDeviceAddress(outputs, kIndex1); - int64_t *decoded_shape = GetDeviceAddress(outputs, kIndex2); - T *log_probability = GetDeviceAddress(outputs, kIndex3); - - std::vector seq_host(sequence_shape_[0]); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(seq_host.data(), sequence_length, sequence_shape_[0] * sizeof(int32_t), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr_)), - "For 'CTCGreedyDecoder', cudaMemcpy beta failed"); - if (cudaStreamQuery(reinterpret_cast(stream_ptr_)) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), - "For 'CTCGreedyDecoder', cudaStreamSyncFailed"); - } - for (int b = 0; b < sequence_shape_[0]; b++) { - if (seq_host[b] > static_cast(max_time_)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', sequence_length[" << b << "] should be less than " - << max_time_ << ", but got " << seq_host[b] << "."; - } - } - - auto status = CalCTCGreedyDecoder(inputs_x, bound_, max_time_ * batch_size_, batch_size_, decoded_values_temp, - log_probability, device_id_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - - status = Calmerge(decoded_values_temp, sequence_length, batch_size_, bound_, merge_repeated_, log_probability, - nums_count, device_id_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - - status = Calindices(decoded_values_temp, nums_count, batch_size_, decoded_indices, decoded_values, decoded_shape, - device_id_, reinterpret_cast(stream_ptr), &element_cnt_); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -void CTCGreedyDecoderGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), - "cudaStreamSynchronized failed"); - - std::vector indices_shape = outputs[kIndex0]->GetShapeVector(); - indices_shape[kIndex0] = element_cnt_; - outputs[kIndex0]->SetShapeVector(std::vector(indices_shape.begin(), indices_shape.end())); - outputs[kIndex0]->set_size( - LongToSize(std::accumulate(indices_shape.begin(), indices_shape.end(), - UnitSizeInBytes(outputs[kIndex0]->dtype_id()), std::multiplies()))); - - std::vector values_shape = outputs[kIndex1]->GetShapeVector(); - values_shape[kIndex0] = element_cnt_; - outputs[kIndex1]->SetShapeVector(std::vector(values_shape.begin(), values_shape.end())); - outputs[kIndex1]->set_size( - LongToSize(std::accumulate(values_shape.begin(), values_shape.end(), UnitSizeInBytes(outputs[kIndex1]->dtype_id()), - std::multiplies()))); - - std::vector log_shape = outputs[kIndex3]->GetShapeVector(); - log_shape[kIndex0] = inputs_x_shape_[1]; - outputs[kIndex3]->SetShapeVector(std::vector(log_shape.begin(), log_shape.end())); - outputs[kIndex3]->set_size(LongToSize(std::accumulate( - log_shape.begin(), log_shape.end(), UnitSizeInBytes(outputs[kIndex3]->dtype_id()), std::multiplies()))); -} - -std::vector> - CTCGreedyDecoderGpuKernelMod::func_list_ = {{KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &CTCGreedyDecoderGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &CTCGreedyDecoderGpuKernelMod::LaunchKernel}}; - -std::vector CTCGreedyDecoderGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CTCGreedyDecoder, CTCGreedyDecoderGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernel/common_utils.h" +#include "plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h" +#include "mindspore/core/ops/ctc_greedy_decoder.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcgreedydecoder_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t kInputNum = 2; +constexpr size_t kOutputNum = 4; +constexpr size_t kInputsRank = 3; +constexpr size_t kDecodedIndicesRank = 2; +constexpr size_t kSeqLenRank = 1; + +void CTCGreedyDecoderGpuKernelMod::ResetResource() { + stream_ptr_ = nullptr; + is_null_input_ = false; + workspace_size_list_.clear(); + output_size_list_.clear(); +} + +void CTCGreedyDecoderGpuKernelMod::InitSizeLists() { + max_time_ = inputs_x_shape_[kIndex0]; + batch_size_ = inputs_x_shape_[kIndex1]; + bound_ = inputs_x_shape_[kIndex2]; + + workspace_size_list_.push_back(sizeof(int64_t)); + workspace_size_list_.push_back(batch_size_ * sizeof(int64_t)); + workspace_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t)); + + output_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t) * kDecodedIndicesRank); + output_size_list_.push_back(max_time_ * batch_size_ * sizeof(int64_t)); + output_size_list_.push_back(kDecodedIndicesRank * sizeof(int64_t)); + output_size_list_.push_back(max_time_ * batch_size_ * data_unit_size_); +} + +bool CTCGreedyDecoderGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + + merge_repeated_ = GetValue(primitive_->GetAttr("merge_repeated")); + + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; + return false; + } + + kernel_func_ = func_list_[index].second; + data_unit_size_ = abstract::TypeIdSize(inputs[kIndex0]->dtype_id()); + return true; +} + +int CTCGreedyDecoderGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + return KRET_OK; + } + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + + ResetResource(); + + inputs_x_shape_ = inputs[kIndex0]->GetShapeVector(); + sequence_shape_ = inputs[kIndex1]->GetShapeVector(); + + if (inputs_x_shape_.size() != kInputsRank) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', inputs's dim must be 3, but got: " << inputs_x_shape_.size() + << "."; + return KRET_RESIZE_FAILED; + } + + if (sequence_shape_.size() != kSeqLenRank) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', sequence_length's dims must be 1, but got: " << sequence_shape_.size() << "."; + return KRET_RESIZE_FAILED; + } + + if (inputs_x_shape_[1] != sequence_shape_[0]) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', inputs batch_size must be the same with sequence_length batch_size" + << "."; + return KRET_RESIZE_FAILED; + } + + InitSizeLists(); + + if (inputs.size() != kInputNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', Input size list should be " << kInputNum << ", but got " + << inputs.size() << "."; + return KRET_RESIZE_FAILED; + } + + return KRET_OK; +} + +template +bool CTCGreedyDecoderGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + stream_ptr_ = stream_ptr; + T *inputs_x = GetDeviceAddress(inputs, kIndex0); + int *sequence_length = GetDeviceAddress(inputs, kIndex1); + int64_t *nums_count = GetDeviceAddress(workspace, kIndex1); + int64_t *decoded_values_temp = GetDeviceAddress(workspace, kIndex2); + + int64_t *decoded_indices = GetDeviceAddress(outputs, kIndex0); + int64_t *decoded_values = GetDeviceAddress(outputs, kIndex1); + int64_t *decoded_shape = GetDeviceAddress(outputs, kIndex2); + T *log_probability = GetDeviceAddress(outputs, kIndex3); + + std::vector seq_host(sequence_shape_[0]); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(seq_host.data(), sequence_length, sequence_shape_[0] * sizeof(int32_t), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "For 'CTCGreedyDecoder', cudaMemcpy beta failed"); + if (cudaStreamQuery(reinterpret_cast(stream_ptr_)) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "For 'CTCGreedyDecoder', cudaStreamSyncFailed"); + } + for (int b = 0; b < sequence_shape_[0]; b++) { + if (seq_host[b] > static_cast(max_time_)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', sequence_length[" << b << "] should be less than " + << max_time_ << ", but got " << seq_host[b] << "."; + } + } + + auto status = CalCTCGreedyDecoder(inputs_x, bound_, max_time_ * batch_size_, batch_size_, decoded_values_temp, + log_probability, device_id_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + + status = Calmerge(decoded_values_temp, sequence_length, batch_size_, bound_, merge_repeated_, log_probability, + nums_count, device_id_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + + status = Calindices(decoded_values_temp, nums_count, batch_size_, decoded_indices, decoded_values, decoded_shape, + device_id_, reinterpret_cast(stream_ptr), &element_cnt_); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +void CTCGreedyDecoderGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); + + std::vector indices_shape = outputs[kIndex0]->GetShapeVector(); + indices_shape[kIndex0] = element_cnt_; + outputs[kIndex0]->SetShapeVector(std::vector(indices_shape.begin(), indices_shape.end())); + outputs[kIndex0]->set_size( + LongToSize(std::accumulate(indices_shape.begin(), indices_shape.end(), + UnitSizeInBytes(outputs[kIndex0]->dtype_id()), std::multiplies()))); + + std::vector values_shape = outputs[kIndex1]->GetShapeVector(); + values_shape[kIndex0] = element_cnt_; + outputs[kIndex1]->SetShapeVector(std::vector(values_shape.begin(), values_shape.end())); + outputs[kIndex1]->set_size( + LongToSize(std::accumulate(values_shape.begin(), values_shape.end(), UnitSizeInBytes(outputs[kIndex1]->dtype_id()), + std::multiplies()))); + + std::vector log_shape = outputs[kIndex3]->GetShapeVector(); + log_shape[kIndex0] = inputs_x_shape_[1]; + outputs[kIndex3]->SetShapeVector(std::vector(log_shape.begin(), log_shape.end())); + outputs[kIndex3]->set_size(LongToSize(std::accumulate( + log_shape.begin(), log_shape.end(), UnitSizeInBytes(outputs[kIndex3]->dtype_id()), std::multiplies()))); +} + +std::vector> + CTCGreedyDecoderGpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &CTCGreedyDecoderGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &CTCGreedyDecoderGpuKernelMod::LaunchKernel}}; + +std::vector CTCGreedyDecoderGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CTCGreedyDecoder, CTCGreedyDecoderGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h index 8ded62a4a2f..23a6a6be664 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcgreedydecoder_gpu_kernel.h @@ -1,80 +1,80 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ -#define MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/ctc_greedy_decoder.h" -#include "mindspore/core/abstract/utils.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" -#include "plugin/factory/ms_factory.h" -namespace mindspore { -namespace kernel { -class CTCGreedyDecoderGpuKernelMod : public NativeGpuKernelMod { - public: - CTCGreedyDecoderGpuKernelMod() = default; - ~CTCGreedyDecoderGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - return kernel_func_(this, inputs, workspace, outputs, stream_ptr); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr); - using CTCGreedyDecoderFunc = std::function &, - const std::vector &, const std::vector &, void *)>; - static std::vector> func_list_; - bool IsNeedUpdateOutputShapeAndSize() override { return true; } - void UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) override; - - CTCGreedyDecoderFunc kernel_func_; - - private: - std::vector inputs_x_shape_; - std::vector sequence_shape_; - size_t data_unit_size_; - size_t batch_size_; - size_t max_time_; - int bound_; - bool merge_repeated_{true}; - bool is_null_input_; - int64_t element_cnt_; - void *stream_ptr_; - void ResetResource(); - void InitSizeLists(); -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ +#define MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/ctc_greedy_decoder.h" +#include "mindspore/core/abstract/utils.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/kernel_constants.h" +#include "plugin/factory/ms_factory.h" +namespace mindspore { +namespace kernel { +class CTCGreedyDecoderGpuKernelMod : public NativeGpuKernelMod { + public: + CTCGreedyDecoderGpuKernelMod() = default; + ~CTCGreedyDecoderGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + using CTCGreedyDecoderFunc = std::function &, + const std::vector &, const std::vector &, void *)>; + static std::vector> func_list_; + bool IsNeedUpdateOutputShapeAndSize() override { return true; } + void UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) override; + + CTCGreedyDecoderFunc kernel_func_; + + private: + std::vector inputs_x_shape_; + std::vector sequence_shape_; + size_t data_unit_size_; + size_t batch_size_; + size_t max_time_; + int bound_; + bool merge_repeated_{true}; + bool is_null_input_; + int64_t element_cnt_; + void *stream_ptr_; + void ResetResource(); + void InitSizeLists(); +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CTCGREEDYDECODER_CTCGREEDYDECODER_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.cc index 6646e2fade9..59bd5187f17 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.cc @@ -1,31 +1,31 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CTCLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CtcLossGpuKernelMod, float) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/ctcloss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CTCLoss, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CtcLossGpuKernelMod, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.cc index 8e7b022bbb4..6dc37e7d41d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.cc @@ -1,246 +1,246 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.h" - -#include -#include -#include -#include -#include -#include -#include "abstract/utils.h" -#include "mindspore/core/ops/grad/deformable_offsets_grad.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kInputNum = 3; -constexpr size_t kOutputNum = 2; -constexpr size_t kInputShapeSize = 4; - -constexpr size_t kGradIndex = 0; -constexpr size_t kXIndex = 1; -constexpr size_t kOffsetIndex = 2; -constexpr size_t kGradXIndex = 0; -constexpr size_t kGradOffsetIndex = 1; - -auto constexpr kPadStr = "pads"; -auto constexpr kStrideStr = "strides"; -auto constexpr kDilationStr = "dilation"; -auto constexpr kKernelSizeStr = "kernel size"; -auto constexpr kInputXStr = "input_x"; -auto constexpr kInputGradStr = "input_grad"; - -constexpr size_t kPadNum = 4; -constexpr size_t kStrideNum = 4; -constexpr size_t kDilationNum = 4; -constexpr size_t kKernelSizeNum = 2; - -constexpr size_t kCIndexForNCHW = 1; -constexpr size_t kHIndexForNCHW = 2; -constexpr size_t kWIndexForNCHW = 3; -constexpr size_t kHIndexForNHWC = 1; -constexpr size_t kWIndexForNHWC = 2; -constexpr size_t kCIndexForNHWC = 3; - -constexpr size_t kPadTopIndex = 0; -constexpr size_t kPadLeftIndex = 2; -constexpr size_t kStrideHIndex = 2; -constexpr size_t kStrideWIndex = 3; -constexpr size_t kDilationHIndex = 2; -constexpr size_t kDilationWIndex = 3; -constexpr size_t kKernelHIndex = 0; -constexpr size_t kKernelWIndex = 1; - -void CheckSize(const std::string &kernel_name_, const std::string &dim_name, size_t expect, size_t actual) { - if (actual != expect) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of '" << dim_name << "' must be " << expect - << ", but got " << actual; - } -} -} // namespace - -bool DeformableOffsetsGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, void *stream_ptr) { - cuda_stream_ = reinterpret_cast(stream_ptr); - return kernel_func_(this, inputs, outputs); -} - -bool DeformableOffsetsGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.size() != kInputNum || outputs.size() != kOutputNum) { - MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputNum << " and " << kOutputNum - << ", but get " << inputs.size() << " and " << outputs.size(); - return false; - } - - data_format_ = GetValue(primitive_->GetAttr("format")); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(0).dtype); - return true; -} - -void DeformableOffsetsGradGpuKernelMod::SetDims(const std::vector &inputs, - const std::vector &outputs) { - auto kernel_name_ = primitive_->name(); - dims_.deformable_group = LongToUint(GetValue(primitive_->GetAttr("deformable_groups"))); - if (dims_.deformable_group == 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', deformable group must be greater than 0."; - } - std::vector pad = GetValue>(primitive_->GetAttr("pads")); - CheckSize(kernel_name_, kPadStr, kPadNum, pad.size()); - dims_.pad_top = LongToUint(pad[kPadTopIndex]); - dims_.pad_left = LongToUint(pad[kPadLeftIndex]); - - std::vector stride = GetValue>(primitive_->GetAttr("strides")); - CheckSize(kernel_name_, kStrideStr, kStrideNum, stride.size()); - dims_.stride_h = LongToUint(stride[kStrideHIndex]); - dims_.stride_w = LongToUint(stride[kStrideWIndex]); - - std::vector dilation = GetValue>(primitive_->GetAttr("dilations")); - CheckSize(kernel_name_, kDilationStr, kDilationNum, dilation.size()); - dims_.dilation_h = LongToUint(dilation[kDilationHIndex]); - dims_.dilation_w = LongToUint(dilation[kDilationWIndex]); - - std::vector ksize = GetValue>(primitive_->GetAttr("ksize")); - CheckSize(kernel_name_, kKernelSizeStr, kKernelSizeNum, ksize.size()); - dims_.kernel_h = LongToUint(ksize[kKernelHIndex]); - dims_.kernel_w = LongToUint(ksize[kKernelWIndex]); - if (dims_.kernel_h == 0 || dims_.kernel_w == 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'ksize' must be larger than 0."; - } - auto x_shape = inputs[kXIndex]->GetShapeVector(); - CheckSize(kernel_name_, kInputXStr, kInputShapeSize, x_shape.size()); - dims_.x_n = LongToUint(x_shape[0]); - auto grad_shape = inputs[kGradIndex]->GetShapeVector(); - CheckSize(kernel_name_, kInputGradStr, kInputShapeSize, grad_shape.size()); - if (data_format_ == kOpFormat_NCHW) { - dims_.grad_h = LongToUint(grad_shape[kHIndexForNCHW]); - dims_.grad_w = LongToUint(grad_shape[kWIndexForNCHW]); - dims_.x_h = LongToUint(x_shape[kHIndexForNCHW]); - dims_.x_w = LongToUint(x_shape[kWIndexForNCHW]); - dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNCHW]) / dims_.deformable_group; - } else { - dims_.grad_h = LongToUint(grad_shape[kHIndexForNHWC]); - dims_.grad_w = LongToUint(grad_shape[kWIndexForNHWC]); - dims_.x_h = LongToUint(x_shape[kHIndexForNHWC]); - dims_.x_w = LongToUint(x_shape[kWIndexForNHWC]); - dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNHWC]) / dims_.deformable_group; - } - dims_.offset_h = dims_.grad_h / dims_.kernel_h; - dims_.offset_w = dims_.grad_w / dims_.kernel_w; - - auto grad_x_shape = outputs[kGradXIndex]->GetShapeVector(); - grad_x_size_ = std::accumulate(grad_x_shape.begin(), grad_x_shape.end(), type_size_, std::multiplies()); - - auto grad_offset_shape = outputs[kGradOffsetIndex]->GetShapeVector(); - grad_offset_size_ = - std::accumulate(grad_offset_shape.begin(), grad_offset_shape.end(), type_size_, std::multiplies()); -} - -int DeformableOffsetsGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - MS_LOG(ERROR) << kernel_name_ << " kernel mode resize failed."; - return ret; - } - if (inputs.size() != kInputNum || output_size_list_.size() != kOutputNum) { - MS_LOG(ERROR) << kernel_name_ << " resize : input and output size should be " << kInputNum << " and " << kOutputNum - << ", but got " << inputs.size() << " and " << output_size_list_.size(); - return KRET_RESIZE_FAILED; - } - SetDims(inputs, outputs); - return KRET_OK; -} - -template -bool DeformableOffsetsGradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *grad_addr = GetDeviceAddress(inputs, kGradIndex); - T *x_addr = GetDeviceAddress(inputs, kXIndex); - T *offset_addr = GetDeviceAddress(inputs, kOffsetIndex); - T *grad_x_addr = GetDeviceAddress(outputs, kGradXIndex); - T *grad_offset_addr = GetDeviceAddress(outputs, kGradOffsetIndex); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_x_addr, 0, grad_x_size_, cuda_stream_), - "Call cudaMemsetAsync grad_x failed"); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_offset_addr, 0, grad_offset_size_, cuda_stream_), - "Call cudaMemsetAsync grad_x failed"); - uint dim_x_n = dims_.x_n; - uint dim_x_h = dims_.x_h; - uint dim_x_w = dims_.x_w; - uint dim_offset_h = dims_.offset_h; - uint dim_offset_w = dims_.offset_w; - uint dim_kernel_h = dims_.kernel_h; - uint dim_kernel_w = dims_.kernel_w; - uint dim_pad_top = dims_.pad_top; - uint dim_pad_left = dims_.pad_left; - uint dim_stride_h = dims_.stride_h; - uint dim_stride_w = dims_.stride_w; - uint dim_dilation_h = dims_.dilation_h; - uint dim_dilation_w = dims_.dilation_w; - uint dim_deformable_group = dims_.deformable_group; - uint dim_deformable_group_channel = dims_.deformable_group_channel; - cudaError_t status = cudaErrorNotReady; - if (data_format_ == kOpFormat_NCHW) { - status = ApplyDeformableOffsetGrad( - dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top, dim_pad_left, - dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group, dim_deformable_group_channel, - true, grad_addr, x_addr, offset_addr, grad_x_addr, grad_offset_addr, device_id_, cuda_stream_); - } else { - status = ApplyDeformableOffsetGrad( - dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top, dim_pad_left, - dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group, dim_deformable_group_channel, - false, grad_addr, x_addr, offset_addr, grad_x_addr, grad_offset_addr, device_id_, cuda_stream_); - } - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> - DeformableOffsetsGradGpuKernelMod::func_list_ = {{KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - &DeformableOffsetsGradGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &DeformableOffsetsGradGpuKernelMod::LaunchKernel}}; - -std::vector DeformableOffsetsGradGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DeformableOffsetsGrad, DeformableOffsetsGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/nn/deformable_offsets_grad_gpu_kernel.h" + +#include +#include +#include +#include +#include +#include +#include "abstract/utils.h" +#include "mindspore/core/ops/grad/deformable_offsets_grad.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputNum = 3; +constexpr size_t kOutputNum = 2; +constexpr size_t kInputShapeSize = 4; + +constexpr size_t kGradIndex = 0; +constexpr size_t kXIndex = 1; +constexpr size_t kOffsetIndex = 2; +constexpr size_t kGradXIndex = 0; +constexpr size_t kGradOffsetIndex = 1; + +auto constexpr kPadStr = "pads"; +auto constexpr kStrideStr = "strides"; +auto constexpr kDilationStr = "dilation"; +auto constexpr kKernelSizeStr = "kernel size"; +auto constexpr kInputXStr = "input_x"; +auto constexpr kInputGradStr = "input_grad"; + +constexpr size_t kPadNum = 4; +constexpr size_t kStrideNum = 4; +constexpr size_t kDilationNum = 4; +constexpr size_t kKernelSizeNum = 2; + +constexpr size_t kCIndexForNCHW = 1; +constexpr size_t kHIndexForNCHW = 2; +constexpr size_t kWIndexForNCHW = 3; +constexpr size_t kHIndexForNHWC = 1; +constexpr size_t kWIndexForNHWC = 2; +constexpr size_t kCIndexForNHWC = 3; + +constexpr size_t kPadTopIndex = 0; +constexpr size_t kPadLeftIndex = 2; +constexpr size_t kStrideHIndex = 2; +constexpr size_t kStrideWIndex = 3; +constexpr size_t kDilationHIndex = 2; +constexpr size_t kDilationWIndex = 3; +constexpr size_t kKernelHIndex = 0; +constexpr size_t kKernelWIndex = 1; + +void CheckSize(const std::string &kernel_name_, const std::string &dim_name, size_t expect, size_t actual) { + if (actual != expect) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of '" << dim_name << "' must be " << expect + << ", but got " << actual; + } +} +} // namespace + +bool DeformableOffsetsGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, void *stream_ptr) { + cuda_stream_ = reinterpret_cast(stream_ptr); + return kernel_func_(this, inputs, outputs); +} + +bool DeformableOffsetsGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.size() != kInputNum || outputs.size() != kOutputNum) { + MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kInputNum << " and " << kOutputNum + << ", but get " << inputs.size() << " and " << outputs.size(); + return false; + } + + data_format_ = GetValue(primitive_->GetAttr("format")); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(0).dtype); + return true; +} + +void DeformableOffsetsGradGpuKernelMod::SetDims(const std::vector &inputs, + const std::vector &outputs) { + auto kernel_name_ = primitive_->name(); + dims_.deformable_group = LongToUint(GetValue(primitive_->GetAttr("deformable_groups"))); + if (dims_.deformable_group == 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', deformable group must be greater than 0."; + } + std::vector pad = GetValue>(primitive_->GetAttr("pads")); + CheckSize(kernel_name_, kPadStr, kPadNum, pad.size()); + dims_.pad_top = LongToUint(pad[kPadTopIndex]); + dims_.pad_left = LongToUint(pad[kPadLeftIndex]); + + std::vector stride = GetValue>(primitive_->GetAttr("strides")); + CheckSize(kernel_name_, kStrideStr, kStrideNum, stride.size()); + dims_.stride_h = LongToUint(stride[kStrideHIndex]); + dims_.stride_w = LongToUint(stride[kStrideWIndex]); + + std::vector dilation = GetValue>(primitive_->GetAttr("dilations")); + CheckSize(kernel_name_, kDilationStr, kDilationNum, dilation.size()); + dims_.dilation_h = LongToUint(dilation[kDilationHIndex]); + dims_.dilation_w = LongToUint(dilation[kDilationWIndex]); + + std::vector ksize = GetValue>(primitive_->GetAttr("ksize")); + CheckSize(kernel_name_, kKernelSizeStr, kKernelSizeNum, ksize.size()); + dims_.kernel_h = LongToUint(ksize[kKernelHIndex]); + dims_.kernel_w = LongToUint(ksize[kKernelWIndex]); + if (dims_.kernel_h == 0 || dims_.kernel_w == 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'ksize' must be larger than 0."; + } + auto x_shape = inputs[kXIndex]->GetShapeVector(); + CheckSize(kernel_name_, kInputXStr, kInputShapeSize, x_shape.size()); + dims_.x_n = LongToUint(x_shape[0]); + auto grad_shape = inputs[kGradIndex]->GetShapeVector(); + CheckSize(kernel_name_, kInputGradStr, kInputShapeSize, grad_shape.size()); + if (data_format_ == kOpFormat_NCHW) { + dims_.grad_h = LongToUint(grad_shape[kHIndexForNCHW]); + dims_.grad_w = LongToUint(grad_shape[kWIndexForNCHW]); + dims_.x_h = LongToUint(x_shape[kHIndexForNCHW]); + dims_.x_w = LongToUint(x_shape[kWIndexForNCHW]); + dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNCHW]) / dims_.deformable_group; + } else { + dims_.grad_h = LongToUint(grad_shape[kHIndexForNHWC]); + dims_.grad_w = LongToUint(grad_shape[kWIndexForNHWC]); + dims_.x_h = LongToUint(x_shape[kHIndexForNHWC]); + dims_.x_w = LongToUint(x_shape[kWIndexForNHWC]); + dims_.deformable_group_channel = LongToUint(x_shape[kCIndexForNHWC]) / dims_.deformable_group; + } + dims_.offset_h = dims_.grad_h / dims_.kernel_h; + dims_.offset_w = dims_.grad_w / dims_.kernel_w; + + auto grad_x_shape = outputs[kGradXIndex]->GetShapeVector(); + grad_x_size_ = std::accumulate(grad_x_shape.begin(), grad_x_shape.end(), type_size_, std::multiplies()); + + auto grad_offset_shape = outputs[kGradOffsetIndex]->GetShapeVector(); + grad_offset_size_ = + std::accumulate(grad_offset_shape.begin(), grad_offset_shape.end(), type_size_, std::multiplies()); +} + +int DeformableOffsetsGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + MS_LOG(ERROR) << kernel_name_ << " kernel mode resize failed."; + return ret; + } + if (inputs.size() != kInputNum || output_size_list_.size() != kOutputNum) { + MS_LOG(ERROR) << kernel_name_ << " resize : input and output size should be " << kInputNum << " and " << kOutputNum + << ", but got " << inputs.size() << " and " << output_size_list_.size(); + return KRET_RESIZE_FAILED; + } + SetDims(inputs, outputs); + return KRET_OK; +} + +template +bool DeformableOffsetsGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *grad_addr = GetDeviceAddress(inputs, kGradIndex); + T *x_addr = GetDeviceAddress(inputs, kXIndex); + T *offset_addr = GetDeviceAddress(inputs, kOffsetIndex); + T *grad_x_addr = GetDeviceAddress(outputs, kGradXIndex); + T *grad_offset_addr = GetDeviceAddress(outputs, kGradOffsetIndex); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_x_addr, 0, grad_x_size_, cuda_stream_), + "Call cudaMemsetAsync grad_x failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(grad_offset_addr, 0, grad_offset_size_, cuda_stream_), + "Call cudaMemsetAsync grad_x failed"); + uint dim_x_n = dims_.x_n; + uint dim_x_h = dims_.x_h; + uint dim_x_w = dims_.x_w; + uint dim_offset_h = dims_.offset_h; + uint dim_offset_w = dims_.offset_w; + uint dim_kernel_h = dims_.kernel_h; + uint dim_kernel_w = dims_.kernel_w; + uint dim_pad_top = dims_.pad_top; + uint dim_pad_left = dims_.pad_left; + uint dim_stride_h = dims_.stride_h; + uint dim_stride_w = dims_.stride_w; + uint dim_dilation_h = dims_.dilation_h; + uint dim_dilation_w = dims_.dilation_w; + uint dim_deformable_group = dims_.deformable_group; + uint dim_deformable_group_channel = dims_.deformable_group_channel; + cudaError_t status = cudaErrorNotReady; + if (data_format_ == kOpFormat_NCHW) { + status = ApplyDeformableOffsetGrad( + dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top, dim_pad_left, + dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group, dim_deformable_group_channel, + true, grad_addr, x_addr, offset_addr, grad_x_addr, grad_offset_addr, device_id_, cuda_stream_); + } else { + status = ApplyDeformableOffsetGrad( + dim_x_n, dim_x_h, dim_x_w, dim_offset_h, dim_offset_w, dim_kernel_h, dim_kernel_w, dim_pad_top, dim_pad_left, + dim_stride_h, dim_stride_w, dim_dilation_h, dim_dilation_w, dim_deformable_group, dim_deformable_group_channel, + false, grad_addr, x_addr, offset_addr, grad_x_addr, grad_offset_addr, device_id_, cuda_stream_); + } + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> + DeformableOffsetsGradGpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &DeformableOffsetsGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &DeformableOffsetsGradGpuKernelMod::LaunchKernel}}; + +std::vector DeformableOffsetsGradGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DeformableOffsetsGrad, DeformableOffsetsGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.cc index 9166301ae0e..9d453aa44d2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.cc @@ -1,91 +1,91 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h" -#include -#include - -namespace mindspore { -namespace kernel { -std::vector> GridSampler2DGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat16), - &GridSampler2DGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat32), - &GridSampler2DGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat64), - &GridSampler2DGpuKernelMod::LaunchKernel}}; - -std::vector GridSampler2DGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -std::vector> GridSampler3DGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat16), - &GridSampler3DGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat32), - &GridSampler3DGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) - .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) - .AddOutputAttr(kNumberTypeFloat64), - &GridSampler3DGpuKernelMod::LaunchKernel}}; - -std::vector GridSampler3DGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GridSampler2D, GridSampler2DGpuKernelMod); -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GridSampler3D, GridSampler3DGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h" +#include +#include + +namespace mindspore { +namespace kernel { +std::vector> GridSampler2DGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat16), + &GridSampler2DGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat32), + &GridSampler2DGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat64), + &GridSampler2DGpuKernelMod::LaunchKernel}}; + +std::vector GridSampler2DGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +std::vector> GridSampler3DGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat16), + &GridSampler3DGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat32), + &GridSampler3DGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64) + .AddInputAttr(kObjectTypeNumber, kNumberTypeBool) + .AddOutputAttr(kNumberTypeFloat64), + &GridSampler3DGpuKernelMod::LaunchKernel}}; + +std::vector GridSampler3DGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GridSampler2D, GridSampler2DGpuKernelMod); +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GridSampler3D, GridSampler3DGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h index 95abee63141..7c4cf637d29 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h @@ -1,292 +1,292 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d.h" -#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d.h" -#include "mindspore/core/ops/op_enum.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh" - -namespace mindspore { -namespace kernel { -class GridSampler2DGpuKernelMod : public NativeGpuKernelMod { - public: - GridSampler2DGpuKernelMod() { ResetResource(); } - ~GridSampler2DGpuKernelMod() override = default; - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs, - void *stream_ptr) { - if (is_null_input_) { - return true; - } - interpolation_mode_ = static_cast(inputs[kIndex2]->GetValueWithCheck()); - padding_mode_ = static_cast(inputs[kIndex3]->GetValueWithCheck()); - align_corners_ = inputs[kIndex4]->GetValueWithCheck(); - T *input_addr = GetDeviceAddress(inputs, kIndex0); - T *grid_addr = GetDeviceAddress(inputs, kIndex1); - T *output_addr = GetDeviceAddress(outputs, kIndex0); - auto status = GridSampler2D(size_, input_addr, grid_addr, output_addr, input_shape_, grid_shape_, output_shape_, - input_stride_, grid_stride_, output_stride_, interpolation_mode_, padding_mode_, - align_corners_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; - } - - int Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - auto convert_int64_shape_to_sizet_shape = [=](std::vector int64_shape) -> std::vector { - std::vector size_t_shape; - (void)std::transform(int64_shape.begin(), int64_shape.end(), std::back_inserter(size_t_shape), LongToSize); - return size_t_shape; - }; - input_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex0]->GetShapeVector()); - grid_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex1]->GetShapeVector()); - output_shape_ = convert_int64_shape_to_sizet_shape(outputs[kIndex0]->GetShapeVector()); - - if (input_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'input' must be at 4-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - if (grid_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'grid' must be at 4-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - if (output_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'output' must be at 4-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - size_t stride_tmp = 1; - auto stride_compute = [&](std::vector &stride, std::vector shape) { - for (int i = 3; i > -static_cast(1); i--) { - (void)stride.insert(stride.begin(), stride_tmp); - stride_tmp *= shape[static_cast(i)]; - } - stride_tmp = 1; - }; - input_stride_.clear(); - grid_stride_.clear(); - output_stride_.clear(); - stride_compute(input_stride_, input_shape_); - stride_compute(grid_stride_, grid_shape_); - stride_compute(output_stride_, output_shape_); - size_ = input_shape_[kIndex0] * grid_shape_[kIndex1] * grid_shape_[kIndex2]; - return KRET_OK; - } - - void ResetResource() noexcept { - input_shape_.clear(); - grid_shape_.clear(); - output_shape_.clear(); - input_stride_.clear(); - grid_stride_.clear(); - output_stride_.clear(); - size_ = 0; - interpolation_mode_ = GridSamplerInterpolationMode::BILINEAR; - padding_mode_ = GridSamplerPaddingMode::ZEROS; - align_corners_ = false; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); - kernel_func_(this, inputs, outputs, stream_ptr); - return true; - } - - std::vector GetOpSupport() override; - - private: - using KernelFunc = std::function &, - const std::vector &, void *)>; - KernelFunc kernel_func_{}; - static std::vector> func_list_; - size_t size_; - std::vector input_shape_; - std::vector grid_shape_; - std::vector output_shape_; - std::vector input_stride_; - std::vector grid_stride_; - std::vector output_stride_; - GridSamplerInterpolationMode interpolation_mode_; - GridSamplerPaddingMode padding_mode_; - bool align_corners_; - bool is_null_input_; -}; - -class GridSampler3DGpuKernelMod : public NativeGpuKernelMod { - public: - GridSampler3DGpuKernelMod() { ResetResource(); } - ~GridSampler3DGpuKernelMod() override = default; - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs, - void *stream_ptr) { - if (is_null_input_) { - return true; - } - interpolation_mode_ = static_cast(inputs[kIndex2]->GetValueWithCheck()); - padding_mode_ = static_cast(inputs[kIndex3]->GetValueWithCheck()); - align_corners_ = inputs[kIndex4]->GetValueWithCheck(); - T *input_addr = GetDeviceAddress(inputs, kIndex0); - T *grid_addr = GetDeviceAddress(inputs, kIndex1); - T *output_addr = GetDeviceAddress(outputs, kIndex0); - GridSampler3D(size_, input_addr, grid_addr, output_addr, input_shape_, grid_shape_, output_shape_, input_stride_, - grid_stride_, output_stride_, interpolation_mode_, padding_mode_, align_corners_, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; - } - - int Resize(const std::vector &inputs, const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - auto convert_int64_shape_to_sizet_shape = [=](std::vector int64_shape) -> std::vector { - std::vector size_t_shape; - (void)std::transform(int64_shape.begin(), int64_shape.end(), std::back_inserter(size_t_shape), LongToSize); - return size_t_shape; - }; - input_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex0]->GetShapeVector()); - grid_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex1]->GetShapeVector()); - output_shape_ = convert_int64_shape_to_sizet_shape(outputs[kIndex0]->GetShapeVector()); - - if (input_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'input' must be at 5-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - if (grid_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'grid' must be at 5-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - if (output_shape_.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'output' must be at 5-D, but got scalar or None."; - return KRET_RESIZE_FAILED; - } - - size_t stride_tmp = 1; - auto stride_compute = [&](std::vector &stride, std::vector shape) { - for (int i = 4; i > -static_cast(1); i--) { - (void)stride.insert(stride.begin(), stride_tmp); - stride_tmp *= shape[static_cast(i)]; - } - stride_tmp = 1; - }; - input_stride_.clear(); - grid_stride_.clear(); - output_stride_.clear(); - stride_compute(input_stride_, input_shape_); - stride_compute(grid_stride_, grid_shape_); - stride_compute(output_stride_, output_shape_); - size_ = input_shape_[kIndex0] * grid_shape_[kIndex1] * grid_shape_[kIndex2] * grid_shape_[kIndex3]; - return KRET_OK; - } - - void ResetResource() noexcept { - input_shape_.clear(); - grid_shape_.clear(); - output_shape_.clear(); - input_stride_.clear(); - grid_stride_.clear(); - output_stride_.clear(); - size_ = 0; - interpolation_mode_ = GridSamplerInterpolationMode::BILINEAR; - padding_mode_ = GridSamplerPaddingMode::ZEROS; - align_corners_ = false; - is_null_input_ = false; - workspace_size_list_.clear(); - } - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); - kernel_func_(this, inputs, outputs, stream_ptr); - return true; - } - - std::vector GetOpSupport() override; - - private: - using KernelFunc = std::function &, - const std::vector &, void *)>; - KernelFunc kernel_func_{}; - static std::vector> func_list_; - std::vector input_shape_; - std::vector grid_shape_; - std::vector output_shape_; - std::vector input_stride_; - std::vector grid_stride_; - std::vector output_stride_; - size_t size_; - GridSamplerInterpolationMode interpolation_mode_; - GridSamplerPaddingMode padding_mode_; - bool align_corners_; - bool is_null_input_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d.h" +#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d.h" +#include "mindspore/core/ops/op_enum.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh" + +namespace mindspore { +namespace kernel { +class GridSampler2DGpuKernelMod : public NativeGpuKernelMod { + public: + GridSampler2DGpuKernelMod() { ResetResource(); } + ~GridSampler2DGpuKernelMod() override = default; + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + if (is_null_input_) { + return true; + } + interpolation_mode_ = static_cast(inputs[kIndex2]->GetValueWithCheck()); + padding_mode_ = static_cast(inputs[kIndex3]->GetValueWithCheck()); + align_corners_ = inputs[kIndex4]->GetValueWithCheck(); + T *input_addr = GetDeviceAddress(inputs, kIndex0); + T *grid_addr = GetDeviceAddress(inputs, kIndex1); + T *output_addr = GetDeviceAddress(outputs, kIndex0); + auto status = GridSampler2D(size_, input_addr, grid_addr, output_addr, input_shape_, grid_shape_, output_shape_, + input_stride_, grid_stride_, output_stride_, interpolation_mode_, padding_mode_, + align_corners_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; + } + + int Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + auto convert_int64_shape_to_sizet_shape = [=](std::vector int64_shape) -> std::vector { + std::vector size_t_shape; + (void)std::transform(int64_shape.begin(), int64_shape.end(), std::back_inserter(size_t_shape), LongToSize); + return size_t_shape; + }; + input_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex0]->GetShapeVector()); + grid_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex1]->GetShapeVector()); + output_shape_ = convert_int64_shape_to_sizet_shape(outputs[kIndex0]->GetShapeVector()); + + if (input_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'input' must be at 4-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + if (grid_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'grid' must be at 4-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + if (output_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'output' must be at 4-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + size_t stride_tmp = 1; + auto stride_compute = [&](std::vector &stride, std::vector shape) { + for (int i = 3; i > -static_cast(1); i--) { + (void)stride.insert(stride.begin(), stride_tmp); + stride_tmp *= shape[static_cast(i)]; + } + stride_tmp = 1; + }; + input_stride_.clear(); + grid_stride_.clear(); + output_stride_.clear(); + stride_compute(input_stride_, input_shape_); + stride_compute(grid_stride_, grid_shape_); + stride_compute(output_stride_, output_shape_); + size_ = input_shape_[kIndex0] * grid_shape_[kIndex1] * grid_shape_[kIndex2]; + return KRET_OK; + } + + void ResetResource() noexcept { + input_shape_.clear(); + grid_shape_.clear(); + output_shape_.clear(); + input_stride_.clear(); + grid_stride_.clear(); + output_stride_.clear(); + size_ = 0; + interpolation_mode_ = GridSamplerInterpolationMode::BILINEAR; + padding_mode_ = GridSamplerPaddingMode::ZEROS; + align_corners_ = false; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); + kernel_func_(this, inputs, outputs, stream_ptr); + return true; + } + + std::vector GetOpSupport() override; + + private: + using KernelFunc = std::function &, + const std::vector &, void *)>; + KernelFunc kernel_func_{}; + static std::vector> func_list_; + size_t size_; + std::vector input_shape_; + std::vector grid_shape_; + std::vector output_shape_; + std::vector input_stride_; + std::vector grid_stride_; + std::vector output_stride_; + GridSamplerInterpolationMode interpolation_mode_; + GridSamplerPaddingMode padding_mode_; + bool align_corners_; + bool is_null_input_; +}; + +class GridSampler3DGpuKernelMod : public NativeGpuKernelMod { + public: + GridSampler3DGpuKernelMod() { ResetResource(); } + ~GridSampler3DGpuKernelMod() override = default; + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + if (is_null_input_) { + return true; + } + interpolation_mode_ = static_cast(inputs[kIndex2]->GetValueWithCheck()); + padding_mode_ = static_cast(inputs[kIndex3]->GetValueWithCheck()); + align_corners_ = inputs[kIndex4]->GetValueWithCheck(); + T *input_addr = GetDeviceAddress(inputs, kIndex0); + T *grid_addr = GetDeviceAddress(inputs, kIndex1); + T *output_addr = GetDeviceAddress(outputs, kIndex0); + GridSampler3D(size_, input_addr, grid_addr, output_addr, input_shape_, grid_shape_, output_shape_, input_stride_, + grid_stride_, output_stride_, interpolation_mode_, padding_mode_, align_corners_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; + } + + int Resize(const std::vector &inputs, const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + auto convert_int64_shape_to_sizet_shape = [=](std::vector int64_shape) -> std::vector { + std::vector size_t_shape; + (void)std::transform(int64_shape.begin(), int64_shape.end(), std::back_inserter(size_t_shape), LongToSize); + return size_t_shape; + }; + input_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex0]->GetShapeVector()); + grid_shape_ = convert_int64_shape_to_sizet_shape(inputs[kIndex1]->GetShapeVector()); + output_shape_ = convert_int64_shape_to_sizet_shape(outputs[kIndex0]->GetShapeVector()); + + if (input_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'input' must be at 5-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + if (grid_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'grid' must be at 5-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + if (output_shape_.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'output' must be at 5-D, but got scalar or None."; + return KRET_RESIZE_FAILED; + } + + size_t stride_tmp = 1; + auto stride_compute = [&](std::vector &stride, std::vector shape) { + for (int i = 4; i > -static_cast(1); i--) { + (void)stride.insert(stride.begin(), stride_tmp); + stride_tmp *= shape[static_cast(i)]; + } + stride_tmp = 1; + }; + input_stride_.clear(); + grid_stride_.clear(); + output_stride_.clear(); + stride_compute(input_stride_, input_shape_); + stride_compute(grid_stride_, grid_shape_); + stride_compute(output_stride_, output_shape_); + size_ = input_shape_[kIndex0] * grid_shape_[kIndex1] * grid_shape_[kIndex2] * grid_shape_[kIndex3]; + return KRET_OK; + } + + void ResetResource() noexcept { + input_shape_.clear(); + grid_shape_.clear(); + output_shape_.clear(); + input_stride_.clear(); + grid_stride_.clear(); + output_stride_.clear(); + size_ = 0; + interpolation_mode_ = GridSamplerInterpolationMode::BILINEAR; + padding_mode_ = GridSamplerPaddingMode::ZEROS; + align_corners_ = false; + is_null_input_ = false; + workspace_size_list_.clear(); + } + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGridSamplerInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGridSamplerOutputNum, kernel_name_); + kernel_func_(this, inputs, outputs, stream_ptr); + return true; + } + + std::vector GetOpSupport() override; + + private: + using KernelFunc = std::function &, + const std::vector &, void *)>; + KernelFunc kernel_func_{}; + static std::vector> func_list_; + std::vector input_shape_; + std::vector grid_shape_; + std::vector output_shape_; + std::vector input_stride_; + std::vector grid_stride_; + std::vector output_stride_; + size_t size_; + GridSamplerInterpolationMode interpolation_mode_; + GridSamplerPaddingMode padding_mode_; + bool align_corners_; + bool is_null_input_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRID_SAMPLER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc index b890efed8d3..141c7fd4b50 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.cc @@ -1,100 +1,100 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.h" -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" -#include "kernel/common_utils.h" -#include "ops/kl_div_loss.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kKLDivLossInputsNum = 2; - -template -bool KLDivLossGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - T *input_x = GetDeviceAddress(inputs, 0); - T *input_y = GetDeviceAddress(inputs, 1); - T *loss = GetDeviceAddress(outputs, 0); - T *tmp_loss = GetDeviceAddress(workspace, 0); - auto status = - KLDivLoss(input_size_, reduction_, input_x, input_y, loss, tmp_loss, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -bool KLDivLossGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - string reduction = GetValue(primitive_->GetAttr(ops::kReduction)); - reduction_ = kReductionModeMap[reduction]; - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(0).dtype); - return true; -} - -int KLDivLossGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kKLDivLossInputsNum, kernel_name_); - int ret = KernelMod::Resize(inputs, outputs); - if (ret != KRET_OK) { - return ret; - } - - auto input_shape = inputs[0]->GetShapeVector(); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "logits"); - if (is_null_input_) { - return ret; - } - - workspace_size_list_.clear(); - input_size_ = 1; - input_size_ *= SizeOf(input_shape); - size_t workspace_size = type_size_; - if (reduction_ != ReductionMode::kNone) { - workspace_size *= input_size_; - } - workspace_size_list_.push_back(workspace_size); - return ret; -} - -std::vector> KLDivLossGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &KLDivLossGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &KLDivLossGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &KLDivLossGpuKernelMod::LaunchKernel}, -}; - -std::vector KLDivLossGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &item) { return item.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, KLDivLoss, KLDivLossGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/kl_div_loss_gpu_kernel.h" +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh" +#include "kernel/common_utils.h" +#include "ops/kl_div_loss.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kKLDivLossInputsNum = 2; + +template +bool KLDivLossGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *input_x = GetDeviceAddress(inputs, 0); + T *input_y = GetDeviceAddress(inputs, 1); + T *loss = GetDeviceAddress(outputs, 0); + T *tmp_loss = GetDeviceAddress(workspace, 0); + auto status = + KLDivLoss(input_size_, reduction_, input_x, input_y, loss, tmp_loss, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +bool KLDivLossGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + string reduction = GetValue(primitive_->GetAttr(ops::kReduction)); + reduction_ = kReductionModeMap[reduction]; + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(0).dtype); + return true; +} + +int KLDivLossGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kKLDivLossInputsNum, kernel_name_); + int ret = KernelMod::Resize(inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + + auto input_shape = inputs[0]->GetShapeVector(); + is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "logits"); + if (is_null_input_) { + return ret; + } + + workspace_size_list_.clear(); + input_size_ = 1; + input_size_ *= SizeOf(input_shape); + size_t workspace_size = type_size_; + if (reduction_ != ReductionMode::kNone) { + workspace_size *= input_size_; + } + workspace_size_list_.push_back(workspace_size); + return ret; +} + +std::vector> KLDivLossGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &KLDivLossGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &KLDivLossGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &KLDivLossGpuKernelMod::LaunchKernel}, +}; + +std::vector KLDivLossGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, KLDivLoss, KLDivLossGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.cc index 687ae518db7..9c563b872d7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.cc @@ -1,103 +1,103 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.h" -#include -#include -#include "mindspore/core/ops/grad/kl_div_loss_grad.h" - -namespace mindspore { -namespace kernel { -bool KLDivLossGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - constexpr size_t input_num = 3; - constexpr size_t output_num = 1; - - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - auto input_data_type = inputs[kIndex0]->dtype_id(); - type_id_size_ = abstract::TypeIdSize(input_data_type); - return true; -} - -int KLDivLossGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - auto input_shape = inputs[kIndex1]->GetShapeVector(); - input_size_ = 1; - input_size_ *= SizeOf(input_shape); - string reduction = GetValue(primitive_->GetAttr("reduction")); - reduction_ = kReductionModeMap[reduction]; - - return KRET_OK; -} - -template -bool KLDivLossGradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, void *stream_ptr) { - T *dloss = GetDeviceAddress(inputs, kIndex0); - T *input_x = GetDeviceAddress(inputs, kIndex1); - T *input_y = GetDeviceAddress(inputs, kIndex2); - T *dx = GetDeviceAddress(outputs, kIndex0); - auto status = - KLDivLossGrad(input_size_, reduction_, input_x, input_y, dloss, dx, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> - KLDivLossGradGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - &KLDivLossGradGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &KLDivLossGradGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - &KLDivLossGradGpuKernelMod::LaunchKernel}, -}; - -std::vector KLDivLossGradGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, KLDivLossGrad, KLDivLossGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/kl_div_loss_grad_kernel.h" +#include +#include +#include "mindspore/core/ops/grad/kl_div_loss_grad.h" + +namespace mindspore { +namespace kernel { +bool KLDivLossGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = 3; + constexpr size_t output_num = 1; + + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + auto input_data_type = inputs[kIndex0]->dtype_id(); + type_id_size_ = abstract::TypeIdSize(input_data_type); + return true; +} + +int KLDivLossGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + auto input_shape = inputs[kIndex1]->GetShapeVector(); + input_size_ = 1; + input_size_ *= SizeOf(input_shape); + string reduction = GetValue(primitive_->GetAttr("reduction")); + reduction_ = kReductionModeMap[reduction]; + + return KRET_OK; +} + +template +bool KLDivLossGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, void *stream_ptr) { + T *dloss = GetDeviceAddress(inputs, kIndex0); + T *input_x = GetDeviceAddress(inputs, kIndex1); + T *input_y = GetDeviceAddress(inputs, kIndex2); + T *dx = GetDeviceAddress(outputs, kIndex0); + auto status = + KLDivLossGrad(input_size_, reduction_, input_x, input_y, dloss, dx, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> + KLDivLossGradGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &KLDivLossGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &KLDivLossGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &KLDivLossGradGpuKernelMod::LaunchKernel}, +}; + +std::vector KLDivLossGradGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, KLDivLossGrad, KLDivLossGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.cc index c71618b7580..9daca6e1f54 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.cc @@ -1,96 +1,96 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kColindex = 1; -constexpr size_t kRowindex = 2; -bool PDistGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the kernel type should be in [float32, double], but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - p_ = GetValue(primitive_->GetAttr(ops::kP)); - input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - return true; -} - -int PDistGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = inputs[kIndex0]->GetShapeVector(); - std::vector output_shape = outputs[kIndex0]->GetShapeVector(); - x_size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); - y_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - int64_t x_dim = input_shape.size(); - if (x_dim != kRowindex) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 2-D, but got " << x_dim << "-D."; - return KRET_RESIZE_FAILED; - } - if (y_size_ == 0) { - is_null_input_ = true; - } - matrix_row_ = input_shape[input_shape.size() - kRowindex]; - matrix_col_ = input_shape[input_shape.size() - kColindex]; - size_t output_size = y_size_ * input_type_size_; - output_size_list_.push_back(output_size); - return KRET_OK; -} - -template -bool PDistGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - auto status = CalPDist(x_size_, y_size_, input, output, p_, matrix_row_, matrix_col_, device_id_, - reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> PDistGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &PDistGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &PDistGpuKernelMod::LaunchKernel}}; - -std::vector PDistGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Pdist, PDistGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kColindex = 1; +constexpr size_t kRowindex = 2; +bool PDistGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the kernel type should be in [float32, double], but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + p_ = GetValue(primitive_->GetAttr(ops::kP)); + input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + return true; +} + +int PDistGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = inputs[kIndex0]->GetShapeVector(); + std::vector output_shape = outputs[kIndex0]->GetShapeVector(); + x_size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + y_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int64_t x_dim = input_shape.size(); + if (x_dim != kRowindex) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 2-D, but got " << x_dim << "-D."; + return KRET_RESIZE_FAILED; + } + if (y_size_ == 0) { + is_null_input_ = true; + } + matrix_row_ = input_shape[input_shape.size() - kRowindex]; + matrix_col_ = input_shape[input_shape.size() - kColindex]; + size_t output_size = y_size_ * input_type_size_; + output_size_list_.push_back(output_size); + return KRET_OK; +} + +template +bool PDistGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + auto status = CalPDist(x_size_, y_size_, input, output, p_, matrix_row_, matrix_col_, device_id_, + reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> PDistGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &PDistGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &PDistGpuKernelMod::LaunchKernel}}; + +std::vector PDistGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Pdist, PDistGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h index 13b1a66c645..3f4ba2e0760 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_gpu_kernel.h @@ -1,82 +1,82 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/pdist.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh" - -namespace mindspore { -namespace kernel { -class PDistGpuKernelMod : public NativeGpuKernelMod { - public: - PDistGpuKernelMod() { ResetResource(); } - ~PDistGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - workspace_size_list_.clear(); - output_size_list_.clear(); - } - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using PDistFunc = - std::function &, - const std::vector &, const std::vector &)>; - float p_{0}; - size_t input_type_size_{1}; - size_t x_size_{0}; - size_t y_size_{0}; - int64_t matrix_row_{0}; - int64_t matrix_col_{0}; - PDistFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/pdist.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_impl.cuh" + +namespace mindspore { +namespace kernel { +class PDistGpuKernelMod : public NativeGpuKernelMod { + public: + PDistGpuKernelMod() { ResetResource(); } + ~PDistGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + workspace_size_list_.clear(); + output_size_list_.clear(); + } + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using PDistFunc = + std::function &, + const std::vector &, const std::vector &)>; + float p_{0}; + size_t input_type_size_{1}; + size_t x_size_{0}; + size_t y_size_{0}; + int64_t matrix_row_{0}; + int64_t matrix_col_{0}; + PDistFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.cc index ba444eef50d..ab9ae1b95fb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.cc @@ -1,131 +1,131 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kZeroindex = 0; -constexpr size_t kOneindex = 1; -constexpr size_t kTwoindex = 2; -bool PDistGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', the kernel type should be in [float16, float32, double], but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - p_ = GetValue(primitive_->GetAttr(ops::kP)); - input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - return true; -} - -int PDistGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - - ResetResource(); - std::vector y_grad_shape = inputs[kIndex0]->GetShapeVector(); - std::vector x_shape = inputs[kIndex1]->GetShapeVector(); - std::vector y_shape = inputs[kIndex2]->GetShapeVector(); - std::vector output_shape = outputs[kIndex0]->GetShapeVector(); - int64_t y_grad_dim = y_grad_shape.size(); - int64_t x_dim = x_shape.size(); - int64_t y_dim = y_shape.size(); - if (y_grad_dim != 1) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'y_grad' must be 1-D," - << " but got " << y_grad_dim << "-D."; - return false; - } - if (x_dim != kTwoindex) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 2-D," - << " but got " << x_dim << "-D."; - return false; - } - if (y_dim != 1) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'y' must be 1-D," - << " but got " << y_dim << "-D."; - return false; - } - y_grad_size_ = std::accumulate(y_grad_shape.begin(), y_grad_shape.end(), 1, std::multiplies()); - x_size_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies()); - y_size_ = std::accumulate(y_shape.begin(), y_shape.end(), 1, std::multiplies()); - size_t x_grad_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - matrix_row_ = x_shape[x_shape.size() - kTwoindex]; - matrix_col_ = x_shape[x_shape.size() - kOneindex]; - if (x_grad_size == 0) { - is_null_input_ = true; - } - - size_t output_size = x_grad_size * input_type_size_; - size_t work_size = ((matrix_row_ - 1) * x_size_) * input_type_size_; - output_size_list_.push_back(output_size); - workspace_size_list_.push_back(work_size); - return KRET_OK; -} - -template -bool PDistGradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *y_grad = GetDeviceAddress(inputs, kZeroindex); - T *x = GetDeviceAddress(inputs, kOneindex); - T *y = GetDeviceAddress(inputs, kTwoindex); - T *output = GetDeviceAddress(outputs, kZeroindex); - T *buffer = GetDeviceAddress(workspace, kZeroindex); - auto status = CalPDistGrad(x_size_, y_size_, y_grad_size_, y_grad, x, y, matrix_row_, matrix_col_, p_, output, buffer, - device_id_, reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> PDistGradGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &PDistGradGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - &PDistGradGpuKernelMod::LaunchKernel}}; - -std::vector PDistGradGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PdistGrad, PDistGradGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kZeroindex = 0; +constexpr size_t kOneindex = 1; +constexpr size_t kTwoindex = 2; +bool PDistGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the kernel type should be in [float16, float32, double], but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + p_ = GetValue(primitive_->GetAttr(ops::kP)); + input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + return true; +} + +int PDistGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + + ResetResource(); + std::vector y_grad_shape = inputs[kIndex0]->GetShapeVector(); + std::vector x_shape = inputs[kIndex1]->GetShapeVector(); + std::vector y_shape = inputs[kIndex2]->GetShapeVector(); + std::vector output_shape = outputs[kIndex0]->GetShapeVector(); + int64_t y_grad_dim = y_grad_shape.size(); + int64_t x_dim = x_shape.size(); + int64_t y_dim = y_shape.size(); + if (y_grad_dim != 1) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'y_grad' must be 1-D," + << " but got " << y_grad_dim << "-D."; + return false; + } + if (x_dim != kTwoindex) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 2-D," + << " but got " << x_dim << "-D."; + return false; + } + if (y_dim != 1) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'y' must be 1-D," + << " but got " << y_dim << "-D."; + return false; + } + y_grad_size_ = std::accumulate(y_grad_shape.begin(), y_grad_shape.end(), 1, std::multiplies()); + x_size_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies()); + y_size_ = std::accumulate(y_shape.begin(), y_shape.end(), 1, std::multiplies()); + size_t x_grad_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + matrix_row_ = x_shape[x_shape.size() - kTwoindex]; + matrix_col_ = x_shape[x_shape.size() - kOneindex]; + if (x_grad_size == 0) { + is_null_input_ = true; + } + + size_t output_size = x_grad_size * input_type_size_; + size_t work_size = ((matrix_row_ - 1) * x_size_) * input_type_size_; + output_size_list_.push_back(output_size); + workspace_size_list_.push_back(work_size); + return KRET_OK; +} + +template +bool PDistGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *y_grad = GetDeviceAddress(inputs, kZeroindex); + T *x = GetDeviceAddress(inputs, kOneindex); + T *y = GetDeviceAddress(inputs, kTwoindex); + T *output = GetDeviceAddress(outputs, kZeroindex); + T *buffer = GetDeviceAddress(workspace, kZeroindex); + auto status = CalPDistGrad(x_size_, y_size_, y_grad_size_, y_grad, x, y, matrix_row_, matrix_col_, p_, output, buffer, + device_id_, reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> PDistGradGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &PDistGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &PDistGradGpuKernelMod::LaunchKernel}}; + +std::vector PDistGradGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PdistGrad, PDistGradGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h index 8baef7680e3..073f10e3b30 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pdist_grad_gpu_kernel.h @@ -1,84 +1,84 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/grad/pdist_grad.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -class PDistGradGpuKernelMod : public NativeGpuKernelMod { - public: - PDistGradGpuKernelMod() { ResetResource(); } - ~PDistGradGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - workspace_size_list_.clear(); - output_size_list_.clear(); - } - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using PDistGradFunc = - std::function &, - const std::vector &, const std::vector &)>; - float p_{0}; - size_t input_type_size_{1}; - size_t y_grad_size_{0}; - size_t x_size_{0}; - size_t y_size_{0}; - int64_t matrix_row_{0}; - int64_t matrix_col_{0}; - PDistGradFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/grad/pdist_grad.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pdist_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +class PDistGradGpuKernelMod : public NativeGpuKernelMod { + public: + PDistGradGpuKernelMod() { ResetResource(); } + ~PDistGradGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + workspace_size_list_.clear(); + output_size_list_.clear(); + } + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using PDistGradFunc = + std::function &, + const std::vector &, const std::vector &)>; + float p_{0}; + size_t input_type_size_{1}; + size_t y_grad_size_{0}; + size_t x_size_{0}; + size_t y_size_{0}; + int64_t matrix_row_{0}; + int64_t matrix_col_{0}; + PDistGradFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PDIST_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sgd_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sgd_gpu_kernel.cc index e72ba24536f..814c6bb1580 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sgd_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sgd_gpu_kernel.cc @@ -1,32 +1,32 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/sgd_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(SGD, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SGDGpuKernelMod, float) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/sgd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SGD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SGDGpuKernelMod, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/softmax_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/softmax_gpu_kernel.cc index edb824099b5..4c654303840 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/softmax_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/softmax_gpu_kernel.cc @@ -1,147 +1,147 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/nn/softmax_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" - -namespace mindspore { -namespace kernel { -namespace { -int64_t MaybeWrapDim(int64_t dim, int64_t dim_post_expr, const std::string &kernel_name) { - int64_t min = -dim_post_expr; - int64_t max = dim_post_expr - 1; - if (dim < min || dim > max) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of 'axis' must be in range [-" << dim_post_expr << ", " - << dim_post_expr << "), but got " << dim; - } - if (dim < 0) { - dim += dim_post_expr; - } - return dim; -} -} // namespace -bool SoftmaxGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - } - kernel_func_ = func_list_[index].second; - is_log_softmax_ = kernel_name_ == "LogSoftmax"; - - return true; -} - -size_t SoftmaxGpuKernelMod::GetAccAxis(KernelTensor *axis_kernel_tensor) const noexcept { - std::vector axis; - if (is_log_softmax_) { - axis.push_back(axis_kernel_tensor->GetValueWithCheck()); - } else { - axis = axis_kernel_tensor->GetValueWithCheck>(); - // axis size must be 1 - if (axis.size() != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be equal to 0, but got " - << axis.size(); - } - } - // check axis value - auto axis_acc = static_cast(MaybeWrapDim(axis[0], shape_size_, kernel_name_)); - return axis_acc; -} - -int SoftmaxGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - // input, workspace and output will be assign in InitSizeLists. - const auto &input_shape = inputs[kIndex0]->GetShapeVector(); - auto input_element_num = - std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies()); - is_null_input_ = (input_element_num == 0); - if (is_null_input_) { - return KRET_OK; - } - - ResetResource(); - (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shape_), - [](const int64_t &value) { return LongToInt(value); }); - shape_size_ = input_shape.size(); - axis_acc_ = GetAccAxis(inputs[kIndex1]); - if (input_element_num > 0) { - // calculate outer and inner size - for (size_t i = 0; i < axis_acc_; ++i) { - outer_size_ *= shape_[i]; - } - for (size_t i = axis_acc_ + 1; i < shape_.size(); ++i) { - inner_size_ *= shape_[i]; - } - } - - return KRET_OK; -} - -template -bool SoftmaxGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) noexcept { - if (is_null_input_) { - return true; - } - - T *input_addr = GetDeviceAddress(inputs, 0); - MS_ERROR_IF_NULL_W_RET_VAL(input_addr, false); - T *output_addr = GetDeviceAddress(outputs, 0); - MS_ERROR_IF_NULL_W_RET_VAL(output_addr, false); - - // kernel function - if (is_log_softmax_) { - Softmax(input_addr, output_addr, shape_[axis_acc_], outer_size_, inner_size_, device_id_, - reinterpret_cast(stream_ptr)); - } else { - Softmax(input_addr, output_addr, shape_[axis_acc_], outer_size_, inner_size_, device_id_, - reinterpret_cast(stream_ptr)); - } - - return true; -} - -#define SOFTMAX_GPU_REG(MT, T) \ - KernelAttr().AddInputAttr(MT).AddInputAttr(kObjectTypeTuple, kNumberTypeInt64).AddOutputAttr(MT), \ - &SoftmaxGpuKernelMod::LaunchKernel - -#define LOG_SOFTMAX_GPU_REG(MT, T) \ - KernelAttr().AddInputAttr(MT).AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(MT), \ - &SoftmaxGpuKernelMod::LaunchKernel - -std::vector> SoftmaxGpuKernelMod::func_list_ = { - {SOFTMAX_GPU_REG(kNumberTypeFloat64, double)}, {SOFTMAX_GPU_REG(kNumberTypeFloat32, float)}, - {SOFTMAX_GPU_REG(kNumberTypeFloat16, half)}, {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat64, double)}, - {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat32, float)}, {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat16, half)}}; - -std::vector SoftmaxGpuKernelMod::GetOpSupport() { - static std::vector support_list; - if (support_list.empty()) { - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - } - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Softmax, SoftmaxGpuKernelMod); -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogSoftmax, SoftmaxGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2019-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/softmax_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" + +namespace mindspore { +namespace kernel { +namespace { +int64_t MaybeWrapDim(int64_t dim, int64_t dim_post_expr, const std::string &kernel_name) { + int64_t min = -dim_post_expr; + int64_t max = dim_post_expr - 1; + if (dim < min || dim > max) { + MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of 'axis' must be in range [-" << dim_post_expr << ", " + << dim_post_expr << "), but got " << dim; + } + if (dim < 0) { + dim += dim_post_expr; + } + return dim; +} +} // namespace +bool SoftmaxGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + is_log_softmax_ = kernel_name_ == "LogSoftmax"; + + return true; +} + +size_t SoftmaxGpuKernelMod::GetAccAxis(KernelTensor *axis_kernel_tensor) const noexcept { + std::vector axis; + if (is_log_softmax_) { + axis.push_back(axis_kernel_tensor->GetValueWithCheck()); + } else { + axis = axis_kernel_tensor->GetValueWithCheck>(); + // axis size must be 1 + if (axis.size() != 1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be equal to 0, but got " + << axis.size(); + } + } + // check axis value + auto axis_acc = static_cast(MaybeWrapDim(axis[0], shape_size_, kernel_name_)); + return axis_acc; +} + +int SoftmaxGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + // input, workspace and output will be assign in InitSizeLists. + const auto &input_shape = inputs[kIndex0]->GetShapeVector(); + auto input_element_num = + std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies()); + is_null_input_ = (input_element_num == 0); + if (is_null_input_) { + return KRET_OK; + } + + ResetResource(); + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shape_), + [](const int64_t &value) { return LongToInt(value); }); + shape_size_ = input_shape.size(); + axis_acc_ = GetAccAxis(inputs[kIndex1]); + if (input_element_num > 0) { + // calculate outer and inner size + for (size_t i = 0; i < axis_acc_; ++i) { + outer_size_ *= shape_[i]; + } + for (size_t i = axis_acc_ + 1; i < shape_.size(); ++i) { + inner_size_ *= shape_[i]; + } + } + + return KRET_OK; +} + +template +bool SoftmaxGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) noexcept { + if (is_null_input_) { + return true; + } + + T *input_addr = GetDeviceAddress(inputs, 0); + MS_ERROR_IF_NULL_W_RET_VAL(input_addr, false); + T *output_addr = GetDeviceAddress(outputs, 0); + MS_ERROR_IF_NULL_W_RET_VAL(output_addr, false); + + // kernel function + if (is_log_softmax_) { + Softmax(input_addr, output_addr, shape_[axis_acc_], outer_size_, inner_size_, device_id_, + reinterpret_cast(stream_ptr)); + } else { + Softmax(input_addr, output_addr, shape_[axis_acc_], outer_size_, inner_size_, device_id_, + reinterpret_cast(stream_ptr)); + } + + return true; +} + +#define SOFTMAX_GPU_REG(MT, T) \ + KernelAttr().AddInputAttr(MT).AddInputAttr(kObjectTypeTuple, kNumberTypeInt64).AddOutputAttr(MT), \ + &SoftmaxGpuKernelMod::LaunchKernel + +#define LOG_SOFTMAX_GPU_REG(MT, T) \ + KernelAttr().AddInputAttr(MT).AddInputAttr(kObjectTypeNumber, kNumberTypeInt64).AddOutputAttr(MT), \ + &SoftmaxGpuKernelMod::LaunchKernel + +std::vector> SoftmaxGpuKernelMod::func_list_ = { + {SOFTMAX_GPU_REG(kNumberTypeFloat64, double)}, {SOFTMAX_GPU_REG(kNumberTypeFloat32, float)}, + {SOFTMAX_GPU_REG(kNumberTypeFloat16, half)}, {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat64, double)}, + {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat32, float)}, {LOG_SOFTMAX_GPU_REG(kNumberTypeFloat16, half)}}; + +std::vector SoftmaxGpuKernelMod::GetOpSupport() { + static std::vector support_list; + if (support_list.empty()) { + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + } + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Softmax, SoftmaxGpuKernelMod); +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogSoftmax, SoftmaxGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/upsample_trilinear_3d_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/upsample_trilinear_3d_grad_gpu_kernel.h index 879dde7e8fc..392660e4c42 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/upsample_trilinear_3d_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/upsample_trilinear_3d_grad_gpu_kernel.h @@ -1,81 +1,81 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class UpsampleTrilinear3DGradGpuKernelMod : public NativeGpuKernelMod { - public: - UpsampleTrilinear3DGradGpuKernelMod() = default; - ~UpsampleTrilinear3DGradGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - std::vector GetLaunchIgnoredInputAddressIdx() const override { return {kIndex2}; } - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - using UpsampleTrilinear3DGradFunc = - std::function &, - const std::vector &, const std::vector &)>; - UpsampleTrilinear3DGradFunc kernel_func_; - static std::vector> func_list_; - - void *cuda_stream_{nullptr}; - bool align_corners_{}; - // array dims -> reset these - int64_t n_{}; - int64_t c_{}; - int64_t grad_d_{}; - int64_t grad_h_{}; - int64_t grad_w_{}; - int64_t dinput_d_{}; - int64_t dinput_h_{}; - int64_t dinput_w_{}; - - // only need these - std::vector scales_{0., 0., 0.}; - std::vector none_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class UpsampleTrilinear3DGradGpuKernelMod : public NativeGpuKernelMod { + public: + UpsampleTrilinear3DGradGpuKernelMod() = default; + ~UpsampleTrilinear3DGradGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + std::vector GetLaunchIgnoredInputAddressIdx() const override { return {kIndex2}; } + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using UpsampleTrilinear3DGradFunc = + std::function &, + const std::vector &, const std::vector &)>; + UpsampleTrilinear3DGradFunc kernel_func_; + static std::vector> func_list_; + + void *cuda_stream_{nullptr}; + bool align_corners_{}; + // array dims -> reset these + int64_t n_{}; + int64_t c_{}; + int64_t grad_d_{}; + int64_t grad_h_{}; + int64_t grad_w_{}; + int64_t dinput_d_{}; + int64_t dinput_h_{}; + int64_t dinput_w_{}; + + // only need these + std::vector scales_{0., 0., 0.}; + std::vector none_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_UPSAMPLE_TRILINEAR_3D_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.cc index 4c976002f85..68151fcaf7f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.cc @@ -1,104 +1,104 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h" -#include "ops/op_name.h" - -namespace mindspore { -namespace kernel { -bool BartlettWindowGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); - return true; -} - -int BartlettWindowGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), - inputs.at(kIndex0)->GetDeviceShapeVector().end()); - int64_t input_dims = input_shape.size(); - if (input_dims != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 0-D, but got " << input_dims << "-D."; - return false; - } - std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - size_t output_size = output_elements_ * unit_output_size_; - output_size_list_.push_back(output_size); - return KRET_OK; -} - -template -bool BartlettWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - S *output = GetDeviceAddress(outputs, 0); - auto status = CalBartlettWindow(output_elements_, input, periodic_, output, device_id_, - reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> BartlettWindowGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &BartlettWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - &BartlettWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &BartlettWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &BartlettWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &BartlettWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &BartlettWindowGpuKernelMod::LaunchKernel}}; - -std::vector BartlettWindowGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BartlettWindow, BartlettWindowGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h" +#include "ops/op_name.h" + +namespace mindspore { +namespace kernel { +bool BartlettWindowGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); + return true; +} + +int BartlettWindowGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), + inputs.at(kIndex0)->GetDeviceShapeVector().end()); + int64_t input_dims = input_shape.size(); + if (input_dims != 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 0-D, but got " << input_dims << "-D."; + return false; + } + std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + size_t output_size = output_elements_ * unit_output_size_; + output_size_list_.push_back(output_size); + return KRET_OK; +} + +template +bool BartlettWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + S *output = GetDeviceAddress(outputs, 0); + auto status = CalBartlettWindow(output_elements_, input, periodic_, output, device_id_, + reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> BartlettWindowGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &BartlettWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &BartlettWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &BartlettWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &BartlettWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &BartlettWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &BartlettWindowGpuKernelMod::LaunchKernel}}; + +std::vector BartlettWindowGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BartlettWindow, BartlettWindowGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h index 6cd9a66ab56..e408f53430b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/bartlett_window_gpu_kernel.h @@ -1,83 +1,83 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_WINDOW_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_WINDOW_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/bartlett_window.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh" - -namespace mindspore { -namespace kernel { -class BartlettWindowGpuKernelMod : public NativeGpuKernelMod { - public: - BartlettWindowGpuKernelMod() { ResetResource(); } - ~BartlettWindowGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using BlWFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - bool periodic_{true}; - size_t unit_input_size_{1}; - size_t unit_output_size_{1}; - size_t output_elements_; - BlWFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_WINDOW_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_WINDOW_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/bartlett_window.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/bartlett_window_impl.cuh" + +namespace mindspore { +namespace kernel { +class BartlettWindowGpuKernelMod : public NativeGpuKernelMod { + public: + BartlettWindowGpuKernelMod() { ResetResource(); } + ~BartlettWindowGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using BlWFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + bool periodic_{true}; + size_t unit_input_size_{1}; + size_t unit_output_size_{1}; + size_t output_elements_; + BlWFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BARTLETT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.cc index 76cda30e4cc..55c29715338 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.cc @@ -1,102 +1,102 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool BlackmanWindowGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); - return true; -} - -int BlackmanWindowGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), - inputs.at(kIndex0)->GetDeviceShapeVector().end()); - int64_t input_dims = input_shape.size(); - if (input_dims != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 0-D, but got " << input_dims << "-D."; - return false; - } - std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - size_t output_size = output_elements_ * unit_output_size_; - output_size_list_.push_back(output_size); - return KRET_OK; -} - -template -bool BlackmanWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - S *output = GetDeviceAddress(outputs, 0); - CalBlackmanWindow(output_elements_, input, periodic_, output, device_id_, - reinterpret_cast(cuda_stream_)); - return true; -} - -std::vector> BlackmanWindowGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &BlackmanWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - &BlackmanWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &BlackmanWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &BlackmanWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &BlackmanWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &BlackmanWindowGpuKernelMod::LaunchKernel}}; - -std::vector BlackmanWindowGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BlackmanWindow, BlackmanWindowGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool BlackmanWindowGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + periodic_ = GetValue(primitive_->GetAttr(ops::kPeriodic)); + return true; +} + +int BlackmanWindowGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector input_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), + inputs.at(kIndex0)->GetDeviceShapeVector().end()); + int64_t input_dims = input_shape.size(); + if (input_dims != 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 0-D, but got " << input_dims << "-D."; + return false; + } + std::vector output_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + size_t output_size = output_elements_ * unit_output_size_; + output_size_list_.push_back(output_size); + return KRET_OK; +} + +template +bool BlackmanWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + S *output = GetDeviceAddress(outputs, 0); + CalBlackmanWindow(output_elements_, input, periodic_, output, device_id_, + reinterpret_cast(cuda_stream_)); + return true; +} + +std::vector> BlackmanWindowGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &BlackmanWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &BlackmanWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &BlackmanWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &BlackmanWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &BlackmanWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &BlackmanWindowGpuKernelMod::LaunchKernel}}; + +std::vector BlackmanWindowGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BlackmanWindow, BlackmanWindowGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h index 18e62b6e49a..1578f79a7d2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/blackman_window_gpu_kernel.h @@ -1,83 +1,83 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_WINDOW_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_WINDOW_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/blackman_window.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh" - -namespace mindspore { -namespace kernel { -class BlackmanWindowGpuKernelMod : public NativeGpuKernelMod { - public: - BlackmanWindowGpuKernelMod() { ResetResource(); } - ~BlackmanWindowGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using BmWFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - bool periodic_{true}; - size_t unit_input_size_{1}; - size_t unit_output_size_{1}; - size_t output_elements_; - BmWFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_WINDOW_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_WINDOW_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/blackman_window.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/blackman_window_impl.cuh" + +namespace mindspore { +namespace kernel { +class BlackmanWindowGpuKernelMod : public NativeGpuKernelMod { + public: + BlackmanWindowGpuKernelMod() { ResetResource(); } + ~BlackmanWindowGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using BmWFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + bool periodic_{true}; + size_t unit_input_size_{1}; + size_t unit_output_size_{1}; + size_t output_elements_; + BmWFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_BLACKMAN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.cc index 294b14feaba..70bd0c29619 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.cc @@ -1,140 +1,140 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool BoundingBoxDecodeGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - constexpr size_t input_num = 2; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - - const size_t coordinate_size = 4; - auto means = primitive_->GetAttr("means"); - MS_EXCEPTION_IF_NULL(means); - if (means->isa()) { - means_ = GetValue>(means); - } else if (means->isa()) { - float mean = GetValue(means); - for (size_t i = 0; i < coordinate_size; i++) { - (void)means_.emplace_back(mean); - } - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the input 'means' must be a tuple or a list, and dtype must be float, but got is not."; - } - - auto stds = primitive_->GetAttr("stds"); - MS_EXCEPTION_IF_NULL(stds); - if (stds->isa()) { - stds_ = GetValue>(stds); - } else if (stds->isa()) { - float std = GetValue(stds); - for (size_t i = 0; i < coordinate_size; i++) { - (void)stds_.emplace_back(std); - } - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the input 'stds' must be a tuple or a list, and dtype must be float, but got is not."; - } - - if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the length of input 'means' and 'stds' must be at least 4, " - "but got the length of 'means': " - << means_.size() << ", and the length of 'stds': " << stds_.size(); - } - - auto max_shape = primitive_->GetAttr("max_shape"); - std::vector max_shape_me = GetValue>(max_shape); - (void)std::transform(max_shape_me.begin(), max_shape_me.end(), std::back_inserter(max_shape_), - [](const int64_t &value) { return LongToInt(value); }); - auto wh_ratio_clip = primitive_->GetAttr("wh_ratio_clip"); - wh_ratio_clip_ = GetValue(wh_ratio_clip); - - if (max_shape_.size() < kMinMaxShapeSize) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the length of 'max_shape' must be at least 2, but got: " << max_shape_.size(); - } - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int BoundingBoxDecodeGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - return KRET_OK; -} - -template -bool BoundingBoxDecodeGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - T *rois_addr = GetDeviceAddress(inputs, 0); - T *deltas_addr = GetDeviceAddress(inputs, 1); - T *bboxes_addr = GetDeviceAddress(outputs, 0); - - if (inputs[0]->size() != inputs[1]->size()) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', rois box size must equal with deltas box size: " << inputs[1]->size() << ", but got " - << inputs[0]->size(); - return false; - } - - const size_t coordinate = 4; - const size_t block_size = inputs[0]->size() / sizeof(T); - if ((block_size % coordinate) != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; - return false; - } - auto status = - BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], - means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, - device_id_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> - BoundingBoxDecodeGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &BoundingBoxDecodeGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &BoundingBoxDecodeGpuKernelMod::LaunchKernel}}; - -std::vector BoundingBoxDecodeGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { - return pair.first; - }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BoundingBoxDecode, BoundingBoxDecodeGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/other/boundingbox_decode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool BoundingBoxDecodeGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = 2; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + + const size_t coordinate_size = 4; + auto means = primitive_->GetAttr("means"); + MS_EXCEPTION_IF_NULL(means); + if (means->isa()) { + means_ = GetValue>(means); + } else if (means->isa()) { + float mean = GetValue(means); + for (size_t i = 0; i < coordinate_size; i++) { + (void)means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the input 'means' must be a tuple or a list, and dtype must be float, but got is not."; + } + + auto stds = primitive_->GetAttr("stds"); + MS_EXCEPTION_IF_NULL(stds); + if (stds->isa()) { + stds_ = GetValue>(stds); + } else if (stds->isa()) { + float std = GetValue(stds); + for (size_t i = 0; i < coordinate_size; i++) { + (void)stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the input 'stds' must be a tuple or a list, and dtype must be float, but got is not."; + } + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the length of input 'means' and 'stds' must be at least 4, " + "but got the length of 'means': " + << means_.size() << ", and the length of 'stds': " << stds_.size(); + } + + auto max_shape = primitive_->GetAttr("max_shape"); + std::vector max_shape_me = GetValue>(max_shape); + (void)std::transform(max_shape_me.begin(), max_shape_me.end(), std::back_inserter(max_shape_), + [](const int64_t &value) { return LongToInt(value); }); + auto wh_ratio_clip = primitive_->GetAttr("wh_ratio_clip"); + wh_ratio_clip_ = GetValue(wh_ratio_clip); + + if (max_shape_.size() < kMinMaxShapeSize) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the length of 'max_shape' must be at least 2, but got: " << max_shape_.size(); + } + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int BoundingBoxDecodeGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + return KRET_OK; +} + +template +bool BoundingBoxDecodeGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *rois_addr = GetDeviceAddress(inputs, 0); + T *deltas_addr = GetDeviceAddress(inputs, 1); + T *bboxes_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size() != inputs[1]->size()) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', rois box size must equal with deltas box size: " << inputs[1]->size() << ", but got " + << inputs[0]->size(); + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size() / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; + return false; + } + auto status = + BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], + means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, + device_id_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> + BoundingBoxDecodeGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &BoundingBoxDecodeGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &BoundingBoxDecodeGpuKernelMod::LaunchKernel}}; + +std::vector BoundingBoxDecodeGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { + return pair.first; + }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BoundingBoxDecode, BoundingBoxDecodeGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.cc index 6dd17af983a..255fc5171bf 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.cc @@ -1,128 +1,128 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool BoundingBoxEncodeGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - constexpr size_t input_num = 2; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); - - const size_t coordinate_size = 4; - auto means = primitive_->GetAttr("means"); - MS_EXCEPTION_IF_NULL(means); - if (means->isa()) { - means_ = GetValue>(means); - } else if (means->isa()) { - float mean = GetValue(means); - for (size_t i = 0; i < coordinate_size; i++) { - (void)means_.emplace_back(mean); - } - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the input 'means' must be a tuple or a list, and dtype must be float, but got is not."; - } - - auto stds = primitive_->GetAttr("stds"); - MS_EXCEPTION_IF_NULL(stds); - if (stds->isa()) { - stds_ = GetValue>(stds); - } else if (stds->isa()) { - float std = GetValue(stds); - for (size_t i = 0; i < coordinate_size; i++) { - (void)stds_.emplace_back(std); - } - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the input 'stds' must be a tuple or a list, and dtype must be float, but got is not."; - } - - if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the length of input 'means' and 'stds' must be at least 4, " - "but got the length of 'means': " - << means_.size() << ", and the length of 'stds': " << stds_.size(); - } - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int BoundingBoxEncodeGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - return KRET_OK; -} - -template -bool BoundingBoxEncodeGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - T *anchor_addr = GetDeviceAddress(inputs, 0); - T *groundtruth_addr = GetDeviceAddress(inputs, 1); - T *deltas_addr = GetDeviceAddress(outputs, 0); - - if (inputs[0]->size() != inputs[1]->size()) { - MS_LOG(ERROR) << "For '" << kernel_name_ - << "', anchor box size must equal with groundtruth box size: " << inputs[1]->size() << ", but got " - << inputs[0]->size(); - return false; - } - - const size_t coordinate = 4; - const size_t block_size = inputs[0]->size() / sizeof(T); - if ((block_size % coordinate) != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; - return false; - } - - auto status = BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], - means_[1], means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> - BoundingBoxEncodeGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &BoundingBoxEncodeGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &BoundingBoxEncodeGpuKernelMod::LaunchKernel}}; - -std::vector BoundingBoxEncodeGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { - return pair.first; - }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BoundingBoxEncode, BoundingBoxEncodeGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/other/boundingbox_encode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool BoundingBoxEncodeGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = 2; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + + const size_t coordinate_size = 4; + auto means = primitive_->GetAttr("means"); + MS_EXCEPTION_IF_NULL(means); + if (means->isa()) { + means_ = GetValue>(means); + } else if (means->isa()) { + float mean = GetValue(means); + for (size_t i = 0; i < coordinate_size; i++) { + (void)means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the input 'means' must be a tuple or a list, and dtype must be float, but got is not."; + } + + auto stds = primitive_->GetAttr("stds"); + MS_EXCEPTION_IF_NULL(stds); + if (stds->isa()) { + stds_ = GetValue>(stds); + } else if (stds->isa()) { + float std = GetValue(stds); + for (size_t i = 0; i < coordinate_size; i++) { + (void)stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the input 'stds' must be a tuple or a list, and dtype must be float, but got is not."; + } + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the length of input 'means' and 'stds' must be at least 4, " + "but got the length of 'means': " + << means_.size() << ", and the length of 'stds': " << stds_.size(); + } + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int BoundingBoxEncodeGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + return KRET_OK; +} + +template +bool BoundingBoxEncodeGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *anchor_addr = GetDeviceAddress(inputs, 0); + T *groundtruth_addr = GetDeviceAddress(inputs, 1); + T *deltas_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size() != inputs[1]->size()) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', anchor box size must equal with groundtruth box size: " << inputs[1]->size() << ", but got " + << inputs[0]->size(); + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size() / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; + return false; + } + + auto status = BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], + means_[1], means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> + BoundingBoxEncodeGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &BoundingBoxEncodeGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &BoundingBoxEncodeGpuKernelMod::LaunchKernel}}; + +std::vector BoundingBoxEncodeGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { + return pair.first; + }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BoundingBoxEncode, BoundingBoxEncodeGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.cc index cea7b7dcacd..da827a5480c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.cc @@ -1,145 +1,145 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool HammingWindowGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " - << "but got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); - periodic_ = GetValue(primitive_->GetAttr("periodic")); - alpha_ = GetValue(primitive_->GetAttr("alpha")); - beta_ = GetValue(primitive_->GetAttr("beta")); - return true; -} - -int HammingWindowGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - auto output = outputs.at(kIndex0); - MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); - std::vector output_shape = - std::vector(output->GetDeviceShapeVector().begin(), output->GetDeviceShapeVector().end()); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - size_t output_size = output_elements_ * unit_output_size_; - output_size_list_.push_back(output_size); - return KRET_OK; -} - -template -bool HammingWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *input = GetDeviceAddress(inputs, 0); - S *output = GetDeviceAddress(outputs, 0); - T N = 0; - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost, reinterpret_cast(cuda_stream_)), - "For 'HammingWindow', copy max_index failed"); - if (cudaStreamQuery(reinterpret_cast(cuda_stream_)) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(cuda_stream_)), - "For 'HammingWindow', cudaStreamSyncFailed"); - } - auto status = HammingWindow(output_elements_, N, alpha_, beta_, periodic_, output, device_id_, - reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> HammingWindowGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), - &HammingWindowGpuKernelMod::LaunchKernel}}; - -std::vector HammingWindowGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HammingWindow, HammingWindowGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool HammingWindowGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], " + << "but got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).dtype); + periodic_ = GetValue(primitive_->GetAttr("periodic")); + alpha_ = GetValue(primitive_->GetAttr("alpha")); + beta_ = GetValue(primitive_->GetAttr("beta")); + return true; +} + +int HammingWindowGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + MS_ERROR_IF_NULL_W_RET_VAL(input, KRET_RESIZE_FAILED); + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + auto output = outputs.at(kIndex0); + MS_ERROR_IF_NULL_W_RET_VAL(output, KRET_RESIZE_FAILED); + std::vector output_shape = + std::vector(output->GetDeviceShapeVector().begin(), output->GetDeviceShapeVector().end()); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + size_t output_size = output_elements_ * unit_output_size_; + output_size_list_.push_back(output_size); + return KRET_OK; +} + +template +bool HammingWindowGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + S *output = GetDeviceAddress(outputs, 0); + T N = 0; + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost, reinterpret_cast(cuda_stream_)), + "For 'HammingWindow', copy max_index failed"); + if (cudaStreamQuery(reinterpret_cast(cuda_stream_)) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(cuda_stream_)), + "For 'HammingWindow', cudaStreamSyncFailed"); + } + auto status = HammingWindow(output_elements_, N, alpha_, beta_, periodic_, output, device_id_, + reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> HammingWindowGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), + &HammingWindowGpuKernelMod::LaunchKernel}}; + +std::vector HammingWindowGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HammingWindow, HammingWindowGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h index d925761bb9e..314b7941f75 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h @@ -1,85 +1,85 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/hamming_window.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh" - -namespace mindspore { -namespace kernel { -class HammingWindowGpuKernelMod : public NativeGpuKernelMod { - public: - HammingWindowGpuKernelMod() { ResetResource(); } - ~HammingWindowGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using Hamming_Func = - std::function &, - const std::vector &, const std::vector &)>; - - private: - bool periodic_{true}; - float alpha_{0.54}; - float beta_{0.46}; - size_t unit_input_size_{1}; - size_t unit_output_size_{1}; - size_t output_elements_; - Hamming_Func kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/hamming_window.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh" + +namespace mindspore { +namespace kernel { +class HammingWindowGpuKernelMod : public NativeGpuKernelMod { + public: + HammingWindowGpuKernelMod() { ResetResource(); } + ~HammingWindowGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using Hamming_Func = + std::function &, + const std::vector &, const std::vector &)>; + + private: + bool periodic_{true}; + float alpha_{0.54}; + float beta_{0.46}; + size_t unit_input_size_{1}; + size_t unit_output_size_{1}; + size_t output_elements_; + Hamming_Func kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/iou_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/iou_gpu_kernel.cc index 7f77f0d9307..d03467fa32a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/iou_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/iou_gpu_kernel.cc @@ -1,108 +1,108 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "plugin/device/gpu/kernel/other/iou_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kIOUInputsNum = 2; -constexpr size_t kIOUOutputsNum = 1; -constexpr size_t kBoxCoordinateLen = 4; -constexpr auto kIou = "iou"; -constexpr auto kIof = "iof"; -}; // namespace - -bool IOUGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIOUInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIOUOutputsNum, kernel_name_); - - auto mode_value_ptr = primitive_->GetAttr(kAttrMode); - MS_EXCEPTION_IF_NULL(mode_value_ptr); - auto mode = GetValue(mode_value_ptr); - if (mode == kIou) { - mode_ = IOU_MODE; - } else if (mode == kIof) { - mode_ = IOF_MODE; - } else { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', mode only support 'iou' or 'iof'."; - } - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - return true; -} - -int IOUGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIOUInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIOUOutputsNum, kernel_name_); - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - - size_t type_size = GetTypeByte(TypeIdToType(inputs[ANCHOR_BOXES]->dtype_id())); - const size_t anchor_boxes_size_ = inputs[ANCHOR_BOXES]->size() / type_size; - const size_t gt_boxes_size_ = inputs[GT_BOXES]->size() / type_size; - if ((anchor_boxes_size_ % kBoxCoordinateLen) != 0 || (gt_boxes_size_ % kBoxCoordinateLen) != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; - return false; - } - anchor_boxes_len_ = anchor_boxes_size_ / kBoxCoordinateLen; - gt_boxes_len_ = gt_boxes_size_ / kBoxCoordinateLen; - return KRET_OK; -} - -template -bool IOUGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs, void *stream_ptr) { - auto *anchor_boxes_addr = GetDeviceAddress(inputs, ANCHOR_BOXES); - auto *gt_boxes_addr = GetDeviceAddress(inputs, GT_BOXES); - auto *iou_addr = GetDeviceAddress(outputs, IOU_VALUE); - - auto status = IOU(anchor_boxes_len_ * gt_boxes_len_, anchor_boxes_addr, gt_boxes_addr, iou_addr, mode_, - anchor_boxes_len_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> IOUGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &IOUGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &IOUGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &IOUGpuKernelMod::LaunchKernel}, -}; - -std::vector IOUGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IOU, IOUGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "plugin/device/gpu/kernel/other/iou_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kIOUInputsNum = 2; +constexpr size_t kIOUOutputsNum = 1; +constexpr size_t kBoxCoordinateLen = 4; +constexpr auto kIou = "iou"; +constexpr auto kIof = "iof"; +}; // namespace + +bool IOUGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIOUInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIOUOutputsNum, kernel_name_); + + auto mode_value_ptr = primitive_->GetAttr(kAttrMode); + MS_EXCEPTION_IF_NULL(mode_value_ptr); + auto mode = GetValue(mode_value_ptr); + if (mode == kIou) { + mode_ = IOU_MODE; + } else if (mode == kIof) { + mode_ = IOF_MODE; + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', mode only support 'iou' or 'iof'."; + } + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int IOUGpuKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIOUInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIOUOutputsNum, kernel_name_); + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + + size_t type_size = GetTypeByte(TypeIdToType(inputs[ANCHOR_BOXES]->dtype_id())); + const size_t anchor_boxes_size_ = inputs[ANCHOR_BOXES]->size() / type_size; + const size_t gt_boxes_size_ = inputs[GT_BOXES]->size() / type_size; + if ((anchor_boxes_size_ % kBoxCoordinateLen) != 0 || (gt_boxes_size_ % kBoxCoordinateLen) != 0) { + MS_LOG(ERROR) << "For '" << kernel_name_ << ", the size of the box should be a multiple of 4."; + return false; + } + anchor_boxes_len_ = anchor_boxes_size_ / kBoxCoordinateLen; + gt_boxes_len_ = gt_boxes_size_ / kBoxCoordinateLen; + return KRET_OK; +} + +template +bool IOUGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs, void *stream_ptr) { + auto *anchor_boxes_addr = GetDeviceAddress(inputs, ANCHOR_BOXES); + auto *gt_boxes_addr = GetDeviceAddress(inputs, GT_BOXES); + auto *iou_addr = GetDeviceAddress(outputs, IOU_VALUE); + + auto status = IOU(anchor_boxes_len_ * gt_boxes_len_, anchor_boxes_addr, gt_boxes_addr, iou_addr, mode_, + anchor_boxes_len_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> IOUGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &IOUGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &IOUGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &IOUGpuKernelMod::LaunchKernel}, +}; + +std::vector IOUGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, IOU, IOUGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/customize/non_zero.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/customize/non_zero.cc old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/customize/non_zero.h b/mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/customize/non_zero.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc index db14b2c1109..11f9a3bd8d2 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc @@ -1,109 +1,109 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeLearnedScaleQuantPerChannelGpuKernelMod::FakeLearnedScaleQuantPerChannelGpuKernelMod() - : input_size_(0), - quant_num_(1), - global_step_(0), - quant_delay_(0), - training_(false), - neg_trunc_(false), - num_channels_(0) {} - -bool FakeLearnedScaleQuantPerChannelGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - training_ = GetValue(primitive_->GetAttr("training")); - neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); - return true; -} - -int FakeLearnedScaleQuantPerChannelGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = inputs[kIndex0]->GetShapeVector(); - num_channels_ = LongToInt(input_shape[0]); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= LongToInt(input_shape[i]); - } - input_size_ = sizeof(float) * quant_num_; - output_size_list_.push_back(input_size_); // y - workspace_size_list_.push_back(input_size_); // input_div_alpha - workspace_size_list_.push_back(input_size_); // input_quant - return KRET_OK; -} - -bool FakeLearnedScaleQuantPerChannelGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *input = GetDeviceAddress(inputs, kIndex0); - float *input_alpha = GetDeviceAddress(inputs, kIndex1); - float *input_quant_max = GetDeviceAddress(inputs, kIndex2); - float *output = GetDeviceAddress(outputs, kIndex0); - float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); - float *input_quant = GetDeviceAddress(workspace, kIndex1); - - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(input_alpha); - MS_EXCEPTION_IF_NULL(input_quant_max); - MS_EXCEPTION_IF_NULL(output); - MS_EXCEPTION_IF_NULL(input_div_alpha); - MS_EXCEPTION_IF_NULL(input_quant); - - if (training_) { - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - } else { - // real launch - auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannel, FakeLearnedScaleQuantPerChannelGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerChannelGpuKernelMod::FakeLearnedScaleQuantPerChannelGpuKernelMod() + : input_size_(0), + quant_num_(1), + global_step_(0), + quant_delay_(0), + training_(false), + neg_trunc_(false), + num_channels_(0) {} + +bool FakeLearnedScaleQuantPerChannelGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + training_ = GetValue(primitive_->GetAttr("training")); + neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); + return true; +} + +int FakeLearnedScaleQuantPerChannelGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = inputs[kIndex0]->GetShapeVector(); + num_channels_ = LongToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= LongToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant + return KRET_OK; +} + +bool FakeLearnedScaleQuantPerChannelGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *input = GetDeviceAddress(inputs, kIndex0); + float *input_alpha = GetDeviceAddress(inputs, kIndex1); + float *input_quant_max = GetDeviceAddress(inputs, kIndex2); + float *output = GetDeviceAddress(outputs, kIndex0); + float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); + float *input_quant = GetDeviceAddress(workspace, kIndex1); + + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(output); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannel, FakeLearnedScaleQuantPerChannelGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc index ac26b8a8e86..276d9f5f64b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc @@ -1,113 +1,113 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeLearnedScaleQuantPerChannelGradGpuKernelMod::FakeLearnedScaleQuantPerChannelGradGpuKernelMod() - : input_size_(0), - workspace_size_(0), - quant_num_(1), - quant_delay_(0), - global_step_(0), - neg_trunc_(false), - num_channels_(0) {} - -bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); - return true; -} - -int FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = inputs[kIndex0]->GetShapeVector(); - num_channels_ = LongToInt(input_shape[0]); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= LongToInt(input_shape[i]); - } - input_size_ = sizeof(float) * quant_num_; - output_size_list_.push_back(input_size_); // grad_input - output_size_list_.push_back(sizeof(float) * num_channels_); // grad_alpha - workspace_size_list_.push_back(input_size_); // input_div_alpha - workspace_size_list_.push_back(input_size_); // input_quant - return KRET_OK; -} - -bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, - void *stream_ptr) { - float *grad_input = GetDeviceAddress(outputs, kIndex0); - float *grad_alpha = GetDeviceAddress(outputs, kIndex1); - float *gradient = GetDeviceAddress(inputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex1); - float *input_alpha = GetDeviceAddress(inputs, kIndex2); - float *input_quant_max = GetDeviceAddress(inputs, kIndex3); - float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); - float *input_quant = GetDeviceAddress(workspace, kIndex1); - - MS_EXCEPTION_IF_NULL(grad_input); - MS_EXCEPTION_IF_NULL(grad_alpha); - MS_EXCEPTION_IF_NULL(gradient); - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(input_alpha); - MS_EXCEPTION_IF_NULL(input_quant_max); - MS_EXCEPTION_IF_NULL(input_div_alpha); - MS_EXCEPTION_IF_NULL(input_quant); - const int kChannelLen = num_channels_; - std::vector alpha_no_grad(kChannelLen); - if (memset_s(alpha_no_grad.data(), kChannelLen * sizeof(float), 0, kChannelLen * sizeof(float)) != EOK) { - MS_LOG(EXCEPTION) << "Failed to set memory."; - } - if (global_step_ >= quant_delay_) { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = - CalFakeLearnedScaleQuantPerChannelGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, input_quant, - neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannelGrad, FakeLearnedScaleQuantPerChannelGradGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perchannel_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerChannelGradGpuKernelMod::FakeLearnedScaleQuantPerChannelGradGpuKernelMod() + : input_size_(0), + workspace_size_(0), + quant_num_(1), + quant_delay_(0), + global_step_(0), + neg_trunc_(false), + num_channels_(0) {} + +bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); + return true; +} + +int FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = inputs[kIndex0]->GetShapeVector(); + num_channels_ = LongToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= LongToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + output_size_list_.push_back(input_size_); // grad_input + output_size_list_.push_back(sizeof(float) * num_channels_); // grad_alpha + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant + return KRET_OK; +} + +bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, + void *stream_ptr) { + float *grad_input = GetDeviceAddress(outputs, kIndex0); + float *grad_alpha = GetDeviceAddress(outputs, kIndex1); + float *gradient = GetDeviceAddress(inputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex1); + float *input_alpha = GetDeviceAddress(inputs, kIndex2); + float *input_quant_max = GetDeviceAddress(inputs, kIndex3); + float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); + float *input_quant = GetDeviceAddress(workspace, kIndex1); + + MS_EXCEPTION_IF_NULL(grad_input); + MS_EXCEPTION_IF_NULL(grad_alpha); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + const int kChannelLen = num_channels_; + std::vector alpha_no_grad(kChannelLen); + if (memset_s(alpha_no_grad.data(), kChannelLen * sizeof(float), 0, kChannelLen * sizeof(float)) != EOK) { + MS_LOG(EXCEPTION) << "Failed to set memory."; + } + if (global_step_ >= quant_delay_) { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + auto status = CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = + CalFakeLearnedScaleQuantPerChannelGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, input_quant, + neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannelGrad, FakeLearnedScaleQuantPerChannelGradGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc index 98e230b7d63..8505b5b4717 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc @@ -1,102 +1,102 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeLearnedScaleQuantPerLayerGpuKernelMod::FakeLearnedScaleQuantPerLayerGpuKernelMod() - : input_size_(0), quant_num_(1), global_step_(0), quant_delay_(0), training_(false), neg_trunc_(false) {} - -bool FakeLearnedScaleQuantPerLayerGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - training_ = GetValue(primitive_->GetAttr("training")); - neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); - return true; -} - -int FakeLearnedScaleQuantPerLayerGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = inputs[kIndex0]->GetShapeVector(); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= LongToInt(input_shape[i]); - } - input_size_ = sizeof(float) * quant_num_; - output_size_list_.push_back(input_size_); // y - workspace_size_list_.push_back(input_size_); // input_div_alpha - workspace_size_list_.push_back(input_size_); // input_quant - return KRET_OK; -} - -bool FakeLearnedScaleQuantPerLayerGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *input = GetDeviceAddress(inputs, kIndex0); - float *input_alpha = GetDeviceAddress(inputs, kIndex1); - float *input_quant_max = GetDeviceAddress(inputs, kIndex2); - float *output = GetDeviceAddress(outputs, kIndex0); - float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); - float *input_quant = GetDeviceAddress(workspace, kIndex1); - - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(input_alpha); - MS_EXCEPTION_IF_NULL(input_quant_max); - MS_EXCEPTION_IF_NULL(output); - MS_EXCEPTION_IF_NULL(input_div_alpha); - MS_EXCEPTION_IF_NULL(input_quant); - - if (training_) { - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - } else { - // real launch - auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayer, FakeLearnedScaleQuantPerLayerGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerLayerGpuKernelMod::FakeLearnedScaleQuantPerLayerGpuKernelMod() + : input_size_(0), quant_num_(1), global_step_(0), quant_delay_(0), training_(false), neg_trunc_(false) {} + +bool FakeLearnedScaleQuantPerLayerGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + training_ = GetValue(primitive_->GetAttr("training")); + neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); + return true; +} + +int FakeLearnedScaleQuantPerLayerGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = inputs[kIndex0]->GetShapeVector(); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= LongToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant + return KRET_OK; +} + +bool FakeLearnedScaleQuantPerLayerGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *input = GetDeviceAddress(inputs, kIndex0); + float *input_alpha = GetDeviceAddress(inputs, kIndex1); + float *input_quant_max = GetDeviceAddress(inputs, kIndex2); + float *output = GetDeviceAddress(outputs, kIndex0); + float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); + float *input_quant = GetDeviceAddress(workspace, kIndex1); + + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(output); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayer, FakeLearnedScaleQuantPerLayerGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc index fa8c22ae455..2da5bca3f39 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc @@ -1,99 +1,99 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeLearnedScaleQuantPerLayerGradGpuKernelMod::FakeLearnedScaleQuantPerLayerGradGpuKernelMod() - : input_size_(0), workspace_size_(0), quant_num_(1), quant_delay_(0), global_step_(0), neg_trunc_(false) {} - -bool FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); - return true; -} - -int FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = inputs[kIndex0]->GetShapeVector(); - auto size = SizeOf(input_shape); - quant_num_ = SizeToInt(size); - input_size_ = sizeof(float) * size; - output_size_list_.push_back(input_size_); // grad_input - output_size_list_.push_back(sizeof(float)); // grad_alpha - workspace_size_list_.push_back(input_size_); // input_div_alpha - workspace_size_list_.push_back(input_size_); // input_quant - return KRET_OK; -} - -bool FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, - void *stream_ptr) { - float *grad_input = GetDeviceAddress(outputs, kIndex0); - float *grad_alpha = GetDeviceAddress(outputs, kIndex1); - float *gradient = GetDeviceAddress(inputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex1); - float *input_alpha = GetDeviceAddress(inputs, kIndex2); - float *input_quant_max = GetDeviceAddress(inputs, kIndex3); - float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); - float *input_quant = GetDeviceAddress(workspace, kIndex1); - - MS_EXCEPTION_IF_NULL(grad_input); - MS_EXCEPTION_IF_NULL(grad_alpha); - MS_EXCEPTION_IF_NULL(gradient); - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(input_alpha); - MS_EXCEPTION_IF_NULL(input_quant_max); - MS_EXCEPTION_IF_NULL(input_div_alpha); - MS_EXCEPTION_IF_NULL(input_quant); - - const float alpha_no_grad[1] = {0.f}; - if (global_step_ >= quant_delay_) { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, - neg_trunc_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeLearnedScaleQuantPerLayerGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, - input_quant, neg_trunc_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayerGrad, FakeLearnedScaleQuantPerLayerGradGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/gpu/kernel/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_learned_scale_quant_perlayer_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerLayerGradGpuKernelMod::FakeLearnedScaleQuantPerLayerGradGpuKernelMod() + : input_size_(0), workspace_size_(0), quant_num_(1), quant_delay_(0), global_step_(0), neg_trunc_(false) {} + +bool FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + neg_trunc_ = GetValue(primitive_->GetAttr("neg_trunc")); + return true; +} + +int FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = inputs[kIndex0]->GetShapeVector(); + auto size = SizeOf(input_shape); + quant_num_ = SizeToInt(size); + input_size_ = sizeof(float) * size; + output_size_list_.push_back(input_size_); // grad_input + output_size_list_.push_back(sizeof(float)); // grad_alpha + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant + return KRET_OK; +} + +bool FakeLearnedScaleQuantPerLayerGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, + void *stream_ptr) { + float *grad_input = GetDeviceAddress(outputs, kIndex0); + float *grad_alpha = GetDeviceAddress(outputs, kIndex1); + float *gradient = GetDeviceAddress(inputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex1); + float *input_alpha = GetDeviceAddress(inputs, kIndex2); + float *input_quant_max = GetDeviceAddress(inputs, kIndex3); + float *input_div_alpha = GetDeviceAddress(workspace, kIndex0); + float *input_quant = GetDeviceAddress(workspace, kIndex1); + + MS_EXCEPTION_IF_NULL(grad_input); + MS_EXCEPTION_IF_NULL(grad_alpha); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + const float alpha_no_grad[1] = {0.f}; + if (global_step_ >= quant_delay_) { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + auto status = CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, + neg_trunc_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeLearnedScaleQuantPerLayerGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, + input_quant, neg_trunc_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayerGrad, FakeLearnedScaleQuantPerLayerGradGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.cc index e6493322e27..aacc34f1eb3 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.cc @@ -1,139 +1,139 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGpuKernelMod::FakeQuantPerChannelGpuKernelMod() - : input_size_(0), - num_channels_(0), - num_bits_(0), - training_(false), - symmetric_(false), - narrow_range_(false), - is_null_input_(false), - quant_delay_(0), - quant_min_(0), - quant_max_(0), - global_step_(0) {} - -bool FakeQuantPerChannelGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - // get attribute - num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); - training_ = GetValue(primitive_->GetAttr("training")); - symmetric_ = GetValue(primitive_->GetAttr("symmetric")); - narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - - if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " - << num_bits_; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - return true; -} - -void FakeQuantPerChannelGpuKernelMod::SetSizeLists() { - output_size_list_.push_back(input_size_); // output in tensor - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -int FakeQuantPerChannelGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // shape info for gpu - auto input_shape = inputs[kIndex0]->GetShapeVector(); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); - if (is_null_input_) { - SetSizeLists(); - return KRET_UNKNOWN_SHAPE; - } - if (input_shape.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input cannot be empty, but got empty"; - } - num_channels_ = LongToInt(input_shape[0]); - input_size_ = sizeof(float) * SizeOf(input_shape); - SetSizeLists(); - return KRET_OK; -} - -void FakeQuantPerChannelGpuKernelMod::CalFakeQuantize(const float *input, float *output, float *input_min, - float *input_max, float *nudge_min, float *nudge_max, - float *scale, void *stream_ptr) { - auto status = CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - num_channels_, symmetric_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, "CalNudgePerChannel called by " + kernel_name_); - status = CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, - scale, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); -} - -bool FakeQuantPerChannelGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - (void)workspace; - float *output = GetDeviceAddress(outputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex0); - float *input_min = GetDeviceAddress(inputs, kIndex1); - float *input_max = GetDeviceAddress(inputs, kIndex2); - float *scale = GetDeviceAddress(workspace, kIndex0); - float *nudge_min = GetDeviceAddress(workspace, kIndex1); - float *nudge_max = GetDeviceAddress(workspace, kIndex2); - - if (training_) { - if (global_step_ >= quant_delay_) { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - } else { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGpuKernelMod::FakeQuantPerChannelGpuKernelMod() + : input_size_(0), + num_channels_(0), + num_bits_(0), + training_(false), + symmetric_(false), + narrow_range_(false), + is_null_input_(false), + quant_delay_(0), + quant_min_(0), + quant_max_(0), + global_step_(0) {} + +bool FakeQuantPerChannelGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + // get attribute + num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); + training_ = GetValue(primitive_->GetAttr("training")); + symmetric_ = GetValue(primitive_->GetAttr("symmetric")); + narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + + if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " + << num_bits_; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + return true; +} + +void FakeQuantPerChannelGpuKernelMod::SetSizeLists() { + output_size_list_.push_back(input_size_); // output in tensor + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +int FakeQuantPerChannelGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // shape info for gpu + auto input_shape = inputs[kIndex0]->GetShapeVector(); + is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); + if (is_null_input_) { + SetSizeLists(); + return KRET_UNKNOWN_SHAPE; + } + if (input_shape.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input cannot be empty, but got empty"; + } + num_channels_ = LongToInt(input_shape[0]); + input_size_ = sizeof(float) * SizeOf(input_shape); + SetSizeLists(); + return KRET_OK; +} + +void FakeQuantPerChannelGpuKernelMod::CalFakeQuantize(const float *input, float *output, float *input_min, + float *input_max, float *nudge_min, float *nudge_max, + float *scale, void *stream_ptr) { + auto status = CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, + num_channels_, symmetric_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, "CalNudgePerChannel called by " + kernel_name_); + status = CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, + scale, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); +} + +bool FakeQuantPerChannelGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + (void)workspace; + float *output = GetDeviceAddress(outputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex0); + float *input_min = GetDeviceAddress(inputs, kIndex1); + float *input_max = GetDeviceAddress(inputs, kIndex2); + float *scale = GetDeviceAddress(workspace, kIndex0); + float *nudge_min = GetDeviceAddress(workspace, kIndex1); + float *nudge_max = GetDeviceAddress(workspace, kIndex2); + + if (training_) { + if (global_step_ >= quant_delay_) { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + } else { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_gpu_kernel.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.cc index 52f822f0649..3ea27eec5c0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.cc @@ -1,122 +1,122 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGradGpuKernelMod::FakeQuantPerChannelGradGpuKernelMod() - : input_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - num_channels_(0), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - is_null_input_(false), - symmetric_(false) {} - -bool FakeQuantPerChannelGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); - if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " - << num_bits_; - } - - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - - symmetric_ = GetValue(primitive_->GetAttr("symmetric")); - narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - return true; -} - -void FakeQuantPerChannelGradGpuKernelMod::SetSizeLists() { - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -int FakeQuantPerChannelGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - auto input_shape = inputs[kIndex0]->GetShapeVector(); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); - if (is_null_input_) { - SetSizeLists(); - return KRET_UNKNOWN_SHAPE; - } - if (input_shape.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input cannot be empty, but got empty"; - } - num_channels_ = LongToInt(input_shape[0]); - input_size_ = sizeof(float) * SizeOf(input_shape); - SetSizeLists(); - return KRET_OK; -} - -bool FakeQuantPerChannelGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - (void)workspace; - float *output = GetDeviceAddress(outputs, kIndex0); - float *gradient = GetDeviceAddress(inputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex1); - float *input_min = GetDeviceAddress(inputs, kIndex2); - float *input_max = GetDeviceAddress(inputs, kIndex3); - float *scale = GetDeviceAddress(workspace, kIndex0); - float *nudge_min = GetDeviceAddress(workspace, kIndex1); - float *nudge_max = GetDeviceAddress(workspace, kIndex2); - - int total_size = input_size_ / sizeof(float); - if (global_step_ >= quant_delay_) { - auto status = CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - num_channels_, symmetric_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_quant_perchannel_grad_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perchannel_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGradGpuKernelMod::FakeQuantPerChannelGradGpuKernelMod() + : input_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + num_channels_(0), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + is_null_input_(false), + symmetric_(false) {} + +bool FakeQuantPerChannelGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); + if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " + << num_bits_; + } + + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + + symmetric_ = GetValue(primitive_->GetAttr("symmetric")); + narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + return true; +} + +void FakeQuantPerChannelGradGpuKernelMod::SetSizeLists() { + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +int FakeQuantPerChannelGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + auto input_shape = inputs[kIndex0]->GetShapeVector(); + is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); + if (is_null_input_) { + SetSizeLists(); + return KRET_UNKNOWN_SHAPE; + } + if (input_shape.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input cannot be empty, but got empty"; + } + num_channels_ = LongToInt(input_shape[0]); + input_size_ = sizeof(float) * SizeOf(input_shape); + SetSizeLists(); + return KRET_OK; +} + +bool FakeQuantPerChannelGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + (void)workspace; + float *output = GetDeviceAddress(outputs, kIndex0); + float *gradient = GetDeviceAddress(inputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex1); + float *input_min = GetDeviceAddress(inputs, kIndex2); + float *input_max = GetDeviceAddress(inputs, kIndex3); + float *scale = GetDeviceAddress(workspace, kIndex0); + float *nudge_min = GetDeviceAddress(workspace, kIndex1); + float *nudge_max = GetDeviceAddress(workspace, kIndex2); + + int total_size = input_size_ / sizeof(float); + if (global_step_ >= quant_delay_) { + auto status = CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, + num_channels_, symmetric_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.cc index 0eb5e0cb06e..d0c1239f76d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.cc @@ -1,136 +1,136 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.h" -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGpuKernelMod::FakeQuantPerLayerGpuKernelMod() - : input_size_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - global_step_(0), - num_bits_(0), - quant_delay_(0), - training_(false), - narrow_range_(false), - is_null_input_(false), - symmetric_(false) {} - -bool FakeQuantPerLayerGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - training_ = GetValue(primitive_->GetAttr("training")); - symmetric_ = GetValue(primitive_->GetAttr("symmetric")); - narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); - - if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " - << num_bits_; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - return true; -} - -void FakeQuantPerLayerGpuKernelMod::SetSizeLists() { - output_size_list_.push_back(input_size_); // y - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -int FakeQuantPerLayerGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = inputs[kIndex0]->GetShapeVector(); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); - if (is_null_input_) { - SetSizeLists(); - return KRET_UNKNOWN_SHAPE; - } - auto size = SizeOf(input_shape); - quant_num_ = SizeToInt(size); - input_size_ = sizeof(float) * size; - SetSizeLists(); - return KRET_OK; -} - -bool FakeQuantPerLayerGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - float *output = GetDeviceAddress(outputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex0); - float *input_min = GetDeviceAddress(inputs, kIndex1); - float *input_max = GetDeviceAddress(inputs, kIndex2); - float *scale = GetDeviceAddress(workspace, kIndex0); - float *nudge_min = GetDeviceAddress(workspace, kIndex1); - float *nudge_max = GetDeviceAddress(workspace, kIndex2); - - if (training_) { - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - symmetric_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - } else { - // real launch - auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - symmetric_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.h" +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGpuKernelMod::FakeQuantPerLayerGpuKernelMod() + : input_size_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + global_step_(0), + num_bits_(0), + quant_delay_(0), + training_(false), + narrow_range_(false), + is_null_input_(false), + symmetric_(false) {} + +bool FakeQuantPerLayerGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + training_ = GetValue(primitive_->GetAttr("training")); + symmetric_ = GetValue(primitive_->GetAttr("symmetric")); + narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); + + if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " + << num_bits_; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + return true; +} + +void FakeQuantPerLayerGpuKernelMod::SetSizeLists() { + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +int FakeQuantPerLayerGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = inputs[kIndex0]->GetShapeVector(); + is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); + if (is_null_input_) { + SetSizeLists(); + return KRET_UNKNOWN_SHAPE; + } + auto size = SizeOf(input_shape); + quant_num_ = SizeToInt(size); + input_size_ = sizeof(float) * size; + SetSizeLists(); + return KRET_OK; +} + +bool FakeQuantPerLayerGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + float *output = GetDeviceAddress(outputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex0); + float *input_min = GetDeviceAddress(inputs, kIndex1); + float *input_max = GetDeviceAddress(inputs, kIndex2); + float *scale = GetDeviceAddress(workspace, kIndex0); + float *nudge_min = GetDeviceAddress(workspace, kIndex1); + float *nudge_max = GetDeviceAddress(workspace, kIndex2); + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, + symmetric_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, + symmetric_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_gpu_kernel.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.cc index f583dec1c9e..4d7db778806 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.cc @@ -1,124 +1,124 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh" -#include "plugin/device/gpu/kernel/quant/quant_op_const.h" - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGradGpuKernelMod::FakeQuantPerLayerGradGpuKernelMod() - : input_size_(0), - workspace_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - is_null_input_(false), - symmetric_(false) {} - -bool FakeQuantPerLayerGradGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); - if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " - << num_bits_; - } - - quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " - << quant_delay_; - } - - symmetric_ = GetValue(primitive_->GetAttr("symmetric")); - narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - return true; -} - -void FakeQuantPerLayerGradGpuKernelMod::SetSizeLists() { - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -int FakeQuantPerLayerGradGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - output_size_list_.clear(); - workspace_size_list_.clear(); - // init size - auto input_shape = Convert2SizeT(inputs[kIndex0]->GetShapeVector()); - is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); - if (is_null_input_) { - SetSizeLists(); - return KRET_UNKNOWN_SHAPE; - } - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - SetSizeLists(); - return KRET_OK; -} - -bool FakeQuantPerLayerGradGpuKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - float *output = GetDeviceAddress(outputs, kIndex0); - float *gradient = GetDeviceAddress(inputs, kIndex0); - float *input = GetDeviceAddress(inputs, kIndex1); - float *input_min = GetDeviceAddress(inputs, kIndex2); - float *input_max = GetDeviceAddress(inputs, kIndex3); - float *scale = GetDeviceAddress(workspace, kIndex0); - float *nudge_min = GetDeviceAddress(workspace, kIndex1); - float *nudge_max = GetDeviceAddress(workspace, kIndex2); - - if (global_step_ >= quant_delay_) { - auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - symmetric_, reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - status = CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_STATUS(status, kernel_name_); - } else { - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernelMod) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/quant/fake_quant_perlayer_grad_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fake_quant_perlayer_impl.cuh" +#include "plugin/device/gpu/kernel/quant/quant_op_const.h" + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGradGpuKernelMod::FakeQuantPerLayerGradGpuKernelMod() + : input_size_(0), + workspace_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + is_null_input_(false), + symmetric_(false) {} + +bool FakeQuantPerLayerGradGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + num_bits_ = static_cast(GetValue(primitive_->GetAttr("num_bits"))); + if (num_bits_ <= kMinQuantBit || num_bits_ >= kMaxQuantBit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of num_bits should be in (2, 16), but got " + << num_bits_; + } + + quant_delay_ = static_cast(GetValue(primitive_->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of quant_delay_ cannot be less than 0, but got " + << quant_delay_; + } + + symmetric_ = GetValue(primitive_->GetAttr("symmetric")); + narrow_range_ = GetValue(primitive_->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + return true; +} + +void FakeQuantPerLayerGradGpuKernelMod::SetSizeLists() { + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +int FakeQuantPerLayerGradGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + output_size_list_.clear(); + workspace_size_list_.clear(); + // init size + auto input_shape = Convert2SizeT(inputs[kIndex0]->GetShapeVector()); + is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input"); + if (is_null_input_) { + SetSizeLists(); + return KRET_UNKNOWN_SHAPE; + } + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + SetSizeLists(); + return KRET_OK; +} + +bool FakeQuantPerLayerGradGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + float *output = GetDeviceAddress(outputs, kIndex0); + float *gradient = GetDeviceAddress(inputs, kIndex0); + float *input = GetDeviceAddress(inputs, kIndex1); + float *input_min = GetDeviceAddress(inputs, kIndex2); + float *input_max = GetDeviceAddress(inputs, kIndex3); + float *scale = GetDeviceAddress(workspace, kIndex0); + float *nudge_min = GetDeviceAddress(workspace, kIndex1); + float *nudge_max = GetDeviceAddress(workspace, kIndex2); + + if (global_step_ >= quant_delay_) { + auto status = CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, + symmetric_, reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + status = CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_STATUS(status, kernel_name_); + } else { + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernelMod) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.cc index 2b8ac75ec6f..4839320cc58 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.cc @@ -1,129 +1,129 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pack.cuh" -#include "mindspore/core/ops/sequence_stack.h" -#include "mindspore/ccsrc/kernel/format_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr int kInputsNum = 1; -constexpr int kOutputsNum = 1; -} // namespace - -bool SequenceStackGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); - - return MatchKernelFunc(kernel_name_, inputs, outputs); -} - -int SequenceStackGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - int ret = KernelMod::Resize(inputs, outputs); - if (ret != 0) { - return ret; - } - workspace_size_list_.clear(); - tuple_shape_ = inputs[0]->GetShapeVector(); - if (tuple_shape_.empty()) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0"; - } - std::vector shape_vec_item; - std::copy(tuple_shape_.begin() + 1, tuple_shape_.end(), std::back_inserter(shape_vec_item)); - axis_ = GetValue(primitive_->GetAttr(ops::kAxis)); - if (axis_ < 0) { - axis_ += (SizeToInt(shape_vec_item.size()) + 1); - } - auto origin_data_format = kOpFormat_DEFAULT; - auto input_format = GetFormatFromEnumToStr(inputs[0]->format()); - axis_ = AxisTransform(origin_data_format, input_format, axis_); - input_num_ = tuple_shape_[0]; - inputs_host_.resize(input_num_); - dims_behind_axis_ = 1; - for (size_t i = IntToSize(axis_); i < shape_vec_item.size(); i++) { - dims_behind_axis_ *= static_cast(shape_vec_item[i]); - } - workspace_size_list_.push_back(sizeof(void *) * input_num_); - auto output_shape = outputs[0]->GetShapeVector(); - output_size_ = 1; - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= static_cast(output_shape[i]); - } - return KRET_OK; -} - -template -bool SequenceStackGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - const auto input_addr = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - T **inputs_array = GetDeviceAddress(workspace, 0); - size_t element_num = outputs[0]->size() / sizeof(T) / input_num_; - for (int i = 0; i < input_num_; i++) { - T *tmp_addr = input_addr + i * element_num; - inputs_host_[i] = tmp_addr; - } - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(inputs_array, inputs_host_.data(), sizeof(T *) * input_num_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr_)), - "SequenceStack opt cudaMemcpyAsync inputs failed"); - auto status = PackKernel(output_size_, input_num_, dims_behind_axis_, inputs_array, output, - reinterpret_cast(stream_ptr_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -#define SEQUENCE_STACK_KERNEL_REG(ms_type, builtin_type) \ - { \ - KernelAttr().AddAllSameAttr(true).AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \ - &SequenceStackGpuKernelMod::LaunchKernel \ - } - -const SequenceStackGpuKernelMod::FuncList &SequenceStackGpuKernelMod::GetFuncList() const { - static const FuncList func_list = {SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt8, int8_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt16, int16_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt32, int32_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt64, int64_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt8, uint8_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt16, uint16_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt32, uint32_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt64, uint64_t), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat16, half), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat32, float), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat64, double), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeComplex64, Complex), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeComplex128, Complex), - SEQUENCE_STACK_KERNEL_REG(kNumberTypeBool, bool)}; - return func_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SequenceStack, SequenceStackGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pack.cuh" +#include "mindspore/core/ops/sequence_stack.h" +#include "mindspore/ccsrc/kernel/format_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int kInputsNum = 1; +constexpr int kOutputsNum = 1; +} // namespace + +bool SequenceStackGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + + return MatchKernelFunc(kernel_name_, inputs, outputs); +} + +int SequenceStackGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + int ret = KernelMod::Resize(inputs, outputs); + if (ret != 0) { + return ret; + } + workspace_size_list_.clear(); + tuple_shape_ = inputs[0]->GetShapeVector(); + if (tuple_shape_.empty()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0"; + } + std::vector shape_vec_item; + std::copy(tuple_shape_.begin() + 1, tuple_shape_.end(), std::back_inserter(shape_vec_item)); + axis_ = GetValue(primitive_->GetAttr(ops::kAxis)); + if (axis_ < 0) { + axis_ += (SizeToInt(shape_vec_item.size()) + 1); + } + auto origin_data_format = kOpFormat_DEFAULT; + auto input_format = GetFormatFromEnumToStr(inputs[0]->format()); + axis_ = AxisTransform(origin_data_format, input_format, axis_); + input_num_ = tuple_shape_[0]; + inputs_host_.resize(input_num_); + dims_behind_axis_ = 1; + for (size_t i = IntToSize(axis_); i < shape_vec_item.size(); i++) { + dims_behind_axis_ *= static_cast(shape_vec_item[i]); + } + workspace_size_list_.push_back(sizeof(void *) * input_num_); + auto output_shape = outputs[0]->GetShapeVector(); + output_size_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= static_cast(output_shape[i]); + } + return KRET_OK; +} + +template +bool SequenceStackGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + const auto input_addr = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + T **inputs_array = GetDeviceAddress(workspace, 0); + size_t element_num = outputs[0]->size() / sizeof(T) / input_num_; + for (int i = 0; i < input_num_; i++) { + T *tmp_addr = input_addr + i * element_num; + inputs_host_[i] = tmp_addr; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(inputs_array, inputs_host_.data(), sizeof(T *) * input_num_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "SequenceStack opt cudaMemcpyAsync inputs failed"); + auto status = PackKernel(output_size_, input_num_, dims_behind_axis_, inputs_array, output, + reinterpret_cast(stream_ptr_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +#define SEQUENCE_STACK_KERNEL_REG(ms_type, builtin_type) \ + { \ + KernelAttr().AddAllSameAttr(true).AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \ + &SequenceStackGpuKernelMod::LaunchKernel \ + } + +const SequenceStackGpuKernelMod::FuncList &SequenceStackGpuKernelMod::GetFuncList() const { + static const FuncList func_list = {SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt8, int8_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt16, int16_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt32, int32_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeInt64, int64_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt8, uint8_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt16, uint16_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt32, uint32_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeUInt64, uint64_t), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat16, half), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat32, float), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeFloat64, double), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeComplex64, Complex), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeComplex128, Complex), + SEQUENCE_STACK_KERNEL_REG(kNumberTypeBool, bool)}; + return func_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SequenceStack, SequenceStackGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h index 25c7d75ce0a..d3eeb01feea 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_stack_gpu_kernel.h @@ -1,80 +1,80 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pack.cuh" -#include "mindspore/core/ops/stack.h" - -namespace mindspore { -namespace kernel { -template -using Complex = mindspore::utils::Complex; -class SequenceStackGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper { - public: - SequenceStackGpuKernelMod() = default; - ~SequenceStackGpuKernelMod() override = default; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - MS_EXCEPTION_IF_NULL(kernel_func_); - stream_ptr_ = stream_ptr; - return kernel_func_(this, inputs, workspace, outputs); - } - - using FuncList = std::vector>; - const FuncList &GetFuncList() const override; - - void ResetResource() noexcept { - axis_ = 0; - input_num_ = 1; - output_size_ = 0; - dims_behind_axis_ = 1; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - private: - std::vector GetOpSupport() override { return OpSupport(); } - - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - int axis_; - int input_num_{1}; - size_t output_size_; - size_t dims_behind_axis_; - std::vector tuple_shape_; - std::vector inputs_shape_; - std::vector inputs_host_; - std::string kernel_name_; - void *stream_ptr_{nullptr}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pack.cuh" +#include "mindspore/core/ops/stack.h" + +namespace mindspore { +namespace kernel { +template +using Complex = mindspore::utils::Complex; +class SequenceStackGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper { + public: + SequenceStackGpuKernelMod() = default; + ~SequenceStackGpuKernelMod() override = default; + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(kernel_func_); + stream_ptr_ = stream_ptr; + return kernel_func_(this, inputs, workspace, outputs); + } + + using FuncList = std::vector>; + const FuncList &GetFuncList() const override; + + void ResetResource() noexcept { + axis_ = 0; + input_num_ = 1; + output_size_ = 0; + dims_behind_axis_ = 1; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + private: + std::vector GetOpSupport() override { return OpSupport(); } + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + int axis_; + int input_num_{1}; + size_t output_size_; + size_t dims_behind_axis_; + std::vector tuple_shape_; + std::vector inputs_shape_; + std::vector inputs_host_; + std::string kernel_name_; + void *stream_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SEQUENCE_STACK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.cc index 41162994e0d..1d5649b5a17 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.cc @@ -1,283 +1,283 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -bool SparseMatrixNNZGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type is unsupported, got: " << kernel_attr << "."; - return false; - } - kernel_func_ = func_list_[index].second; - unit_indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex4).dtype); - return true; -} - -int SparseMatrixNNZGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - for (const auto &output : outputs) { - // If any output shape contains -1, means input shape is dynamic, so just return do nothing. - auto output_shape = output->GetShapeVector(); - if (!IsValidShape(output_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - - output_size_list_.push_back(output_elements_ * sizeof(int32_t)); - return KRET_OK; -} - -template -bool SparseMatrixNNZGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - T *batch_pointers = GetDeviceAddress(inputs, kIndex1); - int32_t *output = GetDeviceAddress(outputs, kIndex0); - - auto status = CalSparseMatrixNNZ(output_elements_, batch_pointers, output, device_id_, - reinterpret_cast(cuda_stream_)); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::vector> - SparseMatrixNNZGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeBool) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeBool) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeComplex64) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeComplex128) - .AddOutputAttr(kNumberTypeInt32), - &SparseMatrixNNZGpuKernelMod::LaunchKernel}, -}; - -std::vector SparseMatrixNNZGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseMatrixNNZ, SparseMatrixNNZGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +bool SparseMatrixNNZGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type is unsupported, got: " << kernel_attr << "."; + return false; + } + kernel_func_ = func_list_[index].second; + unit_indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex4).dtype); + return true; +} + +int SparseMatrixNNZGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any output shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + + output_size_list_.push_back(output_elements_ * sizeof(int32_t)); + return KRET_OK; +} + +template +bool SparseMatrixNNZGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *batch_pointers = GetDeviceAddress(inputs, kIndex1); + int32_t *output = GetDeviceAddress(outputs, kIndex0); + + auto status = CalSparseMatrixNNZ(output_elements_, batch_pointers, output, device_id_, + reinterpret_cast(cuda_stream_)); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::vector> + SparseMatrixNNZGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeInt32), + &SparseMatrixNNZGpuKernelMod::LaunchKernel}, +}; + +std::vector SparseMatrixNNZGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseMatrixNNZ, SparseMatrixNNZGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h index 5cd86f42d51..9db0efc3dad 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_matrix_nnz_gpu_kernel.h @@ -1,85 +1,85 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "abstract/utils.h" -#include "mindspore/core/ops/sparse_matrix_nnz.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class SparseMatrixNNZGpuKernelMod : public NativeGpuKernelMod { - public: - SparseMatrixNNZGpuKernelMod() { ResetResource(); } - ~SparseMatrixNNZGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - is_null_input_ = false; - output_elements_ = 0; - output_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using SparseMatrixNNZFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t unit_indices_size_{1}; - size_t unit_values_size_{1}; - size_t output_elements_{0}; - SparseMatrixNNZFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "abstract/utils.h" +#include "mindspore/core/ops/sparse_matrix_nnz.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_matrix_nnz_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SparseMatrixNNZGpuKernelMod : public NativeGpuKernelMod { + public: + SparseMatrixNNZGpuKernelMod() { ResetResource(); } + ~SparseMatrixNNZGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + is_null_input_ = false; + output_elements_ = 0; + output_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using SparseMatrixNNZFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t unit_indices_size_{1}; + size_t unit_values_size_{1}; + size_t output_elements_{0}; + SparseMatrixNNZFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_MATRIX_NNZ_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc index dbcb536b9cf..c32f448f41b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc @@ -1,506 +1,506 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto Sparse_Segment_Sum = "SparseSegmentSum"; -constexpr auto Sparse_Segment_Sum_With_Num_Segments = "SparseSegmentSumWithNumSegments"; -constexpr auto Sparse_Segment_Sqrt_N = "SparseSegmentSqrtN"; -constexpr auto Sparse_Segment_Sqrt_N_With_Num_Segments = "SparseSegmentSqrtNWithNumSegments"; -constexpr size_t kNumber1 = 1; -constexpr size_t kNumber3 = 3; -constexpr size_t kNumber4 = 4; -} // namespace - -bool SparseSegmentOpsGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - if (kernel_type_ == "SparseSegmentSum" || kernel_type_ == "SparseSegmentSqrtN") { - flag_ = true; - } else { - flag_ = false; - } - size_t inputs_num = flag_ ? kNumber3 : kNumber4; - size_t outputs_num = kNumber1; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; - return false; - } - kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; - unit_x_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); - return true; -} - -int SparseSegmentOpsGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - for (const auto &output : outputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto output_shape = output->GetShapeVector(); - if (!IsValidShape(output_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - std::vector x_shape = inputs.at(kIndex0)->GetShapeVector(); - x_shape_0_ = x_shape[0]; - x_elements_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{}); - outer_size_ = x_shape.front(); - if (outer_size_ == 0) { - return KRET_RESIZE_FAILED; - } - inner_size_ = x_elements_ / x_shape.front(); - std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); - idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); - output_dim0_ = LongToSize(output_shape.front()); - - size_t output_size = output_elements_ * unit_x_size_; - output_size_list_.push_back(output_size); - workspace_size_list_.push_back((output_dim0_ + 1) * sizeof(size_t)); - return KRET_OK; -} - -template -bool SparseSegmentOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - R *x_ptr = GetDeviceAddress(inputs, kIndex0); - S *indices_ptr = GetDeviceAddress(inputs, kIndex1); - S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); - R *y_ptr = GetDeviceAddress(outputs, kIndex0); - size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); - auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; - if (any(x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { - return false; - } - cudaStream_t stream = reinterpret_cast(cuda_stream_); - std::vector indices_host; - std::vector segment_ids_host; - std::vector num_segments_host; - indices_host.resize(idx_seg_elements_); - segment_ids_host.resize(idx_seg_elements_); - num_segments_host.resize(kNumber1); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), - "For '" << kernel_name_ << "', cudaMemcpy failed."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, - idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), - "For '" << kernel_name_ << "', cudaMemcpy failed."); - if (!flag_) { - auto num_segments_ptr = GetDeviceAddress(inputs, kIndex3); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(num_segments_host.data(), num_segments_ptr, sizeof(S), cudaMemcpyDeviceToHost, stream), - "For '" << kernel_name_ << "', cudaMemcpy failed."); - } - if (cudaStreamQuery(stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), - "For '" << kernel_name_ << "', cuda Stream Sync Failed."); - } - if (segment_ids_host[0] != 0 && flag_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', indices in 'segment_ids' should be contiguous and start from 0."; - } - for (size_t i = 1; i < idx_seg_elements_; i++) { - if (segment_ids_host[i] < segment_ids_host[i - 1]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; - } - if (segment_ids_host[i] - segment_ids_host[i - 1] > 1 && flag_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', indices in 'segment_ids' should be contiguous and start from 0."; - } - } - if (segment_ids_host[idx_seg_elements_ - 1] >= num_segments_host[kIndex0] && !flag_) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ - << "', num_segments must bigger than the last number of segment_ids."; - } - for (size_t i = 0; i < idx_seg_elements_; i++) { - if (indices_host[i] >= static_cast(x_shape_0_)) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape."; - } - } - auto status = - CalSparseSegmentCombination(kernel_type_, x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, outer_size_, - inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::map>> - SparseSegmentOpsGpuKernelMod::kernel_attr_map_ = { - {Sparse_Segment_Sum, - {{KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, - {Sparse_Segment_Sum_With_Num_Segments, - {{KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, - {Sparse_Segment_Sqrt_N, - {{KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, - {Sparse_Segment_Sqrt_N_With_Num_Segments, - {{KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ - -std::vector SparseSegmentOpsGpuKernelMod::GetOpSupport() { - auto iter = kernel_attr_map_.find(kernel_type_); - if (iter == kernel_attr_map_.end()) { - MS_EXCEPTION(ValueError) << "For 'SparseSegmentOpsOp', only support these types: " - << kernel::Map2Str>>( - kernel_attr_map_) - << " currently, but got " << kernel_name_; - } - std::vector support_list; - (void)std::transform( - iter->second.begin(), iter->second.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSum, - []() { return std::make_shared(Sparse_Segment_Sum); }); -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumWithNumSegments, []() { - return std::make_shared(Sparse_Segment_Sum_With_Num_Segments); -}); -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtN, []() { - return std::make_shared(Sparse_Segment_Sqrt_N); -}); -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNWithNumSegments, []() { - return std::make_shared(Sparse_Segment_Sqrt_N_With_Num_Segments); -}); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto Sparse_Segment_Sum = "SparseSegmentSum"; +constexpr auto Sparse_Segment_Sum_With_Num_Segments = "SparseSegmentSumWithNumSegments"; +constexpr auto Sparse_Segment_Sqrt_N = "SparseSegmentSqrtN"; +constexpr auto Sparse_Segment_Sqrt_N_With_Num_Segments = "SparseSegmentSqrtNWithNumSegments"; +constexpr size_t kNumber1 = 1; +constexpr size_t kNumber3 = 3; +constexpr size_t kNumber4 = 4; +} // namespace + +bool SparseSegmentOpsGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + if (kernel_type_ == "SparseSegmentSum" || kernel_type_ == "SparseSegmentSqrtN") { + flag_ = true; + } else { + flag_ = false; + } + size_t inputs_num = flag_ ? kNumber3 : kNumber4; + size_t outputs_num = kNumber1; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; + return false; + } + kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; + unit_x_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); + return true; +} + +int SparseSegmentOpsGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + std::vector x_shape = inputs.at(kIndex0)->GetShapeVector(); + x_shape_0_ = x_shape[0]; + x_elements_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{}); + outer_size_ = x_shape.front(); + if (outer_size_ == 0) { + return KRET_RESIZE_FAILED; + } + inner_size_ = x_elements_ / x_shape.front(); + std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); + idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); + output_dim0_ = LongToSize(output_shape.front()); + + size_t output_size = output_elements_ * unit_x_size_; + output_size_list_.push_back(output_size); + workspace_size_list_.push_back((output_dim0_ + 1) * sizeof(size_t)); + return KRET_OK; +} + +template +bool SparseSegmentOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + R *x_ptr = GetDeviceAddress(inputs, kIndex0); + S *indices_ptr = GetDeviceAddress(inputs, kIndex1); + S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); + R *y_ptr = GetDeviceAddress(outputs, kIndex0); + size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); + auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; + if (any(x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { + return false; + } + cudaStream_t stream = reinterpret_cast(cuda_stream_); + std::vector indices_host; + std::vector segment_ids_host; + std::vector num_segments_host; + indices_host.resize(idx_seg_elements_); + segment_ids_host.resize(idx_seg_elements_); + num_segments_host.resize(kNumber1); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "For '" << kernel_name_ << "', cudaMemcpy failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, + idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "For '" << kernel_name_ << "', cudaMemcpy failed."); + if (!flag_) { + auto num_segments_ptr = GetDeviceAddress(inputs, kIndex3); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(num_segments_host.data(), num_segments_ptr, sizeof(S), cudaMemcpyDeviceToHost, stream), + "For '" << kernel_name_ << "', cudaMemcpy failed."); + } + if (cudaStreamQuery(stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), + "For '" << kernel_name_ << "', cuda Stream Sync Failed."); + } + if (segment_ids_host[0] != 0 && flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', indices in 'segment_ids' should be contiguous and start from 0."; + } + for (size_t i = 1; i < idx_seg_elements_; i++) { + if (segment_ids_host[i] < segment_ids_host[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; + } + if (segment_ids_host[i] - segment_ids_host[i - 1] > 1 && flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', indices in 'segment_ids' should be contiguous and start from 0."; + } + } + if (segment_ids_host[idx_seg_elements_ - 1] >= num_segments_host[kIndex0] && !flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', num_segments must bigger than the last number of segment_ids."; + } + for (size_t i = 0; i < idx_seg_elements_; i++) { + if (indices_host[i] >= static_cast(x_shape_0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape."; + } + } + auto status = + CalSparseSegmentCombination(kernel_type_, x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, outer_size_, + inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::map>> + SparseSegmentOpsGpuKernelMod::kernel_attr_map_ = { + {Sparse_Segment_Sum, + {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sum_With_Num_Segments, + {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N_With_Num_Segments, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ + +std::vector SparseSegmentOpsGpuKernelMod::GetOpSupport() { + auto iter = kernel_attr_map_.find(kernel_type_); + if (iter == kernel_attr_map_.end()) { + MS_EXCEPTION(ValueError) << "For 'SparseSegmentOpsOp', only support these types: " + << kernel::Map2Str>>( + kernel_attr_map_) + << " currently, but got " << kernel_name_; + } + std::vector support_list; + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSum, + []() { return std::make_shared(Sparse_Segment_Sum); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumWithNumSegments, []() { + return std::make_shared(Sparse_Segment_Sum_With_Num_Segments); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtN, []() { + return std::make_shared(Sparse_Segment_Sqrt_N); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNWithNumSegments, []() { + return std::make_shared(Sparse_Segment_Sqrt_N_With_Num_Segments); +}); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h index 19fe45f269a..25459529517 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h @@ -1,98 +1,98 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseSegmentOpsGpuKernelMod : public NativeGpuKernelMod { - public: - explicit SparseSegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} - ~SparseSegmentOpsGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - outer_size_ = 0; - inner_size_ = 0; - x_elements_ = 0; - x_shape_0_ = 0; - idx_seg_elements_ = 0; - output_dim0_ = 0; - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - using SSLaunchFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t outer_size_{0}; - size_t inner_size_{0}; - size_t x_elements_{0}; - size_t x_shape_0_{0}; - size_t idx_seg_elements_{0}; - size_t output_dim0_{0}; - size_t output_elements_{0}; - size_t unit_x_size_{1}; - size_t unit_idx_seg_size_{1}; - std::string kernel_type_{"Unknown"}; - bool is_null_input_{false}; - size_t flag_{0}; - void *cuda_stream_{nullptr}; - SSLaunchFunc kernel_func_{}; - static std::map>> kernel_attr_map_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentOpsGpuKernelMod : public NativeGpuKernelMod { + public: + explicit SparseSegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~SparseSegmentOpsGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + outer_size_ = 0; + inner_size_ = 0; + x_elements_ = 0; + x_shape_0_ = 0; + idx_seg_elements_ = 0; + output_dim0_ = 0; + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using SSLaunchFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t outer_size_{0}; + size_t inner_size_{0}; + size_t x_elements_{0}; + size_t x_shape_0_{0}; + size_t idx_seg_elements_{0}; + size_t output_dim0_{0}; + size_t output_elements_{0}; + size_t unit_x_size_{1}; + size_t unit_idx_seg_size_{1}; + std::string kernel_type_{"Unknown"}; + bool is_null_input_{false}; + size_t flag_{0}; + void *cuda_stream_{nullptr}; + SSLaunchFunc kernel_func_{}; + static std::map>> kernel_attr_map_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.cc index 5e50e6afcf2..b99e40da8fe 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.cc @@ -1,289 +1,289 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto Sparse_Sparse_Maximum = "SparseSparseMaximum"; -constexpr auto Sparse_Sparse_Minimum = "SparseSparseMinimum"; -constexpr int kSparseSparseInputsNum = 6; -constexpr int kSparseSparseOutputsNum = 2; -constexpr size_t kSparseSparseIndex0 = 0; -constexpr size_t kSparseSparseIndex1 = 1; -constexpr size_t kSparseSparseIndex2 = 2; -constexpr size_t kSparseSparseIndex3 = 3; -constexpr size_t kSparseSparseIndex4 = 4; -constexpr size_t kSparseSparseIndex5 = 5; -} // namespace - -bool SparseSparseGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSparseInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSparseOutputsNum, kernel_name_); - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kSparseSparseIndex0).dtype); - values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kSparseSparseIndex1).dtype); - return true; -} - -int SparseSparseGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - auto a_indices_shape = inputs.at(kSparseSparseIndex0)->GetShapeVector(); - auto a_values_shape = inputs.at(kSparseSparseIndex1)->GetShapeVector(); - auto dense_shape = inputs.at(kSparseSparseIndex2)->GetShapeVector(); - auto b_indices_shape = inputs.at(kSparseSparseIndex3)->GetShapeVector(); - auto b_values_shape = inputs.at(kSparseSparseIndex4)->GetShapeVector(); - rank_ = a_indices_shape.at(1); - auto a_indices_size = std::accumulate(a_indices_shape.begin(), a_indices_shape.end(), 1, std::multiplies()); - auto a_values_size = std::accumulate(a_values_shape.begin(), a_values_shape.end(), 1, std::multiplies()); - auto dense_shape_size = std::accumulate(dense_shape.begin(), dense_shape.end(), 1, std::multiplies()); - auto b_indices_size = std::accumulate(b_indices_shape.begin(), b_indices_shape.end(), 1, std::multiplies()); - auto b_values_size = std::accumulate(b_values_shape.begin(), b_values_shape.end(), 1, std::multiplies()); - if (a_indices_size == 0 || a_values_size == 0 || dense_shape_size == 0 || b_indices_size == 0 || b_values_size == 0) { - is_null_input_ = true; - } - - auto a_row_num = a_values_shape[0]; - auto b_row_num = b_values_shape[0]; - a_indices_num_ = a_values_shape[0]; - b_indices_num_ = b_values_shape[0]; - auto a_values_num = a_values_shape[0]; - auto b_values_num = b_values_shape[0]; - size_t ab_status = (a_values_num * b_values_num) * sizeof(int64_t); - size_t sum = sizeof(int64_t); - size_t ab_status1 = (a_values_num + b_values_num) * sizeof(int64_t); - size_t ab_status2 = (a_values_num + b_values_num) * sizeof(int64_t); - - output_size_list_.push_back((a_row_num + b_row_num) * rank_ * indices_size_); - output_size_list_.push_back((a_row_num + b_row_num) * values_size_); - workspace_size_list_.push_back(ab_status); - workspace_size_list_.push_back(sum); - workspace_size_list_.push_back(ab_status1); - workspace_size_list_.push_back(ab_status2); - return KRET_OK; -} - -template -bool SparseSparseGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - cuda_stream_ = reinterpret_cast(stream_ptr); - auto a_indices_ptr = GetDeviceAddress(inputs, kSparseSparseIndex0); - MS_EXCEPTION_IF_NULL(a_indices_ptr); - auto a_values_ptr = GetDeviceAddress(inputs, kSparseSparseIndex1); - MS_EXCEPTION_IF_NULL(a_values_ptr); - auto dense_shape_ptr1 = GetDeviceAddress(inputs, kSparseSparseIndex2); - MS_EXCEPTION_IF_NULL(dense_shape_ptr1); - auto b_indices_ptr = GetDeviceAddress(inputs, kSparseSparseIndex3); - MS_EXCEPTION_IF_NULL(b_indices_ptr); - auto b_values_ptr = GetDeviceAddress(inputs, kSparseSparseIndex4); - MS_EXCEPTION_IF_NULL(b_values_ptr); - auto dense_shape_ptr2 = GetDeviceAddress(inputs, kSparseSparseIndex5); - MS_EXCEPTION_IF_NULL(dense_shape_ptr2); - auto sum_indices_ptr = GetDeviceAddress(outputs, kSparseSparseIndex0); - MS_EXCEPTION_IF_NULL(sum_indices_ptr); - auto sum_values_ptr = GetDeviceAddress(outputs, kSparseSparseIndex1); - MS_EXCEPTION_IF_NULL(sum_values_ptr); - auto ab_status_ptr = GetDeviceAddress(workspace, kSparseSparseIndex0); - MS_EXCEPTION_IF_NULL(ab_status_ptr); - auto sum_ptr = GetDeviceAddress(workspace, kSparseSparseIndex1); - MS_EXCEPTION_IF_NULL(sum_ptr); - auto ab_status_ptr1 = GetDeviceAddress(workspace, kSparseSparseIndex2); - MS_EXCEPTION_IF_NULL(ab_status_ptr1); - auto ab_status_ptr2 = GetDeviceAddress(workspace, kSparseSparseIndex3); - MS_EXCEPTION_IF_NULL(ab_status_ptr2); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemsetAsync(sum_ptr, static_cast(1), workspace.at(kSparseSparseIndex1)->size(), cuda_stream_), - "For SparseSparseOperators, failed to cudaMemset."); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(ab_status_ptr1, static_cast(kSparseSparseIndex3), - workspace.at(kSparseSparseIndex2)->size(), cuda_stream_), - "For SparseSparseOperators, failed to cudaMemset."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), - "For SparseSparseOperators, cudaStreamSynchronize failed."); - std::vector x1_shape(rank_); - std::vector x2_shape(rank_); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(&x1_shape[0], dense_shape_ptr1, rank_ * sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), - "For SparseSparseOperators, cudaMemcpyAsync failed."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(&x2_shape[0], dense_shape_ptr2, rank_ * sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), - "For SparseSparseOperators, cudaMemcpyAsync failed."); - if (cudaStreamQuery(cuda_stream_) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), - "For 'SparseSparseOperators', cuda Stream Sync Failed."); - } - for (int64_t n = 0; n < rank_; n++) { - if (x1_shape[n] != x2_shape[n]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', operands' shapes do not match."; - } - } - cudaError_t status = cudaErrorNotReady; - if (kernel_name_ == "SparseSparseMaximum") { - status = SparseSparseMaximum(a_indices_ptr, a_values_ptr, b_indices_ptr, b_values_ptr, sum_indices_ptr, - sum_values_ptr, ab_status_ptr, sum_ptr, a_indices_num_, b_indices_num_, rank_, - cuda_stream_, device_id_, ab_status_ptr1, ab_status_ptr2); - } else { - status = SparseSparseMinimum(a_indices_ptr, a_values_ptr, b_indices_ptr, b_values_ptr, sum_indices_ptr, - sum_values_ptr, ab_status_ptr, sum_ptr, a_indices_num_, b_indices_num_, rank_, - cuda_stream_, device_id_, ab_status_ptr1, ab_status_ptr2); - } - CHECK_CUDA_STATUS(status, kernel_name_); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(&real_output_size_, sum_ptr, sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), - "For SparseSparseOperators, failed to cudaMemset."); - if (cudaStreamQuery(cuda_stream_) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), - "For 'SparseSparseOperators', cuda Stream Sync Failed."); - } - return true; -} - -void SparseSparseGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), - "For SparseSparseOperators, cudaStreamSynchronize failed."); - std::vector sum_indices_shape = {real_output_size_, static_cast(rank_)}; - std::vector sum_values_shape = {real_output_size_}; - outputs[kSparseSparseIndex0]->SetShapeVector(sum_indices_shape); - outputs[kSparseSparseIndex1]->SetShapeVector(sum_values_shape); - outputs[kSparseSparseIndex0]->set_size(LongToSize(real_output_size_ * rank_) * - UnitSizeInBytes(outputs[kSparseSparseIndex0]->dtype_id())); - outputs[kSparseSparseIndex1]->set_size(LongToSize(real_output_size_) * - UnitSizeInBytes(outputs[kSparseSparseIndex1]->dtype_id())); -} - -std::vector> SparseSparseGpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSparseGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSparseGpuKernelMod::LaunchKernel}, -}; - -std::vector SparseSparseGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSparseMaximum, SparseSparseGpuKernelMod); -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSparseMinimum, SparseSparseGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto Sparse_Sparse_Maximum = "SparseSparseMaximum"; +constexpr auto Sparse_Sparse_Minimum = "SparseSparseMinimum"; +constexpr int kSparseSparseInputsNum = 6; +constexpr int kSparseSparseOutputsNum = 2; +constexpr size_t kSparseSparseIndex0 = 0; +constexpr size_t kSparseSparseIndex1 = 1; +constexpr size_t kSparseSparseIndex2 = 2; +constexpr size_t kSparseSparseIndex3 = 3; +constexpr size_t kSparseSparseIndex4 = 4; +constexpr size_t kSparseSparseIndex5 = 5; +} // namespace + +bool SparseSparseGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseSparseInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseSparseOutputsNum, kernel_name_); + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kSparseSparseIndex0).dtype); + values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kSparseSparseIndex1).dtype); + return true; +} + +int SparseSparseGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + auto a_indices_shape = inputs.at(kSparseSparseIndex0)->GetShapeVector(); + auto a_values_shape = inputs.at(kSparseSparseIndex1)->GetShapeVector(); + auto dense_shape = inputs.at(kSparseSparseIndex2)->GetShapeVector(); + auto b_indices_shape = inputs.at(kSparseSparseIndex3)->GetShapeVector(); + auto b_values_shape = inputs.at(kSparseSparseIndex4)->GetShapeVector(); + rank_ = a_indices_shape.at(1); + auto a_indices_size = std::accumulate(a_indices_shape.begin(), a_indices_shape.end(), 1, std::multiplies()); + auto a_values_size = std::accumulate(a_values_shape.begin(), a_values_shape.end(), 1, std::multiplies()); + auto dense_shape_size = std::accumulate(dense_shape.begin(), dense_shape.end(), 1, std::multiplies()); + auto b_indices_size = std::accumulate(b_indices_shape.begin(), b_indices_shape.end(), 1, std::multiplies()); + auto b_values_size = std::accumulate(b_values_shape.begin(), b_values_shape.end(), 1, std::multiplies()); + if (a_indices_size == 0 || a_values_size == 0 || dense_shape_size == 0 || b_indices_size == 0 || b_values_size == 0) { + is_null_input_ = true; + } + + auto a_row_num = a_values_shape[0]; + auto b_row_num = b_values_shape[0]; + a_indices_num_ = a_values_shape[0]; + b_indices_num_ = b_values_shape[0]; + auto a_values_num = a_values_shape[0]; + auto b_values_num = b_values_shape[0]; + size_t ab_status = (a_values_num * b_values_num) * sizeof(int64_t); + size_t sum = sizeof(int64_t); + size_t ab_status1 = (a_values_num + b_values_num) * sizeof(int64_t); + size_t ab_status2 = (a_values_num + b_values_num) * sizeof(int64_t); + + output_size_list_.push_back((a_row_num + b_row_num) * rank_ * indices_size_); + output_size_list_.push_back((a_row_num + b_row_num) * values_size_); + workspace_size_list_.push_back(ab_status); + workspace_size_list_.push_back(sum); + workspace_size_list_.push_back(ab_status1); + workspace_size_list_.push_back(ab_status2); + return KRET_OK; +} + +template +bool SparseSparseGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + cuda_stream_ = reinterpret_cast(stream_ptr); + auto a_indices_ptr = GetDeviceAddress(inputs, kSparseSparseIndex0); + MS_EXCEPTION_IF_NULL(a_indices_ptr); + auto a_values_ptr = GetDeviceAddress(inputs, kSparseSparseIndex1); + MS_EXCEPTION_IF_NULL(a_values_ptr); + auto dense_shape_ptr1 = GetDeviceAddress(inputs, kSparseSparseIndex2); + MS_EXCEPTION_IF_NULL(dense_shape_ptr1); + auto b_indices_ptr = GetDeviceAddress(inputs, kSparseSparseIndex3); + MS_EXCEPTION_IF_NULL(b_indices_ptr); + auto b_values_ptr = GetDeviceAddress(inputs, kSparseSparseIndex4); + MS_EXCEPTION_IF_NULL(b_values_ptr); + auto dense_shape_ptr2 = GetDeviceAddress(inputs, kSparseSparseIndex5); + MS_EXCEPTION_IF_NULL(dense_shape_ptr2); + auto sum_indices_ptr = GetDeviceAddress(outputs, kSparseSparseIndex0); + MS_EXCEPTION_IF_NULL(sum_indices_ptr); + auto sum_values_ptr = GetDeviceAddress(outputs, kSparseSparseIndex1); + MS_EXCEPTION_IF_NULL(sum_values_ptr); + auto ab_status_ptr = GetDeviceAddress(workspace, kSparseSparseIndex0); + MS_EXCEPTION_IF_NULL(ab_status_ptr); + auto sum_ptr = GetDeviceAddress(workspace, kSparseSparseIndex1); + MS_EXCEPTION_IF_NULL(sum_ptr); + auto ab_status_ptr1 = GetDeviceAddress(workspace, kSparseSparseIndex2); + MS_EXCEPTION_IF_NULL(ab_status_ptr1); + auto ab_status_ptr2 = GetDeviceAddress(workspace, kSparseSparseIndex3); + MS_EXCEPTION_IF_NULL(ab_status_ptr2); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemsetAsync(sum_ptr, static_cast(1), workspace.at(kSparseSparseIndex1)->size(), cuda_stream_), + "For SparseSparseOperators, failed to cudaMemset."); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(ab_status_ptr1, static_cast(kSparseSparseIndex3), + workspace.at(kSparseSparseIndex2)->size(), cuda_stream_), + "For SparseSparseOperators, failed to cudaMemset."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), + "For SparseSparseOperators, cudaStreamSynchronize failed."); + std::vector x1_shape(rank_); + std::vector x2_shape(rank_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(&x1_shape[0], dense_shape_ptr1, rank_ * sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), + "For SparseSparseOperators, cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(&x2_shape[0], dense_shape_ptr2, rank_ * sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), + "For SparseSparseOperators, cudaMemcpyAsync failed."); + if (cudaStreamQuery(cuda_stream_) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), + "For 'SparseSparseOperators', cuda Stream Sync Failed."); + } + for (int64_t n = 0; n < rank_; n++) { + if (x1_shape[n] != x2_shape[n]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', operands' shapes do not match."; + } + } + cudaError_t status = cudaErrorNotReady; + if (kernel_name_ == "SparseSparseMaximum") { + status = SparseSparseMaximum(a_indices_ptr, a_values_ptr, b_indices_ptr, b_values_ptr, sum_indices_ptr, + sum_values_ptr, ab_status_ptr, sum_ptr, a_indices_num_, b_indices_num_, rank_, + cuda_stream_, device_id_, ab_status_ptr1, ab_status_ptr2); + } else { + status = SparseSparseMinimum(a_indices_ptr, a_values_ptr, b_indices_ptr, b_values_ptr, sum_indices_ptr, + sum_values_ptr, ab_status_ptr, sum_ptr, a_indices_num_, b_indices_num_, rank_, + cuda_stream_, device_id_, ab_status_ptr1, ab_status_ptr2); + } + CHECK_CUDA_STATUS(status, kernel_name_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(&real_output_size_, sum_ptr, sizeof(int64_t), cudaMemcpyDeviceToHost, cuda_stream_), + "For SparseSparseOperators, failed to cudaMemset."); + if (cudaStreamQuery(cuda_stream_) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), + "For 'SparseSparseOperators', cuda Stream Sync Failed."); + } + return true; +} + +void SparseSparseGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), + "For SparseSparseOperators, cudaStreamSynchronize failed."); + std::vector sum_indices_shape = {real_output_size_, static_cast(rank_)}; + std::vector sum_values_shape = {real_output_size_}; + outputs[kSparseSparseIndex0]->SetShapeVector(sum_indices_shape); + outputs[kSparseSparseIndex1]->SetShapeVector(sum_values_shape); + outputs[kSparseSparseIndex0]->set_size(LongToSize(real_output_size_ * rank_) * + UnitSizeInBytes(outputs[kSparseSparseIndex0]->dtype_id())); + outputs[kSparseSparseIndex1]->set_size(LongToSize(real_output_size_) * + UnitSizeInBytes(outputs[kSparseSparseIndex1]->dtype_id())); +} + +std::vector> SparseSparseGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSparseGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSparseGpuKernelMod::LaunchKernel}, +}; + +std::vector SparseSparseGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSparseMaximum, SparseSparseGpuKernelMod); +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSparseMinimum, SparseSparseGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h index 70d035fab64..4d70056f243 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_sparse_gpu_kernel.h @@ -1,91 +1,91 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_SPARSE_SPARSE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_SPARSE_SPARSE_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/sparse_sparse_maximum.h" -#include "mindspore/core/ops/sparse_sparse_minimum.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh" - -namespace mindspore { -namespace kernel { -class SparseSparseGpuKernelMod : public NativeGpuKernelMod { - public: - SparseSparseGpuKernelMod() {} - ~SparseSparseGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - return kernel_func_(this, inputs, workspace, outputs, cuda_stream); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - rank_ = 0; - a_indices_num_ = 0; - b_indices_num_ = 0; - real_output_size_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - protected: - bool IsNeedUpdateOutputShapeAndSize() override { return true; } - void UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr); - using SparseSparseFunc = - std::function &, - const std::vector &, const std::vector &, void *)>; - - private: - static std::vector> func_list_; - SparseSparseFunc kernel_func_{}; - cudaStream_t cuda_stream_; - bool is_null_input_{false}; - size_t indices_size_ = 0; - size_t values_size_ = 0; - int64_t real_output_size_ = 0; - int64_t rank_ = 0; - int64_t a_indices_num_ = 0; - int64_t b_indices_num_ = 0; -}; -} // namespace kernel -} // namespace mindspore -#endif +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_SPARSE_SPARSE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_SPARSE_SPARSE_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/sparse_sparse_maximum.h" +#include "mindspore/core/ops/sparse_sparse_minimum.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_maximum_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_sparse_minimum_impl.cuh" + +namespace mindspore { +namespace kernel { +class SparseSparseGpuKernelMod : public NativeGpuKernelMod { + public: + SparseSparseGpuKernelMod() {} + ~SparseSparseGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + return kernel_func_(this, inputs, workspace, outputs, cuda_stream); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + rank_ = 0; + a_indices_num_ = 0; + b_indices_num_ = 0; + real_output_size_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + bool IsNeedUpdateOutputShapeAndSize() override { return true; } + void UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + using SparseSparseFunc = + std::function &, + const std::vector &, const std::vector &, void *)>; + + private: + static std::vector> func_list_; + SparseSparseFunc kernel_func_{}; + cudaStream_t cuda_stream_; + bool is_null_input_{false}; + size_t indices_size_ = 0; + size_t values_size_ = 0; + int64_t real_output_size_ = 0; + int64_t rank_ = 0; + int64_t a_indices_num_ = 0; + int64_t b_indices_num_ = 0; +}; +} // namespace kernel +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.cc index 5123f44a4c6..b5f1def3913 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.cc @@ -1,194 +1,194 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" - -namespace mindspore { -namespace kernel { -constexpr size_t InputsNum = 4; -constexpr int64_t Kindex2 = 2; -constexpr int64_t Kindex3 = 3; -template -using Complex = mindspore::utils::Complex; -bool SparseSplitGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - num_split = static_cast(GetValue(primitive_->GetAttr("num_split"))); - - input_dtype_ = inputs[kIndex2]->dtype_id(); - size_t outputs_num = Kindex3 * num_split; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), InputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); - std::map kernel_list = { - {kNumberTypeUInt8, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeUInt16, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeInt64, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeInt32, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeInt16, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeInt8, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeFloat64, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeFloat32, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeFloat16, &SparseSplitGpuKernelMod::LaunchKernel}, - {kNumberTypeBool, &SparseSplitGpuKernelMod::LaunchKernel}, - }; - if (kernel_list.find(input_dtype_) == kernel_list.end()) { - MS_LOG(ERROR) << "SparseSplit does not support this data type."; - return false; - } - kernel_func_ = kernel_list[input_dtype_]; - return true; -} - -int SparseSplitGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { - return ret; - } - - auto input_indices_shape = inputs[kIndex1]->GetShapeVector(); - input_nnz_ = input_indices_shape[0]; - num_dim_ = input_indices_shape[1]; - - num_split = static_cast(GetValue(primitive_->GetAttr("num_split"))); - workspace_size_list_.push_back(num_split * sizeof(void *)); - workspace_size_list_.push_back(num_split * sizeof(void *)); - workspace_size_list_.push_back(num_split * sizeof(void *)); - workspace_size_list_.push_back(num_split * sizeof(int)); - workspace_size_list_.push_back((num_split + 1) * GetTypeByte(TypeIdToType(inputs[1]->dtype_id()))); - - return KRET_OK; -} - -template -bool SparseSplitGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - auto cuda_stream = reinterpret_cast(stream_ptr); - MS_EXCEPTION_IF_NULL(cuda_stream); - auto split_dim_ptr = GetDeviceAddress(inputs, kIndex0); - auto indices_ptr = GetDeviceAddress(inputs, kIndex1); - auto values_ptr = GetDeviceAddress(inputs, kIndex2); - auto shape_ptr = GetDeviceAddress(inputs, kIndex3); - std::vector y_indices_vec; - std::vector y_values_ptr; - std::vector out_shape_ptr; - std::vector out_shape_value(num_split * Kindex2, 0); - for (size_t i = 0; i < num_split; i++) { - y_indices_vec.push_back(GetDeviceAddress(outputs, i)); - y_values_ptr.push_back(GetDeviceAddress(outputs, num_split + i)); - out_shape_ptr.push_back(GetDeviceAddress(outputs, num_split * Kindex2 + i)); - } - auto d_y_indices_vec = GetDeviceAddress(workspace, kIndex0); - auto d_y_values_ptr = GetDeviceAddress(workspace, kIndex1); - auto d_out_shape_ptr = GetDeviceAddress(workspace, kIndex2); - - std::vector h_shape(Kindex2); - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(h_shape.data(), shape_ptr, sizeof(IndexType) * h_shape.size(), cudaMemcpyDeviceToHost, cuda_stream), - "For SparseSplit, cudaMemcpyAsync shape failed."); - if (cudaStreamQuery(cuda_stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), - "For 'SparseSplit', cuda Stream Sync Failed."); - } - - h_block.resize(num_split + 1); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(&h_split_dim, split_dim_ptr, sizeof(IndexType), cudaMemcpyDeviceToHost, cuda_stream), - "For SparseSplit, cudaMemcpyAsync split_dim failed."); - if (cudaStreamQuery(cuda_stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), - "For 'SparseSplit', cuda Stream Sync Failed."); - } - h_block[0] = 0; - int base_range = h_shape[h_split_dim] / num_split; - size_t res = h_shape[h_split_dim] - base_range * num_split; - for (size_t i = 1; i < h_block.size(); i++) { - if (i > 1) { - h_block[i] = h_block[i - 1] + base_range; - } else { - h_block[i] = base_range; - } - if (i <= res) { - h_block[i] += 1; - } - } - - for (size_t i = 0; i < num_split; i++) { - if (i == 0) { - out_shape_value[i * Kindex2 + h_split_dim] = (IndexType)h_block[i + 1]; - } else { - out_shape_value[i * Kindex2 + h_split_dim] = (IndexType)(h_block[i + 1] - h_block[i]); - } - out_shape_value[i * Kindex2 + 1 - h_split_dim] = h_shape[1 - h_split_dim]; - } - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(d_y_indices_vec, y_indices_vec.data(), sizeof(IndexType *) * num_split, cudaMemcpyHostToDevice, - cuda_stream), - "For SparseSplit, cudaMemcpyAsync failed."); - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(d_y_values_ptr, y_values_ptr.data(), sizeof(DataType *) * num_split, cudaMemcpyHostToDevice, - cuda_stream), - "For SparseSplit, cudaMemcpyAsync failed."); - - for (size_t i = 0; i < num_split; i++) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(out_shape_ptr[i], out_shape_value.data() + i * Kindex2, sizeof(IndexType) * Kindex2, - cudaMemcpyHostToDevice, cuda_stream), - "For SparseSplit out_shape_ptr, cudaMemcpyAsync failed."); - } - auto d_block_ptr = GetDeviceAddress(workspace, kIndex4); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(d_block_ptr, h_block.data(), sizeof(IndexType) * h_block.size(), - cudaMemcpyHostToDevice, cuda_stream), - "For SparseSplit, cudaMemcpyAsync failed."); - - auto sum_count_ptr = GetDeviceAddress(workspace, kIndex3); - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(sum_count_ptr, 0, workspace[kIndex3]->size(), cuda_stream), - "For SparseSplit, cudaMemsetAsync failed."); - - SparseSplit(split_dim_ptr, indices_ptr, values_ptr, shape_ptr, num_split, d_y_indices_vec, - d_y_values_ptr, d_out_shape_ptr, sum_count_ptr, input_nnz_, num_dim_, d_block_ptr, - cuda_stream); - h_blocks.resize(num_split); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(h_blocks.data(), sum_count_ptr, sizeof(int) * num_split, cudaMemcpyDeviceToHost, cuda_stream), - "For SparseSplit, cudaMemcpyAsync failed."); - - return true; -} - -void SparseSplitGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, - const std::vector &outputs) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), "SparseSplit cudaStreamSynchronized failed"); - for (size_t i = 0; i < num_split; i++) { - outputs[i]->SetShapeVector(ShapeVector({h_blocks[i], Kindex2})); // indices - outputs[i + num_split]->SetShapeVector(ShapeVector({h_blocks[i]})); // value - outputs[i + num_split * Kindex2]->SetShapeVector(ShapeVector({static_cast(num_dim_)})); // shape - outputs[i]->set_size(LongToSize(h_blocks[i] * Kindex2) * UnitSizeInBytes(outputs[i]->dtype_id())); - outputs[i + num_split]->set_size(LongToSize(h_blocks[i]) * UnitSizeInBytes(outputs[i + num_split]->dtype_id())); - outputs[i + num_split * Kindex2]->set_size(num_dim_ * - UnitSizeInBytes(outputs[i + num_split * Kindex2]->dtype_id())); - } -} - -std::vector SparseSplitGpuKernelMod::GetOpSupport() { return {KernelAttr().AddSkipCheckAttr(true)}; } - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSplit, SparseSplitGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_split_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +namespace mindspore { +namespace kernel { +constexpr size_t InputsNum = 4; +constexpr int64_t Kindex2 = 2; +constexpr int64_t Kindex3 = 3; +template +using Complex = mindspore::utils::Complex; +bool SparseSplitGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + num_split = static_cast(GetValue(primitive_->GetAttr("num_split"))); + + input_dtype_ = inputs[kIndex2]->dtype_id(); + size_t outputs_num = Kindex3 * num_split; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); + std::map kernel_list = { + {kNumberTypeUInt8, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeUInt16, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeInt64, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeInt32, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeInt16, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeInt8, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeFloat64, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeFloat32, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeFloat16, &SparseSplitGpuKernelMod::LaunchKernel}, + {kNumberTypeBool, &SparseSplitGpuKernelMod::LaunchKernel}, + }; + if (kernel_list.find(input_dtype_) == kernel_list.end()) { + MS_LOG(ERROR) << "SparseSplit does not support this data type."; + return false; + } + kernel_func_ = kernel_list[input_dtype_]; + return true; +} + +int SparseSplitGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { + return ret; + } + + auto input_indices_shape = inputs[kIndex1]->GetShapeVector(); + input_nnz_ = input_indices_shape[0]; + num_dim_ = input_indices_shape[1]; + + num_split = static_cast(GetValue(primitive_->GetAttr("num_split"))); + workspace_size_list_.push_back(num_split * sizeof(void *)); + workspace_size_list_.push_back(num_split * sizeof(void *)); + workspace_size_list_.push_back(num_split * sizeof(void *)); + workspace_size_list_.push_back(num_split * sizeof(int)); + workspace_size_list_.push_back((num_split + 1) * GetTypeByte(TypeIdToType(inputs[1]->dtype_id()))); + + return KRET_OK; +} + +template +bool SparseSplitGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + auto cuda_stream = reinterpret_cast(stream_ptr); + MS_EXCEPTION_IF_NULL(cuda_stream); + auto split_dim_ptr = GetDeviceAddress(inputs, kIndex0); + auto indices_ptr = GetDeviceAddress(inputs, kIndex1); + auto values_ptr = GetDeviceAddress(inputs, kIndex2); + auto shape_ptr = GetDeviceAddress(inputs, kIndex3); + std::vector y_indices_vec; + std::vector y_values_ptr; + std::vector out_shape_ptr; + std::vector out_shape_value(num_split * Kindex2, 0); + for (size_t i = 0; i < num_split; i++) { + y_indices_vec.push_back(GetDeviceAddress(outputs, i)); + y_values_ptr.push_back(GetDeviceAddress(outputs, num_split + i)); + out_shape_ptr.push_back(GetDeviceAddress(outputs, num_split * Kindex2 + i)); + } + auto d_y_indices_vec = GetDeviceAddress(workspace, kIndex0); + auto d_y_values_ptr = GetDeviceAddress(workspace, kIndex1); + auto d_out_shape_ptr = GetDeviceAddress(workspace, kIndex2); + + std::vector h_shape(Kindex2); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(h_shape.data(), shape_ptr, sizeof(IndexType) * h_shape.size(), cudaMemcpyDeviceToHost, cuda_stream), + "For SparseSplit, cudaMemcpyAsync shape failed."); + if (cudaStreamQuery(cuda_stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), + "For 'SparseSplit', cuda Stream Sync Failed."); + } + + h_block.resize(num_split + 1); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(&h_split_dim, split_dim_ptr, sizeof(IndexType), cudaMemcpyDeviceToHost, cuda_stream), + "For SparseSplit, cudaMemcpyAsync split_dim failed."); + if (cudaStreamQuery(cuda_stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), + "For 'SparseSplit', cuda Stream Sync Failed."); + } + h_block[0] = 0; + int base_range = h_shape[h_split_dim] / num_split; + size_t res = h_shape[h_split_dim] - base_range * num_split; + for (size_t i = 1; i < h_block.size(); i++) { + if (i > 1) { + h_block[i] = h_block[i - 1] + base_range; + } else { + h_block[i] = base_range; + } + if (i <= res) { + h_block[i] += 1; + } + } + + for (size_t i = 0; i < num_split; i++) { + if (i == 0) { + out_shape_value[i * Kindex2 + h_split_dim] = (IndexType)h_block[i + 1]; + } else { + out_shape_value[i * Kindex2 + h_split_dim] = (IndexType)(h_block[i + 1] - h_block[i]); + } + out_shape_value[i * Kindex2 + 1 - h_split_dim] = h_shape[1 - h_split_dim]; + } + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(d_y_indices_vec, y_indices_vec.data(), sizeof(IndexType *) * num_split, cudaMemcpyHostToDevice, + cuda_stream), + "For SparseSplit, cudaMemcpyAsync failed."); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(d_y_values_ptr, y_values_ptr.data(), sizeof(DataType *) * num_split, cudaMemcpyHostToDevice, + cuda_stream), + "For SparseSplit, cudaMemcpyAsync failed."); + + for (size_t i = 0; i < num_split; i++) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(out_shape_ptr[i], out_shape_value.data() + i * Kindex2, sizeof(IndexType) * Kindex2, + cudaMemcpyHostToDevice, cuda_stream), + "For SparseSplit out_shape_ptr, cudaMemcpyAsync failed."); + } + auto d_block_ptr = GetDeviceAddress(workspace, kIndex4); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(d_block_ptr, h_block.data(), sizeof(IndexType) * h_block.size(), + cudaMemcpyHostToDevice, cuda_stream), + "For SparseSplit, cudaMemcpyAsync failed."); + + auto sum_count_ptr = GetDeviceAddress(workspace, kIndex3); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(sum_count_ptr, 0, workspace[kIndex3]->size(), cuda_stream), + "For SparseSplit, cudaMemsetAsync failed."); + + SparseSplit(split_dim_ptr, indices_ptr, values_ptr, shape_ptr, num_split, d_y_indices_vec, + d_y_values_ptr, d_out_shape_ptr, sum_count_ptr, input_nnz_, num_dim_, d_block_ptr, + cuda_stream); + h_blocks.resize(num_split); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(h_blocks.data(), sum_count_ptr, sizeof(int) * num_split, cudaMemcpyDeviceToHost, cuda_stream), + "For SparseSplit, cudaMemcpyAsync failed."); + + return true; +} + +void SparseSplitGpuKernelMod::UpdateOutputShapeAndSize(const std::vector &inputs, + const std::vector &outputs) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream), "SparseSplit cudaStreamSynchronized failed"); + for (size_t i = 0; i < num_split; i++) { + outputs[i]->SetShapeVector(ShapeVector({h_blocks[i], Kindex2})); // indices + outputs[i + num_split]->SetShapeVector(ShapeVector({h_blocks[i]})); // value + outputs[i + num_split * Kindex2]->SetShapeVector(ShapeVector({static_cast(num_dim_)})); // shape + outputs[i]->set_size(LongToSize(h_blocks[i] * Kindex2) * UnitSizeInBytes(outputs[i]->dtype_id())); + outputs[i + num_split]->set_size(LongToSize(h_blocks[i]) * UnitSizeInBytes(outputs[i + num_split]->dtype_id())); + outputs[i + num_split * Kindex2]->set_size(num_dim_ * + UnitSizeInBytes(outputs[i + num_split * Kindex2]->dtype_id())); + } +} + +std::vector SparseSplitGpuKernelMod::GetOpSupport() { return {KernelAttr().AddSkipCheckAttr(true)}; } + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseSplit, SparseSplitGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_to_csr_sparse_matrix_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_to_csr_sparse_matrix_gpu_kernel.h index bc00cb0622d..759dfb2cebf 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_to_csr_sparse_matrix_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_to_csr_sparse_matrix_gpu_kernel.h @@ -1,73 +1,73 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/kernel_constants.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh" - -namespace mindspore { -namespace kernel { -class SparseTensorToCSRSparseMatrixGpuKernelMod : public NativeGpuKernelMod { - public: - SparseTensorToCSRSparseMatrixGpuKernelMod() = default; - ~SparseTensorToCSRSparseMatrixGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - stream_ = reinterpret_cast(cuda_stream); - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using SparseTensorToCSRSparseMatrixFunc = - std::function &, - const std::vector &, const std::vector &)>; - - static std::vector> func_list_; - - private: - size_t unit_size_{1}; - size_t input_elements_{}; - int elements[3] = {0, 0, 0}; - cudaStream_t stream_; - cusparseHandle_t handle_{nullptr}; - int row_num; - int batch_size; - int temp_nnz; - int bapt; - SparseTensorToCSRSparseMatrixFunc kernel_func_{}; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/kernel_constants.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_to_csr_sparse_matrix_impl.cuh" + +namespace mindspore { +namespace kernel { +class SparseTensorToCSRSparseMatrixGpuKernelMod : public NativeGpuKernelMod { + public: + SparseTensorToCSRSparseMatrixGpuKernelMod() = default; + ~SparseTensorToCSRSparseMatrixGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + stream_ = reinterpret_cast(cuda_stream); + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using SparseTensorToCSRSparseMatrixFunc = + std::function &, + const std::vector &, const std::vector &)>; + + static std::vector> func_list_; + + private: + size_t unit_size_{1}; + size_t input_elements_{}; + int elements[3] = {0, 0, 0}; + cudaStream_t stream_; + cusparseHandle_t handle_{nullptr}; + int row_num; + int batch_size; + int temp_nnz; + int bapt; + SparseTensorToCSRSparseMatrixFunc kernel_func_{}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_SPARSE_TENSOR_TO_CSR_SPARSE_MATRIX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.cc index 2799ca5b031..ce75c9152d4 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.cc @@ -1,234 +1,234 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h" -#include - -namespace mindspore { -namespace kernel { -constexpr int64_t kNumTwo = 2; -constexpr int INPUT_NUM = 9; -constexpr int OUTPUT_NUM = 3; - -bool SspaddmmGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; - return false; - } - if (inputs.size() != INPUT_NUM || outputs.size() != OUTPUT_NUM) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output must be " << INPUT_NUM << " and " << OUTPUT_NUM - << ", but got " << inputs.size() << " and " << outputs.size(); - } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - unit_indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); - return true; -} - -int SspaddmmGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector x1_indices_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), - inputs.at(kIndex0)->GetDeviceShapeVector().end()); - std::vector x2_indices_shape = std::vector(inputs.at(kIndex3)->GetDeviceShapeVector().begin(), - inputs.at(kIndex3)->GetDeviceShapeVector().end()); - std::vector x3_dense_shape = std::vector(inputs.at(kIndex6)->GetDeviceShapeVector().begin(), - inputs.at(kIndex6)->GetDeviceShapeVector().end()); - std::vector y_indices_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), - outputs.at(kIndex0)->GetDeviceShapeVector().end()); - x1_values_num_ = x1_indices_shape[1]; - x2_values_num_ = x2_indices_shape[1]; - y_values_num_ = y_indices_shape[1]; - x3_dense_col_ = x3_dense_shape[1]; - if (y_values_num_ == 0) { - is_null_input_ = true; - } - - workspace_size_list_.emplace_back(x2_values_num_ * sizeof(int64_t)); // index - output_size_list_.emplace_back(y_values_num_ * sizeof(int64_t) * kNumTwo); // y_indices - output_size_list_.emplace_back(y_values_num_ * unit_values_size_); // y_values - output_size_list_.emplace_back(kNumTwo * sizeof(int64_t)); // y_shape - - return KRET_OK; -} - -template -bool SspaddmmGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - S *x1_indices = GetDeviceAddress(inputs, 0); - T *x1_values = GetDeviceAddress(inputs, 1); - S *x1_shape = GetDeviceAddress(inputs, 2); - S *x2_indices = GetDeviceAddress(inputs, 3); - T *x2_values = GetDeviceAddress(inputs, 4); - T *x3_dense = GetDeviceAddress(inputs, 6); - T *alpha = GetDeviceAddress(inputs, 7); - T *beta = GetDeviceAddress(inputs, 8); - - int64_t *index = GetDeviceAddress(workspace, 0); - - int64_t *y_indices = GetDeviceAddress(outputs, 0); - T *y_values = GetDeviceAddress(outputs, 1); - int64_t *y_shape = GetDeviceAddress(outputs, 2); - - const int64_t kSize = x2_values_num_; - const int64_t kSizeX2 = kNumTwo * kSize; - std::vector x2(kSizeX2); - S x1_devicetohost_shape[kNumTwo]; - int64_t x1_host_shape[kNumTwo]; - cudaStream_t stream = reinterpret_cast(cuda_stream_); - - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(y_values, 0, y_values_num_ * unit_values_size_, stream), - "For SspaddmmGpuKernelMod, failed to cudaMemset for y_values."); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(index, 0, x2_values_num_ * unit_values_size_, stream), - "For SspaddmmGpuKernelMod, failed to cudaMemset for index."); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(&x1_devicetohost_shape, x1_shape, sizeof(S) * kNumTwo, cudaMemcpyDeviceToHost, stream), - "For SspaddmmGpuKernelMod cudaMemcpyAsync x1_shape Fail"); - CHECK_CUDA_RET_WITH_ERROR_NOTRACE( - cudaMemcpyAsync(x2.data(), x2_indices, sizeof(S) * x2_values_num_ * kNumTwo, cudaMemcpyDeviceToHost, stream), - "For SspaddmmGpuKernelMod cudaMemcpyAsync x2_values Fail"); - if (cudaStreamQuery(stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), - "For 'SspaddmmGpuKernelMod', cuda Stream Sync Failed."); - } - - // cal y_shape - x1_host_shape[0] = static_cast(x1_devicetohost_shape[0]); - x1_host_shape[1] = static_cast(x1_devicetohost_shape[1]); - // cal index for y_values and y_indices - std::vector idx(kSize); - int64_t count = 0; - idx[0] = count; - for (int64_t i = 1; i < x2_values_num_; ++i) { - for (int64_t j = 0; j < i; ++j) { - if (x2[i] == x2[j]) { - idx[i] = idx[j]; - break; - } else if (i == j + 1) { - idx[i] = ++count; - break; - } - } - } - - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(index, idx.data(), sizeof(int64_t) * x2_values_num_, cudaMemcpyHostToDevice, stream), - "For SspaddmmGpuKernelMod cudaMemcpyAsync index failed."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(y_shape, x1_host_shape, kNumTwo * sizeof(int64_t), cudaMemcpyHostToDevice, stream), - "For SspaddmmGpuKernelMod cudaMemcpyAsync x1_shape failed."); - - // x1 + x2 @ x3_dense - auto status = CalSparseAddSparse(x1_indices, x1_values, x1_values_num_, y_indices, y_values, y_values_num_, beta, - device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - // the result of x2 @ x3_dense will write to output directly - status = CalSparseMulDense(x2_indices, x2_values, x2_values_num_, x3_dense, y_indices, y_values, y_values_num_, - x3_dense_col_, x1_values_num_, alpha, index, device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -#define GPU_SSPADDMM_KERNEL_REGISTER_ONE(ms_value_type, ms_index_type) \ - { \ - KernelAttr() \ - .AddInputAttr(ms_index_type) \ - .AddInputAttr(ms_value_type) \ - .AddInputAttr(ms_index_type) \ - .AddInputAttr(ms_index_type) \ - .AddInputAttr(ms_value_type) \ - .AddInputAttr(ms_index_type) - -#define GPU_SSPADDMM_KERNEL_REGISTER_TWO(ms_value_type, value_type, index_type) \ - .AddInputAttr(ms_value_type) \ - .AddInputAttr(ms_value_type) \ - .AddInputAttr(ms_value_type) \ - .AddOutputAttr(kNumberTypeInt64) \ - .AddOutputAttr(ms_value_type) \ - .AddOutputAttr(kNumberTypeInt64), \ - &SspaddmmGpuKernelMod::LaunchKernel \ - } - -std::vector> SspaddmmGpuKernelMod::func_list_ = { - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt8, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt8, int8_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt16, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt16, int16_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt32, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt32, int32_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt64, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt64, int64_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt8, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt8, uint8_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt16, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt16, uint16_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt32, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt32, uint32_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt64, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt64, uint64_t, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat16, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat16, half, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat32, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat32, float, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat64, kNumberTypeInt32) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat64, double, int), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt8, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt8, int8_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt16, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt16, int16_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt32, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt32, int32_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt64, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt64, int64_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt8, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt8, uint8_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat32, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat32, float, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat64, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat64, double, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt16, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt16, uint16_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt32, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt32, uint32_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt64, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt64, uint64_t, int64_t), - GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat16, kNumberTypeInt64) - GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat16, half, int64_t)}; - -std::vector SspaddmmGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sspaddmm, SspaddmmGpuKernelMod); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h" +#include + +namespace mindspore { +namespace kernel { +constexpr int64_t kNumTwo = 2; +constexpr int INPUT_NUM = 9; +constexpr int OUTPUT_NUM = 3; + +bool SspaddmmGpuKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + return false; + } + if (inputs.size() != INPUT_NUM || outputs.size() != OUTPUT_NUM) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output must be " << INPUT_NUM << " and " << OUTPUT_NUM + << ", but got " << inputs.size() << " and " << outputs.size(); + } + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + unit_indices_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_values_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); + return true; +} + +int SspaddmmGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector x1_indices_shape = std::vector(inputs.at(kIndex0)->GetDeviceShapeVector().begin(), + inputs.at(kIndex0)->GetDeviceShapeVector().end()); + std::vector x2_indices_shape = std::vector(inputs.at(kIndex3)->GetDeviceShapeVector().begin(), + inputs.at(kIndex3)->GetDeviceShapeVector().end()); + std::vector x3_dense_shape = std::vector(inputs.at(kIndex6)->GetDeviceShapeVector().begin(), + inputs.at(kIndex6)->GetDeviceShapeVector().end()); + std::vector y_indices_shape = std::vector(outputs.at(kIndex0)->GetDeviceShapeVector().begin(), + outputs.at(kIndex0)->GetDeviceShapeVector().end()); + x1_values_num_ = x1_indices_shape[1]; + x2_values_num_ = x2_indices_shape[1]; + y_values_num_ = y_indices_shape[1]; + x3_dense_col_ = x3_dense_shape[1]; + if (y_values_num_ == 0) { + is_null_input_ = true; + } + + workspace_size_list_.emplace_back(x2_values_num_ * sizeof(int64_t)); // index + output_size_list_.emplace_back(y_values_num_ * sizeof(int64_t) * kNumTwo); // y_indices + output_size_list_.emplace_back(y_values_num_ * unit_values_size_); // y_values + output_size_list_.emplace_back(kNumTwo * sizeof(int64_t)); // y_shape + + return KRET_OK; +} + +template +bool SspaddmmGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + S *x1_indices = GetDeviceAddress(inputs, 0); + T *x1_values = GetDeviceAddress(inputs, 1); + S *x1_shape = GetDeviceAddress(inputs, 2); + S *x2_indices = GetDeviceAddress(inputs, 3); + T *x2_values = GetDeviceAddress(inputs, 4); + T *x3_dense = GetDeviceAddress(inputs, 6); + T *alpha = GetDeviceAddress(inputs, 7); + T *beta = GetDeviceAddress(inputs, 8); + + int64_t *index = GetDeviceAddress(workspace, 0); + + int64_t *y_indices = GetDeviceAddress(outputs, 0); + T *y_values = GetDeviceAddress(outputs, 1); + int64_t *y_shape = GetDeviceAddress(outputs, 2); + + const int64_t kSize = x2_values_num_; + const int64_t kSizeX2 = kNumTwo * kSize; + std::vector x2(kSizeX2); + S x1_devicetohost_shape[kNumTwo]; + int64_t x1_host_shape[kNumTwo]; + cudaStream_t stream = reinterpret_cast(cuda_stream_); + + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(y_values, 0, y_values_num_ * unit_values_size_, stream), + "For SspaddmmGpuKernelMod, failed to cudaMemset for y_values."); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemsetAsync(index, 0, x2_values_num_ * unit_values_size_, stream), + "For SspaddmmGpuKernelMod, failed to cudaMemset for index."); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(&x1_devicetohost_shape, x1_shape, sizeof(S) * kNumTwo, cudaMemcpyDeviceToHost, stream), + "For SspaddmmGpuKernelMod cudaMemcpyAsync x1_shape Fail"); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE( + cudaMemcpyAsync(x2.data(), x2_indices, sizeof(S) * x2_values_num_ * kNumTwo, cudaMemcpyDeviceToHost, stream), + "For SspaddmmGpuKernelMod cudaMemcpyAsync x2_values Fail"); + if (cudaStreamQuery(stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), + "For 'SspaddmmGpuKernelMod', cuda Stream Sync Failed."); + } + + // cal y_shape + x1_host_shape[0] = static_cast(x1_devicetohost_shape[0]); + x1_host_shape[1] = static_cast(x1_devicetohost_shape[1]); + // cal index for y_values and y_indices + std::vector idx(kSize); + int64_t count = 0; + idx[0] = count; + for (int64_t i = 1; i < x2_values_num_; ++i) { + for (int64_t j = 0; j < i; ++j) { + if (x2[i] == x2[j]) { + idx[i] = idx[j]; + break; + } else if (i == j + 1) { + idx[i] = ++count; + break; + } + } + } + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(index, idx.data(), sizeof(int64_t) * x2_values_num_, cudaMemcpyHostToDevice, stream), + "For SspaddmmGpuKernelMod cudaMemcpyAsync index failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(y_shape, x1_host_shape, kNumTwo * sizeof(int64_t), cudaMemcpyHostToDevice, stream), + "For SspaddmmGpuKernelMod cudaMemcpyAsync x1_shape failed."); + + // x1 + x2 @ x3_dense + auto status = CalSparseAddSparse(x1_indices, x1_values, x1_values_num_, y_indices, y_values, y_values_num_, beta, + device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + // the result of x2 @ x3_dense will write to output directly + status = CalSparseMulDense(x2_indices, x2_values, x2_values_num_, x3_dense, y_indices, y_values, y_values_num_, + x3_dense_col_, x1_values_num_, alpha, index, device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +#define GPU_SSPADDMM_KERNEL_REGISTER_ONE(ms_value_type, ms_index_type) \ + { \ + KernelAttr() \ + .AddInputAttr(ms_index_type) \ + .AddInputAttr(ms_value_type) \ + .AddInputAttr(ms_index_type) \ + .AddInputAttr(ms_index_type) \ + .AddInputAttr(ms_value_type) \ + .AddInputAttr(ms_index_type) + +#define GPU_SSPADDMM_KERNEL_REGISTER_TWO(ms_value_type, value_type, index_type) \ + .AddInputAttr(ms_value_type) \ + .AddInputAttr(ms_value_type) \ + .AddInputAttr(ms_value_type) \ + .AddOutputAttr(kNumberTypeInt64) \ + .AddOutputAttr(ms_value_type) \ + .AddOutputAttr(kNumberTypeInt64), \ + &SspaddmmGpuKernelMod::LaunchKernel \ + } + +std::vector> SspaddmmGpuKernelMod::func_list_ = { + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt8, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt8, int8_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt16, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt16, int16_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt32, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt32, int32_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt64, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt64, int64_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt8, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt8, uint8_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt16, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt16, uint16_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt32, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt32, uint32_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt64, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt64, uint64_t, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat16, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat16, half, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat32, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat32, float, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat64, kNumberTypeInt32) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat64, double, int), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt8, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt8, int8_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt16, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt16, int16_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt32, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt32, int32_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeInt64, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeInt64, int64_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt8, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt8, uint8_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat32, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat32, float, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat64, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat64, double, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt16, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt16, uint16_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt32, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt32, uint32_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeUInt64, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeUInt64, uint64_t, int64_t), + GPU_SSPADDMM_KERNEL_REGISTER_ONE(kNumberTypeFloat16, kNumberTypeInt64) + GPU_SSPADDMM_KERNEL_REGISTER_TWO(kNumberTypeFloat16, half, int64_t)}; + +std::vector SspaddmmGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sspaddmm, SspaddmmGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h index eaf0d4c3e0d..75da749aa93 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sspaddmm_gpu_kernel.h @@ -1,89 +1,89 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/core/ops/sspaddmm.h" -#include "abstract/utils.h" -#include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sspaddmm_impl.cuh" - -namespace mindspore { -namespace kernel { -class SspaddmmGpuKernelMod : public NativeGpuKernelMod { - public: - SspaddmmGpuKernelMod() { ResetResource(); } - ~SspaddmmGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - std::vector GetOpSupport() override; - - protected: - void ResetResource() noexcept { - x1_values_num_ = 0; - x2_values_num_ = 0; - y_values_num_ = 0; - x3_dense_col_ = 0; - is_null_input_ = false; - workspace_size_list_.clear(); - output_size_list_.clear(); - } - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - using SspaddmmFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - int64_t unit_indices_size_{1}; - int64_t unit_values_size_{1}; - int64_t x1_values_num_{}; - int64_t x2_values_num_{}; - int64_t y_values_num_{}; - int64_t x3_dense_col_{}; - SspaddmmFunc kernel_func_{}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - static std::vector> func_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/sspaddmm.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sspaddmm_impl.cuh" + +namespace mindspore { +namespace kernel { +class SspaddmmGpuKernelMod : public NativeGpuKernelMod { + public: + SspaddmmGpuKernelMod() { ResetResource(); } + ~SspaddmmGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + std::vector GetOpSupport() override; + + protected: + void ResetResource() noexcept { + x1_values_num_ = 0; + x2_values_num_ = 0; + y_values_num_ = 0; + x3_dense_col_ = 0; + is_null_input_ = false; + workspace_size_list_.clear(); + output_size_list_.clear(); + } + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using SspaddmmFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + int64_t unit_indices_size_{1}; + int64_t unit_values_size_{1}; + int64_t x1_values_num_{}; + int64_t x2_values_num_{}; + int64_t y_values_num_{}; + int64_t x3_dense_col_{}; + SspaddmmFunc kernel_func_{}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SSPADDMM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc index 0e9c228237c..2b356bcdc4b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc @@ -1,244 +1,244 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto Sparse_Segment_Sum_Grad = "SparseSegmentSumGrad"; -constexpr auto Sparse_Segment_Sqrt_N_Grad = "SparseSegmentSqrtNGrad"; -constexpr size_t kNumber1 = 1; -constexpr size_t kNumber4 = 4; -} // namespace - -bool SparseSegmentGradOpsGpuKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - size_t inputs_num = kNumber4; - size_t outputs_num = kNumber1; - CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; - return false; - } - kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; - unit_grad_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); - unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); - return true; -} - -int SparseSegmentGradOpsGpuKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - for (const auto &input : inputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto input_shape = input->GetShapeVector(); - if (!IsValidShape(input_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - for (const auto &output : outputs) { - // If any input shape contains -1, means input shape is dynamic, so just return do nothing. - auto output_shape = output->GetShapeVector(); - if (!IsValidShape(output_shape)) { - return KRET_UNKNOWN_SHAPE; - } - } - ResetResource(); - std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); - output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - if (output_elements_ == 0) { - is_null_input_ = true; - } - std::vector grad_shape = inputs.at(kIndex0)->GetShapeVector(); - grad_shape_0_ = grad_shape[0]; - grad_elements_ = std::accumulate(grad_shape.begin(), grad_shape.end(), 1, std::multiplies{}); - outer_size_ = grad_shape.front(); - inner_size_ = grad_elements_ / outer_size_; - std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); - idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); - output_dim0_ = LongToSize(output_shape.front()); - - size_t output_size = output_elements_ * unit_grad_size_; - output_size_list_.push_back(output_size); - workspace_size_list_.push_back((outer_size_ + 1) * sizeof(size_t)); - return KRET_OK; -} - -template -bool SparseSegmentGradOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - R *grad_ptr = GetDeviceAddress(inputs, kIndex0); - S *indices_ptr = GetDeviceAddress(inputs, kIndex1); - S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); - R *y_ptr = GetDeviceAddress(outputs, kIndex0); - size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); - auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; - if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { - return false; - } - cudaStream_t stream = reinterpret_cast(cuda_stream_); - std::vector indices_host; - std::vector segment_ids_host; - indices_host.resize(idx_seg_elements_); - segment_ids_host.resize(idx_seg_elements_); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), - "For 'SparseSegmentGradOps', cudaMemcpy indices failed."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, - idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), - "For 'SparseSegmentGradOps', cudaMemcpy segment_ids failed."); - if (cudaStreamQuery(stream) != cudaSuccess) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), - "For 'SparseSegmentGradOps', cudaStreamSyncFailed"); - } - for (size_t i = 1; i < idx_seg_elements_; i++) { - if (segment_ids_host[i] < segment_ids_host[i - 1]) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; - } - } - for (size_t i = 0; i < idx_seg_elements_; i++) { - if (indices_host[i] >= static_cast(output_dim0_)) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of output_dim0."; - } - if (segment_ids_host[i] >= static_cast(grad_shape_0_)) { - MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of grad's first shape."; - } - } - cudaMemset(y_ptr, 0, output_elements_ * unit_grad_size_); - auto status = - CalSparseSegmentGradCombination(kernel_type_, grad_ptr, segment_ids_ptr, indices_ptr, segment_pos_ptr, outer_size_, - inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); - CHECK_CUDA_STATUS(status, kernel_name_); - return true; -} - -std::map>> - SparseSegmentGradOpsGpuKernelMod::kernel_attr_map_ = { - {Sparse_Segment_Sum_Grad, - {{KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}, - {Sparse_Segment_Sqrt_N_Grad, - {{KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ - -std::vector SparseSegmentGradOpsGpuKernelMod::GetOpSupport() { - auto iter = kernel_attr_map_.find(kernel_type_); - if (iter == kernel_attr_map_.end()) { - MS_EXCEPTION(ValueError) << "For 'SparseSegmentGradOpsOp', only support these types: " - << kernel::Map2Str>>( - kernel_attr_map_) - << " currently, but got " << kernel_name_; - } - std::vector support_list; - (void)std::transform( - iter->second.begin(), iter->second.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumGrad, []() { - return std::make_shared(Sparse_Segment_Sum_Grad); -}); -MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNGrad, []() { - return std::make_shared(Sparse_Segment_Sqrt_N_Grad); -}); -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto Sparse_Segment_Sum_Grad = "SparseSegmentSumGrad"; +constexpr auto Sparse_Segment_Sqrt_N_Grad = "SparseSegmentSqrtNGrad"; +constexpr size_t kNumber1 = 1; +constexpr size_t kNumber4 = 4; +} // namespace + +bool SparseSegmentGradOpsGpuKernelMod::Init(const std::vector &inputs, + const std::vector &outputs) { + size_t inputs_num = kNumber4; + size_t outputs_num = kNumber1; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; + return false; + } + kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; + unit_grad_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype); + unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype); + return true; +} + +int SparseSegmentGradOpsGpuKernelMod::Resize(const std::vector &inputs, + const std::vector &outputs) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + std::vector grad_shape = inputs.at(kIndex0)->GetShapeVector(); + grad_shape_0_ = grad_shape[0]; + grad_elements_ = std::accumulate(grad_shape.begin(), grad_shape.end(), 1, std::multiplies{}); + outer_size_ = grad_shape.front(); + inner_size_ = grad_elements_ / outer_size_; + std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); + idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); + output_dim0_ = LongToSize(output_shape.front()); + + size_t output_size = output_elements_ * unit_grad_size_; + output_size_list_.push_back(output_size); + workspace_size_list_.push_back((outer_size_ + 1) * sizeof(size_t)); + return KRET_OK; +} + +template +bool SparseSegmentGradOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + R *grad_ptr = GetDeviceAddress(inputs, kIndex0); + S *indices_ptr = GetDeviceAddress(inputs, kIndex1); + S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); + R *y_ptr = GetDeviceAddress(outputs, kIndex0); + size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); + auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; + if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { + return false; + } + cudaStream_t stream = reinterpret_cast(cuda_stream_); + std::vector indices_host; + std::vector segment_ids_host; + indices_host.resize(idx_seg_elements_); + segment_ids_host.resize(idx_seg_elements_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "For 'SparseSegmentGradOps', cudaMemcpy indices failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, + idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "For 'SparseSegmentGradOps', cudaMemcpy segment_ids failed."); + if (cudaStreamQuery(stream) != cudaSuccess) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), + "For 'SparseSegmentGradOps', cudaStreamSyncFailed"); + } + for (size_t i = 1; i < idx_seg_elements_; i++) { + if (segment_ids_host[i] < segment_ids_host[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; + } + } + for (size_t i = 0; i < idx_seg_elements_; i++) { + if (indices_host[i] >= static_cast(output_dim0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of output_dim0."; + } + if (segment_ids_host[i] >= static_cast(grad_shape_0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of grad's first shape."; + } + } + cudaMemset(y_ptr, 0, output_elements_ * unit_grad_size_); + auto status = + CalSparseSegmentGradCombination(kernel_type_, grad_ptr, segment_ids_ptr, indices_ptr, segment_pos_ptr, outer_size_, + inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); + CHECK_CUDA_STATUS(status, kernel_name_); + return true; +} + +std::map>> + SparseSegmentGradOpsGpuKernelMod::kernel_attr_map_ = { + {Sparse_Segment_Sum_Grad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N_Grad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ + +std::vector SparseSegmentGradOpsGpuKernelMod::GetOpSupport() { + auto iter = kernel_attr_map_.find(kernel_type_); + if (iter == kernel_attr_map_.end()) { + MS_EXCEPTION(ValueError) << "For 'SparseSegmentGradOpsOp', only support these types: " + << kernel::Map2Str>>( + kernel_attr_map_) + << " currently, but got " << kernel_name_; + } + std::vector support_list; + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumGrad, []() { + return std::make_shared(Sparse_Segment_Sum_Grad); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNGrad, []() { + return std::make_shared(Sparse_Segment_Sqrt_N_Grad); +}); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h index 27d76e02d9e..56c715db452 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h @@ -1,96 +1,96 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseSegmentGradOpsGpuKernelMod : public NativeGpuKernelMod { - public: - explicit SparseSegmentGradOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} - ~SparseSegmentGradOpsGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *cuda_stream) override { - if (is_null_input_) { - return true; - } - cuda_stream_ = cuda_stream; - return kernel_func_(this, inputs, workspace, outputs); - } - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - protected: - void ResetResource() noexcept { - outer_size_ = 0; - inner_size_ = 0; - grad_elements_ = 0; - idx_seg_elements_ = 0; - output_dim0_ = 0; - output_elements_ = 0; - is_null_input_ = false; - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs); - - using SSGLaunchFunc = - std::function &, - const std::vector &, const std::vector &)>; - - private: - size_t outer_size_{0}; - size_t inner_size_{0}; - size_t grad_elements_{0}; - size_t grad_shape_0_{0}; - size_t idx_seg_elements_{0}; - size_t output_dim0_{0}; - size_t output_elements_{0}; - size_t unit_grad_size_{1}; - size_t unit_idx_seg_size_{1}; - std::string kernel_type_{"Unknown"}; - bool is_null_input_{false}; - void *cuda_stream_{nullptr}; - SSGLaunchFunc kernel_func_{}; - static std::map>> kernel_attr_map_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentGradOpsGpuKernelMod : public NativeGpuKernelMod { + public: + explicit SparseSegmentGradOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~SparseSegmentGradOpsGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const std::vector &inputs, const std::vector &outputs) override; + + int Resize(const std::vector &inputs, const std::vector &outputs) override; + + protected: + void ResetResource() noexcept { + outer_size_ = 0; + inner_size_ = 0; + grad_elements_ = 0; + idx_seg_elements_ = 0; + output_dim0_ = 0; + output_elements_ = 0; + is_null_input_ = false; + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using SSGLaunchFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t outer_size_{0}; + size_t inner_size_{0}; + size_t grad_elements_{0}; + size_t grad_shape_0_{0}; + size_t idx_seg_elements_{0}; + size_t output_dim0_{0}; + size_t output_elements_{0}; + size_t unit_grad_size_{1}; + size_t unit_idx_seg_size_{1}; + std::string kernel_type_{"Unknown"}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + SSGLaunchFunc kernel_func_{}; + static std::map>> kernel_attr_map_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_fusion.h index 4764153325a..f481aa3ef25 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_fusion.h @@ -1,61 +1,61 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormAddReluFusion : public PatternProcessPass { - public: - explicit BatchNormAddReluFusion(bool multigraph = true) - : PatternProcessPass("batch_norm_add_relu_fusion", multigraph) { - x_ = std::make_shared(); - scale_ = std::make_shared(); - bias_ = std::make_shared(); - mean_ = std::make_shared(); - var_ = std::make_shared(); - is_training_ = std::make_shared(); - eps_ = std::make_shared(); - momentum_ = std::make_shared(); - format_ = std::make_shared(); - index_ = std::make_shared(); - z_ = std::make_shared(); - umonad_ = std::make_shared(); - } - ~BatchNormAddReluFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr x_; - VarPtr scale_; - VarPtr bias_; - VarPtr mean_; - VarPtr var_; - VarPtr is_training_; - VarPtr eps_; - VarPtr momentum_; - VarPtr format_; - VarPtr index_; - VarPtr z_; - VarPtr umonad_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormAddReluFusion : public PatternProcessPass { + public: + explicit BatchNormAddReluFusion(bool multigraph = true) + : PatternProcessPass("batch_norm_add_relu_fusion", multigraph) { + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + is_training_ = std::make_shared(); + eps_ = std::make_shared(); + momentum_ = std::make_shared(); + format_ = std::make_shared(); + index_ = std::make_shared(); + z_ = std::make_shared(); + umonad_ = std::make_shared(); + } + ~BatchNormAddReluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr is_training_; + VarPtr eps_; + VarPtr momentum_; + VarPtr format_; + VarPtr index_; + VarPtr z_; + VarPtr umonad_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_grad_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_grad_fusion.h index 2486ef4d686..d7f3dace0c6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_grad_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_add_relu_grad_fusion.h @@ -1,63 +1,63 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormAddReluGradFusion : public PatternProcessPass { - public: - explicit BatchNormAddReluGradFusion(bool multigraph = true) - : PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) { - dy_ = std::make_shared(); - y_ = std::make_shared(); - x_ = std::make_shared(); - scale_ = std::make_shared(); - bias_ = std::make_shared(); - mean_ = std::make_shared(); - var_ = std::make_shared(); - save_mean_ = std::make_shared(); - save_var_ = std::make_shared(); - is_training_ = std::make_shared(); - eps_ = std::make_shared(); - format_ = std::make_shared(); - reserve_ = std::make_shared(); - } - ~BatchNormAddReluGradFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr dy_; - VarPtr y_; - VarPtr x_; - VarPtr scale_; - VarPtr bias_; - VarPtr mean_; - VarPtr var_; - VarPtr save_mean_; - VarPtr save_var_; - VarPtr is_training_; - VarPtr eps_; - VarPtr format_; - VarPtr reserve_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormAddReluGradFusion : public PatternProcessPass { + public: + explicit BatchNormAddReluGradFusion(bool multigraph = true) + : PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) { + dy_ = std::make_shared(); + y_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + save_mean_ = std::make_shared(); + save_var_ = std::make_shared(); + is_training_ = std::make_shared(); + eps_ = std::make_shared(); + format_ = std::make_shared(); + reserve_ = std::make_shared(); + } + ~BatchNormAddReluGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr y_; + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr save_mean_; + VarPtr save_var_; + VarPtr is_training_; + VarPtr eps_; + VarPtr format_; + VarPtr reserve_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_fusion.h index bb0ac2c043f..f8c44e18f21 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_fusion.h @@ -1,58 +1,58 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormReluFusion : public PatternProcessPass { - public: - explicit BatchNormReluFusion(bool multigraph = true) : PatternProcessPass("batch_norm_relu_fusion", multigraph) { - x_ = std::make_shared(); - scale_ = std::make_shared(); - bias_ = std::make_shared(); - mean_ = std::make_shared(); - var_ = std::make_shared(); - is_training_ = std::make_shared(); - eps_ = std::make_shared(); - momentum_ = std::make_shared(); - format_ = std::make_shared(); - umonad_ = std::make_shared(); - index_ = std::make_shared(); - } - ~BatchNormReluFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr x_; - VarPtr scale_; - VarPtr bias_; - VarPtr mean_; - VarPtr var_; - VarPtr is_training_; - VarPtr eps_; - VarPtr momentum_; - VarPtr format_; - VarPtr umonad_; - VarPtr index_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormReluFusion : public PatternProcessPass { + public: + explicit BatchNormReluFusion(bool multigraph = true) : PatternProcessPass("batch_norm_relu_fusion", multigraph) { + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + is_training_ = std::make_shared(); + eps_ = std::make_shared(); + momentum_ = std::make_shared(); + format_ = std::make_shared(); + umonad_ = std::make_shared(); + index_ = std::make_shared(); + } + ~BatchNormReluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr is_training_; + VarPtr eps_; + VarPtr momentum_; + VarPtr format_; + VarPtr umonad_; + VarPtr index_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.h index 470877b2b5c..40a5722aa11 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.h @@ -1,57 +1,57 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormReluGradFusion : public PatternProcessPass { - public: - explicit BatchNormReluGradFusion(bool multigraph = true) - : PatternProcessPass("batch_norm_relu_grad_fusion", multigraph) { - dy_ = std::make_shared(); - y_ = std::make_shared(); - x_ = std::make_shared(); - scale_ = std::make_shared(); - save_mean_ = std::make_shared(); - save_var_ = std::make_shared(); - reserve_ = std::make_shared(); - is_training_ = std::make_shared(); - eps_ = std::make_shared(); - format_ = std::make_shared(); - } - ~BatchNormReluGradFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr dy_; - VarPtr y_; - VarPtr x_; - VarPtr scale_; - VarPtr save_mean_; - VarPtr save_var_; - VarPtr reserve_; - VarPtr is_training_; - VarPtr eps_; - VarPtr format_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormReluGradFusion : public PatternProcessPass { + public: + explicit BatchNormReluGradFusion(bool multigraph = true) + : PatternProcessPass("batch_norm_relu_grad_fusion", multigraph) { + dy_ = std::make_shared(); + y_ = std::make_shared(); + x_ = std::make_shared(); + scale_ = std::make_shared(); + save_mean_ = std::make_shared(); + save_var_ = std::make_shared(); + reserve_ = std::make_shared(); + is_training_ = std::make_shared(); + eps_ = std::make_shared(); + format_ = std::make_shared(); + } + ~BatchNormReluGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr dy_; + VarPtr y_; + VarPtr x_; + VarPtr scale_; + VarPtr save_mean_; + VarPtr save_var_; + VarPtr reserve_; + VarPtr is_training_; + VarPtr eps_; + VarPtr format_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_addn_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_addn_fusion.h index bcaf692ce8f..17fbe0afa90 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_addn_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_addn_fusion.h @@ -1,34 +1,34 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ - -#include -#include "include/backend/optimizer/pattern_to_pattern.h" - -namespace mindspore { -namespace opt { -class ReplaceAddNFusion : public PatternToPatternPass { - public: - ReplaceAddNFusion() : PatternToPatternPass("replace_addn") {} - ~ReplaceAddNFusion() override = default; - void DefineSrcPattern(SrcPattern *src_pattern) override; - void DefineDstPattern(DstPattern *dst_pattern) override; - bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ + +#include +#include "include/backend/optimizer/pattern_to_pattern.h" + +namespace mindspore { +namespace opt { +class ReplaceAddNFusion : public PatternToPatternPass { + public: + ReplaceAddNFusion() : PatternToPatternPass("replace_addn") {} + ~ReplaceAddNFusion() override = default; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_ADDN_FUSION_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_momentum_cast_fusion.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_momentum_cast_fusion.h index 43772c71924..d4a3621e649 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_momentum_cast_fusion.h +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/replace_momentum_cast_fusion.h @@ -1,46 +1,46 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ - -#include -#include "include/backend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -class ReplaceMomentumCastFusion : public PatternProcessPass { - public: - explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { - var_ = std::make_shared(); - acc_ = std::make_shared(); - lr_ = std::make_shared(); - grad_ = std::make_shared(); - mom_ = std::make_shared(); - } - ~ReplaceMomentumCastFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr var_; - VarPtr acc_; - VarPtr lr_; - VarPtr grad_; - VarPtr mom_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ + +#include +#include "include/backend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceMomentumCastFusion : public PatternProcessPass { + public: + explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { + var_ = std::make_shared(); + acc_ = std::make_shared(); + lr_ = std::make_shared(); + grad_ = std::make_shared(); + mom_ = std::make_shared(); + } + ~ReplaceMomentumCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr var_; + VarPtr acc_; + VarPtr lr_; + VarPtr grad_; + VarPtr mom_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/ps/scheduler.cc b/mindspore/ccsrc/ps/scheduler.cc old mode 100755 new mode 100644 index 0c0bd757b2b..c9344c41c5a --- a/mindspore/ccsrc/ps/scheduler.cc +++ b/mindspore/ccsrc/ps/scheduler.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/backend/distributed/ps/scheduler.h" -#include "ps/core/scheduler_node.h" -#include "ps/core/ps_scheduler_node.h" - -namespace mindspore { -namespace ps { -Scheduler &Scheduler::GetInstance() { - static Scheduler instance{}; - return instance; -} - -Scheduler::Scheduler() { - if (scheduler_node_ == nullptr) { - scheduler_node_ = std::make_unique(); - } -} - -Scheduler::~Scheduler() = default; - -void Scheduler::Run() { - MS_LOG(INFO) << "Start scheduler."; - PSContext::instance()->cluster_config().scheduler_host = PSContext::instance()->scheduler_host(); - PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port(); - PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num(); - PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num(); - if (!scheduler_node_->Start()) { - MS_LOG(WARNING) << "Scheduler start failed."; - } - - if (!scheduler_node_->Finish()) { - MS_LOG(WARNING) << "Scheduler finish failed."; - } - - if (!scheduler_node_->Stop()) { - MS_LOG(WARNING) << "Scheduler stop failed."; - } -} -} // namespace ps -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/backend/distributed/ps/scheduler.h" +#include "ps/core/scheduler_node.h" +#include "ps/core/ps_scheduler_node.h" + +namespace mindspore { +namespace ps { +Scheduler &Scheduler::GetInstance() { + static Scheduler instance{}; + return instance; +} + +Scheduler::Scheduler() { + if (scheduler_node_ == nullptr) { + scheduler_node_ = std::make_unique(); + } +} + +Scheduler::~Scheduler() = default; + +void Scheduler::Run() { + MS_LOG(INFO) << "Start scheduler."; + PSContext::instance()->cluster_config().scheduler_host = PSContext::instance()->scheduler_host(); + PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port(); + PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num(); + PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num(); + if (!scheduler_node_->Start()) { + MS_LOG(WARNING) << "Scheduler start failed."; + } + + if (!scheduler_node_->Finish()) { + MS_LOG(WARNING) << "Scheduler finish failed."; + } + + if (!scheduler_node_->Stop()) { + MS_LOG(WARNING) << "Scheduler stop failed."; + } +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/CMakeLists.txt b/mindspore/ccsrc/pybind_api/CMakeLists.txt index c51d887309d..d5cdd22eae7 100644 --- a/mindspore/ccsrc/pybind_api/CMakeLists.txt +++ b/mindspore/ccsrc/pybind_api/CMakeLists.txt @@ -1,8 +1,8 @@ -file(GLOB_RECURSE _PYBIND_API_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_PYBIND_API_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON) -add_library(_mindspore_pybind_api_obj OBJECT ${_PYBIND_API_SRC_LIST}) - -if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) - string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +file(GLOB_RECURSE _PYBIND_API_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_PYBIND_API_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON) +add_library(_mindspore_pybind_api_obj OBJECT ${_PYBIND_API_SRC_LIST}) + +if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) + string(REPLACE " -Werror " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") endif() \ No newline at end of file diff --git a/mindspore/ccsrc/transform/express_ir/CMakeLists.txt b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt index d8c80850ef4..74bd0727c9b 100644 --- a/mindspore/ccsrc/transform/express_ir/CMakeLists.txt +++ b/mindspore/ccsrc/transform/express_ir/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_EXPORTER_IR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS - SUBMODULE_ID=mindspore::SubModuleId::SM_EXPRESS) -add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES}) +file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_EXPORTER_IR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_EXPRESS) +add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES}) diff --git a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt index d9b4838f990..421b39eb618 100644 --- a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt +++ b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt @@ -1,64 +1,64 @@ -if(ENABLE_D OR ENABLE_ACL) - set(op_proto_src_gen_script "${CMAKE_CURRENT_BINARY_DIR}/op_proto/generate_op_proto.cmake") - file(WRITE ${op_proto_src_gen_script} "" - "get_filename_component(op_inc_file_name \${OP_PROTO_INC} NAME_WE) \n" - "set(OP_PROTO_INCLUDE_FILE \${OP_PROTO_INC}) \n" - "configure_file(${CMAKE_CURRENT_SOURCE_DIR}/op_declare/op_proto.cc.in \n" - " \${WORKSPACE_PATH}/\${op_inc_file_name}_op_proto.cc @ONLY) \n" - ) -endif() -function(op_proto_generate path c_var op_proto_include_file) - set(${c_var}) - get_filename_component(abs_file ${op_proto_include_file} ABSOLUTE) - get_filename_component(file_name ${op_proto_include_file} NAME_WE) - get_filename_component(file_dir ${abs_file} PATH) - file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) - - list(APPEND ${c_var} "${path}/${file_name}_op_proto.cc") - add_custom_command( - OUTPUT "${path}/${file_name}_op_proto.cc" - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - COMMAND ${CMAKE_COMMAND} - -DOP_PROTO_INC=\"${op_proto_include_file}\" - -DWORKSPACE_PATH=${CMAKE_CURRENT_BINARY_DIR}/op_proto -P ${op_proto_src_gen_script} - DEPENDS ${op_proto_include_file} - ${CMAKE_CURRENT_SOURCE_DIR}/op_declare/op_proto.cc.in - COMMENT "Generating op proto source file: ${${c_var}}" VERBATIM) - - set_source_files_properties(${${c_var}} PROPERTIES GENERATED TRUE) - set(${c_var} ${${c_var}} PARENT_SCOPE) -endfunction() - -if(ENABLE_D OR ENABLE_ACL) - set(OPS_INC_DIR ${ASCEND_PATH}/latest/opp/built-in/op_proto/inc/) - file(GLOB_RECURSE CUSTOM_OPS_INC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/custom_op_proto/*.h) - message("CANN ops include path: " ${OPS_INC_DIR}) - file(GLOB_RECURSE OPS_INC_LIST ${OPS_INC_DIR}/*.h) - # remove god.h - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/all_ops.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_math.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/tensor.h") - # remove repeated header - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/outfeed_ops.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_pooling_ops.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_norm.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/selection.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/transformation.h") - list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/experiment_ops.h") - #append custom op - list(APPEND OPS_INC_LIST ${CUSTOM_OPS_INC_LIST}) - set(OPS_PROTO_OBJECTS) - foreach(op ${OPS_INC_LIST}) - get_filename_component(op_inc_file_name ${op} NAME_WE) - op_proto_generate(${CMAKE_CURRENT_BINARY_DIR}/op_proto ${op_inc_file_name}_SRC ${op}) - list(APPEND OPS_PROTO_OBJECTS ${${op_inc_file_name}_SRC}) - endforeach() - file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") - if(BUILD_LITE) - list(REMOVE_ITEM _TRANSFORM_SRC_LIST "callbacks_ge.cc") - endif() - set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS - SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) - add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST} ${OPS_PROTO_OBJECTS}) -endif() +if(ENABLE_D OR ENABLE_ACL) + set(op_proto_src_gen_script "${CMAKE_CURRENT_BINARY_DIR}/op_proto/generate_op_proto.cmake") + file(WRITE ${op_proto_src_gen_script} "" + "get_filename_component(op_inc_file_name \${OP_PROTO_INC} NAME_WE) \n" + "set(OP_PROTO_INCLUDE_FILE \${OP_PROTO_INC}) \n" + "configure_file(${CMAKE_CURRENT_SOURCE_DIR}/op_declare/op_proto.cc.in \n" + " \${WORKSPACE_PATH}/\${op_inc_file_name}_op_proto.cc @ONLY) \n" + ) +endif() +function(op_proto_generate path c_var op_proto_include_file) + set(${c_var}) + get_filename_component(abs_file ${op_proto_include_file} ABSOLUTE) + get_filename_component(file_name ${op_proto_include_file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) + + list(APPEND ${c_var} "${path}/${file_name}_op_proto.cc") + add_custom_command( + OUTPUT "${path}/${file_name}_op_proto.cc" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} + -DOP_PROTO_INC=\"${op_proto_include_file}\" + -DWORKSPACE_PATH=${CMAKE_CURRENT_BINARY_DIR}/op_proto -P ${op_proto_src_gen_script} + DEPENDS ${op_proto_include_file} + ${CMAKE_CURRENT_SOURCE_DIR}/op_declare/op_proto.cc.in + COMMENT "Generating op proto source file: ${${c_var}}" VERBATIM) + + set_source_files_properties(${${c_var}} PROPERTIES GENERATED TRUE) + set(${c_var} ${${c_var}} PARENT_SCOPE) +endfunction() + +if(ENABLE_D OR ENABLE_ACL) + set(OPS_INC_DIR ${ASCEND_PATH}/latest/opp/built-in/op_proto/inc/) + file(GLOB_RECURSE CUSTOM_OPS_INC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/custom_op_proto/*.h) + message("CANN ops include path: " ${OPS_INC_DIR}) + file(GLOB_RECURSE OPS_INC_LIST ${OPS_INC_DIR}/*.h) + # remove god.h + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/all_ops.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_math.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/tensor.h") + # remove repeated header + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/outfeed_ops.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_pooling_ops.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/nn_norm.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/selection.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/transformation.h") + list(REMOVE_ITEM OPS_INC_LIST "${OPS_INC_DIR}/experiment_ops.h") + #append custom op + list(APPEND OPS_INC_LIST ${CUSTOM_OPS_INC_LIST}) + set(OPS_PROTO_OBJECTS) + foreach(op ${OPS_INC_LIST}) + get_filename_component(op_inc_file_name ${op} NAME_WE) + op_proto_generate(${CMAKE_CURRENT_BINARY_DIR}/op_proto ${op_inc_file_name}_SRC ${op}) + list(APPEND OPS_PROTO_OBJECTS ${${op_inc_file_name}_SRC}) + endforeach() + file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + if(BUILD_LITE) + list(REMOVE_ITEM _TRANSFORM_SRC_LIST "callbacks_ge.cc") + endif() + set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) + add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST} ${OPS_PROTO_OBJECTS}) +endif() diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h old mode 100755 new mode 100644 diff --git a/mindspore/ccsrc/utils/dlopen_macro.h b/mindspore/ccsrc/utils/dlopen_macro.h index f7ed3fb2669..7c86066b78f 100644 --- a/mindspore/ccsrc/utils/dlopen_macro.h +++ b/mindspore/ccsrc/utils/dlopen_macro.h @@ -1,88 +1,88 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H -#define MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H - -#ifndef _WIN32 -#include -#else -#include -#undef ERROR -#undef SM_DEBUG -#undef Yield -#endif -#include -#include -#include "utils/log_adapter.h" - -#ifndef _WIN32 -#define PORTABLE_EXPORT __attribute__((visibility("default"))) -#else -#define PORTABLE_EXPORT __declspec(dllexport) -#endif - -#define PLUGIN_METHOD(name, return_type, ...) \ - extern "C" { \ - PORTABLE_EXPORT return_type Plugin##name(__VA_ARGS__); \ - } \ - constexpr const char *k##name##Name = "Plugin" #name; \ - using name##FunObj = std::function; \ - using name##FunPtr = return_type (*)(__VA_ARGS__); - -#define ORIGIN_METHOD(name, return_type, ...) \ - extern "C" { \ - return_type name(__VA_ARGS__); \ - } \ - constexpr const char *k##name##Name = #name; \ - using name##FunObj = std::function; \ - using name##FunPtr = return_type (*)(__VA_ARGS__); - -inline static std::string GetDlErrorMsg() { -#ifndef _WIN32 - const char *result = dlerror(); - return (result == nullptr) ? "Unknown" : result; -#else - return std::to_string(GetLastError()); -#endif -} - -template -static T DlsymWithCast(void *handle, const char *symbol_name) { -#ifndef _WIN32 - T symbol = reinterpret_cast(reinterpret_cast(dlsym(handle, symbol_name))); -#else - T symbol = reinterpret_cast(GetProcAddress(reinterpret_cast(handle), symbol_name)); -#endif - if (symbol == nullptr) { - MS_LOG(EXCEPTION) << "Dynamically load symbol " << symbol_name << " failed, result = " << GetDlErrorMsg(); - } - return symbol; -} - -#define DlsymFuncObj(func_name, plugin_handle) DlsymWithCast(plugin_handle, k##func_name##Name); - -template -static T DlsymAscend(void *handle, const char *symbol_name) { - T symbol = reinterpret_cast(reinterpret_cast(dlsym(handle, symbol_name))); - if (symbol == nullptr) { - MS_LOG(WARNING) << "Dynamically load symbol " << symbol_name << " failed, result = " << GetDlErrorMsg(); - } - return symbol; -} - -#define DlsymAscendFuncObj(func_name, plugin_handle) DlsymAscend(plugin_handle, k##func_name##Name); -#endif // MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H +#define MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H + +#ifndef _WIN32 +#include +#else +#include +#undef ERROR +#undef SM_DEBUG +#undef Yield +#endif +#include +#include +#include "utils/log_adapter.h" + +#ifndef _WIN32 +#define PORTABLE_EXPORT __attribute__((visibility("default"))) +#else +#define PORTABLE_EXPORT __declspec(dllexport) +#endif + +#define PLUGIN_METHOD(name, return_type, ...) \ + extern "C" { \ + PORTABLE_EXPORT return_type Plugin##name(__VA_ARGS__); \ + } \ + constexpr const char *k##name##Name = "Plugin" #name; \ + using name##FunObj = std::function; \ + using name##FunPtr = return_type (*)(__VA_ARGS__); + +#define ORIGIN_METHOD(name, return_type, ...) \ + extern "C" { \ + return_type name(__VA_ARGS__); \ + } \ + constexpr const char *k##name##Name = #name; \ + using name##FunObj = std::function; \ + using name##FunPtr = return_type (*)(__VA_ARGS__); + +inline static std::string GetDlErrorMsg() { +#ifndef _WIN32 + const char *result = dlerror(); + return (result == nullptr) ? "Unknown" : result; +#else + return std::to_string(GetLastError()); +#endif +} + +template +static T DlsymWithCast(void *handle, const char *symbol_name) { +#ifndef _WIN32 + T symbol = reinterpret_cast(reinterpret_cast(dlsym(handle, symbol_name))); +#else + T symbol = reinterpret_cast(GetProcAddress(reinterpret_cast(handle), symbol_name)); +#endif + if (symbol == nullptr) { + MS_LOG(EXCEPTION) << "Dynamically load symbol " << symbol_name << " failed, result = " << GetDlErrorMsg(); + } + return symbol; +} + +#define DlsymFuncObj(func_name, plugin_handle) DlsymWithCast(plugin_handle, k##func_name##Name); + +template +static T DlsymAscend(void *handle, const char *symbol_name) { + T symbol = reinterpret_cast(reinterpret_cast(dlsym(handle, symbol_name))); + if (symbol == nullptr) { + MS_LOG(WARNING) << "Dynamically load symbol " << symbol_name << " failed, result = " << GetDlErrorMsg(); + } + return symbol; +} + +#define DlsymAscendFuncObj(func_name, plugin_handle) DlsymAscend(plugin_handle, k##func_name##Name); +#endif // MINDSPORE_CCSRC_UTILS_DLOPEN_MACRO_H diff --git a/mindspore/ccsrc/utils/recompute_helper.cc b/mindspore/ccsrc/utils/recompute_helper.cc index 1482f8f248f..62fb5ddb2c4 100644 --- a/mindspore/ccsrc/utils/recompute_helper.cc +++ b/mindspore/ccsrc/utils/recompute_helper.cc @@ -1,634 +1,634 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/common/utils/recompute_helper.h" -#include -#include -#include -#include -#include -#include "mindspore/core/ops/sequence_ops.h" -#include "mindspore/core/ops/other_ops.h" -#include "mindspore/core/ops/nn_ops.h" -#include "mindspore/core/ops/framework_ops.h" -#include "mindspore/core/ops/array_op_name.h" -#include "include/common/utils/utils.h" - -namespace mindspore { -constexpr auto kGradientsFlag = "Gradients"; -const int64_t fusion_id_increasement_size = 2000; -bool CanNotRecomputed(const CNodePtr &node) { - static mindspore::HashSet not_recomputed_op_list{ - prim::kPrimDropoutGenMask, prim::kPrimLoad, prim::kPrimTupleGetItem, prim::kPrimSend, prim::kPrimReceive}; - - return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); -} - -bool IsBpropNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - return IsBpropNode(node->cast()->input(1)); - } - return node->fullname_with_scope().find(kGradientsFlag) == 0; -} - -bool WithRecomputedScope(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto full_name_with_scope = node->fullname_with_scope(); - return full_name_with_scope.find(kAttrRecompute) == 0; -} - -ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast_ptr(); - if (cnode == nullptr) { - return nullptr; - } - return cnode->GetAttr(kAttrRecompute); -} - -bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) { - auto cnode_recompute_val = GetRecomputeCNodeAttr(node); - return cnode_recompute_val != nullptr && cnode_recompute_val->isa() && !GetValue(cnode_recompute_val); -} - -bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) { - auto cnode_recompute_val = GetRecomputeCNodeAttr(node); - return cnode_recompute_val != nullptr && cnode_recompute_val->isa() && GetValue(cnode_recompute_val); -} - -bool IsCandidateRecomputedNode(const CNodePtr &node) { - // The tuple_getitem in the bprop function should also be recomputed. - return (!IsBpropNode(node) || IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) && IsSetRecomputeCNodeAttr(node); -} - -bool DynShapeOpInsertedInBprop(const AnfNodePtr &node) { - static const PrimitiveSet dyn_shape_ops_prim_set = {prim::kPrimShape, std::make_shared(kShapeCalcOpName)}; - if (IsOneOfPrimitiveCNode(node, dyn_shape_ops_prim_set)) { - return true; - } - return false; -} - -std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, - const std::vector &cnodes) { - MS_EXCEPTION_IF_NULL(mng); - std::vector candidate_recomputed_nodes; - for (const auto &cnode : cnodes) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsCandidateRecomputedNode(cnode)) { - continue; - } - // Check outputs. - const auto &node_users = mng->node_users(); - auto output_set_iter = node_users.find(cnode); - if (output_set_iter == node_users.end()) { - continue; - } - const auto &node_index_set = output_set_iter->second; - if (!std::any_of(node_index_set.begin(), node_index_set.end(), - [](const auto &node_index) { return IsBpropNode(node_index.first); })) { - continue; - } - // Check inputs. - const auto &inputs = cnode->inputs(); - if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) { - continue; - } - (void)candidate_recomputed_nodes.emplace_back(cnode); - } - return candidate_recomputed_nodes; -} - -void GetMaxSubGraph(const FuncGraphManagerPtr &mng, mindspore::HashSet *recomputed_nodes, bool get_inputs, - bool get_outputs) { - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(recomputed_nodes); - std::queue nodes_to_visit; - for (const auto &node : *recomputed_nodes) { - nodes_to_visit.push(node); - } - recomputed_nodes->clear(); - while (!nodes_to_visit.empty()) { - auto current_node = nodes_to_visit.front(); - nodes_to_visit.pop(); - (void)recomputed_nodes->insert(current_node); - // No need to find nodes through side-effect dependency. - if (IsPrimitiveCNode(current_node, prim::kPrimUpdateState)) { - continue; - } - if (get_inputs) { - for (auto &weak_input : current_node->weak_inputs()) { - auto input = weak_input.lock(); - MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - auto input_cnode = input->cast(); - if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() && - IsCandidateRecomputedNode(input_cnode)) { - nodes_to_visit.push(input_cnode); - } - } - } - } - - if (get_outputs) { - const auto &node_users = mng->node_users(); - auto output_set_iter = node_users.find(current_node); - if (output_set_iter == node_users.end()) { - continue; - } - for (const auto &node_index_set : output_set_iter->second) { - auto output_node = node_index_set.first; - MS_EXCEPTION_IF_NULL(output_node); - if (output_node->isa()) { - auto output_cnode = output_node->cast(); - if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() && - IsCandidateRecomputedNode(output_cnode)) { - nodes_to_visit.push(output_cnode); - } - } - } - } - } -} - -void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, - const mindspore::HashSet &max_recomputed_sub_graph, - mindspore::HashSet *recompute_nodes, - mindspore::HashSet *target_nodes) { - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(recompute_nodes); - MS_EXCEPTION_IF_NULL(target_nodes); - const auto &node_users = mng->node_users(); - for (const auto &node : max_recomputed_sub_graph) { - bool inserted = false; - auto output_set_iter = node_users.find(node); - if (output_set_iter == node_users.end()) { - continue; - } - for (const auto &node_index_set : output_set_iter->second) { - auto output_node = node_index_set.first; - MS_EXCEPTION_IF_NULL(output_node); - // The tuple_getitem to be recomputed can be in the bprop function. - if (!IsBpropNode(output_node) || IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) { - continue; - } - if (DynShapeOpInsertedInBprop(output_node)) { - continue; - } - (void)target_nodes->insert(output_node->cast()); - if (!inserted) { - (void)recompute_nodes->insert(node); - inserted = true; - } - } - } -} - -std::vector GetInputNodesWithFilter(const CNodePtr &node, std::function filter, - std::function push) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector res; - std::queue cnode_queue; - cnode_queue.push(node); - while (!cnode_queue.empty()) { - auto queue_end = cnode_queue.front(); - cnode_queue.pop(); - auto input_nodes = queue_end->inputs(); - bool is_filtered = false; - for (size_t i = 1; i < input_nodes.size(); ++i) { - if (push(input_nodes[i])) { - res.push_back(queue_end); - is_filtered = true; - break; - } - } - if (!is_filtered) { - for (size_t i = 1; i < input_nodes.size(); ++i) { - if (!input_nodes[i]->isa()) { - continue; - } - if (filter(input_nodes[i])) { - continue; - } - cnode_queue.push(input_nodes[i]->cast()); - } - } - } - return res; -} - -void GetNewFirstTargetInputs(const std::vector &recompute_input_border_bprop_nodes, - std::function push_func, - std::vector *first_target_inputs, bool *inserted) { - for (const auto &input_border_bprop_node : recompute_input_border_bprop_nodes) { - MS_LOG(INFO) << "input_border_bprop_node:" << input_border_bprop_node->DebugString() - << ", the fullname:" << input_border_bprop_node->fullname_with_scope(); - if (!input_border_bprop_node->isa()) { - (void)(*first_target_inputs).emplace_back(input_border_bprop_node); - *inserted = true; - continue; - } - auto input_border_bprop_cnode = input_border_bprop_node->cast(); - for (size_t k = 1; k < input_border_bprop_cnode->size(); ++k) { - if (!push_func(input_border_bprop_cnode->input(k))) { - continue; - } - (void)(*first_target_inputs).emplace_back(input_border_bprop_cnode->input(k)); - *inserted = true; - } - } -} - -bool HasTargetOrRecomputeInputs(const mindspore::HashSet &recomputed_origin_nodes, - const mindspore::HashSet &target_nodes, const CNodePtr &node, - mindspore::HashMap *has_target_or_recompute_inputs_map) { - auto iter = has_target_or_recompute_inputs_map->find(node); - if (iter != has_target_or_recompute_inputs_map->end()) { - return iter->second; - } - if (recomputed_origin_nodes.find(node) != recomputed_origin_nodes.end()) { - (void)has_target_or_recompute_inputs_map->emplace(node, true); - return true; - } - if (target_nodes.find(node) != target_nodes.end()) { - (void)has_target_or_recompute_inputs_map->emplace(node, true); - return true; - } - - if (IsBpropNode(node) && !DynShapeOpInsertedInBprop(node)) { - for (auto &weak_input : node->weak_inputs()) { - auto input = weak_input.lock(); - MS_EXCEPTION_IF_NULL(input); - if (input->isa() && - HasTargetOrRecomputeInputs(recomputed_origin_nodes, target_nodes, input->cast(), - has_target_or_recompute_inputs_map)) { - (void)has_target_or_recompute_inputs_map->emplace(node, true); - return true; - } - } - } - (void)has_target_or_recompute_inputs_map->emplace(node, false); - return false; -} - -std::vector GetFirstTargetInputs(const std::vector &origin_nodes_topological, - const mindspore::HashSet &max_recomputed_sub_graph, - const mindspore::HashSet &recomputed_origin_nodes, - const mindspore::HashSet &target_nodes) { - std::vector first_target_inputs; - mindspore::HashMap has_grad_inputs_map; - auto filt_func = [&](const AnfNodePtr &anode) { - if (!anode->isa() || !anode->cast()->HasPrimalAttr(kPrimalAttrForwardUniqueId)) { - return true; - } - auto c_node = anode->cast(); - auto forward_unique_id = GetValue(c_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)); - return std::find_if(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end(), [&](CNodePtr r_cnode) { - return r_cnode->HasPrimalAttr(kPrimalAttrUniqueId) && - GetValue(r_cnode->GetPrimalAttr(kPrimalAttrUniqueId)) == forward_unique_id; - }) == max_recomputed_sub_graph.end(); - }; - auto push_func = [&](const AnfNodePtr &anode) { - if (!anode->isa() || !anode->cast()->HasPrimalAttr(kPrimalAttrForwardUniqueId)) { - return false; - } - auto c_node = anode->cast(); - auto forward_unique_id = GetValue(c_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)); - return std::find_if(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end(), [&](CNodePtr r_cnode) { - return r_cnode->HasPrimalAttr(kPrimalAttrUniqueId) && - GetValue(r_cnode->GetPrimalAttr(kPrimalAttrUniqueId)) == forward_unique_id; - }) == max_recomputed_sub_graph.end(); - }; - for (const auto &node : origin_nodes_topological) { - MS_EXCEPTION_IF_NULL(node); - if (target_nodes.find(node) == target_nodes.end()) { - continue; - } - for (size_t i = 1; i < node->size(); ++i) { - auto input = node->input(i); - MS_EXCEPTION_IF_NULL(input); - if (!input->isa()) { - continue; - } - auto input_cnode = input->cast(); - if (!IsBpropNode(input_cnode)) { - continue; - } - if (HasTargetOrRecomputeInputs(recomputed_origin_nodes, target_nodes, input_cnode, &has_grad_inputs_map)) { - continue; - } - - bool inserted = false; - for (size_t j = 1; j < input_cnode->size(); ++j) { - if (filt_func(input_cnode->input(j))) { - continue; - } - auto select_node = input_cnode->input(j)->cast(); - auto recompute_input_border_bprop_nodes = GetInputNodesWithFilter(select_node, filt_func, push_func); - if (recompute_input_border_bprop_nodes.empty()) { - (void)first_target_inputs.emplace_back(input); - inserted = true; - continue; - } - GetNewFirstTargetInputs(recompute_input_border_bprop_nodes, push_func, &first_target_inputs, &inserted); - } - if (!inserted) { - (void)first_target_inputs.emplace_back(input); - } - } - if (!first_target_inputs.empty()) { - break; - } - } - return first_target_inputs; -} - -bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap *has_grad_inputs_map) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(has_grad_inputs_map); - if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) { - return has_grad_inputs_map->find(node)->second; - } - auto cnode = node->cast_ptr(); - if (cnode == nullptr) { - (void)has_grad_inputs_map->emplace(node, false); - return false; - } - const auto &inputs = cnode->inputs(); - for (size_t i = 0; i < inputs.size(); ++i) { - // For the pipeline split case, the forward pass may depend on the backward pass. - if (cnode->IsApply(prim::kPrimDepend) && i == kDependAttachNodeIndex) { - continue; - } - if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) { - (void)has_grad_inputs_map->emplace(node, true); - return true; - } - } - (void)has_grad_inputs_map->emplace(node, false); - return false; -} - -bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(mng); - const auto &node_users = mng->node_users(); - auto output_set_iter = node_users.find(node); - if (output_set_iter == node_users.end()) { - return false; - } - - return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), - [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); }); -} - -void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, - std::vector *tuple_getitem_output_nodes) { - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes); - const auto &node_users = mng->node_users(); - auto output_set_iter = node_users.find(node); - if (output_set_iter == node_users.end()) { - return; - } - for (const auto &node_index_set : output_set_iter->second) { - if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) { - (void)tuple_getitem_output_nodes->emplace_back(node_index_set.first); - } - } -} - -bool SetRecomputedScope(const CNodePtr &node) { - return WithRecomputedScope(node) || - (IsPrimitiveCNode(node, prim::kPrimDepend) && WithRecomputedScope(node->input(kRealInputIndexInDepend))); -} - -void SetCkptOffloadAttr(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - node->AddAttr(kAttrCheckpoint, MakeValue(true)); -} - -// Set 'recompute' cnode attr for the nodes according to its scope. -// A node set 'recompute' cnode attr can become the candidate recomputed node. -void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &origin_nodes_topological) { - MS_EXCEPTION_IF_NULL(graph); - auto mng = graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - mindspore::HashMap has_grad_inputs_map; - for (const auto &node : origin_nodes_topological) { - MS_EXCEPTION_IF_NULL(node); - // The node may be set the non-recomputed before such as the cell outputs. - if (IsSetNoRecomputeCNodeAttr(node)) { - SetCkptOffloadAttr(node); - continue; - } - if (IsBpropNode(node)) { - SetCkptOffloadAttr(node); - continue; - } - // Filter some unrecomputable operators. - if (CanNotRecomputed(node)) { - SetCkptOffloadAttr(node); - continue; - } - if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { - SetCkptOffloadAttr(node); - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto prim = GetCNodePrimitive(cnode); - if (prim == nullptr) { - continue; - } - auto prim_recompute_attr = prim->GetAttr(kAttrRecompute); - int prim_recompute_val = -1; - if (prim_recompute_attr != nullptr && prim_recompute_attr->isa()) { - prim_recompute_val = static_cast(GetValue(prim_recompute_attr)); - } - if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) { - cnode->AddAttr(kAttrRecompute, MakeValue(true)); - } - if (!IsSetRecomputeCNodeAttr(node)) { - SetCkptOffloadAttr(node); - continue; - } - // Set attr for the tuple_getitem outputs. - std::vector tuple_getitem_output_nodes; - GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes); - for (const auto &output_node : tuple_getitem_output_nodes) { - auto output_cnode = output_node->cast_ptr(); - MS_EXCEPTION_IF_NULL(output_cnode); - output_cnode->AddAttr(kAttrRecompute, MakeValue(true)); - } - } -} - -CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, - const std::vector &new_inputs) { - auto recomputed_node = graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(recomputed_node); - recomputed_node->AddAttr(kAttrDuplicated, MakeValue(true)); - recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); - recomputed_node->set_abstract(origin_node->abstract()); - recomputed_node->set_scope(origin_node->scope()); - if (origin_node->HasPrimalAttr(kAttrMicro)) { - recomputed_node->AddPrimalAttr(kAttrMicro, origin_node->GetPrimalAttr(kAttrMicro)); - } - if (origin_node->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) { - recomputed_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, - origin_node->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)); - } - if (origin_node->HasAttr(kAttrRecomputeSubGraph)) { - recomputed_node->AddAttr(kAttrRecomputeSubGraph, origin_node->GetAttr(kAttrRecomputeSubGraph)); - } - static int64_t recompute_id = 0; - ++recompute_id; - recomputed_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id)); - origin_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id)); - static const PrimitiveSet dropout_prims = {prim::kPrimDropout, prim::kPrimDropoutDoMask, prim::kPrimDropoutDoMaskV3}; - static const std::vector need_primal_attr = {kAttrFusion, kPrimalAttrUniqueId, - kPrimalAttrForwardUniqueId}; - if (IsOneOfPrimitiveCNode(origin_node, dropout_prims)) { - for (auto &primal_attr : need_primal_attr) { - if (origin_node->HasPrimalAttr(primal_attr)) { - recomputed_node->AddPrimalAttr(primal_attr, origin_node->GetPrimalAttr(primal_attr)); - } - } - } - return recomputed_node; -} - -CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, - const std::vector &first_target_inputs, - const mindspore::HashSet &recomputed_origin_nodes, - mindspore::HashMap *origin_to_recomputed_nodes) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(origin_node); - MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes); - auto iter = origin_to_recomputed_nodes->find(origin_node); - if (iter != origin_to_recomputed_nodes->end()) { - return iter->second; - } - MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); - std::vector new_inputs; - bool has_recomputed_inputs = false; - for (size_t i = 0; i < origin_node->size(); ++i) { - auto input = origin_node->input(i); - if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) { - auto prim = GetValuePtr(input); - auto instance_name = prim->instance_name(); - bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos; - int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue(prim->GetAttr(kAttrFusion)) : 0; - if (is_from_parallel_optimizer && fusion_id > 0) { - auto new_prim = std::make_shared(prim::kPrimAllGather->name()); - (void)new_prim->SetAttrs(prim->attrs()); - new_prim->set_attr(kAttrFusion, MakeValue(fusion_id + fusion_id_increasement_size)); - new_prim->set_prim_type(prim->prim_type()); - new_prim->set_instance_name(instance_name); - auto value_node = NewValueNode(new_prim); - (void)new_inputs.emplace_back(value_node); - continue; - } - } - MS_EXCEPTION_IF_NULL(input); - if (!input->isa()) { - (void)new_inputs.emplace_back(input); - continue; - } - auto input_cnode = input->cast(); - if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) { - if (IsPrimitiveCNode(input_cnode, prim::kPrimUpdateState)) { - auto u = NewValueNode(kUMonad); - u->set_abstract(kUMonad->ToAbstract()); - (void)new_inputs.emplace_back(u); - } else { - (void)new_inputs.emplace_back(input); - } - } else { - has_recomputed_inputs = true; - (void)new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes, - origin_to_recomputed_nodes)); - } - } - // Add the execution dependency. - if (!has_recomputed_inputs && new_inputs.size() > 1) { - std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; - (void)std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs)); - AbstractBasePtrList abstract_list; - (void)std::transform(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(abstract_list), - [](const AnfNodePtr &node) -> AbstractBasePtr { return node->abstract(); }); - auto make_tuple = graph->NewCNode(make_tuple_inputs); - make_tuple->set_abstract(std::make_shared(abstract_list)); - auto first_input = new_inputs[1]; - MS_EXCEPTION_IF_NULL(first_input); - std::vector depend_inputs{NewValueNode(prim::kPrimDepend), first_input, make_tuple}; - auto depend_node = graph->NewCNode(depend_inputs); - MS_EXCEPTION_IF_NULL(depend_node); - depend_node->set_abstract(first_input->abstract()); - depend_node->AddAttr("recompute_depend", MakeValue(true)); - new_inputs[1] = depend_node; - } - auto recomputed_node = CreateNewRecomputedNode(graph, origin_node, new_inputs); - (void)origin_to_recomputed_nodes->emplace(origin_node, recomputed_node); - return recomputed_node; -} - -void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSet &target_nodes, - const mindspore::HashSet &origin_recomputed_nodes, - const std::vector &first_target_inputs, - mindspore::HashMap *origin_to_new_target_nodes, - mindspore::HashMap *origin_to_recomputed_nodes) { - MS_EXCEPTION_IF_NULL(graph); - auto mng = graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - for (const auto &target_node : target_nodes) { - MS_EXCEPTION_IF_NULL(target_node); - MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; - std::vector new_target_inputs; - for (auto &weak_input : target_node->weak_inputs()) { - auto input = weak_input.lock(); - MS_EXCEPTION_IF_NULL(input); - if (!input->isa()) { - (void)new_target_inputs.emplace_back(input); - } else { - auto input_cnode = input->cast(); - if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) { - (void)new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, - origin_recomputed_nodes, origin_to_recomputed_nodes)); - } else { - (void)new_target_inputs.emplace_back(input_cnode); - } - } - } - auto new_target_node = graph->NewCNode(new_target_inputs); - new_target_node->CloneCNodeInfo(target_node); - new_target_node->AddAttr("target_grad", MakeValue(true)); - new_target_node->set_scope(target_node->scope()); - (void)mng->Replace(target_node, new_target_node); - (void)origin_to_new_target_nodes->emplace(target_node, new_target_node); - } -} -} // namespace mindspore +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/common/utils/recompute_helper.h" +#include +#include +#include +#include +#include +#include "mindspore/core/ops/sequence_ops.h" +#include "mindspore/core/ops/other_ops.h" +#include "mindspore/core/ops/nn_ops.h" +#include "mindspore/core/ops/framework_ops.h" +#include "mindspore/core/ops/array_op_name.h" +#include "include/common/utils/utils.h" + +namespace mindspore { +constexpr auto kGradientsFlag = "Gradients"; +const int64_t fusion_id_increasement_size = 2000; +bool CanNotRecomputed(const CNodePtr &node) { + static mindspore::HashSet not_recomputed_op_list{ + prim::kPrimDropoutGenMask, prim::kPrimLoad, prim::kPrimTupleGetItem, prim::kPrimSend, prim::kPrimReceive}; + + return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); +} + +bool IsBpropNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + return IsBpropNode(node->cast()->input(1)); + } + return node->fullname_with_scope().find(kGradientsFlag) == 0; +} + +bool WithRecomputedScope(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto full_name_with_scope = node->fullname_with_scope(); + return full_name_with_scope.find(kAttrRecompute) == 0; +} + +ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast_ptr(); + if (cnode == nullptr) { + return nullptr; + } + return cnode->GetAttr(kAttrRecompute); +} + +bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) { + auto cnode_recompute_val = GetRecomputeCNodeAttr(node); + return cnode_recompute_val != nullptr && cnode_recompute_val->isa() && !GetValue(cnode_recompute_val); +} + +bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) { + auto cnode_recompute_val = GetRecomputeCNodeAttr(node); + return cnode_recompute_val != nullptr && cnode_recompute_val->isa() && GetValue(cnode_recompute_val); +} + +bool IsCandidateRecomputedNode(const CNodePtr &node) { + // The tuple_getitem in the bprop function should also be recomputed. + return (!IsBpropNode(node) || IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) && IsSetRecomputeCNodeAttr(node); +} + +bool DynShapeOpInsertedInBprop(const AnfNodePtr &node) { + static const PrimitiveSet dyn_shape_ops_prim_set = {prim::kPrimShape, std::make_shared(kShapeCalcOpName)}; + if (IsOneOfPrimitiveCNode(node, dyn_shape_ops_prim_set)) { + return true; + } + return false; +} + +std::vector FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng, + const std::vector &cnodes) { + MS_EXCEPTION_IF_NULL(mng); + std::vector candidate_recomputed_nodes; + for (const auto &cnode : cnodes) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsCandidateRecomputedNode(cnode)) { + continue; + } + // Check outputs. + const auto &node_users = mng->node_users(); + auto output_set_iter = node_users.find(cnode); + if (output_set_iter == node_users.end()) { + continue; + } + const auto &node_index_set = output_set_iter->second; + if (!std::any_of(node_index_set.begin(), node_index_set.end(), + [](const auto &node_index) { return IsBpropNode(node_index.first); })) { + continue; + } + // Check inputs. + const auto &inputs = cnode->inputs(); + if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsBpropNode(node); })) { + continue; + } + (void)candidate_recomputed_nodes.emplace_back(cnode); + } + return candidate_recomputed_nodes; +} + +void GetMaxSubGraph(const FuncGraphManagerPtr &mng, mindspore::HashSet *recomputed_nodes, bool get_inputs, + bool get_outputs) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(recomputed_nodes); + std::queue nodes_to_visit; + for (const auto &node : *recomputed_nodes) { + nodes_to_visit.push(node); + } + recomputed_nodes->clear(); + while (!nodes_to_visit.empty()) { + auto current_node = nodes_to_visit.front(); + nodes_to_visit.pop(); + (void)recomputed_nodes->insert(current_node); + // No need to find nodes through side-effect dependency. + if (IsPrimitiveCNode(current_node, prim::kPrimUpdateState)) { + continue; + } + if (get_inputs) { + for (auto &weak_input : current_node->weak_inputs()) { + auto input = weak_input.lock(); + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + auto input_cnode = input->cast(); + if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() && + IsCandidateRecomputedNode(input_cnode)) { + nodes_to_visit.push(input_cnode); + } + } + } + } + + if (get_outputs) { + const auto &node_users = mng->node_users(); + auto output_set_iter = node_users.find(current_node); + if (output_set_iter == node_users.end()) { + continue; + } + for (const auto &node_index_set : output_set_iter->second) { + auto output_node = node_index_set.first; + MS_EXCEPTION_IF_NULL(output_node); + if (output_node->isa()) { + auto output_cnode = output_node->cast(); + if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() && + IsCandidateRecomputedNode(output_cnode)) { + nodes_to_visit.push(output_cnode); + } + } + } + } + } +} + +void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng, + const mindspore::HashSet &max_recomputed_sub_graph, + mindspore::HashSet *recompute_nodes, + mindspore::HashSet *target_nodes) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(recompute_nodes); + MS_EXCEPTION_IF_NULL(target_nodes); + const auto &node_users = mng->node_users(); + for (const auto &node : max_recomputed_sub_graph) { + bool inserted = false; + auto output_set_iter = node_users.find(node); + if (output_set_iter == node_users.end()) { + continue; + } + for (const auto &node_index_set : output_set_iter->second) { + auto output_node = node_index_set.first; + MS_EXCEPTION_IF_NULL(output_node); + // The tuple_getitem to be recomputed can be in the bprop function. + if (!IsBpropNode(output_node) || IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) { + continue; + } + if (DynShapeOpInsertedInBprop(output_node)) { + continue; + } + (void)target_nodes->insert(output_node->cast()); + if (!inserted) { + (void)recompute_nodes->insert(node); + inserted = true; + } + } + } +} + +std::vector GetInputNodesWithFilter(const CNodePtr &node, std::function filter, + std::function push) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector res; + std::queue cnode_queue; + cnode_queue.push(node); + while (!cnode_queue.empty()) { + auto queue_end = cnode_queue.front(); + cnode_queue.pop(); + auto input_nodes = queue_end->inputs(); + bool is_filtered = false; + for (size_t i = 1; i < input_nodes.size(); ++i) { + if (push(input_nodes[i])) { + res.push_back(queue_end); + is_filtered = true; + break; + } + } + if (!is_filtered) { + for (size_t i = 1; i < input_nodes.size(); ++i) { + if (!input_nodes[i]->isa()) { + continue; + } + if (filter(input_nodes[i])) { + continue; + } + cnode_queue.push(input_nodes[i]->cast()); + } + } + } + return res; +} + +void GetNewFirstTargetInputs(const std::vector &recompute_input_border_bprop_nodes, + std::function push_func, + std::vector *first_target_inputs, bool *inserted) { + for (const auto &input_border_bprop_node : recompute_input_border_bprop_nodes) { + MS_LOG(INFO) << "input_border_bprop_node:" << input_border_bprop_node->DebugString() + << ", the fullname:" << input_border_bprop_node->fullname_with_scope(); + if (!input_border_bprop_node->isa()) { + (void)(*first_target_inputs).emplace_back(input_border_bprop_node); + *inserted = true; + continue; + } + auto input_border_bprop_cnode = input_border_bprop_node->cast(); + for (size_t k = 1; k < input_border_bprop_cnode->size(); ++k) { + if (!push_func(input_border_bprop_cnode->input(k))) { + continue; + } + (void)(*first_target_inputs).emplace_back(input_border_bprop_cnode->input(k)); + *inserted = true; + } + } +} + +bool HasTargetOrRecomputeInputs(const mindspore::HashSet &recomputed_origin_nodes, + const mindspore::HashSet &target_nodes, const CNodePtr &node, + mindspore::HashMap *has_target_or_recompute_inputs_map) { + auto iter = has_target_or_recompute_inputs_map->find(node); + if (iter != has_target_or_recompute_inputs_map->end()) { + return iter->second; + } + if (recomputed_origin_nodes.find(node) != recomputed_origin_nodes.end()) { + (void)has_target_or_recompute_inputs_map->emplace(node, true); + return true; + } + if (target_nodes.find(node) != target_nodes.end()) { + (void)has_target_or_recompute_inputs_map->emplace(node, true); + return true; + } + + if (IsBpropNode(node) && !DynShapeOpInsertedInBprop(node)) { + for (auto &weak_input : node->weak_inputs()) { + auto input = weak_input.lock(); + MS_EXCEPTION_IF_NULL(input); + if (input->isa() && + HasTargetOrRecomputeInputs(recomputed_origin_nodes, target_nodes, input->cast(), + has_target_or_recompute_inputs_map)) { + (void)has_target_or_recompute_inputs_map->emplace(node, true); + return true; + } + } + } + (void)has_target_or_recompute_inputs_map->emplace(node, false); + return false; +} + +std::vector GetFirstTargetInputs(const std::vector &origin_nodes_topological, + const mindspore::HashSet &max_recomputed_sub_graph, + const mindspore::HashSet &recomputed_origin_nodes, + const mindspore::HashSet &target_nodes) { + std::vector first_target_inputs; + mindspore::HashMap has_grad_inputs_map; + auto filt_func = [&](const AnfNodePtr &anode) { + if (!anode->isa() || !anode->cast()->HasPrimalAttr(kPrimalAttrForwardUniqueId)) { + return true; + } + auto c_node = anode->cast(); + auto forward_unique_id = GetValue(c_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)); + return std::find_if(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end(), [&](CNodePtr r_cnode) { + return r_cnode->HasPrimalAttr(kPrimalAttrUniqueId) && + GetValue(r_cnode->GetPrimalAttr(kPrimalAttrUniqueId)) == forward_unique_id; + }) == max_recomputed_sub_graph.end(); + }; + auto push_func = [&](const AnfNodePtr &anode) { + if (!anode->isa() || !anode->cast()->HasPrimalAttr(kPrimalAttrForwardUniqueId)) { + return false; + } + auto c_node = anode->cast(); + auto forward_unique_id = GetValue(c_node->GetPrimalAttr(kPrimalAttrForwardUniqueId)); + return std::find_if(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end(), [&](CNodePtr r_cnode) { + return r_cnode->HasPrimalAttr(kPrimalAttrUniqueId) && + GetValue(r_cnode->GetPrimalAttr(kPrimalAttrUniqueId)) == forward_unique_id; + }) == max_recomputed_sub_graph.end(); + }; + for (const auto &node : origin_nodes_topological) { + MS_EXCEPTION_IF_NULL(node); + if (target_nodes.find(node) == target_nodes.end()) { + continue; + } + for (size_t i = 1; i < node->size(); ++i) { + auto input = node->input(i); + MS_EXCEPTION_IF_NULL(input); + if (!input->isa()) { + continue; + } + auto input_cnode = input->cast(); + if (!IsBpropNode(input_cnode)) { + continue; + } + if (HasTargetOrRecomputeInputs(recomputed_origin_nodes, target_nodes, input_cnode, &has_grad_inputs_map)) { + continue; + } + + bool inserted = false; + for (size_t j = 1; j < input_cnode->size(); ++j) { + if (filt_func(input_cnode->input(j))) { + continue; + } + auto select_node = input_cnode->input(j)->cast(); + auto recompute_input_border_bprop_nodes = GetInputNodesWithFilter(select_node, filt_func, push_func); + if (recompute_input_border_bprop_nodes.empty()) { + (void)first_target_inputs.emplace_back(input); + inserted = true; + continue; + } + GetNewFirstTargetInputs(recompute_input_border_bprop_nodes, push_func, &first_target_inputs, &inserted); + } + if (!inserted) { + (void)first_target_inputs.emplace_back(input); + } + } + if (!first_target_inputs.empty()) { + break; + } + } + return first_target_inputs; +} + +bool HasGradInputs(const AnfNodePtr &node, mindspore::HashMap *has_grad_inputs_map) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(has_grad_inputs_map); + if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) { + return has_grad_inputs_map->find(node)->second; + } + auto cnode = node->cast_ptr(); + if (cnode == nullptr) { + (void)has_grad_inputs_map->emplace(node, false); + return false; + } + const auto &inputs = cnode->inputs(); + for (size_t i = 0; i < inputs.size(); ++i) { + // For the pipeline split case, the forward pass may depend on the backward pass. + if (cnode->IsApply(prim::kPrimDepend) && i == kDependAttachNodeIndex) { + continue; + } + if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) { + (void)has_grad_inputs_map->emplace(node, true); + return true; + } + } + (void)has_grad_inputs_map->emplace(node, false); + return false; +} + +bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(mng); + const auto &node_users = mng->node_users(); + auto output_set_iter = node_users.find(node); + if (output_set_iter == node_users.end()) { + return false; + } + + return std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), + [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); }); +} + +void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, + std::vector *tuple_getitem_output_nodes) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes); + const auto &node_users = mng->node_users(); + auto output_set_iter = node_users.find(node); + if (output_set_iter == node_users.end()) { + return; + } + for (const auto &node_index_set : output_set_iter->second) { + if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) { + (void)tuple_getitem_output_nodes->emplace_back(node_index_set.first); + } + } +} + +bool SetRecomputedScope(const CNodePtr &node) { + return WithRecomputedScope(node) || + (IsPrimitiveCNode(node, prim::kPrimDepend) && WithRecomputedScope(node->input(kRealInputIndexInDepend))); +} + +void SetCkptOffloadAttr(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + node->AddAttr(kAttrCheckpoint, MakeValue(true)); +} + +// Set 'recompute' cnode attr for the nodes according to its scope. +// A node set 'recompute' cnode attr can become the candidate recomputed node. +void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &origin_nodes_topological) { + MS_EXCEPTION_IF_NULL(graph); + auto mng = graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + mindspore::HashMap has_grad_inputs_map; + for (const auto &node : origin_nodes_topological) { + MS_EXCEPTION_IF_NULL(node); + // The node may be set the non-recomputed before such as the cell outputs. + if (IsSetNoRecomputeCNodeAttr(node)) { + SetCkptOffloadAttr(node); + continue; + } + if (IsBpropNode(node)) { + SetCkptOffloadAttr(node); + continue; + } + // Filter some unrecomputable operators. + if (CanNotRecomputed(node)) { + SetCkptOffloadAttr(node); + continue; + } + if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) { + SetCkptOffloadAttr(node); + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + continue; + } + auto prim_recompute_attr = prim->GetAttr(kAttrRecompute); + int prim_recompute_val = -1; + if (prim_recompute_attr != nullptr && prim_recompute_attr->isa()) { + prim_recompute_val = static_cast(GetValue(prim_recompute_attr)); + } + if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) { + cnode->AddAttr(kAttrRecompute, MakeValue(true)); + } + if (!IsSetRecomputeCNodeAttr(node)) { + SetCkptOffloadAttr(node); + continue; + } + // Set attr for the tuple_getitem outputs. + std::vector tuple_getitem_output_nodes; + GetTupleGetItemOutputNodes(mng, node, &tuple_getitem_output_nodes); + for (const auto &output_node : tuple_getitem_output_nodes) { + auto output_cnode = output_node->cast_ptr(); + MS_EXCEPTION_IF_NULL(output_cnode); + output_cnode->AddAttr(kAttrRecompute, MakeValue(true)); + } + } +} + +CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, + const std::vector &new_inputs) { + auto recomputed_node = graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(recomputed_node); + recomputed_node->AddAttr(kAttrDuplicated, MakeValue(true)); + recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); + recomputed_node->set_abstract(origin_node->abstract()); + recomputed_node->set_scope(origin_node->scope()); + if (origin_node->HasPrimalAttr(kAttrMicro)) { + recomputed_node->AddPrimalAttr(kAttrMicro, origin_node->GetPrimalAttr(kAttrMicro)); + } + if (origin_node->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) { + recomputed_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, + origin_node->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)); + } + if (origin_node->HasAttr(kAttrRecomputeSubGraph)) { + recomputed_node->AddAttr(kAttrRecomputeSubGraph, origin_node->GetAttr(kAttrRecomputeSubGraph)); + } + static int64_t recompute_id = 0; + ++recompute_id; + recomputed_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id)); + origin_node->AddAttr(kAttrRecomputeId, MakeValue(recompute_id)); + static const PrimitiveSet dropout_prims = {prim::kPrimDropout, prim::kPrimDropoutDoMask, prim::kPrimDropoutDoMaskV3}; + static const std::vector need_primal_attr = {kAttrFusion, kPrimalAttrUniqueId, + kPrimalAttrForwardUniqueId}; + if (IsOneOfPrimitiveCNode(origin_node, dropout_prims)) { + for (auto &primal_attr : need_primal_attr) { + if (origin_node->HasPrimalAttr(primal_attr)) { + recomputed_node->AddPrimalAttr(primal_attr, origin_node->GetPrimalAttr(primal_attr)); + } + } + } + return recomputed_node; +} + +CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, + const std::vector &first_target_inputs, + const mindspore::HashSet &recomputed_origin_nodes, + mindspore::HashMap *origin_to_recomputed_nodes) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(origin_node); + MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes); + auto iter = origin_to_recomputed_nodes->find(origin_node); + if (iter != origin_to_recomputed_nodes->end()) { + return iter->second; + } + MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString(); + std::vector new_inputs; + bool has_recomputed_inputs = false; + for (size_t i = 0; i < origin_node->size(); ++i) { + auto input = origin_node->input(i); + if (i == 0 && IsPrimitive(input, prim::kPrimAllGather)) { + auto prim = GetValuePtr(input); + auto instance_name = prim->instance_name(); + bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos; + int64_t fusion_id = prim->HasAttr(kAttrFusion) ? GetValue(prim->GetAttr(kAttrFusion)) : 0; + if (is_from_parallel_optimizer && fusion_id > 0) { + auto new_prim = std::make_shared(prim::kPrimAllGather->name()); + (void)new_prim->SetAttrs(prim->attrs()); + new_prim->set_attr(kAttrFusion, MakeValue(fusion_id + fusion_id_increasement_size)); + new_prim->set_prim_type(prim->prim_type()); + new_prim->set_instance_name(instance_name); + auto value_node = NewValueNode(new_prim); + (void)new_inputs.emplace_back(value_node); + continue; + } + } + MS_EXCEPTION_IF_NULL(input); + if (!input->isa()) { + (void)new_inputs.emplace_back(input); + continue; + } + auto input_cnode = input->cast(); + if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) { + if (IsPrimitiveCNode(input_cnode, prim::kPrimUpdateState)) { + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + (void)new_inputs.emplace_back(u); + } else { + (void)new_inputs.emplace_back(input); + } + } else { + has_recomputed_inputs = true; + (void)new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes, + origin_to_recomputed_nodes)); + } + } + // Add the execution dependency. + if (!has_recomputed_inputs && new_inputs.size() > 1) { + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + (void)std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs)); + AbstractBasePtrList abstract_list; + (void)std::transform(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(abstract_list), + [](const AnfNodePtr &node) -> AbstractBasePtr { return node->abstract(); }); + auto make_tuple = graph->NewCNode(make_tuple_inputs); + make_tuple->set_abstract(std::make_shared(abstract_list)); + auto first_input = new_inputs[1]; + MS_EXCEPTION_IF_NULL(first_input); + std::vector depend_inputs{NewValueNode(prim::kPrimDepend), first_input, make_tuple}; + auto depend_node = graph->NewCNode(depend_inputs); + MS_EXCEPTION_IF_NULL(depend_node); + depend_node->set_abstract(first_input->abstract()); + depend_node->AddAttr("recompute_depend", MakeValue(true)); + new_inputs[1] = depend_node; + } + auto recomputed_node = CreateNewRecomputedNode(graph, origin_node, new_inputs); + (void)origin_to_recomputed_nodes->emplace(origin_node, recomputed_node); + return recomputed_node; +} + +void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const mindspore::HashSet &target_nodes, + const mindspore::HashSet &origin_recomputed_nodes, + const std::vector &first_target_inputs, + mindspore::HashMap *origin_to_new_target_nodes, + mindspore::HashMap *origin_to_recomputed_nodes) { + MS_EXCEPTION_IF_NULL(graph); + auto mng = graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + for (const auto &target_node : target_nodes) { + MS_EXCEPTION_IF_NULL(target_node); + MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input"; + std::vector new_target_inputs; + for (auto &weak_input : target_node->weak_inputs()) { + auto input = weak_input.lock(); + MS_EXCEPTION_IF_NULL(input); + if (!input->isa()) { + (void)new_target_inputs.emplace_back(input); + } else { + auto input_cnode = input->cast(); + if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) { + (void)new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, + origin_recomputed_nodes, origin_to_recomputed_nodes)); + } else { + (void)new_target_inputs.emplace_back(input_cnode); + } + } + } + auto new_target_node = graph->NewCNode(new_target_inputs); + new_target_node->CloneCNodeInfo(target_node); + new_target_node->AddAttr("target_grad", MakeValue(true)); + new_target_node->set_scope(target_node->scope()); + (void)mng->Replace(target_node, new_target_node); + (void)origin_to_new_target_nodes->emplace(target_node, new_target_node); + } +} +} // namespace mindspore diff --git a/mindspore/core/abstract/ops/infer_functions.cc b/mindspore/core/abstract/ops/infer_functions.cc old mode 100755 new mode 100644 diff --git a/mindspore/core/mindrt/src/async/uuid_generator.cc b/mindspore/core/mindrt/src/async/uuid_generator.cc index 977117dda0b..82326e51ccd 100644 --- a/mindspore/core/mindrt/src/async/uuid_generator.cc +++ b/mindspore/core/mindrt/src/async/uuid_generator.cc @@ -1,55 +1,55 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "async/uuid_generator.h" -#include - -namespace mindspore { -namespace uuid_generator { -std::string UUID::ToString() { - std::ostringstream ret; - ret << *this; - return ret.str(); -} -} // namespace uuid_generator - -namespace localid_generator { -int GenLocalActorId() { - static std::atomic localActorId(0); - return localActorId.fetch_add(1); -} - -#ifdef HTTP_ENABLED -// not support muti-thread -int GenHttpClientConnId() { - static int httpClientConnId = 1; - if (httpClientConnId == INT_MAX) { - httpClientConnId = 1; - } - return httpClientConnId++; -} - -// not support muti-thread -int GenHttpServerConnId() { - static int httpServerConnId = 1; - if (httpServerConnId == INT_MAX) { - httpServerConnId = 1; - } - return httpServerConnId++; -} -#endif -} // namespace localid_generator -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "async/uuid_generator.h" +#include + +namespace mindspore { +namespace uuid_generator { +std::string UUID::ToString() { + std::ostringstream ret; + ret << *this; + return ret.str(); +} +} // namespace uuid_generator + +namespace localid_generator { +int GenLocalActorId() { + static std::atomic localActorId(0); + return localActorId.fetch_add(1); +} + +#ifdef HTTP_ENABLED +// not support muti-thread +int GenHttpClientConnId() { + static int httpClientConnId = 1; + if (httpClientConnId == INT_MAX) { + httpClientConnId = 1; + } + return httpClientConnId++; +} + +// not support muti-thread +int GenHttpServerConnId() { + static int httpServerConnId = 1; + if (httpServerConnId == INT_MAX) { + httpServerConnId = 1; + } + return httpServerConnId++; +} +#endif +} // namespace localid_generator +} // namespace mindspore diff --git a/mindspore/core/ops/adjust_saturation.h b/mindspore/core/ops/adjust_saturation.h index 22dfa642f3b..d00a71f5d3c 100644 --- a/mindspore/core/ops/adjust_saturation.h +++ b/mindspore/core/ops/adjust_saturation.h @@ -1,46 +1,46 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ -#define MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ -#include -#include -#include -#include "ops/base_operator.h" -#include "ops/op_utils.h" -#include "ops/primitive_c.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameAdjustSaturation = "AdjustSaturation"; -/// \brief Convert the images to HSV and multiply the saturation (S) channel by `scale` and clipping. -/// Refer to Python API @ref mindspore.ops.AdjustSaturation for more details. -class MIND_API AdjustSaturation : public BaseOperator { - public: - MIND_API_BASE_MEMBER(AdjustSaturation); - /// \brief Constructor. - AdjustSaturation() : BaseOperator(kNameAdjustSaturation) { InitIOName({"image", "scale"}, {"y"}); } -}; - -MIND_API abstract::AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimAdjustSaturationPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ +#define MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ +#include +#include +#include +#include "ops/base_operator.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameAdjustSaturation = "AdjustSaturation"; +/// \brief Convert the images to HSV and multiply the saturation (S) channel by `scale` and clipping. +/// Refer to Python API @ref mindspore.ops.AdjustSaturation for more details. +class MIND_API AdjustSaturation : public BaseOperator { + public: + MIND_API_BASE_MEMBER(AdjustSaturation); + /// \brief Constructor. + AdjustSaturation() : BaseOperator(kNameAdjustSaturation) { InitIOName({"image", "scale"}, {"y"}); } +}; + +MIND_API abstract::AbstractBasePtr AdjustSaturationInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimAdjustSaturationPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_ADJUST_SATURATION_H_ diff --git a/mindspore/core/ops/bartlett_window.h b/mindspore/core/ops/bartlett_window.h index f9c0869ccf9..626519651ee 100644 --- a/mindspore/core/ops/bartlett_window.h +++ b/mindspore/core/ops/bartlett_window.h @@ -1,44 +1,44 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_BARTLETT_WINDOW_H_ -#define MINDSPORE_CORE_OPS_BARTLETT_WINDOW_H_ -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameBartlettWindow = "BartlettWindow"; -class MIND_API BartlettWindow : public BaseOperator { - public: - MIND_API_BASE_MEMBER(BartlettWindow); - BartlettWindow() : BaseOperator(kNameBartlettWindow) { InitIOName({"window_length"}, {"y"}); } - /// \brief Init. - void Init(const bool periodic = true); - /// \brief Set periodic. - void set_periodic(const bool periodic); - bool get_periodic() const; -}; - -MIND_API abstract::AbstractBasePtr BartlettWindowInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_BARTLETTWINDOW_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_BARTLETT_WINDOW_H_ +#define MINDSPORE_CORE_OPS_BARTLETT_WINDOW_H_ +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameBartlettWindow = "BartlettWindow"; +class MIND_API BartlettWindow : public BaseOperator { + public: + MIND_API_BASE_MEMBER(BartlettWindow); + BartlettWindow() : BaseOperator(kNameBartlettWindow) { InitIOName({"window_length"}, {"y"}); } + /// \brief Init. + void Init(const bool periodic = true); + /// \brief Set periodic. + void set_periodic(const bool periodic); + bool get_periodic() const; +}; + +MIND_API abstract::AbstractBasePtr BartlettWindowInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_BARTLETTWINDOW_H_ diff --git a/mindspore/core/ops/bessel_j0.h b/mindspore/core/ops/bessel_j0.h old mode 100755 new mode 100644 diff --git a/mindspore/core/ops/blackman_window.h b/mindspore/core/ops/blackman_window.h index ee33b2ff53c..451f5db0c07 100644 --- a/mindspore/core/ops/blackman_window.h +++ b/mindspore/core/ops/blackman_window.h @@ -1,44 +1,44 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ -#define MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameBlackmanWindow = "BlackmanWindow"; -class MIND_API BlackmanWindow : public BaseOperator { - public: - MIND_API_BASE_MEMBER(BlackmanWindow); - BlackmanWindow() : BaseOperator(kNameBlackmanWindow) { InitIOName({"window_length"}, {"y"}); } - /// \brief Init. - void Init(const bool periodic = true); - /// \brief Set periodic. - void set_periodic(const bool periodic); - bool get_periodic() const; -}; - -MIND_API abstract::AbstractBasePtr BlackmanWindowInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ +#define MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameBlackmanWindow = "BlackmanWindow"; +class MIND_API BlackmanWindow : public BaseOperator { + public: + MIND_API_BASE_MEMBER(BlackmanWindow); + BlackmanWindow() : BaseOperator(kNameBlackmanWindow) { InitIOName({"window_length"}, {"y"}); } + /// \brief Init. + void Init(const bool periodic = true); + /// \brief Set periodic. + void set_periodic(const bool periodic); + bool get_periodic() const; +}; + +MIND_API abstract::AbstractBasePtr BlackmanWindowInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_BLACKMAN_WINDOW_H_ diff --git a/mindspore/core/ops/bounding_box_decode.h b/mindspore/core/ops/bounding_box_decode.h index c28e743aa8c..887989399e2 100644 --- a/mindspore/core/ops/bounding_box_decode.h +++ b/mindspore/core/ops/bounding_box_decode.h @@ -1,43 +1,43 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ -#define MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ - -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameBoundingBoxDecode = "BoundingBoxDecode"; -class MIND_API BoundingBoxDecode : public BaseOperator { - public: - MIND_API_BASE_MEMBER(BoundingBoxDecode); - BoundingBoxDecode() : BaseOperator(kNameBoundingBoxDecode) { InitIOName({"anchor_box", "deltas"}, {"output"}); } -}; - -MIND_API abstract::AbstractBasePtr BoundingBoxDecodeInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ +#define MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ + +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameBoundingBoxDecode = "BoundingBoxDecode"; +class MIND_API BoundingBoxDecode : public BaseOperator { + public: + MIND_API_BASE_MEMBER(BoundingBoxDecode); + BoundingBoxDecode() : BaseOperator(kNameBoundingBoxDecode) { InitIOName({"anchor_box", "deltas"}, {"output"}); } +}; + +MIND_API abstract::AbstractBasePtr BoundingBoxDecodeInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_BOUNDING_BOX_DECODE_H_ diff --git a/mindspore/core/ops/cumulative_logsumexp.h b/mindspore/core/ops/cumulative_logsumexp.h index 0da1fb46e7d..92be32dac27 100644 --- a/mindspore/core/ops/cumulative_logsumexp.h +++ b/mindspore/core/ops/cumulative_logsumexp.h @@ -1,44 +1,44 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ -#define MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameCumulativeLogsumexp = "CumulativeLogsumexp"; -class MIND_API CumulativeLogsumexp : public BaseOperator { - public: - MIND_API_BASE_MEMBER(CumulativeLogsumexp); - CumulativeLogsumexp() : BaseOperator(kNameCumulativeLogsumexp) { InitIOName({"x", "axis"}, {"y"}); } - void Init() const {} - bool get_exclusive() const; - bool get_reverse() const; - int64_t get_axis() const; -}; - -MIND_API abstract::AbstractBasePtr CumulativeLogsumexpInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ +#define MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCumulativeLogsumexp = "CumulativeLogsumexp"; +class MIND_API CumulativeLogsumexp : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CumulativeLogsumexp); + CumulativeLogsumexp() : BaseOperator(kNameCumulativeLogsumexp) { InitIOName({"x", "axis"}, {"y"}); } + void Init() const {} + bool get_exclusive() const; + bool get_reverse() const; + int64_t get_axis() const; +}; + +MIND_API abstract::AbstractBasePtr CumulativeLogsumexpInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_CUMULATIVE_LOGSUMEXP_H_ diff --git a/mindspore/core/ops/expand.h b/mindspore/core/ops/expand.h index 4a52b3c1c1b..70a663f00f5 100644 --- a/mindspore/core/ops/expand.h +++ b/mindspore/core/ops/expand.h @@ -1,47 +1,47 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_EXPAND_H_ -#define MINDSPORE_CORE_OPS_EXPAND_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameExpand = "Expand"; -/// \brief Expand the tensor ‘x‘ to the size of ‘shape‘ -///// Refer to Python API @ref mindspore.ops.Expand for more details. -class MIND_API Expand : public BaseOperator { - public: - MIND_API_BASE_MEMBER(Expand); - /// \brief Constructor. - Expand() : BaseOperator(kNameExpand) { InitIOName({"x", "shape"}, {"y"}); } - /// \brief Destructor. - void Init() const {} -}; - -MIND_API abstract::AbstractBasePtr ExpandInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); - -using PrimExpand = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_EXPAND_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_EXPAND_H_ +#define MINDSPORE_CORE_OPS_EXPAND_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameExpand = "Expand"; +/// \brief Expand the tensor ‘x‘ to the size of ‘shape‘ +///// Refer to Python API @ref mindspore.ops.Expand for more details. +class MIND_API Expand : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Expand); + /// \brief Constructor. + Expand() : BaseOperator(kNameExpand) { InitIOName({"x", "shape"}, {"y"}); } + /// \brief Destructor. + void Init() const {} +}; + +MIND_API abstract::AbstractBasePtr ExpandInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimExpand = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_EXPAND_H_ diff --git a/mindspore/core/ops/fill_diagonal.h b/mindspore/core/ops/fill_diagonal.h index 5cdab6b412b..012a603d4fc 100644 --- a/mindspore/core/ops/fill_diagonal.h +++ b/mindspore/core/ops/fill_diagonal.h @@ -1,56 +1,56 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ -#define MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ - -#include -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameFillDiagonal = "FillDiagonal"; -/// \brief Fill the main diagonal of a tensor that has at least 2-dimensions. -/// Refer to Python API @ref mindspore.ops.FillDiagonal for more details. -class MIND_API FillDiagonal : public BaseOperator { - public: - MIND_API_BASE_MEMBER(FillDiagonal); - /// \brief Constructor. - FillDiagonal() : BaseOperator(kNameFillDiagonal) { InitIOName({"input_x"}, {"y"}); } - - /// \brief Init. - void Init(const float fill_value = 0.0, const bool wrap = false); - /// \brief Set fill_value & wrap. - void set_fill_value(const float fill_value); - void set_wrap(const bool wrap); - - /// \brief Get fill_value & wrap. - float get_fill_value() const; - bool get_wrap() const; -}; - -MIND_API abstract::AbstractBasePtr FillDiagonalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ +#define MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ + +#include +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameFillDiagonal = "FillDiagonal"; +/// \brief Fill the main diagonal of a tensor that has at least 2-dimensions. +/// Refer to Python API @ref mindspore.ops.FillDiagonal for more details. +class MIND_API FillDiagonal : public BaseOperator { + public: + MIND_API_BASE_MEMBER(FillDiagonal); + /// \brief Constructor. + FillDiagonal() : BaseOperator(kNameFillDiagonal) { InitIOName({"input_x"}, {"y"}); } + + /// \brief Init. + void Init(const float fill_value = 0.0, const bool wrap = false); + /// \brief Set fill_value & wrap. + void set_fill_value(const float fill_value); + void set_wrap(const bool wrap); + + /// \brief Get fill_value & wrap. + float get_fill_value() const; + bool get_wrap() const; +}; + +MIND_API abstract::AbstractBasePtr FillDiagonalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_FILL_DIAGONAL_H_ diff --git a/mindspore/core/ops/fusion/scale_grad_fusion.h b/mindspore/core/ops/fusion/scale_grad_fusion.h index 90f10db1b71..96fac999846 100644 --- a/mindspore/core/ops/fusion/scale_grad_fusion.h +++ b/mindspore/core/ops/fusion/scale_grad_fusion.h @@ -1,44 +1,44 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ -#define MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ -#include -#include -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameScaleGradFusion = "ScaleGrad"; -/// \brief SliceFusion defined Slice operator prototype of lite. -class MIND_API ScaleGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(ScaleGrad); - /// \brief Constructor. - ScaleGrad() : BaseOperator(kNameScaleGradFusion) { InitIOName({"x"}, {"y"}); } - - /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScaleGrad for the inputs. - void Init() const {} -}; -abstract::AbstractBasePtr ScaleGradInferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ +#define MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ +#include +#include +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameScaleGradFusion = "ScaleGrad"; +/// \brief SliceFusion defined Slice operator prototype of lite. +class MIND_API ScaleGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(ScaleGrad); + /// \brief Constructor. + ScaleGrad() : BaseOperator(kNameScaleGradFusion) { InitIOName({"x"}, {"y"}); } + + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ScaleGrad for the inputs. + void Init() const {} +}; +abstract::AbstractBasePtr ScaleGradInferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SCALE_GRAD_FUSION_H_ diff --git a/mindspore/core/ops/grad/adaptive_max_pool2d_grad.h b/mindspore/core/ops/grad/adaptive_max_pool2d_grad.h index 4d001817cec..983c9c27500 100644 --- a/mindspore/core/ops/grad/adaptive_max_pool2d_grad.h +++ b/mindspore/core/ops/grad/adaptive_max_pool2d_grad.h @@ -1,45 +1,45 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ -#define MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameAdaptiveMaxPool2DGrad = "AdaptiveMaxPool2DGrad"; - -class MIND_API AdaptiveMaxPool2DGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(AdaptiveMaxPool2DGrad); - - AdaptiveMaxPool2DGrad() : BaseOperator(kNameAdaptiveMaxPool2DGrad) { - InitIOName({"y_grad", "x", "argmax"}, {"x_grad"}); - } -}; -MIND_API abstract::AbstractBasePtr AdaptiveMaxPool2DGradInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ +#define MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameAdaptiveMaxPool2DGrad = "AdaptiveMaxPool2DGrad"; + +class MIND_API AdaptiveMaxPool2DGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(AdaptiveMaxPool2DGrad); + + AdaptiveMaxPool2DGrad() : BaseOperator(kNameAdaptiveMaxPool2DGrad) { + InitIOName({"y_grad", "x", "argmax"}, {"x_grad"}); + } +}; +MIND_API abstract::AbstractBasePtr AdaptiveMaxPool2DGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_ADAPTIVE_MAX_POOL_2D_GRAD_H_ diff --git a/mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h b/mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h index be3be151726..359c5448efb 100644 --- a/mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h +++ b/mindspore/core/ops/grad/fractional_max_pool_grad_with_fixed_ksize.h @@ -1,57 +1,57 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ -#define MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ - -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameFractionalMaxPoolGradWithFixedKsize = "FractionalMaxPoolGradWithFixedKsize"; -class MIND_API FractionalMaxPoolGradWithFixedKsize : public BaseOperator { - public: - MIND_API_BASE_MEMBER(FractionalMaxPoolGradWithFixedKsize); - FractionalMaxPoolGradWithFixedKsize() : BaseOperator(kNameFractionalMaxPoolGradWithFixedKsize) { - InitIOName({"origin_input", "out_backprop", "argmax"}, {"y"}); - } - std::vector InputDynamic(const std::vector &out_backprop_shape_, - const std::vector &argmax_shape_, - const std::vector &origin_input_shape_, bool out_backprop_shape_dy_, - bool argmax_shape_dy_, bool origin_input_shape_dy_); - void Init(const std::string data_format); - /// \brief Init. Refer to the parameters of Python API @ref - /// mindspore.ops.operations._grad_ops.FractionalMaxPoolWithFixedKsize for the inputs. - void set_data_format(const std::string data_format); - /// \brief Set data format. - std::string get_data_format() const; - /// \brief Method to get data format attributes. - /// - /// \return data format attributes. -}; - -MIND_API abstract::AbstractBasePtr FractionalMaxPoolGradWithFixedKsizeInfer( - const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ +#define MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ + +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameFractionalMaxPoolGradWithFixedKsize = "FractionalMaxPoolGradWithFixedKsize"; +class MIND_API FractionalMaxPoolGradWithFixedKsize : public BaseOperator { + public: + MIND_API_BASE_MEMBER(FractionalMaxPoolGradWithFixedKsize); + FractionalMaxPoolGradWithFixedKsize() : BaseOperator(kNameFractionalMaxPoolGradWithFixedKsize) { + InitIOName({"origin_input", "out_backprop", "argmax"}, {"y"}); + } + std::vector InputDynamic(const std::vector &out_backprop_shape_, + const std::vector &argmax_shape_, + const std::vector &origin_input_shape_, bool out_backprop_shape_dy_, + bool argmax_shape_dy_, bool origin_input_shape_dy_); + void Init(const std::string data_format); + /// \brief Init. Refer to the parameters of Python API @ref + /// mindspore.ops.operations._grad_ops.FractionalMaxPoolWithFixedKsize for the inputs. + void set_data_format(const std::string data_format); + /// \brief Set data format. + std::string get_data_format() const; + /// \brief Method to get data format attributes. + /// + /// \return data format attributes. +}; + +MIND_API abstract::AbstractBasePtr FractionalMaxPoolGradWithFixedKsizeInfer( + const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_FRACTIONAL_MAX_POOL_GRAD_WITH_FIXED_KSIZE_H_ diff --git a/mindspore/core/ops/grad/glu_grad.h b/mindspore/core/ops/grad/glu_grad.h index 2f0412a0765..42b17fce0cb 100644 --- a/mindspore/core/ops/grad/glu_grad.h +++ b/mindspore/core/ops/grad/glu_grad.h @@ -1,40 +1,40 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CORE_OPS_GLU_GRAD_H_ -#define MINDSPORE_CORE_OPS_GLU_GRAD_H_ -#include -#include -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameGluGrad = "GluGrad"; - -class MIND_API GluGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(GluGrad); - GluGrad() : BaseOperator(kNameGluGrad) { InitIOName({"grads", "x"}, {"y"}); } - void Init() const {} -}; - -MIND_API abstract::AbstractBasePtr GluGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_GLU_GRAD_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_OPS_GLU_GRAD_H_ +#define MINDSPORE_CORE_OPS_GLU_GRAD_H_ +#include +#include +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameGluGrad = "GluGrad"; + +class MIND_API GluGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(GluGrad); + GluGrad() : BaseOperator(kNameGluGrad) { InitIOName({"grads", "x"}, {"y"}); } + void Init() const {} +}; + +MIND_API abstract::AbstractBasePtr GluGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_GLU_GRAD_H_ diff --git a/mindspore/core/ops/grad/max_unpool2d_grad.h b/mindspore/core/ops/grad/max_unpool2d_grad.h index 7a6ee9fcc84..700d533b0ea 100644 --- a/mindspore/core/ops/grad/max_unpool2d_grad.h +++ b/mindspore/core/ops/grad/max_unpool2d_grad.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ -#define MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameMaxUnpool2DGrad = "MaxUnpool2DGrad"; -class MIND_API MaxUnpool2DGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(MaxUnpool2DGrad); - MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } - std::string get_format() const; -}; - -MIND_API abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimMaxUnpool2DGradPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ +#define MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMaxUnpool2DGrad = "MaxUnpool2DGrad"; +class MIND_API MaxUnpool2DGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MaxUnpool2DGrad); + MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } + std::string get_format() const; +}; + +MIND_API abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMaxUnpool2DGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_ diff --git a/mindspore/core/ops/grad/max_unpool3d_grad.h b/mindspore/core/ops/grad/max_unpool3d_grad.h index 61d7af222a0..09c1fa3224f 100644 --- a/mindspore/core/ops/grad/max_unpool3d_grad.h +++ b/mindspore/core/ops/grad/max_unpool3d_grad.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ -#define MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameMaxUnpool3DGrad = "MaxUnpool3DGrad"; -class MIND_API MaxUnpool3DGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(MaxUnpool3DGrad); - MaxUnpool3DGrad() : BaseOperator(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } - std::string get_format() const; -}; - -MIND_API abstract::AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimMaxUnpool3DGradPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ +#define MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMaxUnpool3DGrad = "MaxUnpool3DGrad"; +class MIND_API MaxUnpool3DGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MaxUnpool3DGrad); + MaxUnpool3DGrad() : BaseOperator(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } + std::string get_format() const; +}; + +MIND_API abstract::AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMaxUnpool3DGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAXUNPOOL3DGRAD_H_ diff --git a/mindspore/core/ops/grad/p_s_r_o_i_pooling_grad.h b/mindspore/core/ops/grad/p_s_r_o_i_pooling_grad.h index 8a45444934c..1167514c2ad 100644 --- a/mindspore/core/ops/grad/p_s_r_o_i_pooling_grad.h +++ b/mindspore/core/ops/grad/p_s_r_o_i_pooling_grad.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_P_S_R_O_I_POOLING_GRAD_H -#define MINDSPORE_P_S_R_O_I_POOLING_GRAD_H - -#include -#include -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNamePSROIPoolingGrad = "PSROIPoolingGrad"; - -class MIND_API PSROIPoolingGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(PSROIPoolingGrad); - PSROIPoolingGrad() : BaseOperator(kNamePSROIPoolingGrad) {} -}; - -MIND_API abstract::AbstractBasePtr PSROIPoolingGradInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_P_S_R_O_I_POOLING_GRAD_H +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_P_S_R_O_I_POOLING_GRAD_H +#define MINDSPORE_P_S_R_O_I_POOLING_GRAD_H + +#include +#include +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNamePSROIPoolingGrad = "PSROIPoolingGrad"; + +class MIND_API PSROIPoolingGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(PSROIPoolingGrad); + PSROIPoolingGrad() : BaseOperator(kNamePSROIPoolingGrad) {} +}; + +MIND_API abstract::AbstractBasePtr PSROIPoolingGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_P_S_R_O_I_POOLING_GRAD_H diff --git a/mindspore/core/ops/grad/sparse_segment_sum_grad.h b/mindspore/core/ops/grad/sparse_segment_sum_grad.h index e2391b3c7a1..5386774131a 100644 --- a/mindspore/core/ops/grad/sparse_segment_sum_grad.h +++ b/mindspore/core/ops/grad/sparse_segment_sum_grad.h @@ -1,47 +1,47 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ -#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ - -#include -#include -#include -#include -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseSegmentSumGrad = "SparseSegmentSumGrad"; -class MIND_API SparseSegmentSumGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseSegmentSumGrad); - SparseSegmentSumGrad() : BaseOperator(kNameSparseSegmentSumGrad) { - InitIOName({"grad", "indices", "segment_ids", "output_dim0"}, {"output"}); - } -}; - -MIND_API abstract::AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimSparseSegmentSumGradPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ + +#include +#include +#include +#include +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSumGrad = "SparseSegmentSumGrad"; +class MIND_API SparseSegmentSumGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSumGrad); + SparseSegmentSumGrad() : BaseOperator(kNameSparseSegmentSumGrad) { + InitIOName({"grad", "indices", "segment_ids", "output_dim0"}, {"output"}); + } +}; + +MIND_API abstract::AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentSumGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_ diff --git a/mindspore/core/ops/grad/trace_grad.h b/mindspore/core/ops/grad/trace_grad.h index 11cc49245a6..32475fdf88a 100644 --- a/mindspore/core/ops/grad/trace_grad.h +++ b/mindspore/core/ops/grad/trace_grad.h @@ -1,38 +1,38 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_TRACE_GRAD_H_ -#define MINDSPORE_CORE_OPS_TRACE_GRAD_H_ -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameTraceGrad = "TraceGrad"; -class MIND_API TraceGrad : public BaseOperator { - public: - MIND_API_BASE_MEMBER(TraceGrad); - TraceGrad() : BaseOperator(kNameTraceGrad) {} - void Init() const {} -}; -MIND_API abstract::AbstractBasePtr TraceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_TRACE_GRAD_H_ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRACE_GRAD_H_ +#define MINDSPORE_CORE_OPS_TRACE_GRAD_H_ +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTraceGrad = "TraceGrad"; +class MIND_API TraceGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TraceGrad); + TraceGrad() : BaseOperator(kNameTraceGrad) {} + void Init() const {} +}; +MIND_API abstract::AbstractBasePtr TraceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_TRACE_GRAD_H_ diff --git a/mindspore/core/ops/lstsq.h b/mindspore/core/ops/lstsq.h index e96ed114ad3..fe7b3d81edd 100644 --- a/mindspore/core/ops/lstsq.h +++ b/mindspore/core/ops/lstsq.h @@ -1,39 +1,39 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License - */ - -#ifndef MINDSPORE_CORE_OPS_LSTSQ_H_ -#define MINDSPORE_CORE_OPS_LSTSQ_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameLstsq = "Lstsq"; -class MIND_API Lstsq : public BaseOperator { - public: - MIND_API_BASE_MEMBER(Lstsq); - Lstsq() : BaseOperator(kNameLstsq) { InitIOName({"matrix", "rhs"}, {"y"}); } -}; -MIND_API abstract::AbstractBasePtr LstsqInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimLstsqPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_LSTSQ_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +#ifndef MINDSPORE_CORE_OPS_LSTSQ_H_ +#define MINDSPORE_CORE_OPS_LSTSQ_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLstsq = "Lstsq"; +class MIND_API Lstsq : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Lstsq); + Lstsq() : BaseOperator(kNameLstsq) { InitIOName({"matrix", "rhs"}, {"y"}); } +}; +MIND_API abstract::AbstractBasePtr LstsqInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimLstsqPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_LSTSQ_H_ diff --git a/mindspore/core/ops/lu_scipy.h b/mindspore/core/ops/lu_scipy.h index 393f4775146..35c948c0953 100644 --- a/mindspore/core/ops/lu_scipy.h +++ b/mindspore/core/ops/lu_scipy.h @@ -1,42 +1,42 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License - */ - -#ifndef MINDSPORE_CORE_OPS_LU_SCIPY_H_ -#define MINDSPORE_CORE_OPS_LU_SCIPY_H_ - -#include -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameLU = "LU"; -class MIND_API LU : public BaseOperator { - public: - MIND_API_BASE_MEMBER(LU); - LU() : BaseOperator(kNameLU) { InitIOName({"x"}, {"lu", "pivots", "permutation"}); } -}; -abstract::AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimLUPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_LU_SCIPY_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +#ifndef MINDSPORE_CORE_OPS_LU_SCIPY_H_ +#define MINDSPORE_CORE_OPS_LU_SCIPY_H_ + +#include +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLU = "LU"; +class MIND_API LU : public BaseOperator { + public: + MIND_API_BASE_MEMBER(LU); + LU() : BaseOperator(kNameLU) { InitIOName({"x"}, {"lu", "pivots", "permutation"}); } +}; +abstract::AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimLUPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_LU_SCIPY_H_ diff --git a/mindspore/core/ops/lu_solve_.h b/mindspore/core/ops/lu_solve_.h index 886cd525d80..ff30cd71da4 100644 --- a/mindspore/core/ops/lu_solve_.h +++ b/mindspore/core/ops/lu_solve_.h @@ -1,42 +1,42 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License - */ - -#ifndef MINDSPORE_CORE_OPS_LUSOLVE_H_ -#define MINDSPORE_CORE_OPS_LUSOLVE_H_ - -#include -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameLuSolve = "LuSolve"; -class MIND_API LuSolve : public BaseOperator { - public: - MIND_API_BASE_MEMBER(LuSolve); - LuSolve() : BaseOperator(kNameLuSolve) { InitIOName({"x", "lu_data", "lu_pivots"}, {"output"}); } -}; -MIND_API abstract::AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimLuSolvePtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_LUSOLVE_H_ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +#ifndef MINDSPORE_CORE_OPS_LUSOLVE_H_ +#define MINDSPORE_CORE_OPS_LUSOLVE_H_ + +#include +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLuSolve = "LuSolve"; +class MIND_API LuSolve : public BaseOperator { + public: + MIND_API_BASE_MEMBER(LuSolve); + LuSolve() : BaseOperator(kNameLuSolve) { InitIOName({"x", "lu_data", "lu_pivots"}, {"output"}); } +}; +MIND_API abstract::AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimLuSolvePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_LUSOLVE_H_ diff --git a/mindspore/core/ops/max_unpool2d.h b/mindspore/core/ops/max_unpool2d.h index 7a68d75d900..76eb8aac4c0 100644 --- a/mindspore/core/ops/max_unpool2d.h +++ b/mindspore/core/ops/max_unpool2d.h @@ -1,42 +1,42 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ -#define MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameMaxUnpool2D = "MaxUnpool2D"; -class MIND_API MaxUnpool2D : public BaseOperator { - public: - MIND_API_BASE_MEMBER(MaxUnpool2D); - MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); } - std::string get_format() const; -}; - -MIND_API abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimMaxUnpool2DPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ +#define MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMaxUnpool2D = "MaxUnpool2D"; +class MIND_API MaxUnpool2D : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MaxUnpool2D); + MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); } + std::string get_format() const; +}; + +MIND_API abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMaxUnpool2DPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_ diff --git a/mindspore/core/ops/max_unpool3d.h b/mindspore/core/ops/max_unpool3d.h index 547614eb3c0..7ea5492e82e 100644 --- a/mindspore/core/ops/max_unpool3d.h +++ b/mindspore/core/ops/max_unpool3d.h @@ -1,42 +1,42 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ -#define MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameMaxUnpool3D = "MaxUnpool3D"; -class MIND_API MaxUnpool3D : public BaseOperator { - public: - MIND_API_BASE_MEMBER(MaxUnpool3D); - MaxUnpool3D() : BaseOperator(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); } - std::string get_format() const; -}; - -MIND_API abstract::AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimMaxUnpool3DPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ +#define MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMaxUnpool3D = "MaxUnpool3D"; +class MIND_API MaxUnpool3D : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MaxUnpool3D); + MaxUnpool3D() : BaseOperator(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); } + std::string get_format() const; +}; + +MIND_API abstract::AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMaxUnpool3DPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAXUNPOOL3D_H_ diff --git a/mindspore/core/ops/ops_def/README.md b/mindspore/core/ops/ops_def/README.md index 10c08861582..2f39da80dc1 100644 --- a/mindspore/core/ops/ops_def/README.md +++ b/mindspore/core/ops/ops_def/README.md @@ -1,116 +1,116 @@ -### Yaml file rules and explanations are as follows - -```yaml -# Defining the function name and Primitive name of operators, use the '_' to separate words. For example, op_name is 'word1_word2', then the function name is 'word1_word2', and the Primitive class name is 'Word1Word2'. -: - # The 'args' is a fixed key of yaml file to define input args of operators. - : - # Mandatory. For every arg, key is operators' argument name, and the value are some items, items' key name can be 'dtype', 'prim_init', 'default', 'type_cast','arg_handler'. - : - # Mandatory. The 'dtype' is a fixed key. - # Value is one of {int, float, bool, number, tensor, tuple, list, tuple[int], tuple[float], tuple[bool], tuple[number], tuple[tensor], list[int], list[float], list[bool], list[number], list[tensor]}. - # If value is 'number', arg can be 'int', 'float' or 'bool'. - : - - # Optional. The 'default' is a fixed key. - # This item means input arg can use default value. - : - - # Optional. The 'prim_init' is a fixed key. Value can be 'True' or 'False', arg is arg of '__init__' of Primitive if value is 'True'. - : - - # Optional. The 'type_cast' is a fixed key. This item means can accept unmatchable input by implicit conversion. Value is one of {int, float, bool, number, tensor, tuple, list, tuple[int], tuple[float], tuple[bool], tuple[number], tuple[tensor], list[int], list[float], list[bool], list[number], list[tensor]} - # Supported type cast now: - # 1. int, float, bool, number <-> tensor. - # 2. int, float, bool, number, tensor <-> list/tuple. - # 3. list <-> tuple. - : - - # Optional. The 'arg_handler' is a fixed key. Value is a function name used to convert arg. For example, converting kernel size from 2 to (2, 2). - : - - : - ... - - : #Optional - # Optional. The 'rw_write' is a fixed key, 'arg_name' is the corresponding arg name. - : - - # Optional. The 'rw_read' is a fixed key, 'arg_name' is the corresponding arg name. - : - - # Optional. The 'rw_ref' is a fixed key, 'arg_name' is the corresponding arg name. - : - - # Optional. arg1 and arg2 should has same dtype. arg3 and arg4 should has same dtype. - : (, , ...), (, , ...), ... - - # The 'returns' is a fixed key of yaml file to define output of operators. - : - # Mandatory. For every output, key is operators' output name, and the value is a item, item's key is 'dtype'. - : - # Mandatory. Just refer to key 'dtype' in args. - : - - # Optional. The 'inplace' is a fixed key. Value is input name of operator if the input is a inplace input. - : - - : - ... - - # Optional. Rename the function but not use function name from . - : - # Optional. The 'name' is a fixed key. Value is the new function name to replace . - # Default: `op_name`. - : - - # Optional. The 'disable' is a fixed key. Value is 'True' or 'False', the function will not be generated if it is 'True'. - # Default: False - : - - # Optional. Reaname the primitive class name but not use class name from . - : - # Optional. The 'name' is a fixed key. Value is the new class name to replace . - # Default: Transformed from `op_name`. For example, `op_name` is 'avg_pool', then the class name is 'AvgPool'. - : - # Optional. The 'disable' is a fixed key. Value is 'True' or 'False', the primitive definition will not be generated if it is 'True'. - # Default: False. - : - - # Optional. The 'view' is a fixed key. Value should be set as 'True' if this is a view operator. - # Default: False. - : - - # Optional. The 'dispatch' is a fixed key. The item is used to control whether generate pyboost codes. - : - # Optional. The 'enable' is a fixed key. Pyboost codes will be auto generated if value is True. - # Default: False. - : - - # Optional. The 'device_name' can be set as 'CPU', 'GPU' or 'Ascend' and the value is a function name. If this item eixst, it means pyboost function cannot be - # auto generated in specified device target and the specified function defined manually will act as pyboost function. - : -``` - -### operators definitions will be auto generated when build MindSpore package - -The auto generated operator definition python files are in path: - -'mindspore/python/mindspore/ops/auto_generate/'. - -The auto generated operator definition c++ files are in path: - -'mindspore/core/ops/auto_generate/'. - -The auto generated operator pyboost code files are in path: - -1. 'mindspore/ccsrc/kernel/pyboost/auto_generate'. - -2. 'mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/auto_generate'. - -3. 'mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/auto_generate'. - -4. 'mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/auto_generate'. - -5. 'mindspore/ccsrc/pipeline/pynative/op_function/auto_generat'. - +### Yaml file rules and explanations are as follows + +```yaml +# Defining the function name and Primitive name of operators, use the '_' to separate words. For example, op_name is 'word1_word2', then the function name is 'word1_word2', and the Primitive class name is 'Word1Word2'. +: + # The 'args' is a fixed key of yaml file to define input args of operators. + : + # Mandatory. For every arg, key is operators' argument name, and the value are some items, items' key name can be 'dtype', 'prim_init', 'default', 'type_cast','arg_handler'. + : + # Mandatory. The 'dtype' is a fixed key. + # Value is one of {int, float, bool, number, tensor, tuple, list, tuple[int], tuple[float], tuple[bool], tuple[number], tuple[tensor], list[int], list[float], list[bool], list[number], list[tensor]}. + # If value is 'number', arg can be 'int', 'float' or 'bool'. + : + + # Optional. The 'default' is a fixed key. + # This item means input arg can use default value. + : + + # Optional. The 'prim_init' is a fixed key. Value can be 'True' or 'False', arg is arg of '__init__' of Primitive if value is 'True'. + : + + # Optional. The 'type_cast' is a fixed key. This item means can accept unmatchable input by implicit conversion. Value is one of {int, float, bool, number, tensor, tuple, list, tuple[int], tuple[float], tuple[bool], tuple[number], tuple[tensor], list[int], list[float], list[bool], list[number], list[tensor]} + # Supported type cast now: + # 1. int, float, bool, number <-> tensor. + # 2. int, float, bool, number, tensor <-> list/tuple. + # 3. list <-> tuple. + : + + # Optional. The 'arg_handler' is a fixed key. Value is a function name used to convert arg. For example, converting kernel size from 2 to (2, 2). + : + + : + ... + + : #Optional + # Optional. The 'rw_write' is a fixed key, 'arg_name' is the corresponding arg name. + : + + # Optional. The 'rw_read' is a fixed key, 'arg_name' is the corresponding arg name. + : + + # Optional. The 'rw_ref' is a fixed key, 'arg_name' is the corresponding arg name. + : + + # Optional. arg1 and arg2 should has same dtype. arg3 and arg4 should has same dtype. + : (, , ...), (, , ...), ... + + # The 'returns' is a fixed key of yaml file to define output of operators. + : + # Mandatory. For every output, key is operators' output name, and the value is a item, item's key is 'dtype'. + : + # Mandatory. Just refer to key 'dtype' in args. + : + + # Optional. The 'inplace' is a fixed key. Value is input name of operator if the input is a inplace input. + : + + : + ... + + # Optional. Rename the function but not use function name from . + : + # Optional. The 'name' is a fixed key. Value is the new function name to replace . + # Default: `op_name`. + : + + # Optional. The 'disable' is a fixed key. Value is 'True' or 'False', the function will not be generated if it is 'True'. + # Default: False + : + + # Optional. Reaname the primitive class name but not use class name from . + : + # Optional. The 'name' is a fixed key. Value is the new class name to replace . + # Default: Transformed from `op_name`. For example, `op_name` is 'avg_pool', then the class name is 'AvgPool'. + : + # Optional. The 'disable' is a fixed key. Value is 'True' or 'False', the primitive definition will not be generated if it is 'True'. + # Default: False. + : + + # Optional. The 'view' is a fixed key. Value should be set as 'True' if this is a view operator. + # Default: False. + : + + # Optional. The 'dispatch' is a fixed key. The item is used to control whether generate pyboost codes. + : + # Optional. The 'enable' is a fixed key. Pyboost codes will be auto generated if value is True. + # Default: False. + : + + # Optional. The 'device_name' can be set as 'CPU', 'GPU' or 'Ascend' and the value is a function name. If this item eixst, it means pyboost function cannot be + # auto generated in specified device target and the specified function defined manually will act as pyboost function. + : +``` + +### operators definitions will be auto generated when build MindSpore package + +The auto generated operator definition python files are in path: + +'mindspore/python/mindspore/ops/auto_generate/'. + +The auto generated operator definition c++ files are in path: + +'mindspore/core/ops/auto_generate/'. + +The auto generated operator pyboost code files are in path: + +1. 'mindspore/ccsrc/kernel/pyboost/auto_generate'. + +2. 'mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/auto_generate'. + +3. 'mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/auto_generate'. + +4. 'mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/auto_generate'. + +5. 'mindspore/ccsrc/pipeline/pynative/op_function/auto_generat'. + diff --git a/mindspore/core/ops/ops_frontend_func_impl.cc b/mindspore/core/ops/ops_frontend_func_impl.cc index 99b910f2532..9d7fe5e8c5d 100644 --- a/mindspore/core/ops/ops_frontend_func_impl.cc +++ b/mindspore/core/ops/ops_frontend_func_impl.cc @@ -1,77 +1,77 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ops/ops_frontend_func_impl.h" -#include "utils/log_adapter.h" - -namespace mindspore::ops { -OpsFrontendFuncImplMap *GetOpsFrontendFuncImplMapPtr() { - static OpsFrontendFuncImplMap ops_frontend_func_impl_map; - return &ops_frontend_func_impl_map; -} - -OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name) { - auto iter = GetOpsFrontendFuncImplMapPtr()->find(name); - if (iter == GetOpsFrontendFuncImplMapPtr()->end()) { - return nullptr; - } - - return iter->second.get_func_impl(); -} - -RegFrontendFuncImplHelper::RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl) { - const FrontendFuncImplHolder holder{func_impl}; - (void)GetOpsFrontendFuncImplMapPtr()->emplace(name, holder); -} - -InferValueCallback &InferValueCallback::GetInstance() { - static InferValueCallback instance{}; - return instance; -} - -void InferValueCallback::RegImpl(const std::string &impl_type, const InferValueFunc &func) { - if (impl_type == "python_impl") { - if (python_impl_) { - MS_LOG(ERROR) << "InferValueImpl for python_impl is already registered!"; - } - python_impl_ = func; - } else if (impl_type == "cpu_kernel_impl") { - if (kernel_impl_) { - MS_LOG(ERROR) << "InferValueImpl for cpu_kernel_impl is already registered!"; - } - kernel_impl_ = func; - } else { - MS_LOG(ERROR) << "Unsupported InferValue implement type " << impl_type << "!"; - } -} - -ValuePtr InferValueCallback::CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) { - if (python_impl_) { - return python_impl_(op_name, input_args); - } - return nullptr; -} -ValuePtr InferValueCallback::CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) { - if (kernel_impl_) { - return kernel_impl_(op_name, input_args); - } - return nullptr; -} - -InferValueImplRegister::InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn) { - InferValueCallback::GetInstance().RegImpl(impl_type, fn); -} -} // namespace mindspore::ops +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/ops_frontend_func_impl.h" +#include "utils/log_adapter.h" + +namespace mindspore::ops { +OpsFrontendFuncImplMap *GetOpsFrontendFuncImplMapPtr() { + static OpsFrontendFuncImplMap ops_frontend_func_impl_map; + return &ops_frontend_func_impl_map; +} + +OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name) { + auto iter = GetOpsFrontendFuncImplMapPtr()->find(name); + if (iter == GetOpsFrontendFuncImplMapPtr()->end()) { + return nullptr; + } + + return iter->second.get_func_impl(); +} + +RegFrontendFuncImplHelper::RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl) { + const FrontendFuncImplHolder holder{func_impl}; + (void)GetOpsFrontendFuncImplMapPtr()->emplace(name, holder); +} + +InferValueCallback &InferValueCallback::GetInstance() { + static InferValueCallback instance{}; + return instance; +} + +void InferValueCallback::RegImpl(const std::string &impl_type, const InferValueFunc &func) { + if (impl_type == "python_impl") { + if (python_impl_) { + MS_LOG(ERROR) << "InferValueImpl for python_impl is already registered!"; + } + python_impl_ = func; + } else if (impl_type == "cpu_kernel_impl") { + if (kernel_impl_) { + MS_LOG(ERROR) << "InferValueImpl for cpu_kernel_impl is already registered!"; + } + kernel_impl_ = func; + } else { + MS_LOG(ERROR) << "Unsupported InferValue implement type " << impl_type << "!"; + } +} + +ValuePtr InferValueCallback::CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) { + if (python_impl_) { + return python_impl_(op_name, input_args); + } + return nullptr; +} +ValuePtr InferValueCallback::CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) { + if (kernel_impl_) { + return kernel_impl_(op_name, input_args); + } + return nullptr; +} + +InferValueImplRegister::InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn) { + InferValueCallback::GetInstance().RegImpl(impl_type, fn); +} +} // namespace mindspore::ops diff --git a/mindspore/core/ops/ops_frontend_func_impl.h b/mindspore/core/ops/ops_frontend_func_impl.h index 07724bccab3..28bee472c10 100644 --- a/mindspore/core/ops/ops_frontend_func_impl.h +++ b/mindspore/core/ops/ops_frontend_func_impl.h @@ -1,113 +1,113 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H -#define MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H - -#include -#include -#include -#include -#include "ir/cell.h" -#include "ir/primitive.h" -#include "abstract/abstract_value.h" -#include "ir/anf.h" -#include "mindapi/base/macros.h" - -namespace mindspore::ops { -class OpFrontendFuncImpl { - public: - OpFrontendFuncImpl() = default; - virtual ~OpFrontendFuncImpl() = default; - - /// \brief Infer the output value for target operator. Only override when needed. - /// - /// \param[in] primitive Operator's primitive. - /// \param[in] input_args Operator's inputs. - /// - /// \return Inferred Value based on given inputs. - virtual ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector &input_args) const { - return nullptr; - } - - /// \brief Infer the related Abstract for target operator. - /// - /// \param[in] primitive Operator's primitive. - /// \param[in] input_args Operator's inputs. - /// - /// \return AbstractBasePtr with inferred shape and inferred type. - virtual AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, - const std::vector &input_args) const { - return nullptr; - } -}; - -using OpFrontendFuncImplPtr = std::shared_ptr; - -class FrontendFuncImplHolder { - public: - explicit FrontendFuncImplHolder(const OpFrontendFuncImplPtr &func_impl) : func_impl_(func_impl) {} - ~FrontendFuncImplHolder() = default; - OpFrontendFuncImplPtr get_func_impl() { return func_impl_; } - - private: - OpFrontendFuncImplPtr func_impl_{nullptr}; -}; - -using OpsFrontendFuncImplMap = std::unordered_map; - -MS_CORE_API OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name); - -class MS_CORE_API RegFrontendFuncImplHelper { - public: - RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl); - ~RegFrontendFuncImplHelper() = default; -}; - -#define REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(name, func_impl_class) \ - static auto helper_##func_impl_class = RegFrontendFuncImplHelper(name, std::make_shared()); - -using InferValueFunc = std::function; -class MS_CORE_API InferValueCallback { - public: - InferValueCallback(const InferValueCallback &) = delete; - InferValueCallback &operator=(const InferValueCallback &) = delete; - - static InferValueCallback &GetInstance(); - - void RegImpl(const std::string &impl_type, const InferValueFunc &py_func); - ValuePtr CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); - ValuePtr CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); - - private: - InferValueCallback() = default; - ~InferValueCallback() {} - - private: - InferValueFunc python_impl_{nullptr}; - InferValueFunc kernel_impl_{nullptr}; -}; - -class MS_CORE_API InferValueImplRegister { - public: - InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn); - ~InferValueImplRegister() = default; -}; - -#define INFER_VALUE_IMPL_REGISTER(impl_type, func) \ - static auto reg_##impl_type##_##func = mindspore::ops::InferValueImplRegister(#impl_type, func) -} // namespace mindspore::ops -#endif // MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H +#define MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H + +#include +#include +#include +#include +#include "ir/cell.h" +#include "ir/primitive.h" +#include "abstract/abstract_value.h" +#include "ir/anf.h" +#include "mindapi/base/macros.h" + +namespace mindspore::ops { +class OpFrontendFuncImpl { + public: + OpFrontendFuncImpl() = default; + virtual ~OpFrontendFuncImpl() = default; + + /// \brief Infer the output value for target operator. Only override when needed. + /// + /// \param[in] primitive Operator's primitive. + /// \param[in] input_args Operator's inputs. + /// + /// \return Inferred Value based on given inputs. + virtual ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector &input_args) const { + return nullptr; + } + + /// \brief Infer the related Abstract for target operator. + /// + /// \param[in] primitive Operator's primitive. + /// \param[in] input_args Operator's inputs. + /// + /// \return AbstractBasePtr with inferred shape and inferred type. + virtual AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, + const std::vector &input_args) const { + return nullptr; + } +}; + +using OpFrontendFuncImplPtr = std::shared_ptr; + +class FrontendFuncImplHolder { + public: + explicit FrontendFuncImplHolder(const OpFrontendFuncImplPtr &func_impl) : func_impl_(func_impl) {} + ~FrontendFuncImplHolder() = default; + OpFrontendFuncImplPtr get_func_impl() { return func_impl_; } + + private: + OpFrontendFuncImplPtr func_impl_{nullptr}; +}; + +using OpsFrontendFuncImplMap = std::unordered_map; + +MS_CORE_API OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name); + +class MS_CORE_API RegFrontendFuncImplHelper { + public: + RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl); + ~RegFrontendFuncImplHelper() = default; +}; + +#define REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(name, func_impl_class) \ + static auto helper_##func_impl_class = RegFrontendFuncImplHelper(name, std::make_shared()); + +using InferValueFunc = std::function; +class MS_CORE_API InferValueCallback { + public: + InferValueCallback(const InferValueCallback &) = delete; + InferValueCallback &operator=(const InferValueCallback &) = delete; + + static InferValueCallback &GetInstance(); + + void RegImpl(const std::string &impl_type, const InferValueFunc &py_func); + ValuePtr CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); + ValuePtr CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); + + private: + InferValueCallback() = default; + ~InferValueCallback() {} + + private: + InferValueFunc python_impl_{nullptr}; + InferValueFunc kernel_impl_{nullptr}; +}; + +class MS_CORE_API InferValueImplRegister { + public: + InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn); + ~InferValueImplRegister() = default; +}; + +#define INFER_VALUE_IMPL_REGISTER(impl_type, func) \ + static auto reg_##impl_type##_##func = mindspore::ops::InferValueImplRegister(#impl_type, func) +} // namespace mindspore::ops +#endif // MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H diff --git a/mindspore/core/ops/ops_func_impl/betainc.h b/mindspore/core/ops/ops_func_impl/betainc.h index 574a3c23967..a815cbcef13 100644 --- a/mindspore/core/ops/ops_func_impl/betainc.h +++ b/mindspore/core/ops/ops_func_impl/betainc.h @@ -1,32 +1,32 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ -#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ - -#include -#include "ops/ops_func_impl/op_func_impl.h" - -namespace mindspore { -namespace ops { -class MIND_API BetaincFuncImpl : public OpFuncImpl { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; - - TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; -}; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ + +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class MIND_API BetaincFuncImpl : public OpFuncImpl { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_BETAINC_H_ diff --git a/mindspore/core/ops/ops_func_impl/tanh.cc b/mindspore/core/ops/ops_func_impl/tanh.cc index d6396ad310e..a289428c2c2 100644 --- a/mindspore/core/ops/ops_func_impl/tanh.cc +++ b/mindspore/core/ops/ops_func_impl/tanh.cc @@ -1,71 +1,71 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "ops/op_utils.h" -#include "utils/check_convert_utils.h" -#include "ops/ops_func_impl/simple_infer.h" -#include "ops/ops_func_impl/tanh.h" - -namespace mindspore { -namespace ops { -BaseShapePtr TanhFuncImpl::InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const { - MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); - auto x_shape = input_args[kInputIndex0]->GetShape(); - return x_shape->Clone(); -} - -TypePtr TanhFuncImpl::InferType(const PrimitivePtr &primitive, const std::vector &input_args) const { - auto input_type = input_args[kIndex0]->GetType(); - auto input_type_id = input_type->cast()->element()->type_id(); - static const std::vector int_or_bool = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt16, - kNumberTypeInt32, kNumberTypeInt64, kNumberTypeBool}; - bool is_int_or_bool = std::any_of(int_or_bool.begin(), int_or_bool.end(), - [&input_type_id](const TypeId &type_id) { return input_type_id == type_id; }); - if (is_int_or_bool) { - return std::make_shared(kFloat32); - } else { - return input_type->Clone(); - } -} - -ShapeArray TanhFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const { - const auto &x_tensor = input_values[kIndex0]->cast(); - MS_EXCEPTION_IF_NULL(x_tensor); - return {x_tensor->shape()}; -} - -TypePtrList TanhFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const { - const auto &x_tensor = input_values[kIndex0]->cast(); - MS_EXCEPTION_IF_NULL(x_tensor); - const auto &input_type = x_tensor->Dtype(); - const auto &input_type_id = x_tensor->Dtype()->type_id(); - static const std::vector int_or_bool = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt16, - kNumberTypeInt32, kNumberTypeInt64, kNumberTypeBool}; - bool is_int_or_bool = std::any_of(int_or_bool.begin(), int_or_bool.end(), - [&input_type_id](const TypeId &type_id) { return input_type_id == type_id; }); - if (is_int_or_bool) { - return {kFloat32}; - } else { - return {input_type}; - } -} - -REGISTER_SIMPLE_INFER(kNameTanh, TanhFuncImpl) - -} // namespace ops -} // namespace mindspore +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "ops/ops_func_impl/simple_infer.h" +#include "ops/ops_func_impl/tanh.h" + +namespace mindspore { +namespace ops { +BaseShapePtr TanhFuncImpl::InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const { + MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); + auto x_shape = input_args[kInputIndex0]->GetShape(); + return x_shape->Clone(); +} + +TypePtr TanhFuncImpl::InferType(const PrimitivePtr &primitive, const std::vector &input_args) const { + auto input_type = input_args[kIndex0]->GetType(); + auto input_type_id = input_type->cast()->element()->type_id(); + static const std::vector int_or_bool = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt16, + kNumberTypeInt32, kNumberTypeInt64, kNumberTypeBool}; + bool is_int_or_bool = std::any_of(int_or_bool.begin(), int_or_bool.end(), + [&input_type_id](const TypeId &type_id) { return input_type_id == type_id; }); + if (is_int_or_bool) { + return std::make_shared(kFloat32); + } else { + return input_type->Clone(); + } +} + +ShapeArray TanhFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const { + const auto &x_tensor = input_values[kIndex0]->cast(); + MS_EXCEPTION_IF_NULL(x_tensor); + return {x_tensor->shape()}; +} + +TypePtrList TanhFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const { + const auto &x_tensor = input_values[kIndex0]->cast(); + MS_EXCEPTION_IF_NULL(x_tensor); + const auto &input_type = x_tensor->Dtype(); + const auto &input_type_id = x_tensor->Dtype()->type_id(); + static const std::vector int_or_bool = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt16, + kNumberTypeInt32, kNumberTypeInt64, kNumberTypeBool}; + bool is_int_or_bool = std::any_of(int_or_bool.begin(), int_or_bool.end(), + [&input_type_id](const TypeId &type_id) { return input_type_id == type_id; }); + if (is_int_or_bool) { + return {kFloat32}; + } else { + return {input_type}; + } +} + +REGISTER_SIMPLE_INFER(kNameTanh, TanhFuncImpl) + +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/ops_func_impl/tanh.h b/mindspore/core/ops/ops_func_impl/tanh.h index ce66784e37a..f6289b33e0c 100644 --- a/mindspore/core/ops/ops_func_impl/tanh.h +++ b/mindspore/core/ops/ops_func_impl/tanh.h @@ -1,39 +1,39 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ -#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ - -#include -#include "ops/ops_func_impl/op_func_impl.h" - -namespace mindspore { -namespace ops { -class MIND_API TanhFuncImpl : public OpFuncImpl { - public: - TanhFuncImpl() = default; - ~TanhFuncImpl() = default; - - BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; - - TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; - - // simply infer - ShapeArray InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const override; - TypePtrList InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const override; -}; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ +#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ + +#include +#include "ops/ops_func_impl/op_func_impl.h" + +namespace mindspore { +namespace ops { +class MIND_API TanhFuncImpl : public OpFuncImpl { + public: + TanhFuncImpl() = default; + ~TanhFuncImpl() = default; + + BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; + + // simply infer + ShapeArray InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const override; + TypePtrList InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const override; +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_TANH_H_ diff --git a/mindspore/core/ops/parallel_concat.h b/mindspore/core/ops/parallel_concat.h index 43727222917..dcece324707 100644 --- a/mindspore/core/ops/parallel_concat.h +++ b/mindspore/core/ops/parallel_concat.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ -#define MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameParallelConcat = "ParallelConcat"; -/// \brief Connect tensor in the specified axis. -/// Refer to Python API @ref mindspore.ops.ParallelConcat for more details. -class MIND_API ParallelConcat : public BaseOperator { - public: - MIND_API_BASE_MEMBER(ParallelConcat); - /// \brief Constructor. - ParallelConcat() : BaseOperator(kNameParallelConcat) { InitIOName({"x"}, {"y"}); } - /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ParallelConcat for the inputs. - void Init() const {} -}; -MIND_API abstract::AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ +#define MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameParallelConcat = "ParallelConcat"; +/// \brief Connect tensor in the specified axis. +/// Refer to Python API @ref mindspore.ops.ParallelConcat for more details. +class MIND_API ParallelConcat : public BaseOperator { + public: + MIND_API_BASE_MEMBER(ParallelConcat); + /// \brief Constructor. + ParallelConcat() : BaseOperator(kNameParallelConcat) { InitIOName({"x"}, {"y"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ParallelConcat for the inputs. + void Init() const {} +}; +MIND_API abstract::AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_ diff --git a/mindspore/core/ops/scatter_add_with_axis.h b/mindspore/core/ops/scatter_add_with_axis.h index fb6199bd223..d43958c5adb 100644 --- a/mindspore/core/ops/scatter_add_with_axis.h +++ b/mindspore/core/ops/scatter_add_with_axis.h @@ -1,50 +1,50 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ -#define MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameScatterAddWithAxis = "ScatterAddWithAxis"; -/// \brief Updates tensor values by using input indices and value. -/// Refer to Python API @ref mindspore.ops.ScatterAddWithAxis for more details. -class MIND_API ScatterAddWithAxis : public BaseOperator { - public: - MIND_API_BASE_MEMBER(ScatterAddWithAxis); - /// \brief Constructor. - ScatterAddWithAxis() : BaseOperator(kNameScatterAddWithAxis) { InitIOName({"input_x", "indices", "update"}, {"y"}); } - /// \brief Init. Refer to the parameters of Python API @ref - /// mindspore.ops.ScatterAddWithAxis for the inputs. - void Init(const int64_t axis = 0); - /// \brief Set axis. - void set_axis(const int64_t axis); - /// \brief Get axis. - int64_t get_axis() const; -}; -MIND_API abstract::AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using kPrimScatterAddWithAxisPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ +#define MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameScatterAddWithAxis = "ScatterAddWithAxis"; +/// \brief Updates tensor values by using input indices and value. +/// Refer to Python API @ref mindspore.ops.ScatterAddWithAxis for more details. +class MIND_API ScatterAddWithAxis : public BaseOperator { + public: + MIND_API_BASE_MEMBER(ScatterAddWithAxis); + /// \brief Constructor. + ScatterAddWithAxis() : BaseOperator(kNameScatterAddWithAxis) { InitIOName({"input_x", "indices", "update"}, {"y"}); } + /// \brief Init. Refer to the parameters of Python API @ref + /// mindspore.ops.ScatterAddWithAxis for the inputs. + void Init(const int64_t axis = 0); + /// \brief Set axis. + void set_axis(const int64_t axis); + /// \brief Get axis. + int64_t get_axis() const; +}; +MIND_API abstract::AbstractBasePtr ScatterAddWithAxisInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimScatterAddWithAxisPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SCATTER_ADD_WITH_AXIS_H_ diff --git a/mindspore/core/ops/self_adjoint_eig.h b/mindspore/core/ops/self_adjoint_eig.h index 3da1fdc9013..16ab14a4942 100644 --- a/mindspore/core/ops/self_adjoint_eig.h +++ b/mindspore/core/ops/self_adjoint_eig.h @@ -1,37 +1,37 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ -#define MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ -#include -#include - -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSelfAdjointEig = "SelfAdjointEig"; -class MIND_API SelfAdjointEig : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SelfAdjointEig); - SelfAdjointEig() : BaseOperator(kNameSelfAdjointEig) { InitIOName({"x"}, {"eigen_value", "eigen_vector"}); } -}; -MIND_API abstract::AbstractBasePtr SelfAdjointEigInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ +#define MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ +#include +#include + +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSelfAdjointEig = "SelfAdjointEig"; +class MIND_API SelfAdjointEig : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SelfAdjointEig); + SelfAdjointEig() : BaseOperator(kNameSelfAdjointEig) { InitIOName({"x"}, {"eigen_value", "eigen_vector"}); } +}; +MIND_API abstract::AbstractBasePtr SelfAdjointEigInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_SELFADJOINTEIG_H_ diff --git a/mindspore/core/ops/sequence_stack.cc b/mindspore/core/ops/sequence_stack.cc index 6df3fd21739..bb7e8ab45da 100644 --- a/mindspore/core/ops/sequence_stack.cc +++ b/mindspore/core/ops/sequence_stack.cc @@ -1,199 +1,199 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ops/sequence_stack.h" - -#include -#include -#include - -#include "abstract/abstract_value.h" -#include "abstract/dshape.h" -#include "abstract/ops/primitive_infer_map.h" -#include "abstract/utils.h" -#include "base/base.h" -#include "ir/anf.h" -#include "ir/dtype/type.h" -#include "ir/primitive.h" -#include "mindapi/base/shape_vector.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/ir/value.h" -#include "mindapi/src/helper.h" -#include "ops/op_name.h" -#include "ops/primitive_c.h" -#include "ops/stack_comm.h" -#include "utils/check_convert_utils.h" -#include "utils/convert_utils_base.h" -#include "utils/log_adapter.h" -#include "utils/shape_utils.h" - -namespace mindspore { -namespace ops { -namespace { -constexpr int64_t kUnDim = -1; -constexpr int64_t kUnRank = -2; -} // namespace -void SequenceStack::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); } - -int64_t SequenceStack::get_axis() const { return GetValue(GetAttr(kAxis)); } - -void SequenceStack::Init(const int64_t axis) { this->set_axis(axis); } -namespace { -BaseShapePtr SequenceStackInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - auto queue = input_args[kIndex0]; - if (!CheckAndConvertUtils::IsSequence(queue)) { - MS_EXCEPTION(TypeError) << "For " << op_name << ", input[0] must be sequence, but got " << queue->ToString(); - } - - if (CheckAndConvertUtils::IsDynamicSequence(queue)) { - auto queue_shape = queue->GetShape()->cast(); - MS_EXCEPTION_IF_NULL(queue_shape); - return queue_shape->element_shape()->Clone(); - } - const int64_t kOneNum = 1; - auto queue_shape = queue->GetShape()->cast(); - MS_EXCEPTION_IF_NULL(queue_shape); - auto elements = queue_shape->shape(); - if (input_args.size() < 1) { - MS_LOG(ERROR) << "Invalid input size " << input_args.size(); - } - - if (input_args.size() == 1) { - if (!CheckAndConvertUtils::IsSequence(input_args[0])) { - MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input data type must be list or tuple of tensors."; - } - } - - (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum, - primitive->name()); - - bool has_rank_valid_shape = false; - ShapeVector input_shape; - size_t element_rank = 0; - for (size_t i = 0; i < elements.size(); ++i) { - auto input_shape_tmp = elements[i]->GetShapeVector(); - if (IsDynamicRank(input_shape_tmp)) { - continue; - } - - if (!has_rank_valid_shape) { - has_rank_valid_shape = true; - input_shape = input_shape_tmp; - element_rank = input_shape_tmp.size(); - continue; - } - if (input_shape_tmp.size() != input_shape.size()) { - MS_EXCEPTION(ValueError) << "All input shape size must be the same!"; - } - for (size_t j = 0; j < input_shape.size(); ++j) { - if (input_shape.at(j) == abstract::TensorShape::kShapeDimAny && - input_shape_tmp.at(j) != abstract::TensorShape::kShapeDimAny) { - input_shape[j] = input_shape_tmp.at(j); - continue; - } - if (input_shape_tmp.at(j) != input_shape.at(j)) { - MS_EXCEPTION(ValueError) << "All input shape must be the same! " << input_shape_tmp << " And " << input_shape; - } - } - } - - if (!has_rank_valid_shape) { - return std::make_shared(ShapeVector{abstract::TensorShape::kShapeRankAny}); - } - std::vector infer_shape = input_shape; - auto axis_temp = GetValue(primitive->GetAttr(kAxis)); - CheckAndConvertUtils::CheckInRange("Stack axis", axis_temp, kIncludeBoth, - {-SizeToLong(element_rank) - 1, SizeToLong(element_rank)}, - primitive->name()); - auto axis = axis_temp < 0 ? static_cast(axis_temp) + element_rank + 1 : LongToSize(axis_temp); - (void)infer_shape.insert(infer_shape.begin() + axis, elements.size()); - return std::make_shared(infer_shape); -} - -template -TypePtr GetOutputType(const PrimitivePtr &primitive, const AbstractBasePtr &queue) { - auto queue_type = queue->GetType()->cast(); - MS_EXCEPTION_IF_NULL(queue_type); - if (queue_type->dynamic_len()) { - return queue_type->dynamic_element_type()->Clone(); - } - if (queue_type->elements().empty()) { - MS_LOG(EXCEPTION) << "Sequence length should not be 0."; - } - - auto elements = queue_type->elements(); - primitive->AddAttr("num", MakeValue(SizeToLong(elements.size()))); - auto infer_type0 = elements[0]; - for (size_t i = 1; i < elements.size(); i++) { - auto infer_typei = elements[i]; - if (infer_typei == infer_type0) { - MS_EXCEPTION(TypeError) << "All input must have the same data type!input[" << i << "] data type = " << infer_typei - << "infer_type0= " << infer_type0; - } - } - return elements[kIndex0]->Clone(); -} - -TypePtr SequenceStackInferType(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - auto queue = input_args[kIndex0]; - if (!CheckAndConvertUtils::IsTuple(queue) && !CheckAndConvertUtils::IsList(queue)) { - MS_EXCEPTION(TypeError) << "For " << op_name << ", input[0] must be sequence, but got " << queue->ToString(); - } - - if (CheckAndConvertUtils::IsTuple(queue)) { - return GetOutputType(primitive, queue); - } else { - return GetOutputType(primitive, queue); - } -} -} // namespace - -MIND_API_OPERATOR_IMPL(SequenceStack, BaseOperator); -AbstractBasePtr SequenceStackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - const int64_t kInputNum = 1; - CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto infer_shape = SequenceStackInferShape(primitive, input_args); - auto infer_type = SequenceStackInferType(primitive, input_args); - return abstract::MakeAbstract(infer_shape, infer_type); -} - -// AG means auto generated -class MIND_API AGSequenceStackInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - return SequenceStackInferShape(primitive, input_args); - } - TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { - return SequenceStackInferType(primitive, input_args); - } - AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, - const std::vector &input_args) const override { - return SequenceStackInfer(engine, primitive, input_args); - } -}; - -REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceStack, prim::kPrimSequenceStack, AGSequenceStackInfer, false); -} // namespace ops -} // namespace mindspore +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/sequence_stack.h" + +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "abstract/dshape.h" +#include "abstract/ops/primitive_infer_map.h" +#include "abstract/utils.h" +#include "base/base.h" +#include "ir/anf.h" +#include "ir/dtype/type.h" +#include "ir/primitive.h" +#include "mindapi/base/shape_vector.h" +#include "mindapi/base/shared_ptr.h" +#include "mindapi/ir/value.h" +#include "mindapi/src/helper.h" +#include "ops/op_name.h" +#include "ops/primitive_c.h" +#include "ops/stack_comm.h" +#include "utils/check_convert_utils.h" +#include "utils/convert_utils_base.h" +#include "utils/log_adapter.h" +#include "utils/shape_utils.h" + +namespace mindspore { +namespace ops { +namespace { +constexpr int64_t kUnDim = -1; +constexpr int64_t kUnRank = -2; +} // namespace +void SequenceStack::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); } + +int64_t SequenceStack::get_axis() const { return GetValue(GetAttr(kAxis)); } + +void SequenceStack::Init(const int64_t axis) { this->set_axis(axis); } +namespace { +BaseShapePtr SequenceStackInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + auto queue = input_args[kIndex0]; + if (!CheckAndConvertUtils::IsSequence(queue)) { + MS_EXCEPTION(TypeError) << "For " << op_name << ", input[0] must be sequence, but got " << queue->ToString(); + } + + if (CheckAndConvertUtils::IsDynamicSequence(queue)) { + auto queue_shape = queue->GetShape()->cast(); + MS_EXCEPTION_IF_NULL(queue_shape); + return queue_shape->element_shape()->Clone(); + } + const int64_t kOneNum = 1; + auto queue_shape = queue->GetShape()->cast(); + MS_EXCEPTION_IF_NULL(queue_shape); + auto elements = queue_shape->shape(); + if (input_args.size() < 1) { + MS_LOG(ERROR) << "Invalid input size " << input_args.size(); + } + + if (input_args.size() == 1) { + if (!CheckAndConvertUtils::IsSequence(input_args[0])) { + MS_EXCEPTION(TypeError) << "For '" << op_name << "', the input data type must be list or tuple of tensors."; + } + } + + (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum, + primitive->name()); + + bool has_rank_valid_shape = false; + ShapeVector input_shape; + size_t element_rank = 0; + for (size_t i = 0; i < elements.size(); ++i) { + auto input_shape_tmp = elements[i]->GetShapeVector(); + if (IsDynamicRank(input_shape_tmp)) { + continue; + } + + if (!has_rank_valid_shape) { + has_rank_valid_shape = true; + input_shape = input_shape_tmp; + element_rank = input_shape_tmp.size(); + continue; + } + if (input_shape_tmp.size() != input_shape.size()) { + MS_EXCEPTION(ValueError) << "All input shape size must be the same!"; + } + for (size_t j = 0; j < input_shape.size(); ++j) { + if (input_shape.at(j) == abstract::TensorShape::kShapeDimAny && + input_shape_tmp.at(j) != abstract::TensorShape::kShapeDimAny) { + input_shape[j] = input_shape_tmp.at(j); + continue; + } + if (input_shape_tmp.at(j) != input_shape.at(j)) { + MS_EXCEPTION(ValueError) << "All input shape must be the same! " << input_shape_tmp << " And " << input_shape; + } + } + } + + if (!has_rank_valid_shape) { + return std::make_shared(ShapeVector{abstract::TensorShape::kShapeRankAny}); + } + std::vector infer_shape = input_shape; + auto axis_temp = GetValue(primitive->GetAttr(kAxis)); + CheckAndConvertUtils::CheckInRange("Stack axis", axis_temp, kIncludeBoth, + {-SizeToLong(element_rank) - 1, SizeToLong(element_rank)}, + primitive->name()); + auto axis = axis_temp < 0 ? static_cast(axis_temp) + element_rank + 1 : LongToSize(axis_temp); + (void)infer_shape.insert(infer_shape.begin() + axis, elements.size()); + return std::make_shared(infer_shape); +} + +template +TypePtr GetOutputType(const PrimitivePtr &primitive, const AbstractBasePtr &queue) { + auto queue_type = queue->GetType()->cast(); + MS_EXCEPTION_IF_NULL(queue_type); + if (queue_type->dynamic_len()) { + return queue_type->dynamic_element_type()->Clone(); + } + if (queue_type->elements().empty()) { + MS_LOG(EXCEPTION) << "Sequence length should not be 0."; + } + + auto elements = queue_type->elements(); + primitive->AddAttr("num", MakeValue(SizeToLong(elements.size()))); + auto infer_type0 = elements[0]; + for (size_t i = 1; i < elements.size(); i++) { + auto infer_typei = elements[i]; + if (infer_typei == infer_type0) { + MS_EXCEPTION(TypeError) << "All input must have the same data type!input[" << i << "] data type = " << infer_typei + << "infer_type0= " << infer_type0; + } + } + return elements[kIndex0]->Clone(); +} + +TypePtr SequenceStackInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + auto queue = input_args[kIndex0]; + if (!CheckAndConvertUtils::IsTuple(queue) && !CheckAndConvertUtils::IsList(queue)) { + MS_EXCEPTION(TypeError) << "For " << op_name << ", input[0] must be sequence, but got " << queue->ToString(); + } + + if (CheckAndConvertUtils::IsTuple(queue)) { + return GetOutputType(primitive, queue); + } else { + return GetOutputType(primitive, queue); + } +} +} // namespace + +MIND_API_OPERATOR_IMPL(SequenceStack, BaseOperator); +AbstractBasePtr SequenceStackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + const int64_t kInputNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_shape = SequenceStackInferShape(primitive, input_args); + auto infer_type = SequenceStackInferType(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +// AG means auto generated +class MIND_API AGSequenceStackInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + return SequenceStackInferShape(primitive, input_args); + } + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + return SequenceStackInferType(primitive, input_args); + } + AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const std::vector &input_args) const override { + return SequenceStackInfer(engine, primitive, input_args); + } +}; + +REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceStack, prim::kPrimSequenceStack, AGSequenceStackInfer, false); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sequence_stack.h b/mindspore/core/ops/sequence_stack.h index 9369d203ee0..a0af6981f6a 100644 --- a/mindspore/core/ops/sequence_stack.h +++ b/mindspore/core/ops/sequence_stack.h @@ -1,45 +1,45 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ -#define MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ - -#include "mindapi/base/types.h" -#include "mindspore/core/ops/sequence_ops.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -// constexpr auto kNameSequenceStack = "SequenceStack"; -/// \brief Sequence concat operation -class MIND_API SequenceStack : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SequenceStack); - /// \brief Constructor. - SequenceStack() : BaseOperator(kSequenceStackOpName) {} - /// \brief Init function. - void Init(const int64_t axis); - /// \brief Set axis. - void set_axis(const int64_t axis); - /// \brief Get axis. - /// - /// \return axis. - int64_t get_axis() const; -}; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ +#define MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ + +#include "mindapi/base/types.h" +#include "mindspore/core/ops/sequence_ops.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +// constexpr auto kNameSequenceStack = "SequenceStack"; +/// \brief Sequence concat operation +class MIND_API SequenceStack : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SequenceStack); + /// \brief Constructor. + SequenceStack() : BaseOperator(kSequenceStackOpName) {} + /// \brief Init function. + void Init(const int64_t axis); + /// \brief Set axis. + void set_axis(const int64_t axis); + /// \brief Get axis. + /// + /// \return axis. + int64_t get_axis() const; +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SEQUENCE_STACK_H_ diff --git a/mindspore/core/ops/sparse_addmm.h b/mindspore/core/ops/sparse_addmm.h index 378e98209d4..57cb7bbc63e 100644 --- a/mindspore/core/ops/sparse_addmm.h +++ b/mindspore/core/ops/sparse_addmm.h @@ -1,41 +1,41 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_ -#define MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_ -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseAddmm = "SparseAddmm"; -class MIND_API SparseAddmm : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseAddmm); - SparseAddmm() : BaseOperator(kNameSparseAddmm) { - InitIOName({"indices", "values", "sparse_shape", "x2_dense", "x3_dense", "alpha", "beta"}, {"output"}); - } - void Init() const {} -}; -MIND_API abstract::AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimSparseAddmmPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_ +#define MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_ +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseAddmm = "SparseAddmm"; +class MIND_API SparseAddmm : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseAddmm); + SparseAddmm() : BaseOperator(kNameSparseAddmm) { + InitIOName({"indices", "values", "sparse_shape", "x2_dense", "x3_dense", "alpha", "beta"}, {"output"}); + } + void Init() const {} +}; +MIND_API abstract::AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseAddmmPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif diff --git a/mindspore/core/ops/sparse_apply_momentum.h b/mindspore/core/ops/sparse_apply_momentum.h index 2608b9a86bb..f1b400e8f4a 100644 --- a/mindspore/core/ops/sparse_apply_momentum.h +++ b/mindspore/core/ops/sparse_apply_momentum.h @@ -1,55 +1,55 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ -#define MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ - -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseApplyMomentum = "SparseApplyMomentum"; -class MIND_API SparseApplyMomentum : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseApplyMomentum); - SparseApplyMomentum() : BaseOperator(kNameSparseApplyMomentum) {} - - void Init(const bool use_locking = false, const bool use_nesterov = false); - - void set_use_locking(const bool use_locking); - - void set_use_nesterov(const bool use_nesterov); - - bool get_use_locking() const; - - bool get_use_nesterov() const; -}; - -MIND_API abstract::AbstractBasePtr SparseApplyMomentumInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -using kPrimSparseApplyMomentumPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ +#define MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ + +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseApplyMomentum = "SparseApplyMomentum"; +class MIND_API SparseApplyMomentum : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseApplyMomentum); + SparseApplyMomentum() : BaseOperator(kNameSparseApplyMomentum) {} + + void Init(const bool use_locking = false, const bool use_nesterov = false); + + void set_use_locking(const bool use_locking); + + void set_use_nesterov(const bool use_nesterov); + + bool get_use_locking() const; + + bool get_use_nesterov() const; +}; + +MIND_API abstract::AbstractBasePtr SparseApplyMomentumInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimSparseApplyMomentumPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_MOMENTUM_H_ diff --git a/mindspore/core/ops/sparse_apply_proximal_gradient_descent.h b/mindspore/core/ops/sparse_apply_proximal_gradient_descent.h index b4b856d23aa..f69b18f48c1 100644 --- a/mindspore/core/ops/sparse_apply_proximal_gradient_descent.h +++ b/mindspore/core/ops/sparse_apply_proximal_gradient_descent.h @@ -1,50 +1,50 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ -#define MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ - -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseApplyProximalGradientDescent = "SparseApplyProximalGradientDescent"; -class MIND_API SparseApplyProximalGradientDescent : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseApplyProximalGradientDescent); - SparseApplyProximalGradientDescent() : BaseOperator(kNameSparseApplyProximalGradientDescent) {} - - void Init(const bool use_locking = false); - - void set_use_locking(const bool use_locking); - - bool get_use_locking() const; -}; - -MIND_API abstract::AbstractBasePtr SparseApplyProximalGradientDescentInfer( - const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); -using kPrimSparseApplyProximalGradientDescentPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ +#define MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ + +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseApplyProximalGradientDescent = "SparseApplyProximalGradientDescent"; +class MIND_API SparseApplyProximalGradientDescent : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseApplyProximalGradientDescent); + SparseApplyProximalGradientDescent() : BaseOperator(kNameSparseApplyProximalGradientDescent) {} + + void Init(const bool use_locking = false); + + void set_use_locking(const bool use_locking); + + bool get_use_locking() const; +}; + +MIND_API abstract::AbstractBasePtr SparseApplyProximalGradientDescentInfer( + const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using kPrimSparseApplyProximalGradientDescentPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ diff --git a/mindspore/core/ops/sparse_segment_sum.h b/mindspore/core/ops/sparse_segment_sum.h index 9f72b0c475a..2121fb714ff 100644 --- a/mindspore/core/ops/sparse_segment_sum.h +++ b/mindspore/core/ops/sparse_segment_sum.h @@ -1,45 +1,45 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ -#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ -#include -#include -#include -#include -#include -#include "abstract/abstract_value.h" -#include "ops/base_operator.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseSegmentSum = "SparseSegmentSum"; -/// \brief Computes the sum along sparse segments of a tensor. -/// Refer to Python API @ref mindspore.ops.SparseSegmentSum for more details. -class MIND_API SparseSegmentSum : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseSegmentSum); - /// \brief Constructor. - SparseSegmentSum() : BaseOperator(kNameSparseSegmentSum) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); } -}; - -AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ +#include +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "ops/base_operator.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSum = "SparseSegmentSum"; +/// \brief Computes the sum along sparse segments of a tensor. +/// Refer to Python API @ref mindspore.ops.SparseSegmentSum for more details. +class MIND_API SparseSegmentSum : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSum); + /// \brief Constructor. + SparseSegmentSum() : BaseOperator(kNameSparseSegmentSum) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); } +}; + +AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ diff --git a/mindspore/core/ops/sparse_segment_sum_with_num_segments.h b/mindspore/core/ops/sparse_segment_sum_with_num_segments.h index ed0acc2c36c..9b119be78d8 100644 --- a/mindspore/core/ops/sparse_segment_sum_with_num_segments.h +++ b/mindspore/core/ops/sparse_segment_sum_with_num_segments.h @@ -1,47 +1,47 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ -#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ -#include -#include -#include -#include -#include -#include "abstract/abstract_value.h" -#include "ops/base_operator.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments"; -/// \brief Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids. -/// Refer to Python API @ref mindspore.ops.SparseSegmentSumWithNumSegments for more details. -class MIND_API SparseSegmentSumWithNumSegments : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseSegmentSumWithNumSegments); - /// \brief Constructor. - SparseSegmentSumWithNumSegments() : BaseOperator(kNameSparseSegmentSumWithNumSegments) { - InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"}); - } -}; -AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimSparseSegmentSumWithNumSegmentsPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ +#include +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "ops/base_operator.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments"; +/// \brief Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids. +/// Refer to Python API @ref mindspore.ops.SparseSegmentSumWithNumSegments for more details. +class MIND_API SparseSegmentSumWithNumSegments : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSumWithNumSegments); + /// \brief Constructor. + SparseSegmentSumWithNumSegments() : BaseOperator(kNameSparseSegmentSumWithNumSegments) { + InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"}); + } +}; +AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentSumWithNumSegmentsPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ diff --git a/mindspore/core/ops/sparse_softmax.h b/mindspore/core/ops/sparse_softmax.h index c07ca0bc0d5..2d1563ed2e7 100644 --- a/mindspore/core/ops/sparse_softmax.h +++ b/mindspore/core/ops/sparse_softmax.h @@ -1,44 +1,44 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ -#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ -#include -#include -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseSoftmax = "SparseSoftmax"; -/// \brief Similar to softmax but with the catch that the implicitly zero -/// elements do not participate. Refer to Python API @ref -/// mindspore.ops.SparseSoftmax for more details. -class MIND_API SparseSoftmax : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseSoftmax); - /// \brief Constructor. - SparseSoftmax() : BaseOperator(kNameSparseSoftmax) { InitIOName({"indices", "values", "shape"}, {"output"}); } - /// \brief Init. - void Init() const {} -}; - -MIND_API abstract::AbstractBasePtr SparseSoftmaxInfer(const abstract::AnalysisEnginePtr &, - const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ +#include +#include +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSoftmax = "SparseSoftmax"; +/// \brief Similar to softmax but with the catch that the implicitly zero +/// elements do not participate. Refer to Python API @ref +/// mindspore.ops.SparseSoftmax for more details. +class MIND_API SparseSoftmax : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSoftmax); + /// \brief Constructor. + SparseSoftmax() : BaseOperator(kNameSparseSoftmax) { InitIOName({"indices", "values", "shape"}, {"output"}); } + /// \brief Init. + void Init() const {} +}; + +MIND_API abstract::AbstractBasePtr SparseSoftmaxInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_H_ diff --git a/mindspore/core/ops/sparse_tensor_to_csr_sparse_matrix.h b/mindspore/core/ops/sparse_tensor_to_csr_sparse_matrix.h index d7c9b5bf8a9..111896148d3 100644 --- a/mindspore/core/ops/sparse_tensor_to_csr_sparse_matrix.h +++ b/mindspore/core/ops/sparse_tensor_to_csr_sparse_matrix.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ -#define MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseTensorToCSRSparseMatrix = "SparseTensorToCSRSparseMatrix"; -class MIND_API SparseTensorToCSRSparseMatrix : public BaseOperator { - public: - MIND_API_BASE_MEMBER(SparseTensorToCSRSparseMatrix); - SparseTensorToCSRSparseMatrix() : BaseOperator(kNameSparseTensorToCSRSparseMatrix) { - InitIOName({"x_indices", "x_values", "x_dense_shape"}, - {"y_dense_shape", "y_batch_pointers", "y_row_pointers", "y_col_indices", "y_values"}); - } -}; -MIND_API abstract::AbstractBasePtr SparseTensorToCSRSparseMatrixInfer( - const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimSparseTensorToCSRSparseMatrixPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ +#define MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseTensorToCSRSparseMatrix = "SparseTensorToCSRSparseMatrix"; +class MIND_API SparseTensorToCSRSparseMatrix : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseTensorToCSRSparseMatrix); + SparseTensorToCSRSparseMatrix() : BaseOperator(kNameSparseTensorToCSRSparseMatrix) { + InitIOName({"x_indices", "x_values", "x_dense_shape"}, + {"y_dense_shape", "y_batch_pointers", "y_row_pointers", "y_col_indices", "y_values"}); + } +}; +MIND_API abstract::AbstractBasePtr SparseTensorToCSRSparseMatrixInfer( + const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseTensorToCSRSparseMatrixPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_TENSOR_TO_CSR_SPAESE_MATRIX_H_ diff --git a/mindspore/core/ops/tensor_copy.h b/mindspore/core/ops/tensor_copy.h index 7e1147a766c..bfec7a31cd3 100644 --- a/mindspore/core/ops/tensor_copy.h +++ b/mindspore/core/ops/tensor_copy.h @@ -1,43 +1,43 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_TENSOR_MOVE_ELEMENTS_H_ -#define MINDSPORE_CORE_OPS_TENSOR_MOVE_ELEMENTS_H_ -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameTensorMove = "TensorMove"; -/// \brief Updates tensor values by using input indices and value. -/// Refer to Python API @ref mindspore.ops.TensorMove for more details. -class MIND_API TensorMove : public BaseOperator { - public: - MIND_API_BASE_MEMBER(TensorMove); - /// \brief Constructor. - TensorMove() : BaseOperator(kNameTensorMove) { InitIOName({"input"}, {"output"}); } -}; -MIND_API abstract::AbstractBasePtr TensorMoveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using kPrimTensorMovePtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_ELEMENTS_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TENSOR_MOVE_ELEMENTS_H_ +#define MINDSPORE_CORE_OPS_TENSOR_MOVE_ELEMENTS_H_ +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTensorMove = "TensorMove"; +/// \brief Updates tensor values by using input indices and value. +/// Refer to Python API @ref mindspore.ops.TensorMove for more details. +class MIND_API TensorMove : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TensorMove); + /// \brief Constructor. + TensorMove() : BaseOperator(kNameTensorMove) { InitIOName({"input"}, {"output"}); } +}; +MIND_API abstract::AbstractBasePtr TensorMoveInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimTensorMovePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_ELEMENTS_H_ diff --git a/mindspore/core/ops/tril.h b/mindspore/core/ops/tril.h index b640de99384..bc65e4a0c6c 100644 --- a/mindspore/core/ops/tril.h +++ b/mindspore/core/ops/tril.h @@ -1,47 +1,47 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_TRIL_H_ -#define MINDSPORE_CORE_OPS_TRIL_H_ - -#include -#include -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameTril = "Tril"; -class MIND_API Tril : public BaseOperator { - public: - MIND_API_BASE_MEMBER(Tril); - Tril() : BaseOperator(kNameTril) { InitIOName({"x"}, {"y"}); } - /// \brief Init. - void Init(const int64_t diagonal = 0); - /// \brief Set diagonal. - void set_diagonal(const int64_t diagonal); - int64_t get_diagonal() const; -}; - -MIND_API abstract::AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_TRIL_H_ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRIL_H_ +#define MINDSPORE_CORE_OPS_TRIL_H_ + +#include +#include +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTril = "Tril"; +class MIND_API Tril : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Tril); + Tril() : BaseOperator(kNameTril) { InitIOName({"x"}, {"y"}); } + /// \brief Init. + void Init(const int64_t diagonal = 0); + /// \brief Set diagonal. + void set_diagonal(const int64_t diagonal); + int64_t get_diagonal() const; +}; + +MIND_API abstract::AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TRIL_H_ diff --git a/mindspore/core/ops/tril_indices.h b/mindspore/core/ops/tril_indices.h index 24eeb9a20ed..46105e6fad5 100644 --- a/mindspore/core/ops/tril_indices.h +++ b/mindspore/core/ops/tril_indices.h @@ -1,64 +1,64 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_TRIL_INDICES_H_ -#define MINDSPORE_CORE_OPS_TRIL_INDICES_H_ - -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameTrilIndices = "TrilIndices"; -/// \brief Returns the indices of the lower triangular part of a row-by- col matrix. -/// Refer to Python API @ref mindspore.ops.TrilIndices for more details. -class MIND_API TrilIndices : public BaseOperator { - public: - MIND_API_BASE_MEMBER(TrilIndices); - /// \brief Construct. - TrilIndices() : BaseOperator(kNameTrilIndices) { InitIOName({}, {"y"}); } - /// \brief Init. - void Init(const int64_t row, const int64_t col, const int64_t offset = 0); - /// \brief Set row. - void set_row(const int64_t row); - /// \brief Set col. - void set_col(const int64_t col); - /// \brief Set offset. - void set_offset(const int64_t offset); - - /// \brief Get row. - /// - /// \return row. - int64_t get_row() const; - /// \brief Get col. - /// - /// \return col. - int64_t get_col() const; - /// \brief Get offset. - /// - /// \return offset. - int64_t get_offset() const; -}; - -MIND_API abstract::AbstractBasePtr TrilIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_TRIL_INDICES_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRIL_INDICES_H_ +#define MINDSPORE_CORE_OPS_TRIL_INDICES_H_ + +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTrilIndices = "TrilIndices"; +/// \brief Returns the indices of the lower triangular part of a row-by- col matrix. +/// Refer to Python API @ref mindspore.ops.TrilIndices for more details. +class MIND_API TrilIndices : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TrilIndices); + /// \brief Construct. + TrilIndices() : BaseOperator(kNameTrilIndices) { InitIOName({}, {"y"}); } + /// \brief Init. + void Init(const int64_t row, const int64_t col, const int64_t offset = 0); + /// \brief Set row. + void set_row(const int64_t row); + /// \brief Set col. + void set_col(const int64_t col); + /// \brief Set offset. + void set_offset(const int64_t offset); + + /// \brief Get row. + /// + /// \return row. + int64_t get_row() const; + /// \brief Get col. + /// + /// \return col. + int64_t get_col() const; + /// \brief Get offset. + /// + /// \return offset. + int64_t get_offset() const; +}; + +MIND_API abstract::AbstractBasePtr TrilIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TRIL_INDICES_H_ diff --git a/mindspore/core/ops/triu_indices.h b/mindspore/core/ops/triu_indices.h index 569c4b3a67f..6e50f4df4fb 100644 --- a/mindspore/core/ops/triu_indices.h +++ b/mindspore/core/ops/triu_indices.h @@ -1,64 +1,64 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_TRIU_INDICES_H_ -#define MINDSPORE_CORE_OPS_TRIU_INDICES_H_ - -#include -#include - -#include "mindapi/base/types.h" -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameTriuIndices = "TriuIndices"; -/// \brief Returns the indices of the lower triangular part of a row-by- col matrix. -/// Refer to Python API @ref mindspore.ops.TriuIndices for more details. -class MIND_API TriuIndices : public BaseOperator { - public: - MIND_API_BASE_MEMBER(TriuIndices); - /// \brief Construct. - TriuIndices() : BaseOperator(kNameTriuIndices) { InitIOName({}, {"y"}); } - /// \brief Init. - void Init(const int64_t row, const int64_t col, const int64_t offset = 0); - /// \brief Set row. - void set_row(const int64_t row); - /// \brief Set col. - void set_col(const int64_t col); - /// \brief Set offset. - void set_offset(const int64_t offset); - - /// \brief Get row. - /// - /// \return row. - int64_t get_row() const; - /// \brief Get col. - /// - /// \return col. - int64_t get_col() const; - /// \brief Get offset. - /// - /// \return offset. - int64_t get_offset() const; -}; - -MIND_API abstract::AbstractBasePtr TriuIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_TRIU_INDICES_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRIU_INDICES_H_ +#define MINDSPORE_CORE_OPS_TRIU_INDICES_H_ + +#include +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTriuIndices = "TriuIndices"; +/// \brief Returns the indices of the lower triangular part of a row-by- col matrix. +/// Refer to Python API @ref mindspore.ops.TriuIndices for more details. +class MIND_API TriuIndices : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TriuIndices); + /// \brief Construct. + TriuIndices() : BaseOperator(kNameTriuIndices) { InitIOName({}, {"y"}); } + /// \brief Init. + void Init(const int64_t row, const int64_t col, const int64_t offset = 0); + /// \brief Set row. + void set_row(const int64_t row); + /// \brief Set col. + void set_col(const int64_t col); + /// \brief Set offset. + void set_offset(const int64_t offset); + + /// \brief Get row. + /// + /// \return row. + int64_t get_row() const; + /// \brief Get col. + /// + /// \return col. + int64_t get_col() const; + /// \brief Get offset. + /// + /// \return offset. + int64_t get_offset() const; +}; + +MIND_API abstract::AbstractBasePtr TriuIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TRIU_INDICES_H_ diff --git a/mindspore/lite/build_lite.sh b/mindspore/lite/build_lite.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/export_models/models/NetworkInNetwork.py b/mindspore/lite/examples/export_models/models/NetworkInNetwork.py index ee084802cb8..8e6a4137940 100644 --- a/mindspore/lite/examples/export_models/models/NetworkInNetwork.py +++ b/mindspore/lite/examples/export_models/models/NetworkInNetwork.py @@ -1,81 +1,81 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""NetworkInNetwork.""" - -import numpy as np -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -# NiN block -class NiN(nn.Cell): - """class NiN""" - def __init__(self, num_classes=10, num_channel=3): - super().__init__() - self.size = ops.Size() - self.block0 = nn.SequentialCell( - # block 0 - nn.Conv2d(in_channels=num_channel, out_channels=192, kernel_size=5, stride=1, has_bias=False), - nn.ReLU(), - nn.Conv2d(in_channels=192, out_channels=160, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.Conv2d(in_channels=160, out_channels=96, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'), - nn.Dropout(p=0.0) - ) - self.block1 = nn.SequentialCell( - # block 1 - nn.Conv2d(in_channels=96, out_channels=192, kernel_size=5, stride=1, has_bias=False), - nn.ReLU(), - nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'), - nn.Dropout(p=0.0) - ) - self.block2 = nn.SequentialCell( - # block 2 - nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, has_bias=False), - nn.ReLU(), - nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.Conv2d(in_channels=192, out_channels=num_classes, kernel_size=1, stride=1, has_bias=True), - nn.ReLU(), - nn.AvgPool2d(kernel_size=8, stride=1, pad_mode='valid') - ) - # flatten - self.flatten = nn.Flatten() - self._initialize_weights() - - def _initialize_weights(self): - self.init_parameters_data() - for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d)): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - def construct(self, x): - out = self.block0(x) - out = self.block1(out) - out = self.block2(out) - out = self.flatten(out) - return out +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""NetworkInNetwork.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +# NiN block +class NiN(nn.Cell): + """class NiN""" + def __init__(self, num_classes=10, num_channel=3): + super().__init__() + self.size = ops.Size() + self.block0 = nn.SequentialCell( + # block 0 + nn.Conv2d(in_channels=num_channel, out_channels=192, kernel_size=5, stride=1, has_bias=False), + nn.ReLU(), + nn.Conv2d(in_channels=192, out_channels=160, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.Conv2d(in_channels=160, out_channels=96, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'), + nn.Dropout(p=0.0) + ) + self.block1 = nn.SequentialCell( + # block 1 + nn.Conv2d(in_channels=96, out_channels=192, kernel_size=5, stride=1, has_bias=False), + nn.ReLU(), + nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same'), + nn.Dropout(p=0.0) + ) + self.block2 = nn.SequentialCell( + # block 2 + nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, has_bias=False), + nn.ReLU(), + nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.Conv2d(in_channels=192, out_channels=num_classes, kernel_size=1, stride=1, has_bias=True), + nn.ReLU(), + nn.AvgPool2d(kernel_size=8, stride=1, pad_mode='valid') + ) + # flatten + self.flatten = nn.Flatten() + self._initialize_weights() + + def _initialize_weights(self): + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, (nn.Conv2d)): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + def construct(self, x): + out = self.block0(x) + out = self.block1(out) + out = self.block2(out) + out = self.flatten(out) + return out diff --git a/mindspore/lite/examples/export_models/models/effnet.py b/mindspore/lite/examples/export_models/models/effnet.py old mode 100755 new mode 100644 index fc498724f66..28c2a70fdcb --- a/mindspore/lite/examples/export_models/models/effnet.py +++ b/mindspore/lite/examples/export_models/models/effnet.py @@ -1,313 +1,313 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""effnet.""" - -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.initializer import TruncatedNormal -from mindspore import Tensor - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - - -def _make_value_divisible(value, factor, min_value=None): - """ - It ensures that all layers have a channel number that is divisible by 8 - :param v: value to process - :param factor: divisor - :param min_value: new value always greater than the min_value - :return: new value - """ - if min_value is None: - min_value = factor - new_value = max(int(value + factor / 2) // factor * factor, min_value) - if new_value < value * 0.9: - new_value += factor - return new_value - - -class Swish(nn.Cell): - def __init__(self): - super().__init__() - self.sigmoid = nn.Sigmoid() - - def construct(self, x): - s = self.sigmoid(x) - m = x*s - return m - - -class AdaptiveAvgPool(nn.Cell): - def __init__(self, output_size=None): - super().__init__() - self.mean = P.ReduceMean(keep_dims=True) - self.output_size = output_size - - def construct(self, x): - return self.mean(x, (2, 3)) # This is not a general case - - -class SELayer(nn.Cell): - """SELayer""" - def __init__(self, channel, reduction=4): - super().__init__() - reduced_chs = _make_value_divisible(channel/reduction, 1) - self.avg_pool = AdaptiveAvgPool(output_size=(1, 1)) - weight = weight_variable() - self.conv_reduce = nn.Conv2d(in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, - weight_init=weight) - self.act1 = Swish() - self.conv_expand = nn.Conv2d(in_channels=reduced_chs, out_channels=channel, kernel_size=1, has_bias=True) - self.act2 = nn.Sigmoid() - - def construct(self, x): - o = self.avg_pool(x) # .view(b,c) - o = self.conv_reduce(o) - o = self.act1(o) - o = self.conv_expand(o) - o = self.act2(o) # .view(b, c, 1,1) - return x * o - - -class DepthwiseSeparableConv(nn.Cell): - """DepthwiseSeparableConv""" - def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0): - super().__init__() - if stride not in [1, 2]: - print("ERROR stride param") - return - self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip - self.drop_connect_rate = drop_connect_rate - - self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size, stride=stride, - pad_mode="pad", padding=1, has_bias=False, group=in_chs) - self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) # momentum=0.1) - self.act1 = Swish() - - # Squeeze-and-excitation - if se_ratio is not None and se_ratio > 0.: - self.se = SELayer(in_chs, reduction=se_ratio) - else: - print("ERRRRRORRRR -- not prepared for this one\n") - - self.conv_pw = nn.Conv2d(in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False) - self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) # momentum=0.1) - - def construct(self, x): - """construct""" - residual = x - - x = self.conv_dw(x) - x = self.bn1(x) - x = self.act1(x) - - x = self.se(x) - - x = self.conv_pw(x) - x = self.bn2(x) - - if self.has_residual: - x += residual - return x - - -def conv_3x3_bn(inp, oup, stride): - weight = weight_variable() - return nn.SequentialCell([ - nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, padding=1, weight_init=weight, - has_bias=False, pad_mode='pad'), - nn.BatchNorm2d(oup, eps=0.001), # momentum=0.1), - nn.HSwish()]) - - -def conv_1x1_bn(inp, oup): - weight = weight_variable() - return nn.SequentialCell([ - nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, stride=1, padding=0, weight_init=weight, - has_bias=False), - nn.BatchNorm2d(oup, eps=0.001), - nn.HSwish()]) - - -class InvertedResidual(nn.Cell): - """InvertedResidual""" - def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio): - super().__init__() - if stride not in [1, 2]: - print("ERROR stride param") - return - mid_chs: int = _make_value_divisible(in_chs * expansion, 1) - self.has_residual = (in_chs == out_chs and stride == 1) - self.drop_connect_rate = 0 - - # Point-wise expansion - self.conv_pw = nn.Conv2d(in_channels=in_chs, out_channels=mid_chs, kernel_size=1, stride=1, has_bias=False) - self.bn1 = nn.BatchNorm2d(mid_chs, eps=0.001) - self.act1 = Swish() - - # Depth-wise convolution - if stride > 1: - self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, stride=stride, - padding=padding, has_bias=False, group=mid_chs, pad_mode='same') - else: - self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, stride=stride, - padding=padding, has_bias=False, group=mid_chs, pad_mode='pad') - self.bn2 = nn.BatchNorm2d(mid_chs, eps=0.001) - self.act2 = Swish() - - # Squeeze-and-excitation - if se_ratio is not None and se_ratio > 0.: - self.se = SELayer(mid_chs, reduction=se_ratio) - else: - print("ERRRRRORRRR -- not prepared for this one\n") - - # Point-wise linear projection - self.conv_pwl = nn.Conv2d(in_channels=mid_chs, out_channels=out_chs, kernel_size=1, stride=1, has_bias=False) - self.bn3 = nn.BatchNorm2d(out_chs, eps=0.001) - - def construct(self, x): - """construct""" - residual = x - - # Point-wise expansion - x = self.conv_pw(x) - x = self.bn1(x) - x = self.act1(x) - - # Depth-wise convolution - x = self.conv_dw(x) - x = self.bn2(x) - x = self.act2(x) - - # Squeeze-and-excitation - x = self.se(x) - - # Point-wise linear projection - x = self.conv_pwl(x) - x = self.bn3(x) - - if self.has_residual: - x += residual - return x - - -class EfficientNet(nn.Cell): - """EfficientNet""" - def __init__(self, cfgs, num_classes=1000): - super().__init__() - # setting of inverted residual blocks - self.cfgs = cfgs - stem_size = 32 - self.num_classes_ = num_classes - self.num_features_ = 1280 - - self.conv_stem = nn.Conv2d(in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False) - - self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1) - self.act1 = Swish() - in_chs = stem_size - - layers = [nn.SequentialCell([DepthwiseSeparableConv(in_chs, 16, 3, 1, se_ratio=4)]), - - nn.SequentialCell([InvertedResidual(16, 24, 3, 2, 0, 6, se_ratio=24), - InvertedResidual(24, 24, 3, 1, 1, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(24, 40, 5, 2, 0, 6, se_ratio=24), - InvertedResidual(40, 40, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(40, 80, 3, 2, 0, 6, se_ratio=24), - InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24), - InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(80, 112, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(112, 192, 5, 2, 0, 6, se_ratio=24), - InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(192, 320, 3, 1, 1, 6, se_ratio=24)]) - ] - self.blocks = nn.SequentialCell(layers) - - self.conv_head = nn.Conv2d(in_channels=320, out_channels=self.num_features_, kernel_size=1) - self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) # momentum=0.1) - self.act2 = Swish() - self.global_pool = AdaptiveAvgPool(output_size=(1, 1)) - self.classifier = nn.Dense(self.num_features_, num_classes) - - self._initialize_weights() - - def construct(self, x): - """construct""" - x = self.conv_stem(x) - x = self.bn1(x) - x = self.act1(x) - x = self.blocks(x) - x = self.conv_head(x) - x = self.bn2(x) - x = self.act2(x) - x = self.global_pool(x) - x = P.Reshape()(x, (-1, self.num_features_)) - x = self.classifier(x) - return x - - def _initialize_weights(self): - """_initialize_weights""" - def init_linear_weight(m): - m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - for m in self.cells(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.data.zero_() - m.weight.requires_grad = True - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - init_linear_weight(m) - - -def effnet(**kwargs): - """ - Constructs a EfficientNet model - """ - cfgs = [ - # k, t, c, SE, HS, s - [3, 1, 16, 1, 0, 2], - [3, 4.5, 24, 0, 0, 2], - [3, 3.67, 24, 0, 0, 1], - [5, 4, 40, 1, 1, 2], - [5, 6, 40, 1, 1, 1], - [5, 6, 40, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 6, 96, 1, 1, 2], - [5, 6, 96, 1, 1, 1], - [5, 6, 96, 1, 1, 1], - ] - - return EfficientNet(cfgs, **kwargs) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""effnet.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal +from mindspore import Tensor + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def _make_value_divisible(value, factor, min_value=None): + """ + It ensures that all layers have a channel number that is divisible by 8 + :param v: value to process + :param factor: divisor + :param min_value: new value always greater than the min_value + :return: new value + """ + if min_value is None: + min_value = factor + new_value = max(int(value + factor / 2) // factor * factor, min_value) + if new_value < value * 0.9: + new_value += factor + return new_value + + +class Swish(nn.Cell): + def __init__(self): + super().__init__() + self.sigmoid = nn.Sigmoid() + + def construct(self, x): + s = self.sigmoid(x) + m = x*s + return m + + +class AdaptiveAvgPool(nn.Cell): + def __init__(self, output_size=None): + super().__init__() + self.mean = P.ReduceMean(keep_dims=True) + self.output_size = output_size + + def construct(self, x): + return self.mean(x, (2, 3)) # This is not a general case + + +class SELayer(nn.Cell): + """SELayer""" + def __init__(self, channel, reduction=4): + super().__init__() + reduced_chs = _make_value_divisible(channel/reduction, 1) + self.avg_pool = AdaptiveAvgPool(output_size=(1, 1)) + weight = weight_variable() + self.conv_reduce = nn.Conv2d(in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, + weight_init=weight) + self.act1 = Swish() + self.conv_expand = nn.Conv2d(in_channels=reduced_chs, out_channels=channel, kernel_size=1, has_bias=True) + self.act2 = nn.Sigmoid() + + def construct(self, x): + o = self.avg_pool(x) # .view(b,c) + o = self.conv_reduce(o) + o = self.act1(o) + o = self.conv_expand(o) + o = self.act2(o) # .view(b, c, 1,1) + return x * o + + +class DepthwiseSeparableConv(nn.Cell): + """DepthwiseSeparableConv""" + def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0): + super().__init__() + if stride not in [1, 2]: + print("ERROR stride param") + return + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size, stride=stride, + pad_mode="pad", padding=1, has_bias=False, group=in_chs) + self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) # momentum=0.1) + self.act1 = Swish() + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + self.se = SELayer(in_chs, reduction=se_ratio) + else: + print("ERRRRRORRRR -- not prepared for this one\n") + + self.conv_pw = nn.Conv2d(in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False) + self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) # momentum=0.1) + + def construct(self, x): + """construct""" + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + + if self.has_residual: + x += residual + return x + + +def conv_3x3_bn(inp, oup, stride): + weight = weight_variable() + return nn.SequentialCell([ + nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, padding=1, weight_init=weight, + has_bias=False, pad_mode='pad'), + nn.BatchNorm2d(oup, eps=0.001), # momentum=0.1), + nn.HSwish()]) + + +def conv_1x1_bn(inp, oup): + weight = weight_variable() + return nn.SequentialCell([ + nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, stride=1, padding=0, weight_init=weight, + has_bias=False), + nn.BatchNorm2d(oup, eps=0.001), + nn.HSwish()]) + + +class InvertedResidual(nn.Cell): + """InvertedResidual""" + def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio): + super().__init__() + if stride not in [1, 2]: + print("ERROR stride param") + return + mid_chs: int = _make_value_divisible(in_chs * expansion, 1) + self.has_residual = (in_chs == out_chs and stride == 1) + self.drop_connect_rate = 0 + + # Point-wise expansion + self.conv_pw = nn.Conv2d(in_channels=in_chs, out_channels=mid_chs, kernel_size=1, stride=1, has_bias=False) + self.bn1 = nn.BatchNorm2d(mid_chs, eps=0.001) + self.act1 = Swish() + + # Depth-wise convolution + if stride > 1: + self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, stride=stride, + padding=padding, has_bias=False, group=mid_chs, pad_mode='same') + else: + self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, stride=stride, + padding=padding, has_bias=False, group=mid_chs, pad_mode='pad') + self.bn2 = nn.BatchNorm2d(mid_chs, eps=0.001) + self.act2 = Swish() + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + self.se = SELayer(mid_chs, reduction=se_ratio) + else: + print("ERRRRRORRRR -- not prepared for this one\n") + + # Point-wise linear projection + self.conv_pwl = nn.Conv2d(in_channels=mid_chs, out_channels=out_chs, kernel_size=1, stride=1, has_bias=False) + self.bn3 = nn.BatchNorm2d(out_chs, eps=0.001) + + def construct(self, x): + """construct""" + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + x += residual + return x + + +class EfficientNet(nn.Cell): + """EfficientNet""" + def __init__(self, cfgs, num_classes=1000): + super().__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + stem_size = 32 + self.num_classes_ = num_classes + self.num_features_ = 1280 + + self.conv_stem = nn.Conv2d(in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False) + + self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1) + self.act1 = Swish() + in_chs = stem_size + + layers = [nn.SequentialCell([DepthwiseSeparableConv(in_chs, 16, 3, 1, se_ratio=4)]), + + nn.SequentialCell([InvertedResidual(16, 24, 3, 2, 0, 6, se_ratio=24), + InvertedResidual(24, 24, 3, 1, 1, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(24, 40, 5, 2, 0, 6, se_ratio=24), + InvertedResidual(40, 40, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(40, 80, 3, 2, 0, 6, se_ratio=24), + InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24), + InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(80, 112, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(112, 192, 5, 2, 0, 6, se_ratio=24), + InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(192, 320, 3, 1, 1, 6, se_ratio=24)]) + ] + self.blocks = nn.SequentialCell(layers) + + self.conv_head = nn.Conv2d(in_channels=320, out_channels=self.num_features_, kernel_size=1) + self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) # momentum=0.1) + self.act2 = Swish() + self.global_pool = AdaptiveAvgPool(output_size=(1, 1)) + self.classifier = nn.Dense(self.num_features_, num_classes) + + self._initialize_weights() + + def construct(self, x): + """construct""" + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + x = self.global_pool(x) + x = P.Reshape()(x, (-1, self.num_features_)) + x = self.classifier(x) + return x + + def _initialize_weights(self): + """_initialize_weights""" + def init_linear_weight(m): + m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + for m in self.cells(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.data.zero_() + m.weight.requires_grad = True + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + init_linear_weight(m) + + +def effnet(**kwargs): + """ + Constructs a EfficientNet model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 1, 16, 1, 0, 2], + [3, 4.5, 24, 0, 0, 2], + [3, 3.67, 24, 0, 0, 1], + [5, 4, 40, 1, 1, 2], + [5, 6, 40, 1, 1, 1], + [5, 6, 40, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 6, 96, 1, 1, 2], + [5, 6, 96, 1, 1, 1], + [5, 6, 96, 1, 1, 1], + ] + + return EfficientNet(cfgs, **kwargs) diff --git a/mindspore/lite/examples/export_models/prepare.sh b/mindspore/lite/examples/export_models/prepare.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/quick_start_micro/mobilenetv2_arm64/README.md b/mindspore/lite/examples/quick_start_micro/mobilenetv2_arm64/README.md old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_cpp/Makefile b/mindspore/lite/examples/train_lenet_cpp/Makefile deleted file mode 100644 index 465ece34d87..00000000000 --- a/mindspore/lite/examples/train_lenet_cpp/Makefile +++ /dev/null @@ -1,58 +0,0 @@ -BASE_DIR=$(realpath ../../../../) -APP:=bin/net_runner -INF_APP:=bin/infer -LMSTLIB:=-lmindspore-lite-train -lminddata-lite -LMSLIB:=-lmindspore-lite -MSDIR:=$(realpath package-$(TARGET)/lib) -ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") - LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai -else - LHIAILIB:= -endif - -SRC:=src/net_runner.cc -OBJ:=$(SRC:.cc=.o) - -INF_SRC:=src/inference.cc -INF_OBJ:=$(INF_SRC:.cc=.o) - -CFLAGS := -Ofast -std=c++17 \ - -I . \ - -I ./msl/runtime \ - -I ./msl/runtime/include \ - -I ./msl/runtime/minddata \ - -I ./msl/tools/third_party/flatbuffers/include - - -ifeq ($(TARGET),arm64) -CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++ -CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections -LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections -LDFLAGS += -L$(MSDIR) $(LMSLIB) -pthread -llog -latomic -lm $(LHIAILIB) -Wl,-rpath,$(MSDIR) -else -CFLAGS += -g -LDFLAGS := -L$(MSDIR) $(LMSLIB) -lpthread -Wl,-rpath,$(MSDIR) -endif -LD := ${CXX} - - -all:$(APP) $(INF_APP) - -$(APP): $(OBJ) - @mkdir -p bin - $(LD) $(OBJ) $(LMSTLIB) $(LDFLAGS) -o $@ - -$(INF_APP): $(INF_OBJ) - @mkdir -p bin - $(LD) $(INF_OBJ) $(LDFLAGS) -o $@ - - -clean: - rm -rf src/*.o bin/ - - -mrproper: - rm -rf package* msl src/*.o bin/ model/*.mindir model/*.ms model/*.so* model/converter_lite - -%.o:%.cc - $(CXX) $(CFLAGS) -c $< -o $@ diff --git a/mindspore/lite/examples/train_lenet_cpp/model/prepare_model.sh b/mindspore/lite/examples/train_lenet_cpp/model/prepare_model.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_cpp/prepare_and_run.sh b/mindspore/lite/examples/train_lenet_cpp/prepare_and_run.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_cpp/scripts/batch_of32.dat b/mindspore/lite/examples/train_lenet_cpp/scripts/batch_of32.dat deleted file mode 100644 index 5e79e95ef23..00000000000 Binary files a/mindspore/lite/examples/train_lenet_cpp/scripts/batch_of32.dat and /dev/null differ diff --git a/mindspore/lite/examples/train_lenet_cpp/scripts/eval.sh b/mindspore/lite/examples/train_lenet_cpp/scripts/eval.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_cpp/scripts/infer.sh b/mindspore/lite/examples/train_lenet_cpp/scripts/infer.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_cpp/scripts/train.sh b/mindspore/lite/examples/train_lenet_cpp/scripts/train.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_java/build.sh b/mindspore/lite/examples/train_lenet_java/build.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_java/model/prepare_model.sh b/mindspore/lite/examples/train_lenet_java/model/prepare_model.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/train_lenet_java/prepare_and_run.sh b/mindspore/lite/examples/train_lenet_java/prepare_and_run.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/Makefile b/mindspore/lite/examples/transfer_learning/Makefile deleted file mode 100644 index 2caefe48536..00000000000 --- a/mindspore/lite/examples/transfer_learning/Makefile +++ /dev/null @@ -1,48 +0,0 @@ -BASE_DIR=$(realpath ../../../../) -APP:=bin/net_runner -LMSLIB:=-lmindspore-lite-train -lmindspore-lite -LMDLIB:=-lminddata-lite -MSDIR:=$(realpath package-$(TARGET)/lib) -ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") - LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai -else - LHIAILIB:= -endif - -SRC:=src/net_runner.cc src/dataset.cc -OBJ:=$(SRC:.cc=.o) - -CFLAGS := -Ofast -std=c++17 \ - -I . \ - -I ./msl/runtime \ - -I ./msl/runtime/include \ - -I ./msl/runtime/minddata \ - -I ./msl/tools/third_party/flatbuffers/include - - -ifeq ($(TARGET),arm64) -CXX := ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++ -CFLAGS += --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -fdata-sections -ffunction-sections -LDFLAGS := --target=aarch64-none-linux-android21 --gcc-toolchain=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64 --sysroot=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot -Wl,--gc-sections -LDFLAGS += -L$(MSDIR) $(LMSLIB) $(LMDLIB) $(LHIAILIB) -pthread -llog -latomic -lm -Wl,-rpath,$(MSDIR) -else -LDFLAGS := -L$(MSDIR) $(LMSLIB) $(LMDLIB) $(LHIAILIB) -lpthread -Wl,-rpath,$(MSDIR) -endif -LD := ${CXX} - - -all:$(APP) - -$(APP): $(OBJ) - @mkdir -p bin - $(LD) $(OBJ) $(LDFLAGS) -o $@ - -clean: - rm -rf src/*.o bin/ - - -mrproper: - rm -rf dataset package* msl src/*.o bin/ model/*.mindir model/*.ms model/*.so model/converter_lite - -%.o:%.cc - $(CXX) $(CFLAGS) -c $< -o $@ diff --git a/mindspore/lite/examples/transfer_learning/model/effnet.py b/mindspore/lite/examples/transfer_learning/model/effnet.py old mode 100755 new mode 100644 index dcbef08eb83..ff15bb40af5 --- a/mindspore/lite/examples/transfer_learning/model/effnet.py +++ b/mindspore/lite/examples/transfer_learning/model/effnet.py @@ -1,344 +1,344 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""EffNet model define""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.initializer import TruncatedNormal -from mindspore import Tensor - - -__all__ = ['effnet'] - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - - -def _make_divisible(v, divisor, min_value=None): - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - iparam min_value: - :return: - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class Swish(nn.Cell): - def __init__(self): - super().__init__(Swish) - self.sigmoid = nn.Sigmoid() - - def construct(self, x): - s = self.sigmoid(x) - m = x*s - return m - - -class AdaptiveAvgPool(nn.Cell): - def __init__(self, output_size=None): - super().__init__(AdaptiveAvgPool) - self.mean = P.ReduceMean(keep_dims=True) - self.output_size = output_size - - def construct(self, x): - return self.mean(x, (2, 3)) - - -class SELayer(nn.Cell): - """ - SELayer - """ - def __init__(self, channel, reduction=4): - super().__init__(SELayer) - reduced_chs = _make_divisible(channel/reduction, 1) - self.avg_pool = AdaptiveAvgPool(output_size=(1, 1)) - weight = weight_variable() - self.conv_reduce = nn.Conv2d( - in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, weight_init=weight) - self.act1 = Swish() - self.conv_expand = nn.Conv2d( - in_channels=reduced_chs, out_channels=channel, kernel_size=1, has_bias=True) - self.act2 = nn.Sigmoid() - - def construct(self, x): - o = self.avg_pool(x) - o = self.conv_reduce(o) - o = self.act1(o) - o = self.conv_expand(o) - o = self.act2(o) - return x * o - - -class DepthwiseSeparableConv(nn.Cell): - """ - DepthwiseSeparableConv - """ - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0): - super().__init__(DepthwiseSeparableConv) - if stride not in [1, 2]: - print("ERROR") - return - self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip - self.drop_connect_rate = drop_connect_rate - - self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size, - stride=stride, pad_mode="pad", padding=1, has_bias=False, group=in_chs) - self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) - self.act1 = Swish() - - if se_ratio is not None and se_ratio > 0.: - self.se = SELayer(in_chs, reduction=se_ratio) - else: - print("ERRRRRORRRR -- not prepared for this one\n") - - self.conv_pw = nn.Conv2d( - in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False) - self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) - - def construct(self, x): - """ - construct - """ - residual = x - x = self.conv_dw(x) - x = self.bn1(x) - x = self.act1(x) - x = self.se(x) - x = self.conv_pw(x) - x = self.bn2(x) - if self.has_residual: - x += residual - return x - - -def conv_3x3_bn(inp, oup, stride): - weight = weight_variable() - return nn.SequentialCell([ - nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, - padding=1, weight_init=weight, has_bias=False, pad_mode='pad'), - nn.BatchNorm2d(oup, eps=0.001), # , momentum=0.1), - nn.HSwish()]) - - -def conv_1x1_bn(inp, oup): - weight = weight_variable() - return nn.SequentialCell([ - nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, - stride=1, padding=0, weight_init=weight, has_bias=False), - nn.BatchNorm2d(oup, eps=0.001), - nn.HSwish()]) - - -class InvertedResidual(nn.Cell): - """ - InvertedResidual - """ - def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio): - super().__init__(InvertedResidual) - if stride not in [1, 2]: - print("ERROR") - return - mid_chs: int = _make_divisible(in_chs * expansion, 1) - self.has_residual = (in_chs == out_chs and stride == 1) - self.drop_connect_rate = 0 - self.conv_pw = nn.Conv2d( - in_channels=in_chs, out_channels=mid_chs, kernel_size=1, stride=1, has_bias=False) - self.bn1 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1) - self.act1 = Swish() - if stride > 1: - self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, - stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='same') - else: - self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, - stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='pad') - self.bn2 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1) - self.act2 = Swish() - - # Squeeze-and-excitation - if se_ratio is not None and se_ratio > 0.: - self.se = SELayer(mid_chs, reduction=se_ratio) - else: - print("ERRRRRORRRR -- not prepared for this one\n") - - # Point-wise linear projection - self.conv_pwl = nn.Conv2d( - in_channels=mid_chs, out_channels=out_chs, kernel_size=1, stride=1, has_bias=False) - self.bn3 = nn.BatchNorm2d(out_chs, eps=0.001) # ,momentum=0.1) - - def construct(self, x): - """ - construct - """ - residual = x - - # Point-wise expansion - x = self.conv_pw(x) - x = self.bn1(x) - x = self.act1(x) - - # Depth-wise convolution - x = self.conv_dw(x) - x = self.bn2(x) - x = self.act2(x) - - # Squeeze-and-excitation - x = self.se(x) - - # Point-wise linear projection - x = self.conv_pwl(x) - x = self.bn3(x) - - if self.has_residual: - x += residual - return x - - -class EfficientNet(nn.Cell): - """ - EfficientNet - """ - def __init__(self, cfgs, num_classes=1000): - super().__init__(EfficientNet) - # setting of inverted residual blocks - self.cfgs = cfgs - stem_size = 32 - self.num_classes_ = num_classes - self.num_features_ = 1280 - - self.conv_stem = nn.Conv2d( - in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False) - - self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1) - self.act1 = Swish() - in_chs = stem_size - - layers = [nn.SequentialCell([DepthwiseSeparableConv(in_chs, 16, 3, 1, se_ratio=4)]), - - nn.SequentialCell([InvertedResidual(16, 24, 3, 2, 0, 6, se_ratio=24), - InvertedResidual(24, 24, 3, 1, 1, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(24, 40, 5, 2, 0, 6, se_ratio=24), - InvertedResidual(40, 40, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(40, 80, 3, 2, 0, 6, se_ratio=24), - InvertedResidual( - 80, 80, 3, 1, 1, 6, se_ratio=24), - InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(80, 112, 5, 1, 2, 6, se_ratio=24), - InvertedResidual( - 112, 112, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell([InvertedResidual(112, 192, 5, 2, 0, 6, se_ratio=24), - InvertedResidual( - 192, 192, 5, 1, 2, 6, se_ratio=24), - InvertedResidual( - 192, 192, 5, 1, 2, 6, se_ratio=24), - InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24)]), - - nn.SequentialCell( - [InvertedResidual(192, 320, 3, 1, 1, 6, se_ratio=24)]) - ] - self.blocks = nn.SequentialCell(layers) - - self.conv_head = nn.Conv2d( - in_channels=320, out_channels=self.num_features_, kernel_size=1) - self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) - self.act2 = Swish() - self.global_pool = AdaptiveAvgPool(output_size=(1, 1)) - self.classifier = nn.Dense(self.num_features_, num_classes) - - self._initialize_weights() - - def construct(self, x): - """ - construct - """ - x = self.conv_stem(x) - x = self.bn1(x) - x = self.act1(x) - x = self.blocks(x) - x = self.conv_head(x) - x = self.bn2(x) - x = self.act2(x) - x = self.global_pool(x) - x = P.Reshape()(x, (-1, self.num_features_)) - x = self.classifier(x) - return x - - def _initialize_weights(self): - """ - _initialize_weights - """ - def init_linear_weight(m): - m.weight.set_data(Tensor(np.random.normal( - 0, 0.01, m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.set_data( - Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) - - for m in self.cells(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), - m.weight.data.shape).astype("float32"))) - if m.bias is not None: - m.bias.data.zero_() - m.weight.requires_grad = True - elif isinstance(m, nn.BatchNorm2d): - m.gamma.set_data( - Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) - m.beta.set_data( - Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) - elif isinstance(m, nn.Dense): - init_linear_weight(m) - - -def effnet(**kwargs): - """ - Constructs a EfficientNet model - """ - cfgs = [ - # k, t, c, SE, HS, s - [3, 1, 16, 1, 0, 2], - [3, 4.5, 24, 0, 0, 2], - [3, 3.67, 24, 0, 0, 1], - [5, 4, 40, 1, 1, 2], - [5, 6, 40, 1, 1, 1], - [5, 6, 40, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 6, 96, 1, 1, 2], - [5, 6, 96, 1, 1, 1], - [5, 6, 96, 1, 1, 1], - ] - - return EfficientNet(cfgs, **kwargs) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""EffNet model define""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal +from mindspore import Tensor + + +__all__ = ['effnet'] + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + iparam min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class Swish(nn.Cell): + def __init__(self): + super().__init__(Swish) + self.sigmoid = nn.Sigmoid() + + def construct(self, x): + s = self.sigmoid(x) + m = x*s + return m + + +class AdaptiveAvgPool(nn.Cell): + def __init__(self, output_size=None): + super().__init__(AdaptiveAvgPool) + self.mean = P.ReduceMean(keep_dims=True) + self.output_size = output_size + + def construct(self, x): + return self.mean(x, (2, 3)) + + +class SELayer(nn.Cell): + """ + SELayer + """ + def __init__(self, channel, reduction=4): + super().__init__(SELayer) + reduced_chs = _make_divisible(channel/reduction, 1) + self.avg_pool = AdaptiveAvgPool(output_size=(1, 1)) + weight = weight_variable() + self.conv_reduce = nn.Conv2d( + in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, weight_init=weight) + self.act1 = Swish() + self.conv_expand = nn.Conv2d( + in_channels=reduced_chs, out_channels=channel, kernel_size=1, has_bias=True) + self.act2 = nn.Sigmoid() + + def construct(self, x): + o = self.avg_pool(x) + o = self.conv_reduce(o) + o = self.act1(o) + o = self.conv_expand(o) + o = self.act2(o) + return x * o + + +class DepthwiseSeparableConv(nn.Cell): + """ + DepthwiseSeparableConv + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0): + super().__init__(DepthwiseSeparableConv) + if stride not in [1, 2]: + print("ERROR") + return + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size, + stride=stride, pad_mode="pad", padding=1, has_bias=False, group=in_chs) + self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001) + self.act1 = Swish() + + if se_ratio is not None and se_ratio > 0.: + self.se = SELayer(in_chs, reduction=se_ratio) + else: + print("ERRRRRORRRR -- not prepared for this one\n") + + self.conv_pw = nn.Conv2d( + in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False) + self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001) + + def construct(self, x): + """ + construct + """ + residual = x + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + x = self.se(x) + x = self.conv_pw(x) + x = self.bn2(x) + if self.has_residual: + x += residual + return x + + +def conv_3x3_bn(inp, oup, stride): + weight = weight_variable() + return nn.SequentialCell([ + nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, + padding=1, weight_init=weight, has_bias=False, pad_mode='pad'), + nn.BatchNorm2d(oup, eps=0.001), # , momentum=0.1), + nn.HSwish()]) + + +def conv_1x1_bn(inp, oup): + weight = weight_variable() + return nn.SequentialCell([ + nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, + stride=1, padding=0, weight_init=weight, has_bias=False), + nn.BatchNorm2d(oup, eps=0.001), + nn.HSwish()]) + + +class InvertedResidual(nn.Cell): + """ + InvertedResidual + """ + def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio): + super().__init__(InvertedResidual) + if stride not in [1, 2]: + print("ERROR") + return + mid_chs: int = _make_divisible(in_chs * expansion, 1) + self.has_residual = (in_chs == out_chs and stride == 1) + self.drop_connect_rate = 0 + self.conv_pw = nn.Conv2d( + in_channels=in_chs, out_channels=mid_chs, kernel_size=1, stride=1, has_bias=False) + self.bn1 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1) + self.act1 = Swish() + if stride > 1: + self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, + stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='same') + else: + self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size, + stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='pad') + self.bn2 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1) + self.act2 = Swish() + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + self.se = SELayer(mid_chs, reduction=se_ratio) + else: + print("ERRRRRORRRR -- not prepared for this one\n") + + # Point-wise linear projection + self.conv_pwl = nn.Conv2d( + in_channels=mid_chs, out_channels=out_chs, kernel_size=1, stride=1, has_bias=False) + self.bn3 = nn.BatchNorm2d(out_chs, eps=0.001) # ,momentum=0.1) + + def construct(self, x): + """ + construct + """ + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + x += residual + return x + + +class EfficientNet(nn.Cell): + """ + EfficientNet + """ + def __init__(self, cfgs, num_classes=1000): + super().__init__(EfficientNet) + # setting of inverted residual blocks + self.cfgs = cfgs + stem_size = 32 + self.num_classes_ = num_classes + self.num_features_ = 1280 + + self.conv_stem = nn.Conv2d( + in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False) + + self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1) + self.act1 = Swish() + in_chs = stem_size + + layers = [nn.SequentialCell([DepthwiseSeparableConv(in_chs, 16, 3, 1, se_ratio=4)]), + + nn.SequentialCell([InvertedResidual(16, 24, 3, 2, 0, 6, se_ratio=24), + InvertedResidual(24, 24, 3, 1, 1, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(24, 40, 5, 2, 0, 6, se_ratio=24), + InvertedResidual(40, 40, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(40, 80, 3, 2, 0, 6, se_ratio=24), + InvertedResidual( + 80, 80, 3, 1, 1, 6, se_ratio=24), + InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(80, 112, 5, 1, 2, 6, se_ratio=24), + InvertedResidual( + 112, 112, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell([InvertedResidual(112, 192, 5, 2, 0, 6, se_ratio=24), + InvertedResidual( + 192, 192, 5, 1, 2, 6, se_ratio=24), + InvertedResidual( + 192, 192, 5, 1, 2, 6, se_ratio=24), + InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24)]), + + nn.SequentialCell( + [InvertedResidual(192, 320, 3, 1, 1, 6, se_ratio=24)]) + ] + self.blocks = nn.SequentialCell(layers) + + self.conv_head = nn.Conv2d( + in_channels=320, out_channels=self.num_features_, kernel_size=1) + self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001) + self.act2 = Swish() + self.global_pool = AdaptiveAvgPool(output_size=(1, 1)) + self.classifier = nn.Dense(self.num_features_, num_classes) + + self._initialize_weights() + + def construct(self, x): + """ + construct + """ + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + x = self.global_pool(x) + x = P.Reshape()(x, (-1, self.num_features_)) + x = self.classifier(x) + return x + + def _initialize_weights(self): + """ + _initialize_weights + """ + def init_linear_weight(m): + m.weight.set_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + + for m in self.cells(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.data.zero_() + m.weight.requires_grad = True + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + elif isinstance(m, nn.Dense): + init_linear_weight(m) + + +def effnet(**kwargs): + """ + Constructs a EfficientNet model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 1, 16, 1, 0, 2], + [3, 4.5, 24, 0, 0, 2], + [3, 3.67, 24, 0, 0, 1], + [5, 4, 40, 1, 1, 2], + [5, 6, 40, 1, 1, 1], + [5, 6, 40, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 6, 96, 1, 1, 2], + [5, 6, 96, 1, 1, 1], + [5, 6, 96, 1, 1, 1], + ] + + return EfficientNet(cfgs, **kwargs) diff --git a/mindspore/lite/examples/transfer_learning/model/prepare_model.sh b/mindspore/lite/examples/transfer_learning/model/prepare_model.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/model/transfer_learning_export.py b/mindspore/lite/examples/transfer_learning/model/transfer_learning_export.py old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/prepare_and_run.sh b/mindspore/lite/examples/transfer_learning/prepare_and_run.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/prepare_dataset.sh b/mindspore/lite/examples/transfer_learning/prepare_dataset.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/scripts/eval.sh b/mindspore/lite/examples/transfer_learning/scripts/eval.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/scripts/eval_untrained.sh b/mindspore/lite/examples/transfer_learning/scripts/eval_untrained.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/examples/transfer_learning/scripts/train.sh b/mindspore/lite/examples/transfer_learning/scripts/train.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cu b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cu old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc index 7416f24e6bc..e317bd9a919 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc @@ -1,130 +1,130 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/depthtospace_impl.cuh" -#include "ops/depth_to_space.h" - -namespace mindspore::lite { -int DepthToSpaceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - if (out_tensors.size() < 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int DepthToSpaceTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto op = AsOps(); - int block_size = op->get_block_size(); - - auto plugin = std::make_shared(input_tensor->getName(), block_size, device_id_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "add depthtospace plugin failed for" << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input_tensor}; - nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (layer == nullptr) { - MS_LOG(ERROR) << "add depthtospace op failed for TensorRT."; - return RET_ERROR; - } - layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(DepthToSpacePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int DepthToSpacePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaDepthToSpace(inputDesc, inputs, outputs, stream); -} - -int DepthToSpacePlugin::RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int in = input_dims.d[0]; - int ic = input_dims.d[1]; - int ih = input_dims.d[2]; - int iw = input_dims.d[3]; - int on = in; - int oc = ic / block_size_ / block_size_; - int oh = ih * block_size_; - int ow = iw * block_size_; - - int size = on * oc * oh * ow; - - CalDepthToSpace(size, static_cast(inputs[0]), in, ic, ih, iw, on, oc, oh, ow, block_size_, - static_cast(outputs[0]), device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *DepthToSpacePlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) DepthToSpacePlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "new plugin failed!"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t DepthToSpacePlugin::getSerializationSize() const noexcept { return sizeof(int); } - -nvinfer1::DimsExprs DepthToSpacePlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = inputs[0].nbDims; - dims.d[0] = inputs[0].d[0]; - dims.d[1] = inputs[0].d[1]; - auto block_size_sqrt = exprBuilder.constant(block_size_ * block_size_); - dims.d[1] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *inputs[0].d[1], *block_size_sqrt); - auto block_size = exprBuilder.constant(block_size_); - dims.d[INPUT_SIZE2] = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE2], *block_size); - dims.d[INPUT_SIZE3] = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE3], *block_size); - return dims; -} - -void DepthToSpacePlugin::serialize(void *buffer) const noexcept { SerializeValue(&buffer, &block_size_, sizeof(int)); } -REGISTER_TENSORRT_CREATOR(ops::kNameDepthToSpace, DepthToSpaceTensorRT) -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h" +#include +#include +#include +#include +#include +#include +#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" +#include "NvInferRuntimeCommon.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/depthtospace_impl.cuh" +#include "ops/depth_to_space.h" + +namespace mindspore::lite { +int DepthToSpaceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + + if (out_tensors.size() < 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int DepthToSpaceTensorRT::AddInnerOp(TensorRTContext *ctx) { + nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; + auto op = AsOps(); + int block_size = op->get_block_size(); + + auto plugin = std::make_shared(input_tensor->getName(), block_size, device_id_); + if (plugin == nullptr) { + MS_LOG(ERROR) << "add depthtospace plugin failed for" << op_name_; + return RET_ERROR; + } + nvinfer1::ITensor *inputTensors[] = {input_tensor}; + nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); + if (layer == nullptr) { + MS_LOG(ERROR) << "add depthtospace op failed for TensorRT."; + return RET_ERROR; + } + layer->setName(op_name_.c_str()); + nvinfer1::ITensor *out_tensor = layer->getOutput(0); + ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); + this->layer_ = layer; + return RET_OK; +} + +REGISTER_TENSORRT_PLUGIN(DepthToSpacePluginCreater); +template class TensorRTPluginCreater; +template +nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; +template +std::vector TensorRTPluginCreater::fields_; + +int DepthToSpacePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) noexcept { + return RunCudaDepthToSpace(inputDesc, inputs, outputs, stream); +} + +int DepthToSpacePlugin::RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, + void *const *outputs, cudaStream_t stream) { + nvinfer1::Dims input_dims = inputDesc[0].dims; + int in = input_dims.d[0]; + int ic = input_dims.d[1]; + int ih = input_dims.d[2]; + int iw = input_dims.d[3]; + int on = in; + int oc = ic / block_size_ / block_size_; + int oh = ih * block_size_; + int ow = iw * block_size_; + + int size = on * oc * oh * ow; + + CalDepthToSpace(size, static_cast(inputs[0]), in, ic, ih, iw, on, oc, oh, ow, block_size_, + static_cast(outputs[0]), device_id_, stream); + return RET_OK; +} + +nvinfer1::IPluginV2DynamicExt *DepthToSpacePlugin::clone() const noexcept { + auto *plugin = new (std::nothrow) DepthToSpacePlugin(*this); + if (plugin == nullptr) { + MS_LOG(ERROR) << "new plugin failed!"; + return nullptr; + } + plugin->setPluginNamespace(name_space_.c_str()); + return plugin; +} + +size_t DepthToSpacePlugin::getSerializationSize() const noexcept { return sizeof(int); } + +nvinfer1::DimsExprs DepthToSpacePlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, + int nbInputDims, + nvinfer1::IExprBuilder &exprBuilder) noexcept { + nvinfer1::DimsExprs dims; + dims.nbDims = inputs[0].nbDims; + dims.d[0] = inputs[0].d[0]; + dims.d[1] = inputs[0].d[1]; + auto block_size_sqrt = exprBuilder.constant(block_size_ * block_size_); + dims.d[1] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *inputs[0].d[1], *block_size_sqrt); + auto block_size = exprBuilder.constant(block_size_); + dims.d[INPUT_SIZE2] = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE2], *block_size); + dims.d[INPUT_SIZE3] = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE3], *block_size); + return dims; +} + +void DepthToSpacePlugin::serialize(void *buffer) const noexcept { SerializeValue(&buffer, &block_size_, sizeof(int)); } +REGISTER_TENSORRT_CREATOR(ops::kNameDepthToSpace, DepthToSpaceTensorRT) +} // namespace mindspore::lite diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h index e483a9ca477..6f48cf8554c 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h @@ -1,79 +1,79 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class DepthToSpaceTensorRT : public TensorRTOp { - public: - DepthToSpaceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~DepthToSpaceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto DEPTHTOSPACETENSORRT_PLUGIN_NAME{"DepthToSpacePlugin"}; -class DepthToSpacePlugin : public TensorRTPlugin { - public: - DepthToSpacePlugin(const std::string name, int block_size, uint32_t device_id) - : TensorRTPlugin(name, std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME), device_id), block_size_(block_size) {} - - DepthToSpacePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - block_size_ = static_cast(fields[0].data)[0]; - } - - DepthToSpacePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &block_size_, sizeof(int)); - } - - DepthToSpacePlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - - private: - int RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int block_size_; - const std::string layer_name_; - std::string name_space_; -}; -class DepthToSpacePluginCreater : public TensorRTPluginCreater { - public: - DepthToSpacePluginCreater() : TensorRTPluginCreater(std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ + +#include +#include +#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" +#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" + +namespace mindspore::lite { +class DepthToSpaceTensorRT : public TensorRTOp { + public: + DepthToSpaceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} + + ~DepthToSpaceTensorRT() override = default; + + int AddInnerOp(TensorRTContext *ctx) override; + + int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, + const std::vector &out_tensors) override; +}; + +constexpr auto DEPTHTOSPACETENSORRT_PLUGIN_NAME{"DepthToSpacePlugin"}; +class DepthToSpacePlugin : public TensorRTPlugin { + public: + DepthToSpacePlugin(const std::string name, int block_size, uint32_t device_id) + : TensorRTPlugin(name, std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME), device_id), block_size_(block_size) {} + + DepthToSpacePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) + : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { + const nvinfer1::PluginField *fields = fc->fields; + block_size_ = static_cast(fields[0].data)[0]; + } + + DepthToSpacePlugin(const char *name, const void *serialData, size_t serialLength) + : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { + DeserializeValue(&serialData, &serialLength, &block_size_, sizeof(int)); + } + + DepthToSpacePlugin() = delete; + + nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void *buffer) const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, + nvinfer1::IExprBuilder &exprBuilder) noexcept override; + + private: + int RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, + cudaStream_t stream); + int block_size_; + const std::string layer_name_; + std::string name_space_; +}; +class DepthToSpacePluginCreater : public TensorRTPluginCreater { + public: + DepthToSpacePluginCreater() : TensorRTPluginCreater(std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {} +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/litert/cxx_api/kernel.cc b/mindspore/lite/src/litert/cxx_api/kernel.cc index 9d321b24b2f..a0ec26eb4e1 100644 --- a/mindspore/lite/src/litert/cxx_api/kernel.cc +++ b/mindspore/lite/src/litert/cxx_api/kernel.cc @@ -1,71 +1,71 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/api/kernel.h" -#include "include/errorcode.h" -#include "src/registry/kernel_interface_registry.h" -#include "src/common/log_adapter.h" - -namespace mindspore::kernel { -void Kernel::Initialize() { - if (primitive_ == nullptr) { - return; - } - type_ = primitive_->value_type(); - if (type_ == schema::PrimitiveType_Custom) { - auto param = primitive_->value_as_Custom(); - if (param != nullptr && param->type() != nullptr) { - SetAttr("type", param->type()->str()); - } - } -} - -int Kernel::InferShape() { -#ifndef CUSTOM_KERNEL_REGISTRY_CLIP - std::shared_ptr kernel_interface = nullptr; - if (type() == schema::PrimitiveType_Custom) { - kernel_interface = registry::KernelInterfaceRegistry::Instance()->GetKernelInterface("", nullptr, this); - } else { - auto device_list = const_cast(context_)->MutableDeviceInfo(); - for (auto &device : device_list) { - MS_CHECK_TRUE_RET(device != nullptr, lite::RET_NULL_PTR); - kernel_interface = - registry::KernelInterfaceRegistry::Instance()->GetKernelInterface(device->GetProvider(), nullptr, this); - if (kernel_interface != nullptr) { - break; - } - } - } - - if (kernel_interface == nullptr) { - MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " can not find infer interface."; - return lite::RET_NOT_SUPPORT; - } - auto ret = kernel_interface->Infer(&inputs_, &outputs_, static_cast(primitive_), this); - if (ret == kLiteInferInvalid) { - for (auto output : outputs_) { - output.SetShape({-1}); - } - return lite::RET_INFER_INVALID; - } - if (ret != kSuccess) { - MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " infer fail!ret: " << ret; - return lite::RET_ERROR; - } - return lite::RET_OK; -#endif - return lite::RET_NOT_SUPPORT; -} -} // namespace mindspore::kernel +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/kernel.h" +#include "include/errorcode.h" +#include "src/registry/kernel_interface_registry.h" +#include "src/common/log_adapter.h" + +namespace mindspore::kernel { +void Kernel::Initialize() { + if (primitive_ == nullptr) { + return; + } + type_ = primitive_->value_type(); + if (type_ == schema::PrimitiveType_Custom) { + auto param = primitive_->value_as_Custom(); + if (param != nullptr && param->type() != nullptr) { + SetAttr("type", param->type()->str()); + } + } +} + +int Kernel::InferShape() { +#ifndef CUSTOM_KERNEL_REGISTRY_CLIP + std::shared_ptr kernel_interface = nullptr; + if (type() == schema::PrimitiveType_Custom) { + kernel_interface = registry::KernelInterfaceRegistry::Instance()->GetKernelInterface("", nullptr, this); + } else { + auto device_list = const_cast(context_)->MutableDeviceInfo(); + for (auto &device : device_list) { + MS_CHECK_TRUE_RET(device != nullptr, lite::RET_NULL_PTR); + kernel_interface = + registry::KernelInterfaceRegistry::Instance()->GetKernelInterface(device->GetProvider(), nullptr, this); + if (kernel_interface != nullptr) { + break; + } + } + } + + if (kernel_interface == nullptr) { + MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " can not find infer interface."; + return lite::RET_NOT_SUPPORT; + } + auto ret = kernel_interface->Infer(&inputs_, &outputs_, static_cast(primitive_), this); + if (ret == kLiteInferInvalid) { + for (auto output : outputs_) { + output.SetShape({-1}); + } + return lite::RET_INFER_INVALID; + } + if (ret != kSuccess) { + MS_LOG(ERROR) << "op_type: " << schema::EnumNamePrimitiveType(type_) << " infer fail!ret: " << ret; + return lite::RET_ERROR; + } + return lite::RET_OK; +#endif + return lite::RET_NOT_SUPPORT; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/cxx_api/model/model_group.cc b/mindspore/lite/src/litert/cxx_api/model/model_group.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/litert/cxx_api/model/model_group_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_group_impl.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.cc b/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.cc index 6d53d10cf37..5d4ba51d508 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.cc +++ b/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.cc @@ -1,72 +1,72 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/litert/delegate/coreml/op/avg_pooling_coreml.h" -namespace mindspore::lite { -int AvgPoolingCoreMLOp::InitParams() { - pooling_prim_ = op_primitive_->value_as_AvgPoolFusion(); - if (pooling_prim_ == nullptr) { - MS_LOG(ERROR) << "Get null primitive value for op ." << name_; - return RET_ERROR; - } - return RET_OK; -} - -int AvgPoolingCoreMLOp::BuildLayer() { - MS_ASSERT(op_ != nullptr); - auto pooling_param = op_->mutable_pooling(); - pooling_param->set_type(CoreML::Specification::PoolingLayerParams::AVERAGE); - if (pooling_prim_->global()) { - pooling_param->set_globalpooling(true); - pooling_param->mutable_valid(); - return RET_OK; - } - pooling_param->set_avgpoolexcludepadding(true); - auto kernel_h = static_cast(*(pooling_prim_->kernel_size()->begin())); - auto kernel_w = static_cast(*(pooling_prim_->kernel_size()->begin() + 1)); - auto stride_h = static_cast(*(pooling_prim_->strides()->begin())); - auto stride_w = static_cast(*(pooling_prim_->strides()->begin() + 1)); - pooling_param->add_stride(stride_h); - pooling_param->add_stride(stride_w); - pooling_param->add_kernelsize(kernel_h); - pooling_param->add_kernelsize(kernel_w); - if (pooling_prim_->pad_mode() == schema::PadMode_SAME) { - pooling_param->mutable_same(); - } else { - pooling_param->mutable_valid(); - if (pooling_prim_->pad() != nullptr) { - auto pad_u = static_cast(*(pooling_prim_->pad()->begin() + PAD_UP)); - auto pad_d = static_cast(*(pooling_prim_->pad()->begin() + PAD_DOWN)); - auto pad_l = static_cast(*(pooling_prim_->pad()->begin() + PAD_LEFT)); - auto pad_r = static_cast(*(pooling_prim_->pad()->begin() + PAD_RIGHT)); - auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Fail to set padding for op: " << name_; - return RET_ERROR; - } - } - } - auto act_type = pooling_prim_->activation_type(); - if (act_type != schema::ActivationType_NO_ACTIVATION) { - auto ret = SetActivation(act_type); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set pooling activation failed for op: " << name_; - return RET_ERROR; - } - } - return RET_OK; -} -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/delegate/coreml/op/avg_pooling_coreml.h" +namespace mindspore::lite { +int AvgPoolingCoreMLOp::InitParams() { + pooling_prim_ = op_primitive_->value_as_AvgPoolFusion(); + if (pooling_prim_ == nullptr) { + MS_LOG(ERROR) << "Get null primitive value for op ." << name_; + return RET_ERROR; + } + return RET_OK; +} + +int AvgPoolingCoreMLOp::BuildLayer() { + MS_ASSERT(op_ != nullptr); + auto pooling_param = op_->mutable_pooling(); + pooling_param->set_type(CoreML::Specification::PoolingLayerParams::AVERAGE); + if (pooling_prim_->global()) { + pooling_param->set_globalpooling(true); + pooling_param->mutable_valid(); + return RET_OK; + } + pooling_param->set_avgpoolexcludepadding(true); + auto kernel_h = static_cast(*(pooling_prim_->kernel_size()->begin())); + auto kernel_w = static_cast(*(pooling_prim_->kernel_size()->begin() + 1)); + auto stride_h = static_cast(*(pooling_prim_->strides()->begin())); + auto stride_w = static_cast(*(pooling_prim_->strides()->begin() + 1)); + pooling_param->add_stride(stride_h); + pooling_param->add_stride(stride_w); + pooling_param->add_kernelsize(kernel_h); + pooling_param->add_kernelsize(kernel_w); + if (pooling_prim_->pad_mode() == schema::PadMode_SAME) { + pooling_param->mutable_same(); + } else { + pooling_param->mutable_valid(); + if (pooling_prim_->pad() != nullptr) { + auto pad_u = static_cast(*(pooling_prim_->pad()->begin() + PAD_UP)); + auto pad_d = static_cast(*(pooling_prim_->pad()->begin() + PAD_DOWN)); + auto pad_l = static_cast(*(pooling_prim_->pad()->begin() + PAD_LEFT)); + auto pad_r = static_cast(*(pooling_prim_->pad()->begin() + PAD_RIGHT)); + auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Fail to set padding for op: " << name_; + return RET_ERROR; + } + } + } + auto act_type = pooling_prim_->activation_type(); + if (act_type != schema::ActivationType_NO_ACTIVATION) { + auto ret = SetActivation(act_type); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling activation failed for op: " << name_; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.h b/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.h index 1c5c421521b..eea64643531 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.h +++ b/mindspore/lite/src/litert/delegate/coreml/op/avg_pooling_coreml.h @@ -1,39 +1,39 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ - -#include -#include -#include -#include -#include "src/litert/delegate/coreml/op/coreml_op.h" -namespace mindspore::lite { -class AvgPoolingCoreMLOp : public CoreMLOp { - public: - AvgPoolingCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : CoreMLOp(primitive, in_tensors, out_tensors, name) {} - - int InitParams() override; - - int BuildLayer() override; - - private: - const schema::AvgPoolFusion *pooling_prim_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ + +#include +#include +#include +#include +#include "src/litert/delegate/coreml/op/coreml_op.h" +namespace mindspore::lite { +class AvgPoolingCoreMLOp : public CoreMLOp { + public: + AvgPoolingCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : CoreMLOp(primitive, in_tensors, out_tensors, name) {} + + int InitParams() override; + + int BuildLayer() override; + + private: + const schema::AvgPoolFusion *pooling_prim_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_AVG_POOLING_COREML_H_ diff --git a/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.cc b/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.cc index 79214118bc2..5b78a6d149d 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.cc +++ b/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.cc @@ -1,89 +1,89 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" -#include "src/litert/delegate/delegate_utils.h" -namespace mindspore::lite { -int ConvolutionBaseCoreMLOp::SetConvWeight() { - auto weight_tensor = in_tensors_.at(kWeightIndex); - auto weight_shape = weight_tensor.Shape(); - conv_param_->set_kernelchannels(weight_shape.at(MS_WT_CIN)); - conv_param_->set_outputchannels(weight_shape.at(MS_WT_COUT)); - conv_param_->add_kernelsize(weight_shape.at(MS_WT_H)); - conv_param_->add_kernelsize(weight_shape.at(MS_WT_W)); - - // transpose the weight, (c_out, h, w, c_in) -> (c_out, c_in, h, w) - auto org_weight = weight_tensor.Data().get(); - MS_ASSERT(org_weight != nullptr); - if (weight_tensor.DataType() == DataType::kNumberTypeFloat32) { - auto *ml_weight_container = conv_param_->mutable_weights()->mutable_floatvalue(); - ml_weight_container->Resize(weight_tensor.ElementNum(), 0); - auto *ml_weight = reinterpret_cast(ml_weight_container->mutable_data()); - PackNHWCToNCHWFp32(org_weight, ml_weight, weight_shape[MS_WT_COUT], weight_shape[MS_WT_H] * weight_shape[MS_WT_W], - weight_shape[MS_WT_CIN]); - } else { - MS_LOG(ERROR) << "Unsupported data type of weight tensor for CoreML convolution."; - return RET_ERROR; - } - return RET_OK; -} - -int ConvolutionBaseCoreMLOp::SetConvBias() { - if (in_tensors_.size() >= kInputSize2) { - auto bias_tensor = in_tensors_.at(kBiasIndex); - auto org_bias = bias_tensor.Data().get(); - conv_param_->set_hasbias(true); - if (bias_tensor.DataType() == DataType::kNumberTypeFloat32) { - auto *ml_bias_container = conv_param_->mutable_bias()->mutable_floatvalue(); - ml_bias_container->Resize(bias_tensor.ElementNum(), 0); - auto *ml_bias = reinterpret_cast(ml_bias_container->mutable_data()); - memcpy(ml_bias, org_bias, bias_tensor.DataSize()); - } else { - MS_LOG(ERROR) << "Unsupported data type of bias tensor for CoreML convolution."; - return RET_ERROR; - } - } - return RET_OK; -} - -int ConvolutionBaseCoreMLOp::BuildLayer() { - MS_ASSERT(op_ != nullptr); - conv_param_ = op_->mutable_convolution(); - auto ret = SetConvParam(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set conv param failed for op: " << name_; - return RET_ERROR; - } - ret = SetConvWeight(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set conv weight failed for op: " << name_; - return RET_ERROR; - } - ret = SetConvBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set conv bias failed for op: " << name_; - return RET_ERROR; - } - if (act_type_ != schema::ActivationType_NO_ACTIVATION) { - ret = SetActivation(act_type_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set conv activation failed for op: " << name_; - return RET_ERROR; - } - } - return RET_OK; -} -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" +#include "src/litert/delegate/delegate_utils.h" +namespace mindspore::lite { +int ConvolutionBaseCoreMLOp::SetConvWeight() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + auto weight_shape = weight_tensor.Shape(); + conv_param_->set_kernelchannels(weight_shape.at(MS_WT_CIN)); + conv_param_->set_outputchannels(weight_shape.at(MS_WT_COUT)); + conv_param_->add_kernelsize(weight_shape.at(MS_WT_H)); + conv_param_->add_kernelsize(weight_shape.at(MS_WT_W)); + + // transpose the weight, (c_out, h, w, c_in) -> (c_out, c_in, h, w) + auto org_weight = weight_tensor.Data().get(); + MS_ASSERT(org_weight != nullptr); + if (weight_tensor.DataType() == DataType::kNumberTypeFloat32) { + auto *ml_weight_container = conv_param_->mutable_weights()->mutable_floatvalue(); + ml_weight_container->Resize(weight_tensor.ElementNum(), 0); + auto *ml_weight = reinterpret_cast(ml_weight_container->mutable_data()); + PackNHWCToNCHWFp32(org_weight, ml_weight, weight_shape[MS_WT_COUT], weight_shape[MS_WT_H] * weight_shape[MS_WT_W], + weight_shape[MS_WT_CIN]); + } else { + MS_LOG(ERROR) << "Unsupported data type of weight tensor for CoreML convolution."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionBaseCoreMLOp::SetConvBias() { + if (in_tensors_.size() >= kInputSize2) { + auto bias_tensor = in_tensors_.at(kBiasIndex); + auto org_bias = bias_tensor.Data().get(); + conv_param_->set_hasbias(true); + if (bias_tensor.DataType() == DataType::kNumberTypeFloat32) { + auto *ml_bias_container = conv_param_->mutable_bias()->mutable_floatvalue(); + ml_bias_container->Resize(bias_tensor.ElementNum(), 0); + auto *ml_bias = reinterpret_cast(ml_bias_container->mutable_data()); + memcpy(ml_bias, org_bias, bias_tensor.DataSize()); + } else { + MS_LOG(ERROR) << "Unsupported data type of bias tensor for CoreML convolution."; + return RET_ERROR; + } + } + return RET_OK; +} + +int ConvolutionBaseCoreMLOp::BuildLayer() { + MS_ASSERT(op_ != nullptr); + conv_param_ = op_->mutable_convolution(); + auto ret = SetConvParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set conv param failed for op: " << name_; + return RET_ERROR; + } + ret = SetConvWeight(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set conv weight failed for op: " << name_; + return RET_ERROR; + } + ret = SetConvBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set conv bias failed for op: " << name_; + return RET_ERROR; + } + if (act_type_ != schema::ActivationType_NO_ACTIVATION) { + ret = SetActivation(act_type_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set conv activation failed for op: " << name_; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.h b/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.h index 441885c3228..0fe5a4abea7 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.h +++ b/mindspore/lite/src/litert/delegate/coreml/op/convolution_base_coreml.h @@ -1,61 +1,61 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ - -#include -#include -#include -#include -#include -#include "src/litert/delegate/coreml/op/coreml_op.h" -namespace mindspore::lite { -class ConvolutionBaseCoreMLOp : public CoreMLOp { - public: - ConvolutionBaseCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : CoreMLOp(primitive, in_tensors, out_tensors, name) { - input_h_ = static_cast(in_tensors.at(0).Shape().at(kNHWC_H)); - input_w_ = static_cast(in_tensors.at(0).Shape().at(kNHWC_W)); - kernel_h_ = static_cast(in_tensors.at(1).Shape().at(MS_WT_H)); - kernel_w_ = static_cast(in_tensors.at(1).Shape().at(MS_WT_W)); - output_h_ = static_cast(out_tensors.at(0).Shape().at(kNHWC_H)); - output_w_ = static_cast(out_tensors.at(0).Shape().at(kNHWC_W)); - } - - int BuildLayer() override; - - protected: - virtual int SetConvParam() { return RET_OK; } - - virtual int SetConvWeight(); - - virtual int SetConvBias(); - - protected: - int input_h_; - int input_w_; - int kernel_h_; - int kernel_w_; - int output_h_; - int output_w_; - CoreML::Specification::ConvolutionLayerParams *conv_param_ = nullptr; - schema::ActivationType act_type_ = schema::ActivationType_NO_ACTIVATION; - std::unique_ptr trans_in_op_ = nullptr; - std::unique_ptr trans_out_op_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ + +#include +#include +#include +#include +#include +#include "src/litert/delegate/coreml/op/coreml_op.h" +namespace mindspore::lite { +class ConvolutionBaseCoreMLOp : public CoreMLOp { + public: + ConvolutionBaseCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : CoreMLOp(primitive, in_tensors, out_tensors, name) { + input_h_ = static_cast(in_tensors.at(0).Shape().at(kNHWC_H)); + input_w_ = static_cast(in_tensors.at(0).Shape().at(kNHWC_W)); + kernel_h_ = static_cast(in_tensors.at(1).Shape().at(MS_WT_H)); + kernel_w_ = static_cast(in_tensors.at(1).Shape().at(MS_WT_W)); + output_h_ = static_cast(out_tensors.at(0).Shape().at(kNHWC_H)); + output_w_ = static_cast(out_tensors.at(0).Shape().at(kNHWC_W)); + } + + int BuildLayer() override; + + protected: + virtual int SetConvParam() { return RET_OK; } + + virtual int SetConvWeight(); + + virtual int SetConvBias(); + + protected: + int input_h_; + int input_w_; + int kernel_h_; + int kernel_w_; + int output_h_; + int output_w_; + CoreML::Specification::ConvolutionLayerParams *conv_param_ = nullptr; + schema::ActivationType act_type_ = schema::ActivationType_NO_ACTIVATION; + std::unique_ptr trans_in_op_ = nullptr; + std::unique_ptr trans_out_op_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_BASE_COREML_H_ diff --git a/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.cc b/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.cc index 429ab00084a..1883ccf15b8 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.cc +++ b/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.cc @@ -1,71 +1,71 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/litert/delegate/coreml/op/convolution_coreml.h" -#include -#include "src/litert/delegate/delegate_utils.h" -namespace mindspore::lite { -int ConvolutionCoreMLOp::IsSupport() { - if (!in_tensors_[kWeightIndex].IsConst()) { - MS_LOG(WARNING) << "CoreML convolution does not support dynamic weight."; - return RET_NOT_SUPPORT; - } - conv_prim_ = op_primitive_->value_as_Conv2DFusion(); - if (conv_prim_ == nullptr) { - MS_LOG(ERROR) << "Get null primitive value for op ." << name_; - return RET_ERROR; - } - CHECK_NULL_RETURN(conv_prim_->stride()); - stride_h_ = static_cast(*(conv_prim_->stride()->begin())); - stride_w_ = static_cast(*(conv_prim_->stride()->begin() + 1)); - CHECK_NULL_RETURN(conv_prim_->dilation()); - dilation_h_ = static_cast(*(conv_prim_->dilation()->begin())); - dilation_w_ = static_cast(*(conv_prim_->dilation()->begin() + 1)); - // org conv format: NHWC - if (stride_h_ > in_tensors_[0].Shape()[kNHWC_H] || stride_w_ > in_tensors_[0].Shape()[kNHWC_W]) { - MS_LOG(WARNING) << "CoreML convolution does not support stride greater than input size."; - return RET_NOT_SUPPORT; - } - return RET_OK; -} - -int ConvolutionCoreMLOp::SetConvParam() { - auto group = static_cast(conv_prim_->group()); - conv_param_->set_ngroups(group); - conv_param_->add_stride(stride_h_); - conv_param_->add_stride(stride_w_); - conv_param_->add_dilationfactor(dilation_h_); - conv_param_->add_dilationfactor(dilation_w_); - if (conv_prim_->pad_mode() == schema::PadMode_SAME) { - conv_param_->mutable_same(); - } else { - conv_param_->mutable_valid(); - if (conv_prim_->pad_list() != nullptr) { - auto pad_u = static_cast(*(conv_prim_->pad_list()->begin() + PAD_UP)); - auto pad_d = static_cast(*(conv_prim_->pad_list()->begin() + PAD_DOWN)); - auto pad_l = static_cast(*(conv_prim_->pad_list()->begin() + PAD_LEFT)); - auto pad_r = static_cast(*(conv_prim_->pad_list()->begin() + PAD_RIGHT)); - auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Fail to set padding for op: " << name_; - return RET_ERROR; - } - } - } - act_type_ = conv_prim_->activation_type(); - return RET_OK; -} -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/delegate/coreml/op/convolution_coreml.h" +#include +#include "src/litert/delegate/delegate_utils.h" +namespace mindspore::lite { +int ConvolutionCoreMLOp::IsSupport() { + if (!in_tensors_[kWeightIndex].IsConst()) { + MS_LOG(WARNING) << "CoreML convolution does not support dynamic weight."; + return RET_NOT_SUPPORT; + } + conv_prim_ = op_primitive_->value_as_Conv2DFusion(); + if (conv_prim_ == nullptr) { + MS_LOG(ERROR) << "Get null primitive value for op ." << name_; + return RET_ERROR; + } + CHECK_NULL_RETURN(conv_prim_->stride()); + stride_h_ = static_cast(*(conv_prim_->stride()->begin())); + stride_w_ = static_cast(*(conv_prim_->stride()->begin() + 1)); + CHECK_NULL_RETURN(conv_prim_->dilation()); + dilation_h_ = static_cast(*(conv_prim_->dilation()->begin())); + dilation_w_ = static_cast(*(conv_prim_->dilation()->begin() + 1)); + // org conv format: NHWC + if (stride_h_ > in_tensors_[0].Shape()[kNHWC_H] || stride_w_ > in_tensors_[0].Shape()[kNHWC_W]) { + MS_LOG(WARNING) << "CoreML convolution does not support stride greater than input size."; + return RET_NOT_SUPPORT; + } + return RET_OK; +} + +int ConvolutionCoreMLOp::SetConvParam() { + auto group = static_cast(conv_prim_->group()); + conv_param_->set_ngroups(group); + conv_param_->add_stride(stride_h_); + conv_param_->add_stride(stride_w_); + conv_param_->add_dilationfactor(dilation_h_); + conv_param_->add_dilationfactor(dilation_w_); + if (conv_prim_->pad_mode() == schema::PadMode_SAME) { + conv_param_->mutable_same(); + } else { + conv_param_->mutable_valid(); + if (conv_prim_->pad_list() != nullptr) { + auto pad_u = static_cast(*(conv_prim_->pad_list()->begin() + PAD_UP)); + auto pad_d = static_cast(*(conv_prim_->pad_list()->begin() + PAD_DOWN)); + auto pad_l = static_cast(*(conv_prim_->pad_list()->begin() + PAD_LEFT)); + auto pad_r = static_cast(*(conv_prim_->pad_list()->begin() + PAD_RIGHT)); + auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Fail to set padding for op: " << name_; + return RET_ERROR; + } + } + } + act_type_ = conv_prim_->activation_type(); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.h b/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.h index d6e77b3356e..726389816d3 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.h +++ b/mindspore/lite/src/litert/delegate/coreml/op/convolution_coreml.h @@ -1,46 +1,46 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ - -#include -#include -#include -#include -#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" -namespace mindspore::lite { -class ConvolutionCoreMLOp : public ConvolutionBaseCoreMLOp { - public: - ConvolutionCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : ConvolutionBaseCoreMLOp(primitive, in_tensors, out_tensors, name) {} - - int IsSupport() override; - - private: - schema::PadMode GetPadMode(); - - int SetConvParam() override; - - private: - int stride_h_{0}; - int stride_w_{0}; - int dilation_h_{0}; - int dilation_w_{0}; - const schema::Conv2DFusion *conv_prim_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ + +#include +#include +#include +#include +#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" +namespace mindspore::lite { +class ConvolutionCoreMLOp : public ConvolutionBaseCoreMLOp { + public: + ConvolutionCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : ConvolutionBaseCoreMLOp(primitive, in_tensors, out_tensors, name) {} + + int IsSupport() override; + + private: + schema::PadMode GetPadMode(); + + int SetConvParam() override; + + private: + int stride_h_{0}; + int stride_w_{0}; + int dilation_h_{0}; + int dilation_w_{0}; + const schema::Conv2DFusion *conv_prim_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_CONVOLUTION_COREML_H_ diff --git a/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.cc b/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.cc index 7ddf4b3168d..7d39f784fbb 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.cc +++ b/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.cc @@ -1,70 +1,70 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/litert/delegate/coreml/op/deconvolution_coreml.h" -#include "src/litert/delegate/delegate_utils.h" -namespace mindspore::lite { -int DeconvolutionCoreMLOp::IsSupport() { - if (!in_tensors_[kWeightIndex].IsConst()) { - MS_LOG(WARNING) << "CoreML deconvolution does not support dynamic weight."; - return RET_NOT_SUPPORT; - } - deconv_prim_ = op_primitive_->value_as_Conv2dTransposeFusion(); - if (deconv_prim_ == nullptr) { - MS_LOG(ERROR) << "Get null primitive value for op ." << name_; - return RET_ERROR; - } - if (static_cast(deconv_prim_->group()) != 1) { - MS_LOG(WARNING) << "Only support group equals 1 for npu deconvolution op"; - return RET_NOT_SUPPORT; - } - return RET_OK; -} - -int DeconvolutionCoreMLOp::SetConvParam() { - conv_param_->set_isdeconvolution(true); - CHECK_NULL_RETURN(deconv_prim_->stride()); - auto stride_h = static_cast(*(deconv_prim_->stride()->begin())); - auto stride_w = static_cast(*(deconv_prim_->stride()->begin() + 1)); - conv_param_->add_stride(stride_h); - conv_param_->add_stride(stride_w); - CHECK_NULL_RETURN(deconv_prim_->dilation()); - auto dilation_h = static_cast(*(deconv_prim_->dilation()->begin())); - auto dilation_w = static_cast(*(deconv_prim_->dilation()->begin() + 1)); - conv_param_->add_dilationfactor(dilation_h); - conv_param_->add_dilationfactor(dilation_w); - conv_param_->add_outputshape(output_h_); - conv_param_->add_outputshape(output_w_); - if (deconv_prim_->pad_mode() == schema::PadMode_SAME) { - conv_param_->mutable_same(); - } else { - conv_param_->mutable_valid(); - if (deconv_prim_->pad_list() != nullptr) { - auto pad_u = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_UP)); - auto pad_d = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_DOWN)); - auto pad_l = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_LEFT)); - auto pad_r = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_RIGHT)); - auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Fail to set padding for op: " << name_; - return RET_ERROR; - } - } - } - act_type_ = deconv_prim_->activation_type(); - return RET_OK; -} -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/delegate/coreml/op/deconvolution_coreml.h" +#include "src/litert/delegate/delegate_utils.h" +namespace mindspore::lite { +int DeconvolutionCoreMLOp::IsSupport() { + if (!in_tensors_[kWeightIndex].IsConst()) { + MS_LOG(WARNING) << "CoreML deconvolution does not support dynamic weight."; + return RET_NOT_SUPPORT; + } + deconv_prim_ = op_primitive_->value_as_Conv2dTransposeFusion(); + if (deconv_prim_ == nullptr) { + MS_LOG(ERROR) << "Get null primitive value for op ." << name_; + return RET_ERROR; + } + if (static_cast(deconv_prim_->group()) != 1) { + MS_LOG(WARNING) << "Only support group equals 1 for npu deconvolution op"; + return RET_NOT_SUPPORT; + } + return RET_OK; +} + +int DeconvolutionCoreMLOp::SetConvParam() { + conv_param_->set_isdeconvolution(true); + CHECK_NULL_RETURN(deconv_prim_->stride()); + auto stride_h = static_cast(*(deconv_prim_->stride()->begin())); + auto stride_w = static_cast(*(deconv_prim_->stride()->begin() + 1)); + conv_param_->add_stride(stride_h); + conv_param_->add_stride(stride_w); + CHECK_NULL_RETURN(deconv_prim_->dilation()); + auto dilation_h = static_cast(*(deconv_prim_->dilation()->begin())); + auto dilation_w = static_cast(*(deconv_prim_->dilation()->begin() + 1)); + conv_param_->add_dilationfactor(dilation_h); + conv_param_->add_dilationfactor(dilation_w); + conv_param_->add_outputshape(output_h_); + conv_param_->add_outputshape(output_w_); + if (deconv_prim_->pad_mode() == schema::PadMode_SAME) { + conv_param_->mutable_same(); + } else { + conv_param_->mutable_valid(); + if (deconv_prim_->pad_list() != nullptr) { + auto pad_u = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_UP)); + auto pad_d = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_DOWN)); + auto pad_l = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_LEFT)); + auto pad_r = static_cast(*(deconv_prim_->pad_list()->begin() + PAD_RIGHT)); + auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Fail to set padding for op: " << name_; + return RET_ERROR; + } + } + } + act_type_ = deconv_prim_->activation_type(); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.h b/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.h index e368d5ac7fd..f510edbcce2 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.h +++ b/mindspore/lite/src/litert/delegate/coreml/op/deconvolution_coreml.h @@ -1,42 +1,42 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ - -#include -#include -#include -#include -#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" -namespace mindspore::lite { -class DeconvolutionCoreMLOp : public ConvolutionBaseCoreMLOp { - public: - DeconvolutionCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : ConvolutionBaseCoreMLOp(primitive, in_tensors, out_tensors, name) {} - - int IsSupport() override; - - private: - schema::PadMode GetPadMode(); - - int SetConvParam() override; - - private: - const schema::Conv2dTransposeFusion *deconv_prim_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ + +#include +#include +#include +#include +#include "src/litert/delegate/coreml/op/convolution_base_coreml.h" +namespace mindspore::lite { +class DeconvolutionCoreMLOp : public ConvolutionBaseCoreMLOp { + public: + DeconvolutionCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : ConvolutionBaseCoreMLOp(primitive, in_tensors, out_tensors, name) {} + + int IsSupport() override; + + private: + schema::PadMode GetPadMode(); + + int SetConvParam() override; + + private: + const schema::Conv2dTransposeFusion *deconv_prim_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_DECONVOLUTION_COREML_H_ diff --git a/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.cc b/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.cc index f2e71a20820..3f2fc359190 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.cc +++ b/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.cc @@ -1,71 +1,71 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/litert/delegate/coreml/op/max_pooling_coreml.h" -namespace mindspore::lite { -int MaxPoolingCoreMLOp::InitParams() { - pooling_prim_ = op_primitive_->value_as_MaxPoolFusion(); - if (pooling_prim_ == nullptr) { - MS_LOG(ERROR) << "Get null primitive value for op ." << name_; - return RET_ERROR; - } - return RET_OK; -} - -int MaxPoolingCoreMLOp::BuildLayer() { - MS_ASSERT(op_ != nullptr); - auto pooling_param = op_->mutable_pooling(); - pooling_param->set_type(CoreML::Specification::PoolingLayerParams::MAX); - if (pooling_prim_->global()) { - pooling_param->set_globalpooling(true); - pooling_param->mutable_valid(); - return RET_OK; - } - auto kernel_h = static_cast(*(pooling_prim_->kernel_size()->begin())); - auto kernel_w = static_cast(*(pooling_prim_->kernel_size()->begin() + 1)); - auto stride_h = static_cast(*(pooling_prim_->strides()->begin())); - auto stride_w = static_cast(*(pooling_prim_->strides()->begin() + 1)); - pooling_param->add_stride(stride_h); - pooling_param->add_stride(stride_w); - pooling_param->add_kernelsize(kernel_h); - pooling_param->add_kernelsize(kernel_w); - if (pooling_prim_->pad_mode() == schema::PadMode_SAME) { - pooling_param->mutable_same(); - } else { - pooling_param->mutable_valid(); - if (pooling_prim_->pad() != nullptr) { - auto pad_u = static_cast(*(pooling_prim_->pad()->begin() + PAD_UP)); - auto pad_d = static_cast(*(pooling_prim_->pad()->begin() + PAD_DOWN)); - auto pad_l = static_cast(*(pooling_prim_->pad()->begin() + PAD_LEFT)); - auto pad_r = static_cast(*(pooling_prim_->pad()->begin() + PAD_RIGHT)); - auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Fail to set padding for op: " << name_; - return RET_ERROR; - } - } - } - auto act_type = pooling_prim_->activation_type(); - if (act_type != schema::ActivationType_NO_ACTIVATION) { - auto ret = SetActivation(act_type); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set pooling activation failed for op: " << name_; - return RET_ERROR; - } - } - return RET_OK; -} -} // namespace mindspore::lite +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/delegate/coreml/op/max_pooling_coreml.h" +namespace mindspore::lite { +int MaxPoolingCoreMLOp::InitParams() { + pooling_prim_ = op_primitive_->value_as_MaxPoolFusion(); + if (pooling_prim_ == nullptr) { + MS_LOG(ERROR) << "Get null primitive value for op ." << name_; + return RET_ERROR; + } + return RET_OK; +} + +int MaxPoolingCoreMLOp::BuildLayer() { + MS_ASSERT(op_ != nullptr); + auto pooling_param = op_->mutable_pooling(); + pooling_param->set_type(CoreML::Specification::PoolingLayerParams::MAX); + if (pooling_prim_->global()) { + pooling_param->set_globalpooling(true); + pooling_param->mutable_valid(); + return RET_OK; + } + auto kernel_h = static_cast(*(pooling_prim_->kernel_size()->begin())); + auto kernel_w = static_cast(*(pooling_prim_->kernel_size()->begin() + 1)); + auto stride_h = static_cast(*(pooling_prim_->strides()->begin())); + auto stride_w = static_cast(*(pooling_prim_->strides()->begin() + 1)); + pooling_param->add_stride(stride_h); + pooling_param->add_stride(stride_w); + pooling_param->add_kernelsize(kernel_h); + pooling_param->add_kernelsize(kernel_w); + if (pooling_prim_->pad_mode() == schema::PadMode_SAME) { + pooling_param->mutable_same(); + } else { + pooling_param->mutable_valid(); + if (pooling_prim_->pad() != nullptr) { + auto pad_u = static_cast(*(pooling_prim_->pad()->begin() + PAD_UP)); + auto pad_d = static_cast(*(pooling_prim_->pad()->begin() + PAD_DOWN)); + auto pad_l = static_cast(*(pooling_prim_->pad()->begin() + PAD_LEFT)); + auto pad_r = static_cast(*(pooling_prim_->pad()->begin() + PAD_RIGHT)); + auto ret = SetPadding({pad_u, pad_d, pad_l, pad_r}); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Fail to set padding for op: " << name_; + return RET_ERROR; + } + } + } + auto act_type = pooling_prim_->activation_type(); + if (act_type != schema::ActivationType_NO_ACTIVATION) { + auto ret = SetActivation(act_type); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling activation failed for op: " << name_; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.h b/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.h index 3904c7ccc68..ab0625f5e3e 100644 --- a/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.h +++ b/mindspore/lite/src/litert/delegate/coreml/op/max_pooling_coreml.h @@ -1,39 +1,39 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ - -#include -#include -#include -#include -#include "src/litert/delegate/coreml/op/coreml_op.h" -namespace mindspore::lite { -class MaxPoolingCoreMLOp : public CoreMLOp { - public: - MaxPoolingCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : CoreMLOp(primitive, in_tensors, out_tensors, name) {} - - int InitParams() override; - - int BuildLayer() override; - - private: - const schema::MaxPoolFusion *pooling_prim_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ + +#include +#include +#include +#include +#include "src/litert/delegate/coreml/op/coreml_op.h" +namespace mindspore::lite { +class MaxPoolingCoreMLOp : public CoreMLOp { + public: + MaxPoolingCoreMLOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, std::string name) + : CoreMLOp(primitive, in_tensors, out_tensors, name) {} + + int InitParams() override; + + int BuildLayer() override; + + private: + const schema::MaxPoolFusion *pooling_prim_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_COREML_OP_MAX_POOLING_COREML_H_ diff --git a/mindspore/lite/src/litert/kernel/cpu/bolt/bolt b/mindspore/lite/src/litert/kernel/cpu/bolt/bolt deleted file mode 160000 index 1c7f2642117..00000000000 --- a/mindspore/lite/src/litert/kernel/cpu/bolt/bolt +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1c7f2642117b72c881c9696e188244a2de6c1203 diff --git a/mindspore/lite/test/st/scripts/experimental/config/models_accuracy.yaml b/mindspore/lite/test/st/scripts/experimental/config/models_accuracy.yaml index 817abd3fbc5..ff79b29ab3c 100644 --- a/mindspore/lite/test/st/scripts/experimental/config/models_accuracy.yaml +++ b/mindspore/lite/test/st/scripts/experimental/config/models_accuracy.yaml @@ -1,873 +1,873 @@ -tf: - browser_scene1_v2: - fmk: tf - input_number: 1 - benchmark_shapes: 75,19910 - acc_threshold: 0.5 - - browser_v15: - fmk: tf - input_number: 1 - benchmark_shapes: 75,10896 - acc_threshold: 0 - - browser_v58: - fmk: tf - input_number: 1 - benchmark_shapes: 75,26180 - acc_threshold: 0 - - densenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - fsr_270_mindspore: - fmk: tf - input_number: 1 - - fsr_360_mindspore: - fmk: tf - input_number: 1 - - fsr_720_mindspore: - fmk: tf - input_number: 1 - - hiai_AADB_HADB_MBV2_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_cn_recognize_modify_padv2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,32,512,1 - - hiai_cpu_face_emotion: - fmk: tf - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tf - input_number: 1 - - hiai_ctpn_feature_map: - fmk: tf - input_number: 1 - - hiai_cv_focusShootOCRModel_02: - fmk: tf - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tf - input_number: 1 - - hiai_cv_poseEstimation: - fmk: tf - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tf - input_number: 1 - - hiai_ghostnet: - fmk: tf - input_number: 1 - - hiai_humanDetection: - fmk: tf - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tf - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tf - input_number: 1 - - hiai_latin_ocr: - fmk: tf - input_number: 1 - - hiai_model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_PoseEstimation_Pcm: - fmk: tf - input_number: 1 - - inception_resnet_v2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - inception_v3: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - inception_v4: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - matmul: - fmk: tf - input_number: 1 - - ml_ei_headpose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,64,64,3 - - ml_face_openclose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,32,32,3 - - ml_object_detect: - fmk: tf - input_number: 1 - benchmark_shapes: 1,288,288,3 - - ml_ocr_latin: - fmk: tf - input_number: 1 - - ml_video_edit_shot_selection_opticalFlow: - fmk: tf - input_number: 1 - - ml_vision_guide_detection1: - fmk: tf - input_number: 1 - - ml_vision_guide_detection2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,320,320,1 - - ml_vision_guide_detection3: - fmk: tf - input_number: 1 - - mnasnet_1.0_224: - fmk: tf - input_number: 1 - - mnasnet_1.3_224: - fmk: tf - input_number: 1 - - mobilenet_v1_0.25_128_frozen: - fmk: tf - input_number: 1 - benchmark_shapes: 1,128,128,3 - - model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_AADB_HADB_MBV2_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_AADB_HADB_MBV3_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_face_features_v1: - fmk: tf - input_number: 1 - - mtk_model_ckpt: - fmk: tf - input_number: 1 - - mtk_model_face_dress: - fmk: tf - input_number: 1 - benchmark_shapes: 1,128,128,3 - - mtk_model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - nasnet_large: - fmk: tf - input_number: 1 - benchmark_shapes: 1,331,331,3 - - nasnet_mobile: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - Q_crnn_ori_75w_slim_norm: - fmk: tf - input_number: 1 - - Q_crnn_ori_v2_405001_notrans_nopre: - fmk: tf - input_number: 1 - - Q_crnn_screen_slim400w_more_20w: - fmk: tf - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: - fmk: tf - input_number: 1 - - Q_hand_0812: - fmk: tf - input_number: 1 - - Q_inception-249970-672-11-16: - fmk: tf - input_number: 1 - - scan_hms_angle: - fmk: tf - input_number: 1 - - siteAI_trans_nonlinear: - fmk: tf - input_number: 1 - benchmark_shapes: 1,137 - - siteAI_trans_nonlinear134g: - fmk: tf - input_number: 1 - benchmark_shapes: 1,137 - - siteAI_trans_nonlinear134g_nrz: - fmk: tf - input_number: 1 - benchmark_shapes: 1,182 - - siteAI_trans_nonlinear40g: - fmk: tf - input_number: 1 - benchmark_shapes: 1,271 - - siteAI_wireless_depress_w: - fmk: tf - input_number: 1 - benchmark_shapes: 1,36 - - siteAI_wireless_restore_w: - fmk: tf - input_number: 1 - benchmark_shapes: 1,36 - - tensor_dot: - fmk: tf - input_number: 1 - benchmark_shapes: 1,217 - ---- - -caffe: - 2012_ATLANTA_1class_20190621_v4.x_nomean: - fmk: caffe - input_number: 1 - - bank_card_recognition_fcny: - fmk: caffe - input_number: 1 - - hdc_contour_pose_128: - fmk: caffe - input_number: 1 - - hdc_fivembnet: - fmk: caffe - input_number: 1 - - hdc_mobilenetface: - fmk: caffe - input_number: 1 - - hdc_resnet: - fmk: caffe - input_number: 1 - - hiai_cpu_face_hat: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_04: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_06: - fmk: caffe - input_number: 1 - - hiai_face_landmark: - fmk: caffe - input_number: 1 - - hiai_video_seg: - fmk: caffe - input_number: 1 - - HWSR-s_256_256: - fmk: caffe - input_number: 1 - - ml_face_compare: - fmk: caffe - input_number: 1 - - ml_face_contour: - fmk: caffe - input_number: 1 - - ml_face_emotion: - fmk: caffe - input_number: 1 - - ml_face_hat: - fmk: caffe - input_number: 1 - - ml_face_landmark: - fmk: caffe - input_number: 1 - - ml_hand_3d_regression: - fmk: caffe - input_number: 1 - - ml_hardware_eyeclose: - fmk: caffe - input_number: 1 - - ml_liveness_detect_landmark_tmp: - fmk: caffe - input_number: 1 - - ml_ocr_cn: - fmk: caffe - input_number: 1 - - ml_text_division: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_20211119: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_v3: - fmk: caffe - input_number: 1 - - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: - fmk: caffe - input_number: 1 - - ml_video_edit_person_divison_pic: - fmk: caffe - input_number: 1 - ---- - -onnx: - 01-face_det_400_400: - fmk: onnx - input_number: 1 - input_suffix: .onnx.bin - output_suffix: .onnx.out - benchmark_shapes: 1,400,400,3 - acc_threshold: 4.5 - - efficientnet-lite4-11: - fmk: onnx - input_number: 1 - - emotion-ferplus-8: - fmk: onnx - input_number: 1 - - gender_lstm_scd: - fmk: onnx - input_number: 1 - - gender_lstm_vad: - fmk: onnx - input_number: 1 - - gender_resnet34_lzl: - fmk: onnx - input_number: 1 - - hdc_Face_Landmark5_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_ocr_attention: - fmk: onnx - input_number: 1 - - ml_2012_ocr_cn: - fmk: onnx - input_number: 1 - - ml_2012_ocr_cn_noLSTM: - fmk: onnx - input_number: 1 - - ml_table_segment: - fmk: onnx - input_number: 1 - - mnist-8: - fmk: onnx - input_number: 1 - - rcnn-ilsvrc13-9: - fmk: onnx - input_number: 1 - ---- - -tflite: - bloom_model_age_gender: - fmk: tflite - input_number: 1 - - bloom_new_detect: - fmk: tflite - input_number: 1 - - deeplabv3_1_default_1: - fmk: tflite - input_number: 1 - - deeplabv3_257_mv_gpu: - fmk: tflite - input_number: 1 - - efficientnet_lite0_fp32_2: - fmk: tflite - input_number: 1 - - hiai_AADB_HADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_chinese_english_recognize_model_float32: - fmk: tflite - input_number: 1 - - hiai_cn_recognize_modify_padv2: - fmk: tflite - input_number: 1 - - hiai_cpu_face_emotion: - fmk: tflite - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tflite - input_number: 1 - - hiai_cpu_face_headpose: - fmk: tflite - input_number: 1 - - hiai_ctpn_feature_map: - fmk: tflite - input_number: 1 - - hiai_cv_focusShootOCRModel_02: - fmk: tflite - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tflite - input_number: 1 - - hiai_cv_labelDetectorModel_v2: - fmk: tflite - input_number: 1 - - hiai_cv_labelDetectorModel_v4: - fmk: tflite - input_number: 1 - - hiai_cv_poseEstimation: - fmk: tflite - input_number: 1 - - hiai_detect_curve_model_float32: - fmk: tflite - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tflite - input_number: 1 - - hiai_detectmodel_desnet_256_128_64_32: - fmk: tflite - input_number: 1 - - hiai_ghostnet: - fmk: tflite - input_number: 1 - - hiai_humanDetection: - fmk: tflite - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tflite - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tflite - input_number: 1 - - hiai_latin_ocr: - fmk: tflite - input_number: 1 - - hiai_latin_ocr_1: - fmk: tflite - input_number: 1 - - hiai_model_normalize_object_scene_ps_20200519: - fmk: tflite - input_number: 1 - - hiai_PoseEstimation_Pcm: - fmk: tflite - input_number: 1 - - ide_label_base: - fmk: tflite - input_number: 1 - - inception_resnet_v2: - fmk: tflite - input_number: 1 - - inception_v3: - fmk: tflite - input_number: 1 - - inception_v4: - fmk: tflite - input_number: 1 - - lite-model_on_device_vision_classifier_popular_us_products_V1_1: - fmk: tflite - input_number: 1 - - lite-model_on_device_vision_classifier_popular_wine_V1_1: - fmk: tflite - input_number: 1 - - lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: - fmk: tflite - input_number: 1 - - mindspore_text_classification_tflite: - fmk: tflite - input_number: 1 - - ml_ei_headpose: - fmk: tflite - input_number: 1 - - ml_ei_landmark: - fmk: tflite - input_number: 1 - - ml_ei_landmark_pb2tflite: - fmk: tflite - input_number: 1 - - ml_face_openclose: - fmk: tflite - input_number: 1 - - ml_face_openclose_tflite: - fmk: tflite - input_number: 1 - - ml_object_detect: - fmk: tflite - input_number: 1 - - ml_object_detect_1: - fmk: tflite - input_number: 1 - - ml_object_detect_pb2tflite: - fmk: tflite - input_number: 1 - - ml_ocr_latin: - fmk: tflite - input_number: 1 - - ml_ocr_latin_pb2tflite: - fmk: tflite - input_number: 1 - - ml_pic_shopping: - fmk: tflite - input_number: 1 - - ml_pic_shopping_pb2tflite: - fmk: tflite - input_number: 1 - - ml_text_correction: - fmk: tflite - input_number: 1 - - ml_vision_guide_detection1_pb2tflite: - fmk: tflite - input_number: 1 - - ml_vision_guide_detection3_pb2tflite: - fmk: tflite - input_number: 1 - - mnasnet_0.50_224_1_metadata_1: - fmk: tflite - input_number: 1 - - mnasnet_1.3_224: - fmk: tflite - input_number: 1 - - mnist: - fmk: tflite - input_number: 1 - - mobilenet: - fmk: tflite - input_number: 1 - - mobilenet_v1_0.25_128: - fmk: tflite - input_number: 1 - - mobilenet_v2_1.0_224: - fmk: tflite - input_number: 1 - - model_emotions_0727_nosoftmax: - fmk: tflite - input_number: 1 - - mtk_AADB_HADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - mtk_AADB_HADB_MBV3_model_fp32: - fmk: tflite - input_number: 1 - - mtk_convert_model: - fmk: tflite - input_number: 1 - - mtk_face_features_v1: - fmk: tflite - input_number: 1 - - mtk_face_recognition: - fmk: tflite - input_number: 1 - - mtk_model_ckpt: - fmk: tflite - input_number: 1 - - mtk_model_emotions_0727_nosoftmax: - fmk: tflite - input_number: 1 - - mtk_model_face_dress: - fmk: tflite - input_number: 1 - - mtk_model_normalize_object_scene_ps_20200519_f32: - fmk: tflite - input_number: 1 - - mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax: - fmk: tflite - input_number: 1 - - mtk_new_detect: - fmk: tflite - input_number: 1 - - multi_person_mobilenet_v1_075_float: - fmk: tflite - input_number: 1 - - posenet_mobilenet_float_075_1_default_1: - fmk: tflite - input_number: 1 - - Q_AADB_HADB_MBV2_model: - fmk: tflite - input_number: 1 - - Q_convert: - fmk: tflite - input_number: 1 - - Q_crnn_ori_75w_slim_norm_pb2tflite: - fmk: tflite - input_number: 1 - - Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite: - fmk: tflite - input_number: 1 - - Q_crnn_screen_slim400w_more_20w_pb2tflite: - fmk: tflite - input_number: 1 - - Q_detect_fpn_add_inception-1448650: - fmk: tflite - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: - fmk: tflite - input_number: 1 - - Q_focusocr_cn_recog: - fmk: tflite - input_number: 1 - - Q_focusocr_jk_recog: - fmk: tflite - input_number: 1 - - Q_hand_0812_pb2tflite: - fmk: tflite - input_number: 1 - - Q_inception-249970-672-11-16_pb2tflite: - fmk: tflite - input_number: 1 - - Q_object_scene: - fmk: tflite - input_number: 1 - - Q888_age_gender_orderd: - fmk: tflite - input_number: 1 - - Q888_face_dress_mv3y: - fmk: tflite - input_number: 1 - - Q888_face_emo_dress_mv3_orderd: - fmk: tflite - input_number: 1 - - Q888_HADB_AADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - Q888_isface: - fmk: tflite - input_number: 1 - - Q888_landmark: - fmk: tflite - input_number: 1 - - Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax: - fmk: tflite - input_number: 1 - - Q888_new_detect: - fmk: tflite - input_number: 1 - - Q888_pose: - fmk: tflite - input_number: 1 - - resnet: - fmk: tflite - input_number: 1 - - scan_hms_angle_pb2tflite: - fmk: tflite - input_number: 1 - - scan_hms_angle1: - fmk: tflite - input_number: 1 - - scan_hms_detect: - fmk: tflite - input_number: 1 - - scan_hms_detect_pb2tflite: - fmk: tflite - input_number: 1 - - siteAI_digcom_AI_ECN: - fmk: tflite - input_number: 1 - - siteAI_digcom_g2v_keras: - fmk: tflite - input_number: 1 - - siteAI_trans_nonlinear: - fmk: tflite - input_number: 1 - - siteAI_trans_tcpclassify: - fmk: tflite - input_number: 1 - - siteAI_wireless_depress_w: - fmk: tflite - input_number: 1 - - siteAI_wireless_restore_w: - fmk: tflite - input_number: 1 - - squeezenet: - fmk: tflite - input_number: 1 - - text_classification: - fmk: tflite - input_number: 1 +tf: + browser_scene1_v2: + fmk: tf + input_number: 1 + benchmark_shapes: 75,19910 + acc_threshold: 0.5 + + browser_v15: + fmk: tf + input_number: 1 + benchmark_shapes: 75,10896 + acc_threshold: 0 + + browser_v58: + fmk: tf + input_number: 1 + benchmark_shapes: 75,26180 + acc_threshold: 0 + + densenet: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + fsr_270_mindspore: + fmk: tf + input_number: 1 + + fsr_360_mindspore: + fmk: tf + input_number: 1 + + fsr_720_mindspore: + fmk: tf + input_number: 1 + + hiai_AADB_HADB_MBV2_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_cn_recognize_modify_padv2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,32,512,1 + + hiai_cpu_face_emotion: + fmk: tf + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tf + input_number: 1 + + hiai_ctpn_feature_map: + fmk: tf + input_number: 1 + + hiai_cv_focusShootOCRModel_02: + fmk: tf + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tf + input_number: 1 + + hiai_cv_poseEstimation: + fmk: tf + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tf + input_number: 1 + + hiai_ghostnet: + fmk: tf + input_number: 1 + + hiai_humanDetection: + fmk: tf + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tf + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tf + input_number: 1 + + hiai_latin_ocr: + fmk: tf + input_number: 1 + + hiai_model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_PoseEstimation_Pcm: + fmk: tf + input_number: 1 + + inception_resnet_v2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + inception_v3: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + inception_v4: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + matmul: + fmk: tf + input_number: 1 + + ml_ei_headpose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,64,64,3 + + ml_face_openclose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,32,32,3 + + ml_object_detect: + fmk: tf + input_number: 1 + benchmark_shapes: 1,288,288,3 + + ml_ocr_latin: + fmk: tf + input_number: 1 + + ml_video_edit_shot_selection_opticalFlow: + fmk: tf + input_number: 1 + + ml_vision_guide_detection1: + fmk: tf + input_number: 1 + + ml_vision_guide_detection2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,320,320,1 + + ml_vision_guide_detection3: + fmk: tf + input_number: 1 + + mnasnet_1.0_224: + fmk: tf + input_number: 1 + + mnasnet_1.3_224: + fmk: tf + input_number: 1 + + mobilenet_v1_0.25_128_frozen: + fmk: tf + input_number: 1 + benchmark_shapes: 1,128,128,3 + + model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_AADB_HADB_MBV2_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_AADB_HADB_MBV3_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_face_features_v1: + fmk: tf + input_number: 1 + + mtk_model_ckpt: + fmk: tf + input_number: 1 + + mtk_model_face_dress: + fmk: tf + input_number: 1 + benchmark_shapes: 1,128,128,3 + + mtk_model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + nasnet_large: + fmk: tf + input_number: 1 + benchmark_shapes: 1,331,331,3 + + nasnet_mobile: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + Q_crnn_ori_75w_slim_norm: + fmk: tf + input_number: 1 + + Q_crnn_ori_v2_405001_notrans_nopre: + fmk: tf + input_number: 1 + + Q_crnn_screen_slim400w_more_20w: + fmk: tf + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: + fmk: tf + input_number: 1 + + Q_hand_0812: + fmk: tf + input_number: 1 + + Q_inception-249970-672-11-16: + fmk: tf + input_number: 1 + + scan_hms_angle: + fmk: tf + input_number: 1 + + siteAI_trans_nonlinear: + fmk: tf + input_number: 1 + benchmark_shapes: 1,137 + + siteAI_trans_nonlinear134g: + fmk: tf + input_number: 1 + benchmark_shapes: 1,137 + + siteAI_trans_nonlinear134g_nrz: + fmk: tf + input_number: 1 + benchmark_shapes: 1,182 + + siteAI_trans_nonlinear40g: + fmk: tf + input_number: 1 + benchmark_shapes: 1,271 + + siteAI_wireless_depress_w: + fmk: tf + input_number: 1 + benchmark_shapes: 1,36 + + siteAI_wireless_restore_w: + fmk: tf + input_number: 1 + benchmark_shapes: 1,36 + + tensor_dot: + fmk: tf + input_number: 1 + benchmark_shapes: 1,217 + +--- + +caffe: + 2012_ATLANTA_1class_20190621_v4.x_nomean: + fmk: caffe + input_number: 1 + + bank_card_recognition_fcny: + fmk: caffe + input_number: 1 + + hdc_contour_pose_128: + fmk: caffe + input_number: 1 + + hdc_fivembnet: + fmk: caffe + input_number: 1 + + hdc_mobilenetface: + fmk: caffe + input_number: 1 + + hdc_resnet: + fmk: caffe + input_number: 1 + + hiai_cpu_face_hat: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_04: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_06: + fmk: caffe + input_number: 1 + + hiai_face_landmark: + fmk: caffe + input_number: 1 + + hiai_video_seg: + fmk: caffe + input_number: 1 + + HWSR-s_256_256: + fmk: caffe + input_number: 1 + + ml_face_compare: + fmk: caffe + input_number: 1 + + ml_face_contour: + fmk: caffe + input_number: 1 + + ml_face_emotion: + fmk: caffe + input_number: 1 + + ml_face_hat: + fmk: caffe + input_number: 1 + + ml_face_landmark: + fmk: caffe + input_number: 1 + + ml_hand_3d_regression: + fmk: caffe + input_number: 1 + + ml_hardware_eyeclose: + fmk: caffe + input_number: 1 + + ml_liveness_detect_landmark_tmp: + fmk: caffe + input_number: 1 + + ml_ocr_cn: + fmk: caffe + input_number: 1 + + ml_text_division: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_20211119: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_v3: + fmk: caffe + input_number: 1 + + ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: + fmk: caffe + input_number: 1 + + ml_video_edit_person_divison_pic: + fmk: caffe + input_number: 1 + +--- + +onnx: + 01-face_det_400_400: + fmk: onnx + input_number: 1 + input_suffix: .onnx.bin + output_suffix: .onnx.out + benchmark_shapes: 1,400,400,3 + acc_threshold: 4.5 + + efficientnet-lite4-11: + fmk: onnx + input_number: 1 + + emotion-ferplus-8: + fmk: onnx + input_number: 1 + + gender_lstm_scd: + fmk: onnx + input_number: 1 + + gender_lstm_vad: + fmk: onnx + input_number: 1 + + gender_resnet34_lzl: + fmk: onnx + input_number: 1 + + hdc_Face_Landmark5_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_ocr_attention: + fmk: onnx + input_number: 1 + + ml_2012_ocr_cn: + fmk: onnx + input_number: 1 + + ml_2012_ocr_cn_noLSTM: + fmk: onnx + input_number: 1 + + ml_table_segment: + fmk: onnx + input_number: 1 + + mnist-8: + fmk: onnx + input_number: 1 + + rcnn-ilsvrc13-9: + fmk: onnx + input_number: 1 + +--- + +tflite: + bloom_model_age_gender: + fmk: tflite + input_number: 1 + + bloom_new_detect: + fmk: tflite + input_number: 1 + + deeplabv3_1_default_1: + fmk: tflite + input_number: 1 + + deeplabv3_257_mv_gpu: + fmk: tflite + input_number: 1 + + efficientnet_lite0_fp32_2: + fmk: tflite + input_number: 1 + + hiai_AADB_HADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_chinese_english_recognize_model_float32: + fmk: tflite + input_number: 1 + + hiai_cn_recognize_modify_padv2: + fmk: tflite + input_number: 1 + + hiai_cpu_face_emotion: + fmk: tflite + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tflite + input_number: 1 + + hiai_cpu_face_headpose: + fmk: tflite + input_number: 1 + + hiai_ctpn_feature_map: + fmk: tflite + input_number: 1 + + hiai_cv_focusShootOCRModel_02: + fmk: tflite + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tflite + input_number: 1 + + hiai_cv_labelDetectorModel_v2: + fmk: tflite + input_number: 1 + + hiai_cv_labelDetectorModel_v4: + fmk: tflite + input_number: 1 + + hiai_cv_poseEstimation: + fmk: tflite + input_number: 1 + + hiai_detect_curve_model_float32: + fmk: tflite + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tflite + input_number: 1 + + hiai_detectmodel_desnet_256_128_64_32: + fmk: tflite + input_number: 1 + + hiai_ghostnet: + fmk: tflite + input_number: 1 + + hiai_humanDetection: + fmk: tflite + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tflite + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tflite + input_number: 1 + + hiai_latin_ocr: + fmk: tflite + input_number: 1 + + hiai_latin_ocr_1: + fmk: tflite + input_number: 1 + + hiai_model_normalize_object_scene_ps_20200519: + fmk: tflite + input_number: 1 + + hiai_PoseEstimation_Pcm: + fmk: tflite + input_number: 1 + + ide_label_base: + fmk: tflite + input_number: 1 + + inception_resnet_v2: + fmk: tflite + input_number: 1 + + inception_v3: + fmk: tflite + input_number: 1 + + inception_v4: + fmk: tflite + input_number: 1 + + lite-model_on_device_vision_classifier_popular_us_products_V1_1: + fmk: tflite + input_number: 1 + + lite-model_on_device_vision_classifier_popular_wine_V1_1: + fmk: tflite + input_number: 1 + + lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: + fmk: tflite + input_number: 1 + + mindspore_text_classification_tflite: + fmk: tflite + input_number: 1 + + ml_ei_headpose: + fmk: tflite + input_number: 1 + + ml_ei_landmark: + fmk: tflite + input_number: 1 + + ml_ei_landmark_pb2tflite: + fmk: tflite + input_number: 1 + + ml_face_openclose: + fmk: tflite + input_number: 1 + + ml_face_openclose_tflite: + fmk: tflite + input_number: 1 + + ml_object_detect: + fmk: tflite + input_number: 1 + + ml_object_detect_1: + fmk: tflite + input_number: 1 + + ml_object_detect_pb2tflite: + fmk: tflite + input_number: 1 + + ml_ocr_latin: + fmk: tflite + input_number: 1 + + ml_ocr_latin_pb2tflite: + fmk: tflite + input_number: 1 + + ml_pic_shopping: + fmk: tflite + input_number: 1 + + ml_pic_shopping_pb2tflite: + fmk: tflite + input_number: 1 + + ml_text_correction: + fmk: tflite + input_number: 1 + + ml_vision_guide_detection1_pb2tflite: + fmk: tflite + input_number: 1 + + ml_vision_guide_detection3_pb2tflite: + fmk: tflite + input_number: 1 + + mnasnet_0.50_224_1_metadata_1: + fmk: tflite + input_number: 1 + + mnasnet_1.3_224: + fmk: tflite + input_number: 1 + + mnist: + fmk: tflite + input_number: 1 + + mobilenet: + fmk: tflite + input_number: 1 + + mobilenet_v1_0.25_128: + fmk: tflite + input_number: 1 + + mobilenet_v2_1.0_224: + fmk: tflite + input_number: 1 + + model_emotions_0727_nosoftmax: + fmk: tflite + input_number: 1 + + mtk_AADB_HADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + mtk_AADB_HADB_MBV3_model_fp32: + fmk: tflite + input_number: 1 + + mtk_convert_model: + fmk: tflite + input_number: 1 + + mtk_face_features_v1: + fmk: tflite + input_number: 1 + + mtk_face_recognition: + fmk: tflite + input_number: 1 + + mtk_model_ckpt: + fmk: tflite + input_number: 1 + + mtk_model_emotions_0727_nosoftmax: + fmk: tflite + input_number: 1 + + mtk_model_face_dress: + fmk: tflite + input_number: 1 + + mtk_model_normalize_object_scene_ps_20200519_f32: + fmk: tflite + input_number: 1 + + mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax: + fmk: tflite + input_number: 1 + + mtk_new_detect: + fmk: tflite + input_number: 1 + + multi_person_mobilenet_v1_075_float: + fmk: tflite + input_number: 1 + + posenet_mobilenet_float_075_1_default_1: + fmk: tflite + input_number: 1 + + Q_AADB_HADB_MBV2_model: + fmk: tflite + input_number: 1 + + Q_convert: + fmk: tflite + input_number: 1 + + Q_crnn_ori_75w_slim_norm_pb2tflite: + fmk: tflite + input_number: 1 + + Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite: + fmk: tflite + input_number: 1 + + Q_crnn_screen_slim400w_more_20w_pb2tflite: + fmk: tflite + input_number: 1 + + Q_detect_fpn_add_inception-1448650: + fmk: tflite + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: + fmk: tflite + input_number: 1 + + Q_focusocr_cn_recog: + fmk: tflite + input_number: 1 + + Q_focusocr_jk_recog: + fmk: tflite + input_number: 1 + + Q_hand_0812_pb2tflite: + fmk: tflite + input_number: 1 + + Q_inception-249970-672-11-16_pb2tflite: + fmk: tflite + input_number: 1 + + Q_object_scene: + fmk: tflite + input_number: 1 + + Q888_age_gender_orderd: + fmk: tflite + input_number: 1 + + Q888_face_dress_mv3y: + fmk: tflite + input_number: 1 + + Q888_face_emo_dress_mv3_orderd: + fmk: tflite + input_number: 1 + + Q888_HADB_AADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + Q888_isface: + fmk: tflite + input_number: 1 + + Q888_landmark: + fmk: tflite + input_number: 1 + + Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax: + fmk: tflite + input_number: 1 + + Q888_new_detect: + fmk: tflite + input_number: 1 + + Q888_pose: + fmk: tflite + input_number: 1 + + resnet: + fmk: tflite + input_number: 1 + + scan_hms_angle_pb2tflite: + fmk: tflite + input_number: 1 + + scan_hms_angle1: + fmk: tflite + input_number: 1 + + scan_hms_detect: + fmk: tflite + input_number: 1 + + scan_hms_detect_pb2tflite: + fmk: tflite + input_number: 1 + + siteAI_digcom_AI_ECN: + fmk: tflite + input_number: 1 + + siteAI_digcom_g2v_keras: + fmk: tflite + input_number: 1 + + siteAI_trans_nonlinear: + fmk: tflite + input_number: 1 + + siteAI_trans_tcpclassify: + fmk: tflite + input_number: 1 + + siteAI_wireless_depress_w: + fmk: tflite + input_number: 1 + + siteAI_wireless_restore_w: + fmk: tflite + input_number: 1 + + squeezenet: + fmk: tflite + input_number: 1 + + text_classification: + fmk: tflite + input_number: 1 diff --git a/mindspore/lite/test/st/scripts/experimental/config/models_all.yaml b/mindspore/lite/test/st/scripts/experimental/config/models_all.yaml index 7cccf67eeb9..11d69a0b833 100644 --- a/mindspore/lite/test/st/scripts/experimental/config/models_all.yaml +++ b/mindspore/lite/test/st/scripts/experimental/config/models_all.yaml @@ -1,2209 +1,2209 @@ -mindir: - deepfm_criteo_bs_16000_Ascend: - fmk: mindir - input_number: 2 - - efficientnetb0_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - efficientnetb1_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - efficientnetb2_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 4 - - efficientnetb3_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - inceptionv3_ascend: - fmk: mindir - input_number: 1 - - inceptionV4: - fmk: mindir - input_number: 1 - - mobilenetv3large_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - mobilenetv3small_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - pix2pix_facades_bs1: - fmk: mindir - input_number: 1 - - resnet101_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet101_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet18_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet50_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_thor_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - se-resnet50_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - shufflenetv1: - fmk: mindir - input_number: 1 - - ssimae_mvtecadbottle_bs1: - fmk: mindir - input_number: 1 - - unet_bs_1_input_2: - fmk: mindir - input_number: 1 - - unet_nested_cell_bs_1_input_2: - fmk: mindir - input_number: 1 - - vgg16_cifar10_bs_64_Ascend: - fmk: mindir - input_number: 1 - - vgg16_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - - vgg19_cifar10_bs1: - fmk: mindir - input_number: 1 - - vgg19_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 1 - - vit_imagenet2012_bs1: - fmk: mindir - input_number: 1 - ---- - -tf: - browser_deepfm_v7: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_deepfm_v7_int64: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_scene1_v2: - fmk: tf - input_number: 1 - benchmark_shapes: 75,19910 - acc_threshold: 0.5 - - browser_v15: - fmk: tf - input_number: 1 - benchmark_shapes: 75,10896 - acc_threshold: 0 - - browser_v36: - fmk: tf - input_number: 2 - benchmark_shapes: 75,190:75,9120 - acc_threshold: 0.00002 - - browser_v50: - fmk: tf - input_number: 2 - benchmark_shapes: 75,276:75,276 - acc_threshold: 0.5 - - browser_v50_int64: - fmk: tf - input_number: 2 - benchmark_shapes: 75,276:75,276 - acc_threshold: 0.5 - - browser_v58: - fmk: tf - input_number: 1 - benchmark_shapes: 75,26180 - acc_threshold: 0 - - browser_v7: - fmk: tf - input_number: 1 - benchmark_shapes: 75,39160 - acc_threshold: 0 - - browser_v79: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - browser_v79_int32: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - bolt_segment: - fmk: tf - input_number: 1 - - decoder_step_nocumsum_v5: - fmk: tf - input_number: 13 - benchmark_shapes: 1,512:1,512:1,512:1,512:1,512:1,127,320:1,1429,2:1,127:1:1,127:1,512:1,80:1,127 - - densenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - female_model_step2_int16_noiseout: - fmk: tf - input_number: 66 - - fsr_270_mindspore: - fmk: tf - input_number: 1 - - fsr_360_mindspore: - fmk: tf - input_number: 1 - - fsr_720_mindspore: - fmk: tf - input_number: 1 - - g_00730000_female10_frames_tf1: - fmk: tf - input_number: 150 - - hiai_AADB_HADB_MBV2_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_asr_ctc: - fmk: tf - input_number: 2 - - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: - fmk: tf - input_number: 2 - - - hiai_cn_recognize_modify_padv2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,32,512,1 - - hiai_cpu_face_emotion: - fmk: tf - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tf - input_number: 1 - - hiai_cpu_face_headpose: - fmk: tf - input_number: 1 - - hiai_ctpn_feature_map: - fmk: tf - input_number: 1 - - hiai_cv_focusShootOCRModel_02: - fmk: tf - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tf - input_number: 1 - - hiai_cv_poseEstimation: - fmk: tf - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tf - input_number: 1 - - hiai_dress_detect: - fmk: tf - input_number: 1 - benchmark_shapes: 1,960,960,3 - - hiai_face_model_npu: - fmk: tf - input_number: 1 - - hiai_frozen_inference_graph: - fmk: tf - input_number: 1 - benchmark_shapes: 1,300,300,3 - - hiai_ghostnet: - fmk: tf - input_number: 1 - - hiai_humanDetection: - fmk: tf - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tf - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tf - input_number: 1 - - hiai_label_and_video: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_latin_ocr: - fmk: tf - input_number: 1 - - hiai_latin_ocr_1: - fmk: tf - input_number: 1 - - hiai_lm_inference_graph: - fmk: tf - input_number: 1 - - hiai_model_0909_kd_rot_ps_softmax: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_nlu_model: - fmk: tf - input_number: 3 - benchmark_shapes: 1,16:1,16:1,16 - - hiai_nlu_model_multi: - fmk: tf - input_number: 6 - benchmark_shapes: 1,32:1,32:1,32:1,74:1,11:1,6 - - hiai_nlu_model_single: - fmk: tf - input_number: 3 - benchmark_shapes: 1,32:1,32:1,32 - - hiai_nlu_model_v2: - fmk: tf - input_number: 7 - benchmark_shapes: 1,5:1,5:1,5:1,98:1,174:1,6:1,5 - - hiai_PoseEstimation_Pcm: - fmk: tf - input_number: 1 - - hiai_ssd_mobilenetv2_object: - fmk: tf - input_number: 1 - - hiai_transformer_encoder: - fmk: tf - input_number: 15 - - inception_resnet_v2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - inception_v3: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - inception_v4: - fmk: tf - input_number: 1 - benchmark_shapes: 1,299,299,3 - - matmul: - fmk: tf - input_number: 1 - - ml_ei_headpose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,64,64,3 - - ml_ei_landmark: - fmk: tf - input_number: 1 - benchmark_shapes: 1,160,160,3 - - ml_face_openclose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,32,32,3 - - ml_female_model_step6_noiseout: - fmk: tf - input_number: 66 - - ml_male_model_step6_noiseout: - fmk: tf - input_number: 66 - - ml_noya_tts_melgan: - fmk: tf - input_number: 1 - benchmark_shapes: 16,16,80 - - ml_object_detect: - fmk: tf - input_number: 1 - benchmark_shapes: 1,288,288,3 - - ml_ocr_jk: - fmk: tf - input_number: 1 - - ml_ocr_latin: - fmk: tf - input_number: 1 - - ml_tts_vocoder: - fmk: tf - input_number: 66 - - ml_video_edit_enhance: - fmk: tf - input_number: 1 - - ml_video_edit_generate_filter: - fmk: tf - input_number: 1 - - ml_video_edit_img_segment_adaptise: - fmk: tf - input_number: 2 - - ml_video_edit_oneclick_adaptis: - fmk: tf - input_number: 3 - - ml_video_edit_shot_selection_opticalFlow: - fmk: tf - input_number: 1 - - ml_video_edit_video_segment_gauss_adaptis_part2: - fmk: tf - input_number: 2 - - ml_vision_guide_detection1: - fmk: tf - input_number: 1 - - ml_vision_guide_detection2: - fmk: tf - input_number: 1 - benchmark_shapes: 1,320,320,1 - - ml_vision_guide_detection3: - fmk: tf - input_number: 1 - - mnasnet_1.0_224: - fmk: tf - input_number: 1 - - mnasnet_1.3_224: - fmk: tf - input_number: 1 - - mobilenet_v1_0.25_128_frozen: - fmk: tf - input_number: 1 - benchmark_shapes: 1,128,128,3 - - mobilenet_v2_1.0_224_frozen: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_AADB_HADB_MBV2_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_AADB_HADB_MBV3_model: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - mtk_age_gender: - fmk: tf - input_number: 1 - - mtk_face_features_v1: - fmk: tf - input_number: 1 - - mtk_model_ckpt: - fmk: tf - input_number: 1 - - mtk_model_face_dress: - fmk: tf - input_number: 1 - benchmark_shapes: 1,128,128,3 - - mtk_model_normalize_object_scene_ps_20200519: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - nasnet_large: - fmk: tf - input_number: 1 - benchmark_shapes: 1,331,331,3 - - nasnet_mobile: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - Q_crnn_ori_75w_slim_norm: - fmk: tf - input_number: 1 - - Q_crnn_ori_v2_405001_notrans_nopre: - fmk: tf - input_number: 1 - - Q_crnn_screen_slim400w_more_20w: - fmk: tf - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: - fmk: tf - input_number: 1 - - Q_hand_0812: - fmk: tf - input_number: 1 - - Q_inception-249970-672-11-16: - fmk: tf - input_number: 1 - - scan_hms_angle: - fmk: tf - input_number: 1 - - scan_hms_detect: - fmk: tf - input_number: 1 - - siteAI_trans_nonlinear: - fmk: tf - input_number: 1 - benchmark_shapes: 1,137 - - siteAI_trans_nonlinear134g: - fmk: tf - input_number: 1 - benchmark_shapes: 1,137 - - siteAI_trans_nonlinear134g_nrz: - fmk: tf - input_number: 1 - benchmark_shapes: 1,182 - - siteAI_trans_nonlinear40g: - fmk: tf - input_number: 1 - benchmark_shapes: 1,271 - - siteAI_wireless_depress_w: - fmk: tf - input_number: 1 - benchmark_shapes: 1,36 - - siteAI_wireless_restore_w: - fmk: tf - input_number: 1 - benchmark_shapes: 1,36 - - squeezenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - tensor_dot: - fmk: tf - input_number: 1 - benchmark_shapes: 1,217 - - tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: - fmk: tf - input_number: 14 - benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 - - unet_model_reconstruct: - fmk: tf - input_number: 1 - benchmark_shapes: 1,256,256,3 - ---- - -caffe: - 2012_ATLANTA_10class_20190131_v4.0: - fmk: caffe - input_number: 1 - - 2012_ATLANTA_1class_20190621_v4.x_nomean: - fmk: caffe - input_number: 1 - - 6c_seg_nomean_20200610: - fmk: caffe - input_number: 1 - - age_new: - fmk: caffe - input_number: 1 - - bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 - - bank_card_recognition_fcny: - fmk: caffe - input_number: 1 - - bolt_deploy_color-server: - fmk: caffe - input_number: 1 - - deconv_test_model: - fmk: caffe - input_number: 1 - - deconvs_model: - fmk: caffe - input_number: 1 - - detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - detection_retinaface_fix: - fmk: caffe - input_number: 1 - - emotion: - fmk: caffe - input_number: 1 - - gender_res_large_deploy: - fmk: caffe - input_number: 1 - - hdc_age_medium: - fmk: caffe - input_number: 1 - - hdc_contour_pose_128: - fmk: caffe - input_number: 1 - - hdc_Face_Aesthetic_MTI_Aesthetic: - fmk: caffe - input_number: 1 - - hdc_fivembnet: - fmk: caffe - input_number: 1 - - hdc_mobilenetface: - fmk: caffe - input_number: 1 - - hdc_ocr_recog_horizontal: - fmk: caffe - input_number: 1 - - hdc_resnet: - fmk: caffe - input_number: 1 - - hdc_retinaface: - fmk: caffe - input_number: 1 - - hiai_cpu_face_attr: - fmk: caffe - input_number: 1 - - hiai_cpu_face_detect: - fmk: caffe - input_number: 1 - - hiai_cpu_face_hat: - fmk: caffe - input_number: 1 - - hiai_cv_aestheticsEngineModel_osp: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_01: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_03: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_04: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_06: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_07: - fmk: caffe - input_number: 1 - - hiai_face_attr1: - fmk: caffe - input_number: 1 - - hiai_face_detect_rfb: - fmk: caffe - input_number: 1 - - hiai_face_landmark: - fmk: caffe - input_number: 1 - - hiai_face_pose_tuku: - fmk: caffe - input_number: 1 - - hiai_face_recognition_1: - fmk: caffe - input_number: 1 - - hiai_face_RFB-Epoch-170-no-transpose: - fmk: caffe - input_number: 1 - - hiai_human_seg: - fmk: caffe - input_number: 1 - - hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: - fmk: caffe - input_number: 1 - - hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: - fmk: caffe - input_number: 1 - - hiai_semantic_seg: - fmk: caffe - input_number: 1 - - hiai_video_seg: - fmk: caffe - input_number: 1 - - HWSR-s_256_256: - fmk: caffe - input_number: 1 - - identify_card_detect_tmp: - fmk: caffe - input_number: 1 - - ml_2012_ocr_detection_caffe_tmp: - fmk: caffe - input_number: 1 - - ml_2012_ocr_rec_caffe: - fmk: caffe - input_number: 1 - - ml_ARengine23_bodypose: - fmk: caffe - input_number: 1 - - ml_bank_detect_0312_tmp: - fmk: caffe - input_number: 1 - - ml_bank_recog: - fmk: caffe - input_number: 1 - - ml_bodymask: - fmk: caffe - input_number: 1 - - ml_face_age: - fmk: caffe - input_number: 1 - - ml_face_beard: - fmk: caffe - input_number: 1 - - ml_face_compare: - fmk: caffe - input_number: 1 - - ml_face_contour: - fmk: caffe - input_number: 1 - - ml_face_div_parsing: - fmk: caffe - input_number: 1 - - ml_face_emotion: - fmk: caffe - input_number: 1 - - ml_face_glasses: - fmk: caffe - input_number: 1 - - ml_face_hat: - fmk: caffe - input_number: 1 - - ml_face_isface: - fmk: caffe - input_number: 1 - - ml_face_landmark: - fmk: caffe - input_number: 1 - - ml_face_mnet: - fmk: caffe - input_number: 1 - - ml_face_pose: - fmk: caffe - input_number: 1 - - ml_face_sex: - fmk: caffe - input_number: 1 - - ml_face_tracking: - fmk: caffe - input_number: 1 - - ml_hand_3d_detection: - fmk: caffe - input_number: 1 - - ml_hand_3d_regression: - fmk: caffe - input_number: 1 - - ml_Hand_deploy: - fmk: caffe - input_number: 1 - - ml_hand_detection: - fmk: caffe - input_number: 1 - - ml_handpose: - fmk: caffe - input_number: 1 - - ml_hardware_eyeclose: - fmk: caffe - input_number: 1 - - ml_hardware_liveness: - fmk: caffe - input_number: 1 - - ml_hardware_pose: - fmk: caffe - input_number: 1 - - ml_Heatmap_depth_180240: - fmk: caffe - input_number: 2 - - ml_Heatmap_depth_240180: - fmk: caffe - input_number: 2 - - ml_lable_model_hebing_device: - fmk: caffe - input_number: 1 - - ml_liveness_detect_landmark_tmp: - fmk: caffe - input_number: 1 - - ml_location_scene_division: - fmk: caffe - input_number: 1 - - ml_ocr_bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 - - ml_ocr_bank_card_recognition_fcny: - fmk: caffe - input_number: 1 - - ml_ocr_cn: - fmk: caffe - input_number: 1 - - ml_ocr_detect_20200305: - fmk: caffe - input_number: 1 - - ml_ocr_identify_card_detect_tmp: - fmk: caffe - input_number: 1 - - ml_ocr_identify_card_fcny: - fmk: caffe - input_number: 1 - - ml_ocr_sfz_add_final_0325: - fmk: caffe - input_number: 1 - - ml_ocr_sfz_detect_0325_tmp: - fmk: caffe - input_number: 1 - - ml_segmentation_atlanta_1: - fmk: caffe - input_number: 1 - - ml_segmentation_atlanta_10: - fmk: caffe - input_number: 1 - - ml_segmentation_matting: - fmk: caffe - input_number: 1 - - ml_tabel_recog: - fmk: caffe - input_number: 1 - - ml_text_division: - fmk: caffe - input_number: 1 - - ml_video_edit_detect_20211111: - fmk: caffe - input_number: 1 - - ml_video_edit_dynamic_effect_MTI_seg5c_v1: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_20211119: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_v3: - fmk: caffe - input_number: 1 - - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: - fmk: caffe - input_number: 1 - - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: - fmk: caffe - input_number: 1 - - ml_video_edit_have_imageProcessLayer_interpTo145_20201015: - fmk: caffe - input_number: 1 - - ml_video_edit_img_segment: - fmk: caffe - input_number: 1 - - ml_video_edit_Mnet: - fmk: caffe - input_number: 1 - - ml_video_edit_MnetN367_extract_1010_pay: - fmk: caffe - input_number: 1 - - ml_video_edit_moon_mode_moon_seg: - fmk: caffe - input_number: 1 - - ml_video_edit_moon_mode_MTI_9c_segmentation_v12: - fmk: caffe - input_number: 1 - - ml_video_edit_person_divison_pic: - fmk: caffe - input_number: 1 - - ml_video_edit_person_divison_video: - fmk: caffe - input_number: 2 - - ml_video_edit_reid: - fmk: caffe - input_number: 1 - - ml_video_edit_seg_320: - fmk: caffe - input_number: 1 - - ml_video_edit_v10_best_model_nomean_20200723: - fmk: caffe - input_number: 1 - - ml_video_edit_video_segment_gauss_adaptis_part1: - fmk: caffe - input_number: 1 - - mnet: - fmk: caffe - input_number: 1 - - Mnet6_0312_extract_pay: - fmk: caffe - input_number: 1 - - model_hebing_3branch: - fmk: caffe - input_number: 1 - - mtk_2012_ATLANTA_10class_20190614_v41: - fmk: caffe - input_number: 1 - - mtk_detect_mbv1_640_480_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_face_recognition_v1: - fmk: caffe - input_number: 1 - - plat_isface: - fmk: caffe - input_number: 1 - - pose_3d: - fmk: caffe - input_number: 1 - - PoseNet_dla_17_x512_tmp: - fmk: caffe - input_number: 1 - - recognition: - fmk: caffe - input_number: 1 - - retinaface: - fmk: caffe - input_number: 1 - - Sport_Health_Tech_pose_iter: - fmk: caffe - input_number: 1 - ---- - -onnx: - 01-face_det_400_400: - fmk: onnx - input_number: 1 - input_suffix: .onnx.bin - output_suffix: .onnx.out - benchmark_shapes: 1,400,400,3 - acc_threshold: 4.5 - - adversarial_pruning: - fmk: onnx - input_number: 1 - - bloom_hongmo_detection_tmp: - fmk: onnx - input_number: 1 - - candy-9: - fmk: onnx - input_number: 1 - - carbu_intelligent_cockpit_fasttext_best: - fmk: onnx - input_number: 1 - - CloudBU_FSRCNN_RTC_8ch_3450_QP9: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,225,225,3 - - CloudBU_rfdn_rtc_x2_ver2_13: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,225,225,3 - - CloudBU_rfdn_rtc_x2_ver2_3450: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,225,225,3 - - crnn_lite_lstm_v2: - fmk: onnx - input_number: 1 - benchmark_shapes: 32,32,32,1 - - densenet-9: - fmk: onnx - input_number: 1 - - efficientnet-lite4-11: - fmk: onnx - input_number: 1 - - emotion-ferplus-8: - fmk: onnx - input_number: 1 - - encoder: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,32,83 - - gender_lstm_scd: - fmk: onnx - input_number: 1 - - gender_lstm_vad: - fmk: onnx - input_number: 1 - - gender_resnet34_lzl: - fmk: onnx - input_number: 1 - - googlenet-9: - fmk: onnx - input_number: 1 - - gts_text_detection: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,224,224,3 - - gts_version-RFB-320_simplified: - fmk: onnx - input_number: 1 - - Harmony_Voiceprint: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,200,40,1 - - Harmony_Voiceprint_resnet18: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,150,40,1 - - hdc_efficientnet_b3_1w_class: - fmk: onnx - input_number: 1 - - hdc_Face_Emotion_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_Face_Landmark5_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_Image_Aesthetic_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_mobilenet_1w_class: - fmk: onnx - input_number: 1 - - hdc_ocr_attention: - fmk: onnx - input_number: 1 - - hdc_ocr_detect_tmp: - fmk: onnx - input_number: 1 - - hdc_resnet_1w_class: - fmk: onnx - input_number: 1 - - Huawei_video_rvm_mobilenetv3_192: - fmk: onnx - input_number: 6 - - inception-v1-9: - fmk: onnx - input_number: 1 - - inception-v2-9: - fmk: onnx - input_number: 1 - - Ireland_face_detector: - fmk: onnx - input_number: 1 - - Ireland_gaze_corrector: - fmk: onnx - input_number: 3 - acc_threshold: 1 - - Ireland_gaze_estimator_ng: - fmk: onnx - input_number: 1 - - Ireland_ulfgf: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,240,320,3 - - ml_2012_ocr_cn: - fmk: onnx - input_number: 1 - - ml_2012_ocr_cn_noLSTM: - fmk: onnx - input_number: 1 - - ml_2012_ocr_detection_tmp: - fmk: onnx - input_number: 1 - - ml_asr_encoder_int8_202103: - fmk: onnx - input_number: 1 - - ml_audio_edit_rhythm_check_model: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,1024,81,1 - - ml_audio_kit_vocals_test: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,512,1024,2 - acc_threshold: 2 - - ml_edu_kit_hand_detection: - fmk: onnx - input_number: 1 - - ml_edu_kit_hand_key_position: - fmk: onnx - input_number: 1 - - ml_ei_facedetection: - fmk: onnx - input_number: 1 - - ml_face_3d: - fmk: onnx - input_number: 1 - - ml_facedetector: - fmk: onnx - input_number: 1 - - ml_location_lane_counter: - fmk: onnx - input_number: 1 - - ml_location_lane_counter0: - fmk: onnx - input_number: 1 - - ml_motion_capture_nanodet_m_0.5x_people_0928_sim: - fmk: onnx - input_number: 1 - - ml_motion_capture_smpl_0916: - fmk: onnx - input_number: 3 - - ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: - fmk: onnx - input_number: 5 - - ml_table_detection_fp32_tmp: - fmk: onnx - input_number: 1 - - ml_table_segment: - fmk: onnx - input_number: 1 - - ml_video_edit_art_generate: - fmk: onnx - input_number: 1 - - ml_video_edit_art_generate_20210513: - fmk: onnx - input_number: 1 - - ml_video_edit_art_transfer_20210513: - fmk: onnx - input_number: 3 - - ml_video_edit_dimming_tech_model_345000_color: - fmk: onnx - input_number: 2 - - ml_video_edit_dimming_tech_model_studio_20: - fmk: onnx - input_number: 2 - - ml_video_edit_dimming_tech_model_styleGan: - fmk: onnx - input_number: 2 - - ml_video_edit_enhance_update_tmp: - fmk: onnx - input_number: 1 - - ml_video_edit_face_edit_face3d: - fmk: onnx - input_number: 1 - - ml_video_edit_face_edit_pix2pixHD_unet: - fmk: onnx - input_number: 1 - - ml_video_edit_face_edit_retinaface: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,120,128,3 - - ml_video_edit_hair_dyeing_migrate_v2: - fmk: onnx - input_number: 4 - - ml_video_edit_hair_dyeing_migrate_v2_fix: - fmk: onnx - input_number: 4 - - - ml_video_edit_judge: - fmk: onnx - input_number: 1 - - ml_video_edit_makeup_mobilenetv203: - fmk: onnx - input_number: 1 - - ml_video_edit_moon_mode_sky_refine: - fmk: onnx - input_number: 2 - benchmark_shapes: 1,256,256,4:1,88,88,4 - - ml_video_edit_shot_selection_face_emotion: - fmk: onnx - input_number: 1 - - ml_video_edit_shot_selection_yolox_nano_coco_reduced: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_autoportrait: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_candy: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_gongnongbing: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_starry: - fmk: onnx - input_number: 1 - - ml_video_edit_styleCode_part1: - fmk: onnx - input_number: 1 - - ml_video_edit_styleCode_part2: - fmk: onnx - input_number: 9 - - ml_video_edit_vignet: - fmk: onnx - input_number: 1 - - ml_voice_detect: - fmk: onnx - input_number: 1 - - mnist-8: - fmk: onnx - input_number: 1 - - mobilenetv2-7: - fmk: onnx - input_number: 1 - - mosaic-9: - fmk: onnx - input_number: 1 - - mtk_detect_mbv1_640_480: - fmk: onnx - input_number: 1 - - mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,480,640,3 - - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400: - fmk: onnx - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv2-shortcut-400-400: - fmk: onnx - input_number: 1 - - mtk_detect-mbv2-shortcut-400-400-simplified: - fmk: onnx - input_number: 1 - - mtk_emotions-d2012-75: - fmk: onnx - input_number: 1 - - mtk_face_features_v2: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,256,192,3 - - mtk_face_features_v3: - fmk: onnx - input_number: 1 - - mtk_face_recognition_v2: - fmk: onnx - input_number: 1 - - mtk_face_recognition_v3: - fmk: onnx - input_number: 1 - - pointilism-9: - fmk: onnx - input_number: 1 - - porseg_tmp: - fmk: onnx - input_number: 2 - - psenet_lite_mbv2: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,32,32,3 - - Q888_CV_face_recognition_self: - fmk: onnx - input_number: 1 - - Q888_face_recognition: - fmk: onnx - input_number: 1 - - Q888_iris_detect: - fmk: onnx - input_number: 1 - - rain-princess-9: - fmk: onnx - input_number: 1 - - rcnn-ilsvrc13-9: - fmk: onnx - input_number: 1 - - residual_distill_bs_1: - fmk: onnx - input_number: 1 - - residual_distill_bs_32: - fmk: onnx - input_number: 1 - - residual_distill_cifar10_bs_1: - fmk: onnx - input_number: 1 - - residual_distill_cifar10_bs_32: - fmk: onnx - input_number: 1 - - residual_distill_res34_cifar10_bs_1_update: - fmk: onnx - input_number: 1 - - residual_distill_res50_cifar10_bs_1_update: - fmk: onnx - input_number: 1 - - rpnt_pdr_conv2d_16_fixed_last: - fmk: onnx - input_number: 1 - - rvm_mobilenetv3_192: - fmk: onnx - input_number: 6 - - shufflenet-9: - fmk: onnx - input_number: 1 - - shufflenet-v2-10: - fmk: onnx - input_number: 1 - - simple_IPS_model_4D_input: - fmk: onnx - input_number: 1 - - squeezenet1.0-9: - fmk: onnx - input_number: 1 - - squeezenet1.1-7: - fmk: onnx - input_number: 1 - - ssd-10: - fmk: onnx - input_number: 1 - - super-resolution-10: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,224,224,1 - - tinyyolov2-8: - fmk: onnx - input_number: 1 - benchmark_shapes: 1,416,416,3 - - udnie-9: - fmk: onnx - input_number: 1 - - yolov5s: - fmk: onnx - input_number: 1 - ---- - -tflite: - albert_lite_base_squadv1_1: - fmk: tflite - input_number: 3 - - bloom_model_age_gender: - fmk: tflite - input_number: 1 - - bloom_new_detect: - fmk: tflite - input_number: 1 - - deeplabv3_1_default_1: - fmk: tflite - input_number: 1 - - deeplabv3_257_mv_gpu: - fmk: tflite - input_number: 1 - - densenet: - fmk: tflite - input_number: 1 - - efficientnet_lite0_fp32_2: - fmk: tflite - input_number: 1 - - gts_detect_5k_tf115: - fmk: tflite - input_number: 1 - - hdc_tb_cn_neg: - fmk: tflite - input_number: 3 - acc_threshold: 0.5 - - hiai_AADB_HADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: - fmk: tflite - input_number: 2 - - - hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_chinese_english_recognize_model_float32: - fmk: tflite - input_number: 1 - - hiai_cn_recognize_modify_padv2: - fmk: tflite - input_number: 1 - - hiai_cpu_face_emotion: - fmk: tflite - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tflite - input_number: 1 - - hiai_cpu_face_headpose: - fmk: tflite - input_number: 1 - - hiai_ctpn_feature_map: - fmk: tflite - input_number: 1 - - hiai_cv_focusShootOCRModel_02: - fmk: tflite - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tflite - input_number: 1 - - hiai_cv_labelDetectorModel_v2: - fmk: tflite - input_number: 1 - - hiai_cv_labelDetectorModel_v3: - fmk: tflite - input_number: 2 - - hiai_cv_labelDetectorModel_v4: - fmk: tflite - input_number: 1 - - hiai_cv_poseEstimation: - fmk: tflite - input_number: 1 - - hiai_cv_saliencyDetectorModel: - fmk: tflite - input_number: 1 - - hiai_detect_curve_model_float32: - fmk: tflite - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tflite - input_number: 1 - - hiai_detectmodel_desnet_256_128_64_32: - fmk: tflite - input_number: 1 - - hiai_dress_detect: - fmk: tflite - input_number: 1 - - hiai_face_model_npu: - fmk: tflite - input_number: 1 - - hiai_frozen_inference_graph: - fmk: tflite - input_number: 1 - - hiai_ghostnet: - fmk: tflite - input_number: 1 - - hiai_humanDetection: - fmk: tflite - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tflite - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tflite - input_number: 1 - - hiai_label_and_video: - fmk: tflite - input_number: 1 - - hiai_latin_ocr: - fmk: tflite - input_number: 1 - - hiai_latin_ocr_1: - fmk: tflite - input_number: 1 - - hiai_lm_inference_graph: - fmk: tflite - input_number: 1 - - hiai_model_0909_kd_rot_ps_softmax: - fmk: tflite - input_number: 1 - - hiai_model_normalize_object_scene_ps_20200519: - fmk: tflite - input_number: 1 - - hiai_object_detect_814: - fmk: tflite - input_number: 1 - - hiai_PoseEstimation_Pcm: - fmk: tflite - input_number: 1 - - hiai_ssd_mobilenetv2_object: - fmk: tflite - input_number: 1 - - hiai_vad: - fmk: tflite - input_number: 2 - - ide_label_base: - fmk: tflite - input_number: 1 - - ide_label_retrained: - fmk: tflite - input_number: 1 - - inception_resnet_v2: - fmk: tflite - input_number: 1 - - inception_v3: - fmk: tflite - input_number: 1 - - inception_v4: - fmk: tflite - input_number: 1 - - lite-model_albert_lite_base_squadv1_metadata_1: - fmk: tflite - input_number: 3 - - lite-model_mobilebert_1_metadata_1: - fmk: tflite - input_number: 3 - - lite-model_on_device_vision_classifier_popular_us_products_V1_1: - fmk: tflite - input_number: 1 - - lite-model_on_device_vision_classifier_popular_wine_V1_1: - fmk: tflite - input_number: 1 - - lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: - fmk: tflite - input_number: 1 - - mindspore_text_classification_tflite: - fmk: tflite - input_number: 1 - - ml_ei_headpose: - fmk: tflite - input_number: 1 - - ml_ei_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_ei_landmark: - fmk: tflite - input_number: 1 - - ml_ei_landmark_pb2tflite: - fmk: tflite - input_number: 1 - - ml_face_openclose: - fmk: tflite - input_number: 1 - - ml_face_openclose_tflite: - fmk: tflite - input_number: 1 - - ml_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_location: - fmk: tflite - input_number: 1 - - ml_object_detect: - fmk: tflite - input_number: 1 - - ml_object_detect_1: - fmk: tflite - input_number: 1 - - ml_object_detect_pb2tflite: - fmk: tflite - input_number: 1 - - ml_ocr_jk: - fmk: tflite - input_number: 1 - - ml_ocr_jk_pb2tflite: - fmk: tflite - input_number: 1 - - ml_ocr_latin: - fmk: tflite - input_number: 1 - - ml_ocr_latin_pb2tflite: - fmk: tflite - input_number: 1 - - ml_pic_shopping: - fmk: tflite - input_number: 1 - - ml_pic_shopping_pb2tflite: - fmk: tflite - input_number: 1 - - ml_tacotron_decoder_step_stf: - fmk: tflite - input_number: 9 - benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 - - ml_text_correction: - fmk: tflite - input_number: 1 - - ml_video_edit_img_segment_adaptise_pb2tflite: - fmk: tflite - input_number: 2 - - ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: - fmk: tflite - input_number: 2 - - ml_vision_guide_detection1_pb2tflite: - fmk: tflite - input_number: 1 - - ml_vision_guide_detection3_pb2tflite: - fmk: tflite - input_number: 1 - - mnasnet_0.50_224_1_metadata_1: - fmk: tflite - input_number: 1 - - mnasnet_1.3_224: - fmk: tflite - input_number: 1 - - mnist: - fmk: tflite - input_number: 1 - - mobilebert_1_default_1: - fmk: tflite - input_number: 3 - - mobilenet: - fmk: tflite - input_number: 1 - - mobilenet_v1_0.25_128: - fmk: tflite - input_number: 1 - - mobilenet_v2_1.0_224: - fmk: tflite - input_number: 1 - - model_emotions_0727_nosoftmax: - fmk: tflite - input_number: 1 - - mtk_276landmark_0913: - fmk: tflite - input_number: 1 - - mtk_AADB_HADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - mtk_AADB_HADB_MBV3_model_fp32: - fmk: tflite - input_number: 1 - - mtk_age_gender: - fmk: tflite - input_number: 1 - - mtk_convert_model: - fmk: tflite - input_number: 1 - - mtk_face_features_v1: - fmk: tflite - input_number: 1 - - mtk_face_recognition: - fmk: tflite - input_number: 1 - - mtk_model_ckpt: - fmk: tflite - input_number: 1 - - mtk_model_emotions_0727_nosoftmax: - fmk: tflite - input_number: 1 - - mtk_model_face_dress: - fmk: tflite - input_number: 1 - - mtk_model_normalize_object_scene_ps_20200519_f32: - fmk: tflite - input_number: 1 - - mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax: - fmk: tflite - input_number: 1 - - mtk_new_detect: - fmk: tflite - input_number: 1 - - multi_person_mobilenet_v1_075_float: - fmk: tflite - input_number: 1 - - nasnet_large: - fmk: tflite - input_number: 1 - - nasnet_mobile: - fmk: tflite - input_number: 1 - - posenet_mobilenet_float_075_1_default_1: - fmk: tflite - input_number: 1 - - Q_AADB_HADB_MBV2_model: - fmk: tflite - input_number: 1 - - Q_convert: - fmk: tflite - input_number: 1 - - Q_crnn_ori_75w_slim_norm_pb2tflite: - fmk: tflite - input_number: 1 - - Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite: - fmk: tflite - input_number: 1 - - Q_crnn_screen_slim400w_more_20w_pb2tflite: - fmk: tflite - input_number: 1 - - Q_detect_fpn_add_inception-1448650: - fmk: tflite - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: - fmk: tflite - input_number: 1 - - Q_focusocr_cn_recog: - fmk: tflite - input_number: 1 - - Q_focusocr_jk_recog: - fmk: tflite - input_number: 1 - - Q_hand_0812_pb2tflite: - fmk: tflite - input_number: 1 - - Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite: - fmk: tflite - input_number: 1 - - Q_iMaxSR_RGB_385_p_pb2tflite: - fmk: tflite - input_number: 1 - - Q_inception-249970-672-11-16_pb2tflite: - fmk: tflite - input_number: 1 - - Q_language_model_hrmini_Q4_b4_17w: - fmk: tflite - input_number: 1 - - Q_object_scene: - fmk: tflite - input_number: 1 - - Q888_age_gender_orderd: - fmk: tflite - input_number: 1 - - Q888_face_dress_mv3y: - fmk: tflite - input_number: 1 - - Q888_face_emo_dress_mv3_orderd: - fmk: tflite - input_number: 1 - - Q888_HADB_AADB_MBV2_model_fp32: - fmk: tflite - input_number: 1 - - Q888_isface: - fmk: tflite - input_number: 1 - - Q888_landmark: - fmk: tflite - input_number: 1 - - Q888_lapa158_unet_0924: - fmk: tflite - input_number: 1 - - Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax: - fmk: tflite - input_number: 1 - - Q888_new_detect: - fmk: tflite - input_number: 1 - - Q888_pose: - fmk: tflite - input_number: 1 - - resnet: - fmk: tflite - input_number: 1 - - resnet_v2_101_299: - fmk: tflite - input_number: 1 - - scan_hms_angle_pb2tflite: - fmk: tflite - input_number: 1 - - scan_hms_angle1: - fmk: tflite - input_number: 1 - - scan_hms_detect: - fmk: tflite - input_number: 1 - - scan_hms_detect_pb2tflite: - fmk: tflite - input_number: 1 - - siteAI_digcom_AI_ECN: - fmk: tflite - input_number: 1 - - siteAI_digcom_g2v_keras: - fmk: tflite - input_number: 1 - - siteAI_trans_nonlinear: - fmk: tflite - input_number: 1 - - siteAI_trans_tcpclassify: - fmk: tflite - input_number: 1 - - siteAI_wireless_depress_w: - fmk: tflite - input_number: 1 - - siteAI_wireless_restore_w: - fmk: tflite - input_number: 1 - - squeezenet: - fmk: tflite - input_number: 1 - - text_classification: - fmk: tflite - input_number: 1 - - unet_mbv2_05_104pts: - fmk: tflite - input_number: 1 +mindir: + deepfm_criteo_bs_16000_Ascend: + fmk: mindir + input_number: 2 + + efficientnetb0_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + efficientnetb1_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + efficientnetb2_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 4 + + efficientnetb3_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + inceptionv3_ascend: + fmk: mindir + input_number: 1 + + inceptionV4: + fmk: mindir + input_number: 1 + + mobilenetv3large_imagenet2012_bs1: + fmk: mindir + input_number: 1 + + mobilenetv3small_imagenet2012_bs1: + fmk: mindir + input_number: 1 + + pix2pix_facades_bs1: + fmk: mindir + input_number: 1 + + resnet101_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet101_imagenet_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet18_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_cifar10_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_cifar10_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet50_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_imagenet_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_thor_imagenet_bs_1_ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + se-resnet50_imagenet_bs_1_ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + shufflenetv1: + fmk: mindir + input_number: 1 + + ssimae_mvtecadbottle_bs1: + fmk: mindir + input_number: 1 + + unet_bs_1_input_2: + fmk: mindir + input_number: 1 + + unet_nested_cell_bs_1_input_2: + fmk: mindir + input_number: 1 + + vgg16_cifar10_bs_64_Ascend: + fmk: mindir + input_number: 1 + + vgg16_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + + vgg19_cifar10_bs1: + fmk: mindir + input_number: 1 + + vgg19_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 1 + + vit_imagenet2012_bs1: + fmk: mindir + input_number: 1 + +--- + +tf: + browser_deepfm_v7: + fmk: tf + input_number: 2 + benchmark_shapes: 200,94:200,94 + acc_threshold: 0.5 + + browser_deepfm_v7_int64: + fmk: tf + input_number: 2 + benchmark_shapes: 200,94:200,94 + acc_threshold: 0.5 + + browser_scene1_v2: + fmk: tf + input_number: 1 + benchmark_shapes: 75,19910 + acc_threshold: 0.5 + + browser_v15: + fmk: tf + input_number: 1 + benchmark_shapes: 75,10896 + acc_threshold: 0 + + browser_v36: + fmk: tf + input_number: 2 + benchmark_shapes: 75,190:75,9120 + acc_threshold: 0.00002 + + browser_v50: + fmk: tf + input_number: 2 + benchmark_shapes: 75,276:75,276 + acc_threshold: 0.5 + + browser_v50_int64: + fmk: tf + input_number: 2 + benchmark_shapes: 75,276:75,276 + acc_threshold: 0.5 + + browser_v58: + fmk: tf + input_number: 1 + benchmark_shapes: 75,26180 + acc_threshold: 0 + + browser_v7: + fmk: tf + input_number: 1 + benchmark_shapes: 75,39160 + acc_threshold: 0 + + browser_v79: + fmk: tf + input_number: 2 + benchmark_shapes: 10,294:10,294 + acc_threshold: 0.004 + + browser_v79_int32: + fmk: tf + input_number: 2 + benchmark_shapes: 10,294:10,294 + acc_threshold: 0.004 + + bolt_segment: + fmk: tf + input_number: 1 + + decoder_step_nocumsum_v5: + fmk: tf + input_number: 13 + benchmark_shapes: 1,512:1,512:1,512:1,512:1,512:1,127,320:1,1429,2:1,127:1:1,127:1,512:1,80:1,127 + + densenet: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + female_model_step2_int16_noiseout: + fmk: tf + input_number: 66 + + fsr_270_mindspore: + fmk: tf + input_number: 1 + + fsr_360_mindspore: + fmk: tf + input_number: 1 + + fsr_720_mindspore: + fmk: tf + input_number: 1 + + g_00730000_female10_frames_tf1: + fmk: tf + input_number: 150 + + hiai_AADB_HADB_MBV2_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_asr_ctc: + fmk: tf + input_number: 2 + + hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: + fmk: tf + input_number: 2 + + + hiai_cn_recognize_modify_padv2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,32,512,1 + + hiai_cpu_face_emotion: + fmk: tf + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tf + input_number: 1 + + hiai_cpu_face_headpose: + fmk: tf + input_number: 1 + + hiai_ctpn_feature_map: + fmk: tf + input_number: 1 + + hiai_cv_focusShootOCRModel_02: + fmk: tf + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tf + input_number: 1 + + hiai_cv_poseEstimation: + fmk: tf + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tf + input_number: 1 + + hiai_dress_detect: + fmk: tf + input_number: 1 + benchmark_shapes: 1,960,960,3 + + hiai_face_model_npu: + fmk: tf + input_number: 1 + + hiai_frozen_inference_graph: + fmk: tf + input_number: 1 + benchmark_shapes: 1,300,300,3 + + hiai_ghostnet: + fmk: tf + input_number: 1 + + hiai_humanDetection: + fmk: tf + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tf + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tf + input_number: 1 + + hiai_label_and_video: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_latin_ocr: + fmk: tf + input_number: 1 + + hiai_latin_ocr_1: + fmk: tf + input_number: 1 + + hiai_lm_inference_graph: + fmk: tf + input_number: 1 + + hiai_model_0909_kd_rot_ps_softmax: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_nlu_model: + fmk: tf + input_number: 3 + benchmark_shapes: 1,16:1,16:1,16 + + hiai_nlu_model_multi: + fmk: tf + input_number: 6 + benchmark_shapes: 1,32:1,32:1,32:1,74:1,11:1,6 + + hiai_nlu_model_single: + fmk: tf + input_number: 3 + benchmark_shapes: 1,32:1,32:1,32 + + hiai_nlu_model_v2: + fmk: tf + input_number: 7 + benchmark_shapes: 1,5:1,5:1,5:1,98:1,174:1,6:1,5 + + hiai_PoseEstimation_Pcm: + fmk: tf + input_number: 1 + + hiai_ssd_mobilenetv2_object: + fmk: tf + input_number: 1 + + hiai_transformer_encoder: + fmk: tf + input_number: 15 + + inception_resnet_v2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + inception_v3: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + inception_v4: + fmk: tf + input_number: 1 + benchmark_shapes: 1,299,299,3 + + matmul: + fmk: tf + input_number: 1 + + ml_ei_headpose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,64,64,3 + + ml_ei_landmark: + fmk: tf + input_number: 1 + benchmark_shapes: 1,160,160,3 + + ml_face_openclose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,32,32,3 + + ml_female_model_step6_noiseout: + fmk: tf + input_number: 66 + + ml_male_model_step6_noiseout: + fmk: tf + input_number: 66 + + ml_noya_tts_melgan: + fmk: tf + input_number: 1 + benchmark_shapes: 16,16,80 + + ml_object_detect: + fmk: tf + input_number: 1 + benchmark_shapes: 1,288,288,3 + + ml_ocr_jk: + fmk: tf + input_number: 1 + + ml_ocr_latin: + fmk: tf + input_number: 1 + + ml_tts_vocoder: + fmk: tf + input_number: 66 + + ml_video_edit_enhance: + fmk: tf + input_number: 1 + + ml_video_edit_generate_filter: + fmk: tf + input_number: 1 + + ml_video_edit_img_segment_adaptise: + fmk: tf + input_number: 2 + + ml_video_edit_oneclick_adaptis: + fmk: tf + input_number: 3 + + ml_video_edit_shot_selection_opticalFlow: + fmk: tf + input_number: 1 + + ml_video_edit_video_segment_gauss_adaptis_part2: + fmk: tf + input_number: 2 + + ml_vision_guide_detection1: + fmk: tf + input_number: 1 + + ml_vision_guide_detection2: + fmk: tf + input_number: 1 + benchmark_shapes: 1,320,320,1 + + ml_vision_guide_detection3: + fmk: tf + input_number: 1 + + mnasnet_1.0_224: + fmk: tf + input_number: 1 + + mnasnet_1.3_224: + fmk: tf + input_number: 1 + + mobilenet_v1_0.25_128_frozen: + fmk: tf + input_number: 1 + benchmark_shapes: 1,128,128,3 + + mobilenet_v2_1.0_224_frozen: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_AADB_HADB_MBV2_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_AADB_HADB_MBV3_model: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + mtk_age_gender: + fmk: tf + input_number: 1 + + mtk_face_features_v1: + fmk: tf + input_number: 1 + + mtk_model_ckpt: + fmk: tf + input_number: 1 + + mtk_model_face_dress: + fmk: tf + input_number: 1 + benchmark_shapes: 1,128,128,3 + + mtk_model_normalize_object_scene_ps_20200519: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + nasnet_large: + fmk: tf + input_number: 1 + benchmark_shapes: 1,331,331,3 + + nasnet_mobile: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + Q_crnn_ori_75w_slim_norm: + fmk: tf + input_number: 1 + + Q_crnn_ori_v2_405001_notrans_nopre: + fmk: tf + input_number: 1 + + Q_crnn_screen_slim400w_more_20w: + fmk: tf + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: + fmk: tf + input_number: 1 + + Q_hand_0812: + fmk: tf + input_number: 1 + + Q_inception-249970-672-11-16: + fmk: tf + input_number: 1 + + scan_hms_angle: + fmk: tf + input_number: 1 + + scan_hms_detect: + fmk: tf + input_number: 1 + + siteAI_trans_nonlinear: + fmk: tf + input_number: 1 + benchmark_shapes: 1,137 + + siteAI_trans_nonlinear134g: + fmk: tf + input_number: 1 + benchmark_shapes: 1,137 + + siteAI_trans_nonlinear134g_nrz: + fmk: tf + input_number: 1 + benchmark_shapes: 1,182 + + siteAI_trans_nonlinear40g: + fmk: tf + input_number: 1 + benchmark_shapes: 1,271 + + siteAI_wireless_depress_w: + fmk: tf + input_number: 1 + benchmark_shapes: 1,36 + + siteAI_wireless_restore_w: + fmk: tf + input_number: 1 + benchmark_shapes: 1,36 + + squeezenet: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + tensor_dot: + fmk: tf + input_number: 1 + benchmark_shapes: 1,217 + + tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: + fmk: tf + input_number: 14 + benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 + + unet_model_reconstruct: + fmk: tf + input_number: 1 + benchmark_shapes: 1,256,256,3 + +--- + +caffe: + 2012_ATLANTA_10class_20190131_v4.0: + fmk: caffe + input_number: 1 + + 2012_ATLANTA_1class_20190621_v4.x_nomean: + fmk: caffe + input_number: 1 + + 6c_seg_nomean_20200610: + fmk: caffe + input_number: 1 + + age_new: + fmk: caffe + input_number: 1 + + bank_card_detection_inception_tmp: + fmk: caffe + input_number: 1 + + bank_card_recognition_fcny: + fmk: caffe + input_number: 1 + + bolt_deploy_color-server: + fmk: caffe + input_number: 1 + + deconv_test_model: + fmk: caffe + input_number: 1 + + deconvs_model: + fmk: caffe + input_number: 1 + + detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + detection_retinaface_fix: + fmk: caffe + input_number: 1 + + emotion: + fmk: caffe + input_number: 1 + + gender_res_large_deploy: + fmk: caffe + input_number: 1 + + hdc_age_medium: + fmk: caffe + input_number: 1 + + hdc_contour_pose_128: + fmk: caffe + input_number: 1 + + hdc_Face_Aesthetic_MTI_Aesthetic: + fmk: caffe + input_number: 1 + + hdc_fivembnet: + fmk: caffe + input_number: 1 + + hdc_mobilenetface: + fmk: caffe + input_number: 1 + + hdc_ocr_recog_horizontal: + fmk: caffe + input_number: 1 + + hdc_resnet: + fmk: caffe + input_number: 1 + + hdc_retinaface: + fmk: caffe + input_number: 1 + + hiai_cpu_face_attr: + fmk: caffe + input_number: 1 + + hiai_cpu_face_detect: + fmk: caffe + input_number: 1 + + hiai_cpu_face_hat: + fmk: caffe + input_number: 1 + + hiai_cv_aestheticsEngineModel_osp: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_01: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_03: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_04: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_06: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_07: + fmk: caffe + input_number: 1 + + hiai_face_attr1: + fmk: caffe + input_number: 1 + + hiai_face_detect_rfb: + fmk: caffe + input_number: 1 + + hiai_face_landmark: + fmk: caffe + input_number: 1 + + hiai_face_pose_tuku: + fmk: caffe + input_number: 1 + + hiai_face_recognition_1: + fmk: caffe + input_number: 1 + + hiai_face_RFB-Epoch-170-no-transpose: + fmk: caffe + input_number: 1 + + hiai_human_seg: + fmk: caffe + input_number: 1 + + hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: + fmk: caffe + input_number: 1 + + hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: + fmk: caffe + input_number: 1 + + hiai_semantic_seg: + fmk: caffe + input_number: 1 + + hiai_video_seg: + fmk: caffe + input_number: 1 + + HWSR-s_256_256: + fmk: caffe + input_number: 1 + + identify_card_detect_tmp: + fmk: caffe + input_number: 1 + + ml_2012_ocr_detection_caffe_tmp: + fmk: caffe + input_number: 1 + + ml_2012_ocr_rec_caffe: + fmk: caffe + input_number: 1 + + ml_ARengine23_bodypose: + fmk: caffe + input_number: 1 + + ml_bank_detect_0312_tmp: + fmk: caffe + input_number: 1 + + ml_bank_recog: + fmk: caffe + input_number: 1 + + ml_bodymask: + fmk: caffe + input_number: 1 + + ml_face_age: + fmk: caffe + input_number: 1 + + ml_face_beard: + fmk: caffe + input_number: 1 + + ml_face_compare: + fmk: caffe + input_number: 1 + + ml_face_contour: + fmk: caffe + input_number: 1 + + ml_face_div_parsing: + fmk: caffe + input_number: 1 + + ml_face_emotion: + fmk: caffe + input_number: 1 + + ml_face_glasses: + fmk: caffe + input_number: 1 + + ml_face_hat: + fmk: caffe + input_number: 1 + + ml_face_isface: + fmk: caffe + input_number: 1 + + ml_face_landmark: + fmk: caffe + input_number: 1 + + ml_face_mnet: + fmk: caffe + input_number: 1 + + ml_face_pose: + fmk: caffe + input_number: 1 + + ml_face_sex: + fmk: caffe + input_number: 1 + + ml_face_tracking: + fmk: caffe + input_number: 1 + + ml_hand_3d_detection: + fmk: caffe + input_number: 1 + + ml_hand_3d_regression: + fmk: caffe + input_number: 1 + + ml_Hand_deploy: + fmk: caffe + input_number: 1 + + ml_hand_detection: + fmk: caffe + input_number: 1 + + ml_handpose: + fmk: caffe + input_number: 1 + + ml_hardware_eyeclose: + fmk: caffe + input_number: 1 + + ml_hardware_liveness: + fmk: caffe + input_number: 1 + + ml_hardware_pose: + fmk: caffe + input_number: 1 + + ml_Heatmap_depth_180240: + fmk: caffe + input_number: 2 + + ml_Heatmap_depth_240180: + fmk: caffe + input_number: 2 + + ml_lable_model_hebing_device: + fmk: caffe + input_number: 1 + + ml_liveness_detect_landmark_tmp: + fmk: caffe + input_number: 1 + + ml_location_scene_division: + fmk: caffe + input_number: 1 + + ml_ocr_bank_card_detection_inception_tmp: + fmk: caffe + input_number: 1 + + ml_ocr_bank_card_recognition_fcny: + fmk: caffe + input_number: 1 + + ml_ocr_cn: + fmk: caffe + input_number: 1 + + ml_ocr_detect_20200305: + fmk: caffe + input_number: 1 + + ml_ocr_identify_card_detect_tmp: + fmk: caffe + input_number: 1 + + ml_ocr_identify_card_fcny: + fmk: caffe + input_number: 1 + + ml_ocr_sfz_add_final_0325: + fmk: caffe + input_number: 1 + + ml_ocr_sfz_detect_0325_tmp: + fmk: caffe + input_number: 1 + + ml_segmentation_atlanta_1: + fmk: caffe + input_number: 1 + + ml_segmentation_atlanta_10: + fmk: caffe + input_number: 1 + + ml_segmentation_matting: + fmk: caffe + input_number: 1 + + ml_tabel_recog: + fmk: caffe + input_number: 1 + + ml_text_division: + fmk: caffe + input_number: 1 + + ml_video_edit_detect_20211111: + fmk: caffe + input_number: 1 + + ml_video_edit_dynamic_effect_MTI_seg5c_v1: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_20211119: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_v3: + fmk: caffe + input_number: 1 + + ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: + fmk: caffe + input_number: 1 + + ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: + fmk: caffe + input_number: 1 + + ml_video_edit_have_imageProcessLayer_interpTo145_20201015: + fmk: caffe + input_number: 1 + + ml_video_edit_img_segment: + fmk: caffe + input_number: 1 + + ml_video_edit_Mnet: + fmk: caffe + input_number: 1 + + ml_video_edit_MnetN367_extract_1010_pay: + fmk: caffe + input_number: 1 + + ml_video_edit_moon_mode_moon_seg: + fmk: caffe + input_number: 1 + + ml_video_edit_moon_mode_MTI_9c_segmentation_v12: + fmk: caffe + input_number: 1 + + ml_video_edit_person_divison_pic: + fmk: caffe + input_number: 1 + + ml_video_edit_person_divison_video: + fmk: caffe + input_number: 2 + + ml_video_edit_reid: + fmk: caffe + input_number: 1 + + ml_video_edit_seg_320: + fmk: caffe + input_number: 1 + + ml_video_edit_v10_best_model_nomean_20200723: + fmk: caffe + input_number: 1 + + ml_video_edit_video_segment_gauss_adaptis_part1: + fmk: caffe + input_number: 1 + + mnet: + fmk: caffe + input_number: 1 + + Mnet6_0312_extract_pay: + fmk: caffe + input_number: 1 + + model_hebing_3branch: + fmk: caffe + input_number: 1 + + mtk_2012_ATLANTA_10class_20190614_v41: + fmk: caffe + input_number: 1 + + mtk_detect_mbv1_640_480_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_face_recognition_v1: + fmk: caffe + input_number: 1 + + plat_isface: + fmk: caffe + input_number: 1 + + pose_3d: + fmk: caffe + input_number: 1 + + PoseNet_dla_17_x512_tmp: + fmk: caffe + input_number: 1 + + recognition: + fmk: caffe + input_number: 1 + + retinaface: + fmk: caffe + input_number: 1 + + Sport_Health_Tech_pose_iter: + fmk: caffe + input_number: 1 + +--- + +onnx: + 01-face_det_400_400: + fmk: onnx + input_number: 1 + input_suffix: .onnx.bin + output_suffix: .onnx.out + benchmark_shapes: 1,400,400,3 + acc_threshold: 4.5 + + adversarial_pruning: + fmk: onnx + input_number: 1 + + bloom_hongmo_detection_tmp: + fmk: onnx + input_number: 1 + + candy-9: + fmk: onnx + input_number: 1 + + carbu_intelligent_cockpit_fasttext_best: + fmk: onnx + input_number: 1 + + CloudBU_FSRCNN_RTC_8ch_3450_QP9: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,225,225,3 + + CloudBU_rfdn_rtc_x2_ver2_13: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,225,225,3 + + CloudBU_rfdn_rtc_x2_ver2_3450: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,225,225,3 + + crnn_lite_lstm_v2: + fmk: onnx + input_number: 1 + benchmark_shapes: 32,32,32,1 + + densenet-9: + fmk: onnx + input_number: 1 + + efficientnet-lite4-11: + fmk: onnx + input_number: 1 + + emotion-ferplus-8: + fmk: onnx + input_number: 1 + + encoder: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,32,83 + + gender_lstm_scd: + fmk: onnx + input_number: 1 + + gender_lstm_vad: + fmk: onnx + input_number: 1 + + gender_resnet34_lzl: + fmk: onnx + input_number: 1 + + googlenet-9: + fmk: onnx + input_number: 1 + + gts_text_detection: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,224,224,3 + + gts_version-RFB-320_simplified: + fmk: onnx + input_number: 1 + + Harmony_Voiceprint: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,200,40,1 + + Harmony_Voiceprint_resnet18: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,150,40,1 + + hdc_efficientnet_b3_1w_class: + fmk: onnx + input_number: 1 + + hdc_Face_Emotion_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_Face_Landmark5_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_Image_Aesthetic_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_mobilenet_1w_class: + fmk: onnx + input_number: 1 + + hdc_ocr_attention: + fmk: onnx + input_number: 1 + + hdc_ocr_detect_tmp: + fmk: onnx + input_number: 1 + + hdc_resnet_1w_class: + fmk: onnx + input_number: 1 + + Huawei_video_rvm_mobilenetv3_192: + fmk: onnx + input_number: 6 + + inception-v1-9: + fmk: onnx + input_number: 1 + + inception-v2-9: + fmk: onnx + input_number: 1 + + Ireland_face_detector: + fmk: onnx + input_number: 1 + + Ireland_gaze_corrector: + fmk: onnx + input_number: 3 + acc_threshold: 1 + + Ireland_gaze_estimator_ng: + fmk: onnx + input_number: 1 + + Ireland_ulfgf: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,240,320,3 + + ml_2012_ocr_cn: + fmk: onnx + input_number: 1 + + ml_2012_ocr_cn_noLSTM: + fmk: onnx + input_number: 1 + + ml_2012_ocr_detection_tmp: + fmk: onnx + input_number: 1 + + ml_asr_encoder_int8_202103: + fmk: onnx + input_number: 1 + + ml_audio_edit_rhythm_check_model: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,1024,81,1 + + ml_audio_kit_vocals_test: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,512,1024,2 + acc_threshold: 2 + + ml_edu_kit_hand_detection: + fmk: onnx + input_number: 1 + + ml_edu_kit_hand_key_position: + fmk: onnx + input_number: 1 + + ml_ei_facedetection: + fmk: onnx + input_number: 1 + + ml_face_3d: + fmk: onnx + input_number: 1 + + ml_facedetector: + fmk: onnx + input_number: 1 + + ml_location_lane_counter: + fmk: onnx + input_number: 1 + + ml_location_lane_counter0: + fmk: onnx + input_number: 1 + + ml_motion_capture_nanodet_m_0.5x_people_0928_sim: + fmk: onnx + input_number: 1 + + ml_motion_capture_smpl_0916: + fmk: onnx + input_number: 3 + + ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: + fmk: onnx + input_number: 5 + + ml_table_detection_fp32_tmp: + fmk: onnx + input_number: 1 + + ml_table_segment: + fmk: onnx + input_number: 1 + + ml_video_edit_art_generate: + fmk: onnx + input_number: 1 + + ml_video_edit_art_generate_20210513: + fmk: onnx + input_number: 1 + + ml_video_edit_art_transfer_20210513: + fmk: onnx + input_number: 3 + + ml_video_edit_dimming_tech_model_345000_color: + fmk: onnx + input_number: 2 + + ml_video_edit_dimming_tech_model_studio_20: + fmk: onnx + input_number: 2 + + ml_video_edit_dimming_tech_model_styleGan: + fmk: onnx + input_number: 2 + + ml_video_edit_enhance_update_tmp: + fmk: onnx + input_number: 1 + + ml_video_edit_face_edit_face3d: + fmk: onnx + input_number: 1 + + ml_video_edit_face_edit_pix2pixHD_unet: + fmk: onnx + input_number: 1 + + ml_video_edit_face_edit_retinaface: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,120,128,3 + + ml_video_edit_hair_dyeing_migrate_v2: + fmk: onnx + input_number: 4 + + ml_video_edit_hair_dyeing_migrate_v2_fix: + fmk: onnx + input_number: 4 + + + ml_video_edit_judge: + fmk: onnx + input_number: 1 + + ml_video_edit_makeup_mobilenetv203: + fmk: onnx + input_number: 1 + + ml_video_edit_moon_mode_sky_refine: + fmk: onnx + input_number: 2 + benchmark_shapes: 1,256,256,4:1,88,88,4 + + ml_video_edit_shot_selection_face_emotion: + fmk: onnx + input_number: 1 + + ml_video_edit_shot_selection_yolox_nano_coco_reduced: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_autoportrait: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_candy: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_gongnongbing: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_starry: + fmk: onnx + input_number: 1 + + ml_video_edit_styleCode_part1: + fmk: onnx + input_number: 1 + + ml_video_edit_styleCode_part2: + fmk: onnx + input_number: 9 + + ml_video_edit_vignet: + fmk: onnx + input_number: 1 + + ml_voice_detect: + fmk: onnx + input_number: 1 + + mnist-8: + fmk: onnx + input_number: 1 + + mobilenetv2-7: + fmk: onnx + input_number: 1 + + mosaic-9: + fmk: onnx + input_number: 1 + + mtk_detect_mbv1_640_480: + fmk: onnx + input_number: 1 + + mtk_detect_mbv1_640_480_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,480,640,3 + + mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400: + fmk: onnx + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv2-shortcut-400-400: + fmk: onnx + input_number: 1 + + mtk_detect-mbv2-shortcut-400-400-simplified: + fmk: onnx + input_number: 1 + + mtk_emotions-d2012-75: + fmk: onnx + input_number: 1 + + mtk_face_features_v2: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,256,192,3 + + mtk_face_features_v3: + fmk: onnx + input_number: 1 + + mtk_face_recognition_v2: + fmk: onnx + input_number: 1 + + mtk_face_recognition_v3: + fmk: onnx + input_number: 1 + + pointilism-9: + fmk: onnx + input_number: 1 + + porseg_tmp: + fmk: onnx + input_number: 2 + + psenet_lite_mbv2: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,32,32,3 + + Q888_CV_face_recognition_self: + fmk: onnx + input_number: 1 + + Q888_face_recognition: + fmk: onnx + input_number: 1 + + Q888_iris_detect: + fmk: onnx + input_number: 1 + + rain-princess-9: + fmk: onnx + input_number: 1 + + rcnn-ilsvrc13-9: + fmk: onnx + input_number: 1 + + residual_distill_bs_1: + fmk: onnx + input_number: 1 + + residual_distill_bs_32: + fmk: onnx + input_number: 1 + + residual_distill_cifar10_bs_1: + fmk: onnx + input_number: 1 + + residual_distill_cifar10_bs_32: + fmk: onnx + input_number: 1 + + residual_distill_res34_cifar10_bs_1_update: + fmk: onnx + input_number: 1 + + residual_distill_res50_cifar10_bs_1_update: + fmk: onnx + input_number: 1 + + rpnt_pdr_conv2d_16_fixed_last: + fmk: onnx + input_number: 1 + + rvm_mobilenetv3_192: + fmk: onnx + input_number: 6 + + shufflenet-9: + fmk: onnx + input_number: 1 + + shufflenet-v2-10: + fmk: onnx + input_number: 1 + + simple_IPS_model_4D_input: + fmk: onnx + input_number: 1 + + squeezenet1.0-9: + fmk: onnx + input_number: 1 + + squeezenet1.1-7: + fmk: onnx + input_number: 1 + + ssd-10: + fmk: onnx + input_number: 1 + + super-resolution-10: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,224,224,1 + + tinyyolov2-8: + fmk: onnx + input_number: 1 + benchmark_shapes: 1,416,416,3 + + udnie-9: + fmk: onnx + input_number: 1 + + yolov5s: + fmk: onnx + input_number: 1 + +--- + +tflite: + albert_lite_base_squadv1_1: + fmk: tflite + input_number: 3 + + bloom_model_age_gender: + fmk: tflite + input_number: 1 + + bloom_new_detect: + fmk: tflite + input_number: 1 + + deeplabv3_1_default_1: + fmk: tflite + input_number: 1 + + deeplabv3_257_mv_gpu: + fmk: tflite + input_number: 1 + + densenet: + fmk: tflite + input_number: 1 + + efficientnet_lite0_fp32_2: + fmk: tflite + input_number: 1 + + gts_detect_5k_tf115: + fmk: tflite + input_number: 1 + + hdc_tb_cn_neg: + fmk: tflite + input_number: 3 + acc_threshold: 0.5 + + hiai_AADB_HADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: + fmk: tflite + input_number: 2 + + + hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_chinese_english_recognize_model_float32: + fmk: tflite + input_number: 1 + + hiai_cn_recognize_modify_padv2: + fmk: tflite + input_number: 1 + + hiai_cpu_face_emotion: + fmk: tflite + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tflite + input_number: 1 + + hiai_cpu_face_headpose: + fmk: tflite + input_number: 1 + + hiai_ctpn_feature_map: + fmk: tflite + input_number: 1 + + hiai_cv_focusShootOCRModel_02: + fmk: tflite + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tflite + input_number: 1 + + hiai_cv_labelDetectorModel_v2: + fmk: tflite + input_number: 1 + + hiai_cv_labelDetectorModel_v3: + fmk: tflite + input_number: 2 + + hiai_cv_labelDetectorModel_v4: + fmk: tflite + input_number: 1 + + hiai_cv_poseEstimation: + fmk: tflite + input_number: 1 + + hiai_cv_saliencyDetectorModel: + fmk: tflite + input_number: 1 + + hiai_detect_curve_model_float32: + fmk: tflite + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tflite + input_number: 1 + + hiai_detectmodel_desnet_256_128_64_32: + fmk: tflite + input_number: 1 + + hiai_dress_detect: + fmk: tflite + input_number: 1 + + hiai_face_model_npu: + fmk: tflite + input_number: 1 + + hiai_frozen_inference_graph: + fmk: tflite + input_number: 1 + + hiai_ghostnet: + fmk: tflite + input_number: 1 + + hiai_humanDetection: + fmk: tflite + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tflite + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tflite + input_number: 1 + + hiai_label_and_video: + fmk: tflite + input_number: 1 + + hiai_latin_ocr: + fmk: tflite + input_number: 1 + + hiai_latin_ocr_1: + fmk: tflite + input_number: 1 + + hiai_lm_inference_graph: + fmk: tflite + input_number: 1 + + hiai_model_0909_kd_rot_ps_softmax: + fmk: tflite + input_number: 1 + + hiai_model_normalize_object_scene_ps_20200519: + fmk: tflite + input_number: 1 + + hiai_object_detect_814: + fmk: tflite + input_number: 1 + + hiai_PoseEstimation_Pcm: + fmk: tflite + input_number: 1 + + hiai_ssd_mobilenetv2_object: + fmk: tflite + input_number: 1 + + hiai_vad: + fmk: tflite + input_number: 2 + + ide_label_base: + fmk: tflite + input_number: 1 + + ide_label_retrained: + fmk: tflite + input_number: 1 + + inception_resnet_v2: + fmk: tflite + input_number: 1 + + inception_v3: + fmk: tflite + input_number: 1 + + inception_v4: + fmk: tflite + input_number: 1 + + lite-model_albert_lite_base_squadv1_metadata_1: + fmk: tflite + input_number: 3 + + lite-model_mobilebert_1_metadata_1: + fmk: tflite + input_number: 3 + + lite-model_on_device_vision_classifier_popular_us_products_V1_1: + fmk: tflite + input_number: 1 + + lite-model_on_device_vision_classifier_popular_wine_V1_1: + fmk: tflite + input_number: 1 + + lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: + fmk: tflite + input_number: 1 + + mindspore_text_classification_tflite: + fmk: tflite + input_number: 1 + + ml_ei_headpose: + fmk: tflite + input_number: 1 + + ml_ei_headpose_pb2tflite: + fmk: tflite + input_number: 3 + benchmark_shapes: 1,64,64,3:16:16 + + ml_ei_landmark: + fmk: tflite + input_number: 1 + + ml_ei_landmark_pb2tflite: + fmk: tflite + input_number: 1 + + ml_face_openclose: + fmk: tflite + input_number: 1 + + ml_face_openclose_tflite: + fmk: tflite + input_number: 1 + + ml_headpose_pb2tflite: + fmk: tflite + input_number: 3 + benchmark_shapes: 1,64,64,3:16:16 + + ml_location: + fmk: tflite + input_number: 1 + + ml_object_detect: + fmk: tflite + input_number: 1 + + ml_object_detect_1: + fmk: tflite + input_number: 1 + + ml_object_detect_pb2tflite: + fmk: tflite + input_number: 1 + + ml_ocr_jk: + fmk: tflite + input_number: 1 + + ml_ocr_jk_pb2tflite: + fmk: tflite + input_number: 1 + + ml_ocr_latin: + fmk: tflite + input_number: 1 + + ml_ocr_latin_pb2tflite: + fmk: tflite + input_number: 1 + + ml_pic_shopping: + fmk: tflite + input_number: 1 + + ml_pic_shopping_pb2tflite: + fmk: tflite + input_number: 1 + + ml_tacotron_decoder_step_stf: + fmk: tflite + input_number: 9 + benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 + + ml_text_correction: + fmk: tflite + input_number: 1 + + ml_video_edit_img_segment_adaptise_pb2tflite: + fmk: tflite + input_number: 2 + + ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: + fmk: tflite + input_number: 2 + + ml_vision_guide_detection1_pb2tflite: + fmk: tflite + input_number: 1 + + ml_vision_guide_detection3_pb2tflite: + fmk: tflite + input_number: 1 + + mnasnet_0.50_224_1_metadata_1: + fmk: tflite + input_number: 1 + + mnasnet_1.3_224: + fmk: tflite + input_number: 1 + + mnist: + fmk: tflite + input_number: 1 + + mobilebert_1_default_1: + fmk: tflite + input_number: 3 + + mobilenet: + fmk: tflite + input_number: 1 + + mobilenet_v1_0.25_128: + fmk: tflite + input_number: 1 + + mobilenet_v2_1.0_224: + fmk: tflite + input_number: 1 + + model_emotions_0727_nosoftmax: + fmk: tflite + input_number: 1 + + mtk_276landmark_0913: + fmk: tflite + input_number: 1 + + mtk_AADB_HADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + mtk_AADB_HADB_MBV3_model_fp32: + fmk: tflite + input_number: 1 + + mtk_age_gender: + fmk: tflite + input_number: 1 + + mtk_convert_model: + fmk: tflite + input_number: 1 + + mtk_face_features_v1: + fmk: tflite + input_number: 1 + + mtk_face_recognition: + fmk: tflite + input_number: 1 + + mtk_model_ckpt: + fmk: tflite + input_number: 1 + + mtk_model_emotions_0727_nosoftmax: + fmk: tflite + input_number: 1 + + mtk_model_face_dress: + fmk: tflite + input_number: 1 + + mtk_model_normalize_object_scene_ps_20200519_f32: + fmk: tflite + input_number: 1 + + mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax: + fmk: tflite + input_number: 1 + + mtk_new_detect: + fmk: tflite + input_number: 1 + + multi_person_mobilenet_v1_075_float: + fmk: tflite + input_number: 1 + + nasnet_large: + fmk: tflite + input_number: 1 + + nasnet_mobile: + fmk: tflite + input_number: 1 + + posenet_mobilenet_float_075_1_default_1: + fmk: tflite + input_number: 1 + + Q_AADB_HADB_MBV2_model: + fmk: tflite + input_number: 1 + + Q_convert: + fmk: tflite + input_number: 1 + + Q_crnn_ori_75w_slim_norm_pb2tflite: + fmk: tflite + input_number: 1 + + Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite: + fmk: tflite + input_number: 1 + + Q_crnn_screen_slim400w_more_20w_pb2tflite: + fmk: tflite + input_number: 1 + + Q_detect_fpn_add_inception-1448650: + fmk: tflite + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: + fmk: tflite + input_number: 1 + + Q_focusocr_cn_recog: + fmk: tflite + input_number: 1 + + Q_focusocr_jk_recog: + fmk: tflite + input_number: 1 + + Q_hand_0812_pb2tflite: + fmk: tflite + input_number: 1 + + Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite: + fmk: tflite + input_number: 1 + + Q_iMaxSR_RGB_385_p_pb2tflite: + fmk: tflite + input_number: 1 + + Q_inception-249970-672-11-16_pb2tflite: + fmk: tflite + input_number: 1 + + Q_language_model_hrmini_Q4_b4_17w: + fmk: tflite + input_number: 1 + + Q_object_scene: + fmk: tflite + input_number: 1 + + Q888_age_gender_orderd: + fmk: tflite + input_number: 1 + + Q888_face_dress_mv3y: + fmk: tflite + input_number: 1 + + Q888_face_emo_dress_mv3_orderd: + fmk: tflite + input_number: 1 + + Q888_HADB_AADB_MBV2_model_fp32: + fmk: tflite + input_number: 1 + + Q888_isface: + fmk: tflite + input_number: 1 + + Q888_landmark: + fmk: tflite + input_number: 1 + + Q888_lapa158_unet_0924: + fmk: tflite + input_number: 1 + + Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax: + fmk: tflite + input_number: 1 + + Q888_new_detect: + fmk: tflite + input_number: 1 + + Q888_pose: + fmk: tflite + input_number: 1 + + resnet: + fmk: tflite + input_number: 1 + + resnet_v2_101_299: + fmk: tflite + input_number: 1 + + scan_hms_angle_pb2tflite: + fmk: tflite + input_number: 1 + + scan_hms_angle1: + fmk: tflite + input_number: 1 + + scan_hms_detect: + fmk: tflite + input_number: 1 + + scan_hms_detect_pb2tflite: + fmk: tflite + input_number: 1 + + siteAI_digcom_AI_ECN: + fmk: tflite + input_number: 1 + + siteAI_digcom_g2v_keras: + fmk: tflite + input_number: 1 + + siteAI_trans_nonlinear: + fmk: tflite + input_number: 1 + + siteAI_trans_tcpclassify: + fmk: tflite + input_number: 1 + + siteAI_wireless_depress_w: + fmk: tflite + input_number: 1 + + siteAI_wireless_restore_w: + fmk: tflite + input_number: 1 + + squeezenet: + fmk: tflite + input_number: 1 + + text_classification: + fmk: tflite + input_number: 1 + + unet_mbv2_05_104pts: + fmk: tflite + input_number: 1 diff --git a/mindspore/lite/test/st/scripts/experimental/config/models_loop.yaml b/mindspore/lite/test/st/scripts/experimental/config/models_loop.yaml index 6413b2fedcd..6f1b22a1f83 100644 --- a/mindspore/lite/test/st/scripts/experimental/config/models_loop.yaml +++ b/mindspore/lite/test/st/scripts/experimental/config/models_loop.yaml @@ -1,1104 +1,1104 @@ -mindir: - deepfm_criteo_bs_16000_Ascend: - fmk: mindir - input_number: 2 - - efficientnetb0_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - efficientnetb1_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - efficientnetb2_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 4 - - efficientnetb3_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - inceptionv3_ascend: - fmk: mindir - input_number: 1 - - inceptionV4: - fmk: mindir - input_number: 1 - - mobilenetv3large_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - mobilenetv3small_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - pix2pix_facades_bs1: - fmk: mindir - input_number: 1 - - resnet101_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet101_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet18_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet50_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_thor_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - se-resnet50_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - shufflenetv1: - fmk: mindir - input_number: 1 - - ssimae_mvtecadbottle_bs1: - fmk: mindir - input_number: 1 - - unet_bs_1_input_2: - fmk: mindir - input_number: 1 - - unet_nested_cell_bs_1_input_2: - fmk: mindir - input_number: 1 - - vgg16_cifar10_bs_64_Ascend: - fmk: mindir - input_number: 1 - - vgg16_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - - vgg19_cifar10_bs1: - fmk: mindir - input_number: 1 - - vgg19_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 1 - - vit_imagenet2012_bs1: - fmk: mindir - input_number: 1 - ---- - -tf: - browser_deepfm_v7: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_deepfm_v7_int64: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_v36: - fmk: tf - input_number: 2 - benchmark_shapes: 75,190:75,9120 - acc_threshold: 0.00002 - - browser_v79: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - browser_v79_int32: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: - fmk: tf - input_number: 2 - - hiai_dress_detect: - fmk: tf - input_number: 1 - benchmark_shapes: 1,960,960,3 - - hiai_face_model_npu: - fmk: tf - input_number: 1 - - hiai_frozen_inference_graph: - fmk: tf - input_number: 1 - benchmark_shapes: 1,300,300,3 - - hiai_label_and_video: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_lm_inference_graph: - fmk: tf - input_number: 1 - - hiai_model_0909_kd_rot_ps_softmax: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_ssd_mobilenetv2_object: - fmk: tf - input_number: 1 - - hiai_transformer_encoder: - fmk: tf - input_number: 15 - - ml_ocr_jk: - fmk: tf - input_number: 1 - - ml_video_edit_img_segment_adaptise: - fmk: tf - input_number: 2 - - ml_video_edit_oneclick_adaptis: - fmk: tf - input_number: 3 - - ml_video_edit_video_segment_gauss_adaptis_part2: - fmk: tf - input_number: 2 - - mtk_age_gender: - fmk: tf - input_number: 1 - - squeezenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: - fmk: tf - input_number: 14 - benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 - ---- - -caffe: - 2012_ATLANTA_10class_20190131_v4.0: - fmk: caffe - input_number: 1 - - 6c_seg_nomean_20200610: - fmk: caffe - input_number: 1 - - age_new: - fmk: caffe - input_number: 1 - - bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 - - bolt_deploy_color-server: - fmk: caffe - input_number: 1 - - deconv_test_model: - fmk: caffe - input_number: 1 - - deconvs_model: - fmk: caffe - input_number: 1 - - detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - detection_retinaface_fix: - fmk: caffe - input_number: 1 - - emotion: - fmk: caffe - input_number: 1 - - gender_res_large_deploy: - fmk: caffe - input_number: 1 - - hdc_age_medium: - fmk: caffe - input_number: 1 - - hdc_Face_Aesthetic_MTI_Aesthetic: - fmk: caffe - input_number: 1 - - hdc_ocr_recog_horizontal: - fmk: caffe - input_number: 1 - - hdc_retinaface: - fmk: caffe - input_number: 1 - - hiai_cpu_face_attr: - fmk: caffe - input_number: 1 - - hiai_cpu_face_detect: - fmk: caffe - input_number: 1 - - hiai_cv_aestheticsEngineModel_osp: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_01: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_03: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_07: - fmk: caffe - input_number: 1 - - hiai_face_attr1: - fmk: caffe - input_number: 1 - - hiai_face_detect_rfb: - fmk: caffe - input_number: 1 - - hiai_face_pose_tuku: - fmk: caffe - input_number: 1 - - hiai_face_recognition_1: - fmk: caffe - input_number: 1 - - hiai_face_RFB-Epoch-170-no-transpose: - fmk: caffe - input_number: 1 - - hiai_human_seg: - fmk: caffe - input_number: 1 - - hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: - fmk: caffe - input_number: 1 - - hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: - fmk: caffe - input_number: 1 - - hiai_semantic_seg: - fmk: caffe - input_number: 1 - - identify_card_detect_tmp: - fmk: caffe - input_number: 1 - - ml_2012_ocr_detection_caffe_tmp: - fmk: caffe - input_number: 1 - - ml_2012_ocr_rec_caffe: - fmk: caffe - input_number: 1 - - ml_ARengine23_bodypose: - fmk: caffe - input_number: 1 - - ml_bank_detect_0312_tmp: - fmk: caffe - input_number: 1 - - ml_bank_recog: - fmk: caffe - input_number: 1 - - ml_bodymask: - fmk: caffe - input_number: 1 - - ml_face_age: - fmk: caffe - input_number: 1 - - ml_face_beard: - fmk: caffe - input_number: 1 - - ml_face_div_parsing: - fmk: caffe - input_number: 1 - - ml_face_glasses: - fmk: caffe - input_number: 1 - - ml_face_isface: - fmk: caffe - input_number: 1 - - ml_face_mnet: - fmk: caffe - input_number: 1 - - ml_face_pose: - fmk: caffe - input_number: 1 - - ml_face_sex: - fmk: caffe - input_number: 1 - - ml_face_tracking: - fmk: caffe - input_number: 1 - - ml_hand_3d_detection: - fmk: caffe - input_number: 1 - - ml_Hand_deploy: - fmk: caffe - input_number: 1 - - ml_hand_detection: - fmk: caffe - input_number: 1 - - ml_handpose: - fmk: caffe - input_number: 1 - - ml_hardware_liveness: - fmk: caffe - input_number: 1 - - ml_hardware_pose: - fmk: caffe - input_number: 1 - - ml_Heatmap_depth_180240: - fmk: caffe - input_number: 2 - - ml_Heatmap_depth_240180: - fmk: caffe - input_number: 2 - - ml_lable_model_hebing_device: - fmk: caffe - input_number: 1 - - ml_location_scene_division: - fmk: caffe - input_number: 1 - - ml_ocr_bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 - - ml_ocr_bank_card_recognition_fcny: - fmk: caffe - input_number: 1 - - ml_ocr_detect_20200305: - fmk: caffe - input_number: 1 - - ml_ocr_identify_card_detect_tmp: - fmk: caffe - input_number: 1 - - ml_ocr_identify_card_fcny: - fmk: caffe - input_number: 1 - - ml_ocr_sfz_add_final_0325: - fmk: caffe - input_number: 1 - - ml_ocr_sfz_detect_0325_tmp: - fmk: caffe - input_number: 1 - - ml_segmentation_atlanta_1: - fmk: caffe - input_number: 1 - - ml_segmentation_atlanta_10: - fmk: caffe - input_number: 1 - - ml_segmentation_matting: - fmk: caffe - input_number: 1 - - ml_tabel_recog: - fmk: caffe - input_number: 1 - - ml_video_edit_detect_20211111: - fmk: caffe - input_number: 1 - - ml_video_edit_dynamic_effect_MTI_seg5c_v1: - fmk: caffe - input_number: 1 - - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: - fmk: caffe - input_number: 1 - - ml_video_edit_have_imageProcessLayer_interpTo145_20201015: - fmk: caffe - input_number: 1 - - ml_video_edit_img_segment: - fmk: caffe - input_number: 1 - - ml_video_edit_Mnet: - fmk: caffe - input_number: 1 - - ml_video_edit_MnetN367_extract_1010_pay: - fmk: caffe - input_number: 1 - - ml_video_edit_moon_mode_moon_seg: - fmk: caffe - input_number: 1 - - ml_video_edit_moon_mode_MTI_9c_segmentation_v12: - fmk: caffe - input_number: 1 - - ml_video_edit_person_divison_video: - fmk: caffe - input_number: 2 - - ml_video_edit_reid: - fmk: caffe - input_number: 1 - - ml_video_edit_seg_320: - fmk: caffe - input_number: 1 - - ml_video_edit_v10_best_model_nomean_20200723: - fmk: caffe - input_number: 1 - - ml_video_edit_video_segment_gauss_adaptis_part1: - fmk: caffe - input_number: 1 - - mnet: - fmk: caffe - input_number: 1 - - Mnet6_0312_extract_pay: - fmk: caffe - input_number: 1 - - model_hebing_3branch: - fmk: caffe - input_number: 1 - - mtk_2012_ATLANTA_10class_20190614_v41: - fmk: caffe - input_number: 1 - - mtk_detect_mbv1_640_480_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - mtk_face_recognition_v1: - fmk: caffe - input_number: 1 - - plat_isface: - fmk: caffe - input_number: 1 - - pose_3d: - fmk: caffe - input_number: 1 - - PoseNet_dla_17_x512_tmp: - fmk: caffe - input_number: 1 - - recognition: - fmk: caffe - input_number: 1 - - retinaface: - fmk: caffe - input_number: 1 - - Sport_Health_Tech_pose_iter: - fmk: caffe - input_number: 1 - ---- - -onnx: - adversarial_pruning: - fmk: onnx - input_number: 1 - - bloom_hongmo_detection_tmp: - fmk: onnx - input_number: 1 - - candy-9: - fmk: onnx - input_number: 1 - - carbu_intelligent_cockpit_fasttext_best: - fmk: onnx - input_number: 1 - - densenet-9: - fmk: onnx - input_number: 1 - - googlenet-9: - fmk: onnx - input_number: 1 - - gts_version-RFB-320_simplified: - fmk: onnx - input_number: 1 - - hdc_efficientnet_b3_1w_class: - fmk: onnx - input_number: 1 - - hdc_Face_Emotion_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_Image_Aesthetic_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - hdc_mobilenet_1w_class: - fmk: onnx - input_number: 1 - - hdc_ocr_detect_tmp: - fmk: onnx - input_number: 1 - - hdc_resnet_1w_class: - fmk: onnx - input_number: 1 - - inception-v1-9: - fmk: onnx - input_number: 1 - - inception-v2-9: - fmk: onnx - input_number: 1 - - Ireland_face_detector: - fmk: onnx - input_number: 1 - - Ireland_gaze_corrector: - fmk: onnx - input_number: 3 - acc_threshold: 1 - - Ireland_gaze_estimator_ng: - fmk: onnx - input_number: 1 - - ml_2012_ocr_detection_tmp: - fmk: onnx - input_number: 1 - - ml_edu_kit_hand_detection: - fmk: onnx - input_number: 1 - - ml_edu_kit_hand_key_position: - fmk: onnx - input_number: 1 - - ml_ei_facedetection: - fmk: onnx - input_number: 1 - - ml_face_3d: - fmk: onnx - input_number: 1 - - ml_facedetector: - fmk: onnx - input_number: 1 - - ml_location_lane_counter: - fmk: onnx - input_number: 1 - - ml_location_lane_counter0: - fmk: onnx - input_number: 1 - - ml_motion_capture_nanodet_m_0.5x_people_0928_sim: - fmk: onnx - input_number: 1 - - ml_motion_capture_smpl_0916: - fmk: onnx - input_number: 3 - - ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: - fmk: onnx - input_number: 5 - - ml_table_detection_fp32_tmp: - fmk: onnx - input_number: 1 - - ml_video_edit_art_generate: - fmk: onnx - input_number: 1 - - ml_video_edit_art_generate_20210513: - fmk: onnx - input_number: 1 - - ml_video_edit_art_transfer_20210513: - fmk: onnx - input_number: 3 - - ml_video_edit_dimming_tech_model_345000_color: - fmk: onnx - input_number: 2 - - ml_video_edit_dimming_tech_model_studio_20: - fmk: onnx - input_number: 2 - - ml_video_edit_enhance_update_tmp: - fmk: onnx - input_number: 1 - - ml_video_edit_face_edit_face3d: - fmk: onnx - input_number: 1 - - ml_video_edit_face_edit_pix2pixHD_unet: - fmk: onnx - input_number: 1 - - ml_video_edit_hair_dyeing_migrate_v2: - fmk: onnx - input_number: 4 - - ml_video_edit_hair_dyeing_migrate_v2_fix: - fmk: onnx - input_number: 4 - - ml_video_edit_judge: - fmk: onnx - input_number: 1 - - ml_video_edit_makeup_mobilenetv203: - fmk: onnx - input_number: 1 - - ml_video_edit_shot_selection_face_emotion: - fmk: onnx - input_number: 1 - - ml_video_edit_shot_selection_yolox_nano_coco_reduced: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_autoportrait: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_candy: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_gongnongbing: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_starry: - fmk: onnx - input_number: 1 - - ml_video_edit_styleCode_part1: - fmk: onnx - input_number: 1 - - ml_video_edit_styleCode_part2: - fmk: onnx - input_number: 9 - - ml_video_edit_vignet: - fmk: onnx - input_number: 1 - - mobilenetv2-7: - fmk: onnx - input_number: 1 - - mosaic-9: - fmk: onnx - input_number: 1 - - mtk_detect_mbv1_640_480: - fmk: onnx - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400: - fmk: onnx - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv2-shortcut-400-400: - fmk: onnx - input_number: 1 - - mtk_detect-mbv2-shortcut-400-400-simplified: - fmk: onnx - input_number: 1 - - mtk_emotions-d2012-75: - fmk: onnx - input_number: 1 - - mtk_face_features_v3: - fmk: onnx - input_number: 1 - - mtk_face_recognition_v2: - fmk: onnx - input_number: 1 - - mtk_face_recognition_v3: - fmk: onnx - input_number: 1 - - pointilism-9: - fmk: onnx - input_number: 1 - - porseg_tmp: - fmk: onnx - input_number: 2 - - Q888_CV_face_recognition_self: - fmk: onnx - input_number: 1 - - Q888_face_recognition: - fmk: onnx - input_number: 1 - - Q888_iris_detect: - fmk: onnx - input_number: 1 - - rain-princess-9: - fmk: onnx - input_number: 1 - - residual_distill_bs_1: - fmk: onnx - input_number: 1 - - residual_distill_bs_32: - fmk: onnx - input_number: 1 - - residual_distill_cifar10_bs_1: - fmk: onnx - input_number: 1 - - residual_distill_cifar10_bs_32: - fmk: onnx - input_number: 1 - - residual_distill_res34_cifar10_bs_1_update: - fmk: onnx - input_number: 1 - - residual_distill_res50_cifar10_bs_1_update: - fmk: onnx - input_number: 1 - - rpnt_pdr_conv2d_16_fixed_last: - fmk: onnx - input_number: 1 - - shufflenet-9: - fmk: onnx - input_number: 1 - - shufflenet-v2-10: - fmk: onnx - input_number: 1 - - simple_IPS_model_4D_input: - fmk: onnx - input_number: 1 - - squeezenet1.0-9: - fmk: onnx - input_number: 1 - - squeezenet1.1-7: - fmk: onnx - input_number: 1 - - ssd-10: - fmk: onnx - input_number: 1 - - udnie-9: - fmk: onnx - input_number: 1 - - yolov5s: - fmk: onnx - input_number: 1 - ---- - -tflite: - albert_lite_base_squadv1_1: - fmk: tflite - input_number: 3 - - gts_detect_5k_tf115: - fmk: tflite - input_number: 1 - - hdc_tb_cn_neg: - fmk: tflite - input_number: 3 - acc_threshold: 0.5 - - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: - fmk: tflite - input_number: 2 - - hiai_cv_labelDetectorModel_v3: - fmk: tflite - input_number: 2 - - hiai_cv_saliencyDetectorModel: - fmk: tflite - input_number: 1 - - hiai_dress_detect: - fmk: tflite - input_number: 1 - - hiai_face_model_npu: - fmk: tflite - input_number: 1 - - hiai_frozen_inference_graph: - fmk: tflite - input_number: 1 - - hiai_label_and_video: - fmk: tflite - input_number: 1 - - hiai_lm_inference_graph: - fmk: tflite - input_number: 1 - - hiai_model_0909_kd_rot_ps_softmax: - fmk: tflite - input_number: 1 - - hiai_object_detect_814: - fmk: tflite - input_number: 1 - - hiai_ssd_mobilenetv2_object: - fmk: tflite - input_number: 1 - - hiai_vad: - fmk: tflite - input_number: 2 - - ide_label_retrained: - fmk: tflite - input_number: 1 - - lite-model_albert_lite_base_squadv1_metadata_1: - fmk: tflite - input_number: 3 - - lite-model_mobilebert_1_metadata_1: - fmk: tflite - input_number: 3 - - ml_ei_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_location: - fmk: tflite - input_number: 1 - - ml_ocr_jk: - fmk: tflite - input_number: 1 - - ml_ocr_jk_pb2tflite: - fmk: tflite - input_number: 1 - - ml_tacotron_decoder_step_stf: - fmk: tflite - input_number: 9 - benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 - - ml_video_edit_img_segment_adaptise_pb2tflite: - fmk: tflite - input_number: 2 - - ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: - fmk: tflite - input_number: 2 - - mobilebert_1_default_1: - fmk: tflite - input_number: 3 - - mtk_276landmark_0913: - fmk: tflite - input_number: 1 - - mtk_age_gender: - fmk: tflite - input_number: 1 - - Q_language_model_hrmini_Q4_b4_17w: - fmk: tflite - input_number: 1 - - Q888_lapa158_unet_0924: - fmk: tflite - input_number: 1 - - resnet_v2_101_299: - fmk: tflite - input_number: 1 - - scan_hms_detect: - fmk: tf - input_number: 1 - - unet_mbv2_05_104pts: - fmk: tflite - input_number: 1 +mindir: + deepfm_criteo_bs_16000_Ascend: + fmk: mindir + input_number: 2 + + efficientnetb0_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + efficientnetb1_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + efficientnetb2_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 4 + + efficientnetb3_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + inceptionv3_ascend: + fmk: mindir + input_number: 1 + + inceptionV4: + fmk: mindir + input_number: 1 + + mobilenetv3large_imagenet2012_bs1: + fmk: mindir + input_number: 1 + + mobilenetv3small_imagenet2012_bs1: + fmk: mindir + input_number: 1 + + pix2pix_facades_bs1: + fmk: mindir + input_number: 1 + + resnet101_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet101_imagenet_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet18_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_cifar10_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_cifar10_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + resnet50_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_imagenet_bs_1_GPU: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + resnet50_thor_imagenet_bs_1_ascend: + fmk: mindir + input_number: 1 + acc_threshold: 2 + + se-resnet50_imagenet_bs_1_ascend: + fmk: mindir + input_number: 1 + acc_threshold: 3 + + shufflenetv1: + fmk: mindir + input_number: 1 + + ssimae_mvtecadbottle_bs1: + fmk: mindir + input_number: 1 + + unet_bs_1_input_2: + fmk: mindir + input_number: 1 + + unet_nested_cell_bs_1_input_2: + fmk: mindir + input_number: 1 + + vgg16_cifar10_bs_64_Ascend: + fmk: mindir + input_number: 1 + + vgg16_imagenet_bs_1_Ascend: + fmk: mindir + input_number: 1 + + vgg19_cifar10_bs1: + fmk: mindir + input_number: 1 + + vgg19_imagenet2012_bs1: + fmk: mindir + input_number: 1 + acc_threshold: 1 + + vit_imagenet2012_bs1: + fmk: mindir + input_number: 1 + +--- + +tf: + browser_deepfm_v7: + fmk: tf + input_number: 2 + benchmark_shapes: 200,94:200,94 + acc_threshold: 0.5 + + browser_deepfm_v7_int64: + fmk: tf + input_number: 2 + benchmark_shapes: 200,94:200,94 + acc_threshold: 0.5 + + browser_v36: + fmk: tf + input_number: 2 + benchmark_shapes: 75,190:75,9120 + acc_threshold: 0.00002 + + browser_v79: + fmk: tf + input_number: 2 + benchmark_shapes: 10,294:10,294 + acc_threshold: 0.004 + + browser_v79_int32: + fmk: tf + input_number: 2 + benchmark_shapes: 10,294:10,294 + acc_threshold: 0.004 + + hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: + fmk: tf + input_number: 2 + + hiai_dress_detect: + fmk: tf + input_number: 1 + benchmark_shapes: 1,960,960,3 + + hiai_face_model_npu: + fmk: tf + input_number: 1 + + hiai_frozen_inference_graph: + fmk: tf + input_number: 1 + benchmark_shapes: 1,300,300,3 + + hiai_label_and_video: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_lm_inference_graph: + fmk: tf + input_number: 1 + + hiai_model_0909_kd_rot_ps_softmax: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + hiai_ssd_mobilenetv2_object: + fmk: tf + input_number: 1 + + hiai_transformer_encoder: + fmk: tf + input_number: 15 + + ml_ocr_jk: + fmk: tf + input_number: 1 + + ml_video_edit_img_segment_adaptise: + fmk: tf + input_number: 2 + + ml_video_edit_oneclick_adaptis: + fmk: tf + input_number: 3 + + ml_video_edit_video_segment_gauss_adaptis_part2: + fmk: tf + input_number: 2 + + mtk_age_gender: + fmk: tf + input_number: 1 + + squeezenet: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: + fmk: tf + input_number: 14 + benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 + +--- + +caffe: + 2012_ATLANTA_10class_20190131_v4.0: + fmk: caffe + input_number: 1 + + 6c_seg_nomean_20200610: + fmk: caffe + input_number: 1 + + age_new: + fmk: caffe + input_number: 1 + + bank_card_detection_inception_tmp: + fmk: caffe + input_number: 1 + + bolt_deploy_color-server: + fmk: caffe + input_number: 1 + + deconv_test_model: + fmk: caffe + input_number: 1 + + deconvs_model: + fmk: caffe + input_number: 1 + + detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + detection_retinaface_fix: + fmk: caffe + input_number: 1 + + emotion: + fmk: caffe + input_number: 1 + + gender_res_large_deploy: + fmk: caffe + input_number: 1 + + hdc_age_medium: + fmk: caffe + input_number: 1 + + hdc_Face_Aesthetic_MTI_Aesthetic: + fmk: caffe + input_number: 1 + + hdc_ocr_recog_horizontal: + fmk: caffe + input_number: 1 + + hdc_retinaface: + fmk: caffe + input_number: 1 + + hiai_cpu_face_attr: + fmk: caffe + input_number: 1 + + hiai_cpu_face_detect: + fmk: caffe + input_number: 1 + + hiai_cv_aestheticsEngineModel_osp: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_01: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_03: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_07: + fmk: caffe + input_number: 1 + + hiai_face_attr1: + fmk: caffe + input_number: 1 + + hiai_face_detect_rfb: + fmk: caffe + input_number: 1 + + hiai_face_pose_tuku: + fmk: caffe + input_number: 1 + + hiai_face_recognition_1: + fmk: caffe + input_number: 1 + + hiai_face_RFB-Epoch-170-no-transpose: + fmk: caffe + input_number: 1 + + hiai_human_seg: + fmk: caffe + input_number: 1 + + hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: + fmk: caffe + input_number: 1 + + hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: + fmk: caffe + input_number: 1 + + hiai_semantic_seg: + fmk: caffe + input_number: 1 + + identify_card_detect_tmp: + fmk: caffe + input_number: 1 + + ml_2012_ocr_detection_caffe_tmp: + fmk: caffe + input_number: 1 + + ml_2012_ocr_rec_caffe: + fmk: caffe + input_number: 1 + + ml_ARengine23_bodypose: + fmk: caffe + input_number: 1 + + ml_bank_detect_0312_tmp: + fmk: caffe + input_number: 1 + + ml_bank_recog: + fmk: caffe + input_number: 1 + + ml_bodymask: + fmk: caffe + input_number: 1 + + ml_face_age: + fmk: caffe + input_number: 1 + + ml_face_beard: + fmk: caffe + input_number: 1 + + ml_face_div_parsing: + fmk: caffe + input_number: 1 + + ml_face_glasses: + fmk: caffe + input_number: 1 + + ml_face_isface: + fmk: caffe + input_number: 1 + + ml_face_mnet: + fmk: caffe + input_number: 1 + + ml_face_pose: + fmk: caffe + input_number: 1 + + ml_face_sex: + fmk: caffe + input_number: 1 + + ml_face_tracking: + fmk: caffe + input_number: 1 + + ml_hand_3d_detection: + fmk: caffe + input_number: 1 + + ml_Hand_deploy: + fmk: caffe + input_number: 1 + + ml_hand_detection: + fmk: caffe + input_number: 1 + + ml_handpose: + fmk: caffe + input_number: 1 + + ml_hardware_liveness: + fmk: caffe + input_number: 1 + + ml_hardware_pose: + fmk: caffe + input_number: 1 + + ml_Heatmap_depth_180240: + fmk: caffe + input_number: 2 + + ml_Heatmap_depth_240180: + fmk: caffe + input_number: 2 + + ml_lable_model_hebing_device: + fmk: caffe + input_number: 1 + + ml_location_scene_division: + fmk: caffe + input_number: 1 + + ml_ocr_bank_card_detection_inception_tmp: + fmk: caffe + input_number: 1 + + ml_ocr_bank_card_recognition_fcny: + fmk: caffe + input_number: 1 + + ml_ocr_detect_20200305: + fmk: caffe + input_number: 1 + + ml_ocr_identify_card_detect_tmp: + fmk: caffe + input_number: 1 + + ml_ocr_identify_card_fcny: + fmk: caffe + input_number: 1 + + ml_ocr_sfz_add_final_0325: + fmk: caffe + input_number: 1 + + ml_ocr_sfz_detect_0325_tmp: + fmk: caffe + input_number: 1 + + ml_segmentation_atlanta_1: + fmk: caffe + input_number: 1 + + ml_segmentation_atlanta_10: + fmk: caffe + input_number: 1 + + ml_segmentation_matting: + fmk: caffe + input_number: 1 + + ml_tabel_recog: + fmk: caffe + input_number: 1 + + ml_video_edit_detect_20211111: + fmk: caffe + input_number: 1 + + ml_video_edit_dynamic_effect_MTI_seg5c_v1: + fmk: caffe + input_number: 1 + + ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: + fmk: caffe + input_number: 1 + + ml_video_edit_have_imageProcessLayer_interpTo145_20201015: + fmk: caffe + input_number: 1 + + ml_video_edit_img_segment: + fmk: caffe + input_number: 1 + + ml_video_edit_Mnet: + fmk: caffe + input_number: 1 + + ml_video_edit_MnetN367_extract_1010_pay: + fmk: caffe + input_number: 1 + + ml_video_edit_moon_mode_moon_seg: + fmk: caffe + input_number: 1 + + ml_video_edit_moon_mode_MTI_9c_segmentation_v12: + fmk: caffe + input_number: 1 + + ml_video_edit_person_divison_video: + fmk: caffe + input_number: 2 + + ml_video_edit_reid: + fmk: caffe + input_number: 1 + + ml_video_edit_seg_320: + fmk: caffe + input_number: 1 + + ml_video_edit_v10_best_model_nomean_20200723: + fmk: caffe + input_number: 1 + + ml_video_edit_video_segment_gauss_adaptis_part1: + fmk: caffe + input_number: 1 + + mnet: + fmk: caffe + input_number: 1 + + Mnet6_0312_extract_pay: + fmk: caffe + input_number: 1 + + model_hebing_3branch: + fmk: caffe + input_number: 1 + + mtk_2012_ATLANTA_10class_20190614_v41: + fmk: caffe + input_number: 1 + + mtk_detect_mbv1_640_480_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: + fmk: caffe + input_number: 1 + + mtk_face_recognition_v1: + fmk: caffe + input_number: 1 + + plat_isface: + fmk: caffe + input_number: 1 + + pose_3d: + fmk: caffe + input_number: 1 + + PoseNet_dla_17_x512_tmp: + fmk: caffe + input_number: 1 + + recognition: + fmk: caffe + input_number: 1 + + retinaface: + fmk: caffe + input_number: 1 + + Sport_Health_Tech_pose_iter: + fmk: caffe + input_number: 1 + +--- + +onnx: + adversarial_pruning: + fmk: onnx + input_number: 1 + + bloom_hongmo_detection_tmp: + fmk: onnx + input_number: 1 + + candy-9: + fmk: onnx + input_number: 1 + + carbu_intelligent_cockpit_fasttext_best: + fmk: onnx + input_number: 1 + + densenet-9: + fmk: onnx + input_number: 1 + + googlenet-9: + fmk: onnx + input_number: 1 + + gts_version-RFB-320_simplified: + fmk: onnx + input_number: 1 + + hdc_efficientnet_b3_1w_class: + fmk: onnx + input_number: 1 + + hdc_Face_Emotion_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_Image_Aesthetic_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + hdc_mobilenet_1w_class: + fmk: onnx + input_number: 1 + + hdc_ocr_detect_tmp: + fmk: onnx + input_number: 1 + + hdc_resnet_1w_class: + fmk: onnx + input_number: 1 + + inception-v1-9: + fmk: onnx + input_number: 1 + + inception-v2-9: + fmk: onnx + input_number: 1 + + Ireland_face_detector: + fmk: onnx + input_number: 1 + + Ireland_gaze_corrector: + fmk: onnx + input_number: 3 + acc_threshold: 1 + + Ireland_gaze_estimator_ng: + fmk: onnx + input_number: 1 + + ml_2012_ocr_detection_tmp: + fmk: onnx + input_number: 1 + + ml_edu_kit_hand_detection: + fmk: onnx + input_number: 1 + + ml_edu_kit_hand_key_position: + fmk: onnx + input_number: 1 + + ml_ei_facedetection: + fmk: onnx + input_number: 1 + + ml_face_3d: + fmk: onnx + input_number: 1 + + ml_facedetector: + fmk: onnx + input_number: 1 + + ml_location_lane_counter: + fmk: onnx + input_number: 1 + + ml_location_lane_counter0: + fmk: onnx + input_number: 1 + + ml_motion_capture_nanodet_m_0.5x_people_0928_sim: + fmk: onnx + input_number: 1 + + ml_motion_capture_smpl_0916: + fmk: onnx + input_number: 3 + + ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: + fmk: onnx + input_number: 5 + + ml_table_detection_fp32_tmp: + fmk: onnx + input_number: 1 + + ml_video_edit_art_generate: + fmk: onnx + input_number: 1 + + ml_video_edit_art_generate_20210513: + fmk: onnx + input_number: 1 + + ml_video_edit_art_transfer_20210513: + fmk: onnx + input_number: 3 + + ml_video_edit_dimming_tech_model_345000_color: + fmk: onnx + input_number: 2 + + ml_video_edit_dimming_tech_model_studio_20: + fmk: onnx + input_number: 2 + + ml_video_edit_enhance_update_tmp: + fmk: onnx + input_number: 1 + + ml_video_edit_face_edit_face3d: + fmk: onnx + input_number: 1 + + ml_video_edit_face_edit_pix2pixHD_unet: + fmk: onnx + input_number: 1 + + ml_video_edit_hair_dyeing_migrate_v2: + fmk: onnx + input_number: 4 + + ml_video_edit_hair_dyeing_migrate_v2_fix: + fmk: onnx + input_number: 4 + + ml_video_edit_judge: + fmk: onnx + input_number: 1 + + ml_video_edit_makeup_mobilenetv203: + fmk: onnx + input_number: 1 + + ml_video_edit_shot_selection_face_emotion: + fmk: onnx + input_number: 1 + + ml_video_edit_shot_selection_yolox_nano_coco_reduced: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_autoportrait: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_candy: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_gongnongbing: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_starry: + fmk: onnx + input_number: 1 + + ml_video_edit_styleCode_part1: + fmk: onnx + input_number: 1 + + ml_video_edit_styleCode_part2: + fmk: onnx + input_number: 9 + + ml_video_edit_vignet: + fmk: onnx + input_number: 1 + + mobilenetv2-7: + fmk: onnx + input_number: 1 + + mosaic-9: + fmk: onnx + input_number: 1 + + mtk_detect_mbv1_640_480: + fmk: onnx + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400: + fmk: onnx + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv2-shortcut-400-400: + fmk: onnx + input_number: 1 + + mtk_detect-mbv2-shortcut-400-400-simplified: + fmk: onnx + input_number: 1 + + mtk_emotions-d2012-75: + fmk: onnx + input_number: 1 + + mtk_face_features_v3: + fmk: onnx + input_number: 1 + + mtk_face_recognition_v2: + fmk: onnx + input_number: 1 + + mtk_face_recognition_v3: + fmk: onnx + input_number: 1 + + pointilism-9: + fmk: onnx + input_number: 1 + + porseg_tmp: + fmk: onnx + input_number: 2 + + Q888_CV_face_recognition_self: + fmk: onnx + input_number: 1 + + Q888_face_recognition: + fmk: onnx + input_number: 1 + + Q888_iris_detect: + fmk: onnx + input_number: 1 + + rain-princess-9: + fmk: onnx + input_number: 1 + + residual_distill_bs_1: + fmk: onnx + input_number: 1 + + residual_distill_bs_32: + fmk: onnx + input_number: 1 + + residual_distill_cifar10_bs_1: + fmk: onnx + input_number: 1 + + residual_distill_cifar10_bs_32: + fmk: onnx + input_number: 1 + + residual_distill_res34_cifar10_bs_1_update: + fmk: onnx + input_number: 1 + + residual_distill_res50_cifar10_bs_1_update: + fmk: onnx + input_number: 1 + + rpnt_pdr_conv2d_16_fixed_last: + fmk: onnx + input_number: 1 + + shufflenet-9: + fmk: onnx + input_number: 1 + + shufflenet-v2-10: + fmk: onnx + input_number: 1 + + simple_IPS_model_4D_input: + fmk: onnx + input_number: 1 + + squeezenet1.0-9: + fmk: onnx + input_number: 1 + + squeezenet1.1-7: + fmk: onnx + input_number: 1 + + ssd-10: + fmk: onnx + input_number: 1 + + udnie-9: + fmk: onnx + input_number: 1 + + yolov5s: + fmk: onnx + input_number: 1 + +--- + +tflite: + albert_lite_base_squadv1_1: + fmk: tflite + input_number: 3 + + gts_detect_5k_tf115: + fmk: tflite + input_number: 1 + + hdc_tb_cn_neg: + fmk: tflite + input_number: 3 + acc_threshold: 0.5 + + hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: + fmk: tflite + input_number: 2 + + hiai_cv_labelDetectorModel_v3: + fmk: tflite + input_number: 2 + + hiai_cv_saliencyDetectorModel: + fmk: tflite + input_number: 1 + + hiai_dress_detect: + fmk: tflite + input_number: 1 + + hiai_face_model_npu: + fmk: tflite + input_number: 1 + + hiai_frozen_inference_graph: + fmk: tflite + input_number: 1 + + hiai_label_and_video: + fmk: tflite + input_number: 1 + + hiai_lm_inference_graph: + fmk: tflite + input_number: 1 + + hiai_model_0909_kd_rot_ps_softmax: + fmk: tflite + input_number: 1 + + hiai_object_detect_814: + fmk: tflite + input_number: 1 + + hiai_ssd_mobilenetv2_object: + fmk: tflite + input_number: 1 + + hiai_vad: + fmk: tflite + input_number: 2 + + ide_label_retrained: + fmk: tflite + input_number: 1 + + lite-model_albert_lite_base_squadv1_metadata_1: + fmk: tflite + input_number: 3 + + lite-model_mobilebert_1_metadata_1: + fmk: tflite + input_number: 3 + + ml_ei_headpose_pb2tflite: + fmk: tflite + input_number: 3 + benchmark_shapes: 1,64,64,3:16:16 + + ml_headpose_pb2tflite: + fmk: tflite + input_number: 3 + benchmark_shapes: 1,64,64,3:16:16 + + ml_location: + fmk: tflite + input_number: 1 + + ml_ocr_jk: + fmk: tflite + input_number: 1 + + ml_ocr_jk_pb2tflite: + fmk: tflite + input_number: 1 + + ml_tacotron_decoder_step_stf: + fmk: tflite + input_number: 9 + benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 + + ml_video_edit_img_segment_adaptise_pb2tflite: + fmk: tflite + input_number: 2 + + ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: + fmk: tflite + input_number: 2 + + mobilebert_1_default_1: + fmk: tflite + input_number: 3 + + mtk_276landmark_0913: + fmk: tflite + input_number: 1 + + mtk_age_gender: + fmk: tflite + input_number: 1 + + Q_language_model_hrmini_Q4_b4_17w: + fmk: tflite + input_number: 1 + + Q888_lapa158_unet_0924: + fmk: tflite + input_number: 1 + + resnet_v2_101_299: + fmk: tflite + input_number: 1 + + scan_hms_detect: + fmk: tf + input_number: 1 + + unet_mbv2_05_104pts: + fmk: tflite + input_number: 1 diff --git a/mindspore/lite/test/st/scripts/experimental/config/models_onnx.yaml b/mindspore/lite/test/st/scripts/experimental/config/models_onnx.yaml index 46f3ef7c5e0..0f543dca999 100644 --- a/mindspore/lite/test/st/scripts/experimental/config/models_onnx.yaml +++ b/mindspore/lite/test/st/scripts/experimental/config/models_onnx.yaml @@ -1,76 +1,76 @@ -onnx: - 01-face_det_400_400: - fmk: onnx - input_number: 1 - input_suffix: .onnx.bin - output_suffix: .onnx.out - benchmark_shapes: 1,400,400,3 - acc_threshold: 4.5 - - gts_version-RFB-320_simplified: - fmk: onnx - input_number: 1 - - ml_edu_kit_hand_detection: - fmk: onnx - input_number: 1 - - ml_ei_facedetection: - fmk: onnx - input_number: 1 - - ml_facedetector: - fmk: onnx - input_number: 1 - - ml_motion_capture_nanodet_m_0.5x_people_0928_sim: - fmk: onnx - input_number: 1 - - ml_video_edit_judge: - fmk: onnx - input_number: 1 - - ml_video_edit_shot_selection_yolox_nano_coco_reduced: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_autoportrait: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_candy: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_gongnongbing: - fmk: onnx - input_number: 1 - - ml_video_edit_style_transfer_starry: - fmk: onnx - input_number: 1 - - ml_video_edit_vignet: - fmk: onnx - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 - - mtk_detect-mbv2-shortcut-400-400-simplified: - fmk: onnx - input_number: 1 - - yolov5s: - fmk: onnx - input_number: 1 +onnx: + 01-face_det_400_400: + fmk: onnx + input_number: 1 + input_suffix: .onnx.bin + output_suffix: .onnx.out + benchmark_shapes: 1,400,400,3 + acc_threshold: 4.5 + + gts_version-RFB-320_simplified: + fmk: onnx + input_number: 1 + + ml_edu_kit_hand_detection: + fmk: onnx + input_number: 1 + + ml_ei_facedetection: + fmk: onnx + input_number: 1 + + ml_facedetector: + fmk: onnx + input_number: 1 + + ml_motion_capture_nanodet_m_0.5x_people_0928_sim: + fmk: onnx + input_number: 1 + + ml_video_edit_judge: + fmk: onnx + input_number: 1 + + ml_video_edit_shot_selection_yolox_nano_coco_reduced: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_autoportrait: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_candy: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_gongnongbing: + fmk: onnx + input_number: 1 + + ml_video_edit_style_transfer_starry: + fmk: onnx + input_number: 1 + + ml_video_edit_vignet: + fmk: onnx + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: + fmk: onnx + input_number: 1 + + mtk_detect-mbv2-shortcut-400-400-simplified: + fmk: onnx + input_number: 1 + + yolov5s: + fmk: onnx + input_number: 1 diff --git a/mindspore/lite/test/st/scripts/experimental/config/models_release.yaml b/mindspore/lite/test/st/scripts/experimental/config/models_release.yaml index 5f953e0e9fd..527b4fdea4a 100644 --- a/mindspore/lite/test/st/scripts/experimental/config/models_release.yaml +++ b/mindspore/lite/test/st/scripts/experimental/config/models_release.yaml @@ -1,423 +1,423 @@ -tf: - densenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - fsr_270_mindspore: - fmk: tf - input_number: 1 - - fsr_360_mindspore: - fmk: tf - input_number: 1 - - fsr_720_mindspore: - fmk: tf - input_number: 1 - - hiai_cpu_face_emotion: - fmk: tf - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tf - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tf - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tf - input_number: 1 - - hiai_ghostnet: - fmk: tf - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tf - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tf - input_number: 1 - - hiai_latin_ocr: - fmk: tf - input_number: 1 - - ml_ei_headpose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,64,64,3 - - ml_face_openclose: - fmk: tf - input_number: 1 - benchmark_shapes: 1,32,32,3 - - ml_ocr_latin: - fmk: tf - input_number: 1 - - ml_video_edit_shot_selection_opticalFlow: - fmk: tf - input_number: 1 - - mobilenet_v1_0.25_128_frozen: - fmk: tf - input_number: 1 - benchmark_shapes: 1,128,128,3 - - Q_crnn_ori_75w_slim_norm: - fmk: tf - input_number: 1 - - Q_crnn_screen_slim400w_more_20w: - fmk: tf - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: - fmk: tf - input_number: 1 - - scan_hms_angle: - fmk: tf - input_number: 1 - - siteAI_trans_nonlinear40g: - fmk: tf - input_number: 1 - benchmark_shapes: 1,271 - - siteAI_wireless_depress_w: - fmk: tf - input_number: 1 - benchmark_shapes: 1,36 - ---- - -caffe: - 2012_ATLANTA_1class_20190621_v4.x_nomean: - fmk: caffe - input_number: 1 - - bank_card_recognition_fcny: - fmk: caffe - input_number: 1 - - hdc_contour_pose_128: - fmk: caffe - input_number: 1 - - hdc_fivembnet: - fmk: caffe - input_number: 1 - - hdc_resnet: - fmk: caffe - input_number: 1 - - hiai_cpu_face_hat: - fmk: caffe - input_number: 1 - - hiai_cv_focusShootOCRModel_06: - fmk: caffe - input_number: 1 - - hiai_face_landmark: - fmk: caffe - input_number: 1 - - hiai_video_seg: - fmk: caffe - input_number: 1 - - ml_face_compare: - fmk: caffe - input_number: 1 - - ml_face_contour: - fmk: caffe - input_number: 1 - - ml_face_hat: - fmk: caffe - input_number: 1 - - ml_face_landmark: - fmk: caffe - input_number: 1 - - ml_hand_3d_regression: - fmk: caffe - input_number: 1 - - ml_hardware_eyeclose: - fmk: caffe - input_number: 1 - - ml_liveness_detect_landmark_tmp: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_20211119: - fmk: caffe - input_number: 1 - - ml_video_edit_hair_dyeing_segmodel_v3: - fmk: caffe - input_number: 1 - - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: - fmk: caffe - input_number: 1 - - ml_video_edit_person_divison_pic: - fmk: caffe - input_number: 1 - ---- - -onnx: - 01-face_det_400_400: - fmk: onnx - input_number: 1 - input_suffix: .onnx.bin - output_suffix: .onnx.out - benchmark_shapes: 1,400,400,3 - acc_threshold: 4.5 - - emotion-ferplus-8: - fmk: onnx - input_number: 1 - - gender_lstm_scd: - fmk: onnx - input_number: 1 - - gender_lstm_vad: - fmk: onnx - input_number: 1 - - gender_resnet34_lzl: - fmk: onnx - input_number: 1 - - hdc_Face_Landmark5_MTI_Aesthetic: - fmk: onnx - input_number: 1 - - ml_table_segment: - fmk: onnx - input_number: 1 - - mnist-8: - fmk: onnx - input_number: 1 - - rcnn-ilsvrc13-9: - fmk: onnx - input_number: 1 - ---- - -tflite: - bloom_model_age_gender: - fmk: tflite - input_number: 1 - - bloom_new_detect: - fmk: tflite - input_number: 1 - - hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: - fmk: tflite - input_number: 1 - - hiai_cpu_face_emotion: - fmk: tflite - input_number: 1 - - hiai_cpu_face_gazing: - fmk: tflite - input_number: 1 - - hiai_cpu_face_headpose: - fmk: tflite - input_number: 1 - - hiai_cv_focusShootOCRModel_08: - fmk: tflite - input_number: 1 - - hiai_cv_labelDetectorModel_v4: - fmk: tflite - input_number: 1 - - hiai_detectmodel_06_23_960_480_1180700: - fmk: tflite - input_number: 1 - - hiai_ghostnet: - fmk: tflite - input_number: 1 - - hiai_iMaxDN_RGB: - fmk: tflite - input_number: 1 - - hiai_iMaxSR_RGB: - fmk: tflite - input_number: 1 - - hiai_latin_ocr: - fmk: tflite - input_number: 1 - - hiai_latin_ocr_1: - fmk: tflite - input_number: 1 - - lite-model_on_device_vision_classifier_popular_us_products_V1_1: - fmk: tflite - input_number: 1 - - lite-model_on_device_vision_classifier_popular_wine_V1_1: - fmk: tflite - input_number: 1 - - lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: - fmk: tflite - input_number: 1 - - ml_ei_headpose: - fmk: tflite - input_number: 1 - - ml_ei_landmark: - fmk: tflite - input_number: 1 - - ml_ei_landmark_pb2tflite: - fmk: tflite - input_number: 1 - - ml_face_openclose: - fmk: tflite - input_number: 1 - - ml_face_openclose_tflite: - fmk: tflite - input_number: 1 - - ml_ocr_latin: - fmk: tflite - input_number: 1 - - ml_ocr_latin_pb2tflite: - fmk: tflite - input_number: 1 - - ml_pic_shopping_pb2tflite: - fmk: tflite - input_number: 1 - - ml_text_correction: - fmk: tflite - input_number: 1 - - mnasnet_0.50_224_1_metadata_1: - fmk: tflite - input_number: 1 - - mobilenet: - fmk: tflite - input_number: 1 - - mobilenet_v1_0.25_128: - fmk: tflite - input_number: 1 - - mtk_face_recognition: - fmk: tflite - input_number: 1 - - mtk_model_face_dress: - fmk: tflite - input_number: 1 - - mtk_new_detect: - fmk: tflite - input_number: 1 - - Q_convert: - fmk: tflite - input_number: 1 - - Q_crnn_ori_75w_slim_norm_pb2tflite: - fmk: tflite - input_number: 1 - - Q_crnn_screen_slim400w_more_20w_pb2tflite: - fmk: tflite - input_number: 1 - - Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: - fmk: tflite - input_number: 1 - - Q_focusocr_cn_recog: - fmk: tflite - input_number: 1 - - Q_focusocr_jk_recog: - fmk: tflite - input_number: 1 - - Q888_age_gender_orderd: - fmk: tflite - input_number: 1 - - Q888_face_emo_dress_mv3_orderd: - fmk: tflite - input_number: 1 - - Q888_isface: - fmk: tflite - input_number: 1 - - Q888_new_detect: - fmk: tflite - input_number: 1 - - Q888_pose: - fmk: tflite - input_number: 1 - - resnet: - fmk: tflite - input_number: 1 - - siteAI_digcom_g2v_keras: - fmk: tflite - input_number: 1 - - siteAI_wireless_depress_w: - fmk: tflite - input_number: 1 - - squeezenet: - fmk: tflite - input_number: 1 - - text_classification: - fmk: tflite - input_number: 1 +tf: + densenet: + fmk: tf + input_number: 1 + benchmark_shapes: 1,224,224,3 + + fsr_270_mindspore: + fmk: tf + input_number: 1 + + fsr_360_mindspore: + fmk: tf + input_number: 1 + + fsr_720_mindspore: + fmk: tf + input_number: 1 + + hiai_cpu_face_emotion: + fmk: tf + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tf + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tf + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tf + input_number: 1 + + hiai_ghostnet: + fmk: tf + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tf + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tf + input_number: 1 + + hiai_latin_ocr: + fmk: tf + input_number: 1 + + ml_ei_headpose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,64,64,3 + + ml_face_openclose: + fmk: tf + input_number: 1 + benchmark_shapes: 1,32,32,3 + + ml_ocr_latin: + fmk: tf + input_number: 1 + + ml_video_edit_shot_selection_opticalFlow: + fmk: tf + input_number: 1 + + mobilenet_v1_0.25_128_frozen: + fmk: tf + input_number: 1 + benchmark_shapes: 1,128,128,3 + + Q_crnn_ori_75w_slim_norm: + fmk: tf + input_number: 1 + + Q_crnn_screen_slim400w_more_20w: + fmk: tf + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid: + fmk: tf + input_number: 1 + + scan_hms_angle: + fmk: tf + input_number: 1 + + siteAI_trans_nonlinear40g: + fmk: tf + input_number: 1 + benchmark_shapes: 1,271 + + siteAI_wireless_depress_w: + fmk: tf + input_number: 1 + benchmark_shapes: 1,36 + +--- + +caffe: + 2012_ATLANTA_1class_20190621_v4.x_nomean: + fmk: caffe + input_number: 1 + + bank_card_recognition_fcny: + fmk: caffe + input_number: 1 + + hdc_contour_pose_128: + fmk: caffe + input_number: 1 + + hdc_fivembnet: + fmk: caffe + input_number: 1 + + hdc_resnet: + fmk: caffe + input_number: 1 + + hiai_cpu_face_hat: + fmk: caffe + input_number: 1 + + hiai_cv_focusShootOCRModel_06: + fmk: caffe + input_number: 1 + + hiai_face_landmark: + fmk: caffe + input_number: 1 + + hiai_video_seg: + fmk: caffe + input_number: 1 + + ml_face_compare: + fmk: caffe + input_number: 1 + + ml_face_contour: + fmk: caffe + input_number: 1 + + ml_face_hat: + fmk: caffe + input_number: 1 + + ml_face_landmark: + fmk: caffe + input_number: 1 + + ml_hand_3d_regression: + fmk: caffe + input_number: 1 + + ml_hardware_eyeclose: + fmk: caffe + input_number: 1 + + ml_liveness_detect_landmark_tmp: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_20211119: + fmk: caffe + input_number: 1 + + ml_video_edit_hair_dyeing_segmodel_v3: + fmk: caffe + input_number: 1 + + ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145: + fmk: caffe + input_number: 1 + + ml_video_edit_person_divison_pic: + fmk: caffe + input_number: 1 + +--- + +onnx: + 01-face_det_400_400: + fmk: onnx + input_number: 1 + input_suffix: .onnx.bin + output_suffix: .onnx.out + benchmark_shapes: 1,400,400,3 + acc_threshold: 4.5 + + emotion-ferplus-8: + fmk: onnx + input_number: 1 + + gender_lstm_scd: + fmk: onnx + input_number: 1 + + gender_lstm_vad: + fmk: onnx + input_number: 1 + + gender_resnet34_lzl: + fmk: onnx + input_number: 1 + + hdc_Face_Landmark5_MTI_Aesthetic: + fmk: onnx + input_number: 1 + + ml_table_segment: + fmk: onnx + input_number: 1 + + mnist-8: + fmk: onnx + input_number: 1 + + rcnn-ilsvrc13-9: + fmk: onnx + input_number: 1 + +--- + +tflite: + bloom_model_age_gender: + fmk: tflite + input_number: 1 + + bloom_new_detect: + fmk: tflite + input_number: 1 + + hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite: + fmk: tflite + input_number: 1 + + hiai_cpu_face_emotion: + fmk: tflite + input_number: 1 + + hiai_cpu_face_gazing: + fmk: tflite + input_number: 1 + + hiai_cpu_face_headpose: + fmk: tflite + input_number: 1 + + hiai_cv_focusShootOCRModel_08: + fmk: tflite + input_number: 1 + + hiai_cv_labelDetectorModel_v4: + fmk: tflite + input_number: 1 + + hiai_detectmodel_06_23_960_480_1180700: + fmk: tflite + input_number: 1 + + hiai_ghostnet: + fmk: tflite + input_number: 1 + + hiai_iMaxDN_RGB: + fmk: tflite + input_number: 1 + + hiai_iMaxSR_RGB: + fmk: tflite + input_number: 1 + + hiai_latin_ocr: + fmk: tflite + input_number: 1 + + hiai_latin_ocr_1: + fmk: tflite + input_number: 1 + + lite-model_on_device_vision_classifier_popular_us_products_V1_1: + fmk: tflite + input_number: 1 + + lite-model_on_device_vision_classifier_popular_wine_V1_1: + fmk: tflite + input_number: 1 + + lma_tsec_shallow_channels16_ds2.1.1_model-best-f1: + fmk: tflite + input_number: 1 + + ml_ei_headpose: + fmk: tflite + input_number: 1 + + ml_ei_landmark: + fmk: tflite + input_number: 1 + + ml_ei_landmark_pb2tflite: + fmk: tflite + input_number: 1 + + ml_face_openclose: + fmk: tflite + input_number: 1 + + ml_face_openclose_tflite: + fmk: tflite + input_number: 1 + + ml_ocr_latin: + fmk: tflite + input_number: 1 + + ml_ocr_latin_pb2tflite: + fmk: tflite + input_number: 1 + + ml_pic_shopping_pb2tflite: + fmk: tflite + input_number: 1 + + ml_text_correction: + fmk: tflite + input_number: 1 + + mnasnet_0.50_224_1_metadata_1: + fmk: tflite + input_number: 1 + + mobilenet: + fmk: tflite + input_number: 1 + + mobilenet_v1_0.25_128: + fmk: tflite + input_number: 1 + + mtk_face_recognition: + fmk: tflite + input_number: 1 + + mtk_model_face_dress: + fmk: tflite + input_number: 1 + + mtk_new_detect: + fmk: tflite + input_number: 1 + + Q_convert: + fmk: tflite + input_number: 1 + + Q_crnn_ori_75w_slim_norm_pb2tflite: + fmk: tflite + input_number: 1 + + Q_crnn_screen_slim400w_more_20w_pb2tflite: + fmk: tflite + input_number: 1 + + Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid_tflite: + fmk: tflite + input_number: 1 + + Q_focusocr_cn_recog: + fmk: tflite + input_number: 1 + + Q_focusocr_jk_recog: + fmk: tflite + input_number: 1 + + Q888_age_gender_orderd: + fmk: tflite + input_number: 1 + + Q888_face_emo_dress_mv3_orderd: + fmk: tflite + input_number: 1 + + Q888_isface: + fmk: tflite + input_number: 1 + + Q888_new_detect: + fmk: tflite + input_number: 1 + + Q888_pose: + fmk: tflite + input_number: 1 + + resnet: + fmk: tflite + input_number: 1 + + siteAI_digcom_g2v_keras: + fmk: tflite + input_number: 1 + + siteAI_wireless_depress_w: + fmk: tflite + input_number: 1 + + squeezenet: + fmk: tflite + input_number: 1 + + text_classification: + fmk: tflite + input_number: 1 diff --git a/mindspore/lite/test/st/scripts/nnie/run_benchmark_nnie.sh b/mindspore/lite/test/st/scripts/nnie/run_benchmark_nnie.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/st/scripts/nnie/run_benchmark_nnie_micro.sh b/mindspore/lite/test/st/scripts/nnie/run_benchmark_nnie_micro.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/st/scripts/nnie/run_converter_nnie.sh b/mindspore/lite/test/st/scripts/nnie/run_converter_nnie.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/st/scripts/nnie/run_converter_nnie_micro.sh b/mindspore/lite/test/st/scripts/nnie/run_converter_nnie_micro.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/st/scripts/run_net_train.sh b/mindspore/lite/test/st/scripts/run_net_train.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/st/scripts/triton/test_triton.sh b/mindspore/lite/test/st/scripts/triton/test_triton.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/scripts/run_ut_arm64.sh b/mindspore/lite/test/ut/scripts/run_ut_arm64.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/dynamic_library_loader_test.cc b/mindspore/lite/test/ut/src/dynamic_library_loader_test.cc index 2b2fc74cb1c..5ca2b663686 100644 --- a/mindspore/lite/test/ut/src/dynamic_library_loader_test.cc +++ b/mindspore/lite/test/ut/src/dynamic_library_loader_test.cc @@ -1,43 +1,43 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindspore/lite/src/common/dynamic_library_loader.h" -#include "common/common_test.h" - -namespace mindspore { -class LoaderUtilTest : public mindspore::CommonTest { - public: - LoaderUtilTest() {} -}; - -/* - in file add.c, the code is: - int add(int a, int b) {return a + b;} - use this command to generate so file: - gcc add.cc -fPIC -shared -o libadd.so - use this command to see the symbol table: - nm -D libadd.so -*/ -TEST_F(LoaderUtilTest, TestAdd) { - lite::DynamicLibraryLoader loader; - loader.Open("./libadd.so"); - int (*add)(int a, int b); - add = (int (*)(int, int))loader.GetFunc("add"); - int res = add(7, 8); - loader.Close(); - ASSERT_EQ(15, res); -} -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/common/dynamic_library_loader.h" +#include "common/common_test.h" + +namespace mindspore { +class LoaderUtilTest : public mindspore::CommonTest { + public: + LoaderUtilTest() {} +}; + +/* + in file add.c, the code is: + int add(int a, int b) {return a + b;} + use this command to generate so file: + gcc add.cc -fPIC -shared -o libadd.so + use this command to see the symbol table: + nm -D libadd.so +*/ +TEST_F(LoaderUtilTest, TestAdd) { + lite::DynamicLibraryLoader loader; + loader.Open("./libadd.so"); + int (*add)(int a, int b); + add = (int (*)(int, int))loader.GetFunc("add"); + int res = add(7, 8); + loader.Close(); + ASSERT_EQ(15, res); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_input_x_1_3_224_224.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_input_x_1_3_224_224.bin old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_f_1_1280_7_7.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_f_1_1280_7_7.bin old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_y_1_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_y_1_1000.bin old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_input.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_input.f32 old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_0.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_0.f32 old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_1.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_1.f32 old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_2.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface_out_2.f32 old mode 100755 new mode 100644 diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index a0988660ef2..402fd7884cc 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -1,253 +1,253 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" - -namespace mindspore::lite::opencl::test { - -class TestOpenCL_DepthwiseConv2d : public CommonTest {}; - -namespace { -// Check and optimize -// PrimitiveType_DepthwiseConv2D: src/ops/populate/depthwise_conv2d_populate.cc -OpParameter *CreateParameter(int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_u, int pad_d, int pad_l, - int pad_r, int dilation_h, int dilation_w, ActType act_type, int input_channel) { - auto *param = test::CreateParameter(schema::PrimitiveType_Conv2DFusion); - param->kernel_h_ = kernel_h; - param->kernel_w_ = kernel_w; - param->stride_h_ = stride_h; - param->stride_w_ = stride_w; - param->pad_u_ = pad_u; - param->pad_d_ = pad_d; - param->pad_l_ = pad_l; - param->pad_r_ = pad_r; - param->input_channel_ = input_channel; - param->output_channel_ = input_channel; - param->group_ = input_channel; - param->dilation_h_ = dilation_h; - param->dilation_w_ = dilation_w; - param->act_type_ = act_type; - return reinterpret_cast(param); -} -} // namespace - -TEST_F(TestOpenCL_DepthwiseConv2d, NoPad) { - int kernel_h = 3; - int kernel_w = 3; - int stride_h = 1; - int stride_w = 1; - int pad_u = 0; - int pad_d = 0; - int pad_l = 0; - int pad_r = 0; - int dilation_h = 1; - int dilation_w = 1; - ActType act_type = ActType_No; - - std::vector input_shape = {1, 4, 4, 4}; - std::vector output_shape = {1, 2, 2, 4}; - std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; - std::vector bias_shape = {output_shape.back()}; - float input_data[] = { - 0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, - 0.96366274, 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, - 0.0202184, 0.83261985, 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, - 0.11827443, 0.639921, 0.14335328, 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, - 0.45615032, 0.56843394, 0.0187898, 0.6176355, 0.6120957, 0.616934, 0.94374806, 0.6818203, - 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, 0.67063785, 0.21038257, 0.12892629, - 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, 0.20887676, 0.16130951, - 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, 0.13818295, - }; - float bias_data[] = {0, 0, 0, 0}; - float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, - 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, - 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, - 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, - 0.5865129, 0.02010755, 0.82894003, 0.00469548}; - float output_data[] = {2.9720426, 1.890834, 2.3618119, 2.3867798, 2.5666943, 1.6261611, 2.0977764, 1.6445805, - 2.462798, 1.6643658, 1.6861027, 1.8428761, 2.5156446, 1.5366757, 1.6767557, 1.6905226}; - - for (auto fp16_enable : {false, true}) { - auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, - dilation_w, act_type, input_shape.back()); - TestMain({{input_shape, input_data, VAR}, - {weight_shape, weight_data, CONST_TENSOR}, - {bias_shape, bias_data, CONST_TENSOR}}, - {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); - } -} - -TEST_F(TestOpenCL_DepthwiseConv2d, Pad) { - int kernel_h = 3; - int kernel_w = 3; - int stride_h = 1; - int stride_w = 1; - int pad_u = 1; - int pad_d = 1; - int pad_l = 1; - int pad_r = 1; - int dilation_h = 1; - int dilation_w = 1; - ActType act_type = ActType_No; - - std::vector input_shape = {1, 3, 3, 5}; - std::vector output_shape = {1, 3, 3, 5}; - std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; - std::vector bias_shape = {output_shape.back()}; - float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, - 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, - 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, - 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, - 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, - 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; - float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, - 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, - 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, - 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, - 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, - 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; - float bias_data[] = {0, 0, 0, 0, 0}; - float output_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, - 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, - 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, - 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, - 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, - 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; - - for (auto fp16_enable : {false, true}) { - auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, - dilation_w, act_type, input_shape.back()); - TestMain({{input_shape, input_data, VAR}, - {weight_shape, weight_data, CONST_TENSOR}, - {bias_shape, bias_data, CONST_TENSOR}}, - {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); - } -} - -TEST_F(TestOpenCL_DepthwiseConv2d, NoPad1) { - int kernel_h = 2; - int kernel_w = 2; - int stride_h = 1; - int stride_w = 1; - int pad_u = 0; - int pad_d = 0; - int pad_l = 0; - int pad_r = 0; - int dilation_h = 1; - int dilation_w = 1; - ActType act_type = ActType_No; - - std::vector input_shape = {1, 4, 4, 4}; - std::vector output_shape = {1, 3, 3, 4}; - std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; - std::vector bias_shape = {output_shape.back()}; - float input_data[] = {0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, - 0.96366274, 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, - 0.0202184, 0.83261985, 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, - 0.11827443, 0.639921, 0.14335328, 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, - 0.45615032, 0.56843394, 0.0187898, 0.6176355, 0.6120957, 0.616934, 0.94374806, 0.6818203, - 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, 0.67063785, 0.21038257, 0.12892629, - 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, 0.20887676, 0.16130951, - 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, 0.13818295}; - float bias_data[] = {0, 0, 0, 0}; - float weight_data[] = {0.19658236, 0.36872517, 0.82099323, 0.09710128, 0.83794491, 0.09609841, - 0.97645947, 0.4686512, 0.97676109, 0.60484552, 0.73926358, 0.03918779, - 0.28280696, 0.12019656, 0.2961402, 0.11872772}; - float output_data[] = {0.3757235, 1.8489048, 1.4467758, 0.6116009, 1.2535334, 1.6583176, 1.2530621, 0.6590755, - 0.5466661, 1.22944, 0.93263525, 0.5317252, 0.7987474, 1.618667, 1.090071, 0.60372007, - 0.773425, 1.5383728, 1.262479, 0.54334986, 0.5755667, 1.3171062, 0.82401496, 0.39336145, - 0.6703031, 0.9385749, 1.018886, 0.40566355, 1.1277528, 0.7773028, 1.5164642, 0.27685273, - 0.86816025, 0.72971237, 1.1791146, 0.12131907}; - - for (auto fp16_enable : {false, true}) { - auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, - dilation_w, act_type, input_shape.back()); - TestMain({{input_shape, input_data, VAR}, - {weight_shape, weight_data, CONST_TENSOR}, - {bias_shape, bias_data, CONST_TENSOR}}, - {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); - } -} -TEST_F(TestOpenCL_DepthwiseConv2d, Pad1) { - int kernel_h = 3; - int kernel_w = 3; - int stride_h = 1; - int stride_w = 1; - int pad_u = 1; - int pad_d = 1; - int pad_l = 1; - int pad_r = 1; - int dilation_h = 1; - int dilation_w = 1; - ActType act_type = ActType_No; - - std::vector input_shape = {1, 5, 5, 6}; - std::vector output_shape = {1, 5, 5, 6}; - std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; - std::vector bias_shape = {output_shape.back()}; - float input_data[] = { - 0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, 0.96366274, - 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, 0.0202184, 0.83261985, - 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, 0.11827443, 0.639921, 0.14335328, - 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, 0.45615032, 0.56843394, 0.0187898, 0.6176355, - 0.6120957, 0.616934, 0.94374806, 0.6818203, 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, - 0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, - 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, - 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, - 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, 0.31798318, - 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962, - 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, 0.5865129, 0.02010755, 0.82894003, - 0.00469548, 0.6778165, 0.27000797, 0.735194, 0.96218854, 0.24875315, 0.57615733, 0.5920419, 0.5722519, - 0.22308163, 0.952749, 0.44712538, 0.84640867, 0.6994793, 0.29743695, 0.81379783, 0.39650574, 0.8811032, - 0.5812729, 0.8817354, 0.6925316, 0.7252543, 0.50132436, 0.95608366, 0.6439902, 0.42385504, 0.6063932, - 0.0191932, 0.30157483, 0.66017354, 0.2900776, 0.6180154, 0.4287687, 0.13547407, 0.29828233, 0.5699649, - 0.59087276, 0.57432526, 0.6532008, 0.65210325, 0.43141845, 0.8965466, 0.36756188, 0.43586493, 0.89192337, - 0.806194, 0.7038886, 0.10022689, 0.9194826, 0.7142413, 0.998847}; - float weight_data[] = {0.1494483, 0.86812606, 0.16249293, 0.61555956, 0.12381998, 0.84800823, 0.80731896, 0.56910074, - 0.4071833, 0.069167, 0.69742877, 0.45354268, 0.7220556, 0.86638233, 0.97552151, 0.85580334, - 0.01171408, 0.35997806, 0.72999056, 0.17162968, 0.52103661, 0.05433799, 0.19999652, 0.01852179, - 0.7936977, 0.22392469, 0.34535168, 0.92808129, 0.7044144, 0.03183893, 0.16469416, 0.6214784, - 0.57722859, 0.23789282, 0.934214, 0.61396596, 0.5356328, 0.58990998, 0.73012203, 0.311945, - 0.39822106, 0.20984375, 0.18619301, 0.94437239, 0.7395508, 0.49045881, 0.22741463, 0.25435648, - 0.05802916, 0.43441663, 0.31179588, 0.69634349, 0.37775184, 0.17960368}; - float bias_data[] = {0, 0, 0, 0, 0, 0}; - float output_data[] = { - 0.8388255, 1.7207233, 0.56646764, 1.50962, 0.6184657, 0.7572999, 1.7197044, 2.8834608, 1.0304408, 1.5622743, - 0.95027775, 1.1451806, 2.0191956, 2.9541533, 1.1799709, 1.6366025, 1.3484346, 1.0071151, 1.3740869, 2.1602216, - 1.0846798, 1.7810996, 1.6170096, 0.6889053, 0.8671698, 1.4957678, 0.68065727, 1.0596768, 0.9761665, 0.38881996, - 1.524128, 2.2121127, 1.1506181, 1.330961, 1.8186853, 0.9094476, 2.3777275, 2.5568333, 1.8321692, 1.8297466, - 2.069798, 1.3701197, 2.7548862, 2.0871775, 2.3611763, 1.5387508, 1.6725919, 1.2565864, 2.6130712, 2.0915375, - 1.2955335, 1.6571269, 1.7603228, 1.3315495, 1.0005323, 1.0135669, 1.2701392, 1.8230836, 1.6048919, 1.4224635, - 1.4651375, 1.0251865, 1.0325887, 1.2355556, 1.3313429, 0.6756204, 2.602416, 2.1827717, 1.4354478, 1.6628273, - 2.0171032, 1.0299077, 2.6085434, 1.3310422, 2.1677747, 2.457499, 2.6715999, 1.0225507, 2.5822947, 2.1068158, - 1.6401942, 2.5422354, 2.6937182, 1.3813802, 1.1241511, 1.273326, 1.2024405, 1.4564767, 2.016776, 1.0182433, - 1.228782, 0.83329916, 1.033041, 1.3280122, 1.9437144, 0.6729013, 2.438968, 2.3275855, 2.289177, 1.4376242, - 2.4595368, 1.325891, 2.018128, 2.676854, 1.9685578, 1.8240746, 2.3104675, 1.4958379, 2.474168, 2.6657124, - 1.6738743, 2.336092, 2.3048637, 1.802324, 1.7594845, 1.6022205, 1.2564734, 1.8977238, 1.6991055, 1.8674731, - 0.47793916, 1.2031221, 0.6579696, 1.0724078, 0.96408695, 0.5074543, 1.2399375, 1.410824, 0.56263226, 1.3138686, - 1.4859737, 0.7219256, 1.3437214, 2.0015993, 1.0472497, 1.064316, 1.7359762, 0.9249617, 1.2835678, 2.1866667, - 0.92954785, 2.005947, 1.8761289, 1.2612648, 1.2410495, 1.263778, 0.54638237, 1.8269669, 1.3152003, 0.7890457}; - - for (auto fp16_enable : {false, true}) { - auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, - dilation_w, act_type, input_shape.back()); - TestMain({{input_shape, input_data, VAR}, - {weight_shape, weight_data, CONST_TENSOR}, - {bias_shape, bias_data, CONST_TENSOR}}, - {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); - } -} -} // namespace mindspore::lite::opencl::test +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/src/runtime/kernel/opencl/common.h" +#include "nnacl/conv_parameter.h" + +namespace mindspore::lite::opencl::test { + +class TestOpenCL_DepthwiseConv2d : public CommonTest {}; + +namespace { +// Check and optimize +// PrimitiveType_DepthwiseConv2D: src/ops/populate/depthwise_conv2d_populate.cc +OpParameter *CreateParameter(int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_u, int pad_d, int pad_l, + int pad_r, int dilation_h, int dilation_w, ActType act_type, int input_channel) { + auto *param = test::CreateParameter(schema::PrimitiveType_Conv2DFusion); + param->kernel_h_ = kernel_h; + param->kernel_w_ = kernel_w; + param->stride_h_ = stride_h; + param->stride_w_ = stride_w; + param->pad_u_ = pad_u; + param->pad_d_ = pad_d; + param->pad_l_ = pad_l; + param->pad_r_ = pad_r; + param->input_channel_ = input_channel; + param->output_channel_ = input_channel; + param->group_ = input_channel; + param->dilation_h_ = dilation_h; + param->dilation_w_ = dilation_w; + param->act_type_ = act_type; + return reinterpret_cast(param); +} +} // namespace + +TEST_F(TestOpenCL_DepthwiseConv2d, NoPad) { + int kernel_h = 3; + int kernel_w = 3; + int stride_h = 1; + int stride_w = 1; + int pad_u = 0; + int pad_d = 0; + int pad_l = 0; + int pad_r = 0; + int dilation_h = 1; + int dilation_w = 1; + ActType act_type = ActType_No; + + std::vector input_shape = {1, 4, 4, 4}; + std::vector output_shape = {1, 2, 2, 4}; + std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; + std::vector bias_shape = {output_shape.back()}; + float input_data[] = { + 0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, + 0.96366274, 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, + 0.0202184, 0.83261985, 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, + 0.11827443, 0.639921, 0.14335328, 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, + 0.45615032, 0.56843394, 0.0187898, 0.6176355, 0.6120957, 0.616934, 0.94374806, 0.6818203, + 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, 0.67063785, 0.21038257, 0.12892629, + 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, 0.20887676, 0.16130951, + 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, 0.13818295, + }; + float bias_data[] = {0, 0, 0, 0}; + float weight_data[] = {0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, + 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, + 0.5759465, 0.9292962, 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, + 0.5865129, 0.02010755, 0.82894003, 0.00469548}; + float output_data[] = {2.9720426, 1.890834, 2.3618119, 2.3867798, 2.5666943, 1.6261611, 2.0977764, 1.6445805, + 2.462798, 1.6643658, 1.6861027, 1.8428761, 2.5156446, 1.5366757, 1.6767557, 1.6905226}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, + dilation_w, act_type, input_shape.back()); + TestMain({{input_shape, input_data, VAR}, + {weight_shape, weight_data, CONST_TENSOR}, + {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); + } +} + +TEST_F(TestOpenCL_DepthwiseConv2d, Pad) { + int kernel_h = 3; + int kernel_w = 3; + int stride_h = 1; + int stride_w = 1; + int pad_u = 1; + int pad_d = 1; + int pad_l = 1; + int pad_r = 1; + int dilation_h = 1; + int dilation_w = 1; + ActType act_type = ActType_No; + + std::vector input_shape = {1, 3, 3, 5}; + std::vector output_shape = {1, 3, 3, 5}; + std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; + std::vector bias_shape = {output_shape.back()}; + float input_data[] = {0.5488135, 0.3834415, 0.77815676, 0.9446689, 0.6120957, 0.71518934, 0.79172504, 0.87001216, + 0.5218483, 0.616934, 0.60276335, 0.5288949, 0.9786183, 0.41466194, 0.94374806, 0.5448832, + 0.56804454, 0.7991586, 0.2645556, 0.6818203, 0.4236548, 0.92559665, 0.46147937, 0.7742337, + 0.3595079, 0.6458941, 0.07103606, 0.7805292, 0.45615032, 0.43703195, 0.4375872, 0.0871293, + 0.11827443, 0.56843394, 0.6976312, 0.891773, 0.0202184, 0.639921, 0.0187898, 0.06022547, + 0.96366274, 0.83261985, 0.14335328, 0.6176355, 0.6667667}; + float weight_data[] = {0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, + 0.10204481, 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, + 0.11037514, 0.6563296, 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, + 0.09609841, 0.97645944, 0.4686512, 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, + 0.12019656, 0.2961402, 0.11872772, 0.31798318, 0.41426298, 0.06414749, 0.6924721, 0.56660146, + 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962}; + float bias_data[] = {0, 0, 0, 0, 0}; + float output_data[] = {1.189188, 1.0425153, 1.8012011, 0.6074867, 1.2120346, 1.5005531, 0.8346756, 2.4365785, + 0.54975945, 1.6815965, 1.2690231, 0.60214907, 1.6158017, 0.42115876, 0.8854959, 1.1709145, + 1.0929465, 1.3534508, 1.1985044, 1.2932993, 2.4621446, 1.7086457, 2.6977584, 2.1960166, + 2.3769147, 2.3185873, 0.6133741, 0.9687358, 0.9987654, 1.0254729, 0.8368954, 0.74171704, + 0.8749627, 0.8953936, 0.5093431, 1.5496738, 0.54936385, 0.7683113, 1.165742, 1.3682933, + 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, + dilation_w, act_type, input_shape.back()); + TestMain({{input_shape, input_data, VAR}, + {weight_shape, weight_data, CONST_TENSOR}, + {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); + } +} + +TEST_F(TestOpenCL_DepthwiseConv2d, NoPad1) { + int kernel_h = 2; + int kernel_w = 2; + int stride_h = 1; + int stride_w = 1; + int pad_u = 0; + int pad_d = 0; + int pad_l = 0; + int pad_r = 0; + int dilation_h = 1; + int dilation_w = 1; + ActType act_type = ActType_No; + + std::vector input_shape = {1, 4, 4, 4}; + std::vector output_shape = {1, 3, 3, 4}; + std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; + std::vector bias_shape = {output_shape.back()}; + float input_data[] = {0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, + 0.96366274, 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, + 0.0202184, 0.83261985, 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, + 0.11827443, 0.639921, 0.14335328, 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, + 0.45615032, 0.56843394, 0.0187898, 0.6176355, 0.6120957, 0.616934, 0.94374806, 0.6818203, + 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, 0.67063785, 0.21038257, 0.12892629, + 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, 0.20887676, 0.16130951, + 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, 0.13818295}; + float bias_data[] = {0, 0, 0, 0}; + float weight_data[] = {0.19658236, 0.36872517, 0.82099323, 0.09710128, 0.83794491, 0.09609841, + 0.97645947, 0.4686512, 0.97676109, 0.60484552, 0.73926358, 0.03918779, + 0.28280696, 0.12019656, 0.2961402, 0.11872772}; + float output_data[] = {0.3757235, 1.8489048, 1.4467758, 0.6116009, 1.2535334, 1.6583176, 1.2530621, 0.6590755, + 0.5466661, 1.22944, 0.93263525, 0.5317252, 0.7987474, 1.618667, 1.090071, 0.60372007, + 0.773425, 1.5383728, 1.262479, 0.54334986, 0.5755667, 1.3171062, 0.82401496, 0.39336145, + 0.6703031, 0.9385749, 1.018886, 0.40566355, 1.1277528, 0.7773028, 1.5164642, 0.27685273, + 0.86816025, 0.72971237, 1.1791146, 0.12131907}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, + dilation_w, act_type, input_shape.back()); + TestMain({{input_shape, input_data, VAR}, + {weight_shape, weight_data, CONST_TENSOR}, + {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); + } +} +TEST_F(TestOpenCL_DepthwiseConv2d, Pad1) { + int kernel_h = 3; + int kernel_w = 3; + int stride_h = 1; + int stride_w = 1; + int pad_u = 1; + int pad_d = 1; + int pad_l = 1; + int pad_r = 1; + int dilation_h = 1; + int dilation_w = 1; + ActType act_type = ActType_No; + + std::vector input_shape = {1, 5, 5, 6}; + std::vector output_shape = {1, 5, 5, 6}; + std::vector weight_shape = {1, kernel_h, kernel_w, output_shape.back()}; + std::vector bias_shape = {output_shape.back()}; + float input_data[] = { + 0.5488135, 0.71518934, 0.60276335, 0.5448832, 0.4236548, 0.6458941, 0.4375872, 0.891773, 0.96366274, + 0.3834415, 0.79172504, 0.5288949, 0.56804454, 0.92559665, 0.07103606, 0.0871293, 0.0202184, 0.83261985, + 0.77815676, 0.87001216, 0.9786183, 0.7991586, 0.46147937, 0.7805292, 0.11827443, 0.639921, 0.14335328, + 0.9446689, 0.5218483, 0.41466194, 0.2645556, 0.7742337, 0.45615032, 0.56843394, 0.0187898, 0.6176355, + 0.6120957, 0.616934, 0.94374806, 0.6818203, 0.3595079, 0.43703195, 0.6976312, 0.06022547, 0.6667667, + 0.67063785, 0.21038257, 0.12892629, 0.31542835, 0.36371076, 0.57019675, 0.43860152, 0.9883738, 0.10204481, + 0.20887676, 0.16130951, 0.6531083, 0.2532916, 0.46631077, 0.2444256, 0.15896958, 0.11037514, 0.6563296, + 0.13818295, 0.19658236, 0.36872518, 0.82099324, 0.09710128, 0.8379449, 0.09609841, 0.97645944, 0.4686512, + 0.9767611, 0.6048455, 0.7392636, 0.03918779, 0.28280696, 0.12019656, 0.2961402, 0.11872772, 0.31798318, + 0.41426298, 0.06414749, 0.6924721, 0.56660146, 0.2653895, 0.5232481, 0.09394051, 0.5759465, 0.9292962, + 0.31856894, 0.6674104, 0.13179787, 0.7163272, 0.2894061, 0.18319136, 0.5865129, 0.02010755, 0.82894003, + 0.00469548, 0.6778165, 0.27000797, 0.735194, 0.96218854, 0.24875315, 0.57615733, 0.5920419, 0.5722519, + 0.22308163, 0.952749, 0.44712538, 0.84640867, 0.6994793, 0.29743695, 0.81379783, 0.39650574, 0.8811032, + 0.5812729, 0.8817354, 0.6925316, 0.7252543, 0.50132436, 0.95608366, 0.6439902, 0.42385504, 0.6063932, + 0.0191932, 0.30157483, 0.66017354, 0.2900776, 0.6180154, 0.4287687, 0.13547407, 0.29828233, 0.5699649, + 0.59087276, 0.57432526, 0.6532008, 0.65210325, 0.43141845, 0.8965466, 0.36756188, 0.43586493, 0.89192337, + 0.806194, 0.7038886, 0.10022689, 0.9194826, 0.7142413, 0.998847}; + float weight_data[] = {0.1494483, 0.86812606, 0.16249293, 0.61555956, 0.12381998, 0.84800823, 0.80731896, 0.56910074, + 0.4071833, 0.069167, 0.69742877, 0.45354268, 0.7220556, 0.86638233, 0.97552151, 0.85580334, + 0.01171408, 0.35997806, 0.72999056, 0.17162968, 0.52103661, 0.05433799, 0.19999652, 0.01852179, + 0.7936977, 0.22392469, 0.34535168, 0.92808129, 0.7044144, 0.03183893, 0.16469416, 0.6214784, + 0.57722859, 0.23789282, 0.934214, 0.61396596, 0.5356328, 0.58990998, 0.73012203, 0.311945, + 0.39822106, 0.20984375, 0.18619301, 0.94437239, 0.7395508, 0.49045881, 0.22741463, 0.25435648, + 0.05802916, 0.43441663, 0.31179588, 0.69634349, 0.37775184, 0.17960368}; + float bias_data[] = {0, 0, 0, 0, 0, 0}; + float output_data[] = { + 0.8388255, 1.7207233, 0.56646764, 1.50962, 0.6184657, 0.7572999, 1.7197044, 2.8834608, 1.0304408, 1.5622743, + 0.95027775, 1.1451806, 2.0191956, 2.9541533, 1.1799709, 1.6366025, 1.3484346, 1.0071151, 1.3740869, 2.1602216, + 1.0846798, 1.7810996, 1.6170096, 0.6889053, 0.8671698, 1.4957678, 0.68065727, 1.0596768, 0.9761665, 0.38881996, + 1.524128, 2.2121127, 1.1506181, 1.330961, 1.8186853, 0.9094476, 2.3777275, 2.5568333, 1.8321692, 1.8297466, + 2.069798, 1.3701197, 2.7548862, 2.0871775, 2.3611763, 1.5387508, 1.6725919, 1.2565864, 2.6130712, 2.0915375, + 1.2955335, 1.6571269, 1.7603228, 1.3315495, 1.0005323, 1.0135669, 1.2701392, 1.8230836, 1.6048919, 1.4224635, + 1.4651375, 1.0251865, 1.0325887, 1.2355556, 1.3313429, 0.6756204, 2.602416, 2.1827717, 1.4354478, 1.6628273, + 2.0171032, 1.0299077, 2.6085434, 1.3310422, 2.1677747, 2.457499, 2.6715999, 1.0225507, 2.5822947, 2.1068158, + 1.6401942, 2.5422354, 2.6937182, 1.3813802, 1.1241511, 1.273326, 1.2024405, 1.4564767, 2.016776, 1.0182433, + 1.228782, 0.83329916, 1.033041, 1.3280122, 1.9437144, 0.6729013, 2.438968, 2.3275855, 2.289177, 1.4376242, + 2.4595368, 1.325891, 2.018128, 2.676854, 1.9685578, 1.8240746, 2.3104675, 1.4958379, 2.474168, 2.6657124, + 1.6738743, 2.336092, 2.3048637, 1.802324, 1.7594845, 1.6022205, 1.2564734, 1.8977238, 1.6991055, 1.8674731, + 0.47793916, 1.2031221, 0.6579696, 1.0724078, 0.96408695, 0.5074543, 1.2399375, 1.410824, 0.56263226, 1.3138686, + 1.4859737, 0.7219256, 1.3437214, 2.0015993, 1.0472497, 1.064316, 1.7359762, 0.9249617, 1.2835678, 2.1866667, + 0.92954785, 2.005947, 1.8761289, 1.2612648, 1.2410495, 1.263778, 0.54638237, 1.8269669, 1.3152003, 0.7890457}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(kernel_h, kernel_w, stride_h, stride_w, pad_u, pad_d, pad_l, pad_r, dilation_h, + dilation_w, act_type, input_shape.back()); + TestMain({{input_shape, input_data, VAR}, + {weight_shape, weight_data, CONST_TENSOR}, + {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true); + } +} +} // namespace mindspore::lite::opencl::test diff --git a/mindspore/lite/test/ut/src/utils_test.cc b/mindspore/lite/test/ut/src/utils_test.cc index 97abe5501ba..06bceb86eb1 100644 --- a/mindspore/lite/test/ut/src/utils_test.cc +++ b/mindspore/lite/test/ut/src/utils_test.cc @@ -1,66 +1,66 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "schema/inner/model_generated.h" -#include "common/common_test.h" -#include "include/errorcode.h" -#include "src/common/log_adapter.h" -#include "mindspore/lite/src/executor/kernel_exec.h" -#include "mindspore/lite/src/litert/kernel_exec_util.h" - -namespace mindspore { -class UtilsTest : public mindspore::CommonTest { - public: - UtilsTest() {} -}; - -TEST_F(UtilsTest, TestSubgraph) { - auto kernel0 = std::make_shared(); - auto kernel1 = std::make_shared(); - auto kernel2 = std::make_shared(); - - auto tensor0 = std::make_shared(); - auto tensor1 = std::make_shared(); - auto tensor2 = std::make_shared(); - auto tensor3 = std::make_shared(); - auto tensor4 = std::make_shared(); - - kernel0->AddOutKernel(kernel1.get()); - kernel1->AddInKernel(kernel0.get()); - kernel1->AddOutKernel(kernel2.get()); - kernel2->AddInKernel(kernel1.get()); - - kernel0->set_in_tensors({tensor0.get(), tensor1.get()}); - kernel0->set_out_tensors({tensor2.get()}); - kernel1->set_in_tensors({tensor2.get()}); - kernel1->set_out_tensors({tensor3.get()}); - kernel2->set_in_tensors({tensor3.get()}); - kernel2->set_out_tensors({tensor4.get()}); - - std::vector kernels = {kernel0.get(), kernel1.get(), kernel2.get()}; - - auto input_kernels = kernel::KernelExecUtil::SubgraphInputNodes(kernels); - ASSERT_EQ(input_kernels.size(), 1); - auto output_kernels = kernel::KernelExecUtil::SubgraphOutputNodes(kernels); - ASSERT_EQ(output_kernels.size(), 1); - auto input_tensors = kernel::KernelExecUtil::SubgraphInputTensors(kernels); - ASSERT_EQ(input_tensors.size(), 2); - auto output_tensors = kernel::KernelExecUtil::SubgraphOutputTensors(kernels); - ASSERT_EQ(output_tensors.size(), 1); -} -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "common/common_test.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "mindspore/lite/src/executor/kernel_exec.h" +#include "mindspore/lite/src/litert/kernel_exec_util.h" + +namespace mindspore { +class UtilsTest : public mindspore::CommonTest { + public: + UtilsTest() {} +}; + +TEST_F(UtilsTest, TestSubgraph) { + auto kernel0 = std::make_shared(); + auto kernel1 = std::make_shared(); + auto kernel2 = std::make_shared(); + + auto tensor0 = std::make_shared(); + auto tensor1 = std::make_shared(); + auto tensor2 = std::make_shared(); + auto tensor3 = std::make_shared(); + auto tensor4 = std::make_shared(); + + kernel0->AddOutKernel(kernel1.get()); + kernel1->AddInKernel(kernel0.get()); + kernel1->AddOutKernel(kernel2.get()); + kernel2->AddInKernel(kernel1.get()); + + kernel0->set_in_tensors({tensor0.get(), tensor1.get()}); + kernel0->set_out_tensors({tensor2.get()}); + kernel1->set_in_tensors({tensor2.get()}); + kernel1->set_out_tensors({tensor3.get()}); + kernel2->set_in_tensors({tensor3.get()}); + kernel2->set_out_tensors({tensor4.get()}); + + std::vector kernels = {kernel0.get(), kernel1.get(), kernel2.get()}; + + auto input_kernels = kernel::KernelExecUtil::SubgraphInputNodes(kernels); + ASSERT_EQ(input_kernels.size(), 1); + auto output_kernels = kernel::KernelExecUtil::SubgraphOutputNodes(kernels); + ASSERT_EQ(output_kernels.size(), 1); + auto input_tensors = kernel::KernelExecUtil::SubgraphInputTensors(kernels); + ASSERT_EQ(input_tensors.size(), 2); + auto output_tensors = kernel::KernelExecUtil::SubgraphOutputTensors(kernels); + ASSERT_EQ(output_tensors.size(), 1); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/registry/test_data/tf_add.pb b/mindspore/lite/test/ut/tools/converter/registry/test_data/tf_add.pb deleted file mode 100644 index 458a4856c77..00000000000 Binary files a/mindspore/lite/test/ut/tools/converter/registry/test_data/tf_add.pb and /dev/null differ diff --git a/mindspore/lite/tools/common/CMakeLists.txt b/mindspore/lite/tools/common/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/acl/mapper/argmax_fusion_mapper.cc b/mindspore/lite/tools/converter/adapter/acl/mapper/argmax_fusion_mapper.cc index 7978caf1ad5..4e9fe084b5e 100644 --- a/mindspore/lite/tools/converter/adapter/acl/mapper/argmax_fusion_mapper.cc +++ b/mindspore/lite/tools/converter/adapter/acl/mapper/argmax_fusion_mapper.cc @@ -1,83 +1,83 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/adapter/acl/mapper/argmax_fusion_mapper.h" -#include -#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" -#include "src/common/log_util.h" -#include "tools/converter/adapter/acl/mapper/tbe_op_def.h" -#include "ops/op_utils.h" -#include "ops/auto_generate/gen_lite_ops.h" - -namespace mindspore { -namespace lite { -namespace { -constexpr size_t kNameInputNum = 2; -constexpr size_t kNumFlagThree = 3; -} // namespace - -STATUS ArgMaxFusionMapper::Mapper(const CNodePtr &cnode) { - ValueNodePtr value_node = nullptr; - PrimitivePtr src_prim = nullptr; - if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) { - MS_LOG(ERROR) << "Get primitive from cnode failed."; - return lite::RET_ERROR; - } - if (cnode->size() != kNameInputNum) { - MS_LOG(ERROR) << "Input size of argmax must be " << kNameInputNum << " real size: " << cnode->size(); - return lite::RET_ERROR; - } - // ArgMaxV2 doesn't have keep_dims attr, replace by ArgMaxWithValue - auto keep_dims_ptr = src_prim->GetAttr(ops::kKeepDims); - if (keep_dims_ptr != nullptr && GetValue(keep_dims_ptr)) { - // adjust axis and keep_dims to input to adapt mindir. - auto axis_ptr = src_prim->GetAttr(ops::kAxis); - CHECK_NULL_RETURN(axis_ptr); - auto axis_value_node = NewValueNode(GetValue(axis_ptr)); - MS_CHECK_TRUE_MSG(axis_value_node != nullptr, lite::RET_ERROR, "New value node for axis failed."); - std::vector shape_vec = {}; - auto axis_abstract = std::make_shared(kInt64, shape_vec); - CHECK_NULL_RETURN(axis_abstract); - axis_value_node->set_abstract(axis_abstract); - cnode->add_input(axis_value_node); - - auto keep_dims_value_node = NewValueNode(GetValue(keep_dims_ptr)); - MS_CHECK_TRUE_MSG(keep_dims_value_node != nullptr, lite::RET_ERROR, "New value node for keep_dims failed."); - auto keep_dims_abstract = std::make_shared(kBool, shape_vec); - CHECK_NULL_RETURN(keep_dims_abstract); - keep_dims_value_node->set_abstract(keep_dims_abstract); - cnode->add_input(keep_dims_value_node); - - auto argmax = std::make_shared(); - CHECK_NULL_RETURN(argmax); - auto dst_prim = argmax->GetPrim(); - CHECK_NULL_RETURN(dst_prim); - dst_prim->SetAttrs(src_prim->attrs()); - value_node->set_value(dst_prim); - return lite::RET_OK; - } - - auto dst_prim = std::make_shared(); - CHECK_NULL_RETURN(dst_prim); - dst_prim->AddAttr("output_type", TypeIdToType(kNumberTypeInt32)); - dst_prim->SetAttrs(src_prim->attrs()); - value_node->set_value(dst_prim); - return lite::RET_OK; -} - -REGISTER_PRIMITIVE_MAPPER(kNameArgMaxFusion, ArgMaxFusionMapper) -} // namespace lite -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/adapter/acl/mapper/argmax_fusion_mapper.h" +#include +#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" +#include "src/common/log_util.h" +#include "tools/converter/adapter/acl/mapper/tbe_op_def.h" +#include "ops/op_utils.h" +#include "ops/auto_generate/gen_lite_ops.h" + +namespace mindspore { +namespace lite { +namespace { +constexpr size_t kNameInputNum = 2; +constexpr size_t kNumFlagThree = 3; +} // namespace + +STATUS ArgMaxFusionMapper::Mapper(const CNodePtr &cnode) { + ValueNodePtr value_node = nullptr; + PrimitivePtr src_prim = nullptr; + if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) { + MS_LOG(ERROR) << "Get primitive from cnode failed."; + return lite::RET_ERROR; + } + if (cnode->size() != kNameInputNum) { + MS_LOG(ERROR) << "Input size of argmax must be " << kNameInputNum << " real size: " << cnode->size(); + return lite::RET_ERROR; + } + // ArgMaxV2 doesn't have keep_dims attr, replace by ArgMaxWithValue + auto keep_dims_ptr = src_prim->GetAttr(ops::kKeepDims); + if (keep_dims_ptr != nullptr && GetValue(keep_dims_ptr)) { + // adjust axis and keep_dims to input to adapt mindir. + auto axis_ptr = src_prim->GetAttr(ops::kAxis); + CHECK_NULL_RETURN(axis_ptr); + auto axis_value_node = NewValueNode(GetValue(axis_ptr)); + MS_CHECK_TRUE_MSG(axis_value_node != nullptr, lite::RET_ERROR, "New value node for axis failed."); + std::vector shape_vec = {}; + auto axis_abstract = std::make_shared(kInt64, shape_vec); + CHECK_NULL_RETURN(axis_abstract); + axis_value_node->set_abstract(axis_abstract); + cnode->add_input(axis_value_node); + + auto keep_dims_value_node = NewValueNode(GetValue(keep_dims_ptr)); + MS_CHECK_TRUE_MSG(keep_dims_value_node != nullptr, lite::RET_ERROR, "New value node for keep_dims failed."); + auto keep_dims_abstract = std::make_shared(kBool, shape_vec); + CHECK_NULL_RETURN(keep_dims_abstract); + keep_dims_value_node->set_abstract(keep_dims_abstract); + cnode->add_input(keep_dims_value_node); + + auto argmax = std::make_shared(); + CHECK_NULL_RETURN(argmax); + auto dst_prim = argmax->GetPrim(); + CHECK_NULL_RETURN(dst_prim); + dst_prim->SetAttrs(src_prim->attrs()); + value_node->set_value(dst_prim); + return lite::RET_OK; + } + + auto dst_prim = std::make_shared(); + CHECK_NULL_RETURN(dst_prim); + dst_prim->AddAttr("output_type", TypeIdToType(kNumberTypeInt32)); + dst_prim->SetAttrs(src_prim->attrs()); + value_node->set_value(dst_prim); + return lite::RET_OK; +} + +REGISTER_PRIMITIVE_MAPPER(kNameArgMaxFusion, ArgMaxFusionMapper) +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/adapter/dpico/common/graph_output_name_keeper.h b/mindspore/lite/tools/converter/adapter/dpico/common/graph_output_name_keeper.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_lstm_onnx_infer.cc b/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_lstm_onnx_infer.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_lstm_onnx_infer.h b/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_lstm_onnx_infer.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_maxunpool_infer.cc b/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_maxunpool_infer.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_maxunpool_infer.h b/mindspore/lite/tools/converter/adapter/dpico/infer/dpico_maxunpool_infer.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/mapper/maxunpool_mapper.cc b/mindspore/lite/tools/converter/adapter/dpico/mapper/maxunpool_mapper.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/mapper/maxunpool_mapper.h b/mindspore/lite/tools/converter/adapter/dpico/mapper/maxunpool_mapper.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_lstm_parser.cc b/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_lstm_parser.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_lstm_parser.h b/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_lstm_parser.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_maxunpool_parser.cc b/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_maxunpool_parser.cc old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_maxunpool_parser.h b/mindspore/lite/tools/converter/adapter/dpico/parser/onnx/onnx_maxunpool_parser.h old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/converter/ops/ops_def.h b/mindspore/lite/tools/converter/ops/ops_def.h index b941602d4e8..1fac0a608b1 100644 --- a/mindspore/lite/tools/converter/ops/ops_def.h +++ b/mindspore/lite/tools/converter/ops/ops_def.h @@ -1,65 +1,65 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ -#include "ops/primitive_c.h" -using mindspore::ops::PrimitiveC; - -namespace mindspore { -namespace lite { -#define ADD_CONVERTER_ONLY_OP(name) \ - constexpr auto kName##name = #name; \ - class name : public PrimitiveC { \ - public: \ - name() : PrimitiveC(kName##name) {} \ - ~name() = default; \ - MS_DECLARE_PARENT(name, PrimitiveC); \ - }; - -ADD_CONVERTER_ONLY_OP(Enter); -ADD_CONVERTER_ONLY_OP(Exit); -ADD_CONVERTER_ONLY_OP(If); -ADD_CONVERTER_ONLY_OP(LoopCond); -ADD_CONVERTER_ONLY_OP(NextIteration); -ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3); -ADD_CONVERTER_ONLY_OP(TensorArrayReadV3); -ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3); -ADD_CONVERTER_ONLY_OP(TensorArraySizeV3); -ADD_CONVERTER_ONLY_OP(TensorArrayV3); -ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3); -ADD_CONVERTER_ONLY_OP(Constant); -ADD_CONVERTER_ONLY_OP(Merge); -ADD_CONVERTER_ONLY_OP(Einsum); -ADD_CONVERTER_ONLY_OP(QuantizeLinear); -ADD_CONVERTER_ONLY_OP(DequantizeLinear); -ADD_CONVERTER_ONLY_OP(FakeQuantWithMinMaxVars); -ADD_CONVERTER_ONLY_OP(MegatronAllReduce); -ADD_CONVERTER_ONLY_OP(MegatronLinearAllGather); -ADD_CONVERTER_ONLY_OP(MegatronMakeViewlessTensor); -ADD_CONVERTER_ONLY_OP(MegatronScaledMaskedSoftmax); -ADD_CONVERTER_ONLY_OP(Shrink); -ADD_CONVERTER_ONLY_OP(TfIdfVectorizer); -ADD_CONVERTER_ONLY_OP(MVN); -ADD_CONVERTER_ONLY_OP(RandomUniformLike); -ADD_CONVERTER_ONLY_OP(Rot90); -ADD_CONVERTER_ONLY_OP(SwinAttentionFFN); -ADD_CONVERTER_ONLY_OP(SwinTransformerLnQKV); -ADD_CONVERTER_ONLY_OP(SwinAttentionScore); -} // namespace lite -} // namespace mindspore - -#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ +#include "ops/primitive_c.h" +using mindspore::ops::PrimitiveC; + +namespace mindspore { +namespace lite { +#define ADD_CONVERTER_ONLY_OP(name) \ + constexpr auto kName##name = #name; \ + class name : public PrimitiveC { \ + public: \ + name() : PrimitiveC(kName##name) {} \ + ~name() = default; \ + MS_DECLARE_PARENT(name, PrimitiveC); \ + }; + +ADD_CONVERTER_ONLY_OP(Enter); +ADD_CONVERTER_ONLY_OP(Exit); +ADD_CONVERTER_ONLY_OP(If); +ADD_CONVERTER_ONLY_OP(LoopCond); +ADD_CONVERTER_ONLY_OP(NextIteration); +ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3); +ADD_CONVERTER_ONLY_OP(TensorArrayReadV3); +ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3); +ADD_CONVERTER_ONLY_OP(TensorArraySizeV3); +ADD_CONVERTER_ONLY_OP(TensorArrayV3); +ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3); +ADD_CONVERTER_ONLY_OP(Constant); +ADD_CONVERTER_ONLY_OP(Merge); +ADD_CONVERTER_ONLY_OP(Einsum); +ADD_CONVERTER_ONLY_OP(QuantizeLinear); +ADD_CONVERTER_ONLY_OP(DequantizeLinear); +ADD_CONVERTER_ONLY_OP(FakeQuantWithMinMaxVars); +ADD_CONVERTER_ONLY_OP(MegatronAllReduce); +ADD_CONVERTER_ONLY_OP(MegatronLinearAllGather); +ADD_CONVERTER_ONLY_OP(MegatronMakeViewlessTensor); +ADD_CONVERTER_ONLY_OP(MegatronScaledMaskedSoftmax); +ADD_CONVERTER_ONLY_OP(Shrink); +ADD_CONVERTER_ONLY_OP(TfIdfVectorizer); +ADD_CONVERTER_ONLY_OP(MVN); +ADD_CONVERTER_ONLY_OP(RandomUniformLike); +ADD_CONVERTER_ONLY_OP(Rot90); +ADD_CONVERTER_ONLY_OP(SwinAttentionFFN); +ADD_CONVERTER_ONLY_OP(SwinTransformerLnQKV); +ADD_CONVERTER_ONLY_OP(SwinAttentionScore); +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_NEXTITERATION_H_ diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc index 407a0fa6f67..eeb0bb5dc7d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc @@ -1,139 +1,139 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/parser/onnx/onnx_custom_op_parser.h" -#include -#include -#include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "tools/converter/ops/ops_def.h" -#include "mindspore/core/ops/op_name.h" -#include "nnacl/op_base.h" -#include "ops/affine_grid.h" -#include "ops/histogram.h" -#include "ops/auto_generate/gen_lite_ops.h" -#include "ops/xlogy.h" -#include "ops/op_name.h" - -namespace mindspore { -namespace lite { -PrimitiveCPtr OnnxAffineGridParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_shared(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - const auto &attribute_name = onnx_node_attr.name(); - if (attribute_name == "align_corners") { - prim->set_align_corners(static_cast(onnx_node_attr.i())); - } else if (attribute_name == "size") { - auto prim_c = prim->GetPrim(); - MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr); - auto size_attr_size = onnx_node_attr.ints().size(); - std::vector size; - size.reserve(size_attr_size); - for (int idx = 0; idx < size_attr_size; idx++) { - size.emplace_back(onnx_node_attr.ints(idx)); - } - prim_c->AddAttr(ops::kSize, MakeValue>(size)); - } - } - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - - return prim->GetPrim(); -} - -PrimitiveCPtr OnnxHistogramParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_shared(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - const auto &attribute_name = onnx_node_attr.name(); - if (attribute_name == "bins") { - prim->set_bins(static_cast(onnx_node_attr.i())); - } else if (attribute_name == "max") { - prim->set_max(static_cast(onnx_node_attr.i())); - } else if (attribute_name == "min") { - prim->set_min(static_cast(onnx_node_attr.i())); - } - } - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - - return prim->GetPrim(); -} - -PrimitiveCPtr OnnxLogicalNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_shared(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - return prim->GetPrim(); -} - -PrimitiveCPtr OnnxRot90Parser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_shared(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - const auto &attribute_name = onnx_node_attr.name(); - if (attribute_name == "dims") { - auto dim_size = onnx_node_attr.ints().size(); - std::vector dims; - dims.reserve(dim_size); - for (int idx = 0; idx < dim_size; idx++) { - dims.emplace_back(onnx_node_attr.ints(idx)); - } - prim->AddAttr(ops::kDims, MakeValue(dims)); - } else if (attribute_name == "axis") { - auto dim_size = onnx_node_attr.ints().size(); - std::vector axis; - axis.reserve(dim_size); - for (int idx = 0; idx < dim_size; idx++) { - axis.emplace_back(onnx_node_attr.ints(idx)); - } - prim->AddAttr(ops::kAxis, MakeValue(axis)); - } - } - return prim; -} - -PrimitiveCPtr OnnxXlogyParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_shared(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - return prim->GetPrim(); -} - -PrimitiveCPtr OnnxRandomUniformLikeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_unique(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - const auto &attribute_name = onnx_node_attr.name(); - if (attribute_name == "dtype") { - auto onnx_dtype = static_cast(onnx_node_attr.i()); - auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_dtype); - prim->AddAttr("dtype", MakeValue(static_cast(data_type))); - } else if (attribute_name == "high") { - prim->AddAttr("high", MakeValue(static_cast(onnx_node_attr.f()))); - } else if (attribute_name == "low") { - prim->AddAttr("low", MakeValue(static_cast(onnx_node_attr.f()))); - } else if (attribute_name == "seed") { - prim->AddAttr(ops::kSeed, MakeValue(static_cast(onnx_node_attr.f()))); - } - } - return prim; -} - -OnnxNodeRegistrar g_onnxAffineGridParser("affine_grid", new OnnxAffineGridParser()); -OnnxNodeRegistrar g_onnxHistogramParser("histc", new OnnxHistogramParser()); -OnnxNodeRegistrar g_onnxLogicalNotParser("logical_not", new OnnxLogicalNotParser()); -OnnxNodeRegistrar g_onnxRot90Parser("rot90", new OnnxRot90Parser()); -OnnxNodeRegistrar g_onnxXlogyParser("xlogy", new OnnxXlogyParser()); -OnnxNodeRegistrar g_onnxRandomUniformLikeParser("RandomUniformLike", new OnnxRandomUniformLikeParser()); -} // namespace lite -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_custom_op_parser.h" +#include +#include +#include "tools/converter/parser/onnx/onnx_model_parser.h" +#include "tools/converter/ops/ops_def.h" +#include "mindspore/core/ops/op_name.h" +#include "nnacl/op_base.h" +#include "ops/affine_grid.h" +#include "ops/histogram.h" +#include "ops/auto_generate/gen_lite_ops.h" +#include "ops/xlogy.h" +#include "ops/op_name.h" + +namespace mindspore { +namespace lite { +PrimitiveCPtr OnnxAffineGridParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_shared(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "align_corners") { + prim->set_align_corners(static_cast(onnx_node_attr.i())); + } else if (attribute_name == "size") { + auto prim_c = prim->GetPrim(); + MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr); + auto size_attr_size = onnx_node_attr.ints().size(); + std::vector size; + size.reserve(size_attr_size); + for (int idx = 0; idx < size_attr_size; idx++) { + size.emplace_back(onnx_node_attr.ints(idx)); + } + prim_c->AddAttr(ops::kSize, MakeValue>(size)); + } + } + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + + return prim->GetPrim(); +} + +PrimitiveCPtr OnnxHistogramParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_shared(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "bins") { + prim->set_bins(static_cast(onnx_node_attr.i())); + } else if (attribute_name == "max") { + prim->set_max(static_cast(onnx_node_attr.i())); + } else if (attribute_name == "min") { + prim->set_min(static_cast(onnx_node_attr.i())); + } + } + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + + return prim->GetPrim(); +} + +PrimitiveCPtr OnnxLogicalNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_shared(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + return prim->GetPrim(); +} + +PrimitiveCPtr OnnxRot90Parser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_shared(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "dims") { + auto dim_size = onnx_node_attr.ints().size(); + std::vector dims; + dims.reserve(dim_size); + for (int idx = 0; idx < dim_size; idx++) { + dims.emplace_back(onnx_node_attr.ints(idx)); + } + prim->AddAttr(ops::kDims, MakeValue(dims)); + } else if (attribute_name == "axis") { + auto dim_size = onnx_node_attr.ints().size(); + std::vector axis; + axis.reserve(dim_size); + for (int idx = 0; idx < dim_size; idx++) { + axis.emplace_back(onnx_node_attr.ints(idx)); + } + prim->AddAttr(ops::kAxis, MakeValue(axis)); + } + } + return prim; +} + +PrimitiveCPtr OnnxXlogyParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_shared(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + return prim->GetPrim(); +} + +PrimitiveCPtr OnnxRandomUniformLikeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "dtype") { + auto onnx_dtype = static_cast(onnx_node_attr.i()); + auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_dtype); + prim->AddAttr("dtype", MakeValue(static_cast(data_type))); + } else if (attribute_name == "high") { + prim->AddAttr("high", MakeValue(static_cast(onnx_node_attr.f()))); + } else if (attribute_name == "low") { + prim->AddAttr("low", MakeValue(static_cast(onnx_node_attr.f()))); + } else if (attribute_name == "seed") { + prim->AddAttr(ops::kSeed, MakeValue(static_cast(onnx_node_attr.f()))); + } + } + return prim; +} + +OnnxNodeRegistrar g_onnxAffineGridParser("affine_grid", new OnnxAffineGridParser()); +OnnxNodeRegistrar g_onnxHistogramParser("histc", new OnnxHistogramParser()); +OnnxNodeRegistrar g_onnxLogicalNotParser("logical_not", new OnnxLogicalNotParser()); +OnnxNodeRegistrar g_onnxRot90Parser("rot90", new OnnxRot90Parser()); +OnnxNodeRegistrar g_onnxXlogyParser("xlogy", new OnnxXlogyParser()); +OnnxNodeRegistrar g_onnxRandomUniformLikeParser("RandomUniformLike", new OnnxRandomUniformLikeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.h index 04f50cc9697..d5d78cfc8f1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_custom_op_parser.h @@ -1,75 +1,75 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ - -#include "tools/converter/parser/onnx/onnx_node_parser.h" -#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class OnnxAffineGridParser : public OnnxNodeParser { - public: - OnnxAffineGridParser() : OnnxNodeParser("affine_grid") {} - ~OnnxAffineGridParser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; - -class OnnxHistogramParser : public OnnxNodeParser { - public: - OnnxHistogramParser() : OnnxNodeParser("Histogram") {} - ~OnnxHistogramParser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; - -class OnnxLogicalNotParser : public OnnxNodeParser { - public: - OnnxLogicalNotParser() : OnnxNodeParser("logical_not") {} - ~OnnxLogicalNotParser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; - -class OnnxRot90Parser : public OnnxNodeParser { - public: - OnnxRot90Parser() : OnnxNodeParser("Rot90") {} - ~OnnxRot90Parser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; - -class OnnxXlogyParser : public OnnxNodeParser { - public: - OnnxXlogyParser() : OnnxNodeParser("xlogy") {} - ~OnnxXlogyParser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; - -class OnnxRandomUniformLikeParser : public OnnxNodeParser { - public: - OnnxRandomUniformLikeParser() : OnnxNodeParser("RandomUniformLike") {} - ~OnnxRandomUniformLikeParser() override = default; - - PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxAffineGridParser : public OnnxNodeParser { + public: + OnnxAffineGridParser() : OnnxNodeParser("affine_grid") {} + ~OnnxAffineGridParser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; + +class OnnxHistogramParser : public OnnxNodeParser { + public: + OnnxHistogramParser() : OnnxNodeParser("Histogram") {} + ~OnnxHistogramParser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; + +class OnnxLogicalNotParser : public OnnxNodeParser { + public: + OnnxLogicalNotParser() : OnnxNodeParser("logical_not") {} + ~OnnxLogicalNotParser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; + +class OnnxRot90Parser : public OnnxNodeParser { + public: + OnnxRot90Parser() : OnnxNodeParser("Rot90") {} + ~OnnxRot90Parser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; + +class OnnxXlogyParser : public OnnxNodeParser { + public: + OnnxXlogyParser() : OnnxNodeParser("xlogy") {} + ~OnnxXlogyParser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; + +class OnnxRandomUniformLikeParser : public OnnxNodeParser { + public: + OnnxRandomUniformLikeParser() : OnnxNodeParser("RandomUniformLike") {} + ~OnnxRandomUniformLikeParser() override = default; + + PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_CUSTOM_OP_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_sparse_reshape.cc b/mindspore/lite/tools/converter/parser/tf/tf_sparse_reshape.cc index 0827f17aa2c..d609ca36be8 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_sparse_reshape.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_sparse_reshape.cc @@ -1,40 +1,40 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tf/tf_sparse_reshape.h" -#include -#include -#include -#include -#include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "ops/sparse_reshape.h" - -namespace mindspore { -namespace lite { -PrimitiveCPtr TFSparseReshapeParser::Parse(const tensorflow::NodeDef &tf_op, - const std::map &tf_node_map, - std::vector *inputs, int *output_size) { - auto prim = std::make_unique(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - - *output_size = C2NUM; - for (int i = 0; i < tf_op.input_size(); i++) { - inputs->emplace_back(tf_op.input(i)); - } - return prim->GetPrim(); -} -TFNodeRegistrar g_tfSparseReshapeParser("SparseReshape", new TFSparseReshapeParser()); -} // namespace lite -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_sparse_reshape.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "ops/sparse_reshape.h" + +namespace mindspore { +namespace lite { +PrimitiveCPtr TFSparseReshapeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + + *output_size = C2NUM; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return prim->GetPrim(); +} +TFNodeRegistrar g_tfSparseReshapeParser("SparseReshape", new TFSparseReshapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.cc b/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.cc index 2f392180067..c85714cea48 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.cc @@ -1,40 +1,40 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tf/tf_sparse_segment_sum.h" -#include -#include -#include -#include -#include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "ops/sparse_segment_sum.h" - -namespace mindspore { -namespace lite { -PrimitiveCPtr TFSparseSegmentSumParser::Parse(const tensorflow::NodeDef &tf_op, - const std::map &tf_node_map, - std::vector *inputs, int *output_size) { - auto prim = std::make_unique(); - MS_CHECK_TRUE_RET(prim != nullptr, nullptr); - - *output_size = C1NUM; - for (int i = 0; i < tf_op.input_size(); i++) { - inputs->emplace_back(tf_op.input(i)); - } - return prim->GetPrim(); -} -TFNodeRegistrar g_tfSparseSegmentSumParser("SparseSegmentSum", new TFSparseSegmentSumParser()); -} // namespace lite -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_sparse_segment_sum.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "ops/sparse_segment_sum.h" + +namespace mindspore { +namespace lite { +PrimitiveCPtr TFSparseSegmentSumParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) { + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, nullptr); + + *output_size = C1NUM; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return prim->GetPrim(); +} +TFNodeRegistrar g_tfSparseSegmentSumParser("SparseSegmentSum", new TFSparseSegmentSumParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.h b/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.h index 84610055650..f36b3324fa3 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_sparse_segment_sum.h @@ -1,38 +1,38 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ - -#include -#include -#include -#include -#include "tools/converter/parser/tf/tf_node_parser.h" - -namespace mindspore { -namespace lite { -class TFSparseSegmentSumParser : public TFNodeParser { - public: - TFSparseSegmentSumParser() = default; - ~TFSparseSegmentSumParser() override = default; - - PrimitiveCPtr Parse(const tensorflow::NodeDef &tf_op, - const std::map &tf_node_map, - std::vector *inputs, int *output_size) override; -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSparseSegmentSumParser : public TFNodeParser { + public: + TFSparseSegmentSumParser() = default; + ~TFSparseSegmentSumParser() override = default; + + PrimitiveCPtr Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPARSE_SPARSE_SEGMENT_SUM_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tflite/schema.fbs b/mindspore/lite/tools/converter/parser/tflite/schema.fbs index 1919b8bea14..a8bdf5e067a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/schema.fbs +++ b/mindspore/lite/tools/converter/parser/tflite/schema.fbs @@ -1,1094 +1,1094 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Revision History -// Version 0: Initial version. -// Version 1: Add subgraphs to schema. -// Version 2: Rename operators to conform to NN API. -// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. - -namespace tflite; - -// This corresponds to the version. -file_identifier "TFL3"; -// File extension of any written files. -file_extension "tflite"; - -// IMPORTANT: All new members of tables, enums and unions must be added at the -// end to ensure backwards compatibility. - -// The type of data stored in a tensor. -enum TensorType : byte { - FLOAT32 = 0, - FLOAT16 = 1, - INT32 = 2, - UINT8 = 3, - INT64 = 4, - STRING = 5, - BOOL = 6, - INT16 = 7, - COMPLEX64 = 8, - INT8 = 9, - FLOAT64 = 10, -} - -// Custom quantization parameters for experimenting with new quantization -// techniques. -table CustomQuantization { - custom:[ubyte] (force_align: 16); -} - -// Represents a specific quantization technique's parameters. -union QuantizationDetails { - CustomQuantization, -} - -// Parameters for converting a quantized tensor back to float. -table QuantizationParameters { - // These four parameters are the asymmetric linear quantization parameters. - // Given a quantized value q, the corresponding float value f should be: - // f = scale * (q - zero_point) - // For other quantization types, the QuantizationDetails below is used. - min:[float]; // For importing back into tensorflow. - max:[float]; // For importing back into tensorflow. - scale:[float]; // For dequantizing the tensor's values. - zero_point:[long]; - - // If this is not none, the other quantization parameters (i.e. min, max, - // scale, zero_point fields above) are ignored and the value of the - // QuantizationDetails union should be used. - details:QuantizationDetails; - - // Specifies the dimension of the Tensor's shape that the scales and - // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] - // with quantization params: - // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 - // will be quantized across the second dimension of t. - // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 - // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 - // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 - quantized_dimension:int; -} - -// Sparse tensors. -// We use a modification of the TACO format. -// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf -// -// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), -// potentially with a k-dimensional block (0 <= k <= n) with dims -// (dn, ..., dn+k-1), the format needs to specify: -// 1. In what order to traverse these dimensions. For example, to store a 2-D -// matrix in row major order, the traversal order would be (d0, d1), -// whereas to store it in column major order, the traversal order would be -// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order -// could be (d0, d1, d2, d3). -// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original -// tensor dimension in (d0, ..., dn-1). -// 3. In the traversal order defined above, the format (dense vs. sparse) and -// index metadata for each dimension. For a dense dimension, this is just -// the size of that dimension. For a sparse dimension, it's the same as -// the compressed index defined in the Compressed Sparse Row (CSR) format. -// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) - -// The storage type for a dimension. Currently we support: -// 1. DENSE: each coordinate in this dimension is stored implicitly. -// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The -// compression technique is the same what CSR uses. -// More types like a sparse dimension with a different compression technique -// could be added to the list in the future. -enum DimensionType : byte { - DENSE = 0, - SPARSE_CSR = 1, -} - -table Int32Vector { - values:[int]; -} - -table Uint16Vector { - values:[ushort] (force_align: 4); -} - -table Uint8Vector { - values:[ubyte] (force_align: 4); -} - -// Variable-typed buffer to store the index metadata for a sparse dimension. -// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 -// vector. We don't want the per-dimensional index to overflow that range. -union SparseIndexVector { - Int32Vector, - Uint16Vector, - Uint8Vector -} - -table DimensionMetadata { - // Whether a dimension is dense or sparse. - format:DimensionType; - // Index metadata used for a dimension. - // - If format is DimensionType.DENSE then we use the dense_size field to - // store the size of that dimension. Each index in that dimension is - // stored implicitly. - // - If format is DimensionType.SPARSE_CSR then we use array_segments and - // array_indices to encode that dimension. array_segments represents how - // to segment the indices array, each segment corresponds to one element - // in the previous dimension. array_indices represents the index of the - // non-zero elements within this dimension (as those in the CSR matrix - // format, where the first array is row pointers and the second array is - // column indices). - dense_size:int; - array_segments:SparseIndexVector; - array_indices:SparseIndexVector; -} - -// Parameters to encode a sparse TfLite tensor. -table SparsityParameters { - // The traversal order of the dimensions defined in the `shape` field of the - // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, - // ..., dn-1), - // - if not block sparse, the traversal_order is just a permutation of (d0, - // ..., dn-1). For example, a 2-D matrix stored in row-major order would - // have traversal_order = (d0, d1). - // - if block sparse with a k-dimensional block (0 <= k <= n), the - // traversal_order has n + k elements. The first n elements are still a - // permutation of (d0, ..., dn-1). The lask k elements are a permutation - // of (dn, ..., dn+k-1), defining how to traverse a block internally. For - // example, a 2-D matrix with 2-D blocks, both stored in row-major order - // would have traversal_order = (d0, d1, d2, d3). - traversal_order:[int]; - // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), - // stores how a block dimension in (dn, ..., dn+k-1) maps to the original - // tensor dimension in (d0, ..., dn). - // It's stored in the order of (dn, ..., dn+k-1). - // If not block-sparse, this field is NULL. - block_map:[int]; - // In the traversal order defined above, the metadata needed for - // each dimension to locate the non-zero values in the original dense tensor. - // The size of the dim_metadata array = the size of the traversal_order array - // = n + k. - dim_metadata:[DimensionMetadata]; -} - -table Tensor { - // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, height, width, number of channels] (That's - // Tensorflow's NHWC). - shape:[int]; - type:TensorType; - // An index that refers to the buffers table at the root of the model. Or, - // if there is no data buffer associated (i.e. intermediate results), then - // this is 0 (which refers to an always existent empty buffer). - // - // The data_buffer itself is an opaque container, with the assumption that the - // target device is little-endian. In addition, all builtin operators assume - // the memory is ordered such that if `shape` is [4, 3, 2], then index - // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. - buffer:uint; - name:string; // For debugging and importing back into tensorflow. - quantization:QuantizationParameters; // Optional. - - is_variable:bool = false; - - // Parameters to encode a sparse tensor. See the example in - // tensorflow/lite/testdata/sparse_tensor.json. - sparsity:SparsityParameters; // Optional. - - // Encodes `shape` with unknown dimensions. Unknown dimensions are - // represented with -1. - shape_signature:[int]; // Optional. -} - -// A list of builtin operators. Builtin operators are slightly faster than custom -// ones, but not by much. Moreover, while custom operators accept an opaque -// object containing configuration parameters, builtins have a predetermined -// set of acceptable options. - -enum BuiltinOperator : byte { - ADD = 0, - AVERAGE_POOL_2D = 1, - CONCATENATION = 2, - CONV_2D = 3, - DEPTHWISE_CONV_2D = 4, - DEPTH_TO_SPACE = 5, - DEQUANTIZE = 6, - EMBEDDING_LOOKUP = 7, - FLOOR = 8, - FULLY_CONNECTED = 9, - HASHTABLE_LOOKUP = 10, - L2_NORMALIZATION = 11, - L2_POOL_2D = 12, - LOCAL_RESPONSE_NORMALIZATION = 13, - LOGISTIC = 14, - LSH_PROJECTION = 15, - LSTM = 16, - MAX_POOL_2D = 17, - MUL = 18, - RELU = 19, - // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed - // since different model developers use RELU1 in different ways. Never - // create another op called RELU1. - RELU_N1_TO_1 = 20, - RELU6 = 21, - RESHAPE = 22, - RESIZE_BILINEAR = 23, - RNN = 24, - SOFTMAX = 25, - SPACE_TO_DEPTH = 26, - SVDF = 27, - TANH = 28, - // Consider rename to CONCATENATE_EMBEDDINGS - CONCAT_EMBEDDINGS = 29, - SKIP_GRAM = 30, - CALL = 31, - CUSTOM = 32, - EMBEDDING_LOOKUP_SPARSE = 33, - PAD = 34, - UNIDIRECTIONAL_SEQUENCE_RNN = 35, - GATHER = 36, - BATCH_TO_SPACE_ND = 37, - SPACE_TO_BATCH_ND = 38, - TRANSPOSE = 39, - MEAN = 40, - SUB = 41, - DIV = 42, - SQUEEZE = 43, - UNIDIRECTIONAL_SEQUENCE_LSTM = 44, - STRIDED_SLICE = 45, - BIDIRECTIONAL_SEQUENCE_RNN = 46, - EXP = 47, - TOPK_V2 = 48, - SPLIT = 49, - LOG_SOFTMAX = 50, - // DELEGATE is a special op type for the operations which are delegated to - // other backends. - // WARNING: Experimental interface, subject to change - DELEGATE = 51, - BIDIRECTIONAL_SEQUENCE_LSTM = 52, - CAST = 53, - PRELU = 54, - MAXIMUM = 55, - ARG_MAX = 56, - MINIMUM = 57, - LESS = 58, - NEG = 59, - PADV2 = 60, - GREATER = 61, - GREATER_EQUAL = 62, - LESS_EQUAL = 63, - SELECT = 64, - SLICE = 65, - SIN = 66, - TRANSPOSE_CONV = 67, - SPARSE_TO_DENSE = 68, - TILE = 69, - EXPAND_DIMS = 70, - EQUAL = 71, - NOT_EQUAL = 72, - LOG = 73, - SUM = 74, - SQRT = 75, - RSQRT = 76, - SHAPE = 77, - POW = 78, - ARG_MIN = 79, - FAKE_QUANT = 80, - REDUCE_PROD = 81, - REDUCE_MAX = 82, - PACK = 83, - LOGICAL_OR = 84, - ONE_HOT = 85, - LOGICAL_AND = 86, - LOGICAL_NOT = 87, - UNPACK = 88, - REDUCE_MIN = 89, - FLOOR_DIV = 90, - REDUCE_ANY = 91, - SQUARE = 92, - ZEROS_LIKE = 93, - FILL = 94, - FLOOR_MOD = 95, - RANGE = 96, - RESIZE_NEAREST_NEIGHBOR = 97, - LEAKY_RELU = 98, - SQUARED_DIFFERENCE = 99, - MIRROR_PAD = 100, - ABS = 101, - SPLIT_V = 102, - UNIQUE = 103, - CEIL = 104, - REVERSE_V2 = 105, - ADD_N = 106, - GATHER_ND = 107, - COS = 108, - WHERE = 109, - RANK = 110, - ELU = 111, - REVERSE_SEQUENCE = 112, - MATRIX_DIAG = 113, - QUANTIZE = 114, - MATRIX_SET_DIAG = 115, - ROUND = 116, - HARD_SWISH = 117, - IF = 118, - WHILE = 119, - NON_MAX_SUPPRESSION_V4 = 120, - NON_MAX_SUPPRESSION_V5 = 121, - SCATTER_ND = 122, - SELECT_V2 = 123, - DENSIFY = 124, - SEGMENT_SUM = 125, - BATCH_MATMUL = 126 -} - - -// Options for the builtin operators. -union BuiltinOptions { - Conv2DOptions, - DepthwiseConv2DOptions, - ConcatEmbeddingsOptions, - LSHProjectionOptions, - Pool2DOptions, - SVDFOptions, - RNNOptions, - FullyConnectedOptions, - SoftmaxOptions, - ConcatenationOptions, - AddOptions, - L2NormOptions, - LocalResponseNormalizationOptions, - LSTMOptions, - ResizeBilinearOptions, - CallOptions, - ReshapeOptions, - SkipGramOptions, - SpaceToDepthOptions, - EmbeddingLookupSparseOptions, - MulOptions, - PadOptions, - GatherOptions, - BatchToSpaceNDOptions, - SpaceToBatchNDOptions, - TransposeOptions, - ReducerOptions, - SubOptions, - DivOptions, - SqueezeOptions, - SequenceRNNOptions, - StridedSliceOptions, - ExpOptions, - TopKV2Options, - SplitOptions, - LogSoftmaxOptions, - CastOptions, - DequantizeOptions, - MaximumMinimumOptions, - ArgMaxOptions, - LessOptions, - NegOptions, - PadV2Options, - GreaterOptions, - GreaterEqualOptions, - LessEqualOptions, - SelectOptions, - SliceOptions, - TransposeConvOptions, - SparseToDenseOptions, - TileOptions, - ExpandDimsOptions, - EqualOptions, - NotEqualOptions, - ShapeOptions, - PowOptions, - ArgMinOptions, - FakeQuantOptions, - PackOptions, - LogicalOrOptions, - OneHotOptions, - LogicalAndOptions, - LogicalNotOptions, - UnpackOptions, - FloorDivOptions, - SquareOptions, - ZerosLikeOptions, - FillOptions, - BidirectionalSequenceLSTMOptions, - BidirectionalSequenceRNNOptions, - UnidirectionalSequenceLSTMOptions, - FloorModOptions, - RangeOptions, - ResizeNearestNeighborOptions, - LeakyReluOptions, - SquaredDifferenceOptions, - MirrorPadOptions, - AbsOptions, - SplitVOptions, - UniqueOptions, - ReverseV2Options, - AddNOptions, - GatherNdOptions, - CosOptions, - WhereOptions, - RankOptions, - ReverseSequenceOptions, - MatrixDiagOptions, - QuantizeOptions, - MatrixSetDiagOptions, - HardSwishOptions, - IfOptions, - WhileOptions, - DepthToSpaceOptions, - NonMaxSuppressionV4Options, - NonMaxSuppressionV5Options, - ScatterNdOptions, - SelectV2Options, - DensifyOptions, - SegmentSumOptions, - BatchMatMulOptions -} - -enum Padding : byte { SAME, VALID } - -enum ActivationFunctionType : byte { - NONE = 0, - RELU = 1, - RELU_N1_TO_1 = 2, - RELU6 = 3, - TANH = 4, - SIGN_BIT = 5, -} - -table Conv2DOptions { - padding:Padding; - stride_w:int; - stride_h:int; - fused_activation_function:ActivationFunctionType; - dilation_w_factor:int = 1; - dilation_h_factor:int = 1; -} - -table Pool2DOptions { - padding:Padding; - stride_w:int; - stride_h:int; - filter_width:int; - filter_height:int; - fused_activation_function:ActivationFunctionType; -} - -table DepthwiseConv2DOptions { - // Parameters for DepthwiseConv version 1 or above. - padding:Padding; - stride_w:int; - stride_h:int; - // `depth_multiplier` is redundant. It's used by CPU kernels in - // TensorFlow 2.0 or below, but ignored in versions above. - // See comments in lite/c/builtin_op_data.h for more details. - depth_multiplier:int; - fused_activation_function:ActivationFunctionType; - // Parameters for DepthwiseConv version 2 or above. - dilation_w_factor:int = 1; - dilation_h_factor:int = 1; -} - -table ConcatEmbeddingsOptions { - num_channels:int; - num_columns_per_channel:[int]; - embedding_dim_per_channel:[int]; // This could be inferred from parameters. -} - -enum LSHProjectionType: byte { - UNKNOWN = 0, - SPARSE = 1, - DENSE = 2, -} - -table LSHProjectionOptions { - type: LSHProjectionType; -} - -table SVDFOptions { - rank:int; - fused_activation_function:ActivationFunctionType; - // For weights-only quantization, use asymmetric quantization for non - // constant inputs at evaluation time. - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow RNNCell. -table RNNOptions { - fused_activation_function:ActivationFunctionType; - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow dynamic_rnn with RNNCell. -table SequenceRNNOptions { - time_major:bool; - fused_activation_function:ActivationFunctionType; - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. -table BidirectionalSequenceRNNOptions { - time_major:bool; - fused_activation_function:ActivationFunctionType; - merge_outputs: bool; - asymmetric_quantize_inputs:bool; -} - -enum FullyConnectedOptionsWeightsFormat: byte { - DEFAULT = 0, - SHUFFLED4x16INT8 = 1, -} - -// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. -table FullyConnectedOptions { - // Parameters for FullyConnected version 1 or above. - fused_activation_function:ActivationFunctionType; - - // Parameters for FullyConnected version 2 or above. - weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; - - // Parameters for FullyConnected version 5 or above. - // If set to true, then the number of dimension is preserved. Furthermore, - // all but the last dimension of the input and output shapes will be equal. - keep_num_dims: bool; - - // Parameters for FullyConnected version 7 or above. - // If set to true, then weights-only op will use asymmetric quantization for - // inputs. - asymmetric_quantize_inputs: bool; -} - -table SoftmaxOptions { - beta: float; -} - -// An implementation of TensorFlow concat. -table ConcatenationOptions { - axis:int; - fused_activation_function:ActivationFunctionType; -} - -table AddOptions { - fused_activation_function:ActivationFunctionType; -} - -table MulOptions { - fused_activation_function:ActivationFunctionType; -} - -table L2NormOptions { - fused_activation_function:ActivationFunctionType; -} - -table LocalResponseNormalizationOptions { - radius:int; - bias:float; - alpha:float; - beta:float; -} - -enum LSTMKernelType : byte { - // Full LSTM kernel which supports peephole and projection. - FULL = 0, - // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. - BASIC = 1, -} - -// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell -table LSTMOptions { - // Parameters for LSTM version 1 or above. - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // Parameters for LSTM version 2 or above. - // Basic kernel is only supported in version 2 or above. - kernel_type: LSTMKernelType = FULL; - - // Parameters for LSTM version 4 or above. - asymmetric_quantize_inputs: bool; -} - -// An implementation of TensorFlow dynamic_rnn with LSTMCell. -table UnidirectionalSequenceLSTMOptions { - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // If true then first dimension is sequence, otherwise batch. - time_major:bool; - - // Parameter for Unidirectional Sequence LSTM version 4. - asymmetric_quantize_inputs:bool; -} - -table BidirectionalSequenceLSTMOptions { - // Parameters supported by version 1: - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // If true, store the outputs of both directions into the first output. - merge_outputs: bool; - - // Parameters supported by version 2: - // If true then first dimension is sequence, otherwise batch. - // Version 1 implementations assumed time_major to be true, so this default - // value should never change. - time_major: bool = true; - - // Parameters for version 3 or above. - asymmetric_quantize_inputs:bool; -} - -table ResizeBilinearOptions { - new_height: int (deprecated); - new_width: int (deprecated); - align_corners: bool; - half_pixel_centers: bool; -} - -table ResizeNearestNeighborOptions { - align_corners: bool; - half_pixel_centers: bool; -} - -// A call operation options -table CallOptions { - // The subgraph index that needs to be called. - subgraph:uint; -} - -table PadOptions { -} - -table PadV2Options { -} - -table ReshapeOptions { - new_shape:[int]; -} - -table SpaceToBatchNDOptions { -} - -table BatchToSpaceNDOptions { -} - -table SkipGramOptions { - ngram_size: int; - max_skip_size: int; - include_all_ngrams: bool; -} - -table SpaceToDepthOptions { - block_size: int; -} - -table DepthToSpaceOptions { - block_size: int; -} - -table SubOptions { - fused_activation_function:ActivationFunctionType; -} - -table DivOptions { - fused_activation_function:ActivationFunctionType; -} - -table TopKV2Options { -} - -enum CombinerType : byte { - SUM = 0, - MEAN = 1, - SQRTN = 2, -} - -table EmbeddingLookupSparseOptions { - combiner:CombinerType; -} - -table GatherOptions { - axis: int; -} - -table TransposeOptions { -} - -table ExpOptions { -} - -table CosOptions { -} - -table ReducerOptions { - keep_dims: bool; -} - -table SqueezeOptions { - squeeze_dims:[int]; -} - -table SplitOptions { - num_splits: int; -} - -table SplitVOptions { - num_splits: int; -} - -table StridedSliceOptions { - begin_mask: int; - end_mask: int; - ellipsis_mask: int; - new_axis_mask: int; - shrink_axis_mask: int; -} - -table LogSoftmaxOptions { -} - -table CastOptions { - in_data_type: TensorType; - out_data_type: TensorType; -} - -table DequantizeOptions { -} - -table MaximumMinimumOptions { -} - -table TileOptions { -} - -table ArgMaxOptions { - output_type : TensorType; -} - -table ArgMinOptions { - output_type : TensorType; -} - -table GreaterOptions { -} - -table GreaterEqualOptions { -} - -table LessOptions { -} - -table LessEqualOptions { -} - -table NegOptions { -} - -table SelectOptions { -} - -table SliceOptions { -} - -table TransposeConvOptions { - padding:Padding; - stride_w:int; - stride_h:int; -} - -table ExpandDimsOptions { -} - -table SparseToDenseOptions { - validate_indices:bool; -} - -table EqualOptions { -} - -table NotEqualOptions { -} - -table ShapeOptions { - // Optional output type of the operation (int32 or int64). Defaults to int32. - out_type : TensorType; -} - -table RankOptions { -} - -table PowOptions { -} - -table FakeQuantOptions { - // Parameters supported by version 1: - min:float; - max:float; - num_bits:int; - - // Parameters supported by version 2: - narrow_range:bool; -} - -table PackOptions { - values_count:int; - axis:int; -} - -table LogicalOrOptions { -} - -table OneHotOptions { - axis:int; -} - -table AbsOptions { -} - - -table HardSwishOptions { -} - -table LogicalAndOptions { -} - -table LogicalNotOptions { -} - -table UnpackOptions { - num:int; - axis:int; -} - -table FloorDivOptions { -} - -table SquareOptions { -} - -table ZerosLikeOptions { -} - -table FillOptions { -} - -table FloorModOptions { -} - -table RangeOptions { -} - -table LeakyReluOptions { - alpha:float; -} - -table SquaredDifferenceOptions { -} - -enum MirrorPadMode : byte { - // Doesn't include borders. - REFLECT = 0, - // Includes borders. - SYMMETRIC = 1, -} - -table MirrorPadOptions { - mode:MirrorPadMode; -} - -table UniqueOptions { - idx_out_type:TensorType = INT32; -} - -table ReverseV2Options { -} - -table AddNOptions { -} - -table GatherNdOptions { -} - -table WhereOptions { -} - -table ReverseSequenceOptions { - seq_dim:int; - batch_dim:int = 0; -} - -table MatrixDiagOptions { -} - -table QuantizeOptions { -} - -table MatrixSetDiagOptions { -} - -table IfOptions { - then_subgraph_index:int; - else_subgraph_index:int; -} - -table WhileOptions { - cond_subgraph_index:int; - body_subgraph_index:int; -} - -table NonMaxSuppressionV4Options { -} - -table NonMaxSuppressionV5Options { -} - -table ScatterNdOptions { -} - -table SelectV2Options { -} - -table DensifyOptions { -} - -table SegmentSumOptions { -} - -table BatchMatMulOptions { - adj_x:bool; - adj_y:bool; -} - -// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a -// builtin, or a string if the operator is custom. -table OperatorCode { - builtin_code:BuiltinOperator; - custom_code:string; - - // The version of the operator. The version need to be bumped whenever new - // parameters are introduced into an op. - version:int = 1; -} - -enum CustomOptionsFormat : byte { - FLEXBUFFERS = 0, -} - -// An operator takes tensors as inputs and outputs. The type of operation being -// performed is determined by an index into the list of valid OperatorCodes, -// while the specifics of each operations is configured using builtin_options -// or custom_options. -table Operator { - // Index into the operator_codes array. Using an integer here avoids - // complicate map lookups. - opcode_index:uint; - - // Optional input are indicated by -1. - inputs:[int]; - outputs:[int]; - - builtin_options:BuiltinOptions; - custom_options:[ubyte]; - custom_options_format:CustomOptionsFormat; - - // A list of booleans indicating the input tensors which are being mutated by - // this operator.(e.g. used by RNN and LSTM). - // For example, if the "inputs" array refers to 5 tensors and the second and - // fifth are mutable variables, then this list will contain - // [false, true, false, false, true]. - // - // If the list is empty, no variable is mutated in this operator. - // The list either has the same length as `inputs`, or is empty. - mutating_variable_inputs:[bool]; - - // A list of indices to the subgraph's "tensors" that are internal to an Op. - // Internal tensors are those that do not flow in or out of the operation, - // but instead are part of internal computation. As such, the operation's - // implementation may manage its memory more efficiently. They are needed - // however (i.e. not just an implementation detail) since they are part of the - // computation, which may require relevant metadata such as quantization - // parameters. - intermediates:[int]; -} - -// The root type, defining a subgraph, which typically represents an entire -// model. -table SubGraph { - // A list of all tensors used in this subgraph. - tensors:[Tensor]; - - // Indices of the tensors that are inputs into this subgraph. Note this is - // the list of non-static tensors that feed into the subgraph for inference. - inputs:[int]; - - // Indices of the tensors that are outputs out of this subgraph. Note this is - // the list of output tensors that are considered the product of the - // subgraph's inference. - outputs:[int]; - - // All operators, in execution order. - operators:[Operator]; - - // Name of this subgraph (used for debugging). - name:string; -} - -// Table of raw data buffers (used for constant tensors). Referenced by tensors -// by index. The generous alignment accommodates mmap-friendly data structures. -table Buffer { - data:[ubyte] (force_align: 16); -} - -table Metadata { - // A human readable string to uniquely identify a Metadata. - name:string; - // An index to the buffers table. - buffer:uint; -} - -table Model { - // Version of the schema. - version:uint; - - // A list of all operator codes used in this model. This is - // kept in order because operators carry an index into this - // vector. - operator_codes:[OperatorCode]; - - // All the subgraphs of the model. The 0th is assumed to be the main - // model. - subgraphs:[SubGraph]; - - // A description of the model. - description:string; - - // Buffers of the model. - // Note the 0th entry of this array must be an empty buffer (sentinel). - // This is a convention so that tensors without a buffer can provide 0 as - // their buffer. - buffers:[Buffer]; - - // Metadata about the model. Indirects into the existings buffers list. - // Deprecated, prefer to use metadata field. - metadata_buffer:[int]; - - // Metadata about the model. - metadata:[Metadata]; -} - -root_type Model; +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. + +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126 +} + + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; +} + +root_type Model; diff --git a/mindspore/lite/tools/dataset/cropper/crop.sh b/mindspore/lite/tools/dataset/cropper/crop.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_impl_and_mrege_json.sh b/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_impl_and_mrege_json.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_ops_filter.sh b/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_ops_filter.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_version_info.sh b/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/gen_version_info.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/merge_aicpu_info_json.sh b/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/merge_aicpu_info_json.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/scripts/install.sh b/mindspore/lite/tools/kernel_builder/ascend/scripts/install.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/gen_impl_and_mrege_json.sh b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/gen_impl_and_mrege_json.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/gen_ops_filter.sh b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/gen_ops_filter.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/merge_aicpu_info_json.sh b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/cmake/util/merge_aicpu_info_json.sh old mode 100755 new mode 100644 diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/impl/matmul_tik.py b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/impl/matmul_tik.py index aec168eea3a..025a9373ba9 100644 --- a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/impl/matmul_tik.py +++ b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/impl/matmul_tik.py @@ -1,212 +1,212 @@ -""" -Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -matmul_tik -""" - -from tbe import tik -from tbe.common.platform import get_soc_spec - -DTYPE_SIZE = { - 'bool': 1, - 'uint8': 1, - 'int8': 1, - 'uint16': 2, - 'int16': 2, - 'int24': 3, - 'uint32': 4, - 'int32': 4, - 'float16': 2, - 'float32': 4, - 'int48': 6, - 'int64': 8, - 'uint64': 8, - 'float64': 8 -} - - -def MK_TO_K1MK0(tik_instance, mk_input_tensor, k1mk0_tensor, dtype, k1, m, k0): - """data move mk to k1mk0""" - src_ub = tik_instance.Tensor(dtype, (k1, m, k0), name='src_ub', scope=tik.scope_ubuf) - - # data_move(m, k) ---> (k1, m, k0) - with tik_instance.for_range(0, k1) as i: - tik_instance.data_move(src_ub[i * m * k0:], mk_input_tensor[i * k0:], 0, m, k0 * DTYPE_SIZE[dtype] // 32, - (k1 - 1) * k0 * DTYPE_SIZE[dtype] // 32, 0) - - tik_instance.data_move(k1mk0_tensor, src_ub, 0, 1, k1 * m * k0 * DTYPE_SIZE[dtype] // 32, 0, 0) - - -def KN_TO_K1NK0(tik_instance, kn_input_tensor, k1nk0_tensor, dtype, k1, n, k0): - """data move kn to k1nk0""" - - with tik_instance.for_range(0, k1) as index: - k1nk0_ub = tik_instance.Tensor(dtype, (n, k0), tik.scope_ubuf, "k1nk0_ub") - src_ub = tik_instance.Tensor(dtype, (k0, n), tik.scope_ubuf, "src_ub") - burst_len = k0 * n * DTYPE_SIZE[dtype] // 32 - tik_instance.data_move(src_ub, kn_input_tensor[index * k0 * n], 0, 1, burst_len, 0, 0) - dst_list = [k1nk0_ub[16 * i] for i in range(16)] - src_list = [src_ub[n * i] for i in range(16)] - rep_times = n // k0 - dst_rep_stride = k0 - src_rep_stride = 1 - tik_instance.vec_trans_scatter(False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride) - tik_instance.data_move(k1nk0_tensor[index * k0 * n], k1nk0_ub, 0, 1, burst_len, 0, 0) - - -def N1MN0_TO_MN(tik_instance, mn_output_tensor, n1mn0_tensor, dtype, n1, m, n0): - """data move mn to n1mn0""" - src_ub = tik_instance.Tensor(dtype, (m, n1 * n0), name='src_ub', scope=tik.scope_ubuf) - - # data_move(n1, m, n0) ---> (m, n) - with tik_instance.for_range(0, n1) as i: - tik_instance.data_move(src_ub[i * n0:], n1mn0_tensor[i * m * n0:], 0, m, - n0 * DTYPE_SIZE[dtype] // 32, 0, (n1 - 1) * n0 * DTYPE_SIZE[dtype] // 32) - - tik_instance.data_move(mn_output_tensor, src_ub, 0, 1, m * n1 * n0 * DTYPE_SIZE[dtype] // 32, 0, 0) - - -def matmul_tik_compute(params, kernel_name): - """ - matmul tik compute - @param params: matmul data - @param kernel_name: kernel name - @return: tik instance - """ - tik_instance = tik.Tik() - if not isinstance(params, dict): - params = params.__dict__ - m_size, k_size, n_size = params['M'], params['K'], params['N'] - data_type = params["data_type"] - m_tiling_size = int(params["m_tiling_size"]) - n_tiling_size = int(params["n_tiling_size"]) - k_tiling_size = int(params['k_tiling_size']) - - m_cycle_times = params["m_cycle_times"] - n_cycle_times = params["n_cycle_times"] - k_cycle_times = params["k_cycle_times"] - - # Determine the output type - if data_type == "float16": - if get_soc_spec("SOC_VERSION") in ["SD3403", "OPTG", "Hi3796CV300CS", "TsnsC"]: - C_loc_out_type = "float16" - else: - C_loc_out_type = "float32" - K0 = 16 - else: - C_loc_out_type = "int32" - K0 = 32 - block_size = 16 - - n_thread_num = params['n_thread_num'] - m_thread_num = params['m_thread_num'] - k_thread_num = params['k_thread_num'] - - mk_gm_input = tik_instance.Tensor(data_type, (m_size, k_size), name="mk_input_gm", scope=tik.scope_gm) - kn_gm_input = tik_instance.Tensor(data_type, (k_size, n_size), name="kn_input_gm", scope=tik.scope_gm) - - k1mk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, m_size, K0), name="k1mk0_workspace", - scope=tik.scope_gm, is_workspace=True) - - k1nk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, n_size, K0), name="k1nk0_workspace", - scope=tik.scope_gm, is_workspace=True) - - mn_gm_output = tik_instance.Tensor(C_loc_out_type, (m_size, n_size), tik.scope_gm, name="mn_output_gm") - nmk0_workspace = tik_instance.Tensor(C_loc_out_type, (n_size // block_size, m_size, block_size), - name="nmk0_workspace", scope=tik.scope_gm, is_workspace=True) - - MK_TO_K1MK0(tik_instance, mk_gm_input, k1mk0_workspace, data_type, k_size // K0, m_size, K0) - KN_TO_K1NK0(tik_instance, kn_gm_input, k1nk0_workspace, data_type, k_size // K0, n_size, K0) - - # Tiling is realized through the for_range() loop. - with tik_instance.for_range(0, 2, block_num=1) as core_id: - with tik_instance.for_range(0, n_cycle_times // 2, thread_num=n_thread_num) as n_idx: - with tik_instance.for_range(0, m_cycle_times, thread_num=m_thread_num) as m_idx: - dst_l0c = tik_instance.Tensor(C_loc_out_type, [n_tiling_size // 16, m_tiling_size, 16], name='dst_l0c', - scope=tik.scope_cbuf_out) - with tik_instance.for_range(0, k_cycle_times, - thread_num=k_thread_num) as k_idx: - # Calculation result data transfer. - inputa_l1 = tik_instance.Tensor(params['data_type'], [k_tiling_size // K0, m_tiling_size, K0], - name="A_tiling_l1", scope=tik.scope_cbuf) - tik_instance.data_move(inputa_l1, - k1mk0_workspace[k_idx * k_tiling_size // K0, m_idx * m_tiling_size, :], - 0, k_tiling_size // K0, m_tiling_size, m_size - m_tiling_size, 0) - inputb_l1 = tik_instance.Tensor(params["data_type"], [k_tiling_size // K0, n_tiling_size, K0], - name="B_tiling_l1", scope=tik.scope_cbuf) - if n_size - n_tiling_size > 65535: - with tik_instance.for_range(0, k_tiling_size // K0) \ - as dma_k_idx: - tik_instance.data_move(inputb_l1[dma_k_idx, :, :], - k1nk0_workspace[k_idx * k_tiling_size // K0 + dma_k_idx, - (core_id * n_cycle_times // 2 + n_idx) - * n_tiling_size, :], - 0, 1, n_tiling_size, 0, 0) - else: - tik_instance.data_move(inputb_l1, k1nk0_workspace[k_idx * k_tiling_size // K0, - (core_id * n_cycle_times // 2 + n_idx) - * n_tiling_size, :], - 0, k_tiling_size // K0, n_tiling_size, n_size - n_tiling_size, 0) - # Call matmul API to matrix multiplication calculation. - with tik_instance.if_scope(k_idx == 0): - tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size, - init_l1out=True) - with tik_instance.else_scope(): - tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size, - init_l1out=False) - tik_instance.fixpipe(nmk0_workspace[n_tiling_size // 16 * (core_id * n_cycle_times // 2 + n_idx), - m_idx * m_tiling_size, :], - dst_l0c, n_tiling_size // 16, - m_tiling_size * 16 * DTYPE_SIZE[C_loc_out_type] // 32, - (m_size - m_tiling_size) * 16 * DTYPE_SIZE[C_loc_out_type] // 32, 0) - - N1MN0_TO_MN(tik_instance, mn_gm_output, nmk0_workspace, C_loc_out_type, n_size // K0, m_size, K0) - - tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[mk_gm_input, kn_gm_input], outputs=[mn_gm_output]) - return tik_instance - - -def matmul_tik(input_x1, input_x2, output_y=None, kernel_name="simple_matmul"): - """ - matmul_tik main func - Parameters - ---------- - input_x1: input data 1 - input_x2: input data 2 - output_y: output dta - """ - shape_a = input_x1.get("ori_shape") - shape_b = input_x2.get("ori_shape") - m = shape_a[0] - k = shape_a[1] - n = shape_b[1] - data_type = input_x1.get("dtype").lower() - params = { - 'M': m, - 'K': k, - 'N': n, - 'data_type': data_type, - 'm_tiling_size': 16, - 'm_cycle_times': 1, - 'm_thread_num': 1, - 'n_tiling_size': 64, - 'n_cycle_times': 16, - 'n_thread_num': 1, - 'k_tiling_size': 32, - 'k_cycle_times': 2, - 'k_thread_num': 2, - 'output_y': output_y - } - return matmul_tik_compute(params, kernel_name) +""" +Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul_tik +""" + +from tbe import tik +from tbe.common.platform import get_soc_spec + +DTYPE_SIZE = { + 'bool': 1, + 'uint8': 1, + 'int8': 1, + 'uint16': 2, + 'int16': 2, + 'int24': 3, + 'uint32': 4, + 'int32': 4, + 'float16': 2, + 'float32': 4, + 'int48': 6, + 'int64': 8, + 'uint64': 8, + 'float64': 8 +} + + +def MK_TO_K1MK0(tik_instance, mk_input_tensor, k1mk0_tensor, dtype, k1, m, k0): + """data move mk to k1mk0""" + src_ub = tik_instance.Tensor(dtype, (k1, m, k0), name='src_ub', scope=tik.scope_ubuf) + + # data_move(m, k) ---> (k1, m, k0) + with tik_instance.for_range(0, k1) as i: + tik_instance.data_move(src_ub[i * m * k0:], mk_input_tensor[i * k0:], 0, m, k0 * DTYPE_SIZE[dtype] // 32, + (k1 - 1) * k0 * DTYPE_SIZE[dtype] // 32, 0) + + tik_instance.data_move(k1mk0_tensor, src_ub, 0, 1, k1 * m * k0 * DTYPE_SIZE[dtype] // 32, 0, 0) + + +def KN_TO_K1NK0(tik_instance, kn_input_tensor, k1nk0_tensor, dtype, k1, n, k0): + """data move kn to k1nk0""" + + with tik_instance.for_range(0, k1) as index: + k1nk0_ub = tik_instance.Tensor(dtype, (n, k0), tik.scope_ubuf, "k1nk0_ub") + src_ub = tik_instance.Tensor(dtype, (k0, n), tik.scope_ubuf, "src_ub") + burst_len = k0 * n * DTYPE_SIZE[dtype] // 32 + tik_instance.data_move(src_ub, kn_input_tensor[index * k0 * n], 0, 1, burst_len, 0, 0) + dst_list = [k1nk0_ub[16 * i] for i in range(16)] + src_list = [src_ub[n * i] for i in range(16)] + rep_times = n // k0 + dst_rep_stride = k0 + src_rep_stride = 1 + tik_instance.vec_trans_scatter(False, False, dst_list, src_list, rep_times, dst_rep_stride, src_rep_stride) + tik_instance.data_move(k1nk0_tensor[index * k0 * n], k1nk0_ub, 0, 1, burst_len, 0, 0) + + +def N1MN0_TO_MN(tik_instance, mn_output_tensor, n1mn0_tensor, dtype, n1, m, n0): + """data move mn to n1mn0""" + src_ub = tik_instance.Tensor(dtype, (m, n1 * n0), name='src_ub', scope=tik.scope_ubuf) + + # data_move(n1, m, n0) ---> (m, n) + with tik_instance.for_range(0, n1) as i: + tik_instance.data_move(src_ub[i * n0:], n1mn0_tensor[i * m * n0:], 0, m, + n0 * DTYPE_SIZE[dtype] // 32, 0, (n1 - 1) * n0 * DTYPE_SIZE[dtype] // 32) + + tik_instance.data_move(mn_output_tensor, src_ub, 0, 1, m * n1 * n0 * DTYPE_SIZE[dtype] // 32, 0, 0) + + +def matmul_tik_compute(params, kernel_name): + """ + matmul tik compute + @param params: matmul data + @param kernel_name: kernel name + @return: tik instance + """ + tik_instance = tik.Tik() + if not isinstance(params, dict): + params = params.__dict__ + m_size, k_size, n_size = params['M'], params['K'], params['N'] + data_type = params["data_type"] + m_tiling_size = int(params["m_tiling_size"]) + n_tiling_size = int(params["n_tiling_size"]) + k_tiling_size = int(params['k_tiling_size']) + + m_cycle_times = params["m_cycle_times"] + n_cycle_times = params["n_cycle_times"] + k_cycle_times = params["k_cycle_times"] + + # Determine the output type + if data_type == "float16": + if get_soc_spec("SOC_VERSION") in ["SD3403", "OPTG", "Hi3796CV300CS", "TsnsC"]: + C_loc_out_type = "float16" + else: + C_loc_out_type = "float32" + K0 = 16 + else: + C_loc_out_type = "int32" + K0 = 32 + block_size = 16 + + n_thread_num = params['n_thread_num'] + m_thread_num = params['m_thread_num'] + k_thread_num = params['k_thread_num'] + + mk_gm_input = tik_instance.Tensor(data_type, (m_size, k_size), name="mk_input_gm", scope=tik.scope_gm) + kn_gm_input = tik_instance.Tensor(data_type, (k_size, n_size), name="kn_input_gm", scope=tik.scope_gm) + + k1mk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, m_size, K0), name="k1mk0_workspace", + scope=tik.scope_gm, is_workspace=True) + + k1nk0_workspace = tik_instance.Tensor(data_type, (k_size // K0, n_size, K0), name="k1nk0_workspace", + scope=tik.scope_gm, is_workspace=True) + + mn_gm_output = tik_instance.Tensor(C_loc_out_type, (m_size, n_size), tik.scope_gm, name="mn_output_gm") + nmk0_workspace = tik_instance.Tensor(C_loc_out_type, (n_size // block_size, m_size, block_size), + name="nmk0_workspace", scope=tik.scope_gm, is_workspace=True) + + MK_TO_K1MK0(tik_instance, mk_gm_input, k1mk0_workspace, data_type, k_size // K0, m_size, K0) + KN_TO_K1NK0(tik_instance, kn_gm_input, k1nk0_workspace, data_type, k_size // K0, n_size, K0) + + # Tiling is realized through the for_range() loop. + with tik_instance.for_range(0, 2, block_num=1) as core_id: + with tik_instance.for_range(0, n_cycle_times // 2, thread_num=n_thread_num) as n_idx: + with tik_instance.for_range(0, m_cycle_times, thread_num=m_thread_num) as m_idx: + dst_l0c = tik_instance.Tensor(C_loc_out_type, [n_tiling_size // 16, m_tiling_size, 16], name='dst_l0c', + scope=tik.scope_cbuf_out) + with tik_instance.for_range(0, k_cycle_times, + thread_num=k_thread_num) as k_idx: + # Calculation result data transfer. + inputa_l1 = tik_instance.Tensor(params['data_type'], [k_tiling_size // K0, m_tiling_size, K0], + name="A_tiling_l1", scope=tik.scope_cbuf) + tik_instance.data_move(inputa_l1, + k1mk0_workspace[k_idx * k_tiling_size // K0, m_idx * m_tiling_size, :], + 0, k_tiling_size // K0, m_tiling_size, m_size - m_tiling_size, 0) + inputb_l1 = tik_instance.Tensor(params["data_type"], [k_tiling_size // K0, n_tiling_size, K0], + name="B_tiling_l1", scope=tik.scope_cbuf) + if n_size - n_tiling_size > 65535: + with tik_instance.for_range(0, k_tiling_size // K0) \ + as dma_k_idx: + tik_instance.data_move(inputb_l1[dma_k_idx, :, :], + k1nk0_workspace[k_idx * k_tiling_size // K0 + dma_k_idx, + (core_id * n_cycle_times // 2 + n_idx) + * n_tiling_size, :], + 0, 1, n_tiling_size, 0, 0) + else: + tik_instance.data_move(inputb_l1, k1nk0_workspace[k_idx * k_tiling_size // K0, + (core_id * n_cycle_times // 2 + n_idx) + * n_tiling_size, :], + 0, k_tiling_size // K0, n_tiling_size, n_size - n_tiling_size, 0) + # Call matmul API to matrix multiplication calculation. + with tik_instance.if_scope(k_idx == 0): + tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size, + init_l1out=True) + with tik_instance.else_scope(): + tik_instance.matmul(dst_l0c, inputa_l1, inputb_l1, m_tiling_size, k_tiling_size, n_tiling_size, + init_l1out=False) + tik_instance.fixpipe(nmk0_workspace[n_tiling_size // 16 * (core_id * n_cycle_times // 2 + n_idx), + m_idx * m_tiling_size, :], + dst_l0c, n_tiling_size // 16, + m_tiling_size * 16 * DTYPE_SIZE[C_loc_out_type] // 32, + (m_size - m_tiling_size) * 16 * DTYPE_SIZE[C_loc_out_type] // 32, 0) + + N1MN0_TO_MN(tik_instance, mn_gm_output, nmk0_workspace, C_loc_out_type, n_size // K0, m_size, K0) + + tik_instance.BuildCCE(kernel_name=kernel_name, inputs=[mk_gm_input, kn_gm_input], outputs=[mn_gm_output]) + return tik_instance + + +def matmul_tik(input_x1, input_x2, output_y=None, kernel_name="simple_matmul"): + """ + matmul_tik main func + Parameters + ---------- + input_x1: input data 1 + input_x2: input data 2 + output_y: output dta + """ + shape_a = input_x1.get("ori_shape") + shape_b = input_x2.get("ori_shape") + m = shape_a[0] + k = shape_a[1] + n = shape_b[1] + data_type = input_x1.get("dtype").lower() + params = { + 'M': m, + 'K': k, + 'N': n, + 'data_type': data_type, + 'm_tiling_size': 16, + 'm_cycle_times': 1, + 'm_thread_num': 1, + 'n_tiling_size': 64, + 'n_cycle_times': 16, + 'n_thread_num': 1, + 'k_tiling_size': 32, + 'k_cycle_times': 2, + 'k_thread_num': 2, + 'output_y': output_y + } + return matmul_tik_compute(params, kernel_name) diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310/matmul_tik.ini b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310/matmul_tik.ini index 9f5f156633c..50c0b552a69 100644 --- a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310/matmul_tik.ini +++ b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310/matmul_tik.ini @@ -1,20 +1,20 @@ -[MatmulTik] -input0.name=x1 -input0.dtype=int8,uint8,float16 -input0.shape=all -input0.needCompile=false -input0.paramType=required -input0.format=ND,ND,ND -input1.name=x2 -input1.dtype=int8,int8,float16 -input1.shape=all -input1.needCompile=false -input1.paramType=required -input1.format=ND,ND,ND -output0.name=y -output0.dtype=int32,int32,float -output0.shape=all -output0.paramType=required -output0.format=ND,ND,ND -opFile.value=matmul_tik -opInterface.value=matmul_tik +[MatmulTik] +input0.name=x1 +input0.dtype=int8,uint8,float16 +input0.shape=all +input0.needCompile=false +input0.paramType=required +input0.format=ND,ND,ND +input1.name=x2 +input1.dtype=int8,int8,float16 +input1.shape=all +input1.needCompile=false +input1.paramType=required +input1.format=ND,ND,ND +output0.name=y +output0.dtype=int32,int32,float +output0.shape=all +output0.paramType=required +output0.format=ND,ND,ND +opFile.value=matmul_tik +opInterface.value=matmul_tik diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310p/matmul_tik.ini b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310p/matmul_tik.ini index 9f5f156633c..50c0b552a69 100644 --- a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310p/matmul_tik.ini +++ b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend310p/matmul_tik.ini @@ -1,20 +1,20 @@ -[MatmulTik] -input0.name=x1 -input0.dtype=int8,uint8,float16 -input0.shape=all -input0.needCompile=false -input0.paramType=required -input0.format=ND,ND,ND -input1.name=x2 -input1.dtype=int8,int8,float16 -input1.shape=all -input1.needCompile=false -input1.paramType=required -input1.format=ND,ND,ND -output0.name=y -output0.dtype=int32,int32,float -output0.shape=all -output0.paramType=required -output0.format=ND,ND,ND -opFile.value=matmul_tik -opInterface.value=matmul_tik +[MatmulTik] +input0.name=x1 +input0.dtype=int8,uint8,float16 +input0.shape=all +input0.needCompile=false +input0.paramType=required +input0.format=ND,ND,ND +input1.name=x2 +input1.dtype=int8,int8,float16 +input1.shape=all +input1.needCompile=false +input1.paramType=required +input1.format=ND,ND,ND +output0.name=y +output0.dtype=int32,int32,float +output0.shape=all +output0.paramType=required +output0.format=ND,ND,ND +opFile.value=matmul_tik +opInterface.value=matmul_tik diff --git a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend910/matmul_tik.ini b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend910/matmul_tik.ini index 9f5f156633c..50c0b552a69 100644 --- a/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend910/matmul_tik.ini +++ b/mindspore/lite/tools/kernel_builder/ascend/tbe_and_aicpu/tbe/op_info_cfg/ai_core/ascend910/matmul_tik.ini @@ -1,20 +1,20 @@ -[MatmulTik] -input0.name=x1 -input0.dtype=int8,uint8,float16 -input0.shape=all -input0.needCompile=false -input0.paramType=required -input0.format=ND,ND,ND -input1.name=x2 -input1.dtype=int8,int8,float16 -input1.shape=all -input1.needCompile=false -input1.paramType=required -input1.format=ND,ND,ND -output0.name=y -output0.dtype=int32,int32,float -output0.shape=all -output0.paramType=required -output0.format=ND,ND,ND -opFile.value=matmul_tik -opInterface.value=matmul_tik +[MatmulTik] +input0.name=x1 +input0.dtype=int8,uint8,float16 +input0.shape=all +input0.needCompile=false +input0.paramType=required +input0.format=ND,ND,ND +input1.name=x2 +input1.dtype=int8,int8,float16 +input1.shape=all +input1.needCompile=false +input1.paramType=required +input1.format=ND,ND,ND +output0.name=y +output0.dtype=int32,int32,float +output0.shape=all +output0.paramType=required +output0.format=ND,ND,ND +opFile.value=matmul_tik +opInterface.value=matmul_tik diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/__init__.py b/mindspore/lite/tools/mslite_bench/mslite_bench/__init__.py index 0fa0a0c2a26..e4539525f45 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/__init__.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/__init__.py @@ -1,30 +1,30 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -mslite bench classes and functions -""" - -from mslite_bench.common.config import ( - MsliteConfig, PaddleConfig, OnnxConfig, TFConfig -) -from mslite_bench.infer_base.infer_session_factory import InferSessionFactory -from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary - -acc_info_between_features = CrossFrameworkAccSummary.acc_infos_between_features - -__all__ = [ - 'InferSessionFactory', 'MsliteConfig', 'PaddleConfig', 'OnnxConfig', 'TFConfig', - 'acc_info_between_features' -] +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +mslite bench classes and functions +""" + +from mslite_bench.common.config import ( + MsliteConfig, PaddleConfig, OnnxConfig, TFConfig +) +from mslite_bench.infer_base.infer_session_factory import InferSessionFactory +from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary + +acc_info_between_features = CrossFrameworkAccSummary.acc_infos_between_features + +__all__ = [ + 'InferSessionFactory', 'MsliteConfig', 'PaddleConfig', 'OnnxConfig', 'TFConfig', + 'acc_info_between_features' +] diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/__main__.py b/mindspore/lite/tools/mslite_bench/mslite_bench/__main__.py index 03097e945fd..4928138f478 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/__main__.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/__main__.py @@ -1,80 +1,80 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -__main__ for mslite_bench -""" - -import os - -from mslite_bench.utils import ArgParser, InferLogger -from mslite_bench.common.model_info_enum import TaskType -from mslite_bench.common.task_common_func import CommonFunc - - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -os.environ['KMP_WARNINGS'] = '0' -os.environ['GLOG_v'] = '3' - -if __name__ == '__main__': - args = ArgParser.parse_arguments() - mslite_logger = InferLogger(args.log_path) - mslite_logger.set_level(CommonFunc.logging_level(args.log_level)) - logger = mslite_logger.logger - logger.debug('Start model infer now!') - - if args.task_type.lower() == TaskType.FRAMEWORK_CMP.value: - try: - from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary - except ImportError as e: - logger.error('Failed to import CFA: %s', e) - raise - logger.debug('Start framework compare task!') - CrossFrameworkAccSummary.accuracy_compare_func(args, logger) - elif args.task_type.lower() in set( - [TaskType.NPU_DYNAMIC_INFER.value, TaskType.MODEL_INFER.value] - ): - try: - from mslite_bench.tools.easy_infer import EasyInfer - except ImportError as e: - logger.error('Failed to import easy infer: %s', e) - raise - if args.task_type.lower() == TaskType.NPU_DYNAMIC_INFER.value: - logger.debug('Start mslite model dynamic infer task!') - EasyInfer.ms_dynamic_input_infer(args, logger) - else: - logger.debug('Start model infer: %s', args.model_file) - EasyInfer.easy_infer(args, logger) - elif args.task_type.lower() == TaskType.CONVERTER.value: - try: - from mslite_bench.tools.converter import MsliteConverter - except ImportError as e: - logger.error('Failed to import MsliteConverter class') - raise - MsliteConverter.convert(args, logger) - elif args.task_type.lower() == TaskType.AUTO_CMP.value: - try: - from mslite_bench.tools.mslite_auto_cmp import MsliteAutoCMP - except ImportError as e: - logger.error('Failed to import MsliteAutoCMP class') - raise - if args.input_tensor_shapes is None: - logger.error('Shall input input_tensor_shapes for accuracy compare') - raise ValueError('input_tensor_shapes is None') - if args.input_tensor_dtypes is None: - logger.error('Shall input input_tensor_dtypes for accuracy compare') - raise ValueError('input_tensor_dtypes is None') - MsliteAutoCMP.acc_infos_in_specific_node(args, logger) - else: - raise NotImplementedError(f'Task Type {args.task_type} ') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +__main__ for mslite_bench +""" + +import os + +from mslite_bench.utils import ArgParser, InferLogger +from mslite_bench.common.model_info_enum import TaskType +from mslite_bench.common.task_common_func import CommonFunc + + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ['KMP_WARNINGS'] = '0' +os.environ['GLOG_v'] = '3' + +if __name__ == '__main__': + args = ArgParser.parse_arguments() + mslite_logger = InferLogger(args.log_path) + mslite_logger.set_level(CommonFunc.logging_level(args.log_level)) + logger = mslite_logger.logger + logger.debug('Start model infer now!') + + if args.task_type.lower() == TaskType.FRAMEWORK_CMP.value: + try: + from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary + except ImportError as e: + logger.error('Failed to import CFA: %s', e) + raise + logger.debug('Start framework compare task!') + CrossFrameworkAccSummary.accuracy_compare_func(args, logger) + elif args.task_type.lower() in set( + [TaskType.NPU_DYNAMIC_INFER.value, TaskType.MODEL_INFER.value] + ): + try: + from mslite_bench.tools.easy_infer import EasyInfer + except ImportError as e: + logger.error('Failed to import easy infer: %s', e) + raise + if args.task_type.lower() == TaskType.NPU_DYNAMIC_INFER.value: + logger.debug('Start mslite model dynamic infer task!') + EasyInfer.ms_dynamic_input_infer(args, logger) + else: + logger.debug('Start model infer: %s', args.model_file) + EasyInfer.easy_infer(args, logger) + elif args.task_type.lower() == TaskType.CONVERTER.value: + try: + from mslite_bench.tools.converter import MsliteConverter + except ImportError as e: + logger.error('Failed to import MsliteConverter class') + raise + MsliteConverter.convert(args, logger) + elif args.task_type.lower() == TaskType.AUTO_CMP.value: + try: + from mslite_bench.tools.mslite_auto_cmp import MsliteAutoCMP + except ImportError as e: + logger.error('Failed to import MsliteAutoCMP class') + raise + if args.input_tensor_shapes is None: + logger.error('Shall input input_tensor_shapes for accuracy compare') + raise ValueError('input_tensor_shapes is None') + if args.input_tensor_dtypes is None: + logger.error('Shall input input_tensor_dtypes for accuracy compare') + raise ValueError('input_tensor_dtypes is None') + MsliteAutoCMP.acc_infos_in_specific_node(args, logger) + else: + raise NotImplementedError(f'Task Type {args.task_type} ') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/common/config.py b/mindspore/lite/tools/mslite_bench/mslite_bench/common/config.py index b92be2dad4c..b60a4532e1b 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/common/config.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/common/config.py @@ -1,85 +1,85 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -configs for mslite bench -""" -from dataclasses import dataclass -from typing import List, Tuple, Dict - - -from mslite_bench.common.model_info_enum import FrameworkType - - -@dataclass -class Config: - """base config""" - device: str = 'cpu' - device_id: int = 0 - log_path: str = None - batch_size: int = 1 - - -class ModelConfig(Config): - """model config""" - infer_framework: FrameworkType = FrameworkType.MSLITE.value - thread_num: int = 1 - input_tensor_shapes: Dict[str, Tuple] = None - input_tensor_dtypes: Dict[str, str] = None - output_tensor_names: List[str] = None - - -@dataclass -class MsliteConfig(ModelConfig): - """mslite config""" - thread_affinity_mode: int = 2 - - ascend_provider: str = '' - - -@dataclass -class PaddleConfig(ModelConfig): - """paddle config""" - infer_framework = FrameworkType.PADDLE.value - is_fp16: bool = False - is_int8: bool = False - - # for paddle infer - is_enable_tensorrt: bool = False - gpu_memory_size: int = 100 - tensorrt_optim_input_shape: Dict[str, List[int]] = None - tensorrt_min_input_shape: Dict[str, List[int]] = None - tensorrt_max_input_shape: Dict[str, List[int]] = None - - -@dataclass -class OnnxConfig(ModelConfig): - """onnx config""" - # for onnx export - infer_framework = FrameworkType.ONNX.value - - -@dataclass -class TFConfig(ModelConfig): - """tensorflow config""" - infer_framework = FrameworkType.TF.value - - -@dataclass -class BenchConfig(Config): - """benchmark config""" - eps: float = 1e-5 - random_input_flag: bool = False - cmp_model_file: str = None - input_data_file: str = None +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +configs for mslite bench +""" +from dataclasses import dataclass +from typing import List, Tuple, Dict + + +from mslite_bench.common.model_info_enum import FrameworkType + + +@dataclass +class Config: + """base config""" + device: str = 'cpu' + device_id: int = 0 + log_path: str = None + batch_size: int = 1 + + +class ModelConfig(Config): + """model config""" + infer_framework: FrameworkType = FrameworkType.MSLITE.value + thread_num: int = 1 + input_tensor_shapes: Dict[str, Tuple] = None + input_tensor_dtypes: Dict[str, str] = None + output_tensor_names: List[str] = None + + +@dataclass +class MsliteConfig(ModelConfig): + """mslite config""" + thread_affinity_mode: int = 2 + + ascend_provider: str = '' + + +@dataclass +class PaddleConfig(ModelConfig): + """paddle config""" + infer_framework = FrameworkType.PADDLE.value + is_fp16: bool = False + is_int8: bool = False + + # for paddle infer + is_enable_tensorrt: bool = False + gpu_memory_size: int = 100 + tensorrt_optim_input_shape: Dict[str, List[int]] = None + tensorrt_min_input_shape: Dict[str, List[int]] = None + tensorrt_max_input_shape: Dict[str, List[int]] = None + + +@dataclass +class OnnxConfig(ModelConfig): + """onnx config""" + # for onnx export + infer_framework = FrameworkType.ONNX.value + + +@dataclass +class TFConfig(ModelConfig): + """tensorflow config""" + infer_framework = FrameworkType.TF.value + + +@dataclass +class BenchConfig(Config): + """benchmark config""" + eps: float = 1e-5 + random_input_flag: bool = False + cmp_model_file: str = None + input_data_file: str = None diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/common/enum_class.py b/mindspore/lite/tools/mslite_bench/mslite_bench/common/enum_class.py index 43c5a9cd206..8a72014148b 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/common/enum_class.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/common/enum_class.py @@ -1,30 +1,30 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -data type classes -""" -from enum import Enum -import numpy as np - - -class NumpyDtype(Enum): - """numpy data type class""" - INT32 = np.dtype('int32') - INT64 = np.dtype('int64') - FLOAT32 = np.dtype('float32') - FLOAT64 = np.dtype('float64') - FLOAT16 = np.dtype('float16') - UINT8 = np.dtype('uint8') - INT8 = np.dtype('int8') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +data type classes +""" +from enum import Enum +import numpy as np + + +class NumpyDtype(Enum): + """numpy data type class""" + INT32 = np.dtype('int32') + INT64 = np.dtype('int64') + FLOAT32 = np.dtype('float32') + FLOAT64 = np.dtype('float64') + FLOAT16 = np.dtype('float16') + UINT8 = np.dtype('uint8') + INT8 = np.dtype('int8') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/common/model_info_enum.py b/mindspore/lite/tools/mslite_bench/mslite_bench/common/model_info_enum.py index f50d2c3eb99..5523d3986e2 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/common/model_info_enum.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/common/model_info_enum.py @@ -1,60 +1,60 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -model related enum infos -""" - -from enum import Enum - - -class TaskType(Enum): - """task type enum""" - MODEL_INFER = "infer" - FRAMEWORK_CMP = "framework_cmp" - CONVERTER = "convert" - AUTO_CMP = "auto_cmp" - NPU_DYNAMIC_INFER = "npu_dynamic_infer" - - -class DeviceType(Enum): - """device type enum""" - CPU = 'cpu' - ASCEND = 'ascend' - GPU = 'gpu' - - -class FrameworkType(Enum): - """framework type enum""" - TF = 'TF' - ONNX = 'ONNX' - MSLITE = 'MSLITE' - PADDLE = 'PADDLE' - - -class SaveFileType(Enum): - """save file type enum""" - DONT_SAVE = 'dont_save' - NPY = 'npy' - BIN = 'bin' - - -class ErrorAlgType(Enum): - """ - Algorithm types to calculate error between features - - MEAN_RELATIVE_ERROR: sum(abs(A-B) / A) / A.size - - COSINE_SIMILARITY: sum(A * B) / (sqrt(sum(A * A)) * sqrt(sum(B * B))) - """ - MEAN_RELATIVE_ERROR = "mean_relative_error" - COSINE_SIMILARITY = "cosine_similarity" +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +model related enum infos +""" + +from enum import Enum + + +class TaskType(Enum): + """task type enum""" + MODEL_INFER = "infer" + FRAMEWORK_CMP = "framework_cmp" + CONVERTER = "convert" + AUTO_CMP = "auto_cmp" + NPU_DYNAMIC_INFER = "npu_dynamic_infer" + + +class DeviceType(Enum): + """device type enum""" + CPU = 'cpu' + ASCEND = 'ascend' + GPU = 'gpu' + + +class FrameworkType(Enum): + """framework type enum""" + TF = 'TF' + ONNX = 'ONNX' + MSLITE = 'MSLITE' + PADDLE = 'PADDLE' + + +class SaveFileType(Enum): + """save file type enum""" + DONT_SAVE = 'dont_save' + NPY = 'npy' + BIN = 'bin' + + +class ErrorAlgType(Enum): + """ + Algorithm types to calculate error between features + - MEAN_RELATIVE_ERROR: sum(abs(A-B) / A) / A.size + - COSINE_SIMILARITY: sum(A * B) / (sqrt(sum(A * A)) * sqrt(sum(B * B))) + """ + MEAN_RELATIVE_ERROR = "mean_relative_error" + COSINE_SIMILARITY = "cosine_similarity" diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/common/task_common_func.py b/mindspore/lite/tools/mslite_bench/mslite_bench/common/task_common_func.py index f53d6209b35..b8c8fa08b21 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/common/task_common_func.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/common/task_common_func.py @@ -1,225 +1,225 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -common functions -""" -import logging -import os -import stat -from typing import Dict, Tuple -import importlib - -import numpy as np - -from mslite_bench.common.model_info_enum import FrameworkType -from mslite_bench.common.enum_class import NumpyDtype -from mslite_bench.common.config import ( - MsliteConfig, TFConfig, PaddleConfig, OnnxConfig -) - - -class CommonFunc: - """common functions""" - @classmethod - def get_framework_config(cls, - model_path, - args): - """ - get framework config by model type and args - params: - model_path: path to model file - args: input arguments - return: model config - """ - if not os.path.exists(model_path): - raise ValueError(f'Create model session failed: {model_path} does not exist') - - if model_path.endswith('pb'): - cfg = cls.init_tf_cfg() - elif model_path.endswith('onnx'): - cfg = cls.init_onnx_cfg() - elif model_path.endswith('ms') or model_path.endswith('mindir'): - cfg = cls.init_mslite_cfg(args, model_path) - elif model_path.endswith('pdmodel'): - cfg = cls.init_paddle_cfg(args) - else: - raise ValueError(f'model {model_path} is not supported yet') - - cfg.input_tensor_shapes = cls.get_tensor_shapes(args.input_tensor_shapes) - cfg.device = args.device - cfg.device_id = args.device_id - cfg.batch_size = args.batch_size - cfg.output_tensor_names = args.output_tensor_names - cfg.thread_num = args.thread_num - - if cfg.input_tensor_shapes is None and args.input_data_file is not None: - input_data_map = cls.get_input_data_map_from_file(args.input_data_file) - cfg.input_tensor_shapes = { - key: val.shape for key, val in input_data_map.items() - } - - return cfg - - @classmethod - def create_numpy_data_map(cls, - args): - """ - create input tensor map, with key input tensor name, - value its numpy value - """ - if args.input_data_file is not None: - input_data_map = np.load(args.input_data_file, allow_pickle=True).item() - return input_data_map - - input_tensor_dtypes = CommonFunc.parse_dtype_infos(args.input_tensor_dtypes) - input_tensor_shapes = CommonFunc.get_tensor_shapes(args.input_tensor_shapes) - input_tensor_infos = { - key: (shape, input_tensor_dtypes.get(key)) - for key, shape in input_tensor_shapes.items() - } - try: - input_tensor_map = cls.create_numpy_data_map_out(input_tensor_infos) - except ValueError as e: - raise e - - return input_tensor_map - - @classmethod - def init_onnx_cfg(cls): - """init onnx config""" - cfg = OnnxConfig() - return cfg - - @classmethod - def init_mslite_cfg(cls, args, model_path): - """init mslite config""" - cfg = MsliteConfig() - cfg.infer_framework = FrameworkType.MSLITE.value - cfg.mslite_model_type = 4 if model_path.endswith('ms') else 0 - cfg.thread_affinity_mode = args.thread_affinity_mode - cfg.ascend_provider = args.ascend_provider - return cfg - - @classmethod - def init_paddle_cfg(cls, args): - """init paddle config""" - cfg = PaddleConfig() - cfg.infer_framework = FrameworkType.PADDLE.value - cfg.is_fp16 = args.is_fp16 - cfg.is_int8 = args.is_int8 - cfg.is_enable_tensorrt = args.is_enable_tensorrt - def tmp_func(x): - if x is None: - return None - return cls.get_tensor_shapes(x) - cfg.tensorrt_optim_input_shape = tmp_func(args.tensorrt_optim_input_shape) - cfg.tensorrt_min_input_shape = tmp_func(args.tensorrt_min_input_shape) - cfg.tensorrt_max_input_shape = tmp_func(args.tensorrt_max_input_shape) - if cfg.tensorrt_min_input_shape is None: - cfg.tensorrt_min_input_shape = cfg.tensorrt_optim_input_shape - if cfg.tensorrt_max_input_shape is None: - cfg.tensorrt_max_input_shape = cfg.tensorrt_optim_input_shape - return cfg - - @staticmethod - def get_tensor_shapes(tensor_shapes: str) -> Dict[str, Tuple[int]]: - """parse tensor shapes string into dict""" - if tensor_shapes is None: - return {} - - input_tensor_shape = {} - shape_list = tensor_shapes.split(';') - - for shapes in shape_list: - name, shape = shapes.split(':') - shape = [int(i) for i in shape.split(',')] - input_tensor_shape[name] = shape - - return input_tensor_shape - - @staticmethod - def import_module(module_name, file_path=None): - """import module functions""" - return importlib.import_module(module_name, package=file_path) - - @staticmethod - def get_input_data_map_from_file(input_data_file): - """get input data map from file""" - return np.load(input_data_file, allow_pickle=True).item() - - @staticmethod - def create_numpy_data_map_out(tensor_infos): - """create numpy data dict""" - np_data_map = {} - for tensor_name, infos in tensor_infos.items(): - if not isinstance(infos, tuple): - raise ValueError('input info shall contain tensor shape and tensor dtype') - shape, dtype = infos - np_dtype = getattr(NumpyDtype, dtype.upper()).value - tensor_data = np.random.rand(*shape).astype(np_dtype) - np_data_map[tensor_name] = tensor_data - - return np_data_map - - @staticmethod - def save_output_as_benchmark_txt(save_dir, - output_tensor): - """save output tensor as benchmark type text""" - for key, value in output_tensor.items(): - save_path = f'{save_dir}_{"".join(key.split("/"))}.txt' - shape = value.shape - shape_str = '' - for val in shape: - shape_str = shape_str.join(f'{val} ') - dim = len(shape) - flags = os.O_WRONLY - mode = stat.S_IWUSR | stat.S_IRUSR - with os.fdopen(os.open(save_path, flags, mode), 'w') as fi: - fi.write(f'{key} {dim} {shape_str}\n') - np.savetxt(fi, value.flatten(), newline=' ') - - @staticmethod - def init_tf_cfg(): - """init tensorflow config""" - cfg = TFConfig() - return cfg - - @staticmethod - def logging_level(level): - if level == 0: - return logging.DEBUG - if level == 1: - return logging.INFO - if level == 2: - return logging.WARNING - return logging.ERROR - - @staticmethod - def parse_dtype_infos(dtype_infos): - """ - parse input dtype infos string to dict, key is input tensor, - value is tensor dtype - params: - model_path: path to model file - args: input arguments - return: model config - """ - infos = dtype_infos.split(';') - ret = {} - for info in infos: - key, dtype = info.split(':') - ret[key] = dtype.strip() - - return ret +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +common functions +""" +import logging +import os +import stat +from typing import Dict, Tuple +import importlib + +import numpy as np + +from mslite_bench.common.model_info_enum import FrameworkType +from mslite_bench.common.enum_class import NumpyDtype +from mslite_bench.common.config import ( + MsliteConfig, TFConfig, PaddleConfig, OnnxConfig +) + + +class CommonFunc: + """common functions""" + @classmethod + def get_framework_config(cls, + model_path, + args): + """ + get framework config by model type and args + params: + model_path: path to model file + args: input arguments + return: model config + """ + if not os.path.exists(model_path): + raise ValueError(f'Create model session failed: {model_path} does not exist') + + if model_path.endswith('pb'): + cfg = cls.init_tf_cfg() + elif model_path.endswith('onnx'): + cfg = cls.init_onnx_cfg() + elif model_path.endswith('ms') or model_path.endswith('mindir'): + cfg = cls.init_mslite_cfg(args, model_path) + elif model_path.endswith('pdmodel'): + cfg = cls.init_paddle_cfg(args) + else: + raise ValueError(f'model {model_path} is not supported yet') + + cfg.input_tensor_shapes = cls.get_tensor_shapes(args.input_tensor_shapes) + cfg.device = args.device + cfg.device_id = args.device_id + cfg.batch_size = args.batch_size + cfg.output_tensor_names = args.output_tensor_names + cfg.thread_num = args.thread_num + + if cfg.input_tensor_shapes is None and args.input_data_file is not None: + input_data_map = cls.get_input_data_map_from_file(args.input_data_file) + cfg.input_tensor_shapes = { + key: val.shape for key, val in input_data_map.items() + } + + return cfg + + @classmethod + def create_numpy_data_map(cls, + args): + """ + create input tensor map, with key input tensor name, + value its numpy value + """ + if args.input_data_file is not None: + input_data_map = np.load(args.input_data_file, allow_pickle=True).item() + return input_data_map + + input_tensor_dtypes = CommonFunc.parse_dtype_infos(args.input_tensor_dtypes) + input_tensor_shapes = CommonFunc.get_tensor_shapes(args.input_tensor_shapes) + input_tensor_infos = { + key: (shape, input_tensor_dtypes.get(key)) + for key, shape in input_tensor_shapes.items() + } + try: + input_tensor_map = cls.create_numpy_data_map_out(input_tensor_infos) + except ValueError as e: + raise e + + return input_tensor_map + + @classmethod + def init_onnx_cfg(cls): + """init onnx config""" + cfg = OnnxConfig() + return cfg + + @classmethod + def init_mslite_cfg(cls, args, model_path): + """init mslite config""" + cfg = MsliteConfig() + cfg.infer_framework = FrameworkType.MSLITE.value + cfg.mslite_model_type = 4 if model_path.endswith('ms') else 0 + cfg.thread_affinity_mode = args.thread_affinity_mode + cfg.ascend_provider = args.ascend_provider + return cfg + + @classmethod + def init_paddle_cfg(cls, args): + """init paddle config""" + cfg = PaddleConfig() + cfg.infer_framework = FrameworkType.PADDLE.value + cfg.is_fp16 = args.is_fp16 + cfg.is_int8 = args.is_int8 + cfg.is_enable_tensorrt = args.is_enable_tensorrt + def tmp_func(x): + if x is None: + return None + return cls.get_tensor_shapes(x) + cfg.tensorrt_optim_input_shape = tmp_func(args.tensorrt_optim_input_shape) + cfg.tensorrt_min_input_shape = tmp_func(args.tensorrt_min_input_shape) + cfg.tensorrt_max_input_shape = tmp_func(args.tensorrt_max_input_shape) + if cfg.tensorrt_min_input_shape is None: + cfg.tensorrt_min_input_shape = cfg.tensorrt_optim_input_shape + if cfg.tensorrt_max_input_shape is None: + cfg.tensorrt_max_input_shape = cfg.tensorrt_optim_input_shape + return cfg + + @staticmethod + def get_tensor_shapes(tensor_shapes: str) -> Dict[str, Tuple[int]]: + """parse tensor shapes string into dict""" + if tensor_shapes is None: + return {} + + input_tensor_shape = {} + shape_list = tensor_shapes.split(';') + + for shapes in shape_list: + name, shape = shapes.split(':') + shape = [int(i) for i in shape.split(',')] + input_tensor_shape[name] = shape + + return input_tensor_shape + + @staticmethod + def import_module(module_name, file_path=None): + """import module functions""" + return importlib.import_module(module_name, package=file_path) + + @staticmethod + def get_input_data_map_from_file(input_data_file): + """get input data map from file""" + return np.load(input_data_file, allow_pickle=True).item() + + @staticmethod + def create_numpy_data_map_out(tensor_infos): + """create numpy data dict""" + np_data_map = {} + for tensor_name, infos in tensor_infos.items(): + if not isinstance(infos, tuple): + raise ValueError('input info shall contain tensor shape and tensor dtype') + shape, dtype = infos + np_dtype = getattr(NumpyDtype, dtype.upper()).value + tensor_data = np.random.rand(*shape).astype(np_dtype) + np_data_map[tensor_name] = tensor_data + + return np_data_map + + @staticmethod + def save_output_as_benchmark_txt(save_dir, + output_tensor): + """save output tensor as benchmark type text""" + for key, value in output_tensor.items(): + save_path = f'{save_dir}_{"".join(key.split("/"))}.txt' + shape = value.shape + shape_str = '' + for val in shape: + shape_str = shape_str.join(f'{val} ') + dim = len(shape) + flags = os.O_WRONLY + mode = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(save_path, flags, mode), 'w') as fi: + fi.write(f'{key} {dim} {shape_str}\n') + np.savetxt(fi, value.flatten(), newline=' ') + + @staticmethod + def init_tf_cfg(): + """init tensorflow config""" + cfg = TFConfig() + return cfg + + @staticmethod + def logging_level(level): + if level == 0: + return logging.DEBUG + if level == 1: + return logging.INFO + if level == 2: + return logging.WARNING + return logging.ERROR + + @staticmethod + def parse_dtype_infos(dtype_infos): + """ + parse input dtype infos string to dict, key is input tensor, + value is tensor dtype + params: + model_path: path to model file + args: input arguments + return: model config + """ + infos = dtype_infos.split(';') + ret = {} + for info in infos: + key, dtype = info.split(':') + ret[key] = dtype.strip() + + return ret diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier.py b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier.py index ef2798c2f2b..4ec962f8477 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier.py @@ -1,49 +1,49 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""abstract class for graph modifier""" - -from abc import ABC, abstractmethod - - -from mslite_bench.utils import InferLogger - -class ABCGraphModifier(ABC): - """abstract class for graph modifier""" - def __init__(self): - self.blocks_sorted = self._sorted_blocks() - self.logger = InferLogger().logger - - @property - def sorted_blocks(self): - """sorted blocks list based on feed-froward network""" - return self.blocks_sorted - - @abstractmethod - def extract_model(self, - save_path, - input_names=None, - output_names=None): - """ extract sub model based on input and output tensor names""" - raise NotImplementedError - - @abstractmethod - def _all_node_names(self): - """return all node names in network""" - raise NotImplementedError - - @abstractmethod - def _sorted_blocks(self): - """get sorted blocks based on feed-foward network""" - raise NotImplementedError +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""abstract class for graph modifier""" + +from abc import ABC, abstractmethod + + +from mslite_bench.utils import InferLogger + +class ABCGraphModifier(ABC): + """abstract class for graph modifier""" + def __init__(self): + self.blocks_sorted = self._sorted_blocks() + self.logger = InferLogger().logger + + @property + def sorted_blocks(self): + """sorted blocks list based on feed-froward network""" + return self.blocks_sorted + + @abstractmethod + def extract_model(self, + save_path, + input_names=None, + output_names=None): + """ extract sub model based on input and output tensor names""" + raise NotImplementedError + + @abstractmethod + def _all_node_names(self): + """return all node names in network""" + raise NotImplementedError + + @abstractmethod + def _sorted_blocks(self): + """get sorted blocks based on feed-foward network""" + raise NotImplementedError diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier_factory.py b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier_factory.py index dbadb18fe5d..a617b0573fa 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier_factory.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/graph_modifier_factory.py @@ -1,39 +1,39 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""graph modifier factory""" -from mslite_bench.utils.infer_log import InferLogger -from mslite_bench.common.task_common_func import CommonFunc - - -_logger = InferLogger().logger - - -def create_graph_modifier(model_path): - """create graph modifier""" - if model_path.endswith('onnx'): - try: - infer_module = CommonFunc.import_module('mslite_bench.graphs.onnx_graph_modifier') - except ImportError as e: - _logger.info('import tf session failed: %s', e) - raise - return infer_module.OnnxModifier(model_path) - if model_path.endswith('xx'): - try: - infer_module = CommonFunc.import_module('mslite_bench.graphs.tf_graph_modifier') - except ImportError as e: - _logger.info('import tf session failed: %s', e) - raise - return infer_module.TFModifier(model_path) - raise NotImplementedError(f'model type of {model_path} is not supported ') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""graph modifier factory""" +from mslite_bench.utils.infer_log import InferLogger +from mslite_bench.common.task_common_func import CommonFunc + + +_logger = InferLogger().logger + + +def create_graph_modifier(model_path): + """create graph modifier""" + if model_path.endswith('onnx'): + try: + infer_module = CommonFunc.import_module('mslite_bench.graphs.onnx_graph_modifier') + except ImportError as e: + _logger.info('import tf session failed: %s', e) + raise + return infer_module.OnnxModifier(model_path) + if model_path.endswith('xx'): + try: + infer_module = CommonFunc.import_module('mslite_bench.graphs.tf_graph_modifier') + except ImportError as e: + _logger.info('import tf session failed: %s', e) + raise + return infer_module.TFModifier(model_path) + raise NotImplementedError(f'model type of {model_path} is not supported ') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/onnx_graph_modifier.py b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/onnx_graph_modifier.py index 0b5ff6b7578..13d15fed4d9 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/onnx_graph_modifier.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/onnx_graph_modifier.py @@ -1,108 +1,108 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""onnx graph modifier""" -import os -from abc import ABC - -import onnx - -from mslite_bench.graphs.graph_modifier import ABCGraphModifier - - -class OnnxModifier(ABCGraphModifier, ABC): - """ modifier for onnx model""" - def __init__(self, model_path): - super().__init__() - if not os.path.exists(model_path): - raise FileNotFoundError(f'{model_path} does not exist!') - onnx.checker.check_model(model_path) - self.model_path = model_path - self.model = onnx.load(model_path) - self._check_model() - self.model_input_names = self._get_input_names() - self.model_output_names = [model_output.name for model_output in self.model.graph.output] - self.black_node_type_list = { - 'Pad', 'Div', 'Const', 'Shape', 'ConstOfShape', 'Slice', 'Cast', 'Gather', - 'Reshape', 'Unsqueeze', 'Mul', 'RandomNormalLike', 'Exp', 'InstanceNormalization', - 'Where', 'Equal', 'Greater', 'Clip', 'Range', 'IsInf', 'IsNaN', 'Less', 'Loop', - 'Not', 'Or', 'Xor', 'And', 'BitwiseNot', 'BitwiseAnd', 'BitwiseXor', 'BitwiseOr', - 'BitwiseNot', 'BatchNormalization', 'Constant' - } - - def extract_model(self, - save_path, - input_names=None, - output_names=None): - """ extract sub model based on input and output tensor names""" - if input_names is None: - input_names = self.model_input_names - - if output_names is None: - output_names = self.model_output_names - elif len(output_names) == 1 and output_names[0].lower() == 'mslite_bench_all': - output_names = list(set(self._all_node_names()) - set(self.model_input_names)) - else: - valid_node_names = set(self._all_node_names()) - invalid_out_names = [name for name in output_names if name not in valid_node_names] - if invalid_out_names: - output_names = list(set(output_names) - set(invalid_out_names)) - self.logger.warning('Output nodes %s are not supported for ' - 'accuracy compare', invalid_out_names) - if not output_names: - raise ValueError('Shall input valid output names, but it is empty or all invalid') - - if not isinstance(input_names, list) or not isinstance(output_names, list): - raise ValueError("input and output nodes name shall be a list") - - try: - onnx.utils.extract_model(self.model_path, - save_path, - input_names, - output_names, - check_model=False) - except KeyError as e: - self.logger.error('Extract sub model failed, this tensor name is not in graph.value_info: %s', e) - raise - return output_names - - def _all_node_names(self): - """return all node names in network""" - def is_in_black_node_list(node): - return node.op_type in self.black_node_type_list - - output_names = [node.output for node in self.model.graph.node if not is_in_black_node_list(node)] - ret_names = [] - for out_name in output_names: - for name in out_name: - ret_names.append(name) - return ret_names - - def _sorted_blocks(self): - """get sorted blocks based on feed-foward network""" - return [] - - def _get_input_names(self): - """get all input nodes""" - all_input_node_names = {item.name for item in self.model.graph.input} - input_initializer_node_names = {item.name for item in self.model.graph.initializer} - - return list(all_input_node_names - input_initializer_node_names) - - def _check_model(self): - """check model whether has value info""" - graph = onnx.shape_inference.infer_shapes(self.model).graph - if not graph.value_info: - self.logger.error('Model value info is empty, this model is not supported!') - raise ValueError('model value info is empty') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""onnx graph modifier""" +import os +from abc import ABC + +import onnx + +from mslite_bench.graphs.graph_modifier import ABCGraphModifier + + +class OnnxModifier(ABCGraphModifier, ABC): + """ modifier for onnx model""" + def __init__(self, model_path): + super().__init__() + if not os.path.exists(model_path): + raise FileNotFoundError(f'{model_path} does not exist!') + onnx.checker.check_model(model_path) + self.model_path = model_path + self.model = onnx.load(model_path) + self._check_model() + self.model_input_names = self._get_input_names() + self.model_output_names = [model_output.name for model_output in self.model.graph.output] + self.black_node_type_list = { + 'Pad', 'Div', 'Const', 'Shape', 'ConstOfShape', 'Slice', 'Cast', 'Gather', + 'Reshape', 'Unsqueeze', 'Mul', 'RandomNormalLike', 'Exp', 'InstanceNormalization', + 'Where', 'Equal', 'Greater', 'Clip', 'Range', 'IsInf', 'IsNaN', 'Less', 'Loop', + 'Not', 'Or', 'Xor', 'And', 'BitwiseNot', 'BitwiseAnd', 'BitwiseXor', 'BitwiseOr', + 'BitwiseNot', 'BatchNormalization', 'Constant' + } + + def extract_model(self, + save_path, + input_names=None, + output_names=None): + """ extract sub model based on input and output tensor names""" + if input_names is None: + input_names = self.model_input_names + + if output_names is None: + output_names = self.model_output_names + elif len(output_names) == 1 and output_names[0].lower() == 'mslite_bench_all': + output_names = list(set(self._all_node_names()) - set(self.model_input_names)) + else: + valid_node_names = set(self._all_node_names()) + invalid_out_names = [name for name in output_names if name not in valid_node_names] + if invalid_out_names: + output_names = list(set(output_names) - set(invalid_out_names)) + self.logger.warning('Output nodes %s are not supported for ' + 'accuracy compare', invalid_out_names) + if not output_names: + raise ValueError('Shall input valid output names, but it is empty or all invalid') + + if not isinstance(input_names, list) or not isinstance(output_names, list): + raise ValueError("input and output nodes name shall be a list") + + try: + onnx.utils.extract_model(self.model_path, + save_path, + input_names, + output_names, + check_model=False) + except KeyError as e: + self.logger.error('Extract sub model failed, this tensor name is not in graph.value_info: %s', e) + raise + return output_names + + def _all_node_names(self): + """return all node names in network""" + def is_in_black_node_list(node): + return node.op_type in self.black_node_type_list + + output_names = [node.output for node in self.model.graph.node if not is_in_black_node_list(node)] + ret_names = [] + for out_name in output_names: + for name in out_name: + ret_names.append(name) + return ret_names + + def _sorted_blocks(self): + """get sorted blocks based on feed-foward network""" + return [] + + def _get_input_names(self): + """get all input nodes""" + all_input_node_names = {item.name for item in self.model.graph.input} + input_initializer_node_names = {item.name for item in self.model.graph.initializer} + + return list(all_input_node_names - input_initializer_node_names) + + def _check_model(self): + """check model whether has value info""" + graph = onnx.shape_inference.infer_shapes(self.model).graph + if not graph.value_info: + self.logger.error('Model value info is empty, this model is not supported!') + raise ValueError('model value info is empty') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/tf_graph_modifier.py b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/tf_graph_modifier.py index 7d28a57f5af..db0d53c755a 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/tf_graph_modifier.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/graphs/tf_graph_modifier.py @@ -1,72 +1,72 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" tensorflow graph modifier""" -import os -from abc import ABC - -import tensorflow as tf - -from mslite_bench.graphs.graph_modifier import ABCGraphModifier - - -class TFModifier(ABCGraphModifier, ABC): - """ modifier for onnx model""" - def __init__(self, model_path): - super().__init__() - if not os.path.exists(model_path): - raise FileNotFoundError(f'{model_path} does not exist!') - self.model_path = model_path - self.graph_def = self._get_tf_graph() - self.model_input_names = self._get_input_names() - - def extract_model(self, - save_path, - input_names=None, - output_names=None): - """ extract sub model based on input and output tensor names""" - if input_names is None: - input_names = self.model_input_names - - if not isinstance(input_names, list) or not isinstance(output_names, list): - raise ValueError("input and output nodes name shall be a list") - - - sub_graph_def = tf.compat.v1.graph_util.extract_sub_graph(self.graph_def, - output_names) - - with tf.io.gfile.GFile(save_path, 'wb') as f: - f.write(sub_graph_def.SerializeToString()) - return output_names - - def _all_node_names(self): - """return all node names in network""" - raise NotImplementedError - - - def _sorted_blocks(self): - """get sorted blocks based on feed-foward network""" - return [] - - def _get_input_names(self): - """get all input nodes""" - return [node.name for node in self.graph_def.node if node.op == 'Placeholder'] - - def _get_tf_graph(self): - """get tensorflow graph""" - graph_def = None - with tf.io.gfile.GFile(self.model_path, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - return graph_def +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" tensorflow graph modifier""" +import os +from abc import ABC + +import tensorflow as tf + +from mslite_bench.graphs.graph_modifier import ABCGraphModifier + + +class TFModifier(ABCGraphModifier, ABC): + """ modifier for onnx model""" + def __init__(self, model_path): + super().__init__() + if not os.path.exists(model_path): + raise FileNotFoundError(f'{model_path} does not exist!') + self.model_path = model_path + self.graph_def = self._get_tf_graph() + self.model_input_names = self._get_input_names() + + def extract_model(self, + save_path, + input_names=None, + output_names=None): + """ extract sub model based on input and output tensor names""" + if input_names is None: + input_names = self.model_input_names + + if not isinstance(input_names, list) or not isinstance(output_names, list): + raise ValueError("input and output nodes name shall be a list") + + + sub_graph_def = tf.compat.v1.graph_util.extract_sub_graph(self.graph_def, + output_names) + + with tf.io.gfile.GFile(save_path, 'wb') as f: + f.write(sub_graph_def.SerializeToString()) + return output_names + + def _all_node_names(self): + """return all node names in network""" + raise NotImplementedError + + + def _sorted_blocks(self): + """get sorted blocks based on feed-foward network""" + return [] + + def _get_input_names(self): + """get all input nodes""" + return [node.name for node in self.graph_def.node if node.op == 'Placeholder'] + + def _get_tf_graph(self): + """get tensorflow graph""" + graph_def = None + with tf.io.gfile.GFile(self.model_path, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + return graph_def diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/__init__.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/__init__.py index a54b7959a9f..52eccdccb84 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/__init__.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/__init__.py @@ -1,18 +1,18 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""infer base""" -from mslite_bench.infer_base.infer_session_factory import InferSessionFactory - -__all__ = ['InferSessionFactory'] +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""infer base""" +from mslite_bench.infer_base.infer_session_factory import InferSessionFactory + +__all__ = ['InferSessionFactory'] diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/abs_infer_session.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/abs_infer_session.py index 3ac4243b3eb..ed24c14619e 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/abs_infer_session.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/abs_infer_session.py @@ -1,61 +1,61 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""abstract infer session for mslite bench""" -from abc import ABC, abstractmethod -from typing import Dict - -import numpy as np - -from mslite_bench.utils import InferLogger - - -class AbcInferSession(ABC): - """ - abstract infer session - """ - def __init__(self, - model_file: str, - cfg=None): - self.model_file = model_file - self.cfg = cfg - self.input_tensor_shapes = cfg.input_tensor_shapes - self.output_tensor_names = cfg.output_tensor_names - self.batch_size = cfg.batch_size - self.logger = InferLogger(file_path=cfg.log_path).logger - self.data_type_class = None - self.input_tensor_infos = None - - def __call__(self, *args, **kwargs): - return self.infer(*args, **kwargs) - - @property - def input_infos(self): - """property input infos""" - return self.input_tensor_infos - - @property - def dtype_class(self): - """property dtype class""" - return self.data_type_class - - @abstractmethod - def infer(self, input_data_map: Dict[str, np.ndarray]): - """start model infer""" - raise NotImplementedError - - @abstractmethod - def _create_infer_session(self): - """create model infer""" - raise NotImplementedError +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""abstract infer session for mslite bench""" +from abc import ABC, abstractmethod +from typing import Dict + +import numpy as np + +from mslite_bench.utils import InferLogger + + +class AbcInferSession(ABC): + """ + abstract infer session + """ + def __init__(self, + model_file: str, + cfg=None): + self.model_file = model_file + self.cfg = cfg + self.input_tensor_shapes = cfg.input_tensor_shapes + self.output_tensor_names = cfg.output_tensor_names + self.batch_size = cfg.batch_size + self.logger = InferLogger(file_path=cfg.log_path).logger + self.data_type_class = None + self.input_tensor_infos = None + + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + @property + def input_infos(self): + """property input infos""" + return self.input_tensor_infos + + @property + def dtype_class(self): + """property dtype class""" + return self.data_type_class + + @abstractmethod + def infer(self, input_data_map: Dict[str, np.ndarray]): + """start model infer""" + raise NotImplementedError + + @abstractmethod + def _create_infer_session(self): + """create model infer""" + raise NotImplementedError diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/infer_session_factory.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/infer_session_factory.py index 8107a320872..9ade80e1290 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/infer_session_factory.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/infer_session_factory.py @@ -1,103 +1,103 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -infer session factory for unified api -""" -import importlib - -from mslite_bench.common.model_info_enum import FrameworkType -from mslite_bench.utils.infer_log import InferLogger -from mslite_bench.common.task_common_func import CommonFunc - - -_logger = InferLogger().logger - - -class InferSessionFactory: - """ - infer session factory - """ - @classmethod - def create_infer_session_by_args(cls, - args, - logger=None): - """ - params: - args: input arguments - logger: logger for mslite bench - return: model session - """ - if logger is None: - logger = _logger - model_path = args.model_file - param_path = args.params_file - cfg = CommonFunc.get_framework_config(model_path, - args) - - model_session = InferSessionFactory.create_infer_session(model_path, - cfg, - params_file=param_path) - logger.debug('Create model session success') - return model_session - - @classmethod - def create_infer_session(cls, - model_file, - cfg, - params_file=None): - """ - params: - model_file: path to AI model - cfg: framework related config - params_file: path to model weight file, for paddle, caffe etc. - return: model session - """ - infer_framework_type = cfg.infer_framework - if infer_framework_type == FrameworkType.TF.value: - try: - infer_module = cls.import_module('mslite_bench.infer_base.tf_infer_session') - except ImportError as e: - _logger.info('import tf session failed: %s', e) - raise - infer_session = infer_module.TFSession(model_file, cfg) - elif infer_framework_type == FrameworkType.ONNX.value: - try: - infer_module = cls.import_module('mslite_bench.infer_base.onnx_infer_session') - except ImportError as e: - _logger.info('import onnx session failed: %s', e) - raise - infer_session = infer_module.OnnxSession(model_file, cfg) - elif infer_framework_type == FrameworkType.PADDLE.value: - try: - infer_module = cls.import_module('mslite_bench.infer_base.paddle_infer_session') - except ImportError as e: - _logger.info('import paddle session failed: %s', e) - raise - infer_session = infer_module.PaddleSession(model_file, cfg, params_file=params_file) - elif infer_framework_type == FrameworkType.MSLITE.value: - try: - infer_module = cls.import_module('mslite_bench.infer_base.mslite_infer_session') - except ImportError as e: - _logger.info('import paddle session failed: %s', e) - raise - infer_session = infer_module.MsliteSession(model_file, cfg) - else: - raise NotImplementedError(f'{infer_framework_type} is not supported yet') - return infer_session - - @staticmethod - def import_module(module_name, file_path=None): - """import module functions""" - return importlib.import_module(module_name, package=file_path) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +infer session factory for unified api +""" +import importlib + +from mslite_bench.common.model_info_enum import FrameworkType +from mslite_bench.utils.infer_log import InferLogger +from mslite_bench.common.task_common_func import CommonFunc + + +_logger = InferLogger().logger + + +class InferSessionFactory: + """ + infer session factory + """ + @classmethod + def create_infer_session_by_args(cls, + args, + logger=None): + """ + params: + args: input arguments + logger: logger for mslite bench + return: model session + """ + if logger is None: + logger = _logger + model_path = args.model_file + param_path = args.params_file + cfg = CommonFunc.get_framework_config(model_path, + args) + + model_session = InferSessionFactory.create_infer_session(model_path, + cfg, + params_file=param_path) + logger.debug('Create model session success') + return model_session + + @classmethod + def create_infer_session(cls, + model_file, + cfg, + params_file=None): + """ + params: + model_file: path to AI model + cfg: framework related config + params_file: path to model weight file, for paddle, caffe etc. + return: model session + """ + infer_framework_type = cfg.infer_framework + if infer_framework_type == FrameworkType.TF.value: + try: + infer_module = cls.import_module('mslite_bench.infer_base.tf_infer_session') + except ImportError as e: + _logger.info('import tf session failed: %s', e) + raise + infer_session = infer_module.TFSession(model_file, cfg) + elif infer_framework_type == FrameworkType.ONNX.value: + try: + infer_module = cls.import_module('mslite_bench.infer_base.onnx_infer_session') + except ImportError as e: + _logger.info('import onnx session failed: %s', e) + raise + infer_session = infer_module.OnnxSession(model_file, cfg) + elif infer_framework_type == FrameworkType.PADDLE.value: + try: + infer_module = cls.import_module('mslite_bench.infer_base.paddle_infer_session') + except ImportError as e: + _logger.info('import paddle session failed: %s', e) + raise + infer_session = infer_module.PaddleSession(model_file, cfg, params_file=params_file) + elif infer_framework_type == FrameworkType.MSLITE.value: + try: + infer_module = cls.import_module('mslite_bench.infer_base.mslite_infer_session') + except ImportError as e: + _logger.info('import paddle session failed: %s', e) + raise + infer_session = infer_module.MsliteSession(model_file, cfg) + else: + raise NotImplementedError(f'{infer_framework_type} is not supported yet') + return infer_session + + @staticmethod + def import_module(module_name, file_path=None): + """import module functions""" + return importlib.import_module(module_name, package=file_path) diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/mslite_infer_session.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/mslite_infer_session.py index 936074994d0..9b5b30e71ec 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/mslite_infer_session.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/mslite_infer_session.py @@ -1,148 +1,148 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -for mslite infer session -""" -from abc import ABC -from typing import Dict - -import mindspore_lite as mslite -from mindspore_lite import DataType -import numpy as np - -from mslite_bench.infer_base.abs_infer_session import AbcInferSession - - -class MsliteSession(AbcInferSession, ABC): - """ - mindspore lite infer session - """ - def __init__(self, - model_file, - cfg=None): - super().__init__(model_file, cfg) - self.thread_num = cfg.thread_num - mslite_model_type = self._set_ms_model_type() - self.model_type = mslite.ModelType(mslite_model_type) - self.device = cfg.device - self.thread_affinity_mode = cfg.thread_affinity_mode - self.context = self._init_context() - self.model_session = self._create_infer_session() - self.model_inputs = self.model_session.get_inputs() - self.dtype_map = { - DataType.BOOL: np.bool_, - DataType.INT8: np.int8, - DataType.INT16: np.int16, - DataType.INT32: np.int32, - DataType.INT64: np.int64, - DataType.UINT8: np.uint8, - DataType.UINT16: np.uint16, - DataType.UINT32: np.uint32, - DataType.UINT64: np.uint64, - DataType.FLOAT16: np.float16, - DataType.FLOAT32: np.float32, - DataType.FLOAT64: np.float64, - } - - def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """model infer""" - self._check_and_resize_input_tensor(input_data_map) - for model_input in self.model_inputs: - tensor_name = model_input.name.rstrip() - input_data = input_data_map.get(tensor_name, None) - if self.dtype_map[model_input.dtype] != input_data.dtype: - self.logger.warning('Input data type %s is different ' - 'from input tensor dtype %s, would convert' - 'input data type to %s ', - input_data.dtype, - model_input.dtype, - model_input.dtype) - input_data = input_data.astype(self.dtype_map[model_input.dtype]) - model_input.set_data_from_numpy(input_data) - outputs = self.model_session.predict(self.model_inputs) - predict_results = { - tensor.name.rstrip(): tensor.get_data_to_numpy() - for tensor in outputs - } - return predict_results - - def _check_and_resize_input_tensor(self, input_data_map): - """check and resize input tensor""" - is_need_reshape = False - input_shape_list = [] - - for model_input in self.model_inputs: - tensor_name = model_input.name.rstrip() - input_data = input_data_map.get(tensor_name, None) - if input_data is None: - raise ValueError(f'{tensor_name} is not in model inputs') - if model_input.shape != list(input_data.shape): - self.logger.warning('model input shape: %s is not equal' - 'with input data shape: %s, model input shape' - 'would be reshaped', model_input.shape, input_data.shape) - is_need_reshape = True - input_shape_list.append(list(input_data.shape)) - - if is_need_reshape: - self.model_session.resize(self.model_inputs, input_shape_list) - self.model_inputs = self.model_session.get_inputs() - - def _create_infer_session(self): - """create mslite infer session""" - model_session = mslite.Model() - model_session.build_from_file(self.model_file, - self.model_type, - self.context) - return model_session - - def _get_input_tensor_infos(self): - """get infos about input tensors""" - input_tensor_infos = {} - tensor_shape_list = [] - resize_tensor_list = [] - for input_tensor in self.model_inputs: - tensor_name = input_tensor.name.rstrip() - dtype = input_tensor.dtype - shape = input_tensor.shape - if -1 in shape or not shape: - resize_tensor_list.append(input_tensor) - shape = self.input_tensor_shapes.get(tensor_name, None) - tensor_shape_list.append(list(shape)) - input_tensor_infos[tensor_name] = (shape, dtype) - - if not resize_tensor_list: - self.model_session.resize(resize_tensor_list, tensor_shape_list) - self.model_inputs = self.model_session.get_inputs() - - return input_tensor_infos - - def _init_context(self): - """init mslite context""" - context = mslite.Context() - context.target = [self.device] - if self.device == 'ascend': - context.ascend.device_id = 0 - context.provider = self.cfg.ascend_provider - context.cpu.thread_num = self.thread_num - context.cpu.thread_affinity_mode = self.thread_affinity_mode - return context - - def _set_ms_model_type(self): - """set mslite model type""" - if self.model_file.endswith('ms'): - mslite_model_type = 4 - else: - mslite_model_type = 0 - return mslite_model_type +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +for mslite infer session +""" +from abc import ABC +from typing import Dict + +import mindspore_lite as mslite +from mindspore_lite import DataType +import numpy as np + +from mslite_bench.infer_base.abs_infer_session import AbcInferSession + + +class MsliteSession(AbcInferSession, ABC): + """ + mindspore lite infer session + """ + def __init__(self, + model_file, + cfg=None): + super().__init__(model_file, cfg) + self.thread_num = cfg.thread_num + mslite_model_type = self._set_ms_model_type() + self.model_type = mslite.ModelType(mslite_model_type) + self.device = cfg.device + self.thread_affinity_mode = cfg.thread_affinity_mode + self.context = self._init_context() + self.model_session = self._create_infer_session() + self.model_inputs = self.model_session.get_inputs() + self.dtype_map = { + DataType.BOOL: np.bool_, + DataType.INT8: np.int8, + DataType.INT16: np.int16, + DataType.INT32: np.int32, + DataType.INT64: np.int64, + DataType.UINT8: np.uint8, + DataType.UINT16: np.uint16, + DataType.UINT32: np.uint32, + DataType.UINT64: np.uint64, + DataType.FLOAT16: np.float16, + DataType.FLOAT32: np.float32, + DataType.FLOAT64: np.float64, + } + + def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """model infer""" + self._check_and_resize_input_tensor(input_data_map) + for model_input in self.model_inputs: + tensor_name = model_input.name.rstrip() + input_data = input_data_map.get(tensor_name, None) + if self.dtype_map[model_input.dtype] != input_data.dtype: + self.logger.warning('Input data type %s is different ' + 'from input tensor dtype %s, would convert' + 'input data type to %s ', + input_data.dtype, + model_input.dtype, + model_input.dtype) + input_data = input_data.astype(self.dtype_map[model_input.dtype]) + model_input.set_data_from_numpy(input_data) + outputs = self.model_session.predict(self.model_inputs) + predict_results = { + tensor.name.rstrip(): tensor.get_data_to_numpy() + for tensor in outputs + } + return predict_results + + def _check_and_resize_input_tensor(self, input_data_map): + """check and resize input tensor""" + is_need_reshape = False + input_shape_list = [] + + for model_input in self.model_inputs: + tensor_name = model_input.name.rstrip() + input_data = input_data_map.get(tensor_name, None) + if input_data is None: + raise ValueError(f'{tensor_name} is not in model inputs') + if model_input.shape != list(input_data.shape): + self.logger.warning('model input shape: %s is not equal' + 'with input data shape: %s, model input shape' + 'would be reshaped', model_input.shape, input_data.shape) + is_need_reshape = True + input_shape_list.append(list(input_data.shape)) + + if is_need_reshape: + self.model_session.resize(self.model_inputs, input_shape_list) + self.model_inputs = self.model_session.get_inputs() + + def _create_infer_session(self): + """create mslite infer session""" + model_session = mslite.Model() + model_session.build_from_file(self.model_file, + self.model_type, + self.context) + return model_session + + def _get_input_tensor_infos(self): + """get infos about input tensors""" + input_tensor_infos = {} + tensor_shape_list = [] + resize_tensor_list = [] + for input_tensor in self.model_inputs: + tensor_name = input_tensor.name.rstrip() + dtype = input_tensor.dtype + shape = input_tensor.shape + if -1 in shape or not shape: + resize_tensor_list.append(input_tensor) + shape = self.input_tensor_shapes.get(tensor_name, None) + tensor_shape_list.append(list(shape)) + input_tensor_infos[tensor_name] = (shape, dtype) + + if not resize_tensor_list: + self.model_session.resize(resize_tensor_list, tensor_shape_list) + self.model_inputs = self.model_session.get_inputs() + + return input_tensor_infos + + def _init_context(self): + """init mslite context""" + context = mslite.Context() + context.target = [self.device] + if self.device == 'ascend': + context.ascend.device_id = 0 + context.provider = self.cfg.ascend_provider + context.cpu.thread_num = self.thread_num + context.cpu.thread_affinity_mode = self.thread_affinity_mode + return context + + def _set_ms_model_type(self): + """set mslite model type""" + if self.model_file.endswith('ms'): + mslite_model_type = 4 + else: + mslite_model_type = 0 + return mslite_model_type diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/onnx_infer_session.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/onnx_infer_session.py index 73ce45eb97a..a8d0a1bc153 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/onnx_infer_session.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/onnx_infer_session.py @@ -1,70 +1,70 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -for onnx infer session -""" -from abc import ABC -from typing import Dict - -import onnx -import onnxruntime -import numpy as np - -from mslite_bench.infer_base.abs_infer_session import AbcInferSession - - -class OnnxSession(AbcInferSession, ABC): - """onnx infer session""" - def __init__(self, - model_file, - cfg=None): - super().__init__(model_file, cfg) - self.model = onnx.load(model_file) - self.output_nodes = self._get_all_output_nodes() - self.output_tensor_names = self._get_output_tensor_names() - self.model_session = self._create_infer_session() - - def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """onnx infer""" - outputs = self.model_session.run(self.output_tensor_names, - input_data_map) - result = {} - for key, value in zip(self.output_tensor_names, outputs): - result[key] = value - return result - - def _create_infer_session(self): - """create infer session""" - model_session = onnxruntime.InferenceSession(self.model_file, - providers=['CPUExecutionProvider']) - self.logger.debug('onnx Session create successfully') - return model_session - - def _get_all_input_nodes(self): - """get all input nodes""" - all_input_nodes = self.model.graph.input - input_initializer_nodes = self.model.graph.initializer - - return list(set(all_input_nodes) - set(input_initializer_nodes)) - - def _get_all_output_nodes(self): - """get all output nodes""" - return self.model.graph.output - - def _get_output_tensor_names(self): - """get output tensor names""" - if self.output_tensor_names is None: - self.output_tensor_names = [node.name for node in self.output_nodes] - return self.output_tensor_names +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +for onnx infer session +""" +from abc import ABC +from typing import Dict + +import onnx +import onnxruntime +import numpy as np + +from mslite_bench.infer_base.abs_infer_session import AbcInferSession + + +class OnnxSession(AbcInferSession, ABC): + """onnx infer session""" + def __init__(self, + model_file, + cfg=None): + super().__init__(model_file, cfg) + self.model = onnx.load(model_file) + self.output_nodes = self._get_all_output_nodes() + self.output_tensor_names = self._get_output_tensor_names() + self.model_session = self._create_infer_session() + + def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """onnx infer""" + outputs = self.model_session.run(self.output_tensor_names, + input_data_map) + result = {} + for key, value in zip(self.output_tensor_names, outputs): + result[key] = value + return result + + def _create_infer_session(self): + """create infer session""" + model_session = onnxruntime.InferenceSession(self.model_file, + providers=['CPUExecutionProvider']) + self.logger.debug('onnx Session create successfully') + return model_session + + def _get_all_input_nodes(self): + """get all input nodes""" + all_input_nodes = self.model.graph.input + input_initializer_nodes = self.model.graph.initializer + + return list(set(all_input_nodes) - set(input_initializer_nodes)) + + def _get_all_output_nodes(self): + """get all output nodes""" + return self.model.graph.output + + def _get_output_tensor_names(self): + """get output tensor names""" + if self.output_tensor_names is None: + self.output_tensor_names = [node.name for node in self.output_nodes] + return self.output_tensor_names diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/paddle_infer_session.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/paddle_infer_session.py index dafc08faf19..788b6d2777b 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/paddle_infer_session.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/paddle_infer_session.py @@ -1,100 +1,100 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -for paddle infer session -""" -from abc import ABC -from typing import Dict - -import paddle.inference as paddle_infer -import numpy as np - -from mslite_bench.infer_base.abs_infer_session import AbcInferSession -from mslite_bench.common.model_info_enum import ( - DeviceType, -) - - -class PaddleSession(AbcInferSession, ABC): - """paddle infer session""" - def __init__(self, - model_file, - cfg, - params_file=None): - super().__init__(model_file, cfg) - self.place = None - self.param_file = params_file - self.model_session = self._create_infer_session() - self.input_names = self.model_session.get_input_names() - self.output_names = self.model_session.get_output_names() - self.model_inputs = self._get_input_tensor() - - def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """model infer""" - for name, input_tensor in self.model_inputs.items(): - input_data = input_data_map.get(name, None) - input_tensor.copy_from_cpu(input_data) - self.model_session.run() - predict_results = { - name: self.model_session.get_output_handle(name) - for name in self.output_names - } - return predict_results - - def destroy(self): - """model destroy""" - self.model_session.clear_intermediate_tensor() - self.model_session.try_shrink_memory() - - def _get_input_tensor(self): - """get input tensor""" - input_tensor_map = {} - for name in self.input_names: - input_tensor_map[name] = self.model_session.get_input_handle(name) - - return input_tensor_map - - def _create_infer_session(self): - """create infer session""" - config = paddle_infer.Config(self.model_file, - self.param_file) - if self.cfg.device == DeviceType.CPU.value: - config.set_cpu_math_library_num_threads(self.cfg.thread_num) - elif self.cfg.device == DeviceType.GPU.value: - config.enable_use_gpu(self.cfg.gpu_memory_size, - self.cfg.device_id) - if self.cfg.is_enable_tensorrt: - precision_type = paddle_infer.PrecisionType.Float32 - if self.cfg.is_fp16: - precision_type = paddle_infer.PrecisionType.Half - elif self.cfg.is_int8: - precision_type = paddle_infer.PrecisionType.Int8 - config.set_trt_dynamic_shape_info( - optim_input_shape=self.cfg.tensorrt_optim_input_shape, - min_input_shape=self.cfg.tensorrt_min_input_shape, - max_input_shape=self.cfg.tensorrt_max_input_shape - ) - config.enable_tensorrt_engine(workspace_size=1 << 28, - max_batch_size=self.cfg.batch_size, - min_subgraph_size=1, - precision_mode=precision_type, - use_static=False, - use_calib_mode=True) - - else: - raise ValueError(f'paddle do not work on device type {self.cfg.device}') - - model_session = paddle_infer.create_predictor(config) - return model_session +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +for paddle infer session +""" +from abc import ABC +from typing import Dict + +import paddle.inference as paddle_infer +import numpy as np + +from mslite_bench.infer_base.abs_infer_session import AbcInferSession +from mslite_bench.common.model_info_enum import ( + DeviceType, +) + + +class PaddleSession(AbcInferSession, ABC): + """paddle infer session""" + def __init__(self, + model_file, + cfg, + params_file=None): + super().__init__(model_file, cfg) + self.place = None + self.param_file = params_file + self.model_session = self._create_infer_session() + self.input_names = self.model_session.get_input_names() + self.output_names = self.model_session.get_output_names() + self.model_inputs = self._get_input_tensor() + + def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """model infer""" + for name, input_tensor in self.model_inputs.items(): + input_data = input_data_map.get(name, None) + input_tensor.copy_from_cpu(input_data) + self.model_session.run() + predict_results = { + name: self.model_session.get_output_handle(name) + for name in self.output_names + } + return predict_results + + def destroy(self): + """model destroy""" + self.model_session.clear_intermediate_tensor() + self.model_session.try_shrink_memory() + + def _get_input_tensor(self): + """get input tensor""" + input_tensor_map = {} + for name in self.input_names: + input_tensor_map[name] = self.model_session.get_input_handle(name) + + return input_tensor_map + + def _create_infer_session(self): + """create infer session""" + config = paddle_infer.Config(self.model_file, + self.param_file) + if self.cfg.device == DeviceType.CPU.value: + config.set_cpu_math_library_num_threads(self.cfg.thread_num) + elif self.cfg.device == DeviceType.GPU.value: + config.enable_use_gpu(self.cfg.gpu_memory_size, + self.cfg.device_id) + if self.cfg.is_enable_tensorrt: + precision_type = paddle_infer.PrecisionType.Float32 + if self.cfg.is_fp16: + precision_type = paddle_infer.PrecisionType.Half + elif self.cfg.is_int8: + precision_type = paddle_infer.PrecisionType.Int8 + config.set_trt_dynamic_shape_info( + optim_input_shape=self.cfg.tensorrt_optim_input_shape, + min_input_shape=self.cfg.tensorrt_min_input_shape, + max_input_shape=self.cfg.tensorrt_max_input_shape + ) + config.enable_tensorrt_engine(workspace_size=1 << 28, + max_batch_size=self.cfg.batch_size, + min_subgraph_size=1, + precision_mode=precision_type, + use_static=False, + use_calib_mode=True) + + else: + raise ValueError(f'paddle do not work on device type {self.cfg.device}') + + model_session = paddle_infer.create_predictor(config) + return model_session diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/tf_infer_session.py b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/tf_infer_session.py index e6108e55a1a..79f2efe11ef 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/tf_infer_session.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/infer_base/tf_infer_session.py @@ -1,84 +1,84 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -for paddle infer session -""" -import os -from abc import ABC -from typing import Dict - -import tensorflow as tf -import numpy as np - -from mslite_bench.infer_base.abs_infer_session import AbcInferSession - - -class TFSession(AbcInferSession, ABC): - """TF infer session""" - def __init__(self, - model_file, - cfg=None): - super().__init__(model_file, cfg) - self.graph = None - self.model_session = self._create_infer_session() - - self.input_tensor_map = { - tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for - tensor_name in self.input_tensor_shapes.keys() - } - - self.output_tensor_map = { - tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for - tensor_name in self.output_tensor_names - } - - def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """model infer""" - results = { - key: self.model_session.run(output_tensor, - feed_dict={ - self.input_tensor_map.get(name): input_data_map.get(name) - for name in self.input_tensor_shapes.keys() - }) - for key, output_tensor in self.output_tensor_map.items() - } - - return results - - def _create_infer_session(self): - """create infer session""" - if not os.path.exists(self.model_file): - raise ValueError(f'TF model {self.model_file} does not exist') - with tf.io.gfile.GFile(self.model_file, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - input_tensor_map = self._get_tf_input_tensor_map(graph_def) - tf.import_graph_def(graph_def, input_map=input_tensor_map, name='') - self.logger.debug('Tensor map done') - self.graph = tf.compat.v1.get_default_graph() - model_session = tf.compat.v1.Session(graph=self.graph) - return model_session - - def _get_tf_input_tensor_map(self, graph_def): - """get tensorflow input tensor map""" - input_tensor_map = {} - tf.import_graph_def(graph_def, name='') - default_graph = tf.compat.v1.get_default_graph() - for key, shape in self.input_tensor_shapes.items(): - tensor_name = f'{key}:0' - input_tensor = default_graph.get_tensor_by_name(tensor_name) - input_tensor.set_shape(shape) - input_tensor_map[key] = input_tensor - return input_tensor_map +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +for paddle infer session +""" +import os +from abc import ABC +from typing import Dict + +import tensorflow as tf +import numpy as np + +from mslite_bench.infer_base.abs_infer_session import AbcInferSession + + +class TFSession(AbcInferSession, ABC): + """TF infer session""" + def __init__(self, + model_file, + cfg=None): + super().__init__(model_file, cfg) + self.graph = None + self.model_session = self._create_infer_session() + + self.input_tensor_map = { + tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for + tensor_name in self.input_tensor_shapes.keys() + } + + self.output_tensor_map = { + tensor_name: self.graph.get_tensor_by_name(tensor_name + ': 0') for + tensor_name in self.output_tensor_names + } + + def infer(self, input_data_map: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """model infer""" + results = { + key: self.model_session.run(output_tensor, + feed_dict={ + self.input_tensor_map.get(name): input_data_map.get(name) + for name in self.input_tensor_shapes.keys() + }) + for key, output_tensor in self.output_tensor_map.items() + } + + return results + + def _create_infer_session(self): + """create infer session""" + if not os.path.exists(self.model_file): + raise ValueError(f'TF model {self.model_file} does not exist') + with tf.io.gfile.GFile(self.model_file, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + input_tensor_map = self._get_tf_input_tensor_map(graph_def) + tf.import_graph_def(graph_def, input_map=input_tensor_map, name='') + self.logger.debug('Tensor map done') + self.graph = tf.compat.v1.get_default_graph() + model_session = tf.compat.v1.Session(graph=self.graph) + return model_session + + def _get_tf_input_tensor_map(self, graph_def): + """get tensorflow input tensor map""" + input_tensor_map = {} + tf.import_graph_def(graph_def, name='') + default_graph = tf.compat.v1.get_default_graph() + for key, shape in self.input_tensor_shapes.items(): + tensor_name = f'{key}:0' + input_tensor = default_graph.get_tensor_by_name(tensor_name) + input_tensor.set_shape(shape) + input_tensor_map[key] = input_tensor + return input_tensor_map diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/converter.py b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/converter.py index bc014a7fd3c..a56a1b799e0 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/converter.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/converter.py @@ -1,163 +1,163 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" converter for mslite """ - -import copy -from enum import Enum -import os -from dataclasses import dataclass - -import mindspore_lite as mslite - -from mslite_bench.utils import InferLogger -from mslite_bench.common.model_info_enum import DeviceType -from mslite_bench.common.task_common_func import CommonFunc -from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary - - -class InputModelType(Enum): - """ enum model type class for model to be convterted""" - PB = mslite.FmkType.TF - CAFFE = mslite.FmkType.CAFFE - ONNX = mslite.FmkType.ONNX - MINDIR = mslite.FmkType.MINDIR - TFLITE = mslite.FmkType.TFLITE - PTH = mslite.FmkType.PYTORCH - - -class MsliteModelType(Enum): - """ enum model type in mslite""" - MINDIR = mslite.ModelType.MINDIR - MINDIR_LITE = mslite.ModelType.MINDIR_LITE - - -class MsliteDataType(Enum): - """ enum data type for quant """ - FLOAT32 = mslite.DataType.FLOAT32 - INT8 = mslite.DataType.INT8 - UINT8 = mslite.DataType.UINT8 - UNKNOWN = mslite.DataType.UNKNOWN - - -class MsliteTensorFormat(Enum): - """ enum input data format""" - NCHW = mslite.Format.NCHW - NHWC = mslite.Format.NHWC - - -@dataclass -class ConverterParams: - """data class for converter input params""" - input_shape: str = None - input_data_type: str = "FLOAT32" - output_data_type: str = "FLOAT32" - input_format: str = "NCHW" - weight_fp16: bool = False - save_type: str = "MINDIR" - decrypt_key: str = None - decrypt_mode: str = None - enable_encryption: bool = False - encrypt_key: str = None - infer: bool = False - optimize: str = None - device: str = DeviceType.CPU.value - converter_output_file: str = None - model_file: str = None - params_file: str = None - - -class MsliteConverter: - """model converter for mindspore lite""" - @classmethod - def convert(cls, - args, - logger=None, - is_delete_ms_model=False): - """convert third party model to mslite model""" - if logger is None: - logger = InferLogger(args.log_path).logger - - logger.debug('Start to convert model') - cls.model_convert(args, logger) - - if not args.converter_is_analysis: - return - - if cls.enum_value(args.converter_save_type, MsliteModelType) == MsliteModelType.MINDIR.value: - output_file_name = f'{args.converter_output_file}.mindir' - else: - output_file_name = f'{args.converter_output_file}.ms' - - if os.path.exists(output_file_name): - args_copy = copy.deepcopy(args) - args_copy.cmp_model_file = args.model_file - args_copy.model_file = output_file_name - logger.info('Start accuracy compare procedure') - CrossFrameworkAccSummary.accuracy_compare_func(args_copy, logger) - - if is_delete_ms_model: - os.remove(output_file_name) - - @classmethod - def model_convert(cls, args, logger=None): - """model convert for mslite""" - if logger is None: - logger = InferLogger(args.log_path).logger - converter = cls._init_converter(args) - try: - extension = os.path.splitext(args.model_file)[1][1:].upper() - model_type = cls.enum_value(extension, InputModelType) - except NotImplementedError as e: - logger.error('Input model type error: %s', e) - return - - converter.convert(model_type, - args.model_file, - args.converter_output_file, - weight_file=args.params_file, - config_file=args.converter_config_file) - - - @classmethod - def _init_converter(cls, args): - """init mslite model converter with args""" - converter = mslite.Converter() - - converter.input_shape = CommonFunc.get_tensor_shapes(args.converter_input_shape) - converter.input_data_type = cls.enum_value(args.quant_input_data_type.upper(), - MsliteDataType) - converter.output_data_type = cls.enum_value(args.quant_output_data_type.upper(), - MsliteDataType) - converter.input_format = cls.enum_value(args.converter_input_format, - MsliteTensorFormat) - converter.weight_fp16 = args.converter_weight_fp16 - converter.save_type = cls.enum_value(args.converter_save_type, MsliteModelType) - converter.decrypt_key = args.converter_decrypt_key - converter.decrypt_mode = args.converter_decrypt_mode - converter.enable_encryption = args.converter_enable_encryption - converter.encrypt_key = args.converter_encrypt_key - converter.infer = False - converter.optimize = args.converter_optimize - if args.device.lower() == DeviceType.ASCEND.value: - converter.device = "Ascend" - - return converter - - @staticmethod - def enum_value(enum_key, enum_class): - """get value from key in enum class""" - if enum_key in enum_class.__members__: - return enum_class[enum_key].value - raise NotImplementedError(f"{enum_key} is not in class {enum_class}") +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" converter for mslite """ + +import copy +from enum import Enum +import os +from dataclasses import dataclass + +import mindspore_lite as mslite + +from mslite_bench.utils import InferLogger +from mslite_bench.common.model_info_enum import DeviceType +from mslite_bench.common.task_common_func import CommonFunc +from mslite_bench.tools.cross_framework_accuracy import CrossFrameworkAccSummary + + +class InputModelType(Enum): + """ enum model type class for model to be convterted""" + PB = mslite.FmkType.TF + CAFFE = mslite.FmkType.CAFFE + ONNX = mslite.FmkType.ONNX + MINDIR = mslite.FmkType.MINDIR + TFLITE = mslite.FmkType.TFLITE + PTH = mslite.FmkType.PYTORCH + + +class MsliteModelType(Enum): + """ enum model type in mslite""" + MINDIR = mslite.ModelType.MINDIR + MINDIR_LITE = mslite.ModelType.MINDIR_LITE + + +class MsliteDataType(Enum): + """ enum data type for quant """ + FLOAT32 = mslite.DataType.FLOAT32 + INT8 = mslite.DataType.INT8 + UINT8 = mslite.DataType.UINT8 + UNKNOWN = mslite.DataType.UNKNOWN + + +class MsliteTensorFormat(Enum): + """ enum input data format""" + NCHW = mslite.Format.NCHW + NHWC = mslite.Format.NHWC + + +@dataclass +class ConverterParams: + """data class for converter input params""" + input_shape: str = None + input_data_type: str = "FLOAT32" + output_data_type: str = "FLOAT32" + input_format: str = "NCHW" + weight_fp16: bool = False + save_type: str = "MINDIR" + decrypt_key: str = None + decrypt_mode: str = None + enable_encryption: bool = False + encrypt_key: str = None + infer: bool = False + optimize: str = None + device: str = DeviceType.CPU.value + converter_output_file: str = None + model_file: str = None + params_file: str = None + + +class MsliteConverter: + """model converter for mindspore lite""" + @classmethod + def convert(cls, + args, + logger=None, + is_delete_ms_model=False): + """convert third party model to mslite model""" + if logger is None: + logger = InferLogger(args.log_path).logger + + logger.debug('Start to convert model') + cls.model_convert(args, logger) + + if not args.converter_is_analysis: + return + + if cls.enum_value(args.converter_save_type, MsliteModelType) == MsliteModelType.MINDIR.value: + output_file_name = f'{args.converter_output_file}.mindir' + else: + output_file_name = f'{args.converter_output_file}.ms' + + if os.path.exists(output_file_name): + args_copy = copy.deepcopy(args) + args_copy.cmp_model_file = args.model_file + args_copy.model_file = output_file_name + logger.info('Start accuracy compare procedure') + CrossFrameworkAccSummary.accuracy_compare_func(args_copy, logger) + + if is_delete_ms_model: + os.remove(output_file_name) + + @classmethod + def model_convert(cls, args, logger=None): + """model convert for mslite""" + if logger is None: + logger = InferLogger(args.log_path).logger + converter = cls._init_converter(args) + try: + extension = os.path.splitext(args.model_file)[1][1:].upper() + model_type = cls.enum_value(extension, InputModelType) + except NotImplementedError as e: + logger.error('Input model type error: %s', e) + return + + converter.convert(model_type, + args.model_file, + args.converter_output_file, + weight_file=args.params_file, + config_file=args.converter_config_file) + + + @classmethod + def _init_converter(cls, args): + """init mslite model converter with args""" + converter = mslite.Converter() + + converter.input_shape = CommonFunc.get_tensor_shapes(args.converter_input_shape) + converter.input_data_type = cls.enum_value(args.quant_input_data_type.upper(), + MsliteDataType) + converter.output_data_type = cls.enum_value(args.quant_output_data_type.upper(), + MsliteDataType) + converter.input_format = cls.enum_value(args.converter_input_format, + MsliteTensorFormat) + converter.weight_fp16 = args.converter_weight_fp16 + converter.save_type = cls.enum_value(args.converter_save_type, MsliteModelType) + converter.decrypt_key = args.converter_decrypt_key + converter.decrypt_mode = args.converter_decrypt_mode + converter.enable_encryption = args.converter_enable_encryption + converter.encrypt_key = args.converter_encrypt_key + converter.infer = False + converter.optimize = args.converter_optimize + if args.device.lower() == DeviceType.ASCEND.value: + converter.device = "Ascend" + + return converter + + @staticmethod + def enum_value(enum_key, enum_class): + """get value from key in enum class""" + if enum_key in enum_class.__members__: + return enum_class[enum_key].value + raise NotImplementedError(f"{enum_key} is not in class {enum_class}") diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/cross_framework_accuracy.py b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/cross_framework_accuracy.py index b555c2ca84c..921b0626fef 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/cross_framework_accuracy.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/cross_framework_accuracy.py @@ -1,438 +1,438 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -functions for cross framework model infer result accuracy compare and summary -""" -import os -import os.path -import stat -import csv - -import numpy as np - -from mslite_bench.infer_base.infer_session_factory import InferSessionFactory -from mslite_bench.utils.infer_log import InferLogger -from mslite_bench.common.task_common_func import CommonFunc -from mslite_bench.common.enum_class import ( - NumpyDtype -) -from mslite_bench.common.model_info_enum import ErrorAlgType - -_logger = InferLogger().logger - - -class CrossFrameworkAccSummary: - """ - functions for cross framework model infer result accuracy compare and summary - """ - @classmethod - def acc_infos_between_features(cls, - standard_feature, - compare_feature): - """ - get accuracy info between features, including mean error ratio - and cosine similarity. - params: - standard_feature(Dict[str, nunmpy.ndarray): standard features to be compared - compare_feature(Dict[str, nunmpy.ndarray): compare features to compare - return: - A dict, including mean_error_ratio and cosine similarity. - """ - - mean_relative_error = cls.get_mean_relative_error_between_features(compare_feature, - standard_feature) - - cosine_similarity = cls.get_cosine_distance_between_features(compare_feature, - standard_feature) - return { - ErrorAlgType.MEAN_RELATIVE_ERROR.value: mean_relative_error, - ErrorAlgType.COSINE_SIMILARITY.value: cosine_similarity - } - - @classmethod - def accuracy_compare_func(cls, - args, - logger=None): - """ - get outputs accuracy compare info between two different framework using same model. - params: - args: input arguments - logger: logger to recorder logs - return: - A dict, including mean_error_ratio and cosine similarity. - """ - cmp_result = None - - src_file_path = args.model_file - if not src_file_path.endswith('ms') and \ - not src_file_path.endswith('mindir'): - raise ValueError(f'{src_file_path} is not a valid mslite model') - - dst_file_path = args.cmp_model_file - - input_data_map = CommonFunc.create_numpy_data_map(args) - ms_config = CommonFunc.get_framework_config(src_file_path, - args) - try: - ms_session = InferSessionFactory.create_infer_session(src_file_path, - ms_config) - except ValueError as e: - logger.error('[Accuracy Compare] Create ms session failed: %s', e) - return cmp_result - - args.device = args.cmp_device - cmp_cfg = CommonFunc.get_framework_config(dst_file_path, - args) - try: - cmp_session = InferSessionFactory.create_infer_session(dst_file_path, - cmp_cfg, - args.params_file) - except ValueError as e: - logger.error(f'Create dst session failed %s', e) - return cmp_result - - try: - cmp_result = cls.real_accuracy_compare(ms_session, - cmp_session, - input_data_map) - except (NotImplementedError, ValueError) as e: - logger.error(f'Accuracy test failed, get accuracy failed %s', e) - raise - cmp_result = cls.is_acc_ok(cmp_result) - for key, val in cmp_result.items(): - logger.debug(f'{key}: {val}') - - if not args.cmp_result_file: - csv_path = os.path.join(os.path.dirname(src_file_path), 'accuracy_infos.csv') - else: - csv_path = f'{args.cmp_result_file}.csv' - csv_dir = os.path.dirname(csv_path) - os.makedirs(csv_dir, exist_ok=True) - logger.info(f'Accuracy compare done, save accuracy info in %s', csv_path) - cls.write_csv(cmp_result, csv_path) - return cmp_result - - @classmethod - def real_accuracy_compare(cls, - src_session, - dst_session, - input_tensor_map): - """ - get accuracy compare info between two different sessions with same input tensor map. - params: - src_session: session to be compared - dst_session: session to compare - input_tensor_map: tensor name and value dict for session input. - return: - A dict, including mean_error_ratio and cosine similarity. - """ - src_output = src_session(input_tensor_map) - dst_output = dst_session(input_tensor_map) - - result = { - ErrorAlgType.MEAN_RELATIVE_ERROR.value: cls.get_mean_relative_error_between_features(dst_output, - src_output), - ErrorAlgType.COSINE_SIMILARITY.value: cls.get_cosine_distance_between_features(dst_output, - src_output) - } - return result - - @classmethod - def specific_accuracy_compare(cls, - src_session, - dst_session, - args): - """ - get accuracy compare info between two different sessions with - specific input loading from files. - params: - src_session: session to be compared - dst_session: session to compare - args: input arguments. - return: - A dict, including mean_error_ratio and cosine similarity. - """ - input_tensor_map = np.load(args.input_data_file, - allow_pickle=True).item() - result = cls.real_accuracy_compare(src_session, - dst_session, - input_tensor_map) - return result - - @classmethod - def random_accuracy_compare(cls, - src_session, - dst_session, - args): - """ - get accuracy compare info between two different sessions with random inputs. - params: - src_session: session to be compared - dst_session: session to compare - args: input arguments. - return: - A dict, including mean_error_ratio and cosine similarity. - """ - input_tensor_dtypes = CommonFunc.parse_dtype_infos(args.input_tensor_dtypes) - input_tensor_shapes = CommonFunc.get_tensor_shapes(args.input_tensor_shapes) - input_tensor_infos = { - key: (shape, input_tensor_dtypes.get(key)) - for key, shape in input_tensor_shapes.items() - } - try: - input_tensor_map = CommonFunc.create_numpy_data_map(input_tensor_infos) - except ValueError as e: - _logger.error('Random accuracy compare failed: %s', e) - raise - result = cls.real_accuracy_compare(src_session, - dst_session, - input_tensor_map) - - return result - - @classmethod - def get_cosine_distance_between_features(cls, - calibrate_feature, - cmp_feature): - """ - calculate cosine distance between features. - params: - calibrate_feature: feature to be calibrated. - cmp_feature: feature to compare. - return: - cosine similarity values between features. - """ - cosine_similarity = {} - for key, dst_feature in calibrate_feature.items(): - src_feature = cmp_feature.get(key) - abs_eps = cls.absolute_tolerance() - dst_sum = np.sum(dst_feature * dst_feature) - src_sum = np.sum(src_feature * src_feature) - dot_sum = np.sum(dst_feature * src_feature) - - if dst_sum < abs_eps and src_sum < abs_eps: - value = 1.0 - elif dst_sum * src_sum < abs_eps: - if dst_sum < abs_eps or src_sum < abs_eps: - value = 1.0 - else: - value = 0.0 - else: - value = dot_sum / (np.sqrt(dst_sum) * np.sqrt(src_sum) + abs_eps) - - cosine_similarity[key] = cls.error_format(value) - - return cosine_similarity - - @classmethod - def get_mean_relative_error_between_features(cls, - dst_feature, - src_feature): - """ - calculate mean relative error between features. - params: - dst_feature: feature to be calibrated. - src_feature: feature to compare. - return: - mean relative error values between features. - """ - mean_relative_error_info = {} - np.seterr(divide='ignore', invalid='ignore') - - for key in dst_feature.keys(): - feat_a = dst_feature.get(key, None) - feat_b = src_feature.get(key, None) - if feat_b is None: - raise ValueError(f'Model Inference feature ' - f'is not consistent in tensor: {key}') - if feat_a.size == 0: - mean_relative_error_info[key] = '0.0' - continue - if feat_a.dtype != feat_b.dtype: - _logger.warning('layer %s : different dtypes between onnx out: %s ' - 'with mslite out: %s ', - key, - feat_a.dtype, - feat_b.dtype) - mean_relative_error_info[key] = '0.0' - continue - diff = np.abs(feat_b - feat_a) - abs_feat_a = np.abs(feat_a) - relative_index = diff > cls.relative_tolerance() - if relative_index.size == 0: - mean_relative_error_info[key] = '0.0' - continue - diff = diff[relative_index] - abs_feat_a = abs_feat_a[relative_index] - abs_index = abs_feat_a > cls.absolute_tolerance() - if abs_index.size == 0: - mean_relative_error_info[key] = cls.error_format(np.average(diff)) - continue - abs_feat_a = abs_feat_a[abs_index] - relative_diff = diff[abs_index] - abs_diff = diff[~abs_index] - relative_error = np.divide(relative_diff, abs_feat_a) - mean_relative_error_info[key] = (np.sum(relative_error) + np.sum(abs_diff)) \ - / (relative_error.size + abs_diff.size) - if np.isnan(mean_relative_error_info.get(key, None)): - _logger.warning('layer: %s has nan value, ' - '%s do not work', - key, - ErrorAlgType.MEAN_RELATIVE_ERROR.value) - mean_relative_error_info[key] = cls.error_format(mean_relative_error_info.get(key, None)) - - return mean_relative_error_info - - @classmethod - def get_mean_error_between_features(cls, - dst_feature, - src_feature): - """ - calculate mean error between features. - params: - dst_feature: feature to be calibrated. - src_feature: feature to compare. - return: - mean error values between features. - """ - absolute_tolerance = cls.absolute_tolerance() - relative_tolerance = cls.relative_tolerance() - mean_error_info = {} - - for key in dst_feature.keys(): - feat_a = dst_feature.get(key, None) - feat_b = src_feature.get(key, None) - if feat_b is None: - raise ValueError(f'Model Inference feature ' - f'is not consistent in tensor: {key}') - diff = abs(feat_a - feat_b) - gt_tolerance_index = diff > (absolute_tolerance + relative_tolerance * abs(feat_a)) - lt_tolerance_index = np.logical_and(gt_tolerance_index, abs(feat_a) > absolute_tolerance) - gt_tolerance_index = np.logical_and(gt_tolerance_index, abs(feat_a) < absolute_tolerance) - gt_tolerance_index = np.logical_and(gt_tolerance_index, diff > relative_tolerance) - gt_error = diff[gt_tolerance_index] - lt_error = diff / (abs(feat_a) + absolute_tolerance) - lt_error = lt_error[lt_tolerance_index] - if gt_error.size + lt_error.size == 0: - mean_error = 0.0 - else: - mean_error = (np.sum(gt_error) + np.sum(lt_error)) / \ - (gt_error.size + lt_error.size + 1 + cls.absolute_tolerance()) - mean_error_info[key] = cls.error_format(mean_error) - return mean_error_info - - @staticmethod - def check_np_dtype_with_model_input_dtype(tensor_map, - session): - """ - check input numpy data dtype with model input dtype - params: - tensor_map: a dict with key tensor name and value numpy data. - session: model infer session - return: - a dict with key tensor name and value revised numpy data. - """ - ret_map = tensor_map - input_tensor_infos = session.input_infos - dtype_class = session.dtype_class - - for key, np_data in tensor_map.items(): - np_dtype = np_data.dtype - np_dtype_name = NumpyDtype(np_dtype).name - session_dtype = input_tensor_infos.get(key, None) - - if session_dtype is None: - raise ValueError('Input tensor name is not consistent with model inputs') - session_dtype = session_dtype[1] - - session_dtype_name = dtype_class(session_dtype).name - if session_dtype_name != np_dtype_name: - _logger.warning('input tensor %s input dtype %s ' - 'is not consistent with model dtype(%s)', - key, - np_dtype_name, - session_dtype_name) - new_data = np_data.astype(getattr(NumpyDtype, session_dtype_name).value) - ret_map[key] = new_data - - return ret_map - - @staticmethod - def absolute_tolerance(): - """for const absolute tolerance""" - return 1e-4 - - @staticmethod - def relative_tolerance(): - """for const relative tolerance""" - return 1e-4 - - @staticmethod - def error_format(error): - return f'{error * 100:.4f}%' - - @staticmethod - def write_csv(contents, csv_file): - """write csv""" - contents_to_write = [] - error_names = [] - error_infos = [] - for key, value in contents.items(): - error_names.append(key) - error_infos.append(value) - - for layer_name in list(error_infos[0].keys()): - tmp_dict = {'layer_name': layer_name} - for error_name in error_names: - tmp_dict[error_name] = contents.get(error_name).get(layer_name) - contents_to_write.append(tmp_dict) - - fieldnames = ['layer_name'] + error_names - flags = os.O_WRONLY - mode = stat.S_IWUSR | stat.S_IRUSR - with os.fdopen(os.open(csv_file, flags, mode), 'w') as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(contents_to_write) - - @staticmethod - def is_acc_ok(acc_info): - """add is ok check for accuracy result""" - mre = acc_info.get(ErrorAlgType.MEAN_RELATIVE_ERROR.value, None) - cos = acc_info.get(ErrorAlgType.COSINE_SIMILARITY.value, None) - is_ok = {} - if mre is None or cos is None: - raise ValueError('MRE or cosine similarity is None') - mre_thred = 0.05 - cos_thred = 0.99 - cos_bad_thred = 0.9 - def error_format_to_float(num): - return float(num.strip('%')) / 100 - - nan_set = {'nan', 'nan%'} - for key, mre_val in mre.items(): - cos_val = cos.get(key, None) - if mre_val in nan_set or cos_val in nan_set: - is_ok[key] = 'Invalid' - elif error_format_to_float(cos_val) < cos_bad_thred: - is_ok[key] = 'Bad' - elif error_format_to_float(mre_val) > mre_thred \ - and error_format_to_float(cos_val) < cos_thred: - is_ok[key] = 'Bad' - else: - is_ok[key] = 'Good' - - acc_info['is_ok'] = is_ok - return acc_info +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +functions for cross framework model infer result accuracy compare and summary +""" +import os +import os.path +import stat +import csv + +import numpy as np + +from mslite_bench.infer_base.infer_session_factory import InferSessionFactory +from mslite_bench.utils.infer_log import InferLogger +from mslite_bench.common.task_common_func import CommonFunc +from mslite_bench.common.enum_class import ( + NumpyDtype +) +from mslite_bench.common.model_info_enum import ErrorAlgType + +_logger = InferLogger().logger + + +class CrossFrameworkAccSummary: + """ + functions for cross framework model infer result accuracy compare and summary + """ + @classmethod + def acc_infos_between_features(cls, + standard_feature, + compare_feature): + """ + get accuracy info between features, including mean error ratio + and cosine similarity. + params: + standard_feature(Dict[str, nunmpy.ndarray): standard features to be compared + compare_feature(Dict[str, nunmpy.ndarray): compare features to compare + return: + A dict, including mean_error_ratio and cosine similarity. + """ + + mean_relative_error = cls.get_mean_relative_error_between_features(compare_feature, + standard_feature) + + cosine_similarity = cls.get_cosine_distance_between_features(compare_feature, + standard_feature) + return { + ErrorAlgType.MEAN_RELATIVE_ERROR.value: mean_relative_error, + ErrorAlgType.COSINE_SIMILARITY.value: cosine_similarity + } + + @classmethod + def accuracy_compare_func(cls, + args, + logger=None): + """ + get outputs accuracy compare info between two different framework using same model. + params: + args: input arguments + logger: logger to recorder logs + return: + A dict, including mean_error_ratio and cosine similarity. + """ + cmp_result = None + + src_file_path = args.model_file + if not src_file_path.endswith('ms') and \ + not src_file_path.endswith('mindir'): + raise ValueError(f'{src_file_path} is not a valid mslite model') + + dst_file_path = args.cmp_model_file + + input_data_map = CommonFunc.create_numpy_data_map(args) + ms_config = CommonFunc.get_framework_config(src_file_path, + args) + try: + ms_session = InferSessionFactory.create_infer_session(src_file_path, + ms_config) + except ValueError as e: + logger.error('[Accuracy Compare] Create ms session failed: %s', e) + return cmp_result + + args.device = args.cmp_device + cmp_cfg = CommonFunc.get_framework_config(dst_file_path, + args) + try: + cmp_session = InferSessionFactory.create_infer_session(dst_file_path, + cmp_cfg, + args.params_file) + except ValueError as e: + logger.error(f'Create dst session failed %s', e) + return cmp_result + + try: + cmp_result = cls.real_accuracy_compare(ms_session, + cmp_session, + input_data_map) + except (NotImplementedError, ValueError) as e: + logger.error(f'Accuracy test failed, get accuracy failed %s', e) + raise + cmp_result = cls.is_acc_ok(cmp_result) + for key, val in cmp_result.items(): + logger.debug(f'{key}: {val}') + + if not args.cmp_result_file: + csv_path = os.path.join(os.path.dirname(src_file_path), 'accuracy_infos.csv') + else: + csv_path = f'{args.cmp_result_file}.csv' + csv_dir = os.path.dirname(csv_path) + os.makedirs(csv_dir, exist_ok=True) + logger.info(f'Accuracy compare done, save accuracy info in %s', csv_path) + cls.write_csv(cmp_result, csv_path) + return cmp_result + + @classmethod + def real_accuracy_compare(cls, + src_session, + dst_session, + input_tensor_map): + """ + get accuracy compare info between two different sessions with same input tensor map. + params: + src_session: session to be compared + dst_session: session to compare + input_tensor_map: tensor name and value dict for session input. + return: + A dict, including mean_error_ratio and cosine similarity. + """ + src_output = src_session(input_tensor_map) + dst_output = dst_session(input_tensor_map) + + result = { + ErrorAlgType.MEAN_RELATIVE_ERROR.value: cls.get_mean_relative_error_between_features(dst_output, + src_output), + ErrorAlgType.COSINE_SIMILARITY.value: cls.get_cosine_distance_between_features(dst_output, + src_output) + } + return result + + @classmethod + def specific_accuracy_compare(cls, + src_session, + dst_session, + args): + """ + get accuracy compare info between two different sessions with + specific input loading from files. + params: + src_session: session to be compared + dst_session: session to compare + args: input arguments. + return: + A dict, including mean_error_ratio and cosine similarity. + """ + input_tensor_map = np.load(args.input_data_file, + allow_pickle=True).item() + result = cls.real_accuracy_compare(src_session, + dst_session, + input_tensor_map) + return result + + @classmethod + def random_accuracy_compare(cls, + src_session, + dst_session, + args): + """ + get accuracy compare info between two different sessions with random inputs. + params: + src_session: session to be compared + dst_session: session to compare + args: input arguments. + return: + A dict, including mean_error_ratio and cosine similarity. + """ + input_tensor_dtypes = CommonFunc.parse_dtype_infos(args.input_tensor_dtypes) + input_tensor_shapes = CommonFunc.get_tensor_shapes(args.input_tensor_shapes) + input_tensor_infos = { + key: (shape, input_tensor_dtypes.get(key)) + for key, shape in input_tensor_shapes.items() + } + try: + input_tensor_map = CommonFunc.create_numpy_data_map(input_tensor_infos) + except ValueError as e: + _logger.error('Random accuracy compare failed: %s', e) + raise + result = cls.real_accuracy_compare(src_session, + dst_session, + input_tensor_map) + + return result + + @classmethod + def get_cosine_distance_between_features(cls, + calibrate_feature, + cmp_feature): + """ + calculate cosine distance between features. + params: + calibrate_feature: feature to be calibrated. + cmp_feature: feature to compare. + return: + cosine similarity values between features. + """ + cosine_similarity = {} + for key, dst_feature in calibrate_feature.items(): + src_feature = cmp_feature.get(key) + abs_eps = cls.absolute_tolerance() + dst_sum = np.sum(dst_feature * dst_feature) + src_sum = np.sum(src_feature * src_feature) + dot_sum = np.sum(dst_feature * src_feature) + + if dst_sum < abs_eps and src_sum < abs_eps: + value = 1.0 + elif dst_sum * src_sum < abs_eps: + if dst_sum < abs_eps or src_sum < abs_eps: + value = 1.0 + else: + value = 0.0 + else: + value = dot_sum / (np.sqrt(dst_sum) * np.sqrt(src_sum) + abs_eps) + + cosine_similarity[key] = cls.error_format(value) + + return cosine_similarity + + @classmethod + def get_mean_relative_error_between_features(cls, + dst_feature, + src_feature): + """ + calculate mean relative error between features. + params: + dst_feature: feature to be calibrated. + src_feature: feature to compare. + return: + mean relative error values between features. + """ + mean_relative_error_info = {} + np.seterr(divide='ignore', invalid='ignore') + + for key in dst_feature.keys(): + feat_a = dst_feature.get(key, None) + feat_b = src_feature.get(key, None) + if feat_b is None: + raise ValueError(f'Model Inference feature ' + f'is not consistent in tensor: {key}') + if feat_a.size == 0: + mean_relative_error_info[key] = '0.0' + continue + if feat_a.dtype != feat_b.dtype: + _logger.warning('layer %s : different dtypes between onnx out: %s ' + 'with mslite out: %s ', + key, + feat_a.dtype, + feat_b.dtype) + mean_relative_error_info[key] = '0.0' + continue + diff = np.abs(feat_b - feat_a) + abs_feat_a = np.abs(feat_a) + relative_index = diff > cls.relative_tolerance() + if relative_index.size == 0: + mean_relative_error_info[key] = '0.0' + continue + diff = diff[relative_index] + abs_feat_a = abs_feat_a[relative_index] + abs_index = abs_feat_a > cls.absolute_tolerance() + if abs_index.size == 0: + mean_relative_error_info[key] = cls.error_format(np.average(diff)) + continue + abs_feat_a = abs_feat_a[abs_index] + relative_diff = diff[abs_index] + abs_diff = diff[~abs_index] + relative_error = np.divide(relative_diff, abs_feat_a) + mean_relative_error_info[key] = (np.sum(relative_error) + np.sum(abs_diff)) \ + / (relative_error.size + abs_diff.size) + if np.isnan(mean_relative_error_info.get(key, None)): + _logger.warning('layer: %s has nan value, ' + '%s do not work', + key, + ErrorAlgType.MEAN_RELATIVE_ERROR.value) + mean_relative_error_info[key] = cls.error_format(mean_relative_error_info.get(key, None)) + + return mean_relative_error_info + + @classmethod + def get_mean_error_between_features(cls, + dst_feature, + src_feature): + """ + calculate mean error between features. + params: + dst_feature: feature to be calibrated. + src_feature: feature to compare. + return: + mean error values between features. + """ + absolute_tolerance = cls.absolute_tolerance() + relative_tolerance = cls.relative_tolerance() + mean_error_info = {} + + for key in dst_feature.keys(): + feat_a = dst_feature.get(key, None) + feat_b = src_feature.get(key, None) + if feat_b is None: + raise ValueError(f'Model Inference feature ' + f'is not consistent in tensor: {key}') + diff = abs(feat_a - feat_b) + gt_tolerance_index = diff > (absolute_tolerance + relative_tolerance * abs(feat_a)) + lt_tolerance_index = np.logical_and(gt_tolerance_index, abs(feat_a) > absolute_tolerance) + gt_tolerance_index = np.logical_and(gt_tolerance_index, abs(feat_a) < absolute_tolerance) + gt_tolerance_index = np.logical_and(gt_tolerance_index, diff > relative_tolerance) + gt_error = diff[gt_tolerance_index] + lt_error = diff / (abs(feat_a) + absolute_tolerance) + lt_error = lt_error[lt_tolerance_index] + if gt_error.size + lt_error.size == 0: + mean_error = 0.0 + else: + mean_error = (np.sum(gt_error) + np.sum(lt_error)) / \ + (gt_error.size + lt_error.size + 1 + cls.absolute_tolerance()) + mean_error_info[key] = cls.error_format(mean_error) + return mean_error_info + + @staticmethod + def check_np_dtype_with_model_input_dtype(tensor_map, + session): + """ + check input numpy data dtype with model input dtype + params: + tensor_map: a dict with key tensor name and value numpy data. + session: model infer session + return: + a dict with key tensor name and value revised numpy data. + """ + ret_map = tensor_map + input_tensor_infos = session.input_infos + dtype_class = session.dtype_class + + for key, np_data in tensor_map.items(): + np_dtype = np_data.dtype + np_dtype_name = NumpyDtype(np_dtype).name + session_dtype = input_tensor_infos.get(key, None) + + if session_dtype is None: + raise ValueError('Input tensor name is not consistent with model inputs') + session_dtype = session_dtype[1] + + session_dtype_name = dtype_class(session_dtype).name + if session_dtype_name != np_dtype_name: + _logger.warning('input tensor %s input dtype %s ' + 'is not consistent with model dtype(%s)', + key, + np_dtype_name, + session_dtype_name) + new_data = np_data.astype(getattr(NumpyDtype, session_dtype_name).value) + ret_map[key] = new_data + + return ret_map + + @staticmethod + def absolute_tolerance(): + """for const absolute tolerance""" + return 1e-4 + + @staticmethod + def relative_tolerance(): + """for const relative tolerance""" + return 1e-4 + + @staticmethod + def error_format(error): + return f'{error * 100:.4f}%' + + @staticmethod + def write_csv(contents, csv_file): + """write csv""" + contents_to_write = [] + error_names = [] + error_infos = [] + for key, value in contents.items(): + error_names.append(key) + error_infos.append(value) + + for layer_name in list(error_infos[0].keys()): + tmp_dict = {'layer_name': layer_name} + for error_name in error_names: + tmp_dict[error_name] = contents.get(error_name).get(layer_name) + contents_to_write.append(tmp_dict) + + fieldnames = ['layer_name'] + error_names + flags = os.O_WRONLY + mode = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(csv_file, flags, mode), 'w') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(contents_to_write) + + @staticmethod + def is_acc_ok(acc_info): + """add is ok check for accuracy result""" + mre = acc_info.get(ErrorAlgType.MEAN_RELATIVE_ERROR.value, None) + cos = acc_info.get(ErrorAlgType.COSINE_SIMILARITY.value, None) + is_ok = {} + if mre is None or cos is None: + raise ValueError('MRE or cosine similarity is None') + mre_thred = 0.05 + cos_thred = 0.99 + cos_bad_thred = 0.9 + def error_format_to_float(num): + return float(num.strip('%')) / 100 + + nan_set = {'nan', 'nan%'} + for key, mre_val in mre.items(): + cos_val = cos.get(key, None) + if mre_val in nan_set or cos_val in nan_set: + is_ok[key] = 'Invalid' + elif error_format_to_float(cos_val) < cos_bad_thred: + is_ok[key] = 'Bad' + elif error_format_to_float(mre_val) > mre_thred \ + and error_format_to_float(cos_val) < cos_thred: + is_ok[key] = 'Bad' + else: + is_ok[key] = 'Good' + + acc_info['is_ok'] = is_ok + return acc_info diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/easy_infer.py b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/easy_infer.py index 01aba7308b3..b7110def84e 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/easy_infer.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/easy_infer.py @@ -1,116 +1,116 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -functions for model easy infer -""" -import time -import random - -import numpy as np - -from mslite_bench.common.config import ModelConfig -from mslite_bench.utils import InferLogger -from mslite_bench.infer_base import InferSessionFactory -from mslite_bench.common.task_common_func import CommonFunc -from mslite_bench.common.model_info_enum import SaveFileType - - -class EasyInfer: - """ - functions for model easy infer - """ - @staticmethod - def easy_infer(args, logger=None): - """model easy infer""" - if logger is None: - logger = InferLogger(args.log_path).logger - output_data_dir = None - model_path = args.model_file - param_path = args.params_file - - cfg = CommonFunc.get_framework_config(model_path, - args) - if args.input_data_file is not None: - if args.input_data_file.endswith(SaveFileType.NPY.value): - input_data_map = np.load(args.input_data_file, allow_pickle=True).item() - else: - input_data_map = np.fromfile(args.input_data_file) - output_data_dir = f'{args.input_data_file}_output' - - cfg.input_tensor_shapes = { - key: value.shape for key, value in input_data_map.items() - } - else: - input_data_map = CommonFunc.create_numpy_data_map(args) - input_data_file = f'{args.model_file}_input' - output_data_dir = f'{args.model_file}_output' - if args.save_file_type == SaveFileType.NPY.value: - np.save(f'{input_data_file}.npy', input_data_map) - elif args.save_file_type == SaveFileType.BIN.value: - for key, value in input_data_map.items(): - value.tofile(f'{input_data_file}_{"".join(key.split("/"))}.bin') - else: - output_data_dir = None - - model_session = InferSessionFactory.create_infer_session(model_path, - cfg, - params_file=param_path) - logger.debug('Create model session success') - - for _ in range(args.warmup_times): - outputs = model_session(input_data_map) - - start = time.time() - for _ in range(args.loop_infer_times): - outputs = model_session(input_data_map) - end = time.time() - if args.loop_infer_times != 0: - logger.info('Model Infer %s times, ' - 'Avg infer time is %s ms', - args.loop_infer_times, - round((end - start) / args.loop_infer_times * 1000, 3)) - - if output_data_dir is not None: - if args.save_file_type == SaveFileType.NPY.value: - np.save(output_data_dir, outputs) - else: - CommonFunc.save_output_as_benchmark_txt(output_data_dir, - outputs) - return outputs - - @staticmethod - def ms_dynamic_input_infer(args, logger=None): - """conduct dynamic shape mindspore lite model infer""" - if logger is None: - logger = InferLogger(args.log_path).logger - - cfg = ModelConfig(device=args.device) - ms_session = InferSessionFactory.create_infer_session(args.model_path, - cfg) - model_inputs = ms_session.get_input() - - for _ in range(args.dynamic_infer_times): - input_tensor_infos = {} - random_batch_size = random.randint(args.min_random_batch_size, - args.max_random_batch_size) - for input_tensor in model_inputs: - input_shape = input_tensor.shape - input_shape[0] = random_batch_size - tensor_name = input_tensor.name.rstrip() - input_tensor_infos[tensor_name] = (input_shape, input_tensor.dtype) - input_tensor_map = CommonFunc.create_numpy_data_map(input_tensor_infos) - _ = ms_session(input_tensor_map) - - logger.debug('All dynamic input passed successfully') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +functions for model easy infer +""" +import time +import random + +import numpy as np + +from mslite_bench.common.config import ModelConfig +from mslite_bench.utils import InferLogger +from mslite_bench.infer_base import InferSessionFactory +from mslite_bench.common.task_common_func import CommonFunc +from mslite_bench.common.model_info_enum import SaveFileType + + +class EasyInfer: + """ + functions for model easy infer + """ + @staticmethod + def easy_infer(args, logger=None): + """model easy infer""" + if logger is None: + logger = InferLogger(args.log_path).logger + output_data_dir = None + model_path = args.model_file + param_path = args.params_file + + cfg = CommonFunc.get_framework_config(model_path, + args) + if args.input_data_file is not None: + if args.input_data_file.endswith(SaveFileType.NPY.value): + input_data_map = np.load(args.input_data_file, allow_pickle=True).item() + else: + input_data_map = np.fromfile(args.input_data_file) + output_data_dir = f'{args.input_data_file}_output' + + cfg.input_tensor_shapes = { + key: value.shape for key, value in input_data_map.items() + } + else: + input_data_map = CommonFunc.create_numpy_data_map(args) + input_data_file = f'{args.model_file}_input' + output_data_dir = f'{args.model_file}_output' + if args.save_file_type == SaveFileType.NPY.value: + np.save(f'{input_data_file}.npy', input_data_map) + elif args.save_file_type == SaveFileType.BIN.value: + for key, value in input_data_map.items(): + value.tofile(f'{input_data_file}_{"".join(key.split("/"))}.bin') + else: + output_data_dir = None + + model_session = InferSessionFactory.create_infer_session(model_path, + cfg, + params_file=param_path) + logger.debug('Create model session success') + + for _ in range(args.warmup_times): + outputs = model_session(input_data_map) + + start = time.time() + for _ in range(args.loop_infer_times): + outputs = model_session(input_data_map) + end = time.time() + if args.loop_infer_times != 0: + logger.info('Model Infer %s times, ' + 'Avg infer time is %s ms', + args.loop_infer_times, + round((end - start) / args.loop_infer_times * 1000, 3)) + + if output_data_dir is not None: + if args.save_file_type == SaveFileType.NPY.value: + np.save(output_data_dir, outputs) + else: + CommonFunc.save_output_as_benchmark_txt(output_data_dir, + outputs) + return outputs + + @staticmethod + def ms_dynamic_input_infer(args, logger=None): + """conduct dynamic shape mindspore lite model infer""" + if logger is None: + logger = InferLogger(args.log_path).logger + + cfg = ModelConfig(device=args.device) + ms_session = InferSessionFactory.create_infer_session(args.model_path, + cfg) + model_inputs = ms_session.get_input() + + for _ in range(args.dynamic_infer_times): + input_tensor_infos = {} + random_batch_size = random.randint(args.min_random_batch_size, + args.max_random_batch_size) + for input_tensor in model_inputs: + input_shape = input_tensor.shape + input_shape[0] = random_batch_size + tensor_name = input_tensor.name.rstrip() + input_tensor_infos[tensor_name] = (input_shape, input_tensor.dtype) + input_tensor_map = CommonFunc.create_numpy_data_map(input_tensor_infos) + _ = ms_session(input_tensor_map) + + logger.debug('All dynamic input passed successfully') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/mslite_auto_cmp.py b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/mslite_auto_cmp.py index 973d035a5a3..214ac56f785 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/tools/mslite_auto_cmp.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/tools/mslite_auto_cmp.py @@ -1,73 +1,73 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" mindspore lite auto accuracy compare""" -import copy -import os - - -from mslite_bench.tools.converter import MsliteConverter -from mslite_bench.utils import InferLogger -from mslite_bench.graphs.graph_modifier_factory import create_graph_modifier - - -_logger = InferLogger().logger - - -class MsliteAutoCMP: - """Auto compare between third party model and mslite model""" - @classmethod - def acc_infos_in_specific_node(cls, - args, - logger=None): - """accuracy infos in specific node""" - if args.input_tensor_shapes is None: - logger.error('Shall input input_tensor_shapes for accuracy compare') - raise ValueError('input_tensor_shapes is None') - if args.input_tensor_dtypes is None: - logger.error('Shall input input_tensor_dtypes for accuracy compare') - raise ValueError('input_tensor_dtypes is None') - if not args.model_file.endswith('onnx'): - logger.error('Only onnx model accuracy compare is supported') - raise ValueError('input model file shall be .onnx') - graph_modifier = create_graph_modifier(args.model_file) - sub_model_name = f'sub_{args.peak_node_names.replace("/", "_")}_{os.path.basename(args.model_file)}' - sub_model_path = os.path.join(os.path.dirname(args.model_file), sub_model_name) - if logger is None: - logger = _logger - - peak_node_names = cls.parse_peak_node_names(args.peak_node_names) - - logger.info('Collect all node outputs to compare') - graph_out_names = graph_modifier.extract_model(sub_model_path, - output_names=peak_node_names) - logger.debug('Extract sub model successfully') - - args.converter_output_file = os.path.join(os.path.dirname(args.model_file), - f'sub_{args.peak_node_names.replace("/", "_")}') - - args_copy = copy.deepcopy(args) - args_copy.model_file = sub_model_path - args_copy.converter_is_analysis = True - args_copy.output_tensor_names = graph_out_names - args_copy.converter_input_shape = args.input_tensor_shapes - - logger.info('Start to convert model') - MsliteConverter.convert(args_copy, logger, is_delete_ms_model=True) - os.remove(sub_model_path) - - @staticmethod - def parse_peak_node_names(peak_node_str): - """parse peak node names""" - return [item.strip() for item in peak_node_str.split(',')] +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" mindspore lite auto accuracy compare""" +import copy +import os + + +from mslite_bench.tools.converter import MsliteConverter +from mslite_bench.utils import InferLogger +from mslite_bench.graphs.graph_modifier_factory import create_graph_modifier + + +_logger = InferLogger().logger + + +class MsliteAutoCMP: + """Auto compare between third party model and mslite model""" + @classmethod + def acc_infos_in_specific_node(cls, + args, + logger=None): + """accuracy infos in specific node""" + if args.input_tensor_shapes is None: + logger.error('Shall input input_tensor_shapes for accuracy compare') + raise ValueError('input_tensor_shapes is None') + if args.input_tensor_dtypes is None: + logger.error('Shall input input_tensor_dtypes for accuracy compare') + raise ValueError('input_tensor_dtypes is None') + if not args.model_file.endswith('onnx'): + logger.error('Only onnx model accuracy compare is supported') + raise ValueError('input model file shall be .onnx') + graph_modifier = create_graph_modifier(args.model_file) + sub_model_name = f'sub_{args.peak_node_names.replace("/", "_")}_{os.path.basename(args.model_file)}' + sub_model_path = os.path.join(os.path.dirname(args.model_file), sub_model_name) + if logger is None: + logger = _logger + + peak_node_names = cls.parse_peak_node_names(args.peak_node_names) + + logger.info('Collect all node outputs to compare') + graph_out_names = graph_modifier.extract_model(sub_model_path, + output_names=peak_node_names) + logger.debug('Extract sub model successfully') + + args.converter_output_file = os.path.join(os.path.dirname(args.model_file), + f'sub_{args.peak_node_names.replace("/", "_")}') + + args_copy = copy.deepcopy(args) + args_copy.model_file = sub_model_path + args_copy.converter_is_analysis = True + args_copy.output_tensor_names = graph_out_names + args_copy.converter_input_shape = args.input_tensor_shapes + + logger.info('Start to convert model') + MsliteConverter.convert(args_copy, logger, is_delete_ms_model=True) + os.remove(sub_model_path) + + @staticmethod + def parse_peak_node_names(peak_node_str): + """parse peak node names""" + return [item.strip() for item in peak_node_str.split(',')] diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/__init__.py b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/__init__.py index 1798cccecd9..749639e5e8a 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/__init__.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/__init__.py @@ -1,22 +1,22 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -utils for mslite bench -""" - -from mslite_bench.utils.arg_parser import ArgParser -from mslite_bench.utils.infer_log import InferLogger - -__all__ = ['ArgParser', 'InferLogger'] +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +utils for mslite bench +""" + +from mslite_bench.utils.arg_parser import ArgParser +from mslite_bench.utils.infer_log import InferLogger + +__all__ = ['ArgParser', 'InferLogger'] diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/arg_parser.py b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/arg_parser.py index e128024151c..089562a8dd3 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/arg_parser.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/arg_parser.py @@ -1,334 +1,334 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -input argument parser functions -""" -import argparse - - -class ArgParser: - """ - input argument parser functions - """ - @classmethod - def parse_arguments(cls): - """parse input arguments for mslite bench""" - parser = argparse.ArgumentParser(description='Easy Infer for model benchmark') - cls.base_arg_parse(parser) - cls.model_arg_parse(parser) - cls.task_arg_parse(parser) - cls.converter_arg_parse(parser) - cls.auto_cmp_arg_parse(parser) - args = parser.parse_args() - return args - - @classmethod - def task_arg_parse(cls, parser): - """parse task related arguments""" - # for task related - parser.add_argument('--task_type', - type=str, - choices=["infer", "convert", "framework_cmp", "auto_cmp"], - default='auto_cmp', - help='benchmark task type:' - 'infer for framework accuracy compare,' - 'framework_cmp for multiple frameworks accuracy compare' - 'convert for mslite model converter,' - 'auto_cmp for mslite auto compare with third party framework,' - 'both single op and all ops accuracy compare are supported') - # for inference related infos - parser.add_argument('--model_file', - type=str, - default=None, - help='path to model file') - parser.add_argument('--params_file', - type=str, - default="", - help='path to params file') - parser.add_argument('--cmp_model_file', - type=str, - default=None, - help="the model path for model to be compared") - parser.add_argument('--test_data', - type=str, - default=None, - help='path to data to do inference') - parser.add_argument('--test_label', - type=str, - default=None, - help='path to test labels to calculate accuracy of model infer') - - # for benchmark and random accuracy test - parser.add_argument('--input_tensor_shapes', - type=str, - default=None, - help="input tensor infos contain input tensor name and tensor shape" - "format 'tensor_name: tensor_shape;") - parser.add_argument('--input_data_file', - type=str, - default=None, - help="path to files contain input data, with key is input tensor name" - "value is input numpy data") - - parser.add_argument('--save_file_type', - type=str, - default='not_save', - choices=['npy', 'bin', 'not_save'], - help="file type to save input output tensor info, " - "default not save") - - parser.add_argument('--input_tensor_dtypes', - type=str, - default=None, - help="tensor dtype for each model input tensor, " - "choices=[INT8, INT32, INT64" - "FLOAT16, FLOAT, FLOAT64, UINT8]") - - parser.add_argument('--random_input_flag', - type=bool, - default=False, - help="flag indicate whether using random input to do inference") - - parser.add_argument('--loop_infer_times', - type=int, - default=1, - help="infer times for loop infer") - - parser.add_argument('--warmup_times', - type=int, - default=0, - help="warm times for model infer") - - @classmethod - def model_arg_parse(cls, parser): - """parse model and framework related arguments""" - # for mslite config - parser.add_argument('--thread_affinity_mode', - type=int, - default=2, - help='thread affinity number for mslite inference') - - parser.add_argument('--thread_num', - type=int, - default=1, - help='thread number for mslite inference') - - parser.add_argument('--mslite_model_type', - type=int, - default=0, - choices=[0, 4], - help='input model type for mslite inference, ' - '0 for MINDIR, 4 for MINDIR_LITE') - - parser.add_argument('--ascend_provider', - type=str, - default='', - choices=['', 'ge'], - help="Ascend infer method: '' for acl, 'ge' for GE") - - # for tensorrt infer - parser.add_argument('--tensorrt_optim_input_shape', - type=str, - default=None, - help='optim input shape for tensorrt' - 'with key tensor name (str) ' - 'and value shape info(List[int])') - - parser.add_argument('--tensorrt_min_input_shape', - type=str, - default=None, - help='optim input shape for tensorrt' - 'with key tensor name (str) ' - 'and value shape info(List[int])') - - parser.add_argument('--tensorrt_max_input_shape', - type=str, - default=None, - help='optim input shape for tensorrt' - 'with key tensor name (str) ' - 'and value shape info(List[int])') - - parser.add_argument('--gpu_memory_size', - type=int, - default=100, - help='gpu init memory size(M)') - - parser.add_argument('--is_enable_tensorrt', - type=bool, - default=False, - help="flag indicate whether use tensorrt engine") - - parser.add_argument('--is_fp16', - type=bool, - default=False, - help="flag indicate whether apply fp16 infer") - - parser.add_argument('--is_int8', - type=bool, - default=False, - help="flag indicate whether apply int8 infer") - - @staticmethod - def converter_arg_parse(parser): - """parse converter related arguments""" - parser.add_argument('--converter_decrypt_key', - type=str, - default="", - help='decrypt key for mindir, ' - 'only take effect when converter_save_type is MINDIR') - - parser.add_argument('--converter_decrypt_mode', - type=str, - choices=['AES-GCM', 'AES-CBC'], - default='AES-GCM', - help='decrypt model for mindir, ' - 'only take effect when converter_decrypt_key is set') - - parser.add_argument('--converter_enable_encryption', - type=bool, - default=False, - help='whether enable encryption') - - parser.add_argument('--converter_input_shape', - type=str, - default=None, - help="input tensor shape contain input tensor name and tensor shape" - "format 'tensor_name: tensor_shape;") - - parser.add_argument('--converter_weight_fp16', - type=bool, - default=False, - help='whether save model as fp16 model') - - parser.add_argument('--converter_encrypt_key', - type=str, - default="", - help='encrypt key for model encryption,' - 'only take effect when converter decrypt mode ' - 'is AES-GCM') - - parser.add_argument('--converter_optimize', - type=str, - choices=['none', 'general', 'ascend_oriented', 'gpu_oriented'], - default='general', - help='optimize mode for converter') - - parser.add_argument('--converter_save_type', - type=str, - choices=['MINDIR', 'MINDIR_LITE'], - default='MINDIR', - help='model save type for mindspore lite') - - parser.add_argument('--converter_output_file', - type=str, - default="", - help='output path to save mindspore model') - - parser.add_argument('--converter_config_file', - type=str, - default="", - help='config file for converter') - - parser.add_argument('--quant_output_data_type', - type=str, - choices=['float32', 'int8', 'unknown', 'uint8'], - default='unknown', - help='data type for quant model, if unknown is set, data type is ' - 'the same with model input') - - parser.add_argument('--quant_input_data_type', - type=str, - choices=['float32', 'int8', 'unknown', 'uint8'], - default='unknown', - help='data type for quant model, if unknown is set, data type is ' - 'the same with model output') - - parser.add_argument('--converter_is_analysis', - type=bool, - default=False, - help='whether analysis converter status after model convert' - 'if True, will give summary about converter accuracy between' - 'mslite model and third party model') - - parser.add_argument('--converter_input_format', - type=str, - choices=['NCHW', 'NHWC'], - default='NCHW', - help='whether analysis converter status after model convert' - 'if True, will give summary about converter accuracy between' - 'mslite model and third party model') - - @staticmethod - def auto_cmp_arg_parse(parser): - """parse auto cmp related arguments""" - parser.add_argument('--peak_node_names', - type=str, - default='mslite_bench_all', - help='network node name to compare accuracy between' - 'third party framework with mslite framework,' - 'if all is set, every node in network would run accuracy compare') - - parser.add_argument('--cmp_result_file', - type=str, - default='', - help='path to save accuracy info, default in csv format') - - @staticmethod - def base_arg_parse(parser): - """parse base related arguments""" - parser.add_argument('--device', - type=str, - default='ascend', - choices=['cpu', 'gpu', 'ascend'], - help='device type for model inference') - parser.add_argument('--cmp_device', - type=str, - default='cpu', - choices=['cpu', 'gpu', 'ascend'], - help='device type for cmp model inference') - parser.add_argument('--device_id', - type=int, - default=0, - help='device index for model inference') - parser.add_argument('--log_level', - type=int, - choices=[0, 1, 2, 3], - default=1, - help='logging info for mslite bench' - '0 for debug,' - '1 for info,' - '2 for warning' - '3 for error') - parser.add_argument('--frameworkType', - type=str, - default='MSLITE', - choices=['MSLITE', 'PB', 'ONNX', 'PADDLE'], - help='device type for model inference') - parser.add_argument('--log_path', - type=str, - default=None, - help='path to save model inference log') - parser.add_argument('--batch_size', - type=int, - default=1, - help='model inference batch size') - parser.add_argument('--input_tensor_names', - nargs='+', - default=None, - help='model input tensor name list') - parser.add_argument('--output_tensor_names', - nargs='+', - default=None, - help='model output tensor name list') +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +input argument parser functions +""" +import argparse + + +class ArgParser: + """ + input argument parser functions + """ + @classmethod + def parse_arguments(cls): + """parse input arguments for mslite bench""" + parser = argparse.ArgumentParser(description='Easy Infer for model benchmark') + cls.base_arg_parse(parser) + cls.model_arg_parse(parser) + cls.task_arg_parse(parser) + cls.converter_arg_parse(parser) + cls.auto_cmp_arg_parse(parser) + args = parser.parse_args() + return args + + @classmethod + def task_arg_parse(cls, parser): + """parse task related arguments""" + # for task related + parser.add_argument('--task_type', + type=str, + choices=["infer", "convert", "framework_cmp", "auto_cmp"], + default='auto_cmp', + help='benchmark task type:' + 'infer for framework accuracy compare,' + 'framework_cmp for multiple frameworks accuracy compare' + 'convert for mslite model converter,' + 'auto_cmp for mslite auto compare with third party framework,' + 'both single op and all ops accuracy compare are supported') + # for inference related infos + parser.add_argument('--model_file', + type=str, + default=None, + help='path to model file') + parser.add_argument('--params_file', + type=str, + default="", + help='path to params file') + parser.add_argument('--cmp_model_file', + type=str, + default=None, + help="the model path for model to be compared") + parser.add_argument('--test_data', + type=str, + default=None, + help='path to data to do inference') + parser.add_argument('--test_label', + type=str, + default=None, + help='path to test labels to calculate accuracy of model infer') + + # for benchmark and random accuracy test + parser.add_argument('--input_tensor_shapes', + type=str, + default=None, + help="input tensor infos contain input tensor name and tensor shape" + "format 'tensor_name: tensor_shape;") + parser.add_argument('--input_data_file', + type=str, + default=None, + help="path to files contain input data, with key is input tensor name" + "value is input numpy data") + + parser.add_argument('--save_file_type', + type=str, + default='not_save', + choices=['npy', 'bin', 'not_save'], + help="file type to save input output tensor info, " + "default not save") + + parser.add_argument('--input_tensor_dtypes', + type=str, + default=None, + help="tensor dtype for each model input tensor, " + "choices=[INT8, INT32, INT64" + "FLOAT16, FLOAT, FLOAT64, UINT8]") + + parser.add_argument('--random_input_flag', + type=bool, + default=False, + help="flag indicate whether using random input to do inference") + + parser.add_argument('--loop_infer_times', + type=int, + default=1, + help="infer times for loop infer") + + parser.add_argument('--warmup_times', + type=int, + default=0, + help="warm times for model infer") + + @classmethod + def model_arg_parse(cls, parser): + """parse model and framework related arguments""" + # for mslite config + parser.add_argument('--thread_affinity_mode', + type=int, + default=2, + help='thread affinity number for mslite inference') + + parser.add_argument('--thread_num', + type=int, + default=1, + help='thread number for mslite inference') + + parser.add_argument('--mslite_model_type', + type=int, + default=0, + choices=[0, 4], + help='input model type for mslite inference, ' + '0 for MINDIR, 4 for MINDIR_LITE') + + parser.add_argument('--ascend_provider', + type=str, + default='', + choices=['', 'ge'], + help="Ascend infer method: '' for acl, 'ge' for GE") + + # for tensorrt infer + parser.add_argument('--tensorrt_optim_input_shape', + type=str, + default=None, + help='optim input shape for tensorrt' + 'with key tensor name (str) ' + 'and value shape info(List[int])') + + parser.add_argument('--tensorrt_min_input_shape', + type=str, + default=None, + help='optim input shape for tensorrt' + 'with key tensor name (str) ' + 'and value shape info(List[int])') + + parser.add_argument('--tensorrt_max_input_shape', + type=str, + default=None, + help='optim input shape for tensorrt' + 'with key tensor name (str) ' + 'and value shape info(List[int])') + + parser.add_argument('--gpu_memory_size', + type=int, + default=100, + help='gpu init memory size(M)') + + parser.add_argument('--is_enable_tensorrt', + type=bool, + default=False, + help="flag indicate whether use tensorrt engine") + + parser.add_argument('--is_fp16', + type=bool, + default=False, + help="flag indicate whether apply fp16 infer") + + parser.add_argument('--is_int8', + type=bool, + default=False, + help="flag indicate whether apply int8 infer") + + @staticmethod + def converter_arg_parse(parser): + """parse converter related arguments""" + parser.add_argument('--converter_decrypt_key', + type=str, + default="", + help='decrypt key for mindir, ' + 'only take effect when converter_save_type is MINDIR') + + parser.add_argument('--converter_decrypt_mode', + type=str, + choices=['AES-GCM', 'AES-CBC'], + default='AES-GCM', + help='decrypt model for mindir, ' + 'only take effect when converter_decrypt_key is set') + + parser.add_argument('--converter_enable_encryption', + type=bool, + default=False, + help='whether enable encryption') + + parser.add_argument('--converter_input_shape', + type=str, + default=None, + help="input tensor shape contain input tensor name and tensor shape" + "format 'tensor_name: tensor_shape;") + + parser.add_argument('--converter_weight_fp16', + type=bool, + default=False, + help='whether save model as fp16 model') + + parser.add_argument('--converter_encrypt_key', + type=str, + default="", + help='encrypt key for model encryption,' + 'only take effect when converter decrypt mode ' + 'is AES-GCM') + + parser.add_argument('--converter_optimize', + type=str, + choices=['none', 'general', 'ascend_oriented', 'gpu_oriented'], + default='general', + help='optimize mode for converter') + + parser.add_argument('--converter_save_type', + type=str, + choices=['MINDIR', 'MINDIR_LITE'], + default='MINDIR', + help='model save type for mindspore lite') + + parser.add_argument('--converter_output_file', + type=str, + default="", + help='output path to save mindspore model') + + parser.add_argument('--converter_config_file', + type=str, + default="", + help='config file for converter') + + parser.add_argument('--quant_output_data_type', + type=str, + choices=['float32', 'int8', 'unknown', 'uint8'], + default='unknown', + help='data type for quant model, if unknown is set, data type is ' + 'the same with model input') + + parser.add_argument('--quant_input_data_type', + type=str, + choices=['float32', 'int8', 'unknown', 'uint8'], + default='unknown', + help='data type for quant model, if unknown is set, data type is ' + 'the same with model output') + + parser.add_argument('--converter_is_analysis', + type=bool, + default=False, + help='whether analysis converter status after model convert' + 'if True, will give summary about converter accuracy between' + 'mslite model and third party model') + + parser.add_argument('--converter_input_format', + type=str, + choices=['NCHW', 'NHWC'], + default='NCHW', + help='whether analysis converter status after model convert' + 'if True, will give summary about converter accuracy between' + 'mslite model and third party model') + + @staticmethod + def auto_cmp_arg_parse(parser): + """parse auto cmp related arguments""" + parser.add_argument('--peak_node_names', + type=str, + default='mslite_bench_all', + help='network node name to compare accuracy between' + 'third party framework with mslite framework,' + 'if all is set, every node in network would run accuracy compare') + + parser.add_argument('--cmp_result_file', + type=str, + default='', + help='path to save accuracy info, default in csv format') + + @staticmethod + def base_arg_parse(parser): + """parse base related arguments""" + parser.add_argument('--device', + type=str, + default='ascend', + choices=['cpu', 'gpu', 'ascend'], + help='device type for model inference') + parser.add_argument('--cmp_device', + type=str, + default='cpu', + choices=['cpu', 'gpu', 'ascend'], + help='device type for cmp model inference') + parser.add_argument('--device_id', + type=int, + default=0, + help='device index for model inference') + parser.add_argument('--log_level', + type=int, + choices=[0, 1, 2, 3], + default=1, + help='logging info for mslite bench' + '0 for debug,' + '1 for info,' + '2 for warning' + '3 for error') + parser.add_argument('--frameworkType', + type=str, + default='MSLITE', + choices=['MSLITE', 'PB', 'ONNX', 'PADDLE'], + help='device type for model inference') + parser.add_argument('--log_path', + type=str, + default=None, + help='path to save model inference log') + parser.add_argument('--batch_size', + type=int, + default=1, + help='model inference batch size') + parser.add_argument('--input_tensor_names', + nargs='+', + default=None, + help='model input tensor name list') + parser.add_argument('--output_tensor_names', + nargs='+', + default=None, + help='model output tensor name list') diff --git a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/infer_log.py b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/infer_log.py index 93f0d4c6523..177bddd3cdb 100644 --- a/mindspore/lite/tools/mslite_bench/mslite_bench/utils/infer_log.py +++ b/mindspore/lite/tools/mslite_bench/mslite_bench/utils/infer_log.py @@ -1,68 +1,68 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -log for mslite bench -""" -import logging -from functools import wraps - - -def singleton(cls): - """singleton decorator function""" - instances_ = {} - - @wraps(cls) - def _get_instances(*args, **kwargs): - if cls not in instances_: - instances_[cls] = cls(*args, **kwargs) - return instances_.get(cls, None) - - return _get_instances - - -@singleton -class InferLogger: - """ - logger for mslite bench, with singleton decorated - """ - def __init__(self, file_path: str = None): - self.file_path = file_path - self.logger_ = self._create_logger() - - @property - def logger(self): - return self.logger_ - - def set_level(self, level=logging.info): - self.logger_.setLevel(level) - - def _create_logger(self): - """create logger for mslite bench""" - logger = logging.getLogger('MSLITE_BENCH') - log_format = '%(asctime)s - [%(name)s-%(levelname)s' \ - '(%(filename)s:%(lineno)d)]: %(message)s' - formatter = logging.Formatter(log_format, - datefmt='%m/%d %I:%M:%S %p') - if self.file_path is not None: - file_handler = logging.FileHandler(self.file_path) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - logger.setLevel(logging.INFO) - - return logger +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +log for mslite bench +""" +import logging +from functools import wraps + + +def singleton(cls): + """singleton decorator function""" + instances_ = {} + + @wraps(cls) + def _get_instances(*args, **kwargs): + if cls not in instances_: + instances_[cls] = cls(*args, **kwargs) + return instances_.get(cls, None) + + return _get_instances + + +@singleton +class InferLogger: + """ + logger for mslite bench, with singleton decorated + """ + def __init__(self, file_path: str = None): + self.file_path = file_path + self.logger_ = self._create_logger() + + @property + def logger(self): + return self.logger_ + + def set_level(self, level=logging.info): + self.logger_.setLevel(level) + + def _create_logger(self): + """create logger for mslite bench""" + logger = logging.getLogger('MSLITE_BENCH') + log_format = '%(asctime)s - [%(name)s-%(levelname)s' \ + '(%(filename)s:%(lineno)d)]: %(message)s' + formatter = logging.Formatter(log_format, + datefmt='%m/%d %I:%M:%S %p') + if self.file_path is not None: + file_handler = logging.FileHandler(self.file_path) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) + + return logger diff --git a/mindspore/lite/tools/mslite_bench/requirements.txt b/mindspore/lite/tools/mslite_bench/requirements.txt index 4eea5cb22d1..2492772607c 100644 --- a/mindspore/lite/tools/mslite_bench/requirements.txt +++ b/mindspore/lite/tools/mslite_bench/requirements.txt @@ -1,5 +1,5 @@ -# required -numpy>=1.17.0 -mindspore-lite -onnx>=1.13.0 +# required +numpy>=1.17.0 +mindspore-lite +onnx>=1.13.0 onnxruntime>=1.12.1 \ No newline at end of file diff --git a/mindspore/lite/tools/mslite_bench/setup.py b/mindspore/lite/tools/mslite_bench/setup.py index be591bee467..5493099a299 100644 --- a/mindspore/lite/tools/mslite_bench/setup.py +++ b/mindspore/lite/tools/mslite_bench/setup.py @@ -1,33 +1,33 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -build mslite_bench whl -""" -from setuptools import setup, find_packages -with open('requirements.txt', encoding='utf-8') as f: - required = f.read().splitlines() - -setup( - name='mslite_bench', - version='0.0.1-alpha', - description='performance and accuracy tools for multiple framework model infer', - long_description='Debug and optimizer tool for mindspore lite', - url='mslite_bench url', - packages=find_packages(), - py_modules=['mslite_bench'], - keywords='mslite_bench', - install_requires=required, - python_requires='>=3.7' -) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +build mslite_bench whl +""" +from setuptools import setup, find_packages +with open('requirements.txt', encoding='utf-8') as f: + required = f.read().splitlines() + +setup( + name='mslite_bench', + version='0.0.1-alpha', + description='performance and accuracy tools for multiple framework model infer', + long_description='Debug and optimizer tool for mindspore lite', + url='mslite_bench url', + packages=find_packages(), + py_modules=['mslite_bench'], + keywords='mslite_bench', + install_requires=required, + python_requires='>=3.7' +) diff --git a/mindspore/lite/tools/obfuscator/lib/android-aarch32/libmsdeobfuscator-lite.so b/mindspore/lite/tools/obfuscator/lib/android-aarch32/libmsdeobfuscator-lite.so deleted file mode 100644 index fa14e3f467a..00000000000 Binary files a/mindspore/lite/tools/obfuscator/lib/android-aarch32/libmsdeobfuscator-lite.so and /dev/null differ diff --git a/mindspore/lite/tools/obfuscator/lib/android-aarch64/libmsdeobfuscator-lite.so b/mindspore/lite/tools/obfuscator/lib/android-aarch64/libmsdeobfuscator-lite.so deleted file mode 100644 index 2d307cfa607..00000000000 Binary files a/mindspore/lite/tools/obfuscator/lib/android-aarch64/libmsdeobfuscator-lite.so and /dev/null differ diff --git a/mindspore/lite/tools/obfuscator/lib/linux-x64/libmsdeobfuscator-lite.so b/mindspore/lite/tools/obfuscator/lib/linux-x64/libmsdeobfuscator-lite.so deleted file mode 100644 index bf6cbe963f5..00000000000 Binary files a/mindspore/lite/tools/obfuscator/lib/linux-x64/libmsdeobfuscator-lite.so and /dev/null differ diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/boost/boost.py b/mindspore/python/mindspore/boost/boost.py index e854e9e7e1f..60bac30530d 100644 --- a/mindspore/python/mindspore/boost/boost.py +++ b/mindspore/python/mindspore/boost/boost.py @@ -1,400 +1,400 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""boost""" -from __future__ import absolute_import - -import threading -from mindspore.nn.optim import SGD -from mindspore.boost.less_batch_normalization import LessBN -from mindspore.boost.grad_freeze import GradientFreeze -from mindspore.boost.base import OptimizerProcess, ParameterProcess -from mindspore.boost.base import _get_local_pca_mat_path - - -__all__ = ["AutoBoost"] - -_boost_config_mode = ["auto", "manual", "enable_all", "disable_all"] -_boost_config_level = { - "O0": { - "less_bn": False, - "grad_freeze": False, - "adasum": False, - "grad_accumulation": False, - "dim_reduce": False, - 'loss_scale_group': False}, - "O1": { - "less_bn": True, - "grad_freeze": True, - "adasum": False, - "grad_accumulation": False, - "dim_reduce": False, - 'loss_scale_group': False}, - "O2": { - "less_bn": True, - "grad_freeze": True, - "adasum": True, - "grad_accumulation": False, - "dim_reduce": False, - 'loss_scale_group': False} - } - - -class AutoBoost: - r""" - Provide auto accelerating for network. - - Args: - level (str): Boost config level. Default: ``"O0"`` . - boost_config_dict (dict): User config hyperparameter dict, recommended config format: - - .. code-block:: - - { - "boost": { - "mode": "auto", - "less_bn": False, - "grad_freeze": False, - "adasum": False, - "grad_accumulation": False, - "dim_reduce": False, - "loss_scale_group": False - }, - "common": { - "gradient_split_groups": [50, 100], - "device_number": 8 - }, - "less_bn": { - "fn_flag": True, - "gc_flag": True - }, - "grad_freeze": { - "param_groups": 10, - "freeze_type": 1, - "freeze_p": 0.7, - "total_steps": 65536 - } - "dim_reduce": { - "rho": 0.55, - "gamma": 0.9, - "alpha": 0.001, - "sigma": 0.4, - "n_components": 32, - "pca_mat_path": None, - "weight_load_dir": None, - "timeout": 1800 - } - } - - Default: ``""`` . - - - boost: - - - mode (str): How to set the boost. Supports ["auto", "manual", "enable_all", "disable_all"]. - Default: ``"auto"`` . - - - auto: Depend on the argument "boost_level" in class Model. - - manual: Depend on "boost_config_dict". - - enable_all: Set all boost functions true. - - disable_all: Set all boost functions false. - - - less_bn (bool): Whether to apply less_bn function. Default: ``False`` . - - grad_freeze: (bool): Whether to apply grad_freeze function. Default: ``False`` . - - adasum (bool): Whether to apply adasum function. Default: ``False`` . - - grad_accumulation (bool): Whether to apply grad_accumulation function. Default: ``False`` . - - dim_reduce (bool): Whether to apply dim_reduce function. Default: ``False`` . - - loss_scale_group (bool): Whether to apply loss_scale_group function. Default: ``False`` . - - If set dim_reduce true, other functions will be false. - If set grad_freeze true and dim_reduce false, other functions will be false. - - - common: - - - gradient_split_groups (list): The gradient split point of this network. Default: ``[50, 100]`` . - - device_number (int): Device number. Default: ``8`` . - - - less_bn: - - - fn_flag (bool): Whether changing fc to fn. Default: ``True`` . - - gc_flag (bool): Whether to apply gc. Default: ``True`` . - - - grad_freeze: - - - param_groups (int): The number of parameter groups. Default: ``10`` . - - freeze_type (int): Gradient freeze grouping strategy, select from [0, 1]. Default: ``1`` . - - freeze_p (float): Gradient freezing probability. Default: ``0.7`` . - - total_steps (int): Total training steps. Default: ``65536`` . - - - dim_reduce: - - The leading principles of dim_reduce: - - .. math:: - - \begin{align} - grad\_k &= pca\_mat \cdot grad\\ - dk &= - bk \cdot grad\_k\\ - sk &= rho ^ m \cdot dk\\ - delta\_loss &= sigma \cdot grad\_k.T \cdot sk - \end{align} - - Here: - - - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight. - - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method. - - we need to find the m satisfy: - - .. math:: - new\_loss < old\_loss + delta\_loss - - Then, get delta_grad to update the weights for model: - - .. math:: - - \begin{align} - grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\ - new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\ - delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk - \end{align} - - - rho (float): Generally, it does not need to be modified. Default: ``0.55`` . - - gamma (float): Generally, it does not need to be modified. Default: ``0.9`` . - - alpha (float): Generally, it does not need to be modified. Default: ``0.001`` . - - sigma (float): Generally, it does not need to be modified. Default: ``0.4`` . - - n_components (int): PCA component. Default: ``32`` . - - pca_mat_path (str): The path to load pca mat. Default: ``None`` . - - weight_load_dir (str): The directory to load weight files saved as ckpt. Default: ``None`` . - - timeout (int): Waiting time to load local pca mat. Default: ``1800 (second)`` . - - User can load the config through the JSON file or use the dictionary directly. - The unconfigured parameters will adopt the default values. - - Raises: - ValueError: The boost mode not in ["auto", "manual", "enable_all", "disable_all"]. - - Supported Platforms: - ``Ascend`` - - Examples: - >>> from mindspore.boost import AutoBoost - >>> #1) when configuring the dict directly: - >>> boost_config_dict = {"boost": {"mode": "auto"}} - >>> boost = AutoBoost("O1", boost_config_dict) - >>> - >>> #2) when loading the dict from a json file: - >>> import json - >>> boost_json = "/path/boost_config.json" - >>> with open(boost_json, 'r') as fp: - ... boost_config_dict = json.load(fp) - >>> boost = AutoBoost("O1", boost_config_dict) - """ - _instance_lock = threading.Lock() - _instance = None - - # pylint: disable=unused-argument - def __new__(cls, *args, **kwargs): - if AutoBoost._instance is None: - with AutoBoost._instance_lock: - if AutoBoost._instance is None: - AutoBoost._instance = object.__new__(cls) - AutoBoost._instance.level = None - AutoBoost._instance.boost_config_dict = None - return AutoBoost._instance - - def __init__(self, level="O0", boost_config_dict=""): - if level not in _boost_config_level.keys(): - level = "O0" - if self._instance.level is None: - self.level = level - self.boost_config_dict = boost_config_dict - self._fn_flag = True - self._gc_flag = True - self._param_groups = 10 - self._freeze_type = 1 - self._freeze_p = 0.7 - self._total_steps = 65536 - self.gradient_groups = None - self.device_number = 8 - self.grad_accumulation_step = 1 - self.rho = 0.55 - self.gamma = 0.9 - self.alpha = 0.001 - self.sigma = 0.4 - self.n_components = 32 - self.pca_mat_path = None - self.weight_load_dir = None - self.local_pca_mat_path = None - self.timeout = 1800 - self.boost_config = self._get_configuration(level, self.boost_config_dict) - self._param_processer = ParameterProcess() - - def network_auto_process_train(self, network, optimizer): - r""" - Boost network train. - - Args: - network (Cell): The training network. - optimizer (Cell): Optimizer for updating the weights. - """ - if self.boost_config.get("dim_reduce"): - self.local_pca_mat_path = _get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path, - self.n_components, self.device_number, network) - optimizer = SGD(network.trainable_params(), learning_rate=1) - setattr(optimizer, "dim_reduce", True) - return network, optimizer - - if self.boost_config.get("less_bn"): - network = LessBN(network, fn_flag=self._fn_flag) - optimizer_process = OptimizerProcess(optimizer) - group_params = self._param_processer.assign_parameter_group(network.trainable_params(), - self.gradient_groups) - optimizer_process.origin_params = \ - ParameterProcess.generate_group_params(group_params, optimizer_process.origin_params) - if self._gc_flag: - optimizer_process.add_grad_centralization(network) - optimizer = optimizer_process.generate_new_optimizer() - - if self.boost_config.get("grad_freeze"): - freeze_processer = GradientFreeze(self._param_groups, self._freeze_type, - self._freeze_p, self._total_steps) - network, optimizer = freeze_processer.freeze_generate(network, optimizer) - - if self.boost_config.get("adasum"): - setattr(optimizer, "adasum", True) - return network, optimizer - - def network_auto_process_eval(self, network): - r""" - Boost network eval. - - Args: - network (Cell): The inference network. - """ - if self.boost_config.get("dim_reduce"): - return network - if self.boost_config.get("less_bn"): - network = LessBN(network) - - return network - - def _set_fn_flag(self, fn_flag): - self._fn_flag = fn_flag - - def _set_gc_flag(self, gc_flag): - self._gc_flag = gc_flag - - def _set_param_groups(self, param_groups): - self._param_groups = param_groups - - def _set_freeze_type(self, freeze_type): - self._freeze_type = freeze_type - - def _set_freeze_p(self, freeze_p): - self._freeze_p = freeze_p - - def _set_total_steps(self, total_steps): - self._total_steps = total_steps - - def _set_device_number(self, device_number): - self.device_number = device_number - - def _set_grad_accumulation_step(self, grad_accumulation_step): - self.grad_accumulation_step = grad_accumulation_step - - def _set_gradient_split_groups(self, gradient_groups): - if not isinstance(gradient_groups, (list, int)): - raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)") - if isinstance(gradient_groups, int): - gradient_groups = list(gradient_groups) - self.gradient_groups = gradient_groups - - def _set_rho(self, rho): - self.rho = rho - - def _set_gamma(self, gamma): - self.gamma = gamma - - def _set_alpha(self, alpha): - self.alpha = alpha - - def _set_sigma(self, sigma): - self.sigma = sigma - - def _set_n_components(self, n_components): - self.n_components = n_components - - def _set_pca_mat_path(self, pca_mat_path): - self.pca_mat_path = pca_mat_path - - def _set_weight_load_dir(self, weight_load_dir): - self.weight_load_dir = weight_load_dir - - def _set_timeout(self, timeout): - self.timeout = timeout - - def _get_configuration(self, level, boost_config_dict): - """Get configuration.""" - level_config = _boost_config_level.get(level) - if not boost_config_dict: - return level_config - - mode = "auto" - if 'boost' in boost_config_dict and 'mode' in boost_config_dict['boost']: - mode = boost_config_dict['boost']['mode'] - if mode not in _boost_config_mode: - raise ValueError("The boost mode must be in {}, but got {}".format(_boost_config_mode, mode)) - - if mode == "manual": - for key, value in boost_config_dict["boost"].items(): - if key in level_config: - level_config[key] = value - elif mode == "enable_all": - level_config = {key: True for key in level_config} - elif mode == "disable_all": - level_config = {key: False for key in level_config} - - self._do_new_config_func(boost_config_dict, level_config) - return level_config - - def _do_new_config_func(self, boost_config_dict, level_config): - valid_boost_each_mode_config = [] - for key, boost_each_mode_config in boost_config_dict.items(): - if key in level_config.keys() and level_config[key] or key == "common": - valid_boost_each_mode_config.append(boost_each_mode_config) - - for boost_each_mode_config in valid_boost_each_mode_config: - for key_s in boost_each_mode_config.keys(): - if key_s in self._boost_config_func_map: - self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s]) - - _boost_config_func_map = { - "fn_flag": _set_fn_flag, - "gc_flag": _set_gc_flag, - "param_groups": _set_param_groups, - "freeze_type": _set_freeze_type, - "freeze_p": _set_freeze_p, - "total_steps": _set_total_steps, - "device_number": _set_device_number, - "gradient_split_groups": _set_gradient_split_groups, - "grad_accumulation_step": _set_grad_accumulation_step, - "rho": _set_rho, - "gamma": _set_gamma, - "alpha": _set_alpha, - "sigma": _set_sigma, - "n_components": _set_n_components, - "pca_mat_path": _set_pca_mat_path, - "weight_load_dir": _set_weight_load_dir, - "timeout": _set_timeout - } +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""boost""" +from __future__ import absolute_import + +import threading +from mindspore.nn.optim import SGD +from mindspore.boost.less_batch_normalization import LessBN +from mindspore.boost.grad_freeze import GradientFreeze +from mindspore.boost.base import OptimizerProcess, ParameterProcess +from mindspore.boost.base import _get_local_pca_mat_path + + +__all__ = ["AutoBoost"] + +_boost_config_mode = ["auto", "manual", "enable_all", "disable_all"] +_boost_config_level = { + "O0": { + "less_bn": False, + "grad_freeze": False, + "adasum": False, + "grad_accumulation": False, + "dim_reduce": False, + 'loss_scale_group': False}, + "O1": { + "less_bn": True, + "grad_freeze": True, + "adasum": False, + "grad_accumulation": False, + "dim_reduce": False, + 'loss_scale_group': False}, + "O2": { + "less_bn": True, + "grad_freeze": True, + "adasum": True, + "grad_accumulation": False, + "dim_reduce": False, + 'loss_scale_group': False} + } + + +class AutoBoost: + r""" + Provide auto accelerating for network. + + Args: + level (str): Boost config level. Default: ``"O0"`` . + boost_config_dict (dict): User config hyperparameter dict, recommended config format: + + .. code-block:: + + { + "boost": { + "mode": "auto", + "less_bn": False, + "grad_freeze": False, + "adasum": False, + "grad_accumulation": False, + "dim_reduce": False, + "loss_scale_group": False + }, + "common": { + "gradient_split_groups": [50, 100], + "device_number": 8 + }, + "less_bn": { + "fn_flag": True, + "gc_flag": True + }, + "grad_freeze": { + "param_groups": 10, + "freeze_type": 1, + "freeze_p": 0.7, + "total_steps": 65536 + } + "dim_reduce": { + "rho": 0.55, + "gamma": 0.9, + "alpha": 0.001, + "sigma": 0.4, + "n_components": 32, + "pca_mat_path": None, + "weight_load_dir": None, + "timeout": 1800 + } + } + + Default: ``""`` . + + - boost: + + - mode (str): How to set the boost. Supports ["auto", "manual", "enable_all", "disable_all"]. + Default: ``"auto"`` . + + - auto: Depend on the argument "boost_level" in class Model. + - manual: Depend on "boost_config_dict". + - enable_all: Set all boost functions true. + - disable_all: Set all boost functions false. + + - less_bn (bool): Whether to apply less_bn function. Default: ``False`` . + - grad_freeze: (bool): Whether to apply grad_freeze function. Default: ``False`` . + - adasum (bool): Whether to apply adasum function. Default: ``False`` . + - grad_accumulation (bool): Whether to apply grad_accumulation function. Default: ``False`` . + - dim_reduce (bool): Whether to apply dim_reduce function. Default: ``False`` . + - loss_scale_group (bool): Whether to apply loss_scale_group function. Default: ``False`` . + + If set dim_reduce true, other functions will be false. + If set grad_freeze true and dim_reduce false, other functions will be false. + + - common: + + - gradient_split_groups (list): The gradient split point of this network. Default: ``[50, 100]`` . + - device_number (int): Device number. Default: ``8`` . + + - less_bn: + + - fn_flag (bool): Whether changing fc to fn. Default: ``True`` . + - gc_flag (bool): Whether to apply gc. Default: ``True`` . + + - grad_freeze: + + - param_groups (int): The number of parameter groups. Default: ``10`` . + - freeze_type (int): Gradient freeze grouping strategy, select from [0, 1]. Default: ``1`` . + - freeze_p (float): Gradient freezing probability. Default: ``0.7`` . + - total_steps (int): Total training steps. Default: ``65536`` . + + - dim_reduce: + + The leading principles of dim_reduce: + + .. math:: + + \begin{align} + grad\_k &= pca\_mat \cdot grad\\ + dk &= - bk \cdot grad\_k\\ + sk &= rho ^ m \cdot dk\\ + delta\_loss &= sigma \cdot grad\_k.T \cdot sk + \end{align} + + Here: + + - pca_mat (array): Shape :math:`(k*n)`, k is part of n_components, n is the size of weight. + - bk (array): Shape :math:`(k*k)`, is the symmetric positive definite matrix in Quasi-Newton method. + + we need to find the m satisfy: + + .. math:: + new\_loss < old\_loss + delta\_loss + + Then, get delta_grad to update the weights for model: + + .. math:: + + \begin{align} + grad\_k\_proj &= pca\_mat.T \cdot grad\_k\\ + new\_grad\_momentum &= gamma \cdot old\_grad\_momentum + grad - grad\_k\_proj\\ + delta\_grad &= alpha \cdot new\_grad\_momentum - pca\_mat.T \cdot sk + \end{align} + + - rho (float): Generally, it does not need to be modified. Default: ``0.55`` . + - gamma (float): Generally, it does not need to be modified. Default: ``0.9`` . + - alpha (float): Generally, it does not need to be modified. Default: ``0.001`` . + - sigma (float): Generally, it does not need to be modified. Default: ``0.4`` . + - n_components (int): PCA component. Default: ``32`` . + - pca_mat_path (str): The path to load pca mat. Default: ``None`` . + - weight_load_dir (str): The directory to load weight files saved as ckpt. Default: ``None`` . + - timeout (int): Waiting time to load local pca mat. Default: ``1800 (second)`` . + + User can load the config through the JSON file or use the dictionary directly. + The unconfigured parameters will adopt the default values. + + Raises: + ValueError: The boost mode not in ["auto", "manual", "enable_all", "disable_all"]. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore.boost import AutoBoost + >>> #1) when configuring the dict directly: + >>> boost_config_dict = {"boost": {"mode": "auto"}} + >>> boost = AutoBoost("O1", boost_config_dict) + >>> + >>> #2) when loading the dict from a json file: + >>> import json + >>> boost_json = "/path/boost_config.json" + >>> with open(boost_json, 'r') as fp: + ... boost_config_dict = json.load(fp) + >>> boost = AutoBoost("O1", boost_config_dict) + """ + _instance_lock = threading.Lock() + _instance = None + + # pylint: disable=unused-argument + def __new__(cls, *args, **kwargs): + if AutoBoost._instance is None: + with AutoBoost._instance_lock: + if AutoBoost._instance is None: + AutoBoost._instance = object.__new__(cls) + AutoBoost._instance.level = None + AutoBoost._instance.boost_config_dict = None + return AutoBoost._instance + + def __init__(self, level="O0", boost_config_dict=""): + if level not in _boost_config_level.keys(): + level = "O0" + if self._instance.level is None: + self.level = level + self.boost_config_dict = boost_config_dict + self._fn_flag = True + self._gc_flag = True + self._param_groups = 10 + self._freeze_type = 1 + self._freeze_p = 0.7 + self._total_steps = 65536 + self.gradient_groups = None + self.device_number = 8 + self.grad_accumulation_step = 1 + self.rho = 0.55 + self.gamma = 0.9 + self.alpha = 0.001 + self.sigma = 0.4 + self.n_components = 32 + self.pca_mat_path = None + self.weight_load_dir = None + self.local_pca_mat_path = None + self.timeout = 1800 + self.boost_config = self._get_configuration(level, self.boost_config_dict) + self._param_processer = ParameterProcess() + + def network_auto_process_train(self, network, optimizer): + r""" + Boost network train. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + """ + if self.boost_config.get("dim_reduce"): + self.local_pca_mat_path = _get_local_pca_mat_path(self.weight_load_dir, self.pca_mat_path, + self.n_components, self.device_number, network) + optimizer = SGD(network.trainable_params(), learning_rate=1) + setattr(optimizer, "dim_reduce", True) + return network, optimizer + + if self.boost_config.get("less_bn"): + network = LessBN(network, fn_flag=self._fn_flag) + optimizer_process = OptimizerProcess(optimizer) + group_params = self._param_processer.assign_parameter_group(network.trainable_params(), + self.gradient_groups) + optimizer_process.origin_params = \ + ParameterProcess.generate_group_params(group_params, optimizer_process.origin_params) + if self._gc_flag: + optimizer_process.add_grad_centralization(network) + optimizer = optimizer_process.generate_new_optimizer() + + if self.boost_config.get("grad_freeze"): + freeze_processer = GradientFreeze(self._param_groups, self._freeze_type, + self._freeze_p, self._total_steps) + network, optimizer = freeze_processer.freeze_generate(network, optimizer) + + if self.boost_config.get("adasum"): + setattr(optimizer, "adasum", True) + return network, optimizer + + def network_auto_process_eval(self, network): + r""" + Boost network eval. + + Args: + network (Cell): The inference network. + """ + if self.boost_config.get("dim_reduce"): + return network + if self.boost_config.get("less_bn"): + network = LessBN(network) + + return network + + def _set_fn_flag(self, fn_flag): + self._fn_flag = fn_flag + + def _set_gc_flag(self, gc_flag): + self._gc_flag = gc_flag + + def _set_param_groups(self, param_groups): + self._param_groups = param_groups + + def _set_freeze_type(self, freeze_type): + self._freeze_type = freeze_type + + def _set_freeze_p(self, freeze_p): + self._freeze_p = freeze_p + + def _set_total_steps(self, total_steps): + self._total_steps = total_steps + + def _set_device_number(self, device_number): + self.device_number = device_number + + def _set_grad_accumulation_step(self, grad_accumulation_step): + self.grad_accumulation_step = grad_accumulation_step + + def _set_gradient_split_groups(self, gradient_groups): + if not isinstance(gradient_groups, (list, int)): + raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)") + if isinstance(gradient_groups, int): + gradient_groups = list(gradient_groups) + self.gradient_groups = gradient_groups + + def _set_rho(self, rho): + self.rho = rho + + def _set_gamma(self, gamma): + self.gamma = gamma + + def _set_alpha(self, alpha): + self.alpha = alpha + + def _set_sigma(self, sigma): + self.sigma = sigma + + def _set_n_components(self, n_components): + self.n_components = n_components + + def _set_pca_mat_path(self, pca_mat_path): + self.pca_mat_path = pca_mat_path + + def _set_weight_load_dir(self, weight_load_dir): + self.weight_load_dir = weight_load_dir + + def _set_timeout(self, timeout): + self.timeout = timeout + + def _get_configuration(self, level, boost_config_dict): + """Get configuration.""" + level_config = _boost_config_level.get(level) + if not boost_config_dict: + return level_config + + mode = "auto" + if 'boost' in boost_config_dict and 'mode' in boost_config_dict['boost']: + mode = boost_config_dict['boost']['mode'] + if mode not in _boost_config_mode: + raise ValueError("The boost mode must be in {}, but got {}".format(_boost_config_mode, mode)) + + if mode == "manual": + for key, value in boost_config_dict["boost"].items(): + if key in level_config: + level_config[key] = value + elif mode == "enable_all": + level_config = {key: True for key in level_config} + elif mode == "disable_all": + level_config = {key: False for key in level_config} + + self._do_new_config_func(boost_config_dict, level_config) + return level_config + + def _do_new_config_func(self, boost_config_dict, level_config): + valid_boost_each_mode_config = [] + for key, boost_each_mode_config in boost_config_dict.items(): + if key in level_config.keys() and level_config[key] or key == "common": + valid_boost_each_mode_config.append(boost_each_mode_config) + + for boost_each_mode_config in valid_boost_each_mode_config: + for key_s in boost_each_mode_config.keys(): + if key_s in self._boost_config_func_map: + self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s]) + + _boost_config_func_map = { + "fn_flag": _set_fn_flag, + "gc_flag": _set_gc_flag, + "param_groups": _set_param_groups, + "freeze_type": _set_freeze_type, + "freeze_p": _set_freeze_p, + "total_steps": _set_total_steps, + "device_number": _set_device_number, + "gradient_split_groups": _set_gradient_split_groups, + "grad_accumulation_step": _set_grad_accumulation_step, + "rho": _set_rho, + "gamma": _set_gamma, + "alpha": _set_alpha, + "sigma": _set_sigma, + "n_components": _set_n_components, + "pca_mat_path": _set_pca_mat_path, + "weight_load_dir": _set_weight_load_dir, + "timeout": _set_timeout + } diff --git a/mindspore/python/mindspore/communication/_hccl_management.py b/mindspore/python/mindspore/communication/_hccl_management.py index 0070fcba371..275308c32ac 100644 --- a/mindspore/python/mindspore/communication/_hccl_management.py +++ b/mindspore/python/mindspore/communication/_hccl_management.py @@ -1,297 +1,297 @@ -# Copyright 2020 Huawei Technologies Co., Ltd - -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""HCCL management API""" -from __future__ import absolute_import -from __future__ import division - -import ctypes -import os - -from mindspore import context -from mindspore._c_expression import get_hccl_rank_id, get_hccl_rank_size - -MAX_GROUP_NAME_LEN = 127 -MAX_RANK_NUM = 4096 -HCCL_LIB = 'libhccl_plugin.so' -HCCL_LIB_CTYPES = "" - - -def check_group(group): - """ - A function that check if a collection communication group is legal. - - Returns: - None - """ - if isinstance(group, (str)): - group_len = len(group) - if group_len > MAX_GROUP_NAME_LEN or group_len == 0: - raise ValueError("The length of communication group name must be in range [1, 127), " - "but got the value : {} ".format(group_len)) - else: - raise TypeError("The type of communication group name must be type of string, " - "but got 'group' type : {}.".format(type(group))) - - -def check_rank_num(rank_num): - """ - A function that check if a collection communication rank number is legal.If not raise error. - - Returns: - None - """ - if isinstance(rank_num, (int)): - if rank_num > MAX_RANK_NUM or rank_num <= 0: - raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" - "less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) - else: - raise TypeError("The argument 'rank_num' must be type of int, " - "but got 'rank_num' type : {}.".format(type(rank_num))) - - -def check_rank_id(rank_id): - """ - A function that check if a collection communication rank id is legal.If not raise error. - - Returns: - None - """ - if isinstance(rank_id, (int)): - if rank_id >= MAX_RANK_NUM or rank_id < 0: - raise ValueError("The rand id in the communication group must be greater or equal 0 and " - "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) - else: - raise TypeError("The rand id in the communication group must be must be type of int, " - "but got type value : {}.".format(type(rank_id))) - - -def load_lib(): - """load hccl lib""" - try: - base_dir = os.path.dirname(os.path.realpath(__file__)) - lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB) - hccl_lib = ctypes.CDLL(lib_path) - except Exception: - raise RuntimeError('Get hccl lib error.') - - global HCCL_LIB_CTYPES - HCCL_LIB_CTYPES = hccl_lib - - -def c_str(string): - """Convert a python string to C string.""" - if not isinstance(string, str): - string = string.decode('ascii') - return ctypes.c_char_p(string.encode('utf-8')) - - -def c_array(ctype, values): - """Create ctypes array from a python array.""" - return (ctype * len(values))(*values) - - -def create_group(group, rank_num, rank_ids): - """ - Create group. - - A function that creates a collection communication group which includes 'rank_num' - device and 'rank_ids' is the list of these ranks of devices. - - Note: - The world group can not be created. - - Returns: - None - """ - check_group(group) - check_rank_num(rank_num) - if isinstance(rank_ids, (list)): - if rank_num != len(rank_ids): - raise ValueError("The argument 'rank_num' number should be equal to the length " - "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." - .format(rank_num, rank_ids)) - for rank_id in rank_ids: - if not isinstance(rank_id, (int)) or rank_id < 0: - raise ValueError("The elements of argument 'rank_ids' must be " - "unsigned integer, but got the type : {}".format(type(rank_id))) - c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) - c_rank_num = ctypes.c_uint(rank_num) - c_group = c_str(group) - ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) - if ret != 0: - raise RuntimeError('Create group error, the error code is {}.'.format(ret)) - else: - raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " - "but got 'rank_ids' type : {}.".format(type(rank_ids))) - - -def destroy_group(group): - """ - A function that destroy the group which created by user. - - Note: - The world group can not be destroy. - - Returns: - None - """ - check_group(group) - c_group = c_str(group) - ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) - if ret != 0: - raise RuntimeError('Destroy group error.') - - -def get_rank_size(group="hccl_world_group"): - """ - A function that returns the number of ranks within the given collection communication group. - - Note: - The default group is hccl_world_group. - - Returns: - An integer scalar with the num of ranks. - """ - - if context.get_context("mode") == context.PYNATIVE_MODE: - return get_hccl_rank_size() - - check_group(group) - c_group = c_str(group) - c_rank_size = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) - if ret != 0: - raise RuntimeError('Get rank size error.') - - return c_rank_size.value - - -def get_rank_id(group="hccl_world_group"): - """ - A function that returns the rank id of the calling process, within the given collection communication group. - - Returns: - An integer scalar with the rank id of the calling process. - """ - - if context.get_context("mode") == context.PYNATIVE_MODE: - return get_hccl_rank_id() - - check_group(group) - c_group = c_str(group) - c_rank_id = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) - if ret != 0: - raise RuntimeError('Get rank id error.') - - return c_rank_id.value - - - -def get_local_rank_size(group="hccl_world_group"): - """ - A function that returns the number of local ranks within the given collection communication group. - - Note: - The default group is hccl_world_group. - - Returns: - An integer scalar with the num of local ranks. - """ - if context.get_context("mode") is context.PYNATIVE_MODE: - raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " - "'get_local_rank_size' only support GRAPH_MODE") - check_group(group) - c_group = c_str(group) - c_local_rank_size = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) - if ret != 0: - raise RuntimeError('Get local rank size error.') - - return c_local_rank_size.value - - -def get_local_rank_id(group="hccl_world_group"): - """ - Get local rank id. - - A function that returns the local rank id of the calling process, within the given collection communication group. - - Returns: - An integer scalar with the local rank id of the calling process. - """ - - if context.get_context("mode") is context.PYNATIVE_MODE: - raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " - "'get_local_rank_id' only support GRAPH_MODE") - check_group(group) - c_group = c_str(group) - c_local_rank_id = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) - if ret != 0: - raise RuntimeError('Get local rank id error.') - - return c_local_rank_id.value - - -def get_world_rank_from_group_rank(group, group_rank_id): - """ - Get world rank from group rank. - - A function that returns the rank id in the world group corresponding to the - rank which id is 'group_rank_id' in the user group. - - Returns: - An integer scalar with the rank id in the world group. - """ - if context.get_context("mode") is context.PYNATIVE_MODE: - raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " - "'get_world_rank_from_group_rank' only support GRAPH_MODE") - check_group(group) - check_rank_id(group_rank_id) - c_group = c_str(group) - c_group_rank_id = ctypes.c_uint(group_rank_id) - c_world_rank_id = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) - if ret != 0: - raise RuntimeError('Get world rank from group rank error.') - - return c_world_rank_id.value - - -def get_group_rank_from_world_rank(world_rank_id, group): - """ - Get group rank from world rank. - - A function that returns the rank id in the user group corresponding to the - rank which id is 'world_rank_id' in the world group. - - Returns: - An integer scalar with the rank id in the user group. - """ - if context.get_context("mode") is context.PYNATIVE_MODE: - raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " - "'get_group_rank_from_world_rank' only support GRAPH_MODE") - check_group(group) - check_rank_id(world_rank_id) - c_group = c_str(group) - c_world_rank_id = ctypes.c_uint(world_rank_id) - c_group_rank_id = ctypes.c_uint() - ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) - if ret != 0: - raise RuntimeError('Get group rank from world rank error.') - - return c_group_rank_id.value +# Copyright 2020 Huawei Technologies Co., Ltd + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""HCCL management API""" +from __future__ import absolute_import +from __future__ import division + +import ctypes +import os + +from mindspore import context +from mindspore._c_expression import get_hccl_rank_id, get_hccl_rank_size + +MAX_GROUP_NAME_LEN = 127 +MAX_RANK_NUM = 4096 +HCCL_LIB = 'libhccl_plugin.so' +HCCL_LIB_CTYPES = "" + + +def check_group(group): + """ + A function that check if a collection communication group is legal. + + Returns: + None + """ + if isinstance(group, (str)): + group_len = len(group) + if group_len > MAX_GROUP_NAME_LEN or group_len == 0: + raise ValueError("The length of communication group name must be in range [1, 127), " + "but got the value : {} ".format(group_len)) + else: + raise TypeError("The type of communication group name must be type of string, " + "but got 'group' type : {}.".format(type(group))) + + +def check_rank_num(rank_num): + """ + A function that check if a collection communication rank number is legal.If not raise error. + + Returns: + None + """ + if isinstance(rank_num, (int)): + if rank_num > MAX_RANK_NUM or rank_num <= 0: + raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" + "less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) + else: + raise TypeError("The argument 'rank_num' must be type of int, " + "but got 'rank_num' type : {}.".format(type(rank_num))) + + +def check_rank_id(rank_id): + """ + A function that check if a collection communication rank id is legal.If not raise error. + + Returns: + None + """ + if isinstance(rank_id, (int)): + if rank_id >= MAX_RANK_NUM or rank_id < 0: + raise ValueError("The rand id in the communication group must be greater or equal 0 and " + "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) + else: + raise TypeError("The rand id in the communication group must be must be type of int, " + "but got type value : {}.".format(type(rank_id))) + + +def load_lib(): + """load hccl lib""" + try: + base_dir = os.path.dirname(os.path.realpath(__file__)) + lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB) + hccl_lib = ctypes.CDLL(lib_path) + except Exception: + raise RuntimeError('Get hccl lib error.') + + global HCCL_LIB_CTYPES + HCCL_LIB_CTYPES = hccl_lib + + +def c_str(string): + """Convert a python string to C string.""" + if not isinstance(string, str): + string = string.decode('ascii') + return ctypes.c_char_p(string.encode('utf-8')) + + +def c_array(ctype, values): + """Create ctypes array from a python array.""" + return (ctype * len(values))(*values) + + +def create_group(group, rank_num, rank_ids): + """ + Create group. + + A function that creates a collection communication group which includes 'rank_num' + device and 'rank_ids' is the list of these ranks of devices. + + Note: + The world group can not be created. + + Returns: + None + """ + check_group(group) + check_rank_num(rank_num) + if isinstance(rank_ids, (list)): + if rank_num != len(rank_ids): + raise ValueError("The argument 'rank_num' number should be equal to the length " + "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." + .format(rank_num, rank_ids)) + for rank_id in rank_ids: + if not isinstance(rank_id, (int)) or rank_id < 0: + raise ValueError("The elements of argument 'rank_ids' must be " + "unsigned integer, but got the type : {}".format(type(rank_id))) + c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) + c_rank_num = ctypes.c_uint(rank_num) + c_group = c_str(group) + ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) + if ret != 0: + raise RuntimeError('Create group error, the error code is {}.'.format(ret)) + else: + raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " + "but got 'rank_ids' type : {}.".format(type(rank_ids))) + + +def destroy_group(group): + """ + A function that destroy the group which created by user. + + Note: + The world group can not be destroy. + + Returns: + None + """ + check_group(group) + c_group = c_str(group) + ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) + if ret != 0: + raise RuntimeError('Destroy group error.') + + +def get_rank_size(group="hccl_world_group"): + """ + A function that returns the number of ranks within the given collection communication group. + + Note: + The default group is hccl_world_group. + + Returns: + An integer scalar with the num of ranks. + """ + + if context.get_context("mode") == context.PYNATIVE_MODE: + return get_hccl_rank_size() + + check_group(group) + c_group = c_str(group) + c_rank_size = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) + if ret != 0: + raise RuntimeError('Get rank size error.') + + return c_rank_size.value + + +def get_rank_id(group="hccl_world_group"): + """ + A function that returns the rank id of the calling process, within the given collection communication group. + + Returns: + An integer scalar with the rank id of the calling process. + """ + + if context.get_context("mode") == context.PYNATIVE_MODE: + return get_hccl_rank_id() + + check_group(group) + c_group = c_str(group) + c_rank_id = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) + if ret != 0: + raise RuntimeError('Get rank id error.') + + return c_rank_id.value + + + +def get_local_rank_size(group="hccl_world_group"): + """ + A function that returns the number of local ranks within the given collection communication group. + + Note: + The default group is hccl_world_group. + + Returns: + An integer scalar with the num of local ranks. + """ + if context.get_context("mode") is context.PYNATIVE_MODE: + raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " + "'get_local_rank_size' only support GRAPH_MODE") + check_group(group) + c_group = c_str(group) + c_local_rank_size = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) + if ret != 0: + raise RuntimeError('Get local rank size error.') + + return c_local_rank_size.value + + +def get_local_rank_id(group="hccl_world_group"): + """ + Get local rank id. + + A function that returns the local rank id of the calling process, within the given collection communication group. + + Returns: + An integer scalar with the local rank id of the calling process. + """ + + if context.get_context("mode") is context.PYNATIVE_MODE: + raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " + "'get_local_rank_id' only support GRAPH_MODE") + check_group(group) + c_group = c_str(group) + c_local_rank_id = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) + if ret != 0: + raise RuntimeError('Get local rank id error.') + + return c_local_rank_id.value + + +def get_world_rank_from_group_rank(group, group_rank_id): + """ + Get world rank from group rank. + + A function that returns the rank id in the world group corresponding to the + rank which id is 'group_rank_id' in the user group. + + Returns: + An integer scalar with the rank id in the world group. + """ + if context.get_context("mode") is context.PYNATIVE_MODE: + raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " + "'get_world_rank_from_group_rank' only support GRAPH_MODE") + check_group(group) + check_rank_id(group_rank_id) + c_group = c_str(group) + c_group_rank_id = ctypes.c_uint(group_rank_id) + c_world_rank_id = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) + if ret != 0: + raise RuntimeError('Get world rank from group rank error.') + + return c_world_rank_id.value + + +def get_group_rank_from_world_rank(world_rank_id, group): + """ + Get group rank from world rank. + + A function that returns the rank id in the user group corresponding to the + rank which id is 'world_rank_id' in the world group. + + Returns: + An integer scalar with the rank id in the user group. + """ + if context.get_context("mode") is context.PYNATIVE_MODE: + raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " + "'get_group_rank_from_world_rank' only support GRAPH_MODE") + check_group(group) + check_rank_id(world_rank_id) + c_group = c_str(group) + c_world_rank_id = ctypes.c_uint(world_rank_id) + c_group_rank_id = ctypes.c_uint() + ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) + if ret != 0: + raise RuntimeError('Get group rank from world rank error.') + + return c_group_rank_id.value diff --git a/mindspore/python/mindspore/communication/management.py b/mindspore/python/mindspore/communication/management.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/dataset/vision/__init__.py b/mindspore/python/mindspore/dataset/vision/__init__.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/dataset/vision/utils.py b/mindspore/python/mindspore/dataset/vision/utils.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/layer/embedding.py b/mindspore/python/mindspore/nn/layer/embedding.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/adam.py b/mindspore/python/mindspore/nn/optim/adam.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/asgd.py b/mindspore/python/mindspore/nn/optim/asgd.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/lamb.py b/mindspore/python/mindspore/nn/optim/lamb.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/lars.py b/mindspore/python/mindspore/nn/optim/lars.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/momentum.py b/mindspore/python/mindspore/nn/optim/momentum.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/python/mindspore/nn/optim/proximal_ada_grad.py index 20a7f1e078d..b8a5edc2d35 100644 --- a/mindspore/python/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/python/mindspore/nn/optim/proximal_ada_grad.py @@ -1,242 +1,242 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""PROXIMAL_ADA_GRAD""" -from __future__ import absolute_import - -from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.common import Tensor -import mindspore.common.dtype as mstype -from mindspore.common.api import jit -from mindspore import _checkparam as validator -from mindspore.nn.optim.optimizer import Optimizer -from mindspore.nn.optim.optimizer import opt_init_args_register - -_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") - -@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", - "Tensor") -def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): - """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" - success = True - success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices)) - return success - - -@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): - """Apply proximal_ada_grad optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) - return success - - -def _check_param_value(accum, l1, l2, use_locking, prim_name=None): - """Check inputs param.""" - validator.check_value_type("accum", accum, [float], prim_name) - validator.check_value_type("l1", l1, [float], prim_name) - validator.check_value_type("l2", l2, [float], prim_name) - validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_non_negative_float(accum, "accum", prim_name) - validator.check_non_negative_float(l1, "l1", prim_name) - validator.check_non_negative_float(l2, "l2", prim_name) - - -class ProximalAdagrad(Optimizer): - r""" - Implements the ProximalAdagrad algorithm that is an online Learning and Stochastic Optimization. - Refer to paper `Efficient Learning using Forward-Backward Splitting - `_. - - .. math:: - accum_{t+1} = accum_{t} + g * g - - .. math:: - \text{prox_v} = w_{t} - \gamma * g * \frac{1}{\sqrt{accum_{t+1}}} - - .. math:: - w_{t+1} = \frac{sign(\text{prox_v})}{1 + \gamma * l2} * \max(\left| \text{prox_v} \right| - \gamma * l1, 0) - - Here : where :math:`g` , :math:`\gamma`, :math:`w` , :math:`accum` and :math:`t` denote the `grads`, - `learning_rate`, `params`, accumulation and current step respectively. - - Note: - The sparse strategy is applied while the SparseGatherV2 operator is used for forward network. If the sparse - strategy wants to be executed on the host, set the target to the CPU. - The sparse feature is under continuous development. - - If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without - 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When - parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be - applied. - - Args: - params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the - `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and - "order_params" are the keys can be parsed. - - - params: Required. Parameters in current group. The value must be a list of `Parameter`. - - - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. - If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported. - - - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay - will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight - decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic - weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only - with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule - to get the weight decay value of current step. - - - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value - will be used. If not, the `grad_centralization` is False by default. This configuration only works on the - convolution layer. - - - order_params: Optional. When parameters are grouped, this usually is used to maintain the order of - parameters that appeared in the network to improve performance. The value should be parameters whose - order will be followed in optimizer. - If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in - one group of `params`. - - accum (float): The starting value for accumulators `accum`, must be zero or positive values. Default: ``0.1`` . - learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` . - - - float: The fixed learning rate value. Must be equal to or greater than 0. - - - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float. - - - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. - For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate. - - - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate. - - - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of - `LearningRateSchedule - `_ - with step as the input to get the learning rate of the current step. - - l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` . - l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` . - use_locking (bool): If true, use locks for updating operation. Default: ``False`` . - loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value. - Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in - `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in - `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details. - Default: ``1.0`` . - weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` . - - - float: The fixed weight decay value. Must be equal to or greater than 0. - - - int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float. - - - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of - the Cell with step as the input to get the weight decay value of current step. - - Inputs: - - **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params` - in optimizer. - - Outputs: - Tensor[bool], the value is True. - - Raises: - TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. - TypeError: If element of `parameters` is neither Parameter nor dict. - TypeError: If `accum`, `l1`, `l2` or `loss_scale` is not a float. - TypeError: If `weight_decay` is neither float nor int. - ValueError: If `loss_scale` is less than or equal to 0. - ValueError: If `accum`, `l1`, `l2` or `weight_decay` is less than 0. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> import mindspore as ms - >>> from mindspore import nn - >>> - >>> # Define the network structure of LeNet5. Refer to - >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py - >>> net = LeNet5() - >>> #1) All parameters use the same learning rate and weight decay - >>> optim = nn.ProximalAdagrad(params=net.trainable_params()) - >>> - >>> #2) Use parameter groups and set different values - >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) - >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, - ... {'params': no_conv_params, 'lr': 0.01}, - ... {'order_params': net.trainable_params()}] - >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad - >>> # centralization of True. - >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad - >>> # centralization of False. - >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. - >>> - >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim) - """ - - @opt_init_args_register - def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, - use_locking=False, loss_scale=1.0, weight_decay=0.0): - super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) - _check_param_value(accum, l1, l2, use_locking, self.cls_name) - self.accum = self._parameters.clone(prefix="accum", init=accum) - self.l1 = Tensor(l1, mstype.float32) - self.l2 = Tensor(l2, mstype.float32) - self.use_locking = use_locking - self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) - self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking) - - @jit - def construct(self, grads): - params = self._parameters - accum = self.accum - grads = self.flatten_gradients(grads) - grads = self.decay_weight(grads) - grads = self.gradients_centralization(grads) - grads = self.scale_grad(grads) - grads = self._grad_sparse_indices_deduplicate(grads) - lr = self.get_lr() - self.assignadd(self.global_step, self.global_step_increase_tensor) - if self.is_group_lr: - success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), - lr, grads, params, accum) - else: - success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, - lr), - grads, params, accum) - return success - - @Optimizer.target.setter - def target(self, value): - """ - If the input value is set to "CPU", the parameters will be updated on the host using the Fused - optimizer operation. - """ - if not isinstance(value, str): - raise TypeError("For 'ProximalAdagrad', the property 'target' must be string type, " - "but got {}".format(type(value))) - - if value not in ('CPU', 'Ascend', 'GPU'): - raise ValueError("For 'ProximalAdagrad', the property 'target' must be 'CPU', 'Ascend' or 'GPU', " - "but got {}.".format(value)) - - if value == 'CPU': - self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking) - self.sparse_opt.set_device("CPU") - else: - self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking) - - self._target = value +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PROXIMAL_ADA_GRAD""" +from __future__ import absolute_import + +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.common import Tensor +import mindspore.common.dtype as mstype +from mindspore.common.api import jit +from mindspore import _checkparam as validator +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.nn.optim.optimizer import opt_init_args_register + +_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") + +@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", + "Tensor") +def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): + """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" + success = True + success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices)) + return success + + +@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): + """Apply proximal_ada_grad optimizer to the weight parameter.""" + success = True + success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) + return success + + +def _check_param_value(accum, l1, l2, use_locking, prim_name=None): + """Check inputs param.""" + validator.check_value_type("accum", accum, [float], prim_name) + validator.check_value_type("l1", l1, [float], prim_name) + validator.check_value_type("l2", l2, [float], prim_name) + validator.check_value_type("use_locking", use_locking, [bool], prim_name) + validator.check_non_negative_float(accum, "accum", prim_name) + validator.check_non_negative_float(l1, "l1", prim_name) + validator.check_non_negative_float(l2, "l2", prim_name) + + +class ProximalAdagrad(Optimizer): + r""" + Implements the ProximalAdagrad algorithm that is an online Learning and Stochastic Optimization. + Refer to paper `Efficient Learning using Forward-Backward Splitting + `_. + + .. math:: + accum_{t+1} = accum_{t} + g * g + + .. math:: + \text{prox_v} = w_{t} - \gamma * g * \frac{1}{\sqrt{accum_{t+1}}} + + .. math:: + w_{t+1} = \frac{sign(\text{prox_v})}{1 + \gamma * l2} * \max(\left| \text{prox_v} \right| - \gamma * l1, 0) + + Here : where :math:`g` , :math:`\gamma`, :math:`w` , :math:`accum` and :math:`t` denote the `grads`, + `learning_rate`, `params`, accumulation and current step respectively. + + Note: + The sparse strategy is applied while the SparseGatherV2 operator is used for forward network. If the sparse + strategy wants to be executed on the host, set the target to the CPU. + The sparse feature is under continuous development. + + If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without + 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When + parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be + applied. + + Args: + params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the + `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and + "order_params" are the keys can be parsed. + + - params: Required. Parameters in current group. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight + decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic + weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only + with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule + to get the weight decay value of current step. + + - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value + will be used. If not, the `grad_centralization` is False by default. This configuration only works on the + convolution layer. + + - order_params: Optional. When parameters are grouped, this usually is used to maintain the order of + parameters that appeared in the network to improve performance. The value should be parameters whose + order will be followed in optimizer. + If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in + one group of `params`. + + accum (float): The starting value for accumulators `accum`, must be zero or positive values. Default: ``0.1`` . + learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` . + + - float: The fixed learning rate value. Must be equal to or greater than 0. + + - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float. + + - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. + For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate. + + - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate. + + - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of + `LearningRateSchedule + `_ + with step as the input to get the learning rate of the current step. + + l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: ``0.0`` . + l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: ``0.0`` . + use_locking (bool): If true, use locks for updating operation. Default: ``False`` . + loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value. + Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in + `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in + `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details. + Default: ``1.0`` . + weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` . + + - float: The fixed weight decay value. Must be equal to or greater than 0. + + - int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float. + + - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of + the Cell with step as the input to get the weight decay value of current step. + + Inputs: + - **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params` + in optimizer. + + Outputs: + Tensor[bool], the value is True. + + Raises: + TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. + TypeError: If element of `parameters` is neither Parameter nor dict. + TypeError: If `accum`, `l1`, `l2` or `loss_scale` is not a float. + TypeError: If `weight_decay` is neither float nor int. + ValueError: If `loss_scale` is less than or equal to 0. + ValueError: If `accum`, `l1`, `l2` or `weight_decay` is less than 0. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore as ms + >>> from mindspore import nn + >>> + >>> # Define the network structure of LeNet5. Refer to + >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py + >>> net = LeNet5() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.ProximalAdagrad(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad + >>> # centralization of True. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad + >>> # centralization of False. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim) + """ + + @opt_init_args_register + def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, + use_locking=False, loss_scale=1.0, weight_decay=0.0): + super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) + _check_param_value(accum, l1, l2, use_locking, self.cls_name) + self.accum = self._parameters.clone(prefix="accum", init=accum) + self.l1 = Tensor(l1, mstype.float32) + self.l2 = Tensor(l2, mstype.float32) + self.use_locking = use_locking + self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) + self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking) + + @jit + def construct(self, grads): + params = self._parameters + accum = self.accum + grads = self.flatten_gradients(grads) + grads = self.decay_weight(grads) + grads = self.gradients_centralization(grads) + grads = self.scale_grad(grads) + grads = self._grad_sparse_indices_deduplicate(grads) + lr = self.get_lr() + self.assignadd(self.global_step, self.global_step_increase_tensor) + if self.is_group_lr: + success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), + lr, grads, params, accum) + else: + success = self.map_reverse(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, + lr), + grads, params, accum) + return success + + @Optimizer.target.setter + def target(self, value): + """ + If the input value is set to "CPU", the parameters will be updated on the host using the Fused + optimizer operation. + """ + if not isinstance(value, str): + raise TypeError("For 'ProximalAdagrad', the property 'target' must be string type, " + "but got {}".format(type(value))) + + if value not in ('CPU', 'Ascend', 'GPU'): + raise ValueError("For 'ProximalAdagrad', the property 'target' must be 'CPU', 'Ascend' or 'GPU', " + "but got {}.".format(value)) + + if value == 'CPU': + self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking) + self.sparse_opt.set_device("CPU") + else: + self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking) + + self._target = value diff --git a/mindspore/python/mindspore/nn/optim/rprop.py b/mindspore/python/mindspore/nn/optim/rprop.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/optim/sgd.py b/mindspore/python/mindspore/nn/optim/sgd.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/nn/probability/bnn_layers/__init__.py b/mindspore/python/mindspore/nn/probability/bnn_layers/__init__.py index 905a5e1bebe..d11a5d47335 100644 --- a/mindspore/python/mindspore/nn/probability/bnn_layers/__init__.py +++ b/mindspore/python/mindspore/nn/probability/bnn_layers/__init__.py @@ -1,29 +1,29 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -`bnn_layers` are the high-level components used to construct the bayesian neural network. - -""" -from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper -from .conv_variational import ConvReparam -from .dense_variational import DenseReparam, DenseLocalReparam -from .layer_distribution import NormalPrior, NormalPosterior -from .bnn_cell_wrapper import WithBNNLossCell - -__all__ = [] -__all__.extend(conv_variational.__all__) -__all__.extend(dense_variational.__all__) -__all__.extend(layer_distribution.__all__) -__all__.extend(bnn_cell_wrapper.__all__) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +`bnn_layers` are the high-level components used to construct the bayesian neural network. + +""" +from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper +from .conv_variational import ConvReparam +from .dense_variational import DenseReparam, DenseLocalReparam +from .layer_distribution import NormalPrior, NormalPosterior +from .bnn_cell_wrapper import WithBNNLossCell + +__all__ = [] +__all__.extend(conv_variational.__all__) +__all__.extend(dense_variational.__all__) +__all__.extend(layer_distribution.__all__) +__all__.extend(bnn_cell_wrapper.__all__) diff --git a/mindspore/python/mindspore/nn/probability/bnn_layers/_util.py b/mindspore/python/mindspore/nn/probability/bnn_layers/_util.py index 775f376072b..f0e07318034 100644 --- a/mindspore/python/mindspore/nn/probability/bnn_layers/_util.py +++ b/mindspore/python/mindspore/nn/probability/bnn_layers/_util.py @@ -1,46 +1,46 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Utility functions to help bnn layers.""" -from mindspore.common.tensor import Tensor -from ...cell import Cell - - -def check_prior(prior_fn, arg_name): - """check prior distribution of bnn layers.""" - if isinstance(prior_fn, Cell): - prior = prior_fn - else: - prior = prior_fn() - for prior_name, prior_dist in prior.name_cells().items(): - if prior_name != 'normal': - raise TypeError(f"The type of distribution of `{arg_name}` must be `normal`") - if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and - isinstance(getattr(prior_dist, '_sd_value'), Tensor)): - raise TypeError(f"The input form of `{arg_name}` is incorrect") - return prior - - -def check_posterior(posterior_fn, shape, param_name, arg_name): - """check posterior distribution of bnn layers.""" - try: - posterior = posterior_fn(shape=shape, name=param_name) - except TypeError: - raise TypeError(f'The type of `{arg_name}` must be `NormalPosterior`') - finally: - pass - for posterior_name, _ in posterior.name_cells().items(): - if posterior_name != 'normal': - raise TypeError(f"The type of distribution of `{arg_name}` must be `normal`") - return posterior +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utility functions to help bnn layers.""" +from mindspore.common.tensor import Tensor +from ...cell import Cell + + +def check_prior(prior_fn, arg_name): + """check prior distribution of bnn layers.""" + if isinstance(prior_fn, Cell): + prior = prior_fn + else: + prior = prior_fn() + for prior_name, prior_dist in prior.name_cells().items(): + if prior_name != 'normal': + raise TypeError(f"The type of distribution of `{arg_name}` must be `normal`") + if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and + isinstance(getattr(prior_dist, '_sd_value'), Tensor)): + raise TypeError(f"The input form of `{arg_name}` is incorrect") + return prior + + +def check_posterior(posterior_fn, shape, param_name, arg_name): + """check posterior distribution of bnn layers.""" + try: + posterior = posterior_fn(shape=shape, name=param_name) + except TypeError: + raise TypeError(f'The type of `{arg_name}` must be `NormalPosterior`') + finally: + pass + for posterior_name, _ in posterior.name_cells().items(): + if posterior_name != 'normal': + raise TypeError(f"The type of distribution of `{arg_name}` must be `normal`") + return posterior diff --git a/mindspore/python/mindspore/nn/probability/bnn_layers/conv_variational.py b/mindspore/python/mindspore/nn/probability/bnn_layers/conv_variational.py index cec6898be7e..215dadf124a 100644 --- a/mindspore/python/mindspore/nn/probability/bnn_layers/conv_variational.py +++ b/mindspore/python/mindspore/nn/probability/bnn_layers/conv_variational.py @@ -1,267 +1,267 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Convolutional variational layers.""" -from mindspore.ops import operations as P -from mindspore._checkparam import twice -from ...layer.conv import _Conv -from .layer_distribution import NormalPrior, normal_post_fn -from ._util import check_prior, check_posterior - -__all__ = ['ConvReparam'] - - -class _ConvVariational(_Conv): - """ - Base class for all convolutional variational layers. - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - has_bias=False, - weight_prior_fn=NormalPrior, - weight_posterior_fn=normal_post_fn, - bias_prior_fn=NormalPrior, - bias_posterior_fn=normal_post_fn): - kernel_size = twice(kernel_size) - stride = twice(stride) - dilation = twice(dilation) - super(_ConvVariational, self).__init__( - in_channels, - out_channels, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - has_bias, - weight_init='normal', - bias_init='zeros' - ) - if pad_mode not in ('valid', 'same', 'pad'): - raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' - + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') - - # convolution args - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.pad_mode = pad_mode - self.padding = padding - self.dilation = dilation - self.group = group - self.has_bias = has_bias - - self.shape = [self.out_channels, self.in_channels // self.group, *self.kernel_size] - - self.weight.requires_grad = False - self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn") - self.weight_posterior = check_posterior(weight_posterior_fn, shape=self.shape, param_name='bnn_weight', - arg_name="weight_posterior_fn") - - if self.has_bias: - self.bias.requires_grad = False - self.bias_prior = check_prior(bias_prior_fn, "bias_prior_fn") - self.bias_posterior = check_posterior(bias_posterior_fn, shape=[self.out_channels], param_name='bnn_bias', - arg_name="bias_posterior_fn") - - # mindspore operations - self.bias_add = P.BiasAdd() - self.conv2d = P.Conv2D(out_channel=self.out_channels, - kernel_size=self.kernel_size, - mode=1, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation, - group=self.group) - - self.log = P.Log() - self.sum = P.ReduceSum() - - def construct(self, inputs): - outputs = self._apply_variational_weight(inputs) - if self.has_bias: - outputs = self.apply_variational_bias(outputs) - return outputs - - def extend_repr(self): - """Display instance object as string.""" - s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \ - 'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \ - .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, - self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std, - self.has_bias) - if self.has_bias: - s += ', bias_mean={}, bias_std={}' \ - .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std) - return s - - def compute_kl_loss(self): - """Compute kl loss""" - weight_type = self.weight_posterior("get_dist_type") - weight_args_list = self.weight_posterior("get_dist_args") - - kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) - kl_loss = self.sum(kl) - if self.has_bias: - bias_args_list = self.bias_posterior("get_dist_args") - bias_type = self.bias_posterior("get_dist_type") - - kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) - kl = self.sum(kl) - kl_loss += kl - return kl_loss - - def apply_variational_bias(self, inputs): - """Calculate bias.""" - bias_posterior_tensor = self.bias_posterior("sample") - return self.bias_add(inputs, bias_posterior_tensor) - - -class ConvReparam(_ConvVariational): - r""" - Convolutional variational layers with Reparameterization. - - For more details, refer to the paper `Auto-Encoding Variational Bayes `_. - - Args: - in_channels (int): The number of input channel :math:`C_{in}`. - out_channels (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple[int]]): The data type is an integer or - a tuple of 2 integers. The kernel size specifies the height and - width of the 2D convolution window. a single integer stands for the - value is for both height and width of the kernel. With the `kernel_size` - being a tuple of 2 integers, the first value is for the height and the other - is the width of the kernel. - stride(Union[int, tuple[int]]): The distance of kernel moving, - an integer number represents that the height and width of movement - are both strides, or a tuple of two integers numbers represents that - height and width of movement respectively. Default: ``1`` . - pad_mode (str): Specifies the padding mode. The optional values are - ``"same"`` , ``"valid"`` , and ``"pad"`` . Default: ``"same"`` . - - - ``"same"``: Adopts the way of completion. Output height and width - will be the same as the input. - The total number of padding will be calculated for in horizontal and - vertical directions and evenly distributed to top and bottom, - left and right if possible. Otherwise, the last extra padding - will be done from the bottom and the right side. If this mode - is set, `padding` must be 0. - - - ``"valid"``: Adopts the way of discarding. The possible largest - height and width of the output will be returned without padding. - Extra pixels will be discarded. If this mode is set, `padding` - must be 0. - - - ``"pad"``: Implicit paddings on both sides of the input. The number - of `padding` will be padded to the input Tensor borders. - `padding` must be greater than or equal to 0. - - padding (Union[int, tuple[int]]): Implicit paddings on both sides of - the input. Default: ``0`` . - dilation (Union[int, tuple[int]]): The data type is an integer or a tuple - of 2 integers. This parameter specifies the dilation rate of the - dilated convolution. If set to be :math:`k > 1`, - there will be :math:`k - 1` pixels skipped for each sampling - location. Its value must be greater or equal to 1 and bounded - by the height and width of the input. Default: ``1`` . - group (int): Splits filter into groups, `in_ channels` and - `out_channels` must be divisible by the number of groups. - Default: ``1`` . - has_bias (bool): Specifies whether the layer uses a bias vector. - Default: ``False`` . - weight_prior_fn (Cell): The prior distribution for weight. - It must return a mindspore distribution instance. - Default: ``NormalPrior`` . (which creates an instance of standard - normal distribution). The current version only supports normal distribution. - weight_posterior_fn (function): The posterior distribution for sampling weight. - It must be a function handle which returns a mindspore - distribution instance. Default: ``normal_post_fn`` . - The current version only supports normal distribution. - bias_prior_fn (Cell): The prior distribution for bias vector. It must return - a mindspore distribution. Default: ``NormalPrior`` (which creates an - instance of standard normal distribution). The current version - only supports normal distribution. - bias_posterior_fn (function): The posterior distribution for sampling bias vector. - It must be a function handle which returns a mindspore - distribution instance. Default: ``normal_post_fn`` . - The current version only supports normal distribution. - - Inputs: - - **input** (Tensor) - The shape of the tensor is :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> import numpy as np - >>> import mindspore - >>> from mindspore import Tensor - >>> from mindspore.nn.probability import bnn_layers - >>> net = bnn_layers.ConvReparam(120, 240, 4, has_bias=False) - >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) - >>> output = net(input).shape - >>> print(output) - (1, 240, 1024, 640) - """ - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - has_bias=False, - weight_prior_fn=NormalPrior, - weight_posterior_fn=normal_post_fn, - bias_prior_fn=NormalPrior, - bias_posterior_fn=normal_post_fn): - super(ConvReparam, self).__init__( - in_channels, - out_channels, - kernel_size, - stride=stride, - pad_mode=pad_mode, - padding=padding, - dilation=dilation, - group=group, - has_bias=has_bias, - weight_prior_fn=weight_prior_fn, - weight_posterior_fn=weight_posterior_fn, - bias_prior_fn=bias_prior_fn, - bias_posterior_fn=bias_posterior_fn - ) - - def _apply_variational_weight(self, inputs): - """Calculate weight.""" - weight_posterior_tensor = self.weight_posterior("sample") - outputs = self.conv2d(inputs, weight_posterior_tensor) - return outputs +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Convolutional variational layers.""" +from mindspore.ops import operations as P +from mindspore._checkparam import twice +from ...layer.conv import _Conv +from .layer_distribution import NormalPrior, normal_post_fn +from ._util import check_prior, check_posterior + +__all__ = ['ConvReparam'] + + +class _ConvVariational(_Conv): + """ + Base class for all convolutional variational layers. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_prior_fn=NormalPrior, + weight_posterior_fn=normal_post_fn, + bias_prior_fn=NormalPrior, + bias_posterior_fn=normal_post_fn): + kernel_size = twice(kernel_size) + stride = twice(stride) + dilation = twice(dilation) + super(_ConvVariational, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init='normal', + bias_init='zeros' + ) + if pad_mode not in ('valid', 'same', 'pad'): + raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' + + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + + # convolution args + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.pad_mode = pad_mode + self.padding = padding + self.dilation = dilation + self.group = group + self.has_bias = has_bias + + self.shape = [self.out_channels, self.in_channels // self.group, *self.kernel_size] + + self.weight.requires_grad = False + self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn") + self.weight_posterior = check_posterior(weight_posterior_fn, shape=self.shape, param_name='bnn_weight', + arg_name="weight_posterior_fn") + + if self.has_bias: + self.bias.requires_grad = False + self.bias_prior = check_prior(bias_prior_fn, "bias_prior_fn") + self.bias_posterior = check_posterior(bias_posterior_fn, shape=[self.out_channels], param_name='bnn_bias', + arg_name="bias_posterior_fn") + + # mindspore operations + self.bias_add = P.BiasAdd() + self.conv2d = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + + self.log = P.Log() + self.sum = P.ReduceSum() + + def construct(self, inputs): + outputs = self._apply_variational_weight(inputs) + if self.has_bias: + outputs = self.apply_variational_bias(outputs) + return outputs + + def extend_repr(self): + """Display instance object as string.""" + s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \ + 'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \ + .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, + self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std, + self.has_bias) + if self.has_bias: + s += ', bias_mean={}, bias_std={}' \ + .format(self.bias_posterior.mean, self.bias_posterior.untransformed_std) + return s + + def compute_kl_loss(self): + """Compute kl loss""" + weight_type = self.weight_posterior("get_dist_type") + weight_args_list = self.weight_posterior("get_dist_args") + + kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) + kl_loss = self.sum(kl) + if self.has_bias: + bias_args_list = self.bias_posterior("get_dist_args") + bias_type = self.bias_posterior("get_dist_type") + + kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) + kl = self.sum(kl) + kl_loss += kl + return kl_loss + + def apply_variational_bias(self, inputs): + """Calculate bias.""" + bias_posterior_tensor = self.bias_posterior("sample") + return self.bias_add(inputs, bias_posterior_tensor) + + +class ConvReparam(_ConvVariational): + r""" + Convolutional variational layers with Reparameterization. + + For more details, refer to the paper `Auto-Encoding Variational Bayes `_. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple[int]]): The data type is an integer or + a tuple of 2 integers. The kernel size specifies the height and + width of the 2D convolution window. a single integer stands for the + value is for both height and width of the kernel. With the `kernel_size` + being a tuple of 2 integers, the first value is for the height and the other + is the width of the kernel. + stride(Union[int, tuple[int]]): The distance of kernel moving, + an integer number represents that the height and width of movement + are both strides, or a tuple of two integers numbers represents that + height and width of movement respectively. Default: ``1`` . + pad_mode (str): Specifies the padding mode. The optional values are + ``"same"`` , ``"valid"`` , and ``"pad"`` . Default: ``"same"`` . + + - ``"same"``: Adopts the way of completion. Output height and width + will be the same as the input. + The total number of padding will be calculated for in horizontal and + vertical directions and evenly distributed to top and bottom, + left and right if possible. Otherwise, the last extra padding + will be done from the bottom and the right side. If this mode + is set, `padding` must be 0. + + - ``"valid"``: Adopts the way of discarding. The possible largest + height and width of the output will be returned without padding. + Extra pixels will be discarded. If this mode is set, `padding` + must be 0. + + - ``"pad"``: Implicit paddings on both sides of the input. The number + of `padding` will be padded to the input Tensor borders. + `padding` must be greater than or equal to 0. + + padding (Union[int, tuple[int]]): Implicit paddings on both sides of + the input. Default: ``0`` . + dilation (Union[int, tuple[int]]): The data type is an integer or a tuple + of 2 integers. This parameter specifies the dilation rate of the + dilated convolution. If set to be :math:`k > 1`, + there will be :math:`k - 1` pixels skipped for each sampling + location. Its value must be greater or equal to 1 and bounded + by the height and width of the input. Default: ``1`` . + group (int): Splits filter into groups, `in_ channels` and + `out_channels` must be divisible by the number of groups. + Default: ``1`` . + has_bias (bool): Specifies whether the layer uses a bias vector. + Default: ``False`` . + weight_prior_fn (Cell): The prior distribution for weight. + It must return a mindspore distribution instance. + Default: ``NormalPrior`` . (which creates an instance of standard + normal distribution). The current version only supports normal distribution. + weight_posterior_fn (function): The posterior distribution for sampling weight. + It must be a function handle which returns a mindspore + distribution instance. Default: ``normal_post_fn`` . + The current version only supports normal distribution. + bias_prior_fn (Cell): The prior distribution for bias vector. It must return + a mindspore distribution. Default: ``NormalPrior`` (which creates an + instance of standard normal distribution). The current version + only supports normal distribution. + bias_posterior_fn (function): The posterior distribution for sampling bias vector. + It must be a function handle which returns a mindspore + distribution instance. Default: ``normal_post_fn`` . + The current version only supports normal distribution. + + Inputs: + - **input** (Tensor) - The shape of the tensor is :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindspore.nn.probability import bnn_layers + >>> net = bnn_layers.ConvReparam(120, 240, 4, has_bias=False) + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> output = net(input).shape + >>> print(output) + (1, 240, 1024, 640) + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_prior_fn=NormalPrior, + weight_posterior_fn=normal_post_fn, + bias_prior_fn=NormalPrior, + bias_posterior_fn=normal_post_fn): + super(ConvReparam, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_prior_fn=weight_prior_fn, + weight_posterior_fn=weight_posterior_fn, + bias_prior_fn=bias_prior_fn, + bias_posterior_fn=bias_posterior_fn + ) + + def _apply_variational_weight(self, inputs): + """Calculate weight.""" + weight_posterior_tensor = self.weight_posterior("sample") + outputs = self.conv2d(inputs, weight_posterior_tensor) + return outputs diff --git a/mindspore/python/mindspore/nn/probability/bnn_layers/layer_distribution.py b/mindspore/python/mindspore/nn/probability/bnn_layers/layer_distribution.py index e7114eac9b2..6efa1f4a2fc 100644 --- a/mindspore/python/mindspore/nn/probability/bnn_layers/layer_distribution.py +++ b/mindspore/python/mindspore/nn/probability/bnn_layers/layer_distribution.py @@ -1,123 +1,123 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Initialize normal distributions""" -import numpy as np -import mindspore.common.dtype as mstype -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter -from mindspore.ops import operations as P -from ...cell import Cell -from ..distribution.normal import Normal - -__all__ = ['NormalPrior', 'NormalPosterior'] - - -class NormalPrior(Cell): - r""" - To initialize a normal distribution of mean 0 and standard deviation 0.1. - - Args: - dtype (mindspore.dtype): The argument is used to define the data type of the output tensor. - Default: ``mindspore.float32`` . - mean (int, float): Mean of normal distribution. Default: ``0`` . - std (int, float): Standard deviation of normal distribution. Default: ``0.1`` . - - Returns: - Cell, a normal distribution. - - Supported Platforms: - ``Ascend`` ``GPU`` - """ - def __init__(self, dtype=mstype.float32, mean=0, std=0.1): - super(NormalPrior, self).__init__() - self.normal = Normal(mean, std, dtype=dtype) - - def construct(self, *inputs): - return self.normal(*inputs) - - -class NormalPosterior(Cell): - r""" - Build Normal distributions with trainable parameters. - - Args: - name (str): Name prepended to trainable parameter. - shape (list, tuple): Shape of the mean and standard deviation. - dtype (mindspore.dtype): The argument is used to define the data type of the output tensor. - Default: ``mindspore.float32`` . - loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: ``0`` . - loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: ``0.1`` . - untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters. - Default: ``-5`` . - untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters. - Default: ``0.1`` . - - Returns: - Cell, a normal distribution. - - Supported Platforms: - ``Ascend`` ``GPU`` - """ - def __init__(self, - name, - shape, - dtype=mstype.float32, - loc_mean=0, - loc_std=0.1, - untransformed_scale_mean=-5, - untransformed_scale_std=0.1): - super(NormalPosterior, self).__init__() - if not isinstance(name, str): - raise TypeError('The type of `name` must be `str`') - - if not isinstance(shape, (tuple, list)): - raise TypeError('The type of `shape` must be `tuple` or `list`') - - if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)): - raise TypeError('The type of `loc_mean` must be `int` or `float`') - - if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)): - raise TypeError('The type of `untransformed_scale_mean` must be `int` or `float`') - - if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0): - raise TypeError('The type of `loc_std` must be `int` or `float` and its value must > 0') - - if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and - untransformed_scale_std >= 0): - raise TypeError('The type of `untransformed_scale_std` must be `int` or `float` and ' - 'its value must > 0') - - self.mean = Parameter( - Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') - - self.untransformed_std = Parameter( - Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype), - name=name + '_untransformed_std') - - self.normal = Normal() - - def _std_trans(self, std_pre): - """Transform std_pre to prevent its value being zero.""" - std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1) - return std - - def construct(self, *inputs): - std = self._std_trans(self.untransformed_std) - return self.normal(*inputs, mean=self.mean, sd=std) - - -def normal_post_fn(name, shape): - """Provide normal posterior distribution.""" - return NormalPosterior(name=name, shape=shape) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Initialize normal distributions""" +import numpy as np +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +from ...cell import Cell +from ..distribution.normal import Normal + +__all__ = ['NormalPrior', 'NormalPosterior'] + + +class NormalPrior(Cell): + r""" + To initialize a normal distribution of mean 0 and standard deviation 0.1. + + Args: + dtype (mindspore.dtype): The argument is used to define the data type of the output tensor. + Default: ``mindspore.float32`` . + mean (int, float): Mean of normal distribution. Default: ``0`` . + std (int, float): Standard deviation of normal distribution. Default: ``0.1`` . + + Returns: + Cell, a normal distribution. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, dtype=mstype.float32, mean=0, std=0.1): + super(NormalPrior, self).__init__() + self.normal = Normal(mean, std, dtype=dtype) + + def construct(self, *inputs): + return self.normal(*inputs) + + +class NormalPosterior(Cell): + r""" + Build Normal distributions with trainable parameters. + + Args: + name (str): Name prepended to trainable parameter. + shape (list, tuple): Shape of the mean and standard deviation. + dtype (mindspore.dtype): The argument is used to define the data type of the output tensor. + Default: ``mindspore.float32`` . + loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: ``0`` . + loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: ``0.1`` . + untransformed_scale_mean (int, float): Mean of distribution to initialize trainable parameters. + Default: ``-5`` . + untransformed_scale_std (int, float): Standard deviation of distribution to initialize trainable parameters. + Default: ``0.1`` . + + Returns: + Cell, a normal distribution. + + Supported Platforms: + ``Ascend`` ``GPU`` + """ + def __init__(self, + name, + shape, + dtype=mstype.float32, + loc_mean=0, + loc_std=0.1, + untransformed_scale_mean=-5, + untransformed_scale_std=0.1): + super(NormalPosterior, self).__init__() + if not isinstance(name, str): + raise TypeError('The type of `name` must be `str`') + + if not isinstance(shape, (tuple, list)): + raise TypeError('The type of `shape` must be `tuple` or `list`') + + if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)): + raise TypeError('The type of `loc_mean` must be `int` or `float`') + + if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)): + raise TypeError('The type of `untransformed_scale_mean` must be `int` or `float`') + + if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0): + raise TypeError('The type of `loc_std` must be `int` or `float` and its value must > 0') + + if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and + untransformed_scale_std >= 0): + raise TypeError('The type of `untransformed_scale_std` must be `int` or `float` and ' + 'its value must > 0') + + self.mean = Parameter( + Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') + + self.untransformed_std = Parameter( + Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype), + name=name + '_untransformed_std') + + self.normal = Normal() + + def _std_trans(self, std_pre): + """Transform std_pre to prevent its value being zero.""" + std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1) + return std + + def construct(self, *inputs): + std = self._std_trans(self.untransformed_std) + return self.normal(*inputs, mean=self.mean, sd=std) + + +def normal_post_fn(name, shape): + """Provide normal posterior distribution.""" + return NormalPosterior(name=name, shape=shape) diff --git a/mindspore/python/mindspore/nn/probability/distribution/categorical.py b/mindspore/python/mindspore/nn/probability/distribution/categorical.py index bf3781d9298..ce191ba7e57 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/python/mindspore/nn/probability/distribution/categorical.py @@ -1,435 +1,435 @@ -# Copyright 2020-2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Categorical Distribution""" -import numpy as np -from mindspore import context -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C -from mindspore.ops.functional import stop_gradient -from mindspore.ops.operations import _inner_ops as inner -from mindspore import _checkparam as Validator -import mindspore.ops as ops -import mindspore.nn as nn -from mindspore.common import dtype as mstype -from .distribution import Distribution -from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\ - check_distribution_name -from ._utils.custom_ops import exp_generic, log_generic, broadcast_to, log_generic_with_check - - -class Categorical(Distribution): - r""" - Categorical distribution. - A Categorical Distribution is a discrete distribution with the range :math:`\{1, 2, ..., k\}` - and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`. - - Args: - probs (Tensor, list, numpy.ndarray): Event probabilities. Default: ``None`` . - seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: ``None`` . - dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` . - name (str): The name of the distribution. Default: ``Categorical`` . - - Note: - `probs` must have rank at least 1, values are proper probabilities and sum to 1. - - Raises: - ValueError: When the sum of all elements in `probs` is not 1. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> import mindspore - >>> import mindspore.nn as nn - >>> import mindspore.nn.probability.distribution as msd - >>> from mindspore import Tensor - >>> # To initialize a Categorical distribution of probs [0.5, 0.5] - >>> ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mindspore.int32) - >>> # A Categorical distribution can be initialized without arguments. - >>> # In this case, `probs` must be passed in through arguments during function calls. - >>> ca2 = msd.Categorical(dtype=mindspore.int32) - >>> # Here are some tensors used below for testing - >>> value = Tensor([1, 0], dtype=mindspore.int32) - >>> probs_a = Tensor([0.5, 0.5], dtype=mindspore.float32) - >>> probs_b = Tensor([0.35, 0.65], dtype=mindspore.float32) - >>> # Private interfaces of probability functions corresponding to public interfaces, including - >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. - >>> # Args: - >>> # value (Tensor): the value to be evaluated. - >>> # probs (Tensor): event probabilities. Default: self.probs. - >>> # Examples of `prob`. - >>> # Similar calls can be made to other probability functions - >>> # by replacing `prob` by the name of the function. - >>> ans = ca1.prob(value) - >>> print(ans.shape) - (2,) - >>> # Evaluate `prob` with respect to distribution b. - >>> ans = ca1.prob(value, probs_b) - >>> print(ans.shape) - (2,) - >>> # `probs` must be passed in during function calls. - >>> ans = ca2.prob(value, probs_a) - >>> print(ans.shape) - (2,) - >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. - >>> # Args: - >>> # probs (Tensor): event probabilities. Default: self.probs. - >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar. - >>> ans = ca1.mean() # return 0.8 - >>> print(ans.shape) - (1,) - >>> ans = ca1.mean(probs_b) - >>> print(ans.shape) - (1,) - >>> # `probs` must be passed in during function calls. - >>> ans = ca2.mean(probs_a) - >>> print(ans.shape) - (1,) - >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: - >>> # Args: - >>> # dist (str): the name of the distribution. Only 'Categorical' is supported. - >>> # probs_b (Tensor): event probabilities of distribution b. - >>> # probs (Tensor): event probabilities of distribution a. Default: self.probs. - >>> # Examples of `kl_loss`, `cross_entropy` is similar. - >>> ans = ca1.kl_loss('Categorical', probs_b) - >>> print(ans.shape) - () - >>> ans = ca1.kl_loss('Categorical', probs_b, probs_a) - >>> print(ans.shape) - () - >>> # An additional `probs` must be passed in. - >>> ans = ca2.kl_loss('Categorical', probs_b, probs_a) - >>> print(ans.shape) - () - """ - - def __init__(self, - probs=None, - seed=None, - dtype=mstype.int32, - name="Categorical"): - param = dict(locals()) - param['param_dict'] = {'probs': probs} - valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type - Validator.check_type_name( - "dtype", dtype, valid_dtype, type(self).__name__) - super(Categorical, self).__init__(seed, dtype, name, param) - - self._probs = self._add_parameter(probs, 'probs') - if self.probs is not None: - check_rank(self.probs) - check_prob(self.probs) - check_sum_equal_one(probs) - - # update is_scalar_batch and broadcast_shape - # drop one dimension - if self.probs.shape[:-1] == (): - self._is_scalar_batch = True - self._broadcast_shape = self._broadcast_shape[:-1] - - self.argmax = P.ArgMaxWithValue(axis=-1) - self.broadcast = broadcast_to - self.cast = P.Cast() - self.clip_by_value = ops.clip_by_value - self.concat = P.Concat(-1) - self.cumsum = P.CumSum() - self.dtypeop = P.DType() - self.exp = exp_generic - self.expand_dim = P.ExpandDims() - self.gather = P.GatherNd() - self.greater = P.Greater() - self.issubclass = inner.IsSubClass() - self.less = P.Less() - # when the graph kernel mode is enable - # use Log directly as akg will handle the corner cases - self.log = P.Log() if context.get_context("enable_graph_kernel") else log_generic - self.log_with_check = P.Log() if context.get_context("enable_graph_kernel") else log_generic_with_check - self.log_softmax = P.LogSoftmax() - self.logicor = P.LogicalOr() - self.logicand = P.LogicalAnd() - self.multinomial = P.Multinomial(seed=self.seed) - self.reshape = P.Reshape() - self.reduce_sum = P.ReduceSum(keep_dims=True) - self.select = P.Select() - self.shape = P.Shape() - self.softmax = P.Softmax() - self.squeeze = P.Squeeze() - self.squeeze_first_axis = P.Squeeze(0) - self.squeeze_last_axis = P.Squeeze(-1) - self.square = P.Square() - self.transpose = P.Transpose() - - self.index_type = mstype.int32 - self.nan = np.nan - - @property - def probs(self): - """ - Return the probability after casting to dtype. - - Output: - Tensor, the probs of the distribution. - """ - return self._probs - - def extend_repr(self): - """Display instance object as string.""" - if self.is_scalar_batch: - s = 'probs = {}'.format(self.probs) - else: - s = 'batch_shape = {}'.format(self._broadcast_shape) - return s - - def _get_dist_type(self): - return "Categorical" - - def _get_dist_args(self, probs=None): - if probs is not None: - self.checktensor(probs, 'probs') - else: - probs = self.probs - return (probs,) - - def _mean(self, probs=None): - r""" - .. math:: - E[X] = \sum_{i=0}^{num_classes-1} i*p_i - """ - probs = self._check_param_type(probs) - num_classes = self.shape(probs)[-1] - index = nn.Range(0., num_classes, 1.)() - return self.reduce_sum(index * probs, -1) - - def _mode(self, probs=None): - probs = self._check_param_type(probs) - index, _ = self.argmax(probs) - mode = self.cast(index, self.dtype) - return mode - - def _var(self, probs=None): - r""" - .. math:: - VAR(X) = E[X^{2}] - (E[X])^{2} - """ - probs = self._check_param_type(probs) - num_classes = self.shape(probs)[-1] - index = nn.Range(0., num_classes, 1.)() - return self.reduce_sum(self.square(index) * probs, -1) -\ - self.square(self.reduce_sum(index * probs, -1)) - - def _entropy(self, probs=None): - r""" - Evaluate entropy. - - .. math:: - H(X) = -\sum(logits * probs) - """ - probs = self._check_param_type(probs) - logits = self.log(probs) - return self.squeeze(P.Neg()(self.reduce_sum(logits * probs, -1))) - - def _kl_loss(self, dist, probs_b, probs=None): - """ - Evaluate KL divergence between Categorical distributions. - - Args: - dist (str): The type of the distributions. Should be "Categorical" in this case. - probs_b (Tensor): Event probabilities of distribution b. - probs (Tensor): Event probabilities of distribution a. Default: self.probs. - """ - check_distribution_name(dist, 'Categorical') - probs_b = self._check_value(probs_b, 'probs_b') - probs_b = self.cast(probs_b, self.parameter_type) - probs_a = self._check_param_type(probs) - if probs is None: - logits_a = self.log(probs_a) - else: - logits_a = self.log_with_check(probs_a) - logits_b = self.log_with_check(probs_b) - return self.squeeze(self.reduce_sum( - self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1)) - - def _cross_entropy(self, dist, probs_b, probs=None): - """ - Evaluate cross entropy between Categorical distributions. - - Args: - dist (str): The type of the distributions. Should be "Categorical" in this case. - probs_b (Tensor): Event probabilities of distribution b. - probs (Tensor): Event probabilities of distribution a. Default: self.probs. - """ - check_distribution_name(dist, 'Categorical') - return self._entropy(probs) + self._kl_loss(dist, probs_b, probs) - - def _log_prob(self, value, probs=None): - r""" - Evaluate log probability. - - Args: - value (Tensor): The value to be evaluated. - probs (Tensor): Event probabilities. Default: self.probs. - """ - value = self._check_value(value, 'value') - - probs = self._check_param_type(probs) - logits = self.log(probs) - - # find the right integer to compute index - # here we simulate casting to int but still keeping float dtype - value = self.cast(value, self.dtypeop(probs)) - - zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0) - between_zero_neone = self.logicand(self.less(value, 0,), - self.greater(value, -1.)) - value = self.select(between_zero_neone, - zeros, - P.Floor()(value)) - - # handle the case when value is of shape () and probs is a scalar batch - drop_dim = False - if self.shape(value) == () and self.shape(probs)[:-1] == (): - drop_dim = True - # manually add one more dimension: () -> (1,) - # drop this dimension before return - value = self.expand_dim(value, -1) - - value = self.expand_dim(value, -1) - - broadcast_shape_tensor = logits * value - broadcast_shape = self.shape(broadcast_shape_tensor) - num_classes = broadcast_shape[-1] - label_shape = broadcast_shape[:-1] - - # broadcasting logits and value - # logit_pmf shape (num of labels, C) - logits = self.broadcast(logits, broadcast_shape_tensor) - value = self.broadcast(value, broadcast_shape_tensor)[..., :1] - - # flatten value to shape (number of labels, 1) - # clip value to be in range from 0 to num_classes -1 and cast into int32 - value = self.reshape(value, (-1, 1)) - out_of_bound = self.squeeze_last_axis(self.logicor( - self.less(value, 0.0), self.less(num_classes-1, value))) - # deal with the case the there is only one class. - value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) - value_clipped = self.cast(value_clipped, self.index_type) - # create index from 0 ... NumOfLabels - index = self.reshape( - ops.arange(0, self.shape(value)[0], 1, dtype=self.index_type), (-1, 1) - ) - index = self.concat((index, value_clipped)) - - # index into logit_pmf, fill in out_of_bound places with -inf - # reshape into label shape N - logits_pmf = self.gather(self.reshape( - logits, (-1, num_classes)), index) - nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), - self.nan) - logits_pmf = self.select(out_of_bound, nan, logits_pmf) - ans = self.reshape(logits_pmf, label_shape) - if drop_dim: - return self.squeeze(ans) - return ans - - def _cdf(self, value, probs=None): - r""" - Cumulative distribution function (cdf) of Categorical distributions. - - Args: - value (Tensor): The value to be evaluated. - probs (Tensor): Event probabilities. Default: self.probs. - """ - value = self._check_value(value, 'value') - probs = self._check_param_type(probs) - - value = self.cast(value, self.dtypeop(probs)) - - zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0) - between_zero_neone = self.logicand( - self.less(value, 0,), self.greater(value, -1.)) - value = self.select(between_zero_neone, zeros, P.Floor()(value)) - - drop_dim = False - if self.shape(value) == () and self.shape(probs)[:-1] == (): - drop_dim = True - value = self.expand_dim(value, -1) - - value = self.expand_dim(value, -1) - - broadcast_shape_tensor = probs * value - broadcast_shape = self.shape(broadcast_shape_tensor) - num_classes = broadcast_shape[-1] - label_shape = broadcast_shape[:-1] - - probs = self.broadcast(probs, broadcast_shape_tensor) - value = self.broadcast(value, broadcast_shape_tensor)[..., :1] - - # flatten value to shape (number of labels, 1) - value = self.reshape(value, (-1, 1)) - - # drop one dimension to match cdf - # clip value to be in range from 0 to num_classes -1 and cast into int32 - less_than_zero = self.squeeze_last_axis(self.less(value, 0.0)) - value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) - value_clipped = self.cast(value_clipped, self.index_type) - - index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1)) - index = self.concat((index, value_clipped)) - - # reshape probs and fill less_than_zero places with 0 - probs = self.reshape(probs, (-1, num_classes)) - cdf = self.gather(self.cumsum(probs, 1), index) - zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) - cdf = self.select(less_than_zero, zeros, cdf) - cdf = self.reshape(cdf, label_shape) - - if drop_dim: - return self.squeeze(cdf) - return cdf - - def _sample(self, shape=(), probs=None): - """ - Sampling. - - Args: - shape (tuple): The shape of the sample. Default: (). - probs (Tensor): Event probabilities. Default: self.probs. - - Returns: - Tensor, shape is shape(probs)[:-1] + sample_shape - """ - shape = self.checktuple(shape, 'shape') - probs = self._check_param_type(probs) - num_classes = self.shape(probs)[-1] - batch_shape = self.shape(probs)[:-1] - - sample_shape = shape + batch_shape - drop_dim = False - if sample_shape == (): - drop_dim = True - sample_shape = (1,) - - probs_2d = self.reshape(probs, (-1, num_classes)) - sample_tensor = F.fill(self.dtype, shape, 1.0) - sample_tensor = self.reshape(sample_tensor, (-1, 1)) - num_sample = self.shape(sample_tensor)[0] - samples = C.multinomial(probs_2d, num_sample, seed=self.seed) - samples = self.squeeze(self.transpose(samples, (1, 0))) - samples = self.cast(self.reshape(samples, sample_shape), self.dtype) - if drop_dim: - return self.squeeze_first_axis(samples) - samples = stop_gradient(samples) - return samples +# Copyright 2020-2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Categorical Distribution""" +import numpy as np +from mindspore import context +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.ops.functional import stop_gradient +from mindspore.ops.operations import _inner_ops as inner +from mindspore import _checkparam as Validator +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\ + check_distribution_name +from ._utils.custom_ops import exp_generic, log_generic, broadcast_to, log_generic_with_check + + +class Categorical(Distribution): + r""" + Categorical distribution. + A Categorical Distribution is a discrete distribution with the range :math:`\{1, 2, ..., k\}` + and the probability mass function as :math:`P(X = i) = p_i, i = 1, ..., k`. + + Args: + probs (Tensor, list, numpy.ndarray): Event probabilities. Default: ``None`` . + seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: ``None`` . + dtype (mindspore.dtype): The type of the event samples. Default: ``mstype.int32`` . + name (str): The name of the distribution. Default: ``Categorical`` . + + Note: + `probs` must have rank at least 1, values are proper probabilities and sum to 1. + + Raises: + ValueError: When the sum of all elements in `probs` is not 1. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import mindspore.nn.probability.distribution as msd + >>> from mindspore import Tensor + >>> # To initialize a Categorical distribution of probs [0.5, 0.5] + >>> ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mindspore.int32) + >>> # A Categorical distribution can be initialized without arguments. + >>> # In this case, `probs` must be passed in through arguments during function calls. + >>> ca2 = msd.Categorical(dtype=mindspore.int32) + >>> # Here are some tensors used below for testing + >>> value = Tensor([1, 0], dtype=mindspore.int32) + >>> probs_a = Tensor([0.5, 0.5], dtype=mindspore.float32) + >>> probs_b = Tensor([0.35, 0.65], dtype=mindspore.float32) + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing `prob` by the name of the function. + >>> ans = ca1.prob(value) + >>> print(ans.shape) + (2,) + >>> # Evaluate `prob` with respect to distribution b. + >>> ans = ca1.prob(value, probs_b) + >>> print(ans.shape) + (2,) + >>> # `probs` must be passed in during function calls. + >>> ans = ca2.prob(value, probs_a) + >>> print(ans.shape) + (2,) + >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # probs (Tensor): event probabilities. Default: self.probs. + >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar. + >>> ans = ca1.mean() # return 0.8 + >>> print(ans.shape) + (1,) + >>> ans = ca1.mean(probs_b) + >>> print(ans.shape) + (1,) + >>> # `probs` must be passed in during function calls. + >>> ans = ca2.mean(probs_a) + >>> print(ans.shape) + (1,) + >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: + >>> # Args: + >>> # dist (str): the name of the distribution. Only 'Categorical' is supported. + >>> # probs_b (Tensor): event probabilities of distribution b. + >>> # probs (Tensor): event probabilities of distribution a. Default: self.probs. + >>> # Examples of `kl_loss`, `cross_entropy` is similar. + >>> ans = ca1.kl_loss('Categorical', probs_b) + >>> print(ans.shape) + () + >>> ans = ca1.kl_loss('Categorical', probs_b, probs_a) + >>> print(ans.shape) + () + >>> # An additional `probs` must be passed in. + >>> ans = ca2.kl_loss('Categorical', probs_b, probs_a) + >>> print(ans.shape) + () + """ + + def __init__(self, + probs=None, + seed=None, + dtype=mstype.int32, + name="Categorical"): + param = dict(locals()) + param['param_dict'] = {'probs': probs} + valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type + Validator.check_type_name( + "dtype", dtype, valid_dtype, type(self).__name__) + super(Categorical, self).__init__(seed, dtype, name, param) + + self._probs = self._add_parameter(probs, 'probs') + if self.probs is not None: + check_rank(self.probs) + check_prob(self.probs) + check_sum_equal_one(probs) + + # update is_scalar_batch and broadcast_shape + # drop one dimension + if self.probs.shape[:-1] == (): + self._is_scalar_batch = True + self._broadcast_shape = self._broadcast_shape[:-1] + + self.argmax = P.ArgMaxWithValue(axis=-1) + self.broadcast = broadcast_to + self.cast = P.Cast() + self.clip_by_value = ops.clip_by_value + self.concat = P.Concat(-1) + self.cumsum = P.CumSum() + self.dtypeop = P.DType() + self.exp = exp_generic + self.expand_dim = P.ExpandDims() + self.gather = P.GatherNd() + self.greater = P.Greater() + self.issubclass = inner.IsSubClass() + self.less = P.Less() + # when the graph kernel mode is enable + # use Log directly as akg will handle the corner cases + self.log = P.Log() if context.get_context("enable_graph_kernel") else log_generic + self.log_with_check = P.Log() if context.get_context("enable_graph_kernel") else log_generic_with_check + self.log_softmax = P.LogSoftmax() + self.logicor = P.LogicalOr() + self.logicand = P.LogicalAnd() + self.multinomial = P.Multinomial(seed=self.seed) + self.reshape = P.Reshape() + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.select = P.Select() + self.shape = P.Shape() + self.softmax = P.Softmax() + self.squeeze = P.Squeeze() + self.squeeze_first_axis = P.Squeeze(0) + self.squeeze_last_axis = P.Squeeze(-1) + self.square = P.Square() + self.transpose = P.Transpose() + + self.index_type = mstype.int32 + self.nan = np.nan + + @property + def probs(self): + """ + Return the probability after casting to dtype. + + Output: + Tensor, the probs of the distribution. + """ + return self._probs + + def extend_repr(self): + """Display instance object as string.""" + if self.is_scalar_batch: + s = 'probs = {}'.format(self.probs) + else: + s = 'batch_shape = {}'.format(self._broadcast_shape) + return s + + def _get_dist_type(self): + return "Categorical" + + def _get_dist_args(self, probs=None): + if probs is not None: + self.checktensor(probs, 'probs') + else: + probs = self.probs + return (probs,) + + def _mean(self, probs=None): + r""" + .. math:: + E[X] = \sum_{i=0}^{num_classes-1} i*p_i + """ + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + index = nn.Range(0., num_classes, 1.)() + return self.reduce_sum(index * probs, -1) + + def _mode(self, probs=None): + probs = self._check_param_type(probs) + index, _ = self.argmax(probs) + mode = self.cast(index, self.dtype) + return mode + + def _var(self, probs=None): + r""" + .. math:: + VAR(X) = E[X^{2}] - (E[X])^{2} + """ + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + index = nn.Range(0., num_classes, 1.)() + return self.reduce_sum(self.square(index) * probs, -1) -\ + self.square(self.reduce_sum(index * probs, -1)) + + def _entropy(self, probs=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = -\sum(logits * probs) + """ + probs = self._check_param_type(probs) + logits = self.log(probs) + return self.squeeze(P.Neg()(self.reduce_sum(logits * probs, -1))) + + def _kl_loss(self, dist, probs_b, probs=None): + """ + Evaluate KL divergence between Categorical distributions. + + Args: + dist (str): The type of the distributions. Should be "Categorical" in this case. + probs_b (Tensor): Event probabilities of distribution b. + probs (Tensor): Event probabilities of distribution a. Default: self.probs. + """ + check_distribution_name(dist, 'Categorical') + probs_b = self._check_value(probs_b, 'probs_b') + probs_b = self.cast(probs_b, self.parameter_type) + probs_a = self._check_param_type(probs) + if probs is None: + logits_a = self.log(probs_a) + else: + logits_a = self.log_with_check(probs_a) + logits_b = self.log_with_check(probs_b) + return self.squeeze(self.reduce_sum( + self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1)) + + def _cross_entropy(self, dist, probs_b, probs=None): + """ + Evaluate cross entropy between Categorical distributions. + + Args: + dist (str): The type of the distributions. Should be "Categorical" in this case. + probs_b (Tensor): Event probabilities of distribution b. + probs (Tensor): Event probabilities of distribution a. Default: self.probs. + """ + check_distribution_name(dist, 'Categorical') + return self._entropy(probs) + self._kl_loss(dist, probs_b, probs) + + def _log_prob(self, value, probs=None): + r""" + Evaluate log probability. + + Args: + value (Tensor): The value to be evaluated. + probs (Tensor): Event probabilities. Default: self.probs. + """ + value = self._check_value(value, 'value') + + probs = self._check_param_type(probs) + logits = self.log(probs) + + # find the right integer to compute index + # here we simulate casting to int but still keeping float dtype + value = self.cast(value, self.dtypeop(probs)) + + zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0) + between_zero_neone = self.logicand(self.less(value, 0,), + self.greater(value, -1.)) + value = self.select(between_zero_neone, + zeros, + P.Floor()(value)) + + # handle the case when value is of shape () and probs is a scalar batch + drop_dim = False + if self.shape(value) == () and self.shape(probs)[:-1] == (): + drop_dim = True + # manually add one more dimension: () -> (1,) + # drop this dimension before return + value = self.expand_dim(value, -1) + + value = self.expand_dim(value, -1) + + broadcast_shape_tensor = logits * value + broadcast_shape = self.shape(broadcast_shape_tensor) + num_classes = broadcast_shape[-1] + label_shape = broadcast_shape[:-1] + + # broadcasting logits and value + # logit_pmf shape (num of labels, C) + logits = self.broadcast(logits, broadcast_shape_tensor) + value = self.broadcast(value, broadcast_shape_tensor)[..., :1] + + # flatten value to shape (number of labels, 1) + # clip value to be in range from 0 to num_classes -1 and cast into int32 + value = self.reshape(value, (-1, 1)) + out_of_bound = self.squeeze_last_axis(self.logicor( + self.less(value, 0.0), self.less(num_classes-1, value))) + # deal with the case the there is only one class. + value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) + value_clipped = self.cast(value_clipped, self.index_type) + # create index from 0 ... NumOfLabels + index = self.reshape( + ops.arange(0, self.shape(value)[0], 1, dtype=self.index_type), (-1, 1) + ) + index = self.concat((index, value_clipped)) + + # index into logit_pmf, fill in out_of_bound places with -inf + # reshape into label shape N + logits_pmf = self.gather(self.reshape( + logits, (-1, num_classes)), index) + nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), + self.nan) + logits_pmf = self.select(out_of_bound, nan, logits_pmf) + ans = self.reshape(logits_pmf, label_shape) + if drop_dim: + return self.squeeze(ans) + return ans + + def _cdf(self, value, probs=None): + r""" + Cumulative distribution function (cdf) of Categorical distributions. + + Args: + value (Tensor): The value to be evaluated. + probs (Tensor): Event probabilities. Default: self.probs. + """ + value = self._check_value(value, 'value') + probs = self._check_param_type(probs) + + value = self.cast(value, self.dtypeop(probs)) + + zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0) + between_zero_neone = self.logicand( + self.less(value, 0,), self.greater(value, -1.)) + value = self.select(between_zero_neone, zeros, P.Floor()(value)) + + drop_dim = False + if self.shape(value) == () and self.shape(probs)[:-1] == (): + drop_dim = True + value = self.expand_dim(value, -1) + + value = self.expand_dim(value, -1) + + broadcast_shape_tensor = probs * value + broadcast_shape = self.shape(broadcast_shape_tensor) + num_classes = broadcast_shape[-1] + label_shape = broadcast_shape[:-1] + + probs = self.broadcast(probs, broadcast_shape_tensor) + value = self.broadcast(value, broadcast_shape_tensor)[..., :1] + + # flatten value to shape (number of labels, 1) + value = self.reshape(value, (-1, 1)) + + # drop one dimension to match cdf + # clip value to be in range from 0 to num_classes -1 and cast into int32 + less_than_zero = self.squeeze_last_axis(self.less(value, 0.0)) + value_clipped = self.clip_by_value(value, 0.0, num_classes - 1) + value_clipped = self.cast(value_clipped, self.index_type) + + index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1)) + index = self.concat((index, value_clipped)) + + # reshape probs and fill less_than_zero places with 0 + probs = self.reshape(probs, (-1, num_classes)) + cdf = self.gather(self.cumsum(probs, 1), index) + zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) + cdf = self.select(less_than_zero, zeros, cdf) + cdf = self.reshape(cdf, label_shape) + + if drop_dim: + return self.squeeze(cdf) + return cdf + + def _sample(self, shape=(), probs=None): + """ + Sampling. + + Args: + shape (tuple): The shape of the sample. Default: (). + probs (Tensor): Event probabilities. Default: self.probs. + + Returns: + Tensor, shape is shape(probs)[:-1] + sample_shape + """ + shape = self.checktuple(shape, 'shape') + probs = self._check_param_type(probs) + num_classes = self.shape(probs)[-1] + batch_shape = self.shape(probs)[:-1] + + sample_shape = shape + batch_shape + drop_dim = False + if sample_shape == (): + drop_dim = True + sample_shape = (1,) + + probs_2d = self.reshape(probs, (-1, num_classes)) + sample_tensor = F.fill(self.dtype, shape, 1.0) + sample_tensor = self.reshape(sample_tensor, (-1, 1)) + num_sample = self.shape(sample_tensor)[0] + samples = C.multinomial(probs_2d, num_sample, seed=self.seed) + samples = self.squeeze(self.transpose(samples, (1, 0))) + samples = self.cast(self.reshape(samples, sample_shape), self.dtype) + if drop_dim: + return self.squeeze_first_axis(samples) + samples = stop_gradient(samples) + return samples diff --git a/mindspore/python/mindspore/ops/_op_impl/_custom_op/_basic.py b/mindspore/python/mindspore/ops/_op_impl/_custom_op/_basic.py index 72fc48986b7..6d69aeff353 100644 --- a/mindspore/python/mindspore/ops/_op_impl/_custom_op/_basic.py +++ b/mindspore/python/mindspore/ops/_op_impl/_custom_op/_basic.py @@ -1,158 +1,158 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -""" -copyright 2020 Huawei Technologies Co., Ltd - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License == distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -_basic -""" -from __future__ import absolute_import -import te.platform.cce_params as cce - - -def _get_km_kn_shape(shape_a, shape_b, trans_a, trans_b): - """get_km_kn_shape""" - shape_len = len(shape_a) - if trans_a: - m_shape = shape_a[shape_len - 1] - km_shape = shape_a[shape_len - 2] - else: - m_shape = shape_a[shape_len - 2] - km_shape = shape_a[shape_len - 1] - - if trans_b: - kn_shape = shape_b[shape_len - 1] - n_shape = shape_b[shape_len - 2] - else: - kn_shape = shape_b[shape_len - 2] - n_shape = shape_b[shape_len - 1] - return m_shape, km_shape, n_shape, kn_shape - - -def _check_mn_shape(m_shape, n_shape, km_shape, kn_shape): - """_check_mn_shape""" - if m_shape == 1 and n_shape == 1: - raise RuntimeError("input shape M and N can't both be 1") - - if km_shape != kn_shape: - raise RuntimeError("reduce axis not same") - - if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: - raise RuntimeError( - "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) - - if m_shape != 1: - if n_shape == 1 and km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: - raise RuntimeError("input shape K1 must be multiple of %d" - % (cce.BLOCK_IN * cce.BLOCK_IN)) - if km_shape % cce.BLOCK_REDUCE != 0: - raise RuntimeError( - "input shape K1 should be multiple of %d" % cce.BLOCK_IN) - else: - if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: - raise RuntimeError("input shape K1 must be multiple of %d" - % (cce.BLOCK_IN * cce.BLOCK_IN)) - - -def _check_bias(shape_bias, shape_a, shape_b, m_shape, n_shape): - """_check_bias""" - is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False - is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False - if shape_bias: - if len(shape_bias) == 1: - if (is_gevm or is_gemv) and shape_bias[0] != m_shape * n_shape: - raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") - if shape_bias[0] != n_shape: - raise RuntimeError("broadcast bias shape must be equal to shape n") - elif len(shape_bias) == len(shape_a): - if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: - raise RuntimeError("non broadcast bias shape must be same as output shape") - else: - raise RuntimeError("Unsupported input shape now for batch bias case") - - -def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): - """ - Check the given input if legal - - Parameters: - shape_a: list or tuple - Shape of the first tensor a with rank > 1 - shape_b: list or tuple - Shape of the second tensor b with the same type with a, - and shape_a, shape_b must be 2 dims - shape_bias: list or tuple - Shape of bias, only support the input data format with ND - src_dtype: str - The data type of input, support "float32", "float16" - trans_a: bool - If True, shape_a == transposed before multiplication - trans_b: bool - If True, shape_b == transposed before multiplication - - Returns None - """ - shape_len = len(shape_a) - src_dtype = src_dtype.lower() - - check_list = ("float16",) - - if src_dtype not in check_list: - raise RuntimeError("matmul_cce only support %s while src_dtype == %s" - % (",".join(check_list), src_dtype)) - if shape_len != len(shape_b): - raise RuntimeError("length of a and b are not equal") - - if shape_len != 2: - raise RuntimeError( - "length of shape must be 2, more than 2 dimensions must use batch_matmul now!") - - m_shape, km_shape, n_shape, kn_shape = _get_km_kn_shape(shape_a, shape_b, trans_a, trans_b) - - if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: - raise RuntimeError("input shape N must be 1 or multiple of %d" % cce.BLOCK_IN) - _check_mn_shape(m_shape, n_shape, km_shape, kn_shape) - _check_bias(shape_bias, shape_a, shape_b, m_shape, n_shape) - - -def _get_bias(shape_bias): - """_get_bias""" - bias_length = shape_bias[0] - if bias_length % 16 == 0: - shb = shape_bias - else: - bias_length = (bias_length // 16) * 16 + 16 - shape_bias = [] - shape_bias.append(bias_length) - shb = shape_bias - return shb - - -def _get_input_shape(shape_x): - """_get_input_shape""" - dim_a = shape_x[0] - dim_b = shape_x[1] - res = [] - if dim_a % 16 != 0: - dim_a = (dim_a // 16) * 16 + 16 - res.append(dim_a) - else: - res.append(dim_a) - - if dim_b % 16 != 0: - dim_b = (dim_b // 16) * 16 + 16 - res.append(dim_b) - else: - res.append(dim_b) - return res +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +_basic +""" +from __future__ import absolute_import +import te.platform.cce_params as cce + + +def _get_km_kn_shape(shape_a, shape_b, trans_a, trans_b): + """get_km_kn_shape""" + shape_len = len(shape_a) + if trans_a: + m_shape = shape_a[shape_len - 1] + km_shape = shape_a[shape_len - 2] + else: + m_shape = shape_a[shape_len - 2] + km_shape = shape_a[shape_len - 1] + + if trans_b: + kn_shape = shape_b[shape_len - 1] + n_shape = shape_b[shape_len - 2] + else: + kn_shape = shape_b[shape_len - 2] + n_shape = shape_b[shape_len - 1] + return m_shape, km_shape, n_shape, kn_shape + + +def _check_mn_shape(m_shape, n_shape, km_shape, kn_shape): + """_check_mn_shape""" + if m_shape == 1 and n_shape == 1: + raise RuntimeError("input shape M and N can't both be 1") + + if km_shape != kn_shape: + raise RuntimeError("reduce axis not same") + + if m_shape % cce.BLOCK_IN != 0 and m_shape != 1: + raise RuntimeError( + "input shape M should be 1 or multiple of %d" % cce.BLOCK_IN) + + if m_shape != 1: + if n_shape == 1 and km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 must be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + if km_shape % cce.BLOCK_REDUCE != 0: + raise RuntimeError( + "input shape K1 should be multiple of %d" % cce.BLOCK_IN) + else: + if km_shape % (cce.BLOCK_IN * cce.BLOCK_IN) != 0: + raise RuntimeError("input shape K1 must be multiple of %d" + % (cce.BLOCK_IN * cce.BLOCK_IN)) + + +def _check_bias(shape_bias, shape_a, shape_b, m_shape, n_shape): + """_check_bias""" + is_gevm = True if shape_a[-2] == 1 or shape_a[-1] == 1 else False + is_gemv = True if shape_b[-2] == 1 or shape_b[-1] == 1 else False + if shape_bias: + if len(shape_bias) == 1: + if (is_gevm or is_gemv) and shape_bias[0] != m_shape * n_shape: + raise RuntimeError("broadcast case shape bias for gemv must be equal m*n") + if shape_bias[0] != n_shape: + raise RuntimeError("broadcast bias shape must be equal to shape n") + elif len(shape_bias) == len(shape_a): + if [i for i in shape_bias[-2:]] != [m_shape, n_shape]: + raise RuntimeError("non broadcast bias shape must be same as output shape") + else: + raise RuntimeError("Unsupported input shape now for batch bias case") + + +def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b): + """ + Check the given input if legal + + Parameters: + shape_a: list or tuple + Shape of the first tensor a with rank > 1 + shape_b: list or tuple + Shape of the second tensor b with the same type with a, + and shape_a, shape_b must be 2 dims + shape_bias: list or tuple + Shape of bias, only support the input data format with ND + src_dtype: str + The data type of input, support "float32", "float16" + trans_a: bool + If True, shape_a == transposed before multiplication + trans_b: bool + If True, shape_b == transposed before multiplication + + Returns None + """ + shape_len = len(shape_a) + src_dtype = src_dtype.lower() + + check_list = ("float16",) + + if src_dtype not in check_list: + raise RuntimeError("matmul_cce only support %s while src_dtype == %s" + % (",".join(check_list), src_dtype)) + if shape_len != len(shape_b): + raise RuntimeError("length of a and b are not equal") + + if shape_len != 2: + raise RuntimeError( + "length of shape must be 2, more than 2 dimensions must use batch_matmul now!") + + m_shape, km_shape, n_shape, kn_shape = _get_km_kn_shape(shape_a, shape_b, trans_a, trans_b) + + if n_shape % cce.BLOCK_IN != 0 and n_shape != 1: + raise RuntimeError("input shape N must be 1 or multiple of %d" % cce.BLOCK_IN) + _check_mn_shape(m_shape, n_shape, km_shape, kn_shape) + _check_bias(shape_bias, shape_a, shape_b, m_shape, n_shape) + + +def _get_bias(shape_bias): + """_get_bias""" + bias_length = shape_bias[0] + if bias_length % 16 == 0: + shb = shape_bias + else: + bias_length = (bias_length // 16) * 16 + 16 + shape_bias = [] + shape_bias.append(bias_length) + shb = shape_bias + return shb + + +def _get_input_shape(shape_x): + """_get_input_shape""" + dim_a = shape_x[0] + dim_b = shape_x[1] + res = [] + if dim_a % 16 != 0: + dim_a = (dim_a // 16) * 16 + 16 + res.append(dim_a) + else: + res.append(dim_a) + + if dim_b % 16 != 0: + dim_b = (dim_b // 16) * 16 + 16 + res.append(dim_b) + else: + res.append(dim_b) + return res diff --git a/mindspore/python/mindspore/ops/_op_impl/_custom_op/dsd_impl.py b/mindspore/python/mindspore/ops/_op_impl/_custom_op/dsd_impl.py index a34d7e18820..e9f062d37e1 100644 --- a/mindspore/python/mindspore/ops/_op_impl/_custom_op/dsd_impl.py +++ b/mindspore/python/mindspore/ops/_op_impl/_custom_op/dsd_impl.py @@ -1,162 +1,162 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" dense sparse to densne matmul""" -from __future__ import absolute_import -from te import tik -from tbe.tvm.topi.cce import util -from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register - -dsd_matmul_info = TBERegOp('DSDMatmul') \ - .fusion_type("OPAQUE") \ - .async_flag(False) \ - .binfile_name("dsdmatmul.so") \ - .compute_cost(10) \ - .kernel_name("dsd_matmul") \ - .partial_flag(True) \ - .input(0, "input_w1", False, "required", "all") \ - .input(1, "input_w2", False, "required", "all") \ - .input(2, "input_v", False, "required", "all") \ - .output(0, "output_y", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .get_op_info() - - -@op_info_register(dsd_matmul_info) -def dsd_matmul(input_w1, input_w2, input_v, output_y={}, kernel_name='dsd_matmul'): - """ dense sparse to densne matmul""" - if util.get_product_version() == util.VERSION_MINI: - tik_inst = tik.Tik(tik.Dprofile("v100", "mini")) - else: - tik_inst = tik.Tik(tik.Dprofile("v100", "cloud")) - - # shape is: (batch_size, head, block_num, block_size//16, 16, head_size//16, 16) - input_w1_shape = input_w1.get('shape') - # shape is: (batch_size, head, block_num, head_size//16, 16, global_size//16, 16) - input_w2_shape = input_w2.get('shape') - input_v_shape = input_v.get('shape') - - batch_size = input_w1_shape[0] - head = input_w1_shape[1] - block_num = input_w1_shape[2] - block_size = input_w1_shape[4] * 16 - head_size = input_w1_shape[3] * 16 - global_size = input_w2_shape[3] * 16 - v_embedding = input_v_shape[1] * 16 // head - seq_len = input_v_shape[0] * 16 // batch_size - - block_bite_size = 32 - cpt_time = seq_len // 512 - - w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size // - 16, block_size // 16, 16, 16), name='w1_gm', scope=tik.scope_gm) - w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size // - 16, head_size // 16, 16, 16), name='w2_gm', scope=tik.scope_gm) - # - v_gm = tik_inst.Tensor('float16', (batch_size * seq_len // 16, - head * v_embedding // 16, 16, 16), name='v_gm', scope=tik.scope_gm) - # zN - output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16), - name='output_gm', - scope=tik.scope_gm) - - channel_num = batch_size * head - with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx: - head_idx = channel_idx // batch_size - bs_idx = channel_idx % batch_size - output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16), name='output_l0c', - scope=tik.scope_cc) - output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub_32', - scope=tik.scope_ubuf) - output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub', - scope=tik.scope_ubuf) - # zZ - w1_l1 = tik_inst.Tensor( - 'float16', (block_size // 16, head_size // 16, 16, 16), name='w1_l1', scope=tik.scope_cbuf) - # nZ - v_local_l1 = tik_inst.Tensor( - 'float16', (head_size // 16, v_embedding // 16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf) - # zZ - w2_l1 = tik_inst.Tensor('float16', (head_size // 16, global_size // (16 * cpt_time), 16, 16), - name='w2_l1', scope=tik.scope_cbuf) - # nZ - # use same v_global - v_global_l1 = tik_inst.Tensor('float16', (global_size // 16, v_embedding // 16, 16, 16), - name='v_global_l1', scope=tik.scope_cbuf) - # global v - global_idx = 3 - head_idx % 4 - tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx * seq_len // 16 + global_idx, - head_idx * v_embedding // 16, 0, 0], 0, seq_len // (4 * 16), - 16 * v_embedding * 2 // block_bite_size, - (4 * head * v_embedding * 16 - 16 * v_embedding) * 2 // block_bite_size, 0) - # every block size is 64, the output of the local and global is (1024,128) Zn - with tik_inst.for_range(0, block_num, thread_num=2) as w_idx: - # global - with tik_inst.new_stmt_scope(): - w2_l0a = tik_inst.Tensor('float16', (head_size // 16, global_size // (cpt_time * 16), 16, 16), - name='w2_l0a', scope=tik.scope_ca) - v_global_l0b = tik_inst.Tensor('float16', (global_size // (cpt_time * 16), v_embedding // 16, 16, 16), - name='v_global_l0b', scope=tik.scope_cb) - with tik_inst.for_range(0, cpt_time) as cpt_idx: - with tik_inst.for_range(0, head_size // 16) as brick_i: - tik_inst.data_move(w2_l1[brick_i, 0, 0, 0], - w2_gm[bs_idx, head_idx, w_idx, cpt_idx * - global_size // (16 * cpt_time), brick_i, 0, 0], 0, - global_size // (16 * cpt_time), 16 * 16 * 2 // block_bite_size, - (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0) - tik_inst.load2dv1( - w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size * global_size // (cpt_time * 16 * 16), 1, - 0) - - tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx * global_size // ( - 16 * cpt_time), 0, 0, 0], 0, global_size * v_embedding // (16 * 16 * cpt_time), 1, 0) - - with tik_inst.if_scope(cpt_idx == 0): - tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, - block_size, global_size // cpt_time, v_embedding, 0) - with tik_inst.else_scope(): - tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, - block_size, global_size // cpt_time, v_embedding, 1) - # local - with tik_inst.new_stmt_scope(): - w1_l0a = tik_inst.Tensor('float16', (block_size // 16, head_size // 16, 16, 16), - name='w1_l0a', scope=tik.scope_ca) - v_local_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16), - name='v_local_l0b', scope=tik.scope_cb) - tik_inst.data_move(v_local_l1[0, 0, 0, 0], - v_gm[bs_idx * seq_len // 16 + w_idx * 4, head_idx * - v_embedding // 16, 0, 0], 0, block_size // 16, - 16 * v_embedding * 2 // block_bite_size, - 16 * (head - 1) * v_embedding * 2 // block_bite_size, 0) - tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0, - head_size * v_embedding // (16 * 16), 1, 0) - # w - with tik_inst.for_range(0, block_size // 16) as brick_i: - tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0, - head_size // 16, (16 * 16 * 2) // block_bite_size, - (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0) - tik_inst.load2dv1(w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size * head_size // (16 * 16), 1, 0) - tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b, - block_size, head_size, v_embedding, 1) - tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0, - 1, block_size * v_embedding * 4 // 1024, 0, 0) - tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0], - v_embedding * block_size // 64, 1, 1, 4, 8) - tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx * (block_size // 16), 0, 0], - output_ub[0, 0, 0, 0], - 0, v_embedding // 16, 16 * block_size * 2 // block_bite_size, 0, - (seq_len - block_size) * 16 * 2 // block_bite_size) - tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm], - outputs=[output_gm]) - return tik_inst +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" dense sparse to densne matmul""" +from __future__ import absolute_import +from te import tik +from tbe.tvm.topi.cce import util +from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register + +dsd_matmul_info = TBERegOp('DSDMatmul') \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("dsdmatmul.so") \ + .compute_cost(10) \ + .kernel_name("dsd_matmul") \ + .partial_flag(True) \ + .input(0, "input_w1", False, "required", "all") \ + .input(1, "input_w2", False, "required", "all") \ + .input(2, "input_v", False, "required", "all") \ + .output(0, "output_y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(dsd_matmul_info) +def dsd_matmul(input_w1, input_w2, input_v, output_y={}, kernel_name='dsd_matmul'): + """ dense sparse to densne matmul""" + if util.get_product_version() == util.VERSION_MINI: + tik_inst = tik.Tik(tik.Dprofile("v100", "mini")) + else: + tik_inst = tik.Tik(tik.Dprofile("v100", "cloud")) + + # shape is: (batch_size, head, block_num, block_size//16, 16, head_size//16, 16) + input_w1_shape = input_w1.get('shape') + # shape is: (batch_size, head, block_num, head_size//16, 16, global_size//16, 16) + input_w2_shape = input_w2.get('shape') + input_v_shape = input_v.get('shape') + + batch_size = input_w1_shape[0] + head = input_w1_shape[1] + block_num = input_w1_shape[2] + block_size = input_w1_shape[4] * 16 + head_size = input_w1_shape[3] * 16 + global_size = input_w2_shape[3] * 16 + v_embedding = input_v_shape[1] * 16 // head + seq_len = input_v_shape[0] * 16 // batch_size + + block_bite_size = 32 + cpt_time = seq_len // 512 + + w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size // + 16, block_size // 16, 16, 16), name='w1_gm', scope=tik.scope_gm) + w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size // + 16, head_size // 16, 16, 16), name='w2_gm', scope=tik.scope_gm) + # + v_gm = tik_inst.Tensor('float16', (batch_size * seq_len // 16, + head * v_embedding // 16, 16, 16), name='v_gm', scope=tik.scope_gm) + # zN + output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16), + name='output_gm', + scope=tik.scope_gm) + + channel_num = batch_size * head + with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx: + head_idx = channel_idx // batch_size + bs_idx = channel_idx % batch_size + output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16), name='output_l0c', + scope=tik.scope_cc) + output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub_32', + scope=tik.scope_ubuf) + output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub', + scope=tik.scope_ubuf) + # zZ + w1_l1 = tik_inst.Tensor( + 'float16', (block_size // 16, head_size // 16, 16, 16), name='w1_l1', scope=tik.scope_cbuf) + # nZ + v_local_l1 = tik_inst.Tensor( + 'float16', (head_size // 16, v_embedding // 16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf) + # zZ + w2_l1 = tik_inst.Tensor('float16', (head_size // 16, global_size // (16 * cpt_time), 16, 16), + name='w2_l1', scope=tik.scope_cbuf) + # nZ + # use same v_global + v_global_l1 = tik_inst.Tensor('float16', (global_size // 16, v_embedding // 16, 16, 16), + name='v_global_l1', scope=tik.scope_cbuf) + # global v + global_idx = 3 - head_idx % 4 + tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx * seq_len // 16 + global_idx, + head_idx * v_embedding // 16, 0, 0], 0, seq_len // (4 * 16), + 16 * v_embedding * 2 // block_bite_size, + (4 * head * v_embedding * 16 - 16 * v_embedding) * 2 // block_bite_size, 0) + # every block size is 64, the output of the local and global is (1024,128) Zn + with tik_inst.for_range(0, block_num, thread_num=2) as w_idx: + # global + with tik_inst.new_stmt_scope(): + w2_l0a = tik_inst.Tensor('float16', (head_size // 16, global_size // (cpt_time * 16), 16, 16), + name='w2_l0a', scope=tik.scope_ca) + v_global_l0b = tik_inst.Tensor('float16', (global_size // (cpt_time * 16), v_embedding // 16, 16, 16), + name='v_global_l0b', scope=tik.scope_cb) + with tik_inst.for_range(0, cpt_time) as cpt_idx: + with tik_inst.for_range(0, head_size // 16) as brick_i: + tik_inst.data_move(w2_l1[brick_i, 0, 0, 0], + w2_gm[bs_idx, head_idx, w_idx, cpt_idx * + global_size // (16 * cpt_time), brick_i, 0, 0], 0, + global_size // (16 * cpt_time), 16 * 16 * 2 // block_bite_size, + (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0) + tik_inst.load2dv1( + w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size * global_size // (cpt_time * 16 * 16), 1, + 0) + + tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx * global_size // ( + 16 * cpt_time), 0, 0, 0], 0, global_size * v_embedding // (16 * 16 * cpt_time), 1, 0) + + with tik_inst.if_scope(cpt_idx == 0): + tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, + block_size, global_size // cpt_time, v_embedding, 0) + with tik_inst.else_scope(): + tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, + block_size, global_size // cpt_time, v_embedding, 1) + # local + with tik_inst.new_stmt_scope(): + w1_l0a = tik_inst.Tensor('float16', (block_size // 16, head_size // 16, 16, 16), + name='w1_l0a', scope=tik.scope_ca) + v_local_l0b = tik_inst.Tensor('float16', (head_size // 16, v_embedding // 16, 16, 16), + name='v_local_l0b', scope=tik.scope_cb) + tik_inst.data_move(v_local_l1[0, 0, 0, 0], + v_gm[bs_idx * seq_len // 16 + w_idx * 4, head_idx * + v_embedding // 16, 0, 0], 0, block_size // 16, + 16 * v_embedding * 2 // block_bite_size, + 16 * (head - 1) * v_embedding * 2 // block_bite_size, 0) + tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0, + head_size * v_embedding // (16 * 16), 1, 0) + # w + with tik_inst.for_range(0, block_size // 16) as brick_i: + tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0, + head_size // 16, (16 * 16 * 2) // block_bite_size, + (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0) + tik_inst.load2dv1(w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size * head_size // (16 * 16), 1, 0) + tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b, + block_size, head_size, v_embedding, 1) + tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0, + 1, block_size * v_embedding * 4 // 1024, 0, 0) + tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0], + v_embedding * block_size // 64, 1, 1, 4, 8) + tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx * (block_size // 16), 0, 0], + output_ub[0, 0, 0, 0], + 0, v_embedding // 16, 16 * block_size * 2 // block_bite_size, 0, + (seq_len - block_size) * 16 * 2 // block_bite_size) + tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm], + outputs=[output_gm]) + return tik_inst diff --git a/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py b/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py index c1eca702227..b66c4042dfb 100644 --- a/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +++ b/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py @@ -1,644 +1,644 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""matmul dds impl""" -from te import tik -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -matmul_dds_grad_op_info = TBERegOp("MatmulDDSGrad") \ - .fusion_type("OPAQUE") \ - .async_flag(False) \ - .binfile_name("matmul_dds_grad.so") \ - .compute_cost(10) \ - .kernel_name("matmul_dds_grad") \ - .partial_flag(True) \ - .input(0, "q", False, "required", "all") \ - .input(1, "k", False, "required", "all") \ - .input(2, "local_prob", False, "required", "all") \ - .input(3, "global_prob", False, "required", "all") \ - .input(4, "local_prob_grad", False, "required", "all") \ - .input(5, "global_prob_grad", False, "required", "all") \ - .output(0, "dq", False, "required", "all") \ - .output(1, "dk", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default) \ - .get_op_info() - - -@op_info_register(matmul_dds_grad_op_info) -def matmul_dds_grad(q, - k, - local_prob, - global_prob, - local_prob_grad, - global_prob_grad, - dq, - dk, - kernel_name="matmul_dds_grad"): - """ - :param q: the dict of input q (bs*seq_len, embedding_size) zN - :param k: the dict of input k (bs*seq_len, embedding_size) nZ - :param local_mask: the dict of input mask local (bs*16*64, 64) zN - :param global_mask: the dict of input mask global (heads*1024, 256) zN - :param local_prob: local output (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN - :param global_prob: global output (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN - :param local_prob_grad: local output grad (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN - :param global_prob_grad: global output grad (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN - """ - - shape_q = q.get( - 'shape') - shape_lc = local_prob.get( - 'shape') - shape_gc = global_prob.get( - 'shape') - bs = shape_lc[0] - heads = shape_gc[1] - global_size = shape_gc[3] * shape_gc[-1] - block_size = shape_lc[4] * shape_lc[5] - seq_len = shape_q[1] * shape_q[2] // bs - block_num = seq_len // block_size - size_per_head = shape_q[0] * shape_q[-1] // heads - - tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud')) - mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), - name="mat_q", - scope=tik.scope_gm) # zN - mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), - name="mat_k", - scope=tik.scope_gm) # nZ - mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), - name="mat_lc", - scope=tik.scope_gm) # zN - mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), - name="mat_gc", - scope=tik.scope_gm) # zN - mat_lc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), - name="mat_lc_grad", - scope=tik.scope_gm) # zN - mat_gc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), - name="mat_gc_grad", - scope=tik.scope_gm) # zN - mat_dq = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), - name="mat_dq", - scope=tik.scope_gm) # zN - mat_dk = tik_inst.Tensor("float16", (bs * seq_len // 16, size_per_head * heads // 16, 16, 16), - name="mat_dk", - scope=tik.scope_gm) # zN - - channel_num = bs * heads - with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index: - # apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ - mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), - name='mat_l1_ones', - scope=tik.scope_cbuf) - with tik_inst.new_stmt_scope(): - mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), - name='mat_ub_ones', - scope=tik.scope_ubuf) - tik_inst.vec_dup(128, mat_ub_ones, 1.0, - (global_size + block_size) * 16 // 128, 8) - tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0], - 0, (global_size + block_size) // 16, 16, 0, 0) - - b = tik_inst.Scalar(dtype="int32") - b.set_as(block_index // heads) - - head = tik_inst.Scalar(dtype="int32") - head.set_as(block_index - b * heads) - - s = tik_inst.Scalar(dtype="int32") - s.set_as(head // 4) - # formula: global_idx = 3 - (head - 4 * s) # global idx for global key extraction - global_idx = tik_inst.Scalar(dtype="int32") - global_idx.set_as(3 - (head - 4 * s)) - # apply tensor in l1 for global k (256, 128) nZ - mat_l1_gk = tik_inst.Tensor("float16", - (global_size // 16, size_per_head // 16, 16, 16), - name="mat_l1_gk", - scope=tik.scope_cbuf) - # apply for tensor in L0C for global dk (128, 256) zN - mat_l0c_dkg = tik_inst.Tensor("float32", - (global_size // 16, - size_per_head // 16, 16, 16), - name="mat_l0c_dkg", - scope=tik.scope_cc) - with tik_inst.for_range(0, global_size // 16) as gb: - # move global key from gm to L1 nZ - # the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16) - tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0], - mat_k[ - head * size_per_head // 16, b * seq_len // 16 + - global_idx + gb * block_size // 16, 0, 0], - 0, size_per_head // 16, 16, bs * seq_len - 16, 0) - with tik_inst.for_range(0, block_num) as block: - # do backward softmax - # formula: grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax - # apply for tensor in ub for grad_x out (64, 320) zN - mat_ub_lg_d = tik_inst.Tensor("float16", - ((global_size + block_size) // - 16, block_size // 16, 16, 16), - name='mat_ub_lg_d', - scope=tik.scope_ubuf) - with tik_inst.new_stmt_scope(): - # apply for tensor in ub for softmax out (64, 320) zN - mat_ub_lg = tik_inst.Tensor("float16", ((global_size + block_size) // 16, block_size // 16, 16, 16), - name='mat_ub_lg', - scope=tik.scope_ubuf) - # apply for tensor in ub for softmax out grad (64, 320) zN - mat_ub_lg_grad = tik_inst.Tensor("float16", - ((global_size + block_size) // - 16, block_size // 16, 16, 16), - name='mat_ub_lg_grad', - scope=tik.scope_ubuf) - # move local out from gm to ub zN - # the shape of local out in gm is zN - # the shape of local out in UB is zN - # the stride between each (64, 16) is 0 - # repeat 4 times - tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_lc[b, head, block, 0, 0, 0, 0], 0, - block_size // 16, block_size, - 0, 0) - # move global out from gm to ub zN - # the shape of global out in gm is zN - # the shape of global out in UB is zN - # the stride between each (64, 16) is 0 - # repeat 16 times - tik_inst.data_move(mat_ub_lg[block_size // 16, 0, 0, 0], mat_gc[b, head, block, 0, 0, 0, 0], 0, - global_size // 16, block_size, - 0, 0) - # move local out grad from gm to ub zN - # the shape of local out grad in gm is zN - # the shape of local out grad in UB is zN - # the stride between each (64, 16) is 0 - # repeat 4 times - tik_inst.data_move(mat_ub_lg_grad[0, 0, 0, 0], mat_lc_grad[b, head, block, 0, 0, 0, 0], 0, - block_size // 16, block_size, - 0, 0) - # move global out grad from gm to ub zN - # the shape of global out grad in gm is zN - # the shape of global out grad in UB is zN - # the stride between each (64, 16) is 0 - # repeat 16 times - tik_inst.data_move(mat_ub_lg_grad[block_size // 16, 0, 0, 0], - mat_gc_grad[b, head, block, 0, 0, 0, 0], 0, - global_size // 16, block_size, - 0, 0) - # apply for tensor in ub for softmax multiply out grad (64, 320) zN - mat_ub_ssg = tik_inst.Tensor("float16", - ((global_size + block_size) // - 16, block_size // 16, 16, 16), - name='mat_ub_ssg', - scope=tik.scope_ubuf) - # calculate softmax * softmax_grad - tik_inst.vmul(128, mat_ub_ssg[0, 0, 0, 0], mat_ub_lg_grad[0, 0, 0, 0], mat_ub_lg[0, 0, 0, 0], - (global_size + block_size) * block_size // 128, - 1, 1, 1, 8, 8, 8) - - # apply for tensor in L1 for dsoftmax*softmax result (320, 64) nZ - mat_l1_ssg_nz = tik_inst.Tensor("float16", ((global_size + block_size) // 16, - block_size // 16, 16, 16), - name='mat_l1_ssg_nz', - scope=tik.scope_cbuf) - # move ones from ub to L1 for CUBE mmad - # the shape of ones in ub is nZ - # the shape of ones in L0A is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.data_move(mat_l1_ssg_nz[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], 0, - (global_size + block_size) // 16, block_size, 0, 0) - # apply tensor in l0c for exp sum (16, 64) zN - mat_l0c_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), - name='mat_l0c_ssg_sum', - scope=tik.scope_cc) - # apply tensor in ub for exp sum (16, 64) zN - mat_ub_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), - name='mat_ub_ssg_sum', - scope=tik.scope_ubuf) - # apply for tensor in L0A for q (16, 320) zZ - mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 16, 16, 16), - name='mat_l0a_ones', scope=tik.scope_ca) - # apply for tensor in L0B for exp (320, 64) nZ - mat_l0b_ssg = tik_inst.Tensor('float16', ((global_size + block_size) // 16, block_size // 16, 16, 16), - name='mat_l0b_exp', scope=tik.scope_cb) - # move ones from l1 to L0A for CUBE mmad - # the shape of ones in l1 is zZ - # the shape of ones in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0, - (global_size + block_size) * 16 // (16 * 16), 1, 0, False) - # move ssg from l1 to L0B for CUBE mmad - # the shape of ssg in l1 is nZ - # the shape of ssg in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 128 times - tik_inst.load2dv1(mat_l0b_ssg[0, 0, 0, 0], mat_l1_ssg_nz[0, 0, 0, 0], 0, - (global_size + block_size) * block_size // (16 * 16), 1, 0, False) - tik_inst.mmad(mat_l0c_ssg_sum, mat_l0a_ones, mat_l0b_ssg, - 16, (global_size + block_size), block_size, 0) - tik_inst.data_move(mat_ub_ssg_sum[0, 0, 0, 0], mat_l0c_ssg_sum[0, 0, 0, 0], 0, - block_size // 16, 1, 0, 0) - # apply for tensor in UB for global prob sum (64,) - mat_ub_ssg_sums = tik_inst.Tensor("float32", (block_size,), - name='mat_ub_ssg_sums', - scope=tik.scope_ubuf) - tik_inst.data_move(mat_ub_ssg_sums[0], mat_ub_ssg_sum[0, 0, 0, 0], - 0, block_size // 16, 1 * 2, 15 * 2, 0) - # apply for tensor in UB for global prob sum (64,) - mat_ub_ssg_sums_16 = tik_inst.Tensor("float16", (block_size,), - name='mat_ub_ssg_sums_16', - scope=tik.scope_ubuf) - # convert fp32 to fp16 - tik_inst.vec_conv( - 64, "", mat_ub_ssg_sums_16[0], mat_ub_ssg_sums[0], 1, 4, 8) - - mat_ub_ssgs = tik_inst.Tensor("float16", - ((global_size + block_size) // - 16, block_size // 16, 16, 16), - name='mat_ub_ssgs', - scope=tik.scope_ubuf) - - with tik_inst.for_range(0, block_size) as bbs: - # apply for scalar in UB for prob sum rec - sum_ssg = tik_inst.Scalar("float16", - name='sum_ssg', - init_value=0) - # set value for scalar prob sum rec - sum_ssg.set_as(mat_ub_ssg_sums_16[bbs]) - tik_inst.vec_muls(16, mat_ub_ssgs[0, bbs // 16, bbs % 16, 0], - mat_ub_lg[0, bbs // 16, bbs % - 16, 0], sum_ssg, - (global_size + block_size) // 16, - block_size, block_size) - - tik_inst.vsub(128, mat_ub_lg_d[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], mat_ub_ssgs[0, 0, 0, 0], - (global_size + block_size) * block_size // 128, - 1, 1, 1, 8, 8, 8) - - # local dq calculation - # dw X K.T - # apply tensor in l1 for local k (64, 128) nZ - mat_l1_lk = tik_inst.Tensor("float16", - (block_size // 16, - size_per_head // 16, 16, 16), - name="mat_l1_lk", - scope=tik.scope_cbuf) - # move k from gm to l1 - # the shape of local k in gm is nZ - # the shape of local k in l1 is zZ - # the stride between each (16, 16) is 1024*bs-64 - # repeat 8 times - # LOOP 4 times - with tik_inst.for_range(0, block_size // 16) as lb: - tik_inst.data_move(mat_l1_lk[lb, 0, 0, 0], - mat_k[head * size_per_head // 16, b * seq_len // 16 + ( - block * block_size) // 16 + lb, 0, 0], - 0, size_per_head // 16, 16, bs * seq_len - 16, 0) - - # apply tensor in l1 for local dw (64, 128) zZ - mat_l1_ldw = tik_inst.Tensor("float16", - (block_size // 16, - block_size // 16, 16, 16), - name="mat_l1_ldw", - scope=tik.scope_cbuf) - # move local d-softmax from ub to l1 - # the shape of d-softmax in ub is zN - # the shape of d-softmax in l1 is zZ - # the stride between each (16, 64) is 0 - # repeat 16 times - with tik_inst.for_range(0, block_size // 16) as lb: - tik_inst.data_move(mat_l1_ldw[lb, 0, 0, 0], - mat_ub_lg_d[0, lb, 0, 0], - 0, block_size // 16, 16, block_size - 16, 0) - # apply for tensor in L0C for local d-q (64, 128) zN - mat_l0c_dq = tik_inst.Tensor("float32", - (size_per_head // 16, - block_size // 16, 16, 16), - name="mat_l0c_dq", - scope=tik.scope_cc) - with tik_inst.new_stmt_scope(): - # apply for tensor in L0A for q (64, 64) zZ - mat_l0a_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16), - name='mat_l0a_ldw', scope=tik.scope_ca) - # apply for tensor in L0B for global k (128, 256) nZ - mat_l0b_lk = tik_inst.Tensor('float16', (block_size // 16, size_per_head // 16, 16, 16), - name='mat_l0b_lk', scope=tik.scope_cb) - # move q from l1 to L0A for CUBE mmad - # the shape of q in l1 is zZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 16 times - tik_inst.load2dv1(mat_l0a_ldw[0, 0, 0, 0], mat_l1_ldw[0, 0, 0, 0], 0, - block_size * block_size // (16 * 16), 1, 0, False) - # move local k from l1 to L0B for CUBE mmad - # the shape of local k in l1 is zZ - # the shape of local k in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0b_lk[0, 0, 0, 0], mat_l1_lk[0, 0, 0, 0], 0, - block_size * size_per_head // (16 * 16), 1, 0, True) - # matmul q and local dw - # the shape of global scores in L0C is zN - tik_inst.mmad(mat_l0c_dq, mat_l0a_ldw, mat_l0b_lk, - block_size, block_size, size_per_head, 0) - - # global dq calculation - # apply tensor in l1 for global dw (64, 256) zZ - mat_l1_gdw = tik_inst.Tensor("float16", - (block_size // 16, - global_size // 16, 16, 16), - name="mat_l1_gdw", - scope=tik.scope_cbuf) - # move global dw from ub to l1 - # the shape of global dw in gm is zN - # the shape of global dw in l1 is zZ - # the stride between each (16, 16) is 1024*bs-64 - # repeat 8 times - # LOOP 4 times - with tik_inst.for_range(0, block_size // 16) as lb: - tik_inst.data_move(mat_l1_gdw[lb, 0, 0, 0], - mat_ub_lg_d[block_size // 16, lb, 0, 0], - 0, global_size // 16, 16, block_size - 16, 0) - # apply for tensor in ub for dq (64, 128) zN - mat_ub_dq = tik_inst.Tensor("float32", - (size_per_head // 16, - block_size // 16, 16, 16), - name="mat_ub_dq", - scope=tik.scope_ubuf) - with tik_inst.new_stmt_scope(): - # apply for tensor in L0A for global dw (64, 256) zZ - mat_l0a_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16), - name='mat_l0a_gdw', scope=tik.scope_ca) - # apply for tensor in L0B for global k (256, 128) nZ - mat_l0b_gk = tik_inst.Tensor('float16', (global_size // 16, size_per_head // 16, 16, 16), - name='mat_l0b_gk', scope=tik.scope_cb) - # move dw global from l1 to L0A for CUBE mmad - # the shape of q in l1 is zZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 16 times - tik_inst.load2dv1(mat_l0a_gdw[0, 0, 0, 0], mat_l1_gdw[0, 0, 0, 0], 0, - block_size * global_size // (16 * 16), 1, 0, False) - # move local k from l1 to L0B for CUBE mmad - # the shape of local k in l1 is zZ - # the shape of local k in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0b_gk[0, 0, 0, 0], mat_l1_gk[0, 0, 0, 0], 0, - global_size * size_per_head // (16 * 16), 1, 0, True) - # matmul k and local dw - # the shape of global scores in L0C is zN - tik_inst.mmad(mat_l0c_dq, mat_l0a_gdw, mat_l0b_gk, - block_size, global_size, size_per_head, 1) - # move dq from l0c to UB - # the shape of dq in l9c is zN - # the shape of dq in ub is zN - # the stride between each (16, 64) is 0 - # repeat 8 times - tik_inst.data_move(mat_ub_dq[0, 0, 0, 0], mat_l0c_dq[0, 0, 0, 0], 0, size_per_head // 16, - block_size // 16, 0, 0) - - # local dk calculation - # dk calculation q.T X dw - # apply for tensor in ub for dw (320, 64) nZ - mat_ub_lg_d_nz = tik_inst.Tensor("float16", - (block_size // 16, (global_size + - block_size) // 16, 16, 16), - name='mat_ub_lg_d_nz', - scope=tik.scope_ubuf) - # transpose dw from zN to nZ - with tik_inst.for_range(0, (global_size + block_size) // 16) as lb: - with tik_inst.for_range(0, block_size // 16) as gb: - tik_inst.vtranspose( - mat_ub_lg_d_nz[gb, lb, 0, 0], mat_ub_lg_d[lb, gb, 0, 0]) - - # apply tensor in l1 for local dw (64, 64) nZ - mat_l1_ldw_nz = tik_inst.Tensor("float16", - (block_size // 16, - block_size // 16, 16, 16), - name="mat_l1_ldw_nz", - scope=tik.scope_cbuf) - # move local dw from ub to l1 - # the shape of local dw in ub is nZ - # the shape of local dw in l1 is nZ - # the stride between each (16, 64) is 256 - # repeat 4 times - tik_inst.data_move(mat_l1_ldw_nz[0, 0, 0, 0], - mat_ub_lg_d_nz[0, 0, 0, 0], - 0, block_size // 16, block_size, global_size, 0) - # apply for tensor in L1 for q (128, 64) nZ - mat_l1_q_b = tik_inst.Tensor("float16", - (size_per_head // 16, - block_size // 16, 16, 16), - name="mat_l1_q_b", - scope=tik.scope_cbuf) - # move local q from gm to l1 - # the shape of local q in gm is zN - # the shape of local dw in l1 is zZ - # the stride between each (16, 16) is 48 - # repeat 4 times - # LOOP 8 times - with tik_inst.for_range(0, size_per_head // 16) as lb: - tik_inst.load2dv1(mat_l1_q_b[lb, 0, 0, 0], - mat_q[head * size_per_head // 16 + lb, - b * seq_len // 16 + (block * block_size) // 16, 0, 0], - 0, block_size // 16, 1, 0, False) - # apply for tensor in L0C for local dk (128, 64) zN - mat_l0c_dkl = tik_inst.Tensor("float32", - (block_size // 16, - size_per_head // 16, 16, 16), - name="mat_l0c_dkl", - scope=tik.scope_cc) - # apply for tensor in ub for local dk (128, 64) zN - mat_ub_ldk = tik_inst.Tensor("float32", - (block_size // 16, - size_per_head // 16, 16, 16), - name="mat_ub_ldk", - scope=tik.scope_ubuf) - with tik_inst.new_stmt_scope(): - # apply for tensor in L0A for q (128, 64) zZ - mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16), - name='mat_l0a_q', scope=tik.scope_ca) - # apply for tensor in L0B for local dw (64, 64) nZ - mat_l0b_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16), - name='mat_l0b_ldw', scope=tik.scope_cb) - # move q from l1 to L0A for CUBE mmad - # the shape of q in l1 is nZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 4 times - # LOOP 8 times - tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0], - mat_l1_q_b[0, 0, 0, 0], - 0, block_size * size_per_head // 256, 1, 0, True) - # move local dw from l1 to L0B for CUBE mmad - # the shape of local dw in l1 is nZ - # the shape of local dw in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0b_ldw[0, 0, 0, 0], mat_l1_ldw_nz[0, 0, 0, 0], 0, - block_size * block_size // (16 * 16), 1, 0, False) - # matmul q and local dw - # the shape of local k in L0C is zN - tik_inst.mmad(mat_l0c_dkl, mat_l0a_q, mat_l0b_ldw, - size_per_head, block_size, block_size, 0) - # move local dk from l0c to UB - # the shape of local dk in l0C is zN - # the shape of local dk in UB is zN - # the stride between each (16, 128) is 0 - # repeat 4 times - tik_inst.data_move(mat_ub_ldk[0, 0, 0, 0], mat_l0c_dkl[0, 0, 0, 0], 0, block_size // 16, - size_per_head // 16, 0, 0) - - # move global dw from UB to l1 - # apply for tensor in L1 for global dw (64, 256) nZ - mat_l1_dwg_b = tik_inst.Tensor("float16", - (block_size // 16, - global_size // 16, 16, 16), - name="mat_l1_dwg_b", - scope=tik.scope_cbuf) - # move global dw from UB to L1 - # the shape of global dw in gm is nZ - # the shape of global dw in gm is nZ - # the stride between each (16, 64) is 0 - # repeat 8 times - tik_inst.data_move(mat_l1_dwg_b[0, 0, 0, 0], - mat_ub_lg_d_nz[0, block_size // 16, 0, 0], - 0, block_size // 16, global_size, block_size, 0) - - with tik_inst.new_stmt_scope(): - # apply for tensor in L0A for q (128, 64) zZ - mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16), - name='mat_l0a_q', scope=tik.scope_ca) - # apply for tensor in L0B for local dw (64, 64) nZ - mat_l0b_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16), - name='mat_l0b_ldw', scope=tik.scope_cb) - # move q from l1 to L0A for CUBE mmad - # the shape of q in l1 is nZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 4 times - # LOOP 8 times - tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0], - mat_l1_q_b[0, 0, 0, 0], - 0, block_size * size_per_head // 256, 1, 0, True) - # move local dw from l1 to L0B for CUBE mmad - # the shape of local dw in l1 is nZ - # the shape of local dw in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0b_gdw[0, 0, 0, 0], mat_l1_dwg_b[0, 0, 0, 0], 0, - block_size * global_size // (16 * 16), 1, 0, False) - # matmul q and local dw - # the shape of local k in L0C is zN - with tik_inst.if_scope(block == 0): - tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw, - size_per_head, block_size, global_size, 0) - with tik_inst.else_scope(): - tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw, - size_per_head, block_size, global_size, 1) - - # cast dq from 32 to 16 - # apply for tensor in ub for dq (64, 128) zN - mat_ub_dq_16 = tik_inst.Tensor("float16", - (size_per_head // 16, - block_size // 16, 16, 16), - name="mat_ub_dq_16", - scope=tik.scope_ubuf) - # apply for tensor in ub for local dk (128, 64) zN - mat_ub_ldk_16 = tik_inst.Tensor("float16", - (block_size // 16, - size_per_head // 16, 16, 16), - name="mat_ub_ldk_16", - scope=tik.scope_ubuf) - tik_inst.vec_conv( - 64, "", mat_ub_ldk_16[0, 0, 0, 0], mat_ub_ldk[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8) - tik_inst.vec_conv( - 64, "", mat_ub_dq_16[0, 0, 0, 0], mat_ub_dq[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8) - - # move dq from UB to gm - # the shape of dq in UB is zN - # the shape of dq in gm is zN - # the stride between each (16, 64) is 0 - # repeat 8 times - tik_inst.data_move(mat_dq[head * size_per_head // 16, - b * seq_len // 16 + (block * block_size) // 16, 0, 0], - mat_ub_dq_16[0, 0, 0, - 0], 0, size_per_head // 16, block_size, 0, - bs * seq_len - block_size) - # move local dk from UB to gm - # the shape of local dk in UB is zN - # the shape of local dk in gm is zN - # the stride between each (16, 64) is 0 - # repeat 8 times - tik_inst.data_move(mat_dk[b * seq_len // 16 + (block * block_size) // 16, - head * size_per_head // 16, 0, 0], - mat_ub_ldk_16[0, 0, 0, - 0], 0, block_size // 16, size_per_head, 0, - heads * size_per_head - size_per_head) - with tik_inst.for_range(0, global_size // 16) as lb: - # apply for tensor in ub for global dk (128, 16) zN - mat_ub_gdk_32 = tik_inst.Tensor("float32", - (1, size_per_head // 16, 16, 16), - name="mat_ub_gdk", - scope=tik.scope_ubuf) - # apply for tensor in ub for global dk (128, 16) zN - mat_ub_gdk = tik_inst.Tensor("float16", - (1, size_per_head // 16, 16, 16), - name="mat_ub_gdk", - scope=tik.scope_ubuf) - # apply for tensor in ub for global dk (128, 16) zN - mat_ub_ldk2 = tik_inst.Tensor("float16", - (1, size_per_head // 16, 16, 16), - name="mat_ub_ldk2", - scope=tik.scope_ubuf) - # move global dk from l0c to UB - # the shape of global dk in l0C is zN - # the shape of global dk in UB is zN - # the stride between each (16, 128) is 0 - # repeat 1 times - tik_inst.data_move(mat_ub_gdk_32[0, 0, 0, 0], mat_l0c_dkg[lb, 0, 0, 0], 0, 1, - size_per_head // 16, 0, 0) - tik_inst.vec_conv( - 64, "", mat_ub_gdk[0, 0, 0, 0], mat_ub_gdk_32[0, 0, 0, 0], size_per_head * 16 // 64, 4, 8) - # move local dk from gm to UB - # the shape of local dk in gm is zN - # the shape of local dk in UB is zN - # the stride between each (16, 128) is 0 - # repeat 1 times - tik_inst.data_move(mat_ub_ldk2[0, 0, 0, 0], mat_dk[b * seq_len // 16 + 4 * lb + global_idx, - head * size_per_head // 16, 0, 0], 0, 1, - size_per_head, 0, 0) - # add local dk and global dk - mat_ub_dk = tik_inst.Tensor("float16", - (1, size_per_head // 16, 16, 16), - name="mat_ub_dk", - scope=tik.scope_ubuf) - tik_inst.vec_add(128, mat_ub_dk, mat_ub_ldk2, mat_ub_gdk, - size_per_head * 16 // 128, 8, 8, 8) - # move dk from UB to gm - # the shape of dk in UB is zN - # the shape of dk in gm is zN - # the stride between each (16, 128) is 0 - # repeat 1 times - tik_inst.data_move( - mat_dk[b * seq_len // 16 + 4 * lb + global_idx, - head * size_per_head // 16, 0, 0], - mat_ub_dk[0, 0, 0, 0], 0, 1, size_per_head, 0, 0) - tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lc, mat_gc, mat_lc_grad, mat_gc_grad], - outputs=[mat_dq, mat_dk]) - return tik_inst +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""matmul dds impl""" +from te import tik +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +matmul_dds_grad_op_info = TBERegOp("MatmulDDSGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmul_dds_grad.so") \ + .compute_cost(10) \ + .kernel_name("matmul_dds_grad") \ + .partial_flag(True) \ + .input(0, "q", False, "required", "all") \ + .input(1, "k", False, "required", "all") \ + .input(2, "local_prob", False, "required", "all") \ + .input(3, "global_prob", False, "required", "all") \ + .input(4, "local_prob_grad", False, "required", "all") \ + .input(5, "global_prob_grad", False, "required", "all") \ + .output(0, "dq", False, "required", "all") \ + .output(1, "dk", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(matmul_dds_grad_op_info) +def matmul_dds_grad(q, + k, + local_prob, + global_prob, + local_prob_grad, + global_prob_grad, + dq, + dk, + kernel_name="matmul_dds_grad"): + """ + :param q: the dict of input q (bs*seq_len, embedding_size) zN + :param k: the dict of input k (bs*seq_len, embedding_size) nZ + :param local_mask: the dict of input mask local (bs*16*64, 64) zN + :param global_mask: the dict of input mask global (heads*1024, 256) zN + :param local_prob: local output (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN + :param global_prob: global output (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN + :param local_prob_grad: local output grad (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) zN + :param global_prob_grad: global output grad (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) zN + """ + + shape_q = q.get( + 'shape') + shape_lc = local_prob.get( + 'shape') + shape_gc = global_prob.get( + 'shape') + bs = shape_lc[0] + heads = shape_gc[1] + global_size = shape_gc[3] * shape_gc[-1] + block_size = shape_lc[4] * shape_lc[5] + seq_len = shape_q[1] * shape_q[2] // bs + block_num = seq_len // block_size + size_per_head = shape_q[0] * shape_q[-1] // heads + + tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud')) + mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), + name="mat_q", + scope=tik.scope_gm) # zN + mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), + name="mat_k", + scope=tik.scope_gm) # nZ + mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), + name="mat_lc", + scope=tik.scope_gm) # zN + mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), + name="mat_gc", + scope=tik.scope_gm) # zN + mat_lc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), + name="mat_lc_grad", + scope=tik.scope_gm) # zN + mat_gc_grad = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), + name="mat_gc_grad", + scope=tik.scope_gm) # zN + mat_dq = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), + name="mat_dq", + scope=tik.scope_gm) # zN + mat_dk = tik_inst.Tensor("float16", (bs * seq_len // 16, size_per_head * heads // 16, 16, 16), + name="mat_dk", + scope=tik.scope_gm) # zN + + channel_num = bs * heads + with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index: + # apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ + mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), + name='mat_l1_ones', + scope=tik.scope_cbuf) + with tik_inst.new_stmt_scope(): + mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), + name='mat_ub_ones', + scope=tik.scope_ubuf) + tik_inst.vec_dup(128, mat_ub_ones, 1.0, + (global_size + block_size) * 16 // 128, 8) + tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0], + 0, (global_size + block_size) // 16, 16, 0, 0) + + b = tik_inst.Scalar(dtype="int32") + b.set_as(block_index // heads) + + head = tik_inst.Scalar(dtype="int32") + head.set_as(block_index - b * heads) + + s = tik_inst.Scalar(dtype="int32") + s.set_as(head // 4) + # formula: global_idx = 3 - (head - 4 * s) # global idx for global key extraction + global_idx = tik_inst.Scalar(dtype="int32") + global_idx.set_as(3 - (head - 4 * s)) + # apply tensor in l1 for global k (256, 128) nZ + mat_l1_gk = tik_inst.Tensor("float16", + (global_size // 16, size_per_head // 16, 16, 16), + name="mat_l1_gk", + scope=tik.scope_cbuf) + # apply for tensor in L0C for global dk (128, 256) zN + mat_l0c_dkg = tik_inst.Tensor("float32", + (global_size // 16, + size_per_head // 16, 16, 16), + name="mat_l0c_dkg", + scope=tik.scope_cc) + with tik_inst.for_range(0, global_size // 16) as gb: + # move global key from gm to L1 nZ + # the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16) + tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0], + mat_k[ + head * size_per_head // 16, b * seq_len // 16 + + global_idx + gb * block_size // 16, 0, 0], + 0, size_per_head // 16, 16, bs * seq_len - 16, 0) + with tik_inst.for_range(0, block_num) as block: + # do backward softmax + # formula: grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax + # apply for tensor in ub for grad_x out (64, 320) zN + mat_ub_lg_d = tik_inst.Tensor("float16", + ((global_size + block_size) // + 16, block_size // 16, 16, 16), + name='mat_ub_lg_d', + scope=tik.scope_ubuf) + with tik_inst.new_stmt_scope(): + # apply for tensor in ub for softmax out (64, 320) zN + mat_ub_lg = tik_inst.Tensor("float16", ((global_size + block_size) // 16, block_size // 16, 16, 16), + name='mat_ub_lg', + scope=tik.scope_ubuf) + # apply for tensor in ub for softmax out grad (64, 320) zN + mat_ub_lg_grad = tik_inst.Tensor("float16", + ((global_size + block_size) // + 16, block_size // 16, 16, 16), + name='mat_ub_lg_grad', + scope=tik.scope_ubuf) + # move local out from gm to ub zN + # the shape of local out in gm is zN + # the shape of local out in UB is zN + # the stride between each (64, 16) is 0 + # repeat 4 times + tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_lc[b, head, block, 0, 0, 0, 0], 0, + block_size // 16, block_size, + 0, 0) + # move global out from gm to ub zN + # the shape of global out in gm is zN + # the shape of global out in UB is zN + # the stride between each (64, 16) is 0 + # repeat 16 times + tik_inst.data_move(mat_ub_lg[block_size // 16, 0, 0, 0], mat_gc[b, head, block, 0, 0, 0, 0], 0, + global_size // 16, block_size, + 0, 0) + # move local out grad from gm to ub zN + # the shape of local out grad in gm is zN + # the shape of local out grad in UB is zN + # the stride between each (64, 16) is 0 + # repeat 4 times + tik_inst.data_move(mat_ub_lg_grad[0, 0, 0, 0], mat_lc_grad[b, head, block, 0, 0, 0, 0], 0, + block_size // 16, block_size, + 0, 0) + # move global out grad from gm to ub zN + # the shape of global out grad in gm is zN + # the shape of global out grad in UB is zN + # the stride between each (64, 16) is 0 + # repeat 16 times + tik_inst.data_move(mat_ub_lg_grad[block_size // 16, 0, 0, 0], + mat_gc_grad[b, head, block, 0, 0, 0, 0], 0, + global_size // 16, block_size, + 0, 0) + # apply for tensor in ub for softmax multiply out grad (64, 320) zN + mat_ub_ssg = tik_inst.Tensor("float16", + ((global_size + block_size) // + 16, block_size // 16, 16, 16), + name='mat_ub_ssg', + scope=tik.scope_ubuf) + # calculate softmax * softmax_grad + tik_inst.vmul(128, mat_ub_ssg[0, 0, 0, 0], mat_ub_lg_grad[0, 0, 0, 0], mat_ub_lg[0, 0, 0, 0], + (global_size + block_size) * block_size // 128, + 1, 1, 1, 8, 8, 8) + + # apply for tensor in L1 for dsoftmax*softmax result (320, 64) nZ + mat_l1_ssg_nz = tik_inst.Tensor("float16", ((global_size + block_size) // 16, + block_size // 16, 16, 16), + name='mat_l1_ssg_nz', + scope=tik.scope_cbuf) + # move ones from ub to L1 for CUBE mmad + # the shape of ones in ub is nZ + # the shape of ones in L0A is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.data_move(mat_l1_ssg_nz[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], 0, + (global_size + block_size) // 16, block_size, 0, 0) + # apply tensor in l0c for exp sum (16, 64) zN + mat_l0c_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), + name='mat_l0c_ssg_sum', + scope=tik.scope_cc) + # apply tensor in ub for exp sum (16, 64) zN + mat_ub_ssg_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), + name='mat_ub_ssg_sum', + scope=tik.scope_ubuf) + # apply for tensor in L0A for q (16, 320) zZ + mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 16, 16, 16), + name='mat_l0a_ones', scope=tik.scope_ca) + # apply for tensor in L0B for exp (320, 64) nZ + mat_l0b_ssg = tik_inst.Tensor('float16', ((global_size + block_size) // 16, block_size // 16, 16, 16), + name='mat_l0b_exp', scope=tik.scope_cb) + # move ones from l1 to L0A for CUBE mmad + # the shape of ones in l1 is zZ + # the shape of ones in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0, + (global_size + block_size) * 16 // (16 * 16), 1, 0, False) + # move ssg from l1 to L0B for CUBE mmad + # the shape of ssg in l1 is nZ + # the shape of ssg in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 128 times + tik_inst.load2dv1(mat_l0b_ssg[0, 0, 0, 0], mat_l1_ssg_nz[0, 0, 0, 0], 0, + (global_size + block_size) * block_size // (16 * 16), 1, 0, False) + tik_inst.mmad(mat_l0c_ssg_sum, mat_l0a_ones, mat_l0b_ssg, + 16, (global_size + block_size), block_size, 0) + tik_inst.data_move(mat_ub_ssg_sum[0, 0, 0, 0], mat_l0c_ssg_sum[0, 0, 0, 0], 0, + block_size // 16, 1, 0, 0) + # apply for tensor in UB for global prob sum (64,) + mat_ub_ssg_sums = tik_inst.Tensor("float32", (block_size,), + name='mat_ub_ssg_sums', + scope=tik.scope_ubuf) + tik_inst.data_move(mat_ub_ssg_sums[0], mat_ub_ssg_sum[0, 0, 0, 0], + 0, block_size // 16, 1 * 2, 15 * 2, 0) + # apply for tensor in UB for global prob sum (64,) + mat_ub_ssg_sums_16 = tik_inst.Tensor("float16", (block_size,), + name='mat_ub_ssg_sums_16', + scope=tik.scope_ubuf) + # convert fp32 to fp16 + tik_inst.vec_conv( + 64, "", mat_ub_ssg_sums_16[0], mat_ub_ssg_sums[0], 1, 4, 8) + + mat_ub_ssgs = tik_inst.Tensor("float16", + ((global_size + block_size) // + 16, block_size // 16, 16, 16), + name='mat_ub_ssgs', + scope=tik.scope_ubuf) + + with tik_inst.for_range(0, block_size) as bbs: + # apply for scalar in UB for prob sum rec + sum_ssg = tik_inst.Scalar("float16", + name='sum_ssg', + init_value=0) + # set value for scalar prob sum rec + sum_ssg.set_as(mat_ub_ssg_sums_16[bbs]) + tik_inst.vec_muls(16, mat_ub_ssgs[0, bbs // 16, bbs % 16, 0], + mat_ub_lg[0, bbs // 16, bbs % + 16, 0], sum_ssg, + (global_size + block_size) // 16, + block_size, block_size) + + tik_inst.vsub(128, mat_ub_lg_d[0, 0, 0, 0], mat_ub_ssg[0, 0, 0, 0], mat_ub_ssgs[0, 0, 0, 0], + (global_size + block_size) * block_size // 128, + 1, 1, 1, 8, 8, 8) + + # local dq calculation + # dw X K.T + # apply tensor in l1 for local k (64, 128) nZ + mat_l1_lk = tik_inst.Tensor("float16", + (block_size // 16, + size_per_head // 16, 16, 16), + name="mat_l1_lk", + scope=tik.scope_cbuf) + # move k from gm to l1 + # the shape of local k in gm is nZ + # the shape of local k in l1 is zZ + # the stride between each (16, 16) is 1024*bs-64 + # repeat 8 times + # LOOP 4 times + with tik_inst.for_range(0, block_size // 16) as lb: + tik_inst.data_move(mat_l1_lk[lb, 0, 0, 0], + mat_k[head * size_per_head // 16, b * seq_len // 16 + ( + block * block_size) // 16 + lb, 0, 0], + 0, size_per_head // 16, 16, bs * seq_len - 16, 0) + + # apply tensor in l1 for local dw (64, 128) zZ + mat_l1_ldw = tik_inst.Tensor("float16", + (block_size // 16, + block_size // 16, 16, 16), + name="mat_l1_ldw", + scope=tik.scope_cbuf) + # move local d-softmax from ub to l1 + # the shape of d-softmax in ub is zN + # the shape of d-softmax in l1 is zZ + # the stride between each (16, 64) is 0 + # repeat 16 times + with tik_inst.for_range(0, block_size // 16) as lb: + tik_inst.data_move(mat_l1_ldw[lb, 0, 0, 0], + mat_ub_lg_d[0, lb, 0, 0], + 0, block_size // 16, 16, block_size - 16, 0) + # apply for tensor in L0C for local d-q (64, 128) zN + mat_l0c_dq = tik_inst.Tensor("float32", + (size_per_head // 16, + block_size // 16, 16, 16), + name="mat_l0c_dq", + scope=tik.scope_cc) + with tik_inst.new_stmt_scope(): + # apply for tensor in L0A for q (64, 64) zZ + mat_l0a_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16), + name='mat_l0a_ldw', scope=tik.scope_ca) + # apply for tensor in L0B for global k (128, 256) nZ + mat_l0b_lk = tik_inst.Tensor('float16', (block_size // 16, size_per_head // 16, 16, 16), + name='mat_l0b_lk', scope=tik.scope_cb) + # move q from l1 to L0A for CUBE mmad + # the shape of q in l1 is zZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 16 times + tik_inst.load2dv1(mat_l0a_ldw[0, 0, 0, 0], mat_l1_ldw[0, 0, 0, 0], 0, + block_size * block_size // (16 * 16), 1, 0, False) + # move local k from l1 to L0B for CUBE mmad + # the shape of local k in l1 is zZ + # the shape of local k in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0b_lk[0, 0, 0, 0], mat_l1_lk[0, 0, 0, 0], 0, + block_size * size_per_head // (16 * 16), 1, 0, True) + # matmul q and local dw + # the shape of global scores in L0C is zN + tik_inst.mmad(mat_l0c_dq, mat_l0a_ldw, mat_l0b_lk, + block_size, block_size, size_per_head, 0) + + # global dq calculation + # apply tensor in l1 for global dw (64, 256) zZ + mat_l1_gdw = tik_inst.Tensor("float16", + (block_size // 16, + global_size // 16, 16, 16), + name="mat_l1_gdw", + scope=tik.scope_cbuf) + # move global dw from ub to l1 + # the shape of global dw in gm is zN + # the shape of global dw in l1 is zZ + # the stride between each (16, 16) is 1024*bs-64 + # repeat 8 times + # LOOP 4 times + with tik_inst.for_range(0, block_size // 16) as lb: + tik_inst.data_move(mat_l1_gdw[lb, 0, 0, 0], + mat_ub_lg_d[block_size // 16, lb, 0, 0], + 0, global_size // 16, 16, block_size - 16, 0) + # apply for tensor in ub for dq (64, 128) zN + mat_ub_dq = tik_inst.Tensor("float32", + (size_per_head // 16, + block_size // 16, 16, 16), + name="mat_ub_dq", + scope=tik.scope_ubuf) + with tik_inst.new_stmt_scope(): + # apply for tensor in L0A for global dw (64, 256) zZ + mat_l0a_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16), + name='mat_l0a_gdw', scope=tik.scope_ca) + # apply for tensor in L0B for global k (256, 128) nZ + mat_l0b_gk = tik_inst.Tensor('float16', (global_size // 16, size_per_head // 16, 16, 16), + name='mat_l0b_gk', scope=tik.scope_cb) + # move dw global from l1 to L0A for CUBE mmad + # the shape of q in l1 is zZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 16 times + tik_inst.load2dv1(mat_l0a_gdw[0, 0, 0, 0], mat_l1_gdw[0, 0, 0, 0], 0, + block_size * global_size // (16 * 16), 1, 0, False) + # move local k from l1 to L0B for CUBE mmad + # the shape of local k in l1 is zZ + # the shape of local k in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0b_gk[0, 0, 0, 0], mat_l1_gk[0, 0, 0, 0], 0, + global_size * size_per_head // (16 * 16), 1, 0, True) + # matmul k and local dw + # the shape of global scores in L0C is zN + tik_inst.mmad(mat_l0c_dq, mat_l0a_gdw, mat_l0b_gk, + block_size, global_size, size_per_head, 1) + # move dq from l0c to UB + # the shape of dq in l9c is zN + # the shape of dq in ub is zN + # the stride between each (16, 64) is 0 + # repeat 8 times + tik_inst.data_move(mat_ub_dq[0, 0, 0, 0], mat_l0c_dq[0, 0, 0, 0], 0, size_per_head // 16, + block_size // 16, 0, 0) + + # local dk calculation + # dk calculation q.T X dw + # apply for tensor in ub for dw (320, 64) nZ + mat_ub_lg_d_nz = tik_inst.Tensor("float16", + (block_size // 16, (global_size + + block_size) // 16, 16, 16), + name='mat_ub_lg_d_nz', + scope=tik.scope_ubuf) + # transpose dw from zN to nZ + with tik_inst.for_range(0, (global_size + block_size) // 16) as lb: + with tik_inst.for_range(0, block_size // 16) as gb: + tik_inst.vtranspose( + mat_ub_lg_d_nz[gb, lb, 0, 0], mat_ub_lg_d[lb, gb, 0, 0]) + + # apply tensor in l1 for local dw (64, 64) nZ + mat_l1_ldw_nz = tik_inst.Tensor("float16", + (block_size // 16, + block_size // 16, 16, 16), + name="mat_l1_ldw_nz", + scope=tik.scope_cbuf) + # move local dw from ub to l1 + # the shape of local dw in ub is nZ + # the shape of local dw in l1 is nZ + # the stride between each (16, 64) is 256 + # repeat 4 times + tik_inst.data_move(mat_l1_ldw_nz[0, 0, 0, 0], + mat_ub_lg_d_nz[0, 0, 0, 0], + 0, block_size // 16, block_size, global_size, 0) + # apply for tensor in L1 for q (128, 64) nZ + mat_l1_q_b = tik_inst.Tensor("float16", + (size_per_head // 16, + block_size // 16, 16, 16), + name="mat_l1_q_b", + scope=tik.scope_cbuf) + # move local q from gm to l1 + # the shape of local q in gm is zN + # the shape of local dw in l1 is zZ + # the stride between each (16, 16) is 48 + # repeat 4 times + # LOOP 8 times + with tik_inst.for_range(0, size_per_head // 16) as lb: + tik_inst.load2dv1(mat_l1_q_b[lb, 0, 0, 0], + mat_q[head * size_per_head // 16 + lb, + b * seq_len // 16 + (block * block_size) // 16, 0, 0], + 0, block_size // 16, 1, 0, False) + # apply for tensor in L0C for local dk (128, 64) zN + mat_l0c_dkl = tik_inst.Tensor("float32", + (block_size // 16, + size_per_head // 16, 16, 16), + name="mat_l0c_dkl", + scope=tik.scope_cc) + # apply for tensor in ub for local dk (128, 64) zN + mat_ub_ldk = tik_inst.Tensor("float32", + (block_size // 16, + size_per_head // 16, 16, 16), + name="mat_ub_ldk", + scope=tik.scope_ubuf) + with tik_inst.new_stmt_scope(): + # apply for tensor in L0A for q (128, 64) zZ + mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16), + name='mat_l0a_q', scope=tik.scope_ca) + # apply for tensor in L0B for local dw (64, 64) nZ + mat_l0b_ldw = tik_inst.Tensor('float16', (block_size // 16, block_size // 16, 16, 16), + name='mat_l0b_ldw', scope=tik.scope_cb) + # move q from l1 to L0A for CUBE mmad + # the shape of q in l1 is nZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 4 times + # LOOP 8 times + tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0], + mat_l1_q_b[0, 0, 0, 0], + 0, block_size * size_per_head // 256, 1, 0, True) + # move local dw from l1 to L0B for CUBE mmad + # the shape of local dw in l1 is nZ + # the shape of local dw in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0b_ldw[0, 0, 0, 0], mat_l1_ldw_nz[0, 0, 0, 0], 0, + block_size * block_size // (16 * 16), 1, 0, False) + # matmul q and local dw + # the shape of local k in L0C is zN + tik_inst.mmad(mat_l0c_dkl, mat_l0a_q, mat_l0b_ldw, + size_per_head, block_size, block_size, 0) + # move local dk from l0c to UB + # the shape of local dk in l0C is zN + # the shape of local dk in UB is zN + # the stride between each (16, 128) is 0 + # repeat 4 times + tik_inst.data_move(mat_ub_ldk[0, 0, 0, 0], mat_l0c_dkl[0, 0, 0, 0], 0, block_size // 16, + size_per_head // 16, 0, 0) + + # move global dw from UB to l1 + # apply for tensor in L1 for global dw (64, 256) nZ + mat_l1_dwg_b = tik_inst.Tensor("float16", + (block_size // 16, + global_size // 16, 16, 16), + name="mat_l1_dwg_b", + scope=tik.scope_cbuf) + # move global dw from UB to L1 + # the shape of global dw in gm is nZ + # the shape of global dw in gm is nZ + # the stride between each (16, 64) is 0 + # repeat 8 times + tik_inst.data_move(mat_l1_dwg_b[0, 0, 0, 0], + mat_ub_lg_d_nz[0, block_size // 16, 0, 0], + 0, block_size // 16, global_size, block_size, 0) + + with tik_inst.new_stmt_scope(): + # apply for tensor in L0A for q (128, 64) zZ + mat_l0a_q = tik_inst.Tensor('float16', (size_per_head // 16, block_size // 16, 16, 16), + name='mat_l0a_q', scope=tik.scope_ca) + # apply for tensor in L0B for local dw (64, 64) nZ + mat_l0b_gdw = tik_inst.Tensor('float16', (block_size // 16, global_size // 16, 16, 16), + name='mat_l0b_ldw', scope=tik.scope_cb) + # move q from l1 to L0A for CUBE mmad + # the shape of q in l1 is nZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 4 times + # LOOP 8 times + tik_inst.load2dv1(mat_l0a_q[0, 0, 0, 0], + mat_l1_q_b[0, 0, 0, 0], + 0, block_size * size_per_head // 256, 1, 0, True) + # move local dw from l1 to L0B for CUBE mmad + # the shape of local dw in l1 is nZ + # the shape of local dw in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0b_gdw[0, 0, 0, 0], mat_l1_dwg_b[0, 0, 0, 0], 0, + block_size * global_size // (16 * 16), 1, 0, False) + # matmul q and local dw + # the shape of local k in L0C is zN + with tik_inst.if_scope(block == 0): + tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw, + size_per_head, block_size, global_size, 0) + with tik_inst.else_scope(): + tik_inst.mmad(mat_l0c_dkg, mat_l0a_q, mat_l0b_gdw, + size_per_head, block_size, global_size, 1) + + # cast dq from 32 to 16 + # apply for tensor in ub for dq (64, 128) zN + mat_ub_dq_16 = tik_inst.Tensor("float16", + (size_per_head // 16, + block_size // 16, 16, 16), + name="mat_ub_dq_16", + scope=tik.scope_ubuf) + # apply for tensor in ub for local dk (128, 64) zN + mat_ub_ldk_16 = tik_inst.Tensor("float16", + (block_size // 16, + size_per_head // 16, 16, 16), + name="mat_ub_ldk_16", + scope=tik.scope_ubuf) + tik_inst.vec_conv( + 64, "", mat_ub_ldk_16[0, 0, 0, 0], mat_ub_ldk[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8) + tik_inst.vec_conv( + 64, "", mat_ub_dq_16[0, 0, 0, 0], mat_ub_dq[0, 0, 0, 0], size_per_head * block_size // 64, 4, 8) + + # move dq from UB to gm + # the shape of dq in UB is zN + # the shape of dq in gm is zN + # the stride between each (16, 64) is 0 + # repeat 8 times + tik_inst.data_move(mat_dq[head * size_per_head // 16, + b * seq_len // 16 + (block * block_size) // 16, 0, 0], + mat_ub_dq_16[0, 0, 0, + 0], 0, size_per_head // 16, block_size, 0, + bs * seq_len - block_size) + # move local dk from UB to gm + # the shape of local dk in UB is zN + # the shape of local dk in gm is zN + # the stride between each (16, 64) is 0 + # repeat 8 times + tik_inst.data_move(mat_dk[b * seq_len // 16 + (block * block_size) // 16, + head * size_per_head // 16, 0, 0], + mat_ub_ldk_16[0, 0, 0, + 0], 0, block_size // 16, size_per_head, 0, + heads * size_per_head - size_per_head) + with tik_inst.for_range(0, global_size // 16) as lb: + # apply for tensor in ub for global dk (128, 16) zN + mat_ub_gdk_32 = tik_inst.Tensor("float32", + (1, size_per_head // 16, 16, 16), + name="mat_ub_gdk", + scope=tik.scope_ubuf) + # apply for tensor in ub for global dk (128, 16) zN + mat_ub_gdk = tik_inst.Tensor("float16", + (1, size_per_head // 16, 16, 16), + name="mat_ub_gdk", + scope=tik.scope_ubuf) + # apply for tensor in ub for global dk (128, 16) zN + mat_ub_ldk2 = tik_inst.Tensor("float16", + (1, size_per_head // 16, 16, 16), + name="mat_ub_ldk2", + scope=tik.scope_ubuf) + # move global dk from l0c to UB + # the shape of global dk in l0C is zN + # the shape of global dk in UB is zN + # the stride between each (16, 128) is 0 + # repeat 1 times + tik_inst.data_move(mat_ub_gdk_32[0, 0, 0, 0], mat_l0c_dkg[lb, 0, 0, 0], 0, 1, + size_per_head // 16, 0, 0) + tik_inst.vec_conv( + 64, "", mat_ub_gdk[0, 0, 0, 0], mat_ub_gdk_32[0, 0, 0, 0], size_per_head * 16 // 64, 4, 8) + # move local dk from gm to UB + # the shape of local dk in gm is zN + # the shape of local dk in UB is zN + # the stride between each (16, 128) is 0 + # repeat 1 times + tik_inst.data_move(mat_ub_ldk2[0, 0, 0, 0], mat_dk[b * seq_len // 16 + 4 * lb + global_idx, + head * size_per_head // 16, 0, 0], 0, 1, + size_per_head, 0, 0) + # add local dk and global dk + mat_ub_dk = tik_inst.Tensor("float16", + (1, size_per_head // 16, 16, 16), + name="mat_ub_dk", + scope=tik.scope_ubuf) + tik_inst.vec_add(128, mat_ub_dk, mat_ub_ldk2, mat_ub_gdk, + size_per_head * 16 // 128, 8, 8, 8) + # move dk from UB to gm + # the shape of dk in UB is zN + # the shape of dk in gm is zN + # the stride between each (16, 128) is 0 + # repeat 1 times + tik_inst.data_move( + mat_dk[b * seq_len // 16 + 4 * lb + global_idx, + head * size_per_head // 16, 0, 0], + mat_ub_dk[0, 0, 0, 0], 0, 1, size_per_head, 0, 0) + tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lc, mat_gc, mat_lc_grad, mat_gc_grad], + outputs=[mat_dq, mat_dk]) + return tik_inst diff --git a/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py b/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py index b43b509068a..a3fb9467341 100644 --- a/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +++ b/mindspore/python/mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py @@ -1,488 +1,488 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""matmul dds impl""" -from te import tik -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -matmul_dds_op_info = TBERegOp("MatmulDDS") \ - .fusion_type("OPAQUE") \ - .async_flag(False) \ - .binfile_name("matmul_dds.so") \ - .compute_cost(10) \ - .kernel_name("matmul_dds") \ - .partial_flag(True) \ - .attr("bs", "required", "int", "all") \ - .attr("heads", "required", "int", "all") \ - .input(0, "q", False, "required", "all") \ - .input(1, "k", False, "required", "all") \ - .input(2, "local_mask", False, "required", "all") \ - .input(3, "global_mask", False, "required", "all") \ - .output(0, "local_prob", False, "required", "all") \ - .output(1, "global_prob", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, - DataType.F32_Default, DataType.F32_Default, - DataType.F16_Default, DataType.F16_Default) \ - .get_op_info() - - -@op_info_register(matmul_dds_op_info) -def matmul_dds(q, - k, - local_mask, - global_mask, - local_prob, - global_prob, - bs, - heads, - kernel_name="matmul_dds"): - """ - :param q: the dict of input q (bs*seq_len, embedding_size) zN - :param k: the dict of input k (bs*seq_len, embedding_size) nZ - :param bs: batch size int - :param heads: number of heads int - :param local_mask: the dict of input mask local (bs*16*64, 64) zN - :param global_mask: the dict of input mask global (heads*1024, 256) zN - :param kernel_name: dds_softmax - :return: None - """ - - shape_q = q.get( - 'shape') # shape_q (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) zN - shape_local_mask = local_mask.get( - 'shape') # shape_local_mask (16*64, bs*64) > (64, bs*4, 16, 16) zN - # sequence length only support 1024 for now - seq_len = shape_q[1] * shape_q[2] // bs - # size per head assume 128 - size_per_head = shape_q[0] * shape_q[-1] // heads - block_size = shape_local_mask[0] # block size only support 64 for now - block_num = seq_len // block_size # block number only support 16 for now - global_size = seq_len // 4 # global size only support 256 for now - - tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud')) - - mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), - name="mat_q", - scope=tik.scope_gm) # zN - mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), - name="mat_k", - scope=tik.scope_gm) # nZ - mat_lm = tik_inst.Tensor("float32", (block_num * block_size // 16, bs * block_size // 16, 16, 16), - name="mat_lm", - scope=tik.scope_gm) # zN - mat_gm = tik_inst.Tensor("float32", (bs * global_size // 16, seq_len // 16, 16, 16), - name="mat_gm", - scope=tik.scope_gm) # zN - mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), - name="mat_lc", - scope=tik.scope_gm) # zN - mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), - name="mat_gc", - scope=tik.scope_gm) # zN - - channel_num = bs * heads - - with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index: - # apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ - mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), - name='mat_l1_ones', - scope=tik.scope_cbuf) - - with tik_inst.new_stmt_scope(): - mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), - name='mat_ub_ones', - scope=tik.scope_ubuf) - tik_inst.vec_dup(128, mat_ub_ones, 1.0, - (global_size + block_size) * 16 // 128, 8) - tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0], - 0, (global_size + block_size) // 16, 16, 0, 0) - - b = tik_inst.Scalar(dtype="int32") - b.set_as(block_index // heads) - - head = tik_inst.Scalar(dtype="int32") - head.set_as(block_index - b * heads) - s = tik_inst.Scalar(dtype="int32") - s.set_as(head // 4) - global_idx = tik_inst.Scalar(dtype="int32") - global_idx.set_as(3 - (head - 4 * s)) - # apply tensor for global key which is (128, 256) in L1 nZ - # for each head, global k is the same, put global k in L1 in order of reuse - mat_l1_gk = tik_inst.Tensor("float16", - (size_per_head // 16, - global_size // 16, 16, 16), - name="mat_l1_gk", - scope=tik.scope_cbuf) - with tik_inst.for_range(0, size_per_head // 16) as gb: - # move global key from gm to L1 nZ - # the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16) - tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0], - mat_k[head * size_per_head // 16 + gb, - b * seq_len // 16 + global_idx, 0, 0], - 0, block_num, 16, 48, 0) - - with tik_inst.for_range(0, block_num) as block: - # calculate qk matmul block by block - - # apply tensor in l0c for local mask (64, 64) zN - mat_l0c_l = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16), - name='mat_l0c_l', - scope=tik.scope_cc) - # apply tensor in l0c for global mask (256, 64) zN - mat_l0c_g = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16), - name='mat_l0c_g', - scope=tik.scope_cc) - # apply tensor in l1 for local k (128, 64) nZ - mat_l1_lk = tik_inst.Tensor("float16", - (size_per_head // 16, - block_size // 16, 16, 16), - name="mat_l1_lk", - scope=tik.scope_cbuf) - # apply for tensor in L1 for fp 16 exp result (320, 64) zN - mat_l1_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16, - block_size // 16, 16, 16), - name='mat_l1_lg_exp_16', - scope=tik.scope_cbuf) - # convert exp out to fp 16 - # apply for tensor in UB for fp 16 exp result (64, 320) zN - mat_ub_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16, - block_size // 16, 16, 16), - name='mat_ub_lg_exp_16', - scope=tik.scope_ubuf) - # move local k from gm to l1 nZ - # the shape of local k in gm is nZ - # the shape of local k in l1 is nZ - # the stride between each (16, 64) is 1024*bs-64 - # repeat 8 times - tik_inst.data_move(mat_l1_lk, - mat_k[head * size_per_head // 16, b * seq_len // 16 + ( - block * block_size) // 16, 0, 0], - 0, size_per_head // 16, block_size, bs * seq_len - block_size, 0) - # apply tensor in l1 for q (64, 128) zN - mat_l1_q = tik_inst.Tensor("float16", - (block_size // 16, - size_per_head // 16, 16, 16), - name="mat_l1_q", - scope=tik.scope_cbuf) - # move q from gm to l1 - # the shape of local k in gm is zN - # the shape of local k in l1 is zZ - # the stride between each (16, 16) is 1024*bs-64 - # repeat 8 times - # LOOP 4 times - with tik_inst.for_range(0, block_size // 16) as lb: - tik_inst.data_move(mat_l1_q[lb, 0, 0, 0], - mat_q[head * size_per_head // 16, b * seq_len // 16 + ( - block * block_size) // 16 + lb, 0, 0], - 0, size_per_head // 16, 16, bs * seq_len - 16, 0) - - # global - # apply a new scope - with tik_inst.new_stmt_scope(): - # apply tensor in ub for global mask (256, 64) zN - mat_ub_gm = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16), - name='mat_ub_gm', - scope=tik.scope_ubuf) - # move global mask from gm to ub zN - # the shape of global mask in gm is zN - # the shape of global mask in UB is zN - # the stride between each (64, 16) is 960 - # repeat 16 times - tik_inst.data_move(mat_ub_gm, - mat_gm[b * global_size // 16, - block * block_size // 16, 0, 0], - 0, global_size // 16, block_size * 2, seq_len * 2 - block_size * 2, 0) - # move global mask from ub to l0c for bias add - # the shape of global mask in ub is zN - # the shape of global mask in l0c is zN - # the stride between each (16, 64) is 0 - # repeat 16 times - tik_inst.data_move(mat_l0c_g[0, 0, 0, 0], - mat_ub_gm[0, 0, 0, 0], - 0, global_size // 16, block_size // 16, 0, 0) - with tik_inst.for_range(0, 4, thread_num=2) as gb: - # apply for tensor in L0A for q (64, 128) zZ - mat_l0a_g = tik_inst.Tensor('float16', - (block_size // 16, size_per_head // - (16 * 4), 16, 16), - name='mat_l0a_g', scope=tik.scope_ca) - # apply for tensor in L0B for global k (128, 256) nZ - mat_l0b_g = tik_inst.Tensor('float16', - (size_per_head // (16 * 4), - global_size // 16, 16, 16), - name='mat_l0b_g', scope=tik.scope_cb) - # move q from l1 to L0A for CUBE mmad - # the shape of q in l1 is zZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 32 times - with tik_inst.for_range(0, block_size // 16) as bl: - tik_inst.load2dv1(mat_l0a_g[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0, - 16 * size_per_head // (4 * 16 * 16), 1, 0, False) - # move global k from l1 to L0B for CUBE mmad - # the shape of global k in l1 is nZ - # the shape of global k in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 128 times - tik_inst.load2dv1(mat_l0b_g[0, 0, 0, 0], mat_l1_gk[size_per_head * gb // 64, 0, 0, 0], 0, - global_size * size_per_head // (4 * 16 * 16), 1, 0, False) - # matmul q and global k - # the shape of global scores in L0C is zN - tik_inst.mmad(mat_l0c_g, mat_l0a_g, mat_l0b_g, - block_size, size_per_head // 4, global_size, 1) - - # local - # apply a new scope - with tik_inst.new_stmt_scope(): - # apply tensor in ub for local mask (64, 64) zN - mat_ub_lm = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16), - name='mat_ub_lm', - scope=tik.scope_ubuf) - # move local mask from gm to ub zN - # the shape of local mask in gm is zN - # the shape of local mask in UB is zN - # the stride between each (64, 16) is 0 - # repeat 4 times - tik_inst.data_move(mat_ub_lm, - mat_lm[block * block_size // 16, - b * block_size // 16, 0, 0], - 0, block_size // 16, block_size * 2, (bs * block_size - block_size) * 2, 0) - # move local mask from ub to l0c for bias add - # the shape of local mask in ub is zN - # the shape of local mask in l0c is zN - # the stride between each (16, 64) is 0 - # repeat 4 times - tik_inst.data_move(mat_l0c_l[0, 0, 0, 0], - mat_ub_lm[0, 0, 0, 0], - 0, block_size // 16, block_size // 16, 0, 0) - with tik_inst.for_range(0, 4, thread_num=2) as gb: - # apply for tensor in L0A for q (64, 128) zZ - mat_l0a_l = tik_inst.Tensor('float16', (block_size // 16, size_per_head // (16 * 4), 16, 16), - name='mat_l0a_l', scope=tik.scope_ca) - # apply for tensor in L0B for local k (128, 64) nZ - mat_l0b_l = tik_inst.Tensor('float16', (size_per_head // (16 * 4), block_size // 16, 16, 16), - name='mat_l0b_l', scope=tik.scope_cb) - # move q from l1 to L0A for CUBE mmad - # the shape of q in l1 is zZ - # the shape of q in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 32 times - with tik_inst.for_range(0, block_size // 16) as bl: - tik_inst.load2dv1(mat_l0a_l[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0, - 16 * size_per_head // (4 * 16 * 16), 1, 0, False) - # move local k from l1 to L0B for CUBE mmad - # the shape of local k in l1 is nZ - # the shape of local k in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0b_l[0, 0, 0, 0], mat_l1_lk[size_per_head * gb // 64, 0, 0, 0], 0, - block_size * size_per_head // (16 * 16 * 4), 1, 0, False) - # matmul q and local k - # the shape of local scores in L0C is (64, 64) zN - tik_inst.mmad(mat_l0c_l, mat_l0a_l, mat_l0b_l, - block_size, size_per_head // 4, block_size, 1) - - with tik_inst.new_stmt_scope(): - with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb: - mat_ub_lg = tik_inst.Tensor("float32", (1, (block_size + global_size) // 16, 16, 16), - name='mat_ub_lg', - scope=tik.scope_ubuf) - tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_l0c_g[0, gb, 0, 0], 0, - global_size // 16, 1, block_size // 16 - 1, 0) - tik_inst.data_move(mat_ub_lg[0, global_size // 16, 0, 0], mat_l0c_l[0, gb, 0, 0], 0, - block_size // 16, 1, block_size // 16 - 1, 0) - mat_ub_lg_16 = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16), - name='mat_ub_lg_16', - scope=tik.scope_ubuf) - tik_inst.vec_conv(64, "", mat_ub_lg_16[0, 0, 0, 0], - mat_ub_lg[0, 0, 0, 0], - (block_size + global_size) * 16 // 64, 4, 8) - with tik_inst.for_range(0, 16) as lb: - mat_ub_lg_lb = tik_inst.Tensor("float16", (block_size + global_size,), - name='mat_ub_lg_lb', - scope=tik.scope_ubuf) - mat_ub_lg_lb_subs = tik_inst.Tensor("float16", (block_size + global_size,), - name='mat_ub_lg_lb_subs', - scope=tik.scope_ubuf) - - tik_inst.data_move(mat_ub_lg_lb[0], mat_ub_lg_16[0, 0, lb, 0], 0, - (block_size + global_size) // 16, 1, 15, 0) - max_value = tik_inst.Scalar("float16", - name='max_value', - init_value=0) - with tik_inst.for_range(0, (block_size + global_size) // 64) as nb: - mat_ub_lg_max = tik_inst.Tensor("float16", (2,), - name='mat_ub_lg_max', - scope=tik.scope_ubuf) - tik_inst.vcmax(64, mat_ub_lg_max[0], mat_ub_lg_lb[64 * nb], 1, - 1, 1, 4) - mat_ub_lg_max_sub = tik_inst.Tensor("float16", (2,), - name='mat_ub_lg_max_sub', - scope=tik.scope_ubuf) - tik_inst.vmuls( - 2, mat_ub_lg_max_sub[0], mat_ub_lg_max[0], -1.0, 1, 1, 1, 1, 1) - block_max_value = tik_inst.Scalar("float16", - name='block_max_value', - init_value=0) - block_max_value.set_as(mat_ub_lg_max_sub[0]) - max_value_int8 = tik_inst.Scalar("int8", - name='max_value_int8', - init_value=0) - max_value_int = tik_inst.Tensor("int8", (1,), - name='max_value_int', - scope=tik.scope_ubuf) - max_value_fp16 = tik_inst.Tensor("float16", (1,), - name='max_value_fp16', - scope=tik.scope_ubuf) - max_value_fp16[0].set_as(max_value) - block_max_value_int = tik_inst.Tensor("int8", (1,), - name='block_max_value_int', - scope=tik.scope_ubuf) - block_max_value_int8 = tik_inst.Scalar("int8", - name='block_max_value_int8', - init_value=0) - tik_inst.vec_conv( - 1, "", max_value_int, max_value_fp16[0], 1, 1, 1) - tik_inst.vec_conv( - 1, "", block_max_value_int, mat_ub_lg_max_sub[0], 1, 1, 1) - max_value_int8.set_as(max_value_int[0]) - block_max_value_int8.set_as(block_max_value_int[0]) - with tik_inst.if_scope(block_max_value_int8 < max_value_int8): - max_value.set_as(block_max_value) - with tik_inst.else_scope(): - block_max_value.set_as(max_value) - tik_inst.vadds(64, mat_ub_lg_lb_subs[0], mat_ub_lg_lb[0], - max_value, (block_size + global_size) // 64, 1, 1, 4, 4) - mat_ub_lg_exp_lb = tik_inst.Tensor("float16", (block_size + global_size,), - name='mat_ub_lg_exp_lb', - scope=tik.scope_ubuf) - tik_inst.vexp(64, mat_ub_lg_exp_lb[0], - mat_ub_lg_lb_subs[0], (block_size + global_size) // 64, 1, 1, 4, 4) - tik_inst.data_move(mat_ub_lg_exp_16[0, gb, lb, 0], mat_ub_lg_exp_lb[0], 0, - (block_size + global_size) // 16, 1, 0, block_size - 1) - - # move exp fp16 from ub to L1 for CUBE mmad - # the shape of exp fp16 in ub is zN - # the shape of exp fp16 in L1 is zN - # the stride between each (16, 16) is 0 - # repeat 4 times - tik_inst.data_move(mat_l1_lg_exp_16[0, 0, 0, 0], mat_ub_lg_exp_16[0, 0, 0, 0], - 0, (global_size + block_size) // 16, block_size, 0, 0) - # apply for tensor in UB for local attention out (64, 64) zN - mat_ub_l_out = tik_inst.Tensor("float16", (block_size // 16, block_size // 16, 16, 16), - name='mat_ub_l_out', - scope=tik.scope_ubuf) - # apply for tensor in UB for global attention out (64, 256) zN - mat_ub_g_out = tik_inst.Tensor("float16", (global_size // 16, block_size // 16, 16, 16), - name='mat_ub_g_out', - scope=tik.scope_ubuf) - # apply tensor in l0c for exp sum (16, 64) zN - mat_l0c_exp = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), - name='mat_l0c_exp', - scope=tik.scope_cc) - # apply tensor in ub for exp sum (16, 64) zN - mat_ub_exp_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), - name='mat_ub_exp_sum', - scope=tik.scope_ubuf) - - with tik_inst.new_stmt_scope(): - with tik_inst.for_range(0, 4, thread_num=2) as gb: - # apply for tensor in L0A for q (64, 128) zZ - mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 64, 16, 16), - name='mat_l0a_ones', scope=tik.scope_ca) - # apply for tensor in L0B for exp (350, 64) nZ - mat_l0b_exp = tik_inst.Tensor('float16', - ((global_size + block_size) // - 64, block_size // 16, 16, 16), - name='mat_l0b_exp', scope=tik.scope_cb) - # move ones from l1 to L0A for CUBE mmad - # the shape of ones in l1 is zZ - # the shape of ones in L0A is zZ - # the stride between each (16, 16) is 0 - # repeat 32 times - tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0, - (global_size + block_size) * 16 // (4 * 16 * 16), 1, 0, False) - # move global k from l1 to L0B for CUBE mmad - # the shape of global k in l1 is nZ - # the shape of global k in L0B is nZ - # the stride between each (16, 16) is 0 - # repeat 128 times - tik_inst.load2dv1(mat_l0b_exp[0, 0, 0, 0], - mat_l1_lg_exp_16[(global_size + block_size) * gb // 64, 0, 0, 0], 0, - (global_size + block_size) * block_size // (4 * 16 * 16), 1, 0, False) - with tik_inst.if_scope(gb == 0): - tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16, - (global_size + block_size) // 4, block_size, 0) - with tik_inst.else_scope(): - tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16, - (global_size + block_size) // 4, block_size, 1) - - tik_inst.data_move(mat_ub_exp_sum[0, 0, 0, 0], mat_l0c_exp[0, 0, 0, 0], 0, - block_size // 16, 1, 0, 0) - # apply for tensor in UB for global prob sum (64,) - mat_ub_lg_exp_sum = tik_inst.Tensor("float32", (block_size,), - name='mat_ub_lg_exp_sum', - scope=tik.scope_ubuf) - tik_inst.data_move(mat_ub_lg_exp_sum[0], mat_ub_exp_sum[0, 0, 0, 0], - 0, block_size // 16, 1 * 2, 15 * 2, 0) - # apply for tensor in UB for attention prob sum rec (64,) - mat_ub_exp_sum_rec = tik_inst.Tensor("float32", (block_size,), - name='mat_ub_exp_sum_rec', - scope=tik.scope_ubuf) - mat_ub_exp_sum_rec_16 = tik_inst.Tensor("float16", (block_size,), - name='mat_ub_exp_sum_rec_16', - scope=tik.scope_ubuf) - worker_tensor = tik_inst.Tensor("float32", (block_size * 2,), - name='worker_tensor', - scope=tik.scope_ubuf) - # calculate attention prob sum vec (64,) - tik_inst.vec_rec_high_preci( - 64, mat_ub_exp_sum_rec, mat_ub_lg_exp_sum, worker_tensor, 1, 8, 8) - tik_inst.vec_conv(block_size, "", mat_ub_exp_sum_rec_16[0], - mat_ub_exp_sum_rec[0], - block_size // 64, 4, 8) - with tik_inst.for_range(0, block_size) as bbs: - # apply for scalar in UB for prob sum rec - sum_exp = tik_inst.Scalar("float16", - name='sum_exp', - init_value=0) - # set value for scalar prob sum rec - sum_exp.set_as(mat_ub_exp_sum_rec_16[bbs]) - tik_inst.vec_muls(16, mat_ub_l_out[0, bbs // 16, bbs % 16, 0], - mat_ub_lg_exp_16[global_size // - 16, bbs // 16, bbs % 16, 0], - sum_exp, block_size // 16, - block_size, block_size) - tik_inst.vec_muls(16, mat_ub_g_out[0, bbs // 16, bbs % 16, 0], - mat_ub_lg_exp_16[0, bbs // 16, bbs % - 16, 0], sum_exp, global_size // 16, - block_size, block_size) - # move local out from UB to gm - # the shape of local out in UB is zN - # the shape of local out in gm is zN - # the stride between each (16, 64) is 0 - # repeat 4 times - tik_inst.data_move(mat_lc[b, head, block, 0, 0, 0, 0], mat_ub_l_out[0, 0, 0, 0], 0, - block_size // 16, block_size, 0, 0) - # move global out from UB to gm - # the shape of global out in UB is zN - # the shape of global out in gm is zN - # the stride between each (16, 64) is 0 - # repeat 16 times - tik_inst.data_move(mat_gc[b, head, block, 0, 0, 0, 0], mat_ub_g_out[0, 0, 0, 0], 0, - global_size // 16, block_size, 0, 0) - - tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lm, mat_gm], - outputs=[mat_lc, mat_gc]) - return tik_inst +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""matmul dds impl""" +from te import tik +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +matmul_dds_op_info = TBERegOp("MatmulDDS") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("matmul_dds.so") \ + .compute_cost(10) \ + .kernel_name("matmul_dds") \ + .partial_flag(True) \ + .attr("bs", "required", "int", "all") \ + .attr("heads", "required", "int", "all") \ + .input(0, "q", False, "required", "all") \ + .input(1, "k", False, "required", "all") \ + .input(2, "local_mask", False, "required", "all") \ + .input(3, "global_mask", False, "required", "all") \ + .output(0, "local_prob", False, "required", "all") \ + .output(1, "global_prob", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, + DataType.F32_Default, DataType.F32_Default, + DataType.F16_Default, DataType.F16_Default) \ + .get_op_info() + + +@op_info_register(matmul_dds_op_info) +def matmul_dds(q, + k, + local_mask, + global_mask, + local_prob, + global_prob, + bs, + heads, + kernel_name="matmul_dds"): + """ + :param q: the dict of input q (bs*seq_len, embedding_size) zN + :param k: the dict of input k (bs*seq_len, embedding_size) nZ + :param bs: batch size int + :param heads: number of heads int + :param local_mask: the dict of input mask local (bs*16*64, 64) zN + :param global_mask: the dict of input mask global (heads*1024, 256) zN + :param kernel_name: dds_softmax + :return: None + """ + + shape_q = q.get( + 'shape') # shape_q (embedding_size, bs*seq_length) > (embedding_size//16, bs*seq_length//16, 16, 16) zN + shape_local_mask = local_mask.get( + 'shape') # shape_local_mask (16*64, bs*64) > (64, bs*4, 16, 16) zN + # sequence length only support 1024 for now + seq_len = shape_q[1] * shape_q[2] // bs + # size per head assume 128 + size_per_head = shape_q[0] * shape_q[-1] // heads + block_size = shape_local_mask[0] # block size only support 64 for now + block_num = seq_len // block_size # block number only support 16 for now + global_size = seq_len // 4 # global size only support 256 for now + + tik_inst = tik.Tik(tik.Dprofile('v100', 'cloud')) + + mat_q = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), + name="mat_q", + scope=tik.scope_gm) # zN + mat_k = tik_inst.Tensor("float16", (size_per_head * heads // 16, bs * seq_len // 16, 16, 16), + name="mat_k", + scope=tik.scope_gm) # nZ + mat_lm = tik_inst.Tensor("float32", (block_num * block_size // 16, bs * block_size // 16, 16, 16), + name="mat_lm", + scope=tik.scope_gm) # zN + mat_gm = tik_inst.Tensor("float32", (bs * global_size // 16, seq_len // 16, 16, 16), + name="mat_gm", + scope=tik.scope_gm) # zN + mat_lc = tik_inst.Tensor("float16", (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16), + name="mat_lc", + scope=tik.scope_gm) # zN + mat_gc = tik_inst.Tensor("float16", (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16), + name="mat_gc", + scope=tik.scope_gm) # zN + + channel_num = bs * heads + + with tik_inst.for_range(0, channel_num, block_num=channel_num) as block_index: + # apply for tensor in L1 for fp 16 ones-like result (16, 320) zZ + mat_l1_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), + name='mat_l1_ones', + scope=tik.scope_cbuf) + + with tik_inst.new_stmt_scope(): + mat_ub_ones = tik_inst.Tensor("float16", (1, (global_size + block_size) // 16, 16, 16), + name='mat_ub_ones', + scope=tik.scope_ubuf) + tik_inst.vec_dup(128, mat_ub_ones, 1.0, + (global_size + block_size) * 16 // 128, 8) + tik_inst.data_move(mat_l1_ones[0, 0, 0, 0], mat_ub_ones[0, 0, 0, 0], + 0, (global_size + block_size) // 16, 16, 0, 0) + + b = tik_inst.Scalar(dtype="int32") + b.set_as(block_index // heads) + + head = tik_inst.Scalar(dtype="int32") + head.set_as(block_index - b * heads) + s = tik_inst.Scalar(dtype="int32") + s.set_as(head // 4) + global_idx = tik_inst.Scalar(dtype="int32") + global_idx.set_as(3 - (head - 4 * s)) + # apply tensor for global key which is (128, 256) in L1 nZ + # for each head, global k is the same, put global k in L1 in order of reuse + mat_l1_gk = tik_inst.Tensor("float16", + (size_per_head // 16, + global_size // 16, 16, 16), + name="mat_l1_gk", + scope=tik.scope_cbuf) + with tik_inst.for_range(0, size_per_head // 16) as gb: + # move global key from gm to L1 nZ + # the shape of k is nZ, move (16, 256) in one loop, the stride between each (16, 16) is 3*(16,16) + tik_inst.data_move(mat_l1_gk[gb, 0, 0, 0], + mat_k[head * size_per_head // 16 + gb, + b * seq_len // 16 + global_idx, 0, 0], + 0, block_num, 16, 48, 0) + + with tik_inst.for_range(0, block_num) as block: + # calculate qk matmul block by block + + # apply tensor in l0c for local mask (64, 64) zN + mat_l0c_l = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16), + name='mat_l0c_l', + scope=tik.scope_cc) + # apply tensor in l0c for global mask (256, 64) zN + mat_l0c_g = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16), + name='mat_l0c_g', + scope=tik.scope_cc) + # apply tensor in l1 for local k (128, 64) nZ + mat_l1_lk = tik_inst.Tensor("float16", + (size_per_head // 16, + block_size // 16, 16, 16), + name="mat_l1_lk", + scope=tik.scope_cbuf) + # apply for tensor in L1 for fp 16 exp result (320, 64) zN + mat_l1_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16, + block_size // 16, 16, 16), + name='mat_l1_lg_exp_16', + scope=tik.scope_cbuf) + # convert exp out to fp 16 + # apply for tensor in UB for fp 16 exp result (64, 320) zN + mat_ub_lg_exp_16 = tik_inst.Tensor("float16", ((global_size + block_size) // 16, + block_size // 16, 16, 16), + name='mat_ub_lg_exp_16', + scope=tik.scope_ubuf) + # move local k from gm to l1 nZ + # the shape of local k in gm is nZ + # the shape of local k in l1 is nZ + # the stride between each (16, 64) is 1024*bs-64 + # repeat 8 times + tik_inst.data_move(mat_l1_lk, + mat_k[head * size_per_head // 16, b * seq_len // 16 + ( + block * block_size) // 16, 0, 0], + 0, size_per_head // 16, block_size, bs * seq_len - block_size, 0) + # apply tensor in l1 for q (64, 128) zN + mat_l1_q = tik_inst.Tensor("float16", + (block_size // 16, + size_per_head // 16, 16, 16), + name="mat_l1_q", + scope=tik.scope_cbuf) + # move q from gm to l1 + # the shape of local k in gm is zN + # the shape of local k in l1 is zZ + # the stride between each (16, 16) is 1024*bs-64 + # repeat 8 times + # LOOP 4 times + with tik_inst.for_range(0, block_size // 16) as lb: + tik_inst.data_move(mat_l1_q[lb, 0, 0, 0], + mat_q[head * size_per_head // 16, b * seq_len // 16 + ( + block * block_size) // 16 + lb, 0, 0], + 0, size_per_head // 16, 16, bs * seq_len - 16, 0) + + # global + # apply a new scope + with tik_inst.new_stmt_scope(): + # apply tensor in ub for global mask (256, 64) zN + mat_ub_gm = tik_inst.Tensor("float32", (global_size // 16, block_size // 16, 16, 16), + name='mat_ub_gm', + scope=tik.scope_ubuf) + # move global mask from gm to ub zN + # the shape of global mask in gm is zN + # the shape of global mask in UB is zN + # the stride between each (64, 16) is 960 + # repeat 16 times + tik_inst.data_move(mat_ub_gm, + mat_gm[b * global_size // 16, + block * block_size // 16, 0, 0], + 0, global_size // 16, block_size * 2, seq_len * 2 - block_size * 2, 0) + # move global mask from ub to l0c for bias add + # the shape of global mask in ub is zN + # the shape of global mask in l0c is zN + # the stride between each (16, 64) is 0 + # repeat 16 times + tik_inst.data_move(mat_l0c_g[0, 0, 0, 0], + mat_ub_gm[0, 0, 0, 0], + 0, global_size // 16, block_size // 16, 0, 0) + with tik_inst.for_range(0, 4, thread_num=2) as gb: + # apply for tensor in L0A for q (64, 128) zZ + mat_l0a_g = tik_inst.Tensor('float16', + (block_size // 16, size_per_head // + (16 * 4), 16, 16), + name='mat_l0a_g', scope=tik.scope_ca) + # apply for tensor in L0B for global k (128, 256) nZ + mat_l0b_g = tik_inst.Tensor('float16', + (size_per_head // (16 * 4), + global_size // 16, 16, 16), + name='mat_l0b_g', scope=tik.scope_cb) + # move q from l1 to L0A for CUBE mmad + # the shape of q in l1 is zZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 32 times + with tik_inst.for_range(0, block_size // 16) as bl: + tik_inst.load2dv1(mat_l0a_g[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0, + 16 * size_per_head // (4 * 16 * 16), 1, 0, False) + # move global k from l1 to L0B for CUBE mmad + # the shape of global k in l1 is nZ + # the shape of global k in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 128 times + tik_inst.load2dv1(mat_l0b_g[0, 0, 0, 0], mat_l1_gk[size_per_head * gb // 64, 0, 0, 0], 0, + global_size * size_per_head // (4 * 16 * 16), 1, 0, False) + # matmul q and global k + # the shape of global scores in L0C is zN + tik_inst.mmad(mat_l0c_g, mat_l0a_g, mat_l0b_g, + block_size, size_per_head // 4, global_size, 1) + + # local + # apply a new scope + with tik_inst.new_stmt_scope(): + # apply tensor in ub for local mask (64, 64) zN + mat_ub_lm = tik_inst.Tensor("float32", (block_size // 16, block_size // 16, 16, 16), + name='mat_ub_lm', + scope=tik.scope_ubuf) + # move local mask from gm to ub zN + # the shape of local mask in gm is zN + # the shape of local mask in UB is zN + # the stride between each (64, 16) is 0 + # repeat 4 times + tik_inst.data_move(mat_ub_lm, + mat_lm[block * block_size // 16, + b * block_size // 16, 0, 0], + 0, block_size // 16, block_size * 2, (bs * block_size - block_size) * 2, 0) + # move local mask from ub to l0c for bias add + # the shape of local mask in ub is zN + # the shape of local mask in l0c is zN + # the stride between each (16, 64) is 0 + # repeat 4 times + tik_inst.data_move(mat_l0c_l[0, 0, 0, 0], + mat_ub_lm[0, 0, 0, 0], + 0, block_size // 16, block_size // 16, 0, 0) + with tik_inst.for_range(0, 4, thread_num=2) as gb: + # apply for tensor in L0A for q (64, 128) zZ + mat_l0a_l = tik_inst.Tensor('float16', (block_size // 16, size_per_head // (16 * 4), 16, 16), + name='mat_l0a_l', scope=tik.scope_ca) + # apply for tensor in L0B for local k (128, 64) nZ + mat_l0b_l = tik_inst.Tensor('float16', (size_per_head // (16 * 4), block_size // 16, 16, 16), + name='mat_l0b_l', scope=tik.scope_cb) + # move q from l1 to L0A for CUBE mmad + # the shape of q in l1 is zZ + # the shape of q in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 32 times + with tik_inst.for_range(0, block_size // 16) as bl: + tik_inst.load2dv1(mat_l0a_l[bl, 0, 0, 0], mat_l1_q[bl, size_per_head * gb // 64, 0, 0], 0, + 16 * size_per_head // (4 * 16 * 16), 1, 0, False) + # move local k from l1 to L0B for CUBE mmad + # the shape of local k in l1 is nZ + # the shape of local k in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0b_l[0, 0, 0, 0], mat_l1_lk[size_per_head * gb // 64, 0, 0, 0], 0, + block_size * size_per_head // (16 * 16 * 4), 1, 0, False) + # matmul q and local k + # the shape of local scores in L0C is (64, 64) zN + tik_inst.mmad(mat_l0c_l, mat_l0a_l, mat_l0b_l, + block_size, size_per_head // 4, block_size, 1) + + with tik_inst.new_stmt_scope(): + with tik_inst.for_range(0, block_size // 16, thread_num=2) as gb: + mat_ub_lg = tik_inst.Tensor("float32", (1, (block_size + global_size) // 16, 16, 16), + name='mat_ub_lg', + scope=tik.scope_ubuf) + tik_inst.data_move(mat_ub_lg[0, 0, 0, 0], mat_l0c_g[0, gb, 0, 0], 0, + global_size // 16, 1, block_size // 16 - 1, 0) + tik_inst.data_move(mat_ub_lg[0, global_size // 16, 0, 0], mat_l0c_l[0, gb, 0, 0], 0, + block_size // 16, 1, block_size // 16 - 1, 0) + mat_ub_lg_16 = tik_inst.Tensor("float16", (1, (block_size + global_size) // 16, 16, 16), + name='mat_ub_lg_16', + scope=tik.scope_ubuf) + tik_inst.vec_conv(64, "", mat_ub_lg_16[0, 0, 0, 0], + mat_ub_lg[0, 0, 0, 0], + (block_size + global_size) * 16 // 64, 4, 8) + with tik_inst.for_range(0, 16) as lb: + mat_ub_lg_lb = tik_inst.Tensor("float16", (block_size + global_size,), + name='mat_ub_lg_lb', + scope=tik.scope_ubuf) + mat_ub_lg_lb_subs = tik_inst.Tensor("float16", (block_size + global_size,), + name='mat_ub_lg_lb_subs', + scope=tik.scope_ubuf) + + tik_inst.data_move(mat_ub_lg_lb[0], mat_ub_lg_16[0, 0, lb, 0], 0, + (block_size + global_size) // 16, 1, 15, 0) + max_value = tik_inst.Scalar("float16", + name='max_value', + init_value=0) + with tik_inst.for_range(0, (block_size + global_size) // 64) as nb: + mat_ub_lg_max = tik_inst.Tensor("float16", (2,), + name='mat_ub_lg_max', + scope=tik.scope_ubuf) + tik_inst.vcmax(64, mat_ub_lg_max[0], mat_ub_lg_lb[64 * nb], 1, + 1, 1, 4) + mat_ub_lg_max_sub = tik_inst.Tensor("float16", (2,), + name='mat_ub_lg_max_sub', + scope=tik.scope_ubuf) + tik_inst.vmuls( + 2, mat_ub_lg_max_sub[0], mat_ub_lg_max[0], -1.0, 1, 1, 1, 1, 1) + block_max_value = tik_inst.Scalar("float16", + name='block_max_value', + init_value=0) + block_max_value.set_as(mat_ub_lg_max_sub[0]) + max_value_int8 = tik_inst.Scalar("int8", + name='max_value_int8', + init_value=0) + max_value_int = tik_inst.Tensor("int8", (1,), + name='max_value_int', + scope=tik.scope_ubuf) + max_value_fp16 = tik_inst.Tensor("float16", (1,), + name='max_value_fp16', + scope=tik.scope_ubuf) + max_value_fp16[0].set_as(max_value) + block_max_value_int = tik_inst.Tensor("int8", (1,), + name='block_max_value_int', + scope=tik.scope_ubuf) + block_max_value_int8 = tik_inst.Scalar("int8", + name='block_max_value_int8', + init_value=0) + tik_inst.vec_conv( + 1, "", max_value_int, max_value_fp16[0], 1, 1, 1) + tik_inst.vec_conv( + 1, "", block_max_value_int, mat_ub_lg_max_sub[0], 1, 1, 1) + max_value_int8.set_as(max_value_int[0]) + block_max_value_int8.set_as(block_max_value_int[0]) + with tik_inst.if_scope(block_max_value_int8 < max_value_int8): + max_value.set_as(block_max_value) + with tik_inst.else_scope(): + block_max_value.set_as(max_value) + tik_inst.vadds(64, mat_ub_lg_lb_subs[0], mat_ub_lg_lb[0], + max_value, (block_size + global_size) // 64, 1, 1, 4, 4) + mat_ub_lg_exp_lb = tik_inst.Tensor("float16", (block_size + global_size,), + name='mat_ub_lg_exp_lb', + scope=tik.scope_ubuf) + tik_inst.vexp(64, mat_ub_lg_exp_lb[0], + mat_ub_lg_lb_subs[0], (block_size + global_size) // 64, 1, 1, 4, 4) + tik_inst.data_move(mat_ub_lg_exp_16[0, gb, lb, 0], mat_ub_lg_exp_lb[0], 0, + (block_size + global_size) // 16, 1, 0, block_size - 1) + + # move exp fp16 from ub to L1 for CUBE mmad + # the shape of exp fp16 in ub is zN + # the shape of exp fp16 in L1 is zN + # the stride between each (16, 16) is 0 + # repeat 4 times + tik_inst.data_move(mat_l1_lg_exp_16[0, 0, 0, 0], mat_ub_lg_exp_16[0, 0, 0, 0], + 0, (global_size + block_size) // 16, block_size, 0, 0) + # apply for tensor in UB for local attention out (64, 64) zN + mat_ub_l_out = tik_inst.Tensor("float16", (block_size // 16, block_size // 16, 16, 16), + name='mat_ub_l_out', + scope=tik.scope_ubuf) + # apply for tensor in UB for global attention out (64, 256) zN + mat_ub_g_out = tik_inst.Tensor("float16", (global_size // 16, block_size // 16, 16, 16), + name='mat_ub_g_out', + scope=tik.scope_ubuf) + # apply tensor in l0c for exp sum (16, 64) zN + mat_l0c_exp = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), + name='mat_l0c_exp', + scope=tik.scope_cc) + # apply tensor in ub for exp sum (16, 64) zN + mat_ub_exp_sum = tik_inst.Tensor("float32", (block_size // 16, 1, 16, 16), + name='mat_ub_exp_sum', + scope=tik.scope_ubuf) + + with tik_inst.new_stmt_scope(): + with tik_inst.for_range(0, 4, thread_num=2) as gb: + # apply for tensor in L0A for q (64, 128) zZ + mat_l0a_ones = tik_inst.Tensor('float16', (1, (global_size + block_size) // 64, 16, 16), + name='mat_l0a_ones', scope=tik.scope_ca) + # apply for tensor in L0B for exp (350, 64) nZ + mat_l0b_exp = tik_inst.Tensor('float16', + ((global_size + block_size) // + 64, block_size // 16, 16, 16), + name='mat_l0b_exp', scope=tik.scope_cb) + # move ones from l1 to L0A for CUBE mmad + # the shape of ones in l1 is zZ + # the shape of ones in L0A is zZ + # the stride between each (16, 16) is 0 + # repeat 32 times + tik_inst.load2dv1(mat_l0a_ones[0, 0, 0, 0], mat_l1_ones[0, 0, 0, 0], 0, + (global_size + block_size) * 16 // (4 * 16 * 16), 1, 0, False) + # move global k from l1 to L0B for CUBE mmad + # the shape of global k in l1 is nZ + # the shape of global k in L0B is nZ + # the stride between each (16, 16) is 0 + # repeat 128 times + tik_inst.load2dv1(mat_l0b_exp[0, 0, 0, 0], + mat_l1_lg_exp_16[(global_size + block_size) * gb // 64, 0, 0, 0], 0, + (global_size + block_size) * block_size // (4 * 16 * 16), 1, 0, False) + with tik_inst.if_scope(gb == 0): + tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16, + (global_size + block_size) // 4, block_size, 0) + with tik_inst.else_scope(): + tik_inst.mmad(mat_l0c_exp, mat_l0a_ones, mat_l0b_exp, 16, + (global_size + block_size) // 4, block_size, 1) + + tik_inst.data_move(mat_ub_exp_sum[0, 0, 0, 0], mat_l0c_exp[0, 0, 0, 0], 0, + block_size // 16, 1, 0, 0) + # apply for tensor in UB for global prob sum (64,) + mat_ub_lg_exp_sum = tik_inst.Tensor("float32", (block_size,), + name='mat_ub_lg_exp_sum', + scope=tik.scope_ubuf) + tik_inst.data_move(mat_ub_lg_exp_sum[0], mat_ub_exp_sum[0, 0, 0, 0], + 0, block_size // 16, 1 * 2, 15 * 2, 0) + # apply for tensor in UB for attention prob sum rec (64,) + mat_ub_exp_sum_rec = tik_inst.Tensor("float32", (block_size,), + name='mat_ub_exp_sum_rec', + scope=tik.scope_ubuf) + mat_ub_exp_sum_rec_16 = tik_inst.Tensor("float16", (block_size,), + name='mat_ub_exp_sum_rec_16', + scope=tik.scope_ubuf) + worker_tensor = tik_inst.Tensor("float32", (block_size * 2,), + name='worker_tensor', + scope=tik.scope_ubuf) + # calculate attention prob sum vec (64,) + tik_inst.vec_rec_high_preci( + 64, mat_ub_exp_sum_rec, mat_ub_lg_exp_sum, worker_tensor, 1, 8, 8) + tik_inst.vec_conv(block_size, "", mat_ub_exp_sum_rec_16[0], + mat_ub_exp_sum_rec[0], + block_size // 64, 4, 8) + with tik_inst.for_range(0, block_size) as bbs: + # apply for scalar in UB for prob sum rec + sum_exp = tik_inst.Scalar("float16", + name='sum_exp', + init_value=0) + # set value for scalar prob sum rec + sum_exp.set_as(mat_ub_exp_sum_rec_16[bbs]) + tik_inst.vec_muls(16, mat_ub_l_out[0, bbs // 16, bbs % 16, 0], + mat_ub_lg_exp_16[global_size // + 16, bbs // 16, bbs % 16, 0], + sum_exp, block_size // 16, + block_size, block_size) + tik_inst.vec_muls(16, mat_ub_g_out[0, bbs // 16, bbs % 16, 0], + mat_ub_lg_exp_16[0, bbs // 16, bbs % + 16, 0], sum_exp, global_size // 16, + block_size, block_size) + # move local out from UB to gm + # the shape of local out in UB is zN + # the shape of local out in gm is zN + # the stride between each (16, 64) is 0 + # repeat 4 times + tik_inst.data_move(mat_lc[b, head, block, 0, 0, 0, 0], mat_ub_l_out[0, 0, 0, 0], 0, + block_size // 16, block_size, 0, 0) + # move global out from UB to gm + # the shape of global out in UB is zN + # the shape of global out in gm is zN + # the stride between each (16, 64) is 0 + # repeat 16 times + tik_inst.data_move(mat_gc[b, head, block, 0, 0, 0, 0], mat_ub_g_out[0, 0, 0, 0], 0, + global_size // 16, block_size, 0, 0) + + tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[mat_q, mat_k, mat_lm, mat_gm], + outputs=[mat_lc, mat_gc]) + return tik_inst diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py index 9d818d738c8..8515fca86fb 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py @@ -1,37 +1,37 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""AdaptiveMaxPool2DGrad op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -adaptive_max_pool_2d_grad_op_info = AiCPURegOp("AdaptiveMaxPool2DGrad") \ - .fusion_type("OPAQUE") \ - .input(0, "y_grad", "required") \ - .input(1, "x", "required") \ - .input(2, "argmax", "required") \ - .output(0, "x_grad", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F32_Default) \ - .get_op_info() - - -@op_info_register(adaptive_max_pool_2d_grad_op_info) -def _adaptive_max_pool_2d_grad_aicpu(): - """AdaptiveMaxPool2DGrad aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AdaptiveMaxPool2DGrad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +adaptive_max_pool_2d_grad_op_info = AiCPURegOp("AdaptiveMaxPool2DGrad") \ + .fusion_type("OPAQUE") \ + .input(0, "y_grad", "required") \ + .input(1, "x", "required") \ + .input(2, "argmax", "required") \ + .output(0, "x_grad", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(adaptive_max_pool_2d_grad_op_info) +def _adaptive_max_pool_2d_grad_aicpu(): + """AdaptiveMaxPool2DGrad aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/betainc.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/betainc.py index d5e4ee81711..2f03f85f42c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/betainc.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/betainc.py @@ -1,31 +1,31 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Betainc op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType -betainc_op_info = AiCPURegOp("Betainc") \ - .fusion_type("OPAQUE") \ - .input(0, "a", "required") \ - .input(1, "b", "required") \ - .input(2, "x", "required") \ - .output(0, "z", "required") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(betainc_op_info) -def _betainc_aicpu(): - """Betanic aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Betainc op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +betainc_op_info = AiCPURegOp("Betainc") \ + .fusion_type("OPAQUE") \ + .input(0, "a", "required") \ + .input(1, "b", "required") \ + .input(2, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(betainc_op_info) +def _betainc_aicpu(): + """Betanic aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/blackman_window.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/blackman_window.py index e764d5ab4fc..456b542849f 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/blackman_window.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/blackman_window.py @@ -1,36 +1,36 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BlackmanWindow op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -blackman_window_op_info = AiCPURegOp("BlackmanWindow") \ - .fusion_type("OPAQUE") \ - .input(0, "window_length", "required") \ - .output(0, "y", "dynamic") \ - .attr("periodic", "bool") \ - .dtype_format(DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.F64_Default) \ - .dtype_format(DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.I64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(blackman_window_op_info) -def _blackman_window_aicpu(): - """BlackmanWindow aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BlackmanWindow op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +blackman_window_op_info = AiCPURegOp("BlackmanWindow") \ + .fusion_type("OPAQUE") \ + .input(0, "window_length", "required") \ + .output(0, "y", "dynamic") \ + .attr("periodic", "bool") \ + .dtype_format(DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(blackman_window_op_info) +def _blackman_window_aicpu(): + """BlackmanWindow aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/cholesky_inverse.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/cholesky_inverse.py index 77389f8c893..9cb8d374be5 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/cholesky_inverse.py @@ -1,31 +1,31 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""CholeskyInverse op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -cholesky_inverse_op_info = AiCPURegOp("CholeskyInverse") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .attr("upper", "bool") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - -@op_info_register(cholesky_inverse_op_info) -def _cholesky_inverse_aicpu(): - """CholeskyInverse aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CholeskyInverse op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +cholesky_inverse_op_info = AiCPURegOp("CholeskyInverse") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .attr("upper", "bool") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + +@op_info_register(cholesky_inverse_op_info) +def _cholesky_inverse_aicpu(): + """CholeskyInverse aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/cos.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/cos.py index 4c15b7afe90..c2c059e5a73 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/cos.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/cos.py @@ -1,34 +1,34 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Cos op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -cos_op_info = AiCPURegOp("Cos") \ - .fusion_type("ELEMWISE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(cos_op_info) -def _cos_aicpu(): - """Cos AiCPU register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cos op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +cos_op_info = AiCPURegOp("Cos") \ + .fusion_type("ELEMWISE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(cos_op_info) +def _cos_aicpu(): + """Cos AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py index 262bcc47efb..ab7cc2c23be 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py @@ -1,36 +1,36 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""CumulativeLogsumexp op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType -cumulative_logsumexp_op_info = AiCPURegOp("CumulativeLogsumexp") \ - .fusion_type("OPAQUE") \ - .attr("exclusive", "bool") \ - .attr("reverse", "bool") \ - .input(0, "x", "required") \ - .input(1, "axis", "required")\ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ - .dtype_format(DataType.F16_Default, DataType.I16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.I16_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(cumulative_logsumexp_op_info) -def _cumulative_logsumexp_aicpu(): - """CumulativeLogsumexp aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CumulativeLogsumexp op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +cumulative_logsumexp_op_info = AiCPURegOp("CumulativeLogsumexp") \ + .fusion_type("OPAQUE") \ + .attr("exclusive", "bool") \ + .attr("reverse", "bool") \ + .input(0, "x", "required") \ + .input(1, "axis", "required")\ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I16_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(cumulative_logsumexp_op_info) +def _cumulative_logsumexp_aicpu(): + """CumulativeLogsumexp aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/expand.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/expand.py index bede1d5f647..036b410de43 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/expand.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/expand.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Expand op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -expand_op_info = AiCPURegOp("Expand") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .input(1, "shape", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.I16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ - .dtype_format(DataType.I8_Default, DataType.I16_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I16_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ - .get_op_info() - - -@op_info_register(expand_op_info) -def _expand_aicpu(): - """Expand aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Expand op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +expand_op_info = AiCPURegOp("Expand") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "shape", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.I16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I16_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I16_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(expand_op_info) +def _expand_aicpu(): + """Expand aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/fft_with_size.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/fft_with_size.py index 408c87ce0b5..0a9a6f65a82 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/fft_with_size.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/fft_with_size.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FFTWithSize op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -fft_with_size_op_info = AiCPURegOp("FFTWithSize") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .attr("signal_ndim", "int") \ - .attr("inverse", "bool") \ - .attr("signal_sizes", "listInt") \ - .attr("norm", "str") \ - .attr("onesided", "bool") \ - .attr("real", "bool") \ - .dtype_format(DataType.BOOL_Default, DataType.C64_Default) \ - .dtype_format(DataType.U8_Default, DataType.C64_Default) \ - .dtype_format(DataType.I8_Default, DataType.C64_Default) \ - .dtype_format(DataType.I16_Default, DataType.C64_Default) \ - .dtype_format(DataType.I32_Default, DataType.C64_Default) \ - .dtype_format(DataType.I64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default) \ - .dtype_format(DataType.F32_Default, DataType.C64_Default) \ - .dtype_format(DataType.F64_Default, DataType.C128_Default) \ - .dtype_format(DataType.C64_Default, DataType.F32_Default) \ - .dtype_format(DataType.C128_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(fft_with_size_op_info) -def _fft_with_size_aicpu(): - """FFTWithSize aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FFTWithSize op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +fft_with_size_op_info = AiCPURegOp("FFTWithSize") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .attr("signal_ndim", "int") \ + .attr("inverse", "bool") \ + .attr("signal_sizes", "listInt") \ + .attr("norm", "str") \ + .attr("onesided", "bool") \ + .attr("real", "bool") \ + .dtype_format(DataType.BOOL_Default, DataType.C64_Default) \ + .dtype_format(DataType.U8_Default, DataType.C64_Default) \ + .dtype_format(DataType.I8_Default, DataType.C64_Default) \ + .dtype_format(DataType.I16_Default, DataType.C64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .dtype_format(DataType.F32_Default, DataType.C64_Default) \ + .dtype_format(DataType.F64_Default, DataType.C128_Default) \ + .dtype_format(DataType.C64_Default, DataType.F32_Default) \ + .dtype_format(DataType.C128_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(fft_with_size_op_info) +def _fft_with_size_aicpu(): + """FFTWithSize aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/fill_diagonal.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/fill_diagonal.py index 685f9779623..5162f5da150 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/fill_diagonal.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/fill_diagonal.py @@ -1,39 +1,39 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FillDiagonal op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -fill_diagonal_op_info = AiCPURegOp("FillDiagonal") \ - .fusion_type("OPAQUE") \ - .attr("fill_value", "float") \ - .attr("wrap", "bool") \ - .input(0, "input_x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .get_op_info() - - -@op_info_register(fill_diagonal_op_info) -def _fill_diagonal_aicpu(): - """FillDiagonal aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FillDiagonal op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +fill_diagonal_op_info = AiCPURegOp("FillDiagonal") \ + .fusion_type("OPAQUE") \ + .attr("fill_value", "float") \ + .attr("wrap", "bool") \ + .input(0, "input_x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + + +@op_info_register(fill_diagonal_op_info) +def _fill_diagonal_aicpu(): + """FillDiagonal aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py index fed6ff90ca0..78fc233abf1 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py @@ -1,42 +1,42 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FractionalMaxPoolGradWithFixedKsize op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -fractional_max_pool_grad_with_fixed_ksize_op_info = AiCPURegOp("FractionalMaxPoolGradWithFixedKsize") \ - .fusion_type("OPAQUE") \ - .attr("data_format", "str", "NCHW") \ - .input(0, "origin_input", "required") \ - .input(1, "out_backprop", "required") \ - .input(2, "argmax", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.I64_Default, DataType.F16_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I64_Default, DataType.F32_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.F64_NCHW, DataType.I64_Default, DataType.F64_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I64_Default, DataType.I32_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I64_Default, DataType.I64_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F16_NCHW, DataType.I64_Default, DataType.F16_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F32_NCHW, DataType.I64_Default, DataType.F32_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F64_NCHW, DataType.I64_Default, DataType.F64_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_Default, DataType.I32_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_Default, DataType.I64_NCHW) \ - .get_op_info() - - -@op_info_register(fractional_max_pool_grad_with_fixed_ksize_op_info) -def _fractional_max_pool_grad_with_fixed_ksize_aicpu(): - """FractionalMaxPoolGradWithFixedKsize aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FractionalMaxPoolGradWithFixedKsize op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +fractional_max_pool_grad_with_fixed_ksize_op_info = AiCPURegOp("FractionalMaxPoolGradWithFixedKsize") \ + .fusion_type("OPAQUE") \ + .attr("data_format", "str", "NCHW") \ + .input(0, "origin_input", "required") \ + .input(1, "out_backprop", "required") \ + .input(2, "argmax", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I32_NCHW, DataType.F16_NCHW, DataType.I64_Default, DataType.F16_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I64_Default, DataType.F32_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.F64_NCHW, DataType.I64_Default, DataType.F64_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I64_Default, DataType.I32_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I64_Default, DataType.I64_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F16_NCHW, DataType.I64_Default, DataType.F16_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F32_NCHW, DataType.I64_Default, DataType.F32_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F64_NCHW, DataType.I64_Default, DataType.F64_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_Default, DataType.I32_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_Default, DataType.I64_NCHW) \ + .get_op_info() + + +@op_info_register(fractional_max_pool_grad_with_fixed_ksize_op_info) +def _fractional_max_pool_grad_with_fixed_ksize_aicpu(): + """FractionalMaxPoolGradWithFixedKsize aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py index 17415297b39..c61bd172860 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py @@ -1,49 +1,49 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FractionalMaxPoolWithFixedKsize op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -fractional_max_pool_with_fixed_ksize_op_info = AiCPURegOp("FractionalMaxPoolWithFixedKsize") \ - .fusion_type("OPAQUE") \ - .attr("ksize", "listInt") \ - .attr("output_shape", "listInt") \ - .attr("data_format", "str", "NCHW") \ - .input(0, "input_x", "required") \ - .input(1, "random_samples", "required") \ - .output(0, "y", "required") \ - .output(1, "argmax", "optional") \ - .dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F16_NCHW, DataType.F32_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F16_NCHW, DataType.F64_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F32_NCHW, DataType.F16_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F32_NCHW, DataType.F64_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F64_NCHW, DataType.F16_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F64_NCHW, DataType.F32_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.F64_NCHW, DataType.F64_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.F16_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.F32_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I32_NCHW, DataType.F64_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F16_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F32_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ - .dtype_format(DataType.I64_NCHW, DataType.F64_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ - .get_op_info() - - -@op_info_register(fractional_max_pool_with_fixed_ksize_op_info) -def _fractional_max_pool_with_fixed_ksize_aicpu(): - """FractionalMaxPoolWithFixedKsize aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FractionalMaxPoolWithFixedKsize op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +fractional_max_pool_with_fixed_ksize_op_info = AiCPURegOp("FractionalMaxPoolWithFixedKsize") \ + .fusion_type("OPAQUE") \ + .attr("ksize", "listInt") \ + .attr("output_shape", "listInt") \ + .attr("data_format", "str", "NCHW") \ + .input(0, "input_x", "required") \ + .input(1, "random_samples", "required") \ + .output(0, "y", "required") \ + .output(1, "argmax", "optional") \ + .dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F16_NCHW, DataType.F32_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F16_NCHW, DataType.F64_Default, DataType.F16_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.F16_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.F64_Default, DataType.F32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.F16_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.F32_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.F64_Default, DataType.F64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.F16_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.F32_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.F64_Default, DataType.I32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F16_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F32_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.F64_Default, DataType.I64_NCHW, DataType.I64_NCHW) \ + .get_op_info() + + +@op_info_register(fractional_max_pool_with_fixed_ksize_op_info) +def _fractional_max_pool_with_fixed_ksize_aicpu(): + """FractionalMaxPoolWithFixedKsize aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/geqrf.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/geqrf.py index 1d4737b95e2..ebd7742f704 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/geqrf.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/geqrf.py @@ -1,32 +1,32 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""InitDataSetQueue op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -geqrf_k_op_info = AiCPURegOp("Geqrf") \ - .fusion_type("ELEMWISE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .output(1, "tau", "required") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(geqrf_k_op_info) -def _geqrf_aicpu(): - """Geqrf AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""InitDataSetQueue op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +geqrf_k_op_info = AiCPURegOp("Geqrf") \ + .fusion_type("ELEMWISE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .output(1, "tau", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(geqrf_k_op_info) +def _geqrf_aicpu(): + """Geqrf AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/glu_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/glu_grad.py index 2a85c973347..da9036120c6 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/glu_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/glu_grad.py @@ -1,34 +1,34 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""GluGrad op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -glu_grad_op_info = AiCPURegOp("GluGrad") \ - .fusion_type("OPAQUE") \ - .attr("axis", "int") \ - .input(0, "grads", "required") \ - .input(1, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(glu_grad_op_info) -def _glu_grad_aicpu(): - """GluGrad aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GluGrad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +glu_grad_op_info = AiCPURegOp("GluGrad") \ + .fusion_type("OPAQUE") \ + .attr("axis", "int") \ + .input(0, "grads", "required") \ + .input(1, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(glu_grad_op_info) +def _glu_grad_aicpu(): + """GluGrad aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/left_shift.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/left_shift.py index 96aa155e6ac..9536326f97d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/left_shift.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/left_shift.py @@ -1,38 +1,38 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""LeftShift op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -left_shift_op_info = AiCPURegOp("LeftShift") \ - .fusion_type("OPAQUE") \ - .input(0, "x1", "required") \ - .input(1, "x2", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ - .get_op_info() - - -@op_info_register(left_shift_op_info) -def _left_shift_aicpu(): - """LeftShift aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LeftShift op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +left_shift_op_info = AiCPURegOp("LeftShift") \ + .fusion_type("OPAQUE") \ + .input(0, "x1", "required") \ + .input(1, "x2", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ + .get_op_info() + + +@op_info_register(left_shift_op_info) +def _left_shift_aicpu(): + """LeftShift aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/lstsq.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/lstsq.py index dc53dc1adbf..f7955ec001d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/lstsq.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/lstsq.py @@ -1,34 +1,34 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Lstsq op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -lstsq_op_info = AiCPURegOp("Lstsq") \ - .fusion_type("OPAQUE") \ - .input(0, "matrix", "required") \ - .input(1, "rhs", "required") \ - .output(0, "y", "required") \ - .attr("l2_regularizer", "float", "0.0") \ - .attr("fast", "bool", "True") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - -@op_info_register(lstsq_op_info) -def _lstsq_aicpu(): - """Lstsq aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Lstsq op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +lstsq_op_info = AiCPURegOp("Lstsq") \ + .fusion_type("OPAQUE") \ + .input(0, "matrix", "required") \ + .input(1, "rhs", "required") \ + .output(0, "y", "required") \ + .attr("l2_regularizer", "float", "0.0") \ + .attr("fast", "bool", "True") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + +@op_info_register(lstsq_op_info) +def _lstsq_aicpu(): + """Lstsq aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py index d0165d96a49..adc2461da1d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/lu_solve.py @@ -1,32 +1,32 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""LuSolve op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -lu_solve_op_info = AiCPURegOp("LuSolve") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .input(1, "lu_data", "required") \ - .input(2, "lu_pivots", "required") \ - .output(0, "output", "required") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ - .get_op_info() - -@op_info_register(lu_solve_op_info) -def _lu_solve_aicpu(): - """LuSolve aicpu register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LuSolve op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +lu_solve_op_info = AiCPURegOp("LuSolve") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "lu_data", "required") \ + .input(2, "lu_pivots", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .get_op_info() + +@op_info_register(lu_solve_op_info) +def _lu_solve_aicpu(): + """LuSolve aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/neg.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/neg.py index 0b75eb5c08b..7a475f031ce 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/neg.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/neg.py @@ -1,36 +1,36 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Neg op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -neg_op_info = AiCPURegOp("Neg") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(neg_op_info) -def _neg_aicpu(): - """Neg AiCPU register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Neg op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +neg_op_info = AiCPURegOp("Neg") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(neg_op_info) +def _neg_aicpu(): + """Neg AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/non_zero.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/non_zero.py index 1e3890d5f37..a956c5a7abf 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/non_zero.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/non_zero.py @@ -1,43 +1,43 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""NonZero op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -non_zero_op_info = AiCPURegOp("NonZero") \ - .fusion_type("OPAQUE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ - .dtype_format(DataType.I8_Default, DataType.I64_Default) \ - .dtype_format(DataType.I16_Default, DataType.I64_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.I64_Default) \ - .dtype_format(DataType.U16_Default, DataType.I64_Default) \ - .dtype_format(DataType.U32_Default, DataType.I64_Default) \ - .dtype_format(DataType.U64_Default, DataType.I64_Default) \ - .dtype_format(DataType.F16_Default, DataType.I64_Default) \ - .dtype_format(DataType.F32_Default, DataType.I64_Default) \ - .dtype_format(DataType.F64_Default, DataType.I64_Default) \ - .dtype_format(DataType.C64_Default, DataType.I64_Default) \ - .dtype_format(DataType.C128_Default, DataType.I64_Default) \ - .get_op_info() - - -@op_info_register(non_zero_op_info) -def _non_zero_aicpu(): - """Non_Zero aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NonZero op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +non_zero_op_info = AiCPURegOp("NonZero") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.BOOL_Default, DataType.I64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default) \ + .dtype_format(DataType.C64_Default, DataType.I64_Default) \ + .dtype_format(DataType.C128_Default, DataType.I64_Default) \ + .get_op_info() + + +@op_info_register(non_zero_op_info) +def _non_zero_aicpu(): + """Non_Zero aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/not_equal.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/not_equal.py index e22fcf8190d..fa974e9e08e 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/not_equal.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/not_equal.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""NotEqual op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -not_equal_op_info = AiCPURegOp("NotEqual") \ - .fusion_type("OPAQUE") \ - .input(0, "x1", "required") \ - .input(1, "x2", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ - .get_op_info() - - -@op_info_register(not_equal_op_info) -def _not_equal_aicpu(): - """NotEqual aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NotEqual op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +not_equal_op_info = AiCPURegOp("NotEqual") \ + .fusion_type("OPAQUE") \ + .input(0, "x1", "required") \ + .input(1, "x2", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(not_equal_op_info) +def _not_equal_aicpu(): + """NotEqual aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/pow.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/pow.py index 913763f38ef..5314dcba2f2 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/pow.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/pow.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Pow op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -pow_op_info = AiCPURegOp("Pow") \ - .fusion_type("OPAQUE") \ - .input(0, "x1", "required") \ - .input(1, "x2", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(pow_op_info) -def _pow_aicpu(): - """Pow AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Pow op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +pow_op_info = AiCPURegOp("Pow") \ + .fusion_type("OPAQUE") \ + .input(0, "x1", "required") \ + .input(1, "x2", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(pow_op_info) +def _pow_aicpu(): + """Pow AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py index d0c8af540d4..12b59846bd0 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py @@ -1,53 +1,53 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""ScatterAddWithAxis op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -scatter_add_with_axis_op_info = AiCPURegOp("ScatterAddWithAxis") \ - .fusion_type("OPAQUE") \ - .attr("axis", "int") \ - .input(0, "input_x", "required") \ - .input(1, "indices", "required") \ - .input(2, "updates", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(scatter_add_with_axis_op_info) -def _scatter_add_with_axis_aicpu(): - """ScatterAddWithAxis AiCPU register""" - return +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ScatterAddWithAxis op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +scatter_add_with_axis_op_info = AiCPURegOp("ScatterAddWithAxis") \ + .fusion_type("OPAQUE") \ + .attr("axis", "int") \ + .input(0, "input_x", "required") \ + .input(1, "indices", "required") \ + .input(2, "updates", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(scatter_add_with_axis_op_info) +def _scatter_add_with_axis_aicpu(): + """ScatterAddWithAxis AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py index b9163189e4a..d5db0b3414f 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py @@ -1,34 +1,34 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""SelfAdjointEig op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -self_adjoint_eig_op_info = AiCPURegOp("SelfAdjointEig") \ - .fusion_type("OPAQUE") \ - .attr("compute_v", "bool")\ - .input(0, "x", "required") \ - .output(0, "eigen_value", "required") \ - .output(1, "eigen_vector", "required") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(self_adjoint_eig_op_info) -def _self_adjoint_eig_aicpu(): - """SelfAdjointEig aicpu register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SelfAdjointEig op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +self_adjoint_eig_op_info = AiCPURegOp("SelfAdjointEig") \ + .fusion_type("OPAQUE") \ + .attr("compute_v", "bool")\ + .input(0, "x", "required") \ + .output(0, "eigen_value", "required") \ + .output(1, "eigen_vector", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(self_adjoint_eig_op_info) +def _self_adjoint_eig_aicpu(): + """SelfAdjointEig aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sequence_stack.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sequence_stack.py index 7518649d501..e36a2f3ad34 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sequence_stack.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sequence_stack.py @@ -1,40 +1,40 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SequenceStack op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sequence_stack_op_info = AiCPURegOp("SequenceStack") \ - .fusion_type("OPAQUE") \ - .attr("axis", "int") \ - .input(0, "input_0", "required") \ - .output(0, "output_data", "required") \ - .dtype_format(DataType.I32_Default_Tuple, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default_Tuple, DataType.I64_Default) \ - .dtype_format(DataType.U32_Default_Tuple, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default_Tuple, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default_Tuple, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default_Tuple, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default_Tuple, DataType.F64_Default) \ - .dtype_format(DataType.BOOL_Default_Tuple, DataType.BOOL_Default) \ - .dtype_format(DataType.C64_Default_Tuple, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default_Tuple, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(sequence_stack_op_info) -def _sequence_stack_aicpu(): - """SequenceStack AiCPU register""" - return +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SequenceStack op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sequence_stack_op_info = AiCPURegOp("SequenceStack") \ + .fusion_type("OPAQUE") \ + .attr("axis", "int") \ + .input(0, "input_0", "required") \ + .output(0, "output_data", "required") \ + .dtype_format(DataType.I32_Default_Tuple, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default_Tuple, DataType.I64_Default) \ + .dtype_format(DataType.U32_Default_Tuple, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default_Tuple, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default_Tuple, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default_Tuple, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default_Tuple, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default_Tuple, DataType.BOOL_Default) \ + .dtype_format(DataType.C64_Default_Tuple, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default_Tuple, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(sequence_stack_op_info) +def _sequence_stack_aicpu(): + """SequenceStack AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sin.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sin.py index 86525bc0401..b08aff30c18 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sin.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sin.py @@ -1,34 +1,34 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Sin op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sin_op_info = AiCPURegOp("Sin") \ - .fusion_type("ELEMWISE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(sin_op_info) -def _sin_aicpu(): - """Sin AiCPU register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sin op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sin_op_info = AiCPURegOp("Sin") \ + .fusion_type("ELEMWISE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(sin_op_info) +def _sin_aicpu(): + """Sin AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sinc.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sinc.py index 42e47c927e9..0df6b214d6d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sinc.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sinc.py @@ -1,43 +1,43 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Sinc op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sinc_op_info = AiCPURegOp("Sinc") \ - .fusion_type("ELEMWISE") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.U8_Default, DataType.F32_Default) \ - .dtype_format(DataType.I8_Default, DataType.F32_Default) \ - .dtype_format(DataType.U16_Default, DataType.F32_Default) \ - .dtype_format(DataType.I16_Default, DataType.F32_Default) \ - .dtype_format(DataType.U32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.U64_Default, DataType.F32_Default) \ - .dtype_format(DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.C128_Default, DataType.C128_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ - .get_op_info() - - -@op_info_register(sinc_op_info) -def _sinc_aicpu(): - """Sinc AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sinc op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sinc_op_info = AiCPURegOp("Sinc") \ + .fusion_type("ELEMWISE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.U8_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.F32_Default) \ + .dtype_format(DataType.U16_Default, DataType.F32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.U32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.U64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sinc_op_info) +def _sinc_aicpu(): + """Sinc AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_addmm.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_addmm.py index 508d6405346..8a955b31e0f 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_addmm.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_addmm.py @@ -1,87 +1,87 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseAddmm op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sparse_addmm_op_info = AiCPURegOp("SparseAddmm") \ - .fusion_type("OPAQUE") \ - .input(0, "x1_indices", "required") \ - .input(1, "x1_values", "required") \ - .input(2, "x1_shape", "required") \ - .input(3, "x2", "required") \ - .input(4, "x3", "required") \ - .input(5, "alpha", "required") \ - .input(6, "beta", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, - DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, - DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, - DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, - DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, - DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, - DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, - DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, - DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, - DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, - DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, - DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, - DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, - DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, - DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.C64_Default, - DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.C64_Default, - DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ - .dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.C128_Default, - DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ - .dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.C128_Default, - DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(sparse_addmm_op_info) -def _sparse_addmm_aicpu(): - """SparseAddmm AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseAddmm op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_addmm_op_info = AiCPURegOp("SparseAddmm") \ + .fusion_type("OPAQUE") \ + .input(0, "x1_indices", "required") \ + .input(1, "x1_values", "required") \ + .input(2, "x1_shape", "required") \ + .input(3, "x2", "required") \ + .input(4, "x3", "required") \ + .input(5, "alpha", "required") \ + .input(6, "beta", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, + DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.I8_Default, + DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default, DataType.I16_Default, + DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.I16_Default, + DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, + DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.U8_Default, + DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.I32_Default, DataType.U16_Default, + DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.U16_Default, + DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.I32_Default, DataType.U32_Default, + DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.U32_Default, + DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.I32_Default, DataType.U64_Default, + DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.U64_Default, + DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default, + DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, + DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.C64_Default, + DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.C64_Default, + DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.C128_Default, + DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ + .dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.C128_Default, + DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(sparse_addmm_op_info) +def _sparse_addmm_aicpu(): + """SparseAddmm AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py index 77ec471297a..7746ec27e8d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py @@ -1,80 +1,80 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseApplyMomentum op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sparse_apply_momentum_op_info = AiCPURegOp("SparseApplyMomentum") \ - .fusion_type("OPAQUE") \ - .attr("use_locking", "bool") \ - .attr("use_nesterov", "bool") \ - .input(0, "var", "required") \ - .input(1, "accum", "required") \ - .input(2, "lr", "required") \ - .input(3, "grad", "required") \ - .input(4, "indices", "required") \ - .input(5, "momentum", "required") \ - .output(0, "var", "required") \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ - DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ - DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ - DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ - DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ - DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ - DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ - DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ - DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ - DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ - DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ - DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ - DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ - DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ - DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ - DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ - DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ - DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ - DataType.I64_Default, DataType.I16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ - DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ - DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(sparse_apply_momentum_op_info) -def _sparse_apply_momentum_aicpu(): - """SparseApplyMomentum AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyMomentum op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_apply_momentum_op_info = AiCPURegOp("SparseApplyMomentum") \ + .fusion_type("OPAQUE") \ + .attr("use_locking", "bool") \ + .attr("use_nesterov", "bool") \ + .input(0, "var", "required") \ + .input(1, "accum", "required") \ + .input(2, "lr", "required") \ + .input(3, "grad", "required") \ + .input(4, "indices", "required") \ + .input(5, "momentum", "required") \ + .output(0, "var", "required") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ + DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ + DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ + DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ + DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ + DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ + DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ + DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ + DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ + DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ + DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ + DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ + DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ + DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ + DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ + DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ + DataType.I64_Default, DataType.I16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ + DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ + DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_momentum_op_info) +def _sparse_apply_momentum_aicpu(): + """SparseApplyMomentum AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py index 90854a34a21..e0bdd86e434 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py @@ -1,79 +1,79 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseApplyProximalGradientDescent""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sparse_apply_proximal_gradient_descent_op_info = AiCPURegOp("SparseApplyProximalGradientDescent") \ - .fusion_type("OPAQUE") \ - .attr("use_locking", "bool") \ - .input(0, "var", "required") \ - .input(1, "alpha", "required") \ - .input(2, "l1", "required") \ - .input(3, "l2", "required") \ - .input(4, "grad", "required") \ - .input(5, "indices", "required") \ - .output(0, "var", "required") \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ - DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ - DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ - DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ - DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ - DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ - DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ - DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ - DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ - DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ - DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ - DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ - DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ - DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ - DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ - DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ - DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ - DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ - DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ - DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ - DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(sparse_apply_proximal_gradient_descent_op_info) -def _sparse_apply_proximal_gradient_descent_aicpu(): - """SparseApplyProximalGradientDescent""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyProximalGradientDescent""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_apply_proximal_gradient_descent_op_info = AiCPURegOp("SparseApplyProximalGradientDescent") \ + .fusion_type("OPAQUE") \ + .attr("use_locking", "bool") \ + .input(0, "var", "required") \ + .input(1, "alpha", "required") \ + .input(2, "l1", "required") \ + .input(3, "l2", "required") \ + .input(4, "grad", "required") \ + .input(5, "indices", "required") \ + .output(0, "var", "required") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ + DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ + DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ + DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ + DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ + DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ + DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ + DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ + DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ + DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \ + DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \ + DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \ + DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \ + DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \ + DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \ + DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ + DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ + DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \ + DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_proximal_gradient_descent_op_info) +def _sparse_apply_proximal_gradient_descent_aicpu(): + """SparseApplyProximalGradientDescent""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_softmax.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_softmax.py index 7323e16df57..2ef2f83e2fd 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_softmax.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_softmax.py @@ -1,33 +1,33 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseSoftmax op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sparse_softmax_op_info = AiCPURegOp("SparseSoftmax") \ - .fusion_type("OPAQUE") \ - .input(0, "indices", "required") \ - .input(1, "values", "required") \ - .input(2, "shape", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ - .get_op_info() - - -@op_info_register(sparse_softmax_op_info) -def _sparse_softmax_aicpu(): - """SparseSoftmax AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseSoftmax op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_softmax_op_info = AiCPURegOp("SparseSoftmax") \ + .fusion_type("OPAQUE") \ + .input(0, "indices", "required") \ + .input(1, "values", "required") \ + .input(2, "shape", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(sparse_softmax_op_info) +def _sparse_softmax_aicpu(): + """SparseSoftmax AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py index aef55d0a6e4..019d966142c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py @@ -1,51 +1,51 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseTensorToCSRSparseMatrix op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -sparse_tensor_to_csr_sparse_matrix_op_info = AiCPURegOp("SparseTensorToCSRSparseMatrix") \ - .fusion_type("OPAQUE") \ - .input(0, "x_indices", "required") \ - .input(1, "x_values", "required") \ - .input(2, "x_dense_shape", "required") \ - .output(0, "y_dense_shape", "required") \ - .output(1, "y_batch_pointers", "required") \ - .output(2, "y_row_pointers", "required") \ - .output(3, "y_col_indices", "required") \ - .output(4, "y_values", "required") \ - .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \ - .dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.C64_Default) \ - .dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, - DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.C128_Default) \ - .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \ - .dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C64_Default) \ - .dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, - DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C128_Default) \ - .get_op_info() - - -@op_info_register(sparse_tensor_to_csr_sparse_matrix_op_info) -def _sparse_tensor_to_csr_sparse_matrix_aicpu(): - """SparseTensorToCSRSparseMatrix AiCPU register""" - return +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseTensorToCSRSparseMatrix op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +sparse_tensor_to_csr_sparse_matrix_op_info = AiCPURegOp("SparseTensorToCSRSparseMatrix") \ + .fusion_type("OPAQUE") \ + .input(0, "x_indices", "required") \ + .input(1, "x_values", "required") \ + .input(2, "x_dense_shape", "required") \ + .output(0, "y_dense_shape", "required") \ + .output(1, "y_batch_pointers", "required") \ + .output(2, "y_row_pointers", "required") \ + .output(3, "y_col_indices", "required") \ + .output(4, "y_values", "required") \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.C64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.C128_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.I64_Default, + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(sparse_tensor_to_csr_sparse_matrix_op_info) +def _sparse_tensor_to_csr_sparse_matrix_aicpu(): + """SparseTensorToCSRSparseMatrix AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py index 9fb5fdaf9e2..e389b2a27ac 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py @@ -1,42 +1,42 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Tril op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -tril_op_info = AiCPURegOp("Tril") \ - .fusion_type("OPAQUE") \ - .attr("diagonal", "int") \ - .input(0, "x", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.U16_Default, DataType.U16_Default) \ - .dtype_format(DataType.U32_Default, DataType.U32_Default) \ - .dtype_format(DataType.U64_Default, DataType.U64_Default) \ - .dtype_format(DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.I16_Default, DataType.I16_Default) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F64_Default, DataType.F64_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ - .get_op_info() - - -@op_info_register(tril_op_info) -def _tril_aicpu(): - """Tril AiCPU register""" - return +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tril op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +tril_op_info = AiCPURegOp("Tril") \ + .fusion_type("OPAQUE") \ + .attr("diagonal", "int") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(tril_op_info) +def _tril_aicpu(): + """Tril AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/unravel_index.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/unravel_index.py index 6cfc7c970e2..4376a3acacd 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/unravel_index.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/unravel_index.py @@ -1,32 +1,32 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""UnravelIndex op""" -from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType - -unravel_index_op_info = AiCPURegOp("UnravelIndex") \ - .fusion_type("OPAQUE") \ - .input(0, "indices", "required") \ - .input(1, "dims", "required") \ - .output(0, "y", "required") \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ - .get_op_info() - - -@op_info_register(unravel_index_op_info) -def _unravel_index_aicpu(): - """UnravelIndex AiCPU register""" - return +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UnravelIndex op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +unravel_index_op_info = AiCPURegOp("UnravelIndex") \ + .fusion_type("OPAQUE") \ + .input(0, "indices", "required") \ + .input(1, "dims", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + + +@op_info_register(unravel_index_op_info) +def _unravel_index_aicpu(): + """UnravelIndex AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/function/random_func.py b/mindspore/python/mindspore/ops/function/random_func.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/function/sparse_unary_func.py b/mindspore/python/mindspore/ops/function/sparse_unary_func.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py b/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/_inner_ops.py b/mindspore/python/mindspore/ops/operations/_inner_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/_quant_ops.py b/mindspore/python/mindspore/ops/operations/_quant_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/inner_ops.py b/mindspore/python/mindspore/ops/operations/inner_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops/operations/random_ops.py b/mindspore/python/mindspore/ops/operations/random_ops.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py b/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/profiler/common/util.py b/mindspore/python/mindspore/profiler/common/util.py index d0567508e22..dc4c7f9d30e 100644 --- a/mindspore/python/mindspore/profiler/common/util.py +++ b/mindspore/python/mindspore/profiler/common/util.py @@ -1,444 +1,444 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Profiler util. - -This module provides the utils. -""" -import os - -# one sys count takes 10 ns, 1 ms has 100000 system count -import re -import shutil -import stat - -from mindspore import log as logger - - -def to_int(param, param_name): - """ - Transfer param to int type. - - Args: - param (Any): A param transformed. - param_name (str): Param name. - - Returns: - int, value after transformed. - - """ - try: - param = int(param) - except ValueError as err: - raise TypeError('Must be Integer: ' + param_name) from err - return param - - -def fwrite_format(output_data_path, data_source=None, is_print=False, is_start=False): - """ - Write data to the output file. - - Args: - output_data_path (str): The output file path of the data. - data_source (str, list, tuple): The data to write. - is_print (bool): whether to print the data to stdout. - is_start (bool): Whether is the first line of the output file, will remove the old file if True." - """ - - if is_start and os.path.exists(output_data_path): - os.remove(output_data_path) - - if isinstance(data_source, str) and data_source.startswith("title:"): - title_label = '=' * 20 - data_source = title_label + data_source[6:] + title_label - - with open(output_data_path, 'a+') as f: - if isinstance(data_source, (list, tuple)): - for raw_data in data_source: - if isinstance(raw_data, (list, tuple)): - raw_data = map(str, raw_data) - raw_data = " ".join(raw_data) - f.write(raw_data) - f.write("\n") - else: - f.write(data_source) - f.write("\n") - os.chmod(output_data_path, stat.S_IREAD | stat.S_IWRITE) - - if is_print: - if isinstance(data_source, (list, tuple)): - for raw_data in data_source: - if isinstance(raw_data, (list, tuple)): - raw_data = map(str, raw_data) - raw_data = " ".join(raw_data) - logger.info(raw_data) - else: - logger.info(data_source) - - -def get_log_slice_id(file_name): - """Get log slice id.""" - pattern = re.compile(r'(?<=slice_)\d+') - slice_list = pattern.findall(file_name) - index = re.findall(r'\d+', slice_list[0]) - return int(index[0]) - - -def get_file_join_name(input_path, file_name): - """ - Search files under the special path, and will join all the files to one file. - - Args: - input_path (str): The source path, will search files under it. - file_name (str): The target of the filename, such as 'hwts.log.data.45.dev'. - - Returns: - str, the join file name. - """ - name_list = [] - file_join_name = '' - input_path = os.path.realpath(input_path) - if os.path.exists(input_path): - files = os.listdir(input_path) - for f in files: - if file_name in f and not f.endswith('.done') and not f.endswith('.join') \ - and not f.endswith('.zip'): - name_list.append(f) - - # resort name_list - name_list.sort(key=get_log_slice_id) - - if len(name_list) == 1: - file_join_name = os.path.join(input_path, name_list[0]) - elif len(name_list) > 1: - file_join_name = os.path.join(input_path, '%s.join' % file_name) - if os.path.exists(file_join_name): - os.remove(file_join_name) - file_join_name = os.path.realpath(file_join_name) - with open(file_join_name, 'ab') as bin_data: - for i in name_list: - file = input_path + os.sep + i - with open(file, 'rb') as txt: - bin_data.write(txt.read()) - return file_join_name - - -def get_file_path(input_path, file_name): - """ - Search files under the special path. - - Args: - input_path (str): The source path, will search files under it. - file_name (str): The target of the filename, such as 'host_start_log'. - - Returns: - str, a special file path. If there can not find the special path, will return None. - """ - - input_path = os.path.realpath(input_path) - if os.path.exists(input_path): - files = os.listdir(input_path) - for f in files: - if file_name in f and not f.endswith('.done') \ - and not f.endswith('.zip'): - return os.path.join(input_path, f) - - return None - - -def parse_device_id(filename, device_id_list, profiler_file_prefix): - """Parse device id from filename.""" - items = filename.split("_") - if filename.startswith("step_trace_raw"): - device_num = "" - if len(items) > 3: - device_num = items[3] - else: - device_num = items[-1].split(".")[0] if items[-1].split(".") else "" - - if device_num.isdigit() and '_'.join(items[:-1]) in profiler_file_prefix: - device_id_list.add(device_num) - - -def analyse_device_list_from_profiler_dir(profiler_dir): - """ - Analyse device list from profiler dir. - - Args: - profiler_dir (str): The profiler data dir. - - Returns: - list, the device_id list. - """ - profiler_file_prefix = ["timeline_display", "output_op_compute_time"] - - device_id_list = set() - for _, _, filenames in os.walk(profiler_dir): - for filename in filenames: - parse_device_id(filename, device_id_list, profiler_file_prefix) - - return sorted(list(device_id_list)) - - -def query_latest_trace_time_file(profiler_dir, device_id=0): - """ - Query the latest trace time file. - - Args: - profiler_dir (str): The profiler directory. - device_id (int): The id of device. - - Returns: - str, the latest trace time file path. - """ - files = os.listdir(profiler_dir) - target_file = f'step_trace_raw_{device_id}_detail_time.csv' - try: - latest_file = max( - filter( - lambda file: file == target_file, - files - ), - key=lambda file: os.stat(os.path.join(profiler_dir, file)).st_mtime - ) - except ValueError: - return None - return os.path.join(profiler_dir, latest_file) - - -def query_step_trace_file(profiler_dir): - """ - Query for all step trace file. - - Args: - profiler_dir (str): The directory that contains all step trace files. - - Returns: - str, the file path of step trace time. - """ - files = os.listdir(profiler_dir) - training_trace_file = list( - filter( - lambda file: file.startswith('training_trace') and not file.endswith('.done'), - files - ) - ) - if training_trace_file: - return os.path.join(profiler_dir, training_trace_file[0]) - return None - - -def get_summary_for_step_trace(average_info, header, is_training_mode=True): - """The property of summary info.""" - if not average_info or not header: - return {} - total_time = get_field_value(average_info, 'total', header) - iteration_interval = get_field_value(average_info, 'iteration_interval', - header) - summary_part = { - 'total_time': total_time, - 'iteration_interval': iteration_interval, - 'iteration_interval_percent': calculate_percent(iteration_interval, total_time), - } - if is_training_mode: - fp_and_bp = get_field_value(average_info, 'fp_and_bp', header) - tail = get_field_value(average_info, 'tail', header) - summary = { - 'fp_and_bp': fp_and_bp, - 'fp_and_bp_percent': calculate_percent(fp_and_bp, total_time), - 'tail': tail, - 'tail_percent': calculate_percent(tail, total_time) - } - else: - fp = get_field_value(average_info, 'fp', header) - summary = { - 'fp': fp, - 'fp_percent': calculate_percent(fp, total_time) - } - summary.update(summary_part) - return summary - - -def calculate_percent(partial, total): - """Calculate percent value.""" - if total: - percent = round(partial / total * 100, 2) - else: - percent = 0 - return f'{percent}%' - - -def to_millisecond(sys_count, limit=4): - """Translate system count to millisecond.""" - per_ms_syscnt = 100000 - return round(sys_count / per_ms_syscnt, limit) - - -def get_field_value(row_info, field_name, header, time_type='realtime'): - """ - Extract basic info through row_info. - - Args: - row_info (list): The list of data info in one row. - field_name (str): The name in header. - header (list[str]): The list of field names. - time_type (str): The type of value, `realtime` or `systime`. Default: `realtime`. - - Returns: - dict, step trace info in dict format. - """ - field_index = header.index(field_name) - value = row_info[field_index] - value = to_int(value, field_name) - if time_type == 'realtime': - value = to_millisecond(value) - - return value - - -def get_options(options): - if options is None: - options = {} - - return options - - -def combine_stream_task_id(stream_id, task_id): - """Combine Stream ID and task ID into unique values.""" - return f'{stream_id}_{task_id}' - - -def get_newest_file(file_list): - """ - Find the newest files - :param file_list: - :return: - """ - newest_file_list = [] - newest_timestamp = '0' - for file_path in file_list: - timestamp = file_path.split('.')[0].split('/')[-1].split('_')[-1] - newest_timestamp = max(timestamp, newest_timestamp) - - for file_path in file_list: - if file_path.split('.')[0].split('/')[-1].split('_')[-1] == newest_timestamp: - newest_file_list.append(file_path) - - newest_file_list.sort() - return newest_file_list - - -class ProfilerPathManager: - """A path manager to manage profiler path""" - - FRAMEWORK_DIR = "FRAMEWORK" - INVALID_VALUE = -1 - - @classmethod - def get_fwk_path(cls, profiler_path: str) -> str: - """Get FRAMEWORK directory path""" - fwk_path = os.path.join(profiler_path, cls.FRAMEWORK_DIR) - if os.path.isdir(fwk_path): - return fwk_path - return "" - - @classmethod - def get_cann_path(cls, profiler_path: str) -> str: - """Get CANN Prof directory path""" - sub_dirs = os.listdir(os.path.realpath(profiler_path)) - for sub_dir in sub_dirs: - sub_path = os.path.join(profiler_path, sub_dir) - if os.path.isdir(sub_path) and re.match(r"^PROF_\d+_\d+_[0-9a-zA-Z]+", sub_dir): - return sub_path - return "" - - @classmethod - def get_host_path(cls, cann_path: str) -> str: - """Get CANN Prof host directory path""" - host_path = os.path.join(cann_path, 'host') - if os.path.exists(host_path): - return host_path - return "" - - @classmethod - def get_device_path(cls, cann_path: str) -> str: - """Get CANN Prof device directory path""" - sub_dirs = os.listdir(os.path.realpath(cann_path)) - for sub_dir in sub_dirs: - sub_path = os.path.join(cann_path, sub_dir) - if os.path.isdir(sub_path) and re.match(r"^device_\d", sub_dir): - return sub_path - return "" - - @classmethod - def remove_path_safety(cls, path: str): - """Remove directory""" - msg = f"Failed to remove path: {path}" - if os.path.islink(path): - raise RuntimeError(msg) - if not os.path.exists(path): - return - try: - shutil.rmtree(path) - except FileNotFoundError: - return - except Exception as err: - raise RuntimeError(msg) from err - - @classmethod - def remove_file_safety(cls, file: str): - """Remove file""" - msg = f"Failed to remove file: {file}" - if os.path.islink(file): - raise RuntimeError(msg) - if not os.path.exists(file): - return - try: - os.remove(file) - except FileExistsError: - return - except Exception as err: - raise RuntimeError(msg) from err - - @classmethod - def simplify_data(cls, profiler_path: str, simplify_flag: bool): - """Profiler simplify temporary data""" - cann_path = cls.get_cann_path(profiler_path) - device_path = cls.get_device_path(cann_path) - host_path = cls.get_host_path(cann_path) - rm_dirs = ['sqlite', 'summary', 'timeline'] if simplify_flag else ['sqlite'] - for rm_dir in rm_dirs: - if device_path: - target_path = os.path.join(device_path, rm_dir) - cls.remove_path_safety(target_path) - if host_path: - target_path = os.path.join(host_path, rm_dir) - cls.remove_path_safety(target_path) - if simplify_flag: - fwk_path = cls.get_fwk_path(profiler_path) - cls.remove_path_safety(fwk_path) - if not cann_path: - return - cann_rm_dirs = ['analyze', 'mindstudio_profiler_log', 'mindstudio_profiler_output'] - for cann_rm_dir in cann_rm_dirs: - target_path = os.path.join(cann_path, cann_rm_dir) - cls.remove_path_safety(target_path) - log_patten = r'msprof_anlysis_\d+\.log$' - for cann_file in os.listdir(cann_path): - file_path = os.path.join(cann_path, cann_file) - if not os.path.isfile(file_path): - continue - if re.match(log_patten, cann_file): - cls.remove_file_safety(file_path) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Profiler util. + +This module provides the utils. +""" +import os + +# one sys count takes 10 ns, 1 ms has 100000 system count +import re +import shutil +import stat + +from mindspore import log as logger + + +def to_int(param, param_name): + """ + Transfer param to int type. + + Args: + param (Any): A param transformed. + param_name (str): Param name. + + Returns: + int, value after transformed. + + """ + try: + param = int(param) + except ValueError as err: + raise TypeError('Must be Integer: ' + param_name) from err + return param + + +def fwrite_format(output_data_path, data_source=None, is_print=False, is_start=False): + """ + Write data to the output file. + + Args: + output_data_path (str): The output file path of the data. + data_source (str, list, tuple): The data to write. + is_print (bool): whether to print the data to stdout. + is_start (bool): Whether is the first line of the output file, will remove the old file if True." + """ + + if is_start and os.path.exists(output_data_path): + os.remove(output_data_path) + + if isinstance(data_source, str) and data_source.startswith("title:"): + title_label = '=' * 20 + data_source = title_label + data_source[6:] + title_label + + with open(output_data_path, 'a+') as f: + if isinstance(data_source, (list, tuple)): + for raw_data in data_source: + if isinstance(raw_data, (list, tuple)): + raw_data = map(str, raw_data) + raw_data = " ".join(raw_data) + f.write(raw_data) + f.write("\n") + else: + f.write(data_source) + f.write("\n") + os.chmod(output_data_path, stat.S_IREAD | stat.S_IWRITE) + + if is_print: + if isinstance(data_source, (list, tuple)): + for raw_data in data_source: + if isinstance(raw_data, (list, tuple)): + raw_data = map(str, raw_data) + raw_data = " ".join(raw_data) + logger.info(raw_data) + else: + logger.info(data_source) + + +def get_log_slice_id(file_name): + """Get log slice id.""" + pattern = re.compile(r'(?<=slice_)\d+') + slice_list = pattern.findall(file_name) + index = re.findall(r'\d+', slice_list[0]) + return int(index[0]) + + +def get_file_join_name(input_path, file_name): + """ + Search files under the special path, and will join all the files to one file. + + Args: + input_path (str): The source path, will search files under it. + file_name (str): The target of the filename, such as 'hwts.log.data.45.dev'. + + Returns: + str, the join file name. + """ + name_list = [] + file_join_name = '' + input_path = os.path.realpath(input_path) + if os.path.exists(input_path): + files = os.listdir(input_path) + for f in files: + if file_name in f and not f.endswith('.done') and not f.endswith('.join') \ + and not f.endswith('.zip'): + name_list.append(f) + + # resort name_list + name_list.sort(key=get_log_slice_id) + + if len(name_list) == 1: + file_join_name = os.path.join(input_path, name_list[0]) + elif len(name_list) > 1: + file_join_name = os.path.join(input_path, '%s.join' % file_name) + if os.path.exists(file_join_name): + os.remove(file_join_name) + file_join_name = os.path.realpath(file_join_name) + with open(file_join_name, 'ab') as bin_data: + for i in name_list: + file = input_path + os.sep + i + with open(file, 'rb') as txt: + bin_data.write(txt.read()) + return file_join_name + + +def get_file_path(input_path, file_name): + """ + Search files under the special path. + + Args: + input_path (str): The source path, will search files under it. + file_name (str): The target of the filename, such as 'host_start_log'. + + Returns: + str, a special file path. If there can not find the special path, will return None. + """ + + input_path = os.path.realpath(input_path) + if os.path.exists(input_path): + files = os.listdir(input_path) + for f in files: + if file_name in f and not f.endswith('.done') \ + and not f.endswith('.zip'): + return os.path.join(input_path, f) + + return None + + +def parse_device_id(filename, device_id_list, profiler_file_prefix): + """Parse device id from filename.""" + items = filename.split("_") + if filename.startswith("step_trace_raw"): + device_num = "" + if len(items) > 3: + device_num = items[3] + else: + device_num = items[-1].split(".")[0] if items[-1].split(".") else "" + + if device_num.isdigit() and '_'.join(items[:-1]) in profiler_file_prefix: + device_id_list.add(device_num) + + +def analyse_device_list_from_profiler_dir(profiler_dir): + """ + Analyse device list from profiler dir. + + Args: + profiler_dir (str): The profiler data dir. + + Returns: + list, the device_id list. + """ + profiler_file_prefix = ["timeline_display", "output_op_compute_time"] + + device_id_list = set() + for _, _, filenames in os.walk(profiler_dir): + for filename in filenames: + parse_device_id(filename, device_id_list, profiler_file_prefix) + + return sorted(list(device_id_list)) + + +def query_latest_trace_time_file(profiler_dir, device_id=0): + """ + Query the latest trace time file. + + Args: + profiler_dir (str): The profiler directory. + device_id (int): The id of device. + + Returns: + str, the latest trace time file path. + """ + files = os.listdir(profiler_dir) + target_file = f'step_trace_raw_{device_id}_detail_time.csv' + try: + latest_file = max( + filter( + lambda file: file == target_file, + files + ), + key=lambda file: os.stat(os.path.join(profiler_dir, file)).st_mtime + ) + except ValueError: + return None + return os.path.join(profiler_dir, latest_file) + + +def query_step_trace_file(profiler_dir): + """ + Query for all step trace file. + + Args: + profiler_dir (str): The directory that contains all step trace files. + + Returns: + str, the file path of step trace time. + """ + files = os.listdir(profiler_dir) + training_trace_file = list( + filter( + lambda file: file.startswith('training_trace') and not file.endswith('.done'), + files + ) + ) + if training_trace_file: + return os.path.join(profiler_dir, training_trace_file[0]) + return None + + +def get_summary_for_step_trace(average_info, header, is_training_mode=True): + """The property of summary info.""" + if not average_info or not header: + return {} + total_time = get_field_value(average_info, 'total', header) + iteration_interval = get_field_value(average_info, 'iteration_interval', + header) + summary_part = { + 'total_time': total_time, + 'iteration_interval': iteration_interval, + 'iteration_interval_percent': calculate_percent(iteration_interval, total_time), + } + if is_training_mode: + fp_and_bp = get_field_value(average_info, 'fp_and_bp', header) + tail = get_field_value(average_info, 'tail', header) + summary = { + 'fp_and_bp': fp_and_bp, + 'fp_and_bp_percent': calculate_percent(fp_and_bp, total_time), + 'tail': tail, + 'tail_percent': calculate_percent(tail, total_time) + } + else: + fp = get_field_value(average_info, 'fp', header) + summary = { + 'fp': fp, + 'fp_percent': calculate_percent(fp, total_time) + } + summary.update(summary_part) + return summary + + +def calculate_percent(partial, total): + """Calculate percent value.""" + if total: + percent = round(partial / total * 100, 2) + else: + percent = 0 + return f'{percent}%' + + +def to_millisecond(sys_count, limit=4): + """Translate system count to millisecond.""" + per_ms_syscnt = 100000 + return round(sys_count / per_ms_syscnt, limit) + + +def get_field_value(row_info, field_name, header, time_type='realtime'): + """ + Extract basic info through row_info. + + Args: + row_info (list): The list of data info in one row. + field_name (str): The name in header. + header (list[str]): The list of field names. + time_type (str): The type of value, `realtime` or `systime`. Default: `realtime`. + + Returns: + dict, step trace info in dict format. + """ + field_index = header.index(field_name) + value = row_info[field_index] + value = to_int(value, field_name) + if time_type == 'realtime': + value = to_millisecond(value) + + return value + + +def get_options(options): + if options is None: + options = {} + + return options + + +def combine_stream_task_id(stream_id, task_id): + """Combine Stream ID and task ID into unique values.""" + return f'{stream_id}_{task_id}' + + +def get_newest_file(file_list): + """ + Find the newest files + :param file_list: + :return: + """ + newest_file_list = [] + newest_timestamp = '0' + for file_path in file_list: + timestamp = file_path.split('.')[0].split('/')[-1].split('_')[-1] + newest_timestamp = max(timestamp, newest_timestamp) + + for file_path in file_list: + if file_path.split('.')[0].split('/')[-1].split('_')[-1] == newest_timestamp: + newest_file_list.append(file_path) + + newest_file_list.sort() + return newest_file_list + + +class ProfilerPathManager: + """A path manager to manage profiler path""" + + FRAMEWORK_DIR = "FRAMEWORK" + INVALID_VALUE = -1 + + @classmethod + def get_fwk_path(cls, profiler_path: str) -> str: + """Get FRAMEWORK directory path""" + fwk_path = os.path.join(profiler_path, cls.FRAMEWORK_DIR) + if os.path.isdir(fwk_path): + return fwk_path + return "" + + @classmethod + def get_cann_path(cls, profiler_path: str) -> str: + """Get CANN Prof directory path""" + sub_dirs = os.listdir(os.path.realpath(profiler_path)) + for sub_dir in sub_dirs: + sub_path = os.path.join(profiler_path, sub_dir) + if os.path.isdir(sub_path) and re.match(r"^PROF_\d+_\d+_[0-9a-zA-Z]+", sub_dir): + return sub_path + return "" + + @classmethod + def get_host_path(cls, cann_path: str) -> str: + """Get CANN Prof host directory path""" + host_path = os.path.join(cann_path, 'host') + if os.path.exists(host_path): + return host_path + return "" + + @classmethod + def get_device_path(cls, cann_path: str) -> str: + """Get CANN Prof device directory path""" + sub_dirs = os.listdir(os.path.realpath(cann_path)) + for sub_dir in sub_dirs: + sub_path = os.path.join(cann_path, sub_dir) + if os.path.isdir(sub_path) and re.match(r"^device_\d", sub_dir): + return sub_path + return "" + + @classmethod + def remove_path_safety(cls, path: str): + """Remove directory""" + msg = f"Failed to remove path: {path}" + if os.path.islink(path): + raise RuntimeError(msg) + if not os.path.exists(path): + return + try: + shutil.rmtree(path) + except FileNotFoundError: + return + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def remove_file_safety(cls, file: str): + """Remove file""" + msg = f"Failed to remove file: {file}" + if os.path.islink(file): + raise RuntimeError(msg) + if not os.path.exists(file): + return + try: + os.remove(file) + except FileExistsError: + return + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def simplify_data(cls, profiler_path: str, simplify_flag: bool): + """Profiler simplify temporary data""" + cann_path = cls.get_cann_path(profiler_path) + device_path = cls.get_device_path(cann_path) + host_path = cls.get_host_path(cann_path) + rm_dirs = ['sqlite', 'summary', 'timeline'] if simplify_flag else ['sqlite'] + for rm_dir in rm_dirs: + if device_path: + target_path = os.path.join(device_path, rm_dir) + cls.remove_path_safety(target_path) + if host_path: + target_path = os.path.join(host_path, rm_dir) + cls.remove_path_safety(target_path) + if simplify_flag: + fwk_path = cls.get_fwk_path(profiler_path) + cls.remove_path_safety(fwk_path) + if not cann_path: + return + cann_rm_dirs = ['analyze', 'mindstudio_profiler_log', 'mindstudio_profiler_output'] + for cann_rm_dir in cann_rm_dirs: + target_path = os.path.join(cann_path, cann_rm_dir) + cls.remove_path_safety(target_path) + log_patten = r'msprof_anlysis_\d+\.log$' + for cann_file in os.listdir(cann_path): + file_path = os.path.join(cann_path, cann_file) + if not os.path.isfile(file_path): + continue + if re.match(log_patten, cann_file): + cls.remove_file_safety(file_path) diff --git a/mindspore/python/mindspore/profiler/parser/minddata_parser.py b/mindspore/python/mindspore/profiler/parser/minddata_parser.py index dce0be556d4..978db1698d5 100644 --- a/mindspore/python/mindspore/profiler/parser/minddata_parser.py +++ b/mindspore/python/mindspore/profiler/parser/minddata_parser.py @@ -1,186 +1,186 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Minddata aicpu parser.""" -import os -import glob -import csv - -from mindspore.profiler.common.util import get_file_join_name, fwrite_format -from mindspore import log as logger -from mindspore.profiler.common.validator.validate_path import \ - validate_and_normalize_path - - -class MinddataParser: - """Minddata Aicpu Parser.""" - - @staticmethod - def parse_step_minddata_aicpu_data(one_step, result): - """ - Parse step mind_data ai_cpu data. - - Args: - one_step (str): The mind_data step info text, it is one of two structures. - - Type queue: node_name,queue_size,run_start,run_end - Type run: node_name,run_start,run_end,queue_size - - result ([[node_name, node_start, node_end, queue_size]]): Step info list. - """ - - if not one_step: - return - node_info = one_step.split(", ") - node_name, node_start, node_end, queue_size = "", 0, 0, 0 - if node_info: - node_name = node_info[0].replace("Node:", "") - - if len(node_info) > 3: - if "queue" in node_info[1]: - queue_size = node_info[1].replace("queue size:", "") - node_start = node_info[2].replace("Run start:", "") - node_end = node_info[3].replace("Run end:", "") - elif "Run" in node_info[1]: - queue_size = node_info[3].replace("queue size:", "") - node_start = node_info[1].replace("Run start:", "") - node_end = node_info[2].replace("Run end:", "") - queue_size = int(queue_size) if queue_size.isdigit() else queue_size - node_start = int(node_start) if node_start.isdigit() else node_start - node_end = int(node_end) if node_end.isdigit() else node_end - - one_step_list = [node_name, node_start, node_end, queue_size] - result.append(one_step_list) - - @staticmethod - def parse_minddata_aicpu_data(minddata_aicpu_source_path): - """ - Parse minddata get_next info which contains queue size and execute time. - - Args: - minddata_aicpu_source_path (str): the source file path. - - Returns: - list[Union[str, float]], the converted data. - """ - result = list() - try: - minddata_aicpu_source_path = validate_and_normalize_path(minddata_aicpu_source_path) - with open(minddata_aicpu_source_path) as source_data_file: - source_data = source_data_file.read() - step_data = source_data.split("\x00") - for one_step in step_data: - MinddataParser.parse_step_minddata_aicpu_data(one_step, result) - except OSError: - logger.error("Open get_next profiling file error.") - - return result - - @staticmethod - def execute(source_path, output_path, job_id, device_id): - """ - Execute the parser. - - Args: - source_path (str): the source file path, eg: profiler. - output_path (str): the output file path, eg: profiler. - job_id (str): the job id, eg: PROF_XXX/device_* - device_id (str): the device id. - """ - if MinddataParser._is_legacy_aicpu_data(source_path, job_id): - logger.warning("The aicpu data is legacy, which will be deprecated in the future, please update your " - "CANN and driver version.") - MinddataParser._execute_legacy(os.path.join(source_path, job_id), output_path, device_id) - return - - MinddataParser._execute(source_path, output_path, job_id, device_id) - - @staticmethod - def _is_legacy_aicpu_data(source_path, job_id) -> bool: - """ - Check whether the aicpu data is legacy. - - Args: - source_path (str): the source file path, eg: profiler. - job_id (str): the job id, eg: PROF_XXX/device_* - Returns: - bool, True if the aicpu data is legacy, False otherwise. - """ - legacy_files = glob.glob(os.path.join(source_path, job_id, "data", "DATA_PREPROCESS.*")) - return len(legacy_files) > 0 - - @staticmethod - def _execute(source_path, output_path, job_id, device_id): - """ - Execute the parser when using newest CANN and driver version. - - Args: - source_path (str): the source file path, eg: profiler. - output_path (str): the output file path, eg: profiler. - job_id (str): the job id, eg: PROF_XXX/device_* - device_id (str): the device id. - """ - minddata_aicpu_data = [] - prof_path = job_id.split("/")[0] - if not prof_path: - logger.error("The job_id is invalid: %s", job_id) - return - - prof_output_path = os.path.join(source_path, prof_path, "mindstudio_profiler_output") - aicpu_file = glob.glob(os.path.join(prof_output_path, "aicpu_mi_*.csv")) - if not aicpu_file: - return - - # aicpu_file len is 1 - for file_path in aicpu_file: - file_path = validate_and_normalize_path(file_path) - with open(file_path, "r", newline='') as f: - reader = csv.reader(f) - minddata_aicpu_data = [[line[1], line[2][:-2], line[3][:-2], line[4]] for line in reader] - - if minddata_aicpu_data: - minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + str(device_id) + ".txt") - fwrite_format(minddata_aicpu_output_path, minddata_aicpu_data[1:], is_start=True) - logger.info("Minddata aicpu data has been saved to %s", minddata_aicpu_output_path) - - @staticmethod - def _execute_legacy(source_path, output_path, device_id): - """ - Execute the parser when using legacy CANN and driver version. - - Args: - source_path (str): the source file path, eg: profiler/PROF_XXX/device_*. - output_path (str): the output file path, eg: profiler. - device_id (str): the device id. - """ - col_names = ["node_name", "start_time", "end_time", "queue_size"] - source_path = validate_and_normalize_path(source_path) - minddata_aicpu_source_path = get_file_join_name( - input_path=source_path, file_name='DATA_PREPROCESS.AICPUMI') - if not minddata_aicpu_source_path: - minddata_aicpu_source_path = get_file_join_name( - input_path=source_path, file_name='DATA_PREPROCESS.dev.AICPUMI') - if not minddata_aicpu_source_path: - minddata_aicpu_source_path = get_file_join_name( - input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.AICPUMI') - if not minddata_aicpu_source_path: - minddata_aicpu_source_path = get_file_join_name( - input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.dev.AICPUMI') - if not minddata_aicpu_source_path: - return - minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + str(device_id) + ".txt") - minddata_aicpu_data = MinddataParser.parse_minddata_aicpu_data(minddata_aicpu_source_path) - if minddata_aicpu_data: - fwrite_format(minddata_aicpu_output_path, " ".join(col_names), is_start=True) - fwrite_format(minddata_aicpu_output_path, minddata_aicpu_data, is_start=True) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Minddata aicpu parser.""" +import os +import glob +import csv + +from mindspore.profiler.common.util import get_file_join_name, fwrite_format +from mindspore import log as logger +from mindspore.profiler.common.validator.validate_path import \ + validate_and_normalize_path + + +class MinddataParser: + """Minddata Aicpu Parser.""" + + @staticmethod + def parse_step_minddata_aicpu_data(one_step, result): + """ + Parse step mind_data ai_cpu data. + + Args: + one_step (str): The mind_data step info text, it is one of two structures. + + Type queue: node_name,queue_size,run_start,run_end + Type run: node_name,run_start,run_end,queue_size + + result ([[node_name, node_start, node_end, queue_size]]): Step info list. + """ + + if not one_step: + return + node_info = one_step.split(", ") + node_name, node_start, node_end, queue_size = "", 0, 0, 0 + if node_info: + node_name = node_info[0].replace("Node:", "") + + if len(node_info) > 3: + if "queue" in node_info[1]: + queue_size = node_info[1].replace("queue size:", "") + node_start = node_info[2].replace("Run start:", "") + node_end = node_info[3].replace("Run end:", "") + elif "Run" in node_info[1]: + queue_size = node_info[3].replace("queue size:", "") + node_start = node_info[1].replace("Run start:", "") + node_end = node_info[2].replace("Run end:", "") + queue_size = int(queue_size) if queue_size.isdigit() else queue_size + node_start = int(node_start) if node_start.isdigit() else node_start + node_end = int(node_end) if node_end.isdigit() else node_end + + one_step_list = [node_name, node_start, node_end, queue_size] + result.append(one_step_list) + + @staticmethod + def parse_minddata_aicpu_data(minddata_aicpu_source_path): + """ + Parse minddata get_next info which contains queue size and execute time. + + Args: + minddata_aicpu_source_path (str): the source file path. + + Returns: + list[Union[str, float]], the converted data. + """ + result = list() + try: + minddata_aicpu_source_path = validate_and_normalize_path(minddata_aicpu_source_path) + with open(minddata_aicpu_source_path) as source_data_file: + source_data = source_data_file.read() + step_data = source_data.split("\x00") + for one_step in step_data: + MinddataParser.parse_step_minddata_aicpu_data(one_step, result) + except OSError: + logger.error("Open get_next profiling file error.") + + return result + + @staticmethod + def execute(source_path, output_path, job_id, device_id): + """ + Execute the parser. + + Args: + source_path (str): the source file path, eg: profiler. + output_path (str): the output file path, eg: profiler. + job_id (str): the job id, eg: PROF_XXX/device_* + device_id (str): the device id. + """ + if MinddataParser._is_legacy_aicpu_data(source_path, job_id): + logger.warning("The aicpu data is legacy, which will be deprecated in the future, please update your " + "CANN and driver version.") + MinddataParser._execute_legacy(os.path.join(source_path, job_id), output_path, device_id) + return + + MinddataParser._execute(source_path, output_path, job_id, device_id) + + @staticmethod + def _is_legacy_aicpu_data(source_path, job_id) -> bool: + """ + Check whether the aicpu data is legacy. + + Args: + source_path (str): the source file path, eg: profiler. + job_id (str): the job id, eg: PROF_XXX/device_* + Returns: + bool, True if the aicpu data is legacy, False otherwise. + """ + legacy_files = glob.glob(os.path.join(source_path, job_id, "data", "DATA_PREPROCESS.*")) + return len(legacy_files) > 0 + + @staticmethod + def _execute(source_path, output_path, job_id, device_id): + """ + Execute the parser when using newest CANN and driver version. + + Args: + source_path (str): the source file path, eg: profiler. + output_path (str): the output file path, eg: profiler. + job_id (str): the job id, eg: PROF_XXX/device_* + device_id (str): the device id. + """ + minddata_aicpu_data = [] + prof_path = job_id.split("/")[0] + if not prof_path: + logger.error("The job_id is invalid: %s", job_id) + return + + prof_output_path = os.path.join(source_path, prof_path, "mindstudio_profiler_output") + aicpu_file = glob.glob(os.path.join(prof_output_path, "aicpu_mi_*.csv")) + if not aicpu_file: + return + + # aicpu_file len is 1 + for file_path in aicpu_file: + file_path = validate_and_normalize_path(file_path) + with open(file_path, "r", newline='') as f: + reader = csv.reader(f) + minddata_aicpu_data = [[line[1], line[2][:-2], line[3][:-2], line[4]] for line in reader] + + if minddata_aicpu_data: + minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + str(device_id) + ".txt") + fwrite_format(minddata_aicpu_output_path, minddata_aicpu_data[1:], is_start=True) + logger.info("Minddata aicpu data has been saved to %s", minddata_aicpu_output_path) + + @staticmethod + def _execute_legacy(source_path, output_path, device_id): + """ + Execute the parser when using legacy CANN and driver version. + + Args: + source_path (str): the source file path, eg: profiler/PROF_XXX/device_*. + output_path (str): the output file path, eg: profiler. + device_id (str): the device id. + """ + col_names = ["node_name", "start_time", "end_time", "queue_size"] + source_path = validate_and_normalize_path(source_path) + minddata_aicpu_source_path = get_file_join_name( + input_path=source_path, file_name='DATA_PREPROCESS.AICPUMI') + if not minddata_aicpu_source_path: + minddata_aicpu_source_path = get_file_join_name( + input_path=source_path, file_name='DATA_PREPROCESS.dev.AICPUMI') + if not minddata_aicpu_source_path: + minddata_aicpu_source_path = get_file_join_name( + input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.AICPUMI') + if not minddata_aicpu_source_path: + minddata_aicpu_source_path = get_file_join_name( + input_path=os.path.join(source_path, "data"), file_name='DATA_PREPROCESS.dev.AICPUMI') + if not minddata_aicpu_source_path: + return + minddata_aicpu_output_path = os.path.join(output_path, "minddata_aicpu_" + str(device_id) + ".txt") + minddata_aicpu_data = MinddataParser.parse_minddata_aicpu_data(minddata_aicpu_source_path) + if minddata_aicpu_data: + fwrite_format(minddata_aicpu_output_path, " ".join(col_names), is_start=True) + fwrite_format(minddata_aicpu_output_path, minddata_aicpu_data, is_start=True) diff --git a/mindspore/python/mindspore/safeguard/OWNERS b/mindspore/python/mindspore/safeguard/OWNERS index 98c6abf617a..bf8608a2826 100644 --- a/mindspore/python/mindspore/safeguard/OWNERS +++ b/mindspore/python/mindspore/safeguard/OWNERS @@ -1,5 +1,5 @@ -approvers: -- jxlang910 # - -reviewers: +approvers: +- jxlang910 # + +reviewers: - ZhangZheng_99 \ No newline at end of file diff --git a/mindspore/python/mindspore/safeguard/rewrite_obfuscation.py b/mindspore/python/mindspore/safeguard/rewrite_obfuscation.py index 7bb004423dd..2b608d3c175 100644 --- a/mindspore/python/mindspore/safeguard/rewrite_obfuscation.py +++ b/mindspore/python/mindspore/safeguard/rewrite_obfuscation.py @@ -1,531 +1,531 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""obfuscate network based on rewrite interfaces.""" -import os -import re -import secrets -from pathlib import Path - -from mindspore import ops, nn -from mindspore.common.tensor import Tensor -from mindspore import log as logger -from mindspore import load_checkpoint, save_checkpoint -from mindspore.rewrite import SymbolTree, Node, NodeType, ScopedValue -from mindspore.rewrite.parsers import ClassDefParser -from mindspore.rewrite.parsers import ModuleParser - -OBF_RATIOS_LENGTH = 1 -MAX_OBF_RATIOS_NUM = 50 -OBF_RATIOS_WIDTH = 0 -OBF_RATIOS_INSERT_INDEX = 0 - - -def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100): - """ - obfuscate the plaintext checkpoint files. Usually used in conjunction with - :func:`mindspore.load_obf_params_into_net`. - interface. - - Args: - network (nn.Cell): The original network that need to be obfuscated. - ckpt_files (str): The directory path of original ckpt files. - target_modules (list[str]): The target module of network that need to be obfuscated. The first string - represents the network path of target module in original network, which should be in form of ``'A/B/C'``. - The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For - example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``. - If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or - 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers - (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search - target modules by itself. If found, the searched target module would be used, otherwise suggested target - modules would be given with warning log. Default: ``None``. - saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``. - obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in - range of (1 / obfuscate_scale, obfuscate_scale). Default: 100. - - Raises: - TypeError: If `network` is not nn.Cell. - TypeError: If `ckpt_files` is not string or `saved_path` is not string. - TypeError: If `target_modules` is not list. - TypeError: If target_modules's elements are not string. - ValueError: If `ckpt_files` is not exist or `saved_path` is not exist. - ValueError: If the number of elements of `target_modules` is less than ``2``. - ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase - letters, numbers, ``'_'`` and ``'/'``. - ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and - lowercase letters, numbers, ``'_'`` and ``'|'``. - ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or - 'obfuscate_layers:int'. - - Returns: - list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network. - - Examples: - >>> from mindspore import obfuscate_ckpt, save_checkpoint - >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py - >>> net = LeNet5() - >>> save_checkpoint(net, './test_net.ckpt') - >>> target_modules = ['', 'fc1|fc2'] - >>> obfuscate_ckpt(net, target_modules, './', './') - """ - if not isinstance(network, nn.Cell): - raise TypeError("network must be nn.Cell, but got {}.".format(type(network))) - _check_dir_path('ckpt_files', ckpt_files) - _check_dir_path('saved_path', saved_path) - # Try to find default target modules - if target_modules is None: - to_split_modules = _get_default_target_modules(ckpt_files) - else: - if len(target_modules) >= 1 and target_modules[0] == '/': - target_modules[0] = '' - to_split_modules = target_modules - if not _check_valid_target(network, to_split_modules): - raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'." - .format(to_split_modules)) - if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1): - raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}." - .format(obfuscate_scale)) - # generate and save obf_ratios to saved_path - path_list = to_split_modules[0].split('/') - target_list = to_split_modules[1].split('|') - global OBF_RATIOS_LENGTH - number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH - if number_of_ratios > MAX_OBF_RATIOS_NUM: - OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH - number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH - obf_ratios = [] - secrets_generator = secrets.SystemRandom() - for _ in range(number_of_ratios): - secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale) - obf_ratios.append(secure_float) - # start obfuscate ckpt - ckpt_dir_files = os.listdir(ckpt_files) - for ckpt_name in ckpt_dir_files: - sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name - if Path(sub_path).is_dir(): - sub_ckpt_file_list = os.listdir(sub_path) - new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name - if not os.path.exists(new_saved_path): - try: - os.mkdir(new_saved_path, mode=0o700) - except FileExistsError: - pass - for sub_ckpt_name in sub_ckpt_file_list: - if not sub_ckpt_name.endswith('.ckpt'): - continue - _obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list, - target_list, new_saved_path) - else: - if not ckpt_name.endswith('.ckpt'): - continue - _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list, - target_list, saved_path) - return obf_ratios - - -def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path): - """Obfuscate single ckpt file""" - module_has_been_obfuscated = set() - try: - ckpt_param = load_checkpoint(ckpt_name) - except (ValueError, TypeError, OSError): - logger.error("Load checkpoint failed for file {}.".format(ckpt_name)) - return None - obf_ratios_index = -1 - for item in ckpt_param: - module = _get_valid_module(item, path_list, target_list) - if module: - layer_index = _judge_layer_index(item) - if layer_index >= OBF_RATIOS_LENGTH: - continue - if module not in module_has_been_obfuscated: - module_has_been_obfuscated.add(module) - obf_ratios_index += 1 - ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH - ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index]) - # save the obfuscated model to saved_path - obf_param_list = [] - for item in ckpt_param: - obf_param_list.append({'name': item, 'data': ckpt_param[item]}) - ckpt_file_name = ckpt_name.split('/')[-1] - obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt' - save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name) - return None - - -def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs): - """ - load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt` - interface. - - Args: - network (nn.Cell): The original network that need to be obfuscated. - target_modules (list[str]): The target module of network that need to be obfuscated. The first string - represents the network path of target module in original network, which should be in form of ``'A/B/C'``. - The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For - example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``. - If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or - 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers - (such as transformer layers or resnet blocks). - data_parallel_num (int): The data parallel number of parallel training. Default: 1. - obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`. - kwargs (dict): Configuration options dictionary. - - - ignored_func_decorators (list[str]): The name list of function decorators in network's python code. - - ignored_class_decorators (list[str]): The name list of class decorators in network's python code. - - Raises: - TypeError: If `network` is not nn.Cell. - TypeError: If `obf_ratios` is not Tensor. - TypeError: If `target_modules` is not list. - TypeError: If target_modules's elements are not string. - ValueError: If the number of elements of `target_modules` is less than ``2``. - ValueError: If `obf_ratios` is empty Tensor. - ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase - letters, numbers, ``'_'`` and ``'/'``. - ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and - lowercase letters, numbers, ``'_'`` and ``'|'``. - ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or - 'obfuscate_layers:int'. - TypeError: If `ignored_func_decorators` is not list[str] or `ignored_class_decorators` is not list[str]. - - Examples: - >>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor - >>> import mindspore.common.dtype as mstype - >>> import numpy as np - >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py - >>> net = LeNet5() - >>> save_checkpoint(net, './test_net.ckpt') - >>> target_modules = ['', 'fc1|fc2'] - >>> # obfuscate ckpt files - >>> obfuscate_ckpt(net, target_modules, './', './') - >>> # load obf ckpt into network - >>> new_net = LeNet5() - >>> load_checkpoint('./test_net_obf.ckpt', new_net) - >>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16) - >>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios) - """ - if not isinstance(network, nn.Cell): - raise TypeError("network must be nn.Cell, but got {}.".format(type(network))) - if not isinstance(obf_ratios, Tensor): - raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios))) - if obf_ratios.size == 0: - raise ValueError("obf_ratios can not be empty.") - if not _check_valid_target(network, target_modules): - raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules)) - if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0): - raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num)) - if len(target_modules) >= 1 and target_modules[0] == '/': - target_modules[0] = '' - path_list = target_modules[0].split('/') - path_len = len(path_list) - target_list = [] - for _ in range(path_len): - target_list.append([]) - target_list.append(target_modules[1].split('|')) - global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH - number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH - if number_of_ratios > MAX_OBF_RATIOS_NUM: - OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH - number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH - MAX_OBF_RATIOS_NUM = number_of_ratios - rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs) - setattr(rewrite_network, 'obf_ratios', obf_ratios) - return rewrite_network - - -def _check_dir_path(name, dir_path): - """check directory path""" - if not isinstance(dir_path, str): - raise TypeError("{} must be string, but got {}.".format(name, type(dir_path))) - if not os.path.exists(dir_path): - raise ValueError("{} is not exist, please check the input {}.".format(dir_path, name)) - if not Path(dir_path).is_dir(): - raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path)) - - -def _judge_layer_index(layer_name): - """Judge the layer index of target layers""" - split_name = layer_name.split('.') - for split_str in split_name[:]: - if split_str.isdigit(): - return int(split_str) - return 0 - - -def _check_valid_target(network, target_modules): - """check whether the input 'target_modules' exists""" - if not isinstance(target_modules, list): - raise TypeError("target_modules type should be list, but got {}.".format(type(target_modules))) - if len(target_modules) < 2: - raise ValueError("target_modules should contain at least two string values, in the form of ['A/B/C', 'D1|D2']," - "but got {}.".format(target_modules)) - if (not isinstance(target_modules[0], str)) or (not isinstance(target_modules[1], str)): - raise TypeError("The values of target_modules should be string, but got {} and {}.". - format(type(target_modules[0]), type(target_modules[1]))) - - if not target_modules[1]: - raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'" - .format(target_modules[1])) - if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \ - or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]): - raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']." - "target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/'," - "target_modules[1] can only contain uppercase and lowercase letters, numbers, '_' and '|'" - .format(target_modules)) - # target_modules[0] is allowed to be '', it means the main network path - path_list = target_modules[0].split('/') - target_list = target_modules[1].split('|') - net = network - # DFS check whether path_list is valid - stk = [net] - i = 0 - global OBF_RATIOS_LENGTH - OBF_RATIOS_LENGTH = 1 - while stk and i < len(path_list): - net = stk.pop() - if hasattr(net, path_list[i]): - net = getattr(net, path_list[i]) - i += 1 - if isinstance(net, nn.CellList): - OBF_RATIOS_LENGTH *= len(net) - for n in net: - stk.append(n) - elif isinstance(net, nn.Cell): - stk.append(net) - else: - raise TypeError("Target_modules[0] should be a subgraph and it's type should be nn.Cell(nn.CellList)," - "but got type {}".format(type(net))) - if target_modules[0] != '' and i != len(path_list): - raise ValueError("the path {} does not exist.".format(target_modules[0])) - # check whether target_list is valid - global OBF_RATIOS_WIDTH - OBF_RATIOS_WIDTH = 0 - for target in target_list: - if not hasattr(net, target): - logger.warning("{} does not exist in the path {}".format(target, target_modules[0])) - else: - OBF_RATIOS_WIDTH += 1 - if OBF_RATIOS_WIDTH == 0: - raise ValueError("all targets {} do not exist in the path {}.".format(target_list, target_modules[0])) - _update_max_obf_ratios_num(target_modules) - return True - - -def _update_max_obf_ratios_num(target_modules): - """Update MAX_OBF_RATIOS_NUM""" - if len(target_modules) >= 3: - obfuscate_layers = target_modules[2].split(':') - if len(obfuscate_layers) != 2 or obfuscate_layers[0] != 'obfuscate_layers': - raise ValueError("The third value of target_modules should be in the format of 'obfuscate_layers:all' or" - "'obfuscate_layers:int'") - global MAX_OBF_RATIOS_NUM - if obfuscate_layers[1] == 'all': - MAX_OBF_RATIOS_NUM = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH - else: - if not obfuscate_layers[1].isdigit(): - raise ValueError( - "The third value of target_modules should be in the format of 'obfuscate_layers:all' or" - "'obfuscate_layers:int'") - MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH - - -def _get_default_target_modules(ckpt_files): - """Get the default or suggested target modules, if the target modules is None.""" - - def _split_to_path_and_target(module, target): - # split module into path list and target list - target_index = module.index(target) - path = module[:target_index - 1] - target = module[target_index:].split('/')[0] - return path, target - - def _find_default_obfuscate_modules(net_path): - # find modules including the default paths - default_module = {'attention'} - for module in default_module: - if module in net_path and module not in candidate_modules: - candidate_modules.append(net_path) - # find the default targets in the default module - default_target = {'dense', 'query', 'key', 'value'} - for target in default_target: - for candidate in candidate_modules: - if target in candidate: - path, target = _split_to_path_and_target(candidate, target) - if path not in paths: - paths.append(path) - if target not in targets: - targets.append(target) - - def _find_suggested_obfuscate_modules(net_path): - default_target = {'dense', 'query', 'key', 'value'} - for target in default_target: - # find the suggest modules - if target in net_path: - path, target = _split_to_path_and_target(net_path, target) - if [path, target] not in suggest_modules: - suggest_modules.append([path, target]) - - # store the potential candidate_modules - candidate_modules = [] - suggest_modules = [] - paths = [] - targets = [] - ckpt_dir_files = os.listdir(ckpt_files) - for ckpt_name in ckpt_dir_files: - if not ckpt_name.endswith('.ckpt'): - continue - try: - ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name) - except (ValueError, TypeError, OSError): - logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name)) - return None - for item in ckpt_param: - param_path = _remove_digit(item) - param_path = '/'.join(param_path) - # find candidate modules including the default paths and append candidate_modules - _find_default_obfuscate_modules(param_path) - # give the suggested modules and find the default targets in the default module - _find_suggested_obfuscate_modules(param_path) - if paths and targets: - target_modules = [paths[0], '|'.join(targets)] - logger.warning("The default obfuscate modules is obtained:{}".format(target_modules)) - return target_modules - # logging the suggested target module - logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}" - .format(suggest_modules)) - raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']") - - -def _get_valid_module(item, path_list, target_list): - """get the valid module""" - number_path = len(path_list) - net_path = _remove_digit(item) - net_path = '/'.join(net_path[:number_path]) - tar_path = '/'.join(path_list) - # update the weights with obf_ratios in target module - if net_path == tar_path: - for target in target_list: - if target in item.split('.'): - target_index = item.split('.').index(target) - module = ''.join(item.split('.')[:target_index + 1]) - return module - return None - - -def _remove_digit(item): - """remove digit in the parameter path""" - param_path = item.split('.') - for tmp_str in param_path[:]: - if tmp_str.isdigit(): - param_path.remove(tmp_str) - return param_path - - -def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs): - """obfuscate original network, including add mul operation and add inputs for passing obf_ratio.""" - - def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'): - """add inputs for passing obf_ratio""" - last_input = None - for node in stree.nodes(): - if node.get_node_type() == NodeType.Input: - last_input = node - position = stree.after(last_input) - # the insert input node name would be 'input_y_obf' - new_input_node = last_input.create_input(arg_name) - stree.insert(position, new_input_node) - - def _insert_mul(stree: SymbolTree, node: Node, index: int): - """add mul operation for original network""" - arg_list = node.get_targets().copy() - input_y_node = stree.get_node("input_y_obf") - v: str = input_y_node.get_targets()[0].value - sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]') - arg_list.append(sv) - target_list = node.get_targets().copy() - if data_parallel_num > 1: - logger.info("Data parallel number is: {}".format(data_parallel_num)) - new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())), - targets=target_list, args=arg_list, name='mul') - else: - new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul') - position = stree.after(node) - stree.insert(position, new_mul_node) - - def _insert_mul_by_name(stree: SymbolTree, after_name_list: list): - """add mul operation after the target nodes according the name of them""" - if not after_name_list: - return - for node in stree.nodes(): - for after_name in after_name_list: - if node.get_name() == after_name: - global OBF_RATIOS_INSERT_INDEX - if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM: - _insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX) - OBF_RATIOS_INSERT_INDEX += 1 - - def _update_subnet(substree: SymbolTree, subnode: Node): - """update the network once the subnet is obfuscated""" - input_y_node = substree.get_node("input_y_obf") - if input_y_node is None: - return - subnode.get_handler().append_kwarg({"y_obf": input_y_node.get_targets()[0]}) - - def _traverse(stree, i=0): - """traverse and obfuscate the original network""" - if len(path_list) == i: - return - for node in stree.nodes(): - node_name = node.get_name() - if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]): - sub_stree = node.get_sub_tree() - _traverse(sub_stree, i + 1) - _insert_input(sub_stree, arg_name='y_obf') - _insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1]) - _update_subnet(sub_stree, node) - - def _register_denied_func_decorators(fn): - """set the function decorators which should be denied for parse""" - name = "denied_function_decorator_list" - setattr(ClassDefParser, name, fn) - - def _register_denied_class_decorators(fn): - """set the class decorators which should be denied for parse""" - name = "denied_class_decorator_list" - setattr(ModuleParser, name, fn) - - if 'ignored_func_decorators' in kwargs.keys(): - kw_func_dec = kwargs["ignored_func_decorators"] - if not isinstance(kw_func_dec, list): - raise TypeError('{} should be list, but got {}'.format(kw_func_dec, type(kw_func_dec))) - if kw_func_dec and not isinstance(kw_func_dec[0], str): - raise TypeError('elements of {} should be str, but got {}'.format(kw_func_dec, type(kw_func_dec[0]))) - _register_denied_func_decorators(kw_func_dec) - else: - _register_denied_func_decorators(["_args_type_validator_check", "_LogActionOnce", "cell_attr_register"]) - if 'ignored_class_decorators' in kwargs.keys(): - kw_class_dec = kwargs["ignored_class_decorators"] - _register_denied_class_decorators(kw_class_dec) - if not isinstance(kw_class_dec, list): - raise TypeError('{} should be list[str] type, but got {}'.format(kw_class_dec, type(kw_class_dec))) - if kw_class_dec and not isinstance(kw_class_dec[0], str): - raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0]))) - - main_stree = SymbolTree.create(model) - _traverse(main_stree, 0) - _insert_input(main_stree, arg_name='y_obf') - _insert_mul_by_name(main_stree, after_name_list=target_list[0]) - new_net = main_stree.get_network() - return new_net +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""obfuscate network based on rewrite interfaces.""" +import os +import re +import secrets +from pathlib import Path + +from mindspore import ops, nn +from mindspore.common.tensor import Tensor +from mindspore import log as logger +from mindspore import load_checkpoint, save_checkpoint +from mindspore.rewrite import SymbolTree, Node, NodeType, ScopedValue +from mindspore.rewrite.parsers import ClassDefParser +from mindspore.rewrite.parsers import ModuleParser + +OBF_RATIOS_LENGTH = 1 +MAX_OBF_RATIOS_NUM = 50 +OBF_RATIOS_WIDTH = 0 +OBF_RATIOS_INSERT_INDEX = 0 + + +def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100): + """ + obfuscate the plaintext checkpoint files. Usually used in conjunction with + :func:`mindspore.load_obf_params_into_net`. + interface. + + Args: + network (nn.Cell): The original network that need to be obfuscated. + ckpt_files (str): The directory path of original ckpt files. + target_modules (list[str]): The target module of network that need to be obfuscated. The first string + represents the network path of target module in original network, which should be in form of ``'A/B/C'``. + The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For + example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``. + If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or + 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers + (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search + target modules by itself. If found, the searched target module would be used, otherwise suggested target + modules would be given with warning log. Default: ``None``. + saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``. + obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in + range of (1 / obfuscate_scale, obfuscate_scale). Default: 100. + + Raises: + TypeError: If `network` is not nn.Cell. + TypeError: If `ckpt_files` is not string or `saved_path` is not string. + TypeError: If `target_modules` is not list. + TypeError: If target_modules's elements are not string. + ValueError: If `ckpt_files` is not exist or `saved_path` is not exist. + ValueError: If the number of elements of `target_modules` is less than ``2``. + ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase + letters, numbers, ``'_'`` and ``'/'``. + ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and + lowercase letters, numbers, ``'_'`` and ``'|'``. + ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or + 'obfuscate_layers:int'. + + Returns: + list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network. + + Examples: + >>> from mindspore import obfuscate_ckpt, save_checkpoint + >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py + >>> net = LeNet5() + >>> save_checkpoint(net, './test_net.ckpt') + >>> target_modules = ['', 'fc1|fc2'] + >>> obfuscate_ckpt(net, target_modules, './', './') + """ + if not isinstance(network, nn.Cell): + raise TypeError("network must be nn.Cell, but got {}.".format(type(network))) + _check_dir_path('ckpt_files', ckpt_files) + _check_dir_path('saved_path', saved_path) + # Try to find default target modules + if target_modules is None: + to_split_modules = _get_default_target_modules(ckpt_files) + else: + if len(target_modules) >= 1 and target_modules[0] == '/': + target_modules[0] = '' + to_split_modules = target_modules + if not _check_valid_target(network, to_split_modules): + raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'." + .format(to_split_modules)) + if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1): + raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}." + .format(obfuscate_scale)) + # generate and save obf_ratios to saved_path + path_list = to_split_modules[0].split('/') + target_list = to_split_modules[1].split('|') + global OBF_RATIOS_LENGTH + number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH + if number_of_ratios > MAX_OBF_RATIOS_NUM: + OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH + number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH + obf_ratios = [] + secrets_generator = secrets.SystemRandom() + for _ in range(number_of_ratios): + secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale) + obf_ratios.append(secure_float) + # start obfuscate ckpt + ckpt_dir_files = os.listdir(ckpt_files) + for ckpt_name in ckpt_dir_files: + sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name + if Path(sub_path).is_dir(): + sub_ckpt_file_list = os.listdir(sub_path) + new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name + if not os.path.exists(new_saved_path): + try: + os.mkdir(new_saved_path, mode=0o700) + except FileExistsError: + pass + for sub_ckpt_name in sub_ckpt_file_list: + if not sub_ckpt_name.endswith('.ckpt'): + continue + _obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list, + target_list, new_saved_path) + else: + if not ckpt_name.endswith('.ckpt'): + continue + _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list, + target_list, saved_path) + return obf_ratios + + +def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path): + """Obfuscate single ckpt file""" + module_has_been_obfuscated = set() + try: + ckpt_param = load_checkpoint(ckpt_name) + except (ValueError, TypeError, OSError): + logger.error("Load checkpoint failed for file {}.".format(ckpt_name)) + return None + obf_ratios_index = -1 + for item in ckpt_param: + module = _get_valid_module(item, path_list, target_list) + if module: + layer_index = _judge_layer_index(item) + if layer_index >= OBF_RATIOS_LENGTH: + continue + if module not in module_has_been_obfuscated: + module_has_been_obfuscated.add(module) + obf_ratios_index += 1 + ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH + ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index]) + # save the obfuscated model to saved_path + obf_param_list = [] + for item in ckpt_param: + obf_param_list.append({'name': item, 'data': ckpt_param[item]}) + ckpt_file_name = ckpt_name.split('/')[-1] + obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt' + save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name) + return None + + +def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs): + """ + load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt` + interface. + + Args: + network (nn.Cell): The original network that need to be obfuscated. + target_modules (list[str]): The target module of network that need to be obfuscated. The first string + represents the network path of target module in original network, which should be in form of ``'A/B/C'``. + The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For + example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``. + If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or + 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers + (such as transformer layers or resnet blocks). + data_parallel_num (int): The data parallel number of parallel training. Default: 1. + obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`. + kwargs (dict): Configuration options dictionary. + + - ignored_func_decorators (list[str]): The name list of function decorators in network's python code. + - ignored_class_decorators (list[str]): The name list of class decorators in network's python code. + + Raises: + TypeError: If `network` is not nn.Cell. + TypeError: If `obf_ratios` is not Tensor. + TypeError: If `target_modules` is not list. + TypeError: If target_modules's elements are not string. + ValueError: If the number of elements of `target_modules` is less than ``2``. + ValueError: If `obf_ratios` is empty Tensor. + ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase + letters, numbers, ``'_'`` and ``'/'``. + ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and + lowercase letters, numbers, ``'_'`` and ``'|'``. + ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or + 'obfuscate_layers:int'. + TypeError: If `ignored_func_decorators` is not list[str] or `ignored_class_decorators` is not list[str]. + + Examples: + >>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor + >>> import mindspore.common.dtype as mstype + >>> import numpy as np + >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py + >>> net = LeNet5() + >>> save_checkpoint(net, './test_net.ckpt') + >>> target_modules = ['', 'fc1|fc2'] + >>> # obfuscate ckpt files + >>> obfuscate_ckpt(net, target_modules, './', './') + >>> # load obf ckpt into network + >>> new_net = LeNet5() + >>> load_checkpoint('./test_net_obf.ckpt', new_net) + >>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16) + >>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios) + """ + if not isinstance(network, nn.Cell): + raise TypeError("network must be nn.Cell, but got {}.".format(type(network))) + if not isinstance(obf_ratios, Tensor): + raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios))) + if obf_ratios.size == 0: + raise ValueError("obf_ratios can not be empty.") + if not _check_valid_target(network, target_modules): + raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules)) + if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0): + raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num)) + if len(target_modules) >= 1 and target_modules[0] == '/': + target_modules[0] = '' + path_list = target_modules[0].split('/') + path_len = len(path_list) + target_list = [] + for _ in range(path_len): + target_list.append([]) + target_list.append(target_modules[1].split('|')) + global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH + number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH + if number_of_ratios > MAX_OBF_RATIOS_NUM: + OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH + number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH + MAX_OBF_RATIOS_NUM = number_of_ratios + rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs) + setattr(rewrite_network, 'obf_ratios', obf_ratios) + return rewrite_network + + +def _check_dir_path(name, dir_path): + """check directory path""" + if not isinstance(dir_path, str): + raise TypeError("{} must be string, but got {}.".format(name, type(dir_path))) + if not os.path.exists(dir_path): + raise ValueError("{} is not exist, please check the input {}.".format(dir_path, name)) + if not Path(dir_path).is_dir(): + raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path)) + + +def _judge_layer_index(layer_name): + """Judge the layer index of target layers""" + split_name = layer_name.split('.') + for split_str in split_name[:]: + if split_str.isdigit(): + return int(split_str) + return 0 + + +def _check_valid_target(network, target_modules): + """check whether the input 'target_modules' exists""" + if not isinstance(target_modules, list): + raise TypeError("target_modules type should be list, but got {}.".format(type(target_modules))) + if len(target_modules) < 2: + raise ValueError("target_modules should contain at least two string values, in the form of ['A/B/C', 'D1|D2']," + "but got {}.".format(target_modules)) + if (not isinstance(target_modules[0], str)) or (not isinstance(target_modules[1], str)): + raise TypeError("The values of target_modules should be string, but got {} and {}.". + format(type(target_modules[0]), type(target_modules[1]))) + + if not target_modules[1]: + raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'" + .format(target_modules[1])) + if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \ + or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]): + raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']." + "target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/'," + "target_modules[1] can only contain uppercase and lowercase letters, numbers, '_' and '|'" + .format(target_modules)) + # target_modules[0] is allowed to be '', it means the main network path + path_list = target_modules[0].split('/') + target_list = target_modules[1].split('|') + net = network + # DFS check whether path_list is valid + stk = [net] + i = 0 + global OBF_RATIOS_LENGTH + OBF_RATIOS_LENGTH = 1 + while stk and i < len(path_list): + net = stk.pop() + if hasattr(net, path_list[i]): + net = getattr(net, path_list[i]) + i += 1 + if isinstance(net, nn.CellList): + OBF_RATIOS_LENGTH *= len(net) + for n in net: + stk.append(n) + elif isinstance(net, nn.Cell): + stk.append(net) + else: + raise TypeError("Target_modules[0] should be a subgraph and it's type should be nn.Cell(nn.CellList)," + "but got type {}".format(type(net))) + if target_modules[0] != '' and i != len(path_list): + raise ValueError("the path {} does not exist.".format(target_modules[0])) + # check whether target_list is valid + global OBF_RATIOS_WIDTH + OBF_RATIOS_WIDTH = 0 + for target in target_list: + if not hasattr(net, target): + logger.warning("{} does not exist in the path {}".format(target, target_modules[0])) + else: + OBF_RATIOS_WIDTH += 1 + if OBF_RATIOS_WIDTH == 0: + raise ValueError("all targets {} do not exist in the path {}.".format(target_list, target_modules[0])) + _update_max_obf_ratios_num(target_modules) + return True + + +def _update_max_obf_ratios_num(target_modules): + """Update MAX_OBF_RATIOS_NUM""" + if len(target_modules) >= 3: + obfuscate_layers = target_modules[2].split(':') + if len(obfuscate_layers) != 2 or obfuscate_layers[0] != 'obfuscate_layers': + raise ValueError("The third value of target_modules should be in the format of 'obfuscate_layers:all' or" + "'obfuscate_layers:int'") + global MAX_OBF_RATIOS_NUM + if obfuscate_layers[1] == 'all': + MAX_OBF_RATIOS_NUM = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH + else: + if not obfuscate_layers[1].isdigit(): + raise ValueError( + "The third value of target_modules should be in the format of 'obfuscate_layers:all' or" + "'obfuscate_layers:int'") + MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH + + +def _get_default_target_modules(ckpt_files): + """Get the default or suggested target modules, if the target modules is None.""" + + def _split_to_path_and_target(module, target): + # split module into path list and target list + target_index = module.index(target) + path = module[:target_index - 1] + target = module[target_index:].split('/')[0] + return path, target + + def _find_default_obfuscate_modules(net_path): + # find modules including the default paths + default_module = {'attention'} + for module in default_module: + if module in net_path and module not in candidate_modules: + candidate_modules.append(net_path) + # find the default targets in the default module + default_target = {'dense', 'query', 'key', 'value'} + for target in default_target: + for candidate in candidate_modules: + if target in candidate: + path, target = _split_to_path_and_target(candidate, target) + if path not in paths: + paths.append(path) + if target not in targets: + targets.append(target) + + def _find_suggested_obfuscate_modules(net_path): + default_target = {'dense', 'query', 'key', 'value'} + for target in default_target: + # find the suggest modules + if target in net_path: + path, target = _split_to_path_and_target(net_path, target) + if [path, target] not in suggest_modules: + suggest_modules.append([path, target]) + + # store the potential candidate_modules + candidate_modules = [] + suggest_modules = [] + paths = [] + targets = [] + ckpt_dir_files = os.listdir(ckpt_files) + for ckpt_name in ckpt_dir_files: + if not ckpt_name.endswith('.ckpt'): + continue + try: + ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name) + except (ValueError, TypeError, OSError): + logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name)) + return None + for item in ckpt_param: + param_path = _remove_digit(item) + param_path = '/'.join(param_path) + # find candidate modules including the default paths and append candidate_modules + _find_default_obfuscate_modules(param_path) + # give the suggested modules and find the default targets in the default module + _find_suggested_obfuscate_modules(param_path) + if paths and targets: + target_modules = [paths[0], '|'.join(targets)] + logger.warning("The default obfuscate modules is obtained:{}".format(target_modules)) + return target_modules + # logging the suggested target module + logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}" + .format(suggest_modules)) + raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']") + + +def _get_valid_module(item, path_list, target_list): + """get the valid module""" + number_path = len(path_list) + net_path = _remove_digit(item) + net_path = '/'.join(net_path[:number_path]) + tar_path = '/'.join(path_list) + # update the weights with obf_ratios in target module + if net_path == tar_path: + for target in target_list: + if target in item.split('.'): + target_index = item.split('.').index(target) + module = ''.join(item.split('.')[:target_index + 1]) + return module + return None + + +def _remove_digit(item): + """remove digit in the parameter path""" + param_path = item.split('.') + for tmp_str in param_path[:]: + if tmp_str.isdigit(): + param_path.remove(tmp_str) + return param_path + + +def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs): + """obfuscate original network, including add mul operation and add inputs for passing obf_ratio.""" + + def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'): + """add inputs for passing obf_ratio""" + last_input = None + for node in stree.nodes(): + if node.get_node_type() == NodeType.Input: + last_input = node + position = stree.after(last_input) + # the insert input node name would be 'input_y_obf' + new_input_node = last_input.create_input(arg_name) + stree.insert(position, new_input_node) + + def _insert_mul(stree: SymbolTree, node: Node, index: int): + """add mul operation for original network""" + arg_list = node.get_targets().copy() + input_y_node = stree.get_node("input_y_obf") + v: str = input_y_node.get_targets()[0].value + sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]') + arg_list.append(sv) + target_list = node.get_targets().copy() + if data_parallel_num > 1: + logger.info("Data parallel number is: {}".format(data_parallel_num)) + new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())), + targets=target_list, args=arg_list, name='mul') + else: + new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul') + position = stree.after(node) + stree.insert(position, new_mul_node) + + def _insert_mul_by_name(stree: SymbolTree, after_name_list: list): + """add mul operation after the target nodes according the name of them""" + if not after_name_list: + return + for node in stree.nodes(): + for after_name in after_name_list: + if node.get_name() == after_name: + global OBF_RATIOS_INSERT_INDEX + if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM: + _insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX) + OBF_RATIOS_INSERT_INDEX += 1 + + def _update_subnet(substree: SymbolTree, subnode: Node): + """update the network once the subnet is obfuscated""" + input_y_node = substree.get_node("input_y_obf") + if input_y_node is None: + return + subnode.get_handler().append_kwarg({"y_obf": input_y_node.get_targets()[0]}) + + def _traverse(stree, i=0): + """traverse and obfuscate the original network""" + if len(path_list) == i: + return + for node in stree.nodes(): + node_name = node.get_name() + if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]): + sub_stree = node.get_sub_tree() + _traverse(sub_stree, i + 1) + _insert_input(sub_stree, arg_name='y_obf') + _insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1]) + _update_subnet(sub_stree, node) + + def _register_denied_func_decorators(fn): + """set the function decorators which should be denied for parse""" + name = "denied_function_decorator_list" + setattr(ClassDefParser, name, fn) + + def _register_denied_class_decorators(fn): + """set the class decorators which should be denied for parse""" + name = "denied_class_decorator_list" + setattr(ModuleParser, name, fn) + + if 'ignored_func_decorators' in kwargs.keys(): + kw_func_dec = kwargs["ignored_func_decorators"] + if not isinstance(kw_func_dec, list): + raise TypeError('{} should be list, but got {}'.format(kw_func_dec, type(kw_func_dec))) + if kw_func_dec and not isinstance(kw_func_dec[0], str): + raise TypeError('elements of {} should be str, but got {}'.format(kw_func_dec, type(kw_func_dec[0]))) + _register_denied_func_decorators(kw_func_dec) + else: + _register_denied_func_decorators(["_args_type_validator_check", "_LogActionOnce", "cell_attr_register"]) + if 'ignored_class_decorators' in kwargs.keys(): + kw_class_dec = kwargs["ignored_class_decorators"] + _register_denied_class_decorators(kw_class_dec) + if not isinstance(kw_class_dec, list): + raise TypeError('{} should be list[str] type, but got {}'.format(kw_class_dec, type(kw_class_dec))) + if kw_class_dec and not isinstance(kw_class_dec[0], str): + raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0]))) + + main_stree = SymbolTree.create(model) + _traverse(main_stree, 0) + _insert_input(main_stree, arg_name='y_obf') + _insert_mul_by_name(main_stree, after_name_list=target_list[0]) + new_net = main_stree.get_network() + return new_net diff --git a/mindspore/python/mindspore/scipy/linalg.py b/mindspore/python/mindspore/scipy/linalg.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/train/metrics/__init__.py b/mindspore/python/mindspore/train/metrics/__init__.py old mode 100755 new mode 100644 diff --git a/mindspore/python/mindspore/train/metrics/fbeta.py b/mindspore/python/mindspore/train/metrics/fbeta.py old mode 100755 new mode 100644 diff --git a/scripts/build/akg_find_llvm.sh b/scripts/build/akg_find_llvm.sh deleted file mode 100755 index 7e7eb6b7346..00000000000 --- a/scripts/build/akg_find_llvm.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# Copyright 2021-2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# Find a suitable LLVM version for AKG. -# -# This file generates a temporary cmake script file -# and executes it by `cmake -P` (cmake script mode). -# -# If no suitable LLVM is found, the `find_package` function runs normally, -# the `cmake` command exits with status `0`. -# -# If suitable LLVM is found, the `find_package` will encounter the error -# "add_library command is not scriptable" in `LLVMExports.cmake` of LLVM library. -# This error is caused because of running `cmake` in script mode. -# Finally the `cmake` command exit with status `1`. - -echo "find_package(LLVM 16 QUIET)" > akg_llvm_tmp.cmake -echo "find_package(LLVM 15 QUIET)" >> akg_llvm_tmp.cmake -echo "find_package(LLVM 14 QUIET)" >> akg_llvm_tmp.cmake -echo "find_package(LLVM 13 QUIET)" >> akg_llvm_tmp.cmake -echo "find_package(LLVM 12 QUIET)" >> akg_llvm_tmp.cmake -cmake -P akg_llvm_tmp.cmake > /dev/null 2>&1 -result=$? -rm akg_llvm_tmp.cmake - -if [ ${result} -eq 0 ]; then - echo "off" -else - echo "on" -fi - - diff --git a/scripts/build/build_mindspore.sh b/scripts/build/build_mindspore.sh deleted file mode 100755 index be4c9b30326..00000000000 --- a/scripts/build/build_mindspore.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/bin/bash -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -# Create building path -build_mindspore() -{ - echo "start build mindspore project." - mkdir -pv "${BUILD_PATH}/mindspore" - cd "${BUILD_PATH}/mindspore" - CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH" - if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON" - fi - if [[ "X$RUN_TESTCASES" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TESTCASES=ON" - fi - if [[ "X$RUN_CPP_ST_TESTS" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPP_ST=ON" - fi - if [[ -n "$ENABLE_BACKEND" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${ENABLE_BACKEND}=ON" - fi - if [[ "X$ENABLE_SYM_FILE" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SYM_FILE=ON" - fi - if [[ "X$ENABLE_ASAN" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON" - fi - if [[ "X$ENABLE_PROFILE" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PROFILE=ON" - fi - if [[ "X$ENABLE_SECURITY" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SECURITY=ON" - fi - if [[ "X$ENABLE_TIMELINE" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TIMELINE=ON" - fi - if [[ "X$ENABLE_DUMP2PROTO" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_PROTO=ON" - fi - if [[ "X$ENABLE_GITEE" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" - fi - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_IR=${ENABLE_DUMP_IR}" - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PYTHON=${ENABLE_PYTHON}" - if [[ "X$ENABLE_MPI" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MPI=ON" - fi - if [[ "X$ENABLE_D" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON -DASCEND_VERSION=${ASCEND_VERSION}" - fi - if [[ "X$ENABLE_GPU" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON" - fi - if [[ "X$GPU_BACKEND" = "Xrocm" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DGPU_BACKEND_ROCM=ON -DROCM_PATH=$ROCM_PATH" - fi - if [[ "X$GPU_BACKEND" = "Xcuda" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DGPU_BACKEND_CUDA=ON -DUSE_CUDA=ON -DCUDA_PATH=$CUDA_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}" - fi - if [[ "X$ENABLE_CPU" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD} -DARM_SIMD=${ARM_SIMD}" - fi - if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON" - fi - if [[ "X$USE_GLOG" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" - fi - if [[ "X$ENABLE_AKG" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" - if [[ "X$USE_LLVM" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DUSE_LLVM=ON" - fi - fi - if [[ "X$ENABLE_ACL" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ACL=ON" - fi - if [[ "X$ENABLE_DEBUGGER" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DEBUGGER=ON" - fi - - if [[ "X$ENABLE_RDMA" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_RDMA=ON" - fi - if [[ "X$ENABLE_HIDDEN" = "Xoff" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_HIDDEN=OFF" - fi - if [[ "X$ENABLE_TRT" == "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DTENSORRT_HOME=${TENSORRT_HOME}" - fi - if [[ "X$ENABLE_FAST_HASH_TABLE" == "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_FAST_HASH_TABLE=ON" - else - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_FAST_HASH_TABLE=OFF" - fi - if [[ "X$FASTER_BUILD_FOR_PLUGINS" == "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DONLY_BUILD_DEVICE_PLUGINS=ON" - fi - if [[ "X$ENABLE_AIO" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AIO=ON" - fi - if [[ "X$ENABLE_DVM" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DVM=ON" - fi - echo "${CMAKE_ARGS}" - if [[ "X$INC_BUILD" = "Xoff" ]]; then - cmake ${CMAKE_ARGS} ${BASEPATH} - fi - if [[ -n "$VERBOSE" ]]; then - CMAKE_VERBOSE="--verbose" - fi - cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM - echo "success building mindspore project!" -} diff --git a/scripts/build/check_and_build_ms_kernels_internal.sh b/scripts/build/check_and_build_ms_kernels_internal.sh deleted file mode 100644 index e60a9b05a0c..00000000000 --- a/scripts/build/check_and_build_ms_kernels_internal.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -if [[ -n "${MS_INTERNAL_KERNEL_HOME}" ]]; then - echo "Use local MS_INTERNAL_KERNEL_HOME : ${MS_INTERNAL_KERNEL_HOME}" - return -fi -if [[ "$(uname)" != Linux || ("$(arch)" != x86_64 && "$(arch)" != aarch64) ]]; then - echo "[WARNING] Internal kernels only supports linux system, x86_64 or aarch64 CPU arch." - return -fi -file_path=${BASEPATH}/mindspore/ccsrc/plugin/device/ascend/kernel/internal/prebuild/$(arch) -file_name=${file_path}/ms_kernels_internal.tar.gz -if [[ ! -f "${file_name}" ]]; then - echo "[WARNING] The file ${file_name} does NOT EXIST." - return -fi -file_lines=`cat "${file_name}" | wc -l` -if [[ ${file_lines} -eq 3 ]]; then - echo "[WARNING] The file ms_kernel_internal.tar.gz is not pulled. Please ensure git-lfs is installed by" - echo "[WARNING] 'git lfs install' and retry downloading using 'git lfs pull'." - return -fi -tar -zxf ${file_name} -C ${file_path} -if [[ $? -ne 0 ]]; then - echo "[WARNING] Unzip ms_kernel_internal.tar.gz FAILED!" - return -fi -echo "Unzip ms_kernel_internal.tar.gz SUCCESS!" -export MS_INTERNAL_KERNEL_HOME="${file_path}/ms_kernels_internal" -echo "MS_INTERNAL_KERNEL_HOME = ${MS_INTERNAL_KERNEL_HOME}" \ No newline at end of file diff --git a/scripts/build/check_binary_file.sh b/scripts/build/check_binary_file.sh deleted file mode 100644 index 7a11cf3d3de..00000000000 --- a/scripts/build/check_binary_file.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -arch_name=`uname -m` -lib_file="${BASEPATH}/mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/${arch_name}/libdvm.a" -if [ -f "${lib_file}" ]; then - file_lines=`cat "${lib_file}" | wc -l` - if [ ${file_lines} -ne 3 ]; then - export ENABLE_DVM="on" - export DVM_LIB="${lib_file}" - fi -fi diff --git a/scripts/build/default_options.sh b/scripts/build/default_options.sh deleted file mode 100755 index 449178e5192..00000000000 --- a/scripts/build/default_options.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/bin/bash -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# shellcheck disable=SC2034 - -set -e - -init_default_options() -{ - # Init default values of build options - export THREAD_NUM=8 - export DEBUG_MODE="off" - VERBOSE="" - export ENABLE_SECURITY="off" - export ENABLE_COVERAGE="off" - export RUN_TESTCASES="off" - export RUN_CPP_ST_TESTS="off" - export ENABLE_BACKEND="" - export ENABLE_ASAN="off" - export ENABLE_PROFILE="off" - export INC_BUILD="off" - export ENABLE_TIMELINE="off" - export ENABLE_DUMP2PROTO="on" - export ENABLE_DUMP_IR="on" - export COMPILE_MINDDATA="on" - export COMPILE_MINDDATA_LITE="lite_cv" - export ENABLE_MPI="off" - export CUDA_VERSION="10.1" - export ASCEND_VERSION="910" - export COMPILE_LITE="off" - export LITE_PLATFORM="" - export LITE_ENABLE_AAR="off" - export USE_GLOG="on" - export ENABLE_AKG="on" - export ENABLE_ACL="off" - export ENABLE_D="off" - export ENABLE_DEBUGGER="on" - export ENABLE_RDMA="off" - export ENABLE_PYTHON="on" - export ENABLE_GPU="off" - export ENABLE_VERBOSE="off" - export ENABLE_GITEE="off" - export ENABLE_MAKE_CLEAN="off" - export X86_64_SIMD="off" - export ARM_SIMD="off" - export DEVICE_VERSION="" - export DEVICE="" - export ENABLE_HIDDEN="on" - export TENSORRT_HOME="" - export USER_ENABLE_DUMP_IR=false - export USER_ENABLE_DEBUGGER=false - export ENABLE_SYM_FILE="off" - export ENABLE_FAST_HASH_TABLE="on" - export CUDA_ARCH="auto" - export ENABLE_AIO="off" -} diff --git a/scripts/build/merge_whl_package.sh b/scripts/build/merge_whl_package.sh deleted file mode 100755 index b140fb1fed6..00000000000 --- a/scripts/build/merge_whl_package.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/bin/bash -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e -BASEPATH=$(cd "$(dirname $0)"; pwd) -MINDSPORE_ROOT_PATH=`realpath $BASEPATH/../../` -BASE_PACKAGE_UNZIP_DIR=./0 -PACKAGE_FILE_NAME=`basename $1` - -counter=0 -for whl in "$@"; do - echo "Unzip $whl ..." - unzip -q $whl -d $counter - ((++counter)) -done - -MAX_GPU_VERSION=0 -declare -A GPU_VERSION_MAP -for ((i=1;i<$counter;i=$i+1)) -do - echo "Rename $i dirname to mindspore ..." - mv ./$i/mindspore.py* "./$i/mindspore" - echo "Copy $i plugin files to 0 ..." - if [ -d "./$i/mindspore/lib/plugin" ]; then - \cp -rf ./$i/mindspore/lib/plugin/* $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/plugin - fi; - if [ -f "./$i/mindspore/lib/libmpi_collective.so" ]; then - \cp -rf ./$i/mindspore/lib/libmpi_collective.so $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/ - fi; - if [ -f "./$i/mindspore/lib/libmpi_adapter.so" ]; then - \cp -rf ./$i/mindspore/lib/libmpi_adapter.so $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/ - fi; - if [ -f "./$i/mindspore/lib/libmindspore.so" ]; then - \cp -rf ./$i/mindspore/lib/libmindspore.so $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/ - fi; - if [ -f "./$i/mindspore/lib/libmindspore_shared_lib.so" ]; then - \cp -rf ./$i/mindspore/lib/libmindspore_shared_lib.so $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/ - fi; - - # dataset library "mindspore/_c_dataengine.*.so" with 910b dvpp which is biggest should be used - file_size_src=`du ./$i/mindspore/_c_dataengine.*.so | awk '{print $1;}'` - file_size_dst=`du $BASE_PACKAGE_UNZIP_DIR/mindspore/_c_dataengine.*.so | awk '{print $1;}'` - echo "_c_dataengine.*.so, file_size_src: ${file_size_src}, file_size_dst: ${file_size_dst}" - if [ $file_size_src -gt $file_size_dst ]; then - \cp -rf ./$i/mindspore/_c_dataengine.*.so $BASE_PACKAGE_UNZIP_DIR/mindspore/ - fi; - - CUR_GPU_VERSION=`find "./$i/mindspore/lib/plugin" -name 'gpu*' -exec sh -c 'echo ${0##*gpu}' {} \;` - if [ -n "$CUR_GPU_VERSION" ]; then - GPU_VERSION_MAP[$CUR_GPU_VERSION]=$i - else - rm -rf $i - fi; -done - -for key in $(for x in "${!GPU_VERSION_MAP[@]}"; do echo $x; done | sort) -do - i=${GPU_VERSION_MAP[$key]} - CUR_GPU_VERSION=$key - CUDA_OPS_FILE=`basename ./$i/mindspore/lib/plugin/gpu$CUR_GPU_VERSION/libcuda_ops.so*` - if [ "`echo "$CUR_GPU_VERSION > $MAX_GPU_VERSION" | bc`" -eq 1 ]; then - if [ "`echo "$CUR_GPU_VERSION > $MAX_GPU_VERSION" | bc`" -eq 1 ]; then - MAX_GPU_VERSION=$CUR_GPU_VERSION - fi; - if [ ! -d "$BASE_PACKAGE_UNZIP_DIR/mindspore/lib/plugin/gpu" ]; then - mkdir -p $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/plugin/gpu - fi; - \cp -rf ./$i/mindspore/lib/plugin/gpu$CUR_GPU_VERSION/$CUDA_OPS_FILE \ - $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/plugin/gpu/$CUDA_OPS_FILE - fi; - rm -f $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/plugin/gpu$CUR_GPU_VERSION/libcuda_ops.so* - rm -rf $i -done - - -export COMMIT_ID=`cat $BASE_PACKAGE_UNZIP_DIR/mindspore/.commit_id | awk '{print $3}' | sed $'s/\'//g'` -VERSION=`cat $BASE_PACKAGE_UNZIP_DIR/mindspore/version.py | awk '{print $3}' | sed $'s/\'//g'` -echo -n "$VERSION" > $MINDSPORE_ROOT_PATH/version.txt - -echo "Delete useless file ..." -rm -f $BASE_PACKAGE_UNZIP_DIR/mindspore/version.py -rm -f $BASE_PACKAGE_UNZIP_DIR/mindspore/default_config.py -rm -f $BASE_PACKAGE_UNZIP_DIR/mindspore/.commit_id -rm -f $BASE_PACKAGE_UNZIP_DIR/mindspore/lib/libakg.so - -echo "Repacking new wheel package ..." -PACKAGE_WORK_DIR=$MINDSPORE_ROOT_PATH/build -if [ -d "$PACKAGE_WORK_DIR/package" ]; then - rm -rf $PACKAGE_WORK_DIR/package -fi -mkdir -p $MINDSPORE_ROOT_PATH/mindspore/python/mindspore -mkdir -p $PACKAGE_WORK_DIR -PACKAGE_WORK_DIR=`realpath $PACKAGE_WORK_DIR` -mv $BASE_PACKAGE_UNZIP_DIR $PACKAGE_WORK_DIR/package -export MS_PACKAGE_NAME="mindspore" -export BACKEND_POLICY="ms" -export BUILD_PATH=$PACKAGE_WORK_DIR -cd $BUILD_PATH/package -python $MINDSPORE_ROOT_PATH/setup.py bdist_wheel -if [ -d "$MINDSPORE_ROOT_PATH/output" ]; then - rm -rf $MINDSPORE_ROOT_PATH/output -fi -mkdir -p $MINDSPORE_ROOT_PATH/output -mv dist/*.whl $MINDSPORE_ROOT_PATH/output/$PACKAGE_FILE_NAME -cd - -cd $MINDSPORE_ROOT_PATH/output/ -echo "$(sha256sum $PACKAGE_FILE_NAME)" > $PACKAGE_FILE_NAME.sha256 -cd - diff --git a/scripts/build/option_proc_debug.sh b/scripts/build/option_proc_debug.sh deleted file mode 100755 index b8ef0ca758a..00000000000 --- a/scripts/build/option_proc_debug.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -build_option_proc_v() -{ - export ENABLE_VERBOSE="on" - export VERBOSE="VERBOSE=1" -} - -build_option_proc_c() -{ - check_on_off $OPTARG c - export ENABLE_COVERAGE="$OPTARG" -} - -build_option_proc_t() -{ - if [[ "X$OPTARG" == "Xon" || "X$OPTARG" == "Xut" ]]; then - export RUN_TESTCASES="on" - elif [[ "X$OPTARG" == "Xoff" ]]; then - export RUN_TESTCASES="off" - elif [[ "X$OPTARG" == "Xst" ]]; then - export RUN_CPP_ST_TESTS="on" - else - echo "Invalid value ${OPTARG} for option -t" - usage - exit 1 - fi -} - -build_option_proc_g() -{ - check_on_off $OPTARG g - export USE_GLOG="$OPTARG" -} - -build_option_proc_h() -{ - usage - exit 0 -} - -build_option_proc_a() -{ - check_on_off $OPTARG a - export ENABLE_ASAN="$OPTARG" -} - -build_option_proc_p() -{ - check_on_off $OPTARG p - export ENABLE_PROFILE="$OPTARG" -} - -build_option_proc_upper_d() -{ - check_on_off $OPTARG D - if [[ "X$OPTARG" == "Xon" ]]; then - if [[ "X$ENABLE_SECURITY" == "Xon" ]]; then - echo "enable security, the dump ir is not available" - usage - exit 1 - fi - export USER_ENABLE_DUMP_IR=true - fi - export ENABLE_DUMP_IR="$OPTARG" - echo "enable dump function graph ir" -} - -build_option_proc_upper_b() -{ - check_on_off $OPTARG B - if [[ "X$OPTARG" == "Xon" ]]; then - if [[ "X$ENABLE_SECURITY" == "Xon" ]]; then - echo "enable security, the debugger is not available" - usage - exit 1 - fi - export USER_ENABLE_DEBUGGER=true - fi - export ENABLE_DEBUGGER="$OPTARG" -} \ No newline at end of file diff --git a/scripts/build/option_proc_lite.sh b/scripts/build/option_proc_lite.sh deleted file mode 100755 index 8711f89e5c2..00000000000 --- a/scripts/build/option_proc_lite.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -build_option_proc_n() -{ - if [[ "X$OPTARG" == "Xoff" || "X$OPTARG" == "Xlite" || "X$OPTARG" == "Xfull" || "X$OPTARG" == "Xlite_cv" || "X$OPTARG" == "Xwrapper" ]]; then - export COMPILE_MINDDATA_LITE="$OPTARG" - else - echo "Invalid value ${OPTARG} for option -n" - usage - exit 1 - fi -} - -build_option_proc_upper_i() -{ - COMPILE_LITE="on" - if [[ "$OPTARG" == "arm64" ]]; then - LITE_PLATFORM="arm64" - elif [[ "$OPTARG" == "arm32" ]]; then - LITE_PLATFORM="arm32" - elif [[ "$OPTARG" == "x86_64" ]]; then - export LITE_PLATFORM="x86_64" - else - echo "-I parameter must be arm64、arm32 or x86_64" - exit 1 - fi -} - -build_option_proc_upper_a() -{ - export COMPILE_LITE="on" - if [[ "$OPTARG" == "on" ]]; then - export LITE_ENABLE_AAR="on" - fi -} - -build_option_proc_upper_w() -{ - if [[ "$OPTARG" != "sse" && "$OPTARG" != "off" && "$OPTARG" != "avx" && "$OPTARG" != "avx512" && "$OPTARG" != "neon" ]]; then - echo "Invalid value ${OPTARG} for option -W, -W parameter must be sse|neon|avx|avx512|off" - usage - exit 1 - fi - if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" || "$OPTARG" == "avx512" ]]; then - export X86_64_SIMD="$OPTARG" - fi - if [[ "$OPTARG" == "neon" ]]; then - export ARM_SIMD="$OPTARG" - fi -} \ No newline at end of file diff --git a/scripts/build/option_proc_mindspore.sh b/scripts/build/option_proc_mindspore.sh deleted file mode 100755 index 0387fc0f1ff..00000000000 --- a/scripts/build/option_proc_mindspore.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -build_option_proc_b() -{ - if [ "X$OPTARG" != "Xcpu" ]; then - echo "Invalid value ${OPTARG} for option -b" - usage - exit 1 - fi - ENABLE_BACKEND=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]') - if [[ "X$ENABLE_BACKEND" != "XCPU" ]]; then - export ENABLE_CPU="on" - fi -} - -build_option_proc_l() -{ - check_on_off $OPTARG l - export ENABLE_PYTHON="$OPTARG" -} - -build_option_proc_s() -{ - check_on_off $OPTARG s - if [[ "X$OPTARG" == "Xon" ]]; then - if [[ "$USER_ENABLE_DUMP_IR" == true ]]; then - echo "enable security, the dump ir is not available" - usage - exit 1 - fi - if [[ "$USER_ENABLE_DEBUGGER" == true ]]; then - echo "enable security, the debugger is not available" - usage - exit 1 - fi - export ENABLE_DUMP_IR="off" - export ENABLE_DEBUGGER="off" - fi - export ENABLE_SECURITY="$OPTARG" - echo "enable security" -} - -build_option_proc_upper_s() -{ - check_on_off $OPTARG S - export ENABLE_GITEE="$OPTARG" - echo "enable download from gitee" -} - -build_option_proc_upper_f() -{ - check_on_off $OPTARG F - export ENABLE_FAST_HASH_TABLE="$OPTARG" -} - -build_option_proc_z() -{ - eval ARG=\$\{$OPTIND\} - if [[ -n "$ARG" && "$ARG" != -* ]]; then - OPTARG="$ARG" - check_on_off $OPTARG z - OPTIND=$((OPTIND + 1)) - else - OPTARG="" - fi - if [[ "X$OPTARG" == "Xoff" ]]; then - export COMPILE_MINDDATA="off" - fi -} - -build_option_proc_upper_g() -{ - if [[ "X$OPTARG" == "Xcommon" || "X$OPTARG" == "Xauto" || "X$OPTARG" == "Xptx" ]]; then - export CUDA_ARCH=$OPTARG - else - echo "Invalid value $OPTARG for option -G" - usage - exit 1 - fi - echo "build gpu for arch $OPTARG" -} diff --git a/scripts/build/parse_device.sh b/scripts/build/parse_device.sh deleted file mode 100755 index e7be774abe3..00000000000 --- a/scripts/build/parse_device.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/bash -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -# check and set options -parse_device() -{ - if [[ "X$RUN_TESTCASES" == "Xon" && "X$DEVICE" != "X" ]]; then - echo "WARNING:Option -e can't be set while option -t on/ut is set, reset device to empty." - DEVICE="" - fi - - # Parse device - # Process build option - export IFS_ORIGIN=$IFS - export IFS=":" - for D in $DEVICE; - do - if [[ "X$D" == "Xgpu" ]]; then - export ENABLE_GPU="on" - export GPU_BACKEND="cuda" - ENABLE_CPU="on" - ENABLE_MPI="on" - # version default 10.1 - if [[ "X$DEVICE_VERSION" == "X" ]]; then - DEVICE_VERSION=10.1 - fi - if [[ "X$DEVICE_VERSION" != "X11.6" && "X$DEVICE_VERSION" != "X11.1" && "X$DEVICE_VERSION" != "X10.1" ]]; then - echo "Invalid value ${DEVICE_VERSION} for option -V" - usage - exit 1 - fi - export CUDA_VERSION="$DEVICE_VERSION" - export DEVICE_VERSION= - elif [[ "X$D" == "Xrocm" ]]; then - export ENABLE_GPU="on" - export GPU_BACKEND="rocm" - ENABLE_CPU="on" - ENABLE_MPI="on" - export ENABLE_AKG="off" - elif [[ "X$D" == "Xd" || "X$D" == "Xascend" ]]; then - # version default 910 - if [[ "X$DEVICE_VERSION" == "X" ]]; then - DEVICE_VERSION=910 - fi - # building 310 package by giving specific -V 310 instruction - if [[ "X$DEVICE_VERSION" == "X310" ]]; then - export ENABLE_D="on" - export ENABLE_AKG="on" - export ENABLE_ACL="on" - ENABLE_CPU="on" - export ENABLE_MPI="on" - export ENABLE_INTERNAL_KERNELS="on" - # universal ascend package, building 910b package by giving specific -V 910b instruction - elif [[ "X$DEVICE_VERSION" == "X910" || "X$DEVICE_VERSION" == "X910b" ]]; then - export ENABLE_D="on" - export ENABLE_ACL="on" - ENABLE_CPU="on" - export ENABLE_MPI="on" - export ENABLE_INTERNAL_KERNELS="on" - export ASCEND_GLOBAL_LOG_LEVEL=3 - export ASCEND_SLOG_PRINT_TO_STDOUT=1 - else - echo "Invalid value ${DEVICE_VERSION} for option -V" - usage - exit 1 - fi - export DEVICE_VERSION= - elif [[ "X$D" == "Xcpu" ]]; then - export ENABLE_CPU="on" - export ENABLE_MPI="on" - elif [[ "X$D" == "X" ]]; then - : - else - echo "Invalid value ${DEVICE} for option -e" - usage - exit 1 - fi - done - export IFS=$IFS_ORIGIN - if [[ "X$ENABLE_AKG" == "Xon" && "X$ENABLE_D" != "Xon" && "X$ENABLE_CPU" == "Xon" ]]; then - # check llvm version for akg - HAS_LLVM=`bash ${BASEPATH}/scripts/build/akg_find_llvm.sh` - export USE_LLVM=$HAS_LLVM - fi - export ENABLE_DVM="off" - source ${BASEPATH}/scripts/build/check_binary_file.sh - if [[ "X$ENABLE_INTERNAL_KERNELS" == "Xon" ]]; then - source ${BASEPATH}/scripts/build/check_and_build_ms_kernels_internal.sh - fi -} diff --git a/scripts/build/process_options.sh b/scripts/build/process_options.sh deleted file mode 100755 index 81f16f999e1..00000000000 --- a/scripts/build/process_options.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/bin/bash -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -# check and set options -process_options() -{ - # Process the options - while getopts 'RdfhiorvyzA:B:D:E:F:G:H:I:K:L:M:P:S:V:W:a:b:c:e:g:j:k:l:n:p:s:t:' opt - do - CASE_SENSIVE_ARG=${OPTARG} - OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') - case "${opt}" in - d) - DEBUG_MODE="on" ;; - n) - build_option_proc_n ;; - y) - export ENABLE_SYM_FILE="on" ;; - r) - export DEBUG_MODE="off" ;; - v) - build_option_proc_v ;; - j) - export THREAD_NUM=$OPTARG ;; - c) - build_option_proc_c ;; - t) - build_option_proc_t ;; - g) - build_option_proc_g ;; - h) - build_option_proc_h ;; - b) - build_option_proc_b ;; - a) - build_option_proc_a ;; - p) - build_option_proc_p ;; - l) - build_option_proc_l ;; - i) - export INC_BUILD="on" ;; - s) - build_option_proc_s ;; - R) - export ENABLE_TIMELINE="on" - echo "enable time_line record" ;; - S) - build_option_proc_upper_s ;; - k) - check_on_off $OPTARG k - export ENABLE_MAKE_CLEAN="$OPTARG" - echo "enable make clean" ;; - e) - export DEVICE=$DEVICE:$OPTARG ;; - M) - check_on_off $OPTARG M - export ENABLE_MPI="$OPTARG" ;; - V) - export DEVICE_VERSION=$OPTARG ;; - P) - check_on_off $OPTARG p - export ENABLE_DUMP2PROTO="$OPTARG" - echo "enable dump anf graph to proto file" ;; - D) - build_option_proc_upper_d ;; - z) - build_option_proc_z ;; - I) - build_option_proc_upper_i ;; - K) - check_on_off $OPTARG K - export ENABLE_AKG="$OPTARG" ;; - B) - build_option_proc_upper_b ;; - E) - check_on_off $OPTARG E - export ENABLE_RDMA="$OPTARG" - echo "RDMA for RPC $ENABLE_RDMA" ;; - A) - build_option_proc_upper_a ;; - W) - build_option_proc_upper_w ;; - F) - build_option_proc_upper_f ;; - H) - check_on_off $OPTARG H - export ENABLE_HIDDEN="$OPTARG" - echo "${OPTARG} hidden" ;; - L) - export ENABLE_TRT="on" - export TENSORRT_HOME="$CASE_SENSIVE_ARG" - echo "Link Tensor-RT library. Path: ${CASE_SENSIVE_ARG}" ;; - G) - build_option_proc_upper_g ;; - f) - export FASTER_BUILD_FOR_PLUGINS="on" ;; - o) - export ENABLE_AIO="on" ;; - *) - echo "Unknown option ${opt}!" - usage - exit 1 - esac - done -} diff --git a/scripts/build/usage.sh b/scripts/build/usage.sh deleted file mode 100755 index 29c0d6c9f96..00000000000 --- a/scripts/build/usage.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -set -e - -usage() -{ - echo "Usage:" - echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t ut|st] [-g on|off] [-h] [-b ge] [-m infer|train] \\" - echo " [-a on|off] [-p on|off] [-i] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|d|cpu] \\" - echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 10.1|11.1|310|910|910b] [-I arm64|arm32|x86_64] [-K on|off] \\" - echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-H on|off] \\" - echo " [-A on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|avx512|off] \\" - echo " [-L Tensor-RT path] [-y on|off] [-F on|off] [-G common|auto|ptx] [-o]\\" - echo "" - echo "Options:" - echo " -d Debug mode" - echo " -r Release mode, default mode" - echo " -v Display build command" - echo " -c Enable code coverage, default off" - echo " -t Run testcases, default off" - echo " -g Use glog to output log, default on" - echo " -h Print usage" - echo " -b Select other backend, available: \\" - echo " ge:graph engine" - echo " -m Select graph engine backend mode, available: infer, train, default is infer" - echo " -a Enable ASAN, default off" - echo " -p Enable pipeline profile, print to stdout, default off" - echo " -R Enable pipeline profile, record to json, default off" - echo " -i Enable increment building, default off" - echo " -j[n] Set the threads when building (Default: -j8)" - echo " -e Use cpu, gpu or ascend[ascend|d]" - echo " -s Enable security, default off" - echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" - echo " -D Enable dumping of function graph ir, default on" - echo " -z Compile dataset & mindrecord, default on" - echo " -n Compile minddata with mindspore lite, available: off, lite, full, lite_cv, full mode in lite train and lite_cv, wrapper mode in lite predict" - echo " -M Enable MPI and NCCL for GPU training, gpu default on" - echo " -V Specify the device version, if -e gpu, default CUDA 10.1, if -e ascend, default Ascend 910" - echo " -I Enable compiling mindspore lite for arm64, arm32 or x86_64, default disable mindspore lite compilation" - echo " -A Enable compiling mindspore lite aar package, option: on/off, default: off" - echo " -K Compile with AKG, default on" - echo " -B Enable debugger, default on" - echo " -E Enable IBVERBS for parameter server, default off" - echo " -l Compile with python dependency, default on" - echo " -S Enable enable download cmake compile dependency from gitee , default off" - echo " -k Enable make clean, clean up compilation generated cache " - echo " -W Enable SIMD instruction set, use [sse|neon|avx|avx512|off], default avx for cloud CPU backend" - echo " -H Enable hidden" - echo " -L Link and specify Tensor-RT library path, default disable Tensor-RT lib linking" - echo " -y Compile the symbol table switch and save the symbol table to the directory output" - echo " -F Use fast hash table in mindspore compiler, default on" - echo " -G Select an architecture to build, set 'common' to build with common architectures(eg. gpu: 5.3, 6.0, 6.2, 7.0, 7.2, 7.5),\\" - echo " set auto to detect automatically, set 'ptx' to only build ptx, default: 'auto'. Only effective for GPU currently." - echo " -f Faster build process for device plugins, only build plugin." - echo " -o Compile aio plugin, default off" -} diff --git a/scripts/check_clang_format.sh b/scripts/check_clang_format.sh old mode 100755 new mode 100644 diff --git a/scripts/check_tid.sh b/scripts/check_tid.sh old mode 100755 new mode 100644 diff --git a/scripts/dot2svg.sh b/scripts/dot2svg.sh old mode 100755 new mode 100644 diff --git a/scripts/format_source_code.sh b/scripts/format_source_code.sh old mode 100755 new mode 100644 diff --git a/scripts/get_bert_shape_from_pytest.sh b/scripts/get_bert_shape_from_pytest.sh old mode 100755 new mode 100644 diff --git a/scripts/get_op_use_count.sh b/scripts/get_op_use_count.sh old mode 100755 new mode 100644 diff --git a/scripts/get_shape_from_ir.sh b/scripts/get_shape_from_ir.sh old mode 100755 new mode 100644 diff --git a/scripts/pre_commit/githooks/pre-push b/scripts/pre_commit/githooks/pre-push old mode 100755 new mode 100644 diff --git a/scripts/pre_commit/install_generic_tools.sh b/scripts/pre_commit/install_generic_tools.sh old mode 100755 new mode 100644 diff --git a/scripts/pre_commit/install_system_specific_tools.sh b/scripts/pre_commit/install_system_specific_tools.sh old mode 100755 new mode 100644 diff --git a/scripts/run_perf_test.sh b/scripts/run_perf_test.sh old mode 100755 new mode 100644 diff --git a/scripts/setdotlabelwidth b/scripts/setdotlabelwidth old mode 100755 new mode 100644 diff --git a/tests/models b/tests/models deleted file mode 160000 index 29047f66c5a..00000000000 --- a/tests/models +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 29047f66c5ab63c4109b57992a95f4d58f0ab526 diff --git a/tests/st/cpp/cxx_api/runtest.sh b/tests/st/cpp/cxx_api/runtest.sh old mode 100755 new mode 100644 diff --git a/tests/st/dump/test_save_kernel_args.py b/tests/st/dump/test_save_kernel_args.py index 4fcce7c0d34..276fc05bb57 100644 --- a/tests/st/dump/test_save_kernel_args.py +++ b/tests/st/dump/test_save_kernel_args.py @@ -1,75 +1,75 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import os -import glob -import json -import mindspore.context as context -import tempfile -import time - -import mindspore -from mindspore import JitConfig, Tensor, nn -from mindspore.ops import operations as P -from pathlib import Path -import numpy as np -from tests.mark_utils import arg_mark -from dump_test_utils import generate_dump_json - - -def check_kernel_args_dump(dump_file_path): - output_name = "MatMul.*.json" - output_path = glob.glob(os.path.join(dump_file_path, output_name))[0] - real_path = os.path.realpath(output_path) - with open(real_path, 'r') as f: - net_args = json.load(f) - assert net_args.get("transpose_a") == "True" - assert net_args.get("transpose_b") == "False" - - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_e2e_dump_save_kernel_args_true(): - """ - Feature: kbyk dump support kernel args. - Description: Test kbyk dump kernel args on device. - Expectation: dump real kernel args. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - test_dir = tempfile.TemporaryDirectory(suffix="save_kernel_args") - - path = Path(test_dir.name) - dump_path = str(path / "dump_data") - dump_config_path = str(path / "config.json") - - generate_dump_json(dump_path, dump_config_path, "test_e2e_dump_save_kernel_args_true", "Net") - os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path - try: - class Net(nn.Cell): - def __init__(self, transpose_a=False, transpose_b=False): - super(Net, self).__init__() - self.matmul = P.MatMul(transpose_a, transpose_b) - - def construct(self, x, y): - return self.matmul(x, y) - - jit_config = JitConfig(jit_level="O0") - net = Net(transpose_a=True) - net.set_jit_config(jit_config) - x = Tensor(np.ones(shape=[3, 3]), mindspore.float32) - y = Tensor(np.ones(shape=[3, 4]), mindspore.float32) - _ = net(x, y) - time.sleep(2) - check_kernel_args_dump(path / "dump_data" / "rank_0" / "Net" / "0" / "0") - finally: - del os.environ['MINDSPORE_DUMP_CONFIG'] +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import glob +import json +import mindspore.context as context +import tempfile +import time + +import mindspore +from mindspore import JitConfig, Tensor, nn +from mindspore.ops import operations as P +from pathlib import Path +import numpy as np +from tests.mark_utils import arg_mark +from dump_test_utils import generate_dump_json + + +def check_kernel_args_dump(dump_file_path): + output_name = "MatMul.*.json" + output_path = glob.glob(os.path.join(dump_file_path, output_name))[0] + real_path = os.path.realpath(output_path) + with open(real_path, 'r') as f: + net_args = json.load(f) + assert net_args.get("transpose_a") == "True" + assert net_args.get("transpose_b") == "False" + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_e2e_dump_save_kernel_args_true(): + """ + Feature: kbyk dump support kernel args. + Description: Test kbyk dump kernel args on device. + Expectation: dump real kernel args. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_dir = tempfile.TemporaryDirectory(suffix="save_kernel_args") + + path = Path(test_dir.name) + dump_path = str(path / "dump_data") + dump_config_path = str(path / "config.json") + + generate_dump_json(dump_path, dump_config_path, "test_e2e_dump_save_kernel_args_true", "Net") + os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path + try: + class Net(nn.Cell): + def __init__(self, transpose_a=False, transpose_b=False): + super(Net, self).__init__() + self.matmul = P.MatMul(transpose_a, transpose_b) + + def construct(self, x, y): + return self.matmul(x, y) + + jit_config = JitConfig(jit_level="O0") + net = Net(transpose_a=True) + net.set_jit_config(jit_config) + x = Tensor(np.ones(shape=[3, 3]), mindspore.float32) + y = Tensor(np.ones(shape=[3, 4]), mindspore.float32) + _ = net(x, y) + time.sleep(2) + check_kernel_args_dump(path / "dump_data" / "rank_0" / "Net" / "0" / "0") + finally: + del os.environ['MINDSPORE_DUMP_CONFIG'] diff --git a/tests/st/graph_kernel/custom/aot_test_files/add.cc b/tests/st/graph_kernel/custom/aot_test_files/add.cc old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/add.cu b/tests/st/graph_kernel/custom/aot_test_files/add.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/add_mul_div.cu b/tests/st/graph_kernel/custom/aot_test_files/add_mul_div.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/add_mul_div_bprop.cu b/tests/st/graph_kernel/custom/aot_test_files/add_mul_div_bprop.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/add_with_attr.cc b/tests/st/graph_kernel/custom/aot_test_files/add_with_attr.cc old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/add_with_attr.cu b/tests/st/graph_kernel/custom/aot_test_files/add_with_attr.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/hetero_square_mul.cu b/tests/st/graph_kernel/custom/aot_test_files/hetero_square_mul.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/reorganize.cu b/tests/st/graph_kernel/custom/aot_test_files/reorganize.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/square.cu b/tests/st/graph_kernel/custom/aot_test_files/square.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/aot_test_files/square_bprop.cu b/tests/st/graph_kernel/custom/aot_test_files/square_bprop.cu old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/custom/test_custom_aot.py b/tests/st/graph_kernel/custom/test_custom_aot.py old mode 100755 new mode 100644 diff --git a/tests/st/graph_kernel/test_assign_add.py b/tests/st/graph_kernel/test_assign_add.py index ec54acef7cc..8d7fe131cd4 100644 --- a/tests/st/graph_kernel/test_assign_add.py +++ b/tests/st/graph_kernel/test_assign_add.py @@ -1,78 +1,78 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -from tests.mark_utils import arg_mark -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, Parameter -from mindspore.ops import operations as P - - -class AssignAdd(nn.Cell): - def __init__(self, value): - super(AssignAdd, self).__init__() - self.var = Parameter(value, name="var") - self.add = P.AssignAdd() - - def construct(self, y): - self.add(self.var, y) - return self.var - - -def get_output(x2, y2, enable_graph_kernel=False): - context.set_context(enable_graph_kernel=enable_graph_kernel) - add = AssignAdd(x2) - result_gk_on_1 = add(y2) - add_2 = AssignAdd(result_gk_on_1) - result_gk_on_2 = add_2(y2) - output = [result_gk_on_1, result_gk_on_2] - return output - - -def assign_add(): - x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) - y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float16)) - - expect = get_output(x2, y2, False) - output = get_output(x2, y2, True) - e1, e2 = list(expect) - o1, o2 = list(output) - - assert np.allclose(o1.asnumpy(), e1.asnumpy()) - assert np.allclose(o2.asnumpy(), e2.asnumpy()) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='unessential') -def test_assign_add_gpu(): - """ - Feature: test graph kernel AssignAdd - Description: run test case on GPU - Expectation: the result match with expect - """ - context.set_context(mode=context.GRAPH_MODE) - assign_add() - - -@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_assign_add_ascend(): - """ - Feature: test graph kernel AssignAdd - Description: run test case on Ascend - Expectation: the result match with expect - """ - context.set_context(jit_level='O0') - context.set_context(mode=context.GRAPH_MODE) - assign_add() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tests.mark_utils import arg_mark +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P + + +class AssignAdd(nn.Cell): + def __init__(self, value): + super(AssignAdd, self).__init__() + self.var = Parameter(value, name="var") + self.add = P.AssignAdd() + + def construct(self, y): + self.add(self.var, y) + return self.var + + +def get_output(x2, y2, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + add = AssignAdd(x2) + result_gk_on_1 = add(y2) + add_2 = AssignAdd(result_gk_on_1) + result_gk_on_2 = add_2(y2) + output = [result_gk_on_1, result_gk_on_2] + return output + + +def assign_add(): + x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) + y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float16)) + + expect = get_output(x2, y2, False) + output = get_output(x2, y2, True) + e1, e2 = list(expect) + o1, o2 = list(output) + + assert np.allclose(o1.asnumpy(), e1.asnumpy()) + assert np.allclose(o2.asnumpy(), e2.asnumpy()) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='unessential') +def test_assign_add_gpu(): + """ + Feature: test graph kernel AssignAdd + Description: run test case on GPU + Expectation: the result match with expect + """ + context.set_context(mode=context.GRAPH_MODE) + assign_add() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_assign_add_ascend(): + """ + Feature: test graph kernel AssignAdd + Description: run test case on Ascend + Expectation: the result match with expect + """ + context.set_context(jit_level='O0') + context.set_context(mode=context.GRAPH_MODE) + assign_add() diff --git a/tests/st/graph_kernel/test_sigmoid.py b/tests/st/graph_kernel/test_sigmoid.py index e21e55236a5..4f92d325d94 100644 --- a/tests/st/graph_kernel/test_sigmoid.py +++ b/tests/st/graph_kernel/test_sigmoid.py @@ -1,92 +1,92 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -from tests.mark_utils import arg_mark - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G - - -class NetSigmoid(nn.Cell): - def __init__(self): - super(NetSigmoid, self).__init__() - self.sigmoid = P.Sigmoid() - - def construct(self, x): - return self.sigmoid(x) - - -class NetSigmoidGrad(nn.Cell): - def __init__(self): - super(NetSigmoidGrad, self).__init__() - self.sigmoid_grad = G.SigmoidGrad() - - def construct(self, y, dy): - return self.sigmoid_grad(y, dy) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sigmoid(): - """ - Feature: todo - Description: todo - Expectation: todo - """ - x = Tensor(np.array([[[[-1, 1, 10], - [1, -1, 1], - [10, 1, -1]]]]).astype(np.float32)) - - error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - net = NetSigmoid() - result_open_gk = net(x) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) - net_beta = NetSigmoid() - result_close_gk = net_beta(x) - diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() - assert np.all(abs(diff) < error) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sigmoid_grad(): - """ - Feature: todo - Description: todo - Expectation: todo - """ - y = Tensor(np.array([[[[-1, 1, 2], - [1, -1, 1], - [2, 1, -1]]]]).astype(np.float32)) - dy = Tensor(np.array([[[[-11, 2, 4], - [-1, 1, -1], - [-4, 4, -4]]]]).astype(np.float32)) - - error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - net = NetSigmoidGrad() - result_open_gk = net(y, dy) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) - net_beta = NetSigmoidGrad() - result_close_gk = net_beta(y, dy) - diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() - assert np.all(abs(diff) < error) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tests.mark_utils import arg_mark + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G + + +class NetSigmoid(nn.Cell): + def __init__(self): + super(NetSigmoid, self).__init__() + self.sigmoid = P.Sigmoid() + + def construct(self, x): + return self.sigmoid(x) + + +class NetSigmoidGrad(nn.Cell): + def __init__(self): + super(NetSigmoidGrad, self).__init__() + self.sigmoid_grad = G.SigmoidGrad() + + def construct(self, y, dy): + return self.sigmoid_grad(y, dy) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sigmoid(): + """ + Feature: todo + Description: todo + Expectation: todo + """ + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + + error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) + net = NetSigmoid() + result_open_gk = net(x) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) + net_beta = NetSigmoid() + result_close_gk = net_beta(x) + diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() + assert np.all(abs(diff) < error) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sigmoid_grad(): + """ + Feature: todo + Description: todo + Expectation: todo + """ + y = Tensor(np.array([[[[-1, 1, 2], + [1, -1, 1], + [2, 1, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[-11, 2, 4], + [-1, 1, -1], + [-4, 4, -4]]]]).astype(np.float32)) + + error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) + net = NetSigmoidGrad() + result_open_gk = net(y, dy) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) + net_beta = NetSigmoidGrad() + result_close_gk = net_beta(y, dy) + diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() + assert np.all(abs(diff) < error) diff --git a/tests/st/graph_kernel/test_sigmoid_cross_entropy_with_logits.py b/tests/st/graph_kernel/test_sigmoid_cross_entropy_with_logits.py index 080ea109c50..f2157c1e695 100644 --- a/tests/st/graph_kernel/test_sigmoid_cross_entropy_with_logits.py +++ b/tests/st/graph_kernel/test_sigmoid_cross_entropy_with_logits.py @@ -1,100 +1,100 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -from tests.mark_utils import arg_mark - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G - - -class NetSigmoidCrossEntropyWithLogits(nn.Cell): - def __init__(self): - super(NetSigmoidCrossEntropyWithLogits, self).__init__() - self.loss = P.SigmoidCrossEntropyWithLogits() - - def construct(self, logits, labels): - return self.loss(logits, labels) - - -class NetSigmoidCrossEntropyWithLogitsGrad(nn.Cell): - def __init__(self): - super(NetSigmoidCrossEntropyWithLogitsGrad, self).__init__() - self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad() - - def construct(self, logits, labels, dout): - return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) - - -@arg_mark(plat_marks=['platform_ascend910b', 'platform_gpu'], - level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sigmoid_cross_entropy_with_logits(): - """ - Feature: test graph kernel SigmoidCrossEntropyWithLogits expander - Description: SigmoidCrossEntropyWithLogits expander - Expectation: the result match with the expected result - """ - context.set_context(jit_level='O0') - logits = Tensor(np.array([[1, 1, 2], - [1, 2, 1], - [2, 1, 1]]).astype(np.float32)) - labels = Tensor(np.array([[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) - - error = np.ones(shape=[3, 3]) * 1.0e-6 - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() - result_open_gk = sigmoid_cross_entropy_with_logits(logits, labels) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) - sigmoid_cross_entropy_with_logits_beta = NetSigmoidCrossEntropyWithLogits() - result_close_gk = sigmoid_cross_entropy_with_logits_beta(logits, labels) - diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() - assert np.all(abs(diff) < error) - - -@arg_mark(plat_marks=['platform_ascend910b', 'platform_gpu'], - level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sigmoid_cross_entropy_with_logits_grad(): - """ - Feature: test graph kernel SigmoidCrossEntropyWithLogitsGrad expander - Description: SigmoidCrossEntropyWithLogitsGrad expander - Expectation: the result match with the expected result - """ - context.set_context(jit_level='O0') - logits = Tensor(np.array([[1, 1, 2], - [1, 2, 1], - [2, 1, 1]]).astype(np.float32)) - labels = Tensor(np.array([[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) - dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) - - error = np.ones(shape=[3, 3]) * 1.0e-6 - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - sigmoid_cross_entropy_with_logits_grad = NetSigmoidCrossEntropyWithLogitsGrad() - result_open_gk = sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) - sigmoid_cross_entropy_with_logits_grad_beta = NetSigmoidCrossEntropyWithLogitsGrad() - result_close_gk = sigmoid_cross_entropy_with_logits_grad_beta(logits, labels, dout) - diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() - assert np.all(abs(diff) < error) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tests.mark_utils import arg_mark + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G + + +class NetSigmoidCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(NetSigmoidCrossEntropyWithLogits, self).__init__() + self.loss = P.SigmoidCrossEntropyWithLogits() + + def construct(self, logits, labels): + return self.loss(logits, labels) + + +class NetSigmoidCrossEntropyWithLogitsGrad(nn.Cell): + def __init__(self): + super(NetSigmoidCrossEntropyWithLogitsGrad, self).__init__() + self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad() + + def construct(self, logits, labels, dout): + return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) + + +@arg_mark(plat_marks=['platform_ascend910b', 'platform_gpu'], + level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sigmoid_cross_entropy_with_logits(): + """ + Feature: test graph kernel SigmoidCrossEntropyWithLogits expander + Description: SigmoidCrossEntropyWithLogits expander + Expectation: the result match with the expected result + """ + context.set_context(jit_level='O0') + logits = Tensor(np.array([[1, 1, 2], + [1, 2, 1], + [2, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + + error = np.ones(shape=[3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) + sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() + result_open_gk = sigmoid_cross_entropy_with_logits(logits, labels) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) + sigmoid_cross_entropy_with_logits_beta = NetSigmoidCrossEntropyWithLogits() + result_close_gk = sigmoid_cross_entropy_with_logits_beta(logits, labels) + diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() + assert np.all(abs(diff) < error) + + +@arg_mark(plat_marks=['platform_ascend910b', 'platform_gpu'], + level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sigmoid_cross_entropy_with_logits_grad(): + """ + Feature: test graph kernel SigmoidCrossEntropyWithLogitsGrad expander + Description: SigmoidCrossEntropyWithLogitsGrad expander + Expectation: the result match with the expected result + """ + context.set_context(jit_level='O0') + logits = Tensor(np.array([[1, 1, 2], + [1, 2, 1], + [2, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) + + error = np.ones(shape=[3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) + sigmoid_cross_entropy_with_logits_grad = NetSigmoidCrossEntropyWithLogitsGrad() + result_open_gk = sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) + sigmoid_cross_entropy_with_logits_grad_beta = NetSigmoidCrossEntropyWithLogitsGrad() + result_close_gk = sigmoid_cross_entropy_with_logits_grad_beta(logits, labels, dout) + diff = result_open_gk.asnumpy() - result_close_gk.asnumpy() + assert np.all(abs(diff) < error) diff --git a/tests/st/graph_kernel/test_softmax_cross_entropy_with_logits.py b/tests/st/graph_kernel/test_softmax_cross_entropy_with_logits.py index 61f027c7799..f4bfda639cd 100644 --- a/tests/st/graph_kernel/test_softmax_cross_entropy_with_logits.py +++ b/tests/st/graph_kernel/test_softmax_cross_entropy_with_logits.py @@ -1,60 +1,60 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -from tests.mark_utils import arg_mark - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - - -class NetSoftmaxCrossEntropyWithLogits(nn.Cell): - def __init__(self): - super(NetSoftmaxCrossEntropyWithLogits, self).__init__() - self.loss = P.SoftmaxCrossEntropyWithLogits() - - def construct(self, logits, labels): - return self.loss(logits, labels) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_softmax_cross_entropy_with_logits(): - """ - Feature: todo - Description: todo - Expectation: todo - """ - logits = Tensor(np.array([[1, 1, 10], - [1, 10, 1], - [10, 1, 1]]).astype(np.float32)) - labels = Tensor(np.array([[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits() - result_open_gk = softmax_cross_entropy_with_logits(logits, labels) - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) - softmax_cross_entropy_with_logits_beta = NetSoftmaxCrossEntropyWithLogits() - result_close_gk = softmax_cross_entropy_with_logits_beta(logits, labels) - - error0 = 1.0e-6 - diff0 = result_open_gk[0].asnumpy() - result_close_gk[0].asnumpy() - diff1 = result_open_gk[1].asnumpy() - result_close_gk[1].asnumpy() - assert np.all(abs(diff0) < error0) - assert np.all(abs(diff1) < error0) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from tests.mark_utils import arg_mark + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetSoftmaxCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(NetSoftmaxCrossEntropyWithLogits, self).__init__() + self.loss = P.SoftmaxCrossEntropyWithLogits() + + def construct(self, logits, labels): + return self.loss(logits, labels) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_softmax_cross_entropy_with_logits(): + """ + Feature: todo + Description: todo + Expectation: todo + """ + logits = Tensor(np.array([[1, 1, 10], + [1, 10, 1], + [10, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) + softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits() + result_open_gk = softmax_cross_entropy_with_logits(logits, labels) + + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=False) + softmax_cross_entropy_with_logits_beta = NetSoftmaxCrossEntropyWithLogits() + result_close_gk = softmax_cross_entropy_with_logits_beta(logits, labels) + + error0 = 1.0e-6 + diff0 = result_open_gk[0].asnumpy() - result_close_gk[0].asnumpy() + diff1 = result_open_gk[1].asnumpy() - result_close_gk[1].asnumpy() + assert np.all(abs(diff0) < error0) + assert np.all(abs(diff1) < error0) diff --git a/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py b/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py index ead5fb9e82e..6a45cc6ec93 100644 --- a/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py +++ b/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py @@ -1,206 +1,206 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim.optimizer import Optimizer -from mindspore.ops.function.clip_func import get_square_sum - -from tests.mark_utils import arg_mark - - -class LeNet(nn.Cell): - """ - Implements lenet. - """ - - def __init__(self): - super(LeNet, self).__init__() - self.relu = P.ReLU() - self.batch_size = 1 - weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01) - weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float16) * 0.01) - self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0) - self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid") - - self.reshape = P.Reshape() - self.reshape1 = P.Reshape() - - self.fc1 = nn.Dense(400, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, 10) - - def construct(self, input_x): - output = self.conv1(input_x) - output = self.relu(output) - output = self.pool(output) - output = P.Cast()(output, mstype.float16) - output = self.conv2(output) - output = P.Cast()(output, mstype.float32) - output = self.relu(output) - output = self.pool(output) - output = self.reshape(output, (self.batch_size, -1)) - output = self.fc1(output) - output = self.fc2(output) - output = self.fc3(output) - return output - - -_adam_opt = C.MultitypeFuncGraph("adam_opt") - - -@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Bool", "Bool") -def _fused_update_with_global_norm(opt, global_norm, beta1, beta2, eps, lr, weight_decay, - param, m, v, gradient, decay_flags, optim_filter): - """ - Update parameters by FusedAdamWeightDecay. - """ - success = True - if optim_filter: - if decay_flags: - next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient, global_norm) - else: - next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient, global_norm) - return F.depend(success, next_param) - return success - - -def clone_state(parameter_tuple, prefix, init): - new = [] - for old_param in parameter_tuple: - new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32)) - new_state.param_info = old_param.param_info.clone() - new_state.is_init = False - new_state.name = prefix + '.' + new_state.name - new.append(new_state) - return ParameterTuple(new) - - -apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") - - -@apply_global_norm.register("Tensor", "Tensor", "Tensor") -def _apply_global_norm(clip_norm, global_norm, grad): - return grad * clip_norm / global_norm - - -class GlobalNorm(nn.Cell): - """ - Calculate the global norm value of given tensors - """ - - def __init__(self): - super(GlobalNorm, self).__init__() - self.norm = nn.Norm() - self.hyper_map = C.HyperMap() - self.sqrt = P.Sqrt() - - def construct(self, grads): - """Calculate global norm construct""" - square_sum = self.hyper_map(get_square_sum, grads) - global_norms = self.sqrt(F.addn(square_sum)) - return global_norms - - -class FusedAdamWeightDecayWithGlobalNorm(Optimizer): - """ - Implements the gradient clipping by global norm for a AdamWeightDecay optimizer. - """ - - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): - super(FusedAdamWeightDecayWithGlobalNorm, self).__init__(learning_rate, params, weight_decay) - self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) - self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) - self.eps = Tensor(np.array([eps]).astype(np.float32)) - self.moments1 = clone_state(self._parameters, prefix="adam_m", init='zeros') - self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros') - self.norm = GlobalNorm() - self.opt = P.FusedCastAdamWeightDecay() - self.opt.set_device("CPU") - - def construct(self, gradients): - """construct with gradients""" - global_norm = self.norm(gradients) - lr = self.get_lr() - optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm, - self.beta1, self.beta2, self.eps, lr, self.weight_decay), - self._parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) - return optim_result - - -@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level2", card_mark="onecard", - essential_mark="essential") -def test_fused_cast_adam_weight_decay(): - ''' - Feature: FusedCastAdamWeightDecay - Description: Test FusedCastAdamWeightDecay - Expectation: Run lenet success - ''' - context.set_context(mode=context.GRAPH_MODE) - data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = LeNet() - net.batch_size = 32 - learning_rate = 0.01 - optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()), - learning_rate) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) - train_network.set_train() - loss = [] - for _ in range(10): - res = train_network(data, label) - loss.append(res.asnumpy()) - assert np.all(loss[-1] < 0.1) - - -@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level2", card_mark="onecard", - essential_mark="essential") -def test_fused_cast_adam_weight_decay_with_memory_optimize(): - ''' - Feature: Integration of dynamic and static memory in the heterogeneous scene - Description: Test FusedCastAdamWeightDecay - Expectation: Run lenet success - ''' - context.set_context(mode=context.GRAPH_MODE, memory_optimize_level="O1") - data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = LeNet() - net.batch_size = 32 - learning_rate = 0.01 - optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()), - learning_rate) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) - train_network.set_train() - loss = [] - for _ in range(10): - res = train_network(data, label) - loss.append(res.asnumpy()) - assert np.all(loss[-1] < 0.1) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.ops.function.clip_func import get_square_sum + +from tests.mark_utils import arg_mark + + +class LeNet(nn.Cell): + """ + Implements lenet. + """ + + def __init__(self): + super(LeNet, self).__init__() + self.relu = P.ReLU() + self.batch_size = 1 + weight1 = Tensor(np.ones([6, 3, 5, 5]).astype(np.float32) * 0.01) + weight2 = Tensor(np.ones([16, 6, 5, 5]).astype(np.float16) * 0.01) + self.conv1 = nn.Conv2d(3, 6, (5, 5), weight_init=weight1, stride=1, padding=0, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid") + + self.reshape = P.Reshape() + self.reshape1 = P.Reshape() + + self.fc1 = nn.Dense(400, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = P.Cast()(output, mstype.float16) + output = self.conv2(output) + output = P.Cast()(output, mstype.float32) + output = self.relu(output) + output = self.pool(output) + output = self.reshape(output, (self.batch_size, -1)) + output = self.fc1(output) + output = self.fc2(output) + output = self.fc3(output) + return output + + +_adam_opt = C.MultitypeFuncGraph("adam_opt") + + +@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Bool", "Bool") +def _fused_update_with_global_norm(opt, global_norm, beta1, beta2, eps, lr, weight_decay, + param, m, v, gradient, decay_flags, optim_filter): + """ + Update parameters by FusedAdamWeightDecay. + """ + success = True + if optim_filter: + if decay_flags: + next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient, global_norm) + else: + next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, gradient, global_norm) + return F.depend(success, next_param) + return success + + +def clone_state(parameter_tuple, prefix, init): + new = [] + for old_param in parameter_tuple: + new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32)) + new_state.param_info = old_param.param_info.clone() + new_state.is_init = False + new_state.name = prefix + '.' + new_state.name + new.append(new_state) + return ParameterTuple(new) + + +apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") + + +@apply_global_norm.register("Tensor", "Tensor", "Tensor") +def _apply_global_norm(clip_norm, global_norm, grad): + return grad * clip_norm / global_norm + + +class GlobalNorm(nn.Cell): + """ + Calculate the global norm value of given tensors + """ + + def __init__(self): + super(GlobalNorm, self).__init__() + self.norm = nn.Norm() + self.hyper_map = C.HyperMap() + self.sqrt = P.Sqrt() + + def construct(self, grads): + """Calculate global norm construct""" + square_sum = self.hyper_map(get_square_sum, grads) + global_norms = self.sqrt(F.addn(square_sum)) + return global_norms + + +class FusedAdamWeightDecayWithGlobalNorm(Optimizer): + """ + Implements the gradient clipping by global norm for a AdamWeightDecay optimizer. + """ + + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(FusedAdamWeightDecayWithGlobalNorm, self).__init__(learning_rate, params, weight_decay) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = clone_state(self._parameters, prefix="adam_m", init='zeros') + self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros') + self.norm = GlobalNorm() + self.opt = P.FusedCastAdamWeightDecay() + self.opt.set_device("CPU") + + def construct(self, gradients): + """construct with gradients""" + global_norm = self.norm(gradients) + lr = self.get_lr() + optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm, + self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self._parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + return optim_result + + +@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level2", card_mark="onecard", + essential_mark="essential") +def test_fused_cast_adam_weight_decay(): + ''' + Feature: FusedCastAdamWeightDecay + Description: Test FusedCastAdamWeightDecay + Expectation: Run lenet success + ''' + context.set_context(mode=context.GRAPH_MODE) + data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + net.batch_size = 32 + learning_rate = 0.01 + optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + loss = [] + for _ in range(10): + res = train_network(data, label) + loss.append(res.asnumpy()) + assert np.all(loss[-1] < 0.1) + + +@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level2", card_mark="onecard", + essential_mark="essential") +def test_fused_cast_adam_weight_decay_with_memory_optimize(): + ''' + Feature: Integration of dynamic and static memory in the heterogeneous scene + Description: Test FusedCastAdamWeightDecay + Expectation: Run lenet success + ''' + context.set_context(mode=context.GRAPH_MODE, memory_optimize_level="O1") + data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + net.batch_size = 32 + learning_rate = 0.01 + optimizer = FusedAdamWeightDecayWithGlobalNorm(filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + loss = [] + for _ in range(10): + res = train_network(data, label) + loss.append(res.asnumpy()) + assert np.all(loss[-1] < 0.1) diff --git a/tests/st/lccl/test_lccl_allgather.py b/tests/st/lccl/test_lccl_allgather.py index 5a50bffad66..956593abead 100644 --- a/tests/st/lccl/test_lccl_allgather.py +++ b/tests/st/lccl/test_lccl_allgather.py @@ -1,63 +1,63 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""test lccl allgather with 8p""" - -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.communication.management import init, HCCL_WORLD_COMM_GROUP, get_rank, get_group_size -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') -context.set_context(jit_level='O0') - -init() -rank = get_rank() -size = get_group_size() -x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.all_gather = P.AllGather(group=HCCL_WORLD_COMM_GROUP) - self.x = Parameter(initializer(Tensor(x), x.shape), name='x') - - def construct(self): - return self.all_gather(self.x) - - -def test_AllGather(): - """ - Feature: lccl operator test. - Description: msrun lccl all_gather 8P case. - Expectation: success - """ - all_gather = Net() - output = all_gather() - - expect = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (0 + 1) - for i in range(size - 1): - tmp = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 2) - expect = np.concatenate((expect, tmp)) - diff = np.absolute(output.asnumpy() - expect) - error = np.ones(shape=expect.shape) * 1.0e-5 - assert np.all(diff < error) - assert output.shape == expect.shape +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test lccl allgather with 8p""" + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.communication.management import init, HCCL_WORLD_COMM_GROUP, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') +context.set_context(jit_level='O0') + +init() +rank = get_rank() +size = get_group_size() +x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.all_gather = P.AllGather(group=HCCL_WORLD_COMM_GROUP) + self.x = Parameter(initializer(Tensor(x), x.shape), name='x') + + def construct(self): + return self.all_gather(self.x) + + +def test_AllGather(): + """ + Feature: lccl operator test. + Description: msrun lccl all_gather 8P case. + Expectation: success + """ + all_gather = Net() + output = all_gather() + + expect = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (0 + 1) + for i in range(size - 1): + tmp = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 2) + expect = np.concatenate((expect, tmp)) + diff = np.absolute(output.asnumpy() - expect) + error = np.ones(shape=expect.shape) * 1.0e-5 + assert np.all(diff < error) + assert output.shape == expect.shape diff --git a/tests/st/lccl/test_lccl_allreduce.py b/tests/st/lccl/test_lccl_allreduce.py index 1ab4a34e983..1d48e88f29b 100644 --- a/tests/st/lccl/test_lccl_allreduce.py +++ b/tests/st/lccl/test_lccl_allreduce.py @@ -1,85 +1,85 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""test lccl allreduce with 8p""" - -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.communication.management import init, HCCL_WORLD_COMM_GROUP, get_rank, get_group_size -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') -context.set_context(jit_level='O0') - -init() -rank = get_rank() -size = get_group_size() -x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') - self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') - self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') - - self.op0 = "sum" - self.op1 = "sum" - self.op2 = "sum" - - self.all_reduce1 = P.AllReduce(self.op0, group=HCCL_WORLD_COMM_GROUP) - self.all_reduce2 = P.AllReduce(self.op1, group=HCCL_WORLD_COMM_GROUP) - self.all_reduce3 = P.AllReduce(self.op2, group=HCCL_WORLD_COMM_GROUP) - - def construct(self): - return (self.all_reduce1(self.x1), - self.all_reduce2(self.x2), - self.all_reduce3(self.x3)) - - -def test_AllReduce(): - """ - Feature: lccl operator test. - Description: msrun lccl all_reduce 8P case. - Expectation: success - """ - all_reduce = Net() - output = all_reduce() - - expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0 - for i in range(size): - part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1) - expect0 += part - diff0 = output[0].asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output[0].shape == expect0.shape - - expect1 = expect0 - diff1 = output[1].asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output[1].shape == expect1.shape - - expect2 = expect1 - diff2 = output[2].asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output[2].shape == expect2.shape +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test lccl allreduce with 8p""" + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.communication.management import init, HCCL_WORLD_COMM_GROUP, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') +context.set_context(jit_level='O0') + +init() +rank = get_rank() +size = get_group_size() +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') + self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') + self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') + + self.op0 = "sum" + self.op1 = "sum" + self.op2 = "sum" + + self.all_reduce1 = P.AllReduce(self.op0, group=HCCL_WORLD_COMM_GROUP) + self.all_reduce2 = P.AllReduce(self.op1, group=HCCL_WORLD_COMM_GROUP) + self.all_reduce3 = P.AllReduce(self.op2, group=HCCL_WORLD_COMM_GROUP) + + def construct(self): + return (self.all_reduce1(self.x1), + self.all_reduce2(self.x2), + self.all_reduce3(self.x3)) + + +def test_AllReduce(): + """ + Feature: lccl operator test. + Description: msrun lccl all_reduce 8P case. + Expectation: success + """ + all_reduce = Net() + output = all_reduce() + + expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 0 + for i in range(size): + part = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1) + expect0 += part + diff0 = output[0].asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = expect0 + diff1 = output[1].asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = expect1 + diff2 = output[2].asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape diff --git a/tests/st/lccl/test_lccl_broadcast.py b/tests/st/lccl/test_lccl_broadcast.py index 01d45bd1210..1441792a821 100644 --- a/tests/st/lccl/test_lccl_broadcast.py +++ b/tests/st/lccl/test_lccl_broadcast.py @@ -1,77 +1,77 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.communication.management import init, get_rank, get_group_size -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') -context.set_context(jit_level='O0') - -init() -rank = get_rank() -size = get_group_size() -x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') - self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') - self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') - - self.broadcast1 = P.Broadcast(0) - self.broadcast2 = P.Broadcast(1) - self.broadcast3 = P.Broadcast(2) - - def construct(self): - return (self.broadcast1((self.x1,)), - self.broadcast2((self.x2,)), - self.broadcast3((self.x3,))) - - -def test_Broadcast(): - """ - Feature: lccl operator test. - Description: msrun lccl broadcast 8P case. - Expectation: success - """ - broadcast = Net() - output = broadcast() - - expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 1 - expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 2 - expect2 = np.ones([3, 1, 3, 3]).astype(np.float32) * 3 - - diff0 = output[0][0].asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output[0][0].shape == expect0.shape - - diff1 = output[1][0].asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output[1][0].shape == expect1.shape - - diff2 = output[2][0].asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output[2][0].shape == expect2.shape +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') +context.set_context(jit_level='O0') + +init() +rank = get_rank() +size = get_group_size() +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') + self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') + self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') + + self.broadcast1 = P.Broadcast(0) + self.broadcast2 = P.Broadcast(1) + self.broadcast3 = P.Broadcast(2) + + def construct(self): + return (self.broadcast1((self.x1,)), + self.broadcast2((self.x2,)), + self.broadcast3((self.x3,))) + + +def test_Broadcast(): + """ + Feature: lccl operator test. + Description: msrun lccl broadcast 8P case. + Expectation: success + """ + broadcast = Net() + output = broadcast() + + expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 1 + expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 2 + expect2 = np.ones([3, 1, 3, 3]).astype(np.float32) * 3 + + diff0 = output[0][0].asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0][0].shape == expect0.shape + + diff1 = output[1][0].asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1][0].shape == expect1.shape + + diff2 = output[2][0].asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2][0].shape == expect2.shape diff --git a/tests/st/mindscience/mindsponge/mindsponge/cell/mask.py b/tests/st/mindscience/mindsponge/mindsponge/cell/mask.py index b818c781623..32f4f0a7a5d 100644 --- a/tests/st/mindscience/mindsponge/mindsponge/cell/mask.py +++ b/tests/st/mindscience/mindsponge/mindsponge/cell/mask.py @@ -1,54 +1,54 @@ -# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Mask""" -from mindspore.ops import operations as P -from mindspore.ops import functional as F -import mindspore.nn as nn - - -class LayerNormProcess(nn.Cell): - def __init__(self,): - super(LayerNormProcess, self).__init__() - self.layernorm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) - - def construct(self, msa_act, query_norm_gamma, query_norm_beta): - output, _, _ = self.layernorm(msa_act, query_norm_gamma, query_norm_beta) - return output - - -class MaskedLayerNorm(nn.Cell): - '''masked_layer_norm''' - - def __init__(self): - super(MaskedLayerNorm, self).__init__() - self.norm = LayerNormProcess() - - def construct(self, act, gamma, beta, mask=None): - '''construct''' - act = act - gamma = gamma - beta = beta - - ones = P.Ones()(act.shape[:-1] + (1,), act.dtype) - if mask is not None: - mask = F.expand_dims(mask, -1) - mask = mask * ones - else: - mask = ones - - act = act * mask - act = self.norm(act, gamma, beta) - act = act * mask - return act +# Copyright 2023 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Mask""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +import mindspore.nn as nn + + +class LayerNormProcess(nn.Cell): + def __init__(self,): + super(LayerNormProcess, self).__init__() + self.layernorm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) + + def construct(self, msa_act, query_norm_gamma, query_norm_beta): + output, _, _ = self.layernorm(msa_act, query_norm_gamma, query_norm_beta) + return output + + +class MaskedLayerNorm(nn.Cell): + '''masked_layer_norm''' + + def __init__(self): + super(MaskedLayerNorm, self).__init__() + self.norm = LayerNormProcess() + + def construct(self, act, gamma, beta, mask=None): + '''construct''' + act = act + gamma = gamma + beta = beta + + ones = P.Ones()(act.shape[:-1] + (1,), act.dtype) + if mask is not None: + mask = F.expand_dims(mask, -1) + mask = mask * ones + else: + mask = ones + + act = act * mask + act = self.norm(act, gamma, beta) + act = act * mask + return act diff --git a/tests/st/mindscience/mindsponge/mindsponge/data_transform/data.py b/tests/st/mindscience/mindsponge/mindsponge/data_transform/data.py index 09df2fd610a..0db94d6b769 100644 --- a/tests/st/mindscience/mindsponge/mindsponge/data_transform/data.py +++ b/tests/st/mindscience/mindsponge/mindsponge/data_transform/data.py @@ -1,37 +1,37 @@ - -# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""data transform MSA TEMPLATE""" -import numpy as np -from tests.st.mindscience.mindsponge.mindsponge.common.residue_constants import restype_1to3, chi_angles_atoms, \ - atom_order, restypes - - -def get_chi_atom_pos_indices(): - """get the atom indices for computing chi angles for all residue types""" - chi_atom_pos_indices = [] - for residue_name in restypes: - residue_name = restype_1to3.get(residue_name) - residue_chi_angles = chi_angles_atoms.get(residue_name) - atom_pos_indices = [] - for chi_angle in residue_chi_angles: - atom_pos_indices.append([atom_order[atom] for atom in chi_angle]) - for _ in range(4 - len(atom_pos_indices)): - atom_pos_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. - chi_atom_pos_indices.append(atom_pos_indices) - - chi_atom_pos_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. - - return np.array(chi_atom_pos_indices) + +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data transform MSA TEMPLATE""" +import numpy as np +from tests.st.mindscience.mindsponge.mindsponge.common.residue_constants import restype_1to3, chi_angles_atoms, \ + atom_order, restypes + + +def get_chi_atom_pos_indices(): + """get the atom indices for computing chi angles for all residue types""" + chi_atom_pos_indices = [] + for residue_name in restypes: + residue_name = restype_1to3.get(residue_name) + residue_chi_angles = chi_angles_atoms.get(residue_name) + atom_pos_indices = [] + for chi_angle in residue_chi_angles: + atom_pos_indices.append([atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_pos_indices)): + atom_pos_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_pos_indices.append(atom_pos_indices) + + chi_atom_pos_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_pos_indices) diff --git a/tests/st/mindscience/mindsponge/mindsponge/data_transform/data_transform.py b/tests/st/mindscience/mindsponge/mindsponge/data_transform/data_transform.py index 7d4dcaed57a..242d2bea202 100644 --- a/tests/st/mindscience/mindsponge/mindsponge/data_transform/data_transform.py +++ b/tests/st/mindscience/mindsponge/mindsponge/data_transform/data_transform.py @@ -1,402 +1,402 @@ -# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""data transform MSA TEMPLATE""" -import numpy as np -from tests.st.mindscience.mindsponge.mindsponge.common.residue_constants import \ - restype_1to3, atom_order, MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, restype_order, restypes, \ - restype_name_to_atom14_names, atom_types, \ - residue_atoms - -MS_MIN32 = -2147483648 -MS_MAX32 = 2147483647 - - -def one_hot(depth, indices): - """one hot compute""" - res = np.eye(depth)[indices.reshape(-1)] - return res.reshape(list(indices.shape) + [depth]) - - -def correct_msa_restypes(msa, deletion_matrix=None, is_evogen=False): - """Correct MSA restype to have the same order as residue_constants.""" - new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE - new_order = np.array(new_order_list, dtype=msa.dtype) - msa = new_order[msa] - if is_evogen: - msa_input = np.concatenate((msa, deletion_matrix), axis=-1).astype(np.int32) - result = msa, msa_input - else: - result = msa - return result - - -def randomly_replace_msa_with_unknown(msa, aatype, replace_proportion): - """Replace a proportion of the MSA with 'X'.""" - msa_mask = np.random.uniform(size=msa.shape, low=0, high=1) < replace_proportion - x_idx = 20 - gap_idx = 21 - msa_mask = np.logical_and(msa_mask, msa != gap_idx) - msa = np.where(msa_mask, np.ones_like(msa) * x_idx, msa) - aatype_mask = np.random.uniform(size=aatype.shape, low=0, high=1) < replace_proportion - aatype = np.where(aatype_mask, np.ones_like(aatype) * x_idx, aatype) - return msa, aatype - - -def fix_templates_aatype(template_aatype): - """Fixes aatype encoding of templates.""" - # Map one-hot to indices. - template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32) - # Map hhsearch-aatype to our aatype. - new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE - new_order = np.array(new_order_list, np.int32) - template_aatype = new_order[template_aatype] - return template_aatype - - -def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): - """compute pseudo beta features from atom positions""" - is_gly = np.equal(aatype, restype_order['G']) - ca_idx = atom_order['CA'] - cb_idx = atom_order['CB'] - pseudo_beta = np.where( - np.tile(is_gly[..., None].astype("int32"), [1] * len(is_gly.shape) + [3]).astype("bool"), - all_atom_positions[..., ca_idx, :], - all_atom_positions[..., cb_idx, :]) - if all_atom_masks is not None: - pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) - pseudo_beta_mask = pseudo_beta_mask.astype(np.float32) - return pseudo_beta, pseudo_beta_mask - return pseudo_beta - - -def make_atom14_masks(aatype): - """create atom 14 position features from aatype""" - rt_atom14_to_atom37 = [] - rt_atom37_to_atom14 = [] - rt_atom14_mask = [] - - for restype in restypes: - atom_names = restype_name_to_atom14_names.get(restype_1to3.get(restype)) - - rt_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names]) - - atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} - rt_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) - for name in atom_types]) - - rt_atom14_mask.append([(1. if name else 0.) for name in atom_names]) - - # Add dummy mapping for restype 'UNK' - rt_atom14_to_atom37.append([0] * 14) - rt_atom37_to_atom14.append([0] * 37) - rt_atom14_mask.append([0.] * 14) - - rt_atom14_to_atom37 = np.array(rt_atom14_to_atom37, np.int32) - rt_atom37_to_atom14 = np.array(rt_atom37_to_atom14, np.int32) - rt_atom14_mask = np.array(rt_atom14_mask, np.float32) - - ri_atom14_to_atom37 = rt_atom14_to_atom37[aatype] - ri_atom14_mask = rt_atom14_mask[aatype] - - atom14_atom_exists = ri_atom14_mask - ri_atom14_to_atom37 = ri_atom14_to_atom37 - - # create the gather indices for mapping back - ri_atom37_to_atom14 = rt_atom37_to_atom14[aatype] - ri_atom37_to_atom14 = ri_atom37_to_atom14 - - # create the corresponding mask - restype_atom37_mask = np.zeros([21, 37], np.float32) - for restype, restype_letter in enumerate(restypes): - restype_name = restype_1to3.get(restype_letter) - atom_names = residue_atoms.get(restype_name) - for atom_name in atom_names: - atom_type = atom_order[atom_name] - restype_atom37_mask[restype, atom_type] = 1 - - atom37_atom_exists = restype_atom37_mask[aatype] - res = [atom14_atom_exists, ri_atom14_to_atom37, ri_atom37_to_atom14, atom37_atom_exists] - return res - - -def block_delete_msa_indices(msa, msa_fraction_per_block, randomize_num_blocks, num_blocks): - """Sample MSA by deleting contiguous blocks. - """ - - num_seq = msa.shape[0] - block_num_seq = np.floor(num_seq * msa_fraction_per_block).astype(np.int32) - - if randomize_num_blocks: - nb = int(np.random.uniform(0, num_blocks + 1)) - else: - nb = num_blocks - del_block_starts = np.random.uniform(0, num_seq, nb).astype(np.int32) - del_blocks = del_block_starts[:, None] + np.array([_ for _ in range(block_num_seq)]).astype(np.int32) - del_blocks = np.clip(del_blocks, 0, num_seq - 1) - del_indices = np.unique(np.sort(np.reshape(del_blocks, (-1,)))) - - # Make sure we keep the original sequence - keep_indices = np.setdiff1d(np.array([_ for _ in range(1, num_seq)]), - del_indices) - keep_indices = np.concatenate([[0], keep_indices], axis=0) - keep_indices = [int(x) for x in keep_indices] - return keep_indices - - -def sample_msa(msa, max_seq): - """Sample MSA randomly, remaining sequences are stored as `extra_*`.""" - num_seq = msa.shape[0] - - shuffled = list(range(1, num_seq)) - np.random.shuffle(shuffled) - shuffled.insert(0, 0) - index_order = np.array(shuffled, np.int32) - num_sel = min(max_seq, num_seq) - - sel_seq = index_order[:num_sel] - not_sel_seq = index_order[num_sel:] - is_sel = num_seq - num_sel - return is_sel, not_sel_seq, sel_seq - -def shape_list(x): - """get the list of dimensions of an array""" - x = np.array(x) - if x.ndim is None: - return x.shape - - static = x.shape - ret = [] - for _, dimension in enumerate(static): - ret.append(dimension) - return ret - - -def shaped_categorical(probability): - """get categorical shape""" - ds = shape_list(probability) - num_classes = ds[-1] - flat_probs = np.reshape(probability, (-1, num_classes)) - numbers = list(range(num_classes)) - res = [] - for flat_prob in flat_probs: - res.append(np.random.choice(numbers, p=flat_prob)) - return np.reshape(np.array(res, np.int32), ds[:-1]) - - -def make_masked_msa(inputs): - """create masked msa for BERT on raw MSA features""" - msa = inputs[0] - hhblits_profile = inputs[1] - uniform_prob = inputs[2] - profile_prob = inputs[3] - same_prob = inputs[4] - replace_fraction = inputs[5] - - random_aatype = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) - - probability = uniform_prob * random_aatype + profile_prob * hhblits_profile + same_prob * one_hot(22, msa) - - pad_shapes = [[0, 0] for _ in range(len(probability.shape))] - pad_shapes[-1][1] = 1 - mask_prob = 1. - profile_prob - same_prob - uniform_prob - - probability = np.pad(probability, pad_shapes, constant_values=(mask_prob,)) - - masked_aatype = np.random.uniform(size=msa.shape, low=0, high=1) < replace_fraction - - bert_msa = shaped_categorical(probability) - bert_msa = np.where(masked_aatype, bert_msa, msa) - - bert_mask = masked_aatype.astype(np.int32) - true_msa = msa - msa = bert_msa - make_masked_msa_result = bert_mask, true_msa, msa - return make_masked_msa_result - -def nearest_neighbor_clusters(msa_mask, msa, extra_msa_mask, extra_msa, gap_agreement_weight=0.): - """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" - - # Determine how much weight we assign to each agreement. In theory, we could - # use a full blosum matrix here, but right now let's just down-weight gap - # agreement because it could be spurious. - # Never put weight on agreeing on BERT mask - weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0) - - # Make agreement score as weighted Hamming distance - sample_one_hot = msa_mask[:, :, None] * one_hot(23, msa) - num_seq, num_res, _ = sample_one_hot.shape - - array_extra_msa_mask = extra_msa_mask - if array_extra_msa_mask.any(): - extra_one_hot = extra_msa_mask[:, :, None] * one_hot(23, extra_msa) - extra_num_seq, _, _ = extra_one_hot.shape - - agreement = np.matmul( - np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), - np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T) - # Assign each sequence in the extra sequences to the closest MSA sample - extra_cluster_assignment = np.argmax(agreement, axis=1) - else: - extra_cluster_assignment = np.array([]) - return extra_cluster_assignment - - -def summarize_clusters(inputs): - """Produce profile and deletion_matrix_mean within each cluster.""" - msa, msa_mask, extra_cluster_assignment, extra_msa_mask, \ - extra_msa, extra_deletion_matrix, deletion_matrix = inputs - num_seq = msa.shape[0] - - def csum(x): - result = [] - for i in range(num_seq): - result.append(np.sum(x[np.where(extra_cluster_assignment == i)], axis=0)) - return np.array(result) - - mask = extra_msa_mask - mask_counts = 1e-6 + msa_mask + csum(mask) # Include center - - msa_sum = csum(mask[:, :, None] * one_hot(23, extra_msa)) - msa_sum += one_hot(23, msa) # Original sequence - cluster_profile = msa_sum / mask_counts[:, :, None] - - del msa_sum - - del_sum = csum(mask * extra_deletion_matrix) - del_sum += deletion_matrix # Original sequence - cluster_deletion_mean = del_sum / mask_counts - del del_sum - - return cluster_profile, cluster_deletion_mean - - -def crop_extra_msa(extra_msa, max_extra_msa): - """MSA features are cropped so only `max_extra_msa` sequences are kept.""" - num_seq = extra_msa.shape[0] - num_sel = np.minimum(max_extra_msa, num_seq) - shuffled = list(range(num_seq)) - np.random.shuffle(shuffled) - select_indices = shuffled[:num_sel] - return select_indices - - -def make_msa_feat(inputs): - """Create and concatenate MSA features.""" - # Whether there is a domain break. Always zero for chains, but keeping - # for compatibility with domain datasets. - between_segment_residues, aatype, msa, deletion_matrix, cluster_deletion_mean, cluster_profile, \ - extra_deletion_matrix = inputs - has_break = np.clip(between_segment_residues.astype(np.float32), np.array(0), np.array(1)) - aatype_1hot = one_hot(21, aatype) - - target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot] - - msa_1hot = one_hot(23, msa) - has_deletion = np.clip(deletion_matrix, np.array(0), np.array(1)) - deletion_value = np.arctan(deletion_matrix / 3.) * (2. / np.pi) - - msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)] - - if cluster_profile is not None: - deletion_mean_value = (np.arctan(cluster_deletion_mean / 3.) * (2. / np.pi)) - msa_feat.extend([cluster_profile, np.expand_dims(deletion_mean_value, axis=-1)]) - extra_has_deletion = None - extra_deletion_value = None - if extra_deletion_matrix is not None: - extra_has_deletion = np.clip(extra_deletion_matrix, np.array(0), np.array(1)) - extra_deletion_value = np.arctan(extra_deletion_matrix / 3.) * (2. / np.pi) - - msa_feat = np.concatenate(msa_feat, axis=-1) - target_feat = np.concatenate(target_feat, axis=-1) - res = [extra_has_deletion, extra_deletion_value, msa_feat, target_feat] - return res - - -def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32, random_recycle=False): - if random_recycle: - r = np.random.RandomState(seed_maker_t) - return r.uniform(size=size, low=low, high=high) - np.random.seed(seed_maker_t) - return np.random.uniform(size=size, low=low, high=high) - - -def random_crop_to_size(inputs): - """Crop randomly to `crop_size`, or keep as is if shorter than that.""" - seq_length, template_mask, crop_size, max_templates, \ - subsample_templates, seed, random_recycle = inputs - seq_length = seq_length - seq_length_int = int(seq_length) - if template_mask is not None: - num_templates = np.array(template_mask.shape[0], np.int32) - else: - num_templates = np.array(0, np.int32) - num_res_crop_size = np.minimum(seq_length, crop_size) - num_res_crop_size_int = int(num_res_crop_size) - - # Ensures that the cropping of residues and templates happens in the same way - # across ensembling iterations. - # Do not use for randomness that should vary in ensembling. - - if subsample_templates: - templates_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, high=num_templates + 1, - random_recycle=random_recycle)) - else: - templates_crop_start = 0 - - num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates) - num_templates_crop_size_int = int(num_templates_crop_size) - - num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, - high=seq_length_int - num_res_crop_size_int + 1, - random_recycle=random_recycle)) - - templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed, - random_recycle=random_recycle)) - res = [num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \ - templates_crop_start, templates_select_indices] - return res - - -def generate_random_sample(cfg, model_config): - '''generate_random_sample''' - np.random.seed(0) - num_noise = model_config.model.latent.num_noise - latent_dim = model_config.model.latent.latent_dim - - context_true_prob = np.absolute(model_config.train.context_true_prob) - keep_prob = np.absolute(model_config.train.keep_prob) - - available_msa = int(model_config.train.available_msa_fraction * model_config.train.max_msa_clusters) - available_msa = min(available_msa, model_config.train.max_msa_clusters) - - evogen_random_data = np.random.normal( - size=(num_noise, model_config.train.max_msa_clusters, cfg.eval.crop_size, latent_dim)).astype(np.float32) - - # (Nseq,): - context_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) - z1 = np.random.random(model_config.train.max_msa_clusters) - context_mask = np.asarray([1 if x < context_true_prob else 0 for x in z1], np.int32) - context_mask[available_msa:] *= 0 - - # (Nseq,): - target_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) - z2 = np.random.random(model_config.train.max_msa_clusters) - target_mask = np.asarray([1 if x < keep_prob else 0 for x in z2], np.int32) - - context_mask[0] = 1 - target_mask[0] = 1 - - evogen_context_mask = np.stack((context_mask, target_mask), -1) - return evogen_random_data, evogen_context_mask +# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data transform MSA TEMPLATE""" +import numpy as np +from tests.st.mindscience.mindsponge.mindsponge.common.residue_constants import \ + restype_1to3, atom_order, MAP_HHBLITS_AATYPE_TO_OUR_AATYPE, restype_order, restypes, \ + restype_name_to_atom14_names, atom_types, \ + residue_atoms + +MS_MIN32 = -2147483648 +MS_MAX32 = 2147483647 + + +def one_hot(depth, indices): + """one hot compute""" + res = np.eye(depth)[indices.reshape(-1)] + return res.reshape(list(indices.shape) + [depth]) + + +def correct_msa_restypes(msa, deletion_matrix=None, is_evogen=False): + """Correct MSA restype to have the same order as residue_constants.""" + new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, dtype=msa.dtype) + msa = new_order[msa] + if is_evogen: + msa_input = np.concatenate((msa, deletion_matrix), axis=-1).astype(np.int32) + result = msa, msa_input + else: + result = msa + return result + + +def randomly_replace_msa_with_unknown(msa, aatype, replace_proportion): + """Replace a proportion of the MSA with 'X'.""" + msa_mask = np.random.uniform(size=msa.shape, low=0, high=1) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = np.logical_and(msa_mask, msa != gap_idx) + msa = np.where(msa_mask, np.ones_like(msa) * x_idx, msa) + aatype_mask = np.random.uniform(size=aatype.shape, low=0, high=1) < replace_proportion + aatype = np.where(aatype_mask, np.ones_like(aatype) * x_idx, aatype) + return msa, aatype + + +def fix_templates_aatype(template_aatype): + """Fixes aatype encoding of templates.""" + # Map one-hot to indices. + template_aatype = np.argmax(template_aatype, axis=-1).astype(np.int32) + # Map hhsearch-aatype to our aatype. + new_order_list = MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = np.array(new_order_list, np.int32) + template_aatype = new_order[template_aatype] + return template_aatype + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """compute pseudo beta features from atom positions""" + is_gly = np.equal(aatype, restype_order['G']) + ca_idx = atom_order['CA'] + cb_idx = atom_order['CB'] + pseudo_beta = np.where( + np.tile(is_gly[..., None].astype("int32"), [1] * len(is_gly.shape) + [3]).astype("bool"), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :]) + if all_atom_masks is not None: + pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = pseudo_beta_mask.astype(np.float32) + return pseudo_beta, pseudo_beta_mask + return pseudo_beta + + +def make_atom14_masks(aatype): + """create atom 14 position features from aatype""" + rt_atom14_to_atom37 = [] + rt_atom37_to_atom14 = [] + rt_atom14_mask = [] + + for restype in restypes: + atom_names = restype_name_to_atom14_names.get(restype_1to3.get(restype)) + + rt_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + rt_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types]) + + rt_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + rt_atom14_to_atom37.append([0] * 14) + rt_atom37_to_atom14.append([0] * 37) + rt_atom14_mask.append([0.] * 14) + + rt_atom14_to_atom37 = np.array(rt_atom14_to_atom37, np.int32) + rt_atom37_to_atom14 = np.array(rt_atom37_to_atom14, np.int32) + rt_atom14_mask = np.array(rt_atom14_mask, np.float32) + + ri_atom14_to_atom37 = rt_atom14_to_atom37[aatype] + ri_atom14_mask = rt_atom14_mask[aatype] + + atom14_atom_exists = ri_atom14_mask + ri_atom14_to_atom37 = ri_atom14_to_atom37 + + # create the gather indices for mapping back + ri_atom37_to_atom14 = rt_atom37_to_atom14[aatype] + ri_atom37_to_atom14 = ri_atom37_to_atom14 + + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], np.float32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3.get(restype_letter) + atom_names = residue_atoms.get(restype_name) + for atom_name in atom_names: + atom_type = atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + atom37_atom_exists = restype_atom37_mask[aatype] + res = [atom14_atom_exists, ri_atom14_to_atom37, ri_atom37_to_atom14, atom37_atom_exists] + return res + + +def block_delete_msa_indices(msa, msa_fraction_per_block, randomize_num_blocks, num_blocks): + """Sample MSA by deleting contiguous blocks. + """ + + num_seq = msa.shape[0] + block_num_seq = np.floor(num_seq * msa_fraction_per_block).astype(np.int32) + + if randomize_num_blocks: + nb = int(np.random.uniform(0, num_blocks + 1)) + else: + nb = num_blocks + del_block_starts = np.random.uniform(0, num_seq, nb).astype(np.int32) + del_blocks = del_block_starts[:, None] + np.array([_ for _ in range(block_num_seq)]).astype(np.int32) + del_blocks = np.clip(del_blocks, 0, num_seq - 1) + del_indices = np.unique(np.sort(np.reshape(del_blocks, (-1,)))) + + # Make sure we keep the original sequence + keep_indices = np.setdiff1d(np.array([_ for _ in range(1, num_seq)]), + del_indices) + keep_indices = np.concatenate([[0], keep_indices], axis=0) + keep_indices = [int(x) for x in keep_indices] + return keep_indices + + +def sample_msa(msa, max_seq): + """Sample MSA randomly, remaining sequences are stored as `extra_*`.""" + num_seq = msa.shape[0] + + shuffled = list(range(1, num_seq)) + np.random.shuffle(shuffled) + shuffled.insert(0, 0) + index_order = np.array(shuffled, np.int32) + num_sel = min(max_seq, num_seq) + + sel_seq = index_order[:num_sel] + not_sel_seq = index_order[num_sel:] + is_sel = num_seq - num_sel + return is_sel, not_sel_seq, sel_seq + +def shape_list(x): + """get the list of dimensions of an array""" + x = np.array(x) + if x.ndim is None: + return x.shape + + static = x.shape + ret = [] + for _, dimension in enumerate(static): + ret.append(dimension) + return ret + + +def shaped_categorical(probability): + """get categorical shape""" + ds = shape_list(probability) + num_classes = ds[-1] + flat_probs = np.reshape(probability, (-1, num_classes)) + numbers = list(range(num_classes)) + res = [] + for flat_prob in flat_probs: + res.append(np.random.choice(numbers, p=flat_prob)) + return np.reshape(np.array(res, np.int32), ds[:-1]) + + +def make_masked_msa(inputs): + """create masked msa for BERT on raw MSA features""" + msa = inputs[0] + hhblits_profile = inputs[1] + uniform_prob = inputs[2] + profile_prob = inputs[3] + same_prob = inputs[4] + replace_fraction = inputs[5] + + random_aatype = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) + + probability = uniform_prob * random_aatype + profile_prob * hhblits_profile + same_prob * one_hot(22, msa) + + pad_shapes = [[0, 0] for _ in range(len(probability.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - profile_prob - same_prob - uniform_prob + + probability = np.pad(probability, pad_shapes, constant_values=(mask_prob,)) + + masked_aatype = np.random.uniform(size=msa.shape, low=0, high=1) < replace_fraction + + bert_msa = shaped_categorical(probability) + bert_msa = np.where(masked_aatype, bert_msa, msa) + + bert_mask = masked_aatype.astype(np.int32) + true_msa = msa + msa = bert_msa + make_masked_msa_result = bert_mask, true_msa, msa + return make_masked_msa_result + +def nearest_neighbor_clusters(msa_mask, msa, extra_msa_mask, extra_msa, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask + weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0) + + # Make agreement score as weighted Hamming distance + sample_one_hot = msa_mask[:, :, None] * one_hot(23, msa) + num_seq, num_res, _ = sample_one_hot.shape + + array_extra_msa_mask = extra_msa_mask + if array_extra_msa_mask.any(): + extra_one_hot = extra_msa_mask[:, :, None] * one_hot(23, extra_msa) + extra_num_seq, _, _ = extra_one_hot.shape + + agreement = np.matmul( + np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T) + # Assign each sequence in the extra sequences to the closest MSA sample + extra_cluster_assignment = np.argmax(agreement, axis=1) + else: + extra_cluster_assignment = np.array([]) + return extra_cluster_assignment + + +def summarize_clusters(inputs): + """Produce profile and deletion_matrix_mean within each cluster.""" + msa, msa_mask, extra_cluster_assignment, extra_msa_mask, \ + extra_msa, extra_deletion_matrix, deletion_matrix = inputs + num_seq = msa.shape[0] + + def csum(x): + result = [] + for i in range(num_seq): + result.append(np.sum(x[np.where(extra_cluster_assignment == i)], axis=0)) + return np.array(result) + + mask = extra_msa_mask + mask_counts = 1e-6 + msa_mask + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * one_hot(23, extra_msa)) + msa_sum += one_hot(23, msa) # Original sequence + cluster_profile = msa_sum / mask_counts[:, :, None] + + del msa_sum + + del_sum = csum(mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence + cluster_deletion_mean = del_sum / mask_counts + del del_sum + + return cluster_profile, cluster_deletion_mean + + +def crop_extra_msa(extra_msa, max_extra_msa): + """MSA features are cropped so only `max_extra_msa` sequences are kept.""" + num_seq = extra_msa.shape[0] + num_sel = np.minimum(max_extra_msa, num_seq) + shuffled = list(range(num_seq)) + np.random.shuffle(shuffled) + select_indices = shuffled[:num_sel] + return select_indices + + +def make_msa_feat(inputs): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping + # for compatibility with domain datasets. + between_segment_residues, aatype, msa, deletion_matrix, cluster_deletion_mean, cluster_profile, \ + extra_deletion_matrix = inputs + has_break = np.clip(between_segment_residues.astype(np.float32), np.array(0), np.array(1)) + aatype_1hot = one_hot(21, aatype) + + target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot] + + msa_1hot = one_hot(23, msa) + has_deletion = np.clip(deletion_matrix, np.array(0), np.array(1)) + deletion_value = np.arctan(deletion_matrix / 3.) * (2. / np.pi) + + msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)] + + if cluster_profile is not None: + deletion_mean_value = (np.arctan(cluster_deletion_mean / 3.) * (2. / np.pi)) + msa_feat.extend([cluster_profile, np.expand_dims(deletion_mean_value, axis=-1)]) + extra_has_deletion = None + extra_deletion_value = None + if extra_deletion_matrix is not None: + extra_has_deletion = np.clip(extra_deletion_matrix, np.array(0), np.array(1)) + extra_deletion_value = np.arctan(extra_deletion_matrix / 3.) * (2. / np.pi) + + msa_feat = np.concatenate(msa_feat, axis=-1) + target_feat = np.concatenate(target_feat, axis=-1) + res = [extra_has_deletion, extra_deletion_value, msa_feat, target_feat] + return res + + +def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32, random_recycle=False): + if random_recycle: + r = np.random.RandomState(seed_maker_t) + return r.uniform(size=size, low=low, high=high) + np.random.seed(seed_maker_t) + return np.random.uniform(size=size, low=low, high=high) + + +def random_crop_to_size(inputs): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + seq_length, template_mask, crop_size, max_templates, \ + subsample_templates, seed, random_recycle = inputs + seq_length = seq_length + seq_length_int = int(seq_length) + if template_mask is not None: + num_templates = np.array(template_mask.shape[0], np.int32) + else: + num_templates = np.array(0, np.int32) + num_res_crop_size = np.minimum(seq_length, crop_size) + num_res_crop_size_int = int(num_res_crop_size) + + # Ensures that the cropping of residues and templates happens in the same way + # across ensembling iterations. + # Do not use for randomness that should vary in ensembling. + + if subsample_templates: + templates_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, high=num_templates + 1, + random_recycle=random_recycle)) + else: + templates_crop_start = 0 + + num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates) + num_templates_crop_size_int = int(num_templates_crop_size) + + num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed, low=0, + high=seq_length_int - num_res_crop_size_int + 1, + random_recycle=random_recycle)) + + templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed, + random_recycle=random_recycle)) + res = [num_res_crop_size, num_templates_crop_size_int, num_res_crop_start, num_res_crop_size_int, \ + templates_crop_start, templates_select_indices] + return res + + +def generate_random_sample(cfg, model_config): + '''generate_random_sample''' + np.random.seed(0) + num_noise = model_config.model.latent.num_noise + latent_dim = model_config.model.latent.latent_dim + + context_true_prob = np.absolute(model_config.train.context_true_prob) + keep_prob = np.absolute(model_config.train.keep_prob) + + available_msa = int(model_config.train.available_msa_fraction * model_config.train.max_msa_clusters) + available_msa = min(available_msa, model_config.train.max_msa_clusters) + + evogen_random_data = np.random.normal( + size=(num_noise, model_config.train.max_msa_clusters, cfg.eval.crop_size, latent_dim)).astype(np.float32) + + # (Nseq,): + context_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) + z1 = np.random.random(model_config.train.max_msa_clusters) + context_mask = np.asarray([1 if x < context_true_prob else 0 for x in z1], np.int32) + context_mask[available_msa:] *= 0 + + # (Nseq,): + target_mask = np.zeros((model_config.train.max_msa_clusters,), np.int32) + z2 = np.random.random(model_config.train.max_msa_clusters) + target_mask = np.asarray([1 if x < keep_prob else 0 for x in z2], np.int32) + + context_mask[0] = 1 + target_mask[0] = 1 + + evogen_context_mask = np.stack((context_mask, target_mask), -1) + return evogen_random_data, evogen_context_mask diff --git a/tests/st/mindscience/mindsponge/module/head.py b/tests/st/mindscience/mindsponge/module/head.py index 3456019f688..0170e1fc54f 100644 --- a/tests/st/mindscience/mindsponge/module/head.py +++ b/tests/st/mindscience/mindsponge/module/head.py @@ -1,43 +1,43 @@ -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""structure module""" -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from tests.st.mindscience.mindsponge.mindsponge.cell.initializer import lecun_init - - -class PredictedLDDTHead(nn.Cell): - """Head to predict the per-residue LDDT to be used as a confidence measure.""" - - def __init__(self, config, seq_channel): - super().__init__() - self.config = config - self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) - self.act_0 = nn.Dense(seq_channel, self.config.num_channels, - weight_init=lecun_init(seq_channel, initializer_name='relu') - ).to_float(mstype.float16) - self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, - weight_init=lecun_init(self.config.num_channels, initializer_name='relu') - ).to_float(mstype.float16) - self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' - ).to_float(mstype.float16) - self.relu = nn.ReLU() - - def construct(self, rp_structure_module): - """Builds ExperimentallyResolvedHead module.""" - act = rp_structure_module - act = self.input_layer_norm(act.astype(mstype.float32)) - act = self.act_0(act) - act = self.relu(act.astype(mstype.float32)) - act = self.act_1(act) - act = self.relu(act.astype(mstype.float32)) - logits = self.logits(act) - return logits +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from tests.st.mindscience.mindsponge.mindsponge.cell.initializer import lecun_init + + +class PredictedLDDTHead(nn.Cell): + """Head to predict the per-residue LDDT to be used as a confidence measure.""" + + def __init__(self, config, seq_channel): + super().__init__() + self.config = config + self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) + self.act_0 = nn.Dense(seq_channel, self.config.num_channels, + weight_init=lecun_init(seq_channel, initializer_name='relu') + ).to_float(mstype.float16) + self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, + weight_init=lecun_init(self.config.num_channels, initializer_name='relu') + ).to_float(mstype.float16) + self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' + ).to_float(mstype.float16) + self.relu = nn.ReLU() + + def construct(self, rp_structure_module): + """Builds ExperimentallyResolvedHead module.""" + act = rp_structure_module + act = self.input_layer_norm(act.astype(mstype.float32)) + act = self.act_0(act) + act = self.relu(act.astype(mstype.float32)) + act = self.act_1(act) + act = self.relu(act.astype(mstype.float32)) + logits = self.logits(act) + return logits diff --git a/tests/st/mindscience/mindsponge/test_fold.py b/tests/st/mindscience/mindsponge/test_fold.py index cbe6ffa3a0e..2c9999884ce 100644 --- a/tests/st/mindscience/mindsponge/test_fold.py +++ b/tests/st/mindscience/mindsponge/test_fold.py @@ -1,110 +1,110 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""eval script""" -import os -import numpy as np -import time -import mindspore.context as context -from mindspore import Tensor, nn, load_checkpoint -from tests.st.mindscience.mindsponge.mindsponge.cell.amp import amp_convert -from tests.st.mindscience.mindsponge.mindsponge.common.config_load import load_config -from tests.mark_utils import arg_mark - -from data import Feature -from model import MegaFold, compute_confidence - - -def fold_infer(mixed_precision, crop_size, is_ge_only=False): - '''mega fold inference''' - data_config = "./config/data.yaml" - model_config = "./config/model.yaml" - checkpoint_path = "/home/workspace/mindspore_ckpt/ckpt/MEGA_Fold_1.ckpt" - data_cfg = load_config(data_config) - model_cfg = load_config(model_config) - data_cfg.eval.crop_size = crop_size - model_cfg.seq_length = data_cfg.eval.crop_size - slice_key = "seq_" + str(model_cfg.seq_length) - slice_val = vars(model_cfg.slice)[slice_key] - model_cfg.slice = slice_val - megafold = MegaFold(model_cfg, mixed_precision=mixed_precision) - if is_ge_only: - context.set_context(jit_level="O2") - load_checkpoint(checkpoint_path, megafold) - fp32_white_list = (nn.Softmax, nn.LayerNorm) - amp_convert(megafold, fp32_white_list) - time_list = [] - raw_feature = np.load("/home/workspace/mindspore_dataset/mindsponge_data/pkl/raw_feature.npy", allow_pickle=True) - raw_feature = raw_feature.item() - ori_res_length = raw_feature['msa'].shape[1] - processed_feature = Feature(data_cfg, raw_feature) - feat, prev_pos, prev_msa_first_row, prev_pair = processed_feature.pipeline(data_cfg, - mixed_precision=mixed_precision) - prev_pos = Tensor(prev_pos) - prev_msa_first_row = Tensor(prev_msa_first_row) - prev_pair = Tensor(prev_pair) - for i in range(2): - feat_i = [Tensor(x[i]) for x in feat] - t_start = time.time() - result = megafold(*feat_i, prev_pos, prev_msa_first_row, prev_pair) - t_end = time.time() - time_list.append(t_end - t_start) - prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = result - predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] - confidence, _ = compute_confidence(predicted_lddt_logits, return_lddt=True) - return confidence, time_list - - -@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_910B_Ascend_fold(): - """ - Feature: 910B Megaflod - Description: test train and eval - Expectation: success - """ - os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "SATURATION_MODE" - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - memory_optimize_level="O1", - max_call_depth=6000) - mixed_precision = 1 - crop_size = 512 - confidence, time_list = fold_infer(mixed_precision, crop_size) - compile_time, exectue_time = time_list - compile_time = compile_time - exectue_time - os.environ.pop("MS_ASCEND_CHECK_OVERFLOW_MODE") - assert confidence > 0.9 - assert compile_time < 500 - assert exectue_time < 100 - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_910A_Ascend_fold(): - """ - Feature: 910A Megaflod - Description: test train and eval - Expectation: success - """ - context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend", - memory_optimize_level="O1", - max_call_depth=6000) - context.set_context(jit_level="O2") - mixed_precision = 1 - crop_size = 1024 - confidence, time_list = fold_infer(mixed_precision, crop_size, True) - compile_time, exectue_time = time_list - compile_time = compile_time - exectue_time - assert confidence > 0.9 - assert compile_time < 500 - assert exectue_time < 100 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval script""" +import os +import numpy as np +import time +import mindspore.context as context +from mindspore import Tensor, nn, load_checkpoint +from tests.st.mindscience.mindsponge.mindsponge.cell.amp import amp_convert +from tests.st.mindscience.mindsponge.mindsponge.common.config_load import load_config +from tests.mark_utils import arg_mark + +from data import Feature +from model import MegaFold, compute_confidence + + +def fold_infer(mixed_precision, crop_size, is_ge_only=False): + '''mega fold inference''' + data_config = "./config/data.yaml" + model_config = "./config/model.yaml" + checkpoint_path = "/home/workspace/mindspore_ckpt/ckpt/MEGA_Fold_1.ckpt" + data_cfg = load_config(data_config) + model_cfg = load_config(model_config) + data_cfg.eval.crop_size = crop_size + model_cfg.seq_length = data_cfg.eval.crop_size + slice_key = "seq_" + str(model_cfg.seq_length) + slice_val = vars(model_cfg.slice)[slice_key] + model_cfg.slice = slice_val + megafold = MegaFold(model_cfg, mixed_precision=mixed_precision) + if is_ge_only: + context.set_context(jit_level="O2") + load_checkpoint(checkpoint_path, megafold) + fp32_white_list = (nn.Softmax, nn.LayerNorm) + amp_convert(megafold, fp32_white_list) + time_list = [] + raw_feature = np.load("/home/workspace/mindspore_dataset/mindsponge_data/pkl/raw_feature.npy", allow_pickle=True) + raw_feature = raw_feature.item() + ori_res_length = raw_feature['msa'].shape[1] + processed_feature = Feature(data_cfg, raw_feature) + feat, prev_pos, prev_msa_first_row, prev_pair = processed_feature.pipeline(data_cfg, + mixed_precision=mixed_precision) + prev_pos = Tensor(prev_pos) + prev_msa_first_row = Tensor(prev_msa_first_row) + prev_pair = Tensor(prev_pair) + for i in range(2): + feat_i = [Tensor(x[i]) for x in feat] + t_start = time.time() + result = megafold(*feat_i, prev_pos, prev_msa_first_row, prev_pair) + t_end = time.time() + time_list.append(t_end - t_start) + prev_pos, prev_msa_first_row, prev_pair, predicted_lddt_logits = result + predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] + confidence, _ = compute_confidence(predicted_lddt_logits, return_lddt=True) + return confidence, time_list + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_910B_Ascend_fold(): + """ + Feature: 910B Megaflod + Description: test train and eval + Expectation: success + """ + os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "SATURATION_MODE" + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + memory_optimize_level="O1", + max_call_depth=6000) + mixed_precision = 1 + crop_size = 512 + confidence, time_list = fold_infer(mixed_precision, crop_size) + compile_time, exectue_time = time_list + compile_time = compile_time - exectue_time + os.environ.pop("MS_ASCEND_CHECK_OVERFLOW_MODE") + assert confidence > 0.9 + assert compile_time < 500 + assert exectue_time < 100 + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_910A_Ascend_fold(): + """ + Feature: 910A Megaflod + Description: test train and eval + Expectation: success + """ + context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + memory_optimize_level="O1", + max_call_depth=6000) + context.set_context(jit_level="O2") + mixed_precision = 1 + crop_size = 1024 + confidence, time_list = fold_infer(mixed_precision, crop_size, True) + compile_time, exectue_time = time_list + compile_time = compile_time - exectue_time + assert confidence > 0.9 + assert compile_time < 500 + assert exectue_time < 100 diff --git a/tests/st/nccl/test_nccl_broadcast_op.py b/tests/st/nccl/test_nccl_broadcast_op.py index fe3955dc396..ae12df6a654 100644 --- a/tests/st/nccl/test_nccl_broadcast_op.py +++ b/tests/st/nccl/test_nccl_broadcast_op.py @@ -1,71 +1,71 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.communication.management import init, get_rank, get_group_size -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - -init() -rank = get_rank() -size = get_group_size() -x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') - self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') - self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') - - self.broadcast1 = P.Broadcast(0) - self.broadcast2 = P.Broadcast(1) - self.broadcast3 = P.Broadcast(2) - - def construct(self): - return (self.broadcast1((self.x1,)), - self.broadcast2((self.x2,)), - self.broadcast3((self.x3,))) - - -def test_Broadcast(): - broadcast = Net() - output = broadcast() - - expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 1 - expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 2 - expect2 = np.ones([3, 1, 3, 3]).astype(np.float32) * 3 - - diff0 = output[0][0].asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output[0][0].shape == expect0.shape - - diff1 = output[1][0].asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output[1][0].shape == expect1.shape - - diff2 = output[2][0].asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output[2][0].shape == expect2.shape +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + +init() +rank = get_rank() +size = get_group_size() +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') + self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') + self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') + + self.broadcast1 = P.Broadcast(0) + self.broadcast2 = P.Broadcast(1) + self.broadcast3 = P.Broadcast(2) + + def construct(self): + return (self.broadcast1((self.x1,)), + self.broadcast2((self.x2,)), + self.broadcast3((self.x3,))) + + +def test_Broadcast(): + broadcast = Net() + output = broadcast() + + expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 1 + expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 2 + expect2 = np.ones([3, 1, 3, 3]).astype(np.float32) * 3 + + diff0 = output[0][0].asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0][0].shape == expect0.shape + + diff1 = output[1][0].asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1][0].shape == expect1.shape + + diff2 = output[2][0].asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2][0].shape == expect2.shape diff --git a/tests/st/networks/mindcv b/tests/st/networks/mindcv deleted file mode 160000 index 9c6dc67cc53..00000000000 --- a/tests/st/networks/mindcv +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9c6dc67cc53dca7584799815ec001f3edee732eb diff --git a/tests/st/networks/mindformers b/tests/st/networks/mindformers deleted file mode 160000 index a58d16d6d84..00000000000 --- a/tests/st/networks/mindformers +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a58d16d6d849490207ddde5dc8d73ebceb4798df diff --git a/tests/st/networks/mindocr b/tests/st/networks/mindocr deleted file mode 160000 index 372675bc542..00000000000 --- a/tests/st/networks/mindocr +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 372675bc542dc356e7b0ab752d08137e91ec9c8d diff --git a/tests/st/networks/mindone b/tests/st/networks/mindone deleted file mode 160000 index 5e36d3a8fca..00000000000 --- a/tests/st/networks/mindone +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5e36d3a8fca0d3e1c06a2fac5a3e04625a1ed00e diff --git a/tests/st/networks/models/alexnet.py b/tests/st/networks/models/alexnet.py index bb11954fdb9..9d1daccc655 100644 --- a/tests/st/networks/models/alexnet.py +++ b/tests/st/networks/models/alexnet.py @@ -1,55 +1,55 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import mindspore.nn as nn -from mindspore.ops import operations as P - - -class AlexNet(nn.Cell): - def __init__(self, num_classes=10): - super(AlexNet, self).__init__() - self.batch_size = 32 - self.conv1 = nn.Conv2d(3, 96, 11, stride=4, pad_mode="valid") - self.conv2 = nn.Conv2d(96, 256, 5, stride=1, pad_mode="same") - self.conv3 = nn.Conv2d(256, 384, 3, stride=1, pad_mode="same") - self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode="same") - self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode="same") - self.relu = P.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(66256, 4096) - self.fc2 = nn.Dense(4096, 4096) - self.fc3 = nn.Dense(4096, num_classes) - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv3(x) - x = self.relu(x) - x = self.conv4(x) - x = self.relu(x) - x = self.conv5(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class AlexNet(nn.Cell): + def __init__(self, num_classes=10): + super(AlexNet, self).__init__() + self.batch_size = 32 + self.conv1 = nn.Conv2d(3, 96, 11, stride=4, pad_mode="valid") + self.conv2 = nn.Conv2d(96, 256, 5, stride=1, pad_mode="same") + self.conv3 = nn.Conv2d(256, 384, 3, stride=1, pad_mode="same") + self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode="same") + self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode="same") + self.relu = P.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(66256, 4096) + self.fc2 = nn.Dense(4096, 4096) + self.fc3 = nn.Dense(4096, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py index fe1331b593b..89fce7c6ae7 100644 --- a/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/bert_performance/src/bert_for_pre_training.py @@ -1,913 +1,913 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Bert for pretraining.""" -import numpy as np -import mindspore.ops as ops -import mindspore.nn as nn -from mindspore import context -from mindspore.common import dtype as mstype -from mindspore.common.initializer import initializer, TruncatedNormal -from mindspore.common.parameter import Parameter -from mindspore.common.tensor import Tensor -from mindspore.communication.management import get_group_size -from mindspore.context import ParallelMode -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.ops import operations as P -from src.bert_model import BertModel - -GRADIENT_CLIP_TYPE = 1 -GRADIENT_CLIP_VALUE = 1.0 - -clip_grad = C.MultitypeFuncGraph("clip_grad") - - -@clip_grad.register("Number", "Number", "Tensor") -def _clip_grad(clip_type, clip_value, grad): - """ - Clip gradients. - - Inputs: - clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. - clip_value (float): Specifies how much to clip. - grad (tuple[Tensor]): Gradients. - - Outputs: - tuple[Tensor], clipped gradients. - """ - if clip_type not in (0, 1): - return grad - dt = F.dtype(grad) - if clip_type == 0: - new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), - F.cast(F.tuple_to_array((clip_value,)), dt)) - else: - new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) - return new_grad - - -class GetMaskedLMOutput(nn.Cell): - """ - Get masked lm output. - - Args: - config (BertConfig): The config of BertModel. - - Returns: - Tensor, masked lm output. - """ - - def __init__(self, config): - super(GetMaskedLMOutput, self).__init__() - self.width = config.hidden_size - self.reshape = P.Reshape() - self.gather = P.Gather() - - weight_init = TruncatedNormal(config.initializer_range) - self.dense = nn.Dense(self.width, - config.hidden_size, - weight_init=weight_init, - activation=config.hidden_act).to_float(config.compute_type) - self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) - self.output_bias = Parameter( - initializer( - 'zero', - config.vocab_size)) - self.matmul = P.MatMul(transpose_b=True) - self.log_softmax = nn.LogSoftmax(axis=-1) - self.shape_flat_offsets = (-1, 1) - self.last_idx = (-1,) - self.shape_flat_sequence_tensor = (-1, self.width) - self.cast = P.Cast() - self.compute_type = config.compute_type - self.dtype = config.dtype - - def construct(self, - input_tensor, - output_weights, - positions): - """Get output log_probs""" - input_shape = P.Shape()(input_tensor) - rng = F.tuple_to_array(F.make_range(input_shape[0])) - flat_offsets = self.reshape(rng * input_shape[1], self.shape_flat_offsets) - flat_position = self.reshape(positions + flat_offsets, self.last_idx) - flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) - input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) - input_tensor = self.cast(input_tensor, self.compute_type) - output_weights = self.cast(output_weights, self.compute_type) - input_tensor = self.dense(input_tensor) - input_tensor = self.layernorm(input_tensor) - logits = self.matmul(input_tensor, output_weights) - logits = self.cast(logits, self.dtype) - logits = logits + self.output_bias - log_probs = self.log_softmax(logits) - return log_probs - - -class GetNextSentenceOutput(nn.Cell): - """ - Get next sentence output. - - Args: - config (BertConfig): The config of Bert. - - Returns: - Tensor, next sentence output. - """ - - def __init__(self, config): - super(GetNextSentenceOutput, self).__init__() - self.log_softmax = P.LogSoftmax() - weight_init = TruncatedNormal(config.initializer_range) - self.dense = nn.Dense(config.hidden_size, 2, - weight_init=weight_init, has_bias=True).to_float(config.compute_type) - self.dtype = config.dtype - self.cast = P.Cast() - - def construct(self, input_tensor): - logits = self.dense(input_tensor) - logits = self.cast(logits, self.dtype) - log_prob = self.log_softmax(logits) - return log_prob - - -class BertPreTraining(nn.Cell): - """ - Bert pretraining network. - - Args: - config (BertConfig): The config of BertModel. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. - - Returns: - Tensor, prediction_scores, seq_relationship_score. - """ - - def __init__(self, config, is_training, use_one_hot_embeddings): - super(BertPreTraining, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.cls1 = GetMaskedLMOutput(config) - self.cls2 = GetNextSentenceOutput(config) - - def construct(self, input_ids, input_mask, token_type_id, - masked_lm_positions): - sequence_output, pooled_output, embedding_table = \ - self.bert(input_ids, token_type_id, input_mask) - prediction_scores = self.cls1(sequence_output, - embedding_table, - masked_lm_positions) - seq_relationship_score = self.cls2(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPretrainingLoss(nn.Cell): - """ - Provide bert pre-training loss. - - Args: - config (BertConfig): The config of BertModel. - - Returns: - Tensor, total loss. - """ - - def __init__(self, config): - super(BertPretrainingLoss, self).__init__() - self.vocab_size = config.vocab_size - self.onehot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.reshape = P.Reshape() - self.last_idx = (-1,) - self.neg = P.Neg() - self.cast = P.Cast() - - def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, - masked_lm_weights, next_sentence_labels): - """Defines the computation performed.""" - label_ids = self.reshape(masked_lm_ids, self.last_idx) - label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) - one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) - - per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) - numerator = self.reduce_sum(label_weights * per_example_loss, ()) - denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) - masked_lm_loss = numerator / denominator - - # next_sentence_loss - labels = self.reshape(next_sentence_labels, self.last_idx) - one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) - per_example_loss = self.neg(self.reduce_sum( - one_hot_labels * seq_relationship_score, self.last_idx)) - next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) - - # total_loss - total_loss = masked_lm_loss + next_sentence_loss - - return total_loss - - -class BertNetworkWithLoss(nn.Cell): - """ - Provide bert pre-training loss through network. - - Args: - config (BertConfig): The config of BertModel. - is_training (bool): Specifies whether to use the training mode. - use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. - - Returns: - Tensor, the loss of the network. - """ - - def __init__(self, config, is_training, use_one_hot_embeddings=False): - super(BertNetworkWithLoss, self).__init__() - self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) - self.loss = BertPretrainingLoss(config) - self.cast = P.Cast() - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights): - """Get pre-training loss""" - prediction_scores, seq_relationship_score = \ - self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) - total_loss = self.loss(prediction_scores, seq_relationship_score, - masked_lm_ids, masked_lm_weights, next_sentence_labels) - return self.cast(total_loss, mstype.float32) - - -class BertTrainOneStepCell(nn.TrainOneStepCell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - sens (Number): The adjust parameter. Default: 1.0. - enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True. - """ - - def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True): - super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) - self.cast = P.Cast() - self.hyper_map = C.HyperMap() - self.enable_clip_grad = enable_clip_grad - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights): - """Defines the computation performed.""" - weights = self.weights - - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(F.tuple_to_array((self.sens,)), - mstype.float32)) - if self.enable_clip_grad: - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - grads = self.grad_reducer(grads) - self.optimizer(grads) - return loss - - -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() - - -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) - - -_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") -grad_overflow = P.FloatStatus() - - -@_grad_overflow.register("Tensor") -def _tensor_grad_overflow(grad): - return grad_overflow(grad) - - -class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - """ - - def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) - self.cast = P.Cast() - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - self.load = P.Load() - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - status, scaling_sens = self.start_overflow_check(loss, scaling_sens) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(scaling_sens, - mstype.float32)) - # apply grad reducer on grads - grads = self.grad_reducer(grads) - degree_sens = self.cast(scaling_sens * self.degree, mstype.float32) - grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - - cond = self.get_overflow_status(status, grads) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if not overflow: - self.optimizer(grads) - return loss, cond, scaling_sens.value() - - -class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow - condition as input. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - """ - - def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell) - self.cast = P.Cast() - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - - status, scaling_sens = self.start_overflow_check(loss, scaling_sens) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(scaling_sens, - mstype.float32)) - # apply grad reducer on grads - grads = self.grad_reducer(grads) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - cond = self.get_overflow_status(status, grads) - overflow = cond - if self.loss_scaling_manager is not None: - overflow = self.loss_scaling_manager(scaling_sens, cond) - self.optimizer(grads, overflow) - return (loss, cond, scaling_sens.value()) - - -cast = P.Cast() -add_grads = C.MultitypeFuncGraph("add_grads") - - -@add_grads.register("Tensor", "Tensor") -def _add_grads(accu_grad, grad): - return accu_grad + cast(grad, mstype.float32) - - -update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") - - -@update_accu_grads.register("Tensor", "Tensor") -def _update_accu_grads(accu_grad, grad): - F.assign(accu_grad, cast(grad, mstype.float32)) - return True - - -accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") - - -@accumulate_accu_grads.register("Tensor", "Tensor") -def _accumulate_accu_grads(accu_grad, grad): - F.assign_add(accu_grad, cast(grad, mstype.float32)) - return True - - -zeroslike = P.ZerosLike() -reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") - - -@reset_accu_grads.register("Tensor") -def _reset_accu_grads(accu_grad): - F.assign(accu_grad, zeroslike(accu_grad)) - return True - - -class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - To mimic higher batch size, gradients are accumulated N times before weight update. - - For distribution mode, allreduce will only be implemented in the weight updated step, - i.e. the sub-step after gradients accumulated N times. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = - batch_size * accumulation_steps. Default: 1. - """ - - def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): - super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.accumulation_steps = accumulation_steps - self.enable_global_norm = enable_global_norm - self.one = Tensor(np.array([1]).astype(np.int32)) - self.zero = Tensor(np.array([0]).astype(np.int32)) - self.local_step = Parameter(initializer(0, [1], mstype.int32)) - self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') - self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) - self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) - - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.reducer_flag = False - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.overflow_reducer = F.identity - if self.is_distributed: - self.overflow_reducer = P.AllReduce() - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.logical_or = P.LogicalOr() - self.not_equal = P.NotEqual() - self.select = P.Select() - self.reshape = P.Reshape() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) - # update accumulation parameters - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) - self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) - mean_loss = self.accu_loss / self.local_step - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(scaling_sens, - mstype.float32)) - - accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) - mean_loss = F.depend(mean_loss, accu_succ) - - init = F.depend(init, mean_loss) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - overflow = self.less_equal(self.base, flag_sum) - overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) - accu_overflow = self.select(overflow, self.one, self.zero) - self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) - - if not is_accu_step: - # apply grad reducer on grads - grads = self.grad_reducer(self.accu_grads) - scaling = scaling_sens * self.degree * self.accumulation_steps - grads = self.hyper_map(F.partial(grad_scale, scaling), grads) - if self.enable_global_norm: - grads = C.clip_by_global_norm(grads, 1.0, None) - else: - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - accu_overflow = F.depend(accu_overflow, grads) - accu_overflow = self.overflow_reducer(accu_overflow) - overflow = self.less_equal(self.base, accu_overflow) - accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) - overflow = F.depend(overflow, accu_succ) - overflow = self.reshape(overflow, (())) - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, overflow) - if not overflow: - self.optimizer(grads) - - return (mean_loss, overflow, scaling_sens.value()) - - -class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. - - To mimic higher batch size, gradients are accumulated N times before weight update. - - For distribution mode, allreduce will be implemented after each sub-step and the trailing time - will be overided by backend optimization pass. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = - batch_size * accumulation_steps. Default: 1. - """ - - def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): - super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.accumulation_steps = accumulation_steps - self.enable_global_norm = enable_global_norm - self.one = Tensor(np.array([1]).astype(np.int32)) - self.zero = Tensor(np.array([0]).astype(np.int32)) - self.local_step = Parameter(initializer(0, [1], mstype.int32)) - self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') - self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) - self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) - - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.reducer_flag = False - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.overflow_reducer = F.identity - if self.is_distributed: - self.overflow_reducer = P.AllReduce() - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.logical_or = P.LogicalOr() - self.not_equal = P.NotEqual() - self.select = P.Select() - self.reshape = P.Reshape() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - - @C.add_flags(has_effect=True) - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - - # update accumulation parameters - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) - self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) - mean_loss = self.accu_loss / self.local_step - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - self.clear_before_grad(init) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - self.cast(scaling_sens, - mstype.float32)) - - accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) - scaling = scaling_sens * self.degree * self.accumulation_steps - grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) - grads = self.grad_reducer(grads) - - self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - flag_reduce = self.overflow_reducer(flag_sum) - overflow = self.less_equal(self.base, flag_reduce) - overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) - accu_overflow = self.select(overflow, self.one, self.zero) - self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) - overflow = self.reshape(overflow, (())) - - if is_accu_step: - succ = False - accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) - succ = F.depend(succ, accu_succ) - else: - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, overflow) - if overflow: - succ = False - else: - if self.enable_global_norm: - grads = C.clip_by_global_norm(grads, 1.0, None) - else: - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - - succ = self.optimizer(grads) - - accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) - succ = F.depend(succ, accu_succ) - - ret = (mean_loss, overflow, scaling_sens.value()) - return F.depend(ret, succ) - - -class BertNetworkMatchBucket(nn.Cell): - ''' - Bert execute according to different sentence lengths. - ''' - - def __init__(self, network, seq_length, bucket_list=None): - super(BertNetworkMatchBucket, self).__init__() - self.network = network - if not bucket_list or not isinstance(bucket_list, list): - bucket_list = [seq_length] - self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length] - - if network.reducer_flag: - reuse_attr = 'reuse_communication_node' - if not network.grad_reducer.split_fusion: - hccl_op = network.grad_reducer.allreduce - network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) - else: - new_op_list = [] - for hccl_op in network.grad_reducer.op_list: - new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) - new_op_list.append(new_op) - network.grad_reducer.op_list = new_op_list - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights, - sentence_flag): - """Switch network according to sentence length.""" - for bucket in self.bucket_list: - if sentence_flag == bucket: - input_ids = input_ids[:, :bucket] - input_mask = input_mask[:, :bucket] - token_type_id = token_type_id[:, :bucket] - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - return loss - - loss = self.network(input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights) - return loss - - -class BertPretrainEval(nn.Cell): - ''' - Evaluate MaskedLM prediction scores - ''' - - def __init__(self, config, network=None): - super(BertPretrainEval, self).__init__(auto_prefix=False) - if network is None: - self.network = BertPreTraining(config, False, False) - else: - self.network = network - self.argmax = P.Argmax(axis=-1, output_type=mstype.int32) - self.equal = P.Equal() - self.sum = P.ReduceSum() - self.reshape = P.Reshape() - self.shape = P.Shape() - self.cast = P.Cast() - self.allreduce = P.AllReduce() - self.reduce_flag = False - parallel_mode = context.get_auto_parallel_context("parallel_mode") - if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reduce_flag = True - - def construct(self, - input_ids, - input_mask, - token_type_id, - next_sentence_labels, - masked_lm_positions, - masked_lm_ids, - masked_lm_weights): - """Calculate prediction scores""" - bs, _ = self.shape(input_ids) - mlm, _ = self.network(input_ids, input_mask, token_type_id, masked_lm_positions) - index = self.argmax(mlm) - index = self.reshape(index, (bs, -1)) - eval_acc = self.equal(index, masked_lm_ids) - eval_acc = self.cast(eval_acc, mstype.float32) - real_acc = eval_acc * masked_lm_weights - acc = self.sum(real_acc) - total = self.sum(masked_lm_weights) - - if self.reduce_flag: - acc = self.allreduce(acc) - total = self.allreduce(total) - - return acc, total +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert for pretraining.""" +import numpy as np +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import context +from mindspore.common import dtype as mstype +from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.communication.management import get_group_size +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from src.bert_model import BertModel + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type not in (0, 1): + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +class GetMaskedLMOutput(nn.Cell): + """ + Get masked lm output. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, masked lm output. + """ + + def __init__(self, config): + super(GetMaskedLMOutput, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + self.gather = P.Gather() + + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(self.width, + config.hidden_size, + weight_init=weight_init, + activation=config.hidden_act).to_float(config.compute_type) + self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) + self.output_bias = Parameter( + initializer( + 'zero', + config.vocab_size)) + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_offsets = (-1, 1) + self.last_idx = (-1,) + self.shape_flat_sequence_tensor = (-1, self.width) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + + def construct(self, + input_tensor, + output_weights, + positions): + """Get output log_probs""" + input_shape = P.Shape()(input_tensor) + rng = F.tuple_to_array(F.make_range(input_shape[0])) + flat_offsets = self.reshape(rng * input_shape[1], self.shape_flat_offsets) + flat_position = self.reshape(positions + flat_offsets, self.last_idx) + flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + input_tensor = self.dense(input_tensor) + input_tensor = self.layernorm(input_tensor) + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + logits = logits + self.output_bias + log_probs = self.log_softmax(logits) + return log_probs + + +class GetNextSentenceOutput(nn.Cell): + """ + Get next sentence output. + + Args: + config (BertConfig): The config of Bert. + + Returns: + Tensor, next sentence output. + """ + + def __init__(self, config): + super(GetNextSentenceOutput, self).__init__() + self.log_softmax = P.LogSoftmax() + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(config.hidden_size, 2, + weight_init=weight_init, has_bias=True).to_float(config.compute_type) + self.dtype = config.dtype + self.cast = P.Cast() + + def construct(self, input_tensor): + logits = self.dense(input_tensor) + logits = self.cast(logits, self.dtype) + log_prob = self.log_softmax(logits) + return log_prob + + +class BertPreTraining(nn.Cell): + """ + Bert pretraining network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings): + super(BertPreTraining, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cls1 = GetMaskedLMOutput(config) + self.cls2 = GetNextSentenceOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, + masked_lm_positions): + sequence_output, pooled_output, embedding_table = \ + self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, + embedding_table, + masked_lm_positions) + seq_relationship_score = self.cls2(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPretrainingLoss(nn.Cell): + """ + Provide bert pre-training loss. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, total loss. + """ + + def __init__(self, config): + super(BertPretrainingLoss, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + + def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, + masked_lm_weights, next_sentence_labels): + """Defines the computation performed.""" + label_ids = self.reshape(masked_lm_ids, self.last_idx) + label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + masked_lm_loss = numerator / denominator + + # next_sentence_loss + labels = self.reshape(next_sentence_labels, self.last_idx) + one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum( + one_hot_labels * seq_relationship_score, self.last_idx)) + next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) + + # total_loss + total_loss = masked_lm_loss + next_sentence_loss + + return total_loss + + +class BertNetworkWithLoss(nn.Cell): + """ + Provide bert pre-training loss through network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(BertNetworkWithLoss, self).__init__() + self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) + self.loss = BertPretrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Get pre-training loss""" + prediction_scores, seq_relationship_score = \ + self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) + total_loss = self.loss(prediction_scores, seq_relationship_score, + masked_lm_ids, masked_lm_weights, next_sentence_labels) + return self.cast(total_loss, mstype.float32) + + +class BertTrainOneStepCell(nn.TrainOneStepCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True. + """ + + def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True): + super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + self.enable_clip_grad = enable_clip_grad + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Defines the computation performed.""" + weights = self.weights + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + if self.enable_clip_grad: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + grads = self.grad_reducer(grads) + self.optimizer(grads) + return loss + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + + +class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + self.load = P.Load() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + degree_sens = self.cast(scaling_sens * self.degree, mstype.float32) + grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + cond = self.get_overflow_status(status, grads) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if not overflow: + self.optimizer(grads) + return loss, cond, scaling_sens.value() + + +class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow + condition as input. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell) + self.cast = P.Cast() + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + cond = self.get_overflow_status(status, grads) + overflow = cond + if self.loss_scaling_manager is not None: + overflow = self.loss_scaling_manager(scaling_sens, cond) + self.optimizer(grads, overflow) + return (loss, cond, scaling_sens.value()) + + +cast = P.Cast() +add_grads = C.MultitypeFuncGraph("add_grads") + + +@add_grads.register("Tensor", "Tensor") +def _add_grads(accu_grad, grad): + return accu_grad + cast(grad, mstype.float32) + + +update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") + + +@update_accu_grads.register("Tensor", "Tensor") +def _update_accu_grads(accu_grad, grad): + F.assign(accu_grad, cast(grad, mstype.float32)) + return True + + +accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") + + +@accumulate_accu_grads.register("Tensor", "Tensor") +def _accumulate_accu_grads(accu_grad, grad): + F.assign_add(accu_grad, cast(grad, mstype.float32)) + return True + + +zeroslike = P.ZerosLike() +reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") + + +@reset_accu_grads.register("Tensor") +def _reset_accu_grads(accu_grad): + F.assign(accu_grad, zeroslike(accu_grad)) + return True + + +class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will only be implemented in the weight updated step, + i.e. the sub-step after gradients accumulated N times. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): + super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32)) + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_status = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) + mean_loss = F.depend(mean_loss, accu_succ) + + init = F.depend(init, mean_loss) + get_status = self.get_status(init) + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + overflow = self.less_equal(self.base, flag_sum) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + + if not is_accu_step: + # apply grad reducer on grads + grads = self.grad_reducer(self.accu_grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), grads) + if self.enable_global_norm: + grads = C.clip_by_global_norm(grads, 1.0, None) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + accu_overflow = F.depend(accu_overflow, grads) + accu_overflow = self.overflow_reducer(accu_overflow) + overflow = self.less_equal(self.base, accu_overflow) + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + overflow = F.depend(overflow, accu_succ) + overflow = self.reshape(overflow, (())) + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if not overflow: + self.optimizer(grads) + + return (mean_loss, overflow, scaling_sens.value()) + + +class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + To mimic higher batch size, gradients are accumulated N times before weight update. + + For distribution mode, allreduce will be implemented after each sub-step and the trailing time + will be overided by backend optimization pass. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): + super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32)) + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + + accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) + grads = self.grad_reducer(grads) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + flag_reduce = self.overflow_reducer(flag_sum) + overflow = self.less_equal(self.base, flag_reduce) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + overflow = self.reshape(overflow, (())) + + if is_accu_step: + succ = False + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) + succ = F.depend(succ, accu_succ) + else: + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if overflow: + succ = False + else: + if self.enable_global_norm: + grads = C.clip_by_global_norm(grads, 1.0, None) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + succ = self.optimizer(grads) + + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + succ = F.depend(succ, accu_succ) + + ret = (mean_loss, overflow, scaling_sens.value()) + return F.depend(ret, succ) + + +class BertNetworkMatchBucket(nn.Cell): + ''' + Bert execute according to different sentence lengths. + ''' + + def __init__(self, network, seq_length, bucket_list=None): + super(BertNetworkMatchBucket, self).__init__() + self.network = network + if not bucket_list or not isinstance(bucket_list, list): + bucket_list = [seq_length] + self.bucket_list = [bucket for bucket in bucket_list if bucket <= seq_length] + + if network.reducer_flag: + reuse_attr = 'reuse_communication_node' + if not network.grad_reducer.split_fusion: + hccl_op = network.grad_reducer.allreduce + network.grad_reducer.allreduce = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) + else: + new_op_list = [] + for hccl_op in network.grad_reducer.op_list: + new_op = hccl_op.add_prim_attr(reuse_attr, getattr(hccl_op, 'fusion')) + new_op_list.append(new_op) + network.grad_reducer.op_list = new_op_list + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sentence_flag): + """Switch network according to sentence length.""" + for bucket in self.bucket_list: + if sentence_flag == bucket: + input_ids = input_ids[:, :bucket] + input_mask = input_mask[:, :bucket] + token_type_id = token_type_id[:, :bucket] + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + return loss + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + return loss + + +class BertPretrainEval(nn.Cell): + ''' + Evaluate MaskedLM prediction scores + ''' + + def __init__(self, config, network=None): + super(BertPretrainEval, self).__init__(auto_prefix=False) + if network is None: + self.network = BertPreTraining(config, False, False) + else: + self.network = network + self.argmax = P.Argmax(axis=-1, output_type=mstype.int32) + self.equal = P.Equal() + self.sum = P.ReduceSum() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.cast = P.Cast() + self.allreduce = P.AllReduce() + self.reduce_flag = False + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reduce_flag = True + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Calculate prediction scores""" + bs, _ = self.shape(input_ids) + mlm, _ = self.network(input_ids, input_mask, token_type_id, masked_lm_positions) + index = self.argmax(mlm) + index = self.reshape(index, (bs, -1)) + eval_acc = self.equal(index, masked_lm_ids) + eval_acc = self.cast(eval_acc, mstype.float32) + real_acc = eval_acc * masked_lm_weights + acc = self.sum(real_acc) + total = self.sum(masked_lm_weights) + + if self.reduce_flag: + acc = self.allreduce(acc) + total = self.allreduce(total) + + return acc, total diff --git a/tests/st/networks/models/bert/bert_performance/src/bert_model.py b/tests/st/networks/models/bert/bert_performance/src/bert_model.py index 972aa717862..ca30e3fe2a1 100644 --- a/tests/st/networks/models/bert/bert_performance/src/bert_model.py +++ b/tests/st/networks/models/bert/bert_performance/src/bert_model.py @@ -1,857 +1,857 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Bert model.""" - -import copy -import math -import numpy as np -import mindspore.common.dtype as mstype -import mindspore.ops as ops -import mindspore.nn as nn -import mindspore.ops.functional as F -from mindspore.common.initializer import TruncatedNormal, initializer -from mindspore.common.parameter import Parameter -from mindspore.common.tensor import Tensor -from mindspore.ops import operations as P - - -class BertConfig: - """ - Configuration for `BertModel`. - - Args: - seq_length (int): Length of input sequence. Default: 128. - vocab_size (int): The shape of each embedding vector. Default: 32000. - hidden_size (int): Size of the bert encoder layers. Default: 768. - num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder - cell. Default: 12. - num_attention_heads (int): Number of attention heads in the BertTransformer - encoder cell. Default: 12. - intermediate_size (int): Size of intermediate layer in the BertTransformer - encoder cell. Default: 3072. - hidden_act (str): Activation function used in the BertTransformer encoder - cell. Default: "gelu". - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 512. - type_vocab_size (int): Size of token type vocab. Default: 16. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - """ - - def __init__(self, - seq_length=128, - vocab_size=32000, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - initializer_range=0.02, - use_relative_positions=False, - dtype=mstype.float32, - compute_type=mstype.float32): - self.seq_length = seq_length - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.use_relative_positions = use_relative_positions - self.dtype = dtype - self.compute_type = compute_type - - -class EmbeddingLookup(nn.Cell): - """ - A embeddings lookup table with a fixed dictionary and size. - - Args: - vocab_size (int): Size of the dictionary of embeddings. - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - """ - - def __init__(self, - vocab_size, - embedding_size, - embedding_shape, - use_one_hot_embeddings=False, - initializer_range=0.02): - super(EmbeddingLookup, self).__init__() - self.vocab_size = vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [vocab_size, embedding_size])) - self.expand = P.ExpandDims() - self.shape_flat = (-1,) - self.gather = P.Gather() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - - def construct(self, input_ids): - """Get output and embeddings lookup table""" - extended_ids = self.expand(input_ids, -1) - flat_ids = self.reshape(extended_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) - output_for_reshape = self.array_mul( - one_hot_ids, self.embedding_table) - else: - output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) - output = self.reshape(output_for_reshape, self.shape) - return output, self.embedding_table.value() - - -class EmbeddingPostprocessor(nn.Cell): - """ - Postprocessors apply positional and token type embeddings to word embeddings. - - Args: - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_token_type (bool): Specifies whether to use token type embeddings. Default: False. - token_type_vocab_size (int): Size of token type vocab. Default: 16. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 512. - dropout_prob (float): The dropout probability. Default: 0.1. - """ - - def __init__(self, - embedding_size, - embedding_shape, - use_relative_positions=False, - use_token_type=False, - token_type_vocab_size=16, - use_one_hot_embeddings=False, - initializer_range=0.02, - max_position_embeddings=512, - dropout_prob=0.1): - super(EmbeddingPostprocessor, self).__init__() - self.use_token_type = use_token_type - self.token_type_vocab_size = token_type_vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.max_position_embeddings = max_position_embeddings - self.token_type_embedding = nn.Embedding( - vocab_size=token_type_vocab_size, - embedding_size=embedding_size, - use_one_hot=use_one_hot_embeddings) - self.shape_flat = (-1,) - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.1, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - self.dropout = nn.Dropout(p=dropout_prob) - self.gather = P.Gather() - self.use_relative_positions = use_relative_positions - self.slice = P.StridedSlice() - _, seq, _ = self.shape - self.full_position_embedding = nn.Embedding( - vocab_size=max_position_embeddings, - embedding_size=embedding_size, - use_one_hot=False) - self.layernorm = nn.LayerNorm((embedding_size,)) - self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) - self.add = P.Add() - - def construct(self, token_type_ids, word_embeddings): - """Postprocessors apply positional and token type embeddings to word embeddings.""" - output = word_embeddings - if self.use_token_type: - token_type_embeddings = self.token_type_embedding(token_type_ids) - output = self.add(output, token_type_embeddings) - if not self.use_relative_positions: - shape = F.shape(output) - position_ids = self.position_ids[:, :shape[1]] - position_embeddings = self.full_position_embedding(position_ids) - output = self.add(output, position_embeddings) - output = self.layernorm(output) - output = self.dropout(output) - return output - - -class BertOutput(nn.Cell): - """ - Apply a linear computation to hidden status and a residual computation to input. - - Args: - in_channels (int): Input channels. - out_channels (int): Output channels. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - dropout_prob (float): The dropout probability. Default: 0.1. - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - """ - - def __init__(self, - in_channels, - out_channels, - initializer_range=0.02, - dropout_prob=0.1, - compute_type=mstype.float32): - super(BertOutput, self).__init__() - self.dense = nn.Dense(in_channels, out_channels, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.dropout = nn.Dropout(p=dropout_prob) - self.dropout_prob = dropout_prob - self.add = P.Add() - self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) - self.cast = P.Cast() - - def construct(self, hidden_status, input_tensor): - output = self.dense(hidden_status) - output = self.dropout(output) - output = self.add(input_tensor, output) - output = self.layernorm(output) - return output - - -class RelaPosMatrixGenerator(nn.Cell): - """ - Generates matrix of relative positions between inputs. - - Args: - length (int): Length of one dim for the matrix to be generated. - max_relative_position (int): Max value of relative position. - """ - - def __init__(self, max_relative_position): - super(RelaPosMatrixGenerator, self).__init__() - self._max_relative_position = max_relative_position - self._min_relative_position = -max_relative_position - - self.tile = P.Tile() - self.range_mat = P.Reshape() - self.sub = P.Sub() - self.expanddims = P.ExpandDims() - self.cast = P.Cast() - - def construct(self, length): - """Generates matrix of relative positions between inputs.""" - range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(length)), mstype.int32) - range_vec_col_out = self.range_mat(range_vec_row_out, (length, -1)) - tile_row_out = self.tile(range_vec_row_out, (length,)) - tile_col_out = self.tile(range_vec_col_out, (1, length)) - range_mat_out = self.range_mat(tile_row_out, (length, length)) - transpose_out = self.range_mat(tile_col_out, (length, length)) - distance_mat = self.sub(range_mat_out, transpose_out) - - distance_mat_clipped = ops.clip_by_value(distance_mat, - self._min_relative_position, - self._max_relative_position) - - # Shift values to be >=0. Each integer still uniquely identifies a - # relative position difference. - final_mat = distance_mat_clipped + self._max_relative_position - return final_mat - - -class RelaPosEmbeddingsGenerator(nn.Cell): - """ - Generates tensor of size [length, length, depth]. - - Args: - length (int): Length of one dim for the matrix to be generated. - depth (int): Size of each attention head. - max_relative_position (int): Maxmum value of relative position. - initializer_range (float): Initialization value of TruncatedNormal. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - - def __init__(self, - depth, - max_relative_position, - initializer_range, - use_one_hot_embeddings=False): - super(RelaPosEmbeddingsGenerator, self).__init__() - self.depth = depth - self.vocab_size = max_relative_position * 2 + 1 - self.use_one_hot_embeddings = use_one_hot_embeddings - - self.embeddings_table = Parameter( - initializer(TruncatedNormal(initializer_range), - [self.vocab_size, self.depth])) - - self.relative_positions_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position) - self.reshape = P.Reshape() - self.one_hot = nn.OneHot(depth=self.vocab_size) - self.shape = P.Shape() - self.gather = P.Gather() # index_select - self.matmul = P.BatchMatMul() - - def construct(self, length): - """Generate embedding for each relative position of dimension depth.""" - relative_positions_matrix_out = self.relative_positions_matrix(length) - - if self.use_one_hot_embeddings: - flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) - one_hot_relative_positions_matrix = self.one_hot( - flat_relative_positions_matrix) - embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) - my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) - embeddings = self.reshape(embeddings, my_shape) - else: - embeddings = self.gather(self.embeddings_table, - relative_positions_matrix_out, 0) - return embeddings - - -class SaturateCast(nn.Cell): - """ - Performs a safe saturating cast. This operation applies proper clamping before casting to prevent - the danger that the value will overflow or underflow. - - Args: - src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. - dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. - """ - - def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): - super(SaturateCast, self).__init__() - np_type = mstype.dtype_to_nptype(dst_type) - - self.tensor_min_type = float(np.finfo(np_type).min) - self.tensor_max_type = float(np.finfo(np_type).max) - - self.min_op = P.Minimum() - self.max_op = P.Maximum() - self.cast = P.Cast() - self.dst_type = dst_type - - def construct(self, x): - out = self.max_op(x, self.tensor_min_type) - out = self.min_op(out, self.tensor_max_type) - return self.cast(out, self.dst_type) - - -class BertAttention(nn.Cell): - """ - Apply multi-headed attention from "from_tensor" to "to_tensor". - - Args: - from_tensor_width (int): Size of last dim of from_tensor. - to_tensor_width (int): Size of last dim of to_tensor. - num_attention_heads (int): Number of attention heads. Default: 1. - size_per_head (int): Size of each attention head. Default: 512. - query_act (str): Activation function for the query transform. Default: None. - key_act (str): Activation function for the key transform. Default: None. - value_act (str): Activation function for the value transform. Default: None. - has_attention_mask (bool): Specifies whether to use attention mask. Default: False. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.0. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. - """ - - def __init__(self, - from_tensor_width, - to_tensor_width, - num_attention_heads=1, - size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=False, - attention_probs_dropout_prob=0.0, - use_one_hot_embeddings=False, - initializer_range=0.02, - use_relative_positions=False, - compute_type=mstype.float32): - - super(BertAttention, self).__init__() - self.num_attention_heads = num_attention_heads - self.size_per_head = size_per_head - self.has_attention_mask = has_attention_mask - self.use_relative_positions = use_relative_positions - - self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) - self.reshape = P.Reshape() - self.shape_from_2d = (-1, from_tensor_width) - self.shape_to_2d = (-1, to_tensor_width) - weight = TruncatedNormal(initializer_range) - units = num_attention_heads * size_per_head - self.query_layer = nn.Dense(from_tensor_width, - units, - activation=query_act, - weight_init=weight).to_float(compute_type) - self.key_layer = nn.Dense(to_tensor_width, - units, - activation=key_act, - weight_init=weight).to_float(compute_type) - self.value_layer = nn.Dense(to_tensor_width, - units, - activation=value_act, - weight_init=weight).to_float(compute_type) - - self.matmul_trans_b = P.BatchMatMul(transpose_b=True) - self.multiply = P.Mul() - self.transpose = P.Transpose() - self.trans_shape = (0, 2, 1, 3) - self.trans_shape_relative = (2, 0, 1, 3) - self.trans_shape_position = (1, 2, 0, 3) - self.multiply_data = -10000.0 - self.matmul = P.BatchMatMul() - - self.softmax = nn.Softmax() - self.dropout = nn.Dropout(p=attention_probs_dropout_prob) - - if self.has_attention_mask: - self.expand_dims = P.ExpandDims() - self.sub = P.Sub() - self.add = P.Add() - self.cast = P.Cast() - self.get_dtype = P.DType() - - self.shape_return = (-1, num_attention_heads * size_per_head) - - self.cast_compute_type = SaturateCast(dst_type=compute_type) - if self.use_relative_positions: - self._generate_relative_positions_embeddings = \ - RelaPosEmbeddingsGenerator(depth=size_per_head, - max_relative_position=16, - initializer_range=initializer_range, - use_one_hot_embeddings=use_one_hot_embeddings) - - def construct(self, from_tensor, to_tensor, attention_mask): - """reshape 2d/3d input tensors to 2d""" - shape_from = F.shape(attention_mask)[2] - from_tensor = F.depend(from_tensor, shape_from) - from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) - to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) - query_out = self.query_layer(from_tensor_2d) - key_out = self.key_layer(to_tensor_2d) - value_out = self.value_layer(to_tensor_2d) - - query_layer = self.reshape(query_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) - query_layer = self.transpose(query_layer, self.trans_shape) - key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) - key_layer = self.transpose(key_layer, self.trans_shape) - - attention_scores = self.matmul_trans_b(query_layer, key_layer) - - # use_relative_position, supplementary logic - if self.use_relative_positions: - relations_keys = self._generate_relative_positions_embeddings(shape_from) - relations_keys = self.cast_compute_type(relations_keys) - - query_layer_t = self.transpose(query_layer, self.trans_shape_relative) - - query_layer_r = self.reshape(query_layer_t, - (shape_from, - -1, - self.size_per_head)) - - key_position_scores = self.matmul_trans_b(query_layer_r, - relations_keys) - - key_position_scores_r = self.reshape(key_position_scores, - (shape_from, - -1, - self.num_attention_heads, - shape_from)) - - key_position_scores_r_t = self.transpose(key_position_scores_r, - self.trans_shape_position) - attention_scores = attention_scores + key_position_scores_r_t - - attention_scores = self.multiply(self.scores_mul, attention_scores) - - if self.has_attention_mask: - attention_mask = self.expand_dims(attention_mask, 1) - multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), - self.cast(attention_mask, self.get_dtype(attention_scores))) - - adder = self.multiply(multiply_out, self.multiply_data) - attention_scores = self.add(adder, attention_scores) - - attention_probs = self.softmax(attention_scores) - attention_probs = self.dropout(attention_probs) - - value_layer = self.reshape(value_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) - value_layer = self.transpose(value_layer, self.trans_shape) - context_layer = self.matmul(attention_probs, value_layer) - - # use_relative_position, supplementary logic - if self.use_relative_positions: - relations_values = self._generate_relative_positions_embeddings(shape_from) - relations_values = self.cast_compute_type(relations_values) - - attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) - - attention_probs_r = self.reshape( - attention_probs_t, - (shape_from, - -1, - shape_from)) - - value_position_scores = self.matmul(attention_probs_r, - relations_values) - - value_position_scores_r = self.reshape(value_position_scores, - (shape_from, - -1, - self.num_attention_heads, - self.size_per_head)) - - value_position_scores_r_t = self.transpose(value_position_scores_r, - self.trans_shape_position) - context_layer = context_layer + value_position_scores_r_t - - context_layer = self.transpose(context_layer, self.trans_shape) - context_layer = self.reshape(context_layer, self.shape_return) - - return context_layer - - -class BertSelfAttention(nn.Cell): - """ - Apply self-attention. - - Args: - hidden_size (int): Size of the bert encoder layers. - num_attention_heads (int): Number of attention heads. Default: 12. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. - """ - - def __init__(self, - hidden_size, - num_attention_heads=12, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - compute_type=mstype.float32): - super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError("The hidden size (%d) is not a multiple of the number " - "of attention heads (%d)" % (hidden_size, num_attention_heads)) - - self.size_per_head = int(hidden_size / num_attention_heads) - - self.attention = BertAttention( - from_tensor_width=hidden_size, - to_tensor_width=hidden_size, - num_attention_heads=num_attention_heads, - size_per_head=self.size_per_head, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - use_relative_positions=use_relative_positions, - has_attention_mask=True, - compute_type=compute_type) - - self.output = BertOutput(in_channels=hidden_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - - def construct(self, input_tensor, attention_mask): - attention_output = self.attention(input_tensor, input_tensor, attention_mask) - output = self.output(attention_output, input_tensor) - return output - - -class BertEncoderCell(nn.Cell): - """ - Encoder cells used in BertTransformer. - - Args: - hidden_size (int): Size of the bert encoder layers. Default: 768. - num_attention_heads (int): Number of attention heads. Default: 12. - intermediate_size (int): Size of intermediate layer. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.02. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. - """ - - def __init__(self, - hidden_size=768, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.02, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32): - super(BertEncoderCell, self).__init__() - self.attention = BertSelfAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - compute_type=compute_type) - self.intermediate = nn.Dense(in_channels=hidden_size, - out_channels=intermediate_size, - activation=hidden_act, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.output = BertOutput(in_channels=intermediate_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type) - - def construct(self, hidden_states, attention_mask): - # self-attention - attention_output = self.attention(hidden_states, attention_mask) - # feed construct - intermediate_output = self.intermediate(attention_output) - # add and normalize - output = self.output(intermediate_output, attention_output) - return output - - -class BertTransformer(nn.Cell): - """ - Multi-layer bert transformer. - - Args: - hidden_size (int): Size of the encoder layers. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - return_all_encoders (bool): Specifies whether to return all encoders. Default: False. - """ - - def __init__(self, - hidden_size, - num_hidden_layers, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=False): - super(BertTransformer, self).__init__() - self.return_all_encoders = return_all_encoders - - layers = [] - for _ in range(num_hidden_layers): - layer = BertEncoderCell(hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - hidden_act=hidden_act, - compute_type=compute_type) - layers.append(layer) - - self.layers = nn.CellList(layers) - - self.reshape = P.Reshape() - self.shape = (-1, hidden_size) - - def construct(self, input_tensor, attention_mask): - """Multi-layer bert transformer.""" - prev_output = self.reshape(input_tensor, self.shape) - - all_encoder_layers = () - for layer_module in self.layers: - layer_output = layer_module(prev_output, attention_mask) - prev_output = layer_output - - if self.return_all_encoders: - shape = F.shape(input_tensor) - layer_output = self.reshape(layer_output, shape) - all_encoder_layers = all_encoder_layers + (layer_output,) - - if not self.return_all_encoders: - shape = F.shape(input_tensor) - prev_output = self.reshape(prev_output, shape) - all_encoder_layers = all_encoder_layers + (prev_output,) - return all_encoder_layers - - -class CreateAttentionMaskFromInputMask(nn.Cell): - """ - Create attention mask according to input mask. - - Args: - config (Class): Configuration for BertModel. - """ - - def __init__(self, config): - super(CreateAttentionMaskFromInputMask, self).__init__() - self.input_mask = None - - self.cast = P.Cast() - self.reshape = P.Reshape() - - def construct(self, input_mask): - seq_length = F.shape(input_mask)[1] - attention_mask = self.cast(self.reshape(input_mask, (-1, 1, seq_length)), mstype.float32) - return attention_mask - - -class BertModel(nn.Cell): - """ - Bidirectional Encoder Representations from Transformers. - - Args: - config (Class): Configuration for BertModel. - is_training (bool): True for training mode. False for eval mode. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - """ - - def __init__(self, - config, - is_training, - use_one_hot_embeddings=False): - super(BertModel, self).__init__() - config = copy.deepcopy(config) - if not is_training: - config.hidden_dropout_prob = 0.0 - config.attention_probs_dropout_prob = 0.0 - - self.hidden_size = config.hidden_size - self.num_hidden_layers = config.num_hidden_layers - self.embedding_size = config.hidden_size - self.token_type_ids = None - - self.last_idx = self.num_hidden_layers - 1 - output_embedding_shape = [-1, config.seq_length, self.embedding_size] - - self.bert_embedding_lookup = nn.Embedding( - vocab_size=config.vocab_size, - embedding_size=self.embedding_size, - use_one_hot=use_one_hot_embeddings, - embedding_table=TruncatedNormal(config.initializer_range)) - - self.bert_embedding_postprocessor = EmbeddingPostprocessor( - embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_relative_positions=config.use_relative_positions, - use_token_type=True, - token_type_vocab_size=config.type_vocab_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=0.02, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) - - self.bert_encoder = BertTransformer( - hidden_size=self.hidden_size, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=self.num_hidden_layers, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - use_relative_positions=config.use_relative_positions, - hidden_act=config.hidden_act, - compute_type=config.compute_type, - return_all_encoders=True) - - self.cast = P.Cast() - self.dtype = config.dtype - self.cast_compute_type = SaturateCast(dst_type=config.compute_type) - self.slice = P.StridedSlice() - - self.squeeze_1 = P.Squeeze(axis=1) - self.dense = nn.Dense(self.hidden_size, self.hidden_size, - activation="tanh", - weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) - self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) - - def construct(self, input_ids, token_type_ids, input_mask): - """Bidirectional Encoder Representations from Transformers.""" - # embedding - embedding_tables = self.bert_embedding_lookup.embedding_table - word_embeddings = self.bert_embedding_lookup(input_ids) - embedding_output = self.bert_embedding_postprocessor(token_type_ids, - word_embeddings) - - # attention mask [batch_size, seq_length, seq_length] - attention_mask = self._create_attention_mask_from_input_mask(input_mask) - - # bert encoder - encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), - attention_mask) - - sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) - - # pooler - batch_size = P.Shape()(input_ids)[0] - sequence_slice = self.slice(sequence_output, - (0, 0, 0), - (batch_size, 1, self.hidden_size), - (1, 1, 1)) - first_token = self.squeeze_1(sequence_slice) - pooled_output = self.dense(first_token) - pooled_output = self.cast(pooled_output, self.dtype) - - return sequence_output, pooled_output, embedding_tables +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" + +import copy +import math +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.ops as ops +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + dtype=mstype.float32, + compute_type=mstype.float32): + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size])) + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.Gather() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + """Get output and embeddings lookup table""" + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table.value() + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + + def __init__(self, + embedding_size, + embedding_shape, + use_relative_positions=False, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.token_type_embedding = nn.Embedding( + vocab_size=token_type_vocab_size, + embedding_size=embedding_size, + use_one_hot=use_one_hot_embeddings) + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.dropout = nn.Dropout(p=dropout_prob) + self.gather = P.Gather() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + _, seq, _ = self.shape + self.full_position_embedding = nn.Embedding( + vocab_size=max_position_embeddings, + embedding_size=embedding_size, + use_one_hot=False) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) + self.add = P.Add() + + def construct(self, token_type_ids, word_embeddings): + """Postprocessors apply positional and token type embeddings to word embeddings.""" + output = word_embeddings + if self.use_token_type: + token_type_embeddings = self.token_type_embedding(token_type_ids) + output = self.add(output, token_type_embeddings) + if not self.use_relative_positions: + shape = F.shape(output) + position_ids = self.position_ids[:, :shape[1]] + position_embeddings = self.full_position_embedding(position_ids) + output = self.add(output, position_embeddings) + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(p=dropout_prob) + self.dropout_prob = dropout_prob + self.add = P.Add() + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(input_tensor, output) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + + def __init__(self, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position + + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self, length): + """Generates matrix of relative positions between inputs.""" + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (length, -1)) + tile_row_out = self.tile(range_vec_row_out, (length,)) + tile_col_out = self.tile(range_vec_col_out, (1, length)) + range_mat_out = self.range_mat(tile_row_out, (length, length)) + transpose_out = self.range_mat(tile_col_out, (length, length)) + distance_mat = self.sub(range_mat_out, transpose_out) + + distance_mat_clipped = ops.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth])) + + self.relative_positions_matrix = RelaPosMatrixGenerator(max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = nn.OneHot(depth=self.vocab_size) + self.shape = P.Shape() + self.gather = P.Gather() # index_select + self.matmul = P.BatchMatMul() + + def construct(self, length): + """Generate embedding for each relative position of dimension depth.""" + relative_positions_matrix_out = self.relative_positions_matrix(length) + + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + + self.tensor_min_type = float(np.finfo(np_type).min) + self.tensor_max_type = float(np.finfo(np_type).max) + + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + + def __init__(self, + from_tensor_width, + to_tensor_width, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + use_relative_positions=False, + compute_type=mstype.float32): + + super(BertAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = -10000.0 + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(p=attention_probs_dropout_prob) + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.Add() + self.cast = P.Cast() + self.get_dtype = P.DType() + + self.shape_return = (-1, num_attention_heads * size_per_head) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + """reshape 2d/3d input tensors to 2d""" + shape_from = F.shape(attention_mask)[2] + from_tensor = F.depend(from_tensor, shape_from) + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + relations_keys = self._generate_relative_positions_embeddings(shape_from) + relations_keys = self.cast_compute_type(relations_keys) + + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + + query_layer_r = self.reshape(query_layer_t, + (shape_from, + -1, + self.size_per_head)) + + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + + key_position_scores_r = self.reshape(key_position_scores, + (shape_from, + -1, + self.num_attention_heads, + shape_from)) + + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = self.multiply(self.scores_mul, attention_scores) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + relations_values = self._generate_relative_positions_embeddings(shape_from) + relations_values = self.cast_compute_type(relations_values) + + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + + attention_probs_r = self.reshape( + attention_probs_t, + (shape_from, + -1, + shape_from)) + + value_position_scores = self.matmul(attention_probs_r, + relations_values) + + value_position_scores_r = self.reshape(value_position_scores, + (shape_from, + -1, + self.num_attention_heads, + self.size_per_head)) + + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + + return context_layer + + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + + def __init__(self, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + + self.size_per_head = int(hidden_size / num_attention_heads) + + self.attention = BertAttention( + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + compute_type=compute_type) + + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + attention_output = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + + def __init__(self, + hidden_size=768, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type) + + def construct(self, hidden_states, attention_mask): + # self-attention + attention_output = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + hidden_size (int): Size of the encoder layers. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + + def __init__(self, + hidden_size, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type) + layers.append(layer) + + self.layers = nn.CellList(layers) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + """Multi-layer bert transformer.""" + prev_output = self.reshape(input_tensor, self.shape) + + all_encoder_layers = () + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + if self.return_all_encoders: + shape = F.shape(input_tensor) + layer_output = self.reshape(layer_output, shape) + all_encoder_layers = all_encoder_layers + (layer_output,) + + if not self.return_all_encoders: + shape = F.shape(input_tensor) + prev_output = self.reshape(prev_output, shape) + all_encoder_layers = all_encoder_layers + (prev_output,) + return all_encoder_layers + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask = None + + self.cast = P.Cast() + self.reshape = P.Reshape() + + def construct(self, input_mask): + seq_length = F.shape(input_mask)[1] + attention_mask = self.cast(self.reshape(input_mask, (-1, 1, seq_length)), mstype.float32) + return attention_mask + + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [-1, config.seq_length, self.embedding_size] + + self.bert_embedding_lookup = nn.Embedding( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + use_one_hot=use_one_hot_embeddings, + embedding_table=TruncatedNormal(config.initializer_range)) + + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + self.bert_encoder = BertTransformer( + hidden_size=self.hidden_size, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + """Bidirectional Encoder Representations from Transformers.""" + # embedding + embedding_tables = self.bert_embedding_lookup.embedding_table + word_embeddings = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, + word_embeddings) + + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + + # bert encoder + encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + + # pooler + batch_size = P.Shape()(input_ids)[0] + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + + return sequence_output, pooled_output, embedding_tables diff --git a/tests/st/networks/models/bert/bert_performance/src/utils.py b/tests/st/networks/models/bert/bert_performance/src/utils.py index 7a2617d524c..db491b88b72 100644 --- a/tests/st/networks/models/bert/bert_performance/src/utils.py +++ b/tests/st/networks/models/bert/bert_performance/src/utils.py @@ -1,273 +1,273 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -Functional Cells used in Bert finetune and evaluation. -""" - -import os -import collections -import math -import numpy as np -import mindspore.nn as nn -from mindspore import log as logger -from mindspore.common import dtype as mstype -from mindspore.common.tensor import Tensor -from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR -from mindspore.ops import operations as P -from mindspore.train import Callback, Metric - - -class CrossEntropyCalculation(nn.Cell): - """ - Cross Entropy loss - """ - - def __init__(self, is_training=True): - super(CrossEntropyCalculation, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.reduce_sum = P.ReduceSum() - self.reduce_mean = P.ReduceMean() - self.reshape = P.Reshape() - self.last_idx = (-1,) - self.neg = P.Neg() - self.cast = P.Cast() - self.is_training = is_training - - def construct(self, logits, label_ids, num_labels): - if self.is_training: - label_ids = self.reshape(label_ids, self.last_idx) - one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) - per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) - loss = self.reduce_mean(per_example_loss, self.last_idx) - return_value = self.cast(loss, mstype.float32) - else: - return_value = logits * 1.0 - return return_value - - -def make_directory(path: str): - """Make directory.""" - if path is None or not isinstance(path, str) or path.strip() == "": - logger.error("The path(%r) is invalid type.", path) - raise TypeError("Input path is invalid type") - - # convert the relative paths - path = os.path.realpath(path) - logger.debug("The abs path is %r", path) - - # check the path is exist and write permissions? - if os.path.exists(path): - real_path = path - else: - # All exceptions need to be caught because create directory maybe have some limit(permissions) - logger.debug("The directory(%s) doesn't exist, will create it", path) - try: - os.makedirs(path, exist_ok=True) - real_path = path - except PermissionError as e: - logger.error("No write permission on the directory(%r), error = %r", path, e) - raise TypeError("No write permission on the directory.") - return real_path - - -class LossCallBack(Callback): - """ - Monitor the loss in training. - If the loss in NAN or INF terminating training. - Note: - if per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - - def __init__(self, dataset_size=-1): - super(LossCallBack, self).__init__() - self._dataset_size = dataset_size - - def step_end(self, run_context): - """ - Print loss after each step - """ - cb_params = run_context.original_args() - if self._dataset_size > 0: - percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size) - if percent == 0: - percent = 1 - epoch_num -= 1 - print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" - .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)), - flush=True) - else: - print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs)), flush=True) - - -class BertLearningRate(LearningRateSchedule): - """ - Warmup-decay learning rate for Bert network. - """ - - def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): - super(BertLearningRate, self).__init__() - self.warmup_flag = False - if warmup_steps > 0: - self.warmup_flag = True - self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) - self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) - self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) - - self.greater = P.Greater() - self.one = Tensor(np.array([1.0]).astype(np.float32)) - self.cast = P.Cast() - - def construct(self, global_step): - decay_lr = self.decay_lr(global_step) - if self.warmup_flag: - is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) - warmup_lr = self.warmup_lr(global_step) - lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr - else: - lr = decay_lr - return lr - - -def convert_labels_to_index(label_list): - """ - Convert label_list to indices for NER task. - """ - label2id = collections.OrderedDict() - label2id["O"] = 0 - prefix = ["S_", "B_", "M_", "E_"] - index = 0 - for label in label_list: - for pre in prefix: - index += 1 - sub_label = pre + label - label2id[sub_label] = index - return label2id - - -def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power): - """ - generate learning rate array - - Args: - global_step(int): current step - lr_init(float): init learning rate - lr_end(float): end learning rate - lr_max(float): max learning rate - warmup_steps(int): number of warmup epochs - total_steps(int): total epoch of training - poly_power(int): poly learning rate power - - Returns: - np.array, learning rate array - """ - lr_each_step = [] - if warmup_steps != 0: - inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) - else: - inc_each_step = 0 - for i in range(total_steps): - if i < warmup_steps: - lr = float(lr_init) + inc_each_step * float(i) - else: - base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) - lr = float(lr_max - lr_end) * (base ** poly_power) - lr = lr + lr_end - if lr < 0.0: - lr = 0.0 - lr_each_step.append(lr) - - learning_rate = np.array(lr_each_step).astype(np.float32) - current_step = global_step - learning_rate = learning_rate[current_step:] - return learning_rate - - -def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000): - learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0, - total_steps=lr_total_steps, poly_power=lr_power) - return Tensor(learning_rate) - - -def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000): - damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0, - total_steps=damping_total_steps, poly_power=damping_power) - return Tensor(damping) - - -class EvalCallBack(Callback): - """ - Evaluate after a certain amount of training samples. - Args: - model (Model): The network model. - eval_ds (Dataset): The eval dataset. - global_batch (int): The batchsize of the sum of all devices. - eval_samples (int): The number of eval interval samples. - """ - - def __init__(self, model, eval_ds, global_batch, eval_samples): - super(EvalCallBack, self).__init__() - self.model = model - self.eval_ds = eval_ds - self.global_batch = global_batch - self.eval_samples = eval_samples - self.last_eval_step = 0 - - def epoch_end(self, run_context): - """ - Evaluate after training a certain number of samples. - """ - cb_params = run_context.original_args() - num_samples = (cb_params.cur_step_num - self.last_eval_step) * self.global_batch - if num_samples < self.eval_samples: - return - self.last_eval_step = cb_params.cur_step_num - total_sumples = cb_params.cur_step_num * self.global_batch - res = self.model.eval(self.eval_ds, dataset_sink_mode=True) - res = res['bert_acc'] - print("====================================", flush=True) - print("Accuracy is: ", "%.6f" % res, ", current samples is: ", total_sumples) - print("====================================", flush=True) - - -class BertMetric(Metric): - """ - The metric of bert network. - Args: - batch_size (int): The batchsize of each device. - """ - - def __init__(self, batch_size): - super(BertMetric, self).__init__() - self.clear() - self.batch_size = batch_size - - def clear(self): - self.mlm_total = 0 - self.mlm_acc = 0 - - def update(self, *inputs): - mlm_acc = self._convert_data(inputs[0]) - mlm_total = self._convert_data(inputs[1]) - self.mlm_acc += mlm_acc - self.mlm_total += mlm_total - - def eval(self): - return self.mlm_acc / self.mlm_total +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Functional Cells used in Bert finetune and evaluation. +""" + +import os +import collections +import math +import numpy as np +import mindspore.nn as nn +from mindspore import log as logger +from mindspore.common import dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +from mindspore.ops import operations as P +from mindspore.train import Callback, Metric + + +class CrossEntropyCalculation(nn.Cell): + """ + Cross Entropy loss + """ + + def __init__(self, is_training=True): + super(CrossEntropyCalculation, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + self.is_training = is_training + + def construct(self, logits, label_ids, num_labels): + if self.is_training: + label_ids = self.reshape(label_ids, self.last_idx) + one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) + loss = self.reduce_mean(per_example_loss, self.last_idx) + return_value = self.cast(loss, mstype.float32) + else: + return_value = logits * 1.0 + return return_value + + +def make_directory(path: str): + """Make directory.""" + if path is None or not isinstance(path, str) or path.strip() == "": + logger.error("The path(%r) is invalid type.", path) + raise TypeError("Input path is invalid type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The abs path is %r", path) + + # check the path is exist and write permissions? + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path, exist_ok=True) + real_path = path + except PermissionError as e: + logger.error("No write permission on the directory(%r), error = %r", path, e) + raise TypeError("No write permission on the directory.") + return real_path + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, dataset_size=-1): + super(LossCallBack, self).__init__() + self._dataset_size = dataset_size + + def step_end(self, run_context): + """ + Print loss after each step + """ + cb_params = run_context.original_args() + if self._dataset_size > 0: + percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size) + if percent == 0: + percent = 1 + epoch_num -= 1 + print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" + .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)), + flush=True) + else: + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs)), flush=True) + + +class BertLearningRate(LearningRateSchedule): + """ + Warmup-decay learning rate for Bert network. + """ + + def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): + super(BertLearningRate, self).__init__() + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True + self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) + self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + + self.greater = P.Greater() + self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + + def construct(self, global_step): + decay_lr = self.decay_lr(global_step) + if self.warmup_flag: + is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) + warmup_lr = self.warmup_lr(global_step) + lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr + else: + lr = decay_lr + return lr + + +def convert_labels_to_index(label_list): + """ + Convert label_list to indices for NER task. + """ + label2id = collections.OrderedDict() + label2id["O"] = 0 + prefix = ["S_", "B_", "M_", "E_"] + index = 0 + for label in label_list: + for pre in prefix: + index += 1 + sub_label = pre + label + label2id[sub_label] = index + return label2id + + +def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power): + """ + generate learning rate array + + Args: + global_step(int): current step + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_steps(int): number of warmup epochs + total_steps(int): total epoch of training + poly_power(int): poly learning rate power + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max - lr_end) * (base ** poly_power) + lr = lr + lr_end + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + current_step = global_step + learning_rate = learning_rate[current_step:] + return learning_rate + + +def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000): + learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0, + total_steps=lr_total_steps, poly_power=lr_power) + return Tensor(learning_rate) + + +def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000): + damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0, + total_steps=damping_total_steps, poly_power=damping_power) + return Tensor(damping) + + +class EvalCallBack(Callback): + """ + Evaluate after a certain amount of training samples. + Args: + model (Model): The network model. + eval_ds (Dataset): The eval dataset. + global_batch (int): The batchsize of the sum of all devices. + eval_samples (int): The number of eval interval samples. + """ + + def __init__(self, model, eval_ds, global_batch, eval_samples): + super(EvalCallBack, self).__init__() + self.model = model + self.eval_ds = eval_ds + self.global_batch = global_batch + self.eval_samples = eval_samples + self.last_eval_step = 0 + + def epoch_end(self, run_context): + """ + Evaluate after training a certain number of samples. + """ + cb_params = run_context.original_args() + num_samples = (cb_params.cur_step_num - self.last_eval_step) * self.global_batch + if num_samples < self.eval_samples: + return + self.last_eval_step = cb_params.cur_step_num + total_sumples = cb_params.cur_step_num * self.global_batch + res = self.model.eval(self.eval_ds, dataset_sink_mode=True) + res = res['bert_acc'] + print("====================================", flush=True) + print("Accuracy is: ", "%.6f" % res, ", current samples is: ", total_sumples) + print("====================================", flush=True) + + +class BertMetric(Metric): + """ + The metric of bert network. + Args: + batch_size (int): The batchsize of each device. + """ + + def __init__(self, batch_size): + super(BertMetric, self).__init__() + self.clear() + self.batch_size = batch_size + + def clear(self): + self.mlm_total = 0 + self.mlm_acc = 0 + + def update(self, *inputs): + mlm_acc = self._convert_data(inputs[0]) + mlm_total = self._convert_data(inputs[1]) + self.mlm_acc += mlm_acc + self.mlm_total += mlm_total + + def eval(self): + return self.mlm_acc / self.mlm_total diff --git a/tests/st/networks/models/resnet50/src/config.py b/tests/st/networks/models/resnet50/src/config.py old mode 100755 new mode 100644 diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py old mode 100755 new mode 100644 diff --git a/tests/st/networks/models/resnet50/src/lr_generator.py b/tests/st/networks/models/resnet50/src/lr_generator.py old mode 100755 new mode 100644 diff --git a/tests/st/networks/models/resnet50/src/resnet.py b/tests/st/networks/models/resnet50/src/resnet.py old mode 100755 new mode 100644 diff --git a/tests/st/networks/models/resnetv1_5.py b/tests/st/networks/models/resnetv1_5.py index 1a6b3ae2503..805b5db7610 100644 --- a/tests/st/networks/models/resnetv1_5.py +++ b/tests/st/networks/models/resnetv1_5.py @@ -1,296 +1,296 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - - -def weight_variable(shape): - ones = np.ones(shape).astype(np.float32) - return Tensor(ones * 0.01) - - -def weight_variable_0(shape): - zeros = np.zeros(shape).astype(np.float32) - return Tensor(zeros) - - -def weight_variable_1(shape): - ones = np.ones(shape).astype(np.float32) - return Tensor(ones) - - -def conv3x3(in_channels, out_channels, stride=1, padding=0): - """3x3 convolution """ - weight_shape = (out_channels, in_channels, 3, 3) - weight = weight_variable(weight_shape) - return nn.Conv2d(in_channels, out_channels, - kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") - - -def conv1x1(in_channels, out_channels, stride=1, padding=0): - """1x1 convolution""" - weight_shape = (out_channels, in_channels, 1, 1) - weight = weight_variable(weight_shape) - return nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") - - -def conv7x7(in_channels, out_channels, stride=1, padding=0): - """1x1 convolution""" - weight_shape = (out_channels, in_channels, 7, 7) - weight = weight_variable(weight_shape) - return nn.Conv2d(in_channels, out_channels, - kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") - - -def bn_with_initialize(out_channels): - shape = (out_channels) - mean = weight_variable_0(shape) - var = weight_variable_1(shape) - beta = weight_variable_0(shape) - gamma = weight_variable_1(shape) - bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) - return bn - - -def bn_with_initialize_last(out_channels): - shape = (out_channels) - mean = weight_variable_0(shape) - var = weight_variable_1(shape) - beta = weight_variable_0(shape) - gamma = weight_variable_0(shape) - bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) - return bn - - -def fc_with_initialize(input_channels, out_channels): - weight_shape = (out_channels, input_channels) - bias_shape = (out_channels) - weight = weight_variable(weight_shape) - bias = weight_variable_0(bias_shape) - - return nn.Dense(input_channels, out_channels, weight, bias) - - -class ResidualBlock(nn.Cell): - expansion = 4 - - def __init__(self, - in_channels, - out_channels, - stride=1): - super(ResidualBlock, self).__init__() - - out_chls = out_channels // self.expansion - self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0) - self.bn1 = bn_with_initialize(out_chls) - - self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0) - self.bn2 = bn_with_initialize(out_chls) - - self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) - self.bn3 = bn_with_initialize_last(out_channels) - - self.relu = P.ReLU() - self.add = P.Add() - - def construct(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - out = self.add(out, identity) - out = self.relu(out) - - return out - - -class ResidualBlockWithDown(nn.Cell): - expansion = 4 - - def __init__(self, - in_channels, - out_channels, - stride=1, - down_sample=False): - super(ResidualBlockWithDown, self).__init__() - - out_chls = out_channels // self.expansion - self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0) - self.bn1 = bn_with_initialize(out_chls) - - self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0) - self.bn2 = bn_with_initialize(out_chls) - - self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) - self.bn3 = bn_with_initialize_last(out_channels) - - self.relu = P.ReLU() - self.downSample = down_sample - - self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) - self.bn_down_sample = bn_with_initialize(out_channels) - self.add = P.Add() - - def construct(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - identity = self.conv_down_sample(identity) - identity = self.bn_down_sample(identity) - - out = self.add(out, identity) - out = self.relu(out) - - return out - - -class MakeLayer0(nn.Cell): - - def __init__(self, block, in_channels, out_channels, stride): - super(MakeLayer0, self).__init__() - self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) - self.b = block(out_channels, out_channels, stride=stride) - self.c = block(out_channels, out_channels, stride=1) - - def construct(self, x): - x = self.a(x) - x = self.b(x) - x = self.c(x) - - return x - - -class MakeLayer1(nn.Cell): - - def __init__(self, block, in_channels, out_channels, stride): - super(MakeLayer1, self).__init__() - self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) - self.b = block(out_channels, out_channels, stride=1) - self.c = block(out_channels, out_channels, stride=1) - self.d = block(out_channels, out_channels, stride=1) - - def construct(self, x): - x = self.a(x) - x = self.b(x) - x = self.c(x) - x = self.d(x) - - return x - - -class MakeLayer2(nn.Cell): - - def __init__(self, block, in_channels, out_channels, stride): - super(MakeLayer2, self).__init__() - self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) - self.b = block(out_channels, out_channels, stride=1) - self.c = block(out_channels, out_channels, stride=1) - self.d = block(out_channels, out_channels, stride=1) - self.e = block(out_channels, out_channels, stride=1) - self.f = block(out_channels, out_channels, stride=1) - - def construct(self, x): - x = self.a(x) - x = self.b(x) - x = self.c(x) - x = self.d(x) - x = self.e(x) - x = self.f(x) - - return x - - -class MakeLayer3(nn.Cell): - - def __init__(self, block, in_channels, out_channels, stride): - super(MakeLayer3, self).__init__() - self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) - self.b = block(out_channels, out_channels, stride=1) - self.c = block(out_channels, out_channels, stride=1) - - def construct(self, x): - x = self.a(x) - x = self.b(x) - x = self.c(x) - - return x - - -class ResNet(nn.Cell): - - def __init__(self, block, num_classes=100, batch_size=32): - super(ResNet, self).__init__() - self.batch_size = batch_size - self.num_classes = num_classes - - self.conv1 = conv7x7(3, 64, stride=2, padding=0) - - self.bn1 = bn_with_initialize(64) - self.relu = P.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="SAME") - - self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) - self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) - self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) - self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) - - self.pool = P.ReduceMean(keep_dims=True) - self.fc = fc_with_initialize(512 * block.expansion, num_classes) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.pool(x, (-2, -1)) - x = self.flatten(x) - x = self.fc(x) - return x - - -def resnet50(batch_size, num_classes): - return ResNet(ResidualBlock, num_classes, batch_size) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +def weight_variable(shape): + ones = np.ones(shape).astype(np.float32) + return Tensor(ones * 0.01) + + +def weight_variable_0(shape): + zeros = np.zeros(shape).astype(np.float32) + return Tensor(zeros) + + +def weight_variable_1(shape): + ones = np.ones(shape).astype(np.float32) + return Tensor(ones) + + +def conv3x3(in_channels, out_channels, stride=1, padding=0): + """3x3 convolution """ + weight_shape = (out_channels, in_channels, 3, 3) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 1, 1) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv7x7(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 7, 7) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def bn_with_initialize(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_1(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def bn_with_initialize_last(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_0(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def fc_with_initialize(input_channels, out_channels): + weight_shape = (out_channels, input_channels) + bias_shape = (out_channels) + weight = weight_variable(weight_shape) + bias = weight_variable_0(bias_shape) + + return nn.Dense(input_channels, out_channels, weight, bias) + + +class ResidualBlock(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1): + super(ResidualBlock, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.add = P.Add() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResidualBlockWithDown(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + down_sample=False): + super(ResidualBlockWithDown, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.downSample = down_sample + + self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) + self.bn_down_sample = bn_with_initialize(out_channels) + self.add = P.Add() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.conv_down_sample(identity) + identity = self.bn_down_sample(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class MakeLayer0(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer0, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) + self.b = block(out_channels, out_channels, stride=stride) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class MakeLayer1(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer1, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + + return x + + +class MakeLayer2(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer2, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + self.e = block(out_channels, out_channels, stride=1) + self.f = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + x = self.e(x) + x = self.f(x) + + return x + + +class MakeLayer3(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer3, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class ResNet(nn.Cell): + + def __init__(self, block, num_classes=100, batch_size=32): + super(ResNet, self).__init__() + self.batch_size = batch_size + self.num_classes = num_classes + + self.conv1 = conv7x7(3, 64, stride=2, padding=0) + + self.bn1 = bn_with_initialize(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="SAME") + + self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) + self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) + self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) + self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) + + self.pool = P.ReduceMean(keep_dims=True) + self.fc = fc_with_initialize(512 * block.expansion, num_classes) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.pool(x, (-2, -1)) + x = self.flatten(x) + x = self.fc(x) + return x + + +def resnet50(batch_size, num_classes): + return ResNet(ResidualBlock, num_classes, batch_size) diff --git a/tests/st/networks/test_network_main.py b/tests/st/networks/test_network_main.py index b83a5c23472..c17a900adf9 100644 --- a/tests/st/networks/test_network_main.py +++ b/tests/st/networks/test_network_main.py @@ -1,84 +1,84 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Function: - test network -Usage: - python test_network_main.py --net lenet --target Ascend -""" -import argparse - -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Momentum - -from .models.alexnet import AlexNet -from .models.lenet import LeNet -from .models.resnetv1_5 import resnet50 - - -def train(net, data, label): - learning_rate = 0.01 - momentum = 0.9 - - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer - train_network.set_train() - res = train_network(data, label) - print(res) - assert res - - -def test_resnet50(): - data = Tensor(np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = resnet50(32, 10) - train(net, data, label) - - -def test_lenet(): - net = LeNet() - data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([net.batch_size]).astype(np.int32)) - train(net, data, label) - - -def test_alexnet(): - data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = AlexNet() - train(net, data, label) - - -parser = argparse.ArgumentParser(description='MindSpore Testing Network') -parser.add_argument('--net', default='resnet50', type=str, help='net name') -parser.add_argument('--device', default='Ascend', type=str, help='device target') -if __name__ == "__main__": - args = parser.parse_args() - context.set_context(device_target=args.device) - if args.net == 'resnet50': - test_resnet50() - elif args.net == 'lenet': - test_lenet() - elif args.net == 'alexnet': - test_alexnet() - else: - print("Please add net name like --net lenet") +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + test network +Usage: + python test_network_main.py --net lenet --target Ascend +""" +import argparse + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum + +from .models.alexnet import AlexNet +from .models.lenet import LeNet +from .models.resnetv1_5 import resnet50 + + +def train(net, data, label): + learning_rate = 0.01 + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + res = train_network(data, label) + print(res) + assert res + + +def test_resnet50(): + data = Tensor(np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = resnet50(32, 10) + train(net, data, label) + + +def test_lenet(): + net = LeNet() + data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([net.batch_size]).astype(np.int32)) + train(net, data, label) + + +def test_alexnet(): + data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = AlexNet() + train(net, data, label) + + +parser = argparse.ArgumentParser(description='MindSpore Testing Network') +parser.add_argument('--net', default='resnet50', type=str, help='net name') +parser.add_argument('--device', default='Ascend', type=str, help='device target') +if __name__ == "__main__": + args = parser.parse_args() + context.set_context(device_target=args.device) + if args.net == 'resnet50': + test_resnet50() + elif args.net == 'lenet': + test_lenet() + elif args.net == 'alexnet': + test_alexnet() + else: + print("Please add net name like --net lenet") diff --git a/tests/st/nn/test_avgpool3d.py b/tests/st/nn/test_avgpool3d.py index b86ddbaeab8..73a58933209 100644 --- a/tests/st/nn/test_avgpool3d.py +++ b/tests/st/nn/test_avgpool3d.py @@ -1,58 +1,58 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from tests.mark_utils import arg_mark - - -class Net(nn.Cell): - def __init__(self, kernel_size=1, stride=1, pad_mode="valid", padding=0, ceil_mode=False, count_include_pad=True, - divisor_override=None): - super(Net, self).__init__() - self.pool = nn.AvgPool3d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, - ceil_mode=ceil_mode, count_include_pad=count_include_pad, - divisor_override=divisor_override) - - def construct(self, x): - out = self.pool(x) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_avgpool3d_normal(mode): - """ - Feature: AvgPool3d - Description: Verify the result of AvgPool3d - Expectation: success - """ - ms.set_context(mode=mode) - x1 = ops.randn(1, 2, 4, 4, 5).astype(ms.float32) - pool1 = Net(kernel_size=3, stride=1) - output1 = pool1(x1) - - x2 = ops.randn(6, 5, 7, 7, 5).astype(ms.float32) - pool2 = Net(kernel_size=4, stride=2, pad_mode='pad', padding=(2, 2, 1), divisor_override=10) - output2 = pool2(x2) - - assert output1.shape == (1, 2, 2, 2, 3) - assert output2.shape == (6, 5, 4, 4, 2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from tests.mark_utils import arg_mark + + +class Net(nn.Cell): + def __init__(self, kernel_size=1, stride=1, pad_mode="valid", padding=0, ceil_mode=False, count_include_pad=True, + divisor_override=None): + super(Net, self).__init__() + self.pool = nn.AvgPool3d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, + ceil_mode=ceil_mode, count_include_pad=count_include_pad, + divisor_override=divisor_override) + + def construct(self, x): + out = self.pool(x) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_avgpool3d_normal(mode): + """ + Feature: AvgPool3d + Description: Verify the result of AvgPool3d + Expectation: success + """ + ms.set_context(mode=mode) + x1 = ops.randn(1, 2, 4, 4, 5).astype(ms.float32) + pool1 = Net(kernel_size=3, stride=1) + output1 = pool1(x1) + + x2 = ops.randn(6, 5, 7, 7, 5).astype(ms.float32) + pool2 = Net(kernel_size=4, stride=2, pad_mode='pad', padding=(2, 2, 1), divisor_override=10) + output2 = pool2(x2) + + assert output1.shape == (1, 2, 2, 2, 3) + assert output2.shape == (6, 5, 4, 4, 2) diff --git a/tests/st/nn/test_lppool1d.py b/tests/st/nn/test_lppool1d.py index 3ce6a2bc558..4fb7d4e5ef4 100644 --- a/tests/st/nn/test_lppool1d.py +++ b/tests/st/nn/test_lppool1d.py @@ -1,61 +1,61 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from tests.mark_utils import arg_mark - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.pool = nn.LPPool1d(norm_type=1, kernel_size=3, stride=1) - - def construct(self, x): - out = self.pool(x) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_lppool1d_normal(mode): - """ - Feature: LPPool1d - Description: Verify the result of LPPool1d - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) - y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) - out = net(x) - out2 = net(y) - expect_out = np.array([[[3., 6.], - [15., 18.], - [27., 30.]], - [[39., 42.], - [51., 54.], - [63., 66.]]]) - expect_out2 = np.array([[3., 6.], - [15., 18.], - [27., 30.]]) - assert np.allclose(out.asnumpy(), expect_out) - assert np.allclose(out2.asnumpy(), expect_out2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from tests.mark_utils import arg_mark + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pool = nn.LPPool1d(norm_type=1, kernel_size=3, stride=1) + + def construct(self, x): + out = self.pool(x) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_lppool1d_normal(mode): + """ + Feature: LPPool1d + Description: Verify the result of LPPool1d + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) + y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) + out = net(x) + out2 = net(y) + expect_out = np.array([[[3., 6.], + [15., 18.], + [27., 30.]], + [[39., 42.], + [51., 54.], + [63., 66.]]]) + expect_out2 = np.array([[3., 6.], + [15., 18.], + [27., 30.]]) + assert np.allclose(out.asnumpy(), expect_out) + assert np.allclose(out2.asnumpy(), expect_out2) diff --git a/tests/st/nn/test_lppool2d.py b/tests/st/nn/test_lppool2d.py index 942489498e1..774da481822 100644 --- a/tests/st/nn/test_lppool2d.py +++ b/tests/st/nn/test_lppool2d.py @@ -1,61 +1,61 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from tests.mark_utils import arg_mark - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.pool = nn.LPPool2d(norm_type=1, kernel_size=3, stride=1) - - def construct(self, x): - out = self.pool(x) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_lppool2d_normal(mode): - """ - Feature: LPPool2d - Description: Verify the result of LPPool2d - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) - out = net(x) - expect_out = np.array([[[[54., 63., 72.], - [99., 108., 117.]], - [[234., 243., 252.], - [279., 288., 297.]], - [[414., 423., 432.], - [459., 468., 477.]]], - [[[594., 603., 612.], - [639., 648., 657.]], - [[774., 783., 792.], - [819., 828., 837.]], - [[954., 963., 972.], - [999., 1008., 1017.]]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from tests.mark_utils import arg_mark + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pool = nn.LPPool2d(norm_type=1, kernel_size=3, stride=1) + + def construct(self, x): + out = self.pool(x) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_lppool2d_normal(mode): + """ + Feature: LPPool2d + Description: Verify the result of LPPool2d + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) + out = net(x) + expect_out = np.array([[[[54., 63., 72.], + [99., 108., 117.]], + [[234., 243., 252.], + [279., 288., 297.]], + [[414., 423., 432.], + [459., 468., 477.]]], + [[[594., 603., 612.], + [639., 648., 657.]], + [[774., 783., 792.], + [819., 828., 837.]], + [[954., 963., 972.], + [999., 1008., 1017.]]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/nn/test_margin_ranking_loss.py b/tests/st/nn/test_margin_ranking_loss.py index 36943ec28e9..ba3f021a24e 100644 --- a/tests/st/nn/test_margin_ranking_loss.py +++ b/tests/st/nn/test_margin_ranking_loss.py @@ -1,88 +1,88 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor -from tests.mark_utils import arg_mark - - -class MarginRankingLoss(nn.Cell): - def __init__(self, reduction="none"): - super(MarginRankingLoss, self).__init__() - self.margin_ranking_loss = nn.MarginRankingLoss(margin=0.0, reduction=reduction) - - def construct(self, x, y, label): - return self.margin_ranking_loss(x, y, label) - - -input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) -input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) -target = Tensor(np.array([-1, -1, 1]), ms.float32) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) -def test_margin_ranking_loss(mode, reduction): - """ - Feature: test MarginRankingLoss op with reduction none. - Description: Verify the result of MarginRankingLoss. - Expectation: expect correct forward result. - """ - ms.set_context(mode=mode) - loss = MarginRankingLoss(reduction=reduction) - output = loss(input1, input2, target) - if reduction == 'none': - expect_output = np.array([0.98759997, 0., 2.7003999]) - elif reduction == 'sum': - expect_output = np.array(3.6879997) - else: - expect_output = np.array(1.2293333) - - assert np.allclose(output.asnumpy(), expect_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE]) -def test_tensor_dim(mode): - """ - Feature: test tensor dim - Description: Verify the result of dim. - Expectation: expect correct forward result. - """ - - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.tensor = Tensor([[1, 2, 3], [4, 5, 6]]) - - def construct(self, x): - return x.dim(), self.tensor.dim() - - net = Net() - input11 = Tensor([[1, 2, 3], [4, 5, 6]]) - input22 = Tensor([[[1, 2, 3], [4, 5, 6]]]) - net(input11) - net(input22) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from tests.mark_utils import arg_mark + + +class MarginRankingLoss(nn.Cell): + def __init__(self, reduction="none"): + super(MarginRankingLoss, self).__init__() + self.margin_ranking_loss = nn.MarginRankingLoss(margin=0.0, reduction=reduction) + + def construct(self, x, y, label): + return self.margin_ranking_loss(x, y, label) + + +input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) +input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) +target = Tensor(np.array([-1, -1, 1]), ms.float32) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) +def test_margin_ranking_loss(mode, reduction): + """ + Feature: test MarginRankingLoss op with reduction none. + Description: Verify the result of MarginRankingLoss. + Expectation: expect correct forward result. + """ + ms.set_context(mode=mode) + loss = MarginRankingLoss(reduction=reduction) + output = loss(input1, input2, target) + if reduction == 'none': + expect_output = np.array([0.98759997, 0., 2.7003999]) + elif reduction == 'sum': + expect_output = np.array(3.6879997) + else: + expect_output = np.array(1.2293333) + + assert np.allclose(output.asnumpy(), expect_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE]) +def test_tensor_dim(mode): + """ + Feature: test tensor dim + Description: Verify the result of dim. + Expectation: expect correct forward result. + """ + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.tensor = Tensor([[1, 2, 3], [4, 5, 6]]) + + def construct(self, x): + return x.dim(), self.tensor.dim() + + net = Net() + input11 = Tensor([[1, 2, 3], [4, 5, 6]]) + input22 = Tensor([[[1, 2, 3], [4, 5, 6]]]) + net(input11) + net(input22) diff --git a/tests/st/nn/test_maxpool3d.py b/tests/st/nn/test_maxpool3d.py index 36df976bc6c..821a8af60d9 100644 --- a/tests/st/nn/test_maxpool3d.py +++ b/tests/st/nn/test_maxpool3d.py @@ -1,57 +1,57 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from tests.mark_utils import arg_mark - - -class Net(nn.Cell): - def __init__(self, kernel_size=1, stride=1, pad_mode="valid", padding=0, dilation=1, return_indices=False, - ceil_mode=False): - super(Net, self).__init__() - self.pool = nn.MaxPool3d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, - dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode) - - def construct(self, x): - out = self.pool(x) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_maxpool3d_normal(mode): - """ - Feature: MaxPool3d - Description: Verify the result of MaxPool3d - Expectation: success - """ - ms.set_context(mode=mode) - x = ms.Tensor(np.random.randint(0, 10, [5, 3, 4, 6, 7]), dtype=ms.float32) - pool1 = Net(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=False) - output1 = pool1(x) - - pool2 = Net(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=True) - output2 = pool2(x) - - assert output1.shape == (5, 3, 3, 5, 6) - assert output2[0].shape == (5, 3, 3, 5, 6) - assert output2[1].shape == (5, 3, 3, 5, 6) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from tests.mark_utils import arg_mark + + +class Net(nn.Cell): + def __init__(self, kernel_size=1, stride=1, pad_mode="valid", padding=0, dilation=1, return_indices=False, + ceil_mode=False): + super(Net, self).__init__() + self.pool = nn.MaxPool3d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, + dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode) + + def construct(self, x): + out = self.pool(x) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_maxpool3d_normal(mode): + """ + Feature: MaxPool3d + Description: Verify the result of MaxPool3d + Expectation: success + """ + ms.set_context(mode=mode) + x = ms.Tensor(np.random.randint(0, 10, [5, 3, 4, 6, 7]), dtype=ms.float32) + pool1 = Net(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=False) + output1 = pool1(x) + + pool2 = Net(kernel_size=2, stride=1, pad_mode='pad', padding=1, dilation=3, return_indices=True) + output2 = pool2(x) + + assert output1.shape == (5, 3, 3, 5, 6) + assert output2[0].shape == (5, 3, 3, 5, 6) + assert output2[1].shape == (5, 3, 3, 5, 6) diff --git a/tests/st/nn/test_nn_reflectionpad.py b/tests/st/nn/test_nn_reflectionpad.py index 45f4cfcd780..07d8c2a468e 100644 --- a/tests/st/nn/test_nn_reflectionpad.py +++ b/tests/st/nn/test_nn_reflectionpad.py @@ -1,171 +1,171 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest - -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_reflection_pad1d_input3d(mode): - """ - Feature: ReflectionPad1d - Description: Test ReflectionPad1d with 3D input. - Expectation: success - """ - context.set_context(mode=mode) - x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32)) - padding = (3, 1) - net = nn.ReflectionPad1d(padding) - output = net(x) - expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2], - [7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32)) - - assert np.array_equal(output.asnumpy(), expected_output) - - padding = 2 - expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1], - [6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32)) - net = nn.ReflectionPad1d(padding) - output = net(x) - assert np.array_equal(output.asnumpy(), expected_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_reflection_pad1d_input2d(mode): - """ - Feature: ReflectionPad1d - Description: Test ReflectionPad1d with 2D input. - Expectation: success - """ - context.set_context(mode=mode) - x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float32)) - padding = (3, 1) - net = nn.ReflectionPad1d(padding) - output = net(x) - expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2], - [7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) - - padding = 2 - expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1], - [6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float32)) - net = nn.ReflectionPad1d(padding) - output = net(x) - assert np.array_equal(output.asnumpy(), expected_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_reflection_pad2d_input4d(mode): - r""" - Feature: ReflectionPad2d - Description: Test ReflectionPad2d with 4D input. - Expectation: success - """ - context.set_context(mode=mode) - x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.float32)) - padding = (1, 1, 2, 0) - net = nn.ReflectionPad2d(padding) - output = net(x) - expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1], - [4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) - - padding = 2 - output = nn.ReflectionPad2d(padding)(x) - expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], - [2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3], - [8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], - [2, 1, 0, 1, 2, 1, 0]]]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_reflection_pad2d_input3d(mode): - r""" - Feature: ReflectionPad2d - Description: Test ReflectionPad2d with 3D input. - Expectation: success - """ - context.set_context(mode=mode) - x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32)) - padding = (1, 1, 2, 0) - net = nn.ReflectionPad2d(padding) - output = net(x) - expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1], - [4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32)) - - padding = 2 - output = nn.ReflectionPad2d(padding)(x) - - expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], - [2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3], - [8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], - [2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_reflection_pad_3d(mode): - """ - Feature: ReflectionPad3d - Description: Infer process of ReflectionPad3d with three type parameters. - Expectation: success - """ - context.set_context(mode=mode) - arr = np.arange(8).astype(np.float32).reshape((1, 2, 2, 2)) - x = Tensor(arr) - padding = (1, 1, 1, 0, 0, 1) - net3d = nn.ReflectionPad3d(padding) - output = net3d(x) - expected_output = Tensor(np.array([[[[3, 2, 3, 2], [1, 0, 1, 0], [3, 2, 3, 2]], - [[7, 6, 7, 6], [5, 4, 5, 4], [7, 6, 7, 6]], - [[3, 2, 3, 2], [1, 0, 1, 0], [3, 2, 3, 2]]]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) - - padding = 1 - output = nn.ReflectionPad3d(padding)(x) - expected_output = Tensor(np.array([[[[7., 6., 7., 6.], [5., 4., 5., 4.], - [7., 6., 7., 6.], [5., 4., 5., 4.]], - [[3., 2., 3., 2.], [1., 0., 1., 0.], - [3., 2., 3., 2.], [1., 0., 1., 0.]], - [[7., 6., 7., 6.], [5., 4., 5., 4.], - [7., 6., 7., 6.], [5., 4., 5., 4.]], - [[3., 2., 3., 2.], [1., 0., 1., 0.], - [3., 2., 3., 2.], [1., 0., 1., 0.]]]]).astype(np.float32)) - assert np.array_equal(output.asnumpy(), expected_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reflection_pad1d_input3d(mode): + """ + Feature: ReflectionPad1d + Description: Test ReflectionPad1d with 3D input. + Expectation: success + """ + context.set_context(mode=mode) + x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32)) + padding = (3, 1) + net = nn.ReflectionPad1d(padding) + output = net(x) + expected_output = Tensor(np.array([[[3, 2, 1, 0, 1, 2, 3, 2], + [7, 6, 5, 4, 5, 6, 7, 6]]]).astype(np.float32)) + + assert np.array_equal(output.asnumpy(), expected_output) + + padding = 2 + expected_output = Tensor(np.array([[[2, 1, 0, 1, 2, 3, 2, 1], + [6, 5, 4, 5, 6, 7, 6, 5]]]).astype(np.float32)) + net = nn.ReflectionPad1d(padding) + output = net(x) + assert np.array_equal(output.asnumpy(), expected_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reflection_pad1d_input2d(mode): + """ + Feature: ReflectionPad1d + Description: Test ReflectionPad1d with 2D input. + Expectation: success + """ + context.set_context(mode=mode) + x = Tensor(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]).astype(np.float32)) + padding = (3, 1) + net = nn.ReflectionPad1d(padding) + output = net(x) + expected_output = Tensor(np.array([[3, 2, 1, 0, 1, 2, 3, 2], + [7, 6, 5, 4, 5, 6, 7, 6]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) + + padding = 2 + expected_output = Tensor(np.array([[2, 1, 0, 1, 2, 3, 2, 1], + [6, 5, 4, 5, 6, 7, 6, 5]]).astype(np.float32)) + net = nn.ReflectionPad1d(padding) + output = net(x) + assert np.array_equal(output.asnumpy(), expected_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reflection_pad2d_input4d(mode): + r""" + Feature: ReflectionPad2d + Description: Test ReflectionPad2d with 4D input. + Expectation: success + """ + context.set_context(mode=mode) + x = Tensor(np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]).astype(np.float32)) + padding = (1, 1, 2, 0) + net = nn.ReflectionPad2d(padding) + output = net(x) + expected_output = Tensor(np.array([[[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1], + [4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) + + padding = 2 + output = nn.ReflectionPad2d(padding)(x) + expected_output = Tensor(np.array([[[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], + [2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3], + [8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], + [2, 1, 0, 1, 2, 1, 0]]]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reflection_pad2d_input3d(mode): + r""" + Feature: ReflectionPad2d + Description: Test ReflectionPad2d with 3D input. + Expectation: success + """ + context.set_context(mode=mode) + x = Tensor(np.array([[[0, 1, 2], [3, 4, 5], [6, 7, 8]]]).astype(np.float32)) + padding = (1, 1, 2, 0) + net = nn.ReflectionPad2d(padding) + output = net(x) + expected_output = Tensor(np.array([[[7, 6, 7, 8, 7], [4, 3, 4, 5, 4], [1, 0, 1, 2, 1], + [4, 3, 4, 5, 4], [7, 6, 7, 8, 7]]]).astype(np.float32)) + + padding = 2 + output = nn.ReflectionPad2d(padding)(x) + + expected_output = Tensor(np.array([[[8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], + [2, 1, 0, 1, 2, 1, 0], [5, 4, 3, 4, 5, 4, 3], + [8, 7, 6, 7, 8, 7, 6], [5, 4, 3, 4, 5, 4, 3], + [2, 1, 0, 1, 2, 1, 0]]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reflection_pad_3d(mode): + """ + Feature: ReflectionPad3d + Description: Infer process of ReflectionPad3d with three type parameters. + Expectation: success + """ + context.set_context(mode=mode) + arr = np.arange(8).astype(np.float32).reshape((1, 2, 2, 2)) + x = Tensor(arr) + padding = (1, 1, 1, 0, 0, 1) + net3d = nn.ReflectionPad3d(padding) + output = net3d(x) + expected_output = Tensor(np.array([[[[3, 2, 3, 2], [1, 0, 1, 0], [3, 2, 3, 2]], + [[7, 6, 7, 6], [5, 4, 5, 4], [7, 6, 7, 6]], + [[3, 2, 3, 2], [1, 0, 1, 0], [3, 2, 3, 2]]]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) + + padding = 1 + output = nn.ReflectionPad3d(padding)(x) + expected_output = Tensor(np.array([[[[7., 6., 7., 6.], [5., 4., 5., 4.], + [7., 6., 7., 6.], [5., 4., 5., 4.]], + [[3., 2., 3., 2.], [1., 0., 1., 0.], + [3., 2., 3., 2.], [1., 0., 1., 0.]], + [[7., 6., 7., 6.], [5., 4., 5., 4.], + [7., 6., 7., 6.], [5., 4., 5., 4.]], + [[3., 2., 3., 2.], [1., 0., 1., 0.], + [3., 2., 3., 2.], [1., 0., 1., 0.]]]]).astype(np.float32)) + assert np.array_equal(output.asnumpy(), expected_output) diff --git a/tests/st/nontask_sink/test_allreduce.py b/tests/st/nontask_sink/test_allreduce.py index 338715c0d5f..0a654dbd885 100644 --- a/tests/st/nontask_sink/test_allreduce.py +++ b/tests/st/nontask_sink/test_allreduce.py @@ -1,100 +1,100 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""test hccl AllReduce and all_reduce with 8p""" - -import os -import numpy as np -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import dtype as mstype -from mindspore.ops import operations as P -from mindspore.communication.management import init -from mindspore.communication.comm_func import all_reduce -from mindspore import context -from mindspore.ops import ReduceOp - -np.random.seed(1) -os.environ['HCCL_WHITELIST_DISABLE'] = str(1) -context.set_context(jit_level='O0') -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -init() - - -class AllReduceNet(nn.Cell): - def __init__(self): - super(AllReduceNet, self).__init__() - self.mul = P.Mul() - self.all_reduce = P.AllReduce() - self.add = P.Add() - self.y1 = Tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])).astype(np.float32) - self.y2 = Tensor(np.array([[-16, -16, -16, -16], [-16, -16, -16, -16], \ - [-16, -16, -16, -16]])).astype(np.float32) - - def construct(self, x): - x = self.mul(x, 2) - z = self.add(x, self.y1) - z = self.all_reduce(z) - out = self.add(z, self.y2) - out = self.all_reduce(out) - out = self.mul(out, 2) - return out - - -class AllReduceFuncNet(nn.Cell): - def __init__(self, op=ReduceOp.SUM): - super(AllReduceFuncNet, self).__init__() - self.op = op - - def construct(self, x): - return all_reduce(x) - - -def test_hccl_allreduce_8p(): - """ - Feature: test 'AllReduce' communication operation. - Description: test 'AllReduce' communication operation. - Expectation: expect correct result. - """ - net = AllReduceNet() - input_x = np.ones([3, 4]).astype(np.float32) - expect_output = [[256, 256, 256, 256], [256, 256, 256, 256], [256, 256, 256, 256]] - output = net(Tensor(input_x, mstype.float32)) - assert np.allclose(output.asnumpy(), expect_output) - - -def test_hccl_allreduce_func_net_8p(): - """ - Feature: test 'all_reduce' communication function in cell. - Description: test 'all_reduce' communication function in cell. - Expectation: expect correct result. - """ - net = AllReduceFuncNet() - input_x = np.ones([3, 4]).astype(np.float32) - expect_output = [[8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8]] - output = net(Tensor(input_x, mstype.float32)) - assert np.allclose(output.asnumpy(), expect_output) - - -def test_hccl_allreduce_func_8p(): - """ - Feature: test 'all_reduce' communication function. - Description: test 'all_reduce' communication function. - Expectation: expect correct result. - """ - x = np.ones([3, 4]).astype(np.float32) - expect_output = [[8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8]] - output = all_reduce(Tensor(x, mstype.float32)) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test hccl AllReduce and all_reduce with 8p""" + +import os +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore.ops import operations as P +from mindspore.communication.management import init +from mindspore.communication.comm_func import all_reduce +from mindspore import context +from mindspore.ops import ReduceOp + +np.random.seed(1) +os.environ['HCCL_WHITELIST_DISABLE'] = str(1) +context.set_context(jit_level='O0') +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +init() + + +class AllReduceNet(nn.Cell): + def __init__(self): + super(AllReduceNet, self).__init__() + self.mul = P.Mul() + self.all_reduce = P.AllReduce() + self.add = P.Add() + self.y1 = Tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]])).astype(np.float32) + self.y2 = Tensor(np.array([[-16, -16, -16, -16], [-16, -16, -16, -16], \ + [-16, -16, -16, -16]])).astype(np.float32) + + def construct(self, x): + x = self.mul(x, 2) + z = self.add(x, self.y1) + z = self.all_reduce(z) + out = self.add(z, self.y2) + out = self.all_reduce(out) + out = self.mul(out, 2) + return out + + +class AllReduceFuncNet(nn.Cell): + def __init__(self, op=ReduceOp.SUM): + super(AllReduceFuncNet, self).__init__() + self.op = op + + def construct(self, x): + return all_reduce(x) + + +def test_hccl_allreduce_8p(): + """ + Feature: test 'AllReduce' communication operation. + Description: test 'AllReduce' communication operation. + Expectation: expect correct result. + """ + net = AllReduceNet() + input_x = np.ones([3, 4]).astype(np.float32) + expect_output = [[256, 256, 256, 256], [256, 256, 256, 256], [256, 256, 256, 256]] + output = net(Tensor(input_x, mstype.float32)) + assert np.allclose(output.asnumpy(), expect_output) + + +def test_hccl_allreduce_func_net_8p(): + """ + Feature: test 'all_reduce' communication function in cell. + Description: test 'all_reduce' communication function in cell. + Expectation: expect correct result. + """ + net = AllReduceFuncNet() + input_x = np.ones([3, 4]).astype(np.float32) + expect_output = [[8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8]] + output = net(Tensor(input_x, mstype.float32)) + assert np.allclose(output.asnumpy(), expect_output) + + +def test_hccl_allreduce_func_8p(): + """ + Feature: test 'all_reduce' communication function. + Description: test 'all_reduce' communication function. + Expectation: expect correct result. + """ + x = np.ones([3, 4]).astype(np.float32) + expect_output = [[8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8]] + output = all_reduce(Tensor(x, mstype.float32)) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/nontask_sink/test_barrier.py b/tests/st/nontask_sink/test_barrier.py index d2a08947503..71c5d05f0b0 100644 --- a/tests/st/nontask_sink/test_barrier.py +++ b/tests/st/nontask_sink/test_barrier.py @@ -1,72 +1,72 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""test hccl reduce with 8p""" - -import time -import numpy as np -import mindspore as ms -import mindspore.nn as nn -from mindspore.ops.operations import comm_ops -from mindspore.communication.comm_func import barrier -from mindspore.communication.management import init, get_rank - -# 'Barrier' operator only supports KernelByKernel mode by now. -np.random.seed(1) -ms.set_context(jit_level='O0') -init() -rank = get_rank() - -class BarrierNet(nn.Cell): - def __init__(self): - super(BarrierNet, self).__init__() - self.barrier = comm_ops.Barrier() - - def construct(self): - self.barrier() - -class BarrierFuncNet(nn.Cell): - def construct(self): - barrier() - -def test_hccl_barrier_8p(): - """ - Feature: test 'Barrier' collective communication operator. - Description: test 'Barrier' collective communication operator. - Expectation: all processes in the group synchronize in this operator. - """ - net = BarrierNet() - if rank == 3: - time.sleep(3) - if rank == 4: - time.sleep(6) - print("Process {} start time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) - net() - print("Process {} end time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) - -def test_hccl_barrier_func_8p(): - """ - Feature: test 'Barrier' collective communication operator. - Description: test 'Barrier' collective communication operator. - Expectation: all processes in the group synchronize in this operator. - """ - net = BarrierFuncNet() - if rank == 3: - time.sleep(3) - if rank == 4: - time.sleep(6) - print("Process {} start time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) - net() - print("Process {} end time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test hccl reduce with 8p""" + +import time +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops.operations import comm_ops +from mindspore.communication.comm_func import barrier +from mindspore.communication.management import init, get_rank + +# 'Barrier' operator only supports KernelByKernel mode by now. +np.random.seed(1) +ms.set_context(jit_level='O0') +init() +rank = get_rank() + +class BarrierNet(nn.Cell): + def __init__(self): + super(BarrierNet, self).__init__() + self.barrier = comm_ops.Barrier() + + def construct(self): + self.barrier() + +class BarrierFuncNet(nn.Cell): + def construct(self): + barrier() + +def test_hccl_barrier_8p(): + """ + Feature: test 'Barrier' collective communication operator. + Description: test 'Barrier' collective communication operator. + Expectation: all processes in the group synchronize in this operator. + """ + net = BarrierNet() + if rank == 3: + time.sleep(3) + if rank == 4: + time.sleep(6) + print("Process {} start time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) + net() + print("Process {} end time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) + +def test_hccl_barrier_func_8p(): + """ + Feature: test 'Barrier' collective communication operator. + Description: test 'Barrier' collective communication operator. + Expectation: all processes in the group synchronize in this operator. + """ + net = BarrierFuncNet() + if rank == 3: + time.sleep(3) + if rank == 4: + time.sleep(6) + print("Process {} start time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) + net() + print("Process {} end time: {}".format(rank, time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime(time.time())))) diff --git a/tests/st/nontask_sink/test_lenet.py b/tests/st/nontask_sink/test_lenet.py index 9e607e304b9..e0186aa1f5f 100644 --- a/tests/st/nontask_sink/test_lenet.py +++ b/tests/st/nontask_sink/test_lenet.py @@ -1,297 +1,297 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import os -import time -import numpy as np - -import mindspore.nn as nn -from mindspore import context, Tensor, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore.common.initializer import TruncatedNormal -from mindspore.nn.optim import Momentum -from mindspore.nn.wrap.cell_wrapper import WithLossCell -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.ops import operations as P - -from tests.mark_utils import arg_mark - -np.random.seed(1) -grad_by_list = C.GradOperation(get_by_list=True) - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - - -def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - -class LeNet(nn.Cell): - """ - Lenet network - Args: - num_class (int): Num classes, Default: 10. - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - """ - - def __init__(self, num_class=10): - super(LeNet, self).__init__() - self.num_class = num_class - self.batch_size = 32 - self.conv1 = conv(1, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x - - -class CrossEntropyLoss(nn.Cell): - """ - Define loss for network - """ - - def __init__(self): - super(CrossEntropyLoss, self).__init__() - self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.num = Tensor(32.0, mstype.float32) - - def construct(self, logits, label): - label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) - loss = self.cross_entropy(logits, label)[0] - loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) - return loss - - -class GradWrap(nn.Cell): - """ - GradWrap definition - """ - - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) - - def construct(self, x, label): - weights = self.weights - return grad_by_list(self.network, weights)(x, label) - - -def test_ascend_lenet(): - epoch_size = 20 - batch_size = 32 - inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) - labels = Tensor(np.ones([batch_size]).astype(np.int32)) - - net = LeNet() - criterion = CrossEntropyLoss() - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) - - net_with_criterion = WithLossCell(net, criterion) - train_network = GradWrap(net_with_criterion) - train_network.set_train() - total_time = 0 - - for epoch in range(0, epoch_size): - start_time = time.time() - fw_output = net(inputs) - loss_output = criterion(fw_output, labels) - grads = train_network(inputs, labels) - optimizer(grads) - end_time = time.time() - cost_time = end_time - start_time - total_time = total_time + cost_time - - print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - return loss_output - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet1(): - context.set_context(jit_level='O0') - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - loss_output = test_ascend_lenet() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet2(): - context.set_context(jit_level='O0') - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - loss_output = test_ascend_lenet() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet3(): - """ - Feature: Somas Ascend kernel by kernel. - Description: LeNet with Somas Ascend kernel by kernel. - Expectation: No exception. - """ - context.set_context(jit_level='O0') - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", memory_optimize_level='O1') - loss_output = test_ascend_lenet() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet4(): - """ - Feature: Ascend kernel by kernel and Ascend VMM. - Description: LeNet with Ascend kernel by kernel and VMM. - Expectation: No exception. - """ - context.set_context(jit_level='O0') - os.environ['MS_DEV_ENABLE_ASCEND_VMM'] = str(1) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - loss_output = test_ascend_lenet() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet5(): - """ - Feature: Ascend kernel by kernel and Ascend VMM. - Description: LeNet with Ascend kernel by kernel and VMM. - Expectation: No exception. - """ - context.set_context(jit_level='O0') - os.environ['MS_DEV_ENABLE_ASCEND_VMM'] = str(1) - os.environ['MS_DEV_ASCEND_VMM_ALIGN_SIZE'] = "20MB" - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - loss_output = test_ascend_lenet() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -class GradWrapTuple(nn.Cell): - """ - GradWrapTuple definition - """ - - def __init__(self, network): - super(GradWrapTuple, self).__init__() - self.network = network - self.weights = tuple(filter(lambda x: x.requires_grad, network.get_parameters())) - - def construct(self, x, label): - weights = self.weights - return grad_by_list(self.network, weights)(x, label) - - -def test_ascend_lenet_grad_by_list_tuple(): - """ - Feature: GradOperation get_by_list pass tuple/list - Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. - Expectation: No exception. - """ - epoch_size = 20 - batch_size = 32 - inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) - labels = Tensor(np.ones([batch_size]).astype(np.int32)) - - net = LeNet() - criterion = CrossEntropyLoss() - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) - - net_with_criterion = WithLossCell(net, criterion) - train_network = GradWrapTuple(net_with_criterion) - train_network.set_train() - total_time = 0 - - for epoch in range(0, epoch_size): - start_time = time.time() - fw_output = net(inputs) - loss_output = criterion(fw_output, labels) - grads = train_network(inputs, labels) - optimizer(grads) - end_time = time.time() - cost_time = end_time - start_time - total_time = total_time + cost_time - - print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - return loss_output - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet_grad_by_list_tuple1(): - """ - Feature: GradOperation get_by_list pass tuple/list - Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. - Expectation: No exception. - """ - context.set_context(jit_level='O0') - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - loss_output = test_ascend_lenet_grad_by_list_tuple() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") -def test_ascend_lenet_grad_by_list_tuple2(): - """ - Feature: GradOperation get_by_list pass tuple/list with Ascend kernel by kernel Somas. - Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. - Expectation: No exception. - """ - context.set_context(jit_level='O0') - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", memory_optimize_level='O1') - loss_output = test_ascend_lenet_grad_by_list_tuple() - assert loss_output.asnumpy() < 0.004 - assert loss_output.asnumpy() > 0.003 +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import time +import numpy as np + +import mindspore.nn as nn +from mindspore import context, Tensor, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.nn.optim import Momentum +from mindspore.nn.wrap.cell_wrapper import WithLossCell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +from tests.mark_utils import arg_mark + +np.random.seed(1) +grad_by_list = C.GradOperation(get_by_list=True) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet(nn.Cell): + """ + Lenet network + Args: + num_class (int): Num classes, Default: 10. + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + """ + + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class CrossEntropyLoss(nn.Cell): + """ + Define loss for network + """ + + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.num = Tensor(32.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss = self.cross_entropy(logits, label)[0] + loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) + return loss + + +class GradWrap(nn.Cell): + """ + GradWrap definition + """ + + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return grad_by_list(self.network, weights)(x, label) + + +def test_ascend_lenet(): + epoch_size = 20 + batch_size = 32 + inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) + labels = Tensor(np.ones([batch_size]).astype(np.int32)) + + net = LeNet() + criterion = CrossEntropyLoss() + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + total_time = 0 + + for epoch in range(0, epoch_size): + start_time = time.time() + fw_output = net(inputs) + loss_output = criterion(fw_output, labels) + grads = train_network(inputs, labels) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + total_time = total_time + cost_time + + print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + return loss_output + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet1(): + context.set_context(jit_level='O0') + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet2(): + context.set_context(jit_level='O0') + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet3(): + """ + Feature: Somas Ascend kernel by kernel. + Description: LeNet with Somas Ascend kernel by kernel. + Expectation: No exception. + """ + context.set_context(jit_level='O0') + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", memory_optimize_level='O1') + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet4(): + """ + Feature: Ascend kernel by kernel and Ascend VMM. + Description: LeNet with Ascend kernel by kernel and VMM. + Expectation: No exception. + """ + context.set_context(jit_level='O0') + os.environ['MS_DEV_ENABLE_ASCEND_VMM'] = str(1) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet5(): + """ + Feature: Ascend kernel by kernel and Ascend VMM. + Description: LeNet with Ascend kernel by kernel and VMM. + Expectation: No exception. + """ + context.set_context(jit_level='O0') + os.environ['MS_DEV_ENABLE_ASCEND_VMM'] = str(1) + os.environ['MS_DEV_ASCEND_VMM_ALIGN_SIZE'] = "20MB" + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + loss_output = test_ascend_lenet() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +class GradWrapTuple(nn.Cell): + """ + GradWrapTuple definition + """ + + def __init__(self, network): + super(GradWrapTuple, self).__init__() + self.network = network + self.weights = tuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return grad_by_list(self.network, weights)(x, label) + + +def test_ascend_lenet_grad_by_list_tuple(): + """ + Feature: GradOperation get_by_list pass tuple/list + Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. + Expectation: No exception. + """ + epoch_size = 20 + batch_size = 32 + inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) + labels = Tensor(np.ones([batch_size]).astype(np.int32)) + + net = LeNet() + criterion = CrossEntropyLoss() + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrapTuple(net_with_criterion) + train_network.set_train() + total_time = 0 + + for epoch in range(0, epoch_size): + start_time = time.time() + fw_output = net(inputs) + loss_output = criterion(fw_output, labels) + grads = train_network(inputs, labels) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + total_time = total_time + cost_time + + print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + return loss_output + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet_grad_by_list_tuple1(): + """ + Feature: GradOperation get_by_list pass tuple/list + Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. + Expectation: No exception. + """ + context.set_context(jit_level='O0') + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + loss_output = test_ascend_lenet_grad_by_list_tuple() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level1", card_mark="onecard", essential_mark="essential") +def test_ascend_lenet_grad_by_list_tuple2(): + """ + Feature: GradOperation get_by_list pass tuple/list with Ascend kernel by kernel Somas. + Description: Grad with Parameters as input type and fv. list or tuple as fv of grad. + Expectation: No exception. + """ + context.set_context(jit_level='O0') + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", memory_optimize_level='O1') + loss_output = test_ascend_lenet_grad_by_list_tuple() + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 diff --git a/tests/st/nontask_sink/test_reduce.py b/tests/st/nontask_sink/test_reduce.py index 6f059f2c5fe..72ebb451ac2 100644 --- a/tests/st/nontask_sink/test_reduce.py +++ b/tests/st/nontask_sink/test_reduce.py @@ -1,109 +1,109 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""test hccl reduce with 8p""" - -import numpy as np -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations import comm_ops -from mindspore.communication.management import init, get_rank -from mindspore.communication.comm_func import reduce -from mindspore import context -from mindspore.communication import GlobalComm - -# 'Reduce' operator only supports KernelByKernel mode by now. -np.random.seed(1) -context.set_context(jit_level='O0') -context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") -init() -this_rank = get_rank() - - -class ReduceNet(nn.Cell): - def __init__(self): - super(ReduceNet, self).__init__() - self.reduce1 = comm_ops.Reduce(2) - self.reduce2 = comm_ops.Reduce(6) - - def construct(self, x): - output1 = self.reduce1(x) - output2 = self.reduce2(x) - return output1, output2 - - -class ReduceFuncNet(nn.Cell): - def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): - super(ReduceFuncNet, self).__init__() - self.group = group - - def construct(self, x): - output1 = reduce(x, 2) - output2 = reduce(x, 6) - return output1, output2 - - -def test_hccl_reduce_8p(): - """ - Feature: test 'Reduce' communication operator. - Description: test 'Reduce' communication operator. - Expectation: expect correct result. - """ - net = ReduceNet() - input_x = np.array([0, 1, 2, 3]).astype(np.float32) - expect_output = np.array([0, 8, 16, 24]).astype(np.float32) - output1, output2 = net(Tensor(input_x)) - if this_rank == 2: - assert np.allclose(output1.asnumpy(), expect_output) - - if this_rank == 6: - assert np.allclose(output2.asnumpy(), expect_output) - print("outputs are", output1, output2) - - -def test_hccl_reduce_func_net_8p(): - """ - Feature: test 'Reduce' communication operator. - Description: test 'Reduce' communication operator. - Expectation: expect correct result. - """ - net = ReduceFuncNet() - input_x = np.array([0, 1, 2, 3]).astype(np.float32) - expect_output = np.array([0, 8, 16, 24]).astype(np.float32) - output1, output2 = net(Tensor(input_x)) - if this_rank == 2: - assert np.allclose(output1.asnumpy(), expect_output) - - if this_rank == 6: - assert np.allclose(output2.asnumpy(), expect_output) - print("outputs are", output1, output2) - - -def test_hccl_reduce_func_8p(): - """ - Feature: test 'reduce' communication function. - Description: test 'reduce' communication function. - Expectation: expect correct result. - """ - input_x = np.array([0, 1, 2, 3]).astype(np.float32) - expect_output = np.array([0, 8, 16, 24]).astype(np.float32) - output1 = reduce(Tensor(input_x), 2) - output2 = reduce(Tensor(input_x), 6) - if this_rank == 2: - assert np.allclose(output1.asnumpy(), expect_output) - - if this_rank == 6: - assert np.allclose(output2.asnumpy(), expect_output) - print("outputs are", output1, output2) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test hccl reduce with 8p""" + +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import comm_ops +from mindspore.communication.management import init, get_rank +from mindspore.communication.comm_func import reduce +from mindspore import context +from mindspore.communication import GlobalComm + +# 'Reduce' operator only supports KernelByKernel mode by now. +np.random.seed(1) +context.set_context(jit_level='O0') +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") +init() +this_rank = get_rank() + + +class ReduceNet(nn.Cell): + def __init__(self): + super(ReduceNet, self).__init__() + self.reduce1 = comm_ops.Reduce(2) + self.reduce2 = comm_ops.Reduce(6) + + def construct(self, x): + output1 = self.reduce1(x) + output2 = self.reduce2(x) + return output1, output2 + + +class ReduceFuncNet(nn.Cell): + def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): + super(ReduceFuncNet, self).__init__() + self.group = group + + def construct(self, x): + output1 = reduce(x, 2) + output2 = reduce(x, 6) + return output1, output2 + + +def test_hccl_reduce_8p(): + """ + Feature: test 'Reduce' communication operator. + Description: test 'Reduce' communication operator. + Expectation: expect correct result. + """ + net = ReduceNet() + input_x = np.array([0, 1, 2, 3]).astype(np.float32) + expect_output = np.array([0, 8, 16, 24]).astype(np.float32) + output1, output2 = net(Tensor(input_x)) + if this_rank == 2: + assert np.allclose(output1.asnumpy(), expect_output) + + if this_rank == 6: + assert np.allclose(output2.asnumpy(), expect_output) + print("outputs are", output1, output2) + + +def test_hccl_reduce_func_net_8p(): + """ + Feature: test 'Reduce' communication operator. + Description: test 'Reduce' communication operator. + Expectation: expect correct result. + """ + net = ReduceFuncNet() + input_x = np.array([0, 1, 2, 3]).astype(np.float32) + expect_output = np.array([0, 8, 16, 24]).astype(np.float32) + output1, output2 = net(Tensor(input_x)) + if this_rank == 2: + assert np.allclose(output1.asnumpy(), expect_output) + + if this_rank == 6: + assert np.allclose(output2.asnumpy(), expect_output) + print("outputs are", output1, output2) + + +def test_hccl_reduce_func_8p(): + """ + Feature: test 'reduce' communication function. + Description: test 'reduce' communication function. + Expectation: expect correct result. + """ + input_x = np.array([0, 1, 2, 3]).astype(np.float32) + expect_output = np.array([0, 8, 16, 24]).astype(np.float32) + output1 = reduce(Tensor(input_x), 2) + output2 = reduce(Tensor(input_x), 6) + if this_rank == 2: + assert np.allclose(output1.asnumpy(), expect_output) + + if this_rank == 6: + assert np.allclose(output2.asnumpy(), expect_output) + print("outputs are", output1, output2) diff --git a/tests/st/ops/ascend/test_adam_weight_decay.py b/tests/st/ops/ascend/test_adam_weight_decay.py index 3ff0b4eabed..b6a08af1ee5 100644 --- a/tests/st/ops/ascend/test_adam_weight_decay.py +++ b/tests/st/ops/ascend/test_adam_weight_decay.py @@ -1,180 +1,180 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import copy -import numpy as np -import mindspore.nn as nn -from mindspore import Tensor, Parameter, context -from mindspore.ops import operations as P -from mindspore.nn.optim.adam import _update_run_op -from tests.mark_utils import arg_mark - - -class OriNet(nn.Cell): - """Origin net uses _update_run_op""" - - def __init__(self, decay_flag): - super(OriNet, self).__init__() - self.decay_flag = decay_flag - self.optim_filter = True - - def construct(self, param, m, v, lr, beta1, beta2, eps, weight_decay, gradient): - next_param = _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, - self.decay_flag, self.optim_filter) - return next_param - - -class FissionNet(nn.Cell): - """Fission net uses P.AdamWeightDecay()""" - - def __init__(self): - super(FissionNet, self).__init__() - self.optim_filter = True - - def construct(self, param, m, v, lr, beta1, beta2, eps, weight_decay, gradient): - if self.optim_filter: - adam = P.AdamWeightDecay() - next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) - return next_param - return gradient - - -def test_adam_weight_decay_fission_1_decay_flag_is_true(): - """ - Feature: AdamWeightDecay op - Description: test the rightness of AdamWeightDecay kernel, decay_flag is true - Expectation: the output is wrong - """ - decay_flag = True # equivalent to weight_decay is not zero - weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") - beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") - beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") - eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") - lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") - gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") - - # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), - # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) - param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") - m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") - v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") - - param2 = copy.deepcopy(param1) - m2 = copy.deepcopy(m1) - v2 = copy.deepcopy(v1) - - context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') - origin_net = OriNet(decay_flag) - output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) - fission_net = FissionNet() - output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) - assert (output1.asnumpy() == output2[0].asnumpy()).all() - - -def test_adam_weight_decay_fission_2_decay_flag_is_false(): - """ - Feature: AdamWeightDecay op - Description: test the rightness of ScaleGrad kernel, decay_flag is false - Expectation: the output is wrong - """ - decay_flag = False # equivalent to weight_decay is zero - weight_decay = Parameter(Tensor(np.array([0]).astype(np.float32)), name="weight_decay") - beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") - beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") - eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") - lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") - gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") - - # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), - # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) - param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") - m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") - v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") - - param2 = copy.deepcopy(param1) - m2 = copy.deepcopy(m1) - v2 = copy.deepcopy(v1) - - context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') - origin_net = OriNet(decay_flag) - output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) - fission_net = FissionNet() - output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) - assert (output1.asnumpy() == output2[0].asnumpy()).all() - - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential') -def test_adam_weight_decay_pass_without_same_type(): - """ - Feature: AdamWeightDecay op - Description: test the rightness of AdamWeightDecay kernel, decay_flag is true - Expectation: the output is same - """ - decay_flag = True # equivalent to weight_decay is not zero - weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") - beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") - beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") - eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") - lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") - gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") - - # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), - # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) - param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") - m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float16)), name="m1") - v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float16)), name="v1") - - param2 = copy.deepcopy(param1) - m2 = copy.deepcopy(m1) - v2 = copy.deepcopy(v1) - - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - origin_net = OriNet(decay_flag) - output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) - fission_net = FissionNet() - output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) - assert (output1.asnumpy() == output2[0].asnumpy()).all() - - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential') -def test_adam_weight_decay_pass_with_same_type_to_assign(): - """ - Feature: AdamWeightDecay op - Description: test the rightness of AdamWeightDecay kernel, decay_flag is true - Expectation: the output is same - """ - decay_flag = True # equivalent to weight_decay is not zero - weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") - beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") - beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") - eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") - lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") - gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") - - # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), - # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) - param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") - m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") - v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") - - param2 = copy.deepcopy(param1) - m2 = copy.deepcopy(m1) - v2 = copy.deepcopy(v1) - - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - origin_net = OriNet(decay_flag) - output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) - fission_net = FissionNet() - output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) - assert (output1.asnumpy() == output2[0].asnumpy()).all() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import copy +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor, Parameter, context +from mindspore.ops import operations as P +from mindspore.nn.optim.adam import _update_run_op +from tests.mark_utils import arg_mark + + +class OriNet(nn.Cell): + """Origin net uses _update_run_op""" + + def __init__(self, decay_flag): + super(OriNet, self).__init__() + self.decay_flag = decay_flag + self.optim_filter = True + + def construct(self, param, m, v, lr, beta1, beta2, eps, weight_decay, gradient): + next_param = _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, + self.decay_flag, self.optim_filter) + return next_param + + +class FissionNet(nn.Cell): + """Fission net uses P.AdamWeightDecay()""" + + def __init__(self): + super(FissionNet, self).__init__() + self.optim_filter = True + + def construct(self, param, m, v, lr, beta1, beta2, eps, weight_decay, gradient): + if self.optim_filter: + adam = P.AdamWeightDecay() + next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) + return next_param + return gradient + + +def test_adam_weight_decay_fission_1_decay_flag_is_true(): + """ + Feature: AdamWeightDecay op + Description: test the rightness of AdamWeightDecay kernel, decay_flag is true + Expectation: the output is wrong + """ + decay_flag = True # equivalent to weight_decay is not zero + weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") + beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") + beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") + eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") + lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") + gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") + + # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), + # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) + param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") + m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") + v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") + + param2 = copy.deepcopy(param1) + m2 = copy.deepcopy(m1) + v2 = copy.deepcopy(v1) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + origin_net = OriNet(decay_flag) + output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) + fission_net = FissionNet() + output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) + assert (output1.asnumpy() == output2[0].asnumpy()).all() + + +def test_adam_weight_decay_fission_2_decay_flag_is_false(): + """ + Feature: AdamWeightDecay op + Description: test the rightness of ScaleGrad kernel, decay_flag is false + Expectation: the output is wrong + """ + decay_flag = False # equivalent to weight_decay is zero + weight_decay = Parameter(Tensor(np.array([0]).astype(np.float32)), name="weight_decay") + beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") + beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") + eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") + lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") + gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") + + # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), + # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) + param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") + m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") + v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") + + param2 = copy.deepcopy(param1) + m2 = copy.deepcopy(m1) + v2 = copy.deepcopy(v1) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + origin_net = OriNet(decay_flag) + output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) + fission_net = FissionNet() + output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) + assert (output1.asnumpy() == output2[0].asnumpy()).all() + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential') +def test_adam_weight_decay_pass_without_same_type(): + """ + Feature: AdamWeightDecay op + Description: test the rightness of AdamWeightDecay kernel, decay_flag is true + Expectation: the output is same + """ + decay_flag = True # equivalent to weight_decay is not zero + weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") + beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") + beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") + eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") + lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") + gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") + + # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), + # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) + param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") + m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float16)), name="m1") + v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float16)), name="v1") + + param2 = copy.deepcopy(param1) + m2 = copy.deepcopy(m1) + v2 = copy.deepcopy(v1) + + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + origin_net = OriNet(decay_flag) + output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) + fission_net = FissionNet() + output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) + assert (output1.asnumpy() == output2[0].asnumpy()).all() + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential') +def test_adam_weight_decay_pass_with_same_type_to_assign(): + """ + Feature: AdamWeightDecay op + Description: test the rightness of AdamWeightDecay kernel, decay_flag is true + Expectation: the output is same + """ + decay_flag = True # equivalent to weight_decay is not zero + weight_decay = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="weight_decay") + beta1 = Parameter(Tensor(np.array([0.9]).astype(np.float32)), name="beta1") + beta2 = Parameter(Tensor(np.array([0.999]).astype(np.float32)), name="beta2") + eps = Parameter(Tensor(np.array([1e-8]).astype(np.float32)), name="eps") + lr = Parameter(Tensor(np.array([0.001]).astype(np.float32)), name="lr") + gradient = Parameter(Tensor(np.array([[2, 3], [1, 5]]).astype(np.float32)), name="gradient") + + # The inputs: param, m and v will be modified in-place by P.AdamWeightDecay() or _update_run_op(), + # so here defines two copied of them: (param1, m1, v1) and (param2, m2, v2) + param1 = Parameter(Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)), name="param1") + m1 = Parameter(Tensor(np.array([[5, 6], [7, 8]]).astype(np.float32)), name="m1") + v1 = Parameter(Tensor(np.array([[3, 1], [7, 4]]).astype(np.float32)), name="v1") + + param2 = copy.deepcopy(param1) + m2 = copy.deepcopy(m1) + v2 = copy.deepcopy(v1) + + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + origin_net = OriNet(decay_flag) + output1 = origin_net(param1, m1, v1, lr, beta1, beta2, eps, weight_decay, gradient) + fission_net = FissionNet() + output2 = fission_net(param2, m2, v2, lr, beta1, beta2, eps, weight_decay, gradient) + assert (output1.asnumpy() == output2[0].asnumpy()).all() diff --git a/tests/st/ops/ascend/test_scale_grad.py b/tests/st/ops/ascend/test_scale_grad.py index 6a733e4561f..7278bcdb53c 100644 --- a/tests/st/ops/ascend/test_scale_grad.py +++ b/tests/st/ops/ascend/test_scale_grad.py @@ -1,129 +1,129 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations.inner_ops import ScaleGrad -from mindspore.common import dtype - -context.set_context(device_target="Ascend") - - -class Net(nn.Cell): - def __init__(self, scale): - super(Net, self).__init__() - self.scale_grad = ScaleGrad() - self.scale = scale - - def construct(self, origin_grads): - return self.scale_grad(origin_grads, self.scale) - - -def test_scale_grad_grad_float32_scale_float32(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is float32, scale's dtype is float32 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float32) - gradients = [] - for _ in range(3): - gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) - - -def test_scale_grad_grad_float32_scale_float16(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is float32, scale's dtype is float16 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float32) - gradients = [] - for _ in range(3): - gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) - - -def test_scale_grad_grad_float16_scale_float32(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is float16, scale's dtype is float32 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float32) - gradients = [] - for _ in range(3): - gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) - - -def test_scale_grad_grad_float16_scale_float16(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is float16, scale's dtype is float16 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float16) - gradients = [] - for _ in range(3): - gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) - - -def test_scale_grad_grad_mixed_scale_float32(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is mixed, scale's dtype is float32 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float32) - gradients = [] - for i in range(3): - if (i % 2) == 0: - gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) - else: - gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) - - -def test_scale_grad_grad_mixed_scale_float16(): - """ - Feature: Scale Grad fusion operation - Description: test the rightness of ScaleGrad kernel, gradient's dtype is mixed, scale's dtype is float16 - Expectation: the output is wrong - """ - scale = Tensor(1024.0, dtype.float16) - gradients = [] - for i in range(3): - if (i % 2) == 0: - gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) - else: - gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) - gradients_input = tuple(gradients) - scale_grad = Net(scale) - scale_grad(gradients_input) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations.inner_ops import ScaleGrad +from mindspore.common import dtype + +context.set_context(device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, scale): + super(Net, self).__init__() + self.scale_grad = ScaleGrad() + self.scale = scale + + def construct(self, origin_grads): + return self.scale_grad(origin_grads, self.scale) + + +def test_scale_grad_grad_float32_scale_float32(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is float32, scale's dtype is float32 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float32) + gradients = [] + for _ in range(3): + gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) + + +def test_scale_grad_grad_float32_scale_float16(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is float32, scale's dtype is float16 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float32) + gradients = [] + for _ in range(3): + gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) + + +def test_scale_grad_grad_float16_scale_float32(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is float16, scale's dtype is float32 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float32) + gradients = [] + for _ in range(3): + gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) + + +def test_scale_grad_grad_float16_scale_float16(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is float16, scale's dtype is float16 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float16) + gradients = [] + for _ in range(3): + gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) + + +def test_scale_grad_grad_mixed_scale_float32(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is mixed, scale's dtype is float32 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float32) + gradients = [] + for i in range(3): + if (i % 2) == 0: + gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) + else: + gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) + + +def test_scale_grad_grad_mixed_scale_float16(): + """ + Feature: Scale Grad fusion operation + Description: test the rightness of ScaleGrad kernel, gradient's dtype is mixed, scale's dtype is float16 + Expectation: the output is wrong + """ + scale = Tensor(1024.0, dtype.float16) + gradients = [] + for i in range(3): + if (i % 2) == 0: + gradients.append(Tensor(np.ones([3, 3]).astype(np.float32))) + else: + gradients.append(Tensor(np.ones([3, 3]).astype(np.float16))) + gradients_input = tuple(gradients) + scale_grad = Net(scale) + scale_grad(gradients_input) diff --git a/tests/st/ops/ascend/test_sigmoid.py b/tests/st/ops/ascend/test_sigmoid.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/ascend/test_sparse_attention.py b/tests/st/ops/ascend/test_sparse_attention.py index fdebe406e15..18d63468ba1 100644 --- a/tests/st/ops/ascend/test_sparse_attention.py +++ b/tests/st/ops/ascend/test_sparse_attention.py @@ -1,27 +1,27 @@ -import numpy as np -from mindspore import Tensor -from mindspore.parallel.nn.layers import FixedSparseAttention -import mindspore.context as context - -context.set_context(device_target="Ascend") - - -def test_net(): - np.random.seed(0) - bs = 2 # batch size - heads = 2 - seq_len = 1024 # this op is designed for seq_len = 1024 - size_per_head = 128 # maximum size per head value is 128 - - block_size = 64 # block size is designed to be 64 - fixed_sparse = FixedSparseAttention(bs, heads, size_per_head, block_size) - q = np.random.rand(bs, seq_len, heads * size_per_head) - q = q.astype(np.float16) - k = np.random.rand(bs, seq_len, heads * size_per_head) - k = k.astype(np.float16) - v = np.random.rand(bs, seq_len, heads * size_per_head) - v = v.astype(np.float16) - attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32) - out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask)) - out_np = out.asnumpy() - print("local output: ", out_np[0, 0]) +import numpy as np +from mindspore import Tensor +from mindspore.parallel.nn.layers import FixedSparseAttention +import mindspore.context as context + +context.set_context(device_target="Ascend") + + +def test_net(): + np.random.seed(0) + bs = 2 # batch size + heads = 2 + seq_len = 1024 # this op is designed for seq_len = 1024 + size_per_head = 128 # maximum size per head value is 128 + + block_size = 64 # block size is designed to be 64 + fixed_sparse = FixedSparseAttention(bs, heads, size_per_head, block_size) + q = np.random.rand(bs, seq_len, heads * size_per_head) + q = q.astype(np.float16) + k = np.random.rand(bs, seq_len, heads * size_per_head) + k = k.astype(np.float16) + v = np.random.rand(bs, seq_len, heads * size_per_head) + v = v.astype(np.float16) + attention_mask = np.ones((bs, seq_len, seq_len), dtype=np.float32) + out = fixed_sparse(Tensor(q), Tensor(k), Tensor(v), Tensor(attention_mask)) + out_np = out.asnumpy() + print("local output: ", out_np[0, 0]) diff --git a/tests/st/ops/ascend/test_tbe_ops/Initialize.info b/tests/st/ops/ascend/test_tbe_ops/Initialize.info deleted file mode 100644 index ddbb043dbe3..00000000000 --- a/tests/st/ops/ascend/test_tbe_ops/Initialize.info +++ /dev/null @@ -1,43 +0,0 @@ -{ - "job_content": { - "SocInfo": { - "autoTilingMode": "NO_TUNE", - "coreNum": "", - "coreType": "", - "deviceId": "1", - "l1Fusion": "false", - "l2Fusion": "false", - "l2Mode": "2", - "mdl_bank_path": "", - "offlineTune": false, - "op_bank_path": "", - "op_bank_update": false, - "op_debug_dir": "./rank_0/", - "op_debug_level": "3", - "op_impl_mode": "", - "op_impl_mode_list": [], - "socVersion": "Ascend910A", - "vector_fp_ceiling": "" - }, - "TuneInfo": { - "tune_bank_path": "", - "tune_dump_path": "", - "tune_op_list": [] - }, - "LicInfo": { - "rl_tune_switch": "on", - "rl_tune_list": "ALL", - "op_tune_switch": "on", - "op_tune_list": "ALL", - "pass_list": "ALL" - }, - "enable_event": false, - "log_level": 1, - "para_debug_path": "", - "process_num": 8, - "tbe_impl_path": "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe" - }, - "job_id": 1, - "job_type": "Initialize", - "source_id": 1 -} \ No newline at end of file diff --git a/tests/st/ops/ascend/test_tbe_ops/op.info b/tests/st/ops/ascend/test_tbe_ops/op.info deleted file mode 100644 index 9e26dfeeb6e..00000000000 --- a/tests/st/ops/ascend/test_tbe_ops/op.info +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py b/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py b/tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/ascend/test_tbe_ops/test_p_s_r_o_i_pooling_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_p_s_r_o_i_pooling_grad.py index f24fdd15fb4..91b74867b64 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_p_s_r_o_i_pooling_grad.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_p_s_r_o_i_pooling_grad.py @@ -1,70 +1,70 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.api import jit -from mindspore.ops.operations import _grad_ops as G - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - -class Net(nn.Cell): - def __init__(self, input_size, spatial_scale, group_size, output_dim): - super(Net, self).__init__() - self.roi_pooling = G.PSROIPoolingGrad(input_size, spatial_scale, group_size, output_dim) - - @jit - def construct(self, x, rois): - return self.roi_pooling(x, rois) - - -def test_net(x_shape, rois_shape, input_size, spatial_scale, group_size, output_dim): - """ - Feature: test PSROIPoolingGrad. - Description: - Input: - x: shape is: [n, c, out_shape, out_shape]. - rois: shape is: [n1, c2, n2], where c2 is 5, n1 represent the batch size, - and n2 = n // n1 (also mean n = n1 * n2). - - input_size: should contain 2 value, refers to width and height, refers to h1 and w1 in output shape. - spatial_scale: default is 1/16.0. - group_size: should equal to out_shape. - output_dim: (output_dim + C0 - 1) // C0 == c, where c0 is 16 in davinci. - - output: - output shape: [n1, c1, h1, w1], where n1 is the batch_size in rois, c1 = c * out_shape * out_shape - h1 and w1 is output resolution. - Expectation: Run successfully. - """ - - np_x = np.random.random(x_shape).astype(np.float32) - input_x = Tensor(np_x) - - np_rois = np.random.random(rois_shape).astype(np.float32) - input_rois = Tensor(np_rois) - - roi_grad = Net(input_size, spatial_scale, group_size, output_dim) - output = roi_grad(input_x, input_rois) - print(output.asnumpy()) - - -if __name__ == "__main__": - test_net((512, 22, 7, 7), (4, 5, 128), (84, 84), 0.0625, 7, 22) - test_net((16, 2, 7, 7), (1, 5, 16), (28, 28), 0.0625, 7, 2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import jit +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, input_size, spatial_scale, group_size, output_dim): + super(Net, self).__init__() + self.roi_pooling = G.PSROIPoolingGrad(input_size, spatial_scale, group_size, output_dim) + + @jit + def construct(self, x, rois): + return self.roi_pooling(x, rois) + + +def test_net(x_shape, rois_shape, input_size, spatial_scale, group_size, output_dim): + """ + Feature: test PSROIPoolingGrad. + Description: + Input: + x: shape is: [n, c, out_shape, out_shape]. + rois: shape is: [n1, c2, n2], where c2 is 5, n1 represent the batch size, + and n2 = n // n1 (also mean n = n1 * n2). + + input_size: should contain 2 value, refers to width and height, refers to h1 and w1 in output shape. + spatial_scale: default is 1/16.0. + group_size: should equal to out_shape. + output_dim: (output_dim + C0 - 1) // C0 == c, where c0 is 16 in davinci. + + output: + output shape: [n1, c1, h1, w1], where n1 is the batch_size in rois, c1 = c * out_shape * out_shape + h1 and w1 is output resolution. + Expectation: Run successfully. + """ + + np_x = np.random.random(x_shape).astype(np.float32) + input_x = Tensor(np_x) + + np_rois = np.random.random(rois_shape).astype(np.float32) + input_rois = Tensor(np_rois) + + roi_grad = Net(input_size, spatial_scale, group_size, output_dim) + output = roi_grad(input_x, input_rois) + print(output.asnumpy()) + + +if __name__ == "__main__": + test_net((512, 22, 7, 7), (4, 5, 128), (84, 84), 0.0625, 7, 22) + test_net((16, 2, 7, 7), (1, 5, 16), (28, 28), 0.0625, 7, 2) diff --git a/tests/st/ops/cpu/test_apply_power_sign.py b/tests/st/ops/cpu/test_apply_power_sign.py index 65238f1da72..e084e21609d 100644 --- a/tests/st/ops/cpu/test_apply_power_sign.py +++ b/tests/st/ops/cpu/test_apply_power_sign.py @@ -1,158 +1,158 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, Parameter -import mindspore.ops as ops -from mindspore.ops.functional import vmap - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.apply_power_sign = ops.ApplyPowerSign() - self.var = Parameter(Tensor(np.array([[0.6, 0.4], - [0.1, 0.5]]).astype(np.float32)), name="var") - self.m = Parameter(Tensor(np.array([[0.6, 0.5], - [0.2, 0.6]]).astype(np.float32)), name="m") - self.lr = 0.001 - self.logbase = np.e - self.sign_decay = 0.99 - self.beta = 0.9 - - def construct(self, grad): - out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase, - self.sign_decay, self.beta, grad) - return out - - -def test_apply_power_assign(): - """ - Feature: test ops ApplyPowerSign. - Description: Update var and m by ApplyPowerSign op. - Expectation: match to expected benchmark output. - """ - grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)) - net = Net() - net(grad) - expect_var = [[5.95575690e-01, 3.89676481e-01], - [9.85252112e-02, 4.88201708e-01]] - expect_m = [[5.70000052e-01, 5.19999981e-01], - [1.89999998e-01, 6.20000064e-01]] - assert np.allclose(net.var.asnumpy(), expect_var, atol=0.0001, rtol=0.0001, equal_nan=True) - assert np.allclose(net.m.asnumpy(), expect_m, atol=0.0001, rtol=0.0001, equal_nan=True) - - -class PowerSignNetVmap(nn.Cell): - def __init__(self, net): - super(PowerSignNetVmap, self).__init__() - self.net = net - self.var = Parameter( - Tensor(np.array([[[0.6, 0.4], [0.1, 0.5]], [[0.6, 0.4], [0.1, 0.5]]]).astype(np.float32)), name="var") - self.m = Parameter( - Tensor(np.array([[[0.6, 0.5], [0.2, 0.6]], [[0.6, 0.5], [0.2, 0.6]]]).astype(np.float32)), name="m") - self.vmap_grad = vmap(self.net, in_axes=(0, 0, 0, None, None, None, 0), out_axes=0) - - def construct(self, lr, logbase, sign_decay, beta, grad): - return self.vmap_grad(self.var, self.m, lr, logbase, sign_decay, beta, grad) - - -def test_apply_power_sign_op_vmap(): - """ - Feature: ApplyPowerSign cpu kernel - Description: test the ApplyPowerSign vmap. - Expectation: match to expected benchmark output. - """ - def cal_grad(var, m, lr, logbase, sign_decay, beta, grad): - return ops.ApplyPowerSign()(var, m, lr, logbase, sign_decay, beta, grad) - error = 1e-3 - grad = Tensor(np.array([[[0.3, 0.7], [0.1, 0.8]], - [[0.3, 0.7], [0.1, 0.8]]]).astype(np.float32)) - - lr = Tensor(np.array([0.01, 0.01]).astype(np.float32)) - logbase = np.e - sign_decay = 0.99 - beta = 0.9 - - vmap_agrad = PowerSignNetVmap(cal_grad) - output = vmap_agrad(lr, logbase, sign_decay, beta, grad) - mindspore_var_out = output[0].asnumpy() - mindspore_m_out = output[1].asnumpy() - - expect_var = np.array([[[0.5557564, 0.29676488], [0.08525213, 0.38201702]], - [[0.5557564, 0.29676488], [0.08525213, 0.38201702]]]).astype(np.float32) - - expect_m = np.array([[[0.57, 0.52], [0.19, 0.62]], - [[0.57, 0.52], [0.19, 0.62]]]).astype(np.float32) - - np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error) - np.testing.assert_allclose(mindspore_m_out, expect_m, rtol=error) - - -class PowerSignNetVmap2(nn.Cell): - def __init__(self, net): - super(PowerSignNetVmap2, self).__init__() - self.net = net - self.var = Parameter( - Tensor(np.array([[[[0.6, 0.4], [0.1, 0.5]], [[0.7, 0.4], [0.1, 0.5]]], - [[[0.8, 0.4], [0.1, 0.5]], [[0.9, 0.4], [0.1, 0.5]]]]).astype(np.float32)), name="var") - self.m = Parameter( - Tensor(np.array([[[[0.6, 0.5], [0.2, 0.6]], [[0.7, 0.5], [0.2, 0.6]]], - [[[0.8, 0.5], [0.2, 0.6]], [[0.9, 0.5], [0.2, 0.6]]]]).astype(np.float32)), name="m") - self.vmap_grad = vmap(vmap(self.net, in_axes=( - 0, 0, 0, None, 0, None, 0), out_axes=0), in_axes=(0, 0, 0, None, None, None, 0), out_axes=0) - - def construct(self, lr, logbase, sign_decay, beta, grad): - return self.vmap_grad(self.var, self.m, lr, logbase, sign_decay, beta, grad) - - -def test_apply_power_sign_op_vmap2(): - """ - Feature: ApplyPowerSign cpu kernel - Description: test the ApplyPowerSign vmap. - Expectation: match to expected benchmark output. - """ - def cal_grad(var, m, lr, logbase, sign_decay, beta, grad): - return ops.ApplyPowerSign()(var, m, lr, logbase, sign_decay, beta, grad) - error = 1e-3 - grad = Tensor(np.array([[[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]], - [[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]]]).astype(np.float32)) - lr = Tensor(np.array([[0.01, 0.02], [0.03, 0.04]]).astype(np.float32)) - logbase = np.e - sign_decay = Tensor(np.array([0.99, 0.9]).astype(np.float32)) - beta = 0.9 - - vmap_agrad = PowerSignNetVmap2(cal_grad) - output = vmap_agrad(lr, logbase, sign_decay, beta, grad) - - mindspore_var_out = output[0].asnumpy() - mindspore_m_out = output[1].asnumpy() - - expect_var = np.array([[[[0.5557564, 0.29676488], [0.08525213, 0.38201702]], - [[0.630716, 0.2383375], [0.07690535, 0.31524283]]], - [[[0.6672691, 0.09029466], [0.05575638, 0.14605102]], - [[0.7614321, 0.076675], [0.05381071, 0.13048568]]]]).astype(np.float32) - - expect_m = np.array([[[[0.57, 0.52], [0.19, 0.62]], - [[0.66, 0.52], [0.19, 0.62]]], - [[[0.75, 0.52], [0.19, 0.62]], - [[0.84, 0.52], [0.19, 0.62]]]]).astype(np.float32) - - np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error) - np.testing.assert_allclose(mindspore_m_out, expect_m, rtol=error) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +import mindspore.ops as ops +from mindspore.ops.functional import vmap + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.apply_power_sign = ops.ApplyPowerSign() + self.var = Parameter(Tensor(np.array([[0.6, 0.4], + [0.1, 0.5]]).astype(np.float32)), name="var") + self.m = Parameter(Tensor(np.array([[0.6, 0.5], + [0.2, 0.6]]).astype(np.float32)), name="m") + self.lr = 0.001 + self.logbase = np.e + self.sign_decay = 0.99 + self.beta = 0.9 + + def construct(self, grad): + out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase, + self.sign_decay, self.beta, grad) + return out + + +def test_apply_power_assign(): + """ + Feature: test ops ApplyPowerSign. + Description: Update var and m by ApplyPowerSign op. + Expectation: match to expected benchmark output. + """ + grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)) + net = Net() + net(grad) + expect_var = [[5.95575690e-01, 3.89676481e-01], + [9.85252112e-02, 4.88201708e-01]] + expect_m = [[5.70000052e-01, 5.19999981e-01], + [1.89999998e-01, 6.20000064e-01]] + assert np.allclose(net.var.asnumpy(), expect_var, atol=0.0001, rtol=0.0001, equal_nan=True) + assert np.allclose(net.m.asnumpy(), expect_m, atol=0.0001, rtol=0.0001, equal_nan=True) + + +class PowerSignNetVmap(nn.Cell): + def __init__(self, net): + super(PowerSignNetVmap, self).__init__() + self.net = net + self.var = Parameter( + Tensor(np.array([[[0.6, 0.4], [0.1, 0.5]], [[0.6, 0.4], [0.1, 0.5]]]).astype(np.float32)), name="var") + self.m = Parameter( + Tensor(np.array([[[0.6, 0.5], [0.2, 0.6]], [[0.6, 0.5], [0.2, 0.6]]]).astype(np.float32)), name="m") + self.vmap_grad = vmap(self.net, in_axes=(0, 0, 0, None, None, None, 0), out_axes=0) + + def construct(self, lr, logbase, sign_decay, beta, grad): + return self.vmap_grad(self.var, self.m, lr, logbase, sign_decay, beta, grad) + + +def test_apply_power_sign_op_vmap(): + """ + Feature: ApplyPowerSign cpu kernel + Description: test the ApplyPowerSign vmap. + Expectation: match to expected benchmark output. + """ + def cal_grad(var, m, lr, logbase, sign_decay, beta, grad): + return ops.ApplyPowerSign()(var, m, lr, logbase, sign_decay, beta, grad) + error = 1e-3 + grad = Tensor(np.array([[[0.3, 0.7], [0.1, 0.8]], + [[0.3, 0.7], [0.1, 0.8]]]).astype(np.float32)) + + lr = Tensor(np.array([0.01, 0.01]).astype(np.float32)) + logbase = np.e + sign_decay = 0.99 + beta = 0.9 + + vmap_agrad = PowerSignNetVmap(cal_grad) + output = vmap_agrad(lr, logbase, sign_decay, beta, grad) + mindspore_var_out = output[0].asnumpy() + mindspore_m_out = output[1].asnumpy() + + expect_var = np.array([[[0.5557564, 0.29676488], [0.08525213, 0.38201702]], + [[0.5557564, 0.29676488], [0.08525213, 0.38201702]]]).astype(np.float32) + + expect_m = np.array([[[0.57, 0.52], [0.19, 0.62]], + [[0.57, 0.52], [0.19, 0.62]]]).astype(np.float32) + + np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error) + np.testing.assert_allclose(mindspore_m_out, expect_m, rtol=error) + + +class PowerSignNetVmap2(nn.Cell): + def __init__(self, net): + super(PowerSignNetVmap2, self).__init__() + self.net = net + self.var = Parameter( + Tensor(np.array([[[[0.6, 0.4], [0.1, 0.5]], [[0.7, 0.4], [0.1, 0.5]]], + [[[0.8, 0.4], [0.1, 0.5]], [[0.9, 0.4], [0.1, 0.5]]]]).astype(np.float32)), name="var") + self.m = Parameter( + Tensor(np.array([[[[0.6, 0.5], [0.2, 0.6]], [[0.7, 0.5], [0.2, 0.6]]], + [[[0.8, 0.5], [0.2, 0.6]], [[0.9, 0.5], [0.2, 0.6]]]]).astype(np.float32)), name="m") + self.vmap_grad = vmap(vmap(self.net, in_axes=( + 0, 0, 0, None, 0, None, 0), out_axes=0), in_axes=(0, 0, 0, None, None, None, 0), out_axes=0) + + def construct(self, lr, logbase, sign_decay, beta, grad): + return self.vmap_grad(self.var, self.m, lr, logbase, sign_decay, beta, grad) + + +def test_apply_power_sign_op_vmap2(): + """ + Feature: ApplyPowerSign cpu kernel + Description: test the ApplyPowerSign vmap. + Expectation: match to expected benchmark output. + """ + def cal_grad(var, m, lr, logbase, sign_decay, beta, grad): + return ops.ApplyPowerSign()(var, m, lr, logbase, sign_decay, beta, grad) + error = 1e-3 + grad = Tensor(np.array([[[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]], + [[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]]]).astype(np.float32)) + lr = Tensor(np.array([[0.01, 0.02], [0.03, 0.04]]).astype(np.float32)) + logbase = np.e + sign_decay = Tensor(np.array([0.99, 0.9]).astype(np.float32)) + beta = 0.9 + + vmap_agrad = PowerSignNetVmap2(cal_grad) + output = vmap_agrad(lr, logbase, sign_decay, beta, grad) + + mindspore_var_out = output[0].asnumpy() + mindspore_m_out = output[1].asnumpy() + + expect_var = np.array([[[[0.5557564, 0.29676488], [0.08525213, 0.38201702]], + [[0.630716, 0.2383375], [0.07690535, 0.31524283]]], + [[[0.6672691, 0.09029466], [0.05575638, 0.14605102]], + [[0.7614321, 0.076675], [0.05381071, 0.13048568]]]]).astype(np.float32) + + expect_m = np.array([[[[0.57, 0.52], [0.19, 0.62]], + [[0.66, 0.52], [0.19, 0.62]]], + [[[0.75, 0.52], [0.19, 0.62]], + [[0.84, 0.52], [0.19, 0.62]]]]).astype(np.float32) + + np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error) + np.testing.assert_allclose(mindspore_m_out, expect_m, rtol=error) diff --git a/tests/st/ops/cpu/test_argmin_op.py b/tests/st/ops/cpu/test_argmin_op.py index 6835f8aafdb..61a586f35ce 100644 --- a/tests/st/ops/cpu/test_argmin_op.py +++ b/tests/st/ops/cpu/test_argmin_op.py @@ -1,184 +1,184 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import random -from functools import reduce -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, ops -from mindspore.common import dtype as mstype - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetArgmin(nn.Cell): - def __init__(self, axis=0): - super(NetArgmin, self).__init__() - self.argmin = ops.Argmin(axis=axis, output_type=mstype.int32) - - def construct(self, x): - return self.argmin(x) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_1d(): - """ - Features: The ops Argmin on CPU. - Description: Test Argmin with 1d-input. - Expectation: No exception. - """ - x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) - output = NetArgmin(axis=0)(x) - expect = np.array([0]).astype(np.float32) - assert (output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_2d(): - """ - Features: The ops Argmin on CPU. - Description: Test Argmin with 2d-input. - Expectation: No exception. - """ - x = Tensor(np.array([[1., 20., 5.], - [67., 8., 9.], - [130., 24., 15.]]).astype(np.float32)) - output = NetArgmin(axis=0)(x) - expect = np.array([0, 1, 0]).astype(np.float32) - assert (output.asnumpy() == expect).all() - output = NetArgmin(axis=1)(x) - expect = np.array([0, 1, 2]).astype(np.float32) - assert (output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_argmin_high_dims(): - """ - Features: The ops Argmin on CPU. - Description: Test Argmin with random input. - Expectation: No exception. - """ - for dim in range(3, 10): - shape = np.random.randint(1, 10, size=dim) - x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) - x = x.reshape(shape) - - rnd_axis = random.randint(-dim + 1, dim - 1) - ms_output = NetArgmin(axis=rnd_axis)(Tensor(x)) - np_output = np.argmin(x, axis=rnd_axis) - assert (ms_output.asnumpy() == np_output).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_function_argmin(): - """ - Features: The function argmin on CPU. - Description: Test function argmin with random input. - Expectation: No exception. - """ - for dim in range(2, 5): - shape = np.random.randint(1, 10, size=dim) - x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) - x = x.reshape(shape) - - rnd_axis = random.randint(-dim + 1, dim - 1) - ms_output = ops.argmin(Tensor(x), axis=rnd_axis) - np_output = np.argmin(x, axis=rnd_axis) - assert (ms_output.asnumpy() == np_output).all() - - -def cal_argmin_axis_zero(x): - return ops.Argmin(axis=0)(x) - - -def cal_argmin_axis_negative(x): - return ops.Argmin(axis=-1)(x) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_vmap_axis_zero(): - """ - Features: The argmin vmap on CPU. - Description: Test basic vmap of argmin op. - Expectation: No exception. - """ - x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], dtype=mstype.float32) - outputs = ops.vmap(cal_argmin_axis_zero, in_axes=0, out_axes=0)(x) - expect = np.array([1, 0, 1]).astype(np.int32) - assert np.allclose(outputs.asnumpy(), expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_vmap_basic_axis_negative(): - """ - Features: The argmin vmap on CPU. - Description: Test basic vmap of argmin op. - Expectation: No exception. - """ - x = Tensor([[[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], - [[4., 2., 1.], [3., 4., 5.], [1., 2., 3.]]], dtype=mstype.float32) - outputs = ops.vmap(cal_argmin_axis_negative, in_axes=0, out_axes=0)(x) - expect = np.array([[1, 0, 1], [2, 0, 0]]).astype(np.int32) - assert np.allclose(outputs.asnumpy(), expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_functional(): - """ - Feature: test ops.argmin. - Description: test ops.argmin functional api. - Expectation: the result match with expected result. - """ - x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32) - out_dim_none = ops.argmin(x, axis=None, keepdims=False) - out_dim_0 = ops.argmin(x, axis=0, keepdims=False) - out_dim_1 = ops.argmin(x, axis=1, keepdims=False) - out_dim_none_keepdim = ops.argmin(x, axis=None, keepdims=True) - out_dim_0_keepdim = ops.argmin(x, axis=0, keepdims=True) - out_dim_1_keepdim = ops.argmin(x, axis=1, keepdims=True) - - assert out_dim_none.asnumpy() == 7 - assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1])) - assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1])) - assert out_dim_none_keepdim.asnumpy() == 7 - assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]])) - assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]])) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argmin_tensor(): - """ - Feature: test tensor.argmin. - Description: test argmin tensor api. - Expectation: the result match with expected result. - """ - x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32) - out_dim_none = x.argmin(axis=None, keepdims=False) - out_dim_0 = x.argmin(axis=0, keepdims=False) - out_dim_1 = x.argmin(axis=1, keepdims=False) - out_dim_none_keepdim = x.argmin(axis=None, keepdims=True) - out_dim_0_keepdim = x.argmin(axis=0, keepdims=True) - out_dim_1_keepdim = x.argmin(axis=1, keepdims=True) - - assert out_dim_none.asnumpy() == 7 - assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1])) - assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1])) - assert out_dim_none_keepdim.asnumpy() == 7 - assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]])) - assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]])) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import random +from functools import reduce +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetArgmin(nn.Cell): + def __init__(self, axis=0): + super(NetArgmin, self).__init__() + self.argmin = ops.Argmin(axis=axis, output_type=mstype.int32) + + def construct(self, x): + return self.argmin(x) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_1d(): + """ + Features: The ops Argmin on CPU. + Description: Test Argmin with 1d-input. + Expectation: No exception. + """ + x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) + output = NetArgmin(axis=0)(x) + expect = np.array([0]).astype(np.float32) + assert (output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_2d(): + """ + Features: The ops Argmin on CPU. + Description: Test Argmin with 2d-input. + Expectation: No exception. + """ + x = Tensor(np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.]]).astype(np.float32)) + output = NetArgmin(axis=0)(x) + expect = np.array([0, 1, 0]).astype(np.float32) + assert (output.asnumpy() == expect).all() + output = NetArgmin(axis=1)(x) + expect = np.array([0, 1, 2]).astype(np.float32) + assert (output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_argmin_high_dims(): + """ + Features: The ops Argmin on CPU. + Description: Test Argmin with random input. + Expectation: No exception. + """ + for dim in range(3, 10): + shape = np.random.randint(1, 10, size=dim) + x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) + x = x.reshape(shape) + + rnd_axis = random.randint(-dim + 1, dim - 1) + ms_output = NetArgmin(axis=rnd_axis)(Tensor(x)) + np_output = np.argmin(x, axis=rnd_axis) + assert (ms_output.asnumpy() == np_output).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_function_argmin(): + """ + Features: The function argmin on CPU. + Description: Test function argmin with random input. + Expectation: No exception. + """ + for dim in range(2, 5): + shape = np.random.randint(1, 10, size=dim) + x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) + x = x.reshape(shape) + + rnd_axis = random.randint(-dim + 1, dim - 1) + ms_output = ops.argmin(Tensor(x), axis=rnd_axis) + np_output = np.argmin(x, axis=rnd_axis) + assert (ms_output.asnumpy() == np_output).all() + + +def cal_argmin_axis_zero(x): + return ops.Argmin(axis=0)(x) + + +def cal_argmin_axis_negative(x): + return ops.Argmin(axis=-1)(x) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_vmap_axis_zero(): + """ + Features: The argmin vmap on CPU. + Description: Test basic vmap of argmin op. + Expectation: No exception. + """ + x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], dtype=mstype.float32) + outputs = ops.vmap(cal_argmin_axis_zero, in_axes=0, out_axes=0)(x) + expect = np.array([1, 0, 1]).astype(np.int32) + assert np.allclose(outputs.asnumpy(), expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_vmap_basic_axis_negative(): + """ + Features: The argmin vmap on CPU. + Description: Test basic vmap of argmin op. + Expectation: No exception. + """ + x = Tensor([[[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], + [[4., 2., 1.], [3., 4., 5.], [1., 2., 3.]]], dtype=mstype.float32) + outputs = ops.vmap(cal_argmin_axis_negative, in_axes=0, out_axes=0)(x) + expect = np.array([[1, 0, 1], [2, 0, 0]]).astype(np.int32) + assert np.allclose(outputs.asnumpy(), expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_functional(): + """ + Feature: test ops.argmin. + Description: test ops.argmin functional api. + Expectation: the result match with expected result. + """ + x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32) + out_dim_none = ops.argmin(x, axis=None, keepdims=False) + out_dim_0 = ops.argmin(x, axis=0, keepdims=False) + out_dim_1 = ops.argmin(x, axis=1, keepdims=False) + out_dim_none_keepdim = ops.argmin(x, axis=None, keepdims=True) + out_dim_0_keepdim = ops.argmin(x, axis=0, keepdims=True) + out_dim_1_keepdim = ops.argmin(x, axis=1, keepdims=True) + + assert out_dim_none.asnumpy() == 7 + assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1])) + assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1])) + assert out_dim_none_keepdim.asnumpy() == 7 + assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]])) + assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]])) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argmin_tensor(): + """ + Feature: test tensor.argmin. + Description: test argmin tensor api. + Expectation: the result match with expected result. + """ + x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32) + out_dim_none = x.argmin(axis=None, keepdims=False) + out_dim_0 = x.argmin(axis=0, keepdims=False) + out_dim_1 = x.argmin(axis=1, keepdims=False) + out_dim_none_keepdim = x.argmin(axis=None, keepdims=True) + out_dim_0_keepdim = x.argmin(axis=0, keepdims=True) + out_dim_1_keepdim = x.argmin(axis=1, keepdims=True) + + assert out_dim_none.asnumpy() == 7 + assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1])) + assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1])) + assert out_dim_none_keepdim.asnumpy() == 7 + assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]])) + assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]])) diff --git a/tests/st/ops/cpu/test_argminwithvalue_op.py b/tests/st/ops/cpu/test_argminwithvalue_op.py index b4712134f19..438f5a43cf1 100644 --- a/tests/st/ops/cpu/test_argminwithvalue_op.py +++ b/tests/st/ops/cpu/test_argminwithvalue_op.py @@ -1,161 +1,161 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetArgminWithValue(nn.Cell): - - def __init__(self, axis=0, keep_dims=False): - super(NetArgminWithValue, self).__init__() - self.argmin = P.ArgMinWithValue(axis=axis, keep_dims=keep_dims) - - def construct(self, x): - return self.argmin(x) - - -def dyn_case(): - net = NetArgminWithValue() - - x_dyn = Tensor(shape=[None, None], dtype=ms.float32) - net.set_inputs(x_dyn) - - x = Tensor( - np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], - [-0.5, 25, 100]]).astype(np.float32)) - out = net(x) - - expect_shape = (3,) - for i in range(2): - assert out[i].asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argminwithvalue_dyn(): - """ - Feature: test ArgminWithValue dynamic shape in cpu. - Description: inputs is dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - dyn_case() - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - dyn_case() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_argminwithvalue_fp32(): - x = np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], - [-0.5, 25, 100]]).astype(np.float32) - argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) - - output0, output1 = argmin_a0(Tensor(x)) - expect0 = np.array([3, 1, 0]).astype(np.int32) - expect1 = np.array([-0.5, 8., 5.]).astype(np.float32) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) - - output0, output1 = argmin_a0k(Tensor(x)) - expect0 = np.array([[3, 1, 0]]).astype(np.int32) - expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float32) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) - - output0, output1 = argmin_a1(Tensor(x)) - expect0 = np.array([0, 1, 2, 0]).astype(np.int32) - expect1 = np.array([1., 8., 15., -0.5]).astype(np.float32) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) - - output0, output1 = argmin_a1k(Tensor(x)) - expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) - expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float32) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_argminwithvalue_fp16(): - x = np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], - [-0.5, 25, 100]]).astype(np.float16) - argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) - - output0, output1 = argmin_a0(Tensor(x)) - expect0 = np.array([3, 1, 0]).astype(np.int32) - expect1 = np.array([-0.5, 8., 5.]).astype(np.float16) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) - - output0, output1 = argmin_a0k(Tensor(x)) - expect0 = np.array([[3, 1, 0]]).astype(np.int32) - expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float16) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) - - output0, output1 = argmin_a1(Tensor(x)) - expect0 = np.array([0, 1, 2, 0]).astype(np.int32) - expect1 = np.array([1., 8., 15., -0.5]).astype(np.float16) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) - - output0, output1 = argmin_a1k(Tensor(x)) - expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) - expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float16) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_argminwithvalue_tensor(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop - argmin_a0 = NetArgminWithValue(axis=-2, keep_dims=False) - - output0, output1 = argmin_a0(Tensor(x)) - expect0 = np.argmin(x, axis=-2) - expect1 = np.min(x, axis=-2).astype(np.float16) - error = np.ones(shape=expect1.shape) * 1.0e-6 - assert np.all(output0.asnumpy() == expect0) - assert np.all(np.abs(output1.asnumpy() - expect1) < error) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetArgminWithValue(nn.Cell): + + def __init__(self, axis=0, keep_dims=False): + super(NetArgminWithValue, self).__init__() + self.argmin = P.ArgMinWithValue(axis=axis, keep_dims=keep_dims) + + def construct(self, x): + return self.argmin(x) + + +def dyn_case(): + net = NetArgminWithValue() + + x_dyn = Tensor(shape=[None, None], dtype=ms.float32) + net.set_inputs(x_dyn) + + x = Tensor( + np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], + [-0.5, 25, 100]]).astype(np.float32)) + out = net(x) + + expect_shape = (3,) + for i in range(2): + assert out[i].asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argminwithvalue_dyn(): + """ + Feature: test ArgminWithValue dynamic shape in cpu. + Description: inputs is dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + dyn_case() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_argminwithvalue_fp32(): + x = np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], + [-0.5, 25, 100]]).astype(np.float32) + argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.array([3, 1, 0]).astype(np.int32) + expect1 = np.array([-0.5, 8., 5.]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) + + output0, output1 = argmin_a0k(Tensor(x)) + expect0 = np.array([[3, 1, 0]]).astype(np.int32) + expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) + + output0, output1 = argmin_a1(Tensor(x)) + expect0 = np.array([0, 1, 2, 0]).astype(np.int32) + expect1 = np.array([1., 8., 15., -0.5]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) + + output0, output1 = argmin_a1k(Tensor(x)) + expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) + expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_argminwithvalue_fp16(): + x = np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], + [-0.5, 25, 100]]).astype(np.float16) + argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.array([3, 1, 0]).astype(np.int32) + expect1 = np.array([-0.5, 8., 5.]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) + + output0, output1 = argmin_a0k(Tensor(x)) + expect0 = np.array([[3, 1, 0]]).astype(np.int32) + expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) + + output0, output1 = argmin_a1(Tensor(x)) + expect0 = np.array([0, 1, 2, 0]).astype(np.int32) + expect1 = np.array([1., 8., 15., -0.5]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) + + output0, output1 = argmin_a1k(Tensor(x)) + expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) + expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_argminwithvalue_tensor(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop + argmin_a0 = NetArgminWithValue(axis=-2, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.argmin(x, axis=-2) + expect1 = np.min(x, axis=-2).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) diff --git a/tests/st/ops/cpu/test_binary_cross_entropy_op.py b/tests/st/ops/cpu/test_binary_cross_entropy_op.py index bd72d4813f3..0ca3835e724 100644 --- a/tests/st/ops/cpu/test_binary_cross_entropy_op.py +++ b/tests/st/ops/cpu/test_binary_cross_entropy_op.py @@ -1,228 +1,228 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -import mindspore as ms -from mindspore import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore.ops import functional as F - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class Net(nn.Cell): - def __init__(self, reduction="none"): - super(Net, self).__init__() - self.bce = P.BinaryCrossEntropy(reduction) - - def construct(self, x, y, weight=None): - return self.bce(x, y, weight) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_binary_cross_entropy_loss(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - reduction = "none" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, - 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, - 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, - 0.03405444, 0.23934692] - assert np.allclose(loss.asnumpy(), expect) - - -def test_binary_cross_entropy_loss_mean(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - reduction = "mean" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [0.7447324991226196] - assert loss.asnumpy() == expect - - -def test_binary_cross_entropy_loss_sum(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - reduction = "sum" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [14.894649505615234] - assert loss.asnumpy() == expect - - -def test_binary_cross_entropy_loss_sum_without_weight(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - reduction = "sum" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target)) - expect = [25.48195216753522] - assert np.allclose(loss.asnumpy(), expect) - - -def test_binary_cross_entropy_loss_16(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float16) - target = np.random.rand(20).astype(np.float16) - weight = np.random.rand(20).astype(np.float16) - reduction = "none" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [0.09552, 1.28613, 0.0351868, 0.696777, 0.243164, 0.990234, - 0.192139, 0.546875, 0.370117, 0.219971, 2.29492, 2.25391, - 1.58105, 1.32812, 0.987305, 1.30078, 0.0544434, 0.143921, - 0.0340576, 0.239258] - assert np.allclose(loss.asnumpy(), expect) - - -def test_binary_cross_entropy_loss_mean_16(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float16) - target = np.random.rand(20).astype(np.float16) - weight = np.random.rand(20).astype(np.float16) - reduction = "mean" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [0.74462890625] - assert loss.asnumpy() == expect - - -def test_binary_cross_entropy_loss_sum_16(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float16) - target = np.random.rand(20).astype(np.float16) - weight = np.random.rand(20).astype(np.float16) - reduction = "sum" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [14.890625] - assert loss.asnumpy() == expect - - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = C.GradOperation(get_all=True, sens_param=True) - self.network = network - - def construct(self, x1, x2, sens, weight=None): - gout = self.grad(self.network)(x1, x2, sens, weight) - return gout - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_binary_cross_entropy_loss_grad(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - sens = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - reduction = "none" - grad = Grad(Net(reduction)) - dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) - - dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, - -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, - 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, - 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, - -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] - - assert np.allclose(dx[0].asnumpy(), dx1_expect) - - -def test_binary_cross_entropy_forward_functional(nptype): - """ - Feature: test binary_cross_entropy forward for given input dtype. - Description: test inputs for given input dtype. - Expectation: the result match with expected result. - """ - logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(nptype)) - labels = Tensor(np.array([0., 1., 0.]).astype(nptype)) - weight = Tensor(np.array([1, 2, 2]).astype(nptype)) - output = F.binary_cross_entropy(logits, labels, weight) - expected = Tensor(np.array([0.38240486]).astype(nptype)) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_binary_cross_entropy_forward_float32_functional(): - """ - Feature: test binary_cross_entropy forward. - Description: test float32 inputs. - Expectation: the result match with expected result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - test_binary_cross_entropy_forward_functional(np.float32) - context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - test_binary_cross_entropy_forward_functional(np.float32) - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_binary_cross_entropy_forward_float32_functional_with_optional(mode): - """ - Feature: test binary_cross_entropy forward with optional input. - Description: test float32 inputs with optional input. - Expectation: without error. - """ - context.set_context(mode=mode) - logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(np.float32)) - labels = Tensor(np.array([0., 1., 0.]).astype(np.float32)) - weight = None - output = F.binary_cross_entropy(logits, labels, weight) - print(output.asnumpy()) - -class GradNet(nn.Cell): - def __init__(self, network): - super(GradNet, self).__init__() - self.grad = C.GradOperation(get_all=True) - self.network = network - - def construct(self, x1, x2, weight=None): - gout = self.grad(self.network)(x1, x2, weight) - return gout - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_binary_cross_entropy_loss_grad_with_optional(mode): - """ - Feature: test binary_cross_entropy backward with optional input. - Description: test float32 inputs with optional input. - Expectation: without error. - """ - context.set_context(mode=mode) - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - weight = None - reduction = "none" - grad = GradNet(Net(reduction)) - dx = grad(Tensor(prediction), Tensor(target), weight) - print(dx[0].asnumpy()) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.bce = P.BinaryCrossEntropy(reduction) + + def construct(self, x, y, weight=None): + return self.bce(x, y, weight) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "none" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, + 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, + 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, + 0.03405444, 0.23934692] + assert np.allclose(loss.asnumpy(), expect) + + +def test_binary_cross_entropy_loss_mean(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "mean" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.7447324991226196] + assert loss.asnumpy() == expect + + +def test_binary_cross_entropy_loss_sum(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [14.894649505615234] + assert loss.asnumpy() == expect + + +def test_binary_cross_entropy_loss_sum_without_weight(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target)) + expect = [25.48195216753522] + assert np.allclose(loss.asnumpy(), expect) + + +def test_binary_cross_entropy_loss_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "none" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09552, 1.28613, 0.0351868, 0.696777, 0.243164, 0.990234, + 0.192139, 0.546875, 0.370117, 0.219971, 2.29492, 2.25391, + 1.58105, 1.32812, 0.987305, 1.30078, 0.0544434, 0.143921, + 0.0340576, 0.239258] + assert np.allclose(loss.asnumpy(), expect) + + +def test_binary_cross_entropy_loss_mean_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "mean" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.74462890625] + assert loss.asnumpy() == expect + + +def test_binary_cross_entropy_loss_sum_16(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float16) + target = np.random.rand(20).astype(np.float16) + weight = np.random.rand(20).astype(np.float16) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [14.890625] + assert loss.asnumpy() == expect + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens, weight=None): + gout = self.grad(self.network)(x1, x2, sens, weight) + return gout + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + reduction = "none" + grad = Grad(Net(reduction)) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) + + dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, + -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, + 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, + 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, + -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + + +def test_binary_cross_entropy_forward_functional(nptype): + """ + Feature: test binary_cross_entropy forward for given input dtype. + Description: test inputs for given input dtype. + Expectation: the result match with expected result. + """ + logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(nptype)) + labels = Tensor(np.array([0., 1., 0.]).astype(nptype)) + weight = Tensor(np.array([1, 2, 2]).astype(nptype)) + output = F.binary_cross_entropy(logits, labels, weight) + expected = Tensor(np.array([0.38240486]).astype(nptype)) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_binary_cross_entropy_forward_float32_functional(): + """ + Feature: test binary_cross_entropy forward. + Description: test float32 inputs. + Expectation: the result match with expected result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + test_binary_cross_entropy_forward_functional(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + test_binary_cross_entropy_forward_functional(np.float32) + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_binary_cross_entropy_forward_float32_functional_with_optional(mode): + """ + Feature: test binary_cross_entropy forward with optional input. + Description: test float32 inputs with optional input. + Expectation: without error. + """ + context.set_context(mode=mode) + logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(np.float32)) + labels = Tensor(np.array([0., 1., 0.]).astype(np.float32)) + weight = None + output = F.binary_cross_entropy(logits, labels, weight) + print(output.asnumpy()) + +class GradNet(nn.Cell): + def __init__(self, network): + super(GradNet, self).__init__() + self.grad = C.GradOperation(get_all=True) + self.network = network + + def construct(self, x1, x2, weight=None): + gout = self.grad(self.network)(x1, x2, weight) + return gout + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_binary_cross_entropy_loss_grad_with_optional(mode): + """ + Feature: test binary_cross_entropy backward with optional input. + Description: test float32 inputs with optional input. + Expectation: without error. + """ + context.set_context(mode=mode) + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = None + reduction = "none" + grad = GradNet(Net(reduction)) + dx = grad(Tensor(prediction), Tensor(target), weight) + print(dx[0].asnumpy()) diff --git a/tests/st/ops/cpu/test_cpu_type.py b/tests/st/ops/cpu/test_cpu_type.py index 25d0820b877..eaac5329d38 100644 --- a/tests/st/ops/cpu/test_cpu_type.py +++ b/tests/st/ops/cpu/test_cpu_type.py @@ -1,113 +1,113 @@ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.context as context -from mindspore.nn import Dense -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import Momentum - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.bias_add = P.BiasAdd() - self.bias_add1 = P.BiasAdd() - - def construct(self, x, b, c): - return self.bias_add1(self.bias_add(x, b), c) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_bias_add1(): - x = np.ones([2, 2]).astype(np.float16) - b = np.array([1, 1]).astype(np.float16) - c = np.array([1, 1]).astype(np.float16) - bias_add = Net() - output = bias_add(Tensor(x), Tensor(b), Tensor(c)) - expect_output = np.ones([2, 2]).astype(np.float16) * 3 - assert np.all(output.asnumpy() == expect_output) - - -class Net1(nn.Cell): - def __init__(self): - super(Net1, self).__init__() - self.bias_add = P.BiasAdd() - self.mul = P.Mul() - - def construct(self, x, a, b): - p1 = self.bias_add(x, b) - p2 = self.bias_add(x, a) - p3 = self.mul(p1, p2) - return p3 - - -class Net2(nn.Cell): - def __init__(self): - super(Net2, self).__init__() - self.bias_add = P.BiasAdd() - self.bias_add1 = P.BiasAdd() - - def construct(self, x, b, c): - return self.bias_add1(self.bias_add(x, b), c) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_bias_add2(): - x = np.ones([2, 2]).astype(np.float32) - a = np.array([1, 1]).astype(np.float32) - b = np.array([1, 1]).astype(np.float32) - c = np.array([1, 1]).astype(np.float32) - bias_add = Net1() - output = bias_add(Tensor(x), Tensor(a), Tensor(b)) - print(output) - - net2 = Net2() - output2 = net2(Tensor(x), Tensor(b), Tensor(c)) - print(output2) - - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class MomentumNet(nn.Cell): - def __init__(self): - super(MomentumNet, self).__init__() - self.batch_size = 1 - - self.reshape = P.Reshape() - weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) - self.fc1 = Dense(16, 10, weight_init=weight) - - def construct(self, input_x): - output = self.reshape(input_x, (self.batch_size, -1)) - output = self.fc1(output) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_momentum(): - epoch = 1 - net = MomentumNet() - learning_rate = (0.1, 0.2) - momentum = 0.9 - - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer - train_network.set_train() - losses = [] - for _ in range(epoch): - data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) - label = Tensor(np.array([0]).astype(np.int32)) - loss = train_network(data, label) - losses.append(loss) - print("================================") - print(losses) - - return losses +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.context as context +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add = P.BiasAdd() + self.bias_add1 = P.BiasAdd() + + def construct(self, x, b, c): + return self.bias_add1(self.bias_add(x, b), c) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_bias_add1(): + x = np.ones([2, 2]).astype(np.float16) + b = np.array([1, 1]).astype(np.float16) + c = np.array([1, 1]).astype(np.float16) + bias_add = Net() + output = bias_add(Tensor(x), Tensor(b), Tensor(c)) + expect_output = np.ones([2, 2]).astype(np.float16) * 3 + assert np.all(output.asnumpy() == expect_output) + + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.bias_add = P.BiasAdd() + self.mul = P.Mul() + + def construct(self, x, a, b): + p1 = self.bias_add(x, b) + p2 = self.bias_add(x, a) + p3 = self.mul(p1, p2) + return p3 + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.bias_add = P.BiasAdd() + self.bias_add1 = P.BiasAdd() + + def construct(self, x, b, c): + return self.bias_add1(self.bias_add(x, b), c) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_bias_add2(): + x = np.ones([2, 2]).astype(np.float32) + a = np.array([1, 1]).astype(np.float32) + b = np.array([1, 1]).astype(np.float32) + c = np.array([1, 1]).astype(np.float32) + bias_add = Net1() + output = bias_add(Tensor(x), Tensor(a), Tensor(b)) + print(output) + + net2 = Net2() + output2 = net2(Tensor(x), Tensor(b), Tensor(c)) + print(output2) + + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class MomentumNet(nn.Cell): + def __init__(self): + super(MomentumNet, self).__init__() + self.batch_size = 1 + + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_momentum(): + epoch = 1 + net = MomentumNet() + learning_rate = (0.1, 0.2) + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss) + print("================================") + print(losses) + + return losses diff --git a/tests/st/ops/cpu/test_equal_op.py b/tests/st/ops/cpu/test_equal_op.py index 9f012b8e856..22bca864b25 100644 --- a/tests/st/ops/cpu/test_equal_op.py +++ b/tests/st/ops/cpu/test_equal_op.py @@ -1,123 +1,123 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.ops import operations as P -from mindspore.common import dtype as mstype - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetEqBool(nn.Cell): - def __init__(self): - super(NetEqBool, self).__init__() - self.equal = P.Equal() - x = Tensor(np.array([True, True, False]).astype(np.bool)) - y = Tensor(np.array([True, False, True]).astype(np.bool)) - self.x = Parameter(initializer(x, x.shape), name="x") - self.y = Parameter(initializer(y, y.shape), name="y") - - def construct(self): - return self.equal(self.x, self.y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equal_bool(): - equal_net = NetEqBool() - output = equal_net() - print("================================") - expect = np.array([True, False, False]).astype(np.bool) - print(output) - assert (output.asnumpy() == expect).all() - - -class NetEqInt(nn.Cell): - def __init__(self): - super(NetEqInt, self).__init__() - self.equal = P.Equal() - x = Tensor(np.array([1, 20, 5]).astype(np.int32)) - y = Tensor(np.array([2, 20, 5]).astype(np.int32)) - self.x = Parameter(initializer(x, x.shape), name="x") - self.y = Parameter(initializer(y, y.shape), name="y") - - def construct(self): - return self.equal(self.x, self.y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equal_int(): - equal_net = NetEqInt() - output = equal_net() - print("================================") - expect = np.array([False, True, True]).astype(np.bool) - print(output) - assert (output.asnumpy() == expect).all() - - -class NetEqFloat(nn.Cell): - def __init__(self): - super(NetEqFloat, self).__init__() - self.equal = P.Equal() - x = Tensor(np.array([1.2, 10.4, 5.5]).astype(np.float32)) - y = Tensor(np.array([1.2, 10.3, 5.4]).astype(np.float32)) - self.x = Parameter(initializer(x, x.shape), name="x") - self.y = Parameter(initializer(y, y.shape), name="y") - - def construct(self): - return self.equal(self.x, self.y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equal_float(): - equal_net = NetEqFloat() - output = equal_net() - print("================================") - expect = np.array([True, False, False]).astype(np.bool) - print(output) - assert (output.asnumpy() == expect).all() - - -def test_equal_tensor_api(): - """ - Feature: test equal tensor API. - Description: testcase for equal tensor API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([1, 2, 3]), mstype.int32) - y = Tensor(np.array([1, 2, 4]), mstype.int32) - output = x.equal(y) - expected = np.array([True, True, False]) - np.testing.assert_array_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equal_tensor_modes(): - """ - Feature: test equal tensor API in PyNative and Graph modes. - Description: test case for equal tensor API. - Expectation: the result match with expected result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - test_equal_tensor_api() - context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - test_equal_tensor_api() +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetEqBool(nn.Cell): + def __init__(self): + super(NetEqBool, self).__init__() + self.equal = P.Equal() + x = Tensor(np.array([True, True, False]).astype(np.bool)) + y = Tensor(np.array([True, False, True]).astype(np.bool)) + self.x = Parameter(initializer(x, x.shape), name="x") + self.y = Parameter(initializer(y, y.shape), name="y") + + def construct(self): + return self.equal(self.x, self.y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equal_bool(): + equal_net = NetEqBool() + output = equal_net() + print("================================") + expect = np.array([True, False, False]).astype(np.bool) + print(output) + assert (output.asnumpy() == expect).all() + + +class NetEqInt(nn.Cell): + def __init__(self): + super(NetEqInt, self).__init__() + self.equal = P.Equal() + x = Tensor(np.array([1, 20, 5]).astype(np.int32)) + y = Tensor(np.array([2, 20, 5]).astype(np.int32)) + self.x = Parameter(initializer(x, x.shape), name="x") + self.y = Parameter(initializer(y, y.shape), name="y") + + def construct(self): + return self.equal(self.x, self.y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equal_int(): + equal_net = NetEqInt() + output = equal_net() + print("================================") + expect = np.array([False, True, True]).astype(np.bool) + print(output) + assert (output.asnumpy() == expect).all() + + +class NetEqFloat(nn.Cell): + def __init__(self): + super(NetEqFloat, self).__init__() + self.equal = P.Equal() + x = Tensor(np.array([1.2, 10.4, 5.5]).astype(np.float32)) + y = Tensor(np.array([1.2, 10.3, 5.4]).astype(np.float32)) + self.x = Parameter(initializer(x, x.shape), name="x") + self.y = Parameter(initializer(y, y.shape), name="y") + + def construct(self): + return self.equal(self.x, self.y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equal_float(): + equal_net = NetEqFloat() + output = equal_net() + print("================================") + expect = np.array([True, False, False]).astype(np.bool) + print(output) + assert (output.asnumpy() == expect).all() + + +def test_equal_tensor_api(): + """ + Feature: test equal tensor API. + Description: testcase for equal tensor API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([1, 2, 3]), mstype.int32) + y = Tensor(np.array([1, 2, 4]), mstype.int32) + output = x.equal(y) + expected = np.array([True, True, False]) + np.testing.assert_array_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equal_tensor_modes(): + """ + Feature: test equal tensor API in PyNative and Graph modes. + Description: test case for equal tensor API. + Expectation: the result match with expected result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + test_equal_tensor_api() + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + test_equal_tensor_api() diff --git a/tests/st/ops/cpu/test_equalcount_op.py b/tests/st/ops/cpu/test_equalcount_op.py index 0fe080e11c9..fdf10327af6 100644 --- a/tests/st/ops/cpu/test_equalcount_op.py +++ b/tests/st/ops/cpu/test_equalcount_op.py @@ -1,81 +1,81 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetEqualCount(nn.Cell): - def __init__(self): - super(NetEqualCount, self).__init__() - self.equalcount = P.EqualCount() - x = Tensor(np.array([1, 20, 5]).astype(np.int32)) - y = Tensor(np.array([2, 20, 5]).astype(np.int32)) - self.x = Parameter(initializer(x, x.shape), name='x') - self.y = Parameter(initializer(y, y.shape), name='y') - - def construct(self): - return self.equalcount(self.x, self.y) - - -class NetEqualCount2(nn.Cell): - def __init__(self): - super(NetEqualCount2, self).__init__() - self.equalcount = P.EqualCount() - - def construct(self, x, y): - return self.equalcount(x, y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equalcount(): - net = NetEqualCount() - output = net() - print("================================") - expect = np.array([2]).astype(np.int32) - print(output) - assert (output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_equalcount_dynamic(): - """ - Feature: EqualCount ops - Description: dynamic shape in cpu - Expectation: success - """ - net = NetEqualCount2() - xx = Tensor(shape=[None], dtype=mindspore.int32) - yy = Tensor(shape=[None], dtype=mindspore.int32) - net.set_inputs(xx, yy) - - x = Tensor(np.array([1, 20, 5]).astype(np.int32)) - y = Tensor(np.array([2, 20, 5]).astype(np.int32)) - output = net(x, y) - print("================================") - expect = np.array([2]).astype(np.int32) - print(output) - assert (output.asnumpy() == expect).all() +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetEqualCount(nn.Cell): + def __init__(self): + super(NetEqualCount, self).__init__() + self.equalcount = P.EqualCount() + x = Tensor(np.array([1, 20, 5]).astype(np.int32)) + y = Tensor(np.array([2, 20, 5]).astype(np.int32)) + self.x = Parameter(initializer(x, x.shape), name='x') + self.y = Parameter(initializer(y, y.shape), name='y') + + def construct(self): + return self.equalcount(self.x, self.y) + + +class NetEqualCount2(nn.Cell): + def __init__(self): + super(NetEqualCount2, self).__init__() + self.equalcount = P.EqualCount() + + def construct(self, x, y): + return self.equalcount(x, y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equalcount(): + net = NetEqualCount() + output = net() + print("================================") + expect = np.array([2]).astype(np.int32) + print(output) + assert (output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_equalcount_dynamic(): + """ + Feature: EqualCount ops + Description: dynamic shape in cpu + Expectation: success + """ + net = NetEqualCount2() + xx = Tensor(shape=[None], dtype=mindspore.int32) + yy = Tensor(shape=[None], dtype=mindspore.int32) + net.set_inputs(xx, yy) + + x = Tensor(np.array([1, 20, 5]).astype(np.int32)) + y = Tensor(np.array([2, 20, 5]).astype(np.int32)) + output = net(x, y) + print("================================") + expect = np.array([2]).astype(np.int32) + print(output) + assert (output.asnumpy() == expect).all() diff --git a/tests/st/ops/cpu/test_fftwithsize.py b/tests/st/ops/cpu/test_fftwithsize.py old mode 100755 new mode 100644 index 98741e7d3d6..dcb136a2e89 --- a/tests/st/ops/cpu/test_fftwithsize.py +++ b/tests/st/ops/cpu/test_fftwithsize.py @@ -1,188 +1,188 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.context as context -from mindspore import Tensor, ops -from tests.st.utils import test_utils -from tests.mark_utils import arg_mark - - -@test_utils.run_with_cell -def fft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) - - -@test_utils.run_with_cell -def rfft_and_irfft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - x = ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) - return ops.FFTWithSize(signal_ndim, not inverse, real, norm, onesided, signal_sizes)(x) - - -@test_utils.run_with_cell -def fft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.grad(fft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) - - -@test_utils.run_with_cell -def rfft_and_irfft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.grad(rfft_and_irfft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) -def test_fftwithsize_fft_ifft(dtype, eps): - """ - Feature: fft & ifft function - Description: test cases for fft & ifft - Expectation: the result matches pytorch - """ - x = Tensor(np.array([1.6243454+0.j, -0.6117564+0.j, -0.5281718+0.j, -1.0729686+0.j]).astype(dtype)) - expect = np.array([-0.5885514+0.j, 2.1525173-0.46121222j, 2.7808986+0.j, 2.1525173+0.46121222j]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = fft_forward_func(x, 1, False, False) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - output_ifft = fft_forward_func(output, 1, True, False) - diff_ifft = np.abs(output_ifft.asnumpy() - x.asnumpy()) - assert np.all(diff_ifft < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) -def test_fftwithsize_fft2_ifft2(dtype, eps): - """ - Feature: fft2 & ifft2 function - Description: test cases for fft2 & ifft2 - Expectation: the result matches pytorch - """ - x = Tensor(np.array([[1.6243454+0.j, -0.6117564+0.j], [-0.5281718+0.j, -1.0729686+0.j]]).astype(dtype)) - expect = np.array([[-0.5885514+0.j, 2.7808986+0.j], [2.6137295+0.j, 1.6913052+0.j]]).astype(dtype) - error = np.ones(shape=[2, 2]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = fft_forward_func(x, 2, False, False) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - output_ifft2 = fft_forward_func(output, 2, True, False) - diff_ifft2 = np.abs(output_ifft2.asnumpy() - x.asnumpy()) - assert np.all(diff_ifft2 < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_forward(mode): - """ - Feature: rfft3 forward function - Description: test cases for rfft - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - x = np.arange(1 * 2 * 3 * 4, dtype=np.float64).reshape(1, 2, 3, 4) - ms_x = ms.Tensor(x) - output = fft_forward_func(ms_x, 3, False, True) - expect = np.fft.rfftn(x, s=(2, 3, 4)) - assert np.allclose(output.asnumpy(), expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_irfft3_forward(mode): - """ - Feature: irfft3 forward function - Description: test cases for irfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - x = np.arange(1 * 2 * 3 * 3, dtype=np.complex128).reshape(1, 2, 3, 3) - ms_x = ms.Tensor(x) - output = fft_forward_func(ms_x, 3, True, True) - expect = np.fft.irfftn(x, s=(2, 3, 4)) - assert np.allclose(output.asnumpy(), expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_backward(mode): - """ - Feature: rfft3 backward function - Description: test cases for rfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - dim1 = 1 - dim2 = 2 - dim3 = 3 - dim4 = 4 - offset_size = dim1 * dim2 * dim3 * dim4 - x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) - ms_x = ms.Tensor(x) - output = fft_backward_func(ms_x, 3, False, True) - dout = np.ones((dim1, dim2, dim3, dim4 // 2 + 1), dtype=np.complex128) - concat_array = np.zeros((dim1, dim2, dim3, dim4 - dim4 // 2 - 1)) - concat_array = concat_array.astype(np.complex128) - dout = np.concatenate((dout, concat_array), axis=-1) - expect = np.fft.ifftn(dout, s=(dim2, dim3, dim4)) * offset_size - assert np.allclose(output.asnumpy(), expect.real) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_and_irfft3_backward(mode): - """ - Feature: rfft3_and_irfft3 function - Description: test cases for rfft3_and_irfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - dim1 = 1 - dim2 = 2 - dim3 = 3 - dim4 = 4 - offset_size = dim1 * dim2 * dim3 * dim4 - x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) - ms_x = ms.Tensor(x) - output = rfft_and_irfft_backward_func(ms_x, 3, False, True) - expect = np.ones((dim1, dim2, dim3, dim4)) - assert np.allclose(output.asnumpy(), expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', - essential_mark='essential') -def test_fftwithsize_exception(): - """ - Feature: FFTWithSize op. - Description: Test FFTWithSize operator input when last dimension is 1. - Expectation: The result match to the expect value. - """ - signal_ndim = 2 - inverse = True - real = True - x = Tensor(np.random.uniform(-10, 10, size=[2, 1])).astype(ms.complex64) - with pytest.raises(ValueError, match="For 'FFTWithSize', the last dimension of the input cannot be 1"): - fft_forward_func(x, signal_ndim, inverse, real) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, ops +from tests.st.utils import test_utils +from tests.mark_utils import arg_mark + + +@test_utils.run_with_cell +def fft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) + + +@test_utils.run_with_cell +def rfft_and_irfft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + x = ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) + return ops.FFTWithSize(signal_ndim, not inverse, real, norm, onesided, signal_sizes)(x) + + +@test_utils.run_with_cell +def fft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.grad(fft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) + + +@test_utils.run_with_cell +def rfft_and_irfft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.grad(rfft_and_irfft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) +def test_fftwithsize_fft_ifft(dtype, eps): + """ + Feature: fft & ifft function + Description: test cases for fft & ifft + Expectation: the result matches pytorch + """ + x = Tensor(np.array([1.6243454+0.j, -0.6117564+0.j, -0.5281718+0.j, -1.0729686+0.j]).astype(dtype)) + expect = np.array([-0.5885514+0.j, 2.1525173-0.46121222j, 2.7808986+0.j, 2.1525173+0.46121222j]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = fft_forward_func(x, 1, False, False) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + output_ifft = fft_forward_func(output, 1, True, False) + diff_ifft = np.abs(output_ifft.asnumpy() - x.asnumpy()) + assert np.all(diff_ifft < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) +def test_fftwithsize_fft2_ifft2(dtype, eps): + """ + Feature: fft2 & ifft2 function + Description: test cases for fft2 & ifft2 + Expectation: the result matches pytorch + """ + x = Tensor(np.array([[1.6243454+0.j, -0.6117564+0.j], [-0.5281718+0.j, -1.0729686+0.j]]).astype(dtype)) + expect = np.array([[-0.5885514+0.j, 2.7808986+0.j], [2.6137295+0.j, 1.6913052+0.j]]).astype(dtype) + error = np.ones(shape=[2, 2]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = fft_forward_func(x, 2, False, False) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + output_ifft2 = fft_forward_func(output, 2, True, False) + diff_ifft2 = np.abs(output_ifft2.asnumpy() - x.asnumpy()) + assert np.all(diff_ifft2 < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_forward(mode): + """ + Feature: rfft3 forward function + Description: test cases for rfft + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + x = np.arange(1 * 2 * 3 * 4, dtype=np.float64).reshape(1, 2, 3, 4) + ms_x = ms.Tensor(x) + output = fft_forward_func(ms_x, 3, False, True) + expect = np.fft.rfftn(x, s=(2, 3, 4)) + assert np.allclose(output.asnumpy(), expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_irfft3_forward(mode): + """ + Feature: irfft3 forward function + Description: test cases for irfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + x = np.arange(1 * 2 * 3 * 3, dtype=np.complex128).reshape(1, 2, 3, 3) + ms_x = ms.Tensor(x) + output = fft_forward_func(ms_x, 3, True, True) + expect = np.fft.irfftn(x, s=(2, 3, 4)) + assert np.allclose(output.asnumpy(), expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_backward(mode): + """ + Feature: rfft3 backward function + Description: test cases for rfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + dim1 = 1 + dim2 = 2 + dim3 = 3 + dim4 = 4 + offset_size = dim1 * dim2 * dim3 * dim4 + x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) + ms_x = ms.Tensor(x) + output = fft_backward_func(ms_x, 3, False, True) + dout = np.ones((dim1, dim2, dim3, dim4 // 2 + 1), dtype=np.complex128) + concat_array = np.zeros((dim1, dim2, dim3, dim4 - dim4 // 2 - 1)) + concat_array = concat_array.astype(np.complex128) + dout = np.concatenate((dout, concat_array), axis=-1) + expect = np.fft.ifftn(dout, s=(dim2, dim3, dim4)) * offset_size + assert np.allclose(output.asnumpy(), expect.real) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_and_irfft3_backward(mode): + """ + Feature: rfft3_and_irfft3 function + Description: test cases for rfft3_and_irfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + dim1 = 1 + dim2 = 2 + dim3 = 3 + dim4 = 4 + offset_size = dim1 * dim2 * dim3 * dim4 + x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) + ms_x = ms.Tensor(x) + output = rfft_and_irfft_backward_func(ms_x, 3, False, True) + expect = np.ones((dim1, dim2, dim3, dim4)) + assert np.allclose(output.asnumpy(), expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', + essential_mark='essential') +def test_fftwithsize_exception(): + """ + Feature: FFTWithSize op. + Description: Test FFTWithSize operator input when last dimension is 1. + Expectation: The result match to the expect value. + """ + signal_ndim = 2 + inverse = True + real = True + x = Tensor(np.random.uniform(-10, 10, size=[2, 1])).astype(ms.complex64) + with pytest.raises(ValueError, match="For 'FFTWithSize', the last dimension of the input cannot be 1"): + fft_forward_func(x, signal_ndim, inverse, real) diff --git a/tests/st/ops/cpu/test_flatten_op.py b/tests/st/ops/cpu/test_flatten_op.py index 9ed09b55722..d6680ebb18c 100644 --- a/tests/st/ops/cpu/test_flatten_op.py +++ b/tests/st/ops/cpu/test_flatten_op.py @@ -1,96 +1,96 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, ops - - -class FlattenNet(nn.Cell): - def __init__(self): - super().__init__() - self.flatten = ops.Flatten() - - def construct(self, x): - out = self.flatten(x) - return out - - -class FlattenFunc(nn.Cell): - def construct(self, x): - out = ops.flatten(x) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, - np.uint32, np.uint64, np.float16, np.float32, np.float64, - np.bool, np.complex64, np.complex128]) -def test_flatten_op_dtype(mode, dtype): - """ - Feature: cpu Flatten op. - Description: test flatten with the different types. - Expectation: success. - """ - context.set_context(mode=mode, device_target="CPU") - - x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(dtype)) - expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(dtype) - - net = FlattenNet() - out = net(x) - - assert np.allclose(expect, out.asnumpy()) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_flatten_op_functional(mode): - """ - Feature: cpu Flatten op. - Description: test flatten with functional interface. - Expectation: success. - """ - context.set_context(mode=mode, device_target="CPU") - x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) - expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) - - net = FlattenFunc() - out = net(x) - - assert np.allclose(expect, out.asnumpy()) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_flatten_op_nn(mode): - """ - Feature: cpu Flatten ops. - Description: test flatten with nn interface. - Expectation: success. - """ - context.set_context(mode=mode, device_target="CPU") - x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) - expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) - - net = nn.Flatten() - out = net(x) - - assert np.allclose(expect, out.asnumpy()) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class FlattenNet(nn.Cell): + def __init__(self): + super().__init__() + self.flatten = ops.Flatten() + + def construct(self, x): + out = self.flatten(x) + return out + + +class FlattenFunc(nn.Cell): + def construct(self, x): + out = ops.flatten(x) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, + np.uint32, np.uint64, np.float16, np.float32, np.float64, + np.bool, np.complex64, np.complex128]) +def test_flatten_op_dtype(mode, dtype): + """ + Feature: cpu Flatten op. + Description: test flatten with the different types. + Expectation: success. + """ + context.set_context(mode=mode, device_target="CPU") + + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(dtype)) + expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(dtype) + + net = FlattenNet() + out = net(x) + + assert np.allclose(expect, out.asnumpy()) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_flatten_op_functional(mode): + """ + Feature: cpu Flatten op. + Description: test flatten with functional interface. + Expectation: success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) + expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) + + net = FlattenFunc() + out = net(x) + + assert np.allclose(expect, out.asnumpy()) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_flatten_op_nn(mode): + """ + Feature: cpu Flatten ops. + Description: test flatten with nn interface. + Expectation: success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) + expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) + + net = nn.Flatten() + out = net(x) + + assert np.allclose(expect, out.asnumpy()) diff --git a/tests/st/ops/cpu/test_floordiv_op.py b/tests/st/ops/cpu/test_floordiv_op.py index d83b63e1d10..517f7ecc14c 100644 --- a/tests/st/ops/cpu/test_floordiv_op.py +++ b/tests/st/ops/cpu/test_floordiv_op.py @@ -1,71 +1,71 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetFloorDiv(nn.Cell): - def __init__(self): - super(NetFloorDiv, self).__init__() - self.floordiv = P.FloorDiv() - - def construct(self, x, y): - return self.floordiv(x, y) - - -@pytest.mark.skip(reason="never run on ci or smoke test") -@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, - np.int64, np.uint8, np.uint16, np.uint32, np.uint64]) -def testtype_floor_div_int_float(dtype): - """ - Feature: ALL To ALL - Description: test cases for FloorDiv - Expectation: the result match to numpy - """ - x_np = np.random.rand(1, 5).astype(dtype) - y_np = np.random.rand(1, 5).astype(dtype) - expect = np.floor_divide(x_np, y_np) - x_input = Tensor(x_np) - y_input = Tensor(y_np) - floor_div = NetFloorDiv() - output = floor_div(x_input, y_input) - assert np.allclose(output.asnumpy(), expect) - - -@pytest.mark.skip(reason="never run on ci or smoke test") -@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -def testtype_floor_div_complex(dtype): - """ - Feature: ALL To ALL - Description: test cases for FloorDiv - Expectation: the result match to numpy - """ - x_np = np.random.rand(1, 5).astype(dtype) - x_np = x_np + 0.5j * x_np - y_np = np.random.rand(1, 5).astype(dtype) - y_np = y_np + 0.4j * y_np - expect = np.floor_divide(x_np, y_np) - x_input = Tensor(x_np) - y_input = Tensor(y_np) - floor_div = NetFloorDiv() - output = floor_div(x_input, y_input) - assert np.allclose(output.asnumpy(), expect) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetFloorDiv(nn.Cell): + def __init__(self): + super(NetFloorDiv, self).__init__() + self.floordiv = P.FloorDiv() + + def construct(self, x, y): + return self.floordiv(x, y) + + +@pytest.mark.skip(reason="never run on ci or smoke test") +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, + np.int64, np.uint8, np.uint16, np.uint32, np.uint64]) +def testtype_floor_div_int_float(dtype): + """ + Feature: ALL To ALL + Description: test cases for FloorDiv + Expectation: the result match to numpy + """ + x_np = np.random.rand(1, 5).astype(dtype) + y_np = np.random.rand(1, 5).astype(dtype) + expect = np.floor_divide(x_np, y_np) + x_input = Tensor(x_np) + y_input = Tensor(y_np) + floor_div = NetFloorDiv() + output = floor_div(x_input, y_input) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.skip(reason="never run on ci or smoke test") +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def testtype_floor_div_complex(dtype): + """ + Feature: ALL To ALL + Description: test cases for FloorDiv + Expectation: the result match to numpy + """ + x_np = np.random.rand(1, 5).astype(dtype) + x_np = x_np + 0.5j * x_np + y_np = np.random.rand(1, 5).astype(dtype) + y_np = y_np + 0.4j * y_np + expect = np.floor_divide(x_np, y_np) + x_input = Tensor(x_np) + y_input = Tensor(y_np) + floor_div = NetFloorDiv() + output = floor_div(x_input, y_input) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_fractionalavgpool_op.py b/tests/st/ops/cpu/test_fractionalavgpool_op.py index fcfda947149..9920a899959 100644 --- a/tests/st/ops/cpu/test_fractionalavgpool_op.py +++ b/tests/st/ops/cpu/test_fractionalavgpool_op.py @@ -1,370 +1,370 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalAvgPool(nn.Cell): - def __init__(self): - super(NetFractionalAvgPool, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolRealRandom(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolRealRandom, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, - pseudo_random=False, seed=5454, seed2=144) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolOverlapPing(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolOverlapPing, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolGrad(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolGrad, self).__init__() - self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad() - - def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -class NetFractionalAvgPoolGradOverlapping(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolGradOverlapping, self).__init__() - self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad(overlapping=True) - - def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpool_graph(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalAvgPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalAvgPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalAvgPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - netgrad = NetFractionalAvgPoolGradOverlapping() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], - [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpool_pynative(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - output = fractionalavgpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], - deterministic=True, pseudo_random=False, seed=5454, seed2=144) - output = fractionalavgpool(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - output = fractionalavgpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) - - fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad(overlapping=True) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) - expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], - [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpool_pynative_dynamic(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalAvgPool() - dy_shape = [None for _ in x.shape] - input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) - net.set_inputs(input_dyn) - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpoolgrad_graph_dynamic(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalAvgPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalAvgPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalAvgPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - - set_input_dyn = [] - dy_shape_0 = [None for _ in x_shape.shape] - input_dyn_0 = Tensor(shape=dy_shape_0, dtype=x_shape.dtype) - set_input_dyn.append(input_dyn_0) - - dy_shape_1 = [None for _ in out_backprop.shape] - input_dyn_1 = Tensor(shape=dy_shape_1, dtype=out_backprop.dtype) - set_input_dyn.append(input_dyn_1) - - dy_shape_2 = [None for _ in output[1].shape] - input_dyn_2 = Tensor(shape=dy_shape_2, dtype=output[1].dtype) - set_input_dyn.append(input_dyn_2) - - dy_shape_3 = [None for _ in output[2].shape] - input_dyn_3 = Tensor(shape=dy_shape_3, dtype=output[2].dtype) - set_input_dyn.append(input_dyn_3) - - netgrad.set_inputs(*set_input_dyn) - - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpoolgrad_pynative_dynamic(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalAvgPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalAvgPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalAvgPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - - set_input_dyn = [] - dy_shape_0 = [None for _ in x_shape.shape] - input_dyn_0 = Tensor(shape=dy_shape_0, dtype=x_shape.dtype) - set_input_dyn.append(input_dyn_0) - - dy_shape_1 = [None for _ in out_backprop.shape] - input_dyn_1 = Tensor(shape=dy_shape_1, dtype=out_backprop.dtype) - set_input_dyn.append(input_dyn_1) - - dy_shape_2 = [None for _ in output[1].shape] - input_dyn_2 = Tensor(shape=dy_shape_2, dtype=output[1].dtype) - set_input_dyn.append(input_dyn_2) - - dy_shape_3 = [None for _ in output[2].shape] - input_dyn_3 = Tensor(shape=dy_shape_3, dtype=output[2].dtype) - set_input_dyn.append(input_dyn_3) - - netgrad.set_inputs(*set_input_dyn) - - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalAvgPool(nn.Cell): + def __init__(self): + super(NetFractionalAvgPool, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolRealRandom(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolRealRandom, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, + pseudo_random=False, seed=5454, seed2=144) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolOverlapPing(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolOverlapPing, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolGrad(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolGrad, self).__init__() + self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad() + + def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +class NetFractionalAvgPoolGradOverlapping(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolGradOverlapping, self).__init__() + self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad(overlapping=True) + + def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpool_graph(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalAvgPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalAvgPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalAvgPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + netgrad = NetFractionalAvgPoolGradOverlapping() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], + [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpool_pynative(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + output = fractionalavgpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], + deterministic=True, pseudo_random=False, seed=5454, seed2=144) + output = fractionalavgpool(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + output = fractionalavgpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) + + fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad(overlapping=True) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) + expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], + [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpool_pynative_dynamic(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalAvgPool() + dy_shape = [None for _ in x.shape] + input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) + net.set_inputs(input_dyn) + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpoolgrad_graph_dynamic(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalAvgPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalAvgPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalAvgPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + + set_input_dyn = [] + dy_shape_0 = [None for _ in x_shape.shape] + input_dyn_0 = Tensor(shape=dy_shape_0, dtype=x_shape.dtype) + set_input_dyn.append(input_dyn_0) + + dy_shape_1 = [None for _ in out_backprop.shape] + input_dyn_1 = Tensor(shape=dy_shape_1, dtype=out_backprop.dtype) + set_input_dyn.append(input_dyn_1) + + dy_shape_2 = [None for _ in output[1].shape] + input_dyn_2 = Tensor(shape=dy_shape_2, dtype=output[1].dtype) + set_input_dyn.append(input_dyn_2) + + dy_shape_3 = [None for _ in output[2].shape] + input_dyn_3 = Tensor(shape=dy_shape_3, dtype=output[2].dtype) + set_input_dyn.append(input_dyn_3) + + netgrad.set_inputs(*set_input_dyn) + + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpoolgrad_pynative_dynamic(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalAvgPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalAvgPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalAvgPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + + set_input_dyn = [] + dy_shape_0 = [None for _ in x_shape.shape] + input_dyn_0 = Tensor(shape=dy_shape_0, dtype=x_shape.dtype) + set_input_dyn.append(input_dyn_0) + + dy_shape_1 = [None for _ in out_backprop.shape] + input_dyn_1 = Tensor(shape=dy_shape_1, dtype=out_backprop.dtype) + set_input_dyn.append(input_dyn_1) + + dy_shape_2 = [None for _ in output[1].shape] + input_dyn_2 = Tensor(shape=dy_shape_2, dtype=output[1].dtype) + set_input_dyn.append(input_dyn_2) + + dy_shape_3 = [None for _ in output[2].shape] + input_dyn_3 = Tensor(shape=dy_shape_3, dtype=output[2].dtype) + set_input_dyn.append(input_dyn_3) + + netgrad.set_inputs(*set_input_dyn) + + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) diff --git a/tests/st/ops/cpu/test_fractionalmaxpool_op.py b/tests/st/ops/cpu/test_fractionalmaxpool_op.py index a0f49a4eff0..ab63041744a 100644 --- a/tests/st/ops/cpu/test_fractionalmaxpool_op.py +++ b/tests/st/ops/cpu/test_fractionalmaxpool_op.py @@ -1,242 +1,242 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalMaxPool(nn.Cell): - def __init__(self): - super(NetFractionalMaxPool, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolRealRandom(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolRealRandom, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, - pseudo_random=False, seed=5454, seed2=144) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolOverlapPing(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolOverlapPing, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolGrad(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolGrad, self).__init__() - self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad() - - def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -class NetFractionalMaxPoolGradOverlapping(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolGradOverlapping, self).__init__() - self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad(overlapping=True) - - def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool_graph(): - """ - Feature: FractionalMaxPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalMaxPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalMaxPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalMaxPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalMaxPoolGrad() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], - [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - netgrad = NetFractionalMaxPoolGradOverlapping() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], - [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool_pynative_dynamic(): - """ - Feature: FractionalMaxPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - # case1 - net = NetFractionalMaxPool() - dy_shape = [None for _ in x.shape] - input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) - net.set_inputs(input_dyn) - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - # case2 - net = NetFractionalMaxPoolRealRandom() - dy_shape = [None for _ in x.shape] - input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) - net.set_inputs(input_dyn) - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - # case3 - net = NetFractionalMaxPoolOverlapPing() - dy_shape = [None for _ in x.shape] - input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) - net.set_inputs(input_dyn) - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool_pynative(): - """ - Feature: FractionalMaxPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - output = fractionalmaxpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], - deterministic=True, pseudo_random=False, seed=5454, seed2=144) - output = fractionalmaxpool(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - output = fractionalmaxpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], - [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad(overlapping=True) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], - [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalMaxPool(nn.Cell): + def __init__(self): + super(NetFractionalMaxPool, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolRealRandom(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolRealRandom, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, + pseudo_random=False, seed=5454, seed2=144) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolOverlapPing(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolOverlapPing, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolGrad(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolGrad, self).__init__() + self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad() + + def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +class NetFractionalMaxPoolGradOverlapping(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolGradOverlapping, self).__init__() + self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad(overlapping=True) + + def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool_graph(): + """ + Feature: FractionalMaxPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalMaxPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalMaxPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalMaxPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalMaxPoolGrad() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], + [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + netgrad = NetFractionalMaxPoolGradOverlapping() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], + [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool_pynative_dynamic(): + """ + Feature: FractionalMaxPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + # case1 + net = NetFractionalMaxPool() + dy_shape = [None for _ in x.shape] + input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) + net.set_inputs(input_dyn) + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + # case2 + net = NetFractionalMaxPoolRealRandom() + dy_shape = [None for _ in x.shape] + input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) + net.set_inputs(input_dyn) + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + # case3 + net = NetFractionalMaxPoolOverlapPing() + dy_shape = [None for _ in x.shape] + input_dyn = Tensor(shape=dy_shape, dtype=x.dtype) + net.set_inputs(input_dyn) + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool_pynative(): + """ + Feature: FractionalMaxPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + output = fractionalmaxpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], + deterministic=True, pseudo_random=False, seed=5454, seed2=144) + output = fractionalmaxpool(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + output = fractionalmaxpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], + [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad(overlapping=True) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], + [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) diff --git a/tests/st/ops/cpu/test_gather_d_grad_op.py b/tests/st/ops/cpu/test_gather_d_grad_op.py index 2a6279889f7..20fe775c781 100644 --- a/tests/st/ops/cpu/test_gather_d_grad_op.py +++ b/tests/st/ops/cpu/test_gather_d_grad_op.py @@ -1,132 +1,132 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.common.api import jit -from mindspore.ops.composite import GradOperation - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetGatherD(nn.Cell): - def __init__(self, dim=1): - super(NetGatherD, self).__init__() - self.gatherd = P.GatherD() - self.dim = int(dim) - - def construct(self, x, index): - return self.gatherd(x, self.dim, index) - - -class NetGatherDGrad(nn.Cell): - def __init__(self, network): - super(NetGatherDGrad, self).__init__() - self.grad = GradOperation(get_all=True) - self.network = network - - @jit - def construct(self, inputx, index): - return self.grad(self.network)(inputx, index) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gatherd_grad_fp32(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.float32) * prop - index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) - dim = 1 - - gatherd = NetGatherD(dim) - grad = NetGatherDGrad(gatherd) - output_grad = grad(Tensor(x), Tensor(index)) - if isinstance(output_grad, (tuple, list)): - output_grad = output_grad[0] - print(output_grad.asnumpy()) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gatherd_grad_fp16(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.float16) * prop - index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int32) - dim = 0 - - gatherd = NetGatherD(dim) - grad = NetGatherDGrad(gatherd) - output_grad = grad(Tensor(x), Tensor(index)) - if isinstance(output_grad, (tuple, list)): - output_grad = output_grad[0] - print(output_grad.asnumpy()) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gatherd_grad_int32(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.int32) * prop - index = np.random.randint(0, 5, (5, 5, 7)).astype(np.int64) - dim = -1 - - gatherd = NetGatherD(dim) - grad = NetGatherDGrad(gatherd) - output_grad = grad(Tensor(x), Tensor(index)) - if isinstance(output_grad, (tuple, list)): - output_grad = output_grad[0] - print(output_grad.asnumpy()) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_grad_checkresult(): - x = np.array([[[-146.76097, 119.84371], [91.22607, -166.12923]], - [[37.67479, -8.696029], [43.804962, -23.369316]]], np.float32) - index = np.array([[[0, 1], [0, 0]], [[0, 0], [0, 1]]], np.int32) - dim = 1 - - gatherd = NetGatherD(dim) - grad = NetGatherDGrad(gatherd) - output = grad(Tensor(x), Tensor(index)) - - if isinstance(output, (tuple, list)): - output = output[0] - expect = np.array([[[2., 1.], [0., 1.]], [[2., 1.], [0., 1.]]], np.float32) - error = np.ones(shape=expect.shape) * 1.0e-6 - assert np.all(np.abs(output.asnumpy() - expect) < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_grad_dynamic_shape(): - """ - Feature: dynamic shape support of GatherDGrad. - Description: input Tensor with dynamic shape. - Expectation: output shape coincide with expect_shape. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - x_dyn = Tensor(shape=[2, None], dtype=ms.float16) - x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), dtype=ms.float16) - dim = 0 - index_dyn = Tensor(shape=[None, 5], dtype=ms.int64) - index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), dtype=ms.int64) - except_shape = (2, 5) - grad_net = NetGatherDGrad(NetGatherD(dim)) - grad_net.set_inputs(x_dyn, index_dyn) - output = grad_net(x, index) - assert output[0].asnumpy().shape == except_shape +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.api import jit +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetGatherD(nn.Cell): + def __init__(self, dim=1): + super(NetGatherD, self).__init__() + self.gatherd = P.GatherD() + self.dim = int(dim) + + def construct(self, x, index): + return self.gatherd(x, self.dim, index) + + +class NetGatherDGrad(nn.Cell): + def __init__(self, network): + super(NetGatherDGrad, self).__init__() + self.grad = GradOperation(get_all=True) + self.network = network + + @jit + def construct(self, inputx, index): + return self.grad(self.network)(inputx, index) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gatherd_grad_fp32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float32) * prop + index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + output_grad = grad(Tensor(x), Tensor(index)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gatherd_grad_fp16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float16) * prop + index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int32) + dim = 0 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + output_grad = grad(Tensor(x), Tensor(index)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gatherd_grad_int32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + index = np.random.randint(0, 5, (5, 5, 7)).astype(np.int64) + dim = -1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + output_grad = grad(Tensor(x), Tensor(index)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_grad_checkresult(): + x = np.array([[[-146.76097, 119.84371], [91.22607, -166.12923]], + [[37.67479, -8.696029], [43.804962, -23.369316]]], np.float32) + index = np.array([[[0, 1], [0, 0]], [[0, 0], [0, 1]]], np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + output = grad(Tensor(x), Tensor(index)) + + if isinstance(output, (tuple, list)): + output = output[0] + expect = np.array([[[2., 1.], [0., 1.]], [[2., 1.], [0., 1.]]], np.float32) + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_grad_dynamic_shape(): + """ + Feature: dynamic shape support of GatherDGrad. + Description: input Tensor with dynamic shape. + Expectation: output shape coincide with expect_shape. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + x_dyn = Tensor(shape=[2, None], dtype=ms.float16) + x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), dtype=ms.float16) + dim = 0 + index_dyn = Tensor(shape=[None, 5], dtype=ms.int64) + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), dtype=ms.int64) + except_shape = (2, 5) + grad_net = NetGatherDGrad(NetGatherD(dim)) + grad_net.set_inputs(x_dyn, index_dyn) + output = grad_net(x, index) + assert output[0].asnumpy().shape == except_shape diff --git a/tests/st/ops/cpu/test_gather_d_op.py b/tests/st/ops/cpu/test_gather_d_op.py index 362ad0c6d0b..ab46d85c9a7 100644 --- a/tests/st/ops/cpu/test_gather_d_op.py +++ b/tests/st/ops/cpu/test_gather_d_op.py @@ -1,162 +1,162 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import os -import stat -import numpy as np -import pytest - -import mindspore -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.train.serialization import export - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class NetGatherD(nn.Cell): - def __init__(self, dim=1): - super(NetGatherD, self).__init__() - self.gatherd = P.GatherD() - self.dim = int(dim) - - def construct(self, x, index): - return self.gatherd(x, self.dim, index) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_fp32(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.float32) * prop - index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) - dim = 1 - - gatherd = NetGatherD(dim) - output = gatherd(Tensor(x), Tensor(index)) - - expect = np.zeros(index.shape).astype(np.float32) - for i in range(index.shape[0]): - for j in range(index.shape[1]): - for k in range(index.shape[2]): - expect[i, j, k] = x[i, index[i, j, k], k] - error = np.ones(shape=expect.shape) * 1.0e-6 - assert np.all(np.abs(output.asnumpy() - expect) < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_fp16(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.float16) * prop - index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int64) - dim = 0 - - gatherd = NetGatherD(dim) - output = gatherd(Tensor(x), Tensor(index)) - - expect = np.zeros(index.shape).astype(np.float16) - for i in range(index.shape[0]): - for j in range(index.shape[1]): - for k in range(index.shape[2]): - expect[i, j, k] = x[index[i, j, k], j, k] - error = np.ones(shape=expect.shape) * 1.0e-6 - assert np.all(np.abs(output.asnumpy() - expect) < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_int32(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.int32) * prop - index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) - dim = -1 - - gatherd = NetGatherD(dim) - output = gatherd(Tensor(x), Tensor(index)) - - expect = np.zeros(index.shape).astype(np.int32) - for i in range(index.shape[0]): - for j in range(index.shape[1]): - for k in range(index.shape[2]): - expect[i, j, k] = x[i, j, index[i, j, k]] - assert np.all(output.asnumpy() == expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_bool(): - prop = 100 if np.random.random() > 0.5 else -100 - x = np.random.randn(5, 5, 5).astype(np.int32) * prop - x = (x >= 0).astype(np.bool) - index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) - dim = -1 - - gatherd = NetGatherD(dim) - output = gatherd(Tensor(x), Tensor(index)) - - expect = np.zeros(index.shape).astype(np.bool) - for i in range(index.shape[0]): - for j in range(index.shape[1]): - for k in range(index.shape[2]): - expect[i, j, k] = x[i, j, index[i, j, k]] - assert np.all(output.asnumpy() == expect) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_gatherd_cpu_dynamic_shape(): - """ - Feature: test GatherD op in cpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - dim = -1 - gatherd = NetGatherD(dim) - x_dyn = Tensor(shape=[None, 5, 5], dtype=mindspore.float32) - index_dyn = Tensor(shape=[5, 5, None], dtype=mindspore.int32) - gatherd.set_inputs(x_dyn, index_dyn) - x = np.random.randn(5, 5, 5) - y = np.random.randn(5, 5, 8) - output = gatherd(Tensor(x, mindspore.float32), Tensor(y, mindspore.int32)) - expect_shape = (5, 5, 8) - assert output.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gatherd_cpu_onnx(): - """ - Feature: test GatherD op in cpu. - Description: test the ops export onnx. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - dim = 1 - net = NetGatherD(dim) - data = np.array([[1, 2], [3, 4]], dtype=np.float32) - indices = np.array([[0, 0], [1, 0]], dtype=np.int32) - out_ms = net(Tensor(data), Tensor(indices)).asnumpy() - file = 'gatherd.onnx' - export(net, Tensor(data), Tensor(indices), file_name=file, file_format="ONNX") - assert os.path.exists(file) - - import onnxruntime - sess = onnxruntime.InferenceSession(file) - input_x = sess.get_inputs()[0].name - input_indices = sess.get_inputs()[1].name - result = sess.run([], {input_x: data, input_indices: indices})[0] - assert np.all(out_ms == result) - - os.chmod(file, stat.S_IWRITE) - os.remove(file) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import os +import stat +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.train.serialization import export + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetGatherD(nn.Cell): + def __init__(self, dim=1): + super(NetGatherD, self).__init__() + self.gatherd = P.GatherD() + self.dim = int(dim) + + def construct(self, x, index): + return self.gatherd(x, self.dim, index) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_fp32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float32) * prop + index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.float32) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, index[i, j, k], k] + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_fp16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float16) * prop + index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int64) + dim = 0 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.float16) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[index[i, j, k], j, k] + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_int32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) + dim = -1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.int32) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, j, index[i, j, k]] + assert np.all(output.asnumpy() == expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_bool(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + x = (x >= 0).astype(np.bool) + index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) + dim = -1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.bool) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, j, index[i, j, k]] + assert np.all(output.asnumpy() == expect) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_gatherd_cpu_dynamic_shape(): + """ + Feature: test GatherD op in cpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dim = -1 + gatherd = NetGatherD(dim) + x_dyn = Tensor(shape=[None, 5, 5], dtype=mindspore.float32) + index_dyn = Tensor(shape=[5, 5, None], dtype=mindspore.int32) + gatherd.set_inputs(x_dyn, index_dyn) + x = np.random.randn(5, 5, 5) + y = np.random.randn(5, 5, 8) + output = gatherd(Tensor(x, mindspore.float32), Tensor(y, mindspore.int32)) + expect_shape = (5, 5, 8) + assert output.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gatherd_cpu_onnx(): + """ + Feature: test GatherD op in cpu. + Description: test the ops export onnx. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dim = 1 + net = NetGatherD(dim) + data = np.array([[1, 2], [3, 4]], dtype=np.float32) + indices = np.array([[0, 0], [1, 0]], dtype=np.int32) + out_ms = net(Tensor(data), Tensor(indices)).asnumpy() + file = 'gatherd.onnx' + export(net, Tensor(data), Tensor(indices), file_name=file, file_format="ONNX") + assert os.path.exists(file) + + import onnxruntime + sess = onnxruntime.InferenceSession(file) + input_x = sess.get_inputs()[0].name + input_indices = sess.get_inputs()[1].name + result = sess.run([], {input_x: data, input_indices: indices})[0] + assert np.all(out_ms == result) + + os.chmod(file, stat.S_IWRITE) + os.remove(file) diff --git a/tests/st/ops/cpu/test_isinf_op.py b/tests/st/ops/cpu/test_isinf_op.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index d03150034a3..087b741111a 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -1,292 +1,292 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from tests.mark_utils import arg_mark - -import math -import pytest -import numpy as np -import mindspore -from mindspore import context -from mindspore import nn, ops -from mindspore import Tensor -from mindspore.common.parameter import ParameterTuple -from mindspore.common.parameter import Parameter -from mindspore.ops import composite as c - - -class GradOfAllInputsAndParams(nn.Cell): - def __init__(self, network, sens_param): - super().__init__() - self.grad = c.GradOperation(get_all=True, get_by_list=True, sens_param=sens_param) - self.network = network - self.params = ParameterTuple(self.network.trainable_params()) - - def construct(self, *inputs): - gout = self.grad(self.network, self.params)(*inputs) - return gout - - -class LSTMP(nn.Cell): - def __init__(self, input_s, hidden_s, num_layers, has_bias, batch_first, bidirectional, dropout, proj_size=0): - super().__init__() - self.lstm = ops.LSTM(input_s, hidden_s, num_layers, has_bias, bidirectional, dropout, proj_size) - real_hidden_size = proj_size if proj_size > 0 else hidden_s - weights_size = 4 * hidden_s * (input_s + real_hidden_size) - if proj_size > 0: - weights_size += proj_size * hidden_s - if has_bias: - weights_size += 4 * hidden_s - stdv = 1 / math.sqrt(hidden_s) - self.weights = Parameter(Tensor(np.random.uniform(-stdv, stdv, (weights_size)).astype(np.float32))) - - def construct(self, inp, h0, c0): - return self.lstm(inp, h0, c0, self.weights) - - -class LSTM(nn.Cell): - def __init__(self, input_s, hidden_s, num_layers, has_bias, batch_first, bidirectional, dropout): - super().__init__() - self.lstm = nn.LSTM(input_size=input_s, hidden_size=hidden_s, num_layers=num_layers, has_bias=has_bias, - batch_first=batch_first, bidirectional=bidirectional, dropout=dropout) - - def construct(self, inp, h0, c0): - return self.lstm(inp, (h0, c0)) - - -class LSTMWeightBias(): - def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional): - self.num_layers = num_layers - self.has_bias = has_bias - self.input_size = input_size - self.num_directions = num_directions - self.hidden_size = hidden_size - self.bidirectional = bidirectional - - def get_weight_bias(self): - gate_size = 4 * self.hidden_size - - w_ih_list = [] - w_hh_list = [] - b_ih_list = [] - b_hh_list = [] - stdv = 1 / math.sqrt(self.hidden_size) - for layer in range(self.num_layers): - for direction in range(self.num_directions): - layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions - suffix = '_reverse' if direction == 1 else '' - - w_ih_list.append(Parameter( - Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)), - name='weight_ih_l{}{}'.format(layer, suffix))) - w_hh_list.append(Parameter( - Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)), - name='weight_hh_l{}{}'.format(layer, suffix))) - if self.has_bias: - b_ih_list.append(Parameter( - Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), - name='bias_ih_l{}{}'.format(layer, suffix))) - b_hh_list.append(Parameter( - Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), - name='bias_hh_l{}{}'.format(layer, suffix))) - w_ih_list = ParameterTuple(w_ih_list) - w_hh_list = ParameterTuple(w_hh_list) - b_ih_list = ParameterTuple(b_ih_list) - b_hh_list = ParameterTuple(b_hh_list) - return w_ih_list, w_hh_list, b_ih_list, b_hh_list - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sit_lstm_forward_input_3_32_32_is_32_hs_16(): - """ - Feature: LSTM forward - Description: LSTM with input (3, 32, 32) - Expectation: Graph mode equal to pynative mode - """ - input_s = 32 - hidden_s = 16 - has_bias = True - bidirectional = False - num_layers = 1 - num_directions = 1 - - fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) - w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() - - h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) - - # graph mode - context.set_context(mode=context.GRAPH_MODE) - net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, - bidirectional=bidirectional, dropout=0.0) - net.lstm.w_ih_list = w_ih_list - net.lstm.w_hh_list = w_hh_list - net.lstm.b_ih_list = b_ih_list - net.lstm.b_hh_list = b_hh_list - out, (hy, cy) = net(input_ms, h0, c0) - - # pynative mode - context.set_context(mode=context.PYNATIVE_MODE) - net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, - bidirectional=bidirectional, dropout=0.0) - net_pynative.lstm.w_ih_list = w_ih_list - net_pynative.lstm.w_hh_list = w_hh_list - net_pynative.lstm.b_ih_list = b_ih_list - net_pynative.lstm.b_hh_list = b_hh_list - out_pynative, (hy_pynative, cy_pynative) = net_pynative(input_ms, h0, c0) - context.set_context(mode=context.GRAPH_MODE) - - assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001) - assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001) - assert np.allclose(cy.asnumpy(), cy_pynative.asnumpy(), 0.0001, 0.0001) - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sit_lstm_grad_input_3_32_32_is_32_hs_16(): - """ - Feature: LSTM backward - Description: LSTM with input (3, 32, 32) - Expectation: Graph mode equal to pynative mode - """ - input_s = 32 - hidden_s = 16 - has_bias = True - bidirectional = False - num_layers = 1 - num_directions = 1 - - fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) - w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() - - h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) - - # graph mode - context.set_context(mode=context.GRAPH_MODE) - net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, - bidirectional=bidirectional, dropout=0.0) - net.lstm.w_ih_list = w_ih_list - net.lstm.w_hh_list = w_hh_list - net.lstm.b_ih_list = b_ih_list - net.lstm.b_hh_list = b_hh_list - - grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False) - grad_net_inp.set_train() - out_grad, _ = grad_net_inp(input_ms, h0, c0) - x_grad = out_grad[0].asnumpy() - h_grad = out_grad[1].asnumpy() - c_grad = out_grad[2].asnumpy() - - # pynative mode - context.set_context(mode=context.PYNATIVE_MODE) - net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, - bidirectional=bidirectional, dropout=0.0) - net_pynative.lstm.w_ih_list = w_ih_list - net_pynative.lstm.w_hh_list = w_hh_list - net_pynative.lstm.b_ih_list = b_ih_list - net_pynative.lstm.b_hh_list = b_hh_list - - grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False) - grad_net_inp_pynative.set_train() - out_grad_pynative, _ = grad_net_inp_pynative(input_ms, h0, c0) - x_grad_pynative = out_grad_pynative[0].asnumpy() - h_grad_pynative = out_grad_pynative[1].asnumpy() - c_grad_pynative = out_grad_pynative[2].asnumpy() - context.set_context(mode=context.GRAPH_MODE) - - assert np.allclose(x_grad, x_grad_pynative, 0.001, 0.001) - assert np.allclose(h_grad, h_grad_pynative, 0.001, 0.001) - assert np.allclose(c_grad, c_grad_pynative, 0.001, 0.001) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_lstm_cpu_dynamic_shape(): - """ - Feature: test LSTM op in cpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - input_s = 32 - hidden_s = 16 - has_bias = True - bidirectional = False - num_layers = 1 - num_directions = 1 - - fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) - w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() - net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, - bidirectional=bidirectional, dropout=0.0) - net.lstm.w_ih_list = w_ih_list - net.lstm.w_hh_list = w_hh_list - net.lstm.b_ih_list = b_ih_list - net.lstm.b_hh_list = b_hh_list - - h0_dyn = Tensor(shape=[None, 32, 16], dtype=mindspore.float32) - c0_dyn = Tensor(shape=[num_layers * 1, None, 16], dtype=mindspore.float32) - input_dyn = Tensor(shape=[3, 32, None], dtype=mindspore.float32) - net.set_inputs(input_dyn, h0_dyn, c0_dyn) - - h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) - input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) - out, (hy, cy) = net(input_ms, h0, c0) - out_shape = (3, 32, 16) - assert out.asnumpy().shape == out_shape - hy_shape = (1, 32, 16) - assert hy.asnumpy().shape == hy_shape - cy_shape = (1, 32, 16) - assert cy.asnumpy().shape == cy_shape - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_lstm_cpu_proj_size(): - """ - Feature: test LSTM op in cpu. - Description: test the ops with proj_size input. - Expectation: expect correct result. - """ - np.random.seed(1) - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - input_size = 4 - hidden_size = 4 - num_layers = 1 - seq_len = 2 - batch_size = 1 - dropout = 0.0 - proj_size = 2 - has_bias = False - batch_first = False - bidirectional = False - - x = Tensor(np.random.randn(seq_len, batch_size, input_size), mindspore.float32) - h0 = Tensor(np.random.randn(num_layers, batch_size, proj_size), mindspore.float32) - c0 = Tensor(np.random.randn(num_layers, batch_size, hidden_size), mindspore.float32) - net = LSTMP(input_size, hidden_size, num_layers, has_bias, batch_first, bidirectional, dropout, proj_size) - grad_net = GradOfAllInputsAndParams(net, sens_param=False) - grad_net.set_train() - out_grad, _ = grad_net(x, h0, c0) - x_grad = out_grad[0].asnumpy() - h_grad = out_grad[1].asnumpy() - c_grad = out_grad[2].asnumpy() - - expect_x_grad = np.array([[[-0.02324772, -0.09717661, 0.06087979, -0.00883127]], - [[-0.11961889, 0.0196102, 0.02770284, -0.13316777]]], np.float32) - expect_h_grad = np.array([[[0.04825277, 0.00618415]]], np.float32) - expect_c_grad = np.array([[[-0.01589189, 0.17060986, 0.07265963, -0.05466095]]], np.float32) - assert np.allclose(x_grad, expect_x_grad) - assert np.allclose(h_grad, expect_h_grad) - assert np.allclose(c_grad, expect_c_grad) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tests.mark_utils import arg_mark + +import math +import pytest +import numpy as np +import mindspore +from mindspore import context +from mindspore import nn, ops +from mindspore import Tensor +from mindspore.common.parameter import ParameterTuple +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as c + + +class GradOfAllInputsAndParams(nn.Cell): + def __init__(self, network, sens_param): + super().__init__() + self.grad = c.GradOperation(get_all=True, get_by_list=True, sens_param=sens_param) + self.network = network + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + gout = self.grad(self.network, self.params)(*inputs) + return gout + + +class LSTMP(nn.Cell): + def __init__(self, input_s, hidden_s, num_layers, has_bias, batch_first, bidirectional, dropout, proj_size=0): + super().__init__() + self.lstm = ops.LSTM(input_s, hidden_s, num_layers, has_bias, bidirectional, dropout, proj_size) + real_hidden_size = proj_size if proj_size > 0 else hidden_s + weights_size = 4 * hidden_s * (input_s + real_hidden_size) + if proj_size > 0: + weights_size += proj_size * hidden_s + if has_bias: + weights_size += 4 * hidden_s + stdv = 1 / math.sqrt(hidden_s) + self.weights = Parameter(Tensor(np.random.uniform(-stdv, stdv, (weights_size)).astype(np.float32))) + + def construct(self, inp, h0, c0): + return self.lstm(inp, h0, c0, self.weights) + + +class LSTM(nn.Cell): + def __init__(self, input_s, hidden_s, num_layers, has_bias, batch_first, bidirectional, dropout): + super().__init__() + self.lstm = nn.LSTM(input_size=input_s, hidden_size=hidden_s, num_layers=num_layers, has_bias=has_bias, + batch_first=batch_first, bidirectional=bidirectional, dropout=dropout) + + def construct(self, inp, h0, c0): + return self.lstm(inp, (h0, c0)) + + +class LSTMWeightBias(): + def __init__(self, num_layers, has_bias, input_size, num_directions, hidden_size, bidirectional): + self.num_layers = num_layers + self.has_bias = has_bias + self.input_size = input_size + self.num_directions = num_directions + self.hidden_size = hidden_size + self.bidirectional = bidirectional + + def get_weight_bias(self): + gate_size = 4 * self.hidden_size + + w_ih_list = [] + w_hh_list = [] + b_ih_list = [] + b_hh_list = [] + stdv = 1 / math.sqrt(self.hidden_size) + for layer in range(self.num_layers): + for direction in range(self.num_directions): + layer_input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions + suffix = '_reverse' if direction == 1 else '' + + w_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, layer_input_size)).astype(np.float32)), + name='weight_ih_l{}{}'.format(layer, suffix))) + w_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size, self.hidden_size)).astype(np.float32)), + name='weight_hh_l{}{}'.format(layer, suffix))) + if self.has_bias: + b_ih_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_ih_l{}{}'.format(layer, suffix))) + b_hh_list.append(Parameter( + Tensor(np.random.uniform(-stdv, stdv, (gate_size)).astype(np.float32)), + name='bias_hh_l{}{}'.format(layer, suffix))) + w_ih_list = ParameterTuple(w_ih_list) + w_hh_list = ParameterTuple(w_hh_list) + b_ih_list = ParameterTuple(b_ih_list) + b_hh_list = ParameterTuple(b_hh_list) + return w_ih_list, w_hh_list, b_ih_list, b_hh_list + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sit_lstm_forward_input_3_32_32_is_32_hs_16(): + """ + Feature: LSTM forward + Description: LSTM with input (3, 32, 32) + Expectation: Graph mode equal to pynative mode + """ + input_s = 32 + hidden_s = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.lstm.w_ih_list = w_ih_list + net.lstm.w_hh_list = w_hh_list + net.lstm.b_ih_list = b_ih_list + net.lstm.b_hh_list = b_hh_list + out, (hy, cy) = net(input_ms, h0, c0) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net_pynative.lstm.w_ih_list = w_ih_list + net_pynative.lstm.w_hh_list = w_hh_list + net_pynative.lstm.b_ih_list = b_ih_list + net_pynative.lstm.b_hh_list = b_hh_list + out_pynative, (hy_pynative, cy_pynative) = net_pynative(input_ms, h0, c0) + context.set_context(mode=context.GRAPH_MODE) + + assert np.allclose(out.asnumpy(), out_pynative.asnumpy(), 0.0001, 0.0001) + assert np.allclose(hy.asnumpy(), hy_pynative.asnumpy(), 0.0001, 0.0001) + assert np.allclose(cy.asnumpy(), cy_pynative.asnumpy(), 0.0001, 0.0001) + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sit_lstm_grad_input_3_32_32_is_32_hs_16(): + """ + Feature: LSTM backward + Description: LSTM with input (3, 32, 32) + Expectation: Graph mode equal to pynative mode + """ + input_s = 32 + hidden_s = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + + h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.lstm.w_ih_list = w_ih_list + net.lstm.w_hh_list = w_hh_list + net.lstm.b_ih_list = b_ih_list + net.lstm.b_hh_list = b_hh_list + + grad_net_inp = GradOfAllInputsAndParams(net, sens_param=False) + grad_net_inp.set_train() + out_grad, _ = grad_net_inp(input_ms, h0, c0) + x_grad = out_grad[0].asnumpy() + h_grad = out_grad[1].asnumpy() + c_grad = out_grad[2].asnumpy() + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + net_pynative = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net_pynative.lstm.w_ih_list = w_ih_list + net_pynative.lstm.w_hh_list = w_hh_list + net_pynative.lstm.b_ih_list = b_ih_list + net_pynative.lstm.b_hh_list = b_hh_list + + grad_net_inp_pynative = GradOfAllInputsAndParams(net_pynative, sens_param=False) + grad_net_inp_pynative.set_train() + out_grad_pynative, _ = grad_net_inp_pynative(input_ms, h0, c0) + x_grad_pynative = out_grad_pynative[0].asnumpy() + h_grad_pynative = out_grad_pynative[1].asnumpy() + c_grad_pynative = out_grad_pynative[2].asnumpy() + context.set_context(mode=context.GRAPH_MODE) + + assert np.allclose(x_grad, x_grad_pynative, 0.001, 0.001) + assert np.allclose(h_grad, h_grad_pynative, 0.001, 0.001) + assert np.allclose(c_grad, c_grad_pynative, 0.001, 0.001) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_lstm_cpu_dynamic_shape(): + """ + Feature: test LSTM op in cpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + input_s = 32 + hidden_s = 16 + has_bias = True + bidirectional = False + num_layers = 1 + num_directions = 1 + + fact = LSTMWeightBias(num_layers, has_bias, input_s, num_directions, hidden_s, bidirectional) + w_ih_list, w_hh_list, b_ih_list, b_hh_list = fact.get_weight_bias() + net = LSTM(input_s=input_s, hidden_s=16, num_layers=num_layers, has_bias=has_bias, batch_first=False, + bidirectional=bidirectional, dropout=0.0) + net.lstm.w_ih_list = w_ih_list + net.lstm.w_hh_list = w_hh_list + net.lstm.b_ih_list = b_ih_list + net.lstm.b_hh_list = b_hh_list + + h0_dyn = Tensor(shape=[None, 32, 16], dtype=mindspore.float32) + c0_dyn = Tensor(shape=[num_layers * 1, None, 16], dtype=mindspore.float32) + input_dyn = Tensor(shape=[3, 32, None], dtype=mindspore.float32) + net.set_inputs(input_dyn, h0_dyn, c0_dyn) + + h0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + c0 = Tensor(np.random.randn(num_layers * 1, 32, 16).astype(np.float32)) + input_ms = Tensor(np.random.randn(3, 32, 32).astype(np.float32)) + out, (hy, cy) = net(input_ms, h0, c0) + out_shape = (3, 32, 16) + assert out.asnumpy().shape == out_shape + hy_shape = (1, 32, 16) + assert hy.asnumpy().shape == hy_shape + cy_shape = (1, 32, 16) + assert cy.asnumpy().shape == cy_shape + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_lstm_cpu_proj_size(): + """ + Feature: test LSTM op in cpu. + Description: test the ops with proj_size input. + Expectation: expect correct result. + """ + np.random.seed(1) + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + input_size = 4 + hidden_size = 4 + num_layers = 1 + seq_len = 2 + batch_size = 1 + dropout = 0.0 + proj_size = 2 + has_bias = False + batch_first = False + bidirectional = False + + x = Tensor(np.random.randn(seq_len, batch_size, input_size), mindspore.float32) + h0 = Tensor(np.random.randn(num_layers, batch_size, proj_size), mindspore.float32) + c0 = Tensor(np.random.randn(num_layers, batch_size, hidden_size), mindspore.float32) + net = LSTMP(input_size, hidden_size, num_layers, has_bias, batch_first, bidirectional, dropout, proj_size) + grad_net = GradOfAllInputsAndParams(net, sens_param=False) + grad_net.set_train() + out_grad, _ = grad_net(x, h0, c0) + x_grad = out_grad[0].asnumpy() + h_grad = out_grad[1].asnumpy() + c_grad = out_grad[2].asnumpy() + + expect_x_grad = np.array([[[-0.02324772, -0.09717661, 0.06087979, -0.00883127]], + [[-0.11961889, 0.0196102, 0.02770284, -0.13316777]]], np.float32) + expect_h_grad = np.array([[[0.04825277, 0.00618415]]], np.float32) + expect_c_grad = np.array([[[-0.01589189, 0.17060986, 0.07265963, -0.05466095]]], np.float32) + assert np.allclose(x_grad, expect_x_grad) + assert np.allclose(h_grad, expect_h_grad) + assert np.allclose(c_grad, expect_c_grad) diff --git a/tests/st/ops/cpu/test_parallel_concat_op.py b/tests/st/ops/cpu/test_parallel_concat_op.py index 42e9ba12e25..36d8fbf60e6 100644 --- a/tests/st/ops/cpu/test_parallel_concat_op.py +++ b/tests/st/ops/cpu/test_parallel_concat_op.py @@ -1,80 +1,80 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.array_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class ParallelConcatNet(nn.Cell): - def __init__(self): - super(ParallelConcatNet, self).__init__() - self.net = P.ParallelConcat() - - @jit - def construct(self, inputs): - return self.net(inputs) - - -def parallel_concat(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float32)) - data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float32)) - inputs = [data1, data2] - net_ms = ParallelConcatNet() - out_ms = net_ms(inputs) - expected = np.array([[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]], dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def parallel_concat_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float64)) - data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float64)) - inputs = [data1, data2] - net_ms = ParallelConcatNet() - out_ms = net_ms(inputs) - expected = np.array([[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]], dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for ParallelConcat - Expectation: the result match to tensorflow - """ - parallel_concat(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for ParallelConcat - Expectation: the result match to tensorflow - """ - parallel_concat_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class ParallelConcatNet(nn.Cell): + def __init__(self): + super(ParallelConcatNet, self).__init__() + self.net = P.ParallelConcat() + + @jit + def construct(self, inputs): + return self.net(inputs) + + +def parallel_concat(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float32)) + data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float32)) + inputs = [data1, data2] + net_ms = ParallelConcatNet() + out_ms = net_ms(inputs) + expected = np.array([[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def parallel_concat_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float64)) + data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float64)) + inputs = [data1, data2] + net_ms = ParallelConcatNet() + out_ms = net_ms(inputs) + expected = np.array([[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for ParallelConcat + Expectation: the result match to tensorflow + """ + parallel_concat(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for ParallelConcat + Expectation: the result match to tensorflow + """ + parallel_concat_pynative(loss=1.0e-5) diff --git a/tests/st/ops/cpu/test_reduce_op.py b/tests/st/ops/cpu/test_reduce_op.py index 7cdd99cfa71..7dc29a414f4 100644 --- a/tests/st/ops/cpu/test_reduce_op.py +++ b/tests/st/ops/cpu/test_reduce_op.py @@ -1,313 +1,313 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import os -import stat -import pytest -import numpy as np -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import mindspore.context as context -from mindspore.common.api import jit -from mindspore.train.serialization import export - -context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - - -class NetReduce(nn.Cell): - def __init__(self): - super(NetReduce, self).__init__() - self.axis0 = 0 - self.axis1 = 1 - self.axis2 = -1 - self.axis3 = (0, 1, -2) - self.axis4 = (0, 1, 2) - self.axis5 = (-1,) - self.axis6 = () - self.reduce_mean = P.ReduceMean(False) - self.reduce_sum = P.ReduceSum(False) - self.reduce_max = P.ReduceMax(False) - self.reduce_min = P.ReduceMin(False) - - @jit - def construct(self, indice): - return (self.reduce_mean(indice, self.axis0), - self.reduce_mean(indice, self.axis1), - self.reduce_mean(indice, self.axis2), - self.reduce_mean(indice, self.axis3), - self.reduce_mean(indice, self.axis4), - self.reduce_sum(indice, self.axis0), - self.reduce_sum(indice, self.axis2), - self.reduce_max(indice, self.axis0), - self.reduce_max(indice, self.axis2), - self.reduce_max(indice, self.axis5), - self.reduce_max(indice, self.axis6), - self.reduce_min(indice, self.axis0), - self.reduce_min(indice, self.axis1), - self.reduce_min(indice, self.axis2), - self.reduce_min(indice, self.axis3), - self.reduce_min(indice, self.axis4), - self.reduce_min(indice, self.axis5), - self.reduce_min(indice, self.axis6)) - - -class NetReduceLogic(nn.Cell): - def __init__(self): - super(NetReduceLogic, self).__init__() - self.axis0 = 0 - self.axis1 = -1 - self.axis2 = (0, 1, 2) - self.axis3 = () - self.reduce_all = P.ReduceAll(False) - self.reduce_any = P.ReduceAny(False) - - @jit - def construct(self, indice): - return (self.reduce_all(indice, self.axis0), - self.reduce_all(indice, self.axis1), - self.reduce_all(indice, self.axis2), - self.reduce_all(indice, self.axis3), - self.reduce_any(indice, self.axis0), - self.reduce_any(indice, self.axis1), - self.reduce_any(indice, self.axis2), - self.reduce_any(indice, self.axis3),) - - -class NetReduceProd(nn.Cell): - def __init__(self): - super(NetReduceProd, self).__init__() - self.axis0 = 0 - self.axis1 = 1 - self.axis2 = -1 - self.axis3 = (0, 1) - self.axis4 = () - self.reduce_prod = P.ReduceProd(False) - self.reduce_prod_keep = P.ReduceProd(True) - - @jit - def construct(self, indices): - return (self.reduce_prod(indices, self.axis0), - self.reduce_prod(indices, self.axis1), - self.reduce_prod(indices, self.axis2), - self.reduce_prod(indices, self.axis3), - self.reduce_prod_keep(indices, self.axis4)) - - -class NetReduceAny(nn.Cell): - def __init__(self, axis=()): - super(NetReduceAny, self).__init__() - self.op = P.ReduceAny(keep_dims=False) - self.axis = axis - - @jit - def construct(self, x): - return self.op(x, self.axis) - - -class NetReduceAll(nn.Cell): - def __init__(self, axis=()): - super(NetReduceAll, self).__init__() - self.op = P.ReduceAll(keep_dims=False) - self.axis = axis - - @jit - def construct(self, x): - return self.op(x, self.axis) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_reduce_any_onnx(): - """ - Feature: test ReduceAll op in cpu. - Description: test the ops export onnx. - Expectation: expect correct value result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - axis = 1 - net = NetReduceAny(axis) - data = np.array([[True, True, True], [False, False, False], [True, False, False]]) - out_ms = net(Tensor(data)).asnumpy() - file = 'reduceAny.onnx' - export(net, Tensor(data), file_name=file, file_format="ONNX") - assert os.path.exists(file) - - import onnxruntime - sess = onnxruntime.InferenceSession(file) - input_x = sess.get_inputs()[0].name - result = sess.run([], {input_x: data})[0] - assert np.all(out_ms == result) - - os.chmod(file, stat.S_IWRITE) - os.remove(file) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_reduce_all_onnx(): - """ - Feature: test ReduceAll op in cpu. - Description: test the ops export onnx. - Expectation: expect correct value result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - axis = 1 - net = NetReduceAll(axis) - data = np.array([[True, True, True], [False, False, False], [True, False, False]]) - out_ms = net(Tensor(data)).asnumpy() - file = 'reduceAll.onnx' - export(net, Tensor(data), file_name=file, file_format="ONNX") - assert os.path.exists(file) - - import onnxruntime - sess = onnxruntime.InferenceSession(file) - input_x = sess.get_inputs()[0].name - result = sess.run([], {input_x: data})[0] - assert np.all(out_ms == result) - - os.chmod(file, stat.S_IWRITE) - os.remove(file) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_reduce(): - """ - /// Feature: Reduce - /// Description: reduce tensor elements, include reduce_mean, reduce_max, etc. - /// Expectation: Euqal to numpy results - """ - reduce = NetReduce() - indice = Tensor(np.array([ - [[0., 2., 1., 4., 0., 2.], [3., 1., 2., 2., 4., 0.]], - [[2., 0., 1., 5., 0., 1.], [1., 0., 0., 4., 4., 3.]], - [[4., 1., 4., 0., 0., 0.], [2., 5., 1., 0., 1., 3.]] - ]).astype(np.float32)) - output = reduce(indice) - print(output[0]) - print(output[1]) - print(output[2]) - print(output[3]) - print(output[4]) - print(output[5]) - print(output[6]) - print(output[7]) - print(output[8]) - print(output[9]) - print(output[10]) - print(output[11]) - print(output[12]) - print(output[13]) - print(output[14]) - print(output[15]) - print(output[16]) - print(output[17]) - expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32) - expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype( - np.float32) - expect_2 = np.array([[1.5, 2.], [1.5, 2.], [1.5, 2.]]).astype(np.float32) - expect_3 = np.array([2, 1.5, 1.5, 2.5, 1.5, 1.5]).astype(np.float32) - expect_4 = np.array([1.75]).astype(np.float32) - expect_5 = np.array([[6., 3., 6., 9., 0., 3.], [6., 6., 3., 6., 9., 6.]]).astype(np.float32) - expect_6 = np.array([[9., 12.], [9., 12.], [9., 12.]]).astype(np.float32) - expect_7 = np.array([[4., 2., 4., 5., 0., 2.], [3., 5., 2., 4., 4., 3.]]).astype(np.float32) - expect_8 = np.array([[4., 4.], [5., 4.], [4., 5.]]).astype(np.float32) - expect_9 = np.array([[0., 0., 1., 0., 0., 0.], [1., 0., 0., 0., 1., 0.]]).astype(np.float32) - expect_10 = np.array([[0., 1., 1., 2., 0., 0.], [1., 0., 0., 4., 0., 1.], [2., 1., 1., 0., 0., 0.]]).astype( - np.float32) - expect_11 = np.array([[0., 0.], [0., 0.], [0., 0.]]).astype(np.float32) - expect_12 = np.array([0., 0., 0., 0., 0., 0.]).astype(np.float32) - assert (output[0].asnumpy() == expect_0).all() - assert (output[1].asnumpy() == expect_1).all() - assert (output[2].asnumpy() == expect_2).all() - assert (output[3].asnumpy() == expect_3).all() - assert (output[4].asnumpy() == expect_4).all() - assert (output[5].asnumpy() == expect_5).all() - assert (output[6].asnumpy() == expect_6).all() - assert (output[7].asnumpy() == expect_7).all() - assert (output[8].asnumpy() == expect_8).all() - assert (output[9].asnumpy() == expect_8).all() - assert (output[10].asnumpy() == 5.0).all() - assert (output[11].asnumpy() == expect_9).all() - assert (output[12].asnumpy() == expect_10).all() - assert (output[13].asnumpy() == expect_11).all() - assert (output[14].asnumpy() == expect_12).all() - assert (output[15].asnumpy() == 0.0).all() - assert (output[16].asnumpy() == expect_11).all() - assert (output[17].asnumpy() == 0.0).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_reduce_logic(): - """ - /// Feature: Reduce logic - /// Description: Include reduce_all, reduce_any - /// Expectation: Euqal to numpy results - """ - reduce_logic = NetReduceLogic() - indice_bool = Tensor([[[False, True, True, True, False, True], - [True, True, True, True, True, False]], - [[True, False, True, True, False, True], - [True, False, False, True, True, True]], - [[True, True, True, False, False, False], - [True, True, True, False, True, True]]]) - output = reduce_logic(indice_bool) - expect_all_1 = np.array([[False, False, True, False, False, False], - [True, False, False, False, True, False]]) - expect_all_2 = np.array([[False, False], [False, False], [False, False]]) - expect_all_3 = False - expect_all_4 = False - expect_any_1 = np.array([[True, True, True, True, False, True], [True, True, True, True, True, True]]) - expect_any_2 = np.array([[True, True], [True, True], [True, True]]) - expect_any_3 = True - expect_any_4 = True - - assert (output[0].asnumpy() == expect_all_1).all() - assert (output[1].asnumpy() == expect_all_2).all() - assert (output[2].asnumpy() == expect_all_3).all() - assert (output[3].asnumpy() == expect_all_4).all() - assert (output[4].asnumpy() == expect_any_1).all() - assert (output[5].asnumpy() == expect_any_2).all() - assert (output[6].asnumpy() == expect_any_3).all() - assert (output[7].asnumpy() == expect_any_4).all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_reduce_prod(): - """ - /// Feature: Reduce prod - /// Description: Product of tensor elements - /// Expectation: Euqal to numpy results - """ - reduce_prod = NetReduceProd() - indices = Tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], - [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], - [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]).astype(np.float32)) - output = reduce_prod(indices) - expect_prod_0 = np.array([[28, 28, 28, 28, 28, 28], - [80, 80, 80, 80, 80, 80], - [162, 162, 162, 162, 162, 162]]).astype(np.float32) - expect_prod_1 = np.array([[6, 6, 6, 6, 6, 6], - [120, 120, 120, 120, 120, 120], - [504, 504, 504, 504, 504, 504]]).astype(np.float32) - expect_prod_2 = np.array([[1.00000e+00, 6.40000e+01, 7.29000e+02], - [4.09600e+03, 1.56250e+04, 4.66560e+04], - [1.17649e+05, 2.62144e+05, 5.31441e+05]]).astype(np.float32) - expect_prod_3 = np.array([362880, 362880, 362880, 362880, 362880, 362880]).astype(np.float32) - expect_prod_4 = np.array([[[2.2833798e+33]]]).astype(np.float32) - assert (output[0].asnumpy() == expect_prod_0).all() - assert (output[1].asnumpy() == expect_prod_1).all() - assert (output[2].asnumpy() == expect_prod_2).all() - assert (output[3].asnumpy() == expect_prod_3).all() - assert (output[4].asnumpy() == expect_prod_4).all() +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import os +import stat +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common.api import jit +from mindspore.train.serialization import export + +context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + + +class NetReduce(nn.Cell): + def __init__(self): + super(NetReduce, self).__init__() + self.axis0 = 0 + self.axis1 = 1 + self.axis2 = -1 + self.axis3 = (0, 1, -2) + self.axis4 = (0, 1, 2) + self.axis5 = (-1,) + self.axis6 = () + self.reduce_mean = P.ReduceMean(False) + self.reduce_sum = P.ReduceSum(False) + self.reduce_max = P.ReduceMax(False) + self.reduce_min = P.ReduceMin(False) + + @jit + def construct(self, indice): + return (self.reduce_mean(indice, self.axis0), + self.reduce_mean(indice, self.axis1), + self.reduce_mean(indice, self.axis2), + self.reduce_mean(indice, self.axis3), + self.reduce_mean(indice, self.axis4), + self.reduce_sum(indice, self.axis0), + self.reduce_sum(indice, self.axis2), + self.reduce_max(indice, self.axis0), + self.reduce_max(indice, self.axis2), + self.reduce_max(indice, self.axis5), + self.reduce_max(indice, self.axis6), + self.reduce_min(indice, self.axis0), + self.reduce_min(indice, self.axis1), + self.reduce_min(indice, self.axis2), + self.reduce_min(indice, self.axis3), + self.reduce_min(indice, self.axis4), + self.reduce_min(indice, self.axis5), + self.reduce_min(indice, self.axis6)) + + +class NetReduceLogic(nn.Cell): + def __init__(self): + super(NetReduceLogic, self).__init__() + self.axis0 = 0 + self.axis1 = -1 + self.axis2 = (0, 1, 2) + self.axis3 = () + self.reduce_all = P.ReduceAll(False) + self.reduce_any = P.ReduceAny(False) + + @jit + def construct(self, indice): + return (self.reduce_all(indice, self.axis0), + self.reduce_all(indice, self.axis1), + self.reduce_all(indice, self.axis2), + self.reduce_all(indice, self.axis3), + self.reduce_any(indice, self.axis0), + self.reduce_any(indice, self.axis1), + self.reduce_any(indice, self.axis2), + self.reduce_any(indice, self.axis3),) + + +class NetReduceProd(nn.Cell): + def __init__(self): + super(NetReduceProd, self).__init__() + self.axis0 = 0 + self.axis1 = 1 + self.axis2 = -1 + self.axis3 = (0, 1) + self.axis4 = () + self.reduce_prod = P.ReduceProd(False) + self.reduce_prod_keep = P.ReduceProd(True) + + @jit + def construct(self, indices): + return (self.reduce_prod(indices, self.axis0), + self.reduce_prod(indices, self.axis1), + self.reduce_prod(indices, self.axis2), + self.reduce_prod(indices, self.axis3), + self.reduce_prod_keep(indices, self.axis4)) + + +class NetReduceAny(nn.Cell): + def __init__(self, axis=()): + super(NetReduceAny, self).__init__() + self.op = P.ReduceAny(keep_dims=False) + self.axis = axis + + @jit + def construct(self, x): + return self.op(x, self.axis) + + +class NetReduceAll(nn.Cell): + def __init__(self, axis=()): + super(NetReduceAll, self).__init__() + self.op = P.ReduceAll(keep_dims=False) + self.axis = axis + + @jit + def construct(self, x): + return self.op(x, self.axis) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_reduce_any_onnx(): + """ + Feature: test ReduceAll op in cpu. + Description: test the ops export onnx. + Expectation: expect correct value result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + axis = 1 + net = NetReduceAny(axis) + data = np.array([[True, True, True], [False, False, False], [True, False, False]]) + out_ms = net(Tensor(data)).asnumpy() + file = 'reduceAny.onnx' + export(net, Tensor(data), file_name=file, file_format="ONNX") + assert os.path.exists(file) + + import onnxruntime + sess = onnxruntime.InferenceSession(file) + input_x = sess.get_inputs()[0].name + result = sess.run([], {input_x: data})[0] + assert np.all(out_ms == result) + + os.chmod(file, stat.S_IWRITE) + os.remove(file) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_reduce_all_onnx(): + """ + Feature: test ReduceAll op in cpu. + Description: test the ops export onnx. + Expectation: expect correct value result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + axis = 1 + net = NetReduceAll(axis) + data = np.array([[True, True, True], [False, False, False], [True, False, False]]) + out_ms = net(Tensor(data)).asnumpy() + file = 'reduceAll.onnx' + export(net, Tensor(data), file_name=file, file_format="ONNX") + assert os.path.exists(file) + + import onnxruntime + sess = onnxruntime.InferenceSession(file) + input_x = sess.get_inputs()[0].name + result = sess.run([], {input_x: data})[0] + assert np.all(out_ms == result) + + os.chmod(file, stat.S_IWRITE) + os.remove(file) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_reduce(): + """ + /// Feature: Reduce + /// Description: reduce tensor elements, include reduce_mean, reduce_max, etc. + /// Expectation: Euqal to numpy results + """ + reduce = NetReduce() + indice = Tensor(np.array([ + [[0., 2., 1., 4., 0., 2.], [3., 1., 2., 2., 4., 0.]], + [[2., 0., 1., 5., 0., 1.], [1., 0., 0., 4., 4., 3.]], + [[4., 1., 4., 0., 0., 0.], [2., 5., 1., 0., 1., 3.]] + ]).astype(np.float32)) + output = reduce(indice) + print(output[0]) + print(output[1]) + print(output[2]) + print(output[3]) + print(output[4]) + print(output[5]) + print(output[6]) + print(output[7]) + print(output[8]) + print(output[9]) + print(output[10]) + print(output[11]) + print(output[12]) + print(output[13]) + print(output[14]) + print(output[15]) + print(output[16]) + print(output[17]) + expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32) + expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype( + np.float32) + expect_2 = np.array([[1.5, 2.], [1.5, 2.], [1.5, 2.]]).astype(np.float32) + expect_3 = np.array([2, 1.5, 1.5, 2.5, 1.5, 1.5]).astype(np.float32) + expect_4 = np.array([1.75]).astype(np.float32) + expect_5 = np.array([[6., 3., 6., 9., 0., 3.], [6., 6., 3., 6., 9., 6.]]).astype(np.float32) + expect_6 = np.array([[9., 12.], [9., 12.], [9., 12.]]).astype(np.float32) + expect_7 = np.array([[4., 2., 4., 5., 0., 2.], [3., 5., 2., 4., 4., 3.]]).astype(np.float32) + expect_8 = np.array([[4., 4.], [5., 4.], [4., 5.]]).astype(np.float32) + expect_9 = np.array([[0., 0., 1., 0., 0., 0.], [1., 0., 0., 0., 1., 0.]]).astype(np.float32) + expect_10 = np.array([[0., 1., 1., 2., 0., 0.], [1., 0., 0., 4., 0., 1.], [2., 1., 1., 0., 0., 0.]]).astype( + np.float32) + expect_11 = np.array([[0., 0.], [0., 0.], [0., 0.]]).astype(np.float32) + expect_12 = np.array([0., 0., 0., 0., 0., 0.]).astype(np.float32) + assert (output[0].asnumpy() == expect_0).all() + assert (output[1].asnumpy() == expect_1).all() + assert (output[2].asnumpy() == expect_2).all() + assert (output[3].asnumpy() == expect_3).all() + assert (output[4].asnumpy() == expect_4).all() + assert (output[5].asnumpy() == expect_5).all() + assert (output[6].asnumpy() == expect_6).all() + assert (output[7].asnumpy() == expect_7).all() + assert (output[8].asnumpy() == expect_8).all() + assert (output[9].asnumpy() == expect_8).all() + assert (output[10].asnumpy() == 5.0).all() + assert (output[11].asnumpy() == expect_9).all() + assert (output[12].asnumpy() == expect_10).all() + assert (output[13].asnumpy() == expect_11).all() + assert (output[14].asnumpy() == expect_12).all() + assert (output[15].asnumpy() == 0.0).all() + assert (output[16].asnumpy() == expect_11).all() + assert (output[17].asnumpy() == 0.0).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_reduce_logic(): + """ + /// Feature: Reduce logic + /// Description: Include reduce_all, reduce_any + /// Expectation: Euqal to numpy results + """ + reduce_logic = NetReduceLogic() + indice_bool = Tensor([[[False, True, True, True, False, True], + [True, True, True, True, True, False]], + [[True, False, True, True, False, True], + [True, False, False, True, True, True]], + [[True, True, True, False, False, False], + [True, True, True, False, True, True]]]) + output = reduce_logic(indice_bool) + expect_all_1 = np.array([[False, False, True, False, False, False], + [True, False, False, False, True, False]]) + expect_all_2 = np.array([[False, False], [False, False], [False, False]]) + expect_all_3 = False + expect_all_4 = False + expect_any_1 = np.array([[True, True, True, True, False, True], [True, True, True, True, True, True]]) + expect_any_2 = np.array([[True, True], [True, True], [True, True]]) + expect_any_3 = True + expect_any_4 = True + + assert (output[0].asnumpy() == expect_all_1).all() + assert (output[1].asnumpy() == expect_all_2).all() + assert (output[2].asnumpy() == expect_all_3).all() + assert (output[3].asnumpy() == expect_all_4).all() + assert (output[4].asnumpy() == expect_any_1).all() + assert (output[5].asnumpy() == expect_any_2).all() + assert (output[6].asnumpy() == expect_any_3).all() + assert (output[7].asnumpy() == expect_any_4).all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_reduce_prod(): + """ + /// Feature: Reduce prod + /// Description: Product of tensor elements + /// Expectation: Euqal to numpy results + """ + reduce_prod = NetReduceProd() + indices = Tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], + [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], + [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]).astype(np.float32)) + output = reduce_prod(indices) + expect_prod_0 = np.array([[28, 28, 28, 28, 28, 28], + [80, 80, 80, 80, 80, 80], + [162, 162, 162, 162, 162, 162]]).astype(np.float32) + expect_prod_1 = np.array([[6, 6, 6, 6, 6, 6], + [120, 120, 120, 120, 120, 120], + [504, 504, 504, 504, 504, 504]]).astype(np.float32) + expect_prod_2 = np.array([[1.00000e+00, 6.40000e+01, 7.29000e+02], + [4.09600e+03, 1.56250e+04, 4.66560e+04], + [1.17649e+05, 2.62144e+05, 5.31441e+05]]).astype(np.float32) + expect_prod_3 = np.array([362880, 362880, 362880, 362880, 362880, 362880]).astype(np.float32) + expect_prod_4 = np.array([[[2.2833798e+33]]]).astype(np.float32) + assert (output[0].asnumpy() == expect_prod_0).all() + assert (output[1].asnumpy() == expect_prod_1).all() + assert (output[2].asnumpy() == expect_prod_2).all() + assert (output[3].asnumpy() == expect_prod_3).all() + assert (output[4].asnumpy() == expect_prod_4).all() diff --git a/tests/st/ops/cpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/cpu/test_resize_nearest_neighbor_op.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/cpu/test_scatter_arithmetic_op.py b/tests/st/ops/cpu/test_scatter_arithmetic_op.py index 9c2487aca54..f710d4a2d06 100644 --- a/tests/st/ops/cpu/test_scatter_arithmetic_op.py +++ b/tests/st/ops/cpu/test_scatter_arithmetic_op.py @@ -1,745 +1,745 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, Parameter -from mindspore.ops import operations as P -from mindspore.ops import functional as F - -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - -class TestScatterAddNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterAddNet, self).__init__() - self.scatter_add = P.ScatterAdd(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_add(self.inputx, self.indices, self.updates) - return out - - -def scatter_add_net(inputx, indices, updates): - lock = True - net = TestScatterAddNet(lock, inputx, indices, updates) - return net() - - -def scatter_add_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterAddNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_small_float32(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[6., 8., 10.], - [12., 14., 16.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterAddNet(lock, inputx, indices, updates) - net() - expected = np.array([[6., 8., 10.], - [12., 14., 16.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_large_shape_float32(): - inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) - indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) - updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[[[1., 2., 3., 4.], - [5., 6., 7., 8.], - [9., 10., 11., 12.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], - [[[73., 74., 75., 76.], - [77., 78., 79., 80.], - [81., 82., 83., 84.]], - [[85., 86., 87., 88.], - [89., 90., 91., 92.], - [93., 94., 95., 96.]]], - [[[25., 26., 27., 28.], - [29., 30., 31., 32.], - [33., 34., 35., 36.]], - [[37., 38., 39., 40.], - [41., 42., 43., 44.], - [45., 46., 47., 48.]]], - [[[49., 50., 51., 52.], - [53., 54., 55., 56.], - [57., 58., 59., 60.]], - [[61., 62., 63., 64.], - [65., 66., 67., 68.], - [69., 70., 71., 72.]]]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_small_float32_use_locking_false(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([1, 0]).astype(np.int32)) - updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) - output = scatter_add_use_locking_false_net(inputx, indices, updates) - expected = np.array([[3., 4., 5.], - [0., 1., 2.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_input_less_than_1_float32(): - inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516], - [0.876542, 0.451611, 0.55112], - [0.111244, 0.633333, 0.34444]]).astype(np.float32)) - indices = Tensor(np.array([[[1, 0, 2], - [2, 2, 0]], - [[1, 0, 1], - [2, 1, 2]]]).astype(np.int32)) - updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[141.21414, 144.41515, 147.51517], - [208.87654, 212.45161, 216.55112], - [257.11124, 262.63333, 267.34442]], dtype=np.float32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_float16(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float16)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float16)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[6., 8., 10.], - [12., 14., 16.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_large_float16(): - inputx = Tensor(np.zeros((2, 3, 4)).astype(np.float16)) - indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32)) - updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[[138., 140., 142., 144.], - [146., 148., 150., 152.], - [154., 156., 158., 160.]], - [[186., 188., 190., 192.], - [194., 196., 198., 200.], - [202., 204., 206., 208.]]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_disordered_float16(): - inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16))) - indices = Tensor(np.array([[[0, 1, 2], - [2, 1, 0]], - [[0, 0, 0], - [2, 2, 2]]]).astype(np.int32)) - updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[464., 468., 472., 476.], - [187., 188., 189., 190.], - [492., 496., 500., 504.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_large_int32(): - inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32)) - indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32)) - updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[[138., 140., 142., 144.], - [146., 148., 150., 152.], - [154., 156., 158., 160.]], - [[186., 188., 190., 192.], - [194., 196., 198., 200.], - [202., 204., 206., 208.]]]).astype(np.int32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_disordered_int32(): - inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) - indices = Tensor(np.array([[[0, 1, 2], - [2, 1, 0]], - [[0, 0, 0], - [2, 2, 2]]]).astype(np.int32)) - updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) - output = scatter_add_net(inputx, indices, updates) - expected = np.array([[464., 468., 472., 476.], - [187., 188., 189., 190.], - [492., 496., 500., 504.]]).astype(np.int32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_function(): - """ - Feature: test_scatter_add_function. - Description: test cases for scatter add functinal - Expectation: the result match numpy implementation. - """ - input_x = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) - indices = Tensor(np.array([[[0, 1, 2], - [2, 1, 0]], - [[0, 0, 0], - [2, 2, 2]]]).astype(np.int32)) - updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) - output = F.scatter_add(input_x, indices, updates) - expected = np.array([[464., 468., 472., 476.], - [187., 188., 189., 190.], - [492., 496., 500., 504.]]).astype(np.int32) - np.testing.assert_allclose(output.asnumpy(), expected, rtol=1e-6) - - -class TestScatterSubNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterSubNet, self).__init__() - self.scatter_sub = P.ScatterSub(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_sub(self.inputx, self.indices, self.updates) - return out - - -def scatter_sub_net(inputx, indices, updates): - lock = True - net = TestScatterSubNet(lock, inputx, indices, updates) - return net() - - -def scatter_sub_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterSubNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_sub_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterSubNet(lock, inputx, indices, updates) - net() - expected = np.array([[-6., -8., -10.], - [-12., -14., -16.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_sub_large_shape_float32(): - inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) - indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) - updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) - output = scatter_sub_net(inputx, indices, updates) - expected = np.array( - [[[[1.0, 0.0, -1.0, -2.0], - [-3.0, -4.0, -5.0, -6.0], - [-7.0, -8.0, -9.0, -10.0]], - [[-11.0, -12.0, -13.0, -14.0], - [-15.0, -16.0, -17.0, -18.0], - [-19.0, -20.0, -21.0, -22.0]]], - [[[-71.0, -72.0, -73.0, -74.0], - [-75.0, -76.0, -77.0, -78.0], - [-79.0, -80.0, -81.0, -82.0]], - [[-83.0, -84.0, -85.0, -86.0], - [-87.0, -88.0, -89.0, -90.0], - [-91.0, -92.0, -93.0, -94.0]]], - [[[-23.0, -24.0, -25.0, -26.0], - [-27.0, -28.0, -29.0, -30.0], - [-31.0, -32.0, -33.0, -34.0]], - [[-35.0, -36.0, -37.0, -38.0], - [-39.0, -40.0, -41.0, -42.0], - [-43.0, -44.0, -45.0, -46.0]]], - [[[-47.0, -48.0, -49.0, -50.0], - [-51.0, -52.0, -53.0, -54.0], - [-55.0, -56.0, -57.0, -58.0]], - [[-59.0, -60.0, -61.0, -62.0], - [-63.0, -64.0, -65.0, -66.0], - [-67.0, -68.0, -69.0, -70.0]]]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_sub_small_float32_use_locking_false(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([1, 0]).astype(np.int32)) - updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) - output = scatter_sub_use_locking_false_net(inputx, indices, updates) - expected = np.array([[-3., -4., -5.], - [-0., -1., -2.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -class TestScatterMulNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterMulNet, self).__init__() - self.scatter_mul = P.ScatterMul(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_mul(self.inputx, self.indices, self.updates) - return out - - -def scatter_mul_net(inputx, indices, updates): - lock = True - net = TestScatterMulNet(lock, inputx, indices, updates) - return net() - - -def scatter_mul_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterMulNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_mul_input_updated(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterMulNet(lock, inputx, indices, updates) - net() - expected = np.array([[0., 7., 16.], - [27., 40., 55.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_mul_output_updated_float32(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_mul_net(inputx, indices, updates) - expected = np.array([[0., 7., 16.], - [27., 40., 55.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_mul_small_float32_use_locking_false(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_mul_use_locking_false_net(inputx, indices, updates) - expected = np.array([[0., 7., 16.], - [27., 40., 55.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -class TestScatterDivNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterDivNet, self).__init__() - self.scatter_div = P.ScatterDiv(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_div(self.inputx, self.indices, self.updates) - return out - - -def scatter_div_net(inputx, indices, updates): - lock = True - net = TestScatterDivNet(lock, inputx, indices, updates) - return net() - - -def scatter_div_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterDivNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_div_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterDivNet(lock, inputx, indices, updates) - net() - expected = np.array([[0., 0., 0.], - [0., 0., 0.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_div_output_updated_float32(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_div_net(inputx, indices, updates) - expected = np.array([[0., 0., 0.], - [0., 0., 0.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_div_small_float32_use_locking_false(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.ones(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_div_use_locking_false_net(inputx, indices, updates) - expected = np.array([[10., 10., 10.], - [10., 10., 10.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_div_output_int16(): - """ - Feature: test ScatterDiv output and input_x same value. - Description: input is int16. - Expectation: output and input_x have same value - """ - input_x = Parameter(Tensor(np.array([[6, 6, 6], [2, 2, 2]]), mindspore.int16), name="x") - indices = Tensor(np.array([0, 1]), mindspore.int32) - updates = Tensor(np.array([[2, 2, 2], [2, 2, 2]]), mindspore.int16) - output = P.ScatterDiv()(input_x, indices, updates) - assert np.allclose(output.asnumpy(), input_x.asnumpy(), 0.0001, 0.0001) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_div_output_float64(): - """ - Feature: test ScatterDiv output and input_x same value. - Description: input is float64. - Expectation: output and input_x have same value - """ - input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float64), name="x") - indices = Tensor(np.array([0, 1]), mindspore.int32) - updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float64) - output = P.ScatterDiv()(input_x, indices, updates) - assert np.allclose(output.asnumpy(), input_x.asnumpy(), 0.0001, 0.0001) - - -class TestScatterMaxNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterMaxNet, self).__init__() - self.scatter_max = P.ScatterMax(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_max(self.inputx, self.indices, self.updates) - return out - - -def scatter_max_net(inputx, indices, updates): - lock = True - net = TestScatterMaxNet(lock, inputx, indices, updates) - return net() - - -def scatter_max_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterMaxNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_max_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterMaxNet(lock, inputx, indices, updates) - net() - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_max_output_updated_float32(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_max_net(inputx, indices, updates) - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_max_small_float32_use_locking_false(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_max_use_locking_false_net(inputx, indices, updates) - expected = np.array([[10., 10., 10.], - [10., 10., 11.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -class TestScatterMinNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterMinNet, self).__init__() - self.scatter_min = P.ScatterMin(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_min(self.inputx, self.indices, self.updates) - return out - - -def scatter_min_net(inputx, indices, updates): - lock = True - net = TestScatterMinNet(lock, inputx, indices, updates) - return net() - - -def scatter_min_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterMinNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_min_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterMinNet(lock, inputx, indices, updates) - net() - expected = np.array([[0., 0., 0.], - [0., 0., 0.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_min_output_updated_float32(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_min_net(inputx, indices, updates) - expected = np.array([[0., 1., 1.], - [1., 1., 1.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_min_small_float32_use_locking_false(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_min_use_locking_false_net(inputx, indices, updates) - expected = np.array([[0., 1., 1.], - [1., 1., 1.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -class TestScatterUpdateNet(nn.Cell): - def __init__(self, lock, inputx, indices, updates): - super(TestScatterUpdateNet, self).__init__() - self.scatter_update = P.ScatterUpdate(use_locking=lock) - self.inputx = Parameter(inputx, name="inputx") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - - def construct(self): - out = self.scatter_update(self.inputx, self.indices, self.updates) - return out - - -def scatter_update_net(inputx, indices, updates): - lock = True - net = TestScatterUpdateNet(lock, inputx, indices, updates) - return net() - - -def scatter_update_use_locking_false_net(inputx, indices, updates): - lock = False - net = TestScatterUpdateNet(lock, inputx, indices, updates) - return net() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_update_input_updated(): - inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - lock = True - net = TestScatterUpdateNet(lock, inputx, indices, updates) - net() - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_update_output_updated_float32(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_update_net(inputx, indices, updates) - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_update_output_updated_huge_tensor_float32(): - """ - Feature: Test huge input tensor case of cpu kernel ScatterUpdate. - Description: The first input tensor for cpu kernel ScatterUpdate is huge, and - the memory size of this tensor should be greater than 2147483647. - In this case, memory size of inputx tensor is 2147483652 (178956971 * 3 * sizeof(float32)) - Expectation: success. - """ - inputx = Tensor(np.ones((178956971, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_update_net(inputx, indices, updates) - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(output.asnumpy()[0:2], expected) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_update_small_float32_use_locking_false(): - inputx = Tensor(np.ones((2, 3)).astype(np.float32)) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - output = scatter_update_use_locking_false_net(inputx, indices, updates) - expected = np.array([[6., 7., 8.], - [9., 10., 11.]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -class TestScatterAddNetDynamic(nn.Cell): - def __init__(self, lock): - super(TestScatterAddNetDynamic, self).__init__() - self.scatter_add = P.ScatterAdd(use_locking=lock) - - def construct(self, inputx, indices, updates): - out = self.scatter_add(inputx, indices, updates) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_add_dynamic_shape(): - """ - Feature: op dynamic shape - Description: set input_shape None and input real tensor - Expectation: success - """ - inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - net = TestScatterAddNetDynamic(False) - indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) - updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) - net.set_inputs(inputx, indices_dyn, updates_dyn) - output = net(inputx, indices, updates) - expected_shape = (2, 3) - assert expected_shape == output.asnumpy().shape - - -class TestScatterSubNetDynamic(nn.Cell): - def __init__(self, lock): - super(TestScatterSubNetDynamic, self).__init__() - self.scatter_sub = P.ScatterSub(use_locking=lock) - - def construct(self, inputx, indices, updates): - out = self.scatter_sub(inputx, indices, updates) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_sub_dynamic_shape(): - """ - Feature: op dynamic shape - Description: set input_shape None and input real tensor - Expectation: success - """ - - inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - net = TestScatterSubNetDynamic(False) - indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) - updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) - net.set_inputs(inputx, indices_dyn, updates_dyn) - output = net(inputx, indices, updates) - expected_shape = (2, 3) - assert expected_shape == output.asnumpy().shape - - -class TestScatterUpdateNetDynamic(nn.Cell): - def __init__(self, lock): - super(TestScatterUpdateNetDynamic, self).__init__() - self.scatter_update = P.ScatterUpdate(use_locking=lock) - - def construct(self, inputx, indices, updates): - out = self.scatter_update(inputx, indices, updates) - return out - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_scatter_update_dynamic_shape(): - """ - Feature: op dynamic shape - Description: set input_shape None and input real tensor - Expectation: success - """ - - inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) - indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) - updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) - net = TestScatterUpdateNetDynamic(False) - indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) - updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) - net.set_inputs(inputx, indices_dyn, updates_dyn) - output = net(inputx, indices, updates) - expected_shape = (2, 3) - assert expected_shape == output.asnumpy().shape +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class TestScatterAddNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterAddNet, self).__init__() + self.scatter_add = P.ScatterAdd(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_add(self.inputx, self.indices, self.updates) + return out + + +def scatter_add_net(inputx, indices, updates): + lock = True + net = TestScatterAddNet(lock, inputx, indices, updates) + return net() + + +def scatter_add_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterAddNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_small_float32(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[6., 8., 10.], + [12., 14., 16.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterAddNet(lock, inputx, indices, updates) + net() + expected = np.array([[6., 8., 10.], + [12., 14., 16.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_large_shape_float32(): + inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) + indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) + updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[[[1., 2., 3., 4.], + [5., 6., 7., 8.], + [9., 10., 11., 12.]], + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], + [[[73., 74., 75., 76.], + [77., 78., 79., 80.], + [81., 82., 83., 84.]], + [[85., 86., 87., 88.], + [89., 90., 91., 92.], + [93., 94., 95., 96.]]], + [[[25., 26., 27., 28.], + [29., 30., 31., 32.], + [33., 34., 35., 36.]], + [[37., 38., 39., 40.], + [41., 42., 43., 44.], + [45., 46., 47., 48.]]], + [[[49., 50., 51., 52.], + [53., 54., 55., 56.], + [57., 58., 59., 60.]], + [[61., 62., 63., 64.], + [65., 66., 67., 68.], + [69., 70., 71., 72.]]]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_small_float32_use_locking_false(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([1, 0]).astype(np.int32)) + updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) + output = scatter_add_use_locking_false_net(inputx, indices, updates) + expected = np.array([[3., 4., 5.], + [0., 1., 2.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_input_less_than_1_float32(): + inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516], + [0.876542, 0.451611, 0.55112], + [0.111244, 0.633333, 0.34444]]).astype(np.float32)) + indices = Tensor(np.array([[[1, 0, 2], + [2, 2, 0]], + [[1, 0, 1], + [2, 1, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[141.21414, 144.41515, 147.51517], + [208.87654, 212.45161, 216.55112], + [257.11124, 262.63333, 267.34442]], dtype=np.float32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_float16(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float16)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float16)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[6., 8., 10.], + [12., 14., 16.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_large_float16(): + inputx = Tensor(np.zeros((2, 3, 4)).astype(np.float16)) + indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[[138., 140., 142., 144.], + [146., 148., 150., 152.], + [154., 156., 158., 160.]], + [[186., 188., 190., 192.], + [194., 196., 198., 200.], + [202., 204., 206., 208.]]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_disordered_float16(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_large_int32(): + inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32)) + indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[[138., 140., 142., 144.], + [146., 148., 150., 152.], + [154., 156., 158., 160.]], + [[186., 188., 190., 192.], + [194., 196., 198., 200.], + [202., 204., 206., 208.]]]).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_disordered_int32(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_function(): + """ + Feature: test_scatter_add_function. + Description: test cases for scatter add functinal + Expectation: the result match numpy implementation. + """ + input_x = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) + output = F.scatter_add(input_x, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]).astype(np.int32) + np.testing.assert_allclose(output.asnumpy(), expected, rtol=1e-6) + + +class TestScatterSubNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterSubNet, self).__init__() + self.scatter_sub = P.ScatterSub(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_sub(self.inputx, self.indices, self.updates) + return out + + +def scatter_sub_net(inputx, indices, updates): + lock = True + net = TestScatterSubNet(lock, inputx, indices, updates) + return net() + + +def scatter_sub_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterSubNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_sub_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterSubNet(lock, inputx, indices, updates) + net() + expected = np.array([[-6., -8., -10.], + [-12., -14., -16.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_sub_large_shape_float32(): + inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) + indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) + updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) + output = scatter_sub_net(inputx, indices, updates) + expected = np.array( + [[[[1.0, 0.0, -1.0, -2.0], + [-3.0, -4.0, -5.0, -6.0], + [-7.0, -8.0, -9.0, -10.0]], + [[-11.0, -12.0, -13.0, -14.0], + [-15.0, -16.0, -17.0, -18.0], + [-19.0, -20.0, -21.0, -22.0]]], + [[[-71.0, -72.0, -73.0, -74.0], + [-75.0, -76.0, -77.0, -78.0], + [-79.0, -80.0, -81.0, -82.0]], + [[-83.0, -84.0, -85.0, -86.0], + [-87.0, -88.0, -89.0, -90.0], + [-91.0, -92.0, -93.0, -94.0]]], + [[[-23.0, -24.0, -25.0, -26.0], + [-27.0, -28.0, -29.0, -30.0], + [-31.0, -32.0, -33.0, -34.0]], + [[-35.0, -36.0, -37.0, -38.0], + [-39.0, -40.0, -41.0, -42.0], + [-43.0, -44.0, -45.0, -46.0]]], + [[[-47.0, -48.0, -49.0, -50.0], + [-51.0, -52.0, -53.0, -54.0], + [-55.0, -56.0, -57.0, -58.0]], + [[-59.0, -60.0, -61.0, -62.0], + [-63.0, -64.0, -65.0, -66.0], + [-67.0, -68.0, -69.0, -70.0]]]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_sub_small_float32_use_locking_false(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([1, 0]).astype(np.int32)) + updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) + output = scatter_sub_use_locking_false_net(inputx, indices, updates) + expected = np.array([[-3., -4., -5.], + [-0., -1., -2.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterMulNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterMulNet, self).__init__() + self.scatter_mul = P.ScatterMul(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_mul(self.inputx, self.indices, self.updates) + return out + + +def scatter_mul_net(inputx, indices, updates): + lock = True + net = TestScatterMulNet(lock, inputx, indices, updates) + return net() + + +def scatter_mul_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterMulNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_mul_input_updated(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterMulNet(lock, inputx, indices, updates) + net() + expected = np.array([[0., 7., 16.], + [27., 40., 55.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_mul_output_updated_float32(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_mul_net(inputx, indices, updates) + expected = np.array([[0., 7., 16.], + [27., 40., 55.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_mul_small_float32_use_locking_false(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_mul_use_locking_false_net(inputx, indices, updates) + expected = np.array([[0., 7., 16.], + [27., 40., 55.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterDivNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterDivNet, self).__init__() + self.scatter_div = P.ScatterDiv(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_div(self.inputx, self.indices, self.updates) + return out + + +def scatter_div_net(inputx, indices, updates): + lock = True + net = TestScatterDivNet(lock, inputx, indices, updates) + return net() + + +def scatter_div_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterDivNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_div_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterDivNet(lock, inputx, indices, updates) + net() + expected = np.array([[0., 0., 0.], + [0., 0., 0.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_div_output_updated_float32(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_div_net(inputx, indices, updates) + expected = np.array([[0., 0., 0.], + [0., 0., 0.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_div_small_float32_use_locking_false(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.ones(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_div_use_locking_false_net(inputx, indices, updates) + expected = np.array([[10., 10., 10.], + [10., 10., 10.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_div_output_int16(): + """ + Feature: test ScatterDiv output and input_x same value. + Description: input is int16. + Expectation: output and input_x have same value + """ + input_x = Parameter(Tensor(np.array([[6, 6, 6], [2, 2, 2]]), mindspore.int16), name="x") + indices = Tensor(np.array([0, 1]), mindspore.int32) + updates = Tensor(np.array([[2, 2, 2], [2, 2, 2]]), mindspore.int16) + output = P.ScatterDiv()(input_x, indices, updates) + assert np.allclose(output.asnumpy(), input_x.asnumpy(), 0.0001, 0.0001) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_div_output_float64(): + """ + Feature: test ScatterDiv output and input_x same value. + Description: input is float64. + Expectation: output and input_x have same value + """ + input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float64), name="x") + indices = Tensor(np.array([0, 1]), mindspore.int32) + updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float64) + output = P.ScatterDiv()(input_x, indices, updates) + assert np.allclose(output.asnumpy(), input_x.asnumpy(), 0.0001, 0.0001) + + +class TestScatterMaxNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterMaxNet, self).__init__() + self.scatter_max = P.ScatterMax(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_max(self.inputx, self.indices, self.updates) + return out + + +def scatter_max_net(inputx, indices, updates): + lock = True + net = TestScatterMaxNet(lock, inputx, indices, updates) + return net() + + +def scatter_max_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterMaxNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_max_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterMaxNet(lock, inputx, indices, updates) + net() + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_max_output_updated_float32(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_max_net(inputx, indices, updates) + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_max_small_float32_use_locking_false(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_max_use_locking_false_net(inputx, indices, updates) + expected = np.array([[10., 10., 10.], + [10., 10., 11.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterMinNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterMinNet, self).__init__() + self.scatter_min = P.ScatterMin(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_min(self.inputx, self.indices, self.updates) + return out + + +def scatter_min_net(inputx, indices, updates): + lock = True + net = TestScatterMinNet(lock, inputx, indices, updates) + return net() + + +def scatter_min_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterMinNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_min_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterMinNet(lock, inputx, indices, updates) + net() + expected = np.array([[0., 0., 0.], + [0., 0., 0.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_min_output_updated_float32(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_min_net(inputx, indices, updates) + expected = np.array([[0., 1., 1.], + [1., 1., 1.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_min_small_float32_use_locking_false(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_min_use_locking_false_net(inputx, indices, updates) + expected = np.array([[0., 1., 1.], + [1., 1., 1.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterUpdateNet(nn.Cell): + def __init__(self, lock, inputx, indices, updates): + super(TestScatterUpdateNet, self).__init__() + self.scatter_update = P.ScatterUpdate(use_locking=lock) + self.inputx = Parameter(inputx, name="inputx") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + + def construct(self): + out = self.scatter_update(self.inputx, self.indices, self.updates) + return out + + +def scatter_update_net(inputx, indices, updates): + lock = True + net = TestScatterUpdateNet(lock, inputx, indices, updates) + return net() + + +def scatter_update_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterUpdateNet(lock, inputx, indices, updates) + return net() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_update_input_updated(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + lock = True + net = TestScatterUpdateNet(lock, inputx, indices, updates) + net() + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_update_output_updated_float32(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_update_net(inputx, indices, updates) + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_update_output_updated_huge_tensor_float32(): + """ + Feature: Test huge input tensor case of cpu kernel ScatterUpdate. + Description: The first input tensor for cpu kernel ScatterUpdate is huge, and + the memory size of this tensor should be greater than 2147483647. + In this case, memory size of inputx tensor is 2147483652 (178956971 * 3 * sizeof(float32)) + Expectation: success. + """ + inputx = Tensor(np.ones((178956971, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_update_net(inputx, indices, updates) + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(output.asnumpy()[0:2], expected) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_update_small_float32_use_locking_false(): + inputx = Tensor(np.ones((2, 3)).astype(np.float32)) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + output = scatter_update_use_locking_false_net(inputx, indices, updates) + expected = np.array([[6., 7., 8.], + [9., 10., 11.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +class TestScatterAddNetDynamic(nn.Cell): + def __init__(self, lock): + super(TestScatterAddNetDynamic, self).__init__() + self.scatter_add = P.ScatterAdd(use_locking=lock) + + def construct(self, inputx, indices, updates): + out = self.scatter_add(inputx, indices, updates) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_add_dynamic_shape(): + """ + Feature: op dynamic shape + Description: set input_shape None and input real tensor + Expectation: success + """ + inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + net = TestScatterAddNetDynamic(False) + indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) + updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) + net.set_inputs(inputx, indices_dyn, updates_dyn) + output = net(inputx, indices, updates) + expected_shape = (2, 3) + assert expected_shape == output.asnumpy().shape + + +class TestScatterSubNetDynamic(nn.Cell): + def __init__(self, lock): + super(TestScatterSubNetDynamic, self).__init__() + self.scatter_sub = P.ScatterSub(use_locking=lock) + + def construct(self, inputx, indices, updates): + out = self.scatter_sub(inputx, indices, updates) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_sub_dynamic_shape(): + """ + Feature: op dynamic shape + Description: set input_shape None and input real tensor + Expectation: success + """ + + inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + net = TestScatterSubNetDynamic(False) + indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) + updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) + net.set_inputs(inputx, indices_dyn, updates_dyn) + output = net(inputx, indices, updates) + expected_shape = (2, 3) + assert expected_shape == output.asnumpy().shape + + +class TestScatterUpdateNetDynamic(nn.Cell): + def __init__(self, lock): + super(TestScatterUpdateNetDynamic, self).__init__() + self.scatter_update = P.ScatterUpdate(use_locking=lock) + + def construct(self, inputx, indices, updates): + out = self.scatter_update(inputx, indices, updates) + return out + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_scatter_update_dynamic_shape(): + """ + Feature: op dynamic shape + Description: set input_shape None and input real tensor + Expectation: success + """ + + inputx = Parameter(Tensor(np.zeros((2, 3)).astype(np.float32))) + indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) + updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) + net = TestScatterUpdateNetDynamic(False) + indices_dyn = Tensor(shape=[None, None], dtype=indices.dtype) + updates_dyn = Tensor(shape=[None, None, None], dtype=updates.dtype) + net.set_inputs(inputx, indices_dyn, updates_dyn) + output = net(inputx, indices, updates) + expected_shape = (2, 3) + assert expected_shape == output.asnumpy().shape diff --git a/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py b/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py index e2bd5e66260..109285782ba 100644 --- a/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py +++ b/tests/st/ops/cpu/test_softmax_cross_entropy_with_logits_op.py @@ -1,54 +1,54 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor - - -class NetSoftmaxCrossEntropyWithLogits(nn.Cell): - def __init__(self): - super(NetSoftmaxCrossEntropyWithLogits, self).__init__() - self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) - - def construct(self, logits, labels): - return self.loss(logits, labels) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_softmax_cross_entropy_with_logits(): - """ - Feature: template - Description: template - Expectation: template - """ - logits = Tensor(np.array([[1, 1, 10], - [1, 10, 1], - [10, 1, 1]]).astype(np.float32)) - labels = Tensor(np.array([[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]).astype(np.float32)) - expect_loss = [0.00024673, 0.00024673, 0.00024673] - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits() - output = softmax_cross_entropy_with_logits(logits, labels) - error0 = 1.0e-6 - diff0 = output.asnumpy() - expect_loss - assert np.all(abs(diff0) < error0) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + + +class NetSoftmaxCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(NetSoftmaxCrossEntropyWithLogits, self).__init__() + self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False) + + def construct(self, logits, labels): + return self.loss(logits, labels) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_softmax_cross_entropy_with_logits(): + """ + Feature: template + Description: template + Expectation: template + """ + logits = Tensor(np.array([[1, 1, 10], + [1, 10, 1], + [10, 1, 1]]).astype(np.float32)) + labels = Tensor(np.array([[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]).astype(np.float32)) + expect_loss = [0.00024673, 0.00024673, 0.00024673] + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits() + output = softmax_cross_entropy_with_logits(logits, labels) + error0 = 1.0e-6 + diff0 = output.asnumpy() - expect_loss + assert np.all(abs(diff0) < error0) diff --git a/tests/st/ops/cpu/test_softmax_op.py b/tests/st/ops/cpu/test_softmax_op.py index ca73e6ab960..3c871f5428d 100644 --- a/tests/st/ops/cpu/test_softmax_op.py +++ b/tests/st/ops/cpu/test_softmax_op.py @@ -1,79 +1,79 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - - -class NetSoftmax(nn.Cell): - def __init__(self): - super(NetSoftmax, self).__init__() - self.softmax = P.Softmax(axis=-1) - x = Tensor(np.array([[0.1, 0.3, 0.6], - [0.2, -0.6, 0.8], - [0.6, 1, 0.4]]).astype(np.float32)) - self.x = Parameter(initializer(x, x.shape), name='x') - - def construct(self): - return self.softmax(self.x) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_softmax(): - Softmax = NetSoftmax() - output = Softmax() - output = output.asnumpy() - outputSum = output.sum(axis=1) - expect = np.ones(3) - error = expect * 1.0e-6 - diff = np.abs(outputSum - expect) - print(diff) - assert np.all(diff < error) - - -class NetSoftmax1(nn.Cell): - def __init__(self): - super(NetSoftmax1, self).__init__() - self.softmax = P.Softmax(axis=-2) - x = Tensor(np.array([[0.1, 0.3, 0.6], - [0.2, -0.6, 0.8], - [0.6, 1, 0.4]]).astype(np.float32)) - self.x = Parameter(initializer(x, x.shape), name='x') - - def construct(self): - return self.softmax(self.x) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_softmax1(): - Softmax = NetSoftmax1() - output = Softmax() - output = output.asnumpy() - outputSum = output.sum(axis=0) - expect = np.ones(3) - error = expect * 1.0e-6 - diff = np.abs(outputSum - expect) - print(diff) - assert np.all(diff < error) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class NetSoftmax(nn.Cell): + def __init__(self): + super(NetSoftmax, self).__init__() + self.softmax = P.Softmax(axis=-1) + x = Tensor(np.array([[0.1, 0.3, 0.6], + [0.2, -0.6, 0.8], + [0.6, 1, 0.4]]).astype(np.float32)) + self.x = Parameter(initializer(x, x.shape), name='x') + + def construct(self): + return self.softmax(self.x) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_softmax(): + Softmax = NetSoftmax() + output = Softmax() + output = output.asnumpy() + outputSum = output.sum(axis=1) + expect = np.ones(3) + error = expect * 1.0e-6 + diff = np.abs(outputSum - expect) + print(diff) + assert np.all(diff < error) + + +class NetSoftmax1(nn.Cell): + def __init__(self): + super(NetSoftmax1, self).__init__() + self.softmax = P.Softmax(axis=-2) + x = Tensor(np.array([[0.1, 0.3, 0.6], + [0.2, -0.6, 0.8], + [0.6, 1, 0.4]]).astype(np.float32)) + self.x = Parameter(initializer(x, x.shape), name='x') + + def construct(self): + return self.softmax(self.x) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_softmax1(): + Softmax = NetSoftmax1() + output = Softmax() + output = output.asnumpy() + outputSum = output.sum(axis=0) + expect = np.ones(3) + error = expect * 1.0e-6 + diff = np.abs(outputSum - expect) + print(diff) + assert np.all(diff < error) diff --git a/tests/st/ops/cpu/test_tensor_scatter_element_op.py b/tests/st/ops/cpu/test_tensor_scatter_element_op.py index 4ca5089da98..6feb27ff704 100644 --- a/tests/st/ops/cpu/test_tensor_scatter_element_op.py +++ b/tests/st/ops/cpu/test_tensor_scatter_element_op.py @@ -1,118 +1,118 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import Parameter -from mindspore.ops import functional as F -from mindspore.ops.operations.array_ops import TensorScatterElements - -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - - -def scatter_element_np(input_x, indices, updates, axis, reduction="none"): - result = input_x.asnumpy().copy() - indices_np = indices.asnumpy().copy() - updates_np = updates.asnumpy().copy() - - i_len = indices_np.shape[0] - j_len = indices_np.shape[1] - - if axis < 0: - axis += len(result.shape) - - for i in range(i_len): - for j in range(j_len): - if axis == 0: - if reduction == "none": - result[indices_np[i][j]][j] = updates_np[i][j] - if reduction == "add": - result[indices_np[i][j]][j] += updates_np[i][j] - if axis == 1: - if reduction == "none": - result[i][indices_np[i][j]] = updates_np[i][j] - if reduction == "add": - result[i][indices_np[i][j]] += updates_np[i][j] - - return result - - -class TestTensorScatterElements(nn.Cell): - def __init__(self, input_x, indices, updates, axis, reduction): - super(TestTensorScatterElements, self).__init__() - self.axis = axis - self.reduction = reduction - self.input_x = Parameter(input_x, name="input_x") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - self.scatter_elements = TensorScatterElements( - self.axis, self.reduction) - - def construct(self): - return self.scatter_elements(self.input_x, self.indices, self.updates) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32]) -@pytest.mark.parametrize('index_dtype', [np.int32, np.int64]) -@pytest.mark.parametrize('axis', [0, 1, -1]) -@pytest.mark.parametrize('reduction', ["none", "add"]) -def test_scatter_elements(dtype, index_dtype, axis, reduction): - """ - Feature: Op TensorScatterElements - Description: Scatter update value according indices to output. - output[indices[i][j]][j] = updates[i][j] if axis = 0, reduction="none" - output[i][indices[i][j]] += updates[i][j] if axis = 1, reduction="add" - Expectation: Ans is same as expected. - """ - x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) - indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) - update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) - - ms_output = TestTensorScatterElements( - x, indices, update, axis, reduction)() - np_output = scatter_element_np(x, indices, update, axis, reduction) - print("ms_output:\n", ms_output.asnumpy()) - assert np.allclose(ms_output.asnumpy(), np_output) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype', [np.float32]) -@pytest.mark.parametrize('index_dtype', [np.int32]) -@pytest.mark.parametrize('axis', [0]) -@pytest.mark.parametrize('reduction', ["none", "add"]) -def test_scatter_add_with_axis_func(dtype, index_dtype, axis, reduction): - """ - Feature: test scatter_add_with_axis functional interface(scatter_add). - Description: Scatter update value according indices to output. - output[indices[i][j]][j] += updates[i][j] if axis = 0, - output[i][indices[i][j]] += updates[i][j] if axis = 1. - Expectation: Ans is same as expected. - """ - x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) - indices = Tensor(np.array([[1, -1, 2], [0, 2, 1]], dtype=index_dtype)) - update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) - - #cause scatter_add will change the value of input, so we first calculate numpy output. - np_output = scatter_element_np(x, indices, update, axis, reduction) - ms_output = F.tensor_scatter_elements(x, indices, update, axis, reduction) - print("np_output:\n", np_output) - print("ms_output:\n", ms_output.asnumpy()) - assert np.allclose(ms_output.asnumpy(), np_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import functional as F +from mindspore.ops.operations.array_ops import TensorScatterElements + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +def scatter_element_np(input_x, indices, updates, axis, reduction="none"): + result = input_x.asnumpy().copy() + indices_np = indices.asnumpy().copy() + updates_np = updates.asnumpy().copy() + + i_len = indices_np.shape[0] + j_len = indices_np.shape[1] + + if axis < 0: + axis += len(result.shape) + + for i in range(i_len): + for j in range(j_len): + if axis == 0: + if reduction == "none": + result[indices_np[i][j]][j] = updates_np[i][j] + if reduction == "add": + result[indices_np[i][j]][j] += updates_np[i][j] + if axis == 1: + if reduction == "none": + result[i][indices_np[i][j]] = updates_np[i][j] + if reduction == "add": + result[i][indices_np[i][j]] += updates_np[i][j] + + return result + + +class TestTensorScatterElements(nn.Cell): + def __init__(self, input_x, indices, updates, axis, reduction): + super(TestTensorScatterElements, self).__init__() + self.axis = axis + self.reduction = reduction + self.input_x = Parameter(input_x, name="input_x") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + self.scatter_elements = TensorScatterElements( + self.axis, self.reduction) + + def construct(self): + return self.scatter_elements(self.input_x, self.indices, self.updates) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32]) +@pytest.mark.parametrize('index_dtype', [np.int32, np.int64]) +@pytest.mark.parametrize('axis', [0, 1, -1]) +@pytest.mark.parametrize('reduction', ["none", "add"]) +def test_scatter_elements(dtype, index_dtype, axis, reduction): + """ + Feature: Op TensorScatterElements + Description: Scatter update value according indices to output. + output[indices[i][j]][j] = updates[i][j] if axis = 0, reduction="none" + output[i][indices[i][j]] += updates[i][j] if axis = 1, reduction="add" + Expectation: Ans is same as expected. + """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) + indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) + update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) + + ms_output = TestTensorScatterElements( + x, indices, update, axis, reduction)() + np_output = scatter_element_np(x, indices, update, axis, reduction) + print("ms_output:\n", ms_output.asnumpy()) + assert np.allclose(ms_output.asnumpy(), np_output) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype', [np.float32]) +@pytest.mark.parametrize('index_dtype', [np.int32]) +@pytest.mark.parametrize('axis', [0]) +@pytest.mark.parametrize('reduction', ["none", "add"]) +def test_scatter_add_with_axis_func(dtype, index_dtype, axis, reduction): + """ + Feature: test scatter_add_with_axis functional interface(scatter_add). + Description: Scatter update value according indices to output. + output[indices[i][j]][j] += updates[i][j] if axis = 0, + output[i][indices[i][j]] += updates[i][j] if axis = 1. + Expectation: Ans is same as expected. + """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) + indices = Tensor(np.array([[1, -1, 2], [0, 2, 1]], dtype=index_dtype)) + update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) + + #cause scatter_add will change the value of input, so we first calculate numpy output. + np_output = scatter_element_np(x, indices, update, axis, reduction) + ms_output = F.tensor_scatter_elements(x, indices, update, axis, reduction) + print("np_output:\n", np_output) + print("ms_output:\n", ms_output.asnumpy()) + assert np.allclose(ms_output.asnumpy(), np_output) diff --git a/tests/st/ops/cpu/test_transpose_op.py b/tests/st/ops/cpu/test_transpose_op.py index cdf9fcd2079..d3e198d555b 100644 --- a/tests/st/ops/cpu/test_transpose_op.py +++ b/tests/st/ops/cpu/test_transpose_op.py @@ -1,379 +1,379 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest -import numpy as np -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.common.api import jit -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -import mindspore.nn as nn -import mindspore.context as context - -context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - - -class Transpose(nn.Cell): - def __init__(self): - super(Transpose, self).__init__() - self.transpose = P.Transpose() - - self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.float32)), [5, 6]), - name='x_2D') - self.perm_2D = (1, 0) - - self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)), [2, 2, 4]), - name='x_3D') - self.perm_3D = (1, 0, 2) - - self.x_4D = Parameter( - initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, - 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), - name='x_4D') - self.perm_4D = (0, 1, 2, 3) - - self.x_5D = Parameter( - initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.float32)), - [1, 2, 3, 4, 5]), name='x_5D') - self.perm_5D = (1, 0, 3, 4, 2) - - @jit - def construct(self): - return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), - self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_transpose(): - transpose = Transpose() - output = transpose() - - expect0 = np.array([[[0, 6, 12, 18, 24], - [1, 7, 13, 19, 25], - [2, 8, 14, 20, 26], - [3, 9, 15, 21, 27], - [4, 10, 16, 22, 28], - [5, 11, 17, 23, 29]]]).astype(np.float32) - expect1 = np.array([[[[0, 1, 2, 3], - [8, 9, 10, 11]], - [[4, 5, 6, 7], - [12, 13, 14, 15]]]]).astype(np.float32) - expect2 = np.array([[[[[0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19]], - [[20, 21, 22, 23, 24], - [25, 26, 27, 28, 29], - [30, 31, 32, 33, 34], - [35, 36, 37, 38, 39]], - [[40, 41, 42, 43, 44], - [45, 46, 47, 48, 49], - [50, 51, 52, 53, 54], - [55, 56, 57, 58, 59]]], - - [[[60, 61, 62, 63, 64], - [65, 66, 67, 68, 69], - [70, 71, 72, 73, 74], - [75, 76, 77, 78, 79]], - [[80, 81, 82, 83, 84], - [85, 86, 87, 88, 89], - [90, 91, 92, 93, 94], - [95, 96, 97, 98, 99]], - [[100, 101, 102, 103, 104], - [105, 106, 107, 108, 109], - [110, 111, 112, 113, 114], - [115, 116, 117, 118, 119]]]]]).astype(np.float32) - expect3 = np.array([[[[[[0, 20, 40], - [1, 21, 41], - [2, 22, 42], - [3, 23, 43], - [4, 24, 44]], - [[5, 25, 45], - [6, 26, 46], - [7, 27, 47], - [8, 28, 48], - [9, 29, 49]], - [[10, 30, 50], - [11, 31, 51], - [12, 32, 52], - [13, 33, 53], - [14, 34, 54]], - [[15, 35, 55], - [16, 36, 56], - [17, 37, 57], - [18, 38, 58], - [19, 39, 59]]]], - - [[[[60, 80, 100], - [61, 81, 101], - [62, 82, 102], - [63, 83, 103], - [64, 84, 104]], - [[65, 85, 105], - [66, 86, 106], - [67, 87, 107], - [68, 88, 108], - [69, 89, 109]], - [[70, 90, 110], - [71, 91, 111], - [72, 92, 112], - [73, 93, 113], - [74, 94, 114]], - [[75, 95, 115], - [76, 96, 116], - [77, 97, 117], - [78, 98, 118], - [79, 99, 119]]]]]]).astype(np.float32) - assert (output[0].asnumpy() == expect0).all() - assert (output[1].asnumpy() == expect1).all() - assert (output[2].asnumpy() == expect2).all() - assert (output[3].asnumpy() == expect3).all() - - -class Transpose_int64(nn.Cell): - def __init__(self): - super(Transpose_int64, self).__init__() - self.transpose = P.Transpose() - - self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.int64)), [5, 6]), - name='x_2D') - self.perm_2D = (1, 0) - - self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.int64)), [2, 2, 4]), - name='x_3D') - self.perm_3D = (1, 0, 2) - - self.x_4D = Parameter( - initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, - 3, 4, 5).astype(np.int64)), [2, 3, 4, 5]), - name='x_4D') - self.perm_4D = (0, 1, 2, 3) - - self.x_5D = Parameter( - initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.int64)), - [1, 2, 3, 4, 5]), name='x_5D') - self.perm_5D = (1, 0, 3, 4, 2) - - @jit - def construct(self): - return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), - self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_transpose_int64(): - transpose = Transpose_int64() - output = transpose() - - expect0 = np.array([[[0, 6, 12, 18, 24], - [1, 7, 13, 19, 25], - [2, 8, 14, 20, 26], - [3, 9, 15, 21, 27], - [4, 10, 16, 22, 28], - [5, 11, 17, 23, 29]]]).astype(np.int64) - expect1 = np.array([[[[0, 1, 2, 3], - [8, 9, 10, 11]], - [[4, 5, 6, 7], - [12, 13, 14, 15]]]]).astype(np.int64) - expect2 = np.array([[[[[0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19]], - [[20, 21, 22, 23, 24], - [25, 26, 27, 28, 29], - [30, 31, 32, 33, 34], - [35, 36, 37, 38, 39]], - [[40, 41, 42, 43, 44], - [45, 46, 47, 48, 49], - [50, 51, 52, 53, 54], - [55, 56, 57, 58, 59]]], - - [[[60, 61, 62, 63, 64], - [65, 66, 67, 68, 69], - [70, 71, 72, 73, 74], - [75, 76, 77, 78, 79]], - [[80, 81, 82, 83, 84], - [85, 86, 87, 88, 89], - [90, 91, 92, 93, 94], - [95, 96, 97, 98, 99]], - [[100, 101, 102, 103, 104], - [105, 106, 107, 108, 109], - [110, 111, 112, 113, 114], - [115, 116, 117, 118, 119]]]]]).astype(np.int64) - expect3 = np.array([[[[[[0, 20, 40], - [1, 21, 41], - [2, 22, 42], - [3, 23, 43], - [4, 24, 44]], - [[5, 25, 45], - [6, 26, 46], - [7, 27, 47], - [8, 28, 48], - [9, 29, 49]], - [[10, 30, 50], - [11, 31, 51], - [12, 32, 52], - [13, 33, 53], - [14, 34, 54]], - [[15, 35, 55], - [16, 36, 56], - [17, 37, 57], - [18, 38, 58], - [19, 39, 59]]]], - - [[[[60, 80, 100], - [61, 81, 101], - [62, 82, 102], - [63, 83, 103], - [64, 84, 104]], - [[65, 85, 105], - [66, 86, 106], - [67, 87, 107], - [68, 88, 108], - [69, 89, 109]], - [[70, 90, 110], - [71, 91, 111], - [72, 92, 112], - [73, 93, 113], - [74, 94, 114]], - [[75, 95, 115], - [76, 96, 116], - [77, 97, 117], - [78, 98, 118], - [79, 99, 119]]]]]]).astype(np.int64) - assert (output[0].asnumpy() == expect0).all() - assert (output[1].asnumpy() == expect1).all() - assert (output[2].asnumpy() == expect2).all() - assert (output[3].asnumpy() == expect3).all() - - - -class Transpose_uint8(nn.Cell): - def __init__(self): - super(Transpose_uint8, self).__init__() - self.transpose = P.Transpose() - - self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.uint8)), [5, 6]), - name='x_2D') - self.perm_2D = (1, 0) - - self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.uint8)), [2, 2, 4]), - name='x_3D') - self.perm_3D = (1, 0, 2) - - self.x_4D = Parameter( - initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, - 3, 4, 5).astype(np.uint8)), [2, 3, 4, 5]), - name='x_4D') - self.perm_4D = (0, 1, 2, 3) - - self.x_5D = Parameter( - initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.uint8)), - [1, 2, 3, 4, 5]), name='x_5D') - self.perm_5D = (1, 0, 3, 4, 2) - - @jit - def construct(self): - return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), - self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_transpose_uint8(): - transpose = Transpose_uint8() - output = transpose() - - expect0 = np.array([[[0, 6, 12, 18, 24], - [1, 7, 13, 19, 25], - [2, 8, 14, 20, 26], - [3, 9, 15, 21, 27], - [4, 10, 16, 22, 28], - [5, 11, 17, 23, 29]]]).astype(np.uint8) - expect1 = np.array([[[[0, 1, 2, 3], - [8, 9, 10, 11]], - [[4, 5, 6, 7], - [12, 13, 14, 15]]]]).astype(np.uint8) - expect2 = np.array([[[[[0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19]], - [[20, 21, 22, 23, 24], - [25, 26, 27, 28, 29], - [30, 31, 32, 33, 34], - [35, 36, 37, 38, 39]], - [[40, 41, 42, 43, 44], - [45, 46, 47, 48, 49], - [50, 51, 52, 53, 54], - [55, 56, 57, 58, 59]]], - - [[[60, 61, 62, 63, 64], - [65, 66, 67, 68, 69], - [70, 71, 72, 73, 74], - [75, 76, 77, 78, 79]], - [[80, 81, 82, 83, 84], - [85, 86, 87, 88, 89], - [90, 91, 92, 93, 94], - [95, 96, 97, 98, 99]], - [[100, 101, 102, 103, 104], - [105, 106, 107, 108, 109], - [110, 111, 112, 113, 114], - [115, 116, 117, 118, 119]]]]]).astype(np.uint8) - expect3 = np.array([[[[[[0, 20, 40], - [1, 21, 41], - [2, 22, 42], - [3, 23, 43], - [4, 24, 44]], - [[5, 25, 45], - [6, 26, 46], - [7, 27, 47], - [8, 28, 48], - [9, 29, 49]], - [[10, 30, 50], - [11, 31, 51], - [12, 32, 52], - [13, 33, 53], - [14, 34, 54]], - [[15, 35, 55], - [16, 36, 56], - [17, 37, 57], - [18, 38, 58], - [19, 39, 59]]]], - - [[[[60, 80, 100], - [61, 81, 101], - [62, 82, 102], - [63, 83, 103], - [64, 84, 104]], - [[65, 85, 105], - [66, 86, 106], - [67, 87, 107], - [68, 88, 108], - [69, 89, 109]], - [[70, 90, 110], - [71, 91, 111], - [72, 92, 112], - [73, 93, 113], - [74, 94, 114]], - [[75, 95, 115], - [76, 96, 116], - [77, 97, 117], - [78, 98, 118], - [79, 99, 119]]]]]]).astype(np.uint8) - assert (output[0].asnumpy() == expect0).all() - assert (output[1].asnumpy() == expect1).all() - assert (output[2].asnumpy() == expect2).all() - assert (output[3].asnumpy() == expect3).all() +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.api import jit +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +import mindspore.nn as nn +import mindspore.context as context + +context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + + +class Transpose(nn.Cell): + def __init__(self): + super(Transpose, self).__init__() + self.transpose = P.Transpose() + + self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.float32)), [5, 6]), + name='x_2D') + self.perm_2D = (1, 0) + + self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)), [2, 2, 4]), + name='x_3D') + self.perm_3D = (1, 0, 2) + + self.x_4D = Parameter( + initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, + 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), + name='x_4D') + self.perm_4D = (0, 1, 2, 3) + + self.x_5D = Parameter( + initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.float32)), + [1, 2, 3, 4, 5]), name='x_5D') + self.perm_5D = (1, 0, 3, 4, 2) + + @jit + def construct(self): + return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), + self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_transpose(): + transpose = Transpose() + output = transpose() + + expect0 = np.array([[[0, 6, 12, 18, 24], + [1, 7, 13, 19, 25], + [2, 8, 14, 20, 26], + [3, 9, 15, 21, 27], + [4, 10, 16, 22, 28], + [5, 11, 17, 23, 29]]]).astype(np.float32) + expect1 = np.array([[[[0, 1, 2, 3], + [8, 9, 10, 11]], + [[4, 5, 6, 7], + [12, 13, 14, 15]]]]).astype(np.float32) + expect2 = np.array([[[[[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39]], + [[40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]], + + [[[60, 61, 62, 63, 64], + [65, 66, 67, 68, 69], + [70, 71, 72, 73, 74], + [75, 76, 77, 78, 79]], + [[80, 81, 82, 83, 84], + [85, 86, 87, 88, 89], + [90, 91, 92, 93, 94], + [95, 96, 97, 98, 99]], + [[100, 101, 102, 103, 104], + [105, 106, 107, 108, 109], + [110, 111, 112, 113, 114], + [115, 116, 117, 118, 119]]]]]).astype(np.float32) + expect3 = np.array([[[[[[0, 20, 40], + [1, 21, 41], + [2, 22, 42], + [3, 23, 43], + [4, 24, 44]], + [[5, 25, 45], + [6, 26, 46], + [7, 27, 47], + [8, 28, 48], + [9, 29, 49]], + [[10, 30, 50], + [11, 31, 51], + [12, 32, 52], + [13, 33, 53], + [14, 34, 54]], + [[15, 35, 55], + [16, 36, 56], + [17, 37, 57], + [18, 38, 58], + [19, 39, 59]]]], + + [[[[60, 80, 100], + [61, 81, 101], + [62, 82, 102], + [63, 83, 103], + [64, 84, 104]], + [[65, 85, 105], + [66, 86, 106], + [67, 87, 107], + [68, 88, 108], + [69, 89, 109]], + [[70, 90, 110], + [71, 91, 111], + [72, 92, 112], + [73, 93, 113], + [74, 94, 114]], + [[75, 95, 115], + [76, 96, 116], + [77, 97, 117], + [78, 98, 118], + [79, 99, 119]]]]]]).astype(np.float32) + assert (output[0].asnumpy() == expect0).all() + assert (output[1].asnumpy() == expect1).all() + assert (output[2].asnumpy() == expect2).all() + assert (output[3].asnumpy() == expect3).all() + + +class Transpose_int64(nn.Cell): + def __init__(self): + super(Transpose_int64, self).__init__() + self.transpose = P.Transpose() + + self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.int64)), [5, 6]), + name='x_2D') + self.perm_2D = (1, 0) + + self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.int64)), [2, 2, 4]), + name='x_3D') + self.perm_3D = (1, 0, 2) + + self.x_4D = Parameter( + initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, + 3, 4, 5).astype(np.int64)), [2, 3, 4, 5]), + name='x_4D') + self.perm_4D = (0, 1, 2, 3) + + self.x_5D = Parameter( + initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.int64)), + [1, 2, 3, 4, 5]), name='x_5D') + self.perm_5D = (1, 0, 3, 4, 2) + + @jit + def construct(self): + return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), + self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_transpose_int64(): + transpose = Transpose_int64() + output = transpose() + + expect0 = np.array([[[0, 6, 12, 18, 24], + [1, 7, 13, 19, 25], + [2, 8, 14, 20, 26], + [3, 9, 15, 21, 27], + [4, 10, 16, 22, 28], + [5, 11, 17, 23, 29]]]).astype(np.int64) + expect1 = np.array([[[[0, 1, 2, 3], + [8, 9, 10, 11]], + [[4, 5, 6, 7], + [12, 13, 14, 15]]]]).astype(np.int64) + expect2 = np.array([[[[[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39]], + [[40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]], + + [[[60, 61, 62, 63, 64], + [65, 66, 67, 68, 69], + [70, 71, 72, 73, 74], + [75, 76, 77, 78, 79]], + [[80, 81, 82, 83, 84], + [85, 86, 87, 88, 89], + [90, 91, 92, 93, 94], + [95, 96, 97, 98, 99]], + [[100, 101, 102, 103, 104], + [105, 106, 107, 108, 109], + [110, 111, 112, 113, 114], + [115, 116, 117, 118, 119]]]]]).astype(np.int64) + expect3 = np.array([[[[[[0, 20, 40], + [1, 21, 41], + [2, 22, 42], + [3, 23, 43], + [4, 24, 44]], + [[5, 25, 45], + [6, 26, 46], + [7, 27, 47], + [8, 28, 48], + [9, 29, 49]], + [[10, 30, 50], + [11, 31, 51], + [12, 32, 52], + [13, 33, 53], + [14, 34, 54]], + [[15, 35, 55], + [16, 36, 56], + [17, 37, 57], + [18, 38, 58], + [19, 39, 59]]]], + + [[[[60, 80, 100], + [61, 81, 101], + [62, 82, 102], + [63, 83, 103], + [64, 84, 104]], + [[65, 85, 105], + [66, 86, 106], + [67, 87, 107], + [68, 88, 108], + [69, 89, 109]], + [[70, 90, 110], + [71, 91, 111], + [72, 92, 112], + [73, 93, 113], + [74, 94, 114]], + [[75, 95, 115], + [76, 96, 116], + [77, 97, 117], + [78, 98, 118], + [79, 99, 119]]]]]]).astype(np.int64) + assert (output[0].asnumpy() == expect0).all() + assert (output[1].asnumpy() == expect1).all() + assert (output[2].asnumpy() == expect2).all() + assert (output[3].asnumpy() == expect3).all() + + + +class Transpose_uint8(nn.Cell): + def __init__(self): + super(Transpose_uint8, self).__init__() + self.transpose = P.Transpose() + + self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.uint8)), [5, 6]), + name='x_2D') + self.perm_2D = (1, 0) + + self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.uint8)), [2, 2, 4]), + name='x_3D') + self.perm_3D = (1, 0, 2) + + self.x_4D = Parameter( + initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, + 3, 4, 5).astype(np.uint8)), [2, 3, 4, 5]), + name='x_4D') + self.perm_4D = (0, 1, 2, 3) + + self.x_5D = Parameter( + initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.uint8)), + [1, 2, 3, 4, 5]), name='x_5D') + self.perm_5D = (1, 0, 3, 4, 2) + + @jit + def construct(self): + return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), + self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_transpose_uint8(): + transpose = Transpose_uint8() + output = transpose() + + expect0 = np.array([[[0, 6, 12, 18, 24], + [1, 7, 13, 19, 25], + [2, 8, 14, 20, 26], + [3, 9, 15, 21, 27], + [4, 10, 16, 22, 28], + [5, 11, 17, 23, 29]]]).astype(np.uint8) + expect1 = np.array([[[[0, 1, 2, 3], + [8, 9, 10, 11]], + [[4, 5, 6, 7], + [12, 13, 14, 15]]]]).astype(np.uint8) + expect2 = np.array([[[[[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + [[20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39]], + [[40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59]]], + + [[[60, 61, 62, 63, 64], + [65, 66, 67, 68, 69], + [70, 71, 72, 73, 74], + [75, 76, 77, 78, 79]], + [[80, 81, 82, 83, 84], + [85, 86, 87, 88, 89], + [90, 91, 92, 93, 94], + [95, 96, 97, 98, 99]], + [[100, 101, 102, 103, 104], + [105, 106, 107, 108, 109], + [110, 111, 112, 113, 114], + [115, 116, 117, 118, 119]]]]]).astype(np.uint8) + expect3 = np.array([[[[[[0, 20, 40], + [1, 21, 41], + [2, 22, 42], + [3, 23, 43], + [4, 24, 44]], + [[5, 25, 45], + [6, 26, 46], + [7, 27, 47], + [8, 28, 48], + [9, 29, 49]], + [[10, 30, 50], + [11, 31, 51], + [12, 32, 52], + [13, 33, 53], + [14, 34, 54]], + [[15, 35, 55], + [16, 36, 56], + [17, 37, 57], + [18, 38, 58], + [19, 39, 59]]]], + + [[[[60, 80, 100], + [61, 81, 101], + [62, 82, 102], + [63, 83, 103], + [64, 84, 104]], + [[65, 85, 105], + [66, 86, 106], + [67, 87, 107], + [68, 88, 108], + [69, 89, 109]], + [[70, 90, 110], + [71, 91, 111], + [72, 92, 112], + [73, 93, 113], + [74, 94, 114]], + [[75, 95, 115], + [76, 96, 116], + [77, 97, 117], + [78, 98, 118], + [79, 99, 119]]]]]]).astype(np.uint8) + assert (output[0].asnumpy() == expect0).all() + assert (output[1].asnumpy() == expect1).all() + assert (output[2].asnumpy() == expect2).all() + assert (output[3].asnumpy() == expect3).all() diff --git a/tests/st/ops/cpu/test_xdivy_op.py b/tests/st/ops/cpu/test_xdivy_op.py index 70d851f2112..9a24fc14d33 100644 --- a/tests/st/ops/cpu/test_xdivy_op.py +++ b/tests/st/ops/cpu/test_xdivy_op.py @@ -1,179 +1,179 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore as ms - -TF_INSTALL_FLG = 1 -try: - import tensorflow as tf -except ImportError: - TF_INSTALL_FLG = 0 - - -class NetXDivy(nn.Cell): - def __init__(self): - super(NetXDivy, self).__init__() - self.xdivy = P.Xdivy() - - def construct(self, x, y): - return self.xdivy(x, y) - - -def xdivy(nptype): - x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) - x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) - y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x3_np = np.random.randint(1, 5, 1).astype(nptype) - y3_np = np.random.randint(1, 5, 1).astype(nptype) - x4_np = np.array(78).astype(nptype) - y4_np = np.array(37.5).astype(nptype) - - x0 = Tensor(x0_np) - y0 = Tensor(y0_np) - x1 = Tensor(x1_np) - y1 = Tensor(y1_np) - x2 = Tensor(x2_np) - y2 = Tensor(y2_np) - x3 = Tensor(x3_np) - y3 = Tensor(y3_np) - x4 = Tensor(x4_np) - y4 = Tensor(y4_np) - - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - div_net = NetXDivy() - output0 = div_net(x0, y0) - expect0 = np.divide(x0_np, y0_np) - diff0 = output0.asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output0.shape == expect0.shape - - output1 = div_net(x1, y1) - expect1 = np.divide(x1_np, y1_np) - diff1 = output1.asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output1.shape == expect1.shape - - output2 = div_net(x2, y2) - expect2 = np.divide(x2_np, y2_np) - diff2 = output2.asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output2.shape == expect2.shape - - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - output3 = div_net(x3, y3) - expect3 = np.divide(x3_np, y3_np) - diff3 = output3.asnumpy() - expect3 - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output3.shape == expect3.shape - - output4 = div_net(x4, y4) - expect4 = np.divide(x4_np, y4_np) - diff4 = output4.asnumpy() - expect4 - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output4.shape == expect4.shape - - -def xdivy_sf_check(mstype, tftype): - # test divided zero - tx = tf.constant([-4.0, 0.0, 1.0, 0.0], dtype=tftype) - ty = tf.constant([3.0, 2.0, 0.0, 0.0], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - - x = ms.Tensor(np.array([-4.0, 0.0, 1.0, 0.0]), dtype=mstype) - y = ms.Tensor(np.array([3.0, 2.0, 0.0, 0.0]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - # test broadcast - tx = tf.constant([-4.0, 5.0, 0.0], dtype=tftype) - ty = tf.constant([[3.0], [2.0]], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - x = ms.Tensor(np.array([-4.0, 5.0, 0.0]), dtype=mstype) - y = ms.Tensor(np.array([[3.0], [2.0]]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - # test broadcast - tx = tf.constant([-4.0], dtype=tftype) - ty = tf.constant([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - x = ms.Tensor(np.array([-4.0]), dtype=mstype) - y = ms.Tensor(np.array([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_xdivy_float64(): - """ - Feature: test xdivy primitive use float64 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float64) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - xdivy_sf_check(ms.float64, tf.dtypes.float64) - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - xdivy_sf_check(ms.float64, tf.dtypes.float64) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_xdivy_float32(): - """ - Feature: test xdivy primitive use float32 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float32) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - xdivy_sf_check(ms.float32, tf.dtypes.float32) - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - xdivy_sf_check(ms.float32, tf.dtypes.float32) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_xdivy_float16(): - """ - Feature: test xdivy primitive use float16 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float16) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') - xdivy_sf_check(ms.float16, tf.dtypes.float16) - context.set_context(mode=context.GRAPH_MODE, device_target='CPU') - xdivy_sf_check(ms.float16, tf.dtypes.float16) +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore as ms + +TF_INSTALL_FLG = 1 +try: + import tensorflow as tf +except ImportError: + TF_INSTALL_FLG = 0 + + +class NetXDivy(nn.Cell): + def __init__(self): + super(NetXDivy, self).__init__() + self.xdivy = P.Xdivy() + + def construct(self, x, y): + return self.xdivy(x, y) + + +def xdivy(nptype): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x3_np = np.random.randint(1, 5, 1).astype(nptype) + y3_np = np.random.randint(1, 5, 1).astype(nptype) + x4_np = np.array(78).astype(nptype) + y4_np = np.array(37.5).astype(nptype) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + div_net = NetXDivy() + output0 = div_net(x0, y0) + expect0 = np.divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = div_net(x1, y1) + expect1 = np.divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = div_net(x2, y2) + expect2 = np.divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + output3 = div_net(x3, y3) + expect3 = np.divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = div_net(x4, y4) + expect4 = np.divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + +def xdivy_sf_check(mstype, tftype): + # test divided zero + tx = tf.constant([-4.0, 0.0, 1.0, 0.0], dtype=tftype) + ty = tf.constant([3.0, 2.0, 0.0, 0.0], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + + x = ms.Tensor(np.array([-4.0, 0.0, 1.0, 0.0]), dtype=mstype) + y = ms.Tensor(np.array([3.0, 2.0, 0.0, 0.0]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + # test broadcast + tx = tf.constant([-4.0, 5.0, 0.0], dtype=tftype) + ty = tf.constant([[3.0], [2.0]], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + x = ms.Tensor(np.array([-4.0, 5.0, 0.0]), dtype=mstype) + y = ms.Tensor(np.array([[3.0], [2.0]]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + # test broadcast + tx = tf.constant([-4.0], dtype=tftype) + ty = tf.constant([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + x = ms.Tensor(np.array([-4.0]), dtype=mstype) + y = ms.Tensor(np.array([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_xdivy_float64(): + """ + Feature: test xdivy primitive use float64 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float64) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + xdivy_sf_check(ms.float64, tf.dtypes.float64) + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + xdivy_sf_check(ms.float64, tf.dtypes.float64) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_xdivy_float32(): + """ + Feature: test xdivy primitive use float32 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float32) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + xdivy_sf_check(ms.float32, tf.dtypes.float32) + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + xdivy_sf_check(ms.float32, tf.dtypes.float32) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_xdivy_float16(): + """ + Feature: test xdivy primitive use float16 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float16) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + xdivy_sf_check(ms.float16, tf.dtypes.float16) + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + xdivy_sf_check(ms.float16, tf.dtypes.float16) diff --git a/tests/st/ops/custom_ops_tbe/conv_layer.py b/tests/st/ops/custom_ops_tbe/conv_layer.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/dynamic_sequence/test_dynamic_sequence_stack.py b/tests/st/ops/dynamic_sequence/test_dynamic_sequence_stack.py index 4cb60b22796..f100ef67026 100644 --- a/tests/st/ops/dynamic_sequence/test_dynamic_sequence_stack.py +++ b/tests/st/ops/dynamic_sequence/test_dynamic_sequence_stack.py @@ -1,74 +1,74 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest - -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor, context -from mindspore.common import mutable -from mindspore.ops.operations._sequence_ops import SequenceStack -from sequence_help import context_prepare - -context.set_context(mode=context.GRAPH_MODE) -context_prepare() - - -class NetSequenceStack(nn.Cell): - def __init__(self, axis=0): - super().__init__() - self.op = SequenceStack(axis=axis) - - def construct(self, seq): - return self.op(seq) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_seq_tensor_stack0(): - """ - Feature: test sequence stack op - Description: setitem operation on tuple type - Expectation: the behavior is matched to python style - """ - dtype = np.float32 - data_np = np.array([0] * 16).astype(dtype) - data_np = np.reshape(data_np, (2, 2, 2, 2)) - x1 = Tensor(data_np) - x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(dtype)) - x = mutable((x1, x2), True) - y = ops.stack((x1, x2), 0) - net = NetSequenceStack(axis=0) - res = net(x) - assert np.all(res.asnumpy() == y.asnumpy()) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_seq_tensor_stack1(): - """ - Feature: test sequence stack op - Description: setitem operation on tuple type - Expectation: the behavior is matched to python style - """ - dtype = np.float32 - data_np = np.array([0] * 16).astype(dtype) - data_np = np.reshape(data_np, (2, 2, 2, 2)) - x1 = Tensor(data_np) - x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(dtype)) - x = mutable((x1, x2), True) - y = ops.stack((x1, x2), 1) - net = NetSequenceStack(axis=1) - res = net(x) - assert np.all(res.asnumpy() == y.asnumpy()) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor, context +from mindspore.common import mutable +from mindspore.ops.operations._sequence_ops import SequenceStack +from sequence_help import context_prepare + +context.set_context(mode=context.GRAPH_MODE) +context_prepare() + + +class NetSequenceStack(nn.Cell): + def __init__(self, axis=0): + super().__init__() + self.op = SequenceStack(axis=axis) + + def construct(self, seq): + return self.op(seq) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_seq_tensor_stack0(): + """ + Feature: test sequence stack op + Description: setitem operation on tuple type + Expectation: the behavior is matched to python style + """ + dtype = np.float32 + data_np = np.array([0] * 16).astype(dtype) + data_np = np.reshape(data_np, (2, 2, 2, 2)) + x1 = Tensor(data_np) + x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(dtype)) + x = mutable((x1, x2), True) + y = ops.stack((x1, x2), 0) + net = NetSequenceStack(axis=0) + res = net(x) + assert np.all(res.asnumpy() == y.asnumpy()) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_seq_tensor_stack1(): + """ + Feature: test sequence stack op + Description: setitem operation on tuple type + Expectation: the behavior is matched to python style + """ + dtype = np.float32 + data_np = np.array([0] * 16).astype(dtype) + data_np = np.reshape(data_np, (2, 2, 2, 2)) + x1 = Tensor(data_np) + x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(dtype)) + x = mutable((x1, x2), True) + y = ops.stack((x1, x2), 1) + net = NetSequenceStack(axis=1) + res = net(x) + assert np.all(res.asnumpy() == y.asnumpy()) diff --git a/tests/st/ops/gpu/test_adaptive_avg_pool3d_op.py b/tests/st/ops/gpu/test_adaptive_avg_pool3d_op.py index c537be991d9..c5a69d213e0 100644 --- a/tests/st/ops/gpu/test_adaptive_avg_pool3d_op.py +++ b/tests/st/ops/gpu/test_adaptive_avg_pool3d_op.py @@ -1,226 +1,226 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import torch -from torch.nn.functional import adaptive_avg_pool3d - -import numpy as np -import pytest - -import mindspore -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, ops -import mindspore.ops.operations.nn_ops as P -from mindspore.ops.operations import _grad_ops as G -from mindspore.common.api import jit - -context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - - -class Net(nn.Cell): - def __init__(self, output_size): - super(Net, self).__init__() - self.adaptive_avg_pool3d = P.AdaptiveAvgPool3D(output_size) - - @jit - def construct(self, x): - return self.adaptive_avg_pool3d(x) - - -class GradNet(nn.Cell): - def __init__(self): - super(GradNet, self).__init__() - self.adaptive_avg_pool3d_grad = G.AdaptiveAvgPool3DGrad() - - @jit - def construct(self, x, dy): - return self.adaptive_avg_pool3d_grad(x, dy) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize("shape", [(1, 32, 9, 9, 9), (3, 9, 5, 4)]) -def test_net_normal_with_functional(mode, shape): - ''' - Feature: Test adaptive_avg_pool3d functional interface - Description: A randomly generated 5-dimensional matrix, Expected pooled output size - Expectation: Successfully get output with expected output size - ''' - context.set_context(mode=mode) - x = Tensor(np.random.randn(*shape).astype(np.float32)) - output_size = (3, 4, 5) - output = ops.adaptive_avg_pool3d(x, output_size) - expect_shape = shape[:-3] + output_size - assert output.asnumpy().shape == expect_shape - - output_size = 3 - output = ops.adaptive_avg_pool3d(x, output_size) - expect_shape = shape[:-3] + (output_size, output_size, output_size) - assert output.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize("shape", [(1, 32, 9, 9, 9), (3, 9, 5, 4)]) -def test_net_normal_with_nn(mode, shape): - ''' - Feature: Test AdaptiveAvgPool3d nn interface - Description: A randomly generated 5-dimensional matrix, Expected pooled output size - Expectation: Successfully get output with expected output size - ''' - context.set_context(mode=mode) - x = Tensor(np.random.randn(*shape).astype(np.float32)) - output_size = (3, 4, 5) - net = nn.AdaptiveAvgPool3d(output_size) - output = net(x) - expect_shape = shape[:-3] + output_size - assert output.asnumpy().shape == expect_shape - - output_size = 3 - output = ops.adaptive_avg_pool3d(x, output_size) - expect_shape = shape[:-3] + (output_size, output_size, output_size) - assert output.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_net_normal(): - ''' - Feature: If AdaptiveAvgPool3D is normal - Description: A randomly generated 5-dimensional matrix, Expected pooled output size - Expectation: Successfully get output with expected output size - ''' - x = np.random.randn(1, 32, 9, 9, 9) - net = Net((3, 4, 5)) - output = net(Tensor(x, mindspore.float32)) - expect_shape = (1, 32, 3, 4, 5) - assert output.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_net_graph_mode_fp64(): - ''' - Feature: If every value type of AdaptiveAvgPool3D and AdaptiveAvgPool3DGrad are normal - Description: A 4-dimensional matrix with different types, Expected pooled output size - Expectation: Successfully get output with expected output value - ''' - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - x = np.array([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]], - - [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]], - - [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]) - - adaptive_avg_pool_3d = P.AdaptiveAvgPool3D((2, 2, 2)) - output_fp16 = adaptive_avg_pool_3d(Tensor(x, mindspore.float16)) - output_fp32 = adaptive_avg_pool_3d(Tensor(x, mindspore.float32)) - output_fp64 = adaptive_avg_pool_3d(Tensor(x, mindspore.float64)) - - torchx_fp16 = torch.tensor(x, requires_grad=True, dtype=torch.half) - output_torch_fp16 = adaptive_avg_pool3d(torchx_fp16, (2, 2, 2)) - torchx_fp32 = torch.tensor(x, requires_grad=True, dtype=torch.float) - output_torch_fp32 = adaptive_avg_pool3d(torchx_fp32, (2, 2, 2)) - torchx_fp64 = torch.tensor(x, requires_grad=True, dtype=torch.double) - output_torch_fp64 = adaptive_avg_pool3d(torchx_fp64, (2, 2, 2)) - - expect_shape = (3, 2, 2, 2) - expect_output = np.array([[[[3.0, 4.0], [6.0, 7.0]], - [[3.0, 4.0], [6.0, 7.0]]], - - [[[3.0, 4.0], [6.0, 7.0]], - [[3.0, 4.0], [6.0, 7.0]]], - - [[[3.0, 4.0], [6.0, 7.0]], - [[3.0, 4.0], [6.0, 7.0]]]]) - - assert (output_fp16.asnumpy() == expect_output).all - assert output_fp32.asnumpy().shape == expect_shape - assert (output_fp32.asnumpy() == expect_output).all - assert output_fp64.asnumpy().shape == expect_shape - assert (output_fp64.asnumpy() == expect_output).all - - - assert output_torch_fp16.detach().numpy().shape == expect_shape - assert (output_fp16.asnumpy() - output_torch_fp16.detach().numpy() == 0).all - assert output_torch_fp32.detach().numpy().shape == expect_shape - assert (output_fp32.asnumpy() - output_torch_fp32.detach().numpy() == 0).all - assert output_torch_fp64.detach().numpy().shape == expect_shape - assert (output_fp64.asnumpy() - output_torch_fp64.detach().numpy() == 0).all - - expect_dx = np.array([[[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]], - - [[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]], - - [[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], - [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]]]) - grad_net = GradNet() - - dx_fp16 = grad_net(output_fp16, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) - dx_fp32 = grad_net(output_fp32, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) - dx_fp64 = grad_net(output_fp64, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) - - output_torch_fp16.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]]])) - dx_torch_fp16 = torchx_fp16.grad - output_torch_fp32.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]]])) - dx_torch_fp32 = torchx_fp32.grad - output_torch_fp64.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]], - - [[[1.0, 1.0], [1.0, 1.0]], - [[1.0, 1.0], [1.0, 1.0]]]])) - dx_torch_fp64 = torchx_fp64.grad - - assert dx_fp16.asnumpy().shape == x.shape - assert (dx_fp16.asnumpy() == expect_dx).all - assert dx_fp32.asnumpy().shape == x.shape - assert (dx_fp32.asnumpy() == expect_dx).all - assert dx_fp64.asnumpy().shape == x.shape - assert (dx_fp64.asnumpy() == expect_dx).all - - assert dx_torch_fp16.detach().numpy().shape == x.shape - assert (dx_fp16.asnumpy() - dx_torch_fp16.detach().numpy() == 0).all - assert dx_torch_fp32.detach().numpy().shape == x.shape - assert (dx_fp32.asnumpy() - dx_torch_fp32.detach().numpy() == 0).all - assert dx_torch_fp64.detach().numpy().shape == x.shape - assert (dx_fp64.asnumpy() - dx_torch_fp64.detach().numpy() == 0).all +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import torch +from torch.nn.functional import adaptive_avg_pool3d + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops +import mindspore.ops.operations.nn_ops as P +from mindspore.ops.operations import _grad_ops as G +from mindspore.common.api import jit + +context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + + +class Net(nn.Cell): + def __init__(self, output_size): + super(Net, self).__init__() + self.adaptive_avg_pool3d = P.AdaptiveAvgPool3D(output_size) + + @jit + def construct(self, x): + return self.adaptive_avg_pool3d(x) + + +class GradNet(nn.Cell): + def __init__(self): + super(GradNet, self).__init__() + self.adaptive_avg_pool3d_grad = G.AdaptiveAvgPool3DGrad() + + @jit + def construct(self, x, dy): + return self.adaptive_avg_pool3d_grad(x, dy) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("shape", [(1, 32, 9, 9, 9), (3, 9, 5, 4)]) +def test_net_normal_with_functional(mode, shape): + ''' + Feature: Test adaptive_avg_pool3d functional interface + Description: A randomly generated 5-dimensional matrix, Expected pooled output size + Expectation: Successfully get output with expected output size + ''' + context.set_context(mode=mode) + x = Tensor(np.random.randn(*shape).astype(np.float32)) + output_size = (3, 4, 5) + output = ops.adaptive_avg_pool3d(x, output_size) + expect_shape = shape[:-3] + output_size + assert output.asnumpy().shape == expect_shape + + output_size = 3 + output = ops.adaptive_avg_pool3d(x, output_size) + expect_shape = shape[:-3] + (output_size, output_size, output_size) + assert output.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("shape", [(1, 32, 9, 9, 9), (3, 9, 5, 4)]) +def test_net_normal_with_nn(mode, shape): + ''' + Feature: Test AdaptiveAvgPool3d nn interface + Description: A randomly generated 5-dimensional matrix, Expected pooled output size + Expectation: Successfully get output with expected output size + ''' + context.set_context(mode=mode) + x = Tensor(np.random.randn(*shape).astype(np.float32)) + output_size = (3, 4, 5) + net = nn.AdaptiveAvgPool3d(output_size) + output = net(x) + expect_shape = shape[:-3] + output_size + assert output.asnumpy().shape == expect_shape + + output_size = 3 + output = ops.adaptive_avg_pool3d(x, output_size) + expect_shape = shape[:-3] + (output_size, output_size, output_size) + assert output.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_net_normal(): + ''' + Feature: If AdaptiveAvgPool3D is normal + Description: A randomly generated 5-dimensional matrix, Expected pooled output size + Expectation: Successfully get output with expected output size + ''' + x = np.random.randn(1, 32, 9, 9, 9) + net = Net((3, 4, 5)) + output = net(Tensor(x, mindspore.float32)) + expect_shape = (1, 32, 3, 4, 5) + assert output.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_net_graph_mode_fp64(): + ''' + Feature: If every value type of AdaptiveAvgPool3D and AdaptiveAvgPool3DGrad are normal + Description: A 4-dimensional matrix with different types, Expected pooled output size + Expectation: Successfully get output with expected output value + ''' + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]], + + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]], + + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]) + + adaptive_avg_pool_3d = P.AdaptiveAvgPool3D((2, 2, 2)) + output_fp16 = adaptive_avg_pool_3d(Tensor(x, mindspore.float16)) + output_fp32 = adaptive_avg_pool_3d(Tensor(x, mindspore.float32)) + output_fp64 = adaptive_avg_pool_3d(Tensor(x, mindspore.float64)) + + torchx_fp16 = torch.tensor(x, requires_grad=True, dtype=torch.half) + output_torch_fp16 = adaptive_avg_pool3d(torchx_fp16, (2, 2, 2)) + torchx_fp32 = torch.tensor(x, requires_grad=True, dtype=torch.float) + output_torch_fp32 = adaptive_avg_pool3d(torchx_fp32, (2, 2, 2)) + torchx_fp64 = torch.tensor(x, requires_grad=True, dtype=torch.double) + output_torch_fp64 = adaptive_avg_pool3d(torchx_fp64, (2, 2, 2)) + + expect_shape = (3, 2, 2, 2) + expect_output = np.array([[[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]], + + [[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]], + + [[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]]]) + + assert (output_fp16.asnumpy() == expect_output).all + assert output_fp32.asnumpy().shape == expect_shape + assert (output_fp32.asnumpy() == expect_output).all + assert output_fp64.asnumpy().shape == expect_shape + assert (output_fp64.asnumpy() == expect_output).all + + + assert output_torch_fp16.detach().numpy().shape == expect_shape + assert (output_fp16.asnumpy() - output_torch_fp16.detach().numpy() == 0).all + assert output_torch_fp32.detach().numpy().shape == expect_shape + assert (output_fp32.asnumpy() - output_torch_fp32.detach().numpy() == 0).all + assert output_torch_fp64.detach().numpy().shape == expect_shape + assert (output_fp64.asnumpy() - output_torch_fp64.detach().numpy() == 0).all + + expect_dx = np.array([[[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]], + + [[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]], + + [[[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]], + [[0.75, 1.75, 1.0], [2.25, 5.0, 2.75], [1.5, 3.25, 1.75]]]]) + grad_net = GradNet() + + dx_fp16 = grad_net(output_fp16, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) + dx_fp32 = grad_net(output_fp32, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) + dx_fp64 = grad_net(output_fp64, Tensor(np.array([3, 3, 3, 3])).astype(np.int32)) + + output_torch_fp16.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]]])) + dx_torch_fp16 = torchx_fp16.grad + output_torch_fp32.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]]])) + dx_torch_fp32 = torchx_fp32.grad + output_torch_fp64.backward(torch.DoubleTensor([[[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]], + + [[[1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0], [1.0, 1.0]]]])) + dx_torch_fp64 = torchx_fp64.grad + + assert dx_fp16.asnumpy().shape == x.shape + assert (dx_fp16.asnumpy() == expect_dx).all + assert dx_fp32.asnumpy().shape == x.shape + assert (dx_fp32.asnumpy() == expect_dx).all + assert dx_fp64.asnumpy().shape == x.shape + assert (dx_fp64.asnumpy() == expect_dx).all + + assert dx_torch_fp16.detach().numpy().shape == x.shape + assert (dx_fp16.asnumpy() - dx_torch_fp16.detach().numpy() == 0).all + assert dx_torch_fp32.detach().numpy().shape == x.shape + assert (dx_fp32.asnumpy() - dx_torch_fp32.detach().numpy() == 0).all + assert dx_torch_fp64.detach().numpy().shape == x.shape + assert (dx_fp64.asnumpy() - dx_torch_fp64.detach().numpy() == 0).all diff --git a/tests/st/ops/gpu/test_apply_adagrad_a_d_op.py b/tests/st/ops/gpu/test_apply_adagrad_a_d_op.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/gpu/test_binary_cross_entropy_op.py b/tests/st/ops/gpu/test_binary_cross_entropy_op.py index 2064f68947a..0435556b19d 100644 --- a/tests/st/ops/gpu/test_binary_cross_entropy_op.py +++ b/tests/st/ops/gpu/test_binary_cross_entropy_op.py @@ -1,119 +1,119 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore.ops import functional as F - -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - -class Net(nn.Cell): - def __init__(self, reduction="none"): - super(Net, self).__init__() - self.bce = P.BinaryCrossEntropy(reduction) - - def construct(self, x, y, weight=None): - return self.bce(x, y, weight) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_binary_cross_entropy_loss(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - net = Net() - loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) - expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, - 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, - 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, - 0.03405444, 0.23934692] - assert np.allclose(loss.asnumpy(), expect) - - -def test_binary_cross_entropy_loss_sum_without_weight(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - reduction = "sum" - net = Net(reduction) - loss = net(Tensor(prediction), Tensor(target)) - expect = [25.48195216753522] - assert np.allclose(loss.asnumpy(), expect) - - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = C.GradOperation(get_all=True, sens_param=True) - self.network = network - - def construct(self, x1, x2, sens, weight=None): - gout = self.grad(self.network)(x1, x2, sens, weight) - return gout - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_binary_cross_entropy_loss_grad(): - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - sens = np.random.rand(20).astype(np.float32) - weight = np.random.rand(20).astype(np.float32) - grad = Grad(Net()) - dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) - - dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, - -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, - 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, - 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, - -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] - - assert np.allclose(dx[0].asnumpy(), dx1_expect) - - -def test_binary_cross_entropy_forward_functional(nptype): - """ - Feature: test binary_cross_entropy forward for given input dtype. - Description: test inputs for given input dtype. - Expectation: the result match with expected result. - """ - logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(nptype)) - labels = Tensor(np.array([0., 1., 0.]).astype(nptype)) - weight = Tensor(np.array([1, 2, 2]).astype(nptype)) - output = F.binary_cross_entropy(logits, labels, weight) - expected = Tensor(np.array([0.38240486]).astype(nptype)) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_binary_cross_entropy_forward_float32_functional(): - """ - Feature: test binary_cross_entropy forward. - Description: test float32 inputs. - Expectation: the result match with expected result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_binary_cross_entropy_forward_functional(np.float32) - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - test_binary_cross_entropy_forward_functional(np.float32) +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.bce = P.BinaryCrossEntropy(reduction) + + def construct(self, x, y, weight=None): + return self.bce(x, y, weight) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_binary_cross_entropy_loss(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + net = Net() + loss = net(Tensor(prediction), Tensor(target), Tensor(weight)) + expect = [0.09555826, 1.2861121, 0.03518666, 0.6969416, 0.24313456, 0.99062896, + 0.19205657, 0.5465214, 0.36964455, 0.21999404, 2.2953863, 2.2566645, + 1.5803775, 1.3266402, 0.9883408, 1.2997618, 0.05439841, 0.14389999, + 0.03405444, 0.23934692] + assert np.allclose(loss.asnumpy(), expect) + + +def test_binary_cross_entropy_loss_sum_without_weight(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + reduction = "sum" + net = Net(reduction) + loss = net(Tensor(prediction), Tensor(target)) + expect = [25.48195216753522] + assert np.allclose(loss.asnumpy(), expect) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens, weight=None): + gout = self.grad(self.network)(x1, x2, sens, weight) + return gout + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_binary_cross_entropy_loss_grad(): + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + weight = np.random.rand(20).astype(np.float32) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens), Tensor(weight)) + + dx1_expect = [-4.80516590e-02, 2.32625079e+00, 6.38972521e-02, 3.13642323e-01, + -1.65661633e-01, -1.71821892e+00, -1.13685496e-01, 1.26669514e+00, + 1.47891801e-03, 5.83921909e-01, -2.17992840e+01, 4.21899414e+00, + 2.85430793e-02, -3.21346498e+00, -2.22674108e+00, -2.80453944e+00, + -1.19787852e-04, 2.48514321e-02, -1.66696273e-02, -2.71965731e-02] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + + +def test_binary_cross_entropy_forward_functional(nptype): + """ + Feature: test binary_cross_entropy forward for given input dtype. + Description: test inputs for given input dtype. + Expectation: the result match with expected result. + """ + logits = Tensor(np.array([0.2, 0.7, 0.1]).astype(nptype)) + labels = Tensor(np.array([0., 1., 0.]).astype(nptype)) + weight = Tensor(np.array([1, 2, 2]).astype(nptype)) + output = F.binary_cross_entropy(logits, labels, weight) + expected = Tensor(np.array([0.38240486]).astype(nptype)) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_binary_cross_entropy_forward_float32_functional(): + """ + Feature: test binary_cross_entropy forward. + Description: test float32 inputs. + Expectation: the result match with expected result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_binary_cross_entropy_forward_functional(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + test_binary_cross_entropy_forward_functional(np.float32) diff --git a/tests/st/ops/gpu/test_blackman_window_op.py b/tests/st/ops/gpu/test_blackman_window_op.py index 346482a0a84..073acc41e3c 100644 --- a/tests/st/ops/gpu/test_blackman_window_op.py +++ b/tests/st/ops/gpu/test_blackman_window_op.py @@ -1,124 +1,124 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import torch -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore.ops import functional as F -import mindspore.ops.operations.spectral_ops as P -from mindspore import Tensor -from mindspore.common import dtype as mstype -from mindspore.common.api import jit - - -class BlackmanWindowNet(nn.Cell): - def __init__(self, periodic=True, dtype=mstype.float32): - super(BlackmanWindowNet, self).__init__() - self.blackmanwindow = P.BlackmanWindow(periodic=periodic, dtype=dtype) - - @jit - def construct(self, input_x): - return self.blackmanwindow(input_x) - - -def get_dtype(dtype="float16"): - if dtype == "float16": - nptype = np.float16 - msptype = mstype.float16 - pttype = torch.float32 - elif dtype == "float32": - nptype = np.float32 - msptype = mstype.float32 - pttype = torch.float32 - elif dtype == "float64": - nptype = np.float64 - msptype = mstype.float64 - pttype = torch.float64 - else: - print("The attr 'dtype' must in [float16, float32, float64]") - return nptype, msptype, pttype - - -def blackman_window(periodic, dtype, loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - nptype, msptype, pttype = get_dtype(dtype) - input_x_np = np.array(200, dtype=np.int32) - input_x_ms = Tensor(input_x_np) - input_x_torch = torch.tensor(input_x_np) - blackman_window_net = BlackmanWindowNet(periodic, msptype) - blackman_window_output = blackman_window_net(input_x_ms) - blackman_window_expect = torch.blackman_window(input_x_torch, periodic=periodic, dtype=pttype) - assert np.allclose(blackman_window_output.asnumpy(), blackman_window_expect.numpy().astype(nptype), loss, loss) - - -def blackman_window_pynative(periodic, dtype, loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - nptype, msptype, pttype = get_dtype(dtype) - input_x_np = np.array(200, dtype=np.int64) - input_x_ms = Tensor(input_x_np) - input_x_torch = torch.tensor(input_x_np) - blackman_window_net = BlackmanWindowNet(periodic, msptype) - blackman_window_output = blackman_window_net(input_x_ms) - blackman_window_expect = torch.blackman_window(input_x_torch, periodic=periodic, dtype=pttype) - assert np.allclose(blackman_window_output.asnumpy(), blackman_window_expect.numpy().astype(nptype), loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_blackman_window_graph_int32_true_float32(): - """ - Feature: ALL To ALL - Description: test cases for BlackmanWindow - Expectation: the result match to torch - """ - blackman_window(periodic=True, dtype="float32", loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_blackman_window_pynative_int64_false_float64(): - """ - Feature: ALL To ALL - Description: test cases for BlackmanWindow - Expectation: the result match to torch - """ - blackman_window_pynative(periodic=False, dtype="float64", loss=1.0e-5) - - -def test_blackman_window_functional(): - """ - Feature: test blackman_window functional API. - Description: test case for blackman_window functional API. - Expectation: the result match with expected result. - """ - window_length = Tensor(10, mstype.int32) - output = F.blackman_window(window_length, periodic=True, dtype=mstype.float32) - expected = np.array([-2.9802322e-08, 4.0212840e-02, 2.0077014e-01, 5.0978714e-01, - 8.4922993e-01, 1.0000000e+00, 8.4922981e-01, 5.0978690e-01, - 2.0077008e-01, 4.0212870e-02]).astype(np.float32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_blackman_window_functional_modes(): - """ - Feature: test blackman_window functional API in PyNative and Graph modes. - Description: test case for blackman_window functional API. - Expectation: the result match with expected result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_blackman_window_functional() - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - test_blackman_window_functional() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import torch +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore.ops import functional as F +import mindspore.ops.operations.spectral_ops as P +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.api import jit + + +class BlackmanWindowNet(nn.Cell): + def __init__(self, periodic=True, dtype=mstype.float32): + super(BlackmanWindowNet, self).__init__() + self.blackmanwindow = P.BlackmanWindow(periodic=periodic, dtype=dtype) + + @jit + def construct(self, input_x): + return self.blackmanwindow(input_x) + + +def get_dtype(dtype="float16"): + if dtype == "float16": + nptype = np.float16 + msptype = mstype.float16 + pttype = torch.float32 + elif dtype == "float32": + nptype = np.float32 + msptype = mstype.float32 + pttype = torch.float32 + elif dtype == "float64": + nptype = np.float64 + msptype = mstype.float64 + pttype = torch.float64 + else: + print("The attr 'dtype' must in [float16, float32, float64]") + return nptype, msptype, pttype + + +def blackman_window(periodic, dtype, loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + nptype, msptype, pttype = get_dtype(dtype) + input_x_np = np.array(200, dtype=np.int32) + input_x_ms = Tensor(input_x_np) + input_x_torch = torch.tensor(input_x_np) + blackman_window_net = BlackmanWindowNet(periodic, msptype) + blackman_window_output = blackman_window_net(input_x_ms) + blackman_window_expect = torch.blackman_window(input_x_torch, periodic=periodic, dtype=pttype) + assert np.allclose(blackman_window_output.asnumpy(), blackman_window_expect.numpy().astype(nptype), loss, loss) + + +def blackman_window_pynative(periodic, dtype, loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + nptype, msptype, pttype = get_dtype(dtype) + input_x_np = np.array(200, dtype=np.int64) + input_x_ms = Tensor(input_x_np) + input_x_torch = torch.tensor(input_x_np) + blackman_window_net = BlackmanWindowNet(periodic, msptype) + blackman_window_output = blackman_window_net(input_x_ms) + blackman_window_expect = torch.blackman_window(input_x_torch, periodic=periodic, dtype=pttype) + assert np.allclose(blackman_window_output.asnumpy(), blackman_window_expect.numpy().astype(nptype), loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_blackman_window_graph_int32_true_float32(): + """ + Feature: ALL To ALL + Description: test cases for BlackmanWindow + Expectation: the result match to torch + """ + blackman_window(periodic=True, dtype="float32", loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_blackman_window_pynative_int64_false_float64(): + """ + Feature: ALL To ALL + Description: test cases for BlackmanWindow + Expectation: the result match to torch + """ + blackman_window_pynative(periodic=False, dtype="float64", loss=1.0e-5) + + +def test_blackman_window_functional(): + """ + Feature: test blackman_window functional API. + Description: test case for blackman_window functional API. + Expectation: the result match with expected result. + """ + window_length = Tensor(10, mstype.int32) + output = F.blackman_window(window_length, periodic=True, dtype=mstype.float32) + expected = np.array([-2.9802322e-08, 4.0212840e-02, 2.0077014e-01, 5.0978714e-01, + 8.4922993e-01, 1.0000000e+00, 8.4922981e-01, 5.0978690e-01, + 2.0077008e-01, 4.0212870e-02]).astype(np.float32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_blackman_window_functional_modes(): + """ + Feature: test blackman_window functional API in PyNative and Graph modes. + Description: test case for blackman_window functional API. + Expectation: the result match with expected result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_blackman_window_functional() + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + test_blackman_window_functional() diff --git a/tests/st/ops/gpu/test_check_numerics_op.py b/tests/st/ops/gpu/test_check_numerics_op.py index 5712e78dc70..b8d2530b1d9 100644 --- a/tests/st/ops/gpu/test_check_numerics_op.py +++ b/tests/st/ops/gpu/test_check_numerics_op.py @@ -1,108 +1,108 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore as ms -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.array_ops as P -from mindspore import Tensor, jit - - -class CheckNumericsNet(nn.Cell): - - def __init__(self): - super(CheckNumericsNet, self).__init__() - self.checknumerics = P.CheckNumerics() - - @jit - def construct(self, input_x): - return self.checknumerics(input_x) - - -def check_numerics(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - input_x_np = np.array([[6, 3], [2, 5]], dtype=np.float32) - input_x_ms = Tensor(input_x_np) - check_numerics_net = CheckNumericsNet() - check_numerics_output = check_numerics_net(input_x_ms) - check_numerics_expect = np.array([[6, 3], [2, 5]], dtype=np.float32) - assert np.allclose(check_numerics_output.asnumpy(), check_numerics_expect, - loss, loss) - - -def check_numerics_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - input_x_np = np.array([[1, 5], [2, 4]], dtype=np.float64) - input_x_ms = Tensor(input_x_np) - check_numerics_net = CheckNumericsNet() - check_numerics_output = check_numerics_net(input_x_ms) - check_numerics_expect = np.array([[1, 5], [2, 4]], dtype=np.float64) - print(check_numerics_output) - print(check_numerics_expect) - assert np.allclose(check_numerics_output.asnumpy(), check_numerics_expect, - loss, loss) - - -def dyn_case(): - net = CheckNumericsNet() - - x_dyn = Tensor(shape=[None, None], dtype=ms.float64) - net.set_inputs(x_dyn) - - x = Tensor( - np.array([[0.42987306, 0.02847828, 0.59385591, 0.7040952, 0.27390435], - [0.32904094, 0.63063352, 0.70752448, 0.24763578, 0.99662956], - [0.66478424, 0.70580542, 0.92749155, 0.72736302, 0.24973136], - [0.79918445, 0.68613469, 0.9526593, 0.12412648, - 0.15175918]]).astype(np.float64)) - out = net(x) - - expect_shape = (4, 5) - assert out.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_check_numerics_dyn(): - """ - Feature: test CheckNumerics ops in gpu. - Description: Test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - dyn_case() - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - dyn_case() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_check_numerics_graph_float32(): - """ - Feature: ALL To ALL - Description: test cases for CheckNumerics - Expectation: the result match to tensorflow - """ - check_numerics(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_check_numerics_pynative_float64(): - """ - Feature: ALL To ALL - Description: test cases for CheckNumerics - Expectation: the result match to tensorflow - """ - check_numerics_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor, jit + + +class CheckNumericsNet(nn.Cell): + + def __init__(self): + super(CheckNumericsNet, self).__init__() + self.checknumerics = P.CheckNumerics() + + @jit + def construct(self, input_x): + return self.checknumerics(input_x) + + +def check_numerics(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_x_np = np.array([[6, 3], [2, 5]], dtype=np.float32) + input_x_ms = Tensor(input_x_np) + check_numerics_net = CheckNumericsNet() + check_numerics_output = check_numerics_net(input_x_ms) + check_numerics_expect = np.array([[6, 3], [2, 5]], dtype=np.float32) + assert np.allclose(check_numerics_output.asnumpy(), check_numerics_expect, + loss, loss) + + +def check_numerics_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_x_np = np.array([[1, 5], [2, 4]], dtype=np.float64) + input_x_ms = Tensor(input_x_np) + check_numerics_net = CheckNumericsNet() + check_numerics_output = check_numerics_net(input_x_ms) + check_numerics_expect = np.array([[1, 5], [2, 4]], dtype=np.float64) + print(check_numerics_output) + print(check_numerics_expect) + assert np.allclose(check_numerics_output.asnumpy(), check_numerics_expect, + loss, loss) + + +def dyn_case(): + net = CheckNumericsNet() + + x_dyn = Tensor(shape=[None, None], dtype=ms.float64) + net.set_inputs(x_dyn) + + x = Tensor( + np.array([[0.42987306, 0.02847828, 0.59385591, 0.7040952, 0.27390435], + [0.32904094, 0.63063352, 0.70752448, 0.24763578, 0.99662956], + [0.66478424, 0.70580542, 0.92749155, 0.72736302, 0.24973136], + [0.79918445, 0.68613469, 0.9526593, 0.12412648, + 0.15175918]]).astype(np.float64)) + out = net(x) + + expect_shape = (4, 5) + assert out.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_check_numerics_dyn(): + """ + Feature: test CheckNumerics ops in gpu. + Description: Test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + dyn_case() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_check_numerics_graph_float32(): + """ + Feature: ALL To ALL + Description: test cases for CheckNumerics + Expectation: the result match to tensorflow + """ + check_numerics(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_check_numerics_pynative_float64(): + """ + Feature: ALL To ALL + Description: test cases for CheckNumerics + Expectation: the result match to tensorflow + """ + check_numerics_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_col2im_op.py b/tests/st/ops/gpu/test_col2im_op.py index 386222b5821..9f5cf9d5472 100644 --- a/tests/st/ops/gpu/test_col2im_op.py +++ b/tests/st/ops/gpu/test_col2im_op.py @@ -1,114 +1,114 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import dtype as mstype -from mindspore import Tensor -from mindspore.ops.operations.array_ops import Col2Im -from mindspore.ops import functional as F -from mindspore.common import dtype as mstype - -np.random.seed(1) - - -class Col2ImTest(nn.Cell): - - def __init__(self, kernel_size, dilation, padding, stride): - super(Col2ImTest, self).__init__() - self.c2i = Col2Im(kernel_size, dilation, padding, stride) - - def construct(self, x, output_size): - return self.c2i(x, output_size) - - -def dyn_case(): - kernel_size = [2, 2] - dilation = [2, 2] - padding = [2, 2] - stride = [2, 2] - col2im = Col2ImTest(kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride) - - x_dyn = Tensor(shape=[None, None, None, None], dtype=mstype.float32) - output_size_dyn = Tensor(shape=[None], dtype=mstype.int32) - col2im.set_inputs(x_dyn, output_size_dyn) - - x = Tensor(np.random.rand(16, 16, 4, 25).astype(np.float32)) - output_size = Tensor([8, 8], dtype=mstype.int32) - output = col2im(x, output_size) - - expect_shape = (16, 16, 8, 8) - assert output.shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_col2im_dyn(): - """ - Feature: Col2Im function. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - dyn_case() - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - dyn_case() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize("mode, input_type", - [(context.GRAPH_MODE, np.float32), - (context.PYNATIVE_MODE, np.float32), - (context.GRAPH_MODE, np.float16), - (context.PYNATIVE_MODE, np.float16), - (context.GRAPH_MODE, np.float64), - (context.PYNATIVE_MODE, np.float64), - (context.GRAPH_MODE, np.complex64), - (context.PYNATIVE_MODE, np.complex64), - (context.GRAPH_MODE, np.complex128), - (context.PYNATIVE_MODE, np.complex128)]) -def test_col2im_op(mode, input_type): - """ - Feature: Celu cpu kernel - Description: test the celu alpha = 1.0. - Expectation: match to np benchmark. - """ - context.set_context(mode=mode, device_target='GPU') - x = Tensor(np.random.rand(16, 16, 4, 25).astype(input_type)) - output_size = Tensor([8, 8], dtype=mstype.int32) - kernel_size = [2, 2] - dilation = [2, 2] - padding = [2, 2] - stride = [2, 2] - expect_shape = (16, 16, 8, 8) - col2im = Col2ImTest(kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride) - output = col2im(x, output_size) - assert output.shape == expect_shape - - output_func = F.col2im(x, output_size, kernel_size, dilation, padding, - stride) - assert output_func.shape == expect_shape - - assert x.col2im(output_size, kernel_size, dilation, padding, - stride).shape == expect_shape +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import dtype as mstype +from mindspore import Tensor +from mindspore.ops.operations.array_ops import Col2Im +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + +np.random.seed(1) + + +class Col2ImTest(nn.Cell): + + def __init__(self, kernel_size, dilation, padding, stride): + super(Col2ImTest, self).__init__() + self.c2i = Col2Im(kernel_size, dilation, padding, stride) + + def construct(self, x, output_size): + return self.c2i(x, output_size) + + +def dyn_case(): + kernel_size = [2, 2] + dilation = [2, 2] + padding = [2, 2] + stride = [2, 2] + col2im = Col2ImTest(kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + x_dyn = Tensor(shape=[None, None, None, None], dtype=mstype.float32) + output_size_dyn = Tensor(shape=[None], dtype=mstype.int32) + col2im.set_inputs(x_dyn, output_size_dyn) + + x = Tensor(np.random.rand(16, 16, 4, 25).astype(np.float32)) + output_size = Tensor([8, 8], dtype=mstype.int32) + output = col2im(x, output_size) + + expect_shape = (16, 16, 8, 8) + assert output.shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_col2im_dyn(): + """ + Feature: Col2Im function. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dyn_case() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize("mode, input_type", + [(context.GRAPH_MODE, np.float32), + (context.PYNATIVE_MODE, np.float32), + (context.GRAPH_MODE, np.float16), + (context.PYNATIVE_MODE, np.float16), + (context.GRAPH_MODE, np.float64), + (context.PYNATIVE_MODE, np.float64), + (context.GRAPH_MODE, np.complex64), + (context.PYNATIVE_MODE, np.complex64), + (context.GRAPH_MODE, np.complex128), + (context.PYNATIVE_MODE, np.complex128)]) +def test_col2im_op(mode, input_type): + """ + Feature: Celu cpu kernel + Description: test the celu alpha = 1.0. + Expectation: match to np benchmark. + """ + context.set_context(mode=mode, device_target='GPU') + x = Tensor(np.random.rand(16, 16, 4, 25).astype(input_type)) + output_size = Tensor([8, 8], dtype=mstype.int32) + kernel_size = [2, 2] + dilation = [2, 2] + padding = [2, 2] + stride = [2, 2] + expect_shape = (16, 16, 8, 8) + col2im = Col2ImTest(kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + output = col2im(x, output_size) + assert output.shape == expect_shape + + output_func = F.col2im(x, output_size, kernel_size, dilation, padding, + stride) + assert output_func.shape == expect_shape + + assert x.col2im(output_size, kernel_size, dilation, padding, + stride).shape == expect_shape diff --git a/tests/st/ops/gpu/test_cumprod_op.py b/tests/st/ops/gpu/test_cumprod_op.py index f974c14b801..3da4e0e6ccc 100644 --- a/tests/st/ops/gpu/test_cumprod_op.py +++ b/tests/st/ops/gpu/test_cumprod_op.py @@ -1,224 +1,224 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.nn as nn -from mindspore import dtype -from mindspore import Tensor -import mindspore.context as context -from mindspore.ops import functional as F -from mindspore.ops import operations as P -from mindspore.common.api import jit - - -def cum_prod(nptype): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x0 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis0 = 3 - - x1 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis1 = 3 - - x2 = np.random.rand(2, 3, 1, 4).astype(nptype) - axis2 = 2 - - x3 = np.random.rand(2, 3, 1, 4).astype(nptype) - axis3 = 2 - - x4 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis4 = 1 - - x5 = np.random.rand(2, 3).astype(nptype) - axis5 = 1 - - x6 = np.random.rand(1, 1, 1, 1).astype(nptype) - axis6 = 0 - - class CumProd(nn.Cell): - def __init__(self, nptype): - super(CumProd, self).__init__() - - self.x0 = Tensor(x0) - self.axis0 = axis0 - - self.x1 = Tensor(x1) - self.axis1 = axis1 - - self.x2 = Tensor(x2) - self.axis2 = axis2 - - self.x3 = Tensor(x3) - self.axis3 = axis3 - - self.x4 = Tensor(x4) - self.axis4 = axis4 - - self.x5 = Tensor(x5) - self.axis5 = axis5 - - self.x6 = Tensor(x6) - self.axis6 = axis6 - - @jit - def construct(self): - output = (P.CumProd()(self.x0, self.axis0), - P.CumProd()(self.x1, self.axis1), - P.CumProd()(self.x2, self.axis2), - P.CumProd()(self.x3, self.axis3), - P.CumProd()(self.x4, self.axis4), - P.CumProd()(self.x5, self.axis5), - P.CumProd()(self.x6, self.axis6)) - return output - - cumprod = CumProd(nptype) - output = cumprod() - - expect0 = np.cumprod(x0, axis=axis0) - diff0 = abs(output[0].asnumpy() - expect0) - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output[0].shape == expect0.shape - - expect1 = np.cumprod(x1, axis=axis1) - diff1 = abs(output[1].asnumpy() - expect1) - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output[1].shape == expect1.shape - - expect2 = np.cumprod(x2, axis=axis2) - diff2 = abs(output[2].asnumpy() - expect2) - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output[2].shape == expect2.shape - - expect3 = np.cumprod(x3, axis=axis3) - diff3 = abs(output[3].asnumpy() - expect3) - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output[3].shape == expect3.shape - - expect4 = np.cumprod(x4, axis=axis4) - diff4 = abs(output[4].asnumpy() - expect4) - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output[4].shape == expect4.shape - - expect5 = np.cumprod(x5, axis=axis5) - diff5 = abs(output[5].asnumpy() - expect5) - error5 = np.ones(shape=expect5.shape) * 1.0e-5 - assert np.all(diff5 < error5) - assert output[5].shape == expect5.shape - - expect6 = np.cumprod(x6, axis=axis6) - diff6 = abs(output[6].asnumpy() - expect6) - error6 = np.ones(shape=expect6.shape) * 1.0e-5 - assert np.all(diff6 < error6) - assert output[6].shape == expect6.shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_prod_uint8(): - cum_prod(np.uint8) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_prod_int8(): - cum_prod(np.int8) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_prod_int32(): - cum_prod(np.int32) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_prod_float16(): - cum_prod(np.float16) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_prod_float32(): - cum_prod(np.float32) - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.op = P.CumProd() - - def construct(self, x): - return self.op(x, 0) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cumprod_dshape(): - """ - Feature: Test cumprod dynamic shape. - Description: Test cumprod dynamic shape. - Expectation: Success. - """ - net = Net() - input_x_dyn = Tensor(shape=[3, None], dtype=dtype.float32) - net.set_inputs(input_x_dyn) - input_x = Tensor(np.random.random(([3, 10])), dtype=dtype.float32) - output = net(input_x) - expect_shape = (3, 10) - assert output.asnumpy().shape == expect_shape - - -def test_cumprod_functional_api(): - """ - Feature: test cumprod functional API. - Description: testcase for cumprod functional API. - Expectation: the result match with expected result. - """ - dtype_op = P.DType() - x = Tensor(np.array([1, 2, 3]), dtype.float32) - output = F.cumprod(x, 0, dtype.int32) - expected = np.array([1, 2, 6], np.int32) - assert dtype_op(output) == dtype.int32 - np.testing.assert_array_equal(output.asnumpy(), expected) - - -def test_cumprod_tensor_api(): - """ - Feature: test cumprod tensor API. - Description: testcase for cumprod tensor API. - Expectation: the result match with expected result. - """ - dtype_op = P.DType() - x = Tensor(np.array([1, 2, 3]), dtype.float32) - output = x.cumprod(0, dtype.int32) - expected = np.array([1, 2, 6], np.int32) - assert dtype_op(output) == dtype.int32 - np.testing.assert_array_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cumprod_functional_tensor_modes(): - """ - Feature: test cumprod functional and tensor APIs in PyNative and Graph modes. - Description: test case for cumprod functional and tensor APIs. - Expectation: the result match with expected result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_cumprod_functional_api() - test_cumprod_tensor_api() - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - test_cumprod_functional_api() - test_cumprod_tensor_api() +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor +import mindspore.context as context +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.common.api import jit + + +def cum_prod(nptype): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x0 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis0 = 3 + + x1 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis1 = 3 + + x2 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis2 = 2 + + x3 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis3 = 2 + + x4 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis4 = 1 + + x5 = np.random.rand(2, 3).astype(nptype) + axis5 = 1 + + x6 = np.random.rand(1, 1, 1, 1).astype(nptype) + axis6 = 0 + + class CumProd(nn.Cell): + def __init__(self, nptype): + super(CumProd, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + + @jit + def construct(self): + output = (P.CumProd()(self.x0, self.axis0), + P.CumProd()(self.x1, self.axis1), + P.CumProd()(self.x2, self.axis2), + P.CumProd()(self.x3, self.axis3), + P.CumProd()(self.x4, self.axis4), + P.CumProd()(self.x5, self.axis5), + P.CumProd()(self.x6, self.axis6)) + return output + + cumprod = CumProd(nptype) + output = cumprod() + + expect0 = np.cumprod(x0, axis=axis0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.cumprod(x1, axis=axis1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.cumprod(x2, axis=axis2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.cumprod(x3, axis=axis3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.cumprod(x4, axis=axis4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.cumprod(x5, axis=axis5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.cumprod(x6, axis=axis6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_prod_uint8(): + cum_prod(np.uint8) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_prod_int8(): + cum_prod(np.int8) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_prod_int32(): + cum_prod(np.int32) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_prod_float16(): + cum_prod(np.float16) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_prod_float32(): + cum_prod(np.float32) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op = P.CumProd() + + def construct(self, x): + return self.op(x, 0) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cumprod_dshape(): + """ + Feature: Test cumprod dynamic shape. + Description: Test cumprod dynamic shape. + Expectation: Success. + """ + net = Net() + input_x_dyn = Tensor(shape=[3, None], dtype=dtype.float32) + net.set_inputs(input_x_dyn) + input_x = Tensor(np.random.random(([3, 10])), dtype=dtype.float32) + output = net(input_x) + expect_shape = (3, 10) + assert output.asnumpy().shape == expect_shape + + +def test_cumprod_functional_api(): + """ + Feature: test cumprod functional API. + Description: testcase for cumprod functional API. + Expectation: the result match with expected result. + """ + dtype_op = P.DType() + x = Tensor(np.array([1, 2, 3]), dtype.float32) + output = F.cumprod(x, 0, dtype.int32) + expected = np.array([1, 2, 6], np.int32) + assert dtype_op(output) == dtype.int32 + np.testing.assert_array_equal(output.asnumpy(), expected) + + +def test_cumprod_tensor_api(): + """ + Feature: test cumprod tensor API. + Description: testcase for cumprod tensor API. + Expectation: the result match with expected result. + """ + dtype_op = P.DType() + x = Tensor(np.array([1, 2, 3]), dtype.float32) + output = x.cumprod(0, dtype.int32) + expected = np.array([1, 2, 6], np.int32) + assert dtype_op(output) == dtype.int32 + np.testing.assert_array_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cumprod_functional_tensor_modes(): + """ + Feature: test cumprod functional and tensor APIs in PyNative and Graph modes. + Description: test case for cumprod functional and tensor APIs. + Expectation: the result match with expected result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_cumprod_functional_api() + test_cumprod_tensor_api() + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + test_cumprod_functional_api() + test_cumprod_tensor_api() diff --git a/tests/st/ops/gpu/test_cumsum_op.py b/tests/st/ops/gpu/test_cumsum_op.py index 4d978d98056..e7d843f26bb 100644 --- a/tests/st/ops/gpu/test_cumsum_op.py +++ b/tests/st/ops/gpu/test_cumsum_op.py @@ -1,148 +1,148 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.api import jit -from mindspore.ops import operations as P - -def cum_sum(nptype): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x0 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis0 = 3 - - x1 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis1 = 3 - - x2 = np.random.rand(2, 3, 1, 4).astype(nptype) - axis2 = 2 - - x3 = np.random.rand(2, 3, 1, 4).astype(nptype) - axis3 = 2 - - x4 = np.random.rand(2, 3, 4, 4).astype(nptype) - axis4 = 1 - - x5 = np.random.rand(2, 3).astype(nptype) - axis5 = 1 - - x6 = np.random.rand(1, 1, 1, 1).astype(nptype) - axis6 = 0 - - class CumSum(nn.Cell): - def __init__(self, nptype): - super(CumSum, self).__init__() - - self.x0 = Tensor(x0) - self.axis0 = axis0 - - self.x1 = Tensor(x1) - self.axis1 = axis1 - - self.x2 = Tensor(x2) - self.axis2 = axis2 - - self.x3 = Tensor(x3) - self.axis3 = axis3 - - self.x4 = Tensor(x4) - self.axis4 = axis4 - - self.x5 = Tensor(x5) - self.axis5 = axis5 - - self.x6 = Tensor(x6) - self.axis6 = axis6 - - @jit - def construct(self): - return (P.CumSum()(self.x0, self.axis0), - P.CumSum()(self.x1, self.axis1), - P.CumSum()(self.x2, self.axis2), - P.CumSum()(self.x3, self.axis3), - P.CumSum()(self.x4, self.axis4), - P.CumSum()(self.x5, self.axis5), - P.CumSum()(self.x6, self.axis6)) - - - cumsum = CumSum(nptype) - output = cumsum() - - expect0 = np.cumsum(x0, axis=axis0) - diff0 = abs(output[0].asnumpy() - expect0) - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output[0].shape == expect0.shape - - expect1 = np.cumsum(x1, axis=axis1) - diff1 = abs(output[1].asnumpy() - expect1) - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output[1].shape == expect1.shape - - expect2 = np.cumsum(x2, axis=axis2) - diff2 = abs(output[2].asnumpy() - expect2) - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output[2].shape == expect2.shape - - expect3 = np.cumsum(x3, axis=axis3) - diff3 = abs(output[3].asnumpy() - expect3) - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output[3].shape == expect3.shape - - expect4 = np.cumsum(x4, axis=axis4) - diff4 = abs(output[4].asnumpy() - expect4) - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output[4].shape == expect4.shape - - expect5 = np.cumsum(x5, axis=axis5) - diff5 = abs(output[5].asnumpy() - expect5) - error5 = np.ones(shape=expect5.shape) * 1.0e-5 - assert np.all(diff5 < error5) - assert output[5].shape == expect5.shape - - expect6 = np.cumsum(x6, axis=axis6) - diff6 = abs(output[6].asnumpy() - expect6) - error6 = np.ones(shape=expect6.shape) * 1.0e-5 - assert np.all(diff6 < error6) - assert output[6].shape == expect6.shape - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_sum_uint8(): - cum_sum(np.uint8) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_sum_int8(): - cum_sum(np.int8) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_sum_int32(): - cum_sum(np.int32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_sum_float16(): - cum_sum(np.float16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_cum_sum_float32(): - cum_sum(np.float32) +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import jit +from mindspore.ops import operations as P + +def cum_sum(nptype): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x0 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis0 = 3 + + x1 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis1 = 3 + + x2 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis2 = 2 + + x3 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis3 = 2 + + x4 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis4 = 1 + + x5 = np.random.rand(2, 3).astype(nptype) + axis5 = 1 + + x6 = np.random.rand(1, 1, 1, 1).astype(nptype) + axis6 = 0 + + class CumSum(nn.Cell): + def __init__(self, nptype): + super(CumSum, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + + @jit + def construct(self): + return (P.CumSum()(self.x0, self.axis0), + P.CumSum()(self.x1, self.axis1), + P.CumSum()(self.x2, self.axis2), + P.CumSum()(self.x3, self.axis3), + P.CumSum()(self.x4, self.axis4), + P.CumSum()(self.x5, self.axis5), + P.CumSum()(self.x6, self.axis6)) + + + cumsum = CumSum(nptype) + output = cumsum() + + expect0 = np.cumsum(x0, axis=axis0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.cumsum(x1, axis=axis1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.cumsum(x2, axis=axis2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.cumsum(x3, axis=axis3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.cumsum(x4, axis=axis4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.cumsum(x5, axis=axis5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.cumsum(x6, axis=axis6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_sum_uint8(): + cum_sum(np.uint8) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_sum_int8(): + cum_sum(np.int8) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_sum_int32(): + cum_sum(np.int32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_sum_float16(): + cum_sum(np.float16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_cum_sum_float32(): + cum_sum(np.float32) diff --git a/tests/st/ops/gpu/test_div_op.py b/tests/st/ops/gpu/test_div_op.py index 87f01fd9f90..e5fa2ff6550 100644 --- a/tests/st/ops/gpu/test_div_op.py +++ b/tests/st/ops/gpu/test_div_op.py @@ -1,251 +1,251 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import functional as F -from mindspore.ops import operations as P - - -class NetDiv(nn.Cell): - def __init__(self): - super(NetDiv, self).__init__() - self.div = P.Div() - - def construct(self, x, y): - return self.div(x, y) - - -def div(nptype): - x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) - x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) - y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x3_np = np.random.randint(1, 5, 1).astype(nptype) - y3_np = np.random.randint(1, 5, 1).astype(nptype) - x4_np = np.array(78).astype(nptype) - y4_np = np.array(37.5).astype(nptype) - - x0 = Tensor(x0_np) - y0 = Tensor(y0_np) - x1 = Tensor(x1_np) - y1 = Tensor(y1_np) - x2 = Tensor(x2_np) - y2 = Tensor(y2_np) - x3 = Tensor(x3_np) - y3 = Tensor(y3_np) - x4 = Tensor(x4_np) - y4 = Tensor(y4_np) - - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - div_net = NetDiv() - output0 = div_net(x0, y0) - expect0 = np.divide(x0_np, y0_np) - diff0 = output0.asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output0.shape == expect0.shape - - output1 = div_net(x1, y1) - expect1 = np.divide(x1_np, y1_np) - diff1 = output1.asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output1.shape == expect1.shape - - output2 = div_net(x2, y2) - expect2 = np.divide(x2_np, y2_np) - diff2 = output2.asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output2.shape == expect2.shape - - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - output3 = div_net(x3, y3) - expect3 = np.divide(x3_np, y3_np) - diff3 = output3.asnumpy() - expect3 - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output3.shape == expect3.shape - - output4 = div_net(x4, y4) - expect4 = np.divide(x4_np, y4_np) - diff4 = output4.asnumpy() - expect4 - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output4.shape == expect4.shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float64(): - div(np.float64) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float32(): - div(np.float32) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float16(): - div(np.float16) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_int64(): - div(np.int64) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_int32(): - div(np.int32) - - -def test_div_tensor_api(): - """ - Feature: test div tensor API. - Description: testcase for div tensor API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[-0.3711, -1.9353, -0.4605, -0.2917], - [0.1815, -1.0111, 0.9805, -1.5923], - [0.1062, 1.4581, 0.7759, -1.2344], - [-0.1830, -0.0313, 1.1908, -1.4757]])) - y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) - output = x.div(y) - expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], - [0.2260, -3.4509, -1.2086, 6.8990], - [0.1322, 4.9764, -0.9564, 5.3484], - [-0.2278, -0.1068, -1.4678, 6.3938]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=2) - - -def test_div_trunc_tensor_api(): - """ - Feature: test div tensor API. - Description: testcase for div tensor API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], - [0.9276, -0.5893, -0.0838, 0.4097], - [-0.2601, -0.2397, 0.5832, 0.2250], - [0.0322, 0.7103, 0.6315, -0.8621]])) - y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, rounding_mode='trunc') - expected = np.array([[0., -0., -0., 0.], - [1., 1., 0., -0.], - [-0., 0., -1., -0.], - [0., -1., -2., 1.]]) - np.testing.assert_array_equal(output.asnumpy(), expected) - - -def test_div_floor_tensor_api(): - """ - Feature: test div tensor API. - Description: testcase for div tensor API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], - [0.9276, -0.5893, -0.0838, 0.4097], - [-0.2601, -0.2397, 0.5832, 0.2250], - [0.0322, 0.7103, 0.6315, -0.8621]])) - y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, rounding_mode='floor') - expected = np.array([[0., -1., -1., 0.], - [1., 1., 0., -1.], - [-1., 0., -2., -1.], - [0., -2., -3., 1.]]) - np.testing.assert_array_equal(output.asnumpy(), expected) - - -def test_div_functional_api(): - """ - Feature: test div functional API. - Description: testcase for div functional API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[-0.3711, -1.9353, -0.4605, -0.2917], - [0.1815, -1.0111, 0.9805, -1.5923], - [0.1062, 1.4581, 0.7759, -1.2344], - [-0.1830, -0.0313, 1.1908, -1.4757]])) - y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) - output = F.div(x, y, rounding_mode=None) - expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], - [0.2260, -3.4509, -1.2086, 6.8990], - [0.1322, 4.9764, -0.9564, 5.3484], - [-0.2278, -0.1068, -1.4678, 6.3938]]) - np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=2) - - -def test_div_trunc_functional_api(): - """ - Feature: test div functional API. - Description: testcase for div functional API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], - [0.9276, -0.5893, -0.0838, 0.4097], - [-0.2601, -0.2397, 0.5832, 0.2250], - [0.0322, 0.7103, 0.6315, -0.8621]])) - y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, rounding_mode='trunc') - expected = np.array([[0., -0., -0., 0.], - [1., 1., 0., -0.], - [-0., 0., -1., -0.], - [0., -1., -2., 1.]]) - np.testing.assert_array_equal(output.asnumpy(), expected) - - -def test_div_floor_functional_api(): - """ - Feature: test div functional API. - Description: testcase for div functional API. - Expectation: the result match with expected result. - """ - x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], - [0.9276, -0.5893, -0.0838, 0.4097], - [-0.2601, -0.2397, 0.5832, 0.2250], - [0.0322, 0.7103, 0.6315, -0.8621]])) - y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, rounding_mode='floor') - expected = np.array([[0., -1., -1., 0.], - [1., 1., 0., -1.], - [-1., 0., -2., -1.], - [0., -2., -3., 1.]]) - np.testing.assert_array_equal(output.asnumpy(), expected) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_div_functional_tensor_modes(mode): - """ - Feature: test div functional and tensor APIs in PyNative and Graph modes. - Description: test case for div functional and tensor APIs. - Expectation: the result match with expected result. - """ - context.set_context(mode=mode, device_target="GPU") - test_div_tensor_api() - test_div_trunc_tensor_api() - test_div_floor_tensor_api() - test_div_functional_api() - test_div_trunc_functional_api() - test_div_floor_functional_api() +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +class NetDiv(nn.Cell): + def __init__(self): + super(NetDiv, self).__init__() + self.div = P.Div() + + def construct(self, x, y): + return self.div(x, y) + + +def div(nptype): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x3_np = np.random.randint(1, 5, 1).astype(nptype) + y3_np = np.random.randint(1, 5, 1).astype(nptype) + x4_np = np.array(78).astype(nptype) + y4_np = np.array(37.5).astype(nptype) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + div_net = NetDiv() + output0 = div_net(x0, y0) + expect0 = np.divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = div_net(x1, y1) + expect1 = np.divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = div_net(x2, y2) + expect2 = np.divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + output3 = div_net(x3, y3) + expect3 = np.divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = div_net(x4, y4) + expect4 = np.divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float64(): + div(np.float64) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float32(): + div(np.float32) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float16(): + div(np.float16) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_int64(): + div(np.int64) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_int32(): + div(np.int32) + + +def test_div_tensor_api(): + """ + Feature: test div tensor API. + Description: testcase for div tensor API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[-0.3711, -1.9353, -0.4605, -0.2917], + [0.1815, -1.0111, 0.9805, -1.5923], + [0.1062, 1.4581, 0.7759, -1.2344], + [-0.1830, -0.0313, 1.1908, -1.4757]])) + y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) + output = x.div(y) + expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], + [0.2260, -3.4509, -1.2086, 6.8990], + [0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=2) + + +def test_div_trunc_tensor_api(): + """ + Feature: test div tensor API. + Description: testcase for div tensor API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], + [0.9276, -0.5893, -0.0838, 0.4097], + [-0.2601, -0.2397, 0.5832, 0.2250], + [0.0322, 0.7103, 0.6315, -0.8621]])) + y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) + output = x.div(y, rounding_mode='trunc') + expected = np.array([[0., -0., -0., 0.], + [1., 1., 0., -0.], + [-0., 0., -1., -0.], + [0., -1., -2., 1.]]) + np.testing.assert_array_equal(output.asnumpy(), expected) + + +def test_div_floor_tensor_api(): + """ + Feature: test div tensor API. + Description: testcase for div tensor API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], + [0.9276, -0.5893, -0.0838, 0.4097], + [-0.2601, -0.2397, 0.5832, 0.2250], + [0.0322, 0.7103, 0.6315, -0.8621]])) + y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) + output = x.div(y, rounding_mode='floor') + expected = np.array([[0., -1., -1., 0.], + [1., 1., 0., -1.], + [-1., 0., -2., -1.], + [0., -2., -3., 1.]]) + np.testing.assert_array_equal(output.asnumpy(), expected) + + +def test_div_functional_api(): + """ + Feature: test div functional API. + Description: testcase for div functional API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[-0.3711, -1.9353, -0.4605, -0.2917], + [0.1815, -1.0111, 0.9805, -1.5923], + [0.1062, 1.4581, 0.7759, -1.2344], + [-0.1830, -0.0313, 1.1908, -1.4757]])) + y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) + output = F.div(x, y, rounding_mode=None) + expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], + [0.2260, -3.4509, -1.2086, 6.8990], + [0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected, decimal=2) + + +def test_div_trunc_functional_api(): + """ + Feature: test div functional API. + Description: testcase for div functional API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], + [0.9276, -0.5893, -0.0838, 0.4097], + [-0.2601, -0.2397, 0.5832, 0.2250], + [0.0322, 0.7103, 0.6315, -0.8621]])) + y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) + output = F.div(x, y, rounding_mode='trunc') + expected = np.array([[0., -0., -0., 0.], + [1., 1., 0., -0.], + [-0., 0., -1., -0.], + [0., -1., -2., 1.]]) + np.testing.assert_array_equal(output.asnumpy(), expected) + + +def test_div_floor_functional_api(): + """ + Feature: test div functional API. + Description: testcase for div functional API. + Expectation: the result match with expected result. + """ + x = Tensor(np.array([[0.0385, 0.2672, 0.2781, -0.4063], + [0.9276, -0.5893, -0.0838, 0.4097], + [-0.2601, -0.2397, 0.5832, 0.2250], + [0.0322, 0.7103, 0.6315, -0.8621]])) + y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) + output = F.div(x, y, rounding_mode='floor') + expected = np.array([[0., -1., -1., 0.], + [1., 1., 0., -1.], + [-1., 0., -2., -1.], + [0., -2., -3., 1.]]) + np.testing.assert_array_equal(output.asnumpy(), expected) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_div_functional_tensor_modes(mode): + """ + Feature: test div functional and tensor APIs in PyNative and Graph modes. + Description: test case for div functional and tensor APIs. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode, device_target="GPU") + test_div_tensor_api() + test_div_trunc_tensor_api() + test_div_floor_tensor_api() + test_div_functional_api() + test_div_trunc_functional_api() + test_div_floor_functional_api() diff --git a/tests/st/ops/gpu/test_equalcount_op.py b/tests/st/ops/gpu/test_equalcount_op.py index 0cc0324d1d3..a57c90908c3 100644 --- a/tests/st/ops/gpu/test_equalcount_op.py +++ b/tests/st/ops/gpu/test_equalcount_op.py @@ -1,76 +1,76 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - - -class NetEqualCount(nn.Cell): - def __init__(self): - super(NetEqualCount, self).__init__() - self.equalcount = P.EqualCount() - - def construct(self, x, y): - return self.equalcount(x, y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_equalcount(): - x = Tensor(np.array([1, 20, 5]).astype(np.int32)) - y = Tensor(np.array([2, 20, 5]).astype(np.int32)) - expect = np.array([2]).astype(np.int32) - - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - equal_count = NetEqualCount() - output = equal_count(x, y) - assert (output.asnumpy() == expect).all() - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - equal_count = NetEqualCount() - output = equal_count(x, y) - assert (output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_equalcount_dynamic(): - """ - Feature: EqualCount ops - Description: dynamic shape in pynative and graph mode in GPU - Expectation: success - """ - x = Tensor(np.array([1, 20, 5]).astype(np.int32)) - y = Tensor(np.array([2, 20, 5]).astype(np.int32)) - xx = Tensor(shape=[None], dtype=mindspore.int32) - yy = Tensor(shape=[None], dtype=mindspore.int32) - expect = np.array([2]).astype(np.int32) - - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - equal_count = NetEqualCount() - equal_count.set_inputs(xx, yy) - output = equal_count(x, y) - assert (output.asnumpy() == expect).all() - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - equal_count = NetEqualCount() - equal_count.set_inputs(xx, yy) - output = equal_count(x, y) - assert (output.asnumpy() == expect).all() +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetEqualCount(nn.Cell): + def __init__(self): + super(NetEqualCount, self).__init__() + self.equalcount = P.EqualCount() + + def construct(self, x, y): + return self.equalcount(x, y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_equalcount(): + x = Tensor(np.array([1, 20, 5]).astype(np.int32)) + y = Tensor(np.array([2, 20, 5]).astype(np.int32)) + expect = np.array([2]).astype(np.int32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + equal_count = NetEqualCount() + output = equal_count(x, y) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + equal_count = NetEqualCount() + output = equal_count(x, y) + assert (output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_equalcount_dynamic(): + """ + Feature: EqualCount ops + Description: dynamic shape in pynative and graph mode in GPU + Expectation: success + """ + x = Tensor(np.array([1, 20, 5]).astype(np.int32)) + y = Tensor(np.array([2, 20, 5]).astype(np.int32)) + xx = Tensor(shape=[None], dtype=mindspore.int32) + yy = Tensor(shape=[None], dtype=mindspore.int32) + expect = np.array([2]).astype(np.int32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + equal_count = NetEqualCount() + equal_count.set_inputs(xx, yy) + output = equal_count(x, y) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + equal_count = NetEqualCount() + equal_count.set_inputs(xx, yy) + output = equal_count(x, y) + assert (output.asnumpy() == expect).all() diff --git a/tests/st/ops/gpu/test_error_on_dynamic_shape_input_op.py b/tests/st/ops/gpu/test_error_on_dynamic_shape_input_op.py index ed2982076d1..18298e4b6b7 100644 --- a/tests/st/ops/gpu/test_error_on_dynamic_shape_input_op.py +++ b/tests/st/ops/gpu/test_error_on_dynamic_shape_input_op.py @@ -1,55 +1,55 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest - -from mindspore.ops.operations import _inner_ops as inner -import mindspore.context as context - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_error_on_dynamic_shape_input_is_dynamic(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() - - with pytest.raises(ValueError) as info: - error_on_dynamic_shape_input.infer_shape([-1]) - assert "Input is dynamically shaped" in str(info.value) - - with pytest.raises(ValueError) as info: - error_on_dynamic_shape_input.infer_shape([1, 1, -1]) - assert "Input is dynamically shaped" in str(info.value) - - with pytest.raises(ValueError) as info: - error_on_dynamic_shape_input.infer_shape([-1, 1, 1]) - assert "Input is dynamically shaped" in str(info.value) - - with pytest.raises(ValueError) as info: - error_on_dynamic_shape_input.infer_shape([1, -1, 1]) - assert "Input is dynamically shaped" in str(info.value) - - with pytest.raises(ValueError) as info: - error_on_dynamic_shape_input.infer_shape([-1, -1, -1]) - assert "Input is dynamically shaped" in str(info.value) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_error_on_dynamic_shape_input_not_dynamic(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() - error_on_dynamic_shape_input([1]) - error_on_dynamic_shape_input([1, 1]) - error_on_dynamic_shape_input([23, 12, 9712]) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest + +from mindspore.ops.operations import _inner_ops as inner +import mindspore.context as context + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_error_on_dynamic_shape_input_is_dynamic(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() + + with pytest.raises(ValueError) as info: + error_on_dynamic_shape_input.infer_shape([-1]) + assert "Input is dynamically shaped" in str(info.value) + + with pytest.raises(ValueError) as info: + error_on_dynamic_shape_input.infer_shape([1, 1, -1]) + assert "Input is dynamically shaped" in str(info.value) + + with pytest.raises(ValueError) as info: + error_on_dynamic_shape_input.infer_shape([-1, 1, 1]) + assert "Input is dynamically shaped" in str(info.value) + + with pytest.raises(ValueError) as info: + error_on_dynamic_shape_input.infer_shape([1, -1, 1]) + assert "Input is dynamically shaped" in str(info.value) + + with pytest.raises(ValueError) as info: + error_on_dynamic_shape_input.infer_shape([-1, -1, -1]) + assert "Input is dynamically shaped" in str(info.value) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_error_on_dynamic_shape_input_not_dynamic(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() + error_on_dynamic_shape_input([1]) + error_on_dynamic_shape_input([1, 1]) + error_on_dynamic_shape_input([23, 12, 9712]) diff --git a/tests/st/ops/gpu/test_fftwithsize.py b/tests/st/ops/gpu/test_fftwithsize.py old mode 100755 new mode 100644 index c439c0b03fb..6d13bfabd02 --- a/tests/st/ops/gpu/test_fftwithsize.py +++ b/tests/st/ops/gpu/test_fftwithsize.py @@ -1,161 +1,161 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.context as context -from mindspore import Tensor, ops -from tests.st.utils import test_utils -from tests.mark_utils import arg_mark - -@test_utils.run_with_cell -def fft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) - -@test_utils.run_with_cell -def rfft_and_irfft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - x = ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) - return ops.FFTWithSize(signal_ndim, not inverse, real, norm, onesided, signal_sizes)(x) - -@test_utils.run_with_cell -def fft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.grad(fft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) - -@test_utils.run_with_cell -def rfft_and_irfft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): - return ops.grad(rfft_and_irfft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) -def test_fftwithsize_fft_ifft(dtype, eps): - """ - Feature: fft & ifft function - Description: test cases for fft & ifft - Expectation: the result matches pytorch - """ - x = Tensor(np.array([1.6243454+0.j, -0.6117564+0.j, -0.5281718+0.j, -1.0729686+0.j]).astype(dtype)) - expect = np.array([-0.5885514+0.j, 2.1525173-0.46121222j, 2.7808986+0.j, 2.1525173+0.46121222j]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - output = fft_forward_func(x, 1, False, False) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - output_ifft = fft_forward_func(output, 1, True, False) - diff_ifft = np.abs(output_ifft.asnumpy() - x.asnumpy()) - assert np.all(diff_ifft < error) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) -def test_fftwithsize_fft2_ifft2(dtype, eps): - """ - Feature: fft2 & ifft2 function - Description: test cases for fft2 & ifft2 - Expectation: the result matches pytorch - """ - x = Tensor(np.array([[1.6243454+0.j, -0.6117564+0.j], [-0.5281718+0.j, -1.0729686+0.j]]).astype(dtype)) - expect = np.array([[-0.5885514+0.j, 2.7808986+0.j], [2.6137295+0.j, 1.6913052+0.j]]).astype(dtype) - error = np.ones(shape=[2, 2]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - output = fft_forward_func(x, 2, False, False) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - output_ifft2 = fft_forward_func(output, 2, True, False) - diff_ifft2 = np.abs(output_ifft2.asnumpy() - x.asnumpy()) - assert np.all(diff_ifft2 < error) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_forward(mode): - """ - Feature: rfft3 forward function - Description: test cases for rfft - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - x = np.arange(1 * 2 * 3 * 4, dtype=np.float64).reshape(1, 2, 3, 4) - ms_x = ms.Tensor(x) - output = fft_forward_func(ms_x, 3, False, True) - expect = np.fft.rfftn(x, s=(2, 3, 4)) - assert np.allclose(output.asnumpy(), expect) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_irfft3_forward(mode): - """ - Feature: irfft3 forward function - Description: test cases for irfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - x = np.arange(1 * 2 * 3 * 3, dtype=np.complex128).reshape(1, 2, 3, 3) - ms_x = ms.Tensor(x) - output = fft_forward_func(ms_x, 3, True, True) - expect = np.fft.irfftn(x, s=(2, 3, 4)) - assert np.allclose(output.asnumpy(), expect) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_backward(mode): - """ - Feature: rfft3 backward function - Description: test cases for rfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - dim1 = 1 - dim2 = 2 - dim3 = 3 - dim4 = 4 - offset_size = dim1 * dim2 * dim3 * dim4 - x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) - ms_x = ms.Tensor(x) - output = fft_backward_func(ms_x, 3, False, True) - dout = np.ones((dim1, dim2, dim3, dim4 // 2 + 1), dtype=np.complex128) - concat_array = np.zeros((dim1, dim2, dim3, dim4 - dim4 // 2 - 1)) - concat_array = concat_array.astype(np.complex128) - dout = np.concatenate((dout, concat_array), axis=-1) - expect = np.fft.ifftn(dout, s=(dim2, dim3, dim4)) * offset_size - assert np.allclose(output.asnumpy(), expect.real) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fft_with_size_rfft3_and_irfft3_backward(mode): - """ - Feature: rfft3_and_irfft3 function - Description: test cases for rfft3_and_irfft3 - Expectation: the result matches pytorch - """ - ms.context.set_context(mode=mode) - dim1 = 1 - dim2 = 2 - dim3 = 3 - dim4 = 4 - offset_size = dim1 * dim2 * dim3 * dim4 - x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) - ms_x = ms.Tensor(x) - output = rfft_and_irfft_backward_func(ms_x, 3, False, True) - expect = np.ones((dim1, dim2, dim3, dim4)) - assert np.allclose(output.asnumpy(), expect) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, ops +from tests.st.utils import test_utils +from tests.mark_utils import arg_mark + +@test_utils.run_with_cell +def fft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) + +@test_utils.run_with_cell +def rfft_and_irfft_forward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + x = ops.FFTWithSize(signal_ndim, inverse, real, norm, onesided, signal_sizes)(x) + return ops.FFTWithSize(signal_ndim, not inverse, real, norm, onesided, signal_sizes)(x) + +@test_utils.run_with_cell +def fft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.grad(fft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) + +@test_utils.run_with_cell +def rfft_and_irfft_backward_func(x, signal_ndim, inverse, real, norm='backward', onesided=True, signal_sizes=()): + return ops.grad(rfft_and_irfft_forward_func, (0,))(x, signal_ndim, inverse, real, norm, onesided, signal_sizes) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) +def test_fftwithsize_fft_ifft(dtype, eps): + """ + Feature: fft & ifft function + Description: test cases for fft & ifft + Expectation: the result matches pytorch + """ + x = Tensor(np.array([1.6243454+0.j, -0.6117564+0.j, -0.5281718+0.j, -1.0729686+0.j]).astype(dtype)) + expect = np.array([-0.5885514+0.j, 2.1525173-0.46121222j, 2.7808986+0.j, 2.1525173+0.46121222j]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + output = fft_forward_func(x, 1, False, False) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + output_ifft = fft_forward_func(output, 1, True, False) + diff_ifft = np.abs(output_ifft.asnumpy() - x.asnumpy()) + assert np.all(diff_ifft < error) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.complex64, 1e-6), (np.complex128, 1e-6)]) +def test_fftwithsize_fft2_ifft2(dtype, eps): + """ + Feature: fft2 & ifft2 function + Description: test cases for fft2 & ifft2 + Expectation: the result matches pytorch + """ + x = Tensor(np.array([[1.6243454+0.j, -0.6117564+0.j], [-0.5281718+0.j, -1.0729686+0.j]]).astype(dtype)) + expect = np.array([[-0.5885514+0.j, 2.7808986+0.j], [2.6137295+0.j, 1.6913052+0.j]]).astype(dtype) + error = np.ones(shape=[2, 2]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + output = fft_forward_func(x, 2, False, False) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + output_ifft2 = fft_forward_func(output, 2, True, False) + diff_ifft2 = np.abs(output_ifft2.asnumpy() - x.asnumpy()) + assert np.all(diff_ifft2 < error) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_forward(mode): + """ + Feature: rfft3 forward function + Description: test cases for rfft + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + x = np.arange(1 * 2 * 3 * 4, dtype=np.float64).reshape(1, 2, 3, 4) + ms_x = ms.Tensor(x) + output = fft_forward_func(ms_x, 3, False, True) + expect = np.fft.rfftn(x, s=(2, 3, 4)) + assert np.allclose(output.asnumpy(), expect) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_irfft3_forward(mode): + """ + Feature: irfft3 forward function + Description: test cases for irfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + x = np.arange(1 * 2 * 3 * 3, dtype=np.complex128).reshape(1, 2, 3, 3) + ms_x = ms.Tensor(x) + output = fft_forward_func(ms_x, 3, True, True) + expect = np.fft.irfftn(x, s=(2, 3, 4)) + assert np.allclose(output.asnumpy(), expect) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_backward(mode): + """ + Feature: rfft3 backward function + Description: test cases for rfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + dim1 = 1 + dim2 = 2 + dim3 = 3 + dim4 = 4 + offset_size = dim1 * dim2 * dim3 * dim4 + x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) + ms_x = ms.Tensor(x) + output = fft_backward_func(ms_x, 3, False, True) + dout = np.ones((dim1, dim2, dim3, dim4 // 2 + 1), dtype=np.complex128) + concat_array = np.zeros((dim1, dim2, dim3, dim4 - dim4 // 2 - 1)) + concat_array = concat_array.astype(np.complex128) + dout = np.concatenate((dout, concat_array), axis=-1) + expect = np.fft.ifftn(dout, s=(dim2, dim3, dim4)) * offset_size + assert np.allclose(output.asnumpy(), expect.real) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fft_with_size_rfft3_and_irfft3_backward(mode): + """ + Feature: rfft3_and_irfft3 function + Description: test cases for rfft3_and_irfft3 + Expectation: the result matches pytorch + """ + ms.context.set_context(mode=mode) + dim1 = 1 + dim2 = 2 + dim3 = 3 + dim4 = 4 + offset_size = dim1 * dim2 * dim3 * dim4 + x = np.arange(offset_size, dtype=np.float64).reshape(dim1, dim2, dim3, dim4) + ms_x = ms.Tensor(x) + output = rfft_and_irfft_backward_func(ms_x, 3, False, True) + expect = np.ones((dim1, dim2, dim3, dim4)) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_floordiv_op.py b/tests/st/ops/gpu/test_floordiv_op.py index bcf8bdf239a..f5b0dd0ddeb 100644 --- a/tests/st/ops/gpu/test_floordiv_op.py +++ b/tests/st/ops/gpu/test_floordiv_op.py @@ -1,115 +1,115 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P - -class NetFloorDiv(nn.Cell): - def __init__(self): - super(NetFloorDiv, self).__init__() - self.floordiv = P.FloorDiv() - - def construct(self, x, y): - return self.floordiv(x, y) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_floor_div(): - x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) - y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) - x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) - y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) - x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32) - y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32) - x3_np = np.random.randint(1, 5, 1).astype(np.float32) - y3_np = np.random.randint(1, 5, 1).astype(np.float32) - x4_np = np.array(768).astype(np.float32) - y4_np = np.array(3072.5).astype(np.float32) - x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) - y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) - x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32) - y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32) - - x0 = Tensor(x0_np) - y0 = Tensor(y0_np) - x1 = Tensor(x1_np) - y1 = Tensor(y1_np) - x2 = Tensor(x2_np) - y2 = Tensor(y2_np) - x3 = Tensor(x3_np) - y3 = Tensor(y3_np) - x4 = Tensor(x4_np) - y4 = Tensor(y4_np) - x5 = Tensor(x5_np) - y5 = Tensor(y5_np) - x6 = Tensor(x6_np) - y6 = Tensor(y6_np) - - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - floor_div = NetFloorDiv() - output0 = floor_div(x0, y0) - expect0 = np.floor_divide(x0_np, y0_np) - diff0 = output0.asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output0.shape == expect0.shape - - output1 = floor_div(x1, y1) - expect1 = np.floor_divide(x1_np, y1_np) - diff1 = output1.asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output1.shape == expect1.shape - - output2 = floor_div(x2, y2) - expect2 = np.floor_divide(x2_np, y2_np) - diff2 = output2.asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output2.shape == expect2.shape - - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - output3 = floor_div(x3, y3) - expect3 = np.floor_divide(x3_np, y3_np) - diff3 = output3.asnumpy() - expect3 - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output3.shape == expect3.shape - - output4 = floor_div(x4, y4) - expect4 = np.floor_divide(x4_np, y4_np) - diff4 = output4.asnumpy() - expect4 - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output4.shape == expect4.shape - - output5 = floor_div(x5, y5) - expect5 = np.floor_divide(x5_np, y5_np) - diff5 = output5.asnumpy() - expect5 - error5 = np.ones(shape=expect5.shape) * 1.0e-5 - assert np.all(diff5 < error5) - assert output5.shape == expect5.shape - - output6 = floor_div(x6, y6) - expect6 = np.floor_divide(x6_np, y6_np) - diff6 = output6.asnumpy() - expect6 - error6 = np.ones(shape=expect6.shape) * 1.0e-5 - assert np.all(diff6 < error6) - assert output6.shape == expect6.shape +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class NetFloorDiv(nn.Cell): + def __init__(self): + super(NetFloorDiv, self).__init__() + self.floordiv = P.FloorDiv() + + def construct(self, x, y): + return self.floordiv(x, y) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_floor_div(): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32) + x3_np = np.random.randint(1, 5, 1).astype(np.float32) + y3_np = np.random.randint(1, 5, 1).astype(np.float32) + x4_np = np.array(768).astype(np.float32) + y4_np = np.array(3072.5).astype(np.float32) + x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32) + y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + x5 = Tensor(x5_np) + y5 = Tensor(y5_np) + x6 = Tensor(x6_np) + y6 = Tensor(y6_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + floor_div = NetFloorDiv() + output0 = floor_div(x0, y0) + expect0 = np.floor_divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = floor_div(x1, y1) + expect1 = np.floor_divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = floor_div(x2, y2) + expect2 = np.floor_divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + output3 = floor_div(x3, y3) + expect3 = np.floor_divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = floor_div(x4, y4) + expect4 = np.floor_divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + output5 = floor_div(x5, y5) + expect5 = np.floor_divide(x5_np, y5_np) + diff5 = output5.asnumpy() - expect5 + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output5.shape == expect5.shape + + output6 = floor_div(x6, y6) + expect6 = np.floor_divide(x6_np, y6_np) + diff6 = output6.asnumpy() - expect6 + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output6.shape == expect6.shape diff --git a/tests/st/ops/gpu/test_fractionalavgpool_op.py b/tests/st/ops/gpu/test_fractionalavgpool_op.py index 3e69c770743..a7e91e9b717 100644 --- a/tests/st/ops/gpu/test_fractionalavgpool_op.py +++ b/tests/st/ops/gpu/test_fractionalavgpool_op.py @@ -1,196 +1,196 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalAvgPool(nn.Cell): - def __init__(self): - super(NetFractionalAvgPool, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolRealRandom(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolRealRandom, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, - pseudo_random=False, seed=5454, seed2=144) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolOverlapPing(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolOverlapPing, self).__init__() - self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - - def construct(self, x): - return self.fractional_avg_pool(x) - - -class NetFractionalAvgPoolGrad(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolGrad, self).__init__() - self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad() - - def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -class NetFractionalAvgPoolGradOverlapping(nn.Cell): - def __init__(self): - super(NetFractionalAvgPoolGradOverlapping, self).__init__() - self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad(overlapping=True) - - def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpool_graph(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalAvgPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalAvgPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalAvgPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - netgrad = NetFractionalAvgPoolGradOverlapping() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], - [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalavgpool_pynative(): - """ - Feature: FractionalAvgPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - output = fractionalavgpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], - deterministic=True, pseudo_random=False, seed=5454, seed2=144) - output = fractionalavgpool(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - output = fractionalavgpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad() - x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) - expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]], - [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) - - fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad(overlapping=True) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) - expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], - [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], - [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) - assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalAvgPool(nn.Cell): + def __init__(self): + super(NetFractionalAvgPool, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolRealRandom(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolRealRandom, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, + pseudo_random=False, seed=5454, seed2=144) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolOverlapPing(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolOverlapPing, self).__init__() + self.fractional_avg_pool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + + def construct(self, x): + return self.fractional_avg_pool(x) + + +class NetFractionalAvgPoolGrad(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolGrad, self).__init__() + self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad() + + def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +class NetFractionalAvgPoolGradOverlapping(nn.Cell): + def __init__(self): + super(NetFractionalAvgPoolGradOverlapping, self).__init__() + self.fractional_avg_pool_grad = grad_ops.FractionalAvgPoolGrad(overlapping=True) + + def construct(self, orig_input, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_avg_pool_grad(orig_input, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpool_graph(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalAvgPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalAvgPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalAvgPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + netgrad = NetFractionalAvgPoolGradOverlapping() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x_shape, out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], + [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalavgpool_pynative(): + """ + Feature: FractionalAvgPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + output = fractionalavgpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[3.5], [5.5]], [[11.5], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], + deterministic=True, pseudo_random=False, seed=5454, seed2=144) + output = fractionalavgpool(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + fractionalavgpool = ops.FractionalAvgPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + output = fractionalavgpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [7.5]], [[12], [13.5]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad() + x_shape = Tensor(np.array([1, 4, 4, 1]).astype(np.int64)) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) + expect_output_grad_y = np.array([[[[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]], + [[0.25], [0.25], [0.25], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) + + fractionalavgpoolgrad = grad_ops.FractionalAvgPoolGrad(overlapping=True) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalavgpoolgrad(x_shape, out_backprop, output[1], output[2]) + expect_output_grad_y = np.array([[[[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.11111111], [0.11111111], [0.2777778], [0.16666667]], + [[0.2777778], [0.2777778], [0.6944444], [0.41666666]], + [[0.16666667], [0.16666667], [0.41666666], [0.25]]]]).astype(type_i) + assert np.allclose(output_grad.asnumpy(), expect_output_grad_y) diff --git a/tests/st/ops/gpu/test_fractionalmaxpool3dwithfixedksize_op.py b/tests/st/ops/gpu/test_fractionalmaxpool3dwithfixedksize_op.py index 5b37d66ca79..c6ea9df7430 100644 --- a/tests/st/ops/gpu/test_fractionalmaxpool3dwithfixedksize_op.py +++ b/tests/st/ops/gpu/test_fractionalmaxpool3dwithfixedksize_op.py @@ -1,80 +1,80 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalMaxPool3DWithFixedKsize(nn.Cell): - def __init__(self, ksize, output_shape): - super(NetFractionalMaxPool3DWithFixedKsize, self).__init__() - self.fractional_max_pool_3d_with_fixed_ksize = ops.FractionalMaxPool3DWithFixedKsize(ksize, output_shape) - - def construct(self, x, random_sapmples): - return self.fractional_max_pool_3d_with_fixed_ksize(x, random_sapmples) - - -class NetFractionalMaxPool3DGradWithFixedKsize(nn.Cell): - def __init__(self): - super(NetFractionalMaxPool3DGradWithFixedKsize, self).__init__() - self.fractional_max_pool_3d_grad_with_fixed_ksize = grad_ops.FractionalMaxPool3DGradWithFixedKsize() - - def construct(self, origin_input, out_backprop, argmax): - return self.fractional_max_pool_3d_grad_with_fixed_ksize(origin_input, out_backprop, argmax) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool3dwithfixedksize(): - """ - Feature: FractionalMaxPool3DWithFixedKsize - Description: Test of input - Expectation: The results are as expected - """ - context_mode_types = [context.GRAPH_MODE, context.PYNATIVE_MODE] - types_input1 = [np.float16, np.float32, np.int32, np.int64] - types_input2 = [np.float16, np.float32] - for context_mode_type in context_mode_types: - context.set_context(mode=context_mode_type, device_target='GPU') - for type_input1 in types_input1: - for type_input2 in types_input2: - x_np = np.array([i+1 for i in range(64)]).reshape([1, 1, 4, 4, 4]).astype(type_input1) - x_ms = Tensor(x_np) - random_samples = Tensor(np.array([0.5, 0.5, 0.8]).reshape([1, 1, 3]).astype(type_input2)) - ksize = (1, 1, 1) - output_shape = (2, 2, 3) - net = NetFractionalMaxPool3DWithFixedKsize(ksize, output_shape) - output_ms, argmax = net(x_ms, random_samples) - expect_output = np.array([[[[[1, 2, 4], [13, 14, 16]], - [[49, 50, 52], [61, 62, 64]]]]]).astype(type_input1) - expect_output_argmax = np.array([[[[[0, 1, 3], [12, 13, 15]], - [[48, 49, 51], [60, 61, 63]]]]]).astype(type_input2) - assert np.allclose(output_ms.asnumpy(), expect_output) - assert np.allclose(argmax.asnumpy(), expect_output_argmax) - - out_backprop = Tensor(np.array([i+1 for i in range(12)]).reshape([1, 1, 2, 2, 3]).astype(type_input1)) - net_grad = NetFractionalMaxPool3DGradWithFixedKsize() - output_grad = net_grad(x_ms, out_backprop, argmax) - expect_output_grad = np.array([[[[[1, 2, 0, 3], [0, 0, 0, 0], [0, 0, 0, 0], [4, 5, 0, 6]], - [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], - [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], - [[7, 8, 0, 9], [0, 0, 0, 0], [0, 0, 0, 0], - [10, 11, 0, 12]]]]]).astype(type_input2) - assert np.allclose(output_grad.asnumpy(), expect_output_grad) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalMaxPool3DWithFixedKsize(nn.Cell): + def __init__(self, ksize, output_shape): + super(NetFractionalMaxPool3DWithFixedKsize, self).__init__() + self.fractional_max_pool_3d_with_fixed_ksize = ops.FractionalMaxPool3DWithFixedKsize(ksize, output_shape) + + def construct(self, x, random_sapmples): + return self.fractional_max_pool_3d_with_fixed_ksize(x, random_sapmples) + + +class NetFractionalMaxPool3DGradWithFixedKsize(nn.Cell): + def __init__(self): + super(NetFractionalMaxPool3DGradWithFixedKsize, self).__init__() + self.fractional_max_pool_3d_grad_with_fixed_ksize = grad_ops.FractionalMaxPool3DGradWithFixedKsize() + + def construct(self, origin_input, out_backprop, argmax): + return self.fractional_max_pool_3d_grad_with_fixed_ksize(origin_input, out_backprop, argmax) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool3dwithfixedksize(): + """ + Feature: FractionalMaxPool3DWithFixedKsize + Description: Test of input + Expectation: The results are as expected + """ + context_mode_types = [context.GRAPH_MODE, context.PYNATIVE_MODE] + types_input1 = [np.float16, np.float32, np.int32, np.int64] + types_input2 = [np.float16, np.float32] + for context_mode_type in context_mode_types: + context.set_context(mode=context_mode_type, device_target='GPU') + for type_input1 in types_input1: + for type_input2 in types_input2: + x_np = np.array([i+1 for i in range(64)]).reshape([1, 1, 4, 4, 4]).astype(type_input1) + x_ms = Tensor(x_np) + random_samples = Tensor(np.array([0.5, 0.5, 0.8]).reshape([1, 1, 3]).astype(type_input2)) + ksize = (1, 1, 1) + output_shape = (2, 2, 3) + net = NetFractionalMaxPool3DWithFixedKsize(ksize, output_shape) + output_ms, argmax = net(x_ms, random_samples) + expect_output = np.array([[[[[1, 2, 4], [13, 14, 16]], + [[49, 50, 52], [61, 62, 64]]]]]).astype(type_input1) + expect_output_argmax = np.array([[[[[0, 1, 3], [12, 13, 15]], + [[48, 49, 51], [60, 61, 63]]]]]).astype(type_input2) + assert np.allclose(output_ms.asnumpy(), expect_output) + assert np.allclose(argmax.asnumpy(), expect_output_argmax) + + out_backprop = Tensor(np.array([i+1 for i in range(12)]).reshape([1, 1, 2, 2, 3]).astype(type_input1)) + net_grad = NetFractionalMaxPool3DGradWithFixedKsize() + output_grad = net_grad(x_ms, out_backprop, argmax) + expect_output_grad = np.array([[[[[1, 2, 0, 3], [0, 0, 0, 0], [0, 0, 0, 0], [4, 5, 0, 6]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[7, 8, 0, 9], [0, 0, 0, 0], [0, 0, 0, 0], + [10, 11, 0, 12]]]]]).astype(type_input2) + assert np.allclose(output_grad.asnumpy(), expect_output_grad) diff --git a/tests/st/ops/gpu/test_fractionalmaxpool_op.py b/tests/st/ops/gpu/test_fractionalmaxpool_op.py index 2f2629a14ba..1f8210ee8c6 100644 --- a/tests/st/ops/gpu/test_fractionalmaxpool_op.py +++ b/tests/st/ops/gpu/test_fractionalmaxpool_op.py @@ -1,188 +1,188 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalMaxPool(nn.Cell): - def __init__(self): - super(NetFractionalMaxPool, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolRealRandom(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolRealRandom, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, - pseudo_random=False, seed=5454, seed2=144) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolOverlapPing(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolOverlapPing, self).__init__() - self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - - def construct(self, x): - return self.fractional_max_pool(x) - - -class NetFractionalMaxPoolGrad(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolGrad, self).__init__() - self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad() - - def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -class NetFractionalMaxPoolGradOverlapping(nn.Cell): - def __init__(self): - super(NetFractionalMaxPoolGradOverlapping, self).__init__() - self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad(overlapping=True) - - def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): - return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, - col_pooling_sequence) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool_graph(): - """ - Feature: FractionalMaxPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - net = NetFractionalMaxPool() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - net = NetFractionalMaxPoolRealRandom() - output = net(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - net = NetFractionalMaxPoolOverlapPing() - output = net(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - netgrad = NetFractionalMaxPoolGrad() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], - [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - netgrad = NetFractionalMaxPoolGradOverlapping() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], - [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool_pynative(): - """ - Feature: FractionalMaxPool - Description: Test of input - Expectation: The results are as expected - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - types = [np.float32, np.float64, np.int32, np.int64] - for type_i in types: - x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, - 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) - output = fractionalmaxpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], - deterministic=True, pseudo_random=False, seed=5454, seed2=144) - output = fractionalmaxpool(x) - type0 = output[0].asnumpy().dtype - assert type0 == type_i - - fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) - output = fractionalmaxpool(x) - output_y = output[0].asnumpy() - output_row_pooling_sequence = output[1].asnumpy() - output_col_pooling_sequence = output[2].asnumpy() - expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) - expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) - assert np.allclose(output_y, expect_output_y) - assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) - assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) - - fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad() - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], - [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) - - fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad(overlapping=True) - out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) - output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) - output_grad_y = output_grad[0].asnumpy() - expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], - [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) - assert np.allclose(output_grad_y, expect_output_grad_y) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalMaxPool(nn.Cell): + def __init__(self): + super(NetFractionalMaxPool, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolRealRandom(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolRealRandom, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], deterministic=True, + pseudo_random=False, seed=5454, seed2=144) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolOverlapPing(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolOverlapPing, self).__init__() + self.fractional_max_pool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + + def construct(self, x): + return self.fractional_max_pool(x) + + +class NetFractionalMaxPoolGrad(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolGrad, self).__init__() + self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad() + + def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +class NetFractionalMaxPoolGradOverlapping(nn.Cell): + def __init__(self): + super(NetFractionalMaxPoolGradOverlapping, self).__init__() + self.fractional_max_pool_grad = grad_ops.FractionalMaxPoolGrad(overlapping=True) + + def construct(self, orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence): + return self.fractional_max_pool_grad(orig_input, orig_output, out_backprop, row_pooling_sequence, + col_pooling_sequence) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool_graph(): + """ + Feature: FractionalMaxPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + net = NetFractionalMaxPool() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + net = NetFractionalMaxPoolRealRandom() + output = net(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + net = NetFractionalMaxPoolOverlapPing() + output = net(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + netgrad = NetFractionalMaxPoolGrad() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], + [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + netgrad = NetFractionalMaxPoolGradOverlapping() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = netgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], + [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool_pynative(): + """ + Feature: FractionalMaxPool + Description: Test of input + Expectation: The results are as expected + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + types = [np.float32, np.float64, np.int32, np.int64] + for type_i in types: + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16]).reshape([1, 4, 4, 1]).astype(type_i)) + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0]) + output = fractionalmaxpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[6], [8]], [[14], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], + deterministic=True, pseudo_random=False, seed=5454, seed2=144) + output = fractionalmaxpool(x) + type0 = output[0].asnumpy().dtype + assert type0 == type_i + + fractionalmaxpool = ops.FractionalMaxPool(pooling_ratio=[1.0, 1.5, 1.5, 1.0], overlapping=True) + output = fractionalmaxpool(x) + output_y = output[0].asnumpy() + output_row_pooling_sequence = output[1].asnumpy() + output_col_pooling_sequence = output[2].asnumpy() + expect_output_y = np.array([[[[11], [12]], [[15], [16]]]]).astype(type_i) + expect_output_row_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + expect_output_col_pooling_sequence = np.array([0, 2, 4]).astype(np.int64) + assert np.allclose(output_y, expect_output_y) + assert np.allclose(output_row_pooling_sequence, expect_output_row_pooling_sequence) + assert np.allclose(output_col_pooling_sequence, expect_output_col_pooling_sequence) + + fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad() + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [1], [0], [1]], + [[0], [0], [0], [0]], [[0], [1], [0], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) + + fractionalmaxpoolgrad = grad_ops.FractionalMaxPoolGrad(overlapping=True) + out_backprop = Tensor(np.ones([1, 2, 2, 1]).astype(type_i)) + output_grad = fractionalmaxpoolgrad(x, output[0], out_backprop, output[1], output[2]) + output_grad_y = output_grad[0].asnumpy() + expect_output_grad_y = np.array([[[[0], [0], [0], [0]], [[0], [0], [0], [0]], + [[0], [0], [1], [1]], [[0], [0], [1], [1]]]]).astype(type_i) + assert np.allclose(output_grad_y, expect_output_grad_y) diff --git a/tests/st/ops/gpu/test_gathernd_op.py b/tests/st/ops/gpu/test_gathernd_op.py index 94813c38604..ec8f7b46e8f 100644 --- a/tests/st/ops/gpu/test_gathernd_op.py +++ b/tests/st/ops/gpu/test_gathernd_op.py @@ -1,201 +1,201 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import mindspore.context as context - -class GatherNdNet(nn.Cell): - def __init__(self): - super(GatherNdNet, self).__init__() - self.gathernd = P.GatherNd() - - def construct(self, x, indices): - return self.gathernd(x, indices) - - -def gathernd0(nptype): - x = Tensor(np.arange(3 * 2, dtype=nptype).reshape(3, 2)) - indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) - expect = np.array([3, 1]).astype(nptype) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gathernd = GatherNdNet() - output = gathernd(x, indices) - - assert np.array_equal(output.asnumpy(), expect) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_float64(): - gathernd0(np.float64) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_float32(): - gathernd0(np.float32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_float16(): - gathernd0(np.float16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_int32(): - gathernd0(np.int32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_int16(): - gathernd0(np.int16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd0_uint8(): - gathernd0(np.uint8) - -def gathernd1(nptype): - x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=nptype).reshape(2, 3, 4, 5)) - indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] - for k in range(3)] for l in range(2)], dtype='i4')) - expect = np.array([[[[1., 3., 4.], - [6., 8., 9.], - [11., 13., 14.], - [16., 18., 19.]], - - [[21., 23., 24.], - [26., 28., 29.], - [31., 33., 34.], - [36., 38., 39.]], - - [[41., 43., 44.], - [46., 48., 49.], - [51., 53., 54.], - [56., 58., 59.]]], - - [[[61., 63., 64.], - [66., 68., 69.], - [71., 73., 74.], - [76., 78., 79.]], - - [[81., 83., 84.], - [86., 88., 89.], - [91., 93., 94.], - [96., 98., 99.]], - - [[101., 103., 104.], - [106., 108., 109.], - [111., 113., 114.], - [116., 118., 119.]]]]).astype(nptype) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gather = GatherNdNet() - output = gather(x, indices) - - assert np.array_equal(output.asnumpy(), expect) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_float64(): - gathernd1(np.float64) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_float32(): - gathernd1(np.float32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_float16(): - gathernd1(np.float16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_int32(): - gathernd1(np.int32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_int16(): - gathernd1(np.int16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd1_uint8(): - gathernd1(np.uint8) - -def gathernd2(nptype): - x = Tensor(np.array([[4., 5., 4., 1., 5.], - [4., 9., 5., 6., 4.], - [9., 8., 4., 3., 6.], - [0., 4., 2., 2., 8.], - [1., 8., 6., 2., 8.], - [8., 1., 9., 7., 3.], - [7., 9., 2., 5., 7.], - [9., 8., 6., 8., 5.], - [3., 7., 2., 7., 4.], - [4., 2., 8., 2., 9.]]).astype(np.float16)) - - indices = Tensor(np.array([[0], [1], [3]]).astype(np.int32)) - expect = np.array([[4., 5., 4., 1., 5.], - [4., 9., 5., 6., 4.], - [0., 4., 2., 2., 8.]]) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gathernd = GatherNdNet() - output = gathernd(x, indices) - - assert np.array_equal(output.asnumpy(), expect) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_float64(): - gathernd2(np.float64) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_float32(): - gathernd2(np.float32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_float16(): - gathernd2(np.float16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_int32(): - gathernd2(np.int32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_int16(): - gathernd2(np.int16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd2_uint8(): - gathernd2(np.uint8) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd_bool(): - x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool)) - indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int32)) - expect = np.array([True, False, False, False]).astype(np.bool) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gathernd = GatherNdNet() - output = gathernd(x, indices) - - assert np.array_equal(output.asnumpy(), expect) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gathernd_indices_int64(): - x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool)) - indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int64)) - expect = np.array([True, False, False, False]).astype(np.bool) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gathernd = GatherNdNet() - output = gathernd(x, indices) - - assert np.array_equal(output.asnumpy(), expect) +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +class GatherNdNet(nn.Cell): + def __init__(self): + super(GatherNdNet, self).__init__() + self.gathernd = P.GatherNd() + + def construct(self, x, indices): + return self.gathernd(x, indices) + + +def gathernd0(nptype): + x = Tensor(np.arange(3 * 2, dtype=nptype).reshape(3, 2)) + indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) + expect = np.array([3, 1]).astype(nptype) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + assert np.array_equal(output.asnumpy(), expect) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_float64(): + gathernd0(np.float64) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_float32(): + gathernd0(np.float32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_float16(): + gathernd0(np.float16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_int32(): + gathernd0(np.int32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_int16(): + gathernd0(np.int16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd0_uint8(): + gathernd0(np.uint8) + +def gathernd1(nptype): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=nptype).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] + for k in range(3)] for l in range(2)], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(nptype) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNdNet() + output = gather(x, indices) + + assert np.array_equal(output.asnumpy(), expect) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_float64(): + gathernd1(np.float64) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_float32(): + gathernd1(np.float32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_float16(): + gathernd1(np.float16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_int32(): + gathernd1(np.int32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_int16(): + gathernd1(np.int16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd1_uint8(): + gathernd1(np.uint8) + +def gathernd2(nptype): + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]]).astype(np.float16)) + + indices = Tensor(np.array([[0], [1], [3]]).astype(np.int32)) + expect = np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [0., 4., 2., 2., 8.]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + assert np.array_equal(output.asnumpy(), expect) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_float64(): + gathernd2(np.float64) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_float32(): + gathernd2(np.float32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_float16(): + gathernd2(np.float16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_int32(): + gathernd2(np.int32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_int16(): + gathernd2(np.int16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd2_uint8(): + gathernd2(np.uint8) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd_bool(): + x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool)) + indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int32)) + expect = np.array([True, False, False, False]).astype(np.bool) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + assert np.array_equal(output.asnumpy(), expect) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gathernd_indices_int64(): + x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool)) + indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int64)) + expect = np.array([True, False, False, False]).astype(np.bool) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + assert np.array_equal(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_ge_op.py b/tests/st/ops/gpu/test_ge_op.py index bd630238309..aa33c4cce02 100644 --- a/tests/st/ops/gpu/test_ge_op.py +++ b/tests/st/ops/gpu/test_ge_op.py @@ -1,141 +1,141 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, ops -from mindspore.ops import operations as P - - -class OpNetWrapper(nn.Cell): - def __init__(self, op): - super(OpNetWrapper, self).__init__() - self.op = op - - def construct(self, *inputs): - return self.op(*inputs) - - -class GreaterEqualFunc(nn.Cell): - def construct(self, *inputs): - return ops.ge(*inputs) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]) -def test_greater_equal_op_dtype_1(mode, dtype): - """ - Feature: Test GreaterEqual op. - Description: Test GreaterEqual with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.GreaterEqual() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([1, -2, 3]).astype(dtype)) - input_y = Tensor(np.array([3, -2, 1]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, True, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.uint8, np.uint16, np.uint32, np.uint64]) -def test_greater_equal_op_dtype_2(mode, dtype): - """ - Feature: Test GreaterEqual op. - Description: Test GreaterEqual with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.GreaterEqual() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([1, 2, 3]).astype(dtype)) - input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, True, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.bool]) -def test_greater_equal_op_dtype_3(mode, dtype): - """ - Feature: Test GreaterEqual op. - Description: Test GreaterEqual with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.GreaterEqual() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([False, True, True]).astype(dtype)) - input_y = Tensor(np.array([True, True, False]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, True, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_greater_equal_op_functional(mode): - """ - Feature: Test GreaterEqual op. - Description: Test GreaterEqual with functional. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op_wrapper = GreaterEqualFunc() - - input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) - input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, True, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_greater_equal_op_tensor(mode): - """ - Feature: Test GreaterEqual op. - Description: Test GreaterEqual with Tensor. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) - input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) - outputs = input_x.ge(input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, True, True]) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops +from mindspore.ops import operations as P + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +class GreaterEqualFunc(nn.Cell): + def construct(self, *inputs): + return ops.ge(*inputs) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]) +def test_greater_equal_op_dtype_1(mode, dtype): + """ + Feature: Test GreaterEqual op. + Description: Test GreaterEqual with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.GreaterEqual() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, -2, 3]).astype(dtype)) + input_y = Tensor(np.array([3, -2, 1]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_greater_equal_op_dtype_2(mode, dtype): + """ + Feature: Test GreaterEqual op. + Description: Test GreaterEqual with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.GreaterEqual() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 2, 3]).astype(dtype)) + input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.bool]) +def test_greater_equal_op_dtype_3(mode, dtype): + """ + Feature: Test GreaterEqual op. + Description: Test GreaterEqual with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.GreaterEqual() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([False, True, True]).astype(dtype)) + input_y = Tensor(np.array([True, True, False]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_greater_equal_op_functional(mode): + """ + Feature: Test GreaterEqual op. + Description: Test GreaterEqual with functional. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op_wrapper = GreaterEqualFunc() + + input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_greater_equal_op_tensor(mode): + """ + Feature: Test GreaterEqual op. + Description: Test GreaterEqual with Tensor. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) + outputs = input_x.ge(input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, True, True]) diff --git a/tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py b/tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py index 5378114da47..bc2566a8a96 100644 --- a/tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py +++ b/tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py @@ -1,139 +1,139 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -from mindspore import Tensor -from mindspore.ops.operations import _inner_ops as inner -import mindspore.nn as nn -import mindspore.context as context - -# test to make sure this op actually generates a dynamically shaped output -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dyanamic_shape_confirm_dynamic(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - class AssertDynamicShapeNet(nn.Cell): - def __init__(self): - super(AssertDynamicShapeNet, self).__init__() - self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() - self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() - - def construct(self, x): - output = self.gpu_convert_to_dynamic_shape(x) - self.error_on_dynamic_shape_input(output) - return output - - assert_dynamic_shape_net = AssertDynamicShapeNet() - x = Tensor(np.array([0, 0, 0, 0]).astype(np.float32)) - - with pytest.raises(ValueError) as info: - assert_dynamic_shape_net(x) - assert "Input is dynamically shaped" in str(info.value) - -def gpu_convert_to_dynamic_shape(x): - class GpuConvertToDynamicShapeNet(nn.Cell): - def __init__(self): - super(GpuConvertToDynamicShapeNet, self).__init__() - self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() - - def construct(self, x): - return self.gpu_convert_to_dynamic_shape(x) - - gpu_convert_to_dynamic_shape_net = GpuConvertToDynamicShapeNet() - return gpu_convert_to_dynamic_shape_net(Tensor(x)).asnumpy() - -def gpu_convert_to_dynamic_shape_float(dtype): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - np.random.seed(0) - finfo = np.finfo(dtype) - - # np.random.uniform will overflow if we use min/max for float64, so we use - # the finfo for float32, but still test the operator with float64 input. - if dtype == np.float64: - finfo = np.finfo(np.float32) - - float_min = finfo.min - float_max = finfo.max - x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype) - ms_out = gpu_convert_to_dynamic_shape(x) - np.testing.assert_array_equal(x, ms_out) - -def gpu_convert_to_dynamic_shape_int(dtype): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - np.random.seed(0) - iinfo = np.iinfo(dtype) - int_min = iinfo.min - int_max = iinfo.max - x = np.random.uniform(low=int_min, high=int_max, size=12).astype(dtype) - ms_out = gpu_convert_to_dynamic_shape(x) - np.testing.assert_array_equal(x, ms_out) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_bool(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - np.random.seed(0) - x = np.random.choice([False, True], 12) - ms_out = gpu_convert_to_dynamic_shape(x) - np.testing.assert_array_equal(x, ms_out) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_float16(): - gpu_convert_to_dynamic_shape_float(np.float16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_float32(): - gpu_convert_to_dynamic_shape_float(np.float32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_float64(): - gpu_convert_to_dynamic_shape_float(np.float64) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_int8(): - gpu_convert_to_dynamic_shape_int(np.int8) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_int16(): - gpu_convert_to_dynamic_shape_int(np.int16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_int32(): - gpu_convert_to_dynamic_shape_int(np.int32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_int64(): - gpu_convert_to_dynamic_shape_int(np.int64) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_uint8(): - gpu_convert_to_dynamic_shape_int(np.uint8) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_uint16(): - gpu_convert_to_dynamic_shape_int(np.uint16) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_uint32(): - gpu_convert_to_dynamic_shape_int(np.uint32) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_gpu_convert_to_dynamic_shape_uint64(): - gpu_convert_to_dynamic_shape_int(np.uint64) +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as inner +import mindspore.nn as nn +import mindspore.context as context + +# test to make sure this op actually generates a dynamically shaped output +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dyanamic_shape_confirm_dynamic(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + class AssertDynamicShapeNet(nn.Cell): + def __init__(self): + super(AssertDynamicShapeNet, self).__init__() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() + + def construct(self, x): + output = self.gpu_convert_to_dynamic_shape(x) + self.error_on_dynamic_shape_input(output) + return output + + assert_dynamic_shape_net = AssertDynamicShapeNet() + x = Tensor(np.array([0, 0, 0, 0]).astype(np.float32)) + + with pytest.raises(ValueError) as info: + assert_dynamic_shape_net(x) + assert "Input is dynamically shaped" in str(info.value) + +def gpu_convert_to_dynamic_shape(x): + class GpuConvertToDynamicShapeNet(nn.Cell): + def __init__(self): + super(GpuConvertToDynamicShapeNet, self).__init__() + self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + + def construct(self, x): + return self.gpu_convert_to_dynamic_shape(x) + + gpu_convert_to_dynamic_shape_net = GpuConvertToDynamicShapeNet() + return gpu_convert_to_dynamic_shape_net(Tensor(x)).asnumpy() + +def gpu_convert_to_dynamic_shape_float(dtype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + np.random.seed(0) + finfo = np.finfo(dtype) + + # np.random.uniform will overflow if we use min/max for float64, so we use + # the finfo for float32, but still test the operator with float64 input. + if dtype == np.float64: + finfo = np.finfo(np.float32) + + float_min = finfo.min + float_max = finfo.max + x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype) + ms_out = gpu_convert_to_dynamic_shape(x) + np.testing.assert_array_equal(x, ms_out) + +def gpu_convert_to_dynamic_shape_int(dtype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + np.random.seed(0) + iinfo = np.iinfo(dtype) + int_min = iinfo.min + int_max = iinfo.max + x = np.random.uniform(low=int_min, high=int_max, size=12).astype(dtype) + ms_out = gpu_convert_to_dynamic_shape(x) + np.testing.assert_array_equal(x, ms_out) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_bool(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + np.random.seed(0) + x = np.random.choice([False, True], 12) + ms_out = gpu_convert_to_dynamic_shape(x) + np.testing.assert_array_equal(x, ms_out) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_float16(): + gpu_convert_to_dynamic_shape_float(np.float16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_float32(): + gpu_convert_to_dynamic_shape_float(np.float32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_float64(): + gpu_convert_to_dynamic_shape_float(np.float64) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_int8(): + gpu_convert_to_dynamic_shape_int(np.int8) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_int16(): + gpu_convert_to_dynamic_shape_int(np.int16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_int32(): + gpu_convert_to_dynamic_shape_int(np.int32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_int64(): + gpu_convert_to_dynamic_shape_int(np.int64) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_uint8(): + gpu_convert_to_dynamic_shape_int(np.uint8) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_uint16(): + gpu_convert_to_dynamic_shape_int(np.uint16) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_uint32(): + gpu_convert_to_dynamic_shape_int(np.uint32) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_gpu_convert_to_dynamic_shape_uint64(): + gpu_convert_to_dynamic_shape_int(np.uint64) diff --git a/tests/st/ops/gpu/test_gt_op.py b/tests/st/ops/gpu/test_gt_op.py index 94f683a2450..aa6ae4c975b 100644 --- a/tests/st/ops/gpu/test_gt_op.py +++ b/tests/st/ops/gpu/test_gt_op.py @@ -1,141 +1,141 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor, ops -from mindspore.ops import operations as P - - -class OpNetWrapper(nn.Cell): - def __init__(self, op): - super(OpNetWrapper, self).__init__() - self.op = op - - def construct(self, *inputs): - return self.op(*inputs) - - -class GreaterFunc(nn.Cell): - def construct(self, *inputs): - return ops.gt(*inputs) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]) -def test_greater_op_dtype_1(mode, dtype): - """ - Feature: Test Greater op. - Description: Test Greater with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.Greater() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([1, -2, 3]).astype(dtype)) - input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, False, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.uint8, np.uint16, np.uint32, np.uint64]) -def test_greater_op_dtype_2(mode, dtype): - """ - Feature: Test Greater op. - Description: Test Greater with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.Greater() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([1, 0, 3]).astype(dtype)) - input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, False, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [np.bool]) -def test_greater_op_dtype_3(mode, dtype): - """ - Feature: Test Greater op. - Description: Test Greater with dtype input. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op = P.Greater() - op_wrapper = OpNetWrapper(op) - - input_x = Tensor(np.array([False, False, True]).astype(dtype)) - input_y = Tensor(np.array([True, True, False]).astype(dtype)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, False, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_greater_op_functional(mode): - """ - Feature: Test Greater op. - Description: Test Greater with with functional. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - op_wrapper = GreaterFunc() - - input_x = Tensor(np.array([1, -2, 3]).astype(np.float32)) - input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) - outputs = op_wrapper(input_x, input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, False, True]) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_greater_op_tensor(mode): - """ - Feature: Test Greater op. - Description: Test Greater with Tensor. - Expectation: The result match to the expect value. - """ - context.set_context(mode=mode, device_target="GPU") - - input_x = Tensor(np.array([1, -2, 3]).astype(np.float32)) - input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) - outputs = input_x.gt(input_y) - - assert outputs.shape == (3,) - assert np.allclose(outputs.asnumpy(), [False, False, True]) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops +from mindspore.ops import operations as P + + +class OpNetWrapper(nn.Cell): + def __init__(self, op): + super(OpNetWrapper, self).__init__() + self.op = op + + def construct(self, *inputs): + return self.op(*inputs) + + +class GreaterFunc(nn.Cell): + def construct(self, *inputs): + return ops.gt(*inputs) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]) +def test_greater_op_dtype_1(mode, dtype): + """ + Feature: Test Greater op. + Description: Test Greater with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.Greater() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, -2, 3]).astype(dtype)) + input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_greater_op_dtype_2(mode, dtype): + """ + Feature: Test Greater op. + Description: Test Greater with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.Greater() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([1, 0, 3]).astype(dtype)) + input_y = Tensor(np.array([3, 2, 1]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [np.bool]) +def test_greater_op_dtype_3(mode, dtype): + """ + Feature: Test Greater op. + Description: Test Greater with dtype input. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op = P.Greater() + op_wrapper = OpNetWrapper(op) + + input_x = Tensor(np.array([False, False, True]).astype(dtype)) + input_y = Tensor(np.array([True, True, False]).astype(dtype)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_greater_op_functional(mode): + """ + Feature: Test Greater op. + Description: Test Greater with with functional. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + op_wrapper = GreaterFunc() + + input_x = Tensor(np.array([1, -2, 3]).astype(np.float32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) + outputs = op_wrapper(input_x, input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_greater_op_tensor(mode): + """ + Feature: Test Greater op. + Description: Test Greater with Tensor. + Expectation: The result match to the expect value. + """ + context.set_context(mode=mode, device_target="GPU") + + input_x = Tensor(np.array([1, -2, 3]).astype(np.float32)) + input_y = Tensor(np.array([3, 2, 1]).astype(np.float32)) + outputs = input_x.gt(input_y) + + assert outputs.shape == (3,) + assert np.allclose(outputs.asnumpy(), [False, False, True]) diff --git a/tests/st/ops/gpu/test_hamming_window_op.py b/tests/st/ops/gpu/test_hamming_window_op.py index 8a4ee94597f..d45957162f1 100644 --- a/tests/st/ops/gpu/test_hamming_window_op.py +++ b/tests/st/ops/gpu/test_hamming_window_op.py @@ -1,75 +1,75 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import torch -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common import dtype as mstype -from mindspore.common.api import jit -import mindspore.ops.operations.array_ops as P2 - - -class HammingWindowNet(nn.Cell): - def __init__(self, periodic=True, alpha=0.54, beta=0.46, dtype=mstype.Int): - super(HammingWindowNet, self).__init__() - self.hammingwindow = P2.HammingWindow(periodic=periodic, alpha=alpha, beta=beta, dtype=dtype) - - @jit - def construct(self, input_x): - return self.hammingwindow(input_x) - - - -def hamming_window(periodic, loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - input_x_np = np.array([10]).astype(np.int32) - input_x_ms = Tensor(input_x_np) - hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32) - hamming_window_output = hamming_window_net(input_x_ms) - hamming_window_expect = torch.hamming_window(10, periodic=periodic) - assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss) - - -def hamming_window_pynative(periodic, loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - input_x_np = np.array([10]).astype(np.int32) - input_x_ms = Tensor(input_x_np) - hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32) - hamming_window_output = hamming_window_net(input_x_ms) - hamming_window_expect = torch.hamming_window(10, periodic=periodic) - assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_hamming_window_graph_int32_true_float32(): - """ - Feature: ALL To ALL - Description: test cases for HammingWindow - Expectation: the result match to torch - """ - hamming_window(periodic=True, loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_hamming_window_pynative_int64_false_float64(): - """ - Feature: ALL To ALL - Description: test cases for HammingWindow - Expectation: the result match to torch - """ - hamming_window_pynative(periodic=False, loss=1.0e-4) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import torch +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.api import jit +import mindspore.ops.operations.array_ops as P2 + + +class HammingWindowNet(nn.Cell): + def __init__(self, periodic=True, alpha=0.54, beta=0.46, dtype=mstype.Int): + super(HammingWindowNet, self).__init__() + self.hammingwindow = P2.HammingWindow(periodic=periodic, alpha=alpha, beta=beta, dtype=dtype) + + @jit + def construct(self, input_x): + return self.hammingwindow(input_x) + + + +def hamming_window(periodic, loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_x_np = np.array([10]).astype(np.int32) + input_x_ms = Tensor(input_x_np) + hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32) + hamming_window_output = hamming_window_net(input_x_ms) + hamming_window_expect = torch.hamming_window(10, periodic=periodic) + assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss) + + +def hamming_window_pynative(periodic, loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_x_np = np.array([10]).astype(np.int32) + input_x_ms = Tensor(input_x_np) + hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32) + hamming_window_output = hamming_window_net(input_x_ms) + hamming_window_expect = torch.hamming_window(10, periodic=periodic) + assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_hamming_window_graph_int32_true_float32(): + """ + Feature: ALL To ALL + Description: test cases for HammingWindow + Expectation: the result match to torch + """ + hamming_window(periodic=True, loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_hamming_window_pynative_int64_false_float64(): + """ + Feature: ALL To ALL + Description: test cases for HammingWindow + Expectation: the result match to torch + """ + hamming_window_pynative(periodic=False, loss=1.0e-4) diff --git a/tests/st/ops/gpu/test_kl_div_op.py b/tests/st/ops/gpu/test_kl_div_op.py index 2ee0a5d64ff..9474d3cd160 100644 --- a/tests/st/ops/gpu/test_kl_div_op.py +++ b/tests/st/ops/gpu/test_kl_div_op.py @@ -1,109 +1,109 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - -class Net(nn.Cell): - def __init__(self, reduction="none"): - super(Net, self).__init__() - self.KLDivLoss = P.KLDivLoss("none") - - def construct(self, x, y): - return self.KLDivLoss(x, y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_kl_div_loss(): - """ - Feature: Test KLDivLoss. - Description: Test KLDivLoss op with float inputs. - Expectation: The result match to expect. - """ - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - net = Net() - loss = net(Tensor(prediction), Tensor(target)) - expect = [-0.5297444, -0.40738472, -0.5733339, -0.58720195, -0.42922008, -0.31237593, - -0.3332863, -0.78742254, -0.6662671, -0.17546377, -0.31526336, -0.46702948, - -0.23191005, -0.2512708, -0.20934652, -0.32021108, -0.45477402, -0.278453, - -0.5551879, -0.48938933] - assert np.allclose(loss.asnumpy(), expect) - - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = C.GradOperation(get_all=True, sens_param=True) - self.network = network - - def construct(self, x1, x2, sens): - gout = self.grad(self.network)(x1, x2, sens) - return gout - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_kl_div_loss_grad(): - """ - Feature: Test KLDivLossGrad. - Description: Test KLDivLossGrad op with float inputs. - Expectation: The result match to expect. - """ - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float32) - target = np.random.rand(20).astype(np.float32) - sens = np.random.rand(20).astype(np.float32) - grad = Grad(Net()) - dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) - - dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, - -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, - -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, - -0.03094601, -0.14319494] - - assert np.allclose(dx[0].asnumpy(), dx1_expect) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_kl_div_loss_grad_float64(): - """ - Feature: Test KLDivLossGrad. - Description: Test KLDivLossGrad op with float inputs. - Expectation: The result match to expect. - """ - np.random.seed(42) - prediction = np.random.rand(20).astype(np.float64) - target = np.random.rand(20).astype(np.float64) - sens = np.random.rand(20).astype(np.float64) - grad = Grad(Net()) - dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) - - dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, - -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, - -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, - -0.03094601, -0.14319494] - - assert np.allclose(dx[0].asnumpy(), dx1_expect) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, reduction="none"): + super(Net, self).__init__() + self.KLDivLoss = P.KLDivLoss("none") + + def construct(self, x, y): + return self.KLDivLoss(x, y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_kl_div_loss(): + """ + Feature: Test KLDivLoss. + Description: Test KLDivLoss op with float inputs. + Expectation: The result match to expect. + """ + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + net = Net() + loss = net(Tensor(prediction), Tensor(target)) + expect = [-0.5297444, -0.40738472, -0.5733339, -0.58720195, -0.42922008, -0.31237593, + -0.3332863, -0.78742254, -0.6662671, -0.17546377, -0.31526336, -0.46702948, + -0.23191005, -0.2512708, -0.20934652, -0.32021108, -0.45477402, -0.278453, + -0.5551879, -0.48938933] + assert np.allclose(loss.asnumpy(), expect) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens): + gout = self.grad(self.network)(x1, x2, sens) + return gout + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_kl_div_loss_grad(): + """ + Feature: Test KLDivLossGrad. + Description: Test KLDivLossGrad op with float inputs. + Expectation: The result match to expect. + """ + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float32) + target = np.random.rand(20).astype(np.float32) + sens = np.random.rand(20).astype(np.float32) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + + dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, + -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, + -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, + -0.03094601, -0.14319494] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_kl_div_loss_grad_float64(): + """ + Feature: Test KLDivLossGrad. + Description: Test KLDivLossGrad op with float inputs. + Expectation: The result match to expect. + """ + np.random.seed(42) + prediction = np.random.rand(20).astype(np.float64) + target = np.random.rand(20).astype(np.float64) + sens = np.random.rand(20).astype(np.float64) + grad = Grad(Net()) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + + dx1_expect = [-0.07466945, -0.06907414, -0.01004642, -0.3331403, -0.11802178, -0.52019656, + -0.06224053, -0.2674369, -0.32387912, -0.00858657, -0.58906615, -0.13217884, + -0.06111591, -0.8490888, -0.57735133, -0.7452407, -0.02695603, -0.01914206, + -0.03094601, -0.14319494] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) diff --git a/tests/st/ops/gpu/test_list_diff_op.py b/tests/st/ops/gpu/test_list_diff_op.py index dc18c3441b0..076125edb87 100644 --- a/tests/st/ops/gpu/test_list_diff_op.py +++ b/tests/st/ops/gpu/test_list_diff_op.py @@ -1,73 +1,73 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -from mindspore.common import dtype as mstype -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations.array_ops import ListDiff - - -class NetListDiff(nn.Cell): - def __init__(self, out_idx=mstype.int64): - super(NetListDiff, self).__init__() - self.list_diff = ListDiff(out_idx=out_idx) - - def construct(self, x, y): - return self.list_diff(x, y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_list_diff_int32(): - """ - Feature: ListDiff gpu TEST. - Description: 1d test case for ListDiff - Expectation: the result match to expect - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - x = np.array([1, 2, 3, 4, 5, 6]).astype(np.int32) - y = np.array([1, 3, 6]).astype(np.int32) - res_out = np.array([2, 4, 5]).astype(np.int32) - res_idx = np.array([1, 3, 4]).astype(np.int64) - x1 = Tensor(x) - y1 = Tensor(y) - net = NetListDiff(out_idx=mstype.int64) - out, idx = net(x1, y1) - assert np.allclose(res_out, out.asnumpy()) - assert np.allclose(res_idx, idx.asnumpy()) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_list_diff_fp32(): - """ - Feature: ListDiff gpu TEST. - Description: 1d test case for ListDiff - Expectation: the result match to expect - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - x = np.array([1.5, 2.0, 3.1, 4.5, 5, 6]).astype(np.float32) - y = np.array([1.5, 3.1, 6]).astype(np.float32) - res_out = np.array([2.0, 4.5, 5]).astype(np.float32) - res_idx = np.array([1, 3, 4]).astype(np.int64) - x1 = Tensor(x) - y1 = Tensor(y) - net = NetListDiff(out_idx=mstype.int64) - out, idx = net(x1, y1) - assert np.allclose(res_out, out.asnumpy()) - assert np.allclose(res_idx, idx.asnumpy()) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +from mindspore.common import dtype as mstype +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations.array_ops import ListDiff + + +class NetListDiff(nn.Cell): + def __init__(self, out_idx=mstype.int64): + super(NetListDiff, self).__init__() + self.list_diff = ListDiff(out_idx=out_idx) + + def construct(self, x, y): + return self.list_diff(x, y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_list_diff_int32(): + """ + Feature: ListDiff gpu TEST. + Description: 1d test case for ListDiff + Expectation: the result match to expect + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([1, 2, 3, 4, 5, 6]).astype(np.int32) + y = np.array([1, 3, 6]).astype(np.int32) + res_out = np.array([2, 4, 5]).astype(np.int32) + res_idx = np.array([1, 3, 4]).astype(np.int64) + x1 = Tensor(x) + y1 = Tensor(y) + net = NetListDiff(out_idx=mstype.int64) + out, idx = net(x1, y1) + assert np.allclose(res_out, out.asnumpy()) + assert np.allclose(res_idx, idx.asnumpy()) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_list_diff_fp32(): + """ + Feature: ListDiff gpu TEST. + Description: 1d test case for ListDiff + Expectation: the result match to expect + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([1.5, 2.0, 3.1, 4.5, 5, 6]).astype(np.float32) + y = np.array([1.5, 3.1, 6]).astype(np.float32) + res_out = np.array([2.0, 4.5, 5]).astype(np.float32) + res_idx = np.array([1, 3, 4]).astype(np.int64) + x1 = Tensor(x) + y1 = Tensor(y) + net = NetListDiff(out_idx=mstype.int64) + out, idx = net(x1, y1) + assert np.allclose(res_out, out.asnumpy()) + assert np.allclose(res_idx, idx.asnumpy()) diff --git a/tests/st/ops/gpu/test_minimum_grad_grad.py b/tests/st/ops/gpu/test_minimum_grad_grad.py index e705c91bb84..a6c30de211d 100644 --- a/tests/st/ops/gpu/test_minimum_grad_grad.py +++ b/tests/st/ops/gpu/test_minimum_grad_grad.py @@ -1,182 +1,182 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest -import numpy as np -import mindspore.context as context -import mindspore.ops.operations._grad_ops as G -from mindspore import Tensor -from mindspore.nn import Cell -from mindspore.common import dtype as ms_type -from mindspore.ops.functional import vmap - - -class MinimumGradGradNet(Cell): - def __init__(self): - super(MinimumGradGradNet, self).__init__() - self.minimum_grad_grad = G.MinimumGradGrad() - - def construct(self, x1, x2, dx1, dx2): - return self.minimum_grad_grad(x1, x2, dx1, dx2) - - -def minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape): - """ - Feature: generate a minimum grad grad numpy benchmark. - Description: The input shape may need to broadcast. - Expectation: match to np mindspore MinimumGradGrad. - """ - b_f_x1 = np.broadcast_to(x1, shape).flatten() - b_f_x2 = np.broadcast_to(x2, shape).flatten() - b_f_dx1 = np.broadcast_to(dx1, shape).flatten() - b_f_dx2 = np.broadcast_to(dx2, shape).flatten() - sopd_x1 = np.zeros_like(x1) - sopd_x2 = np.zeros_like(x2) - b_f_sopd_grad = np.zeros_like(b_f_x1).flatten() - for index, _ in enumerate(b_f_x1): - if b_f_x1[index] < b_f_x2[index]: - b_f_sopd_grad[index] = b_f_dx1[index] - else: - b_f_sopd_grad[index] = b_f_dx2[index] - return sopd_x1, sopd_x2, b_f_sopd_grad.reshape(shape) - - -class MinimumGradGradVMapNet(Cell): - def __init__(self, forward_net, in_axes, out_axes): - super(MinimumGradGradVMapNet, self).__init__() - self.net = forward_net - self.in_axes = in_axes - self.out_axes = out_axes - - def construct(self, x1, x2, dx1, dx2): - return vmap(self.net, self.in_axes, self.out_axes)(x1, x2, dx1, dx2) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_minimum_grad_grad_random(): - """ - Feature: Test MinimumGradGrad. - Description: The input shape need support broadcast. - Expectation: match to np benchmark. - """ - np.random.seed(0) - loss = 1e-4 - x1 = np.random.normal(0, 1, [3, 5]).astype(np.float32) - x2 = np.random.normal(0, 1, [3, 5]).astype(np.float32) - dout = np.minimum(x1, x2).astype(np.float32) - dx1 = dout * (x1 <= x2) - dx2 = dout - dx1 - np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) - context.set_context(mode=context.GRAPH_MODE) - net = MinimumGradGradNet() - result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) - assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) - context.set_context(mode=context.PYNATIVE_MODE) - net = MinimumGradGradNet() - result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) - assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_minimum_grad_grad_broadcast(): - """ - Feature: Test MinimumGradGrad. - Description: The input shape need support broadcast. - Expectation: match to np benchmark. - """ - np.random.seed(0) - loss = 1e-4 - x1 = np.random.normal(0, 1, [1, 3, 5]).astype(np.float32) - x2 = np.random.normal(0, 1, [1, 3, 5]).astype(np.float32) - dout = np.minimum(x1, x2).astype(np.float32) - dx1 = dout * (x1 <= x2) - dx2 = dout - dx1 - reduce_x2 = x2[0][:] - reduce_dx2 = dx2[0][:] - context.set_context(mode=context.GRAPH_MODE) - net = MinimumGradGradNet() - benchmark_result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) - result = net(Tensor(x1), Tensor(reduce_x2), Tensor(dx1), Tensor(reduce_dx2)) - assert np.allclose(benchmark_result[2].asnumpy(), result[2].asnumpy(), rtol=loss, atol=loss, equal_nan=True) - context.set_context(mode=context.PYNATIVE_MODE) - net = MinimumGradGradNet() - result = net(Tensor(x1), Tensor(reduce_x2), Tensor(dx1), Tensor(reduce_dx2)) - assert np.allclose(benchmark_result[2].asnumpy(), result[2].asnumpy(), rtol=loss, atol=loss, equal_nan=True) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize("data_type", [np.float32]) -def test_minimum_grad_grad_dy_shape(data_type): - """ - Feature: Test MinimumGradGrad DynamicShape. - Description: The input data type only float16 and float32. - Expectation: match to np benchmark. - """ - context.set_context(mode=context.GRAPH_MODE) - loss = 1e-4 - x1 = np.random.normal(0, 1, [3, 5]).astype(data_type) - x2 = np.random.normal(0, 1, [3, 5]).astype(data_type) - dout = np.minimum(x1, x2).astype(data_type) - dx1 = dout * (x1 <= x2) - dx2 = dout - dx1 - np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) - context.set_context(mode=context.GRAPH_MODE) - minimum_grad_grad_net = MinimumGradGradNet() - ms_data_type = ms_type.float32 - x1_dyn = Tensor(shape=[3, None], dtype=ms_data_type) - x2_dyn = Tensor(shape=[3, None], dtype=ms_data_type) - dx1_dyn = Tensor(shape=[3, None], dtype=ms_data_type) - dx2_dyn = Tensor(shape=[3, None], dtype=ms_data_type) - minimum_grad_grad_net.set_inputs(x1_dyn, x2_dyn, dx1_dyn, dx2_dyn) - result = minimum_grad_grad_net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) - assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) - context.set_context(mode=context.PYNATIVE_MODE) - minimum_grad_grad_net.set_inputs(x1_dyn, x2_dyn, dx1_dyn, dx2_dyn) - assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_minimum_grad_grad_vmap(): - """ - Feature: test MinimumGradGrad vmap on GPU. - Description: The input data type only float16 and float32. - Expectation: match to np benchmark. - """ - context.set_context(mode=context.GRAPH_MODE) - loss = 1e-4 - # Case : in_axes input_x batch remains 0 - x1 = np.random.normal(0, 1, [2, 3, 5]).astype(np.float32) - x2 = np.random.normal(0, 1, [2, 3, 5]).astype(np.float32) - dout = np.minimum(x1, x2).astype(np.float32) - dx1 = dout * (x1 <= x2) - dx2 = dout - dx1 - vmap_np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) - in_axes = 0 - out_axes = 0 - minimum_grad_grad = MinimumGradGradNet() - result = MinimumGradGradVMapNet(minimum_grad_grad, in_axes, out_axes)(Tensor(x1), Tensor(x2), Tensor(dx1), - Tensor(dx2)) - assert np.allclose(result[0].asnumpy(), vmap_np_result[0], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[1].asnumpy(), vmap_np_result[1], rtol=loss, atol=loss, equal_nan=True) - assert np.allclose(result[2].asnumpy(), vmap_np_result[2], rtol=loss, atol=loss, equal_nan=True) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest +import numpy as np +import mindspore.context as context +import mindspore.ops.operations._grad_ops as G +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore.common import dtype as ms_type +from mindspore.ops.functional import vmap + + +class MinimumGradGradNet(Cell): + def __init__(self): + super(MinimumGradGradNet, self).__init__() + self.minimum_grad_grad = G.MinimumGradGrad() + + def construct(self, x1, x2, dx1, dx2): + return self.minimum_grad_grad(x1, x2, dx1, dx2) + + +def minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape): + """ + Feature: generate a minimum grad grad numpy benchmark. + Description: The input shape may need to broadcast. + Expectation: match to np mindspore MinimumGradGrad. + """ + b_f_x1 = np.broadcast_to(x1, shape).flatten() + b_f_x2 = np.broadcast_to(x2, shape).flatten() + b_f_dx1 = np.broadcast_to(dx1, shape).flatten() + b_f_dx2 = np.broadcast_to(dx2, shape).flatten() + sopd_x1 = np.zeros_like(x1) + sopd_x2 = np.zeros_like(x2) + b_f_sopd_grad = np.zeros_like(b_f_x1).flatten() + for index, _ in enumerate(b_f_x1): + if b_f_x1[index] < b_f_x2[index]: + b_f_sopd_grad[index] = b_f_dx1[index] + else: + b_f_sopd_grad[index] = b_f_dx2[index] + return sopd_x1, sopd_x2, b_f_sopd_grad.reshape(shape) + + +class MinimumGradGradVMapNet(Cell): + def __init__(self, forward_net, in_axes, out_axes): + super(MinimumGradGradVMapNet, self).__init__() + self.net = forward_net + self.in_axes = in_axes + self.out_axes = out_axes + + def construct(self, x1, x2, dx1, dx2): + return vmap(self.net, self.in_axes, self.out_axes)(x1, x2, dx1, dx2) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_minimum_grad_grad_random(): + """ + Feature: Test MinimumGradGrad. + Description: The input shape need support broadcast. + Expectation: match to np benchmark. + """ + np.random.seed(0) + loss = 1e-4 + x1 = np.random.normal(0, 1, [3, 5]).astype(np.float32) + x2 = np.random.normal(0, 1, [3, 5]).astype(np.float32) + dout = np.minimum(x1, x2).astype(np.float32) + dx1 = dout * (x1 <= x2) + dx2 = dout - dx1 + np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) + context.set_context(mode=context.GRAPH_MODE) + net = MinimumGradGradNet() + result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) + assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) + context.set_context(mode=context.PYNATIVE_MODE) + net = MinimumGradGradNet() + result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) + assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_minimum_grad_grad_broadcast(): + """ + Feature: Test MinimumGradGrad. + Description: The input shape need support broadcast. + Expectation: match to np benchmark. + """ + np.random.seed(0) + loss = 1e-4 + x1 = np.random.normal(0, 1, [1, 3, 5]).astype(np.float32) + x2 = np.random.normal(0, 1, [1, 3, 5]).astype(np.float32) + dout = np.minimum(x1, x2).astype(np.float32) + dx1 = dout * (x1 <= x2) + dx2 = dout - dx1 + reduce_x2 = x2[0][:] + reduce_dx2 = dx2[0][:] + context.set_context(mode=context.GRAPH_MODE) + net = MinimumGradGradNet() + benchmark_result = net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) + result = net(Tensor(x1), Tensor(reduce_x2), Tensor(dx1), Tensor(reduce_dx2)) + assert np.allclose(benchmark_result[2].asnumpy(), result[2].asnumpy(), rtol=loss, atol=loss, equal_nan=True) + context.set_context(mode=context.PYNATIVE_MODE) + net = MinimumGradGradNet() + result = net(Tensor(x1), Tensor(reduce_x2), Tensor(dx1), Tensor(reduce_dx2)) + assert np.allclose(benchmark_result[2].asnumpy(), result[2].asnumpy(), rtol=loss, atol=loss, equal_nan=True) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize("data_type", [np.float32]) +def test_minimum_grad_grad_dy_shape(data_type): + """ + Feature: Test MinimumGradGrad DynamicShape. + Description: The input data type only float16 and float32. + Expectation: match to np benchmark. + """ + context.set_context(mode=context.GRAPH_MODE) + loss = 1e-4 + x1 = np.random.normal(0, 1, [3, 5]).astype(data_type) + x2 = np.random.normal(0, 1, [3, 5]).astype(data_type) + dout = np.minimum(x1, x2).astype(data_type) + dx1 = dout * (x1 <= x2) + dx2 = dout - dx1 + np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) + context.set_context(mode=context.GRAPH_MODE) + minimum_grad_grad_net = MinimumGradGradNet() + ms_data_type = ms_type.float32 + x1_dyn = Tensor(shape=[3, None], dtype=ms_data_type) + x2_dyn = Tensor(shape=[3, None], dtype=ms_data_type) + dx1_dyn = Tensor(shape=[3, None], dtype=ms_data_type) + dx2_dyn = Tensor(shape=[3, None], dtype=ms_data_type) + minimum_grad_grad_net.set_inputs(x1_dyn, x2_dyn, dx1_dyn, dx2_dyn) + result = minimum_grad_grad_net(Tensor(x1), Tensor(x2), Tensor(dx1), Tensor(dx2)) + assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) + context.set_context(mode=context.PYNATIVE_MODE) + minimum_grad_grad_net.set_inputs(x1_dyn, x2_dyn, dx1_dyn, dx2_dyn) + assert np.allclose(result[0].asnumpy(), np_result[0], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[1].asnumpy(), np_result[1], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[2].asnumpy(), np_result[2], rtol=loss, atol=loss, equal_nan=True) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_minimum_grad_grad_vmap(): + """ + Feature: test MinimumGradGrad vmap on GPU. + Description: The input data type only float16 and float32. + Expectation: match to np benchmark. + """ + context.set_context(mode=context.GRAPH_MODE) + loss = 1e-4 + # Case : in_axes input_x batch remains 0 + x1 = np.random.normal(0, 1, [2, 3, 5]).astype(np.float32) + x2 = np.random.normal(0, 1, [2, 3, 5]).astype(np.float32) + dout = np.minimum(x1, x2).astype(np.float32) + dx1 = dout * (x1 <= x2) + dx2 = dout - dx1 + vmap_np_result = minimum_grad_grad_np_bencmark(x1, x2, dx1, dx2, shape=dout.shape) + in_axes = 0 + out_axes = 0 + minimum_grad_grad = MinimumGradGradNet() + result = MinimumGradGradVMapNet(minimum_grad_grad, in_axes, out_axes)(Tensor(x1), Tensor(x2), Tensor(dx1), + Tensor(dx2)) + assert np.allclose(result[0].asnumpy(), vmap_np_result[0], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[1].asnumpy(), vmap_np_result[1], rtol=loss, atol=loss, equal_nan=True) + assert np.allclose(result[2].asnumpy(), vmap_np_result[2], rtol=loss, atol=loss, equal_nan=True) diff --git a/tests/st/ops/gpu/test_multinomial_op.py b/tests/st/ops/gpu/test_multinomial_op.py index 70d4b861dbf..9e3a4ec1f6c 100644 --- a/tests/st/ops/gpu/test_multinomial_op.py +++ b/tests/st/ops/gpu/test_multinomial_op.py @@ -1,114 +1,114 @@ -# Copyright 2020-2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore.ops.functional import vmap -from mindspore import context -from mindspore import nn -import mindspore as ms -from mindspore import Tensor - -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - - -class Net(nn.Cell): - def __init__(self, sample, replacement, seed=0): - super(Net, self).__init__() - self.sample = sample - self.replacement = replacement - self.seed = seed - - def construct(self, x): - return C.multinomial(x, self.sample, self.replacement, self.seed) - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_multinomial_exception1(): - """ - Feature: test Multinomial exception case. - Description: test Multinomial exception case and GPU kernel exception handling feature. - Expectation: success. - """ - x = Tensor(np.array([0.9, 0.5, 0.2, 0]).astype(np.float32)) - net = Net(2, True) - try: - net(x) - except RuntimeError: - assert True - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_multinomial__exception2(): - """ - Feature: test Multinomial exception case. - Description: test Multinomial exception case and GPU kernel exception handling feature. - Expectation: success. - """ - x = Tensor(np.array([9, 4, 2, 1]).astype(np.float32)) - net = Net(2, True) - try: - net(x) - except RuntimeError: - assert True - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_multinomial(): - """ - Feature: test Multinomial common call. - Description: test Multinomial common call. - Expectation: success. - """ - x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) - x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) - net0 = Net(1, True, 20) - net1 = Net(2, True, 20) - net2 = Net(6, True, 20) - out0 = net0(x0) - out1 = net1(x0) - out2 = net2(x1) - assert out0.asnumpy().shape == (1,) - assert out1.asnumpy().shape == (2,) - assert out2.asnumpy().shape == (2, 6) - -class BatchedMultinomial(nn.Cell): - def __init__(self): - super().__init__() - self.multinomial = P.Multinomial(seed=5, seed2=6) - - def construct(self, prob, num_sample): - return self.multinomial(prob, num_sample) - - -def multinomial(prob, num_sample): - return P.Multinomial(seed=5, seed2=6)(prob, num_sample) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_multinomial_vmap(): - """ - Feature: test Multinomial vmap feature. - Description: test Multinomial vmap feature. - Expectation: success. - """ - prob = Tensor([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], ms.float32) - num_sample = 3 - - batched_multinomial = BatchedMultinomial() - batched_out = batched_multinomial(prob, num_sample) - vmap_out = vmap(multinomial, in_axes=(0, None), out_axes=0)(prob, num_sample) - - assert (batched_out.asnumpy() == vmap_out.asnumpy()).all() +# Copyright 2020-2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops.functional import vmap +from mindspore import context +from mindspore import nn +import mindspore as ms +from mindspore import Tensor + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + +class Net(nn.Cell): + def __init__(self, sample, replacement, seed=0): + super(Net, self).__init__() + self.sample = sample + self.replacement = replacement + self.seed = seed + + def construct(self, x): + return C.multinomial(x, self.sample, self.replacement, self.seed) + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_multinomial_exception1(): + """ + Feature: test Multinomial exception case. + Description: test Multinomial exception case and GPU kernel exception handling feature. + Expectation: success. + """ + x = Tensor(np.array([0.9, 0.5, 0.2, 0]).astype(np.float32)) + net = Net(2, True) + try: + net(x) + except RuntimeError: + assert True + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_multinomial__exception2(): + """ + Feature: test Multinomial exception case. + Description: test Multinomial exception case and GPU kernel exception handling feature. + Expectation: success. + """ + x = Tensor(np.array([9, 4, 2, 1]).astype(np.float32)) + net = Net(2, True) + try: + net(x) + except RuntimeError: + assert True + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_multinomial(): + """ + Feature: test Multinomial common call. + Description: test Multinomial common call. + Expectation: success. + """ + x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) + x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) + net0 = Net(1, True, 20) + net1 = Net(2, True, 20) + net2 = Net(6, True, 20) + out0 = net0(x0) + out1 = net1(x0) + out2 = net2(x1) + assert out0.asnumpy().shape == (1,) + assert out1.asnumpy().shape == (2,) + assert out2.asnumpy().shape == (2, 6) + +class BatchedMultinomial(nn.Cell): + def __init__(self): + super().__init__() + self.multinomial = P.Multinomial(seed=5, seed2=6) + + def construct(self, prob, num_sample): + return self.multinomial(prob, num_sample) + + +def multinomial(prob, num_sample): + return P.Multinomial(seed=5, seed2=6)(prob, num_sample) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_multinomial_vmap(): + """ + Feature: test Multinomial vmap feature. + Description: test Multinomial vmap feature. + Expectation: success. + """ + prob = Tensor([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], ms.float32) + num_sample = 3 + + batched_multinomial = BatchedMultinomial() + batched_out = batched_multinomial(prob, num_sample) + vmap_out = vmap(multinomial, in_axes=(0, None), out_axes=0)(prob, num_sample) + + assert (batched_out.asnumpy() == vmap_out.asnumpy()).all() diff --git a/tests/st/ops/gpu/test_pdist_grad_op.py b/tests/st/ops/gpu/test_pdist_grad_op.py index 3da2bc43c6d..20946c8c975 100644 --- a/tests/st/ops/gpu/test_pdist_grad_op.py +++ b/tests/st/ops/gpu/test_pdist_grad_op.py @@ -1,83 +1,83 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -from mindspore.ops.operations import _grad_ops as G - - -class PdistGradNet(nn.Cell): - def __init__(self, p=2.0): - super().__init__() - self.pdistgrad = G.PdistGrad(p=p) - - def construct(self, y_grad, x, y): - return self.pdistgrad(y_grad, x, y) - - -def pdist_grad_graph(y_grad, x, y, p): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - net = PdistGradNet(p) - output_ms = net(y_grad, x, y) - return output_ms - - -def pdist_grad_pynative(y_grad, x, y, p): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - net = PdistGradNet(p) - output_ms = net(y_grad, x, y) - return output_ms - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) -def test_pdist_grad_graph(dtype, eps): - """ - Feature: test PdistGrad operation in result - Description: test the Pdist p = 2.0 - Expectation: the output matches numpy - """ - y_grad = Tensor(np.array([1., 1., 2.]).astype(dtype)) - x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) - y = Tensor(np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype)) - p = 2.0 - error = np.ones(shape=(3, 2)) * eps - output_ms_graph = pdist_grad_graph(y_grad, x, y, p) - out_pt = np.array([[-1.41421356, -1.41421356], [-0.70710678, -0.70710678], [2.12132034, 2.12132034]]).astype(dtype) - diff_graph = np.abs(output_ms_graph.asnumpy() - out_pt) - assert np.all(diff_graph < error) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) -def test_pdist_grad_pynative(dtype, eps): - """ - Feature: test PdistGrad operation in result - Description: test the Pdist p = 2.0 - Expectation: the output matches numpy - """ - y_grad = Tensor(np.array([1., 1., 2.]).astype(dtype)) - x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) - y = Tensor(np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype)) - p = 2.0 - error = np.ones(shape=(3, 2)) * eps - output_ms_pynative = pdist_grad_pynative(y_grad, x, y, p) - out_pt = np.array([[-1.41421356, -1.41421356], [-0.70710678, -0.70710678], [2.12132034, 2.12132034]]).astype(dtype) - diff_pynative = np.abs(output_ms_pynative.asnumpy() - out_pt) - assert np.all(diff_pynative < error) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class PdistGradNet(nn.Cell): + def __init__(self, p=2.0): + super().__init__() + self.pdistgrad = G.PdistGrad(p=p) + + def construct(self, y_grad, x, y): + return self.pdistgrad(y_grad, x, y) + + +def pdist_grad_graph(y_grad, x, y, p): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = PdistGradNet(p) + output_ms = net(y_grad, x, y) + return output_ms + + +def pdist_grad_pynative(y_grad, x, y, p): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + net = PdistGradNet(p) + output_ms = net(y_grad, x, y) + return output_ms + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) +def test_pdist_grad_graph(dtype, eps): + """ + Feature: test PdistGrad operation in result + Description: test the Pdist p = 2.0 + Expectation: the output matches numpy + """ + y_grad = Tensor(np.array([1., 1., 2.]).astype(dtype)) + x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) + y = Tensor(np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype)) + p = 2.0 + error = np.ones(shape=(3, 2)) * eps + output_ms_graph = pdist_grad_graph(y_grad, x, y, p) + out_pt = np.array([[-1.41421356, -1.41421356], [-0.70710678, -0.70710678], [2.12132034, 2.12132034]]).astype(dtype) + diff_graph = np.abs(output_ms_graph.asnumpy() - out_pt) + assert np.all(diff_graph < error) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) +def test_pdist_grad_pynative(dtype, eps): + """ + Feature: test PdistGrad operation in result + Description: test the Pdist p = 2.0 + Expectation: the output matches numpy + """ + y_grad = Tensor(np.array([1., 1., 2.]).astype(dtype)) + x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) + y = Tensor(np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype)) + p = 2.0 + error = np.ones(shape=(3, 2)) * eps + output_ms_pynative = pdist_grad_pynative(y_grad, x, y, p) + out_pt = np.array([[-1.41421356, -1.41421356], [-0.70710678, -0.70710678], [2.12132034, 2.12132034]]).astype(dtype) + diff_pynative = np.abs(output_ms_pynative.asnumpy() - out_pt) + assert np.all(diff_pynative < error) diff --git a/tests/st/ops/gpu/test_pdist_op.py b/tests/st/ops/gpu/test_pdist_op.py index 693d38f3e85..999d670838c 100644 --- a/tests/st/ops/gpu/test_pdist_op.py +++ b/tests/st/ops/gpu/test_pdist_op.py @@ -1,79 +1,79 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -from mindspore.ops.operations import nn_ops as P - - -class PdistNet(nn.Cell): - def __init__(self, p=2.0): - super().__init__() - self.pdist = P.Pdist(p=p) - - def construct(self, x): - return self.pdist(x) - - -def pdist_graph(x, p): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - net = PdistNet(p) - output_ms = net(x) - return output_ms - - -def pdist_pynative(x, p): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - net = PdistNet(p) - output_ms = net(x) - return output_ms - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) -def test_pdist_graph(dtype, eps): - """ - Feature: test Pdist operation in result - Description: test the Pdist p = 2.0 - Expectation: the output matches numpy - """ - x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) - error = np.ones(shape=(3,)) * eps - p = 2.0 - output_ms_graph = pdist_graph(x, p) - out_expect = np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype) - diff_graph = np.abs(output_ms_graph.asnumpy() - out_expect) - assert np.all(diff_graph < error) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) -def test_pdist_pynative(dtype, eps): - """ - Feature: test Pdist operation in result - Description: test the Pdist p = 2.0 - Expectation: the output matches numpy - """ - x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) - error = np.ones(shape=(3,)) * eps - p = 2.0 - output_ms_pynative = pdist_pynative(x, p) - out_expect = np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype) - diff_pynative = np.abs(output_ms_pynative.asnumpy() - out_expect) - assert np.all(diff_pynative < error) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops.operations import nn_ops as P + + +class PdistNet(nn.Cell): + def __init__(self, p=2.0): + super().__init__() + self.pdist = P.Pdist(p=p) + + def construct(self, x): + return self.pdist(x) + + +def pdist_graph(x, p): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = PdistNet(p) + output_ms = net(x) + return output_ms + + +def pdist_pynative(x, p): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + net = PdistNet(p) + output_ms = net(x) + return output_ms + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) +def test_pdist_graph(dtype, eps): + """ + Feature: test Pdist operation in result + Description: test the Pdist p = 2.0 + Expectation: the output matches numpy + """ + x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) + error = np.ones(shape=(3,)) * eps + p = 2.0 + output_ms_graph = pdist_graph(x, p) + out_expect = np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype) + diff_graph = np.abs(output_ms_graph.asnumpy() - out_expect) + assert np.all(diff_graph < error) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype, eps', [(np.float32, 1.0e-4), (np.float64, 1.0e-5)]) +def test_pdist_pynative(dtype, eps): + """ + Feature: test Pdist operation in result + Description: test the Pdist p = 2.0 + Expectation: the output matches numpy + """ + x = Tensor(np.array([[1., 1.], [2., 2.], [3., 3.]]).astype(dtype)) + error = np.ones(shape=(3,)) * eps + p = 2.0 + output_ms_pynative = pdist_pynative(x, p) + out_expect = np.array([1.41421356, 2.82842712, 1.41421356]).astype(dtype) + diff_pynative = np.abs(output_ms_pynative.asnumpy() - out_expect) + assert np.all(diff_pynative < error) diff --git a/tests/st/ops/gpu/test_polar_op.py b/tests/st/ops/gpu/test_polar_op.py index dedb70d8734..f7d8308cb38 100644 --- a/tests/st/ops/gpu/test_polar_op.py +++ b/tests/st/ops/gpu/test_polar_op.py @@ -1,78 +1,78 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.math_ops as P -from mindspore import Tensor -from mindspore.common.api import ms_function - - -class PolarNet(nn.Cell): - def __init__(self): - super(PolarNet, self).__init__() - self.polar = P.Polar() - - @ms_function - def construct(self, ms_abs, ms_angle): - return self.polar(ms_abs, ms_angle) - - -def polar(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - np_abs = np.array([1, 2, 3, 4]).astype(np.float32) - np_angle = np.array([np.pi/2, 5*np.pi/4, 3*np.pi/2, 2*np.pi/3]).astype(np.float32) - ms_abs = Tensor(np_abs) - ms_angle = Tensor(np_angle) - net = PolarNet() - output = net(ms_abs, ms_angle) - expected = [-4.3711388e-08+1.j, -1.4142137e+00-1.4142134j, 3.5774640e-08-3.j, -2.0000002e+00+3.4641016j] - assert np.allclose(output.asnumpy(), expected, loss, loss) - - -def polar_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - np_abs = np.array([1, 2, 3, 4]).astype(np.float64) - np_angle = np.array([np.pi/2, 5*np.pi/4, 3*np.pi/2, 2*np.pi/3]).astype(np.float64) - ms_abs = Tensor(np_abs) - ms_angle = Tensor(np_angle) - net = PolarNet() - output = net(ms_abs, ms_angle) - expected = [6.12323400e-17+1.j, -1.41421356e+00-1.41421356j, -5.51091060e-16-3.j, -2.00000000e+00+3.46410162j] - assert np.allclose(output.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_polar_graph_float(): - """ - Feature: ALL To ALL - Description: test cases for Polar - Expectation: the result match to pytorch - """ - polar(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_polar_pynative_double(): - """ - Feature: ALL To ALL - Description: test cases for Polar - Expectation: the result match to pytorch - """ - polar_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.math_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class PolarNet(nn.Cell): + def __init__(self): + super(PolarNet, self).__init__() + self.polar = P.Polar() + + @ms_function + def construct(self, ms_abs, ms_angle): + return self.polar(ms_abs, ms_angle) + + +def polar(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + np_abs = np.array([1, 2, 3, 4]).astype(np.float32) + np_angle = np.array([np.pi/2, 5*np.pi/4, 3*np.pi/2, 2*np.pi/3]).astype(np.float32) + ms_abs = Tensor(np_abs) + ms_angle = Tensor(np_angle) + net = PolarNet() + output = net(ms_abs, ms_angle) + expected = [-4.3711388e-08+1.j, -1.4142137e+00-1.4142134j, 3.5774640e-08-3.j, -2.0000002e+00+3.4641016j] + assert np.allclose(output.asnumpy(), expected, loss, loss) + + +def polar_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + np_abs = np.array([1, 2, 3, 4]).astype(np.float64) + np_angle = np.array([np.pi/2, 5*np.pi/4, 3*np.pi/2, 2*np.pi/3]).astype(np.float64) + ms_abs = Tensor(np_abs) + ms_angle = Tensor(np_angle) + net = PolarNet() + output = net(ms_abs, ms_angle) + expected = [6.12323400e-17+1.j, -1.41421356e+00-1.41421356j, -5.51091060e-16-3.j, -2.00000000e+00+3.46410162j] + assert np.allclose(output.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_polar_graph_float(): + """ + Feature: ALL To ALL + Description: test cases for Polar + Expectation: the result match to pytorch + """ + polar(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_polar_pynative_double(): + """ + Feature: ALL To ALL + Description: test cases for Polar + Expectation: the result match to pytorch + """ + polar_pynative(loss=1.0e-5) \ No newline at end of file diff --git a/tests/st/ops/gpu/test_reduce_min_op.py b/tests/st/ops/gpu/test_reduce_min_op.py index 73eb7b5d738..18ac0486d99 100644 --- a/tests/st/ops/gpu/test_reduce_min_op.py +++ b/tests/st/ops/gpu/test_reduce_min_op.py @@ -1,94 +1,94 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _inner_ops as inner - - -class ReduceMin(nn.Cell): - def __init__(self, keep_dims): - super(ReduceMin, self).__init__() - self.reduce_min = P.ReduceMin(keep_dims=keep_dims) - - def construct(self, x, axis): - return self.reduce_min(x, axis) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64]) -@pytest.mark.parametrize('shape, axis, keep_dims', - [((2, 3, 4, 4), 3, True), ((2, 3, 4, 4), 3, False), ((2, 3, 1, 4), 2, True), - ((2, 3, 1, 4), 2, False), ((2, 3, 4, 4), None, True), ((2, 3, 4, 4), None, False), - ((2, 3, 4, 4), -2, False), ((2, 3, 4, 4), (-2, -1), False), ((1, 1, 1, 1), None, True)]) -def test_reduce_min(dtype, shape, axis, keep_dims): - """ - Feature: ALL To ALL - Description: test cases for ReduceMin - Expectation: the result match to numpy - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x = np.random.rand(*shape).astype(dtype) - tensor_x = Tensor(x) - - reduce_min = ReduceMin(keep_dims) - ms_axis = axis if axis is not None else () - output = reduce_min(tensor_x, ms_axis) - - expect = np.min(x, axis=axis, keepdims=keep_dims) - diff = abs(output.asnumpy() - expect) - error = np.ones(shape=expect.shape) * 1.0e-5 - assert np.all(diff < error) - assert output.shape == expect.shape - - -class ReduceMinDynamic(nn.Cell): - def __init__(self, x, axis): - super(ReduceMinDynamic, self).__init__() - self.reduce_min = P.ReduceMin(False) - self.test_dynamic = inner.GpuConvertToDynamicShape() - self.x = x - self.axis = axis - - def construct(self): - dynamic_x = self.test_dynamic(self.x) - return self.reduce_min(dynamic_x, self.axis) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype', [np.float32]) -@pytest.mark.parametrize('shape, axis, keep_dims', - [((1, 1, 1, 1), 0, False), ((2, 3, 4, 4), 0, False)]) -def test_reduce_min_dynamic(dtype, shape, axis, keep_dims): - """ - Feature: ALL To ALL - Description: test cases for ReduceMin with dynamic shape - Expectation: the result match to numpy - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - x = np.random.rand(*shape).astype(dtype) - ms_axis = axis if axis is not None else () - net = ReduceMinDynamic(Tensor(x), ms_axis) - - expect = np.min(x, axis=axis, keepdims=keep_dims) - output = net() - - np.testing.assert_almost_equal(output.asnumpy(), expect) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner + + +class ReduceMin(nn.Cell): + def __init__(self, keep_dims): + super(ReduceMin, self).__init__() + self.reduce_min = P.ReduceMin(keep_dims=keep_dims) + + def construct(self, x, axis): + return self.reduce_min(x, axis) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64]) +@pytest.mark.parametrize('shape, axis, keep_dims', + [((2, 3, 4, 4), 3, True), ((2, 3, 4, 4), 3, False), ((2, 3, 1, 4), 2, True), + ((2, 3, 1, 4), 2, False), ((2, 3, 4, 4), None, True), ((2, 3, 4, 4), None, False), + ((2, 3, 4, 4), -2, False), ((2, 3, 4, 4), (-2, -1), False), ((1, 1, 1, 1), None, True)]) +def test_reduce_min(dtype, shape, axis, keep_dims): + """ + Feature: ALL To ALL + Description: test cases for ReduceMin + Expectation: the result match to numpy + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x = np.random.rand(*shape).astype(dtype) + tensor_x = Tensor(x) + + reduce_min = ReduceMin(keep_dims) + ms_axis = axis if axis is not None else () + output = reduce_min(tensor_x, ms_axis) + + expect = np.min(x, axis=axis, keepdims=keep_dims) + diff = abs(output.asnumpy() - expect) + error = np.ones(shape=expect.shape) * 1.0e-5 + assert np.all(diff < error) + assert output.shape == expect.shape + + +class ReduceMinDynamic(nn.Cell): + def __init__(self, x, axis): + super(ReduceMinDynamic, self).__init__() + self.reduce_min = P.ReduceMin(False) + self.test_dynamic = inner.GpuConvertToDynamicShape() + self.x = x + self.axis = axis + + def construct(self): + dynamic_x = self.test_dynamic(self.x) + return self.reduce_min(dynamic_x, self.axis) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype', [np.float32]) +@pytest.mark.parametrize('shape, axis, keep_dims', + [((1, 1, 1, 1), 0, False), ((2, 3, 4, 4), 0, False)]) +def test_reduce_min_dynamic(dtype, shape, axis, keep_dims): + """ + Feature: ALL To ALL + Description: test cases for ReduceMin with dynamic shape + Expectation: the result match to numpy + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.random.rand(*shape).astype(dtype) + ms_axis = axis if axis is not None else () + net = ReduceMinDynamic(Tensor(x), ms_axis) + + expect = np.min(x, axis=axis, keepdims=keep_dims) + output = net() + + np.testing.assert_almost_equal(output.asnumpy(), expect) diff --git a/tests/st/ops/gpu/test_repeat_elements_op.py b/tests/st/ops/gpu/test_repeat_elements_op.py index 5e938b7dcbb..58be12891bd 100644 --- a/tests/st/ops/gpu/test_repeat_elements_op.py +++ b/tests/st/ops/gpu/test_repeat_elements_op.py @@ -1,606 +1,606 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -from mindspore import Tensor -from mindspore.ops import composite as C -import mindspore.nn as nn -import mindspore.context as context - - -class RepeatElementsNet(nn.Cell): - def __init__(self, rep, axis): - super(RepeatElementsNet, self).__init__() - self.rep = rep - self.axis = axis - - def construct(self, x): - return C.repeat_elements(x, self.rep, self.axis) - - -def repeat_elements(x, rep, axis): - repeat_elements_net = RepeatElementsNet(rep, axis) - return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_1d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_1d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1) - - ms_out = repeat_elements(a, 5, 0) - np_out = a.repeat(5, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 513, 0) - np_out = a.repeat(513, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_1d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_1d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24) - - ms_out = repeat_elements(a, 231, 0) - np_out = a.repeat(231, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_2d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_2d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1) - - ms_out = repeat_elements(a, 13, 0) - np_out = a.repeat(13, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 13, 1) - np_out = a.repeat(13, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_2d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(12, 2) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_2d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(8, 3) - - ms_out = repeat_elements(a, 23, 0) - np_out = a.repeat(23, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 23, 1) - np_out = a.repeat(23, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_3d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_3d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1) - - ms_out = repeat_elements(a, 43, 0) - np_out = a.repeat(43, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 43, 1) - np_out = a.repeat(43, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 43, 2) - np_out = a.repeat(43, 2) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_3d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(60).reshape(6, 2, 5) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_3d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(60).reshape(3, 4, 5) - - ms_out = repeat_elements(a, 14, 0) - np_out = a.repeat(14, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 14, 1) - np_out = a.repeat(14, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 14, 2) - np_out = a.repeat(14, 2) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_4d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_4d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1) - - ms_out = repeat_elements(a, 17, 0) - np_out = a.repeat(17, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 17, 1) - np_out = a.repeat(17, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 17, 2) - np_out = a.repeat(17, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 17, 3) - np_out = a.repeat(17, 3) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_4d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(4, 3, 2, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_4d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(2, 2, 2, 3) - - ms_out = repeat_elements(a, 23, 0) - np_out = a.repeat(23, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 23, 1) - np_out = a.repeat(23, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 23, 2) - np_out = a.repeat(23, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 23, 3) - np_out = a.repeat(23, 3) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_5d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_5d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1) - - ms_out = repeat_elements(a, 19, 0) - np_out = a.repeat(19, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 19, 1) - np_out = a.repeat(19, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 19, 2) - np_out = a.repeat(19, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 19, 3) - np_out = a.repeat(19, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 19, 4) - np_out = a.repeat(19, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_5d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(224).reshape(8, 2, 1, 7, 2) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_5d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(224).reshape(1, 7, 4, 4, 2) - - ms_out = repeat_elements(a, 7, 0) - np_out = a.repeat(7, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 7, 1) - np_out = a.repeat(7, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 7, 2) - np_out = a.repeat(7, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 7, 3) - np_out = a.repeat(7, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 7, 4) - np_out = a.repeat(7, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_large_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 5) - np_out = a.repeat(1, 5) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 6) - np_out = a.repeat(1, 6) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 7) - np_out = a.repeat(1, 7) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_large_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1, 1) - - ms_out = repeat_elements(a, 42, 0) - np_out = a.repeat(42, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 42, 1) - np_out = a.repeat(42, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 42, 2) - np_out = a.repeat(42, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 42, 3) - np_out = a.repeat(42, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 42, 4) - np_out = a.repeat(42, 4) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 42, 5) - np_out = a.repeat(42, 5) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_large_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1152).reshape(2, 3, 4, 8, 1, 1, 2, 3) - - ms_out = repeat_elements(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 5) - np_out = a.repeat(1, 5) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 6) - np_out = a.repeat(1, 6) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 1, 7) - np_out = a.repeat(1, 7) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_large_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1152).reshape(4, 3, 4, 2, 4, 3) - - ms_out = repeat_elements(a, 4, 0) - np_out = a.repeat(4, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 1) - np_out = a.repeat(4, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 2) - np_out = a.repeat(4, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 3) - np_out = a.repeat(4, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 4) - np_out = a.repeat(4, 4) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 5) - np_out = a.repeat(4, 5) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_half(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 4, 3) - - ms_out = repeat_elements(a, 4, 0) - np_out = a.repeat(4, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 1) - np_out = a.repeat(4, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 2) - np_out = a.repeat(4, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 3) - np_out = a.repeat(4, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 4) - np_out = a.repeat(4, 4) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements(a, 4, 5) - np_out = a.repeat(4, 5) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_net_multi_use(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - rep = 3 - axis = 4 - repeat_elements_net = RepeatElementsNet(rep, axis) - - a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) - ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() - np_out = a.repeat(rep, axis) - np.testing.assert_array_equal(np_out, ms_out) - - a = np.arange(128).reshape(2, 2, 4, 2, 2, 2) - ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() - np_out = a.repeat(rep, axis) - np.testing.assert_array_equal(np_out, ms_out) - - a = np.arange(18).reshape(1, 1, 3, 2, 3, 1) - ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() - np_out = a.repeat(rep, axis) - np.testing.assert_array_equal(np_out, ms_out) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_repeat_elements_invalid_input(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) - with pytest.raises(ValueError): - _ = repeat_elements(a, 0, 0) - - with pytest.raises(ValueError): - _ = repeat_elements(a, 1, 6) - - with pytest.raises(ValueError): - _ = repeat_elements(a, 1, -7) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import composite as C +import mindspore.nn as nn +import mindspore.context as context + + +class RepeatElementsNet(nn.Cell): + def __init__(self, rep, axis): + super(RepeatElementsNet, self).__init__() + self.rep = rep + self.axis = axis + + def construct(self, x): + return C.repeat_elements(x, self.rep, self.axis) + + +def repeat_elements(x, rep, axis): + repeat_elements_net = RepeatElementsNet(rep, axis) + return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_1d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_1d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1) + + ms_out = repeat_elements(a, 5, 0) + np_out = a.repeat(5, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 513, 0) + np_out = a.repeat(513, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_1d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_1d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24) + + ms_out = repeat_elements(a, 231, 0) + np_out = a.repeat(231, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_2d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_2d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + ms_out = repeat_elements(a, 13, 0) + np_out = a.repeat(13, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 13, 1) + np_out = a.repeat(13, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_2d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(12, 2) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_2d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(8, 3) + + ms_out = repeat_elements(a, 23, 0) + np_out = a.repeat(23, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 1) + np_out = a.repeat(23, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_3d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_3d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1) + + ms_out = repeat_elements(a, 43, 0) + np_out = a.repeat(43, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 43, 1) + np_out = a.repeat(43, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 43, 2) + np_out = a.repeat(43, 2) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_3d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(60).reshape(6, 2, 5) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_3d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(60).reshape(3, 4, 5) + + ms_out = repeat_elements(a, 14, 0) + np_out = a.repeat(14, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 14, 1) + np_out = a.repeat(14, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 14, 2) + np_out = a.repeat(14, 2) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_4d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_4d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1) + + ms_out = repeat_elements(a, 17, 0) + np_out = a.repeat(17, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 1) + np_out = a.repeat(17, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 2) + np_out = a.repeat(17, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 17, 3) + np_out = a.repeat(17, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_4d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(4, 3, 2, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_4d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(2, 2, 2, 3) + + ms_out = repeat_elements(a, 23, 0) + np_out = a.repeat(23, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 1) + np_out = a.repeat(23, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 2) + np_out = a.repeat(23, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 23, 3) + np_out = a.repeat(23, 3) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_5d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_5d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 19, 0) + np_out = a.repeat(19, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 1) + np_out = a.repeat(19, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 2) + np_out = a.repeat(19, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 3) + np_out = a.repeat(19, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 19, 4) + np_out = a.repeat(19, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_5d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(8, 2, 1, 7, 2) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_5d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(1, 7, 4, 4, 2) + + ms_out = repeat_elements(a, 7, 0) + np_out = a.repeat(7, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 1) + np_out = a.repeat(7, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 2) + np_out = a.repeat(7, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 3) + np_out = a.repeat(7, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 7, 4) + np_out = a.repeat(7, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_large_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 5) + np_out = a.repeat(1, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 6) + np_out = a.repeat(1, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 7) + np_out = a.repeat(1, 7) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_large_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1, 1) + + ms_out = repeat_elements(a, 42, 0) + np_out = a.repeat(42, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 1) + np_out = a.repeat(42, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 2) + np_out = a.repeat(42, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 3) + np_out = a.repeat(42, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 4) + np_out = a.repeat(42, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 42, 5) + np_out = a.repeat(42, 5) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_large_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).reshape(2, 3, 4, 8, 1, 1, 2, 3) + + ms_out = repeat_elements(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 5) + np_out = a.repeat(1, 5) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 6) + np_out = a.repeat(1, 6) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 1, 7) + np_out = a.repeat(1, 7) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_large_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).reshape(4, 3, 4, 2, 4, 3) + + ms_out = repeat_elements(a, 4, 0) + np_out = a.repeat(4, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 1) + np_out = a.repeat(4, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 2) + np_out = a.repeat(4, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 3) + np_out = a.repeat(4, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 4) + np_out = a.repeat(4, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 5) + np_out = a.repeat(4, 5) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_half(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 4, 3) + + ms_out = repeat_elements(a, 4, 0) + np_out = a.repeat(4, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 1) + np_out = a.repeat(4, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 2) + np_out = a.repeat(4, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 3) + np_out = a.repeat(4, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 4) + np_out = a.repeat(4, 4) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements(a, 4, 5) + np_out = a.repeat(4, 5) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_net_multi_use(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + rep = 3 + axis = 4 + repeat_elements_net = RepeatElementsNet(rep, axis) + + a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + + a = np.arange(128).reshape(2, 2, 4, 2, 2, 2) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + + a = np.arange(18).reshape(1, 1, 3, 2, 3, 1) + ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy() + np_out = a.repeat(rep, axis) + np.testing.assert_array_equal(np_out, ms_out) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_repeat_elements_invalid_input(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) + with pytest.raises(ValueError): + _ = repeat_elements(a, 0, 0) + + with pytest.raises(ValueError): + _ = repeat_elements(a, 1, 6) + + with pytest.raises(ValueError): + _ = repeat_elements(a, 1, -7) diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py old mode 100755 new mode 100644 diff --git a/tests/st/ops/gpu/test_sgd_op.py b/tests/st/ops/gpu/test_sgd_op.py index dcff87106aa..4d966709580 100644 --- a/tests/st/ops/gpu/test_sgd_op.py +++ b/tests/st/ops/gpu/test_sgd_op.py @@ -1,72 +1,72 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import Dense -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.nn.optim import SGD -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - -class NetSGD(nn.Cell): - def __init__(self): - super(NetSGD, self).__init__() - self.batch_size = 1 - self.reshape = P.Reshape() - weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) - self.fc1 = Dense(16, 10, weight_init=weight) - - def construct(self, input_x): - output = self.reshape(input_x, (self.batch_size, -1)) - output = self.fc1(output) - return output - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_SGD(): - epoch = 3 - net = NetSGD() - learning_rate = 0.1 - momentum = 0.9 - dampening = 0.0 - weight_decay = 0.0 - nesterov = True - loss_scale = 1.0 - - optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, - weight_decay, nesterov, loss_scale) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer - train_network.set_train() - losses = [] - for _ in range(epoch): - data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) - label = Tensor(np.array([0]).astype(np.int32)) - loss = train_network(data, label) - losses.append(loss.asnumpy()) - - last_loss = 100.0 - for loss in losses: - assert last_loss > loss - last_loss = loss - return losses +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import SGD +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetSGD(nn.Cell): + def __init__(self): + super(NetSGD, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_SGD(): + epoch = 3 + net = NetSGD() + learning_rate = 0.1 + momentum = 0.9 + dampening = 0.0 + weight_decay = 0.0 + nesterov = True + loss_scale = 1.0 + + optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, + weight_decay, nesterov, loss_scale) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss.asnumpy()) + + last_loss = 100.0 + for loss in losses: + assert last_loss > loss + last_loss = loss + return losses diff --git a/tests/st/ops/gpu/test_sparse_apply_centered_rms_prop_op.py b/tests/st/ops/gpu/test_sparse_apply_centered_rms_prop_op.py index a6f1ae4a2ef..172f4b58ab0 100644 --- a/tests/st/ops/gpu/test_sparse_apply_centered_rms_prop_op.py +++ b/tests/st/ops/gpu/test_sparse_apply_centered_rms_prop_op.py @@ -1,91 +1,91 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.context as context -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.parameter import Parameter -import mindspore.ops.operations.nn_ops as P - - -class SparseApplyCenteredRMSPropNet(nn.Cell): - def __init__(self, use_locking=False): - super(SparseApplyCenteredRMSPropNet, self).__init__() - self.sparse_apply_centered_rms_prop = P.SparseApplyCenteredRMSProp(use_locking=False) - - def construct(self, var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices): - out = self.sparse_apply_centered_rms_prop(var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices) - return out - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_apply_centered_rms_prop_graph_1(): - """ - Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal. - Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices. - Expectation: Success. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var") - mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg") - ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms") - mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom") - lr = Tensor(0.001, mstype.float32) - rho = Tensor(1e-10, mstype.float32) - momentum = Tensor(0.001, mstype.float32) - epsilon = Tensor(0.01, mstype.float32) - grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32))) - indices = Tensor(np.array([0, 1]).astype(np.int32)) - sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False) - sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \ - momentum, epsilon, grad, indices) - sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32) - - print(sparse_apply_centered_rms_prop_output) - print(sparse_apply_centered_rms_prop_expected_output) - assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \ - sparse_apply_centered_rms_prop_expected_output, rtol=1e-3) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_apply_centered_rms_prop_graph_2(): - """ - Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal. - Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices. - Expectation: Success. - """ - var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var") - mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg") - ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms") - mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom") - lr = Tensor(0.001, mstype.float32) - rho = Tensor(1e-10, mstype.float32) - momentum = Tensor(0.001, mstype.float32) - epsilon = Tensor(0.01, mstype.float32) - grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32))) - indices = Tensor(np.array([0, 1]).astype(np.int32)) - sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False) - sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \ - momentum, epsilon, grad, indices) - sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32) - - print(sparse_apply_centered_rms_prop_output) - print(sparse_apply_centered_rms_prop_expected_output) - assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \ - sparse_apply_centered_rms_prop_expected_output, rtol=1e-3) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.parameter import Parameter +import mindspore.ops.operations.nn_ops as P + + +class SparseApplyCenteredRMSPropNet(nn.Cell): + def __init__(self, use_locking=False): + super(SparseApplyCenteredRMSPropNet, self).__init__() + self.sparse_apply_centered_rms_prop = P.SparseApplyCenteredRMSProp(use_locking=False) + + def construct(self, var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices): + out = self.sparse_apply_centered_rms_prop(var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices) + return out + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_apply_centered_rms_prop_graph_1(): + """ + Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal. + Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices. + Expectation: Success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var") + mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg") + ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms") + mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom") + lr = Tensor(0.001, mstype.float32) + rho = Tensor(1e-10, mstype.float32) + momentum = Tensor(0.001, mstype.float32) + epsilon = Tensor(0.01, mstype.float32) + grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32))) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False) + sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \ + momentum, epsilon, grad, indices) + sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32) + + print(sparse_apply_centered_rms_prop_output) + print(sparse_apply_centered_rms_prop_expected_output) + assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \ + sparse_apply_centered_rms_prop_expected_output, rtol=1e-3) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_apply_centered_rms_prop_graph_2(): + """ + Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal. + Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices. + Expectation: Success. + """ + var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var") + mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg") + ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms") + mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom") + lr = Tensor(0.001, mstype.float32) + rho = Tensor(1e-10, mstype.float32) + momentum = Tensor(0.001, mstype.float32) + epsilon = Tensor(0.01, mstype.float32) + grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32))) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False) + sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \ + momentum, epsilon, grad, indices) + sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32) + + print(sparse_apply_centered_rms_prop_output) + print(sparse_apply_centered_rms_prop_expected_output) + assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \ + sparse_apply_centered_rms_prop_expected_output, rtol=1e-3) diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py index 303781bae1d..9c1c8edb569 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py @@ -1,98 +1,98 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations._grad_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSqrtNGradNet(nn.Cell): - def __init__(self): - super(SparseSegmentSqrtNGradNet, self).__init__() - self.net = P.SparseSegmentSqrtNGrad() - - @jit - def construct(self, grad, indices, segment_ids, output_dim0): - return self.net(grad, indices, segment_ids, output_dim0) - - -def sparse_segment_sqrt_n_grad(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) - output_dim0_np = np.array(8, dtype=np.int32) - grad_ms = Tensor(grad_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - output_dim0_ms = Tensor(output_dim0_np) - net_ms = SparseSegmentSqrtNGradNet() - out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) - expected = np.array([[6, 8, 10, 12], - [6.363961, 7.071068, 7.7781744, 8.485281], - [15.55635, 16.970562, 18.384777, 19.798988], - [9.192389, 9.899495, 10.606602, 11.313708], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sqrt_n_grad_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) - output_dim0_np = np.array(8, dtype=np.int64) - grad_ms = Tensor(grad_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - output_dim0_ms = Tensor(output_dim0_np) - net_ms = SparseSegmentSqrtNGradNet() - out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) - expected = np.array([[0.70710678, 1.41421356, 2.12132034, 2.82842712], - [5.70710678, 7.41421356, 9.12132034, 10.82842712], - [6.36396103, 7.07106781, 7.77817459, 8.48528137], - [19.36396103, 21.07106781, 22.77817459, 24.48528137], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtNGrad - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n_grad(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtNGrad - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n_grad_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations._grad_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSqrtNGradNet(nn.Cell): + def __init__(self): + super(SparseSegmentSqrtNGradNet, self).__init__() + self.net = P.SparseSegmentSqrtNGrad() + + @jit + def construct(self, grad, indices, segment_ids, output_dim0): + return self.net(grad, indices, segment_ids, output_dim0) + + +def sparse_segment_sqrt_n_grad(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + output_dim0_np = np.array(8, dtype=np.int32) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSqrtNGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[6, 8, 10, 12], + [6.363961, 7.071068, 7.7781744, 8.485281], + [15.55635, 16.970562, 18.384777, 19.798988], + [9.192389, 9.899495, 10.606602, 11.313708], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_grad_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + output_dim0_np = np.array(8, dtype=np.int64) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSqrtNGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[0.70710678, 1.41421356, 2.12132034, 2.82842712], + [5.70710678, 7.41421356, 9.12132034, 10.82842712], + [6.36396103, 7.07106781, 7.77817459, 8.48528137], + [19.36396103, 21.07106781, 22.77817459, 24.48528137], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_grad(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_grad_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py index e2bb970665c..5dcd016414d 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py @@ -1,117 +1,117 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore as ms -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.sparse_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSqrtNNet(nn.Cell): - - def __init__(self): - super(SparseSegmentSqrtNNet, self).__init__() - self.net = P.SparseSegmentSqrtN() - - @jit - def construct(self, x, indices, segment_ids): - return self.net(x, indices, segment_ids) - - -def sparse_segment_sqrt_n(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - net_ms = SparseSegmentSqrtNNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) - expected = np.array([[1, 2, 3, 4], [1, 2, 3, 4], - [9.899495, 11.313708, 12.727922, 14.142136], - [15.556349, 16.970562, 18.384777, 19.79899]], - dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sqrt_n_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - net_ms = SparseSegmentSqrtNNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) - expected = np.array( - [[4.24264069, 5.65685425, 7.07106781, 8.48528137], [5, 6, 7, 8], - [15.55634919, 16.97056275, 18.38477631, 19.79898987], - [13, 14, 15, 16]], - dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_graph_float32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtN - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_pynative_float64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtN - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n_pynative(loss=1.0e-5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_sparse_segment_sqrt_n_dyn(): - """ - Feature: test SparseSegmentSqrtN ops in gpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - net = SparseSegmentSqrtNNet() - - x_dyn = Tensor(shape=[None, None], dtype=ms.float32) - indices_dyn = Tensor(shape=[None], dtype=ms.int32) - segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) - net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn) - - x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=ms.float32) - indices = Tensor([0, 1, 2], dtype=ms.int32) - segment_ids = Tensor([0, 1, 2], dtype=ms.int32) - output = net(x, indices, segment_ids) - - expect_shape = (3, 4) - assert output.asnumpy().shape == expect_shape +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSqrtNNet(nn.Cell): + + def __init__(self): + super(SparseSegmentSqrtNNet, self).__init__() + self.net = P.SparseSegmentSqrtN() + + @jit + def construct(self, x, indices, segment_ids): + return self.net(x, indices, segment_ids) + + +def sparse_segment_sqrt_n(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSqrtNNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array([[1, 2, 3, 4], [1, 2, 3, 4], + [9.899495, 11.313708, 12.727922, 14.142136], + [15.556349, 16.970562, 18.384777, 19.79899]], + dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSqrtNNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array( + [[4.24264069, 5.65685425, 7.07106781, 8.48528137], [5, 6, 7, 8], + [15.55634919, 16.97056275, 18.38477631, 19.79898987], + [13, 14, 15, 16]], + dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtN + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtN + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_pynative(loss=1.0e-5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_sparse_segment_sqrt_n_dyn(): + """ + Feature: test SparseSegmentSqrtN ops in gpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = SparseSegmentSqrtNNet() + + x_dyn = Tensor(shape=[None, None], dtype=ms.float32) + indices_dyn = Tensor(shape=[None], dtype=ms.int32) + segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) + net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn) + + x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=ms.float32) + indices = Tensor([0, 1, 2], dtype=ms.int32) + segment_ids = Tensor([0, 1, 2], dtype=ms.int32) + output = net(x, indices, segment_ids) + + expect_shape = (3, 4) + assert output.asnumpy().shape == expect_shape diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py index aef3bade666..6894683d693 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py @@ -1,125 +1,125 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore as ms -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.sparse_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSqrtNWithNumSegmentsNet(nn.Cell): - - def __init__(self): - super(SparseSegmentSqrtNWithNumSegmentsNet, self).__init__() - self.net = P.SparseSegmentSqrtNWithNumSegments() - - @jit - def construct(self, x, indices, segment_ids, num_segments): - return self.net(x, indices, segment_ids, num_segments) - - -def sparse_segment_sqrt_n_with_num_segments(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) - num_segments_np = np.array(8, dtype=np.int32) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - num_segments_ms = Tensor(num_segments_np) - net_ms = SparseSegmentSqrtNWithNumSegmentsNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) - expected = np.array([[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0], - [4.2426405, 5.656854, 7.071068, 8.485281], - [0, 0, 0, 0], [9, 10, 11, 12], [0, 0, 0, 0], - [15.556349, 16.970562, 18.384777, 19.79899]], - dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sqrt_n_with_num_segments_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) - num_segments_np = np.array(8, dtype=np.int64) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - num_segments_ms = Tensor(num_segments_np) - net_ms = SparseSegmentSqrtNWithNumSegmentsNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) - expected = np.array( - [[4.24264069, 5.65685425, 7.07106781, 8.48528137], [0, 0, 0, 0], - [0, 0, 0, 0], [5, 6, 7, 8], [0, 0, 0, 0], - [15.55634919, 16.97056275, 18.38477631, 19.79898987], [0, 0, 0, 0], - [13, 14, 15, 16]], - dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_with_num_segments_graph_float32_int32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtNWithNumSegments - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n_with_num_segments(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sqrt_n_with_num_segments_pynative_float64_int64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSqrtNWithNumSegments - Expectation: the result match to tensorflow - """ - sparse_segment_sqrt_n_with_num_segments_pynative(loss=1.0e-5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_sparse_segment_sqrt_n_with_num_segments_dyn(): - """ - Feature: test SparseSegmentSqrtNWithNumSegments ops in gpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - net = SparseSegmentSqrtNWithNumSegmentsNet() - - x_dyn = Tensor(shape=[None, None], dtype=ms.float32) - indices_dyn = Tensor(shape=[None], dtype=ms.int32) - segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) - num_segments_dyn = Tensor(shape=[None], dtype=ms.int32) - net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn, num_segments_dyn) - - x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float32) - indices = Tensor([0, 2, 1], dtype=ms.int32) - segment_ids = Tensor([0, 1, 2], dtype=ms.int32) - num_segments = Tensor([4], dtype=ms.int32) - output = net(x, indices, segment_ids, num_segments) - - expect_shape = (4, 4) - assert output.asnumpy().shape == expect_shape +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSqrtNWithNumSegmentsNet(nn.Cell): + + def __init__(self): + super(SparseSegmentSqrtNWithNumSegmentsNet, self).__init__() + self.net = P.SparseSegmentSqrtNWithNumSegments() + + @jit + def construct(self, x, indices, segment_ids, num_segments): + return self.net(x, indices, segment_ids, num_segments) + + +def sparse_segment_sqrt_n_with_num_segments(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) + num_segments_np = np.array(8, dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSqrtNWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0], + [4.2426405, 5.656854, 7.071068, 8.485281], + [0, 0, 0, 0], [9, 10, 11, 12], [0, 0, 0, 0], + [15.556349, 16.970562, 18.384777, 19.79899]], + dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_with_num_segments_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) + num_segments_np = np.array(8, dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSqrtNWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array( + [[4.24264069, 5.65685425, 7.07106781, 8.48528137], [0, 0, 0, 0], + [0, 0, 0, 0], [5, 6, 7, 8], [0, 0, 0, 0], + [15.55634919, 16.97056275, 18.38477631, 19.79898987], [0, 0, 0, 0], + [13, 14, 15, 16]], + dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_with_num_segments_graph_float32_int32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_with_num_segments(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sqrt_n_with_num_segments_pynative_float64_int64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_with_num_segments_pynative(loss=1.0e-5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_sparse_segment_sqrt_n_with_num_segments_dyn(): + """ + Feature: test SparseSegmentSqrtNWithNumSegments ops in gpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = SparseSegmentSqrtNWithNumSegmentsNet() + + x_dyn = Tensor(shape=[None, None], dtype=ms.float32) + indices_dyn = Tensor(shape=[None], dtype=ms.int32) + segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) + num_segments_dyn = Tensor(shape=[None], dtype=ms.int32) + net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn, num_segments_dyn) + + x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float32) + indices = Tensor([0, 2, 1], dtype=ms.int32) + segment_ids = Tensor([0, 1, 2], dtype=ms.int32) + num_segments = Tensor([4], dtype=ms.int32) + output = net(x, indices, segment_ids, num_segments) + + expect_shape = (4, 4) + assert output.asnumpy().shape == expect_shape diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py index 2b776698a34..c5e69d70201 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py @@ -1,98 +1,98 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations._grad_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSumGradNet(nn.Cell): - def __init__(self): - super(SparseSegmentSumGradNet, self).__init__() - self.net = P.SparseSegmentSumGrad() - - @jit - def construct(self, grad, indices, segment_ids, output_dim0): - return self.net(grad, indices, segment_ids, output_dim0) - - -def sparse_segment_sum_grad(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) - output_dim0_np = np.array(8, dtype=np.int32) - grad_ms = Tensor(grad_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - output_dim0_ms = Tensor(output_dim0_np) - net_ms = SparseSegmentSumGradNet() - out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) - expected = np.array([[6, 8, 10, 12], - [9, 10, 11, 12], - [22, 24, 26, 28], - [13, 14, 15, 16], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sum_grad_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) - output_dim0_np = np.array(8, dtype=np.int64) - grad_ms = Tensor(grad_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - output_dim0_ms = Tensor(output_dim0_np) - net_ms = SparseSegmentSumGradNet() - out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) - expected = np.array([[1, 2, 3, 4.], - [6, 8, 10, 12], - [9, 10, 11, 12], - [22, 24, 26, 28], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_grad_graph_float32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSumGrad - Expectation: the result match to tensorflow - """ - sparse_segment_sum_grad(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_grad_pynative_float64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSumGrad - Expectation: the result match to tensorflow - """ - sparse_segment_sum_grad_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations._grad_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSumGradNet(nn.Cell): + def __init__(self): + super(SparseSegmentSumGradNet, self).__init__() + self.net = P.SparseSegmentSumGrad() + + @jit + def construct(self, grad, indices, segment_ids, output_dim0): + return self.net(grad, indices, segment_ids, output_dim0) + + +def sparse_segment_sum_grad(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + output_dim0_np = np.array(8, dtype=np.int32) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSumGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[6, 8, 10, 12], + [9, 10, 11, 12], + [22, 24, 26, 28], + [13, 14, 15, 16], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_grad_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + output_dim0_np = np.array(8, dtype=np.int64) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSumGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[1, 2, 3, 4.], + [6, 8, 10, 12], + [9, 10, 11, 12], + [22, 24, 26, 28], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_grad_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sum_grad(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_grad_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sum_grad_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_op.py index 4327811a765..29ffbddca0b 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sum_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sum_op.py @@ -1,124 +1,124 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore as ms -import mindspore.ops.operations.sparse_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSumNet(nn.Cell): - - def __init__(self): - super(SparseSegmentSumNet, self).__init__() - self.net = P.SparseSegmentSum() - - @jit - def construct(self, x, indices, segment_ids): - return self.net(x, indices, segment_ids) - - -def dyn_case(): - net = SparseSegmentSumNet() - - x_dyn = Tensor(shape=[None, 4], dtype=ms.float32) - indices_dyn = Tensor(shape=[None], dtype=ms.int32) - segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) - net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn) - - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - out_ms = net(x_ms, indices_ms, segment_ids_ms) - - assert out_ms.asnumpy().shape == (4, 4) - - -def sparse_segment_sum(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - net_ms = SparseSegmentSumNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) - expected = np.array( - [[1, 2, 3, 4], [1, 2, 3, 4], [14, 16, 18, 20], [22, 24, 26, 28]], - dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sum_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x_np = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], - dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - net_ms = SparseSegmentSumNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) - expected = np.array( - [[6, 8, 10, 12], [5, 6, 7, 8], [22, 24, 26, 28], [13, 14, 15, 16]], - dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_graph_float32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSum - Expectation: the result match to tensorflow - """ - sparse_segment_sum(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_pynative_float64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSum - Expectation: the result match to tensorflow - """ - sparse_segment_sum_pynative(loss=1.0e-5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_sparse_segment_sum_dyn(): - """ - Feature: test SparseSegmentSum in gpu. - Description: test the ops in dynamic case. - Expectation: success. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - dyn_case() - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - dyn_case() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSumNet(nn.Cell): + + def __init__(self): + super(SparseSegmentSumNet, self).__init__() + self.net = P.SparseSegmentSum() + + @jit + def construct(self, x, indices, segment_ids): + return self.net(x, indices, segment_ids) + + +def dyn_case(): + net = SparseSegmentSumNet() + + x_dyn = Tensor(shape=[None, 4], dtype=ms.float32) + indices_dyn = Tensor(shape=[None], dtype=ms.int32) + segment_ids_dyn = Tensor(shape=[None], dtype=ms.int32) + net.set_inputs(x_dyn, indices_dyn, segment_ids_dyn) + + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + out_ms = net(x_ms, indices_ms, segment_ids_ms) + + assert out_ms.asnumpy().shape == (4, 4) + + +def sparse_segment_sum(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSumNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array( + [[1, 2, 3, 4], [1, 2, 3, 4], [14, 16, 18, 20], [22, 24, 26, 28]], + dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSumNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array( + [[6, 8, 10, 12], [5, 6, 7, 8], [22, 24, 26, 28], [13, 14, 15, 16]], + dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSum + Expectation: the result match to tensorflow + """ + sparse_segment_sum(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSum + Expectation: the result match to tensorflow + """ + sparse_segment_sum_pynative(loss=1.0e-5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_sparse_segment_sum_dyn(): + """ + Feature: test SparseSegmentSum in gpu. + Description: test the ops in dynamic case. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + dyn_case() diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py index 96cdc6d5111..8f50a804850 100644 --- a/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py +++ b/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py @@ -1,98 +1,98 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.sparse_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class SparseSegmentSumWithNumSegmentsNet(nn.Cell): - def __init__(self): - super(SparseSegmentSumWithNumSegmentsNet, self).__init__() - self.net = P.SparseSegmentSumWithNumSegments() - - @jit - def construct(self, x, indices, segment_ids, num_segments): - return self.net(x, indices, segment_ids, num_segments) - - -def sparse_segment_sum_with_num_segments(loss): - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) - indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) - segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) - num_segments_np = np.array(8, dtype=np.int32) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - num_segments_ms = Tensor(num_segments_np) - net_ms = SparseSegmentSumWithNumSegmentsNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) - expected = np.array([[1, 2, 3, 4], - [0, 0, 0, 0], - [0, 0, 0, 0], - [6, 8, 10, 12], - [0, 0, 0, 0], - [9, 10, 11, 12], - [0, 0, 0, 0], - [22, 24, 26, 28]], dtype=np.float32) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -def sparse_segment_sum_with_num_segments_pynative(loss): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) - indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) - segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) - num_segments_np = np.array(8, dtype=np.int64) - x_ms = Tensor(x_np) - indices_ms = Tensor(indices_np) - segment_ids_ms = Tensor(segment_ids_np) - num_segments_ms = Tensor(num_segments_np) - net_ms = SparseSegmentSumWithNumSegmentsNet() - out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) - expected = np.array([[6, 8, 10, 12], - [0, 0, 0, 0], - [0, 0, 0, 0], - [5, 6, 7, 8], - [0, 0, 0, 0], - [22, 24, 26, 28], - [0, 0, 0, 0], - [13, 14, 15, 16]], dtype=np.float64) - assert np.allclose(out_ms.asnumpy(), expected, loss, loss) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_with_num_segments_graph_float32_int32_int32_int32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSumWithNumSegments - Expectation: the result match to tensorflow - """ - sparse_segment_sum_with_num_segments(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_segment_sum_with_num_segments_pynative_float64_int64_int64_int64(): - """ - Feature: ALL To ALL - Description: test cases for SparseSegmentSumWithNumSegments - Expectation: the result match to tensorflow - """ - sparse_segment_sum_with_num_segments_pynative(loss=1.0e-5) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class SparseSegmentSumWithNumSegmentsNet(nn.Cell): + def __init__(self): + super(SparseSegmentSumWithNumSegmentsNet, self).__init__() + self.net = P.SparseSegmentSumWithNumSegments() + + @jit + def construct(self, x, indices, segment_ids, num_segments): + return self.net(x, indices, segment_ids, num_segments) + + +def sparse_segment_sum_with_num_segments(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) + num_segments_np = np.array(8, dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSumWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[1, 2, 3, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], + [6, 8, 10, 12], + [0, 0, 0, 0], + [9, 10, 11, 12], + [0, 0, 0, 0], + [22, 24, 26, 28]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_with_num_segments_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) + num_segments_np = np.array(8, dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSumWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[6, 8, 10, 12], + [0, 0, 0, 0], + [0, 0, 0, 0], + [5, 6, 7, 8], + [0, 0, 0, 0], + [22, 24, 26, 28], + [0, 0, 0, 0], + [13, 14, 15, 16]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_with_num_segments_graph_float32_int32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sum_with_num_segments(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_segment_sum_with_num_segments_pynative_float64_int64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sum_with_num_segments_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_sparse_maximum.py b/tests/st/ops/gpu/test_sparse_sparse_maximum.py index 132f7e7c1c2..1d257c76dc9 100644 --- a/tests/st/ops/gpu/test_sparse_sparse_maximum.py +++ b/tests/st/ops/gpu/test_sparse_sparse_maximum.py @@ -1,86 +1,86 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import pytest -import numpy as np -import mindspore -from mindspore import Tensor -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.sparse_ops as P - - -class SparseSparseMaximumNet(nn.Cell): - def __init__(self): - super(SparseSparseMaximumNet, self).__init__() - self.sparse_sparse_maximum = P.SparseSparseMaximum() - - def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape): - return self.sparse_sparse_maximum(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape) - - -def sparse_sparse_maximum(loss): - loss1 = loss - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) - values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) - shape1 = Tensor([3, 4]) - indices2 = Tensor([[0, 1], [2, 3]]) - values2 = Tensor([2, 3], dtype=mindspore.float32) - shape2 = Tensor([3, 4]) - net = SparseSparseMaximumNet() - m, n = net(indices1, values1, shape1, indices2, values2, shape2) - expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) - expected_n = np.array([4, 2, 3, 4, 3], dtype=np.float32) - assert np.allclose(m.asnumpy(), expected_m, loss, loss) - assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) - - -def sparse_sparse_maximum_pynative(loss): - loss1 = loss - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) - values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) - shape1 = Tensor([3, 4]) - indices2 = Tensor([[0, 1], [2, 3]]) - values2 = Tensor([2, 3], dtype=mindspore.float32) - shape2 = Tensor([3, 4]) - net = SparseSparseMaximumNet() - m, n = net(indices1, values1, shape1, indices2, values2, shape2) - print(m) - expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) - expected_n = np.array([4, 2, 3, 4, 3], dtype=np.float32) - assert np.allclose(m.asnumpy(), expected_m, loss, loss) - assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_sparse_maximum_graph_float32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSparseMaximum - Expectation: the result match to tensorflow - """ - sparse_sparse_maximum(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_sparse_maximum_pynative_float32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSparseMaximum - Expectation: the result match to tensorflow - """ - sparse_sparse_maximum_pynative(loss=1.0e-5) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import pytest +import numpy as np +import mindspore +from mindspore import Tensor +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P + + +class SparseSparseMaximumNet(nn.Cell): + def __init__(self): + super(SparseSparseMaximumNet, self).__init__() + self.sparse_sparse_maximum = P.SparseSparseMaximum() + + def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape): + return self.sparse_sparse_maximum(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape) + + +def sparse_sparse_maximum(loss): + loss1 = loss + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) + values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) + shape1 = Tensor([3, 4]) + indices2 = Tensor([[0, 1], [2, 3]]) + values2 = Tensor([2, 3], dtype=mindspore.float32) + shape2 = Tensor([3, 4]) + net = SparseSparseMaximumNet() + m, n = net(indices1, values1, shape1, indices2, values2, shape2) + expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) + expected_n = np.array([4, 2, 3, 4, 3], dtype=np.float32) + assert np.allclose(m.asnumpy(), expected_m, loss, loss) + assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) + + +def sparse_sparse_maximum_pynative(loss): + loss1 = loss + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) + values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) + shape1 = Tensor([3, 4]) + indices2 = Tensor([[0, 1], [2, 3]]) + values2 = Tensor([2, 3], dtype=mindspore.float32) + shape2 = Tensor([3, 4]) + net = SparseSparseMaximumNet() + m, n = net(indices1, values1, shape1, indices2, values2, shape2) + print(m) + expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) + expected_n = np.array([4, 2, 3, 4, 3], dtype=np.float32) + assert np.allclose(m.asnumpy(), expected_m, loss, loss) + assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_sparse_maximum_graph_float32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSparseMaximum + Expectation: the result match to tensorflow + """ + sparse_sparse_maximum(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_sparse_maximum_pynative_float32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSparseMaximum + Expectation: the result match to tensorflow + """ + sparse_sparse_maximum_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_sparse_minimum.py b/tests/st/ops/gpu/test_sparse_sparse_minimum.py index f3095dc8e29..8d1a5264055 100644 --- a/tests/st/ops/gpu/test_sparse_sparse_minimum.py +++ b/tests/st/ops/gpu/test_sparse_sparse_minimum.py @@ -1,85 +1,85 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import pytest -import mindspore -from mindspore import Tensor -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.sparse_ops as P -import numpy as np - - -class SparseSparseMinimumNet(nn.Cell): - def __init__(self): - super(SparseSparseMinimumNet, self).__init__() - self.sparsesparseminimum = P.SparseSparseMinimum() - - def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape): - return self.sparsesparseminimum(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape) - - -def sparse_sparse_minimum(loss): - loss1 = loss - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) - values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) - shape1 = Tensor([3, 4]) - indices2 = Tensor([[0, 1], [2, 3]]) - values2 = Tensor([2, 3], dtype=mindspore.float32) - shape2 = Tensor([3, 4]) - net = SparseSparseMinimumNet() - m, n = net(indices1, values1, shape1, indices2, values2, shape2) - expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) - expected_n = np.array([2, 0, 0, 0, 0], dtype=np.float32) - assert np.allclose(m.asnumpy(), expected_m, loss, loss) - assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) - - -def sparse_sparse_minimum_pynative(loss): - loss1 = loss - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) - values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) - shape1 = Tensor([3, 4]) - indices2 = Tensor([[0, 1], [2, 3]]) - values2 = Tensor([2, 3], dtype=mindspore.float32) - shape2 = Tensor([3, 4]) - net = SparseSparseMinimumNet() - m, n = net(indices1, values1, shape1, indices2, values2, shape2) - expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) - expected_n = np.array([2, 0, 0, 0, 0], dtype=np.float32) - assert np.allclose(m.asnumpy(), expected_m, loss, loss) - assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_sparse_maximum_graph_float32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSparseMinimum - Expectation: the result match to tensorflow - """ - sparse_sparse_minimum(loss=1.0e-4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sparse_sparse_minimum_pynative_float32(): - """ - Feature: ALL To ALL - Description: test cases for SparseSparseMinimum - Expectation: the result match to tensorflow - """ - sparse_sparse_minimum_pynative(loss=1.0e-5) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import pytest +import mindspore +from mindspore import Tensor +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +import numpy as np + + +class SparseSparseMinimumNet(nn.Cell): + def __init__(self): + super(SparseSparseMinimumNet, self).__init__() + self.sparsesparseminimum = P.SparseSparseMinimum() + + def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape): + return self.sparsesparseminimum(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape) + + +def sparse_sparse_minimum(loss): + loss1 = loss + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) + values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) + shape1 = Tensor([3, 4]) + indices2 = Tensor([[0, 1], [2, 3]]) + values2 = Tensor([2, 3], dtype=mindspore.float32) + shape2 = Tensor([3, 4]) + net = SparseSparseMinimumNet() + m, n = net(indices1, values1, shape1, indices2, values2, shape2) + expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) + expected_n = np.array([2, 0, 0, 0, 0], dtype=np.float32) + assert np.allclose(m.asnumpy(), expected_m, loss, loss) + assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) + + +def sparse_sparse_minimum_pynative(loss): + loss1 = loss + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + indices1 = Tensor([[0, 1], [0, 1], [2, 2], [0, 2]]) + values1 = Tensor([4, 2, 3, 4], dtype=mindspore.float32) + shape1 = Tensor([3, 4]) + indices2 = Tensor([[0, 1], [2, 3]]) + values2 = Tensor([2, 3], dtype=mindspore.float32) + shape2 = Tensor([3, 4]) + net = SparseSparseMinimumNet() + m, n = net(indices1, values1, shape1, indices2, values2, shape2) + expected_m = np.array([[0, 1], [0, 1], [2, 2], [0, 2], [2, 3]], dtype=np.int64) + expected_n = np.array([2, 0, 0, 0, 0], dtype=np.float32) + assert np.allclose(m.asnumpy(), expected_m, loss, loss) + assert np.allclose(n.asnumpy(), expected_n, loss1, loss1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_sparse_maximum_graph_float32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSparseMinimum + Expectation: the result match to tensorflow + """ + sparse_sparse_minimum(loss=1.0e-4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sparse_sparse_minimum_pynative_float32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSparseMinimum + Expectation: the result match to tensorflow + """ + sparse_sparse_minimum_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sspaddmm_op.py b/tests/st/ops/gpu/test_sspaddmm_op.py index c95bf0a66ba..18f50713b8b 100644 --- a/tests/st/ops/gpu/test_sspaddmm_op.py +++ b/tests/st/ops/gpu/test_sspaddmm_op.py @@ -1,146 +1,146 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import dtype as mstype -from mindspore.ops.operations.sparse_ops import Sspaddmm - - -class SspaddmmNet(nn.Cell): - - def __init__(self): - super(SspaddmmNet, self).__init__() - self.sspaddmm = Sspaddmm() - - def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, - x2_shape, x3_dense, alpha, beta): - return self.sspaddmm(x1_indices, x1_values, x1_shape, x2_indices, - x2_values, x2_shape, x3_dense, alpha, beta) - - -@pytest.mark.skip(reason="never run on ci or smoke test") -def test_sspaddmm_dyn(): - """ - Feature: test Sspaddmm ops in gpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - net = SspaddmmNet() - - x1_indices_dyn = Tensor(shape=[2, None], dtype=mstype.int64) - x1_values_dyn = Tensor(shape=[None], dtype=mstype.int32) - x1_shape_dyn = Tensor(shape=[None], dtype=mstype.int64) - x2_indices_dyn = Tensor(shape=[None, None], dtype=mstype.int64) - x2_values_dyn = Tensor(shape=[None], dtype=mstype.int32) - x2_shape_dyn = Tensor(shape=[None], dtype=mstype.int64) - x3_dense_dyn = Tensor(shape=[None, None], dtype=mstype.int32) - alpha = Tensor(1, dtype=mstype.int32) - beta = Tensor(1, dtype=mstype.int32) - - net.set_inputs(x1_indices_dyn, x1_values_dyn, x1_shape_dyn, x2_indices_dyn, - x2_values_dyn, x2_shape_dyn, x3_dense_dyn, alpha, beta) - - x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int64) - x1_values = Tensor(np.array([1, 2]), mstype.int32) - x1_shape = Tensor(np.array([3, 3]), mstype.int64) - x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int64) - x2_values = Tensor(np.array([3, 4]), mstype.int32) - x2_shape = Tensor(np.array([3, 3]), mstype.int64) - x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), - mstype.int32) - - out = net(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape, - x3_dense, alpha, beta) - expect_shapes = [(2, 8), (8,), (2,)] - for i in range(3): - assert out[i].asnumpy().shape == expect_shapes[i] - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sspaddmm_input_int32(): - """ - Feature: Sspaddmm gpu TEST. - Description: 2d int32 test case for Sspaddmm - Expectation: The value and shape of output are the expected values. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) - x1_values = Tensor(np.array([1, 2]), mstype.int32) - x1_shape = Tensor(np.array([3, 3]), mstype.int32) - x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int32) - x2_values = Tensor(np.array([3, 4]), mstype.int32) - x2_shape = Tensor(np.array([3, 3]), mstype.int32) - x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), - mstype.int32) - alpha = Tensor(np.array([1]), mstype.int32) - beta = Tensor(np.array([1]), mstype.int32) - net = SspaddmmNet() - y_indices, y_values, y_shape = net(x1_indices, x1_values, x1_shape, - x2_indices, x2_values, x2_shape, - x3_dense, alpha, beta) - y_indices_expect = np.array( - [[0, 1, 0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 0, 1, 2]], dtype=np.int64) - y_values_expect = np.array([1, 2, 9, 6, 3, 12, 8, 4], dtype=np.int32) - y_shape_expect = np.array([3, 3], dtype=np.int64) - - assert np.allclose(y_indices.asnumpy(), y_indices_expect.astype(np.int64), - 0.0001, 0.0001) - assert np.allclose(y_values.asnumpy(), y_values_expect.astype(np.int32), - 0.0001, 0.0001) - assert np.allclose(y_shape.asnumpy(), y_shape_expect.astype(np.int64), - 0.0001, 0.0001) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_sspaddmm_input_int64(): - """ - Feature: Sspaddmm gpu TEST. - Description: 2d int64 test case for Sspaddmm - Expectation: The value and shape of output are the expected values. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - - x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) - x1_values = Tensor(np.array([7, 6]), mstype.int32) - x1_shape = Tensor(np.array([3, 3]), mstype.int32) - x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int32) - x2_values = Tensor(np.array([11, 23]), mstype.int32) - x2_shape = Tensor(np.array([3, 3]), mstype.int32) - x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), - mstype.int32) - alpha = Tensor(np.array([2]), mstype.int32) - beta = Tensor(np.array([2]), mstype.int32) - net = SspaddmmNet() - y_indices, y_values, y_shape = net(x1_indices, x1_values, x1_shape, - x2_indices, x2_values, x2_shape, - x3_dense, alpha, beta) - y_indices_expect = np.array([[0, 1, 0, 0, 0, 1, 1, 1], - [0, 1, 0, 1, 2, 0, 1, 2]]) - y_values_expect = np.array([14, 12, 66, 44, 22, 138, 92, 46]) - y_shape_expect = np.array([3, 3]) - - assert np.allclose(y_indices.asnumpy(), y_indices_expect.astype(np.int64), - 0.0001, 0.0001) - assert np.allclose(y_values.asnumpy(), y_values_expect.astype(np.int32), - 0.0001, 0.0001) - assert np.allclose(y_shape.asnumpy(), y_shape_expect.astype(np.int64), - 0.0001, 0.0001) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore.ops.operations.sparse_ops import Sspaddmm + + +class SspaddmmNet(nn.Cell): + + def __init__(self): + super(SspaddmmNet, self).__init__() + self.sspaddmm = Sspaddmm() + + def construct(self, x1_indices, x1_values, x1_shape, x2_indices, x2_values, + x2_shape, x3_dense, alpha, beta): + return self.sspaddmm(x1_indices, x1_values, x1_shape, x2_indices, + x2_values, x2_shape, x3_dense, alpha, beta) + + +@pytest.mark.skip(reason="never run on ci or smoke test") +def test_sspaddmm_dyn(): + """ + Feature: test Sspaddmm ops in gpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = SspaddmmNet() + + x1_indices_dyn = Tensor(shape=[2, None], dtype=mstype.int64) + x1_values_dyn = Tensor(shape=[None], dtype=mstype.int32) + x1_shape_dyn = Tensor(shape=[None], dtype=mstype.int64) + x2_indices_dyn = Tensor(shape=[None, None], dtype=mstype.int64) + x2_values_dyn = Tensor(shape=[None], dtype=mstype.int32) + x2_shape_dyn = Tensor(shape=[None], dtype=mstype.int64) + x3_dense_dyn = Tensor(shape=[None, None], dtype=mstype.int32) + alpha = Tensor(1, dtype=mstype.int32) + beta = Tensor(1, dtype=mstype.int32) + + net.set_inputs(x1_indices_dyn, x1_values_dyn, x1_shape_dyn, x2_indices_dyn, + x2_values_dyn, x2_shape_dyn, x3_dense_dyn, alpha, beta) + + x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int64) + x1_values = Tensor(np.array([1, 2]), mstype.int32) + x1_shape = Tensor(np.array([3, 3]), mstype.int64) + x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int64) + x2_values = Tensor(np.array([3, 4]), mstype.int32) + x2_shape = Tensor(np.array([3, 3]), mstype.int64) + x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), + mstype.int32) + + out = net(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape, + x3_dense, alpha, beta) + expect_shapes = [(2, 8), (8,), (2,)] + for i in range(3): + assert out[i].asnumpy().shape == expect_shapes[i] + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sspaddmm_input_int32(): + """ + Feature: Sspaddmm gpu TEST. + Description: 2d int32 test case for Sspaddmm + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) + x1_values = Tensor(np.array([1, 2]), mstype.int32) + x1_shape = Tensor(np.array([3, 3]), mstype.int32) + x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int32) + x2_values = Tensor(np.array([3, 4]), mstype.int32) + x2_shape = Tensor(np.array([3, 3]), mstype.int32) + x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), + mstype.int32) + alpha = Tensor(np.array([1]), mstype.int32) + beta = Tensor(np.array([1]), mstype.int32) + net = SspaddmmNet() + y_indices, y_values, y_shape = net(x1_indices, x1_values, x1_shape, + x2_indices, x2_values, x2_shape, + x3_dense, alpha, beta) + y_indices_expect = np.array( + [[0, 1, 0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 0, 1, 2]], dtype=np.int64) + y_values_expect = np.array([1, 2, 9, 6, 3, 12, 8, 4], dtype=np.int32) + y_shape_expect = np.array([3, 3], dtype=np.int64) + + assert np.allclose(y_indices.asnumpy(), y_indices_expect.astype(np.int64), + 0.0001, 0.0001) + assert np.allclose(y_values.asnumpy(), y_values_expect.astype(np.int32), + 0.0001, 0.0001) + assert np.allclose(y_shape.asnumpy(), y_shape_expect.astype(np.int64), + 0.0001, 0.0001) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_sspaddmm_input_int64(): + """ + Feature: Sspaddmm gpu TEST. + Description: 2d int64 test case for Sspaddmm + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) + x1_values = Tensor(np.array([7, 6]), mstype.int32) + x1_shape = Tensor(np.array([3, 3]), mstype.int32) + x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int32) + x2_values = Tensor(np.array([11, 23]), mstype.int32) + x2_shape = Tensor(np.array([3, 3]), mstype.int32) + x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), + mstype.int32) + alpha = Tensor(np.array([2]), mstype.int32) + beta = Tensor(np.array([2]), mstype.int32) + net = SspaddmmNet() + y_indices, y_values, y_shape = net(x1_indices, x1_values, x1_shape, + x2_indices, x2_values, x2_shape, + x3_dense, alpha, beta) + y_indices_expect = np.array([[0, 1, 0, 0, 0, 1, 1, 1], + [0, 1, 0, 1, 2, 0, 1, 2]]) + y_values_expect = np.array([14, 12, 66, 44, 22, 138, 92, 46]) + y_shape_expect = np.array([3, 3]) + + assert np.allclose(y_indices.asnumpy(), y_indices_expect.astype(np.int64), + 0.0001, 0.0001) + assert np.allclose(y_values.asnumpy(), y_values_expect.astype(np.int32), + 0.0001, 0.0001) + assert np.allclose(y_shape.asnumpy(), y_shape_expect.astype(np.int64), + 0.0001, 0.0001) diff --git a/tests/st/ops/gpu/test_tensor_scatter_element_op.py b/tests/st/ops/gpu/test_tensor_scatter_element_op.py index afe5bceda22..a2b7d1f698f 100644 --- a/tests/st/ops/gpu/test_tensor_scatter_element_op.py +++ b/tests/st/ops/gpu/test_tensor_scatter_element_op.py @@ -1,118 +1,118 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import Parameter -from mindspore.ops import functional as F -from mindspore.ops.operations.array_ops import TensorScatterElements - -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - - -def scatter_element_np(input_x, indices, updates, axis, reduction="none"): - result = input_x.asnumpy().copy() - indices_np = indices.asnumpy().copy() - updates_np = updates.asnumpy().copy() - - i_len = indices_np.shape[0] - j_len = indices_np.shape[1] - - if axis < 0: - axis += len(result.shape) - - for i in range(i_len): - for j in range(j_len): - if axis == 0: - if reduction == "none": - result[indices_np[i][j]][j] = updates_np[i][j] - if reduction == "add": - result[indices_np[i][j]][j] += updates_np[i][j] - if axis == 1: - if reduction == "none": - result[i][indices_np[i][j]] = updates_np[i][j] - if reduction == "add": - result[i][indices_np[i][j]] += updates_np[i][j] - - return result - - -class TestTensorScatterElements(nn.Cell): - def __init__(self, input_x, indices, updates, axis, reduction): - super(TestTensorScatterElements, self).__init__() - self.axis = axis - self.reduction = reduction - self.input_x = Parameter(input_x, name="input_x") - self.indices = Parameter(indices, name="indices") - self.updates = Parameter(updates, name="updates") - self.scatter_elements = TensorScatterElements( - self.axis, self.reduction) - - def construct(self): - return self.scatter_elements(self.input_x, self.indices, self.updates) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32]) -@pytest.mark.parametrize('index_dtype', [np.int32, np.int64]) -@pytest.mark.parametrize('axis', [0, 1, -1]) -@pytest.mark.parametrize('reduction', ["none", "add"]) -def test_scatter_elements(dtype, index_dtype, axis, reduction): - """ - Feature: Op TensorScatterElements - Description: Scatter update value according indices to output. - output[indices[i][j]][j] = updates[i][j] if axis = 0, reduction="none" - output[i][indices[i][j]] += updates[i][j] if axis = 1, reduction="add" - Expectation: Ans is same as expected. - """ - x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) - indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) - update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) - - ms_output = TestTensorScatterElements( - x, indices, update, axis, reduction)() - np_output = scatter_element_np(x, indices, update, axis, reduction) - print("ms_output:\n", ms_output.asnumpy()) - assert np.allclose(ms_output.asnumpy(), np_output) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('dtype', [np.float32]) -@pytest.mark.parametrize('index_dtype', [np.int32]) -@pytest.mark.parametrize('axis', [0]) -@pytest.mark.parametrize('reduction', ["none", "add"]) -def test_scatter_add_with_axis_func(dtype, index_dtype, axis, reduction): - """ - Feature: test scatter_add_with_axis functional interface(scatter_add). - Description: Scatter update value according indices to output. - output[indices[i][j]][j] += updates[i][j] if axis = 0, - output[i][indices[i][j]] += updates[i][j] if axis = 1. - Expectation: Ans is same as expected. - """ - x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) - indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) - update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) - - #cause scatter_add will change the value of input, so we first calculate numpy output. - np_output = scatter_element_np(x, indices, update, axis, reduction) - ms_output = F.tensor_scatter_elements(x, indices, update, axis, reduction) - print("np_output:\n", np_output) - print("ms_output:\n", ms_output.asnumpy()) - assert np.allclose(ms_output.asnumpy(), np_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import Parameter +from mindspore.ops import functional as F +from mindspore.ops.operations.array_ops import TensorScatterElements + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + +def scatter_element_np(input_x, indices, updates, axis, reduction="none"): + result = input_x.asnumpy().copy() + indices_np = indices.asnumpy().copy() + updates_np = updates.asnumpy().copy() + + i_len = indices_np.shape[0] + j_len = indices_np.shape[1] + + if axis < 0: + axis += len(result.shape) + + for i in range(i_len): + for j in range(j_len): + if axis == 0: + if reduction == "none": + result[indices_np[i][j]][j] = updates_np[i][j] + if reduction == "add": + result[indices_np[i][j]][j] += updates_np[i][j] + if axis == 1: + if reduction == "none": + result[i][indices_np[i][j]] = updates_np[i][j] + if reduction == "add": + result[i][indices_np[i][j]] += updates_np[i][j] + + return result + + +class TestTensorScatterElements(nn.Cell): + def __init__(self, input_x, indices, updates, axis, reduction): + super(TestTensorScatterElements, self).__init__() + self.axis = axis + self.reduction = reduction + self.input_x = Parameter(input_x, name="input_x") + self.indices = Parameter(indices, name="indices") + self.updates = Parameter(updates, name="updates") + self.scatter_elements = TensorScatterElements( + self.axis, self.reduction) + + def construct(self): + return self.scatter_elements(self.input_x, self.indices, self.updates) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32]) +@pytest.mark.parametrize('index_dtype', [np.int32, np.int64]) +@pytest.mark.parametrize('axis', [0, 1, -1]) +@pytest.mark.parametrize('reduction', ["none", "add"]) +def test_scatter_elements(dtype, index_dtype, axis, reduction): + """ + Feature: Op TensorScatterElements + Description: Scatter update value according indices to output. + output[indices[i][j]][j] = updates[i][j] if axis = 0, reduction="none" + output[i][indices[i][j]] += updates[i][j] if axis = 1, reduction="add" + Expectation: Ans is same as expected. + """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) + indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) + update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) + + ms_output = TestTensorScatterElements( + x, indices, update, axis, reduction)() + np_output = scatter_element_np(x, indices, update, axis, reduction) + print("ms_output:\n", ms_output.asnumpy()) + assert np.allclose(ms_output.asnumpy(), np_output) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('dtype', [np.float32]) +@pytest.mark.parametrize('index_dtype', [np.int32]) +@pytest.mark.parametrize('axis', [0]) +@pytest.mark.parametrize('reduction', ["none", "add"]) +def test_scatter_add_with_axis_func(dtype, index_dtype, axis, reduction): + """ + Feature: test scatter_add_with_axis functional interface(scatter_add). + Description: Scatter update value according indices to output. + output[indices[i][j]][j] += updates[i][j] if axis = 0, + output[i][indices[i][j]] += updates[i][j] if axis = 1. + Expectation: Ans is same as expected. + """ + x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)) + indices = Tensor(np.array([[-1, 0, 1], [0, 1, 2]], dtype=index_dtype)) + update = Tensor(np.array([[1, 2, 2], [4, 5, 8]], dtype=dtype)) + + #cause scatter_add will change the value of input, so we first calculate numpy output. + np_output = scatter_element_np(x, indices, update, axis, reduction) + ms_output = F.tensor_scatter_elements(x, indices, update, axis, reduction) + print("np_output:\n", np_output) + print("ms_output:\n", ms_output.asnumpy()) + assert np.allclose(ms_output.asnumpy(), np_output) diff --git a/tests/st/ops/gpu/test_tril_indices_op.py b/tests/st/ops/gpu/test_tril_indices_op.py index a47b447326f..914ea8307ee 100644 --- a/tests/st/ops/gpu/test_tril_indices_op.py +++ b/tests/st/ops/gpu/test_tril_indices_op.py @@ -1,62 +1,62 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.math_ops as ops -from mindspore.common import dtype as mstype - - -class TrilIndicesNet(nn.Cell): - def __init__(self, row, col, offset=0, dtype=mstype.int32): - super().__init__() - self.tril_indices = ops.TrilIndices(row, col, offset, dtype) - - def construct(self): - return self.tril_indices() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_indices_int32_positive_offset(): - """ - Feature: TrilIndcies GPU TEST. - Description: dtype int32 and positive offset for TrilIndices. - Expectation: the result match to numpy. - """ - for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: - context.set_context(mode=mode, device_target="GPU") - tril_indices = TrilIndicesNet(row=300, col=200, offset=50, dtype=mstype.int32) - output = tril_indices() - expect = np.array(np.tril_indices(n=300, m=200, k=50)).astype(np.int32) - assert(output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_indices_int64_negative_offset(): - """ - Feature: TrilIndcies GPU TEST. - Description: dtype int64 and negative offset for TrilIndices. - Expectation: the result match to numpy. - """ - for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: - context.set_context(mode=mode, device_target="GPU") - tril_indices = TrilIndicesNet(row=500, col=700, offset=-200, dtype=mstype.int64) - output = tril_indices() - expect = np.array(np.tril_indices(n=500, m=700, k=-200)).astype(np.int64) - assert(output.asnumpy() == expect).all() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.math_ops as ops +from mindspore.common import dtype as mstype + + +class TrilIndicesNet(nn.Cell): + def __init__(self, row, col, offset=0, dtype=mstype.int32): + super().__init__() + self.tril_indices = ops.TrilIndices(row, col, offset, dtype) + + def construct(self): + return self.tril_indices() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_indices_int32_positive_offset(): + """ + Feature: TrilIndcies GPU TEST. + Description: dtype int32 and positive offset for TrilIndices. + Expectation: the result match to numpy. + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + tril_indices = TrilIndicesNet(row=300, col=200, offset=50, dtype=mstype.int32) + output = tril_indices() + expect = np.array(np.tril_indices(n=300, m=200, k=50)).astype(np.int32) + assert(output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_indices_int64_negative_offset(): + """ + Feature: TrilIndcies GPU TEST. + Description: dtype int64 and negative offset for TrilIndices. + Expectation: the result match to numpy. + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + tril_indices = TrilIndicesNet(row=500, col=700, offset=-200, dtype=mstype.int64) + output = tril_indices() + expect = np.array(np.tril_indices(n=500, m=700, k=-200)).astype(np.int64) + assert(output.asnumpy() == expect).all() diff --git a/tests/st/ops/gpu/test_triu_indices_op.py b/tests/st/ops/gpu/test_triu_indices_op.py index 613f57eeff6..7c61b137380 100644 --- a/tests/st/ops/gpu/test_triu_indices_op.py +++ b/tests/st/ops/gpu/test_triu_indices_op.py @@ -1,62 +1,62 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -import mindspore.ops.operations.math_ops as ops -from mindspore.common import dtype as mstype - - -class TriuIndicesNet(nn.Cell): - def __init__(self, row, col, offset=0, dtype=mstype.int32): - super().__init__() - self.triu_indices = ops.TriuIndices(row, col, offset, dtype) - - def construct(self): - return self.triu_indices() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_triu_indices_int32_positive_offset(): - """ - Feature: TriuIndcies GPU TEST. - Description: dtype int32 and positive offset for TriuIndices. - Expectation: the result match to numpy. - """ - for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: - context.set_context(mode=mode, device_target="GPU") - triu_indices = TriuIndicesNet(row=300, col=200, offset=50, dtype=mstype.int32) - output = triu_indices() - expect = np.array(np.triu_indices(n=300, m=200, k=50)).astype(np.int32) - assert(output.asnumpy() == expect).all() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_triu_indices_int64_negative_offset(): - """ - Feature: TriuIndcies GPU TEST. - Description: dtype int64 and negative offset for TriuIndices. - Expectation: the result match to numpy. - """ - for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: - context.set_context(mode=mode, device_target="GPU") - triu_indices = TriuIndicesNet(row=500, col=700, offset=-200, dtype=mstype.int64) - output = triu_indices() - expect = np.array(np.triu_indices(n=500, m=700, k=-200)).astype(np.int64) - assert(output.asnumpy() == expect).all() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.math_ops as ops +from mindspore.common import dtype as mstype + + +class TriuIndicesNet(nn.Cell): + def __init__(self, row, col, offset=0, dtype=mstype.int32): + super().__init__() + self.triu_indices = ops.TriuIndices(row, col, offset, dtype) + + def construct(self): + return self.triu_indices() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_triu_indices_int32_positive_offset(): + """ + Feature: TriuIndcies GPU TEST. + Description: dtype int32 and positive offset for TriuIndices. + Expectation: the result match to numpy. + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + triu_indices = TriuIndicesNet(row=300, col=200, offset=50, dtype=mstype.int32) + output = triu_indices() + expect = np.array(np.triu_indices(n=300, m=200, k=50)).astype(np.int32) + assert(output.asnumpy() == expect).all() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_triu_indices_int64_negative_offset(): + """ + Feature: TriuIndcies GPU TEST. + Description: dtype int64 and negative offset for TriuIndices. + Expectation: the result match to numpy. + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + triu_indices = TriuIndicesNet(row=500, col=700, offset=-200, dtype=mstype.int64) + output = triu_indices() + expect = np.array(np.triu_indices(n=500, m=700, k=-200)).astype(np.int64) + assert(output.asnumpy() == expect).all() diff --git a/tests/st/ops/gpu/test_xdivy_op.py b/tests/st/ops/gpu/test_xdivy_op.py index a7e7020fcbd..7331f00e4c2 100644 --- a/tests/st/ops/gpu/test_xdivy_op.py +++ b/tests/st/ops/gpu/test_xdivy_op.py @@ -1,182 +1,182 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore as ms - -TF_INSTALL_FLG = 1 -try: - import tensorflow as tf -except ImportError: - TF_INSTALL_FLG = 0 - - -class NetXDivy(nn.Cell): - def __init__(self): - super(NetXDivy, self).__init__() - self.xdivy = P.Xdivy() - - def construct(self, x, y): - return self.xdivy(x, y) - - -def xdivy(nptype): - x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) - x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) - y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) - x3_np = np.random.randint(1, 5, 1).astype(nptype) - y3_np = np.random.randint(1, 5, 1).astype(nptype) - x4_np = np.array(78).astype(nptype) - y4_np = np.array(37.5).astype(nptype) - - x0 = Tensor(x0_np) - y0 = Tensor(y0_np) - x1 = Tensor(x1_np) - y1 = Tensor(y1_np) - x2 = Tensor(x2_np) - y2 = Tensor(y2_np) - x3 = Tensor(x3_np) - y3 = Tensor(y3_np) - x4 = Tensor(x4_np) - y4 = Tensor(y4_np) - - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - div_net = NetXDivy() - output0 = div_net(x0, y0) - expect0 = np.divide(x0_np, y0_np) - diff0 = output0.asnumpy() - expect0 - error0 = np.ones(shape=expect0.shape) * 1.0e-5 - assert np.all(diff0 < error0) - assert output0.shape == expect0.shape - - output1 = div_net(x1, y1) - expect1 = np.divide(x1_np, y1_np) - diff1 = output1.asnumpy() - expect1 - error1 = np.ones(shape=expect1.shape) * 1.0e-5 - assert np.all(diff1 < error1) - assert output1.shape == expect1.shape - - output2 = div_net(x2, y2) - expect2 = np.divide(x2_np, y2_np) - diff2 = output2.asnumpy() - expect2 - error2 = np.ones(shape=expect2.shape) * 1.0e-5 - assert np.all(diff2 < error2) - assert output2.shape == expect2.shape - - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - output3 = div_net(x3, y3) - expect3 = np.divide(x3_np, y3_np) - diff3 = output3.asnumpy() - expect3 - error3 = np.ones(shape=expect3.shape) * 1.0e-5 - assert np.all(diff3 < error3) - assert output3.shape == expect3.shape - - output4 = div_net(x4, y4) - expect4 = np.divide(x4_np, y4_np) - diff4 = output4.asnumpy() - expect4 - error4 = np.ones(shape=expect4.shape) * 1.0e-5 - assert np.all(diff4 < error4) - assert output4.shape == expect4.shape - - -def xdivy_sf_check(mstype, tftype): - # test divided zero - with tf.device('/cpu:0'): - tx = tf.constant([-4.0, 0.0, 1.0, 0.0], dtype=tftype) - ty = tf.constant([3.0, 2.0, 0.0, 0.0], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - - x = ms.Tensor(np.array([-4.0, 0.0, 1.0, 0.0]), dtype=mstype) - y = ms.Tensor(np.array([3.0, 2.0, 0.0, 0.0]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - # test broadcast - with tf.device('/cpu:0'): - tx = tf.constant([-4.0, 5.0, 0.0], dtype=tftype) - ty = tf.constant([[3.0], [2.0]], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - x = ms.Tensor(np.array([-4.0, 5.0, 0.0]), dtype=mstype) - y = ms.Tensor(np.array([[3.0], [2.0]]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - # test broadcast - with tf.device('/cpu:0'): - tx = tf.constant([-4.0], dtype=tftype) - ty = tf.constant([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]], dtype=tftype) - tz = tf.math.xdivy(tx, ty) - x = ms.Tensor(np.array([-4.0]), dtype=mstype) - y = ms.Tensor(np.array([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]]), dtype=mstype) - z = ms.ops.xdivy(x, y) - assert tz.numpy().all() == z.asnumpy().all() - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float64(): - """ - Feature: test xdivy primitive use float64 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float64) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - xdivy_sf_check(ms.float64, tf.dtypes.float64) - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - xdivy_sf_check(ms.float64, tf.dtypes.float64) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float32(): - """ - Feature: test xdivy primitive use float32 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float32) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - xdivy_sf_check(ms.float32, tf.dtypes.float32) - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - xdivy_sf_check(ms.float32, tf.dtypes.float32) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_div_float16(): - """ - Feature: test xdivy primitive use float16 - Description: compare result with numpy&& tensorflow - Expectation: calculate result same to numpy&&tensorflow - """ - xdivy(np.float16) - if TF_INSTALL_FLG == 0: - return - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - xdivy_sf_check(ms.float16, tf.dtypes.float16) - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - xdivy_sf_check(ms.float16, tf.dtypes.float16) +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore as ms + +TF_INSTALL_FLG = 1 +try: + import tensorflow as tf +except ImportError: + TF_INSTALL_FLG = 0 + + +class NetXDivy(nn.Cell): + def __init__(self): + super(NetXDivy, self).__init__() + self.xdivy = P.Xdivy() + + def construct(self, x, y): + return self.xdivy(x, y) + + +def xdivy(nptype): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype) + x3_np = np.random.randint(1, 5, 1).astype(nptype) + y3_np = np.random.randint(1, 5, 1).astype(nptype) + x4_np = np.array(78).astype(nptype) + y4_np = np.array(37.5).astype(nptype) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + div_net = NetXDivy() + output0 = div_net(x0, y0) + expect0 = np.divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = div_net(x1, y1) + expect1 = np.divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = div_net(x2, y2) + expect2 = np.divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + output3 = div_net(x3, y3) + expect3 = np.divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = div_net(x4, y4) + expect4 = np.divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + +def xdivy_sf_check(mstype, tftype): + # test divided zero + with tf.device('/cpu:0'): + tx = tf.constant([-4.0, 0.0, 1.0, 0.0], dtype=tftype) + ty = tf.constant([3.0, 2.0, 0.0, 0.0], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + + x = ms.Tensor(np.array([-4.0, 0.0, 1.0, 0.0]), dtype=mstype) + y = ms.Tensor(np.array([3.0, 2.0, 0.0, 0.0]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + # test broadcast + with tf.device('/cpu:0'): + tx = tf.constant([-4.0, 5.0, 0.0], dtype=tftype) + ty = tf.constant([[3.0], [2.0]], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + x = ms.Tensor(np.array([-4.0, 5.0, 0.0]), dtype=mstype) + y = ms.Tensor(np.array([[3.0], [2.0]]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + # test broadcast + with tf.device('/cpu:0'): + tx = tf.constant([-4.0], dtype=tftype) + ty = tf.constant([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]], dtype=tftype) + tz = tf.math.xdivy(tx, ty) + x = ms.Tensor(np.array([-4.0]), dtype=mstype) + y = ms.Tensor(np.array([[3.0, 1.0, 1.0], [2.0, 3.0, 5.0]]), dtype=mstype) + z = ms.ops.xdivy(x, y) + assert tz.numpy().all() == z.asnumpy().all() + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float64(): + """ + Feature: test xdivy primitive use float64 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float64) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + xdivy_sf_check(ms.float64, tf.dtypes.float64) + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + xdivy_sf_check(ms.float64, tf.dtypes.float64) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float32(): + """ + Feature: test xdivy primitive use float32 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float32) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + xdivy_sf_check(ms.float32, tf.dtypes.float32) + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + xdivy_sf_check(ms.float32, tf.dtypes.float32) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_div_float16(): + """ + Feature: test xdivy primitive use float16 + Description: compare result with numpy&& tensorflow + Expectation: calculate result same to numpy&&tensorflow + """ + xdivy(np.float16) + if TF_INSTALL_FLG == 0: + return + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + xdivy_sf_check(ms.float16, tf.dtypes.float16) + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + xdivy_sf_check(ms.float16, tf.dtypes.float16) diff --git a/tests/st/ops/gpu/test_zeta_op.py b/tests/st/ops/gpu/test_zeta_op.py index 1abc48bc37e..81798c87032 100644 --- a/tests/st/ops/gpu/test_zeta_op.py +++ b/tests/st/ops/gpu/test_zeta_op.py @@ -1,68 +1,68 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations.math_ops import Zeta - - -class NetZeta(nn.Cell): - - def __init__(self): - super(NetZeta, self).__init__() - self.zeta = Zeta() - - def construct(self, x, y): - return self.zeta(x, y) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_zeta_1d_input_float32_output_float32(): - """ - Feature: Zeta gpu TEST. - Description: 1d test case for Zeta - Expectation: The value and shape of output are the expected values. - """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - x_ms = Tensor(np.array([3, 3, 9]).astype(np.float32)) - q_ms = Tensor(np.array([4, 2, 9]).astype(np.float32)) - net = NetZeta() - z_ms = net(x_ms, q_ms) - expect = np.array([4.0019866e-02, 2.02056915e-01, 4.4048485e-09]) - - assert np.allclose(z_ms.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_zeta_1d_input_float64_output_float64(): - """ - Feature: Zeta gpu TEST. - Description: 1d test case for Zeta - Expectation: The value and shape of output are the expected values. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - - x_ms = Tensor(np.array([13, 5.3, 6, 10]).astype(np.float64)) - q_ms = Tensor(np.array([4.4, 21.2, -4.7, -3.7]).astype(np.float64)) - net = NetZeta() - z_ms = net(x_ms, q_ms) - expect = np.array([4.6569e-09, 5.0921e-07, 1.3805e+03, 1.6939e+05]) - - assert np.allclose(z_ms.asnumpy(), expect.astype(np.float64), 0.0001, 0.0001) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations.math_ops import Zeta + + +class NetZeta(nn.Cell): + + def __init__(self): + super(NetZeta, self).__init__() + self.zeta = Zeta() + + def construct(self, x, y): + return self.zeta(x, y) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_zeta_1d_input_float32_output_float32(): + """ + Feature: Zeta gpu TEST. + Description: 1d test case for Zeta + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + x_ms = Tensor(np.array([3, 3, 9]).astype(np.float32)) + q_ms = Tensor(np.array([4, 2, 9]).astype(np.float32)) + net = NetZeta() + z_ms = net(x_ms, q_ms) + expect = np.array([4.0019866e-02, 2.02056915e-01, 4.4048485e-09]) + + assert np.allclose(z_ms.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_zeta_1d_input_float64_output_float64(): + """ + Feature: Zeta gpu TEST. + Description: 1d test case for Zeta + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + x_ms = Tensor(np.array([13, 5.3, 6, 10]).astype(np.float64)) + q_ms = Tensor(np.array([4.4, 21.2, -4.7, -3.7]).astype(np.float64)) + net = NetZeta() + z_ms = net(x_ms, q_ms) + expect = np.array([4.6569e-09, 5.0921e-07, 1.3805e+03, 1.6939e+05]) + + assert np.allclose(z_ms.asnumpy(), expect.astype(np.float64), 0.0001, 0.0001) diff --git a/tests/st/ops/test_bessel.py b/tests/st/ops/test_bessel.py old mode 100755 new mode 100644 index 3331a4e5c5a..5d9336836e3 --- a/tests/st/ops/test_bessel.py +++ b/tests/st/ops/test_bessel.py @@ -1,250 +1,250 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops as F - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_j0(dtype, eps): - """ - Feature: bessel j0 function - Description: test cases for BesselJ0 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [0.9384698, 0.7651977, 0.22389078, -0.3971498]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_j0(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_j1(dtype, eps): - """ - Feature: bessel j1 function - Description: test cases for BesselJ1 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [0.24226846, 0.44005057, 0.5767248, -0.06604332]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_j1(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_i0(dtype, eps): - """ - Feature: bessel i0 function - Description: test cases for BesselI0 - Expectation: the result matches scipy - """ - x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) - expect = np.array( - [1.2660658, 1.0634834, 1.0634834, 1.2660658]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE) - - output = F.bessel_i0(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_i0e(dtype, eps): - """ - Feature: bessel i0e function - Description: test cases for BesselI0e - Expectation: the result matches scipy - """ - x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) - expect = np.array( - [0.4657596, 0.64503527, 0.64503527, 0.4657596]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_i0e(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_k0(dtype, eps): - """ - Feature: bessel k0 function - Description: test cases for BesselK0 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [0.92441905, 0.42102444, 0.11389387, 0.01115968]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_k0(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_k0e(dtype, eps): - """ - Feature: bessel k-e function - Description: test cases for BesselK0e - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [1.5241094, 1.1444631, 0.84156823, 0.6092977]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_k0e(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_y0(dtype, eps): - """ - Feature: bessel y0 function - Description: test cases for BesselY0 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [-0.44451874, 0.08825696, 0.51037567, -0.01694074]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_y0(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_y1(dtype, eps): - """ - Feature: bessel y1 function - Description: test cases for BesselY1 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array([-1.47147239, -0.78121282, - -0.10703243, 0.39792571]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_y1(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_i1(dtype, eps): - """ - Feature: bessel i1 function - Description: test cases for BesselI1 - Expectation: the result matches scipy - """ - x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) - expect = np.array( - [-0.5651591, -0.25789431, 0.25789431, 0.5651591]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_i1(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_i1e(dtype, eps): - """ - Feature: bessel i1e function - Description: test cases for BesselI1e - Expectation: the result matches scipy - """ - x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) - expect = np.array( - [-0.20791042, -0.15642083, 0.15642083, 0.20791042]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_i1e(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_k1(dtype, eps): - """ - Feature: bessel k1 function - Description: test cases for BesselK1 - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [1.65644112, 0.60190723, 0.13986588, 0.0124835]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_k1(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) -def test_bessel_k1e(dtype, eps): - """ - Feature: bessel k1e function - Description: test cases for BesselK1e - Expectation: the result matches scipy - """ - x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) - expect = np.array( - [2.73100971, 1.63615349, 1.03347685, 0.68157595]).astype(dtype) - error = np.ones(shape=[4]) * eps - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - output = F.bessel_k1e(x) - diff = np.abs(output.asnumpy() - expect) - assert np.all(diff < error) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops as F + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_j0(dtype, eps): + """ + Feature: bessel j0 function + Description: test cases for BesselJ0 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [0.9384698, 0.7651977, 0.22389078, -0.3971498]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_j0(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_j1(dtype, eps): + """ + Feature: bessel j1 function + Description: test cases for BesselJ1 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [0.24226846, 0.44005057, 0.5767248, -0.06604332]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_j1(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_i0(dtype, eps): + """ + Feature: bessel i0 function + Description: test cases for BesselI0 + Expectation: the result matches scipy + """ + x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) + expect = np.array( + [1.2660658, 1.0634834, 1.0634834, 1.2660658]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE) + + output = F.bessel_i0(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_i0e(dtype, eps): + """ + Feature: bessel i0e function + Description: test cases for BesselI0e + Expectation: the result matches scipy + """ + x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) + expect = np.array( + [0.4657596, 0.64503527, 0.64503527, 0.4657596]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_i0e(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_k0(dtype, eps): + """ + Feature: bessel k0 function + Description: test cases for BesselK0 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [0.92441905, 0.42102444, 0.11389387, 0.01115968]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_k0(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_k0e(dtype, eps): + """ + Feature: bessel k-e function + Description: test cases for BesselK0e + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [1.5241094, 1.1444631, 0.84156823, 0.6092977]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_k0e(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_y0(dtype, eps): + """ + Feature: bessel y0 function + Description: test cases for BesselY0 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [-0.44451874, 0.08825696, 0.51037567, -0.01694074]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_y0(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_y1(dtype, eps): + """ + Feature: bessel y1 function + Description: test cases for BesselY1 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array([-1.47147239, -0.78121282, + -0.10703243, 0.39792571]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_y1(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_i1(dtype, eps): + """ + Feature: bessel i1 function + Description: test cases for BesselI1 + Expectation: the result matches scipy + """ + x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) + expect = np.array( + [-0.5651591, -0.25789431, 0.25789431, 0.5651591]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_i1(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_i1e(dtype, eps): + """ + Feature: bessel i1e function + Description: test cases for BesselI1e + Expectation: the result matches scipy + """ + x = Tensor(np.array([-1, -0.5, 0.5, 1]).astype(dtype)) + expect = np.array( + [-0.20791042, -0.15642083, 0.15642083, 0.20791042]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_i1e(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_k1(dtype, eps): + """ + Feature: bessel k1 function + Description: test cases for BesselK1 + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [1.65644112, 0.60190723, 0.13986588, 0.0124835]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_k1(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('dtype, eps', [(np.float16, 1.0e-3), (np.float32, 1.0e-6), (np.float64, 1.0e-6)]) +def test_bessel_k1e(dtype, eps): + """ + Feature: bessel k1e function + Description: test cases for BesselK1e + Expectation: the result matches scipy + """ + x = Tensor(np.array([0.5, 1., 2., 4.]).astype(dtype)) + expect = np.array( + [2.73100971, 1.63615349, 1.03347685, 0.68157595]).astype(dtype) + error = np.ones(shape=[4]) * eps + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + output = F.bessel_k1e(x) + diff = np.abs(output.asnumpy() - expect) + assert np.all(diff < error) diff --git a/tests/st/ops/test_cauchy.py b/tests/st/ops/test_cauchy.py index 74eb1662e56..7bc39fe37a4 100644 --- a/tests/st/ops/test_cauchy.py +++ b/tests/st/ops/test_cauchy.py @@ -1,44 +1,44 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def __init__(self, size): - super().__init__() - self.cauchy = ops.Cauchy(size) - - def construct(self): - return self.cauchy() - -@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_cauchy(mode): - """ - Feature: Cauchy op - Description: Verify the result of cauchy - Expectation: success - """ - ms.set_context(mode=mode) - size = [2, 3] - net = Net(size) - out = net() - assert out.shape == tuple(size) +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def __init__(self, size): + super().__init__() + self.cauchy = ops.Cauchy(size) + + def construct(self): + return self.cauchy() + +@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_cauchy(mode): + """ + Feature: Cauchy op + Description: Verify the result of cauchy + Expectation: success + """ + ms.set_context(mode=mode) + size = [2, 3] + net = Net(size) + out = net() + assert out.shape == tuple(size) diff --git a/tests/st/ops/test_flip.py b/tests/st/ops/test_flip.py index de2501d63f1..8b1565bf5b6 100644 --- a/tests/st/ops/test_flip.py +++ b/tests/st/ops/test_flip.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.flip(x, (0, 2)) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_flip_normal(mode): - """ - Feature: flip - Description: Verify the result of flip - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[5., 4.], - [7., 6.]], - [[1., 0.], - [3., 2.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.flip(x, (0, 2)) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_flip_normal(mode): + """ + Feature: flip + Description: Verify the result of flip + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[5., 4.], + [7., 6.]], + [[1., 0.], + [3., 2.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_fliplr.py b/tests/st/ops/test_fliplr.py index 2e9aed5b7b3..681880c8c55 100644 --- a/tests/st/ops/test_fliplr.py +++ b/tests/st/ops/test_fliplr.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.fliplr(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fliplr_normal(mode): - """ - Feature: fliplr - Description: Verify the result of fliplr - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[2., 3.], - [0., 1.]], - [[6., 7.], - [4., 5.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.fliplr(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fliplr_normal(mode): + """ + Feature: fliplr + Description: Verify the result of fliplr + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[2., 3.], + [0., 1.]], + [[6., 7.], + [4., 5.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_flipud.py b/tests/st/ops/test_flipud.py index 34a99c95032..64b3c9f66a0 100644 --- a/tests/st/ops/test_flipud.py +++ b/tests/st/ops/test_flipud.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.flipud(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_flipud_normal(mode): - """ - Feature: flipud - Description: Verify the result of flipud - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[4., 5.], - [6., 7.]], - [[0., 1.], - [2., 3.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.flipud(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_flipud_normal(mode): + """ + Feature: flipud + Description: Verify the result of flipud + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[4., 5.], + [6., 7.]], + [[0., 1.], + [2., 3.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_fractionalmaxpool3dwithfixedksize.py b/tests/st/ops/test_fractionalmaxpool3dwithfixedksize.py index d1246c31562..1dd01116cb2 100644 --- a/tests/st/ops/test_fractionalmaxpool3dwithfixedksize.py +++ b/tests/st/ops/test_fractionalmaxpool3dwithfixedksize.py @@ -1,76 +1,76 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.nn as nn -import mindspore.context as context -from mindspore import Tensor -import mindspore.ops.operations.nn_ops as ops -import mindspore.ops.operations._grad_ops as grad_ops - - -class NetFractionalMaxPool3DWithFixedKsize(nn.Cell): - def __init__(self, ksize, output_shape): - super(NetFractionalMaxPool3DWithFixedKsize, self).__init__() - self.fractional_max_pool_3d_with_fixed_ksize = ops.FractionalMaxPool3DWithFixedKsize( - ksize, output_shape) - - def construct(self, x, random_sapmples): - return self.fractional_max_pool_3d_with_fixed_ksize(x, random_sapmples) - - -class NetFractionalMaxPool3DGradWithFixedKsize(nn.Cell): - def __init__(self): - super(NetFractionalMaxPool3DGradWithFixedKsize, self).__init__() - self.fractional_max_pool_3d_grad_with_fixed_ksize = grad_ops.FractionalMaxPool3DGradWithFixedKsize() - - def construct(self, origin_input, out_backprop, argmax): - return self.fractional_max_pool_3d_grad_with_fixed_ksize(origin_input, out_backprop, argmax) - - -@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_fractionalmaxpool3dwithfixedksize(): - """ - Feature: FractionalMaxPool3DWithFixedKsize - Description: Test of input - Expectation: The results are as expected - """ - context_mode_types = [context.GRAPH_MODE, context.PYNATIVE_MODE] - types_input1 = [np.float16, np.float32, np.int32, np.int64] - types_input2 = [np.float16, np.float32] - for context_mode_type in context_mode_types: - context.set_context(mode=context_mode_type) - for type_input1 in types_input1: - for type_input2 in types_input2: - x_np = np.array([i+1 for i in range(64)] - ).reshape([1, 1, 4, 4, 4]).astype(type_input1) - x_ms = Tensor(x_np) - x_dyn = Tensor(shape=(1, 1, None, None, None), - dtype=x_ms.dtype) - random_samples = Tensor(np.array([0.5, 0.5, 0.8]).reshape( - [1, 1, 3]).astype(type_input2)) - ksize = (1, 1, 1) - output_shape = (2, 2, 3) - net = NetFractionalMaxPool3DWithFixedKsize(ksize, output_shape) - net.set_inputs(x_dyn, random_samples) - output_ms, argmax = net(x_ms, random_samples) - expect_output = np.array([[[[[1, 2, 4], [13, 14, 16]], - [[49, 50, 52], [61, 62, 64]]]]]).astype(type_input1) - expect_output_argmax = np.array([[[[[0, 1, 3], [12, 13, 15]], - [[48, 49, 51], [60, 61, 63]]]]]).astype(type_input2) - assert np.allclose(output_ms.asnumpy(), expect_output) - assert np.allclose(argmax.asnumpy(), expect_output_argmax) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import mindspore.ops.operations._grad_ops as grad_ops + + +class NetFractionalMaxPool3DWithFixedKsize(nn.Cell): + def __init__(self, ksize, output_shape): + super(NetFractionalMaxPool3DWithFixedKsize, self).__init__() + self.fractional_max_pool_3d_with_fixed_ksize = ops.FractionalMaxPool3DWithFixedKsize( + ksize, output_shape) + + def construct(self, x, random_sapmples): + return self.fractional_max_pool_3d_with_fixed_ksize(x, random_sapmples) + + +class NetFractionalMaxPool3DGradWithFixedKsize(nn.Cell): + def __init__(self): + super(NetFractionalMaxPool3DGradWithFixedKsize, self).__init__() + self.fractional_max_pool_3d_grad_with_fixed_ksize = grad_ops.FractionalMaxPool3DGradWithFixedKsize() + + def construct(self, origin_input, out_backprop, argmax): + return self.fractional_max_pool_3d_grad_with_fixed_ksize(origin_input, out_backprop, argmax) + + +@arg_mark(plat_marks=['platform_ascend', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_fractionalmaxpool3dwithfixedksize(): + """ + Feature: FractionalMaxPool3DWithFixedKsize + Description: Test of input + Expectation: The results are as expected + """ + context_mode_types = [context.GRAPH_MODE, context.PYNATIVE_MODE] + types_input1 = [np.float16, np.float32, np.int32, np.int64] + types_input2 = [np.float16, np.float32] + for context_mode_type in context_mode_types: + context.set_context(mode=context_mode_type) + for type_input1 in types_input1: + for type_input2 in types_input2: + x_np = np.array([i+1 for i in range(64)] + ).reshape([1, 1, 4, 4, 4]).astype(type_input1) + x_ms = Tensor(x_np) + x_dyn = Tensor(shape=(1, 1, None, None, None), + dtype=x_ms.dtype) + random_samples = Tensor(np.array([0.5, 0.5, 0.8]).reshape( + [1, 1, 3]).astype(type_input2)) + ksize = (1, 1, 1) + output_shape = (2, 2, 3) + net = NetFractionalMaxPool3DWithFixedKsize(ksize, output_shape) + net.set_inputs(x_dyn, random_samples) + output_ms, argmax = net(x_ms, random_samples) + expect_output = np.array([[[[[1, 2, 4], [13, 14, 16]], + [[49, 50, 52], [61, 62, 64]]]]]).astype(type_input1) + expect_output_argmax = np.array([[[[[0, 1, 3], [12, 13, 15]], + [[48, 49, 51], [60, 61, 63]]]]]).astype(type_input2) + assert np.allclose(output_ms.asnumpy(), expect_output) + assert np.allclose(argmax.asnumpy(), expect_output_argmax) diff --git a/tests/st/ops/test_func_addbmm.py b/tests/st/ops/test_func_addbmm.py index e041c377bc1..d13305d9146 100644 --- a/tests/st/ops/test_func_addbmm.py +++ b/tests/st/ops/test_func_addbmm.py @@ -1,43 +1,43 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, batch1, batch2, alpha=0.1, beta=0.5): - output = ops.addbmm(x, batch1, batch2, alpha=alpha, beta=beta) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_real_normal(): - """ - Feature: ops.addbmm - Description: Test 4D input - Expectation: raise ValueError - """ - x = Tensor(np.random.randn(6, 8).astype(np.float32)) - b1 = Tensor(np.random.randn(12, 10, 6, 4).astype(np.float32)) - b2 = Tensor(np.random.randn(12, 8, 4, 8).astype(np.float32)) - net = Net() - with pytest.raises(ValueError): - net(x, b1, b2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, batch1, batch2, alpha=0.1, beta=0.5): + output = ops.addbmm(x, batch1, batch2, alpha=alpha, beta=beta) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_real_normal(): + """ + Feature: ops.addbmm + Description: Test 4D input + Expectation: raise ValueError + """ + x = Tensor(np.random.randn(6, 8).astype(np.float32)) + b1 = Tensor(np.random.randn(12, 10, 6, 4).astype(np.float32)) + b2 = Tensor(np.random.randn(12, 8, 4, 8).astype(np.float32)) + net = Net() + with pytest.raises(ValueError): + net(x, b1, b2) diff --git a/tests/st/ops/test_func_arange.py b/tests/st/ops/test_func_arange.py index 2b57c690531..0907976530c 100644 --- a/tests/st/ops/test_func_arange.py +++ b/tests/st/ops/test_func_arange.py @@ -1,48 +1,48 @@ -# Copyright 2022-2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, start=0, end=None, step=1, dtype=None): - return ops.arange(start, end, step, dtype=dtype) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_arange_normal(mode): - """ - Feature: arange - Description: Verify the result of arange - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - output1 = net(1, 6) - output2 = net(10.0, dtype=ms.int32) - output3 = net(ms.Tensor(12.0, dtype=ms.float64), 2, ms.Tensor(-1.0, dtype=ms.float32)) - assert np.allclose(output1.asnumpy(), np.array([1., 2., 3., 4., 5.])) - assert output1.dtype == ms.int64 - assert np.allclose(output2.asnumpy(), np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])) - assert output2.dtype == ms.int32 - assert np.allclose(output3.asnumpy(), np.array([12., 11., 10., 9., 8., 7., 6., 5., 4., 3.])) - assert output3.dtype == ms.float32 +# Copyright 2022-2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, start=0, end=None, step=1, dtype=None): + return ops.arange(start, end, step, dtype=dtype) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_arange_normal(mode): + """ + Feature: arange + Description: Verify the result of arange + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + output1 = net(1, 6) + output2 = net(10.0, dtype=ms.int32) + output3 = net(ms.Tensor(12.0, dtype=ms.float64), 2, ms.Tensor(-1.0, dtype=ms.float32)) + assert np.allclose(output1.asnumpy(), np.array([1., 2., 3., 4., 5.])) + assert output1.dtype == ms.int64 + assert np.allclose(output2.asnumpy(), np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])) + assert output2.dtype == ms.int32 + assert np.allclose(output3.asnumpy(), np.array([12., 11., 10., 9., 8., 7., 6., 5., 4., 3.])) + assert output3.dtype == ms.float32 diff --git a/tests/st/ops/test_func_chunk.py b/tests/st/ops/test_func_chunk.py index 22ae6565970..8754259a6ff 100644 --- a/tests/st/ops/test_func_chunk.py +++ b/tests/st/ops/test_func_chunk.py @@ -1,53 +1,53 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, chunks, axis): - output = ops.chunk(x, chunks, axis) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_chunk_normal(mode): - """ - Feature: ops.chunk - Description: Verify the result of chunk - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]]) - chunks = 6 - axis = 1 - out = net(x, chunks, axis) - expect_out_1 = np.array([[[[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]]]]) - expect_out_2 = np.array([[[[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]]]]) - assert np.allclose(out[0].asnumpy(), expect_out_1) - assert np.allclose(out[1].asnumpy(), expect_out_2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, chunks, axis): + output = ops.chunk(x, chunks, axis) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_chunk_normal(mode): + """ + Feature: ops.chunk + Description: Verify the result of chunk + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]]) + chunks = 6 + axis = 1 + out = net(x, chunks, axis) + expect_out_1 = np.array([[[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]]]]) + expect_out_2 = np.array([[[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]]]]) + assert np.allclose(out[0].asnumpy(), expect_out_1) + assert np.allclose(out[1].asnumpy(), expect_out_2) diff --git a/tests/st/ops/test_func_full.py b/tests/st/ops/test_func_full.py index c69b082a13d..711d922b36c 100644 --- a/tests/st/ops/test_func_full.py +++ b/tests/st/ops/test_func_full.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, size, fill_value): - output = ops.full(size, fill_value) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_full_normal(mode): - """ - Feature: ops.full - Description: Verify the result of ops.full - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - size = (1, 2, 3) - fill_value = 11 - out = net(size, fill_value) - expect_out = np.array([[[11, 11, 11], - [11, 11, 11]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, size, fill_value): + output = ops.full(size, fill_value) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_full_normal(mode): + """ + Feature: ops.full + Description: Verify the result of ops.full + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + size = (1, 2, 3) + fill_value = 11 + out = net(size, fill_value) + expect_out = np.array([[[11, 11, 11], + [11, 11, 11]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_full_like.py b/tests/st/ops/test_func_full_like.py index a6fd7826b2c..e7bb7b3a5c2 100644 --- a/tests/st/ops/test_func_full_like.py +++ b/tests/st/ops/test_func_full_like.py @@ -1,46 +1,46 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, fill_value): - output = ops.full_like(x, fill_value) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_full_normal(mode): - """ - Feature: ops.full - Description: Verify the result of ops.full - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = Tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]) - fill_value = 11 - out = net(x, fill_value) - expect_out = np.array([[[11, 11, 11, 11], [11, 11, 11, 11], [11, 11, 11, 11]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, fill_value): + output = ops.full_like(x, fill_value) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_full_normal(mode): + """ + Feature: ops.full + Description: Verify the result of ops.full + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = Tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]) + fill_value = 11 + out = net(x, fill_value) + expect_out = np.array([[[11, 11, 11, 11], [11, 11, 11, 11], [11, 11, 11, 11]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_is_tensor.py b/tests/st/ops/test_func_is_tensor.py index 766ea401a58..e3863a62ba2 100644 --- a/tests/st/ops/test_func_is_tensor.py +++ b/tests/st/ops/test_func_is_tensor.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x): - return ops.is_tensor(x) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_is_tensor(mode): - """ - Feature: ops.is_tensor - Description: Verify the result of ops.is_tensor - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - a = Tensor([1, 2]) - output1 = net(a) - assert output1 - b = [1, 2] - output2 = net(b) - assert not output2 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x): + return ops.is_tensor(x) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_is_tensor(mode): + """ + Feature: ops.is_tensor + Description: Verify the result of ops.is_tensor + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + a = Tensor([1, 2]) + output1 = net(a) + assert output1 + b = [1, 2] + output2 = net(b) + assert not output2 diff --git a/tests/st/ops/test_func_logsigmoid.py b/tests/st/ops/test_func_logsigmoid.py index 1581dd0f39b..696b423c91e 100644 --- a/tests/st/ops/test_func_logsigmoid.py +++ b/tests/st/ops/test_func_logsigmoid.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.logsigmoid(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_logsigmoid(mode): - """ - Feature: logsigmoid - Description: Verify the result of logsigmoid - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) - out = net(x) - expect_out = np.array([-0.31326166, -0.12692806, -0.04858734]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.logsigmoid(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_logsigmoid(mode): + """ + Feature: logsigmoid + Description: Verify the result of logsigmoid + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) + out = net(x) + expect_out = np.array([-0.31326166, -0.12692806, -0.04858734]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_mm.py b/tests/st/ops/test_func_mm.py index 5f885a59c81..f786388f3af 100644 --- a/tests/st/ops/test_func_mm.py +++ b/tests/st/ops/test_func_mm.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x1, x2): - output = ops.mm(x1, x2) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_mm_normal(mode): - """ - Feature: mm - Description: Verify the result of mm - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32) - x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32) - out = net(x1, x2) - expect_out = np.array([[20, 23, 26, 29], - [56, 68, 80, 92]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x1, x2): + output = ops.mm(x1, x2) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_mm_normal(mode): + """ + Feature: mm + Description: Verify the result of mm + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32) + x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32) + out = net(x1, x2) + expect_out = np.array([[20, 23, 26, 29], + [56, 68, 80, 92]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_msort.py b/tests/st/ops/test_func_msort.py index 90743b74a60..f368026c6f2 100644 --- a/tests/st/ops/test_func_msort.py +++ b/tests/st/ops/test_func_msort.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.msort(x) - return output - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_msort_normal(mode): - """ - Feature: msort - Description: Verify the result of msort - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) - out = net(x) - expect_out = np.array([[4., 2., 1.], - [5., 6., 3.], - [8., 9., 7.]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.msort(x) + return output + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_msort_normal(mode): + """ + Feature: msort + Description: Verify the result of msort + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) + out = net(x) + expect_out = np.array([[4., 2., 1.], + [5., 6., 3.], + [8., 9., 7.]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_nan_to_num.py b/tests/st/ops/test_func_nan_to_num.py index 8cfe0350b1b..40a33cad64d 100644 --- a/tests/st/ops/test_func_nan_to_num.py +++ b/tests/st/ops/test_func_nan_to_num.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x, nan, posinf, neginf): - output = ops.nan_to_num(x, nan, posinf, neginf) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_nan_to_num_normal(mode): - """ - Feature: nan_to_num - Description: Verify the result of nan_to_num - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32) - out = net(x, 1.0, 2.0, 3.0) - expect_out = np.array([1., 2., 3., 3.14]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x, nan, posinf, neginf): + output = ops.nan_to_num(x, nan, posinf, neginf) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_nan_to_num_normal(mode): + """ + Feature: nan_to_num + Description: Verify the result of nan_to_num + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32) + out = net(x, 1.0, 2.0, 3.0) + expect_out = np.array([1., 2., 3., 3.14]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_pad.py b/tests/st/ops/test_func_pad.py index 10e790865f6..a485cd9c38c 100644 --- a/tests/st/ops/test_func_pad.py +++ b/tests/st/ops/test_func_pad.py @@ -1,157 +1,157 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def __init__(self, mode): - super(Net, self).__init__() - self.mode = mode - - def construct(self, x, padding, value=None): - output = ops.pad(x, padding, self.mode, value) - return output - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"]) -@pytest.mark.parametrize('padding', [[1, 2, 2, 1], (1, 2, 2, 1), ms.Tensor([1, 2, 2, 1])]) -def test_pad_normal(mode, pad_mode, padding): - """ - Feature: pad - Description: Verify the result of pad - Expectation: success - """ - ms.set_context(mode=mode) - net = Net(pad_mode) - x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64) - - if pad_mode == "constant": - output = net(x, padding, 6) - expect_output = np.array([[[[6., 6., 6., 6., 6., 6., 6.], - [6., 6., 6., 6., 6., 6., 6.], - [6., 0., 1., 2., 3., 6., 6.], - [6., 4., 5., 6., 7., 6., 6.], - [6., 8., 9., 10., 11., 6., 6.], - [6., 6., 6., 6., 6., 6., 6.]], - - [[6., 6., 6., 6., 6., 6., 6.], - [6., 6., 6., 6., 6., 6., 6.], - [6., 12., 13., 14., 15., 6., 6.], - [6., 16., 17., 18., 19., 6., 6.], - [6., 20., 21., 22., 23., 6., 6.], - [6., 6., 6., 6., 6., 6., 6.]]]]) - elif pad_mode == "reflect": - output = net(x, padding) - expect_output = np.array([[[[9., 8., 9., 10., 11., 10., 9.], - [5., 4., 5., 6., 7., 6., 5.], - [1., 0., 1., 2., 3., 2., 1.], - [5., 4., 5., 6., 7., 6., 5.], - [9., 8., 9., 10., 11., 10., 9.], - [5., 4., 5., 6., 7., 6., 5.]], - - [[21., 20., 21., 22., 23., 22., 21.], - [17., 16., 17., 18., 19., 18., 17.], - [13., 12., 13., 14., 15., 14., 13.], - [17., 16., 17., 18., 19., 18., 17.], - [21., 20., 21., 22., 23., 22., 21.], - [17., 16., 17., 18., 19., 18., 17.]]]]) - else: - output = net(x, padding) - expect_output = np.array([[[[0., 0., 1., 2., 3., 3., 3.], - [0., 0., 1., 2., 3., 3., 3.], - [0., 0., 1., 2., 3., 3., 3.], - [4., 4., 5., 6., 7., 7., 7.], - [8., 8., 9., 10., 11., 11., 11.], - [8., 8., 9., 10., 11., 11., 11.]], - - [[12., 12., 13., 14., 15., 15., 15.], - [12., 12., 13., 14., 15., 15., 15.], - [12., 12., 13., 14., 15., 15., 15.], - [16., 16., 17., 18., 19., 19., 19.], - [20., 20., 21., 22., 23., 23., 23.], - [20., 20., 21., 22., 23., 23., 23.]]]]) - assert np.allclose(output.asnumpy(), expect_output) - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"]) -@pytest.mark.parametrize('padding', [[-1, 2, 2, 1]]) -def test_pad_negative(mode, pad_mode, padding): - """ - Feature: pad - Description: Verify the result of pad when padding is negative - Expectation: success - """ - ms.set_context(mode=mode) - net = Net(pad_mode) - x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64) - - if pad_mode == "constant": - output = net(x, padding, 6) - expect_output = np.array([[[[6., 6., 6., 6., 6.], - [6., 6., 6., 6., 6.], - [1., 2., 3., 6., 6.], - [5., 6., 7., 6., 6.], - [9., 10., 11., 6., 6.], - [6., 6., 6., 6., 6.]], - - [[6., 6., 6., 6., 6.], - [6., 6., 6., 6., 6.], - [13., 14., 15., 6., 6.], - [17., 18., 19., 6., 6.], - [21., 22., 23., 6., 6.], - [6., 6., 6., 6., 6.]]]]) - elif pad_mode == "reflect": - output = net(x, padding) - expect_output = np.array([[[[9., 10., 11., 10., 9.], - [5., 6., 7., 6., 5.], - [1., 2., 3., 2., 1.], - [5., 6., 7., 6., 5.], - [9., 10., 11., 10., 9.], - [5., 6., 7., 6., 5.]], - - [[21., 22., 23., 22., 21.], - [17., 18., 19., 18., 17.], - [13., 14., 15., 14., 13.], - [17., 18., 19., 18., 17.], - [21., 22., 23., 22., 21.], - [17., 18., 19., 18., 17.]]]]) - - else: - output = net(x, padding) - expect_output = np.array([[[[1., 2., 3., 3., 3.], - [1., 2., 3., 3., 3.], - [1., 2., 3., 3., 3.], - [5., 6., 7., 7., 7.], - [9., 10., 11., 11., 11.], - [9., 10., 11., 11., 11.]], - - [[13., 14., 15., 15., 15.], - [13., 14., 15., 15., 15.], - [13., 14., 15., 15., 15.], - [17., 18., 19., 19., 19.], - [21., 22., 23., 23., 23.], - [21., 22., 23., 23., 23.]]]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def __init__(self, mode): + super(Net, self).__init__() + self.mode = mode + + def construct(self, x, padding, value=None): + output = ops.pad(x, padding, self.mode, value) + return output + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"]) +@pytest.mark.parametrize('padding', [[1, 2, 2, 1], (1, 2, 2, 1), ms.Tensor([1, 2, 2, 1])]) +def test_pad_normal(mode, pad_mode, padding): + """ + Feature: pad + Description: Verify the result of pad + Expectation: success + """ + ms.set_context(mode=mode) + net = Net(pad_mode) + x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64) + + if pad_mode == "constant": + output = net(x, padding, 6) + expect_output = np.array([[[[6., 6., 6., 6., 6., 6., 6.], + [6., 6., 6., 6., 6., 6., 6.], + [6., 0., 1., 2., 3., 6., 6.], + [6., 4., 5., 6., 7., 6., 6.], + [6., 8., 9., 10., 11., 6., 6.], + [6., 6., 6., 6., 6., 6., 6.]], + + [[6., 6., 6., 6., 6., 6., 6.], + [6., 6., 6., 6., 6., 6., 6.], + [6., 12., 13., 14., 15., 6., 6.], + [6., 16., 17., 18., 19., 6., 6.], + [6., 20., 21., 22., 23., 6., 6.], + [6., 6., 6., 6., 6., 6., 6.]]]]) + elif pad_mode == "reflect": + output = net(x, padding) + expect_output = np.array([[[[9., 8., 9., 10., 11., 10., 9.], + [5., 4., 5., 6., 7., 6., 5.], + [1., 0., 1., 2., 3., 2., 1.], + [5., 4., 5., 6., 7., 6., 5.], + [9., 8., 9., 10., 11., 10., 9.], + [5., 4., 5., 6., 7., 6., 5.]], + + [[21., 20., 21., 22., 23., 22., 21.], + [17., 16., 17., 18., 19., 18., 17.], + [13., 12., 13., 14., 15., 14., 13.], + [17., 16., 17., 18., 19., 18., 17.], + [21., 20., 21., 22., 23., 22., 21.], + [17., 16., 17., 18., 19., 18., 17.]]]]) + else: + output = net(x, padding) + expect_output = np.array([[[[0., 0., 1., 2., 3., 3., 3.], + [0., 0., 1., 2., 3., 3., 3.], + [0., 0., 1., 2., 3., 3., 3.], + [4., 4., 5., 6., 7., 7., 7.], + [8., 8., 9., 10., 11., 11., 11.], + [8., 8., 9., 10., 11., 11., 11.]], + + [[12., 12., 13., 14., 15., 15., 15.], + [12., 12., 13., 14., 15., 15., 15.], + [12., 12., 13., 14., 15., 15., 15.], + [16., 16., 17., 18., 19., 19., 19.], + [20., 20., 21., 22., 23., 23., 23.], + [20., 20., 21., 22., 23., 23., 23.]]]]) + assert np.allclose(output.asnumpy(), expect_output) + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('pad_mode', ["constant", "reflect", "replicate"]) +@pytest.mark.parametrize('padding', [[-1, 2, 2, 1]]) +def test_pad_negative(mode, pad_mode, padding): + """ + Feature: pad + Description: Verify the result of pad when padding is negative + Expectation: success + """ + ms.set_context(mode=mode) + net = Net(pad_mode) + x = ms.Tensor(np.arange(1 * 2 * 3 * 4).reshape((1, 2, 3, 4)), dtype=ms.float64) + + if pad_mode == "constant": + output = net(x, padding, 6) + expect_output = np.array([[[[6., 6., 6., 6., 6.], + [6., 6., 6., 6., 6.], + [1., 2., 3., 6., 6.], + [5., 6., 7., 6., 6.], + [9., 10., 11., 6., 6.], + [6., 6., 6., 6., 6.]], + + [[6., 6., 6., 6., 6.], + [6., 6., 6., 6., 6.], + [13., 14., 15., 6., 6.], + [17., 18., 19., 6., 6.], + [21., 22., 23., 6., 6.], + [6., 6., 6., 6., 6.]]]]) + elif pad_mode == "reflect": + output = net(x, padding) + expect_output = np.array([[[[9., 10., 11., 10., 9.], + [5., 6., 7., 6., 5.], + [1., 2., 3., 2., 1.], + [5., 6., 7., 6., 5.], + [9., 10., 11., 10., 9.], + [5., 6., 7., 6., 5.]], + + [[21., 22., 23., 22., 21.], + [17., 18., 19., 18., 17.], + [13., 14., 15., 14., 13.], + [17., 18., 19., 18., 17.], + [21., 22., 23., 22., 21.], + [17., 18., 19., 18., 17.]]]]) + + else: + output = net(x, padding) + expect_output = np.array([[[[1., 2., 3., 3., 3.], + [1., 2., 3., 3., 3.], + [1., 2., 3., 3., 3.], + [5., 6., 7., 7., 7.], + [9., 10., 11., 11., 11.], + [9., 10., 11., 11., 11.]], + + [[13., 14., 15., 15., 15.], + [13., 14., 15., 15., 15.], + [13., 14., 15., 15., 15.], + [17., 18., 19., 19., 19.], + [21., 22., 23., 23., 23.], + [21., 22., 23., 23., 23.]]]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/test_func_random_functions.py b/tests/st/ops/test_func_random_functions.py index 96998dce5e8..182d0c40641 100644 --- a/tests/st/ops/test_func_random_functions.py +++ b/tests/st/ops/test_func_random_functions.py @@ -1,116 +1,116 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Rand(nn.Cell): - def construct(self, size, dtype): - return ops.rand(size, dtype=dtype) - - -class RandLike(nn.Cell): - def construct(self, x, dtype): - return ops.rand_like(x, dtype=dtype) - - -class Randn(nn.Cell): - def construct(self, size, dtype): - return ops.randn(size, dtype=dtype) - - -class RandnLike(nn.Cell): - def construct(self, x, dtype): - return ops.randn_like(x, dtype=dtype) - - -class RandInt(nn.Cell): - def construct(self, low, high, size, dtype): - return ops.randint(low, high, size, dtype=dtype) - - -class RandIntLike(nn.Cell): - def construct(self, x, low, high, dtype): - return ops.randint_like(x, low, high, dtype=dtype) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [None, ms.float32]) -def test_rand_functions(mode, dtype): - r""" - Feature: ops.rand, ops.randn, ops.rand_like, ops.randn_like - Description: Verify the result of ops.rand, ops.randn, ops.rand_like, ops.randn_like - Expectation: success - """ - ms.set_context(mode=mode) - x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) - size = (2, 3) - net1 = Rand() - net2 = Randn() - net3 = RandLike() - net4 = RandnLike() - out1 = net1(size, dtype) - out2 = net2(size, dtype) - out3 = net3(x, dtype) - out4 = net4(x, dtype) - if dtype is None: - assert out1.dtype == ms.float32 - assert out2.dtype == ms.float32 - assert out3.dtype == ms.float16 - assert out4.dtype == ms.float32 - else: - assert out1.dtype == dtype - assert out2.dtype == dtype - assert out3.dtype == dtype - assert out4.dtype == dtype - - assert out1.shape == size - assert out2.shape == size - assert out3.shape == x.shape - assert out4.shape == x.shape - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('dtype', [None, ms.int32]) -def test_randint_functions(mode, dtype): - r""" - Feature: ops.randint, ops.randint_like - Description: Verify the result of ops.randint, ops.randint_like - Expectation: success - """ - ms.set_context(mode=mode) - x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.int32) - net = RandInt() - net2 = RandIntLike() - out = net(0, 10, (2, 3), dtype=dtype) - out2 = net2(x, low=0, high=15, dtype=dtype) - if dtype is None: - assert out.dtype == ms.int64 - assert out2.dtype == ms.int32 - else: - assert out.dtype == dtype - assert out2.dtype == dtype - assert out.shape == (2, 3) - assert out2.shape == x.shape - assert out.max() < 10 and out.min() >= 0 - assert out2.max() < 15 and out2.min() >= 0 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Rand(nn.Cell): + def construct(self, size, dtype): + return ops.rand(size, dtype=dtype) + + +class RandLike(nn.Cell): + def construct(self, x, dtype): + return ops.rand_like(x, dtype=dtype) + + +class Randn(nn.Cell): + def construct(self, size, dtype): + return ops.randn(size, dtype=dtype) + + +class RandnLike(nn.Cell): + def construct(self, x, dtype): + return ops.randn_like(x, dtype=dtype) + + +class RandInt(nn.Cell): + def construct(self, low, high, size, dtype): + return ops.randint(low, high, size, dtype=dtype) + + +class RandIntLike(nn.Cell): + def construct(self, x, low, high, dtype): + return ops.randint_like(x, low, high, dtype=dtype) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [None, ms.float32]) +def test_rand_functions(mode, dtype): + r""" + Feature: ops.rand, ops.randn, ops.rand_like, ops.randn_like + Description: Verify the result of ops.rand, ops.randn, ops.rand_like, ops.randn_like + Expectation: success + """ + ms.set_context(mode=mode) + x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) + size = (2, 3) + net1 = Rand() + net2 = Randn() + net3 = RandLike() + net4 = RandnLike() + out1 = net1(size, dtype) + out2 = net2(size, dtype) + out3 = net3(x, dtype) + out4 = net4(x, dtype) + if dtype is None: + assert out1.dtype == ms.float32 + assert out2.dtype == ms.float32 + assert out3.dtype == ms.float16 + assert out4.dtype == ms.float32 + else: + assert out1.dtype == dtype + assert out2.dtype == dtype + assert out3.dtype == dtype + assert out4.dtype == dtype + + assert out1.shape == size + assert out2.shape == size + assert out3.shape == x.shape + assert out4.shape == x.shape + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [None, ms.int32]) +def test_randint_functions(mode, dtype): + r""" + Feature: ops.randint, ops.randint_like + Description: Verify the result of ops.randint, ops.randint_like + Expectation: success + """ + ms.set_context(mode=mode) + x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.int32) + net = RandInt() + net2 = RandIntLike() + out = net(0, 10, (2, 3), dtype=dtype) + out2 = net2(x, low=0, high=15, dtype=dtype) + if dtype is None: + assert out.dtype == ms.int64 + assert out2.dtype == ms.int32 + else: + assert out.dtype == dtype + assert out2.dtype == dtype + assert out.shape == (2, 3) + assert out2.shape == x.shape + assert out.max() < 10 and out.min() >= 0 + assert out2.max() < 15 and out2.min() >= 0 diff --git a/tests/st/ops/test_func_real.py b/tests/st/ops/test_func_real.py index 6a85e07666e..92d142f01ee 100644 --- a/tests/st/ops/test_func_real.py +++ b/tests/st/ops/test_func_real.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.real(x) - return output - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_real_normal(mode): - """ - Feature: real - Description: Verify the result of real - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64) - out = net(x) - expect_out = np.array(1.3) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.real(x) + return output + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_real_normal(mode): + """ + Feature: real + Description: Verify the result of real + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64) + out = net(x) + expect_out = np.array(1.3) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_reciprocal.py b/tests/st/ops/test_func_reciprocal.py index 4c23df33891..41c5d2968ac 100644 --- a/tests/st/ops/test_func_reciprocal.py +++ b/tests/st/ops/test_func_reciprocal.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.reciprocal(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_reciprocal_normal(mode): - """ - Feature: reciprocal - Description: Verify the result of reciprocal - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) - out = net(x) - expect_out = np.array([1., 0.5, 0.25]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.reciprocal(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_reciprocal_normal(mode): + """ + Feature: reciprocal + Description: Verify the result of reciprocal + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) + out = net(x) + expect_out = np.array([1., 0.5, 0.25]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_repeat_elements.py b/tests/st/ops/test_func_repeat_elements.py index dd82f3f352c..9653aac994d 100644 --- a/tests/st/ops/test_func_repeat_elements.py +++ b/tests/st/ops/test_func_repeat_elements.py @@ -1,47 +1,47 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x, rep, axis): - output = ops.repeat_elements(x, rep, axis) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_repeat_elements(mode): - """ - Feature: real - Description: Verify the result of repeat_elements when axis is less than 0. - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) - out = net(x, 2, -1) - expect_out = np.array([[0, 0, 1, 1, 2, 2], [3, 3, 4, 4, 5, 5]]) - assert np.allclose(out.asnumpy(), expect_out) - out = net(x, 2, -2) - expect_out = np.array([[0, 1, 2], [0, 1, 2], [3, 4, 5], [3, 4, 5]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x, rep, axis): + output = ops.repeat_elements(x, rep, axis) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', card_mark='onecard', essential_mark='essential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_repeat_elements(mode): + """ + Feature: real + Description: Verify the result of repeat_elements when axis is less than 0. + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) + out = net(x, 2, -1) + expect_out = np.array([[0, 0, 1, 1, 2, 2], [3, 3, 4, 4, 5, 5]]) + assert np.allclose(out.asnumpy(), expect_out) + out = net(x, 2, -2) + expect_out = np.array([[0, 1, 2], [0, 1, 2], [3, 4, 5], [3, 4, 5]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_func_rrelu.py b/tests/st/ops/test_func_rrelu.py index 4dd7ad592a4..d2b791615c9 100644 --- a/tests/st/ops/test_func_rrelu.py +++ b/tests/st/ops/test_func_rrelu.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.rrelu(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_rrelu(mode): - """ - Feature: rrelu - Description: Verify the result of rrelu - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([[1.0, 4.0], [2.0, 0]], dtype=ms.float32) - out = net(x) - expect_out = np.array([[1., 4.], - [2., 0.]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.rrelu(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_rrelu(mode): + """ + Feature: rrelu + Description: Verify the result of rrelu + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([[1.0, 4.0], [2.0, 0]], dtype=ms.float32) + out = net(x) + expect_out = np.array([[1., 4.], + [2., 0.]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_heaviside.py b/tests/st/ops/test_heaviside.py index 03e0afdbf15..0ddb349535b 100644 --- a/tests/st/ops/test_heaviside.py +++ b/tests/st/ops/test_heaviside.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x, values): - output = ops.heaviside(x, values) - return output - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_heaviside_normal(mode): - """ - Feature: heaviside - Description: Verify the result of heaviside - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([-1.5, 0., 2.], ms.float32) - values = ms.Tensor([0.5], ms.float32) - expect_output = np.array([0., 0.5, 1.], dtype=np.float32) - out = net(x, values) - assert np.allclose(out.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x, values): + output = ops.heaviside(x, values) + return output + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_heaviside_normal(mode): + """ + Feature: heaviside + Description: Verify the result of heaviside + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([-1.5, 0., 2.], ms.float32) + values = ms.Tensor([0.5], ms.float32) + expect_output = np.array([0., 0.5, 1.], dtype=np.float32) + out = net(x, values) + assert np.allclose(out.asnumpy(), expect_output) diff --git a/tests/st/ops/test_hypot.py b/tests/st/ops/test_hypot.py index 16e8c404ce5..ed90f336971 100644 --- a/tests/st/ops/test_hypot.py +++ b/tests/st/ops/test_hypot.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x, other): - output = ops.hypot(x, other) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_hypot_normal(mode): - """ - Feature: hypot - Description: Verify the result of hypot - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([4], ms.float32) - other = ms.Tensor([3, 4, 5], ms.float64) - out = net(x, other) - expect_out = np.array([5.0000, 5.6569, 6.4031], dtype=np.float64) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x, other): + output = ops.hypot(x, other) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_hypot_normal(mode): + """ + Feature: hypot + Description: Verify the result of hypot + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([4], ms.float32) + other = ms.Tensor([3, 4, 5], ms.float64) + out = net(x, other) + expect_out = np.array([5.0000, 5.6569, 6.4031], dtype=np.float64) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_i0.py b/tests/st/ops/test_i0.py index a51809fb3d4..a8c3dbae6a3 100644 --- a/tests/st/ops/test_i0.py +++ b/tests/st/ops/test_i0.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.i0(x) - return output - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_i0_normal(mode): - """ - Feature: i0 - Description: Verify the result of i0 - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([0, 1, 2, 3, 4], ms.float32) - expect_output = np.array([1.0000, 1.26606588, 2.2795853, 4.88079259, 11.30192195], dtype=np.float32) - out = net(x) - assert np.allclose(out.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.i0(x) + return output + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_i0_normal(mode): + """ + Feature: i0 + Description: Verify the result of i0 + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([0, 1, 2, 3, 4], ms.float32) + expect_output = np.array([1.0000, 1.26606588, 2.2795853, 4.88079259, 11.30192195], dtype=np.float32) + out = net(x) + assert np.allclose(out.asnumpy(), expect_output) diff --git a/tests/st/ops/test_is_floating_point.py b/tests/st/ops/test_is_floating_point.py index c14ca0c233b..eaeaae9ecce 100644 --- a/tests/st/ops/test_is_floating_point.py +++ b/tests/st/ops/test_is_floating_point.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - output = ops.is_floating_point(x) - return output - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_is_floating_point_normal(mode): - """ - Feature: is_floating_point - Description: Verify the result of is_floating_point - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([1, 2, 3], ms.float32) - y = ms.Tensor([1, 2, 3], ms.int64) - out1 = net(x) - out2 = net(y) - assert out1 - assert not out2 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + output = ops.is_floating_point(x) + return output + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_is_floating_point_normal(mode): + """ + Feature: is_floating_point + Description: Verify the result of is_floating_point + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([1, 2, 3], ms.float32) + y = ms.Tensor([1, 2, 3], ms.int64) + out1 = net(x) + out2 = net(y) + assert out1 + assert not out2 diff --git a/tests/st/ops/test_lppool1d.py b/tests/st/ops/test_lppool1d.py index 7b97a666ab9..e5baeb43dc9 100644 --- a/tests/st/ops/test_lppool1d.py +++ b/tests/st/ops/test_lppool1d.py @@ -1,55 +1,55 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - out = ops.lp_pool1d(x, norm_type=1, kernel_size=3, stride=1) - return out - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_lppool1d_normal(mode): - """ - Feature: LPPool1d - Description: Verify the result of LPPool1d - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) - y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) - out = net(x) - out2 = net(y) - expect_out = np.array([[[3., 6.], - [15., 18.], - [27., 30.]], - [[39., 42.], - [51., 54.], - [63., 66.]]]) - expect_out2 = np.array([[3., 6.], - [15., 18.], - [27., 30.]]) - assert np.allclose(out.asnumpy(), expect_out) - assert np.allclose(out2.asnumpy(), expect_out2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + out = ops.lp_pool1d(x, norm_type=1, kernel_size=3, stride=1) + return out + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_lppool1d_normal(mode): + """ + Feature: LPPool1d + Description: Verify the result of LPPool1d + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) + y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) + out = net(x) + out2 = net(y) + expect_out = np.array([[[3., 6.], + [15., 18.], + [27., 30.]], + [[39., 42.], + [51., 54.], + [63., 66.]]]) + expect_out2 = np.array([[3., 6.], + [15., 18.], + [27., 30.]]) + assert np.allclose(out.asnumpy(), expect_out) + assert np.allclose(out2.asnumpy(), expect_out2) diff --git a/tests/st/ops/test_lppool2d.py b/tests/st/ops/test_lppool2d.py index d754670b3e8..ded5dc0fdac 100644 --- a/tests/st/ops/test_lppool2d.py +++ b/tests/st/ops/test_lppool2d.py @@ -1,55 +1,55 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops - - -class Net(nn.Cell): - def construct(self, x): - out = ops.lp_pool2d(x, norm_type=1, kernel_size=3, stride=1) - return out - - -@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_lppool2d_normal(mode): - """ - Feature: LPPool2d - Description: Verify the result of LPPool2d - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) - out = net(x) - expect_out = np.array([[[[54., 63., 72.], - [99., 108., 117.]], - [[234., 243., 252.], - [279., 288., 297.]], - [[414., 423., 432.], - [459., 468., 477.]]], - [[[594., 603., 612.], - [639., 648., 657.]], - [[774., 783., 792.], - [819., 828., 837.]], - [[954., 963., 972.], - [999., 1008., 1017.]]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class Net(nn.Cell): + def construct(self, x): + out = ops.lp_pool2d(x, norm_type=1, kernel_size=3, stride=1) + return out + + +@arg_mark(plat_marks=['platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_lppool2d_normal(mode): + """ + Feature: LPPool2d + Description: Verify the result of LPPool2d + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) + out = net(x) + expect_out = np.array([[[[54., 63., 72.], + [99., 108., 117.]], + [[234., 243., 252.], + [279., 288., 297.]], + [[414., 423., 432.], + [459., 468., 477.]]], + [[[594., 603., 612.], + [639., 648., 657.]], + [[774., 783., 792.], + [819., 828., 837.]], + [[954., 963., 972.], + [999., 1008., 1017.]]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/ops/test_margin_ranking_loss.py b/tests/st/ops/test_margin_ranking_loss.py index 85f26ca30b1..f769b8ee026 100644 --- a/tests/st/ops/test_margin_ranking_loss.py +++ b/tests/st/ops/test_margin_ranking_loss.py @@ -1,57 +1,57 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class MarginRankingLoss(nn.Cell): - def __init__(self, reduction): - super(MarginRankingLoss, self).__init__() - self.reduction = reduction - - def construct(self, x, y, label, margin): - return ops.margin_ranking_loss(x, y, label, margin=margin, reduction=self.reduction) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) -def test_margin_ranking_loss(mode, reduction): - """ - Feature: test MarginRankingLoss op. - Description: Verify the result of MarginRankingLoss. - Expectation: expect correct forward result. - """ - ms.set_context(mode=mode) - loss = MarginRankingLoss(reduction) - input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) - input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) - target = Tensor(np.array([-1, -1, 1]), ms.float32) - output = loss(input1, input2, target, 0.0) - if reduction == 'none': - expect_output = np.array([0.98759997, 0., 2.7003999]) - elif reduction == 'sum': - expect_output = np.array(3.6879997) - else: - expect_output = np.array(1.2293333) - - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class MarginRankingLoss(nn.Cell): + def __init__(self, reduction): + super(MarginRankingLoss, self).__init__() + self.reduction = reduction + + def construct(self, x, y, label, margin): + return ops.margin_ranking_loss(x, y, label, margin=margin, reduction=self.reduction) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) +def test_margin_ranking_loss(mode, reduction): + """ + Feature: test MarginRankingLoss op. + Description: Verify the result of MarginRankingLoss. + Expectation: expect correct forward result. + """ + ms.set_context(mode=mode) + loss = MarginRankingLoss(reduction) + input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) + input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) + target = Tensor(np.array([-1, -1, 1]), ms.float32) + output = loss(input1, input2, target, 0.0) + if reduction == 'none': + expect_output = np.array([0.98759997, 0., 2.7003999]) + elif reduction == 'sum': + expect_output = np.array(3.6879997) + else: + expect_output = np.array(1.2293333) + + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/test_repeat_interleave.py b/tests/st/ops/test_repeat_interleave.py index e84c2848821..5761b9720bc 100644 --- a/tests/st/ops/test_repeat_interleave.py +++ b/tests/st/ops/test_repeat_interleave.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class RepeatInterleave(nn.Cell): - def construct(self, x): - return ops.repeat_interleave(x, repeats=2, axis=0) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_repeat_interleave(mode): - """ - Feature: tensor.repeat_interleave - Description: Verify the result of repeat_interleave - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) - net = RepeatInterleave() - output = net(x) - expect_output = [[0, 1, 2], - [0, 1, 2], - [3, 4, 5], - [3, 4, 5]] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class RepeatInterleave(nn.Cell): + def construct(self, x): + return ops.repeat_interleave(x, repeats=2, axis=0) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level2', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_repeat_interleave(mode): + """ + Feature: tensor.repeat_interleave + Description: Verify the result of repeat_interleave + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) + net = RepeatInterleave() + output = net(x) + expect_output = [[0, 1, 2], + [0, 1, 2], + [3, 4, 5], + [3, 4, 5]] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/test_roll.py b/tests/st/ops/test_roll.py index e20b9e4e0dc..28ef7ec898e 100644 --- a/tests/st/ops/test_roll.py +++ b/tests/st/ops/test_roll.py @@ -1,42 +1,42 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark -import numpy as np -import pytest -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.roll(x, shifts=2, dims=0) - - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_roll(mode): - """ - Feature: tensor.roll - Description: Verify the result of roll - Expectation: success - """ - ms.set_context(mode=mode) - input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) - net = Roll() - output = net(input_x) - expect_output = [3., 4., 0., 1., 2.] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark +import numpy as np +import pytest +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.roll(x, shifts=2, dims=0) + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_roll(mode): + """ + Feature: tensor.roll + Description: Verify the result of roll + Expectation: success + """ + ms.set_context(mode=mode) + input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) + net = Roll() + output = net(input_x) + expect_output = [3., 4., 0., 1., 2.] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/test_tril_triu_op.py b/tests/st/ops/test_tril_triu_op.py index 90252aa1227..2d984efa4af 100644 --- a/tests/st/ops/test_tril_triu_op.py +++ b/tests/st/ops/test_tril_triu_op.py @@ -1,378 +1,378 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from tests.mark_utils import arg_mark - -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -import mindspore as ms -import mindspore.ops.operations.array_ops as P -from mindspore import Tensor -from mindspore.common.api import jit - - -class TrilNet(nn.Cell): - - def __init__(self, nptype, diagonal): - super(TrilNet, self).__init__() - self.tril = P.Tril(diagonal=diagonal) - self.x_np = np.random.randn(2, 3, 4).astype(nptype) - self.x_ms = Tensor(self.x_np) - - @jit - def construct(self): - return self.tril(self.x_ms) - - -class TrilDynNet(nn.Cell): - - def __init__(self, diagonal=0): - super(TrilDynNet, self).__init__() - self.op = P.Tril(diagonal=diagonal) - - def construct(self, x): - return self.op(x) - - -class TriuDynNet(nn.Cell): - - def __init__(self, diagonal=0): - super(TriuDynNet, self).__init__() - self.op = P.Triu(diagonal=diagonal) - - def construct(self, x): - return self.op(x) - - -class TriuNet(nn.Cell): - - def __init__(self, nptype, diagonal): - super(TriuNet, self).__init__() - self.triu = P.Triu(diagonal=diagonal) - self.x_np = np.random.randn(2, 3, 4).astype(nptype) - self.x_ms = Tensor(self.x_np) - - @jit - def construct(self): - return self.triu(self.x_ms) - - -def tril_triu(nptype, diagonal): - context.set_context(mode=context.GRAPH_MODE) - tril_ = TrilNet(nptype, diagonal) - triu_ = TriuNet(nptype, diagonal) - tril_output = tril_() - triu_output = triu_() - tril_expect = np.tril(tril_.x_np, diagonal).astype(nptype) - triu_expect = np.triu(triu_.x_np, diagonal).astype(nptype) - assert (tril_output.asnumpy() == tril_expect).all() - assert (triu_output.asnumpy() == triu_expect).all() - - -def tril_triu_pynative(nptype, diagonal): - context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - tril_ = TrilNet(nptype, diagonal) - triu_ = TriuNet(nptype, diagonal) - tril_output = tril_() - triu_output = triu_() - tril_expect = np.tril(tril_.x_np, diagonal).astype(nptype) - triu_expect = np.triu(triu_.x_np, diagonal).astype(nptype) - assert (tril_output.asnumpy() == tril_expect).all() - assert (triu_output.asnumpy() == triu_expect).all() - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') -def test_tril_triu_graph_float32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.float32, 4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_dyn(): - """ - Feature: test Tril op in gpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - for i in range(-1, 2, 1): - net = TrilDynNet(diagonal=i) - x_dyn = Tensor(shape=[None, None], dtype=ms.float32) - net.set_inputs(x_dyn) - - x = Tensor( - [[1, 2, 3, 4], [5, 6, 7, 8], [10, 11, 12, 13], [14, 15, 16, 17]], - dtype=ms.float32) - out = net(x) - - expect_shape = (4, 4) - assert out.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_triu_dyn(): - """ - Feature: test Triu op in gpu. - Description: test the ops in dynamic shape. - Expectation: expect correct shape result. - """ - context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - for i in range(-1, 2, 1): - net = TriuDynNet(diagonal=i) - x_dyn = Tensor(shape=[None, None], dtype=ms.float32) - net.set_inputs(x_dyn) - - x = Tensor( - [[1, 2, 3, 4], [5, 6, 7, 8], [10, 11, 12, 13], [14, 15, 16, 17]], - dtype=ms.float32) - out = net(x) - - expect_shape = (4, 4) - assert out.asnumpy().shape == expect_shape - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_uint8(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.uint8, -5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_uint16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.uint16, -4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_uint32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.uint32, -3) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_uint64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.uint64, -2) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_int8(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.int8, -1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_int16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.int16, 0) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_int32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.int32, 1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_int64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.int64, 2) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_float16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.float16, 3) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_float64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.float64, 5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_graph_bool(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu(np.bool, 6) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_uint8(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.uint8, -5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_uint16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.uint16, -4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_uint32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.uint32, -3) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_uint64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.uint64, -2) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_int8(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.int8, -1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_int16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.int16, 0) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_int32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.int32, 1) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_int64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.int64, 2) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_float16(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.float16, 3) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_float32(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.float32, 4) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_float64(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.float64, 5) - - -@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -def test_tril_triu_pynative_bool(): - """ - Feature: ALL To ALL - Description: test cases for Tril and Triu - Expectation: the result match to numpy - """ - tril_triu_pynative(np.bool, 6) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from tests.mark_utils import arg_mark + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import jit + + +class TrilNet(nn.Cell): + + def __init__(self, nptype, diagonal): + super(TrilNet, self).__init__() + self.tril = P.Tril(diagonal=diagonal) + self.x_np = np.random.randn(2, 3, 4).astype(nptype) + self.x_ms = Tensor(self.x_np) + + @jit + def construct(self): + return self.tril(self.x_ms) + + +class TrilDynNet(nn.Cell): + + def __init__(self, diagonal=0): + super(TrilDynNet, self).__init__() + self.op = P.Tril(diagonal=diagonal) + + def construct(self, x): + return self.op(x) + + +class TriuDynNet(nn.Cell): + + def __init__(self, diagonal=0): + super(TriuDynNet, self).__init__() + self.op = P.Triu(diagonal=diagonal) + + def construct(self, x): + return self.op(x) + + +class TriuNet(nn.Cell): + + def __init__(self, nptype, diagonal): + super(TriuNet, self).__init__() + self.triu = P.Triu(diagonal=diagonal) + self.x_np = np.random.randn(2, 3, 4).astype(nptype) + self.x_ms = Tensor(self.x_np) + + @jit + def construct(self): + return self.triu(self.x_ms) + + +def tril_triu(nptype, diagonal): + context.set_context(mode=context.GRAPH_MODE) + tril_ = TrilNet(nptype, diagonal) + triu_ = TriuNet(nptype, diagonal) + tril_output = tril_() + triu_output = triu_() + tril_expect = np.tril(tril_.x_np, diagonal).astype(nptype) + triu_expect = np.triu(triu_.x_np, diagonal).astype(nptype) + assert (tril_output.asnumpy() == tril_expect).all() + assert (triu_output.asnumpy() == triu_expect).all() + + +def tril_triu_pynative(nptype, diagonal): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + tril_ = TrilNet(nptype, diagonal) + triu_ = TriuNet(nptype, diagonal) + tril_output = tril_() + triu_output = triu_() + tril_expect = np.tril(tril_.x_np, diagonal).astype(nptype) + triu_expect = np.triu(triu_.x_np, diagonal).astype(nptype) + assert (tril_output.asnumpy() == tril_expect).all() + assert (triu_output.asnumpy() == triu_expect).all() + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_tril_triu_graph_float32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.float32, 4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_dyn(): + """ + Feature: test Tril op in gpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + for i in range(-1, 2, 1): + net = TrilDynNet(diagonal=i) + x_dyn = Tensor(shape=[None, None], dtype=ms.float32) + net.set_inputs(x_dyn) + + x = Tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [10, 11, 12, 13], [14, 15, 16, 17]], + dtype=ms.float32) + out = net(x) + + expect_shape = (4, 4) + assert out.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_triu_dyn(): + """ + Feature: test Triu op in gpu. + Description: test the ops in dynamic shape. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + for i in range(-1, 2, 1): + net = TriuDynNet(diagonal=i) + x_dyn = Tensor(shape=[None, None], dtype=ms.float32) + net.set_inputs(x_dyn) + + x = Tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [10, 11, 12, 13], [14, 15, 16, 17]], + dtype=ms.float32) + out = net(x) + + expect_shape = (4, 4) + assert out.asnumpy().shape == expect_shape + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_uint8(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.uint8, -5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_uint16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.uint16, -4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_uint32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.uint32, -3) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_uint64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.uint64, -2) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_int8(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.int8, -1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_int16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.int16, 0) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_int32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.int32, 1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_int64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.int64, 2) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_float16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.float16, 3) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_float64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.float64, 5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_graph_bool(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu(np.bool, 6) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_uint8(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.uint8, -5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_uint16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.uint16, -4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_uint32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.uint32, -3) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_uint64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.uint64, -2) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_int8(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.int8, -1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_int16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.int16, 0) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_int32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.int32, 1) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_int64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.int64, 2) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_float16(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.float16, 3) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_float32(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.float32, 4) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_float64(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.float64, 5) + + +@arg_mark(plat_marks=['platform_gpu'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_tril_triu_pynative_bool(): + """ + Feature: ALL To ALL + Description: test cases for Tril and Triu + Expectation: the result match to numpy + """ + tril_triu_pynative(np.bool, 6) diff --git a/tests/st/optimizer/optimizer_utils.py b/tests/st/optimizer/optimizer_utils.py index 03d9c0601ce..d218c318150 100644 --- a/tests/st/optimizer/optimizer_utils.py +++ b/tests/st/optimizer/optimizer_utils.py @@ -1,251 +1,251 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import mindspore -from mindspore import nn, Tensor -from mindspore.ops import operations as P -from mindspore.nn.optim import ASGD -from mindspore.nn.optim import Rprop -from mindspore.nn.optim import AdaMax - -np.random.seed(1024) - -fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149, - 0.6942514, 0.39767185, 0.24918061, 0.4548748], - [0.7203382, 0.19086994, 0.76286614, 0.87920564, - 0.3169892, 0.9462494, 0.62827677, 0.27504718], - [0.3544535, 0.2524781, 0.5370583, 0.8313121, - 0.6670143, 0.0488653, 0.62225235, 0.7546456], - [0.17985944, 0.05106374, 0.31064633, 0.4863033, - 0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32") - -fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32") - -fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32") - -fc2_bias = np.array([0.09996348]).astype("float32") - - -def make_fake_data(): - """ - make fake data - """ - data, label = [], [] - for i in range(20): - data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32))) - label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32))) - return data, label - - -class NetWithLoss(nn.Cell): - """ - build net with loss - """ - - def __init__(self, network, loss_fn): - super(NetWithLoss, self).__init__() - self.network = network - self.loss = loss_fn - - def construct(self, x, label): - out = self.network(x) - loss = self.loss(out, label) - return loss - - -class FakeNet(nn.Cell): - """ - build fake net - """ - - def __init__(self): - super(FakeNet, self).__init__() - self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias)) - self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias)) - self.relu = nn.ReLU() - self.reducemean = P.ReduceMean() - - def construct(self, x): - x = self.relu(self.fc1(x)) - x = self.fc2(x) - return x - - def _initialize_weights(self): - """ - parameter initialization - """ - self.init_parameters_data() - for name, m in self.cells_and_names(): - if name == 'fc1': - m.weight.set_data(Tensor(fc1_weight)) - m.bias.set_data(Tensor(fc1_bias)) - elif name == 'fc2': - m.weight.set_data(Tensor(fc2_weight)) - m.bias.set_data(Tensor(fc2_bias)) - - -def build_network(opt_config, net, is_group=None, loss_fn=None): - """ - Construct training - """ - if is_group is None: - is_group = False - if loss_fn is None: - loss_fn = nn.L1Loss(reduction='mean') - losses = [] - networkwithloss = NetWithLoss(net, loss_fn) - networkwithloss.set_train() - - if is_group: - fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params())) - fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params())) - if opt_config['name'] == 'ASGD': - params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.1}] - elif opt_config['name'] == 'adamax': - params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}] - elif opt_config['name'] == 'SGD': - params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}] - else: - params = [{'params': fc1_params, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.01}] - else: - params = networkwithloss.trainable_params() - - if opt_config['name'] == 'ASGD': - net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'], - t0=opt_config['t0'], weight_decay=opt_config['weight_decay']) - - elif opt_config['name'] == 'Rprop': - net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'], - step_sizes=opt_config['step_sizes'], weight_decay=0.0) - - elif opt_config['name'] == 'adamax': - net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'], - beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0) - elif opt_config['name'] == 'SGD': - net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1) - trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt) - data, label = make_fake_data() - for i in range(20): - loss = trainonestepcell(data[i], label[i]) - losses.append(loss.asnumpy()) - return np.array(losses), net_opt - - -default_fc1_weight_asgd = np.array([[0.460443, 0.693057, 0.145399, -0.076741, 0.431228, 0.134655, - -0.013833, 0.191857], - [0.391073, -0.138385, 0.433600, 0.549937, -0.012268, 0.616980, - 0.299013, -0.054209], - [0.064144, -0.037829, 0.246745, 0.540993, 0.376698, -0.241438, - 0.331937, 0.464328], - [-0.066224, -0.195017, 0.064560, 0.240214, 0.602717, 0.306225, - -0.043127, 0.475241]], dtype=np.float32) -default_fc1_bias_asgd = np.array([0.740427, 0.091827, 0.624849, 0.851911], dtype=np.float32) -default_fc2_weight_asgd = np.array([[0.585555, 0.512303, 0.424419, 0.323499]], dtype=np.float32) -default_fc2_bias_asgd = np.array([0.059962], dtype=np.float32) - -no_default_fc1_weight_asgd = np.array([[0.645291, 0.877900, 0.330253, 0.108117, 0.616077, 0.319509, 0.171024, - 0.376710], - [0.687056, 0.157610, 0.729583, 0.845918, 0.283724, 0.912958, 0.594999, - 0.241783], - [0.328432, 0.226461, 0.511030, 0.805272, 0.640981, 0.022857, 0.596221, - 0.728608], - [0.165102, 0.036311, 0.295884, 0.471533, 0.834030, 0.537543, 0.188198, - 0.706556]], dtype=np.float32) -no_default_fc1_bias_asgd = np.array([0.785650, 0.131580, 0.658614, 0.878328], dtype=np.float32) -no_default_fc2_weight_asgd = np.array([[0.374859, -0.049370, -0.068307, -0.115195]], dtype=np.float32) -no_default_fc2_bias_asgd = np.array([0.083960], dtype=np.float32) - -no_default_group_fc1_weight_asgd = np.array([[0.197470, 0.429578, -0.116887, -0.338544, 0.168320, -0.127608, - -0.275773, -0.070531], - [0.119964, -0.408341, 0.162399, 0.278482, -0.282498, 0.345379, - 0.028105, -0.324348], - [-0.168310, -0.270062, 0.013893, 0.307500, 0.143563, -0.473227, - 0.098900, 0.231002], - [-0.254349, -0.382861, -0.123849, 0.051422, 0.413136, 0.117289, - -0.231302, 0.285938]], dtype=np.float32) -no_default_group_fc1_bias_asgd = np.array([0.706595, 0.042866, 0.579553, 0.811499], dtype=np.float32) -no_default_group_fc2_weight_asgd = np.array([[-0.076689, -0.092399, -0.072100, -0.054189]], dtype=np.float32) -no_default_group_fc2_bias_asgd = np.array([0.698678], dtype=np.float32) - -default_fc1_weight_sgd = np.array([[0.00533873, 0.03210080, -0.03090680, -0.05646387, 0.00197765, - -0.03214293, -0.04922638, -0.02556189], - [-0.00658702, -0.06750072, -0.00169432, 0.01169018, -0.05299109, - 0.01940336, -0.01717841, -0.05781638], - [-0.03723934, -0.04897130, -0.01623122, 0.01762178, -0.00128018, - -0.07239634, -0.00642990, 0.00880153], - [-0.04421479, -0.05903235, -0.02916817, -0.00895938, 0.03274637, - -0.00136485, -0.04155754, 0.01808037]], dtype=np.float32) -default_fc2_weight_sgd = np.array([[-0.01070179, -0.00702989, -0.00210839, 0.00160410]], dtype=np.float32) - -default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000], - [11.18415642, 11.18415642, 11.18415642, 11.18415642, 11.18415642, - 11.18415642, 11.18415642, 11.18415642], - [-6.70855522, -6.70855522, -6.70855522, -6.70855522, -6.70855522, - -6.70855522, -6.70855522, -6.70855522], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) -default_fc1_bias_adamax = np.array([0.00000000, 0.86349380, -0.51633584, 0.00000000], dtype=np.float32) - -no_default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000], - [-4.02891350, -4.02891350, -4.02891350, -4.02891350, -4.02891350, - -4.02891350, -4.02891350, -4.02891350], - [3.10859227, 3.10859227, 3.10859227, 3.10859227, 3.10859227, - 3.10859227, 3.10859227, 3.10859227], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) -no_default_fc1_bias_adamax = np.array([0.00000000, -0.04809491, 0.06205747, 0.00000000], dtype=np.float32) - -default_group_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000], - [11.07278919, 11.07278919, 11.07278919, 11.07278919, 11.07278919, - 11.07278919, 11.07278919, 11.07278919], - [-6.81674862, -6.81674862, -6.81674862, -6.81674862, -6.81674862, - -6.81674862, -6.81674862, -6.81674862], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) -default_group_fc1_bias_adamax = np.array([0.00000000, 0.85614461, -0.52348828, 0.00000000], dtype=np.float32) - -default_fc1_weight_rprop = np.array([[9.10877514, 9.10877514, 9.10877514, 9.10877514, 9.10877514, - 9.10877514, 9.10877514, 9.10877514], - [2.68465400, 2.68465400, 2.68465400, 2.68465400, 2.68465400, - 2.68465400, 2.68465400, 2.68465400], - [1.04377401, 1.04377401, 1.04377401, 1.04377401, 1.04377401, - 1.04377401, 1.04377401, 1.04377401], - [-1.33468997, -1.33468997, -1.33468997, -1.33468997, -1.33468997, - -1.33468997, -1.33468997, -1.33468997]], dtype=np.float32) -default_fc1_bias_rprop = np.array([0.47940922, 0.14129758, 0.05493547, -0.07024684], dtype=np.float32) - -no_default_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, - 8.41605091, 8.41605091], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000]], dtype=np.float32) -no_default_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) - -default_group_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, - 8.41605091, 8.41605091], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000], - [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, - 0.00000000, 0.00000000]], dtype=np.float32) -default_group_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore +from mindspore import nn, Tensor +from mindspore.ops import operations as P +from mindspore.nn.optim import ASGD +from mindspore.nn.optim import Rprop +from mindspore.nn.optim import AdaMax + +np.random.seed(1024) + +fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149, + 0.6942514, 0.39767185, 0.24918061, 0.4548748], + [0.7203382, 0.19086994, 0.76286614, 0.87920564, + 0.3169892, 0.9462494, 0.62827677, 0.27504718], + [0.3544535, 0.2524781, 0.5370583, 0.8313121, + 0.6670143, 0.0488653, 0.62225235, 0.7546456], + [0.17985944, 0.05106374, 0.31064633, 0.4863033, + 0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32") + +fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32") + +fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32") + +fc2_bias = np.array([0.09996348]).astype("float32") + + +def make_fake_data(): + """ + make fake data + """ + data, label = [], [] + for i in range(20): + data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32))) + label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32))) + return data, label + + +class NetWithLoss(nn.Cell): + """ + build net with loss + """ + + def __init__(self, network, loss_fn): + super(NetWithLoss, self).__init__() + self.network = network + self.loss = loss_fn + + def construct(self, x, label): + out = self.network(x) + loss = self.loss(out, label) + return loss + + +class FakeNet(nn.Cell): + """ + build fake net + """ + + def __init__(self): + super(FakeNet, self).__init__() + self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias)) + self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias)) + self.relu = nn.ReLU() + self.reducemean = P.ReduceMean() + + def construct(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + def _initialize_weights(self): + """ + parameter initialization + """ + self.init_parameters_data() + for name, m in self.cells_and_names(): + if name == 'fc1': + m.weight.set_data(Tensor(fc1_weight)) + m.bias.set_data(Tensor(fc1_bias)) + elif name == 'fc2': + m.weight.set_data(Tensor(fc2_weight)) + m.bias.set_data(Tensor(fc2_bias)) + + +def build_network(opt_config, net, is_group=None, loss_fn=None): + """ + Construct training + """ + if is_group is None: + is_group = False + if loss_fn is None: + loss_fn = nn.L1Loss(reduction='mean') + losses = [] + networkwithloss = NetWithLoss(net, loss_fn) + networkwithloss.set_train() + + if is_group: + fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params())) + fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params())) + if opt_config['name'] == 'ASGD': + params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.1}] + elif opt_config['name'] == 'adamax': + params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}] + elif opt_config['name'] == 'SGD': + params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}] + else: + params = [{'params': fc1_params, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.01}] + else: + params = networkwithloss.trainable_params() + + if opt_config['name'] == 'ASGD': + net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'], + t0=opt_config['t0'], weight_decay=opt_config['weight_decay']) + + elif opt_config['name'] == 'Rprop': + net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'], + step_sizes=opt_config['step_sizes'], weight_decay=0.0) + + elif opt_config['name'] == 'adamax': + net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'], + beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0) + elif opt_config['name'] == 'SGD': + net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1) + trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt) + data, label = make_fake_data() + for i in range(20): + loss = trainonestepcell(data[i], label[i]) + losses.append(loss.asnumpy()) + return np.array(losses), net_opt + + +default_fc1_weight_asgd = np.array([[0.460443, 0.693057, 0.145399, -0.076741, 0.431228, 0.134655, + -0.013833, 0.191857], + [0.391073, -0.138385, 0.433600, 0.549937, -0.012268, 0.616980, + 0.299013, -0.054209], + [0.064144, -0.037829, 0.246745, 0.540993, 0.376698, -0.241438, + 0.331937, 0.464328], + [-0.066224, -0.195017, 0.064560, 0.240214, 0.602717, 0.306225, + -0.043127, 0.475241]], dtype=np.float32) +default_fc1_bias_asgd = np.array([0.740427, 0.091827, 0.624849, 0.851911], dtype=np.float32) +default_fc2_weight_asgd = np.array([[0.585555, 0.512303, 0.424419, 0.323499]], dtype=np.float32) +default_fc2_bias_asgd = np.array([0.059962], dtype=np.float32) + +no_default_fc1_weight_asgd = np.array([[0.645291, 0.877900, 0.330253, 0.108117, 0.616077, 0.319509, 0.171024, + 0.376710], + [0.687056, 0.157610, 0.729583, 0.845918, 0.283724, 0.912958, 0.594999, + 0.241783], + [0.328432, 0.226461, 0.511030, 0.805272, 0.640981, 0.022857, 0.596221, + 0.728608], + [0.165102, 0.036311, 0.295884, 0.471533, 0.834030, 0.537543, 0.188198, + 0.706556]], dtype=np.float32) +no_default_fc1_bias_asgd = np.array([0.785650, 0.131580, 0.658614, 0.878328], dtype=np.float32) +no_default_fc2_weight_asgd = np.array([[0.374859, -0.049370, -0.068307, -0.115195]], dtype=np.float32) +no_default_fc2_bias_asgd = np.array([0.083960], dtype=np.float32) + +no_default_group_fc1_weight_asgd = np.array([[0.197470, 0.429578, -0.116887, -0.338544, 0.168320, -0.127608, + -0.275773, -0.070531], + [0.119964, -0.408341, 0.162399, 0.278482, -0.282498, 0.345379, + 0.028105, -0.324348], + [-0.168310, -0.270062, 0.013893, 0.307500, 0.143563, -0.473227, + 0.098900, 0.231002], + [-0.254349, -0.382861, -0.123849, 0.051422, 0.413136, 0.117289, + -0.231302, 0.285938]], dtype=np.float32) +no_default_group_fc1_bias_asgd = np.array([0.706595, 0.042866, 0.579553, 0.811499], dtype=np.float32) +no_default_group_fc2_weight_asgd = np.array([[-0.076689, -0.092399, -0.072100, -0.054189]], dtype=np.float32) +no_default_group_fc2_bias_asgd = np.array([0.698678], dtype=np.float32) + +default_fc1_weight_sgd = np.array([[0.00533873, 0.03210080, -0.03090680, -0.05646387, 0.00197765, + -0.03214293, -0.04922638, -0.02556189], + [-0.00658702, -0.06750072, -0.00169432, 0.01169018, -0.05299109, + 0.01940336, -0.01717841, -0.05781638], + [-0.03723934, -0.04897130, -0.01623122, 0.01762178, -0.00128018, + -0.07239634, -0.00642990, 0.00880153], + [-0.04421479, -0.05903235, -0.02916817, -0.00895938, 0.03274637, + -0.00136485, -0.04155754, 0.01808037]], dtype=np.float32) +default_fc2_weight_sgd = np.array([[-0.01070179, -0.00702989, -0.00210839, 0.00160410]], dtype=np.float32) + +default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000], + [11.18415642, 11.18415642, 11.18415642, 11.18415642, 11.18415642, + 11.18415642, 11.18415642, 11.18415642], + [-6.70855522, -6.70855522, -6.70855522, -6.70855522, -6.70855522, + -6.70855522, -6.70855522, -6.70855522], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) +default_fc1_bias_adamax = np.array([0.00000000, 0.86349380, -0.51633584, 0.00000000], dtype=np.float32) + +no_default_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000], + [-4.02891350, -4.02891350, -4.02891350, -4.02891350, -4.02891350, + -4.02891350, -4.02891350, -4.02891350], + [3.10859227, 3.10859227, 3.10859227, 3.10859227, 3.10859227, + 3.10859227, 3.10859227, 3.10859227], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) +no_default_fc1_bias_adamax = np.array([0.00000000, -0.04809491, 0.06205747, 0.00000000], dtype=np.float32) + +default_group_fc1_weight_adamax = np.array([[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000], + [11.07278919, 11.07278919, 11.07278919, 11.07278919, 11.07278919, + 11.07278919, 11.07278919, 11.07278919], + [-6.81674862, -6.81674862, -6.81674862, -6.81674862, -6.81674862, + -6.81674862, -6.81674862, -6.81674862], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000]], dtype=np.float32) +default_group_fc1_bias_adamax = np.array([0.00000000, 0.85614461, -0.52348828, 0.00000000], dtype=np.float32) + +default_fc1_weight_rprop = np.array([[9.10877514, 9.10877514, 9.10877514, 9.10877514, 9.10877514, + 9.10877514, 9.10877514, 9.10877514], + [2.68465400, 2.68465400, 2.68465400, 2.68465400, 2.68465400, + 2.68465400, 2.68465400, 2.68465400], + [1.04377401, 1.04377401, 1.04377401, 1.04377401, 1.04377401, + 1.04377401, 1.04377401, 1.04377401], + [-1.33468997, -1.33468997, -1.33468997, -1.33468997, -1.33468997, + -1.33468997, -1.33468997, -1.33468997]], dtype=np.float32) +default_fc1_bias_rprop = np.array([0.47940922, 0.14129758, 0.05493547, -0.07024684], dtype=np.float32) + +no_default_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, + 8.41605091, 8.41605091], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000]], dtype=np.float32) +no_default_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) + +default_group_fc1_weight_rprop = np.array([[8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, 8.41605091, + 8.41605091, 8.41605091], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000]], dtype=np.float32) +default_group_fc1_bias_rprop = np.array([0.44295004, 0.00000000, 0.00000000, 0.00000000], dtype=np.float32) diff --git a/tests/st/optimizer/test_asgd.py b/tests/st/optimizer/test_asgd.py index 023daff94e8..1a8c33c53cf 100644 --- a/tests/st/optimizer/test_asgd.py +++ b/tests/st/optimizer/test_asgd.py @@ -1,82 +1,82 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest - -import mindspore.context as context -from .optimizer_utils import FakeNet, build_network -from tests.st.utils import test_utils -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@test_utils.run_test_with_On -def test_default_asgd(mode): - """ - Feature: Test ASGD optimizer - Description: Test ASGD with default parameter - Expectation: Loss values and parameters conform to preset values. - """ - from .optimizer_utils import default_fc1_weight_asgd, \ - default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd - context.set_context(mode=mode) - config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} - _, cells = build_network(config, FakeNet()) - assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-3) - assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-3) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_no_default_asgd(mode): - """ - Feature: Test ASGD optimizer - Description: Test ASGD with another set of parameter - Expectation: Loss values and parameters conform to preset values. - """ - from .optimizer_utils import no_default_fc1_weight_asgd, \ - no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd - config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} - context.set_context(mode=mode) - _, cells = build_network(config, FakeNet()) - assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-3) - assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-3) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_default_asgd_group(mode): - """ - Feature: Test ASGD optimizer - Description: Test ASGD with parameter grouping - Expectation: Loss values and parameters conform to preset values. - """ - from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ - no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd - context.set_context(mode=mode) - config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} - _, cells = build_network(config, FakeNet(), is_group=True) - assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-3) - assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-3) - assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-3) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +from .optimizer_utils import FakeNet, build_network +from tests.st.utils import test_utils +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@test_utils.run_test_with_On +def test_default_asgd(mode): + """ + Feature: Test ASGD optimizer + Description: Test ASGD with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=mode) + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + _, cells = build_network(config, FakeNet()) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-3) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-3) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_no_default_asgd(mode): + """ + Feature: Test ASGD optimizer + Description: Test ASGD with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + context.set_context(mode=mode) + _, cells = build_network(config, FakeNet()) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-3) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-3) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_default_asgd_group(mode): + """ + Feature: Test ASGD optimizer + Description: Test ASGD with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=mode) + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + _, cells = build_network(config, FakeNet(), is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-3) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-3) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-3) diff --git a/tests/st/optimizer/test_fused_adafactor.py b/tests/st/optimizer/test_fused_adafactor.py index fe1a6d7662c..3f7718f5f75 100644 --- a/tests/st/optimizer/test_fused_adafactor.py +++ b/tests/st/optimizer/test_fused_adafactor.py @@ -1,52 +1,52 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import TrainOneStepCell, WithLossCell - -from tests.mark_utils import arg_mark -from tests.st.networks.models.lenet import LeNet - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_lenet(mode): - ''' - Feature: AdaFactor - Description: Test AdaFactor - Expectation: Run lenet success - ''' - context.set_context(mode=mode) - data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = LeNet() - net.batch_size = 32 - learning_rate = 0.01 - optimizer = nn.AdaFactor(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, - scale_parameter=False, relative_step=False, beta1=0) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer - train_network.set_train() - loss = [] - for _ in range(10): - res = train_network(data, label) - loss.append(res.asnumpy()) - assert np.all(loss[-1] < 0.1) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell + +from tests.mark_utils import arg_mark +from tests.st.networks.models.lenet import LeNet + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_lenet(mode): + ''' + Feature: AdaFactor + Description: Test AdaFactor + Expectation: Run lenet success + ''' + context.set_context(mode=mode) + data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + net.batch_size = 32 + learning_rate = 0.01 + optimizer = nn.AdaFactor(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, + scale_parameter=False, relative_step=False, beta1=0) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + loss = [] + for _ in range(10): + res = train_network(data, label) + loss.append(res.asnumpy()) + assert np.all(loss[-1] < 0.1) diff --git a/tests/st/optimizer/test_fused_adam_with_flatten_weights.py b/tests/st/optimizer/test_fused_adam_with_flatten_weights.py index 24ff1d2247e..7945b03b52f 100644 --- a/tests/st/optimizer/test_fused_adam_with_flatten_weights.py +++ b/tests/st/optimizer/test_fused_adam_with_flatten_weights.py @@ -1,81 +1,81 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore.common import set_seed - -from tests.mark_utils import arg_mark -from tests.st.networks.models.lenet import LeNet - -set_seed(1) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_lenet_flatten_weight_with_adam(mode): - ''' - Feature: Fused optimizer - Description: Test fused adam with flatten weights - Expectation: Run lenet success and loss < 2.2 - ''' - context.set_context(mode=mode) - data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = LeNet() - net.flatten_weights() - net.batch_size = 32 - optimizer = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters())) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) - train_network.set_train() - loss = [] - for _ in range(10): - res = train_network(data, label) - loss.append(res.asnumpy()) - assert np.all(loss[-1] < 2.2) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_lenet_flatten_weight_with_adam_weight_decay(mode): - ''' - Feature: Fused optimizer - Description: Test fused adam weight decay with flatten weights - Expectation: Run lenet success and loss < 0.1 - ''' - context.set_context(mode=mode) - data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) - label = Tensor(np.ones([32]).astype(np.int32)) - net = LeNet() - net.flatten_weights() - net.batch_size = 32 - optimizer = nn.AdamWeightDecay(filter(lambda x: x.requires_grad, net.get_parameters())) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - train_network = TrainOneStepCell(net_with_criterion, optimizer) - train_network.set_train() - loss = [] - for _ in range(10): - res = train_network(data, label) - loss.append(res.asnumpy()) - assert np.all(loss[-1] < 0.1) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.common import set_seed + +from tests.mark_utils import arg_mark +from tests.st.networks.models.lenet import LeNet + +set_seed(1) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_lenet_flatten_weight_with_adam(mode): + ''' + Feature: Fused optimizer + Description: Test fused adam with flatten weights + Expectation: Run lenet success and loss < 2.2 + ''' + context.set_context(mode=mode) + data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + net.flatten_weights() + net.batch_size = 32 + optimizer = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters())) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + loss = [] + for _ in range(10): + res = train_network(data, label) + loss.append(res.asnumpy()) + assert np.all(loss[-1] < 2.2) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level1', card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_lenet_flatten_weight_with_adam_weight_decay(mode): + ''' + Feature: Fused optimizer + Description: Test fused adam weight decay with flatten weights + Expectation: Run lenet success and loss < 0.1 + ''' + context.set_context(mode=mode) + data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([32]).astype(np.int32)) + net = LeNet() + net.flatten_weights() + net.batch_size = 32 + optimizer = nn.AdamWeightDecay(filter(lambda x: x.requires_grad, net.get_parameters())) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + loss = [] + for _ in range(10): + res = train_network(data, label) + loss.append(res.asnumpy()) + assert np.all(loss[-1] < 0.1) diff --git a/tests/st/optimizer/test_rprop.py b/tests/st/optimizer/test_rprop.py index dd27a26ea9a..598a1a139c1 100644 --- a/tests/st/optimizer/test_rprop.py +++ b/tests/st/optimizer/test_rprop.py @@ -1,70 +1,70 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -from .optimizer_utils import FakeNet, build_network, default_fc1_weight_rprop, default_fc1_bias_rprop, \ - no_default_fc1_weight_rprop, no_default_fc1_bias_rprop, default_group_fc1_weight_rprop, default_group_fc1_bias_rprop -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_default_rprop(mode): - """ - Feature: Test Rprop optimizer - Description: Test Rprop with default parameter - Expectation: Loss values and parameters conform to preset values. - """ - context.set_context(mode=mode) - config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} - _, cells = build_network(config, net=FakeNet()) - assert np.allclose(cells.prev[0].asnumpy(), default_fc1_weight_rprop, atol=1.e-2) - assert np.allclose(cells.prev[1].asnumpy(), default_fc1_bias_rprop, atol=1.e-2) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_no_default_rprop(mode): - """ - Feature: Test Rprop optimizer - Description: Test Rprop with another set of parameter - Expectation: Loss values and parameters conform to preset values. - """ - context.set_context(mode=mode) - config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} - _, cells = build_network(config, net=FakeNet()) - assert np.allclose(cells.prev[0].asnumpy(), no_default_fc1_weight_rprop, atol=1.e-2) - assert np.allclose(cells.prev[1].asnumpy(), no_default_fc1_bias_rprop, atol=1.e-2) - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_default_rprop_group(mode): - """ - Feature: Test Rprop optimizer - Description: Test Rprop with parameter grouping - Expectation: Loss values and parameters conform to preset values. - """ - context.set_context(mode=mode) - config = {'name': 'Rprop', 'lr': 0.1, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} - _, cells = build_network(config, net=FakeNet(), is_group=True) - assert np.allclose(cells.prev[0].asnumpy(), default_group_fc1_weight_rprop, atol=1.e-2) - assert np.allclose(cells.prev[1].asnumpy(), default_group_fc1_bias_rprop, atol=1.e-2) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from .optimizer_utils import FakeNet, build_network, default_fc1_weight_rprop, default_fc1_bias_rprop, \ + no_default_fc1_weight_rprop, no_default_fc1_bias_rprop, default_group_fc1_weight_rprop, default_group_fc1_bias_rprop +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_default_rprop(mode): + """ + Feature: Test Rprop optimizer + Description: Test Rprop with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=mode) + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + _, cells = build_network(config, net=FakeNet()) + assert np.allclose(cells.prev[0].asnumpy(), default_fc1_weight_rprop, atol=1.e-2) + assert np.allclose(cells.prev[1].asnumpy(), default_fc1_bias_rprop, atol=1.e-2) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_no_default_rprop(mode): + """ + Feature: Test Rprop optimizer + Description: Test Rprop with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=mode) + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + _, cells = build_network(config, net=FakeNet()) + assert np.allclose(cells.prev[0].asnumpy(), no_default_fc1_weight_rprop, atol=1.e-2) + assert np.allclose(cells.prev[1].asnumpy(), no_default_fc1_bias_rprop, atol=1.e-2) + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_default_rprop_group(mode): + """ + Feature: Test Rprop optimizer + Description: Test Rprop with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=mode) + config = {'name': 'Rprop', 'lr': 0.1, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + _, cells = build_network(config, net=FakeNet(), is_group=True) + assert np.allclose(cells.prev[0].asnumpy(), default_group_fc1_weight_rprop, atol=1.e-2) + assert np.allclose(cells.prev[1].asnumpy(), default_group_fc1_bias_rprop, atol=1.e-2) diff --git a/tests/st/optimizer/test_sgd.py b/tests/st/optimizer/test_sgd.py index 48fe8854482..5f572b3d197 100644 --- a/tests/st/optimizer/test_sgd.py +++ b/tests/st/optimizer/test_sgd.py @@ -1,39 +1,39 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest - -import mindspore.context as context -from .optimizer_utils import FakeNet, build_network -from tests.st.utils import test_utils -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', - card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@test_utils.run_test_with_On -def test_default_asgd(mode): - """ - Feature: Test SGD optimizer - Description: Test SGD with group weight decay - Expectation: Parameters conform to preset values. - """ - from .optimizer_utils import default_fc1_weight_sgd, default_fc2_weight_sgd - context.set_context(mode=mode) - config = {'name': 'SGD', "weight_decay": 0.1} - _, cells = build_network(config, FakeNet(), is_group=True) - assert np.allclose(cells.accum[0].asnumpy(), default_fc1_weight_sgd, atol=1.e-4) - assert np.allclose(cells.accum[2].asnumpy(), default_fc2_weight_sgd, atol=1.e-4) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +from .optimizer_utils import FakeNet, build_network +from tests.st.utils import test_utils +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['platform_ascend', 'platform_gpu', 'cpu_linux', 'cpu_windows', 'cpu_macos'], level_mark='level0', + card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@test_utils.run_test_with_On +def test_default_asgd(mode): + """ + Feature: Test SGD optimizer + Description: Test SGD with group weight decay + Expectation: Parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_sgd, default_fc2_weight_sgd + context.set_context(mode=mode) + config = {'name': 'SGD', "weight_decay": 0.1} + _, cells = build_network(config, FakeNet(), is_group=True) + assert np.allclose(cells.accum[0].asnumpy(), default_fc1_weight_sgd, atol=1.e-4) + assert np.allclose(cells.accum[2].asnumpy(), default_fc2_weight_sgd, atol=1.e-4) diff --git a/tests/st/optimizer_ex/test_adam_cmp.py b/tests/st/optimizer_ex/test_adam_cmp.py index 7d7637dcb60..cbba14a7c44 100644 --- a/tests/st/optimizer_ex/test_adam_cmp.py +++ b/tests/st/optimizer_ex/test_adam_cmp.py @@ -1,249 +1,249 @@ -from __future__ import absolute_import -import pytest -import numpy as np -import torch -import mindspore -from mindspore import nn -from mindspore import Tensor, context -from mindspore.experimental.optim import Adam -from mindspore.experimental.optim.lr_scheduler import StepLR -from tests.mark_utils import arg_mark - - -class Network(nn.Cell): - def __init__(self, lin_weight, lin_bias): - super().__init__() - self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) - self.relu = nn.ReLU() - - def construct(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class NetworkPt(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(2, 3) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class AdamFactory(): - def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): - super().__init__() - np.random.seed(1024) - self.lin_weight_np = np.random.randn(3, 2).astype(dtype) - self.lin_bias_np = np.random.randn(3,).astype(dtype) - - self.group = group - self.lr_dynamic = lr_dynamic - self.if_change = if_change - self.data = np.random.rand(2, 2).astype(np.float32) - self.label = np.random.rand(2, 3).astype(np.float32) - self.epochs = 1 - self.steps = 1 - self.lr = 0.002 - - self.betas = (0.9, 0.999) - self.eps = 1e-8 - self.weight_decay = 0 - self.amsgrad = False - self.maximize = False - - def forward_pytorch_impl(self): - lin_weight = torch.Tensor(self.lin_weight_np.copy()) - lin_bias = torch.Tensor(self.lin_bias_np.copy()) - - model = NetworkPt() - model.lin.weight = torch.nn.Parameter(lin_weight) - model.lin.bias = torch.nn.Parameter(lin_bias) - - data = torch.from_numpy(self.data.copy()) - label = torch.from_numpy(self.label.copy()) - - if not self.group: - optimizer = torch.optim.Adam(model.parameters(), self.lr) - else: - bias_params, no_bias_params = [], [] - for param in model.named_parameters(): - if "bias" in param[0]: - bias_params.append(param[1]) - else: - no_bias_params.append(param[1]) - group_params = [{'params': bias_params, 'weight_decay': 0.0, 'lr': 0.9, "betas": (0.88, 0.8)}, - {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] - optimizer = torch.optim.Adam(params=group_params, lr=self.lr) - - criterion = torch.nn.L1Loss(reduction='mean') - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - for _ in range(self.epochs): - for _ in range(self.steps): - optimizer.zero_grad() - loss = criterion(model(data), label) - loss.backward() - optimizer.step() - if self.lr_dynamic: - lr_scheduler.step() - if self.if_change: - optimizer.param_groups[1]["betas"] = (0.77, 0.7) - optimizer.param_groups[1]["amsgrad"] = False - output = model(data) - return output.detach().numpy() - - - def forward_mindspore_impl(self): - lin_weight = Tensor(self.lin_weight_np.copy()) - lin_bias = Tensor(self.lin_bias_np.copy()) - model = Network(lin_weight, lin_bias) - - data = Tensor(self.data) - label = Tensor(self.label) - - if not self.group: - optimizer = Adam(model.trainable_params(), self.lr) - else: - bias_params = list(filter(lambda x: 'bias' in x.name, model.trainable_params())) - no_bias_params = list(filter(lambda x: 'bias' not in x.name, model.trainable_params())) - group_params = [{'params': bias_params, 'weight_decay': 0.0, 'lr': 0.9, "betas": (0.88, 0.8)}, - {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] - optimizer = Adam(params=group_params, lr=self.lr) - - criterion = nn.MAELoss(reduction="mean") - lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - def forward_fn(data, label): - logits = model(data) - loss = criterion(logits, label) - return loss, logits - - grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - - def train_step(data, label): - (loss, _), grads = grad_fn(data, label) - optimizer(grads) - return loss - - def train(epochs, steps, lr_dynamic, if_change): - for _ in range(epochs): - for _ in range(steps): - train_step(data, label) - if lr_dynamic: - lr_scheduler.step() - if if_change: - optimizer.param_groups[1]["betas"] = (0.77, 0.7) - optimizer.param_groups[1]["amsgrad"] = False - train(self.epochs, self.steps, self.lr_dynamic, self.if_change) - output = model(data) - return output.asnumpy() - - def result_cmp(self): - loss_expect = self.forward_pytorch_impl() - loss_out = self.forward_mindspore_impl() - allclose_nparray(loss_expect, loss_out, 0.005, 0.005) - - -def _count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_me) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, \ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ - format(data_expected[greater], data_me[greater], error[greater]) - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): - assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): - _count_unequal_element(data_expected, data_me, rtol, atol) - else: - assert np.array(data_expected).shape == np.array(data_me).shape - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adam_basic(mode): - """ - Feature: Test adam. - Description: Test adam with default parameter. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamFactory(False, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adam_group(mode): - """ - Feature: Test adam. - Description: Test adam with grouped params. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamFactory(True, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adam_lr_dynamic(mode): - """ - Feature: Test adam. - Description: Test adam when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamFactory(False, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adam_group_lr_dynamic(mode): - """ - Feature: Test adam. - Description: Test adam with grouped params when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamFactory(True, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) -def test_adam_group_lr_dynamic_change_param(mode): - """ - Feature: Test adam. - Description: Test adam with grouped params when optimizer params are changed. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamFactory(True, True, True) - fact.result_cmp() +from __future__ import absolute_import +import pytest +import numpy as np +import torch +import mindspore +from mindspore import nn +from mindspore import Tensor, context +from mindspore.experimental.optim import Adam +from mindspore.experimental.optim.lr_scheduler import StepLR +from tests.mark_utils import arg_mark + + +class Network(nn.Cell): + def __init__(self, lin_weight, lin_bias): + super().__init__() + self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) + self.relu = nn.ReLU() + + def construct(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class NetworkPt(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(2, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class AdamFactory(): + def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): + super().__init__() + np.random.seed(1024) + self.lin_weight_np = np.random.randn(3, 2).astype(dtype) + self.lin_bias_np = np.random.randn(3,).astype(dtype) + + self.group = group + self.lr_dynamic = lr_dynamic + self.if_change = if_change + self.data = np.random.rand(2, 2).astype(np.float32) + self.label = np.random.rand(2, 3).astype(np.float32) + self.epochs = 1 + self.steps = 1 + self.lr = 0.002 + + self.betas = (0.9, 0.999) + self.eps = 1e-8 + self.weight_decay = 0 + self.amsgrad = False + self.maximize = False + + def forward_pytorch_impl(self): + lin_weight = torch.Tensor(self.lin_weight_np.copy()) + lin_bias = torch.Tensor(self.lin_bias_np.copy()) + + model = NetworkPt() + model.lin.weight = torch.nn.Parameter(lin_weight) + model.lin.bias = torch.nn.Parameter(lin_bias) + + data = torch.from_numpy(self.data.copy()) + label = torch.from_numpy(self.label.copy()) + + if not self.group: + optimizer = torch.optim.Adam(model.parameters(), self.lr) + else: + bias_params, no_bias_params = [], [] + for param in model.named_parameters(): + if "bias" in param[0]: + bias_params.append(param[1]) + else: + no_bias_params.append(param[1]) + group_params = [{'params': bias_params, 'weight_decay': 0.0, 'lr': 0.9, "betas": (0.88, 0.8)}, + {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] + optimizer = torch.optim.Adam(params=group_params, lr=self.lr) + + criterion = torch.nn.L1Loss(reduction='mean') + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + for _ in range(self.epochs): + for _ in range(self.steps): + optimizer.zero_grad() + loss = criterion(model(data), label) + loss.backward() + optimizer.step() + if self.lr_dynamic: + lr_scheduler.step() + if self.if_change: + optimizer.param_groups[1]["betas"] = (0.77, 0.7) + optimizer.param_groups[1]["amsgrad"] = False + output = model(data) + return output.detach().numpy() + + + def forward_mindspore_impl(self): + lin_weight = Tensor(self.lin_weight_np.copy()) + lin_bias = Tensor(self.lin_bias_np.copy()) + model = Network(lin_weight, lin_bias) + + data = Tensor(self.data) + label = Tensor(self.label) + + if not self.group: + optimizer = Adam(model.trainable_params(), self.lr) + else: + bias_params = list(filter(lambda x: 'bias' in x.name, model.trainable_params())) + no_bias_params = list(filter(lambda x: 'bias' not in x.name, model.trainable_params())) + group_params = [{'params': bias_params, 'weight_decay': 0.0, 'lr': 0.9, "betas": (0.88, 0.8)}, + {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] + optimizer = Adam(params=group_params, lr=self.lr) + + criterion = nn.MAELoss(reduction="mean") + lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + def forward_fn(data, label): + logits = model(data) + loss = criterion(logits, label) + return loss, logits + + grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + optimizer(grads) + return loss + + def train(epochs, steps, lr_dynamic, if_change): + for _ in range(epochs): + for _ in range(steps): + train_step(data, label) + if lr_dynamic: + lr_scheduler.step() + if if_change: + optimizer.param_groups[1]["betas"] = (0.77, 0.7) + optimizer.param_groups[1]["amsgrad"] = False + train(self.epochs, self.steps, self.lr_dynamic, self.if_change) + output = model(data) + return output.asnumpy() + + def result_cmp(self): + loss_expect = self.forward_pytorch_impl() + loss_out = self.forward_mindspore_impl() + allclose_nparray(loss_expect, loss_out, 0.005, 0.005) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_me) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): + assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert np.array(data_expected).shape == np.array(data_me).shape + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adam_basic(mode): + """ + Feature: Test adam. + Description: Test adam with default parameter. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamFactory(False, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adam_group(mode): + """ + Feature: Test adam. + Description: Test adam with grouped params. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamFactory(True, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adam_lr_dynamic(mode): + """ + Feature: Test adam. + Description: Test adam when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamFactory(False, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adam_group_lr_dynamic(mode): + """ + Feature: Test adam. + Description: Test adam with grouped params when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamFactory(True, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) +def test_adam_group_lr_dynamic_change_param(mode): + """ + Feature: Test adam. + Description: Test adam with grouped params when optimizer params are changed. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamFactory(True, True, True) + fact.result_cmp() diff --git a/tests/st/optimizer_ex/test_adamw_cmp.py b/tests/st/optimizer_ex/test_adamw_cmp.py index b2ab17e8d89..f16bd36444c 100644 --- a/tests/st/optimizer_ex/test_adamw_cmp.py +++ b/tests/st/optimizer_ex/test_adamw_cmp.py @@ -1,251 +1,251 @@ -from __future__ import absolute_import -import pytest -import numpy as np -import torch -import mindspore -from mindspore import nn -from mindspore import Tensor, context -from mindspore.experimental.optim import AdamW -from mindspore.experimental.optim.lr_scheduler import StepLR -from tests.mark_utils import arg_mark - - -class Network(nn.Cell): - def __init__(self, lin_weight, lin_bias): - super().__init__() - self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) - self.relu = nn.ReLU() - - def construct(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class NetworkPt(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(2, 3) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class AdamWFactory(): - def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): - super().__init__() - np.random.seed(1024) - self.lin_weight_np = np.random.randn(3, 2).astype(dtype) - self.lin_bias_np = np.random.randn(3,).astype(dtype) - - self.data = np.random.rand(2, 2).astype(np.float32) - self.label = np.random.rand(2, 3).astype(np.float32) - - self.group = group - self.lr_dynamic = lr_dynamic - self.if_change = if_change - self.epochs = 1 - self.steps = 1 - self.lr = 0.002 - self.betas = (0.9, 0.999) - self.eps = 1e-8 - self.weight_decay = 0 - self.amsgrad = False - self.maximize = False - - def forward_pytorch_impl(self): - lin_weight = torch.Tensor(self.lin_weight_np.copy()) - lin_bias = torch.Tensor(self.lin_bias_np.copy()) - - model = NetworkPt() - model.lin.weight = torch.nn.Parameter(lin_weight) - model.lin.bias = torch.nn.Parameter(lin_bias) - - data = torch.from_numpy(self.data.copy()) - label = torch.from_numpy(self.label.copy()) - - if not self.group: - optimizer = torch.optim.AdamW(model.parameters(), self.lr) - else: - bias_params, no_bias_params = [], [] - for param in model.named_parameters(): - if "bias" in param[0]: - bias_params.append(param[1]) - else: - no_bias_params.append(param[1]) - group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.9, "betas": (0.88, 0.8)}, - {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] - optimizer = torch.optim.AdamW(params=group_params, lr=self.lr) - - criterion = torch.nn.L1Loss(reduction='mean') - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - for _ in range(self.epochs): - for _ in range(self.steps): - optimizer.zero_grad() - loss = criterion(model(data), label) - loss.backward() - optimizer.step() - if self.lr_dynamic: - lr_scheduler.step() - if self.if_change: - optimizer.param_groups[1]["betas"] = (0.77, 0.7) - optimizer.param_groups[1]["amsgrad"] = False - - output = model(data) - return output.detach().numpy() - - - def forward_mindspore_impl(self): - lin_weight = Tensor(self.lin_weight_np.copy()) - lin_bias = Tensor(self.lin_bias_np.copy()) - model = Network(lin_weight, lin_bias) - - data = Tensor(self.data) - label = Tensor(self.label) - - if not self.group: - optimizer = AdamW(model.trainable_params(), self.lr) - else: - bias_params = list(filter(lambda x: 'bias' in x.name, model.trainable_params())) - no_bias_params = list(filter(lambda x: 'bias' not in x.name, model.trainable_params())) - group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.9, "betas": (0.88, 0.8)}, - {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] - optimizer = AdamW(params=group_params, lr=self.lr) - - criterion = nn.MAELoss(reduction="mean") - lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - def forward_fn(data, label): - logits = model(data) - loss = criterion(logits, label) - return loss, logits - - grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - - def train_step(data, label): - (loss, _), grads = grad_fn(data, label) - optimizer(grads) - return loss - - def train(epochs, steps, lr_dynamic, if_change): - for _ in range(epochs): - for _ in range(steps): - train_step(data, label) - if lr_dynamic: - lr_scheduler.step() - if if_change: - optimizer.param_groups[1]["betas"] = (0.77, 0.7) - optimizer.param_groups[1]["amsgrad"] = False - - train(self.epochs, self.steps, self.lr_dynamic, self.if_change) - output = model(data) - return output.asnumpy() - - def result_cmp(self): - loss_expect = self.forward_pytorch_impl() - loss_out = self.forward_mindspore_impl() - allclose_nparray(loss_expect, loss_out, 0.005, 0.005) - - -def _count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_me) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, \ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ - format(data_expected[greater], data_me[greater], error[greater]) - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): - assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): - _count_unequal_element(data_expected, data_me, rtol, atol) - else: - assert np.array(data_expected).shape == np.array(data_me).shape - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adamw_basic(mode): - """ - Feature: Test adamw. - Description: Test adamw with default parameter. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamWFactory(False, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adamw_group(mode): - """ - Feature: Test adamw. - Description: Test adamw with grouped params. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamWFactory(True, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adamw_lr_dynamic(mode): - """ - Feature: Test adamw. - Description: Test adamw when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamWFactory(False, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_adamw_group_lr_dynamic(mode): - """ - Feature: Test adamw. - Description: Test adamw with grouped params when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamWFactory(True, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) -def test_adamw_group_lr_dynamic_change_param(mode): - """ - Feature: Test adamw. - Description: Test adamw with grouped params when optimizer params are changed. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = AdamWFactory(True, True, True) - fact.result_cmp() +from __future__ import absolute_import +import pytest +import numpy as np +import torch +import mindspore +from mindspore import nn +from mindspore import Tensor, context +from mindspore.experimental.optim import AdamW +from mindspore.experimental.optim.lr_scheduler import StepLR +from tests.mark_utils import arg_mark + + +class Network(nn.Cell): + def __init__(self, lin_weight, lin_bias): + super().__init__() + self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) + self.relu = nn.ReLU() + + def construct(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class NetworkPt(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(2, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class AdamWFactory(): + def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): + super().__init__() + np.random.seed(1024) + self.lin_weight_np = np.random.randn(3, 2).astype(dtype) + self.lin_bias_np = np.random.randn(3,).astype(dtype) + + self.data = np.random.rand(2, 2).astype(np.float32) + self.label = np.random.rand(2, 3).astype(np.float32) + + self.group = group + self.lr_dynamic = lr_dynamic + self.if_change = if_change + self.epochs = 1 + self.steps = 1 + self.lr = 0.002 + self.betas = (0.9, 0.999) + self.eps = 1e-8 + self.weight_decay = 0 + self.amsgrad = False + self.maximize = False + + def forward_pytorch_impl(self): + lin_weight = torch.Tensor(self.lin_weight_np.copy()) + lin_bias = torch.Tensor(self.lin_bias_np.copy()) + + model = NetworkPt() + model.lin.weight = torch.nn.Parameter(lin_weight) + model.lin.bias = torch.nn.Parameter(lin_bias) + + data = torch.from_numpy(self.data.copy()) + label = torch.from_numpy(self.label.copy()) + + if not self.group: + optimizer = torch.optim.AdamW(model.parameters(), self.lr) + else: + bias_params, no_bias_params = [], [] + for param in model.named_parameters(): + if "bias" in param[0]: + bias_params.append(param[1]) + else: + no_bias_params.append(param[1]) + group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.9, "betas": (0.88, 0.8)}, + {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] + optimizer = torch.optim.AdamW(params=group_params, lr=self.lr) + + criterion = torch.nn.L1Loss(reduction='mean') + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + for _ in range(self.epochs): + for _ in range(self.steps): + optimizer.zero_grad() + loss = criterion(model(data), label) + loss.backward() + optimizer.step() + if self.lr_dynamic: + lr_scheduler.step() + if self.if_change: + optimizer.param_groups[1]["betas"] = (0.77, 0.7) + optimizer.param_groups[1]["amsgrad"] = False + + output = model(data) + return output.detach().numpy() + + + def forward_mindspore_impl(self): + lin_weight = Tensor(self.lin_weight_np.copy()) + lin_bias = Tensor(self.lin_bias_np.copy()) + model = Network(lin_weight, lin_bias) + + data = Tensor(self.data) + label = Tensor(self.label) + + if not self.group: + optimizer = AdamW(model.trainable_params(), self.lr) + else: + bias_params = list(filter(lambda x: 'bias' in x.name, model.trainable_params())) + no_bias_params = list(filter(lambda x: 'bias' not in x.name, model.trainable_params())) + group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.9, "betas": (0.88, 0.8)}, + {'params': no_bias_params, 'lr': 0.66, "amsgrad": True}] + optimizer = AdamW(params=group_params, lr=self.lr) + + criterion = nn.MAELoss(reduction="mean") + lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + def forward_fn(data, label): + logits = model(data) + loss = criterion(logits, label) + return loss, logits + + grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + optimizer(grads) + return loss + + def train(epochs, steps, lr_dynamic, if_change): + for _ in range(epochs): + for _ in range(steps): + train_step(data, label) + if lr_dynamic: + lr_scheduler.step() + if if_change: + optimizer.param_groups[1]["betas"] = (0.77, 0.7) + optimizer.param_groups[1]["amsgrad"] = False + + train(self.epochs, self.steps, self.lr_dynamic, self.if_change) + output = model(data) + return output.asnumpy() + + def result_cmp(self): + loss_expect = self.forward_pytorch_impl() + loss_out = self.forward_mindspore_impl() + allclose_nparray(loss_expect, loss_out, 0.005, 0.005) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_me) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): + assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert np.array(data_expected).shape == np.array(data_me).shape + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adamw_basic(mode): + """ + Feature: Test adamw. + Description: Test adamw with default parameter. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamWFactory(False, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adamw_group(mode): + """ + Feature: Test adamw. + Description: Test adamw with grouped params. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamWFactory(True, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adamw_lr_dynamic(mode): + """ + Feature: Test adamw. + Description: Test adamw when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamWFactory(False, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_adamw_group_lr_dynamic(mode): + """ + Feature: Test adamw. + Description: Test adamw with grouped params when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamWFactory(True, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) +def test_adamw_group_lr_dynamic_change_param(mode): + """ + Feature: Test adamw. + Description: Test adamw with grouped params when optimizer params are changed. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = AdamWFactory(True, True, True) + fact.result_cmp() diff --git a/tests/st/optimizer_ex/test_sgd_cmp.py b/tests/st/optimizer_ex/test_sgd_cmp.py index 8d795080615..c94e5399865 100644 --- a/tests/st/optimizer_ex/test_sgd_cmp.py +++ b/tests/st/optimizer_ex/test_sgd_cmp.py @@ -1,244 +1,244 @@ -from __future__ import absolute_import -import pytest -import numpy as np -import torch -import mindspore -from mindspore import nn -from mindspore import Tensor, context -from mindspore.experimental.optim import SGD -from mindspore.experimental.optim.lr_scheduler import StepLR -from tests.mark_utils import arg_mark - - -class Network(nn.Cell): - def __init__(self, lin_weight, lin_bias): - super().__init__() - self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) - self.relu = nn.ReLU() - - def construct(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class NetworkPt(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(2, 3) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.lin(x) - out = self.relu(out) - return out - - -class SGDFactory(): - def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): - super().__init__() - np.random.seed(1024) - self.lin_weight_np = np.random.randn(3, 2).astype(dtype) - self.lin_bias_np = np.random.randn(3,).astype(dtype) - - self.group = group - self.lr_dynamic = lr_dynamic - self.if_change = if_change - self.data = np.random.rand(2, 2).astype(np.float32) - self.label = np.random.rand(2, 3).astype(np.float32) - self.epochs = 1 - self.steps = 1 - self.lr = 0.002 - - def forward_pytorch_impl(self): - lin_weight = torch.Tensor(self.lin_weight_np.copy()) - lin_bias = torch.Tensor(self.lin_bias_np.copy()) - - model = NetworkPt() - model.lin.weight = torch.nn.Parameter(lin_weight) - model.lin.bias = torch.nn.Parameter(lin_bias) - - data = torch.from_numpy(self.data.copy()) - label = torch.from_numpy(self.label.copy()) - - if not self.group: - optimizer = torch.optim.SGD(model.parameters(), lr=self.lr) - else: - bias_params, no_bias_params = [], [] - for param in model.named_parameters(): - if "bias" in param[0]: - bias_params.append(param[1]) - else: - no_bias_params.append(param[1]) - group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.009, "dampening": 1}, - {'params': no_bias_params, 'lr': 0.006, "momentum": 0.7, "nesterov": True}] - optimizer = torch.optim.SGD(params=group_params, lr=self.lr) - - criterion = torch.nn.L1Loss(reduction='mean') - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - for _ in range(self.epochs): - for _ in range(self.steps): - optimizer.zero_grad() - loss = criterion(model(data), label) - loss.backward() - optimizer.step() - if self.lr_dynamic: - lr_scheduler.step() - if self.if_change: - optimizer.param_groups[1]["nesterov"] = False - optimizer.param_groups[1]["momentum"] = 0.2 - - output = model(data) - return output.detach().numpy() - - def forward_mindspore_impl(self): - lin_weight = Tensor(self.lin_weight_np.copy()) - lin_bias = Tensor(self.lin_bias_np.copy()) - model_ms = Network(lin_weight, lin_bias) - - data = Tensor(self.data) - label = Tensor(self.label) - - if not self.group: - optimizer = SGD(params=model_ms.trainable_params(), lr=self.lr) - else: - bias_params = list(filter(lambda x: 'bias' in x.name, model_ms.trainable_params())) - no_bias_params = list(filter(lambda x: 'bias' not in x.name, model_ms.trainable_params())) - group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.009, "dampening": 1}, - {'params': no_bias_params, 'lr': 0.006, "momentum": 0.7, "nesterov": True}] - optimizer = SGD(params=group_params, lr=self.lr) - - criterion = nn.MAELoss(reduction="mean") - - lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) - - def forward_fn(data, label): - logits = model_ms(data) - loss = criterion(logits, label) - return loss, logits - - grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) - - def train_step(data, label): - (loss, _), grads = grad_fn(data, label) - optimizer(grads) - return loss - - def train(epochs, steps, lr_dynamic, if_change): - for _ in range(epochs): - for _ in range(steps): - train_step(data, label) - if lr_dynamic: - lr_scheduler.step() - if if_change: - optimizer.param_groups[1]["nesterov"] = False - optimizer.param_groups[1]["momentum"] = 0.2 - train(self.epochs, self.steps, self.lr_dynamic, self.if_change) - output = model_ms(data) - return output.asnumpy() - - def result_cmp(self): - loss_expect = self.forward_pytorch_impl() - loss_out = self.forward_mindspore_impl() - allclose_nparray(loss_expect, loss_out, 0.005, 0.005) - - -def _count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_me) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, \ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ - format(data_expected[greater], data_me[greater], error[greater]) - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): - assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): - _count_unequal_element(data_expected, data_me, rtol, atol) - else: - assert np.array(data_expected).shape == np.array(data_me).shape - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_sgd_basic(mode): - """ - Feature: Test sgd. - Description: Test sgd with default parameter. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = SGDFactory(False, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_sgd_group(mode): - """ - Feature: Test sgd. - Description: Test sgd with grouped params. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = SGDFactory(True, False) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_sgd_lr_dynamic(mode): - """ - Feature: Test sgd. - Description: Test sgd when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = SGDFactory(False, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_sgd_group_lr_dynamic(mode): - """ - Feature: Test sgd. - Description: Test sgd with grouped params when lr is dynamic. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = SGDFactory(True, True) - fact.result_cmp() - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level1', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) -def test_sgd_group_lr_dynamic_change_param(mode): - """ - Feature: Test sgd. - Description: Test sgd with grouped params when optimizer params are changed. - Expectation: success. - """ - mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) - fact = SGDFactory(True, True, True) - fact.result_cmp() +from __future__ import absolute_import +import pytest +import numpy as np +import torch +import mindspore +from mindspore import nn +from mindspore import Tensor, context +from mindspore.experimental.optim import SGD +from mindspore.experimental.optim.lr_scheduler import StepLR +from tests.mark_utils import arg_mark + + +class Network(nn.Cell): + def __init__(self, lin_weight, lin_bias): + super().__init__() + self.lin = nn.Dense(2, 3, weight_init=lin_weight, bias_init=lin_bias) + self.relu = nn.ReLU() + + def construct(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class NetworkPt(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(2, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.lin(x) + out = self.relu(out) + return out + + +class SGDFactory(): + def __init__(self, group=True, lr_dynamic=False, if_change=False, dtype=np.float32): + super().__init__() + np.random.seed(1024) + self.lin_weight_np = np.random.randn(3, 2).astype(dtype) + self.lin_bias_np = np.random.randn(3,).astype(dtype) + + self.group = group + self.lr_dynamic = lr_dynamic + self.if_change = if_change + self.data = np.random.rand(2, 2).astype(np.float32) + self.label = np.random.rand(2, 3).astype(np.float32) + self.epochs = 1 + self.steps = 1 + self.lr = 0.002 + + def forward_pytorch_impl(self): + lin_weight = torch.Tensor(self.lin_weight_np.copy()) + lin_bias = torch.Tensor(self.lin_bias_np.copy()) + + model = NetworkPt() + model.lin.weight = torch.nn.Parameter(lin_weight) + model.lin.bias = torch.nn.Parameter(lin_bias) + + data = torch.from_numpy(self.data.copy()) + label = torch.from_numpy(self.label.copy()) + + if not self.group: + optimizer = torch.optim.SGD(model.parameters(), lr=self.lr) + else: + bias_params, no_bias_params = [], [] + for param in model.named_parameters(): + if "bias" in param[0]: + bias_params.append(param[1]) + else: + no_bias_params.append(param[1]) + group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.009, "dampening": 1}, + {'params': no_bias_params, 'lr': 0.006, "momentum": 0.7, "nesterov": True}] + optimizer = torch.optim.SGD(params=group_params, lr=self.lr) + + criterion = torch.nn.L1Loss(reduction='mean') + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + for _ in range(self.epochs): + for _ in range(self.steps): + optimizer.zero_grad() + loss = criterion(model(data), label) + loss.backward() + optimizer.step() + if self.lr_dynamic: + lr_scheduler.step() + if self.if_change: + optimizer.param_groups[1]["nesterov"] = False + optimizer.param_groups[1]["momentum"] = 0.2 + + output = model(data) + return output.detach().numpy() + + def forward_mindspore_impl(self): + lin_weight = Tensor(self.lin_weight_np.copy()) + lin_bias = Tensor(self.lin_bias_np.copy()) + model_ms = Network(lin_weight, lin_bias) + + data = Tensor(self.data) + label = Tensor(self.label) + + if not self.group: + optimizer = SGD(params=model_ms.trainable_params(), lr=self.lr) + else: + bias_params = list(filter(lambda x: 'bias' in x.name, model_ms.trainable_params())) + no_bias_params = list(filter(lambda x: 'bias' not in x.name, model_ms.trainable_params())) + group_params = [{'params': bias_params, 'weight_decay': 0.01, 'lr': 0.009, "dampening": 1}, + {'params': no_bias_params, 'lr': 0.006, "momentum": 0.7, "nesterov": True}] + optimizer = SGD(params=group_params, lr=self.lr) + + criterion = nn.MAELoss(reduction="mean") + + lr_scheduler = StepLR(optimizer, 2, gamma=0.5, last_epoch=-1) + + def forward_fn(data, label): + logits = model_ms(data) + loss = criterion(logits, label) + return loss, logits + + grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) + + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + optimizer(grads) + return loss + + def train(epochs, steps, lr_dynamic, if_change): + for _ in range(epochs): + for _ in range(steps): + train_step(data, label) + if lr_dynamic: + lr_scheduler.step() + if if_change: + optimizer.param_groups[1]["nesterov"] = False + optimizer.param_groups[1]["momentum"] = 0.2 + train(self.epochs, self.steps, self.lr_dynamic, self.if_change) + output = model_ms(data) + return output.asnumpy() + + def result_cmp(self): + loss_expect = self.forward_pytorch_impl() + loss_out = self.forward_mindspore_impl() + allclose_nparray(loss_expect, loss_out, 0.005, 0.005) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_me) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)): + assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert np.array(data_expected).shape == np.array(data_me).shape + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_sgd_basic(mode): + """ + Feature: Test sgd. + Description: Test sgd with default parameter. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = SGDFactory(False, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_sgd_group(mode): + """ + Feature: Test sgd. + Description: Test sgd with grouped params. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = SGDFactory(True, False) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_sgd_lr_dynamic(mode): + """ + Feature: Test sgd. + Description: Test sgd when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = SGDFactory(False, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_sgd_group_lr_dynamic(mode): + """ + Feature: Test sgd. + Description: Test sgd with grouped params when lr is dynamic. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = SGDFactory(True, True) + fact.result_cmp() + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level1', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE]) +def test_sgd_group_lr_dynamic_change_param(mode): + """ + Feature: Test sgd. + Description: Test sgd with grouped params when optimizer params are changed. + Expectation: success. + """ + mindspore.set_context(mode=mode, jit_syntax_level=mindspore.STRICT) + fact = SGDFactory(True, True, True) + fact.result_cmp() diff --git a/tests/st/pynative/dynamic_shape/test_pynative_graph_structure_changed.py b/tests/st/pynative/dynamic_shape/test_pynative_graph_structure_changed.py index 14726b2370e..3af91a8b90a 100644 --- a/tests/st/pynative/dynamic_shape/test_pynative_graph_structure_changed.py +++ b/tests/st/pynative/dynamic_shape/test_pynative_graph_structure_changed.py @@ -1,45 +1,45 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import mindspore as ms -from mindspore import nn, value_and_grad -import torch -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_linux'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -def test_pynative_graph_structure_changed(): - """ - Feature: PyNative dynamic shape for Ascend. - Description: Test PyNative dynamic shape if set kernel info. - Expectation: The calculation result is correct and have no exceptip in process. - """ - input_np = np.random.rand(1, 3, 2, 2).astype(np.float32) - - ms_net = nn.BatchNorm2d(num_features=3) - ms_net.set_train() - ms_x = ms.Tensor(input_np) - x_dyn = ms.Tensor(shape=[None for _ in ms_x.shape], dtype=ms_x.dtype) - ms_net.set_inputs(x_dyn) - grad_fn = value_and_grad(ms_net) - ms_output, _ = grad_fn(ms_x) - - torch_net = torch.nn.BatchNorm2d(num_features=3) - torch_x = torch.tensor(input_np) - torch_output = torch_net(torch_x) - assert np.allclose(torch_output.detach().numpy(), ms_output.asnumpy(), 0.01, 0.01) +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +from mindspore import nn, value_and_grad +import torch +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_linux'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_pynative_graph_structure_changed(): + """ + Feature: PyNative dynamic shape for Ascend. + Description: Test PyNative dynamic shape if set kernel info. + Expectation: The calculation result is correct and have no exceptip in process. + """ + input_np = np.random.rand(1, 3, 2, 2).astype(np.float32) + + ms_net = nn.BatchNorm2d(num_features=3) + ms_net.set_train() + ms_x = ms.Tensor(input_np) + x_dyn = ms.Tensor(shape=[None for _ in ms_x.shape], dtype=ms_x.dtype) + ms_net.set_inputs(x_dyn) + grad_fn = value_and_grad(ms_net) + ms_output, _ = grad_fn(ms_x) + + torch_net = torch.nn.BatchNorm2d(num_features=3) + torch_x = torch.tensor(input_np) + torch_output = torch_net(torch_x) + assert np.allclose(torch_output.detach().numpy(), ms_output.asnumpy(), 0.01, 0.01) diff --git a/tests/st/pynative/network/test_pynative_resnet50_ascend_8p.py b/tests/st/pynative/network/test_pynative_resnet50_ascend_8p.py index 71b1fc9994a..3e33d5ba328 100644 --- a/tests/st/pynative/network/test_pynative_resnet50_ascend_8p.py +++ b/tests/st/pynative/network/test_pynative_resnet50_ascend_8p.py @@ -1,32 +1,32 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import os -from tests.mark_utils import arg_mark - - -@arg_mark(plat_marks=['platform_ascend'], - level_mark='level0', - card_mark='allcards', - essential_mark='essential') -def test_pynative_resnet50_ascend_8p_mpi(): - """ - Feature: PyNative ResNet50 8P - Description: test PyNative ResNet50 8p with mpirun - Expectation: success, return_code==0 - """ - os.system("mpirun -n 8 pytest -s test_pynative_resnet50_ascend.py::test_train_tensor" - " >stdout.log 2>&1") - return_code = os.system(r"grep '1 passed' stdout.log") - assert return_code == 0 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='allcards', + essential_mark='essential') +def test_pynative_resnet50_ascend_8p_mpi(): + """ + Feature: PyNative ResNet50 8P + Description: test PyNative ResNet50 8p with mpirun + Expectation: success, return_code==0 + """ + os.system("mpirun -n 8 pytest -s test_pynative_resnet50_ascend.py::test_train_tensor" + " >stdout.log 2>&1") + return_code = os.system(r"grep '1 passed' stdout.log") + assert return_code == 0 diff --git a/tests/st/pynative/test_pynative_heterogeneous.py b/tests/st/pynative/test_pynative_heterogeneous.py index 3ab7a4c384d..c6b21edf32e 100644 --- a/tests/st/pynative/test_pynative_heterogeneous.py +++ b/tests/st/pynative/test_pynative_heterogeneous.py @@ -1,115 +1,115 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test_pynative_heterogeneous """ -import numpy as np - -from mindspore import context, Tensor -from mindspore.nn import Cell -import mindspore.ops as ops -from tests.mark_utils import arg_mark - - -class MulRelu(Cell): - def __init__(self): - super(MulRelu, self).__init__() - self.relu1 = ops.ReLU() - self.relu2 = ops.ReLU() - self.mul = ops.Mul() - - def construct(self, inp1, inp2): - x1 = self.relu1(inp1) - x2 = self.relu2(inp2) - y = self.mul(x1, x2) - return y - - -@arg_mark(plat_marks=['platform_ascend'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -def test_heterogeneous_default_ascend_prim_cpu(): - """ - Feature: PyNative heterogeneous. - Description: Default device target is Ascend, the relu1 set to CPU. - Expectation: The output of device is equal to the output of heterogeneous. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - net = MulRelu() - inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) - inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) - output_device = net(inp1, inp2) - net.relu1.set_device("CPU") - output_heter = net(inp1, inp2) - assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) - - -@arg_mark(plat_marks=['platform_ascend'], - level_mark='level0', - card_mark='onecard', - essential_mark='essential') -def test_heterogeneous_default_cpu_prim_ascend(): - """ - Feature: PyNative heterogeneous. - Description: Default device target is CPU, the relu1 set to Ascend. - Expectation: The output of device is equal to the output of heterogeneous. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - net = MulRelu() - inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) - inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) - output_device = net(inp1, inp2) - net.relu1.set_device("Ascend") - output_heter = net(inp1, inp2) - assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) - - -@arg_mark(plat_marks=['platform_gpu'], - level_mark='level1', - card_mark='onecard', - essential_mark='essential') -def test_heterogeneous_default_gpu_prim_cpu(): - """ - Feature: PyNative heterogeneous. - Description: Default device target is GPU, the relu1 set to CPU. - Expectation: The output of device is equal to the output of heterogeneous. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - net = MulRelu() - inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) - inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) - output_device = net(inp1, inp2) - net.relu1.set_device("CPU") - output_heter = net(inp1, inp2) - assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) - - -@arg_mark(plat_marks=['platform_gpu'], - level_mark='level1', - card_mark='onecard', - essential_mark='essential') -def test_heterogeneous_default_cpu_prim_gpu(): - """ - Feature: PyNative heterogeneous. - Description: Default device target is CPU, the relu1 set to GPU. - Expectation: The output of device is equal to the output of heterogeneous. - """ - context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - net = MulRelu() - inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) - inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) - output_device = net(inp1, inp2) - net.relu1.set_device("GPU") - output_heter = net(inp1, inp2) - assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_pynative_heterogeneous """ +import numpy as np + +from mindspore import context, Tensor +from mindspore.nn import Cell +import mindspore.ops as ops +from tests.mark_utils import arg_mark + + +class MulRelu(Cell): + def __init__(self): + super(MulRelu, self).__init__() + self.relu1 = ops.ReLU() + self.relu2 = ops.ReLU() + self.mul = ops.Mul() + + def construct(self, inp1, inp2): + x1 = self.relu1(inp1) + x2 = self.relu2(inp2) + y = self.mul(x1, x2) + return y + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_heterogeneous_default_ascend_prim_cpu(): + """ + Feature: PyNative heterogeneous. + Description: Default device target is Ascend, the relu1 set to CPU. + Expectation: The output of device is equal to the output of heterogeneous. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + net = MulRelu() + inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) + inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) + output_device = net(inp1, inp2) + net.relu1.set_device("CPU") + output_heter = net(inp1, inp2) + assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_heterogeneous_default_cpu_prim_ascend(): + """ + Feature: PyNative heterogeneous. + Description: Default device target is CPU, the relu1 set to Ascend. + Expectation: The output of device is equal to the output of heterogeneous. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + net = MulRelu() + inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) + inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) + output_device = net(inp1, inp2) + net.relu1.set_device("Ascend") + output_heter = net(inp1, inp2) + assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) + + +@arg_mark(plat_marks=['platform_gpu'], + level_mark='level1', + card_mark='onecard', + essential_mark='essential') +def test_heterogeneous_default_gpu_prim_cpu(): + """ + Feature: PyNative heterogeneous. + Description: Default device target is GPU, the relu1 set to CPU. + Expectation: The output of device is equal to the output of heterogeneous. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = MulRelu() + inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) + inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) + output_device = net(inp1, inp2) + net.relu1.set_device("CPU") + output_heter = net(inp1, inp2) + assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) + + +@arg_mark(plat_marks=['platform_gpu'], + level_mark='level1', + card_mark='onecard', + essential_mark='essential') +def test_heterogeneous_default_cpu_prim_gpu(): + """ + Feature: PyNative heterogeneous. + Description: Default device target is CPU, the relu1 set to GPU. + Expectation: The output of device is equal to the output of heterogeneous. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + net = MulRelu() + inp1 = Tensor(np.random.randn(2, 2).astype(np.float32)) + inp2 = Tensor(np.random.randn(2, 2).astype(np.float32)) + output_device = net(inp1, inp2) + net.relu1.set_device("GPU") + output_heter = net(inp1, inp2) + assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6) diff --git a/tests/st/pynative/test_pynative_sync_control.py b/tests/st/pynative/test_pynative_sync_control.py index a2fb44ff58d..72e2ce68283 100644 --- a/tests/st/pynative/test_pynative_sync_control.py +++ b/tests/st/pynative/test_pynative_sync_control.py @@ -1,51 +1,51 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common import dtype as mstype -from mindspore.ops import operations as P - -context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.get_next = P.GetNext([mstype.float32], [(1, 1)], 1, "test") - - def construct(self, x1,): - x = self.get_next() - x = x + x1 - return x - -def test_pynative_synchronize_true(): - context.set_context(pynative_synchronize=True) - with pytest.raises(RuntimeError) as execinfo: - x1 = np.random.randn(1, 1).astype(np.float32) - net = Net() - output = net(Tensor(x1)) - print(output.asnumpy()) - assert "GetNext" in str(execinfo.value) - -def test_pynative_synchronize_false(): - context.set_context(pynative_synchronize=False) - with pytest.raises(RuntimeError) as execinfo: - x1 = np.random.randn(1, 1).astype(np.float32) - net = Net() - output = net(Tensor(x1)) - print(output.asnumpy()) - assert "Sync stream error" in str(execinfo.value) +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.get_next = P.GetNext([mstype.float32], [(1, 1)], 1, "test") + + def construct(self, x1,): + x = self.get_next() + x = x + x1 + return x + +def test_pynative_synchronize_true(): + context.set_context(pynative_synchronize=True) + with pytest.raises(RuntimeError) as execinfo: + x1 = np.random.randn(1, 1).astype(np.float32) + net = Net() + output = net(Tensor(x1)) + print(output.asnumpy()) + assert "GetNext" in str(execinfo.value) + +def test_pynative_synchronize_false(): + context.set_context(pynative_synchronize=False) + with pytest.raises(RuntimeError) as execinfo: + x1 = np.random.randn(1, 1).astype(np.float32) + net = Net() + output = net(Tensor(x1)) + print(output.asnumpy()) + assert "Sync stream error" in str(execinfo.value) diff --git a/tests/st/sparse/test_sparse_unary_ops.py b/tests/st/sparse/test_sparse_unary_ops.py old mode 100755 new mode 100644 diff --git a/tests/st/summary/test_summary_collector.py b/tests/st/summary/test_summary_collector.py index 151df8886c4..c6573c22241 100644 --- a/tests/st/summary/test_summary_collector.py +++ b/tests/st/summary/test_summary_collector.py @@ -1,324 +1,324 @@ -# Copyright 2020-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""test SummaryCollector.""" -import json -import os -import re -import tempfile -from collections import Counter - -import numpy as np -import pytest -from mindspore.common import set_seed -from mindspore.common.initializer import Normal -from mindspore.nn.optim import Momentum -from mindspore.ops import operations as P -from mindspore.train import Loss, Model - -from mindspore import SummaryCollector, SummaryLandscape, SummaryRecord, Tensor, context, nn -from tests.st.summary.dataset import create_mnist_dataset -from tests.summary_utils import SummaryReader -from tests.mark_utils import arg_mark - - -def callback_fn(): - """A python function job""" - network = LeNet5() - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - metrics = {"Loss": Loss()} - model = Model(network, loss, metrics=metrics) - ds_train = create_mnist_dataset("train", num_samples=6) - return model, network, ds_train, metrics - - -class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Number of classes. Default: 10. - num_channel (int): Number of channels. Default: 1. - - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - - """ - - def __init__(self, num_class=10, num_channel=1, include_top=True): - super(LeNet5, self).__init__() - self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.include_top = include_top - if self.include_top: - self.flatten = nn.Flatten() - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02), bias_init="zeros") - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02), bias_init="zeros") - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02), bias_init="zeros") - - self.scalar_summary = P.ScalarSummary() - self.image_summary = P.ImageSummary() - self.histogram_summary = P.HistogramSummary() - self.tensor_summary = P.TensorSummary() - self.channel = Tensor(num_channel) - - def construct(self, x): - """construct.""" - self.image_summary('image', x) - x = self.conv1(x) - self.histogram_summary('histogram', x) - x = self.relu(x) - self.tensor_summary('tensor', x) - x = self.relu(x) - x = self.max_pool2d(x) - self.scalar_summary('scalar', self.channel) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - if not self.include_top: - return x - x = self.flatten(x) - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -def run_network(dataset_sink_mode=False, num_samples=2, dir_suffix="summary", **kwargs): - """run network.""" - lenet = LeNet5() - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) - model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) - summary_dir = tempfile.TemporaryDirectory(suffix=dir_suffix) - summary_collector = SummaryCollector(summary_dir=summary_dir.name, collect_freq=2, **kwargs) - - ds_train = create_mnist_dataset("train", num_samples=num_samples) - model.train(3, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) - - ds_eval = create_mnist_dataset("test") - model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) - return summary_dir - - -def train_network(epoch=3, dataset_sink_mode=False, num_samples=2, dir_suffix="summary", **kwargs): - """run network.""" - lenet = LeNet5() - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - optim = Momentum(lenet.trainable_params(), learning_rate=0.01, momentum=0.9) - model = Model(lenet, loss_fn=loss, optimizer=optim) - summary_dir = tempfile.TemporaryDirectory(suffix=dir_suffix) - summary_collector = SummaryCollector(summary_dir=summary_dir.name, collect_freq=2, **kwargs) - - ds_train = create_mnist_dataset("train", num_samples=num_samples) - model.train(epoch, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) - return summary_dir - - -class TestSummary: - """Test summary collector the basic function.""" - - @staticmethod - def list_summary_tags(summary_dir): - """list summary tags.""" - summary_file_path = '' - for file in os.listdir(summary_dir): - if re.search("_MS", file): - summary_file_path = os.path.join(summary_dir, file) - break - assert summary_file_path - - tags = list() - with SummaryReader(summary_file_path) as summary_reader: - - while True: - summary_event = summary_reader.read_event() - if not summary_event: - break - for value in summary_event.summary.value: - tags.append(value.tag) - return tags - - @staticmethod - def list_tensor_files(summary_dir): - """list tensor tags.""" - export_file_path = '' - for file in os.listdir(summary_dir): - if re.search("export_", file): - export_file_path = os.path.join(summary_dir, file) - break - assert export_file_path - tensor_file_path = os.path.join(export_file_path, "tensor") - assert tensor_file_path - - tensors = list() - for file in os.listdir(tensor_file_path): - tensors.append(file) - - return tensors - - @staticmethod - def list_summary_collect_landscape_tags(summary_dir): - """list summary landscape tags.""" - summary_dir_path = '' - for file in os.listdir(summary_dir): - if re.search("ckpt_dir", file): - summary_dir_path = os.path.join(summary_dir, file) - break - assert summary_dir_path - - summary_file_path = '' - for file in os.listdir(summary_dir_path): - if re.search(".json", file): - summary_file_path = os.path.join(summary_dir_path, file) - break - assert summary_file_path - - tags = list() - with open(summary_file_path, 'r') as file: - data = json.load(file) - for key, value in data.items(): - tags.append(key) - - assert value - return tags - - @staticmethod - def list_landscape_tags(summary_dir): - """list landscape tags.""" - expected_tags = {'landscape_[1, 3]', 'landscape_[3]'} - summary_list = [] - for file in os.listdir(summary_dir): - if re.search("_MS", file): - summary_file_path = os.path.join(summary_dir, file) - summary_list = summary_list + [summary_file_path] - else: - continue - - assert summary_list - - tags = [] - for summary_path in summary_list: - with SummaryReader(summary_path) as summary_reader: - - while True: - summary_event = summary_reader.read_event() - if not summary_event: - break - for value in summary_event.summary.value: - if value.tag in expected_tags: - tags.append(value.loss_landscape.landscape.z.float_data) - break - return tags - - -@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level0", card_mark="onecard", - essential_mark="essential") -def test_summary_with_sink_mode_false(): - """ - Feature: Test summary with sink mode false, and num samples is 64. - Description: Test summary with sink mode false, and num samples is 64. - Expectation: Passed. - """ - context.set_context(mode=context.GRAPH_MODE) - summary_dir = run_network(num_samples=10, dir_suffix="test_summary_with_sink_mode_false") - - tag_list = TestSummary.list_summary_tags(summary_dir.name) - - expected_summary_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', - 'fc2.weight/auto', 'input_data/auto', 'loss/auto'} - expected_op_tag_set = {'histogram', 'image', 'scalar', 'tensor'} - assert set(expected_op_tag_set | expected_summary_tag_set) == set(tag_list) - - op_tag_count, summary_tag_count = 8, 9 - for key, value in Counter(tag_list).items(): # pylint: disable=E1121 - if key in expected_op_tag_set: - assert value == op_tag_count - if key in expected_summary_tag_set: - assert value == summary_tag_count - - -@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_summarycollector_user_defind(): - """ - Feature: Test SummaryCollector with user-defined. - Description: Test SummaryCollector with user-defined. - Expectation: Passed. - """ - context.set_context(mode=context.GRAPH_MODE) - summary_dir = run_network(dataset_sink_mode=False, num_samples=2, dir_suffix="test_summarycollector_user_defind", - custom_lineage_data={'test': 'self test'}, - export_options={'tensor_format': 'npy'}) - - tag_list = TestSummary.list_summary_tags(summary_dir.name) - file_list = TestSummary.list_tensor_files(summary_dir.name) - # There will not record input data when dataset sink mode is True - expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', - 'fc2.weight/auto', 'loss/auto', 'input_data/auto', 'histogram', 'image', 'scalar', 'tensor'} - assert set(expected_tags) == set(tag_list) - expected_files = {'tensor_1.npy', 'tensor_2.npy'} - assert set(expected_files) == set(file_list) - - -@arg_mark(plat_marks=["platform_gpu"], level_mark="level0", card_mark="onecard", essential_mark="essential") -def test_summary_collector_landscape(): - """ - Feature: Summary collector with landscape. - Description: Test summary collector with landscape. - Expectation: Landscape data collected with expected value. - """ - context.set_context(mode=context.GRAPH_MODE) - set_seed(1) - interval_1 = [1, 2, 3] - num_samples = 6 - summary_dir = train_network(epoch=3, num_samples=num_samples, dir_suffix="test_summary_collector_landscape", - collect_specified_data={'collect_landscape': - {'landscape_size': 4, - 'unit': 'epoch', - 'create_landscape': {'train': True, 'result': True}, - 'num_samples': num_samples, - 'intervals': [interval_1]}}) - - tag_list = TestSummary.list_summary_collect_landscape_tags(summary_dir.name) - expected_tags = {'epoch_group', 'model_params_file_map', 'step_per_epoch', 'unit', 'num_samples', - 'landscape_size', 'create_landscape'} - assert set(expected_tags) == set(tag_list) - device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 - summary_landscape = SummaryLandscape(summary_dir.name) - summary_landscape.gen_landscapes_with_multi_process(callback_fn, device_ids=[device_id]) - tag_list_landscape = TestSummary.list_landscape_tags(summary_dir.name) - assert np.allclose(tag_list_landscape[0], 2.28, atol=0.03) - assert np.allclose(tag_list_landscape[1], 2.28, atol=0.03) - - -@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level1", card_mark="onecard", - essential_mark="essential") -def test_summary_of_more_than_one_instance(): - """ - Feature: Test the multi instances of SummaryRecord in a script. - Description: Multi instances of SummaryRecord in a script. - Expectation: Throw RuntimeError. - """ - context.set_context(mode=context.GRAPH_MODE) - with pytest.raises(RuntimeError) as errinfo: - summary_dir1 = tempfile.TemporaryDirectory(suffix="test_summary_of_more_than_one_instance") - summary_record1 = SummaryRecord(log_dir=summary_dir1.name) - summary_dir2 = tempfile.TemporaryDirectory(suffix="test_summary_of_more_than_one_instance") - _ = SummaryRecord(log_dir=summary_dir2.name) - summary_record1.close() - assert "only one instance is supported in a training process" in str(errinfo.value) +# Copyright 2020-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test SummaryCollector.""" +import json +import os +import re +import tempfile +from collections import Counter + +import numpy as np +import pytest +from mindspore.common import set_seed +from mindspore.common.initializer import Normal +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P +from mindspore.train import Loss, Model + +from mindspore import SummaryCollector, SummaryLandscape, SummaryRecord, Tensor, context, nn +from tests.st.summary.dataset import create_mnist_dataset +from tests.summary_utils import SummaryReader +from tests.mark_utils import arg_mark + + +def callback_fn(): + """A python function job""" + network = LeNet5() + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + metrics = {"Loss": Loss()} + model = Model(network, loss, metrics=metrics) + ds_train = create_mnist_dataset("train", num_samples=6) + return model, network, ds_train, metrics + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Number of classes. Default: 10. + num_channel (int): Number of channels. Default: 1. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + + def __init__(self, num_class=10, num_channel=1, include_top=True): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid', weight_init="normal", bias_init="zeros") + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02), bias_init="zeros") + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02), bias_init="zeros") + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02), bias_init="zeros") + + self.scalar_summary = P.ScalarSummary() + self.image_summary = P.ImageSummary() + self.histogram_summary = P.HistogramSummary() + self.tensor_summary = P.TensorSummary() + self.channel = Tensor(num_channel) + + def construct(self, x): + """construct.""" + self.image_summary('image', x) + x = self.conv1(x) + self.histogram_summary('histogram', x) + x = self.relu(x) + self.tensor_summary('tensor', x) + x = self.relu(x) + x = self.max_pool2d(x) + self.scalar_summary('scalar', self.channel) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def run_network(dataset_sink_mode=False, num_samples=2, dir_suffix="summary", **kwargs): + """run network.""" + lenet = LeNet5() + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) + summary_dir = tempfile.TemporaryDirectory(suffix=dir_suffix) + summary_collector = SummaryCollector(summary_dir=summary_dir.name, collect_freq=2, **kwargs) + + ds_train = create_mnist_dataset("train", num_samples=num_samples) + model.train(3, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) + + ds_eval = create_mnist_dataset("test") + model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) + return summary_dir + + +def train_network(epoch=3, dataset_sink_mode=False, num_samples=2, dir_suffix="summary", **kwargs): + """run network.""" + lenet = LeNet5() + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + optim = Momentum(lenet.trainable_params(), learning_rate=0.01, momentum=0.9) + model = Model(lenet, loss_fn=loss, optimizer=optim) + summary_dir = tempfile.TemporaryDirectory(suffix=dir_suffix) + summary_collector = SummaryCollector(summary_dir=summary_dir.name, collect_freq=2, **kwargs) + + ds_train = create_mnist_dataset("train", num_samples=num_samples) + model.train(epoch, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) + return summary_dir + + +class TestSummary: + """Test summary collector the basic function.""" + + @staticmethod + def list_summary_tags(summary_dir): + """list summary tags.""" + summary_file_path = '' + for file in os.listdir(summary_dir): + if re.search("_MS", file): + summary_file_path = os.path.join(summary_dir, file) + break + assert summary_file_path + + tags = list() + with SummaryReader(summary_file_path) as summary_reader: + + while True: + summary_event = summary_reader.read_event() + if not summary_event: + break + for value in summary_event.summary.value: + tags.append(value.tag) + return tags + + @staticmethod + def list_tensor_files(summary_dir): + """list tensor tags.""" + export_file_path = '' + for file in os.listdir(summary_dir): + if re.search("export_", file): + export_file_path = os.path.join(summary_dir, file) + break + assert export_file_path + tensor_file_path = os.path.join(export_file_path, "tensor") + assert tensor_file_path + + tensors = list() + for file in os.listdir(tensor_file_path): + tensors.append(file) + + return tensors + + @staticmethod + def list_summary_collect_landscape_tags(summary_dir): + """list summary landscape tags.""" + summary_dir_path = '' + for file in os.listdir(summary_dir): + if re.search("ckpt_dir", file): + summary_dir_path = os.path.join(summary_dir, file) + break + assert summary_dir_path + + summary_file_path = '' + for file in os.listdir(summary_dir_path): + if re.search(".json", file): + summary_file_path = os.path.join(summary_dir_path, file) + break + assert summary_file_path + + tags = list() + with open(summary_file_path, 'r') as file: + data = json.load(file) + for key, value in data.items(): + tags.append(key) + + assert value + return tags + + @staticmethod + def list_landscape_tags(summary_dir): + """list landscape tags.""" + expected_tags = {'landscape_[1, 3]', 'landscape_[3]'} + summary_list = [] + for file in os.listdir(summary_dir): + if re.search("_MS", file): + summary_file_path = os.path.join(summary_dir, file) + summary_list = summary_list + [summary_file_path] + else: + continue + + assert summary_list + + tags = [] + for summary_path in summary_list: + with SummaryReader(summary_path) as summary_reader: + + while True: + summary_event = summary_reader.read_event() + if not summary_event: + break + for value in summary_event.summary.value: + if value.tag in expected_tags: + tags.append(value.loss_landscape.landscape.z.float_data) + break + return tags + + +@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level0", card_mark="onecard", + essential_mark="essential") +def test_summary_with_sink_mode_false(): + """ + Feature: Test summary with sink mode false, and num samples is 64. + Description: Test summary with sink mode false, and num samples is 64. + Expectation: Passed. + """ + context.set_context(mode=context.GRAPH_MODE) + summary_dir = run_network(num_samples=10, dir_suffix="test_summary_with_sink_mode_false") + + tag_list = TestSummary.list_summary_tags(summary_dir.name) + + expected_summary_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', + 'fc2.weight/auto', 'input_data/auto', 'loss/auto'} + expected_op_tag_set = {'histogram', 'image', 'scalar', 'tensor'} + assert set(expected_op_tag_set | expected_summary_tag_set) == set(tag_list) + + op_tag_count, summary_tag_count = 8, 9 + for key, value in Counter(tag_list).items(): # pylint: disable=E1121 + if key in expected_op_tag_set: + assert value == op_tag_count + if key in expected_summary_tag_set: + assert value == summary_tag_count + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_summarycollector_user_defind(): + """ + Feature: Test SummaryCollector with user-defined. + Description: Test SummaryCollector with user-defined. + Expectation: Passed. + """ + context.set_context(mode=context.GRAPH_MODE) + summary_dir = run_network(dataset_sink_mode=False, num_samples=2, dir_suffix="test_summarycollector_user_defind", + custom_lineage_data={'test': 'self test'}, + export_options={'tensor_format': 'npy'}) + + tag_list = TestSummary.list_summary_tags(summary_dir.name) + file_list = TestSummary.list_tensor_files(summary_dir.name) + # There will not record input data when dataset sink mode is True + expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', + 'fc2.weight/auto', 'loss/auto', 'input_data/auto', 'histogram', 'image', 'scalar', 'tensor'} + assert set(expected_tags) == set(tag_list) + expected_files = {'tensor_1.npy', 'tensor_2.npy'} + assert set(expected_files) == set(file_list) + + +@arg_mark(plat_marks=["platform_gpu"], level_mark="level0", card_mark="onecard", essential_mark="essential") +def test_summary_collector_landscape(): + """ + Feature: Summary collector with landscape. + Description: Test summary collector with landscape. + Expectation: Landscape data collected with expected value. + """ + context.set_context(mode=context.GRAPH_MODE) + set_seed(1) + interval_1 = [1, 2, 3] + num_samples = 6 + summary_dir = train_network(epoch=3, num_samples=num_samples, dir_suffix="test_summary_collector_landscape", + collect_specified_data={'collect_landscape': + {'landscape_size': 4, + 'unit': 'epoch', + 'create_landscape': {'train': True, 'result': True}, + 'num_samples': num_samples, + 'intervals': [interval_1]}}) + + tag_list = TestSummary.list_summary_collect_landscape_tags(summary_dir.name) + expected_tags = {'epoch_group', 'model_params_file_map', 'step_per_epoch', 'unit', 'num_samples', + 'landscape_size', 'create_landscape'} + assert set(expected_tags) == set(tag_list) + device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 + summary_landscape = SummaryLandscape(summary_dir.name) + summary_landscape.gen_landscapes_with_multi_process(callback_fn, device_ids=[device_id]) + tag_list_landscape = TestSummary.list_landscape_tags(summary_dir.name) + assert np.allclose(tag_list_landscape[0], 2.28, atol=0.03) + assert np.allclose(tag_list_landscape[1], 2.28, atol=0.03) + + +@arg_mark(plat_marks=["platform_ascend", "platform_gpu"], level_mark="level1", card_mark="onecard", + essential_mark="essential") +def test_summary_of_more_than_one_instance(): + """ + Feature: Test the multi instances of SummaryRecord in a script. + Description: Multi instances of SummaryRecord in a script. + Expectation: Throw RuntimeError. + """ + context.set_context(mode=context.GRAPH_MODE) + with pytest.raises(RuntimeError) as errinfo: + summary_dir1 = tempfile.TemporaryDirectory(suffix="test_summary_of_more_than_one_instance") + summary_record1 = SummaryRecord(log_dir=summary_dir1.name) + summary_dir2 = tempfile.TemporaryDirectory(suffix="test_summary_of_more_than_one_instance") + _ = SummaryRecord(log_dir=summary_dir2.name) + summary_record1.close() + assert "only one instance is supported in a training process" in str(errinfo.value) diff --git a/tests/st/tensor/test_arcsinh.py b/tests/st/tensor/test_arcsinh.py index 8e5d6210f15..16ee24a0516 100644 --- a/tests/st/tensor/test_arcsinh.py +++ b/tests/st/tensor/test_arcsinh.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Arcsinh(nn.Cell): - def construct(self, x): - return x.arcsinh() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_arcsinh(mode): - """ - Feature: tensor.arcsinh - Description: Verify the result of arcsinh - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), ms.float32) - net = Arcsinh() - output = net(x) - expect_output = [-2.3124382, 1.1947632, 1.8184465, 5.298342] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Arcsinh(nn.Cell): + def construct(self, x): + return x.arcsinh() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_arcsinh(mode): + """ + Feature: tensor.arcsinh + Description: Verify the result of arcsinh + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), ms.float32) + net = Arcsinh() + output = net(x) + expect_output = [-2.3124382, 1.1947632, 1.8184465, 5.298342] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_arctanh.py b/tests/st/tensor/test_arctanh.py index e2878923710..f28deeaa9f6 100644 --- a/tests/st/tensor/test_arctanh.py +++ b/tests/st/tensor/test_arctanh.py @@ -1,44 +1,44 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Arctanh(nn.Cell): - def construct(self, x): - return x.arctanh() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_arctanh(mode): - """ - Feature: tensor.arctanh - Description: Verify the result of arctanh - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([0, -0.5]), ms.float32) - net = Arctanh() - output = net(x) - expect_output = [0., -0.54930615] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Arctanh(nn.Cell): + def construct(self, x): + return x.arctanh() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_arctanh(mode): + """ + Feature: tensor.arctanh + Description: Verify the result of arctanh + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([0, -0.5]), ms.float32) + net = Arctanh() + output = net(x) + expect_output = [0., -0.54930615] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_baddmm.py b/tests/st/tensor/test_baddmm.py index 7bed299b58b..a839dcb71eb 100644 --- a/tests/st/tensor/test_baddmm.py +++ b/tests/st/tensor/test_baddmm.py @@ -1,54 +1,54 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, y, z, beta, alpha): - return x.baddbmm(y, z, beta=beta, alpha=alpha) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_baddbmm(mode): - """ - Feature: tensor.baddbmm - Description: Verify the result of baddbmm - Expectation: success - """ - ms.set_context(mode=mode) - arr1 = np.arange(18).astype(np.float32).reshape((2, 3, 3)) - arr2 = np.arange(24).astype(np.float32).reshape((2, 3, 4)) - arr3 = np.arange(24).astype(np.float32).reshape((2, 4, 3)) - x = Tensor(arr1) - y = Tensor(arr2) - z = Tensor(arr3) - net = Net() - output = net(x, y, z, 2, 0.4) - expect_output = np.array([[[16.8000, 21.2000, 25.6000], - [51.6000, 62.4000, 73.2000], - [86.4000, 103.6000, 120.8000]], - [[380.4000, 404.0000, 427.6000], - [492.0000, 522.0000, 552.0000], - [603.6000, 640.0000, 676.4000]]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, y, z, beta, alpha): + return x.baddbmm(y, z, beta=beta, alpha=alpha) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_baddbmm(mode): + """ + Feature: tensor.baddbmm + Description: Verify the result of baddbmm + Expectation: success + """ + ms.set_context(mode=mode) + arr1 = np.arange(18).astype(np.float32).reshape((2, 3, 3)) + arr2 = np.arange(24).astype(np.float32).reshape((2, 3, 4)) + arr3 = np.arange(24).astype(np.float32).reshape((2, 4, 3)) + x = Tensor(arr1) + y = Tensor(arr2) + z = Tensor(arr3) + net = Net() + output = net(x, y, z, 2, 0.4) + expect_output = np.array([[[16.8000, 21.2000, 25.6000], + [51.6000, 62.4000, 73.2000], + [86.4000, 103.6000, 120.8000]], + [[380.4000, 404.0000, 427.6000], + [492.0000, 522.0000, 552.0000], + [603.6000, 640.0000, 676.4000]]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_fill_diagonal.py b/tests/st/tensor/test_fill_diagonal.py index 163a845fdb2..d2c8b09165d 100644 --- a/tests/st/tensor/test_fill_diagonal.py +++ b/tests/st/tensor/test_fill_diagonal.py @@ -1,49 +1,49 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, fill_value, wrap): - return x.fill_diagonal(fill_value, wrap) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('wrap', [True, False]) -def test_tensor_fill_diag(mode, wrap): - """ - Feature: tensor.fill_diagonal - Description: Verify the result of fill_diagonal - Expectation: success - """ - ms.set_context(mode=mode) - a = Tensor(np.zeros((6, 3)), ms.float32) - output = a.fill_diagonal(5.0, wrap=wrap) - expect_output = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - if wrap: - expect_output = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0], - [0.0, 0.0, 0.0], [5.0, 0.0, 0.0], [0.0, 5.0, 0.0]]) - - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, fill_value, wrap): + return x.fill_diagonal(fill_value, wrap) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('wrap', [True, False]) +def test_tensor_fill_diag(mode, wrap): + """ + Feature: tensor.fill_diagonal + Description: Verify the result of fill_diagonal + Expectation: success + """ + ms.set_context(mode=mode) + a = Tensor(np.zeros((6, 3)), ms.float32) + output = a.fill_diagonal(5.0, wrap=wrap) + expect_output = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0], + [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + if wrap: + expect_output = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0], + [0.0, 0.0, 0.0], [5.0, 0.0, 0.0], [0.0, 5.0, 0.0]]) + + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_flip.py b/tests/st/tensor/test_flip.py index 8f169b80d46..b048b2bece6 100644 --- a/tests/st/tensor/test_flip.py +++ b/tests/st/tensor/test_flip.py @@ -1,49 +1,49 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.flip((0, 2)) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_flip_normal(mode): - """ - Feature: tensor.flip - Description: Verify the result of flip - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[5., 4.], - [7., 6.]], - [[1., 0.], - [3., 2.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.flip((0, 2)) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_flip_normal(mode): + """ + Feature: tensor.flip + Description: Verify the result of flip + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[5., 4.], + [7., 6.]], + [[1., 0.], + [3., 2.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_fliplr.py b/tests/st/tensor/test_fliplr.py index a37b52902c9..e1459b84bad 100644 --- a/tests/st/tensor/test_fliplr.py +++ b/tests/st/tensor/test_fliplr.py @@ -1,49 +1,49 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.fliplr() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_fliplr_normal(mode): - """ - Feature: tensor.fliplr - Description: Verify the result of fliplr - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[2., 3.], - [0., 1.]], - [[6., 7.], - [4., 5.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.fliplr() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_fliplr_normal(mode): + """ + Feature: tensor.fliplr + Description: Verify the result of fliplr + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[2., 3.], + [0., 1.]], + [[6., 7.], + [4., 5.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_flipud.py b/tests/st/tensor/test_flipud.py index ca86ca08956..92ee90ba53a 100644 --- a/tests/st/tensor/test_flipud.py +++ b/tests/st/tensor/test_flipud.py @@ -1,49 +1,49 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.flipud() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_flipud_normal(mode): - """ - Feature: tensor.flipud - Description: Verify the result of flipud - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - out = net(x) - expect_out = np.array([[[4., 5.], - [6., 7.]], - [[0., 1.], - [2., 3.]]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.flipud() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_flipud_normal(mode): + """ + Feature: tensor.flipud + Description: Verify the result of flipud + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + out = net(x) + expect_out = np.array([[[4., 5.], + [6., 7.]], + [[0., 1.], + [2., 3.]]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_heaviside.py b/tests/st/tensor/test_heaviside.py index b7db36b3230..32481f0e146 100644 --- a/tests/st/tensor/test_heaviside.py +++ b/tests/st/tensor/test_heaviside.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest -from tests.mark_utils import arg_mark -import numpy as np - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x, values): - output = x.heaviside(values) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_heaviside_normal(mode): - """ - Feature: tensor.heaviside - Description: Verify the result of heaviside - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([-1.5, 0., 2.], ms.float32) - values = ms.Tensor([0.5], ms.float32) - expect_output = np.array([0., 0.5, 1.], dtype=np.float32) - out = net(x, values) - assert np.allclose(out.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from tests.mark_utils import arg_mark +import numpy as np + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x, values): + output = x.heaviside(values) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_heaviside_normal(mode): + """ + Feature: tensor.heaviside + Description: Verify the result of heaviside + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([-1.5, 0., 2.], ms.float32) + values = ms.Tensor([0.5], ms.float32) + expect_output = np.array([0., 0.5, 1.], dtype=np.float32) + out = net(x, values) + assert np.allclose(out.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_hypot.py b/tests/st/tensor/test_hypot.py index a5414a6ccc9..bb2942d1c82 100644 --- a/tests/st/tensor/test_hypot.py +++ b/tests/st/tensor/test_hypot.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest -from tests.mark_utils import arg_mark -import numpy as np - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x, other): - output = x.hypot(other) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_hypot_normal(mode): - """ - Feature: tensor.hypot - Description: Verify the result of hypot - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([4.], ms.float32) - other = ms.Tensor([3., 4., 5.], ms.float64) - out = net(x, other) - expect_out = np.array([5.0000, 5.6569, 6.4031], dtype=np.float64) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from tests.mark_utils import arg_mark +import numpy as np + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x, other): + output = x.hypot(other) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_hypot_normal(mode): + """ + Feature: tensor.hypot + Description: Verify the result of hypot + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([4.], ms.float32) + other = ms.Tensor([3., 4., 5.], ms.float64) + out = net(x, other) + expect_out = np.array([5.0000, 5.6569, 6.4031], dtype=np.float64) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_i0.py b/tests/st/tensor/test_i0.py index a813794286e..196398db6aa 100644 --- a/tests/st/tensor/test_i0.py +++ b/tests/st/tensor/test_i0.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest -from tests.mark_utils import arg_mark -import numpy as np - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.i0() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_i0_normal(mode): - """ - Feature: tensor.i0 - Description: Verify the result of i0 - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([0, 1, 2, 3, 4], ms.float32) - expect_output = np.array([1.0000, 1.26606588, 2.2795853, 4.88079259, 11.30192195], dtype=np.float32) - out = net(x) - assert np.allclose(out.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from tests.mark_utils import arg_mark +import numpy as np + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.i0() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_i0_normal(mode): + """ + Feature: tensor.i0 + Description: Verify the result of i0 + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([0, 1, 2, 3, 4], ms.float32) + expect_output = np.array([1.0000, 1.26606588, 2.2795853, 4.88079259, 11.30192195], dtype=np.float32) + out = net(x) + assert np.allclose(out.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_is_floating_point.py b/tests/st/tensor/test_is_floating_point.py index 3541c4c21b9..ebf5c1988b3 100644 --- a/tests/st/tensor/test_is_floating_point.py +++ b/tests/st/tensor/test_is_floating_point.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.is_floating_point() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_is_floating_point_normal(mode): - """ - Feature: tensor.is_floating_point - Description: Verify the result of is_floating_point - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([1, 2, 3], ms.float32) - y = ms.Tensor([1, 2, 3], ms.int64) - out1 = net(x) - out2 = net(y) - assert out1 - assert not out2 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.is_floating_point() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_is_floating_point_normal(mode): + """ + Feature: tensor.is_floating_point + Description: Verify the result of is_floating_point + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([1, 2, 3], ms.float32) + y = ms.Tensor([1, 2, 3], ms.int64) + out1 = net(x) + out2 = net(y) + assert out1 + assert not out2 diff --git a/tests/st/tensor/test_is_signed.py b/tests/st/tensor/test_is_signed.py index 77b0ee6e71e..96834f5da0a 100644 --- a/tests/st/tensor/test_is_signed.py +++ b/tests/st/tensor/test_is_signed.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.is_signed() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_is_signed_normal(mode): - """ - Feature: tensor.is_signed - Description: Verify the result of is_signed - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor([1, 2, 3], ms.int64) - y = ms.Tensor([1, 2, 3], ms.uint64) - out1 = net(x) - out2 = net(y) - assert out1 - assert not out2 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.is_signed() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_is_signed_normal(mode): + """ + Feature: tensor.is_signed + Description: Verify the result of is_signed + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor([1, 2, 3], ms.int64) + y = ms.Tensor([1, 2, 3], ms.uint64) + out1 = net(x) + out2 = net(y) + assert out1 + assert not out2 diff --git a/tests/st/tensor/test_log_normal.py b/tests/st/tensor/test_log_normal.py index 39d0c0f51ed..ff85c5a2009 100644 --- a/tests/st/tensor/test_log_normal.py +++ b/tests/st/tensor/test_log_normal.py @@ -1,43 +1,43 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Net(nn.Cell): - def construct(self, x, mean, std): - return x.log_normal(mean, std) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_log_normal(mode): - """ - Feature: tensor.log_normal - Description: Verify the result of log_normal - Expectation: success - """ - ms.set_context(mode=mode) - a = Tensor(np.zeros((6, 3)), ms.float32) - output = a.log_normal() - assert a.dtype == output.dtype - assert a.shape == output.shape +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Net(nn.Cell): + def construct(self, x, mean, std): + return x.log_normal(mean, std) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_log_normal(mode): + """ + Feature: tensor.log_normal + Description: Verify the result of log_normal + Expectation: success + """ + ms.set_context(mode=mode) + a = Tensor(np.zeros((6, 3)), ms.float32) + output = a.log_normal() + assert a.dtype == output.dtype + assert a.shape == output.shape diff --git a/tests/st/tensor/test_mH.py b/tests/st/tensor/test_mH.py index d14efdbc596..533f6fbac63 100644 --- a/tests/st/tensor/test_mH.py +++ b/tests/st/tensor/test_mH.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - return x.mH - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_mH_normal(mode): - """ - Feature: mH - Description: Verify the result of mH - Expectation: success - """ - ms.set_context(mode=mode) - x = ms.Tensor(np.array([[0., 1.], [2., 3.]]), ms.float32) - net = Net() - output = net(x) - expect_output = np.array([[0., 2.], - [1., 3.]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + return x.mH + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_mH_normal(mode): + """ + Feature: mH + Description: Verify the result of mH + Expectation: success + """ + ms.set_context(mode=mode) + x = ms.Tensor(np.array([[0., 1.], [2., 3.]]), ms.float32) + net = Net() + output = net(x) + expect_output = np.array([[0., 2.], + [1., 3.]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_mT.py b/tests/st/tensor/test_mT.py index fc3c2ff7aa3..9ea9bff2b0d 100644 --- a/tests/st/tensor/test_mT.py +++ b/tests/st/tensor/test_mT.py @@ -1,51 +1,51 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - return x.mT - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_mT_normal(mode): - """ - Feature: mT - Description: Verify the result of mT - Expectation: success - """ - ms.set_context(mode=mode) - x = ms.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), ms.float32) - net = Net() - output = net(x) - expect_output = np.array([[[1, 4], - [2, 5], - [3, 6]], - - [[7, 10], - [8, 11], - [9, 12]]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + return x.mT + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_mT_normal(mode): + """ + Feature: mT + Description: Verify the result of mT + Expectation: success + """ + ms.set_context(mode=mode) + x = ms.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), ms.float32) + net = Net() + output = net(x) + expect_output = np.array([[[1, 4], + [2, 5], + [3, 6]], + + [[7, 10], + [8, 11], + [9, 12]]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_mm.py b/tests/st/tensor/test_mm.py index deb0e5774a4..fecb9ee426b 100644 --- a/tests/st/tensor/test_mm.py +++ b/tests/st/tensor/test_mm.py @@ -1,48 +1,48 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x1, x2): - output = x1.mm(x2) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_mm_normal(mode): - """ - Feature: mm - Description: Verify the result of mm - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32) - x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32) - out = net(x1, x2) - expect_out = np.array([[20, 23, 26, 29], - [56, 68, 80, 92]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x1, x2): + output = x1.mm(x2) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_mm_normal(mode): + """ + Feature: mm + Description: Verify the result of mm + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x1 = ms.Tensor(np.arange(6).reshape((2, 3)), dtype=ms.float32) + x2 = ms.Tensor(np.arange(12).reshape((3, 4)), dtype=ms.float32) + out = net(x1, x2) + expect_out = np.array([[20, 23, 26, 29], + [56, 68, 80, 92]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_msort.py b/tests/st/tensor/test_msort.py index 80bcf2611ec..3f897106b01 100644 --- a/tests/st/tensor/test_msort.py +++ b/tests/st/tensor/test_msort.py @@ -1,48 +1,48 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.msort() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_msort_normal(mode): - """ - Feature: msort - Description: Verify the result of msort - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) - out = net(x) - expect_out = np.array([[4., 2., 1.], - [5., 6., 3.], - [8., 9., 7.]]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.msort() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_msort_normal(mode): + """ + Feature: msort + Description: Verify the result of msort + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16) + out = net(x) + expect_out = np.array([[4., 2., 1.], + [5., 6., 3.], + [8., 9., 7.]]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_nan_to_num.py b/tests/st/tensor/test_nan_to_num.py index 199594e9f4d..029a20f5887 100644 --- a/tests/st/tensor/test_nan_to_num.py +++ b/tests/st/tensor/test_nan_to_num.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x, nan, posinf, neginf): - output = x.nan_to_num(nan, posinf, neginf) - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_nan_to_num_normal(mode): - """ - Feature: nan_to_num - Description: Verify the result of nan_to_num - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32) - out = net(x, 1.0, 2.0, 3.0) - expect_out = np.array([1., 2., 3., 3.14]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x, nan, posinf, neginf): + output = x.nan_to_num(nan, posinf, neginf) + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_nan_to_num_normal(mode): + """ + Feature: nan_to_num + Description: Verify the result of nan_to_num + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([float('nan'), float('inf'), -float('inf'), 3.14]), ms.float32) + out = net(x, 1.0, 2.0, 3.0) + expect_out = np.array([1., 2., 3., 3.14]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_real.py b/tests/st/tensor/test_real.py index 50a4c1263e2..dc6680d83eb 100644 --- a/tests/st/tensor/test_real.py +++ b/tests/st/tensor/test_real.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.real() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_real_normal(mode): - """ - Feature: real - Description: Verify the result of real - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64) - out = net(x) - expect_out = np.array(1.3) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.real() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_real_normal(mode): + """ + Feature: real + Description: Verify the result of real + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64) + out = net(x) + expect_out = np.array(1.3) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_reciprocal.py b/tests/st/tensor/test_reciprocal.py index a08af0206de..b3f8d9efe4f 100644 --- a/tests/st/tensor/test_reciprocal.py +++ b/tests/st/tensor/test_reciprocal.py @@ -1,46 +1,46 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest -from tests.mark_utils import arg_mark - -import mindspore as ms -import mindspore.nn as nn - - -class Net(nn.Cell): - def construct(self, x): - output = x.reciprocal() - return output - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_reciprocal_normal(mode): - """ - Feature: reciprocal - Description: Verify the result of reciprocal - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) - out = net(x) - expect_out = np.array([1., 0.5, 0.25]) - assert np.allclose(out.asnumpy(), expect_out) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from tests.mark_utils import arg_mark + +import mindspore as ms +import mindspore.nn as nn + + +class Net(nn.Cell): + def construct(self, x): + output = x.reciprocal() + return output + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_reciprocal_normal(mode): + """ + Feature: reciprocal + Description: Verify the result of reciprocal + Expectation: success + """ + ms.set_context(mode=mode) + net = Net() + x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) + out = net(x) + expect_out = np.array([1., 0.5, 0.25]) + assert np.allclose(out.asnumpy(), expect_out) diff --git a/tests/st/tensor/test_repeat_interleave.py b/tests/st/tensor/test_repeat_interleave.py index 9a828f84f93..63498a28eac 100644 --- a/tests/st/tensor/test_repeat_interleave.py +++ b/tests/st/tensor/test_repeat_interleave.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class RepeatInterleave(nn.Cell): - def construct(self, x): - return x.repeat_interleave(repeats=2, dim=0) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_repeat_interleave(mode): - """ - Feature: tensor.repeat_interleave - Description: Verify the result of repeat_interleave - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) - net = RepeatInterleave() - output = net(x) - expect_output = [[0, 1, 2], - [0, 1, 2], - [3, 4, 5], - [3, 4, 5]] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class RepeatInterleave(nn.Cell): + def construct(self, x): + return x.repeat_interleave(repeats=2, dim=0) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_repeat_interleave(mode): + """ + Feature: tensor.repeat_interleave + Description: Verify the result of repeat_interleave + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) + net = RepeatInterleave() + output = net(x) + expect_output = [[0, 1, 2], + [0, 1, 2], + [3, 4, 5], + [3, 4, 5]] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_reshape_as.py b/tests/st/tensor/test_reshape_as.py index a772805d499..41a7f35db1f 100644 --- a/tests/st/tensor/test_reshape_as.py +++ b/tests/st/tensor/test_reshape_as.py @@ -1,47 +1,47 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class ReshapeAs(nn.Cell): - def construct(self, x, y): - return x.reshape_as(y) - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_reshape_as(mode): - """ - Feature: tensor.reshape_as - Description: Verify the result of output - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=ms.float32) - y = Tensor(np.arange(6).reshape(3, 2)) - net = ReshapeAs() - output = net(x, y) - expect_output = np.array([[-0.1, 0.3], - [3.6, 0.4], - [0.5, -3.2]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class ReshapeAs(nn.Cell): + def construct(self, x, y): + return x.reshape_as(y) + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_reshape_as(mode): + """ + Feature: tensor.reshape_as + Description: Verify the result of output + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=ms.float32) + y = Tensor(np.arange(6).reshape(3, 2)) + net = ReshapeAs() + output = net(x, y) + expect_output = np.array([[-0.1, 0.3], + [3.6, 0.4], + [0.5, -3.2]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_roll.py b/tests/st/tensor/test_roll.py index 666c9274f38..b2abf926a3a 100644 --- a/tests/st/tensor/test_roll.py +++ b/tests/st/tensor/test_roll.py @@ -1,41 +1,41 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Roll(nn.Cell): - def construct(self, x): - return x.roll(shifts=2, dims=0) - - -@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_roll(mode): - """ - Feature: tensor.roll - Description: Verify the result of roll - Expectation: success - """ - ms.set_context(mode=mode) - input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) - net = Roll() - output = net(input_x) - expect_output = [3., 4., 0., 1., 2.] - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Roll(nn.Cell): + def construct(self, x): + return x.roll(shifts=2, dims=0) + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_roll(mode): + """ + Feature: tensor.roll + Description: Verify the result of roll + Expectation: success + """ + ms.set_context(mode=mode) + input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) + net = Roll() + output = net(input_x) + expect_output = [3., 4., 0., 1., 2.] + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_rot90.py b/tests/st/tensor/test_rot90.py index 678b496d385..cb0a324f67a 100644 --- a/tests/st/tensor/test_rot90.py +++ b/tests/st/tensor/test_rot90.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Rot90(nn.Cell): - def construct(self, x): - return x.rot90(k=1, dims=[0, 1]) - - -@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_rot90(mode): - """ - Feature: tensor.rot90 - Description: Verify the result of rot90 - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([[0, 1], [2, 3]]), dtype=ms.float32) - net = Rot90() - output = net(x) - expect_output = np.array([[1., 3.], - [0., 2.]]) - assert np.allclose(output.asnumpy(), expect_output) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Rot90(nn.Cell): + def construct(self, x): + return x.rot90(k=1, dims=[0, 1]) + + +@arg_mark(plat_marks=['platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_rot90(mode): + """ + Feature: tensor.rot90 + Description: Verify the result of rot90 + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([[0, 1], [2, 3]]), dtype=ms.float32) + net = Rot90() + output = net(x) + expect_output = np.array([[1., 3.], + [0., 2.]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_short.py b/tests/st/tensor/test_short.py index d25ad40fa2b..291f2d4aedc 100644 --- a/tests/st/tensor/test_short.py +++ b/tests/st/tensor/test_short.py @@ -1,43 +1,43 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from tests.mark_utils import arg_mark -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor - - -class Short(nn.Cell): - def construct(self, x): - return x.short() - - -@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], - level_mark='level2', - card_mark='onecard', - essential_mark='unessential') -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_tensor_short(mode): - """ - Feature: tensor.short - Description: Verify the type of output - Expectation: success - """ - ms.set_context(mode=mode) - x = Tensor(np.array([1, 2, 3, 4, 5]), ms.int32) - net = Short() - output = net(x) - assert output.dtype == ms.int16 +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from tests.mark_utils import arg_mark +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class Short(nn.Cell): + def construct(self, x): + return x.short() + + +@arg_mark(plat_marks=['cpu_linux', 'cpu_windows', 'cpu_macos', 'platform_gpu', 'platform_ascend'], + level_mark='level2', + card_mark='onecard', + essential_mark='unessential') +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_tensor_short(mode): + """ + Feature: tensor.short + Description: Verify the type of output + Expectation: success + """ + ms.set_context(mode=mode) + x = Tensor(np.array([1, 2, 3, 4, 5]), ms.int32) + net = Short() + output = net(x) + assert output.dtype == ms.int16 diff --git a/tests/ut/cpp/dataset/c_api_dataset_ag_news_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ag_news_test.cc index 24cac9ba8bf..db92aa1a9b3 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ag_news_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ag_news_test.cc @@ -1,486 +1,486 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" -#include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/include/dataset/datasets.h" - -#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h" - -using namespace mindspore::dataset; - -class MindDataTestPipeline : public UT::DatasetOpTesting { -protected: -}; - -/// Feature: AGNewsDataset -/// Description: Basic test for AGNewsDataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetBasic."; - - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::vector column_names = {"index", "title", "description"}; - std::shared_ptr ds = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 2 samples. - EXPECT_EQ(i, 2); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset in pipeline mode -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsGetters."; - - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - std::vector types = ToDETypes(ds->GetOutputTypes()); - std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); - EXPECT_EQ(types.size(), 3); - EXPECT_EQ(types[0].ToString(), "string"); - EXPECT_EQ(types[1].ToString(), "string"); - EXPECT_EQ(types[2].ToString(), "string"); - EXPECT_EQ(shapes.size(), 3); - EXPECT_EQ(shapes[0].ToString(), "<>"); - EXPECT_EQ(shapes[1].ToString(), "<>"); - EXPECT_EQ(shapes[2].ToString(), "<>"); - EXPECT_EQ(ds->GetColumnNames(), column_names); - EXPECT_EQ(ds->GetDatasetSize(), 2); - EXPECT_EQ(ds->GetColumnNames(), column_names); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset with invalid inputs -/// Expectation: Correct error and message are thrown -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetFail."; - - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::string invalid_csv_file = "./NotExistFile"; - std::vector column_names = {"index", "title", "description"}; - std::shared_ptr ds0 = AGNews("", "test", 0); - EXPECT_NE(ds0, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid AGNews input. - EXPECT_EQ(iter0, nullptr); - // Create a AGNews Dataset with invalid usage. - std::shared_ptr ds1 = AGNews(invalid_csv_file); - EXPECT_NE(ds1, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid AGNews input. - EXPECT_EQ(iter1, nullptr); - // Test invalid num_samples < -1. - std::shared_ptr ds2 = - AGNews(dataset_dir, "test", -1, ShuffleMode::kFalse); - EXPECT_NE(ds2, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter2 = ds2->CreateIterator(); - // Expect failure: invalid AGNews input. - EXPECT_EQ(iter2, nullptr); - // Test invalid num_shards < 1. - std::shared_ptr ds3 = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); - EXPECT_NE(ds3, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter3 = ds3->CreateIterator(); - // Expect failure: invalid AGNews input. - EXPECT_EQ(iter3, nullptr); - // Test invalid shard_id >= num_shards. - std::shared_ptr ds4 = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); - EXPECT_NE(ds4, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter4 = ds4->CreateIterator(); - // Expect failure: invalid AGNews input. - EXPECT_EQ(iter4, nullptr); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset with valid num_samples -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetNumSamples) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetNumSamples."; - - // Create a AGNewsDataset, with single CSV file. - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "test", 2, ShuffleMode::kFalse); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it.. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - }; - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 2 samples. - EXPECT_EQ(i, 2); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AGNewsDataset -/// Description: Test distributed AGNewsDataset (with num_shards and shard_id) -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetDistribution) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetDistribution."; - - // Create a AGNewsDataset, with single CSV file. - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - }; - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 1 samples. - EXPECT_EQ(i, 1); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset with all as usage -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetMultiFiles) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetMultiFiles."; - - // Create a AGNewsDataset, with single CSV file. - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "all", 0, ShuffleMode::kFalse); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"3", "Demand analysis", - "\"Users simply click on the module they want to view to " - "browse information about that module.\""}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - {"3", "UML Timing Diagram", - "Information is mainly displayed using locally stored data and mapping, " - "which is not timely and does not have the ability to update itself."}, - {"3", "In summary", - "This paper implements a map visualization system for Hangzhou city " - "information, using extensive knowledge of visualization techniques."}, - }; - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 5 samples. - EXPECT_EQ(i, 5); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset header -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetHeader) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetHeader."; - - // Create a AGNewsDataset, with single CSV file. - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - }; - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 2 samples. - EXPECT_EQ(i, 2); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset using ShuffleMode::kFiles -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetShuffleFilesA) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesA."; - - // Set configuration. - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = - GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed - << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "all", 0, ShuffleMode::kFiles); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "Demand analysis", - "\"Users simply click on the module they want to view to " - "browse information about that module.\""}, - {"3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data."}, - {"3", "UML Timing Diagram", - "Information is mainly displayed using locally stored data and mapping, " - "which is not timely and does not have the ability to update itself."}, - {"4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""}, - {"3", "In summary", - "This paper implements a map visualization system for Hangzhou city " - "information, using extensive knowledge of visualization techniques."}, - }; - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 5 samples. - EXPECT_EQ(i, 5); - // Manually terminate the pipeline. - iter->Stop(); - // Restore configuration. - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers( - original_num_parallel_workers); -} - -/// Feature: AGNewsDataset -/// Description: Test AGNewsDataset using ShuffleMode::kGlobal -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetShuffleGlobal) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleGlobal."; - // Test AGNews Dataset with GLOBLE shuffle. - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = - GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed - << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(135); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testAGNews"; - std::shared_ptr ds = - AGNews(dataset_dir, "train", 0, ShuffleMode::kGlobal); - std::vector column_names = {"index", "title", "description"}; - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"3", "UML Timing Diagram", - "Information is mainly displayed using locally stored data and mapping, " - "which is not timely and does not have the ability to update itself."}, - {"3", "In summary", - "This paper implements a map visualization system for Hangzhou city " - "information, using extensive knowledge of visualization techniques."}, - {"3", "Demand analysis", - "\"Users simply click on the module they want to view to " - "browse information about that module.\""}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 3 samples. - EXPECT_EQ(i, 3); - // Manually terminate the pipeline. - iter->Stop(); - // Restore configuration. - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers( - original_num_parallel_workers); -} +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/dataset/datasets.h" + +#include "minddata/dataset/engine/ir/datasetops/source/ag_news_node.h" + +using namespace mindspore::dataset; + +class MindDataTestPipeline : public UT::DatasetOpTesting { +protected: +}; + +/// Feature: AGNewsDataset +/// Description: Basic test for AGNewsDataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetBasic."; + + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::vector column_names = {"index", "title", "description"}; + std::shared_ptr ds = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 2 samples. + EXPECT_EQ(i, 2); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset in pipeline mode +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsGetters."; + + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + EXPECT_EQ(types.size(), 3); + EXPECT_EQ(types[0].ToString(), "string"); + EXPECT_EQ(types[1].ToString(), "string"); + EXPECT_EQ(types[2].ToString(), "string"); + EXPECT_EQ(shapes.size(), 3); + EXPECT_EQ(shapes[0].ToString(), "<>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + EXPECT_EQ(shapes[2].ToString(), "<>"); + EXPECT_EQ(ds->GetColumnNames(), column_names); + EXPECT_EQ(ds->GetDatasetSize(), 2); + EXPECT_EQ(ds->GetColumnNames(), column_names); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset with invalid inputs +/// Expectation: Correct error and message are thrown +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetFail."; + + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::string invalid_csv_file = "./NotExistFile"; + std::vector column_names = {"index", "title", "description"}; + std::shared_ptr ds0 = AGNews("", "test", 0); + EXPECT_NE(ds0, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid AGNews input. + EXPECT_EQ(iter0, nullptr); + // Create a AGNews Dataset with invalid usage. + std::shared_ptr ds1 = AGNews(invalid_csv_file); + EXPECT_NE(ds1, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid AGNews input. + EXPECT_EQ(iter1, nullptr); + // Test invalid num_samples < -1. + std::shared_ptr ds2 = + AGNews(dataset_dir, "test", -1, ShuffleMode::kFalse); + EXPECT_NE(ds2, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter2 = ds2->CreateIterator(); + // Expect failure: invalid AGNews input. + EXPECT_EQ(iter2, nullptr); + // Test invalid num_shards < 1. + std::shared_ptr ds3 = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); + EXPECT_NE(ds3, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter3 = ds3->CreateIterator(); + // Expect failure: invalid AGNews input. + EXPECT_EQ(iter3, nullptr); + // Test invalid shard_id >= num_shards. + std::shared_ptr ds4 = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); + EXPECT_NE(ds4, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter4 = ds4->CreateIterator(); + // Expect failure: invalid AGNews input. + EXPECT_EQ(iter4, nullptr); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset with valid num_samples +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetNumSamples) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetNumSamples."; + + // Create a AGNewsDataset, with single CSV file. + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "test", 2, ShuffleMode::kFalse); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it.. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + }; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 2 samples. + EXPECT_EQ(i, 2); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AGNewsDataset +/// Description: Test distributed AGNewsDataset (with num_shards and shard_id) +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetDistribution) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetDistribution."; + + // Create a AGNewsDataset, with single CSV file. + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + }; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 1 samples. + EXPECT_EQ(i, 1); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset with all as usage +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetMultiFiles) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetMultiFiles."; + + // Create a AGNewsDataset, with single CSV file. + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "all", 0, ShuffleMode::kFalse); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"3", "Demand analysis", + "\"Users simply click on the module they want to view to " + "browse information about that module.\""}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + {"3", "UML Timing Diagram", + "Information is mainly displayed using locally stored data and mapping, " + "which is not timely and does not have the ability to update itself."}, + {"3", "In summary", + "This paper implements a map visualization system for Hangzhou city " + "information, using extensive knowledge of visualization techniques."}, + }; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 5 samples. + EXPECT_EQ(i, 5); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset header +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetHeader) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetHeader."; + + // Create a AGNewsDataset, with single CSV file. + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + }; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 2 samples. + EXPECT_EQ(i, 2); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset using ShuffleMode::kFiles +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetShuffleFilesA) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleFilesA."; + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = + GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed + << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "all", 0, ShuffleMode::kFiles); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "Demand analysis", + "\"Users simply click on the module they want to view to " + "browse information about that module.\""}, + {"3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data."}, + {"3", "UML Timing Diagram", + "Information is mainly displayed using locally stored data and mapping, " + "which is not timely and does not have the ability to update itself."}, + {"4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""}, + {"3", "In summary", + "This paper implements a map visualization system for Hangzhou city " + "information, using extensive knowledge of visualization techniques."}, + }; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 5 samples. + EXPECT_EQ(i, 5); + // Manually terminate the pipeline. + iter->Stop(); + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers( + original_num_parallel_workers); +} + +/// Feature: AGNewsDataset +/// Description: Test AGNewsDataset using ShuffleMode::kGlobal +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, DISABLED_TestAGNewsDatasetShuffleGlobal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAGNewsDatasetShuffleGlobal."; + // Test AGNews Dataset with GLOBLE shuffle. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = + GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed + << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(135); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testAGNews"; + std::shared_ptr ds = + AGNews(dataset_dir, "train", 0, ShuffleMode::kGlobal); + std::vector column_names = {"index", "title", "description"}; + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"3", "UML Timing Diagram", + "Information is mainly displayed using locally stored data and mapping, " + "which is not timely and does not have the ability to update itself."}, + {"3", "In summary", + "This paper implements a map visualization system for Hangzhou city " + "information, using extensive knowledge of visualization techniques."}, + {"3", "Demand analysis", + "\"Users simply click on the module they want to view to " + "browse information about that module.\""}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 3 samples. + EXPECT_EQ(i, 3); + // Manually terminate the pipeline. + iter->Stop(); + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers( + original_num_parallel_workers); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_amazon_review_test.cc b/tests/ut/cpp/dataset/c_api_dataset_amazon_review_test.cc old mode 100755 new mode 100644 index 665a2dcf5d0..74ce672b2e7 --- a/tests/ut/cpp/dataset/c_api_dataset_amazon_review_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_amazon_review_test.cc @@ -1,584 +1,584 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/common.h" -#include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/include/dataset/datasets.h" -#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" - -using namespace mindspore::dataset; - -class MindDataTestPipeline : public UT::DatasetOpTesting { -protected: -}; - -/// Feature: AmazonReview -/// Description: Read AmazonReviewPolarityDataset data and get data. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewPolarityDatasetBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewPolarityDatasetBasic."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/polarity"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"1", "DVD", "It is very good!"}, - {"2", "Book", "I would read it again lol."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 2 samples. - EXPECT_EQ(i, 2); - - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AmazonReview -/// Description: Read AmazonReviewFullDataset data and get data. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewFullDatasetBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewFullDatasetBasic."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"1", "amazing", "unlimited buyback!"}, - {"4", "delightful", "a funny book!"}, - {"3", "Small", "It is a small ball!"} - }; - - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 3 samples - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: AmazonReview(usage=all). -/// Description: Read train data and test data. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetUsageAll) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetUsageAll."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"1", "amazing", "unlimited buyback!"}, - {"3", "Satisfied", "good quality."}, - {"4", "delightful", "a funny book!"}, - {"5", "good", "This is an very good product."}, - {"3", "Small", "It is a small ball!"}, - {"1", "bad", "work badly."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 6 samples - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: AmazonReview -/// Description: Test Getter methods -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewGetters."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::vector column_names = {"label", "title", "content"}; - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds-> GetDatasetSize(),3); - EXPECT_EQ(ds->GetColumnNames(),column_names); -} - -/// Feature: AmazonReview(num_samples = 3). -/// Description: Test whether the interface meets expectations when NumSamples is equal to 2. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewNumSamples) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewNumSamples."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "test", 3, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"1", "amazing", "unlimited buyback!"}, - {"4", "delightful", "a funny book!"}, - {"3", "Small", "It is a small ball!"} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 3 samples. - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AmazonReview -/// Description: Test interface in a distributed state. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetDistribution) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetDistribution."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"1", "amazing", "unlimited buyback!"}, - {"4", "delightful", "a funny book!"}, - {"3", "Small", "It is a small ball!"} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 2 samples. - EXPECT_EQ(i, 2); - - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AmazonReview -/// Description: Test the wrong input. -/// Expectation: Unable to read in data. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetFail."; - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::string invalid_csv_file = "./NotExistFile"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds0 = AmazonReview("", "test", 0); - EXPECT_NE(ds0, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid AmazonReview input. - EXPECT_EQ(iter0, nullptr); - - // Create a AmazonReview Dataset with invalid usage. - std::shared_ptr ds1 = AmazonReview(invalid_csv_file); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid SogouNews input. - EXPECT_EQ(iter1, nullptr); - - // Test invalid num_samples < -1. - std::shared_ptr ds2 = AmazonReview(dataset_dir, "test", -1, ShuffleMode::kFalse); - EXPECT_NE(ds2, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter2 = ds2->CreateIterator(); - // Expect failure: invalid AmazonReviewNews input. - EXPECT_EQ(iter2, nullptr); - - // Test invalid num_shards < 1. - std::shared_ptr ds3 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); - EXPECT_NE(ds3, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter3 = ds3->CreateIterator(); - // Expect failure: invalid AmazonReview input. - EXPECT_EQ(iter3, nullptr); - - // Test invalid shard_id >= num_shards. - std::shared_ptr ds4 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); - EXPECT_NE(ds4, nullptr); - // Create an iterator over the result of the above dataset. - std::shared_ptr iter4 = ds4->CreateIterator(); - // Expect failure: invalid AmazonReview input. - EXPECT_EQ(iter4, nullptr); -} - -/// Feature: AmazonReview -/// Description: Test AmazonReview Dataset interface in pipeline. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetBasicWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetBasicWithPipeline."; - - // Create two AmazonReview Dataset, with single AmazonReview file. - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - - std::shared_ptr ds1 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::shared_ptr ds2 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds1, nullptr); - EXPECT_NE(ds2, nullptr); - - // Create two Repeat operation on ds. - int32_t repeat_num = 2; - ds1 = ds1->Repeat(repeat_num); - EXPECT_NE(ds1, nullptr); - repeat_num = 3; - ds2 = ds2->Repeat(repeat_num); - EXPECT_NE(ds2, nullptr); - - // Create two Project operation on ds. - std::vector column_project = {"label"}; - ds1 = ds1->Project(column_project); - EXPECT_NE(ds1, nullptr); - ds2 = ds2->Project(column_project); - EXPECT_NE(ds2, nullptr); - - // Create a Concat operation on the ds. - ds1 = ds1->Concat({ds2}); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds1->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("label"), row.end()); - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["label"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 10 samples. - EXPECT_EQ(i, 15); - - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: AmazonReview(ShuffleMode=kFiles). -/// Description: Test AmazonReview Dataset interface with different ShuffleMode. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesA) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-AmazonReviewDatasetShuffleFilesA."; - - // Set configuration. - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFiles); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"3", "Satisfied", "good quality."}, - {"1", "amazing", "unlimited buyback!"}, - {"5", "good", "This is an very good product."}, - {"4", "delightful", "a funny book!"}, - {"1", "bad", "work badly."}, - {"3", "Small", "It is a small ball!"} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 6 samples. - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline. - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: AmazonReview(ShuffleMode=kInfile). -/// Description: Test AmazonReview Dataset interface with different ShuffleMode. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesB) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetShuffleFilesB."; - - // Set configuration. - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFiles); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"3", "Satisfied", "good quality."}, - {"1", "amazing", "unlimited buyback!"}, - {"5", "good", "This is an very good product."}, - {"4", "delightful", "a funny book!"}, - {"1", "bad", "work badly."}, - {"3", "Small", "It is a small ball!"} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 6 samples - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: AmazonReview(ShuffleMode=kGlobal). -/// Description: Test AmazonReview Dataset interface with different ShuffleMode. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesGlobal) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetShuffleFilesGlobal."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; - std::vector column_names = {"label", "title", "content"}; - - std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kGlobal); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("label"), row.end()); - std::vector> expected_result = { - {"3", "Satisfied", "good quality."}, - {"1", "amazing", "unlimited buyback!"}, - {"5", "good", "This is an very good product."}, - {"3", "Small", "It is a small ball!"}, - {"4", "delightful", "a funny book!"}, - {"1", "bad", "work badly."} -}; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 6 samples. - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline. - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/dataset/datasets.h" +#include "minddata/dataset/engine/ir/datasetops/source/amazon_review_node.h" + +using namespace mindspore::dataset; + +class MindDataTestPipeline : public UT::DatasetOpTesting { +protected: +}; + +/// Feature: AmazonReview +/// Description: Read AmazonReviewPolarityDataset data and get data. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewPolarityDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewPolarityDatasetBasic."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/polarity"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"1", "DVD", "It is very good!"}, + {"2", "Book", "I would read it again lol."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 2 samples. + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AmazonReview +/// Description: Read AmazonReviewFullDataset data and get data. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewFullDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewFullDatasetBasic."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"1", "amazing", "unlimited buyback!"}, + {"4", "delightful", "a funny book!"}, + {"3", "Small", "It is a small ball!"} + }; + + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 samples + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: AmazonReview(usage=all). +/// Description: Read train data and test data. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetUsageAll) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetUsageAll."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"1", "amazing", "unlimited buyback!"}, + {"3", "Satisfied", "good quality."}, + {"4", "delightful", "a funny book!"}, + {"5", "good", "This is an very good product."}, + {"3", "Small", "It is a small ball!"}, + {"1", "bad", "work badly."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 6 samples + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: AmazonReview +/// Description: Test Getter methods +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewGetters."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::vector column_names = {"label", "title", "content"}; + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds-> GetDatasetSize(),3); + EXPECT_EQ(ds->GetColumnNames(),column_names); +} + +/// Feature: AmazonReview(num_samples = 3). +/// Description: Test whether the interface meets expectations when NumSamples is equal to 2. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewNumSamples) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewNumSamples."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "test", 3, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"1", "amazing", "unlimited buyback!"}, + {"4", "delightful", "a funny book!"}, + {"3", "Small", "It is a small ball!"} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 samples. + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AmazonReview +/// Description: Test interface in a distributed state. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetDistribution) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetDistribution."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"1", "amazing", "unlimited buyback!"}, + {"4", "delightful", "a funny book!"}, + {"3", "Small", "It is a small ball!"} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 2 samples. + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AmazonReview +/// Description: Test the wrong input. +/// Expectation: Unable to read in data. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetFail."; + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::string invalid_csv_file = "./NotExistFile"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds0 = AmazonReview("", "test", 0); + EXPECT_NE(ds0, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid AmazonReview input. + EXPECT_EQ(iter0, nullptr); + + // Create a AmazonReview Dataset with invalid usage. + std::shared_ptr ds1 = AmazonReview(invalid_csv_file); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid SogouNews input. + EXPECT_EQ(iter1, nullptr); + + // Test invalid num_samples < -1. + std::shared_ptr ds2 = AmazonReview(dataset_dir, "test", -1, ShuffleMode::kFalse); + EXPECT_NE(ds2, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter2 = ds2->CreateIterator(); + // Expect failure: invalid AmazonReviewNews input. + EXPECT_EQ(iter2, nullptr); + + // Test invalid num_shards < 1. + std::shared_ptr ds3 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); + EXPECT_NE(ds3, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter3 = ds3->CreateIterator(); + // Expect failure: invalid AmazonReview input. + EXPECT_EQ(iter3, nullptr); + + // Test invalid shard_id >= num_shards. + std::shared_ptr ds4 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); + EXPECT_NE(ds4, nullptr); + // Create an iterator over the result of the above dataset. + std::shared_ptr iter4 = ds4->CreateIterator(); + // Expect failure: invalid AmazonReview input. + EXPECT_EQ(iter4, nullptr); +} + +/// Feature: AmazonReview +/// Description: Test AmazonReview Dataset interface in pipeline. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetBasicWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetBasicWithPipeline."; + + // Create two AmazonReview Dataset, with single AmazonReview file. + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + + std::shared_ptr ds1 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::shared_ptr ds2 = AmazonReview(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds. + int32_t repeat_num = 2; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 3; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create two Project operation on ds. + std::vector column_project = {"label"}; + ds1 = ds1->Project(column_project); + EXPECT_NE(ds1, nullptr); + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds. + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("label"), row.end()); + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["label"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 10 samples. + EXPECT_EQ(i, 15); + + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: AmazonReview(ShuffleMode=kFiles). +/// Description: Test AmazonReview Dataset interface with different ShuffleMode. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesA) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-AmazonReviewDatasetShuffleFilesA."; + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFiles); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"3", "Satisfied", "good quality."}, + {"1", "amazing", "unlimited buyback!"}, + {"5", "good", "This is an very good product."}, + {"4", "delightful", "a funny book!"}, + {"1", "bad", "work badly."}, + {"3", "Small", "It is a small ball!"} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 6 samples. + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: AmazonReview(ShuffleMode=kInfile). +/// Description: Test AmazonReview Dataset interface with different ShuffleMode. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesB) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetShuffleFilesB."; + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kFiles); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"3", "Satisfied", "good quality."}, + {"1", "amazing", "unlimited buyback!"}, + {"5", "good", "This is an very good product."}, + {"4", "delightful", "a funny book!"}, + {"1", "bad", "work badly."}, + {"3", "Small", "It is a small ball!"} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 6 samples + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: AmazonReview(ShuffleMode=kGlobal). +/// Description: Test AmazonReview Dataset interface with different ShuffleMode. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, DISABLED_TestAmazonReviewDatasetShuffleFilesGlobal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAmazonReviewDatasetShuffleFilesGlobal."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testAmazonReview/full"; + std::vector column_names = {"label", "title", "content"}; + + std::shared_ptr ds = AmazonReview(dataset_dir, "all" , 0, ShuffleMode::kGlobal); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("label"), row.end()); + std::vector> expected_result = { + {"3", "Satisfied", "good quality."}, + {"1", "amazing", "unlimited buyback!"}, + {"5", "good", "This is an very good product."}, + {"3", "Small", "It is a small ball!"}, + {"4", "delightful", "a funny book!"}, + {"1", "bad", "work badly."} +}; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 6 samples. + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); } \ No newline at end of file diff --git a/tests/ut/cpp/dataset/c_api_dataset_caltech256_test.cc b/tests/ut/cpp/dataset/c_api_dataset_caltech256_test.cc old mode 100755 new mode 100644 diff --git a/tests/ut/cpp/dataset/c_api_dataset_cmu_arctic_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cmu_arctic_test.cc index c60f47cc416..359d8899ca5 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cmu_arctic_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cmu_arctic_test.cc @@ -1,287 +1,287 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" - -#include "include/dataset/datasets.h" -#include "include/dataset/transforms.h" - -using namespace mindspore::dataset; -using mindspore::dataset::Tensor; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset basic usage -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestCMUArcticBasic) { - MS_LOG(INFO) << "Doing CMUArcticDataTestPipeline-TestCMUArcticBasic."; - - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - // Create a CMUArctic Dataset. - std::shared_ptr ds = CMUArctic(folder_path); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - std::string_view transcript_idx, utterance_id_idx; - uint32_t rate = 0; - uint64_t i = 0; - - while (row.size() != 0) { - auto waveform = row["waveform"]; - auto sample_rate = row["sample_rate"]; - auto transcript = row["transcript"]; - auto utterance_id = row["utterance_id"]; - - MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); - - std::shared_ptr trate; - ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); - ASSERT_OK(trate->GetItemAt(&rate, {})); - MS_LOG(INFO) << "Audio sample rate: " << rate; - - std::shared_ptr de_transcript; - ASSERT_OK(Tensor::CreateFromMSTensor(transcript, &de_transcript)); - ASSERT_OK(de_transcript->GetItemAt(&transcript_idx, {})); - std::string s_transcript(transcript_idx); - MS_LOG(INFO) << "Tensor transcript value: " << transcript_idx; - - std::shared_ptr de_utterance_id; - ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); - ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); - std::string s_utterance_id(utterance_id_idx); - MS_LOG(INFO) << "Tensor utterance_id value: " << utterance_id_idx; - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 3); - - iter->Stop(); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset in pipeline mode -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestCMUArcticBasicWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticBasicWithPipeline."; - - // Create a CMUArcticDataset Dataset - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - std::shared_ptr ds = CMUArctic(folder_path, "aew", std::make_shared(0, 2)); - EXPECT_NE(ds, nullptr); - auto op = transforms::PadEnd({1, 50000});; - std::vector input_columns = {"waveform"}; - std::vector output_columns = {"waveform"}; - ds = ds->Map({op}, input_columns, output_columns); - EXPECT_NE(ds, nullptr); - ds = ds->Repeat(10); - EXPECT_NE(ds, nullptr); - ds = ds->Batch(2); - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - iter->GetNextRow(&row); - std::vector expected_utterance = {"Dog.", "Cat."}; - std::vector expected_utterance_id = {"a0001", "a0002"}; - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto waveform = row["waveform"]; - auto transcript = row["transcript"]; - auto utterance_id = row["utterance_id"]; - - std::shared_ptr de_expected_transcript; - ASSERT_OK(Tensor::CreateFromVector(expected_utterance, &de_expected_transcript)); - mindspore::MSTensor fix_expected_transcript = - mindspore::MSTensor(std::make_shared(de_expected_transcript)); - EXPECT_MSTENSOR_EQ(transcript, fix_expected_transcript); - - std::shared_ptr de_expected_utterance_id; - ASSERT_OK(Tensor::CreateFromVector(expected_utterance_id, &de_expected_utterance_id)); - mindspore::MSTensor fix_expected_utterance_id = - mindspore::MSTensor(std::make_shared(de_expected_utterance_id)); - EXPECT_MSTENSOR_EQ(utterance_id, fix_expected_utterance_id); - - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 10); - - iter->Stop(); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset with non-existing dataset directory -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestCMUArcticError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticError."; - - // Create a CMUArctic Dataset with non-existing dataset dir - std::shared_ptr ds0 = CMUArctic("NotExistFile"); - EXPECT_NE(ds0, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid CMUArctic input - EXPECT_EQ(iter0, nullptr); - - // Create a CMUArctic Dataset with invalid string of dataset dir - std::shared_ptr ds1 = CMUArctic(":*?\"<>|`&;'"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid CMUArctic input - EXPECT_EQ(iter1, nullptr); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset Getters method -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestCMUArcticGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticGetters."; - - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - // Create a CMUArctic Dataset. - std::shared_ptr ds1 = CMUArctic(folder_path); - std::shared_ptr ds2 = CMUArctic(folder_path, "aew"); - - std::vector column_names = {"waveform", "sample_rate", "transcript", "utterance_id"}; - - EXPECT_NE(ds1, nullptr); - EXPECT_EQ(ds1->GetDatasetSize(), 3); - EXPECT_EQ(ds1->GetColumnNames(), column_names); - - EXPECT_NE(ds2, nullptr); - EXPECT_EQ(ds2->GetDatasetSize(), 3); - EXPECT_EQ(ds2->GetColumnNames(), column_names); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset with invalid name -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestCMUArcticWithInvalidNameError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithInvalidNameError."; - - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - // Create a CMUArctic Dataset. - std::shared_ptr ds1 = CMUArctic(folder_path, "----"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid CMUArctic input, invalid name - EXPECT_EQ(iter1, nullptr); - - std::shared_ptr ds2 = CMUArctic(folder_path, "csacs"); - EXPECT_NE(ds2, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter2 = ds2->CreateIterator(); - // Expect failure: invalid CMUArctic input, invalid name - EXPECT_EQ(iter2, nullptr); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset with null sampler -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestCMUArcticWithNullSamplerError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithNullSamplerError."; - - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - // Create a CMUArctic Dataset. - std::shared_ptr ds = CMUArctic(folder_path, "aew", nullptr); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid CMUArctic input, sampler cannot be nullptr - EXPECT_EQ(iter, nullptr); -} - -/// Feature: CMUArcticDataset -/// Description: Test CMUArcticDataset with SequentialSampler -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestCMUArcticNumSamplers) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithSequentialSampler."; - - std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; - // Create a CMUArctic Dataset. - std::shared_ptr ds = CMUArctic(folder_path, "aew", std::make_shared(0, 2)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - std::string_view transcript_idx, utterance_id_idx; - std::vector expected_utterance = {"Dog.", "Cat."}; - std::vector expected_utterance_id = {"a0001", "a0002"}; - uint32_t rate = 0; - uint64_t i = 0; - - while (row.size() != 0) { - auto waveform = row["waveform"]; - auto sample_rate = row["sample_rate"]; - auto transcript = row["transcript"]; - auto utterance_id = row["utterance_id"]; - - MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); - - std::shared_ptr trate; - ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); - ASSERT_OK(trate->GetItemAt(&rate, {})); - EXPECT_EQ(rate, 16000); - MS_LOG(INFO) << "Tensor sample rate: " << rate; - - std::shared_ptr de_transcript; - ASSERT_OK(Tensor::CreateFromMSTensor(transcript, &de_transcript)); - ASSERT_OK(de_transcript->GetItemAt(&transcript_idx, {})); - std::string s_transcript(transcript_idx); - EXPECT_STREQ(s_transcript.c_str(), expected_utterance[i].c_str()); - MS_LOG(INFO) << "Tensor transcript value: " << transcript_idx; - - std::shared_ptr de_utterance_id; - ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); - ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); - std::string s_utterance_id(utterance_id_idx); - EXPECT_STREQ(s_utterance_id.c_str(), expected_utterance_id[i].c_str()); - MS_LOG(INFO) << "Tensor utterance_id value: " << utterance_id_idx; - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 2); - - iter->Stop(); -} +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" + +#include "include/dataset/datasets.h" +#include "include/dataset/transforms.h" + +using namespace mindspore::dataset; +using mindspore::dataset::Tensor; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset basic usage +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestCMUArcticBasic) { + MS_LOG(INFO) << "Doing CMUArcticDataTestPipeline-TestCMUArcticBasic."; + + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + // Create a CMUArctic Dataset. + std::shared_ptr ds = CMUArctic(folder_path); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + std::string_view transcript_idx, utterance_id_idx; + uint32_t rate = 0; + uint64_t i = 0; + + while (row.size() != 0) { + auto waveform = row["waveform"]; + auto sample_rate = row["sample_rate"]; + auto transcript = row["transcript"]; + auto utterance_id = row["utterance_id"]; + + MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); + + std::shared_ptr trate; + ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); + ASSERT_OK(trate->GetItemAt(&rate, {})); + MS_LOG(INFO) << "Audio sample rate: " << rate; + + std::shared_ptr de_transcript; + ASSERT_OK(Tensor::CreateFromMSTensor(transcript, &de_transcript)); + ASSERT_OK(de_transcript->GetItemAt(&transcript_idx, {})); + std::string s_transcript(transcript_idx); + MS_LOG(INFO) << "Tensor transcript value: " << transcript_idx; + + std::shared_ptr de_utterance_id; + ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); + ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); + std::string s_utterance_id(utterance_id_idx); + MS_LOG(INFO) << "Tensor utterance_id value: " << utterance_id_idx; + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 3); + + iter->Stop(); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset in pipeline mode +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestCMUArcticBasicWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticBasicWithPipeline."; + + // Create a CMUArcticDataset Dataset + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + std::shared_ptr ds = CMUArctic(folder_path, "aew", std::make_shared(0, 2)); + EXPECT_NE(ds, nullptr); + auto op = transforms::PadEnd({1, 50000});; + std::vector input_columns = {"waveform"}; + std::vector output_columns = {"waveform"}; + ds = ds->Map({op}, input_columns, output_columns); + EXPECT_NE(ds, nullptr); + ds = ds->Repeat(10); + EXPECT_NE(ds, nullptr); + ds = ds->Batch(2); + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + iter->GetNextRow(&row); + std::vector expected_utterance = {"Dog.", "Cat."}; + std::vector expected_utterance_id = {"a0001", "a0002"}; + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto waveform = row["waveform"]; + auto transcript = row["transcript"]; + auto utterance_id = row["utterance_id"]; + + std::shared_ptr de_expected_transcript; + ASSERT_OK(Tensor::CreateFromVector(expected_utterance, &de_expected_transcript)); + mindspore::MSTensor fix_expected_transcript = + mindspore::MSTensor(std::make_shared(de_expected_transcript)); + EXPECT_MSTENSOR_EQ(transcript, fix_expected_transcript); + + std::shared_ptr de_expected_utterance_id; + ASSERT_OK(Tensor::CreateFromVector(expected_utterance_id, &de_expected_utterance_id)); + mindspore::MSTensor fix_expected_utterance_id = + mindspore::MSTensor(std::make_shared(de_expected_utterance_id)); + EXPECT_MSTENSOR_EQ(utterance_id, fix_expected_utterance_id); + + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 10); + + iter->Stop(); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset with non-existing dataset directory +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestCMUArcticError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticError."; + + // Create a CMUArctic Dataset with non-existing dataset dir + std::shared_ptr ds0 = CMUArctic("NotExistFile"); + EXPECT_NE(ds0, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid CMUArctic input + EXPECT_EQ(iter0, nullptr); + + // Create a CMUArctic Dataset with invalid string of dataset dir + std::shared_ptr ds1 = CMUArctic(":*?\"<>|`&;'"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid CMUArctic input + EXPECT_EQ(iter1, nullptr); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset Getters method +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestCMUArcticGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticGetters."; + + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + // Create a CMUArctic Dataset. + std::shared_ptr ds1 = CMUArctic(folder_path); + std::shared_ptr ds2 = CMUArctic(folder_path, "aew"); + + std::vector column_names = {"waveform", "sample_rate", "transcript", "utterance_id"}; + + EXPECT_NE(ds1, nullptr); + EXPECT_EQ(ds1->GetDatasetSize(), 3); + EXPECT_EQ(ds1->GetColumnNames(), column_names); + + EXPECT_NE(ds2, nullptr); + EXPECT_EQ(ds2->GetDatasetSize(), 3); + EXPECT_EQ(ds2->GetColumnNames(), column_names); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset with invalid name +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestCMUArcticWithInvalidNameError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithInvalidNameError."; + + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + // Create a CMUArctic Dataset. + std::shared_ptr ds1 = CMUArctic(folder_path, "----"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid CMUArctic input, invalid name + EXPECT_EQ(iter1, nullptr); + + std::shared_ptr ds2 = CMUArctic(folder_path, "csacs"); + EXPECT_NE(ds2, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter2 = ds2->CreateIterator(); + // Expect failure: invalid CMUArctic input, invalid name + EXPECT_EQ(iter2, nullptr); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset with null sampler +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestCMUArcticWithNullSamplerError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithNullSamplerError."; + + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + // Create a CMUArctic Dataset. + std::shared_ptr ds = CMUArctic(folder_path, "aew", nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid CMUArctic input, sampler cannot be nullptr + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CMUArcticDataset +/// Description: Test CMUArcticDataset with SequentialSampler +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestCMUArcticNumSamplers) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCMUArcticWithSequentialSampler."; + + std::string folder_path = datasets_root_path_ + "/testCMUArcticData"; + // Create a CMUArctic Dataset. + std::shared_ptr ds = CMUArctic(folder_path, "aew", std::make_shared(0, 2)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + std::string_view transcript_idx, utterance_id_idx; + std::vector expected_utterance = {"Dog.", "Cat."}; + std::vector expected_utterance_id = {"a0001", "a0002"}; + uint32_t rate = 0; + uint64_t i = 0; + + while (row.size() != 0) { + auto waveform = row["waveform"]; + auto sample_rate = row["sample_rate"]; + auto transcript = row["transcript"]; + auto utterance_id = row["utterance_id"]; + + MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); + + std::shared_ptr trate; + ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); + ASSERT_OK(trate->GetItemAt(&rate, {})); + EXPECT_EQ(rate, 16000); + MS_LOG(INFO) << "Tensor sample rate: " << rate; + + std::shared_ptr de_transcript; + ASSERT_OK(Tensor::CreateFromMSTensor(transcript, &de_transcript)); + ASSERT_OK(de_transcript->GetItemAt(&transcript_idx, {})); + std::string s_transcript(transcript_idx); + EXPECT_STREQ(s_transcript.c_str(), expected_utterance[i].c_str()); + MS_LOG(INFO) << "Tensor transcript value: " << transcript_idx; + + std::shared_ptr de_utterance_id; + ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); + ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); + std::string s_utterance_id(utterance_id_idx); + EXPECT_STREQ(s_utterance_id.c_str(), expected_utterance_id[i].c_str()); + MS_LOG(INFO) << "Tensor utterance_id value: " << utterance_id_idx; + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 2); + + iter->Stop(); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc b/tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc index 21b953786f2..73d49ecb8c2 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc @@ -1,270 +1,270 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" - -#include "include/dataset/datasets.h" -#include "include/dataset/transforms.h" - -using namespace mindspore::dataset; -using mindspore::dataset::Tensor; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: GTZANDataset -/// Description: Test GTZAN -/// Expectation: Get correct GTZAN dataset -TEST_F(MindDataTestPipeline, TestGTZANBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANBasic."; - - std::string file_path = datasets_root_path_ + "/testGTZANData"; - // Create a GTZAN Dataset - std::shared_ptr ds = GTZAN(file_path); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - std::string_view label_idx; - uint32_t rate = 0; - uint64_t i = 0; - - while (row.size() != 0) { - i++; - auto waveform = row["waveform"]; - auto label = row["label"]; - auto sample_rate = row["sample_rate"]; - MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); - - std::shared_ptr trate; - ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); - ASSERT_OK(trate->GetItemAt(&rate, {})); - EXPECT_EQ(rate, 22050); - MS_LOG(INFO) << "Tensor label rate: " << rate; - - std::shared_ptr de_label; - ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label)); - ASSERT_OK(de_label->GetItemAt(&label_idx, {})); - std::string s_label(label_idx); - std::string expected("blues"); - EXPECT_STREQ(s_label.c_str(), expected.c_str()); - MS_LOG(INFO) << "Tensor label value: " << label_idx; - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN with Pipeline -/// Expectation: Get correct GTZAN dataset -TEST_F(MindDataTestPipeline, TestGTZANBasicWithPipeline) { - MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestGTZANBasicWithPipeline."; - - // Create a GTZANDataset Dataset. - std::string folder_path = datasets_root_path_ + "/testGTZANData"; - std::shared_ptr ds = GTZAN(folder_path, "all", std::make_shared(false, 2)); - EXPECT_NE(ds, nullptr); - auto op = transforms::PadEnd({1, 50000}); - std::vector input_columns = {"waveform"}; - std::vector output_columns = {"waveform"}; - ds = ds->Map({op}, input_columns, output_columns); - EXPECT_NE(ds, nullptr); - ds = ds->Repeat(10); - EXPECT_NE(ds, nullptr); - ds = ds->Batch(5); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - iter->GetNextRow(&row); - std::vector expected_rate = {22050, 22050, 22050, 22050, 22050}; - std::vector expected_label = {"blues", "blues", "blues", "blues", "blues"}; - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto waveform = row["waveform"]; - auto label = row["label"]; - auto sample_rate = row["sample_rate"]; - - std::shared_ptr de_expected_rate; - ASSERT_OK(Tensor::CreateFromVector(expected_rate, &de_expected_rate)); - mindspore::MSTensor fix_expected_rate = - mindspore::MSTensor(std::make_shared(de_expected_rate)); - EXPECT_MSTENSOR_EQ(sample_rate, fix_expected_rate); - - std::shared_ptr de_expected_label; - ASSERT_OK(Tensor::CreateFromVector(expected_label, &de_expected_label)); - mindspore::MSTensor fix_expected_label = - mindspore::MSTensor(std::make_shared(de_expected_label)); - EXPECT_MSTENSOR_EQ(label, fix_expected_label); - - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 4); - // Manually terminate the pipeline. - iter->Stop(); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN with invalid directory -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestGTZANError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANError."; - - // Create a GTZAN Dataset with non-existing dataset dir. - std::shared_ptr ds0 = GTZAN("NotExistFile"); - EXPECT_NE(ds0, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid GTZAN30k input. - EXPECT_EQ(iter0, nullptr); - - // Create a GTZAN Dataset with invalid string of dataset dir. - std::shared_ptr ds1 = GTZAN(":*?\"<>|`&;'"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid GTZAN input. - EXPECT_EQ(iter1, nullptr); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN with Getters -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestGTZANGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANGetters."; - - std::string folder_path = datasets_root_path_ + "/testGTZANData"; - // Create a GTZAN Dataset. - std::shared_ptr ds1 = GTZAN(folder_path); - std::shared_ptr ds2 = GTZAN(folder_path, "all"); - std::shared_ptr ds3 = GTZAN(folder_path, "valid"); - - std::vector column_names = {"waveform", "sample_rate", "label"}; - - EXPECT_NE(ds1, nullptr); - EXPECT_EQ(ds1->GetDatasetSize(), 3); - EXPECT_EQ(ds1->GetColumnNames(), column_names); - - EXPECT_NE(ds2, nullptr); - EXPECT_EQ(ds2->GetDatasetSize(), 3); - EXPECT_EQ(ds2->GetColumnNames(), column_names); - - EXPECT_NE(ds3, nullptr); - EXPECT_EQ(ds3->GetDatasetSize(), 3); - EXPECT_EQ(ds3->GetColumnNames(), column_names); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN dataset with invalid usage -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestGTZANWithInvalidUsageError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithInvalidUsageError."; - - std::string folder_path = datasets_root_path_ + "/testGTZANData"; - // Create a GTZAN Dataset. - std::shared_ptr ds1 = GTZAN(folder_path, "----"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter1 = ds1->CreateIterator(); - - EXPECT_EQ(iter1, nullptr); - - std::shared_ptr ds2 = GTZAN(folder_path, "csacs"); - EXPECT_NE(ds2, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter2 = ds2->CreateIterator(); - EXPECT_EQ(iter2, nullptr); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN dataset with null sampler -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestGTZANWithNullSamplerError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithNullSamplerError."; - - std::string folder_path = datasets_root_path_ + "/testGTZANData"; - // Create a GTZAN Dataset. - std::shared_ptr ds = GTZAN(folder_path, "all ", nullptr); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid GTZAN input, sampler cannot be nullptr. - EXPECT_EQ(iter, nullptr); -} - -/// Feature: GTZANDataset -/// Description: Test GTZAN with sequential sampler -/// Expectation: Get correct GTZAN dataset -TEST_F(MindDataTestPipeline, TestGTZANNumSamplers) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithSequentialSampler."; - - std::string folder_path = datasets_root_path_ + "/testGTZANData"; - // Create a GTZAN Dataset. - std::shared_ptr ds = GTZAN(folder_path, "all", std::make_shared(0, 2)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - uint32_t rate = 0; - uint64_t i = 0; - - while (row.size() != 0) { - auto waveform = row["waveform"]; - auto sample_rate = row["sample_rate"]; - - MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); - - std::shared_ptr t_rate; - ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &t_rate)); - ASSERT_OK(t_rate->GetItemAt(&rate, {})); - EXPECT_EQ(rate, 22050); - MS_LOG(INFO) << "Tensor sample rate: " << rate; - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 2); - - iter->Stop(); -} +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" + +#include "include/dataset/datasets.h" +#include "include/dataset/transforms.h" + +using namespace mindspore::dataset; +using mindspore::dataset::Tensor; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: GTZANDataset +/// Description: Test GTZAN +/// Expectation: Get correct GTZAN dataset +TEST_F(MindDataTestPipeline, TestGTZANBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANBasic."; + + std::string file_path = datasets_root_path_ + "/testGTZANData"; + // Create a GTZAN Dataset + std::shared_ptr ds = GTZAN(file_path); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + std::string_view label_idx; + uint32_t rate = 0; + uint64_t i = 0; + + while (row.size() != 0) { + i++; + auto waveform = row["waveform"]; + auto label = row["label"]; + auto sample_rate = row["sample_rate"]; + MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); + + std::shared_ptr trate; + ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); + ASSERT_OK(trate->GetItemAt(&rate, {})); + EXPECT_EQ(rate, 22050); + MS_LOG(INFO) << "Tensor label rate: " << rate; + + std::shared_ptr de_label; + ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label)); + ASSERT_OK(de_label->GetItemAt(&label_idx, {})); + std::string s_label(label_idx); + std::string expected("blues"); + EXPECT_STREQ(s_label.c_str(), expected.c_str()); + MS_LOG(INFO) << "Tensor label value: " << label_idx; + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN with Pipeline +/// Expectation: Get correct GTZAN dataset +TEST_F(MindDataTestPipeline, TestGTZANBasicWithPipeline) { + MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestGTZANBasicWithPipeline."; + + // Create a GTZANDataset Dataset. + std::string folder_path = datasets_root_path_ + "/testGTZANData"; + std::shared_ptr ds = GTZAN(folder_path, "all", std::make_shared(false, 2)); + EXPECT_NE(ds, nullptr); + auto op = transforms::PadEnd({1, 50000}); + std::vector input_columns = {"waveform"}; + std::vector output_columns = {"waveform"}; + ds = ds->Map({op}, input_columns, output_columns); + EXPECT_NE(ds, nullptr); + ds = ds->Repeat(10); + EXPECT_NE(ds, nullptr); + ds = ds->Batch(5); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + iter->GetNextRow(&row); + std::vector expected_rate = {22050, 22050, 22050, 22050, 22050}; + std::vector expected_label = {"blues", "blues", "blues", "blues", "blues"}; + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto waveform = row["waveform"]; + auto label = row["label"]; + auto sample_rate = row["sample_rate"]; + + std::shared_ptr de_expected_rate; + ASSERT_OK(Tensor::CreateFromVector(expected_rate, &de_expected_rate)); + mindspore::MSTensor fix_expected_rate = + mindspore::MSTensor(std::make_shared(de_expected_rate)); + EXPECT_MSTENSOR_EQ(sample_rate, fix_expected_rate); + + std::shared_ptr de_expected_label; + ASSERT_OK(Tensor::CreateFromVector(expected_label, &de_expected_label)); + mindspore::MSTensor fix_expected_label = + mindspore::MSTensor(std::make_shared(de_expected_label)); + EXPECT_MSTENSOR_EQ(label, fix_expected_label); + + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 4); + // Manually terminate the pipeline. + iter->Stop(); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN with invalid directory +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestGTZANError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANError."; + + // Create a GTZAN Dataset with non-existing dataset dir. + std::shared_ptr ds0 = GTZAN("NotExistFile"); + EXPECT_NE(ds0, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid GTZAN30k input. + EXPECT_EQ(iter0, nullptr); + + // Create a GTZAN Dataset with invalid string of dataset dir. + std::shared_ptr ds1 = GTZAN(":*?\"<>|`&;'"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid GTZAN input. + EXPECT_EQ(iter1, nullptr); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN with Getters +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestGTZANGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANGetters."; + + std::string folder_path = datasets_root_path_ + "/testGTZANData"; + // Create a GTZAN Dataset. + std::shared_ptr ds1 = GTZAN(folder_path); + std::shared_ptr ds2 = GTZAN(folder_path, "all"); + std::shared_ptr ds3 = GTZAN(folder_path, "valid"); + + std::vector column_names = {"waveform", "sample_rate", "label"}; + + EXPECT_NE(ds1, nullptr); + EXPECT_EQ(ds1->GetDatasetSize(), 3); + EXPECT_EQ(ds1->GetColumnNames(), column_names); + + EXPECT_NE(ds2, nullptr); + EXPECT_EQ(ds2->GetDatasetSize(), 3); + EXPECT_EQ(ds2->GetColumnNames(), column_names); + + EXPECT_NE(ds3, nullptr); + EXPECT_EQ(ds3->GetDatasetSize(), 3); + EXPECT_EQ(ds3->GetColumnNames(), column_names); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN dataset with invalid usage +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestGTZANWithInvalidUsageError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithInvalidUsageError."; + + std::string folder_path = datasets_root_path_ + "/testGTZANData"; + // Create a GTZAN Dataset. + std::shared_ptr ds1 = GTZAN(folder_path, "----"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter1 = ds1->CreateIterator(); + + EXPECT_EQ(iter1, nullptr); + + std::shared_ptr ds2 = GTZAN(folder_path, "csacs"); + EXPECT_NE(ds2, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter2 = ds2->CreateIterator(); + EXPECT_EQ(iter2, nullptr); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN dataset with null sampler +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestGTZANWithNullSamplerError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithNullSamplerError."; + + std::string folder_path = datasets_root_path_ + "/testGTZANData"; + // Create a GTZAN Dataset. + std::shared_ptr ds = GTZAN(folder_path, "all ", nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid GTZAN input, sampler cannot be nullptr. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: GTZANDataset +/// Description: Test GTZAN with sequential sampler +/// Expectation: Get correct GTZAN dataset +TEST_F(MindDataTestPipeline, TestGTZANNumSamplers) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithSequentialSampler."; + + std::string folder_path = datasets_root_path_ + "/testGTZANData"; + // Create a GTZAN Dataset. + std::shared_ptr ds = GTZAN(folder_path, "all", std::make_shared(0, 2)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + uint32_t rate = 0; + uint64_t i = 0; + + while (row.size() != 0) { + auto waveform = row["waveform"]; + auto sample_rate = row["sample_rate"]; + + MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); + + std::shared_ptr t_rate; + ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &t_rate)); + ASSERT_OK(t_rate->GetItemAt(&rate, {})); + EXPECT_EQ(rate, 22050); + MS_LOG(INFO) << "Tensor sample rate: " << rate; + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 2); + + iter->Stop(); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_libri_tts_test.cc b/tests/ut/cpp/dataset/c_api_dataset_libri_tts_test.cc index 2cdc8960f2e..33b25d9d572 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_libri_tts_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_libri_tts_test.cc @@ -1,309 +1,309 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" -#include "minddata/dataset/include/dataset/datasets.h" -#include "include/dataset/transforms.h" - -using namespace mindspore::dataset; -using mindspore::dataset::Tensor; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset basic usage -/// Expectation: Get correct LibriTTS dataset -TEST_F(MindDataTestPipeline, TestLibriTTSBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSBasic."; - - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - std::shared_ptr ds = LibriTTS(folder_path); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - uint64_t i = 0; - - while (row.size() != 0) { - auto waveform = row["waveform"]; - auto sample_rate = row["sample_rate"]; - auto original_text = row["original_text"]; - auto normalized_text = row["normalized_text"]; - auto speaker_id = row["speaker_id"]; - auto chapter_id = row["chapter_id"]; - auto utterance_id = row["utterance_id"]; - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - EXPECT_EQ(i, 3); - iter->Stop(); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with pipeline mode -/// Expectation: Get correct LibriTTS dataset -TEST_F(MindDataTestPipeline, TestLibriTTSBasicWithPipeline) { - MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestLibriTTSBasicWithPipeline."; - - // Create a LibriTTSDataset Dataset - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - std::shared_ptr ds = LibriTTS(folder_path, "train-clean-100", std::make_shared(0, 2)); - EXPECT_NE(ds, nullptr); - auto op = transforms::PadEnd({1, 500000}); - std::vector input_columns = {"waveform"}; - std::vector output_columns = {"waveform"}; - ds = ds->Map({op}, input_columns, output_columns); - EXPECT_NE(ds, nullptr); - ds = ds->Repeat(5); - EXPECT_NE(ds, nullptr); - ds = ds->Batch(2); - EXPECT_NE(ds, nullptr); - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - // Iterate the dataset and get each row. - std::unordered_map row; - iter->GetNextRow(&row); - std::vector expected_original_text = {"good morning", "good afternoon"}; - std::vector expected_normalized_text = {"Good morning", "Good afternoon"}; - std::vector expected_speaker_id = {2506, 2506}; - std::vector expected_chapter_id = {11267, 11267}; - std::vector expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"}; - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto waveform = row["waveform"]; - auto original_text = row["original_text"]; - auto normalized_text = row["normalized_text"]; - auto sample_rate = row["sample_rate"]; - auto speaker_id = row["speaker_id"]; - auto chapter_id = row["chapter_id"]; - auto utterance_id = row["utterance_id"]; - - std::shared_ptr de_original_text; - ASSERT_OK(Tensor::CreateFromVector(expected_original_text, &de_original_text)); - mindspore::MSTensor fix_original_text = - mindspore::MSTensor(std::make_shared(de_original_text)); - EXPECT_MSTENSOR_EQ(original_text, fix_original_text); - - std::shared_ptr de_normalized_text; - ASSERT_OK(Tensor::CreateFromVector(expected_normalized_text, &de_normalized_text)); - mindspore::MSTensor fix_normalized_text = - mindspore::MSTensor(std::make_shared(de_normalized_text)); - EXPECT_MSTENSOR_EQ(normalized_text, fix_normalized_text); - - std::shared_ptr de_expected_speaker_id; - ASSERT_OK(Tensor::CreateFromVector(expected_speaker_id, &de_expected_speaker_id)); - mindspore::MSTensor fix_expected_speaker_id = - mindspore::MSTensor(std::make_shared(de_expected_speaker_id)); - EXPECT_MSTENSOR_EQ(speaker_id, fix_expected_speaker_id); - - std::shared_ptr de_expected_chapter_id; - ASSERT_OK(Tensor::CreateFromVector(expected_chapter_id, &de_expected_chapter_id)); - mindspore::MSTensor fix_expected_chapter_id = - mindspore::MSTensor(std::make_shared(de_expected_chapter_id)); - EXPECT_MSTENSOR_EQ(chapter_id, fix_expected_chapter_id); - - std::shared_ptr de_expected_utterance_id; - ASSERT_OK(Tensor::CreateFromVector(expected_utterance_id, &de_expected_utterance_id)); - mindspore::MSTensor fix_expected_utterance_id = - mindspore::MSTensor(std::make_shared(de_expected_utterance_id)); - EXPECT_MSTENSOR_EQ(utterance_id, fix_expected_utterance_id); - - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 5); - iter->Stop(); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with invalid directory -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestLibriTTSError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSError."; - - // Create a LibriTTS Dataset with non-existing dataset dir - std::shared_ptr ds0 = LibriTTS("NotExistFile"); - EXPECT_NE(ds0, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid LibriTTS input - EXPECT_EQ(iter0, nullptr); - - // Create a LibriTTS Dataset with invalid string of dataset dir - std::shared_ptr ds1 = LibriTTS(":*?\"<>|`&;'"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid LibriTTS input - EXPECT_EQ(iter1, nullptr); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with Getters -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestLibriTTSGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSGetters."; - - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - // Create a LibriTTS Dataset. - std::shared_ptr ds1 = LibriTTS(folder_path); - std::shared_ptr ds2 = LibriTTS(folder_path, "train-clean-100"); - - std::vector column_names = {"waveform", "sample_rate", "original_text", "normalized_text", - "speaker_id", "chapter_id", "utterance_id"}; - - EXPECT_NE(ds1, nullptr); - EXPECT_EQ(ds1->GetDatasetSize(), 3); - EXPECT_EQ(ds1->GetColumnNames(), column_names); - - EXPECT_NE(ds2, nullptr); - EXPECT_EQ(ds2->GetDatasetSize(), 3); - EXPECT_EQ(ds2->GetColumnNames(), column_names); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with invalid usage -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestLibriTTSWithInvalidUsageError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithInvalidUsageError."; - - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - // Create a LibriTTS Dataset. - std::shared_ptr ds1 = LibriTTS(folder_path, "----"); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid LibriTTS input, sampler cannot be nullptr - EXPECT_EQ(iter1, nullptr); - - std::shared_ptr ds2 = LibriTTS(folder_path, "csacs"); - EXPECT_NE(ds2, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter2 = ds2->CreateIterator(); - // Expect failure: invalid LibriTTS input, sampler cannot be nullptr - EXPECT_EQ(iter2, nullptr); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with null sampler -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestLibriTTSWithNullSamplerError) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithNullSamplerError."; - - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - // Create a LibriTTS Dataset. - std::shared_ptr ds = LibriTTS(folder_path, "all", nullptr); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid LibriTTS input, sampler cannot be nullptr - EXPECT_EQ(iter, nullptr); -} - -/// Feature: LibriTTSDataset -/// Description: Test LibriTTSDataset with SequentialSampler -/// Expectation: Get correct LibriTTS dataset -TEST_F(MindDataTestPipeline, TestLibriTTSSequentialSamplers) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSSequentialSamplers."; - - std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; - std::shared_ptr ds = LibriTTS(folder_path, "all", std::make_shared(0, 2)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row. - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - std::string_view original_text_idx, normalized_text_idx, utterance_id_idx; - uint32_t speaker_idx_id = 0, chapter_idx_id = 0; - std::vector expected_original_text = {"good morning", "good afternoon"}; - std::vector expected_normalized_text = {"Good morning", "Good afternoon"}; - std::vector expected_speaker_id = {2506, 2506}; - std::vector expected_chapter_id = {11267, 11267}; - std::vector expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"}; - uint32_t rate = 0; - uint64_t i = 0; - while (row.size() != 0) { - auto waveform = row["waveform"]; - auto sample_rate = row["sample_rate"]; - auto original_text = row["original_text"]; - auto normalized_text = row["normalized_text"]; - auto speaker_id = row["speaker_id"]; - auto chapter_id = row["chapter_id"]; - auto utterance_id = row["utterance_id"]; - - MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); - - std::shared_ptr trate; - ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); - ASSERT_OK(trate->GetItemAt(&rate, {})); - EXPECT_EQ(rate, 24000); - - std::shared_ptr de_original_text; - ASSERT_OK(Tensor::CreateFromMSTensor(original_text, &de_original_text)); - ASSERT_OK(de_original_text->GetItemAt(&original_text_idx, {})); - std::string s_original_text(original_text_idx); - EXPECT_STREQ(s_original_text.c_str(), expected_original_text[i].c_str()); - - std::shared_ptr de_normalized_text; - ASSERT_OK(Tensor::CreateFromMSTensor(normalized_text, &de_normalized_text)); - ASSERT_OK(de_normalized_text->GetItemAt(&normalized_text_idx, {})); - std::string s_normalized_text(normalized_text_idx); - EXPECT_STREQ(s_normalized_text.c_str(), expected_normalized_text[i].c_str()); - - std::shared_ptr de_speaker_id; - ASSERT_OK(Tensor::CreateFromMSTensor(speaker_id, &de_speaker_id)); - ASSERT_OK(de_speaker_id->GetItemAt(&speaker_idx_id, {})); - EXPECT_EQ(speaker_idx_id, expected_speaker_id[i]); - - std::shared_ptr de_chapter_id; - ASSERT_OK(Tensor::CreateFromMSTensor(chapter_id, &de_chapter_id)); - ASSERT_OK(de_chapter_id->GetItemAt(&chapter_idx_id, {})); - EXPECT_EQ(chapter_idx_id, expected_chapter_id[i]); - - std::shared_ptr de_utterance_id; - ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); - ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); - std::string s_utterance_id(utterance_id_idx); - EXPECT_STREQ(s_utterance_id.c_str(), expected_utterance_id[i].c_str()); - - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 2); - - iter->Stop(); -} +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/dataset/datasets.h" +#include "include/dataset/transforms.h" + +using namespace mindspore::dataset; +using mindspore::dataset::Tensor; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset basic usage +/// Expectation: Get correct LibriTTS dataset +TEST_F(MindDataTestPipeline, TestLibriTTSBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSBasic."; + + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + std::shared_ptr ds = LibriTTS(folder_path); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + uint64_t i = 0; + + while (row.size() != 0) { + auto waveform = row["waveform"]; + auto sample_rate = row["sample_rate"]; + auto original_text = row["original_text"]; + auto normalized_text = row["normalized_text"]; + auto speaker_id = row["speaker_id"]; + auto chapter_id = row["chapter_id"]; + auto utterance_id = row["utterance_id"]; + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + EXPECT_EQ(i, 3); + iter->Stop(); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with pipeline mode +/// Expectation: Get correct LibriTTS dataset +TEST_F(MindDataTestPipeline, TestLibriTTSBasicWithPipeline) { + MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestLibriTTSBasicWithPipeline."; + + // Create a LibriTTSDataset Dataset + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + std::shared_ptr ds = LibriTTS(folder_path, "train-clean-100", std::make_shared(0, 2)); + EXPECT_NE(ds, nullptr); + auto op = transforms::PadEnd({1, 500000}); + std::vector input_columns = {"waveform"}; + std::vector output_columns = {"waveform"}; + ds = ds->Map({op}, input_columns, output_columns); + EXPECT_NE(ds, nullptr); + ds = ds->Repeat(5); + EXPECT_NE(ds, nullptr); + ds = ds->Batch(2); + EXPECT_NE(ds, nullptr); + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + // Iterate the dataset and get each row. + std::unordered_map row; + iter->GetNextRow(&row); + std::vector expected_original_text = {"good morning", "good afternoon"}; + std::vector expected_normalized_text = {"Good morning", "Good afternoon"}; + std::vector expected_speaker_id = {2506, 2506}; + std::vector expected_chapter_id = {11267, 11267}; + std::vector expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"}; + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto waveform = row["waveform"]; + auto original_text = row["original_text"]; + auto normalized_text = row["normalized_text"]; + auto sample_rate = row["sample_rate"]; + auto speaker_id = row["speaker_id"]; + auto chapter_id = row["chapter_id"]; + auto utterance_id = row["utterance_id"]; + + std::shared_ptr de_original_text; + ASSERT_OK(Tensor::CreateFromVector(expected_original_text, &de_original_text)); + mindspore::MSTensor fix_original_text = + mindspore::MSTensor(std::make_shared(de_original_text)); + EXPECT_MSTENSOR_EQ(original_text, fix_original_text); + + std::shared_ptr de_normalized_text; + ASSERT_OK(Tensor::CreateFromVector(expected_normalized_text, &de_normalized_text)); + mindspore::MSTensor fix_normalized_text = + mindspore::MSTensor(std::make_shared(de_normalized_text)); + EXPECT_MSTENSOR_EQ(normalized_text, fix_normalized_text); + + std::shared_ptr de_expected_speaker_id; + ASSERT_OK(Tensor::CreateFromVector(expected_speaker_id, &de_expected_speaker_id)); + mindspore::MSTensor fix_expected_speaker_id = + mindspore::MSTensor(std::make_shared(de_expected_speaker_id)); + EXPECT_MSTENSOR_EQ(speaker_id, fix_expected_speaker_id); + + std::shared_ptr de_expected_chapter_id; + ASSERT_OK(Tensor::CreateFromVector(expected_chapter_id, &de_expected_chapter_id)); + mindspore::MSTensor fix_expected_chapter_id = + mindspore::MSTensor(std::make_shared(de_expected_chapter_id)); + EXPECT_MSTENSOR_EQ(chapter_id, fix_expected_chapter_id); + + std::shared_ptr de_expected_utterance_id; + ASSERT_OK(Tensor::CreateFromVector(expected_utterance_id, &de_expected_utterance_id)); + mindspore::MSTensor fix_expected_utterance_id = + mindspore::MSTensor(std::make_shared(de_expected_utterance_id)); + EXPECT_MSTENSOR_EQ(utterance_id, fix_expected_utterance_id); + + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + iter->Stop(); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with invalid directory +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestLibriTTSError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSError."; + + // Create a LibriTTS Dataset with non-existing dataset dir + std::shared_ptr ds0 = LibriTTS("NotExistFile"); + EXPECT_NE(ds0, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid LibriTTS input + EXPECT_EQ(iter0, nullptr); + + // Create a LibriTTS Dataset with invalid string of dataset dir + std::shared_ptr ds1 = LibriTTS(":*?\"<>|`&;'"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid LibriTTS input + EXPECT_EQ(iter1, nullptr); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with Getters +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestLibriTTSGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSGetters."; + + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + // Create a LibriTTS Dataset. + std::shared_ptr ds1 = LibriTTS(folder_path); + std::shared_ptr ds2 = LibriTTS(folder_path, "train-clean-100"); + + std::vector column_names = {"waveform", "sample_rate", "original_text", "normalized_text", + "speaker_id", "chapter_id", "utterance_id"}; + + EXPECT_NE(ds1, nullptr); + EXPECT_EQ(ds1->GetDatasetSize(), 3); + EXPECT_EQ(ds1->GetColumnNames(), column_names); + + EXPECT_NE(ds2, nullptr); + EXPECT_EQ(ds2->GetDatasetSize(), 3); + EXPECT_EQ(ds2->GetColumnNames(), column_names); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with invalid usage +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestLibriTTSWithInvalidUsageError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithInvalidUsageError."; + + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + // Create a LibriTTS Dataset. + std::shared_ptr ds1 = LibriTTS(folder_path, "----"); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid LibriTTS input, sampler cannot be nullptr + EXPECT_EQ(iter1, nullptr); + + std::shared_ptr ds2 = LibriTTS(folder_path, "csacs"); + EXPECT_NE(ds2, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter2 = ds2->CreateIterator(); + // Expect failure: invalid LibriTTS input, sampler cannot be nullptr + EXPECT_EQ(iter2, nullptr); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with null sampler +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestLibriTTSWithNullSamplerError) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSWithNullSamplerError."; + + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + // Create a LibriTTS Dataset. + std::shared_ptr ds = LibriTTS(folder_path, "all", nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid LibriTTS input, sampler cannot be nullptr + EXPECT_EQ(iter, nullptr); +} + +/// Feature: LibriTTSDataset +/// Description: Test LibriTTSDataset with SequentialSampler +/// Expectation: Get correct LibriTTS dataset +TEST_F(MindDataTestPipeline, TestLibriTTSSequentialSamplers) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLibriTTSSequentialSamplers."; + + std::string folder_path = datasets_root_path_ + "/testLibriTTSData"; + std::shared_ptr ds = LibriTTS(folder_path, "all", std::make_shared(0, 2)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + std::string_view original_text_idx, normalized_text_idx, utterance_id_idx; + uint32_t speaker_idx_id = 0, chapter_idx_id = 0; + std::vector expected_original_text = {"good morning", "good afternoon"}; + std::vector expected_normalized_text = {"Good morning", "Good afternoon"}; + std::vector expected_speaker_id = {2506, 2506}; + std::vector expected_chapter_id = {11267, 11267}; + std::vector expected_utterance_id = {"2506_11267_000001_000000", "2506_11267_000002_000000"}; + uint32_t rate = 0; + uint64_t i = 0; + while (row.size() != 0) { + auto waveform = row["waveform"]; + auto sample_rate = row["sample_rate"]; + auto original_text = row["original_text"]; + auto normalized_text = row["normalized_text"]; + auto speaker_id = row["speaker_id"]; + auto chapter_id = row["chapter_id"]; + auto utterance_id = row["utterance_id"]; + + MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); + + std::shared_ptr trate; + ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); + ASSERT_OK(trate->GetItemAt(&rate, {})); + EXPECT_EQ(rate, 24000); + + std::shared_ptr de_original_text; + ASSERT_OK(Tensor::CreateFromMSTensor(original_text, &de_original_text)); + ASSERT_OK(de_original_text->GetItemAt(&original_text_idx, {})); + std::string s_original_text(original_text_idx); + EXPECT_STREQ(s_original_text.c_str(), expected_original_text[i].c_str()); + + std::shared_ptr de_normalized_text; + ASSERT_OK(Tensor::CreateFromMSTensor(normalized_text, &de_normalized_text)); + ASSERT_OK(de_normalized_text->GetItemAt(&normalized_text_idx, {})); + std::string s_normalized_text(normalized_text_idx); + EXPECT_STREQ(s_normalized_text.c_str(), expected_normalized_text[i].c_str()); + + std::shared_ptr de_speaker_id; + ASSERT_OK(Tensor::CreateFromMSTensor(speaker_id, &de_speaker_id)); + ASSERT_OK(de_speaker_id->GetItemAt(&speaker_idx_id, {})); + EXPECT_EQ(speaker_idx_id, expected_speaker_id[i]); + + std::shared_ptr de_chapter_id; + ASSERT_OK(Tensor::CreateFromMSTensor(chapter_id, &de_chapter_id)); + ASSERT_OK(de_chapter_id->GetItemAt(&chapter_idx_id, {})); + EXPECT_EQ(chapter_idx_id, expected_chapter_id[i]); + + std::shared_ptr de_utterance_id; + ASSERT_OK(Tensor::CreateFromMSTensor(utterance_id, &de_utterance_id)); + ASSERT_OK(de_utterance_id->GetItemAt(&utterance_id_idx, {})); + std::string s_utterance_id(utterance_id_idx); + EXPECT_STREQ(s_utterance_id.c_str(), expected_utterance_id[i].c_str()); + + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 2); + + iter->Stop(); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_penn_treebank_test.cc b/tests/ut/cpp/dataset/c_api_dataset_penn_treebank_test.cc index f4926da2315..a73b3e5e997 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_penn_treebank_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_penn_treebank_test.cc @@ -1,592 +1,592 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" -#include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/include/dataset/datasets.h" - -using namespace mindspore::dataset; - -using mindspore::dataset::ShuffleMode; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset basic usage -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasic."; - // Test PennTreebank Dataset with single text file and many default inputs - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(987); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("text"), row.end()); - std::vector expected_result = { - {" no it was black friday "}, - {" clash twits poetry formulate flip loyalty splash "}, - {" you pay less for the supermaket's own brands "}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); - // Compare against expected result - EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); - - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 3 samples - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset in pipeline mode -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasicWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasicWithPipeline."; - // Test PennTreebank Dataset with single text file and many default inputs - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(987); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds1 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::shared_ptr ds2 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds1, nullptr); - EXPECT_NE(ds2, nullptr); - - // Create two Repeat operation on ds - int32_t repeat_num = 2; - ds1 = ds1->Repeat(repeat_num); - EXPECT_NE(ds1, nullptr); - repeat_num = 3; - ds2 = ds2->Repeat(repeat_num); - EXPECT_NE(ds2, nullptr); - - // Create a Concat operation on the ds - ds1 = ds1->Concat({ds2}); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds1->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("text"), row.end()); - std::vector expected_result = { - {" no it was black friday "}, - {" clash twits poetry formulate flip loyalty splash "}, - {" you pay less for the supermaket's own brands "}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 15 samples - EXPECT_EQ(i, 15); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: PennTreebankDataset -/// Description: Test iterator of PennTreebankDataset with only text column -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetIteratorOneColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetIteratorOneColumn."; - // Create a PennTreebank dataset - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(987); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 1; - ds = ds->Batch(batch_size); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // Only select "text" column and drop others - std::vector columns = {"text"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - uint64_t i = 0; - while (row.size() != 0) { - auto audio = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << audio.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: PennTreebankDataset -/// Description: Test iterator of PennTreebankDataset with wrong column -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetIteratorWrongColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetIteratorWrongColumn."; - // Create a PennTreebank dataset - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(987); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Pass wrong column name - std::vector columns = {"digital"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset Getters method -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestPennTreebankGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankGetters."; - // Test PennTreebank Dataset with single text file and many default inputs - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(987); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - std::vector column_names = {"text"}; - EXPECT_EQ(ds->GetDatasetSize(), 2); - EXPECT_EQ(ds->GetColumnNames(), column_names); - - ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 3); - - std::vector types = ToDETypes(ds->GetOutputTypes()); - std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); - EXPECT_EQ(types.size(), 1); - EXPECT_EQ(types[0].ToString(), "string"); - EXPECT_EQ(shapes.size(), 1); - EXPECT_EQ(shapes[0].ToString(), "<>"); - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with invalid samplers -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail1) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail1."; - - // Create a PennTreebank Dataset - // with invalid samplers=-1 - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", -1, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: PennTreebank number of samples cannot be negative - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with empty dataset_files input -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail2) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail2."; - - // Attempt to create a PennTreebank Dataset - // with wrongful empty dataset_files input - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank("123", "test", 2, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: dataset_dir is not specified - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with non-existent dataset_files input -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail3) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail3."; - - // Create a PennTreebank Dataset - // with non-existent dataset_files input - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "asd", 2, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid usage - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with empty string dataset_files input -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail4) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail4."; - - // Create a PennTreebank Dataset - // with empty string dataset_files input - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank("", "test", 2, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: specified dataset_files does not exist - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with invalid num_shards=0 -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail5) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail5."; - - // Create a PennTreebank Dataset - // with invalid num_shards=0 value - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 0); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: Number of shards cannot be <=0 - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with invalid shard_id=-1 -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail6) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail6."; - - // Create a PennTreebank Dataset - // with invalid shard_id=-1 value - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 1, -1); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: shard_id cannot be negative - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with invalid shard_id=2 and num_shards=2 combination -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail7) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail7."; - - // Create a PennTreebank Dataset - // with invalid shard_id=2 and num_shards=2 combination - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: Cannot have shard_id >= num_shards - EXPECT_EQ(iter, nullptr); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with ShuffleMode::kFalse -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFalse) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFalse."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(246); - GlobalContext::config_manager()->set_num_parallel_workers(2); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("text"), row.end()); - std::vector expected_result = { - {" no it was black friday "}, - {" does the bank charge a fee for setting up the account "}, - {" clash twits poetry formulate flip loyalty splash "}, - {" the wardrobe was very small in our room "}, - {" you pay less for the supermaket's own brands "}, - {" black white grapes "}, - {" just ahead of them there was a huge fissure "}, - {" the proportion of female workers in this company "}, - {" everyone in our football team is fuming "}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); - // Compare against expected result - EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); - - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 9 samples - EXPECT_EQ(i, 9); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with ShuffleMode::kFiles -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFilesA) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFilesA."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(654); - GlobalContext::config_manager()->set_num_parallel_workers(1); - - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFiles); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("text"), row.end()); - std::vector expected_result = { - {" does the bank charge a fee for setting up the account "}, - {" the wardrobe was very small in our room "}, - {" black white grapes "}, - {" no it was black friday "}, - {" clash twits poetry formulate flip loyalty splash "}, - {" you pay less for the supermaket's own brands "}, - {" just ahead of them there was a huge fissure "}, - {" the proportion of female workers in this company "}, - {" everyone in our football team is fuming "}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); - // Compare against expected result - EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); - - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 9 samples - EXPECT_EQ(i, 9); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: PennTreebankDataset -/// Description: Test PennTreebankDataset with ShuffleMode::kGlobal -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleGlobal) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleGlobal."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(246); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - // Create a TextFile Dataset, with two text files - // Note: 1.txt has 3 rows - // Note: 2.txt has 2 rows - // Set shuffle to global shuffle - std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; - std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kGlobal); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset. - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("text"), row.end()); - std::vector expected_result = { - {" everyone in our football team is fuming "}, - {" does the bank charge a fee for setting up the account "}, - {" clash twits poetry formulate flip loyalty splash "}, - {" no it was black friday "}, - {" just ahead of them there was a huge fissure "}, - {" the proportion of female workers in this company "}, - {" you pay less for the supermaket's own brands "}, - {" the wardrobe was very small in our room "}, - {" black white grapes "}, - }; - - uint64_t i = 0; - while (row.size() != 0) { - auto text = row["text"]; - MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); - // Compare against expected result - EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); - - i++; - ASSERT_OK(iter->GetNextRow(&row)); - } - - // Expect 9 samples - EXPECT_EQ(i, 9); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; + +using mindspore::dataset::ShuffleMode; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset basic usage +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasic."; + // Test PennTreebank Dataset with single text file and many default inputs + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("text"), row.end()); + std::vector expected_result = { + {" no it was black friday "}, + {" clash twits poetry formulate flip loyalty splash "}, + {" you pay less for the supermaket's own brands "}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); + // Compare against expected result + EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); + + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 3 samples + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset in pipeline mode +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetBasicWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetBasicWithPipeline."; + // Test PennTreebank Dataset with single text file and many default inputs + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds1 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::shared_ptr ds2 = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds + int32_t repeat_num = 2; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 3; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("text"), row.end()); + std::vector expected_result = { + {" no it was black friday "}, + {" clash twits poetry formulate flip loyalty splash "}, + {" you pay less for the supermaket's own brands "}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 15 samples + EXPECT_EQ(i, 15); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: PennTreebankDataset +/// Description: Test iterator of PennTreebankDataset with only text column +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetIteratorOneColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetIteratorOneColumn."; + // Create a PennTreebank dataset + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // Only select "text" column and drop others + std::vector columns = {"text"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + uint64_t i = 0; + while (row.size() != 0) { + auto audio = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << audio.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: PennTreebankDataset +/// Description: Test iterator of PennTreebankDataset with wrong column +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetIteratorWrongColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetIteratorWrongColumn."; + // Create a PennTreebank dataset + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Pass wrong column name + std::vector columns = {"digital"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset Getters method +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestPennTreebankGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankGetters."; + // Test PennTreebank Dataset with single text file and many default inputs + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + std::vector column_names = {"text"}; + EXPECT_EQ(ds->GetDatasetSize(), 2); + EXPECT_EQ(ds->GetColumnNames(), column_names); + + ds = PennTreebank(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); + + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + EXPECT_EQ(types.size(), 1); + EXPECT_EQ(types[0].ToString(), "string"); + EXPECT_EQ(shapes.size(), 1); + EXPECT_EQ(shapes[0].ToString(), "<>"); + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with invalid samplers +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail1."; + + // Create a PennTreebank Dataset + // with invalid samplers=-1 + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", -1, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: PennTreebank number of samples cannot be negative + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with empty dataset_files input +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail2."; + + // Attempt to create a PennTreebank Dataset + // with wrongful empty dataset_files input + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank("123", "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: dataset_dir is not specified + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with non-existent dataset_files input +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail3) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail3."; + + // Create a PennTreebank Dataset + // with non-existent dataset_files input + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "asd", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid usage + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with empty string dataset_files input +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail4) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail4."; + + // Create a PennTreebank Dataset + // with empty string dataset_files input + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank("", "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: specified dataset_files does not exist + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with invalid num_shards=0 +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail5) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail5."; + + // Create a PennTreebank Dataset + // with invalid num_shards=0 value + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 0); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: Number of shards cannot be <=0 + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with invalid shard_id=-1 +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail6) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail6."; + + // Create a PennTreebank Dataset + // with invalid shard_id=-1 value + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 1, -1); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: shard_id cannot be negative + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with invalid shard_id=2 and num_shards=2 combination +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetFail7) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetFail7."; + + // Create a PennTreebank Dataset + // with invalid shard_id=2 and num_shards=2 combination + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: Cannot have shard_id >= num_shards + EXPECT_EQ(iter, nullptr); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with ShuffleMode::kFalse +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFalse) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFalse."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(246); + GlobalContext::config_manager()->set_num_parallel_workers(2); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("text"), row.end()); + std::vector expected_result = { + {" no it was black friday "}, + {" does the bank charge a fee for setting up the account "}, + {" clash twits poetry formulate flip loyalty splash "}, + {" the wardrobe was very small in our room "}, + {" you pay less for the supermaket's own brands "}, + {" black white grapes "}, + {" just ahead of them there was a huge fissure "}, + {" the proportion of female workers in this company "}, + {" everyone in our football team is fuming "}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); + // Compare against expected result + EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); + + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 9 samples + EXPECT_EQ(i, 9); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with ShuffleMode::kFiles +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleFilesA) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleFilesA."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(654); + GlobalContext::config_manager()->set_num_parallel_workers(1); + + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kFiles); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("text"), row.end()); + std::vector expected_result = { + {" does the bank charge a fee for setting up the account "}, + {" the wardrobe was very small in our room "}, + {" black white grapes "}, + {" no it was black friday "}, + {" clash twits poetry formulate flip loyalty splash "}, + {" you pay less for the supermaket's own brands "}, + {" just ahead of them there was a huge fissure "}, + {" the proportion of female workers in this company "}, + {" everyone in our football team is fuming "}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); + // Compare against expected result + EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); + + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 9 samples + EXPECT_EQ(i, 9); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: PennTreebankDataset +/// Description: Test PennTreebankDataset with ShuffleMode::kGlobal +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestPennTreebankDatasetShuffleGlobal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPennTreebankDatasetShuffleGlobal."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(246); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a TextFile Dataset, with two text files + // Note: 1.txt has 3 rows + // Note: 2.txt has 2 rows + // Set shuffle to global shuffle + std::string dataset_dir = datasets_root_path_ + "/testPennTreebank"; + std::shared_ptr ds = PennTreebank(dataset_dir, "all", 0, ShuffleMode::kGlobal); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("text"), row.end()); + std::vector expected_result = { + {" everyone in our football team is fuming "}, + {" does the bank charge a fee for setting up the account "}, + {" clash twits poetry formulate flip loyalty splash "}, + {" no it was black friday "}, + {" just ahead of them there was a huge fissure "}, + {" the proportion of female workers in this company "}, + {" you pay less for the supermaket's own brands "}, + {" the wardrobe was very small in our room "}, + {" black white grapes "}, + }; + + uint64_t i = 0; + while (row.size() != 0) { + auto text = row["text"]; + MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); + // Compare against expected result + EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); + + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 9 samples + EXPECT_EQ(i, 9); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc b/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc index c7db37dc5ef..41d984a6257 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_qmnist_test.cc @@ -1,436 +1,436 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" -#include "minddata/dataset/include/dataset/datasets.h" - -using namespace mindspore::dataset; -using mindspore::dataset::DataType; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: QMnistTrainDataset. -/// Description: Test basic usage of QMnistTrainDataset. -/// Expectation: Get correct number of data. -TEST_F(MindDataTestPipeline, TestQMnistTrainDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDataset."; - - // Create a QMNIST Train Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test basic usage of QMnistDataset with test dataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistTestDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTestDataset."; - - // Create a QMNIST Test Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "test", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test basic usage of QMnistDataset with nist dataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistNistDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistNistDataset."; - - // Create a QMNIST Nist Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "nist", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test basic usage of QMnistDataset with all dataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistAllDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistAllDataset."; - - // Create a QMNIST All Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "all", true, std::make_shared(false, 20)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 20); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test basic usage of QMnistDataset with all and compat dataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistCompatDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistCompatDataset."; - - // Create a QMNIST All Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "all", false, std::make_shared(false, 20)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - auto label = row["label"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - MS_LOG(INFO) << "Tensor label shape: " << label.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 20); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test usage of QMnistDataset with pipeline mode -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistDatasetWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetWithPipeline."; - - // Create two QMNIST Train Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds1 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); - std::shared_ptr ds2 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); - EXPECT_NE(ds1, nullptr); - EXPECT_NE(ds2, nullptr); - - // Create two Repeat operation on ds - int32_t repeat_num = 1; - ds1 = ds1->Repeat(repeat_num); - EXPECT_NE(ds1, nullptr); - repeat_num = 1; - ds2 = ds2->Repeat(repeat_num); - EXPECT_NE(ds2, nullptr); - - // Create two Project operation on ds - std::vector column_project = {"image", "label"}; - ds1 = ds1->Project(column_project); - EXPECT_NE(ds1, nullptr); - ds2 = ds2->Project(column_project); - EXPECT_NE(ds2, nullptr); - - // Create a Concat operation on the ds - ds1 = ds1->Concat({ds2}); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds1->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 10); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test iterator of QMnistDataset with only the image column -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestQMnistIteratorOneColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistIteratorOneColumn."; - // Create a QMnist Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 1; - ds = ds->Batch(batch_size); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // Only select "image" column and drop others - std::vector columns = {"image"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::vector row; - ASSERT_OK(iter->GetNextRow(&row)); - std::vector expect_image = {1, 28, 28, 1}; - - uint64_t i = 0; - while (row.size() != 0) { - for (auto &v : row) { - MS_LOG(INFO) << "image shape:" << v.Shape(); - EXPECT_EQ(expect_image, v.Shape()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: QMnistTestDataset -/// Description: Test iterator of QMnistDataset with wrong column -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestQMnistIteratorWrongColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistIteratorWrongColumn."; - // Create a QMnist Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Pass wrong column name - std::vector columns = {"digital"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_EQ(iter, nullptr); -} - -/// Feature: QMnistTestDataset -/// Description: Test QMnistDataset GetDatasetSize -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestGetQMnistDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetQMnistTrainDatasetSize."; - - // Create a QMNIST Train Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 10); -} - -/// Feature: QMnistTestDataset -/// Description: Test QMnistDataset Getters method -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestQMnistDatasetGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetGetters."; - - // Create a QMNIST Train Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 10); - std::vector types = ToDETypes(ds->GetOutputTypes()); - std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); - std::vector column_names = {"image", "label"}; - int64_t num_classes = ds->GetNumClasses(); - EXPECT_EQ(types.size(), 2); - EXPECT_EQ(types[0].ToString(), "uint8"); - EXPECT_EQ(types[1].ToString(), "uint32"); - EXPECT_EQ(shapes.size(), 2); - EXPECT_EQ(shapes[0].ToString(), "<28,28,1>"); - EXPECT_EQ(shapes[1].ToString(), "<>"); - EXPECT_EQ(num_classes, -1); - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - - EXPECT_EQ(ds->GetDatasetSize(), 10); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - EXPECT_EQ(ds->GetNumClasses(), -1); - - EXPECT_EQ(ds->GetColumnNames(), column_names); - EXPECT_EQ(ds->GetDatasetSize(), 10); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - EXPECT_EQ(ds->GetNumClasses(), -1); - EXPECT_EQ(ds->GetDatasetSize(), 10); -} - -/// Feature: QMnistTestDataset -/// Description: Test QMnistDataset with invalid folder path input -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestQMnistDataFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataFail."; - - // Create a QMNIST Dataset - std::shared_ptr ds = QMnist("", "train", true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid QMNIST input - EXPECT_EQ(iter, nullptr); -} - -/// Feature: QMnistTestDataset -/// Description: Test QMnistDataset with invalid usage -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestQMnistDataWithInvalidUsageFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithInvalidUsageFail."; - - // Create a QMNIST Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "validation", true); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid QMNIST input, validation is not a valid usage - EXPECT_EQ(iter, nullptr); -} - -/// Feature: QMnistTestDataset -/// Description: Test QMnistDataset with null sampler -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestQMnistDataWithNullSamplerFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithNullSamplerFail."; - - // Create a QMNIST Dataset - std::string folder_path = datasets_root_path_ + "/testQMnistData/"; - std::shared_ptr ds = QMnist(folder_path, "train", true, nullptr); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid QMNIST input, sampler cannot be nullptr - EXPECT_EQ(iter, nullptr); -} +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; +using mindspore::dataset::DataType; +using mindspore::dataset::Tensor; +using mindspore::dataset::TensorShape; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: QMnistTrainDataset. +/// Description: Test basic usage of QMnistTrainDataset. +/// Expectation: Get correct number of data. +TEST_F(MindDataTestPipeline, TestQMnistTrainDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDataset."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test basic usage of QMnistDataset with test dataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistTestDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTestDataset."; + + // Create a QMNIST Test Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "test", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test basic usage of QMnistDataset with nist dataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistNistDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistNistDataset."; + + // Create a QMNIST Nist Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "nist", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test basic usage of QMnistDataset with all dataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistAllDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistAllDataset."; + + // Create a QMNIST All Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "all", true, std::make_shared(false, 20)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test basic usage of QMnistDataset with all and compat dataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistCompatDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistCompatDataset."; + + // Create a QMNIST All Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "all", false, std::make_shared(false, 20)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + auto label = row["label"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + MS_LOG(INFO) << "Tensor label shape: " << label.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test usage of QMnistDataset with pipeline mode +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistDatasetWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetWithPipeline."; + + // Create two QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds1 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + std::shared_ptr ds2 = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds + int32_t repeat_num = 1; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 1; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create two Project operation on ds + std::vector column_project = {"image", "label"}; + ds1 = ds1->Project(column_project); + EXPECT_NE(ds1, nullptr); + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test iterator of QMnistDataset with only the image column +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestQMnistIteratorOneColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistIteratorOneColumn."; + // Create a QMnist Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // Only select "image" column and drop others + std::vector columns = {"image"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::vector row; + ASSERT_OK(iter->GetNextRow(&row)); + std::vector expect_image = {1, 28, 28, 1}; + + uint64_t i = 0; + while (row.size() != 0) { + for (auto &v : row) { + MS_LOG(INFO) << "image shape:" << v.Shape(); + EXPECT_EQ(expect_image, v.Shape()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: QMnistTestDataset +/// Description: Test iterator of QMnistDataset with wrong column +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestQMnistIteratorWrongColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistIteratorWrongColumn."; + // Create a QMnist Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Pass wrong column name + std::vector columns = {"digital"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +/// Feature: QMnistTestDataset +/// Description: Test QMnistDataset GetDatasetSize +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestGetQMnistDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetQMnistTrainDatasetSize."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + +/// Feature: QMnistTestDataset +/// Description: Test QMnistDataset Getters method +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestQMnistDatasetGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistTrainDatasetGetters."; + + // Create a QMNIST Train Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + std::vector column_names = {"image", "label"}; + int64_t num_classes = ds->GetNumClasses(); + EXPECT_EQ(types.size(), 2); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "uint32"); + EXPECT_EQ(shapes.size(), 2); + EXPECT_EQ(shapes[0].ToString(), "<28,28,1>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + EXPECT_EQ(num_classes, -1); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + + EXPECT_EQ(ds->GetDatasetSize(), 10); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetNumClasses(), -1); + + EXPECT_EQ(ds->GetColumnNames(), column_names); + EXPECT_EQ(ds->GetDatasetSize(), 10); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + EXPECT_EQ(ds->GetNumClasses(), -1); + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + +/// Feature: QMnistTestDataset +/// Description: Test QMnistDataset with invalid folder path input +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestQMnistDataFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataFail."; + + // Create a QMNIST Dataset + std::shared_ptr ds = QMnist("", "train", true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input + EXPECT_EQ(iter, nullptr); +} + +/// Feature: QMnistTestDataset +/// Description: Test QMnistDataset with invalid usage +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestQMnistDataWithInvalidUsageFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithInvalidUsageFail."; + + // Create a QMNIST Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "validation", true); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input, validation is not a valid usage + EXPECT_EQ(iter, nullptr); +} + +/// Feature: QMnistTestDataset +/// Description: Test QMnistDataset with null sampler +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestQMnistDataWithNullSamplerFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestQMnistDataWithNullSamplerFail."; + + // Create a QMNIST Dataset + std::string folder_path = datasets_root_path_ + "/testQMnistData/"; + std::shared_ptr ds = QMnist(folder_path, "train", true, nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid QMNIST input, sampler cannot be nullptr + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_sbu_test.cc b/tests/ut/cpp/dataset/c_api_dataset_sbu_test.cc index b8d3322e261..5933f657beb 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_sbu_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_sbu_test.cc @@ -1,264 +1,264 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" - -#include "minddata/dataset/include/dataset/datasets.h" - -using namespace mindspore::dataset; -using mindspore::dataset::DataType; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: SBUDataset -/// Description: Test basic usage of SBUDataset -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestSBUDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDataset."; - - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("caption"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: SBUDataset -/// Description: Test SBUDataset with pipeline mode -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestSBUDatasetWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithPipeline."; - - // Create two SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds1 = SBU(folder_path, true, std::make_shared(false, 5)); - std::shared_ptr ds2 = SBU(folder_path, true, std::make_shared(false, 5)); - EXPECT_NE(ds1, nullptr); - EXPECT_NE(ds2, nullptr); - - // Create two Repeat operation on ds - int32_t repeat_num = 1; - ds1 = ds1->Repeat(repeat_num); - EXPECT_NE(ds1, nullptr); - repeat_num = 1; - ds2 = ds2->Repeat(repeat_num); - EXPECT_NE(ds2, nullptr); - - // Create two Project operation on ds - std::vector column_project = {"image", "caption"}; - ds1 = ds1->Project(column_project); - EXPECT_NE(ds1, nullptr); - ds2 = ds2->Project(column_project); - EXPECT_NE(ds2, nullptr); - - // Create a Concat operation on the ds - ds1 = ds1->Concat({ds2}); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds1->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("caption"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 10); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: SBUDataset -/// Description: Test iterator of SBUDataset with only the image column -/// Expectation: The data is processed successfully -TEST_F(MindDataTestPipeline, TestSBUIteratorOneColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUIteratorOneColumn."; - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 1; - ds = ds->Batch(batch_size); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // Only select "image" column and drop others - std::vector columns = {"image"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::vector row; - ASSERT_OK(iter->GetNextRow(&row)); - - uint64_t i = 0; - while (row.size() != 0) { - for (auto &v : row) { - MS_LOG(INFO) << "image shape:" << v.Shape(); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - EXPECT_EQ(i, 5); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: SBUDataset -/// Description: Test iterator of SBUDataset with wrong column -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestSBUIteratorWrongColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUIteratorWrongColumn."; - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); - EXPECT_NE(ds, nullptr); - - // Pass wrong column name - std::vector columns = {"digital"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_EQ(iter, nullptr); -} - -/// Feature: SBUDataset -/// Description: Test SBUDataset GetDatasetSize -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestGetSBUDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSBUDatasetSize."; - - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 5); -} - -/// Feature: SBUDataset -/// Description: Test SBUDataset Getters method -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestSBUDatasetGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetGetters."; - - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 5); - std::vector types = ToDETypes(ds->GetOutputTypes()); - std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); - std::vector column_names = {"image", "caption"}; - EXPECT_EQ(types.size(), 2); - EXPECT_EQ(types[0].ToString(), "uint8"); - EXPECT_EQ(types[1].ToString(), "string"); - - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - - EXPECT_EQ(ds->GetDatasetSize(), 5); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - EXPECT_EQ(ds->GetNumClasses(), -1); - - EXPECT_EQ(ds->GetColumnNames(), column_names); - EXPECT_EQ(ds->GetDatasetSize(), 5); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - EXPECT_EQ(ds->GetNumClasses(), -1); - EXPECT_EQ(ds->GetDatasetSize(), 5); -} - -/// Feature: SBUDataset -/// Description: Test SBUDataset with invalid folder path input -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestSBUDatasetFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetFail."; - - // Create a SBU Dataset - std::shared_ptr ds = SBU("", true, std::make_shared(false, 10)); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid SBU input - EXPECT_EQ(iter, nullptr); -} - -/// Feature: SBUDataset -/// Description: Test SBUDataset with null sampler -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestSBUDatasetWithNullSamplerFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithNullSamplerFail."; - - // Create a SBU Dataset - std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; - std::shared_ptr ds = SBU(folder_path, true, nullptr); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid SBU input, sampler cannot be nullptr - EXPECT_EQ(iter, nullptr); -} +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" + +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; +using mindspore::dataset::DataType; +using mindspore::dataset::Tensor; +using mindspore::dataset::TensorShape; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: SBUDataset +/// Description: Test basic usage of SBUDataset +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestSBUDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDataset."; + + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("caption"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: SBUDataset +/// Description: Test SBUDataset with pipeline mode +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestSBUDatasetWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithPipeline."; + + // Create two SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds1 = SBU(folder_path, true, std::make_shared(false, 5)); + std::shared_ptr ds2 = SBU(folder_path, true, std::make_shared(false, 5)); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds + int32_t repeat_num = 1; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 1; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create two Project operation on ds + std::vector column_project = {"image", "caption"}; + ds1 = ds1->Project(column_project); + EXPECT_NE(ds1, nullptr); + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("caption"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: SBUDataset +/// Description: Test iterator of SBUDataset with only the image column +/// Expectation: The data is processed successfully +TEST_F(MindDataTestPipeline, TestSBUIteratorOneColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUIteratorOneColumn."; + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // Only select "image" column and drop others + std::vector columns = {"image"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::vector row; + ASSERT_OK(iter->GetNextRow(&row)); + + uint64_t i = 0; + while (row.size() != 0) { + for (auto &v : row) { + MS_LOG(INFO) << "image shape:" << v.Shape(); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: SBUDataset +/// Description: Test iterator of SBUDataset with wrong column +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestSBUIteratorWrongColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUIteratorWrongColumn."; + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Pass wrong column name + std::vector columns = {"digital"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +/// Feature: SBUDataset +/// Description: Test SBUDataset GetDatasetSize +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestGetSBUDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetSBUDatasetSize."; + + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 5); +} + +/// Feature: SBUDataset +/// Description: Test SBUDataset Getters method +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestSBUDatasetGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetGetters."; + + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 5); + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + std::vector column_names = {"image", "caption"}; + EXPECT_EQ(types.size(), 2); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "string"); + + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + + EXPECT_EQ(ds->GetDatasetSize(), 5); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetNumClasses(), -1); + + EXPECT_EQ(ds->GetColumnNames(), column_names); + EXPECT_EQ(ds->GetDatasetSize(), 5); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + EXPECT_EQ(ds->GetNumClasses(), -1); + EXPECT_EQ(ds->GetDatasetSize(), 5); +} + +/// Feature: SBUDataset +/// Description: Test SBUDataset with invalid folder path input +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestSBUDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetFail."; + + // Create a SBU Dataset + std::shared_ptr ds = SBU("", true, std::make_shared(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid SBU input + EXPECT_EQ(iter, nullptr); +} + +/// Feature: SBUDataset +/// Description: Test SBUDataset with null sampler +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestSBUDatasetWithNullSamplerFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSBUDatasetWithNullSamplerFail."; + + // Create a SBU Dataset + std::string folder_path = datasets_root_path_ + "/testSBUDataset/"; + std::shared_ptr ds = SBU(folder_path, true, nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid SBU input, sampler cannot be nullptr + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/c_api_dataset_sogou_news_test.cc b/tests/ut/cpp/dataset/c_api_dataset_sogou_news_test.cc index e50e516d382..ea26220a875 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_sogou_news_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_sogou_news_test.cc @@ -1,459 +1,459 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/common.h" -#include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/include/dataset/datasets.h" - -using namespace mindspore::dataset; - -class MindDataTestPipeline : public UT::DatasetOpTesting { -protected: -}; - -/// Feature: Test SogouNews Dataset. -/// Description: Read SogouNewsDataset data and get data. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetBasic) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetBasic."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 3 samples - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: Test SogouNews Dataset(usage=all). -/// Description: Read train data and test data. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetUsageAll) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetUsageAll."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" - " top five","They say he has the talent of the top five in the league. The talent of the top five in the" - " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," - " but he is not the top five players because the top five players play every night."}, - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" - " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" - " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," - " or plaid suit, which is gentle and elegant and beautiful to a new height."}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" - " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 6 samples - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: Test Getters. -/// Description: Includes tests for shape, type, size. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsGetters."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse); - std::vector column_names = {"index", "title", "content"}; - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds-> GetDatasetSize(),3); - EXPECT_EQ(ds->GetColumnNames(),column_names); -} - -/// Feature: Test SogouNews Dataset(num_samples = 3). -/// Description: Test whether the interface meets expectations when NumSamples is equal to 3. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsNumSamples) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsNumSamples."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "test", 3, ShuffleMode::kFalse); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 3 samples - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: Test SogouNewsDataset in distribution. -/// Description: Test interface in a distributed state. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetDistribution) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetDistribution."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - // Expect 2 samples - EXPECT_EQ(i, 2); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: Error Test. -/// Description: Test the wrong input. -/// Expectation: Unable to read in data. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetFail."; - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::string invalid_csv_file = "./NotExistFile"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds0 = SogouNews("", "test", 0); - EXPECT_NE(ds0, nullptr); - // Create an iterator over the result of the above dataset - std::shared_ptr iter0 = ds0->CreateIterator(); - // Expect failure: invalid SogouNews input - EXPECT_EQ(iter0, nullptr); - - // Create a SogouNews Dataset with invalid usage - std::shared_ptr ds1 = SogouNews(invalid_csv_file); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter1 = ds1->CreateIterator(); - // Expect failure: invalid SogouNews input - EXPECT_EQ(iter1, nullptr); - - // Test invalid num_samples < -1 - std::shared_ptr ds2 = SogouNews(dataset_dir, "test", -1, ShuffleMode::kFalse); - EXPECT_NE(ds2, nullptr); - // Create an iterator over the result of the above dataset - std::shared_ptr iter2 = ds2->CreateIterator(); - // Expect failure: invalid SogouNews input - EXPECT_EQ(iter2, nullptr); - - // Test invalid num_shards < 1 - std::shared_ptr ds3 = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); - EXPECT_NE(ds3, nullptr); - // Create an iterator over the result of the above dataset - std::shared_ptr iter3 = ds3->CreateIterator(); - // Expect failure: invalid SogouNews input - EXPECT_EQ(iter3, nullptr); - - // Test invalid shard_id >= num_shards - std::shared_ptr ds4 = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); - EXPECT_NE(ds4, nullptr); - // Create an iterator over the result of the above dataset - std::shared_ptr iter4 = ds4->CreateIterator(); - // Expect failure: invalid SogouNews input - EXPECT_EQ(iter4, nullptr); -} - -/// Feature: Test SogouNews Dataset(ShuffleMode=kFiles). -/// Description: Test SogouNews Dataset interface with different ShuffleMode. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetShuffleFilesA) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetShuffleFilesA."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kFiles); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" - " top five","They say he has the talent of the top five in the league. The talent of the top five in the" - " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," - " but he is not the top five players because the top five players play every night."}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" - " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" - " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," - " or plaid suit, which is gentle and elegant and beautiful to a new height."}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."}, - {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" - " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 6 samples - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); -} - -/// Feature: Test SogouNews Dataset(ShuffleMode=kGlobal). -/// Description: Test SogouNews Dataset interface with different ShuffleMode. -/// Expectation: The data is processed successfully. -TEST_F(MindDataTestPipeline, TestSogouNewsDatasetShuffleFilesGlobal) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetShuffleFilesGlobal."; - - // Set configuration - uint32_t original_seed = GlobalContext::config_manager()->seed(); - uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); - MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - GlobalContext::config_manager()->set_seed(130); - GlobalContext::config_manager()->set_num_parallel_workers(4); - - std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; - std::vector column_names = {"index", "title", "content"}; - - std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kGlobal); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - EXPECT_NE(row.find("index"), row.end()); - std::vector> expected_result = { - {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, - {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" - " top five","They say he has the talent of the top five in the league. The talent of the top five in the" - " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," - " but he is not the top five players because the top five players play every night."}, - {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, - {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" - " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."}, - {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" - " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" - " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," - " or plaid suit, which is gentle and elegant and beautiful to a new height."}, - {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " - "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " - "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" - " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" - " National Games flame will burn here for 12 days."} - }; - - uint64_t i = 0; - while (row.size() != 0) { - for (int j = 0; j < column_names.size(); j++) { - auto text = row[column_names[j]]; - std::shared_ptr de_text; - ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); - std::string_view sv; - ASSERT_OK(de_text->GetItemAt(&sv, {})); - std::string ss(sv); - EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - // Expect 6 samples - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); - - // Restore configuration - GlobalContext::config_manager()->set_seed(original_seed); - GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; + +class MindDataTestPipeline : public UT::DatasetOpTesting { +protected: +}; + +/// Feature: Test SogouNews Dataset. +/// Description: Read SogouNewsDataset data and get data. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetBasic."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 samples + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: Test SogouNews Dataset(usage=all). +/// Description: Read train data and test data. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetUsageAll) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetUsageAll."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" + " top five","They say he has the talent of the top five in the league. The talent of the top five in the" + " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," + " but he is not the top five players because the top five players play every night."}, + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" + " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" + " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," + " or plaid suit, which is gentle and elegant and beautiful to a new height."}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" + " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 6 samples + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: Test Getters. +/// Description: Includes tests for shape, type, size. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsGetters."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::vector column_names = {"index", "title", "content"}; + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds-> GetDatasetSize(),3); + EXPECT_EQ(ds->GetColumnNames(),column_names); +} + +/// Feature: Test SogouNews Dataset(num_samples = 3). +/// Description: Test whether the interface meets expectations when NumSamples is equal to 3. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsNumSamples) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsNumSamples."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "test", 3, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 samples + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: Test SogouNewsDataset in distribution. +/// Description: Test interface in a distributed state. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetDistribution) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetDistribution."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 0); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 2 samples + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: Error Test. +/// Description: Test the wrong input. +/// Expectation: Unable to read in data. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetFail."; + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::string invalid_csv_file = "./NotExistFile"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds0 = SogouNews("", "test", 0); + EXPECT_NE(ds0, nullptr); + // Create an iterator over the result of the above dataset + std::shared_ptr iter0 = ds0->CreateIterator(); + // Expect failure: invalid SogouNews input + EXPECT_EQ(iter0, nullptr); + + // Create a SogouNews Dataset with invalid usage + std::shared_ptr ds1 = SogouNews(invalid_csv_file); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter1 = ds1->CreateIterator(); + // Expect failure: invalid SogouNews input + EXPECT_EQ(iter1, nullptr); + + // Test invalid num_samples < -1 + std::shared_ptr ds2 = SogouNews(dataset_dir, "test", -1, ShuffleMode::kFalse); + EXPECT_NE(ds2, nullptr); + // Create an iterator over the result of the above dataset + std::shared_ptr iter2 = ds2->CreateIterator(); + // Expect failure: invalid SogouNews input + EXPECT_EQ(iter2, nullptr); + + // Test invalid num_shards < 1 + std::shared_ptr ds3 = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 0); + EXPECT_NE(ds3, nullptr); + // Create an iterator over the result of the above dataset + std::shared_ptr iter3 = ds3->CreateIterator(); + // Expect failure: invalid SogouNews input + EXPECT_EQ(iter3, nullptr); + + // Test invalid shard_id >= num_shards + std::shared_ptr ds4 = SogouNews(dataset_dir, "test", 0, ShuffleMode::kFalse, 2, 2); + EXPECT_NE(ds4, nullptr); + // Create an iterator over the result of the above dataset + std::shared_ptr iter4 = ds4->CreateIterator(); + // Expect failure: invalid SogouNews input + EXPECT_EQ(iter4, nullptr); +} + +/// Feature: Test SogouNews Dataset(ShuffleMode=kFiles). +/// Description: Test SogouNews Dataset interface with different ShuffleMode. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetShuffleFilesA) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetShuffleFilesA."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kFiles); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" + " top five","They say he has the talent of the top five in the league. The talent of the top five in the" + " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," + " but he is not the top five players because the top five players play every night."}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" + " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" + " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," + " or plaid suit, which is gentle and elegant and beautiful to a new height."}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."}, + {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" + " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 6 samples + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: Test SogouNews Dataset(ShuffleMode=kGlobal). +/// Description: Test SogouNews Dataset interface with different ShuffleMode. +/// Expectation: The data is processed successfully. +TEST_F(MindDataTestPipeline, TestSogouNewsDatasetShuffleFilesGlobal) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSogouNewsDatasetShuffleFilesGlobal."; + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(130); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + std::string dataset_dir = datasets_root_path_ + "/testSogouNews/"; + std::vector column_names = {"index", "title", "content"}; + + std::shared_ptr ds = SogouNews(dataset_dir, "all" , 0, ShuffleMode::kGlobal); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("index"), row.end()); + std::vector> expected_result = { + {"1","Make history","Su Bingtian's 100m breakthrough\\n 9.83"}, + {"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the" + " top five","They say he has the talent of the top five in the league. The talent of the top five in the" + " league is one of the most disrespectful statements. I say he has the talent of the top five in the league," + " but he is not the top five players because the top five players play every night."}, + {"4","Tesla price","Tesla reduced its price by 70000 yuan"}, + {"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover" + " blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression."}, + {"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long" + " hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are" + " released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat," + " or plaid suit, which is gentle and elegant and beautiful to a new height."}, + {"1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, " + "the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, " + "Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic" + " Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th" + " National Games flame will burn here for 12 days."} + }; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto text = row[column_names[j]]; + std::shared_ptr de_text; + ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); + std::string_view sv; + ASSERT_OK(de_text->GetItemAt(&sv, {})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + // Expect 6 samples + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); } \ No newline at end of file diff --git a/tests/ut/cpp/dataset/c_api_dataset_usps_test.cc b/tests/ut/cpp/dataset/c_api_dataset_usps_test.cc index 650185a30d1..6deeb9fdbf1 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_usps_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_usps_test.cc @@ -1,339 +1,339 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/common.h" -#include "minddata/dataset/include/dataset/datasets.h" - -using namespace mindspore::dataset; -using mindspore::dataset::DataType; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; - -class MindDataTestPipeline : public UT::DatasetOpTesting { - protected: -}; - -/// Feature: USPSDataset -/// Description: Test basic usage of USPSDataset with train dataset -/// Expectation: Get correct number of data -TEST_F(MindDataTestPipeline, TestUSPSTrainDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDataset."; - - // Create a USPS Train Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "train"); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: USPSDataset -/// Description: Test basic usage of USPSDataset with test dataset -/// Expectation: Get correct number of data -TEST_F(MindDataTestPipeline, TestUSPSTestDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTestDataset."; - - // Create a USPS Test Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "test"); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: USPSDataset -/// Description: Test basic usage of USPSDataset with all dataset -/// Expectation: Get correct number of data -TEST_F(MindDataTestPipeline, TestUSPSAllDataset) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSAllDataset."; - - // Create a USPS Test Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "all"); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: USPSDataset -/// Description: Test usage of USPSDataset with pipeline mode -/// Expectation: Get correct number of data -TEST_F(MindDataTestPipeline, TestUSPSDatasetWithPipeline) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetWithPipeline."; - - // Create two USPS Train Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds1 = USPS(folder_path, "train"); - std::shared_ptr ds2 = USPS(folder_path, "train"); - EXPECT_NE(ds1, nullptr); - EXPECT_NE(ds2, nullptr); - - // Create two Repeat operation on ds - int32_t repeat_num = 1; - ds1 = ds1->Repeat(repeat_num); - EXPECT_NE(ds1, nullptr); - repeat_num = 1; - ds2 = ds2->Repeat(repeat_num); - EXPECT_NE(ds2, nullptr); - - // Create two Project operation on ds - std::vector column_project = {"image", "label"}; - ds1 = ds1->Project(column_project); - EXPECT_NE(ds1, nullptr); - ds2 = ds2->Project(column_project); - EXPECT_NE(ds2, nullptr); - - // Create a Concat operation on the ds - ds1 = ds1->Concat({ds2}); - EXPECT_NE(ds1, nullptr); - - // Create an iterator over the result of the above dataset - // This will trigger the creation of the Execution Tree and launch it. - std::shared_ptr iter = ds1->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::unordered_map row; - ASSERT_OK(iter->GetNextRow(&row)); - - EXPECT_NE(row.find("image"), row.end()); - EXPECT_NE(row.find("label"), row.end()); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); - ASSERT_OK(iter->GetNextRow(&row)); - } - - EXPECT_EQ(i, 6); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: USPSDataset -/// Description: Test iterator of USPSDataset with only the "image" column -/// Expectation: Get correct data -TEST_F(MindDataTestPipeline, TestUSPSIteratorOneColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSIteratorOneColumn."; - // Create a USPS Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "train"); - EXPECT_NE(ds, nullptr); - - // Create a Batch operation on ds - int32_t batch_size = 1; - ds = ds->Batch(batch_size); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - // Only select "image" column and drop others - std::vector columns = {"image"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_NE(iter, nullptr); - - // Iterate the dataset and get each row - std::vector row; - ASSERT_OK(iter->GetNextRow(&row)); - std::vector expect_image = {1, 16, 16, 1}; - - uint64_t i = 0; - while (row.size() != 0) { - for (auto &v : row) { - MS_LOG(INFO) << "image shape:" << v.Shape(); - EXPECT_EQ(expect_image, v.Shape()); - } - ASSERT_OK(iter->GetNextRow(&row)); - i++; - } - - EXPECT_EQ(i, 3); - - // Manually terminate the pipeline - iter->Stop(); -} - -/// Feature: USPSDataset -/// Description: Test iterator of USPSDataset with wrong column -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestUSPSIteratorWrongColumn) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSIteratorWrongColumn."; - // Create a USPS Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "train"); - EXPECT_NE(ds, nullptr); - - // Pass wrong column name - std::vector columns = {"digital"}; - std::shared_ptr project_ds = ds->Project(columns); - std::shared_ptr iter = project_ds->CreateIterator(); - EXPECT_EQ(iter, nullptr); -} - -/// Feature: USPSDataset -/// Description: Test GetDatasetSize of USPSDataset -/// Expectation: Output is equal to the expected output -TEST_F(MindDataTestPipeline, TestGetUSPSDatasetSize) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetUSPSTrainDatasetSize."; - - // Create a USPS Train Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "train"); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 3); -} - -/// Feature: USPSDataset -/// Description: Test usage of getters USPSDataset -/// Expectation: Get correct number of data and correct tensor shape -TEST_F(MindDataTestPipeline, TestUSPSDatasetGetters) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetGetters."; - - // Create a USPS Train Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "train"); - EXPECT_NE(ds, nullptr); - - EXPECT_EQ(ds->GetDatasetSize(), 3); - std::vector types = ToDETypes(ds->GetOutputTypes()); - std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); - std::vector column_names = {"image", "label"}; - EXPECT_EQ(types.size(), 2); - EXPECT_EQ(types[0].ToString(), "uint8"); - EXPECT_EQ(types[1].ToString(), "uint32"); - EXPECT_EQ(shapes.size(), 2); - EXPECT_EQ(shapes[0].ToString(), "<16,16,1>"); - EXPECT_EQ(shapes[1].ToString(), "<>"); - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - - EXPECT_EQ(ds->GetDatasetSize(), 3); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - - EXPECT_EQ(ds->GetColumnNames(), column_names); - EXPECT_EQ(ds->GetDatasetSize(), 3); - EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); - EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); - EXPECT_EQ(ds->GetBatchSize(), 1); - EXPECT_EQ(ds->GetRepeatCount(), 1); - EXPECT_EQ(ds->GetDatasetSize(), 3); -} - -/// Feature: USPSDataset -/// Description: Test failure of USPSDataset with empty string as folder path -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestUSPSDatasetFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetFail."; - - // Create a USPS Dataset - std::shared_ptr ds = USPS("", "train"); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid USPS input - EXPECT_EQ(iter, nullptr); -} - -/// Feature: USPSDataset -/// Description: Test failure of USPSDataset with invalid usage -/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr -TEST_F(MindDataTestPipeline, TestUSPSDatasetWithInvalidUsageFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetWithInvalidUsageFail."; - - // Create a USPS Dataset - std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; - std::shared_ptr ds = USPS(folder_path, "validation"); - EXPECT_NE(ds, nullptr); - - // Create an iterator over the result of the above dataset - std::shared_ptr iter = ds->CreateIterator(); - // Expect failure: invalid USPS input, validation is not a valid usage - EXPECT_EQ(iter, nullptr); -} +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; +using mindspore::dataset::DataType; +using mindspore::dataset::Tensor; +using mindspore::dataset::TensorShape; + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + +/// Feature: USPSDataset +/// Description: Test basic usage of USPSDataset with train dataset +/// Expectation: Get correct number of data +TEST_F(MindDataTestPipeline, TestUSPSTrainDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDataset."; + + // Create a USPS Train Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "train"); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: USPSDataset +/// Description: Test basic usage of USPSDataset with test dataset +/// Expectation: Get correct number of data +TEST_F(MindDataTestPipeline, TestUSPSTestDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTestDataset."; + + // Create a USPS Test Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "test"); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: USPSDataset +/// Description: Test basic usage of USPSDataset with all dataset +/// Expectation: Get correct number of data +TEST_F(MindDataTestPipeline, TestUSPSAllDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSAllDataset."; + + // Create a USPS Test Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "all"); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: USPSDataset +/// Description: Test usage of USPSDataset with pipeline mode +/// Expectation: Get correct number of data +TEST_F(MindDataTestPipeline, TestUSPSDatasetWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetWithPipeline."; + + // Create two USPS Train Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds1 = USPS(folder_path, "train"); + std::shared_ptr ds2 = USPS(folder_path, "train"); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds + int32_t repeat_num = 1; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 1; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create two Project operation on ds + std::vector column_project = {"image", "label"}; + ds1 = ds1->Project(column_project); + EXPECT_NE(ds1, nullptr); + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("image"), row.end()); + EXPECT_NE(row.find("label"), row.end()); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 6); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: USPSDataset +/// Description: Test iterator of USPSDataset with only the "image" column +/// Expectation: Get correct data +TEST_F(MindDataTestPipeline, TestUSPSIteratorOneColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSIteratorOneColumn."; + // Create a USPS Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "train"); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // Only select "image" column and drop others + std::vector columns = {"image"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::vector row; + ASSERT_OK(iter->GetNextRow(&row)); + std::vector expect_image = {1, 16, 16, 1}; + + uint64_t i = 0; + while (row.size() != 0) { + for (auto &v : row) { + MS_LOG(INFO) << "image shape:" << v.Shape(); + EXPECT_EQ(expect_image, v.Shape()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + EXPECT_EQ(i, 3); + + // Manually terminate the pipeline + iter->Stop(); +} + +/// Feature: USPSDataset +/// Description: Test iterator of USPSDataset with wrong column +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestUSPSIteratorWrongColumn) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSIteratorWrongColumn."; + // Create a USPS Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "train"); + EXPECT_NE(ds, nullptr); + + // Pass wrong column name + std::vector columns = {"digital"}; + std::shared_ptr project_ds = ds->Project(columns); + std::shared_ptr iter = project_ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} + +/// Feature: USPSDataset +/// Description: Test GetDatasetSize of USPSDataset +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestPipeline, TestGetUSPSDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetUSPSTrainDatasetSize."; + + // Create a USPS Train Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "train"); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); +} + +/// Feature: USPSDataset +/// Description: Test usage of getters USPSDataset +/// Expectation: Get correct number of data and correct tensor shape +TEST_F(MindDataTestPipeline, TestUSPSDatasetGetters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSTrainDatasetGetters."; + + // Create a USPS Train Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "train"); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); + std::vector types = ToDETypes(ds->GetOutputTypes()); + std::vector shapes = ToTensorShapeVec(ds->GetOutputShapes()); + std::vector column_names = {"image", "label"}; + EXPECT_EQ(types.size(), 2); + EXPECT_EQ(types[0].ToString(), "uint8"); + EXPECT_EQ(types[1].ToString(), "uint32"); + EXPECT_EQ(shapes.size(), 2); + EXPECT_EQ(shapes[0].ToString(), "<16,16,1>"); + EXPECT_EQ(shapes[1].ToString(), "<>"); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + + EXPECT_EQ(ds->GetDatasetSize(), 3); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + + EXPECT_EQ(ds->GetColumnNames(), column_names); + EXPECT_EQ(ds->GetDatasetSize(), 3); + EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); + EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); + EXPECT_EQ(ds->GetBatchSize(), 1); + EXPECT_EQ(ds->GetRepeatCount(), 1); + EXPECT_EQ(ds->GetDatasetSize(), 3); +} + +/// Feature: USPSDataset +/// Description: Test failure of USPSDataset with empty string as folder path +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestUSPSDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetFail."; + + // Create a USPS Dataset + std::shared_ptr ds = USPS("", "train"); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid USPS input + EXPECT_EQ(iter, nullptr); +} + +/// Feature: USPSDataset +/// Description: Test failure of USPSDataset with invalid usage +/// Expectation: Error message is logged, and CreateIterator() for invalid pipeline returns nullptr +TEST_F(MindDataTestPipeline, TestUSPSDatasetWithInvalidUsageFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUSPSDatasetWithInvalidUsageFail."; + + // Create a USPS Dataset + std::string folder_path = datasets_root_path_ + "/testUSPSDataset/"; + std::shared_ptr ds = USPS(folder_path, "validation"); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: invalid USPS input, validation is not a valid usage + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc old mode 100755 new mode 100644 diff --git a/tests/ut/cpp/graph_kernel/common/graph_kernel_common_test_suite.h b/tests/ut/cpp/graph_kernel/common/graph_kernel_common_test_suite.h index 42c82e97dbe..283a66f40a0 100644 --- a/tests/ut/cpp/graph_kernel/common/graph_kernel_common_test_suite.h +++ b/tests/ut/cpp/graph_kernel/common/graph_kernel_common_test_suite.h @@ -1,32 +1,32 @@ -/** - * Copyright 2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ -#define TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ - -#include "common/common_test.h" -#include "common/graph_optimizer_test_framework.h" - -namespace mindspore::graphkernel::test { -using mindspore::test::ConstructGraph; -using mindspore::test::RunPass; - -class GraphKernelCommonTestSuite : public UT::Common { - public: - GraphKernelCommonTestSuite(){}; - virtual ~GraphKernelCommonTestSuite() = default; -}; -} // namespace mindspore::graphkernel::test -#endif // TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ +#define TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ + +#include "common/common_test.h" +#include "common/graph_optimizer_test_framework.h" + +namespace mindspore::graphkernel::test { +using mindspore::test::ConstructGraph; +using mindspore::test::RunPass; + +class GraphKernelCommonTestSuite : public UT::Common { + public: + GraphKernelCommonTestSuite(){}; + virtual ~GraphKernelCommonTestSuite() = default; +}; +} // namespace mindspore::graphkernel::test +#endif // TESTS_UT_CPP_GRAPH_KERNEL_COMMON_GRAPH_KERNEL_TEST_SUITE_H_ diff --git a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc index 6c68c98affb..f6acb47f9c6 100644 --- a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc @@ -1,287 +1,287 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/common_test.h" -#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" -#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" -#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" -#include "frontend/parallel/auto_parallel/stage_compute.h" -#include -#include "ir/value.h" - -namespace mindspore { -namespace parallel { -#define ARRAY_A 3000 // also 'I' :height of the first input tensor -#define ARRAY_B 1000 // also 'K' :used by both input tensor -#define ARRAY_C 4000 // also 'J' :width of the first input tensor - -class TestPartition : public UT::Common { - public: - void Create(std::shared_ptr graph, int node_num, std::vector edge_head, - std::vector edge_tail); - void InitEdge(std::shared_ptr graph, int vHead, int vTail); - void InitNode(std::shared_ptr graph, int num_node); - TensorParam *MakeTensor(int n, int c, int h, int w); - std::shared_ptr MakeMatMulData(int numNode); -}; - -// Local function to create test input graph with nodes -void TestPartition::Create(std::shared_ptr graph, int node_num, std::vector edge_head, - std::vector edge_tail) { - TestPartition::InitNode(graph, node_num); - unsigned int edge_num = edge_head.size(); - if (edge_num != edge_tail.size()) { - exit(1); - }; - - for (unsigned int i = 0; i < edge_num; i++) { - TestPartition::InitEdge(graph, edge_head[i], edge_tail[i]); - }; -} - -// Local function for Create() to crate Node -void TestPartition::InitNode(std::shared_ptr graph, int num_node) { - Graph::NodeType NewNode; - for (int i = 0; i < num_node; i++) { - graph->nodes.push_back(NewNode); - std::stringstream ss; - ss << 'N' << i; - graph->nodes[i].name = ss.str(); - graph->nodes[i].info = kConstant; - }; -} - -// Local function for Create() to crate Edge -void TestPartition::InitEdge(std::shared_ptr graph, int vHead, int vTail) { - graph->nodes[vHead].node_out.push_back(vTail); - graph->nodes[vTail].node_in.push_back(vHead); -} - -// Local function for Create() to crate Tensor -TensorParam *TestPartition::MakeTensor(int n, int c, int h, int w) { - TensorParam *p_tensor = new TensorParam; - p_tensor->tensor_type = kFloat32; - p_tensor->tensor_shape.shape_n = n; - p_tensor->tensor_shape.shape_c = c; - p_tensor->tensor_shape.shape_h = h; - p_tensor->tensor_shape.shape_w = w; - - return p_tensor; -}; - -// Local function for Create() to create MatMul Operator -// @numNode include Tensor and Operator, for example 4(1 Input Tensor, 1 Input Tensor, 1 Operator, 1 Output Tensor) -std::shared_ptr TestPartition::MakeMatMulData(int numNode) { - // Build Edges - int edgeNum = 0; - constexpr int INTERVAL = 2; - if (numNode % INTERVAL == 0 && numNode != 0) { - edgeNum = numNode - INTERVAL; - } else if (numNode % INTERVAL == 1) { - edgeNum = numNode - 1; - } else { - edgeNum = 0; - }; - - std::vector edgeHead(edgeNum); // int edgeHead[8] = {0,2,4,6,1,3,5,7}; - std::vector edgeTail(edgeNum); // int edgeTail[8] = {2,4,6,8,2,4,6,8}; - - for (int i = 0; i < edgeNum; i++) { - edgeHead[i] = i; - if (i % INTERVAL == 0) { - edgeTail[i] = i + INTERVAL; - } else { - edgeTail[i] = i + 1; - }; - }; - - // Create graph - std::shared_ptr graph(new Graph); - TestPartition::Create(graph, numNode, edgeHead, edgeTail); - - // Add Node information. - for (int i = 0; i < numNode; i++) { - if (0 == i) { - graph->nodes[i].info = InfoType::kConstant; - TensorParam *p_tensor_out = new TensorParam; - p_tensor_out->tensor_type = kFloat32; - p_tensor_out->tensor_shape.shape_w = ARRAY_B; - p_tensor_out->tensor_shape.shape_h = ARRAY_A; - - graph->nodes[i].tensor_parm = *p_tensor_out; - - } else if (0 == i % 4) { - graph->nodes[i].info = InfoType::kApplication; - graph->nodes[i].apply.op_type = OperatorType::kRecMatMul; - - TensorParam *p_tensor0 = new TensorParam; - p_tensor0->tensor_type = kFloat32; - p_tensor0->tensor_shape.shape_w = ARRAY_C; - p_tensor0->tensor_shape.shape_h = ARRAY_A; - - TensorParam *p_tensor1 = new TensorParam; - p_tensor1->tensor_type = kFloat32; - p_tensor1->tensor_shape.shape_w = ARRAY_B; - p_tensor1->tensor_shape.shape_h = ARRAY_C; - - TensorParam *p_tensor_out = new TensorParam; - p_tensor_out->tensor_type = kFloat32; - p_tensor_out->tensor_shape.shape_w = ARRAY_B; - p_tensor_out->tensor_shape.shape_h = ARRAY_A; - - graph->nodes[i].apply.arguments[0] = *p_tensor0; - graph->nodes[i].apply.arguments[1] = *p_tensor1; - graph->nodes[i].tensor_parm = *p_tensor_out; - - } else if (1 == i % 4) { - graph->nodes[i].info = InfoType::kConstant; - - TensorParam *p_tensor_out = new TensorParam; - p_tensor_out->tensor_type = kFloat32; - p_tensor_out->tensor_shape.shape_w = ARRAY_C; - p_tensor_out->tensor_shape.shape_h = ARRAY_B; - - graph->nodes[i].tensor_parm = *p_tensor_out; - - } else if (2 == i % 4) { - graph->nodes[i].info = InfoType::kApplication; - graph->nodes[i].apply.op_type = OperatorType::kRecMatMul; - - TensorParam *p_tensor0 = new TensorParam; - p_tensor0->tensor_type = kFloat32; - p_tensor0->tensor_shape.shape_w = ARRAY_B; - p_tensor0->tensor_shape.shape_h = ARRAY_A; - - TensorParam *p_tensor1 = new TensorParam; - p_tensor1->tensor_type = kFloat32; - p_tensor1->tensor_shape.shape_w = ARRAY_C; - p_tensor1->tensor_shape.shape_h = ARRAY_B; - - TensorParam *p_tensor_out = new TensorParam; - p_tensor_out->tensor_type = kFloat32; - p_tensor_out->tensor_shape.shape_w = ARRAY_C; - p_tensor_out->tensor_shape.shape_h = ARRAY_A; - - graph->nodes[i].apply.arguments[0] = *p_tensor0; - graph->nodes[i].apply.arguments[1] = *p_tensor1; - graph->nodes[i].tensor_parm = *p_tensor_out; - - } else if (3 == i % 4) { - graph->nodes[i].info = InfoType::kConstant; - - TensorParam *p_tensor_out = new TensorParam; - p_tensor_out->tensor_type = kFloat32; - p_tensor_out->tensor_shape.shape_w = ARRAY_B; - p_tensor_out->tensor_shape.shape_h = ARRAY_C; - - graph->nodes[i].tensor_parm = *p_tensor_out; - }; - }; - return graph; -}; - -TEST_F(TestPartition, test_GetWeights) { - std::shared_ptr graph = MakeMatMulData(9); - double wop1 = GetWeights(graph->nodes[2]); - double wop2 = GetWeights(graph->nodes[4]); - double wop3 = GetWeights(graph->nodes[6]); - double wop4 = GetWeights(graph->nodes[8]); - ASSERT_GE(wop1, wop2); - ASSERT_GE(wop2, wop3); - ASSERT_GE(wop3, wop4); -} - -TEST_F(TestPartition, test_SortByWeight) { - std::shared_ptr graph = MakeMatMulData(9); - std::vector result = SortByWeight(graph); - ASSERT_GE(result.at(0), result.at(1)); - ASSERT_GE(result.at(1), result.at(2)); - ASSERT_GE(result.at(2), result.at(3)); -} - -TEST_F(TestPartition, test_SortByWeight2) { - std::shared_ptr graph = MakeMatMulData(5); - std::vector result = SortByWeight(graph); - ASSERT_GE(result.at(0), result.at(1)); -} - -TEST_F(TestPartition, test_PartitionNode) { - std::shared_ptr graph = MakeMatMulData(9); - // node 2 is the first kRecMatMul Operator - Graph::NodeType node2 = graph->nodes[2]; - std::vector> nameToStrategy; - bool isTraining = true; - StrategyRec str = PartitionNode(node2, nameToStrategy, graph, isTraining); - ASSERT_EQ(str.outputTensor.str_h, 0.5); - ASSERT_EQ(str.outputTensor.str_w, 1); -} - -TEST_F(TestPartition, test_PartitionForAllDevices) { - std::shared_ptr graph = MakeMatMulData(9); - double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; - bool isTraining = true; - ASSERT_EQ(PartitionForAllDevices(1024, device_memory, graph, isTraining, nullptr), SUCCESS); -} - -TEST_F(TestPartition, test_PartitionForAllDevices2) { - std::shared_ptr graph = MakeMatMulData(9); - double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; - bool isTraining = true; - ASSERT_EQ(PartitionForAllDevices(2, device_memory, graph, isTraining, nullptr), SUCCESS); -} - -// Negative case: partition on 0 device -TEST_F(TestPartition, test_PartitionForAllDevices0) { - std::shared_ptr graph = MakeMatMulData(9); - double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; - bool isTraining = true; - // Throw Exception "Number of devices can't be 0" - EXPECT_ANY_THROW(PartitionForAllDevices(0, device_memory, graph, isTraining, nullptr)); -} - -TEST_F(TestPartition, test_ApplyStrToTensor) { - std::shared_ptr graph = MakeMatMulData(9); - std::vector> nameToStrategy; - bool isTraining = true; - graph->nodes[4].apply.str = PartitionNode(graph->nodes[4], nameToStrategy, graph, isTraining); - auto h_str = graph->nodes[4].apply.str.outputTensor.str_h; - auto w_str = graph->nodes[4].apply.str.outputTensor.str_w; - - Graph::NodeType n_node = ApplyStrToTensor(graph->nodes[4]); - auto h_node = n_node.tensor_parm.tensor_str.str_h; - auto w_node = n_node.tensor_parm.tensor_str.str_w; - ASSERT_EQ(h_str, h_node); - ASSERT_EQ(w_str, w_node); -} - -/// Feature: test GetDPAndMP. -/// Description: -/// Expectation: success -TEST_F(TestPartition, test_get_dp_mp) { - size_t dp, mp; - bool isTraining = true; - double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; - std::shared_ptr graph = MakeMatMulData(9); - PartitionForAllDevices(8, device_memory, graph, isTraining, nullptr); - std::tie(dp, mp) = GetDPAndMP(graph, 1); - ASSERT_GT(dp, 0); - ASSERT_GT(mp, 0); - ASSERT_LE(dp, GetNumDevices()); - ASSERT_LE(mp, GetNumDevices()); -} - -} // namespace parallel -} // namespace mindspore +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/auto_parallel/stage_compute.h" +#include +#include "ir/value.h" + +namespace mindspore { +namespace parallel { +#define ARRAY_A 3000 // also 'I' :height of the first input tensor +#define ARRAY_B 1000 // also 'K' :used by both input tensor +#define ARRAY_C 4000 // also 'J' :width of the first input tensor + +class TestPartition : public UT::Common { + public: + void Create(std::shared_ptr graph, int node_num, std::vector edge_head, + std::vector edge_tail); + void InitEdge(std::shared_ptr graph, int vHead, int vTail); + void InitNode(std::shared_ptr graph, int num_node); + TensorParam *MakeTensor(int n, int c, int h, int w); + std::shared_ptr MakeMatMulData(int numNode); +}; + +// Local function to create test input graph with nodes +void TestPartition::Create(std::shared_ptr graph, int node_num, std::vector edge_head, + std::vector edge_tail) { + TestPartition::InitNode(graph, node_num); + unsigned int edge_num = edge_head.size(); + if (edge_num != edge_tail.size()) { + exit(1); + }; + + for (unsigned int i = 0; i < edge_num; i++) { + TestPartition::InitEdge(graph, edge_head[i], edge_tail[i]); + }; +} + +// Local function for Create() to crate Node +void TestPartition::InitNode(std::shared_ptr graph, int num_node) { + Graph::NodeType NewNode; + for (int i = 0; i < num_node; i++) { + graph->nodes.push_back(NewNode); + std::stringstream ss; + ss << 'N' << i; + graph->nodes[i].name = ss.str(); + graph->nodes[i].info = kConstant; + }; +} + +// Local function for Create() to crate Edge +void TestPartition::InitEdge(std::shared_ptr graph, int vHead, int vTail) { + graph->nodes[vHead].node_out.push_back(vTail); + graph->nodes[vTail].node_in.push_back(vHead); +} + +// Local function for Create() to crate Tensor +TensorParam *TestPartition::MakeTensor(int n, int c, int h, int w) { + TensorParam *p_tensor = new TensorParam; + p_tensor->tensor_type = kFloat32; + p_tensor->tensor_shape.shape_n = n; + p_tensor->tensor_shape.shape_c = c; + p_tensor->tensor_shape.shape_h = h; + p_tensor->tensor_shape.shape_w = w; + + return p_tensor; +}; + +// Local function for Create() to create MatMul Operator +// @numNode include Tensor and Operator, for example 4(1 Input Tensor, 1 Input Tensor, 1 Operator, 1 Output Tensor) +std::shared_ptr TestPartition::MakeMatMulData(int numNode) { + // Build Edges + int edgeNum = 0; + constexpr int INTERVAL = 2; + if (numNode % INTERVAL == 0 && numNode != 0) { + edgeNum = numNode - INTERVAL; + } else if (numNode % INTERVAL == 1) { + edgeNum = numNode - 1; + } else { + edgeNum = 0; + }; + + std::vector edgeHead(edgeNum); // int edgeHead[8] = {0,2,4,6,1,3,5,7}; + std::vector edgeTail(edgeNum); // int edgeTail[8] = {2,4,6,8,2,4,6,8}; + + for (int i = 0; i < edgeNum; i++) { + edgeHead[i] = i; + if (i % INTERVAL == 0) { + edgeTail[i] = i + INTERVAL; + } else { + edgeTail[i] = i + 1; + }; + }; + + // Create graph + std::shared_ptr graph(new Graph); + TestPartition::Create(graph, numNode, edgeHead, edgeTail); + + // Add Node information. + for (int i = 0; i < numNode; i++) { + if (0 == i) { + graph->nodes[i].info = InfoType::kConstant; + TensorParam *p_tensor_out = new TensorParam; + p_tensor_out->tensor_type = kFloat32; + p_tensor_out->tensor_shape.shape_w = ARRAY_B; + p_tensor_out->tensor_shape.shape_h = ARRAY_A; + + graph->nodes[i].tensor_parm = *p_tensor_out; + + } else if (0 == i % 4) { + graph->nodes[i].info = InfoType::kApplication; + graph->nodes[i].apply.op_type = OperatorType::kRecMatMul; + + TensorParam *p_tensor0 = new TensorParam; + p_tensor0->tensor_type = kFloat32; + p_tensor0->tensor_shape.shape_w = ARRAY_C; + p_tensor0->tensor_shape.shape_h = ARRAY_A; + + TensorParam *p_tensor1 = new TensorParam; + p_tensor1->tensor_type = kFloat32; + p_tensor1->tensor_shape.shape_w = ARRAY_B; + p_tensor1->tensor_shape.shape_h = ARRAY_C; + + TensorParam *p_tensor_out = new TensorParam; + p_tensor_out->tensor_type = kFloat32; + p_tensor_out->tensor_shape.shape_w = ARRAY_B; + p_tensor_out->tensor_shape.shape_h = ARRAY_A; + + graph->nodes[i].apply.arguments[0] = *p_tensor0; + graph->nodes[i].apply.arguments[1] = *p_tensor1; + graph->nodes[i].tensor_parm = *p_tensor_out; + + } else if (1 == i % 4) { + graph->nodes[i].info = InfoType::kConstant; + + TensorParam *p_tensor_out = new TensorParam; + p_tensor_out->tensor_type = kFloat32; + p_tensor_out->tensor_shape.shape_w = ARRAY_C; + p_tensor_out->tensor_shape.shape_h = ARRAY_B; + + graph->nodes[i].tensor_parm = *p_tensor_out; + + } else if (2 == i % 4) { + graph->nodes[i].info = InfoType::kApplication; + graph->nodes[i].apply.op_type = OperatorType::kRecMatMul; + + TensorParam *p_tensor0 = new TensorParam; + p_tensor0->tensor_type = kFloat32; + p_tensor0->tensor_shape.shape_w = ARRAY_B; + p_tensor0->tensor_shape.shape_h = ARRAY_A; + + TensorParam *p_tensor1 = new TensorParam; + p_tensor1->tensor_type = kFloat32; + p_tensor1->tensor_shape.shape_w = ARRAY_C; + p_tensor1->tensor_shape.shape_h = ARRAY_B; + + TensorParam *p_tensor_out = new TensorParam; + p_tensor_out->tensor_type = kFloat32; + p_tensor_out->tensor_shape.shape_w = ARRAY_C; + p_tensor_out->tensor_shape.shape_h = ARRAY_A; + + graph->nodes[i].apply.arguments[0] = *p_tensor0; + graph->nodes[i].apply.arguments[1] = *p_tensor1; + graph->nodes[i].tensor_parm = *p_tensor_out; + + } else if (3 == i % 4) { + graph->nodes[i].info = InfoType::kConstant; + + TensorParam *p_tensor_out = new TensorParam; + p_tensor_out->tensor_type = kFloat32; + p_tensor_out->tensor_shape.shape_w = ARRAY_B; + p_tensor_out->tensor_shape.shape_h = ARRAY_C; + + graph->nodes[i].tensor_parm = *p_tensor_out; + }; + }; + return graph; +}; + +TEST_F(TestPartition, test_GetWeights) { + std::shared_ptr graph = MakeMatMulData(9); + double wop1 = GetWeights(graph->nodes[2]); + double wop2 = GetWeights(graph->nodes[4]); + double wop3 = GetWeights(graph->nodes[6]); + double wop4 = GetWeights(graph->nodes[8]); + ASSERT_GE(wop1, wop2); + ASSERT_GE(wop2, wop3); + ASSERT_GE(wop3, wop4); +} + +TEST_F(TestPartition, test_SortByWeight) { + std::shared_ptr graph = MakeMatMulData(9); + std::vector result = SortByWeight(graph); + ASSERT_GE(result.at(0), result.at(1)); + ASSERT_GE(result.at(1), result.at(2)); + ASSERT_GE(result.at(2), result.at(3)); +} + +TEST_F(TestPartition, test_SortByWeight2) { + std::shared_ptr graph = MakeMatMulData(5); + std::vector result = SortByWeight(graph); + ASSERT_GE(result.at(0), result.at(1)); +} + +TEST_F(TestPartition, test_PartitionNode) { + std::shared_ptr graph = MakeMatMulData(9); + // node 2 is the first kRecMatMul Operator + Graph::NodeType node2 = graph->nodes[2]; + std::vector> nameToStrategy; + bool isTraining = true; + StrategyRec str = PartitionNode(node2, nameToStrategy, graph, isTraining); + ASSERT_EQ(str.outputTensor.str_h, 0.5); + ASSERT_EQ(str.outputTensor.str_w, 1); +} + +TEST_F(TestPartition, test_PartitionForAllDevices) { + std::shared_ptr graph = MakeMatMulData(9); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + bool isTraining = true; + ASSERT_EQ(PartitionForAllDevices(1024, device_memory, graph, isTraining, nullptr), SUCCESS); +} + +TEST_F(TestPartition, test_PartitionForAllDevices2) { + std::shared_ptr graph = MakeMatMulData(9); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + bool isTraining = true; + ASSERT_EQ(PartitionForAllDevices(2, device_memory, graph, isTraining, nullptr), SUCCESS); +} + +// Negative case: partition on 0 device +TEST_F(TestPartition, test_PartitionForAllDevices0) { + std::shared_ptr graph = MakeMatMulData(9); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + bool isTraining = true; + // Throw Exception "Number of devices can't be 0" + EXPECT_ANY_THROW(PartitionForAllDevices(0, device_memory, graph, isTraining, nullptr)); +} + +TEST_F(TestPartition, test_ApplyStrToTensor) { + std::shared_ptr graph = MakeMatMulData(9); + std::vector> nameToStrategy; + bool isTraining = true; + graph->nodes[4].apply.str = PartitionNode(graph->nodes[4], nameToStrategy, graph, isTraining); + auto h_str = graph->nodes[4].apply.str.outputTensor.str_h; + auto w_str = graph->nodes[4].apply.str.outputTensor.str_w; + + Graph::NodeType n_node = ApplyStrToTensor(graph->nodes[4]); + auto h_node = n_node.tensor_parm.tensor_str.str_h; + auto w_node = n_node.tensor_parm.tensor_str.str_w; + ASSERT_EQ(h_str, h_node); + ASSERT_EQ(w_str, w_node); +} + +/// Feature: test GetDPAndMP. +/// Description: +/// Expectation: success +TEST_F(TestPartition, test_get_dp_mp) { + size_t dp, mp; + bool isTraining = true; + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + std::shared_ptr graph = MakeMatMulData(9); + PartitionForAllDevices(8, device_memory, graph, isTraining, nullptr); + std::tie(dp, mp) = GetDPAndMP(graph, 1); + ASSERT_GT(dp, 0); + ASSERT_GT(mp, 0); + ASSERT_LE(dp, GetNumDevices()); + ASSERT_LE(mp, GetNumDevices()); +} + +} // namespace parallel +} // namespace mindspore diff --git a/tests/ut/cpp/runtest.sh b/tests/ut/cpp/runtest.sh old mode 100755 new mode 100644 diff --git a/tests/ut/cpp/stub/hccl/collective_stub.cc b/tests/ut/cpp/stub/hccl/collective_stub.cc index 4674635dadd..9d210d694d9 100644 --- a/tests/ut/cpp/stub/hccl/collective_stub.cc +++ b/tests/ut/cpp/stub/hccl/collective_stub.cc @@ -1,36 +1,36 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" -#include "include/backend/distributed/cluster/cluster_context.h" -namespace mindspore { -namespace device { -namespace ascend { -namespace collective { -HcclCollectiveGroup &HcclCollectiveGroup::instance() { - static HcclCollectiveGroup instance; - return instance; -} -int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; } -int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; } -int HcclCollectiveGroup::GetDeviceId() const { return 0; } -HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) { return nullptr; } -void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector &) { return; } -void HcclCollectiveGroup::FinalizeCollective() { return; } -} // namespace collective -} // namespace ascend -} // namespace device -} // namespace mindspore +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h" +#include "include/backend/distributed/cluster/cluster_context.h" +namespace mindspore { +namespace device { +namespace ascend { +namespace collective { +HcclCollectiveGroup &HcclCollectiveGroup::instance() { + static HcclCollectiveGroup instance; + return instance; +} +int HcclCollectiveGroup::GetRankSize(const std::string &) const { return 0; } +int HcclCollectiveGroup::GetRankId(const std::string &) const { return 0; } +int HcclCollectiveGroup::GetDeviceId() const { return 0; } +HcclComm HcclCollectiveGroup::GetGroupComm(const std::string &name) { return nullptr; } +void HcclCollectiveGroup::CreateCommGroup(const std::string &, const std::vector &) { return; } +void HcclCollectiveGroup::FinalizeCollective() { return; } +} // namespace collective +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/tests/ut/cpp/stub/runtime/cuda_runtime_api.cc b/tests/ut/cpp/stub/runtime/cuda_runtime_api.cc index ea06f52ff64..5e0abfab466 100644 --- a/tests/ut/cpp/stub/runtime/cuda_runtime_api.cc +++ b/tests/ut/cpp/stub/runtime/cuda_runtime_api.cc @@ -1,34 +1,34 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -cudaError_t cudaMalloc(void **devPtr, size_t size) { return cudaSuccess; } - -cudaError_t cudaFree(void *devPtr) { return cudaSuccess; } - -cudaError_t cudaMemcpy(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind) { return cudaSuccess; } - -cudaError_t cudaMemGetInfo(size_t *free, size_t *total) { return cudaSuccess; } - -cudaError_t cudaStreamCreate(cudaStream_t *pStream) { return cudaSuccess; } - -cudaError_t cudaStreamDestroy(cudaStream_t stream) { return cudaSuccess; } - -cudaError_t cudaStreamSynchronize(cudaStream_t stream) { return cudaSuccess; } - -cudaError_t cudaGetDeviceCount(int *count) { return cudaSuccess; } - -cudaError_t cudaSetDevice(int device) { return cudaSuccess; } +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +cudaError_t cudaMalloc(void **devPtr, size_t size) { return cudaSuccess; } + +cudaError_t cudaFree(void *devPtr) { return cudaSuccess; } + +cudaError_t cudaMemcpy(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind) { return cudaSuccess; } + +cudaError_t cudaMemGetInfo(size_t *free, size_t *total) { return cudaSuccess; } + +cudaError_t cudaStreamCreate(cudaStream_t *pStream) { return cudaSuccess; } + +cudaError_t cudaStreamDestroy(cudaStream_t stream) { return cudaSuccess; } + +cudaError_t cudaStreamSynchronize(cudaStream_t stream) { return cudaSuccess; } + +cudaError_t cudaGetDeviceCount(int *count) { return cudaSuccess; } + +cudaError_t cudaSetDevice(int device) { return cudaSuccess; } diff --git a/tests/ut/cpp/stub/runtime/cuda_runtime_api.h b/tests/ut/cpp/stub/runtime/cuda_runtime_api.h index 7f0f0d0be33..e017258b2e2 100644 --- a/tests/ut/cpp/stub/runtime/cuda_runtime_api.h +++ b/tests/ut/cpp/stub/runtime/cuda_runtime_api.h @@ -1,47 +1,47 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ -#define TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ - -#include -typedef enum { cudaSuccess = 0, cudaErrorNotReady = 1 } cudaError_t; - -unsigned int cudaEventDefault = 0; - -enum cudaMemcpyKind { - cudaMemcpyHostToHost = 0, - cudaMemcpyHostToDevice = 1, - cudaMemcpyDeviceToHost = 2, - cudaMemcpyDeviceToDevice = 3 -}; - -struct CUstream_st { - int arch; -}; - -typedef struct CUStream_st *cudaStream_t; - -cudaError_t cudaMalloc(void **devPtr, size_t size); -cudaError_t cudaFree(void *devPtr); -cudaError_t cudaMemcpy(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind); -cudaError_t cudaMemGetInfo(size_t *free, size_t *total); -cudaError_t cudaStreamCreate(cudaStream_t *pStream); -cudaError_t cudaStreamDestroy(cudaStream_t stream); -cudaError_t cudaStreamSynchronize(cudaStream_t stream); -cudaError_t cudaGetDeviceCount(int *count); -cudaError_t cudaSetDevice(int device); - -#endif // TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ +#define TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ + +#include +typedef enum { cudaSuccess = 0, cudaErrorNotReady = 1 } cudaError_t; + +unsigned int cudaEventDefault = 0; + +enum cudaMemcpyKind { + cudaMemcpyHostToHost = 0, + cudaMemcpyHostToDevice = 1, + cudaMemcpyDeviceToHost = 2, + cudaMemcpyDeviceToDevice = 3 +}; + +struct CUstream_st { + int arch; +}; + +typedef struct CUStream_st *cudaStream_t; + +cudaError_t cudaMalloc(void **devPtr, size_t size); +cudaError_t cudaFree(void *devPtr); +cudaError_t cudaMemcpy(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind); +cudaError_t cudaMemGetInfo(size_t *free, size_t *total); +cudaError_t cudaStreamCreate(cudaStream_t *pStream); +cudaError_t cudaStreamDestroy(cudaStream_t stream); +cudaError_t cudaStreamSynchronize(cudaStream_t stream); +cudaError_t cudaGetDeviceCount(int *count); +cudaError_t cudaSetDevice(int device); + +#endif // TESTS_UT_STUB_RUNTIME_INCLUDE_CUDA_RUNTIME_API_H_ diff --git a/tests/ut/data/dataset/jiebadict/hmm_model.utf8 b/tests/ut/data/dataset/jiebadict/hmm_model.utf8 old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/jiebadict/jieba.dict.utf8 b/tests/ut/data/dataset/jiebadict/jieba.dict.utf8 old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testAGNews/test.csv b/tests/ut/data/dataset/testAGNews/test.csv index 09e12a1c1e3..fc2459ad5bb 100644 --- a/tests/ut/data/dataset/testAGNews/test.csv +++ b/tests/ut/data/dataset/testAGNews/test.csv @@ -1,2 +1,2 @@ -3,Background of the selection,"In this day and age, the internet is growing rapidly, the total number of connected devices is increasing and we are entering the era of big data." -4,Related technologies,"""Leaflet is the leading open source JavaScript library for mobile-friendly interactive maps.""" +3,Background of the selection,"In this day and age, the internet is growing rapidly, the total number of connected devices is increasing and we are entering the era of big data." +4,Related technologies,"""Leaflet is the leading open source JavaScript library for mobile-friendly interactive maps.""" diff --git a/tests/ut/data/dataset/testAGNews/train.csv b/tests/ut/data/dataset/testAGNews/train.csv index 12b9cb973c5..e0bf165ef76 100644 --- a/tests/ut/data/dataset/testAGNews/train.csv +++ b/tests/ut/data/dataset/testAGNews/train.csv @@ -1,3 +1,3 @@ -3,Demand analysis,"""Users simply click on the module they want to view to browse information about that module.""" -3,UML Timing Diagram,"Information is mainly displayed using locally stored data and mapping, which is not timely and does not have the ability to update itself." -3,In summary,"This paper implements a map visualization system for Hangzhou city information, using extensive knowledge of visualization techniques." +3,Demand analysis,"""Users simply click on the module they want to view to browse information about that module.""" +3,UML Timing Diagram,"Information is mainly displayed using locally stored data and mapping, which is not timely and does not have the ability to update itself." +3,In summary,"This paper implements a map visualization system for Hangzhou city information, using extensive knowledge of visualization techniques." diff --git a/tests/ut/data/dataset/testAmazonReview/full/test.csv b/tests/ut/data/dataset/testAmazonReview/full/test.csv old mode 100755 new mode 100644 index 61f71ff4e4f..109e4ea983f --- a/tests/ut/data/dataset/testAmazonReview/full/test.csv +++ b/tests/ut/data/dataset/testAmazonReview/full/test.csv @@ -1,3 +1,3 @@ -1,amazing,unlimited buyback! -4,delightful,a funny book! -3,Small,It is a small ball! +1,amazing,unlimited buyback! +4,delightful,a funny book! +3,Small,It is a small ball! diff --git a/tests/ut/data/dataset/testAmazonReview/full/train.csv b/tests/ut/data/dataset/testAmazonReview/full/train.csv old mode 100755 new mode 100644 index 0b20e7b6c28..ab77d2f40b1 --- a/tests/ut/data/dataset/testAmazonReview/full/train.csv +++ b/tests/ut/data/dataset/testAmazonReview/full/train.csv @@ -1,3 +1,3 @@ -3,Satisfied,good quality. -5,good,This is an very good product. -1,bad,work badly. +3,Satisfied,good quality. +5,good,This is an very good product. +1,bad,work badly. diff --git a/tests/ut/data/dataset/testAmazonReview/polarity/test.csv b/tests/ut/data/dataset/testAmazonReview/polarity/test.csv old mode 100755 new mode 100644 index 8157ff65a35..705864681f5 --- a/tests/ut/data/dataset/testAmazonReview/polarity/test.csv +++ b/tests/ut/data/dataset/testAmazonReview/polarity/test.csv @@ -1,2 +1,2 @@ -1,DVD,It is very good! -2,Book,I would read it again lol. +1,DVD,It is very good! +2,Book,I would read it again lol. diff --git a/tests/ut/data/dataset/testAmazonReview/polarity/train.csv b/tests/ut/data/dataset/testAmazonReview/polarity/train.csv old mode 100755 new mode 100644 index 0d3d13af6d9..9ab809e8413 --- a/tests/ut/data/dataset/testAmazonReview/polarity/train.csv +++ b/tests/ut/data/dataset/testAmazonReview/polarity/train.csv @@ -1,3 +1,3 @@ -2,Great Read,I thought this book was excellent! -1,Oh dear,It is so bad! -2,Delicious,A funny product. +2,Great Read,I thought this book was excellent! +1,Oh dear,It is so bad! +2,Delicious,A funny product. diff --git a/tests/ut/data/dataset/testArgoverse/features_2645.pkl b/tests/ut/data/dataset/testArgoverse/features_2645.pkl deleted file mode 100644 index d8508253b30..00000000000 Binary files a/tests/ut/data/dataset/testArgoverse/features_2645.pkl and /dev/null differ diff --git a/tests/ut/data/dataset/testCLUE/wsc/dev.json b/tests/ut/data/dataset/testCLUE/wsc/dev.json old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCLUE/wsc/test.json b/tests/ut/data/dataset/testCLUE/wsc/test.json old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCLUE/wsc/train.json b/tests/ut/data/dataset/testCLUE/wsc/train.json old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCOCO/annotations/captions.json b/tests/ut/data/dataset/testCOCO/annotations/captions.json old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCaltech101Data/Annotations/apple/annotation_0001.mat b/tests/ut/data/dataset/testCaltech101Data/Annotations/apple/annotation_0001.mat old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCaltech101Data/Annotations/apple/annotation_0002.mat b/tests/ut/data/dataset/testCaltech101Data/Annotations/apple/annotation_0002.mat old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCaltech101Data/Annotations/banana/annotation_0001.mat b/tests/ut/data/dataset/testCaltech101Data/Annotations/banana/annotation_0001.mat old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCaltech101Data/Annotations/banana/annotation_0002.mat b/tests/ut/data/dataset/testCaltech101Data/Annotations/banana/annotation_0002.mat old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCoNLL2000Dataset/test.txt b/tests/ut/data/dataset/testCoNLL2000Dataset/test.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testCoNLL2000Dataset/train.txt b/tests/ut/data/dataset/testCoNLL2000Dataset/train.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testDBpedia/test.csv b/tests/ut/data/dataset/testDBpedia/test.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testDBpedia/train.csv b/tests/ut/data/dataset/testDBpedia/train.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testGloVe/glove.6B.dim_different.txt b/tests/ut/data/dataset/testGloVe/glove.6B.dim_different.txt index 65830c6aaf0..6ef293e0716 100644 --- a/tests/ut/data/dataset/testGloVe/glove.6B.dim_different.txt +++ b/tests/ut/data/dataset/testGloVe/glove.6B.dim_different.txt @@ -1,6 +1,6 @@ -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/testGloVe/glove.6B.test.txt b/tests/ut/data/dataset/testGloVe/glove.6B.test.txt index dc5c942ba1d..b414d57d92f 100644 --- a/tests/ut/data/dataset/testGloVe/glove.6B.test.txt +++ b/tests/ut/data/dataset/testGloVe/glove.6B.test.txt @@ -1,6 +1,6 @@ -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/testGloVe/glove.6B.test.vec b/tests/ut/data/dataset/testGloVe/glove.6B.test.vec index dc5c942ba1d..b414d57d92f 100644 --- a/tests/ut/data/dataset/testGloVe/glove.6B.test.vec +++ b/tests/ut/data/dataset/testGloVe/glove.6B.test.vec @@ -1,6 +1,6 @@ -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/testGloVe/glove.6B.with_info.txt b/tests/ut/data/dataset/testGloVe/glove.6B.with_info.txt index 74d030d01e3..9b4e6dd43bd 100644 --- a/tests/ut/data/dataset/testGloVe/glove.6B.with_info.txt +++ b/tests/ut/data/dataset/testGloVe/glove.6B.with_info.txt @@ -1,7 +1,7 @@ -6 6 -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +6 6 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/testGloVe/glove.6B.with_wrong_info.txt b/tests/ut/data/dataset/testGloVe/glove.6B.with_wrong_info.txt index 86d3cc3952f..d9b752e7f66 100644 --- a/tests/ut/data/dataset/testGloVe/glove.6B.with_wrong_info.txt +++ b/tests/ut/data/dataset/testGloVe/glove.6B.with_wrong_info.txt @@ -1,7 +1,7 @@ -the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -6 6 -of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +6 6 +of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/testGloVe/words_with_big_letter.txt b/tests/ut/data/dataset/testGloVe/words_with_big_letter.txt index efa25a4b390..8643123fe2a 100644 --- a/tests/ut/data/dataset/testGloVe/words_with_big_letter.txt +++ b/tests/ut/data/dataset/testGloVe/words_with_big_letter.txt @@ -1,7 +1,7 @@ -ok -! -This -iS -my -HOME -. +ok +! +This +iS +my +HOME +. diff --git a/tests/ut/data/dataset/testKITTI/data_object_image_2/training/image_2/000000.png b/tests/ut/data/dataset/testKITTI/data_object_image_2/training/image_2/000000.png old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testKITTI/data_object_label_2/training/label_2/000000.txt b/tests/ut/data/dataset/testKITTI/data_object_label_2/training/label_2/000000.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testLJSpeechData/LJSpeech-1.1/metadata.csv b/tests/ut/data/dataset/testLJSpeechData/LJSpeech-1.1/metadata.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testLJSpeechData/metadata.csv b/tests/ut/data/dataset/testLJSpeechData/metadata.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testNumpySlicesDataset/heart.csv b/tests/ut/data/dataset/testNumpySlicesDataset/heart.csv index 92bc9db643a..b8a5a194ae2 100644 --- a/tests/ut/data/dataset/testNumpySlicesDataset/heart.csv +++ b/tests/ut/data/dataset/testNumpySlicesDataset/heart.csv @@ -1,6 +1,6 @@ -age,sex,height,weight,slope,state,target -65,0,161,45,93,fixed,1 -72,1,164,60,86,good,0 -45,0,174,70,79,bad,1 -73,1,173,65,70,good,1 +age,sex,height,weight,slope,state,target +65,0,161,45,93,fixed,1 +72,1,164,60,86,good,0 +45,0,174,70,79,bad,1 +73,1,173,65,70,good,1 55,1,182,80,104,good,0 \ No newline at end of file diff --git a/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character01/1_1.jpg b/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character01/1_1.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character01/1_2.jpg b/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character01/1_2.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character02/2_1.jpg b/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character02/2_1.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character02/2_2.jpg b/tests/ut/data/dataset/testOmniglot/images_background/Alphabet_of_the_Magi/character02/2_2.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character01/1_1.jpg b/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character01/1_1.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character01/1_2.png b/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character01/1_2.png old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character02/1_1.jpg b/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character02/1_1.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character02/1_2.jpg b/tests/ut/data/dataset/testOmniglot/images_evaluation/Angelic/character02/1_2.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testPhotoTourData/liberty/info.txt b/tests/ut/data/dataset/testPhotoTourData/liberty/info.txt old mode 100755 new mode 100644 index 3ca881ad187..2c44f2b5203 --- a/tests/ut/data/dataset/testPhotoTourData/liberty/info.txt +++ b/tests/ut/data/dataset/testPhotoTourData/liberty/info.txt @@ -1,100 +1,100 @@ -0 0 -0 0 -1 0 -1 0 -1 0 -2 0 -2 0 -3 0 -3 0 -4 0 -4 0 -5 0 -5 0 -6 0 -6 0 -7 0 -7 0 -7 0 -8 0 -8 0 -8 0 -9 0 -9 0 -10 0 -10 0 -11 0 -11 0 -12 0 -12 0 -12 0 -13 0 -13 0 -13 0 -14 0 -14 0 -14 0 -15 0 -15 0 -15 0 -16 0 -16 0 -16 0 -17 0 -17 0 -17 0 -18 0 -18 0 -18 0 -19 0 -19 0 -19 0 -20 0 -20 0 -21 0 -21 0 -22 0 -22 0 -23 0 -23 0 -24 0 -24 0 -25 0 -25 0 -26 0 -26 0 -26 0 -27 0 -27 0 -27 0 -28 0 -28 0 -28 0 -29 0 -29 0 -29 0 -30 0 -30 0 -31 0 -31 0 -31 0 -32 0 -32 0 -32 0 -32 0 -33 0 -33 0 -34 0 -34 0 -34 0 -35 0 -35 0 -36 0 -36 0 -37 0 -37 0 -38 0 -38 0 -39 0 -39 0 +0 0 +0 0 +1 0 +1 0 +1 0 +2 0 +2 0 +3 0 +3 0 +4 0 +4 0 +5 0 +5 0 +6 0 +6 0 +7 0 +7 0 +7 0 +8 0 +8 0 +8 0 +9 0 +9 0 +10 0 +10 0 +11 0 +11 0 +12 0 +12 0 +12 0 +13 0 +13 0 +13 0 +14 0 +14 0 +14 0 +15 0 +15 0 +15 0 +16 0 +16 0 +16 0 +17 0 +17 0 +17 0 +18 0 +18 0 +18 0 +19 0 +19 0 +19 0 +20 0 +20 0 +21 0 +21 0 +22 0 +22 0 +23 0 +23 0 +24 0 +24 0 +25 0 +25 0 +26 0 +26 0 +26 0 +27 0 +27 0 +27 0 +28 0 +28 0 +28 0 +29 0 +29 0 +29 0 +30 0 +30 0 +31 0 +31 0 +31 0 +32 0 +32 0 +32 0 +32 0 +33 0 +33 0 +34 0 +34 0 +34 0 +35 0 +35 0 +36 0 +36 0 +37 0 +37 0 +38 0 +38 0 +39 0 +39 0 39 0 \ No newline at end of file diff --git a/tests/ut/data/dataset/testPhotoTourData/liberty/m50_100000_100000_0.txt b/tests/ut/data/dataset/testPhotoTourData/liberty/m50_100000_100000_0.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testSogouNews/test.csv b/tests/ut/data/dataset/testSogouNews/test.csv old mode 100755 new mode 100644 index d54d610a6c0..f7ecc62cb33 --- a/tests/ut/data/dataset/testSogouNews/test.csv +++ b/tests/ut/data/dataset/testSogouNews/test.csv @@ -1,3 +1,3 @@ -"1","Make history","Su Bingtian's 100m breakthrough\n 9.83" -"4","Tesla price","Tesla reduced its price by 70000 yuan" +"1","Make history","Su Bingtian's 100m breakthrough\n 9.83" +"4","Tesla price","Tesla reduced its price by 70000 yuan" "1","Opening ceremony of the 14th National Games","On the evening of September 15, Beijing time, the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center Stadium, Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the Tokyo Olympic Games and a Post-00 shooter, lit the main torch platform. From then on, to September 27, the 14th National Games flame will burn here for 12 days." \ No newline at end of file diff --git a/tests/ut/data/dataset/testSogouNews/train.csv b/tests/ut/data/dataset/testSogouNews/train.csv old mode 100755 new mode 100644 index 9492eb27de3..56c84956ac3 --- a/tests/ut/data/dataset/testSogouNews/train.csv +++ b/tests/ut/data/dataset/testSogouNews/train.csv @@ -1,3 +1,3 @@ -"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the top five","They say he has the talent of the top five in the league. The talent of the top five in the league is one of the most disrespectful statements. I say he has the talent of the top five in the league, but he is not the top five players because the top five players play every night." -"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat, or plaid suit, which is gentle and elegant and beautiful to a new height." -"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression." +"1","Jefferson commented on thick eyebrow: he has the top five talents in the league, but he is not the top five","They say he has the talent of the top five in the league. The talent of the top five in the league is one of the most disrespectful statements. I say he has the talent of the top five in the league, but he is not the top five players because the top five players play every night." +"3","Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro curly long hair, elegant, lazy, gentle and capable","Liu Shishi's latest group of cover magazine blockbusters are released. In the photos, Liu Shishi's long hair is slightly curly, or camel colored belted woolen coat, or plaid suit, which is gentle and elegant and beautiful to a new height." +"3","Ni Ni deduces elegant retro style in different styles","Ni Ni's latest group of magazine cover blockbusters released that wearing gift hats is cool, retro, unique and full of fashion expression." diff --git a/tests/ut/data/dataset/testVectors/char_n_gram_20.txt b/tests/ut/data/dataset/testVectors/char_n_gram_20.txt index 51182469981..20197dd8d7c 100644 --- a/tests/ut/data/dataset/testVectors/char_n_gram_20.txt +++ b/tests/ut/data/dataset/testVectors/char_n_gram_20.txt @@ -1,20 +1,20 @@ -1gram-e -0.655379 0.574261 -0.714026 -0.148858 -0.0534275 -1gram-a -0.288984 -0.225616 0.323913 -0.261039 -0.0628034 -1gram-t 0.408448 0.175862 -0.296873 -0.209094 -0.53478 -1gram-i 0.278486 -0.910641 -0.743681 -0.734405 0.519959 -1gram-n -0.0712582 0.0898121 -1.12567 -0.815067 -0.435836 -1gram-o -0.182786 0.535789 -0.391385 0.181972 0.317399 -1gram-r 0.68474 0.103464 0.201631 -0.65319 0.554142 -1gram-s -0.175988 -0.813322 0.465603 -0.0951031 0.193374 -1gram-h -0.39348 -0.678079 0.233101 0.431805 2.04905 -1gram-l -0.451299 -0.268223 -0.787034 -0.991984 0.251244 -1gram-d 0.799629 -0.326191 -0.474959 0.235657 0.796227 -2gram-e#END# -2.26956 0.288491 -0.740001 0.661703 0.147355 -1gram-c -0.0413309 0.436135 -0.835305 -1.64429 -1.08329 -2gram-s#END# 0.657201 2.11761 -1.59276 0.432072 1.21395 -1gram-u -0.25203 -0.176365 -0.263038 -0.995372 -1.24916 -2gram-#BEGIN#t -0.96853 -0.789463 0.515762 2.02107 -1.64635 -1gram-m 0.422293 -0.149725 -0.734202 1.27342 0.232722 -2gram-he -0.785562 0.63378 -1.23667 -0.693956 0.395988 -2gram-th 0.663336 -0.240809 -1.87298 0.364651 0.26296 +1gram-e -0.655379 0.574261 -0.714026 -0.148858 -0.0534275 +1gram-a -0.288984 -0.225616 0.323913 -0.261039 -0.0628034 +1gram-t 0.408448 0.175862 -0.296873 -0.209094 -0.53478 +1gram-i 0.278486 -0.910641 -0.743681 -0.734405 0.519959 +1gram-n -0.0712582 0.0898121 -1.12567 -0.815067 -0.435836 +1gram-o -0.182786 0.535789 -0.391385 0.181972 0.317399 +1gram-r 0.68474 0.103464 0.201631 -0.65319 0.554142 +1gram-s -0.175988 -0.813322 0.465603 -0.0951031 0.193374 +1gram-h -0.39348 -0.678079 0.233101 0.431805 2.04905 +1gram-l -0.451299 -0.268223 -0.787034 -0.991984 0.251244 +1gram-d 0.799629 -0.326191 -0.474959 0.235657 0.796227 +2gram-e#END# -2.26956 0.288491 -0.740001 0.661703 0.147355 +1gram-c -0.0413309 0.436135 -0.835305 -1.64429 -1.08329 +2gram-s#END# 0.657201 2.11761 -1.59276 0.432072 1.21395 +1gram-u -0.25203 -0.176365 -0.263038 -0.995372 -1.24916 +2gram-#BEGIN#t -0.96853 -0.789463 0.515762 2.02107 -1.64635 +1gram-m 0.422293 -0.149725 -0.734202 1.27342 0.232722 +2gram-he -0.785562 0.63378 -1.23667 -0.693956 0.395988 +2gram-th 0.663336 -0.240809 -1.87298 0.364651 0.26296 2gram-n#END# -0.149612 -0.664577 -1.12344 2.23695 0.610406 \ No newline at end of file diff --git a/tests/ut/data/dataset/testVectors/char_n_gram_20_dim_different.txt b/tests/ut/data/dataset/testVectors/char_n_gram_20_dim_different.txt index 9d0ae01e09f..5d29fce4895 100644 --- a/tests/ut/data/dataset/testVectors/char_n_gram_20_dim_different.txt +++ b/tests/ut/data/dataset/testVectors/char_n_gram_20_dim_different.txt @@ -1,20 +1,20 @@ -1gram-e -0.655379 0.574261 -0.714026 -0.148858 -0.0534275 -1gram-a -0.288984 -0.225616 0.323913 -0.261039 -0.0628034 -1gram-t 0.408448 0.175862 -0.296873 -0.209094 -0.53478 -1gram-i 0.278486 -0.910641 -0.743681 -0.734405 0.519959 -1gram-n -0.0712582 0.0898121 -1.12567 -0.815067 -0.435836 -1gram-o -0.182786 0.535789 -0.391385 0.181972 0.317399 -1gram-r 0.68474 0.103464 0.201631 -0.65319 0.554142 -1gram-s -0.175988 -0.813322 0.465603 -0.0951031 0.193374 -1gram-h -0.39348 -0.678079 0.233101 0.431805 2.04905 -1gram-l -0.451299 -0.268223 -0.787034 -0.991984 0.251244 -1gram-d 0.799629 -0.326191 -0.474959 0.235657 0.796227 -2gram-e#END# -2.26956 0.288491 -0.740001 0.661703 0.147355 -1gram-c -0.0413309 0.436135 -0.835305 -1.64429 -1.08329 -2gram-s#END# 0.657201 2.11761 -1.59276 0.432072 1.21395 -1gram-u -0.25203 -0.176365 -0.263038 -0.995372 -1.24916 -2gram-#BEGIN#t -0.96853 -0.789463 0.515762 2.02107 -1gram-m 0.422293 -0.149725 -0.734202 1.27342 0.232722 -2gram-he -0.785562 0.63378 -1.23667 -0.693956 0.395988 -2gram-th 0.663336 -0.240809 -1.87298 0.364651 0.26296 +1gram-e -0.655379 0.574261 -0.714026 -0.148858 -0.0534275 +1gram-a -0.288984 -0.225616 0.323913 -0.261039 -0.0628034 +1gram-t 0.408448 0.175862 -0.296873 -0.209094 -0.53478 +1gram-i 0.278486 -0.910641 -0.743681 -0.734405 0.519959 +1gram-n -0.0712582 0.0898121 -1.12567 -0.815067 -0.435836 +1gram-o -0.182786 0.535789 -0.391385 0.181972 0.317399 +1gram-r 0.68474 0.103464 0.201631 -0.65319 0.554142 +1gram-s -0.175988 -0.813322 0.465603 -0.0951031 0.193374 +1gram-h -0.39348 -0.678079 0.233101 0.431805 2.04905 +1gram-l -0.451299 -0.268223 -0.787034 -0.991984 0.251244 +1gram-d 0.799629 -0.326191 -0.474959 0.235657 0.796227 +2gram-e#END# -2.26956 0.288491 -0.740001 0.661703 0.147355 +1gram-c -0.0413309 0.436135 -0.835305 -1.64429 -1.08329 +2gram-s#END# 0.657201 2.11761 -1.59276 0.432072 1.21395 +1gram-u -0.25203 -0.176365 -0.263038 -0.995372 -1.24916 +2gram-#BEGIN#t -0.96853 -0.789463 0.515762 2.02107 +1gram-m 0.422293 -0.149725 -0.734202 1.27342 0.232722 +2gram-he -0.785562 0.63378 -1.23667 -0.693956 0.395988 +2gram-th 0.663336 -0.240809 -1.87298 0.364651 0.26296 2gram-n#END# -0.149612 -0.664577 -1.12344 2.23695 0.610406 \ No newline at end of file diff --git a/tests/ut/data/dataset/testVectors/vectors.txt b/tests/ut/data/dataset/testVectors/vectors.txt index dc5c942ba1d..b414d57d92f 100644 --- a/tests/ut/data/dataset/testVectors/vectors.txt +++ b/tests/ut/data/dataset/testVectors/vectors.txt @@ -1,6 +1,6 @@ -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/testVectors/vectors_dim_different.txt b/tests/ut/data/dataset/testVectors/vectors_dim_different.txt index 65830c6aaf0..6ef293e0716 100644 --- a/tests/ut/data/dataset/testVectors/vectors_dim_different.txt +++ b/tests/ut/data/dataset/testVectors/vectors_dim_different.txt @@ -1,6 +1,6 @@ -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/testVectors/vectors_with_info.txt b/tests/ut/data/dataset/testVectors/vectors_with_info.txt index b708aa25bc4..c273f7c5911 100644 --- a/tests/ut/data/dataset/testVectors/vectors_with_info.txt +++ b/tests/ut/data/dataset/testVectors/vectors_with_info.txt @@ -1,7 +1,7 @@ -6 6 -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +6 6 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/testVectors/vectors_with_wrong_info.txt b/tests/ut/data/dataset/testVectors/vectors_with_wrong_info.txt index 86d3cc3952f..d9b752e7f66 100644 --- a/tests/ut/data/dataset/testVectors/vectors_with_wrong_info.txt +++ b/tests/ut/data/dataset/testVectors/vectors_with_wrong_info.txt @@ -1,7 +1,7 @@ -the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -6 6 -of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +6 6 +of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/0--Abs/0_Abs_mypic_1_111.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/0--Abs/0_Abs_mypic_1_111.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/1--Pushup/1_Pushup_mypic_1_111.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/1--Pushup/1_Pushup_mypic_1_111.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/1--Pushup/1_Pushup_mypic_7_777.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_test/images/1--Pushup/1_Pushup_mypic_7_777.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_train/images/0--Abs/0_Abs_mypic_4_444.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_train/images/0--Abs/0_Abs_mypic_4_444.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_train/images/1--Pushup/1_Pushup_mypic_3_333.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_train/images/1--Pushup/1_Pushup_mypic_3_333.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_val/images/0--Abs/0_Abs_mypic_5_555.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_val/images/0--Abs/0_Abs_mypic_5_555.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/WIDER_val/images/1--Pushup/1_Pushup_mypic_4_444.jpg b/tests/ut/data/dataset/testWIDERFace/WIDER_val/images/1--Pushup/1_Pushup_mypic_4_444.jpg old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_test_filelist.txt b/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_test_filelist.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_train_bbx_gt.txt b/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_train_bbx_gt.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_val_bbx_gt.txt b/tests/ut/data/dataset/testWIDERFace/wider_face_split/wider_face_val_bbx_gt.txt old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testYahooAnswers/test.csv b/tests/ut/data/dataset/testYahooAnswers/test.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/dataset/testYahooAnswers/train.csv b/tests/ut/data/dataset/testYahooAnswers/train.csv old mode 100755 new mode 100644 index ba73fe5107d..a4321b7201a --- a/tests/ut/data/dataset/testYahooAnswers/train.csv +++ b/tests/ut/data/dataset/testYahooAnswers/train.csv @@ -1,4 +1,4 @@ -"3","My Chinese teacher","I have a Chinese Teacher.","She is from LanCha." -"5","Last weekend","We played games, we were very happy.","I visited my friends." -"1","A Happy Day","Last Sunday, I visited my grandmother.","I counted the flowers." -"8","My Good Friend","She lives in China.","He likes listening to music." +"3","My Chinese teacher","I have a Chinese Teacher.","She is from LanCha." +"5","Last weekend","We played games, we were very happy.","I visited my friends." +"1","A Happy Day","Last Sunday, I visited my grandmother.","I counted the flowers." +"8","My Good Friend","She lives in China.","He likes listening to music." diff --git a/tests/ut/data/dataset/testYelpReview/full/test.csv b/tests/ut/data/dataset/testYelpReview/full/test.csv index 9b595248c53..128681bcffb 100644 --- a/tests/ut/data/dataset/testYelpReview/full/test.csv +++ b/tests/ut/data/dataset/testYelpReview/full/test.csv @@ -1,2 +1,2 @@ -1,"\""YelpFull\"" service was very good.\n" -1,"\""YelpFull\"" service was very bad.\n" +1,"\""YelpFull\"" service was very good.\n" +1,"\""YelpFull\"" service was very bad.\n" diff --git a/tests/ut/data/dataset/testYelpReview/full/train.csv b/tests/ut/data/dataset/testYelpReview/full/train.csv index fd89c3dc691..8f3f5362bc6 100644 --- a/tests/ut/data/dataset/testYelpReview/full/train.csv +++ b/tests/ut/data/dataset/testYelpReview/full/train.csv @@ -1,3 +1,3 @@ -5,Yelpfull's drink tastes bad.\n -2,Yelpfull's food is terrible.\n -4,Yelpful's service was very good.\n +5,Yelpfull's drink tastes bad.\n +2,Yelpfull's food is terrible.\n +4,Yelpful's service was very good.\n diff --git a/tests/ut/data/dataset/test_fast_text/fast_text.txt b/tests/ut/data/dataset/test_fast_text/fast_text.txt index 74d030d01e3..9b4e6dd43bd 100644 --- a/tests/ut/data/dataset/test_fast_text/fast_text.txt +++ b/tests/ut/data/dataset/test_fast_text/fast_text.txt @@ -1,7 +1,7 @@ -6 6 -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +6 6 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/test_fast_text/fast_text.vec b/tests/ut/data/dataset/test_fast_text/fast_text.vec index 74d030d01e3..9b4e6dd43bd 100644 --- a/tests/ut/data/dataset/test_fast_text/fast_text.vec +++ b/tests/ut/data/dataset/test_fast_text/fast_text.vec @@ -1,7 +1,7 @@ -6 6 -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 +6 6 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 diff --git a/tests/ut/data/dataset/test_fast_text/fast_text_dim_different.vec b/tests/ut/data/dataset/test_fast_text/fast_text_dim_different.vec index 54d94cfa01c..a50031f6697 100644 --- a/tests/ut/data/dataset/test_fast_text/fast_text_dim_different.vec +++ b/tests/ut/data/dataset/test_fast_text/fast_text_dim_different.vec @@ -1,7 +1,7 @@ -6 6 -ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -this 0.15164 0.30177 -0.16763 0.17684 0.31719 -is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +6 6 +ok 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +! 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +this 0.15164 0.30177 -0.16763 0.17684 0.31719 +is 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +my 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 home 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/test_fast_text/fast_text_with_wrong_info.vec b/tests/ut/data/dataset/test_fast_text/fast_text_with_wrong_info.vec index 86d3cc3952f..d9b752e7f66 100644 --- a/tests/ut/data/dataset/test_fast_text/fast_text_with_wrong_info.vec +++ b/tests/ut/data/dataset/test_fast_text/fast_text_with_wrong_info.vec @@ -1,7 +1,7 @@ -the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 -, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -6 6 -of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 -to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 +the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.04445718411 +, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 +. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 +6 6 +of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 +to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 \ No newline at end of file diff --git a/tests/ut/data/dataset/test_fast_text/words_with_big_letter.txt b/tests/ut/data/dataset/test_fast_text/words_with_big_letter.txt index efa25a4b390..8643123fe2a 100644 --- a/tests/ut/data/dataset/test_fast_text/words_with_big_letter.txt +++ b/tests/ut/data/dataset/test_fast_text/words_with_big_letter.txt @@ -1,7 +1,7 @@ -ok -! -This -iS -my -HOME -. +ok +! +This +iS +my +HOME +. diff --git a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv old mode 100755 new mode 100644 index ec117c5346b..e7cf6704971 --- a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv +++ b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_detail.csv @@ -1,11 +1,11 @@ -full_op_time,execution_time -Default/AtomicAddrClean-op104,0.00133 -Default/AtomicAddrClean-op105,0.000987 -Default/AtomicAddrClean-op106,0.001129 -Default/Cast-op10,0.00466 -Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,0.002366 -Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,0.004879 -Default/TransData-op11,0.006366 -Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,0.006782 -Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,0.05651 -Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,0.370864 +full_op_time,execution_time +Default/AtomicAddrClean-op104,0.00133 +Default/AtomicAddrClean-op105,0.000987 +Default/AtomicAddrClean-op106,0.001129 +Default/Cast-op10,0.00466 +Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,0.002366 +Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,0.004879 +Default/TransData-op11,0.006366 +Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,0.006782 +Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,0.05651 +Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,0.370864 diff --git a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv old mode 100755 new mode 100644 index 56bf368a6c2..ec710164d19 --- a/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv +++ b/tests/ut/data/profiler_data/profiler/aicore_intermediate_1_type.csv @@ -1,6 +1,6 @@ -op_type,execution_time,execution_frequency,percent -AtomicAddrClean,0.007283,6,0.49 -Cast,0.053395,13,3.63 -TransData,0.121800,5,8.23 -Conv2D,0.063656,2,4.33 -MatMul,1.085982,9,73.80 +op_type,execution_time,execution_frequency,percent +AtomicAddrClean,0.007283,6,0.49 +Cast,0.053395,13,3.63 +TransData,0.121800,5,8.23 +Conv2D,0.063656,2,4.33 +MatMul,1.085982,9,73.80 diff --git a/tests/ut/data/profiler_data/profiler/framework_raw_0.csv b/tests/ut/data/profiler_data/profiler/framework_raw_0.csv old mode 100755 new mode 100644 index e5286ca5ef9..3cbc1594b13 --- a/tests/ut/data/profiler_data/profiler/framework_raw_0.csv +++ b/tests/ut/data/profiler_data/profiler/framework_raw_0.csv @@ -1,5 +1,5 @@ -task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info -51517,0,32,Default/Cast-op6,Cast-op6,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,3,224,224""}}" -51518,0,32,Default/TransData-op7,TransData-op7,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,1,224,224,16""}}" -51519,0,32,Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5,Cast-op5,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""49,4,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""49,4,16,16""}}" -51522,0,4,Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28,Cast-op28,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""4,4,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""4,4,16,16""}}" +task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info +51517,0,32,Default/Cast-op6,Cast-op6,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,3,224,224""}}" +51518,0,32,Default/TransData-op7,TransData-op7,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,3,224,224""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""32,1,224,224,16""}}" +51519,0,32,Default/network-WithLossCell/_backbone-ResNet/conv1-Conv2d/Cast-op5,Cast-op5,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""49,4,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""49,4,16,16""}}" +51522,0,4,Default/network-WithLossCell/_backbone-ResNet/layer1-SequentialCell/0-ResidualBlock/conv1-Conv2d/Cast-op28,Cast-op28,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT32"", ""shape"": ""4,4,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_UINT16"", ""shape"": ""4,4,16,16""}}" diff --git a/tests/ut/data/profiler_data/profiler/framework_raw_1.csv b/tests/ut/data/profiler_data/profiler/framework_raw_1.csv old mode 100755 new mode 100644 index ceddd1db5c1..43afd219848 --- a/tests/ut/data/profiler_data/profiler/framework_raw_1.csv +++ b/tests/ut/data/profiler_data/profiler/framework_raw_1.csv @@ -1,11 +1,11 @@ -task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info -30290,0,1,Default/AtomicAddrClean-op104,AtomicAddrClean-op104,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": """"}}" -30295,0,1,Default/AtomicAddrClean-op105,AtomicAddrClean-op105,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""10""}}" -30300,0,1,Default/AtomicAddrClean-op106,AtomicAddrClean-op106,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}}" -30268,0,32,Default/Cast-op10,Cast-op10,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" -30271,0,9,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,Cast-op12,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}}" -30320,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,Cast-op53,Cast,Gradients,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,28,28,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" -30269,0,32,Default/TransData-op11,TransData-op11,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" -30308,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,TransData-op44,TransData,Gradients,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,16,5,5""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,5,5,16""}}" -30272,0,32,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,Conv2D-op13,Conv2D,Default,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32,16""}, ""input_1"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" -30286,0,1,Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,MatMul-op9,MatMul,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,120""}, ""input_1"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84,120""}, ""input_2"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,84""}}" +task_id,stream_id,block_dim,full_op_name,op_name,op_type,subgraph,op_info +30290,0,1,Default/AtomicAddrClean-op104,AtomicAddrClean-op104,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": """"}}" +30295,0,1,Default/AtomicAddrClean-op105,AtomicAddrClean-op105,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""10""}}" +30300,0,1,Default/AtomicAddrClean-op106,AtomicAddrClean-op106,AtomicAddrClean,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}}" +30268,0,32,Default/Cast-op10,Cast-op10,Cast,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" +30271,0,9,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op12,Cast-op12,Cast,Default,"{""input_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}}" +30320,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Cast-op53,Cast-op53,Cast,Gradients,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,1,28,28,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" +30269,0,32,Default/TransData-op11,TransData-op11,TransData,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32""}}" +30308,0,32,Gradients/Default/network-WithLossCell/_backbone-LeNet5/gradReshape/TransData-op44,TransData-op44,TransData,Gradients,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,16,5,5""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,5,5,16""}}" +30272,0,32,Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Conv2D-op13,Conv2D-op13,Conv2D,Default,"{""input_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,32,32,16""}, ""input_1"": {""format"": ""FRACTAL_Z"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""25,1,16,16""}, ""output_0"": {""format"": ""NC1HWC0"", ""data_type"": ""NUMBER_TYPE_FLOAT16"", ""shape"": ""32,1,28,28,16""}}" +30286,0,1,Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/MatMul-op9,MatMul-op9,MatMul,Default,"{""input_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,120""}, ""input_1"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84,120""}, ""input_2"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""84""}, ""output_0"": {""format"": ""DefaultFormat"", ""data_type"": ""NUMBER_TYPE_FLOAT32"", ""shape"": ""32,84""}}" diff --git a/tests/ut/data/profiler_data/profiler/hccl_raw_6.csv b/tests/ut/data/profiler_data/profiler/hccl_raw_6.csv index 80e164ed0e2..d0f7852c2b0 100644 --- a/tests/ut/data/profiler_data/profiler/hccl_raw_6.csv +++ b/tests/ut/data/profiler_data/profiler/hccl_raw_6.csv @@ -1,4 +1,4 @@ -step_num,communication_cost,wait_cost,link_info,communication_operator_cost -1,4.51637,4e-05,"{""1-0"": {""SDMA"": [4.36566, 94356.992, 21613454.09399724]}, ""0-0"": {""SDMA"": [0.15071, 47178.496, 313041576.53772146]}}","{""stream_24_0_AllReduce-op3143"": [""1"", 4.51637, 4e-05, {""1-0"": {""SDMA"": [4.36566, 94356.992, 21613454.09399724]}, ""0-0"": {""SDMA"": [0.15071, 47178.496, 313041576.53772146]}}]}" -2,4.519260000000001,7.000000000000001e-05,"{""1-0"": {""SDMA"": [4.368450000000001, 94356.992, 21599650.21918529]}, ""0-0"": {""SDMA"": [0.15081, 47178.496, 312834003.0501956]}}","{""stream_24_0_AllReduce-op3143"": [""2"", 4.519260000000001, 7.000000000000001e-05, {""1-0"": {""SDMA"": [4.368450000000001, 94356.992, 21599650.21918529]}, ""0-0"": {""SDMA"": [0.15081, 47178.496, 312834003.0501956]}}]}" --,4.517815000000001,5.500000000000001e-05,"{""1-0"": {""SDMA"": [4.367055000000001, 94356.992, 21606552.156591266]}, ""0-0"": {""SDMA"": [0.15076, 47178.496, 312937789.79395854]}}","{""stream_24_0_AllReduce-op3143"": [""-"", 4.517815000000001, 5.500000000000001e-05, {""1-0"": {""SDMA"": [4.367055000000001, 94356.992, 21606552.156591266]}, ""0-0"": {""SDMA"": [0.15076, 47178.496, 312937789.79395854]}}]}" +step_num,communication_cost,wait_cost,link_info,communication_operator_cost +1,4.51637,4e-05,"{""1-0"": {""SDMA"": [4.36566, 94356.992, 21613454.09399724]}, ""0-0"": {""SDMA"": [0.15071, 47178.496, 313041576.53772146]}}","{""stream_24_0_AllReduce-op3143"": [""1"", 4.51637, 4e-05, {""1-0"": {""SDMA"": [4.36566, 94356.992, 21613454.09399724]}, ""0-0"": {""SDMA"": [0.15071, 47178.496, 313041576.53772146]}}]}" +2,4.519260000000001,7.000000000000001e-05,"{""1-0"": {""SDMA"": [4.368450000000001, 94356.992, 21599650.21918529]}, ""0-0"": {""SDMA"": [0.15081, 47178.496, 312834003.0501956]}}","{""stream_24_0_AllReduce-op3143"": [""2"", 4.519260000000001, 7.000000000000001e-05, {""1-0"": {""SDMA"": [4.368450000000001, 94356.992, 21599650.21918529]}, ""0-0"": {""SDMA"": [0.15081, 47178.496, 312834003.0501956]}}]}" +-,4.517815000000001,5.500000000000001e-05,"{""1-0"": {""SDMA"": [4.367055000000001, 94356.992, 21606552.156591266]}, ""0-0"": {""SDMA"": [0.15076, 47178.496, 312937789.79395854]}}","{""stream_24_0_AllReduce-op3143"": [""-"", 4.517815000000001, 5.500000000000001e-05, {""1-0"": {""SDMA"": [4.367055000000001, 94356.992, 21606552.156591266]}, ""0-0"": {""SDMA"": [0.15076, 47178.496, 312937789.79395854]}}]}" diff --git a/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv b/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv index 57f8ddaf687..7e8ff097ad2 100644 --- a/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv +++ b/tests/ut/data/profiler_data/profiler/minddata_pipeline_raw_0.csv @@ -1,5 +1,5 @@ -op_id,op_type,num_workers,output_queue_size,output_queue_average_size,output_queue_length,output_queue_usage_rate,sample_interval,parent_id,children_id -0,Batch,4,,,,,10,,[1] -1,Shuffle,1,"[10, 20, 30]",20.0,64,0.3125,10,0,"[2, 3]" -2,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, -3,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, +op_id,op_type,num_workers,output_queue_size,output_queue_average_size,output_queue_length,output_queue_usage_rate,sample_interval,parent_id,children_id +0,Batch,4,,,,,10,,[1] +1,Shuffle,1,"[10, 20, 30]",20.0,64,0.3125,10,0,"[2, 3]" +2,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, +3,TFReader,4,"[10, 20, 30]",20.0,64,0.3125,10,1, diff --git a/tests/ut/data/profiler_data/profiler/step_trace_raw_0_detail_time.csv b/tests/ut/data/profiler_data/profiler/step_trace_raw_0_detail_time.csv old mode 100755 new mode 100644 diff --git a/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv b/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv old mode 100755 new mode 100644 index 9e97e499b24..ce765fc3864 --- a/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv +++ b/tests/ut/data/profiler_data/profiler/step_trace_raw_10_detail_time.csv @@ -1,42 +1,42 @@ -step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_10_parallel_0_start_point,stream_10_parallel_0_end_point,stream_10_parallel_0,stream_10_parallel_1_start_point,stream_10_parallel_1_end_point,stream_10_parallel_1,stream_10_parallel_2_start_point,stream_10_parallel_2_end_point,stream_10_parallel_2,stream_11_parallel_0_start_point,stream_11_parallel_0_end_point,stream_11_parallel_0 -1,45000025226,45004034753,4009527,45000025226,45001734362,0,1709136,2300391,45000044023,45000060886,16863,45001043581,45001343373,299792,45002254048,45002452830,198782,45000043807,45000065736,21929 -2,45004034753,45017091420,13056667,45013073790,45014789509,9039037,1715719,2301911,45013085205,45013104210,19005,45014086339,45014393261,306922,45015299546,45015501808,202262,45013085040,45013119810,34770 -3,45017091420,45030144372,13052952,45026123867,45027843651,9032447,1719784,2300721,45026138546,45026154524,15978,45027135742,45027437486,301744,45028363120,45028560901,197781,45026136046,45026171363,35317 -4,45030144372,45043184486,13040114,45039173149,45040883087,9028777,1709938,2301399,45039190927,45039209948,19021,45040185915,45040484897,298982,45041399754,45041594775,195021,45039192768,45039221423,28655 -5,45043184486,45056241064,13056578,45052223555,45053940709,9039069,1717154,2300355,45052241736,45052262186,20450,45053239605,45053540866,301261,45054452604,45054654505,201901,45052233932,45052265774,31842 -6,45056241064,45069291346,13050282,45065278144,45066991121,9037080,1712977,2300225,45065293660,45065316136,22476,45066289480,45066589910,300430,45067511002,45067701731,190729,45065293679,45065321296,27617 -7,45069291346,45082344927,13053581,45078335376,45080043268,9044030,1707892,2301659,45078353164,45078365382,12218,45079354748,45079648384,293636,45080557453,45080760374,202921,45078353030,45078384530,31500 -8,45082344927,45095382554,13037627,45091368697,45093080797,9023770,1712100,2301757,45091381244,45091405208,23964,45092382630,45092684285,301655,45093590961,45093796698,205737,45091381199,45091413840,32641 -9,45095382554,45108433947,13051393,45104419947,45106132133,9037393,1712186,2301814,45104432587,45104457476,24889,45105431458,45105735476,304018,45106651213,45106845305,194092,45104435207,45104466677,31470 -10,45108433947,45121486591,13052644,45117469353,45119185969,9035406,1716616,2300622,45117483627,45117504869,21242,45118483411,45118788540,305129,45119696660,45119898575,201915,45117485587,45117510985,25398 -11,45121486591,45134546571,13059980,45130528618,45132244809,9042027,1716191,2301762,45130539730,45130561122,21392,45131538695,45131846715,308020,45132759789,45132960848,201059,45130545378,45130569412,24034 -12,45134546571,45147608222,13061651,45143597023,45145307273,9050452,1710250,2300949,45143615771,45143631460,15689,45144610592,45144910736,300144,45145818642,45146024326,205684,45143613528,45143640223,26695 -13,45147608222,45160663790,13055568,45156648696,45158362923,9040474,1714227,2300867,45156663193,45156685466,22273,45157661576,45157963074,301498,45158881212,45159074431,193219,45156667038,45156694912,27874 -14,45160663790,45173707626,13043836,45169694535,45171407246,9030745,1712711,2300380,45169710667,45169727936,17269,45170705802,45171013806,308004,45171924100,45172120273,196173,45169708524,45169739038,30514 -15,45173707626,45186754860,13047234,45182750254,45184454036,9042628,1703782,2300824,45182765445,45182789799,24354,45183761335,45184065169,303834,45184973312,45185170444,197132,45182769451,45182799598,30147 -16,45186754860,45199798718,13043858,45195792271,45197497908,9037411,1705637,2300810,45195804771,45195827915,23144,45196804016,45197108243,304227,45198013357,45198209858,196501,45195806656,45195841674,35018 -17,45199798718,45212854993,13056275,45208834355,45210553378,9035637,1719023,2301615,45208850179,45208865588,15409,45209851018,45210151436,300418,45211073169,45211271792,198623,45208847052,45208876998,29946 -18,45212854993,45225893712,13038719,45221888939,45223593704,9033946,1704765,2300008,45221901732,45221924983,23251,45222908795,45223203590,294795,45224105803,45224313354,207551,45221899792,45221938802,39010 -19,45225893712,45238941242,13047530,45234926295,45236640454,9032583,1714159,2300788,45234938628,45234957237,18609,45235942710,45236239983,297273,45237159532,45237356140,196608,45234938330,45234976170,37840 -20,45238941242,45251979177,13037935,45247977674,45249678116,9036432,1700442,2301061,45247990919,45248013476,22557,45248991451,45249294742,303291,45250195733,45250395760,200027,45247988950,45248024969,36019 -21,45251979177,45265018752,13039575,45261005416,45262718472,9026239,1713056,2300280,0,0,0,0,0,0,0,0,0,0,0,0 -22,45265018752,45278062782,13044030,45274047185,45275762095,9028433,1714910,2300687,0,0,0,0,0,0,0,0,0,0,0,0 -23,45278062782,45291105708,13042926,45287094000,45288805223,9031218,1711223,2300485,0,0,0,0,0,0,0,0,0,0,0,0 -24,45291105708,45304155918,13050210,45300150844,45301854040,9045136,1703196,2301878,0,0,0,0,0,0,0,0,0,0,0,0 -25,45304155918,45317206695,13050777,45313191948,45314905714,9036030,1713766,2300981,0,0,0,0,0,0,0,0,0,0,0,0 -26,45317206695,45330265105,13058410,45326256021,45327964581,9049326,1708560,2300524,0,0,0,0,0,0,0,0,0,0,0,0 -27,45330265105,45343324012,13058907,45339305124,45341023739,9040019,1718615,2300273,0,0,0,0,0,0,0,0,0,0,0,0 -28,45343324012,45356374571,13050559,45352366211,45354073401,9042199,1707190,2301170,0,0,0,0,0,0,0,0,0,0,0,0 -29,45356374571,45369429514,13054943,45365417827,45367128283,9043256,1710456,2301231,0,0,0,0,0,0,0,0,0,0,0,0 -30,45369429514,45382479199,13049685,45378476397,45380177297,9046883,1700900,2301902,0,0,0,0,0,0,0,0,0,0,0,0 -31,45382479199,45395530376,13051177,45391510137,45393229377,9030938,1719240,2300999,0,0,0,0,0,0,0,0,0,0,0,0 -32,45395530376,45408571765,13041389,45404559082,45406270720,9028706,1711638,2301045,0,0,0,0,0,0,0,0,0,0,0,0 -33,45408571765,45421635175,13063410,45417619223,45419334221,9047458,1714998,2300954,0,0,0,0,0,0,0,0,0,0,0,0 -34,45421635175,45434672219,13037044,45430669445,45432371312,9034270,1701867,2300907,0,0,0,0,0,0,0,0,0,0,0,0 -35,45434672219,45447714036,13041817,45443704548,45445413852,9032329,1709304,2300184,0,0,0,0,0,0,0,0,0,0,0,0 -36,45447714036,45460765153,13051117,45456753675,45458463701,9039639,1710026,2301452,0,0,0,0,0,0,0,0,0,0,0,0 -37,45460765153,45473829105,13063952,45469808281,45471527400,9043128,1719119,2301705,0,0,0,0,0,0,0,0,0,0,0,0 -38,45473829105,45486884190,13055085,45482867237,45484583534,9038132,1716297,2300656,0,0,0,0,0,0,0,0,0,0,0,0 -39,45486884190,45499928571,13044381,45495917628,45497627921,9033438,1710293,2300650,0,0,0,0,0,0,0,0,0,0,0,0 -40,45499928571,45512973815,13045244,45508968990,45510673699,9040419,1704709,2300116,0,0,0,0,0,0,0,0,0,0,0,0 --,45251983006,45265032725,13049720,45261020353,45262731761,9037347,1711408,2300964,21986676455,21986686280,9825,21987163213,21987310272,147058,21987754537,21987851587,97050,21986676441,21986691731,15290 +step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_10_parallel_0_start_point,stream_10_parallel_0_end_point,stream_10_parallel_0,stream_10_parallel_1_start_point,stream_10_parallel_1_end_point,stream_10_parallel_1,stream_10_parallel_2_start_point,stream_10_parallel_2_end_point,stream_10_parallel_2,stream_11_parallel_0_start_point,stream_11_parallel_0_end_point,stream_11_parallel_0 +1,45000025226,45004034753,4009527,45000025226,45001734362,0,1709136,2300391,45000044023,45000060886,16863,45001043581,45001343373,299792,45002254048,45002452830,198782,45000043807,45000065736,21929 +2,45004034753,45017091420,13056667,45013073790,45014789509,9039037,1715719,2301911,45013085205,45013104210,19005,45014086339,45014393261,306922,45015299546,45015501808,202262,45013085040,45013119810,34770 +3,45017091420,45030144372,13052952,45026123867,45027843651,9032447,1719784,2300721,45026138546,45026154524,15978,45027135742,45027437486,301744,45028363120,45028560901,197781,45026136046,45026171363,35317 +4,45030144372,45043184486,13040114,45039173149,45040883087,9028777,1709938,2301399,45039190927,45039209948,19021,45040185915,45040484897,298982,45041399754,45041594775,195021,45039192768,45039221423,28655 +5,45043184486,45056241064,13056578,45052223555,45053940709,9039069,1717154,2300355,45052241736,45052262186,20450,45053239605,45053540866,301261,45054452604,45054654505,201901,45052233932,45052265774,31842 +6,45056241064,45069291346,13050282,45065278144,45066991121,9037080,1712977,2300225,45065293660,45065316136,22476,45066289480,45066589910,300430,45067511002,45067701731,190729,45065293679,45065321296,27617 +7,45069291346,45082344927,13053581,45078335376,45080043268,9044030,1707892,2301659,45078353164,45078365382,12218,45079354748,45079648384,293636,45080557453,45080760374,202921,45078353030,45078384530,31500 +8,45082344927,45095382554,13037627,45091368697,45093080797,9023770,1712100,2301757,45091381244,45091405208,23964,45092382630,45092684285,301655,45093590961,45093796698,205737,45091381199,45091413840,32641 +9,45095382554,45108433947,13051393,45104419947,45106132133,9037393,1712186,2301814,45104432587,45104457476,24889,45105431458,45105735476,304018,45106651213,45106845305,194092,45104435207,45104466677,31470 +10,45108433947,45121486591,13052644,45117469353,45119185969,9035406,1716616,2300622,45117483627,45117504869,21242,45118483411,45118788540,305129,45119696660,45119898575,201915,45117485587,45117510985,25398 +11,45121486591,45134546571,13059980,45130528618,45132244809,9042027,1716191,2301762,45130539730,45130561122,21392,45131538695,45131846715,308020,45132759789,45132960848,201059,45130545378,45130569412,24034 +12,45134546571,45147608222,13061651,45143597023,45145307273,9050452,1710250,2300949,45143615771,45143631460,15689,45144610592,45144910736,300144,45145818642,45146024326,205684,45143613528,45143640223,26695 +13,45147608222,45160663790,13055568,45156648696,45158362923,9040474,1714227,2300867,45156663193,45156685466,22273,45157661576,45157963074,301498,45158881212,45159074431,193219,45156667038,45156694912,27874 +14,45160663790,45173707626,13043836,45169694535,45171407246,9030745,1712711,2300380,45169710667,45169727936,17269,45170705802,45171013806,308004,45171924100,45172120273,196173,45169708524,45169739038,30514 +15,45173707626,45186754860,13047234,45182750254,45184454036,9042628,1703782,2300824,45182765445,45182789799,24354,45183761335,45184065169,303834,45184973312,45185170444,197132,45182769451,45182799598,30147 +16,45186754860,45199798718,13043858,45195792271,45197497908,9037411,1705637,2300810,45195804771,45195827915,23144,45196804016,45197108243,304227,45198013357,45198209858,196501,45195806656,45195841674,35018 +17,45199798718,45212854993,13056275,45208834355,45210553378,9035637,1719023,2301615,45208850179,45208865588,15409,45209851018,45210151436,300418,45211073169,45211271792,198623,45208847052,45208876998,29946 +18,45212854993,45225893712,13038719,45221888939,45223593704,9033946,1704765,2300008,45221901732,45221924983,23251,45222908795,45223203590,294795,45224105803,45224313354,207551,45221899792,45221938802,39010 +19,45225893712,45238941242,13047530,45234926295,45236640454,9032583,1714159,2300788,45234938628,45234957237,18609,45235942710,45236239983,297273,45237159532,45237356140,196608,45234938330,45234976170,37840 +20,45238941242,45251979177,13037935,45247977674,45249678116,9036432,1700442,2301061,45247990919,45248013476,22557,45248991451,45249294742,303291,45250195733,45250395760,200027,45247988950,45248024969,36019 +21,45251979177,45265018752,13039575,45261005416,45262718472,9026239,1713056,2300280,0,0,0,0,0,0,0,0,0,0,0,0 +22,45265018752,45278062782,13044030,45274047185,45275762095,9028433,1714910,2300687,0,0,0,0,0,0,0,0,0,0,0,0 +23,45278062782,45291105708,13042926,45287094000,45288805223,9031218,1711223,2300485,0,0,0,0,0,0,0,0,0,0,0,0 +24,45291105708,45304155918,13050210,45300150844,45301854040,9045136,1703196,2301878,0,0,0,0,0,0,0,0,0,0,0,0 +25,45304155918,45317206695,13050777,45313191948,45314905714,9036030,1713766,2300981,0,0,0,0,0,0,0,0,0,0,0,0 +26,45317206695,45330265105,13058410,45326256021,45327964581,9049326,1708560,2300524,0,0,0,0,0,0,0,0,0,0,0,0 +27,45330265105,45343324012,13058907,45339305124,45341023739,9040019,1718615,2300273,0,0,0,0,0,0,0,0,0,0,0,0 +28,45343324012,45356374571,13050559,45352366211,45354073401,9042199,1707190,2301170,0,0,0,0,0,0,0,0,0,0,0,0 +29,45356374571,45369429514,13054943,45365417827,45367128283,9043256,1710456,2301231,0,0,0,0,0,0,0,0,0,0,0,0 +30,45369429514,45382479199,13049685,45378476397,45380177297,9046883,1700900,2301902,0,0,0,0,0,0,0,0,0,0,0,0 +31,45382479199,45395530376,13051177,45391510137,45393229377,9030938,1719240,2300999,0,0,0,0,0,0,0,0,0,0,0,0 +32,45395530376,45408571765,13041389,45404559082,45406270720,9028706,1711638,2301045,0,0,0,0,0,0,0,0,0,0,0,0 +33,45408571765,45421635175,13063410,45417619223,45419334221,9047458,1714998,2300954,0,0,0,0,0,0,0,0,0,0,0,0 +34,45421635175,45434672219,13037044,45430669445,45432371312,9034270,1701867,2300907,0,0,0,0,0,0,0,0,0,0,0,0 +35,45434672219,45447714036,13041817,45443704548,45445413852,9032329,1709304,2300184,0,0,0,0,0,0,0,0,0,0,0,0 +36,45447714036,45460765153,13051117,45456753675,45458463701,9039639,1710026,2301452,0,0,0,0,0,0,0,0,0,0,0,0 +37,45460765153,45473829105,13063952,45469808281,45471527400,9043128,1719119,2301705,0,0,0,0,0,0,0,0,0,0,0,0 +38,45473829105,45486884190,13055085,45482867237,45484583534,9038132,1716297,2300656,0,0,0,0,0,0,0,0,0,0,0,0 +39,45486884190,45499928571,13044381,45495917628,45497627921,9033438,1710293,2300650,0,0,0,0,0,0,0,0,0,0,0,0 +40,45499928571,45512973815,13045244,45508968990,45510673699,9040419,1704709,2300116,0,0,0,0,0,0,0,0,0,0,0,0 +-,45251983006,45265032725,13049720,45261020353,45262731761,9037347,1711408,2300964,21986676455,21986686280,9825,21987163213,21987310272,147058,21987754537,21987851587,97050,21986676441,21986691731,15290 diff --git a/tests/ut/data/profiler_data/profiler/step_trace_raw_6_detail_time.csv b/tests/ut/data/profiler_data/profiler/step_trace_raw_6_detail_time.csv index 0aadd2509f8..a747c191981 100644 --- a/tests/ut/data/profiler_data/profiler/step_trace_raw_6_detail_time.csv +++ b/tests/ut/data/profiler_data/profiler/step_trace_raw_6_detail_time.csv @@ -1,4 +1,4 @@ -step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_24_0_AllReduce-op3143,stream_24_0_AllReduce-op3143_start_point,stream_24_0_AllReduce-op3143_end_point -1,61688106196804,61688109551051,3354247,61688106196804,61688109001447,0,2804643,549604,454080,61688109006157,61688109460237 -2,61688109551051,61688112916194,3365143,61688109556119,61688112367356,5068,2811237,548838,454080,61688112371861,61688112825941 --,61688109551051,61688112916194,3365143,61688109556119,61688112367356,5068,2811237,548838,454080,61688112371861,61688112825941 +step_num,start_point,end_point,total,fp_point,bp_point,iteration_interval,fp_and_bp,tail,stream_24_0_AllReduce-op3143,stream_24_0_AllReduce-op3143_start_point,stream_24_0_AllReduce-op3143_end_point +1,61688106196804,61688109551051,3354247,61688106196804,61688109001447,0,2804643,549604,454080,61688109006157,61688109460237 +2,61688109551051,61688112916194,3365143,61688109556119,61688112367356,5068,2811237,548838,454080,61688112371861,61688112825941 +-,61688109551051,61688112916194,3365143,61688109556119,61688112367356,5068,2811237,548838,454080,61688112371861,61688112825941 diff --git a/tests/ut/python/cachetests/cachetest.sh b/tests/ut/python/cachetests/cachetest.sh old mode 100755 new mode 100644 diff --git a/tests/ut/python/cachetests/cachetest_args.sh b/tests/ut/python/cachetests/cachetest_args.sh old mode 100755 new mode 100644 diff --git a/tests/ut/python/cachetests/cachetest_cpp.sh b/tests/ut/python/cachetests/cachetest_cpp.sh old mode 100755 new mode 100644 diff --git a/tests/ut/python/cachetests/cachetest_lib.sh b/tests/ut/python/cachetests/cachetest_lib.sh old mode 100755 new mode 100644 diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_adjust_brightness.py b/tests/ut/python/dataset/test_adjust_brightness.py index 8c2e65ae23b..22f7758820e 100644 --- a/tests/ut/python/dataset/test_adjust_brightness.py +++ b/tests/ut/python/dataset/test_adjust_brightness.py @@ -1,147 +1,147 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing AdjustBrightness op in DE -""" -import numpy as np -from numpy.testing import assert_allclose - -import mindspore.dataset as ds -import mindspore.dataset.transforms.transforms -import mindspore.dataset.vision as vision -from mindspore import log as logger -from util import diff_mse - -DATA_DIR = "../data/dataset/testImageNetData/train/" -MNIST_DATA_DIR = "../data/dataset/testMnistData" - -DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] -SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" - - -def generate_numpy_random_rgb(shape): - """ - Only generate floating points that are fractions like n / 256, since they - are RGB pixels. Some low-precision floating point types in this test can't - handle arbitrary precision floating points well. - """ - return np.random.randint(0, 256, shape) / 255. - - -def test_adjust_brightness_eager(plot=False): - """ - Feature: AdjustBrightness op - Description: Test AdjustBrightness in eager mode - Expectation: Output is the same as expected output - """ - # Eager 3-channel - image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg" - img = np.fromfile(image_file, dtype=np.uint8) - logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) - - img = vision.Decode()(img) - img_adjustbrightness = vision.AdjustBrightness(1)(img) - if plot: - visualize_image(img, img_adjustbrightness) - logger.info("Image.type: {}, Image.shape: {}".format(type(img_adjustbrightness), - img_adjustbrightness.shape)) - mse = diff_mse(img_adjustbrightness, img) - logger.info("MSE= {}".format(str(mse))) - assert mse == 0 - - -def test_adjust_brightness_invalid_brightness_factor_param(): - """ - Feature: AdjustBrightness op - Description: Test improper parameters for AdjustBrightness implementation - Expectation: Throw ValueError exception and TypeError exception - """ - logger.info("Test AdjustBrightness implementation with invalid ignore parameter") - try: - data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) - trans = mindspore.dataset.transforms.transforms.Compose([ - vision.Decode(True), - vision.Resize((224, 224)), - vision.AdjustBrightness(brightness_factor=-10.0), - vision.ToTensor() - ]) - data_set = data_set.map(operations=[trans], input_columns=["image"]) - except ValueError as error: - logger.info("Got an exception in AdjustBrightness: {}".format(str(error))) - assert "Input brightness_factor is not within the required interval of " in str(error) - try: - data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) - trans = mindspore.dataset.transforms.transforms.Compose([ - vision.Decode(True), - vision.Resize((224, 224)), - vision.AdjustBrightness(brightness_factor=[1, 2]), - vision.ToTensor() - ]) - data_set = data_set.map(operations=[trans], input_columns=["image"]) - except TypeError as error: - logger.info("Got an exception in AdjustBrightness: {}".format(str(error))) - assert "is not of type [, ], but got" in str(error) - - -def test_adjust_brightness_pipeline(): - """ - Feature: AdjustBrightness op - Description: Test AdjustBrightness in pipeline mode - Expectation: Output is the same as expected output - """ - # First dataset - transforms1 = [vision.Decode(True), vision.Resize([64, 64]), vision.ToTensor()] - transforms1 = mindspore.dataset.transforms.transforms.Compose( - transforms1) - ds1 = ds.TFRecordDataset(DATA_DIR_2, - SCHEMA_DIR, - columns_list=["image"], - shuffle=False) - ds1 = ds1.map(operations=transforms1, input_columns=["image"]) - - # Second dataset - transforms2 = [ - vision.Decode(True), - vision.Resize([64, 64]), - vision.AdjustBrightness(1.0), - vision.ToTensor() - ] - transform2 = mindspore.dataset.transforms.transforms.Compose( - transforms2) - ds2 = ds.TFRecordDataset(DATA_DIR_2, - SCHEMA_DIR, - columns_list=["image"], - shuffle=False) - ds2 = ds2.map(operations=transform2, input_columns=["image"]) - - num_iter = 0 - for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), - ds2.create_dict_iterator(num_epochs=1)): - num_iter += 1 - ori_img = data1["image"].asnumpy() - cvt_img = data2["image"].asnumpy() - assert_allclose(ori_img.flatten(), - cvt_img.flatten(), - rtol=1e-5, - atol=0) - mse = diff_mse(ori_img, cvt_img) - logger.info("MSE= {}".format(str(mse))) - assert mse == 0 - - -if __name__ == "__main__": - test_adjust_brightness_eager() - test_adjust_brightness_invalid_brightness_factor_param() - test_adjust_brightness_pipeline() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing AdjustBrightness op in DE +""" +import numpy as np +from numpy.testing import assert_allclose + +import mindspore.dataset as ds +import mindspore.dataset.transforms.transforms +import mindspore.dataset.vision as vision +from mindspore import log as logger +from util import diff_mse + +DATA_DIR = "../data/dataset/testImageNetData/train/" +MNIST_DATA_DIR = "../data/dataset/testMnistData" + +DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + + +def generate_numpy_random_rgb(shape): + """ + Only generate floating points that are fractions like n / 256, since they + are RGB pixels. Some low-precision floating point types in this test can't + handle arbitrary precision floating points well. + """ + return np.random.randint(0, 256, shape) / 255. + + +def test_adjust_brightness_eager(plot=False): + """ + Feature: AdjustBrightness op + Description: Test AdjustBrightness in eager mode + Expectation: Output is the same as expected output + """ + # Eager 3-channel + image_file = "../data/dataset/testImageNetData/train/class1/1_1.jpg" + img = np.fromfile(image_file, dtype=np.uint8) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + + img = vision.Decode()(img) + img_adjustbrightness = vision.AdjustBrightness(1)(img) + if plot: + visualize_image(img, img_adjustbrightness) + logger.info("Image.type: {}, Image.shape: {}".format(type(img_adjustbrightness), + img_adjustbrightness.shape)) + mse = diff_mse(img_adjustbrightness, img) + logger.info("MSE= {}".format(str(mse))) + assert mse == 0 + + +def test_adjust_brightness_invalid_brightness_factor_param(): + """ + Feature: AdjustBrightness op + Description: Test improper parameters for AdjustBrightness implementation + Expectation: Throw ValueError exception and TypeError exception + """ + logger.info("Test AdjustBrightness implementation with invalid ignore parameter") + try: + data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) + trans = mindspore.dataset.transforms.transforms.Compose([ + vision.Decode(True), + vision.Resize((224, 224)), + vision.AdjustBrightness(brightness_factor=-10.0), + vision.ToTensor() + ]) + data_set = data_set.map(operations=[trans], input_columns=["image"]) + except ValueError as error: + logger.info("Got an exception in AdjustBrightness: {}".format(str(error))) + assert "Input brightness_factor is not within the required interval of " in str(error) + try: + data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) + trans = mindspore.dataset.transforms.transforms.Compose([ + vision.Decode(True), + vision.Resize((224, 224)), + vision.AdjustBrightness(brightness_factor=[1, 2]), + vision.ToTensor() + ]) + data_set = data_set.map(operations=[trans], input_columns=["image"]) + except TypeError as error: + logger.info("Got an exception in AdjustBrightness: {}".format(str(error))) + assert "is not of type [, ], but got" in str(error) + + +def test_adjust_brightness_pipeline(): + """ + Feature: AdjustBrightness op + Description: Test AdjustBrightness in pipeline mode + Expectation: Output is the same as expected output + """ + # First dataset + transforms1 = [vision.Decode(True), vision.Resize([64, 64]), vision.ToTensor()] + transforms1 = mindspore.dataset.transforms.transforms.Compose( + transforms1) + ds1 = ds.TFRecordDataset(DATA_DIR_2, + SCHEMA_DIR, + columns_list=["image"], + shuffle=False) + ds1 = ds1.map(operations=transforms1, input_columns=["image"]) + + # Second dataset + transforms2 = [ + vision.Decode(True), + vision.Resize([64, 64]), + vision.AdjustBrightness(1.0), + vision.ToTensor() + ] + transform2 = mindspore.dataset.transforms.transforms.Compose( + transforms2) + ds2 = ds.TFRecordDataset(DATA_DIR_2, + SCHEMA_DIR, + columns_list=["image"], + shuffle=False) + ds2 = ds2.map(operations=transform2, input_columns=["image"]) + + num_iter = 0 + for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), + ds2.create_dict_iterator(num_epochs=1)): + num_iter += 1 + ori_img = data1["image"].asnumpy() + cvt_img = data2["image"].asnumpy() + assert_allclose(ori_img.flatten(), + cvt_img.flatten(), + rtol=1e-5, + atol=0) + mse = diff_mse(ori_img, cvt_img) + logger.info("MSE= {}".format(str(mse))) + assert mse == 0 + + +if __name__ == "__main__": + test_adjust_brightness_eager() + test_adjust_brightness_invalid_brightness_factor_param() + test_adjust_brightness_pipeline() diff --git a/tests/ut/python/dataset/test_bandpass_biquad.py b/tests/ut/python/dataset/test_bandpass_biquad.py index eb7aa690507..ae66293ac6f 100644 --- a/tests/ut/python/dataset/test_bandpass_biquad.py +++ b/tests/ut/python/dataset/test_bandpass_biquad.py @@ -1,125 +1,125 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np -import pytest -import mindspore.dataset as ds -import mindspore.dataset.audio as audio -from mindspore import log as logger - - -def count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( - data_expected[greater], data_me[greater], error[greater]) - - -def test_func_bandpass_biquad_eager(): - """ - Feature: BandpassBiquad op - Description: Test BandpassBiquad op in eager mode with valid input - Expectation: Output is equal to the expected output - """ - - # Original waveform - waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) - # Expect waveform - expect_waveform = np.array([[0.01979545, 0.07838227, 0.17417782], - [0.07918181, 0.25414270, 0.46156447]], dtype=np.float64) - bandpass_biquad_op = audio.BandpassBiquad(44000, 200.0, 0.707, False) - # Filtered waveform by bandpassbiquad - output = bandpass_biquad_op(waveform) - count_unequal_element(expect_waveform, output, 0.0001, 0.0001) - - -def test_func_bandpass_biquad_pipeline(): - """ - Feature: BandpassBiquad op - Description: Test BandpassBiquad op in pipeline mode with valid input - Expectation: Output is equal to the expected output - """ - - # Original waveform - waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) - # Expect waveform - expect_waveform = np.array([[0.01979545, 0.07838227, 0.17417782], - [0.07918181, 0.25414270, 0.46156447]], dtype=np.float64) - label = np.random.sample((2, 1)) - data = (waveform, label) - dataset = ds.NumpySlicesDataset(data, ["channel", "sample"], shuffle=False) - bandpass_biquad_op = audio.BandpassBiquad(44000, 200.0) - # Filtered waveform by bandpassbiquad - dataset = dataset.map( - input_columns=["channel"], operations=bandpass_biquad_op, num_parallel_workers=8) - i = 0 - for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - count_unequal_element( - expect_waveform[i, :], item['channel'], 0.0001, 0.0001) - i += 1 - - -def test_bandpass_biquad_invalid_input(): - """ - Feature: BandpassBiquad op - Description: Test BandpassBiquad op with invalid input - Expectation: Correct error and message are thrown as expected - """ - def test_invalid_input(test_name, sample_rate, central_freq, Q, const_skirt_gain, error, error_msg): - logger.info( - "Test BandpassBiquad with bad input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - audio.BandpassBiquad( - sample_rate, central_freq, Q, const_skirt_gain) - assert error_msg in str(error_info.value) - - test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 200, 0.707, True, TypeError, - "Argument sample_rate with value 44100.5 is not of type []," - " but got .") - test_invalid_input("invalid sample_rate parameter type as a String", "44100", 200, 0.707, True, TypeError, - "Argument sample_rate with value 44100 is not of type [], but got .") - test_invalid_input("invalid contral_freq parameter type as a String", 44100, "200", 0.707, True, TypeError, - "Argument central_freq with value 200 is not of type [, ]," - " but got .") - test_invalid_input("invalid sample_rate parameter value", 0, 200, 0.707, True, ValueError, - "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") - test_invalid_input("invalid contral_freq parameter value", 44100, 32434324324234321, 0.707, True, ValueError, - "Input central_freq is not within the required interval of [-16777216, 16777216].") - test_invalid_input("invalid Q parameter type as a String", 44100, 200, "0.707", True, TypeError, - "Argument Q with value 0.707 is not of type [, ]," - " but got .") - test_invalid_input("invalid Q parameter value", 44100, 200, 1.707, True, ValueError, - "Input Q is not within the required interval of (0, 1].") - test_invalid_input("invalid Q parameter value", 44100, 200, 0, True, ValueError, - "Input Q is not within the required interval of (0, 1].") - test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 200, 0.707, True, ValueError, - "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") - test_invalid_input("invalid sample_rate parameter value", None, 200, 0.707, True, TypeError, - "Argument sample_rate with value None is not of type []," - " but got .") - test_invalid_input("invalid central_rate parameter value", 44100, None, 0.707, True, TypeError, - "Argument central_freq with value None is not of type [, ]," - " but got .") - test_invalid_input("invalid const_skirt_gain parameter type as a String", 44100, 200, 0.707, "False", TypeError, - "Argument const_skirt_gain with value False is not of type [], " + - "but got .") - - -if __name__ == "__main__": - test_func_bandpass_biquad_eager() - test_func_bandpass_biquad_pipeline() - test_bandpass_biquad_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest +import mindspore.dataset as ds +import mindspore.dataset.audio as audio +from mindspore import log as logger + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( + data_expected[greater], data_me[greater], error[greater]) + + +def test_func_bandpass_biquad_eager(): + """ + Feature: BandpassBiquad op + Description: Test BandpassBiquad op in eager mode with valid input + Expectation: Output is equal to the expected output + """ + + # Original waveform + waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([[0.01979545, 0.07838227, 0.17417782], + [0.07918181, 0.25414270, 0.46156447]], dtype=np.float64) + bandpass_biquad_op = audio.BandpassBiquad(44000, 200.0, 0.707, False) + # Filtered waveform by bandpassbiquad + output = bandpass_biquad_op(waveform) + count_unequal_element(expect_waveform, output, 0.0001, 0.0001) + + +def test_func_bandpass_biquad_pipeline(): + """ + Feature: BandpassBiquad op + Description: Test BandpassBiquad op in pipeline mode with valid input + Expectation: Output is equal to the expected output + """ + + # Original waveform + waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([[0.01979545, 0.07838227, 0.17417782], + [0.07918181, 0.25414270, 0.46156447]], dtype=np.float64) + label = np.random.sample((2, 1)) + data = (waveform, label) + dataset = ds.NumpySlicesDataset(data, ["channel", "sample"], shuffle=False) + bandpass_biquad_op = audio.BandpassBiquad(44000, 200.0) + # Filtered waveform by bandpassbiquad + dataset = dataset.map( + input_columns=["channel"], operations=bandpass_biquad_op, num_parallel_workers=8) + i = 0 + for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + count_unequal_element( + expect_waveform[i, :], item['channel'], 0.0001, 0.0001) + i += 1 + + +def test_bandpass_biquad_invalid_input(): + """ + Feature: BandpassBiquad op + Description: Test BandpassBiquad op with invalid input + Expectation: Correct error and message are thrown as expected + """ + def test_invalid_input(test_name, sample_rate, central_freq, Q, const_skirt_gain, error, error_msg): + logger.info( + "Test BandpassBiquad with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + audio.BandpassBiquad( + sample_rate, central_freq, Q, const_skirt_gain) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 200, 0.707, True, TypeError, + "Argument sample_rate with value 44100.5 is not of type []," + " but got .") + test_invalid_input("invalid sample_rate parameter type as a String", "44100", 200, 0.707, True, TypeError, + "Argument sample_rate with value 44100 is not of type [], but got .") + test_invalid_input("invalid contral_freq parameter type as a String", 44100, "200", 0.707, True, TypeError, + "Argument central_freq with value 200 is not of type [, ]," + " but got .") + test_invalid_input("invalid sample_rate parameter value", 0, 200, 0.707, True, ValueError, + "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") + test_invalid_input("invalid contral_freq parameter value", 44100, 32434324324234321, 0.707, True, ValueError, + "Input central_freq is not within the required interval of [-16777216, 16777216].") + test_invalid_input("invalid Q parameter type as a String", 44100, 200, "0.707", True, TypeError, + "Argument Q with value 0.707 is not of type [, ]," + " but got .") + test_invalid_input("invalid Q parameter value", 44100, 200, 1.707, True, ValueError, + "Input Q is not within the required interval of (0, 1].") + test_invalid_input("invalid Q parameter value", 44100, 200, 0, True, ValueError, + "Input Q is not within the required interval of (0, 1].") + test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 200, 0.707, True, ValueError, + "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") + test_invalid_input("invalid sample_rate parameter value", None, 200, 0.707, True, TypeError, + "Argument sample_rate with value None is not of type []," + " but got .") + test_invalid_input("invalid central_rate parameter value", 44100, None, 0.707, True, TypeError, + "Argument central_freq with value None is not of type [, ]," + " but got .") + test_invalid_input("invalid const_skirt_gain parameter type as a String", 44100, 200, 0.707, "False", TypeError, + "Argument const_skirt_gain with value False is not of type [], " + + "but got .") + + +if __name__ == "__main__": + test_func_bandpass_biquad_eager() + test_func_bandpass_biquad_pipeline() + test_bandpass_biquad_invalid_input() diff --git a/tests/ut/python/dataset/test_bass_biquad.py b/tests/ut/python/dataset/test_bass_biquad.py index 45c11beda08..c2bb726e5b1 100644 --- a/tests/ut/python/dataset/test_bass_biquad.py +++ b/tests/ut/python/dataset/test_bass_biquad.py @@ -1,130 +1,130 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np -import pytest -import mindspore.dataset as ds -import mindspore.dataset.audio as audio -from mindspore import log as logger - - -def count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( - data_expected[greater], data_me[greater], error[greater]) - - -def test_func_bass_biquad_eager(): - """ - Feature: BassBiquad op - Description: Test BassBiquad op in eager mode with valid input - Expectation: Output is equal to the expected output - """ - - # Original waveform - waveform = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) - # Expect waveform - expect_waveform = np.array([[0.10409035359, 0.21652136269, 0.33761211292], - [0.41636141439, 0.55381438997, 0.70088436361]], dtype=np.float64) - bass_biquad_op = audio.BassBiquad(44100, 50.0, 100.0, 0.707) - # Filtered waveform by bassbiquad - output = bass_biquad_op(waveform) - count_unequal_element(expect_waveform, output, 0.0001, 0.0001) - - -def test_func_bass_biquad_pipeline(): - """ - Feature: BassBiquad op - Description: Test BassBiquad op in pipeline mode with valid input - Expectation: Output is equal to the expected output - """ - - # Original waveform - waveform = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) - # Expect waveform - expect_waveform = np.array([[0.10409035359, 0.21652136269, 0.33761211292], - [0.41636141439, 0.55381438997, 0.70088436361]], dtype=np.float64) - label = np.random.sample((2, 1)) - data = (waveform, label) - dataset = ds.NumpySlicesDataset(data, ["channel", "sample"], shuffle=False) - bass_biquad_op = audio.BassBiquad(44100, 50, 100.0, 0.707) - # Filtered waveform by bassbiquad - dataset = dataset.map( - input_columns=["channel"], operations=bass_biquad_op, num_parallel_workers=8) - i = 0 - for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - count_unequal_element(expect_waveform[i, :], - item['channel'], 0.0001, 0.0001) - i += 1 - - -def test_invalid_invalid_input(): - """ - Feature: BassBiquad op - Description: Test BassBiquad op with invalid input - Expectation: Correct error and message are thrown as expected - """ - def test_invalid_input(test_name, sample_rate, gain, central_freq, Q, error, error_msg): - logger.info("Test BassBiquad with bad input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - audio.BassBiquad(sample_rate, gain, central_freq, Q) - assert error_msg in str(error_info.value) - - test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 50.0, 200, 0.707, TypeError, - "Argument sample_rate with value 44100.5 is not of type []," - " but got .") - test_invalid_input("invalid sample_rate parameter type as a String", "44100", 50.0, 200, 0.707, TypeError, - "Argument sample_rate with value 44100 is not of type []," - " but got .") - test_invalid_input("invalid gain parameter type as a String", 44100, "50.0", 200, 0.707, TypeError, - "Argument gain with value 50.0 is not of type [, ]," - " but got .") - test_invalid_input("invalid contral_freq parameter type as a String", 44100, 50.0, "200", 0.707, TypeError, - "Argument central_freq with value 200 is not of type [, ]," - " but got .") - test_invalid_input("invalid Q parameter type as a String", 44100, 50.0, 200, "0.707", TypeError, - "Argument Q with value 0.707 is not of type [, ]," - " but got .") - - test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 50.0, 200, 0.707, ValueError, - "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") - test_invalid_input("invalid gain parameter value", 44100, 32434324324234321, 200, 0.707, ValueError, - "Input gain is not within the required interval of [-16777216, 16777216].") - test_invalid_input("invalid contral_freq parameter value", 44100, 50, 32434324324234321, 0.707, ValueError, - "Input central_freq is not within the required interval of [-16777216, 16777216].") - - test_invalid_input("invalid sample_rate parameter value", None, 50.0, 200, 0.707, TypeError, - "Argument sample_rate with value None is not of type [], " - "but got .") - test_invalid_input("invalid gain parameter value", 44100, None, 200, 0.707, TypeError, - "Argument gain with value None is not of type [, ], " - "but got .") - test_invalid_input("invalid central_rate parameter value", 44100, 50.0, None, 0.707, TypeError, - "Argument central_freq with value None is not of type [, ]," - " but got .") - - test_invalid_input("invalid sample_rate parameter value", 0, 50.0, 200, 0.707, ValueError, - "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") - test_invalid_input("invalid Q parameter value", 44100, 50.0, 200, 1.707, ValueError, - "Input Q is not within the required interval of (0, 1].") - - -if __name__ == '__main__': - test_func_bass_biquad_eager() - test_func_bass_biquad_pipeline() - test_invalid_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest +import mindspore.dataset as ds +import mindspore.dataset.audio as audio +from mindspore import log as logger + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( + data_expected[greater], data_me[greater], error[greater]) + + +def test_func_bass_biquad_eager(): + """ + Feature: BassBiquad op + Description: Test BassBiquad op in eager mode with valid input + Expectation: Output is equal to the expected output + """ + + # Original waveform + waveform = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([[0.10409035359, 0.21652136269, 0.33761211292], + [0.41636141439, 0.55381438997, 0.70088436361]], dtype=np.float64) + bass_biquad_op = audio.BassBiquad(44100, 50.0, 100.0, 0.707) + # Filtered waveform by bassbiquad + output = bass_biquad_op(waveform) + count_unequal_element(expect_waveform, output, 0.0001, 0.0001) + + +def test_func_bass_biquad_pipeline(): + """ + Feature: BassBiquad op + Description: Test BassBiquad op in pipeline mode with valid input + Expectation: Output is equal to the expected output + """ + + # Original waveform + waveform = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([[0.10409035359, 0.21652136269, 0.33761211292], + [0.41636141439, 0.55381438997, 0.70088436361]], dtype=np.float64) + label = np.random.sample((2, 1)) + data = (waveform, label) + dataset = ds.NumpySlicesDataset(data, ["channel", "sample"], shuffle=False) + bass_biquad_op = audio.BassBiquad(44100, 50, 100.0, 0.707) + # Filtered waveform by bassbiquad + dataset = dataset.map( + input_columns=["channel"], operations=bass_biquad_op, num_parallel_workers=8) + i = 0 + for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + count_unequal_element(expect_waveform[i, :], + item['channel'], 0.0001, 0.0001) + i += 1 + + +def test_invalid_invalid_input(): + """ + Feature: BassBiquad op + Description: Test BassBiquad op with invalid input + Expectation: Correct error and message are thrown as expected + """ + def test_invalid_input(test_name, sample_rate, gain, central_freq, Q, error, error_msg): + logger.info("Test BassBiquad with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + audio.BassBiquad(sample_rate, gain, central_freq, Q) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 50.0, 200, 0.707, TypeError, + "Argument sample_rate with value 44100.5 is not of type []," + " but got .") + test_invalid_input("invalid sample_rate parameter type as a String", "44100", 50.0, 200, 0.707, TypeError, + "Argument sample_rate with value 44100 is not of type []," + " but got .") + test_invalid_input("invalid gain parameter type as a String", 44100, "50.0", 200, 0.707, TypeError, + "Argument gain with value 50.0 is not of type [, ]," + " but got .") + test_invalid_input("invalid contral_freq parameter type as a String", 44100, 50.0, "200", 0.707, TypeError, + "Argument central_freq with value 200 is not of type [, ]," + " but got .") + test_invalid_input("invalid Q parameter type as a String", 44100, 50.0, 200, "0.707", TypeError, + "Argument Q with value 0.707 is not of type [, ]," + " but got .") + + test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 50.0, 200, 0.707, ValueError, + "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") + test_invalid_input("invalid gain parameter value", 44100, 32434324324234321, 200, 0.707, ValueError, + "Input gain is not within the required interval of [-16777216, 16777216].") + test_invalid_input("invalid contral_freq parameter value", 44100, 50, 32434324324234321, 0.707, ValueError, + "Input central_freq is not within the required interval of [-16777216, 16777216].") + + test_invalid_input("invalid sample_rate parameter value", None, 50.0, 200, 0.707, TypeError, + "Argument sample_rate with value None is not of type [], " + "but got .") + test_invalid_input("invalid gain parameter value", 44100, None, 200, 0.707, TypeError, + "Argument gain with value None is not of type [, ], " + "but got .") + test_invalid_input("invalid central_rate parameter value", 44100, 50.0, None, 0.707, TypeError, + "Argument central_freq with value None is not of type [, ]," + " but got .") + + test_invalid_input("invalid sample_rate parameter value", 0, 50.0, 200, 0.707, ValueError, + "Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].") + test_invalid_input("invalid Q parameter value", 44100, 50.0, 200, 1.707, ValueError, + "Input Q is not within the required interval of (0, 1].") + + +if __name__ == '__main__': + test_func_bass_biquad_eager() + test_func_bass_biquad_pipeline() + test_invalid_invalid_input() diff --git a/tests/ut/python/dataset/test_char_n_gram.py b/tests/ut/python/dataset/test_char_n_gram.py index 2ac3c42e35a..fd6704f8b11 100644 --- a/tests/ut/python/dataset/test_char_n_gram.py +++ b/tests/ut/python/dataset/test_char_n_gram.py @@ -1,217 +1,217 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -from mindspore import log -import mindspore.dataset as ds -import mindspore.dataset.text as text -import mindspore.dataset.text.transforms as T - -DATASET_ROOT_PATH = "../data/dataset/testVectors/" - - -def _count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected)*rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count/total_count) < rtol,\ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ - format(data_expected[greater], data_me[greater], error[greater]) - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - if np.any(np.isnan(data_expected)): - assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): - _count_unequal_element(data_expected, data_me, rtol, atol) - else: - assert True - - -def test_char_n_gram_all_to_vectors_params_eager(): - """ - Feature: CharNGram - Description: Test with all parameters which include `unk_init` - and `lower_case_backup` in function ToVectors in eager mode - Expectation: Output is equal to the expected value - """ - char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=18) - unk_init = (-np.ones(5)).tolist() - to_vectors = T.ToVectors(char_n_gram, unk_init=unk_init, lower_case_backup=True) - result1 = to_vectors("THE") - result2 = to_vectors(".") - result3 = to_vectors("To") - res = [[-1.34121733e+00, 4.42693333e-02, -4.86969667e-01, 6.62939000e-01, -3.67669000e-01], - [-1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00], - [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] - res_array = np.array(res, dtype=np.float32) - - allclose_nparray(res_array[0], result1, 0.0001, 0.0001) - allclose_nparray(res_array[1], result2, 0.0001, 0.0001) - allclose_nparray(res_array[2], result3, 0.0001, 0.0001) - - -def test_char_n_gram_build_from_file(): - """ - Feature: CharNGram - Description: Test with only default parameter - Expectation: Output is equal to the expected value - """ - char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt") - to_vectors = text.ToVectors(char_n_gram) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.], - [0.117336, 0.362446, -0.983326, 0.939264, -0.05648], - [0.657201, 2.11761, -1.59276, 0.432072, 1.21395], - [0., 0., 0., 0., 0.], - [-2.26956, 0.288491, -0.740001, 0.661703, 0.147355], - [0., 0., 0., 0., 0.]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - allclose_nparray(res_array, d["text"], 0.0001, 0.0001) - ind += 1 - - -def test_char_n_gram_all_build_from_file_params(): - """ - Feature: CharNGram - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile - Expectation: Output is equal to the expected value - """ - char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=100) - to_vectors = text.ToVectors(char_n_gram) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0.], - [0.117336, 0.362446, -0.983326, 0.939264, -0.05648], - [0.657201, 2.11761, -1.59276, 0.432072, 1.21395], - [0., 0., 0., 0., 0.], - [-2.26956, 0.288491, -0.740001, 0.661703, 0.147355], - [0., 0., 0., 0., 0.]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - allclose_nparray(res_array, d["text"], 0.0001, 0.0001) - ind += 1 - - -def test_char_n_gram_all_build_from_file_params_eager(): - """ - Feature: CharNGram - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode - Expectation: Output is equal to the expected value - """ - char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=18) - to_vectors = T.ToVectors(char_n_gram) - result1 = to_vectors("the") - result2 = to_vectors(".") - result3 = to_vectors("to") - res = [[-1.34121733e+00, 4.42693333e-02, -4.86969667e-01, 6.62939000e-01, -3.67669000e-01], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] - res_array = np.array(res, dtype=np.float32) - - allclose_nparray(res_array[0], result1, 0.0001, 0.0001) - allclose_nparray(res_array[1], result2, 0.0001, 0.0001) - allclose_nparray(res_array[2], result3, 0.0001, 0.0001) - - -def test_char_n_gram_build_from_file_eager(): - """ - Feature: CharNGram - Description: Test with only default parameter in eager mode - Expectation: Output is equal to the expected value - """ - char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt") - to_vectors = T.ToVectors(char_n_gram) - result1 = to_vectors("the") - result2 = to_vectors(".") - result3 = to_vectors("to") - res = [[-8.40079000e-01, -2.70002500e-02, -8.33472250e-01, 5.88367000e-01, -2.10011750e-01], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] - res_array = np.array(res, dtype=np.float32) - - allclose_nparray(res_array[0], result1, 0.0001, 0.0001) - allclose_nparray(res_array[1], result2, 0.0001, 0.0001) - allclose_nparray(res_array[2], result3, 0.0001, 0.0001) - - -def test_char_n_gram_invalid_input(): - """ - Feature: CharNGram - Description: Test the validate function with invalid parameters. - Expectation: Verification of correct error message for invalid input. - """ - def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, - unk_init=None, lower_case_backup=False, token="ok"): - log.info("Test CharNGram with wrong input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - char_n_gram = text.CharNGram.from_file(file_path, max_vectors=max_vectors) - to_vectors = T.ToVectors(char_n_gram, unk_init=unk_init, lower_case_backup=lower_case_backup) - to_vectors(token) - assert error_msg in str(error_info.value) - - test_invalid_input("Not all vectors have the same number of dimensions", - DATASET_ROOT_PATH + "char_n_gram_20_dim_different.txt", error=RuntimeError, - error_msg="all vectors must have the same number of dimensions, " + - "but got dim 4 while expecting 5") - test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "vectors_empty.txt", - error=RuntimeError, error_msg="invalid file, file is empty.") - test_invalid_input("the count of `unknown_init`'s element is different with word vector.", - DATASET_ROOT_PATH + "char_n_gram_20.txt", - error=RuntimeError, error_msg="unk_init must be the same length as vectors, " + - "but got unk_init: 6 and vectors: 5", unk_init=np.ones(6).tolist()) - test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", RuntimeError, - error_msg="get real path failed") - test_invalid_input("max_vectors parameter must be greater than 0", - DATASET_ROOT_PATH + "char_n_gram_20.txt", error=ValueError, - error_msg="Input max_vectors is not within the required interval", max_vectors=-1) - test_invalid_input("invalid max_vectors parameter type as a float", - DATASET_ROOT_PATH + "char_n_gram_20.txt", error=TypeError, - error_msg="Argument max_vectors with value 1.0 is not of type []," - " but got .", max_vectors=1.0) - test_invalid_input("invalid max_vectors parameter type as a string", - DATASET_ROOT_PATH + "char_n_gram_20.txt", error=TypeError, - error_msg="Argument max_vectors with value 1 is not of type []," - " but got .", max_vectors="1") - test_invalid_input("invalid token parameter type as a float", - DATASET_ROOT_PATH + "char_n_gram_20.txt", error=RuntimeError, - error_msg="input tensor type should be string.", token=1.0) - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "char_n_gram_20.txt", - error=TypeError, error_msg="Argument lower_case_backup with " + - "value True is not of type []," - " but got .", lower_case_backup="True") - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "char_n_gram_20.txt", - error=TypeError, error_msg="Argument lower_case_backup with " + - "value True is not of type []," - " but got .", lower_case_backup="True") - - -if __name__ == '__main__': - test_char_n_gram_all_to_vectors_params_eager() - test_char_n_gram_build_from_file() - test_char_n_gram_all_build_from_file_params() - test_char_n_gram_all_build_from_file_params_eager() - test_char_n_gram_build_from_file_eager() - test_char_n_gram_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest + +from mindspore import log +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.text.transforms as T + +DATASET_ROOT_PATH = "../data/dataset/testVectors/" + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected)*rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count/total_count) < rtol,\ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert True + + +def test_char_n_gram_all_to_vectors_params_eager(): + """ + Feature: CharNGram + Description: Test with all parameters which include `unk_init` + and `lower_case_backup` in function ToVectors in eager mode + Expectation: Output is equal to the expected value + """ + char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=18) + unk_init = (-np.ones(5)).tolist() + to_vectors = T.ToVectors(char_n_gram, unk_init=unk_init, lower_case_backup=True) + result1 = to_vectors("THE") + result2 = to_vectors(".") + result3 = to_vectors("To") + res = [[-1.34121733e+00, 4.42693333e-02, -4.86969667e-01, 6.62939000e-01, -3.67669000e-01], + [-1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00], + [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] + res_array = np.array(res, dtype=np.float32) + + allclose_nparray(res_array[0], result1, 0.0001, 0.0001) + allclose_nparray(res_array[1], result2, 0.0001, 0.0001) + allclose_nparray(res_array[2], result3, 0.0001, 0.0001) + + +def test_char_n_gram_build_from_file(): + """ + Feature: CharNGram + Description: Test with only default parameter + Expectation: Output is equal to the expected value + """ + char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt") + to_vectors = text.ToVectors(char_n_gram) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0.117336, 0.362446, -0.983326, 0.939264, -0.05648], + [0.657201, 2.11761, -1.59276, 0.432072, 1.21395], + [0., 0., 0., 0., 0.], + [-2.26956, 0.288491, -0.740001, 0.661703, 0.147355], + [0., 0., 0., 0., 0.]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + allclose_nparray(res_array, d["text"], 0.0001, 0.0001) + ind += 1 + + +def test_char_n_gram_all_build_from_file_params(): + """ + Feature: CharNGram + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile + Expectation: Output is equal to the expected value + """ + char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=100) + to_vectors = text.ToVectors(char_n_gram) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0.117336, 0.362446, -0.983326, 0.939264, -0.05648], + [0.657201, 2.11761, -1.59276, 0.432072, 1.21395], + [0., 0., 0., 0., 0.], + [-2.26956, 0.288491, -0.740001, 0.661703, 0.147355], + [0., 0., 0., 0., 0.]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + allclose_nparray(res_array, d["text"], 0.0001, 0.0001) + ind += 1 + + +def test_char_n_gram_all_build_from_file_params_eager(): + """ + Feature: CharNGram + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode + Expectation: Output is equal to the expected value + """ + char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt", max_vectors=18) + to_vectors = T.ToVectors(char_n_gram) + result1 = to_vectors("the") + result2 = to_vectors(".") + result3 = to_vectors("to") + res = [[-1.34121733e+00, 4.42693333e-02, -4.86969667e-01, 6.62939000e-01, -3.67669000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] + res_array = np.array(res, dtype=np.float32) + + allclose_nparray(res_array[0], result1, 0.0001, 0.0001) + allclose_nparray(res_array[1], result2, 0.0001, 0.0001) + allclose_nparray(res_array[2], result3, 0.0001, 0.0001) + + +def test_char_n_gram_build_from_file_eager(): + """ + Feature: CharNGram + Description: Test with only default parameter in eager mode + Expectation: Output is equal to the expected value + """ + char_n_gram = text.CharNGram.from_file(DATASET_ROOT_PATH + "char_n_gram_20.txt") + to_vectors = T.ToVectors(char_n_gram) + result1 = to_vectors("the") + result2 = to_vectors(".") + result3 = to_vectors("to") + res = [[-8.40079000e-01, -2.70002500e-02, -8.33472250e-01, 5.88367000e-01, -2.10011750e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [-9.68530000e-01, -7.89463000e-01, 5.15762000e-01, 2.02107000e+00, -1.64635000e+00]] + res_array = np.array(res, dtype=np.float32) + + allclose_nparray(res_array[0], result1, 0.0001, 0.0001) + allclose_nparray(res_array[1], result2, 0.0001, 0.0001) + allclose_nparray(res_array[2], result3, 0.0001, 0.0001) + + +def test_char_n_gram_invalid_input(): + """ + Feature: CharNGram + Description: Test the validate function with invalid parameters. + Expectation: Verification of correct error message for invalid input. + """ + def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, + unk_init=None, lower_case_backup=False, token="ok"): + log.info("Test CharNGram with wrong input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + char_n_gram = text.CharNGram.from_file(file_path, max_vectors=max_vectors) + to_vectors = T.ToVectors(char_n_gram, unk_init=unk_init, lower_case_backup=lower_case_backup) + to_vectors(token) + assert error_msg in str(error_info.value) + + test_invalid_input("Not all vectors have the same number of dimensions", + DATASET_ROOT_PATH + "char_n_gram_20_dim_different.txt", error=RuntimeError, + error_msg="all vectors must have the same number of dimensions, " + + "but got dim 4 while expecting 5") + test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "vectors_empty.txt", + error=RuntimeError, error_msg="invalid file, file is empty.") + test_invalid_input("the count of `unknown_init`'s element is different with word vector.", + DATASET_ROOT_PATH + "char_n_gram_20.txt", + error=RuntimeError, error_msg="unk_init must be the same length as vectors, " + + "but got unk_init: 6 and vectors: 5", unk_init=np.ones(6).tolist()) + test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", RuntimeError, + error_msg="get real path failed") + test_invalid_input("max_vectors parameter must be greater than 0", + DATASET_ROOT_PATH + "char_n_gram_20.txt", error=ValueError, + error_msg="Input max_vectors is not within the required interval", max_vectors=-1) + test_invalid_input("invalid max_vectors parameter type as a float", + DATASET_ROOT_PATH + "char_n_gram_20.txt", error=TypeError, + error_msg="Argument max_vectors with value 1.0 is not of type []," + " but got .", max_vectors=1.0) + test_invalid_input("invalid max_vectors parameter type as a string", + DATASET_ROOT_PATH + "char_n_gram_20.txt", error=TypeError, + error_msg="Argument max_vectors with value 1 is not of type []," + " but got .", max_vectors="1") + test_invalid_input("invalid token parameter type as a float", + DATASET_ROOT_PATH + "char_n_gram_20.txt", error=RuntimeError, + error_msg="input tensor type should be string.", token=1.0) + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "char_n_gram_20.txt", + error=TypeError, error_msg="Argument lower_case_backup with " + + "value True is not of type []," + " but got .", lower_case_backup="True") + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "char_n_gram_20.txt", + error=TypeError, error_msg="Argument lower_case_backup with " + + "value True is not of type []," + " but got .", lower_case_backup="True") + + +if __name__ == '__main__': + test_char_n_gram_all_to_vectors_params_eager() + test_char_n_gram_build_from_file() + test_char_n_gram_all_build_from_file_params() + test_char_n_gram_all_build_from_file_params_eager() + test_char_n_gram_build_from_file_eager() + test_char_n_gram_invalid_input() diff --git a/tests/ut/python/dataset/test_create_dct.py b/tests/ut/python/dataset/test_create_dct.py index e6ffa39e2dc..6c525df08d2 100644 --- a/tests/ut/python/dataset/test_create_dct.py +++ b/tests/ut/python/dataset/test_create_dct.py @@ -1,95 +1,95 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np -import pytest - -from mindspore.dataset.audio import create_dct, NormMode -from mindspore import log as logger - - -def count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, \ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ - format(data_expected[greater], data_me[greater], error[greater]) - - -def test_create_dct_none(): - """ - Feature: Create DCT transformation - Description: Test create_dct in eager mode with no normalization - Expectation: The returned result is as expected - """ - expect = np.array([[2.00000000, 1.84775901], - [2.00000000, 0.76536685], - [2.00000000, -0.76536703], - [2.00000000, -1.84775925]], dtype=np.float64) - output = create_dct(2, 4, NormMode.NONE) - count_unequal_element(expect, output, 0.0001, 0.0001) - - -def test_create_dct_ortho(): - """ - Feature: Create DCT transformation - Description: Test create_dct in eager mode with orthogonal normalization - Expectation: The returned result is as expected - """ - output = create_dct(1, 3, NormMode.ORTHO) - expect = np.array([[0.57735026], - [0.57735026], - [0.57735026]], dtype=np.float64) - count_unequal_element(expect, output, 0.0001, 0.0001) - - -def test_createdct_invalid_input(): - """ - Feature: Create DCT transformation - Description: Test create_dct with invalid inputs - Expectation: Error is raised as expected - """ - def test_invalid_input(test_name, n_mfcc, n_mels, norm, error, error_msg): - logger.info("Test CreateDct with bad input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - create_dct(n_mfcc, n_mels, norm) - assert error_msg in str(error_info.value) - - test_invalid_input("invalid n_mfcc parameter type as a float", 100.5, 200, NormMode.NONE, TypeError, - "n_mfcc with value 100.5 is not of type , but got .") - test_invalid_input("invalid n_mfcc parameter type as a String", "100", 200, NormMode.NONE, TypeError, - "n_mfcc with value 100 is not of type , but got .") - test_invalid_input("invalid n_mels parameter type as a String", 100, "200", NormMode.NONE, TypeError, - "n_mels with value 200 is not of type , but got .") - test_invalid_input("invalid n_mels parameter type as a String", 0, 200, NormMode.NONE, ValueError, - "n_mfcc must be greater than 0, but got 0.") - test_invalid_input("invalid n_mels parameter type as a String", 100, 0, NormMode.NONE, ValueError, - "n_mels must be greater than 0, but got 0.") - test_invalid_input("invalid n_mels parameter type as a String", -100, 200, NormMode.NONE, ValueError, - "n_mfcc must be greater than 0, but got -100.") - test_invalid_input("invalid n_mfcc parameter value", None, 100, NormMode.NONE, TypeError, - "n_mfcc with value None is not of type , but got .") - test_invalid_input("invalid n_mels parameter value", 100, None, NormMode.NONE, TypeError, - "n_mels with value None is not of type , but got .") - test_invalid_input("invalid n_mels parameter value", 100, 200, "None", TypeError, - "norm with value None is not of type , but got .") - - -if __name__ == "__main__": - test_create_dct_none() - test_create_dct_ortho() - test_createdct_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +from mindspore.dataset.audio import create_dct, NormMode +from mindspore import log as logger + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def test_create_dct_none(): + """ + Feature: Create DCT transformation + Description: Test create_dct in eager mode with no normalization + Expectation: The returned result is as expected + """ + expect = np.array([[2.00000000, 1.84775901], + [2.00000000, 0.76536685], + [2.00000000, -0.76536703], + [2.00000000, -1.84775925]], dtype=np.float64) + output = create_dct(2, 4, NormMode.NONE) + count_unequal_element(expect, output, 0.0001, 0.0001) + + +def test_create_dct_ortho(): + """ + Feature: Create DCT transformation + Description: Test create_dct in eager mode with orthogonal normalization + Expectation: The returned result is as expected + """ + output = create_dct(1, 3, NormMode.ORTHO) + expect = np.array([[0.57735026], + [0.57735026], + [0.57735026]], dtype=np.float64) + count_unequal_element(expect, output, 0.0001, 0.0001) + + +def test_createdct_invalid_input(): + """ + Feature: Create DCT transformation + Description: Test create_dct with invalid inputs + Expectation: Error is raised as expected + """ + def test_invalid_input(test_name, n_mfcc, n_mels, norm, error, error_msg): + logger.info("Test CreateDct with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + create_dct(n_mfcc, n_mels, norm) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid n_mfcc parameter type as a float", 100.5, 200, NormMode.NONE, TypeError, + "n_mfcc with value 100.5 is not of type , but got .") + test_invalid_input("invalid n_mfcc parameter type as a String", "100", 200, NormMode.NONE, TypeError, + "n_mfcc with value 100 is not of type , but got .") + test_invalid_input("invalid n_mels parameter type as a String", 100, "200", NormMode.NONE, TypeError, + "n_mels with value 200 is not of type , but got .") + test_invalid_input("invalid n_mels parameter type as a String", 0, 200, NormMode.NONE, ValueError, + "n_mfcc must be greater than 0, but got 0.") + test_invalid_input("invalid n_mels parameter type as a String", 100, 0, NormMode.NONE, ValueError, + "n_mels must be greater than 0, but got 0.") + test_invalid_input("invalid n_mels parameter type as a String", -100, 200, NormMode.NONE, ValueError, + "n_mfcc must be greater than 0, but got -100.") + test_invalid_input("invalid n_mfcc parameter value", None, 100, NormMode.NONE, TypeError, + "n_mfcc with value None is not of type , but got .") + test_invalid_input("invalid n_mels parameter value", 100, None, NormMode.NONE, TypeError, + "n_mels with value None is not of type , but got .") + test_invalid_input("invalid n_mels parameter value", 100, 200, "None", TypeError, + "norm with value None is not of type , but got .") + + +if __name__ == "__main__": + test_create_dct_none() + test_create_dct_ortho() + test_createdct_invalid_input() diff --git a/tests/ut/python/dataset/test_datasets_ag_news.py b/tests/ut/python/dataset/test_datasets_ag_news.py index f6db27b39cf..7ea6260a376 100644 --- a/tests/ut/python/dataset/test_datasets_ag_news.py +++ b/tests/ut/python/dataset/test_datasets_ag_news.py @@ -1,163 +1,163 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import mindspore.dataset as ds - -FILE_DIR = '../data/dataset/testAGNews' - - -def test_ag_news_dataset_basic(): - """ - Feature: Test AG News Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False) - data = data.repeat(2) - data = data.skip(2) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 8 - - -def test_ag_news_dataset_one_file(): - """ - Feature: Test AG News Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - buffer = [] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 2 - - -def test_ag_news_dataset_all_file(): - """ - Feature: Test AG News Dataset(usage=all). - Description: Read train data and test data. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 5 - - -def test_ag_news_dataset_num_samples(): - """ - Feature: Test AG News Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.AGNewsDataset(FILE_DIR, usage='all', num_samples=4, shuffle=False) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 4 - - -def test_ag_news_dataset_distribution(): - """ - Feature: Test AG News Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 1 - - -def test_ag_news_dataset_quoted(): - """ - Feature: Test get the AG News Dataset. - Description: Read AGNewsDataset data and get data. - Expectation: The data is processed successfully. - """ - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - buffer = [] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['index'], - d['title'], - d['description']]) - assert buffer == ["3", "Background of the selection", - "In this day and age, the internet is growing rapidly, " - "the total number of connected devices is increasing and " - "we are entering the era of big data.", - "4", "Related technologies", - "\"Leaflet is the leading open source JavaScript library " - "for mobile-friendly interactive maps.\""] - - -def test_ag_news_dataset_size(): - """ - Feature: Test Getters. - Description: Test get_dataset_size of AG News dataset. - Expectation: The data is processed successfully. - """ - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - assert data.get_dataset_size() == 2 - - -def test_ag_news_dataset_exception(): - """ - Feature: Error Test. - Description: Test the wrong input. - Expectation: Unable to read in data. - """ - def exception_func(item): - raise Exception("Error occur!") - - try: - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["description"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - -if __name__ == "__main__": - test_ag_news_dataset_basic() - test_ag_news_dataset_one_file() - test_ag_news_dataset_all_file() - test_ag_news_dataset_num_samples() - test_ag_news_dataset_distribution() - test_ag_news_dataset_quoted() - test_ag_news_dataset_size() - test_ag_news_dataset_exception() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds + +FILE_DIR = '../data/dataset/testAGNews' + + +def test_ag_news_dataset_basic(): + """ + Feature: Test AG News Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 8 + + +def test_ag_news_dataset_one_file(): + """ + Feature: Test AG News Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + buffer = [] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 2 + + +def test_ag_news_dataset_all_file(): + """ + Feature: Test AG News Dataset(usage=all). + Description: Read train data and test data. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AGNewsDataset(FILE_DIR, usage='all', shuffle=False) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 5 + + +def test_ag_news_dataset_num_samples(): + """ + Feature: Test AG News Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.AGNewsDataset(FILE_DIR, usage='all', num_samples=4, shuffle=False) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 4 + + +def test_ag_news_dataset_distribution(): + """ + Feature: Test AG News Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 1 + + +def test_ag_news_dataset_quoted(): + """ + Feature: Test get the AG News Dataset. + Description: Read AGNewsDataset data and get data. + Expectation: The data is processed successfully. + """ + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + buffer = [] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['index'], + d['title'], + d['description']]) + assert buffer == ["3", "Background of the selection", + "In this day and age, the internet is growing rapidly, " + "the total number of connected devices is increasing and " + "we are entering the era of big data.", + "4", "Related technologies", + "\"Leaflet is the leading open source JavaScript library " + "for mobile-friendly interactive maps.\""] + + +def test_ag_news_dataset_size(): + """ + Feature: Test Getters. + Description: Test get_dataset_size of AG News dataset. + Expectation: The data is processed successfully. + """ + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + assert data.get_dataset_size() == 2 + + +def test_ag_news_dataset_exception(): + """ + Feature: Error Test. + Description: Test the wrong input. + Expectation: Unable to read in data. + """ + def exception_func(item): + raise Exception("Error occur!") + + try: + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.AGNewsDataset(FILE_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["description"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + +if __name__ == "__main__": + test_ag_news_dataset_basic() + test_ag_news_dataset_one_file() + test_ag_news_dataset_all_file() + test_ag_news_dataset_num_samples() + test_ag_news_dataset_distribution() + test_ag_news_dataset_quoted() + test_ag_news_dataset_size() + test_ag_news_dataset_exception() diff --git a/tests/ut/python/dataset/test_datasets_amazon_review.py b/tests/ut/python/dataset/test_datasets_amazon_review.py index 7abd45b2791..218e0a06031 100644 --- a/tests/ut/python/dataset/test_datasets_amazon_review.py +++ b/tests/ut/python/dataset/test_datasets_amazon_review.py @@ -1,238 +1,238 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np - -import mindspore.dataset as ds -import mindspore.dataset.text.transforms as a_c_trans - -POLARITY_DIR = '../data/dataset/testAmazonReview/polarity' -FULL_DIR = '../data/dataset/testAmazonReview/full' - - -def count_unequal_element(data_expected, data_me): - assert data_expected.shape == data_me.shape - assert data_expected == data_me - - -def test_amazon_review_polarity_dataset_basic(): - """ - Feature: Test AmazonReviewPolarity Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AmazonReviewDataset(POLARITY_DIR, usage='test', shuffle=False) - data = data.repeat(2) - data = data.skip(2) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 2 - - -def test_amazon_review_full_dataset_basic(): - """ - Feature: Test AmazonReviewFull Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - data = data.repeat(2) - data = data.skip(2) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 4 - - -def test_amazon_review_dataset_quoted(): - """ - Feature: Test get the AmazonReview Dataset. - Description: Read AmazonReviewPolarityDataset data and get data. - Expectation: The data is processed successfully. - """ - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - buffer = [] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['label'], - d['title'], - d['content']]) - assert buffer == ["1", "amazing", "unlimited buyback!", - "4", "delightful", "a funny book!", - "3", "Small", "It is a small ball!"] - - -def test_amazon_review_full_dataset_usage_all(): - """ - Feature: Test AmazonReviewPolarity Dataset(usage=all). - Description: Read train data and test data. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AmazonReviewDataset(FULL_DIR, usage='all', shuffle=False) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['label'], - d['title'], - d['content']]) - assert buffer == ["1", "amazing", "unlimited buyback!", - "3", "Satisfied", "good quality.", - "4", "delightful", "a funny book!", - "5", "good", "This is an very good product.", - "3", "Small", "It is a small ball!", - "1", "bad", "work badly."] - - -def test_amazon_review_polarity_dataset_usage_all(): - """ - Feature: Test AmazonReviewPolarityPolarity Dataset(usage=all). - Description: Read train data and test data. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.AmazonReviewDataset(POLARITY_DIR, usage='all', shuffle=False) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['label'], - d['title'], - d['content']]) - assert buffer == ["1", "DVD", "It is very good!", - "2", "Great Read", "I thought this book was excellent!", - "2", "Book", "I would read it again lol.", - "1", "Oh dear", "It is so bad!", - "2", "Delicious", "A funny product."] - - -def test_amazon_review_dataset_get_datasetsize(): - """ - Feature: Test Getters. - Description: Test get_dataset_size of AmazonReview dataset. - Expectation: The data is processed successfully. - """ - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - size = data.get_dataset_size() - assert size == 3 - - -def test_amazon_review_dataset_distribution(): - """ - Feature: Test AmazonReviewDataset in distribution. - Description: Test in a distributed state. - Expectation: The data is processed successfully. - """ - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 2 - - -def test_amazon_review_dataset_num_samples(): - """ - Feature: Test AmazonReview Dataset(num_samples = 2). - Description: Test get num_samples. - Expectation: The data is processed successfully. - """ - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_samples=2) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 2 - - -def test_amazon_review_dataset_exception(): - """ - Feature: Error Test. - Description: Test the wrong input. - Expectation: Unable to read in data. - """ - def exception_func(item): - raise Exception("Error occur!") - - try: - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - -def test_amazon_review_dataset_pipeline(): - """ - Feature: AmazonReviewDataset - Description: Test AmazonReviewDataset in pipeline mode - Expectation: The data is processed successfully - """ - expected_columns1 = np.array(["3", "5", "1"]) - dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) - filter_wikipedia_xml_op = a_c_trans.CaseFold() - dataset = dataset.map(input_columns=["label"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) - i = 0 - for data in dataset.create_dict_iterator(output_numpy=True): - count_unequal_element(np.array(expected_columns1[i]), data['label']) - i += 1 - assert i == 3 - - expected_columns2 = np.array(["satisfied", "good", "bad"]) - dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) - filter_wikipedia_xml_op = a_c_trans.CaseFold() - dataset = dataset.map(input_columns=["title"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) - i = 0 - for data in dataset.create_dict_iterator(output_numpy=True): - count_unequal_element(np.array(expected_columns2[i]), data['title']) - i += 1 - assert i == 3 - - expected_columns3 = np.array(["good quality.", - "this is an very good product.", - "work badly."]) - dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) - filter_wikipedia_xml_op = a_c_trans.CaseFold() - dataset = dataset.map(input_columns=["content"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) - i = 0 - for data in dataset.create_dict_iterator(output_numpy=True): - count_unequal_element(np.array(expected_columns3[i]), data['content']) - i += 1 - assert i == 3 - - -if __name__ == "__main__": - test_amazon_review_polarity_dataset_basic() - test_amazon_review_full_dataset_basic() - test_amazon_review_dataset_quoted() - test_amazon_review_full_dataset_usage_all() - test_amazon_review_polarity_dataset_usage_all() - test_amazon_review_dataset_get_datasetsize() - test_amazon_review_dataset_distribution() - test_amazon_review_dataset_num_samples() - test_amazon_review_dataset_exception() - test_amazon_review_dataset_pipeline() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.text.transforms as a_c_trans + +POLARITY_DIR = '../data/dataset/testAmazonReview/polarity' +FULL_DIR = '../data/dataset/testAmazonReview/full' + + +def count_unequal_element(data_expected, data_me): + assert data_expected.shape == data_me.shape + assert data_expected == data_me + + +def test_amazon_review_polarity_dataset_basic(): + """ + Feature: Test AmazonReviewPolarity Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AmazonReviewDataset(POLARITY_DIR, usage='test', shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 2 + + +def test_amazon_review_full_dataset_basic(): + """ + Feature: Test AmazonReviewFull Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 4 + + +def test_amazon_review_dataset_quoted(): + """ + Feature: Test get the AmazonReview Dataset. + Description: Read AmazonReviewPolarityDataset data and get data. + Expectation: The data is processed successfully. + """ + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + buffer = [] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['label'], + d['title'], + d['content']]) + assert buffer == ["1", "amazing", "unlimited buyback!", + "4", "delightful", "a funny book!", + "3", "Small", "It is a small ball!"] + + +def test_amazon_review_full_dataset_usage_all(): + """ + Feature: Test AmazonReviewPolarity Dataset(usage=all). + Description: Read train data and test data. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AmazonReviewDataset(FULL_DIR, usage='all', shuffle=False) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['label'], + d['title'], + d['content']]) + assert buffer == ["1", "amazing", "unlimited buyback!", + "3", "Satisfied", "good quality.", + "4", "delightful", "a funny book!", + "5", "good", "This is an very good product.", + "3", "Small", "It is a small ball!", + "1", "bad", "work badly."] + + +def test_amazon_review_polarity_dataset_usage_all(): + """ + Feature: Test AmazonReviewPolarityPolarity Dataset(usage=all). + Description: Read train data and test data. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.AmazonReviewDataset(POLARITY_DIR, usage='all', shuffle=False) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['label'], + d['title'], + d['content']]) + assert buffer == ["1", "DVD", "It is very good!", + "2", "Great Read", "I thought this book was excellent!", + "2", "Book", "I would read it again lol.", + "1", "Oh dear", "It is so bad!", + "2", "Delicious", "A funny product."] + + +def test_amazon_review_dataset_get_datasetsize(): + """ + Feature: Test Getters. + Description: Test get_dataset_size of AmazonReview dataset. + Expectation: The data is processed successfully. + """ + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + size = data.get_dataset_size() + assert size == 3 + + +def test_amazon_review_dataset_distribution(): + """ + Feature: Test AmazonReviewDataset in distribution. + Description: Test in a distributed state. + Expectation: The data is processed successfully. + """ + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_amazon_review_dataset_num_samples(): + """ + Feature: Test AmazonReview Dataset(num_samples = 2). + Description: Test get num_samples. + Expectation: The data is processed successfully. + """ + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_samples=2) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_amazon_review_dataset_exception(): + """ + Feature: Error Test. + Description: Test the wrong input. + Expectation: Unable to read in data. + """ + def exception_func(item): + raise Exception("Error occur!") + + try: + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + +def test_amazon_review_dataset_pipeline(): + """ + Feature: AmazonReviewDataset + Description: Test AmazonReviewDataset in pipeline mode + Expectation: The data is processed successfully + """ + expected_columns1 = np.array(["3", "5", "1"]) + dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) + filter_wikipedia_xml_op = a_c_trans.CaseFold() + dataset = dataset.map(input_columns=["label"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) + i = 0 + for data in dataset.create_dict_iterator(output_numpy=True): + count_unequal_element(np.array(expected_columns1[i]), data['label']) + i += 1 + assert i == 3 + + expected_columns2 = np.array(["satisfied", "good", "bad"]) + dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) + filter_wikipedia_xml_op = a_c_trans.CaseFold() + dataset = dataset.map(input_columns=["title"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) + i = 0 + for data in dataset.create_dict_iterator(output_numpy=True): + count_unequal_element(np.array(expected_columns2[i]), data['title']) + i += 1 + assert i == 3 + + expected_columns3 = np.array(["good quality.", + "this is an very good product.", + "work badly."]) + dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False) + filter_wikipedia_xml_op = a_c_trans.CaseFold() + dataset = dataset.map(input_columns=["content"], operations=filter_wikipedia_xml_op, num_parallel_workers=1) + i = 0 + for data in dataset.create_dict_iterator(output_numpy=True): + count_unequal_element(np.array(expected_columns3[i]), data['content']) + i += 1 + assert i == 3 + + +if __name__ == "__main__": + test_amazon_review_polarity_dataset_basic() + test_amazon_review_full_dataset_basic() + test_amazon_review_dataset_quoted() + test_amazon_review_full_dataset_usage_all() + test_amazon_review_polarity_dataset_usage_all() + test_amazon_review_dataset_get_datasetsize() + test_amazon_review_dataset_distribution() + test_amazon_review_dataset_num_samples() + test_amazon_review_dataset_exception() + test_amazon_review_dataset_pipeline() diff --git a/tests/ut/python/dataset/test_datasets_cmu_arctic.py b/tests/ut/python/dataset/test_datasets_cmu_arctic.py index 4877524e8df..fe789ba4bd5 100644 --- a/tests/ut/python/dataset/test_datasets_cmu_arctic.py +++ b/tests/ut/python/dataset/test_datasets_cmu_arctic.py @@ -1,234 +1,234 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License foNtest_resr the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test CMUArctic dataset operations -""" -import numpy as np -import pytest - -import mindspore.dataset as ds -from mindspore import log as logger - - -DATA_DIR = "../data/dataset/testCMUArcticData" - - -def test_cmu_arctic_basic(): - """ - Feature: CMUArcticDataset - Description: Test basic name of CMUArctic - Expectation: The dataset is as expected - """ - logger.info("Test CMUArcticDataset Op") - - # case 1: test loading fault dataset. - data1 = ds.CMUArcticDataset(DATA_DIR) - num_iter1 = 0 - for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter1 += 1 - assert num_iter1 == 3 - - # case 2: test num_samples. - data2 = ds.CMUArcticDataset(DATA_DIR, num_samples=1) - num_iter2 = 0 - for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter2 += 1 - assert num_iter2 == 1 - - # case 3: test repeat. - data3 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) - data3 = data3.repeat(3) - num_iter3 = 0 - for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter3 += 1 - assert num_iter3 == 9 - - # case 4: test batch with drop_remainder=False. - data4 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) - assert data4.get_dataset_size() == 3 - assert data4.get_batch_size() == 1 - data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. - assert data4.get_dataset_size() == 2 - assert data4.get_batch_size() == 2 - - # case 5: test batch with drop_remainder=True. - data5 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) - assert data5.get_dataset_size() == 3 - assert data5.get_batch_size() == 1 - # the rest of incomplete batch will be dropped. - data5 = data5.batch(batch_size=2, drop_remainder=True) - assert data5.get_dataset_size() == 1 - assert data5.get_batch_size() == 2 - - -def test_cmu_arctic_distribute_sampler(): - """ - Feature: CMUArcticDataset - Description: Test CMUArctic dataset with DistributedSampler - Expectation: The results are as expected - """ - logger.info("Test CMUArctic with sharding") - - num_shards = 3 - shard_id = 0 - - data1 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_shards=num_shards, shard_id=shard_id) - count = 0 - for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - count = count + 1 - assert count == 1 - - num_shards = 3 - shard_id = 0 - sampler = ds.DistributedSampler(num_shards, shard_id) - data2 = ds.CMUArcticDataset(DATA_DIR, name="aew", sampler=sampler) - count = 0 - for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - count = count + 1 - assert count == 1 - - -def test_cmu_arctic_exception(): - """ - Feature: CMUArcticDataset - Description: Test error cases for CMUArcticDataset - Expectation: The results are as expected - """ - logger.info("Test error cases for CMUArcticDataset") - - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.CMUArcticDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.CMUArcticDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.CMUArcticDataset(DATA_DIR, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.CMUArcticDataset(DATA_DIR, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.CMUArcticDataset(DATA_DIR, num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.CMUArcticDataset(DATA_DIR, num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.CMUArcticDataset(DATA_DIR, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.CMUArcticDataset(DATA_DIR, num_shards=2, shard_id="0") - - def exception_func(item): - raise Exception("Error occur!") - - error_msg_8 = "The corresponding data file is" - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.CMUArcticDataset(DATA_DIR) - data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) - for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): - pass - - -def test_cmu_arctic_sequential_sampler(): - """ - Feature: CMUArcticDataset - Description: Test CMUArcticDataset with SequentialSampler - Expectation: The results are as expected - """ - logger.info("Test CMUArcticDataset Op with SequentialSampler") - - num_samples = 2 - sampler = ds.SequentialSampler(num_samples=num_samples) - data1 = ds.CMUArcticDataset(DATA_DIR, name="aew", sampler=sampler) - data2 = ds.CMUArcticDataset(DATA_DIR, name="aew", shuffle=False, num_samples=num_samples) - - utterance_id_expected = ['a0001', 'a0002'] - utterance_id_list1, utterance_id_list2 = [], [] - - sample_rate_expected = [16000, 16000] - sample_rate_list1, sample_rate_list2 = [], [] - - transcript_expected = ['Dog.', 'Cat.'] - transcript_list1, transcript_list2 = [], [] - num_iter = 0 - for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), - data2.create_dict_iterator(output_numpy=True, num_epochs=1)): - transcript_list1.append(item1["transcript"]) - transcript_list2.append(item2["transcript"]) - sample_rate_list1.append(item1["sample_rate"]) - sample_rate_list2.append(item2["sample_rate"]) - utterance_id_list1.append(item1["utterance_id"]) - utterance_id_list2.append(item2["utterance_id"]) - num_iter += 1 - - np.testing.assert_array_equal(transcript_list1, transcript_expected) - np.testing.assert_array_equal(transcript_list2, transcript_expected) - np.testing.assert_array_equal(utterance_id_list1, utterance_id_expected) - np.testing.assert_array_equal(utterance_id_list2, utterance_id_expected) - np.testing.assert_array_equal(sample_rate_list1, sample_rate_expected) - np.testing.assert_array_equal(sample_rate_list2, sample_rate_expected) - assert num_iter == num_samples - - -def test_cmu_arctic_name(): - """ - Feature: CMUArcticDataset - Description: Test CMUArcticDataset name - Expectation: The results are as expected - """ - logger.info("Test CMUArcticDataset name") - - def test_config(name, cmu_arctic_path=None): - cmu_arctic_path = DATA_DIR if cmu_arctic_path is None else cmu_arctic_path - try: - data = ds.CMUArcticDataset(cmu_arctic_path, name=name, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("aew") == 3 - assert "Input name is not within the valid set of ['aew', 'ahw', 'aup', 'awb', 'axb', 'bdl', 'clb', 'eey', "\ - "'fem', 'gka', 'jmk', 'ksp', 'ljm', 'lnh', 'rms', 'rxr', 'slp', 'slt']." in test_config("invalid") - assert "Argument name with value ['list'] is not of type []" in test_config(["list"]) - - all_files_path = None - if all_files_path is not None: - assert test_config("aew", all_files_path) == 3 - assert ds.cmu_arcticDataset(all_files_path, name="aew").get_dataset_size() == 3 - - -if __name__ == '__main__': - test_cmu_arctic_basic() - test_cmu_arctic_distribute_sampler() - test_cmu_arctic_exception() - test_cmu_arctic_sequential_sampler() - test_cmu_arctic_name() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License foNtest_resr the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test CMUArctic dataset operations +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger + + +DATA_DIR = "../data/dataset/testCMUArcticData" + + +def test_cmu_arctic_basic(): + """ + Feature: CMUArcticDataset + Description: Test basic name of CMUArctic + Expectation: The dataset is as expected + """ + logger.info("Test CMUArcticDataset Op") + + # case 1: test loading fault dataset. + data1 = ds.CMUArcticDataset(DATA_DIR) + num_iter1 = 0 + for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 3 + + # case 2: test num_samples. + data2 = ds.CMUArcticDataset(DATA_DIR, num_samples=1) + num_iter2 = 0 + for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 1 + + # case 3: test repeat. + data3 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) + data3 = data3.repeat(3) + num_iter3 = 0 + for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 9 + + # case 4: test batch with drop_remainder=False. + data4 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) + assert data4.get_dataset_size() == 3 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. + assert data4.get_dataset_size() == 2 + assert data4.get_batch_size() == 2 + + # case 5: test batch with drop_remainder=True. + data5 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_samples=3) + assert data5.get_dataset_size() == 3 + assert data5.get_batch_size() == 1 + # the rest of incomplete batch will be dropped. + data5 = data5.batch(batch_size=2, drop_remainder=True) + assert data5.get_dataset_size() == 1 + assert data5.get_batch_size() == 2 + + +def test_cmu_arctic_distribute_sampler(): + """ + Feature: CMUArcticDataset + Description: Test CMUArctic dataset with DistributedSampler + Expectation: The results are as expected + """ + logger.info("Test CMUArctic with sharding") + + num_shards = 3 + shard_id = 0 + + data1 = ds.CMUArcticDataset(DATA_DIR, name="aew", num_shards=num_shards, shard_id=shard_id) + count = 0 + for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + count = count + 1 + assert count == 1 + + num_shards = 3 + shard_id = 0 + sampler = ds.DistributedSampler(num_shards, shard_id) + data2 = ds.CMUArcticDataset(DATA_DIR, name="aew", sampler=sampler) + count = 0 + for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + count = count + 1 + assert count == 1 + + +def test_cmu_arctic_exception(): + """ + Feature: CMUArcticDataset + Description: Test error cases for CMUArcticDataset + Expectation: The results are as expected + """ + logger.info("Test error cases for CMUArcticDataset") + + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.CMUArcticDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.CMUArcticDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.CMUArcticDataset(DATA_DIR, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.CMUArcticDataset(DATA_DIR, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.CMUArcticDataset(DATA_DIR, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.CMUArcticDataset(DATA_DIR, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.CMUArcticDataset(DATA_DIR, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.CMUArcticDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.CMUArcticDataset(DATA_DIR, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data file is" + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.CMUArcticDataset(DATA_DIR) + data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) + for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): + pass + + +def test_cmu_arctic_sequential_sampler(): + """ + Feature: CMUArcticDataset + Description: Test CMUArcticDataset with SequentialSampler + Expectation: The results are as expected + """ + logger.info("Test CMUArcticDataset Op with SequentialSampler") + + num_samples = 2 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.CMUArcticDataset(DATA_DIR, name="aew", sampler=sampler) + data2 = ds.CMUArcticDataset(DATA_DIR, name="aew", shuffle=False, num_samples=num_samples) + + utterance_id_expected = ['a0001', 'a0002'] + utterance_id_list1, utterance_id_list2 = [], [] + + sample_rate_expected = [16000, 16000] + sample_rate_list1, sample_rate_list2 = [], [] + + transcript_expected = ['Dog.', 'Cat.'] + transcript_list1, transcript_list2 = [], [] + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), + data2.create_dict_iterator(output_numpy=True, num_epochs=1)): + transcript_list1.append(item1["transcript"]) + transcript_list2.append(item2["transcript"]) + sample_rate_list1.append(item1["sample_rate"]) + sample_rate_list2.append(item2["sample_rate"]) + utterance_id_list1.append(item1["utterance_id"]) + utterance_id_list2.append(item2["utterance_id"]) + num_iter += 1 + + np.testing.assert_array_equal(transcript_list1, transcript_expected) + np.testing.assert_array_equal(transcript_list2, transcript_expected) + np.testing.assert_array_equal(utterance_id_list1, utterance_id_expected) + np.testing.assert_array_equal(utterance_id_list2, utterance_id_expected) + np.testing.assert_array_equal(sample_rate_list1, sample_rate_expected) + np.testing.assert_array_equal(sample_rate_list2, sample_rate_expected) + assert num_iter == num_samples + + +def test_cmu_arctic_name(): + """ + Feature: CMUArcticDataset + Description: Test CMUArcticDataset name + Expectation: The results are as expected + """ + logger.info("Test CMUArcticDataset name") + + def test_config(name, cmu_arctic_path=None): + cmu_arctic_path = DATA_DIR if cmu_arctic_path is None else cmu_arctic_path + try: + data = ds.CMUArcticDataset(cmu_arctic_path, name=name, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("aew") == 3 + assert "Input name is not within the valid set of ['aew', 'ahw', 'aup', 'awb', 'axb', 'bdl', 'clb', 'eey', "\ + "'fem', 'gka', 'jmk', 'ksp', 'ljm', 'lnh', 'rms', 'rxr', 'slp', 'slt']." in test_config("invalid") + assert "Argument name with value ['list'] is not of type []" in test_config(["list"]) + + all_files_path = None + if all_files_path is not None: + assert test_config("aew", all_files_path) == 3 + assert ds.cmu_arcticDataset(all_files_path, name="aew").get_dataset_size() == 3 + + +if __name__ == '__main__': + test_cmu_arctic_basic() + test_cmu_arctic_distribute_sampler() + test_cmu_arctic_exception() + test_cmu_arctic_sequential_sampler() + test_cmu_arctic_name() diff --git a/tests/ut/python/dataset/test_datasets_flowers102.py b/tests/ut/python/dataset/test_datasets_flowers102.py index 40ecb53cdc1..63b8c2f2a96 100644 --- a/tests/ut/python/dataset/test_datasets_flowers102.py +++ b/tests/ut/python/dataset/test_datasets_flowers102.py @@ -1,362 +1,362 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test Flowers102 dataset operations -""" -import os - -import matplotlib.pyplot as plt -import numpy as np -import pytest -from PIL import Image -from scipy.io import loadmat - -import mindspore.dataset as ds -import mindspore.dataset.vision as c_vision -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testFlowers102Dataset" -WRONG_DIR = "../data/dataset/testMnistData" - - -def load_flowers102(path, usage): - """ - load Flowers102 data - """ - assert usage in ["train", "valid", "test", "all"] - - imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32) - split = loadmat(os.path.join(path, "setid.mat")) - if usage == 'train': - indices = split["trnid"][0].tolist() - elif usage == 'test': - indices = split["tstid"][0].tolist() - elif usage == 'valid': - indices = split["valid"][0].tolist() - elif usage == 'all': - indices = split["trnid"][0].tolist() - indices += split["tstid"][0].tolist() - indices += split["valid"][0].tolist() - - image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices] - segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices] - images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths] - segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths] - labels = [imagelabels[index - 1] for index in indices] - - return images, segmentations, labels - - -def visualize_dataset(images, labels): - """ - Helper function to visualize the dataset samples - """ - num_samples = len(images) - for i in range(num_samples): - plt.subplot(1, num_samples, i + 1) - plt.imshow(images[i].squeeze()) - plt.title(labels[i]) - plt.show() - - -def test_flowers102_content_check(): - """ - Feature: Flowers102Dataset - Description: Test Flowers102Dataset image readings with content check - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset Op with content check") - all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all", - num_samples=6, decode=True, shuffle=False) - images, segmentations, labels = load_flowers102(DATA_DIR, "all") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["segmentation"], segmentations[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 6 - - train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train", - num_samples=2, decode=True, shuffle=False) - images, segmentations, labels = load_flowers102(DATA_DIR, "train") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["segmentation"], segmentations[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 2 - - test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test", - num_samples=2, decode=True, shuffle=False) - images, segmentations, labels = load_flowers102(DATA_DIR, "test") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["segmentation"], segmentations[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 2 - - val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid", - num_samples=2, decode=True, shuffle=False) - images, segmentations, labels = load_flowers102(DATA_DIR, "valid") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["segmentation"], segmentations[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 2 - - -def test_flowers102_basic(): - """ - Feature: Flowers102Dataset - Description: Test basic read on Flowers102Dataset - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset Op") - - # case 1: test decode - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False) - all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1) - all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False) - - num_iter = 0 - for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True), - all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(item1["label"], item2["label"]) - num_iter += 1 - assert num_iter == 6 - - # case 2: test num_samples - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) - num_iter = 0 - for _ in all_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 4 - - # case 3: test repeat - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) - all_data = all_data.repeat(5) - num_iter = 0 - for _ in all_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 20 - - # case 3: test get_dataset_size, resize and batch - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) - all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"], - num_parallel_workers=1) - - assert all_data.get_dataset_size() == 4 - assert all_data.get_batch_size() == 1 - all_data = all_data.batch(batch_size=3) # drop_remainder is default to be False - assert all_data.get_batch_size() == 3 - assert all_data.get_dataset_size() == 2 - - num_iter = 0 - for _ in all_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 2 - - # case 4: test get_class_indexing - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) - class_indexing = all_data.get_class_indexing() - assert class_indexing["pink primrose"] == 0 - assert class_indexing["blackberry lily"] == 101 - - -def test_flowers102_sequential_sampler(): - """ - Feature: Flowers102Dataset - Description: Test Flowers102Dataset with SequentialSampler - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset Op with SequentialSampler") - num_samples = 4 - sampler = ds.SequentialSampler(num_samples=num_samples) - all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", - decode=True, sampler=sampler) - all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", - decode=True, shuffle=False, num_samples=num_samples) - label_list_1, label_list_2 = [], [] - num_iter = 0 - for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1), - all_data_2.create_dict_iterator(num_epochs=1)): - label_list_1.append(item1["label"].asnumpy()) - label_list_2.append(item2["label"].asnumpy()) - num_iter += 1 - np.testing.assert_array_equal(label_list_1, label_list_2) - assert num_iter == num_samples - - -def test_flowers102_exception(): - """ - Feature: Flowers102Dataset - Description: Test error cases on Flowers102Dataset - Expectation: Correct error is thrown as expected - """ - logger.info("Test error cases for Flowers102Dataset") - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False, - decode=True, sampler=ds.SequentialSampler(1)) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1), - decode=True, num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1) - - with pytest.raises(ValueError, match=error_msg_5): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5) - - with pytest.raises(ValueError, match=error_msg_5): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, - shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, - shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, - shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0") - - - error_msg_8 = "does not exist or is not a directory or permission denied!" - with pytest.raises(ValueError, match=error_msg_8): - all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True) - for _ in all_data.create_dict_iterator(num_epochs=1): - pass - - error_msg_9 = "is not of type" - with pytest.raises(TypeError, match=error_msg_9): - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123) - for _ in all_data.create_dict_iterator(num_epochs=1): - pass - - -def test_flowers102_visualize(plot=False): - """ - Feature: Flowers102Dataset - Description: Test Flowers102Dataset visualization for results - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset visualization") - - all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4, - decode=True, shuffle=False) - num_iter = 0 - image_list, label_list = [], [] - for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True): - image = item["image"] - label = item["label"] - image_list.append(image) - label_list.append("label {}".format(label)) - assert isinstance(image, np.ndarray) - assert len(image.shape) == 3 - assert image.shape[-1] == 3 - assert image.dtype == np.uint8 - assert label.dtype == np.uint32 - num_iter += 1 - assert num_iter == 4 - if plot: - visualize_dataset(image_list, label_list) - - -def test_flowers102_usage(): - """ - Feature: Flowers102Dataset - Description: Test Flowers102Dataset usage flag - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset usage flag") - - def test_config(usage): - try: - data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("all") == 6 - assert test_config("train") == 2 - assert test_config("test") == 2 - assert test_config("valid") == 2 - - assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") - assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) - - -def test_flowers102_task(): - """ - Feature: Flowers102Dataset - Description: Test Flowers102Dataset task flag - Expectation: The dataset is processed as expected - """ - logger.info("Test Flowers102Dataset task flag") - - def test_config(task): - try: - data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("Classification") == 6 - assert test_config("Segmentation") == 6 - - assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid") - assert "Argument task with value ['list'] is not of type []" in test_config(["list"]) - -if __name__ == '__main__': - test_flowers102_content_check() - test_flowers102_basic() - test_flowers102_sequential_sampler() - test_flowers102_exception() - test_flowers102_visualize(plot=True) - test_flowers102_usage() - test_flowers102_task() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test Flowers102 dataset operations +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from PIL import Image +from scipy.io import loadmat + +import mindspore.dataset as ds +import mindspore.dataset.vision as c_vision +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testFlowers102Dataset" +WRONG_DIR = "../data/dataset/testMnistData" + + +def load_flowers102(path, usage): + """ + load Flowers102 data + """ + assert usage in ["train", "valid", "test", "all"] + + imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32) + split = loadmat(os.path.join(path, "setid.mat")) + if usage == 'train': + indices = split["trnid"][0].tolist() + elif usage == 'test': + indices = split["tstid"][0].tolist() + elif usage == 'valid': + indices = split["valid"][0].tolist() + elif usage == 'all': + indices = split["trnid"][0].tolist() + indices += split["tstid"][0].tolist() + indices += split["valid"][0].tolist() + + image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices] + segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices] + images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths] + segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths] + labels = [imagelabels[index - 1] for index in indices] + + return images, segmentations, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze()) + plt.title(labels[i]) + plt.show() + + +def test_flowers102_content_check(): + """ + Feature: Flowers102Dataset + Description: Test Flowers102Dataset image readings with content check + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset Op with content check") + all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all", + num_samples=6, decode=True, shuffle=False) + images, segmentations, labels = load_flowers102(DATA_DIR, "all") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["segmentation"], segmentations[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 6 + + train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train", + num_samples=2, decode=True, shuffle=False) + images, segmentations, labels = load_flowers102(DATA_DIR, "train") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["segmentation"], segmentations[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 2 + + test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test", + num_samples=2, decode=True, shuffle=False) + images, segmentations, labels = load_flowers102(DATA_DIR, "test") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["segmentation"], segmentations[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 2 + + val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid", + num_samples=2, decode=True, shuffle=False) + images, segmentations, labels = load_flowers102(DATA_DIR, "valid") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["segmentation"], segmentations[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 2 + + +def test_flowers102_basic(): + """ + Feature: Flowers102Dataset + Description: Test basic read on Flowers102Dataset + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset Op") + + # case 1: test decode + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False) + all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1) + all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False) + + num_iter = 0 + for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True), + all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1["label"], item2["label"]) + num_iter += 1 + assert num_iter == 6 + + # case 2: test num_samples + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) + num_iter = 0 + for _ in all_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 4 + + # case 3: test repeat + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) + all_data = all_data.repeat(5) + num_iter = 0 + for _ in all_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 20 + + # case 3: test get_dataset_size, resize and batch + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) + all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"], + num_parallel_workers=1) + + assert all_data.get_dataset_size() == 4 + assert all_data.get_batch_size() == 1 + all_data = all_data.batch(batch_size=3) # drop_remainder is default to be False + assert all_data.get_batch_size() == 3 + assert all_data.get_dataset_size() == 2 + + num_iter = 0 + for _ in all_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 2 + + # case 4: test get_class_indexing + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) + class_indexing = all_data.get_class_indexing() + assert class_indexing["pink primrose"] == 0 + assert class_indexing["blackberry lily"] == 101 + + +def test_flowers102_sequential_sampler(): + """ + Feature: Flowers102Dataset + Description: Test Flowers102Dataset with SequentialSampler + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset Op with SequentialSampler") + num_samples = 4 + sampler = ds.SequentialSampler(num_samples=num_samples) + all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", + decode=True, sampler=sampler) + all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", + decode=True, shuffle=False, num_samples=num_samples) + label_list_1, label_list_2 = [], [] + num_iter = 0 + for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1), + all_data_2.create_dict_iterator(num_epochs=1)): + label_list_1.append(item1["label"].asnumpy()) + label_list_2.append(item2["label"].asnumpy()) + num_iter += 1 + np.testing.assert_array_equal(label_list_1, label_list_2) + assert num_iter == num_samples + + +def test_flowers102_exception(): + """ + Feature: Flowers102Dataset + Description: Test error cases on Flowers102Dataset + Expectation: Correct error is thrown as expected + """ + logger.info("Test error cases for Flowers102Dataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False, + decode=True, sampler=ds.SequentialSampler(1)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1), + decode=True, num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1) + + with pytest.raises(ValueError, match=error_msg_5): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5) + + with pytest.raises(ValueError, match=error_msg_5): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, + shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, + shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, + shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0") + + + error_msg_8 = "does not exist or is not a directory or permission denied!" + with pytest.raises(ValueError, match=error_msg_8): + all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True) + for _ in all_data.create_dict_iterator(num_epochs=1): + pass + + error_msg_9 = "is not of type" + with pytest.raises(TypeError, match=error_msg_9): + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123) + for _ in all_data.create_dict_iterator(num_epochs=1): + pass + + +def test_flowers102_visualize(plot=False): + """ + Feature: Flowers102Dataset + Description: Test Flowers102Dataset visualization for results + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset visualization") + + all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4, + decode=True, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert len(image.shape) == 3 + assert image.shape[-1] == 3 + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 4 + if plot: + visualize_dataset(image_list, label_list) + + +def test_flowers102_usage(): + """ + Feature: Flowers102Dataset + Description: Test Flowers102Dataset usage flag + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset usage flag") + + def test_config(usage): + try: + data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("all") == 6 + assert test_config("train") == 2 + assert test_config("test") == 2 + assert test_config("valid") == 2 + + assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + +def test_flowers102_task(): + """ + Feature: Flowers102Dataset + Description: Test Flowers102Dataset task flag + Expectation: The dataset is processed as expected + """ + logger.info("Test Flowers102Dataset task flag") + + def test_config(task): + try: + data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("Classification") == 6 + assert test_config("Segmentation") == 6 + + assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid") + assert "Argument task with value ['list'] is not of type []" in test_config(["list"]) + +if __name__ == '__main__': + test_flowers102_content_check() + test_flowers102_basic() + test_flowers102_sequential_sampler() + test_flowers102_exception() + test_flowers102_visualize(plot=True) + test_flowers102_usage() + test_flowers102_task() diff --git a/tests/ut/python/dataset/test_datasets_gtzan.py b/tests/ut/python/dataset/test_datasets_gtzan.py index f93c1d8a208..b2ecf616bac 100644 --- a/tests/ut/python/dataset/test_datasets_gtzan.py +++ b/tests/ut/python/dataset/test_datasets_gtzan.py @@ -1,223 +1,223 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License foNtest_resr the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test Gtzan dataset operations. -""" -import numpy as np -import pytest - -import mindspore.dataset as ds -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testGTZANData" - - -def test_gtzan_basic(): - """ - Feature: GTZANDataset - Description: Test basic usage of GTZAN - Expectation: The dataset is as expected - """ - logger.info("Test GTZANDataset Op") - - # case 1: test loading whole dataset. - data1 = ds.GTZANDataset(DATA_DIR) - num_iter1 = 0 - for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter1 += 1 - assert num_iter1 == 3 - - # case 2: test num_samples. - data2 = ds.GTZANDataset(DATA_DIR, num_samples=2) - num_iter2 = 0 - for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter2 += 1 - assert num_iter2 == 2 - - # case 3: test repeat. - data3 = ds.GTZANDataset(DATA_DIR, num_samples=2) - data3 = data3.repeat(5) - num_iter3 = 0 - for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter3 += 1 - assert num_iter3 == 10 - - # case 4: test batch with drop_remainder=False. - data4 = ds.GTZANDataset(DATA_DIR, num_samples=3) - assert data4.get_dataset_size() == 3 - assert data4.get_batch_size() == 1 - data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. - assert data4.get_dataset_size() == 2 - assert data4.get_batch_size() == 2 - - # case 5: test batch with drop_remainder=True. - data5 = ds.GTZANDataset(DATA_DIR, num_samples=3) - assert data5.get_dataset_size() == 3 - assert data5.get_batch_size() == 1 - # the rest of incomplete batch will be dropped. - data5 = data5.batch(batch_size=2, drop_remainder=True) - assert data5.get_dataset_size() == 1 - assert data5.get_batch_size() == 2 - - -def test_gtzan_distribute_sampler(): - """ - Feature: GTZANDataset - Description: Test GTZAN dataset with DistributedSampler - Expectation: The results are as expected - """ - logger.info("Test GTZAN with DistributedSampler") - - label_list1, label_list2 = [], [] - num_shards = 3 - shard_id = 0 - - data1 = ds.GTZANDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id) - count = 0 - for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - label_list1.append(item1["label"]) - count = count + 1 - assert count == 1 - - num_shards = 3 - shard_id = 0 - sampler = ds.DistributedSampler(num_shards, shard_id) - data2 = ds.GTZANDataset(DATA_DIR, usage="all", sampler=sampler) - count = 0 - for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - label_list2.append(item2["label"]) - count = count + 1 - np.testing.assert_array_equal(label_list1, label_list2) - assert count == 1 - - -def test_gtzan_exception(): - """ - Feature: GTZANDataset - Description: Test error cases for GTZANDataset - Expectation: The results are as expected - """ - logger.info("Test error cases for GTZANDataset") - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.GTZANDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.GTZANDataset(DATA_DIR, sampler=ds.PKSampler(3), - num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.GTZANDataset(DATA_DIR, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.GTZANDataset(DATA_DIR, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id="0") - - def exception_func(item): - raise Exception("Error occur!") - - error_msg_8 = "The corresponding data file is" - - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.GTZANDataset(DATA_DIR) - data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) - for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): - pass - - -def test_gtzan_sequential_sampler(): - """ - Feature: GTZANDataset - Description: Test GTZANDataset with SequentialSampler - Expectation: The results are as expected - """ - logger.info("Test GTZANDataset Op with SequentialSampler") - num_samples = 2 - sampler = ds.SequentialSampler(num_samples=num_samples) - data1 = ds.GTZANDataset(DATA_DIR, sampler=sampler) - data2 = ds.GTZANDataset(DATA_DIR, shuffle=False, num_samples=num_samples) - label_list1, label_list2 = [], [] - num_iter = 0 - for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), - data2.create_dict_iterator(output_numpy=True, num_epochs=1)): - label_list1.append(item1["label"]) - label_list2.append(item2["label"]) - num_iter += 1 - np.testing.assert_array_equal(label_list1, label_list2) - assert num_iter == num_samples - - -def test_gtzan_usage(): - """ - Feature: GTZANDataset - Description: Test GTZANDataset usage - Expectation: The results are as expected - """ - logger.info("Test GTZANDataset usage") - - def test_config(usage, gtzan_path=None): - gtzan_path = DATA_DIR if gtzan_path is None else gtzan_path - try: - data = ds.GTZANDataset(gtzan_path, usage=usage, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("valid") == 3 - assert test_config("all") == 3 - assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") - assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) - - # change this directory to the folder that contains all gtzan files. - all_files_path = None - # the following tests on the entire datasets. - if all_files_path is not None: - assert test_config("train", all_files_path) == 3 - assert test_config("valid", all_files_path) == 3 - assert ds.GTZANDataset(all_files_path, usage="train").get_dataset_size() == 3 - assert ds.GTZANDataset(all_files_path, usage="valid").get_dataset_size() == 3 - - -if __name__ == '__main__': - test_gtzan_basic() - test_gtzan_distribute_sampler() - test_gtzan_exception() - test_gtzan_sequential_sampler() - test_gtzan_usage() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License foNtest_resr the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test Gtzan dataset operations. +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testGTZANData" + + +def test_gtzan_basic(): + """ + Feature: GTZANDataset + Description: Test basic usage of GTZAN + Expectation: The dataset is as expected + """ + logger.info("Test GTZANDataset Op") + + # case 1: test loading whole dataset. + data1 = ds.GTZANDataset(DATA_DIR) + num_iter1 = 0 + for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 3 + + # case 2: test num_samples. + data2 = ds.GTZANDataset(DATA_DIR, num_samples=2) + num_iter2 = 0 + for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 2 + + # case 3: test repeat. + data3 = ds.GTZANDataset(DATA_DIR, num_samples=2) + data3 = data3.repeat(5) + num_iter3 = 0 + for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 10 + + # case 4: test batch with drop_remainder=False. + data4 = ds.GTZANDataset(DATA_DIR, num_samples=3) + assert data4.get_dataset_size() == 3 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. + assert data4.get_dataset_size() == 2 + assert data4.get_batch_size() == 2 + + # case 5: test batch with drop_remainder=True. + data5 = ds.GTZANDataset(DATA_DIR, num_samples=3) + assert data5.get_dataset_size() == 3 + assert data5.get_batch_size() == 1 + # the rest of incomplete batch will be dropped. + data5 = data5.batch(batch_size=2, drop_remainder=True) + assert data5.get_dataset_size() == 1 + assert data5.get_batch_size() == 2 + + +def test_gtzan_distribute_sampler(): + """ + Feature: GTZANDataset + Description: Test GTZAN dataset with DistributedSampler + Expectation: The results are as expected + """ + logger.info("Test GTZAN with DistributedSampler") + + label_list1, label_list2 = [], [] + num_shards = 3 + shard_id = 0 + + data1 = ds.GTZANDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id) + count = 0 + for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + label_list1.append(item1["label"]) + count = count + 1 + assert count == 1 + + num_shards = 3 + shard_id = 0 + sampler = ds.DistributedSampler(num_shards, shard_id) + data2 = ds.GTZANDataset(DATA_DIR, usage="all", sampler=sampler) + count = 0 + for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + label_list2.append(item2["label"]) + count = count + 1 + np.testing.assert_array_equal(label_list1, label_list2) + assert count == 1 + + +def test_gtzan_exception(): + """ + Feature: GTZANDataset + Description: Test error cases for GTZANDataset + Expectation: The results are as expected + """ + logger.info("Test error cases for GTZANDataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.GTZANDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.GTZANDataset(DATA_DIR, sampler=ds.PKSampler(3), + num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.GTZANDataset(DATA_DIR, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.GTZANDataset(DATA_DIR, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data file is" + + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.GTZANDataset(DATA_DIR) + data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) + for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): + pass + + +def test_gtzan_sequential_sampler(): + """ + Feature: GTZANDataset + Description: Test GTZANDataset with SequentialSampler + Expectation: The results are as expected + """ + logger.info("Test GTZANDataset Op with SequentialSampler") + num_samples = 2 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.GTZANDataset(DATA_DIR, sampler=sampler) + data2 = ds.GTZANDataset(DATA_DIR, shuffle=False, num_samples=num_samples) + label_list1, label_list2 = [], [] + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), + data2.create_dict_iterator(output_numpy=True, num_epochs=1)): + label_list1.append(item1["label"]) + label_list2.append(item2["label"]) + num_iter += 1 + np.testing.assert_array_equal(label_list1, label_list2) + assert num_iter == num_samples + + +def test_gtzan_usage(): + """ + Feature: GTZANDataset + Description: Test GTZANDataset usage + Expectation: The results are as expected + """ + logger.info("Test GTZANDataset usage") + + def test_config(usage, gtzan_path=None): + gtzan_path = DATA_DIR if gtzan_path is None else gtzan_path + try: + data = ds.GTZANDataset(gtzan_path, usage=usage, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("valid") == 3 + assert test_config("all") == 3 + assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + # change this directory to the folder that contains all gtzan files. + all_files_path = None + # the following tests on the entire datasets. + if all_files_path is not None: + assert test_config("train", all_files_path) == 3 + assert test_config("valid", all_files_path) == 3 + assert ds.GTZANDataset(all_files_path, usage="train").get_dataset_size() == 3 + assert ds.GTZANDataset(all_files_path, usage="valid").get_dataset_size() == 3 + + +if __name__ == '__main__': + test_gtzan_basic() + test_gtzan_distribute_sampler() + test_gtzan_exception() + test_gtzan_sequential_sampler() + test_gtzan_usage() diff --git a/tests/ut/python/dataset/test_datasets_libri_tts.py b/tests/ut/python/dataset/test_datasets_libri_tts.py index 61657d0b38e..9a3dbb73388 100644 --- a/tests/ut/python/dataset/test_datasets_libri_tts.py +++ b/tests/ut/python/dataset/test_datasets_libri_tts.py @@ -1,235 +1,235 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License foNtest_resr the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test LibriTTS dataset operations -""" -import numpy as np -import pytest - -import mindspore.dataset as ds -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testLibriTTSData" - - -def test_libri_tts_basic(): - """ - Feature: LibriTTSDataset - Description: Test basic usage of LibriTTS - Expectation: The dataset is as expected - """ - logger.info("Test LibriTTSDataset Op") - - # case 1: test loading fault dataset. - data1 = ds.LibriTTSDataset(DATA_DIR) - num_iter1 = 0 - for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter1 += 1 - assert num_iter1 == 3 - - # case 2: test num_samples. - data2 = ds.LibriTTSDataset(DATA_DIR, num_samples=1) - num_iter2 = 0 - for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter2 += 1 - assert num_iter2 == 1 - - # case 3: test repeat. - data3 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_samples=3) - data3 = data3.repeat(3) - num_iter3 = 0 - for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): - num_iter3 += 1 - assert num_iter3 == 9 - - # case 4: test batch with drop_remainder=False. - data4 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3) - assert data4.get_dataset_size() == 3 - assert data4.get_batch_size() == 1 - data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. - assert data4.get_dataset_size() == 2 - assert data4.get_batch_size() == 2 - - # case 5: test batch with drop_remainder=True. - data5 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3) - assert data5.get_dataset_size() == 3 - assert data5.get_batch_size() == 1 - # the rest of incomplete batch will be dropped. - data5 = data5.batch(batch_size=2, drop_remainder=True) - assert data5.get_dataset_size() == 1 - assert data5.get_batch_size() == 2 - - -def test_libri_tts_distribute_sampler(): - """ - Feature: LibriTTSDataset - Description: Test LibriTTS dataset with DisributeSampler - Expectation: The results are as expected - """ - logger.info("Test LibriTTS with sharding") - - list1, list2 = [], [] - num_shards = 3 - shard_id = 0 - - data1 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id) - count = 0 - for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1): - list1.append(item1["original_text"]) - count = count + 1 - assert count == 1 - - num_shards = 3 - shard_id = 0 - sampler = ds.DistributedSampler(num_shards, shard_id) - data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler) - count = 0 - for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1): - list2.append(item2["original_text"]) - count = count + 1 - assert count == 1 - - -def test_libri_tts_exception(): - """ - Feature: LibriTTSDataset - Description: Test error cases for LibriTTSDataset - Expectation: The results are as expected - """ - logger.info("Test error cases for LibriTTSDataset") - - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.LibriTTSDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.LibriTTSDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.LibriTTSDataset(DATA_DIR, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.LibriTTSDataset(DATA_DIR, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id="0") - - def exception_func(item): - raise Exception("Error occur!") - - error_msg_8 = "The corresponding data file is" - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.LibriTTSDataset(DATA_DIR) - data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) - for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): - pass - - -def test_libri_tts_sequential_sampler(): - """ - Feature: LibriTTSDataset - Description: Test LibriTTSDataset with SequentialSampler - Expectation: The results are as expected - """ - logger.info("Test LibriTTSDataset Op with SequentialSampler") - - num_samples = 2 - sampler = ds.SequentialSampler(num_samples=num_samples) - data1 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler) - data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", shuffle=False, num_samples=num_samples) - list1, list2 = [], [] - list_expected = [24000, 'good morning', 'Good morning', 2506, 11267, '2506_11267_000001_000000', - 24000, 'good afternoon', 'Good afternoon', 2506, 11267, '2506_11267_000002_000000'] - - num_iter = 0 - for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), - data2.create_dict_iterator(output_numpy=True, num_epochs=1)): - list1.append(item1["sample_rate"]) - list2.append(item2["sample_rate"]) - list1.append(item1["original_text"]) - list2.append(item2["original_text"]) - list1.append(item1["normalized_text"]) - list2.append(item2["normalized_text"]) - list1.append(item1["speaker_id"]) - list2.append(item2["speaker_id"]) - list1.append(item1["chapter_id"]) - list2.append(item2["chapter_id"]) - list1.append(item1["utterance_id"]) - list2.append(item2["utterance_id"]) - num_iter += 1 - np.testing.assert_array_equal(list1, list_expected) - np.testing.assert_array_equal(list2, list_expected) - assert num_iter == num_samples - - -def test_libri_tts_usage(): - """ - Feature: LibriTTSDataset - Description: Test LibriTTSDataset usage - Expectation: The results are as expected - """ - logger.info("Test LibriTTSDataset usage") - - def test_config(usage, libri_tts_path=None): - libri_tts_path = DATA_DIR if libri_tts_path is None else libri_tts_path - try: - data = ds.LibriTTSDataset(libri_tts_path, usage=usage, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("all") == 3 - assert test_config("train-clean-100") == 3 - assert "Input usage is not within the valid set of ['dev-clean', 'dev-other', 'test-clean', 'test-other', " \ - "'train-clean-100', 'train-clean-360', 'train-other-500', 'all']." in test_config("invalid") - assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) - - all_files_path = None - if all_files_path is not None: - assert test_config("train-clean-100", all_files_path) == 3 - assert ds.LibriTTSDataset(all_files_path, usage="train-clean-100").get_dataset_size() == 3 - assert test_config("all", all_files_path) == 3 - assert ds.LibriTTSDataset(all_files_path, usage="all").get_dataset_size() == 3 - - -if __name__ == '__main__': - test_libri_tts_basic() - test_libri_tts_distribute_sampler() - test_libri_tts_exception() - test_libri_tts_sequential_sampler() - test_libri_tts_usage() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License foNtest_resr the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test LibriTTS dataset operations +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testLibriTTSData" + + +def test_libri_tts_basic(): + """ + Feature: LibriTTSDataset + Description: Test basic usage of LibriTTS + Expectation: The dataset is as expected + """ + logger.info("Test LibriTTSDataset Op") + + # case 1: test loading fault dataset. + data1 = ds.LibriTTSDataset(DATA_DIR) + num_iter1 = 0 + for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 3 + + # case 2: test num_samples. + data2 = ds.LibriTTSDataset(DATA_DIR, num_samples=1) + num_iter2 = 0 + for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 1 + + # case 3: test repeat. + data3 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_samples=3) + data3 = data3.repeat(3) + num_iter3 = 0 + for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 9 + + # case 4: test batch with drop_remainder=False. + data4 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3) + assert data4.get_dataset_size() == 3 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. + assert data4.get_dataset_size() == 2 + assert data4.get_batch_size() == 2 + + # case 5: test batch with drop_remainder=True. + data5 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", num_samples=3) + assert data5.get_dataset_size() == 3 + assert data5.get_batch_size() == 1 + # the rest of incomplete batch will be dropped. + data5 = data5.batch(batch_size=2, drop_remainder=True) + assert data5.get_dataset_size() == 1 + assert data5.get_batch_size() == 2 + + +def test_libri_tts_distribute_sampler(): + """ + Feature: LibriTTSDataset + Description: Test LibriTTS dataset with DisributeSampler + Expectation: The results are as expected + """ + logger.info("Test LibriTTS with sharding") + + list1, list2 = [], [] + num_shards = 3 + shard_id = 0 + + data1 = ds.LibriTTSDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id) + count = 0 + for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1): + list1.append(item1["original_text"]) + count = count + 1 + assert count == 1 + + num_shards = 3 + shard_id = 0 + sampler = ds.DistributedSampler(num_shards, shard_id) + data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler) + count = 0 + for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1): + list2.append(item2["original_text"]) + count = count + 1 + assert count == 1 + + +def test_libri_tts_exception(): + """ + Feature: LibriTTSDataset + Description: Test error cases for LibriTTSDataset + Expectation: The results are as expected + """ + logger.info("Test error cases for LibriTTSDataset") + + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.LibriTTSDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.LibriTTSDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.LibriTTSDataset(DATA_DIR, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.LibriTTSDataset(DATA_DIR, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.LibriTTSDataset(DATA_DIR, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.LibriTTSDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.LibriTTSDataset(DATA_DIR, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data file is" + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.LibriTTSDataset(DATA_DIR) + data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) + for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): + pass + + +def test_libri_tts_sequential_sampler(): + """ + Feature: LibriTTSDataset + Description: Test LibriTTSDataset with SequentialSampler + Expectation: The results are as expected + """ + logger.info("Test LibriTTSDataset Op with SequentialSampler") + + num_samples = 2 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", sampler=sampler) + data2 = ds.LibriTTSDataset(DATA_DIR, usage="train-clean-100", shuffle=False, num_samples=num_samples) + list1, list2 = [], [] + list_expected = [24000, 'good morning', 'Good morning', 2506, 11267, '2506_11267_000001_000000', + 24000, 'good afternoon', 'Good afternoon', 2506, 11267, '2506_11267_000002_000000'] + + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), + data2.create_dict_iterator(output_numpy=True, num_epochs=1)): + list1.append(item1["sample_rate"]) + list2.append(item2["sample_rate"]) + list1.append(item1["original_text"]) + list2.append(item2["original_text"]) + list1.append(item1["normalized_text"]) + list2.append(item2["normalized_text"]) + list1.append(item1["speaker_id"]) + list2.append(item2["speaker_id"]) + list1.append(item1["chapter_id"]) + list2.append(item2["chapter_id"]) + list1.append(item1["utterance_id"]) + list2.append(item2["utterance_id"]) + num_iter += 1 + np.testing.assert_array_equal(list1, list_expected) + np.testing.assert_array_equal(list2, list_expected) + assert num_iter == num_samples + + +def test_libri_tts_usage(): + """ + Feature: LibriTTSDataset + Description: Test LibriTTSDataset usage + Expectation: The results are as expected + """ + logger.info("Test LibriTTSDataset usage") + + def test_config(usage, libri_tts_path=None): + libri_tts_path = DATA_DIR if libri_tts_path is None else libri_tts_path + try: + data = ds.LibriTTSDataset(libri_tts_path, usage=usage, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("all") == 3 + assert test_config("train-clean-100") == 3 + assert "Input usage is not within the valid set of ['dev-clean', 'dev-other', 'test-clean', 'test-other', " \ + "'train-clean-100', 'train-clean-360', 'train-other-500', 'all']." in test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + all_files_path = None + if all_files_path is not None: + assert test_config("train-clean-100", all_files_path) == 3 + assert ds.LibriTTSDataset(all_files_path, usage="train-clean-100").get_dataset_size() == 3 + assert test_config("all", all_files_path) == 3 + assert ds.LibriTTSDataset(all_files_path, usage="all").get_dataset_size() == 3 + + +if __name__ == '__main__': + test_libri_tts_basic() + test_libri_tts_distribute_sampler() + test_libri_tts_exception() + test_libri_tts_sequential_sampler() + test_libri_tts_usage() diff --git a/tests/ut/python/dataset/test_datasets_obs_mindrecord.py b/tests/ut/python/dataset/test_datasets_obs_mindrecord.py index 29627619523..453368e262d 100644 --- a/tests/ut/python/dataset/test_datasets_obs_mindrecord.py +++ b/tests/ut/python/dataset/test_datasets_obs_mindrecord.py @@ -1,80 +1,80 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test OBSMindDataset operations -""" -import pytest - -from mindspore.dataset.engine.datasets_standard_format import OBSMindDataset -from mindspore import log as logger - -DATA_DIR = ["s3://dataset/imagenet0", "s3://dataset/imagenet1"] - - -def test_obs_mindrecord_exception(): - """ - Feature: Test OBSMindDataset. - Description: Invalid input. - Expectation: Raise exception. - """ - - logger.info("Test error cases for MnistDataset") - error_msg_0 = "Argument dataset_files" - with pytest.raises(TypeError, match=error_msg_0): - OBSMindDataset("err_dataset", "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") - - error_msg_0_1 = "Item of dataset files" - with pytest.raises(TypeError, match=error_msg_0_1): - OBSMindDataset([1, 2], "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") - - error_msg_1 = "Argument server" - with pytest.raises(TypeError, match=error_msg_1): - OBSMindDataset(DATA_DIR, 12, "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") - - error_msg_1_1 = "server should" - with pytest.raises(ValueError, match=error_msg_1_1): - OBSMindDataset(DATA_DIR, "ftp://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") - - error_msg_2 = "Argument ak" - with pytest.raises(TypeError, match=error_msg_2): - OBSMindDataset(DATA_DIR, "https://dummy_site", 12, "dummy_sk", "s3://dummy_sync_dir") - - error_msg_3 = "Argument sk" - with pytest.raises(TypeError, match=error_msg_3): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", 12, "s3://dummy_sync_dir") - - error_msg_4 = "Argument sync_obs_path" - with pytest.raises(TypeError, match=error_msg_4): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", "dummy_sk", 12) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", - "dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", - "dummy_sk", "s3://dummy_sync_dir", num_shards=4, shard_id=4) - with pytest.raises(ValueError, match=error_msg_5): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", - "dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=4) - - error_msg_7 = "Argument shard_equal_rows" - with pytest.raises(TypeError, match=error_msg_7): - OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", - "dummy_sk", "s3://dummy_sync_dir", shard_equal_rows=1) - - -if __name__ == '__main__': - test_obs_mindrecord_exception() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test OBSMindDataset operations +""" +import pytest + +from mindspore.dataset.engine.datasets_standard_format import OBSMindDataset +from mindspore import log as logger + +DATA_DIR = ["s3://dataset/imagenet0", "s3://dataset/imagenet1"] + + +def test_obs_mindrecord_exception(): + """ + Feature: Test OBSMindDataset. + Description: Invalid input. + Expectation: Raise exception. + """ + + logger.info("Test error cases for MnistDataset") + error_msg_0 = "Argument dataset_files" + with pytest.raises(TypeError, match=error_msg_0): + OBSMindDataset("err_dataset", "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") + + error_msg_0_1 = "Item of dataset files" + with pytest.raises(TypeError, match=error_msg_0_1): + OBSMindDataset([1, 2], "https://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") + + error_msg_1 = "Argument server" + with pytest.raises(TypeError, match=error_msg_1): + OBSMindDataset(DATA_DIR, 12, "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") + + error_msg_1_1 = "server should" + with pytest.raises(ValueError, match=error_msg_1_1): + OBSMindDataset(DATA_DIR, "ftp://dummy_site", "dummy_ak", "dummy_sk", "s3://dummy_sync_dir") + + error_msg_2 = "Argument ak" + with pytest.raises(TypeError, match=error_msg_2): + OBSMindDataset(DATA_DIR, "https://dummy_site", 12, "dummy_sk", "s3://dummy_sync_dir") + + error_msg_3 = "Argument sk" + with pytest.raises(TypeError, match=error_msg_3): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", 12, "s3://dummy_sync_dir") + + error_msg_4 = "Argument sync_obs_path" + with pytest.raises(TypeError, match=error_msg_4): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", "dummy_sk", 12) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", + "dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", + "dummy_sk", "s3://dummy_sync_dir", num_shards=4, shard_id=4) + with pytest.raises(ValueError, match=error_msg_5): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", + "dummy_sk", "s3://dummy_sync_dir", num_shards=2, shard_id=4) + + error_msg_7 = "Argument shard_equal_rows" + with pytest.raises(TypeError, match=error_msg_7): + OBSMindDataset(DATA_DIR, "https://dummy_site", "dummy_ak", + "dummy_sk", "s3://dummy_sync_dir", shard_equal_rows=1) + + +if __name__ == '__main__': + test_obs_mindrecord_exception() diff --git a/tests/ut/python/dataset/test_datasets_penn_treebank.py b/tests/ut/python/dataset/test_datasets_penn_treebank.py index 5358bbf695e..36b4e1b7d39 100644 --- a/tests/ut/python/dataset/test_datasets_penn_treebank.py +++ b/tests/ut/python/dataset/test_datasets_penn_treebank.py @@ -1,384 +1,384 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import pytest - -import mindspore.dataset as ds -from mindspore import log as logger -from util import config_get_set_num_parallel_workers, config_get_set_seed - -FILE_DIR = '../data/dataset/testPennTreebank' - - -def test_penn_treebank_dataset_one_file(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='test') - count = 0 - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - logger.info("{}".format(i["text"])) - count += 1 - assert count == 3 - - -def test_penn_treebank_dataset_train(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='train') - count = 0 - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - logger.info("{}".format(i["text"])) - count += 1 - assert count == 3 - - -def test_penn_treebank_dataset_valid(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='valid') - count = 0 - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - logger.info("{}".format(i["text"])) - count += 1 - assert count == 3 - - -def test_penn_treebank_dataset_all_file(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='all') - count = 0 - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - logger.info("{}".format(i["text"])) - count += 1 - assert count == 9 - - -def test_penn_treebank_dataset_num_samples_none(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data with no num_samples input. - Expectation: The data is processed successfully. - """ - # Do not provide a num_samples argument, so it would be None by default - data = ds.PennTreebankDataset(FILE_DIR, usage='all') - count = 0 - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - logger.info("{}".format(i["text"])) - count += 1 - assert count == 9 - - -def test_penn_treebank_dataset_shuffle_false4(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is false. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(4) - original_seed = config_get_set_seed(987) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False) - count = 0 - line = [" no it was black friday ", - " does the bank charge a fee for setting up the account ", - " just ahead of them there was a huge fissure ", - " clash twits poetry formulate flip loyalty splash ", - " the wardrobe was very small in our room ", - " the proportion of female workers in this company ", - " you pay less for the supermaket's own brands ", - " black white grapes ", - " everyone in our football team is fuming "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_shuffle_false1(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is false. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(1) - original_seed = config_get_set_seed(987) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False) - count = 0 - line = [" no it was black friday ", - " clash twits poetry formulate flip loyalty splash ", - " you pay less for the supermaket's own brands ", - " does the bank charge a fee for setting up the account ", - " the wardrobe was very small in our room ", - " black white grapes ", - " just ahead of them there was a huge fissure ", - " the proportion of female workers in this company ", - " everyone in our football team is fuming "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_shuffle_files4(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is files. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(4) - original_seed = config_get_set_seed(135) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES) - count = 0 - line = [" just ahead of them there was a huge fissure ", - " does the bank charge a fee for setting up the account ", - " no it was black friday ", - " the proportion of female workers in this company ", - " the wardrobe was very small in our room ", - " clash twits poetry formulate flip loyalty splash ", - " everyone in our football team is fuming ", - " black white grapes ", - " you pay less for the supermaket's own brands "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_shuffle_files1(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is files. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(1) - original_seed = config_get_set_seed(135) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES) - count = 0 - line = [" just ahead of them there was a huge fissure ", - " the proportion of female workers in this company ", - " everyone in our football team is fuming ", - " does the bank charge a fee for setting up the account ", - " the wardrobe was very small in our room ", - " black white grapes ", - " no it was black friday ", - " clash twits poetry formulate flip loyalty splash ", - " you pay less for the supermaket's own brands "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_shuffle_global4(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is global. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(4) - original_seed = config_get_set_seed(246) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL) - count = 0 - line = [" everyone in our football team is fuming ", - " does the bank charge a fee for setting up the account ", - " clash twits poetry formulate flip loyalty splash ", - " no it was black friday ", - " just ahead of them there was a huge fissure ", - " the proportion of female workers in this company ", - " you pay less for the supermaket's own brands ", - " the wardrobe was very small in our room ", - " black white grapes "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_shuffle_global1(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file with shulle is global. - Expectation: The data is processed successfully. - """ - original_num_parallel_workers = config_get_set_num_parallel_workers(1) - original_seed = config_get_set_seed(246) - data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL) - count = 0 - line = [" everyone in our football team is fuming ", - " does the bank charge a fee for setting up the account ", - " clash twits poetry formulate flip loyalty splash ", - " the wardrobe was very small in our room ", - " black white grapes ", - " you pay less for the supermaket's own brands ", - " the proportion of female workers in this company ", - " no it was black friday ", - " just ahead of them there was a huge fissure "] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - # Restore configuration - ds.config.set_num_parallel_workers(original_num_parallel_workers) - ds.config.set_seed(original_seed) - - -def test_penn_treebank_dataset_num_samples(): - """ - Feature: Test PennTreebank Dataset. - Description: Test num_samples. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_samples=2) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 2 - - -def test_penn_treebank_dataset_distribution(): - """ - Feature: Test PennTreebank Dataset. - Description: Read data from a single file. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 5 - - -def test_penn_treebank_dataset_repeat(): - """ - Feature: Test PennTreebank Dataset. - Description: Test repeat. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='test', shuffle=False) - data = data.repeat(3) - count = 0 - line = [" no it was black friday ", - " clash twits poetry formulate flip loyalty splash ", - " you pay less for the supermaket's own brands ", - " no it was black friday ", - " clash twits poetry formulate flip loyalty splash ", - " you pay less for the supermaket's own brands ", - " no it was black friday ", - " clash twits poetry formulate flip loyalty splash ", - " you pay less for the supermaket's own brands ",] - for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): - strs = i["text"] - assert strs == line[count] - count += 1 - assert count == 9 - - -def test_penn_treebank_dataset_get_datasetsize(): - """ - Feature: Test PennTreebank Dataset. - Description: Test get_datasetsize. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='test') - size = data.get_dataset_size() - assert size == 3 - - -def test_penn_treebank_dataset_device_que(): - """ - Feature: Test PennTreebank Dataset. - Description: Test device_que. - Expectation: The data is processed successfully. - """ - data = ds.PennTreebankDataset(FILE_DIR, usage='test') - data = data.device_que() - data.send() - - -def test_penn_treebank_dataset_exceptions(): - """ - Feature: Test PennTreebank Dataset. - Description: Test exceptions. - Expectation: Exception thrown to be caught - """ - with pytest.raises(ValueError) as error_info: - _ = ds.PennTreebankDataset(FILE_DIR, usage='test', num_samples=-1) - assert "num_samples exceeds the boundary" in str(error_info.value) - with pytest.raises(ValueError) as error_info: - _ = ds.PennTreebankDataset("does/not/exist/no.txt") - assert str(error_info.value) - with pytest.raises(ValueError) as error_info: - _ = ds.PennTreebankDataset("") - assert str(error_info.value) - def exception_func(item): - raise Exception("Error occur!") - with pytest.raises(RuntimeError) as error_info: - data = ds.PennTreebankDataset(FILE_DIR) - data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value) - - -if __name__ == "__main__": - test_penn_treebank_dataset_one_file() - test_penn_treebank_dataset_train() - test_penn_treebank_dataset_valid() - test_penn_treebank_dataset_all_file() - test_penn_treebank_dataset_num_samples_none() - test_penn_treebank_dataset_shuffle_false4() - test_penn_treebank_dataset_shuffle_false1() - test_penn_treebank_dataset_shuffle_files4() - test_penn_treebank_dataset_shuffle_files1() - test_penn_treebank_dataset_shuffle_global4() - test_penn_treebank_dataset_shuffle_global1() - test_penn_treebank_dataset_num_samples() - test_penn_treebank_dataset_distribution() - test_penn_treebank_dataset_repeat() - test_penn_treebank_dataset_get_datasetsize() - test_penn_treebank_dataset_device_que() - test_penn_treebank_dataset_exceptions() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger +from util import config_get_set_num_parallel_workers, config_get_set_seed + +FILE_DIR = '../data/dataset/testPennTreebank' + + +def test_penn_treebank_dataset_one_file(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='test') + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["text"])) + count += 1 + assert count == 3 + + +def test_penn_treebank_dataset_train(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='train') + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["text"])) + count += 1 + assert count == 3 + + +def test_penn_treebank_dataset_valid(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='valid') + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["text"])) + count += 1 + assert count == 3 + + +def test_penn_treebank_dataset_all_file(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='all') + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["text"])) + count += 1 + assert count == 9 + + +def test_penn_treebank_dataset_num_samples_none(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data with no num_samples input. + Expectation: The data is processed successfully. + """ + # Do not provide a num_samples argument, so it would be None by default + data = ds.PennTreebankDataset(FILE_DIR, usage='all') + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["text"])) + count += 1 + assert count == 9 + + +def test_penn_treebank_dataset_shuffle_false4(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is false. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(987) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False) + count = 0 + line = [" no it was black friday ", + " does the bank charge a fee for setting up the account ", + " just ahead of them there was a huge fissure ", + " clash twits poetry formulate flip loyalty splash ", + " the wardrobe was very small in our room ", + " the proportion of female workers in this company ", + " you pay less for the supermaket's own brands ", + " black white grapes ", + " everyone in our football team is fuming "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_shuffle_false1(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is false. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(987) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False) + count = 0 + line = [" no it was black friday ", + " clash twits poetry formulate flip loyalty splash ", + " you pay less for the supermaket's own brands ", + " does the bank charge a fee for setting up the account ", + " the wardrobe was very small in our room ", + " black white grapes ", + " just ahead of them there was a huge fissure ", + " the proportion of female workers in this company ", + " everyone in our football team is fuming "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_shuffle_files4(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is files. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(135) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES) + count = 0 + line = [" just ahead of them there was a huge fissure ", + " does the bank charge a fee for setting up the account ", + " no it was black friday ", + " the proportion of female workers in this company ", + " the wardrobe was very small in our room ", + " clash twits poetry formulate flip loyalty splash ", + " everyone in our football team is fuming ", + " black white grapes ", + " you pay less for the supermaket's own brands "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_shuffle_files1(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is files. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(135) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES) + count = 0 + line = [" just ahead of them there was a huge fissure ", + " the proportion of female workers in this company ", + " everyone in our football team is fuming ", + " does the bank charge a fee for setting up the account ", + " the wardrobe was very small in our room ", + " black white grapes ", + " no it was black friday ", + " clash twits poetry formulate flip loyalty splash ", + " you pay less for the supermaket's own brands "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_shuffle_global4(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is global. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(246) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL) + count = 0 + line = [" everyone in our football team is fuming ", + " does the bank charge a fee for setting up the account ", + " clash twits poetry formulate flip loyalty splash ", + " no it was black friday ", + " just ahead of them there was a huge fissure ", + " the proportion of female workers in this company ", + " you pay less for the supermaket's own brands ", + " the wardrobe was very small in our room ", + " black white grapes "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_shuffle_global1(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file with shulle is global. + Expectation: The data is processed successfully. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(246) + data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL) + count = 0 + line = [" everyone in our football team is fuming ", + " does the bank charge a fee for setting up the account ", + " clash twits poetry formulate flip loyalty splash ", + " the wardrobe was very small in our room ", + " black white grapes ", + " you pay less for the supermaket's own brands ", + " the proportion of female workers in this company ", + " no it was black friday ", + " just ahead of them there was a huge fissure "] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_penn_treebank_dataset_num_samples(): + """ + Feature: Test PennTreebank Dataset. + Description: Test num_samples. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_samples=2) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_penn_treebank_dataset_distribution(): + """ + Feature: Test PennTreebank Dataset. + Description: Read data from a single file. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 5 + + +def test_penn_treebank_dataset_repeat(): + """ + Feature: Test PennTreebank Dataset. + Description: Test repeat. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='test', shuffle=False) + data = data.repeat(3) + count = 0 + line = [" no it was black friday ", + " clash twits poetry formulate flip loyalty splash ", + " you pay less for the supermaket's own brands ", + " no it was black friday ", + " clash twits poetry formulate flip loyalty splash ", + " you pay less for the supermaket's own brands ", + " no it was black friday ", + " clash twits poetry formulate flip loyalty splash ", + " you pay less for the supermaket's own brands ",] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + strs = i["text"] + assert strs == line[count] + count += 1 + assert count == 9 + + +def test_penn_treebank_dataset_get_datasetsize(): + """ + Feature: Test PennTreebank Dataset. + Description: Test get_datasetsize. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='test') + size = data.get_dataset_size() + assert size == 3 + + +def test_penn_treebank_dataset_device_que(): + """ + Feature: Test PennTreebank Dataset. + Description: Test device_que. + Expectation: The data is processed successfully. + """ + data = ds.PennTreebankDataset(FILE_DIR, usage='test') + data = data.device_que() + data.send() + + +def test_penn_treebank_dataset_exceptions(): + """ + Feature: Test PennTreebank Dataset. + Description: Test exceptions. + Expectation: Exception thrown to be caught + """ + with pytest.raises(ValueError) as error_info: + _ = ds.PennTreebankDataset(FILE_DIR, usage='test', num_samples=-1) + assert "num_samples exceeds the boundary" in str(error_info.value) + with pytest.raises(ValueError) as error_info: + _ = ds.PennTreebankDataset("does/not/exist/no.txt") + assert str(error_info.value) + with pytest.raises(ValueError) as error_info: + _ = ds.PennTreebankDataset("") + assert str(error_info.value) + def exception_func(item): + raise Exception("Error occur!") + with pytest.raises(RuntimeError) as error_info: + data = ds.PennTreebankDataset(FILE_DIR) + data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value) + + +if __name__ == "__main__": + test_penn_treebank_dataset_one_file() + test_penn_treebank_dataset_train() + test_penn_treebank_dataset_valid() + test_penn_treebank_dataset_all_file() + test_penn_treebank_dataset_num_samples_none() + test_penn_treebank_dataset_shuffle_false4() + test_penn_treebank_dataset_shuffle_false1() + test_penn_treebank_dataset_shuffle_files4() + test_penn_treebank_dataset_shuffle_files1() + test_penn_treebank_dataset_shuffle_global4() + test_penn_treebank_dataset_shuffle_global1() + test_penn_treebank_dataset_num_samples() + test_penn_treebank_dataset_distribution() + test_penn_treebank_dataset_repeat() + test_penn_treebank_dataset_get_datasetsize() + test_penn_treebank_dataset_device_que() + test_penn_treebank_dataset_exceptions() diff --git a/tests/ut/python/dataset/test_datasets_qmnist.py b/tests/ut/python/dataset/test_datasets_qmnist.py index fb228a252cf..1701c2fab80 100644 --- a/tests/ut/python/dataset/test_datasets_qmnist.py +++ b/tests/ut/python/dataset/test_datasets_qmnist.py @@ -1,356 +1,356 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test QMnistDataset operations -""" -import os - -import matplotlib.pyplot as plt -import numpy as np -import pytest - -import mindspore.dataset as ds -import mindspore.dataset.vision as vision -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testQMnistData" - - -def load_qmnist(path, usage, compat=True): - """ - load QMNIST data - """ - image_path = [] - label_path = [] - image_ext = "images-idx3-ubyte" - label_ext = "labels-idx2-int" - train_prefix = "qmnist-train" - test_prefix = "qmnist-test" - nist_prefix = "xnist" - assert usage in ["train", "test", "nist", "all"] - if usage == "train": - image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) - elif usage == "test": - image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) - elif usage == "nist": - image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) - elif usage == "all": - image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) - image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) - image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) - label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) - - assert len(image_path) == len(label_path) - - images = [] - labels = [] - for i, _ in enumerate(image_path): - with open(image_path[i], 'rb') as image_file: - image_file.read(16) - image = np.fromfile(image_file, dtype=np.uint8) - image = image.reshape(-1, 28, 28, 1) - images.append(image) - with open(label_path[i], 'rb') as label_file: - label_file.read(12) - label = np.fromfile(label_file, dtype='>u4') - label = label.reshape(-1, 8) - labels.append(label) - - images = np.concatenate(images, 0) - labels = np.concatenate(labels, 0) - if compat: - return images, labels[:, 0] - return images, labels - - -def visualize_dataset(images, labels): - """ - Helper function to visualize the dataset samples - """ - num_samples = len(images) - for i in range(num_samples): - plt.subplot(1, num_samples, i + 1) - plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) - plt.title(labels[i]) - plt.show() - - -def test_qmnist_content_check(): - """ - Feature: QMnistDataset - Description: Test QMnistDataset image readings with content check - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset Op with content check") - for usage in ["train", "test", "nist", "all"]: - data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False) - images, labels = load_qmnist(DATA_DIR, usage, True) - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - image_list, label_list = [], [] - for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): - image_list.append(data["image"]) - label_list.append("label {}".format(data["label"])) - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 10 - - for usage in ["train", "test", "nist", "all"]: - data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False) - images, labels = load_qmnist(DATA_DIR, usage, False) - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - image_list, label_list = [], [] - for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): - image_list.append(data["image"]) - label_list.append("label {}".format(data["label"])) - np.testing.assert_array_equal(data["image"], images[i]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 10 - - -def test_qmnist_basic(): - """ - Feature: QMnistDataset - Description: Test QMnistDataset basic usage - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset Op") - - # case 1: test loading whole dataset - data1 = ds.QMnistDataset(DATA_DIR, "train", True) - num_iter1 = 0 - for _ in data1.create_dict_iterator(num_epochs=1): - num_iter1 += 1 - assert num_iter1 == 10 - - # case 2: test num_samples - data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5) - num_iter2 = 0 - for _ in data2.create_dict_iterator(num_epochs=1): - num_iter2 += 1 - assert num_iter2 == 5 - - # case 3: test repeat - data3 = ds.QMnistDataset(DATA_DIR, "train", True) - data3 = data3.repeat(5) - num_iter3 = 0 - for _ in data3.create_dict_iterator(num_epochs=1): - num_iter3 += 1 - assert num_iter3 == 50 - - # case 4: test batch with drop_remainder=False - data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) - assert data4.get_dataset_size() == 10 - assert data4.get_batch_size() == 1 - data4 = data4.batch(batch_size=7) # drop_remainder is default to be False - assert data4.get_dataset_size() == 2 - assert data4.get_batch_size() == 7 - num_iter4 = 0 - for _ in data4.create_dict_iterator(num_epochs=1): - num_iter4 += 1 - assert num_iter4 == 2 - - # case 5: test batch with drop_remainder=True - data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) - assert data5.get_dataset_size() == 10 - assert data5.get_batch_size() == 1 - data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped - assert data5.get_dataset_size() == 3 - assert data5.get_batch_size() == 3 - num_iter5 = 0 - for _ in data5.create_dict_iterator(num_epochs=1): - num_iter5 += 1 - assert num_iter5 == 3 - - # case 6: test get_col_names - dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) - assert dataset.get_col_names() == ["image", "label"] - - -def test_qmnist_pk_sampler(): - """ - Feature: QMnistDataset - Description: Test QMnistDataset with PKSampler - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset Op with PKSampler") - golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - sampler = ds.PKSampler(10) - data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler) - num_iter = 0 - label_list = [] - for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): - label_list.append(item["label"]) - num_iter += 1 - np.testing.assert_array_equal(golden, label_list) - assert num_iter == 10 - - -def test_qmnist_sequential_sampler(): - """ - Feature: QMnistDataset - Description: Test QMnistDataset with SequentialSampler - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset Op with SequentialSampler") - num_samples = 10 - sampler = ds.SequentialSampler(num_samples=num_samples) - data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler) - data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples) - label_list1, label_list2 = [], [] - num_iter = 0 - for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)): - label_list1.append(item1["label"].asnumpy()) - label_list2.append(item2["label"].asnumpy()) - num_iter += 1 - np.testing.assert_array_equal(label_list1, label_list2) - assert num_iter == num_samples - - -def test_qmnist_exception(): - """ - Feature: QMnistDataset - Description: Test error cases for QMnistDataset - Expectation: Correct error is thrown as expected - """ - logger.info("Test error cases for MnistDataset") - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3)) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0") - - def exception_func(item): - raise Exception("Error occur!") - - error_msg_8 = "The corresponding data file is" - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.QMnistDataset(DATA_DIR, "train", True) - data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.QMnistDataset(DATA_DIR, "train", True) - data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) - data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - with pytest.raises(RuntimeError, match=error_msg_8): - data = ds.QMnistDataset(DATA_DIR, "train", True) - data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) - for _ in data.__iter__(): - pass - - -def test_qmnist_visualize(plot=False): - """ - Feature: QMnistDataset - Description: Test QMnistDataset visualized results - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset visualization") - - data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False) - num_iter = 0 - image_list, label_list = [], [] - for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): - image = item["image"] - label = item["label"] - image_list.append(image) - label_list.append("label {}".format(label)) - assert isinstance(image, np.ndarray) - assert image.shape == (28, 28, 1) - assert image.dtype == np.uint8 - assert label.dtype == np.uint32 - num_iter += 1 - assert num_iter == 10 - if plot: - visualize_dataset(image_list, label_list) - - -def test_qmnist_usage(): - """ - Feature: QMnistDataset - Description: Test QMnistDataset image readings with usage flag - Expectation: The dataset is processed as expected - """ - logger.info("Test QMnistDataset usage flag") - - def test_config(usage, path=None): - path = DATA_DIR if path is None else path - try: - data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("train") == 10 - assert test_config("test") == 10 - assert test_config("nist") == 10 - assert test_config("all") == 30 - assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in \ - test_config("invalid") - assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) - - -if __name__ == '__main__': - test_qmnist_content_check() - test_qmnist_basic() - test_qmnist_pk_sampler() - test_qmnist_sequential_sampler() - test_qmnist_exception() - test_qmnist_visualize(plot=True) - test_qmnist_usage() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test QMnistDataset operations +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testQMnistData" + + +def load_qmnist(path, usage, compat=True): + """ + load QMNIST data + """ + image_path = [] + label_path = [] + image_ext = "images-idx3-ubyte" + label_ext = "labels-idx2-int" + train_prefix = "qmnist-train" + test_prefix = "qmnist-test" + nist_prefix = "xnist" + assert usage in ["train", "test", "nist", "all"] + if usage == "train": + image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) + elif usage == "test": + image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) + elif usage == "nist": + image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) + elif usage == "all": + image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext))) + image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext))) + image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext))) + label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext))) + + assert len(image_path) == len(label_path) + + images = [] + labels = [] + for i, _ in enumerate(image_path): + with open(image_path[i], 'rb') as image_file: + image_file.read(16) + image = np.fromfile(image_file, dtype=np.uint8) + image = image.reshape(-1, 28, 28, 1) + images.append(image) + with open(label_path[i], 'rb') as label_file: + label_file.read(12) + label = np.fromfile(label_file, dtype='>u4') + label = label.reshape(-1, 8) + labels.append(label) + + images = np.concatenate(images, 0) + labels = np.concatenate(labels, 0) + if compat: + return images, labels[:, 0] + return images, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) + plt.title(labels[i]) + plt.show() + + +def test_qmnist_content_check(): + """ + Feature: QMnistDataset + Description: Test QMnistDataset image readings with content check + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset Op with content check") + for usage in ["train", "test", "nist", "all"]: + data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False) + images, labels = load_qmnist(DATA_DIR, usage, True) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + image_list, label_list = [], [] + for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): + image_list.append(data["image"]) + label_list.append("label {}".format(data["label"])) + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 10 + + for usage in ["train", "test", "nist", "all"]: + data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False) + images, labels = load_qmnist(DATA_DIR, usage, False) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + image_list, label_list = [], [] + for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): + image_list.append(data["image"]) + label_list.append("label {}".format(data["label"])) + np.testing.assert_array_equal(data["image"], images[i]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 10 + + +def test_qmnist_basic(): + """ + Feature: QMnistDataset + Description: Test QMnistDataset basic usage + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset Op") + + # case 1: test loading whole dataset + data1 = ds.QMnistDataset(DATA_DIR, "train", True) + num_iter1 = 0 + for _ in data1.create_dict_iterator(num_epochs=1): + num_iter1 += 1 + assert num_iter1 == 10 + + # case 2: test num_samples + data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5) + num_iter2 = 0 + for _ in data2.create_dict_iterator(num_epochs=1): + num_iter2 += 1 + assert num_iter2 == 5 + + # case 3: test repeat + data3 = ds.QMnistDataset(DATA_DIR, "train", True) + data3 = data3.repeat(5) + num_iter3 = 0 + for _ in data3.create_dict_iterator(num_epochs=1): + num_iter3 += 1 + assert num_iter3 == 50 + + # case 4: test batch with drop_remainder=False + data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert data4.get_dataset_size() == 10 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=7) # drop_remainder is default to be False + assert data4.get_dataset_size() == 2 + assert data4.get_batch_size() == 7 + num_iter4 = 0 + for _ in data4.create_dict_iterator(num_epochs=1): + num_iter4 += 1 + assert num_iter4 == 2 + + # case 5: test batch with drop_remainder=True + data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert data5.get_dataset_size() == 10 + assert data5.get_batch_size() == 1 + data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped + assert data5.get_dataset_size() == 3 + assert data5.get_batch_size() == 3 + num_iter5 = 0 + for _ in data5.create_dict_iterator(num_epochs=1): + num_iter5 += 1 + assert num_iter5 == 3 + + # case 6: test get_col_names + dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10) + assert dataset.get_col_names() == ["image", "label"] + + +def test_qmnist_pk_sampler(): + """ + Feature: QMnistDataset + Description: Test QMnistDataset with PKSampler + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset Op with PKSampler") + golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + sampler = ds.PKSampler(10) + data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler) + num_iter = 0 + label_list = [] + for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): + label_list.append(item["label"]) + num_iter += 1 + np.testing.assert_array_equal(golden, label_list) + assert num_iter == 10 + + +def test_qmnist_sequential_sampler(): + """ + Feature: QMnistDataset + Description: Test QMnistDataset with SequentialSampler + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset Op with SequentialSampler") + num_samples = 10 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler) + data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples) + label_list1, label_list2 = [], [] + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)): + label_list1.append(item1["label"].asnumpy()) + label_list2.append(item2["label"].asnumpy()) + num_iter += 1 + np.testing.assert_array_equal(label_list1, label_list2) + assert num_iter == num_samples + + +def test_qmnist_exception(): + """ + Feature: QMnistDataset + Description: Test error cases for QMnistDataset + Expectation: Correct error is thrown as expected + """ + logger.info("Test error cases for MnistDataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data file is" + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) + data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + with pytest.raises(RuntimeError, match=error_msg_8): + data = ds.QMnistDataset(DATA_DIR, "train", True) + data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + + +def test_qmnist_visualize(plot=False): + """ + Feature: QMnistDataset + Description: Test QMnistDataset visualized results + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset visualization") + + data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (28, 28, 1) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 10 + if plot: + visualize_dataset(image_list, label_list) + + +def test_qmnist_usage(): + """ + Feature: QMnistDataset + Description: Test QMnistDataset image readings with usage flag + Expectation: The dataset is processed as expected + """ + logger.info("Test QMnistDataset usage flag") + + def test_config(usage, path=None): + path = DATA_DIR if path is None else path + try: + data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("train") == 10 + assert test_config("test") == 10 + assert test_config("nist") == 10 + assert test_config("all") == 30 + assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in \ + test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + +if __name__ == '__main__': + test_qmnist_content_check() + test_qmnist_basic() + test_qmnist_pk_sampler() + test_qmnist_sequential_sampler() + test_qmnist_exception() + test_qmnist_visualize(plot=True) + test_qmnist_usage() diff --git a/tests/ut/python/dataset/test_datasets_sbu.py b/tests/ut/python/dataset/test_datasets_sbu.py index 592079c64ff..215a5b4125d 100644 --- a/tests/ut/python/dataset/test_datasets_sbu.py +++ b/tests/ut/python/dataset/test_datasets_sbu.py @@ -1,317 +1,317 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test USPS dataset operations -""" -import os - -import matplotlib.pyplot as plt -import numpy as np -import pytest -from PIL import Image - -import mindspore.dataset as ds -import mindspore.dataset.vision as vision -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testSBUDataset" -WRONG_DIR = "../data/dataset/testMnistData" - - -def load_sbu(path): - """ - load SBU data - """ - images = [] - captions = [] - - file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt')) - file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt')) - - for line1, line2 in zip(open(file1), open(file2)): - url = line1.rstrip() - image = url[23:].replace("/", "_") - filename = os.path.join(path, 'sbu_images', image) - if os.path.exists(filename): - caption = line2.rstrip() - images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8)) - captions.append(caption) - return images, captions - - -def visualize_dataset(images, captions): - """ - Helper function to visualize the dataset samples - """ - num_samples = len(images) - for i in range(num_samples): - plt.subplot(1, num_samples, i + 1) - plt.imshow(images[i].squeeze()) - plt.title(captions[i]) - plt.show() - - -def test_sbu_content_check(): - """ - Feature: SBUDataset - Description: Test SBUDataset image readings with content check - Expectation: The dataset is processed as expected - """ - logger.info("Test SBUDataset Op with content check") - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False) - images, captions = load_sbu(DATA_DIR) - num_iter = 0 - # in this example, each dictionary has keys "image" and "caption" - for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)): - assert data["image"].shape == images[i].shape - assert data["caption"] == captions[i] - num_iter += 1 - assert num_iter == 5 - - -def test_sbu_case(): - """ - Feature: SBUDataset - Description: Test SBUDataset cases - Expectation: The dataset is processed as expected - """ - dataset = ds.SBUDataset(DATA_DIR, decode=True) - - dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"]) - repeat_num = 4 - dataset = dataset.repeat(repeat_num) - batch_size = 2 - dataset = dataset.batch(batch_size, drop_remainder=True) - - num = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num += 1 - # 4 x 5 / 2 - assert num == 10 - - dataset = ds.SBUDataset(DATA_DIR, decode=False) - - dataset = dataset.map(operations=[vision.Decode(), vision.Resize((224, 224))], input_columns=["image"]) - repeat_num = 4 - dataset = dataset.repeat(repeat_num) - batch_size = 2 - dataset = dataset.batch(batch_size, drop_remainder=True) - - num = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num += 1 - # 4 x 5 / 2 - assert num == 10 - - -def test_sbu_basic(): - """ - Feature: SBUDataset - Description: Test SBUDataset basic usage - Expectation: The dataset is processed as expected - """ - logger.info("Test SBUDataset Op") - - # case 1: test loading whole dataset - dataset = ds.SBUDataset(DATA_DIR, decode=True) - num_iter = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num_iter += 1 - assert num_iter == 5 - - - # case 2: test num_samples - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) - num_iter = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num_iter += 1 - assert num_iter == 5 - - # case 3: test repeat - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) - dataset = dataset.repeat(5) - num_iter = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num_iter += 1 - assert num_iter == 25 - - # case 4: test batch - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) - assert dataset.get_dataset_size() == 5 - assert dataset.get_batch_size() == 1 - - num_iter = 0 - for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - num_iter += 1 - assert num_iter == 5 - - # case 5: test get_class_indexing - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) - assert dataset.get_class_indexing() == {} - - # case 6: test get_col_names - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) - assert dataset.get_col_names() == ["image", "caption"] - - -def test_sbu_sequential_sampler(): - """ - Feature: SBUDataset - Description: Test SBUDataset wtih SequentialSampler - Expectation: The dataset is processed as expected - """ - logger.info("Test SBUDataset Op with SequentialSampler") - num_samples = 5 - sampler = ds.SequentialSampler(num_samples=num_samples) - dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler) - dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples) - - num_iter = 0 - for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True), - dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(item1["caption"], item2["caption"]) - num_iter += 1 - assert num_iter == num_samples - - -def test_sbu_exception(): - """ - Feature: SBUDataset - Description: Test error cases for SBUDataset - Expectation: Correct error is thrown as expected - """ - logger.info("Test error cases for SBUDataset") - error_msg_1 = "sampler and shuffle cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_1): - ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler()) - - error_msg_2 = "sampler and sharding cannot be specified at the same time" - with pytest.raises(RuntimeError, match=error_msg_2): - ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0) - - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.SBUDataset(DATA_DIR, decode=True, num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.SBUDataset(DATA_DIR, decode=True, shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0") - - def exception_func(item): - raise Exception("Error occur!") - - error_msg_8 = "The corresponding data file is" - with pytest.raises(RuntimeError, match=error_msg_8): - dataset = ds.SBUDataset(DATA_DIR, decode=True) - dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) - for _ in dataset.__iter__(): - pass - - with pytest.raises(RuntimeError, match=error_msg_8): - dataset = ds.SBUDataset(DATA_DIR, decode=True) - dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) - for _ in dataset.__iter__(): - pass - - error_msg_9 = "does not exist or permission denied" - with pytest.raises(ValueError, match=error_msg_9): - dataset = ds.SBUDataset(WRONG_DIR, decode=True) - for _ in dataset.__iter__(): - pass - - error_msg_10 = "Argument decode with value" - with pytest.raises(TypeError, match=error_msg_10): - dataset = ds.SBUDataset(DATA_DIR, decode="not_bool") - for _ in dataset.__iter__(): - pass - - -def test_sbu_visualize(plot=False): - """ - Feature: SBUDataset - Description: Test SBUDataset visualized results - Expectation: The dataset is processed as expected - """ - logger.info("Test SBUDataset visualization") - - dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False) - num_iter = 0 - image_list, caption_list = [], [] - for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): - image = item["image"] - caption = item["caption"] - image_list.append(image) - caption_list.append("caption {}".format(caption)) - assert isinstance(image, np.ndarray) - - assert image.dtype == np.uint8 - assert caption.dtype.type == np.str_ - num_iter += 1 - assert num_iter == 5 - if plot: - visualize_dataset(image_list, caption_list) - - -def test_sbu_decode(): - """ - Feature: SBUDataset - Description: Test SBUDataset image readings with decode flag - Expectation: The dataset is processed as expected - """ - logger.info("Test SBUDataset decode flag") - - sampler = ds.SequentialSampler(num_samples=50) - dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler) - dataset_1 = dataset.map(operations=[vision.Decode()], input_columns=["image"]) - - dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler) - - num_iter = 0 - for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True), - dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)): - np.testing.assert_array_equal(item1["caption"], item2["caption"]) - num_iter += 1 - - assert num_iter == 5 - - -if __name__ == '__main__': - test_sbu_content_check() - test_sbu_basic() - test_sbu_case() - test_sbu_sequential_sampler() - test_sbu_exception() - test_sbu_visualize(plot=True) - test_sbu_decode() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test USPS dataset operations +""" +import os + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from PIL import Image + +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testSBUDataset" +WRONG_DIR = "../data/dataset/testMnistData" + + +def load_sbu(path): + """ + load SBU data + """ + images = [] + captions = [] + + file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt')) + file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt')) + + for line1, line2 in zip(open(file1), open(file2)): + url = line1.rstrip() + image = url[23:].replace("/", "_") + filename = os.path.join(path, 'sbu_images', image) + if os.path.exists(filename): + caption = line2.rstrip() + images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8)) + captions.append(caption) + return images, captions + + +def visualize_dataset(images, captions): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze()) + plt.title(captions[i]) + plt.show() + + +def test_sbu_content_check(): + """ + Feature: SBUDataset + Description: Test SBUDataset image readings with content check + Expectation: The dataset is processed as expected + """ + logger.info("Test SBUDataset Op with content check") + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False) + images, captions = load_sbu(DATA_DIR) + num_iter = 0 + # in this example, each dictionary has keys "image" and "caption" + for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)): + assert data["image"].shape == images[i].shape + assert data["caption"] == captions[i] + num_iter += 1 + assert num_iter == 5 + + +def test_sbu_case(): + """ + Feature: SBUDataset + Description: Test SBUDataset cases + Expectation: The dataset is processed as expected + """ + dataset = ds.SBUDataset(DATA_DIR, decode=True) + + dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"]) + repeat_num = 4 + dataset = dataset.repeat(repeat_num) + batch_size = 2 + dataset = dataset.batch(batch_size, drop_remainder=True) + + num = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num += 1 + # 4 x 5 / 2 + assert num == 10 + + dataset = ds.SBUDataset(DATA_DIR, decode=False) + + dataset = dataset.map(operations=[vision.Decode(), vision.Resize((224, 224))], input_columns=["image"]) + repeat_num = 4 + dataset = dataset.repeat(repeat_num) + batch_size = 2 + dataset = dataset.batch(batch_size, drop_remainder=True) + + num = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num += 1 + # 4 x 5 / 2 + assert num == 10 + + +def test_sbu_basic(): + """ + Feature: SBUDataset + Description: Test SBUDataset basic usage + Expectation: The dataset is processed as expected + """ + logger.info("Test SBUDataset Op") + + # case 1: test loading whole dataset + dataset = ds.SBUDataset(DATA_DIR, decode=True) + num_iter = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert num_iter == 5 + + + # case 2: test num_samples + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) + num_iter = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert num_iter == 5 + + # case 3: test repeat + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) + dataset = dataset.repeat(5) + num_iter = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert num_iter == 25 + + # case 4: test batch + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) + assert dataset.get_dataset_size() == 5 + assert dataset.get_batch_size() == 1 + + num_iter = 0 + for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert num_iter == 5 + + # case 5: test get_class_indexing + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) + assert dataset.get_class_indexing() == {} + + # case 6: test get_col_names + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5) + assert dataset.get_col_names() == ["image", "caption"] + + +def test_sbu_sequential_sampler(): + """ + Feature: SBUDataset + Description: Test SBUDataset wtih SequentialSampler + Expectation: The dataset is processed as expected + """ + logger.info("Test SBUDataset Op with SequentialSampler") + num_samples = 5 + sampler = ds.SequentialSampler(num_samples=num_samples) + dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler) + dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples) + + num_iter = 0 + for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True), + dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1["caption"], item2["caption"]) + num_iter += 1 + assert num_iter == num_samples + + +def test_sbu_exception(): + """ + Feature: SBUDataset + Description: Test error cases for SBUDataset + Expectation: Correct error is thrown as expected + """ + logger.info("Test error cases for SBUDataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler()) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.SBUDataset(DATA_DIR, decode=True, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.SBUDataset(DATA_DIR, decode=True, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0") + + def exception_func(item): + raise Exception("Error occur!") + + error_msg_8 = "The corresponding data file is" + with pytest.raises(RuntimeError, match=error_msg_8): + dataset = ds.SBUDataset(DATA_DIR, decode=True) + dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) + for _ in dataset.__iter__(): + pass + + with pytest.raises(RuntimeError, match=error_msg_8): + dataset = ds.SBUDataset(DATA_DIR, decode=True) + dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) + for _ in dataset.__iter__(): + pass + + error_msg_9 = "does not exist or permission denied" + with pytest.raises(ValueError, match=error_msg_9): + dataset = ds.SBUDataset(WRONG_DIR, decode=True) + for _ in dataset.__iter__(): + pass + + error_msg_10 = "Argument decode with value" + with pytest.raises(TypeError, match=error_msg_10): + dataset = ds.SBUDataset(DATA_DIR, decode="not_bool") + for _ in dataset.__iter__(): + pass + + +def test_sbu_visualize(plot=False): + """ + Feature: SBUDataset + Description: Test SBUDataset visualized results + Expectation: The dataset is processed as expected + """ + logger.info("Test SBUDataset visualization") + + dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False) + num_iter = 0 + image_list, caption_list = [], [] + for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + caption = item["caption"] + image_list.append(image) + caption_list.append("caption {}".format(caption)) + assert isinstance(image, np.ndarray) + + assert image.dtype == np.uint8 + assert caption.dtype.type == np.str_ + num_iter += 1 + assert num_iter == 5 + if plot: + visualize_dataset(image_list, caption_list) + + +def test_sbu_decode(): + """ + Feature: SBUDataset + Description: Test SBUDataset image readings with decode flag + Expectation: The dataset is processed as expected + """ + logger.info("Test SBUDataset decode flag") + + sampler = ds.SequentialSampler(num_samples=50) + dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler) + dataset_1 = dataset.map(operations=[vision.Decode()], input_columns=["image"]) + + dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler) + + num_iter = 0 + for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True), + dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1["caption"], item2["caption"]) + num_iter += 1 + + assert num_iter == 5 + + +if __name__ == '__main__': + test_sbu_content_check() + test_sbu_basic() + test_sbu_case() + test_sbu_sequential_sampler() + test_sbu_exception() + test_sbu_visualize(plot=True) + test_sbu_decode() diff --git a/tests/ut/python/dataset/test_datasets_sogou_news.py b/tests/ut/python/dataset/test_datasets_sogou_news.py index 42e6a5080f7..3f923fa8977 100644 --- a/tests/ut/python/dataset/test_datasets_sogou_news.py +++ b/tests/ut/python/dataset/test_datasets_sogou_news.py @@ -1,185 +1,185 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import mindspore.dataset as ds - -DATA_SOGOU_NEWS_DIR = '../data/dataset/testSogouNews/' - - -def test_sogou_news_dataset_basic(): - """ - Feature: Test SogouNews Dataset. - Description: Read data from a test.csv file. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - data = data.repeat(2) - data = data.skip(2) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 4 - - -def test_sogou_news_dataset_all(): - """ - Feature: Test SogouNews Dataset. - Description: Read data from a test.csv and train.csv file. - Expectation: The data is processed successfully. - """ - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='all', shuffle=False) - buffer = [] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['index'], - d['title'], - d['content']]) - assert buffer == ["1", "Jefferson commented on thick eyebrow: he has the top five talents in the league, but he " - "is not the top five", "They say he has the talent of the top five in the league. The talent " - "of the top five in the league is one of the most disrespectful statements. I say he has the " - "talent of the top five in the league, but he is not the top five players because the top five " - "players play every night.", - "1", "Make history", "Su Bingtian's 100m breakthrough\\n 9.83", - "3", "Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro " - "curly long hair, elegant, lazy, gentle and capable", "Liu Shishi's latest group of cover " - "magazine blockbusters are released. In the photos, Liu Shishi's long hair is slightly curly, " - "or camel colored belted woolen coat, or plaid suit, which is gentle and elegant and beautiful " - "to a new height.", - "4", "Tesla price", "Tesla reduced its price by 70000 yuan", - "3", "Ni Ni deduces elegant retro style in different styles", "Ni Ni's latest group of magazine " - "cover blockbusters released that wearing gift hats is cool, retro, unique and full of fashion " - "expression.", - "1", "Opening ceremony of the 14th National Games", "On the evening of September 15, Beijing " - "time, the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center " - "Stadium, Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in " - "the Tokyo Olympic Games and a Post-00 shooter, lit the main torch platform. From then on, " - "to September 27, the 14th National Games flame will burn here for 12 days."] - - -def test_sogou_news_dataset_quoted(): - """ - Feature: Test get the SogouNews Dataset. - Description: Read SogouNewsDataset data and get data. - Expectation: The data is processed successfully. - """ - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - buffer = [] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.extend([d['index'], - d['title'], - d['content']]) - assert buffer == ["1", "Make history", "Su Bingtian's 100m breakthrough\\n 9.83", - "4", "Tesla price", "Tesla reduced its price by 70000 yuan", - "1", "Opening ceremony of the 14th National Games", "On the evening of September 15, Beijing time" - ", the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center " - "Stadium, Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the" - " Tokyo Olympic Games and a Post-00 shooter, lit the main torch platform. From then on, to " - "September 27, the 14th National Games flame will burn here for 12 days."] - - -def test_sogou_news_dataset_usage_all(): - """ - Feature: Test SogouNews Dataset(usage=all). - Description: Read train data and test data. - Expectation: The data is processed successfully. - """ - buffer = [] - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='all', shuffle=False) - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - buffer.append(d) - assert len(buffer) == 6 - - -def test_sogou_news_dataset_get_datasetsize(): - """ - Feature: Test Getters. - Description: Test get_dataset_size of SogouNews dataset. - Expectation: The data is processed successfully. - """ - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - size = data.get_dataset_size() - assert size == 3 - - -def test_sogou_news_dataset_distribution(): - """ - Feature: Test SogouNewsDataset in distribution. - Description: Test in a distributed state. - Expectation: The data is processed successfully. - """ - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 2 - - -def test_sogou_news_dataset_num_samples(): - """ - Feature: Test SogouNews Dataset(num_samples = 2). - Description: Test get num_samples. - Expectation: The data is processed successfully. - """ - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False, num_samples=2) - count = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - count += 1 - assert count == 2 - - -def test_sogou_news_dataset_exception(): - """ - Feature: Error Test. - Description: Test the wrong input. - Expectation: Unable to read in data. - """ - def exception_func(item): - raise Exception("Error occur!") - - try: - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - try: - data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) - data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1) - for _ in data.create_dict_iterator(): - pass - assert False - except RuntimeError as e: - assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) - - -if __name__ == "__main__": - test_sogou_news_dataset_basic() - test_sogou_news_dataset_all() - test_sogou_news_dataset_quoted() - test_sogou_news_dataset_usage_all() - test_sogou_news_dataset_get_datasetsize() - test_sogou_news_dataset_distribution() - test_sogou_news_dataset_num_samples() - test_sogou_news_dataset_exception() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds + +DATA_SOGOU_NEWS_DIR = '../data/dataset/testSogouNews/' + + +def test_sogou_news_dataset_basic(): + """ + Feature: Test SogouNews Dataset. + Description: Read data from a test.csv file. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 4 + + +def test_sogou_news_dataset_all(): + """ + Feature: Test SogouNews Dataset. + Description: Read data from a test.csv and train.csv file. + Expectation: The data is processed successfully. + """ + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='all', shuffle=False) + buffer = [] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['index'], + d['title'], + d['content']]) + assert buffer == ["1", "Jefferson commented on thick eyebrow: he has the top five talents in the league, but he " + "is not the top five", "They say he has the talent of the top five in the league. The talent " + "of the top five in the league is one of the most disrespectful statements. I say he has the " + "talent of the top five in the league, but he is not the top five players because the top five " + "players play every night.", + "1", "Make history", "Su Bingtian's 100m breakthrough\\n 9.83", + "3", "Group pictures: Liu Shishi's temperament in early autumn released a large piece of micro " + "curly long hair, elegant, lazy, gentle and capable", "Liu Shishi's latest group of cover " + "magazine blockbusters are released. In the photos, Liu Shishi's long hair is slightly curly, " + "or camel colored belted woolen coat, or plaid suit, which is gentle and elegant and beautiful " + "to a new height.", + "4", "Tesla price", "Tesla reduced its price by 70000 yuan", + "3", "Ni Ni deduces elegant retro style in different styles", "Ni Ni's latest group of magazine " + "cover blockbusters released that wearing gift hats is cool, retro, unique and full of fashion " + "expression.", + "1", "Opening ceremony of the 14th National Games", "On the evening of September 15, Beijing " + "time, the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center " + "Stadium, Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in " + "the Tokyo Olympic Games and a Post-00 shooter, lit the main torch platform. From then on, " + "to September 27, the 14th National Games flame will burn here for 12 days."] + + +def test_sogou_news_dataset_quoted(): + """ + Feature: Test get the SogouNews Dataset. + Description: Read SogouNewsDataset data and get data. + Expectation: The data is processed successfully. + """ + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + buffer = [] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.extend([d['index'], + d['title'], + d['content']]) + assert buffer == ["1", "Make history", "Su Bingtian's 100m breakthrough\\n 9.83", + "4", "Tesla price", "Tesla reduced its price by 70000 yuan", + "1", "Opening ceremony of the 14th National Games", "On the evening of September 15, Beijing time" + ", the 14th games of the people's Republic of China opened in Xi'an Olympic Sports Center " + "Stadium, Shaanxi Province. Yang Qian, the first gold medalist of the Chinese delegation in the" + " Tokyo Olympic Games and a Post-00 shooter, lit the main torch platform. From then on, to " + "September 27, the 14th National Games flame will burn here for 12 days."] + + +def test_sogou_news_dataset_usage_all(): + """ + Feature: Test SogouNews Dataset(usage=all). + Description: Read train data and test data. + Expectation: The data is processed successfully. + """ + buffer = [] + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='all', shuffle=False) + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + buffer.append(d) + assert len(buffer) == 6 + + +def test_sogou_news_dataset_get_datasetsize(): + """ + Feature: Test Getters. + Description: Test get_dataset_size of SogouNews dataset. + Expectation: The data is processed successfully. + """ + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + size = data.get_dataset_size() + assert size == 3 + + +def test_sogou_news_dataset_distribution(): + """ + Feature: Test SogouNewsDataset in distribution. + Description: Test in a distributed state. + Expectation: The data is processed successfully. + """ + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_sogou_news_dataset_num_samples(): + """ + Feature: Test SogouNews Dataset(num_samples = 2). + Description: Test get num_samples. + Expectation: The data is processed successfully. + """ + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False, num_samples=2) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_sogou_news_dataset_exception(): + """ + Feature: Error Test. + Description: Test the wrong input. + Expectation: Unable to read in data. + """ + def exception_func(item): + raise Exception("Error occur!") + + try: + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["index"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + try: + data = ds.SogouNewsDataset(DATA_SOGOU_NEWS_DIR, usage='test', shuffle=False) + data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1) + for _ in data.create_dict_iterator(): + pass + assert False + except RuntimeError as e: + assert "map operation: [PyFunc] failed. The corresponding data file is" in str(e) + + +if __name__ == "__main__": + test_sogou_news_dataset_basic() + test_sogou_news_dataset_all() + test_sogou_news_dataset_quoted() + test_sogou_news_dataset_usage_all() + test_sogou_news_dataset_get_datasetsize() + test_sogou_news_dataset_distribution() + test_sogou_news_dataset_num_samples() + test_sogou_news_dataset_exception() diff --git a/tests/ut/python/dataset/test_datasets_usps.py b/tests/ut/python/dataset/test_datasets_usps.py index 65e7c84638e..0cdd58ffd9a 100644 --- a/tests/ut/python/dataset/test_datasets_usps.py +++ b/tests/ut/python/dataset/test_datasets_usps.py @@ -1,337 +1,337 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Test USPS dataset operations -""" -import os -from typing import cast - -import matplotlib.pyplot as plt -import numpy as np -import pytest - -import mindspore.dataset as ds -import mindspore.dataset.transforms as transforms -import mindspore.dataset.vision as vision -from mindspore import log as logger - -DATA_DIR = "../data/dataset/testUSPSDataset" -WRONG_DIR = "../data/dataset/testMnistData" - - -def load_usps(path, usage): - """ - load USPS data - """ - assert usage in ["train", "test"] - if usage == "train": - data_path = os.path.realpath(os.path.join(path, "usps")) - elif usage == "test": - data_path = os.path.realpath(os.path.join(path, "usps.t")) - - with open(data_path, 'r') as f: - raw_data = [line.split() for line in f.readlines()] - tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] - images = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16, 1)) - images = ((cast(np.ndarray, images) + 1) / 2 * 255).astype(dtype=np.uint8) - labels = [int(d[0]) - 1 for d in raw_data] - return images, labels - - -def visualize_dataset(images, labels): - """ - Helper function to visualize the dataset samples - """ - num_samples = len(images) - for i in range(num_samples): - plt.subplot(1, num_samples, i + 1) - plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) - plt.title(labels[i]) - plt.show() - - -def test_usps_content_check(): - """ - Feature: USPSDataset - Description: Test USPSDataset image readings with content check - Expectation: The dataset is processed as expected - """ - logger.info("Test USPSDataset Op with content check") - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=10, shuffle=False) - images, labels = load_usps(DATA_DIR, "train") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - for m in range(16): - for n in range(16): - assert (data["image"][m, n, 0] != 0 or images[i][m, n, 0] != 255) and \ - (data["image"][m, n, 0] != 255 or images[i][m, n, 0] != 0) - assert (data["image"][m, n, 0] == images[i][m, n, 0]) or\ - (data["image"][m, n, 0] == images[i][m, n, 0] + 1) or\ - (data["image"][m, n, 0] + 1 == images[i][m, n, 0]) - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 3 - - test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False) - images, labels = load_usps(DATA_DIR, "test") - num_iter = 0 - # in this example, each dictionary has keys "image" and "label" - for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)): - for m in range(16): - for n in range(16): - if (data["image"][m, n, 0] == 0 and images[i][m, n, 0] == 255) or\ - (data["image"][m, n, 0] == 255 and images[i][m, n, 0] == 0): - assert False - if (data["image"][m, n, 0] != images[i][m, n, 0]) and\ - (data["image"][m, n, 0] != images[i][m, n, 0] + 1) and\ - (data["image"][m, n, 0] + 1 != images[i][m, n, 0]): - assert False - np.testing.assert_array_equal(data["label"], labels[i]) - num_iter += 1 - assert num_iter == 3 - - -def test_usps_basic(): - """ - Feature: USPSDataset - Description: Test USPSDataset basic usage - Expectation: The dataset is processed as expected - """ - logger.info("Test USPSDataset Op") - - # case 1: test loading whole dataset - train_data = ds.USPSDataset(DATA_DIR, "train") - num_iter = 0 - for _ in train_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 3 - - test_data = ds.USPSDataset(DATA_DIR, "test") - num_iter = 0 - for _ in test_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 3 - - # case 2: test num_samples - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2) - num_iter = 0 - for _ in train_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 2 - - # case 3: test repeat - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2) - train_data = train_data.repeat(5) - num_iter = 0 - for _ in train_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 10 - - # case 4: test batch with drop_remainder=False - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3) - assert train_data.get_dataset_size() == 3 - assert train_data.get_batch_size() == 1 - train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False - assert train_data.get_batch_size() == 2 - assert train_data.get_dataset_size() == 2 - - num_iter = 0 - for _ in train_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 2 - - # case 5: test batch with drop_remainder=True - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3) - assert train_data.get_dataset_size() == 3 - assert train_data.get_batch_size() == 1 - train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped - assert train_data.get_dataset_size() == 1 - assert train_data.get_batch_size() == 2 - num_iter = 0 - for _ in train_data.create_dict_iterator(num_epochs=1): - num_iter += 1 - assert num_iter == 1 - - -def test_usps_exception(): - """ - Feature: USPSDataset - Description: Test error cases for USPSDataset - Expectation: Correct error is thrown as expected - """ - error_msg_3 = "num_shards is specified and currently requires shard_id as well" - with pytest.raises(RuntimeError, match=error_msg_3): - ds.USPSDataset(DATA_DIR, "train", num_shards=10) - ds.USPSDataset(DATA_DIR, "test", num_shards=10) - - error_msg_4 = "shard_id is specified but num_shards is not" - with pytest.raises(RuntimeError, match=error_msg_4): - ds.USPSDataset(DATA_DIR, "train", shard_id=0) - ds.USPSDataset(DATA_DIR, "test", shard_id=0) - - error_msg_5 = "Input shard_id is not within the required interval" - with pytest.raises(ValueError, match=error_msg_5): - ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=-1) - ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=-1) - with pytest.raises(ValueError, match=error_msg_5): - ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=5) - ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=5) - with pytest.raises(ValueError, match=error_msg_5): - ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id=5) - ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id=5) - - error_msg_6 = "num_parallel_workers exceeds" - with pytest.raises(ValueError, match=error_msg_6): - ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0) - ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=0) - with pytest.raises(ValueError, match=error_msg_6): - ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256) - ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=256) - with pytest.raises(ValueError, match=error_msg_6): - ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2) - ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=-2) - - error_msg_7 = "Argument shard_id" - with pytest.raises(TypeError, match=error_msg_7): - ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id="0") - ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id="0") - - error_msg_8 = "invalid input shape" - with pytest.raises(RuntimeError, match=error_msg_8): - train_data = ds.USPSDataset(DATA_DIR, "train") - train_data = train_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) - for _ in train_data.__iter__(): - pass - - test_data = ds.USPSDataset(DATA_DIR, "test") - test_data = test_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) - for _ in test_data.__iter__(): - pass - - error_msg_9 = "usps does not exist or is a directory" - with pytest.raises(RuntimeError, match=error_msg_9): - train_data = ds.USPSDataset(WRONG_DIR, "train") - for _ in train_data.__iter__(): - pass - error_msg_10 = "usps.t does not exist or is a directory" - with pytest.raises(RuntimeError, match=error_msg_10): - test_data = ds.USPSDataset(WRONG_DIR, "test") - for _ in test_data.__iter__(): - pass - - -def test_usps_visualize(plot=False): - """ - Feature: USPSDataset - Description: Test USPSDataset visualized results - Expectation: The dataset is processed as expected - """ - logger.info("Test USPSDataset visualization") - - train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3, shuffle=False) - num_iter = 0 - image_list, label_list = [], [] - for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True): - image = item["image"] - label = item["label"] - image_list.append(image) - label_list.append("label {}".format(label)) - assert isinstance(image, np.ndarray) - assert image.shape == (16, 16, 1) - assert image.dtype == np.uint8 - assert label.dtype == np.uint32 - num_iter += 1 - assert num_iter == 3 - if plot: - visualize_dataset(image_list, label_list) - - test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False) - num_iter = 0 - image_list, label_list = [], [] - for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True): - image = item["image"] - label = item["label"] - image_list.append(image) - label_list.append("label {}".format(label)) - assert isinstance(image, np.ndarray) - assert image.shape == (16, 16, 1) - assert image.dtype == np.uint8 - assert label.dtype == np.uint32 - num_iter += 1 - assert num_iter == 3 - if plot: - visualize_dataset(image_list, label_list) - - -def test_usps_usage(): - """ - Feature: USPSDataset - Description: Test USPSDataset image readings with usage flag - Expectation: The dataset is processed as expected - """ - logger.info("Test USPSDataset usage flag") - - def test_config(usage, path=None): - path = DATA_DIR if path is None else path - try: - data = ds.USPSDataset(path, usage=usage, shuffle=False) - num_rows = 0 - for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): - num_rows += 1 - except (ValueError, TypeError, RuntimeError) as e: - return str(e) - return num_rows - - assert test_config("train") == 3 - assert test_config("test") == 3 - - assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid") - assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) - - # change this directory to the folder that contains all USPS files - all_files_path = None - # the following tests on the entire datasets - if all_files_path is not None: - assert test_config("train", all_files_path) == 3 - assert test_config("test", all_files_path) == 3 - assert ds.USPSDataset(all_files_path, usage="train").get_dataset_size() == 3 - assert ds.USPSDataset(all_files_path, usage="test").get_dataset_size() == 3 - - -def test_usps_with_map(): - """ - Feature: USPSDataset - Description: Test doing map operation on USPSDataset - Expectation: The dataset is processed as expected - """ - dataset = ds.USPSDataset(DATA_DIR) - random_crop = vision.RandomCrop((10, 10)) - dataset = dataset.map(random_crop, input_columns=["image"]) - type_cast = transforms.TypeCast(np.float32) - dataset = dataset.map(type_cast, input_columns=["label"]) - count = 0 - for _ in dataset.create_dict_iterator(num_epochs=1): - count += 1 - assert count == 6 - - -if __name__ == '__main__': - test_usps_content_check() - test_usps_basic() - test_usps_exception() - test_usps_visualize(plot=True) - test_usps_usage() - test_usps_with_map() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test USPS dataset operations +""" +import os +from typing import cast + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.transforms as transforms +import mindspore.dataset.vision as vision +from mindspore import log as logger + +DATA_DIR = "../data/dataset/testUSPSDataset" +WRONG_DIR = "../data/dataset/testMnistData" + + +def load_usps(path, usage): + """ + load USPS data + """ + assert usage in ["train", "test"] + if usage == "train": + data_path = os.path.realpath(os.path.join(path, "usps")) + elif usage == "test": + data_path = os.path.realpath(os.path.join(path, "usps.t")) + + with open(data_path, 'r') as f: + raw_data = [line.split() for line in f.readlines()] + tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] + images = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16, 1)) + images = ((cast(np.ndarray, images) + 1) / 2 * 255).astype(dtype=np.uint8) + labels = [int(d[0]) - 1 for d in raw_data] + return images, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i].squeeze(), cmap=plt.cm.gray) + plt.title(labels[i]) + plt.show() + + +def test_usps_content_check(): + """ + Feature: USPSDataset + Description: Test USPSDataset image readings with content check + Expectation: The dataset is processed as expected + """ + logger.info("Test USPSDataset Op with content check") + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=10, shuffle=False) + images, labels = load_usps(DATA_DIR, "train") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + for m in range(16): + for n in range(16): + assert (data["image"][m, n, 0] != 0 or images[i][m, n, 0] != 255) and \ + (data["image"][m, n, 0] != 255 or images[i][m, n, 0] != 0) + assert (data["image"][m, n, 0] == images[i][m, n, 0]) or\ + (data["image"][m, n, 0] == images[i][m, n, 0] + 1) or\ + (data["image"][m, n, 0] + 1 == images[i][m, n, 0]) + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 3 + + test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False) + images, labels = load_usps(DATA_DIR, "test") + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)): + for m in range(16): + for n in range(16): + if (data["image"][m, n, 0] == 0 and images[i][m, n, 0] == 255) or\ + (data["image"][m, n, 0] == 255 and images[i][m, n, 0] == 0): + assert False + if (data["image"][m, n, 0] != images[i][m, n, 0]) and\ + (data["image"][m, n, 0] != images[i][m, n, 0] + 1) and\ + (data["image"][m, n, 0] + 1 != images[i][m, n, 0]): + assert False + np.testing.assert_array_equal(data["label"], labels[i]) + num_iter += 1 + assert num_iter == 3 + + +def test_usps_basic(): + """ + Feature: USPSDataset + Description: Test USPSDataset basic usage + Expectation: The dataset is processed as expected + """ + logger.info("Test USPSDataset Op") + + # case 1: test loading whole dataset + train_data = ds.USPSDataset(DATA_DIR, "train") + num_iter = 0 + for _ in train_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 3 + + test_data = ds.USPSDataset(DATA_DIR, "test") + num_iter = 0 + for _ in test_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 3 + + # case 2: test num_samples + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2) + num_iter = 0 + for _ in train_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 2 + + # case 3: test repeat + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2) + train_data = train_data.repeat(5) + num_iter = 0 + for _ in train_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 10 + + # case 4: test batch with drop_remainder=False + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3) + assert train_data.get_dataset_size() == 3 + assert train_data.get_batch_size() == 1 + train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False + assert train_data.get_batch_size() == 2 + assert train_data.get_dataset_size() == 2 + + num_iter = 0 + for _ in train_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 2 + + # case 5: test batch with drop_remainder=True + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3) + assert train_data.get_dataset_size() == 3 + assert train_data.get_batch_size() == 1 + train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped + assert train_data.get_dataset_size() == 1 + assert train_data.get_batch_size() == 2 + num_iter = 0 + for _ in train_data.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert num_iter == 1 + + +def test_usps_exception(): + """ + Feature: USPSDataset + Description: Test error cases for USPSDataset + Expectation: Correct error is thrown as expected + """ + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.USPSDataset(DATA_DIR, "train", num_shards=10) + ds.USPSDataset(DATA_DIR, "test", num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.USPSDataset(DATA_DIR, "train", shard_id=0) + ds.USPSDataset(DATA_DIR, "test", shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=-1) + ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=5) + ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=5) + with pytest.raises(ValueError, match=error_msg_5): + ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id=5) + ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0) + ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256) + ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=256) + with pytest.raises(ValueError, match=error_msg_6): + ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2) + ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=-2) + + error_msg_7 = "Argument shard_id" + with pytest.raises(TypeError, match=error_msg_7): + ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id="0") + ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id="0") + + error_msg_8 = "invalid input shape" + with pytest.raises(RuntimeError, match=error_msg_8): + train_data = ds.USPSDataset(DATA_DIR, "train") + train_data = train_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) + for _ in train_data.__iter__(): + pass + + test_data = ds.USPSDataset(DATA_DIR, "test") + test_data = test_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) + for _ in test_data.__iter__(): + pass + + error_msg_9 = "usps does not exist or is a directory" + with pytest.raises(RuntimeError, match=error_msg_9): + train_data = ds.USPSDataset(WRONG_DIR, "train") + for _ in train_data.__iter__(): + pass + error_msg_10 = "usps.t does not exist or is a directory" + with pytest.raises(RuntimeError, match=error_msg_10): + test_data = ds.USPSDataset(WRONG_DIR, "test") + for _ in test_data.__iter__(): + pass + + +def test_usps_visualize(plot=False): + """ + Feature: USPSDataset + Description: Test USPSDataset visualized results + Expectation: The dataset is processed as expected + """ + logger.info("Test USPSDataset visualization") + + train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (16, 16, 1) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 3 + if plot: + visualize_dataset(image_list, label_list) + + test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (16, 16, 1) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 3 + if plot: + visualize_dataset(image_list, label_list) + + +def test_usps_usage(): + """ + Feature: USPSDataset + Description: Test USPSDataset image readings with usage flag + Expectation: The dataset is processed as expected + """ + logger.info("Test USPSDataset usage flag") + + def test_config(usage, path=None): + path = DATA_DIR if path is None else path + try: + data = ds.USPSDataset(path, usage=usage, shuffle=False) + num_rows = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + num_rows += 1 + except (ValueError, TypeError, RuntimeError) as e: + return str(e) + return num_rows + + assert test_config("train") == 3 + assert test_config("test") == 3 + + assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid") + assert "Argument usage with value ['list'] is not of type []" in test_config(["list"]) + + # change this directory to the folder that contains all USPS files + all_files_path = None + # the following tests on the entire datasets + if all_files_path is not None: + assert test_config("train", all_files_path) == 3 + assert test_config("test", all_files_path) == 3 + assert ds.USPSDataset(all_files_path, usage="train").get_dataset_size() == 3 + assert ds.USPSDataset(all_files_path, usage="test").get_dataset_size() == 3 + + +def test_usps_with_map(): + """ + Feature: USPSDataset + Description: Test doing map operation on USPSDataset + Expectation: The dataset is processed as expected + """ + dataset = ds.USPSDataset(DATA_DIR) + random_crop = vision.RandomCrop((10, 10)) + dataset = dataset.map(random_crop, input_columns=["image"]) + type_cast = transforms.TypeCast(np.float32) + dataset = dataset.map(type_cast, input_columns=["label"]) + count = 0 + for _ in dataset.create_dict_iterator(num_epochs=1): + count += 1 + assert count == 6 + + +if __name__ == '__main__': + test_usps_content_check() + test_usps_basic() + test_usps_exception() + test_usps_visualize(plot=True) + test_usps_usage() + test_usps_with_map() diff --git a/tests/ut/python/dataset/test_encode_jpeg.py b/tests/ut/python/dataset/test_encode_jpeg.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_encode_png.py b/tests/ut/python/dataset/test_encode_png.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_fast_text.py b/tests/ut/python/dataset/test_fast_text.py index 7ce365e562e..2564c4ca718 100644 --- a/tests/ut/python/dataset/test_fast_text.py +++ b/tests/ut/python/dataset/test_fast_text.py @@ -1,235 +1,235 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -from mindspore import log -import mindspore.dataset as ds -import mindspore.dataset.text as text -import mindspore.dataset.text.transforms as T - -DATASET_ROOT_PATH = "../data/dataset/test_fast_text/" - - -def test_fast_text_all_build_from_file_params(): - """ - Feature: FastText - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile - Expectation: Output is equal to the expected value - """ - vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=100) - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - - -def test_fast_text_all_build_from_file_params_eager(): - """ - Feature: FastText - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=4) - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_fast_text_all_to_vectors_params_eager(): - """ - Feature: FastText - Description: Test with all parameters which include `unk_init` and `lower_case_backup` in function ToVectors - in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=4) - my_unk = [-1, -1, -1, -1, -1, -1] - to_vectors = T.ToVectors(vectors, unk_init=my_unk, lower_case_backup=True) - result1 = to_vectors("Ok") - result2 = to_vectors("!") - result3 = to_vectors("This") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_fast_text_build_from_file(): - """ - Feature: FastText - Description: Test with only default parameter - Expectation: Output is equal to the expected value - """ - vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec") - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - - -def test_fast_text_build_from_file_eager(): - """ - Feature: FastText - Description: Test with only default parameter in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec") - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_fast_text_invalid_input(): - """ - Feature: FastText - Description: Test the validate function with invalid parameters - Expectation: Output is equal to the expected error - """ - def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, unk_init=None, - lower_case_backup=False, token="ok"): - log.info("Test Vectors with wrong input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - vectors = text.FastText.from_file(file_path, max_vectors=max_vectors) - to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) - to_vectors(token) - assert error_msg in str(error_info.value) - - test_invalid_input("Not all vectors have the same number of dimensions", - DATASET_ROOT_PATH + "fast_text_dim_different.vec", error=RuntimeError, - error_msg="all vectors must have the same number of dimensions, " \ - "but got dim 5 while expecting 6") - test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "fast_text_empty.vec", - error=RuntimeError, error_msg="invalid file, file is empty.") - test_invalid_input("the count of `unknown_init`'s element is different with word vector.", - DATASET_ROOT_PATH + "fast_text.vec", - error=RuntimeError, - error_msg="unk_init must be the same length as vectors, but got unk_init", - unk_init=[-1, -1]) - test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.vec", RuntimeError, - error_msg="FastText: invalid file") - test_invalid_input("The token is 1-dimensional", DATASET_ROOT_PATH + "fast_text_with_wrong_info.vec", - error=RuntimeError, error_msg="token with 1-dimensional vector.") - test_invalid_input("max_vectors parameter must be greater than 0", DATASET_ROOT_PATH + "fast_text.vec", - error=ValueError, error_msg="Input max_vectors is not within the required interval", - max_vectors=-1) - test_invalid_input("invalid max_vectors parameter type as a float", DATASET_ROOT_PATH + "fast_text.vec", - error=TypeError, error_msg="Argument max_vectors with value 1.0 is not of type []," - " but got .", max_vectors=1.0) - test_invalid_input("invalid max_vectors parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", - error=TypeError, error_msg="Argument max_vectors with value 1 is not of type []," - " but got .", max_vectors="1") - test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "fast_text.vec", - error=RuntimeError, error_msg="input tensor type should be string.", token=1.0) - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", - error=TypeError, error_msg="Argument lower_case_backup with value True is " \ - "not of type []," - " but got .", lower_case_backup="True") - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", - error=TypeError, error_msg="Argument lower_case_backup with value True is " \ - "not of type []," - " but got .", lower_case_backup="True") - test_invalid_input("the suffix of pre-training set must be `*.vec`", DATASET_ROOT_PATH + "fast_text.txt", - error=RuntimeError, error_msg="FastText: invalid file, can not find file '*.vec'") - - -if __name__ == '__main__': - test_fast_text_all_build_from_file_params() - test_fast_text_all_build_from_file_params_eager() - test_fast_text_all_to_vectors_params_eager() - test_fast_text_build_from_file() - test_fast_text_build_from_file_eager() - test_fast_text_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest + +from mindspore import log +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.text.transforms as T + +DATASET_ROOT_PATH = "../data/dataset/test_fast_text/" + + +def test_fast_text_all_build_from_file_params(): + """ + Feature: FastText + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile + Expectation: Output is equal to the expected value + """ + vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=100) + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + + +def test_fast_text_all_build_from_file_params_eager(): + """ + Feature: FastText + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=4) + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_fast_text_all_to_vectors_params_eager(): + """ + Feature: FastText + Description: Test with all parameters which include `unk_init` and `lower_case_backup` in function ToVectors + in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec", max_vectors=4) + my_unk = [-1, -1, -1, -1, -1, -1] + to_vectors = T.ToVectors(vectors, unk_init=my_unk, lower_case_backup=True) + result1 = to_vectors("Ok") + result2 = to_vectors("!") + result3 = to_vectors("This") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_fast_text_build_from_file(): + """ + Feature: FastText + Description: Test with only default parameter + Expectation: Output is equal to the expected value + """ + vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec") + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + + +def test_fast_text_build_from_file_eager(): + """ + Feature: FastText + Description: Test with only default parameter in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.FastText.from_file(DATASET_ROOT_PATH + "fast_text.vec") + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_fast_text_invalid_input(): + """ + Feature: FastText + Description: Test the validate function with invalid parameters + Expectation: Output is equal to the expected error + """ + def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, unk_init=None, + lower_case_backup=False, token="ok"): + log.info("Test Vectors with wrong input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + vectors = text.FastText.from_file(file_path, max_vectors=max_vectors) + to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) + to_vectors(token) + assert error_msg in str(error_info.value) + + test_invalid_input("Not all vectors have the same number of dimensions", + DATASET_ROOT_PATH + "fast_text_dim_different.vec", error=RuntimeError, + error_msg="all vectors must have the same number of dimensions, " \ + "but got dim 5 while expecting 6") + test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "fast_text_empty.vec", + error=RuntimeError, error_msg="invalid file, file is empty.") + test_invalid_input("the count of `unknown_init`'s element is different with word vector.", + DATASET_ROOT_PATH + "fast_text.vec", + error=RuntimeError, + error_msg="unk_init must be the same length as vectors, but got unk_init", + unk_init=[-1, -1]) + test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.vec", RuntimeError, + error_msg="FastText: invalid file") + test_invalid_input("The token is 1-dimensional", DATASET_ROOT_PATH + "fast_text_with_wrong_info.vec", + error=RuntimeError, error_msg="token with 1-dimensional vector.") + test_invalid_input("max_vectors parameter must be greater than 0", DATASET_ROOT_PATH + "fast_text.vec", + error=ValueError, error_msg="Input max_vectors is not within the required interval", + max_vectors=-1) + test_invalid_input("invalid max_vectors parameter type as a float", DATASET_ROOT_PATH + "fast_text.vec", + error=TypeError, error_msg="Argument max_vectors with value 1.0 is not of type []," + " but got .", max_vectors=1.0) + test_invalid_input("invalid max_vectors parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", + error=TypeError, error_msg="Argument max_vectors with value 1 is not of type []," + " but got .", max_vectors="1") + test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "fast_text.vec", + error=RuntimeError, error_msg="input tensor type should be string.", token=1.0) + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", + error=TypeError, error_msg="Argument lower_case_backup with value True is " \ + "not of type []," + " but got .", lower_case_backup="True") + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "fast_text.vec", + error=TypeError, error_msg="Argument lower_case_backup with value True is " \ + "not of type []," + " but got .", lower_case_backup="True") + test_invalid_input("the suffix of pre-training set must be `*.vec`", DATASET_ROOT_PATH + "fast_text.txt", + error=RuntimeError, error_msg="FastText: invalid file, can not find file '*.vec'") + + +if __name__ == '__main__': + test_fast_text_all_build_from_file_params() + test_fast_text_all_build_from_file_params_eager() + test_fast_text_all_to_vectors_params_eager() + test_fast_text_build_from_file() + test_fast_text_build_from_file_eager() + test_fast_text_invalid_input() diff --git a/tests/ut/python/dataset/test_glove.py b/tests/ut/python/dataset/test_glove.py index 4cbeafb9537..9f54e9ce5f9 100644 --- a/tests/ut/python/dataset/test_glove.py +++ b/tests/ut/python/dataset/test_glove.py @@ -1,233 +1,233 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -from mindspore import log -import mindspore.dataset as ds -import mindspore.dataset.text as text -import mindspore.dataset.text.transforms as T - -DATASET_ROOT_PATH = "../data/dataset/testGloVe/" - - -def test_glove_all_build_from_file_params(): - """ - Feature: GloVe - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile - Expectation: Output is equal to the expected value - """ - vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=100) - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - - -def test_glove_all_build_from_file_params_eager(): - """ - Feature: GloVe - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=4) - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_glove_all_to_vectors_params_eager(): - """ - Feature: GloVe - Description: Test with all parameters which include `unk_init` and `lower_case_backup` in function ToVectors - in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=4) - my_unk = [-1, -1, -1, -1, -1, -1] - to_vectors = T.ToVectors(vectors, unk_init=my_unk, lower_case_backup=True) - result1 = to_vectors("Ok") - result2 = to_vectors("!") - result3 = to_vectors("This") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_glove_build_from_file(): - """ - Feature: GloVe - Description: Test with only default parameter - Expectation: Output is equal to the expected value - """ - vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt") - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - assert ind == 7 - - -def test_glove_build_from_file_eager(): - """ - Feature: GloVe - Description: Test with only default parameter in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt") - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_glove_invalid_input(): - """ - Feature: GloVe - Description: Test the validate function with invalid parameters - Expectation: Output is equal to the expected error - """ - def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, unk_init=None, - lower_case_backup=False, token="ok"): - log.info("Test Vectors with wrong input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - vectors = text.GloVe.from_file(file_path, max_vectors=max_vectors) - to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) - to_vectors(token) - assert error_msg in str(error_info.value) - - test_invalid_input("Not all vectors have the same number of dimensions", - DATASET_ROOT_PATH + "glove.6B.dim_different.txt", error=RuntimeError, - error_msg="all vectors must have the same number of dimensions, " \ - "but got dim 5 while expecting 6") - test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "glove.6B.empty.txt", - error=RuntimeError, error_msg="invalid file, file is empty.") - test_invalid_input("the count of `unknown_init`'s element is different with word vector.", - DATASET_ROOT_PATH + "glove.6B.test.txt", - error=RuntimeError, - error_msg="unk_init must be the same length as vectors, but got unk_init", - unk_init=[-1, -1]) - test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", RuntimeError, - error_msg="GloVe: invalid file") - test_invalid_input("The token is 1-dimensional", DATASET_ROOT_PATH + "glove.6B.with_wrong_info.txt", - error=RuntimeError, error_msg="token with 1-dimensional vector.") - test_invalid_input("max_vectors parameter must be greater than 0", DATASET_ROOT_PATH + "glove.6B.test.txt", - error=ValueError, error_msg="Input max_vectors is not within the required interval", - max_vectors=-1) - test_invalid_input("invalid max_vectors parameter type as a float", DATASET_ROOT_PATH + "glove.6B.test.txt", - error=TypeError, error_msg="Argument max_vectors with value 1.0 is not of type []," - " but got .", max_vectors=1.0) - test_invalid_input("invalid max_vectors parameter type as a string", DATASET_ROOT_PATH + "glove.6B.test.txt", - error=TypeError, error_msg="Argument max_vectors with value 1 is not of type []," - " but got .", max_vectors="1") - test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "glove.6B.test.txt", - error=RuntimeError, error_msg="input tensor type should be string.", token=1.0) - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "glove.6B.test.txt", - error=TypeError, error_msg="Argument lower_case_backup with value True is " \ - "not of type []," - " but got .", lower_case_backup="True") - test_invalid_input("not right glove dataset. The formal must be `glove.6B.*.txt`", DATASET_ROOT_PATH + - "glove.6B.test.vec", error=RuntimeError, error_msg="GloVe: invalid file, can not " \ - "find file 'glove.6B.*.txt'") - - -if __name__ == '__main__': - test_glove_all_build_from_file_params() - test_glove_all_build_from_file_params_eager() - test_glove_all_to_vectors_params_eager() - test_glove_build_from_file() - test_glove_build_from_file_eager() - test_glove_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest + +from mindspore import log +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.text.transforms as T + +DATASET_ROOT_PATH = "../data/dataset/testGloVe/" + + +def test_glove_all_build_from_file_params(): + """ + Feature: GloVe + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile + Expectation: Output is equal to the expected value + """ + vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=100) + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + + +def test_glove_all_build_from_file_params_eager(): + """ + Feature: GloVe + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=4) + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_glove_all_to_vectors_params_eager(): + """ + Feature: GloVe + Description: Test with all parameters which include `unk_init` and `lower_case_backup` in function ToVectors + in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt", max_vectors=4) + my_unk = [-1, -1, -1, -1, -1, -1] + to_vectors = T.ToVectors(vectors, unk_init=my_unk, lower_case_backup=True) + result1 = to_vectors("Ok") + result2 = to_vectors("!") + result3 = to_vectors("This") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_glove_build_from_file(): + """ + Feature: GloVe + Description: Test with only default parameter + Expectation: Output is equal to the expected value + """ + vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt") + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + assert ind == 7 + + +def test_glove_build_from_file_eager(): + """ + Feature: GloVe + Description: Test with only default parameter in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.GloVe.from_file(DATASET_ROOT_PATH + "glove.6B.test.txt") + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_glove_invalid_input(): + """ + Feature: GloVe + Description: Test the validate function with invalid parameters + Expectation: Output is equal to the expected error + """ + def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, unk_init=None, + lower_case_backup=False, token="ok"): + log.info("Test Vectors with wrong input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + vectors = text.GloVe.from_file(file_path, max_vectors=max_vectors) + to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) + to_vectors(token) + assert error_msg in str(error_info.value) + + test_invalid_input("Not all vectors have the same number of dimensions", + DATASET_ROOT_PATH + "glove.6B.dim_different.txt", error=RuntimeError, + error_msg="all vectors must have the same number of dimensions, " \ + "but got dim 5 while expecting 6") + test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "glove.6B.empty.txt", + error=RuntimeError, error_msg="invalid file, file is empty.") + test_invalid_input("the count of `unknown_init`'s element is different with word vector.", + DATASET_ROOT_PATH + "glove.6B.test.txt", + error=RuntimeError, + error_msg="unk_init must be the same length as vectors, but got unk_init", + unk_init=[-1, -1]) + test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", RuntimeError, + error_msg="GloVe: invalid file") + test_invalid_input("The token is 1-dimensional", DATASET_ROOT_PATH + "glove.6B.with_wrong_info.txt", + error=RuntimeError, error_msg="token with 1-dimensional vector.") + test_invalid_input("max_vectors parameter must be greater than 0", DATASET_ROOT_PATH + "glove.6B.test.txt", + error=ValueError, error_msg="Input max_vectors is not within the required interval", + max_vectors=-1) + test_invalid_input("invalid max_vectors parameter type as a float", DATASET_ROOT_PATH + "glove.6B.test.txt", + error=TypeError, error_msg="Argument max_vectors with value 1.0 is not of type []," + " but got .", max_vectors=1.0) + test_invalid_input("invalid max_vectors parameter type as a string", DATASET_ROOT_PATH + "glove.6B.test.txt", + error=TypeError, error_msg="Argument max_vectors with value 1 is not of type []," + " but got .", max_vectors="1") + test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "glove.6B.test.txt", + error=RuntimeError, error_msg="input tensor type should be string.", token=1.0) + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "glove.6B.test.txt", + error=TypeError, error_msg="Argument lower_case_backup with value True is " \ + "not of type []," + " but got .", lower_case_backup="True") + test_invalid_input("not right glove dataset. The formal must be `glove.6B.*.txt`", DATASET_ROOT_PATH + + "glove.6B.test.vec", error=RuntimeError, error_msg="GloVe: invalid file, can not " \ + "find file 'glove.6B.*.txt'") + + +if __name__ == '__main__': + test_glove_all_build_from_file_params() + test_glove_all_build_from_file_params_eager() + test_glove_all_to_vectors_params_eager() + test_glove_build_from_file() + test_glove_build_from_file_eager() + test_glove_invalid_input() diff --git a/tests/ut/python/dataset/test_magphase.py b/tests/ut/python/dataset/test_magphase.py index bcc53ddddad..7f4741061de 100644 --- a/tests/ut/python/dataset/test_magphase.py +++ b/tests/ut/python/dataset/test_magphase.py @@ -1,110 +1,110 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing Magphase Python API -""" -import numpy as np - -import mindspore.dataset as ds -import mindspore.dataset.audio as audio -from mindspore import log as logger - - -def test_magphase_pipeline(): - """ - Feature: Magphase - Description: Test Magphase in pipeline mode - Expectation: Output is equal to the expected output - """ - logger.info("Test Magphase pipeline.") - - data1 = [[[3.0, -4.0], [-5.0, 12.0]]] - expected = [5, 13, -0.927295, 1.965587] - dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False) - magphase_window = audio.Magphase(power=1.0) - dataset = dataset.map(operations=magphase_window, input_columns=["col1"], - output_columns=["mag", "phase"]) - for data1, data2 in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True): - assert abs(data1[0] - expected[0]) < 0.00001 - assert abs(data1[1] - expected[1]) < 0.00001 - assert abs(data2[0] - expected[2]) < 0.00001 - assert abs(data2[1] - expected[3]) < 0.00001 - - logger.info("Finish testing Magphase.") - - -def test_magphase_eager(): - """ - Feature: Magphase - Description: Test Magphase in eager mode - Expectation: Output is equal to the expected output - """ - logger.info("Test Magphase eager.") - - input_number = np.array([41, 67, 34, 0, 69, 24, 78, 58]).reshape((2, 2, 2)).astype("double") - mag = np.array([78.54934755, 34., 73.05477397, 97.20082304]).reshape((2, 2)).astype("double") - phase = np.array([1.02164342, 0, 0.33473684, 0.63938591]).reshape((2, 2)).astype("double") - magphase_window = audio.Magphase() - data1, data2 = magphase_window(input_number) - assert (abs(data1 - mag) < 0.00001).all() - assert (abs(data2 - phase) < 0.00001).all() - - logger.info("Finish testing Magphase.") - - -def test_magphase_exception(): - """ - Feature: Magphase - Description: Test Magphase with invalid input - Expectation: Correct error is raised as expected - """ - logger.info("Test Magphase not callable.") - - try: - input_number = np.array([1, 2, 3, 4]).reshape(4,).astype("double") - magphase_window = audio.Magphase(power=2.0) - _ = magphase_window(input_number) - except RuntimeError as error: - logger.info("Got an exception in Magphase: {}".format(str(error))) - assert "the shape of input tensor does not match the requirement of operator" in str(error) - try: - input_number = np.array([1, 2, 3, 4]).reshape(1, 4).astype("double") - magphase_window = audio.Magphase(power=2.0) - _ = magphase_window(input_number) - except RuntimeError as error: - logger.info("Got an exception in Magphase: {}".format(str(error))) - assert "the shape of input tensor does not match the requirement of operator" in str(error) - try: - input_number = np.array(['test', 'test']).reshape(1, 2) - magphase_window = audio.Magphase(power=2.0) - _ = magphase_window(input_number) - except RuntimeError as error: - logger.info("Got an exception in Magphase: {}".format(str(error))) - assert "the data type of input tensor does not match the requirement of operator" in str(error) - try: - input_number = np.array([1, 2, 3, 4]).reshape(2, 2).astype("double") - magphase_window = audio.Magphase(power=-1.0) - _ = magphase_window(input_number) - except ValueError as error: - logger.info("Got an exception in Magphase: {}".format(str(error))) - assert "Input power is not within the required interval of [0, 16777216]." in str(error) - - logger.info("Finish testing Magphase.") - - -if __name__ == "__main__": - test_magphase_pipeline() - test_magphase_eager() - test_magphase_exception() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Magphase Python API +""" +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.audio as audio +from mindspore import log as logger + + +def test_magphase_pipeline(): + """ + Feature: Magphase + Description: Test Magphase in pipeline mode + Expectation: Output is equal to the expected output + """ + logger.info("Test Magphase pipeline.") + + data1 = [[[3.0, -4.0], [-5.0, 12.0]]] + expected = [5, 13, -0.927295, 1.965587] + dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False) + magphase_window = audio.Magphase(power=1.0) + dataset = dataset.map(operations=magphase_window, input_columns=["col1"], + output_columns=["mag", "phase"]) + for data1, data2 in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True): + assert abs(data1[0] - expected[0]) < 0.00001 + assert abs(data1[1] - expected[1]) < 0.00001 + assert abs(data2[0] - expected[2]) < 0.00001 + assert abs(data2[1] - expected[3]) < 0.00001 + + logger.info("Finish testing Magphase.") + + +def test_magphase_eager(): + """ + Feature: Magphase + Description: Test Magphase in eager mode + Expectation: Output is equal to the expected output + """ + logger.info("Test Magphase eager.") + + input_number = np.array([41, 67, 34, 0, 69, 24, 78, 58]).reshape((2, 2, 2)).astype("double") + mag = np.array([78.54934755, 34., 73.05477397, 97.20082304]).reshape((2, 2)).astype("double") + phase = np.array([1.02164342, 0, 0.33473684, 0.63938591]).reshape((2, 2)).astype("double") + magphase_window = audio.Magphase() + data1, data2 = magphase_window(input_number) + assert (abs(data1 - mag) < 0.00001).all() + assert (abs(data2 - phase) < 0.00001).all() + + logger.info("Finish testing Magphase.") + + +def test_magphase_exception(): + """ + Feature: Magphase + Description: Test Magphase with invalid input + Expectation: Correct error is raised as expected + """ + logger.info("Test Magphase not callable.") + + try: + input_number = np.array([1, 2, 3, 4]).reshape(4,).astype("double") + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "the shape of input tensor does not match the requirement of operator" in str(error) + try: + input_number = np.array([1, 2, 3, 4]).reshape(1, 4).astype("double") + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "the shape of input tensor does not match the requirement of operator" in str(error) + try: + input_number = np.array(['test', 'test']).reshape(1, 2) + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "the data type of input tensor does not match the requirement of operator" in str(error) + try: + input_number = np.array([1, 2, 3, 4]).reshape(2, 2).astype("double") + magphase_window = audio.Magphase(power=-1.0) + _ = magphase_window(input_number) + except ValueError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "Input power is not within the required interval of [0, 16777216]." in str(error) + + logger.info("Finish testing Magphase.") + + +if __name__ == "__main__": + test_magphase_pipeline() + test_magphase_eager() + test_magphase_exception() diff --git a/tests/ut/python/dataset/test_mask_along_axis.py b/tests/ut/python/dataset/test_mask_along_axis.py index 806e587d1bd..b1c6a675859 100644 --- a/tests/ut/python/dataset/test_mask_along_axis.py +++ b/tests/ut/python/dataset/test_mask_along_axis.py @@ -1,167 +1,167 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -import numpy as np -import pytest - -import mindspore.dataset as ds -import mindspore.dataset.audio as atf -from mindspore import log as logger - -CHANNEL = 1 -FREQ = 5 -TIME = 5 - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - """ - Precision calculation formula - """ - if np.any(np.isnan(data_expected)): - assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): - count_unequal_element(data_expected, data_me, rtol, atol) - - -def count_unequal_element(data_expected, data_me, rtol, atol): - """ - Precision calculation func - """ - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( - data_expected[greater], data_me[greater], error[greater]) - - -def gen(shape): - np.random.seed(0) - data = np.random.random(shape) - yield(np.array(data, dtype=np.float32),) - - -def test_mask_along_axis_eager_random_input(): - """ - Feature: MaskAlongAxis - Description: Mindspore eager mode normal testcase with random input tensor - Expectation: The returned result is as expected - """ - logger.info("test Mask_Along_axis op") - spectrogram = next(gen((CHANNEL, FREQ, TIME)))[0] - expect_output = copy.deepcopy(spectrogram) - out_put = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=5.0, axis=2)(spectrogram) - for item in expect_output[0]: - item[0] = 5.0 - assert out_put.shape == (CHANNEL, FREQ, TIME) - allclose_nparray(out_put, expect_output, 0.0001, 0.0001) - - -def test_mask_along_axis_eager_precision(): - """ - Feature: MaskAlongAxis - Description: Mindspore eager mode checking precision - Expectation: The returned result is as expected - """ - logger.info("test MaskAlongAxis op, checking precision") - spectrogram_0 = np.array([[[-0.0635, -0.6903], - [-1.7175, -0.0815], - [0.7981, -0.8297], - [-0.4589, -0.7506]], - [[0.6189, 1.1874], - [0.1856, -0.5536], - [1.0620, 0.2071], - [-0.3874, 0.0664]]]).astype(np.float32) - out_ms_0 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_0) - spectrogram_1 = np.array([[[-0.0635, -0.6903], - [-1.7175, -0.0815], - [0.7981, -0.8297], - [-0.4589, -0.7506]], - [[0.6189, 1.1874], - [0.1856, -0.5536], - [1.0620, 0.2071], - [-0.3874, 0.0664]]]).astype(np.float64) - out_ms_1 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_1) - out_benchmark = np.array([[[2.0000, -0.6903], - [2.0000, -0.0815], - [2.0000, -0.8297], - [2.0000, -0.7506]], - [[2.0000, 1.1874], - [2.0000, -0.5536], - [2.0000, 0.2071], - [2.0000, 0.0664]]]).astype(np.float32) - allclose_nparray(out_ms_0, out_benchmark, 0.0001, 0.0001) - allclose_nparray(out_ms_1, out_benchmark, 0.0001, 0.0001) - - -def test_mask_along_axis_pipeline(): - """ - Feature: MaskAlongAxis - Description: Mindspore pipeline mode normal testcase - Expectation: The returned result is as expected - """ - logger.info("test MaskAlongAxis op, pipeline") - - generator = gen((CHANNEL, FREQ, TIME)) - expect_output = copy.deepcopy(next(gen((CHANNEL, FREQ, TIME)))[0]) - data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"]) - transforms = [atf.MaskAlongAxis(mask_start=2, mask_width=2, mask_value=2.0, axis=2)] - data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) - - for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): - out_put = item["multi_dimensional_data"] - - for item in expect_output[0]: - item[2] = 2.0 - item[3] = 2.0 - assert out_put.shape == (CHANNEL, FREQ, TIME) - allclose_nparray(out_put, expect_output, 0.0001, 0.0001) - - -def test_mask_along_axis_invalid_input(): - """ - Feature: MaskAlongAxis - Description: Mindspore eager mode with invalid input tensor - Expectation: Throw correct error and message - """ - def test_invalid_param(test_name, mask_start, mask_width, mask_value, axis, error, error_msg): - """ - a function used for checking correct error and message with various input - """ - logger.info("Test MaskAlongAxis with wrong params: {0}".format(test_name)) - with pytest.raises(error) as error_info: - atf.MaskAlongAxis(mask_start, mask_width, mask_value, axis) - assert error_msg in str(error_info.value) - - test_invalid_param("invalid mask_start", -1, 10, 1.0, 1, ValueError, - "Input mask_start is not within the required interval of [0, 2147483647].") - test_invalid_param("invalid mask_width", 0, -1, 1.0, 1, ValueError, - "Input mask_width is not within the required interval of [1, 2147483647].") - test_invalid_param("invalid axis", 0, 10, 1.0, 1.0, TypeError, - "Argument axis with value 1.0 is not of type [], but got .") - test_invalid_param("invalid axis", 0, 10, 1.0, 0, ValueError, - "Input axis is not within the required interval of [1, 2].") - test_invalid_param("invalid axis", 0, 10, 1.0, 3, ValueError, - "Input axis is not within the required interval of [1, 2].") - test_invalid_param("invalid axis", 0, 10, 1.0, -1, ValueError, - "Input axis is not within the required interval of [1, 2].") - - -if __name__ == "__main__": - test_mask_along_axis_eager_random_input() - test_mask_along_axis_eager_precision() - test_mask_along_axis_pipeline() - test_mask_along_axis_invalid_input() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio as atf +from mindspore import log as logger + +CHANNEL = 1 +FREQ = 5 +TIME = 5 + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + """ + Precision calculation formula + """ + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + count_unequal_element(data_expected, data_me, rtol, atol) + + +def count_unequal_element(data_expected, data_me, rtol, atol): + """ + Precision calculation func + """ + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( + data_expected[greater], data_me[greater], error[greater]) + + +def gen(shape): + np.random.seed(0) + data = np.random.random(shape) + yield(np.array(data, dtype=np.float32),) + + +def test_mask_along_axis_eager_random_input(): + """ + Feature: MaskAlongAxis + Description: Mindspore eager mode normal testcase with random input tensor + Expectation: The returned result is as expected + """ + logger.info("test Mask_Along_axis op") + spectrogram = next(gen((CHANNEL, FREQ, TIME)))[0] + expect_output = copy.deepcopy(spectrogram) + out_put = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=5.0, axis=2)(spectrogram) + for item in expect_output[0]: + item[0] = 5.0 + assert out_put.shape == (CHANNEL, FREQ, TIME) + allclose_nparray(out_put, expect_output, 0.0001, 0.0001) + + +def test_mask_along_axis_eager_precision(): + """ + Feature: MaskAlongAxis + Description: Mindspore eager mode checking precision + Expectation: The returned result is as expected + """ + logger.info("test MaskAlongAxis op, checking precision") + spectrogram_0 = np.array([[[-0.0635, -0.6903], + [-1.7175, -0.0815], + [0.7981, -0.8297], + [-0.4589, -0.7506]], + [[0.6189, 1.1874], + [0.1856, -0.5536], + [1.0620, 0.2071], + [-0.3874, 0.0664]]]).astype(np.float32) + out_ms_0 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_0) + spectrogram_1 = np.array([[[-0.0635, -0.6903], + [-1.7175, -0.0815], + [0.7981, -0.8297], + [-0.4589, -0.7506]], + [[0.6189, 1.1874], + [0.1856, -0.5536], + [1.0620, 0.2071], + [-0.3874, 0.0664]]]).astype(np.float64) + out_ms_1 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_1) + out_benchmark = np.array([[[2.0000, -0.6903], + [2.0000, -0.0815], + [2.0000, -0.8297], + [2.0000, -0.7506]], + [[2.0000, 1.1874], + [2.0000, -0.5536], + [2.0000, 0.2071], + [2.0000, 0.0664]]]).astype(np.float32) + allclose_nparray(out_ms_0, out_benchmark, 0.0001, 0.0001) + allclose_nparray(out_ms_1, out_benchmark, 0.0001, 0.0001) + + +def test_mask_along_axis_pipeline(): + """ + Feature: MaskAlongAxis + Description: Mindspore pipeline mode normal testcase + Expectation: The returned result is as expected + """ + logger.info("test MaskAlongAxis op, pipeline") + + generator = gen((CHANNEL, FREQ, TIME)) + expect_output = copy.deepcopy(next(gen((CHANNEL, FREQ, TIME)))[0]) + data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"]) + transforms = [atf.MaskAlongAxis(mask_start=2, mask_width=2, mask_value=2.0, axis=2)] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + + for item in expect_output[0]: + item[2] = 2.0 + item[3] = 2.0 + assert out_put.shape == (CHANNEL, FREQ, TIME) + allclose_nparray(out_put, expect_output, 0.0001, 0.0001) + + +def test_mask_along_axis_invalid_input(): + """ + Feature: MaskAlongAxis + Description: Mindspore eager mode with invalid input tensor + Expectation: Throw correct error and message + """ + def test_invalid_param(test_name, mask_start, mask_width, mask_value, axis, error, error_msg): + """ + a function used for checking correct error and message with various input + """ + logger.info("Test MaskAlongAxis with wrong params: {0}".format(test_name)) + with pytest.raises(error) as error_info: + atf.MaskAlongAxis(mask_start, mask_width, mask_value, axis) + assert error_msg in str(error_info.value) + + test_invalid_param("invalid mask_start", -1, 10, 1.0, 1, ValueError, + "Input mask_start is not within the required interval of [0, 2147483647].") + test_invalid_param("invalid mask_width", 0, -1, 1.0, 1, ValueError, + "Input mask_width is not within the required interval of [1, 2147483647].") + test_invalid_param("invalid axis", 0, 10, 1.0, 1.0, TypeError, + "Argument axis with value 1.0 is not of type [], but got .") + test_invalid_param("invalid axis", 0, 10, 1.0, 0, ValueError, + "Input axis is not within the required interval of [1, 2].") + test_invalid_param("invalid axis", 0, 10, 1.0, 3, ValueError, + "Input axis is not within the required interval of [1, 2].") + test_invalid_param("invalid axis", 0, 10, 1.0, -1, ValueError, + "Input axis is not within the required interval of [1, 2].") + + +if __name__ == "__main__": + test_mask_along_axis_eager_random_input() + test_mask_along_axis_eager_precision() + test_mask_along_axis_pipeline() + test_mask_along_axis_invalid_input() diff --git a/tests/ut/python/dataset/test_mask_along_axis_iid.py b/tests/ut/python/dataset/test_mask_along_axis_iid.py index 3373445319b..131d80bab82 100644 --- a/tests/ut/python/dataset/test_mask_along_axis_iid.py +++ b/tests/ut/python/dataset/test_mask_along_axis_iid.py @@ -1,124 +1,124 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -import numpy as np -import pytest - -import mindspore.dataset as ds -import mindspore.dataset.audio as audio -from mindspore import log as logger - -BATCH = 2 -CHANNEL = 2 -FREQ = 10 -TIME = 10 - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - """ - Precision calculation formula - """ - if np.any(np.isnan(data_expected)): - assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): - count_unequal_element(data_expected, data_me, rtol, atol) - - -def count_unequal_element(data_expected, data_me, rtol, atol): - """ - Precision calculation func - """ - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( - data_expected[greater], data_me[greater], error[greater]) - - -def gen(shape): - np.random.seed(0) - data = np.random.random(shape) - yield (np.array(data, dtype=np.float32),) - - -def test_mask_along_axis_iid_eager(): - """ - Feature: MaskAlongAxisIID - Description: Mindspore eager mode with normal testcase - Expectation: The returned result is as expected - """ - logger.info("test MaskAlongAxisIID op, eager") - spectrogram_01 = next(gen((BATCH, CHANNEL, FREQ, TIME)))[0] - output_01 = audio.MaskAlongAxisIID(mask_param=8, mask_value=5.0, axis=1)(spectrogram_01) - assert output_01.shape == (BATCH, CHANNEL, FREQ, TIME) - - spectrogram_02 = next(gen((BATCH, CHANNEL, FREQ, TIME)))[0] - expect_output = copy.deepcopy(spectrogram_02) - output_02 = audio.MaskAlongAxisIID(mask_param=0, mask_value=5.0, axis=1)(spectrogram_02) - allclose_nparray(output_02, expect_output, 0.0001, 0.0001) - - -def test_mask_along_axis_iid_pipeline(): - """ - Feature: MaskAlongAxisIID - Description: Mindspore pipeline mode with normal testcase - Expectation: The returned result is as expected - """ - logger.info("test MaskAlongAxisIID op, pipeline") - - generator = gen([BATCH, CHANNEL, FREQ, TIME]) - data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"]) - - transforms = [audio.MaskAlongAxisIID(mask_param=8, mask_value=5.0, axis=2)] - data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) - - for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): - out_put = item["multi_dimensional_data"] - assert out_put.shape == (BATCH, CHANNEL, FREQ, TIME) - - -def test_mask_along_axis_iid_invalid_input(): - """ - Feature: MaskAlongAxisIID - Description: Mindspore eager mode with invalid input - Expectation: The returned result is as expected - """ - def test_invalid_param(test_name, mask_param, mask_value, axis, error, error_msg): - """ - a function used for checking correct error and message - """ - logger.info("Test MaskAlongAxisIID with wrong params: {0}".format(test_name)) - with pytest.raises(error) as error_info: - audio.MaskAlongAxisIID(mask_param, mask_value, axis) - assert error_msg in str(error_info.value) - - test_invalid_param("invalid mask_param", 1.0, 1.0, 1, TypeError, - "Argument mask_param with value 1.0 is not of type [], but got .") - test_invalid_param("invalid mask_param", -1, 1.0, 1, ValueError, - "Input mask_param is not within the required interval of [0, 2147483647].") - test_invalid_param("invalid axis", 5, 1.0, 5.0, TypeError, - "Argument axis with value 5.0 is not of type [], but got .") - test_invalid_param("invalid axis", 5, 1.0, 0, ValueError, - "Input axis is not within the required interval of [1, 2].") - test_invalid_param("invalid axis", 5, 1.0, 3, ValueError, - "Input axis is not within the required interval of [1, 2].") - - -if __name__ == "__main__": - test_mask_along_axis_iid_eager() - test_mask_along_axis_iid_invalid_input() - test_mask_along_axis_iid_pipeline() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio as audio +from mindspore import log as logger + +BATCH = 2 +CHANNEL = 2 +FREQ = 10 +TIME = 10 + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + """ + Precision calculation formula + """ + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + count_unequal_element(data_expected, data_me, rtol, atol) + + +def count_unequal_element(data_expected, data_me, rtol, atol): + """ + Precision calculation func + """ + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( + data_expected[greater], data_me[greater], error[greater]) + + +def gen(shape): + np.random.seed(0) + data = np.random.random(shape) + yield (np.array(data, dtype=np.float32),) + + +def test_mask_along_axis_iid_eager(): + """ + Feature: MaskAlongAxisIID + Description: Mindspore eager mode with normal testcase + Expectation: The returned result is as expected + """ + logger.info("test MaskAlongAxisIID op, eager") + spectrogram_01 = next(gen((BATCH, CHANNEL, FREQ, TIME)))[0] + output_01 = audio.MaskAlongAxisIID(mask_param=8, mask_value=5.0, axis=1)(spectrogram_01) + assert output_01.shape == (BATCH, CHANNEL, FREQ, TIME) + + spectrogram_02 = next(gen((BATCH, CHANNEL, FREQ, TIME)))[0] + expect_output = copy.deepcopy(spectrogram_02) + output_02 = audio.MaskAlongAxisIID(mask_param=0, mask_value=5.0, axis=1)(spectrogram_02) + allclose_nparray(output_02, expect_output, 0.0001, 0.0001) + + +def test_mask_along_axis_iid_pipeline(): + """ + Feature: MaskAlongAxisIID + Description: Mindspore pipeline mode with normal testcase + Expectation: The returned result is as expected + """ + logger.info("test MaskAlongAxisIID op, pipeline") + + generator = gen([BATCH, CHANNEL, FREQ, TIME]) + data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"]) + + transforms = [audio.MaskAlongAxisIID(mask_param=8, mask_value=5.0, axis=2)] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + assert out_put.shape == (BATCH, CHANNEL, FREQ, TIME) + + +def test_mask_along_axis_iid_invalid_input(): + """ + Feature: MaskAlongAxisIID + Description: Mindspore eager mode with invalid input + Expectation: The returned result is as expected + """ + def test_invalid_param(test_name, mask_param, mask_value, axis, error, error_msg): + """ + a function used for checking correct error and message + """ + logger.info("Test MaskAlongAxisIID with wrong params: {0}".format(test_name)) + with pytest.raises(error) as error_info: + audio.MaskAlongAxisIID(mask_param, mask_value, axis) + assert error_msg in str(error_info.value) + + test_invalid_param("invalid mask_param", 1.0, 1.0, 1, TypeError, + "Argument mask_param with value 1.0 is not of type [], but got .") + test_invalid_param("invalid mask_param", -1, 1.0, 1, ValueError, + "Input mask_param is not within the required interval of [0, 2147483647].") + test_invalid_param("invalid axis", 5, 1.0, 5.0, TypeError, + "Argument axis with value 5.0 is not of type [], but got .") + test_invalid_param("invalid axis", 5, 1.0, 0, ValueError, + "Input axis is not within the required interval of [1, 2].") + test_invalid_param("invalid axis", 5, 1.0, 3, ValueError, + "Input axis is not within the required interval of [1, 2].") + + +if __name__ == "__main__": + test_mask_along_axis_iid_eager() + test_mask_along_axis_iid_invalid_input() + test_mask_along_axis_iid_pipeline() diff --git a/tests/ut/python/dataset/test_phase_vocoder.py b/tests/ut/python/dataset/test_phase_vocoder.py index 2f9785ba5e8..d2b8bf70c7b 100644 --- a/tests/ut/python/dataset/test_phase_vocoder.py +++ b/tests/ut/python/dataset/test_phase_vocoder.py @@ -1,161 +1,161 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -import mindspore.dataset as ds -import mindspore.dataset.audio as audio -from mindspore import log as logger - - -def gen(shape): - np.random.seed(0) - data = np.random.random(shape) - yield (np.array(data, dtype=np.float32),) - - -def count_unequal_element(data_expected, data_me, rtol, atol): - assert data_expected.shape == data_me.shape - total_count = len(data_expected.flatten()) - error = np.abs(data_expected - data_me) - greater = np.greater(error, atol + np.abs(data_expected) * rtol) - loss_count = np.count_nonzero(greater) - assert (loss_count / total_count) < rtol, \ - "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ - format(data_expected[greater], data_me[greater], error[greater]) - - -def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): - if np.any(np.isnan(data_expected)): - assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) - elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): - count_unequal_element(data_expected, data_me, rtol, atol) - - -def test_phase_vocoder_compare(): - """ - Feature: PhaseVocoder - Description: Mindspore eager mode checking precision - Expectation: The returned result is as expected - """ - indata_0 = np.array([[[[0.43189, 2.3049924], - [-0.01202229, 0.9176453], - [-0.6258611, 0.66475236], - [0.13541847, 1.2829605], - [0.9725325, 1.1669061]], - [[-0.35001752, -1.0989336], - [-1.4930767, 0.86829656], - [0.3355314, -0.41216415], - [-1.1828239, 1.0075365], - [-0.19343425, 0.38364533]]]]).astype('float32') - indata_1 = np.array([[[[0.43189, 2.3049924], - [-0.01202229, 0.9176453], - [-0.6258611, 0.66475236], - [0.13541847, 1.2829605], - [0.9725325, 1.1669061]], - [[-0.35001752, -1.0989336], - [-1.4930767, 0.86829656], - [0.3355314, -0.41216415], - [-1.1828239, 1.0075365], - [-0.19343425, 0.38364533]]]]).astype('float64') - rate = 2. - phase_advance_0 = np.array([[0.0000], [3.9270]]).astype('float32') - op_0 = audio.PhaseVocoder(rate, phase_advance_0) - phase_advance_1 = np.array([[0.0000], [3.9270]]).astype('float64') - op_1 = audio.PhaseVocoder(rate, phase_advance_1) - outdata_0 = op_0(indata_0) - outdata_1 = op_1(indata_1) - stand_outdata = np.array([[[[0.43189007, 2.3049924], - [-0.01196056, 0.9129374], - [1.1385509, 1.00558]], - [[-0.35001755, -1.0989336], - [-0.4594292, 0.26718047], - [0.404371, -0.14520557]]]]).astype('float32') - allclose_nparray(outdata_0, stand_outdata, 0.0001, 0.0001) - allclose_nparray(outdata_1, stand_outdata, 0.0001, 0.0001) - - -def test_phase_vocoder_eager(): - """ - Feature: PhaseVocoder - Description: Mindspore eager mode with normal testcase - Expectation: The returned result is as expected - """ - logger.info("test PhaseVocoder op in eager mode") - stft = next(gen([10, 10, 10, 2]))[0] - out_put = audio.PhaseVocoder(1.3, np.random.randn(10, 1).astype('float32'))(stft) - assert out_put.shape == (10, 10, 8, 2) - - -def test_phase_vocoder_pipeline(): - """ - Feature: PhaseVocoder - Description: Mindspore pipeline mode with normal testcase - Expectation: The returned result is as expected - """ - logger.info("test PhaseVocoder op in pipeline mode") - - generator = gen([32, 33, 333, 2]) - data1 = ds.GeneratorDataset(source=generator, column_names=["input"]) - - transforms = [audio.PhaseVocoder(0.8, np.random.randn(33, 1).astype('float32'))] - data1 = data1.map(operations=transforms, input_columns=["input"]) - - for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): - out_put = item["input"] - assert out_put.shape == (32, 33, 417, 2) - - -def test_phase_vocoder_invalid_input(): - """ - Feature: PhaseVocoder - Description: Mindspore eager mode with invalid input - Expectation: The returned result is as expected - """ - def test_invalid_param(test_name, rate, phase_advance, error, error_msg): - logger.info("Test PhaseVocoder with wrong params: {0}".format(test_name)) - with pytest.raises(error) as error_info: - _ = audio.PhaseVocoder(rate, phase_advance) - assert error_msg in str(error_info.value) - - def test_invalid_input(test_name, spec, rate, phase_advance, error, error_msg): - logger.info("Test PhaseVocoder with wrong params: {0}".format(test_name)) - with pytest.raises(error) as error_info: - _ = audio.PhaseVocoder(rate, phase_advance)(spec) - assert error_msg in str(error_info.value) - - test_invalid_param("invalid phase_advance", 2, None, TypeError, - "Argument phase_advance with value None is not of type") - test_invalid_param("invalid phase_advance", 0, np.random.randn(4, 1), ValueError, - "Input rate is not within the required interval of (0, 16777216].") - spec = next(gen([1, 2, 2]))[0] - test_invalid_input("invalid phase_advance", spec, 1.23, np.random.randn(4), RuntimeError, - "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of .") - test_invalid_input("invalid phase_advance", spec, 1.1, np.random.randn(4, 4, 1), RuntimeError, - "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of .") - test_invalid_input("invalid input tensor", spec, 2, np.random.randn(3, 1), RuntimeError, - "PhaseVocoder: invalid parameter, 'first dimension of 'phase_advance'' should be equal") - input_tensor = np.random.randn(4, 4, 2).astype('float32') - input_phase_advance = np.random.randn(4, 1).astype('float64') - test_invalid_input("invalid input tensor", input_tensor, 2, input_phase_advance, RuntimeError, - "PhaseVocoder: invalid parameter, data type of phase_advance should be equal to data") - - -if __name__ == "__main__": - test_phase_vocoder_compare() - test_phase_vocoder_eager() - test_phase_vocoder_pipeline() - test_phase_vocoder_invalid_input() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio as audio +from mindspore import log as logger + + +def gen(shape): + np.random.seed(0) + data = np.random.random(shape) + yield (np.array(data, dtype=np.float32),) + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + count_unequal_element(data_expected, data_me, rtol, atol) + + +def test_phase_vocoder_compare(): + """ + Feature: PhaseVocoder + Description: Mindspore eager mode checking precision + Expectation: The returned result is as expected + """ + indata_0 = np.array([[[[0.43189, 2.3049924], + [-0.01202229, 0.9176453], + [-0.6258611, 0.66475236], + [0.13541847, 1.2829605], + [0.9725325, 1.1669061]], + [[-0.35001752, -1.0989336], + [-1.4930767, 0.86829656], + [0.3355314, -0.41216415], + [-1.1828239, 1.0075365], + [-0.19343425, 0.38364533]]]]).astype('float32') + indata_1 = np.array([[[[0.43189, 2.3049924], + [-0.01202229, 0.9176453], + [-0.6258611, 0.66475236], + [0.13541847, 1.2829605], + [0.9725325, 1.1669061]], + [[-0.35001752, -1.0989336], + [-1.4930767, 0.86829656], + [0.3355314, -0.41216415], + [-1.1828239, 1.0075365], + [-0.19343425, 0.38364533]]]]).astype('float64') + rate = 2. + phase_advance_0 = np.array([[0.0000], [3.9270]]).astype('float32') + op_0 = audio.PhaseVocoder(rate, phase_advance_0) + phase_advance_1 = np.array([[0.0000], [3.9270]]).astype('float64') + op_1 = audio.PhaseVocoder(rate, phase_advance_1) + outdata_0 = op_0(indata_0) + outdata_1 = op_1(indata_1) + stand_outdata = np.array([[[[0.43189007, 2.3049924], + [-0.01196056, 0.9129374], + [1.1385509, 1.00558]], + [[-0.35001755, -1.0989336], + [-0.4594292, 0.26718047], + [0.404371, -0.14520557]]]]).astype('float32') + allclose_nparray(outdata_0, stand_outdata, 0.0001, 0.0001) + allclose_nparray(outdata_1, stand_outdata, 0.0001, 0.0001) + + +def test_phase_vocoder_eager(): + """ + Feature: PhaseVocoder + Description: Mindspore eager mode with normal testcase + Expectation: The returned result is as expected + """ + logger.info("test PhaseVocoder op in eager mode") + stft = next(gen([10, 10, 10, 2]))[0] + out_put = audio.PhaseVocoder(1.3, np.random.randn(10, 1).astype('float32'))(stft) + assert out_put.shape == (10, 10, 8, 2) + + +def test_phase_vocoder_pipeline(): + """ + Feature: PhaseVocoder + Description: Mindspore pipeline mode with normal testcase + Expectation: The returned result is as expected + """ + logger.info("test PhaseVocoder op in pipeline mode") + + generator = gen([32, 33, 333, 2]) + data1 = ds.GeneratorDataset(source=generator, column_names=["input"]) + + transforms = [audio.PhaseVocoder(0.8, np.random.randn(33, 1).astype('float32'))] + data1 = data1.map(operations=transforms, input_columns=["input"]) + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["input"] + assert out_put.shape == (32, 33, 417, 2) + + +def test_phase_vocoder_invalid_input(): + """ + Feature: PhaseVocoder + Description: Mindspore eager mode with invalid input + Expectation: The returned result is as expected + """ + def test_invalid_param(test_name, rate, phase_advance, error, error_msg): + logger.info("Test PhaseVocoder with wrong params: {0}".format(test_name)) + with pytest.raises(error) as error_info: + _ = audio.PhaseVocoder(rate, phase_advance) + assert error_msg in str(error_info.value) + + def test_invalid_input(test_name, spec, rate, phase_advance, error, error_msg): + logger.info("Test PhaseVocoder with wrong params: {0}".format(test_name)) + with pytest.raises(error) as error_info: + _ = audio.PhaseVocoder(rate, phase_advance)(spec) + assert error_msg in str(error_info.value) + + test_invalid_param("invalid phase_advance", 2, None, TypeError, + "Argument phase_advance with value None is not of type") + test_invalid_param("invalid phase_advance", 0, np.random.randn(4, 1), ValueError, + "Input rate is not within the required interval of (0, 16777216].") + spec = next(gen([1, 2, 2]))[0] + test_invalid_input("invalid phase_advance", spec, 1.23, np.random.randn(4), RuntimeError, + "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of .") + test_invalid_input("invalid phase_advance", spec, 1.1, np.random.randn(4, 4, 1), RuntimeError, + "PhaseVocoder: invalid parameter, 'phase_advance' should be in shape of .") + test_invalid_input("invalid input tensor", spec, 2, np.random.randn(3, 1), RuntimeError, + "PhaseVocoder: invalid parameter, 'first dimension of 'phase_advance'' should be equal") + input_tensor = np.random.randn(4, 4, 2).astype('float32') + input_phase_advance = np.random.randn(4, 1).astype('float64') + test_invalid_input("invalid input tensor", input_tensor, 2, input_phase_advance, RuntimeError, + "PhaseVocoder: invalid parameter, data type of phase_advance should be equal to data") + + +if __name__ == "__main__": + test_phase_vocoder_compare() + test_phase_vocoder_eager() + test_phase_vocoder_pipeline() + test_phase_vocoder_invalid_input() diff --git a/tests/ut/python/dataset/test_posterize.py b/tests/ut/python/dataset/test_posterize.py index fc948ad5d22..9d573df7785 100644 --- a/tests/ut/python/dataset/test_posterize.py +++ b/tests/ut/python/dataset/test_posterize.py @@ -1,155 +1,155 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= -""" -Testing Posterize op in DE -""" -import numpy as np -from numpy.testing import assert_allclose -from PIL import Image, ImageOps - -import mindspore -import mindspore.dataset as ds -import mindspore.dataset.vision as vision -import mindspore.log as logger - - -DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] -SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" - - -def test_posterize_op(): - """ - Feature: Posterize op - Description: Test eager support for Posterize Cpp implementation - Expectation: Receive correct output image from op - """ - logger.info("test_posterize_op_c") - for i in range(1, 9): - posterize_op = vision.Posterize(i) - - img_in = Image.open("../data/dataset/apple.jpg") - img_ms = posterize_op(img_in) - img_cv = np.array(ImageOps.posterize(img_in, i)) - assert_allclose(img_ms.flatten(), - img_cv.flatten(), - rtol=1e-5, - atol=0) - - -def test_posterize_exception_bit(): - """ - Feature: Posterize op - Description: Test Posterize with out of range or invalid type of input bits - Expectation: Errors and logs are as expected - """ - logger.info("test_posterize_exception_bit") - # Test max > 8 - try: - _ = vision.Posterize(9) - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input bits is not within the required interval of [0, 8]." - # Test min < 1 - try: - _ = vision.Posterize(-1) - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input bits is not within the required interval of [0, 8]." - # Test wrong type (not uint8) - try: - _ = vision.Posterize(1.1) - except TypeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Argument bits with value 1.1 is not of type [], but got ." - # Test wrong number of bits - try: - _ = vision.Posterize((1, 1, 1)) - except TypeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Argument bits with value (1, 1, 1) is not of type [], but got ." - - -def test_data_type_with_posterize(): - """ - Feature: Posterize op - Description: Test Posterize only support type CV_8S/CV_8U - Expectation: Errors and logs are as expected - """ - logger.info("test_data_type_with_posterize") - - data_dir_10 = "../data/dataset/testCifar10Data" - dataset = ds.Cifar10Dataset(data_dir_10) - - rescale_op = vision.Rescale((1.0 / 255.0), 0.0) - dataset = dataset.map(operations=rescale_op, input_columns=["image"]) - - posterize_op = vision.Posterize(4) - dataset = dataset.map(operations=posterize_op, input_columns=["image"], num_parallel_workers=1) - - try: - _ = dataset.output_shapes() - except RuntimeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "data type of input image should be int" in str(e) - - -def test_posterize_pipeline(): - """ - Feature: Posterize op - Description: Test Posterize C implementation Pipeline - Expectation: Pass without error - """ - # First dataset - transforms1 = [vision.Decode(), vision.Resize([64, 64])] - transforms1 = mindspore.dataset.transforms.transforms.Compose( - transforms1) - ds1 = ds.TFRecordDataset(DATA_DIR, - SCHEMA_DIR, - columns_list=["image"], - shuffle=False) - ds1 = ds1.map(operations=transforms1, input_columns=["image"]) - - # Second dataset - transforms2 = [ - vision.Decode(), - vision.Resize([64, 64]), - vision.Posterize(8) - ] - transform2 = mindspore.dataset.transforms.transforms.Compose( - transforms2) - ds2 = ds.TFRecordDataset(DATA_DIR, - SCHEMA_DIR, - columns_list=["image"], - shuffle=False) - ds2 = ds2.map(operations=transform2, input_columns=["image"]) - - num_iter = 0 - for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), - ds2.create_dict_iterator(num_epochs=1)): - num_iter += 1 - ori_img = data1["image"].asnumpy() - cvt_img = data2["image"].asnumpy() - assert_allclose(ori_img.flatten(), - cvt_img.flatten(), - rtol=1e-5, - atol=0) - assert ori_img.shape == cvt_img.shape - - -if __name__ == "__main__": - test_posterize_op() - test_posterize_exception_bit() - test_data_type_with_posterize() - test_posterize_pipeline() +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +""" +Testing Posterize op in DE +""" +import numpy as np +from numpy.testing import assert_allclose +from PIL import Image, ImageOps + +import mindspore +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +import mindspore.log as logger + + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + + +def test_posterize_op(): + """ + Feature: Posterize op + Description: Test eager support for Posterize Cpp implementation + Expectation: Receive correct output image from op + """ + logger.info("test_posterize_op_c") + for i in range(1, 9): + posterize_op = vision.Posterize(i) + + img_in = Image.open("../data/dataset/apple.jpg") + img_ms = posterize_op(img_in) + img_cv = np.array(ImageOps.posterize(img_in, i)) + assert_allclose(img_ms.flatten(), + img_cv.flatten(), + rtol=1e-5, + atol=0) + + +def test_posterize_exception_bit(): + """ + Feature: Posterize op + Description: Test Posterize with out of range or invalid type of input bits + Expectation: Errors and logs are as expected + """ + logger.info("test_posterize_exception_bit") + # Test max > 8 + try: + _ = vision.Posterize(9) + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert str(e) == "Input bits is not within the required interval of [0, 8]." + # Test min < 1 + try: + _ = vision.Posterize(-1) + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert str(e) == "Input bits is not within the required interval of [0, 8]." + # Test wrong type (not uint8) + try: + _ = vision.Posterize(1.1) + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert str(e) == "Argument bits with value 1.1 is not of type [], but got ." + # Test wrong number of bits + try: + _ = vision.Posterize((1, 1, 1)) + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert str(e) == "Argument bits with value (1, 1, 1) is not of type [], but got ." + + +def test_data_type_with_posterize(): + """ + Feature: Posterize op + Description: Test Posterize only support type CV_8S/CV_8U + Expectation: Errors and logs are as expected + """ + logger.info("test_data_type_with_posterize") + + data_dir_10 = "../data/dataset/testCifar10Data" + dataset = ds.Cifar10Dataset(data_dir_10) + + rescale_op = vision.Rescale((1.0 / 255.0), 0.0) + dataset = dataset.map(operations=rescale_op, input_columns=["image"]) + + posterize_op = vision.Posterize(4) + dataset = dataset.map(operations=posterize_op, input_columns=["image"], num_parallel_workers=1) + + try: + _ = dataset.output_shapes() + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "data type of input image should be int" in str(e) + + +def test_posterize_pipeline(): + """ + Feature: Posterize op + Description: Test Posterize C implementation Pipeline + Expectation: Pass without error + """ + # First dataset + transforms1 = [vision.Decode(), vision.Resize([64, 64])] + transforms1 = mindspore.dataset.transforms.transforms.Compose( + transforms1) + ds1 = ds.TFRecordDataset(DATA_DIR, + SCHEMA_DIR, + columns_list=["image"], + shuffle=False) + ds1 = ds1.map(operations=transforms1, input_columns=["image"]) + + # Second dataset + transforms2 = [ + vision.Decode(), + vision.Resize([64, 64]), + vision.Posterize(8) + ] + transform2 = mindspore.dataset.transforms.transforms.Compose( + transforms2) + ds2 = ds.TFRecordDataset(DATA_DIR, + SCHEMA_DIR, + columns_list=["image"], + shuffle=False) + ds2 = ds2.map(operations=transform2, input_columns=["image"]) + + num_iter = 0 + for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), + ds2.create_dict_iterator(num_epochs=1)): + num_iter += 1 + ori_img = data1["image"].asnumpy() + cvt_img = data2["image"].asnumpy() + assert_allclose(ori_img.flatten(), + cvt_img.flatten(), + rtol=1e-5, + atol=0) + assert ori_img.shape == cvt_img.shape + + +if __name__ == "__main__": + test_posterize_op() + test_posterize_exception_bit() + test_data_type_with_posterize() + test_posterize_pipeline() diff --git a/tests/ut/python/dataset/test_read_file.py b/tests/ut/python/dataset/test_read_file.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_read_image.py b/tests/ut/python/dataset/test_read_image.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_vectors.py b/tests/ut/python/dataset/test_vectors.py index e41439b89ec..9434d6e4f6f 100644 --- a/tests/ut/python/dataset/test_vectors.py +++ b/tests/ut/python/dataset/test_vectors.py @@ -1,235 +1,235 @@ -# Copyright 2021-2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import numpy as np -import pytest - -from mindspore import log -import mindspore.dataset as ds -import mindspore.dataset.text as text -import mindspore.dataset.text.transforms as T - -DATASET_ROOT_PATH = "../data/dataset/testVectors/" - - -def test_vectors_all_tovectors_params_eager(): - """ - Feature: Vectors - Description: Test with all parameters which include `unk_init` - and `lower_case_backup` in function ToVectors in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=4) - myUnk = [-1, -1, -1, -1, -1, -1] - to_vectors = T.ToVectors(vectors, unk_init=myUnk, lower_case_backup=True) - result1 = to_vectors("Ok") - result2 = to_vectors("!") - result3 = to_vectors("This") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_vectors_from_file(): - """ - Feature: Vectors - Description: Test with only default parameter - Expectation: Output is equal to the expected value - """ - vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt") - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - - -def test_vectors_from_file_all_buildfromfile_params(): - """ - Feature: Vectors - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile - Expectation: Output is equal to the expected value - """ - vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=100) - to_vectors = text.ToVectors(vectors) - data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) - data = data.map(operations=to_vectors, input_columns=["text"]) - ind = 0 - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0, 0, 0, 0, 0, 0], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): - res_array = np.array(res[ind], dtype=np.float32) - assert np.array_equal(res_array, d["text"]), ind - ind += 1 - - -def test_vectors_from_file_all_buildfromfile_params_eager(): - """ - Feature: Vectors - Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=4) - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_vectors_from_file_eager(): - """ - Feature: Vectors - Description: Test with only default parameter in eager mode - Expectation: Output is equal to the expected value - """ - vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt") - to_vectors = T.ToVectors(vectors) - result1 = to_vectors("ok") - result2 = to_vectors("!") - result3 = to_vectors("this") - result4 = to_vectors("is") - result5 = to_vectors("my") - result6 = to_vectors("home") - result7 = to_vectors("none") - res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], - [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], - [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], - [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], - [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], - [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], - [0, 0, 0, 0, 0, 0]] - res_array = np.array(res, dtype=np.float32) - - assert np.array_equal(result1, res_array[0]) - assert np.array_equal(result2, res_array[1]) - assert np.array_equal(result3, res_array[2]) - assert np.array_equal(result4, res_array[3]) - assert np.array_equal(result5, res_array[4]) - assert np.array_equal(result6, res_array[5]) - assert np.array_equal(result7, res_array[6]) - - -def test_vectors_invalid_input(): - """ - Feature: Vectors - Description: Test the validate function with invalid parameters - Expectation: Correct error is raised as expected - """ - def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, - unk_init=None, lower_case_backup=False, token="ok"): - log.info("Test Vectors with wrong input: {0}".format(test_name)) - with pytest.raises(error) as error_info: - vectors = text.Vectors.from_file(file_path, max_vectors=max_vectors) - to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) - to_vectors(token) - assert error_msg in str(error_info.value) - - test_invalid_input("Not all vectors have the same number of dimensions", - DATASET_ROOT_PATH + "vectors_dim_different.txt", error=RuntimeError, - error_msg="all vectors must have the same number of dimensions, but got dim 5 while expecting 6") - test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "vectors_empty.txt", - error=RuntimeError, error_msg="invalid file, file is empty.") - test_invalid_input("the count of `unknown_init`'s element is different with word vector.", - DATASET_ROOT_PATH + "vectors.txt", - error=RuntimeError, error_msg="ToVectors: " + - "unk_init must be the same length as vectors, but got unk_init: 2 and vectors: 6", - unk_init=[-1, -1]) - test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", error=RuntimeError, - error_msg="get real path failed") - test_invalid_input("The token is 1-dimensional", - DATASET_ROOT_PATH + "vectors_with_wrong_info.txt", error=RuntimeError, - error_msg="token with 1-dimensional vector.") - test_invalid_input("max_vectors parameter must be greater than 0", - DATASET_ROOT_PATH + "vectors.txt", error=ValueError, - error_msg="Input max_vectors is not within the required interval", max_vectors=-1) - test_invalid_input("invalid max_vectors parameter type as a float", - DATASET_ROOT_PATH + "vectors.txt", error=TypeError, - error_msg="Argument max_vectors with value 1.0 is not of type []," - " but got .", max_vectors=1.0) - test_invalid_input("invalid max_vectors parameter type as a string", - DATASET_ROOT_PATH + "vectors.txt", error=TypeError, - error_msg="Argument max_vectors with value 1 is not of type []," - " but got .", max_vectors="1") - test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "vectors.txt", error=RuntimeError, - error_msg="input tensor type should be string.", token=1.0) - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "vectors.txt", - error=TypeError, error_msg="Argument lower_case_backup with " + - "value True is not of type []," - " but got .", lower_case_backup="True") - test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "vectors.txt", - error=TypeError, error_msg="Argument lower_case_backup with " + - "value True is not of type []," - " but got .", lower_case_backup="True") - - -if __name__ == '__main__': - test_vectors_all_tovectors_params_eager() - test_vectors_from_file() - test_vectors_from_file_all_buildfromfile_params() - test_vectors_from_file_all_buildfromfile_params_eager() - test_vectors_from_file_eager() - test_vectors_invalid_input() +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest + +from mindspore import log +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.text.transforms as T + +DATASET_ROOT_PATH = "../data/dataset/testVectors/" + + +def test_vectors_all_tovectors_params_eager(): + """ + Feature: Vectors + Description: Test with all parameters which include `unk_init` + and `lower_case_backup` in function ToVectors in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=4) + myUnk = [-1, -1, -1, -1, -1, -1] + to_vectors = T.ToVectors(vectors, unk_init=myUnk, lower_case_backup=True) + result1 = to_vectors("Ok") + result2 = to_vectors("!") + result3 = to_vectors("This") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_vectors_from_file(): + """ + Feature: Vectors + Description: Test with only default parameter + Expectation: Output is equal to the expected value + """ + vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt") + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + + +def test_vectors_from_file_all_buildfromfile_params(): + """ + Feature: Vectors + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile + Expectation: Output is equal to the expected value + """ + vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=100) + to_vectors = text.ToVectors(vectors) + data = ds.TextFileDataset(DATASET_ROOT_PATH + "words.txt", shuffle=False) + data = data.map(operations=to_vectors, input_columns=["text"]) + ind = 0 + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0, 0, 0, 0, 0, 0], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + res_array = np.array(res[ind], dtype=np.float32) + assert np.array_equal(res_array, d["text"]), ind + ind += 1 + + +def test_vectors_from_file_all_buildfromfile_params_eager(): + """ + Feature: Vectors + Description: Test with all parameters which include `path` and `max_vector` in function BuildFromFile in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt", max_vectors=4) + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_vectors_from_file_eager(): + """ + Feature: Vectors + Description: Test with only default parameter in eager mode + Expectation: Output is equal to the expected value + """ + vectors = text.Vectors.from_file(DATASET_ROOT_PATH + "vectors.txt") + to_vectors = T.ToVectors(vectors) + result1 = to_vectors("ok") + result2 = to_vectors("!") + result3 = to_vectors("this") + result4 = to_vectors("is") + result5 = to_vectors("my") + result6 = to_vectors("home") + result7 = to_vectors("none") + res = [[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.04445718411], + [0.013441, 0.23682, -0.16899, 0.40951, 0.63812, 0.47709], + [0.15164, 0.30177, -0.16763, 0.17684, 0.31719, 0.33973], + [0.70853, 0.57088, -0.4716, 0.18048, 0.54449, 0.72603], + [0.68047, -0.039263, 0.30186, -0.17792, 0.42962, 0.032246], + [0.26818, 0.14346, -0.27877, 0.016257, 0.11384, 0.69923], + [0, 0, 0, 0, 0, 0]] + res_array = np.array(res, dtype=np.float32) + + assert np.array_equal(result1, res_array[0]) + assert np.array_equal(result2, res_array[1]) + assert np.array_equal(result3, res_array[2]) + assert np.array_equal(result4, res_array[3]) + assert np.array_equal(result5, res_array[4]) + assert np.array_equal(result6, res_array[5]) + assert np.array_equal(result7, res_array[6]) + + +def test_vectors_invalid_input(): + """ + Feature: Vectors + Description: Test the validate function with invalid parameters + Expectation: Correct error is raised as expected + """ + def test_invalid_input(test_name, file_path, error, error_msg, max_vectors=None, + unk_init=None, lower_case_backup=False, token="ok"): + log.info("Test Vectors with wrong input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + vectors = text.Vectors.from_file(file_path, max_vectors=max_vectors) + to_vectors = T.ToVectors(vectors, unk_init=unk_init, lower_case_backup=lower_case_backup) + to_vectors(token) + assert error_msg in str(error_info.value) + + test_invalid_input("Not all vectors have the same number of dimensions", + DATASET_ROOT_PATH + "vectors_dim_different.txt", error=RuntimeError, + error_msg="all vectors must have the same number of dimensions, but got dim 5 while expecting 6") + test_invalid_input("the file is empty.", DATASET_ROOT_PATH + "vectors_empty.txt", + error=RuntimeError, error_msg="invalid file, file is empty.") + test_invalid_input("the count of `unknown_init`'s element is different with word vector.", + DATASET_ROOT_PATH + "vectors.txt", + error=RuntimeError, error_msg="ToVectors: " + + "unk_init must be the same length as vectors, but got unk_init: 2 and vectors: 6", + unk_init=[-1, -1]) + test_invalid_input("The file not exist", DATASET_ROOT_PATH + "not_exist.txt", error=RuntimeError, + error_msg="get real path failed") + test_invalid_input("The token is 1-dimensional", + DATASET_ROOT_PATH + "vectors_with_wrong_info.txt", error=RuntimeError, + error_msg="token with 1-dimensional vector.") + test_invalid_input("max_vectors parameter must be greater than 0", + DATASET_ROOT_PATH + "vectors.txt", error=ValueError, + error_msg="Input max_vectors is not within the required interval", max_vectors=-1) + test_invalid_input("invalid max_vectors parameter type as a float", + DATASET_ROOT_PATH + "vectors.txt", error=TypeError, + error_msg="Argument max_vectors with value 1.0 is not of type []," + " but got .", max_vectors=1.0) + test_invalid_input("invalid max_vectors parameter type as a string", + DATASET_ROOT_PATH + "vectors.txt", error=TypeError, + error_msg="Argument max_vectors with value 1 is not of type []," + " but got .", max_vectors="1") + test_invalid_input("invalid token parameter type as a float", DATASET_ROOT_PATH + "vectors.txt", error=RuntimeError, + error_msg="input tensor type should be string.", token=1.0) + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "vectors.txt", + error=TypeError, error_msg="Argument lower_case_backup with " + + "value True is not of type []," + " but got .", lower_case_backup="True") + test_invalid_input("invalid lower_case_backup parameter type as a string", DATASET_ROOT_PATH + "vectors.txt", + error=TypeError, error_msg="Argument lower_case_backup with " + + "value True is not of type []," + " but got .", lower_case_backup="True") + + +if __name__ == '__main__': + test_vectors_all_tovectors_params_eager() + test_vectors_from_file() + test_vectors_from_file_all_buildfromfile_params() + test_vectors_from_file_all_buildfromfile_params_eager() + test_vectors_from_file_eager() + test_vectors_invalid_input() diff --git a/tests/ut/python/dataset/test_write_file.py b/tests/ut/python/dataset/test_write_file.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_write_jpeg.py b/tests/ut/python/dataset/test_write_jpeg.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/dataset/test_write_png.py b/tests/ut/python/dataset/test_write_png.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/metrics/test_fbeta.py b/tests/ut/python/metrics/test_fbeta.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/nn/layer/test_pool.py b/tests/ut/python/nn/layer/test_pool.py index 16dd18c52ea..3038c2c16df 100644 --- a/tests/ut/python/nn/layer/test_pool.py +++ b/tests/ut/python/nn/layer/test_pool.py @@ -1,115 +1,115 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test pooling api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -from mindspore.common.api import _cell_graph_executor - - -class MaxPoolNet(nn.Cell): - """MaxPool3d""" - - def __init__(self): - super(MaxPoolNet, self).__init__() - self.pool1 = nn.MaxPool3d(kernel_size=3, stride=1, pad_mode='pad', padding=1) - self.pool2 = nn.MaxPool3d(kernel_size=3, stride=1, pad_mode='pad', padding=1, return_indices=True) - - def construct(self, x): - output1 = self.pool1(x) - output2 = self.pool2(x) - return output1, output2 - - -def test_compile_max(): - """ - Feature: Test MaxPool3d - Description: Test the functionality of MaxPool3d - Expectation: Success - """ - net = MaxPoolNet() - x = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4, 5]), ms.float32) - _cell_graph_executor.compile(net, x) - - -class AvgPoolNet(nn.Cell): - """AvgPool3d""" - - def __init__(self): - super(AvgPoolNet, self).__init__() - self.pool = nn.AvgPool3d(kernel_size=3, stride=1) - - def construct(self, x): - return self.pool(x) - - -def test_compile_avg(): - """ - Feature: Test AvgPool3d - Description: Test the functionality of AvgPool3d - Expectation: Success - """ - net = AvgPoolNet() - x = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4, 5]), ms.float32) - _cell_graph_executor.compile(net, x) - - -class LPPool1d(nn.Cell): - """LPPool1d""" - - def __init__(self): - super(LPPool1d, self).__init__() - self.pool = nn.LPPool1d(norm_type=1, kernel_size=3, stride=1) - - def construct(self, x): - output1 = self.pool(x) - return output1 - - -def test_compile_lpool1d(): - """ - Feature: Test LPPool1d - Description: Test the functionality of LPPool1d - Expectation: Success - """ - net = LPPool1d() - x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) - y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) - _cell_graph_executor.compile(net, x) - _cell_graph_executor.compile(net, y) - - -class LPPool2d(nn.Cell): - def __init__(self): - super(LPPool2d, self).__init__() - self.pool = nn.LPPool2d(norm_type=1, kernel_size=3, stride=1) - - def construct(self, x): - out = self.pool(x) - return out - - -def test_compile_lppool2d(): - """ - Feature: Test LPPool2d - Description: Test the functionality of LPPool2d - Expectation: Success - """ - net = LPPool2d() - x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test pooling api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.api import _cell_graph_executor + + +class MaxPoolNet(nn.Cell): + """MaxPool3d""" + + def __init__(self): + super(MaxPoolNet, self).__init__() + self.pool1 = nn.MaxPool3d(kernel_size=3, stride=1, pad_mode='pad', padding=1) + self.pool2 = nn.MaxPool3d(kernel_size=3, stride=1, pad_mode='pad', padding=1, return_indices=True) + + def construct(self, x): + output1 = self.pool1(x) + output2 = self.pool2(x) + return output1, output2 + + +def test_compile_max(): + """ + Feature: Test MaxPool3d + Description: Test the functionality of MaxPool3d + Expectation: Success + """ + net = MaxPoolNet() + x = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4, 5]), ms.float32) + _cell_graph_executor.compile(net, x) + + +class AvgPoolNet(nn.Cell): + """AvgPool3d""" + + def __init__(self): + super(AvgPoolNet, self).__init__() + self.pool = nn.AvgPool3d(kernel_size=3, stride=1) + + def construct(self, x): + return self.pool(x) + + +def test_compile_avg(): + """ + Feature: Test AvgPool3d + Description: Test the functionality of AvgPool3d + Expectation: Success + """ + net = AvgPoolNet() + x = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4, 5]), ms.float32) + _cell_graph_executor.compile(net, x) + + +class LPPool1d(nn.Cell): + """LPPool1d""" + + def __init__(self): + super(LPPool1d, self).__init__() + self.pool = nn.LPPool1d(norm_type=1, kernel_size=3, stride=1) + + def construct(self, x): + output1 = self.pool(x) + return output1 + + +def test_compile_lpool1d(): + """ + Feature: Test LPPool1d + Description: Test the functionality of LPPool1d + Expectation: Success + """ + net = LPPool1d() + x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) + y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) + _cell_graph_executor.compile(net, x) + _cell_graph_executor.compile(net, y) + + +class LPPool2d(nn.Cell): + def __init__(self): + super(LPPool2d, self).__init__() + self.pool = nn.LPPool2d(norm_type=1, kernel_size=3, stride=1) + + def construct(self, x): + out = self.pool(x) + return out + + +def test_compile_lppool2d(): + """ + Feature: Test LPPool2d + Description: Test the functionality of LPPool2d + Expectation: Success + """ + net = LPPool2d() + x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/nn/test_activation.py b/tests/ut/python/nn/test_activation.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/nn/test_cell_wrapper.py b/tests/ut/python/nn/test_cell_wrapper.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/nn/test_nn_embedding.py b/tests/ut/python/nn/test_nn_embedding.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/nn/test_nn_padding.py b/tests/ut/python/nn/test_nn_padding.py index 9d0dbd36a44..efe48cc3223 100644 --- a/tests/ut/python/nn/test_nn_padding.py +++ b/tests/ut/python/nn/test_nn_padding.py @@ -1,377 +1,377 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test nn pad """ -import numpy as np -import pytest - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.nn import ConstantPad1d, ConstantPad2d, ConstantPad3d, ZeroPad2d -from mindspore.ops.composite import GradOperation - - -class ConstantPad1dNet(nn.Cell): - def __init__(self, padding, value): - super(ConstantPad1dNet, self).__init__() - self.pad = ConstantPad1d(padding, value) - self.value = value - - def construct(self, x): - return self.pad(x) - - -class ConstantPad2dNet(nn.Cell): - def __init__(self, padding, value): - super(ConstantPad2dNet, self).__init__() - self.pad = ConstantPad2d(padding, value) - self.value = value - - def construct(self, x): - return self.pad(x) - - -class ConstantPad3dNet(nn.Cell): - def __init__(self, padding, value): - super(ConstantPad3dNet, self).__init__() - self.pad = ConstantPad3d(padding, value) - self.value = value - - def construct(self, x): - return self.pad(x) - - -class ZeroPad2dNet(nn.Cell): - def __init__(self, padding): - super(ZeroPad2dNet, self).__init__() - self.pad = ZeroPad2d(padding) - - def construct(self, x): - return self.pad(x) - - -class Grad(nn.Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = GradOperation(get_all=True, sens_param=False) - self.network = network - - def construct(self, x): - return self.grad(self.network)(x) - - -def test_constant_pad_1d_infer(): - """ - Feature: ConstantPad1d - Description: Infer process of ConstantPad1d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - net = ConstantPad1dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 2====================") - padding = 1 - value = 0.5 - net = ConstantPad1dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 3====================") - padding = (1, 0) - value = 0.5 - net = ConstantPad1dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - -def test_constant_pad_1d_train(): - """ - Feature: ConstantPad1d - Description: Train process of ConstantPad1d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - grad = Grad(ConstantPad1dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 2====================") - padding = 1 - value = 0.5 - grad = Grad(ConstantPad1dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 3====================") - padding = (1, 0) - value = 0.5 - grad = Grad(ConstantPad1dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - -def test_constant_pad_2d_infer(): - """ - Feature: ConstantPad2d - Description: Infer process of ConstantPad2d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - net = ConstantPad2dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 2====================") - padding = 1 - value = 0.5 - net = ConstantPad2dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 3====================") - padding = (1, 1, 0, 1) - value = 0.5 - net = ConstantPad2dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - -def test_constant_pad_2d_train(): - """ - Feature: ConstantPad3d - Description: Train process of ConstantPad2d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - grad = Grad(ConstantPad2dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 2====================") - padding = 1 - value = 0.5 - grad = Grad(ConstantPad2dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 3====================") - padding = (1, 1, 0, 1) - value = 0.5 - grad = Grad(ConstantPad2dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - -def test_constant_pad_3d_infer(): - """ - Feature: ConstantPad3d - Description: Infer process of ConstantPad3d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - net = ConstantPad3dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 2====================") - padding = 1 - value = 0.5 - net = ConstantPad3dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 3====================") - padding = (1, 1, 0, 1, 1, 0) - value = 0.5 - net = ConstantPad3dNet(padding, value) - output = net(Tensor(x)) - print(output) - print(output.shape) - - -def test_constant_pad_3d_train(): - """ - Feature: ConstantPad3d - Description: Train process of ConstantPad3d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - value = 0.5 - grad = Grad(ConstantPad3dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 2====================") - padding = 1 - value = 0.5 - grad = Grad(ConstantPad3dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - print("=================case 3====================") - padding = (1, 1, 0, 1, 1, 0) - value = 0.5 - grad = Grad(ConstantPad3dNet(padding, value)) - output = grad(Tensor(x)) - print(output) - - -def test_zero_pad_2d_infer(): - """ - Feature: ZeroPad2d - Description: Infer process of ZeroPad2d with three type parameters. - Expectation: success - """ - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - net = ZeroPad2dNet(padding) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 2====================") - padding = 1 - net = ZeroPad2dNet(padding) - output = net(Tensor(x)) - print(output) - print(output.shape) - - print("=================case 3====================") - padding = (1, 1, 0, 1) - net = ZeroPad2dNet(padding) - output = net(Tensor(x)) - print(output) - print(output.shape) - - -def test_zero_pad_2d_train(): - """ - Feature: ZeroPad2d - Description: Train process of ZeroPad2d with three type parameters. - Expectation: success - """ - - x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) - print("=================case 1====================") - padding = (0, 1) - grad = Grad(ZeroPad2dNet(padding)) - output = grad(Tensor(x)) - print(output) - - print("=================case 2====================") - padding = 1 - grad = Grad(ZeroPad2dNet(padding)) - output = grad(Tensor(x)) - print(output) - - print("=================case 3====================") - padding = (1, 1, 0, 1) - grad = Grad(ZeroPad2dNet(padding)) - output = grad(Tensor(x)) - print(output) - - -def test_invalid_padding_reflection_pad_1d(): - """ - Feature: ReflectionPad1d - Description: test 5 cases of invalid input. - Expectation: success - """ - # case 1: padding is not int or tuple - padding = '-1' - with pytest.raises(TypeError): - nn.ReflectionPad1d(padding) - - # case 2: padding length is not divisible by 2 - padding = (1, 2, 2) - with pytest.raises(ValueError): - nn.ReflectionPad1d(padding) - - # case 3: padding element is not int - padding = ('2', 2) - with pytest.raises(TypeError): - nn.ReflectionPad1d(padding) - - # case 4: negative padding - padding = (-1, 2) - with pytest.raises(ValueError): - nn.ReflectionPad1d(padding) - - # case 5: padding dimension does not match tensor dimension - padding = (1, 1, 1, 1, 1, 1, 1, 1) - x = Tensor([[1, 2, 3], [1, 2, 3]]) - with pytest.raises(ValueError): - nn.ReflectionPad1d(padding)(x) - - - -def test_invalid_padding_reflection_pad_2d(): - """ - Feature: ReflectionPad2d - Description: test 5 cases of invalid input. - Expectation: success - """ - # case 1: padding is not int or tuple - padding = '-1' - with pytest.raises(TypeError): - nn.ReflectionPad2d(padding) - - # case 2: padding length is not divisible by 2 - padding = (1, 2, 2) - with pytest.raises(ValueError): - nn.ReflectionPad2d(padding) - - # case 3: padding element is not int - padding = ('2', 2) - with pytest.raises(TypeError): - nn.ReflectionPad2d(padding) - - # case 4: negative padding - padding = (-1, 2) - with pytest.raises(ValueError): - nn.ReflectionPad2d(padding) - - # case 5: padding dimension does not match tensor dimension - padding = (1, 1, 1, 1, 1, 1, 1, 1) - x = Tensor([[1, 2, 3], [1, 2, 3]]) - with pytest.raises(ValueError): - nn.ReflectionPad2d(padding)(x) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test nn pad """ +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import ConstantPad1d, ConstantPad2d, ConstantPad3d, ZeroPad2d +from mindspore.ops.composite import GradOperation + + +class ConstantPad1dNet(nn.Cell): + def __init__(self, padding, value): + super(ConstantPad1dNet, self).__init__() + self.pad = ConstantPad1d(padding, value) + self.value = value + + def construct(self, x): + return self.pad(x) + + +class ConstantPad2dNet(nn.Cell): + def __init__(self, padding, value): + super(ConstantPad2dNet, self).__init__() + self.pad = ConstantPad2d(padding, value) + self.value = value + + def construct(self, x): + return self.pad(x) + + +class ConstantPad3dNet(nn.Cell): + def __init__(self, padding, value): + super(ConstantPad3dNet, self).__init__() + self.pad = ConstantPad3d(padding, value) + self.value = value + + def construct(self, x): + return self.pad(x) + + +class ZeroPad2dNet(nn.Cell): + def __init__(self, padding): + super(ZeroPad2dNet, self).__init__() + self.pad = ZeroPad2d(padding) + + def construct(self, x): + return self.pad(x) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=False) + self.network = network + + def construct(self, x): + return self.grad(self.network)(x) + + +def test_constant_pad_1d_infer(): + """ + Feature: ConstantPad1d + Description: Infer process of ConstantPad1d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + net = ConstantPad1dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 2====================") + padding = 1 + value = 0.5 + net = ConstantPad1dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 3====================") + padding = (1, 0) + value = 0.5 + net = ConstantPad1dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + +def test_constant_pad_1d_train(): + """ + Feature: ConstantPad1d + Description: Train process of ConstantPad1d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + grad = Grad(ConstantPad1dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 2====================") + padding = 1 + value = 0.5 + grad = Grad(ConstantPad1dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 3====================") + padding = (1, 0) + value = 0.5 + grad = Grad(ConstantPad1dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + +def test_constant_pad_2d_infer(): + """ + Feature: ConstantPad2d + Description: Infer process of ConstantPad2d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + net = ConstantPad2dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 2====================") + padding = 1 + value = 0.5 + net = ConstantPad2dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 3====================") + padding = (1, 1, 0, 1) + value = 0.5 + net = ConstantPad2dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + +def test_constant_pad_2d_train(): + """ + Feature: ConstantPad3d + Description: Train process of ConstantPad2d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + grad = Grad(ConstantPad2dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 2====================") + padding = 1 + value = 0.5 + grad = Grad(ConstantPad2dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 3====================") + padding = (1, 1, 0, 1) + value = 0.5 + grad = Grad(ConstantPad2dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + +def test_constant_pad_3d_infer(): + """ + Feature: ConstantPad3d + Description: Infer process of ConstantPad3d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + net = ConstantPad3dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 2====================") + padding = 1 + value = 0.5 + net = ConstantPad3dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 3====================") + padding = (1, 1, 0, 1, 1, 0) + value = 0.5 + net = ConstantPad3dNet(padding, value) + output = net(Tensor(x)) + print(output) + print(output.shape) + + +def test_constant_pad_3d_train(): + """ + Feature: ConstantPad3d + Description: Train process of ConstantPad3d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + value = 0.5 + grad = Grad(ConstantPad3dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 2====================") + padding = 1 + value = 0.5 + grad = Grad(ConstantPad3dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + print("=================case 3====================") + padding = (1, 1, 0, 1, 1, 0) + value = 0.5 + grad = Grad(ConstantPad3dNet(padding, value)) + output = grad(Tensor(x)) + print(output) + + +def test_zero_pad_2d_infer(): + """ + Feature: ZeroPad2d + Description: Infer process of ZeroPad2d with three type parameters. + Expectation: success + """ + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + net = ZeroPad2dNet(padding) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 2====================") + padding = 1 + net = ZeroPad2dNet(padding) + output = net(Tensor(x)) + print(output) + print(output.shape) + + print("=================case 3====================") + padding = (1, 1, 0, 1) + net = ZeroPad2dNet(padding) + output = net(Tensor(x)) + print(output) + print(output.shape) + + +def test_zero_pad_2d_train(): + """ + Feature: ZeroPad2d + Description: Train process of ZeroPad2d with three type parameters. + Expectation: success + """ + + x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32) + print("=================case 1====================") + padding = (0, 1) + grad = Grad(ZeroPad2dNet(padding)) + output = grad(Tensor(x)) + print(output) + + print("=================case 2====================") + padding = 1 + grad = Grad(ZeroPad2dNet(padding)) + output = grad(Tensor(x)) + print(output) + + print("=================case 3====================") + padding = (1, 1, 0, 1) + grad = Grad(ZeroPad2dNet(padding)) + output = grad(Tensor(x)) + print(output) + + +def test_invalid_padding_reflection_pad_1d(): + """ + Feature: ReflectionPad1d + Description: test 5 cases of invalid input. + Expectation: success + """ + # case 1: padding is not int or tuple + padding = '-1' + with pytest.raises(TypeError): + nn.ReflectionPad1d(padding) + + # case 2: padding length is not divisible by 2 + padding = (1, 2, 2) + with pytest.raises(ValueError): + nn.ReflectionPad1d(padding) + + # case 3: padding element is not int + padding = ('2', 2) + with pytest.raises(TypeError): + nn.ReflectionPad1d(padding) + + # case 4: negative padding + padding = (-1, 2) + with pytest.raises(ValueError): + nn.ReflectionPad1d(padding) + + # case 5: padding dimension does not match tensor dimension + padding = (1, 1, 1, 1, 1, 1, 1, 1) + x = Tensor([[1, 2, 3], [1, 2, 3]]) + with pytest.raises(ValueError): + nn.ReflectionPad1d(padding)(x) + + + +def test_invalid_padding_reflection_pad_2d(): + """ + Feature: ReflectionPad2d + Description: test 5 cases of invalid input. + Expectation: success + """ + # case 1: padding is not int or tuple + padding = '-1' + with pytest.raises(TypeError): + nn.ReflectionPad2d(padding) + + # case 2: padding length is not divisible by 2 + padding = (1, 2, 2) + with pytest.raises(ValueError): + nn.ReflectionPad2d(padding) + + # case 3: padding element is not int + padding = ('2', 2) + with pytest.raises(TypeError): + nn.ReflectionPad2d(padding) + + # case 4: negative padding + padding = (-1, 2) + with pytest.raises(ValueError): + nn.ReflectionPad2d(padding) + + # case 5: padding dimension does not match tensor dimension + padding = (1, 1, 1, 1, 1, 1, 1, 1) + x = Tensor([[1, 2, 3], [1, 2, 3]]) + with pytest.raises(ValueError): + nn.ReflectionPad2d(padding)(x) diff --git a/tests/ut/python/nn/test_norm.py b/tests/ut/python/nn/test_norm.py index f1628b017fb..2e21e455468 100644 --- a/tests/ut/python/nn/test_norm.py +++ b/tests/ut/python/nn/test_norm.py @@ -1,37 +1,37 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test norm """ -import numpy as np - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.common.api import _cell_graph_executor -from ..ut_filter import non_graph_engine - - -class NormNet(nn.Cell): - def __init__(self): - super(NormNet, self).__init__() - self.norm = nn.Norm() - - def construct(self, x): - return self.norm(x) - - -@non_graph_engine -def test_compile_norm(): - net = NormNet() - x = Tensor(np.array([2.0, 1.0])) - _cell_graph_executor.compile(net, x) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test norm """ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import _cell_graph_executor +from ..ut_filter import non_graph_engine + + +class NormNet(nn.Cell): + def __init__(self): + super(NormNet, self).__init__() + self.norm = nn.Norm() + + def construct(self, x): + return self.norm(x) + + +@non_graph_engine +def test_compile_norm(): + net = NormNet() + x = Tensor(np.array([2.0, 1.0])) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/onnx/test_onnx_not_equal.py b/tests/ut/python/onnx/test_onnx_not_equal.py index ffa985566e2..f4231c7063d 100644 --- a/tests/ut/python/onnx/test_onnx_not_equal.py +++ b/tests/ut/python/onnx/test_onnx_not_equal.py @@ -1,51 +1,51 @@ -# Copyright 2022-2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import os -import numpy as np -import onnxruntime as ort - -import mindspore as ms -from mindspore import ops, nn, Tensor - - -class Net(nn.Cell): - def __init__(self): - super().__init__() - self.op = ops.NotEqual() - - def construct(self, x, y): - return self.op(x, y) - - -def test_export_not_equal(): - """ - Feature: Export ops.NotEqual to onnx - Description: Export ops.NotEqual to onnx - Expectation: success - """ - arr1 = np.array([1, 2, 3]).astype(np.float32) - arr2 = np.array([1, 0, 3]).astype(np.float32) - a = Tensor(arr1) - b = Tensor(arr2) - net = Net() - ms.export(net, a, b, file_name='ne', file_format='ONNX') - if os.path.isfile("./ne.onnx"): - session = ort.InferenceSession("./ne.onnx") - output = session.run(None, {"x": arr1, "y": arr2})[0] - expected = np.array([False, True, False]) - assert np.array_equal(output, expected) - os.remove("./ne.onnx") - else: - raise RuntimeError(f"Export operator NotEqual to ONNX failed!") +# Copyright 2022-2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import numpy as np +import onnxruntime as ort + +import mindspore as ms +from mindspore import ops, nn, Tensor + + +class Net(nn.Cell): + def __init__(self): + super().__init__() + self.op = ops.NotEqual() + + def construct(self, x, y): + return self.op(x, y) + + +def test_export_not_equal(): + """ + Feature: Export ops.NotEqual to onnx + Description: Export ops.NotEqual to onnx + Expectation: success + """ + arr1 = np.array([1, 2, 3]).astype(np.float32) + arr2 = np.array([1, 0, 3]).astype(np.float32) + a = Tensor(arr1) + b = Tensor(arr2) + net = Net() + ms.export(net, a, b, file_name='ne', file_format='ONNX') + if os.path.isfile("./ne.onnx"): + session = ort.InferenceSession("./ne.onnx") + output = session.run(None, {"x": arr1, "y": arr2})[0] + expected = np.array([False, True, False]) + assert np.array_equal(output, expected) + os.remove("./ne.onnx") + else: + raise RuntimeError(f"Export operator NotEqual to ONNX failed!") diff --git a/tests/ut/python/ops/test_array_ops_check.py b/tests/ut/python/ops/test_array_ops_check.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/ops/test_dynamic_shape.py b/tests/ut/python/ops/test_dynamic_shape.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/ops/test_flip.py b/tests/ut/python/ops/test_flip.py index 1b55628641f..220462ce718 100644 --- a/tests/ut/python/ops/test_flip.py +++ b/tests/ut/python/ops/test_flip.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test flip api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.flip(x, (0, 2)) - - -def test_compile_flip(): - """ - Feature: Test filp - Description: Test the functionality of flip - Expectation: Success - """ - net = Roll() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test flip api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.flip(x, (0, 2)) + + +def test_compile_flip(): + """ + Feature: Test filp + Description: Test the functionality of flip + Expectation: Success + """ + net = Roll() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_fliplr.py b/tests/ut/python/ops/test_fliplr.py index 6e4a0ef7bad..0c66d67f262 100644 --- a/tests/ut/python/ops/test_fliplr.py +++ b/tests/ut/python/ops/test_fliplr.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test fliplr api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.fliplr(x) - - -def test_compile_fliplr(): - """ - Feature: Test filplr - Description: Test the functionality of fliplr - Expectation: Success - """ - net = Roll() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test fliplr api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.fliplr(x) + + +def test_compile_fliplr(): + """ + Feature: Test filplr + Description: Test the functionality of fliplr + Expectation: Success + """ + net = Roll() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_flipud.py b/tests/ut/python/ops/test_flipud.py index 6a9137a398d..a2952f73efb 100644 --- a/tests/ut/python/ops/test_flipud.py +++ b/tests/ut/python/ops/test_flipud.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test flipud api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.flipud(x) - - -def test_compile_flipud(): - """ - Feature: Test flipud - Description: Test the functionality of flipud - Expectation: Success - """ - net = Roll() - x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test flipud api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.flipud(x) + + +def test_compile_flipud(): + """ + Feature: Test flipud + Description: Test the functionality of flipud + Expectation: Success + """ + net = Roll() + x = ms.Tensor(np.arange(8).reshape((2, 2, 2))) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_func_arange_ut.py b/tests/ut/python/ops/test_func_arange_ut.py index 8f93c809293..2c1eb4a35a1 100644 --- a/tests/ut/python/ops/test_func_arange_ut.py +++ b/tests/ut/python/ops/test_func_arange_ut.py @@ -1,35 +1,35 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test arange api -""" -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Net(nn.Cell): - def construct(self, start=0, end=None, step=1, dtype=None): - return ops.arange(start, end, step, dtype=dtype) - - -def test_arange_normal(): - """ - Feature: arange - Description: Test the functionality of arange - Expectation: success - """ - net = Net() - _cell_graph_executor.compile(net, 1, 6) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test arange api +""" +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Net(nn.Cell): + def construct(self, start=0, end=None, step=1, dtype=None): + return ops.arange(start, end, step, dtype=dtype) + + +def test_arange_normal(): + """ + Feature: arange + Description: Test the functionality of arange + Expectation: success + """ + net = Net() + _cell_graph_executor.compile(net, 1, 6) diff --git a/tests/ut/python/ops/test_func_real.py b/tests/ut/python/ops/test_func_real.py index bac2daeaecb..91e8b238eb0 100644 --- a/tests/ut/python/ops/test_func_real.py +++ b/tests/ut/python/ops/test_func_real.py @@ -1,38 +1,38 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Net(nn.Cell): - def construct(self, x): - output = ops.real(x) - return output - - -def test_real_normal(): - """ - Feature: Test real - Description: Test the functionality of real - Expectation: Success - """ - net = Net() - x = ms.Tensor(np.asarray(np.complex(1.3 + 0.4j)), ms.complex64) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Net(nn.Cell): + def construct(self, x): + output = ops.real(x) + return output + + +def test_real_normal(): + """ + Feature: Test real + Description: Test the functionality of real + Expectation: Success + """ + net = Net() + x = ms.Tensor(np.asarray(np.complex(1.3 + 0.4j)), ms.complex64) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_func_reciprocal.py b/tests/ut/python/ops/test_func_reciprocal.py index a38fce50659..bf6ba1ed016 100644 --- a/tests/ut/python/ops/test_func_reciprocal.py +++ b/tests/ut/python/ops/test_func_reciprocal.py @@ -1,38 +1,38 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Net(nn.Cell): - def construct(self, x): - output = ops.reciprocal(x) - return output - - -def test_reciprocal_normal(): - """ - Feature: Test reciprocal - Description: Test the functionality of reciprocal - Expectation: Success - """ - net = Net() - x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Net(nn.Cell): + def construct(self, x): + output = ops.reciprocal(x) + return output + + +def test_reciprocal_normal(): + """ + Feature: Test reciprocal + Description: Test the functionality of reciprocal + Expectation: Success + """ + net = Net() + x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_func_rsqrt.py b/tests/ut/python/ops/test_func_rsqrt.py index a6cd03bc9e3..bde36d0106a 100644 --- a/tests/ut/python/ops/test_func_rsqrt.py +++ b/tests/ut/python/ops/test_func_rsqrt.py @@ -1,36 +1,36 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Net(nn.Cell): - def construct(self, x): - output = ops.rsqrt(x) - return output - - -def test_rsqrt_normal(): - """ - Feature: Test rsqrt - Description: Test the functionality of rsqrt - Expectation: Success - """ - net = Net() - x = ms.Tensor([-0.0370, 0.2970, 1.5420, -0.9105]) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Net(nn.Cell): + def construct(self, x): + output = ops.rsqrt(x) + return output + + +def test_rsqrt_normal(): + """ + Feature: Test rsqrt + Description: Test the functionality of rsqrt + Expectation: Success + """ + net = Net() + x = ms.Tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_int64_support.py b/tests/ut/python/ops/test_int64_support.py index e083b815835..d53168c93fa 100644 --- a/tests/ut/python/ops/test_int64_support.py +++ b/tests/ut/python/ops/test_int64_support.py @@ -1,39 +1,39 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test_int64_support """ -import numpy as np -import mindspore.nn as nn -from mindspore import context -from mindspore.common.tensor import Tensor -import mindspore as ms - - -def test_parser_support_int64_normal_graph(): - """ test tensor index support int64 -index, graph mode""" - class Net(nn.Cell): - def __init__(self): - super().__init__() - - def construct(self, inputs, tensor_in): - result = inputs[tensor_in] - return result - - context.set_context(mode=context.GRAPH_MODE) - input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32) - input_me_x = Tensor(input_np_x, ms.float32) - input_np_y = np.random.randint(2, size=[1, 2]).astype(np.int64) - tensor = Tensor(input_np_y, ms.int64) - net = Net() - net(input_me_x, tensor).asnumpy() +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_int64_support """ +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +import mindspore as ms + + +def test_parser_support_int64_normal_graph(): + """ test tensor index support int64 -index, graph mode""" + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, inputs, tensor_in): + result = inputs[tensor_in] + return result + + context.set_context(mode=context.GRAPH_MODE) + input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_x = Tensor(input_np_x, ms.float32) + input_np_y = np.random.randint(2, size=[1, 2]).astype(np.int64) + tensor = Tensor(input_np_y, ms.int64) + net = Net() + net(input_me_x, tensor).asnumpy() diff --git a/tests/ut/python/ops/test_is_floating_point.py b/tests/ut/python/ops/test_is_floating_point.py index 868c5b87e9a..2c1a06f1679 100644 --- a/tests/ut/python/ops/test_is_floating_point.py +++ b/tests/ut/python/ops/test_is_floating_point.py @@ -1,38 +1,38 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test is floating point api -""" - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.is_floating_point(x) - - -def test_compile_is_floating_point(): - """ - Feature: Test is floating point - Description: Test the functionality of is floating point - Expectation: Success - """ - net = Roll() - x = ms.Tensor([1, 2, 3], ms.float32) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test is floating point api +""" + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.is_floating_point(x) + + +def test_compile_is_floating_point(): + """ + Feature: Test is floating point + Description: Test the functionality of is floating point + Expectation: Success + """ + net = Roll() + x = ms.Tensor([1, 2, 3], ms.float32) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_margin_ranking_loss.py b/tests/ut/python/ops/test_margin_ranking_loss.py index ce7ecf3b543..d8469e5a669 100644 --- a/tests/ut/python/ops/test_margin_ranking_loss.py +++ b/tests/ut/python/ops/test_margin_ranking_loss.py @@ -1,45 +1,45 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor - - -class MarginRankingLoss(nn.Cell): - def __init__(self, reduction): - super(MarginRankingLoss, self).__init__() - self.reduction = reduction - - def construct(self, x, y, label, margin): - return ops.margin_ranking_loss(x, y, label, margin, reduction=self.reduction) - - -@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) -def test_margin_ranking_loss(reduction): - """ - Feature: test MarginRankingLoss op with reduction none. - Description: Verify the result of MarginRankingLoss. - Expectation: expect correct forward result. - """ - loss = MarginRankingLoss(reduction) - input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) - input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) - target = Tensor(np.array([-1, -1, 1]), ms.float32) - loss(input1, input2, target, 0.0) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + + +class MarginRankingLoss(nn.Cell): + def __init__(self, reduction): + super(MarginRankingLoss, self).__init__() + self.reduction = reduction + + def construct(self, x, y, label, margin): + return ops.margin_ranking_loss(x, y, label, margin, reduction=self.reduction) + + +@pytest.mark.parametrize('reduction', ["none", "mean", "sum"]) +def test_margin_ranking_loss(reduction): + """ + Feature: test MarginRankingLoss op with reduction none. + Description: Verify the result of MarginRankingLoss. + Expectation: expect correct forward result. + """ + loss = MarginRankingLoss(reduction) + input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32) + input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32) + target = Tensor(np.array([-1, -1, 1]), ms.float32) + loss(input1, input2, target, 0.0) diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/ops/test_math_ops_check.py b/tests/ut/python/ops/test_math_ops_check.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/ops/test_nn_ops_check.py b/tests/ut/python/ops/test_nn_ops_check.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/ops/test_pool.py b/tests/ut/python/ops/test_pool.py index 0de9b9b7de5..c6b5f801bea 100644 --- a/tests/ut/python/ops/test_pool.py +++ b/tests/ut/python/ops/test_pool.py @@ -1,63 +1,63 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test pooling api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class LPPool1d(nn.Cell): - """LPPool1d""" - - def construct(self, x): - output = ops.lp_pool1d(x, norm_type=1, kernel_size=3, stride=1) - return output - - -def test_compile_lpool1d(): - """ - Feature: Test LPPool1d - Description: Test the functionality of LPPool1d - Expectation: Success - """ - net = LPPool1d() - x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) - y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) - _cell_graph_executor.compile(net, x) - _cell_graph_executor.compile(net, y) - - -class LPPool2d(nn.Cell): - """LPPool2d""" - - def construct(self, x): - out = ops.lp_pool2d(x, norm_type=1, kernel_size=3, stride=1) - return out - - -def test_compile_lppool2d(): - """ - Feature: Test LPPool2d - Description: Test the functionality of LPPool2d - Expectation: Success - """ - net = LPPool2d() - x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test pooling api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class LPPool1d(nn.Cell): + """LPPool1d""" + + def construct(self, x): + output = ops.lp_pool1d(x, norm_type=1, kernel_size=3, stride=1) + return output + + +def test_compile_lpool1d(): + """ + Feature: Test LPPool1d + Description: Test the functionality of LPPool1d + Expectation: Success + """ + net = LPPool1d() + x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)), dtype=ms.float32) + y = ms.Tensor(np.arange(3 * 4).reshape((3, 4)), dtype=ms.float32) + _cell_graph_executor.compile(net, x) + _cell_graph_executor.compile(net, y) + + +class LPPool2d(nn.Cell): + """LPPool2d""" + + def construct(self, x): + out = ops.lp_pool2d(x, norm_type=1, kernel_size=3, stride=1) + return out + + +def test_compile_lppool2d(): + """ + Feature: Test LPPool2d + Description: Test the functionality of LPPool2d + Expectation: Success + """ + net = LPPool2d() + x = ms.Tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)), dtype=ms.float32) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_repeat_interleave.py b/tests/ut/python/ops/test_repeat_interleave.py index eb4ff2b6a27..48818355c87 100644 --- a/tests/ut/python/ops/test_repeat_interleave.py +++ b/tests/ut/python/ops/test_repeat_interleave.py @@ -1,36 +1,36 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.common.api import _cell_graph_executor - - -class RepeatInterleave(nn.Cell): - def construct(self, x): - return ops.repeat_interleave(x, repeats=2, axis=0) - - -def test_repeat_interleave(): - """ - Feature: tensor.repeat_interleave - Description: Test the functionality of repeat_interleave - Expectation: success - """ - x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) - net = RepeatInterleave() - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.api import _cell_graph_executor + + +class RepeatInterleave(nn.Cell): + def construct(self, x): + return ops.repeat_interleave(x, repeats=2, axis=0) + + +def test_repeat_interleave(): + """ + Feature: tensor.repeat_interleave + Description: Test the functionality of repeat_interleave + Expectation: success + """ + x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), ms.int32) + net = RepeatInterleave() + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/ops/test_roll.py b/tests/ut/python/ops/test_roll.py index 3bff0754969..6f169db74e1 100644 --- a/tests/ut/python/ops/test_roll.py +++ b/tests/ut/python/ops/test_roll.py @@ -1,39 +1,39 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -test roll api -""" -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore.common.api import _cell_graph_executor - - -class Roll(nn.Cell): - def construct(self, x): - return ops.roll(x, shifts=2, dims=0) - - -def test_compile_roll(): - """ - Feature: Test Roll - Description: Test the functionality of roll - Expectation: Success - """ - net = Roll() - x = ms.Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) - _cell_graph_executor.compile(net, x) +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test roll api +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.api import _cell_graph_executor + + +class Roll(nn.Cell): + def construct(self, x): + return ops.roll(x, shifts=2, dims=0) + + +def test_compile_roll(): + """ + Feature: Test Roll + Description: Test the functionality of roll + Expectation: Success + """ + net = Roll() + x = ms.Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) + _cell_graph_executor.compile(net, x) diff --git a/tests/ut/python/parallel/test_layout_extend_activation.py b/tests/ut/python/parallel/test_layout_extend_activation.py index 8fcffb6d80c..dd1884748c6 100644 --- a/tests/ut/python/parallel/test_layout_extend_activation.py +++ b/tests/ut/python/parallel/test_layout_extend_activation.py @@ -1,124 +1,124 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import numpy as np - -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import context -from mindspore.common.api import _cell_graph_executor -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.parallel.shard import Layout -from mindspore.ops.operations._inner_ops import SiLU -from tests.ut.python.ops.test_math_ops import VirtualLoss - -activation_ops_map = { - "gelu": P.GeLU(), - "silu": SiLU(), - "relu": P.ReLU(), - "sigmoid": P.Sigmoid(), - "softmax": P.Softmax(), -} - -def setup_function(): - context.set_auto_parallel_context(dataset_strategy="full_batch") - - -grad_all = C.GradOperation(get_all=True) - - -class NetWithLoss(nn.Cell): - def __init__(self, network): - super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, y): - predict = self.network(y) - return self.loss(predict) - - -class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - - def construct(self, y): - return grad_all(self.network)(y) - - -def compile_net(net, input_x): - net.set_auto_parallel() - net.set_train() - phase, _ = _cell_graph_executor.compile(net, input_x) - return phase - - -class Net(nn.Cell): - def __init__(self, in_layout, out_layout=None, ops_name=None): - super().__init__() - self.activation_ops = activation_ops_map[ops_name] - self.activation_ops.shard(in_strategy=in_layout, out_strategy=out_layout) - - def construct(self, y): - out = self.activation_ops(y) - return out - - -x = Tensor(np.ones([1024, 1024]), dtype=ms.float32) - - -@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) -def test_layout_extend_base(ops_name): - """ - Feature: test layout extend - Description: dev_num is 4. - Expectation: compile success - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) - layout = Layout((2, 2), ("dp", "mp")) - layout1 = (layout("dp", "mp"),) - net = Net(layout1, ops_name=ops_name) - compile_net(net, x) - - -@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) -def test_layout_extend_batch_multi_shard(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8, batch dim multi shard. - Expectation: compile success - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "mp"), "sp"),) - net = Net(layout1, ops_name=ops_name) - compile_net(net, x) - - -@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) -def test_layout_extend_reduce_axis_multi_shard(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8, reduce dim multi shard. - Expectation: compile success - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout("dp", ("mp", "sp")),) - net = Net(layout1, ops_name=ops_name) - compile_net(net, x) +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.parallel.shard import Layout +from mindspore.ops.operations._inner_ops import SiLU +from tests.ut.python.ops.test_math_ops import VirtualLoss + +activation_ops_map = { + "gelu": P.GeLU(), + "silu": SiLU(), + "relu": P.ReLU(), + "sigmoid": P.Sigmoid(), + "softmax": P.Softmax(), +} + +def setup_function(): + context.set_auto_parallel_context(dataset_strategy="full_batch") + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, y): + predict = self.network(y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, y): + return grad_all(self.network)(y) + + +def compile_net(net, input_x): + net.set_auto_parallel() + net.set_train() + phase, _ = _cell_graph_executor.compile(net, input_x) + return phase + + +class Net(nn.Cell): + def __init__(self, in_layout, out_layout=None, ops_name=None): + super().__init__() + self.activation_ops = activation_ops_map[ops_name] + self.activation_ops.shard(in_strategy=in_layout, out_strategy=out_layout) + + def construct(self, y): + out = self.activation_ops(y) + return out + + +x = Tensor(np.ones([1024, 1024]), dtype=ms.float32) + + +@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) +def test_layout_extend_base(ops_name): + """ + Feature: test layout extend + Description: dev_num is 4. + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) + layout = Layout((2, 2), ("dp", "mp")) + layout1 = (layout("dp", "mp"),) + net = Net(layout1, ops_name=ops_name) + compile_net(net, x) + + +@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) +def test_layout_extend_batch_multi_shard(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8, batch dim multi shard. + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "mp"), "sp"),) + net = Net(layout1, ops_name=ops_name) + compile_net(net, x) + + +@pytest.mark.parametrize('ops_name', ["gelu", "silu", "relu", "sigmoid"]) +def test_layout_extend_reduce_axis_multi_shard(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8, reduce dim multi shard. + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout("dp", ("mp", "sp")),) + net = Net(layout1, ops_name=ops_name) + compile_net(net, x) diff --git a/tests/ut/python/parallel/test_layout_extend_arithmetic.py b/tests/ut/python/parallel/test_layout_extend_arithmetic.py index d5804fbd46c..9ffb9eea055 100644 --- a/tests/ut/python/parallel/test_layout_extend_arithmetic.py +++ b/tests/ut/python/parallel/test_layout_extend_arithmetic.py @@ -1,188 +1,188 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor, Parameter -from mindspore import context -from mindspore.common.api import _cell_graph_executor -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.parallel.shard import Layout -from tests.ut.python.ops.test_math_ops import VirtualLoss -from parallel.utils.utils import ParallelValidator - -arithmetic_ops_map = { - "add": P.Add(), - "sub": P.Sub(), - "mul": P.Mul(), - "div": P.Div(), - "real_div": P.RealDiv(), -} - -def setup_function(): - context.set_auto_parallel_context(dataset_strategy="full_batch") - -grad_all = C.GradOperation(get_all=True) - -class NetWithLoss(nn.Cell): - def __init__(self, network): - super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, y): - predict = self.network(y) - return self.loss(predict) - - -class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - - def construct(self, y): - return grad_all(self.network)(y) - - -def compile_net(net, input_x): - net.set_auto_parallel() - net.set_train() - phase, _ = _cell_graph_executor.compile(net, input_x) - return phase - - -class Net(nn.Cell): - def __init__(self, weight, in_layout, out_layout=None, ops_name=None): - super().__init__() - self.arithmetic_ops = arithmetic_ops_map[ops_name] - self.arithmetic_ops.shard(in_strategy=in_layout, out_strategy=out_layout) - self.relu = P.ReLU() - self.w = Parameter(weight, "w1") - - def construct(self, y): - out1 = self.arithmetic_ops(y, self.w) - out2 = self.relu(out1) - out = out1 + out2 - return out - -x = Tensor(np.ones([1024, 1024]), dtype=ms.float32) -w = Tensor(np.ones([1024, 1024]), dtype=ms.float32) - -input_1024 = Tensor(np.ones([1024]), dtype=ms.float32) -input_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32) -input_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_same_shape_same_shard(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile success - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) - first, second = input_1024_1024, x - net = Net(second, layout1, ops_name=ops_name) - phase = compile_net(net, first) - validator = ParallelValidator(net, phase) - assert validator.check_parameter_shape('w1', [256, 512]) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_same_shape_wrong_shard(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile failed - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "mp"), "sp")) - first, second = input_1024_1024, x - net = Net(second, layout1, ops_name=ops_name) - with pytest.raises(RuntimeError): - compile_net(net, first) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_same_dim_broadcast(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile success, second input broadcast - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout("None", "mp")) - first, second = input_1024_1024, input_1_1024 - net = Net(second, layout1, ops_name=ops_name) - phase = compile_net(net, first) - validator = ParallelValidator(net, phase) - assert validator.check_parameter_shape('w1', [1, 512]) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_different_dim_broadcast(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile success, second input broadcast - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout("mp",)) - first, second = input_1024_1024, input_1024 - net = Net(second, layout1, ops_name=ops_name) - phase = compile_net(net, first) - validator = ParallelValidator(net, phase) - assert validator.check_parameter_shape('w1', [512]) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_different_dim_broadcast_failed(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile success, second input broadcast - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout("None",)) - first, second = input_1024_1024, input_1024 - net = Net(second, layout1, ops_name=ops_name) - with pytest.raises(RuntimeError): - compile_net(net, first) - - -@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) -def test_layout_extend_add_same_shape_same_shard_outputlayout_not_allowed(ops_name): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile failed - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) - out_layout = (layout(("dp", "sp"), "mp"),) - first, second = input_1024_1024, x - net = Net(second, layout1, out_layout, ops_name=ops_name) - with pytest.raises(RuntimeError): - compile_net(net, first) +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.parallel.shard import Layout +from tests.ut.python.ops.test_math_ops import VirtualLoss +from parallel.utils.utils import ParallelValidator + +arithmetic_ops_map = { + "add": P.Add(), + "sub": P.Sub(), + "mul": P.Mul(), + "div": P.Div(), + "real_div": P.RealDiv(), +} + +def setup_function(): + context.set_auto_parallel_context(dataset_strategy="full_batch") + +grad_all = C.GradOperation(get_all=True) + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, y): + predict = self.network(y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, y): + return grad_all(self.network)(y) + + +def compile_net(net, input_x): + net.set_auto_parallel() + net.set_train() + phase, _ = _cell_graph_executor.compile(net, input_x) + return phase + + +class Net(nn.Cell): + def __init__(self, weight, in_layout, out_layout=None, ops_name=None): + super().__init__() + self.arithmetic_ops = arithmetic_ops_map[ops_name] + self.arithmetic_ops.shard(in_strategy=in_layout, out_strategy=out_layout) + self.relu = P.ReLU() + self.w = Parameter(weight, "w1") + + def construct(self, y): + out1 = self.arithmetic_ops(y, self.w) + out2 = self.relu(out1) + out = out1 + out2 + return out + +x = Tensor(np.ones([1024, 1024]), dtype=ms.float32) +w = Tensor(np.ones([1024, 1024]), dtype=ms.float32) + +input_1024 = Tensor(np.ones([1024]), dtype=ms.float32) +input_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32) +input_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_same_shape_same_shard(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) + first, second = input_1024_1024, x + net = Net(second, layout1, ops_name=ops_name) + phase = compile_net(net, first) + validator = ParallelValidator(net, phase) + assert validator.check_parameter_shape('w1', [256, 512]) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_same_shape_wrong_shard(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile failed + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "mp"), "sp")) + first, second = input_1024_1024, x + net = Net(second, layout1, ops_name=ops_name) + with pytest.raises(RuntimeError): + compile_net(net, first) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_same_dim_broadcast(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile success, second input broadcast + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout("None", "mp")) + first, second = input_1024_1024, input_1_1024 + net = Net(second, layout1, ops_name=ops_name) + phase = compile_net(net, first) + validator = ParallelValidator(net, phase) + assert validator.check_parameter_shape('w1', [1, 512]) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_different_dim_broadcast(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile success, second input broadcast + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout("mp",)) + first, second = input_1024_1024, input_1024 + net = Net(second, layout1, ops_name=ops_name) + phase = compile_net(net, first) + validator = ParallelValidator(net, phase) + assert validator.check_parameter_shape('w1', [512]) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_different_dim_broadcast_failed(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile success, second input broadcast + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout("None",)) + first, second = input_1024_1024, input_1024 + net = Net(second, layout1, ops_name=ops_name) + with pytest.raises(RuntimeError): + compile_net(net, first) + + +@pytest.mark.parametrize('ops_name', ['add', 'sub', 'mul', 'div', 'real_div']) +def test_layout_extend_add_same_shape_same_shard_outputlayout_not_allowed(ops_name): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile failed + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) + out_layout = (layout(("dp", "sp"), "mp"),) + first, second = input_1024_1024, x + net = Net(second, layout1, out_layout, ops_name=ops_name) + with pytest.raises(RuntimeError): + compile_net(net, first) diff --git a/tests/ut/python/parallel/test_layout_extend_bias_add.py b/tests/ut/python/parallel/test_layout_extend_bias_add.py index 66a491e92f0..4cea618a3ee 100644 --- a/tests/ut/python/parallel/test_layout_extend_bias_add.py +++ b/tests/ut/python/parallel/test_layout_extend_bias_add.py @@ -1,121 +1,121 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor, Parameter -from mindspore import context -from mindspore.common.api import _cell_graph_executor -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.parallel.shard import Layout -from tests.ut.python.ops.test_math_ops import VirtualLoss -from parallel.utils.utils import ParallelValidator - - -def setup_function(): - context.set_auto_parallel_context(dataset_strategy="full_batch") - -grad_all = C.GradOperation(get_all=True) - -class NetWithLoss(nn.Cell): - def __init__(self, network): - super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, y): - predict = self.network(y) - return self.loss(predict) - - -class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - - def construct(self, y): - return grad_all(self.network)(y) - - -def compile_net(net, input_x): - net.set_auto_parallel() - net.set_train() - phase, _ = _cell_graph_executor.compile(net, input_x) - return phase - - -class Net(nn.Cell): - def __init__(self, weight, in_layout, out_layout=None): - super().__init__() - self.bias_add = P.BiasAdd().shard(in_strategy=in_layout, out_strategy=out_layout) - self.relu = P.ReLU() - self.w = Parameter(weight, "w1") - - def construct(self, y): - out1 = self.bias_add(y, self.w) - out2 = self.relu(out1) - out = out1 + out2 - return out - -x_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32) -x_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32) -bias = Tensor(np.ones([1024]), dtype=ms.float32) - -def test_layout_extend_bias_add_shard(): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile success - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout("mp", ("dp", "sp")), layout(("dp", "sp"),)) - first, second = x_1024_1024, bias - net = Net(second, layout1) - phase = compile_net(net, first) - validator = ParallelValidator(net, phase) - assert validator.check_parameter_shape('w1', [256]) - -def test_layout_extend_bias_add_channel_shard_different(): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile failed, channel shard different - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout("sp",)) - first, second = x_1024_1024, bias - net = Net(second, layout1) - with pytest.raises(RuntimeError): - compile_net(net, first) - -def test_layout_extend_bias_add_self_define_outputlayout_not_allowed(): - """ - Feature: test layout extend - Description: dev_num is 8. - Expectation: compile failed - """ - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - layout = Layout((2, 2, 2), ("dp", "sp", "mp")) - layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) - out_layout = (layout(("dp", "sp"), "mp"),) - first, second = x_1024_1024, bias - net = Net(second, layout1, out_layout) - with pytest.raises(RuntimeError): - compile_net(net, first) +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.parallel.shard import Layout +from tests.ut.python.ops.test_math_ops import VirtualLoss +from parallel.utils.utils import ParallelValidator + + +def setup_function(): + context.set_auto_parallel_context(dataset_strategy="full_batch") + +grad_all = C.GradOperation(get_all=True) + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, y): + predict = self.network(y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, y): + return grad_all(self.network)(y) + + +def compile_net(net, input_x): + net.set_auto_parallel() + net.set_train() + phase, _ = _cell_graph_executor.compile(net, input_x) + return phase + + +class Net(nn.Cell): + def __init__(self, weight, in_layout, out_layout=None): + super().__init__() + self.bias_add = P.BiasAdd().shard(in_strategy=in_layout, out_strategy=out_layout) + self.relu = P.ReLU() + self.w = Parameter(weight, "w1") + + def construct(self, y): + out1 = self.bias_add(y, self.w) + out2 = self.relu(out1) + out = out1 + out2 + return out + +x_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32) +x_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32) +bias = Tensor(np.ones([1024]), dtype=ms.float32) + +def test_layout_extend_bias_add_shard(): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout("mp", ("dp", "sp")), layout(("dp", "sp"),)) + first, second = x_1024_1024, bias + net = Net(second, layout1) + phase = compile_net(net, first) + validator = ParallelValidator(net, phase) + assert validator.check_parameter_shape('w1', [256]) + +def test_layout_extend_bias_add_channel_shard_different(): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile failed, channel shard different + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout("sp",)) + first, second = x_1024_1024, bias + net = Net(second, layout1) + with pytest.raises(RuntimeError): + compile_net(net, first) + +def test_layout_extend_bias_add_self_define_outputlayout_not_allowed(): + """ + Feature: test layout extend + Description: dev_num is 8. + Expectation: compile failed + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + layout = Layout((2, 2, 2), ("dp", "sp", "mp")) + layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp")) + out_layout = (layout(("dp", "sp"), "mp"),) + first, second = x_1024_1024, bias + net = Net(second, layout1, out_layout) + with pytest.raises(RuntimeError): + compile_net(net, first) diff --git a/tests/ut/python/parallel/test_loss_and_o2_level.py b/tests/ut/python/parallel/test_loss_and_o2_level.py old mode 100755 new mode 100644 diff --git a/tests/ut/python/rewrite/test_obfuscate/test_create_input.py b/tests/ut/python/rewrite/test_obfuscate/test_create_input.py index e26ee121ddf..ea39d2d6860 100644 --- a/tests/ut/python/rewrite/test_obfuscate/test_create_input.py +++ b/tests/ut/python/rewrite/test_obfuscate/test_create_input.py @@ -1,59 +1,59 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""test create input""" - -import mindspore.nn as nn -from mindspore.rewrite import SymbolTree, NodeType - -class MyNet(nn.Cell): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(16, 16, 3) - self.relu1 = nn.ReLU() - self.relu2 = nn.ReLU() - self.relu3 = nn.ReLU() - - def construct(self, x, y): - x = self.conv(x) - x = self.relu1(x) - x = self.relu2(x) - x = self.relu3(x) - return x - -def test_create_input(): - """ - Feature: Create an input node. - Description: Call create_input to create an input node. - Expectation: Success. - """ - net = MyNet() - stree = SymbolTree.create(net) - - assert len(stree.get_handler().nodes()) == 7 - node = stree.get_node("input_y") - assert node - assert node.get_node_type() == NodeType.Input - position = stree.after(node) - assert position - new_input_node = node.create_input("z") - assert new_input_node - assert new_input_node.get_node_type() == NodeType.Input - - stree.insert(position, new_input_node) - assert stree.get_handler().get_node('input_z') - assert len(stree.get_handler().nodes()) == 8 - - codes = stree.get_code() - assert codes.count("construct(self, x, y, z=None)") == 1 +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test create input""" + +import mindspore.nn as nn +from mindspore.rewrite import SymbolTree, NodeType + +class MyNet(nn.Cell): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(16, 16, 3) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + + def construct(self, x, y): + x = self.conv(x) + x = self.relu1(x) + x = self.relu2(x) + x = self.relu3(x) + return x + +def test_create_input(): + """ + Feature: Create an input node. + Description: Call create_input to create an input node. + Expectation: Success. + """ + net = MyNet() + stree = SymbolTree.create(net) + + assert len(stree.get_handler().nodes()) == 7 + node = stree.get_node("input_y") + assert node + assert node.get_node_type() == NodeType.Input + position = stree.after(node) + assert position + new_input_node = node.create_input("z") + assert new_input_node + assert new_input_node.get_node_type() == NodeType.Input + + stree.insert(position, new_input_node) + assert stree.get_handler().get_node('input_z') + assert len(stree.get_handler().nodes()) == 8 + + codes = stree.get_code() + assert codes.count("construct(self, x, y, z=None)") == 1 diff --git a/tests/ut/python/rewrite/test_obfuscate/test_decorator.py b/tests/ut/python/rewrite/test_obfuscate/test_decorator.py index 85559437893..c78e6f96908 100644 --- a/tests/ut/python/rewrite/test_obfuscate/test_decorator.py +++ b/tests/ut/python/rewrite/test_obfuscate/test_decorator.py @@ -1,106 +1,106 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""test decorator.""" - -import mindspore.nn as nn -import inspect -from mindspore.rewrite import SymbolTree -from functools import wraps - -try: - from mindspore._checkparam import Validator -except ImportError: - import mindspore._checkparam as Validator - - -def my_decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - func(*args, **kwargs) - - return wrapper - - -def _args_type_validator_check(*type_args, **type_kwargs): - """Check whether input data type is correct.""" - - def type_check(func): - sig = inspect.signature(func) - bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments - - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal bound_types - bound_values = sig.bind(*args, **kwargs) - - argument_dict = bound_values.arguments - if "kwargs" in bound_types: - bound_types = bound_types["kwargs"] - if "kwargs" in argument_dict: - argument_dict = argument_dict["kwargs"] - for name, value in argument_dict.items(): - if name in bound_types: - bound_types[name](value, name) - return func(*args, **kwargs) - - return wrapper - - return type_check - - -def register_denied_func_decorators(fn): - """user deny certain decorators""" - from mindspore.rewrite.parsers.class_def_parser import ClassDefParser - name = "denied_function_decorator_list" - setattr(ClassDefParser, name, fn) - - -class MyNet(nn.Cell): - @my_decorator - @_args_type_validator_check(in_channels=Validator.check_positive_int) - def __init__(self, in_channels): - super(MyNet, self).__init__() - self.conv = nn.Conv2d(16, 16, 3) - self.dense = nn.Dense(in_channels=in_channels, out_channels=32, weight_init="ones") - self.relu1 = nn.ReLU() - self.relu2 = nn.ReLU() - self.relu3 = nn.ReLU() - - def construct(self, x): - x = self.conv(x) - x = self.dense(x) - x = self.relu1(x) - x = self.relu2(x) - x = self.relu3(x) - return x - - -def test_decorator(): - """ - Feature: parse decorators - Description: parse decorators of function which are allowed according to users. - Expectation: Success. - """ - # the decorator "_args_type_validator_check" is denied - register_denied_func_decorators(["_args_type_validator_check"]) - net = MyNet(32) - stree = SymbolTree.create(net) - codes = stree.get_code() - - # @my_decorator is allowed - assert codes.count("@my_decorator") == 1 - - # @_args_type_validator_check is denied - assert codes.count("@_args_type_validator_check") == 0 +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test decorator.""" + +import mindspore.nn as nn +import inspect +from mindspore.rewrite import SymbolTree +from functools import wraps + +try: + from mindspore._checkparam import Validator +except ImportError: + import mindspore._checkparam as Validator + + +def my_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + + return wrapper + + +def _args_type_validator_check(*type_args, **type_kwargs): + """Check whether input data type is correct.""" + + def type_check(func): + sig = inspect.signature(func) + bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_types + bound_values = sig.bind(*args, **kwargs) + + argument_dict = bound_values.arguments + if "kwargs" in bound_types: + bound_types = bound_types["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in bound_types: + bound_types[name](value, name) + return func(*args, **kwargs) + + return wrapper + + return type_check + + +def register_denied_func_decorators(fn): + """user deny certain decorators""" + from mindspore.rewrite.parsers.class_def_parser import ClassDefParser + name = "denied_function_decorator_list" + setattr(ClassDefParser, name, fn) + + +class MyNet(nn.Cell): + @my_decorator + @_args_type_validator_check(in_channels=Validator.check_positive_int) + def __init__(self, in_channels): + super(MyNet, self).__init__() + self.conv = nn.Conv2d(16, 16, 3) + self.dense = nn.Dense(in_channels=in_channels, out_channels=32, weight_init="ones") + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.dense(x) + x = self.relu1(x) + x = self.relu2(x) + x = self.relu3(x) + return x + + +def test_decorator(): + """ + Feature: parse decorators + Description: parse decorators of function which are allowed according to users. + Expectation: Success. + """ + # the decorator "_args_type_validator_check" is denied + register_denied_func_decorators(["_args_type_validator_check"]) + net = MyNet(32) + stree = SymbolTree.create(net) + codes = stree.get_code() + + # @my_decorator is allowed + assert codes.count("@my_decorator") == 1 + + # @_args_type_validator_check is denied + assert codes.count("@_args_type_validator_check") == 0 diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh old mode 100755 new mode 100644 diff --git a/tests/ut/runtest.sh b/tests/ut/runtest.sh old mode 100755 new mode 100644 diff --git a/third_party/patch/c-ares/CVE-2021-3672.patch b/third_party/patch/c-ares/CVE-2021-3672.patch deleted file mode 100644 index 31276e27324..00000000000 --- a/third_party/patch/c-ares/CVE-2021-3672.patch +++ /dev/null @@ -1,105 +0,0 @@ -diff -Npur c-ares-1.15.0/ares_expand_name.c c-ares-1.15.0-new/ares_expand_name.c ---- c-ares-1.15.0/ares_expand_name.c 2017-07-03 17:04:19.000000000 +0800 -+++ c-ares-1.15.0-new/ares_expand_name.c 2021-08-21 22:48:24.650973166 +0800 -@@ -38,6 +38,26 @@ - static int name_length(const unsigned char *encoded, const unsigned char *abuf, - int alen); - -+/* Reserved characters for names that need to be escaped */ -+static int is_reservedch(int ch) -+{ -+ switch (ch) { -+ case '"': -+ case '.': -+ case ';': -+ case '\\': -+ case '(': -+ case ')': -+ case '@': -+ case '$': -+ return 1; -+ default: -+ break; -+ } -+ -+ return 0; -+} -+ - /* Expand an RFC1035-encoded domain name given by encoded. The - * containing message is given by abuf and alen. The result given by - * *s, which is set to a NUL-terminated allocated buffer. *enclen is -@@ -113,18 +133,37 @@ int ares_expand_name(const unsigned char - } - else - { -- len = *p; -+ int name_len = *p; -+ len = name_len; - p++; -+ - while (len--) - { -- if (*p == '.' || *p == '\\') -- *q++ = '\\'; -- *q++ = *p; -+ /* Output as \DDD for consistency with RFC1035 5.1, except -+ * for the special case of a root name response */ -+ if (!isprint(*p) && !(name_len == 1 && *p == 0)) -+ { -+ -+ *q++ = '\\'; -+ *q++ = '0' + *p / 100; -+ *q++ = '0' + (*p % 100) / 10; -+ *q++ = '0' + (*p % 10); -+ } -+ else if (is_reservedch(*p)) -+ { -+ *q++ = '\\'; -+ *q++ = *p; -+ } -+ else -+ { -+ *q++ = *p; -+ } - p++; - } - *q++ = '.'; - } -- } -+ } -+ - if (!indir) - *enclen = aresx_uztosl(p + 1U - encoded); - -@@ -171,15 +210,29 @@ static int name_length(const unsigned ch - } - else if (top == 0x00) - { -- offset = *encoded; -+ int name_len = *encoded; -+ offset = name_len; - if (encoded + offset + 1 >= abuf + alen) - return -1; - encoded++; -+ - while (offset--) - { -- n += (*encoded == '.' || *encoded == '\\') ? 2 : 1; -+ if (!isprint(*encoded) && !(name_len == 1 && *encoded == 0)) -+ { -+ n += 4; -+ } -+ else if (is_reservedch(*encoded)) -+ { -+ n += 2; -+ } -+ else -+ { -+ n += 1; -+ } - encoded++; - } -+ - n++; - } - else diff --git a/third_party/patch/c-ares/CVE-2022-4904.patch b/third_party/patch/c-ares/CVE-2022-4904.patch deleted file mode 100644 index 66800ffa882..00000000000 --- a/third_party/patch/c-ares/CVE-2022-4904.patch +++ /dev/null @@ -1,33 +0,0 @@ -diff -Npur c-ares-1.15.0/ares_init.c c-ares-1.15.0-change/ares_init.c ---- c-ares-1.15.0/ares_init.c 2018-10-11 05:20:12.000000000 +0800 -+++ c-ares-1.15.0-change/ares_init.c 2023-03-13 15:27:50.068885382 +0800 -@@ -2205,6 +2205,8 @@ static int config_sortlist(struct apatte - q = str; - while (*q && *q != '/' && *q != ';' && !ISSPACE(*q)) - q++; -+ if (q-str >= 16) -+ return ARES_EBADSTR; - memcpy(ipbuf, str, q-str); - ipbuf[q-str] = '\0'; - /* Find the prefix */ -@@ -2213,6 +2215,8 @@ static int config_sortlist(struct apatte - const char *str2 = q+1; - while (*q && *q != ';' && !ISSPACE(*q)) - q++; -+ if (q-str >= 32) -+ return ARES_EBADSTR; - memcpy(ipbufpfx, str, q-str); - ipbufpfx[q-str] = '\0'; - str = str2; -diff -Npur c-ares-1.15.0/test/ares-test-init.cc c-ares-1.15.0-change/test/ares-test-init.cc ---- c-ares-1.15.0/test/ares-test-init.cc 2018-10-11 05:20:12.000000000 +0800 -+++ c-ares-1.15.0-change/test/ares-test-init.cc 2023-03-13 15:29:02.560881555 +0800 -@@ -270,6 +270,8 @@ TEST_F(DefaultChannelTest, SetAddresses) - - TEST_F(DefaultChannelTest, SetSortlistFailures) { - EXPECT_EQ(ARES_ENODATA, ares_set_sortlist(nullptr, "1.2.3.4")); -+ EXPECT_EQ(ARES_EBADSTR, ares_set_sortlist(channel_, "111.111.111.111*/16")); -+ EXPECT_EQ(ARES_EBADSTR, ares_set_sortlist(channel_, "111.111.111.111/255.255.255.240*")); - EXPECT_EQ(ARES_SUCCESS, ares_set_sortlist(channel_, "xyzzy ; lwk")); - EXPECT_EQ(ARES_SUCCESS, ares_set_sortlist(channel_, "xyzzy ; 0x123")); - } diff --git a/third_party/patch/cucollections/0001-refine-bitwise-compare.patch b/third_party/patch/cucollections/0001-refine-bitwise-compare.patch deleted file mode 100644 index cafa15b4a88..00000000000 --- a/third_party/patch/cucollections/0001-refine-bitwise-compare.patch +++ /dev/null @@ -1,53 +0,0 @@ ---- - include/cuco/detail/pair.cuh | 3 +-- - include/cuco/traits.hpp | 23 +---------------------- - 2 files changed, 2 insertions(+), 24 deletions(-) - -diff --git a/include/cuco/detail/pair.cuh b/include/cuco/detail/pair.cuh -index 7ea3988..ade6df3 100644 ---- a/include/cuco/detail/pair.cuh -+++ b/include/cuco/detail/pair.cuh -@@ -131,8 +131,7 @@ template - constexpr bool is_packable() - { -- return not std::is_void>::value and -- std::has_unique_object_representations_v; -+ return false; - } - - /** -diff --git a/include/cuco/traits.hpp b/include/cuco/traits.hpp -index 445a40d..948b587 100644 ---- a/include/cuco/traits.hpp -+++ b/include/cuco/traits.hpp -@@ -34,28 +34,7 @@ namespace cuco { - * other `NaN` bit patterns. - * - */ --template --struct is_bitwise_comparable : std::false_type { --}; - --/// By default, only types with unique object representations are allowed - template --struct is_bitwise_comparable>> -- : std::true_type { --}; -- --template --inline constexpr bool is_bitwise_comparable_v = is_bitwise_comparable::value; -- --/** -- * @brief Declares that a type `Type` is bitwise comparable. -- * -- */ --#define CUCO_DECLARE_BITWISE_COMPARABLE(Type) \ -- namespace cuco { \ -- template <> \ -- struct is_bitwise_comparable : std::true_type { \ -- }; \ -- } -- -+inline constexpr bool is_bitwise_comparable_v = true; - } // namespace cuco diff --git a/third_party/patch/cucollections/0002-add-get-api-of-dynamic_map.patch b/third_party/patch/cucollections/0002-add-get-api-of-dynamic_map.patch deleted file mode 100644 index 9fdccccb276..00000000000 --- a/third_party/patch/cucollections/0002-add-get-api-of-dynamic_map.patch +++ /dev/null @@ -1,105 +0,0 @@ ---- - include/cuco/dynamic_map.cuh | 62 ++++++++++++++++++++++++++++++++++-- - include/cuco/traits.hpp | 2 +- - 2 files changed, 60 insertions(+), 4 deletions(-) - -diff --git a/include/cuco/dynamic_map.cuh b/include/cuco/dynamic_map.cuh -index 866f948..af3ea03 100644 ---- a/include/cuco/dynamic_map.cuh -+++ b/include/cuco/dynamic_map.cuh -@@ -103,8 +103,8 @@ class dynamic_map { - using key_type = Key; ///< Key type - using mapped_type = Value; ///< Type of mapped values - using atomic_ctr_type = cuda::atomic; ///< Type of atomic counters -- using view_type = typename static_map::device_view; ///< Device view type -- using mutable_view_type = typename static_map::device_mutable_view; -+ using view_type = typename static_map::device_view; ///< Device view type -+ using mutable_view_type = typename static_map::device_mutable_view; - ///< Device mutable view type - - dynamic_map(dynamic_map const&) = delete; -@@ -248,6 +248,62 @@ class dynamic_map { - */ - float get_load_factor() const noexcept { return static_cast(size_) / capacity_; } - -+ /** -+ * @brief Update the size of the hash map. -+ * -+ * @param size The number of the size to be updated. -+ */ -+ void update_size(std::size_t size) { -+ size_ = size; -+ } -+ -+ /** -+ * @brief Update the size of the submap. -+ * -+ * @param submap_idx The index of submap whose size need to be updated. -+ * @param size The number of the size of submap to be updated. -+ */ -+ void update_submap_size(std::size_t submap_idx, std::size_t size) { -+ submaps_[submap_idx]->size_ = size; -+ } -+ -+ /** -+ * @brief Gets the all submaps of the hash map. -+ * -+ * @return The all submaps of the hash map. -+ */ -+ const std::vector>>& get_submaps() const noexcept { -+ return submaps_; -+ } -+ -+ /** -+ * @brief Gets the all mutable views for all submaps of the hash map. -+ * -+ * @return All mutable views for all submaps of the hash map. -+ */ -+ thrust::device_vector& get_submap_mutable_views() noexcept { return submap_mutable_views_; } -+ -+ /** -+ * @brief Gets the all mutable views for all submaps of the hash map. -+ * -+ * @return All mutable views for all submaps of the hash map. -+ */ -+ thrust::device_vector& get_submap_views() noexcept { return submap_views_; } -+ -+ /** -+ * @brief Gets the max load factor of the hash map. -+ * -+ * @return The max load factor of the hash map. -+ */ -+ float get_max_load_factor() const noexcept { return max_load_factor_; } -+ -+ /** -+ * @brief Gets minimum insert size of the hash map. -+ * -+ * @return The minimum insert size of the hash map. -+ */ -+ std::size_t get_min_insert_size() const noexcept { return min_insert_size_; } -+ - private: - key_type empty_key_sentinel_{}; ///< Key value that represents an empty slot - mapped_type empty_value_sentinel_{}; ///< Initial value of empty slot -@@ -255,7 +311,7 @@ class dynamic_map { - std::size_t capacity_{}; ///< Maximum number of keys that can be inserted - float max_load_factor_{}; ///< Max load factor before capacity growth - -- std::vector>> -+ std::vector>> - submaps_; ///< vector of pointers to each submap - thrust::device_vector submap_views_; ///< vector of device views for each submap - thrust::device_vector -diff --git a/include/cuco/traits.hpp b/include/cuco/traits.hpp -index 948b587..b7fbbc4 100644 ---- a/include/cuco/traits.hpp -+++ b/include/cuco/traits.hpp -@@ -16,7 +16,7 @@ - - #pragma once - --#include -+// #include - - namespace cuco { - diff --git a/third_party/patch/cucollections/0003-add-erase-and-export-api.patch b/third_party/patch/cucollections/0003-add-erase-and-export-api.patch deleted file mode 100644 index 10e925d1534..00000000000 --- a/third_party/patch/cucollections/0003-add-erase-and-export-api.patch +++ /dev/null @@ -1,357 +0,0 @@ ---- - include/cuco/detail/bitwise_compare.cuh | 1 + - include/cuco/detail/dynamic_map.inl | 98 ++++++++++++++++++++- - include/cuco/detail/dynamic_map_kernels.cuh | 83 +++++++++++++++++ - include/cuco/dynamic_map.cuh | 58 +++++++++++- - 4 files changed, 236 insertions(+), 4 deletions(-) - -diff --git a/include/cuco/detail/bitwise_compare.cuh b/include/cuco/detail/bitwise_compare.cuh -index 3038943..4bd58c2 100644 ---- a/include/cuco/detail/bitwise_compare.cuh -+++ b/include/cuco/detail/bitwise_compare.cuh -@@ -18,6 +18,7 @@ - - #include - #include -+#include - - namespace cuco { - namespace detail { -diff --git a/include/cuco/detail/dynamic_map.inl b/include/cuco/detail/dynamic_map.inl -index 0c1d2e3..2425c7d 100644 ---- a/include/cuco/detail/dynamic_map.inl -+++ b/include/cuco/detail/dynamic_map.inl -@@ -21,30 +21,68 @@ dynamic_map::dynamic_map( - std::size_t initial_capacity, - sentinel::empty_key empty_key_sentinel, - sentinel::empty_value empty_value_sentinel, -- Allocator const& alloc) -+ Allocator const& alloc, -+ cudaStream_t stream) - : empty_key_sentinel_(empty_key_sentinel.value), - empty_value_sentinel_(empty_value_sentinel.value), - size_(0), - capacity_(initial_capacity), - min_insert_size_(1E4), - max_load_factor_(0.60), -+ counter_allocator_{alloc}, - alloc_{alloc} - { - submaps_.push_back(std::make_unique>( - initial_capacity, - sentinel::empty_key{empty_key_sentinel}, - sentinel::empty_value{empty_value_sentinel}, -- alloc)); -+ alloc, stream)); - submap_views_.push_back(submaps_[0]->get_device_view()); - submap_mutable_views_.push_back(submaps_[0]->get_device_mutable_view()); - - CUCO_CUDA_TRY(cudaMallocManaged(&num_successes_, sizeof(atomic_ctr_type))); --} // namespace cuco -+ d_submaps_erase_num_successes_ = std::allocator_traits::allocate(counter_allocator_, max_num_submaps_); -+ CUCO_CUDA_TRY(cudaMallocHost(&h_submaps_erase_num_successes_, sizeof(atomic_ctr_type) * (max_num_submaps_))); -+} -+ -+template -+dynamic_map::dynamic_map( -+ std::size_t initial_capacity, -+ sentinel::empty_key empty_key_sentinel, -+ sentinel::empty_value empty_value_sentinel, -+ sentinel::erased_key erased_key_sentinel, -+ Allocator const& alloc, -+ cudaStream_t stream) -+ : empty_key_sentinel_(empty_key_sentinel.value), -+ empty_value_sentinel_(empty_value_sentinel.value), -+ erased_key_sentinel_{erased_key_sentinel.value}, -+ size_(0), -+ capacity_(initial_capacity), -+ min_insert_size_(1E4), -+ max_load_factor_(0.60), -+ counter_allocator_{alloc}, -+ alloc_{alloc} -+{ -+ submaps_.push_back(std::make_unique>( -+ initial_capacity, -+ sentinel::empty_key{empty_key_sentinel}, -+ sentinel::empty_value{empty_value_sentinel}, -+ sentinel::erased_key{erased_key_sentinel}, -+ alloc, stream)); -+ submap_views_.push_back(submaps_[0]->get_device_view()); -+ submap_mutable_views_.push_back(submaps_[0]->get_device_mutable_view()); -+ -+ CUCO_CUDA_TRY(cudaMallocManaged(&num_successes_, sizeof(atomic_ctr_type))); -+ d_submaps_erase_num_successes_ = std::allocator_traits::allocate(counter_allocator_, max_num_submaps_); -+ CUCO_CUDA_TRY(cudaMallocHost(&h_submaps_erase_num_successes_, sizeof(atomic_ctr_type) * (max_num_submaps_))); -+} - - template - dynamic_map::~dynamic_map() - { - CUCO_ASSERT_CUDA_SUCCESS(cudaFree(num_successes_)); -+ std::allocator_traits::deallocate(counter_allocator_, d_submaps_erase_num_successes_ , max_num_submaps_); -+ CUCO_ASSERT_CUDA_SUCCESS(cudaFreeHost(reinterpret_cast(h_submaps_erase_num_successes_))); - } - - template -@@ -75,6 +113,9 @@ void dynamic_map::reserve(std::size_t n) - - num_elements_remaining -= max_load_factor_ * submap_capacity - min_insert_size_; - submap_idx++; -+ if (submap_idx > max_num_submaps_) { -+ throw std::runtime_error("The number of submaps exceeds the maximum[256]"); -+ } - } - } - -@@ -160,4 +201,55 @@ void dynamic_map::contains( - CUCO_CUDA_TRY(cudaDeviceSynchronize()); - } - -+template -+template -+void dynamic_map::erase(InputIt first, InputIt last, -+ cudaStream_t stream, Hash hash, KeyEqual key_equal) { -+ auto num_keys = std::distance(first, last); -+ if (num_keys == 0) { return; } -+ -+ auto constexpr block_size = 128; -+ auto constexpr stride = 1; -+ auto constexpr tile_size = 4; -+ auto const grid_size = (tile_size * num_keys + stride * block_size - 1) / (stride * block_size); -+ -+ static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); -+ for(size_t i = 0; i < max_num_submaps_; i++) { -+ h_submaps_erase_num_successes_[i] = 0; -+ } -+ -+ CUCO_CUDA_TRY(cudaMemcpyAsync( -+ d_submaps_erase_num_successes_, h_submaps_erase_num_successes_, submaps_.size() * sizeof(atomic_ctr_type), -+ cudaMemcpyHostToDevice, stream)); -+ -+ detail::erase<<>>( -+ first, first + num_keys, d_submaps_erase_num_successes_, submap_mutable_views_.data().get(), submaps_.size(), hash, key_equal); -+ -+ CUCO_CUDA_TRY(cudaMemcpyAsync( -+ h_submaps_erase_num_successes_, d_submaps_erase_num_successes_, submaps_.size() * sizeof(atomic_ctr_type), -+ cudaMemcpyDeviceToHost, stream)); -+ -+ CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); -+ for(size_t submap_idx = 0; submap_idx < submaps_.size(); submap_idx++){ -+ submaps_[submap_idx]->size_ -= h_submaps_erase_num_successes_[submap_idx]; -+ size_ -= h_submaps_erase_num_successes_[submap_idx]; -+ } -+} -+ -+template -+bool dynamic_map::get_keys_values(Key *keys, Value *values, cudaStream_t stream) { -+ *num_successes_ = 0; -+ int device_id; -+ CUCO_CUDA_TRY(cudaGetDevice(&device_id)); -+ CUCO_CUDA_TRY(cudaMemPrefetchAsync(num_successes_, sizeof(atomic_ctr_type), device_id)); -+ -+ auto const block_size = 128; -+ auto const stride = 1; -+ auto const grid_size = (size_ + stride * block_size - 1) / (stride * block_size); -+ detail::get_keys_values<<>>(submaps_.size(), submap_views_.data().get(), num_successes_, keys, values); -+ -+ CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); -+ size_t h_num_successes = num_successes_->load(cuda::std::memory_order_relaxed); -+ return h_num_successes == size_; -+} - } // namespace cuco -diff --git a/include/cuco/detail/dynamic_map_kernels.cuh b/include/cuco/detail/dynamic_map_kernels.cuh -index f261b49..75b2c07 100644 ---- a/include/cuco/detail/dynamic_map_kernels.cuh -+++ b/include/cuco/detail/dynamic_map_kernels.cuh -@@ -20,6 +20,7 @@ - #include - - #include -+#include - - namespace cuco { - namespace detail { -@@ -463,5 +464,87 @@ __global__ void contains(InputIt first, - key_idx += (gridDim.x * blockDim.x) / tile_size; - } - } -+ -+template -+__global__ void erase( -+ InputIt first, InputIt last, atomicT* num_successes, viewT* views, std::size_t num_submaps, Hash hash, KeyEqual key_equal) -+{ -+ extern __shared__ atomicT local_num_successes[]; -+ -+ if (threadIdx.x < num_submaps) { -+ local_num_successes[threadIdx.x] = 0; -+ } -+ __syncthreads(); -+ -+ auto tile = cg::tiled_partition(cg::this_thread_block()); -+ auto tid = block_size * blockIdx.x + threadIdx.x; -+ auto it = first + tid / tile_size; -+ -+ while (it < last) { -+ for (auto submap_idx = 0; submap_idx < num_submaps; ++submap_idx) { -+ if (views[submap_idx].erase(tile, *it, hash, key_equal)) { -+ if (tile.thread_rank() == 0) { -+ local_num_successes[submap_idx] += 1; -+ } -+ break; -+ } -+ } -+ it += (gridDim.x * block_size) / tile_size; -+ } -+ -+ __syncthreads(); -+ if (threadIdx.x < num_submaps) { -+ num_successes[threadIdx.x] += local_num_successes[threadIdx.x]; -+ } -+} -+ -+template -+__global__ void get_keys_values(size_t num_submaps, ViewType *submap_views, AtomicType* global_cnt, Key* keys, Value*values) { -+ __shared__ size_t global_offset; -+ extern __shared__ AtomicType local_cnt[]; -+ const int default_offset_sentinel = -1; -+ -+ for (size_t submap_idx = 0; submap_idx < num_submaps; submap_idx++){ -+ auto & submap_view = submap_views[submap_idx]; -+ -+ for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < submap_view.get_capacity(); -+ tid += blockDim.x * gridDim.x) { -+ if (threadIdx.x == 0) { -+ local_cnt[0] = 0; -+ } -+ __syncthreads(); -+ -+ auto current_slot = submap_view.begin_slot() + tid; -+ const Key & current_key = current_slot->first.load(cuda::std::memory_order_relaxed); -+ auto const slot_not_idle = -+ !detail::bitwise_compare(current_key, submap_view.get_empty_key_sentinel()) && -+ !detail::bitwise_compare(current_key, submap_view.get_erased_key_sentinel()); -+ -+ int local_offset = default_offset_sentinel; -+ if(slot_not_idle) { -+ local_offset = local_cnt[0].fetch_add(1, cuda::std::memory_order_relaxed); -+ } -+ __syncthreads(); -+ -+ if (threadIdx.x == 0) { -+ auto local_cnt_value = local_cnt[0].load(cuda::std::memory_order_relaxed); -+ global_offset = global_cnt->fetch_add(local_cnt_value, cuda::std::memory_order_relaxed); -+ } -+ __syncthreads(); -+ -+ if (local_offset > default_offset_sentinel) { -+ auto offset = global_offset + local_offset; -+ keys[offset] = current_key; -+ values[offset] = current_slot->second.load(cuda::std::memory_order_relaxed); -+ } -+ } -+ } -+} - } // namespace detail - } // namespace cuco -diff --git a/include/cuco/dynamic_map.cuh b/include/cuco/dynamic_map.cuh -index af3ea03..9ed2f25 100644 ---- a/include/cuco/dynamic_map.cuh -+++ b/include/cuco/dynamic_map.cuh -@@ -105,6 +105,8 @@ class dynamic_map { - using atomic_ctr_type = cuda::atomic; ///< Type of atomic counters - using view_type = typename static_map::device_view; ///< Device view type - using mutable_view_type = typename static_map::device_mutable_view; -+ using counter_allocator_type = typename static_map::counter_allocator_type; -+ - ///< Device mutable view type - - dynamic_map(dynamic_map const&) = delete; -@@ -135,7 +137,36 @@ class dynamic_map { - dynamic_map(std::size_t initial_capacity, - sentinel::empty_key empty_key_sentinel, - sentinel::empty_value empty_value_sentinel, -- Allocator const& alloc = Allocator{}); -+ Allocator const& alloc = Allocator{}, -+ cudaStream_t stream = 0); -+ -+ /** -+ * @brief Construct a dynamically-sized map with the specified initial capacity, growth factor and -+ * sentinel values. -+ * -+ * The capacity of the map will automatically increase as the user adds key/value pairs using -+ * `insert`. -+ * -+ * Capacity increases by a factor of growth_factor each time the size of the map exceeds a -+ * threshold occupancy. The performance of `find` and `contains` decreases somewhat each time the -+ * map's capacity grows. -+ * -+ * The `empty_key_sentinel` and `empty_value_sentinel` values are reserved and -+ * undefined behavior results from attempting to insert any key/value pair -+ * that contains either. -+ * -+ * @param initial_capacity The initial number of slots in the map -+ * @param empty_key_sentinel The reserved key value for empty slots -+ * @param empty_value_sentinel The reserved mapped value for empty slots -+ * @param erased_key_sentinel The reserved value to denote erased slots -+ * @param alloc Allocator used to allocate submap device storage -+ */ -+dynamic_map(std::size_t initial_capacity, -+ sentinel::empty_key empty_key_sentinel, -+ sentinel::empty_value empty_value_sentinel, -+ sentinel::erased_key erased_key_sentinel, -+ Allocator const& alloc = Allocator(), -+ cudaStream_t stream = 0); - - /** - * @brief Destroy the map and frees its contents -@@ -227,6 +258,25 @@ class dynamic_map { - Hash hash = Hash{}, - KeyEqual key_equal = KeyEqual{}); - -+ template , -+ typename KeyEqual = thrust::equal_to> -+ void erase(InputIt first, -+ InputIt last, -+ cudaStream_t stream = 0, -+ Hash hash = Hash{}, -+ KeyEqual key_equal = KeyEqual{}); -+ -+ /** -+ * @brief Get all keys and values in the hash map. -+ * -+ * @param keys The output parameter, pointing the buffer which will maintain all keys in the hash map. -+ * @param values The output parameter, pointing the buffer which will maintain all values in the hash map. -+ * @param stream The cuda stream. -+ * @return Whether export keys and values successfully. -+ */ -+ bool get_keys_values(Key *keys, Value *values, cudaStream_t stream = 0); -+ - /** - * @brief Gets the current number of elements in the map - * -@@ -307,6 +357,7 @@ class dynamic_map { - private: - key_type empty_key_sentinel_{}; ///< Key value that represents an empty slot - mapped_type empty_value_sentinel_{}; ///< Initial value of empty slot -+ key_type erased_key_sentinel_{}; ///< Key value that represents an erased slot - std::size_t size_{}; ///< Number of keys in the map - std::size_t capacity_{}; ///< Maximum number of keys that can be inserted - float max_load_factor_{}; ///< Max load factor before capacity growth -@@ -319,6 +370,11 @@ class dynamic_map { - std::size_t min_insert_size_{}; ///< min remaining capacity of submap for insert - atomic_ctr_type* num_successes_; ///< number of successfully inserted keys on insert - Allocator alloc_{}; ///< Allocator passed to submaps to allocate their device storage -+ -+ counter_allocator_type counter_allocator_{}; ///< Allocator used to allocate counters -+ atomic_ctr_type* d_submaps_erase_num_successes_; ///< number of successfully erased keys on erase, atomic on device. -+ atomic_ctr_type* h_submaps_erase_num_successes_; ///< number of successfully erased keys on erase, atomic on host. -+ const size_t max_num_submaps_ = 256; ///< The max number of submaps. - }; - } // namespace cuco diff --git a/third_party/patch/cucollections/0004-bugfix-for-reserve-find-insert-and-erase-api.patch b/third_party/patch/cucollections/0004-bugfix-for-reserve-find-insert-and-erase-api.patch deleted file mode 100644 index ce628e74b9b..00000000000 --- a/third_party/patch/cucollections/0004-bugfix-for-reserve-find-insert-and-erase-api.patch +++ /dev/null @@ -1,230 +0,0 @@ -From c858e60ee9839b6d9d3346528b5825b42df4e887 Mon Sep 17 00:00:00 2001 -From: lizhenyu -Date: Fri, 2 Dec 2022 11:42:05 +0800 -Subject: [PATCH] bugfix for reserve find insert and erase api - ---- - include/cuco/detail/dynamic_map.inl | 9 ++++--- - include/cuco/detail/static_map.inl | 39 +++++++++++++++++++++++++++++ - include/cuco/dynamic_map.cuh | 2 ++ - include/cuco/static_map.cuh | 8 ++++++ - 4 files changed, 54 insertions(+), 4 deletions(-) - -diff --git a/include/cuco/detail/dynamic_map.inl b/include/cuco/detail/dynamic_map.inl -index 2425c7d..e64032f 100644 ---- a/include/cuco/detail/dynamic_map.inl -+++ b/include/cuco/detail/dynamic_map.inl -@@ -104,6 +104,7 @@ void dynamic_map::reserve(std::size_t n) - submap_capacity, - sentinel::empty_key{empty_key_sentinel_}, - sentinel::empty_value{empty_value_sentinel_}, -+ sentinel::erased_key{erased_key_sentinel_}, - alloc_)); - submap_views_.push_back(submaps_[submap_idx]->get_device_view()); - submap_mutable_views_.push_back(submaps_[submap_idx]->get_device_mutable_view()); -@@ -172,7 +173,7 @@ void dynamic_map::insert(InputIt first, - template - template - void dynamic_map::find( -- InputIt first, InputIt last, OutputIt output_begin, Hash hash, KeyEqual key_equal) -+ InputIt first, InputIt last, OutputIt output_begin, cudaStream_t stream, Hash hash, KeyEqual key_equal) - { - auto num_keys = std::distance(first, last); - auto const block_size = 128; -@@ -180,9 +181,9 @@ void dynamic_map::find( - auto const tile_size = 4; - auto const grid_size = (tile_size * num_keys + stride * block_size - 1) / (stride * block_size); - -- detail::find<<>>( -+ detail::find<<>>( - first, last, output_begin, submap_views_.data().get(), submaps_.size(), hash, key_equal); -- CUCO_CUDA_TRY(cudaDeviceSynchronize()); -+ CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - template -@@ -241,7 +242,7 @@ bool dynamic_map::get_keys_values(Key *keys, Value - *num_successes_ = 0; - int device_id; - CUCO_CUDA_TRY(cudaGetDevice(&device_id)); -- CUCO_CUDA_TRY(cudaMemPrefetchAsync(num_successes_, sizeof(atomic_ctr_type), device_id)); -+ CUCO_CUDA_TRY(cudaMemPrefetchAsync(num_successes_, sizeof(atomic_ctr_type), device_id, stream)); - - auto const block_size = 128; - auto const stride = 1; -diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl -index a1bb4f0..d7c40bb 100644 ---- a/include/cuco/detail/static_map.inl -+++ b/include/cuco/detail/static_map.inl -@@ -49,6 +49,7 @@ static_map::static_map( - detail::initialize - <<>>( - slots_, empty_key_sentinel_, empty_value_sentinel_, capacity_); -+ CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - template -@@ -78,6 +79,7 @@ static_map::static_map( - detail::initialize - <<>>( - slots_, empty_key_sentinel_, empty_value_sentinel_, capacity_); -+ CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - template -@@ -429,6 +431,10 @@ __device__ bool static_map::device_mutable_view::i - { - auto current_slot = initial_slot(g, insert_pair.first, hash); - -+ size_t search_count = 0; -+ // The maximum number of times the target key is searched (that is, traversed through -+ // the entire slot) to prevent an endless loop. -+ size_t max_search_count = static_cast(std::ceil(static_cast(this->get_capacity()) / g.size())); - while (true) { - key_type const existing_key = current_slot->first.load(cuda::std::memory_order_relaxed); - -@@ -481,6 +487,9 @@ __device__ bool static_map::device_mutable_view::i - // if there are no empty slots in the current window, - // we move onto the next window - else { -+ if (++search_count >= max_search_count) { -+ return false; -+ } - current_slot = next_slot(g, current_slot); - } - } -@@ -537,6 +546,10 @@ template - __device__ bool static_map::device_mutable_view::erase( - CG const& g, key_type const& k, Hash hash, KeyEqual key_equal) noexcept - { -+ size_t search_count = 0; -+ // The maximum number of times the target key is searched (that is, traversed through -+ // the entire slot) to prevent an endless loop. -+ size_t max_search_count = static_cast(std::ceil(static_cast(this->get_capacity()) / g.size())); - auto current_slot = initial_slot(g, k, hash); - value_type const insert_pair = - make_pair(this->get_erased_key_sentinel(), this->get_empty_value_sentinel()); -@@ -586,6 +599,10 @@ __device__ bool static_map::device_mutable_view::e - // empty slot found, but key not found, must not be in the map - if (g.ballot(slot_is_empty)) { return false; } - -+ if (++search_count >= max_search_count) { -+ return false; -+ } -+ - current_slot = next_slot(g, current_slot); - } - } -@@ -645,6 +662,10 @@ static_map::device_view::find(CG g, - KeyEqual key_equal) noexcept - { - auto current_slot = initial_slot(g, k, hash); -+ size_t search_count = 0; -+ // The maximum number of times the target key is searched (that is, traversed through -+ // the entire slot) to prevent an endless loop. -+ size_t max_search_count = static_cast(std::ceil(static_cast(this->get_capacity()) / g.size())); - - while (true) { - auto const existing_key = current_slot->first.load(cuda::std::memory_order_relaxed); -@@ -668,6 +689,10 @@ static_map::device_view::find(CG g, - // we found an empty slot, meaning that the key we're searching for isn't present - if (g.ballot(slot_is_empty)) { return this->end(); } - -+ if (++search_count >= max_search_count) { -+ return this->end(); -+ } -+ - // otherwise, all slots in the current window are full with other keys, so we move onto the - // next window - current_slot = next_slot(g, current_slot); -@@ -684,6 +709,10 @@ static_map::device_view::find(CG g, - { - auto current_slot = initial_slot(g, k, hash); - -+ size_t search_count = 0; -+ // The maximum number of times the target key is searched (that is, traversed through -+ // the entire slot) to prevent an endless loop. -+ size_t max_search_count = static_cast(std::ceil(static_cast(this->get_capacity()) / g.size())); - while (true) { - auto const existing_key = current_slot->first.load(cuda::std::memory_order_relaxed); - -@@ -710,6 +739,9 @@ static_map::device_view::find(CG g, - // otherwise, all slots in the current window are full with other keys, - // so we move onto the next window in the current submap - -+ if (++search_count >= max_search_count ){ -+ return this->end(); -+ } - current_slot = next_slot(g, current_slot); - } - } -@@ -742,6 +774,10 @@ static_map::device_view::contains(CG const& g, - { - auto current_slot = initial_slot(g, k, hash); - -+ size_t search_count = 0; -+ // The maximum number of times the target key is searched (that is, traversed through -+ // the entire slot) to prevent an endless loop. -+ size_t max_search_count = static_cast(std::ceil(static_cast(this->get_capacity()) / g.size())); - while (true) { - key_type const existing_key = current_slot->first.load(cuda::std::memory_order_relaxed); - -@@ -757,6 +793,9 @@ static_map::device_view::contains(CG const& g, - // we found an empty slot, meaning that the key we're searching for isn't present - if (g.ballot(slot_is_empty)) { return false; } - -+ if (++search_count >= max_search_count) { -+ return false; -+ } - // otherwise, all slots in the current window are full with other keys, so we move onto the - // next window - current_slot = next_slot(g, current_slot); -diff --git a/include/cuco/dynamic_map.cuh b/include/cuco/dynamic_map.cuh -index 9ed2f25..f82a4c6 100644 ---- a/include/cuco/dynamic_map.cuh -+++ b/include/cuco/dynamic_map.cuh -@@ -218,6 +218,7 @@ dynamic_map(std::size_t initial_capacity, - * @param first Beginning of the sequence of keys - * @param last End of the sequence of keys - * @param output_begin Beginning of the sequence of values retrieved for each key -+ * @param stream The cuda stream to enqueue the find operator. - * @param hash The unary function to apply to hash each key - * @param key_equal The binary function to compare two keys for equality - */ -@@ -228,6 +229,7 @@ dynamic_map(std::size_t initial_capacity, - void find(InputIt first, - InputIt last, - OutputIt output_begin, -+ cudaStream_t stream = 0, - Hash hash = Hash{}, - KeyEqual key_equal = KeyEqual{}); - -diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh -index 4a329ce..072f81f 100644 ---- a/include/cuco/static_map.cuh -+++ b/include/cuco/static_map.cuh -@@ -459,6 +459,7 @@ class static_map { - erased_key_sentinel_{erased_key_sentinel.value}, - empty_value_sentinel_{empty_value_sentinel.value} - { -+ assert(erased_key_sentinel_ != empty_key_sentinel_); - } - - /** -@@ -1337,6 +1338,13 @@ class static_map { - */ - std::size_t get_size() const noexcept { return size_; } - -+ /** -+ * @brief Gets the slots of elements in the hash map. -+ * -+ * @return The the slots of elements in the hash map. -+ */ -+ pair_atomic_type* get_slots() const noexcept { return slots_; } -+ - /** - * @brief Gets the load factor of the hash map. - * --- -2.17.1 - diff --git a/third_party/patch/eigen/0001-fix-eigen.patch b/third_party/patch/eigen/0001-fix-eigen.patch deleted file mode 100644 index 7b046bd71ba..00000000000 --- a/third_party/patch/eigen/0001-fix-eigen.patch +++ /dev/null @@ -1,38 +0,0 @@ ---- a/Eigen/src/Core/arch/NEON/PacketMath.h -+++ b/Eigen/src/Core/arch/NEON/PacketMath.h -@@ -1668,7 +1668,7 @@ - template<> EIGEN_STRONG_INLINE Packet4c pload(const int8_t* from) - { - Packet4c res; -- memcpy(&res, from, sizeof(Packet4c)); -+ memcpy(static_cast(&res), from, sizeof(Packet4c)); - return res; - } - template<> EIGEN_STRONG_INLINE Packet8c pload(const int8_t* from) -@@ -1678,7 +1678,7 @@ - template<> EIGEN_STRONG_INLINE Packet4uc pload(const uint8_t* from) - { - Packet4uc res; -- memcpy(&res, from, sizeof(Packet4uc)); -+ memcpy(static_cast(&res), from, sizeof(Packet4uc)); - return res; - } - template<> EIGEN_STRONG_INLINE Packet8uc pload(const uint8_t* from) -@@ -1713,7 +1713,7 @@ - template<> EIGEN_STRONG_INLINE Packet4c ploadu(const int8_t* from) - { - Packet4c res; -- memcpy(&res, from, sizeof(Packet4c)); -+ memcpy(static_cast(&res), from, sizeof(Packet4c)); - return res; - } - template<> EIGEN_STRONG_INLINE Packet8c ploadu(const int8_t* from) -@@ -1723,7 +1723,7 @@ - template<> EIGEN_STRONG_INLINE Packet4uc ploadu(const uint8_t* from) - { - Packet4uc res; -- memcpy(&res, from, sizeof(Packet4uc)); -+ memcpy(static_cast(&res), from, sizeof(Packet4uc)); - return res; - } - template<> EIGEN_STRONG_INLINE Packet8uc ploadu(const uint8_t* from) \ No newline at end of file diff --git a/third_party/patch/fast_transformer/001-fast_transformer.patch b/third_party/patch/fast_transformer/001-fast_transformer.patch deleted file mode 100644 index e3a6542d2e1..00000000000 --- a/third_party/patch/fast_transformer/001-fast_transformer.patch +++ /dev/null @@ -1,615816 +0,0 @@ -diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml -new file mode 100644 -index 0000000..18054db ---- /dev/null -+++ b/.github/ISSUE_TEMPLATE/bug_report.yml -@@ -0,0 +1,32 @@ -+name: "Bug Report" -+description: Submit a bug report -+labels: [ "bug" ] -+body: -+ - type: textarea -+ id: description -+ attributes: -+ label: Description -+ description: Please share your system info with us. -+ render: shell -+ placeholder: branch, docker version, GPU type -+ validations: -+ required: true -+ -+ - type: textarea -+ id: reproduced-steps -+ attributes: -+ label: Reproduced Steps -+ description: Please provide the step to reproduce the bugs -+ render: shell -+ placeholder: | -+ Steps to reproduce your bugs: -+ -+ 1. docker run -ti --gpus all nvcr.io/nvidia/pytorch:22.03-py3 bash -+ 2. git clone https://github.com/NVIDIA/FasterTransformer.git -+ 3. cd FasterTransformer mkdir build && cd build -+ 4. cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release .. && make -j12 -+ 5. ./bin/bert_example 32 12 32 12 64 0 0 -+ 6. What error you see. -+ -+ validations: -+ required: true -diff --git a/.gitignore b/.gitignore -index d15caca..0bca65e 100644 ---- a/.gitignore -+++ b/.gitignore -@@ -6,3 +6,4 @@ __pycache__/ - .vscode - ./translation - .cache -+ 001-fast_transformer.patch -diff --git a/.vscode/settings.json b/.vscode/settings.json -deleted file mode 100644 -index 6f535da..0000000 ---- a/.vscode/settings.json -+++ /dev/null -@@ -1,72 +0,0 @@ --{ -- "files.associations": { -- "*.cuh": "cpp", -- "stdexcept": "cpp", -- "chrono": "cpp", -- "cmath": "cpp", -- "type_traits": "cpp", -- "cctype": "cpp", -- "clocale": "cpp", -- "cstdarg": "cpp", -- "cstddef": "cpp", -- "cstdio": "cpp", -- "cstdlib": "cpp", -- "cstring": "cpp", -- "ctime": "cpp", -- "cwchar": "cpp", -- "cwctype": "cpp", -- "array": "cpp", -- "atomic": "cpp", -- "*.tcc": "cpp", -- "condition_variable": "cpp", -- "cstdint": "cpp", -- "deque": "cpp", -- "unordered_map": "cpp", -- "vector": "cpp", -- "exception": "cpp", -- "algorithm": "cpp", -- "functional": "cpp", -- "iterator": "cpp", -- "map": "cpp", -- "memory": "cpp", -- "memory_resource": "cpp", -- "numeric": "cpp", -- "optional": "cpp", -- "random": "cpp", -- "ratio": "cpp", -- "set": "cpp", -- "string": "cpp", -- "string_view": "cpp", -- "system_error": "cpp", -- "tuple": "cpp", -- "utility": "cpp", -- "fstream": "cpp", -- "initializer_list": "cpp", -- "iomanip": "cpp", -- "iosfwd": "cpp", -- "iostream": "cpp", -- "istream": "cpp", -- "limits": "cpp", -- "mutex": "cpp", -- "new": "cpp", -- "ostream": "cpp", -- "sstream": "cpp", -- "streambuf": "cpp", -- "thread": "cpp", -- "cinttypes": "cpp", -- "typeinfo": "cpp", -- "bitset": "cpp", -- "hash_map": "cpp", -- "hash_set": "cpp", -- "slist": "cpp", -- "regex": "cpp", -- "strstream": "cpp", -- "complex": "cpp", -- "forward_list": "cpp", -- "list": "cpp", -- "unordered_set": "cpp", -- "future": "cpp", -- "cfenv": "cpp", -- "typeindex": "cpp" -- } --} -\ No newline at end of file -diff --git a/3rdparty/cutlass/LICENSE.txt b/3rdparty/cutlass/LICENSE.txt -new file mode 100644 -index 0000000..2913ab8 ---- /dev/null -+++ b/3rdparty/cutlass/LICENSE.txt -@@ -0,0 +1,27 @@ -+Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+SPDX-License-Identifier: BSD-3-Clause -+ -+Redistribution and use in source and binary forms, with or without -+modification, are permitted provided that the following conditions are met: -+ -+1. Redistributions of source code must retain the above copyright notice, this -+list of conditions and the following disclaimer. -+ -+2. Redistributions in binary form must reproduce the above copyright notice, -+this list of conditions and the following disclaimer in the documentation -+and/or other materials provided with the distribution. -+ -+3. Neither the name of the copyright holder nor the names of its -+contributors may be used to endorse or promote products derived from -+this software without specific prior written permission. -+ -+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -diff --git a/3rdparty/cutlass/cmake/nop.cu b/3rdparty/cutlass/cmake/nop.cu -new file mode 100644 -index 0000000..efdb035 ---- /dev/null -+++ b/3rdparty/cutlass/cmake/nop.cu -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Basic CUDA file for testing compiler flags. -+*/ -+ -+__device__ int inner() -+{ -+ return -1; -+} -+ -+__global__ void test() -+{ -+ inner(); -+} -+ -+int main() -+{ -+ test<<<1,1>>>(); -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu b/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu -new file mode 100644 -index 0000000..57df36b ---- /dev/null -+++ b/3rdparty/cutlass/examples/00_basic_gemm/basic_gemm.cu -@@ -0,0 +1,497 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes -+ the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the SGEMM kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+ This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to -+ highlight the minimum amount of differences needed to transition to cutlass-2.0. -+ -+ Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for single-precision GEMM kernel -+// -+ -+// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class. -+#include "cutlass/gemm/device/gemm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS GEMM template and launch a GEMM kernel. -+cudaError_t CutlassSgemmNN( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ // Define type definition for single-precision CUTLASS GEMM with column-major -+ // input matrices and 128x128x8 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for single-precision GEMM. Typical values are used as -+ // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details. -+ // -+ // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassGemm = cutlass::gemm::device::Gemm; // Layout of C matrix -+ -+ // Define a CUTLASS GEMM type -+ CutlassGemm gemm_operator; -+ -+ // Construct the CUTLASS GEMM arguments object. -+ // -+ // One of CUTLASS's design patterns is to define gemm argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Gemm and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions -+ {A, lda}, // Tensor-ref for source matrix A -+ {B, ldb}, // Tensor-ref for source matrix B -+ {C, ldc}, // Tensor-ref for source matrix C -+ {C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix) -+ {alpha, beta}); // Scalars used in the Epilogue -+ -+ // -+ // Launch the CUTLASS GEMM kernel. -+ // -+ -+ cutlass::Status status = gemm_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ float *matrix, -+ int rows, -+ int columns, -+ int seed = 0) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ int offset = i + j * rows; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ float value = float(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(float *matrix, int rows, int columns, int seed = 0) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, rows, columns, seed); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(float **matrix, int rows, int columns, int seed = 0) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(float) * rows * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, rows, columns, seed); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference GEMM computation. -+__global__ void ReferenceGemm_kernel( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < M && j < N) { -+ float accumulator = 0; -+ -+ for (int k = 0; k < K; ++k) { -+ accumulator += A[i + k * lda] * B[k + j * ldb]; -+ } -+ -+ C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc]; -+ } -+} -+ -+/// Reference GEMM computation. -+cudaError_t ReferenceGemm( -+ int M, -+ int N, -+ int K, -+ float alpha, -+ float const *A, -+ int lda, -+ float const *B, -+ int ldb, -+ float beta, -+ float *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (M + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a single-precision -+/// CUTLASS GEMM kernel. -+cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to GEMM kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = M; -+ int ldb = K; -+ int ldc = M; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(float) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ float *A; -+ float *B; -+ float *C_cutlass; -+ float *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, M, K, 0); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&B, K, N, 17); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS GEMM. -+ // -+ -+ result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS GEMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference GEMM -+ result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference GEMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS GEMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference GEMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_gemm example. -+// -+// usage: -+// -+// 00_basic_gemm -+// -+int main(int argc, const char *arg[]) { -+ -+ // -+ // Parse the command line to obtain GEMM dimensions and scalar values. -+ // -+ -+ // GEMM problem dimensions. -+ int problem[3] = { 128, 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ float scalars[2] = { 1, 0 }; -+ -+ for (int i = 4; i < argc && i < 6; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 4]; -+ } -+ -+ // -+ // Run the CUTLASS GEMM test. -+ // -+ -+ cudaError_t result = TestCutlassGemm( -+ problem[0], // GEMM M dimension -+ problem[1], // GEMM N dimension -+ problem[2], // GEMM K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu b/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu -new file mode 100644 -index 0000000..f4cc4d0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/01_cutlass_utilities/cutlass_utilities.cu -@@ -0,0 +1,400 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates several CUTLASS utilities in the context of a mixed-precision -+ floating-point matrix product computation. -+ -+ These utilities are intended to be useful supporting components for managing tensor and matrix -+ memory allocations, initializing and comparing results, and computing reference output. -+ -+ CUTLASS utilities are defined in the directory `tools/util`, and definitions appear -+ namespace `cutlass::` or an inner namespace therein. Operations in `cutlass::reference::` have -+ both host-side and device-side implementations, and the choice to use device-side initialization -+ and host-side verification in this example was arbitrary. -+ -+ -+ cutlass::half_t -+ -+ This is a numeric type implementing IEEE half-precision quantities. It is functional in host -+ and device code. In host-side code, CUTLASS_ENABLE_F16C optionally enables harware-accelerated -+ numeric conversion on x86-64 CPUs support F16C extensions. In device code, all available -+ hardware is used to implement conversion and numeric operations. -+ -+ -+ cutlass::HostTensor<> -+ -+ This template class simplifies the creation of tensors for all supported layouts. It simplifies -+ allocation and management of host- and device- memory allocations. -+ -+ This class offers methods device_view() and host_view() to provide TensorView objects for -+ device- and host-side memory allocations. -+ -+ -+ cutlass::reference::device::TensorFillRandomGaussian() -+ -+ This template function initializes elementsof a tensor to a random Gaussian distribution. It -+ uses cuRAND in device code to compute random numbers. -+ -+ -+ cutlass::reference::host::Gemm<> -+ -+ This template function computes the general matrix product. This template supports unique -+ data types for each matrix operand, the internal accumulation type, and the scalar parameters -+ alpha and beta. -+ -+ -+ cutlass::reference::host::TensorEquals() -+ -+ Compares two tensors of identical rank and returns true if values are bit equivalent. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+ -+// CUTLASS includes needed for half-precision GEMM kernel -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+// -+// CUTLASS utility includes -+// -+ -+// Defines operator<<() to write TensorView objects to std::ostream -+#include "cutlass/util/tensor_view_io.h" -+ -+// Defines cutlass::HostTensor<> -+#include "cutlass/util/host_tensor.h" -+ -+// Defines cutlass::half_t -+#include "cutlass/numeric_types.h" -+ -+// Defines device_memory::copy_device_to_device() -+#include "cutlass/util/device_memory.h" -+ -+// Defines cutlass::reference::device::TensorFillRandomGaussian() -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+// Defines cutlass::reference::host::TensorEquals() -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+// Defines cutlass::reference::host::Gemm() -+#include "cutlass/util/reference/host/gemm.h" -+ -+#pragma warning( disable : 4503) -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS GEMM template and launch a GEMM kernel. -+cudaError_t cutlass_hgemm_nn( -+ int M, -+ int N, -+ int K, -+ cutlass::half_t alpha, -+ cutlass::half_t const *A, -+ cutlass::layout::ColumnMajor::Stride::Index lda, -+ cutlass::half_t const *B, -+ cutlass::layout::ColumnMajor::Stride::Index ldb, -+ cutlass::half_t beta, -+ cutlass::half_t *C, -+ cutlass::layout::ColumnMajor::Stride::Index ldc) { -+ -+ // Define the GEMM operation -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, // ElementA -+ cutlass::layout::ColumnMajor, // LayoutA -+ cutlass::half_t, // ElementB -+ cutlass::layout::ColumnMajor, // LayoutB -+ cutlass::half_t, // ElementOutput -+ cutlass::layout::ColumnMajor // LayoutOutput -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {M, N, K}, -+ {A, lda}, -+ {B, ldb}, -+ {C, ldc}, -+ {C, ldc}, -+ {alpha, beta} -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a single-precision -+/// CUTLASS GEMM kernel. -+cudaError_t TestCutlassGemm(int M, int N, int K, cutlass::half_t alpha, cutlass::half_t beta) { -+ cudaError_t result; -+ -+ // -+ // Construct cutlass::HostTensor<> using the half-precision host-side type. -+ // -+ // cutlass::HostTensor<> allocates memory on both the host and device corresponding to rank=2 -+ // tensors in column-major layout. Explicit synchronization methods are offered to copy the -+ // tensor to the device or to the host. -+ // -+ -+ // M-by-K matrix of cutlass::half_t -+ cutlass::HostTensor A(cutlass::MatrixCoord(M, K)); -+ -+ // K-by-N matrix of cutlass::half_t -+ cutlass::HostTensor B(cutlass::MatrixCoord(K, N)); -+ -+ // M-by-N matrix of cutlass::half_t -+ cutlass::HostTensor C_cutlass(cutlass::MatrixCoord(M, N)); -+ -+ // M-by-N matrix of cutlass::half_t -+ cutlass::HostTensor C_reference(cutlass::MatrixCoord(M, N)); -+ -+ // -+ // Initialize matrices with small, random integers. -+ // -+ -+ // Arbitrary RNG seed value. Hard-coded for deterministic results. -+ uint64_t seed = 2080; -+ -+ // Gaussian random distribution -+ cutlass::half_t mean = 0.0_hf; -+ cutlass::half_t stddev = 5.0_hf; -+ -+ // Specify the number of bits right of the binary decimal that are permitted -+ // to be non-zero. A value of "0" here truncates random values to integers -+ int bits_less_than_one = 0; -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ A.device_view(), -+ seed, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ B.device_view(), -+ seed * 2019, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ cutlass::reference::device::TensorFillRandomGaussian( -+ C_cutlass.device_view(), -+ seed * 1993, -+ mean, -+ stddev, -+ bits_less_than_one -+ ); -+ -+ -+ // Copy C_cutlass into C_reference so the GEMM is correct when beta != 0. -+ cutlass::device_memory::copy_device_to_device( -+ C_reference.device_data(), -+ C_cutlass.device_data(), -+ C_cutlass.capacity()); -+ -+ // Copy the device-side view into host memory -+ C_reference.sync_host(); -+ -+ // -+ // Launch the CUTLASS GEMM kernel -+ // -+ -+ result = cutlass_hgemm_nn( -+ M, -+ N, -+ K, -+ alpha, -+ A.device_data(), -+ A.stride(0), -+ B.device_data(), -+ B.stride(0), -+ beta, -+ C_cutlass.device_data(), -+ C_cutlass.stride(0) -+ ); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ // -+ // Verify the result using a host-side reference -+ // -+ -+ // A and B were initialized using device-side procedures. The intent of this example is to -+ // use the host-side reference GEMM, so we must perform a device-to-host copy. -+ A.sync_host(); -+ B.sync_host(); -+ -+ // Copy CUTLASS's GEMM results into host memory. -+ C_cutlass.sync_host(); -+ -+ // Compute the reference result using the host-side GEMM reference implementation. -+ cutlass::reference::host::Gemm< -+ cutlass::half_t, // ElementA -+ cutlass::layout::ColumnMajor, // LayoutA -+ cutlass::half_t, // ElementB -+ cutlass::layout::ColumnMajor, // LayoutB -+ cutlass::half_t, // ElementOutput -+ cutlass::layout::ColumnMajor, // LayoutOutput -+ cutlass::half_t, -+ cutlass::half_t -+ > gemm_ref; -+ -+ gemm_ref( -+ {M, N, K}, // problem size (type: cutlass::gemm::GemmCoord) -+ alpha, // alpha (type: cutlass::half_t) -+ A.host_ref(), // A (type: TensorRef) -+ B.host_ref(), // B (type: TensorRef) -+ beta, // beta (type: cutlass::half_t) -+ C_reference.host_ref() // C (type: TensorRef) -+ ); -+ -+ // Compare reference to computed results. -+ if (!cutlass::reference::host::TensorEquals( -+ C_reference.host_view(), -+ C_cutlass.host_view())) { -+ -+ char const *filename = "errors_01_cutlass_utilities.csv"; -+ -+ std::cerr << "Error - CUTLASS GEMM kernel differs from reference. Wrote computed and reference results to '" << filename << "'" << std::endl; -+ -+ // -+ // On error, print C_cutlass and C_reference to std::cerr. -+ // -+ // Note, these are matrices of half-precision elements stored in host memory as -+ // arrays of type cutlass::half_t. -+ // -+ -+ std::ofstream file(filename); -+ -+ // Result of CUTLASS GEMM kernel -+ file << "\n\nCUTLASS =\n" << C_cutlass.host_view() << std::endl; -+ -+ // Result of reference computation -+ file << "\n\nReference =\n" << C_reference.host_view() << std::endl; -+ -+ // Return error code. -+ return cudaErrorUnknown; -+ } -+ -+ // Passed error check -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to cutlass_utilities example. -+// -+// usage: -+// -+// 01_cutlass_utilities -+// -+int main(int argc, const char *arg[]) { -+ -+ // -+ // This example uses half-precision and is only suitable for devices with compute capabitliy 5.3 or greater. -+ // -+ -+ cudaDeviceProp prop; -+ cudaError_t result = cudaGetDeviceProperties(&prop, 0); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to query device properties with error " << cudaGetErrorString(result) << std::endl; -+ return -1; -+ } -+ -+ if (!(prop.major > 5 || (prop.major == 5 && prop.minor >= 3))) { -+ std::cerr << "This example uses half precision and is only suitable for devices with compute capability 5.3 or greater.\n"; -+ std::cerr << "You are using a CUDA device with compute capability " << prop.major << "." << prop.minor << std::endl; -+ return -1; -+ } -+ -+ // -+ // Parse the command line to obtain GEMM dimensions and scalar values. -+ // -+ -+ // GEMM problem dimensions: -+ int problem[3] = { 128, 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Linear scale factors in GEMM. Note, these are half-precision values stored as -+ // cutlass::half_t. -+ // -+ // Values outside the range of IEEE FP16 will overflow to infinity or underflow to zero. -+ // -+ cutlass::half_t scalars[2] = { 1.0_hf, 0.0_hf }; -+ -+ for (int i = 4; i < argc && i < 6; ++i) { -+ std::stringstream ss(arg[i]); -+ -+ ss >> scalars[i - 4]; // lexical cast to cutlass::half_t -+ } -+ -+ // -+ // Run the CUTLASS GEMM test. -+ // -+ -+ result = TestCutlassGemm( -+ problem[0], // GEMM M dimension -+ problem[1], // GEMM N dimension -+ problem[2], // GEMM K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu -new file mode 100644 -index 0000000..f70e721 ---- /dev/null -+++ b/3rdparty/cutlass/examples/02_dump_reg_shmem/dump_reg_shmem.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Demonstrate CUTLASS debugging tool for dumping fragments and shared -+ memory -+ */ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Standard Library includes -+ -+#include -+ -+// -+// CUTLASS includes -+// -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/util/debug.h" -+#include "cutlass/util/device_dump.h" -+ -+#define EXAMPLE_MATRIX_ROW 64 -+#define EXAMPLE_MATRIX_COL 32 -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_dump(typename GmemIterator::Params params, -+ typename GmemIterator::TensorRef ref) { -+ extern __shared__ Element shared_storage[]; -+ -+ // Construct the global iterator and load the data to the fragments. -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ GmemIterator gmem_iterator(params, ref.data(), -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL}, -+ tb_thread_id); -+ -+ typename GmemIterator::Fragment frag; -+ -+ frag.clear(); -+ gmem_iterator.load(frag); -+ -+ // Call dump_fragment() with different parameters. -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nAll threads dump all the elements:\n"); -+ cutlass::debug::dump_fragment(frag); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps all the elements:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps first 16 elements:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1, /*M = */ 16); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nFirst thread dumps first 16 elements with a stride of 8:\n"); -+ cutlass::debug::dump_fragment(frag, /*N = */ 1, /*M = */ 16, /*S = */ 8); -+ -+ // Construct the shared iterator and store the data to the shared memory. -+ SmemIterator smem_iterator( -+ typename SmemIterator::TensorRef( -+ {shared_storage, SmemIterator::Layout::packed( -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL})}), -+ tb_thread_id); -+ -+ smem_iterator.store(frag); -+ -+ // Call dump_shmem() with different parameters. -+ if (threadIdx.x == 0 && blockIdx.x == 0) printf("\nDump all the elements:\n"); -+ cutlass::debug::dump_shmem(shared_storage, -+ EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL); -+ -+ if (threadIdx.x == 0 && blockIdx.x == 0) -+ printf("\nDump all the elements with a stride of 8:\n"); -+ cutlass::debug::dump_shmem( -+ shared_storage, EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL, /*S = */ 8); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point for dump_reg_shmem example. -+// -+// usage: -+// -+// 02_dump_reg_shmem -+// -+int main() { -+ // Initialize a 64x32 column major matrix with sequential data (1,2,3...). -+ using Element = cutlass::half_t; -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ cutlass::HostTensor matrix( -+ {EXAMPLE_MATRIX_ROW, EXAMPLE_MATRIX_COL}); -+ cutlass::reference::host::BlockFillSequential(matrix.host_data(), -+ matrix.capacity()); -+ -+ // Dump the matrix. -+ std::cout << "Matrix:\n" << matrix.host_view() << "\n"; -+ -+ // Copy the matrix to the device. -+ matrix.sync_device(); -+ -+ // Define a global iterator, a shared iterator and their thread map. -+ using ThreadMap = cutlass::transform::PitchLinearWarpRakedThreadMap< -+ cutlass::layout::PitchLinearShape, -+ 32, cutlass::layout::PitchLinearShape<8, 4>, 8>; -+ -+ using GmemIterator = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, Element, -+ Layout, 1, ThreadMap>; -+ -+ typename GmemIterator::Params params(matrix.layout()); -+ -+ using SmemIterator = cutlass::transform::threadblock::RegularTileIterator< -+ cutlass::MatrixShape, Element, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous<16, 64>, 1, -+ ThreadMap>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ int smem_size = -+ int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL); -+ -+ kernel_dump -+ <<>>(params, matrix.device_ref()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cout << "Failed" << std::endl; -+ } -+ -+ return (result == cudaSuccess ? 0 : -1); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/options.h b/3rdparty/cutlass/examples/03_visualize_layout/options.h -new file mode 100644 -index 0000000..fd99b1c ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/options.h -@@ -0,0 +1,121 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Cutlass command line parser -+#include "cutlass/util/command_line.h" -+ -+class Options { -+public: -+ -+ bool help; -+ bool good; -+ std::vector extent; ///< extent of tile to fill -+ std::vector stride; ///< stride vector for layout function -+ std::vector output_shape; ///< output shape -+ int vectorize; ///< sequences of consecutive output elements are concatenated into a vector -+ /// if, and only if, they were consecutive in source memory -+ -+public: -+ -+ /// Options -+ Options(): -+ help(false), -+ good(true), -+ extent({32, 8}), -+ stride({32}), -+ output_shape({16, 8}), -+ vectorize(1) { -+ -+ } -+ -+ /// Constructs from command line parser -+ Options(cutlass::CommandLine const & cmd_line): help(false), good(true) { -+ -+ if (cmd_line.check_cmd_line_flag("help") || -+ cmd_line.check_cmd_line_flag("h")) { -+ -+ help = true; -+ } -+ -+ if (cmd_line.check_cmd_line_flag("extent")) { -+ cmd_line.get_cmd_line_arguments("extent", extent); -+ } -+ else { -+ extent = {32, 8}; -+ } -+ -+ if (cmd_line.check_cmd_line_flag("stride")) { -+ cmd_line.get_cmd_line_arguments("stride", stride); -+ } -+ -+ int default_output_shape[] = {16, 8}; -+ -+ if (cmd_line.check_cmd_line_flag("output-shape")) { -+ cmd_line.get_cmd_line_arguments("output-shape", output_shape); -+ } -+ -+ for (int i = int(output_shape.size()); i < 2; ++i) { -+ output_shape.push_back(default_output_shape[i]); -+ } -+ -+ if (cmd_line.check_cmd_line_flag("vectorize")) { -+ cmd_line.get_cmd_line_argument("vectorize", vectorize); -+ } -+ else { -+ vectorize = 1; -+ } -+ -+ if (output_shape.front() % vectorize) { -+ -+ std::cerr << "Error: --vectorize=" << vectorize -+ << " must divide contiguous elements in --output-shape=" -+ << output_shape.at(0) << "," << output_shape.at(1) << std::endl; -+ -+ good = false; -+ } -+ } -+ -+ /// Prints usage statement -+ static void print_usage(std::ostream &out) { -+ out -+ << " Options:\n" -+ << " --help Displays this help message.\n" -+ << " --extent= Specifies the layout-specific extent (as comma-delimited array).\n" -+ << " --stride= Specifies the layout-specific stride vector (comma-delimited array)\n" -+ << " --output-shape= Specifies the dimensions of a row-major output matrix. \n" -+ << " --vectorize= If possible, vectorizes the output into vectors of consecutive elements\n"; -+ } -+}; -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu -new file mode 100644 -index 0000000..423bfcc ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.cu -@@ -0,0 +1,145 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#include -+#include -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "visualize_layout.h" -+#include "register_layout.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void RegisterLayouts(std::map > &layouts) { -+ -+ struct { -+ char const *name; -+ VisualizeLayoutBase *ptr; -+ } layout_pairs[] = { -+ -+ {"PitchLinear", new VisualizeLayout}, -+ {"ColumnMajor", new VisualizeLayout}, -+ {"RowMajor", new VisualizeLayout}, -+ {"ColumnMajorInterleaved<4>", -+ new VisualizeLayout>}, -+ {"RowMajorInterleaved<4>", -+ new VisualizeLayout>}, -+ // All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling -+ // layout implementation with different templates. -+ // -+ // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256 -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256 -+ {"TensorOpMultiplicand<1,256>", -+ new VisualizeLayout>}, -+ // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512 -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512 -+ {"TensorOpMultiplicand<1,512>", -+ new VisualizeLayout>}, -+ // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024 -+ {"TensorOpMultiplicand<1,1024>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 8832 Interleaved-64 -+ // Integer matrix multiply.int4 16864 Interleaved-64 -+ {"TensorOpMultiplicand<4,64>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 8832 TN kblock128 -+ // Integer matrix multiply.int4 16864 TN kblock128 -+ {"TensorOpMultiplicand<4,128>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply.int4 16864 TN kblock256 -+ {"TensorOpMultiplicand<4,256>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 8816 Interleaved-32 -+ // Integer matrix multiply 16832 Interleaved-32 -+ {"TensorOpMultiplicand<8,32>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 8816 TN kblock64 -+ // Integer matrix multiply 16832 TN kblock64 -+ {"TensorOpMultiplicand<8,64>", -+ new VisualizeLayout>}, -+ // Integer matrix multiply 16832 TN kblock128 -+ {"TensorOpMultiplicand<8,128>", -+ new VisualizeLayout>}, -+ // Matrix Multiply 1688 TN kblock32 -+ // Matrix multiply 16816 TN kblock32 -+ {"TensorOpMultiplicand<16,32>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688 NT -+ // Matrix multiply 16816 NT -+ // Matrix multiply 16816 TN kblock64 -+ {"TensorOpMultiplicand<16,64>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688.TF32 TN kblock16 -+ {"TensorOpMultiplicand<32,16>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688.TF32 TN kblock32 -+ {"TensorOpMultiplicand<32,32>", -+ new VisualizeLayout>}, -+ // Matrix multiply 1688 NT -+ {"TensorOpMultiplicandCongruous<32,32>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>}, -+ // Matrix multiply 884 NT -+ {"TensorOpMultiplicandCongruous<64,16>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>}, -+ // Matrix multiply 884 TN -+ {"TensorOpMultiplicand64bCrosswise", -+ new VisualizeLayout}, -+ {"TensorOpMultiplicandCongruous<128,4>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>}, -+ {"TensorOpMultiplicandCrosswise<128,4>", -+ new VisualizeLayout< -+ cutlass::layout::TensorOpMultiplicandCrosswise<128, 4>>}, -+ {"VoltaTensorOpMultiplicandCongruous<16>", -+ new VisualizeLayout< -+ cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>}, -+ {"VoltaTensorOpMultiplicandCrosswise<16,32>", -+ new VisualizeLayout< -+ cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>} -+ }; -+ -+ for (auto layout : layout_pairs) { -+ layouts.emplace(std::string(layout.name), std::unique_ptr(layout.ptr)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h -new file mode 100644 -index 0000000..bb5f893 ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/register_layout.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct VisualizeLayoutBase { -+ virtual bool visualize(Options const &) = 0; -+ virtual bool verify(bool verbose, std::ostream &out) = 0; -+ virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0; -+ virtual std::ostream &print_help(std::ostream &out) { -+ return out; -+ } -+ virtual ~VisualizeLayoutBase() { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void RegisterLayouts(std::map > &layouts); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h b/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h -new file mode 100644 -index 0000000..cef8579 ---- /dev/null -+++ b/3rdparty/cutlass/examples/03_visualize_layout/visualize_layout.h -@@ -0,0 +1,383 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS layout visualization example -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/reference/host/tensor_foreach.h" -+ -+#include "register_layout.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = vec.at(0); -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+std::ostream &operator<<(std::ostream &out, std::vector const &vec) { -+ auto it = vec.begin(); -+ if (it != vec.end()) { -+ out << *it; -+ for (++it; it != vec.end(); ++it) { -+ out << ", " << *it; -+ } -+ } -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ -+ vec.at(Rank - 1) = coord[Rank - 1]; -+ coord_to_vector(vec, coord); -+ } -+}; -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ -+ vec.at(0) = coord[0]; -+ } -+}; -+ -+/// Permits copying static-length vectors into dynamic vectors -+template -+struct coord_to_vector { -+ -+ coord_to_vector(std::vector &vec, TensorCoord const &coord) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure representing an element in source memory -+struct Element { -+ -+ std::vector coord; ///< logical coordinate of element (as vector) -+ int offset; ///< linear offset from source memory -+ int color; ///< enables coloring each element to indicate -+ -+ /// Default ctor -+ inline Element(): offset(-1), color(0) { } -+ -+ /// Construct from logical coordinate and initial offset -+ inline Element( -+ std::vector const &coord_, -+ int offset_, -+ int color_ = 0 -+ ): -+ coord(coord_), offset(offset_), color(color_) { } -+ -+ /// Returns true if element is in a defined state -+ inline bool valid() const { -+ return offset >= 0; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Visualizes memory layouts by constructing a 'shape' -+template -+class VisualizeLayout : public VisualizeLayoutBase { -+public: -+ -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Stride = typename Layout::Stride; -+ -+public: -+ -+ Options options; -+ Layout layout; -+ TensorCoord extent; -+ std::vector elements; -+ -+public: -+ -+ /// Initializes the problem space -+ VisualizeLayout() { -+ -+ } -+ -+ /// visualization method -+ bool visualize(Options const &options_) { -+ -+ options = options_; -+ -+ if (options.extent.size() != TensorCoord::kRank) { -+ -+ std::cerr -+ << "--extent must have rank " << TensorCoord::kRank -+ << " (given: " << options.extent.size() << ")" << std::endl; -+ -+ return false; -+ } -+ -+ vector_to_coord(extent, options.extent); -+ -+ // Construct the layout for a packed tensor -+ if (options.stride.empty()) { -+ -+ layout = Layout::packed(extent); -+ } -+ else if (options.stride.size() != Stride::kRank) { -+ -+ std::cerr -+ << "--stride must have rank " << Stride::kRank -+ << " (given: " << options.stride.size() << ")" << std::endl; -+ -+ return false; -+ } -+ else { -+ // Stride from -+ Stride stride; -+ vector_to_coord(stride, options.stride); -+ -+ layout = Layout(stride); -+ } -+ -+ // Resize elements, setting elements to 'undefined' state -+ elements.resize(layout.capacity(extent)); -+ -+ // enumerate points in tensor space and assign -+ cutlass::reference::host::TensorForEachLambda( -+ extent, -+ [&](TensorCoord coord) { -+ -+ std::vector coord_vec(TensorCoord::kRank, 0); -+ coord_to_vector(coord_vec, coord); -+ -+ int offset = int(layout(coord)); -+ -+ if (offset >= int(elements.size())) { -+ std::cerr -+ << "Layout error - " << coord_vec -+ << " is out of range (computed offset: " << offset -+ << ", capacity: " << elements.size() << std::endl; -+ -+ throw std::out_of_range("(TensorForEach) layout error - coordinate out of range"); -+ } -+ -+ elements.at(offset) = Element(coord_vec, offset); -+ }); -+ -+ return true; -+ } -+ -+ /// Verifies the layout satisfies vectorization requirements -+ bool verify(bool verbose, std::ostream &out) { -+ return true; -+ } -+ -+private: -+ -+ /// returns a pair (is_vectorizable, one_changing_rank) to determine if a -+ /// vector exists (consecutive logical coordinates or uniformly invalid) -+ /// at the given location. -+ std::pair< bool, int > _is_vectorizable(int i) const { -+ // (all elements are invalid) or -+ // (all elements are valid AND -+ // exactly one rank is changing AND -+ // elements are consecutive) -+ -+ // Don't need vectorization. -+ if (options.vectorize <= 2) return std::make_pair(false, -1); -+ -+ // Boundary check. -+ if (i > elements.size() || (i + options.vectorize - 1) > elements.size()) -+ return std::make_pair(false, -1); -+ -+ // Check if either all elements are valid or invalid. -+ bool all_elements_invalid = std::all_of( -+ elements.begin() + i, elements.begin() + i + options.vectorize, -+ [](Element const &e) { return !e.valid(); }); -+ -+ bool all_elements_valid = std::all_of( -+ elements.begin() + i, elements.begin() + i + options.vectorize, -+ [](Element const &e) { return e.valid(); }); -+ -+ if (!all_elements_invalid && !all_elements_valid) -+ return std::make_pair(false, -1); -+ -+ // From here, it is vectorizable. -+ if (all_elements_invalid) return std::make_pair(true, -1); -+ -+ // Check if only exactly one rank is changing. -+ int one_changing_rank = -1; -+ for (int j = 0; j < options.vectorize; ++j) { -+ for (int r = 0; r < TensorCoord::kRank; ++r) { -+ if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) { -+ if (one_changing_rank == -1) { -+ one_changing_rank = r; -+ } else if (one_changing_rank != r) { -+ return std::make_pair(false, -1); -+ } -+ } -+ } -+ } -+ -+ return std::make_pair(true, one_changing_rank); -+ } -+ -+ /// Prints a vector of elements -+ void _print_vector(std::ostream &out, int i, int one_changing_rank) { -+ Element const &base_element = elements.at(i); -+ if (base_element.valid()) { -+ out << "("; -+ for (int r = 0; r < TensorCoord::kRank; ++r) { -+ if (r) { -+ out << ", "; -+ } -+ -+ if (r == one_changing_rank) { -+ out -+ << base_element.coord.at(r) -+ << ".." -+ << (base_element.coord.at(r) + options.vectorize - 1); -+ } -+ else { -+ out << base_element.coord.at(r); -+ } -+ } -+ out << ")"; -+ } -+ else { -+ out << " "; -+ } -+ } -+ -+ /// Prints a single element -+ void _print_element(std::ostream &out, int k) { -+ Element const &element = elements.at(k); -+ if (element.valid()) { -+ out << "("; -+ for (int v = 0; v < TensorCoord::kRank; ++v) { -+ out << (v ? ", " : "") << element.coord.at(v); -+ } -+ out << ")"; -+ } -+ else { -+ out << " "; -+ } -+ } -+ -+public: -+ -+ /// Pretty-prints the layout to the console -+ void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') { -+ int row = -1; -+ -+ for (int i = 0; i < int(elements.size()); i += options.vectorize) { -+ if (i % options.output_shape.at(0)) { -+ out << delim; -+ } -+ else { -+ if (row >= 0) { -+ out << new_line; -+ } -+ ++row; -+ if (row == options.output_shape.at(1)) { -+ out << new_line; -+ row = 0; -+ } -+ } -+ -+ auto is_vector = _is_vectorizable(i); -+ -+ if (is_vector.first) { -+ _print_vector(out, i, is_vector.second); // print a vector starting at element i -+ } -+ else { -+ for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j) -+ _print_element(out, i + j); -+ } -+ } -+ } -+ -+ out << new_line << std::flush; -+ } -+ -+ /// Help message -+ virtual std::ostream &print_help(std::ostream &out) { -+ out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank; -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu b/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu -new file mode 100644 -index 0000000..8146a09 ---- /dev/null -+++ b/3rdparty/cutlass/examples/04_tile_iterator/tile_iterator.cu -@@ -0,0 +1,221 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to use the PredicatedTileIterator in CUTLASS to load data from -+ addressable memory, and then store it back into addressable memory. -+ -+ TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to -+ and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines -+ the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined -+ thread mappings to be specified. -+ -+ In this example, a PredicatedTileIterator is used to load elements from a tile in global memory, -+ stored in column-major layout, into a fragment and then back into global memory in the same -+ layout. -+ -+ This example uses CUTLASS utilities to ease the matrix operations. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS includes -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+// -+// CUTLASS utility includes -+// -+ -+// Defines operator<<() to write TensorView objects to std::ostream -+#include "cutlass/util/tensor_view_io.h" -+ -+// Defines cutlass::HostTensor<> -+#include "cutlass/util/host_tensor.h" -+ -+// Defines cutlass::reference::host::TensorFill() and -+// cutlass::reference::host::TensorFillBlockSequential() -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#pragma warning( disable : 4503) -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define PredicatedTileIterators to load and store a M-by-K tile, in column major layout. -+ -+template -+__global__ void copy( -+ typename Iterator::Params dst_params, -+ typename Iterator::Element *dst_pointer, -+ typename Iterator::Params src_params, -+ typename Iterator::Element *src_pointer, -+ cutlass::Coord<2> extent) { -+ -+ -+ Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x); -+ Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x); -+ -+ // PredicatedTileIterator uses PitchLinear layout and therefore takes in a PitchLinearShape. -+ // The contiguous dimension can be accessed via Iterator::Shape::kContiguous and the strided -+ // dimension can be accessed via Iterator::Shape::kStrided -+ int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided; -+ -+ typename Iterator::Fragment fragment; -+ -+ for(int i = 0; i < fragment.size(); ++i) { -+ fragment[i] = 0; -+ } -+ -+ src_iterator.load(fragment); -+ dst_iterator.store(fragment); -+ -+ -+ ++src_iterator; -+ ++dst_iterator; -+ -+ for(; iterations > 1; --iterations) { -+ -+ src_iterator.load(fragment); -+ dst_iterator.store(fragment); -+ -+ ++src_iterator; -+ ++dst_iterator; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Initializes the source tile with sequentially increasing values and performs the copy into -+// the destination tile using two PredicatedTileIterators, one to load the data from addressable -+// memory into a fragment (regiser-backed array of elements owned by each thread) and another to -+// store the data from the fragment back into the addressable memory of the destination tile. -+ -+cudaError_t TestTileIterator(int M, int K) { -+ -+ // For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects -+ // PitchLinearShape and PitchLinear layout. -+ using Shape = cutlass::layout::PitchLinearShape<64, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int; -+ int const kThreads = 32; -+ -+ // ThreadMaps define how threads are mapped to a given tile. The PitchLinearStripminedThreadMap -+ // stripmines a pitch-linear tile among a given number of threads, first along the contiguous -+ // dimension then along the strided dimension. -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap; -+ -+ // Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< -+ Shape, Element, Layout, 1, ThreadMap>; -+ -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(M, K); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(M, K); -+ -+ // Allocate source and destination tensors -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ -+ // Initialize destination tensor with all -1s -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ // Initialize source tensor with sequentially increasing values -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ // Launch copy kernel to perform the copy -+ copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ if(result != cudaSuccess) { -+ std::cerr << "Error - kernel failed." << std::endl; -+ return result; -+ } -+ -+ dst_tensor.sync_host(); -+ -+ // Verify results -+ for(int s = 0; s < alloc_extent[1]; ++s) { -+ for(int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if(c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ if(!equal) { -+ std::cerr << "Error - source tile differs from destination tile." << std::endl; -+ return cudaErrorUnknown; -+ } -+ } -+ } -+ -+ return cudaSuccess; -+} -+ -+int main(int argc, const char *arg[]) { -+ -+ cudaError_t result = TestTileIterator(57, 35); -+ -+ if(result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit -+ return result == cudaSuccess ? 0 : -1; -+} -+ -diff --git a/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu b/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu -new file mode 100644 -index 0000000..ab85361 ---- /dev/null -+++ b/3rdparty/cutlass/examples/05_batched_gemm/batched_gemm.cu -@@ -0,0 +1,466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/device/gemm_array.h" -+#include "cutlass/gemm/device/gemm_batched.h" -+ -+#pragma warning( disable : 4503) -+ -+/* -+This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways: -+ 1. By specifying pointers to the first matrices of the batch and the stride between the consecutive -+ matrices of the batch (this is called a strided batched gemm). -+ 2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm). -+In this example, both A and B matrix are non-transpose and column major matrix -+batched_C = batched_A x batched_B -+As an example, matrix C can be seen as -+----------------------------------------------------------- -+(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) | -+----------------------------------------------------------- -+(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) | -+----------------------------------------------------------- -+(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) | -+----------------------------------------------------------- -+(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) | -+----------------------------------------------------------- -+(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) | -+----------------------------------------------------------- -+(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) | -+----------------------------------------------------------- -+ batch 0 | batch 1 -+where we denote each element with (batch_idx, row_idx, column_idx) -+In this example, batch size is 2, M is 6 and N is 3 -+The stride (batch_stride_C) between the first element of two batches is ldc * n -+ -+matrix A can be seen as -+--------------------------------------- -+(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) | -+--------------------------------------- -+(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) | -+--------------------------------------- -+(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) | -+--------------------------------------- -+(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) | -+--------------------------------------- -+(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) | -+--------------------------------------- -+(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) | -+--------------------------------------- -+ batch 0 | batch 1 -+, where batch size is 2, M is 6 and K is 2 -+The stride (batch_stride_A) between the first element of two batches is lda * k -+ -+matrix B can be seen as -+----------------------------- -+(0,0,0) | (0,0,1) | (0,0,2) | -+----------------------------- batch 0 -+(0,1,0) | (0,1,1) | (0,1,2) | -+------------------------------------- -+(1,0,0) | (1,0,1) | (1,0,2) | -+----------------------------- batch 1 -+(1,1,0) | (1,1,1) | (1,1,2) | -+----------------------------- -+, where the batch size is 2, N is 3 and K is 2 -+The stride (batch_stride_B) between the first element of two batches is k -+ -+ -+*/ -+ -+cudaError_t cutlass_array_sgemm( -+ int m, -+ int n, -+ int k, -+ float alpha, -+ float const * const *A, -+ int lda, -+ float const * const *B, -+ int ldb, -+ float * const *C, -+ int ldc, -+ float beta, -+ int batch_count) { -+ -+ using Gemm = cutlass::gemm::device::GemmArray< -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, -+ A, lda, -+ B, ldb, -+ C, ldc, -+ C, ldc, -+ {alpha, beta}, -+ batch_count -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+cudaError_t cutlass_strided_batched_sgemm( -+ int m, -+ int n, -+ int k, -+ float alpha, -+ float const *A, -+ int lda, -+ long long int batch_stride_A, -+ float const *B, -+ int ldb, -+ long long int batch_stride_B, -+ float *C, -+ int ldc, -+ long long int batch_stride_C, -+ float beta, -+ int batch_count) { -+ -+ using Gemm = cutlass::gemm::device::GemmBatched< -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor -+ >; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, -+ {A, lda}, -+ batch_stride_A, -+ {B, ldb}, -+ batch_stride_B, -+ {C, ldc}, -+ batch_stride_C, -+ {C, ldc}, -+ batch_stride_C, -+ {alpha, beta}, -+ batch_count -+ }); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+template -+cudaError_t strided_batched_gemm_nn_reference( -+ int m, -+ int n, -+ int k, -+ T alpha, -+ std::vector const &A, -+ int lda, -+ long long int batch_stride_A, -+ std::vector const &B, -+ int ldb, -+ long long int batch_stride_B, -+ std::vector &C, -+ int ldc, -+ long long int batch_stride_C, -+ T beta, -+ int batch_count) { -+ /* -+ strided batched gemm NN -+ */ -+ -+ cudaError_t result = cudaSuccess; -+ -+ if (A.size() < lda * k * batch_count) { -+ std::cout << "the size of A is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ if (B.size() < ldb * n) { -+ std::cout << "the size of B is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ if (C.size() < ldc * n * batch_count) { -+ std::cout << "the size of C is too small" << std::endl; -+ return cudaErrorInvalidValue; -+ } -+ -+ for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { -+ for (int n_idx = 0; n_idx < n; n_idx++) { -+ for (int m_idx = 0; m_idx < m; m_idx++) { -+ T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx]; -+ for (int k_idx = 0; k_idx < k; k_idx++) { -+ accum += alpha -+ * A[batch_idx * batch_stride_A + k_idx * lda + m_idx] -+ * B[batch_idx * batch_stride_B + n_idx * ldb + k_idx]; -+ } -+ C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum; -+ } -+ } -+ } -+ -+ return result; -+} -+ -+ -+cudaError_t run_batched_gemm(bool use_array) { -+ -+ const char* gemm_desc = use_array ? "array" : "strided batched"; -+ std::cout << "Running " << gemm_desc << " gemm" << std::endl; -+ -+ // Arbitrary problem size -+ int const m = 520; -+ int const n = 219; -+ int const k = 129; -+ int const batch_count = 17; -+ -+ // A, B are non-transpose, column major -+ int const lda = m; -+ int const ldb = k * batch_count; -+ int const ldc = m; -+ -+ int const count_A = batch_count * lda * k; -+ int const count_B = ldb * n; -+ int const count_C = batch_count * ldc * n; -+ -+ // the memory is batched along K dimension -+ long long int batch_stride_A = static_cast(lda) * static_cast(k); -+ long long int batch_stride_B = static_cast(k); -+ long long int batch_stride_C = static_cast(ldc) * static_cast(n); -+ -+ // alpha and beta -+ float alpha = 1.0f; -+ float beta = 2.0f; -+ -+ cudaError_t result = cudaSuccess; -+ -+ // allocate the host memory -+ std::vector host_A(count_A); -+ std::vector host_B(count_B); -+ std::vector host_C(count_C); -+ std::vector result_C(count_C); -+ -+ // allocate the device memory -+ float *A; -+ float *B; -+ float *C; -+ -+ result = cudaMalloc(&A, count_A * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&B, count_B * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&C, count_C * sizeof(float)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ -+ // Limit range to avoid floating-point errors -+ int const kRange = 8; -+ -+ // fill A -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < k; col_idx++) { -+ for (int row_idx = 0; row_idx < m; row_idx++) { -+ host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast((row_idx + col_idx * lda + b_idx * lda * k) % kRange); -+ } -+ } -+ } -+ // fill B -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < n; col_idx++) { -+ for (int row_idx = 0; row_idx < k; row_idx++) { -+ host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange); -+ } -+ } -+ } -+ // fill C -+ for (int b_idx = 0; b_idx < batch_count; b_idx++) { -+ for (int col_idx = 0; col_idx < n; col_idx++) { -+ for (int row_idx = 0; row_idx < m; row_idx++) { -+ host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f; -+ } -+ } -+ } -+ -+ // ref memory -+ std::vector ref_A(host_A); -+ std::vector ref_B(host_B); -+ std::vector ref_C(host_C); -+ // copy host memory to device -+ result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ // run cutlass -+ if (use_array) { -+ // allocate the host memory for the pointers to the matrices of the batch -+ std::vector host_ptr_A(batch_count); -+ std::vector host_ptr_B(batch_count); -+ std::vector host_ptr_C(batch_count); -+ -+ // permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride -+ std::vector permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5}; -+ for (size_t b_idx = 0; b_idx < batch_count; b_idx++) { -+ host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A; -+ host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B; -+ host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C; -+ } -+ -+ // allocate the corresponding device memory -+ float const **ptr_A; -+ float const **ptr_B; -+ float **ptr_C; -+ -+ result = cudaMalloc(&ptr_A, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&ptr_B, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMalloc(&ptr_C, batch_count * sizeof(float*)); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMalloc result = " << result << std::endl; -+ return result; -+ } -+ -+ // copy the matrix pointers to the device -+ result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count); -+ -+ if (result != cudaSuccess) -+ return result; -+ } else { -+ result = cutlass_strided_batched_sgemm( -+ m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C, -+ beta, batch_count); -+ if (result != cudaSuccess) -+ return result; -+ } -+ -+ // copy device memory to host -+ result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaMemcpy result = " << result << std::endl; -+ return result; -+ } -+ -+ //compare with reference code -+ result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C, -+ beta, batch_count); -+ if (result != 0) -+ return result; -+ -+ // Expect bit-level accuracy for this simple example -+ if (ref_C != result_C) { -+ std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl; -+ return cudaErrorUnknown; -+ } -+ -+ // free memory -+ result = cudaFree(A); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ result = cudaFree(B); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ result = cudaFree(C); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaFree result = " << result << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+int main() { -+ -+ cudaError_t result = cudaSuccess; -+ for (bool use_array : {false, true}) { -+ result = run_batched_gemm(use_array); -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } else { -+ break; -+ } -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -diff --git a/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu b/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu -new file mode 100644 -index 0000000..9c88851 ---- /dev/null -+++ b/3rdparty/cutlass/examples/06_splitK_gemm/splitk_gemm.cu -@@ -0,0 +1,340 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to use split-k version of matrix multiplication using functions and data -+structures provided by CUTLASS; which we run on a NVIDIA Volta GPU. -+ -+What is split-k? -+Consider a problem size of M = 128, N = 128, K = 4096. In this case, if my thread-block tile size (a -+tile can be viewed as a 2d matrix) is 128x128x4096, then we launch a singled a thread-block taking -+up a single SM of 84 SMs present on V100. Hence the efficiency of computation is really low. So, how -+to solve it? This is where split-k comes in. It is a way of partitioning K-dimension of matrix -+multiplication and distribute across multiple SMs and get better efficiency than single SM. In the -+above example, we can partition K-dimension with split-k factor of 16 i.e., thread-block tile size -+will be 128x128x256 and will be launching on 16 SMs. Once each thread-block computes their partial -+inner product (1/16th of output), they accumulate to single output matrix. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In this example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = float. As we want to MMA instructions on -+Volta and they support only half-precision floating point (fp16 or half), we use data type for -+elements in input matrix A and B as cutlass::half_t. Volta also supports accumulation of partial dot -+product to fp32, which can store wider range of numbers, we use it as data type of output matrix -+elements and accumulation. We convey this to CUTLASS kernel by initializing template variables -+ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), -+ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not -+enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do -+that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -+to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C -+which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the -+data type of output ElementOutput (float), the number of elements per vector memory access (16), -+data type of accumulator (float) and data type of computation of linear combination (alpha * X + -+beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, -+64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, initialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::GemmSplitKParallel template. -+ -+The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to initialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm70; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes ? -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- This is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Put all the created template variables to create GemmSplitKParallel template variable -+using Gemm = cutlass::gemm::device::GemmSplitKParallel; -+ -+int run() { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major != 7) { -+ std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." -+ << std::endl; -+ -+ // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. -+ return 0; -+ } -+ -+ // -+ // Define problem size -+ // -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 16 partitions -+ int split_k_slices = 16; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ // -+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op. -+ return 0; -+ } -+ else { -+ return run(); -+ } -+} -+ -diff --git a/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu -new file mode 100644 -index 0000000..c38f040 ---- /dev/null -+++ b/3rdparty/cutlass/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu -@@ -0,0 +1,357 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run matrix multiplication kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Volta GPU. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = float. As we want to MMA instructions on -+Volta and they support only half-precision floating point (fp16 or half), we use data type for -+elements in input matrix A and B as cutlass::half_t. Volta also supports accumulation of partial dot -+product to fp32, which can store wider range of numbers, we use it as data type of output matrix -+elements and accumulation. We convey this to CUTLASS kernel by initializing template variables -+ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), -+ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not -+enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do -+that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -+to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C -+which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the -+data type of output ElementOutput (int32_t), the number of elements per vector memory access (16), -+data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + -+beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, -+64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, intialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a CTA. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+matrix in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) matrix in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) matrix in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memoroy load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::Gemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm70; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes ? -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major != 7) { -+ std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75." -+ << std::endl; -+ -+ // Return 0 so tests are considered passing if run on unsupported architectures or CUDA Toolkits. -+ return 0; -+ } -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op. -+ return 0; -+ } -+ else { -+ return run(); -+ } -+} -+ -diff --git a/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu -new file mode 100644 -index 0000000..bcff579 ---- /dev/null -+++ b/3rdparty/cutlass/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run matrix multiplication kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. -+ -+Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU -+easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set matrices will be used to compute -+output of matrix multiplication. -+ -+First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for -+GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the -+rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise -+operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for -+alpha and beta to be equal to ElementComputeEpilogue = int32_t. As we want to use MMA instructions -+on Turing and they support 8-bit signed integer (int8_t), we use data type for elements in input -+matrix A and B as int8_t. Volta also supports accumulation of partial dot product to int32_t, which -+can store wider range of numbers, we use it as data type of output matrix elements and accumulation. -+We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t), -+ElementComputeEpilogue (int32_t), ElementInputA (int8_t), ElementInputB (int8_t), ElementOutput -+(int32_t). Communicating just the data type is not enough. As the data is laid out linearly in -+memory, we have to convey the layout of matrices. We do that by initializing template variable -+LayoutInputA to column major cutlass variable, LayoutInputB to row major and LayoutOutput to row -+major. Next, we setup rules to comptue alpha * X + beta * C which is called epilogue of the kernel. -+We initialize template variable EpilogueOp, which takes the data type of output ElementOutput -+(int32_t), the number of elements per vector memory access (16), data type of accumulator (int32_t) -+and data type of computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x256x64, -+64x64x16, 8x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally -+deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -+bank-conflict free manner, and ton of other variables required to compose, intialize and launch a -+high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from -+understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+matrix in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) matrix in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) matrix in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memoroy load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS GEMM kernel using -+cutlass::gemm::device::Gemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if -+the output from CUTLASS kernel is same as reference GEMM kernel. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = int32_t; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = int8_t; // <- data type of elements in input matrix A -+using ElementInputB = int8_t; // <- data type of elements in input matrix B -+using ElementOutput = int32_t; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ bool notSupported = false; -+ -+ // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 75)) { -+ std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75." -+ << std::endl; -+ -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -+ -diff --git a/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu -new file mode 100644 -index 0000000..e39784e ---- /dev/null -+++ b/3rdparty/cutlass/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu -@@ -0,0 +1,771 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+ -+This example shows how to run convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Turing GPU. -+ -+Writing a single high performance convolution kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -+of GPU easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set tensors will be used to compute -+output of convolution. -+ -+First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -+with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In CUTLASS, -+the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as -+alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as -+epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -+ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit -+signed integer. But int4b_t is not fully supported by Nvidia software stack, so CUTLASS introduces -+cutlass::int4b_t. We use the data type for elements in input tensor A and B as cutlass::int4b_t. We -+convey this to CUTLASS kernel by initializing template variables ElementAccumulator (int32_t), -+ElementComputeEpilogue (float), ElementInputA (cutlass::int4b_t), ElementInputB (cutlass::int4b_t), -+ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out -+linearly in memory, we have to convey the layout of tensors. We do that by initializing template -+variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -+rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -+variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of -+elements per vector memory access (32), data type of accumulator (int32_t) and data type of -+computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128, -+64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -+internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -+data in bank-conflict free manner, and ton of other variables required to compose, intialize and -+launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -+from understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma pipeline. -+ -+tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers -> -+output to global memory -+ -+The problem with single pipeline is, each stage is synchronous which means, each stage has to wait -+until the previous finished executing. There are stages in the pipeline which do not have fixed -+latency, for example, the loads from global memory and shared memory. Therefore, we can add one more -+pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads. -+Finally, the pipeline in a kernel looks like -+ -+(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5) -+mma -> (6) registers -> (7) output to global memory (1) -> (2) -> (3) tensor in global -+memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers -> -+(9) output to global memory -+ -+This way, you can hide the second global memory load latency by doing computation on already loaded -+input data. -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS Implicit GEMM -+kernel using cutlass::conv::device::ImplicitGemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -+R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -+compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = int32_t; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::int4b_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::int4b_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::int4b_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, // Data type of output matrix. -+ 8, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of int4b_t elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 32 elements. -+ // -+ int const kAlignment = 32; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "09_turing_tensorop_conv2dfprop example\n\n" -+ << " This example uses Turing's Tensor Core operators on int4 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/09_turing_tensorop_conv2dfprop/09_turing_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_ref_c(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_c.host_view()); -+ -+ // Fill tensor C for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_c.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_ref_c.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ // mode (kCrossCorrelation or kConvolution) -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_c.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_c.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_c.host_view(), -+ tensor_ref_c.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "09_tensor_conv_workspace_conv2dfprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ return 0; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major > 7 || (props.major == 7 && props.minor >= 5))) { -+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." -+ << std::endl; -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ -diff --git a/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu b/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu -new file mode 100644 -index 0000000..9e0915d ---- /dev/null -+++ b/3rdparty/cutlass/examples/10_planar_complex/planar_complex.cu -@@ -0,0 +1,567 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Planar Complex GEMM -+ -+ This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels supporting -+ the batched strided mode. -+ -+ These kernels represent complex matrices by storing the real and imaginary parts of the matrix in -+ disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts -+ as either column-major or row-major layouts with a single leading dimension indicating the stride -+ between columns or rows. -+ -+ The CUTLASS Library collects multiple template instantiations in a data structure and offers -+ a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. -+ -+ CUTLASS decouples matrix layout from complex transformation, so four possible transformations -+ are possible on the A and B operands: -+ -+ n: column-major -+ c: column-major complex conjugate -+ t: row-major -+ h: row-major complex conjugate -+ -+ The CUTLASS Library contains many kernel instances specialized for architecture, data type, tile -+ size, and alignment. This can result in long compile times. -+ -+ To build strictly the planar complex kernels needed for general application, execute the following -+ CMake command in an empty build directory. -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex -+ -+ This builds all planar complex GEMM variants for Volta and Turing architectures. -+ -+ To build strictly the kernels needed for this example, an even narrower filter string may be -+ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for -+ the 'CN' layout configuration (conjugate A operand with both A and B as column-major). -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn -+ -+ $ make 10_planar_complex -+ -+ $ ./examples/10_planar_complex/10_planar_complex --m=2048 --n=1024 --k=512 --batch=10 -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+ -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "cutlass/util/reference/device/gemm_planar_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::complex alpha; -+ cutlass::complex beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.real()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.imag()); -+ cmd.get_cmd_line_argument("beta", beta.real()); -+ cmd.get_cmd_line_argument("beta_i", beta.imag()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "10_planar_complex example\n\n" -+ << " This example uses the CUTLASS Library to execute Planar Complex GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/10_planar_complex/10_planar_complex --batch=7 --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 4; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance test environment for planar complex -+class TestbedPlanarComplex { -+public: -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementCompute = float; -+ using ElementAccumulator = float; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::library::Handle handle; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::DeviceAllocation tensor_A; -+ cutlass::DeviceAllocation tensor_B; -+ cutlass::DeviceAllocation tensor_C; -+ cutlass::DeviceAllocation tensor_D; -+ cutlass::DeviceAllocation tensor_D_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex( -+ Options const &options -+ ): -+ problem_size(options.problem_size), batch_count(options.batch_count) { -+ -+ // Allocate device memory for batched strided GEMM -+ tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); -+ tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); -+ tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ // Use small integers to simplify correctness checking -+ int scope_max = 6; -+ int scope_min = -6; -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_A.get(), tensor_A.size(), seed, ElementA(scope_max), ElementA(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_B.get(), tensor_B.size(), seed * 2019, ElementB(scope_max), ElementB(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_C.get(), tensor_C.size(), seed * 2020, ElementC(scope_max), ElementC(scope_min), 0); -+ } -+ -+ Result profile(Options const &options) { -+ -+ Result result; -+ -+ initialize(); -+ -+ ElementA *ptr_A = tensor_A.get(); -+ ElementB *ptr_B = tensor_B.get(); -+ ElementC *ptr_C = tensor_C.get(); -+ ElementC *ptr_D = tensor_D.get(); -+ -+ int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; -+ int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; -+ int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; -+ int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; -+ -+ typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); -+ typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ -+ int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); -+ int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); -+ int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); -+ int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // -+ // Execute the planar complex GEMM kernel via the CUTLASS Library's -+ // dispatch routines. -+ // -+ // Note, for planar complex GEMM kernels, all numeric type arguments -+ // specify the data type of the base real types. These are understood to -+ // apply to planar complex representations of matrices in memory and to complex -+ // structures for scalars. -+ // -+ // See tools/library/include/cutlass/library/handle.h for more details. -+ // -+ -+ result.status = handle.gemm_planar_complex( -+ problem_size.m(), // GEMM M dimension -+ problem_size.n(), // GEMM N dimension -+ problem_size.k(), // GEMM K dimension -+ -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars -+ -+ &options.alpha, // Pointer to alpha scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix -+ cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand -+ ptr_A, // Pointer to real part of A matrix -+ ptr_A + imag_stride_A, // Pointer to imaginary part of A matrix -+ lda, // Leading dimension of real part of A matrix -+ lda, // Leading dimension of imaginary part of A matrix -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix -+ cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand -+ ptr_B, // Pointer to real part of B matrix -+ ptr_B + imag_stride_B, // Pointer to imaginary part of B matrix -+ ldb, // Leading dimension of real part of B matrix -+ ldb, // Leading dimension of imaginary part of B matrix -+ -+ &options.beta, // Pointer to beta scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices -+ -+ ptr_C, // Pointer to real part of C matrix -+ ptr_C + imag_stride_C, // Pointer to imaginary part of C matrix -+ ldc, // Leading dimension of real part of C matrix -+ ldc, // Leading dimension of imaginary part of C matrix -+ -+ ptr_D, // Pointer to real part of D matrix -+ ptr_D + imag_stride_D, // Pointer to imaginary part of D matrix -+ ldd, // Leading dimension of real part of D matrix -+ ldd, // Leading dimension of imaginary part of D matrix -+ -+ batch_count, // Number of batched elements -+ -+ batch_stride_A, // Stride between batches of real parts of A matrix -+ batch_stride_A, // Stride between batches of imaginary parts of A matrix -+ -+ batch_stride_B, // Stride between batches of real parts of B matrix -+ batch_stride_B, // Stride between batches of imaginary parts of B matrix -+ -+ batch_stride_C, // Stride between batches of real parts of C matrix -+ batch_stride_C, // Stride between batches of imaginary parts of C matrix -+ -+ batch_stride_D, // Stride between batches of real parts of D matrix -+ batch_stride_D // Stride between batches of imaginary parts of D matrix -+ ); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (handle.get_last_operation()) { -+ std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; -+ } -+ -+ // -+ // Compute reference in device code -+ // -+ -+ if (options.reference_check) { -+ -+ result.passed = true; -+ -+ for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { -+ cutlass::reference::device::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ options.alpha, -+ {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, -+ cutlass::ComplexTransform::kConjugate, -+ {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, -+ {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} -+ ); -+ -+ ElementC epsilon = 0.1_hf; -+ ElementC nonzero_floor = 0.1_hf; -+ -+ result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( -+ tensor_D.get() + idx * batch_stride_D, -+ tensor_D_ref.get() + idx * batch_stride_D, -+ batch_stride_D, -+ epsilon, -+ nonzero_floor -+ ); -+ } -+ -+ if (result.passed) { -+ std::cout << "Reference check passed." << std::endl; -+ } -+ else { -+ std::cerr << "Error - reference check failed." << std::endl; -+ } -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. -+ // -+ // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major < 7) { -+ std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70." -+ << std::endl; -+ -+ // Returning zero so this test passes on older architectures even though its actions are no-op. -+ return 0; -+ } -+ else if (props.major == 7 && props.minor <= 2) { -+ // -+ // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits even though its actions are no-op. -+ return 0; -+ } -+ } -+ else if (props.major == 7 && props.minor >= 5) { -+ // -+ // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits even though its actions are no-op. -+ return 0; -+ } -+ } -+ else { -+ // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. -+ // -+ // fall through -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ TestbedPlanarComplex testbed(options); -+ -+ Result result = testbed.profile(options); -+ -+ return result.passed ? 0 : -1; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu b/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu -new file mode 100644 -index 0000000..e317731 ---- /dev/null -+++ b/3rdparty/cutlass/examples/11_planar_complex_array/planar_complex_array.cu -@@ -0,0 +1,628 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Planar Complex Array Example -+ -+ This example demonstrates the CUTLASS Library's exposure of planar complex GEMM kernels which -+ execute a batch of matrix products, loading problem sizes and matrix base pointers from arrays -+ in global memory. -+ -+ These kernels represent complex matrices by storing the real and imaginary parts of the matrix in -+ disjoint regions in memory. These real-valued matrices are stored using existing cuBLAS layouts -+ as either column-major or row-major layouts with a single leading dimension indicating the stride -+ between columns or rows. -+ -+ The CUTLASS Library collects multiple template instantiations in a data structure and offers -+ a BLAS-like dispatch API to invoke the appropriate kernel on the Volta or Turing architectures. -+ -+ CUTLASS decouples matrix layout from complex transformation, so four possible transformations -+ are possible on the A and B operands: -+ -+ n: column-major -+ c: column-major complex conjugate -+ t: row-major -+ h: row-major complex conjugate -+ -+ To build strictly the planar complex kernels needed for general application, execute the following -+ CMake command in an empty build directory. -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex -+ -+ This builds all planar complex GEMM variants for Volta and Turing architectures. -+ -+ To build strictly the kernels needed for this example, an even narrower filter string may be -+ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for -+ the 'CN' layout configuration (conjugate A operand with both A and B as column-major). -+ -+ $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -+ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn -+ -+ $ make 11_planar_complex_array -+ -+ $ ./examples/11_planar_complex_array/11_planar_complex_array --m=2048 --n=1024 --k=512 --batch=10 -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+ -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "cutlass/util/reference/device/gemm_planar_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::complex alpha; -+ cutlass::complex beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.real()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.imag()); -+ cmd.get_cmd_line_argument("beta", beta.real()); -+ cmd.get_cmd_line_argument("beta_i", beta.imag()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "11_planar_complex_array example\n\n" -+ << " This example uses the CUTLASS Library to execute Planar Complex Array GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/11_planar_complex_array/11_planar_complex_array\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 4; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance test environment for planar complex -+class TestbedPlanarComplex { -+public: -+ -+ // Half-precision input and output -+ using Element = cutlass::half_t; -+ -+ // Configurations for layouts and internal computation -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementCompute = float; -+ using ElementAccumulator = float; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::library::Handle handle; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::DeviceAllocation tensor_A; -+ cutlass::DeviceAllocation tensor_B; -+ cutlass::DeviceAllocation tensor_C; -+ cutlass::DeviceAllocation tensor_D; -+ cutlass::DeviceAllocation tensor_D_ref; -+ -+ cutlass::DeviceAllocation ptr_A_real; -+ cutlass::DeviceAllocation ptr_A_imag; -+ cutlass::DeviceAllocation ptr_B_real; -+ cutlass::DeviceAllocation ptr_B_imag; -+ cutlass::DeviceAllocation ptr_C_real; -+ cutlass::DeviceAllocation ptr_C_imag; -+ cutlass::DeviceAllocation ptr_D_real; -+ cutlass::DeviceAllocation ptr_D_imag; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex( -+ Options const &options -+ ): -+ problem_size(options.problem_size), batch_count(options.batch_count) { -+ -+ // Allocate device memory for batched planar complex GEMM -+ tensor_A.reset(int64_t(problem_size.m()) * problem_size.k() * batch_count * 2); -+ tensor_B.reset(int64_t(problem_size.k()) * problem_size.n() * batch_count * 2); -+ tensor_C.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ tensor_D_ref.reset(int64_t(problem_size.m()) * problem_size.n() * batch_count * 2); -+ -+ ptr_A_real.reset(batch_count); -+ ptr_A_imag.reset(batch_count); -+ ptr_B_real.reset(batch_count); -+ ptr_B_imag.reset(batch_count); -+ ptr_C_real.reset(batch_count); -+ ptr_C_imag.reset(batch_count); -+ ptr_D_real.reset(batch_count); -+ ptr_D_imag.reset(batch_count); -+ -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ // Use small integers to simplify correctness checking -+ int scope_max = 6; -+ int scope_min = -6; -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_A.get(), tensor_A.size(), seed, Element(scope_max), Element(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_B.get(), tensor_B.size(), seed * 2019, Element(scope_max), Element(scope_min), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ tensor_C.get(), tensor_C.size(), seed * 2020, Element(scope_max), Element(scope_min), 0); -+ } -+ -+ Result profile(Options const &options) { -+ -+ Result result; -+ -+ initialize(); -+ -+ Element *ptr_A = tensor_A.get(); -+ Element *ptr_B = tensor_B.get(); -+ Element *ptr_C = tensor_C.get(); -+ Element *ptr_D = tensor_D.get(); -+ -+ int64_t batch_stride_A = int64_t(problem_size.m()) * problem_size.k() * 2; -+ int64_t batch_stride_B = int64_t(problem_size.k()) * problem_size.n() * 2; -+ int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2; -+ int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2; -+ -+ typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0); -+ typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0); -+ -+ -+ int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k(); -+ int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n(); -+ int64_t imag_stride_C = int64_t(problem_size.m()) * problem_size.n(); -+ int64_t imag_stride_D = int64_t(problem_size.m()) * problem_size.n(); -+ -+ // -+ // Configure pointers in global memory -+ // -+ -+ struct { -+ Element *base; -+ void **ptr_real; -+ void **ptr_imag; -+ int64_t batch_stride; -+ int64_t imag_stride; -+ } tensors[] = { -+ { tensor_A.get(), ptr_A_real.get(), ptr_A_imag.get(), batch_stride_A, imag_stride_A}, -+ { tensor_B.get(), ptr_B_real.get(), ptr_B_imag.get(), batch_stride_B, imag_stride_B}, -+ { tensor_C.get(), ptr_C_real.get(), ptr_C_imag.get(), batch_stride_C, imag_stride_C}, -+ { tensor_D.get(), ptr_D_real.get(), ptr_D_imag.get(), batch_stride_D, imag_stride_D} -+ }; -+ -+ for (auto const &tensor : tensors) { -+ for (int idx = 0; idx < batch_count; ++idx) { -+ -+ void *ptr_real = tensor.base + idx * tensor.batch_stride; -+ void *ptr_imag = tensor.base + idx * tensor.batch_stride + tensor.imag_stride; -+ -+ cudaError_t error = cudaMemcpy( -+ tensor.ptr_real + idx, -+ &ptr_real, -+ sizeof(void *), -+ cudaMemcpyHostToDevice); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to copy pointer to device memory"); -+ } -+ -+ error = cudaMemcpy( -+ tensor.ptr_imag + idx, -+ &ptr_imag, -+ sizeof(void *), -+ cudaMemcpyHostToDevice); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to copy pointer to device memory"); -+ } -+ } -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // -+ // Execute the planar complex array GEMM kernel via the CUTLASS Library's -+ // dispatch routines. -+ // -+ // Note, for planar complex array GEMM kernels, all numeric type arguments -+ // specify the data type of the base real types. These are understood to -+ // apply to planar complex representations of matrices in memory and to complex -+ // structures for scalars. -+ // -+ // See tools/library/include/cutlass/library/handle.h for more details. -+ // -+ -+ result.status = handle.gemm_planar_complex_array( -+ -+ problem_size.m(), // expected GEMM M dimension -+ problem_size.n(), // expected GEMM N dimension -+ problem_size.k(), // expected GEMM K dimension -+ batch_count, // Number of batched elements -+ -+ nullptr, -+ nullptr, -+ nullptr, -+ -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued accumulation -+ cutlass::library::NumericTypeID::kF32, // Base data type of complex-valued alpha/beta scalars -+ -+ &options.alpha, // Pointer to alpha scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued A matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of A matrix -+ cutlass::library::ComplexTransform::kConjugate, // Complex transformation on A matrix operand -+ -+ ptr_A_real.get(), // Pointer to array of pointers to real part of A matrix -+ ptr_A_imag.get(), // Pointer to array of pointers to imaginary part of A matrix -+ -+ lda, // Leading dimension of real part of A matrix -+ lda, // Leading dimension of imaginary part of A matrix -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex-valued B matrix -+ cutlass::library::LayoutTypeID::kColumnMajor, // Layout of B matrix -+ cutlass::library::ComplexTransform::kNone, // Complex transformation on B matrix operand -+ -+ ptr_B_real.get(), // Pointer to array of pointers to real part of B matrix -+ ptr_B_imag.get(), // Pointer to array of pointers to imaginary part of B matrix -+ -+ ldb, // Leading dimension of real part of B matrix -+ ldb, // Leading dimension of imaginary part of B matrix -+ -+ &options.beta, // Pointer to beta scalar, of type complex -+ -+ cutlass::library::NumericTypeID::kF16, // Base data type of complex valued C and D matrices -+ -+ ptr_C_real.get(), // Pointer to array of pointers to real part of C matrix -+ ptr_C_imag.get(), // Pointer to array of pointers to imaginary part of C matrix -+ -+ ldc, // Leading dimension of real part of C matrix -+ ldc, // Leading dimension of imaginary part of C matrix -+ -+ ptr_D_real.get(), // Pointer to array of pointers to real part of D matrix -+ ptr_D_imag.get(), // Pointer to array of pointers to imaginary part of D matrix -+ -+ ldd, // Leading dimension of real part of D matrix -+ ldd // Leading dimension of imaginary part of D matrix -+ ); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS internal error - configuration not supported" << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (handle.get_last_operation()) { -+ std::cout << "Recently executed '" << handle.get_last_operation()->description().name << "'" << std::endl; -+ } -+ -+ // -+ // Compute reference in device code -+ // -+ -+ if (options.reference_check) { -+ -+ result.passed = true; -+ -+ for (int64_t idx = 0; result.passed && idx < int64_t(batch_count); ++idx) { -+ cutlass::reference::device::GemmPlanarComplex< -+ Element, LayoutA, -+ Element, LayoutB, -+ Element, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ options.alpha, -+ {tensor_A.get() + idx * batch_stride_A, lda, imag_stride_A}, -+ cutlass::ComplexTransform::kConjugate, -+ {tensor_B.get() + idx * batch_stride_B, ldb, imag_stride_B}, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ {tensor_C.get() + idx * batch_stride_C, ldc, imag_stride_C}, -+ {tensor_D_ref.get() + idx * batch_stride_D, ldd, imag_stride_D} -+ ); -+ -+ Element epsilon = 0.1_hf; -+ Element nonzero_floor = 0.1_hf; -+ -+ result.passed = cutlass::reference::device::BlockCompareRelativelyEqual( -+ tensor_D.get() + idx * batch_stride_D, -+ tensor_D_ref.get() + idx * batch_stride_D, -+ batch_stride_D, -+ epsilon, -+ nonzero_floor -+ ); -+ } -+ -+ if (result.passed) { -+ std::cout << "Reference check passed." << std::endl; -+ } -+ else { -+ std::cerr << "Error - reference check failed." << std::endl; -+ } -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ // Volta Tensor Core operations are first available in CUDA 10.1 Toolkit. -+ // -+ // Turing Tensor Core operations are first available in CUDA 10.2 Toolkit. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major < 7) { -+ std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70." -+ << std::endl; -+ -+ // Returning zero so this passes on older architectures. Its actions are no-op. -+ return 0; -+ } -+ else if (props.major == 7 && props.minor <= 2) { -+ // -+ // If running on the Volta architecture, at least CUDA 10.1 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl; -+ -+ // Returning zero so this passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ } -+ else if (props.major == 7 && props.minor >= 5) { -+ // -+ // If running on the Turing architecture, at least CUDA 10.2 Toolkit is required to run this example. -+ // -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ -+ // Returning zero so this passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ } -+ else { -+ // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. -+ // -+ // fall through -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ TestbedPlanarComplex testbed(options); -+ -+ Result result = testbed.profile(options); -+ -+ return result.passed ? 0 : -1; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu -new file mode 100644 -index 0000000..418540f ---- /dev/null -+++ b/3rdparty/cutlass/examples/12_gemm_bias_relu/gemm_bias_relu.cu -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias. -+// If the output is row major, the bias has to be per column, i.e. every column has different bias. -+// Below list some other notices: -+// -+// Note this example only works for ColumnMajor output because -+// 1) we only have row major epilogue. -+// 2) we swap A and B if the output is column major then we can still use the -+// row major epilogue. -+// 3) Mx1 bias vector becomes 1xM after the swapping/transposing. -+// 4) we can use the existing OutputIterator to load 1xM bias vector. -+ -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm75; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to -+// -+// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij ) -+// -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue, // <- data type for alpha in linear combination function -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run() { -+ -+ const int length_m = 5120; -+ const int length_n = 4096; -+ const int length_k = 4096; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ -+ cutlass::HostTensor tensor_c_bias( -+ {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1 -+ -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_bias.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c_bias.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ -+ {tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM -+ // to project away the N dimension by setting the stride to zero. -+ -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha}, // <- alpha -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference gemm kernel -+ // -+ -+ cutlass::reference::device::Gemm -+ gemm_device_reference; -+ -+ // Launch device reference to compute strictly the product A * B -+ gemm_device_reference( -+ problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ 0, -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Compute bias + relu in host code -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ tensor_ref_d.at({i, j}) = std::max( -+ ElementOutput(0), -+ ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0})) -+ ); -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()) -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ return 0; -+} -+ -+int main() { -+ -+ bool notSupported = false; -+ -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { -+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major * 10 + props.minor >= 75)) { -+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h -new file mode 100644 -index 0000000..9da0e66 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h -@@ -0,0 +1,719 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+ -+template -+class B2bNonFusedConv2dRun { -+public: -+ -+ using Conv2d0 = Conv2d0_; -+ using Conv2d1 = Conv2d1_; -+ using ElementAccumulator = typename Conv2d0::ElementAccumulator; -+ using ElementCompute = typename Conv2d0::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; -+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, -+ "Fused convolution operators must be the same"); -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_D0_computed; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bNonFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); -+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_computed.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1); -+ -+ // configure the operator -+ Conv2d0 conv2d_op_0; -+ Conv2d1 conv2d_op_1; -+ -+ typename Conv2d0::Arguments conv2d_args_0( -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, -+ tensor_D0_computed.device_ref(), -+ {alpha0, beta0}, -+ split_k_mode -+ ); -+ typename Conv2d1::Arguments conv2d_args_1( -+ problem_size_1, -+ tensor_D0_computed.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, -+ tensor_D1_computed.device_ref(), -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ -+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = conv2d_op_1.initialize(conv2d_args_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run Conv2d -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float conv2d0Time, conv2d1Time, totalTime; -+ cudaEventElapsedTime(&conv2d0Time, start, stop1); -+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; -+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0_computed.sync_host(); -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d0::ElementA, -+ typename Conv2d0::LayoutA, -+ typename Conv2d0::ElementB, -+ typename Conv2d0::LayoutB, -+ typename Conv2d0::ElementC, -+ typename Conv2d0::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ beta0); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d1::ElementA, -+ typename Conv2d1::LayoutA, -+ typename Conv2d1::ElementB, -+ typename Conv2d1::LayoutB, -+ typename Conv2d1::ElementC, -+ typename Conv2d1::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" -+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class B2bFusedConv2dRun { -+public: -+ -+ using B2bConv2d = B2bConv2d_; -+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator; -+ using ElementCompute = typename B2bConv2d::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Scale0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_Z0_reference; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ ElementCompute alpha0, -+ ElementCompute alpha1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.K}); -+ tensor_Bias0.resize({1, problem_size_0.K}); -+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1, alpha0, alpha1); -+ -+ // configure the operator -+ B2bConv2d b2b_conv2d_op; -+ -+ typename B2bConv2d::Arguments b2b_conv2d_args( -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, -+ tensor_D1_computed.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" -+ << " problem_size_0.K = problem_size_1.C\n" -+ << " problem_size_1.R = problem_size_1.S = 1\n" -+ << " ThreadblockShape0::kN = problem_size_0.K\n" -+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; -+ } -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_conv2d_op.initialize(b2b_conv2d_args); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the Conv2d -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ -+ // run conv2d operator -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float conv2dTime; -+ cudaEventElapsedTime(&conv2dTime, start, stop); -+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; -+ -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ ElementAccumulator, -+ typename B2bConv2d::LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ ElementAccumulator(1), // intermediate alpha = 1 -+ ElementAccumulator(0) // beta = 0 -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasConv2d< -+ ElementAccumulator, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ typename B2bConv2d::LayoutScaleBias -+ >( -+ problem_size_0, -+ tensor_Z0_reference.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h -new file mode 100644 -index 0000000..b8b080c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h -@@ -0,0 +1,714 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct B2bNonFusedGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bNonFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm0::Arguments arguments_0{ -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0} -+ }; -+ -+ typename Gemm1::Arguments arguments_1{ -+ problem_size_1, -+ tensor_D0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1} -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+}; -+ -+template -+struct B2bFusedGemmRun -+{ -+ -+ using B2bGemm = B2bGemm_; -+ using ElementAccumulator = typename B2bGemm::ElementAccumulator; -+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementA, -+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Scale0; -+ -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ ElementAccumulator, -+ typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ ElementCompute, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename B2bGemm::Arguments arguments{ -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ }; -+ -+ B2bGemm b2b_gemm_op; -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.M = problem_size_1.M\n" -+ << " problem_size_0.N = problem_size_1.K\n" -+ << " ThreadblockShape0::kN = problem_size_0.N\n" -+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; -+ } -+ -+ status = b2b_gemm_op.initialize(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ ElementAccumulator, typename B2bGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename B2bGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ ElementAccumulator(1), //intermediate alpha=1 -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementAccumulator(0), //beta = 0 -+ reference_Z0.device_ref(), -+ reference_Z0.device_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasGemm< -+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, -+ ElementCompute, typename B2bGemm::LayoutScaleBias -+ > ( -+ problem_size_0, -+ reference_Z0.device_ref(), -+ reference_D0.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h -new file mode 100644 -index 0000000..a6d0625 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+ -+template -+class B2bInterleavedNonFusedConv2dRun { -+public: -+ -+ using Conv2d0 = Conv2d0_; -+ using Conv2d1 = Conv2d1_; -+ using ElementAccumulator = typename Conv2d0::ElementAccumulator; -+ using ElementCompute = typename Conv2d0::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; -+ static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, -+ "Fused convolution operators must be the same"); -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_B0_reordered; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_D0_computed; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_B1_reordered; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bInterleavedNonFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); -+ tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_convK( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); -+ cutlass::reorder_convK( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_computed.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1); -+ -+ // configure the operator -+ Conv2d0 conv2d_op_0; -+ Conv2d1 conv2d_op_1; -+ -+ typename Conv2d0::Arguments conv2d_args_0( -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_D0_computed.device_ref(), -+ {alpha0, beta0}, -+ split_k_mode -+ ); -+ typename Conv2d1::Arguments conv2d_args_1( -+ problem_size_1, -+ tensor_D0_computed.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_computed.device_ref(), -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ -+ cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = conv2d_op_1.initialize(conv2d_args_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run Conv2d -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_0(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ -+ for(int i = 0; i < runs; i++) { -+ // run conv2d operator -+ status = conv2d_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float conv2d0Time, conv2d1Time, totalTime; -+ cudaEventElapsedTime(&conv2d0Time, start, stop1); -+ cudaEventElapsedTime(&conv2d1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; -+ std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0_computed.sync_host(); -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d0::ElementA, -+ typename Conv2d0::LayoutA, -+ typename Conv2d0::ElementB, -+ typename Conv2d0::LayoutB, -+ typename Conv2d0::ElementC, -+ typename Conv2d0::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ beta0); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename Conv2d1::ElementA, -+ typename Conv2d1::LayoutA, -+ typename Conv2d1::ElementB, -+ typename Conv2d1::LayoutB, -+ typename Conv2d1::ElementC, -+ typename Conv2d1::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" -+ << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class B2bInterleavedFusedConv2dRun { -+public: -+ -+ using B2bConv2d = B2bConv2d_; -+ using ElementAccumulator = typename B2bConv2d::ElementAccumulator; -+ using ElementCompute = typename B2bConv2d::ElementCompute; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_B0_reordered; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_Scale0; -+ cutlass::HostTensor tensor_Bias0; -+ cutlass::HostTensor tensor_Z0_reference; -+ cutlass::HostTensor tensor_D0_reference; -+ -+ cutlass::HostTensor tensor_B1; -+ cutlass::HostTensor tensor_B1_reordered; -+ cutlass::HostTensor tensor_C1; -+ cutlass::HostTensor tensor_Bias1; -+ cutlass::HostTensor tensor_D1_computed; -+ cutlass::HostTensor tensor_D1_reference; -+ -+ -+public: -+ -+ B2bInterleavedFusedConv2dRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 16) { -+ scope = 2; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ ElementCompute alpha0, -+ ElementCompute alpha1, -+ uint64_t seed = 2019) { -+ -+ tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.K}); -+ tensor_Bias0.resize({1, problem_size_0.K}); -+ tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); -+ tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); -+ tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); -+ -+ initialize_tensor(tensor_A0.host_view(), init_A, seed); -+ initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); -+ initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); -+ initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); -+ initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); -+ initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_convK<16, InterleavedK>( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); -+ cutlass::reorder_convK( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0_reference.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1_computed.sync_device(); -+ tensor_D1_reference.sync_device(); -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size_0, -+ cutlass::conv::Conv2dProblemSize const &problem_size_1, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ initialize(problem_size_0, problem_size_1, alpha0, alpha1); -+ -+ // configure the operator -+ B2bConv2d b2b_conv2d_op; -+ -+ typename B2bConv2d::Arguments b2b_conv2d_args( -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_computed.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ split_k_mode -+ ); -+ -+ cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" -+ << " problem_size_0.K = problem_size_1.C\n" -+ << " problem_size_1.R = problem_size_1.S = 1\n" -+ << " ThreadblockShape0::kN = problem_size_0.K\n" -+ << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; -+ } -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_conv2d_op.initialize(b2b_conv2d_args); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the Conv2d -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ -+ // run conv2d operator -+ status = b2b_conv2d_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float conv2dTime; -+ cudaEventElapsedTime(&conv2dTime, start, stop); -+ std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; -+ -+ tensor_D1_computed.sync_host(); -+ -+ bool passed = false; -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ ElementAccumulator, -+ typename B2bConv2d::LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ tensor_Z0_reference.device_ref(), -+ ElementAccumulator(1), // intermediate alpha = 1 -+ ElementAccumulator(0) // beta = 0 -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasConv2d< -+ ElementAccumulator, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ typename B2bConv2d::LayoutScaleBias, -+ cutlass::NumericConverterClamp -+ >( -+ problem_size_0, -+ tensor_Z0_reference.device_ref(), -+ tensor_D0_reference.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); -+ } -+ -+ cutlass::reference::device::Conv2d< -+ typename B2bConv2d::ElementA, -+ typename B2bConv2d::LayoutA, -+ typename B2bConv2d::ElementB, -+ typename B2bConv2d::LayoutB, -+ typename B2bConv2d::ElementC, -+ typename B2bConv2d::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size_1, -+ tensor_D0_reference.device_ref(), -+ tensor_B1.device_ref(), -+ tensor_C1.device_ref(), -+ tensor_D1_reference.device_ref(), -+ alpha1, -+ beta1); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ CHECK_TRUE(result == cudaSuccess); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D0_reference.sync_host(); -+ tensor_D1_reference.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D1_computed.host_view(), -+ tensor_D1_reference.host_view()); -+ -+ CHECK_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_B2bImplicitGemm_device_interleaved_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size_0 << std::endl; -+ results << problem_size_1 << std::endl; -+ -+ results -+ << "\nA0:\n" << tensor_A0.host_view() << "\n" -+ << "\nB0:\n" << tensor_B0.host_view() << "\n" -+ << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" -+ << "\nC0:\n" << tensor_C0.host_view() << "\n" -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1:\n" << tensor_B1.host_view() << "\n" -+ << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" -+ << "\nC1:\n" << tensor_C1.host_view() << "\n" -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" -+ << "\nD1 computed:\n" << tensor_D1_computed.host_view(); -+ -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h -new file mode 100644 -index 0000000..51ff1bb ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "reference/device/tensor_scale_bias.h" -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+template -+struct B2bInterleavedNonFusedGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bInterleavedNonFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ //Reorder B0 and B1 -+ cutlass::reorder_column( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); -+ cutlass::reorder_column( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm0::Arguments arguments_0{ -+ problem_size_0, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0} -+ }; -+ -+ typename Gemm1::Arguments arguments_1{ -+ problem_size_1, -+ tensor_D0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1} -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_interleaved_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+}; -+ -+template -+struct B2bInterleavedFusedGemmRun -+{ -+ -+ using B2bGemm = B2bGemm_; -+ using ElementAccumulator = typename B2bGemm::ElementAccumulator; -+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ B2bInterleavedFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size_0, -+ cutlass::gemm::GemmCoord problem_size_1, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = true, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementA, -+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Scale0; -+ -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.resize({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementScaleBias, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); -+ -+ cutlass::HostTensor< -+ ElementAccumulator, -+ typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementB, -+ typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); -+ -+ cutlass::HostTensor< -+ typename B2bGemm::ElementC, -+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ //Reorder B0 -+ cutlass::reorder_column<16>( -+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); -+ cutlass::reorder_column( -+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_B0_reordered.sync_device(); -+ tensor_C0.sync_device(); -+ if(alpha0 == ElementCompute(0)) //per-channel scale -+ tensor_Scale0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_B1_reordered.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename B2bGemm::Arguments arguments{ -+ problem_size_0, -+ problem_size_1, -+ tensor_A0.device_ref(), -+ tensor_B0_reordered.device_ref(), -+ tensor_C0.device_ref(), -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref(), -+ tensor_B1_reordered.device_ref(), -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ }; -+ -+ B2bGemm b2b_gemm_op; -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ if(status != cutlass::Status::kSuccess) { -+ std::cout << "Problem sizes not supported.\n" -+ << "Requirments:\n" -+ << " problem_size_0.M = problem_size_1.M\n" -+ << " problem_size_0.N = problem_size_1.K\n" -+ << " ThreadblockShape0::kN = problem_size_0.N\n" -+ << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; -+ } -+ -+ status = b2b_gemm_op.initialize(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ ElementAccumulator, typename B2bGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA, -+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB, -+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename B2bGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size_0, -+ ElementAccumulator(1), //intermediate alpha=1 -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementAccumulator(0), //beta = 0 -+ reference_Z0.device_ref(), -+ reference_Z0.device_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ cutlass::reference::device::TensorScaleBiasGemm< -+ ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, -+ ElementCompute, typename B2bGemm::LayoutScaleBias -+ > ( -+ problem_size_0, -+ reference_Z0.device_ref(), -+ reference_D0.device_ref(), -+ alpha0, -+ tensor_Scale0.device_ref(), -+ tensor_Bias0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size_1, -+ alpha1, -+ reference_D0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ -+ CHECK_TRUE(passed); -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_B2bGemm_device_interleaved_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h -new file mode 100644 -index 0000000..f365b23 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h -@@ -0,0 +1,451 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "kernel/default_b2b_gemm.h" -+#include "kernel/default_b2b_gemm_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Stage accumulator in shared memory -+ bool SmemAccumulator = false, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class B2bGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape0 = ThreadblockShape0_; -+ using ThreadblockShape1 = ThreadblockShape1_; -+ using WarpShape0 = WarpShape0_; -+ using WarpShape1 = WarpShape1_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp0 = EpilogueOutputOp0_; -+ using EpilogueOutputOp1 = EpilogueOutputOp1_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp1::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Derived types -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; -+ -+ /// Define the kernel -+ using B2bGemmKernel = typename kernel::DefaultB2bGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ SmemAccumulator -+ >::B2bGemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size_0; -+ GemmCoord problem_size_1; -+ TensorRef ref_A0; -+ TensorRef ref_B0; -+ TensorRef ref_C0; -+ TensorRef ref_Scale0; -+ TensorRef ref_Bias0; -+ TensorRef ref_B1; -+ TensorRef ref_C1; -+ TensorRef ref_D1; -+ typename EpilogueOutputOp0::Params epilogue0; -+ typename EpilogueOutputOp1::Params epilogue1; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_0_, -+ GemmCoord problem_size_1_, -+ TensorRef ref_A0_, -+ TensorRef ref_B0_, -+ TensorRef ref_C0_, -+ TensorRef ref_Scale0_, -+ TensorRef ref_Bias0_, -+ TensorRef ref_B1_, -+ TensorRef ref_C1_, -+ TensorRef ref_D1_, -+ typename EpilogueOutputOp0::Params epilogue0_ = -+ typename EpilogueOutputOp0::Params(), -+ typename EpilogueOutputOp1::Params epilogue1_ = -+ typename EpilogueOutputOp1::Params(), -+ int split_k_slices_ = 1 -+ ): -+ problem_size_0(problem_size_0_), -+ problem_size_1(problem_size_1_), -+ ref_A0(ref_A0_), -+ ref_B0(ref_B0_), -+ ref_C0(ref_C0_), -+ ref_Scale0(ref_Scale0_), -+ ref_Bias0(ref_Bias0_), -+ ref_B1(ref_B1_), -+ ref_C1(ref_C1_), -+ ref_D1(ref_D1_), -+ epilogue0(epilogue0_), -+ epilogue1(epilogue1_), -+ split_k_slices(split_k_slices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename B2bGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ B2bGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = B2bGemmKernel::can_implement( -+ args.problem_size_0, -+ args.problem_size_1, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1 -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.split_k_slices); -+// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( -+// args.problem_size_1, -+// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, -+// args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename B2bGemmKernel::Params{ -+ args.problem_size_0, -+ args.problem_size_1, -+ grid_shape, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_Scale0.non_const_ref(), -+ args.ref_Bias0.non_const_ref(), -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.epilogue0, -+ args.epilogue1, -+ static_cast(workspace), -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); -+ params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); -+ params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); -+ params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data()); -+ params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data()); -+ params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); -+ params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); -+ params_.ref_D1.reset(args.ref_D1.data()); -+ params_.output_op_0 = args.epilogue0; -+ params_.output_op_1 = args.epilogue1; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(B2bGemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h -new file mode 100644 -index 0000000..b52d058 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h -@@ -0,0 +1,300 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Implicit GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/default_b2b_conv2d_fprop_sm75.h" -+#include "kernel/default_b2b_conv2d_fprop_sm80.h" -+#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h" -+#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h" -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+template -+class B2bImplicitGemmConvolution { -+public: -+ -+ using B2bImplicitGemmKernel = B2bImplicitGemmKernel_; -+ -+ using ElementA = typename B2bImplicitGemmKernel::ElementA; -+ using LayoutA = typename B2bImplicitGemmKernel::LayoutA; -+ using ElementB = typename B2bImplicitGemmKernel::ElementB; -+ using LayoutB = typename B2bImplicitGemmKernel::LayoutB; -+ using ElementC = typename B2bImplicitGemmKernel::ElementC; -+ using LayoutC = typename B2bImplicitGemmKernel::LayoutC; -+ using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator; -+ using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute; -+ using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias; -+ using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias; -+ using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass; -+ using ArchTag = typename B2bImplicitGemmKernel::ArchTag; -+ using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0; -+ using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1; -+ using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0; -+ using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1; -+ using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape; -+ using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0; -+ using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1; -+ static int const kStages = B2bImplicitGemmKernel::kStages; -+ static int const kConvDim = B2bImplicitGemmKernel::kConvDim; -+ using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0; -+ using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1; -+ using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator; -+ using MathOperator = typename B2bImplicitGemmKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm; -+ -+ static int const kWarpCount = -+ (ThreadblockShape0::kM / WarpShape0::kM) * -+ (ThreadblockShape0::kN / WarpShape0::kN); -+ -+ /// Argument structure -+ using Arguments = typename B2bImplicitGemmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename B2bImplicitGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ B2bImplicitGemmConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Determine if fusion sizes are valid -+ -+ cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0); -+ cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1); -+ -+ if(problem_size_0.m() != problem_size_1.m()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() != problem_size_1.k()) -+ return Status::kErrorInvalidProblem; -+ -+ if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() > ThreadblockShape0::kN) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_1.n() > ThreadblockShape1::kN) -+ return Status::kErrorInvalidProblem; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size_0.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename B2bImplicitGemmKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A0 = args.ref_A0.data(); -+ params_.ptr_B0 = args.ref_B0.data(); -+ params_.ptr_C0 = args.ref_C0.data(); -+ params_.ptr_Scale0 = args.ref_Scale0.data(); -+ params_.ptr_Bias0 = args.ref_Bias0.data(); -+ params_.ptr_B1 = args.ref_B1.data(); -+ params_.ptr_C1 = args.ref_C1.data(); -+ params_.ptr_D1 = args.ref_D1.data(); -+ params_.output_op_0 = args.output_op_0; -+ params_.output_op_1 = args.output_op_1; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace device -+} // namespace conv -+} // namespace cutlass -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu -new file mode 100644 -index 0000000..6f12608 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //use beta for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm75, -+ &run_fused_conv2d_fprop_optimized_f16_sm75_rf_res -+ }; -+ -+ return testRun(75, funcs, "conv f16 RF residency"); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu -new file mode 100644 -index 0000000..86eb1f7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm75_problem_size_0, conv2d_f16_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm75, -+ &run_fused_conv2d_fprop_optimized_f16_sm75_shmem -+ }; -+ -+ return testRun(75, funcs, "conv f16 shmem staging"); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu -new file mode 100644 -index 0000000..14bef44 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with RF Residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+ return true; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm80, -+ &run_fused_conv2d_fprop_optimized_f16_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "conv f16 RF residency"); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu -new file mode 100644 -index 0000000..c4df985 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_f16_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+ -+bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back FP16 Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back FP16 Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_f16_sm80_problem_size_0, conv2d_f16_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_f16_sm80, -+ &run_fused_conv2d_fprop_optimized_f16_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "conv f16 shmem staging"); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu -new file mode 100644 -index 0000000..64955f8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm75, -+ &run_fused_conv2d_fprop_optimized_s8_sm75_rf_res -+ }; -+ -+ return testRun(75, funcs, "conv int8 RF residency"); -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu -new file mode 100644 -index 0000000..7f82518 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm75_problem_size_0, conv2d_s8_sm75_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm75, -+ &run_fused_conv2d_fprop_optimized_s8_sm75_shmem -+ }; -+ -+ return testRun(75, funcs, "conv int8 shmem staging"); -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu -new file mode 100644 -index 0000000..c4e0e4c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {128, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 128} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with RF residency...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm80, -+ &run_fused_conv2d_fprop_optimized_s8_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "conv int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu -new file mode 100644 -index 0000000..de15106 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "device/b2b_implicit_gemm_convolution.h" -+#include "b2b_interleaved_conv2d_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_0 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 64} // output size (NPQK) -+ ); -+cutlass::conv::Conv2dProblemSize conv2d_s8_sm80_problem_size_1 ( -+ {32, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {32, 56, 56, 256} // output size (NPQK) -+ ); -+ -+bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop0 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ using Conv2dFpropKernel1 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop1 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ B2bInterleavedNonFusedConv2dRun nonFusedConv2d; -+ -+ std::cout << "Running Non-fused back-to-back INT8 interleaved Optimized Convolution Fprops...\n"; -+ bool pass = nonFusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() { -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ SmemAccumulator -+ >::Kernel; -+ -+ using B2bConv2dFprop = cutlass::conv::device::B2bImplicitGemmConvolution; -+ -+ B2bInterleavedFusedConv2dRun fusedConv2d; -+ -+ std::cout << "Running Fused back-to-back INT8 interleaved Optimized Convolution Fprops with shared memory staging...\n"; -+ bool pass = fusedConv2d.run(conv2d_s8_sm80_problem_size_0, conv2d_s8_sm80_problem_size_1, cutlass::conv::SplitKMode::kSerial, -+ alpha0, beta0, alpha1, beta1); -+ -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_conv2d_fprop_optimized_s8_sm80, -+ &run_fused_conv2d_fprop_optimized_s8_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "conv int8 shmem staging"); -+ -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu -new file mode 100644 -index 0000000..3a02096 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu -@@ -0,0 +1,210 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_f16() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_f16_rf_res() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF Residency...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16, -+ &run_fused_gemm_f16_rf_res -+ }; -+ -+ return testRun(75, funcs, "gemm f16 RF residency"); -+ -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu -new file mode 100644 -index 0000000..3498b40 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm75_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_f16() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_shmem() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ SmemAccumulator -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm75_problem_size_0, gemm_f16_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16, -+ &run_fused_gemm_f16_shmem -+ }; -+ -+ return testRun(75, funcs, "gemm f16 shmem staging"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu -new file mode 100644 -index 0000000..feb22fa ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu -@@ -0,0 +1,213 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_sm80_rf_res() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "gemm f16 RF residency"); -+ -+ -+} -+ -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu -new file mode 100644 -index 0000000..36c4819 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu -@@ -0,0 +1,217 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3 -+ >; -+ -+ B2bNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_f16_sm80_shmem() { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ SmemAccumulator -+ >; -+ -+ B2bFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back FP16 TN GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "gemm f16 shmem staging"); -+ -+ -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu -new file mode 100644 -index 0000000..565cca7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu -@@ -0,0 +1,212 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_s8() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_s8_rf_res() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF Residency...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8, -+ &run_fused_gemm_s8_rf_res -+ }; -+ -+ return testRun(75, funcs, "gemm int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu -new file mode 100644 -index 0000000..8719d74 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_s8() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2 -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_s8_shmem() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ InstructionShape::kM * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ SmemAccumulator -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm75_problem_size_0, gemm_s8_sm75_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8, -+ &run_fused_gemm_s8_shmem -+ }; -+ -+ return testRun(75, funcs, "gemm int8 shmem staing"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu -new file mode 100644 -index 0000000..60f9adb ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 128, 64); -+ -+bool run_nonfused_gemm_s8_sm80() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+ -+bool run_fused_gemm_s8_sm80_rf_res() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = false; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ SmemAccumulator, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8_sm80, -+ &run_fused_gemm_s8_sm80_rf_res -+ }; -+ -+ return testRun(80, funcs, "gemm int8 RF residency"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu -new file mode 100644 -index 0000000..64788e0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/b2b_gemm.h" -+#include "b2b_interleaved_gemm_run.h" -+#include "test_run.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(128*640, 64, 576); -+cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(128*640, 256, 64); -+ -+bool run_nonfused_gemm_s8_sm80() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ ElementCompute beta0 = ElementCompute(1); //beta=1 for bias -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ WarpShape0, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape1, -+ WarpShape1, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedNonFusedGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; -+ bool pass = nonFusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+bool run_fused_gemm_s8_sm80_shmem() { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ ElementCompute alpha0 = ElementCompute(1); -+ //Fused kernel has built-in bias, setting beta=0 -+ ElementCompute beta0 = ElementCompute(0); -+ ElementCompute alpha1 = ElementCompute(1); -+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias -+ -+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; -+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; -+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ using EpilogueOutputOp0 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 8 * InstructionShape::kN / 32, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >; -+ -+ using EpilogueOutputOp1 = -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling -+ >; -+ -+ const bool SmemAccumulator = true; -+ -+ using B2bGemm = cutlass::gemm::device::B2bGemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ SmemAccumulator, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ B2bInterleavedFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with shared memory staging...\n"; -+ bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+} -+ -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_s8_sm80, -+ &run_fused_gemm_s8_sm80_shmem -+ }; -+ -+ return testRun(80, funcs, "gemm int8 shmem staging"); -+ -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h -new file mode 100644 -index 0000000..1ccf902 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h -@@ -0,0 +1,460 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct B2bGemm { -+ -+ using B2bMma = B2bMma_; -+ using Epilogue = Epilogue_; -+ using OutputOp0 = typename B2bMma::OutputOp; -+ using OutputOp1 = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename B2bMma::WarpCount0; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size_0; -+ cutlass::gemm::GemmCoord problem_size_1; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename B2bMma::IteratorA0::Params params_A0; -+ typename B2bMma::IteratorA0::TensorRef ref_A0; -+ typename B2bMma::IteratorB0::Params params_B0; -+ typename B2bMma::IteratorB0::TensorRef ref_B0; -+ typename Epilogue::OutputTileIterator::Params params_C0; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0; -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0; -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0; -+ typename B2bMma::IteratorB1::Params params_B1; -+ typename B2bMma::IteratorB1::TensorRef ref_B1; -+ typename Epilogue::OutputTileIterator::Params params_C1; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1; -+ typename Epilogue::OutputTileIterator::Params params_D1; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1; -+ typename OutputOp0::Params output_op_0; -+ typename OutputOp1::Params output_op_1; -+ int *semaphore; -+ int gemm_k_iterations_0; -+ int gemm_k_size_0; -+ int gemm_k_iterations_1; -+ int gemm_k_size_1; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), -+ gemm_k_iterations_1(0), gemm_k_size_1(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_0, -+ cutlass::gemm::GemmCoord const & problem_size_1, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename B2bMma::IteratorA0::TensorRef ref_A0, -+ typename B2bMma::IteratorB0::TensorRef ref_B0, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0, -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, -+ typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, -+ typename B2bMma::IteratorB1::TensorRef ref_B1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1, -+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), -+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A0(ref_A0.layout()), -+ ref_A0(ref_A0), -+ params_B0(ref_B0.layout()), -+ ref_B0(ref_B0), -+ params_C0(ref_C0.layout()), -+ ref_C0(ref_C0), -+ ref_Scale0(ref_Scale0), -+ ref_Bias0(ref_Bias0), -+ params_B1(ref_B1.layout()), -+ ref_B1(ref_B1), -+ params_C1(ref_C1.layout()), -+ ref_C1(ref_C1), -+ params_D1(ref_D1.layout()), -+ ref_D1(ref_D1), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1) { -+ -+ int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; -+ int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; -+ int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; -+ int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename B2bMma::B2bMmaSharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ B2bGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size_0, -+ cutlass::gemm::GemmCoord const & problem_size_1, -+ typename B2bMma::IteratorA0::TensorRef ref_A0, -+ typename B2bMma::IteratorB0::TensorRef ref_B0, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C0, -+ typename B2bMma::IteratorB1::TensorRef ref_B1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D1) { -+ -+ static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; -+ static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) || -+ (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) || -+ (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) || -+ (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) || -+ (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) || -+ (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // Determine if fusion sizes are valid -+ if(problem_size_0.m() != problem_size_1.m()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() != problem_size_1.k()) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_0.n() > B2bMma::Shape0::kN) -+ return Status::kErrorInvalidProblem; -+ -+ if(problem_size_1.n() > B2bMma::Shape1::kN) -+ return Status::kErrorInvalidProblem; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A0{ -+ threadblock_tile_offset.m() * B2bMma::Shape0::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size_0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B0{ -+ threadblock_tile_offset.k() * params.gemm_k_size_0, -+ threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B1{ -+ threadblock_tile_offset.k() * params.gemm_k_size_1, -+ threadblock_tile_offset.n() * B2bMma::Shape1::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k_0 = min( -+ params.problem_size_0.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k_1 = min( -+ params.problem_size_1.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); -+ -+ // Compute threadblock-scoped matrix multiply-add -+// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename B2bMma::IteratorA0 iterator_A0( -+ params.params_A0, -+ params.ref_A0.data(), -+ {params.problem_size_0.m(), problem_size_k_0}, -+ thread_idx, -+ tb_offset_A0); -+ -+ typename B2bMma::IteratorB0 iterator_B0( -+ params.params_B0, -+ params.ref_B0.data(), -+ {problem_size_k_0, params.problem_size_0.n()}, -+ thread_idx, -+ tb_offset_B0); -+ -+ typename B2bMma::IteratorB1 iterator_B1( -+ params.params_B1, -+ params.ref_B1.data(), -+ {problem_size_k_1, params.problem_size_1.n()}, -+ thread_idx, -+ tb_offset_B1); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Construct iterators to accumulator scale/bias vector -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( -+ params.ref_Scale0.data(), -+ {1, params.problem_size_0.n()}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( -+ params.ref_Bias0.data(), -+ {1, params.problem_size_0.n()}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ -+ -+ // -+ // Main loop -+ // -+ -+ OutputOp0 output_op_0(params.output_op_0); -+ -+ // Construct thread-scoped matrix multiply -+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); -+ -+ typename B2bMma::FragmentC0 src_accum; -+ typename B2bMma::FragmentC1 accumulators; -+ -+ src_accum.clear(); -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations_0 > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, -+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp1 output_op_1(params.output_op_1); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * B2bMma::Shape1::kM, -+ threadblock_tile_offset.n() * B2bMma::Shape1::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ params.ref_C1.data(), -+ params.problem_size_1.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D1( -+ params.params_D1, -+ params.ref_D1.data(), -+ params.problem_size_1.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h -new file mode 100644 -index 0000000..6c54087 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h -@@ -0,0 +1,521 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct B2bImplicitGemmConvolution { -+ -+ using B2bMma = B2bMma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp0 = typename B2bMma::OutputOp; -+ using EpilogueOutputOp1 = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename B2bMma::IteratorA0::Element; -+ using LayoutA = typename B2bMma::IteratorA0::Layout; -+ using ElementB = typename B2bMma::IteratorB0::Element; -+ using LayoutB = typename B2bMma::IteratorB0::Layout; -+ using ElementC = typename EpilogueOutputOp1::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp0::ElementCompute; -+ -+ /// Scale and Bias -+ using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element; -+ using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout; -+ -+ using WarpMmaOperator0 = typename B2bMma::Policy0::Operator; -+ using WarpMmaOperator1 = typename B2bMma::Policy1::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator0::OperatorClass; -+ using ArchTag = typename WarpMmaOperator0::ArchTag; -+ -+ using ThreadblockShape0 = typename B2bMma::Shape0; -+ using ThreadblockShape1 = typename B2bMma::Shape1; -+ using WarpShape0 = typename WarpMmaOperator0::Shape; -+ using WarpShape1 = typename WarpMmaOperator1::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = B2bMma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename B2bMma::WarpCount0; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef; -+ using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef; -+ using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef; -+ using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::B2bImplicitGemmConvolution::kConvDim -+ static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim, -+ "Convolution on different dimensions is not supported"); -+ static int const kConvDim = B2bMma::IteratorA0::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ cutlass::platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size_0; -+ ConvProblemSize problem_size_1; -+ TensorRefA0 ref_A0; -+ TensorRefB0 ref_B0; -+ TensorRefC ref_C0; -+ TensorRefScaleBias0 ref_Scale0; -+ TensorRefScaleBias0 ref_Bias0; -+ TensorRefB1 ref_B1; -+ TensorRefC ref_C1; -+ TensorRefC ref_D1; -+ typename EpilogueOutputOp0::Params output_op_0; -+ typename EpilogueOutputOp1::Params output_op_1; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size_0, -+ ConvProblemSize const & problem_size_1 -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size_0, -+ ConvProblemSize const & problem_size_1, -+ TensorRefA0 const & ref_A0, -+ TensorRefB0 const & ref_B0, -+ TensorRefC const & ref_C0, -+ TensorRefScaleBias0 const & ref_Scale0, -+ TensorRefScaleBias0 const & ref_Bias0, -+ TensorRefB1 const & ref_B1, -+ TensorRefC const & ref_C1, -+ TensorRefC const & ref_D1, -+ typename EpilogueOutputOp0::Params const & output_op_0, -+ typename EpilogueOutputOp1::Params const & output_op_1, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size_0(problem_size_0), -+ problem_size_1(problem_size_1), -+ ref_A0(ref_A0), -+ ref_B0(ref_B0), -+ ref_C0(ref_C0), -+ ref_Scale0(ref_Scale0), -+ ref_Bias0(ref_Bias0), -+ ref_B1(ref_B1), -+ ref_C1(ref_C1), -+ ref_D1(ref_D1), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size_0; -+ ConvProblemSize problem_size_1; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size_0; -+ gemm::GemmCoord implicit_gemm_problem_size_1; -+ int swizzle_log_tile; -+ int gemm_k_iterations_0; -+ int gemm_k_iterations_1; -+ typename B2bMma::IteratorA0::Params iterator_A0; -+ typename B2bMma::IteratorA0::Element const *ptr_A0; -+ typename B2bMma::IteratorB0::Params iterator_B0; -+ typename B2bMma::IteratorB0::Element const *ptr_B0; -+ typename Epilogue::OutputTileIterator::Params iterator_C0; -+ typename Epilogue::OutputTileIterator::Element *ptr_C0; -+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0; -+ typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0; -+ typename B2bMma::IteratorB1::Params iterator_B1; -+ typename B2bMma::IteratorB1::Element const *ptr_B1; -+ typename Epilogue::OutputTileIterator::Params iterator_C1; -+ typename Epilogue::OutputTileIterator::Element *ptr_C1; -+ typename Epilogue::OutputTileIterator::Params iterator_D1; -+ typename Epilogue::OutputTileIterator::Element *ptr_D1; -+ typename EpilogueOutputOp0::Params output_op_0; -+ typename EpilogueOutputOp1::Params output_op_1; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size_0(args.problem_size_0), -+ problem_size_1(args.problem_size_1), -+ implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)), -+ implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)), -+ iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())), -+ ptr_A0(args.ref_A0.data()), -+ iterator_B0(args.problem_size_0, args.ref_B0.layout()), -+ ptr_B0(args.ref_B0.data()), -+ iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)), -+ ptr_C0(args.ref_C0.data()), -+ ptr_Scale0(args.ref_Scale0.data()), -+ ptr_Bias0(args.ref_Bias0.data()), -+ iterator_B1(args.problem_size_1, args.ref_B1.layout()), -+ ptr_B1(args.ref_B1.data()), -+ iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)), -+ ptr_C1(args.ref_C1.data()), -+ iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)), -+ ptr_D1(args.ref_D1.data()), -+ output_op_0(args.output_op_0), -+ output_op_1(args.output_op_1), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0); -+ gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size_0, -+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, -+ args.problem_size_0.split_k_slices); -+ -+ swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename B2bMma::B2bMmaSharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ B2bImplicitGemmConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename B2bMma::IteratorA0 iterator_A0( -+ params.iterator_A0, -+ params.problem_size_0, -+ params.ptr_A0, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * B2bMma::Shape0::kM, -+ threadblock_tile_idx.k() * B2bMma::Shape0::kK -+ ) -+ ); -+ -+ typename B2bMma::IteratorB0 iterator_B0( -+ params.iterator_B0, -+ params.problem_size_0, -+ params.ptr_B0, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * B2bMma::Shape0::kK, -+ threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorB1 iterator_B1( -+ params.iterator_B1, -+ params.problem_size_1, -+ params.ptr_B1, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * B2bMma::Shape1::kK, -+ threadblock_tile_idx.n() * B2bMma::Shape1::kN -+ ) -+ ); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Construct iterators to accumulator scale/bias vector -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( -+ params.ptr_Scale0, -+ {1, params.problem_size_0.K}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( -+ params.ptr_Bias0, -+ {1, params.problem_size_0.K}, -+ thread_idx, -+ warp_idx, -+ MatrixCoord( -+ 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN -+ ) -+ ); -+ -+ -+ // -+ // Main loop -+ // -+ -+ EpilogueOutputOp0 output_op_0(params.output_op_0); -+ -+ // Construct thread-scoped matrix multiply -+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename B2bMma::FragmentC0 src_accum; -+ typename B2bMma::FragmentC1 accumulators; -+ -+ src_accum.clear(); -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, -+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp1 output_op_1(params.output_op_1); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * B2bMma::Shape1::kM, -+ threadblock_tile_idx.n() * B2bMma::Shape1::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D1( -+ params.iterator_D1, -+ params.ptr_D1, -+ ConvOutputIteratorParameter::extent(params.problem_size_1), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.iterator_C1, -+ params.ptr_C1, -+ ConvOutputIteratorParameter::extent(params.problem_size_1), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ __threadfence(); -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D1.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h -new file mode 100644 -index 0000000..82e808d ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined.h" -+#include "threadblock/b2b_implicit_gemm_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ bool SmemAccumulator = false -+> struct DefaultB2bConv2dFprop; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h -new file mode 100644 -index 0000000..d5792a8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h -@@ -0,0 +1,749 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ false -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelined< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h -new file mode 100644 -index 0000000..7261e7e ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h -@@ -0,0 +1,740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistage< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ ThreadblockShape1, -+ FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorA1ScaleBias, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h -new file mode 100644 -index 0000000..09d094f ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h -@@ -0,0 +1,817 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1 -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ > -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ > -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ > -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ IteratorB0, -+ SmemIteratorB0, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ ElementC, -+ LayoutC, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1 -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h -new file mode 100644 -index 0000000..7a5b380 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h -@@ -0,0 +1,804 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "kernel/default_b2b_conv2d_fprop.h" -+#include "kernel/b2b_implicit_gemm_convolution.h" -+#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and -+/// multistage pipeline. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+/// Accumulator will be staged in shared memory. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape0, -+ typename ThreadblockShape1, -+ typename WarpShape0, -+ typename WarpShape1, -+ typename InstructionShape, -+ typename EpilogueOutputOp0, -+ typename EpilogueOutputOp1, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int InterleavedK -+> -+struct DefaultB2bConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ true -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; -+ using IteratorA0 = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA0 -+ >; -+ -+ using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; -+ using IteratorB0 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB0 -+ >; -+ -+ using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; -+ using IteratorB1 = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB1 -+ >; -+ -+ using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; -+ -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ ElementC, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; -+ -+ // Define the Mma -+ using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< -+ ThreadblockShape0, -+ IteratorA0, -+ SmemIteratorA0, -+ arch::CacheOperation::Always, -+ IteratorB0, -+ SmemIteratorB0, -+ arch::CacheOperation::Global, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, -+ ThreadblockShape1, -+ WarpIteratorA1, -+ IteratorB1, -+ SmemIteratorB1, -+ arch::CacheOperation::Global, -+ EpilogueOutputOp0, -+ MmaPolicy0, -+ MmaPolicy1, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape1, -+ WarpMmaTensorOp1, -+ 1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< -+ B2bMma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h -new file mode 100644 -index 0000000..05c3f4e ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h -@@ -0,0 +1,442 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "threadblock/default_b2b_mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Stage accumulator in shared memory -+ bool SmemAccumulator = false -+> -+struct DefaultB2bGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultB2bGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ 2, -+ Operator, -+ EpilogueOutputOp0 -+ >::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ typename B2bMma::Operator1, -+ kPartitionsK1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, -+ true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Turing Integer Tensor Core Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, -+ WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue for the 2nd Gemm -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h -new file mode 100644 -index 0000000..23717c6 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h -@@ -0,0 +1,397 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+ -+#include "kernel/b2b_gemm.h" -+#include "threadblock/default_b2b_mma.h" -+#include "threadblock/default_b2b_mma_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultB2bGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ true -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape0, -+ ThreadblockShape1, -+ WarpShape0, -+ WarpShape1, -+ InstructionShape, -+ 2, -+ Operator, -+ EpilogueOutputOp0, -+ false, -+ true -+ >::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape1, -+ typename B2bMma::Operator1, -+ kPartitionsK1, -+ EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, true> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp0, -+ true, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Turing Integer Tensor Core Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0, -+ /// Epilogue output operator -+ typename EpilogueOutputOp1, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultB2bGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, true> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; -+ -+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; -+ -+ /// Define the epilogue for the 2nd Gemm -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using B2bGemmKernel = kernel::B2bGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h -new file mode 100644 -index 0000000..eef9d9a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h -@@ -0,0 +1,275 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template < -+ typename TensorRefIn, ///< Input TensorRef Type -+ typename TensorRefOut, ///< Output TensorRef Type -+ typename ScalarType, ///< alpha Type -+ typename TensorRefScalar, ///< Scale/Bias TensorRef Type -+ typename OutputTile, -+ typename ConvertOp = NumericConverter -+> -+__global__ void TensorScaleBiasGemm( -+ gemm::GemmCoord problem_size, -+ TensorRefIn tensor_in, ///< input tensor -+ TensorRefOut tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRefScalar tensor_scale, ///< scale tensor -+ TensorRefScalar tensor_bias ///< bias tensor -+) { -+ -+ ConvertOp convert_op; -+ -+ MatrixCoord output_coord( -+ MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), -+ MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) -+ ); -+ -+ // Update the output tensor -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ MatrixCoord coord = output_coord + MatrixCoord(i, j); -+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { -+ -+ ScalarType scale = alpha; -+ if(tensor_scale.good()) -+ scale = tensor_scale.at({0, coord.column()}); -+ -+ ScalarType bias = ScalarType(0); -+ -+ if(tensor_bias.good()) -+ bias = tensor_bias.at({0, coord.column()}); -+ -+ tensor_out.at(coord) = convert_op( -+ scale * ScalarType(tensor_in.at(coord)) + bias); -+ } -+ } -+ } -+} -+ -+template < -+ typename TensorRefIn, ///< Input TensorRef Type -+ typename TensorRefOut, ///< Output TensorRef Type -+ typename ScalarType, ///< alpha Type -+ typename TensorRefScalar, ///< Scale/Bias TensorRef Type -+ typename ConvertOp = NumericConverter, -+ int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void TensorScaleBiasConv2d( -+ conv::Conv2dProblemSize problem_size, -+ TensorRefIn tensor_in, ///< input tensor -+ TensorRefOut tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRefScalar tensor_scale, ///< scale tensor -+ TensorRefScalar tensor_bias ///< bias tensor -+) { -+ -+ ConvertOp convert_op; -+ -+ int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t npq = npq_start + m; -+ -+ thread_n[m] = int(npq / PQ); -+ -+ int64_t residual = npq % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ScalarType scale = alpha; -+ if(tensor_scale.good()) -+ scale = tensor_scale.at({0, thread_k}); -+ -+ ScalarType bias = ScalarType(0); -+ if(tensor_bias.good()) -+ bias = tensor_bias.at({0, thread_k}); -+ -+ tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ scale * ScalarType( -+ tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) -+ ) + bias); -+ } -+ } -+ } -+ } -+ -+} -+ -+} -+ -+/// Apply scale and bias on a tensor -+template < -+ typename ElementIn, ///< Input Type -+ typename ElementOut, ///< Output Type -+ typename Layout, ///< Layout of input/output tensor -+ typename ScalarType, ///< alpha Type -+ typename LayoutScaleBias, ///< Layout of scale and bias -+ typename ConvertOp = NumericConverter -+> -+void TensorScaleBiasGemm( -+ gemm::GemmCoord problem_size, -+ TensorRef tensor_in, ///< input tensor -+ TensorRef tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRef tensor_scale, ///< scale tensor -+ TensorRef tensor_bias ///< bias tensor -+) { -+ -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) -+ ); -+ -+ kernel::TensorScaleBiasGemm< -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ TensorRef, -+ OutputTile, -+ ConvertOp -+ ><<< grid, block >>> ( -+ problem_size, -+ tensor_in, -+ tensor_out, -+ alpha, -+ tensor_scale, -+ tensor_bias -+ ); -+} -+ -+/// Apply scale and bias on a tensor -+template < -+ typename ElementIn, ///< Input Type -+ typename ElementOut, ///< Output Type -+ typename Layout, ///< Layout of input/output tensor -+ typename ScalarType, ///< alpha Type -+ typename LayoutScaleBias, ///< Layout of scale and bias -+ typename ConvertOp = NumericConverter -+> -+void TensorScaleBiasConv2d( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_in, ///< input tensor -+ TensorRef tensor_out, ///< output tensor -+ ScalarType alpha, ///< alpha -+ TensorRef tensor_scale, ///< scale tensor -+ TensorRef tensor_bias ///< bias tensor -+) { -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ -+ kernel::TensorScaleBiasConv2d< -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ TensorRef, -+ ConvertOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block >>> ( -+ problem_size, -+ tensor_in, -+ tensor_out, -+ alpha, -+ tensor_scale, -+ tensor_bias -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h -new file mode 100644 -index 0000000..b64f31f ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/test_run.h -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+#include -+ -+// Run tests on GPUs -+ -+int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { -+ -+ bool supported = false; -+ -+ int arch_major = arch / 10; -+ int arch_minor = arch - arch / 10 * 10; -+ -+ if(arch_major >= 8) { -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { -+ supported = true; -+ } -+ } -+ else if(arch_major >= 7) { -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { -+ supported = true; -+ } -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major == arch_major && props.minor == arch_minor)) { -+ supported = false; -+ } -+ -+ if (!supported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ std::cout << "This example isn't supported on current architecture" << std::endl; -+ return 0; -+ } -+ -+ bool pass = true; -+ -+ std::cout << "Device: " << props.name << std::endl; -+ std::cout << "Arch: SM" << arch << std::endl; -+ std::cout << "Test: " << test_name << std::endl; -+ for(auto func : test_funcs) { -+ pass &= func(); -+ } -+ -+ -+ if(pass) -+ return 0; -+ else -+ return -1; -+ -+} -+ -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h -new file mode 100644 -index 0000000..4e154f5 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h -@@ -0,0 +1,831 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// WarpIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bImplicitGemmMultistage : -+ public gemm::threadblock::B2bMmaBase { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of A operand in global memory -+ using FragmentIteratorA1 = FragmentIteratorA1_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from threadblock fragment -+ using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ using ElementC = typename Policy0::Operator::ElementC; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; -+ using WarpLoadedFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0( -+ IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ -+ iterator_A0.set_iteration_index(group_start_A0); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ -+ if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0); -+ -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ -+ iterator_B1.set_iteration_index(group_start_B1); -+ -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { -+ group_start_iteration_A0 = 0; -+ group_start_iteration_B0 = 0; -+ } else { -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ } -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ -+ // 2nd Implicit Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], -+ warp_loaded_frag_A1_scale[0], -+ warp_loaded_frag_A1_bias[0], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_1(iterator_B1); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load threadblock-level scale/bias vector from global memory -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ // Load warp-level scale bias fragment from threadblock scale/bias vector -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ // Load warp-level tile from accumulator fragment -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_B1; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ group_start_iteration_B1 = 0; -+ } else { -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ } -+ -+ copy_tiles_and_advance_1(iterator_B1, -+ group_start_iteration_B1); -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h -new file mode 100644 -index 0000000..7c6793a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h -@@ -0,0 +1,816 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bImplicitGemmMultistageSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ using ElementC = typename Policy0::Operator::ElementC; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmMultistageSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0( -+ IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ -+ iterator_A0.set_iteration_index(group_start_A0); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ -+ if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0); -+ -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ -+ iterator_B1.set_iteration_index(group_start_B1); -+ -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { -+ group_start_iteration_A0 = 0; -+ group_start_iteration_B0 = 0; -+ } else { -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ } -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.advance(); -+ iterator_B0.advance(); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ // 2nd Implicit Gemm -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance_1(iterator_B1); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load warp-level tile from accumulator fragment -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_B1; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ group_start_iteration_B1 = 0; -+ } else { -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ } -+ -+ copy_tiles_and_advance_1(iterator_B1, -+ group_start_iteration_B1); -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.advance(); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h -new file mode 100644 -index 0000000..36d4563 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h -@@ -0,0 +1,553 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// FragmentIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ // (concept: VectorFragmentIterator) -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Transformation applied to A operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bImplicitGemmPipelined : -+ public gemm::threadblock::B2bMmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBase; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using FragmentIteratorA1ScaleBias = -+ FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; -+ /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment -+ using WarpFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmPipelined( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ //These may change across different GEMM layers -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumulator tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ -+ } -+ } -+ -+ -+ //2nd Implicit Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ -+ -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ FragmentB1 tb_frag_B1; -+ -+ if(PerChannelScale) -+ tb_frag_A1_scale.clear(); -+ tb_frag_A1_bias.clear(); -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ if(PerChannelScale) -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ iterator_B1.load(tb_frag_B1); -+ -+ -+ if(PerChannelScale) -+ ++iterator_A1_scale; -+ ++iterator_A1_bias; -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; -+ WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], -+ warp_frag_A1_bias[0], output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], -+ warp_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h -new file mode 100644 -index 0000000..828426b ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h -@@ -0,0 +1,535 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bImplicitGemmPipelinedSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ using WarpFragmentA1 = typename Operator1::FragmentA; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bImplicitGemmPipelinedSmemAccumulator( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumulator tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ -+ } -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ /// 2nd Implicit Gemm -+ -+ -+ // -+ // Prologue -+ // -+ -+ FragmentB1 tb_frag_B1; -+ -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ warp_tile_iterator_A1_.load(warp_frag_A1[0]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h -new file mode 100644 -index 0000000..660879c ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ using Shape1 = Shape1_; -+ -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ using Policy1 = Policy1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm0 = typename Policy0::Operator::Shape; -+ using WarpGemm1 = typename Policy1::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount0 = GemmShape; -+ using WarpCount1 = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations0 = -+ (WarpGemm0::kK / Operator0::Policy::MmaShape::kK); -+ static int const kWarpGemmIterations1 = -+ (WarpGemm1::kK / Operator1::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ template< -+ typename Shape_, -+ typename Policy_ -+ > -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Policy = Policy_; -+ using Operator = typename Policy::Operator; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ using SharedStorage0 = SharedStorage; -+ using SharedStorage1 = SharedStorage; -+ union B2bMmaSharedStorage { -+ SharedStorage0 shared_storage0; -+ SharedStorage1 shared_storage1; -+ }; -+ -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A0 operand from shared memory -+ typename Operator0::IteratorA warp_tile_iterator_A0_; -+ -+ /// Iterator to load a warp-scoped tile of B0 operand from shared memory -+ typename Operator0::IteratorB warp_tile_iterator_B0_; -+ -+ /// Iterator to load a warp-scoped tile of B1 operand from shared memory -+ typename Operator1::IteratorB warp_tile_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx), -+ warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -new file mode 100644 -index 0000000..fc8058a ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -@@ -0,0 +1,179 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "threadblock/b2b_mma_base.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Shared Memory Accumulator Iterator -+ typename SmemAccumulatorIterator0_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaBaseSmemAccumulator : -+ public B2bMmaBase { -+ -+ public: -+ ///< Base class -+ using Base = B2bMmaBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ using Shape1 = Shape1_; -+ -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ using Policy1 = Policy1_; -+ -+ -+ using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_; -+ -+ // -+ // Nested structs -+ // -+ /// Shared storage object needed by accumulator -+ template< -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename Padding_ -+ > -+ class AccumulatorSharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using Padding = Padding_; -+ -+ /// Tensor reference to the accumulator -+ using TensorRefAccum = TensorRef; -+ -+ /// Shape of the accumulator matrix in shared memory -+ using ShapeAccum = MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for accumulator -+ AlignedBuffer accum; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the Accum matrix -+ CUTLASS_DEVICE -+ static Layout LayoutAccum() { -+ return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the Accumulator -+ CUTLASS_HOST_DEVICE -+ TensorRefAccum accum_ref() { -+ return TensorRefAccum{accum.data(), LayoutAccum()}; -+ } -+ -+ }; -+ -+ using AccumulatorSharedStorage0 = AccumulatorSharedStorage< -+ Shape0, typename SmemAccumulatorIterator0::Element, -+ typename SmemAccumulatorIterator0::TensorLayout, -+ typename SmemAccumulatorIterator0::Padding>; -+ -+ struct B2bMmaSharedStorage { -+ typename Base::B2bMmaSharedStorage b2b_mma_shared_storage; -+ AccumulatorSharedStorage0 accumulator_shared_storage0; -+ }; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaBaseSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h -new file mode 100644 -index 0000000..4ec718b ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h -@@ -0,0 +1,885 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// WarpIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaMultistage : -+ public B2bMmaBase { -+public: -+ ///< Base class -+ using Base = B2bMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over intermediate accumulator tile -+ using FragmentIteratorA1 = FragmentIteratorA1_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from threadblock fragment -+ using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; -+ using WarpLoadedFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ iterator_A0.set_iteration_index(group_start_A0 * -+ IteratorA0::kAccessesPerVector); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ if (group_start_A0 + j < Detail::TBLoadIterationsA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0 * -+ IteratorB0::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::TBLoadIterationsB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index(group_start_B1 * -+ IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0) -+ { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ // 2nd Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; -+ WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], -+ warp_loaded_frag_A1_scale[0], -+ warp_loaded_frag_A1_bias[0], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load threadblock-level scale/bias vector from global memory -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ // Load warp-level scale bias fragment from threadblock scale/bias vector -+ if(PerChannelScale) { -+ warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_scale_; -+ } -+ warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ ++warp_tile_iterator_A1_bias_; -+ -+ // Load warp-level tile from accumulator fragment -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -new file mode 100644 -index 0000000..7f42d52 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -@@ -0,0 +1,869 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA0, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB0, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class B2bMmaMultistageSmemAccumulator : -+ public gemm::threadblock::B2bMmaBaseSmemAccumulator { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape0 = Shape0_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA0 = IteratorA0_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB0 = IteratorB0_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy0 = Policy0_; -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Epilogue after 1st Gemm -+ using OutputOp = OutputOp_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations0 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ static_assert(Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA0 = -+ IteratorA0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB0 = -+ IteratorB0::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA0 = -+ (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB0 = -+ (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA0 = typename Operator0::FragmentA; -+ using WarpLoadedFragmentB0 = typename Operator0::FragmentB; -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; -+ using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A0_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaMultistageSmemAccumulator( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::B2bMmaSharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, -+ int group_start_A0 = 0, int group_start_B0 = 0) { -+ iterator_A0.set_iteration_index(group_start_A0 * -+ IteratorA0::kAccessesPerVector); -+ this->smem_iterator_A0_.set_iteration_index(group_start_A0); -+ -+ // cp.async for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { -+ if (group_start_A0 + j < Detail::TBLoadIterationsA0) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B0 * -+ IteratorB0::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { -+ if (group_start_B0 + j < Detail::TBLoadIterationsB0) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index(group_start_B1 * -+ IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_0, -+ ///< destination accumulator tile -+ FragmentC1 &accum, -+ ///< iterator over A0 operand in global memory -+ IteratorA0 iterator_A0, -+ ///< iterator over B0 operand in global memory -+ IteratorB0 iterator_B0, -+ ///< iterator over A1 operand scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, -+ ///< iterator over A1 operand bias vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC0 const &src_accum, -+ ///< epilogue operation after 1st Gemm -+ OutputOp output_op_0) -+ { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_0) { -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ iterator_A0.set_iteration_index(0); -+ this->smem_iterator_A0_.set_iteration_index(0); -+ -+ // cp.async for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { -+ typename IteratorA0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA0::ThreadMap::kElementsPerAccess / -+ IteratorA0::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); -+ -+ ++iterator_A0; -+ } -+ -+ ++this->smem_iterator_A0_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { -+ typename IteratorB0::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB0::ThreadMap::kElementsPerAccess / -+ IteratorB0::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; -+ WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; -+ WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; -+ WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; -+ -+ Operator0 warp_mma0; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k > 0) -+ warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A0[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ -+ warp_mma0( -+ accum0, -+ warp_transformed_frag_A0[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ -+ group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { -+ int group_start_iteration_A0, group_start_iteration_B0; -+ group_start_iteration_A0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; -+ group_start_iteration_B0 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; -+ -+ copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, -+ group_start_iteration_B0); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A0.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A0_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * -+ Base::kWarpGemmIterations0, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations_0; -+ iterator_A0.clear_mask(gemm_k_iterations_0 == 0); -+ iterator_B0.clear_mask(gemm_k_iterations_0 == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations0) -+ warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ -+ // 2nd Gemm -+ -+ // -+ // Prologue -+ // -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations_1) { -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // cp.async for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ smem_write_stage_idx = Base::kStages - 1; -+ smem_read_stage_idx = 0; -+ -+ warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ -+ // Load warp-level tile from accumulator fragment -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B1_; -+ -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h -new file mode 100644 -index 0000000..c36d133 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h -@@ -0,0 +1,554 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile -+ // (concept::MmaTensorOpFragmentIterator) -+ typename FragmentIteratorA1_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// FragmentIterator to load Scale or Bias vector from threadblock fragment -+ typename FragmentIteratorA1ScaleBias_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bMmaPipelined : -+ public B2bMmaBase { -+public: -+ -+ ///< Base class -+ using Base = B2bMmaBase; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using Policy0 = Policy0_; ///< Policy describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using FragmentIteratorA1ScaleBias = -+ FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ /// Warp Fragment of operand A1 loaded from accmulator tile -+ using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; -+ /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment -+ using WarpFragmentA1ScaleBias = -+ typename FragmentIteratorA1ScaleBias::Fragment; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaPipelined( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n ///< GEMM0 N is used for accumulator extent -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ //These should stay the same across different GEMM layers -+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; -+ -+ //These may change across different GEMM layers -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory -+ IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumualtor tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 1); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 2); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ } -+ } -+ -+ //2nd Gemm -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0); -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA1ScaleBias tb_frag_A1_scale; -+ FragmentA1ScaleBias tb_frag_A1_bias; -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); -+ FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); -+ FragmentB1 tb_frag_B1; -+ -+ if(PerChannelScale) -+ tb_frag_A1_scale.clear(); -+ tb_frag_A1_bias.clear(); -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ if(PerChannelScale) -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ iterator_B1.load(tb_frag_B1); -+ -+ if(PerChannelScale) -+ ++iterator_A1_scale; -+ ++iterator_A1_bias; -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; -+ WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], -+ warp_frag_A1_bias[0], output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; -+ -+ // Avoid reading out of bounds -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::WarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ if(PerChannelScale) { -+ tb_frag_A1_scale.clear(); -+ iterator_A1_scale.load(tb_frag_A1_scale); -+ ++iterator_A1_scale; -+ } -+ tb_frag_A1_bias.clear(); -+ iterator_A1_bias.load(tb_frag_A1_bias); -+ ++iterator_A1_bias; -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ if(PerChannelScale) -+ warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], -+ warp_frag_A1_scale[(warp_mma_k + 1) % 2], -+ warp_frag_A1_bias[(warp_mma_k + 1) % 2], -+ output_op_0); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ if(PerChannelScale) -+ ++warp_tile_iterator_A1_scale_; -+ ++warp_tile_iterator_A1_bias_; -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ ++iterator_B1; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -new file mode 100644 -index 0000000..351fae3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -@@ -0,0 +1,544 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_base_smem_accumulator.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape0_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA0_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA0_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB0_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB0_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: VectorIterator) -+ typename IteratorAccumulatorScaleBias_, -+ /// Iterates over accumulator tile -+ typename FragmentIteratorAccumulator_, -+ /// Iterates over accumulator tile in shared memory -+ typename SmemIteratorD0_, -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) -+ typename OutputOp_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy0_, -+ /// Policy describing tuning details (concept: MmaPipelinedPolicy) -+ typename Policy1_, -+ /// Transformation applied to A0 operand -+ typename TransformA0_ = NumericArrayConverter< -+ typename SmemIteratorA0_::Element, -+ typename IteratorA0_::Element, -+ IteratorA0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B0 operand -+ typename TransformB0_ = NumericArrayConverter< -+ typename SmemIteratorB0_::Element, -+ typename IteratorB0_::Element, -+ IteratorB0_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B1 operand -+ typename TransformB1_ = NumericArrayConverter< -+ typename SmemIteratorB1_::Element, -+ typename IteratorB1_::Element, -+ IteratorB1_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class B2bMmaPipelinedSmemAccumulator : -+ public B2bMmaBaseSmemAccumulator { -+public: -+ -+ ///< Base class -+ using Base = B2bMmaBaseSmemAccumulator; -+ -+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory -+ using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory -+ using Policy0 = Policy0_; ///< Policy0 describing tuning details -+ -+ using SmemIteratorA0 = SmemIteratorA0_; -+ using SmemIteratorB0 = SmemIteratorB0_; -+ using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory -+ -+ using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile -+ -+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory -+ using Policy1 = Policy1_; ///< Policy1 describing tuning details -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory -+ -+ -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ -+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm -+ -+ using TransformA0 = TransformA0_; -+ using TransformB0 = TransformB0_; -+ using TransformB1 = TransformB1_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA0 = typename IteratorA0::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB0 = typename IteratorB0::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC0 = typename Policy0::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator0 = typename Policy0::Operator; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB1 = typename IteratorB1::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy0::Operator::ArchTag; -+ -+ /// Complex transform on A0 operand -+ static ComplexTransform const kTransformA0 = Operator0::kTransformA; -+ -+ /// Complex transform on B0 operand -+ static ComplexTransform const kTransformB0 = Operator0::kTransformB; -+ -+ /// Complex transform on B1 operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+ /// Epilog in shared memory -+ using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, ///< SmemTileIterator -+ FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator -+ IteratorAccumulatorScaleBias, ///< ScaleBiasIterator -+ OutputOp>; ///< Output operator -+ -+ -+ -+private: -+ -+ using WarpFragmentA0 = typename Operator0::FragmentA; -+ using WarpFragmentB0 = typename Operator0::FragmentB; -+ using WarpFragmentA1 = typename Operator1::FragmentA; -+ using WarpFragmentB1 = typename Operator1::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA0 smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB0 smem_iterator_B0_; -+ -+ /// Shared Memory Iterator to store accumulator tile -+ SmemIteratorD0 smem_iterator_D0_; -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ B2bMmaPipelinedSmemAccumulator( -+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n ///< GEMM0 N is used for accumulator extent -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), -+ smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), -+ warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), -+ smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); -+ -+ int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; -+ int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; -+ -+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; -+ -+ int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); -+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); -+ warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); -+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); -+ -+ // Add smem accumulator iterator warp offset -+ smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, -+ warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations_0, ///< number of iterations of the mainloop -+ FragmentC1 &accum, ///< destination accumulator tile -+ IteratorA0 iterator_A, ///< iterator over A operand in global memory -+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory -+ IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory -+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory -+ FragmentC0 const &src_accum, ///< source accumualtor tile -+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm -+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment -+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment -+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ FragmentC0 accum0 = src_accum; -+ -+ FragmentA0 tb_frag_A; -+ FragmentB0 tb_frag_B0; -+ -+ tb_frag_A.clear(); -+ tb_frag_B0.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ -+ ++iterator_A; -+ ++iterator_B0; -+ -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA0 warp_frag_A0[2]; -+ WarpFragmentB0 warp_frag_B0[2]; -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ Operator0 warp_mma0; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 1); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A0(tb_frag_A)); -+ -+ this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B0_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A0_.add_tile_offset( -+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); -+ -+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A0_; -+ ++this->warp_tile_iterator_B0_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B0.load(tb_frag_B0); -+ ++iterator_A; -+ ++iterator_B0; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations_0 <= 2); -+ iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); -+ } -+ -+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], -+ warp_frag_B0[warp_mma_k % 2], accum0); -+ } -+ } -+ -+ /// Epilogue for the first Implicit Gemm -+ Epilogue0 epilogue0; -+ -+ epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); -+ -+ __syncthreads(); -+ -+ //2nd Gemm -+ -+ // -+ // Prologue -+ // -+ -+ FragmentB1 tb_frag_B1; -+ -+ tb_frag_B1.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ ++this->smem_iterator_B1_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA1 warp_frag_A1[2]; -+ WarpFragmentB1 warp_frag_B1[2]; -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ warp_tile_iterator_A1_.load(warp_frag_A1[0]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ Operator1 warp_mma1; -+ -+ smem_write_stage_idx = 1; -+ -+ int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; -+ -+ // Avoid reading out of bounds -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_PRAGMA_UNROLL -+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B1_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ -+ } -+ -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ -+ // skip warp tile loading for the last kgroup -+ if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) -+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_B1.load(tb_frag_B1); -+ -+ ++iterator_B1; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); -+ } -+ -+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], -+ warp_frag_B1[warp_mma_k % 2], accum); -+ -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h -new file mode 100644 -index 0000000..d1842f6 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h -@@ -0,0 +1,584 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+#include "cutlass/transform/warp/vector_fragment_iterator.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+ -+#include "threadblock/b2b_mma_pipelined.h" -+#include "threadblock/b2b_mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Staging the accumulators in shared memory. -+ bool SmemAccumulator = false> -+struct DefaultB2bMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output with 2-stage pipeline -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; -+ -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output for multi-stage -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA0 = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB0 = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; -+ -+ // Use fragment iterator for A operand -+ using AccumulatorLayout = cutlass::layout::ColumnMajor; -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using AccessTypeB1 = cutlass::Array; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with 2-stage pipeline -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; -+ -+ // Use fragment iterator for A1 operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, -+ InstructionShape, EpilogueOutputOp>; -+ -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; -+ -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with multi-stage -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, ArchTag, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; -+ -+ // Use fragment iterator for A1 operand -+ using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true -+ using FragmentIteratorA1 = -+ cutlass::gemm::warp::MmaTensorOpFragmentIterator< -+ cutlass::MatrixShape, //warp shape -+ cutlass::MatrixShape, //accumulator shape -+ MmaCore1::Shape::kK, //kBlocksColumn -+ ElementAccumulator, ElementA, AccumulatorLayout, -+ InstructionShape, EpilogueOutputOp>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< -+ MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, -+ LayoutScaleBias, InstructionShape, kElementsPerAccess>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; -+ -+ -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ typename MmaCore1::Shape, FragmentIteratorA1, -+ IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h -new file mode 100644 -index 0000000..1ef7e50 ---- /dev/null -+++ b/3rdparty/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h -@@ -0,0 +1,605 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -+ -+#include "threadblock/b2b_mma_pipelined_smem_accumulator.h" -+#include "threadblock/b2b_mma_multistage_smem_accumulator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output with 2-stage pipeline -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output for multi-stage -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp> -+struct DefaultB2bMma { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA0 = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB0 = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using AccessTypeB1 = cutlass::Array; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 2; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::RowMajor, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output with 2-stage pipeline -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, arch::Sm75, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, 2, Operator, EpilogueOutputOp, true, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; //For interleaved layout -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for column-major-interleaved output with multi-stage -+/// Accumulator will be staged in shared memory. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape0, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape1, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape0, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape1, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultB2bMma, OperatorClass, ArchTag, -+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, -+ InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> { -+ // Define the MmaCore components -+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB0 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; -+ using IteratorB1 = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; -+ using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; -+ using MmaPolicy0 = typename MmaCore0::MmaPolicy; -+ using MmaPolicy1 = typename MmaCore1::MmaPolicy; -+ -+ // Use fragment iterator for the accumulator -+ using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; -+ using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape0, InstructionShape, -+ ElementAccumulator, -+ typename WarpMmaTensorOp0::Policy::Operator::FragmentC, -+ SmemAccumulatorLayout -+ >; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; -+ using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter -+ static int const kElementsPerAccess = 4; -+ using IteratorAccumulatorScaleBias = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::MatrixShape, -+ ElementScaleBias, LayoutScaleBias, kElementsPerAccess> -+ >; -+ -+ // Store Accumulator tiles to Shared Memory -+ using SmemIteratorD0 = -+ cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape0, -+ InstructionShape, -+ typename EpilogueOutputOp::ElementOutput, -+ SmemAccumulatorLayout -+ >; -+ -+ static int const kThreadCount = 32; -+ // load warp tile from Shared Memory accumulator -+ using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ MatrixShape, cutlass::gemm::Operand::kA, -+ ElementA, SmemAccumulatorLayout, -+ MatrixShape, -+ WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; -+ -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< -+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, -+ MmaCore0::kCacheOpA, -+ IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, -+ IteratorAccumulatorScaleBias, -+ FragmentIteratorAccumulator, SmemIteratorD0, -+ typename MmaCore1::Shape, WarpIteratorA1, -+ IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, -+ ElementAccumulator, layout::ColumnMajorInterleaved, -+ EpilogueOutputOp, -+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu -new file mode 100644 -index 0000000..bc2185d ---- /dev/null -+++ b/3rdparty/cutlass/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu -@@ -0,0 +1,472 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere -+architecture, most concept still holds. The two main differences are -+ -+1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see -+ include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere. -+ -+2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide -+ latency (see include/cutlass/gemm/threadblock/mma_multistage.h) -+ -+Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in fp32 data and convert them -+implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional -+fp32 data by using NVIDIA Ampere architecture. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ float alpha; -+ float beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({5120, 4096, 4096}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "14_ampere_tf32_tensorop_gemm example\n\n" -+ << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = float; // <- data type of elements in input matrix A -+using ElementInputB = float; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm -+ gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (passed) { -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ printf("%d x %d x %d TF32 tensor op Matrix Multiply\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ return run(options); -+} -diff --git a/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu -new file mode 100644 -index 0000000..dc87fff ---- /dev/null -+++ b/3rdparty/cutlass/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu -@@ -0,0 +1,317 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere -+architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. -+ -+Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of -+meta data is different for every data types. CUTLASS templates can automatically infer it based on -+input A and B. Check code below. -+ -+Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers -+efficiently. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = int32_t; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B -+using ElementOutput = int32_t; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Row Major for -+// Matrix A, Column Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+using Gemm = cutlass::gemm::device::SparseGemm; -+ -+// Data type and layout of meta data matrix E can be inferred from template Gemm. -+using ElementInputE = typename Gemm::ElementE; -+using LayoutInputE = cutlass::layout::RowMajor; -+using ReorderedLayoutInputE = typename Gemm::LayoutE; -+ -+// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h -+// 50% Sparsity on Ampere -+constexpr int kSparse = Gemm::kSparse; -+// How many elements of A are covered per ElementE -+constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; -+// The size of individual meta data -+constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; -+ -+int run() { -+ -+ const int length_m = 512; -+ const int length_n = 512; -+ const int length_k = 1024; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) -+ cutlass::HostTensor tensor_a_uncompressed( -+ problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing -+ -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. -+ cutlass::HostTensor tensor_e( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ // Same size as the above. The above one needs to be reordered and stored in this one. -+ cutlass::HostTensor tensor_e_reordered( -+ cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(2), -+ ElementInputA(-2), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(2), -+ ElementInputB(-2), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(2), -+ ElementOutput(-2), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_e.host_view(), -+ 1, -+ kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core -+ // instructions. -+ cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / kSparse / kElementsPerElementE}); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_e_reordered.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ tensor_e_reordered.device_ref(), // <- reference to matrix E on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. -+ cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), -+ tensor_e.host_ref(), problem_size.m(), problem_size.k()); -+ -+ // Create instantiation for host reference gemm kernel -+ cutlass::reference::host::Gemm -+ gemm_host; -+ -+ // Launch host reference gemm kernel -+ gemm_host(problem_size, -+ alpha, -+ tensor_a_uncompressed.host_ref(), -+ tensor_b.host_ref(), -+ beta, -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref()); -+ -+ // Copy output data from CUTLASS host for comparison -+ tensor_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ -+ return (passed ? 0 : -1); -+} -+ -+int main() { -+ -+ bool notSupported = false; -+ -+ // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.1. -+ // -+ // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. -+ -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (props.major * 10 + props.minor < 80) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ return run(); -+} -diff --git a/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu -new file mode 100644 -index 0000000..378b489 ---- /dev/null -+++ b/3rdparty/cutlass/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu -@@ -0,0 +1,772 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to run convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. -+ -+Writing a single high performance convolution kernel is hard but do-able. Whereas writing -+high performance kernels at scale which works for multiple problem sizes with good abstractions is -+really hard. CUTLASS solves this problem by providing simplified abstractions to compose -+multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -+of GPU easily. -+ -+CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -+and thread-block level, they compute on their own tile-size with higher level of tile sizes being -+composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -+to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -+threadblock-tile (tile size computed by a threadblock). -+ -+In thie example, we split variable initialization into -+1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -+can view them (logical to physical mapping) -+2. Setting up computation properties : describes how the above set tensors will be used to compute -+output of convolution. -+ -+First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -+with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS, -+the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as -+alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as -+epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -+ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as -+cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float), -+ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t), -+ElementOutput (float). Communicating just the data type is not enough. As the data is laid out -+linearly in memory, we have to convey the layout of tensors. We do that by initializing template -+variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -+rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -+variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of -+elements per vector memory access (8), data type of accumulator (float) and data type of -+computation of linear combination (alpha * X + beta * C). -+ -+Now that we setup the properties of data, we have to setup properties of computation. -+ -+Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64, -+64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -+internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -+data in bank-conflict free manner, and ton of other variables required to compose, intialize and -+launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -+from understanding and coding complicated hardware optimizations which can easily go wrong. -+ -+CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -+constitute the whole process of loading input data from global memory to shared memory, loading data -+from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -+sequence shows a typical mma multistage pipeline. -+(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h) -+ -+tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers -+--mma--> registers --global stores--> output to global memory -+ -+NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies. -+ -+ -+There are few more template variables initialized such as, which threadblock tile of output matrix -+is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. -+ -+These are all put together to create a template variable which describes CUTLASS Implicit GEMM -+kernel using cutlass::conv::device::ImplicitGemm template. -+ -+The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -+We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -+in the way of learning CUTLASS. -+ -+Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -+kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -+R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -+important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -+memory required by the kernel we instantiated. If yes, we create it and pass it along with other -+arguments created to intialize CUTLASS kernel then, the kernel is launched. -+ -+In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -+compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "16_ampere_tensorop_conv2dfprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/16_ampere_tensorop_conv2dfprop/16_ampere_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "16_ampere_workspace_conv2dfprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu b/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu -new file mode 100644 -index 0000000..a334511 ---- /dev/null -+++ b/3rdparty/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu -@@ -0,0 +1,306 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+The convolution version of 12_gemm_bias_relu. Similarly, we put bias vector in Operand C and the -+rest is the same as normal convolution. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha in linear combination -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha X C + per channel bias -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int run() { -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ {1, 7, 7, 512}, // activation -+ {512, 3, 3, 512}, // filter -+ {1, 1, 1, 1}, // padding -+ {1, 1}, // striding -+ {1, 1}, // dilation -+ cutlass::conv::Mode::kCrossCorrelation, // mode (convolution or cross-correlation) -+ 1 // split-k slices -+ ); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(problem_size.activation_extent()); -+ cutlass::HostTensor tensor_b(problem_size.filter_extent()); -+ -+ // Create tensor C with dimensions 1x1x1xk which is the bias vector -+ cutlass::HostTensor tensor_c_bias({1, 1, 1, problem_size.K}); -+ -+ // Create tensor D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(problem_size.output_extent()); -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(problem_size.output_extent()); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_bias.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c_bias.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), // <- reference to tensor A on device -+ tensor_b.device_ref(), // <- reference to tensor B on device -+ // tensor C is treated as the bias vector. We can enable the CONV -+ // to project away the N, H, W dimension by setting the stride to zero. -+ {tensor_c_bias.device_data(), LayoutOutput::Stride(0)}, -+ tensor_d.device_ref(), // <- reference to tensor D on device -+ {alpha} }; -+ -+ // Instantiate CUTLASS kernel depending on templates -+ ImplicitGemm implicit_gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference conv kernel -+ // -+ -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Conv2d< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter> -+ ( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c_bias.device_ref(), -+ tensor_ref_d.device_ref(), -+ alpha, ElementComputeEpilogue(0) -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Compute bias + relu in host code -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ tensor_ref_d.at({n, p, q, k}) = -+ std::max(ElementOutput(0), -+ ElementOutput(tensor_ref_d.at({n, p, q, k}) + -+ tensor_c_bias.at({0, 0, 0, k}))); -+ } -+ } -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()) -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ return 0; -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ return run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu -new file mode 100644 -index 0000000..d1044a2 ---- /dev/null -+++ b/3rdparty/cutlass/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu -@@ -0,0 +1,342 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+In the normal GEMM, the fast changing dimension of a matrix always has stride -+equals to 1, e.g. ColumnMajor and RowMajor matrix. Affine2 matrix can have -+larger than 1 stride in both dimensions. To support such layout, we need to -+change to method to visit the global memory: -+ -+ 1. We can only visit 1 element a time because elements are not stored -+ consecutively anymore. Vectorized load/store is not possible. -+ 2. One extra multiplication is needed in calculating the global memory -+ address -+ addr = base_pointer + coord1 * stride1 + coord2 * stride2 -+ -+The rest part of GEMM which includes shared memory load/store, mma comutation -+is the same. -+ -+This example uses Ampere fp64 tensore core Affine2 GEMM as an example. SIMT -+(e.g. sgemm, dgemm) has support Affine2 layout. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = double; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = double; // Data type of elements in input tensor -+using ElementInputB = double; // Data type of elements in input tensor -+using ElementOutput = double; // Data type of elements in output tensor -+ -+// Since Affine2 explicitly lists the strides of both dimensions, it does not really matter if -+// it is columnmajor and rowmajor. However, it helps CUTLASS to improve the load locality if -+// CUTLASS can know which dimension of A/B operand has smaller stride or more dense. -+// -+// Affine2 ColumnMajor means the row stride is smaller and Affine2 RowMajor means the column -+// stride is smaller. -+// -+// The Affine2 epilogue reuses AffineN epilogue so it does not need to specify column majore -+// or row major. -+using LayoutInputA = cutlass::layout::AffineRank2ColumnMajor; -+using LayoutInputB = cutlass::layout::AffineRank2RowMajor; -+using LayoutOutput = cutlass::layout::AffineRankN<2>; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 1, // The number of elements per memory -+ // access has. It has to be 1 for -+ // affine2. -+ ElementAccumulator, -+ ElementComputeEpilogue>; -+ -+using Gemm = typename cutlass::gemm::device::GemmUniversal< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int run() { -+ -+ // Construct Gemm ProblemSize with user defined output size -+ cutlass::gemm::GemmCoord problem_size = {1024, 512, 1024}; -+ -+ // Stride factor shows the distance between two elements in the differnet dimensions. The -+ // first data is the logical distance between two rows, the second is between two columns. -+ // CUTLASS has a utility tool cutlass::layout::Affine2Layout_Factory::layout_factory -+ // to help to convert stride_factor to the two strides. -+ // -+ // It is also totally fine to compute the strides directly without using the utility to -+ // construct the affine2 layout. -+ typename LayoutInputA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutInputB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutOutput::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(problem_size.mk(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), -+ stride_factor_A)); -+ cutlass::HostTensor tensor_b(problem_size.kn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), -+ stride_factor_B)); -+ -+ // Create matrix C used to load for bias addition. -+ cutlass::HostTensor tensor_c(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Create matrix D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(problem_size.mn(), -+ cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), -+ stride_factor_C)); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(4), -+ ElementInputA(-4), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(4), -+ ElementInputB(-4), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(4), -+ ElementOutput(-4), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(1); -+ -+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; -+ -+ int batch_count = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_a.device_ref().data(), // <- reference to matrix A on device -+ tensor_b.device_ref().data(), // <- reference to matrix B on device -+ tensor_c.device_ref().data(), // <- reference to matrix C on device -+ tensor_d.device_ref().data(), // <- reference to matrix D on device -+ tensor_a.layout().capacity(problem_size.mk()), -+ tensor_b.layout().capacity(problem_size.kn()), -+ tensor_c.layout().capacity(problem_size.mn()), -+ tensor_d.layout().capacity(problem_size.mn()), -+ tensor_a.layout().stride(), -+ tensor_b.layout().stride(), -+ tensor_c.layout().stride(), -+ tensor_d.layout().stride() -+ }; -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ -+ // -+ // Create instantiation for device reference gemm kernel -+ // -+ -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Gemm< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator> gemm_device; -+ -+ gemm_device -+ ( -+ problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref() -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ std::cout << (pass -+ ? "Passed" -+ : "Failed") -+ << std::endl; -+ -+ CUTLASS_CHECK(status); -+ -+ return (pass ? 0 : -1); -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ return run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu b/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu -new file mode 100644 -index 0000000..2a16936 ---- /dev/null -+++ b/3rdparty/cutlass/examples/19_tensorop_canonical/tensorop_canonical.cu -@@ -0,0 +1,438 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example requires NVIDIA Ampere GPU or later. -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS Includes -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+ -+// CUTLASS Utility Includes -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define the overal warp-level problem shape -+int const kM = 27; -+int const kN = 31; -+int const kK = 17; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define a warp-level GEMM operator. -+// -+// This template could be part of the CUTLASS Template Library or implemented internally. This -+// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be -+// instantiated in device code. -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+template < -+ typename Shape, -+ typename InstructionShape, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementScalar -+> -+class GemmTensorOp { -+public: -+ -+ using WarpShape = GemmShape< -+ ((Shape::kM + InstructionShape::kM - 1) / InstructionShape::kM) * InstructionShape::kM, -+ ((Shape::kN + InstructionShape::kN - 1) / InstructionShape::kN) * InstructionShape::kN, -+ InstructionShape::kK -+ >; -+ -+ using MmaWarp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ double, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ double, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ double, // Data type of C elements -+ cutlass::layout::RowMajor // Layout of C matrix -+ >::Type; -+ -+ // Number of 'K groups' -+ int const kKgroups = (Shape::kK + InstructionShape::kK - 1) / InstructionShape::kK; -+ -+ // Define a 'FragmentIterator' to iterate over slices of accumulators -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename MmaWarp::Shape, -+ InstructionShape, -+ double, -+ typename MmaWarp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // Define an epilogue 'Tile Iteterator' to iterate over slices of elements in Shared Memory -+ using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpCanonical< -+ typename MmaWarp::Shape, -+ InstructionShape, -+ double, -+ cutlass::layout::RowMajor -+ >; -+ -+ using TensorRefA = typename MmaWarp::IteratorA::TensorRef; -+ using TensorRefB = typename MmaWarp::IteratorB::TensorRef; -+ using TensorRefC = typename AccumulatorTileIterator::TensorRef; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ GemmTensorOp() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ ElementScalar alpha, -+ TensorRefA ref_A, -+ TensorRefB ref_B, -+ ElementScalar beta, -+ TensorRefC ref_C, -+ TensorRefC ref_D, -+ int lane_id) const { -+ -+ // Instantiate iterators pointing to slices of the A and B matrices in shared memory -+ typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); -+ typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); -+ -+ // Instantiate and clear accumulator tile holding the C matrix -+ typename MmaWarp::FragmentC accum; -+ accum.clear(); -+ -+ // Instantiate the warp-level matrix multiply operator -+ MmaWarp mma_op; -+ -+ // Instantiate fragments holding the slice of the matrix held by each warp -+ typename MmaWarp::FragmentA frag_A[2]; -+ typename MmaWarp::FragmentB frag_B[2]; -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[0]); -+ iter_B.load(frag_B[0]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Load fragments from shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < kKgroups; ++k) { -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[(k + 1) % 2]); -+ iter_B.load(frag_B[(k + 1) % 2]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Compute the matrix multiply -+ mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); -+ } -+ -+ // Instantiate iterators -+ FragmentIterator accum_frag_it(accum); -+ AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); -+ AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); -+ -+ // Define function objects for linear scaling operation -+ cutlass::multiplies mul_source; -+ cutlass::multiply_add mul_add_accumulator; -+ -+ // Iterate over the epilogue components -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { -+ -+ // Define storage for slices of the accumulators -+ typename FragmentIterator::Fragment accum_fragment; -+ typename FragmentIterator::Fragment source_fragment; -+ -+ // Select a slice of accumulators from the accumulator tile -+ accum_frag_it.load(accum_fragment); -+ ++accum_frag_it; -+ -+ // Load a corresponding slice from Shared memory -+ source_tile_it.load(source_fragment); -+ ++source_tile_it; -+ -+ // Compute linear scaling - alpha * AB + beta * C -+ source_fragment = mul_source(beta, source_fragment); -+ accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); -+ -+ // Store the result to shared memory -+ dest_tile_it.store(accum_fragment); -+ ++dest_tile_it; -+ } -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held -+// in Shared Memory. -+__global__ void kernel( -+ double *D_gmem, -+ double alpha, -+ double const *A_gmem, -+ double const *B_gmem, -+ double beta, -+ double const *C_gmem) { -+ -+ // Define several matrices in shared memory -+ __shared__ double A[kM][kK]; -+ __shared__ double B[kN][kK]; -+ __shared__ double C[kM][kN]; -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ for (int k = 0; k < kK; ++k) { -+ A[m][k] = A_gmem[m * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ for (int k = 0; k < kK; ++k) { -+ B[n][k] = B_gmem[n * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ C[m][n] = C_gmem[m * kN + n]; -+ } -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), -+ // overall shape, data type of each operand, and layout of each operand. -+ // -+ -+ using GemmTensorOp = cutlass::gemm::warp::GemmTensorOp< -+ cutlass::gemm::GemmShape, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ double, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ double, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ double, // Data type of C elements -+ cutlass::layout::RowMajor, // Layout of C matrix -+ double // Scalar type of alpha and beta -+ >; -+ -+ // Instantiate the GEMM operator -+ GemmTensorOp gemm; -+ -+ // Execute the warp-level GEMM operation -+ gemm( -+ alpha, -+ {&A[0][0], kK}, -+ {&B[0][0], kK}, -+ beta, -+ {&C[0][0], kN}, -+ {&C[0][0], kN}, -+ threadIdx.x); -+ -+ __syncthreads(); -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ D_gmem[m * kN + n] = C[m][n]; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to canonical warp-level GEMM operation -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Return 0 so tests are considered passing if run on unsupported platforms. -+ return 0; -+ } -+ -+ cutlass::HostTensor A({kM, kK}); -+ cutlass::HostTensor B({kK, kN}); -+ cutlass::HostTensor C({kM, kN}); -+ cutlass::HostTensor D({kM, kN}); -+ -+ uint64_t seed = 2020; -+ double max = 8; -+ double min = -8; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ A.host_view(), -+ seed, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ B.host_view(), -+ seed + 17, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ C.host_view(), -+ seed + 31, -+ max, -+ min, -+ 0 -+ ); -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ dim3 grid(1,1); -+ dim3 block(32, 1, 1); -+ -+ double alpha = 2.25; -+ double beta = 1.24; -+ -+ kernel<<< grid, block >>>( -+ D.device_data(), -+ alpha, -+ A.device_data(), -+ B.device_data(), -+ beta, -+ C.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to synchronize device after kernel launch." << std::endl; -+ return -1; -+ } -+ -+ D.sync_host(); -+ -+ // Compute reference on host -+ cutlass::HostTensor D_ref({kM, kN}, false); -+ -+ cutlass::reference::host::GemmComplex( -+ {kM, kN, kK}, -+ alpha, -+ A.host_ref(), -+ cutlass::ComplexTransform::kNone, -+ B.host_ref(), -+ cutlass::ComplexTransform::kNone, -+ beta, -+ C.host_ref(), -+ D_ref.host_ref(), -+ double() -+ ); -+ -+ // Verify reference matches computed -+ if (!cutlass::reference::host::TensorEquals( -+ D.host_view(), -+ D_ref.host_view())) { -+ -+ std::cerr -+ << "A =\n" << A.host_view() -+ << "\n\nB = \n" << B.host_view() -+ << "\n\nC = " << C.host_view() -+ << "\n\nRef =\n" << D_ref.host_view() -+ << "\n\nD =\n" << D.host_view() << "\n\n"; -+ -+ std::cerr << "Error - device results mismatch host reference." << std::endl; -+ -+ return -1; -+ } -+ -+ std::cout << "Passed" << std::endl; -+ -+ return 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu b/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu -new file mode 100644 -index 0000000..632cd22 ---- /dev/null -+++ b/3rdparty/cutlass/examples/20_simt_canonical/simt_canonical.cu -@@ -0,0 +1,425 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example requires NVIDIA Maxwell GPU or beyond. -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// CUTLASS Includes -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/epilogue/warp/fragment_iterator_simt.h" -+#include "cutlass/epilogue/warp/tile_iterator_simt.h" -+ -+// CUTLASS Utility Includes -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define the overal warp-level problem shape -+int const kM = 14; -+int const kN = 27; -+int const kK = 17; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Define a warp-level GEMM operator. -+// -+// This template could be part of the CUTLASS Template Library or implemented internally. This -+// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be -+// instantiated in device code. -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+template < -+ typename Shape, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementScalar -+> -+class GemmSimt { -+public: -+ -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using MmaWarp = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ // Number of 'K groups' -+ int const kKgroups = Shape::kK; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename MmaWarp::Shape, -+ typename MmaWarp::ThreadMma, -+ layout::RowMajor, // SMEM layout -+ typename MmaWarp::Policy -+ >; -+ -+ using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical< -+ typename MmaWarp::Shape, -+ typename MmaWarp::ThreadMma, -+ float, // ElementAccumulator -+ layout::RowMajor, // SMEM layout -+ typename MmaWarp::Policy -+ >; -+ -+ using TensorRefA = typename MmaWarp::IteratorA::TensorRef; -+ using TensorRefB = typename MmaWarp::IteratorB::TensorRef; -+ using TensorRefC = typename AccumulatorTileIterator::TensorRef; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ GemmSimt() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ ElementScalar alpha, -+ TensorRefA ref_A, -+ TensorRefB ref_B, -+ ElementScalar beta, -+ TensorRefC ref_C, -+ TensorRefC ref_D, -+ int lane_id) const { -+ -+ // Instantiate iterators pointing to slices of the A and B matrices in shared memory -+ typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); -+ typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); -+ -+ // Instantiate and clear accumulator tile holding the C matrix -+ typename MmaWarp::FragmentC accum; -+ accum.clear(); -+ -+ // Instantiate the warp-level matrix multiply operator -+ MmaWarp mma_op; -+ -+ // Instantiate fragments holding the slice of the matrix held by each warp -+ typename MmaWarp::FragmentA frag_A[2]; -+ typename MmaWarp::FragmentB frag_B[2]; -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[0]); -+ iter_B.load(frag_B[0]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Load fragments from shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < kKgroups; ++k) { -+ -+ // Load fragments from shared memory -+ iter_A.load(frag_A[(k + 1) % 2]); -+ iter_B.load(frag_B[(k + 1) % 2]); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ // Compute the matrix multiply -+ mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); -+ } -+ -+ // Instantiate iterators -+ FragmentIterator accum_frag_it(accum); -+ AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); -+ AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); -+ -+ // Define function objects for linear scaling operation -+ cutlass::multiplies mul_source; -+ cutlass::multiply_add mul_add_accumulator; -+ -+ // Iterate over the epilogue components -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { -+ -+ // Define storage for slices of the accumulators -+ typename FragmentIterator::Fragment accum_fragment; -+ typename FragmentIterator::Fragment source_fragment; -+ -+ // Select a slice of accumulators from the accumulator tile -+ accum_frag_it.load(accum_fragment); -+ ++accum_frag_it; -+ -+ // Load a corresponding slice from Shared memory -+ source_tile_it.load(source_fragment); -+ ++source_tile_it; -+ -+ // Compute linear scaling - alpha * AB + beta * C -+ source_fragment = mul_source(beta, source_fragment); -+ accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); -+ -+ // Store the result to shared memory -+ dest_tile_it.store(accum_fragment); -+ ++dest_tile_it; -+ } -+ -+ } -+ -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held -+// in Shared Memory. -+__global__ void kernel( -+ float *D_gmem, -+ float alpha, -+ float const *A_gmem, -+ float const *B_gmem, -+ float beta, -+ float const *C_gmem) { -+ -+ // Define several matrices in shared memory -+ __shared__ float A[kM][kK]; -+ __shared__ float B[kN][kK]; -+ __shared__ float C[kM][kN]; -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ for (int k = 0; k < kK; ++k) { -+ A[m][k] = A_gmem[m * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ for (int k = 0; k < kK; ++k) { -+ B[n][k] = B_gmem[n * kK + k]; -+ } -+ } -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ C[m][n] = C_gmem[m * kN + n]; -+ } -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), -+ // overall shape, data type of each operand, and layout of each operand. -+ // -+ -+ using GemmSimt = cutlass::gemm::warp::GemmSimt< -+ cutlass::gemm::GemmShape, -+ float, // Data type of A elements -+ cutlass::layout::RowMajor, // Layout of A matrix -+ float, // Data type of B elements -+ cutlass::layout::ColumnMajor, // Layout of B matrix -+ float, // Data type of C elements -+ cutlass::layout::RowMajor, // Layout of C matrix -+ float // Scalar type of alpha and beta -+ >; -+ -+ // Instantiate the GEMM operator -+ GemmSimt gemm; -+ -+ // Execute the warp-level GEMM operation -+ gemm( -+ alpha, -+ {&A[0][0], kK}, -+ {&B[0][0], kK}, -+ beta, -+ {&C[0][0], kN}, -+ {&C[0][0], kN}, -+ threadIdx.x); -+ -+ __syncthreads(); -+ -+ // Copy data into SMEM -+ if (threadIdx.x == 0) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int m = 0; m < kM; ++m) { -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int n = 0; n < kN; ++n) { -+ D_gmem[m * kN + n] = C[m][n]; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char *arg[]) { -+ -+ cutlass::HostTensor A({kM, kK}); -+ cutlass::HostTensor B({kK, kN}); -+ cutlass::HostTensor C({kM, kN}); -+ cutlass::HostTensor D({kM, kN}); -+ -+ uint64_t seed = 2020; -+ float max = 8; -+ float min = -8; -+ -+ std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape() <<")" << std::endl; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ A.host_view(), -+ seed, -+ max, -+ min, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ B.host_view(), -+ seed + 17, -+ max, -+ min, -+ 0 -+ ); -+ -+#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging -+ cutlass::reference::host::BlockFillSequential( -+ A.host_view().data(), A.host_view().capacity()); -+ -+ cutlass::reference::host::TensorFillIdentity(B.host_view()); -+#endif -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ C.host_view(), -+ seed + 31, -+ max, -+ min, -+ 0 -+ ); -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ -+ kernel<<< grid, block >>>( -+ D.device_data(), -+ alpha, -+ A.device_data(), -+ B.device_data(), -+ beta, -+ C.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to synchronize device after kernel launch." << std::endl; -+ return -1; -+ } -+ -+ D.sync_host(); -+ -+ // Compute reference on host -+ cutlass::HostTensor D_ref({kM, kN}, false); -+ cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view()); -+ -+ cutlass::reference::host::Gemm< -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, float> reference_gemm; -+ -+ reference_gemm( -+ {kM, kN, kK}, -+ alpha, -+ A.host_ref(), -+ B.host_ref(), -+ beta, -+ D_ref.host_ref(), -+ float() -+ ); -+ -+ // Verify reference matches computed -+ if (!cutlass::reference::host::TensorEquals( -+ D.host_view(), -+ D_ref.host_view())) { -+ -+ std::cerr -+ << "A =\n" << A.host_view() -+ << "\n\nB = \n" << B.host_view() -+ << "\n\nC = " << C.host_view() -+ << "\n\nRef =\n" << D_ref.host_view() -+ << "\n\nD =\n" << D.host_view() << "\n\n"; -+ -+ std::cerr << "Error - device results mismatch host reference." << std::endl; -+ -+ return -1; -+ } -+ -+ std::cout << "Passed" << std::endl; -+ -+ return 0; -+ -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu b/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu -new file mode 100644 -index 0000000..02d7b53 ---- /dev/null -+++ b/3rdparty/cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu -@@ -0,0 +1,454 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ cutlass::Quaternion alpha; -+ cutlass::Quaternion beta; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({1024, 1024, 1024}), -+ batch_count(1), -+ reference_check(true), -+ iterations(20), -+ alpha(1), -+ beta() { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("batch", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha.w()); -+ cmd.get_cmd_line_argument("alpha_i", alpha.x()); -+ cmd.get_cmd_line_argument("alpha_j", alpha.y()); -+ cmd.get_cmd_line_argument("alpha_k", alpha.z()); -+ -+ cmd.get_cmd_line_argument("beta", beta.w()); -+ cmd.get_cmd_line_argument("beta_i", beta.x()); -+ cmd.get_cmd_line_argument("beta_j", beta.y()); -+ cmd.get_cmd_line_argument("beta_k", beta.z()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "21_quaternion_gemm example\n\n" -+ << " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch= Number of GEMM operations executed in one batch\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --alpha_i= Epilogue scalar alpha_i (imaginary part)\n" -+ << " --alpha_j= Epilogue scalar alpha_j (imaginary part)\n" -+ << " --alpha_k= Epilogue scalar alpha_k (imaginary part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --beta_i= Epilogue scalar beta_i (imaginary part)\n\n" -+ << " --beta_j= Epilogue scalar beta_j (imaginary part)\n\n" -+ << " --beta_k= Epilogue scalar beta_k (imaginary part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n" -+ << " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product() * batch_count * 16; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using precision = float; -+using Element = cutlass::Quaternion; -+using ElementComputeEpilogue = Element; // <- data type of epilogue operations -+using ElementAccumulator = Element; // <- data type of accumulator -+using ElementInputA = Element; // <- data type of elements in input matrix A -+using ElementInputB = Element; // <- data type of elements in input matrix B -+using ElementOutput = Element; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm50; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+using Gemm = cutlass::gemm::device::Gemm; -+ -+int run(Options options) { -+ -+ // PASS/FAIL status -+ bool passed = true; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ cutlass::HostTensor tensor_ref_d( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // reference kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ 4, -+ -4, -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(0); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication -+ tensor_a.device_ref(), // <- reference to matrix A on device -+ tensor_b.device_ref(), // <- reference to matrix B on device -+ tensor_c.device_ref(), // <- reference to matrix C on device -+ tensor_d.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ if (options.reference_check) { -+ -+ // Create instantiation for device reference gemm kernel -+ cutlass::reference::device::Gemm gemm_device; -+ -+ // Launch device reference gemm kernel -+ gemm_device(problem_size, -+ alpha, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ beta, -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ passed &= cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ } -+ -+ if (passed) { -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ std::cout << (passed ? "Passed" : "Failed") << std::endl; -+ return (passed ? 0 : -1); -+} -+ -+int main(int argc, char const** argv) { -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ return run(options); -+} -+ -diff --git a/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu b/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu -new file mode 100644 -index 0000000..57df73f ---- /dev/null -+++ b/3rdparty/cutlass/examples/22_quaternion_conv/quaternion_conv.cu -@@ -0,0 +1,667 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using Element = cutlass::Quaternion; -+using ElementAccumulator = Element; // Data type of accumulator -+using ElementComputeEpilogue = Element; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = Element; // Data type of elements in input tensor -+using ElementInputB = Element; // Data type of elements in input tensor -+using ElementOutput = Element; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm50; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; // SIMT instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 2; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+ -+using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha_w", alpha.w()); -+ cmd.get_cmd_line_argument("alpha_x", alpha.x()); -+ cmd.get_cmd_line_argument("alpha_y", alpha.y()); -+ cmd.get_cmd_line_argument("alpha_z", alpha.z()); -+ -+ cmd.get_cmd_line_argument("beta_w", beta.w()); -+ cmd.get_cmd_line_argument("beta_x", beta.x()); -+ cmd.get_cmd_line_argument("beta_y", beta.y()); -+ cmd.get_cmd_line_argument("beta_z", beta.z()); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "22_quaternion_conv example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/22_quaternion_conv/22_quaternion_conv --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()) * 16; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_ref_c(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ 7, -+ -8, -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ 7, -+ -8, -+ 0); -+ -+ // Fill tensor C on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_c.host_view()); -+ -+ // Fill tensor C for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_c.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_ref_c.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_c.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_c.host_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_c.host_view(), -+ tensor_ref_c.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "22_quaternion_conv_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256, 512}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu -new file mode 100644 -index 0000000..81a3e15 ---- /dev/null -+++ b/3rdparty/cutlass/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when -+computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere -+16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor -+core instructions. -+ -+Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h -+A few bit of reduction is done in the epilouge before storing the vector, see -+epilogue/threadblock/epilogue_gemm_k_reduction.h -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_with_k_reduction.h" -+#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -+using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+// Layout of the output vector -+using LayoutGemmKReduction = cutlass::layout::PitchLinear; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// Reduce A or B operand along the K dimension -+constexpr bool ReduceKForA = true; -+ -+// Alignment of A operand -+constexpr int AlignmentA = 8; -+ -+// Alignment of B operand -+constexpr int AlignmentB = 8; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; -+ -+using Gemm = typename cutlass::gemm::device::GemmWithKReduction< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ ReduceKForA, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ AlignmentA, -+ AlignmentB, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+>; -+ -+// Below is the reduction kernel used in the case of parallel split-k -+using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; -+ -+using ReduceOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementOutput, -+ EpilogueOp::kCount -+ >; -+ -+using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< -+ ReduceGemmSplitKShape, -+ EpilogueOp, -+ ReduceOp -+ >; -+ -+using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK; -+ -+using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, -+ cutlass::epilogue::thread::ScaleType::Nothing>; -+ -+using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< -+ ReduceVectorSplitKShape, -+ DummyEpilogueOp, -+ ReduceOp -+ >; -+ -+using ReduceVectorSplitK = cutlass::reduction::device::ReduceSplitK; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ int split_k_slices; -+ bool parallel_split_k; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ problem_size(1024, 1024, 1024), -+ split_k_slices(1), -+ parallel_split_k(false), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(-1), -+ beta(-1), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices, -+ bool parallel_split_k) { -+ -+ this->problem_size = problem_size; -+ this->split_k_slices = split_k_slices; -+ this->parallel_split_k = parallel_split_k; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("parallel-split-k")) { -+ parallel_split_k = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("split-k-slices", split_k_slices); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "28_ampere_gemm_bias_fusion example\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M\n" -+ << " --n= GEMM N\n" -+ << " --k= GEMM K\n" -+ << " --split-k-slices= Split K Slices\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --parallel-split-k If set (true), use parallel split K\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several problem sizes.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n"; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "ID,M,N,K,SplitK-Slices,Parallel-SplitK,Runtime"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "gemm_" << idx << "," -+ << options.problem_size.m() << "," -+ << options.problem_size.n() << "," -+ << options.problem_size.k() << "," -+ << options.split_k_slices << "," -+ << options.parallel_split_k << "," -+ << runtime_ms ; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile(Options const &options) { -+ -+ Result result; -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a(options.problem_size.mk()); -+ cutlass::HostTensor tensor_b(options.problem_size.kn()); -+ -+ -+ // Create tensor C with dimensions 1x1x1xk which is the bias vector -+ cutlass::HostTensor tensor_c(options.problem_size.mn()); -+ -+ // Create tensor D used to store output from CUTLASS kernel -+ cutlass::HostTensor tensor_d(options.problem_size.mn()); -+ // Create matrix D with dimensions M x N used to store output from reference -+ // kernel -+ cutlass::HostTensor tensor_ref_d(options.problem_size.mn()); -+ -+ int reduce_vector_length = ReduceKForA ? options.problem_size.m() : options.problem_size.n(); -+ -+ cutlass::HostTensor tensor_reduction({reduce_vector_length, 1}); -+ cutlass::HostTensor tensor_ref_reduction({reduce_vector_length, 1}); -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1997, -+ ElementInputA(2), -+ ElementInputA(-2), -+ 0); // <- Fill tensor A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 2003, -+ ElementInputB(2), -+ ElementInputB(-2), -+ 0); // <- Fill tensor B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 2017, -+ ElementOutput(2), -+ ElementOutput(-2), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ cutlass::reference::host::TensorFill( -+ tensor_reduction.host_view()); // <- fill matrix D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_reduction.host_view()); // <- fill matrix D for reference on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ tensor_reduction.sync_device(); -+ -+ // Initialize alpha for dot product computation -+ ElementComputeEpilogue alpha = options.parallel_split_k ? ElementComputeEpilogue(1) -+ : ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = options.parallel_split_k ? ElementComputeEpilogue(0) -+ : ElementComputeEpilogue(options.beta); -+ -+ cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ? -+ cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel : -+ cutlass::gemm::GemmUniversalMode::kGemm; -+ -+ int batch_count = options.split_k_slices; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments( -+ mode, -+ options.problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_a.device_ref().data(), // <- reference to tensor A on device -+ tensor_b.device_ref().data(), // <- reference to tensor B on device -+ tensor_c.device_ref().data(), // <- reference to matrix C on device -+ tensor_d.device_ref().data(), // <- reference to matrix D on device -+ tensor_reduction.device_ref().data(), // <- reference to reduction tensor on device -+ options.problem_size.m() * options.problem_size.k(), -+ options.problem_size.n() * options.problem_size.k(), -+ options.problem_size.m() * options.problem_size.n(), -+ options.problem_size.m() * options.problem_size.n(), -+ reduce_vector_length, -+ tensor_a.layout().stride(0), -+ tensor_b.layout().stride(0), -+ tensor_c.layout().stride(0), -+ tensor_d.layout().stride(0), -+ tensor_reduction.layout().stride(0)); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ result.status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ result.status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // Launch initialized CUTLASS kernel -+ result.status = gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ if (options.parallel_split_k && batch_count > 1) { -+ // reduce gemm -+ -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); -+ -+ int splitk_gemm_stride = options.problem_size.m(); -+ -+ cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); -+ -+ void * workspace_gemm_ptr = workspace.get(); -+ cutlass::TensorRef workspace_gemm_tensorref(static_cast(workspace_gemm_ptr), splitk_gemm_layout); -+ -+ cutlass::TensorRef tensor_d_tensorref(tensor_d.device_ref().data(), splitk_gemm_layout); -+ -+ cutlass::TensorRef tensor_c_tensorref(tensor_c.device_ref().data(), splitk_gemm_layout); -+ -+ typename ReduceGemmSplitK::Arguments reduce_gemm_splitk_arguments{ -+ cutlass::MatrixCoord(options.problem_size.n(), options.problem_size.m()), -+ batch_count, -+ size_t(options.problem_size.m() * options.problem_size.n()), -+ workspace_gemm_tensorref, -+ tensor_d_tensorref, -+ tensor_c_tensorref, -+ {alpha, beta} -+ }; -+ -+ ReduceGemmSplitK reduce_gemm_splitk_op; -+ -+ result.status = reduce_gemm_splitk_op.initialize(reduce_gemm_splitk_arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = reduce_gemm_splitk_op(); -+ CUTLASS_CHECK(result.status); -+ -+ // reduce k vector -+ cutlass::layout::RowMajor splitk_vector_layout(reduce_vector_length); -+ -+ ElementOutput *workspace_vector_ptr = static_cast(workspace_gemm_ptr) + batch_count * options.problem_size.m() * options.problem_size.n(); -+ cutlass::TensorRef workspace_vector_tensorref(workspace_vector_ptr, splitk_vector_layout); -+ -+ cutlass::TensorRef tensor_reduction_tensorref(tensor_reduction.device_ref().data(), splitk_vector_layout); -+ -+ cutlass::TensorRef tensor_nullptr_tensorref(nullptr, splitk_vector_layout); -+ -+ typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments( -+ cutlass::MatrixCoord(1, reduce_vector_length), -+ batch_count, -+ size_t(reduce_vector_length), -+ workspace_vector_tensorref, -+ tensor_reduction_tensorref, -+ tensor_nullptr_tensorref, -+ {1.0f, 0.0f}); -+ -+ ReduceVectorSplitK reduce_vector_splitk_op; -+ -+ result.status = reduce_vector_splitk_op.initialize(reduce_vector_splitk_arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = reduce_vector_splitk_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // -+ // Create instantiation for device reference conv kernel -+ // -+ if (options.reference_check) { -+ // Launch device reference to compute strictly the product A * B -+ cutlass::reference::device::Gemm< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator> gemm_device; -+ -+ gemm_device -+ ( -+ options.problem_size, -+ ElementComputeEpilogue(options.alpha), -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ ElementComputeEpilogue(options.beta), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref() -+ ); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ tensor_reduction.sync_host(); -+ -+ // Reduce K in host code -+ if (ReduceKForA) { -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ for (int k = 0; k < options.problem_size.k(); ++k) { -+ tensor_ref_reduction.at({m, 0}) += -+ tensor_a.at(cutlass::MatrixCoord(m, k)); -+ } -+ } -+ } else { -+ for (int k = 0; k < options.problem_size.k(); ++k) { -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ tensor_ref_reduction.at({n, 0}) += -+ tensor_b.at(cutlass::MatrixCoord(k, n)); -+ } -+ } -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(), -+ tensor_reduction.host_view()); -+ -+ if (!pass) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "23_ampere_gemm_operand_reduction_fusion" -+ << options.problem_size.m() << "x" << options.problem_size.n() << "x" << options.problem_size.k() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "A = \n" << tensor_a.host_view() << "\n\n" -+ << "B = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n"; -+ output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl; -+ output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ struct Benchmark { -+ int m, n, k, split_k_slices, parallel_split_k; -+ } problem_sizes[] = { -+ {4096, 6144, 4096, 1, false}, -+ }; -+ -+ Result::print_header(std::cout, options) << "\n"; -+ -+ int idx = 1; -+ -+ for (auto const &problem_size : problem_sizes) { -+ options.update({problem_size.m, problem_size.n, problem_size.k}, -+ problem_size.split_k_slices, problem_size.parallel_split_k); -+ -+ Result result = profile(options); -+ result.print(std::cout, idx, options) << "\n"; -+ -+ ++idx; -+ } -+ } else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << "\n"; -+ return -1; -+ } -+ -+ Result result = profile(options); -+ -+ Result::print_header(std::cout, options) << "\n"; -+ result.print(std::cout, 1, options) << "\n"; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu b/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu -new file mode 100644 -index 0000000..4b080fc ---- /dev/null -+++ b/3rdparty/cutlass/examples/24_gemm_grouped/gemm_grouped.cu -@@ -0,0 +1,1578 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief GEMM Grouped Example. -+ -+ This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices -+ in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly, -+ leading dimensions and problem sizes are stored in arrays in GMEM. -+ -+ This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM -+ concept may be distinct. -+ -+ This benchmark program initializes a workspace with random problem sizes for a given number of -+ groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to -+ model problems more similar to the traditional batched GEMM. -+ -+ Additionally, problem sizes are collected and binned to compute the same problem as a series of -+ conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance -+ enhancement achieved by implementing a specialized grouped GEMM kernel. -+ -+ Examples: -+ -+ # Runs a grouped GEMM with 100 random problem sizes -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 -+ -+ # Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024) -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true -+ -+ # Runs a grouped GEMM that is equivalent to a batched GEMM -+ $ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true -+ -+ # Execute Grouped GEMM and profile with NSight -+ $ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --m=256 --n=256 --k=256 --verbose=true \ -+ --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double initialization_time_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double initialization_time_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), -+ status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function for cutlass::gemm::GemmCoord -+struct HashGemmCoord { -+ size_t operator()(cutlass::gemm::GemmCoord const &problem) const { -+ std::hash hasher; -+ return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool profile_initialization; -+ bool sort_problems; -+ -+ std::vector problem_sizes; -+ -+ // problem size bins -+ std::unordered_map< -+ cutlass::gemm::GemmCoord, -+ std::vector, -+ HashGemmCoord> problem_bins; -+ -+ int alignment; -+ int problem_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ std::string benchmark_path; -+ -+ std::string output_tag; -+ std::ofstream output_file; -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ std::vector scheduler_modes; -+ -+ std::unordered_map -+ str_to_scheduler_mode = { -+ {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, -+ {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} -+ }; -+ -+ struct GroupScheduleModeHash { -+ size_t operator()(GroupScheduleMode m) const { -+ return static_cast(m); -+ } -+ }; -+ -+ std::unordered_map -+ scheduler_mode_to_str = { -+ {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, -+ {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} -+ }; -+ -+ std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(8), -+ reference_check(true), -+ profile_initialization(false), -+ sort_problems(false), -+ problem_count(15), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta(), -+ scheduler_modes({GroupScheduleMode::kDeviceOnly}) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 16); -+ cmd.get_cmd_line_argument("groups", problem_count, 8); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, true); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); -+ cmd.get_cmd_line_argument("sort-problems", sort_problems, false); -+ cmd.get_cmd_line_argument("benchmark", benchmark_path); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ scheduler_modes.clear(); -+ if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { -+ scheduler_modes = all_scheduler_modes; -+ } else { -+ for (std::string precomp_str : scheduler_mode_strs) { -+ auto it = str_to_scheduler_mode.find(precomp_str); -+ if (it != str_to_scheduler_mode.end()) { -+ scheduler_modes.push_back(it->second); -+ } else if (precomp_str == "all") { -+ std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; -+ error = true; -+ return; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ } -+ } -+ -+ std::string output_path; -+ cmd.get_cmd_line_argument("tag", output_tag); -+ cmd.get_cmd_line_argument("output_file", output_path); -+ -+ if (!output_path.empty()) { -+ -+ std::ios_base::openmode open_mode = std::ios_base::out; -+ -+ std::ifstream input_file(output_path.c_str()); -+ -+ if (input_file.good()) { -+ open_mode = std::ios_base::app; -+ input_file.close(); -+ } -+ -+ output_file.open(output_path.c_str(), open_mode); -+ -+ if (output_file.good() && open_mode != std::ios_base::app) { -+ output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; -+ } -+ } -+ -+ // Decide how to initialize the problems -+ if (!benchmark_path.empty()) { -+ if (!benchmark_problems()) { -+ error = true; -+ problem_sizes.clear(); -+ return; -+ } -+ } -+ else { -+ randomize_problems(cmd); -+ } -+ -+ // Post-process the problem sizes -+ bin_problems(); -+ } -+ -+ void randomize_problems(cutlass::CommandLine &cmd) { -+ -+ // -+ // For now, randomly choose the problem sizes. -+ // -+ -+ int cmd_line_m = -1; -+ int cmd_line_n = -1; -+ int cmd_line_k = -1; -+ -+ cmd.get_cmd_line_argument("m", cmd_line_m,128); -+ cmd.get_cmd_line_argument("n", cmd_line_n,128); -+ cmd.get_cmd_line_argument("k", cmd_line_k,64); -+ -+ problem_sizes.reserve(problem_count); -+ -+ for (int i = 0; i < problem_count; ++i) { -+ -+ int m = cmd_line_m; -+ int n = cmd_line_n; -+ int k = cmd_line_k; -+ -+ if (m < 1) { -+ m = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (n < 1) { -+ n = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (k < 1) { -+ k = alignment * ((rand() % 256) + 1); -+ } -+ -+ cutlass::gemm::GemmCoord problem(m, n, k); -+ -+ problem_sizes.push_back(problem); -+ } -+ } -+ -+ /// Load a benchmark -+ bool benchmark_problems() { -+ std::ifstream file(benchmark_path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ while (file.good()) { -+ -+ int idx = -1; -+ std::string extent_str; -+ -+ file >> idx >> extent_str; -+ -+ if (idx < 0 || extent_str.empty()) { -+ break; -+ } -+ -+ cutlass::gemm::GemmCoord extent; -+ std::vector tokens; -+ -+ cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); -+ -+ for (int i = 0; i < int(tokens.size()); ++i) { -+ int x = std::atoi(tokens.at(i).c_str()); -+ -+ // round up -+ if (x % alignment) { -+ x += (alignment - (x % alignment)); -+ } -+ -+ extent.at(i) = x; -+ } -+ -+ if (extent.product()) { -+ problem_sizes.push_back(extent); -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Post processes the problems -+ void bin_problems() { -+ -+ problem_bins.clear(); -+ -+ problem_count = int(problem_sizes.size()); -+ -+ // -+ // Insert the problem sizes into a sorted container class. This is *NOT* necessary -+ // to run the CUTLASS kernel, but it enables the execution of cublas's batched GEMM. -+ // -+ for (int i = 0; i < int(problem_sizes.size()); ++i) { -+ auto it = problem_bins.find(problem_sizes.at(i)); -+ if (it == problem_bins.end()) { -+ problem_bins.insert({problem_sizes.at(i), std::vector({i}) }); -+ } -+ else { -+ it->second.push_back(i); -+ } -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "24_gemm_grouped\n\n" -+ << " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n" -+ << " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n" -+ << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" -+ << " in device Global Memory and loaded by the kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --benchmark= Executes a benchmark problem size.\n" -+ << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" -+ << " --tag= String tag to prepend to the CSV file.\n" -+ << " --groups= Number of individual GEMM problems (default: --groups=15)\n" -+ << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n" -+ << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n" -+ << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" -+ << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a grouped GEMM with 100 random problem sizes\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100\n\n" -+ -+ << "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped GEMM with each different scheduler mode\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all\n\n" -+ -+ << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" -+ -+ << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" -+ << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" -+ << "#\n" -+ << "# For example, assume the following are the contents of 'problems.txt'\n" -+ << "#\n" -+ << "# 0 1024x256x520\n" -+ << "# 1 520x264x1024\n" -+ << "# 2 96x48x1024\n" -+ << "#\n" -+ << "$ ./examples/24_gemm_grouped/24_gemm_grouped --benchmark=problems.txt\n\n" -+ -+ << "# Execute Grouped GEMM and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --m=256 --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ for (auto const & problem : problem_sizes) { -+ fmas += problem.product(); -+ } -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class BaseTestbed { -+public: -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ BaseTestbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ int problem_count() const { -+ return options.problem_count; -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Allocates device-side data -+ void allocate() { -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ lda_host.resize(problem_count()); -+ ldb_host.resize(problem_count()); -+ ldc_host.resize(problem_count()); -+ ldd_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem = options.problem_sizes.at(i); -+ -+ lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.m() * problem.k(); -+ int64_t elements_B = problem.k() * problem.n(); -+ int64_t elements_C = problem.m() * problem.n(); -+ int64_t elements_D = problem.m() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ } -+ -+ lda.reset(problem_count()); -+ ldb.reset(problem_count()); -+ ldc.reset(problem_count()); -+ ldd.reset(problem_count()); -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ } -+ -+ /// Initializes device-side data -+ void initialize() { -+ problem_sizes_device.reset(problem_count()); -+ problem_sizes_device.copy_from_host(options.problem_sizes.data()); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ std::vector ptr_A_host(problem_count()); -+ std::vector ptr_B_host(problem_count()); -+ std::vector ptr_C_host(problem_count()); -+ std::vector ptr_D_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count()); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count()); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count()); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count()); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); -+ initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); -+ initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), block_D.size(), ElementC(), ElementC()); -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ Gemm::kTransformA, -+ view_B, -+ Gemm::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+template -+class TestbedBatched : BaseTestbed { -+public: -+ TestbedBatched( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ void print_problem_sizes() { -+ std::cout << std::endl; -+ size_t bin_idx = 0; -+ size_t problem_count_check = 0; -+ std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n"; -+ for (auto const & bin : this->options.problem_bins) { -+ -+ std::cout << " [" << bin_idx << "]: " -+ << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() -+ << ", batch count: " << bin.second.size() << "\n"; -+ -+ ++bin_idx; -+ problem_count_check += bin.second.size(); -+ } -+ -+ if (problem_count_check != this->problem_count()) { -+ std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; -+ } -+ -+ std::cout << std::endl; -+ } -+ -+ /// Executes a batched kernel and measures runtime -+ Result profile() { -+ std::cout << "Batched GEMM:\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // -+ // Prepare batched GEMM environment -+ // -+ -+ int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); -+ -+ // Array of leading dimensions used by batched GEMM calls -+ std::vector bin_problem_sizes; -+ std::vector bin_count; -+ std::vector bin_ldm_A; -+ std::vector bin_ldm_B; -+ std::vector bin_ldm_C; -+ std::vector bin_start; -+ -+ std::vector ptr_A_batched_host; -+ std::vector ptr_B_batched_host; -+ std::vector ptr_C_batched_host; -+ -+ for (auto const & bin : this->options.problem_bins) { -+ int first_idx = bin.second.front(); -+ -+ bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx)); -+ bin_count.push_back(int32_t(bin.second.size())); -+ -+ bin_ldm_A.push_back(static_cast(this->lda_host.at(first_idx))); -+ bin_ldm_B.push_back(static_cast(this->ldb_host.at(first_idx))); -+ bin_ldm_C.push_back(static_cast(this->ldc_host.at(first_idx))); -+ -+ if (ptr_A_batched_host.size() % 2) { -+ ptr_A_batched_host.push_back(nullptr); -+ ptr_B_batched_host.push_back(nullptr); -+ ptr_C_batched_host.push_back(nullptr); -+ } -+ -+ bin_start.push_back(int32_t(ptr_A_batched_host.size())); -+ -+ for (int idx : bin.second) { -+ -+ if (bin_problem_sizes.back() != this->options.problem_sizes.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_A.back() != this->lda_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_B.back() != this->ldb_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ if (bin_ldm_C.back() != this->ldc_host.at(idx)) { -+ std::cerr << "Error - failed to group problems.\n"; -+ return result; -+ } -+ -+ ptr_A_batched_host.push_back(this->block_A.get() + this->offset_A.at(idx)); -+ ptr_B_batched_host.push_back(this->block_B.get() + this->offset_B.at(idx)); -+ ptr_C_batched_host.push_back(this->block_D.get() + this->offset_C.at(idx)); -+ } -+ } -+ -+ // Array of GMEM pointers used by batched array GEMM calls -+ cutlass::DeviceAllocation ptr_A_batched; -+ cutlass::DeviceAllocation ptr_B_batched; -+ cutlass::DeviceAllocation ptr_C_batched; -+ -+ ptr_A_batched.reset(ptr_A_batched_host.size()); -+ ptr_B_batched.reset(ptr_A_batched_host.size()); -+ ptr_C_batched.reset(ptr_A_batched_host.size()); -+ -+ ptr_A_batched.copy_from_host(ptr_A_batched_host.data()); -+ ptr_B_batched.copy_from_host(ptr_B_batched_host.data()); -+ ptr_C_batched.copy_from_host(ptr_C_batched_host.data()); -+ -+ // -+ // Create CUDA streams to maximize concurrency of batched-array GEMM kernels -+ // -+ std::vector cuda_streams; -+ -+ // -+ // Warmup run -+ // -+ -+ -+ if (this->options.cuda_streams) { -+ for (int i = 0; i < this->options.cuda_streams; ++i) { -+ cudaStream_t stream; -+ -+ result.error = cudaStreamCreate(&stream); -+ if (result.error != cudaSuccess) { -+ std::cerr << "Failed to create CUDA stream." << std::endl; -+ return result; -+ } -+ cuda_streams.push_back(stream); -+ -+ } -+ } -+ else { -+ cuda_streams.push_back(nullptr); -+ -+ } -+ -+ // Use 'D' for the in/out workspace -+ this->block_D.copy_from_device(this->block_C.get()); -+ -+ for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { -+ -+ cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; -+ int32_t batch_count = bin_count[bin_idx]; -+ int32_t bin_start_idx = bin_start[bin_idx]; -+ int32_t lda = bin_ldm_A[bin_idx]; -+ int32_t ldb = bin_ldm_B[bin_idx]; -+ int32_t ldc = bin_ldm_C[bin_idx]; -+ -+ void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx]; -+ void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx]; -+ void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx]; -+ -+ // -+ // Initialize the CUTLASS GEMM operator -+ // -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kArray, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptr_A_array, -+ (void const *)ptr_B_array, -+ (void const *)ptr_C_array, -+ (void *)ptr_C_array, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = gemm_op(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ int last_stream_idx = 0; -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ -+ for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { -+ -+ cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; -+ int32_t batch_count = bin_count[bin_idx]; -+ int32_t bin_start_idx = bin_start[bin_idx]; -+ int32_t lda = bin_ldm_A[bin_idx]; -+ int32_t ldb = bin_ldm_B[bin_idx]; -+ int32_t ldc = bin_ldm_C[bin_idx]; -+ -+ void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx]; -+ void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx]; -+ void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx]; -+ -+ last_stream_idx = (bin_idx % effective_streams); -+ -+ // -+ // Initialize the CUTLASS GEMM operator -+ // -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kArray, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptr_A_array, -+ (void const *)ptr_B_array, -+ (void const *)ptr_C_array, -+ (void *)ptr_C_array, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = gemm_op(cuda_streams[last_stream_idx]); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Wait for work to be completed -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ for (auto stream : cuda_streams) { -+ if (stream) { -+ (void)cudaStreamDestroy(stream); -+ } -+ } -+ -+ std::cout << " " << this->options.problem_bins.size() << " batched GEMMs launched" << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl; -+ -+ std::string provider = "CUTLASS"; -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << "," << provider << ",batched," -+ << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ result.passed = true; -+ return result; -+ } -+}; -+ -+template -+class TestbedGrouped : BaseTestbed { -+public: -+ TestbedGrouped( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ // Redefine GEMM with different GroupScheduleMode_ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ typename Gemm_::ElementA, -+ typename Gemm_::LayoutA, -+ Gemm_::kTransformA, -+ Gemm_::kAlignmentA, -+ typename Gemm_::ElementB, -+ typename Gemm_::LayoutB, -+ Gemm_::kTransformB, -+ Gemm_::kAlignmentB, -+ typename Gemm_::ElementC, -+ typename Gemm_::LayoutC, -+ typename Gemm_::ElementAccumulator, -+ typename Gemm_::OperatorClass, -+ typename Gemm_::ArchTag, -+ typename Gemm_::ThreadblockShape, -+ typename Gemm_::WarpShape, -+ typename Gemm_::InstructionShape, -+ typename Gemm_::EpilogueOutputOp, -+ typename Gemm_::ThreadblockSwizzle, -+ Gemm_::kStages, -+ GroupScheduleMode_>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ std::cout << std::endl; -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = Gemm::problem_tile_count(problem); -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Sort problems in descending order of problem-K dimension -+ void sort_problems() { -+ Gemm::sort_problems(this->options.problem_count, -+ this->options.problem_sizes.data(), -+ this->lda_host.data(), -+ this->ldb_host.data(), -+ this->ldc_host.data(), -+ this->ldd_host.data(), -+ this->offset_A.data(), -+ this->offset_B.data(), -+ this->offset_C.data(), -+ this->offset_D.data()); -+ } -+ -+ /// Executes a grouped kernel and measures runtime -+ Result profile() { -+ std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; -+ -+ std::cout << std::endl; -+ std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" -+ << "==================================================== *********" << std::endl; -+ -+ Result result; -+ -+ int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ if (this->options.sort_problems) { -+ sort_problems(); -+ } -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // Configure the GEMM arguments -+ typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ this->problem_sizes_device.get(), -+ this->problem_count(), -+ threadblock_count, -+ epilogue_op, -+ this->ptr_A.get(), -+ this->ptr_B.get(), -+ this->ptr_C.get(), -+ this->ptr_D.get(), -+ this->lda.get(), -+ this->ldb.get(), -+ this->ldc.get(), -+ this->ldd.get(), -+ this->options.problem_sizes.data() -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ size_t workspace_size = gemm.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = gemm.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped GEMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (this->options.reference_check) { -+ result.passed = this->verify(); -+ } -+ -+ // -+ // Warm-up run of the grouped GEMM object -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Optionally profile initialization -+ if (this->options.profile_initialization) { -+ // Warm up -+ gemm.initialize(args, workspace.get()); -+ -+ auto start_time = std::chrono::high_resolution_clock::now(); -+ for (int32_t i = 0; i < this->options.iterations; ++i) { -+ gemm.initialize(args, workspace.get()); -+ } -+ auto end_time = std::chrono::high_resolution_clock::now(); -+ -+ std::chrono::duration duration = end_time - start_time; -+ duration /= double(this->options.iterations); -+ result.initialization_time_ms = duration.count(); -+ } -+ -+ int64_t total_tiles = Gemm::group_tile_count(args); -+ std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; -+ if (this->options.profile_initialization) { -+ std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; -+ } -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," -+ << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the Grouped and Batched GEMM types -+ // -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmBatched = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4 -+ >; -+ -+ // Define a grouped GEMM kernel with all template parameters set except -+ // for scheduling mode. This will be used as the template for all scheduling -+ // modes executed. -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ LayoutA, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ LayoutB, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. -+ // This parameter is passed in at present to match the APIs of other kernels. The parameter -+ // is unused within the kernel. -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using GemmGrouped = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Profile it -+ // -+ -+ TestbedBatched testbed_batched(options); -+ Result result = testbed_batched.profile(); -+ if (result.error) { -+ return 1; -+ } -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ for (GroupScheduleMode mode : options.scheduler_modes) { -+ Result result; -+ switch (mode) { -+ case GroupScheduleMode::kDeviceOnly: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ case GroupScheduleMode::kHostPrecompute: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ } -+ -+ if (result.error != cudaSuccess) { -+ return 1; -+ } -+ -+ // Override verbose flag to avoid printing duplicate information for each scheduling mode -+ options.verbose = false; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu -new file mode 100644 -index 0000000..5964028 ---- /dev/null -+++ b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu -@@ -0,0 +1,776 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse per channel scale+bias+relu of the activations -+into the 3D fprop mainloop. -+ -+Compared with original 3D fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors is the same as the -+activation channel number. This kernel loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more -+technical details. -+ -+This example is modified based on 25_ampere_fprop_mainloop_fusion. The command -+line is the same. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNDHWC; -+using LayoutInputB = cutlass::layout::TensorNDHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNDHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor5DCoord input_size; -+ cutlass::Tensor5DCoord filter_size; -+ cutlass::Coord<3> padding; -+ cutlass::Coord<3> conv_stride; -+ cutlass::Coord<3> dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32, 32), -+ filter_size(32, 3, 3, 3, 32), -+ padding(cutlass::make_Coord(1, 1, 1)), -+ conv_stride(cutlass::make_Coord(1, 1, 1)), -+ dilation(cutlass::make_Coord(1, 1, 1)), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding[0] != filter_size.d() / 2) || -+ (padding[1] != filter_size.h() / 2) || -+ (padding[2] != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor5DCoord input_size, -+ cutlass::Tensor5DCoord filter_size, -+ cutlass::Coord<3> stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding[0] = filter_size.d() / 2; -+ padding[1] = filter_size.h() / 2; -+ padding[2] = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("d", input_size.d()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("t", filter_size.d()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = cutlass::make_Coord(1, 1, 1); -+ } -+ else { -+ filter_size.d() = 1; -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = cutlass::make_Coord(0, 0, 0); -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "25_ampere_3d_fprop_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activations into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NDHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n Input tensor extent N\n" -+ << " --d Input tensor extent D\n" -+ << " --h Input tensor extent H\n" -+ << " --w Input tensor extent W\n" -+ << " --c Input tensor extent C\n" -+ << " --k Filter extent K\n" -+ << " --t Filter extent T\n" -+ << " --r Filter extent R\n" -+ << " --s Filter extent S\n\n" -+ << " --alpha Epilogue scalar alpha\n" -+ << " --beta Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n" -+ << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor5DCoord output_size() const { -+ return cutlass::Tensor5DCoord( -+ input_size.n(), -+ (input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1, -+ (input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1, -+ (input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.d() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.d() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride[0] << "," -+ << options.conv_stride[1] << "," -+ << options.conv_stride[2] << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_transformed_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor -+ tensor_a_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_a_bias({1, options.input_size.c()}); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill scale vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_a_scale.sync_device(); -+ tensor_a_bias.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv3dProblemSize with user defined output size -+ cutlass::conv::Conv3dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_a_scale.device_ref(), -+ tensor_a_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int d = 0; d < options.input_size.d(); ++d) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_a.at({n, d, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) * -+ tensor_a_scale.at({0, c}) + -+ tensor_a_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_a.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv3dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_transformed_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "25_ampere_3d_fprop_mainloop_fusion" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "This test must run on SM80 or above.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 18}; -+ -+ struct Benchmark { -+ int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1}, -+ {56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2}, -+ {56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1}, -+ {56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2}, -+ {28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1}, -+ {28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2}, -+ {28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1}, -+ {28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2}, -+ {14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1}, -+ {14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2}, -+ {14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1}, -+ {14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2}, -+ { 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1}, -+ { 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1}, -+ { 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.d, layer.h, layer.w, layer.c}, -+ {layer.k, layer.t, layer.r, layer.s, layer.c}, -+ cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w)); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu -new file mode 100644 -index 0000000..71f5040 ---- /dev/null -+++ b/3rdparty/cutlass/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu -@@ -0,0 +1,768 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse per channel scale+bias+relu of the activations -+into the fprop mainloop. -+ -+Compared with original fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors are the same as the -+activation channel number. This kernels loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more -+technical details. -+ -+This example is modified based on 16_ampere_tensorop_conv2dfprop. The command -+line is the same. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv2dFpropFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "25_ampere_fprop_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activations into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_transformed_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor -+ tensor_a_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_a_bias({1, options.input_size.c()}); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill scale vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor A on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_a_scale.sync_device(); -+ tensor_a_bias.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_a_scale.device_ref(), -+ tensor_a_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_a.at({n, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_a.at({n, h, w, c}) * -+ tensor_a_scale.at({0, c}) + -+ tensor_a_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_a.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_transformed_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "25_ampere_fprop_mainloop_fusion" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "This test must run on SM80 or above.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu b/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu -new file mode 100644 -index 0000000..48e2b77 ---- /dev/null -+++ b/3rdparty/cutlass/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example shows how to fuse activation's per channel scale+bias+relu -+into the wgrad mainloop. -+ -+Compared with original fprop kernel, this example has two more vectors, one for -+the scale and one for the bias. The length of the vectors are the same as the -+activation channel number. This kernels loads the vectors when the associated -+activation channels are loaded in the mainloop. Between reading the -+activations and scale/bias data from the shared memory and calling tensor core -+instructions, scale+bias+relu is computed in the register file. -+ -+This example is customized for Ampere 16816 fp16 tensor core instruction. -+Changing to different data types or different tensor core instruction require -+source code changing. See -+include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h for more -+technical details. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad_fusion.h" -+#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutInputScaleBias = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 5; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dWgradFusionKernel = typename cutlass::conv::kernel::DefaultConv2dWgradFusion< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "26_ampere_wgrad_mainloop_fusion example\n\n" -+ << " This example fuses scale+bias+relu of the activation into Ampere's\n" -+ << " Tensor Core operators on F16 data types to compute\n" -+ << " backward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.output_size()); -+ cutlass::HostTensor tensor_b(options.input_size); -+ cutlass::HostTensor tensor_transformed_b(options.input_size); -+ cutlass::HostTensor -+ tensor_b_scale({1, options.input_size.c()}); -+ cutlass::HostTensor -+ tensor_b_bias({1, options.input_size.c()}); -+ -+ cutlass::HostTensor tensor_c(options.filter_size); -+ cutlass::HostTensor tensor_d(options.filter_size); -+ cutlass::HostTensor tensor_ref_d(options.filter_size); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill scale vector for tensor B on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_scale.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill bias vector for tensor B on host with uniform-distribution random -+ // data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_bias.host_view(), -+ 1, -+ ElementInputA(3), -+ ElementInputA(-4), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_b_scale.sync_device(); -+ tensor_b_bias.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ typename ImplicitGemmFusion::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_b_scale.device_ref(), -+ tensor_b_bias.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemmFusion implicit_gemm_fusion_op; -+ -+ size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_fusion_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_fusion_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute scale + bias + relu in host code -+ for (int n = 0; n < options.input_size.n(); ++n) { -+ for (int h = 0; h < options.input_size.h(); ++h) { -+ for (int w = 0; w < options.input_size.w(); ++w) { -+ for (int c = 0; c < options.input_size.c(); ++c) { -+ tensor_transformed_b.at({n, h, w, c}) = std::max( -+ ElementOutput(0), ElementOutput(tensor_b.at({n, h, w, c}) * -+ tensor_b_scale.at({0, c}) + -+ tensor_b_bias.at({0, c}))); -+ } -+ } -+ } -+ } -+ -+ tensor_transformed_b.sync_device(); -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dWgrad< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_transformed_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "26_ampere_wgrad_mainloop_fusion_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_fusion_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major == 8 && props.minor == 0)) { -+ std::cerr << "This test must run on SM80 A100.\n"; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu -new file mode 100644 -index 0000000..e9d0287 ---- /dev/null -+++ b/3rdparty/cutlass/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu -@@ -0,0 +1,750 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in fp32 data and convert them -+implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional -+fp32 data by using NVIDIA Ampere architecture. -+ -+We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated -+using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). -+ -+The trick is very simple -+ a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big -+ big = convert_to_tf32(fp32) -+ small = convert_to_tf32(fp32 - big) -+ -+a_small x b_small is discarded because they are too small. -+ -+This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 -+results (SGEMM using SIMT) and against FP64 results (DGEMM) -+ -+To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -+OpMultiplyAddFastF32. -+ -+Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference -+ -+ sgemm // CUDA core SIMT kernel. FP32 in, accumulated in FP32, FP32 out. -+ s1688gemm // Use 3xTF32 to emulate FP32. FP32 in, converted in TF32-big and TF32-small internally, -+ // accumulated in FP32, FP32 out. -+ s1688tf32gemm // Use 1xTF32. FP32 in, converted to one TF32 internally, accumulated in FP32, FP32 out. -+ s1688gemm_tf32 // TF32 in, accumulated in FP32, FP32 out. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ int m, n, k; -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ // ctor -+ Result( -+ int m, int n, int k, -+ double runtime_ms, double gflops, -+ double l2_norm_3xtf32_vs_fp64, -+ double l2_norm_1xtf32_vs_fp64, -+ double l2_norm_fp32_vs_fp64) : -+ m(m), n(n), k(k), -+ runtime_ms(runtime_ms), gflops(gflops), -+ l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), -+ l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), -+ l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} -+ -+ Result() {} -+ -+ // -+ // Methods -+ // -+ static void print_csv_header() { -+ std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl; -+ } -+ -+ void print_csv_row() { -+ std::cout << m << "," -+ << n << "," -+ << k << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64 << std::endl; -+ } -+}; -+ -+std::vector results; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ -+ int iterations; -+ int seed; -+ bool benchmark; -+ -+ Options(): -+ help(false), -+ problem_size({3456, 4096, 4096}), -+ iterations(20), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform"), -+ benchmark(false) { } -+ -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of F32 elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 4 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "27_ampere_3xtf32_fast_accurate_tensorop_gemm example\n\n" -+ << " This example uses the CUTLASS Library to emulate FP32 with TF32 tensorop GEMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::RowMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ float, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ float, // <- data type of accumulator -+ float>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Alignment -+constexpr int Alignment = 4; -+ -+// -+// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) -+// -+ -+// Gemm_3xTF32 -+using Gemm_3xTF32 = cutlass::gemm::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ Alignment, -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32>; -+ -+// Gemm_1xTF32 -+using Gemm_1xTF32 = cutlass::gemm::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ Alignment, -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAdd>; -+ -+// Gemm_F64 -+using Gemm_F64 = cutlass::reference::device::Gemm< -+ double, -+ LayoutInputA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ double>; -+ -+// Gemm_F32 -+using Gemm_F32 = cutlass::reference::device::Gemm< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ float>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors using the same values used for F32 -+ //////////////////////////////////////////////////////////////////////////////// -+ // Gemm input operands (A, B, C) -+ cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Gemm output (D) for GEMM_F64 -+ cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_3xTF32 -+ cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_1xTF32 -+ cutlass::HostTensor tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ float alpha = float(options.alpha); -+ float beta = float(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_3xTF32 gemm_op_3xTF32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = gemm_op_3xTF32.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = gemm_op_3xTF32(); -+ CUTLASS_CHECK(status_3xtf32); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.m = problem_size.m(); -+ result.n = problem_size.n(); -+ result.k = problem_size.k(); -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run TF32 kernel without profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_1xTF32 gemm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = gemm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Create instantiation for device reference gemm kernel -+ Gemm_F64 gemm_f64; -+ -+ // Launch device reference gemm kernel -+ gemm_f64(problem_size, -+ alpha, -+ tensor_a_F64.device_ref(), -+ tensor_b_F64.device_ref(), -+ beta, -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Create instantiation for device reference gemm kernel -+ Gemm_F32 gemm_f32; -+ -+ // Launch device reference gemm kernel -+ gemm_f32(problem_size, -+ alpha, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ beta, -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref()); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor tensor_d_F32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ results.push_back(result); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ -+ std::cout << std::fixed; -+ std::cout.precision(4); -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout.precision(2); -+ std::cout << "GFLOPs: " << result.gflops << std::endl; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+ << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl -+ << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl -+ << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (options.benchmark) { -+ for (int k = 4; k <= 65536; k *= 2) { -+ -+ options.problem_size[2] = k; -+ -+ printf("Gemm problem size: %d x %d x %d\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result &= run(options); -+ } -+ } else { -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ } -+ -+ if (!result) return -1; -+ -+ std::cout << std::endl << "CSV results" << std::endl; -+ Result::print_csv_header(); -+ for(auto &r : results) -+ r.print_csv_row(); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu b/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu -new file mode 100644 -index 0000000..27286f9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu -@@ -0,0 +1,822 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+This example adopts example 16 to use 3xTF32 to bring FP32 accuracy with 2x performance -+compared with CUDA Cores. See example 27 for the trick of 3xTF32. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = float; // Data type of elements in input tensor -+using ElementInputB = float; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// 3xTF32 Fprop -+using Conv2dFpropKernel_3xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ // Only thing needs to be changed from normal Fprop -+ cutlass::arch::OpMultiplyAddFastF32, -+ IteratorAlgorithm -+>::Kernel; -+ -+// 1xTF32 Fprop -+using Conv2dFpropKernel_1xTF32 = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+>::Kernel; -+ -+using ImplicitGemm_3xTF32 = cutlass::conv::device::ImplicitGemmConvolution; -+using ImplicitGemm_1xTF32 = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "28_ampere_3xtf32_fast_accurate_tensorop_fprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" -+ << "$ ./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/28_ampere_3xtf32_fast_accurate_tensorop_fprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ error(cudaSuccess), -+ l2_norm_3xtf32_vs_fp64(0), -+ l2_norm_1xtf32_vs_fp64(0), -+ l2_norm_fp32_vs_fp64(0) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a_F32(options.input_size); -+ cutlass::HostTensor tensor_b_F32(options.filter_size); -+ cutlass::HostTensor tensor_c_F32(options.output_size()); -+ cutlass::HostTensor tensor_d_F32(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8)); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8)); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8)); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a_F64(options.input_size); -+ cutlass::HostTensor tensor_b_F64(options.filter_size); -+ cutlass::HostTensor tensor_c_F64(options.output_size()); -+ -+ cutlass::HostTensor tensor_d_F64(options.output_size()); -+ cutlass::HostTensor tensor_d_3xTF32(options.output_size()); -+ cutlass::HostTensor tensor_d_1xTF32(options.output_size()); -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm_3xTF32::Arguments arguments_3xTF32{ -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_3xTF32.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm_3xTF32 implicit_gemm_op_3xTF32; -+ -+ size_t workspace_size_3xTF32 = implicit_gemm_op_3xTF32.get_workspace_size(arguments_3xTF32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xTF32(workspace_size_3xTF32); -+ -+ result.status = implicit_gemm_op_3xTF32.can_implement(arguments_3xTF32); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op_3xTF32.initialize(arguments_3xTF32, workspace_3xTF32.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op_3xTF32(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Performance measurement -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op_3xTF32(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run 1xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Construct ImplicitGemm::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename ImplicitGemm_1xTF32::Arguments arguments_1xTF32{ -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_1xTF32.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm_1xTF32 implicit_gemm_op_1xTF32; -+ -+ size_t workspace_size_1xTF32 = implicit_gemm_op_1xTF32.get_workspace_size(arguments_1xTF32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xTF32(workspace_size_1xTF32); -+ -+ result.status = implicit_gemm_op_1xTF32.can_implement(arguments_1xTF32); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op_1xTF32.initialize(arguments_1xTF32, workspace_1xTF32.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op_1xTF32(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ cutlass::reference::device::Conv2d< -+ double, -+ LayoutInputA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ double -+ >( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a_F64.device_ref(), -+ tensor_b_F64.device_ref(), -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ cutlass::reference::device::Conv2d< -+ float, -+ LayoutInputA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ float -+ >( -+ cutlass::conv::Operator::kFprop, -+ problem_size, -+ tensor_a_F32.device_ref(), -+ tensor_b_F32.device_ref(), -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor tensor_d_F32_in_F64(options.output_size()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "28_ampere_3xtf32_fast_accurate_tensorop_fprop_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a_F32.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b_F32.host_view() << "\n\n"; -+ -+ output_workspace << "TF32x3 = \n" << tensor_d_3xTF32.host_view() << std::endl; -+ output_workspace << "TF32x1 = \n" << tensor_d_1xTF32.host_view() << std::endl; -+ output_workspace << "FP32 = \n" << tensor_d_F32.host_view() << std::endl; -+ output_workspace << "FP64 = \n" << tensor_d_F64.host_view() << "\n\n"; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {1, 32, 64, 128, 256}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1}, -+ {56, 56, 64, 64, 1, 1}, -+ {56, 56, 64, 64, 3, 3}, -+ {56, 56, 256, 64, 1, 1}, -+ {56, 56, 256, 512, 1, 1}, -+ {56, 56, 256, 128, 1, 1}, -+ {28, 28, 128, 128, 3, 3}, -+ {28, 28, 128, 512, 1, 1}, -+ {28, 28, 512, 128, 1, 1}, -+ {28, 28, 512, 1024, 1, 1}, -+ {28, 28, 512, 256, 1, 1}, -+ {14, 14, 256, 256, 3, 3}, -+ {14, 14, 256, 1024, 1, 1}, -+ {14, 14, 1024, 256, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1}, -+ {14, 14, 1024, 512, 1, 1}, -+ {7, 7, 512, 512, 3, 3}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ -+ options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu b/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu -new file mode 100644 -index 0000000..fc6f6af ---- /dev/null -+++ b/3rdparty/cutlass/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu -@@ -0,0 +1,692 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ This example is almost the same as example 27 which uses 3xTF32 to run GEMM. The only -+ difference is that this example uses 3xtf32 on complex gemm. -+ -+ To enable this feature, the only change needs to make is to change OpMultiplyAddComplex -+ to OpMultiplyAddComplexFastF32. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ -+ int m, n, k; -+ double l2_norm_3xtf32_vs_fp64; -+ double l2_norm_1xtf32_vs_fp64; -+ double l2_norm_fp32_vs_fp64; -+ -+ // ctor -+ Result( -+ int m, int n, int k, -+ double runtime_ms, double gflops, -+ double l2_norm_3xtf32_vs_fp64, -+ double l2_norm_1xtf32_vs_fp64, -+ double l2_norm_fp32_vs_fp64) : -+ m(m), n(n), k(k), -+ runtime_ms(runtime_ms), gflops(gflops), -+ l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), -+ l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), -+ l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} -+ -+ Result() {} -+ -+ // -+ // Methods -+ // -+ static void print_csv_header() { -+ std::cout << "M,N,K,Runtime(ms),GFLOPS,3xTF32_vs_FP64,1xTF32_vs_FP64,FP32_vs_FP64" << std::endl; -+ } -+ -+ void print_csv_row() { -+ std::cout << m << "," -+ << n << "," -+ << k << "," -+ << runtime_ms << "," -+ << gflops << "," -+ << l2_norm_3xtf32_vs_fp64 << "," -+ << l2_norm_1xtf32_vs_fp64 << "," -+ << l2_norm_fp32_vs_fp64 << std::endl; -+ } -+}; -+ -+std::vector results; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ -+ int iterations; -+ int seed; -+ bool benchmark; -+ -+ Options(): -+ help(false), -+ problem_size({3456, 4096, 4096}), -+ iterations(20), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform"), -+ benchmark(false) { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm example\n\n" -+ << " This example uses the CUTLASS Library to emulate FP32 complex GEMM computations with TF32 tensor cores.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Row Major for Matrix B and Row Major for Matrix C -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::RowMajor; -+using LayoutOutput = cutlass::layout::RowMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ cutlass::complex, // <- data type of output matrix -+ 1, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ cutlass::complex, // <- data type of accumulator -+ cutlass::complex>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Transform -+constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone; -+constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone; -+ -+// -+// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) -+// -+ -+// Gemm_3xTF32 -+using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex< -+ cutlass::complex, -+ LayoutInputA, -+ cutlass::complex, -+ LayoutInputB, -+ cutlass::complex, -+ LayoutOutput, -+ cutlass::complex, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ TransformA, -+ TransformB, -+ cutlass::arch::OpMultiplyAddComplexFastF32>; -+ -+// Gemm_1xTF32 -+using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex< -+ cutlass::complex, -+ LayoutInputA, -+ cutlass::complex, -+ LayoutInputB, -+ cutlass::complex, -+ LayoutOutput, -+ cutlass::complex, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ TransformA, -+ TransformB, -+ cutlass::arch::OpMultiplyAddComplex>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors using the same values used for F32 -+ //////////////////////////////////////////////////////////////////////////////// -+ // Gemm input operands (A, B, C) -+ cutlass::HostTensor, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Gemm output (D) for GEMM_F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_3xTF32 -+ cutlass::HostTensor, LayoutOutput> tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Gemm output (D) for GEMM_1xTF32 -+ cutlass::HostTensor, LayoutOutput> tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+ -+ // Initialize alpha and beta for dot product computation -+ cutlass::complex alpha = cutlass::complex(options.alpha); -+ cutlass::complex beta = cutlass::complex(options.beta); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel within a profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_3xTF32::Arguments arguments_3xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Gemm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_3xTF32 gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = gemm_op.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = gemm_op(); -+ CUTLASS_CHECK(status_3xtf32); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return false; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.m = problem_size.m(); -+ result.n = problem_size.n(); -+ result.k = problem_size.k(); -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run TF32 kernel without profiling loop -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm_1xTF32::Arguments arguments_1xtf32{problem_size, // <- problem size of matrix multiplication -+ tensor_a_F32.device_ref(), // <- reference to matrix A on device -+ tensor_b_F32.device_ref(), // <- reference to matrix B on device -+ tensor_c_F32.device_ref(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_ref(), // <- reference to matrix D on device -+ {alpha, beta}, // <- tuple of alpha and beta -+ split_k_slices}; // <- k-dimension split factor -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Gemm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm_1xTF32 gemm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = gemm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = gemm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F64) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Launch device reference gemm kernel -+ cutlass::reference::device::GemmComplex( -+ problem_size, -+ alpha, -+ tensor_a_F64.device_ref(), -+ TransformA, -+ tensor_b_F64.device_ref(), -+ TransformB, -+ beta, -+ tensor_c_F64.device_ref(), -+ tensor_d_F64.device_ref(), -+ cutlass::complex(0.f)); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ // Run reference kernel (F32) -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // Launch device reference gemm kernel -+ cutlass::reference::device::GemmComplex( -+ problem_size, -+ alpha, -+ tensor_a_F32.device_ref(), -+ TransformA, -+ tensor_b_F32.device_ref(), -+ TransformB, -+ beta, -+ tensor_c_F32.device_ref(), -+ tensor_d_F32.device_ref(), -+ cutlass::complex(0.f)); -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_F32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /////// Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ -+ result.l2_norm_3xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ -+ result.l2_norm_1xtf32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm F32 vs F64 -+ cutlass::HostTensor, LayoutOutput> tensor_d_F32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_F32_in_F64.host_view(), tensor_d_F32.host_view()); -+ -+ result.l2_norm_fp32_vs_fp64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_F32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ results.push_back(result); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ -+ std::cout << std::fixed; -+ std::cout.precision(4); -+ std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout.precision(2); -+ std::cout << "GFLOPs: " << result.gflops << std::endl; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+ << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl -+ << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl -+ << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (options.benchmark) { -+ for (int k = 4; k <= 65536; k *= 2) { -+ -+ options.problem_size[2] = k; -+ -+ printf("Gemm problem size: %d x %d x %d\n", \ -+ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result &= run(options); -+ } -+ } else { -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ } -+ -+ if (!result) return -1; -+ -+ std::cout << std::endl << "CSV results" << std::endl; -+ Result::print_csv_header(); -+ for(auto &r : results) -+ r.print_csv_row(); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu -new file mode 100644 -index 0000000..e512242 ---- /dev/null -+++ b/3rdparty/cutlass/examples/30_wgrad_split_k/30_wgrad_split_k.cu -@@ -0,0 +1,791 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+This example shows how to compute conv2d gradient with respect to weight (wgrad). In wgrad, the K dimension of -+impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q). Split-k with parallel -+reduction is highly effective for such cases. Given split_k_slices parameter, it partitions the K loop into -+split_k_slices chunks and computes partial reductions in parallel across different blocks. After that, -+a parallel reduction kernel is launched to accumulate partial reductions. -+In practice, wgrad requires fp32 accumulation to avoid overflow. When the input is fp16, some care is needed -+to correctly instantiate the GEMM template. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+// In Wgrad, fp32 accumulation is necessary in practice. -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+using ElementC = ElementOutput; -+using ElementCompute = ElementComputeEpilogue; -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// We need two epilogue functors - one for GEMM and another for the final reduction. -+// The epilogue for GEMM is not used, but needed to instantiate the CUTLASS kernel template. -+// Note that, when the input is fp16 and accumulation is fp32, the output of GEMM needs to be fp32, -+// the final reduction is done in fp32, and the reduction epilogue converts fp32 outputs to fp16. -+// Therefore, the output type of the GEMM epilogue is ElementCompute, not ElementOutput. -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOpGEMM = cutlass::epilogue::thread::LinearCombination< -+ ElementCompute, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// The epilogue functor for reduction. This is the one that is actually used. -+using EpilogueOpReduction = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in lin -+ -+using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementAccumulator, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOpGEMM, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm -+ >::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+using EpilogueOutputOp = EpilogueOpReduction; -+ -+/// Reduction kernel -+using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ int split_k_slices; -+ bool benchmark; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ split_k_slices(8), -+ benchmark(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size, -+ cutlass::MatrixCoord stride) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ conv_stride = stride; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("benchmark")) { -+ benchmark = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.c() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("split-k-slices", split_k_slices); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "30_wgrad_split_k example\n\n" -+ << " This example shows how to compute conv2d gradient with respect to weight (wgrad).\n" -+ << " In wgrad, the K dimension of impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q).\n" -+ << " Split-k with parallel reduction is highly effective for such cases.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --split-k-slices= Split-k factor \n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/30_wgrad_split_k/30_wgrad_split_k --n=32 --h=224 --w=224 --c=128 --k=256 --r=3 --s=3 --split-k-slices=8\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord(input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ // Inputs are the output gradient and the original activation. -+ cutlass::HostTensor tensor_a(options.output_size()); -+ cutlass::HostTensor tensor_b(options.input_size); -+ cutlass::HostTensor tensor_c(options.filter_size); -+ cutlass::HostTensor tensor_d(options.filter_size); -+ cutlass::HostTensor tensor_ref_d(options.filter_size); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C, D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_c.host_view()); -+ -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Partition the GEMM K loop into split_k_slices chunks -+ int split_k_slices = options.split_k_slices; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ // Do not forget to pass the last argument. -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices -+ ); -+ -+ using cutlass::layout::TensorNHWC; -+ -+ cutlass::conv::SplitKMode const split_k_mode = cutlass::conv::SplitKMode::kParallel; -+ -+ // Since the epilogue is not computed after GEMM, there is no need to pass the C tensor and -+ // alpha and beta can be set to 1 and 0 respectively. -+ // Moreover, since the output will be written to the workspace, there is no need to pass -+ // the D tensor as well. -+ // Do not forget to pass the last argument. -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ {nullptr, TensorNHWC()}, -+ {nullptr, TensorNHWC()}, -+ {ElementCompute(1), ElementCompute(0)}, -+ split_k_mode -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm; -+ -+ size_t workspace_size = implicit_gemm.get_workspace_size(arguments); -+ -+ // Split-K requires non-zero workspace size. The workspace size grows linearly with split_k_slices. -+ std::cout << "split-k-slices: " << split_k_slices << std::endl; -+ std::cout << "workspace size: " << workspace_size << std::endl; -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ // After the workspace is allocated, we point the GEMM destination pointer to the workspace. -+ TensorNHWC layout_D{TensorNHWC::packed(options.filter_size)}; -+ arguments.ref_D.reset(reinterpret_cast(workspace.get()), layout_D); -+ -+ result.status = implicit_gemm.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ // Do reduction -+ ReductionDevice reduction_op; -+ auto& status = result.status; -+ static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator; -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ // Reduction input -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // Destination -+ { -+ tensor_d.device_data(), -+ ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // Source -+ { -+ tensor_c.device_data(), -+ ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ {options.alpha, options.beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ status = reduction_op(); -+ } -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dWgrad< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_c.sync_host(); -+ tensor_d.sync_host(); -+ tensor_ref_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ -+ std::stringstream ss; -+ -+ ss << "26_ampere_fused_wgrad_batch_normalization_" -+ << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() -+ << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() -+ << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace -+ << "Input = \n" << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.benchmark) { -+ // Benchmark several layers -+ -+ int batch_sizes[] = {34, 408}; -+ -+ struct Benchmark { -+ int h, w, c, k, r, s, stride_h, stride_w; -+ } layers[] = { -+ {56, 56, 64, 256, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 1, 1, 1, 1}, -+ {56, 56, 64, 64, 3, 3, 1, 1}, -+ {56, 56, 256, 64, 1, 1, 1, 1}, -+ {56, 56, 256, 512, 1, 1, 2, 2}, -+ {56, 56, 256, 128, 1, 1, 1, 1}, -+ {56, 56, 128, 128, 3, 3, 2, 2}, -+ {28, 28, 128, 512, 1, 1, 1, 1}, -+ {28, 28, 512, 128, 1, 1, 1, 1}, -+ {28, 28, 128, 128, 3, 3, 1, 1}, -+ {28, 28, 512, 1024, 1, 1, 2, 2}, -+ {28, 28, 512, 256, 1, 1, 1, 1}, -+ {28, 28, 256, 256, 3, 3, 2, 2}, -+ {14, 14, 256, 1024, 1, 1, 1, 1}, -+ {14, 14, 1024, 256, 1, 1, 1, 1}, -+ {14, 14, 256, 256, 3, 3, 1, 1}, -+ {14, 14, 1024, 2048, 1, 1, 2, 2}, -+ {14, 14, 1024, 512, 1, 1, 1, 1}, -+ {14, 14, 512, 512, 3, 3, 2, 2}, -+ { 7, 7, 512, 2048, 1, 1, 1, 1}, -+ { 7, 7, 2048, 512, 1, 1, 1, 1}, -+ { 7, 7, 512, 512, 3, 3, 1, 1}, -+ }; -+ -+ Result::print_header(std::cout, options) << std::endl; -+ -+ int idx = 1; -+ -+ for (auto const &layer : layers) { -+ for (auto N : batch_sizes) { -+ options.update({N, layer.h, layer.w, layer.c}, -+ {layer.k, layer.r, layer.s, layer.c}, -+ {layer.stride_h, layer.stride_w}); -+ -+ Result result = profile_convolution(options); -+ result.print(std::cout, idx, options) << std::endl; -+ } -+ -+ ++idx; -+ } -+ } -+ else { -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu b/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu -new file mode 100644 -index 0000000..82f4a6a ---- /dev/null -+++ b/3rdparty/cutlass/examples/31_basic_syrk/basic_syrk.cu -@@ -0,0 +1,522 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS SYRK kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Syrk template is instantiated in the function CutlassSsyrkNN. This is kernel computes -+ the symmetric rank-k update (SYRK) using double-precision doubleing-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 16x32x16 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the SSYRK kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for double-precision SYRK kernel -+// -+ -+// Defines cutlass::gemm::device::Syrk, the generic Syrk computation template class. -+#include "cutlass/gemm/device/rank_k.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS SYRK kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS SYRK template and launch a SYRK kernel. -+cudaError_t CutlassSsyrkNN( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ // Define type definition for double-precision CUTLASS SYRK with column-major -+ // input matrices and 16x32x16 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for double-precision SYRK. Typical values are used as -+ // default template arguments. -+ // -+ // To view the full syrk device API interface, see `cutlass/gemm/device/syrk.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassSyrk = cutlass::gemm::device::RankK< -+ double, -+ ColumnMajor, -+ double, -+ ColumnMajor, -+ cutlass::FillMode::kLower, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, // Stages -+ 1, // AligmentA -+ false, // SplitKSerail -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ // Define a CUTLASS SYRK type -+ CutlassSyrk syrk_operator; -+ -+ // Construct the CUTLASS SYRK arguments object. -+ // -+ // One of CUTLASS's design patterns is to define syrk argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Syrk and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassSyrk::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, -+ {N, N, K}, // Syrk Problem dimensions -+ 1, // batch_count, -+ {alpha, beta}, // Scalars used in the Epilogue -+ reinterpret_cast(A), -+ const_cast(reinterpret_cast(C)), -+ reinterpret_cast(C), // destination matrix D (may be different memory than source C matrix) -+ (int64_t)N*K, // Batch strides -+ (int64_t)N*N, -+ (int64_t)N*N, -+ lda, -+ ldc, -+ ldc); -+ -+ // -+ // Launch the CUTLASS SYRK kernel. -+ // -+ -+ cutlass::Status status = syrk_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS SYRK operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ double *matrix, -+ int ldm, -+ int rows, -+ int columns, -+ int seed = 0) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ int offset = i + j * ldm; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ double value = double(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(double *matrix, int ldm, int rows, int columns, int seed = 0) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(double **matrix, int ldm, int rows, int columns, int seed = 0) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(double) * ldm * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, ldm, rows, columns, seed); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference SYRK computation. -+__global__ void ReferenceSyrk_kernel( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < N && j < N && i >= j ) { // Since C is in Lower Fill Mode -+ double accumulator = 0; -+ -+ for (int k = 0; k < K; ++k) { -+ accumulator += A[i + k * lda] * A[j + k * lda]; -+ } -+ -+ C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc]; -+ } -+} -+ -+/// Reference SYRK computation. -+cudaError_t ReferenceSyrk( -+ int N, -+ int K, -+ double alpha, -+ double const *A, -+ int lda, -+ double beta, -+ double *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (N + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceSyrk_kernel<<< grid, block >>>(N, K, alpha, A, lda, beta, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a double-precision -+/// CUTLASS SYRK kernel. -+cudaError_t TestCutlassSyrk(int N, int K, double alpha, double beta) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to SYRK kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = N; -+ int ldc = N; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(double) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ double *A; -+ double *C_cutlass; -+ double *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, lda, N, K, 0); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, ldc, N, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, ldc, N, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS SYRK. -+ // -+ -+ result = CutlassSsyrkNN(N, K, alpha, A, lda, beta, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS SYRK kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference SYRK -+ result = ReferenceSyrk(N, K, alpha, A, lda, beta, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference SYRK kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS SYRK results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference SYRK results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_syrk example. -+// -+// usage: -+// -+// 00_basic_syrk -+// -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ // -+ // Parse the command line to obtain SYRK dimensions and scalar values. -+ // -+ -+ // SYRK problem dimensions. -+ int problem[2] = { 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 3; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ double scalars[2] = { 1, 0 }; -+ -+ for (int i = 3; i < argc && i < 5; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 3]; -+ } -+ -+ // -+ // Run the CUTLASS SYRK test. -+ // -+ -+ cudaError_t result = TestCutlassSyrk( -+ problem[0], // SYRK N dimension -+ problem[1], // SYRK K dimension -+ scalars[0], // alpha -+ scalars[1] // beta -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu b/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu -new file mode 100644 -index 0000000..74f5cb9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/32_basic_trmm/basic_trmm.cu -@@ -0,0 +1,550 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+ This example demonstrates how to call a CUTLASS TRMM kernel and provides a naive reference -+ matrix multiply kernel to verify its correctness. -+ -+ The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes -+ the triangular matrix product (TRMM) using double-precision doubleing-point arithmetic and assumes -+ all matrices have column-major layout. -+ -+ The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices. -+ See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available -+ in CUTLASS. -+ -+ https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ -+ -+ Aside from defining and launching the STRMM kernel, this example does not use any other components -+ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are -+ prevalent in the CUTLASS unit tests. -+ -+*/ -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Helper methods to check for errors -+#include "helper.h" -+ -+// -+// CUTLASS includes needed for double-precision TRMM kernel -+// -+ -+// Defines cutlass::gemm::device::Trmm, the generic Trmm computation template class. -+#include "cutlass/gemm/device/trmm.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// This function defines a CUTLASS TRMM kernel instantiation, constructs its parameters object, -+// and launches it on the CUDA device. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Define a CUTLASS TRMM template and launch a TRMM kernel. -+cudaError_t CutlassStrmmNN( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ // Define type definition for double-precision CUTLASS TRMM with column-major -+ // input matrices and 64x64x16 threadblock tile size (chosen by default). -+ // -+ // To keep the interface manageable, several helpers are defined for plausible compositions -+ // including the following example for double-precision TRMM. Typical values are used as -+ // default template arguments. -+ // -+ // To view the full trmm device API interface, see `cutlass/gemm/device/trmm.h` -+ -+ using ColumnMajor = cutlass::layout::ColumnMajor; -+ -+ using CutlassTrmm = cutlass::gemm::device::Trmm< -+ double, -+ ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ ColumnMajor, -+ double, -+ ColumnMajor, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double, -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 5, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ // Define a CUTLASS TRMM type -+ CutlassTrmm trmm_operator; -+ -+ // Construct the CUTLASS TRMM arguments object. -+ // -+ // One of CUTLASS's design patterns is to define trmm argument objects that are constructible -+ // in host code and passed to kernels by value. These may include pointers, strides, scalars, -+ // and other arguments needed by Trmm and its components. -+ // -+ // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible -+ // arguments to kernels and (2.) minimized initialization overhead on kernel entry. -+ // -+ CutlassTrmm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, -+ {M, N, M}, // Trmm Problem dimensions in Left-Side Mode -+ 1, // batch_count, -+ {alpha}, // Scalars used in the Epilogue -+ reinterpret_cast(A), -+ reinterpret_cast(B), -+ reinterpret_cast(C), // destination matrix D (may be different memory than source C matrix) -+ (int64_t)M*M, // Batch strides -+ (int64_t)M*N, -+ (int64_t)M*N, -+ lda, -+ ldb, -+ ldc); -+ -+ // -+ // Launch the CUTLASS TRMM kernel. -+ // -+ -+ cutlass::Status status = trmm_operator(args); -+ -+ // -+ // Return a cudaError_t if the CUTLASS TRMM operator returned an error code. -+ // -+ -+ if (status != cutlass::Status::kSuccess) { -+ return cudaErrorUnknown; -+ } -+ -+ // Return success, if no errors were encountered. -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// The source code after this point in the file is generic CUDA using the CUDA Runtime API -+// and simple CUDA kernels to initialize matrices and compute the general matrix product. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize a matrix with small integers. -+__global__ void InitializeMatrix_kernel( -+ double *matrix, -+ int ldm, -+ int rows, -+ int columns, -+ int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < rows && j < columns) { -+ if (fill_mode == cutlass::FillMode::kLower && i < j) return; -+ else if (fill_mode == cutlass::FillMode::kUpper && i > j) return; -+ int offset = i + j * ldm; -+ -+ // Generate arbitrary elements. -+ int const k = 16807; -+ int const m = 16; -+ double value = double(((offset + seed) * k % m) - m / 2); -+ -+ matrix[offset] = value; -+ } -+} -+ -+/// Simple function to initialize a matrix to arbitrary small integers. -+cudaError_t InitializeMatrix(double *matrix, int ldm, int rows, int columns, int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (rows + block.x - 1) / block.x, -+ (columns + block.y - 1) / block.y -+ ); -+ -+ InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed, fill_mode); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates device memory for a matrix then fills with arbitrary small integers. -+cudaError_t AllocateMatrix(double **matrix, int ldm, int rows, int columns, int seed = 0, -+ cutlass::FillMode fill_mode = cutlass::FillMode::kInvalid) { -+ cudaError_t result; -+ -+ size_t sizeof_matrix = sizeof(double) * ldm * columns; -+ -+ // Allocate device memory. -+ result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to allocate matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Clear the allocation. -+ result = cudaMemset(*matrix, 0, sizeof_matrix); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to clear matrix device memory: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ // Initialize matrix elements to arbitrary small integers. -+ result = InitializeMatrix(*matrix, ldm, rows, columns, seed, fill_mode); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to initialize matrix: " -+ << cudaGetErrorString(result) << std::endl; -+ return result; -+ } -+ -+ return result; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Naive reference TRMM computation. -+__global__ void ReferenceTrmm_kernel( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ int i = threadIdx.x + blockIdx.x * blockDim.x; -+ int j = threadIdx.y + blockIdx.y * blockDim.y; -+ -+ if (i < M && j < N) { -+ double accumulator = 0; -+ -+ for (int k = 0; k < M; ++k) { -+ accumulator += A[i + k * lda] * B[k + j * ldb]; // Since A is in Left-Side Mode -+ } -+ -+ C[i + j * ldc] = alpha * accumulator; -+ } -+} -+ -+/// Reference TRMM computation. -+cudaError_t ReferenceTrmm( -+ int M, -+ int N, -+ double alpha, -+ double const *A, -+ int lda, -+ double const *B, -+ int ldb, -+ double *C, -+ int ldc) { -+ -+ dim3 block(16, 16); -+ dim3 grid( -+ (M + block.x - 1) / block.x, -+ (N + block.y - 1) / block.y -+ ); -+ -+ ReferenceTrmm_kernel<<< grid, block >>>(M, N, alpha, A, lda, B, ldb, C, ldc); -+ -+ return cudaGetLastError(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocate several matrices in GPU device memory and call a double-precision -+/// CUTLASS TRMM kernel. -+cudaError_t TestCutlassTrmm(int M, int N, double alpha) { -+ cudaError_t result; -+ -+ // -+ // Define several matrices to be used as operands to TRMM kernels. -+ // -+ -+ // Compute leading dimensions for each matrix. -+ int lda = M; -+ int ldb = M; -+ int ldc = M; -+ -+ // Compute size in bytes of the C matrix. -+ size_t sizeof_C = sizeof(double) * ldc * N; -+ -+ // Define pointers to matrices in GPU device memory. -+ double *A; -+ double *B; -+ double *C_cutlass; -+ double *C_reference; -+ -+ // -+ // Allocate matrices in GPU device memory with arbitrary seeds. -+ // -+ -+ result = AllocateMatrix(&A, lda, M, M, 0, cutlass::FillMode::kLower); -+ -+ if (result != cudaSuccess) { -+ return result; -+ } -+ -+ result = AllocateMatrix(&B, ldb, M, N, 17); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_cutlass, ldc, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ return result; -+ } -+ -+ result = AllocateMatrix(&C_reference, ldc, M, N, 101); -+ -+ if (result != cudaSuccess) { -+ cudaFree(A); -+ cudaFree(B); -+ cudaFree(C_cutlass); -+ return result; -+ } -+ -+ result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy C_cutlass matrix to C_reference: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Launch CUTLASS TRMM. -+ // -+ -+ result = CutlassStrmmNN(M, N, alpha, A, lda, B, ldb, C_cutlass, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "CUTLASS TRMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Verify. -+ // -+ -+ // Launch reference TRMM -+ result = ReferenceTrmm(M, N, alpha, A, lda, B, ldb, C_reference, ldc); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Reference TRMM kernel failed: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // Copy to host and verify equivalence. -+ std::vector host_cutlass(ldc * N, 0); -+ std::vector host_reference(ldc * N, 0); -+ -+ result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy CUTLASS TRMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to copy Reference TRMM results: " -+ << cudaGetErrorString(result) << std::endl; -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ return result; -+ } -+ -+ // -+ // Free device memory allocations. -+ // -+ -+ cudaFree(C_reference); -+ cudaFree(C_cutlass); -+ cudaFree(B); -+ cudaFree(A); -+ -+ // -+ // Test for bit equivalence of results. -+ // -+ -+ if (host_cutlass != host_reference) { -+ std::cerr << "CUTLASS results incorrect." << std::endl; -+ -+ return cudaErrorUnknown; -+ } -+ -+ return cudaSuccess; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to basic_trmm example. -+// -+// usage: -+// -+// 00_basic_trmm -+// -+int main(int argc, const char *arg[]) { -+ -+ bool notSupported = false; -+ -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "NVIDIA Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ -+ return -1; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ -+ std::cerr << "This example requires compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ // -+ // Parse the command line to obtain TRMM dimensions and scalar values. -+ // -+ -+ // TRMM problem dimensions. -+ int problem[2] = { 128, 128 }; -+ -+ for (int i = 1; i < argc && i < 3; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> problem[i - 1]; -+ } -+ -+ // Scalars used for linear scaling the result of the matrix product. -+ double scalars[1] = { 1 }; -+ -+ for (int i = 3; i < argc && i < 4; ++i) { -+ std::stringstream ss(arg[i]); -+ ss >> scalars[i - 3]; -+ } -+ -+ // -+ // Run the CUTLASS TRMM test. -+ // -+ -+ cudaError_t result = TestCutlassTrmm( -+ problem[0], // TRMM M dimension -+ problem[1], // TRMM N dimension -+ scalars[0] // alpha -+ ); -+ -+ if (result == cudaSuccess) { -+ std::cout << "Passed." << std::endl; -+ } -+ -+ // Exit. -+ return result == cudaSuccess ? 0 : -1; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu -new file mode 100644 -index 0000000..c938e23 ---- /dev/null -+++ b/3rdparty/cutlass/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu -@@ -0,0 +1,687 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) -+data types in tensor cores. One big advantage is that we can load in F32 data and convert them -+implicitly to tf32 inside the SYMM kernel which means no change is needed to accelerate traditional -+F32 data by using NVIDIA Ampere architecture. -+ -+We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated -+using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). -+ -+The trick is very simple -+ a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big -+ big = convert_to_tf32(F32) -+ small = convert_to_tf32(F32 - big) -+ -+a_small x b_small is discarded because they are too small. -+ -+This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 -+results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS) -+ -+To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -+OpMultiplyAddFastF32. -+ -+Now, we have two different flavors of SSYMM in the profiler for Ampere: -+ -+ s1688symm // Use 3xTF32 to emulate F32. F32 in, converted in TF32-big and TF32-small internally, -+ // accumulated in F32, F32 out. -+ s1688tf32symm // Use 1xTF32. F32 in, converted to one TF32 internally, accumulated in F32, F32 out. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ std::string rand_mode; -+ int seed; -+ -+ Options(): -+ help(false), -+ problem_size({4096, 4096, 4096}), -+ seed(1), -+ alpha(1), -+ beta(), -+ rand_mode("uniform") { } -+ -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of F32 elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 4 elements. -+ // -+ int const kAlignment = 4; -+ -+ if ((problem_size.m() % kAlignment) || -+ (problem_size.n() % kAlignment) || -+ (problem_size.k() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ // Since the kernels in this example are in Left Side Mode -+ cmd.get_cmd_line_argument("m", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("rand_mode", rand_mode); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "33_ampere_3xtf32_tensorop_symm example\n\n" -+ << " This example uses the CUTLASS Library to execute 3xTF32 tensorop SYMM computations.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= SYMM M dimension\n" -+ << " --n= SYMM N dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --rand_mode= gauss / uniform*\n\n" -+ << " --seed= Random number seed (1*)\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/33_ampere_3xtf32_tensorop_symm/33_ampere_3xtf32_tensorop_symm --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=1 \n\n"; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes matrix layout of input and output matrices. Column Major for -+// Matrix A, Matrix B and Matrix C (since that's what cuBLAS supports, CUTLASS supports Row Major too) -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// Symmetric Matrix A is in Left Side mode -+constexpr cutlass::SideMode SideModeA = cutlass::SideMode::kLeft; -+// Symmetric Matrix A is in Lower Filled mode -+constexpr cutlass::FillMode FillModeA = cutlass::FillMode::kLower; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// This code section describes the epilogue part of the kernel -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ float, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized -+ // memory access. For a byte, it's 16 -+ // elements. This becomes the vector width of -+ // math instructions in the epilogue too -+ float, // <- data type of accumulator -+ float>; // <- data type for alpha/beta in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+// Alignment -+constexpr int Alignment = 4; -+ -+// -+// CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64) -+// -+ -+// Symm_3xTF32 -+using Symm_3xTF32 = cutlass::gemm::device::Symm< -+ float, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, // Symmetric matrix is always align 1 -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32>; -+ -+// Symm_1xTF32 -+using Symm_1xTF32 = cutlass::gemm::device::Symm< -+ float, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ float, -+ LayoutInputB, -+ float, -+ LayoutOutput, -+ float, -+ MMAOp, -+ SmArch, -+ ShapeMMAThreadBlock, -+ ShapeMMAWarp, -+ ShapeMMAOp, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ 1, // Symmetric matrix is always align 1 -+ Alignment, -+ false, -+ cutlass::arch::OpMultiplyAdd>; -+ -+// Symm_F64 -+using Symm_F64 = cutlass::gemm::device::Symm< -+ double, -+ LayoutInputA, -+ SideModeA, -+ FillModeA, -+ double, -+ LayoutInputB, -+ double, -+ LayoutOutput, -+ double, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ double, -+ 1, -+ double, -+ double -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4>; -+ -+bool run(Options &options) { -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 1. Initialize F32 Precision input tensors using CUTLASS helper functions -+ //////////////////////////////////////////////////////////////////////////////// -+ cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ -+ if (options.rand_mode == "uniform") { -+ const float min = -1; -+ const float max = 1; -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(max), -+ double(min)); // <- Fill matrix C on host with uniform-distribution random data -+ } else if (options.rand_mode == "gauss") { -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_a_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix A on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_b_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix B on host with gaussian-distribution random data -+ cutlass::reference::host::TensorFillRandomGaussian( -+ tensor_c_F32.host_view(), -+ options.seed, -+ double(0), -+ double(5)); // <- Fill matrix C on host with gaussian-distribution random data -+ } -+ cutlass::reference::host::TensorFill( -+ tensor_d_F32.host_view()); // <- fill matrix D on host with zeros -+ -+ // Copy data from host to GPU -+ tensor_a_F32.sync_device(); -+ tensor_b_F32.sync_device(); -+ tensor_c_F32.sync_device(); -+ tensor_d_F32.sync_device(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 2. Initialize F64 tensors, Output tensors and setup arguments -+ //////////////////////////////////////////////////////////////////////////////// -+ // Symm F64 input operands (A, B, C) -+ cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N -+ -+ // Symm output (D) for SYMM_3xTF32 -+ cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Symm output (D) for SYMM_1xTF32 -+ cutlass::HostTensor tensor_d_1xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+ // Symm output (D) for SYMM_F64 -+ cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N -+#if CUTLASS_ENABLE_CUBLAS -+ // Symm output (D) for SYMM_cublasF32 -+ cutlass::HostTensor tensor_d_cublasF32(problem_size.mn()); // <- Create matrix D with dimensions M x N -+#endif -+ -+ // Copy values from the DP tensors -+ cutlass::reference::host::TensorCopy(tensor_a_F64.host_view(), tensor_a_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_b_F64.host_view(), tensor_b_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_c_F64.host_view(), tensor_c_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); -+#if CUTLASS_ENABLE_CUBLAS -+ cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view()); -+#endif -+ -+ // Copy data from host to GPU -+ tensor_a_F64.sync_device(); -+ tensor_b_F64.sync_device(); -+ tensor_c_F64.sync_device(); -+ tensor_d_F64.sync_device(); -+ tensor_d_3xTF32.sync_device(); -+ tensor_d_1xTF32.sync_device(); -+#if CUTLASS_ENABLE_CUBLAS -+ tensor_d_cublasF32.sync_device(); -+#endif -+ -+ // Initialize alpha and beta for dot product computation -+ float alpha = float(options.alpha); -+ float beta = float(options.beta); -+ -+ // Batch count as 1 -+ int batch_count = 1; -+ -+ // Batch stride for A, when matrix A is in Left Side mode -+ int batch_stride_A = problem_size.m()*problem_size.m(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 3. Run 3xTF32 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_3xTF32::Arguments arguments_3xtf32{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {alpha, beta}, // <- tuple of alpha and beta -+ tensor_a_F32.device_data(), // <- reference to matrix A on device -+ tensor_b_F32.device_data(), // <- reference to matrix B on device -+ tensor_c_F32.device_data(), // <- reference to matrix C on device -+ tensor_d_3xTF32.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F32.layout().stride(0), -+ tensor_b_F32.layout().stride(0), -+ tensor_c_F32.layout().stride(0), -+ tensor_d_3xTF32.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_3xtf32 = Symm_3xTF32::get_workspace_size(arguments_3xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_3xtf32(workspace_size_3xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_3xTF32 symm_op_3xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_3xtf32 = symm_op_3xtf32.initialize(arguments_3xtf32, workspace_3xtf32.get()); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_3xtf32 = symm_op_3xtf32(); -+ CUTLASS_CHECK(status_3xtf32); -+ -+ tensor_d_3xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 4. Run 1xTF32 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_1xTF32::Arguments arguments_1xtf32{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {alpha, beta}, // <- tuple of alpha and beta -+ tensor_a_F32.device_data(), // <- reference to matrix A on device -+ tensor_b_F32.device_data(), // <- reference to matrix B on device -+ tensor_c_F32.device_data(), // <- reference to matrix C on device -+ tensor_d_1xTF32.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F32.layout().stride(0), -+ tensor_b_F32.layout().stride(0), -+ tensor_c_F32.layout().stride(0), -+ tensor_d_1xTF32.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_1xtf32 = Symm_1xTF32::get_workspace_size(arguments_1xtf32); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_1xtf32(workspace_size_1xtf32); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_1xTF32 symm_op_1xtf32; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_1xtf32 = symm_op_1xtf32.initialize(arguments_1xtf32, workspace_1xtf32.get()); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ // Launch initialized CUTLASS kernel -+ status_1xtf32 = symm_op_1xtf32(); -+ CUTLASS_CHECK(status_1xtf32); -+ -+ tensor_d_1xTF32.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 5. Run F64 kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Symm_F64::Arguments arguments_f64{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, // <- problem size of matrix multiplication -+ batch_count, // <- batch count -+ {double(options.alpha), double(options.alpha)}, // <- tuple of alpha and beta -+ tensor_a_F64.device_data(), // <- reference to matrix A on device -+ tensor_b_F64.device_data(), // <- reference to matrix B on device -+ tensor_c_F64.device_data(), // <- reference to matrix C on device -+ tensor_d_F64.device_data(), // <- reference to matrix D on device -+ batch_stride_A, // <- batch stride and ld for matrices -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_a_F64.layout().stride(0), -+ tensor_b_F64.layout().stride(0), -+ tensor_c_F64.layout().stride(0), -+ tensor_d_F64.layout().stride(0) -+ }; -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size_f64 = Symm_F64::get_workspace_size(arguments_f64); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace_f64(workspace_size_f64); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Symm_F64 symm_op_f64; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64); -+ CUTLASS_CHECK(status_f64); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status_f64 = symm_op_f64.initialize(arguments_f64, workspace_f64.get()); -+ CUTLASS_CHECK(status_f64); -+ -+ // Launch initialized CUTLASS kernel -+ status_f64 = symm_op_f64(); -+ CUTLASS_CHECK(status_f64); -+ -+ cudaDeviceSynchronize(); -+ -+ tensor_d_F64.sync_host(); -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 6. Run cuBLAS SSYMM kernel -+ //////////////////////////////////////////////////////////////////////////////// -+ -+#if CUTLASS_ENABLE_CUBLAS -+ cublasStatus_t cublas_status; -+ cublasHandle_t handle; -+ -+ cublas_status = cublasCreate(&handle); -+ if (cublas_status != CUBLAS_STATUS_SUCCESS) { -+ std::cerr << "Failed to create cuBLAS handle." << std::endl; -+ return false; -+ } -+ -+ cublas_status = cublasSsymm( -+ handle, -+ CUBLAS_SIDE_LEFT, -+ CUBLAS_FILL_MODE_LOWER, -+ problem_size.m(), -+ problem_size.n(), -+ static_cast(&alpha), -+ static_cast(tensor_a_F32.device_data()), -+ int(tensor_a_F32.layout().stride(0)), -+ static_cast(tensor_b_F32.device_data()), -+ int(tensor_b_F32.layout().stride(0)), -+ static_cast(&beta), -+ static_cast(tensor_d_cublasF32.device_data()), -+ int(tensor_d_cublasF32.layout().stride(0)) -+ ); -+ -+ cudaDeviceSynchronize(); -+ -+ tensor_d_cublasF32.sync_host(); -+#endif -+ -+ //////////////////////////////////////////////////////////////////////////////// -+ /// 7. Compute l2 norms -+ //////////////////////////////////////////////////////////////////////////////// -+ -+#if CUTLASS_ENABLE_CUBLAS -+ // l2 norm cuBLAS F32 vs F64 -+ cutlass::HostTensor tensor_d_cublasF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_cublasF32_in_F64.host_view(), tensor_d_cublasF32.host_view()); -+ -+ double l2_norm_cublasf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_cublasF32_in_F64.host_view(), tensor_d_F64.host_view()); -+#endif -+ -+ // l2 norm 3xTF32 vs F64 -+ cutlass::HostTensor tensor_d_3xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_3xTF32_in_F64.host_view(), tensor_d_3xTF32.host_view()); -+ double l2_norm_3xtf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+ // l2 norm 1xTF32 vs F64 -+ cutlass::HostTensor tensor_d_1xTF32_in_F64(problem_size.mn()); -+ cutlass::reference::host::TensorCopy(tensor_d_1xTF32_in_F64.host_view(), tensor_d_1xTF32.host_view()); -+ double l2_norm_1xtf32_vs_f64 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_1xTF32_in_F64.host_view(), tensor_d_F64.host_view()); -+ -+#if CUTLASS_ENABLE_CUBLAS -+ // l2 norm 3xTF32 vs cuBLAS F32 -+ double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view()); -+#endif -+ -+ // l2 norm 3xTF32 vs 1xTF32 -+ double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric( -+ tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view()); -+ -+ /////////////////////////////////////////////////////////////////////////////// -+ -+ // Print kernel info and L2 norms -+ std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") " -+ << "Alpha: " << alpha << "," << " Beta: " << beta << std::endl; -+ std::cout << std::fixed; -+ std::cout << "Normalized L2 norm of" << std::endl; -+ std::cout.precision(8); -+ std::cout << std::scientific -+#if CUTLASS_ENABLE_CUBLAS -+ << " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl -+#endif -+ << " - 3xTF32 error with F64 reference : " << l2_norm_3xtf32_vs_f64 << std::endl -+ << " - 1xTF32 error with F64 reference : " << l2_norm_1xtf32_vs_f64 << std::endl -+#if CUTLASS_ENABLE_CUBLAS -+ << " - 3xTF32 error with cuBLAS F32 reference : " << l2_norm_3xtf32_vs_cublasf32 << std::endl -+#endif -+ << " - 3xTF32 error with 1xTF32 reference : " << l2_norm_3xtf32_vs_1xtf32 << std::endl; -+ -+ return true; -+} -+ -+int main(int argc, const char **argv) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ bool result = true; -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ result = run(options); -+ -+ if (!result) return -1; -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu b/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu -new file mode 100644 -index 0000000..2e4ce3c ---- /dev/null -+++ b/3rdparty/cutlass/examples/34_transposed_conv2d/34_transposed_conv2d.cu -@@ -0,0 +1,639 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* -+This example shows how to compute 2d transposed convolution, also known as deconvolution, using CUTLASS -+conv2d Dgrad kernels. Although two operations are computationaly equivalent, some care is needed to correctly -+set up a problem size for CUTLASS. -+In deep learning, transposed convolution is sometimes used for upscaling feature maps. This example -+demonstrates the 2x upscaling case using the strided Dgrad kernel. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using cutlass::layout::TensorNHWC; -+using cutlass::TensorRef; -+ -+using ElementAccumulator = cutlass::half_t; // Data type of accumulator -+using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+using ElementC = ElementOutput; -+using ElementCompute = ElementComputeEpilogue; -+using LayoutInputA = TensorNHWC; -+using LayoutInputB = TensorNHWC; -+using LayoutOutput = TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describe iterator algorithm selected is Analytic or Optimized -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementCompute, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementAccumulator, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided // Use the strided Dgrad specialization -+ >::Kernel; -+ -+using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 16), -+ padding(1, 1, 1, 1), -+ conv_stride(2, 2), -+ dilation(1, 1), -+ reference_check(true), -+ measure_performance(false), -+ iterations(20), -+ alpha(1), -+ beta(0) {} -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("skip-ref-check")) { -+ reference_check = false; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ // Filter layout is CRSK -+ cmd.get_cmd_line_argument("k", filter_size.c()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ filter_size.n() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "34_transposed_conv2d example\n\n" -+ << " This example shows how to compute 2d transposed convolution, also known as\n" -+ << " deconvolution, using CUTLASS conv2d Dgrad kernels. Although two operations are\n" -+ << " computationaly equivalent, some care is needed to correctly set up a problem size.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --skip-ref-check If set (true), skip reference check on the host\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ // Here, out_pad corresponds to "output_padding" of conv2d_transpose op in deep learning frameworks. -+ // See for example https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html -+ int out_pad_h = conv_stride.row() > 1 ? 1 : 0; -+ int out_pad_w = conv_stride.column() > 1 ? 1 : 0; -+ int out_h = (input_size.h() - 1) * conv_stride.row() - 2 * padding.n() + (((filter_size.h() - 1) * dilation.row() + 1)) + out_pad_h; -+ int out_w = (input_size.w() - 1) * conv_stride.column() - 2 * padding.w() + (((filter_size.w() - 1) * dilation.column() + 1)) + out_pad_w; -+ return cutlass::Tensor4DCoord(input_size.n(), out_h, out_w, filter_size.c()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NHWC * KRS -+ // Note that the input with the layout NHWC corresponds to the output from the perspective of dgrad, -+ // and that the filter layout is CRSK. -+ int64_t fmas = input_size.product() * int64_t(filter_size.h() * filter_size.w() * filter_size.n()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.c() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.conv_stride.row() << "," -+ << options.conv_stride.column() << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+// This is the same as Conv2dDgrad in tools/util/include/cutlass/util/reference/host/convolution.h, -+// only variable names have been adapted for transposed conv2d. -+void Conv2dTransposeReference( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ int H = problem_size.P; -+ int W = problem_size.Q; -+ int P = problem_size.H; -+ int Q = problem_size.W; -+ int K = problem_size.C; -+ int C = problem_size.K; -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < P; ++p) { -+ for (int q = 0; q < Q; ++q) { -+ for (int k = 0; k < K; ++k) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < C; ++c) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ int h = p + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int w = q + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (h >= 0 && (h % problem_size.stride_h) == 0 && -+ w >= 0 && (w % problem_size.stride_w) == 0) { -+ -+ h = h / problem_size.stride_h; -+ w = w / problem_size.stride_w; -+ -+ if (h < H && w < W) { -+ -+ ElementInputA a = tensor_a.at(cutlass::make_Coord(n, h, w, c)); -+ ElementInputB b = tensor_b.at(cutlass::make_Coord(c, r, s, k)); -+ -+ acc += ElementAccumulator(a) * ElementAccumulator(b); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_c.at(cutlass::make_Coord(n, p, q, k)); -+ } -+ -+ tensor_d.at(cutlass::make_Coord(n, p, q, k)) = alpha * ElementCompute(acc) + beta * ElementCompute(c_ref); -+ -+ } // for (K) -+ } // for (W) -+ } // for (H) -+ } // for (N) -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+Result profile_convolution(Options const &options) { -+ -+ std::cout << "Output shape: " << options.output_size() << std::endl; -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C and D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_c.host_view()); -+ -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ // The input in transposed conv2d corresponds to the output in the equivalent dgrad. -+ // Similarly for the output. -+ // Although the filter layout is CRSK from the perspective of conv2d transpose, -+ // the filter size does not need to change for setting up the problem size. -+ // There is no need to transpose the filter tensor either. -+ -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.output_size(), -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.input_size, -+ mode -+ ); -+ -+ typename ImplicitGemm::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta} -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ ImplicitGemm implicit_gemm; -+ -+ size_t workspace_size = implicit_gemm.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ -+ // // Skip reference check since there is no reference code for conv2d transpose in cutlass. -+ if (options.reference_check) { -+ tensor_d.sync_host(); -+ std::cout << "Verification on host...\n"; -+ Conv2dTransposeReference(problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, options.beta); -+ -+ bool passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } -+ else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } -+ -+ if (options.measure_performance) { -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu -new file mode 100644 -index 0000000..163a634 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_softmax.cu -@@ -0,0 +1,720 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_softmax.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#define TRACE(x) { std::cout << "gemm_softmax.cu:" << __LINE__ << " " << x << std::endl; } -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class Disposition { -+ kPassed, -+ kIncorrect, -+ kNotVerified -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ int batch_count; -+ int iterations; -+ unsigned seed; -+ float alpha; -+ float beta; -+ bool verification_enabled; -+ float tolerance; -+ -+ Options(): -+ help(false), -+ problem_size({16, 24, 64}), -+ batch_count(16), -+ iterations(20), -+ seed(2022), -+ alpha(1), -+ beta(0), -+ verification_enabled(true), -+ tolerance(1e-5f) -+ { } -+ -+ bool valid() { -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("batch_count", batch_count); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("verify", verification_enabled); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("tolerance", tolerance); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "35_gemm_softmax example\n\n" -+ << " This example uses the CUTLASS Library to compute GEMM + Softmax for arbitrary problem sizes.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --batch_count= Batch number\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" -+ << " --verify= If true, performs reference calculation.\n\n" -+ << " --tolerance Error tolerance\n" -+ ; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/35_gemm_softmax/35_gemm_softmax --m=1024 --n=512 \\\n" -+ << " --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Returns true if the environment and Toolkit support this -+ bool supported(bool verbose = true) const { -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ } -+ return false; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ if (verbose) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ } -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ } -+ return false; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Testbed { -+ -+ // -+ // Type definitions -+ // -+ -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementCompute = float; -+ using ElementD = ElementC; -+ using ElementSoftmax = ElementC; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ // ApplyShape impacts the final Softmax performance a lot. -+ // Set ApplyShape::kColumn to be the next multiple of 32 number that is after -+ // (gemm_N / alignment). -+ // Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn). -+ using ApplyShape = cutlass::MatrixShape<1, 1024>; -+ -+ static int const kStages = 3; -+ -+ /// Linear scaling operator -+ using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementCompute, -+ ElementCompute -+ >; -+ -+ using GemmSoftmax = cutlass::GemmSoftmax< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ kStages, -+ ApplyShape -+ >; -+ -+ using ElementNorm = typename GemmSoftmax::ElementNorm; -+ using ElementSum = typename GemmSoftmax::ElementSum; -+ using LayoutC = typename GemmSoftmax::LayoutC; -+ using LayoutN = typename GemmSoftmax::LayoutN; -+ using LayoutS = typename GemmSoftmax::LayoutS; -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options const &options; -+ -+ -+ cutlass::HostTensor reference_N; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ cutlass::DeviceAllocation block_Ref; -+ cutlass::DeviceAllocation block_Softmax; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_size; -+ -+ int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ // fixed rowmajor for norm and sum -+ int64_t ldn = problem.m(); -+ int64_t lds = ldn; -+ -+ int64_t total_elements_A_per_batch = problem.m() * problem.k(); -+ int64_t total_elements_B_per_batch = problem.k() * problem.n(); -+ int64_t total_elements_C_per_batch = problem.m() * problem.n(); -+ int64_t total_elements_D_per_batch = problem.m() * problem.n(); -+ int64_t total_elements_partial_norm_per_batch = block_num * problem.m(); -+ -+ int64_t total_elements_A = total_elements_A_per_batch * options.batch_count; -+ int64_t total_elements_B = total_elements_B_per_batch * options.batch_count; -+ int64_t total_elements_C = total_elements_C_per_batch * options.batch_count; -+ int64_t total_elements_D = total_elements_D_per_batch * options.batch_count; -+ int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_ -+ ): -+ options(options_) -+ { -+ reference_N.reset({options.problem_size.m(), 1}, false); -+ } -+ -+ /// Run -+ Disposition run() { -+ -+ Disposition disposition = Disposition::kNotVerified; -+ -+ // -+ // Initialize the workspace -+ // -+ -+ initialize(); -+ -+ // -+ // Launch device kernel -+ // -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return disposition; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Device synchronize failed with error " -+ << cudaGetErrorString(result) << std::endl; -+ return disposition; -+ } -+ -+ // -+ // Verify -+ // -+ -+ if (options.verification_enabled) { -+ -+ bool passed = verify(); -+ -+ if (passed) { -+ disposition = Disposition::kPassed; -+ } -+ else { -+ disposition = Disposition::kIncorrect; -+ } -+ } -+ -+ // -+ // Profiling -+ // -+ if (options.iterations) { -+ profile(); -+ } -+ -+ return disposition; -+ } -+ -+ /// Random initialization -+ void initialize() { -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ block_Softmax.reset(total_elements_D); -+ block_Ref.reset(total_elements_D_per_batch); -+ block_Norm.reset(total_elements_partial_norm); -+ block_Sum.reset(total_elements_partial_norm); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0); -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0); -+ -+ cutlass::reference::host::TensorFill( -+ reference_N.host_view(), -+ ElementNorm() -+ ); -+ -+ } -+ -+ cutlass::Status execute_device_kernel() { -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ // -+ // Setup arguments -+ // -+ -+ GemmSoftmax::Arguments args( -+ options.problem_size, -+ options.batch_count, -+ {block_A.get(), lda}, -+ {block_B.get(), ldb}, -+ {block_C.get(), ldc}, -+ {block_D.get(), ldc}, -+ { -+ ElementCompute(options.alpha), -+ ElementCompute(options.beta) -+ }, -+ {block_Norm.get(), ldn}, -+ {block_Sum.get(), lds}, -+ {block_Softmax.get(), ldc}, -+ total_elements_A_per_batch, -+ total_elements_B_per_batch, -+ total_elements_C_per_batch, -+ total_elements_D_per_batch, -+ total_elements_partial_norm_per_batch, -+ total_elements_partial_norm_per_batch, -+ total_elements_D_per_batch -+ ); -+ -+ // -+ // Launch -+ // -+ -+ GemmSoftmax gemm_softmax; -+ -+ // Initialize -+ status = gemm_softmax.initialize(args); -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run -+ status = gemm_softmax(); -+ -+ return status; -+ } -+ -+ template -+ bool verify_tensor(std::vector vector_Input, \ -+ std::vector vector_Input_Ref) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ float abs_tol = options.tolerance; -+ float rel_tol = options.tolerance; -+ -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i)); -+ float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; -+ if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { -+ printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the reference matches -+ bool verify() { -+ -+ LayoutA layout_A(lda); -+ LayoutB layout_B(ldb); -+ LayoutC layout_C(ldc); -+ LayoutN Layout_N(ldn); -+ LayoutS Layout_S(lds); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) { -+ -+ cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_C, extent_C); -+ -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementCompute -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ cutlass::ComplexTransform::kNone, -+ view_B, -+ cutlass::ComplexTransform::kNone, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementCompute(0) -+ ); -+ -+ // Copy reference results to host memory for verification -+ std::vector matrix_D_Ref(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); -+ cutlass::TensorView view_Ref(matrix_D_Ref.data(), layout_C, extent_C); -+ -+ std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); -+ cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); -+ -+ // Copy computed results to host memory -+ std::vector matrix_D(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); -+ -+ std::vector matrix_Softmax(layout_C.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); -+ -+ // Compute the norm -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ reference_N.at({m, 0}) = view_Ref.ref().at({m, 0}); -+ for (int n = 1; n < options.problem_size.n(); ++n) { -+ reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n}))); -+ } -+ } -+ -+ // Compute softmax -+ for (int m = 0; m < options.problem_size.m(); ++m) { -+ -+ float sum = float(); -+ -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); -+ } -+ -+ float inv_sum = float(1.0f / sum); -+ -+ for (int n = 0; n < options.problem_size.n(); ++n) { -+ -+ view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( -+ std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum -+ ); -+ } -+ } -+ -+ // Verification checks - set any of these to 'true' to override the verification checks. -+ bool verified_D = false; -+ bool verified_Softmax = false; -+ -+ // Verify softmax output -+ if (!verified_D) { -+ verified_D = verify_tensor(matrix_D, matrix_D_Ref); -+ } -+ -+ if (!verified_Softmax) { -+ verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref); -+ } -+ -+ if (!verified_D || !verified_Softmax) { -+ -+ std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; -+ -+ // Summarize which checks failed -+ if (!verified_D) { -+ std::cerr << "Verification of D tensor failed\n"; -+ } -+ -+ if (!verified_Softmax) { -+ std::cerr << "Verification of Softmax tensor failed\n"; -+ } -+ -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Profiles -+ bool profile() { -+ -+ // -+ // Profile -+ // -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ cudaError_t result; -+ cudaEvent_t events[2]; -+ int const kIterations = options.iterations; -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventCreate(&evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (int iter = 0; iter < kIterations; ++iter) { -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ float elapsed_ms = 0; -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventDestroy(evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; -+ int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); -+ -+ double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); -+ double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); -+ -+ double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; -+ -+ std::cout << " Problem: " -+ << options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k() -+ << ", batch size: " << options.batch_count -+ << std::endl; -+ -+ std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl; -+ -+ std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; -+ std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char **argv) { -+ -+ // Options parsing -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (!options.supported()) { -+ return 0; -+ } -+ -+ // Run -+ Testbed testbed(options); -+ -+ Disposition disposition = testbed.run(); -+ -+ std::cout << std::endl; -+ -+ switch (disposition) { -+ case Disposition::kPassed: -+ std::cout << "Passed" << std::endl; -+ break; -+ case Disposition::kIncorrect: -+ std::cout << "Incorrect" << std::endl; -+ break; -+ case Disposition::kNotVerified: -+ std::cout << "Not verified" << std::endl; -+ break; -+ } -+ -+ return (disposition == Disposition::kPassed ? 0 : -1); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -new file mode 100644 -index 0000000..586c912 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -@@ -0,0 +1,536 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief GEMM kernel to support the epilogue visitor model -+ for customized softmax partial reduction epilogue fusion. -+ -+ This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once -+ its usage has been stabilized. For now, it is included in this example to demonstrate -+ some basic output fusion options. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using TensorRefA = TensorRef; -+ -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using TensorRefB = TensorRef; -+ -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename Epilogue::Layout; -+ using TensorRefC = TensorRef; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ using ElementNorm = typename EpilogueVisitor::ElementNorm; -+ using ElementSum = typename EpilogueVisitor::ElementSum; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ -+ ElementNorm *ptr_Max; -+ ElementSum *ptr_Sum; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1) -+ { } -+ -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode_, -+ GemmCoord problem_size_, -+ int batch_count_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ TensorRefC ref_C_, -+ TensorRefC ref_D_, -+ ElementNorm *ptr_Max_, -+ ElementSum *ptr_Sum_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ typename EpilogueVisitor::Arguments epilogue_visitor_ -+ ): -+ mode(mode_), -+ problem_size(problem_size_), -+ batch_count(batch_count_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ptr_Max(ptr_Max_), -+ ptr_Sum(ptr_Sum_), -+ batch_stride_A(batch_stride_A_), -+ batch_stride_B(batch_stride_B_), -+ epilogue_visitor(epilogue_visitor_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename EpilogueVisitor::OutputTileIterator::Params params_C; -+ typename EpilogueVisitor::OutputTileIterator::Params params_D; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ ElementC * ptr_C; -+ ElementC * ptr_D; -+ -+ ElementNorm * ptr_Max; -+ ElementSum * ptr_Sum; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_Max(nullptr), -+ ptr_Sum(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0) -+ { } -+ -+ -+ Params( -+ Arguments const &args -+ ): -+ problem_size(args.problem_size), -+ swizzle_log_tile(0), -+ params_A(args.ref_A.layout()), -+ params_B(args.ref_B.layout()), -+ params_C(args.ref_C.layout()), -+ params_D(args.ref_D.layout()), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(args.problem_size.k()), -+ ptr_A(args.ref_A.data()), -+ ptr_B(args.ref_B.data()), -+ ptr_C(args.ref_C.data()), -+ ptr_D(args.ref_D.data()), -+ ptr_Max(args.ptr_Max), -+ ptr_Sum(args.ptr_Sum), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ epilogue_visitor(args.epilogue_visitor) -+ { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); -+ -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ typename Mma::SharedStorage main_loop; -+ -+ struct { -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ } epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmWithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // -+ // Construct the epilogue visitor -+ // -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.epilogue.visitor, -+ params.problem_size.mn(), -+ thread_idx, -+ warp_idx, -+ lane_idx, -+ params.params_C, -+ params.params_D, -+ params.ptr_C, -+ params.ptr_D, -+ params.ptr_Max, -+ params.ptr_Sum, -+ threadblock_offset, -+ blockIdx.y *params.problem_size.m() ); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ // Indicate which position in a serial reduction the output operator is currently updating -+ epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h -new file mode 100644 -index 0000000..6b2fa99 ---- /dev/null -+++ b/3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h -@@ -0,0 +1,651 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+#include "cutlass/reduction/kernel/reduce_softmax_final.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_epilogue_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Kernel computes partial reduction -+// -+// -+// 2. Sum[m, n'] = sum_n(exp(D[m, n] - N[m, 0])) -+// -+template < -+ typename ElementD_, -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoft_, -+ typename ElementSoftmaxCompute_, -+ int Alignment, -+ typename ApplyShape_ = MatrixShape<1, 1024> -+> -+class ApplySoftmax { -+public: -+ -+ using ElementD = ElementD_; -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoft = ElementSoft_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ -+ static int const kAlignment = Alignment; -+ using ApplyShape = ApplyShape_; -+ -+ using Layout = cutlass::layout::RowMajor; -+ -+ using TensorRefD = TensorRef; -+ using TensorRefN = TensorRef; -+ using TensorRefSum = TensorRef; -+ using TensorRefSoft = TensorRef; -+ -+ using FragmentSoftmax = Array; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ MatrixCoord extent; ///< Extent of D and Softmax matrices -+ int batch_count; ///< Batch count -+ TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input) -+ TensorRefN ref_N; ///< Norm tensor (input) -+ TensorRefSum ref_S; ///< Sum tensor (input) -+ TensorRefSoft ref_Soft; ///< Softmax tensor (output) -+ int64_t batch_stride_D; ///< Batch stride for D tensor -+ int64_t batch_stride_N; ///< Batch stride for N tensor -+ int64_t batch_stride_S; ///< Batch stride for S tensor -+ int64_t batch_stride_Soft; ///< Batch stride for softmax tensor -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ batch_count(1), -+ batch_stride_D(0), -+ batch_stride_N(0), -+ batch_stride_S(0), -+ batch_stride_Soft(0) -+ { } -+ -+ Arguments( -+ MatrixCoord extent_, ///< Extent of D and Softmax matrices -+ int batch_count_, ///< Batch count -+ TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce -+ TensorRefN ref_N_, ///< Output parameter for N -+ TensorRefSum ref_S_, ///< Output parameter for N -+ TensorRefSoft ref_Soft_, ///< Softmax -+ int64_t batch_stride_D_ = 0, -+ int64_t batch_stride_N_ = 0, -+ int64_t batch_stride_S_ = 0, -+ int64_t batch_stride_Soft_ = 0 -+ ): -+ extent(extent_), -+ batch_count(batch_count_), -+ ref_D(ref_D_), -+ ref_N(ref_N_), -+ ref_S(ref_S_), -+ ref_Soft(ref_Soft_), -+ batch_stride_D(batch_stride_D_), -+ batch_stride_N(batch_stride_N_), -+ batch_stride_S(batch_stride_S_), -+ batch_stride_Soft(batch_stride_Soft_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+ // -+ // SharedStorage -+ // -+ -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplySoftmax() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ -+ /// Compute Softmax -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ using AccessTypeD = AlignedArray; -+ -+ int block_batch = blockIdx.z; -+ int block_m = blockIdx.x * ApplyShape::kRow; -+ int block_n = 0; -+ -+ int thread_m = threadIdx.y; -+ int thread_n = threadIdx.x * kAlignment; -+ -+ int idx_m = block_m + thread_m; -+ int idx_n = block_n + thread_n; -+ -+ int batch_offset_norm = block_batch * params.args.batch_stride_N; -+ int batch_offset_sum = block_batch * params.args.batch_stride_S; -+ -+ // Kill off thread if it is outside the row boundary -+ if (params.args.extent.row() <= idx_m) { -+ return; -+ } -+ -+ // -+ // Setup pointers to load D again -+ // -+ -+ using AccessTypeD = AlignedArray; -+ using AccessTypeSoft = AlignedArray; -+ using FragmentSoft = Array; -+ using ConvertSoftCompute = cutlass::NumericArrayConverter; -+ using ConvertSoftOutput = cutlass::NumericArrayConverter; -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ ConvertSoftCompute convert_soft_compute; -+ ConvertSoftOutput convert_soft_output; -+ -+ Minus minus; -+ Mul mul; -+ Exp exponential; -+ -+ using ConvertSum = cutlass::NumericConverter; -+ using ConvertNorm = cutlass::NumericConverter; -+ -+ ConvertSum convert_sum; -+ ConvertNorm convert_norm; -+ -+ AccessTypeD *access_d = reinterpret_cast( -+ params.args.ref_D.data() + -+ params.args.batch_stride_D * block_batch + -+ params.args.ref_D.layout()({idx_m, idx_n})); -+ -+ AccessTypeSoft *access_soft = reinterpret_cast( -+ params.args.ref_Soft.data() + -+ params.args.batch_stride_Soft * block_batch + -+ params.args.ref_Soft.layout()({idx_m, idx_n})); -+ -+ ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; -+ ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; -+ -+ // -+ // Loop -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for ( -+ int idx = 0; -+ idx < params.args.extent.column(); -+ idx += ApplyShape::kColumn * kAlignment) { -+ -+ if (idx_n < params.args.extent.column()) { -+ AccessTypeD fetch; -+ arch::global_load(fetch, access_d, true); -+ -+ FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum)); -+ FragmentSoft soft = convert_soft_output(result); -+ -+ arch::global_store(soft, access_soft, true); -+ } -+ -+ access_d += ApplyShape::kColumn; -+ access_soft += ApplyShape::kColumn; -+ idx_n += ApplyShape::kColumn * kAlignment; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename OperatorClass_, -+ typename ArchTag_, -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ typename EpilogueFunctorOp_, -+ int kStages_, -+ typename ApplyShape_ = MatrixShape<1, 1024>, -+ int AlignmentA_ = 128 / cutlass::sizeof_bits::value, -+ int AlignmentB_ = 128 / cutlass::sizeof_bits::value, -+ int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, -+ typename ElementNorm_ = float, -+ typename ElementSum_ = float, -+ typename ElementSoftmax_ = ElementC_ -+> -+class GemmSoftmax { -+public: -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = ElementA_; -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ using ElementCompute = ElementCompute_; -+ using ElementSum = ElementSum_; -+ using ElementSoft = ElementSoftmax_; -+ using ElementSoftmaxCompute = float; -+ -+ using LayoutA = LayoutA_; -+ using LayoutB = LayoutB_; -+ -+ using EpilogueFunctorOp = EpilogueFunctorOp_; -+ using ElementNorm = ElementNorm_; -+ -+ using ApplyShape = ApplyShape_; -+ -+ // These are mandatory layouts. -+ using LayoutC = cutlass::layout::RowMajor; -+ using LayoutN = cutlass::layout::RowMajor; -+ using LayoutS = cutlass::layout::RowMajor; -+ using LayoutSoft = cutlass::layout::RowMajor; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorRefN = TensorRef; -+ using TensorRefSum = TensorRef; -+ using TensorRefSoft = TensorRef; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ -+ static int const kStages = kStages_; -+ static int const AlignmentA = AlignmentA_; -+ static int const AlignmentB = AlignmentB_; -+ static int const AlignmentSoftmax = AlignmentSoftmax_; -+ -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // basic GEMM kernel -+ using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ AlignmentA, -+ ElementB, -+ LayoutB, -+ AlignmentB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ ThreadblockSwizzle, -+ kStages, -+ true, -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA, ElementB, ElementC, ElementCompute>::Operator, -+ cutlass::gemm::SharedMemoryClearOption::kNone -+ >::GemmKernel; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // Epilogue visitor -+ using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< -+ ThreadblockShape, -+ DefaultGemmKernel::kThreadCount, -+ typename DefaultGemmKernel::Epilogue::OutputTileIterator, -+ ElementCompute, -+ ElementNorm, -+ ElementSum, -+ ElementSoftmaxCompute, -+ EpilogueFunctorOp -+ >; -+ -+ /// Epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< -+ EpilogueVisitor, -+ typename DefaultGemmKernel::Epilogue -+ >::Epilogue; -+ -+ // GEMM -+ using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor< -+ typename DefaultGemmKernel::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Softmax kernel -+ using SoftmaxApplyKernel = kernel::ApplySoftmax< -+ ElementC, -+ ElementNorm, -+ ElementSum, -+ ElementSoft, -+ ElementSoftmaxCompute, -+ AlignmentSoftmax, -+ ApplyShape -+ >; -+ -+ using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< -+ ElementNorm, -+ ElementSum, -+ ElementSoftmaxCompute, -+ ThreadblockShape -+ >; -+ -+public: -+ -+ /// Arguments class -+ struct Arguments { -+ -+ typename GemmKernel::Arguments gemm; -+ typename SoftmaxApplyKernel::Arguments softmax; -+ typename ApplyFinalReductionKernel::Arguments reduction; -+ cutlass::gemm::GemmCoord extend; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ int32_t batch_count_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ TensorRefC ref_C_, -+ TensorRefC ref_D_, -+ typename EpilogueFunctorOp::Params linear_scaling, -+ TensorRefN ref_N_, -+ TensorRefSum ref_S_, -+ TensorRefSoft ref_Softmax_, -+ int64_t batch_stride_A_ = 0, -+ int64_t batch_stride_B_ = 0, -+ int64_t batch_stride_C_ = 0, -+ int64_t batch_stride_D_ = 0, -+ int64_t batch_stride_Max_ = 0, -+ int64_t batch_stride_Sum_ = 0, -+ int64_t batch_stride_Softmax_ = 0 -+ ): -+ gemm( -+ cutlass::gemm::GemmUniversalMode::kBatched, -+ problem_size, -+ batch_count_, -+ ref_A_, -+ ref_B_, -+ ref_C_, -+ ref_D_, -+ ref_N_.data(), -+ ref_S_.data(), -+ batch_stride_A_, -+ batch_stride_B_, -+ typename EpilogueVisitor::Arguments( -+ linear_scaling, -+ batch_stride_C_, -+ batch_stride_D_, -+ batch_stride_Max_, -+ batch_stride_Sum_ -+ ) -+ ), -+ reduction( -+ problem_size, -+ ref_N_.data(), -+ ref_S_.data(), -+ batch_stride_Max_, -+ batch_stride_Sum_ -+ ), -+ softmax( -+ MatrixCoord(problem_size.m(), problem_size.n()), -+ batch_count_, -+ ref_D_, -+ ref_N_, -+ ref_S_, -+ ref_Softmax_, -+ batch_stride_D_, -+ batch_stride_Max_, -+ batch_stride_Sum_, -+ batch_stride_Softmax_ -+ ), -+ extend(problem_size) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename GemmKernel::Params gemm; -+ typename SoftmaxApplyKernel::Params softmax; -+ typename ApplyFinalReductionKernel::Params reduction; -+ MatrixCoord extend; -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args): -+ gemm(args.gemm), -+ reduction(args.reduction), -+ softmax(args.softmax), -+ extend(MatrixCoord(args.extend.m(), args.extend.n())) -+ { -+ -+ } -+ }; -+ -+public: -+ -+ // Gemm -+ -+ -+ // -+ // Methods -+ // -+ -+private: -+ -+ Params params_; -+ -+public: -+ -+ /// Ctor -+ GemmSoftmax() { -+ -+ } -+ -+ /// Initialize -+ Status initialize(Arguments const &args) { -+ -+ params_ = Params(args); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Run -+ Status run(cudaStream_t stream) { -+ -+ // -+ // Launch the GEMM + max kernel -+ // -+ -+ dim3 gemm_grid = ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape); -+ dim3 gemm_block(GemmKernel::kThreadCount, 1, 1); -+ -+ int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_.gemm); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ -+ // -+ // Launch the ApplyFinalReductionKernel -+ // -+ -+ int thread_per_block = 128; -+ int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; -+ if (block_per_row < 4) { -+ thread_per_block = 32; -+ block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; -+ } -+ -+ dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); -+ dim3 final_reduction_block(thread_per_block); -+ -+ Kernel<<< -+ final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream -+ >>>(params_.reduction); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the SoftmaxApplyKernel -+ // -+ -+ dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); -+ -+ int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; -+ int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; -+ -+ dim3 apply_grid( -+ (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, -+ (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, -+ params_.softmax.args.batch_count); -+ -+ Kernel<<< -+ apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream -+ >>>(params_.softmax); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Function call operator -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu -new file mode 100644 -index 0000000..3ae92c3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+// This example fuses gather before GEMM and scatter after GEMM into the same -+// GEMM kernel. Gather and scatter operation is controled by an index vector -+// to select rows or columns from A, B, C or D matrices. -+// -+// Suppose, all matrices are column major. The pseudo code of the fused kernel -+// in this example is essentially -+// -+// for (int i = 0; i < problem_size.m(); ++i) { -+// for (int j = 0; j < options.index_size; ++j) { -+// int b_c_d_col = tensor_indices.at({j, 0}); -+// -+// for (int k = 0; k < options.index_size; ++k) { -+// tensor_d_ref.at({i, b_c_d_col}) += -+// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); -+// } -+// } -+// -+// Note that the index vector contains unique random integers with max to be N - 1 -+// -+// The gather/scatter operation works best when we can still keep the biggest -+// alignment. For example, when the matrix is row major, we select rows. When -+// the matrix is column major, we select columns. -+// -+// Not all the combination of gather and scatter are legal. For example, if A is -+// row major and C/D is column major, we cannot gather A and scatter C/D at the -+// same time. -+// -+// Also, we don't check the index value is legal and index array point is valid -+// for the sake of the performance. -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "helper.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ int index_size; -+ -+ bool reference_check; -+ int iterations; -+ -+ Options(): -+ help(false), -+ problem_size({248, 1024, 1024}), -+ index_size(240), -+ reference_check(true), -+ iterations(20) { } -+ -+ bool valid() { -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ -+ cmd.get_cmd_line_argument("index_size", index_size); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "36_gather_scatter_fusion example\n\n" -+ << " This example uses the CUTLASS Library to fuse gather/scatter into GEMM\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --index_size= size of N dimension index\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/36_gather_scatter_fusion/36_gather_scatter_fusion --m=1024 --n=512 --k=1024 \\\n" -+ << " --index_size=128\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = problem_size.product(); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The code section below describes datatype for input, output matrices and computation between -+// elements in input matrices. -+using ElementAccumulator = float; // <- data type of accumulator -+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A -+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B -+using ElementOutput = float; // <- data type of elements in output matrix D -+ -+// The code section below describes matrix layout of input and output matrices. -+// Column Major for Matrix A, B and C. -+// -+using LayoutInputA = cutlass::layout::ColumnMajor; -+using LayoutInputB = cutlass::layout::ColumnMajor; -+using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ShapeMMAThreadBlock = -+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 -+// This code section describes tile size a warp will compute -+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 -+// This code section describes the size of MMA op -+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 4 -+// 16, 8, 8 -> Turing -+// 16, 8, 16 -> Ampere -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? -+ -+// Define the epilogue operation as LinearCombination. This is approximately equal to -+// -+// d_ij = alpha * sum_k(a_ik * b_kj) + c_ij -+// -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // <- data type of output matrix -+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per -+ // vectorized memory access. For half -+ // precision, it's 8 elements. This becomes -+ // the vector width of math instructions in -+ // epilogue too -+ ElementAccumulator, // <- data type of accumulator -+ ElementComputeEpilogue>; // <- data type for alpha in linear combination function -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 5; -+// Ampere -> 4/5 -+// Turing -> 2 -+ -+using Gemm = cutlass::gemm::device::GemmUniversal; -+ -+int run(Options &options) { -+ -+ // ================================================================================ -+ // Initialization setup -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size = options.problem_size; -+ -+ // Create a tuple of problem size for matrix multiplication -+ cutlass::gemm::GemmCoord problem_size_real(problem_size.m(), -+ options.index_size, -+ problem_size.k()); -+ -+ // Initialize tensors using CUTLASS helper functions -+ cutlass::HostTensor tensor_a( -+ problem_size.mk()); // <- Create matrix A with dimensions M x K -+ cutlass::HostTensor tensor_b( -+ problem_size.kn()); // <- Create matrix B with dimensions K x N -+ cutlass::HostTensor tensor_c( -+ problem_size.mn()); // <- Create matrix C with dimensions M x N -+ cutlass::HostTensor tensor_d_scattered( -+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from -+ // CUTLASS kernel -+ -+ // Fill input and output matrices on host using CUTLASS helper functions -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); // <- Fill matrix A on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); // <- Fill matrix B on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); // <- Fill matrix C on host with uniform-distribution random data -+ -+ cutlass::reference::host::TensorFill( -+ tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros -+ -+ cutlass::HostTensor tensor_indices( -+ {options.index_size, 1}); // <- Create scatter indices with dimensions val_len x 1 -+ -+ // <- Fill tensor_b_indices on host with unique random integers -+ std::vector to_fill(problem_size.n()) ; // vector with ints. -+ std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n() -+ std::random_shuffle(to_fill.begin(), to_fill.end()); -+ memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int)); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_indices.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d_scattered.sync_device(); -+ -+ // Initialize alpha/beta for dot product computation -+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1); -+ ElementComputeEpilogue beta = ElementComputeEpilogue(1); -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch -+ // instantiated CUTLASS kernel -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size_real, // <- problem size of matrix multiplication -+ split_k_slices, // <- k-dimension split factor -+ {alpha, beta}, // <- alpha, beta -+ tensor_a.device_data(), // <- reference to matrix A on device -+ tensor_b.device_data(), // <- reference to matrix B on device -+ tensor_c.device_data(), // <- reference to matrix C on device -+ tensor_d_scattered.device_data(), // <- reference to matrix D on device -+ tensor_a.layout().capacity(problem_size.mk()), -+ tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.n())), -+ tensor_c.layout().capacity(problem_size.mn()), -+ tensor_d_scattered.layout().capacity(problem_size.mn()), -+ tensor_a.layout().stride(), -+ tensor_b.layout().stride(), -+ tensor_c.layout().stride(), -+ tensor_d_scattered.layout().stride(), -+ nullptr, // <- pointer to index vector to gather A on device -+ tensor_indices.device_data(), // <- pointer to index vector to gather B on device -+ tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm_op; -+ -+ // Check the problem size is supported or not -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(status); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ status = gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ // CPU reference calculation -+ cutlass::HostTensor tensor_d_ref(problem_size.mn()); -+ cutlass::reference::host::TensorFill( -+ tensor_d_ref.host_view()); // <- Fill matrix D on host with zeros -+ -+ status = gemm_op(); -+ cudaDeviceSynchronize(); -+ CUTLASS_CHECK(status); -+ -+ if (options.reference_check) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < options.index_size; ++j) { -+ int b_c_d_col = tensor_indices.at({j, 0}); -+ -+ for (int k = 0; k < problem_size.k(); ++k) { -+ tensor_d_ref.at({i, b_c_d_col}) += -+ alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); -+ } -+ -+ tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col})); -+ } -+ } -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ tensor_d_scattered.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d_scattered.host_view(), -+ tensor_d_ref.host_view()); -+ -+ if (!passed) { -+ std::cout << "Failed!\n"; -+ -+ std::stringstream fname; -+ fname << "error_gather_GEMM_scatter_fusion.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A =\n" << tensor_a.host_view() -+ << "\nB =\n" << tensor_b.host_view() -+ << "\nindices =\n" << tensor_indices.host_view() -+ << "\nC =\n" << tensor_c.host_view() -+ << "\n\nReference =\n" << tensor_d_ref.host_view() -+ << "\nComputed =\n" << tensor_d_scattered.host_view(); -+ return -1; -+ } else { -+ std::cout << "Passed!\n"; -+ } -+ } -+ -+ // Result structure -+ Result result; -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMMs -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ // Launch initialized CUTLASS kernel -+ status = gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMMs are complete -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << "Runtime: " << result.runtime_ms << " ms\n"; -+ std::cout << " GFLOPs: " << result.gflops << "\n"; -+ -+ return 0; -+} -+ -+int main(int argc, const char ** argv) { -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 8)) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << "\n"; -+ return 0; -+ } -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << "\n"; -+ return -1; -+ } -+ -+ return run(options); -+} -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu -new file mode 100644 -index 0000000..ffe378b ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu -@@ -0,0 +1,937 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Layernorm Example. -+ -+ This workload provides a layer normalization example using a one-pass, square-sum-based -+ variance calculation. Specifically, we fuse the reduction operation to find -+ local mean and local square sum mean in the epilogue of 1st GEMM. After a light -+ full reduction kernel, the mean / variance values are readily calculated for element-wise -+ operations which are fused into the 2nd GEMM. -+ -+ As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data, -+ the square-sum based one-pass implementation may raise concerns on numerical stability issues. -+ That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to -+ access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive -+ usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm -+ example in tools/util/include/cutlass/util/device_layernorm.h -+ -+ Examples: -+ -+ # Run a CUTLASS layernorm example with default setup , -+ # using the language of the transformer model as an example, -+ (Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4) -+ $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion -+ -+ # Run an attention example with hidden dimension = 512 -+ $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512 -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/fast_math.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_layernorm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class Disposition { -+ kPassed, -+ kIncorrect, -+ kNotVerified -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+template -+struct Options { -+ -+ using LayoutOutput = LayoutOutput_; -+ -+ static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; -+ -+ bool help; -+ cutlass::gemm::GemmCoord problem_size0; -+ cutlass::gemm::GemmCoord problem_size1; -+ int hidden_dim; -+ int valid_word_num; -+ int intermediate_scale; -+ int iterations; -+ unsigned seed; -+ float alpha; -+ float beta; -+ bool verification_enabled; -+ double tolerance; -+ -+ Options(): -+ help(false), -+ iterations(20), -+ seed(2022), -+ hidden_dim(768), -+ valid_word_num(4096), -+ intermediate_scale(4), -+ alpha(1), -+ beta(0), -+ verification_enabled(true), -+ tolerance(0.01), -+ problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m()) -+ { } -+ -+ bool valid() { -+ -+ return true; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768); -+ cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("verify", verification_enabled); -+ cmd.get_cmd_line_argument("seed", seed); -+ cmd.get_cmd_line_argument("tolerance", tolerance); -+ -+ if (kIsColumnMajorOutput) { -+ // column major output setup -+ problem_size0.m() = hidden_dim; -+ problem_size0.n() = valid_word_num; -+ problem_size0.k() = hidden_dim; -+ -+ problem_size1.m() = hidden_dim * intermediate_scale; -+ problem_size1.n() = valid_word_num; -+ problem_size1.k() = hidden_dim; -+ }else{ -+ // row major output setup -+ problem_size0.m() = valid_word_num; -+ problem_size0.n() = hidden_dim; -+ problem_size0.k() = hidden_dim; -+ -+ problem_size1.m() = valid_word_num; -+ problem_size1.n() = hidden_dim * intermediate_scale; -+ problem_size1.k() = hidden_dim; -+ } -+ -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "37_gemm_layernorm_gemm_fusion example\n\n" -+ << " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --hidden_dim= Hidden dimension\n" -+ << " --valid_word_num= Valid word number\n" -+ << " --seed= Random number seed (1*)\n\n" -+ << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" -+ << " --verify= If true, performs reference calculation.\n\n" -+ << " --tolerance Error tolerance\n" -+ ; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n" -+ << " --hidden_dim=768 --valid_word_num=1024 \n\n"; -+ -+ return out; -+ } -+ -+ /// Returns true if the environment and Toolkit support this -+ bool supported(bool verbose = true) const { -+ -+ // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available -+ // in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ } -+ return false; -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ if (verbose) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ } -+ return false; -+ } -+ -+ if (!((props.major * 10 + props.minor) >= 80)) { -+ if (verbose) { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ } -+ return false; -+ } -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((problem_size0.m() % kAlignment) || -+ (problem_size0.n() % kAlignment) || -+ (problem_size0.k() % kAlignment)) { -+ if (verbose) { -+ std::cerr << "Misaligned input in 1st GEMM." << std::endl; -+ } -+ // misaligned tensors for Gemm1 -+ return false; -+ } -+ -+ if ((problem_size1.m() % kAlignment) || -+ (problem_size1.n() % kAlignment) || -+ (problem_size1.k() % kAlignment)) { -+ if (verbose) { -+ std::cerr << "Misaligned input in 2nd GEMM." << std::endl; -+ } -+ // misaligned tensors for Gemm2 -+ return false; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ typename LayoutOutput_> -+struct Testbed { -+ -+ // -+ // Type definitions -+ // -+ -+ // User-defined data types -+ using ElementInputA0 = cutlass::half_t; -+ using ElementInputB0 = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using LayoutInputA0 = cutlass::layout::RowMajor; -+ using LayoutInputB0 = cutlass::layout::ColumnMajor; -+ using LayoutOutput = LayoutOutput_; -+ -+ static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; -+ // turn of shifted K by default -+ static bool const kIsShiftedVariance = false; -+ -+ /// Linear scaling operator -+ using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementCompute, -+ ElementCompute -+ >; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kStages0 = 3; -+ static int const kStages1 = 4; -+ -+ using GemmLayernorm = cutlass::GemmLayernorm< -+ ElementInputA0, -+ LayoutInputA0, -+ ElementInputB0, -+ LayoutInputB0, -+ ElementOutput, -+ LayoutOutput, -+ ElementCompute, -+ EpilogueFunctorOp, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages0, -+ kStages1, -+ kIsShiftedVariance -+ >; -+ -+ using ElementInputA1 = typename GemmLayernorm::ElementInputA1; -+ using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1; -+ using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias; -+ using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute; -+ -+ using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1; -+ using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0; -+ using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1; -+ using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias; -+ -+ // -+ // Data members -+ // -+ -+ Options const &options; -+ -+ cutlass::HostTensor tensor_A0; -+ cutlass::HostTensor tensor_B0; -+ cutlass::HostTensor tensor_C0; -+ cutlass::HostTensor tensor_A1; -+ cutlass::HostTensor tensor_C1; -+ -+ cutlass::HostTensor reference_C0; -+ cutlass::HostTensor reference_C1; -+ -+ cutlass::HostTensor tensor_Variance; -+ cutlass::HostTensor tensor_Mean; -+ cutlass::HostTensor tensor_Beta; -+ cutlass::HostTensor tensor_Gamma; -+ -+ cutlass::HostTensor reference_Mean; -+ cutlass::HostTensor reference_Variance; -+ -+ // shifted K tensor to better ensure the numerical stability -+ // According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -+ // the closer shifted K to the actual mean, the better numerical stability we'll observe -+ cutlass::HostTensor tensor_Shifted_K; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_ -+ ): -+ options(options_) -+ { -+ -+ tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()}); -+ tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()}); -+ -+ tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); -+ -+ tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()}); -+ tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); -+ -+ reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); -+ reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); -+ -+ int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m(); -+ int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n(); -+ -+ int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM; -+ -+ tensor_Variance.reset({block_num, leading_dim_0}); -+ tensor_Mean.reset({block_num, leading_dim_0}); -+ tensor_Shifted_K.reset({1, leading_dim_0}); -+ -+ tensor_Beta.reset({1, leading_dim_1}); -+ tensor_Gamma.reset({1, leading_dim_1}); -+ -+ reference_Mean.reset({1, leading_dim_0}, false); -+ reference_Variance.reset({1, leading_dim_0}, false); -+ -+ } -+ -+ /// Run -+ Disposition run() { -+ -+ Disposition disposition = Disposition::kNotVerified; -+ -+ // -+ // Initialize the workspace -+ // -+ -+ initialize(); -+ -+ // -+ // Launch device kernel -+ // -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return disposition; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Device synchronize failed with error " -+ << cudaGetErrorString(result) << std::endl; -+ return disposition; -+ } -+ -+ // -+ // Compute the reference -+ // -+ compute_reference(); -+ -+ // -+ // Verify -+ // -+ -+ if (options.verification_enabled) { -+ -+ bool passed = verify(); -+ -+ if (passed) { -+ disposition = Disposition::kPassed; -+ } -+ else { -+ disposition = Disposition::kIncorrect; -+ } -+ } -+ -+ // -+ // Profiling -+ // -+ if (options.iterations) { -+ profile(); -+ } -+ -+ return disposition; -+ } -+ -+ /// Random initialization -+ void initialize() { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A0.host_view(), -+ options.seed, -+ ElementInputA0(5), -+ ElementInputA0(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B0.host_view(), -+ options.seed + 1, -+ ElementInputB0(5), -+ ElementInputB0(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A1.host_view(), -+ options.seed + 2, -+ ElementInputA1(5), -+ ElementInputA1(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Beta.host_view(), -+ options.seed + 3, -+ ElementInputScaleBias(5), -+ ElementInputScaleBias(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Gamma.host_view(), -+ options.seed + 4, -+ ElementInputScaleBias(5), -+ ElementInputScaleBias(-5), -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_Shifted_K.host_view(), -+ options.seed + 5, -+ ElementOutput(5), -+ ElementOutput(-6), -+ 0 -+ ); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_A1.sync_device(); -+ tensor_Beta.sync_device(); -+ tensor_Gamma.sync_device(); -+ -+ } -+ -+ -+ -+ cutlass::Status execute_device_kernel() { -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ -+ // -+ // Setup arguments -+ // -+ -+ typename GemmLayernorm::Arguments args( -+ options.problem_size0, -+ options.problem_size1, -+ tensor_A0.device_ref().data(), -+ tensor_B0.device_ref().data(), -+ tensor_C0.device_ref().data(), -+ tensor_C0.device_ref().data(), -+ tensor_A1.device_ref().data(), -+ tensor_C1.device_ref().data(), -+ tensor_A0.device_ref().stride(0), -+ tensor_B0.device_ref().stride(0), -+ tensor_C0.device_ref().stride(0), -+ tensor_C0.device_ref().stride(0), -+ tensor_A1.device_ref().stride(0), -+ tensor_C1.device_ref().stride(0), -+ { -+ ElementCompute(options.alpha), -+ ElementCompute(options.beta) -+ }, -+ tensor_Variance.device_ref(), -+ tensor_Mean.device_ref(), -+ tensor_Gamma.device_ref(), -+ tensor_Beta.device_ref(), -+ tensor_Shifted_K.device_ref().data() -+ ); -+ -+ // -+ // Launch -+ // -+ -+ GemmLayernorm gemm_layernorm; -+ -+ // Initialize -+ status = gemm_layernorm.initialize(args); -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run -+ status = gemm_layernorm(); -+ -+ return status; -+ } -+ -+ /// Reference calculation -+ void compute_reference() { -+ -+ cutlass::reference::device::Gemm< -+ ElementInputA0, -+ LayoutInputA0, -+ ElementInputB0, -+ LayoutInputB0, -+ ElementOutput, -+ LayoutOutputC0, -+ ElementCompute, -+ ElementCompute -+ > gemm_device0; -+ -+ cutlass::reference::device::Gemm< -+ ElementInputA1, -+ LayoutInputA1, -+ ElementOutput, -+ LayoutOutputC0, -+ ElementOutputC1, -+ LayoutOutputC1, -+ ElementCompute, -+ ElementCompute -+ > gemm_device1; -+ -+ // Compute 1st GEMM -+ gemm_device0( -+ options.problem_size0, -+ ElementCompute(options.alpha), -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ElementCompute(options.beta), -+ tensor_C0.device_ref(), -+ reference_C0.device_ref() -+ ); -+ -+ reference_C0.sync_host(); -+ -+ tensor_Mean.sync_host(); -+ tensor_Variance.sync_host(); -+ tensor_Gamma.sync_host(); -+ tensor_Beta.sync_host(); -+ tensor_Shifted_K.sync_host(); -+ -+ // Compute the sum and square sum for verification purpose -+ if (kIsColumnMajorOutput) { -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ ElementLayernormCompute square_sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})); -+ square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})); -+ } -+ -+ ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ; -+ -+ mean = -mean * variance; -+ -+ reference_Mean.at({0, n}) = ElementInputScaleBias(mean); -+ reference_Variance.at({0, n}) = ElementInputScaleBias(variance); -+ } -+ }else{ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ ElementLayernormCompute square_sum = ElementLayernormCompute(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})) ; -+ square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ; -+ } -+ -+ ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n()); -+ ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n()); -+ ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ; -+ -+ mean = -mean * variance; -+ -+ reference_Mean.at({0, m}) = ElementInputScaleBias(mean); -+ reference_Variance.at({0, m}) = ElementInputScaleBias(variance); -+ } -+ } -+ -+ // Element-wise transform for OutputC0 using 1-pass layernorm algo -+ if (kIsColumnMajorOutput) { -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ -+ ElementLayernormCompute sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n})) ; -+ } -+ -+ ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m())); -+ sum = ElementLayernormCompute(0); -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ; -+ } -+ -+ ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m()); -+ ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one() -+ / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ; -+ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ reference_C0.at({m, n}) = -+ ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance ) -+ * tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m})); -+ -+ } -+ -+ } -+ }else{ -+ -+ for (int m = 0; m < options.problem_size0.m(); ++m) { -+ -+ float sum = float(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += float(reference_C0.at({m, n})) ; -+ } -+ -+ float mean = sum / float(options.problem_size0.n()); -+ sum = float(0); -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ; -+ } -+ -+ float square_mean = sum / float(options.problem_size0.n()); -+ float variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ; -+ -+ for (int n = 0; n < options.problem_size0.n(); ++n) { -+ reference_C0.at({m, n}) = -+ ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance ) -+ * float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n}))); -+ -+ } -+ -+ } -+ -+ } -+ -+ -+ // Sync host data with device after element-wise transform -+ reference_C0.sync_device(); -+ -+ // Compute 2nd GEMM -+ gemm_device1( -+ options.problem_size1, -+ ElementCompute(options.alpha), -+ kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(), -+ kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(), -+ ElementCompute(options.beta), -+ reference_C1.device_ref(), -+ reference_C1.device_ref() -+ ); -+ -+ } -+ -+ /// Emits all tensor values -+ void emit_results() { -+ std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n"; -+ std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n"; -+ std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n"; -+ std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n"; -+ std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n"; -+ std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n"; -+ } -+ -+ template -+ bool verify_tensor(cutlass::HostTensor tensor, \ -+ cutlass::HostTensor reference, -+ int leading_dim0, int leading_dim1, bool is_print = false) { -+ float const kThreshold = float(options.tolerance); -+ float const kAbsThreshold = 0.5f; -+ float const kRelativeThreshold = 0.1f; -+ // Adds a constant bias to avoid being divided by '0' -+ float const kBias = 1e-5f; -+ int counter = 0; -+ for (int m = 0; m < leading_dim0; m++) { -+ for (int n = 0; n < leading_dim1; ++n) { -+ float diff = (float)(tensor.at({m, n}) - reference.at({m, n})); -+ float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias); -+ if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) { -+ counter++; -+ } -+ } -+ } -+ -+ float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1)); -+ return (err_rate < kThreshold); -+ } -+ -+ /// Verifies the reference matches -+ bool verify() { -+ -+ tensor_Variance.sync_host(); -+ tensor_Mean.sync_host(); -+ tensor_C1.sync_host(); -+ reference_C1.sync_host(); -+ -+ // Verification checks - set any of these to 'true' to override the verification checks. -+ bool verified_C1 = false; -+ bool verified_Mean = false; -+ bool verified_Variance = false; -+ -+ // Verify layernorm output -+ if (!verified_C1) { -+ verified_C1 = verify_tensor(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n()); -+ } -+ -+ if (!verified_Variance) { -+ verified_Variance = verify_tensor(tensor_Variance, reference_Variance, 1, options.problem_size0.n()); -+ } -+ -+ if (!verified_Mean) { -+ verified_Mean = verify_tensor(tensor_Mean, reference_Mean, 1, options.problem_size0.n()); -+ } -+ -+ if (!verified_C1 || !verified_Mean || !verified_Variance) { -+ -+ // emit_results(); -+ -+ std::cerr << "Verification check failed for tensor Layernorm" << std::endl; -+ -+ // Summarize which checks failed -+ if (!verified_C1) { -+ std::cerr << "Verification of O tensor failed\n"; -+ } -+ -+ if (!verified_Mean) { -+ std::cerr << "Verification of Mean tensor failed\n"; -+ } -+ -+ if (!verified_Variance) { -+ std::cerr << "Verification of Variance tensor failed\n"; -+ } -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Profiles -+ bool profile() { -+ -+ // -+ // Profile -+ // -+ -+ cutlass::Status status = cutlass::Status::kSuccess; -+ cudaError_t result; -+ cudaEvent_t events[2]; -+ int const kIterations = options.iterations; -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventCreate(&evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (int iter = 0; iter < kIterations; ++iter) { -+ -+ status = execute_device_kernel(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Device execution failed." << std::endl; -+ return false; -+ } -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ float elapsed_ms = 0; -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ float elapsed_ms_per_iter = elapsed_ms / float(kIterations); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ for (cudaEvent_t &evt : events) { -+ result = cudaEventDestroy(evt); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ } -+ -+ int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \ -+ + int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2; -+ -+ double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); -+ -+ std::cout << " 1st GEMM: " -+ << options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n" -+ << " 2nd GEMM: " -+ << options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k() -+ << std::endl; -+ -+ std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl; -+ std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, const char **argv) { -+ -+ // Define final layout -+ using LayoutOutput = cutlass::layout::ColumnMajor; -+ -+ // Options parsing -+ Options options; -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (!options.supported()) { -+ return 0; -+ } -+ -+ // Run -+ Testbed testbed(options); -+ -+ Disposition disposition = testbed.run(); -+ -+ std::cout << std::endl; -+ -+ switch (disposition) { -+ case Disposition::kPassed: -+ std::cout << "Passed" << std::endl; -+ break; -+ case Disposition::kIncorrect: -+ std::cout << "Incorrect" << std::endl; -+ break; -+ case Disposition::kNotVerified: -+ std::cout << "Not verified" << std::endl; -+ break; -+ } -+ -+ return (disposition == Disposition::kPassed ? 0 : -1); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h -new file mode 100644 -index 0000000..143bca3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h -@@ -0,0 +1,444 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief GEMM kernel to support the epilogue visitor model -+ for customized layernorm partial reduction epilogue fusion. -+ -+ This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once -+ its usage has been stabilized. For now, it is included in this example to demonstrate -+ some basic output fusion options. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using TensorRefA = TensorRef; -+ -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using TensorRefB = TensorRef; -+ -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename Epilogue::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm) -+ { } -+ -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode_, -+ GemmCoord problem_size_, -+ TensorRefA ref_A_, -+ TensorRefB ref_B_, -+ typename EpilogueVisitor::Arguments epilogue_visitor_ -+ ): -+ mode(mode_), -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ epilogue_visitor(epilogue_visitor_) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ -+ GemmUniversalMode mode; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr) -+ { } -+ -+ -+ Params( -+ Arguments const &args -+ ): -+ problem_size(args.problem_size), -+ swizzle_log_tile(0), -+ params_A(args.ref_A.layout()), -+ params_B(args.ref_B.layout()), -+ mode(args.mode), -+ gemm_k_size(args.problem_size.k()), -+ ptr_A(args.ref_A.data()), -+ ptr_B(args.ref_B.data()), -+ epilogue_visitor(args.epilogue_visitor) -+ { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(args.problem_size.k(), kAlignK); -+ -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ typename Mma::SharedStorage main_loop; -+ -+ struct { -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ } epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmWithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // -+ // Construct the epilogue visitor -+ // -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.epilogue.visitor, -+ params.problem_size.mn(), -+ thread_idx, -+ warp_idx, -+ lane_idx, -+ threadblock_offset); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ // Indicate which position in a serial reduction the output operator is currently updating -+ epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h -new file mode 100644 -index 0000000..dde3c07 ---- /dev/null -+++ b/3rdparty/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h -@@ -0,0 +1,1066 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief A file contains all functioning classes needed by GemmLayernorm. -+ -+ GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) -+ + lightweight full reduction kernel (ApplyFinalReduction) -+ + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "gemm_with_epilogue_visitor.h" -+#include "helper.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementOutput, -+ typename ThreadblockShape_, -+ bool IsShiftedVariance_ = false -+> -+class ApplyFinalReduction { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ // Pre-processing has ensured the layout equivelent to RowMajor -+ using Layout = cutlass::layout::RowMajor; -+ -+ using TensorVariance = TensorRef; -+ using TensorMean = TensorRef; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ MatrixCoord extent; ///< Extent of D and Layernorm matrices -+ TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) -+ TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) -+ ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer -+ -+ // -+ // Methods -+ // -+ Arguments(){ } -+ -+ Arguments( -+ MatrixCoord extent_, -+ TensorVariance ref_Variance_, -+ TensorMean ref_Mean_, -+ ElementOutput *ptr_Shifted_K_ -+ ): -+ extent(extent_), -+ ref_Variance(ref_Variance_), -+ ref_Mean(ref_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_) -+ { -+ -+ } -+ }; -+ -+ struct SharedStorage { -+ -+ -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplyFinalReduction() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ /// Partial reduction -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; -+ -+ int block_n = blockIdx.x * blockDim.x; -+ -+ int thread_n = threadIdx.x; -+ -+ int idx_n = block_n + thread_n; -+ -+ if (idx_n >= params.args.extent.row()) { -+ return; -+ } -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ using ConvertVariance = cutlass::NumericConverter; -+ using ConvertMean = cutlass::NumericConverter; -+ -+ using ConvertShiftK = cutlass::NumericConverter; -+ -+ ConvertVariance convert_variance; -+ ConvertMean convert_mean; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; -+ ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; -+ -+ ElementVariance *access_square_bak = access_square; -+ ElementMean *access_mean_bak = access_mean; -+ -+ ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); -+ ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); -+ ElementVariance fetch_square; -+ ElementMean fetch_mean; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { -+ arch::global_load(fetch_square, access_square, true); -+ arch::global_load(fetch_mean, access_mean, true); -+ frag_element_sum += convert_mean(fetch_mean); -+ frag_square_sum += convert_variance(fetch_square); -+ access_square += params.args.extent.row(); -+ access_mean += params.args.extent.row(); -+ } -+ -+ ElementLayernormCompute mean = frag_element_sum; -+ ElementLayernormCompute square_mean = frag_square_sum; -+ -+ ElementLayernormCompute variance; -+ -+ if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { -+ ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; -+ ElementOutput fetch_shift_k; -+ ConvertShiftK convert_shift_k; -+ arch::global_load(fetch_shift_k, access_shift_k, true); -+ ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); -+ variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); -+ }else{ -+ variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); -+ } -+ -+ mean = -mean * variance; -+ -+ access_square = access_square_bak; -+ access_mean = access_mean_bak; -+ -+ access_square[0] = convert_variance_output(variance); -+ access_mean[0] = convert_mean_output(mean); -+ -+ } -+ -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename AccumulatorTile_, -+ typename ElementAccumulator_, -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementwiseFunctor_, -+ bool IsShiftedVariance_ = false -+> -+class EpilogueVisitorLayerNorm { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ -+ using AccumulatorTile = AccumulatorTile_; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; -+ -+ static int const kThreads = OutputTileIterator::ThreadMap::kThreads; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; -+ -+ /// Array type used in Shift-K Layernorm -+ static int const kRowAccessCount = kIterations * kRowIterations; -+ -+ using ConvertedShiftFragment = Array; -+ -+ // Conducts manual transpose externally (already supported) for column major -+ using LayoutOutput = cutlass::layout::RowMajor; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using AccumulatorFragment = Array; -+ using LayernormFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; -+ static int const kThreadsInColumn = kThreads / kThreadsPerRow; -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ TensorRefD ref_C; -+ TensorRefD ref_D; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr), -+ ptr_Shifted_K(nullptr) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ TensorRefD ref_C_, -+ TensorRefD ref_D_, -+ ElementVariance *ptr_Variance, -+ ElementMean *ptr_Mean_, -+ ElementOutput *ptr_Shifted_K_ = nullptr -+ ): -+ elementwise(elementwise_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ptr_Variance(ptr_Variance), -+ ptr_Mean(ptr_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ typename OutputTileIterator::Params params_C; -+ typename OutputTileIterator::Params params_D; -+ typename OutputTileIterator::Element *ptr_C; -+ typename OutputTileIterator::Element *ptr_D; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_D(nullptr), -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ params_C(args.ref_C.layout()), -+ params_D(args.ref_D.layout()), -+ ptr_C(args.ref_C.data()), -+ ptr_D(args.ref_D.data()), -+ ptr_Variance(args.ptr_Variance), -+ ptr_Mean(args.ptr_Mean), -+ ptr_Shifted_K(args.ptr_Shifted_K) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ ConvertedShiftFragment shift_k_frag_; -+ -+ ElementLayernormCompute accum_sum_square_; -+ ElementLayernormCompute accum_sum_element_; -+ -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorLayerNorm( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord const &problem_size0, ///< Problem size of the output -+ int thread_idx, ///< Thread index within the threadblock -+ int warp_idx, ///< Warp index within the threadblock -+ int lane_idx, ///< Lane index within the warp -+ MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ extent_(problem_size0), -+ elementwise_(params.elementwise), -+ iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), -+ iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ // If shift-K feature is enabled, we load shift-k fragment -+ // at the very beginning of an epilogue -+ if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { -+ shift_k_frag_.clear(); -+ int thread_offset_row_base = iterator_D_.thread_start_row(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { -+ int step_offset = iter_idx * OutputTileIterator::Shape::kRow; -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < kRowIterations; ++rid) { -+ int row_step_offset = rid * kDeltaRow; -+ int row_offset = thread_offset_row_base + step_offset + row_step_offset; -+ bool is_load = (row_offset < extent_.row()); -+ shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); -+ } -+ -+ } -+ -+ } -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ fragment_C_.clear(); -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ [[maybe_unused]] Minus minus; -+ [[maybe_unused]] Mul mul; -+ [[maybe_unused]] Exp exponential; -+ -+ LayernormFragment result; -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ -+ ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); -+ -+ // Fragment is cleared for non-reachable columns so no need to check against column guard -+ accum_sum_element_ = element_sum_accumulator_(result); -+ -+ // Square sum is different. Non-reachable columns should've been computed for shift-k -+ // Otherwise we will incorrectly have some extra k^2 added into square sum. -+ if (column_guard) { -+ accum_sum_square_ = (kIsShiftedVariance) ? \ -+ square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ -+ square_sum_accumulator_(result); -+ } -+ else { -+ accum_sum_square_ = ElementLayernormCompute(0); -+ } -+ -+ accum_sum_element_ *= inv_scalar; -+ accum_sum_square_ *= inv_scalar; -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { -+ accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); -+ accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); -+ } -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); -+ int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); -+ -+ ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; -+ ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; -+ -+ arch::global_store( -+ convert_variance_output(accum_sum_square_), -+ (void *)curr_ptr_sum_square, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_mean_output(accum_sum_element_), -+ (void *)curr_ptr_element_sum, -+ is_write_thread); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { -+ using ConvertShiftK = cutlass::NumericConverter; -+ ConvertShiftK convert_shift_k; -+ ElementOutput shift_k_val; -+ -+ // Computes the address to load shift_k element -+ ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; -+ // Conditionally loads from global memory -+ arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); -+ // Converts data type to return -+ ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); -+ -+ return converted_shift_k_val; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i]; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i] - shift_k_val; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ sum_ += accum[i]; -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename ElementInputA0_, -+ typename LayoutInputA0_, -+ typename ElementInputB0_, -+ typename LayoutInputB0_, -+ typename ElementOutput_, -+ typename LayoutOutput_, -+ typename ElementCompute_, -+ typename EpilogueFunctorOp_, -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ int Stages0, -+ int Stages1, -+ bool IsShiftedVariance_ = false -+> -+class GemmLayernorm { -+public: -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // -+ // Type definitions -+ // -+ -+ static bool const kInternalTranspose = cutlass::platform::is_same::value; -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ // These is mandatory layout. -+ using LayoutInputScaleBias = cutlass::layout::RowMajor; -+ -+ // These are mandatory data types. -+ using ElementLayernormCompute = float; -+ using ElementInputScaleBias = cutlass::half_t; -+ -+ // These are mandatory params required by mainloop fusion -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ // These are mandatory layouts and data types -+ // that are inheritated from pre-defined params -+ -+ using LayoutSumSqr = LayoutInputScaleBias; -+ using LayoutSum = LayoutInputScaleBias; -+ -+ using ElementMean = ElementInputScaleBias; -+ using ElementVariance = ElementInputScaleBias; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ using LayoutInputA0 = LayoutInputA0_; -+ using LayoutInputB0 = LayoutInputB0_; -+ using LayoutInputA1 = LayoutOutput_; -+ using LayoutInputB1 = LayoutOutput_; -+ using LayoutOutputC0 = LayoutOutput_; -+ using LayoutOutputC1 = LayoutOutput_; -+ -+ using ElementInputA0 = ElementInputA0_; -+ using ElementInputB0 = ElementInputB0_; -+ using ElementOutputC0 = ElementOutput_; -+ using ElementCompute = ElementCompute_; -+ using ElementInputB1 = ElementInputB0_; -+ -+ using ElementInputA1 = ElementOutputC0; -+ using ElementOutputC1 = ElementOutputC0; -+ -+ using EpilogueFunctorOp = EpilogueFunctorOp_; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorVariance = TensorRef; -+ using TensorMean = TensorRef; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ -+ static int const kStages0 = Stages0; -+ static int const kStages1 = Stages1; -+ -+ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ using MapArguments = cutlass::gemm::kernel::detail::MapArguments< -+ ElementInputA0, -+ LayoutInputA0, -+ cutlass::ComplexTransform::kNone, -+ 128 / cutlass::sizeof_bits::value, -+ ElementInputB0, -+ LayoutInputB0, -+ cutlass::ComplexTransform::kNone, -+ 128 / cutlass::sizeof_bits::value, -+ LayoutOutputC0, -+ kInternalTranspose -+ >; -+ -+ using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementOutputC0, -+ typename MapArguments::LayoutC, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ SwizzleThreadBlock, -+ kStages0, -+ true, -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, -+ cutlass::gemm::SharedMemoryClearOption::kNone -+ >::GemmKernel; -+ -+ /////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ // Epilogue visitor -+ using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< -+ ThreadblockShape, -+ DefaultGemmKernel::kThreadCount, -+ typename DefaultGemmKernel::Epilogue::OutputTileIterator, -+ typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, -+ ElementCompute, -+ ElementVariance, -+ ElementMean, -+ ElementLayernormCompute, -+ EpilogueFunctorOp, -+ kIsShiftedVariance -+ >; -+ -+ /// Epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< -+ EpilogueVisitor, -+ typename DefaultGemmKernel::Epilogue -+ >::Epilogue; -+ -+ // GEMM -+ using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< -+ typename DefaultGemmKernel::Mma, -+ Epilogue, -+ SwizzleThreadBlock -+ >; -+ -+ using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< -+ ElementVariance, -+ ElementMean, -+ ElementLayernormCompute, -+ ElementOutputC0, -+ ThreadblockShape, -+ kIsShiftedVariance -+ >; -+ -+using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< -+ ElementInputA1, LayoutInputA1, -+ ElementInputB1, LayoutInputB1, -+ ElementInputScaleBias, LayoutInputScaleBias, -+ ElementOutputC1, LayoutOutputC1, -+ ElementCompute, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueFunctorOp, -+ SwizzleThreadBlock, -+ kStages1 -+>; -+ -+public: -+ -+ /// Arguments class -+ struct Arguments { -+ -+ typename GemmEpilogueFusion::Arguments gemm0; -+ typename GemmMainloopFusion::Arguments gemm1; -+ typename ApplyFinalReductionKernel::Arguments reduction; -+ cutlass::gemm::GemmCoord extend; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size0, -+ cutlass::gemm::GemmCoord problem_size1, -+ ElementInputA0 * ptr_A, -+ ElementInputB0 * ptr_B, -+ ElementOutputC0 * ptr_C, -+ ElementOutputC0 * ptr_D, -+ ElementOutputC0 * ptr_E, -+ ElementOutputC0 * ptr_O, -+ int64_t ldm_A, -+ int64_t ldm_B, -+ int64_t ldm_C, -+ int64_t ldm_D, -+ int64_t ldm_E, -+ int64_t ldm_O, -+ typename EpilogueFunctorOp::Params linear_scaling, -+ TensorVariance ref_Variance_, -+ TensorMean ref_Mean_, -+ TensorVariance ref_Gamma_, -+ TensorMean ref_Beta_, -+ ElementOutputC0 *ptr_Shifted_K = nullptr -+ ): -+ gemm0( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ -+ kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ -+ problem_size0.k()}, -+ {kInternalTranspose ? ptr_B : ptr_A, \ -+ kInternalTranspose ? ldm_B : ldm_A}, -+ {kInternalTranspose ? ptr_A : ptr_B, \ -+ kInternalTranspose ? ldm_A : ldm_B}, -+ typename EpilogueVisitor::Arguments( -+ linear_scaling, -+ {ptr_C, ldm_C}, -+ {ptr_D, ldm_D}, -+ ref_Variance_.data(), -+ ref_Mean_.data(), -+ ptr_Shifted_K -+ ) -+ ), -+ reduction( -+ MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ -+ kInternalTranspose ? problem_size0.m() : problem_size0.n()), -+ ref_Variance_, -+ ref_Mean_, -+ ptr_Shifted_K -+ ), -+ gemm1( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size1, -+ 1, -+ linear_scaling, -+ kInternalTranspose ? ptr_E : ptr_D, -+ kInternalTranspose ? ptr_D : ptr_E, -+ ref_Variance_.data(), -+ ref_Mean_.data(), -+ ref_Gamma_.data(), -+ ref_Beta_.data(), -+ ptr_O, -+ ptr_O, -+ problem_size1.m() * problem_size1.k(), -+ problem_size1.n() * problem_size1.k(), -+ problem_size1.n(), -+ problem_size1.n(), -+ problem_size1.k(), -+ problem_size1.k(), -+ problem_size1.m() * problem_size1.n(), -+ problem_size1.m() * problem_size1.n(), -+ kInternalTranspose ? ldm_E : ldm_D, -+ kInternalTranspose ? ldm_D : ldm_D, -+ ref_Variance_.layout().stride(0), -+ ref_Mean_.layout().stride(0), -+ ref_Gamma_.layout().stride(0), -+ ref_Beta_.layout().stride(0), -+ ldm_O, -+ ldm_O -+ ), -+ extend(problem_size0) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename GemmEpilogueFusion::Params gemm0; -+ typename ApplyFinalReductionKernel::Params reduction; -+ MatrixCoord extend; -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args): -+ gemm0(args.gemm0), -+ reduction(args.reduction), -+ extend(MatrixCoord(args.extend.m(), args.extend.n())) -+ { -+ -+ } -+ }; -+ -+public: -+ -+ // Gemm -+ -+ -+ // -+ // Methods -+ // -+ -+private: -+ -+ Params params_; -+ GemmMainloopFusion gemm_fusion_op; -+ -+public: -+ -+ /// Ctor -+ GemmLayernorm() { -+ -+ } -+ -+ /// Initialize -+ Status initialize(Arguments const &args) { -+ -+ params_ = Params(args); -+ cutlass::Status status; -+ size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ status = gemm_fusion_op.can_implement(args.gemm1); -+ CUTLASS_CHECK(status); -+ -+ status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); -+ CUTLASS_CHECK(status); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Run -+ Status run(cudaStream_t stream) { -+ -+ // -+ // Launch the GEMM + layernorm kernel -+ // -+ -+ dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); -+ dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); -+ -+ int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_.gemm0); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the ApplyFinalReductionKernel -+ // -+ -+ // always performs reduction from leading dimension -+ int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); -+ int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); -+ -+ int thread_per_block = 128; -+ int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; -+ if (block_per_row < 4) { -+ thread_per_block = 32; -+ block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; -+ } -+ -+ dim3 final_reduction_block(thread_per_block); -+ dim3 final_reduction_grid(block_per_row); -+ -+ Kernel<<< -+ final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream -+ >>>(params_.reduction); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ // -+ // Launch the GEMM + mainloop fusion kernel -+ // -+ -+ cutlass::Status status = gemm_fusion_op(); -+ CUTLASS_CHECK(status); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ /// Function call operator -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu b/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu -new file mode 100644 -index 0000000..d8adb9c ---- /dev/null -+++ b/3rdparty/cutlass/examples/38_syr2k_grouped/syr2k_grouped.cu -@@ -0,0 +1,1466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief SYR2K Grouped Example. -+ -+ This workload computes a batch of SYR2K operations with distinct problem sizes. This example closely -+ follows 24_gemm_grouped. -+ -+ Examples: -+ -+ # Runs a grouped SYR2K with 100 random problem sizes -+ $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 -+ -+ # Runs a grouped SYR2K with 100 random problem sizes (with SYR2K-K dimension equal to 1024) -+ $ ./examples/38_syr2k_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true -+ -+ # Runs a grouped SYR2K that is equivalent to a batched SYR2K -+ $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true -+ -+ # Execute grouped SYR2K and profile with NSight -+ $ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true \ -+ --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double initialization_time_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double initialization_time_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), -+ status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool profile_initialization; -+ bool sort_problems; -+ -+ std::vector problem_sizes; -+ -+ int alignment; -+ int problem_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ std::string benchmark_path; -+ -+ std::string output_tag; -+ std::ofstream output_file; -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ std::vector scheduler_modes; -+ -+ std::unordered_map -+ str_to_scheduler_mode = { -+ {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, -+ {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} -+ }; -+ -+ struct GroupScheduleModeHash { -+ size_t operator()(GroupScheduleMode m) const { -+ return static_cast(m); -+ } -+ }; -+ -+ std::unordered_map -+ scheduler_mode_to_str = { -+ {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, -+ {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} -+ }; -+ -+ std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(8), -+ reference_check(true), -+ profile_initialization(false), -+ sort_problems(false), -+ problem_count(5), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta(), -+ scheduler_modes({GroupScheduleMode::kDeviceOnly}) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 8); -+ cmd.get_cmd_line_argument("groups", problem_count, 5); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, false); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); -+ cmd.get_cmd_line_argument("sort-problems", sort_problems, false); -+ cmd.get_cmd_line_argument("benchmark", benchmark_path); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ scheduler_modes.clear(); -+ if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { -+ scheduler_modes = all_scheduler_modes; -+ } else { -+ for (std::string precomp_str : scheduler_mode_strs) { -+ auto it = str_to_scheduler_mode.find(precomp_str); -+ if (it != str_to_scheduler_mode.end()) { -+ scheduler_modes.push_back(it->second); -+ } else if (precomp_str == "all") { -+ std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; -+ error = true; -+ return; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ } -+ } -+ -+ std::string output_path; -+ cmd.get_cmd_line_argument("tag", output_tag); -+ cmd.get_cmd_line_argument("output_file", output_path); -+ -+ if (!output_path.empty()) { -+ -+ std::ios_base::openmode open_mode = std::ios_base::out; -+ -+ std::ifstream input_file(output_path.c_str()); -+ -+ if (input_file.good()) { -+ open_mode = std::ios_base::app; -+ input_file.close(); -+ } -+ -+ output_file.open(output_path.c_str(), open_mode); -+ -+ if (output_file.good() && open_mode != std::ios_base::app) { -+ output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; -+ } -+ } -+ -+ // Decide how to initialize the problems -+ if (!benchmark_path.empty()) { -+ if (!benchmark_problems()) { -+ error = true; -+ problem_sizes.clear(); -+ return; -+ } -+ } -+ else { -+ randomize_problems(cmd); -+ } -+ } -+ -+ void randomize_problems(cutlass::CommandLine &cmd) { -+ -+ // -+ // For now, randomly choose the problem sizes. -+ // -+ -+ int cmd_line_m = -1; -+ int cmd_line_n = -1; -+ int cmd_line_k = -1; -+ -+ cmd.get_cmd_line_argument("m", cmd_line_m); -+ cmd.get_cmd_line_argument("n", cmd_line_n); -+ cmd.get_cmd_line_argument("k", cmd_line_k); -+ -+ // SYR2K is defined via only N and K. -+ if (cmd_line_m != -1) { -+ std::cerr << "Parameter M is ignored for SYR2K\n"; -+ error = true; -+ return; -+ } -+ -+ problem_sizes.reserve(problem_count); -+ -+ for (int i = 0; i < problem_count; ++i) { -+ int n = cmd_line_n; -+ int k = cmd_line_k; -+ -+ if (n < 1) { -+ n = alignment * ((rand() % 256) + 1); -+ } -+ -+ if (k < 1) { -+ k = alignment * ((rand() % 256) + 1); -+ } -+ -+ // SYR2K is defined only in terms of N and K. Replicate N into -+ // the SYR2K-N dimension. -+ cutlass::gemm::GemmCoord problem(n, n, k); -+ -+ problem_sizes.push_back(problem); -+ } -+ } -+ -+ /// Load a benchmark -+ bool benchmark_problems() { -+ std::ifstream file(benchmark_path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ while (file.good()) { -+ -+ int idx = -1; -+ std::string extent_str; -+ -+ file >> idx >> extent_str; -+ -+ if (idx < 0 || extent_str.empty()) { -+ break; -+ } -+ -+ cutlass::gemm::GemmCoord extent; -+ std::vector tokens; -+ -+ cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); -+ -+ for (int i = 0; i < int(tokens.size()); ++i) { -+ int x = std::atoi(tokens.at(i).c_str()); -+ -+ // round up -+ if (x % alignment) { -+ x += (alignment - (x % alignment)); -+ } -+ -+ extent.at(i) = x; -+ } -+ -+ if (extent.product()) { -+ problem_sizes.push_back(extent); -+ } -+ } -+ -+ problem_count = int(problem_sizes.size()); -+ return true; -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "38_syr2k_grouped\n\n" -+ << " This example profiles the performance of a 'grouped' SYR2K kernel. This example closely follows 24_gemm_grouped\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --benchmark= Executes a benchmark problem size.\n" -+ << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" -+ << " --tag= String tag to prepend to the CSV file.\n" -+ << " --groups= Number of individual SYR2K problems (default: --groups=15)\n" -+ << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n" -+ << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n" -+ << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" -+ << " --sort-problems= If true, sorts problem sizes in descending order of SYR2K-K dimension.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a grouped SYR2K with 100 random problem sizes\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100\n\n" -+ -+ << "# Runs a grouped SYR2K with 100 random problem sizes (with K dimension equal to 1024)\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped SYR2K that is equivalent to a batched SYR2K\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true\n\n" -+ -+ << "# Runs a grouped SYR2K with each different scheduler mode\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all\n\n" -+ -+ << "# Runs a grouped SYR2K with each different scheduler mode and profiles host-side initialization time\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all --profile-initialization=true\n\n" -+ -+ << "# Runs a grouped SYR2K problem given an externally supplied benchmark file. This is a text file in which\n" -+ << "# Each line contains a unique group index and an MxNxK triple indicating problemsize. NOTE that the\n" -+ << "# GEMM-M and GEMM-N dimensions must match.\n" -+ << "#\n" -+ << "# For example, assume the following are the contents of 'problems.txt'\n" -+ << "#\n" -+ << "# 0 256x256x520\n" -+ << "# 1 264x264x1024\n" -+ << "# 2 48x48x1024\n" -+ << "#\n" -+ << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n" -+ -+ << "# Execute Grouped SYR2K and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ for (auto const & problem : problem_sizes) { -+ fmas += problem.product(); -+ } -+ -+ // SYR2K is defined as (A x BT) + (B x AT), so the number of FMAs is twice that in a GEMM -+ fmas *= 2; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class BaseTestbed { -+public: -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Rank2K::LayoutA; -+ using LayoutB = typename Rank2K::LayoutB; -+ using LayoutC = typename Rank2K::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ BaseTestbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ int problem_count() const { -+ return options.problem_count; -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Allocates device-side data -+ void allocate() { -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ lda_host.resize(problem_count()); -+ ldb_host.resize(problem_count()); -+ ldc_host.resize(problem_count()); -+ ldd_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem = options.problem_sizes.at(i); -+ -+ lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.n() * problem.k(); -+ int64_t elements_B = problem.n() * problem.k(); -+ int64_t elements_C = problem.n() * problem.n(); -+ int64_t elements_D = problem.n() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ } -+ -+ lda.reset(problem_count()); -+ ldb.reset(problem_count()); -+ ldc.reset(problem_count()); -+ ldd.reset(problem_count()); -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ } -+ -+ /// Initializes device-side data -+ void initialize() { -+ problem_sizes_device.reset(problem_count()); -+ problem_sizes_device.copy_from_host(options.problem_sizes.data()); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ std::vector ptr_A_host(problem_count()); -+ std::vector ptr_B_host(problem_count()); -+ std::vector ptr_C_host(problem_count()); -+ std::vector ptr_D_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count()); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count()); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count()); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count()); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); -+ initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); -+ initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), block_D.size(), ElementC(), ElementC()); -+ } -+ -+ /// Verifies the result is a SYR2K -+ bool verify() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ cutlass::HostTensor host_A( -+ typename LayoutA::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); -+ cutlass::HostTensor host_B( -+ typename LayoutB::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); -+ cutlass::HostTensor host_C( -+ typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); -+ cutlass::HostTensor host_D( -+ typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); -+ -+ cutlass::device_memory::copy_to_host(host_A.host_data(), block_A.get() + offset_A.at(i), problem.n() * problem.k()); -+ cutlass::device_memory::copy_to_host(host_B.host_data(), block_B.get() + offset_B.at(i), problem.n() * problem.k()); -+ cutlass::device_memory::copy_to_host(host_C.host_data(), block_C.get() + offset_C.at(i), problem.n() * problem.n()); -+ cutlass::reference::host::BlockFillSequential( -+ host_D.host_data(), problem.n() * problem.n(), ElementC(), ElementC()); -+ -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ // Reference Rank2K -+ cutlass::reference::host::Rank2KComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementC, ElementAccumulator -+ >( -+ problem, -+ (double)options.alpha, -+ host_A.host_view(), -+ Rank2K::kTransformA, -+ host_B.host_view(), -+ Rank2K::kTransformB, -+ (double)options.beta, -+ host_C.host_view(), -+ host_D.host_view(), -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref = host_D.host_view(); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+}; -+ -+template -+class TestbedConventional : BaseTestbed { -+public: -+ TestbedConventional( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = -+ ((problem.m() + Rank2K::ThreadblockShape::kM - 1) / Rank2K::ThreadblockShape::kM) * -+ ((problem.n() + Rank2K::ThreadblockShape::kN - 1) / Rank2K::ThreadblockShape::kN); -+ -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Executes a conventional SYR2K kernel. -+ Result profile() { -+ std::cout << "Conventional Rank2K:\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // -+ // Create CUDA streams to maximize concurrency of SYR2K kernels -+ // -+ int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); -+ std::vector cuda_streams; -+ char const *provider = "CUTLASS"; -+ -+ // -+ // Warmup run -+ // -+ -+ if (this->options.cuda_streams) { -+ for (int i = 0; i < this->options.cuda_streams; ++i) { -+ cudaStream_t stream; -+ -+ result.error = cudaStreamCreate(&stream); -+ if (result.error != cudaSuccess) { -+ std::cerr << "Failed to create CUDA stream." << std::endl; -+ return result; -+ } -+ cuda_streams.push_back(stream); -+ } -+ } -+ else { -+ cuda_streams.push_back(nullptr); -+ } -+ -+ // Use 'D' for the in/out workspace -+ this->block_D.copy_from_device(this->block_C.get()); -+ -+ for (int i = 0; i < this->options.problem_sizes.size(); ++i) { -+ cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; -+ int32_t batch_count = 1; -+ int64_t lda = this->lda_host.at(i); -+ int64_t ldb = this->ldb_host.at(i); -+ int64_t ldc = this->ldc_host.at(i); -+ typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); -+ typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); -+ typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); -+ typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); -+ -+ // -+ // Initialize the CUTLASS SYR2K operator -+ // -+ -+ // Configure the SYR2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Rank2K::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptrA, -+ (void const *)ptrB, -+ (void const *)ptrC, -+ (void *)ptrD, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ cutlass::Status status = rank2k_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = rank2k_op(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Record an event at the start of a series of SYR2K operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ int last_stream_idx = 0; -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ for (int i = 0; i < this->options.problem_sizes.size(); ++i) { -+ cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; -+ int32_t batch_count = 1; -+ int64_t lda = this->lda_host.at(i); -+ int64_t ldb = this->ldb_host.at(i); -+ int64_t ldc = this->ldc_host.at(i); -+ typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); -+ typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); -+ typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); -+ typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); -+ -+ last_stream_idx = (i % effective_streams); -+ -+ // -+ // Initialize the CUTLASS SYR2K operator -+ // -+ -+ // Configure the SYR2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ typename Rank2K::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ epilogue_op, -+ (void const *)ptrA, -+ (void const *)ptrB, -+ (void const *)ptrC, -+ (void *)ptrD, -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(), -+ int64_t(lda), -+ int64_t(ldb), -+ int64_t(ldc), -+ int64_t(ldc) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ cutlass::Status status = rank2k_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ -+ status = rank2k_op(cuda_streams[last_stream_idx]); -+ -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; -+ return result; -+ } -+ } -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the SYR2K operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Wait for work to be completed -+ // -+ -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ for (auto stream : cuda_streams) { -+ if (stream) { -+ (void)cudaStreamDestroy(stream); -+ } -+ } -+ -+ std::cout << " " << this->options.problem_sizes.size() << " conventional Rank2Ks launched" << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Conventional Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Conventional GFLOPS: " << result.gflops << std::endl; -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << "," << provider << ",conventional," -+ << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ result.passed = true; -+ return result; -+ } -+}; -+ -+template -+class TestbedGrouped : BaseTestbed { -+public: -+ TestbedGrouped( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ) : BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} -+ -+ // Redefine Rank2K with different GroupScheduleMode_ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ typename Rank2K_::ElementA, typename Rank2K_::LayoutA, Rank2K_::kTransformA, Rank2K_::kAlignmentA, -+ typename Rank2K_::ElementB, typename Rank2K_::LayoutB, Rank2K_::kTransformB, Rank2K_::kAlignmentB, -+ typename Rank2K_::ElementC, typename Rank2K_::LayoutC, Rank2K_::kFillModeC, -+ typename Rank2K_::ElementAccumulator, -+ typename Rank2K_::OperatorClass, -+ typename Rank2K_::ArchTag, -+ typename Rank2K_::ThreadblockShape, -+ typename Rank2K_::WarpShape, -+ typename Rank2K_::InstructionShape, -+ typename Rank2K_::EpilogueOutputOp, -+ typename Rank2K_::ThreadblockSwizzle, -+ Rank2K_::kStages, -+ typename Rank2K_::Operator::ArchMmaOperator::Operator, -+ Rank2K_::kBlasMode, -+ GroupScheduleMode_>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ /// Verbose printing of problem sizes -+ void print_problem_sizes() { -+ -+ // Print groups -+ std::cout << this->problem_count() << " groups:\n"; -+ -+ int32_t idx = 0; -+ int64_t total_tiles = 0; -+ -+ for (auto const & problem : this->options.problem_sizes) { -+ int tiles = Rank2K::problem_tile_count(problem); -+ total_tiles += tiles; -+ -+ std::cout << " [" << idx << "]: " -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << " (" << tiles << " threadblock tiles)" << "\n"; -+ -+ ++idx; -+ } -+ std::cout << std::endl; -+ } -+ -+ /// Sort problems in descending order of problem-K dimension -+ void sort_problems() { -+ Rank2K::sort_problems(this->options.problem_count, -+ this->options.problem_sizes.data(), -+ this->lda_host.data(), -+ this->ldb_host.data(), -+ this->ldc_host.data(), -+ this->ldd_host.data(), -+ this->offset_A.data(), -+ this->offset_B.data(), -+ this->offset_C.data(), -+ this->offset_D.data()); -+ } -+ -+ /// Executes a grouped kernel and measures runtime. -+ Result profile() { -+ std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; -+ std::cout << std::endl; -+ std::cout << "Grouped Rank2K (CUTLASS) with mode " << sched_mode << ":\n" -+ << "====================================================" << std::endl; -+ -+ Result result; -+ -+ int threadblock_count = Rank2K::sufficient(this->options.problem_sizes.data(), this->options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped SYR2K kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ this->allocate(); -+ if (this->options.sort_problems) { -+ sort_problems(); -+ } -+ this->initialize(); -+ -+ if (this->options.verbose) { -+ print_problem_sizes(); -+ } -+ -+ // Configure the Rank2K arguments -+ typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); -+ -+ // Configure Rank2K arguments -+ typename Rank2K::Arguments args( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ this->problem_sizes_device.get(), -+ this->problem_count(), -+ threadblock_count, -+ epilogue_op, -+ this->ptr_A.get(), -+ this->ptr_B.get(), -+ this->ptr_C.get(), -+ this->ptr_D.get(), -+ this->lda.get(), -+ this->ldb.get(), -+ this->ldc.get(), -+ this->ldd.get(), -+ this->options.problem_sizes.data() -+ ); -+ -+ // Initialize the Rank2K object -+ Rank2K rank2k; -+ size_t workspace_size = rank2k.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = rank2k.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped Rank2K object -+ result.status = rank2k.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ if (this->options.reference_check) { -+ result.passed = this->verify(); -+ } -+ -+ // -+ // Warm-up run of the grouped Rank2K object -+ // -+ result.status = rank2k.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of SYR2K operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ rank2k(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the Rank2K operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ // Optionally profile initialization -+ if (this->options.profile_initialization) { -+ // Warm up -+ rank2k.initialize(args, workspace.get()); -+ -+ auto start_time = std::chrono::high_resolution_clock::now(); -+ for (int32_t i = 0; i < this->options.iterations; ++i) { -+ rank2k.initialize(args, workspace.get()); -+ } -+ auto end_time = std::chrono::high_resolution_clock::now(); -+ -+ std::chrono::duration duration = end_time - start_time; -+ duration /= double(this->options.iterations); -+ result.initialization_time_ms = duration.count(); -+ } -+ -+ int64_t total_tiles = Rank2K::group_tile_count(args); -+ std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; -+ if (this->options.profile_initialization) { -+ std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; -+ } -+ -+ if (this->options.output_file.good()) { -+ this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," -+ << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped Rank2K example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the Grouped and Conventional Rank2K types -+ // -+ -+ using ElementA = double; -+ using ElementB = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ const cutlass::FillMode kFillModeC = cutlass::FillMode::kLower; -+ const int kAlignmentA = 1; -+ const int kAlignmentB = 1; -+ const cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; -+ const cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using OperatorClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>; -+ -+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. -+ // This parameter is passed in at present to match the APIs of other kernels. The parameter -+ // is unused within the kernel. -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ const int kStages = 4; -+ const bool kSplitKSerial = false; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ const cutlass::BlasMode kBlasMode = cutlass::BlasMode::kSymmetric; -+ -+ // Define a grouped Rank2K kernel with all template parameters set except -+ // for scheduling mode. This will be used as the template for all scheduling -+ // modes executed. -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, kTransformA, kAlignmentA, -+ ElementB, LayoutB, kTransformB, kAlignmentB, -+ ElementOutput, LayoutC, kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ Operator, -+ kBlasMode>::Rank2Kkernel; -+ -+ using Rank2KGrouped = cutlass::gemm::device::Rank2KGrouped; -+ -+ // Rank2k operator -+ using Rank2KConventional = cutlass::gemm::device::Rank2K< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementOutput, LayoutC, kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kTransformB, -+ kBlasMode -+ >; -+ -+ // -+ // Profile it -+ // -+ -+ TestbedConventional testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS conventional Rank2K has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; -+ for (GroupScheduleMode mode : options.scheduler_modes) { -+ Result result; -+ switch (mode) { -+ case GroupScheduleMode::kDeviceOnly: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ case GroupScheduleMode::kHostPrecompute: -+ { -+ TestbedGrouped runner(options); -+ result = runner.profile(); -+ break; -+ } -+ } -+ -+ if (result.error != cudaSuccess) { -+ return 1; -+ } -+ -+ // Override verbose flag to avoid printing duplicate information for each scheduling mode -+ options.verbose = false; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu b/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu -new file mode 100644 -index 0000000..ed3e399 ---- /dev/null -+++ b/3rdparty/cutlass/examples/39_gemm_permute/gemm_permute.cu -@@ -0,0 +1,1126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief GEMM Permute Example. -+ -+ This example computes batched GEMM operations with output results permuted as reshaped tensors. -+ -+ We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation, -+ or any other generalized global memory writeout address computation. To add a customized layout, add new class -+ in include/cutlass/layout/permute.h -+ -+ In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM -+ whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on -+ output matrix. The address computations are performed in compute(col_init, row_init, stride_init, -+ BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. -+ (check include/cutlass/layout/permute.h) -+ -+ Tips: -+ -+ 1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode -+ cutlass::gemm::GemmUniversalMode::kBatched instead of kArray -+ -+ 2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should -+ be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example. -+ As a result, permute op without touching the last dimension is recommended to obtain the best performance gain. -+ -+ Examples: -+ -+ # Runs a batched GEMM with 96 batches -+ $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 -+ -+ # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) -+ $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true -+ -+ # Execute batched GEMM and profile with NSight -+ $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/permute.h" -+ -+/// Tensor4DPermuteBMM0213 ---> -+/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -+/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. -+const int D1 = 12; -+ -+/// Tensor5DPermute20314 ---> -+/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -+const int T1 = 16; -+const int T2 = 3; -+const int T3 = 8; -+ -+// Alignment C -+const int AlignmentC = 8; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ -+ cutlass::gemm::GemmCoord problem_each; -+ -+ int batch_count; -+ int iterations; -+ int cuda_streams; -+ bool verbose; -+ float alpha; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ reference_check(true), -+ batch_count(-1), -+ iterations(20), -+ cuda_streams(0), -+ verbose(false), -+ alpha(1), -+ beta() -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("verbose", verbose, false); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ -+ int m, n, k; -+ -+ cmd.get_cmd_line_argument("m", m, 128); -+ cmd.get_cmd_line_argument("n", n, 192); -+ cmd.get_cmd_line_argument("k", k, 128); -+ cmd.get_cmd_line_argument("batch-count", batch_count, 768); -+ -+ cutlass::gemm::GemmCoord problem(m, n, k); -+ problem_each = problem; -+ -+ if (batch_count % D1 != 0){ -+ std::cerr << "\nProblem count error (problem-count = " << batch_count << "). " -+ << "problem-count needs to be divided with no remain by " << D1 << " (D1)." -+ << " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n"; -+ error = true; -+ } -+ -+ if (m % (AlignmentC * T1) != 0){ -+ std::cerr << "\nProblem m size error (m = " << m << "). " -+ << "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)." -+ << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; -+ error = true; -+ } -+ -+ if (n % (AlignmentC * (T2 * T3)) != 0){ -+ std::cerr << "\nProblem n size error (n = " << n << "). " -+ << "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)." -+ << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; -+ error = true; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "39_gemm_permute\n\n" -+ << " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output" -+ << " (including output matrices for each batch) as permuted 4D Tensor." -+ << " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with" -+ << " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n" -+ << " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor." -+ << " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted" -+ << " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N/T2/T3].\n\n" -+ << " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" -+ << " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" -+ << " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" -+ << " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --verbose= If true, prints problem sizes and batching structure.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a batched GEMM with 96 batches\n" -+ << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n" -+ -+ << "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" -+ << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n" -+ -+ << "# Execute batched GEMM and profile with NSight\n" -+ << "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = int64_t(); -+ -+ fmas += problem_each.product() * batch_count; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Testbed { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename GemmBatched::ElementA; -+ using ElementB = typename GemmBatched::ElementB; -+ using ElementC = typename GemmBatched::ElementC; -+ using ElementAccumulator = typename GemmBatched::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename GemmBatched::LayoutA; -+ using LayoutB = typename GemmBatched::LayoutB; -+ using LayoutC = typename GemmBatched::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3090 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Verbose BMM info -+ void print_BMM_info_() { -+ -+ // Print batched GEMM -+ std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n"; -+ -+ auto problem = options.problem_each; -+ std::cout -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() -+ << ", batch count: " << options.batch_count << "\n"; -+ -+ std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", " -+ << problem.n() <<"]\n"; -+ std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", " -+ << problem.m() << ", " << problem.n() <<"]\n"; -+ std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", " -+ << D1 << ", " << problem.n() <<"]\n"; -+ -+ std::cout << "----------------------------------------------------\n"; -+ -+ } -+ -+ /// Verbose normal GEMM info -+ void print_GEMM_info_() { -+ -+ // Print batched GEMM -+ std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n"; -+ -+ auto problem = options.problem_each; -+ std::cout -+ << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n"; -+ -+ std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl; -+ std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", " -+ << T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; -+ std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", " -+ << T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; -+ -+ std::cout << "----------------------------------------------------\n"; -+ -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_(int batch_count) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; -+ int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; -+ int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; -+ int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; -+ -+ // -+ // Assign space -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); -+ initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); -+ initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); -+ -+ cutlass::reference::device::BlockFillSequential( -+ block_D.get(), total_elements_D, ElementC(), ElementC()); -+ } -+ -+ /// Verifies the BMM GEMM result -+ bool verify_BMM_() { -+ -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_each; -+ -+ LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); -+ LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); -+ LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C) * options.batch_count); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ GemmBatched::kTransformA, -+ view_B, -+ GemmBatched::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0), -+ options.batch_count, -+ options.problem_each.m() * options.problem_each.k(), -+ options.problem_each.n() * options.problem_each.k(), -+ options.problem_each.m() * options.problem_each.n(), -+ options.problem_each.m() * options.problem_each.n() -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C) * options.batch_count); -+ std::vector matrix_Ref(layout_D.capacity(extent_C) * options.batch_count); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ // Print out the results and reference in 4D Tensor -+ // [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3]. -+ // After permute Op, -> [D0, D2, D1, D3]. -+ int D0 = options.batch_count / D1; -+ int D2 = options.problem_each.m(); -+ int D3 = options.problem_each.n(); -+ -+ cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently -+ cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3})); -+ -+ cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), -+ cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3})); -+ -+ // Tensor Permute Op on reference tensor -+ cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3})); -+ for (int n = 0; n < D0; ++n) { -+ for (int h = 0; h < D1; ++h) { -+ for (int w = 0; w < D2; ++w) { -+ for (int c = 0; c < D3; ++c) { -+ view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c}); -+ } -+ } -+ } -+ } -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ -+ std::cout << "Passed verification" << std::endl; -+ return passed; -+ } -+ -+ bool verify_GEMM_normal_() { -+ -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem = options.problem_each; -+ -+ LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); -+ LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); -+ LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); -+ cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); -+ cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); -+ -+ cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ options.alpha, -+ view_A, -+ GemmBatched::kTransformA, -+ view_B, -+ GemmBatched::kTransformB, -+ options.beta, -+ view_C, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); -+ -+ // Print out the results and reference in 5D Tensor -+ // [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4]. -+ // options.problem_each.m() == T0 * T1 -+ // options.problem_each.n() == T2 * T3 * T4 -+ // After permute Op, -> [T2, T0, T3, T1, T4]. -+ int T0 = options.problem_each.m() / T1; -+ int T4 = options.problem_each.n() / T2 / T3; -+ -+ cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently -+ cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); -+ cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), -+ cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})); -+ -+ // Tensor Permute Op on reference tensor -+ cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); -+ for (int n = 0; n < T0; ++n) { -+ for (int d = 0; d < T1; ++d) { -+ for (int h = 0; h < T2; ++h) { -+ for (int w = 0; w < T3; ++w) { -+ for (int c = 0; c < T4; ++c) { -+ view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4]) -+ } -+ } -+ } -+ } -+ } -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ return passed; -+ } -+ -+ std::cout << "Passed verification" << std::endl; -+ return passed; -+} -+ -+public: -+ /// Executes a conventional batched GEMM kernel. -+ Result profile_batched_kBatched() { -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Batched GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ if (options.verbose) { -+ print_BMM_info_(); -+ } -+ -+ Result result; -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(options.batch_count); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Please make sure all problem_sizes are the same for kBatched mode -+ auto problem = options.problem_each; -+ -+ // For regular BMM -+ int64_t batch_stride_C = problem.m() * problem.n(); -+ // For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op -+ int64_t batch_stride_D = 0; -+ -+ // Configure GEMM arguments -+ typename GemmBatched::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kBatched, -+ options.problem_each, -+ options.batch_count, -+ epilogue_op, -+ (void*)block_A.get(), -+ (void*)block_B.get(), -+ (void*)block_C.get(), -+ (void*)block_D.get(), -+ problem.m() * problem.k(), -+ problem.n() * problem.k(), -+ batch_stride_C, -+ batch_stride_D, -+ problem.k(), -+ problem.n(), -+ problem.n(), -+ problem.n() -+ }; -+ -+ // Initialize the GEMM object -+ GemmBatched gemm; -+ -+ result.status = gemm.initialize(arguments, nullptr); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the batched GEMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_BMM_(); -+ } -+ -+ // -+ // Warm-up run of the batched GEMM object -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << " " << 1 << " batched GEMMs launched\n"; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n"; -+ std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n"; -+ -+ return result; -+ } -+ -+ Result profile_GEMM_permute() { -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Normal GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ if (options.verbose) { -+ print_GEMM_info_(); -+ } -+ -+ Result result; -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(1); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Please make sure all problem_sizes are the same for kBatched mode -+ auto problem = options.problem_each; -+ -+ // Configure GEMM arguments -+ typename GemmPermute::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ options.problem_each, -+ 1, -+ epilogue_op, -+ (void*)block_A.get(), -+ (void*)block_B.get(), -+ (void*)block_C.get(), -+ (void*)block_D.get(), -+ 0, -+ 0, -+ 0, -+ 0, -+ problem.k(), -+ problem.n(), -+ problem.n(), -+ problem.n() -+ }; -+ -+ // Initialize the GEMM object -+ GemmPermute gemm_normal; -+ -+ result.status = gemm_normal.initialize(arguments, nullptr); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the normal GEMM object -+ result.status = gemm_normal.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_GEMM_normal_(); -+ } -+ -+ // -+ // Warm-up run of the normal GEMM object -+ // -+ result.status = gemm_normal.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm_normal(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n"; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ // -+ // Define the GEMM types -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ // -+ // Define a conventional batched GEMM type -+ // -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmBatched = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ AlignmentC, //128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ 8, /*alignmentA*/ -+ 8, /*alignmengB*/ -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ cutlass::layout::Tensor4DPermuteBMM0213 /*PermuteDLayout*/ -+ >; -+ -+ // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 -+ using GemmPermute = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, LayoutA, -+ cutlass::half_t, LayoutB, -+ ElementOutput, LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ AlignmentC, //128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ 8, /*alignmentA*/ -+ 8, /*alignmengB*/ -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ cutlass::layout::Tensor5DPermute20314 /*PermuteDLayout*/ -+ >; -+ -+ // -+ // Profile it -+ // -+ -+ Testbed testbed(options); -+ -+ Result result; -+ result = testbed.profile_batched_kBatched(); -+ if (!result.passed) { -+ std::cout << "Profiling batched GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ } else { -+ std::cout << "\nPassed CUTLASS batched GEMM\n"; -+ } -+ -+ result = testbed.profile_GEMM_permute(); -+ if (!result.passed) { -+ std::cout << "Profiling normal GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ } else { -+ std::cout << "\nPassed CUTLASS normal GEMM\n"; -+ } -+ -+ std::cout << "\n====================================================" << std::endl; -+ std::cout << "Finished\n"; -+ std::cout << "====================================================" << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h -new file mode 100644 -index 0000000..4de04ef ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h -@@ -0,0 +1,513 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/functional.h" -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/matrix_shape.h" -+#include "gemm_kernel_utils.h" -+ -+namespace { -+ -+static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { -+ // source: https://stackoverflow.com/a/51549250 -+ return (value >= 0) -+ ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) -+ : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); -+} -+} // namespace -+ -+/* Iterates on the accumulator and corresponding position on result matrix -+ -+(1) Update `mi[r]` to the max value of the row `r` -+(2) In a second iteration do the following: -+ (a) accum <- exp(accum - mi) -+ (b) m_prime <- exp(m_prime - mi) -+ (c) s_prime <- s_prime * m_prime + sum(accum) -+ -+All of this is done on registers, before we store all of this -+on shared memory for the next matmul with Value. -+ -+We have multiple implementations, because each configuration has a different way -+of iterating in the accumulators. -+*/ -+ -+template -+struct RegisterOps { -+ template < -+ int kQueriesPerBlock, -+ bool kFullColumns, -+ bool kIsFirst, -+ bool kKeepOutputInRF> -+ CUTLASS_DEVICE static void update( -+ typename T::Fragment& frag_o, // output so far -+ typename T::Fragment& frag, -+ cutlass::Array& mi, -+ cutlass::Array& m_prime, -+ cutlass::Array& s_prime, -+ int8_t lane_id, -+ int8_t thread_id, -+ int8_t warp_id, -+ int16_t max_col, -+ typename T::TensorCoord const& tile_offset, -+ float scaling) { -+ // Convert to `accum_t` (rather than double) -+ constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E -+ if (!kIsFirst) { -+ if (thread_id < kQueriesPerBlock) { -+ m_prime[thread_id] = mi[thread_id]; -+ } -+ __syncthreads(); -+ } -+ -+ auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); -+ -+ // First update `mi` to the max per-row -+ { -+ accum_t max; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ max = -cutlass::platform::numeric_limits::infinity(); -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (kFullColumns || accum_n < max_col) { -+ max = cutlass::fast_max(max, frag[idx]); -+ } -+ }, -+ [&](int accum_m) { -+ // Having 4x atomicMax seems faster than reduce within warp -+ // first... -+ atomicMaxFloat(&mi[accum_m], max * scaling); -+ }); -+ } -+ frag = cutlass::multiplies()(scaling * kLog2e, frag); -+ -+ // Make sure we all share the update values for `mi` -+ __syncthreads(); -+ -+ if (thread_id < kQueriesPerBlock) { -+ auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); -+ m_prime[thread_id] = m_prime_exp; -+ s_prime[thread_id] *= m_prime_exp; -+ } -+ __syncthreads(); // Update output fragments -+ if (kKeepOutputInRF && !kIsFirst) { -+ accum_t mp; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { mp = m_prime[accum_m]; }, -+ [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, -+ [&](int accum_m) {}); -+ __syncthreads(); -+ } -+ // Update accum_m, accum_n, ... -+ { -+ accum_t mi_row, total_row; -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, -+ [&](int accum_m, int accum_n, int idx) { -+ frag[idx] = (kFullColumns || accum_n < max_col) -+ ? exp2f(frag[idx] - mi_row) -+ : accum_t(0.0); -+ }, -+ [&](int accum_m) {}); -+ BASE::iterateRows( -+ lane_offset, -+ [&](int accum_m) { total_row = 0.0; }, -+ [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, -+ [&](int accum_m) { -+ if (BASE::reduceSameRow( -+ lane_id, total_row, [](accum_t a, accum_t b) { -+ return a + b; -+ })) { -+ atomicAdd(&s_prime[accum_m], total_row); -+ } -+ }); -+ } -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterSm80 -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterSm80, -+ T, -+ accum_t, -+ kWarpSize> { -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ using Policy = typename T::Policy; -+ using InstructionShape = typename T::InstructionShape; -+ using OpDelta = typename T::OpDelta; -+ using Shape = typename T::Shape; -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ return cutlass::MatrixCoord( -+ quad + tile_offset.row() * Shape::kRow, -+ lane_in_quad * kElementsPerAccess + -+ tile_offset.column() * Shape::kColumn); -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + -+ col + lane_offset.column(); -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ op(accum_m, accum_n, idx); -+ } -+ } -+ -+ endRow(accum_m); -+ } -+ } -+ } -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ // In each warp, 4 threads will work on the same row -+ // - the ones with the same `quad` -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); -+ myValue = fn(myValue, otherV); -+ otherV = __shfl_xor_sync(0xffffffff, myValue, 2); -+ myValue = fn(myValue, otherV); -+ int lane_in_quad = (lane_id & 3); -+ return lane_in_quad == 0; -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterVolta -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterVolta, -+ T, -+ accum_t, -+ kWarpSize> { -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ using Policy = typename T::Policy; -+ using InstructionShape = typename T::InstructionShape; -+ using OpDelta = typename T::OpDelta; -+ using Shape = typename T::Shape; -+ using Element = accum_t; -+ -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename cutlass::platform::conditional< -+ cutlass::platform::is_same::value, -+ cutlass::MatrixShape<2, 2>, -+ cutlass::MatrixShape<1, 4>>::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (cutlass::platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = -+ ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + -+ lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ return cutlass::MatrixCoord( -+ accum_m + tile_offset.row() * Shape::kRow, -+ accum_n + tile_offset.column() * Shape::kColumn); -+ } -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ static_assert( -+ cutlass::platform::is_same::value, -+ "update to support non-float accum"); -+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 -+ // T0 & T2 share same line within a quad -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); -+ myValue = fn(myValue, otherV); -+ // quad 0 and quad 2 are on the same lines -+ otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); -+ myValue = fn(myValue, otherV); -+ return (lane_id & ((1 << 1) | (1 << 3))) == 0; -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; -+ ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; -+ ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + -+ mma_n) * -+ Policy::MmaIterations::kRow + -+ mma_m) * -+ kElementsPerMma; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn / 2 + n + -+ lane_offset.column(); -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ op(accum_m, accum_n, idx); -+ } -+ } -+ } -+ } -+ endRow(accum_m); -+ } -+ } -+ } -+ } -+}; -+ -+template -+struct AttentionScalingCoefsUpdaterSimt -+ : RegisterOps< -+ AttentionScalingCoefsUpdaterSimt, -+ T, -+ accum_t, -+ kWarpSize> { -+ using Policy = typename T::Policy; -+ using Iterations = typename T::Iterations; -+ using Element = typename T::Element; -+ using Delta = typename T::Delta; -+ using Shape = typename T::Shape; -+ static_assert( -+ cutlass::platform:: -+ is_same::value, -+ "only RowMajor is supported"); -+ -+ template -+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { -+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); -+ myValue = fn(myValue, otherV); -+ } -+ return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; -+ } -+ -+ template -+ CUTLASS_DEVICE static void iterateRows( -+ cutlass::MatrixCoord& lane_offset, -+ FA beginRow, -+ FB op, -+ FC endRow) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); -+ beginRow(accum_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ int accum_n = -+ mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + -+ lane_offset.column(); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ int idx = n + -+ Policy::LaneMmaShape::kN * -+ (mma_n + -+ Iterations::kColumn * -+ (m + mma_m * Policy::LaneMmaShape::kM)); -+ op(accum_m, accum_n + n, idx); -+ } -+ } -+ endRow(accum_m); -+ } -+ } -+ } -+ -+ static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( -+ int8_t lane_id, -+ int8_t warp_id, -+ typename T::TensorCoord const& tile_offset) { -+ static_assert( -+ cutlass::platform::is_same< -+ typename Policy::LaneLayout, -+ cutlass::layout::RowMajorInterleaved<1>>::value, -+ ""); -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ cutlass::MatrixCoord(Policy::LaneMmaShape::kM, -+ Policy::LaneMmaShape::kN); -+ return lane_offset + -+ tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); -+ } -+}; -+ -+template -+struct DefaultAttentionScalingCoefsUpdater; -+ -+// Simt -+template -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaSimtTileIterator< -+ S, -+ cutlass::gemm::Operand::kC, -+ accum_t, -+ cutlass::layout::RowMajor, -+ P, -+ 1, -+ 1>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< -+ S, -+ cutlass::gemm::Operand::kC, -+ accum_t, -+ cutlass::layout::RowMajor, -+ P, -+ 1, -+ 1>; -+ using Updater = -+ AttentionScalingCoefsUpdaterSimt; -+}; -+ -+// TensorOp - Volta -+template -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ cutlass::MatrixShape<1, 1>>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = -+ typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ cutlass::MatrixShape<1, 1>>; -+ using Updater = -+ AttentionScalingCoefsUpdaterVolta; -+}; -+ -+// TensorOp - Sm75+ -+template < -+ typename S1, -+ typename S2, -+ typename S3, -+ typename accum_t, -+ int kWarpSize> -+struct DefaultAttentionScalingCoefsUpdater< -+ cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ S3>, -+ accum_t, -+ kWarpSize> { -+ using Iterator = -+ typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ S1, -+ accum_t, -+ cutlass::layout::RowMajor, -+ S2, -+ S3>; -+ using Updater = -+ AttentionScalingCoefsUpdaterSm80; -+}; -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h -new file mode 100644 -index 0000000..73a258e ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/debug_utils.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#include -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Debugging functions -+//////////////////////////////////////////////////////////////////////////////// -+// Nans & inf detection -+#define NANCHECK(frag) \ -+ { \ -+ for (int _i = 0; _i < frag.size(); ++_i) { \ -+ assert(std::isfinite(float(frag[_i]))); \ -+ assert(!std::isnan(float(frag[_i]))); \ -+ } \ -+ } -+ -+// Print on the first thread of the first block -+#if 0 -+#define PRINT_WARP_ID 0 -+#define PRINT_LANE_ID 0 -+#define PRINT_T0_L0(msg, ...) \ -+ if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ -+ threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ -+ threadIdx.z == 0) { \ -+ printf(msg "\n", __VA_ARGS__); \ -+ } -+struct __string_view { -+ char const* data; -+ std::size_t size; -+}; -+template -+constexpr __string_view __get_type_name() { -+ char const* p = __PRETTY_FUNCTION__; -+ while (*p++ != '=') -+ ; -+ for (; *p == ' '; ++p) -+ ; -+ char const* p2 = p; -+ int count = 1; -+ for (;; ++p2) { -+ switch (*p2) { -+ case '[': -+ ++count; -+ break; -+ case ']': -+ --count; -+ if (!count) -+ return {p, std::size_t(p2 - p)}; -+ } -+ } -+ return {}; -+} -+#else -+#define PRINT_T0_L0 -+#endif -+ -+// Print a given array -+#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ -+ PRINT_T0_L0( \ -+ "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ -+ name, \ -+ int(start), \ -+ int(start + 8), \ -+ float(accum[start + 0]), \ -+ float(accum[start + 1]), \ -+ float(accum[start + 2]), \ -+ float(accum[start + 3]), \ -+ float(accum[start + 4]), \ -+ float(accum[start + 5]), \ -+ float(accum[start + 6]), \ -+ float(accum[start + 7])); -+#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) -+#define PRINT_FRAG_T0_L0(name, frag) \ -+ { \ -+ auto typeStr = __get_type_name(); \ -+ PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \ -+ for (int _start = 0; _start < frag.size(); _start += 8) { \ -+ PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ -+ } \ -+ /*__syncthreads(); \ -+ NANCHECK(frag); */ \ -+ } -+#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ -+ { \ -+ PRINT_T0_L0("printing %s (len=%d)", name, int(length)); \ -+ for (int _start = 0; _start < length; _start += incr) { \ -+ PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ -+ } \ -+ } -+#define PRINT_ARRAY_T0_L0(name, array, length) \ -+ PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) -+ -+// Print a 4x4 matrix -+#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ -+ PRINT_T0_L0( \ -+ "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ -+ name, \ -+ int(start_x), \ -+ int(start_x + 4), \ -+ int(start_y), \ -+ int(start_y + 4), \ -+ float(ref.at({start_x + 0, start_y + 0})), \ -+ float(ref.at({start_x + 0, start_y + 1})), \ -+ float(ref.at({start_x + 0, start_y + 2})), \ -+ float(ref.at({start_x + 0, start_y + 3})), \ -+ float(ref.at({start_x + 1, start_y + 0})), \ -+ float(ref.at({start_x + 1, start_y + 1})), \ -+ float(ref.at({start_x + 1, start_y + 2})), \ -+ float(ref.at({start_x + 1, start_y + 3})), \ -+ float(ref.at({start_x + 2, start_y + 0})), \ -+ float(ref.at({start_x + 2, start_y + 1})), \ -+ float(ref.at({start_x + 2, start_y + 2})), \ -+ float(ref.at({start_x + 2, start_y + 3})), \ -+ float(ref.at({start_x + 3, start_y + 0})), \ -+ float(ref.at({start_x + 3, start_y + 1})), \ -+ float(ref.at({start_x + 3, start_y + 2})), \ -+ float(ref.at({start_x + 3, start_y + 3}))); -+#define PRINT_TENSOR4x4_T0_L0(name, ref) \ -+ PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) -+ -+#define PRINT_PROBLEM_SIZE(name, ps) \ -+ PRINT_T0_L0( \ -+ "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ -+ name, \ -+ int(ps.m()), \ -+ int(ps.n()), \ -+ int(ps.k())) -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h -new file mode 100644 -index 0000000..5a1ed5c ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h -@@ -0,0 +1,284 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "fmha_grouped.h" -+#include "gemm_kernel_utils.h" -+#include "find_default_mma.h" -+#include "attention_scaling_coefs_updater.h" -+#include "mma_from_smem.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ // The datatype of Q/K/V -+ typename scalar_t_, -+ // Architecture we are targeting (eg `cutlass::arch::Sm80`) -+ typename ArchTag_, -+ // If Q/K/V are correctly aligned in memory and we can run a fast kernel -+ bool isAligned_, -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration, -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly -+ > -+struct DefaultFMHAGrouped { -+ using scalar_t = scalar_t_; -+ using accum_t = float; -+ using output_t = scalar_t; -+ -+ // Accumulator between 2 iterations -+ // Using `accum_t` improves perf on f16 at the cost of -+ // numerical errors -+ using output_accum_t = accum_t; -+ -+ using ArchTag = ArchTag_; -+ static bool const kIsAligned = isAligned_; -+ static int const kWarpSize = 32; -+ static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); -+ -+ struct MM0 { -+ /* -+ In this first matmul, we compute a block of `Q @ K.T`. -+ While the calculation result is still hot in registers, we update -+ `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value -+ into a shared-memory ("AccumulatorSharedStorage") that is used later as -+ operand A for the second matmul (see MM1) -+ */ -+ -+ using GemmType = gemm_kernel_utils::DefaultGemmType; -+ using OpClass = typename GemmType::OpClass; -+ -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = scalar_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator -+ >; -+ -+ static int const kAlignmentA = -+ kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; -+ static int const kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using InstructionShape = typename GemmType::InstructionShape; -+ -+ static int const kStages = DefaultConfig::kStages; -+ using Operator = typename GemmType::Operator; -+ -+ using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages, -+ Operator -+ >::DefaultMma; -+ -+ using MmaCore = typename DefaultMma::MmaCore; -+ using IteratorA = typename DefaultMma::IteratorA; -+ using IteratorB = typename DefaultMma::IteratorB; -+ using Mma = typename DefaultMma::ThreadblockMma; -+ using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< -+ typename Mma::Operator::IteratorC, -+ ElementAccumulator, -+ kWarpSize>::Updater; -+ -+ static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); -+ -+ // Epilogue to store to shared-memory in a format that we can use later for -+ // the second matmul -+ using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< -+ typename Mma::Operator::IteratorC, -+ typename Mma::Operator, -+ scalar_t, -+ WarpShape, -+ ThreadblockShape>; -+ using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; -+ }; -+ -+ struct MM1 { -+ /* -+ Second matmul: perform `attn @ V` where `attn` is the attention (not -+ normalized) and stored in shared memory -+ */ -+ -+ using GemmType = typename MM0::GemmType; -+ using OpClass = typename GemmType::OpClass; -+ -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = output_accum_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator -+ >; -+ -+ static int const kAlignmentA = DefaultConfig::kAlignmentA; -+ static int const kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ -+ using ThreadblockShape = typename MM0::ThreadblockShape; -+ using WarpShape = typename MM0::WarpShape; -+ using InstructionShape = typename MM0::InstructionShape; -+ -+ using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; -+ -+ static int const kStages = DefaultConfig::kStages; -+ using Operator = typename GemmType::Operator; -+ -+ using ThreadblockSwizzle = void; // Swizzling is unused -+ static bool const kSplitKSerial = false; -+ -+ using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator>; -+ -+ using DefaultMmaFromSmem = -+ typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< -+ typename DefaultGemm::Mma, -+ typename MM0::AccumulatorSharedStorage>; -+ -+ using Mma = typename DefaultMmaFromSmem::Mma; -+ using IteratorB = typename Mma::IteratorB; -+ using WarpCount = typename Mma::WarpCount; -+ static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); -+ -+ using DefaultEpilogue = typename DefaultGemm::Epilogue; -+ using OutputTileIterator = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_t>; -+ using OutputTileIteratorAccum = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_accum_t>; -+ -+ struct SharedStorageMM1 { -+ typename Mma::SharedStorage mm; -+ }; -+ }; -+ -+/// Define the kernel in terms of the default kernel -+ using FMHAKernel = kernel::FMHAGrouped< -+ MM0, -+ MM1, -+ scalar_t, -+ accum_t, -+ output_t, -+ output_accum_t, -+ kSingleValueIteration, -+ GroupScheduleMode_ -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h -new file mode 100644 -index 0000000..2a574e7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_pipelined.h -@@ -0,0 +1,632 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ File copied from "cutlass/epilogue/threadblock/epilogue.h" -+ then modified to: -+ (1) load 2 source fragments at the same time (pipelining) -+ (2) support reading from a different dtype -+ (3) pass the row id to the OutputOp if it takes it -+ (see MemoryEfficientAttentionNormalize) -+ Note that in general the fragment passed to the OutputOp could -+ span multiple rows but it does not happen with the configurations we have -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template -+struct ApplyEpilogueOp { -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum, -+ typename Op::FragmentOutput const& source) { -+ return output_op(accum, source); -+ } -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum) { -+ return output_op(accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: -+ ///< gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting -+ ///< accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing -+ ///< accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading -+ ///< from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank -+ ///< conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = -+ 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is -+ ///< large -+ (!IsEpilogueFunctorHeavy::value), -+ typename OutputTileSourceIterator_ = -+ OutputTileIterator_ ///< Tile iterator reading tensors -+ > -+class EpiloguePipelined : public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ public: -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using OutputTileSourceIterator = OutputTileSourceIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ using ElementSource = typename OutputTileSourceIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = -+ typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, -+ OutputTileIterator::kElementsPerAccess>; -+ using SourceAccessType = Array< -+ typename OutputTileSourceIterator::Element, -+ OutputTileSourceIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array< -+ typename WarpTileIterator::Element, -+ OutputTileIterator::kElementsPerAccess>; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 -+ ? Base::kFragmentsPerIteration -+ : kPartitionsK; -+ static int constexpr kSmemPointerOffset = -+ Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ public: -+ static_assert( -+ OutputTileSourceIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between input tile and output tile iterator (kElements)"); -+ static_assert( -+ OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, -+ "Mismatch between input tile and output tile iterator (kIterations)"); -+ static_assert( -+ SharedLoadIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert( -+ OutputTileIterator::kElementsPerAccess, -+ "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert( -+ !(OutputTileIterator::Fragment::kElements % -+ OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ private: -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ public: -+ /// Constructor -+ CUTLASS_DEVICE -+ EpiloguePipelined( -+ typename Base::SharedStorage& shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx) {} -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators, ///< Complete warp-level accumulator tile -+ OutputTileSourceIterator -+ source_iterator) { ///< Threadblock tile coordinate in GEMM (in units -+ ///< of threadblock tiles) -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } else { -+ compute_source_needed_( -+ output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators) { ///< Complete warp-level accumulator tile -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } -+ -+ private: -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper( -+ AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator& warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset( -+ kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push( -+ size_t pos, -+ AccumulatorFragmentIterator const& iterator_begin, -+ WarpTileIterator& warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper( -+ iterator_begin, warp_tile_iterator), -+ 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ static_assert( -+ kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, -+ "One of these must be exactly 1."); -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators ///< Complete warp-level accumulator tile -+ ) { -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+#pragma unroll( \ -+ IterationsUnroll \ -+ ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ -+ : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; -+ iter += Base::kFragmentsPerIteration) { -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed>:: -+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename SharedLoadIterator::Fragment -+ aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } else if (kPartitionsK > 1) { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments( -+ aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset( -+ (1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_source_not_needed_( -+ destination_iterator.thread_start_row(), -+ output_fragment, -+ output_op, -+ aligned_accum_fragment[0]); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset( -+ kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE static void helper( -+ AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator& warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push( -+ size_t pos, -+ AccumulatorFragmentIterator const& iterator_begin, -+ WarpTileIterator& warp_tile_iterator) { -+ int dummy[] = { -+ (pos == Seq) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const& output_op, ///< Output operator -+ OutputTileIterator -+ destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const& -+ accumulators, ///< Complete warp-level accumulator tile -+ OutputTileSourceIterator -+ source_iterator ///< Threadblock tile coordinate in GEMM (in units of -+ ///< threadblock tiles) -+ ) { -+ typename OutputTileSourceIterator::Fragment source_fragment[2]; -+ -+ source_fragment[0].clear(); -+ source_iterator.load(source_fragment[0]); -+ ++source_iterator; -+ source_fragment[1].clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ if (iter > 0) { -+ __syncthreads(); -+ } -+ // -+ // Load the source for next iteration (pipelining) -+ // -+ -+ if (iter + 1 < OutputTileIterator::kIterations) { -+ source_iterator.load(source_fragment[(iter + 1) % 2]); -+ } -+ ++source_iterator; -+ acc2smem_source_needed< -+ cutlass::make_index_sequence>:: -+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment -+ aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the -+ // k-slices -+ if (kPartitionsK > 1) { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments( -+ aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset( -+ (1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_( -+ destination_iterator.thread_start_row(), -+ output_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment[iter % 2]); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ int begin_row, -+ typename OutputTileIterator::Fragment& output_fragment, -+ OutputOp const& output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment, -+ typename OutputTileSourceIterator::Fragment const& source_fragment) { -+ OutputAccessType* output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const* compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ SourceAccessType const* source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / -+ OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = ApplyEpilogueOp::apply( -+ output_op, -+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), -+ compute_frag_ptr[i], -+ source_frag_ptr[i]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ int begin_row, -+ typename OutputTileIterator::Fragment& output_fragment, -+ OutputOp const& output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { -+ OutputAccessType* output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const* compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / -+ OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = ApplyEpilogueOp::apply( -+ output_op, -+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), -+ compute_frag_ptr[i]); -+ } -+ } -+ -+ // This should be constexpr, but it's only supported on c++14 -+ static int CUTLASS_HOST_DEVICE getRowOffset(int i) { -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ int frag_idx = ThreadMap::kElementsPerAccess * -+ (frag_row_idx * ThreadMap::Iterations::kColumn + column); -+ if (i < frag_idx + ThreadMap::kElementsPerAccess) { -+ return row_offset; -+ } -+ } -+ } -+ } -+ } -+ return -1; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h -new file mode 100644 -index 0000000..a5d8f8d ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_rescale_output.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory -+ to match canonical tensor layouts in global memory. Epilogues support -+ conversion and reduction operations. -+ -+ This is a copy of cutlass/epilogue/threadblock/epilogue.h that can -+ handle "row_id" as a first argument, as uses it to get the corresponding -+ `m_prime` / `s_prime` to rescale the output. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "epilogue_pipelined.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+// output <- alpha * accumulator + beta * source -+// with: -+// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) -+// beta = alpha / m_prime (renormalize the output when the max changes) -+// source is the current output -+template < -+ typename ElementOutput_, ///< Data type used to store tensors -+ typename ElementSource_, //< Data type for source (usually matches -+ //`ElementOutput`) -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data -+ ///< to store -+ typename ElementAccumulator_, ///< Accumulator data type -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ bool isFirst, -+ bool isLast, -+ typename FragmentAlphaBeta_, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> -+class MemoryEfficientAttentionNormalize { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentSource = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ using FragmentAlphaBeta = FragmentAlphaBeta_; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ FragmentAlphaBeta const& s_prime_; -+ FragmentAlphaBeta const& m_prime_; -+ -+ public: -+ /// Constructs the function object, possibly loading from pointers in host -+ /// memory -+ CUTLASS_HOST_DEVICE -+ MemoryEfficientAttentionNormalize( -+ FragmentAlphaBeta const& s_prime, -+ FragmentAlphaBeta const& m_prime) -+ : s_prime_(s_prime), m_prime_(m_prime) {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return !isFirst; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) {} -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ int row, -+ FragmentAccumulator const& accumulator, -+ FragmentSource const& source) const { -+ assert(!isFirst); -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter -+ source_converter; -+ NumericArrayConverter -+ accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; -+ ElementCompute beta = alpha * m_prime_[row]; -+ -+ intermediate = mul_add_source(beta, converted_source); // X = beta * C -+ -+ intermediate = mul_add_accumulator( -+ alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) -+ const { -+ assert(isFirst); -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter -+ accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ ComputeFragment intermediate; -+ multiplies mul_accumulator; -+ -+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; -+ -+ intermediate = mul_accumulator( -+ alpha, converted_accumulator); // X = alpha * C + uniform -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+} // namespace thread -+ -+namespace threadblock { -+template < -+ typename EO, -+ typename ES, -+ int Count, -+ typename EA, -+ typename EC, -+ bool F, -+ bool L, -+ typename FAB, -+ FloatRoundStyle R> -+struct ApplyEpilogueOp> { -+ using Op = thread:: -+ MemoryEfficientAttentionNormalize; -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum, -+ typename Op::FragmentSource const& source) { -+ return output_op(row_id, accum, source); -+ } -+ static CUTLASS_DEVICE typename Op::FragmentOutput apply( -+ Op const& output_op, -+ int row_id, -+ typename Op::FragmentAccumulator const& accum) { -+ return output_op(row_id, accum); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h -new file mode 100644 -index 0000000..2e286d3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ArrayExponential { -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const& input) const { -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ result[i] = expf(input[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct ArrayExponential { -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const& input) const { -+ Array result; -+ -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ __half2 const* input_ptr = -+ reinterpret_cast<__half2 const*>(input.raw_data()); -+ __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = h2exp(input_ptr[i]); -+ } -+ -+ return result; -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies: -+/// output <- (input - lse).exp() -+template < -+ typename ElementOutput_, // output -+ typename ElementLSE_, // accumulator from LSE -+ typename ElementAccumulator_, // accumulator from matmul -+ typename ElementCompute_, // intermediate compute (and exp calculation) -+ int ElementsPerAccess> -+class ApplyLogSumExp { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementLSE = ElementLSE_; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ static const ScaleType::Kind kScale = -+ cutlass::epilogue::thread::ScaleType::NoBetaScaling; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentLSE = Array; -+ using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h -+ -+ public: -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ApplyLogSumExp() {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) {} -+ -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const& AB, -+ FragmentLSE const& scale_unused, -+ // bias used as LSE -+ FragmentLSE const& bias) const { -+ FragmentCompute frag_AB = NumericArrayConverter< -+ ElementCompute, -+ ElementAccumulator, -+ kElementsPerAccess>()(AB); -+ FragmentCompute frag_lse_compute = -+ NumericArrayConverter()( -+ bias); -+ FragmentCompute frag_compute; -+ -+ minus minus_lse; -+ detail::ArrayExponential apply_exp; -+ frag_compute = minus_lse(frag_AB, frag_lse_compute); -+ frag_compute = apply_exp(frag_compute); -+ -+ return NumericArrayConverter< -+ ElementOutput, -+ ElementCompute, -+ kElementsPerAccess>()(frag_compute); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h -new file mode 100644 -index 0000000..9c62c8c ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/find_default_mma.h -@@ -0,0 +1,189 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Cutlass provides helper template functions to figure out the right -+ datastructures to instanciate to run a GEMM with various parameters (see -+ `cutlass/gemm/threadblock/default_mma.h`). However, due to template -+ instantiation priority rules, it will only create an MmaMultiStage with -+ kStages=3 (otherwise creates an MmePipelined - which is not compatible with -+ FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, -+ so we just copy-pasted some code from `default_mma.h` and -+ `default_mma_core.h` files and wrapped this template to allow our usecase. -+ -+ This is really only for the FastF32 case - aka using TensorCores with fp32. -+*/ -+ -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ typename Enable_ = void> -+struct FindDefaultMma { -+ static constexpr bool AccumulatorsInRowMajor = false; -+ static constexpr SharedMemoryClearOption SharedMemoryClear = -+ SharedMemoryClearOption::kNone; -+ using DefaultMma = cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator, -+ AccumulatorsInRowMajor, -+ SharedMemoryClear>; -+}; -+ -+/// Specialization for sm80 / FastF32 / multistage with kStages=2 -+template < -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ int kStages, -+ typename Operator> -+struct FindDefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ kStages, -+ Operator, -+ typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { -+ using LayoutC = layout::RowMajor; -+ using OperatorClass = arch::OpClassTensorOp; -+ using ArchTag = arch::Sm80; -+ -+ using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 3, -+ Operator>; -+ struct DefaultMma : DefaultMma_ { -+ using MmaCore_ = typename DefaultMma_::MmaCore; -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore_::Shape, -+ typename DefaultMma_::IteratorA, -+ typename MmaCore_::SmemIteratorA, -+ MmaCore_::kCacheOpA, -+ typename DefaultMma_::IteratorB, -+ typename MmaCore_::SmemIteratorB, -+ MmaCore_::kCacheOpB, -+ ElementAccumulator, -+ LayoutC, -+ typename MmaCore_::MmaPolicy, -+ kStages>; -+ }; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h -new file mode 100644 -index 0000000..7201599 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h -@@ -0,0 +1,839 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Grouped FMHA kernel -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+ -+#include "fmha_grouped_problem_visitor.h" -+#include "gemm_kernel_utils.h" -+#include "epilogue_rescale_output.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename MM0_, ///! Structure for computing P = Q @ K -+ typename MM1_, ///! Structure for computing O = P @ V -+ typename scalar_t_, -+ typename accum_t_, -+ typename output_t_, -+ typename output_accum_t_, -+ bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file -+ GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform -+> -+struct FMHAGrouped { -+public: -+ using MM0 = MM0_; -+ using MM1 = MM1_; -+ -+ using scalar_t = scalar_t_; -+ using accum_t = accum_t_; -+ using output_t = output_t_; -+ using output_accum_t = output_accum_t_; -+ -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ -+ static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && -+ !cutlass::platform::is_same::value; -+ -+ // Parameters to satisfy BaseGrouped -+ using ElementA = scalar_t; -+ using ElementB = scalar_t; -+ using ElementC = accum_t; -+ using LayoutA = typename MM0::LayoutA; -+ using LayoutB = typename MM0::ElementB; -+ using LayoutC = typename MM1::ElementC; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static int const kAlignmentA = MM0::kAlignmentA; -+ static int const kAlignmentB = MM0::kAlignmentB; -+ static int const kAlignmentC = 1; -+ using Mma = typename MM1::Mma; -+ using EpilogueOutputOp = typename MM1::EpilogueOutputOp; -+ using ThreadblockSwizzle = void; -+ using Operator = typename MM1::Operator; -+ using WarpShape = typename MM1::WarpShape; -+ using InstructionShape = typename MM1::InstructionShape; -+ -+ using ElementQ = scalar_t; -+ using ElementK = scalar_t; -+ using ElementP = accum_t; -+ using ElementV = scalar_t; -+ using ElementO = output_t; -+ using ElementOAccum = output_accum_t; -+ using ElementAccumulator = accum_t; -+ -+ using LayoutQ = typename MM0::LayoutA; -+ using LayoutK = typename MM0::LayoutB; -+ using LayoutP = typename MM0::LayoutC; -+ using LayoutV = typename MM1::LayoutB; -+ using LayoutO = typename MM1::LayoutC; -+ -+ static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && -+ cutlass::sizeof_bits::value == 16); -+ -+ static int const kAlignmentQ = MM0::kAlignmentA; -+ static int const kAlignmentK = MM0::kAlignmentB; -+ static int const kAlignmentV = 1; -+ -+ using ThreadblockShape = typename MM0::ThreadblockShape; -+ -+ static int const kQueriesPerBlock = ThreadblockShape::kM; -+ static int const kKeysPerBlock = ThreadblockShape::kN; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename MM1::WarpCount; -+ static int const kThreadsPerWarp = 32; -+ static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; -+ -+ using ProblemVisitor = FMHAGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes0; -+ GemmCoord *problem_sizes1; -+ -+ int problem_count; -+ int threadblock_count; -+ -+ ElementQ ** ptr_Q; -+ ElementK ** ptr_K; -+ ElementP ** ptr_P; -+ ElementV ** ptr_V; -+ ElementO ** ptr_O; -+ ElementOAccum ** ptr_O_accum; -+ -+ typename LayoutQ::Stride::LongIndex *ldq; -+ typename LayoutK::Stride::LongIndex *ldk; -+ typename LayoutP::Stride::LongIndex *ldv; -+ typename LayoutO::Stride::LongIndex *ldo; -+ -+ // Whether causal masking is to be performed -+ bool causal; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_Q(nullptr), -+ ptr_K(nullptr), -+ ptr_P(nullptr), -+ ptr_V(nullptr), -+ ptr_O(nullptr), -+ ptr_O_accum(nullptr), -+ ldq(nullptr), -+ ldk(nullptr), -+ ldv(nullptr), -+ ldo(nullptr), -+ causal(false), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes0, -+ GemmCoord *problem_sizes1, -+ int problem_count, -+ int threadblock_count, -+ ElementQ ** ptr_Q, -+ ElementK ** ptr_K, -+ ElementP ** ptr_P, -+ ElementV ** ptr_V, -+ ElementO ** ptr_O, -+ ElementOAccum ** ptr_O_accum, -+ typename LayoutQ::Stride::LongIndex *ldq, -+ typename LayoutK::Stride::LongIndex *ldk, -+ typename LayoutP::Stride::LongIndex *ldp, -+ typename LayoutV::Stride::LongIndex *ldv, -+ typename LayoutO::Stride::LongIndex *ldo, -+ bool causal, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes0(problem_sizes0), -+ problem_sizes1(problem_sizes1), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ ptr_Q(ptr_Q), -+ ptr_K(ptr_K), -+ ptr_P(ptr_P), -+ ptr_V(ptr_V), -+ ptr_O(ptr_O), -+ ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), -+ ldq(ldq), -+ ldk(ldk), -+ ldv(ldv), -+ ldo(ldo), -+ causal(causal), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ -+ bool __host__ check_supported() { -+ CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); -+ CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); -+ CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); -+ XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); -+ return true; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ ElementQ ** ptr_Q; -+ ElementK ** ptr_K; -+ ElementP ** ptr_P; -+ ElementV ** ptr_V; -+ ElementO ** ptr_O; -+ ElementOAccum ** ptr_O_accum; -+ -+ typename LayoutQ::Stride::LongIndex *ldq; -+ typename LayoutK::Stride::LongIndex *ldk; -+ typename LayoutP::Stride::LongIndex *ldv; -+ typename LayoutO::Stride::LongIndex *ldo; -+ -+ bool causal; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_Q(nullptr), -+ ptr_K(nullptr), -+ ptr_P(nullptr), -+ ptr_V(nullptr), -+ ptr_O(nullptr), -+ ptr_O_accum(nullptr), -+ ldq(nullptr), -+ ldk(nullptr), -+ ldv(nullptr), -+ ldo(nullptr), -+ causal(false) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ ptr_Q(args.ptr_Q), -+ ptr_K(args.ptr_K), -+ ptr_P(args.ptr_P), -+ ptr_V(args.ptr_V), -+ ptr_O(args.ptr_O), -+ ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), -+ ldq(args.ldq), -+ ldk(args.ldk), -+ ldv(args.ldv), -+ ldo(args.ldo), -+ causal(args.causal) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, -+ args.problem_sizes1, -+ args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ ptr_Q = args.ptr_Q; -+ ptr_K = args.ptr_K; -+ ptr_P = args.ptr_P; -+ ptr_V = args.ptr_V; -+ ptr_O = args.ptr_O; -+ ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; -+ ldq = args.ldq; -+ ldk = args.ldk; -+ ldv = args.ldv; -+ ldo = args.ldo; -+ causal = args.causal; -+ } -+ }; -+ -+ // Shared storage - depends on kernel params -+ struct ScalingCoefs { -+ cutlass::Array m_prime; -+ cutlass::Array s_prime; -+ cutlass::Array mi; -+ }; -+ -+ struct SharedStorageEpilogueAtEnd : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ typename MM0::AccumulatorSharedStorage si; -+ typename MM1::SharedStorageMM1 mm1; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return epilogue; -+ } -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct SharedStorageEpilogueInLoop : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ typename MM0::AccumulatorSharedStorage si; -+ typename MM1::SharedStorageMM1 mm1; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return after_mm0.epilogue; -+ } -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ using SharedStorage = typename cutlass::platform::conditional< -+ kKeepOutputInRF, -+ SharedStorageEpilogueAtEnd, -+ SharedStorageEpilogueInLoop>::type; -+ -+private: -+ -+ // Parameters to be used by an individual tile -+ struct TileParams { -+ -+ CUTLASS_HOST_DEVICE -+ static int query_start(int threadblock_idx) { -+ return threadblock_idx * kQueriesPerBlock; -+ } -+ -+ // Returns whether this threadblock computes within the number of queries, -+ // which is determined by the M dimension of problem 0 -+ CUTLASS_HOST_DEVICE -+ static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { -+ return query_start(threadblock_idx) < problem_size0.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { -+ return problem_size0.m() - query_start(threadblock_idx); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { -+ int nk = problem_size0.n(); -+ if (causal) { -+ nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); -+ } -+ return nk; -+ } -+ -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ FMHAGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ static CUTLASS_DEVICE int16_t thread_id() { -+ return threadIdx.x; -+ } -+ -+ static CUTLASS_DEVICE int8_t warp_id() { -+ return threadIdx.x / kThreadsPerWarp; -+ } -+ -+ static CUTLASS_DEVICE int8_t lane_id() { -+ return threadIdx.x % kThreadsPerWarp; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ auto& m_prime = shared_storage.m_prime; -+ auto& s_prime = shared_storage.s_prime; -+ [[maybe_unused]] auto& si = shared_storage.after_mm0.si; -+ auto& mi = shared_storage.mi; -+ -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size0 = problem_visitor.problem_size0(); -+ GemmCoord problem_size1 = problem_visitor.problem_size1(); -+ const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ if (!TileParams::can_compute(threadblock_idx, problem_size0)) { -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ const int32_t problem_idx = problem_visitor.problem_index(); -+ -+ if (thread_id() < kQueriesPerBlock) { -+ s_prime[thread_id()] = ElementAccumulator(0); -+ m_prime[thread_id()] = -+ -cutlass::platform::numeric_limits::infinity(); -+ mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); -+ } -+ -+ ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; -+ ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; -+ const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); -+ -+ auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { -+ using OutputTileIterator = typename MM1::OutputTileIterator; -+ return OutputTileIterator( -+ typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, -+ ptr_O, -+ typename OutputTileIterator::TensorCoord{ -+ num_queries, problem_size1.n()}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ auto createOutputAccumIter = [&](int col) -> -+ typename MM1::OutputTileIteratorAccum { -+ using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; -+ return OutputTileIteratorAccum( -+ typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, -+ ptr_O_accum, -+ typename OutputTileIteratorAccum::TensorCoord{ -+ num_queries, problem_size1.n()}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ typename MM1::Mma::FragmentC accum_o; -+ accum_o.clear(); -+ -+ const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); -+ -+ for (int32_t iter_key_start = 0; iter_key_start < num_keys; -+ iter_key_start += kKeysPerBlock) { -+ int32_t problem_size_0_m = -+ cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); -+ int32_t problem_size_0_n = cutlass::fast_min( -+ (int32_t)kKeysPerBlock, num_keys - iter_key_start); -+ int32_t const& problem_size_0_k = problem_size0.k(); -+ int32_t const& problem_size_1_n = problem_size1.n(); -+ int32_t const& problem_size_1_k = problem_size_0_n; -+ -+ auto prologueV = [&](int blockN) { -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, -+ params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ -+ MM1::Mma::prologue( -+ shared_storage.after_mm0.mm1.mm, -+ iterator_V, -+ thread_id(), -+ problem_size_1_k); -+ }; -+ -+ __syncthreads(); // Need to have shared memory initialized, and `m_prime` -+ // updated from end of prev iter -+ -+ // -+ // MATMUL: Q.K_t -+ // -+ // Computes the block-matrix product of: -+ // (a) query[query_start:query_end, :] -+ // with -+ // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] -+ // and stores that into `shared_storage.si` -+ // -+ -+ ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; -+ -+ // Construct iterators to A and B operands -+ typename MM0::IteratorA iterator_A( -+ typename MM0::IteratorA::Params( -+ typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), -+ ptr_Q, -+ {problem_size_0_m, problem_size_0_k}, -+ thread_id(), -+ {0, 0}); -+ -+ typename MM0::IteratorB iterator_B( -+ typename MM0::IteratorB::Params( -+ typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), -+ params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], -+ {problem_size_0_k, problem_size_0_n}, -+ thread_id(), -+ {0, 0}); -+ -+ // Construct thread-scoped matrix multiply -+ typename MM0::Mma mma( -+ shared_storage.mm0, thread_id(), warp_id(), lane_id()); -+ -+ typename MM0::Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ auto gemm_k_iterations = -+ (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ __syncthreads(); -+ -+ if (kPreloadV) { -+ prologueV(0); -+ } -+ -+ typename MM0::Mma::Operator::IteratorC::TensorCoord -+ iteratorC_tile_offset = { -+ (warp_id() % MM0::Mma::WarpCount::kM), -+ (warp_id() / MM0::Mma::WarpCount::kM) -+ }; -+ -+ // Mask out last if causal -+ if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ int32_t last_col; -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_n > last_col) { -+ accum[idx] = -+ -cutlass::platform::numeric_limits::infinity(); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ num_keys - iter_key_start >= kKeysPerBlock, -+ kFullColumns, -+ ([&] { -+ // Update `mi` from accum stored in registers -+ // Also updates `accum` with accum[i] <- -+ // exp(accum[i] * scale -+ // - mi) -+ MM0::ScalingCoefsUpdater::update< -+ kQueriesPerBlock, -+ kFullColumns, -+ kIsFirst, -+ kKeepOutputInRF>( -+ accum_o, -+ accum, -+ mi, -+ m_prime, -+ s_prime, -+ lane_id(), -+ thread_id(), -+ warp_id(), -+ num_keys - iter_key_start, -+ iteratorC_tile_offset, -+ 1.0f / cutlass::fast_sqrt(float(problem_size0.k()))); -+ })); -+ })); -+ -+ // Output results to shared-memory -+ int warp_idx_mn_0 = warp_id() % -+ (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); -+ auto output_tile_coords = cutlass::MatrixCoord{ -+ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, -+ warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; -+ -+ MM0::B2bGemm::accumToSmem( -+ shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); -+ -+ __syncthreads(); -+ -+ // -+ // MATMUL: Attn . V -+ // Run the matmul `attn @ V` for a block of attn and V. -+ // `attn` is read from shared memory (in `shared_storage_si`) -+ // `V` is read from global memory (with iterator_B) -+ // -+ -+ const int64_t nBlockN = kKeepOutputInRF ? 1 -+ : ceil_div( -+ (int64_t)problem_size_1_n, -+ int64_t(MM1::ThreadblockShape::kN)); -+ -+ // Iterate over the N dimension of GEMM1 -+ for (int blockN = 0; blockN < nBlockN; ++blockN) { -+ int gemm_k_iterations = -+ (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add and store it in accum -+ // (in registers) -+ if (!kPreloadV) { -+ __syncthreads(); // we share shmem between mma and epilogue -+ } -+ -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, -+ params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ -+ typename MM1::Mma mma_pv( -+ shared_storage.after_mm0.mm1.mm, -+ shared_storage.after_mm0.si, -+ (int)thread_id(), -+ (int)warp_id(), -+ (int)lane_id(), -+ (int)problem_size_1_k); -+ -+ mma_pv.set_prologue_done(kPreloadV); -+ if (!kKeepOutputInRF) { -+ accum_o.clear(); -+ } -+ -+ mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); -+ __syncthreads(); -+ -+ if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { -+ prologueV(blockN + 1); -+ } -+ -+ if (!kKeepOutputInRF) { -+ DISPATCH_BOOL( -+ iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ (iter_key_start + kKeysPerBlock) >= num_keys, -+ kIsLast, -+ ([&] { -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = typename cutlass::epilogue:: -+ thread::MemoryEfficientAttentionNormalize< -+ typename cutlass::platform::conditional< -+ kIsLast, -+ output_t, -+ output_accum_t>::type, -+ output_accum_t, -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, -+ output_accum_t, -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename cutlass::platform::conditional< -+ kIsLast, -+ typename MM1::OutputTileIterator, -+ typename MM1::OutputTileIteratorAccum>::type, -+ typename DefaultEpilogue:: -+ AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // Read -+ // iterator -+ >; -+ -+ int col = blockN * MM1::Mma::Shape::kN; -+ auto source_iter = createOutputAccumIter(col); -+ auto dest_iter = gemm_kernel_utils::call_conditional< -+ kIsLast, -+ decltype(createOutputIter), -+ decltype(createOutputAccumIter)>:: -+ apply(createOutputIter, createOutputAccumIter, col); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o, source_iter); -+ })); -+ })); -+ if (!kKeepOutputInRF) { -+ __syncthreads(); -+ } -+ } -+ } -+ __syncthreads(); // we modify `m_prime` after -+ } -+ -+ if (kKeepOutputInRF) { -+ const bool kIsFirst = true; -+ const bool kIsLast = true; -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = -+ typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< -+ output_t, // output -+ output_accum_t, // source -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, // accum -+ output_accum_t, // compute -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename MM1::OutputTileIterator, // destination -+ typename DefaultEpilogue::AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // source tile -+ >; -+ auto dest_iter = createOutputIter(0); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o); -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h -new file mode 100644 -index 0000000..2b31319 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Scheduler for grouped FMHA -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+// Helper for correctly representing problem sizes in grouped kernels -+template -+struct FMHAGroupedProblemSizeHelper { -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ // FMHA only partitions tiles across the M dimension. -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return grid.m() * grid.n(); -+ } -+}; -+ -+} // namespace detail -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::FMHAGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using BaseParams = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ cutlass::gemm::GemmCoord const *problem_sizes0; -+ cutlass::gemm::GemmCoord const *problem_sizes1; -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes0; -+ cutlass::gemm::GemmCoord const *problem_sizes1; -+ int32_t problem_count; -+ void const *workspace; -+ int32_t tile_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), -+ problem_count(0), workspace(nullptr), tile_count(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const *problem_sizes0, -+ cutlass::gemm::GemmCoord const *problem_sizes1, -+ int32_t problem_count, -+ void const *workspace = nullptr, -+ int32_t tile_count = 0 -+ ): -+ problem_sizes0(problem_sizes0), -+ problem_sizes1(problem_sizes1), -+ problem_count(problem_count), -+ workspace(workspace), -+ tile_count(tile_count) -+ {} -+ -+ /// Convert the FMHA-specific parameters to those used by the base class -+ CUTLASS_HOST_DEVICE -+ BaseParams to_base() const { -+ return BaseParams(// Set problem_sizes as problem_sizes1 because these determine -+ // shape of the final output of FMHA -+ problem_sizes1, -+ problem_count, -+ workspace, -+ tile_count); -+ } -+ -+ }; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ FMHAGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base ( -+ params_.to_base(), -+ shared_storage_, block_idx), -+ problem_sizes0(params_.problem_sizes0), -+ problem_sizes1(params_.problem_sizes1) -+ {} -+ -+ /// Returns the problem size 0 for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size0() const { -+ GemmCoord problem = problem_sizes0[this->problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+ -+ /// Returns the problem size 1 for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size1() const { -+ GemmCoord problem = problem_sizes1[this->problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu -new file mode 100644 -index 0000000..53af4ac ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu -@@ -0,0 +1,1087 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Attention Example. -+ -+ This workload computes a fused multi head attention. -+ Because it keeps the attention matrix in shared memory, it's both faster and -+ uses less global memory. -+ -+ This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, -+ and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. -+ -+ Algorithm: -+ In short, we can compute the output incrementally in blocks of size B, -+ we just need to divide the final result by the sum of all coefficients in -+ the softmax (which we compute incrementally) with the following pseudo-code: -+ -+ ``` -+ s_prime = torch.zeros([num_queries, B]) -+ O = torch.zeros([num_queries, head_size_v]) -+ for i in range(0, K.shape[0], B): -+ si = exp((Q . K[i * B:(i+1) * B].t) * scale) -+ sum_coefs += attn_unscaled.sum(-1) -+ O += si . V[i * B:(i+1) * B] -+ O = O / s_prime -+ ``` -+ -+ In practice, and for numerical stability reasons, -+ we also substract the maximum so far (`mi`) before doing -+ the exponential. When we encounter new keys, the maximum -+ used to compute O so far (`m_prime`) can differ from the -+ current maximum, so we update O before accumulating with -+ -+ ``` -+ O = O * exp(m_prime - mi) -+ m_prime = mi -+ ``` -+ -+ Implementation details: -+ - `si` is stored in shared memory between the 2 back to back gemms -+ - we keep and accumulate the output -+ directly in registers if we can (`head_size_v <= 128`). -+ Otherwise, we store it & accumulate in global memory (slower) -+ - blocks are parallelized across the batch dimension, the number -+ of heads, and the query sequence size -+ -+ -+ Examples: -+ -+ # Run an attention example with default setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen -+ -+ # Run an attention example with custom setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true -+ -+ Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+#include "cutlass/fast_math.h" -+#include "kernel_forward.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool use_mask; -+ bool causal; -+ -+ std::vector problem_sizes0; -+ std::vector problem_sizes1; -+ -+ std::vector problem_sizes0_real; -+ std::vector problem_sizes1_real; -+ -+ int alignment; -+ int head_number; -+ int batch_size; -+ int head_size; -+ int head_size_v; -+ int seq_length; -+ int seq_length_kv; -+ int iterations; -+ -+ // alpha0, alpha1 and beta are fixed -+ // in this multi-head attention example -+ float alpha0; -+ float alpha1; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(1), -+ reference_check(true), -+ head_number(12), -+ batch_size(16), -+ head_size(64), -+ head_size_v(64), -+ seq_length(1024), -+ seq_length_kv(1024), -+ use_mask(false), -+ iterations(20), -+ causal(false) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 1); -+ cmd.get_cmd_line_argument("head_number", head_number, 12); -+ cmd.get_cmd_line_argument("batch_size", batch_size, 16); -+ cmd.get_cmd_line_argument("head_size", head_size, 64); -+ cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); -+ cmd.get_cmd_line_argument("seq_length", seq_length, 1024); -+ cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); -+ cmd.get_cmd_line_argument("use_mask", use_mask, false); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("causal", causal, true); -+ -+ randomize_problems(); -+ -+ } -+ -+ void randomize_problems() { -+ -+ int problem_count = head_number * batch_size; -+ -+ problem_sizes0.reserve(problem_count); -+ problem_sizes1.reserve(problem_count); -+ -+ // When using mask, the original inputs are not padded -+ // and we need to save these info. -+ if (use_mask) { -+ problem_sizes0_real.reserve(problem_count); -+ problem_sizes1_real.reserve(problem_count); -+ } -+ -+ for (int i = 0; i < batch_size; ++i) { -+ // problems belonging to the same batch share the same seq len -+ int m_real = seq_length; -+ int mkv_real = seq_length_kv; -+ int m = (m_real + alignment - 1) / alignment * alignment; -+ int mkv = (mkv_real + alignment - 1) / alignment * alignment; -+ int k0 = head_size; -+ int k1 = head_size_v; -+ -+ for (int j = 0; j < head_number; ++j) { -+ cutlass::gemm::GemmCoord problem0(m, mkv, k0); -+ cutlass::gemm::GemmCoord problem1(m, k1, mkv); -+ problem_sizes0.push_back(problem0); -+ problem_sizes1.push_back(problem1); -+ -+ if (use_mask) { -+ cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); -+ cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); -+ problem_sizes0_real.push_back(problem0_real); -+ problem_sizes1_real.push_back(problem1_real); -+ } -+ } -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "41_fused_multi_head_attention_fixed_seqlen\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" -+ << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" -+ << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" -+ << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" -+ << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" -+ << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" -+ << " --use_mask= If true, performs padding-like masking in softmax.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --causal= If true, uses causal masking.\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fops = int64_t(); -+ -+ for (int i = 0; i < problem_sizes0.size(); ++i) { -+ auto const& problem0 = problem_sizes0[i]; -+ auto const& problem1 = problem_sizes1[i]; -+ for (int row = 0; row < problem0.m(); ++row) { -+ int num_cols0 = problem0.n(); -+ if (causal) { -+ num_cols0 = std::min(row + 1, num_cols0); -+ } -+ // P <- Q . K_t -+ fops += 2 * num_cols0 * problem0.k(); -+ // P <- exp(P - max(P)) -+ fops += 2 * num_cols0; -+ // S <- sum(P) -+ fops += num_cols0 - 1; -+ // O <- P . V -+ fops += 2 * num_cols0 * problem1.n(); -+ // O <- O / S -+ fops += num_cols0 * problem1.n(); -+ } -+ } -+ -+ return double(fops) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedAttention { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementQ = typename Attention::scalar_t; -+ using ElementK = typename Attention::scalar_t; -+ using ElementP = typename Attention::accum_t; -+ using ElementAccumulator = typename Attention::accum_t; -+ using ElementV = typename Attention::scalar_t; -+ using ElementO = typename Attention::output_t; -+ -+ using ElementCompute = typename Attention::accum_t; -+ -+ using ElementNorm = typename Attention::accum_t; -+ using ElementSum = typename Attention::accum_t; -+ using ElementSoftmaxCompute = typename Attention::accum_t; -+ -+ using LayoutQ = cutlass::layout::RowMajor; -+ using LayoutK = cutlass::layout::ColumnMajor; -+ using LayoutP = cutlass::layout::RowMajor; -+ using LayoutV = cutlass::layout::RowMajor; -+ using LayoutO = cutlass::layout::RowMajor; -+ -+ using MatrixCoord = typename LayoutP::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_Q; -+ cutlass::Distribution::Kind init_K; -+ cutlass::Distribution::Kind init_P; -+ cutlass::Distribution::Kind init_V; -+ cutlass::Distribution::Kind init_O; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device0; -+ cutlass::DeviceAllocation problem_sizes_device1; -+ cutlass::DeviceAllocation problem_sizes_device0_real; -+ -+ std::vector offset_Q; -+ std::vector offset_K; -+ std::vector offset_P; -+ std::vector offset_V; -+ std::vector offset_O; -+ -+ std::vector ldq_host; -+ std::vector ldk_host; -+ std::vector ldp_host; -+ std::vector ldv_host; -+ std::vector ldo_host; -+ std::vector seqlen_host; -+ -+ cutlass::DeviceAllocation ldq; -+ cutlass::DeviceAllocation ldk; -+ cutlass::DeviceAllocation ldp; -+ cutlass::DeviceAllocation ldv; -+ cutlass::DeviceAllocation ldo; -+ cutlass::DeviceAllocation seqlen; -+ -+ cutlass::DeviceAllocation block_Q; -+ cutlass::DeviceAllocation block_K; -+ cutlass::DeviceAllocation block_P; -+ cutlass::DeviceAllocation block_V; -+ cutlass::DeviceAllocation block_O; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ cutlass::DeviceAllocation offset_P_Device; -+ -+ cutlass::DeviceAllocation ptr_Q; -+ cutlass::DeviceAllocation ptr_K; -+ cutlass::DeviceAllocation ptr_P; -+ cutlass::DeviceAllocation ptr_V; -+ cutlass::DeviceAllocation ptr_O; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedAttention( -+ Options &options_, -+ cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } -+ -+ int problem_count() const { -+ return (options.head_number * options.batch_size); -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 8; -+ scope_min = -8; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ -+ // -+ // Set scalors for the mha example -+ // -+ -+ options.alpha0 = 1.0f / sqrt(float(options.head_size)); -+ options.alpha1 = 1.0f; -+ options.beta = 0; -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_Q = 0; -+ int64_t total_elements_K = 0; -+ int64_t total_elements_P = 0; -+ int64_t total_elements_V = 0; -+ int64_t total_elements_O = 0; -+ -+ ldq_host.resize(problem_count()); -+ ldk_host.resize(problem_count()); -+ ldp_host.resize(problem_count()); -+ ldv_host.resize(problem_count()); -+ ldo_host.resize(problem_count()); -+ seqlen_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem0 = options.problem_sizes0.at(i); -+ auto problem1 = options.problem_sizes1.at(i); -+ -+ ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); -+ ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); -+ ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); -+ ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); -+ ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); -+ -+ // m = n for attention problems. -+ seqlen_host.at(i) = problem0.m(); -+ -+ offset_Q.push_back(total_elements_Q); -+ offset_K.push_back(total_elements_K); -+ offset_P.push_back(total_elements_P); -+ offset_V.push_back(total_elements_V); -+ offset_O.push_back(total_elements_O); -+ -+ int64_t elements_Q = problem0.m() * problem0.k(); -+ int64_t elements_K = problem0.k() * problem0.n(); -+ int64_t elements_P = problem0.m() * problem0.n(); -+ int64_t elements_V = problem1.k() * problem1.n(); -+ int64_t elements_O = problem1.m() * problem1.n(); -+ -+ total_elements_Q += elements_Q; -+ total_elements_K += elements_K; -+ total_elements_P += elements_P; -+ total_elements_V += elements_V; -+ total_elements_O += elements_O; -+ } -+ -+ problem_sizes_device0.reset(problem_count()); -+ problem_sizes_device1.reset(problem_count()); -+ problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); -+ problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); -+ -+ if (options.use_mask) { -+ problem_sizes_device0_real.reset(problem_count()); -+ problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); -+ } -+ -+ ldq.reset(problem_count()); -+ ldk.reset(problem_count()); -+ ldp.reset(problem_count()); -+ ldv.reset(problem_count()); -+ ldo.reset(problem_count()); -+ seqlen.reset(problem_count()); -+ -+ ldq.copy_from_host(ldq_host.data()); -+ ldk.copy_from_host(ldk_host.data()); -+ ldp.copy_from_host(ldp_host.data()); -+ ldv.copy_from_host(ldv_host.data()); -+ ldo.copy_from_host(ldo_host.data()); -+ seqlen.copy_from_host(seqlen_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_Q.reset(total_elements_Q); -+ block_K.reset(total_elements_K); -+ block_P.reset(total_elements_P); -+ block_V.reset(total_elements_V); -+ block_O.reset(total_elements_O); -+ -+ offset_P_Device.reset(problem_count()); -+ -+ // sync offset with device -+ cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); -+ -+ std::vector ptr_Q_host(problem_count()); -+ std::vector ptr_K_host(problem_count()); -+ std::vector ptr_P_host(problem_count()); -+ std::vector ptr_V_host(problem_count()); -+ std::vector ptr_O_host(problem_count()); -+ std::vector ptr_norm_host(problem_count()); -+ std::vector ptr_sum_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); -+ ptr_K_host.at(i) = block_K.get() + offset_K.at(i); -+ ptr_P_host.at(i) = block_P.get() + offset_P.at(i); -+ ptr_V_host.at(i) = block_V.get() + offset_V.at(i); -+ ptr_O_host.at(i) = block_O.get() + offset_O.at(i); -+ } -+ -+ ptr_Q.reset(problem_count()); -+ ptr_Q.copy_from_host(ptr_Q_host.data()); -+ -+ ptr_K.reset(problem_count()); -+ ptr_K.copy_from_host(ptr_K_host.data()); -+ -+ ptr_P.reset(problem_count()); -+ ptr_P.copy_from_host(ptr_P_host.data()); -+ -+ ptr_V.reset(problem_count()); -+ ptr_V.copy_from_host(ptr_V_host.data()); -+ -+ ptr_O.reset(problem_count()); -+ ptr_O.copy_from_host(ptr_O_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); -+ initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); -+ initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); -+ -+ } -+ -+ template -+ bool verify_tensor_(std::vector vector_Input, \ -+ std::vector vector_Input_Ref, -+ int64_t verify_length = -1) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ size = (verify_length == -1) ? size : verify_length; -+ -+ // 0.05 for absolute error -+ float abs_tol = 5e-2f; -+ // 10% for relative error -+ float rel_tol = 1e-1f; -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); -+ float relative_diff = abs_diff / abs_ref; -+ if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { -+ printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); -+ cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); -+ -+ LayoutQ layout_Q(ldq_host.at(i)); -+ LayoutK layout_K(ldk_host.at(i)); -+ LayoutP layout_P(ldp_host.at(i)); -+ LayoutV layout_V(ldv_host.at(i)); -+ LayoutO layout_O(ldo_host.at(i)); -+ -+ MatrixCoord extent_Q{problem0.m(), problem0.k()}; -+ MatrixCoord extent_K{problem0.k(), problem0.n()}; -+ MatrixCoord extent_P{problem0.m(), problem0.n()}; -+ MatrixCoord extent_V{problem1.k(), problem1.n()}; -+ MatrixCoord extent_O{problem1.m(), problem1.n()}; -+ -+ cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); -+ cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); -+ cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); -+ cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); -+ -+ cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); -+ -+ cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementQ, LayoutQ, -+ ElementK, LayoutK, -+ ElementP, LayoutP, -+ ElementCompute, ElementAccumulator -+ >( -+ problem0, -+ ElementAccumulator(options.alpha0), -+ view_Q, -+ Attention::MM0::Mma::kTransformA, -+ view_K, -+ Attention::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_P, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Compute softmax for P. We need to explicitly compute softmax -+ // over P because softmax is fused to the second GEMM in the -+ // profiled implementation. -+ std::vector matrix_Ref(layout_P.capacity(extent_P)); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); -+ cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); -+ std::vector vector_Norm_Ref(problem0.m()); -+ std::vector vector_Sum_Ref(problem0.m()); -+ -+ int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); -+ -+ // Compute softmax for referece matrix -+ for (int m = 0; m < problem0.m(); m++) { -+ int n_dim_row = n_dim; -+ if (options.causal) { -+ n_dim_row = std::min(m + 1, n_dim); -+ } -+ ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); -+ for (int n = 1; n < n_dim_row; n++) { -+ max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); -+ } -+ -+ vector_Norm_Ref.at(m) = ElementNorm(max); -+ -+ ElementSoftmaxCompute sum = ElementSoftmaxCompute(); -+ for (int n = 0; n < n_dim_row; n++) { -+ sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); -+ } -+ ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); -+ -+ vector_Sum_Ref.at(m) = ElementSum(inv_sum); -+ -+ for (int n = 0; n < n_dim_row; n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP( -+ std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum -+ ); -+ } -+ // Mask out the rest of the attention matrix -+ for (int n = n_dim_row; n < n_dim; ++n) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ -+ // when not using mask, problem_real and problem share the same sizes -+ if (options.use_mask) { -+ for (int m = 0; m < problem0.m(); m++) { -+ for (int n = n_dim; n < problem0.n(); n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ } -+ -+ cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementP, LayoutP, -+ ElementV, LayoutV, -+ ElementO, LayoutO, -+ ElementCompute, ElementAccumulator -+ >( -+ problem1, -+ ElementAccumulator(options.alpha1), -+ view_P, -+ Attention::MM0::Mma::kTransformA, -+ view_V, -+ Attention::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_Ref_O_device, -+ view_Ref_O_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); -+ -+ std::vector matrix_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); -+ std::vector matrix_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); -+ -+ // printf("Pb %d: \n Q=(offset=%d, ldq=%d)\n K=(offset=%d, ldk=%d)\n O=(offset=%d, ldo=%d)\n", -+ // int(i), int(offset_Q[i]), int(ldq_host[i]), int(offset_K[i]), int(ldk_host[i]), int(offset_O[i]), int(ldo_host[i])); -+ -+ bool verified_O = false; -+ -+ if (!verified_O) { -+ verified_O = verify_tensor_(matrix_O, matrix_Ref_O); -+ } -+ -+ passed = passed && verified_O; -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ -+ if (!verified_O) { -+ std::cout << "Final matrix output is incorrect" << std::endl; -+ } -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ -+ /// Executes a CUTLASS Attention kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ typename Attention::Params p; -+ { // set parameters -+ p.query_ptr = block_Q.get(); -+ p.key_ptr = block_K.get(); -+ p.value_ptr = block_V.get(); -+ p.logsumexp_ptr = nullptr; // Only needed for bw -+ p.output_accum_ptr = nullptr; -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ cudaMalloc(&p.output_accum_ptr, block_O.size() * sizeof(typename Attention::output_accum_t)); -+ } -+ p.output_ptr = block_O.get(); -+ -+ // TODO: support arbitrary seq lengths -+ // if (cu_seqlens_q.has_value()) { -+ // p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); -+ // p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); -+ // } -+ -+ p.num_heads = options.head_number; -+ p.num_batches = options.batch_size; -+ p.head_dim = options.head_size; -+ p.head_dim_value = options.head_size_v; -+ p.num_queries = options.seq_length; -+ p.num_keys = options.seq_length_kv; -+ p.causal = options.causal; -+ -+ // TODO: This might overflow for big tensors -+ p.q_strideM = int32_t(ldq_host[0]); -+ p.k_strideM = int32_t(ldk_host[0]); -+ p.v_strideM = int32_t(ldv_host[0]); -+ p.q_strideH = p.q_strideM * options.seq_length; -+ p.k_strideH = p.k_strideM * options.seq_length_kv; -+ p.v_strideH = p.v_strideM * options.seq_length_kv; -+ p.o_strideH = options.head_size_v * options.seq_length; -+ p.q_strideB = p.q_strideH * options.head_number; -+ p.k_strideB = p.k_strideH * options.head_number; -+ p.v_strideB = p.v_strideH * options.head_number; -+ p.o_strideB = options.head_size_v * options.seq_length * options.head_number; -+ } -+ -+ // launch kernel :) -+ constexpr auto kernel_fn = attention_kernel_batched_impl; -+ int smem_bytes = sizeof(typename Attention::SharedStorage); -+ if (smem_bytes > 0xc000) { -+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); -+ } -+ if (!Attention::check_supported(p)) { -+ std::cerr << "Kernel does not support these inputs" << std::endl; -+ return result; -+ } -+ kernel_fn<<>>(p); -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run -+ // -+ -+ kernel_fn<<>>(p); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ kernel_fn<<>>(p); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "CUTLASS Attention:\n" -+ << "====================================================" << std::endl; -+ std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ -+ << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ -+ << ", " << options.batch_size << "}." << std::endl; -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration -+> -+int run_attention(Options& options) { -+ using Attention = AttentionKernel< -+ cutlass::half_t, // scalar_t -+ cutlass::arch::Sm80, // ArchTag -+ true, // Memory is aligned -+ kQueriesPerBlock, -+ kKeysPerBlock, -+ kSingleValueIteration -+ >; -+ -+ // -+ // Test and profile -+ // -+ -+ TestbedAttention testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS attention has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ return 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ if (options.use_mask) { -+ std::cerr << "--use_mask is not supported at the moment\n"; -+ return -2; -+ } -+ if (options.alignment != 1) { -+ std::cerr << "--alignment=1 is the only supported value\n"; -+ return -2; -+ } -+ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (options.head_size_v > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (options.head_size_v <= kKeysPerBlock) { -+ return run_attention(options); -+ } else { -+ return run_attention(options); -+ } -+ } else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return run_attention(options); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu -new file mode 100644 -index 0000000..35b5c32 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu -@@ -0,0 +1,1193 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Attention Example. -+ -+ This workload computes a fused multi head attention that supports variable sequence lengths. -+ Because it keeps the attention matrix in shared memory, it's both faster and -+ uses less global memory. -+ -+ This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, -+ and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. -+ -+ Algorithm: -+ In short, we can compute the output incrementally in blocks of size B, -+ we just need to divide the final result by the sum of all coefficients in -+ the softmax (which we compute incrementally) with the following pseudo-code: -+ -+ ``` -+ s_prime = torch.zeros([num_queries, B]) -+ O = torch.zeros([num_queries, head_size_v]) -+ for i in range(0, K.shape[0], B): -+ si = exp((Q . K[i * B:(i+1) * B].t) * scale) -+ sum_coefs += attn_unscaled.sum(-1) -+ O += si . V[i * B:(i+1) * B] -+ O = O / s_prime -+ ``` -+ -+ In practice, and for numerical stability reasons, -+ we also substract the maximum so far (`mi`) before doing -+ the exponential. When we encounter new keys, the maximum -+ used to compute O so far (`m_prime`) can differ from the -+ current maximum, so we update O before accumulating with -+ -+ ``` -+ O = O * exp(m_prime - mi) -+ m_prime = mi -+ ``` -+ -+ Implementation details: -+ - `si` is stored in shared memory between the 2 back to back gemms -+ - we keep and accumulate the output -+ directly in registers if we can (`head_size_v <= 128`). -+ Otherwise, we store it & accumulate in global memory (slower) -+ - blocks are parallelized across the batch dimension, the number -+ of heads, and the query sequence size -+ -+ -+ Examples: -+ -+ # Run an attention example with default setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen -+ -+ # Run an attention example with custom setup -+ $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true -+ -+ Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). -+ Using grouped GEMM to handle variable sequence lengths is inspired by an idea originally prototyped by ByteDance Inc. -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/fast_math.h" -+ -+#include "default_fmha_grouped.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ bool reference_check; -+ bool use_mask; -+ bool causal; -+ bool fixed_seq_length; -+ -+ std::vector problem_sizes0; -+ std::vector problem_sizes1; -+ -+ std::vector problem_sizes0_real; -+ std::vector problem_sizes1_real; -+ -+ int alignment; -+ int head_number; -+ int batch_size; -+ int head_size; -+ int head_size_v; -+ int seq_length; -+ int seq_length_kv; -+ int iterations; -+ int problem_count; -+ -+ // alpha0, alpha1 and beta are fixed -+ // in this multi-head attention example -+ float alpha0; -+ float alpha1; -+ float beta; -+ -+ cutlass::gemm::kernel::GroupScheduleMode scheduler_mode; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ error(false), -+ alignment(1), -+ reference_check(true), -+ head_number(12), -+ batch_size(16), -+ head_size(64), -+ head_size_v(64), -+ seq_length(1024), -+ seq_length_kv(1024), -+ use_mask(false), -+ iterations(20), -+ causal(false), -+ fixed_seq_length(false), -+ problem_count(batch_size * head_number), -+ scheduler_mode(cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("alignment", alignment, 1); -+ cmd.get_cmd_line_argument("head_number", head_number, 12); -+ cmd.get_cmd_line_argument("batch_size", batch_size, 16); -+ cmd.get_cmd_line_argument("head_size", head_size, 64); -+ cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); -+ cmd.get_cmd_line_argument("seq_length", seq_length, 1024); -+ cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); -+ cmd.get_cmd_line_argument("use_mask", use_mask, false); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ cmd.get_cmd_line_argument("causal", causal, true); -+ cmd.get_cmd_line_argument("fixed_seq_length", fixed_seq_length, false); -+ -+ std::vector scheduler_mode_strs; -+ cmd.get_cmd_line_arguments("scheduler-mode", scheduler_mode_strs); -+ -+ if (!scheduler_mode_strs.empty()) { -+ if (scheduler_mode_strs.size() > 1) { -+ std::cerr << "Only one scheduler mode may be passed in" << std::endl; -+ error = true; -+ return; -+ } -+ std::string scheduler_mode_str = scheduler_mode_strs[0]; -+ if (scheduler_mode_str == "kDeviceOnly") { -+ scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly; -+ } else if (scheduler_mode_str == "kHostPrecompute") { -+ scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute; -+ } else { -+ std::cerr << "Unrecognized scheduler mode '" << scheduler_mode_str << "'" << std::endl; -+ error = true; -+ return; -+ } -+ } -+ -+ if (fixed_seq_length) { -+ std::cout << "NOTE: Better performance is expected for fixed-sized sequence length from 41_fused_multi_head_attention_fixed_seqlen." << std::endl; -+ } -+ -+ randomize_problems(); -+ } -+ -+ void randomize_problems() { -+ -+ problem_count = head_number * batch_size; -+ -+ problem_sizes0.reserve(problem_count); -+ problem_sizes1.reserve(problem_count); -+ -+ // When using mask, the original inputs are not padded -+ // and we need to save these info. -+ if (use_mask) { -+ problem_sizes0_real.reserve(problem_count); -+ problem_sizes1_real.reserve(problem_count); -+ } -+ -+ for (int i = 0; i < batch_size; ++i) { -+ // problems belonging to the same batch share the same seq len -+ -+ int m_real, mkv_real; -+ if (fixed_seq_length) { -+ m_real = seq_length; -+ mkv_real = seq_length_kv; -+ } else { -+ m_real = (rand() % seq_length) + 1; -+ -+ // Only randomize seq_length_kv if it was set to a different value than -+ // seq_length originally. -+ if (seq_length != seq_length_kv) { -+ mkv_real = (rand() % seq_length_kv) + 1; -+ } else { -+ mkv_real = m_real; -+ } -+ } -+ -+ int m = (m_real + alignment - 1) / alignment * alignment; -+ int mkv = (mkv_real + alignment - 1) / alignment * alignment; -+ int k0 = head_size; -+ int k1 = head_size_v; -+ -+ for (int j = 0; j < head_number; ++j) { -+ cutlass::gemm::GemmCoord problem0(m, mkv, k0); -+ cutlass::gemm::GemmCoord problem1(m, k1, mkv); -+ -+ problem_sizes0.push_back(problem0); -+ problem_sizes1.push_back(problem1); -+ -+ if (use_mask) { -+ cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); -+ cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); -+ problem_sizes0_real.push_back(problem0_real); -+ problem_sizes1_real.push_back(problem1_real); -+ } -+ -+ } -+ } -+ } -+ -+ void print_problems() { -+ std::cout << " Running " << batch_size << " batches, each with " << head_number << " heads of size " << head_size << ":" << std::endl; -+ for (int i = 0; i < batch_size; ++i) { -+ int idx = i * head_number; -+ std::cout << " [" << i << "] seq_length = " << problem_sizes0[idx].m() << " seq_length_kv = " << problem_sizes0[idx].n() << std::endl; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "41_fused_multi_head_attention_variable_seqlen\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" -+ << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" -+ << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" -+ << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" -+ << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" -+ << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" -+ << " --use_mask= If true, performs padding-like masking in softmax.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n" -+ << " --causal= If true, uses causal masking.\n" -+ << " --fixed_seq_length= If true, uses the same sequence length for each item in the batch.\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fops = int64_t(); -+ -+ for (int i = 0; i < problem_sizes0.size(); ++i) { -+ auto const& problem0 = problem_sizes0[i]; -+ auto const& problem1 = problem_sizes1[i]; -+ -+ for (int row = 0; row < problem0.m(); ++row) { -+ int num_cols0 = problem0.n(); -+ if (causal) { -+ num_cols0 = std::min(row + 1, num_cols0); -+ } -+ // P <- Q . K_t -+ fops += 2 * num_cols0 * problem0.k(); -+ // P <- exp(P - max(P)) -+ fops += 2 * num_cols0; -+ // S <- sum(P) -+ fops += num_cols0 - 1; -+ // O <- P . V -+ fops += 2 * num_cols0 * problem1.n(); -+ // O <- O / S -+ fops += num_cols0 * problem1.n(); -+ } -+ } -+ -+ return double(fops) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedAttention { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using scalar_t = typename Attention::GemmKernel::scalar_t; -+ using accum_t = typename Attention::GemmKernel::accum_t; -+ using output_t = typename Attention::GemmKernel::output_t; -+ using output_accum_t = typename Attention::GemmKernel::output_accum_t; -+ -+ using ElementQ = scalar_t; -+ using ElementK = scalar_t; -+ using ElementP = accum_t; -+ using ElementAccumulator = accum_t; -+ using ElementV = scalar_t; -+ using ElementO = output_t; -+ using ElementOAccum = output_accum_t; -+ -+ using ElementCompute = accum_t; -+ -+ using ElementNorm = accum_t; -+ using ElementSum = accum_t; -+ using ElementSoftmaxCompute = accum_t; -+ -+ using LayoutQ = cutlass::layout::RowMajor; -+ using LayoutK = cutlass::layout::ColumnMajor; -+ using LayoutP = cutlass::layout::RowMajor; -+ using LayoutV = cutlass::layout::RowMajor; -+ using LayoutO = cutlass::layout::RowMajor; -+ -+ using MatrixCoord = typename LayoutP::TensorCoord; -+ -+ static bool const kNeedsOutputAccumulatorBuffer = Attention::GemmKernel::kNeedsOutputAccumulatorBuffer; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options & options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_Q; -+ cutlass::Distribution::Kind init_K; -+ cutlass::Distribution::Kind init_P; -+ cutlass::Distribution::Kind init_V; -+ cutlass::Distribution::Kind init_O; -+ uint32_t seed; -+ -+ cutlass::DeviceAllocation problem_sizes_device0; -+ cutlass::DeviceAllocation problem_sizes_device1; -+ cutlass::DeviceAllocation problem_sizes_device0_real; -+ -+ std::vector offset_Q; -+ std::vector offset_K; -+ std::vector offset_P; -+ std::vector offset_V; -+ std::vector offset_O; -+ -+ std::vector ldq_host; -+ std::vector ldk_host; -+ std::vector ldp_host; -+ std::vector ldv_host; -+ std::vector ldo_host; -+ std::vector seqlen_host; -+ -+ cutlass::DeviceAllocation ldq; -+ cutlass::DeviceAllocation ldk; -+ cutlass::DeviceAllocation ldp; -+ cutlass::DeviceAllocation ldv; -+ cutlass::DeviceAllocation ldo; -+ cutlass::DeviceAllocation seqlen; -+ -+ cutlass::DeviceAllocation block_Q; -+ cutlass::DeviceAllocation block_K; -+ cutlass::DeviceAllocation block_P; -+ cutlass::DeviceAllocation block_V; -+ cutlass::DeviceAllocation block_O; -+ cutlass::DeviceAllocation block_O_accumulate; -+ cutlass::DeviceAllocation block_Norm; -+ cutlass::DeviceAllocation block_Sum; -+ -+ cutlass::DeviceAllocation offset_P_Device; -+ -+ cutlass::DeviceAllocation ptr_Q; -+ cutlass::DeviceAllocation ptr_K; -+ cutlass::DeviceAllocation ptr_P; -+ cutlass::DeviceAllocation ptr_V; -+ cutlass::DeviceAllocation ptr_O; -+ cutlass::DeviceAllocation ptr_O_accumulate; -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedAttention( -+ Options &options_, -+ cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } -+ -+ int problem_count() const { -+ return (options.head_number * options.batch_size); -+ } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ Element *ptr, -+ size_t capacity, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 8; -+ scope_min = -8; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ ptr, capacity, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::device::BlockFillRandomGaussian( -+ ptr, capacity, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(1), Element()); -+ } -+ else { -+ -+ // Fill with all 1s -+ cutlass::reference::device::BlockFillSequential( -+ ptr, capacity, Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ -+ // -+ // Set scalors for the mha example -+ // -+ -+ options.alpha0 = 1.0f / sqrt(float(options.head_size)); -+ options.alpha1 = 1.0f; -+ options.beta = 0; -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_Q = 0; -+ int64_t total_elements_K = 0; -+ int64_t total_elements_P = 0; -+ int64_t total_elements_V = 0; -+ int64_t total_elements_O = 0; -+ -+ ldq_host.resize(problem_count()); -+ ldk_host.resize(problem_count()); -+ ldp_host.resize(problem_count()); -+ ldv_host.resize(problem_count()); -+ ldo_host.resize(problem_count()); -+ seqlen_host.resize(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ -+ auto problem0 = options.problem_sizes0.at(i); -+ auto problem1 = options.problem_sizes1.at(i); -+ -+ ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); -+ ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); -+ ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); -+ ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); -+ ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); -+ -+ // m = n for attention problems. -+ seqlen_host.at(i) = problem0.m(); -+ -+ offset_Q.push_back(total_elements_Q); -+ offset_K.push_back(total_elements_K); -+ offset_P.push_back(total_elements_P); -+ offset_V.push_back(total_elements_V); -+ offset_O.push_back(total_elements_O); -+ -+ int64_t elements_Q = problem0.m() * problem0.k(); -+ int64_t elements_K = problem0.k() * problem0.n(); -+ int64_t elements_P = problem0.m() * problem0.n(); -+ int64_t elements_V = problem1.k() * problem1.n(); -+ int64_t elements_O = problem1.m() * problem1.n(); -+ -+ total_elements_Q += elements_Q; -+ total_elements_K += elements_K; -+ total_elements_P += elements_P; -+ total_elements_V += elements_V; -+ total_elements_O += elements_O; -+ -+ } -+ -+ problem_sizes_device0.reset(problem_count()); -+ problem_sizes_device1.reset(problem_count()); -+ problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); -+ problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); -+ -+ if (options.use_mask) { -+ problem_sizes_device0_real.reset(problem_count()); -+ problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); -+ } -+ -+ ldq.reset(problem_count()); -+ ldk.reset(problem_count()); -+ ldp.reset(problem_count()); -+ ldv.reset(problem_count()); -+ ldo.reset(problem_count()); -+ seqlen.reset(problem_count()); -+ -+ ldq.copy_from_host(ldq_host.data()); -+ ldk.copy_from_host(ldk_host.data()); -+ ldp.copy_from_host(ldp_host.data()); -+ ldv.copy_from_host(ldv_host.data()); -+ ldo.copy_from_host(ldo_host.data()); -+ seqlen.copy_from_host(seqlen_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_Q.reset(total_elements_Q); -+ block_K.reset(total_elements_K); -+ block_P.reset(total_elements_P); -+ block_V.reset(total_elements_V); -+ block_O.reset(total_elements_O); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ block_O_accumulate.reset(total_elements_O); -+ } -+ -+ offset_P_Device.reset(problem_count()); -+ -+ // sync offset with device -+ cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); -+ -+ std::vector ptr_Q_host(problem_count()); -+ std::vector ptr_K_host(problem_count()); -+ std::vector ptr_P_host(problem_count()); -+ std::vector ptr_V_host(problem_count()); -+ std::vector ptr_O_host(problem_count()); -+ std::vector ptr_O_accumulate_host(problem_count()); -+ std::vector ptr_norm_host(problem_count()); -+ std::vector ptr_sum_host(problem_count()); -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); -+ ptr_K_host.at(i) = block_K.get() + offset_K.at(i); -+ ptr_P_host.at(i) = block_P.get() + offset_P.at(i); -+ ptr_V_host.at(i) = block_V.get() + offset_V.at(i); -+ ptr_O_host.at(i) = block_O.get() + offset_O.at(i); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ ptr_O_accumulate_host.at(i) = block_O_accumulate.get() + offset_O.at(i); -+ } -+ } -+ -+ ptr_Q.reset(problem_count()); -+ ptr_Q.copy_from_host(ptr_Q_host.data()); -+ -+ ptr_K.reset(problem_count()); -+ ptr_K.copy_from_host(ptr_K_host.data()); -+ -+ ptr_P.reset(problem_count()); -+ ptr_P.copy_from_host(ptr_P_host.data()); -+ -+ ptr_V.reset(problem_count()); -+ ptr_V.copy_from_host(ptr_V_host.data()); -+ -+ ptr_O.reset(problem_count()); -+ ptr_O.copy_from_host(ptr_O_host.data()); -+ -+ if (kNeedsOutputAccumulatorBuffer) { -+ ptr_O_accumulate.reset(problem_count()); -+ ptr_O_accumulate.copy_from_host(ptr_O_accumulate_host.data()); -+ } -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); -+ initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); -+ initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); -+ -+ } -+ -+ template -+ bool verify_tensor_(std::vector vector_Input, \ -+ std::vector vector_Input_Ref, -+ int64_t verify_length = -1) { -+ -+ int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); -+ size = (verify_length == -1) ? size : verify_length; -+ -+ // 0.05 for absolute error -+ float abs_tol = 5e-2f; -+ // 10% for relative error -+ float rel_tol = 1e-1f; -+ for (int64_t i = 0; i < size; ++i) { -+ float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); -+ float abs_diff = fabs(diff); -+ float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); -+ float relative_diff = abs_diff / abs_ref; -+ if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { -+ printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); -+ return false; -+ } -+ -+ } -+ -+ return true; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count(); ++i) { -+ cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); -+ cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); -+ -+ LayoutQ layout_Q(ldq_host.at(i)); -+ LayoutK layout_K(ldk_host.at(i)); -+ LayoutP layout_P(ldp_host.at(i)); -+ LayoutV layout_V(ldv_host.at(i)); -+ LayoutO layout_O(ldo_host.at(i)); -+ -+ MatrixCoord extent_Q{problem0.m(), problem0.k()}; -+ MatrixCoord extent_K{problem0.k(), problem0.n()}; -+ MatrixCoord extent_P{problem0.m(), problem0.n()}; -+ MatrixCoord extent_V{problem1.k(), problem1.n()}; -+ MatrixCoord extent_O{problem1.m(), problem1.n()}; -+ -+ cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); -+ cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); -+ cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); -+ cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); -+ -+ cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); -+ cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); -+ -+ cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); -+ cutlass::reference::device::TensorFill(view_Ref_O_device, ElementO(0)); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementQ, LayoutQ, -+ ElementK, LayoutK, -+ ElementP, LayoutP, -+ ElementCompute, ElementAccumulator -+ >( -+ problem0, -+ ElementAccumulator(options.alpha0), -+ view_Q, -+ Attention::GemmKernel::MM0::Mma::kTransformA, -+ view_K, -+ Attention::GemmKernel::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_P, -+ view_Ref_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Compute softmax for P. We need to explicitly compute softmax -+ // over P because softmax is fused to the second GEMM in the -+ // profiled implementation. -+ std::vector matrix_Ref(layout_P.capacity(extent_P)); -+ cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); -+ cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); -+ std::vector vector_Norm_Ref(problem0.m()); -+ std::vector vector_Sum_Ref(problem0.m()); -+ -+ int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); -+ -+ // Compute softmax for reference matrix -+ for (int m = 0; m < problem0.m(); m++) { -+ int n_dim_row = n_dim; -+ if (options.causal) { -+ n_dim_row = std::min(m + 1, n_dim); -+ } -+ ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); -+ for (int n = 1; n < n_dim_row; n++) { -+ max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); -+ } -+ -+ vector_Norm_Ref.at(m) = ElementNorm(max); -+ -+ ElementSoftmaxCompute sum = ElementSoftmaxCompute(); -+ for (int n = 0; n < n_dim_row; n++) { -+ sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); -+ } -+ ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); -+ -+ vector_Sum_Ref.at(m) = ElementSum(inv_sum); -+ -+ for (int n = 0; n < n_dim_row; n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP( -+ std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum -+ ); -+ } -+ // Mask out the rest of the attention matrix -+ for (int n = n_dim_row; n < n_dim; ++n) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ -+ } -+ -+ // when not using mask, problem_real and problem share the same sizes -+ if (options.use_mask) { -+ for (int m = 0; m < problem0.m(); m++) { -+ for (int n = n_dim; n < problem0.n(); n++) { -+ view_Ref_host.ref().at({m, n}) = ElementP(0); -+ } -+ } -+ } -+ -+ cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); -+ -+ // Reference GEMM -+ cutlass::reference::device::GemmComplex< -+ ElementP, LayoutP, -+ ElementV, LayoutV, -+ ElementO, LayoutO, -+ ElementCompute, ElementAccumulator -+ >( -+ problem1, -+ ElementAccumulator(options.alpha1), -+ view_P, -+ Attention::GemmKernel::MM0::Mma::kTransformA, -+ view_V, -+ Attention::GemmKernel::MM0::Mma::kTransformB, -+ ElementAccumulator(options.beta), -+ view_Ref_O_device, -+ view_Ref_O_device, -+ ElementAccumulator(0) -+ ); -+ -+ // Copy to host memory -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); -+ -+ std::vector matrix_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); -+ std::vector matrix_Ref_O(layout_O.capacity(extent_O)); -+ cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); -+ -+ -+ bool verified_O = false; -+ if (!verified_O) { -+ verified_O = verify_tensor_(matrix_O, matrix_Ref_O); -+ } -+ -+ passed = passed && verified_O; -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; -+ -+ if (!verified_O) { -+ std::cout << "Final matrix output is incorrect" << std::endl; -+ } -+ -+ return passed; -+ } -+ -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ -+ /// Executes a CUTLASS Attention kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ result.passed = false; -+ -+ int threadblock_count = Attention::sufficient(options.problem_sizes1.data(), options.problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ typename Attention::Arguments args( -+ problem_sizes_device0.get(), -+ problem_sizes_device1.get(), -+ options.problem_count, -+ threadblock_count, -+ ptr_Q.get(), -+ ptr_K.get(), -+ ptr_P.get(), -+ ptr_V.get(), -+ ptr_O.get(), -+ ptr_O_accumulate.get(), -+ ldq.get(), -+ ldk.get(), -+ ldp.get(), -+ ldv.get(), -+ ldo.get(), -+ options.causal, -+ options.problem_sizes1.data() -+ ); -+ -+ Attention fmha; -+ -+ size_t workspace_size = fmha.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ result.status = fmha.initialize(args, workspace.get()); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the grouped FMHA object -+ result.status = fmha.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run of the grouped FMHA object -+ // -+ result.status = fmha.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of FMHA operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < this->options.iterations; ++iter) { -+ fmha(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(this->options.iterations); -+ result.gflops = this->options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "CUTLASS Attention:\n" -+ << "====================================================" << std::endl; -+ std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ -+ << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ -+ << ", " << options.batch_size << "}." << std::endl; -+ options.print_problems(); -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << "GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+ -+ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration, -+ cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ -+> -+int run_grouped(Options& options) { -+ using AttentionKernel = typename cutlass::gemm::kernel::DefaultFMHAGrouped< -+ cutlass::half_t, // scalar_t -+ cutlass::arch::Sm80, // ArchTag -+ true, // Memory is aligned -+ kQueriesPerBlock, -+ kKeysPerBlock, -+ kSingleValueIteration, -+ GroupScheduleMode_ -+ >::FMHAKernel; -+ -+ using FMHA = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test and profile -+ // -+ -+ TestbedAttention testbed(options); -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS attention has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ return 0; -+} -+ -+ -+template < -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration -+> -+int run_attention(Options& options) { -+ if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { -+ return run_grouped(options); -+ } else { -+ return run_grouped(options); -+ } -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+ if (options.use_mask) { -+ std::cerr << "--use_mask is not supported at the moment\n"; -+ return -2; -+ } -+ if (options.alignment != 1) { -+ std::cerr << "--alignment=1 is the only supported value\n"; -+ return -2; -+ } -+ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (options.head_size_v > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (options.head_size_v <= kKeysPerBlock) { -+ return run_attention(options); -+ } else { -+ return run_attention(options); -+ } -+ } else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return run_attention(options); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h -new file mode 100644 -index 0000000..7326bad ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h -@@ -0,0 +1,124 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "custom_mma_multistage.h" -+#include "custom_mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+ -+template -+struct MakeCustomMma; -+ -+template < -+ typename Shape, -+ typename IteratorA, -+ typename SmemIteratorA, -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ typename IteratorB, -+ typename SmemIteratorB, -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ typename ElementC, -+ typename LayoutC, -+ typename Policy, -+ int Stages, -+ cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, -+ int kMaxK> -+struct MakeCustomMma< -+ cutlass::gemm::threadblock::MmaMultistage< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ ElementC, -+ LayoutC, -+ Policy, -+ Stages, -+ SharedMemoryClear>, -+ kMaxK> { -+ // Reduce the number of stages if we don't need that many -+ static int constexpr kStages = -+ kMaxK == cutlass::platform::numeric_limits::max() -+ ? Stages -+ : cutlass::const_min( -+ Stages, -+ (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); -+ using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ ElementC, -+ LayoutC, -+ Policy, -+ kStages, -+ SharedMemoryClear, -+ kMaxK>; -+}; -+ -+template < -+ typename Shape, -+ typename IteratorA, -+ typename SmemIteratorA, -+ typename IteratorB, -+ typename SmemIteratorB, -+ typename ElementC, -+ typename LayoutC, -+ typename Policy, -+ int kMaxK> -+struct MakeCustomMma< -+ cutlass::gemm::threadblock::MmaPipelined< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ Policy>, -+ kMaxK> { -+ using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< -+ Shape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ Policy>; -+}; -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h -new file mode 100644 -index 0000000..6c6d078 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h -@@ -0,0 +1,183 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape< -+ Shape::kM / WarpGemm::kM, -+ Shape::kN / WarpGemm::kN, -+ Shape::kK / WarpGemm::kK>; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ template -+ struct OperandSharedStorage { -+ AlignedBuffer buffer; -+ using TensorRef = TensorRef; -+ -+ CUTLASS_DEVICE -+ static OperandLayout Layout() { -+ return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the operand -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() { -+ return TensorRef{buffer.data(), Layout()}; -+ } -+ }; -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape< -+ Shape::kM + Policy::SmemPaddingA::kRow, -+ Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape< -+ Shape::kK * kStages + Policy::SmemPaddingB::kRow, -+ Shape::kN + Policy::SmemPaddingB::kColumn>; -+ -+ using SharedStorageA = OperandSharedStorage< -+ typename Operator::ElementA, -+ ShapeA, -+ typename Operator::LayoutA>; -+ using SharedStorageB = OperandSharedStorage< -+ typename Operator::ElementB, -+ ShapeB, -+ typename Operator::LayoutB>; -+ using TensorRefA = typename SharedStorageA::TensorRef; -+ using TensorRefB = typename SharedStorageB::TensorRef; -+ -+ struct SharedStorage { -+ /// Buffer for A operand -+ SharedStorageA operand_A; -+ -+ /// Buffer for B operand -+ SharedStorageB operand_B; -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorageA& shared_storageA, -+ SharedStorageB& shared_storageB, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h -new file mode 100644 -index 0000000..e5cdc88 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h -@@ -0,0 +1,767 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "custom_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Upper boundon the K dimension -+ int kMaxK = cutlass::platform::numeric_limits::max(), -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaMultistage : public CustomMmaBase { -+ public: -+ ///< Base class -+ using Base = CustomMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ static_assert( -+ Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / -+ Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / -+ Base::kWarpGemmIterations; -+ }; -+ -+ static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; -+ static constexpr int kNumStagesConcurrentLoad = -+ kSmemContainsEntireMat ? Stages : Stages - 1; -+ -+ private: -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ bool prologue_done_; -+ -+ // Set to `True` to ensure the accumulator will be zero outside the GEMM -+ // footprint -+ bool zero_outside_bounds_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storageA.ref(), thread_idx), -+ smem_iterator_B_(shared_storageB.ref(), thread_idx), -+ prologue_done_(false), -+ zero_outside_bounds_(false) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ CUTLASS_DEVICE -+ CustomMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage& st, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : CustomMmaMultistage( -+ st.operand_A, -+ st.operand_B, -+ thread_idx, -+ warp_idx, -+ lane_idx) {} -+ -+ CUTLASS_DEVICE -+ bool set_prologue_done(bool value) { -+ prologue_done_ = value; -+ } -+ -+ CUTLASS_DEVICE -+ bool set_zero_outside_bounds(bool value) { -+ zero_outside_bounds_ = value; -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ prologue( -+ shared_storage.operand_A, -+ shared_storage.operand_B, -+ iterator_A, -+ iterator_B, -+ thread_idx, -+ problem_size_k); -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); -+ SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); -+ int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; -+ _prologue( -+ iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA& iterator_A, -+ IteratorB& iterator_B, -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ iterator_A.set_iteration_index( -+ group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (zero_outside_bounds_ || -+ SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index( -+ group_start_B * IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (zero_outside_bounds_ || -+ SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ template -+ CUTLASS_DEVICE static void _prologue( -+ IteratorA& iterator_A, -+ IteratorB& iterator_B, -+ int32_t& gemm_k_iterations, -+ SmemIteratorA& smem_iterator_A_, -+ SmemIteratorB& smem_iterator_B_) { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations) { -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ if (kLoadA) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ if (kLoadB) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ -+ ++smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ smem_iterator_A_.add_tile_offset({0, 1}); -+ smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC& accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const& src_accum) { -+ // -+ // Prologue -+ // -+ -+ if (!prologue_done_) { -+ _prologue( -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations, -+ smem_iterator_A_, -+ smem_iterator_B_); -+ } else if (!kSmemContainsEntireMat) { -+ _prologue( -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations, -+ smem_iterator_A_, -+ smem_iterator_B_); -+ } else { -+ gemm_k_iterations -= kNumStagesConcurrentLoad; -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for -+ // some kernels so that all accumulator elements outside the GEMM footprint -+ // are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ /// Iterator to write threadblock-scoped tile of A operand to shared -+ /// memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType* dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared -+ /// memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType* dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform( -+ warp_transformed_frag_A[0], -+ warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], -+ warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ // In case of a non-circular buffer ("kSmemContainsEntireMat") -+ // make sure we don't load out of bounds data. -+ if (!kSmemContainsEntireMat || -+ gemm_k_iterations > (-kNumStagesConcurrentLoad) || -+ warp_mma_k < Base::kWarpGemmIterations - 1) { -+ this->warp_tile_iterator_A_.load( -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform( -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (!kSmemContainsEntireMat && -+ warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ if (!kSmemContainsEntireMat) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (!kSmemContainsEntireMat && -+ smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -+ -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ if (platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM -+ // mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h -new file mode 100644 -index 0000000..73112e9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h -@@ -0,0 +1,401 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "custom_mma_base.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool> -+class CustomMmaPipelined : public CustomMmaBase { -+ public: -+ ///< Base class -+ using Base = CustomMmaBase; -+ -+ using Shape = -+ Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = -+ IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = -+ IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert( -+ (Base::kStages == 2), -+ "MmaPipelined requires kStages set to value 2"); -+ -+ static bool const kSmemContainsEntireMat = false; -+ -+ private: -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ protected: -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ CustomMmaPipelined( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ) -+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storageA.ref(), thread_idx), -+ smem_iterator_B_(shared_storageB.ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ CUTLASS_DEVICE -+ CustomMmaPipelined( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage& st, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : CustomMmaPipelined( -+ st.operand_A, -+ st.operand_B, -+ thread_idx, -+ warp_idx, -+ lane_idx) {} -+ -+ CUTLASS_DEVICE -+ bool set_prologue_done(bool value) { -+ // NOT IMPLEMENTED FOR PIPELINED -+ } -+ -+ CUTLASS_DEVICE -+ bool set_zero_outside_bounds(bool value) { -+ // NOT NEEDED FOR PIPELINED -+ // shared memory will always be zero-filled -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ prologue( -+ shared_storage.operand_A, -+ shared_storage.operand_B, -+ iterator_A, -+ iterator_B, -+ thread_idx, -+ problem_size_k); -+ } -+ -+ template -+ CUTLASS_DEVICE static void prologue( -+ typename Base::SharedStorageA& shared_storageA, -+ typename Base::SharedStorageB& shared_storageB, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ int thread_idx, -+ int problem_size_k) { -+ // NOT IMPLEMENTED FOR PIPELINED -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC& accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const& src_accum, ///< source accumulator tile -+ TransformA transform_A = -+ TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = -+ TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* -+ // issuing shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -+ -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h -new file mode 100644 -index 0000000..1930717 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h -@@ -0,0 +1,295 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Some helper functions -+//////////////////////////////////////////////////////////////////////////////// -+#define DISPATCH_TYPES(tensor, func) \ -+ { \ -+ if (query.scalar_type() == at::ScalarType::Float) { \ -+ using scalar_t = float; \ -+ func(); \ -+ } else if (query.scalar_type() == at::ScalarType::Half) { \ -+ using scalar_t = cutlass::half_t; \ -+ func(); \ -+ } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ -+ using scalar_t = cutlass::bfloat16_t; \ -+ func(); \ -+ } else { \ -+ TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ -+ } \ -+ } -+ -+#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ -+ { \ -+ if (BOOL_V) { \ -+ constexpr bool BOOL_NAME = true; \ -+ F(); \ -+ } else { \ -+ constexpr bool BOOL_NAME = false; \ -+ F(); \ -+ } \ -+ } -+#define DISPATCH_ARCHTAG(CC, func) \ -+ { \ -+ if (CC >= 80) { \ -+ using ArchTag = cutlass::arch::Sm80; \ -+ func(); \ -+ } else if (CC >= 75) { \ -+ using ArchTag = cutlass::arch::Sm75; \ -+ func(); \ -+ } else if (CC >= 70) { \ -+ using ArchTag = cutlass::arch::Sm70; \ -+ func(); \ -+ } else if (CC >= 50) { \ -+ using ArchTag = cutlass::arch::Sm50; \ -+ func(); \ -+ } else { \ -+ TORCH_CHECK( \ -+ false, \ -+ "Your device is too old. We require compute capability >= 50"); \ -+ } \ -+ } -+ -+#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ -+ TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ -+ TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ -+ TORCH_CHECK(TENSOR.is_contiguous()); -+ -+#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ -+ TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ -+ TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ -+ TORCH_CHECK( \ -+ TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); -+ -+#ifdef HAS_PYTORCH -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") -+#define XFORMERS_CHECK TORCH_CHECK -+#elif defined(__CUDACC_RTC__) -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ -+ return false; \ -+ } -+#define XFORMERS_CHECK(COND, ERR) \ -+ if (!(COND)) { \ -+ return false; \ -+ } -+#else -+#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ -+ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ -+ std::cerr << #PTR " is not correctly aligned\n"; \ -+ return false; \ -+ } -+#define XFORMERS_CHECK(COND, ERR) \ -+ if (!(COND)) { \ -+ std::cerr << #COND " failed\n"; \ -+ return false; \ -+ } -+#endif -+ -+#define ASSIGN_CHECK_OVERFLOW(A, B) \ -+ { \ -+ A = B; \ -+ TORCH_CHECK( \ -+ B < cutlass::platform::numeric_limits::max(), \ -+ #B " overflows"); \ -+ } -+ -+namespace gemm_kernel_utils { -+ -+#ifdef HAS_PYTORCH -+template -+struct TypeTraits; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = cutlass::half_t; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::Half; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return at::PackedTensorAccessor32( -+ (scalar_t*)(tensor.data_ptr()), -+ tensor.sizes().data(), -+ tensor.strides().data()); -+ } -+}; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = cutlass::bfloat16_t; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::BFloat16; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return at::PackedTensorAccessor32( -+ (scalar_t*)(tensor.data_ptr()), -+ tensor.sizes().data(), -+ tensor.strides().data()); -+ } -+}; -+ -+template <> -+struct TypeTraits { -+ using scalar_t = float; -+ -+ static constexpr __host__ at::ScalarType atScalarType() { -+ return at::ScalarType::Float; -+ } -+ template -+ static __host__ at::PackedTensorAccessor32 packed_accessor( -+ at::Tensor const& tensor) { -+ return tensor.packed_accessor32(); -+ } -+}; -+#endif -+ -+template -+constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { -+ return (n + m - 1) / m; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Determine the type of GEMM we do (TensorCores or not, Shapes ...) -+// TODO: Maybe we could rely on Cutlass's DefaultGemm templates -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Fallback to Simt (FMA on cuda cores) if not in a special case below -+template -+struct DefaultGemmType { -+ static constexpr int ThreadK = 8; -+ static constexpr int WarpK = 8; -+ static constexpr int kMinimumAlignment = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using OpClass = cutlass::arch::OpClassSimt; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Specialization for tensorcores with f32 -+template -+struct DefaultGemmType< -+ ArchTag, -+ float, -+ typename cutlass::platform::enable_if< -+ ArchTag::kMinComputeCapability >= 80>::type> { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 4; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Operator = cutlass::arch::OpMultiplyAdd; // FastF32; -+}; -+ -+// Specialization for tensorcores with f16/bf16 - Sm75+ -+template -+struct DefaultGemmType< -+ ArchTag, -+ scalar_t, -+ typename cutlass::platform::enable_if< -+ ArchTag::kMinComputeCapability >= 75 && -+ cutlass::sizeof_bits::value == 16>::type> { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 4; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Specialization for tensorcores with f16 - Volta -+template <> -+struct DefaultGemmType { -+ static constexpr int ThreadK = 32; -+ static constexpr int WarpK = 32; -+ static constexpr int kMinimumAlignment = 2; -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+}; -+ -+// Enables to do -+// `auto x = kCondition ? fa(arg) : fb(arg)` -+// when `fa` and `fb` have different types -+template -+struct call_conditional; -+ -+template -+struct call_conditional { -+ template -+ static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -+ -> decltype(ta(arg)) { -+ return ta(arg); -+ } -+}; -+ -+template -+struct call_conditional { -+ template -+ static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -+ -> decltype(tb(arg)) { -+ return tb(arg); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mark a variable as warp-uniform - enables some compiler optimizations -+// The cheapest way to do it is just to broadcast it from lane 0 -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { -+ return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); -+} -+ -+template -+CUTLASS_DEVICE T* warp_uniform(T* ptr) { -+ struct { -+ union { -+ T* ptr; -+ uint32_t asInt[2]; -+ }; -+ } p; -+ p.ptr = ptr; -+ p.asInt[0] = warp_uniform(p.asInt[0]); -+ p.asInt[1] = warp_uniform(p.asInt[1]); -+ return p.ptr; -+} -+} // namespace gemm_kernel_utils -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h -new file mode 100644 -index 0000000..298876e ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h -@@ -0,0 +1,752 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue iterator that supports prefetching -+ -+ Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in -+/// epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | -+/// ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ bool ScatterD = false, ///< Scatter D operand or not -+ bool UseCUDAStore = false> -+class PredicatedTileIteratorPrefetch { -+ public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( -+ ThreadMap::Iterations::kRow > 0, -+ "ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kGroup > 0, -+ "ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kCluster > 0, -+ "ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( -+ ThreadMap::Iterations::kColumn > 0, -+ "ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc()) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : Base(base) {} -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer -+ uint8_t* byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Scatter indices -+ int const* indices_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert( -+ sizeof(PredicatedTileIteratorParams::stride) == 8, -+ "Expected 64b strides"); -+ -+ private: -+ // -+ // Methods -+ // -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorPrefetch( -+ PredicatedTileIteratorParams const& params, -+ Element* pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord(), -+ int const* indices = nullptr) -+ : params_(params), indices_(indices) { -+ TensorCoord thread_offset = -+ ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ mask_.predicates[c] = -+ ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < -+ extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ if (ScatterD && !indices) { -+ mask_.clear(); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / -+ kElementsPerAccess; -+ -+ if (ScatterD) { -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / -+ kElementsPerAccess; -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void prefetch_all() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < kIterations; ++iter) { -+ prefetch(); -+ ++(*this); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void prefetch() { -+ uint8_t* byte_pointer = byte_pointer_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ // on windows using unsigned long here gives the error -+ // error: asm operand type size(4) does not match -+ // type/size implied by constraint 'l' -+ uint64_t addr = (uint64_t)( -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess]); -+ asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast( -+ byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * -+ LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast( -+ byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * -+ LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ if (UseCUDAStore) { -+ if (guard) { -+ memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + -+ column]; -+ } -+ } else { -+ cutlass::arch::global_store( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void downsample_load_with_byte_offset( -+ Fragment& frag, -+ int64_t byte_offset, -+ int convolution_P, -+ int convolution_Q, -+ int add_P, -+ int add_Q, -+ int problem_N) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ -+ int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + -+ (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; -+ -+ int64_t byte_offset = -+ (input_row - output_row) * problem_N * sizeof(float); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void upsample_load_with_byte_offset( -+ Fragment& frag, -+ int64_t byte_offset, -+ int convolution_P, -+ int convolution_Q, -+ int add_P, -+ int add_Q, -+ int problem_N) const { -+ uint8_t* byte_pointer = byte_pointer_; -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; -+ ++cluster) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ int frag_row_idx = -+ (row + -+ ThreadMap::Iterations::kRow * -+ (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow + -+ group * ThreadMap::Delta::kGroup + -+ cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ int row_add_P = add_P; -+ int row_add_Q = add_Q; -+ if (output_P > convolution_P - 2) -+ row_add_P = 0; -+ if (output_Q > convolution_Q - 2) -+ row_add_Q = 0; -+ -+ int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + -+ ((output_P + row_add_P) / 2) * (convolution_Q / 2) + -+ (output_Q + row_add_Q) / 2; -+ -+ int64_t byte_offset = -+ (input_row - output_row) * problem_N * sizeof(float); -+ -+ AccessType* memory_pointer = -+ reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; -+ ++column) { -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load( -+ frag_ptr -+ [frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void*)&memory_pointer -+ [column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorPrefetch& operator++() { -+ ++state_[0]; -+ -+ if (!ScatterD) { -+ byte_pointer_ += params_.advance_row; -+ } -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * -+ ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask& mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const& mask) { -+ mask_ = mask; -+ } -+}; -+ -+template -+struct MakePrefetchableIterator { -+ using Iterator = PredicatedTileIteratorPrefetch< -+ typename IT::ThreadMap, -+ typename IT::Element>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h -new file mode 100644 -index 0000000..e6b5d58 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h -@@ -0,0 +1,97 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "predicated_tile_access_iterator_residual_last.h" -+#include "predicated_tile_iterator_residual_last.h" -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+template -+struct MakeIteratorResidualLast; -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize, -+ bool Gather> -+struct MakeIteratorResidualLast> { -+ using Iterator = PredicatedTileIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessSize, -+ Gather>; -+}; -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ typename AccessType, -+ bool Gather> -+struct MakeIteratorResidualLast> { -+ using Iterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessType, -+ Gather>; -+}; -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h -new file mode 100644 -index 0000000..b9c38cc ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h -@@ -0,0 +1,2115 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile -+ this iterator visits maybe partial, then the remaining tiles are complete. -+ So, we only need to compute the predicates twice, once before the first tile -+ and once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorResidualLast -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ typename AccessType, -+ bool Gather = false> -+class PredicatedTileAccessIteratorResidualLast; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear -+/// data. -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, -+ Element, -+ Layout, -+ AdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = -+ ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert( -+ !(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : Base( -+ layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap>()()) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : Base(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ UnderlyingPredicates the_predicates; -+ Mask residual_tile_mask; -+ -+ /// Parameters object with precomputed internal state -+ Params const& params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Below is used when Gather is turned on. We need to record strided_offset -+ /// and contiguous_offset seperated to compute the offset by using -+ /// -+ /// offset = contiguous_offset + indices[strided_offset] -+ /// -+ -+ /// Gather indices -+ int const* indices_; -+ -+ Index gather_offset_strided; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ indices_(indices) { -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ the_predicates.get_mask(residual_tile_mask); -+ -+ // Working around a weird compiler bug happening on P100 for the backward. -+ // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) -+ // residual_tile_mask[0] = 15 (correct) -+ // -+ // Adding prints when the value is calculated (in `compute_predicates_`) -+ // sometimes removes the bug. The consequence is that we skip some -+ // element of a tensor, leading to wrong results -+ // Setting `compute_predicates_`'s second argument (`is_steady_state`) to -+ // true also seems to get rid of the bug - at the cost of twice as many -+ // comparisons. -+#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) -+ constexpr bool kWorkAroundCompilerBug = false; -+#else -+ constexpr bool kWorkAroundCompilerBug = true; -+#endif -+ the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset( -+ layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); -+ } -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool is_residual_tile) { -+ if (is_residual_tile) { -+ the_predicates.set_mask(residual_tile_mask); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ if (!Gather) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ if (Gather) { -+ assert(indices_); -+ -+ if (!valid()) { -+ return nullptr; -+ } -+ -+ LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * -+ (ThreadMap::Delta::kContiguous * sizeof_bits::value / -+ 8) + -+ the_predicates.iteration_vector_; -+ int strided_index = gather_offset_strided + -+ the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ LongIndex strided_offset = indices_[strided_index] * -+ LongIndex(params_.stride_) * sizeof_bits::value / 8; -+ -+ return reinterpret_cast( -+ pointer_ + contiguous_offset + strided_offset); -+ } -+ -+ return reinterpret_cast( -+ pointer_ + -+ the_predicates.iteration_contiguous_ * -+ (ThreadMap::Delta::kContiguous * -+ sizeof_bits::value) / -+ 8) + -+ the_predicates.iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ the_predicates.operator++(); -+ -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < -+ ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ if (!Gather) { -+ pointer_ += params_.inc_strided_; -+ } -+ -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ if (!Gather) { -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, -+ // this subtraction as well as the subsequent integer addition are both -+ // elided by the compiler. -+ pointer_ -= params_.inc_advance_; -+ } -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ bool Gather> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 -+/// data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRankN<2>, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, -+ Element, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = -+ ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert( -+ !(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ Coord stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// contiguous dimension -+ LongIndex inc_contiguous_; -+ /// amount (in byte) to increment pointer from first access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_next_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() -+ : stride_(0), -+ inc_contiguous_(0), -+ inc_strided_(0), -+ inc_next_(0), -+ inc_advance_(0) {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : stride_({layout.stride(0), layout.stride(1)}) { -+ inc_contiguous_ = -+ (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * -+ sizeof_bits::value / 8; -+ -+ inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ inc_next_strided_ = inc_strided_ - -+ LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * -+ sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = -+ Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - -+ LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - -+ LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const& params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ UnderlyingPredicates the_predicates; -+ Mask residual_tile_mask; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent) { -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool is_residual_tile) { -+ if (is_residual_tile) { -+ the_predicates.set_mask(residual_tile_mask); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(pointer_) + -+ the_predicates.iteration_vector_; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ the_predicates.operator++(); -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < -+ ThreadMap::Iterations::kContiguous) { -+ pointer_ += params_.inc_contiguous_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_next_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 -+/// column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset( -+ make_Coord(tile_offset.row(), tile_offset.column())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 -+/// row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ ///< Precomputed parameters object -+ Params const& params, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset( -+ make_Coord(tile_offset.column(), tile_offset.row())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major -+/// interleaved data. It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ int InterleavedK> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kRow * kInterleavedK, -+ Shape::kColumn / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major -+/// interleaved data. -+// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ typename AccessType_, -+ int InterleavedK> -+class PredicatedTileAccessIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessType_, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kColumn * kInterleavedK, -+ Shape::kRow / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iterator_.set_iteration_index(index); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const& tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType* get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorResidualLast operator++(int) { -+ PredicatedTileAccessIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h -new file mode 100644 -index 0000000..4bb96a1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h -@@ -0,0 +1,2120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 -+ tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile -+ this iterator visits maybe partial, then the remaining tiles are complete. -+ So, we only need to compute the predicates twice, once before the first tile -+ and once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIteratorResidualLast -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize -+/// register liveness and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" -+/// object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is -+/// constructed. Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator -+/// is constructed. Subsequent additions to logical coordinate offset may be -+/// performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be -+/// partially full in both the advance dimension and the steady-state dimension. -+/// This is assumed to be the last tile in the iteration sequence. Advancing an -+/// iterator that has just been constructed moves to the first tile that is full -+/// in the advance dimension and recomputes predicates. Subsequent accesses may -+/// be performed without updating internal predicates and are efficient in terms -+/// of live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced -+/// at least once outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to -+/// dereferencing the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update -+// internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - -+// subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to -+// steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = -+// transform::threadblock::PredicatedTileIteratorResidualLast; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess, -+ bool Gather = false> -+class PredicatedTileIteratorResidualLast; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::PitchLinear, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray< -+ Element, -+ AccessSize, -+ (AccessSize * sizeof_bits::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap, -+ AccessType, -+ Gather>; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIteratorResidualLast; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) : params_(layout) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const& base) : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ /// Gather indices -+ int const* indices = nullptr) -+ : address_iterator_( -+ params.params_, -+ pointer, -+ extent, -+ thread_id, -+ threadblock_offset, -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ address_iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ address_iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ address_iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ address_iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ address_iterator_.get_mask(mask); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ load_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + -+ byte_offset; -+ -+ AccessType const* access_ptr = -+ reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ store_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType* access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize, -+ Gather>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ Gather> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize, -+ Gather>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = nullptr ///< Gather indices -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRankN<2>, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray< -+ Element, -+ AccessSize, -+ (AccessSize * sizeof_bits::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< -+ Shape, -+ Element, -+ Layout, -+ kAdvanceRank, -+ ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileIteratorResidualLast; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) : params_(layout) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char*; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : address_iterator_( -+ params.params_, -+ pointer, -+ extent, -+ thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset(make_Coord(0, 1)); -+ else -+ address_iterator_.add_tile_offset(make_Coord(1, 0)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ address_iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ address_iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ address_iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ address_iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ address_iterator_.get_mask(mask); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ load_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + -+ byte_offset; -+ -+ AccessType const* access_ptr = -+ reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ store_with_byte_offset( -+ frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ int idx = v + -+ kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char* byte_ptr = -+ reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType* access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 -+/// column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2ColumnMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 -+/// row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::AffineRank2RowMajor, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const& threadblock_offset, ///< Initial offset of threadblock -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ int InterleavedK> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::ColumnMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kRow * kInterleavedK, -+ Shape::kColumn / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 -+/// data. It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ int InterleavedK> -+class PredicatedTileIteratorResidualLast< -+ Shape_, -+ Element_, -+ layout::RowMajorInterleaved, -+ AdvanceRank, -+ ThreadMap_, -+ AccessSize, -+ false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element*; -+ using NonConstPointer = typename platform::remove_const::type*; -+ -+ using UnderlyingIterator = PredicatedTileIteratorResidualLast< -+ layout::PitchLinearShape< -+ Shape::kColumn * kInterleavedK, -+ Shape::kRow / kInterleavedK>, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIteratorResidualLast; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const& layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const& base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ /// Precomputed parameters object -+ Params const& params, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const& threadblock_offset, -+ int const* indices = -+ nullptr ///< gather/scatter indices, note no support for -+ ///< gather/scatter at this specialization -+ ) -+ : iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord( -+ extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIteratorResidualLast with zero threadblock -+ /// offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast( -+ Params const& params, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorResidualLast( -+ params, -+ pointer, -+ extent, -+ thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast& operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorResidualLast operator++(int) { -+ PredicatedTileIteratorResidualLast self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_residual_tile(bool enable) { -+ iterator_.set_residual_tile(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const& mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask& mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment& frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const& frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h -new file mode 100644 -index 0000000..6f5eb3f ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h -@@ -0,0 +1,1108 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holdvr nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#ifdef HAS_PYTORCH -+#include -+#include -+#include -+#include -+#endif -+ -+#include -+#include -+ -+#include "cutlass/bfloat16.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "attention_scaling_coefs_updater.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "debug_utils.h" -+#include "epilogue_pipelined.h" -+#include "epilogue_rescale_output.h" -+#include "find_default_mma.h" -+#include "gemm_kernel_utils.h" -+#include "mma_from_smem.h" -+#include "transform/tile_smem_loader.h" -+ -+#include -+ -+using namespace gemm_kernel_utils; -+ -+namespace { -+template -+constexpr int getWarpsPerSm() { -+ return ( -+ Arch::kMinComputeCapability >= 80 && -+ !cutlass::platform::is_same::value -+ ? 16 -+ : 12); -+} -+} // namespace -+ -+template < -+ // The datatype of Q/K/V -+ typename scalar_t_, -+ // Architecture we are targeting (eg `cutlass::arch::Sm80`) -+ typename ArchTag, -+ // If Q/K/V are correctly aligned in memory and we can run a fast kernel -+ bool isAligned_, -+ int kQueriesPerBlock, -+ int kKeysPerBlock, -+ bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` -+ > -+struct AttentionKernel { -+ using scalar_t = scalar_t_; -+ using accum_t = float; -+ using lse_scalar_t = float; -+ using output_t = scalar_t; -+ // Accumulator between 2 iterations -+ // Using `accum_t` improves perf on f16 at the cost of -+ // numerical errors -+ using output_accum_t = accum_t; -+ static constexpr bool kIsAligned = isAligned_; -+ static constexpr int32_t kAlignLSE = 32; // block size of backward -+ static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && -+ cutlass::sizeof_bits::value == 16; -+ static constexpr bool kKeepOutputInRF = kSingleValueIteration; -+ static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && -+ !cutlass::platform::is_same::value; -+ -+ static_assert(kQueriesPerBlock % 32 == 0, ""); -+ static_assert(kKeysPerBlock % 32 == 0, ""); -+ static constexpr int kNumWarpsPerBlock = -+ kQueriesPerBlock * kKeysPerBlock / (32 * 32); -+ static constexpr int kWarpSize = 32; -+ -+ // Launch bounds -+ static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; -+ static constexpr int kMinBlocksPerSm = -+ getWarpsPerSm() / kNumWarpsPerBlock; -+ -+ struct Params { -+ // Input tensors -+ scalar_t* query_ptr; // [num_queries, num_heads, head_dim] -+ scalar_t* key_ptr; // [num_keys, num_heads, head_dim] -+ scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] -+ int32_t* cu_seqlens_q_ptr = nullptr; -+ int32_t* cu_seqlens_k_ptr = nullptr; -+ scalar_t* attn_mask_ptr = nullptr; // [num_queries, num_keys] -+ scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] -+ -+ // Output tensors -+ output_t* output_ptr; // [num_queries, num_heads, head_dim_value] -+ output_accum_t* -+ output_accum_ptr; // [num_queries, num_heads, head_dim_value] -+ lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null -+ float scale; -+ // Dimensions/strides -+ int32_t head_dim; -+ int32_t head_dim_value; -+ int32_t num_queries; -+ int32_t num_keys; -+ -+ bool causal; -+ bool no_bias_head_dim; -+ bool use_past; -+ -+ int32_t q_strideM; -+ int32_t k_strideM; -+ int32_t v_strideM; -+ int32_t attn_mask_strideM; -+ int32_t attn_bias_strideM; -+ -+ // Everything below is only used in `advance_to_block` -+ // and shouldn't use registers -+ int32_t q_strideH; -+ int32_t k_strideH; -+ int32_t v_strideH; -+ int32_t o_strideH; -+ int32_t attn_mask_strideH; -+ int32_t attn_bias_strideH; -+ int64_t q_strideB; -+ int64_t k_strideB; -+ int64_t v_strideB; -+ int64_t o_strideB; -+ int64_t attn_mask_strideB; -+ int64_t attn_bias_strideB; -+ int32_t num_batches; -+ int32_t num_heads; -+ -+ CUTLASS_HOST_DEVICE int32_t o_strideM() const { -+ return head_dim_value * num_heads; -+ } -+ -+ // Moves pointers to what we should process -+ // Returns "false" if there is no work to do -+ CUTLASS_DEVICE bool advance_to_block() { -+ auto batch_id = blockIdx.z; -+ auto head_id = blockIdx.y; -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ -+ auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; -+ -+ int64_t q_start, k_start; -+ // Advance to current batch - in case of different sequence lengths -+ if (cu_seqlens_q_ptr != nullptr) { -+ assert(cu_seqlens_k_ptr != nullptr); -+ if (cu_seqlens_q_ptr[batch_id] == -1) return false; -+ q_strideH = q_strideM * cu_seqlens_q_ptr[batch_id]; -+ if (!use_past) { -+ k_strideH = k_strideM * cu_seqlens_k_ptr[batch_id]; -+ v_strideH = v_strideM * cu_seqlens_k_ptr[batch_id]; -+ } -+ num_queries = cu_seqlens_q_ptr[batch_id]; -+ num_keys = cu_seqlens_k_ptr[batch_id]; -+ for (int i = 0; i < batch_id; i++) -+ { -+ if (cu_seqlens_q_ptr[i] == -1) continue; -+ query_ptr += cu_seqlens_q_ptr[i] * head_dim * num_heads; -+ output_ptr += cu_seqlens_q_ptr[i] * head_dim * num_heads; -+ if (!use_past) { -+ key_ptr += cu_seqlens_k_ptr[i] * head_dim * num_heads; -+ value_ptr += cu_seqlens_k_ptr[i] * head_dim * num_heads; -+ } -+ } -+ if (use_past) { -+ key_ptr += batch_id * k_strideB; -+ value_ptr += batch_id * v_strideB; -+ } -+ if (query_start >= num_queries) { -+ return false; -+ } -+ q_start = 0; -+ k_start = 0; -+ } else { -+ query_ptr += batch_id * q_strideB; -+ key_ptr += batch_id * k_strideB; -+ value_ptr += batch_id * v_strideB; -+ output_ptr += batch_id * o_strideB; -+ if (output_accum_ptr != nullptr) { -+ output_accum_ptr += batch_id * o_strideB; -+ } -+ q_start = 0; -+ k_start = 0; -+ } -+ if (attn_mask_ptr) { -+ attn_mask_ptr += batch_id * (attn_mask_strideB); -+ } -+ if (attn_bias_ptr) { -+ attn_bias_ptr += head_id * attn_bias_strideH; -+ } -+ // Advance to the current batch / head / query_start -+ query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; -+ key_ptr += k_start * k_strideM + head_id * k_strideH; -+ value_ptr += k_start * v_strideM + head_id * v_strideH; -+ output_ptr += int64_t(q_start + query_start) * o_strideM() + -+ head_id * o_strideH; -+ -+ if (output_accum_ptr != nullptr) { -+ output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + -+ head_id * o_strideH; -+ } else { -+ // Accumulate directly in the destination buffer (eg for f32) -+ output_accum_ptr = (accum_t*)output_ptr; -+ } -+ if (logsumexp_ptr != nullptr) { -+ // lse[batch_id, head_id, query_start] -+ logsumexp_ptr += -+ batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; -+ } -+ -+ num_queries -= query_start; -+ if (causal) { -+ num_keys = cutlass::fast_min( -+ int32_t(query_start + kQueriesPerBlock), num_keys); -+ } -+ num_batches = 0; // no longer used after -+ -+ // Make sure the compiler knows these variables are the same on all -+ // the threads of the warp. -+ if (attn_mask_ptr) { -+ attn_mask_ptr = warp_uniform(attn_mask_ptr); -+ } -+ if (attn_bias_ptr) { -+ attn_bias_ptr = warp_uniform(attn_bias_ptr); -+ } -+ query_ptr = warp_uniform(query_ptr); -+ key_ptr = warp_uniform(key_ptr); -+ value_ptr = warp_uniform(value_ptr); -+ output_ptr = warp_uniform(output_ptr); -+ output_accum_ptr = warp_uniform(output_accum_ptr); -+ logsumexp_ptr = warp_uniform(logsumexp_ptr); -+ num_queries = warp_uniform(num_queries); -+ num_keys = warp_uniform(num_keys); -+ head_dim = warp_uniform(head_dim); -+ head_dim_value = warp_uniform(head_dim_value); -+ return true; -+ } -+ -+ __host__ dim3 getBlocksGrid() const { -+ return dim3( -+ ceil_div(num_queries, (int32_t)kQueriesPerBlock), -+ num_heads, -+ num_batches); -+ } -+ __host__ dim3 getThreadsGrid() const { -+ return dim3(kWarpSize, kNumWarpsPerBlock, 1); -+ } -+ }; -+ -+ struct MM0 { -+ /* -+ In this first matmul, we compute a block of `Q @ K.T`. -+ While the calculation result is still hot in registers, we update -+ `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value -+ into a shared-memory ("AccumulatorSharedStorage") that is used later as -+ operand A for the second matmul (see MM1) -+ */ -+ using GemmType = DefaultGemmType; -+ -+ using OpClass = typename GemmType::OpClass; -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ scalar_t, -+ scalar_t, -+ scalar_t, // ElementC -+ accum_t // ElementAccumulator -+ >; -+ static constexpr int kAlignmentA = -+ kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; -+ static constexpr int kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ using ThreadblockShape = cutlass::gemm:: -+ GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< -+ scalar_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ kAlignmentA, -+ scalar_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ kAlignmentB, -+ accum_t, -+ cutlass::layout::RowMajor, // LayoutC, -+ OpClass, -+ ArchTag, // ArchTag -+ ThreadblockShape, // ThreadblockShape -+ WarpShape, // WarpShape -+ typename GemmType::InstructionShape, // InstructionShape -+ DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that -+ // uses too much smem -+ typename GemmType::Operator // Operator -+ >::DefaultMma; -+ using MmaCore = typename DefaultMma::MmaCore; -+ using IteratorA = typename DefaultMma::IteratorA; -+ using IteratorB = typename DefaultMma::IteratorB; -+ using Mma = typename DefaultMma::ThreadblockMma; -+ using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< -+ typename Mma::Operator::IteratorC, -+ accum_t, -+ kWarpSize>::Updater; -+ static_assert( -+ MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * -+ MmaCore::WarpCount::kK == -+ kNumWarpsPerBlock, -+ ""); -+ -+ // used for efficient load of mask tile Mij from global to shared memory -+ using MaskLoader = TileSmemLoader< -+ scalar_t, -+ cutlass::MatrixShape, -+ MmaCore::kThreads, -+ // input restriction: kv_len has to be a multiple of this value -+ 128 / cutlass::sizeof_bits::value>; -+ -+ // used for efficient load of mask tile Mij from global to shared memory -+ using BiasLoader = TileSmemLoader< -+ scalar_t, -+ cutlass::MatrixShape, -+ MmaCore::kThreads, -+ // input restriction: kv_len has to be a multiple of this value -+ 128 / cutlass::sizeof_bits::value>; -+ -+ // Epilogue to store to shared-memory in a format that we can use later for -+ // the second matmul -+ using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< -+ typename Mma::Operator::IteratorC, -+ typename Mma::Operator, -+ scalar_t, -+ WarpShape, -+ ThreadblockShape>; -+ using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; -+ }; -+ -+ struct MM1 { -+ /** -+ Second matmul: perform `attn @ V` where `attn` is the attention (not -+ normalized) and stored in shared memory -+ */ -+ using GemmType = DefaultGemmType; -+ -+ using OpClass = typename GemmType::OpClass; -+ using DefaultConfig = -+ typename cutlass::gemm::device::DefaultGemmConfiguration< -+ OpClass, -+ ArchTag, -+ scalar_t, -+ scalar_t, -+ output_accum_t, // ElementC -+ accum_t // ElementAccumulator -+ >; -+ static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem -+ static constexpr int kAlignmentB = -+ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; -+ using ThreadblockShape = cutlass::gemm:: -+ GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; -+ using InstructionShape = typename GemmType::InstructionShape; -+ -+ using LayoutB = cutlass::layout::RowMajor; -+ using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< -+ scalar_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ kAlignmentA, -+ scalar_t, // ElementB, -+ LayoutB, // LayoutB, -+ kAlignmentB, -+ output_accum_t, -+ cutlass::layout::RowMajor, // LayoutC, -+ accum_t, -+ OpClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ typename GemmType::InstructionShape, -+ typename DefaultConfig::EpilogueOutputOp, -+ void, // ThreadblockSwizzle - not used -+ DefaultConfig::kStages, -+ false, // SplitKSerial -+ typename GemmType::Operator>; -+ -+ using DefaultMmaFromSmem = -+ typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< -+ typename DefaultGemm::Mma, -+ typename MM0::AccumulatorSharedStorage>; -+ using Mma = typename DefaultMmaFromSmem::Mma; -+ using IteratorB = typename Mma::IteratorB; -+ using WarpCount = typename Mma::WarpCount; -+ static_assert( -+ WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, -+ ""); -+ -+ using DefaultEpilogue = typename DefaultGemm::Epilogue; -+ using OutputTileIterator = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_t>; -+ using OutputTileIteratorAccum = -+ typename cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename DefaultEpilogue::OutputTileIterator::ThreadMap, -+ output_accum_t>; -+ -+ struct SharedStorageMM1 { -+ typename Mma::SharedStorage mm; -+ }; -+ }; -+ -+ static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; -+ static constexpr int64_t kAlignmentK = MM0::kAlignmentB; -+ static constexpr int64_t kAlignmentV = 1; -+ -+ // Shared storage - depends on kernel params -+ struct ScalingCoefs { -+ cutlass::Array m_prime; -+ cutlass::Array s_prime; -+ cutlass::Array mi; -+ }; -+ -+ struct SharedStorageEpilogueAtEnd : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ union { -+ typename MM0::MaskLoader::SmemTile mask; -+ typename MM0::BiasLoader::SmemTile bias; -+ typename MM0::AccumulatorSharedStorage si; -+ }; -+ typename MM1::SharedStorageMM1 mm1; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return epilogue; -+ } -+ }; -+ -+ struct SharedStorageEpilogueInLoop : ScalingCoefs { -+ struct SharedStorageAfterMM0 { -+ // Everything here might be overwritten during MM0 -+ union { -+ typename MM0::MaskLoader::SmemTile mask; -+ typename MM0::BiasLoader::SmemTile bias; -+ typename MM0::AccumulatorSharedStorage si; -+ }; -+ typename MM1::SharedStorageMM1 mm1; -+ typename MM1::DefaultEpilogue::SharedStorage epilogue; -+ }; -+ -+ union { -+ typename MM0::Mma::SharedStorage mm0; -+ SharedStorageAfterMM0 after_mm0; -+ }; -+ -+ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& -+ epilogue_shared_storage() { -+ return after_mm0.epilogue; -+ } -+ }; -+ -+ using SharedStorage = typename cutlass::platform::conditional< -+ kSingleValueIteration || kKeepOutputInRF, -+ SharedStorageEpilogueAtEnd, -+ SharedStorageEpilogueInLoop>::type; -+ -+ static bool __host__ check_supported(Params const& p) { -+ CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); -+ CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); -+ CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); -+ XFORMERS_CHECK( -+ p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.k_strideM % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.v_strideM % kAlignmentV == 0, "value is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); -+ -+ if (p.attn_mask_ptr) { -+ CHECK_ALIGNED_PTR(p.attn_mask_ptr, kAlignmentQ); -+ XFORMERS_CHECK( -+ p.attn_mask_strideB % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_mask_strideH % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_mask_strideM % kAlignmentQ == 0, -+ "attn_mask is not correctly aligned"); -+ } -+ if (p.attn_bias_ptr) { -+ CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); -+ XFORMERS_CHECK( -+ p.attn_bias_strideB % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_bias_strideH % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ XFORMERS_CHECK( -+ p.attn_bias_strideM % kAlignmentQ == 0, -+ "attn_bias is not correctly aligned"); -+ } -+ return true; -+ } -+ -+ static void CUTLASS_DEVICE attention_kernel(Params& p) { -+ // In this block, we will only ever: -+ // - read query[query_start:query_end, :] -+ // - write to output[query_start:query_end, :] -+ -+ extern __shared__ char smem_buffer[]; -+ SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); -+ auto& m_prime = shared_storage.m_prime; -+ auto& s_prime = shared_storage.s_prime; -+ [[maybe_unused]] auto& si = shared_storage.after_mm0.si; -+ auto& mi = shared_storage.mi; -+ -+ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); -+ if (thread_id() < kQueriesPerBlock) { -+ s_prime[thread_id()] = accum_t(0); -+ m_prime[thread_id()] = -+ -cutlass::platform::numeric_limits::infinity(); -+ mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); -+ } -+ typename MM1::Mma::FragmentC accum_o; -+ accum_o.clear(); -+ -+ auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { -+ using OutputTileIterator = typename MM1::OutputTileIterator; -+ return OutputTileIterator( -+ typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, -+ p.output_ptr, -+ typename OutputTileIterator::TensorCoord{ -+ p.num_queries, p.head_dim_value}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ auto createOutputAccumIter = [&](int col) -> -+ typename MM1::OutputTileIteratorAccum { -+ using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; -+ return OutputTileIteratorAccum( -+ typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, -+ p.output_accum_ptr, -+ typename OutputTileIteratorAccum::TensorCoord{ -+ p.num_queries, p.head_dim_value}, -+ thread_id(), -+ {0, col}); -+ }; -+ -+ // Iterate through keys -+ for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; -+ iter_key_start += kKeysPerBlock) { -+ int32_t problem_size_0_m = -+ cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); -+ int32_t problem_size_0_n = cutlass::fast_min( -+ int32_t(kKeysPerBlock), p.num_keys - iter_key_start); -+ int32_t const& problem_size_0_k = p.head_dim; -+ int32_t const& problem_size_1_n = p.head_dim_value; -+ int32_t const& problem_size_1_k = problem_size_0_n; -+ -+ auto prologueV = [&](int blockN) { -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, -+ p.value_ptr + iter_key_start * p.v_strideM, -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ MM1::Mma::prologue( -+ shared_storage.after_mm0.mm1.mm, -+ iterator_V, -+ thread_id(), -+ problem_size_1_k); -+ }; -+ -+ __syncthreads(); // Need to have shared memory initialized, and `m_prime` -+ // updated from end of prev iter -+ // -+ // MATMUL: Q.K_t -+ // -+ // Computes the block-matrix product of: -+ // (a) query[query_start:query_end, :] -+ // with -+ // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] -+ // and stores that into `shared_storage.si` -+ // -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{ -+ tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; -+ -+ // Construct iterators to A and B operands -+ typename MM0::IteratorA iterator_A( -+ typename MM0::IteratorA::Params( -+ typename MM0::MmaCore::LayoutA(p.q_strideM)), -+ p.query_ptr, -+ {problem_size_0_m, problem_size_0_k}, -+ thread_id(), -+ tb_offset_A); -+ -+ typename MM0::IteratorB iterator_B( -+ typename MM0::IteratorB::Params( -+ typename MM0::MmaCore::LayoutB(p.k_strideM)), -+ p.key_ptr + iter_key_start * p.k_strideM, -+ {problem_size_0_k, problem_size_0_n}, -+ thread_id(), -+ tb_offset_B); -+ -+ auto my_warp_id = warp_id(); -+ auto my_lane_id = lane_id(); -+ -+ // Construct thread-scoped matrix multiply -+ typename MM0::Mma mma( -+ shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); -+ -+ typename MM0::Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ auto gemm_k_iterations = -+ (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ __syncthreads(); -+ -+ if (kPreloadV) { -+ prologueV(0); -+ } -+ -+ typename MM0::Mma::Operator::IteratorC::TensorCoord -+ iteratorC_tile_offset = { -+ (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + -+ (my_warp_id % MM0::Mma::WarpCount::kM), -+ (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + -+ (my_warp_id / MM0::Mma::WarpCount::kM)}; -+ float scale = p.scale; -+ if (p.attn_bias_ptr) { -+ if (scale != 1.0f) { -+ accum = cutlass::multiplies()(scale, accum); -+ scale = 1.0f; -+ } -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ // load bias tile Bij into shared memory -+ typename MM0::BiasLoader::GmemTileIterator bias_iter( -+ {cutlass::layout::RowMajor(p.attn_bias_strideM)}, -+ // attn_bias_pointer points to matrix of size (n_queries, n_keys) -+ // for the relevant batch_id -+ p.attn_bias_ptr + query_start * p.attn_bias_strideM + iter_key_start, -+ {problem_size_0_m, problem_size_0_n}, -+ thread_id()); -+ cutlass::TensorRef bias_tensor_ref( -+ shared_storage.after_mm0.bias.data(), -+ cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); -+ typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( -+ bias_tensor_ref, thread_id()); -+ -+ MM0::BiasLoader::load(bias_iter, smem_tile_iter); -+ // Pij += Bij, Pij is in register fragment and Bij is in shared memory -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) {}, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { -+ accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ -+ if (p.attn_mask_ptr) { -+ // scale*Q.K_t prio to mask apply -+ if (scale != 1.0f) { -+ accum = cutlass::multiplies()(scale, accum); -+ scale = 1.0f; -+ } -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ // load mask tile Mij into shared memory -+ typename MM0::MaskLoader::GmemTileIterator mask_iter( -+ {cutlass::layout::RowMajor(p.attn_mask_strideM)}, -+ // attn_mask_pointer points to matrix of size (n_queries, n_keys) -+ // for the relevant batch_id -+ p.attn_mask_ptr + query_start * p.attn_mask_strideM + iter_key_start, -+ {problem_size_0_m, problem_size_0_n}, -+ thread_id()); -+ cutlass::TensorRef mask_tensor_ref( -+ shared_storage.after_mm0.mask.data(), -+ cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); -+ typename MM0::MaskLoader::SmemTileIterator smem_tile_iter( -+ mask_tensor_ref, thread_id()); -+ -+ MM0::MaskLoader::load(mask_iter, smem_tile_iter); -+ // Pij += (Mij-1)*(-10000), Pij is in register fragment and Mij is in shared memory -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) {}, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { -+ scalar_t tmp = scalar_t(1.0) - mask_tensor_ref.at({accum_m, accum_n}); -+ accum[idx] += tmp*scalar_t(-10000.0f); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ -+ // Mask out last if causal -+ if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { -+ auto query_start = blockIdx.x * kQueriesPerBlock; -+ auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( -+ lane_id(), warp_id(), iteratorC_tile_offset); -+ int32_t last_col; -+ MM0::ScalingCoefsUpdater::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ last_col = query_start + accum_m - iter_key_start; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (accum_n > last_col) { -+ accum[idx] = -+ -cutlass::platform::numeric_limits::infinity(); -+ } -+ }, -+ [&](int accum_m) {}); -+ } -+ DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ p.num_keys - iter_key_start >= kKeysPerBlock, -+ kFullColumns, -+ ([&] { -+ // Update `mi` from accum stored in registers -+ // Also updates `accum` with accum[i] <- -+ // exp(accum[i] * scale -+ // - mi) -+ MM0::ScalingCoefsUpdater::update< -+ kQueriesPerBlock, -+ kFullColumns, -+ kIsFirst, -+ kKeepOutputInRF>( -+ accum_o, -+ accum, -+ mi, -+ m_prime, -+ s_prime, -+ lane_id(), -+ thread_id(), -+ warp_id(), -+ p.num_keys - iter_key_start, -+ iteratorC_tile_offset, -+ scale); -+ })); -+ })); -+ -+ // Output results to shared-memory -+ int warp_idx_mn_0 = my_warp_id % -+ (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); -+ auto output_tile_coords = cutlass::MatrixCoord{ -+ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, -+ warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; -+ -+ MM0::B2bGemm::accumToSmem( -+ shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); -+ -+ __syncthreads(); -+ -+ // -+ // MATMUL: Attn . V -+ // Run the matmul `attn @ V` for a block of attn and V. -+ // `attn` is read from shared memory (in `shared_storage_si`) -+ // `V` is read from global memory (with iterator_B) -+ // -+ -+ const int64_t nBlockN = kSingleValueIteration -+ ? 1 -+ : ceil_div( -+ (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); -+ for (int blockN = 0; blockN < nBlockN; ++blockN) { -+ int gemm_k_iterations = -+ (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add and store it in accum -+ // (in registers) -+ if (!kPreloadV) { -+ __syncthreads(); // we share shmem between mma and epilogue -+ } -+ -+ typename MM1::Mma::IteratorB iterator_V( -+ typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, -+ p.value_ptr + iter_key_start * p.v_strideM, -+ {problem_size_1_k, problem_size_1_n}, -+ thread_id(), -+ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); -+ typename MM1::Mma mma_pv( -+ shared_storage.after_mm0.mm1.mm, -+ shared_storage.after_mm0.si, -+ (int)thread_id(), -+ (int)warp_id(), -+ (int)lane_id(), -+ (int)problem_size_1_k); -+ mma_pv.set_prologue_done(kPreloadV); -+ if (!kKeepOutputInRF) { -+ accum_o.clear(); -+ } -+ mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); -+ __syncthreads(); -+ -+ if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { -+ prologueV(blockN + 1); -+ } -+ -+ if (!kKeepOutputInRF) { -+ DISPATCH_BOOL( -+ iter_key_start == 0, kIsFirst, ([&] { -+ DISPATCH_BOOL( -+ (iter_key_start + kKeysPerBlock) >= p.num_keys, -+ kIsLast, -+ ([&] { -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = -+ typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = typename cutlass::epilogue:: -+ thread::MemoryEfficientAttentionNormalize< -+ typename cutlass::platform::conditional< -+ kIsLast, -+ output_t, -+ output_accum_t>::type, -+ output_accum_t, -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, -+ ElementCompute, -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename cutlass::platform::conditional< -+ kIsLast, -+ typename MM1::OutputTileIterator, -+ typename MM1::OutputTileIteratorAccum>::type, -+ typename DefaultEpilogue:: -+ AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // Read -+ // iterator -+ >; -+ -+ int col = blockN * MM1::Mma::Shape::kN; -+ auto source_iter = createOutputAccumIter(col); -+ auto dest_iter = call_conditional< -+ kIsLast, -+ decltype(createOutputIter), -+ decltype(createOutputAccumIter)>:: -+ apply(createOutputIter, createOutputAccumIter, col); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o, source_iter); -+ })); -+ })); -+ if (!kSingleValueIteration) { -+ __syncthreads(); -+ } -+ } -+ } -+ __syncthreads(); // we modify `m_prime` after -+ } -+ -+ if (kKeepOutputInRF) { -+ constexpr bool kIsFirst = true; -+ constexpr bool kIsLast = true; -+ using DefaultEpilogue = typename MM1::DefaultEpilogue; -+ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; -+ using ElementCompute = typename DefaultOp::ElementCompute; -+ using EpilogueOutputOp = -+ typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< -+ output_t, // output -+ output_accum_t, // source -+ DefaultOp::kCount, -+ typename DefaultOp::ElementAccumulator, // accum -+ output_accum_t, // compute -+ kIsFirst, -+ kIsLast, -+ cutlass::Array>; -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::EpiloguePipelined< -+ typename DefaultEpilogue::Shape, -+ typename MM1::Mma::Operator, -+ DefaultEpilogue::kPartitionsK, -+ typename MM1::OutputTileIterator, // destination -+ typename DefaultEpilogue::AccumulatorFragmentIterator, -+ typename DefaultEpilogue::WarpTileIterator, -+ typename DefaultEpilogue::SharedLoadIterator, -+ EpilogueOutputOp, -+ typename DefaultEpilogue::Padding, -+ DefaultEpilogue::kFragmentsPerIteration, -+ true, // IterationsUnroll -+ typename MM1::OutputTileIteratorAccum // source tile -+ >; -+ auto dest_iter = createOutputIter(0); -+ EpilogueOutputOp rescale(s_prime, m_prime); -+ Epilogue epilogue( -+ shared_storage.epilogue_shared_storage(), -+ thread_id(), -+ warp_id(), -+ lane_id()); -+ epilogue(rescale, dest_iter, accum_o); -+ } -+ -+ // 7. Calculate logsumexp -+ // To make the backward easier, we pad logsumexp with `inf` -+ // this avoids a few bound checks, and is not more expensive during fwd -+ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); -+ if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { -+ auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; -+ if (thread_id() < p.num_queries) { -+ p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + -+ cutlass::fast_log(accum_t(s_prime[thread_id()])); -+ } else if (thread_id() < lse_dim) { -+ p.logsumexp_ptr[thread_id()] = -+ cutlass::platform::numeric_limits::infinity(); -+ } -+ } -+ } -+ -+ static CUTLASS_DEVICE int8_t lane_id() { -+ return threadIdx.x; -+ } -+ static CUTLASS_DEVICE int8_t warp_id() { -+ return threadIdx.y; -+ } -+ static CUTLASS_DEVICE int16_t thread_id() { -+ return threadIdx.x + threadIdx.y * blockDim.x; -+ } -+}; -+ -+template -+__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) -+ attention_kernel_batched_impl(typename AK::Params p) { -+ if (!p.advance_to_block()) { -+ return; -+ } -+ AK::attention_kernel(p); -+} -+ -+template -+__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) -+ attention_kernel_batched(typename AK::Params params); -+ -+#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ -+ template <> \ -+ __global__ void __launch_bounds__( \ -+ __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ -+ attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ -+ using Kernel = __VA_ARGS__; -+#define _ATTENTION_KERNEL_FORWARD_END() } -+ -+#ifdef __CUDA_ARCH__ -+#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -+#else -+#define __CUDA_ARCH_OR_ZERO__ 0 -+#endif -+ -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ -+ ARCH, \ -+ SCALAR_T, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER) \ -+ _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ -+ SCALAR_T, \ -+ cutlass::arch::Sm##ARCH, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER>) \ -+ if (!p.advance_to_block()) { \ -+ return; \ -+ } \ -+ Kernel::attention_kernel(p); \ -+ _ATTENTION_KERNEL_FORWARD_END(); -+ -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ -+ ARCH, \ -+ SCALAR_T, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER) \ -+ _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ -+ SCALAR_T, \ -+ cutlass::arch::Sm##ARCH, \ -+ IS_ALIGNED, \ -+ QUERIES_PER_BLOCK, \ -+ KEYS_PER_BLOCK, \ -+ SINGLE_VALUE_ITER>) \ -+ printf( \ -+ "FATAL: this function is for sm%d, but was built for sm%d\n", \ -+ int(ARCH), \ -+ int(__CUDA_ARCH_OR_ZERO__)); \ -+ _ATTENTION_KERNEL_FORWARD_END(); -+ -+// All kernels are disabled by default -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) -+ -+// Enable the right one based on __CUDA_ARCH__ -+#ifndef __CUDA_ARCH__ -+#elif __CUDA_ARCH__ < 500 -+#error "Need cuda arch at least 5.0" -+#elif __CUDA_ARCH__ < 700 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) -+#elif __CUDA_ARCH__ < 750 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) -+#elif __CUDA_ARCH__ < 800 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) -+#elif __CUDA_ARCH__ >= 800 -+#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 -+#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ -+ INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) -+#endif -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h -new file mode 100644 -index 0000000..21ac4d1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/mma_from_smem.h -@@ -0,0 +1,1780 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights -+ *reserved. SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, -+ *this list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -+ *POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/vector_iterator.h" -+ -+#include "attention_scaling_coefs_updater.h" -+#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" -+#include "epilogue_thread_apply_logsumexp.h" -+#include "gemm_kernel_utils.h" -+#include "iterators/make_residual_last.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+/// Shared storage object needed by accumulator -+/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename Padding_> -+class AccumulatorSharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using Padding = Padding_; -+ -+ /// Tensor reference to the accumulator -+ using TensorRefAccum = cutlass::TensorRef; -+ -+ /// Shape of the accumulator matrix in shared memory -+ using ShapeAccum = cutlass:: -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for accumulator -+ cutlass::AlignedBuffer accum; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the Accum matrix -+ CUTLASS_DEVICE -+ static Layout LayoutAccum() { -+ return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the Accumulator -+ CUTLASS_HOST_DEVICE -+ TensorRefAccum accum_ref() { -+ return TensorRefAccum{accum.data(), LayoutAccum()}; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ // Maximum value for K -+ int kMaxK, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBaseFromSharedMemory { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape< -+ Shape::kM / WarpGemm::kM, -+ Shape::kN / WarpGemm::kN, -+ Shape::kK / WarpGemm::kK>; -+ using WarpCount1 = WarpCount; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ static int const kWarpGemmIterations1 = kWarpGemmIterations; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// If this is true, we fill the entire shmem buffer at start -+ /// and don't need to iterate through it in a circular fashion -+ static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = -+ TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = -+ TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape< -+ Shape::kK * kStages + Policy::SmemPaddingB::kRow, -+ Shape::kN + Policy::SmemPaddingB::kColumn>; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ // /// Iterator to load a warp-scoped tile of A operand from shared memory -+ // typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBaseFromSharedMemory( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage& shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ // BEGIN smem -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA, -+ // Accumulator type -+ typename AccumulatorSharedStorage, -+ // END smem -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< -+ Shape_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy_, -+ 2> { -+ public: -+ ///< Base class -+ using Base = MmaBaseFromSharedMemory< -+ Shape_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy_, -+ 2>; -+ -+ using Shape = -+ Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorB = -+ IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert( -+ (Base::kStages == 2), -+ "MmaPipelined requires kStages set to value 2"); -+ -+ private: -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ protected: -+ // /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ // SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Iterator to load a warp-scoped tile of A operand from intermediate -+ /// accumulator tile -+ WarpIteratorA warp_tile_iterator_A_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPipelinedFromSharedMemory( -+ typename Base::SharedStorage& -+ shared_storage, ///< Shared storage needed for internal use by -+ ///< threadblock-scoped GEMM -+ AccumulatorSharedStorage& accumulator_shared_storage, -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ int problem_size_0_n) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ // For API compatibility with MmaMultistageFromSharedMemory -+ // but not supported as it worsens perf: older gpus < sm80 don't -+ // support async tranfers and have to waste registers -+ CUTLASS_DEVICE -+ void set_prologue_done(bool value) {} -+ CUTLASS_DEVICE -+ static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ IteratorB iterator_B1, -+ int thread_idx, -+ int problem_size_0_n) {} -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC& accum, ///< destination accumulator tile -+ // IteratorA iterator_A, ///< iterator over A -+ // operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const& src_accum, ///< source accumulator tile -+ // TransformA transform_A = TransformA(), ///< transformation -+ // applied to A fragment -+ TransformB transform_B = -+ TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentB tb_frag_B; -+ -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_B.set_residual_tile(gemm_k_iterations == 1); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_B; -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ warp_frag_A[0].clear(); -+ warp_frag_B[0].clear(); -+ -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_B.set_residual_tile(gemm_k_iterations == 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* -+ // issuing shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ bool hasNext = true; -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ // Write fragments to shared memory -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory SMEM: Don't reset iterator A, as -+ // we are continuing our iteration at this point -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } else { -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ hasNext = gemm_k_iterations > 1; -+ } -+ -+ // Only read the next if we need to -+ if (hasNext) { -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_B.set_residual_tile(gemm_k_iterations == 3); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Taken from -+// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape1_, -+ /// Iterates over the intermediate accumulator tile in shared memory -+ typename WarpIteratorA1_, -+ // Accumulator type -+ typename AccumulatorSharedStorage, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB1_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB1_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB1, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy1_, -+ /// Number of stages, -+ int Stages_, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< -+ Shape1_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy1_, -+ Stages_> { -+ public: -+ ///< Base class -+ using Base = MmaBaseFromSharedMemory< -+ Shape1_, -+ AccumulatorSharedStorage::Shape::kN, -+ Policy1_, -+ Stages_>; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape1 = Shape1_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB1 = IteratorB1_; -+ using IteratorB = IteratorB1; -+ ///< Policy describing tuning details -+ using Policy1 = Policy1_; -+ -+ using SmemIteratorB1 = SmemIteratorB1_; -+ using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate -+ ///< accumulator tile in shared memory -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; -+ static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC1 = typename Policy1::Operator::FragmentC; -+ using FragmentC = FragmentC1; -+ -+ /// Warp-level Mma -+ using Operator1 = typename Policy1::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB1 = Operator1::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ static_assert( -+ Base::kWarpGemmIterations1 > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB1 = -+ IteratorB1::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB1 = -+ (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / -+ Base::kWarpGemmIterations1; -+ }; -+ -+ static constexpr int kNumStagesConcurrentLoad = -+ kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; -+ -+ private: -+ using WarpLoadedFragmentA1 = typename Operator1::FragmentA; -+ using WarpLoadedFragmentB1 = typename Operator1::FragmentB; -+ using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; -+ using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate -+ /// accumulator tile -+ WarpIteratorA1 warp_tile_iterator_A1_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB1 smem_iterator_B1_; -+ -+ bool prologue_done_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMultistageFromSharedMemory( -+ typename Base::SharedStorage& -+ shared_storage, ///< Shared storage needed for internal use by -+ ///< threadblock-scoped GEMM -+ AccumulatorSharedStorage& accumulator_shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx, -+ ///< GEMM0 N is used for accumulator extent -+ int problem_size_0_n) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ warp_tile_iterator_A1_( -+ accumulator_shared_storage.accum_ref(), -+ lane_idx), -+ smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), -+ prologue_done_(false) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn_1 = -+ warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); -+ -+ int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; -+ int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ warp_tile_iterator_A1_.add_tile_offset( -+ {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); -+ } -+ -+ CUTLASS_DEVICE -+ void set_prologue_done(bool value) { -+ prologue_done_ = value; -+ } -+ -+ CUTLASS_DEVICE -+ static void prologue( -+ typename Base::SharedStorage& shared_storage, -+ IteratorB iterator_B1, -+ int thread_idx, -+ int problem_size_0_n) { -+ SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); -+ _prologue( -+ iterator_B1, -+ (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, -+ smem_iterator_B1); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance_1( -+ IteratorB1& iterator_B1, -+ int group_start_B1 = 0) { -+ iterator_B1.set_iteration_index( -+ group_start_B1 * IteratorB1::kAccessesPerVector); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B1); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { -+ if (group_start_B1 + j < Detail::TBLoadIterationsB1) { -+ typename IteratorB1::AccessType* dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void _prologue( -+ IteratorB& iterator_B1, -+ int32_t gemm_k_iterations_1, -+ SmemIteratorB1& smem_iterator_B1_) { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations_1) { -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ -+ iterator_B1.set_iteration_index(0); -+ smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ typename IteratorB1::AccessType* dst_ptr = -+ reinterpret_cast( -+ smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB1::ThreadMap::kElementsPerAccess / -+ IteratorB1::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 0); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations_1_, -+ ///< destination accumulator tile -+ FragmentC1& accum, -+ ///< iterator over B1 operand in global memory -+ IteratorB1 iterator_B1, -+ ///< initial value of accumulator -+ FragmentC1 const& src_accum) { -+ // 2nd Gemm -+ -+ // -+ // Prologue -+ // -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ if (!prologue_done_) { -+ _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); -+ } else if (!kSmemContainsEntireB) { -+ // Restore the iterators increments -+ -+ int gemm_k_iterations_1 = gemm_k_iterations_1_; -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < kNumStagesConcurrentLoad; -+ ++stage, --gemm_k_iterations_1) { -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ iterator_B1.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ } -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); -+ iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); -+ } -+ -+ // DEPBAR+SYNC -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; -+ WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; -+ WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; -+ -+ Operator1 warp_mma1; -+ -+ warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); -+ ++warp_tile_iterator_A1_; -+ -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma1.transform( -+ warp_transformed_frag_A1[0], -+ warp_transformed_frag_B1[0], -+ warp_loaded_frag_A1[0], -+ warp_loaded_frag_B1[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC1 tmp_accum; -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); -+ gemm_k_iterations_1 > (-Base::kStages + 1); -+ gemm_k_iterations_1--) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; -+ ++warp_mma_k) { -+ // Load warp-level tile from accumulator fragment (A) -+ // or shared memory (operand B) -+ this->warp_tile_iterator_B_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations1); -+ // skip warp tile loading for the last kgroup (we are out of the buf) -+ if (gemm_k_iterations_1 > (-Base::kStages + 2) || -+ warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ warp_tile_iterator_A1_.load( -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ ++warp_tile_iterator_A1_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma1.transform( -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A1[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ warp_mma1( -+ tmp_accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ tmp_accum); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma1( -+ accum, -+ warp_transformed_frag_A1[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { -+ int group_start_iteration_B1; -+ -+ group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; -+ -+ if (!kSmemContainsEntireB) { -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { -+ int group_start_iteration_B1; -+ group_start_iteration_B1 = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; -+ -+ if (!kSmemContainsEntireB) { -+ copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); -+ } -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (!kSmemContainsEntireB) { -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy1::kPartitionsK * -+ Base::kWarpGemmIterations1, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ } -+ -+ iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); -+ iterator_B1.clear_mask(gemm_k_iterations_1 == 1); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations1) -+ warp_mma1.transform( -+ warp_transformed_frag_A1[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ if (platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddFastF32>::value || -+ platform::is_same< -+ typename Operator1::MathOperator, -+ arch::OpMultiplyAddComplexFastF32>::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ } -+}; -+ -+template < -+ typename WarpShape, -+ typename InstructionShape, -+ typename RegularWarpIterator, -+ typename Policy> -+struct DefaultWarpIteratorAFromSharedMemory {}; -+ -+// TensorOp - Ampere -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ static constexpr auto kWarpSize = 32; -+ using OpDelta = typename Policy::Operator::Policy::OpDelta; -+ -+ using WarpIterator = -+ cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< -+ cutlass::MatrixShape, -+ cutlass::gemm::Operand::kA, -+ typename RegularWarpIterator::Element, -+ cutlass::layout::RowMajor, -+ cutlass::MatrixShape, -+ OpDelta::kRow, -+ kWarpSize>; -+}; -+ -+// TensorOp - Volta -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; -+ static constexpr auto kWarpSize = 32; -+ using OpDelta = typename Policy::Operator::Policy::OpDelta; -+ -+ using WarpIterator = -+ cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< -+ cutlass::MatrixShape<32, 32>, // MatrixShape, -+ cutlass::gemm::Operand::kA, -+ typename RegularWarpIterator::Element, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, -+ cutlass::MatrixShape<16, 4>, -+ OpDelta::kRow, -+ kWarpSize>; -+}; -+ -+// Simt -+template -+struct DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ RegularWarpIterator, -+ Policy> { -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr auto kWarpSize = 32; -+ -+ // We just use the same iterator, as we reproduced the same shared-memory -+ // schema. Just modify it to handle non-complete tiles. -+ using WarpIterator = RegularWarpIterator; -+}; -+ -+// Converts a "regular" Mma into their counterpart from shared memory -+template -+struct DefaultMmaFromSharedMemory; -+ -+// Mma pipelined -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_, -+ /// Transformation applied to B operand -+ typename TransformB_, -+ typename AccumulatorSharedStorage_> -+struct DefaultMmaFromSharedMemory< -+ MmaPipelined< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ IteratorB_, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ TransformA_, -+ TransformB_>, -+ AccumulatorSharedStorage_> { -+ static constexpr int kWarpSize = 32; -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ -+ using RegularMma = MmaPipelined< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ IteratorB_, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ TransformA_, -+ TransformB_>; -+ -+ using WarpShape = typename Policy_::Operator::Shape; -+ using InstructionShape = typename Policy_::Operator::InstructionShape; -+ using ArchMmaOperator = typename Policy_::Operator; -+ -+ using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ InstructionShape, -+ typename RegularMma::Operator::IteratorA, -+ Policy_>::WarpIterator; -+ using IteratorB = -+ typename cutlass::transform::threadblock::MakeIteratorResidualLast< -+ IteratorB_>::Iterator; -+ -+ using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< -+ Shape_, -+ WarpIteratorA, -+ AccumulatorSharedStorage_, -+ IteratorB, -+ SmemIteratorB_, -+ ElementC_, -+ LayoutC_, -+ Policy_>; -+}; -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ typename AccumulatorSharedStorage_> -+struct DefaultMmaFromSharedMemory< -+ MmaMultistage< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ CacheOpA, -+ IteratorB_, -+ SmemIteratorB_, -+ CacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ Stages, -+ SharedMemoryClear>, -+ AccumulatorSharedStorage_> { -+ static constexpr int kWarpSize = 32; -+ -+ using RegularMma = MmaMultistage< -+ Shape_, -+ IteratorA_, -+ SmemIteratorA_, -+ CacheOpA, -+ IteratorB_, -+ SmemIteratorB_, -+ CacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ Stages, -+ SharedMemoryClear>; -+ -+ using WarpShape = typename Policy_::Operator::Shape; -+ using InstructionShape = typename Policy_::Operator::InstructionShape; -+ using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< -+ WarpShape, -+ InstructionShape, -+ typename RegularMma::Operator::IteratorA, -+ Policy_>::WarpIterator; -+ -+ static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN; -+ // Reduce the number of stages if we don't need that many -+ static int constexpr kStagesMax = -+ (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); -+ static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); -+ -+ using IteratorB = -+ typename cutlass::transform::threadblock::MakeIteratorResidualLast< -+ IteratorB_>::Iterator; -+ using Mma = -+ typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< -+ Shape_, -+ WarpIteratorA, -+ AccumulatorSharedStorage_, -+ IteratorB, -+ SmemIteratorB_, -+ RegularMma::kCacheOpB, -+ ElementC_, -+ LayoutC_, -+ Policy_, -+ kStages>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename IteratorC, -+ typename Operator, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm; -+ -+// Tensor Cores >= Sm75 specialization (Ampere ...) -+template < /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ typename Operator, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm< -+ cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ Shape_, -+ Element_, -+ Layout_, -+ InstructionShape_, -+ OpDelta_>, -+ Operator, -+ scalar_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = -+ typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< -+ Shape_, -+ Element_, -+ Layout_, -+ InstructionShape_, -+ OpDelta_>; -+ using FragmentC = typename IteratorC::Fragment; -+ using InstructionShape = InstructionShape_; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using accum_t = Element_; -+ using lse_scalar_t = float; -+ -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ -+ // Iterator to load accumulators (results of matmul in registers) -+ using FragmentIteratorAccumulator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ accum_t, -+ typename Operator::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor>; -+ -+ // Iterator to store to shared-memory -+ using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ scalar_t, // accum_t, -+ SmemAccumulatorLayout>; -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ typename SmemIteratorD0::Element, -+ typename SmemIteratorD0::TensorLayout, -+ typename SmemIteratorD0::Padding>; -+ // We need to provide an operation for the epilogue. Let's create an -+ // operation that does nothing (ScaleType::Nothing), just converts -+ // from accum_t (float) -> scalar_t (can be half) -+ using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< -+ typename SmemIteratorD0::Element, // ElementOutput -+ FragmentIteratorAccumulator::Fragment::kElements, -+ accum_t, // ElementAccumulator -+ typename SmemIteratorD0::Element, // ElementCompute -+ cutlass::epilogue::thread::ScaleType::Nothing>; -+ using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, -+ FragmentIteratorAccumulator, -+ SmemIteratorD0, // ScaleBiasIterator - not used -+ OutputOpNoOp>; -+ -+ // Epilogue 2: with LSE (for backwards pass) -+ static int const kElementsPerAccess = 2; // TODO: Why 2? -+ using IteratorAccumulatorLSE = -+ cutlass::transform::threadblock::VectorIterator< -+ cutlass::transform::threadblock::PredicatedVectorAccessIterator< -+ // Shape -+ cutlass::MatrixShape, -+ // WarpShape -+ cutlass::MatrixShape, -+ lse_scalar_t, -+ cutlass::layout::RowMajor, -+ kElementsPerAccess>>; -+ using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< -+ scalar_t, // ElementOutput_ -+ lse_scalar_t, // ElementLSE_ -+ accum_t, // ElementAccumulator_ -+ accum_t, // ElementCompute_ -+ 128 / cutlass::sizeof_bits::value -+ // FragmentIteratorAccumulator::Fragment::kElements -+ // InstructionShape::kM * InstructionShape::kN / 32 -+ >; -+ using EpilogueWithLSE = -+ cutlass::epilogue::threadblock::EpilogueSmemAccumulator< -+ SmemIteratorD0, -+ FragmentIteratorAccumulator, -+ IteratorAccumulatorLSE, -+ EpilogueOpApplyLSE>; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); -+ smem_iterator_attn.add_tile_offset( -+ tile_coords * -+ cutlass::MatrixCoord{ -+ SmemIteratorD0::TileIterations::kRow, -+ SmemIteratorD0::TileIterations::kColumn}); -+ Epilogue epilogue; -+ epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC& accum, -+ lse_scalar_t const* lse, -+ int32_t lse_extents, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ constexpr int32_t kAlignLSE = 32; -+ IteratorAccumulatorLSE iterator_lse( -+ lse, -+ {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, -+ thread_id, -+ warp_id, -+ cutlass::MatrixCoord{0, 0} // offset -+ ); -+ -+ SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); -+ smem_iterator_attn.add_tile_offset( -+ tile_coords * -+ cutlass::MatrixCoord{ -+ SmemIteratorD0::TileIterations::kRow, -+ SmemIteratorD0::TileIterations::kColumn}); -+ EpilogueWithLSE epilogue; -+ EpilogueOpApplyLSE minus_lse_exp({}); -+ epilogue( -+ minus_lse_exp, -+ smem_iterator_attn, -+ accum, -+ // scale - unused -+ iterator_lse, -+ // bias -+ iterator_lse); -+ } -+}; -+ -+// Volta Specialization -+// only supported for f16 -+template -+struct B2bGemm< -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ cutlass::MatrixShape<1, 1>>, -+ Operator, -+ cutlass::half_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = -+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ cutlass::MatrixShape<1, 1>>; -+ using scalar_t = cutlass::half_t; -+ using accum_t = IteratorC::Element; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using FragmentC = IteratorC::Fragment; -+ using lse_scalar_t = float; -+ -+ using SmemAccumulatorLayout = cutlass::layout::RowMajor; -+ using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ WarpShape, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ scalar_t, -+ SmemAccumulatorLayout>; -+ -+ // // Storage in shared-memory for Q.Kt -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ scalar_t, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ 16, -+ 32>, // typename SmemIteratorD0::TensorLayout, -+ cutlass::MatrixShape<0, 0> // Padding -+ >; -+ -+ using OutputLayout = -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; -+ using TensorRef = cutlass::TensorRef; -+ using Policy = typename IteratorC::Policy; -+ using Element = accum_t; -+ // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields -+ // Let's copy their values -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename cutlass::platform::conditional< -+ cutlass::platform::is_same::value, -+ cutlass::MatrixShape<2, 2>, -+ cutlass::MatrixShape<1, 4>>::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // ctor - from MmaVoltaTensorOpAccumulatorTileIterator -+ TensorRef ref_(shared_storage.accum_ref()); -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (cutlass::platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = -+ ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + -+ lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ cutlass::MatrixCoord lane_offset(accum_m, accum_n); -+ -+ // Tile offset -+ ref_.add_coord_offset( -+ tile_coords * -+ cutlass::MatrixCoord( -+ {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); -+ -+ using AccessType = cutlass::Array; -+ -+ // store - from MmaVoltaTensorOpAccumulatorTileIterator -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + -+ mma_n) * -+ Policy::MmaIterations::kRow + -+ mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn / 2; -+ int r = (accum_m + lane_offset.row()); -+ AccessType to_store; -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ int c = (accum_n + n + lane_offset.column()); -+ to_store[n] = scalar_t(accum[idx]); -+ } -+ int c = (accum_n + lane_offset.column()); -+ assert(r < 32); -+ assert(c < 32); -+ *reinterpret_cast( -+ ref_.data() + ref_.offset({r, c})) = to_store; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ typename IteratorC::Fragment& accum, -+ lse_scalar_t const* lse, -+ int lse_extent, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // Non-optimized way to apply LSE to registers -+ // NOTE: accum is attn.T -+ // TODO: Optimize for each architecture -+ static constexpr int WarpSize = 32; -+ using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< -+ IteratorC, -+ accum_t, -+ WarpSize>::Updater; -+ auto lane_offset = -+ RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); -+ -+ cutlass::Array lse_prefetched; -+ lse_prefetched.clear(); -+ int rowIdx = 0; -+ int colIdx = 0; -+ RegistersIter::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ ++rowIdx; -+ colIdx = 0; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (rowIdx == 1) { -+ lse_prefetched[colIdx] = accum_n < lse_extent -+ ? lse[accum_n] -+ : platform::numeric_limits::infinity(); -+ } -+ accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); -+ ++colIdx; -+ }, -+ [&](int accum_m) {}); -+ accumToSmem(shared_storage, accum, lane_id, tile_coords); -+ } -+}; -+ -+// Simt Specialization -+// for f32 on Sm70-Sm75 and f16/f32 below -+ -+template < -+ typename Operator, -+ typename OperatorPolicy, -+ typename scalar_t, -+ typename WarpShape_, -+ typename ThreadblockShape_> -+struct B2bGemm< -+ cutlass::gemm::warp::MmaSimtTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ cutlass::gemm::Operand::kC, -+ float, -+ cutlass::layout::RowMajor, -+ OperatorPolicy, -+ 1, -+ 1>, -+ Operator, -+ scalar_t, -+ WarpShape_, -+ ThreadblockShape_> { -+ using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< -+ cutlass::MatrixShape<32, 32>, -+ cutlass::gemm::Operand::kC, -+ float, -+ cutlass::layout::RowMajor, -+ OperatorPolicy, -+ 1, -+ 1>; -+ using accum_t = typename IteratorC::Element; -+ using WarpShape = WarpShape_; -+ using ThreadblockShape = ThreadblockShape_; -+ using FragmentC = typename IteratorC::Fragment; -+ using lse_scalar_t = float; -+ -+ // Storage in shared-memory for Q.Kt -+ using AccumulatorSharedStorage = -+ cutlass::gemm::threadblock::AccumulatorSharedStorage< -+ ThreadblockShape, -+ scalar_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::MatrixShape<0, 0> // Padding -+ >; -+ -+ static void CUTLASS_DEVICE accumToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ FragmentC const& accum, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ using Policy = typename IteratorC::Policy; -+ using Element = typename IteratorC::Element; -+ using Iterations = typename IteratorC::Iterations; -+ using Delta = typename IteratorC::Delta; -+ -+ auto ref_ = shared_storage.accum_ref(); -+ // ctor - MmaSimtTileIterator -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ -+ // Tile offset -+ ref_.add_coord_offset( -+ tile_coords * -+ cutlass::MatrixCoord( -+ {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); -+ -+ // store - MmaSimtTileIterator -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ int r = -+ Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + -+ m; -+ int c = mma_n * Delta::kColumn + n; -+ int idx = n + -+ Policy::LaneMmaShape::kN * -+ (mma_n + -+ Iterations::kColumn * -+ (m + mma_m * Policy::LaneMmaShape::kM)); -+ ref_.at({r, c}) = scalar_t(accum[idx]); -+ } -+ } -+ } -+ } -+ } -+ -+ static void CUTLASS_DEVICE accumApplyLSEToSmem( -+ AccumulatorSharedStorage& shared_storage, -+ typename IteratorC::Fragment& accum, -+ lse_scalar_t const* lse, -+ int lse_extent, -+ int thread_id, -+ int warp_id, -+ int lane_id, -+ cutlass::MatrixCoord const& tile_coords) { -+ // Non-optimized way to apply LSE to registers -+ // NOTE: accum is attn.T -+ // TODO: Optimize for each architecture -+ static constexpr int WarpSize = 32; -+ using RegistersIter = typename DefaultAttentionScalingCoefsUpdater< -+ IteratorC, -+ accum_t, -+ WarpSize>::Updater; -+ auto lane_offset = -+ RegistersIter::get_lane_offset(lane_id, warp_id, tile_coords); -+ -+ cutlass::Array lse_prefetched; -+ lse_prefetched.clear(); -+ int rowIdx = 0; -+ int colIdx = 0; -+ RegistersIter::iterateRows( -+ lane_offset, -+ [&](int accum_m) { -+ ++rowIdx; -+ colIdx = 0; -+ }, -+ [&](int accum_m, int accum_n, int idx) { -+ if (rowIdx == 1) { -+ lse_prefetched[colIdx] = accum_n < lse_extent -+ ? lse[accum_n] -+ : platform::numeric_limits::infinity(); -+ } -+ accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); -+ ++colIdx; -+ }, -+ [&](int accum_m) {}); -+ accumToSmem(shared_storage, accum, lane_id, tile_coords); -+ } -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h -new file mode 100644 -index 0000000..c3a2d9b ---- /dev/null -+++ b/3rdparty/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h -@@ -0,0 +1,57 @@ -+#include -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+template < -+ typename scalar_t, // scalar type -+ typename ThreadblockTileShape, // size of tile to load -+ int Threads, // number of participating threads -+ int ElementsPerAccess> // thread access width in elements -+class TileSmemLoader { -+ public: -+ using SmemTile = -+ cutlass::AlignedBuffer; -+ -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape< -+ ThreadblockTileShape::kColumn, // contiguous -+ ThreadblockTileShape::kRow>, // strided -+ Threads, // Threads -+ ElementsPerAccess>; // ElementsPerAccess -+ -+ using GmemTileIterator = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ ThreadblockTileShape, // Shape -+ scalar_t, // Element -+ cutlass::layout::RowMajor, // Layout -+ 0, // AdvanceRank -+ ThreadMap>; // ThreadMap -+ -+ using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< -+ ThreadblockTileShape, // Shape -+ scalar_t, // Element -+ cutlass::layout::RowMajor, // Layout -+ 0, // AdvanceRank -+ ThreadMap>; // ThreadMap -+ -+ using Fragment = typename GmemTileIterator::Fragment; -+ -+ /// load a tile from global memory into shared memory -+ CUTLASS_DEVICE -+ static void load( -+ GmemTileIterator tile_load_iter, -+ SmemTileIterator tile_store_iter) { -+ Fragment tb_frag; -+ tb_frag.clear(); -+ tile_load_iter.load(tb_frag); -+ tile_store_iter.store(tb_frag); -+ -+ __syncthreads(); -+ } -+}; -\ No newline at end of file -diff --git a/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu b/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu -new file mode 100644 -index 0000000..5f16ff1 ---- /dev/null -+++ b/3rdparty/cutlass/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu -@@ -0,0 +1,706 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run group convolution kernels using functions and data structures -+provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. -+ -+There are 2 group conv mode: -+ 1. cutlass::conv::GroupMode::kSingleGroup -+ This mode is for large K problem size: k_per_group (K/groups) equals or larger than -+ threadblock_tile_N. One or multiple threadblocks calculate data of one group. -+ 2. cutlass::conv::GroupMode::kMultipleGroup -+ This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N. -+ One threadblock will calculate data from more than one group. -+ -+Function profile_convolution_selecter() shows how to choose kernel with different group mode according -+to problem size and threadblock_tile size. -+*/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = float; // Data type of accumulator -+using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = float; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassTensorOp; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm80; -+ -+// This code section describes the tile size a thread block will compute -+using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape -+ -+// This code section describes tile size a warp will compute -+using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 3; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue>; // Data type for alpha/beta in linear combination -+ -+// Analytic kernel and operation for single group problem size -+using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+// Analytic kernel and operation for multiple group problem size -+using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+>::Kernel; -+using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+// Optimized kernel and operation for single group problem size -+using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementInputA, LayoutInputA, -+ ElementInputB, LayoutInputB, -+ ElementOutput, LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+>::Kernel; -+using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int groups; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ bool optimized; -+ std::string tag; -+ -+ Options(): -+ help(false), -+ input_size(1, 32, 32, 32), -+ filter_size(32, 3, 3, 32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ groups(1), -+ reference_check(false), -+ measure_performance(false), -+ iterations(20), -+ alpha(1), -+ beta(0), -+ optimized(false) { } -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || -+ (filter_size.n() % kAlignment)) { -+ -+ // misaligned tensors -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || -+ (padding.w() != filter_size.w() / 2)) { -+ -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update( -+ cutlass::Tensor4DCoord input_size, -+ cutlass::Tensor4DCoord filter_size) { -+ -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("optimized")) { -+ optimized = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ -+ cmd.get_cmd_line_argument("g", groups); -+ filter_size.c() = input_size.c() / groups; -+ -+ cmd.get_cmd_line_argument("u", conv_stride.row()); -+ cmd.get_cmd_line_argument("v", conv_stride.column()); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ if (filter_size.h() == 3 && filter_size.w() == 3) { -+ padding = {1, 1, 1, 1}; -+ } -+ else { -+ filter_size.h() = 1; -+ filter_size.w() = 1; -+ padding = {0, 0, 0, 0}; -+ } -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "42_ampere_tensorop_group_conv example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward grouped convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --g= Conv groups G\n\n" -+ << " --u= Conv stride_h\n\n" -+ << " --v= Conv stride_w\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --ref-check If set (true), reference check is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --tag= String to replicate across the first column in the results table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n" -+ << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result(): -+ runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) { } -+ -+ static std::ostream & print_header(std::ostream &out, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream & print(std::ostream &out, int idx, Options const &options) { -+ -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ out -+ << "conv_" << idx << "," -+ << options.input_size.n() << "," -+ << options.input_size.h() << "," -+ << options.input_size.w() << "," -+ << options.input_size.c() << "," -+ << options.filter_size.n() << "," -+ << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ << options.groups << "," -+ << runtime_ms << "," -+ << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one benchmark -+template -+Result profile_convolution(Options const &options) { -+ -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), -+ 1, -+ ElementInputA(7), -+ ElementInputA(-8), -+ 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), -+ 1, -+ ElementInputB(7), -+ ElementInputB(-8), -+ 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), -+ 1, -+ ElementOutput(7), -+ ElementOutput(-8), -+ 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill( -+ tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split K dimension into 1 partitions -+ int split_k_slices = 1; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size( -+ options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices, -+ options.groups -+ ); -+ -+ // Construct Conv2dOperation::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename Conv2dOperation::Arguments arguments{ -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ }; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ Conv2dOperation implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on device...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::device::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter -+ >( -+ problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_ref_d.device_ref(), -+ options.alpha, -+ options.beta -+ ); -+ -+ tensor_ref_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_d.host_view(), -+ tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Result profile_convolution_selecter(Options const &options) { -+ int k_per_group = options.filter_size.n() / options.groups; -+ -+ // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups -+ if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode -+ if (options.optimized) { -+ std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl; -+ exit(-1); -+ } else { -+ std::cout << "Select AnalyticMultipleGroupOperation\n"; -+ return profile_convolution(options); -+ } -+ } else { // SingleGroup mode -+ if (options.optimized) { -+ std::cout << "Select OptimizedSingleGroupOperation\n"; -+ return profile_convolution(options); -+ } else { -+ std::cout << "Select AnalyticSingleGroupOperation\n"; -+ return profile_convolution(options); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ bool notSupported = false; -+ -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ notSupported = true; -+ } -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { -+ std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." -+ << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution_selecter(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu b/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu -new file mode 100644 -index 0000000..90711df ---- /dev/null -+++ b/3rdparty/cutlass/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu -@@ -0,0 +1,740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Block-Ell sparse gemm example. -+ -+ This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation. -+ Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format. -+ Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here: -+ https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell -+ Whereas matrix B is a dense matrix. -+ -+ Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices. -+ First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, -+ represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix), -+ represented by tensor_ell_idx in this example, that represent the column indices of the -+ corresponding non-zero blocks. All rows in the matrices must have the same number of blocks. -+ ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in -+ row-major order. -+ -+ Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format -+ for this example: -+ a_rows - Rows in the sparse matrix. -+ a_cols - Colums in the sparse matrix. -+ a_ell_blocksize - Size of the ELL-Blocks. -+ a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) -+ tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns) -+ tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is -+ (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) -+ tensor_b - Input dense matrix whose size is (a_cols * n) -+ tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n) -+ {a_rows, n, a_cols} - Problem size -+ -+*/ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/ell_gemm.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_uncompress.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result { -+ -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ // -+ // Methods -+ // -+ -+ Result( -+ double runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess -+ ): -+ runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool reference_check; -+ int iterations; -+ int cuda_streams; -+ int a_rows, n, a_cols; -+ int a_ell_num_columns; -+ int a_ell_blocksize; -+ int a_base; -+ float alpha; -+ float beta; -+ -+ // -+ // Methods -+ // -+ -+ Options(): -+ help(false), -+ reference_check(true), -+ iterations(20), -+ cuda_streams(0), -+ a_rows(1024), -+ n(1024), -+ a_cols(1024), -+ a_ell_num_columns(512), -+ a_ell_blocksize(16), -+ a_base(0), -+ alpha(1), -+ beta() -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("alpha", alpha, 1.0f); -+ cmd.get_cmd_line_argument("beta", beta, 0.0f); -+ cmd.get_cmd_line_argument("iterations", iterations, 20); -+ cmd.get_cmd_line_argument("streams", cuda_streams, 0); -+ cmd.get_cmd_line_argument("reference-check", reference_check, true); -+ -+ cmd.get_cmd_line_argument("a_rows", a_rows, 1024); -+ cmd.get_cmd_line_argument("n", n, 1024); -+ cmd.get_cmd_line_argument("a_cols", a_cols, 1024); -+ -+ cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512); -+ cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16); -+ cmd.get_cmd_line_argument("a_base", a_base, 0); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "43_ell_block_sparse_gemm\n\n" -+ << " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --a_rows= Sets the number of the rows of the sparse matrix.\n" -+ << " --n= Sets the N dimension.\n" -+ << " --a_cols= Sets the number of columns of the sparse matrix.\n" -+ << " --a_ell_num_columns= Sets the actual number of columns of the Blocked-Ellpack format.\n" -+ << " --a_ell_blocksize= Sets the size of the ELL-Block.\n" -+ << " --a_base= Sets the base index.\n" -+ << " --alpha= Epilogue scalar alpha (real part)\n" -+ << " --beta= Epilogue scalar beta (real part)\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --reference-check= If true, performs reference check.\n"; -+ -+ out << "\n\nExamples:\n\n" -+ -+ << "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n" -+ << "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ -+ // Number of real-valued multiply-adds -+ int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n; -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Testbed { -+public: -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Options options; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_ELL; -+ uint32_t seed; -+ -+ cutlass::HostTensor tensor_a; -+ cutlass::HostTensor tensor_b; -+ cutlass::HostTensor tensor_c; -+ cutlass::HostTensor tensor_d; -+ -+ cutlass::HostTensor tensor_a_uncompressed; -+ cutlass::HostTensor reference_d; -+ -+ cutlass::HostTensor tensor_ell_idx; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ Options const &options_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { } -+ -+private: -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor_( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian( -+ view, seed, Element(), Element(0.5f)); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ // Fill with increasing elements -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity(), Element(1), Element()); -+ } else { -+ -+ // Fill with all 1s -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity(), Element(), Element(1)); -+ } -+ } -+ -+ /// Initializes data structures -+ void initialize_() { -+ tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns)); -+ tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n)); -+ tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ -+ tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols)); -+ reference_d.resize(cutlass::make_Coord(options.a_rows, options.n)); -+ -+ tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize, -+ options.a_ell_num_columns / options.a_ell_blocksize)); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021); -+ initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022); -+ initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023); -+ -+ if (init_ELL == cutlass::Distribution::Uniform) { -+ cutlass::reference::host::TensorFillRandomEllIdx( -+ tensor_ell_idx.host_view(), seed, -+ options.a_rows / options.a_ell_blocksize, -+ options.a_ell_num_columns / options.a_ell_blocksize, -+ options.a_cols / options.a_ell_blocksize); -+ -+ } else { -+ for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) { -+ for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) { -+ tensor_ell_idx.at({i, j}) = j+3; -+ } -+ } -+ } -+ -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ell_idx.sync_device(); -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify_() { -+ -+ bool passed = true; -+ -+ tensor_d.sync_host(); -+ -+ cutlass::uncompress_ell_block_sparse( -+ tensor_a_uncompressed.host_ref(), -+ tensor_a.host_ref(), -+ tensor_ell_idx.host_ref(), -+ options.a_rows, -+ options.a_cols, -+ options.a_ell_num_columns, -+ options.a_ell_blocksize -+ ); -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ {options.a_rows, options.n, options.a_cols}, -+ options.alpha, -+ tensor_a_uncompressed.host_ref(), -+ tensor_b.host_ref(), -+ options.beta, -+ reference_d.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ // Reference check -+ passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view()); -+ -+ if (!passed) { -+ std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; -+ -+ std::stringstream fname; -+ -+ fname << "error_43_ell_block_sparse_gemm" -+ << "mnk_" -+ << options.a_rows << "x" -+ << options.n << "x" -+ << options.a_cols << "_" -+ << options.a_ell_num_columns << "_" -+ << options.a_ell_blocksize << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results -+ << "alpha: " << ElementCompute(options.alpha) << "\n" -+ << "beta: " << ElementCompute(options.beta) << "\n" -+ << "block size: " << options.a_ell_blocksize << "\n" -+ << "\nA:\n" << tensor_a.host_view() << "\n" -+ << "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n" -+ << "\nB:\n" << tensor_b.host_view() << "\n" -+ << "\nC:\n" << tensor_c.host_view() << "\n" -+ << "\nD reference:\n" << reference_d.host_view() << "\n" -+ << "\nD computed:\n" << tensor_d.host_view() << "\n"; -+ -+ -+ return passed; -+ } -+ -+ return passed; -+ } -+ -+public: -+ -+ /// Returns the number of threadblocks to launch if the kernel can run on the target -+ /// device. Otherwise, returns zero. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes a BlockedEll SpMM kernel and measures runtime. -+ Result profile() { -+ -+ Result result; -+ -+ // Early exit -+ if (!sufficient()) { -+ std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ result.passed = false; -+ -+ // Initialize the problem -+ initialize_(); -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ {options.a_rows, options.n, options.a_cols}, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ tensor_ell_idx.device_data(), -+ options.a_ell_num_columns, -+ options.a_ell_blocksize, -+ options.a_base, -+ epilogue_op -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ result.status = gemm.initialize(args); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Run the BlockedEll SpMM object -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // Wait for completion -+ result.error = cudaDeviceSynchronize(); -+ -+ if (result.error != cudaSuccess) { -+ std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); -+ return result; -+ } -+ -+ // -+ // Verify correctness -+ // -+ result.passed = true; -+ -+ if (options.reference_check) { -+ result.passed = verify_(); -+ } -+ -+ // -+ // Warm-up run -+ // -+ result.status = gemm.run(); -+ -+ if (result.status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; -+ return result; -+ } -+ -+ // -+ // Construct events -+ // -+ -+ cudaEvent_t events[2]; -+ -+ for (auto & event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return -1; -+ } -+ } -+ -+ // Record an event at the start of a series of GEMM operations -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // -+ // Run profiling loop -+ // -+ -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ gemm(); -+ } -+ -+ // -+ // Stop profiling loop -+ // -+ -+ // Record an event when the GEMM operations have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Compute average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // -+ // Cleanup -+ // -+ -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ std::cout << std::endl; -+ std::cout << "ELL Block Sparse GEMM (CUTLASS):\n" -+ << "====================================================" << std::endl; -+ -+ std::cout << std::endl; -+ std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; -+ std::cout << " " << " GFLOPs: " << result.gflops << std::endl; -+ -+ return result; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // -+ // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. -+ // -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { -+ -+ // -+ // This example requires an NVIDIA Ampere-architecture GPU. -+ // -+ -+ std::cout -+ << "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or " -+ << "later (compute capability 80 or greater).\n"; -+ -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // -+ // Define the BlockedEll type -+ // -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits::value; -+ constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ constexpr int32_t kStages = 4; -+ using Gemm = typename cutlass::gemm::device::EllGemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ kStages, kAlignmentA, kAlignmentB>; -+ -+ // -+ // Profile it -+ // -+ -+ Testbed testbed(options); -+ -+ if (!testbed.sufficient()) { -+ std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; -+ return 0; -+ } -+ -+ Result result = testbed.profile(); -+ if (!result.passed) { -+ std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n"; -+ std::cout << "\nFailed\n"; -+ return -1; -+ } -+ -+ std::cout << "\nPassed\n"; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h -new file mode 100644 -index 0000000..fead537 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h -@@ -0,0 +1,154 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+// #include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+#include "fused_bias_act_epilogue.h" -+#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" -+#include "output_tile_thread_map_for_fused_bias.h" -+#include "default_thread_map_tensor_op_for_fused_bias.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultFusedBiasActEpilogueTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = typename std::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ OutputOp -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h -new file mode 100644 -index 0000000..d9ce0f8 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapTensorOpForFusedBias { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMapBiasAct < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h -new file mode 100644 -index 0000000..8b9c24c ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator without splitk -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename OutputOp_ ///< Output operator -+> -+class FusedBiasActEpilogue { -+ -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using OutputOp = OutputOp_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ -+public: -+ -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ FusedBiasActEpilogue( -+ ){ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators, -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ bool need_bias = output_op.is_source_needed(); -+ -+ if (need_bias) -+ compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); -+ else -+ compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); -+ -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); -+ } -+ -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators, -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ -+ source_fragment.clear(); -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; -+ fused_bias_act_fragment = output_op(accum_fragment, source_fragment); -+ -+ fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); -+ ++fused_bias_act_fragment_iterator; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void compute_source_no_needed_( -+ OutputOp const &output_op, ///< Output operator -+ AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile -+ AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); -+ -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; -+ fused_bias_act_fragment = output_op(accum_fragment); -+ -+ fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); -+ ++fused_bias_act_fragment_iterator; -+ } -+ } -+ -+}; -+ -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h -new file mode 100644 -index 0000000..66a6a34 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// RowArrangement determines how one or more warps cover a region of consecutive rows. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize, -+ bool Is2dTile -+> -+struct RowArrangementBiasAct; -+ -+/// RowArrangement in which each warp's access is a 1D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangementBiasAct { -+ static int const kWarpSize = 32; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ static int const kIterationsRow = 1; -+ static int const kDeltaRow = 1; -+ static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; -+ static int const kDeltaColumn = kWarpSize * kElementsPerAccess; -+ -+ static int const kAccessWidth = kWarpSize; -+ static int const kAccessRows = 1; -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = WarpsRemaining; -+}; -+ -+/// RowArrangement in which each warp's access is a 2D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangementBiasAct { -+ -+ static int const kMemoryAccessSize = 4;//128; -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ struct Detail { -+ static int const kShapeRow = Shape::kRow / WarpsRemaining; -+ static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; -+ -+ static int const kTargetMemoryAccessWidth = -+ kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); -+ -+ static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; -+ }; -+ -+ static int const kAccessWidth = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ kWarpSize / Detail::kShapeRow -+ : const_min( -+ Detail::kShapeWidth, -+ const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) -+ )); -+ -+ static int const kAccessRows = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ Detail::kShapeRow -+ : const_min(Shape::kRow, kWarpSize / kAccessWidth)); -+ -+ static int const kIterationsRow = Detail::kShapeRow / kAccessRows; -+ static int const kDeltaRow = kAccessRows; -+ -+ static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; -+ static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; -+ -+ static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); -+ static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); -+ static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); -+ -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = 1; -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 16 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template < -+ typename Shape_, -+ typename Count_, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct OutputTileOptimalThreadMapBiasAct { -+ -+ using Shape = Shape_; -+ using Count = Count_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail { -+ -+ // Clusters -+ static int const kIterationsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kCluster / kWarpCount -+ : 1); -+ -+ static int const kDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kCompactedDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kWarpPartitionsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ kWarpCount -+ : kWarpCount / Shape::kCluster); -+ -+ static int const kWarpsRemainingForGroups = -+ ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); -+ -+ // Groups -+ static int const kIterationsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kGroup / kWarpsRemainingForGroups -+ : 1); -+ -+ static int const kDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kCompactedDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kWarpPartitionsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ static int const kWarpsRemainingForRows = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ // Rows -+ using RowArrangement = detail::RowArrangementBiasAct< -+ Shape, -+ kWarpsRemainingForRows, -+ kElementsPerAccess, -+ kElementSize, -+ (Shape::kRow > kWarpsRemainingForRows) -+ >; -+ -+ // Warp partitions -+ using WarpPartitions = OutputTileShape< -+ RowArrangement::kWarpPartitionsColumn, -+ RowArrangement::kWarpPartitionsRow, -+ kWarpPartitionsGroup, -+ kWarpPartitionsCluster, -+ 1>; -+ -+ static int const kAccessWidth = RowArrangement::kAccessWidth; -+ static int const kAccessRows = RowArrangement::kAccessRows; -+ }; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kDeltaGroup, -+ Detail::kDeltaCluster, -+ 1>; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; -+ int group_offset = group_idx * Shape::kRow * Count::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ return MatrixCoord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ (column_offset + lane_col_offset) * kElementsPerAccess -+ ); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h -new file mode 100644 -index 0000000..9d7a6c7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h -@@ -0,0 +1,189 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FusedBiasActFragmentIteratorTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) -+> -+class FusedBiasActFragmentIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ OperatorElementC, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ OperatorElementC, -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FusedBiasActFragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+ /// Stores a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void store(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ accumulators_[accumulator_access_offset] = frag_ptr[n]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h -new file mode 100644 -index 0000000..05a4c90 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h -@@ -0,0 +1,427 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of the accumulation tile shape (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Whether beta is zero -+ bool IsBetaZero_ > -+class MmaTensorOpPureFragmentIterator; -+ -+ -+// Partial specialization for col-major accumulator tile -+// And Element type is the same as Accumulator Element type -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_> -+class MmaTensorOpPureFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Whether beta is zero -+ static bool const IsBetaZero = true; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ !(AccumulatorShape::kRow % Shape::kRow) && -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ AccessType src_fragment; -+ src_fragment.clear(); -+ -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; -+ int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow -+ * MmaIterations::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ (n + index_n) * AccumulatorIterations::kRow + m + index_m; -+ -+ frag_ptr[n * MmaIterations::kRow + m].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; -+ // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); -+ } -+ } -+ } -+ -+}; -+ -+// Partial specialization for row-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_> -+class MmaTensorOpPureFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Whether beta is zero -+ static bool const IsBetaZero = true; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ !(AccumulatorShape::kRow % Shape::kRow) && -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpPureFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ -+ FragmentAccessType src_fragment; -+ src_fragment.clear(); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; -+ int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow -+ * MmaIterations::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ int accumulator_access_offset = -+ (m + index_m) * AccumulatorIterations::kColumn + n + index_n; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h -new file mode 100644 -index 0000000..5b46a5a ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h -@@ -0,0 +1,292 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#include -+ -+template -+__device__ -+T add(T const & a, T const &b){ -+ return (a + b); -+} -+ -+template <> -+__device__ -+half2 add(half2 const & a, half2 const &b){ -+ return (__hadd2(a,b)); -+} -+ -+template -+struct RELU{ -+ __device__ -+ T operator()(T const & a){ -+ return a > T(0) ? a : T(0); -+ } -+ __device__ -+ half2 operator()(half2 const & a){ -+ float2 a_fp32x2 = __half22float2(a); -+ a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; -+ a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; -+ if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) -+ printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); -+ return __float22half2_rn(a_fp32x2); -+ } -+}; -+ -+template -+struct LEAKY_RELU{ -+ __device__ -+ T operator()(T const & a, T const & scale = half(1)){ -+ return a > T(0) ? a : scale * a; -+ } -+ __device__ -+ half2 operator()(half2 const & a, half const & scale = half(1)){ -+ half2 zero = __half2half2(half(0)); -+ half2 gt_zero = __hge2(a, zero); -+ half2 le_zero = __hle2(a, zero); -+ -+ -+ half2 scale_f16x2 = __half2half2(scale); -+ half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); -+ return __hmul2(a, mask_scale_f16x2); -+ } -+}; -+ -+template -+__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ LEAKY_RELU Act; -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); -+ } -+ -+ } -+} -+ -+ -+ -+template -+__global__ void leaky_and_activation(half* inout, half scale){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ LEAKY_RELU Act; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); -+ } -+ -+ } -+} -+ -+ -+ -+template -+void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ -+ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ leaky_and_activation<<>>(inout, scale); -+ else -+ leaky_and_activation<<>>(inout, bias, scale, mat_bias); -+} -+ -+template -+__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ RELU Act; -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); -+ } -+ -+ } -+} -+ -+ -+ -+template -+__global__ void relu_and_activation(half* inout){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ RELU Act; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); -+ } -+ -+ } -+} -+ -+ -+ -+template -+void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ relu_and_activation<<>>(inout); -+ else -+ relu_and_activation<<>>(inout, bias, mat_bias); -+} -+ -+ -+template -+__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ -+ Access_tp src_v[iter]; -+ Access_tp bias_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ if (mat_bias) -+ bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); -+ else -+ bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); -+ } -+ -+ } -+} -+ -+template -+__global__ void identity_and_activation(half* inout){ -+ -+ constexpr bool N_MOD_2 = N & 1 ? false : true; -+ -+ using Access_tp = typename std::conditional::type; -+ -+ constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); -+ -+ constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); -+ -+ int batch_id = blockIdx.y; -+ int batch_offset = batch_id * gridDim.x * N; -+ Access_tp src_v[iter]; -+ -+ for(int i = 0; i < iter; i++){ -+ int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; -+ if (idx < N){ -+ src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); -+ *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); -+ } -+ -+ } -+} -+ -+template -+void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ -+ dim3 grid(m, b); -+ if (bias == nullptr) -+ identity_and_activation<<>>(inout); -+ else -+ identity_and_activation<<>>(inout, bias, mat_bias); -+} -diff --git a/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h -new file mode 100644 -index 0000000..9e1a732 ---- /dev/null -+++ b/3rdparty/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#define TI(tag) \ -+ cudaEvent_t _event_start_ ##tag; \ -+ cudaEvent_t _event_end_ ##tag; \ -+ float _event_time_ ##tag; \ -+ cudaEventCreate(& _event_start_ ##tag); \ -+ cudaEventCreate(& _event_end_ ##tag); \ -+ cudaEventRecord(_event_start_ ##tag); -+ -+#define TO(tag, str, times) \ -+ cudaEventRecord(_event_end_ ##tag); \ -+ cudaEventSynchronize(_event_end_ ##tag); \ -+ cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ -+ float _event_time_once_ ##tag = _event_time_ ##tag / times; \ -+ printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ -+ cudaDeviceSynchronize(); \ -+ printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); -+ -+template -+struct memory_unit{ -+ T* host_ptr; -+ T* device_ptr; -+ int size_bytes; -+ int elements; -+ void h2d(){ -+ cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); -+ } -+ void d2h(){ -+ cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); -+ } -+ void free_all(){ -+ free(host_ptr); -+ cudaFree(device_ptr); -+ } -+ memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ -+ host_ptr = (T*) malloc(elements_ * sizeof(T)); -+ cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); -+ } -+ void init(int abs_range = 1){ -+ for(int i = 0; i < elements; i++){ -+ host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); -+ } -+ h2d(); -+ } -+}; -+ -+template -+int check_result(T * a, T * b, int N){ -+ int cnt = 0; -+ for(int i = 0; i < N; i ++){ -+ float std = float(a[i]); -+ float my = float(b[i]); -+ -+ if(abs(std - my) / abs(std) > 1e-2) -+ { -+ // printf("my: %f , std: %f\n", my, std); -+ cnt++; -+ } -+ -+ } -+ printf("total err: %d / %d\n", cnt, N); -+ return cnt; -+} -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h b/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h -new file mode 100644 -index 0000000..491888b ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/device/dual_gemm.h -@@ -0,0 +1,457 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Performs a dual gemm in one fused kernel: -+``` -+D0 = epilogue0(X @ B0, C0) -+D1 = epilogue1(X @ B1, C1) -+D2 = element_wise(D0, D1) -+``` -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "../kernel/dual_gemm.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp0_, -+ typename EpilogueOutputOp1_, -+ typename EpilogueOutputOp2_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ bool StoreD0 = true, -+ bool StoreD1 = true, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class DualGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp0 = EpilogueOutputOp0_; -+ using EpilogueOutputOp1 = EpilogueOutputOp1_; -+ using EpilogueOutputOp2 = EpilogueOutputOp2_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp1::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static bool constexpr kStoreD0 = StoreD0; -+ static bool constexpr kStoreD1 = StoreD1; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ using LayoutScaleBias = layout::RowMajor; -+ /// Define the kernel -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); -+ static_assert(kStages >= 3, "Only multistage is implemented"); -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, -+ ThreadblockShape, WarpShape, -+ InstructionShape, Stages, Operator>::ThreadblockMma; -+ using DualMma = threadblock::DualMmaMultistage< -+ typename Mma::Shape, -+ typename Mma::IteratorA, -+ typename Mma::SmemIteratorA, -+ Mma::kCacheOpA, -+ typename Mma::IteratorB, -+ typename Mma::SmemIteratorB, -+ Mma::kCacheOpB, -+ typename Mma::ElementC, -+ typename Mma::LayoutC, -+ typename Mma::Policy, -+ Mma::kStages, -+ SharedMemoryClearOption::kNone -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue0 = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, -+ EpilogueOutputOp0::kCount>::Epilogue; -+ using Epilogue1 = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, -+ EpilogueOutputOp1::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using DualGemmKernel = kernel::DualGemm< -+ DualMma, -+ Epilogue0, Epilogue1, EpilogueOutputOp2, -+ ThreadblockSwizzle, kSplitKSerial, -+ kStoreD0, kStoreD1>; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A0; -+ TensorRef ref_B0; -+ TensorRef ref_C0; -+ TensorRef ref_D0; -+ TensorRef ref_B1; -+ TensorRef ref_C1; -+ TensorRef ref_D1; -+ TensorRef ref_D2; -+ typename EpilogueOutputOp0::Params epilogue0; -+ typename EpilogueOutputOp1::Params epilogue1; -+ typename EpilogueOutputOp2::Params epilogue2; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A0_, -+ TensorRef ref_B0_, -+ TensorRef ref_C0_, -+ TensorRef ref_D0_, -+ TensorRef ref_B1_, -+ TensorRef ref_C1_, -+ TensorRef ref_D1_, -+ TensorRef ref_D2_, -+ typename EpilogueOutputOp0::Params epilogue0_ = -+ typename EpilogueOutputOp0::Params(), -+ typename EpilogueOutputOp1::Params epilogue1_ = -+ typename EpilogueOutputOp1::Params(), -+ typename EpilogueOutputOp2::Params epilogue2_ = -+ typename EpilogueOutputOp2::Params(), -+ int split_k_slices_ = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A0(ref_A0_), -+ ref_B0(ref_B0_), -+ ref_C0(ref_C0_), -+ ref_D0(ref_D0_), -+ ref_B1(ref_B1_), -+ ref_C1(ref_C1_), -+ ref_D1(ref_D1_), -+ ref_D2(ref_D2_), -+ epilogue0(epilogue0_), -+ epilogue1(epilogue1_), -+ epilogue2(epilogue2_), -+ split_k_slices(split_k_slices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename DualGemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ DualGemm() = default; -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ if (kStoreD0 != (args.ref_D0.data() != nullptr)) { -+ return Status::kErrorInternal; -+ } -+ if (kStoreD1 != (args.ref_D1.data() != nullptr)) { -+ return Status::kErrorInternal; -+ } -+ -+ Status status = DualGemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_D0, -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.ref_D2 -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename DualGemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A0.non_const_ref(), -+ args.ref_B0.non_const_ref(), -+ args.ref_C0.non_const_ref(), -+ args.ref_D0, -+ args.ref_B1.non_const_ref(), -+ args.ref_C1.non_const_ref(), -+ args.ref_D1, -+ args.ref_D2, -+ args.epilogue0, -+ args.epilogue1, -+ args.epilogue2, -+ reinterpret_cast(workspace), -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); -+ params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); -+ params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); -+ params_.ref_D0.reset(args.ref_D0.data()); -+ params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); -+ params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); -+ params_.ref_D1.reset(args.ref_D1.data()); -+ params_.ref_D2.reset(args.ref_D2.data()); -+ params_.output_op_0 = args.epilogue0; -+ params_.output_op_1 = args.epilogue1; -+ params_.output_op_2 = args.epilogue2; -+ params_.semaphore = reinterpret_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(DualGemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu -new file mode 100644 -index 0000000..15974e0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm.cu -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Dual-GEMM Example. -+ -+ Fused kernel that outputs `D0` and `D1`. -+ We assume that B0/B1 have the same shape/layout -+ -+``` -+D0 = epilogue0(X @ B0, C0) -+D1 = epilogue1(X @ B1, C1) -+D2 = element_wise(D0, D1) -+``` -+ D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) -+*/ -+ -+// #define IS_PROFILING -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "device/dual_gemm.h" -+#include "thread/left_silu_and_mul.h" -+#include "dual_gemm_run.h" -+#include "test_run.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); -+ -+constexpr int kStages = 3; -+constexpr bool kSplitKSerial = false; -+constexpr bool kUseBias = true; -+ -+ -+#if 0 -+using ElementOperandA = cutlass::bfloat16_t; -+using ElementOperandB = cutlass::bfloat16_t; -+using ElementOutput = cutlass::bfloat16_t; -+using ElementAccumulator = float; -+using ElementCompute = float; -+#else -+using ElementOperandA = cutlass::half_t; -+using ElementOperandB = cutlass::half_t; -+using ElementOutput = cutlass::half_t; -+using ElementAccumulator = cutlass::half_t; -+using ElementCompute = cutlass::half_t; -+#endif -+ -+constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : ( -+ // No bias -+ kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing -+); -+using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ kScaleType -+>; -+using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute, -+ kScaleType -+>; -+using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementOutput, -+ ElementCompute -+>; -+ -+const ElementCompute alpha0 = ElementCompute(1); -+const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0); -+const ElementCompute alpha1 = ElementCompute(1); -+const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0); -+ -+bool run_nonfused_gemm_f16_sm80() { -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using Gemm0 = cutlass::gemm::device::Gemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp0, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ 8, -+ 8, -+ kSplitKSerial -+ >; -+ using Gemm1 = cutlass::gemm::device::Gemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp1, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ 8, -+ 8, -+ kSplitKSerial -+ >; -+ -+ NonFusedDualGemmRun nonFusedGemm; -+ -+ std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; -+ bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); -+ if(pass) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return pass; -+} -+ -+template -+struct LeftSiLUAndMul { -+ struct Params{}; -+ CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {} -+ -+ CUTLASS_HOST_DEVICE void set_k_partition(int, int) {} -+ -+ CUTLASS_HOST_DEVICE T operator() ( -+ T const &lhs, -+ T const &rhs) const { -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(lhs); -+ return mul(silu_lhs, rhs); -+ } -+ -+ template -+ CUTLASS_HOST_DEVICE cutlass::Array operator() ( -+ cutlass::Array const &lhs, -+ cutlass::Array const &rhs) const { -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(lhs); -+ return mul(silu_lhs, rhs); -+ } -+}; -+ -+bool run_fused_gemm_f16_sm80_shmem() { -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ // Optionally, we might not need intermediate GEMM outputs -+ constexpr bool kStoreD0 = true; -+ constexpr bool kStoreD1 = true; -+ -+ using DualGemm = cutlass::gemm::device::DualGemm< -+ ElementOperandA, -+ cutlass::layout::RowMajor, -+ ElementOperandB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp0, -+ EpilogueOutputOp1, -+ EpilogueOutputOp2, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ kStages, -+ kStoreD0, -+ kStoreD1, -+ kSplitKSerial -+ >; -+ -+ DualFusedGemmRun fusedGemm; -+ -+ std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; -+ bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); -+ if(passed) -+ std::cout << "Pass\n"; -+ else -+ std::cout << "Fail\n"; -+ -+ return passed; -+ -+} -+ -+int main() { -+ -+ std::vectorfuncs = { -+ &run_nonfused_gemm_f16_sm80, -+ &run_fused_gemm_f16_sm80_shmem -+ }; -+ -+ std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial); -+ return testRun(80, funcs, test_name); -+} -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h -new file mode 100644 -index 0000000..63ca2ac ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/dual_gemm_run.h -@@ -0,0 +1,829 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_relu.h" -+ -+#include "helper.h" -+ -+#define CHECK_GT(val1, val2) \ -+ if((val1) <= (val2)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; -+#define CHECK_TRUE(val) \ -+ if(!(val)) \ -+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; -+ -+template < -+ typename OutputOp, -+ typename Element, -+ typename Layout> -+struct TensorEpilogueForEachFunc { -+ /// View type -+ using TensorView = cutlass::TensorView; -+ -+ /// Coordinate in tensor's index space -+ using TensorCoord = typename TensorView::TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view_x0; -+ TensorView view_x1; -+ TensorView view_y; -+ OutputOp output_op; -+ -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_x0_ = TensorView(), -+ TensorView view_x1_ = TensorView(), -+ TensorView view_y_ = TensorView(), -+ OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) -+ ): -+ view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { -+ } -+ }; -+ -+ Params params; -+ -+ CUTLASS_DEVICE -+ TensorEpilogueForEachFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ Element const & x0 = params.view_x0.at(coord); -+ Element const & x1 = params.view_x1.at(coord); -+ Element& y = params.view_y.at(coord); -+ y = params.output_op(x0, x1); -+ } -+}; -+ -+template < -+ typename OutputOp, -+ typename Element, -+ typename Layout> -+void TensorEpilogueForEach( -+ cutlass::TensorView x0, -+ cutlass::TensorView x1, -+ cutlass::TensorView y) { -+ -+ using Func = TensorEpilogueForEachFunc; -+ using Params = typename Func::Params; -+ -+ cutlass::reference::device::TensorForEach( -+ y.extent(), -+ Params(x0, x1, y) -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct NonFusedDualGemmRun -+{ -+ -+ using Gemm0 = Gemm0_; -+ using Gemm1 = Gemm1_; -+ using ElementAccumulator = typename Gemm0::ElementAccumulator; -+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ NonFusedDualGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(0), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(0), -+ bool relu = false, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementA, -+ typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementB, -+ typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm0::ElementC, -+ typename Gemm0::LayoutC> reference_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementB, -+ typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm1::ElementC, -+ typename Gemm1::LayoutC> reference_D1(problem_size.mn()); -+ -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_D0.sync_device(); -+ reference_D0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D1.sync_device(); -+ reference_D1.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; -+ typename Gemm0::Arguments arguments_0{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ tensor_D0.device_ref(), -+ {alpha0, beta0}, -+ split_k_slices -+ }; -+ -+ split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; -+ typename Gemm1::Arguments arguments_1{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ tensor_D1.device_ref(), -+ {alpha1, beta1}, -+ split_k_slices -+ }; -+ -+ -+ Gemm0 gemm_op_0; -+ Gemm1 gemm_op_1; -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); -+ cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); -+ -+ cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ status = gemm_op_1.initialize(arguments_1, workspace1.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = gemm_op_0(); -+ CUTLASS_CHECK(status); -+ status = gemm_op_1(); -+ CUTLASS_CHECK(status); -+ } -+#ifdef IS_PROFILING -+ return true; -+#endif -+ // -+ // Run the GEMM -+ // -+ cudaEvent_t start, stop1, stop2; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop1); -+ cudaEventCreate(&stop2); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_0(); -+ -+ CUTLASS_CHECK(status); -+ } -+ cudaEventRecord(stop1); -+ for(int i = 0; i < runs; i++) { -+ status = gemm_op_1(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop2); -+ cudaDeviceSynchronize(); -+ float gemm0Time, gemm1Time, totalTime; -+ cudaEventElapsedTime(&gemm0Time, start, stop1); -+ cudaEventElapsedTime(&gemm1Time, stop1, stop2); -+ cudaEventElapsedTime(&totalTime, start, stop2); -+ std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; -+ std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; -+ std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ -+ // -+ // Verify -+ // -+ cutlass::reference::device::Gemm< -+ typename Gemm0::ElementA, typename Gemm0::LayoutA, -+ typename Gemm0::ElementB, typename Gemm0::LayoutB, -+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm0::Operator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename Gemm1::ElementA, typename Gemm1::LayoutA, -+ typename Gemm1::ElementB, typename Gemm1::LayoutB, -+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm1::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size, -+ alpha1, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ -+ // Wait for kernels to finish -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ -+ bool passed0 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ CHECK_TRUE(passed0); -+ -+ bool passed1 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ CHECK_TRUE(passed1); -+ if (!passed0 || !passed1) { -+ -+ std::stringstream fname; -+ -+ fname << "error_DualGemm_device_nonfused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nD0 =\n" << tensor_D0.host_view() -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference =\n" << reference_D1.host_view() -+ << "\nComputed =\n" << tensor_D1.host_view(); -+ } -+ return passed0 && passed1; -+ } -+}; -+ -+template -+struct DualFusedGemmRun -+{ -+ -+ using DualGemm = DualGemm_; -+ using ElementAccumulator = typename DualGemm::ElementAccumulator; -+ using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; -+ using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_Scale; -+ cutlass::Distribution::Kind init_Bias; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ DualFusedGemmRun( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), -+ init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view, Element(0)); -+ } -+ else if (dist_kind == cutlass::Distribution::AllOnes) { -+ cutlass::reference::host::TensorFill(view, Element(1)); -+ } -+ else { -+ std::cerr << "Not implemented\n"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha0 = ElementCompute(1), -+ ElementCompute beta0 = ElementCompute(1), -+ ElementCompute alpha1 = ElementCompute(1), -+ ElementCompute beta1 = ElementCompute(1), -+ bool relu = false, -+ int warm_ups = 1, -+ int runs = 100) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementA, -+ typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementB, -+ typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D0(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementB, -+ typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D1(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename DualGemm::ElementC, -+ typename DualGemm::LayoutC> reference_D2(problem_size.mn()); -+ -+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); -+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); -+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); -+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); -+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); -+ CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ tensor_D2.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D0.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D1.host_view()); -+ cutlass::reference::host::TensorFill( -+ reference_D2.host_view()); -+ -+ tensor_A0.sync_device(); -+ tensor_B0.sync_device(); -+ tensor_C0.sync_device(); -+ tensor_Bias0.sync_device(); -+ tensor_B1.sync_device(); -+ tensor_C1.sync_device(); -+ tensor_Bias1.sync_device(); -+ tensor_D0.sync_device(); -+ tensor_D1.sync_device(); -+ tensor_D2.sync_device(); -+ reference_D0.sync_device(); -+ reference_D1.sync_device(); -+ reference_D2.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; -+ typename cutlass::TensorRef nullptr_ref{}; -+ decltype(nullptr_ref) ref_B0, ref_B1; -+ if (beta0 != ElementCompute(0)) { -+ ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; -+ } -+ if (beta1 != ElementCompute(0)) { -+ ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; -+ } -+ typename DualGemm::Arguments arguments{ -+ problem_size, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ ref_B0, -+ DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, -+ tensor_B1.device_ref(), -+ ref_B1, -+ DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, -+ tensor_D2.device_ref(), -+ {alpha0, beta0}, -+ {alpha1, beta1}, -+ {}, -+ split_k_slices -+ }; -+ -+ DualGemm b2b_gemm_op; -+ -+ cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); -+ -+ cutlass::Status status = b2b_gemm_op.can_implement(arguments); -+ -+ CUTLASS_CHECK(status); -+ -+ status = b2b_gemm_op.initialize(arguments, workspace.get()); -+ -+ CUTLASS_CHECK(status); -+ -+ for(int i = 0; i < warm_ups; i++) { -+ status = b2b_gemm_op(); -+ CUTLASS_CHECK(status); -+ } -+ -+#ifdef IS_PROFILING -+ return true; -+#endif -+ // -+ // Run the GEMM -+ // -+ -+ cudaEvent_t start, stop; -+ cudaEventCreate(&start); -+ cudaEventCreate(&stop); -+ -+ cudaEventRecord(start); -+ -+ for(int i = 0; i < runs; i++) { -+ status = b2b_gemm_op(); -+ -+ CUTLASS_CHECK(status); -+ } -+ -+ cudaEventRecord(stop); -+ cudaDeviceSynchronize(); -+ float gemmTime; -+ cudaEventElapsedTime(&gemmTime, start, stop); -+ std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; -+ -+ tensor_D0.sync_host(); -+ tensor_D1.sync_host(); -+ tensor_D2.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::device::Gemm< -+ typename DualGemm::ElementA, typename DualGemm::LayoutA, -+ typename DualGemm::ElementB, typename DualGemm::LayoutB, -+ typename DualGemm::ElementC, typename DualGemm::LayoutC, -+ ElementAccumulator, ElementAccumulator> -+ reference_gemm_0; -+ -+ cutlass::reference::device::Gemm< -+ typename DualGemm::ElementA, typename DualGemm::LayoutA, -+ typename DualGemm::ElementB, typename DualGemm::LayoutB, -+ typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename DualGemm::Operator> -+ reference_gemm_1; -+ -+ reference_gemm_0( -+ problem_size, -+ alpha0, -+ tensor_A0.device_ref(), -+ tensor_B0.device_ref(), -+ beta0, -+ {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, -+ reference_D0.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D0.device_view()); -+ } -+ -+ reference_gemm_1( -+ problem_size, -+ alpha1, -+ tensor_A0.device_ref(), -+ tensor_B1.device_ref(), -+ beta1, -+ {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, -+ reference_D1.device_ref() -+ ); -+ if(relu) { -+ cutlass::reference::device::TensorReLu(reference_D1.device_view()); -+ } -+ TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); -+ cudaDeviceSynchronize(); -+ reference_D0.sync_host(); -+ reference_D1.sync_host(); -+ reference_D2.sync_host(); -+ -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); -+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); -+ -+ bool passed_out0 = true; -+ if (DualGemm::kStoreD0) { -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); -+ passed_out0 = cutlass::reference::host::TensorEquals( -+ reference_D0.host_view(), -+ tensor_D0.host_view()); -+ } -+ CHECK_TRUE(passed_out0); -+ -+ bool passed_out1 = true; -+ if (DualGemm::kStoreD1) { -+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ passed_out1 = cutlass::reference::host::TensorEquals( -+ reference_D1.host_view(), -+ tensor_D1.host_view()); -+ } -+ CHECK_TRUE(passed_out1); -+ -+ bool passed_out2 = cutlass::reference::host::TensorEquals( -+ reference_D2.host_view(), -+ tensor_D2.host_view()); -+ CHECK_TRUE(passed_out2); -+ -+ bool passed = passed_out0 && passed_out1 && passed_out2; -+ if (!passed) -+ { -+ -+ std::stringstream fname; -+ -+ fname << "error_DualGemm_device_fused.txt"; -+ std::cerr << "Dumping results in " << fname.str() << "\n"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "A0 =\n" << tensor_A0.host_view() -+ << "\nB0 =\n" << tensor_B0.host_view() -+ << "\nC0 =\n" << tensor_C0.host_view() -+ << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" -+ << "\nB1 =\n" << tensor_B1.host_view() -+ << "\nC1 =\n" << tensor_C1.host_view() -+ << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" -+ << "\n\nReference0 =\n" << reference_D0.host_view() -+ << "\nComputed0 =\n" << tensor_D0.host_view() -+ << "\n\nReference1 =\n" << reference_D1.host_view() -+ << "\nComputed1 =\n" << tensor_D1.host_view() -+ << "\n\nReference2 =\n" << reference_D2.host_view() -+ << "\nComputed2 =\n" << tensor_D2.host_view(); -+ } -+ //std::cout << "A0 " << tensor_A0.host_view() << std::endl; -+ // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; -+ // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; -+ // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; -+ //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; -+ return passed; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h b/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h -new file mode 100644 -index 0000000..4cbddaa ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+#include "../threadblock/dual_mma_multistage.h" -+#include "../threadblock/dual_epilogue.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue0_, ///! Epilogue -+ typename Epilogue1_, ///! Epilogue -+ typename OutputOp2_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. -+ bool StoreD0, -+ bool StoreD1 -+> -+struct DualGemm { -+ -+ using DualMma = DualMma_; -+ -+ using Epilogue0 = Epilogue0_; -+ using Epilogue1 = Epilogue1_; -+ using OutputOp0 = typename Epilogue0::OutputOp; -+ using OutputOp1 = typename Epilogue1::OutputOp; -+ using OutputOp2 = OutputOp2_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static constexpr bool kStoreD0 = StoreD0; -+ static constexpr bool kStoreD1 = StoreD1; -+ -+ using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< -+ typename Epilogue0::Shape, -+ typename Epilogue0::WarpMmaOperator, -+ Epilogue0::kPartitionsK, -+ typename Epilogue0::OutputTileIterator, -+ typename Epilogue0::AccumulatorFragmentIterator, -+ typename Epilogue0::WarpTileIterator, -+ typename Epilogue0::SharedLoadIterator, -+ OutputOp0, -+ OutputOp1, -+ OutputOp2, -+ typename Epilogue0::Padding, -+ kStoreD0, -+ kStoreD1, -+ Epilogue0::kFragmentsPerIteration, -+ true // IterationsUnroll -+ >; -+ -+ static bool const kSplitKSerial = SplitKSerial; -+ static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), -+ "Split-K serial requires buffers for D0/D1 for reduction"); -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount0 = typename DualMma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount0::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma0 -+ typename DualMma::IteratorA::Params params_A0; -+ typename DualMma::IteratorA::TensorRef ref_A0; -+ typename DualMma::IteratorB::Params params_B0; -+ typename DualMma::IteratorB::TensorRef ref_B0; -+ typename Epilogue0::OutputTileIterator::Params params_C0; -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0; -+ typename Epilogue0::OutputTileIterator::Params params_D0; -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0; -+ typename OutputOp0::Params output_op_0; -+ -+ // Mma1 -+ typename DualMma::IteratorB::Params params_B1; -+ typename DualMma::IteratorB::TensorRef ref_B1; -+ typename Epilogue1::OutputTileIterator::Params params_C1; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1; -+ typename Epilogue1::OutputTileIterator::Params params_D1; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1; -+ typename OutputOp1::Params output_op_1; -+ -+ typename Epilogue1::OutputTileIterator::Params params_D2; -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2; -+ typename OutputOp2::Params output_op_2; -+ -+ int *semaphore; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ // Mma0: D0 = A @ B0 + C0 -+ typename DualMma::IteratorA::TensorRef ref_A0, -+ typename DualMma::IteratorB::TensorRef ref_B0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0, -+ // Mma1: D1 = A @ B1 + C1 -+ typename DualMma::IteratorB::TensorRef ref_B1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1, -+ -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2, -+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), -+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), -+ typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ // Mma0 -+ params_A0(ref_A0.layout()), -+ ref_A0(ref_A0), -+ params_B0(ref_B0.layout()), -+ ref_B0(ref_B0), -+ params_C0(ref_C0.layout()), -+ ref_C0(ref_C0), -+ params_D0(ref_D0.layout()), -+ ref_D0(ref_D0), -+ // Mma1 -+ params_B1(ref_B1.layout()), -+ ref_B1(ref_B1), -+ params_C1(ref_C1.layout()), -+ ref_C1(ref_C1), -+ params_D1(ref_D1.layout()), -+ ref_D1(ref_D1), -+ params_D2(ref_D2.layout()), -+ ref_D2(ref_D2), -+ output_op_0(output_op_0), -+ output_op_1(output_op_1), -+ output_op_2(output_op_2) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename DualMma::SharedStorage main_loop; -+ typename DualEpilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DualGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename DualMma::IteratorA::TensorRef ref_A0, -+ typename DualMma::IteratorB::TensorRef ref_B0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_C0, -+ typename Epilogue0::OutputTileIterator::TensorRef ref_D0, -+ typename DualMma::IteratorB::TensorRef ref_B1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_C1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D1, -+ typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { -+ -+ static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D0, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D2, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A0{ -+ threadblock_tile_offset.m() * DualMma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B0{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B1{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = -+ (params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ? -+ params.problem_size.k() : -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename DualMma::IteratorA iterator_A0( -+ params.params_A0, -+ params.ref_A0.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A0); -+ -+ typename DualMma::IteratorB iterator_B0( -+ params.params_B0, -+ params.ref_B0.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B0); -+ -+ typename DualMma::IteratorB iterator_B1( -+ params.params_B1, -+ params.ref_B1.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B1); -+ -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ -+ // Construct thread-scoped matrix multiply -+ typename DualMma::FragmentC accum0; -+ typename DualMma::FragmentC accum1; -+ accum0.clear(); -+ accum1.clear(); -+ -+ DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, -+ accum0, accum1, -+ iterator_A0, iterator_B0, iterator_B1, -+ accum0, accum1); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp0 output_op_0(params.output_op_0); -+ OutputOp1 output_op_1(params.output_op_1); -+ OutputOp2 output_op_2(params.output_op_2); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * DualMma::Shape::kM, -+ threadblock_tile_offset.n() * DualMma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue0::OutputTileIterator iterator_C0( -+ params.params_C0, -+ params.ref_C0.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_C1( -+ params.params_C1, -+ params.ref_C1.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue0::OutputTileIterator iterator_D0( -+ params.params_D0, -+ params.ref_D0.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_D1( -+ params.params_D1, -+ params.ref_D1.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ typename Epilogue1::OutputTileIterator iterator_D2( -+ params.params_D2, -+ params.ref_D2.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ DualEpilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C0 = iterator_D0; -+ iterator_C1 = iterator_D1; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ typename Epilogue0::OutputTileIterator source_iters[] = { -+ iterator_C0, iterator_C1 -+ }; -+ const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); -+ epilogue( -+ output_op_0, output_op_1, output_op_2, -+ iterator_D0, iterator_D1, iterator_D2, -+ accum0, accum1, -+ source_iters, -+ writeToD2 -+ ); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/test_run.h b/3rdparty/cutlass/examples/45_dual_gemm/test_run.h -new file mode 100644 -index 0000000..b64f31f ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/test_run.h -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+#include -+ -+// Run tests on GPUs -+ -+int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { -+ -+ bool supported = false; -+ -+ int arch_major = arch / 10; -+ int arch_minor = arch - arch / 10 * 10; -+ -+ if(arch_major >= 8) { -+ // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. -+ // -+ // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. -+ if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { -+ supported = true; -+ } -+ } -+ else if(arch_major >= 7) { -+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. -+ // -+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. -+ if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { -+ supported = true; -+ } -+ } -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (!(props.major == arch_major && props.minor == arch_minor)) { -+ supported = false; -+ } -+ -+ if (!supported) { -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ std::cout << "This example isn't supported on current architecture" << std::endl; -+ return 0; -+ } -+ -+ bool pass = true; -+ -+ std::cout << "Device: " << props.name << std::endl; -+ std::cout << "Arch: SM" << arch << std::endl; -+ std::cout << "Test: " << test_name << std::endl; -+ for(auto func : test_funcs) { -+ pass &= func(); -+ } -+ -+ -+ if(pass) -+ return 0; -+ else -+ return -1; -+ -+} -+ -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h b/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h -new file mode 100644 -index 0000000..0ba9bb9 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/epilogue/thread/linear_combination_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LeftSiLUAndMul { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ struct Params{}; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LeftSiLUAndMul(Params const &/*params*/) {} -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(false); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &lhs, -+ FragmentAccumulator const &rhs) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_to_compute; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter compute_to_output; -+ -+ ComputeFragment converted_lhs = accumulator_to_compute(lhs); -+ ComputeFragment converted_rhs = accumulator_to_compute(rhs); -+ -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(converted_lhs); -+ return compute_to_output(mul(silu_lhs, converted_rhs)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ ElementOutput operator()( -+ ElementAccumulator const& lhs, -+ ElementAccumulator const& rhs -+ ) const { -+ ElementCompute convert_lhs(lhs); -+ ElementCompute convert_rhs(rhs); -+ cutlass::epilogue::thread::SiLu silu; -+ cutlass::multiplies mul; -+ auto silu_lhs = silu(convert_lhs); -+ return ElementOutput(mul(silu_lhs, convert_rhs)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h -new file mode 100644 -index 0000000..d9492ab ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h -@@ -0,0 +1,430 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ ///< Output operator -+ typename OutputOp0_, -+ typename OutputOp1_, -+ typename OutputOp2_, -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ bool StoreD0 = true, -+ bool StoreD1 = true, -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class DualEpilogue { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ static bool constexpr kStoreD0 = StoreD0; -+ static bool constexpr kStoreD1 = StoreD1; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp0 = OutputOp0_; -+ using OutputOp1 = OutputOp1_; -+ using OutputOp2 = OutputOp2_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ struct SharedStorage { -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = typename Base::Shape; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = typename Base::SharedStorage::StorageShape; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage[2]; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference(int i) { -+ return TensorRef( -+ storage[i].data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+public: -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator0_; -+ SharedLoadIterator shared_load_iterator1_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator0_; -+ WarpTileIterator warp_tile_iterator1_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ DualEpilogue( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_load_iterator0_(shared_storage.reference(0), thread_idx), -+ shared_load_iterator1_(shared_storage.reference(1), thread_idx), -+ warp_tile_iterator0_(shared_storage.reference(0), lane_idx), -+ warp_tile_iterator1_(shared_storage.reference(1), lane_idx) -+ { -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator0_.add_tile_offset(warp_offset); -+ warp_tile_iterator1_.add_tile_offset(warp_offset); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp0 const &output_op0, -+ OutputOp1 const &output_op1, -+ OutputOp2 const &output_op2, -+ OutputTileIterator dest0, -+ OutputTileIterator dest1, -+ OutputTileIterator dest2, -+ AccumulatorTile const &accumulator0, -+ AccumulatorTile const &accumulator1, -+ OutputTileIterator source_iterator[2], -+ bool writeToD2 // true if it's the final split-k -+ ) { -+ // TODO: Implement when no source is needed -+ -+ typename OutputTileIterator::Fragment source_fragment[2]; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ source_fragment[i].clear(); -+ } -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ source_iterator[i].load(source_fragment[i]); -+ ++source_iterator[i]; -+ } -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; -+ typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; -+ -+ shared_load_iterator0_.load(aligned_accum_fragment0[0]); -+ shared_load_iterator1_.load(aligned_accum_fragment1[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator0_.load(aligned_accum_fragment0[i]); -+ shared_load_iterator1_.load(aligned_accum_fragment1[i]); -+ aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); -+ aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); -+ } -+ -+ shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment[3]; -+ -+ apply_output_operator_(output_fragment, -+ output_op0, output_op1, output_op2, -+ aligned_accum_fragment0[0], aligned_accum_fragment1[0], -+ source_fragment); -+ -+ -+ // -+ // Store the final result -+ // -+ -+ if (kStoreD0) { -+ dest0.store(output_fragment[0]); -+ ++dest0; -+ } -+ if (kStoreD1) { -+ dest1.store(output_fragment[1]); -+ ++dest1; -+ } -+ if (writeToD2) { -+ dest2.store(output_fragment[2]); -+ ++dest2; -+ } -+ } -+ } -+ -+private: -+ -+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment (&output_fragment)[3], -+ OutputOp0 const &output_op0, -+ OutputOp1 const &output_op1, -+ OutputOp2 const &output_op2, -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, -+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, -+ typename OutputTileIterator::Fragment const (&source_fragment)[2]) { -+ -+ OutputAccessType* output_frag_ptr[3] = { -+ reinterpret_cast(&output_fragment[0]), -+ reinterpret_cast(&output_fragment[1]), -+ reinterpret_cast(&output_fragment[2]) -+ }; -+ -+ AccumulatorAccessType const *compute_frag_ptr[2] = { -+ reinterpret_cast(&aligned_accum_fragment0), -+ reinterpret_cast(&aligned_accum_fragment1) -+ }; -+ -+ OutputAccessType const *source_frag_ptr[2] = { -+ reinterpret_cast(&source_fragment[0]), -+ reinterpret_cast(&source_fragment[1]) -+ }; -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operators -+ output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); -+ output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); -+ output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h -new file mode 100644 -index 0000000..10563e7 ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DualMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B0; -+ AlignedBuffer operand_B1; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B0_ref() { -+ return TensorRefB{operand_B0.data(), LayoutB()}; -+ } -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B1_ref() { -+ return TensorRefB{operand_B1.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B0_; -+ typename Operator::IteratorB warp_tile_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DualMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), -+ warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h -new file mode 100644 -index 0000000..7843f2b ---- /dev/null -+++ b/3rdparty/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h -@@ -0,0 +1,760 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "dual_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DualMmaMultistage : -+ public DualMmaBase { -+public: -+ ///< Base class -+ using Base = DualMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B0_; -+ SmemIteratorB smem_iterator_B1_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DualMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), -+ smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B0.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ iterator_B1.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B0_.set_iteration_index(group_start_B); -+ this->smem_iterator_B1_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B0 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B0.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B0.valid()); -+ } -+ -+ ++iterator_B0; -+ } -+ ++this->smem_iterator_B0_; -+ } -+ } -+ // Async Copy for operand B1 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B1.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B1.valid()); -+ } -+ -+ ++iterator_B1; -+ } -+ ++this->smem_iterator_B1_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum0, -+ FragmentC &accum1, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B0, -+ IteratorB iterator_B1, -+ ///< initial value of accumulator -+ FragmentC const &src_accum0, -+ FragmentC const &src_accum1 -+ ) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B0.set_iteration_index(0); -+ iterator_B1.set_iteration_index(0); -+ this->smem_iterator_B0_.set_iteration_index(0); -+ this->smem_iterator_B1_.set_iteration_index(0); -+ -+ // Async Copy for operand B0 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B0_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); -+ -+ ++iterator_B0; -+ } -+ -+ ++this->smem_iterator_B0_; -+ } -+ // Async Copy for operand B1 -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B1_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); -+ -+ ++iterator_B1; -+ } -+ -+ ++this->smem_iterator_B1_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum0 = src_accum0; -+ accum1 = src_accum1; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels -+ // so that all accumulator elements outside the GEMM footprint are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ typename IteratorB::AccessType zero_B; -+ zero_B.clear(); -+ -+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory -+ SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); -+ last_smem_iterator_B0.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B0.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B0; -+ } -+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory -+ SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); -+ last_smem_iterator_B1.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B1.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B1; -+ } -+ } -+ -+ // Waits until stages up to the previous (kStages-2)th stage have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B0[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B1[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B0[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B1[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B0_.set_kgroup_index(0); -+ this->warp_tile_iterator_B1_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B0_; -+ ++this->warp_tile_iterator_B1_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum0, tmp_accum1; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum0.clear(); -+ tmp_accum1.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B0_; -+ ++this->warp_tile_iterator_B1_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B0[warp_mma_k % 2]); -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B1[warp_mma_k % 2]); -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum0, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ tmp_accum0 -+ ); -+ warp_mma( -+ tmp_accum1, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ tmp_accum1 -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum0 = plus_accum(accum0, tmp_accum0); -+ accum1 = plus_accum(accum1, tmp_accum1); -+ tmp_accum0.clear(); -+ tmp_accum1.clear(); -+ } -+ } else { -+ warp_mma( -+ accum0, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B0[warp_mma_k % 2], -+ accum0 -+ ); -+ warp_mma( -+ accum1, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B1[warp_mma_k % 2], -+ accum1 -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until stages up to the previous (kStages-2)th stage have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B0.add_tile_offset({1, 0}); -+ iterator_B1.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B0_.add_tile_offset({1, 0}); -+ this->smem_iterator_B1_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); -+ this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B0_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ this->warp_tile_iterator_B1_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B0.clear_mask(gemm_k_iterations == 0); -+ iterator_B1.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B0[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B1[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum0 = plus_accum(accum0, tmp_accum0); -+ accum1 = plus_accum(accum1, tmp_accum1); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu -new file mode 100644 -index 0000000..9a26e89 ---- /dev/null -+++ b/3rdparty/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu -@@ -0,0 +1,672 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/** -+This example shows how to run depthwise 2d convolution kernels using functions and data structures -+provided by CUTLASS using SIMT instruction; -+ -+There are 3 types of implementations of depthwise 2d convoltion -+ 1. kAnalytic -+ Implicit gemm 2d convoltion algorithm. -+ 2. kOptimized -+ An optimized algorithm and supports arbitrary stride and dilation. -+ 3. kFixedStrideDilation -+ An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do -+more optimizations. -+ -+In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter -+size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If -+in this case, please use kOptimized. -+ -+For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve -+better perf, when the output tensor size is large, splitk should be enabled to achieve better perf. -+ -+In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d -+convolution kernel. -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+// The code section below describes datatype for input, output tensors and computation between -+// elements -+using ElementAccumulator = cutlass::half_t; // Data type of accumulator -+using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) -+using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -+using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -+using ElementOutput = cutlass::half_t; // Data type of elements in output tensor -+ -+using LayoutInputA = cutlass::layout::TensorNHWC; -+using LayoutInputB = cutlass::layout::TensorNHWC; -+using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM -+using MMAOp = cutlass::arch::OpClassSimt; -+ -+// This code section describes CUDA SM architecture number -+using SmArch = cutlass::arch::Sm60; -+ -+// This code section describes the groups a thread block will compute -+constexpr int groups_per_cta = 64; -+ -+// This code section describes the output tile a thread block will compute -+using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+// This code section describes the filter shape -+using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+// Threadblock tile shape -+using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+// This code section describes tile size a warp will computes -+// WarpShape::kM = P * Q the warps would process -+// WarpShape::kN = groups_per_cta that the warps would process -+// WarpShape::kK = filter_size that the warps would process -+using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+// This code section describes the size of MMA op -+using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+// This code section describes how threadblocks are scheduled on GPU -+using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+// Number of pipelines you want to use -+constexpr int NumStages = 4; -+ -+// This code section describe iterator algorithm selected is kFixedStrideDilation -+static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+using StrideShape = cutlass::MatrixShape<1, 1>; -+using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+// This code section describes the epilogue part of the kernel, we use default value -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation. -+ -+using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ bool help; -+ cutlass::Tensor4DCoord input_size; -+ cutlass::Tensor4DCoord filter_size; -+ cutlass::Tensor4DCoord padding; -+ cutlass::MatrixCoord conv_stride; -+ cutlass::MatrixCoord dilation; -+ int groups; -+ int splitk; -+ bool reference_check; -+ bool measure_performance; -+ int iterations; -+ bool save_workspace; -+ ElementComputeEpilogue alpha; -+ ElementComputeEpilogue beta; -+ std::string tag; -+ -+ Options() -+ : help(false), -+ input_size(1, 128, 128, 32), -+ filter_size(32, 3, 3, 1), -+ groups(32), -+ padding(1, 1, 1, 1), -+ conv_stride(1, 1), -+ dilation(1, 1), -+ reference_check(false), -+ measure_performance(true), -+ iterations(20), -+ save_workspace(false), -+ alpha(1), -+ beta(0), -+ splitk(1) {} -+ -+ // Verify the problem size is compatible with the CUTLASS Convolution implementation. -+ bool valid() { -+ // -+ // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, -+ // all pointers, strides, and tensor extents must be divisible by 8 elements. -+ // -+ int const kAlignment = 8; -+ -+ if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { -+ // misaligned tensors -+ return false; -+ } -+ -+ // depthwise conv -+ if (groups != input_size.c()) { -+ return false; -+ } -+ -+ if (filter_size.n() != groups) { -+ return false; -+ } -+ -+ // Invalid padding -+ if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Updates input and filter sizes -+ void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { -+ this->input_size = input_size; -+ this->filter_size = filter_size; -+ -+ padding.n() = filter_size.h() / 2; -+ padding.h() = filter_size.h() / 2; -+ padding.w() = filter_size.w() / 2; -+ padding.c() = filter_size.w() / 2; -+ } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("ref-check")) { -+ reference_check = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("perf-check")) { -+ measure_performance = true; -+ } -+ -+ if (cmd.check_cmd_line_flag("save-workspace")) { -+ save_workspace = true; -+ } -+ -+ cmd.get_cmd_line_argument("n", input_size.n()); -+ cmd.get_cmd_line_argument("h", input_size.h()); -+ cmd.get_cmd_line_argument("w", input_size.w()); -+ cmd.get_cmd_line_argument("c", input_size.c()); -+ -+ cmd.get_cmd_line_argument("k", filter_size.n()); -+ cmd.get_cmd_line_argument("r", filter_size.h()); -+ cmd.get_cmd_line_argument("s", filter_size.w()); -+ -+ cmd.get_cmd_line_argument("g", groups); -+ -+ filter_size.c() = 1; -+ filter_size.n() = input_size.c(); -+ -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("splitk", splitk); -+ -+ cmd.get_cmd_line_argument("iterations", iterations); -+ cmd.get_cmd_line_argument("tag", tag); -+ -+ int32_t padding_h = filter_size.h() / 2; -+ int32_t padding_w = filter_size.w() / 2; -+ padding = {padding_h, padding_h, padding_w, padding_w}; -+ } -+ -+ /// Prints the usage statement. -+ std::ostream &print_usage(std::ostream &out) const { -+ out << "41_depthwise_gemm_fprop example\n\n" -+ << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" -+ << " forward convolution on tensors of layout NHWC.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --n= Input tensor extent N\n" -+ << " --h= Input tensor extent H\n" -+ << " --w= Input tensor extent W\n" -+ << " --c= Input tensor extent C\n" -+ << " --k= Filter extent K\n" -+ << " --r= Filter extent R\n" -+ << " --s= Filter extent S\n\n" -+ << " --g= Groups\n\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --splitk= Enable splitK\n\n" -+ << " --ref-check If set (true), reference check on the host is computed\n" -+ << " --perf-check If set (true), performance is measured.\n" -+ << " --iterations= Number of profiling iterations to perform.\n" -+ << " --save-workspace If set, workspace is written to a text file.\n" -+ << " --tag= String to replicate across the first column in the results " -+ "table\n"; -+ -+ out << "\n\nExamples:\n\n" -+ << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 " -+ "--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n" -+ << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 " -+ "--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n"; -+ -+ return out; -+ } -+ -+ /// Computes the output tensor size (NPQK) -+ cutlass::Tensor4DCoord output_size() const { -+ return cutlass::Tensor4DCoord( -+ input_size.n(), -+ (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, -+ (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, -+ filter_size.n()); -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const { -+ // Number of multiply-adds = NPQK * CRS -+ int64_t fmas = -+ output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); -+ -+ // Two flops per multiply-add -+ return 2.0 * double(fmas) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct Result { -+ double runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cutlass::Status reference_check; -+ cudaError_t error; -+ -+ Result() -+ : runtime_ms(0), -+ gflops(0), -+ status(cutlass::Status::kSuccess), -+ reference_check(cutlass::Status::kInvalid), -+ error(cudaSuccess) {} -+ -+ static std::ostream &print_header(std::ostream &out, Options const &options) { -+ if (!options.tag.empty()) { -+ out << "Name,"; -+ } -+ -+ out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs"; -+ -+ return out; -+ } -+ -+ std::ostream &print(std::ostream &out, int idx, Options const &options) { -+ if (!options.tag.empty()) { -+ out << options.tag << ","; -+ } -+ -+ cutlass::Tensor4DCoord output_size = options.output_size(); -+ out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," -+ << options.input_size.w() << "," << options.input_size.c() << "," -+ -+ << options.filter_size.n() << "," << options.filter_size.h() << "," -+ << options.filter_size.w() << "," -+ -+ << options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column() -+ << "," -+ -+ << options.dilation.row() << "," << options.dilation.column() << "," -+ -+ << options.splitk << "," -+ -+ << runtime_ms << "," << gflops; -+ -+ return out; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Runs one testcase -+Result profile_convolution(Options const &options) { -+ Result result; -+ -+ // -+ // Allocate host-device tensors using the CUTLASS Utilities. -+ // -+ -+ cutlass::HostTensor tensor_a(options.input_size); -+ cutlass::HostTensor tensor_b(options.filter_size); -+ cutlass::HostTensor tensor_b_transpose(options.filter_size); -+ cutlass::HostTensor tensor_c(options.output_size()); -+ cutlass::HostTensor tensor_d(options.output_size()); -+ cutlass::HostTensor tensor_ref_d(options.output_size()); -+ -+ // -+ // Initialize tensors -+ // -+ -+ // Fill tensor A on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0); -+ -+ // Fill tensor B on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0); -+ -+ // Fill tensor C on host with uniform-distribution random data -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0); -+ -+ // Fill tensor D on host with zeros -+ cutlass::reference::host::TensorFill(tensor_d.host_view()); -+ -+ // Fill tensor D for reference on host with zeros -+ cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); -+ -+ // Copy data from host to GPU -+ tensor_a.sync_device(); -+ tensor_b.sync_device(); -+ tensor_b_transpose.sync_device(); -+ tensor_c.sync_device(); -+ tensor_d.sync_device(); -+ tensor_ref_d.sync_device(); -+ -+ // -+ // Define arguments for CUTLASS Convolution -+ // -+ -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; -+ -+ // Split P*Q into multiple CTA -+ int split_k_slices = options.splitk; -+ -+ // Construct Conv2dProblemSize with user defined output size -+ cutlass::conv::Conv2dProblemSize problem_size(options.input_size, -+ options.filter_size, -+ options.padding, -+ options.conv_stride, -+ options.dilation, -+ options.output_size(), -+ mode, -+ split_k_slices, -+ options.groups); -+ -+ // Construct Direc2dConv::Argument structure with conv2d -+ // problem size, data pointers, and epilogue values -+ typename Direct2dConv::Arguments arguments{problem_size, -+ tensor_a.device_ref(), -+ tensor_b.device_ref(), -+ tensor_c.device_ref(), -+ tensor_d.device_ref(), -+ {options.alpha, options.beta}, -+ tensor_b_transpose.device_ref()}; -+ -+ // -+ // Initialize CUTLASS Convolution -+ // -+ -+ Direct2dConv implicit_gemm_op; -+ -+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ result.status = implicit_gemm_op.can_implement(arguments); -+ CUTLASS_CHECK(result.status); -+ -+ result.status = implicit_gemm_op.initialize(arguments, workspace.get()); -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Launch initialized CUTLASS kernel -+ // -+ result.status = implicit_gemm_op(); -+ -+ CUTLASS_CHECK(result.status); -+ -+ // -+ // Optional reference check -+ // -+ -+ if (options.reference_check) { -+ std::cout << "Verification on host...\n"; -+ -+ // Compute with reference implementation -+ cutlass::reference::host::Conv2dFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementComputeEpilogue, -+ ElementAccumulator, -+ cutlass::NumericConverter >(problem_size, -+ tensor_a.host_ref(), -+ tensor_b.host_ref(), -+ tensor_c.host_ref(), -+ tensor_ref_d.host_ref(), -+ options.alpha, -+ options.beta); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ tensor_d.sync_host(); -+ -+ bool passed = -+ cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); -+ -+ if (!passed) { -+ result.reference_check = cutlass::Status::kErrorInternal; -+ std::cout << "ERROR - results miscompared.\n"; -+ } else { -+ result.reference_check = cutlass::Status::kSuccess; -+ std::cout << "Passed.\n"; -+ } -+ } else { -+ result.reference_check = cutlass::Status::kInvalid; -+ } -+ -+ if (options.save_workspace) { -+ std::stringstream ss; -+ -+ ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() -+ << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" -+ << options.filter_size.n() << "x" << options.filter_size.h() << "x" -+ << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; -+ -+ std::ofstream output_workspace(ss.str()); -+ -+ output_workspace << "Input = \n" -+ << tensor_a.host_view() << "\n\n" -+ << "Filters = \n" -+ << tensor_b.host_view() << "\n\n"; -+ -+ if (options.reference_check) { -+ output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; -+ } -+ -+ output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; -+ -+ std::cout << "Results written to '" << ss.str() << "'." << std::endl; -+ } -+ -+ // -+ // Performance measurement -+ // -+ -+ if (options.measure_performance) { -+ cudaEvent_t events[2]; -+ -+ for (auto &event : events) { -+ result.error = cudaEventCreate(&event); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ } -+ -+ // Record an event at the start of a series of convolution operations. -+ result.error = cudaEventRecord(events[0]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Launch a sequence of implicit GEMM operations on the device -+ for (int iteration = 0; iteration < options.iterations; ++iteration) { -+ result.status = implicit_gemm_op(); -+ CUTLASS_CHECK(result.status); -+ } -+ -+ // Record an event when the convolutions have been launched. -+ result.error = cudaEventRecord(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Wait for work on the device to complete. -+ result.error = cudaEventSynchronize(events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) -+ << std::endl; -+ return result; -+ } -+ -+ // Measure elapsed runtime -+ float runtime_ms = 0; -+ result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); -+ if (result.error != cudaSuccess) { -+ std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; -+ return result; -+ } -+ -+ // Print average runtime and GFLOPs. -+ result.runtime_ms = double(runtime_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.runtime_ms / 1000.0); -+ -+ // Cleanup -+ for (auto event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ bool notSupported = false; -+ -+ cudaDeviceProp props; -+ CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); -+ -+ if (!(props.major >= 6)) { -+ std::cerr << "Run on a machine with compute capability at least 60." << std::endl; -+ notSupported = true; -+ } -+ -+ if (notSupported) { -+ return 0; -+ } -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // Execute one problem size -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ Result result = profile_convolution(options); -+ -+ Result::print_header(std::cout, options) << std::endl; -+ result.print(std::cout, 1, options) << std::endl; -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu -new file mode 100644 -index 0000000..12739a0 ---- /dev/null -+++ b/3rdparty/cutlass/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu -@@ -0,0 +1,592 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*************************************************************************************************** -+ Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the -+ "classic data-parallel" and "Split-K" decompositions. -+ -+ For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition -+ for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) -+ -+ Requires NVIDIA Ampere or newer device (SM80+). -+ -+ - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) -+ -+ cutlass$ sudo nvidia-smi -pm 1 -i 0 -+ -+ cutlass$ sudo nvidia-smi -i 0 -pl 400 -+ -+ cutlass$ sudo nvidia-smi -i 0 -lgc 1005 -+ -+ - Build and run: -+ -+ cutlass$ mkdir build -+ -+ cutlass$ cd build -+ -+ cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -+ -+ cutlass/build$ make 47_ampere_gemm_universal_streamk -+ -+ cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk -+ -+ 10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply -+ -+ Basic data-parallel GEMM -+ Disposition: Passed -+ Avg runtime: 0.112633 ms -+ GFLOPs: 152530 -+ -+ StreamK GEMM with default load-balancing -+ Disposition: Passed -+ Avg runtime: 0.0941929 ms -+ GFLOPs: 182390 -+ Speedup vs Basic-DP: 1.196 -+ -+ StreamK emulating basic data-parallel GEMM -+ Disposition: Passed -+ Avg runtime: 0.113119 ms -+ GFLOPs: 151875 -+ Speedup vs Basic-DP: 0.996 -+ -+ Basic split-K GEMM with tile-splitting factor 2 -+ Disposition: Passed -+ Avg runtime: 0.104772 ms -+ GFLOPs: 163973 -+ -+ StreamK emulating Split-K GEMM with tile-splitting factor 2 -+ Disposition: Passed -+ Avg runtime: 0.105379 ms -+ GFLOPs: 163029 -+ Speedup vs Basic-SplitK: 0.994 -+ -+ **************************************************************************************************/ -+ -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "helper.h" -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A matrix configuration -+using ElementA = cutlass::half_t; // Element type for A matrix operand -+using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -+constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) -+ -+// B matrix configuration -+using ElementB = cutlass::half_t; // Element type for B matrix operand -+using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand -+constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) -+ -+// C/D matrix configuration -+using ElementC = cutlass::half_t; // Element type for C and D matrix operands -+using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands -+constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes) -+ -+// Multiply-accumulate blocking/pipelining details -+using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation -+using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature -+using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -+using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) -+using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) -+using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) -+constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop -+ -+// Epilogue output operator -+using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementC, // Element type for C and D matrix operands -+ AlignmentC, // Memory access granularity of C and D matrix in units of elements -+ ElementAccumulator, // Element type from internal accumaccumulation -+ ElementAccumulator>; // Data type used to compute linear combination -+ -+// Reference device GEMM implementation type -+using DeviceGemmReference = cutlass::reference::device::Gemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+// Classic data-parallel device GEMM implementation type -+using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ NumStages, -+ AlignmentA, -+ AlignmentB>; -+ -+// StreamK device GEMM implementation type -+using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference -+ NumStages, -+ AlignmentA, -+ AlignmentB>; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Testbed utility types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result structure -+struct Result -+{ -+ double avg_runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ Result( -+ double avg_runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess) -+ : -+ avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) -+ {} -+ -+}; -+ -+ -+/// Command line options parsing -+struct Options -+{ -+ std::string command_name; -+ bool help; -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha; -+ float beta; -+ int split_k_factor; -+ int avail_sms; -+ bool reference_check; -+ int iterations; -+ -+ cutlass::HostTensor tensor_a; -+ cutlass::HostTensor tensor_b; -+ cutlass::HostTensor tensor_c; -+ cutlass::HostTensor tensor_d; -+ cutlass::HostTensor tensor_ref_d; -+ -+ Options(std::string command_name) : -+ command_name(command_name), -+ help(false), -+ problem_size({2048, 2048, 2048}), -+ alpha(1.0f), -+ beta(0.0f), -+ split_k_factor(1), -+ avail_sms(-1), // Number of device SMs to use is unlimited -+ reference_check(true), -+ iterations(10000) -+ {} -+ -+ bool valid() const -+ { -+ return true; -+ } -+ -+ void parse(int argc, char const **args) -+ { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("m", problem_size.m()); -+ cmd.get_cmd_line_argument("n", problem_size.n()); -+ cmd.get_cmd_line_argument("k", problem_size.k()); -+ cmd.get_cmd_line_argument("alpha", alpha); -+ cmd.get_cmd_line_argument("beta", beta); -+ cmd.get_cmd_line_argument("split", split_k_factor); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const -+ { -+ out -+ << "Performs a GEMM computation.\n" -+ << "\n" -+ << "Options:\n" -+ << "\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --m= GEMM M dimension\n" -+ << " --n= GEMM N dimension\n" -+ << " --k= GEMM K dimension\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --split= Split-K factor to emulate\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out -+ << "\n\nExamples:\n\n" -+ << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const -+ { -+ // Two flops per multiply-add -+ return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM evaluation -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options -+typename DeviceGemmBasic::Arguments args_from_options( -+ const DeviceGemmBasic &device_gemm, -+ const Options &options, -+ cutlass::HostTensor &tensor_a, -+ cutlass::HostTensor &tensor_b, -+ cutlass::HostTensor &tensor_c, -+ cutlass::HostTensor &tensor_d) -+{ -+ return typename DeviceGemmBasic::Arguments( -+ cutlass::gemm::GemmUniversalMode::kGemm, // universal mode -+ options.problem_size, // problem_size -+ options.split_k_factor, // batch count / splitk slices -+ { // epilogue parameters -+ ElementAccumulator(options.alpha), -+ ElementAccumulator(options.beta) -+ }, -+ tensor_a.device_data(), // ptr_A -+ tensor_b.device_data(), // ptr_B -+ tensor_c.device_data(), // ptr_C -+ tensor_d.device_data(), // ptr_D -+ options.problem_size.mk().product(), // batch_stride_A -+ options.problem_size.nk().product(), // batch_stride_B -+ options.problem_size.mn().product(), // batch_stride_C -+ options.problem_size.mn().product(), // batch_stride_D -+ tensor_a.layout().stride(0), // stride_a -+ tensor_b.layout().stride(0), // stride_b -+ tensor_c.layout().stride(0), // stride_c -+ tensor_d.layout().stride(0)); // stride_d -+} -+ -+/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options -+typename DeviceGemmStreamK::Arguments args_from_options( -+ const DeviceGemmStreamK &device_gemm, -+ const Options &options, -+ cutlass::HostTensor &tensor_a, -+ cutlass::HostTensor &tensor_b, -+ cutlass::HostTensor &tensor_c, -+ cutlass::HostTensor &tensor_d) -+{ -+ return typename DeviceGemmStreamK::Arguments( -+ cutlass::gemm::GemmUniversalMode::kGemm, // universal mode -+ options.problem_size, // problem_size -+ options.split_k_factor, // batch count / splitk slices -+ { // epilogue parameters -+ ElementAccumulator(options.alpha), -+ ElementAccumulator(options.beta) -+ }, -+ tensor_a.device_data(), // ptr_A -+ tensor_b.device_data(), // ptr_B -+ tensor_c.device_data(), // ptr_C -+ tensor_d.device_data(), // ptr_D -+ options.problem_size.mk().product(), // batch_stride_A -+ options.problem_size.nk().product(), // batch_stride_B -+ options.problem_size.mn().product(), // batch_stride_C -+ options.problem_size.mn().product(), // batch_stride_D -+ tensor_a.layout().stride(0), // stride_a -+ tensor_b.layout().stride(0), // stride_b -+ tensor_c.layout().stride(0), // stride_c -+ tensor_d.layout().stride(0), // stride_d -+ options.avail_sms); // avail_sms -+} -+ -+ -+/// Execute a given example GEMM computation -+template -+Result run(std::string description, Options &options) -+{ -+ // Display test description -+ std::cout << std::endl << description << std::endl; -+ -+ // Zero-initialize test output matrix D -+ cutlass::reference::host::TensorFill(options.tensor_d.host_view()); -+ options.tensor_d.sync_device(); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ DeviceGemmT device_gemm; -+ -+ // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT -+ auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d); -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check the problem size is supported or not -+ CUTLASS_CHECK(device_gemm.can_implement(arguments)); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); -+ -+ // Correctness / Warmup iteration -+ CUTLASS_CHECK(device_gemm()); -+ -+ // Copy output data from CUTLASS and reference kernel to host for comparison -+ options.tensor_d.sync_host(); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ Result result; -+ result.passed = cutlass::reference::host::TensorEquals( -+ options.tensor_d.host_view(), -+ options.tensor_ref_d.host_view()); -+ -+ std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; -+ -+ // Run profiling loop -+ if (options.iterations > 0) -+ { -+ GpuTimer timer; -+ timer.start(); -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ CUTLASS_CHECK(device_gemm()); -+ } -+ timer.stop(); -+ -+ // Compute average runtime and GFLOPs. -+ float elapsed_ms = timer.elapsed_millis(); -+ result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); -+ -+ std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPs: " << result.gflops << std::endl; -+ } -+ -+ if (!result.passed) { -+ exit(-1); -+ } -+ -+ return result; -+} -+ -+ -+/// Program entrypoint -+int main(int argc, const char **argv) -+{ -+ // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. -+ if (!(__CUDACC_VER_MAJOR__ >= 11)) { -+ std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ // Current device must must have compute capability at least 80 -+ cudaDeviceProp props; -+ int current_device_id; -+ CUDA_CHECK(cudaGetDevice(¤t_device_id)); -+ CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); -+ if (!((props.major * 10 + props.minor) >= 80)) -+ { -+ std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." -+ << std::endl; -+ -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ // Parse commandline options -+ Options options("ampere_streamk_gemm"); -+ options.parse(argc, argv); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ std::cout << -+ options.iterations << " timing iterations of " << -+ options.problem_size.m() << " x " << -+ options.problem_size.n() << " x " << -+ options.problem_size.k() << " matrix-matrix multiply" << std::endl; -+ -+ if (!options.valid()) { -+ std::cerr << "Invalid problem." << std::endl; -+ return -1; -+ } -+ -+ -+ // -+ // Initialize GEMM datasets -+ // -+ -+ // Initialize tensors using CUTLASS helper functions -+ options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K -+ options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N -+ options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N -+ options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel -+ options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel -+ -+ // Fill matrix A on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_a.host_view(), -+ 1, -+ ElementA(2), -+ ElementA(-2), -+ 0); -+ -+ // Fill matrix B on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_b.host_view(), -+ 1, -+ ElementB(2), -+ ElementB(-2), -+ 0); -+ -+ // Fill matrix C on host with uniform-random data [4, -4] -+ cutlass::reference::host::TensorFillRandomUniform( -+ options.tensor_c.host_view(), -+ 1, -+ ElementC(2), -+ ElementC(-2), -+ 0); -+ -+ -+ // -+ // Compute reference output -+ // -+ -+ // Copy data from host to GPU -+ options.tensor_a.sync_device(); -+ options.tensor_b.sync_device(); -+ options.tensor_c.sync_device(); -+ -+ // Zero-initialize reference output matrix D -+ cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); -+ options.tensor_ref_d.sync_device(); -+ -+ // Create instantiation for device reference gemm kernel -+ DeviceGemmReference gemm_reference; -+ -+ // Launch device reference gemm kernel -+ gemm_reference( -+ options.problem_size, -+ ElementAccumulator(options.alpha), -+ options.tensor_a.device_ref(), -+ options.tensor_b.device_ref(), -+ ElementAccumulator(options.beta), -+ options.tensor_c.device_ref(), -+ options.tensor_ref_d.device_ref()); -+ -+ // Wait for kernels to finish -+ CUDA_CHECK(cudaDeviceSynchronize()); -+ -+ // Copy output data from reference kernel to host for comparison -+ options.tensor_ref_d.sync_host(); -+ -+ -+ // -+ // Evaluate CUTLASS kernels -+ // -+ -+ // Test default operation -+ if (options.split_k_factor == 1) -+ { -+ // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics -+ Result basic_dp = run("Basic data-parallel GEMM", options); -+ Result streamk_default = run("StreamK GEMM with default load-balancing", options); -+ -+ printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); -+ -+ // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 -+ options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) -+ Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); -+ options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) -+ -+ printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); -+ -+ options.split_k_factor++; // Increment splitting factor for next evaluation -+ -+ } -+ -+ // Show that StreamK can emulate "Split-K" with a tile-splitting factor -+ Result basic_splitk = run( -+ std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), -+ options); -+ -+ Result streamk_splitk = run( -+ std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), -+ options); -+ -+ printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu -new file mode 100644 -index 0000000..599d1d5 ---- /dev/null -+++ b/3rdparty/cutlass/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu -@@ -0,0 +1,463 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture -+ -+ This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0 -+ APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: -+ -+ 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) -+ which are more efficient than the Ampere tensor core instructions. -+ -+ 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large -+ blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous -+ copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and -+ convert them implicitly to TF32. -+ -+ 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). -+ -+ Examples: -+ -+ $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+#include "helper.h" -+ -+using namespace cute; -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM kernel configurations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A matrix configuration -+using ElementA = float; // Element type for A matrix operand -+using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -+constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) -+ -+// B matrix configuration -+using ElementB = float; // Element type for B matrix operand -+using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -+constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) -+ -+// C/D matrix configuration -+using ElementC = float; // Element type for C and D matrix operands -+using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -+ -+// Core kernel configurations -+using ElementAccumulator = float; // Element type for internal accumulation -+using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -+using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -+using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size -+using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -+using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -+using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder -+ -+using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ ArchTag, OperatorClass, -+ ElementA, LayoutA, AlignmentA, -+ ElementB, LayoutB, AlignmentB, -+ ElementAccumulator, -+ TilesShape, ClusterShape, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, // Indicates ProblemShape -+ CollectiveMainloop, -+ CollectiveEpilogue -+>; -+ -+using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+// Reference device GEMM implementation type -+using DeviceGemmReference = cutlass::reference::device::Gemm< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+using StrideA = typename Gemm::GemmKernel::StrideA; -+using StrideB = typename Gemm::GemmKernel::StrideB; -+using StrideC = typename Gemm::GemmKernel::StrideC; -+using StrideD = typename Gemm::GemmKernel::StrideD; -+ -+// -+// Data members -+// -+ -+/// Initialization -+StrideA stride_A; -+StrideB stride_B; -+StrideC stride_C; -+StrideD stride_D; -+uint64_t seed; -+ -+cutlass::DeviceAllocation block_A; -+cutlass::DeviceAllocation block_B; -+cutlass::DeviceAllocation block_C; -+cutlass::DeviceAllocation block_D; -+cutlass::DeviceAllocation block_ref_D; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Testbed utility types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ -+ float alpha, beta; -+ int iterations; -+ int m, n, k; -+ -+ Options(): -+ help(false), -+ m(5120), n(4096), k(4096), -+ alpha(1.f), beta(0.f), -+ iterations(1000) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("m", m); -+ cmd.get_cmd_line_argument("n", n); -+ cmd.get_cmd_line_argument("k", k); -+ cmd.get_cmd_line_argument("alpha", alpha, 1.f); -+ cmd.get_cmd_line_argument("beta", beta, 0.f); -+ cmd.get_cmd_line_argument("iterations", iterations); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "48_hopper_warp_specialized_gemm\n\n" -+ << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement\n\n" -+ << " --m= Sets the M extent of the GEMM\n" -+ << " --n= Sets the N extent of the GEMM\n" -+ << " --k= Sets the K extent of the GEMM\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n" -+ << " --iterations= Number of profiling iterations to perform.\n\n"; -+ -+ out -+ << "\n\nExamples:\n\n" -+ << "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; -+ -+ return out; -+ } -+ -+ /// Compute performance in GFLOP/s -+ double gflops(double runtime_s) const -+ { -+ // Two flops per multiply-add -+ uint64_t flop = uint64_t(2) * m * n * k; -+ double gflop = double(flop) / double(1.0e9); -+ return gflop / runtime_s; -+ } -+}; -+ -+/// Result structure -+struct Result -+{ -+ double avg_runtime_ms; -+ double gflops; -+ cutlass::Status status; -+ cudaError_t error; -+ bool passed; -+ -+ Result( -+ double avg_runtime_ms = 0, -+ double gflops = 0, -+ cutlass::Status status = cutlass::Status::kSuccess, -+ cudaError_t error = cudaSuccess) -+ : -+ avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) -+ {} -+ -+}; -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// GEMM setup and evaluation -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to initialize a block of device data -+template -+bool initialize_block( -+ cutlass::DeviceAllocation& block, -+ uint64_t seed=2023) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block.get(), block.size(), seed, scope_max, scope_min, 0); -+ -+ return true; -+} -+ -+/// Initialize operands to be used in the GEMM and reference GEMM -+void initialize(const Options &options) { -+ -+ stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); -+ stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); -+ stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); -+ stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); -+ -+ block_A.reset(options.m * options.k); -+ block_B.reset(options.k * options.n); -+ block_C.reset(options.m * options.n); -+ block_D.reset(options.m * options.n); -+ block_ref_D.reset(options.m * options.n); -+ -+ initialize_block(block_A, seed + 2023); -+ initialize_block(block_B, seed + 2022); -+ initialize_block(block_C, seed + 2021); -+} -+ -+/// Populates a Gemm::Arguments structure from the given commandline options -+typename Gemm::Arguments args_from_options(const Options &options) -+{ -+ typename Gemm::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {options.m, options.n, options.k}, -+ block_A.get(), -+ stride_A, -+ block_B.get(), -+ stride_B, -+ {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}} -+ }; -+ -+ return arguments; -+} -+ -+bool verify(const Options &options) { -+ cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); -+ cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.n, options.k})); -+ cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); -+ cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); -+ -+ // -+ // Compute reference output -+ // -+ -+ // Create instantiation for device reference gemm kernel -+ DeviceGemmReference gemm_reference; -+ -+ // Launch device reference gemm kernel -+ gemm_reference( -+ {options.m, options.n, options.k}, -+ ElementAccumulator(options.alpha), -+ ref_A, -+ ref_B, -+ ElementAccumulator(options.beta), -+ ref_C, -+ ref_D); -+ -+ // Wait for kernel to finish -+ CUDA_CHECK(cudaDeviceSynchronize()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); -+ -+ return passed; -+} -+ -+/// Execute a given example GEMM computation -+template -+int run(Options &options) -+{ -+ initialize(options); -+ -+ // Instantiate CUTLASS kernel depending on templates -+ Gemm gemm; -+ -+ // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm -+ auto arguments = args_from_options(options); -+ -+ // Using the arguments, query for extra workspace required for matrix multiplication computation -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ // Allocate workspace memory -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ // Check if the problem size is supported or not -+ CUTLASS_CHECK(gemm.can_implement(arguments)); -+ -+ // Initialize CUTLASS kernel with arguments and workspace pointer -+ CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); -+ -+ // Correctness / Warmup iteration -+ CUTLASS_CHECK(gemm.run()); -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ Result result; -+ result.passed = verify(options); -+ -+ std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; -+ -+ if (!result.passed) { -+ exit(-1); -+ } -+ -+ // Run profiling loop -+ if (options.iterations > 0) -+ { -+ GpuTimer timer; -+ timer.start(); -+ for (int iter = 0; iter < options.iterations; ++iter) { -+ CUTLASS_CHECK(gemm.run()); -+ } -+ timer.stop(); -+ -+ // Compute average runtime and GFLOPs. -+ float elapsed_ms = timer.elapsed_millis(); -+ result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); -+ result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); -+ -+ std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; -+ std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; -+ std::cout << " GFLOPS: " << result.gflops << std::endl; -+ } -+ -+ return 0; -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example -+ // and must have compute capability at least 90. -+ if (__CUDACC_VER_MAJOR__ < 12) { -+ std::cerr << "This example requires CUDA 12 or newer.\n"; -+ // Returning zero so this test passes on older Toolkits. Its actions are no-op. -+ return 0; -+ } -+ -+ cudaDeviceProp props; -+ int current_device_id; -+ CUDA_CHECK(cudaGetDevice(¤t_device_id)); -+ CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (props.major < 9) { -+ std::cerr -+ << "This example requires a GPU of NVIDIA's Hopper Architecture or " -+ << "later (compute capability 90 or greater).\n"; -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ // -+ // Evaluate CUTLASS kernels -+ // -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ run(options); -+#endif -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu -new file mode 100644 -index 0000000..7323cc3 ---- /dev/null -+++ b/3rdparty/cutlass/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu -@@ -0,0 +1,529 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Hopper GEMM example to create a GEMM kernel with custom Collectives -+ -+ The following example shows how to assemble a custom GEMM kernel that spells out the Collectives -+ directly instead of using a builder and, in the process, instance a more efficient Epilogue -+ (from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue. -+ -+ The GemmUniversal API takes 3 main template arguments: -+ (1) the problem shape / extents -+ (2) the collective mainloop type -+ (3) the collective epilogue type -+ -+ While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is -+ possible to build a custom collective mainloop directly as well. Furthermore, since epilogues -+ do not yet have a builder interface, this example shows how to instantiate a more-efficient -+ epilogue alongside the collective mainloop. -+ -+ Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set -+ of trade-offs. So it is recommended that users look at the options available under -+ cutlass/epilogue/collective and evaluate for their particular scenario. -+ -+ Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples -+ present in `test/unit/cute` to famialiarize with the basics of CuTe. -+ -+ Examples: -+ -+ $ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cutlass/util/command_line.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Command line options parsing -+struct Options { -+ -+ bool help; -+ bool error; -+ -+ int m, n, k, l; -+ int alpha, beta; -+ -+ Options(): -+ help(false), -+ error(false), -+ m(2048), n(2048), k(2048), l(1), -+ alpha(1), beta(0) -+ { } -+ -+ // Parses the command line -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ return; -+ } -+ -+ cmd.get_cmd_line_argument("m", m, 2048); -+ cmd.get_cmd_line_argument("n", n, 2048); -+ cmd.get_cmd_line_argument("k", k, 2048); -+ cmd.get_cmd_line_argument("l", l, 1); -+ cmd.get_cmd_line_argument("alpha", alpha, 1); -+ cmd.get_cmd_line_argument("beta", beta, 0); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "50_hopper_gemm_with_vectorized_epilogue\n\n" -+ << "Hopper GEMM Example with Epilogue Swizzle.\n\n" -+ << "Options:\n\n" -+ << " --help If specified, displays this usage statement\n\n" -+ << " --m= Sets the M extent of the GEMM\n" -+ << " --n= Sets the N extent of the GEMM\n" -+ << " --k= Sets the K extent of the GEMM\n" -+ << " --l= Sets the L extent (batch count) of the GEMM\n" -+ << " --alpha= Epilogue scalar alpha\n" -+ << " --beta= Epilogue scalar beta\n\n"; -+ -+ return out; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to initialize a block of device data -+template -+bool initialize_block( -+ cutlass::DeviceAllocation& block, -+ uint64_t seed=2023) { -+ -+ Element scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::device::BlockFillRandomUniform( -+ block.get(), block.size(), seed, scope_max, scope_min, 0); -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+// Wrapper to run and verify a GEMM. -+template < -+ class Gemm -+> -+struct ExampleRunner { -+ -+ using StrideA = typename Gemm::GemmKernel::StrideA; -+ using StrideB = typename Gemm::GemmKernel::StrideB; -+ using StrideC = typename Gemm::GemmKernel::StrideC; -+ using StrideD = typename Gemm::GemmKernel::StrideD; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ using LayoutD = typename Gemm::LayoutD; -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementAcc = typename Gemm::ElementAccumulator; -+ -+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; -+ using ElementC = typename Gemm::ElementC; -+ using ElementOutput = typename CollectiveEpilogue::ElementOutput; -+ using ElementCompute = typename CollectiveEpilogue::ElementCompute; -+ using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; -+ -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ StrideA stride_A; -+ StrideB stride_B; -+ StrideC stride_C; -+ StrideD stride_D; -+ uint64_t seed = 0; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ cutlass::DeviceAllocation block_ref_D; -+ -+ // -+ // Methods -+ // -+ -+ bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) { -+ auto [M, N, K, L] = problem_size; -+ -+ cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); -+ cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); -+ cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); -+ cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); -+ -+ cutlass::reference::device::GemmComplex( -+ {M, N, K}, -+ ElementCompute(alpha), -+ ref_A, -+ cutlass::ComplexTransform::kNone, -+ ref_B, -+ cutlass::ComplexTransform::kNone, -+ ElementCompute(beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(0), -+ L, // batch_count -+ M * K, // batch_stride_A -+ K * N, // batch_stride_B -+ M * N, // batch_stride_C -+ M * N // batch_stride_D -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Reference kernel failed. Last CUDA error: " -+ << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // Check if output from CUTLASS kernel and reference kernel are equal or not -+ bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); -+ -+ return passed; -+ } -+ -+ /// Initialize operands to be used in the GEMM and reference GEMM -+ void initialize(const ProblemShapeType& problem_size) { -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto [M, N, K, L] = problem_shape_MNKL; -+ -+ stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); -+ stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); -+ stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); -+ stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); -+ -+ block_A.reset(M * K * L); -+ block_B.reset(K * N * L); -+ block_C.reset(M * N * L); -+ block_D.reset(M * N * L); -+ block_ref_D.reset(M * N * L); -+ -+ initialize_block(block_A, seed + 2023); -+ initialize_block(block_B, seed + 2022); -+ initialize_block(block_C, seed + 2021); -+ } -+ -+ bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { -+ ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; -+ -+ initialize(problem_size); -+ -+ typename Gemm::GemmKernel::Arguments arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ block_A.get(), -+ stride_A, -+ block_B.get(), -+ stride_B, -+ {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, -+ hw_info -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "This kernel is not supported. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ status = gemm_op.initialize(arguments, workspace.get()); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ // Run the GEMM -+ status = gemm_op.run(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(cudaGetLastError()) << std::endl; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " -+ << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // Verify that the result is correct -+ bool passed = verify(problem_size, options.alpha, options.beta); -+ if (!passed) { -+ std::cerr << "Reference check failed" << std::endl; -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+int main(int argc, char const **args) { -+ -+ cudaDeviceProp props; -+ -+ cudaError_t error = cudaGetDeviceProperties(&props, 0); -+ if (error != cudaSuccess) { -+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; -+ return -1; -+ } -+ -+ if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { -+ std::cout -+ << "This example requires a GPU of NVIDIA's Hopper Architecture or " -+ << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; -+ return 0; -+ } -+ -+ // -+ // Parse options -+ // -+ -+ Options options; -+ -+ options.parse(argc, args); -+ -+ if (options.help) { -+ options.print_usage(std::cout) << std::endl; -+ return 0; -+ } -+ -+ if (options.error) { -+ std::cerr << "Aborting execution." << std::endl; -+ return -1; -+ } -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+ // -+ // Run examples -+ // -+ -+ // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This -+ // information is used by the underlying kernel. -+ cutlass::KernelHardwareInfo hw_info; -+ -+ // Change device_id to another value if you are running on a machine with multiple GPUs and wish -+ // to use a GPU other than that with device ID 0. -+ hw_info.device_id = 0; -+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); -+ -+ bool passed; -+ -+ // Problem configuration -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementAcc = int32_t; -+ using ElementOutput = int8_t; -+ -+ // Note : Only TN WGMMA Gemm is supported currently in 3.0 -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using LayoutD = cutlass::layout::ColumnMajor; -+ -+ // Tiling configuration selection -+ using TileShape = Shape<_128,_64,_128>; -+ -+ // Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks -+ using ClusterShape = Shape<_1,_2,_1>; -+ -+ // -+ // Assembling the CollectiveMainloop type -+ // -+ -+ // Pipeline Depth to be used i.e number of A, B buffers in shared memory -+ constexpr int PipelineStages = 8; -+ -+ // Let's choose a Warp-Specialized Mainloop implemention which uses TMA -+ // Note : This requires / assumes the tensors to be 16B aligned -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; -+ -+ // TN => K Major for both A & B -+ static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K; -+ static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K; -+ -+ // We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA) -+ using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< -+ ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>())); -+ -+ // A loads can be optimized with multicast if cluster-n > 1 -+ using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1, -+ cute::SM90_TMA_LOAD, -+ cute::SM90_TMA_LOAD_MULTICAST>::type; -+ -+ // B loads can be optimized with multicast if cluster-m > 1 -+ using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1, -+ cute::SM90_TMA_LOAD, -+ cute::SM90_TMA_LOAD_MULTICAST>::type; -+ -+ using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< -+ GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{})) -+ >()); -+ -+ using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< -+ GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{})) -+ >()); -+ -+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< -+ DispatchPolicy, -+ TileShape, -+ ElementA, -+ cutlass::gemm::TagToStrideA_t, -+ ElementB, -+ cutlass::gemm::TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, -+ SmemLayoutAtomA, -+ void, // Does not need a SmemCopyAtom, since A is read directly from SMEM -+ cute::identity, -+ GmemTiledCopyB, -+ SmemLayoutAtomB, -+ void, // Does not need a SmemCopyAtom, since B is read directly from SMEM -+ cute::identity -+ >; -+ -+ // -+ // Assembling the Collective Epilogue Type -+ // -+ -+ // Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension -+ using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>, -+ Stride,_32>>; -+ -+ // 128 threads loading 16 elements each (to get vectorized global stores) -+ using TileShapeS2R = Shape<_128,_16>; -+ -+ // Layout to ensure bank-conflict free loads & stores -+ using SmemLayout = ComposedLayout< -+ Swizzle<3,4,3>, -+ smem_ptr_flag_bits::value>, -+ PreSwizzleLayout>; -+ -+ // Tiled copy from Smem to Registers -+ // Note : CuTe will vectorize this copy if the tiling + swizzling above were right -+ using TiledCopyS2R = TiledCopy< -+ Copy_Atom, -+ Layout< Shape<_128,_16>, -+ Stride<_16,_1>>, -+ TileShapeS2R>; -+ -+ using Epilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ SmemLayout, -+ Copy_Atom, -+ TiledCopyS2R, -+ Copy_Atom>; -+ -+ // -+ // Assembling the GemmKernel -+ // -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ Epilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ ExampleRunner runner; -+ -+ passed = runner.run(options, hw_info); -+ -+ std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl; -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/examples/common/helper.h b/3rdparty/cutlass/examples/common/helper.h -new file mode 100644 -index 0000000..ba04113 ---- /dev/null -+++ b/3rdparty/cutlass/examples/common/helper.h -@@ -0,0 +1,77 @@ -+#pragma once -+ -+#include "cuda_runtime.h" -+ -+/** -+ * Panic wrapper for unwinding CUTLASS errors -+ */ -+#define CUTLASS_CHECK(status) \ -+ { \ -+ cutlass::Status error = status; \ -+ if (error != cutlass::Status::kSuccess) { \ -+ std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ -+ << std::endl; \ -+ exit(EXIT_FAILURE); \ -+ } \ -+ } -+ -+ -+/** -+ * Panic wrapper for unwinding CUDA runtime errors -+ */ -+#define CUDA_CHECK(status) \ -+ { \ -+ cudaError_t error = status; \ -+ if (error != cudaSuccess) { \ -+ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ -+ << " at line: " << __LINE__ << std::endl; \ -+ exit(EXIT_FAILURE); \ -+ } \ -+ } -+ -+ -+/** -+ * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream -+ */ -+struct GpuTimer -+{ -+ cudaStream_t _stream_id; -+ cudaEvent_t _start; -+ cudaEvent_t _stop; -+ -+ /// Constructor -+ GpuTimer() : _stream_id(0) -+ { -+ CUDA_CHECK(cudaEventCreate(&_start)); -+ CUDA_CHECK(cudaEventCreate(&_stop)); -+ } -+ -+ /// Destructor -+ ~GpuTimer() -+ { -+ CUDA_CHECK(cudaEventDestroy(_start)); -+ CUDA_CHECK(cudaEventDestroy(_stop)); -+ } -+ -+ /// Start the timer for a given stream (defaults to the default stream) -+ void start(cudaStream_t stream_id = 0) -+ { -+ _stream_id = stream_id; -+ CUDA_CHECK(cudaEventRecord(_start, _stream_id)); -+ } -+ -+ /// Stop the timer -+ void stop() -+ { -+ CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); -+ } -+ -+ /// Return the elapsed time (in milliseconds) -+ float elapsed_millis() -+ { -+ float elapsed = 0.0; -+ CUDA_CHECK(cudaEventSynchronize(_stop)); -+ CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); -+ return elapsed; -+ } -+}; -diff --git a/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu b/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu -new file mode 100644 -index 0000000..fc4839a ---- /dev/null -+++ b/3rdparty/cutlass/examples/cute/tutorial/sgemm_nt_1.cu -@@ -0,0 +1,426 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+#include -+ -+#include -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+# include "cutlass/util/cublas_wrappers.hpp" -+#endif -+#include "cutlass/util/helper_cuda.hpp" -+ -+template -+__global__ static -+__launch_bounds__(decltype(size(CThreadLayout{}))::value) -+void -+gemm_device(MShape M, NShape N, KShape K, -+ TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -+ TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -+ TC * C, CStride dC, CBlockLayout , CThreadLayout tC, -+ Alpha alpha, Beta beta) -+{ -+ using namespace cute; -+ using X = Underscore; -+ -+ // Preconditions -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); -+ CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); -+ -+ //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M -+ //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N -+ CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K -+ -+ // Shared memory buffers -+ __shared__ TA smemA[cosize_v]; -+ __shared__ TB smemB[cosize_v]; -+ auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) -+ auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) -+ -+ // Represent the full tensors -+ auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) -+ auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) -+ auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) -+ -+ // Get the appropriate blocks for this thread block -- -+ // potential for thread block locality -+ auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) -+ -+ auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) -+ auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) -+ auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) -+ -+ // -+ // Partition the copying of A and B tiles across the threads -+ // -+ -+ // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB -+ // Default is a raked partition, but can be changed with Step parameter -+ -+ auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) -+ auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) -+ -+ auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) -+ auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ // TUTORIAL: Example of partitioning via projections of tC -+ -+ // Partition sA (M,K) by the rows of tC -+ auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) -+ // Partition sB (N,K) by the cols of tC -+ auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) -+ // Partition gC (M,N) by the tile of tC -+ auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) -+ -+ // Allocate the accumulators -- same size as the projected data -+ auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) -+ -+ // Clear the accumulators -+ clear(tCrC); -+ -+#if 0 -+ if(thread0()) { -+ print("mA\n"); -+ print(mA.shape()); print("\n"); print(mA.stride()); -+ print("\n\ngA\n"); -+ print(gA.shape()); print("\n"); print(gA.stride()); -+ print("\n\ntAgA\n"); -+ print(tAgA.shape()); print("\n"); print(tAgA.stride()); -+ print("\n\nsA\n"); -+ print(sA.shape()); print("\n"); print(sA.stride()); -+ print("\n\ntAsA\n"); -+ print(tAsA.shape()); print("\n"); print(tAsA.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 0 -+ if(thread0()) { -+ print("mB\n"); -+ print(mB.shape()); print("\n"); print(mB.stride()); -+ print("\n\ngB\n"); -+ print(gB.shape()); print("\n"); print(gB.stride()); -+ print("\n\ntBgB\n"); -+ print(tBgB.shape()); print("\n"); print(tBgB.stride()); -+ print("\n\nsB\n"); -+ print(sB.shape()); print("\n"); print(sB.stride()); -+ print("\n\ntBsB\n"); -+ print(tBsB.shape()); print("\n"); print(tBsB.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 0 -+ if(thread0()) { -+ print("mC\n"); -+ print(mC.shape()); print("\n"); print(mC.stride()); -+ print("\n\ngC\n"); -+ print(gC.shape()); print("\n"); print(gC.stride()); -+ print("\n\ntCsA\n"); -+ print(tCsA.shape()); print("\n"); print(tCsA.stride()); -+ print("\n\ntCsB\n"); -+ print(tCsB.shape()); print("\n"); print(tCsB.stride()); -+ print("\n\ntCgC\n"); -+ print(tCgC.shape()); print("\n"); print(tCgC.stride()); -+ print("\n\ntCrC\n"); -+ print(tCrC.shape()); print("\n"); print(tCrC.stride()); -+ print("\n\n"); -+ } -+#endif -+ -+#if 1 -+ -+ // TUTORIAL: Example of a very simple compute loop -+ // Data is read from global to shared memory via the tA|tB partitioning -+ // gemm(.) operates on the shared memory directly via the tC partitioning -+ -+ auto k_max = size<2>(tAgA); -+ -+ for (int k = 0; k < k_max; ++k) -+ { -+ // Copy gmem to smem -+ copy(tAgA(_,_,k), tAsA); -+ copy(tBgB(_,_,k), tBsB); -+ -+ // In case copy uses cp.async, make sure that the cp.async -+ // instructions are ordered with respect to other cp.async -+ // instructions (fence), then wait on all the outstanding copy -+ // operations (wait<0>()). __syncthreads() alone does not do -+ // this. -+ // -+ // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. -+ // This is equivalent to cp.async.commit_group followed by -+ // cp.async_wait_group 0. This should make the first -+ // cp_async_fence() (which also issues cp.async.commit_group) -+ // redundant. The tutorial works as-is, so we'll leave the -+ // redundant fence in for now and study its removal later. -+ cp_async_fence(); -+ cp_async_wait<0>(); -+ -+ __syncthreads(); -+ -+ // Compute gemm on smem -+ gemm(tCsA, tCsB, tCrC); -+ -+ __syncthreads(); -+ } -+ -+#endif -+ -+ // -+ // Epilogue -+ // -+ -+ axpby(alpha, tCrC, beta, tCgC); -+} -+ -+ -+template -+void -+gemm(int m, int n, int k, -+ Alpha alpha, -+ TA const* A, int ldA, -+ TB const* B, int ldB, -+ Beta beta, -+ TC * C, int ldC, -+ cudaStream_t stream = 0) -+{ -+ using namespace cute; -+ -+ // Define shapes (dynamic) -+ auto M = int(m); -+ auto N = int(n); -+ auto K = int(k); -+ -+ // Define strides (mixed) -+ auto dA = make_stride(Int<1>{}, ldA); -+ auto dB = make_stride(Int<1>{}, ldB); -+ auto dC = make_stride(Int<1>{}, ldC); -+ -+ // Define block sizes (static) -+ auto bM = Int<128>{}; -+ auto bN = Int<128>{}; -+ auto bK = Int< 8>{}; -+ -+ // Define the block layouts (static) -+ auto sA = make_layout(make_shape(bM,bK)); -+ auto sB = make_layout(make_shape(bN,bK)); -+ auto sC = make_layout(make_shape(bM,bN)); -+ -+ // Define the thread layouts (static) -+ auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); -+ auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); -+ auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); -+ -+ dim3 dimBlock(size(tC)); -+ dim3 dimGrid(ceil_div(size(M), size(bM)), -+ ceil_div(size(N), size(bN))); -+ gemm_device -+ <<< dimGrid, dimBlock, 0, stream >>> -+ (M, N, K, -+ A, dA, sA, tA, -+ B, dB, sB, tB, -+ C, dC, sC, tC, -+ alpha, beta); -+} -+ -+#include -+#include -+#include -+ -+void test_gemm(int m, int n, int k) -+{ -+ cute::device_init(0); -+ -+ std::cout << "M = " << m << std::endl; -+ std::cout << "N = " << n << std::endl; -+ std::cout << "K = " << k << std::endl; -+ -+ using TA = float; -+ using TB = float; -+ using TC = float; -+ using TI = float; -+ -+ thrust::host_vector h_A(m*k); -+ thrust::host_vector h_B(n*k); -+ thrust::host_vector h_C(m*n); -+ -+ for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); -+ for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); -+ for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); -+ -+ thrust::device_vector d_A = h_A; -+ thrust::device_vector d_B = h_B; -+ thrust::device_vector d_C = h_C; -+ -+ TI alpha = 1.0; -+ TI beta = 0.0; -+ -+ double gflops = (2.0*m*n*k) * 1e-9; -+ -+ const int timing_iterations = 100; -+ GPU_Clock timer; -+ -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+ // -+ // cuBLas -+ // -+ -+ cublasHandle_t handle; -+ cublasCreate(&handle); -+ -+ // Run once -+ d_C = h_C; -+ blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, -+ m, n, k, -+ &alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ &beta, -+ d_C.data().get(), m); -+ CUTE_CHECK_LAST(); -+ -+ thrust::host_vector cublas_result = d_C; -+ -+ // Timing iterations -+ timer.start(); -+ for (int i = 0; i < timing_iterations; ++i) { -+ blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, -+ m, n, k, -+ &alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ &beta, -+ d_C.data().get(), m); -+ } -+ double cublas_time = timer.seconds() / timing_iterations; -+ CUTE_CHECK_LAST(); -+ printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000); -+ -+#else -+ -+ std::cout << "Verification by comparison with cuBLAS is disabled, " -+ "either because the CMake option CUTLASS_ENABLE_CUBLAS " -+ "was explicitly set to OFF, or because CMake could not find cuBLAS. " -+ "If you would like to enable verification with cuBLAS, " -+ "please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, " -+ "rerun CMake, and recompile this example.\n"; -+ -+#endif // CUTLASS_ENABLE_CUBLAS -+ -+ // -+ // CuTe -+ // -+ -+ // Run once (and check) -+ d_C = h_C; -+ gemm(m, n, k, -+ alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ beta, -+ d_C.data().get(), m); -+ CUTE_CHECK_LAST(); -+ thrust::host_vector cute_result = d_C; -+ -+ // Timing iterations -+ timer.start(); -+ for (int i = 0; i < timing_iterations; ++i) { -+ gemm(m, n, k, -+ alpha, -+ d_A.data().get(), m, -+ d_B.data().get(), n, -+ beta, -+ d_C.data().get(), m); -+ } -+ double cute_time = timer.seconds() / timing_iterations; -+ CUTE_CHECK_LAST(); -+ printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); -+ -+#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -+ printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100); -+ -+ auto host_matrix_to_const_column_major_cute_tensor = -+ [](const auto& X, int num_rows, int num_cols, int LDX) { -+ const auto shape = cute::Shape{num_rows, num_cols}; -+ const auto strides = cute::Stride{1, LDX}; -+ return cute::make_tensor(X.data(), cute::make_layout(shape, strides)); -+ }; -+ -+ const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m); -+ // B^T is k x n, so B is n x k. -+ const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n); -+ const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m); -+ const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m); -+ print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view); -+ -+#endif // CUTLASS_ENABLE_CUBLAS -+} -+ -+ -+int main(int argc, char** argv) -+{ -+ int m = 5120; -+ if (argc >= 2) -+ sscanf(argv[1], "%d", &m); -+ -+ int n = 5120; -+ if (argc >= 3) -+ sscanf(argv[2], "%d", &n); -+ -+ int k = 4096; -+ if (argc >= 4) -+ sscanf(argv[3], "%d", &k); -+ -+ test_gemm(m, n, k); -+ -+ return 0; -+} -diff --git a/3rdparty/cutlass/include/cute/algorithm/axpby.hpp b/3rdparty/cutlass/include/cute/algorithm/axpby.hpp -new file mode 100644 -index 0000000..a613417 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/axpby.hpp -@@ -0,0 +1,79 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+axpby(Alpha const& alpha, -+ Tensor const& x, -+ Beta const& beta, -+ Tensor && y) -+{ -+ return axpby(alpha, x, beta, y); -+} -+ -+// -+// AXPBY -+// -+template -+CUTE_HOST_DEVICE -+void -+axpby(Alpha const& alpha, -+ Tensor const& x, -+ Beta const& beta, -+ Tensor & y) -+{ -+ auto isBetaZero = (beta == Int<0>{}); -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size(x); ++i) { -+ y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/clear.hpp b/3rdparty/cutlass/include/cute/algorithm/clear.hpp -new file mode 100644 -index 0000000..ce7b510 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/clear.hpp -@@ -0,0 +1,66 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+clear(Tensor&& tensor) -+{ -+ return clear(tensor); -+} -+ -+// -+// Set elements to zero -+// -+template -+CUTE_HOST_DEVICE -+void -+clear(Tensor& tensor) -+{ -+ using T = typename Tensor::value_type; -+ -+ fill(tensor, T{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/copy.hpp b/3rdparty/cutlass/include/cute/algorithm/copy.hpp -new file mode 100644 -index 0000000..04ceb05 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/copy.hpp -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(PrdTensor const& pred, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_if(pred, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(Copy_Atom const& copy_atom, -+ PrdTensor const& pred, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_if(copy_atom, pred, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_vec(Tensor const& src, -+ Tensor && dst) -+{ -+ return copy_vec(src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Tensor const& src, -+ Tensor && dst) -+{ -+ return copy(src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const& copy_atom, -+ Tensor const& src, -+ Tensor && dst) -+{ -+ return copy(copy_atom, src, dst); -+} -+ -+// -+// copy_if -- Predicated Copy -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(PrdTensor const& pred, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ auto copy_op = select_elementwise_copy(src, dst); -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size(src); ++i) { -+ if (pred(i)) { -+ copy_op.copy(src(i), dst(i)); -+ } -+ } -+} -+ -+// -+// copy_if -- Predicated CopyAtom -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_if(Copy_Atom const& copy_atom, -+ PredTensor const& pred, // (Rest...) -+ Tensor const& src, // (V,Rest...) -+ Tensor & dst) // (V,Rest...) -+{ -+ static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); -+ if constexpr (SrcLayout::rank == 1) { // Dispatch the copy -+ copy_atom.call(src, dst); -+ } else { // Loop over all but the first mode -+ constexpr int R = SrcLayout::rank; -+ auto src_v = group_modes<1,R>(src); -+ auto dst_v = group_modes<1,R>(dst); -+ CUTE_UNROLL -+ for (int i = 0; i < size<1>(src_v); ++i) { -+ if (pred(i)) { -+ copy_atom.call(src_v(_,i), dst_v(_,i)); -+ } -+ } -+ } -+} -+ -+// -+// copy_vec -- attempt vectorized copy with VecType -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_vec(Tensor const& src, -+ Tensor & dst) -+{ -+ using SrcType = typename SrcEngine::value_type; -+ using DstType = typename DstEngine::value_type; -+ if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) -+ { -+ /* @pre is_aligned(src.data()) && -+ * is_aligned(dst.data()) -+ */ -+ auto src_v = recast(src); -+ auto dst_v = recast(dst); -+ -+#if 0 -+ if (thread0()) { -+ print("copy_vec -- vectorizing copy from %3db to %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(VecType))); -+ print(" "); print(layout(src)); print(" => "); print(layout(src_v)); print("\n"); -+ print(" "); print(layout(dst)); print(" => "); print(layout(dst_v)); print("\n"); -+ } -+#endif -+ -+ return copy_if(TrivialPredTensor{}, src_v, dst_v); -+ } else { -+#if 0 -+ if (thread0()) { -+ print("copy_vec -- not vectorizing, copy with %3db and %3db\n", int(8*sizeof(SrcType)), int(8*sizeof(DstType))); -+ print(" "); print(layout(src)); print("\n"); -+ print(" "); print(layout(dst)); print("\n"); -+ } -+#endif -+ -+ return copy_if(TrivialPredTensor{}, src, dst); -+ } -+} -+ -+// -+// copy -- auto-vectorizing copy -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Tensor const& src, -+ Tensor & dst) -+{ -+ constexpr int N = decltype(max_common_vector(src, dst))::value; -+ -+#if 0 -+ if (thread0()) { -+ print("copy -- found a max_common_vector of %d\n", N); -+ print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n"); -+ print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n"); -+ } -+#endif -+ -+ if constexpr (N <= 1) { -+ return copy_if(TrivialPredTensor{}, src, dst); -+ } else { -+ constexpr int vec_bits = N * sizeof_bits::value; -+ using VecType = uint_bit_t; -+ return copy_vec(src, dst); -+ } -+} -+ -+// -+// copy -- CopyAtom -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const& copy_atom, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ return copy_if(copy_atom, TrivialPredTensor{}, src, dst); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy(Copy_Atom const&, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ return copy(src, dst); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/fill.hpp b/3rdparty/cutlass/include/cute/algorithm/fill.hpp -new file mode 100644 -index 0000000..bc0c4ad ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/fill.hpp -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Accept mutable temporaries -+// -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor&& tensor, T const& value) -+{ -+ return fill(tensor, value); -+} -+ -+namespace detail -+{ -+ -+// Prefer fill(tensor.data(), value), if possible -+template -+CUTE_HOST_DEVICE -+auto -+fill(Tensor& tensor, T const& value, prefer<1>) -+ -> decltype(fill(tensor.data(), value)) -+{ -+ fill(tensor.data(), value); -+} -+ -+// Default implementation -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor& tensor, T const& value, prefer<0>) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ tensor(i) = value; -+ } -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE -+void -+fill(Tensor& tensor, T const& value) -+{ -+ return detail::fill(tensor, value, prefer<1>{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/functional.hpp b/3rdparty/cutlass/include/cute/algorithm/functional.hpp -new file mode 100644 -index 0000000..e66cd97 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/functional.hpp -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+/** C++14 extensions */ -+ -+namespace cute { -+ -+/**************/ -+/** Identity **/ -+/**************/ -+ -+struct identity { -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) operator()(T&& arg) const { -+ return std::forward(arg); -+ } -+}; -+ -+template -+struct constant_fn { -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) operator()(T&&...) const { -+ return r_; -+ } -+ R r_; -+}; -+ -+/***********/ -+/** Unary **/ -+/***********/ -+ -+#define CUTE_LEFT_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return OP std::forward(arg); \ -+ } \ -+ } -+#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return std::forward(arg) OP ; \ -+ } \ -+ } -+#define CUTE_NAMED_UNARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& arg) const { \ -+ return OP (std::forward(arg)); \ -+ } \ -+ } -+ -+CUTE_LEFT_UNARY_OP(unary_plus, +); -+CUTE_LEFT_UNARY_OP(negate, -); -+CUTE_LEFT_UNARY_OP(bit_not, ~); -+CUTE_LEFT_UNARY_OP(logical_not, !); -+CUTE_LEFT_UNARY_OP(dereference, *); -+CUTE_LEFT_UNARY_OP(address_of, &); -+CUTE_LEFT_UNARY_OP(pre_increment, ++); -+CUTE_LEFT_UNARY_OP(pre_decrement, --); -+ -+CUTE_RIGHT_UNARY_OP(post_increment, ++); -+CUTE_RIGHT_UNARY_OP(post_decrement, --); -+ -+CUTE_NAMED_UNARY_OP(abs_fn, abs); -+CUTE_NAMED_UNARY_OP(conjugate, cute::conj); -+ -+#undef CUTE_LEFT_UNARY_OP -+#undef CUTE_RIGHT_UNARY_OP -+#undef CUTE_NAMED_UNARY_OP -+ -+/************/ -+/** Binary **/ -+/************/ -+ -+#define CUTE_BINARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& lhs, U&& rhs) const { \ -+ return std::forward(lhs) OP std::forward(rhs); \ -+ } \ -+ } -+#define CUTE_NAMED_BINARY_OP(NAME,OP) \ -+ struct NAME { \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ decltype(auto) operator()(T&& lhs, U&& rhs) const { \ -+ return OP (std::forward(lhs), std::forward(rhs)); \ -+ } \ -+ } -+ -+ -+CUTE_BINARY_OP(plus, +); -+CUTE_BINARY_OP(minus, -); -+CUTE_BINARY_OP(multiplies, *); -+CUTE_BINARY_OP(divides, /); -+CUTE_BINARY_OP(modulus, %); -+ -+CUTE_BINARY_OP(plus_assign, +=); -+CUTE_BINARY_OP(minus_assign, -=); -+CUTE_BINARY_OP(multiplies_assign, *=); -+CUTE_BINARY_OP(divides_assign, /=); -+CUTE_BINARY_OP(modulus_assign, %=); -+ -+CUTE_BINARY_OP(bit_and, &); -+CUTE_BINARY_OP(bit_or, |); -+CUTE_BINARY_OP(bit_xor, ^); -+CUTE_BINARY_OP(left_shift, <<); -+CUTE_BINARY_OP(right_shift, >>); -+ -+CUTE_BINARY_OP(bit_and_assign, &=); -+CUTE_BINARY_OP(bit_or_assign, |=); -+CUTE_BINARY_OP(bit_xor_assign, ^=); -+CUTE_BINARY_OP(left_shift_assign, <<=); -+CUTE_BINARY_OP(right_shift_assign, >>=); -+ -+CUTE_BINARY_OP(logical_and, &&); -+CUTE_BINARY_OP(logical_or, ||); -+ -+CUTE_BINARY_OP(equal_to, ==); -+CUTE_BINARY_OP(not_equal_to, !=); -+CUTE_BINARY_OP(greater, >); -+CUTE_BINARY_OP(less, <); -+CUTE_BINARY_OP(greater_equal, >=); -+CUTE_BINARY_OP(less_equal, <=); -+ -+CUTE_NAMED_BINARY_OP(max_fn, cute::max); -+CUTE_NAMED_BINARY_OP(min_fn, cute::min); -+ -+#undef CUTE_BINARY_OP -+#undef CUTE_NAMED_BINARY_OP -+ -+/**********/ -+/** Meta **/ -+/**********/ -+ -+template -+struct bound_fn { -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(T&& arg) { -+ return fn_(arg_, std::forward(arg)); -+ } -+ -+ Fn fn_; -+ Arg arg_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+bind(Fn const& fn, Arg const& arg) { -+ return bound_fn{fn, arg}; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/gemm.hpp b/3rdparty/cutlass/include/cute/algorithm/gemm.hpp -new file mode 100644 -index 0000000..6e2ce61 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/gemm.hpp -@@ -0,0 +1,718 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+/** The gemm algorithm takes four (or three) tensors and computes -+ * D += A * B + C -+ * It dispatches based on the number of modes each tensor has: -+ * -+ * 1. `(V) x (V) => (V)`. -+ * The element-wise product of vectors. Dispatches to FMA or MMA. -+ * 2. `(M) x (N) => (M,N)`. -+ * The outer product of vectors. Dispatches to [3] with new mode K=(1). -+ * 3. `(M,K) x (N,K) => (M,N)`. -+ * The product of matrices. Dispatches to [5] with MMA vector-mode V. -+ * 4. `(V,M) x (V,N) => (V,M,N)`. -+ * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). -+ * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. -+ * The batched product of matrices. Dispatches to [4] for each (k). -+ */ -+ -+namespace cute -+{ -+ -+// -+// Three arguments to four -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor const& A, -+ Tensor const& B, -+ Tensor & C) -+{ -+ return gemm(C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor & C) -+{ -+ return gemm(mma, C, A, B, C); -+} -+ -+// -+// Accept mutable temporaries -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor const& A, -+ Tensor const& B, -+ Tensor && C) -+{ -+ return gemm(C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor && D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ return gemm(D, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor && C) -+{ -+ return gemm(mma, C, A, B, C); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor && D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ return gemm(mma, D, A, B, C); -+} -+ -+// -+// Default MMA is UniversalFMA -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+gemm(Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ using MMA = MMA_Atom::value_type, -+ typename Tensor::value_type, -+ typename Tensor::value_type, -+ typename Tensor::value_type>>; -+ -+ return gemm(MMA{}, D, A, B, C); -+} -+ -+// -+// Thread-Local Register-Memory GEMMs -+// -+ -+// Dispatch [1]: (V) x (V) => (V) -+template ::value && -+ ALayout::rank == 1 && is_rmem::value && -+ BLayout::rank == 1 && is_rmem::value && -+ CLayout::rank == 1 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V) Logical data -+ Tensor const& A, // (V) Logical data -+ Tensor const& B, // (V) Logical data -+ Tensor const& C) // (V) Logical data -+{ -+ // No static assertions on (V), MMA checks compatibility -+ mma.call(D, A, B, C); -+} -+ -+// Dispatch [2]: (M) x (N) => (M,N) -+template ::value && -+ ALayout::rank == 1 && is_rmem::value && -+ BLayout::rank == 1 && is_rmem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M) Logical data -+ Tensor const& B, // (N) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ gemm(mma, -+ D, // (M,N) -+ make_tensor(A.data(), append<2>(A.layout())), // (M,1) -+ make_tensor(B.data(), append<2>(B.layout())), // (N,1) -+ C); // (M,N) -+} -+ -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+template ::value && -+ ALayout::rank == 2 && is_rmem::value && -+ BLayout::rank == 2 && is_rmem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M,K) Logical data -+ Tensor const& B, // (N,K) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ // Assert this is a 1-value MMA -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); -+ -+ gemm(mma, -+ make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) -+ make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) -+ make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) -+ make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) -+} -+ -+// Dispatch [4]: (V,M) x (V,N) => (V,M,N) -+template ::value && -+ ALayout::rank == 2 && is_rmem::value && -+ BLayout::rank == 2 && is_rmem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M) Logical data -+ Tensor const& B, // (V,N) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ // REGISTER .reuse OPTIMIZATIONS -+ -+ auto M = size<1>(A); -+ auto N = size<1>(B); -+ -+ // 64-bit traversal specialization -- serpentine path -+ if (size<0>(A) * sizeof(typename Tensor::value_type) == 8 && -+ size<0>(B) * sizeof(typename Tensor::value_type) == 8) -+ { -+#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) -+ // Row-major iteration -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate -+ gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); -+ } -+ } -+#else -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate -+ gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); -+ } -+ } -+#endif -+ } else -+ -+ // 32-bit traversal specialization -- kinked serpentine path -+ if (size<0>(A) * sizeof(typename Tensor::value_type) == 4 && -+ size<0>(B) * sizeof(typename Tensor::value_type) == 4) -+ { -+#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) -+ // Row-major iteration -+ CUTE_UNROLL -+ for (int m = 0; m < M; m += 2) { -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ int ns = (m & 2) ? N-1-n : n; -+ gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); -+ -+ if (m+1 < M) { -+ gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); -+ } -+ } -+ } -+#else -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; n += 2) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ // Kinked serpentine traversal for maximum register reuse -+ int ms = (n & 2) ? M-1-m : m; -+ gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); -+ -+ if (n+1 < N) { -+ gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); -+ } -+ } -+ } -+#endif -+ } else { -+ // Fallback to serpentine loop -+ // Col-major iteration -+ CUTE_UNROLL -+ for (int n = 0; n < N; ++n) { -+ CUTE_UNROLL -+ for (int m = 0; m < M; ++m) { -+ int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate -+ gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); -+ } -+ } -+ } -+} -+ -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+template ::value && -+ ALayout::rank == 3 && is_rmem::value && -+ BLayout::rank == 3 && is_rmem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M,K) Logical data -+ Tensor const& B, // (V,N,K) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ auto K = size<2>(A); -+ -+ CUTE_UNROLL -+ for (int k = 0; k < K; ++k) { -+ gemm(mma, D, A(_,_,k), B(_,_,k), C); -+ } -+} -+ -+// -+// Thread-Local Shared-Memory GEMMs -+// -+ -+// Dispatch [1]: (V) x (V) => (V) -+// Dispatch [2]: (M) x (N) => (M,N) -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+// Dispatch [4]: (V,M) x (V,N) => (V,M,N) -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+// Dispatch [3]: (M,K) x (N,K) => (M,N) -+template ::value && -+ ALayout::rank == 2 && is_smem::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (M,N) Logical data -+ Tensor const& A, // (M,K) Logical data -+ Tensor const& B, // (N,K) Logical data -+ Tensor const& C) // (M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); -+ -+ // Assert this is a 1-value MMA -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); -+ -+ gemm(mma, -+ make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) -+ make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) -+ make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) -+ make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) -+} -+ -+// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) -+template ::value && -+ ALayout::rank == 3 && is_smem::value && -+ BLayout::rank == 3 && is_smem::value && -+ CLayout::rank == 3 && is_rmem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(MMA_Atom const& mma, -+ Tensor & D, // (V,M,N) Logical data -+ Tensor const& A, // (V,M,K) Logical data -+ Tensor const& B, // (V,N,K) Logical data -+ Tensor const& C) // (V,M,N) Logical data -+{ -+ CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK -+ CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); -+ -+ auto rA = MMA_Atom::make_fragment_A(A); -+ auto rB = MMA_Atom::make_fragment_B(B); -+ -+ auto K = size<2>(A); -+ -+ CUTE_UNROLL -+ for (int k = 0; k < K; ++k) -+ { -+ copy(A(_,_,k), rA(_,_,k)); -+ copy(B(_,_,k), rB(_,_,k)); -+ // Thread-level register gemm for k -+ gemm(mma, D, rA(_,_,k), rB(_,_,k), C); -+ } -+} -+ -+// -+// Collective Shared-Memory GEMMs -+// -+ -+template ::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_smem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(ThrMMA const& thr_mma, -+ Alpha const& alpha, -+ Tensor sA, -+ Tensor sB, -+ Beta const& beta, -+ Tensor sC, -+ ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, -+ BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) -+{ -+ CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM -+ CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK -+ -+ using TypeA = typename TA::value_type; -+ using TypeB = typename TB::value_type; -+ using TypeC = typename TC::value_type; -+ -+ static_assert(std::is_same_v>, TypeA>, -+ "ALoadTransformOp functor must accept and return value of type TA::value_type"); -+ static_assert(std::is_same_v>, TypeB>, -+ "BLoadTransformOp functor must accept and return value of type TB::value_type"); -+ -+ // Original, static size of the problem -+ auto M = size<0>(sC); -+ auto N = size<1>(sC); -+ auto K = size<1>(sA); -+ -+ // Block size of the compute tile -+ auto BLK_M = tile_size<0>(thr_mma); -+ auto BLK_N = tile_size<1>(thr_mma); -+ auto BLK_K = tile_size<2>(thr_mma); -+ -+ // Compute the "residues" -+ auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] -+ auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] -+ auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] -+ -+ // Shift the origin so k_residue is zeroth tile -+ sA.data() = &sA(0,k_residue); -+ sB.data() = &sB(0,k_residue); -+ -+#if 0 -+ if (thread0()) { -+ printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); -+ printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); -+ printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); -+ } -+#endif -+ -+ // -+ // MMA Partitioning -+ // -+ -+ // Round the layout extents up to BLK_X -+ Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); -+ Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); -+ Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); -+ -+#if 0 -+ if (thread0()) { -+ print(rounded_sA.layout()); print("\n"); -+ print(rounded_sB.layout()); print("\n"); -+ print(rounded_sC.layout()); print("\n"); -+ } -+#endif -+ -+ // Partition the sA and sB tiles across the threads for the MMA -+ Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) -+ Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K) -+ Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N) -+ // Create register tensors for the MMA to operate on -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) -+ Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) -+ -+#if 0 -+ if (thread0()) { -+ print(tCsA.layout()); print("\n"); -+ print(tCsB.layout()); print("\n"); -+ print(tCsC.layout()); print("\n"); -+ print(tCrA.layout()); print("\n"); -+ print(tCrB.layout()); print("\n"); -+ print(tCrC.layout()); print("\n"); -+ } -+#endif -+ -+ // -+ // PREDICATION -+ // -+ -+ // Allocate the preds for only the MMA-mode of tCsA and tCsB -+ Tensor tCpA = make_tensor(size<0>(tCsA)); -+ Tensor tCpB = make_tensor(size<0>(tCsB)); -+ -+ // Create coordinate tensors on a single compute block for predication -+ Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat partitioning with thr_mma -+ Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) -+ Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) -+ -+ // Populate the m and n predicates -+ CUTE_UNROLL -+ for (int i = 0; i < size(tCpA); ++i) { -+ tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); -+ } -+ CUTE_UNROLL -+ for (int i = 0; i < size(tCpB); ++i) { -+ tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); -+ } -+ -+#if 0 -+ printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", -+ threadIdx.x, -+ int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), -+ int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); -+#endif -+ -+ // -+ // PREFETCH k_block = 0 (with k-predication) -+ // -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I -+ if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m -+ tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; -+ } -+ } -+ } -+ -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I -+ if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k -+ CUTE_UNROLL -+ for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n -+ tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; -+ } -+ } -+ } -+ // -+ // MAINLOOP -+ // -+ -+ // Clear accumulators -+ clear(tCrC); -+ -+ constexpr int K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTE_UNROLL -+ for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) -+ { -+ // static-if load the next k_block. No k-predication required on these loads. -+ if (k_block < K_BLOCK_MAX-1) -+ { -+ // Load the next k_block -+ int k_next = k_block + 1; -+ -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m -+ tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; -+ } -+ } -+ -+ CUTE_UNROLL -+ for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n -+ tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; -+ } -+ } -+ } -+ -+ // GEMM on k_block in registers -+ gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) -+ Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) -+ -+ const bool isBetaZero = (beta == Beta{}); -+ -+ // Custom axpby_if for now -+ CUTE_UNROLL -+ for (int m = 0; m < size<1>(tCsC); ++m) -+ { -+ CUTE_UNROLL -+ for (int n = 0; n < size<2>(tCsC); ++n) -+ { -+ CUTE_UNROLL -+ for (int i = 0; i < size<0>(tCsC); ++i) -+ { -+ if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && -+ (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) -+ { -+ tCsC(i,m,n) = isBetaZero ? alpha * tCrC(i,m,n) : alpha * tCrC(i,m,n) + beta * tCsC(i,m,n); -+ } -+ } -+ } -+ } -+} -+ -+template ::value && -+ BLayout::rank == 2 && is_smem::value && -+ CLayout::rank == 2 && is_smem::value)> -+CUTE_HOST_DEVICE -+void -+gemm(ThrMMA const& thr_mma, -+ Alpha const& alpha, -+ Tensor sA, -+ Tensor sB, -+ Beta const& beta, -+ Tensor sC) -+{ -+ gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/prefer.hpp b/3rdparty/cutlass/include/cute/algorithm/prefer.hpp -new file mode 100644 -index 0000000..700edff ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/prefer.hpp -@@ -0,0 +1,46 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+namespace cute -+{ -+ -+// Infinite types that inherit from each other -+template -+struct prefer : prefer {}; -+ -+template <> -+struct prefer<0> {}; -+ -+// Can be used to preferencially overload implementations -+// Higher N in prefer have higher priority. -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp b/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp -new file mode 100644 -index 0000000..258ddec ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/tensor_algorithms.hpp -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/** Common algorithms on (hierarchical) tensors */ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// for_each -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor const& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ static_cast(op)(tensor(i)); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ static_cast(op)(tensor(i)); -+ } -+} -+ -+// Accept mutable temporaries -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(Tensor&& tensor, UnaryOp&& op) -+{ -+ return for_each(tensor, static_cast(op)); -+} -+ -+// -+// transform -+// -+ -+// Similar to std::transform but does not return number of elements affected -+template -+CUTE_HOST_DEVICE constexpr -+void -+transform(Tensor& tensor, UnaryOp&& op) -+{ -+ CUTE_UNROLL -+ for (int i = 0; i < size(tensor); ++i) { -+ tensor(i) = static_cast(op)(tensor(i)); -+ } -+} -+ -+// Accept mutable temporaries -+template -+CUTE_HOST_DEVICE constexpr -+void -+transform(Tensor&& tensor, UnaryOp&& op) -+{ -+ return transform(tensor, std::forward(op)); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp b/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp -new file mode 100644 -index 0000000..35b19f9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/algorithm/tuple_algorithms.hpp -@@ -0,0 +1,846 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+/** Common algorithms on (hierarchical) tuples */ -+/** Style choice: -+ * Forward params [using static_cast(.)] for const/non-const/ref/non-ref args -+ * but don't bother forwarding functions as ref-qualified member fns are extremely rare -+ */ -+ -+namespace cute -+{ -+ -+// -+// Apply (Unpack) -+// (t, f) => f(t_0,t_1,...,t_n) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+apply(T&& t, F&& f, seq) -+{ -+ return f(get(static_cast(t))...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+apply(T&& t, F&& f) -+{ -+ return detail::apply(static_cast(t), f, tuple_seq{}); -+} -+ -+// -+// Transform Apply -+// (t, f, g) => g(f(t_0),f(t_1),...) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T&& t, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t)))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t0)), -+ get(static_cast(t1)))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) -+{ -+ return g(f(get(static_cast(t0)), -+ get(static_cast(t1)), -+ get(static_cast(t2)))...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T&& t, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t), f, g, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) -+{ -+ return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); -+} -+ -+// -+// For Each -+// (t, f) => f(t_0),f(t_1),...,f(t_n) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+for_each(T&& t, F&& f) -+{ -+ detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+for_each_leaf(T&& t, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); -+ } else { -+ return f(static_cast(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Transform -+// (t, f) => (f(t_0),f(t_1),...,f(t_n)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T const& t, F&& f) -+{ -+ return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T0 const& t0, T1 const& t1, F&& f) -+{ -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) -+{ -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); -+ return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_leaf(T const& t, F&& f) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); -+ } else { -+ return f(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// find and find_if -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f, seq<>) -+{ -+ return cute::integral_constant::value>{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f, seq) -+{ -+ if constexpr (decltype(f(get(t)))::value) { -+ return cute::integral_constant{}; -+ } else { -+ return find_if(t, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find_if(T const& t, F&& f) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::find_if(t, f, tuple_seq{}); -+ } else { -+ return cute::integral_constant{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+find(T const& t, X const& x) -+{ -+ return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false -+} -+ -+template -+auto -+none_of(T const& t, F&& f) -+{ -+ return cute::integral_constant::value>{}; -+} -+ -+template -+auto -+all_of(T const& t, F&& f) -+{ -+ auto not_f = [&](auto const& a) { return !f(a); }; -+ return cute::integral_constant::value>{}; -+} -+ -+template -+auto -+any_of(T const& t, F&& f) -+{ -+ return cute::integral_constant{}; -+} -+ -+// -+// Filter -+// (t, f) => -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_tuple(T const& t, F&& f) -+{ -+ return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_tuple(T0 const& t0, T1 const& t1, F&& f) -+{ -+ return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); -+} -+ -+// -+// Fold (Reduce, Accumulate) -+// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) -+// -+ -+namespace detail { -+ -+// This impl compiles much faster than cute::apply and variadic args -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold(T&& t, V&& v, F&& f, seq<>) -+{ -+ return static_cast(v); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold(T&& t, V&& v, F&& f, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return f(static_cast(v), get(static_cast(t))); -+ } else { -+ return fold(static_cast(t), -+ f(static_cast(v), get(static_cast(t))), -+ f, -+ seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+fold(T&& t, V&& v, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::fold(static_cast(t), -+ static_cast(v), -+ f, -+ tuple_seq{}); -+ } else { -+ return f(static_cast(v), static_cast(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+fold_first(T&& t, F&& f) -+{ -+ if constexpr (is_tuple>::value) { -+ return detail::fold(static_cast(t), -+ get<0>(static_cast(t)), -+ f, -+ make_range<1,std::tuple_size>::value>{}); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// front, back, take, unwrap -+// -+ -+// Get the first non-tuple element in a hierarchical tuple -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+front(T&& t) -+{ -+ if constexpr (is_tuple>::value) { -+ return front(get<0>(static_cast(t))); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Get the last non-tuple element in a hierarchical tuple -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+back(T&& t) -+{ -+ if constexpr (is_tuple>::value) { -+ constexpr int N = tuple_size>::value; -+ return back(get(static_cast(t))); -+ } else { -+ return static_cast(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Takes the elements in the range [B,E) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(T const& t) -+{ -+ return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); -+} -+ -+// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+unwrap(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (tuple_size::value == 1) { -+ return unwrap(get<0>(t)); -+ } else { -+ return t; -+ } -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Flatten a hierarchical tuple to a tuple of depth one. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten_to_tuple(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); -+ } else { -+ return cute::make_tuple(t); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// insert and remove and replace -+// -+ -+namespace detail { -+ -+// Shortcut around tuple_cat for common insert/remove/repeat cases -+template -+CUTE_HOST_DEVICE constexpr -+auto -+construct(T const& t, X const& x, seq, seq, seq) -+{ -+ return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); -+} -+ -+} // end namespace detail -+ -+// Insert x into the Nth position of the tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+insert(T const& t, X const& x) -+{ -+ return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// Remove the Nth element of the tuple -+template -+CUTE_HOST_DEVICE constexpr -+auto -+remove(T const& t) -+{ -+ return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); -+} -+ -+// Replace the Nth element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace(T const& t, X const& x) -+{ -+ return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// Replace the first element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace_front(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Replace the last element of the tuple with x -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace_back(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Make a tuple of Xs of tuple_size N -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+repeat(X const& x) -+{ -+ return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); -+} -+ -+// -+// Make a tuple of Xs the same profile as tuple -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+repeat_like(T const& t, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(t, [&](auto const& a) { return repeat_like(a,x); }); -+ } else { -+ return x; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Group the elements [B,E) of a T into a single element -+// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) -+// => T<_1,_2,T<_3,_4>,_5,_6>{} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(T const& t) -+{ -+ return detail::construct(t, take(t), make_seq{}, seq<0>{}, make_range::value>{}); -+} -+ -+// -+// Extend a T to rank N by appending/prepending an element -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (N == tuple_size::value) { -+ return a; -+ } else { -+ static_assert(N > tuple_size::value); -+ return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); -+ } -+ } else { -+ if constexpr (N == 1) { -+ return a; -+ } else { -+ return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); -+ } else { -+ return cute::make_tuple(a, x); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (N == tuple_size::value) { -+ return a; -+ } else { -+ static_assert(N > tuple_size::value); -+ return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); -+ } -+ } else { -+ if constexpr (N == 1) { -+ return a; -+ } else { -+ static_assert(N > 1); -+ return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(T const& a, X const& x) -+{ -+ if constexpr (is_tuple::value) { -+ return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); -+ } else { -+ return cute::make_tuple(x, a); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Inclusive scan (prefix sum) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+iscan(T const& t, V const& v, F&& f, seq) -+{ -+ // Apply the function to v and the element at I -+ auto v_next = f(v, get(t)); -+ // Replace I with v_next -+ auto t_next = replace(t, v_next); -+ -+#if 0 -+ std::cout << "ISCAN i" << I << std::endl; -+ std::cout << " t " << t << std::endl; -+ std::cout << " i " << v << std::endl; -+ std::cout << " f(i,t) " << v_next << std::endl; -+ std::cout << " t_n " << t_next << std::endl; -+#endif -+ -+ if constexpr (sizeof...(Is) == 0) { -+ return t_next; -+ } else { -+ return iscan(t_next, v_next, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+iscan(T const& t, V const& v, F&& f) -+{ -+ return detail::iscan(t, v, f, tuple_seq{}); -+} -+ -+// -+// Exclusive scan (prefix sum) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+escan(T const& t, V const& v, F&& f, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ // Replace I with v -+ return replace(t, v); -+ } else { -+ // Apply the function to v and the element at I -+ auto v_next = f(v, get(t)); -+ // Replace I with v -+ auto t_next = replace(t, v); -+ -+#if 0 -+ std::cout << "ESCAN i" << I << std::endl; -+ std::cout << " t " << t << std::endl; -+ std::cout << " i " << v << std::endl; -+ std::cout << " f(i,t) " << v_next << std::endl; -+ std::cout << " t_n " << t_next << std::endl; -+#endif -+ -+ // Recurse -+ return escan(t_next, v_next, f, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+escan(T const& t, V const& v, F&& f) -+{ -+ return detail::escan(t, v, f, tuple_seq{}); -+} -+ -+// -+// Zip (Transpose) -+// -+ -+// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input -+// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip_(T const& t, seq) -+{ -+ return cute::make_tuple(get(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T const& t, seq, seq) -+{ -+ static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); -+ return cute::make_tuple(detail::zip_(t, seq{})...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T const& t) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple>::value) { -+ return detail::zip(t, tuple_seq{}, tuple_seq>{}); -+ } else { -+ return cute::make_tuple(t); -+ } -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Convenient to pass them in separately -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(T0 const& t0, T1 const& t1, Ts const&... ts) -+{ -+ return zip(cute::make_tuple(t0, t1, ts...)); -+} -+ -+// -+// zip2_by -- A guided zip for rank-2 tuples -+// Take a tuple like ((A,a),((B,b),(C,c)),d) -+// and produce a tuple ((A,(B,C)),(a,(b,c),d)) -+// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip2_by(T const& t, TG const& guide, seq, seq) -+{ -+ // zip2_by produces the modes like ((A,a),(B,b),...) -+ auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); -+ -+ // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) -+ return cute::make_tuple(cute::make_tuple(get(split)...), -+ cute::make_tuple(get(split)..., get(t)...)); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip2_by(T const& t, TG const& guide) -+{ -+ if constexpr (is_tuple::value) { -+ constexpr int TR = tuple_size::value; -+ constexpr int GR = tuple_size::value; -+ static_assert(TR >= GR, "Mismatched ranks"); -+ return detail::zip2_by(t, guide, -+ make_range< 0, GR>{}, -+ make_range{}); -+ } else { -+ static_assert(tuple_size::value == 2, "Mismatched ranks"); -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp b/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp -new file mode 100644 -index 0000000..6fd9edd ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/cluster_sm90.hpp -@@ -0,0 +1,190 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ -+ ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) -+# define CUTE_ARCH_CLUSTER_SM90_ENABLED -+#endif -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED -+#endif -+ -+namespace cute { -+ -+CUTE_DEVICE void cluster_arrive_relaxed() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_arrive() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.arrive.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_wait() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ asm volatile("barrier.cluster.wait.aligned;\n" : : ); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+CUTE_DEVICE void cluster_sync() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ cluster_arrive(); -+ cluster_wait(); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+// Returns the dim3 grid size in terms of number of clusters. -+CUTE_DEVICE dim3 cluster_grid_dims() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return gridDim; -+#endif -+} -+ -+// Returns the dim3 cluster rank in the grid. -+CUTE_DEVICE dim3 cluster_id_in_grid() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return blockIdx; -+#endif -+} -+ -+// Returns the relative dim3 block rank local to the cluster. -+CUTE_DEVICE dim3 block_id_in_cluster() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return {0,0,0}; -+#endif -+} -+ -+// Returns the dim3 cluster shape. -+CUTE_DEVICE dim3 cluster_shape() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t x, y, z; -+ asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : ); -+ asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : ); -+ asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : ); -+ return {x, y, z}; -+#else -+ return {1,1,1}; -+#endif -+} -+ -+// Get 1D ctaid in a cluster. -+CUTLASS_DEVICE uint32_t block_rank_in_cluster() -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t rank; -+ asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :); -+ return rank; -+#else -+ return 0; -+#endif -+} -+ -+// Set the destination block-ID in cluster for a given SMEM Address -+CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) -+{ -+#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) -+ uint32_t result; -+ asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" -+ : "=r"(result) -+ : "r"(smemAddr), "r"(rank)); -+ return result; -+#else -+ return smemAddr; -+#endif -+} -+ -+// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. -+CUTE_HOST_DEVICE uint32_t elect_one_sync() -+{ -+#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) -+ uint32_t pred = 0; -+ uint32_t laneid = 0; -+ asm volatile( -+ "{\n" -+ ".reg .b32 %rx;\n" -+ ".reg .pred %px;\n" -+ " elect.sync %rx|%px, %2;\n" -+ "@%px mov.s32 %1, 1;\n" -+ " mov.s32 %0, %rx;\n" -+ "}\n" -+ : "+r"(laneid), "+r"(pred) -+ : "r"(0xFFFFFFFF)); -+ return pred; -+#elif defined(__CUDA_ARCH__) -+ return (threadIdx.x % 32) == 0; -+#else -+ return true; -+#endif -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy.hpp b/3rdparty/cutlass/include/cute/arch/copy.hpp -new file mode 100644 -index 0000000..aa7bb33 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Direct Copy for any type -+// -+ -+template -+struct UniversalCopy -+{ -+ using SRegisters = S[1]; -+ using DRegisters = D[1]; -+ -+ CUTE_HOST_DEVICE static constexpr void -+ copy(S const& src, -+ D & dst) -+ { -+ dst = src; -+ } -+}; -+ -+// -+// Placeholder for the copy algorithm's default, auto-vectorizing behavior -+// -+ -+struct DefaultCopy -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint128_t[1]; -+}; -+ -+using AutoVectorizingCopy = DefaultCopy; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp -new file mode 100644 -index 0000000..fda6340 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm75.hpp -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+# define CUTE_ARCH_LDSM_SM75_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM75_U32x1_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" -+ : "=r"(dst) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U32x2_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" -+ : "=r"(dst0), "=r"(dst1) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U32x4_LDSM_N -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" -+ : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x2_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" -+ : "=r"(dst) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x4_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" -+ : "=r"(dst0), "=r"(dst1) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM75_U16x8_LDSM_T -+{ -+ using SRegisters = uint128_t[1]; -+ using DRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint128_t const& smem_src, -+ uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) -+ { -+#if defined(CUTE_ARCH_LDSM_SM75_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); -+ asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" -+ : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) -+ : "r"(smem_int_ptr)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED."); -+#endif -+ } -+}; -+ -+// -+// Legacy LDSM interfaces that aren't very useful -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_ldsm(uint128_t const* const smem_ptr, -+ T* rmem_ptr) -+{ -+ uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); -+ } -+ else if (sizeof(T) == 16) { -+ SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_ldsm_trans(uint128_t const* const smem_ptr, -+ T* rmem_ptr) -+{ -+ uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); -+ } -+ else if (sizeof(T) == 16) { -+ SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp -new file mode 100644 -index 0000000..c6c4412 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm80.hpp -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+/// Copy via cp.async with caching at all levels -+template -+struct SM80_CP_ASYNC_CACHEALWAYS -+{ -+ using SRegisters = TS[1]; -+ using DRegisters = TD[1]; -+ -+ static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); -+ static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); -+ -+ CUTE_HOST_DEVICE static void -+ copy(TS const& gmem_src, -+ TD & smem_dst) -+ { -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ TS const* gmem_ptr = &gmem_src; -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" -+ :: "r"(smem_int_ptr), -+ "l"(gmem_ptr), -+ "n"(sizeof(TS))); -+#else -+ CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); -+#endif -+ } -+}; -+ -+/// Copy via cp.async with caching at global level -+template -+struct SM80_CP_ASYNC_CACHEGLOBAL -+{ -+ using SRegisters = TS[1]; -+ using DRegisters = TD[1]; -+ -+ static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); -+ static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); -+ -+ CUTE_HOST_DEVICE static void -+ copy(TS const& gmem_src, -+ TD & smem_dst) -+ { -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ TS const* gmem_ptr = &gmem_src; -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" -+ :: "r"(smem_int_ptr), -+ "l"(gmem_ptr), -+ "n"(sizeof(TS))); -+#else -+ CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -+CUTE_HOST_DEVICE -+void -+cp_async_fence() -+{ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ asm volatile("cp.async.commit_group;\n" ::); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Blocks until all but N previous cp.async.commit_group operations have committed. -+template -+CUTE_HOST_DEVICE -+void -+cp_async_wait() -+{ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ if constexpr (N == 0) { -+ asm volatile("cp.async.wait_all;\n" ::); -+ } else { -+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -+ } -+#endif -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+cp_async_wait(Int) -+{ -+ return cp_async_wait(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp -new file mode 100644 -index 0000000..6ac9643 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90.hpp -@@ -0,0 +1,225 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+# define CUTE_ARCH_STSM_SM90_ENABLED -+# define CUTE_ARCH_TMA_SM90_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM90_U32x1_STSM_N -+{ -+ using SRegisters = uint32_t[1]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src, -+ uint128_t & smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U32x2_STSM_N -+{ -+ using SRegisters = uint32_t[2]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U32x4_STSM_N -+{ -+ using SRegisters = uint32_t[4]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1), "r"(src2), "r"(src3)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x2_STSM_T -+{ -+ using SRegisters = uint32_t[1]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x4_STSM_T -+{ -+ using SRegisters = uint32_t[2]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_U16x8_STSM_T -+{ -+ using SRegisters = uint32_t[4]; -+ using DRegisters = uint128_t[1]; -+ -+ CUTE_HOST_DEVICE static void -+ copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, -+ uint128_t& smem_dst) -+ { -+#if defined(CUTE_ARCH_STSM_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); -+ asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" -+ :: "r"(smem_int_ptr), -+ "r"(src0), "r"(src1), "r"(src2), "r"(src3)); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+// -+// Legacy STSM interfaces that aren't very useful -+// -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_stsm(T const* const rmem_ptr, -+ uint128_t* const smem_ptr) -+{ -+ uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 16) { -+ SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+copy_stsm_trans(T const* const rmem_ptr, -+ uint128_t* const smem_ptr) -+{ -+ uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); -+ -+ // if constexpr -+ if (sizeof(T) == 4) { -+ SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 8) { -+ SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); -+ } -+ else if (sizeof(T) == 16) { -+ SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); -+ } -+ else { -+ static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp -new file mode 100644 -index 0000000..ca8320f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90_desc.hpp -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+ -+#include -+#include -+#include // to_Format<[u]intX> -+#include // to_Format -+ -+namespace cute -+{ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns -+/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) -+/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) -+////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Initialize barrier present in shared memory -+CUTE_HOST_DEVICE -+void -+initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ int thread_count = 1) // Thread count expected to arrive/wait on this barrier -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile ("mbarrier.init.shared.b64 [%0], %1;\n" -+ :: "r"(smem_int_ptr), -+ "r"(thread_count)); -+#endif -+} -+ -+// Set the number of bytes transfered per transaction -+CUTE_HOST_DEVICE -+void -+set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ uint32_t bytes) // Number of bytes transfered by per TMA transaction -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile ("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" -+ :: "r"(smem_int_ptr), -+ "r"(bytes)); -+#endif -+} -+ -+// Barrier wait -+CUTE_HOST_DEVICE -+void -+wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem -+ int phase_bit) // Current phase bit the barrier waiting to flip -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile( -+ "{\n" -+ ".reg .pred P1;\n" -+ "LAB_WAIT:\n" -+ "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n" -+ "@P1 bra.uni DONE;\n" -+ "bra.uni LAB_WAIT;\n" -+ "DONE:\n" -+ "}\n" -+ :: "r"(smem_int_ptr), -+ "r"(phase_bit)); -+ -+#endif -+} -+ -+// Barrier arrive -+CUTE_HOST_DEVICE -+void -+arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); -+ asm volatile( -+ "{\n" -+ ".reg .b64 state; \n" -+ "mbarrier.arrive.shared.b64 state, [%0];\n" -+ "}\n" -+ :: "r"(smem_int_ptr)); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// TMA Descriptor and utilities -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace TMA { -+ -+enum class SmemSwizzleBits : uint8_t { -+ DISABLE = 0, -+ B32 = 1, -+ B64 = 2, -+ B128 = 3, -+}; -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ -+template -+inline CUtensorMapDataType to_CUtensorMapDataType() { -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else -+ if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else -+ { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } -+} -+ -+inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { -+ switch (t) { -+ default: assert(false && "Unknown SmemSwizzleBits!"); -+ case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; -+ case SmemSwizzleBits::B32: return CU_TENSOR_MAP_SWIZZLE_32B; -+ case SmemSwizzleBits::B64: return CU_TENSOR_MAP_SWIZZLE_64B; -+ case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; -+ } -+} -+ -+#endif // (__CUDACC_VER_MAJOR__ >= 12) -+} // end namespace TMA -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+using TmaDescriptor = CUtensorMap; -+#else -+using TmaDescriptor = struct { char bytes[128]; }; -+#endif -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Initiates a TensorMap Prefetch -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTE_HOST_DEVICE -+void -+prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) -+{ -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) -+ asm volatile ( -+ "prefetch.tensormap [%0];" -+ : -+ : "l"(gmem_int_desc) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp b/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp -new file mode 100644 -index 0000000..d6025e4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/copy_sm90_tma.hpp -@@ -0,0 +1,552 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_1D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_2D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_3D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_4D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5, %6}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_5D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" -+ " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_LOAD_2D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_LOAD_3D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_LOAD_4D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_LOAD_5D::copy(desc_ptr, smem_mbar, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_1D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_2D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_3D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_4D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6, %7}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_5D_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster" -+ " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3;" -+ : -+ : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), -+ "h"(multicast_mask), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_LOAD_MULTICAST -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// TMA_STORE : Initiates a TMA copy from shared memory to global memory -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_STORE_1D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_2D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_3D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_4D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE_5D -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); -+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile ( -+ "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" -+ : -+ : "l"(gmem_int_desc), "r"(smem_int_ptr), -+ "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+ } -+}; -+ -+struct SM90_TMA_STORE -+{ -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0) -+ { -+ return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1) -+ { -+ return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) -+ { -+ return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) -+ { -+ return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); -+ } -+ CUTE_HOST_DEVICE static void -+ copy(void const* const desc_ptr, -+ void const* const smem_ptr, -+ int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) -+ { -+ return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); -+ } -+}; -+ -+// Indicate arrival of warp issuing TMA_STORE -+CUTE_HOST_DEVICE static void -+tma_store_arrive() { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ asm volatile("cp.async.bulk.commit_group;"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+// Wait on prior N (Count) TMA_STORE instructions to complete -+template -+CUTE_HOST_DEVICE static void -+tma_store_wait() { -+#if defined(CUTE_ARCH_TMA_SM90_ENABLED) -+ asm volatile( -+ "cp.async.bulk.wait_group.read %0;" -+ : -+ : "n"(Count) -+ : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma.hpp b/3rdparty/cutlass/include/cute/arch/mma.hpp -new file mode 100644 -index 0000000..1c1058f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma.hpp -@@ -0,0 +1,64 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Direct FMA for any type -+// -+ -+template -+struct UniversalFMA -+{ -+ using DRegisters = D[1]; -+ using ARegisters = A[1]; -+ using BRegisters = B[1]; -+ using CRegisters = C[1]; -+ -+ CUTE_HOST_DEVICE static constexpr void -+ fma(D & d, -+ A const& a, -+ B const& b, -+ C const& c) -+ { -+ // Forward to an ADL/cute free function for these types -+ using cute::fma; -+ fma(d, a, b, c); -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp -new file mode 100644 -index 0000000..32a9fbb ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm61.hpp -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+# define CUTE_ARCH_MMA_SM61_ENABLED -+#endif -+ -+namespace cute -+{ -+ -+struct SM61_DP4A -+{ -+ using DRegisters = int32_t[1]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = int32_t[1]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) -+ { -+#if defined(CUTE_ARCH_MMA_SM61_ENABLED) -+ asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d) -+ : "r"(a), "r"(b), "r"(c)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); -+#endif -+ } -+}; -+ -+struct SM61_DP2A -+{ -+ using DRegisters = int32_t[1]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = int32_t[1]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) -+ { -+#if defined(CUTE_ARCH_MMA_SM61_ENABLED) -+ asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d) -+ : "r"(a), "r"(b), "r"(c)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); -+#endif -+ } -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp -new file mode 100644 -index 0000000..139e600 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm70.hpp -@@ -0,0 +1,329 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) -+# define CUTE_ARCH_MMA_SM70_SUPPORTED -+# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_ARCH_MMA_SM70_ENABLED -+# endif -+#endif -+ -+namespace cute -+{ -+ -+// -+// SM70 MMA 884 F16F16F16 -+// -+ -+struct SM70_8x8x4_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_NT -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_NN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F16F16F16F16_TT -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6, %7}," -+ "{%8, %9, %10, %11};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// SM70 MMA 884 F16F16F32 -+// -+ -+struct SM70_8x8x4_F32F16F16F32_TN -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_NT -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_NN -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct SM70_8x8x4_F32F16F16F32_TT -+{ -+ using DRegisters = float[8]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[8]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, uint32_t const& b1, -+ float const& c0, float const& c1, float const& c2, float const& c3, -+ float const& c4, float const& c5, float const& c6, float const& c7) -+ { -+#if defined(CUTE_ARCH_MMA_SM70_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11}," -+ "{%12, %13, %14, %15, %16, %17, %18, %19};" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), -+ "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) -+ : "r"(a0), "r"(a1), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3), -+ "f"(c4), "f"(c5), "f"(c6), "f"(c7)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); -+#endif -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp -new file mode 100644 -index 0000000..20d2b56 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm75.hpp -@@ -0,0 +1,120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) -+# define CUTE_ARCH_MMA_SM75_SUPPORTED -+# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+# define CUTE_ARCH_MMA_SM75_ENABLED -+# endif -+#endif -+ -+namespace cute -+{ -+ -+// -+// SM75 MMA 1688 F16F16F32 -+// -+ -+struct SM75_16x8x8_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const& c0, float const& c1, float const& c2, float const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM75_ENABLED) -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// SM75 MMA 8816 S8S8S32 -+// -+ -+struct SM75_8x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ // Register asm fma -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM75_ENABLED) -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp -new file mode 100644 -index 0000000..6050500 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm80.hpp -@@ -0,0 +1,2132 @@ -+ /************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+# define CUTE_ARCH_MMA_SM80_ENABLED -+#endif -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3}," -+ "{%4}," -+ "{%5, %6};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F16F16F16F16_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3, %4, %5}," -+ "{%6, %7}," -+ "{%8, %9};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F32F16F16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32BF16BF16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_F32BF16BF16F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM80_16x8x4_F32TF32TF32F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM80_16x8x8_F32TF32TF32F32_TN -+{ -+ using DRegisters = float[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(float & d0, float & d1, float & d2, float & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ float const & c0, float const & c1, float const & c2, float const & c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "f"(c0), "f"(c1), "f"(c2), "f"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x4 TN -+struct SM80_8x8x4_F64F64F64F64_TN -+{ -+ using DRegisters = double[2]; -+ using ARegisters = double[1]; -+ using BRegisters = double[1]; -+ using CRegisters = double[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, -+ double const& a0, -+ double const& b0, -+ double const& c0, double const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=d"(d0), "=d"(d1) -+ : "d"(a0), -+ "d"(b0), -+ "d"(c0), "d"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+// MMA 8x8x4 TN with Planar Complex multiplication -+struct SM80_8x8x4_C64C64C64C64_TN -+{ -+ using DRegisters = complex[2]; -+ using ARegisters = complex[1]; -+ using BRegisters = complex[1]; -+ using CRegisters = complex[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex const& a0, -+ complex const& b0, -+ complex const& c0, complex const& c1) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, -+ a0.real(), -+ b0.real(), -+ c0.real(), c1.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ id0, id1, -+ a0.imag(), -+ b0.real(), -+ c0.imag(), c1.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, -+ -a0.imag(), -+ b0.imag(), -+ d0.real(), d1.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM80_8x8x4_F64F64F64F64_TN::fma( -+ id0, id1, -+ a0.real(), -+ b0.imag(), -+ d0.imag(), d1.imag()); -+ } -+}; -+ -+// MMA 8x8x4 TN with Gaussian Complex multiplication: -+// (a + bi)*(c + di) -+// yields -+// t0 += a*c -+// t1 += b*d -+// t2 += (a+b)*(c+d) -+// then -+// re = t0 - t1 -+// im = t2 - t0 - t1 -+struct SM80_8x8x4_GC64C64C64GC64_TN -+{ -+ struct GaussComplex { -+ double t0, t1, t2; -+ -+ CUTE_HOST_DEVICE //constexpr -+ operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } -+ -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator*(complex const& a, GaussComplex const& b) { return b * a; } -+ -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } -+ CUTE_HOST_DEVICE friend //constexpr -+ complex operator+(complex const& a, GaussComplex const& b) { return b + a; } -+ }; -+ -+ using DRegisters = GaussComplex[2]; -+ using ARegisters = complex[1]; -+ using BRegisters = complex[1]; -+ using CRegisters = GaussComplex[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(GaussComplex & d0, GaussComplex & d1, -+ complex const& a0, -+ complex const& b0, -+ GaussComplex const& c0, GaussComplex const& c1) -+ { -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, -+ a0.real(), -+ b0.real(), -+ c0.t0, c1.t0); -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, -+ a0.imag(), -+ b0.imag(), -+ c0.t1, c1.t1); -+ SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, -+ a0.real() + a0.imag(), -+ b0.real() + b0.imag(), -+ c0.t2, c1.t2); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8S8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8S8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x16 TN -+struct SM80_8x8x16_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM80_16x8x16_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8U8S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U8U8S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32S4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4S4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4S4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x32 TN -+struct SM80_8x8x32_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x32 TN -+struct SM80_16x8x32_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4U4S32_TN -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x64 TN -+struct SM80_16x8x64_S32U4U4S32_TN_SATURATE -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 8x8x128 TN -+struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[2]; -+ using ARegisters = uint32_t[1]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, -+ uint32_t const& a0, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1}," -+ "{%2}," -+ "{%3}," -+ "{%4, %5};\n" -+ : "=r"(d0), "=r"(d1) -+ : "r"(a0), -+ "r"(b0), -+ "r"(c0), "r"(c1)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x128 TN -+struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[2]; -+ using BRegisters = uint32_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, -+ uint32_t const& b0, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), -+ "r"(b0), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x256 TN -+struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC -+{ -+ using DRegisters = uint32_t[4]; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint32_t[2]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint32_t const& b0, uint32_t const& b1, -+ uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM80_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "r"(b0), "r"(b1), -+ "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp -new file mode 100644 -index 0000000..08fe2b2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90.hpp -@@ -0,0 +1,961 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM90_16x8x4_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[2]; -+ using BRegisters = double[1]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, -+ double const& b0, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5}," -+ "{%6}," -+ "{%7, %8, %9, %10};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), -+ "d"(b0), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM90_16x8x8_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[4]; -+ using BRegisters = double[2]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, double const& a2, double const& a3, -+ double const& b0, double const& b1, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ "{%8, %9}," -+ "{%10, %11, %12, %13};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), "d"(a2), "d"(a3), -+ "d"(b0), "d"(b1), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM90_16x8x16_F64F64F64F64_TN -+{ -+ using DRegisters = double[4]; -+ using ARegisters = double[8]; -+ using BRegisters = double[4]; -+ using CRegisters = double[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(double & d0, double & d1, double & d2, double & d3, -+ double const& a0, double const& a1, double const& a2, double const& a3, -+ double const& a4, double const& a5, double const& a6, double const& a7, -+ double const& b0, double const& b1, double const& b2, double const& b3, -+ double const& c0, double const& c1, double const& c2, double const& c3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7, %8, %9, %10, %11}," -+ "{%12, %13, %14, %15}," -+ "{%16, %17, %18, %19};\n" -+ : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) -+ : "d"(a0), "d"(a1), "d"(a2), "d"(a3), -+ "d"(a4), "d"(a5), "d"(a6), "d"(a7), -+ "d"(b0), "d"(b1), "d"(b2), "d"(b3), -+ "d"(c0), "d"(c1), "d"(c2), "d"(c3)); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x4 TN -+struct SM90_16x8x4_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[2]; -+ using BRegisters = complex[1]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& b0, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), -+ b0.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), -+ b0.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -+ b0.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x4_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), -+ b0.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x8 TN -+struct SM90_16x8x8_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[4]; -+ using BRegisters = complex[2]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& a2, complex const& a3, -+ complex const& b0, complex const& b1, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ b0.real(), b1.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), a2.imag(), a3.imag(), -+ b0.real(), b1.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -+ b0.imag(), b1.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x8_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ b0.imag(), b1.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 16x8x16 TN -+struct SM90_16x8x16_C64C64C64C64_TN -+{ -+ using DRegisters = complex[4]; -+ using ARegisters = complex[8]; -+ using BRegisters = complex[4]; -+ using CRegisters = complex[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(complex & d0, complex & d1, -+ complex & d2, complex & d3, -+ complex const& a0, complex const& a1, -+ complex const& a2, complex const& a3, -+ complex const& a4, complex const& a5, -+ complex const& a6, complex const& a7, -+ complex const& b0, complex const& b1, -+ complex const& b2, complex const& b3, -+ complex const& c0, complex const& c1, -+ complex const& c2, complex const& c3) -+ { -+ // Because thrust::complex does not provide a mutable ref -+ double& rd0 = reinterpret_cast(d0)[0]; -+ double& id0 = reinterpret_cast(d0)[1]; -+ double& rd1 = reinterpret_cast(d1)[0]; -+ double& id1 = reinterpret_cast(d1)[1]; -+ double& rd2 = reinterpret_cast(d2)[0]; -+ double& id2 = reinterpret_cast(d2)[1]; -+ double& rd3 = reinterpret_cast(d3)[0]; -+ double& id3 = reinterpret_cast(d3)[1]; -+ -+ // d.real() = a.real() * b.real() + c.real(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ a4.real(), a5.real(), a6.real(), a7.real(), -+ b0.real(), b1.real(), b2.real(), b3.real(), -+ c0.real(), c1.real(), c2.real(), c3.real()); -+ -+ // d.imag() = a.imag() * b.real() + c.imag(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.imag(), a1.imag(), a2.imag(), a3.imag(), -+ a4.imag(), a5.imag(), a6.imag(), a7.imag(), -+ b0.real(), b1.real(), b2.real(), b3.real(), -+ c0.imag(), c1.imag(), c2.imag(), c3.imag()); -+ -+ // d.real() = -a.imag() * b.imag() + d.real(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ rd0, rd1, rd2, rd3, -+ -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -+ -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), -+ b0.imag(), b1.imag(), b2.imag(), b3.imag(), -+ d0.real(), d1.real(), d2.real(), d3.real()); -+ -+ // d.imag() = a.real() * b.imag() + d.imag(); -+ SM90_16x8x16_F64F64F64F64_TN::fma( -+ id0, id1, id2, id3, -+ a0.real(), a1.real(), a2.real(), a3.real(), -+ a4.real(), a5.real(), a6.real(), a7.real(), -+ b0.imag(), b1.imag(), b2.imag(), b3.imag(), -+ d0.imag(), d1.imag(), d2.imag(), d3.imag()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+namespace GMMA { -+ -+template< -+ class ElementA, -+ class ElementB, -+ class ElementC, -+ class TileShape_MNK, -+ GMMA::Major MajorA = GMMA::Major::K, -+ GMMA::Major MajorB = GMMA::Major::K, -+ auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] -+ // But most commonly leave empty for defaults -+> -+CUTE_HOST_DEVICE constexpr -+auto -+ss_op_selector() -+{ -+ static_assert(is_static::value, "TileShape_MNK must be static."); -+ static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); -+ static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); -+ auto Tile_N = size<1>(TileShape_MNK{}); -+ -+ // FP16 accumulator -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // Dispatch against the Tile N mode size -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F16F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F16F16F16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // FP32 accumulator -+ else if constexpr (std::is_same_v) { -+ -+ // FP16 inputs -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32F16F16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32F16F16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // BF16 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32BF16BF16_SS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32BF16BF16_SS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // TF32 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x8_F32TF32TF32_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x8_F32TF32TF32_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ else { -+ static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); -+ } -+ } -+ -+ // S32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ // ElementA == int8_t && ElementB == int8_t -+ if constexpr (std::is_same_v && std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8S8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == int8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8U8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == int8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8S8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8S8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8U8_SS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8U8_SS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ } -+ -+ // Unknown accumulator type -+ else { -+ static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); -+ } -+} -+ -+template< -+ class ElementA, -+ class ElementB, -+ class ElementC, -+ class TileShape_MNK, -+ GMMA::Major MajorA = GMMA::Major::K, -+ GMMA::Major MajorB = GMMA::Major::K, -+ auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] -+ // But most commonly leave empty for defaults -+> -+CUTE_HOST_DEVICE constexpr -+auto -+rs_op_selector() -+{ -+ static_assert(is_static::value, "TileShape_MNK must be static."); -+ static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); -+ static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); -+ static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); -+ auto Tile_N = size<1>(TileShape_MNK{}); -+ -+ // FP16 accumulator -+ if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // Dispatch against the Tile N mode size -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F16F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F16F16F16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // FP32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ // FP16 inputs -+ if constexpr (std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32F16F16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32F16F16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // BF16 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x16_F32BF16BF16_RS{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x16_F32BF16BF16_RS{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // TF32 inputs -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x8_F32TF32TF32_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x8_F32TF32TF32_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ else { -+ static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); -+ } -+ } -+ -+ // S32 accumulator -+ else if constexpr (std::is_same_v) { -+ static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ // ElementA == int8_t && ElementB == int8_t -+ if constexpr (std::is_same_v && std::is_same_v) { -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8S8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == int8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32S8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32S8U8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == int8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8S8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8S8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ -+ // ElementA == uint8_t && ElementB == uint8_t -+ else if constexpr (std::is_same_v && std::is_same_v) { -+ static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); -+ -+ if constexpr (Tile_N % 256 == 0) { -+ return SM90_64x256x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 192 == 0) { -+ return SM90_64x192x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 128 == 0) { -+ return SM90_64x128x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 96 == 0) { -+ return SM90_64x96x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 64 == 0) { -+ return SM90_64x64x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 32 == 0) { -+ return SM90_64x32x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 16 == 0) { -+ return SM90_64x16x32_S32U8U8_RS_TN{}; -+ } -+ else if constexpr (Tile_N % 8 == 0) { -+ return SM90_64x8x32_S32U8U8_RS_TN{}; -+ } -+ else { -+ static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); -+ } -+ } -+ } -+ -+ // Unknown accumulator type -+ else { -+ static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); -+ } -+} -+} // end namespace GMMA -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp -new file mode 100644 -index 0000000..abac517 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90_desc.hpp -@@ -0,0 +1,131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// GMMA Descriptor and utilities -+ -+// GMMA enums and utilities -+namespace GMMA -+{ -+ -+enum class LayoutType : uint8_t { -+ INTERLEAVE = 0, -+ B128 = 1, -+ B64 = 2, -+ B32 = 3, -+}; -+ -+CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { -+ switch (t) { -+ case LayoutType::INTERLEAVE: return "INTERLEAVE"; -+ case LayoutType::B128: return "B128"; -+ case LayoutType::B64: return "B64"; -+ case LayoutType::B32: return "B32"; -+ } -+ return nullptr; -+} -+ -+// Output operator for all enums in this namespace -+CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { -+ char const* s = to_string(t); -+ if (s) { -+ std::operator<<(os, s); // Explicit call to avoid ambiguity -+ } else { -+ os.setstate(std::ios_base::failbit); -+ } -+ return os; -+} -+ -+} // end namespace GMMA -+ -+union GmmaDescriptor -+{ -+ uint64_t desc_; -+ uint32_t reg32_[2]; -+ uint16_t reg16_[4]; -+ -+ // Bitfield implementation avoids the need for shifts in assignment -+ struct { -+ // start_address, bit [0,14), 4LSB not included -+ uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // leading dimension byte offset, bit [16,30), 4LSB not included -+ // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED -+ // Unused for all SWIZZLE_* layouts (and assumed to be 1) -+ // For T: This is the stride from the first 8 rows to the next 8 rows. -+ uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // stride dimension byte offset, bit [32,46), 4LSB not included -+ // For N: This is the stride from the first 8 rows to the next 8 rows. -+ // For T: This is the stride fro mthe first 8 cols to the next 8 cols. -+ uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused -+ // base_offset, bit [49,52) -+ // Valid only for SWIZZLE_128B and SWIZZLE_64B -+ uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused -+ // layout type, bit [62,64) -+ // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 -+ uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) -+ }; -+ -+ // Decay to a uint64_t -+ CUTE_HOST_DEVICE constexpr -+ operator uint64_t() const noexcept { return desc_; } -+ -+ // Printer -+ CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) -+ { -+ printf("GmmaDescriptor: 0x%016lx\n", t.desc_); -+ printf(" start_addr : 0x%04x\n", t.start_address_); -+ printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_); -+ printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_); -+ printf(" base_offset: 0x%01x\n", t.base_offset_); -+ printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast(t.layout_type_))); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp b/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp -new file mode 100644 -index 0000000..25a1d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/mma_sm90_gmma.hpp -@@ -0,0 +1,12265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+// Config -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+# define CUTE_ARCH_MMA_SM90_ENABLED -+#endif -+ -+namespace cute { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Warpgroup sync primitives -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_arrive() -+{ -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+warpgroup_wait() -+{ -+ static_assert(N >= 0 && N <= 7, "_warpgroup.wait {N}; must be in range [0, 7]"); -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+// Marks the commit point for one or more sized batch of warpgroup MMAs. -+CUTE_HOST_DEVICE -+void -+warpgroup_commit_batch() -+{ -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+} -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(uint32_t& reg) { -+ asm volatile("" : "+r"(reg) :: "memory"); -+} -+ -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(float& reg) { -+ asm volatile("" : "+f"(reg) :: "memory"); -+} -+ -+namespace GMMA { -+ -+enum class Major { -+ K = 0, -+ MN = 1 -+}; -+ -+enum class ScaleOut { -+ Zero = 0, -+ One = 1 -+}; -+ -+enum class ScaleIn { -+ Neg = -1, -+ One = 1 -+}; -+ -+} // namespace GMMA -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " -+ "{%0, %1}," -+ " %2," -+ " %3," -+ " %4, %5, %6, %7, %8;\n" -+ : "+r"(d0), "+r"(d1) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[2]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " -+ "{%0, %1}," -+ "{%2, %3, %4, %5}," -+ " %6," -+ " %7, %8, %9, %10;\n" -+ : "+r"(d0), "+r"(d1) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[24]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23}," -+ " %24," -+ " %25," -+ " %26, %27, %28, %29, %30;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[24]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23}," -+ "{%24, %25, %26, %27}," -+ " %28," -+ " %29, %30, %31, %32;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F16F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F16+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F16F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100, %101, %102;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103, %104;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32F16F16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132, %133, %134;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=F16*F16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32F16F16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135, %136;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8, %9, %10;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12, %13, %14;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15, %16;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20, %21, %22;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23, %24;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36, %37, %38;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39, %40;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52, %53, %54;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55, %56;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68, %69, %70;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71, %72;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100, %101, %102;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103, %104;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32BF16BF16_SS -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132, %133, %134;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x16 F32+=BF16*BF16 -+template< -+ GMMA::Major tnspA, -+ GMMA::Major tnspB, -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x16_F32BF16BF16_RS -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ static_assert(tnspA == GMMA::Major::K, -+ "Register source operand A must have K major layout."); -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135, %136;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6, %7, %8;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x8x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x8x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9, %10, %11;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10, %11, %12;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x16x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x16x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ float & d0, float & d1, float & d2, float & d3, -+ float & d4, float & d5, float & d6, float & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13, %14, %15;\n" -+ : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), -+ "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18, %19, %20;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x32x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x32x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21, %22, %23;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34, %35, %36;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x64x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x64x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37, %38, %39;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50, %51, %52;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x96x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x96x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53, %54, %55;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66, %67, %68;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x128x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x128x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69, %70, %71;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98, %99, %100;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x192x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x192x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ float & d00, float & d01, float & d02, float & d03, -+ float & d04, float & d05, float & d06, float & d07, -+ float & d08, float & d09, float & d10, float & d11, -+ float & d12, float & d13, float & d14, float & d15, -+ float & d16, float & d17, float & d18, float & d19, -+ float & d20, float & d21, float & d22, float & d23, -+ float & d24, float & d25, float & d26, float & d27, -+ float & d28, float & d29, float & d30, float & d31, -+ float & d32, float & d33, float & d34, float & d35, -+ float & d36, float & d37, float & d38, float & d39, -+ float & d40, float & d41, float & d42, float & d43, -+ float & d44, float & d45, float & d46, float & d47, -+ float & d48, float & d49, float & d50, float & d51, -+ float & d52, float & d53, float & d54, float & d55, -+ float & d56, float & d57, float & d58, float & d59, -+ float & d60, float & d61, float & d62, float & d63, -+ float & d64, float & d65, float & d66, float & d67, -+ float & d68, float & d69, float & d70, float & d71, -+ float & d72, float & d73, float & d74, float & d75, -+ float & d76, float & d77, float & d78, float & d79, -+ float & d80, float & d81, float & d82, float & d83, -+ float & d84, float & d85, float & d86, float & d87, -+ float & d88, float & d89, float & d90, float & d91, -+ float & d92, float & d93, float & d94, float & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101, %102, %103;\n" -+ : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), -+ "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), -+ "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), -+ "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), -+ "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), -+ "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), -+ "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), -+ "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), -+ "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), -+ "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), -+ "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), -+ "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), -+ "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), -+ "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), -+ "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), -+ "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), -+ "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), -+ "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), -+ "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), -+ "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), -+ "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), -+ "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), -+ "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), -+ "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x8_F32TF32TF32_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130, %131, %132;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// GMMA 64x256x8 TN F32+=TF32*TF32 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, -+ GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, -+ GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -+> -+struct SM90_64x256x8_F32TF32TF32_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = float[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ float & d000, float & d001, float & d002, float & d003, -+ float & d004, float & d005, float & d006, float & d007, -+ float & d008, float & d009, float & d010, float & d011, -+ float & d012, float & d013, float & d014, float & d015, -+ float & d016, float & d017, float & d018, float & d019, -+ float & d020, float & d021, float & d022, float & d023, -+ float & d024, float & d025, float & d026, float & d027, -+ float & d028, float & d029, float & d030, float & d031, -+ float & d032, float & d033, float & d034, float & d035, -+ float & d036, float & d037, float & d038, float & d039, -+ float & d040, float & d041, float & d042, float & d043, -+ float & d044, float & d045, float & d046, float & d047, -+ float & d048, float & d049, float & d050, float & d051, -+ float & d052, float & d053, float & d054, float & d055, -+ float & d056, float & d057, float & d058, float & d059, -+ float & d060, float & d061, float & d062, float & d063, -+ float & d064, float & d065, float & d066, float & d067, -+ float & d068, float & d069, float & d070, float & d071, -+ float & d072, float & d073, float & d074, float & d075, -+ float & d076, float & d077, float & d078, float & d079, -+ float & d080, float & d081, float & d082, float & d083, -+ float & d084, float & d085, float & d086, float & d087, -+ float & d088, float & d089, float & d090, float & d091, -+ float & d092, float & d093, float & d094, float & d095, -+ float & d096, float & d097, float & d098, float & d099, -+ float & d100, float & d101, float & d102, float & d103, -+ float & d104, float & d105, float & d106, float & d107, -+ float & d108, float & d109, float & d110, float & d111, -+ float & d112, float & d113, float & d114, float & d115, -+ float & d116, float & d117, float & d118, float & d119, -+ float & d120, float & d121, float & d122, float & d123, -+ float & d124, float & d125, float & d126, float & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133, %134, %135;\n" -+ : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), -+ "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), -+ "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), -+ "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), -+ "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), -+ "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), -+ "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), -+ "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), -+ "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), -+ "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), -+ "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), -+ "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), -+ "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), -+ "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), -+ "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), -+ "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), -+ "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), -+ "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), -+ "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), -+ "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), -+ "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), -+ "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), -+ "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), -+ "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), -+ "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), -+ "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), -+ "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), -+ "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), -+ "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), -+ "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), -+ "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), -+ "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=S8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*S8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ " %4," -+ " %5," -+ " %6;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ " %8," -+ " %9," -+ " %10;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ " %16," -+ " %17," -+ " %18;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ " %32," -+ " %33," -+ " %34;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ " %48," -+ " %49," -+ " %50;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ " %64," -+ " %65," -+ " %66;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ " %96," -+ " %97," -+ " %98;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_SS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint64_t[1]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint64_t const& desc_a, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ " %128," -+ " %129," -+ " %130;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "l"(desc_a), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x8x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[4]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3}," -+ "{%4, %5, %6, %7}," -+ " %8," -+ " %9;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x16x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[8]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, -+ uint64_t const& desc_b, -+ uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, -+ uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7}," -+ "{%8, %9, %10, %11}," -+ " %12," -+ " %13;\n" -+ : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), -+ "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) -+ : "r"(a0), "r"(a1), "r"(a2), "r"(a3), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x32x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[16]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15}," -+ "{%16, %17, %18, %19}," -+ " %20," -+ " %21;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x64x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[32]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31}," -+ "{%32, %33, %34, %35}," -+ " %36," -+ " %37;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x96x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[48]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47}," -+ "{%48, %49, %50, %51}," -+ " %52," -+ " %53;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x128x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[64]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63}," -+ "{%64, %65, %66, %67}," -+ " %68," -+ " %69;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x192x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[96]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, -+ uint64_t const& desc_b, -+ uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, -+ uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, -+ uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, -+ uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, -+ uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, -+ uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, -+ uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, -+ uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, -+ uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, -+ uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, -+ uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, -+ uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, -+ uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, -+ uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, -+ uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, -+ uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, -+ uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, -+ uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, -+ uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, -+ uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, -+ uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, -+ uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, -+ uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, -+ uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95}," -+ "{%96, %97, %98, %99}," -+ " %100," -+ " %101;\n" -+ : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), -+ "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), -+ "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), -+ "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), -+ "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), -+ "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), -+ "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), -+ "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), -+ "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), -+ "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), -+ "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), -+ "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), -+ "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), -+ "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), -+ "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), -+ "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), -+ "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), -+ "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), -+ "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), -+ "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), -+ "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), -+ "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), -+ "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), -+ "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) -+ : "r"(a00), "r"(a01), "r"(a02), "r"(a03), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_RS_TN -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// MMA 64x256x32 TN S32+=U8*U8 -+template< -+ GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -+> -+struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE -+{ -+ using DRegisters = void; -+ using ARegisters = uint32_t[4]; -+ using BRegisters = uint64_t[1]; -+ using CRegisters = uint32_t[128]; -+ -+ CUTE_HOST_DEVICE static void -+ fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, -+ uint64_t const& desc_b, -+ uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, -+ uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, -+ uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, -+ uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, -+ uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, -+ uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, -+ uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, -+ uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, -+ uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, -+ uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, -+ uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, -+ uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, -+ uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, -+ uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, -+ uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, -+ uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, -+ uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, -+ uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, -+ uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, -+ uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, -+ uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, -+ uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, -+ uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, -+ uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, -+ uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, -+ uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, -+ uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, -+ uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, -+ uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, -+ uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, -+ uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, -+ uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) -+ { -+#if defined(CUTE_ARCH_MMA_SM90_ENABLED) -+ asm volatile( -+ "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " -+ "{%0, %1, %2, %3, %4, %5, %6, %7, " -+ " %8, %9, %10, %11, %12, %13, %14, %15, " -+ " %16, %17, %18, %19, %20, %21, %22, %23, " -+ " %24, %25, %26, %27, %28, %29, %30, %31, " -+ " %32, %33, %34, %35, %36, %37, %38, %39, " -+ " %40, %41, %42, %43, %44, %45, %46, %47, " -+ " %48, %49, %50, %51, %52, %53, %54, %55, " -+ " %56, %57, %58, %59, %60, %61, %62, %63, " -+ " %64, %65, %66, %67, %68, %69, %70, %71, " -+ " %72, %73, %74, %75, %76, %77, %78, %79, " -+ " %80, %81, %82, %83, %84, %85, %86, %87, " -+ " %88, %89, %90, %91, %92, %93, %94, %95, " -+ " %96, %97, %98, %99, %100, %101, %102, %103, " -+ " %104, %105, %106, %107, %108, %109, %110, %111, " -+ " %112, %113, %114, %115, %116, %117, %118, %119, " -+ " %120, %121, %122, %123, %124, %125, %126, %127}," -+ "{%128, %129, %130, %131}," -+ " %132," -+ " %133;\n" -+ : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), -+ "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), -+ "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), -+ "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), -+ "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), -+ "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), -+ "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), -+ "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), -+ "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), -+ "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), -+ "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), -+ "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), -+ "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), -+ "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), -+ "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), -+ "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), -+ "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), -+ "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), -+ "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), -+ "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), -+ "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), -+ "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), -+ "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), -+ "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), -+ "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), -+ "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), -+ "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), -+ "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), -+ "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), -+ "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), -+ "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), -+ "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) -+ : "r"(a000), "r"(a001), "r"(a002), "r"(a003), -+ "l"(desc_b), -+ "n"(int32_t(scaleD))); -+#else -+ CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/arch/util.hpp b/3rdparty/cutlass/include/cute/arch/util.hpp -new file mode 100644 -index 0000000..007781f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/arch/util.hpp -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) -+ extern "C" { -+ // This NVVM intrinsic is subject to change in future versions of CUDA. -+ // Clients should not call it directly. -+ CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); -+ } -+#endif -+ -+namespace cute -+{ -+ -+/// CUTE helper to cast SMEM pointer to unsigned -+CUTE_HOST_DEVICE -+uint32_t -+cast_smem_ptr_to_uint(void const* const ptr) -+{ -+// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to -+// the previous internal intrinsics if they are available. -+#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) -+ // -+ // This NVVM intrinsic converts an address in shared memory to a plain -+ // unsigned integer. This is necessary to pass to shared memory instructions -+ // in inline PTX. -+ // -+ // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. -+ // -+ //__device__ size_t __cvta_generic_to_shared(void* ptr); -+ -+ /// CUTE helper to get SMEM pointer -+ return static_cast(__cvta_generic_to_shared(ptr)); -+ -+#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) -+ -+ return __nvvm_get_smem_pointer(ptr); -+ -+#elif defined(__CUDA_ARCH__) -+ -+ uint32_t smem_ptr; -+ -+ asm( -+ "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" -+ : "=r"(smem_ptr) : "l"(ptr)); -+ -+ return smem_ptr; -+ -+#else -+ -+ -+ (void) ptr; -+ printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); -+ return 0; -+ -+#endif -+} -+ -+// -+// Utility for pointer interfaces -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrS&& s, int_sequence, -+ PtrD&& d, int_sequence) -+{ -+ return fn(s[Is]..., d[Id]...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrA&& a, int_sequence, -+ PtrB&& b, int_sequence, -+ PtrC&& c, int_sequence) -+{ -+ return fn(a[Ia]..., b[Ib]..., c[Ic]...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, -+ PtrD&& d, int_sequence, -+ PtrA&& a, int_sequence, -+ PtrB&& b, int_sequence, -+ PtrC&& c, int_sequence) -+{ -+ return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrS&& s, PtrD&& d) -+{ -+ return detail::explode(fn, -+ s, make_int_sequence{}, -+ d, make_int_sequence{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) -+{ -+ return detail::explode(fn, -+ a, make_int_sequence{}, -+ b, make_int_sequence{}, -+ c, make_int_sequence{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) -+{ -+ return detail::explode(fn, -+ d, make_int_sequence{}, -+ a, make_int_sequence{}, -+ b, make_int_sequence{}, -+ c, make_int_sequence{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_atom.hpp b/3rdparty/cutlass/include/cute/atom/copy_atom.hpp -new file mode 100644 -index 0000000..2c5d9c5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_atom.hpp -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+// Generic copy_unpack for any Copy_Traits -+template -+CUTE_HOST_DEVICE constexpr -+void -+copy_unpack(Copy_Traits const&, -+ Tensor const& src, -+ Tensor & dst) -+{ -+ // Specializations can generalize on these checks -+ //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); -+ //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); -+ -+ using RegistersSrc = typename Operation::SRegisters; -+ using RegistersDst = typename Operation::DRegisters; -+ using RegTypeSrc = typename std::remove_extent::type; -+ using RegTypeDst = typename std::remove_extent::type; -+ constexpr int RegNumSrc = std::extent::value; -+ constexpr int RegNumDst = std::extent::value; -+ -+ Tensor rS = recast(src); -+ Tensor rD = recast(dst); -+ -+ CUTE_STATIC_ASSERT_V(size(rS) == Int{}, -+ "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); -+ CUTE_STATIC_ASSERT_V(size(rD) == Int{}, -+ "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); -+ -+ detail::explode(Operation::copy, -+ rS, make_int_sequence{}, -+ rD, make_int_sequence{}); -+} -+ -+ -+template -+struct Copy_Atom; -+ -+template -+struct Copy_Atom : Copy_Atom, T> -+{}; -+ -+template -+struct Copy_Atom, T> -+ : Copy_Traits -+{ -+ using Traits = Copy_Traits; -+ -+ // Bit and Thr layouts from the Copy_Traits -+ using ThrID = typename Traits::ThrID; -+ using BitLayoutSrc = typename Traits::SrcLayout; -+ using BitLayoutDst = typename Traits::DstLayout; -+ using BitLayoutRef = typename Traits::RefLayout; -+ -+ using ValType = T; -+ -+ using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); -+ using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); -+ using ValLayoutRef = decltype(upcast::value>(BitLayoutRef{})); -+ -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); -+ CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); -+ -+ static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); -+ static constexpr int NumValDst = size<1>(ValLayoutDst{}); -+ -+ // Additional Trait parameters/transformations -+ template -+ CUTE_HOST_DEVICE -+ auto -+ with(TraitsArgs&&... args) const { -+ auto traits = Traits::with(std::forward(args)...); -+ return Copy_Atom{traits}; -+ } -+ -+ // Print thread and data layouts for debugging -+ CUTE_HOST_DEVICE static -+ void -+ print_all() -+ { -+ print("ThrID: "); print(ThrID{}); print("\n"); -+ print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n"); -+ print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n"); -+ print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n"); -+ print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n"); -+ print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n"); -+ print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n"); -+ print("ValueType: %db", sizeof_bits::value); print("\n"); -+ } -+ -+ // -+ // Tensor call interfaces -+ // -+ -+ // Cast, check, and call -+ template -+ CUTE_HOST_DEVICE -+ void -+ call(Tensor const& src, -+ Tensor & dst) const -+ { -+ static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); -+ static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); -+ -+ if constexpr (is_constant::value || is_constant::value) { -+ // Dispatch to unpack for instruction -+ return copy_unpack(*this, src, dst); -+ } else { -+ // Recurse if needed by peeling the tensor mode -+ return copy(*this, tensor<0>(src), tensor<0>(dst)); -+ } -+ } -+ -+ // Accept mutable temporaries -+ template -+ CUTE_HOST_DEVICE -+ void -+ call(Tensor const& src, -+ Tensor && dst) const -+ { -+ return call(src, dst); -+ } -+}; -+ -+// -+// A tiling of copy atoms -+// -+ -+template coord [Need not be 2D...] -+ class ShapeTile_MN> // coord space -+struct TiledCopy : Copy_Atom -+{ -+ // Layout information from the CopyAtom -+ using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx -+ using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset -+ using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset -+ using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset -+ -+ using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); -+ using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); -+ -+ // Layout information for the TiledCopy -+ using Tiler_MN = ShapeTile_MN; -+ using TiledShape_MN = decltype(shape(ShapeTile_MN{})); -+ using TiledLayout_TV = LayoutCopy_TV; -+ using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); -+ using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); -+ -+ CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); -+ CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) -+ // where -+ // ThrV: The threads local to a COPY_ATOM Src. -+ // ThrX: The threads tiled across COPY_ATOMs Src. -+ // FrgV: The values local to a COPY_ATOM Src. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_S(STensor&& stensor) -+ { -+ return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) -+ // where -+ // ThrV: The threads local to a COPY_ATOM Dst. -+ // ThrX: The threads tiled across COPY_ATOMs Dst. -+ // FrgV: The values local to a COPY_ATOM Dst. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_D(DTensor&& dtensor) -+ { -+ return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) -+ { -+ constexpr int R = remove_cvref_t::rank; -+ static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); -+ // Generalize the dimension checks for arbitrary rank -+ //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Take the thrs/vals that the atom is interested in -+ // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID -+ auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); -+ // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) -+ -+ // Transform to the trg layout -+ auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); -+ // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) -+ -+ // Transform the thrs mode from thrid to thr_idx -+ // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID -+ auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); -+ // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) -+ -+ /// ================== -+ -+ // Tile the tensor for TiledLayout -+ auto t_tensor = zipped_divide(tensor, Tiler_MN{}); -+ // ((TileM,TileN,...),(RestM,RestN,...)) -+ -+ // Transform the tile mode -+ auto tv_tensor = t_tensor.compose(thrval2mn, _); -+ // ((thrid,val),(RM,RN,...)) -+ -+ // Unfold and return -+ return tv_tensor(make_coord(_,_), _); -+ } -+ -+ // retile_S and retile_D assume they are working with the reference layout -- they are the same -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ retile(Tensor&& tensor) -+ { -+ constexpr int R = remove_cvref_t::rank; -+ // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation -+ -+ // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. -+ // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV -+ // and that shape is what we gather from the other modes of tensor -+ -+ auto V = size<0>(tensor); -+ -+ auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{})); -+ // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV -+ -+ auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); -+ // (atom_vals,rest_vals) -> (v,m,n) -+ -+ /// ======= -+ -+ // Tile the tensor for TileFrg -+ auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); -+ // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) -+ -+ // Transform the tile mode -+ auto v_tensor = t_tensor.compose(frg_layout_v, _); -+ // ((atom_vals,rest_vals),(1,RM,RN,...)) -+ -+ // Unfold and return -+ return v_tensor(_, append(Int<0>{},_)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutS_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_S = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ auto layoutS_TV = tidfrg_S(ref_S); -+ // (M,K) -> (thr_idx,val_idx) -+ auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); -+ -+ return cute::make_tuple(layoutS_MK, thrID_S); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutS_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_S = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ return tidfrg_S(ref_S)(_,_,Int<0>{}); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutD_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_D = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ auto layoutD_TV = tidfrg_D(ref_D); -+ // (M,K) -> (thr_idx,val_idx) -+ auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); -+ -+ return cute::make_tuple(layoutD_MK, thrID_D); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutD_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_D = make_layout(TiledShape_MN{}); -+ // (thr_idx,val_idx) -> (M,N) -+ return tidfrg_D(ref_D)(_,_,Int<0>{}); -+ } -+ -+ template -+ struct ThrCopy : Copy_Atom -+ { -+ ThrIdx thr_idx_; -+ -+ CUTE_HOST_DEVICE -+ ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} -+ -+ template -+ CUTE_HOST_DEVICE -+ auto -+ partition_S(STensor&& stensor) { -+ //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ // "Expected ValType for tiling SrcTensor."); -+ auto thr_tensor = make_tensor(std::forward(stensor).data(), tidfrg_S(stensor.layout())); -+ return thr_tensor(thr_idx_, _, repeat>(_)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE -+ auto -+ partition_D(DTensor&& dtensor) { -+ //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ // "Expected ValType for tiling DstTensor."); -+ auto thr_tensor = make_tensor(std::forward(dtensor).data(), tidfrg_D(dtensor.layout())); -+ return thr_tensor(thr_idx_, _, repeat>(_)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static -+ auto -+ retile_S(STensor&& stensor) { -+ static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ "Expected ValType for tiling SrcTensor."); -+ return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static -+ auto -+ retile_D(DTensor&& dtensor) { -+ static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), -+ "Expected ValType for tiling DstTensor."); -+ return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); -+ } -+ }; -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static -+ auto -+ get_slice(ThrIdx const& thr_idx) -+ { -+ return ThrCopy(thr_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static -+ auto -+ get_thread_slice(ThrIdx const& thr_idx) -+ { -+ return get_slice(thr_idx); -+ } -+}; -+ -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_impl(Copy_Atom const& atom, -+ LayoutCopy_TV const&, -+ Tile const&) -+{ -+ return TiledCopy, LayoutCopy_TV, Tile>{atom}; -+} -+ -+// -+// These tile the Copy_Atom as a whole -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_A(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{}))); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_B(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{}))); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_C(Copy_Atom const& copy_atom, -+ TiledMMA const& tiled_mma) -+{ -+ using MNK = typename TiledMMA::TiledShape_MNK; -+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); -+} -+ -+template > -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy(Copy_Atom const& copy_atom, -+ ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx -+ ValLayout const& val_layout = {}) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ -+ auto thr_layout_mn = append(thr_layout, Layout<_1>{}); -+ auto val_layout_mn = append(val_layout, Layout<_1>{}); -+ -+ // Take the raked_products to compute the Layout_MN -+ auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); -+ auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); -+ -+ //print("thr_layout: "); print(thr_layout_mn); print("\n"); -+ //print("val_layout: "); print(val_layout_mn); print("\n"); -+ //print("layout_mn : "); print(layout_mn); print("\n"); -+ //print("layout_tv : "); print(layout_tv); print("\n"); -+ -+ return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); -+} -+ -+// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_S(Copy_Atom const& copy_atom, -+ TiledCopy const& tiled_copy) -+{ -+ return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); -+} -+ -+// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy -+template -+CUTE_HOST_DEVICE -+auto -+make_tiled_copy_D(Copy_Atom const& copy_atom, -+ TiledCopy const& tiled_copy) -+{ -+ return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); -+} -+ -+// -+// Size -+// -+ -+// The logical size of a TileCopy -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_size(TiledCopy const&) -+{ -+ return size(typename TiledCopy::TiledShape_MN{}); -+} -+ -+// The number of threads involved in a TiledCopy -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(TiledCopy const&) -+{ -+ return typename TiledCopy::TiledNumThr{}; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+print_latex(TiledCopy const& copy) -+{ -+ auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); -+ auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); -+ -+ print_latex_copy(layoutS_MN, thrID_S, -+ layoutD_MN, thrID_D); -+} -+ -+// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); -+ -+ assert(size<0>(S) == size<0>(D)); -+ assert(size<1>(S) == size<1>(D)); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}",}; -+ -+ // Header -+ printf("%% LayoutS: "); print(S); printf("\n"); -+ printf("%% ThrIDS : "); print(TS); printf("\n"); -+ printf("%% LayoutD: "); print(D); printf("\n"); -+ printf("%% ThrIDD : "); print(TD); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // S starting at 0,0 -+ for (int i = 0; i < size<0>(S); ++i) { -+ for (int j = 0; j < size<1>(S); ++j) { -+ int thrid = S(i,j) % size(TS); -+ int val_idx = S(i,j) / size(TS); -+ int thr_idx = TS(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // D starting at 0,size<1>(S)+3 -+ for (int i = 0; i < size<0>(D); ++i) { -+ for (int j = 0; j < size<1>(D); ++j) { -+ int thrid = D(i,j) % size(TD); -+ int val_idx = D(i,j) / size(TD); -+ int thr_idx = TD(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j + size<1>(S) + 3, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // S Labels -+ for (int i = 0, j = -1; i < size<0>(S); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(S); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ // D Labels -+ for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(D); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // end namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+// Config -+#if (__CUDACC_VER_MAJOR__ >= 12) -+# define CUTE_COPY_ATOM_TMA_SM90_ENABLED -+#endif -+ -+#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) -+#include -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits.hpp -new file mode 100644 -index 0000000..83cb056 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits.hpp -@@ -0,0 +1,76 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct Copy_Traits -+{ -+ static_assert(sizeof(CopyOperation) == 0, "Copy_Traits not implemented for this Copy_Operation."); -+}; -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, Stride<_0,_0>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout, Stride<_0,_0>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp -new file mode 100644 -index 0000000..13eb166 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm75.hpp -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout, -+ Stride<_32, _1>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, -+ Stride<_128, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2>>, -+ Stride,Stride< _1,_128>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout,_128>, -+ Stride, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2, _2>>, -+ Stride,Stride< _1,_128,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout, -+ Stride<_128, _1>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout,Shape <_16, _2, _4>>, -+ Stride,Stride< _1,_128,_1024>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = DstLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp -new file mode 100644 -index 0000000..089d193 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm80.hpp -@@ -0,0 +1,98 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template -+struct Copy_Traits> -+{ -+ // Logical thread id to thread idx (one-thread) -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout::value>>>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout::value>>>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Element copy selector -+template -+CUTE_HOST_DEVICE constexpr -+auto -+select_elementwise_copy(SrcTensor const&, DstTensor const&) -+{ -+ using SrcType = typename SrcTensor::value_type; -+ using DstType = typename DstTensor::value_type; -+ -+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) -+ if constexpr (is_gmem::value && is_smem::value && -+ sizeof(SrcType) == sizeof(DstType) && -+ (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) -+ { -+ return SM80_CP_ASYNC_CACHEALWAYS{}; -+ } else { -+ return UniversalCopy{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+#else -+ return UniversalCopy{}; -+#endif -+} -+ -+} -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp -new file mode 100644 -index 0000000..8c5e843 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90.hpp -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+template <> -+struct Copy_Traits -+{ -+ // Logical thread id to thread idx (warp) -+ using ThrID = Layout<_32>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = typename Copy_Traits::DstLayout; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = typename Copy_Traits::SrcLayout; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp -new file mode 100644 -index 0000000..18e22bf ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp -@@ -0,0 +1,795 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_LOAD /////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; -+ -+// The executable SM90_TMA_LOAD with tma_desc and tma_mbar -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD arguments -+ TmaDescriptor const& tma_desc_; -+ uint64_t& tma_load_mbar_; -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const dst_ptr, -+ Coord const& src_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(src_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_LOAD::copy(&tma_desc_, -+ tma_load_mbar_, -+ dst_ptr, -+ get(src_coord)...); -+ } -+ -+ // This is the copy_unpack dispatch for this Copy_Traits -+ // Src needs to be a gmem tensor with TmaCoordIterator .data() -+ // Dst needs to be a smem tensor -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor -+ static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD"); -+ -+ traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); -+ } -+}; -+ -+// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar -+// Use .with(tma_mbar) to construct an executable version -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Return TmaDescriptor/TensorMap -+ CUTE_HOST_DEVICE constexpr -+ TmaDescriptor const* -+ get_tma_descriptor() const { -+ return &tma_desc_; -+ } -+ -+ // Construct an executable SM90_TMA_LOAD with tma_mbar -+ CUTE_HOST_DEVICE constexpr -+ Copy_Traits -+ with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const { -+ // We accept multicast_mask here to keep the API for both atoms consistent -+ // assert(multicast_mask == 0); -+ (void) multicast_mask; -+ return {tma_desc_, tma_mbar}; -+ } -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() -+ template -+ CUTE_HOST_DEVICE friend constexpr void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) = delete; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; -+ -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD_MULTICAST arguments -+ TmaDescriptor const& tma_desc_; -+ uint64_t& tma_load_mbar_; -+ uint16_t const& multicast_mask_; -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const dst_ptr, -+ Coord const& src_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(src_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, -+ tma_load_mbar_, -+ multicast_mask_, -+ dst_ptr, -+ get(src_coord)...); -+ } -+ -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor -+ static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST"); -+ -+ traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); -+ } -+}; -+ -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_LOAD_MULTICAST arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Return TmaDescriptor/TensorMap -+ CUTE_HOST_DEVICE constexpr -+ TmaDescriptor const* -+ get_tma_descriptor() const { -+ return &tma_desc_; -+ } -+ -+ // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar -+ CUTE_HOST_DEVICE constexpr -+ Copy_Traits -+ with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { -+ return {tma_desc_, tma_load_mbar, multicast_mask}; -+ } -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() -+ template -+ CUTE_HOST_DEVICE friend constexpr void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) = delete; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+///////////////////////////// TMA_STORE ////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////// -+ -+// The executable SM90_TMA_STORE with tma_desc -+template -+struct Copy_Traits -+{ -+ using ThrID = Layout<_1>; -+ -+ // Map from (src-thr,src-val) to bit -+ using SrcLayout = Layout>; -+ // Map from (dst-thr,dst-val) to bit -+ using DstLayout = Layout>; -+ -+ // Reference map from (thr,val) to bit -+ using RefLayout = SrcLayout; -+ -+ // SM90_TMA_STORE arguments -+ TmaDescriptor tma_desc_; -+ GmemStrides g_stride_; -+ -+ // Generate the TMA coord tensor -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_tma_tensor(GShape const& g_shape) const { -+ static_assert(is_congruent::value); -+ constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), -+ g_shape, -+ g_stride_); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ copy_unpack_(void const* const src_ptr, -+ Coord const& dst_coord, seq) const -+ { -+#if 0 -+ print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", -+ threadIdx.x, threadIdx.y, threadIdx.z, -+ blockIdx.x, blockIdx.y, blockIdx.z); -+ print(" TMA Coord "); print(dst_coord); print("\n"); -+ print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), -+ uint64_t(tma_desc_.size1_), -+ uint64_t(tma_desc_.size2_), -+ uint64_t(tma_desc_.size3_))); print("\n"); -+#endif -+ -+ SM90_TMA_STORE::copy(&tma_desc_, -+ src_ptr, -+ get(dst_coord)...); -+ } -+ -+ // This is the copy_unpack dispatch for this Copy_Traits -+ // Src needs to be a smem tensor -+ // Dst needs to be a gmem tensor with TmaCoordIterator .data() -+ template -+ CUTE_HOST_DEVICE friend constexpr -+ void -+ copy_unpack(Copy_Traits const& traits, -+ Tensor const& src, -+ Tensor & dst) -+ { -+ static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); -+ //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor -+ -+ traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq{}); -+ } -+}; -+ -+// -+// MAKE_TMA_COPY and related -+// -+ -+template -+TMA::SmemSwizzleBits -+get_tma_swizzle_bits(ComposedLayout,Offset,SLayout>) -+{ -+ static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); -+ static_assert(S == 3, "Unsupported layout swizzle"); -+ -+ switch (B) { -+ default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle."); -+ case 3: return TMA::SmemSwizzleBits::B128; -+ case 2: return TMA::SmemSwizzleBits::B64; -+ case 1: return TMA::SmemSwizzleBits::B32; -+ case 0: return TMA::SmemSwizzleBits::DISABLE; -+ } -+} -+ -+template -+TMA::SmemSwizzleBits -+get_tma_swizzle_bits(Layout) -+{ -+ return TMA::SmemSwizzleBits::DISABLE; -+} -+ -+template -+auto -+get_nonswizzle_layout(ComposedLayout,Offset,SLayout> const& slayout) -+{ -+ return slayout.layout_fn(); -+} -+ -+template -+auto -+get_nonswizzle_layout(Layout const& slayout) -+{ -+ return slayout; -+} -+ -+/** Make a CuTe CTA-collective TiledCopy for a TMA operation. -+ * -+ * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE -+ * @param gtensor The GMEM Tensor to be involved in the TMA. -+ * @param slayout The SMEM Layout to be involved in the TMA. -+ * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. -+ * This is often the blk_shape that is used to tile the GMEM for CTAs: -+ * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor -+ * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 -+ * defining the multicast size (used to further partition the SMEM) -+ * Else, static-1 -+ * -+ * This code attempts to maximize the TMA box size. It does this by tracing -+ * the SMEM "vector" -- the inverse of the smem layout -- to find the largest -+ * contiguous array of smem that can be written to/from global memory given -+ * the constraints that the TMA instruction imposes. -+ * -+ * This is accomplished by assigning "basis" strides to the GMEM to track which -+ * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according -+ * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. -+ * -+ * Examples: -+ using T = float; -+ T* gptr = nullptr; -+ -+ { -+ // Simple 2D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM -+ auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // GMMA 2D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM -+ auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // 3D -+ Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM -+ auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); -+ } -+ -+ { -+ // cuTENSOR 4D -+ auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM -+ auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 128-elem from m: m0 must divide 128, -+ // m-last may be predicated -+ // Take 32-elem from k0, 2-elem from k1 -+ auto slayout = make_layout(cta_tile); // Col-Major SMEM -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); -+ } -+ * -+ * Check the TMA box size and desc: -+ print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ print("TMA desc : "); print(tma.tma_desc_); print("\n"); -+ * -+ * Usage: -+ Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor -+ Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA -+ Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor -+ -+ auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning -+ Tensor tAgA = cta_tma.partition_S(gA); // Partition for src -+ Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst -+ -+ copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params -+ */ -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp, -+ Tensor const& gtensor, -+ SLayout const& slayout, -+ CTA_Tile const& cta_tile, -+ Cluster_Size const& cluster_size) -+{ -+ static_assert((std::is_same::value && is_constant<1, Cluster_Size>::value) || -+ (std::is_same::value) || -+ (std::is_same::value && is_constant<1, Cluster_Size>::value)); -+ -+ using T = typename Tensor::value_type; -+ -+ // -+ // TMA parameter checking -+ // -+ -+ auto flat_glayout = flatten(gtensor.layout()); -+ -+ CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{}, -+ "CTA_Tile cannot have more than five modes, TMA arch restriction."); -+ CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{}, -+ "If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode."); -+ CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)), -+ "CTA_Tile must be compatible with SLayout."); -+ CUTE_STATIC_ASSERT_V(is_integral{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{}, -+ "Expecting a pow2 integral Cluster_Size leq 16."); -+ CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{}, -+ "ClusterShape must divide domain size of slayout."); -+ -+ // -+ // TMA slayout manipulation -+ // -+ -+ auto tma_multimode = rank(flat_glayout) > Int<5>{}; -+ -+ // Invert the smem to get the largest contiguous vector in the smem layout -+ auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout)); -+ // trunc_smem_idx -> trunc_smem_coord -+ -+ // Map from smem idx to a gmem mode -+ auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout)); -+ -+ // Truncate any incompatibilities -+ auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){ -+ [[maybe_unused]] auto v = basis_value(e); -+ return not is_constant<1,decltype(v)>{}; -+ }); -+ static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA."); -+ constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5)); -+ -+ // Keep only the static-1 basis modes into gmem -+ auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode); -+ // Keep only the portion each multicast CTA will be responsible for -+ auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size)); -+ -+ // -+ // TMA gtensor manipulation -+ // -+ -+ // Generate a TupleBasis for the gtensor -+ auto flat_gbasis = make_basis_like(shape(flat_glayout)); -+ -+ // Fold the flat_gbasis into the glayout -+ auto glayout_basis = make_layout(shape(gtensor), -+ stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), -+ make_layout(repeat_like(shape(gtensor), Int<2>{}))))); -+ -+ // Tile the modes of gtensor with cta_tile -+ auto cta_glayout_basis = composition(glayout_basis, cta_tile); -+ -+ // Check that the cta_tile selects modes from gtensor properly -+ for_each(flatten(stride(cta_glayout_basis)), [](auto d) { -+ static_assert(is_constant<1, decltype(d.value())>::value, -+ "CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout."); -+ }); -+ -+ // Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout -+ auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc)); -+ -+ // Append any missing basis on the end as size-1 modes b/c they got truncated -+ auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){ -+ auto k = find(init, e); -+ return remove(init); -+ }); -+ -+ // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode -+ auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc, -+ make_layout(repeat(Int<1>{}), missing_basis))); -+ -+#if 0 -+ print("g_layout : "); print(gtensor.layout()); print("\n"); -+ print("s_layout : "); print(slayout); print("\n"); -+ print("cta_tile : "); print(cta_tile); print("\n"); -+ print("cluster_size : "); print(cluster_size); print("\n"); -+ print("flat_gbasis : "); print(flat_gbasis); print("\n"); -+ print("cta_glayout : "); print(cta_glayout_basis); print("\n"); -+ print("inv_smem : "); print(inv_smem_layout); print("\n"); -+ print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); -+ print("missing_b : "); print(missing_basis); print("\n"); -+ print("tma_layout_cta: "); print(tma_layout_cta); print("\n"); -+#endif -+ -+ // -+ // TMA gmem desc info -+ // -+ -+ constexpr int TmaRANK = cute::min(rank(flat_glayout), 5); -+ void* gmem_address = (void*) gtensor.data(); -+ -+ cute::array gmem_prob_shape = {1,1,1,1,1}; -+ cute::array gmem_prob_stride = {0,0,0,0,0}; -+ for_each(make_seq{}, [&](auto i) { -+ // NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below -+ auto e = stride(tma_layout_cta); -+ constexpr int j = decltype(e.mode())::value; -+ constexpr int tma_i = i < 5 ? i : 4; -+ -+ // Problem stride -+ uint64_t stride_j = stride(flat_glayout) * sizeof(T); -+ uint64_t old_stride = gmem_prob_stride[tma_i]; -+ gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j); -+ -+ // Problem shape -+ uint64_t shape_j = shape(flat_glayout); -+ if (gmem_prob_stride[tma_i] != 0) { -+ // We're "resetting" this TMA mode and using it as a "multimode" -+ // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 -+ gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i]) -+ + (shape_j-1) * (stride_j / gmem_prob_stride[tma_i]) -+ + 1; -+ } else { -+ gmem_prob_shape[tma_i] = shape_j; -+ } -+ }); -+ -+ assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned -+ -+ assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 -+ assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 -+ -+ assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1 -+ assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 -+ assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) -+ -+ // -+ // TMA smem desc info -+ // -+ -+ // TMA smem box size -+ cute::array smem_box_shape = {1,1,1,1,1}; -+ for_each(make_seq{}, [&](auto i) { -+ uint32_t shape_i = shape(tma_layout_cta); -+ constexpr int tma_i = i < 5 ? i : 4; -+ if (tma_multimode && tma_i == 4) { -+ // We're "reusing" this TMA mode and using it as a "multimode" -+ smem_box_shape[tma_i] = 1; -+ } else { -+ smem_box_shape[tma_i] = shape_i; -+ } -+ }); -+ -+ // TMA smem mode strides -+ [[maybe_unused]] cute::array smem_box_stride = {1,1,1,1,1}; -+ -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 -+ assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 -+ -+ assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 -+ assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 -+ assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 -+ -+ // -+ // Construct the descriptor -+ // -+ -+ TmaDescriptor tma_desc = {0}; -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ -+ // -+ // TMA general info -+ // -+ -+ cuuint32_t tma_dim = TmaRANK; -+ CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); -+ CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; -+ CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; -+ CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; -+ -+ // TMA smem swizzle type -+ CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout)); -+ -+ CUresult result = cuTensorMapEncodeTiled( -+ &tma_desc, -+ tma_format, -+ tma_dim, -+ gmem_address, -+ gmem_prob_shape.data(), -+ gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 -+ smem_box_shape.data(), -+ smem_box_stride.data(), -+ tma_interleave, -+ smem_swizzle, -+ tma_l2Promotion, -+ tma_oobFill); -+ -+ if (result != CUDA_SUCCESS) { -+ std::cerr << "TMA Desc Addr: " << &tma_desc -+ << "\nformat " << tma_format -+ << "\ndim " << tma_dim -+ << "\ngmem_address " << gmem_address -+ << "\nglobalDim " << gmem_prob_shape -+ << "\nglobalStrides " << gmem_prob_stride -+ << "\nboxDim " << smem_box_shape -+ << "\nelementStrides " << smem_box_stride -+ << "\ninterleave " << tma_interleave -+ << "\nswizzle " << smem_swizzle -+ << "\nl2Promotion " << tma_l2Promotion -+ << "\noobFill " << tma_oobFill << std::endl; -+ std::cerr << "Error: Failed to intialize the TMA descriptor " << result << std::endl; -+ assert(false); -+ } -+#endif // (__CUDACC_VER_MAJOR__ >= 12) -+ -+ // -+ // Construct the Copy_Traits -+ // -+ -+ // Finally, get the inverse permutation of the E bases for the mocked gmem stride -+ auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { -+ auto k = find(stride(tma_layout_cta), E{}); -+ // NOTE: gcc 7.3.5 WAR -- avoid if constexpr -+ int32_t tma_coord_stride = int32_t(stride(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16)); -+ return conditional_return(tma_multimode && (k >= Int<4>{}), -+ E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride -+ E{}); -+ }); -+ -+ // Give that the profile of gtensor and fold it -+ auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat), -+ make_layout(repeat_like(shape(gtensor), Int<2>{})))); -+ -+ constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8; -+ using Traits = Copy_Traits, decltype(gmem_stride_bases)>; -+ -+#if 0 -+ print("num_bits : "); print(num_bits); print("\n"); -+ print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); -+#endif -+ -+ // -+ // Construct the TiledCopy -+ // -+ -+ // The ThrVal layout for 1 TMA instruction within cta_tile -+ auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{})); -+ // The ThrVal layout for N TMA instructions within cta_tile -+ auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size)); -+ -+#if 0 -+ print("layout_tv : "); print(layout_tv); print("\n"); -+#endif -+ -+ return TiledCopy, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases}; -+} -+ -+// Explicit defaulting -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp const& copy_op, -+ Tensor const& gtensor, -+ SLayout const& slayout) -+{ -+ return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); -+} -+ -+template -+CUTE_HOST -+auto -+make_tma_copy(CopyOp const& copy_op, -+ Tensor const& gtensor, -+ SLayout const& slayout, -+ Cluster_Size const& cluster_size) -+{ -+ return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_atom.hpp b/3rdparty/cutlass/include/cute/atom/mma_atom.hpp -new file mode 100644 -index 0000000..c3025f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_atom.hpp -@@ -0,0 +1,1081 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute { -+ -+// Generic mma_unpack for any MMA_Traits -+template -+CUTE_HOST_DEVICE constexpr -+void -+mma_unpack(MMA_Traits const&, -+ Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) -+{ -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); -+ -+ // Register value types from the MMA_Operation register arrays -+ using RegTypeD = typename std::remove_extent::type; -+ using RegTypeA = typename std::remove_extent::type; -+ using RegTypeB = typename std::remove_extent::type; -+ using RegTypeC = typename std::remove_extent::type; -+ constexpr int RegNumD = std::extent::value; -+ constexpr int RegNumA = std::extent::value; -+ constexpr int RegNumB = std::extent::value; -+ constexpr int RegNumC = std::extent::value; -+ -+ Tensor rA = recast(A); -+ Tensor rB = recast(B); -+ -+ CUTE_STATIC_ASSERT_V(size(rA) == Int{}); -+ CUTE_STATIC_ASSERT_V(size(rB) == Int{}); -+ -+ if constexpr (std::is_same::value) -+ { -+ static_assert(std::is_same::value, "GMMA C and D value_type must match."); -+ static_assert(std::is_same::value, "GMMA C and D layouts must match."); -+ // assert((void*)&C == (void*)&D); -+ -+ Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D -+ -+ //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); -+ -+ detail::explode(Operation::fma, -+ rA, make_int_sequence{}, -+ rB, make_int_sequence{}, -+ rC, make_int_sequence{}); -+ } else -+ { -+ Tensor rD = recast(D); -+ Tensor rC = recast(C); -+ -+ CUTE_STATIC_ASSERT_V(size(rD) == Int{}); -+ CUTE_STATIC_ASSERT_V(size(rC) == Int{}); -+ -+ detail::explode(Operation::fma, -+ rD, make_int_sequence{}, -+ rA, make_int_sequence{}, -+ rB, make_int_sequence{}, -+ rC, make_int_sequence{}); -+ } -+} -+ -+ -+namespace detail { -+ -+template -+struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; -+template -+struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; -+ -+template -+struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; -+template -+struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; -+ -+template -+struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; -+template -+struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; -+ -+} // end namespace detail -+ -+template -+struct MMA_Atom; -+ -+template -+struct MMA_Atom : MMA_Atom> -+{}; -+ -+template -+struct MMA_Atom> -+ : MMA_Traits -+{ -+ using Traits = MMA_Traits; -+ -+ // Element value types from the MMA_Traits -+ using ValTypeD = typename Traits::ElementDVal; -+ using ValTypeA = typename Traits::ElementAVal; -+ using ValTypeB = typename Traits::ElementBVal; -+ using ValTypeC = typename Traits::ElementCVal; -+ -+ // Thr-Val layouts from the MMA_Traits -+ using Shape_MNK = typename Traits::Shape_MNK; -+ using ThrID = typename Traits::ThrID; -+ using LayoutC_TV = typename Traits::CLayout; -+ using LayoutA_TV = typename Traits::ALayout; -+ using LayoutB_TV = typename Traits::BLayout; -+ -+ // Fragment value types from the MMA_Traits (optional, defaults to Val type) -+ using FrgTypeD = typename detail::FrgTypeC_or_Default::type; -+ using FrgTypeA = typename detail::FrgTypeA_or_Default::type; -+ using FrgTypeB = typename detail::FrgTypeB_or_Default::type; -+ using FrgTypeC = typename detail::FrgTypeC_or_Default::type; -+ -+ // Additional Trait parameters/transformations -+ template -+ CUTE_HOST_DEVICE -+ auto -+ with(TraitsArgs&&... args) const { -+ auto traits = Traits::with(std::forward(args)...); -+ return MMA_Atom{traits}; -+ } -+ -+ // Print thread and data layouts for debugging -+ CUTE_HOST_DEVICE static -+ void -+ print_all() -+ { -+ print("ThrID: "); print(ThrID{}); print("\n"); -+ print("LayoutA_TV: "); print(LayoutA_TV{}); print("\n"); -+ print("LayoutB_TV: "); print(LayoutB_TV{}); print("\n"); -+ print("LayoutC_TV: "); print(LayoutC_TV{}); print("\n"); -+ } -+ -+ // -+ // Tensor call interfaces -+ // -+ -+ // Cast, check, and call fma -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ call(Tensor & D, -+ Tensor const& A, -+ Tensor const& B, -+ Tensor const& C) const -+ { -+ static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); -+ static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); -+ static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); -+ static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); -+ -+ return mma_unpack(*this, D, A, B, C); -+ } -+ -+ // Three arguments reproduces C -+ template -+ CUTE_HOST_DEVICE constexpr -+ void -+ call(Tensor const& A, -+ Tensor const& B, -+ Tensor & C) const -+ { -+ return call(C, A, B, C); -+ } -+ -+ // -+ // make_fragment_A|B|C -+ // These functions are awkward as they expect already-partitioned tensors -+ // resulting from a previous call to partition_A|B|C -+ // The reasoning is that we can inspect the layout of the partitioned data -+ // and attempt to match it in generated fragment to promote vectorization -+ // when copying from partition to fragment. -+ // -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_C(CTensor&& ctensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN -+ CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); -+ -+ // C is a bit special because we are after accumulators here -+ // The input/output type doesn't have to match the accumulator type -+ //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); -+ -+ // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor -+ return make_tensor(shape(ctensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_A(ATensor&& atensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK -+ CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); -+ static_assert(std::is_same::value_type>::value, "Expecting ValTypeA type"); -+ -+ if constexpr (has_dereference::value) { -+ return recast(std::forward(atensor)); -+ } else { -+ return make_tensor(make_fragment_like(atensor.layout())); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ template -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ make_fragment_B(BTensor&& btensor) -+ { -+ // Check that this tensor is likely already partitioned -+ CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK -+ CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); -+ static_assert(std::is_same::value_type>::value, "Expecting ValTypeB type"); -+ -+ if constexpr (has_dereference::value) { -+ return recast(std::forward(btensor)); -+ } else { -+ return make_tensor(make_fragment_like(btensor.layout())); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+}; -+ -+// -+// A tiling of mma atoms -+// -+ -+template -+struct ThrMMA; -+ -+template >, -+ class ValLayoutMNK = Layout>, -+ class PermutationsMNK = Tile> -+struct TiledMMA : MMA_Atom -+{ -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 ValLayoutMNK"); -+ static_assert(rank_v == 3, "TiledMMA requires rank-3 PermutationsMNK"); -+ -+ using AtomShape_MNK = typename MMA_Atom::Shape_MNK; -+ -+ using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; -+ using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; -+ using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; -+ -+ // ThrV -> thread_idx -+ using AtomThrID = typename MMA_Atom::ThrID; -+ -+ // (M,N,K) -+ using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}), -+ size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}), -+ size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{}))); -+ -+ // thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx -+ using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); -+ -+ // thr_idx -> (ThrV,ThrM,ThrN,ThrK) -+ using TidLayout = decltype(right_inverse(ThrLayoutVMNK{})); -+ -+ // Tile a tensor or a layout from shape -+ // (M,N,...) -+ // to shape -+ // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx -+ // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestM: The values tiled in M. -+ // RestN: The values tiled in N. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_C(CTensor&& ctensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), -+ left_inverse(get<1>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), -+ make_layout(size<1>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) -+ -+ // Tile the tensor for the C-threads -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<1>(ThrLayoutVMNK{})), -+ make_layout(size<2>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (M,N,...) -+ // to (thr_idx,(FrgV,(RestM,RestN,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_C(CTensor&& ctensor) -+ { -+ // Don't need a ctile composition because ThrK is last mode in TidLayout -+ -+ return thrfrg_C(ctensor).compose(TidLayout{}, _); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (M,K,...) -+ // to shape -+ // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx -+ // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestM: The values tiled in M. -+ // RestK: The values tiled in K. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_A(ATensor&& atensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), -+ left_inverse(get<2>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), -+ make_layout(size<2>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) -+ -+ // Tile the tensor for the Thread -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<1>(ThrLayoutVMNK{})), -+ make_layout(size<3>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (M,K,...) -+ // to (thr_idx,(FrgV,(RestM,RestK,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_A(ATensor&& atensor) -+ { -+ auto atile = make_tile(_, -+ make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), -+ make_stride( Int<1>{} , Int<0>{} )), -+ _)); -+ // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) -+ -+ return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _); -+ } -+ -+ // Tile a tensor or a layout from shape -+ // (N,K,...) -+ // to shape -+ // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -+ // where -+ // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx -+ // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx -+ // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx -+ // FrgV: The values local to an MMA. -+ // RestN: The values tiled in N. -+ // RestK: The values tiled in K. -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ thrfrg_B(BTensor&& btensor) -+ { -+ CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); -+ CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); -+ CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); -+ -+ // Reorder the tensor for the TiledAtom -+ auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})), -+ left_inverse(get<2>(PermutationsMNK{}))); -+ auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) -+ -+ // Tile the tensor for the Atom -+ auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), -+ make_layout(size<2>(AtomShape_MNK{}))); -+ auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) -+ -+ // Transform the Atom mode from (M,K) to (Thr,Val) -+ auto tv_tensor = a_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) -+ -+ // Tile the tensor for the Thread -+ auto thr_tile = make_tile(_, -+ make_tile(make_layout(size<2>(ThrLayoutVMNK{})), -+ make_layout(size<3>(ThrLayoutVMNK{})))); -+ auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) -+ -+ return thr_tensor; -+ } -+ -+ // Tile from (N,K,...) -+ // to (thr_idx,(FrgV,(RestN,RestK,...))) -+ template -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ tidfrg_B(BTensor&& btensor) -+ { -+ auto btile = make_tile(_, -+ make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), -+ make_stride( Int<0>{} , Int<1>{} )), -+ _)); -+ // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) -+ -+ return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ get_slice(ThrIdx const& thr_idx) -+ { -+ auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx); -+ return ThrMMA(thr_vmnk); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE static constexpr -+ auto -+ get_thread_slice(ThrIdx const& thr_idx) -+ { -+ return get_slice(thr_idx); -+ } -+ -+ // -+ // Utility for printing and visualization -+ // -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutC_MN() -+ { -+ // (M,N) -> (M,N) -+ auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); -+ // (cthrid,val) -> (M,N) -+ auto layoutC_TV = thrfrg_C(ref_C); -+ // (M,N) -> (cthrid,frg) -+ auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); -+ -+ // cthrid = (v,m,n) -> thr_idx -+ auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{}); -+ -+ return cute::make_tuple(layoutC_MN, thrID_C); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutC_TV() -+ { -+ // (M,N) -> (M,N) -+ auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); -+ -+ return tidfrg_C(ref_C); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutA_MK() -+ { -+ // (M,K) -> (M,K) -+ auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ // (athrid,val) -> (M,K) -+ auto layoutA_TV = thrfrg_A(ref_A); -+ // (M,K) -> (athrid,frg) -+ auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); -+ -+ // athrid = (v,m,k) -> thr_idx -+ auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_); -+ -+ return cute::make_tuple(layoutA_MK, thrID_A); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutA_TV() -+ { -+ // (M,K) -> (M,K) -+ auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ -+ return tidfrg_A(ref_A); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutB_NK() -+ { -+ // (N,K) -> (N,K) -+ auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ // (bthrid,val) -> (N,K) -+ auto layoutB_TV = thrfrg_B(ref_B); -+ // (N,K) -> (bthrid,frg) -+ auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); -+ -+ // bthrid = (v,n,k) -> thr_idx -+ auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_); -+ -+ return cute::make_tuple(layoutB_NK, thrID_B); -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ get_layoutB_TV() -+ { -+ // (N,K) -> (N,K) -+ auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); -+ -+ return tidfrg_B(ref_B); -+ } -+}; -+ -+template -+struct ThrMMA : TiledMMA -+{ -+ // Use ThrVMNK and thrfrg rather than thr_idx and tidfrg -+ // to support swizzled threads partitioning dynamic layouts -+ ThrVMNK thr_vmnk_; -+ -+ CUTE_HOST_DEVICE constexpr -+ ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_C(CTensor&& ctensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(ctensor).data(), thrfrg_C(ctensor.layout())); -+ -+ auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); -+ return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_A(ATensor&& atensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(atensor).data(), thrfrg_A(atensor.layout())); -+ -+ auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); -+ return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_B(BTensor&& btensor) const -+ { -+ auto thr_tensor = make_tensor(std::forward(btensor).data(), thrfrg_B(btensor.layout())); -+ -+ auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); -+ return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_C(CTensor&& ctensor) const -+ { -+ return make_fragment_C(partition_C(ctensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_A(ATensor&& atensor) const -+ { -+ return make_fragment_A(partition_A(atensor)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ partition_fragment_B(BTensor&& btensor) const -+ { -+ return make_fragment_B(partition_B(btensor)); -+ } -+}; -+ -+// -+// These tile the MMA_Atom as a whole -+// -+ -+template >, -+ class MMAValLayout = Layout>, -+ class Permutations = Tile> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tiled_mma(MMA_Atom const&, -+ MMAThrLayout const& thr_layout = {}, -+ MMAValLayout const& val_layout = {}, -+ Permutations const& permutations = {}) -+{ -+ auto thr_layout_mnk = append<3>(thr_layout, Layout<_1>{}); -+ auto val_layout_mnk = append<3>(val_layout, Layout<_1>{}); -+ auto permutation_mnk = append<3>(permutations, _); -+ -+ return TiledMMA, -+ decltype(thr_layout_mnk), -+ decltype(val_layout_mnk), -+ decltype(permutation_mnk)>{}; -+} -+ -+template >, -+ class MMAValLayout = Layout>, -+ class Permutations = Tile> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tiled_mma(MMA_Op const&, -+ MMAThrLayout const& thr_layout = {}, -+ MMAValLayout const& val_layout = {}, -+ Permutations const& permutations = {}) -+{ -+ // Attempt to wrap in an MMA_Atom<> and forward -+ return make_tiled_mma(MMA_Atom{}, thr_layout, val_layout, permutations); -+} -+ -+// -+// partition_fragment_C -- static context -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+partition_fragment_C(TiledMMA, Shape_MN shapeMN) -+{ -+ constexpr int R = rank_v; -+ static_assert(R >= 2, "Must have at least rank-2"); -+ auto atomMNK = typename TiledMMA::AtomShape_MNK{}; -+ auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; -+ -+ auto V = size<1>(typename TiledMMA::AtomLayoutC_TV{}); -+ auto M = shape_div(size<0>(shapeMN), size<0>(atomMNK) * size<1>(thrVMNK)); -+ auto N = shape_div(size<1>(shapeMN), size<1>(atomMNK) * size<2>(thrVMNK)); -+ auto frg_shape = tuple_cat(make_shape(V,M,N), take<2,R>(shapeMN)); -+ -+ return make_tensor::FrgTypeC>(frg_shape); -+} -+ -+// partition_fragment_A and partition_fragment_B often depend on the -+// layout of A and B and/or the thread_idx that is requesting the partition. -+// For these reasons, they should not be used in a static context. -+// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. -+ -+// -+// Size -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_size(TiledMMA const& mma) -+{ -+ return size(typename TiledMMA::TiledShape_MNK{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(TiledMMA const& mma) -+{ -+ return size(typename TiledMMA::ThrLayoutVMNK{}); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+print_latex(TiledMMA const& mma) -+{ -+ auto layout_and_thrid_C = mma.get_layoutC_MN(); -+ auto layoutC_MN = get<0>(layout_and_thrid_C); -+ auto thrID_C = get<1>(layout_and_thrid_C); -+ -+ auto layout_and_thrid_A = mma.get_layoutA_MK(); -+ auto layoutA_MK = get<0>(layout_and_thrid_A); -+ auto thrID_A = get<1>(layout_and_thrid_A); -+ -+ auto layout_and_thrid_B = mma.get_layoutB_NK(); -+ auto layoutB_NK = get<0>(layout_and_thrid_B); -+ auto thrID_B = get<1>(layout_and_thrid_B); -+ -+ print_latex_mma(layoutC_MN, thrID_C, -+ layoutA_MK, thrID_A, -+ layoutB_NK, thrID_B); -+} -+ -+// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs... -+template -+CUTE_HOST_DEVICE -+auto -+print_latex_2(TiledMMA const& mma) -+{ -+ print_latex_mma(typename TiledMMA::TiledShape_MNK{}, -+ mma.get_layoutC_TV(), -+ mma.get_layoutA_TV(), -+ mma.get_layoutB_TV()); -+} -+ -+// MNK MMA Layout to console printer -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx -+ LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ assert(size<0>(A) == size<0>(C)); -+ assert(size<0>(B) == size<1>(C)); -+ assert(size<1>(A) == size<1>(B)); -+ -+ int a_width = size<1>(A) * 6 + 4; -+ -+ // Print out B (white-shifted) k-by-n -+ for (int k = 0; k < size<1>(B); ++k) { -+ // Header -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("+-----"); -+ printf("+\n"); -+ // Values -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); -+ printf("|\n"); -+ } -+ // Footer -+ printf("%*s", a_width, ""); -+ for (int n = 0; n < size<0>(B); ++n) printf("+-----"); -+ printf("+\n\n"); -+ -+ // Print out A m-by-k and C m-by-n -+ for (int m = 0; m < size<0>(A); ++m) { -+ // Header -+ for (int k = 0; k < size<1>(A); ++k) printf("+-----"); -+ printf("+ "); -+ for (int n = 0; n < size<1>(C); ++n) printf("+-----"); -+ printf("+\n"); -+ // Values -+ for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); -+ printf("| "); -+ for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); -+ printf("|\n"); -+ } -+ // Footer -+ for (int k = 0; k < size<1>(A); ++k) printf("+-----"); -+ printf("+ "); -+ for (int n = 0; n < size<1>(C); ++n) printf("+-----"); -+ printf("+\n"); -+} -+ -+// MNK MMA Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx -+ LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx -+ LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ assert(size<0>(A) == size<0>(C)); -+ assert(size<0>(B) == size<1>(C)); -+ assert(size<1>(A) == size<1>(B)); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% LayoutC: "); print(C); printf("\n"); -+ printf("%% ThrIDC : "); print(TC); printf("\n"); -+ printf("%% LayoutA: "); print(A); printf("\n"); -+ printf("%% ThrIDA : "); print(TA); printf("\n"); -+ printf("%% LayoutB: "); print(B); printf("\n"); -+ printf("%% ThrIDB : "); print(TB); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // C starting at 0,0 -+ for (int m = 0; m < size<0>(C); ++m) { -+ for (int n = 0; n < size<1>(C); ++n) { -+ int thrid = C(m,n) % size(TC); -+ int val_idx = C(m,n) / size(TC); -+ int thr_idx = TC(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ m, n, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // A starting at 0,-size<1>(A)-1 -+ for (int m = 0; m < size<0>(A); ++m) { -+ for (int k = 0; k < size<1>(A); ++k) { -+ int thrid = A(m,k) % size(TA); -+ int val_idx = A(m,k) / size(TA); -+ int thr_idx = TA(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ m, k-1-size<1>(A), -+ thr_idx, val_idx); -+ } -+ } -+ -+ // B starting at -size<1>(B)-1,0 -+ for (int n = 0; n < size<0>(B); ++n) { -+ for (int k = 0; k < size<1>(B); ++k) { -+ int thrid = B(n,k) % size(TB); -+ int val_idx = B(n,k) / size(TB); -+ int thr_idx = TB(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ k-1-size<1>(B), n, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // A labels -+ for (int m = 0, k = -1; m < size<0>(A); ++m) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); -+ } -+ for (int k = 0, m = -1; k < size<1>(A); ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); -+ } -+ // B labels -+ for (int n = 0, k = -1; n < size<0>(B); ++n) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); -+ } -+ for (int k = 0, n = -1; k < size<1>(B); ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex_mma(Shape_MNK const& shape_mnk, -+ LayoutC const& C, // (thr_idx,vid) -> (m,n) -+ LayoutA const& A, // (thr_idx,vid) -> (m,k) -+ LayoutB const& B) // (thr_idx,vid) -> (n,k) -+{ -+ CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); -+ CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass{standalone}\n" -+ "\\usepackage{tikz}\n" -+ "\\usetikzlibrary{external}\n" -+ "\\tikzexternalize\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% Shape_MNK: "); print(shape_mnk); printf("\n"); -+ printf("%% LayoutC : "); print(C); printf("\n"); -+ printf("%% LayoutA : "); print(A); printf("\n"); -+ printf("%% LayoutB : "); print(B); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ int M = size<0>(shape_mnk); -+ int N = size<1>(shape_mnk); -+ int K = size<2>(shape_mnk); -+ -+ // C starting at 0,0 -+ bool c_filled[M][N] = {}; -+ for (int t = 0; t < size<0>(C); ++t) { -+ for (int v = 0; v < size<1>(C); ++v) { -+ int m = C(t,v) % M; -+ int n = C(t,v) / M; -+ -+ if (not c_filled[m][n]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ m, n, -+ t, v); -+ c_filled[m][n] = true; -+ } -+ } -+ } -+ -+ // A starting at 0,-size<1>(A)-1 -+ bool a_filled[M][K] = {}; -+ for (int t = 0; t < size<0>(A); ++t) { -+ for (int v = 0; v < size<1>(A); ++v) { -+ int m = A(t,v) % M; -+ int k = A(t,v) / M; -+ -+ if (not a_filled[m][k]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ m, k - 1 - K, -+ t, v); -+ a_filled[m][k] = true; -+ } -+ } -+ } -+ -+ // B starting at -size<1>(B)-1,0 -+ bool b_filled[N][K] = {}; -+ for (int t = 0; t < size<0>(B); ++t) { -+ for (int v = 0; v < size<1>(B); ++v) { -+ int n = B(t,v) % N; -+ int k = B(t,v) / N; -+ -+ if (not b_filled[n][k]) { -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[t % 8], -+ k - 1 - K, n, -+ t, v); -+ b_filled[n][k] = true; -+ } -+ } -+ } -+ -+ // A labels -+ for (int m = 0, k = -1; m < M; ++m) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m); -+ } -+ for (int k = 0, m = -1; k < K; ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k); -+ } -+ // B labels -+ for (int n = 0, k = -1; n < N; ++n) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n); -+ } -+ for (int k = 0, n = -1; k < K; ++k) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // namespace cute -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits.hpp -new file mode 100644 -index 0000000..a8c3323 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits.hpp -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct MMA_Traits -+{ -+ static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); -+}; -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = D; -+ using ElementAVal = A; -+ using ElementBVal = B; -+ using ElementCVal = C; -+ -+ // Logical shape of the MMA -+ using Shape_MNK = Shape<_1,_1,_1>; -+ -+ // Logical thread id (tid) -> tidx -+ using ThrID = Layout<_1>; -+ -+ // (Logical thread id (tid), Logical value id (vid)) -> coord -+ -+ // (tid,vid) -> (m,k) -+ using ALayout = Layout>; -+ // (tid,vid) -> (n,k) -+ using BLayout = Layout>; -+ // (tid,vid) -> (m,n) -+ using CLayout = Layout>; -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp -new file mode 100644 -index 0000000..85d4e98 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm61.hpp -@@ -0,0 +1,73 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_1,_1,_4>; -+ using ThrID = Layout<_1>; -+ using ALayout = Layout>; -+ using BLayout = Layout>; -+ using CLayout = Layout>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int16_t; -+ using ElementBVal = int16_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_1,_1,_2>; -+ using ThrID = Layout<_1>; -+ using ALayout = Layout>; -+ using BLayout = Layout>; -+ using CLayout = Layout>; -+}; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp -new file mode 100644 -index 0000000..7943035 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm70.hpp -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+namespace { -+ -+// Logical thread id to thread idx (quadpair) -+using SM70_QuadPair = Layout, -+ Stride<_1,_16>>; -+// (T8,V4) -> (M8,K4) -+using SM70_8x4_Row = Layout, -+ Stride<_1,_8>>; -+// (T8,V4) -> (M8,K4) -+using SM70_8x4_Col = Layout,_4>, -+ Stride,_1>>; -+// (T8,V8) -> (M8,N8) -+using SM70_8x8_16b = Layout, -+ Stride<_1,_8>>; -+// (T8,V8) -> (M8,N8) -+using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, -+ Stride,Stride<_8,_2,_32>>>; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_16b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Col; -+ using BLayout = SM70_8x4_Row; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = SM70_QuadPair; -+ using ALayout = SM70_8x4_Row; -+ using BLayout = SM70_8x4_Col; -+ using CLayout = SM70_8x8_32b; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp -new file mode 100644 -index 0000000..405e871 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm75.hpp -@@ -0,0 +1,81 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_1>>>; -+ using BLayout = Layout,_2>, -+ Stride,_8>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_1>>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_8,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_4>, -+ Stride,_8>>; -+ using BLayout = Layout,_4>, -+ Stride,_8>>; -+ using CLayout = Layout,_2>, -+ Stride,_8>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp -new file mode 100644 -index 0000000..6636b7a ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm80.hpp -@@ -0,0 +1,446 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+namespace { -+ -+// (T32,V1) -> (M8,N8) -+using SM80_8x4 = Layout,_1>, -+ Stride,_0>>; -+// (T32,V2) -> (M8,N8) -+using SM80_8x8_Row = Layout,_2>, -+ Stride,_8>>; -+// (T32,V4) -> (M8,N16) -+using SM80_8x16_Row = Layout,_4>, -+ Stride,_8>>; -+// (T32,V4) -> (M16,N8) -+using SM80_16x8_Row = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_16x8_Row; -+ using BLayout = SM80_8x8_Row; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _2,_2, _2>>, -+ Stride,Stride<_16,_8,_128>>>; -+ using BLayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = cutlass::tfloat32_t; -+ using ElementBVal = cutlass::tfloat32_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_2>, -+ Stride,_8>>; -+ using BLayout = SM80_8x4; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = float; -+ using ElementAVal = cutlass::tfloat32_t; -+ using ElementBVal = cutlass::tfloat32_t; -+ using ElementCVal = float; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _2>, -+ Stride,_32>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_8,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_8x4; -+ using BLayout = SM80_8x4; -+ using CLayout = SM80_8x8_Row; -+}; -+ -+// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_8,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = SM80_8x16_Row; -+ using BLayout = SM80_8x16_Row; -+ using CLayout = SM80_8x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _4,_2>>, -+ Stride,Stride<_16,_8>>>; -+ using BLayout = SM80_8x16_Row; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_32>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape < _4,_2, _2>>, -+ Stride,Stride<_16,_8,_256>>>; -+ using BLayout = Layout, Shape <_4, _2>>, -+ Stride, Stride<_8,_128>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits {}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = cute::uint1b_t; -+ using ElementBVal = cute::uint1b_t; -+ using ElementCVal = int32_t; -+ -+ using Shape_MNK = Shape<_16,_8,_256>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout>, -+ Stride<_64,Stride<_64,_16,_8,_2048>>>; -+ using BLayout = Layout>, -+ Stride<_32,Stride< _1,_1024>>>; -+ using CLayout = SM80_16x8_Row; -+}; -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp -new file mode 100644 -index 0000000..b7a12b9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90.hpp -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_4>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,_2>, -+ Stride,_8>>; -+ using BLayout = Layout,_1>, -+ Stride,_0>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_8>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _2>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _2>, -+ Stride,_32>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+template <> -+struct MMA_Traits -+{ -+ using ElementDVal = double; -+ using ElementAVal = double; -+ using ElementBVal = double; -+ using ElementCVal = double; -+ -+ using Shape_MNK = Shape<_16,_8,_16>; -+ using ThrID = Layout<_32>; -+ using ALayout = Layout,Shape <_2, _4>>, -+ Stride,Stride<_8,_64>>>; -+ using BLayout = Layout, _4>, -+ Stride,_32>>; -+ using CLayout = Layout,Shape < _2,_2>>, -+ Stride,Stride<_16,_8>>>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////// -+//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+template <> -+struct MMA_Traits -+ : MMA_Traits -+{ -+ using ElementDVal = complex; -+ using ElementAVal = complex; -+ using ElementBVal = complex; -+ using ElementCVal = complex; -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp -new file mode 100644 -index 0000000..d390daf ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp -@@ -0,0 +1,2975 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute { -+ -+namespace GMMA { -+ -+/////////////////////////////////////////// -+// Common layouts for GMMA Shared Memory // -+/////////////////////////////////////////// -+ -+// M|N-major GMMA layouts in units of bits -+using Layout_MN_INTER_Atom_Bits = Layout,Stride<_1,_128>>; -+using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; -+using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; -+using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; -+ -+// K-major GMMA layouts in units of bits -+using Layout_K_INTER_Atom_Bits = Layout,Stride<_128,_1>>; -+using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; -+using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; -+using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; -+ -+// M|N-major layouts in units of Type -+template -+using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); -+template -+using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); -+template -+using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); -+template -+using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); -+ -+// K-major layouts in units of Type -+template -+using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); -+template -+using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); -+template -+using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); -+template -+using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); -+ -+// With GMMA::Major param -+template -+using Layout_INTER_Atom = typename std::conditional, -+ Layout_K_INTER_Atom>::type; -+template -+using Layout_SW32_Atom = typename std::conditional, -+ Layout_K_SW32_Atom>::type; -+template -+using Layout_SW64_Atom = typename std::conditional, -+ Layout_K_SW64_Atom>::type; -+template -+using Layout_SW128_Atom = typename std::conditional, -+ Layout_K_SW128_Atom>::type; -+ -+// Helper for GMMA smem selection that considers a tensor TileShape: -+// (BLK_MN, BLK_K) -+// or hierarchically -+// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) -+// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 -+template -+CUTE_HOST_DEVICE constexpr -+auto -+smem_selector() -+{ -+ auto BLK_MN0 = size<0>(BLK_MN{}); -+ auto BLK_K0 = size<0>(BLK_K{}); -+ -+ static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); -+ static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); -+ -+ -+ if constexpr (major == GMMA::Major::MN) { -+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW128_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW64_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { -+ return GMMA::Layout_MN_SW32_Atom{}; -+ } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { -+ return GMMA::Layout_MN_INTER_Atom{}; -+ } else { -+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, -+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); -+ } -+ } else if constexpr (major == GMMA::Major::K) { -+ if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { -+ return GMMA::Layout_K_SW128_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { -+ return GMMA::Layout_K_SW64_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { -+ return GMMA::Layout_K_SW32_Atom{}; -+ } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { -+ return GMMA::Layout_K_INTER_Atom{}; -+ } else { -+ static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, -+ "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); -+ } -+ } -+} -+ -+// -+// Tensor to LayoutType utility -+// -+ -+// smem_ptr_swizzle LayoutType -+template -+CUTE_HOST_DEVICE constexpr -+LayoutType -+layout_type(Tensor>>, -+ Layout> const&) -+{ -+ static_assert(M == 4, "Unsupported layout swizzle"); -+ static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); -+ static_assert(S == 3, "Unsupported layout swizzle"); -+ -+ switch (B) { -+ case 0: return LayoutType::INTERLEAVE; -+ case 1: return LayoutType::B32; -+ case 2: return LayoutType::B64; -+ case 3: return LayoutType::B128; -+ } -+ return LayoutType::INTERLEAVE; // ERROR -+} -+ -+// smem_ptr non-swizzled LayoutType -+template -+CUTE_HOST_DEVICE constexpr -+LayoutType -+layout_type(Tensor>, -+ Layout> const&) -+{ -+ return LayoutType::INTERLEAVE; -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// Construction method for GMMA Descriptors -+/////////////////////////////////////////////////////////////////////////////// -+ -+/** -+* /////////////////////////////// -+* // make_gmma_desc // -+* /////////////////////////////// -+* Each GmmaDescriptor Major-MN describes a canonical layout of the form -+* -+* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) -+* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) -+* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) -+* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) -+* -+* where -+* T : sizeof(uint128_t) / sizeof(value_type) -+* m : integer in [1,16] corresponding to GMMA shape -+* k : integer in [1,32] corresponding to GMMA shape -+* SBO: stride byte offset -+* LBO: leading byte offset -+* -+* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. -+* For example, -+* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); -+* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. -+* -+* ////////////////////////////// -+* // make_gmma_desc // -+* ////////////////////////////// -+* Each GmmaDescriptor Major-K describes a canonical layout of the form -+* -+* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) -+* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) -+* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) -+* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) -+* -+* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. -+* For example, -+* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); -+* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. -+*/ -+template -+CUTE_HOST_DEVICE constexpr -+GmmaDescriptor -+make_gmma_desc(Tensor const& tensor) -+{ -+ static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); -+ static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); -+ using value_type = typename TEngine::value_type; -+ -+ Tensor u128_tensor = recast(tensor); -+ -+ // Result -+ GmmaDescriptor desc; -+ -+ // Layout type -+ constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); -+ desc.layout_type_ = uint8_t(LAYOUT_TYPE); -+ -+ // Start address (4LSB not included) -+ uint32_t start_address = cast_smem_ptr_to_uint(u128_tensor.data().get()); -+ desc.start_address_ = start_address >> 4; -+ -+ constexpr uint8_t base_offset = 0; -+ desc.base_offset_ = base_offset; -+ -+ // LayoutType meta -+ constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : -+ LAYOUT_TYPE == GMMA::LayoutType::B32 ? 2 : -+ LAYOUT_TYPE == GMMA::LayoutType::B64 ? 4 : -+ LAYOUT_TYPE == GMMA::LayoutType::B128 ? 8 : -1; -+ -+ if constexpr (MajorMode == GMMA::Major::MN) -+ { -+ /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form -+ * -+ * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) -+ * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) -+ * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) -+ * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) -+ */ -+ static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{}, // K size -+ "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits."); -+ -+ // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) -+ Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); -+ -+ // Check ranks of canonical -+ CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); -+ CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); -+ // Check canonical mode strides -+ constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); -+ constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; -+ static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); -+ constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); -+ constexpr uint32_t expected_stride_10 = W; -+ static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); -+ -+ // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) -+ constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); -+ constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); -+ -+ desc.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; -+ desc.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; -+ } -+ else if constexpr (MajorMode == GMMA::Major::K) -+ { -+ /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form -+ * -+ * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) -+ * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) -+ * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) -+ * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) -+ */ -+ CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size -+ "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); -+ CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{}, // K size -+ "Not a canonical GMMA_K Layout: Expected K-size 2 (in units of uint128_t)."); -+ -+ // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) -+ Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); -+ -+ // Check ranks of canonical -+ CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); -+ CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); -+ // Check canonical mode strides -+ constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); -+ constexpr uint32_t expected_stride_00 = W; -+ static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); -+ constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); -+ constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; -+ static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); -+ -+ // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) -+ constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); -+ -+ desc.stride_byte_offset_ = stride_01; -+ desc.leading_byte_offset_ = stride_10; -+ } else { -+ static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); -+ } -+ -+#if 0 -+ // DEBUG and SANITY -+ assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation -+ assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later -+ if (thread0()) { -+ print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); -+ print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); -+ //print(" desc canonical layout: "); print(canonical_layout); print("\n"); -+ print(desc); -+ } -+#endif -+ -+ return desc; -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// Higher level GMMA Descriptor utilities -+/////////////////////////////////////////////////////////////////////////////// -+ -+struct gmma_descriptor_iterator -+{ -+ GmmaDescriptor desc_; -+ -+ // Dereference returns the GmmaDescriptor -+ CUTE_HOST_DEVICE constexpr -+ GmmaDescriptor const& operator*() const { return desc_; } -+ -+ // Advance and return a new GmmaDescriptor -+ template -+ CUTE_HOST_DEVICE constexpr -+ GmmaDescriptor operator[](Index const& i) const { return *(*this + i); } -+ -+ // Return an advanced iterator -+ template -+ CUTE_HOST_DEVICE constexpr -+ gmma_descriptor_iterator operator+(Index const& offset) const -+ { -+ // offset is in the units of uint128_t (4LSB of start_address not included) -+ -+ //GmmaDescriptor desc = desc_; -+ //desc.start_address_ += uint16_t(offset); -+ //desc.reg32_[0] += uint16_t(offset); // Generates better asm than adding to the bitfield -+ -+ // May need to update base_offset if swizzle alignment isn't guaranteed -+ //desc.base_offset_ = 0; -+ //assert((desc.start_address_ & 0b111000) == 0); // Assert base_offset is 0, generalize later -+ -+ //return {desc}; -+ -+ // The above seems to not work for some reason... -+ return {desc_ + uint64_t(offset)}; -+ } -+}; -+ -+template -+struct smem_desc : gmma_descriptor_iterator {}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_gmma_desc_fragment(Tensor const& t) -+{ -+ // Cast to a uint128_t tensor for GMMA Desc iteration -+ return make_tensor(gmma_descriptor_iterator{make_gmma_desc(tensor<0>(t))}, -+ recast(t).layout()); -+} -+ -+// Recast a tensor to a tensor of gmma_descriptor_iterator -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor, type_list>) -+{ -+ return make_gmma_desc_fragment(tensor); -+} -+ -+// Recast a gmma_descriptor_iterator Tensor to uint64_t, it's RegType -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor,TLayout> const& tensor, type_list) -+{ -+ static_assert(std::is_same::value, "Can only cast descriptors to uint64_t."); -+ return make_tensor(tensor.data(), Layout<_1,_0>{}); -+} -+ -+} // end namespace GMMA -+ -+// Fence between the async destination accumulators of GMMA & source for their dependent use -+template -+CUTE_HOST_DEVICE -+void -+warpgroup_fence_operand(Tensor& frg) { -+ CUTE_STATIC_ASSERT(is_static::value); -+ if constexpr (std::is_same_v) { -+ auto f32_frg = recast(frg); -+ CUTE_UNROLL -+ for (int i = 0; i < size(f32_frg); ++i) { -+ warpgroup_fence_operand(f32_frg(i)); -+ } -+ } -+ else { -+ CUTE_STATIC_ASSERT(is_rmem::value); -+ auto u32_frg = recast(frg); -+ CUTE_UNROLL -+ for (int i = 0; i < size(u32_frg); ++i) { -+ warpgroup_fence_operand(u32_frg(i)); -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////// MMA_TRAITS /////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace GMMA { -+ -+// Accumulator layouts -+using CLayout_64x8 = Layout,Shape < _2,_2>>, -+ Stride,Stride<_64,_8>>>; -+ -+using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, -+ Stride,Stride<_64,_8,_512>>>; -+ -+// Register source layout for 32-bit value types -+using ALayout_64x8 = Layout,Shape < _2, _2>>, -+ Stride,Stride< _8,_256>>>; -+ -+// Register source layout for 16-bit value types -+using ALayout_64x16 = CLayout_64x16; -+ -+// Register source layout for 8-bit value types -+using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, -+ Stride,Stride<_64,_8,_1024>>>; -+ -+// Shared memory source layouts for any value type -+template -+using ABLayout = Layout,Int>>, -+ Stride< _0,Stride< _1,Int>>>; -+ -+} // namespace GMMA -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = half_t; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = half_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = half_t; -+ using ElementBVal = half_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 8, 16>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 16, 16>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 32, 16>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 64, 16>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout< 96, 16>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<128, 16>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<192, 16>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 16>; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = bfloat16_t; -+ using ElementBVal = bfloat16_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_16>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x16; -+ using BLayout = GMMA::ABLayout<256, 16>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 8, 8>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 8, 8>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 16, 8>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 16, 8>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 32, 8>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 32, 8>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 64, 8>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 64, 8>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout< 96, 8>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout< 96, 8>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<128, 8>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<128, 8>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<192, 8>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<192, 8>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 8>; -+ using BLayout = GMMA::ABLayout<256, 8>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = float; -+ using ElementAVal = tfloat32_t; -+ using ElementBVal = tfloat32_t; -+ using ElementCVal = float; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_8>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x8; -+ using BLayout = GMMA::ABLayout<256, 8>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = int8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = int8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementAFrg = GMMA::smem_desc; -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ABLayout< 64, 32>; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_8,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 8, 32>; -+ using CLayout = GMMA::CLayout_64x8; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_16,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 16, 32>; -+ using CLayout = GMMA::CLayout_64x16; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_32,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 32, 32>; -+ using CLayout = GMMA::CLayout_64x32; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_64,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 64, 32>; -+ using CLayout = GMMA::CLayout_64x64; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_96,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout< 96, 32>; -+ using CLayout = GMMA::CLayout_64x96; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_128,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<128, 32>; -+ using CLayout = GMMA::CLayout_64x128; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_192,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<192, 32>; -+ using CLayout = GMMA::CLayout_64x192; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MMA_Traits> -+{ -+ using ElementDVal = int32_t; -+ using ElementAVal = uint8_t; -+ using ElementBVal = uint8_t; -+ using ElementCVal = int32_t; -+ -+ using ElementBFrg = GMMA::smem_desc; -+ -+ using Shape_MNK = Shape<_64,_256,_32>; -+ using ThrID = Layout<_128>; -+ using ALayout = GMMA::ALayout_64x32; -+ using BLayout = GMMA::ABLayout<256, 32>; -+ using CLayout = GMMA::CLayout_64x256; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/config.hpp b/3rdparty/cutlass/include/cute/config.hpp -new file mode 100644 -index 0000000..b2f4de8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/config.hpp -@@ -0,0 +1,121 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ -+# define CUTE_DEVICE __forceinline__ __device__ -+# define CUTE_HOST __forceinline__ __host__ -+#else -+# define CUTE_HOST_DEVICE inline -+# define CUTE_DEVICE inline -+# define CUTE_HOST inline -+#endif // CUTE_HOST_DEVICE, CUTE_DEVICE -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_UNROLL #pragma unroll -+# define CUTE_NO_UNROLL #pragma unroll 1 -+#else -+# define CUTE_UNROLL -+# define CUTE_NO_UNROLL -+#endif // CUTE_UNROLL -+ -+#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) -+# define CUTE_INLINE_CONSTANT static const __device__ -+#else -+# define CUTE_INLINE_CONSTANT static constexpr -+#endif -+ -+// Some versions of GCC < 11 have trouble deducing that a -+// function with "auto" return type and all of its returns in an "if -+// constexpr ... else" statement must actually return. Thus, GCC -+// emits spurious "missing return statement" build warnings. -+// Developers can suppress these warnings by using the -+// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. -+// It's harmless to use the macro for other GCC versions or other -+// compilers, but it has no effect. -+#if ! defined(CUTE_GCC_UNREACHABLE) -+# if defined(__GNUC__) && __GNUC__ < 11 -+ // GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return -+ // statement" warnings without this little bit of help. -+# define CUTE_GCC_UNREACHABLE __builtin_unreachable() -+# else -+# define CUTE_GCC_UNREACHABLE -+# endif -+#endif -+ -+// -+// Assertion helpers -+// -+ -+#include -+ -+#define CUTE_STATIC_ASSERT static_assert -+#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) -+ -+#if defined(__CUDA_ARCH__) -+# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory") -+#else -+# define CUTE_RUNTIME_ASSERT(x) assert(0 && x) -+#endif -+ -+// -+// IO -+// -+ -+#include -+#include -+#include -+ -+// -+// Support -+// -+ -+#include -+ -+// -+// Basic types -+// -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+// -+// Debugging utilities -+// -+ -+#include -+#include -diff --git a/3rdparty/cutlass/include/cute/container/alignment.hpp b/3rdparty/cutlass/include/cute/container/alignment.hpp -new file mode 100644 -index 0000000..49101fa ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/alignment.hpp -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// Test if a pointer is aligned to N bytes -+template -+CUTE_HOST_DEVICE constexpr -+bool -+is_byte_aligned(void const* const ptr) -+{ -+ static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); -+ return (reinterpret_cast(ptr) & (N-1)) == 0; -+} -+ -+#if defined(__CUDACC__) -+# define CUTE_ALIGNAS(n) __align__(n) -+#else -+# define CUTE_ALIGNAS(n) alignas(n) -+#endif -+ -+template -+struct aligned_struct {}; -+ -+template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; -+template <> struct CUTE_ALIGNAS( 2) aligned_struct< 2> {}; -+template <> struct CUTE_ALIGNAS( 4) aligned_struct< 4> {}; -+template <> struct CUTE_ALIGNAS( 8) aligned_struct< 8> {}; -+template <> struct CUTE_ALIGNAS( 16) aligned_struct< 16> {}; -+template <> struct CUTE_ALIGNAS( 32) aligned_struct< 32> {}; -+template <> struct CUTE_ALIGNAS( 64) aligned_struct< 64> {}; -+template <> struct CUTE_ALIGNAS(128) aligned_struct<128> {}; -+template <> struct CUTE_ALIGNAS(256) aligned_struct<256> {}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/container/array.hpp b/3rdparty/cutlass/include/cute/container/array.hpp -new file mode 100644 -index 0000000..571ac08 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array.hpp -@@ -0,0 +1,282 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct array -+{ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T* data() -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T const* data() const -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void fill(const T& value) -+ { -+ for (auto& e : *this) { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void clear() -+ { -+ fill(T(0)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void swap(array& other) -+ { -+ using std::swap; -+ for (size_type i = 0; i < size(); ++i) { -+ swap((*this)[i], other[i]); -+ } -+ } -+ -+ value_type __elems_[N > 0 ? N : 1]; -+}; -+ -+ -+template -+CUTE_HOST_DEVICE constexpr -+bool operator==(array const& lhs, array const& rhs) -+{ -+ for (std::size_t i = 0; i < N; ++i) { -+ if (lhs[i] != rhs[i]) { -+ return false; -+ } -+ } -+ return true; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array& a) -+{ -+ a.fill(T(0)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void swap(array& a, array& b) -+{ -+ a.swap(b); -+} -+ -+} // end cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::array -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/array_aligned.hpp b/3rdparty/cutlass/include/cute/container/array_aligned.hpp -new file mode 100644 -index 0000000..b1b3572 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_aligned.hpp -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct array_aligned -+ : public aligned_struct -+{ -+ /// Make sure the Alignment makes sense wrt the size of elements. -+ static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small"); -+ /// Alignment must be a power of two -+ static_assert(has_single_bit(Alignment), "Alignment must be a power of two"); -+ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T* data() -+ { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T const* data() const -+ { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void fill(T const& value) -+ { -+ for (auto& e : *this) { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ void clear() -+ { -+ fill(T(0)); -+ } -+ -+ // Not private, we want trivial type -+ //private: -+ -+ /// Storage type to use for Elements -+ using StorageType = typename uint_byte(Alignment)>::type; -+ -+ /// Ensure that there's enough storage for all elements -+ static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment"); -+ -+ /// Number of elements in the storage -+ static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType); -+ -+ /// The storage. -+ StorageType storage[storageN > 0 ? storageN : 1]; -+}; -+ -+// -+// Operators -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array_aligned& a) -+{ -+ a.clear(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array_aligned& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+} // end namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::array -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array_aligned& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array_aligned const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array_aligned&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/array_subbyte.hpp b/3rdparty/cutlass/include/cute/container/array_subbyte.hpp -new file mode 100644 -index 0000000..a217a67 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_subbyte.hpp -@@ -0,0 +1,613 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates subbyte trivial types -+ in a packed storage. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include // sizeof_bits -+ -+namespace cute -+{ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template -+class array_subbyte -+{ -+ public: -+ -+ /// Number of total bits in the array -+ static constexpr int kSizeBits = sizeof_bits::value * N; -+ -+ /// Storage type -+ using Storage = typename std::conditional< -+ (kSizeBits % 32) == 0, -+ uint32_t, -+ typename std::conditional< -+ (kSizeBits % 16) == 0, -+ uint16_t, -+ uint8_t -+ >::type -+ >::type; -+ -+ -+ /// Number of logical elements per stored object -+ static constexpr int kElementsPerStoredItem = sizeof_bits::value / sizeof_bits::value; -+ -+ /// Number of storage elements -+ static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; -+ -+ /// Bitmask for covering one item -+ static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits::value) - 1); -+ -+ // -+ // C++ standard members with reference and iterator types omitted -+ // -+ -+ using value_type = T; -+ using pointer = value_type*; -+ using const_pointer = value_type const*; -+ -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ -+ // -+ // References -+ // -+ -+ /// Reference object inserts or extracts sub-byte items -+ class reference { -+ /// Pointer to storage element -+ Storage* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ reference() : ptr_(nullptr), idx_(0) {} -+ -+ /// Ctor -+ CUTE_HOST_DEVICE constexpr -+ reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {} -+ -+ /// Assignment -+ CUTE_HOST_DEVICE constexpr -+ reference& operator=(T x) { -+ Storage item = (reinterpret_cast(x) & bit_mask_); -+ Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits::value))); -+ *ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits::value))); -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract to type T -- disable if T == bool -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ operator T() const { -+ return get(); -+ } -+ -+ // Extract to bool -- potentially faster impl -+ CUTE_HOST_DEVICE constexpr -+ operator bool() const { -+ return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); -+ } -+ -+ /// Explicit cast to int -+ CUTE_HOST_DEVICE constexpr -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTE_HOST_DEVICE constexpr -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ /// Reference object extracts sub-byte items -+ class const_reference { -+ -+ /// Pointer to storage element -+ Storage const* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ const_reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTE_HOST_DEVICE constexpr -+ const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract to type T -- disable if T == bool -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ operator T() const { -+ return get(); -+ } -+ -+ // Extract to bool -- potentially faster impl -+ CUTE_HOST_DEVICE constexpr -+ operator bool() const { -+ return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); -+ } -+ -+ /// Explicit cast to int -+ CUTE_HOST_DEVICE constexpr -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTE_HOST_DEVICE constexpr -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to storage element -+ Storage* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator--() { -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++(*this); -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator--(int) { -+ iterator ret(*this); -+ --(*this); -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& operator+=(int k) { -+ idx_ += k; -+ ptr_ += idx_ / kElementsPerStoredItem; -+ idx_ = idx_ % kElementsPerStoredItem; -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator operator+(int k) const { -+ return iterator(ptr_,idx_) += k; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator*() const { -+ return reference(ptr_, idx_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](int k) const { -+ return *(*this + k); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(iterator const& other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(iterator const& other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to storage element -+ Storage const* ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator--() { -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator--(int) { -+ iterator ret(*this); -+ if (idx_) { -+ --idx_; -+ } else { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ return ret; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator& operator+=(int k) { -+ idx_ += k; -+ ptr_ += idx_ / kElementsPerStoredItem; -+ idx_ = idx_ % kElementsPerStoredItem; -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator operator+(int k) const { -+ return const_iterator(ptr_,idx_) += k; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator*() const { -+ return const_reference(ptr_, idx_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](int k) const { -+ return *(*this + k); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(iterator const& other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(iterator const& other) const { -+ return !(*this == other); -+ } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kStorageElements]; -+ -+public: -+ -+ CUTE_HOST_DEVICE constexpr -+ array_subbyte() { } -+ -+ CUTE_HOST_DEVICE constexpr -+ array_subbyte(array_subbyte const& x) { -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const { -+ return !N; -+ } -+ -+ /// Efficient clear method -+ CUTE_HOST_DEVICE constexpr -+ void clear() { -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = Storage(0); -+ } -+ } -+ -+ // Efficient fill method -+ CUTE_HOST_DEVICE constexpr -+ void fill(T const& value) { -+ Storage item = (reinterpret_cast(value) & bit_mask_); -+ -+ // Reproduce the value over the bits of the storage item -+ CUTE_UNROLL -+ for (unsigned s = sizeof_bits::value; s < sizeof_bits::value; s *= 2) { -+ item |= item << s; -+ } -+ -+ CUTE_UNROLL -+ for (unsigned i = 0; i < kStorageElements; ++i) { -+ storage[i] = item; -+ } -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference at(size_type pos) { -+ return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference at(size_type pos) const { -+ return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference operator[](size_type pos) { -+ return at(pos); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference operator[](size_type pos) const { -+ return at(pos); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference front() { -+ return at(0); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference front() const { -+ return at(0); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ reference back() { -+ return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_reference back() const { -+ return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ Storage* raw_data() { -+ return storage; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ Storage const* raw_data() const { -+ return storage; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator begin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cbegin() const { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator end() { -+ return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator end() const { -+ return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ const_iterator cend() const { -+ return end(); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+// -+// Operators -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+void clear(array_subbyte& a) -+{ -+ a.clear(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+void fill(array_subbyte& a, T const& value) -+{ -+ a.fill(value); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::array_subbyte -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& get(array_subbyte& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T const& get(array_subbyte const& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& get(array_subbyte&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/container/array_view.hpp b/3rdparty/cutlass/include/cute/container/array_view.hpp -new file mode 100644 -index 0000000..51b3ccc ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/array_view.hpp -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct array_view -+{ -+ using value_type = T; -+ using size_type = std::size_t; -+ using difference_type = std::ptrdiff_t; -+ using reference = value_type&; -+ using const_reference = const value_type&; -+ using pointer = value_type*; -+ using const_pointer = const value_type*; -+ using iterator = pointer; -+ using const_iterator = const_pointer; -+ -+ array_view(array& a) -+ : __elems_(a.data()) {} -+ -+ CUTE_HOST_DEVICE -+ reference operator[](size_type pos) -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference operator[](size_type pos) const -+ { -+ return begin()[pos]; -+ } -+ -+ CUTE_HOST_DEVICE -+ reference front() -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference front() const -+ { -+ return *begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ reference back() -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_reference back() const -+ { -+ // return *rbegin(); -+ return operator[](N-1); -+ } -+ -+ CUTE_HOST_DEVICE -+ T* data() -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE -+ const T* data() const -+ { -+ return __elems_; -+ } -+ -+ CUTE_HOST_DEVICE -+ iterator begin() -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator begin() const -+ { -+ return data(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cbegin() -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cbegin() const -+ { -+ return begin(); -+ } -+ -+ CUTE_HOST_DEVICE -+ iterator end() -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator end() const -+ { -+ return data() + size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cend() -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE -+ const_iterator cend() const -+ { -+ return end(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool empty() const -+ { -+ return size() == 0; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type size() const -+ { -+ return N; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ size_type max_size() const -+ { -+ return size(); -+ } -+ -+ CUTE_HOST_DEVICE -+ void fill(const T& value) -+ { -+ for(auto& e : *this) -+ { -+ e = value; -+ } -+ } -+ -+ CUTE_HOST_DEVICE -+ void swap(array_view& other) -+ { -+ using std::swap; -+ swap(__elems_, other.__elems_); -+ } -+ -+ value_type* __elems_; -+}; -+ -+ -+template -+CUTE_HOST_DEVICE -+bool operator==(const array_view& lhs, const array_view& rhs) -+{ -+ for(std::size_t i = 0; i < N; ++i) -+ { -+ if(lhs[i] != rhs[i]) return false; -+ } -+ -+ return true; -+} -+ -+template -+CUTE_HOST_DEVICE -+void clear(array_view& a) -+{ -+ a.fill(T(0)); -+} -+ -+template -+CUTE_HOST_DEVICE -+void swap(array_view& a, array_view& b) -+{ -+ a.swap(b); -+} -+ -+} // end cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::array_view -+// -+ -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+T& -+get(array_view& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+const T& -+get(const array_view& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return a[I]; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+T&& -+get(array_view&& a) -+{ -+ static_assert(I < N, "Index out of range"); -+ return std::move(a[I]); -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+{ -+ using type = T; -+}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/bit_field.hpp b/3rdparty/cutlass/include/cute/container/bit_field.hpp -new file mode 100644 -index 0000000..06b0875 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/bit_field.hpp -@@ -0,0 +1,131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Portable bit field that supports byte and word straddling that can -+ be used in unions to bit-wise define parameters. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include // uint_bit_t -+ -+namespace cute -+{ -+ -+class dummy_type {}; -+ -+template -+struct bit_field -+{ -+ static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); -+ -+ // value_type: Use the smallest value type that fits NumBits -+ static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : -+ (NumBits <= 16) ? 16 : -+ (NumBits <= 32) ? 32 : 64; -+ using value_type = cute::uint_bit_t; -+ // storage_type: Use the smallest storage_type that avoids boundary crossing -+ static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : -+ (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : -+ (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; -+ using storage_type = cute::uint_bit_t; -+ -+ static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same::value, -+ "sizeof(OtherValueType) must be same as sizeof(value_type)."); -+ -+ // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) -+ static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; -+ // Index of storage value for BitStart -+ static constexpr uint32_t idx = BitStart / storage_type_bits; -+ // Bit of data_[idx] for BitStart -+ static constexpr uint32_t bit_lo = BitStart % storage_type_bits; -+ // Number of bits in data_[idx] used for NumBits if straddling, else 0 -+ static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; -+ -+ // NumBits mask -+ static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1); -+ // NumBits mask for BitStart -+ static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; -+ // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 -+ static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; -+ -+ storage_type data_[N]; -+ -+ // Get value -+ CUTE_HOST_DEVICE constexpr -+ value_type get() const { -+ storage_type result = (data_[idx] & mask_lo) >> bit_lo; -+ if constexpr (bit_hi) { -+ result |= (data_[idx+1] & mask_hi) << bit_hi; -+ } -+ return static_cast(result); -+ } -+ -+ // Set value -+ CUTE_HOST_DEVICE constexpr -+ void set(value_type x) { -+ storage_type item = static_cast(x & mask); -+ data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); -+ if constexpr (bit_hi) { -+ data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); -+ } -+ } -+ -+ // Assign value -+ CUTE_HOST_DEVICE constexpr -+ bit_field& operator=(value_type x) { -+ set(x); -+ return *this; -+ } -+ -+ // Cast to value -+ CUTE_HOST_DEVICE constexpr -+ operator value_type () const { -+ return get(); -+ } -+ -+ // Assign OtherValueType -+ CUTE_HOST_DEVICE constexpr -+ bit_field& operator=(OtherValueType x) { -+ return *this = *reinterpret_cast(&x); -+ } -+ -+ // Cast to OtherValueType -+ CUTE_HOST_DEVICE constexpr -+ operator OtherValueType () const { -+ value_type x = get(); -+ return *reinterpret_cast(&x); -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/container/tuple.hpp b/3rdparty/cutlass/include/cute/container/tuple.hpp -new file mode 100644 -index 0000000..1b3ffa4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/tuple.hpp -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+ -+#include -+#include -+ -+#include // cute::true_type, cute::false_type -+//#include // Advanced optimizations -+ -+#if 0 -+// -+// Use of agency::tuple is functional, but is over-engineered for our purposes... -+// This tends to result in slow compilation times and unintentionally propagated cvref types -+// -+ -+#include -+ -+namespace cute -+{ -+ -+using agency::tuple; -+ -+using agency::make_tuple; -+using agency::tuple_cat; -+ -+} // end namespace cute -+#endif -+ -+// cute::tuple is like std::tuple, with two differences. -+// -+// 1. It works on both host and device. -+// 2. Its template arguments must be semiregular types. -+// -+// Semiregular types are default constructible and copyable. -+// They include "value types" like int or float, -+// but do _not_ include references like int& or float&. -+// (See std::tie for an example of a tuple of references.) -+// -+// This is simplified over the implementation in std:: and agency:: by ignoring much of -+// the conversion SFINAE, special overloading, and avoiding cvref template types. -+// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding -+// construction calls, and ignoring any need for unique element addresses. -+// -+// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x. -+ -+namespace cute -+{ -+ -+namespace detail -+{ -+ -+// EBO stands for "empty base optimization." -+// We use this technique to ensure that cute::tuple -+// doesn't need to waste space storing any template arguments -+// of cute::tuple that have no data (like integral_constant). -+// Otherwise, cute::tuple would need to spend at least 1 byte -+// for each of its template arguments. -+// -+// EBO always "holds" a single value of type T. -+// N is like an array index that TupleBase uses -+// to access the desired tuple element. -+template ::value> -+struct EBO; -+ -+// Specialization for types T that have no data; -+// the "static tuple leaf." Valid T here include -+// integral_constant, Int, -+// and any other semiregular type -+// for which std::is_empty_v is true. -+template -+struct EBO -+{ -+ CUTE_HOST_DEVICE constexpr -+ EBO() {} -+ -+ CUTE_HOST_DEVICE constexpr -+ EBO(T const&) {} -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr T getv(EBO const&) -+{ return {}; } -+ -+// Specialization for types T that are not empty; -+// the "dynamic tuple leaf." Valid T here include int, -+// any other integral or floating-point type, -+// or any semiregular type for which std::is_empty_v is false. -+template -+struct EBO -+{ -+ CUTE_HOST_DEVICE constexpr -+ EBO() : t_{} {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ EBO(U const& u) : t_{u} {} -+ -+ T t_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) -+{ return x.t_; } -+ -+template -+CUTE_HOST_DEVICE constexpr T& getv(EBO& x) -+{ return x.t_; } -+ -+template -+CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) -+{ return static_cast(x.t_); } -+ -+template -+struct TupleBase; -+ -+// Base class of cute::tuple. -+// It inherits from EBO for each (i, t) in (I..., T...). -+// The actual storage (for nonempty t) lives in the base classes. -+// index_sequence is a way to wrap up a sequence of zero or more -+// compile-time integer values in a single type. -+// We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice, -+// as the type alias TupleBase below indicates. -+template -+struct TupleBase, T...> -+ : EBO... -+{ -+ CUTE_HOST_DEVICE constexpr -+ TupleBase() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr explicit -+ TupleBase(U const&... u) -+ : EBO(u)... {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ TupleBase(TupleBase, U...> const& u) -+ : EBO(getv(static_cast const&>(u)))... {} -+}; -+ -+} // end namespace detail -+ -+// make_index_sequence returns index_sequence<0, 1, ..., K-1>. -+template -+using TupleBase = detail::TupleBase, T...>; -+ -+// This is the actual cute::tuple class. -+// The storage (if any) lives in TupleBase's EBO base classes. -+template -+struct tuple : TupleBase -+{ -+ CUTE_HOST_DEVICE constexpr -+ tuple() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ tuple(U const&... u) : TupleBase(u...) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ tuple(tuple const& u) -+ : TupleBase(static_cast const&>(u)) {} -+}; -+ -+// -+// get for cute::tuple (just like std::get for std::tuple) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple const& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(tuple&& t) noexcept -+{ -+ static_assert(I < sizeof...(T), "Index out of range"); -+ return detail::getv(static_cast&&>(t)); -+} -+ -+// -+// Custom is_tuple trait simply checks the existence of std::tuple_size -+// and assumes std::get(.), std::tuple_element -+// -+namespace detail { -+ -+template -+std::integral_constant::value >= 0> has_tuple_size(int); -+ -+template -+std::false_type has_tuple_size(...); -+ -+} // end namespace detail -+ -+template -+struct is_tuple : decltype(detail::has_tuple_size(0)) {}; -+ -+// -+// make_tuple (value-based implementation) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+tuple -+make_tuple(T const&... t) -+{ -+ return {t...}; -+} -+ -+// -+// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, -+// just like std::tuple_cat for std::tuple. -+// -+ -+#if 0 -+// Original implementation -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) -+{ -+ return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); -+} -+#endif -+ -+#if 1 -+// Extended implementation -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, -+ std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, -+ std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, -+ std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) -+{ -+ return detail::tuple_cat(t0, t1, t2, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) -+{ -+ return detail::tuple_cat(t0, t1, t2, t3, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) -+{ -+ return detail::tuple_cat(t0, t1, t2, t3, t4, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) -+{ -+ return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...); -+} -+#endif -+ -+#if 0 -+// Outer-Inner indexing trick to concat all tuples at once -+ -+namespace detail { -+ -+template -+struct tuple_cat_helper -+{ -+ static constexpr cute::array ns = {Ns...}; -+ -+ static constexpr std::size_t total_size() { -+ std::size_t sum = 0; -+ for (std::size_t n : ns) sum += n; -+ return sum; -+ } -+ static constexpr std::size_t total_size_ = total_size(); -+ -+ static constexpr auto values() { -+ cute::array outer_inner = {}; -+ -+ std::size_t idx = 0; -+ for (std::size_t i = 0; i < ns.size(); ++i) { -+ for (std::size_t j = 0; j < ns[i]; ++j, ++idx) { -+ outer_inner[idx][0] = i; -+ outer_inner[idx][1] = j; -+ } -+ } -+ return outer_inner; -+ } -+ static constexpr auto outer_inner_ = values(); -+ -+ using total_sequence = std::make_index_sequence; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(Tuple const& t, std::index_sequence) -+{ -+ return cute::make_tuple(get(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1, -+ std::index_sequence, std::index_sequence) -+{ -+ return cute::make_tuple(get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+CUTE_HOST_DEVICE constexpr -+tuple<> -+tuple_cat() -+{ -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+Tuple const& -+tuple_cat(Tuple const& t) -+{ -+ return t; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(T0 const& t0, T1 const& t1) -+{ -+ return detail::tuple_cat(t0, t1, -+ std::make_index_sequence::value>{}, -+ std::make_index_sequence::value>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tuple_cat(Tuples const&... ts) -+{ -+ using Helper = detail::tuple_cat_helper::value...>; -+ return detail::tuple_cat(make_tuple(ts...), typename Helper::total_sequence{}); -+} -+#endif -+ -+// -+// Equality operators -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+equal_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == std::tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted -+ } else if constexpr (I == std::tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted -+ } else { -+ return (get(a) == get(b)) && equal_impl(a,b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template ::value && is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(TupleT const& t, TupleU const& u) -+{ -+ return detail::equal_impl<0>(t, u); -+} -+ -+template ::value ^ is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(TupleT const& t, TupleU const& u) -+{ -+ return cute::false_type{}; -+} -+ -+template ::value && is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator!=(TupleT const& t, TupleU const& u) -+{ -+ return !(t == u); -+} -+ -+template ::value ^ is_tuple::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator!=(TupleT const& t, TupleU const& u) -+{ -+ return cute::true_type{}; -+} -+ -+// -+// Comparison operators -+// -+ -+// -+// There are many ways to compare tuple of elements and because CuTe is built -+// on parameterizing layouts of coordinates, some comparisons are appropriate -+// only in certain cases. -+// -- lexicographical comparison [reverse, reflected, revref] -+// -- colexicographical comparison [reverse, reflected, revref] -+// -- element-wise comparison [any,all] -+// This can be very confusing. To avoid errors in selecting the appropriate -+// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. -+// -+// That said, see int_tuple for more explicitly named common comparison ops. -+// -+ -+// -+// Shortcuts -+// -+ -+//using std::get; -+using std::tuple_size; -+using std::tuple_element; -+using std::tuple_element_t; -+ -+// -+// Display utilities -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE void print_tuple(Tuple const& t, -+ std::index_sequence, char s = '(', char e = ')') -+{ -+ using eat = int[]; -+ using cute::print; -+ (void) eat {(print(s), 0), -+ (print(Is == 0 ? "" : ","), print(get(t)), 0)..., -+ (print(e), 0)}; -+} -+ -+template -+CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, -+ std::index_sequence, char s = '(', char e = ')') -+{ -+ using eat = int[]; -+ (void) eat {(void(os << s), 0), -+ (void(os << (Is == 0 ? "" : ",") << get(t)), 0)..., -+ (void(os << e), 0)}; -+ return os; -+} -+ -+} // end namespace detail -+ -+template ::value)> -+CUTE_HOST_DEVICE void print(Tuple const& t) -+{ -+ return detail::print_tuple(t, std::make_index_sequence::value>{}); -+} -+ -+template ::value)> -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) -+{ -+ return detail::print_tuple_os(os, t, std::make_index_sequence::value>{}); -+} -+ -+} // end namespace cute -+ -+// -+// std:: compatability -+// -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element> -+{}; -+ -+} // end std -diff --git a/3rdparty/cutlass/include/cute/container/type_list.hpp b/3rdparty/cutlass/include/cute/container/type_list.hpp -new file mode 100644 -index 0000000..c082a6d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/container/type_list.hpp -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+namespace cute -+{ -+ -+template -+struct type_c { -+ using type = T; -+}; -+ -+template -+struct type_list {}; -+ -+} // end namespace cute -+ -+// -+// Specialize tuple-related functionality for cute::type_list -+// -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(type_list&) noexcept { -+ return {}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(type_list const& t) noexcept { -+ return {}; -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : cute::type_c>::type> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/int_tuple.hpp b/3rdparty/cutlass/include/cute/int_tuple.hpp -new file mode 100644 -index 0000000..045e721 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/int_tuple.hpp -@@ -0,0 +1,827 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+using IntTuple = cute::tuple; -+ -+// Construct an IntTuple with all value-elements -+template -+CUTE_HOST_DEVICE constexpr -+IntTuple -+make_int_tuple(Ts const&... t) -+{ -+ return {t...}; -+} -+ -+/** if rank(int) == 1, then get<0>(int) should work too -+ */ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(T&& t) noexcept -+{ -+ static_assert(I == 0, "Index out of range"); -+ return static_cast(t); -+} -+ -+/** Custom recursive get for anything that implements get(.) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+get(Tuple&& t) noexcept -+{ -+ return get(get(static_cast(t))); -+} -+ -+// -+// rank -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(IntTuple const& t) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ if constexpr (is_tuple::value) { -+ return Int::value>{}; -+ } else { -+ return Int<1>{}; -+ } -+ } else { -+ return rank(get(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using rank_t = decltype(rank(std::declval())); -+ -+template -+static constexpr int rank_v = rank_t::value; -+ -+// -+// shape -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape(IntTuple const& s) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(s, [](auto const& a) { return shape(a); }); -+ } else { -+ return s; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape(IntTuple const& s) -+{ -+ if constexpr (is_tuple::value) { -+ return shape(get(s)); -+ } else { -+ return get(shape(s)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// max -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max(T0 const& t0, Ts const&... ts) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); -+ } else if constexpr (sizeof...(Ts) == 0) { -+ return t0; -+ } else { -+ return cute::max(t0, cute::max(ts...)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// min -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+min(T0 const& t0, Ts const&... ts) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); -+ } else if constexpr (sizeof...(Ts) == 0) { -+ return t0; -+ } else { -+ return cute::min(t0, cute::min(ts...)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// depth -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(IntTuple const& t) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ if constexpr (is_tuple::value) { -+ return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); -+ } else { -+ return Int<0>{}; -+ } -+ } else { -+ return depth(get(t)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using depth_t = decltype(depth(std::declval())); -+ -+template -+static constexpr int depth_v = depth_t::value; -+ -+// -+// product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product(IntTuple const& a) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }); -+ } else { -+ return a; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Product of a subrange -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product(Tuple const& a) -+{ -+ return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+product_each(Tuple const& t) -+{ -+ return transform(t, [](auto const& x) { return product(x); }); -+} -+ -+// Return the product of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(IntTuple const& a) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return product(a); -+ } else { -+ return product(get(a)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+static constexpr int size_v = decltype(size(std::declval()))::value; -+ -+// -+// sum -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+sum(IntTuple const& a) -+{ -+ if constexpr (is_tuple::value) { -+ return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); -+ } else { -+ return a; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// inner_product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+inner_product(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, -+ [](auto const&... v) { return (Int<0>{} + ... + v); }); -+ } else { -+ return a * b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// ceil_div -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+ceil_div(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); -+ constexpr int R = tuple_size::value; // Missing ranks in TupleB are implictly 1 -+ return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); -+ } else { -+ return (a + b - Int<1>{}) / b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Division for Shapes -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+shape_div(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); -+ } else { // tuple int -+ auto const [result, rest] = fold(a, make_tuple(make_tuple(), b), -+ [] (auto const& init, auto const& ai) { -+ return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); -+ }); -+ return result; -+ } -+ } else { -+ if constexpr (is_tuple::value) { // int tuple -+ return shape_div(a, product(b)); -+ } else { // int int -+ //assert(a % b == 0 || b % a == 0); -+ return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Division for Shapes that are static constants -+ * @pre t % u == 0 || u % t == 0 -+ * @result if t % u == 0, then t / u -+ * if u % t == 0, then signum(t) * signum(u) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+constant -+shape_div(constant const&, constant const&) -+{ -+ static_assert(t % u == 0 || u % t == 0, "Static shape_div failure"); -+ return {}; -+} -+ -+/** Return a tuple the same profile as A scaled by corresponding elements in B -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_scale(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); -+ } else { -+ return a * product(b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Test if two IntTuple have the same profile (hierarchical rank division) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+congruent(IntTupleA const& a, IntTupleB const& b) -+{ -+ return bool_constant::value>{}; -+} -+ -+template -+using is_congruent = decltype(congruent(std::declval(), std::declval())); -+ -+/** Test if Shape B is compatible with Shape A: -+ * Any coordinate into A can also be used as a coordinate into B -+ * A <= B is a partially ordered set of factored shapes -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compatible(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ if constexpr (tuple_size::value != tuple_size::value) { -+ return false_type{}; -+ } else { -+ return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, -+ [](auto const&... z) { return (true_type{} && ... && z); }); -+ } -+ } else if constexpr (is_integral::value) { -+ return a == size(b); -+ } else if constexpr (is_integral::value) { -+ return false_type{}; -+ } else { -+ return compatible(shape(a), shape(b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+using is_compatible = decltype(compatible(std::declval(), std::declval())); -+ -+/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); -+ } else if constexpr (is_constant<0, IntTupleA>::value) { -+ return Int<1>{}; -+ } else { -+ return b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(Tuple const& t) -+{ -+ return filter_zeros(t, t); -+} -+ -+// -+// Converters and constructors with arrays and params -+// -+ -+/** Make an IntTuple of rank N from an Indexable array. -+ * Access elements up to a dynamic index n, then use init (requires compatible types) -+ * Consider cute::take if all indexing is known to be valid -+ * \code -+ * std::vector a = {6,3,4}; -+ * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_int_tuple(Indexable const& t, int n, T const& init) -+{ -+ static_assert(N > 0); -+ if constexpr (N == 1) { -+ return 0 < n ? t[0] : init; -+ } else { -+ return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** Fill the dynamic values of a Tuple with values from another Tuple -+ * \code -+ * auto params = make_int_tuple(6,3,4); -+ * cute::tuple, cute::tuple>, int, Int<2>> result; -+ * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+fill_int_tuple_from(Tuple& result, TupleV const& vals) -+{ -+ return fold(result, vals, [](auto const& init, auto&& r) { -+ if constexpr (is_static>::value) { // Skip static elements of result -+ return init; -+ } else if constexpr (is_tuple>::value) { // Recurse into tuples -+ return fill_int_tuple_from(r, init); -+ } else { // Assign and consume arg -+ static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); -+ r = get<0>(init); -+ return remove<0>(init); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ }); -+} -+ -+/** Make a "Tuple" by filling in the dynamic values in order from the arguments -+ * \code -+ * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; -+ * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+Tuple -+make_int_tuple_from(Ts const&... ts) -+{ -+ Tuple result = Tuple{}; -+ fill_int_tuple_from(result, make_tuple(ts...)); -+ return result; -+} -+ -+/** Convert a tuple to a flat homogeneous array of type T -+ * \code -+ * auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{}); -+ * cute::array result = to_array(tup); // [1,6,3,3,4,2] -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_array(IntTuple const& t) -+{ -+ auto flat_t = flatten_to_tuple(t); -+ constexpr int N = tuple_size::value; -+ cute::array result; -+ for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); -+ return result; -+} -+ -+// -+// Comparison operators -+// -+ -+// -+// There are many ways to compare tuple of elements and because CuTe is built -+// on parameterizing layouts of coordinates, some comparisons are appropriate -+// only in certain cases. -+// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout -+// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout -+// -- element-wise comparison [any,all] : -+// This can be very confusing. To avoid errors in selecting the appropriate -+// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. -+// -+// When actually desiring to order coordinates, the user should map them to -+// their indices within the Layout they came from: -+// e.g. layoutX(coordA) < layoutX(coordB) -+// That said, we implement the three most common ways to compare tuples below. -+// These are implemented with slighly more explicit names than op<. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less(IntTupleA const& a, IntTupleB const& b); -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less(IntTupleA const& a, IntTupleB const& b); -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less(IntTupleA const& a, IntTupleB const& b); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleB is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted -+ } else { -+ return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleB is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted -+ } else { -+ constexpr std::size_t A = tuple_size::value - 1 - I; -+ constexpr std::size_t B = tuple_size::value - 1 - I; -+ return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less_impl(TupleA const& a, TupleB const& b) -+{ -+ if constexpr (I == tuple_size::value) { -+ return cute::true_type{}; // Terminal: TupleA is exhausted -+ } else if constexpr (I == tuple_size::value) { -+ return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted -+ } else { -+ return elem_less(get(a), get(b)) && elem_less_impl(a,b); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// Lexicographical comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::lex_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_leq(T const& t, U const& u) { -+ return !lex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_gtr(T const& t, U const& u) { -+ return lex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+lex_geq(T const& t, U const& u) { -+ return !lex_less(t, u); -+} -+ -+// Colexicographical comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::colex_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_leq(T const& t, U const& u) { -+ return !colex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_gtr(T const& t, U const& u) { -+ return colex_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+colex_geq(T const& t, U const& u) { -+ return !colex_less(t, u); -+} -+ -+// Elementwise [all] comparison -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_less(IntTupleA const& a, IntTupleB const& b) -+{ -+ if constexpr (is_tuple::value && is_tuple::value) { -+ return detail::elem_less_impl<0>(a, b); -+ } else { -+ return a < b; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_leq(T const& t, U const& u) { -+ return !elem_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_gtr(T const& t, U const& u) { -+ return elem_less(u, t); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+elem_geq(T const& t, U const& u) { -+ return !elem_less(t, u); -+} -+ -+/** Increment a (dynamic) coord lexicographically within a shape -+ * \code -+ * auto shape = make_shape(1,2,make_shape(2,3),3); -+ * -+ * int i = 0; -+ * for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { -+ * std::cout << i++ << ": " << coord << std::endl; -+ * } -+ * assert(i == size(shape)); -+ * \endcode -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape, seq) -+{ -+ cute::increment(get(coord), get(shape)); -+ if constexpr (sizeof...(Is) != 0) { -+ if (back(get(coord)) == back(get(shape))) { -+ back(get(coord)) = 0; -+ increment(coord, shape, seq{}); -+ } -+ } -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+void -+increment(Coord& coord, Shape const& shape) -+{ -+ if constexpr (is_integral::value && is_integral::value) { -+ ++coord; -+ } else if constexpr (is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ detail::increment(coord, shape, tuple_seq{}); -+ } else { -+ static_assert(sizeof(Coord) == 0, "Invalid parameters"); -+ } -+} -+ -+struct ForwardCoordIteratorSentinal -+{}; -+ -+// A forward iterator for a coordinate that starts from zero and goes to shape -+template -+struct ForwardCoordIterator -+{ -+ static_assert(is_congruent::value); -+ -+ CUTE_HOST_DEVICE constexpr -+ Coord const& operator*() const { return coord; } -+ -+ CUTE_HOST_DEVICE constexpr -+ ForwardCoordIterator& operator++() { increment(coord, shape); return *this; } -+ -+ // Sentinal for the end of the implied range -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } -+ // NOTE: These are expensive, avoid use -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } -+ -+ Coord coord; -+ Shape const& shape; -+}; -+ -+// A forward iterator for a coordinate that starts from zero -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_coord_iterator(Shape const& shape) -+{ -+ auto coord = repeat_like(shape, int(0)); -+ return ForwardCoordIterator{coord,shape}; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/layout.hpp b/3rdparty/cutlass/include/cute/layout.hpp -new file mode 100644 -index 0000000..fe937ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/layout.hpp -@@ -0,0 +1,1638 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// Aliases -+ -+template -+using Shape = IntTuple; -+ -+template -+using Stride = IntTuple; -+ -+template -+using Step = IntTuple; -+ -+template -+using Coord = IntTuple; -+ -+template -+CUTE_HOST_DEVICE constexpr -+Shape -+make_shape(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Stride -+make_stride(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Step -+make_step(Ts const&... t) { -+ return {t...}; -+} -+template -+CUTE_HOST_DEVICE constexpr -+Coord -+make_coord(Ts const&... t) { -+ return {t...}; -+} -+ -+ -+template > -+struct Layout -+ : private cute::tuple // EBO for static layouts -+{ -+ // Avoid bad CTAD: -+ // Layout smem = GMMA::Layout_MN_SW128_Atom; -+ // Should fail because smem is a ComposedLayout (SwizzleLayout) and not a Layout -+ static_assert(is_integral::value || is_tuple::value); -+ -+ // Expensive in compilation time... -+ //static_assert(is_congruent::value, -+ // "Shape and Stride must have the same hierarchical structure"); -+ //static_assert(is_integral::value || is_tuple::value); -+ -+ // NOTE: This defaults static Shapes/Strides correctly, but not dynamic -+ CUTE_HOST_DEVICE constexpr -+ Layout(LogicalShape const& logical_shape = {}, -+ LogicalStride const& logical_stride = {}) -+ : cute::tuple(logical_shape, logical_stride) -+ {} -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = rank_v ; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return *this; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() { -+ return get<0,I...>(static_cast&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return get<0,I...>(static_cast const&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() { -+ return get<1,I...>(static_cast&>(*this)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const { -+ return get<1,I...>(static_cast const&>(*this)); -+ } -+ -+ // -+ // Mappings -+ // -+ -+ // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) -+ // OR -+ // Slice the layout and return the sublayout (Coord has an Underscore slice op) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ return slice(coord, *this); -+ } else { -+ return crd2idx(coord, shape(), stride()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // Convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // Map a linear index to a hier ND logical coordinate -+ // NOTE: Dangerous and error-prone -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator[](Int const& linear_idx) const { -+ static_assert(is_integral::value); -+ return get_hier_coord(linear_idx); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(OtherLayout const& other) const { -+ return composition(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return composition(*this, make_tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(OtherShape const& shape) const { -+ return composition(*this, make_layout(shape)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(Shapes const&... shapes) const { -+ return composition(*this, make_layout(make_shape(shapes...))); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(OtherLayout const& other) const { -+ return tiled_divide(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return tiled_divide(*this, make_tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ // -+ // Index to Coordinate -+ // -+ -+ // NOTE: Only valid for compact layouts -+ -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post congruent(@a result, shape()) -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(IInt const& idx) const { -+ return cute::idx2crd(idx, shape(), stride()); -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(IInt const& idx) const { -+ return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post crd2idx(@a result, shape(), stride()) == idx -+ // @post is_integral::value -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(IInt const& idx) const { -+ return cute::crd2idx(this->get_hier_coord(idx), shape()); -+ } -+ -+ // -+ // Coordinate to Coordinate -+ // -+ -+#if 0 -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post congruent(@a result, shape()) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_hier_coord(Coord const& crd) const { -+ return cute::crd2crd(crd, shape(), shape()); -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_flat_coord(Coord const& crd) const { -+ return cute::crd2crd(crd, shape(), product_each(shape())); -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post is_integral::value -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ crd_2_1d_coord(Coord const& crd) const { -+ //return cute::crd2crd(crd, shape(), product(shape())); -+ return cute::crd2idx(crd, shape()); -+ } -+#endif -+}; -+ -+ -+template -+struct is_layout : false_type {}; -+template -+struct is_layout> : true_type {}; -+ -+ -+template ::value || is_integral::value) && -+ (is_tuple::value || is_integral::value))> -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, Stride const& stride) -+{ -+ return Layout(shape, stride); -+} -+ -+template ::value || is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape) -+{ -+ return make_layout(shape, compact_col_major(shape)); -+} -+ -+// Construct a layout from multiple layouts by -+// concatenating each layout as an independent mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Layout const&... layouts) -+{ -+ return make_layout(make_shape (layouts.shape()...), -+ make_stride(layouts.stride()...)); -+} -+ -+// -+// Convenience tags for common layouts -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, GenColMajor) -+{ -+ return make_layout(shape, compact_col_major(shape)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Shape const& shape, GenRowMajor) -+{ -+ return make_layout(shape, compact_row_major(shape)); -+} -+ -+// Follow the same ordering induced by the strides, but make the layout compact -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_ordered_layout(Shape const& shape, Order const& order) -+{ -+ static_assert(is_static::value && is_static::value); -+ return make_layout(shape, compact_order(shape, order)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_ordered_layout(Layout const& layout) -+{ -+ return make_ordered_layout(layout.shape(), layout.stride()); -+} -+ -+// Make a layout of the same shape that is either ordered or colmajor depending on staticness -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout_like(Layout const& layout) -+{ -+ if constexpr (is_static::value && is_static::value) { -+ return make_ordered_layout(layout.shape(), layout.stride()); -+ } else { -+ return make_layout(layout.shape()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Make a layout of the same shape, -+// with mode-0 being colmajor then following the the mode order in layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(Layout const& layout) -+{ -+ auto shape = replace<0>(layout.shape(), size<0>(layout)); -+ auto order = replace<0>(layout.stride(), Int<0>{}); -+ if constexpr (is_static::value && is_static::value) { -+ return make_ordered_layout(shape, order); -+ } else { -+ return make_layout(layout.shape()); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_identity_layout(Shape const& shape) -+{ -+ return make_layout(shape, make_basis_like(shape)); -+} -+ -+// -+// Operations to manipulate Layouts like a tuple of pairs -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+get(Layout const& layout) -+{ -+ // Let the static_asserts in get(shape|stride) catch problems -+ return make_layout(get(layout.shape()), get(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(Layout const& layout) -+{ -+ // Let the static_asserts in take(shape|stride) catch problems -+ return make_layout(take(layout.shape()), take(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(Layout const& layout) -+{ -+ return make_layout(flatten(layout.shape()), flatten(layout.stride())); -+} -+ -+// -+// Utilities -+// -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(Layout const& layout) -+{ -+ if constexpr (sizeof...(Is) == 0) { -+ return layout; -+ } else { -+ return get(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Layout& layout) -+{ -+ return layout.template shape(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Layout const& layout) -+{ -+ return layout.template shape(); -+} -+ -+// Return the stride of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Layout& layout) -+{ -+ return layout.template stride(); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Layout const& layout) -+{ -+ return layout.template stride(); -+} -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+size(Layout const& layout) -+{ -+ return size(shape(layout)); -+} -+ -+// Return the number of modes -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(Layout const& layout) -+{ -+ return rank(shape(layout)); -+} -+ -+// Return the depth of the layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(Layout const& layout) -+{ -+ return depth(shape(layout)); -+} -+ -+// Return the codomain size of a mode -+// @return M smallest integer such that @a sub_layout(c) < M for all c < size(@a sub_layout) -+// where sub_layout = get(layout). -+template -+CUTE_HOST_DEVICE constexpr -+auto -+cosize(Layout const& layout) -+{ -+ // Protect against negative strides -+ auto abs_sub_layout = make_layout(shape(layout), -+ transform_leaf(stride(layout), abs_fn{})); -+ return abs_sub_layout(size(abs_sub_layout) - Int<1>{}) + Int<1>{}; -+} -+ -+template -+using cosize_t = decltype(cosize(std::declval())); -+ -+template -+static constexpr int cosize_v = cosize_t::value; -+ -+// Equality -+// Return a static or dynamic boolean -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(Layout const& layoutA, Layout const& layoutB) -+{ -+ return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); -+} -+ -+// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& c, Layout const& layout) -+{ -+ return crd2idx(c, layout.shape(), layout.stride()); -+} -+ -+// -+// Slice and Dice a layout -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(Coord const& c, Layout const& layout) -+{ -+ return make_layout(slice(c, layout.shape()), -+ slice(c, layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice_and_offset(Coord const& c, Layout const& layout) -+{ -+ return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+dice(Coord const& c, Layout const& layout) -+{ -+ return make_layout(dice(c, layout.shape()), -+ dice(c, layout.stride())); -+} -+ -+// -+// Transform the modes of a layout -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple const& t, F&& f, seq) -+{ -+ return make_layout(f(get(t))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) -+{ -+ return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple const& t, F&& f) -+{ -+ return detail::transform_layout(t, f, make_seq{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) -+{ -+ constexpr int R0 = decltype(rank(t0))::value; -+ constexpr int R1 = decltype(rank(t1))::value; -+ constexpr int R = (R0 < R1) ? R0 : R1; -+ return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); -+} -+ -+// -+// Coalesce and Filter -+// -+ -+namespace detail { -+ -+// Look at each element and the front of the stack (in order of priority) -+// front(NewLayout) get(Layout) -+// s0:d0 _1:d1 => continue -+// _1:d0 s1:d1 => replace_front s1:d1 -+// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 -+// s0:d0 s1:d1 => prepend s1:d1 -+// -+// @pre OldShape and OldStride are flat -+template -+CUTE_HOST_DEVICE constexpr -+auto -+bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, -+ NewShape const& new_shape, NewStride const& new_stride) -+{ -+ if constexpr (I == -1) { -+ // Base case, we're done -+ if constexpr (is_constant<1, NewShape>::value) { -+ return Layout<_1,_0>{}; -+ } else { -+ return Layout{new_shape,new_stride}; -+ } -+ } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { -+ // shape(layout) == _1, skip it and continue -+ return bw_coalesce(old_shape, old_stride, new_shape, new_stride); -+ } else if constexpr (is_constant<1, NewShape>::value) { -+ // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) -+ return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); -+ } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { -+ // Merge modes because the shapes and strides match -+ return bw_coalesce(old_shape, old_stride, -+ replace_front(new_shape, get(old_shape) * get<0>(new_shape)), -+ replace_front(new_stride, get(old_stride))); -+ } else { -+ // Can't replace or merge, so prepend a new mode -+ return bw_coalesce(old_shape, old_stride, -+ prepend(new_shape, get(old_shape)), -+ prepend(new_stride, get(old_stride))); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// Combine all the modes that are possible to combine -+// Does not respect the profile of the layout, but does preserve total size -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Layout const& layout) -+{ -+ auto flat_shape = flatten(layout.shape()); -+ auto flat_stride = flatten(layout.stride()); -+ -+ constexpr int R = decltype(rank(flat_shape))::value; -+ return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); -+} -+ -+// Apply coalesce at the terminals of trg_profile -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Layout const& layout, IntTuple const& trg_profile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); -+ } else { -+ return coalesce(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Replace the modes in layout that have a 0-stride with a 1-size -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter_zeros(Layout const& layout) -+{ -+ return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); -+} -+ -+// Remove all of the 0-strides and 1-sizes -+// Return 1-shape if empty -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(Layout const& layout) -+{ -+ return coalesce(filter_zeros(layout)); -+} -+ -+// Apply filter at the terminals of trg_profile -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(Layout const& layout, IntTuple const& trg_profile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); -+ } else { -+ return filter(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Append, Prepend, Replace -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(Layout const& layout, -+ Layout const& x = {}) -+{ -+ return make_layout(append(layout.shape(), x.shape()), -+ append(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+prepend(Layout const& layout, -+ Layout const& x = {}) -+{ -+ return make_layout(prepend(layout.shape(), x.shape()), -+ prepend(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+replace(Layout const& layout, -+ Layout const& x) -+{ -+ return make_layout(replace(layout.shape(), x.shape()), -+ replace(layout.stride(), x.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(Layout const& layout) -+{ -+ return make_layout(group(layout.shape()), -+ group(layout.stride())); -+} -+ -+// -+// Composition of two layouts: lhs o rhs -+// @post compatible(rhs, result) -+// @post result(c) = lhs(rhs(c)) -+// for all c in the domain of result -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ RShape const& rhs_shape, RStride const& rhs_stride) -+{ -+ if constexpr (is_tuple::value) { -+ // Apply the right-distributivity of Layout composition -+ return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition(lhs, s, d); }); -+ } else -+ if constexpr (is_scaled_basis::value) { -+ // Special case for a ScaledBasis stride -+ return composition(get(lhs), rhs_shape, rhs_stride.value()); -+ } else -+ if constexpr (is_integral::value) { -+ // Integral Rstride (and RShape) -+ -+ // NOTE: Should only flatten once for efficiency -+ auto flat_shape = flatten(lhs.shape()); -+ auto flat_stride = flatten(lhs.stride()); -+ [[maybe_unused]] constexpr int R = rank(flat_shape); -+ -+ if constexpr (is_constant<0, RStride>::value) { -+ // Special case shortcut for any static stride-0 -+ return Layout{rhs_shape, rhs_stride}; -+ } else -+ if constexpr (is_integral::value) { -+ // Special case shortcut for any integral LShape -+ auto result_stride = rhs_stride * flat_stride; -+ return Layout{rhs_shape, result_stride}; -+ } else -+ if constexpr (is_constant<1, RStride>::value) { -+ // Special case shortcut for any static stride-1 -+ auto result_shape_0 = take<0,R-1>(flat_shape); -+ -+ // Mod out the rhs_shape from the lhs.shape() -+ auto const [result_shape_1, rest_shape] = fold(result_shape_0, make_tuple(make_tuple(), rhs_shape), -+ [] (auto const& init, auto const& si) { -+ return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); -+ }); -+ -+ // Jump into coalesce and append (rest_shape, get(lhs.stride()) -+ return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); -+ } else -+ { -+ // General case -+ auto result_shape_0 = take<0,R-1>(flat_shape); -+ auto result_stride_0 = take<0,R-1>(flat_stride); -+ -+ // Divide out the rhs_stride from the lhs.shape() -+ auto const [result_shape_1, rest_stride] = fold(result_shape_0, make_tuple(make_tuple(), rhs_stride), -+ [] (auto const& init, auto const& di) { -+ return make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); -+ }); -+ -+ // Apply any lhs.shape() changes to the stride -+ auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); -+ -+ // Mod out the rhs_shape from the lhs.shape() -+ auto const [result_shape_2, rest_shape] = fold(result_shape_1, make_tuple(make_tuple(), rhs_shape), -+ [] (auto const& init, auto const& si) { -+ return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); -+ }); -+ -+ // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) -+ return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ Layout const& rhs) -+{ -+ //return detail::composition(flatten(lhs), rhs.shape(), rhs.stride()); -+ return detail::composition(lhs, rhs.shape(), rhs.stride()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& lhs, -+ IntTuple const& rhs) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ // Drop any modes of lhs that aren't hit by rhs -+ return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); -+ } else if constexpr (is_underscore::value) { -+ return lhs; -+ } else { -+ return composition(lhs, make_layout(rhs)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Complement -+// -+// Build the complement of a layout. -+// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); -+// @post For all i in [1,size(@a result)), -+// @a result(i) < @a result(i-1) -+// For all j in [0, size(@a layout)), -+// @a result(i) != @a layout(j) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(Layout const& layout, CoSizeHi const& cosize_hi) -+{ -+ // Remove the stride-0 modes, the size-1 modes, and flatten the layout -+ auto flat_layout = filter(layout); -+ -+ if constexpr (is_constant<0, decltype(flat_layout.stride())>::value) { -+ // Special case for stride-0 layout -+ return make_layout(cosize_hi); -+ } else { -+ // General case -+ constexpr int R = decltype(rank(flat_layout))::value; -+ static_assert(R == 1 || is_static::value, -+ "Dynamic-stride complement only for rank-1 layouts"); -+ -+ // Should just be a sort and a fold... -+ // Then we could even handle dynamic strides (but they would destroy all static strides) -+ auto result = fold(make_seq{}, -+ make_tuple(flat_layout.shape(), -+ flat_layout.stride(), -+ make_tuple(), -+ make_tuple(Int<1>{})), -+ [](auto const& init, auto i) -+ { -+ auto curr_stride = cute::min(get<1>(init)); -+ auto curr_idx = find(get<1>(init), curr_stride); -+ auto curr_shape = get(get<0>(init)); -+ -+ return make_tuple(remove(get<0>(init)), // Remove the curr shape -+ remove(get<1>(init)), // Remove the curr stride -+ append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride -+ append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride -+ }); -+ -+ // Append the last shape mode -+ auto result_stride = get<3>(result); -+ auto result_shape = append(get<2>(result), get<1,0>(result) / back(result_stride)); // new shape = curr_stride / last_stride -+ -+ // Compute the rest_stride -+ auto rest_stride = get<0,0>(result) * get<1,0>(result); -+ //return make_layout(append(result_shape, ceil_div(cosize_hi, rest_stride)), append(result_stride, rest_stride)); -+ // Jump into coalesce and append (ceil_div(cosize_hi, rest_stride), rest_stride) -+ return detail::bw_coalesce(result_shape, result_stride, ceil_div(cosize_hi, rest_stride), rest_stride); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(Layout const& layout) -+{ -+ return complement(layout, cosize(layout)); -+} -+ -+// -+// Right-Inverse and Left-Inverse -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+inverse_seq(Shape const& shape, Stride const& stride, seq) -+{ -+ if constexpr (I == decltype(rank(stride))::value) { -+ return seq{}; -+ } else { -+ //auto next_stride = get(shape) * get(stride); -+ using next_stride = decltype(get(shape) * get(stride)); // NOTE: WAR for g++-7 -+ -+ if constexpr (is_static::value) { -+ auto next_idx = find_if(stride, [](auto a) { return is_constant{}; }); -+ return inverse_seq(shape, stride, seq{}); -+ } else { -+ return seq{}; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+// -+// Build the right-inverse of a layout -+// @pre is_static -+// @result A layout @a result such that -+// @a layout(@a result(i)) == i for all i < size(@a result) -+// @result A layout @a result such that -+// composition(@a layout, @a result) is identical to make_layout(shape(result)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(Layout const& layout) -+{ -+ auto flat_layout = coalesce(layout); -+ auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); -+ -+ // Find Int<1>{}, the starting idx, and follow the strides to gen inverse_seq -+ auto next_I = find_if(astride, [](auto a) { return is_constant<1, decltype(a)>{}; }); -+ [[maybe_unused]] auto iseq = detail::inverse_seq(flat_layout.shape(), astride, seq<>{}); -+ -+ if constexpr (tuple_size::value == 0) { -+ return Layout<_1,_0>{}; // Empty case, nothing found -+ } else { -+ // Generate the corresponding new strides and construct -+ auto rstride = compact_col_major(flat_layout.shape()); -+ return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), -+ unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(Underscore const& _) -+{ -+ return _; -+} -+ -+// -+// Build the left-inverse of a layout -+// @pre is_static -+// @pre not has_int0 // @a layout has no 0-strides (is injective) -+// @result A layout @a result such that -+// @a result(@a layout(i)) == i for all i < size(@a layout) -+// @result A layout @a result such that -+// composition(@a result, @a layout) is identical to make_layout(shape(layout)) -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(Layout const& layout) -+{ -+ return right_inverse(make_layout(layout, complement(layout))); -+} -+ -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(Underscore const& _) -+{ -+ return _; -+} -+ -+// -+// Max Common Vector -+// -+ -+/* Return Int such that N is the maximum number of continguous elements -+ * that logically correspond in the layouts of @a a and @a b. This is, -+ * the number of elements that could reasonably be "vectorized" in the layouts. -+ * -+ * @returns Int with N >= 1 -+ * @post For all 0 <= n < N, a(b[n]) == n (NOTE: Problems with negative strides/coords in this post-condition) -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Layout const& a, Layout const& b) -+{ -+ if constexpr (is_static>::value && -+ is_static>::value) -+ { -+ auto result = coalesce(composition(a, right_inverse(b))); -+ -+ if constexpr (is_constant<1, decltype(stride<0>(result))>::value) { -+ return shape<0>(result); -+ } else { -+ return Int<1>{}; -+ } -+ } else { -+ // Dynamic case NOTE: could weaken if we assume dynamic strides are large and multiples of the vector -+ return Int<1>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Zip -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(Layout const& layout) -+{ -+ return make_layout(zip(layout.shape()), -+ zip(layout.stride())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(Layout const& layoutA, -+ Layout const& layoutB) -+{ -+ return make_layout(zip(layoutA.shape(), layoutB.shape()), -+ zip(layoutA.stride(), layoutB.stride())); -+} -+ -+// -+// Tile unzip -+// Logical product and logical divide (on layouts) produce rank-2 results by design. -+// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into -+// their own mode. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_unzip(Layout const& layout, -+ IntTuple const& tile) -+{ -+ return make_layout(zip2_by(layout.shape(), tile), -+ zip2_by(layout.stride(), tile)); -+} -+ -+// -+// Logical divide -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Layout const& layout, -+ Layout const& tile) -+{ -+ //CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{}, -+ // "Tiling does not evenly divide the block"); -+ // NOTE: With tiles that have stride-0, this doesn't have to be true -+ -+ return composition(layout, make_layout(tile, complement(tile, size(layout)))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Layout const& layout, -+ IntTuple const& tile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tile."); -+ return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); }); -+ } else if constexpr (is_underscore::value) { -+ return layout; -+ } else if constexpr (is_integral::value) { -+ return logical_divide(layout, make_layout(tile)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Convenience operator -+// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -+// by gathering the tile modes and residuals into a rank-2 result. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(Layout const& layout, -+ Tile const& tile) -+{ -+ return tile_unzip(logical_divide(layout, tile), tile); -+} -+ -+// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(Layout const& layout, -+ Tile const& tile) -+{ -+ auto div = zipped_divide(layout, tile); -+ -+ auto R = rank<1>(div); -+ return div(_, repeat(_)); -+} -+ -+// -+// Logical product -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& layout, -+ Layout const& tile) -+{ -+ return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& layout, -+ IntTuple const& tile) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value <= Layout::rank); -+ return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); }); -+ } else if constexpr (is_underscore::value) { -+ return layout; -+ } else if constexpr (is_integral::value) { -+ return logical_product(layout, make_layout(tile)); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Convenience operator -+// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -+// by gathering the block modes and products into a rank-2 result. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_product(Layout const& layout, -+ Tile const& tile) -+{ -+ return tile_unzip(logical_product(layout, tile), tile); -+} -+ -+// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(Layout const& layout, -+ Tile const& tile) -+{ -+ auto div = zipped_product(layout, tile); -+ -+ auto R = rank(tile); -+ return div(_, repeat(_)); -+} -+ -+// Attempts to reproduce layout "block" over layout "layout" -+// That is, think of every element of "layout" as a "block" -+// and return the layout of the resulting structure -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(Layout const& block, -+ Layout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block); -+ auto padded_layout = append(layout); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return coalesce(zip(get<0>(result), get<1>(result)), repeat(Int<1>{})); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+raked_product(Layout const& block, -+ Layout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block); -+ auto padded_layout = append(layout); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return coalesce(zip(get<1>(result), get<0>(result)), repeat(Int<1>{})); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_to_shape(Layout const& layout, -+ TrgShape const& trg_shape, -+ ModeOrder const& ord_shape = {}) -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); -+ constexpr int R = rank_v; -+ -+ auto padded_layout = append(layout); -+ -+ auto layout_shape = product_each(padded_layout.shape()); -+ auto target_shape = product_each(trg_shape); -+ -+ // Assert proper division -+ CUTE_STATIC_ASSERT_V(sum(transform(target_shape, layout_shape, modulus{})) == Int<0>{}, -+ "Layout shape does not divide the target shape."); -+ -+ auto product_shape = shape_div(target_shape, layout_shape); -+ -+ return coalesce(blocked_product(padded_layout, make_ordered_layout(product_shape, ord_shape)), product_shape); -+} -+ -+// -+// Upcast -+// For stride-1 mode, divide size by N. Divide all other strides by N. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Shape const& shape, Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { // tuple stride -+ return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); -+ } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride -+ return Layout{shape,stride}; -+ } else if constexpr (is_static::value) { // static stride -+ return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), -+ shape_div(stride, Int{})); -+ } else { // dynamic stride -+ // assume dynamic strides are larger than N and divisible -+ // assert(stride % N == 0); -+ return make_layout(shape, safe_div(stride, Int{})); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Layout const& layout) -+{ -+ return upcast(layout.shape(), layout.stride()); -+} -+ -+// -+// Downcast -+// For stride-1 mode, multiply size by N. Multiply all other strides by N. -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Shape const& shape, Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); -+ } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { -+ return make_layout(shape * Int{}, stride); -+ } else { -+ return make_layout(shape, stride * Int{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Layout const& layout) -+{ -+ CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); -+ return downcast(layout.shape(), layout.stride()); -+} -+ -+// -+// Recast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Layout const& layout) -+{ -+ if constexpr (sizeof(NewType) == sizeof(OldType)) { -+ return layout; -+ } else if constexpr (sizeof(NewType) > sizeof(OldType)) { -+ static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); -+ return upcast(layout); -+ } else if constexpr (sizeof(NewType) < sizeof(OldType)) { -+ static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); -+ return downcast(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(Layout const& layout) -+{ -+ print(layout.shape()); print(":"); print(layout.stride()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) -+{ -+ return os << shape(layout) << ":" << stride(layout); -+} -+ -+// Generic 2D Layout to console table -+template -+CUTE_HOST_DEVICE -+void -+print_layout(Layout const& layout) // (m,n) -> idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ int idx_width = num_digits(cosize(layout)) + 2; -+ const char* delim = "+-----------------------"; -+ -+ print(layout); print("\n"); -+ -+ // Column indices -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } -+ printf("\n"); -+ -+ // Print out A m-by-n -+ for (int m = 0; m < size<0>(layout); ++m) { -+ // Header -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } -+ printf("+\n"); -+ // Values -+ printf("%2d ", m); // Row indices -+ for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } -+ printf("|\n"); -+ } -+ // Footer -+ print(" "); -+ for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } -+ printf("+\n"); -+} -+ -+// Generic ThrVal 2D Layout to console table -+template -+CUTE_HOST_DEVICE -+void -+print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ print(layout); print("\n"); -+ print(thrid); print("\n"); -+ -+ // Print out m-by-n -+ for (int m = 0; m < size<0>(layout); ++m) { -+ // Header -+ for (int n = 0; n < size<1>(layout); ++n) printf("+------"); -+ printf("+\n"); -+ // Values -+ for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); -+ printf("|\n"); -+ } -+ // Footer -+ for (int n = 0; n < size<1>(layout); ++n) printf("+------"); -+ printf("+\n"); -+} -+ -+// Generic 2D Layout to Latex printer -- B&W 8-value color coding -+template -+CUTE_HOST_DEVICE -+void -+print_latex(Layout const& layout) // (m,n) -> idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass[convert]{standalone}\n" -+ "\\usepackage{tikz}\n\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center,font=\\Large}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"black!00", -+ "black!40", -+ "black!20", -+ "black!60", -+ "black!10", -+ "black!50", -+ "black!30", -+ "black!70"}; -+ -+ // Header -+ printf("%% Layout: "); print(layout); printf("\n"); -+ -+ printf(latex_header); -+ -+ // Layout -+ for (int i = 0; i < size<0>(layout); ++i) { -+ for (int j = 0; j < size<1>(layout); ++j) { -+ int idx = layout(i,j); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {%d};\n", -+ color_map[idx % 8], -+ i, j, -+ idx); -+ } -+ } -+ -+ // Labels -+ for (int i = 0, j = -1; i < size<0>(layout); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(layout); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+// Generic ThrVal 2D Layout to Latex TIKZ -- 8-value color coded by thread -+template -+CUTE_HOST_DEVICE -+void -+print_latex(Layout const& layout, ThrID const& thr) // (m,n) -> (tid,vid) and tid -> thr_idx -+{ -+ CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); -+ -+ char const* latex_header = -+ "\\documentclass[convert]{standalone}\n" -+ "\\usepackage{tikz}\n\n" -+ "\\begin{document}\n" -+ "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; -+ char const* latex_footer = -+ "\\end{tikzpicture}\n" -+ "\\end{document}\n"; -+ -+ char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", -+ "{rgb,255:red,175;green,255;blue,175}", -+ "{rgb,255:red,255;green,255;blue,175}", -+ "{rgb,255:red,255;green,175;blue,175}", -+ "{rgb,255:red,210;green,210;blue,255}", -+ "{rgb,255:red,210;green,255;blue,210}", -+ "{rgb,255:red,255;green,255;blue,210}", -+ "{rgb,255:red,255;green,210;blue,210}"}; -+ -+ // Header -+ printf("%% layout: "); print(layout); printf("\n"); -+ printf("%% thrid: "); print(thr); printf("\n\n"); -+ -+ printf(latex_header); -+ -+ // Layout -+ for (int i = 0; i < size<0>(layout); ++i) { -+ for (int j = 0; j < size<1>(layout); ++j) { -+ int thrid = layout(i,j) % size(thr); -+ int val_idx = layout(i,j) / size(thr); -+ int thr_idx = thr(thrid); -+ -+ printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", -+ color_map[thr_idx % 8], -+ i, j, -+ thr_idx, val_idx); -+ } -+ } -+ -+ // Labels -+ for (int i = 0, j = -1; i < size<0>(layout); ++i) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); -+ } -+ for (int j = 0, i = -1; j < size<1>(layout); ++j) { -+ printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); -+ } -+ -+ // Footer -+ printf(latex_footer); -+} -+ -+} // end namespace cute -+ -+// -+// Extended Layouts -+// -+ -+#include -diff --git a/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp b/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp -new file mode 100644 -index 0000000..33471e4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/arithmetic_tuple.hpp -@@ -0,0 +1,388 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct ArithmeticTuple : tuple -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(ArithmeticTuple const& u) -+ : tuple(static_cast const&>(u)) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(tuple const& u) -+ : tuple(u) {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTuple(U const&... u) -+ : tuple(u...) {} -+}; -+ -+template -+struct is_tuple> : true_type {}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_arithmetic_tuple(T const&... t) { -+ return ArithmeticTuple(t...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(tuple const& t) { -+ return ArithmeticTuple(t); -+} -+ -+// -+// Numeric operators -+// -+ -+// Addition -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, tuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(tuple const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); -+ return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); -+} -+ -+// -+// Special cases -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(constant, ArithmeticTuple const& u) { -+ return u; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, constant) { -+ return t; -+} -+ -+// -+// ArithmeticTupleIterator -+// -+ -+template -+struct ArithmeticTupleIterator -+{ -+ ArithTuple coord_; -+ -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTupleIterator() : coord_() {} -+ CUTE_HOST_DEVICE constexpr -+ ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ ArithTuple const& operator*() const { return coord_; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto operator+(Coord const& c) const { -+ return ArithmeticTupleIterator(coord_ + c); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto operator[](Coord const& c) const { return *(*this + c); } -+}; -+ -+template -+CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { -+ printf("ArithTuple"); print(iter.coord_); -+} -+ -+// -+// ArithmeticTuple "basis" elements -+// -+ -+// Abstract value: -+// A ScaledBasis is a (at least) rank-N0 ArithmeticTuple: -+// (_0,_0,...,T,_0,...) -+ -+template -+struct ScaledBasis : private tuple -+{ -+ CUTE_HOST_DEVICE constexpr -+ ScaledBasis(T const& t = {}) : tuple(t) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) value() { return get<0>(static_cast &>(*this)); } -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } -+ -+ CUTE_HOST_DEVICE static constexpr -+ auto mode() { return Int{}; } -+}; -+ -+template -+struct is_scaled_basis : false_type {}; -+template -+struct is_scaled_basis> : true_type {}; -+ -+template -+struct is_integral> : true_type {}; -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+basis_value(T const& e) { -+ return e; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+basis_value(ScaledBasis const& e) { -+ return basis_value(e.value()); -+} -+ -+namespace detail { -+ -+template -+struct Basis; -+ -+template <> -+struct Basis<> { -+ using type = Int<1>; -+}; -+ -+template -+struct Basis { -+ using type = ScaledBasis::type, N>; -+}; -+ -+} // end namespace detail -+ -+template -+using E = typename detail::Basis::type; -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(T const& t, seq, seq) { -+ return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { -+ return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); -+} -+ -+} // end namespace detail -+ -+// Turn a ScaledBases into a rank-M ArithmeticTuple -+// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ScaledBasis const& t) { -+ static_assert(M > N, "Mismatched ranks"); -+ return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); -+} -+ -+// Turn an ArithmeticTuple into a rank-M ArithmeticTuple -+// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) -+template -+CUTE_HOST_DEVICE constexpr -+auto -+as_arithmetic_tuple(ArithmeticTuple const& t) { -+ static_assert(M >= sizeof...(T), "Mismatched ranks"); -+ return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); -+} -+ -+// Return... -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_basis_like(Shape const& shape) -+{ -+ if constexpr (is_integral::value) { -+ return Int<1>{}; -+ } else { -+ // Generate bases for each rank of shape -+ return transform(tuple_seq{}, [&](auto I) { -+ // Generate bases for each rank of shape_i and add an i on front -+ constexpr int i = decltype(I)::value; // NOTE: nvcc workaround -+ return transform_leaf(make_basis_like(get(shape)), [&](auto e) { return ScaledBasis{}; }); -+ }); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Equality -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(ScaledBasis, Int) { -+ return false_type{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(Int, ScaledBasis) { -+ return false_type{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(ScaledBasis const& t, ScaledBasis const& u) { -+ return bool_constant{} && t.value() == u.value(); -+} -+ -+// Multiplication -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator*(A const& a, ScaledBasis const& e) { -+ return ScaledBasis{a*e.value()}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+operator*(ScaledBasis const& e, B const& b) { -+ return ScaledBasis{e.value()*b}; -+} -+ -+// Addition -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, ArithmeticTuple const& u) { -+ constexpr int R = cute::max(N+1, int(sizeof...(U))); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ArithmeticTuple const& t, ScaledBasis const& u) { -+ constexpr int R = cute::max(int(sizeof...(T)), M+1); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, ScaledBasis const& u) { -+ constexpr int R = cute::max(N+1,M+1); -+ return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(constant, ScaledBasis const& u) { -+ return u; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator+(ScaledBasis const& t, constant) { -+ return t; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(ScaledBasis const& e) { -+ printf("%d:", N); print(e.value()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { -+ return os << N << ":" << e.value(); -+} -+ -+} // end namespace cute -+ -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/numeric/bfloat.hpp b/3rdparty/cutlass/include/cute/numeric/bfloat.hpp -new file mode 100644 -index 0000000..94f64ab ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/bfloat.hpp -@@ -0,0 +1,51 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::bfloat16_t; -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v) -+{ -+ return os << float(v); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/complex.hpp b/3rdparty/cutlass/include/cute/numeric/complex.hpp -new file mode 100644 -index 0000000..3790ebd ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/complex.hpp -@@ -0,0 +1,163 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+//#if defined(__CUDA_ARCH__) -+//# include -+//#else -+//# include -+//#endif -+ -+// With CUDA 11.4, builds show spurious "-Wconversion" warnings -+// on line 656 of thrust/detail/type_traits.h. -+// These pragmas suppress the warnings. -+#pragma GCC diagnostic push -+#pragma GCC diagnostic ignored "-Wconversion" -+#include -+#pragma GCC diagnostic pop -+ -+#include -+ -+namespace cute -+{ -+ -+//#if defined(__CUDA_ARCH__) -+//template -+//using complex = cuda::std::complex; -+//#else -+//template -+//using complex = std::complex; -+//#endif -+ -+//template -+//using complex = thrust::complex; -+ -+using thrust::complex; -+ -+template -+CUTE_HOST_DEVICE -+T real(complex const& z) { -+ return z.real(); -+} -+ -+template -+CUTE_HOST_DEVICE -+T imag(complex const& z) { -+ return z.imag(); -+} -+ -+template -+CUTE_HOST_DEVICE -+complex conj(complex const& z) { -+ return complex(real(z), -imag(z)); -+} -+ -+// cute::conj forwards scalars -+template -+CUTE_HOST_DEVICE -+T conj(T z) { -+ return z; -+} -+ -+//CUTE_HOST_DEVICE constexpr -+//float conj(float z) { return z; } -+//CUTE_HOST_DEVICE constexpr -+//double conj(double z) { return z; } -+ -+/// Fused multiply-add for complex numbers -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(complex & d, -+ complex const& a, -+ complex const& b, -+ complex const& c) -+{ -+ d.real(c.real() + a.real() * b.real()); -+ d.imag(c.imag() + a.real() * b.imag()); -+ d.real(d.real() - a.imag() * b.imag()); -+ d.imag(d.imag() + a.imag() * b.real()); -+} -+ -+/// Fused multiply-add for triplets -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(complex const& a, -+ complex const& b, -+ complex & c) -+{ -+ return fma(c, a, b, c); -+} -+ -+/// Used to determine the real-valued underlying type of a numeric type T -+template -+struct RealType { -+ using Type = T; -+}; -+ -+/// Partial specialization for complex-valued type -+template -+struct RealType> { -+ using Type = T; -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct is_complex { -+ static bool const value = false; -+}; -+ -+template -+struct is_complex> { -+ static bool const value = true; -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+// Display utilities -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) -+{ -+ T _r = z.real(); -+ T _i = z.imag(); -+ -+ if (bool(_i)) { -+ return os << _r << "+i" << _i; -+ } else { -+ return os << _r; -+ } -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/float8.hpp b/3rdparty/cutlass/include/cute/numeric/float8.hpp -new file mode 100644 -index 0000000..3fa471d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/float8.hpp -@@ -0,0 +1,43 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::float_e4m3_t; -+using cutlass::float_e5m2_t; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/half.hpp b/3rdparty/cutlass/include/cute/numeric/half.hpp -new file mode 100644 -index 0000000..704ba28 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/half.hpp -@@ -0,0 +1,41 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::half_t; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/int.hpp b/3rdparty/cutlass/include/cute/numeric/int.hpp -new file mode 100644 -index 0000000..a08297f ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/int.hpp -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Signed integers -+// -+ -+using int8_t = std::int8_t; -+using int16_t = std::int16_t; -+using int32_t = std::int32_t; -+using int64_t = std::int64_t; -+ -+template struct int_bit; -+template <> struct int_bit< 2> { using type = cute::int2b_t; }; -+template <> struct int_bit< 4> { using type = cute::int4b_t; }; -+template <> struct int_bit< 8> { using type = int8_t; }; -+template <> struct int_bit< 16> { using type = int16_t; }; -+template <> struct int_bit< 32> { using type = int32_t; }; -+template <> struct int_bit< 64> { using type = int64_t; }; -+ -+template -+using int_bit_t = typename int_bit::type; -+ -+template -+using int_byte = int_bit<8*N>; -+ -+template -+using int_byte_t = typename int_byte::type; -+ -+// -+// Unsigned integers -+// -+ -+using uint8_t = std::uint8_t; -+using uint16_t = std::uint16_t; -+using uint32_t = std::uint32_t; -+using uint64_t = std::uint64_t; -+ -+template struct uint_bit; -+template <> struct uint_bit< 1> { using type = cute::uint1b_t; }; -+template <> struct uint_bit< 2> { using type = cute::uint2b_t; }; -+template <> struct uint_bit< 4> { using type = cute::uint4b_t; }; -+template <> struct uint_bit< 8> { using type = uint8_t; }; -+template <> struct uint_bit< 16> { using type = uint16_t; }; -+template <> struct uint_bit< 32> { using type = uint32_t; }; -+template <> struct uint_bit< 64> { using type = uint64_t; }; -+template <> struct uint_bit<128> { using type = cute::uint128_t; }; -+ -+template -+using uint_bit_t = typename uint_bit::type; -+ -+template -+using uint_byte = uint_bit<8*N>; -+ -+template -+using uint_byte_t = typename uint_byte::type; -+ -+// -+// sizeof_bytes -+// -+ -+template -+struct sizeof_bytes { -+ static constexpr std::size_t value = sizeof(T); -+}; -+template -+static constexpr int sizeof_bytes_v = sizeof_bytes::value; -+ -+// -+// sizeof_bits -+// -+ -+template -+struct sizeof_bits { -+ static constexpr std::size_t value = sizeof(T) * 8; -+}; -+template <> -+struct sizeof_bits { -+ static constexpr std::size_t value = 1; -+}; -+template -+struct sizeof_bits> { -+ static constexpr std::size_t value = Bits; -+}; -+template -+static constexpr int sizeof_bits_v = sizeof_bits::value; -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp b/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp -new file mode 100644 -index 0000000..73a83f7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integer_sequence.hpp -@@ -0,0 +1,139 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include // std::integer_sequence -+ -+#include -+ -+namespace cute -+{ -+ -+using std::integer_sequence; -+using std::make_integer_sequence; -+ -+namespace detail { -+ -+template -+struct make_integer_range_impl; -+ -+template -+struct make_integer_range_impl, Begin> { -+ using type = integer_sequence; -+}; -+ -+} // end namespace detail -+ -+template -+using make_integer_range = typename detail::make_integer_range_impl< -+ T, -+ make_integer_sequence 0) ? (End-Begin) : 0>, -+ Begin>::type; -+ -+// -+// Common aliases -+// -+ -+// int_sequence -+ -+template -+using int_sequence = integer_sequence; -+ -+template -+using make_int_sequence = make_integer_sequence; -+ -+template -+using make_int_range = make_integer_range; -+ -+// index_sequence -+ -+template -+using index_sequence = integer_sequence; -+ -+template -+using make_index_sequence = make_integer_sequence; -+ -+template -+using make_index_range = make_integer_range; -+ -+// -+// Shortcuts -+// -+ -+template -+using seq = int_sequence; -+ -+template -+using make_seq = make_int_sequence; -+ -+template -+using make_range = make_int_range; -+ -+template -+using tuple_seq = make_seq>::value>; -+ -+} // end namespace cute -+ -+ -+// -+// Specialize tuple-related functionality for cute::integer_sequence -+// -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+CUTE_HOST_DEVICE constexpr -+std::tuple_element_t> -+get(integer_sequence) { -+ static_assert(I < sizeof...(Ints), "Index out of range"); -+ return {}; -+} -+ -+} // end namespace cute -+ -+namespace std -+{ -+ -+template -+struct tuple_size> -+ : std::integral_constant -+{}; -+ -+template -+struct tuple_element> -+ : std::tuple_element...>> -+{}; -+ -+} // end namespace std -diff --git a/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp b/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp -new file mode 100644 -index 0000000..3d24a95 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integer_subbyte.hpp -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+#include -+ -+namespace cute { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct integer_subbyte -+{ -+ /// Storage type -+ using Storage = uint8_t; -+ -+ /// Number of bits -+ static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); -+ -+ /// External type -+ using xint_t = typename std::conditional::type; -+ -+ /// Bitmask for truncation from larger integers -+ static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1); -+ /// Bitmask for the sign bit -+ static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1)); -+ -+ // -+ // Data members -+ // -+ -+ Storage storage; -+ -+ // -+ // Methods -+ // -+ -+ /// No operation -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte() {} -+ -+ /// Conversion from integer type -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte(int value) // NOTE: Sign extension? -+ : storage(reinterpret_cast(value) & bits_mask_) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ integer_subbyte(unsigned value) -+ : storage(reinterpret_cast(value) & bits_mask_) {} -+ -+ /// Convert to int or unsigned -+ CUTE_HOST_DEVICE constexpr -+ operator xint_t() const { -+ if (sign_mask_ & storage) { // Sign extend -+ return xint_t(storage) | ~xint_t(bits_mask_); -+ } else { -+ return xint_t(storage); -+ } -+ } -+ -+ /// Equality -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(integer_subbyte const& rhs) const { -+ return storage == rhs.storage; -+ } -+ -+ /// Inequality -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(integer_subbyte const& rhs) const { -+ return storage != rhs.storage; -+ } -+ -+ /// Less than or equal -+ CUTE_HOST_DEVICE constexpr -+ bool operator<=(integer_subbyte const& rhs) const { -+ if (sign_mask_ & storage) { -+ return !(rhs.storage < storage); -+ } else { -+ return storage < rhs.storage; -+ } -+ } -+ -+ /// Less than -+ CUTE_HOST_DEVICE constexpr -+ bool operator<(integer_subbyte const& rhs) const { -+ if (sign_mask_ & storage) { -+ return !(rhs.storage <= storage); -+ } else { -+ return storage < rhs.storage; -+ } -+ } -+ -+ /// Greater than or equal -+ CUTE_HOST_DEVICE constexpr -+ bool operator>=(integer_subbyte const& rhs) const { -+ return !(*this < rhs); -+ } -+ -+ /// Greater than -+ CUTE_HOST_DEVICE constexpr -+ bool operator>(integer_subbyte const& rhs) const { -+ return !(*this <= rhs); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-bit unsigned integer type -+using uint1b_t = integer_subbyte<1, false>; -+ -+/// 2-bit integer type -+using int2b_t = integer_subbyte<2, true>; -+ -+/// 2-bit unsigned integer type -+using uint2b_t = integer_subbyte<2, false>; -+ -+/// 4-bit integer type -+using int4b_t = integer_subbyte<4, true>; -+ -+/// 4-bit unsigned integer type -+using uint4b_t = integer_subbyte<4, false>; -+ -+/// 1-bit binary type -+using bin1_t = bool; -+ -+} // namespace cute -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+ -+#include -+ -+namespace std { -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint1b_t const max() noexcept { return 1; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t lowest() noexcept { return -2; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t min() noexcept { return -2; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int2b_t max() noexcept { return 1; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint2b_t const max() noexcept { return 3; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t lowest() noexcept { return -8; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t min() noexcept { return -8; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::int4b_t max() noexcept { return 7; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const lowest() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const min() noexcept { return 0; } -+ CUTE_HOST_DEVICE static constexpr -+ cute::uint4b_t const max() noexcept { return 15; } -+ static constexpr bool is_integer = true; -+ static constexpr bool is_signed = false; -+}; -+ -+} // namespace std -+ -+#endif -diff --git a/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp b/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp -new file mode 100644 -index 0000000..106763d ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/integral_constant.hpp -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute -+{ -+ -+template -+struct constant : std::integral_constant { -+ static constexpr T value = v; -+ using value_type = T; -+ using type = constant; -+ CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } -+ CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } -+}; -+ -+template -+using integral_constant = constant; -+ -+template -+using bool_constant = constant; -+ -+using true_type = bool_constant; -+using false_type = bool_constant; -+ -+// -+// Traits -+// -+ -+// Use std::is_integral to match built-in integral types (int, int64_t, unsigned, etc) -+// Use cute::is_integral to match both built-in integral types AND constant -+ -+template -+struct is_integral : bool_constant::value> {}; -+template -+struct is_integral> : true_type {}; -+ -+// is_static detects if an (abstract) value is defined completely by it's type (no members) -+ -+template -+struct is_static : bool_constant::value> {}; -+ -+// is_constant detects if a type is a constant and if v is equal to a value -+ -+template -+struct is_constant : false_type {}; -+template -+struct is_constant > : bool_constant {}; -+template -+struct is_constant const > : bool_constant {}; -+template -+struct is_constant const&> : bool_constant {}; -+template -+struct is_constant &> : bool_constant {}; -+template -+struct is_constant &&> : bool_constant {}; -+ -+// -+// Specializations -+// -+ -+template -+using Int = constant; -+ -+using _m32 = Int<-32>; -+using _m24 = Int<-24>; -+using _m16 = Int<-16>; -+using _m12 = Int<-12>; -+using _m10 = Int<-10>; -+using _m9 = Int<-9>; -+using _m8 = Int<-8>; -+using _m7 = Int<-7>; -+using _m6 = Int<-6>; -+using _m5 = Int<-5>; -+using _m4 = Int<-4>; -+using _m3 = Int<-3>; -+using _m2 = Int<-2>; -+using _m1 = Int<-1>; -+using _0 = Int<0>; -+using _1 = Int<1>; -+using _2 = Int<2>; -+using _3 = Int<3>; -+using _4 = Int<4>; -+using _5 = Int<5>; -+using _6 = Int<6>; -+using _7 = Int<7>; -+using _8 = Int<8>; -+using _9 = Int<9>; -+using _10 = Int<10>; -+using _12 = Int<12>; -+using _16 = Int<16>; -+using _24 = Int<24>; -+using _32 = Int<32>; -+using _64 = Int<64>; -+using _96 = Int<96>; -+using _128 = Int<128>; -+using _192 = Int<192>; -+using _256 = Int<256>; -+using _512 = Int<512>; -+using _1024 = Int<1024>; -+using _2048 = Int<2048>; -+using _4096 = Int<4096>; -+using _8192 = Int<8192>; -+ -+/***************/ -+/** Operators **/ -+/***************/ -+ -+#define CUTE_LEFT_UNARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant) { \ -+ return {}; \ -+ } -+#define CUTE_RIGHT_UNARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant) { \ -+ return {}; \ -+ } -+ -+#define CUTE_BINARY_OP(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ operator OP (constant, constant) { \ -+ return {}; \ -+ } -+ -+CUTE_LEFT_UNARY_OP(+); -+CUTE_LEFT_UNARY_OP(-); -+CUTE_LEFT_UNARY_OP(~); -+CUTE_LEFT_UNARY_OP(!); -+CUTE_LEFT_UNARY_OP(*); -+ -+CUTE_BINARY_OP( +); -+CUTE_BINARY_OP( -); -+CUTE_BINARY_OP( *); -+CUTE_BINARY_OP( /); -+CUTE_BINARY_OP( %); -+CUTE_BINARY_OP( &); -+CUTE_BINARY_OP( |); -+CUTE_BINARY_OP( ^); -+CUTE_BINARY_OP(<<); -+CUTE_BINARY_OP(>>); -+ -+CUTE_BINARY_OP(&&); -+CUTE_BINARY_OP(||); -+ -+CUTE_BINARY_OP(==); -+CUTE_BINARY_OP(!=); -+CUTE_BINARY_OP( >); -+CUTE_BINARY_OP( <); -+CUTE_BINARY_OP(>=); -+CUTE_BINARY_OP(<=); -+ -+#undef CUTE_BINARY_OP -+#undef CUTE_LEFT_UNARY_OP -+#undef CUTE_RIGHT_UNARY_OP -+ -+// -+// Mixed static-dynamic special cases -+// -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator*(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator*(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator/(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(U, constant) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator%(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&(constant, U) { -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&(U, constant) { -+ return {}; -+} -+ -+template ::value && !bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&&(constant, U) { -+ return {}; -+} -+ -+template ::value && !bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator&&(U, constant) { -+ return {}; -+} -+ -+template ::value && bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator||(constant, U) { -+ return {}; -+} -+ -+template ::value && bool(t))> -+CUTE_HOST_DEVICE constexpr -+constant -+operator||(U, constant) { -+ return {}; -+} -+ -+// -+// Named functions from math.hpp -+// -+ -+#define CUTE_NAMED_UNARY_FN(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ OP (constant) { \ -+ return {}; \ -+ } -+ -+#define CUTE_NAMED_BINARY_FN(OP) \ -+ template \ -+ CUTE_HOST_DEVICE constexpr \ -+ constant \ -+ OP (constant, constant) { \ -+ return {}; \ -+ } \ -+ \ -+ template ::value)> \ -+ CUTE_HOST_DEVICE constexpr \ -+ auto \ -+ OP (constant, U u) { \ -+ return OP(t,u); \ -+ } \ -+ \ -+ template ::value)> \ -+ CUTE_HOST_DEVICE constexpr \ -+ auto \ -+ OP (T t, constant) { \ -+ return OP(t,u); \ -+ } -+ -+CUTE_NAMED_UNARY_FN(abs); -+CUTE_NAMED_UNARY_FN(signum); -+CUTE_NAMED_UNARY_FN(has_single_bit); -+ -+CUTE_NAMED_BINARY_FN(max); -+CUTE_NAMED_BINARY_FN(min); -+CUTE_NAMED_BINARY_FN(shiftl); -+CUTE_NAMED_BINARY_FN(shiftr); -+CUTE_NAMED_BINARY_FN(gcd); -+CUTE_NAMED_BINARY_FN(lcm); -+ -+#undef CUTE_NAMED_UNARY_FN -+#undef CUTE_NAMED_BINARY_FN -+ -+// -+// Other functions -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+constant -+safe_div(constant, constant) { -+ static_assert(t % u == 0, "Static safe_div requires t % u == 0"); -+ return {}; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(constant, U u) { -+ return t / u; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(T t, constant) { -+ return t / u; -+} -+ -+// cute::true_type prefers standard conversion to std::true_type -+// over user-defined conversion to bool -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+conditional_return(std::true_type, TrueType&& t, FalseType&&) { -+ return static_cast(t); -+} -+ -+// cute::false_type prefers standard conversion to std::false_type -+// over user-defined conversion to bool -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+conditional_return(std::false_type, TrueType&&, FalseType&& f) { -+ return static_cast(f); -+} -+ -+// TrueType and FalseType must have a common type -+template -+CUTE_HOST_DEVICE constexpr -+auto -+conditional_return(bool b, TrueType const& t, FalseType const& f) { -+ return b ? t : f; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(integral_constant const&) { -+ printf("_%d", N); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { -+ return os << "_" << N; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/math.hpp b/3rdparty/cutlass/include/cute/numeric/math.hpp -new file mode 100644 -index 0000000..03e8379 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/math.hpp -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// Common Operations -+// -+ -+template ::value && -+ std::is_arithmetic::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+max(T const& t, U const& u) { -+ return t < u ? u : t; -+} -+ -+template ::value && -+ std::is_arithmetic::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+min(T const& t, U const& u) { -+ return t < u ? t : u; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+abs(T const& t) { -+ if constexpr (std::is_signed::value) { -+ return t < T(0) ? -t : t; -+ } else { -+ return t; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// C++17 operations -+// -+ -+// Greatest common divisor of two integers -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+gcd(T t, U u) { -+ while (true) { -+ if (t == 0) { return u; } -+ u %= t; -+ if (u == 0) { return t; } -+ t %= u; -+ } -+} -+ -+// Least common multiple of two integers -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+lcm(T const& t, U const& u) { -+ return (t / gcd(t,u)) * u; -+} -+ -+// -+// C++20 operations -+// -+ -+// Checks if a number is an integral power of two -+template -+CUTE_HOST_DEVICE constexpr -+bool -+has_single_bit(T x) { -+ return x != 0 && (x & (x - 1)) == 0; -+} -+ -+// Smallest number of bits needed to represent the given value -+// bit_width( 0b0000 ) = 0 -+// bit_width( 0b0001 ) = 1 -+// bit_width( 0b0010 ) = 2 -+// bit_width( 0b0011 ) = 2 -+// bit_width( 0b0100 ) = 3 -+// bit_width( 0b0101 ) = 3 -+// bit_width( 0b0110 ) = 3 -+// bit_width( 0b0111 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_width(T x) { -+ static_assert(std::is_unsigned::value, "Only to be used for unsigned types."); -+ constexpr int N = (std::numeric_limits::digits == 64 ? 6 : -+ (std::numeric_limits::digits == 32 ? 5 : -+ (std::numeric_limits::digits == 16 ? 4 : -+ (std::numeric_limits::digits == 8 ? 3 : (assert(false),0))))); -+ T r = 0; -+ for (int i = N - 1; i >= 0; --i) { -+ T shift = (x > ((T(1) << (T(1) << i))-1)) << i; -+ x >>= shift; -+ r |= shift; -+ } -+ return r + (x != 0); -+} -+ -+// Smallest integral power of two not less than the given value -+// bit_ceil( 0b00000000 ) = 0b00000001 -+// bit_ceil( 0b00000001 ) = 0b00000001 -+// bit_ceil( 0b00000010 ) = 0b00000010 -+// bit_ceil( 0b00000011 ) = 0b00000100 -+// bit_ceil( 0b00000100 ) = 0b00000100 -+// bit_ceil( 0b00000101 ) = 0b00001000 -+// bit_ceil( 0b00000110 ) = 0b00001000 -+// bit_ceil( 0b00000111 ) = 0b00001000 -+// bit_ceil( 0b00001000 ) = 0b00001000 -+// bit_ceil( 0b00001001 ) = 0b00010000 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_ceil(T x) { -+ return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); -+} -+ -+// Largest integral power of two not greater than the given value -+// bit_floor( 0b00000000 ) = 0b00000000 -+// bit_floor( 0b00000001 ) = 0b00000001 -+// bit_floor( 0b00000010 ) = 0b00000010 -+// bit_floor( 0b00000011 ) = 0b00000010 -+// bit_floor( 0b00000100 ) = 0b00000100 -+// bit_floor( 0b00000101 ) = 0b00000100 -+// bit_floor( 0b00000110 ) = 0b00000100 -+// bit_floor( 0b00000111 ) = 0b00000100 -+// bit_floor( 0b00001000 ) = 0b00001000 -+// bit_floor( 0b00001001 ) = 0b00001000 -+template -+CUTE_HOST_DEVICE constexpr -+T -+bit_floor(T x) { -+ return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr T rotl(T x, int s); -+template -+CUTE_HOST_DEVICE constexpr T rotr(T x, int s); -+ -+// Computes the result of circular bitwise left-rotation -+template -+CUTE_HOST_DEVICE constexpr -+T -+rotl(T x, int s) { -+ constexpr int N = std::numeric_limits::digits; -+ return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s); -+} -+ -+// Computes the result of circular bitwise right-rotation -+template -+CUTE_HOST_DEVICE constexpr -+T -+rotr(T x, int s) { -+ constexpr int N = std::numeric_limits::digits; -+ return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s); -+} -+ -+// Counts the number of consecutive 0 bits, starting from the most significant bit -+// countl_zero( 0b00000000 ) = 8 -+// countl_zero( 0b11111111 ) = 0 -+// countl_zero( 0b00011100 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countl_zero(T x) { -+ return std::numeric_limits::digits - bit_width(x); -+} -+ -+// Counts the number of consecutive 1 bits, starting from the most significant bit -+// countl_one( 0b00000000 ) = 0 -+// countl_one( 0b11111111 ) = 8 -+// countl_one( 0b11100011 ) = 3 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countl_one(T x) { -+ return countl_zero(~x); -+} -+ -+// Counts the number of consecutive 0 bits, starting from the least significant bit -+// countr_zero( 0b00000000 ) = 8 -+// countr_zero( 0b11111111 ) = 0 -+// countr_zero( 0b00011100 ) = 2 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countr_zero(T x) { -+ return x == 0 ? std::numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB -+} -+ -+// Counts the number of consecutive 1 bits, starting from the least significant bit -+// countr_one( 0b00000000 ) = 0 -+// countr_one( 0b11111111 ) = 8 -+// countr_one( 0b11100011 ) = 2 -+template -+CUTE_HOST_DEVICE constexpr -+T -+countr_one(T x) { -+ return countr_zero(~x); -+} -+ -+// Counts the number of 1 bits in an unsigned integer -+// popcount( 0b00000000 ) = 0 -+// popcount( 0b11111111 ) = 8 -+// popcount( 0b00011101 ) = 4 -+template -+CUTE_HOST_DEVICE constexpr -+int -+popcount(T x) { -+ int c = 0; -+ while (x) { -+ ++c; -+ x &= x - 1; // clear the least significant bit set -+ } -+ return c; -+} -+ -+// -+// Custom operations -+// -+ -+// Computes the result of bitwise left-shift -+template -+CUTE_HOST_DEVICE constexpr -+T -+shiftl(T x, int s) { -+ return s >= 0 ? (x << s) : (x >> -s); -+} -+ -+// Computes the result of bitwise right-shift -+template -+CUTE_HOST_DEVICE constexpr -+T -+shiftr(T x, int s) { -+ return s >= 0 ? (x >> s) : (x << -s); -+} -+ -+// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+int -+signum(T const& x) { -+ return T(0) < x; -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+int -+signum(T const& x) { -+ return (T(0) < x) - (x < T(0)); -+} -+ -+// Safe divide -+// @pre t % u == 0 -+// @result t / u -+template ::value && -+ std::is_integral::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(T const& t, U const& u) { -+ //assert(t % u == 0); -+ return t / u; -+} -+ -+} // namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/real.hpp b/3rdparty/cutlass/include/cute/numeric/real.hpp -new file mode 100644 -index 0000000..d85e304 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/real.hpp -@@ -0,0 +1,56 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+namespace cute -+{ -+ -+/// Generic fused multiply-add -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(D& d, A const& a, B const& b, C const& c) -+{ -+ d = a * b + c; -+} -+ -+/// Fused multiply-add for triplets -+template -+CUTE_HOST_DEVICE constexpr -+void -+fma(A const& a, B const& b, C& c) -+{ -+ return fma(c, a, b, c); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/tfloat.hpp b/3rdparty/cutlass/include/cute/numeric/tfloat.hpp -new file mode 100644 -index 0000000..bb68b70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/tfloat.hpp -@@ -0,0 +1,51 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+ -+namespace cute { -+ -+using cutlass::tfloat32_t; -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v) -+{ -+ return os << float(v); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/numeric/uint128.hpp b/3rdparty/cutlass/include/cute/numeric/uint128.hpp -new file mode 100644 -index 0000000..fb02441 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/numeric/uint128.hpp -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#include -+#include -+#endif -+ -+#include -+ -+/// Optionally enable GCC's built-in type -+#if defined(__x86_64) && !defined(__CUDA_ARCH__) -+# if defined(__GNUC__) && 0 -+# define CUTE_UINT128_NATIVE -+# elif defined(_MSC_VER) -+# define CUTE_INT128_ARITHMETIC -+# include -+# endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cute { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///! Unsigned 128b integer type -+struct alignas(16) uint128_t -+{ -+ /// Size of one part of the uint's storage in bits -+ static constexpr int storage_bits_ = 64; -+ -+ struct hilo -+ { -+ uint64_t lo; -+ uint64_t hi; -+ }; -+ -+ // Use a union to store either low and high parts or, if present, a built-in 128b integer type. -+ union -+ { -+ struct hilo hilo_; -+ -+#if defined(CUTE_UINT128_NATIVE) -+ unsigned __int128 native; -+#endif // defined(CUTE_UINT128_NATIVE) -+ }; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTE_HOST_DEVICE constexpr -+ uint128_t() : hilo_{0, 0} {} -+ -+ /// Constructor from uint64 -+ CUTE_HOST_DEVICE constexpr -+ uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} -+ -+ /// Constructor from two 64b unsigned integers -+ CUTE_HOST_DEVICE constexpr -+ uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} -+ -+ /// Optional constructor from native value -+#if defined(CUTE_UINT128_NATIVE) -+ uint128_t(unsigned __int128 value) : native(value) { } -+#endif -+ -+ /// Lossily cast to uint64 -+ CUTE_HOST_DEVICE constexpr -+ explicit operator uint64_t() const -+ { -+ return hilo_.lo; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ static void exception() -+ { -+ //static_assert(sizeof(Dummy) == 0, "Not implemented exception!"); -+ //abort(); -+ //printf("uint128 not implemented!\n"); -+ } -+ -+ /// Add -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator+(uint128_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native + rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); -+#endif -+ return y; -+ } -+ -+ /// Subtract -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator-(uint128_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native - rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); -+#endif -+ return y; -+ } -+ -+ /// Multiply by unsigned 64b integer yielding 128b integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator*(uint64_t const& rhs) const -+ { -+ uint128_t y; -+#if defined(CUTE_UINT128_NATIVE) -+ y.native = native * rhs; -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // Multiply by the low part -+ y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); -+ -+ // Add the high part and ignore the overflow -+ uint64_t overflow; -+ y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -+#else -+ exception(); -+#endif -+ return y; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTE_HOST_DEVICE constexpr -+ uint64_t operator/(uint64_t const& divisor) const -+ { -+ uint64_t quotient = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ uint64_t remainder = 0; -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTE_HOST_DEVICE constexpr -+ uint64_t operator%(uint64_t const& divisor) const -+ { -+ uint64_t remainder = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return remainder; -+ } -+ -+ /// Computes the quotient and remainder in a single method. -+ CUTE_HOST_DEVICE constexpr -+ uint64_t divmod(uint64_t &remainder, uint64_t divisor) const -+ { -+ uint64_t quotient = 0; -+#if defined(CUTE_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTE_INT128_ARITHMETIC) -+ // implemented using MSVC's arithmetic intrinsics -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Left-shifts a 128b unsigned integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator<<(int sh) const -+ { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= storage_bits_) { -+ return uint128_t(0, hilo_.lo << (sh - storage_bits_)); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo << sh), -+ (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) -+ ); -+ } -+ } -+ -+ /// Right-shifts a 128b unsigned integer -+ CUTE_HOST_DEVICE constexpr -+ uint128_t operator>>(int sh) const -+ { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= storage_bits_) { -+ return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), -+ (hilo_.hi >> sh) -+ ); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cute -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cute/pointer.hpp b/3rdparty/cutlass/include/cute/pointer.hpp -new file mode 100644 -index 0000000..40ce5d1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/pointer.hpp -@@ -0,0 +1,322 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// has_dereference to determine if a type is a pointer concept -+// -+ -+template -+struct has_dereference : std::false_type { -+}; -+ -+template -+struct has_dereference())>> : std::true_type { -+}; -+ -+// -+// Pointer categories -+// -+ -+template -+struct is_gmem : false_type {}; -+ -+template -+struct is_smem : false_type {}; -+ -+// Anything that is not gmem or smem is rmem -+template -+struct is_rmem : bool_constant< not (is_gmem::value || is_smem::value)> {}; -+ -+// -+// A very simplified wrapper for pointers -- use for constructing tagged pointers -+// -+template -+struct device_ptr -+{ -+ using value_type = T; -+ -+ CUTE_HOST_DEVICE constexpr -+ device_ptr(T* ptr) : ptr_(ptr) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ T* get() const { return ptr_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ T& operator*() const { return *ptr_; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T& operator[](Index const& i) const { return ptr_[i]; } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ DerivedType operator+(Index const& i) const { return {ptr_ + i}; } -+ -+ CUTE_HOST_DEVICE constexpr friend -+ std::ptrdiff_t operator-(device_ptr const& a, -+ device_ptr const& b) { -+ return a.ptr_ - b.ptr_; -+ } -+ -+ T* ptr_; -+}; -+ -+// -+// gmem_ptr -+// -+ -+template -+struct gmem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+gmem_ptr -+make_gmem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+gmem_ptr -+make_gmem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_gmem> : true_type {}; -+ -+// -+// smem_ptr -+// -+ -+template -+struct smem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+smem_ptr -+make_smem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+smem_ptr -+make_smem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_smem> : true_type {}; -+ -+// -+// rmem_ptr -+// -+ -+template -+struct rmem_ptr : device_ptr> { -+ using device_ptr>::device_ptr; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+rmem_ptr -+make_rmem_ptr(T* ptr) { -+ return {ptr}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+rmem_ptr -+make_rmem_ptr(void* ptr) { -+ return {reinterpret_cast(ptr)}; -+} -+ -+template -+struct is_rmem> : true_type {}; -+ -+// -+// counting iterator -- quick and dirty -+// -+ -+struct counting -+{ -+ using index_type = int; -+ using value_type = index_type; -+ -+ CUTE_HOST_DEVICE constexpr -+ counting() : n_(0) {} -+ CUTE_HOST_DEVICE constexpr -+ counting(index_type const& n) : n_(n) {} -+ -+ CUTE_HOST_DEVICE constexpr -+ index_type operator[](index_type const& i) const { return n_ + i; } -+ -+ CUTE_HOST_DEVICE constexpr -+ index_type const& operator*() const { return n_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ counting operator+(index_type const& i) const { return {n_ + i}; } -+ CUTE_HOST_DEVICE constexpr -+ counting& operator++() { ++n_; return *this; } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator==(counting const& other) const { return n_ == other.n_; } -+ CUTE_HOST_DEVICE constexpr -+ bool operator!=(counting const& other) const { return n_ != other.n_; } -+ -+ CUTE_HOST_DEVICE constexpr -+ bool operator< (counting const& other) const { return n_ < other.n_; } -+ -+ index_type n_; -+}; -+ -+// -+// recast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(T* ptr) { -+ return reinterpret_cast(ptr); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(T const* ptr) { -+ return reinterpret_cast(ptr); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(gmem_ptr const& ptr) { -+ return make_gmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(gmem_ptr const& ptr) { -+ return make_gmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr const& ptr) { -+ return make_smem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr const& ptr) { -+ return make_smem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(rmem_ptr const& ptr) { -+ return make_rmem_ptr(recast(ptr.ptr_)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(rmem_ptr const& ptr) { -+ return make_rmem_ptr(recast(ptr.ptr_)); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(T const* const ptr) -+{ -+ printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr); -+} -+ -+template -+CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) -+{ -+ printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr const& ptr) -+{ -+ printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) -+{ -+ printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) -+{ -+ return os << "gmem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) -+{ -+ return os << "smem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) -+{ -+ return os << "rmem_ptr_" << int(8*sizeof(T)) << "b"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/stride.hpp b/3rdparty/cutlass/include/cute/stride.hpp -new file mode 100644 -index 0000000..5fb0da8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/stride.hpp -@@ -0,0 +1,411 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+/** crd2idx maps a coordinate within to an index -+ * This is computed as follows: -+ * [coord, shape, and stride are all integers => step forward by stride] -+ * op(c, s, d) => c * d -+ * [coord is integer, shape and stride are tuple => divmod coord for each mode] -+ * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) -+ * [coord, shape, and stride are all tuples => consider each mode independently] -+ * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) -+ */ -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_ttt(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride, seq) -+{ -+ return (... + crd2idx(get(coord), get(shape), get(stride))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_itt(CInt const& coord, -+ STuple const& shape, -+ DTuple const& stride, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter -+ return crd2idx(coord, get(shape), get(stride)); -+ } else { // General case -+ return crd2idx(coord % product(get(shape)), get(shape), get(stride)) -+ + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape, -+ Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple tuple -+ static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); -+ } else { // tuple "int" "int" -+ static_assert(sizeof(Coord) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // "int" tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); -+ } else { // "int" "int" "int" -+ return coord * stride; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// If we know Stride is default [CompactColMajor], then we can take shortcuts -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx_horner(CTuple const& coord, -+ STuple const& shape, seq) -+{ -+ if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter -+ return get(coord); -+ } else { // General case -+ return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2idx(Coord const& coord, -+ Shape const& shape) -+{ -+ static_assert(decltype(congruent(coord,shape))::value, "Mismatched Ranks"); -+ if constexpr (is_tuple::value) { -+ // Flatten and apply Horner's method -+ auto flat_coord = flatten(coord); -+ auto flat_shape = flatten(shape); -+ return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); -+ } else { -+ return coord; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+/** idx2crd splits an index to a coordinate within . -+ * -+ * This is computed as follows: -+ * [index, shape, and stride are all integers => determine 1D coord] -+ * op(i, s, d) => (i / d) % s -+ * [index is integer, shape and stride are tuple => determine component for each mode] -+ * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) -+ * [index, shape, and stride are all tuples => consider each mode independently] -+ * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) -+ * -+ * NOTE: This only works for compact shape+stride layouts. A more general version would -+ * apply to all surjective layouts -+ */ -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+idx2crd(Index const& idx, -+ Shape const& shape, -+ Stride const& stride) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple tuple -+ static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); -+ } else { // tuple "int" "int" -+ static_assert(sizeof(Index) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // "int" tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); -+ } else { // "int" tuple "int" -+ return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); -+ } -+ } else { // "int" "int" "int" -+ return (idx / stride) % shape; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// If we know Stride is default [CompactColMajor], then we can take shortcuts -+// -+ -+//(idx / 1) % s0 -+//(idx / s0) % s1 -+//(idx / (s0 * s1)) % s2 -+//... -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+idx2crd(Index const& idx, -+ Shape const& shape) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); -+ } else { // tuple "int" -+ static_assert(sizeof(Index) == 0, "Invalid parameters"); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // "int" tuple -+ return idx2crd(idx, shape, compact_col_major(shape)); -+ } else { // "int" "int" -+ return idx; -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// crd2crd -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+crd2crd(Coord const& coord, -+ SShape const& src_shape, -+ DShape const& dst_shape) -+{ -+ if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); -+ } else { -+ // assert(size(src_shape) == size(dst_shape)) -+ return idx2crd(crd2idx(coord, src_shape), dst_shape); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Compact Major -+// -+ -+// General tag for common layouts and dispatching -+struct GenColMajor {}; -+struct GenRowMajor {}; -+ -+template , class Major = GenColMajor> -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major(Shape const& shape, -+ Current const& current = {}, -+ Major const& major = {}); -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major_ti(Shape const& shape, -+ Current const& current, -+ GenColMajor const& major, seq) -+{ -+ return cute::make_tuple(compact_major(get(shape), current * product<0,Is>(shape), major)...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major_ti(Shape const& shape, -+ Current const& current, -+ GenRowMajor const& major, seq) -+{ -+ constexpr int E = tuple_size::value; -+ return cute::make_tuple(compact_major(get(shape), current * product(shape), major)...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_major(Shape const& shape, -+ Current const& current, -+ Major const& major) -+{ -+ if constexpr (is_tuple::value) { -+ if constexpr (is_tuple::value) { // tuple tuple -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); }); -+ } else { // tuple int -+ return detail::compact_major_ti(shape, current, major, tuple_seq{}); -+ } -+ } else { -+ if constexpr (is_tuple::value) { // int tuple -+ static_assert(sizeof(Shape) == 0, "Invalid parameters"); -+ } else { // int int -+ if constexpr (is_constant<1, Shape>::value) { -+ return Int<0>{}; // If current is dynamic, this could save a reg -+ } else { -+ return current; -+ } -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Compact Col Major -+// -+ -+template > -+CUTE_HOST_DEVICE constexpr -+auto -+compact_col_major(Shape const& shape, -+ Current const& current = {}) -+{ -+ return compact_major(shape, current, GenColMajor{}); -+} -+ -+template -+using ColMajor = decltype(compact_col_major(std::declval())); -+ -+// -+// Compact Row Major -+// -+ -+template > -+CUTE_HOST_DEVICE constexpr -+auto -+compact_row_major(Shape const& shape, -+ Current const& current = {}) -+{ -+ return compact_major(shape, current, GenRowMajor{}); -+} -+ -+template -+using RowMajor = decltype(compact_row_major(std::declval())); -+ -+// -+// Compact Order -- compute a compact stride based on an ordering of the modes -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, Order const& order, -+ OrigShape const& orig_shape, OrigOrder const& orig_order) -+{ -+ if constexpr (is_tuple::value) { -+ return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); }); -+ } else { -+ auto d = product(transform(orig_shape, orig_order, -+ [&](auto const& s, auto const& o) { -+ return conditional_return(o < order, product(s), Int<1>{}); -+ })); -+ return compact_col_major(shape, d); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, Order const& order) -+{ -+ static_assert(is_congruent::value, "Need congruence of shape and order."); -+ return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, GenColMajor const& major) -+{ -+ return compact_major(shape, Int<1>{}, major); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+compact_order(Shape const& shape, GenRowMajor const& major) -+{ -+ return compact_major(shape, Int<1>{}, major); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle.hpp b/3rdparty/cutlass/include/cute/swizzle.hpp -new file mode 100644 -index 0000000..0a13e55 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle.hpp -@@ -0,0 +1,497 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// A generic Swizzle functor -+/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx -+ * ^--^ MBase is the number of least-sig bits to keep constant -+ * ^-^ ^-^ BBits is the number of bits in the mask -+ * ^---------^ SShift is the distance to shift the YYY mask -+ * (pos shifts YYY to the right, neg shifts YYY to the left) -+ * -+ * e.g. Given -+ * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx -+ * the result is -+ * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY -+ */ -+template -+struct Swizzle -+{ -+ static constexpr int num_bits = BBits; -+ static constexpr int num_base = MBase; -+ static constexpr int num_shft = SShift; -+ -+ static_assert(num_base >= 0, "MBase must be positive."); -+ static_assert(num_bits >= 0, "BBits must be positive."); -+ static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); -+ -+ // using 'int' type here to avoid unintentially casting to unsigned... unsure. -+ using bit_msk = cute::constant; -+ using yyy_msk = cute::constant; -+ using zzz_msk = cute::constant; -+ using msk_sft = cute::constant; -+ -+ static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr static -+ auto -+ apply(Offset const& offset) -+ { -+ return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Offset const& offset) const -+ { -+ return apply(offset); -+ } -+}; -+ -+// Translation for legacy SwizzleXor -+// TODO: Deprecate -+template -+using SwizzleXor = Swizzle; -+ -+// -+// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> -+// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle() -+{ -+ constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits -+ constexpr uint32_t BY = popcount(Z); // Number of swizzle bits -+ static_assert(BZ == BY, "Number of bits in Y and Z don't match"); -+ constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y -+ constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z -+ constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32; -+ constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros -+ static_assert((Y | Z) == Swizzle::swizzle_code, "Something went wrong."); -+ return Swizzle{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle, Swizzle) -+{ -+ static_assert(S0 == S1, "Can only merge swizzles of the same shift."); -+ constexpr uint32_t Y = Swizzle::yyy_msk::value ^ Swizzle::yyy_msk::value; -+ constexpr uint32_t Z = Swizzle::zzz_msk::value ^ Swizzle::zzz_msk::value; -+ return make_swizzle(); -+ -+ //return ComposedFn, Swizzle>{}; -+} -+ -+// -+// Upcast and Downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(Swizzle const& swizzle) -+{ -+ static_assert(has_single_bit(N), "N must be a power of two"); -+ constexpr int log2_n = bit_width(uint32_t(N)) - 1; -+ constexpr int NewM = M - log2_n; -+ if constexpr (NewM >= 0) { -+ return Swizzle{}; -+ } else { -+ return Swizzle{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(Swizzle const& swizzle) -+{ -+ static_assert(has_single_bit(N), "N must be a power of two"); -+ constexpr int log2_n = bit_width(uint32_t(N)) - 1; -+ return Swizzle{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Swizzle const& swizzle) -+{ -+ if constexpr (sizeof_bits::value == sizeof_bits::value) { -+ return swizzle; -+ } else if constexpr (sizeof_bits::value > sizeof_bits::value) { -+ static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); -+ return upcast::value/sizeof_bits::value>(swizzle); -+ } else if constexpr (sizeof_bits::value < sizeof_bits::value) { -+ static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); -+ return downcast::value/sizeof_bits::value>(swizzle); -+ } -+} -+ -+// -+// Utility for slicing and swizzle "offsets" -+// -+ -+// For swizzle functions, it is often needed to keep track of which bits are -+// consumed and which bits are free. Furthermore, it is useful to know whether -+// each of these bits is known statically or dynamically. -+ -+// MixedBits is an integer class where some bits are known statically and some -+// bits are known dynamically. These sets of bits are disjoint and it is known -+// statically which bits are known dynamically. -+ -+// MixedBits can only be manipulated through bitwise operations -+ -+// Abstract value: StaticInt | (dynamic_int_ & StaticFlags) -+template // 0: static, 1: dynamic -+struct MixedBits -+{ -+ // Representation invariants -+ static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); -+ static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); -+ // assert((dynamic_int_ & ~F) == 0); -+ -+ DynamicType dynamic_int_; -+}; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_mixed_bits(constant const&, DynamicType const& d, constant const&) -+{ -+ static_assert(is_integral::value); -+ if constexpr (is_static::value) { -+ static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed."); -+ return constant{} | (d & constant{}); // Just return a static int -+ } else if constexpr (f == 0) { -+ return constant{}; // Just return a static int -+ } else { -+ return MixedBits{d & f}; // MixedBits -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Explicit conversion for now -- consider casting on plus or minus -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_integral(MixedBits const& m) -+{ -+ //return S | (m.dynamic_int_ & F); -+ return S | m.dynamic_int_; -+} -+ -+// Any cute::is_integral -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+to_integral(I const& i) -+{ -+ return i; -+} -+ -+// -+// Operators -+// -+ -+// Equality -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(MixedBits const& m, constant const&) -+{ -+ return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator==(constant const& s, MixedBits const& m) -+{ -+ return m == s; -+} -+ -+// Bitwise AND -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 0X0 | 0X0 | 0X0 | -+ // 001 | 0X0 | 001 | 001 | 001 | -+ // 011 | 0X0 | 001 | 011 | 011 | -+ // 1X0 | 0X0 | 001 | 011 | 1X0 | -+ -+ return make_mixed_bits(constant{}, -+ //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), -+ ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator&(constant const& s, MixedBits const& m) -+{ -+ return m & s; -+} -+ -+// Bitwise OR -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 001 | 011 | 1X0 | -+ // 001 | 001 | 001 | 011 | 1X0 | -+ // 011 | 011 | 011 | 011 | 1X0 | -+ // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | -+ -+ return make_mixed_bits(constant{}, -+ ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator|(constant const& s, MixedBits const& m) -+{ -+ return m | s; -+} -+ -+// Bitwise XOR -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(MixedBits const& m0, MixedBits const& m1) -+{ -+ // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) -+ // S0D0F0 | 0X0 | 001 | 011 | 1X0 | -+ // S1D1F1 -+ // 0X0 | 0X0 | 001 | 011 | 1X0 | -+ // 001 | 001 | 001 | 011 | 011 | -+ // 011 | 011 | 011 | 001 | 001 | -+ // 1X0 | 1X0 | 011 | 001 | 0X0 | -+ -+ return make_mixed_bits(constant{}, -+ (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(MixedBits const& m, constant const&) -+{ -+ return make_mixed_bits(constant{}, -+ (S0 | m.dynamic_int_) ^ S1, -+ constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+operator^(constant const& s, MixedBits const& m) -+{ -+ return m ^ s; -+} -+ -+// -+// upcast and downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+safe_div(MixedBits const& m, constant const& s) -+{ -+ static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two."); -+ return make_mixed_bits(safe_div(constant{}, s), -+ safe_div(m.dynamic_int_, s), -+ safe_div(constant{}, s)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(MixedBits const& m) -+{ -+ static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); -+ return safe_div(m, constant{}); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(T const& m) -+{ -+ return safe_div(m, constant{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(MixedBits const& m) -+{ -+ static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); -+ return make_mixed_bits(constant{}, -+ m.dynamic_int_ * N, -+ constant{}); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(T const& m) -+{ -+ return m * constant{}; -+} -+ -+// -+// Convert a Pow2Layout+Coord to a MixedBits -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord) -+{ -+ if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); -+ return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); }, -+ [](auto const&... a) { return (a ^ ...); }); -+ } else if constexpr (is_integral::value && is_integral::value && is_integral::value) { -+ static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride."); -+ return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride); -+ } else { -+ static_assert(is_integral::value && is_integral::value && is_integral::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral)."); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+to_mixed_bits(Layout const& layout, Coord const& coord) -+{ -+ return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape())); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(MixedBits const& m) -+{ -+ printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m)); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) -+{ -+ return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m); -+} -+ -+template -+CUTE_HOST_DEVICE void print(Swizzle const&) -+{ -+ print("S<%d,%d,%d>", B, M, S); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) -+{ -+ return os << "S<" << B << "," << M << "," << S << ">"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle_layout.hpp b/3rdparty/cutlass/include/cute/swizzle_layout.hpp -new file mode 100644 -index 0000000..1376a47 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle_layout.hpp -@@ -0,0 +1,1010 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+ -+/* This implements a ComposedLayout of the form -+ * InvolutionFn o OffsetPlus o Layout -+ * where the InvolutionFn need not be linear (hence the need for the Offset). -+ * -+ * This ComposedLayout provides similar coordinate-to-index mapping and layout manipulations, -+ * but is not considered a "normal" layout. -+ * For example, this layout provides size() functions, but does not provide stride() functions. -+ * -+ * Furthermore, for known InvolutionFns, this layout attempts to decay itself -+ * to a normal-layout with dynamic or static strides. -+ * This is possible by determining the subdomain of the Involution function -+ * that is identity and testing if the right Layout's codomain is contained -+ * within it. -+ */ -+ -+namespace cute -+{ -+ -+// A Layout of non-trivially composable functions: F o I o L -+template -+struct ComposedLayout -+ : private cute::tuple // EBO for static layouts -+{ -+ CUTE_HOST_DEVICE constexpr -+ ComposedLayout(InvolutionFn const& fn = {}, -+ IntermediateOffset const& offset = {}, -+ Layout const& layout = {}) -+ : cute::tuple(fn, offset, layout) -+ {} -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = Layout::rank; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ swizzle_fn() const { -+ return get<0>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ offset_fn() const { -+ return get<1>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout_fn() const { -+ return get<2>(static_cast const&>(*this)); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return layout_fn().shape(); -+ } -+ -+ // Doesn't really make sense to ask for the strides of this "layout" -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const = delete; -+ -+ // -+ // Mappings -+ // -+ -+ // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) -+ // OR -+ // Slice the layout and return the sublayout (Coord has an Underscore slice op) -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ return slice(coord, *this); -+ } else { -+ return swizzle_fn()(to_integral(offset_fn()) + layout_fn()(coord)); // (F o L)(c) -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // Map a 1D linear coordinate to a flat ND logical coordinate -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator[](Int const& linear_idx) const { -+ return get_flat_coord(linear_idx); -+ } -+ -+ // Convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(OtherLayout const& other) const { -+ return composition(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return composition(*this, make_tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(OtherShape const& shape) const { -+ return composition(*this, make_layout(shape)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ with_shape(Shapes const&... shapes) const { -+ return composition(*this, make_layout(make_shape(shapes...))); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(OtherLayout const& other) const { -+ return tiled_divide(*this, other); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return tiled_divide(*this, make_tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ // -+ // Index to Coordinate -+ // -+ -+ // NOTE Only valid for compact layouts -+ -+ // Return the (hierarchical) ND logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post congruent(@a result, shape()) -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(IInt const& idx) const { -+ return layout_fn().get_hier_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+ -+ // Return the (flat) ND logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(IInt const& idx) const { -+ return layout_fn().get_flat_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+ -+ // Return the generalized column-major 1D logical coordinate corresponding to the linear index -+ // @post this->crd2idx(@a result) == idx -+ // @post is_integral::value -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(IInt const& idx) const { -+ return layout_fn().get_1d_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) -+ } -+}; -+ -+template -+struct is_layout> : true_type {}; -+ -+template -+struct is_composed_layout : false_type {}; -+template -+struct is_composed_layout> : true_type {}; -+ -+// -+// Constructors -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Swizzle const& sxor) -+{ -+ return composition(sxor, Layout,Int<1>>{}); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(ComposedLayout const& a, Layout const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), make_layout(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_layout(Layout const& a, ComposedLayout const& b) -+{ -+ return composition(b.swizzle_fn(), b.offset_fn(), make_layout(a, b.layout_fn())); -+} -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+transfer_swizzle(Layout const& old_layout, -+ Layout const& new_layout) -+{ -+ // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations -+ -+ // This is accomplished by identifying -+ // S o L :=: S? o L* -+ // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S -+ // Then that active identifier is transformed through the layouts: -+ // L*(L[(P o L)(c*)]) -+ // which is a new swizzle identifier for S?, the new swizzle -+ -+ // Projections of the swizzle layout for composition, P -+ auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), -+ make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); -+ -+ // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ auto layout_only_zy = composition(swizzle_only_zy, old_layout); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); -+ -+ // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y -+ auto zzz_msk = typename Swizzle::zzz_msk{}; -+ auto yyy_msk = typename Swizzle::yyy_msk{}; -+ auto msk_sft = typename Swizzle::msk_sft{}; -+ auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; -+ auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; -+ -+ // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) -+ auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); -+ auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); -+ -+ // Use this new swizzle identifier to construct the new swizzle for new_layout -+ // (this also makes sure it's a "valid" swizzle that Swizzle can represent) -+ return composition(make_swizzle(), new_layout); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(ComposedLayout,Offset,Layout> const& layout) -+{ -+ return detail::transfer_swizzle(layout.layout_fn(), make_fragment_like(layout.layout_fn())); -+} -+ -+// -+// Utilities -+// -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(ComposedLayout const& clayout) -+{ -+ return composition(clayout.swizzle_fn(), clayout.offset_fn(), layout(clayout.layout_fn())); -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(ComposedLayout const& layout) -+{ -+ return shape(layout.layout_fn()); -+} -+ -+// Doesn't make sense to directly ask for the strides of this "layout" -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(ComposedLayout const& layout) = delete; -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+size(ComposedLayout const& layout) -+{ -+ return size(layout.layout_fn()); -+} -+ -+// Return the number of modes -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(ComposedLayout const& layout) -+{ -+ return rank(layout.layout_fn()); -+} -+ -+// Return the depth of the layout -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(ComposedLayout const& layout) -+{ -+ return depth(layout.layout_fn()); -+} -+ -+// Return the codomain size of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+cosize(ComposedLayout const& layout) -+{ -+ return cosize(layout.layout_fn()); -+} -+ -+// -+// Operations to manipulate Layouts like a tuple of pairs -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+get(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), get(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+take(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), take(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), flatten(a.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+append(ComposedLayout const& a, X const& x) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), append(a.layout_fn(), x)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+group(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), group(a.layout_fn())); -+} -+ -+// -+// Slice a ComposedLayout -+// -+ -+namespace detail { -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle_strides(true_type, -+ IntZ const& Z, -+ IntY const& Y, -+ Offset const& offset, -+ int_sequence) -+{ -+ // Below is an optimized/compressed version of: -+ //return make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); -+ // with knowledge of Swizzle, I... ranges for each B bits, -+ // and the layout won't slice along z-bits that are already set -+ -+ // y\z 0 1 -+ // 0 Z DC -+ // 1 -Z DC -+ -+ return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_swizzle_strides(false_type, -+ IntZ const& Z, -+ IntY const& Y, -+ Offset const& offset, -+ int_sequence) -+{ -+ // Below is an optimized/compressed version of: -+ //return make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); -+ // with knowledge of Swizzle, I... ranges for each B bits, -+ // and the layout won't slice along y-bits that are already set -+ -+ // y\z 0 1 -+ // 0 Y+Z Y-Z -+ // 1 DC DC -+ -+ return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); -+} -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) -+{ -+ if constexpr (all_underscore::value) { -+ // Skip the expensive/complicated attempt to decay to a normal layout and just reshape -+ return cute::make_tuple(composition(layout.swizzle_fn(), layout.offset_fn(), slice(coord, layout.layout_fn())), Int<0>{}); -+ } else { -+ -+ // Projections of the swizzle layout for composition -+ auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); -+ -+ auto swizzle_anti_zy = make_layout(shape(sw), -+ make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); -+ auto swizzle_only_zy = make_layout(shape(sw), -+ make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); -+ -+ // The portion of the layout that is not yet consumed -+ auto sliced_layout = slice(coord, layout.layout_fn()); -+ -+ // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay -+ -+ // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ // (this also tests that shape/stride of layout compose with swizzle) -+ auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); -+ // Determine if any active bits collide under the swizzle -+ auto hit_ZandY = !(swizzle_active_bits & ~layout.swizzle_fn()(swizzle_active_bits)); -+ -+ // The portion of the layout that we are consuming now -+ auto diced_layout = dice(coord, layout.layout_fn()); -+ auto diced_coord = dice(coord, coord); -+ -+ auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); -+ auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); -+ -+ // New swizzle and offset -+ auto swizzle = layout.swizzle_fn(); -+ // offset_only_zy interacts with swizzle and gets accumulated with layout.offset_fn() -+ // being careful about the static/dynamic contributions from diced_layout and diced_coord -+ auto offset_only_zy = layout.offset_fn() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); -+ // offset_anti_zy always gets passed through, no interaction with swizzle -+ auto offset_anti_zy = diced_layout_anti_zy(diced_coord); -+ -+ // If Layout's codomain hits on Y AND Z, then it's not reducible -+ // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal -+ // If Layout's codomain hits on neither Y NOR Z, then it's static-normal -+ -+ // Test the sliced layout for hit_X & hit_Y for potential decay -+ if constexpr (is_constant::value) -+ { // Hits on Y AND Z, so it's not reducible -+ return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); -+ } else -+ { // Misses on Y or Z, so it's static-normal or dynamic-normal -+ -+ // Lowest bit of the Z and Y masks -+ auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; -+ auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; -+ auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); -+ auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); -+ -+ // Construct a (dynamic) layout that we can perform the composition with -+ auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), -+ make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); -+ -+ // Decay to a normal layout with offset -+ return cute::make_tuple(composition(swizzle_layout, sliced_layout), -+ swizzle(to_integral(offset_only_zy)) + offset_anti_zy); -+ } -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(Coord const& coord, ComposedLayout const& layout) -+{ -+ return get<0>(slice_and_offset(coord, layout)); -+} -+ -+// -+// composition -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Offset const& offset, -+ Layout const& layout) -+{ -+ return ComposedLayout>{sxor, offset, layout}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Offset const& offset, -+ ComposedLayout const& layout) -+{ -+ // Assume disjoint swizzles and offsets for commutivity -+ return composition(composition(sxor,layout.swizzle_fn()), offset ^ layout.offset_fn(), layout.layout_fn()); -+} -+ -+// Ignore identity case -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle<0,M,S> const&, -+ Int<0> const&, -+ Layout const& layout) -+{ -+ return layout; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Swizzle const& sxor, -+ Layout const& layout) -+{ -+ return composition(sxor, Int<0>{}, layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(ComposedLayout const& a, -+ LayoutOrTile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), composition(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& a, -+ Swizzle const& b) -+{ -+ // Get the Z bits and the Y bits -+ auto active_Y = a(typename Swizzle::yyy_msk{}); -+ auto active_Z = a(typename Swizzle::zzz_msk{}); -+ -+ // Works in simple cases... but could be greatly generalized -+ -+ return composition(make_swizzle(), a); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(Layout const& a, -+ ComposedLayout const& b) -+{ -+ CUTE_STATIC_ASSERT_V(b.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); -+ -+ return composition(composition(a, b.swizzle_fn()), b.layout_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+composition(ComposedLayout const& a, -+ ComposedLayout const& b) -+{ -+ auto asb = composition(a.layout_fn(), b); -+ -+ return composition(composition(a.swizzle_fn(),asb.swizzle_fn()), asb.offset_fn(), asb.layout_fn()); -+} -+ -+// -+// complement -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) -+{ -+ // Assume there is no swizzle component in the complement -+ return complement(layout.layout_fn(), cosize_hi); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+complement(ComposedLayout const& layout) -+{ -+ return complement(layout, cosize(layout)); -+} -+ -+// -+// inverse -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+right_inverse(ComposedLayout const& layout) -+{ -+ CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); -+ return composition(right_inverse(layout.layout_fn()), layout.swizzle_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+left_inverse(ComposedLayout const& layout) -+{ -+ CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); -+ return composition(left_inverse(layout.layout_fn()), layout.swizzle_fn()); -+} -+ -+// -+// Other operations -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(ComposedLayout,Offset,SLayout> const& a, -+ Layout const& b) -+{ -+ // This assumes that Offset is in the YZ domain of the Swizzle... -+ return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Layout const& a, -+ ComposedLayout,Offset,SLayout> const& b) -+{ -+ return max_common_vector(b, a); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, -+ ComposedLayout,Offset1,SLayout1> const& b) -+{ -+ auto result = coalesce(composition(a, right_inverse(b))); -+ -+ if constexpr (is_constant<1, decltype(stride<0>(result.layout_fn()))>::value) { -+ return shape<0>(result); -+ } else { -+ return Int<1>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zip(ComposedLayout const& a) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), zip(a.layout_fn())); -+} -+ -+// Partitions -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), logical_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_unzip(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tile_unzip(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tiled_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), zipped_divide(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), logical_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), tiled_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), blocked_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+raked_product(ComposedLayout const& a, -+ Tile const& b) -+{ -+ return composition(a.swizzle_fn(), a.offset_fn(), raked_product(a.layout_fn(), b)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tile_to_shape(ComposedLayout const& layout, -+ Shape const& trg_shape, -+ ModeOrder const& ord_shape = {}) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), tile_to_shape(layout.layout_fn(), trg_shape, ord_shape)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+filter(ComposedLayout const& layout, Shape const& trg_profile) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), filter(layout.layout_fn(), trg_profile)); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(ComposedLayout const& layout) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(ComposedLayout const& layout, Shape const& trg_profile) -+{ -+ return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn(), trg_profile)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+// ComposedLayout as second argument is often more difficult... -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+logical_product(Layout const& block, -+ ComposedLayout,Offset,LayoutT> const& tile) -+{ -+ CUTE_STATIC_ASSERT_V(tile.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); -+ // The new layout -- if swizzle wasn't an issue, this is the result -+ // our goal is to determine a new swizzle for these strides -+ auto new_layout = logical_product(block, tile.layout_fn()); -+ -+ // This is accomplished by identifying -+ // S o L :=: S? o L* -+ // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S -+ // Then that active identifier is transformed through the layouts: -+ // L*(L[(P o L)(c*)]) -+ // which is a new swizzle identifier for S?, the new swizzle -+ -+ // Projections of the swizzle layout for composition, P -+ auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), -+ make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); -+ -+ // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] -+ auto layout_only_zy = composition(swizzle_only_zy, tile.layout_fn()); -+ // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) -+ auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); -+ // Get the Z bit and the Y bits -+ auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; -+ auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; -+ -+ // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) -+ auto new_active_Z = new_layout(Int<0>{}, tile.layout_fn()[active_Z]); -+ auto new_active_Y = new_layout(Int<0>{}, tile.layout_fn()[active_Y]); -+ -+ // Use this new swizzle identifier to construxt the new swizzle for new_layout -+ // (this also makes sure it's a "valid" swizzle that Swizzle can represent) -+ return composition(make_swizzle(), new_layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_product(Layout const& block, -+ ComposedLayout const& tile) -+{ -+ /// Avoid swizzle slice -+ auto result = logical_product(block, tile); -+ return composition(result.swizzle_fn(), result.offset_fn(), result.layout_fn()(_, repeat>(_))); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+blocked_product(Layout const& block, -+ ComposedLayout const& layout) -+{ -+ constexpr int R = cute::max(rank_v, rank_v); -+ auto padded_block = append(block, Layout<_1,_0>{}); -+ auto padded_layout = append(layout, Layout<_1,_0>{}); -+ -+ auto result = logical_product(padded_block, padded_layout); -+ -+ return composition(result.swizzle_fn(), -+ result.offset_fn(), -+ coalesce(zip(get<0>(result.layout_fn()), get<1>(result.layout_fn())), repeat(Int<1>{}))); -+} -+ -+// -+// Upcast and Downcast -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(ComposedLayout const& layout) -+{ -+ return composition(upcast(layout.swizzle_fn()), upcast(layout.offset_fn()), upcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(ComposedLayout const& layout) -+{ -+ return composition(downcast(layout.swizzle_fn()), downcast(layout.offset_fn()), downcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(ComposedLayout const& layout) -+{ -+ if constexpr (sizeof(NewType) == sizeof(OldType)) { -+ return layout; -+ } else if constexpr (sizeof(NewType) > sizeof(OldType)) { -+ static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); -+ return upcast(layout); -+ } else if constexpr (sizeof(NewType) < sizeof(OldType)) { -+ static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); -+ return downcast(layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print(ComposedLayout const& layout) -+{ -+ print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) -+{ -+ return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/swizzle_ptr.hpp b/3rdparty/cutlass/include/cute/swizzle_ptr.hpp -new file mode 100644 -index 0000000..ed77acb ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/swizzle_ptr.hpp -@@ -0,0 +1,282 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+/* This implements a swizzle pointer of the form -+ * InvolutionFn o PtrAdd -+ * where the InvolutionFn need not be linear. -+ * -+ * This differs subtly from swizzle_layout because the smem pointer is used -+ * as the offset. That means that swizzle_layout will implement position-independent -+ * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. -+ * Arch chose to design hardware with position-dependent swizzles. -+ * -+ * For clarity: -+ * NormalLayout : DeRef <- PtrAdd <- [Layout] -+ * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] -+ * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout -+ * -+ * Furthermore, for known swizzles, this pointer attempts to decay itself -+ * to a normal-pointer with a new layout containing dynamic or static strides. -+ * This is possible by determining the subdomain of the InvolutionFn -+ * that is identity and testing if the Layout's codomain is contained -+ * within it. -+ */ -+ -+namespace cute -+{ -+ -+template -+struct smem_ptr_swizzle -+{ -+ static_assert(std::is_empty::value, "Swizzle can't have state."); -+ -+ CUTE_HOST_DEVICE constexpr -+ T* get() const -+ { -+ return ptr_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ Swizzle get_swizzle() -+ { -+ return {}; -+ } -+ -+ CUTE_HOST_DEVICE constexpr static -+ T* apply_swizzle(T* ptr) -+ { -+ return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ T& operator*() const -+ { -+ return *apply_swizzle(get()); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T& operator[](Int const& i) const -+ { -+ return *apply_swizzle(get() + i); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ smem_ptr_swizzle operator+(Int const& i) const -+ { -+ return {ptr_ + i}; -+ } -+ -+ T* ptr_; -+}; -+ -+template -+struct is_smem> : true_type {}; -+ -+// Make a swizzle pointer -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_smem_ptr(T* ptr, Swizzle const& swizzle) -+{ -+ return smem_ptr_swizzle{ptr}; -+} -+ -+// A model of a nullptr smem_ptr with B == sizeof_bits::value -+// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr -+template -+struct smem_ptr_flag_bits : Int<0> {}; -+ -+using smem_ptr_flag = smem_ptr_flag_bits<1>; -+ -+// A flagged construction method to transform ComposedLayout -+// Make a swizzle pointer tensor and check that the intended type size matches -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr const& ptr, -+ ComposedLayout,Layout> const& layout) -+{ -+ static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); -+ return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()), -+ layout.layout_fn()); -+} -+ -+// Specialization for immediate decay -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr_swizzle>& p, Layout const& layout) -+{ -+ return make_tensor(make_smem_ptr(p.ptr_), layout); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(smem_ptr_swizzle> const& p, Layout const& layout) -+{ -+ return make_tensor(make_smem_ptr(p.ptr_), layout); -+} -+ -+// NOTE: To preserve smem_ptr_flag_bits under recast ops -+template -+CUTE_HOST_DEVICE constexpr -+auto -+upcast(ComposedLayout,Layout> const& layout) -+{ -+ return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, upcast(layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+downcast(ComposedLayout,Layout> const& layout) -+{ -+ return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, downcast(layout.layout_fn())); -+} -+ -+// -+// Recast -+// Swizzle operates on the pointer address, so it doesn't care about the type -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr_swizzle const& ptr) -+{ -+ return smem_ptr_swizzle{recast(ptr.ptr_)}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+recast(smem_ptr_swizzle const& ptr) -+{ -+ return smem_ptr_swizzle{recast(ptr.ptr_)}; -+} -+ -+// -+// Conversion with swizzle_layout -+// -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) -+{ -+ return composition(recast,uint_bit_t>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn()); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout> const& tensor) -+{ -+ { -+ uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); -+ uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); -+ assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle -+ } -+ auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); -+ return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) -+{ -+ { -+ uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); -+ uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); -+ assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle -+ } -+ auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); -+ return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -+} -+ -+template -+CUTE_HOST_DEVICE -+auto -+as_position_independent_swizzle_tensor(Tensor>, Layout>&& tensor) -+{ -+ return as_position_independent_swizzle_tensor(tensor); -+} -+ -+// -+// Print -+// -+ -+// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -+template -+CUTE_HOST_DEVICE -+void -+print_latex(ComposedLayout,Layout> const& layout) -+{ -+ auto new_swizzle = recast,uint_bit_t>(layout.swizzle_fn()); -+ print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn())); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) -+{ -+ printf("smem_ptr_%db(unset)", B); -+} -+ -+template -+CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) -+{ -+ printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) -+{ -+ return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/tensor.hpp b/3rdparty/cutlass/include/cute/tensor.hpp -new file mode 100644 -index 0000000..e88c22b ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tensor.hpp -@@ -0,0 +1,900 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// -+// Engine -- owning or non-owning data store -+// -+ -+// concept Engine { -+// using value_type = ; -+// iterator begin(); -+// }; -+ -+template -+using ArrayEngine = typename std::conditional<(sizeof_bits::value % 8 == 0), -+ array_aligned, -+ array_subbyte>::type; -+ -+template -+struct ViewEngine -+{ -+ using value_type = typename cute::remove_cvref())>::type; -+ -+ using iterator = Iterator; -+ iterator storage_; -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator const& -+ begin() const { -+ return storage_; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator& -+ begin() { -+ return storage_; -+ } -+}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+template -+struct ConstViewEngine -+{ -+ using value_type = typename cute::remove_cvref())>::type; -+ -+ using iterator = Iterator; -+ iterator storage_; -+ -+ CUTE_HOST_DEVICE constexpr -+ iterator const& -+ begin() const { -+ return storage_; -+ } -+}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+// -+// Tensor -+// -+ -+template -+struct Tensor -+{ -+ using value_type = typename Engine::value_type; -+ //using pointer = typename engine_traits::pointer; -+ //using const_pointer = typename engine_traits::const_pointer; -+ //using reference = typename engine_traits::reference; -+ //using const_reference = typename engine_traits::const_reference; -+ -+ using engine_type = Engine; -+ using layout_type = Layout; -+ -+ CUTE_HOST_DEVICE constexpr -+ Tensor() {} -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ Tensor(Ptr const& ptr, Layout const& layout) -+ : rep_(layout, ptr) { -+ } -+ -+ // -+ // Accessors -+ // -+ -+ static constexpr int rank = Layout::rank; -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ tensor() const { -+ return *this; -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ layout() const { -+ return get<0>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ engine() const { -+ return get<1>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ engine() { -+ return get<1>(rep_); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ data() const { -+ return engine().begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ data() { -+ return engine().begin(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ shape() const { -+ return layout().shape(); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ auto -+ size() const { -+ return cute::size(shape()); -+ } -+ -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ stride() const { -+ return layout().stride(); -+ } -+ -+ // -+ // Indexing op() and op[] -+ // -+ -+ // Index into this tensor like an array by computing the offset via layout() -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator[](Coord const& coord) { -+ return data()[layout()(coord)]; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator[](Coord const& coord) const { -+ return data()[layout()(coord)]; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord const& coord) { -+ if constexpr (has_underscore::value) { -+ auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); -+ return make_tensor(data() + offset, sliced_layout); -+ } else { -+ return data()[layout()(coord)]; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord const& coord) const { -+ if constexpr (has_underscore::value) { -+ auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); -+ return make_tensor(data() + offset, sliced_layout); -+ } else { -+ return data()[layout()(coord)]; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+ } -+ -+ // op() convenience function for multi-dimensional coordinates -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ decltype(auto) -+ operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { -+ return operator()(make_coord(c0,c1,cs...)); -+ } -+ -+ // -+ // Compose -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) { -+ return make_tensor(data(), layout().compose(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ compose(Layouts const&... layouts) const { -+ return make_tensor(data(), layout().compose(layouts...)); -+ } -+ -+ // -+ // Tile -+ // -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) { -+ return make_tensor(data(), layout().tile(layouts...)); -+ } -+ -+ template -+ CUTE_HOST_DEVICE constexpr -+ auto -+ tile(Layouts const&... layouts) const { -+ return make_tensor(data(), layout().tile(layouts...)); -+ } -+ -+ // -+ // Utility -+ // -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_1d_coord(Int const& linear_idx) const { -+ return layout().get_1d_coord(linear_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_hier_coord(Int const& linear_idx) const { -+ return layout().get_hier_coord(linear_idx); -+ } -+ -+ template ::value)> -+ CUTE_HOST_DEVICE constexpr -+ auto -+ get_flat_coord(Int const& linear_idx) const { -+ return layout().get_flat_coord(linear_idx); -+ } -+ -+ cute::tuple rep_; -+}; -+ -+ -+template -+struct is_tensor : false_type {}; -+template -+struct is_tensor> : true_type {}; -+ -+template -+struct is_rmem> : is_rmem {}; -+template -+struct is_smem> : is_smem {}; -+template -+struct is_gmem> : is_gmem {}; -+// -+// Make an owning Tensor that will allocate a static array -+// -+ -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Layout const& layout) -+{ -+ static_assert(is_static::value, "Dynamic owning tensors not supported"); -+ using Engine = ArrayEngine>; -+ return Tensor(); -+} -+ -+// e.g. make_tensor(12) -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(LayoutArg const& arg, LayoutArgs const&... args) -+{ -+ return make_tensor(make_layout(arg, args...)); -+} -+ -+// -+// Make a non-owning Tensor that will use a pointer (view) -+// -+ -+template ::value && -+ is_layout::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Iterator const& iter, Layout const& layout) -+{ -+ using Engine = ViewEngine; -+ return Tensor(iter, layout); -+} -+ -+// e.g. make_tensor(vec.data(), 12) -+template ::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args) -+{ -+ return make_tensor(iter, make_layout(arg, args...)); -+} -+ -+// -+// make_tensor_like -- make a register tensor the same type and shape as another -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tensor_like(Tensor const& tensor) -+{ -+ using value_type = typename Tensor::value_type; -+ return make_tensor(tensor.shape()); -+} -+ -+// -+// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_fragment_like(Tensor const& tensor) -+{ -+ using value_type = typename Tensor::value_type; -+ return make_tensor(make_layout_like(tensor.layout())); -+} -+ -+// -+// make_identity_tensor -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_identity_tensor(Shape const& shape) -+{ -+ return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))), -+ make_identity_layout(shape)); -+} -+ -+// -+// Utilities -+// -+ -+// Return the subtensor of a mode -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+tensor(Tensor&& tensor) -+{ -+ return std::forward(tensor); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+tensor(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), get(tensor.layout())); -+} -+ -+// Return the subtensor of a range of modes -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+take(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), take(tensor.layout())); -+} -+ -+// Return the layout of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+layout(Tensor const& tensor) -+{ -+ return layout(tensor.layout()); -+} -+ -+// Return the shape of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+shape(Tensor const& tensor) -+{ -+ return shape(tensor.layout()); -+} -+ -+// Return the stride of a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+stride(Tensor const& tensor) -+{ -+ return stride(tensor.layout()); -+} -+ -+// Return the number of elements in a mode -+template -+CUTE_HOST_DEVICE constexpr -+decltype(auto) -+size(Tensor const& tensor) -+{ -+ return size(tensor.layout()); -+} -+ -+// Return the rank of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+rank(Tensor const& tensor) -+{ -+ return rank(tensor.layout()); -+} -+ -+// Return the depth of a mode -+template -+CUTE_HOST_DEVICE constexpr -+auto -+depth(Tensor const& tensor) -+{ -+ return depth(tensor.layout()); -+} -+ -+// -+// Operations to manipulate Tensors like a Layout -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+flatten(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), flatten(tensor.layout())); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout())); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+coalesce(Tensor&& tensor, Profile const& profile) -+{ -+ return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout(), profile)); -+} -+ -+// Group the modes [B,E) into a single mode -+// e.g. group<2,4>(make_tensor(Layout>{})) -+// => make_tensor(Layout,_5,_6>>{}) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+group_modes(Tensor&& tensor) -+{ -+ return make_tensor(std::forward(tensor).data(), -+ group(tensor.layout())); -+} -+ -+// -+// Recast -+// -+ -+// NOTE: This is very dangerous to do -+// -- doesn't check dynamic integer divisibility -+// -- doesn't check alignment -+ -+// A tagged version for dispatching -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor, type_list) -+{ -+ using OldType = typename remove_cvref_t::value_type; -+ auto old_layout = tensor.layout(); -+ auto new_layout = recast(old_layout); -+ -+ // If this is an upcast of a normal Layout with static negative strides, then offset as well -+ if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { -+ auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); -+ auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); -+ auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); -+ -+ return make_tensor(recast(std::forward(tensor).data() + offset), new_layout); -+ } else { -+ return make_tensor(recast(std::forward(tensor).data() ), new_layout); -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+recast(Tensor&& tensor) -+{ -+ return recast(std::forward(tensor), type_list{}); -+} -+ -+// -+// max_common_vector -+// -+ -+/* Return Int such that N is the maximum number of continguous elements -+ * that logically correspond in the tensors of @a a and @a b. This is, -+ * the number of elements that could reasonably be vectorized into a single load/store. -+ * -+ * @returns Int with N >= 0 -+ * -+ * A return value of Int<0> indicates that no such conclusion can be made and no -+ * vectorization should be attempted. -+ */ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+max_common_vector(Tensor const& a, -+ Tensor const& b) -+{ -+ using SrcType = typename Tensor::value_type; -+ using DstType = typename Tensor::value_type; -+ -+ using SrcRef = decltype(*(a.data())); -+ using DstRef = decltype(*(b.data())); -+ -+ // Determine if vectorization candidates at all -+ if constexpr (// Should be the same value_types, else the copy is also performing a cast -+ sizeof(SrcType) == sizeof(DstType) && -+ // The types should be trivially copyable so that vectorization is valid -+ std::is_trivially_copyable::value && -+ std::is_trivially_copyable::value && -+ // Should be load/storing real data, rather than implicit iterators or such -+ std::is_reference::value && -+ std::is_reference::value) -+ { -+ return max_common_vector(a.layout(), b.layout()); -+ } else { -+ return Int<0>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Key algebraic operations -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+logical_divide(Tensor && tensor, -+ Tile const& tile) -+{ -+ return make_tensor(std::forward(tensor).data(), -+ logical_divide(tensor.layout(), tile)); -+} -+ -+// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b)) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+zipped_divide(Tensor && tensor, -+ Tile const& tile) // Layout or Tile -+{ -+ return make_tensor(std::forward(tensor).data(), -+ zipped_divide(tensor.layout(), tile)); -+} -+ -+// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b) -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+tiled_divide(Tensor && tensor, -+ Tile const& tile) // Layout or Tile -+{ -+ return make_tensor(std::forward(tensor).data(), -+ tiled_divide(tensor.layout(), tile)); -+} -+ -+// logical_product on a Tensor doesn't make sense since it often increases cosize -+ -+// -+// Logicial Divide utilities: local_partition and local_tile -+// -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_partition(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord) -+{ -+ constexpr int R1 = decltype(rank(tensor))::value; -+ -+ // Split the modes of tensor according to the modes of tile -+ // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) -+ -+ // The_coord is the coord into the first mode, flatten the rest -+ return zipped_divide(std::forward(tensor), tile)(coord, repeat(_)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_partition(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord, -+ Projection const& proj) -+{ -+ return local_partition(std::forward(tensor), -+ dice(proj, tile), -+ dice(proj, coord)); -+} -+ -+// Special case with Layout and Integral that extracts the coord first -+// e.g. local_partition(tensor, ThrLayout, threadIdx.x) -+template >::value && -+ is_integral::value)> -+CUTE_HOST_DEVICE -+auto -+local_partition(Tensor && tensor, -+ Layout const& tile, -+ Index const& index) -+{ -+ return local_partition(std::forward(tensor), -+ product_each(shape(tile)), -+ tile.get_flat_coord(index)); -+} -+ -+// Special case with Layout and Integral that extracts the coord first -+// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{}) -+template >::value && -+ is_integral::value)> -+CUTE_HOST_DEVICE -+auto -+local_partition(Tensor && tensor, -+ Layout const& tile, -+ Index const& index, -+ Projection const& proj) -+{ -+ return local_partition(std::forward(tensor), -+ dice(proj, product_each(shape(tile))), -+ dice(proj, tile).get_flat_coord(index)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE constexpr -+auto -+local_tile(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord) -+{ -+ constexpr int R0 = decltype(rank(tile))::value; -+ constexpr int R1 = decltype(rank(tensor))::value; -+ -+ // Split the modes of tensor according to the modes of tile -+ // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) -+ -+ // The padded_coord is the coord into the second mode, flatten the rest -+ return zipped_divide(std::forward(tensor), tile)(repeat(_), append(coord,_)); -+} -+ -+template >::value)> -+CUTE_HOST_DEVICE -+auto -+local_tile(Tensor && tensor, -+ Tile const& tile, -+ Coord const& coord, -+ Proj const& proj) -+{ -+ return local_tile(std::forward(tensor), -+ dice(proj, tile), -+ dice(proj, coord)); -+} -+ -+// -+// Display utilities -+// -+ -+template -+CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) -+{ -+ auto format = get_format(tensor(0)); -+ using type = typename decltype(format)::type; -+ -+ if constexpr (Layout::rank == 1) -+ { -+ for (int m = 0; m < size(tensor); ++m) { -+ printf(format.format, format.digits, type(tensor(m))); -+ printf("\n"); -+ } -+ } else -+ if constexpr (Layout::rank == 2) -+ { -+ for (int m = 0; m < size<0>(tensor); ++m) { -+ for (int n = 0; n < size<1>(tensor); ++n) { -+ printf(format.format, format.digits, type(tensor(m,n))); -+ } -+ printf("\n"); -+ } -+ } else -+ if constexpr (Layout::rank == 3) -+ { -+ print_tensor(tensor(_,_,0)); -+ for (int k = 1; k < size<2>(tensor); ++k) { -+ for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n"); -+ print_tensor(tensor(_,_,k)); -+ } -+ } else -+ if constexpr (Layout::rank == 4) -+ { -+ print_tensor(tensor(_,_,_,0)); -+ for (int p = 1; p < size<3>(tensor); ++p) { -+ for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n"); -+ print_tensor(tensor(_,_,_,p)); -+ } -+ } -+} -+ -+template -+CUTE_HOST_DEVICE void print(Tensor const& tensor) -+{ -+ print(tensor.layout()); print("\n"); -+ print_tensor(tensor); -+} -+ -+template -+CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) -+{ -+ int digits = 9; -+ -+ if constexpr (Layout::rank == 1) -+ { -+ for (int m = 0; m < size(tensor); ++m) { -+ os << std::setw(digits) << tensor(m) << std::endl; -+ } -+ } else -+ if constexpr (Layout::rank == 2) -+ { -+ for (int m = 0; m < size<0>(tensor); ++m) { -+ for (int n = 0; n < size<1>(tensor); ++n) { -+ os << std::setw(digits) << tensor(m,n); -+ } -+ os << std::endl; -+ } -+ } else -+ if constexpr (Layout::rank == 3) -+ { -+ print_tensor_os(os, tensor(_,_,0)); -+ for (int k = 1; k < size<2>(tensor); ++k) { -+ for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; -+ print_tensor_os(os, tensor(_,_,k)); -+ } -+ } else -+ if constexpr (Layout::rank == 4) -+ { -+ print_tensor_os(os, tensor(_,_,_,0)); -+ for (int p = 1; p < size<3>(tensor); ++p) { -+ for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; -+ print_tensor_os(os, tensor(_,_,_,p)); -+ } -+ } -+ -+ return os; -+} -+ -+template -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) -+{ -+ os << tensor.layout() << std::endl; -+ return print_tensor_os(os, tensor); -+} -+ -+} // end namespace cute -+ -+// -+// Extended Engines -+// -+ -+#include -+ -+// -+// Tensor Algorithms -+// -+ -+#include -+#include -+#include -+#include -+#include -+#include -diff --git a/3rdparty/cutlass/include/cute/tensor_predicate.hpp b/3rdparty/cutlass/include/cute/tensor_predicate.hpp -new file mode 100644 -index 0000000..730f219 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tensor_predicate.hpp -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+template -+struct ConstantTensor -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ T const& -+ operator()(Coords const&...) const { -+ return val_; -+ } -+ -+ T val_; -+}; -+ -+struct TrivialPredTensor -+{ -+ template -+ CUTE_HOST_DEVICE constexpr -+ true_type -+ operator()(Coords const&...) const { -+ return {}; -+ } -+}; -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/tile.hpp b/3rdparty/cutlass/include/cute/tile.hpp -new file mode 100644 -index 0000000..b2fa2e8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/tile.hpp -@@ -0,0 +1,58 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+// -+// A Tile is not a Layout, it's a tuple of Layouts or Tiles or Underscores -+// -+ -+template -+using Tile = tuple; -+ -+template -+using is_tile = is_tuple; -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+make_tile(Layouts const&... layouts) -+{ -+ return Tile(layouts...); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/underscore.hpp b/3rdparty/cutlass/include/cute/underscore.hpp -new file mode 100644 -index 0000000..d79b4ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/underscore.hpp -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+#include -+#include -+#include -+ -+namespace cute -+{ -+ -+// For slicing -+struct Underscore : Int<0> {}; -+ -+CUTE_INLINE_CONSTANT Underscore _; -+ -+// Treat Underscore as an integral like integral_constant -+template <> -+struct is_integral : true_type {}; -+ -+template -+struct is_underscore : false_type {}; -+template <> -+struct is_underscore : true_type {}; -+ -+// Tuple trait for detecting static member element -+template -+struct has_elem : false_type {}; -+template -+struct has_elem : true_type {}; -+template -+struct has_elem::value> > -+ : has_elem > {}; -+template -+struct has_elem> -+ : disjunction, Elem>...> {}; -+ -+// Tuple trait for detecting static member element -+template -+struct all_elem : false_type {}; -+template -+struct all_elem : true_type {}; -+template -+struct all_elem::value> > -+ : all_elem > {}; -+template -+struct all_elem> -+ : conjunction, Elem>...> {}; -+ -+// Tuple trait for detecting Underscore member -+template -+using has_underscore = has_elem; -+ -+template -+using all_underscore = all_elem; -+ -+template -+using has_int1 = has_elem>; -+ -+template -+using has_int0 = has_elem>; -+ -+// -+// Slice keeps only the elements of Tuple B that are paired with an Underscore -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+slice(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); }); -+ } else if constexpr (is_underscore::value) { -+ return cute::tuple{b}; -+ } else { -+ return cute::tuple<>{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Dice keeps only the elements of Tuple B that are paired with an Int -+// -+ -+template -+CUTE_HOST_DEVICE constexpr -+auto -+dice(A const& a, B const& b) -+{ -+ if constexpr (is_tuple::value) { -+ static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); -+ return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); }); -+ } else if constexpr (is_underscore::value) { -+ return cute::tuple<>{}; -+ } else { -+ return cute::tuple{b}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// -+// Display utilities -+// -+ -+CUTE_HOST_DEVICE void print(Underscore const&) { -+ printf("_"); -+} -+ -+CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { -+ return os << "_"; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/debug.hpp b/3rdparty/cutlass/include/cute/util/debug.hpp -new file mode 100644 -index 0000000..9a62143 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/debug.hpp -@@ -0,0 +1,153 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/** -+ * \file -+ * \brief Debugging and logging functionality -+ */ -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+/****************************************************************************** -+ * Debug and logging macros -+ ******************************************************************************/ -+ -+/** -+ * Formats and prints the given message to stdout -+ */ -+#if !defined(CUTE_LOG) -+# if !defined(__CUDA_ARCH__) -+# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) -+# else -+# define CUTE_LOG(format, ...) \ -+ printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ -+ blockIdx.x, blockIdx.y, blockIdx.z, \ -+ threadIdx.x, threadIdx.y, threadIdx.z, \ -+ __VA_ARGS__); -+# endif -+#endif -+ -+/** -+ * Formats and prints the given message to stdout only if DEBUG is defined -+ */ -+#if !defined(CUTE_LOG_DEBUG) -+# ifdef DEBUG -+# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) -+# else -+# define CUTE_LOG_DEBUG(format, ...) -+# endif -+#endif -+ -+/** -+ * \brief Perror macro with exit -+ */ -+#if !defined(CUTE_ERROR_EXIT) -+# define CUTE_ERROR_EXIT(e) \ -+ do { \ -+ cudaError_t code = (e); \ -+ if (code != cudaSuccess) { \ -+ fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ -+ __FILE__, __LINE__, #e, \ -+ cudaGetErrorName(code), cudaGetErrorString(code)); \ -+ fflush(stderr); \ -+ exit(0); \ -+ } \ -+ } while (0) -+#endif -+ -+#if !defined(CUTE_CHECK_LAST) -+# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) -+#endif -+ -+#if !defined(CUTE_CHECK_ERROR) -+# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) -+#endif -+ -+// A dummy function that uses compilation failure to print a type -+template -+CUTE_HOST_DEVICE -+void -+print_type(T&&) { -+ static_assert(sizeof(T) < 0, "Printing type T."); -+} -+ -+// -+// Device-specific helpers -+// -+// e.g. -+// if (thread0()) print(...); -+// if (block0()) print(...); -+// if (thread(42)) print(...); -+ -+CUTE_HOST_DEVICE -+bool -+thread(int tid, int bid) -+{ -+#if defined(__CUDA_ARCH__) -+ return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) -+ && ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid); -+#else -+ return true; -+#endif -+} -+ -+CUTE_HOST_DEVICE -+bool -+thread(int tid) -+{ -+ return thread(tid, 0); -+} -+ -+CUTE_HOST_DEVICE -+bool -+thread0() -+{ -+ return thread(0,0); -+} -+ -+CUTE_HOST_DEVICE -+bool -+block0() -+{ -+#if defined(__CUDA_ARCH__) -+ return !(blockIdx.x | blockIdx.y | blockIdx.z); -+#else -+ return true; -+#endif -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/print.hpp b/3rdparty/cutlass/include/cute/util/print.hpp -new file mode 100644 -index 0000000..ec774b0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/print.hpp -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+// -+// CUDA compatible print and printf -+// -+ -+namespace cute -+{ -+ -+CUTE_HOST_DEVICE -+int -+num_digits(int x) -+{ -+ return (x < 10 ? 1 : -+ (x < 100 ? 2 : -+ (x < 1000 ? 3 : -+ (x < 10000 ? 4 : -+ (x < 100000 ? 5 : -+ (x < 1000000 ? 6 : -+ (x < 10000000 ? 7 : -+ (x < 100000000 ? 8 : -+ (x < 1000000000 ? 9 : -+ 10))))))))); -+} -+ -+template -+struct format_and_size { -+ using type = T; -+ char const* format; -+ int digits; -+}; -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(bool) { -+ return {"%*d", 3}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(int32_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(uint32_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(int64_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(uint64_t) { -+ return {"%*d", 5}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(half_t) { -+ return {"%*.2f", 8}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(float) { -+ return {"%*.2e", 10}; -+} -+ -+CUTE_HOST_DEVICE -+format_and_size -+get_format(double) { -+ return {"%*.3e", 11}; -+} -+ -+// -+// print dispatcher -+// -+ -+CUTE_HOST_DEVICE -+void -+print(char const& c) { -+ printf("%c", c); -+} -+ -+template ::value)> -+CUTE_HOST_DEVICE -+void -+print(T const& a) { -+ printf("%d", int(a)); -+} -+ -+template -+CUTE_HOST_DEVICE -+void -+print(char const* format, T const&... t) { -+ printf(format, t...); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cute/util/type_traits.hpp b/3rdparty/cutlass/include/cute/util/type_traits.hpp -new file mode 100644 -index 0000000..4d37eb9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cute/util/type_traits.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr -+#define __CUTE_REQUIRES_V(...) typename std::enable_if::type* = nullptr -+ -+namespace cute -+{ -+ -+using std::conjunction; -+using std::conjunction_v; -+ -+using std::disjunction; -+using std::disjunction_v; -+ -+using std::negation; -+using std::negation_v; -+ -+using std::void_t; -+ -+// C++20 -+// using std::remove_cvref; -+template -+struct remove_cvref { -+ using type = std::remove_cv_t>; -+}; -+ -+// C++20 -+// using std::remove_cvref_t; -+template -+using remove_cvref_t = typename remove_cvref::type; -+ -+// -+// is_valid -+// -+ -+namespace detail { -+ -+template ()(std::declval()...))> -+CUTE_HOST_DEVICE constexpr auto -+is_valid_impl(int) { return std::true_type{}; } -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid_impl(...) { return std::false_type{}; } -+ -+template -+struct is_valid_fn { -+ template -+ CUTE_HOST_DEVICE constexpr auto -+ operator()(Args&&...) const { return is_valid_impl(int{}); } -+}; -+ -+} // end namespace detail -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid(F&&) { -+ return detail::is_valid_fn{}; -+} -+ -+template -+CUTE_HOST_DEVICE constexpr auto -+is_valid(F&&, Args&&...) { -+ return detail::is_valid_impl(int{}); -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/include/cutlass/aligned_buffer.h b/3rdparty/cutlass/include/cutlass/aligned_buffer.h -new file mode 100644 -index 0000000..1b29277 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/aligned_buffer.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief AlignedBuffer is a container for trivially copyable elements suitable for use in -+ unions and shared memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Modifies semantics of cutlass::Array<> to provide guaranteed alignment. -+template < -+ typename T, -+ int N, -+ int Align = 16 -+> -+struct AlignedBuffer { -+ -+ /// Internal storage type -+ using Storage = uint8_t; -+ -+ /// Number of logical elements held in buffer -+ static int const kCount = N; -+ -+ /// Alignment requirement in bytes -+ static int const kAlign = Align; -+ -+ /// Number of storage elements -+ static int const kBytes = -+ (sizeof_bits::value * N + 7) / 8; -+ -+private: -+ -+ /// Internal storage -+ alignas(Align) Storage storage[kBytes]; -+ -+public: -+ -+ // -+ // C++ standard members -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type *pointer; -+ typedef value_type const * const_pointer; -+ -+ using Array = Array; -+ using reference = typename Array::reference; -+ using const_reference = typename Array::const_reference; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage * raw_data() { -+ return storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage const * raw_data() const { -+ return storage; -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kCount; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kCount; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kCount; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/arch/arch.h b/3rdparty/cutlass/include/cutlass/arch/arch.h -new file mode 100644 -index 0000000..043bfac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/arch.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines tags for architecture-specific configurations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -+ -+/// Computes laneId within a warp -+CUTLASS_DEVICE -+int LaneId() { -+ int ret; -+ asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); -+ return ret; -+} -+ -+/// Computes SM number the thread is running on -+CUTLASS_DEVICE -+int SmId() { -+ int ret; -+ asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); -+ return ret; -+} -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+struct Sm50 { -+ static int const kMinComputeCapability = 50; -+}; -+struct Sm60 { -+ static int const kMinComputeCapability = 60; -+}; -+struct Sm61 { -+ static int const kMinComputeCapability = 61; -+}; -+struct Sm70 { -+ static int const kMinComputeCapability = 70; -+}; -+struct Sm72 { -+ static int const kMinComputeCapability = 72; -+}; -+struct Sm75 { -+ static int const kMinComputeCapability = 75; -+}; -+struct Sm80 { -+ static int const kMinComputeCapability = 80; -+}; -+struct Sm86 { -+ static int const kMinComputeCapability = 86; -+}; -+ -+struct Sm90 { -+ static int const kMinComputeCapability = 90; -+}; -+ -+/// Triggers a breakpoint on the device -+CUTLASS_DEVICE -+void device_breakpoint() { -+#if defined(__CUDA_ARCH__) -+ asm volatile (" brkpt;\n"); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/barrier.h b/3rdparty/cutlass/include/cutlass/arch/barrier.h -new file mode 100644 -index 0000000..34f0b4e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/barrier.h -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are not permit- -+ * ted. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Barrier Operations on SM90+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+namespace cutlass { -+/// @brief -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) -+#define CUDA_BARRIER_ENABLED 1 -+#else -+#define CUDA_BARRIER_ENABLED 0 -+#endif -+ -+class NamedBarrier { -+ -+ // Data Members: -+ -+ // Range = [1 , NUM_THREADS_PER_CTA] -+ // Range % warp-size (i.e 32) == 0 -+ uint32_t const num_threads_; -+ -+ // Range : [0, 15] -+ uint32_t const id_; -+ -+ public: -+ -+ CUTLASS_DEVICE -+ NamedBarrier(uint32_t num_threads, uint32_t id = 0) -+ : num_threads_(num_threads), id_(id) {} -+ -+ CUTLASS_DEVICE -+ void arrive_and_wait() const { -+ NamedBarrier::arrive_and_wait(num_threads_, id_); -+ } -+ -+ CUTLASS_DEVICE -+ void arrive() const { -+ NamedBarrier::arrive(num_threads_, id_); -+ } -+ -+ CUTLASS_DEVICE -+ void sync() const { -+ NamedBarrier::arrive_and_wait(); -+ } -+ -+ // Static variants -+ CUTLASS_DEVICE -+ static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { -+#if CUDA_BARRIER_ENABLED -+ asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void arrive(uint32_t num_threads, uint32_t barrier_id) { -+#if CUDA_BARRIER_ENABLED -+ asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void sync(uint32_t num_threads, uint32_t barrier_id) { -+ NamedBarrier::arrive_and_wait(num_threads, barrier_id); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour. -+// This is an extension to the Ampere AW barriers -+// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20). -+struct ClusterBarrier { -+ -+ using ValueType = uint64_t; -+ -+protected: -+ // Can never be initializated - can only be aliased to smem -+ ValueType barrier_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ ClusterBarrier() = delete; -+ -+ CUTLASS_DEVICE -+ void init(uint32_t arrive_count) const { -+ ClusterBarrier::init(&this->barrier_, arrive_count); -+ } -+ -+ CUTLASS_DEVICE -+ uint32_t test_wait(uint32_t phase, uint32_t pred=true) const { -+ return ClusterBarrier::test_wait(&this->barrier_, phase, pred); -+ } -+ -+ CUTLASS_DEVICE -+ void wait(uint32_t phase) const { -+ ClusterBarrier::wait(&this->barrier_, phase); -+ } -+ -+ // Barrier arrive on local smem -+ CUTLASS_DEVICE -+ void arrive() const { -+ ClusterBarrier::arrive(&this->barrier_); -+ } -+ -+ // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) -+ CUTLASS_DEVICE -+ void arrive(uint32_t cta_id, uint32_t pred = true ) const { -+ ClusterBarrier::arrive(&this->barrier_, cta_id, pred); -+ } -+ -+ // -+ // Static Versions -+ // -+ CUTLASS_DEVICE -+ static void init(ValueType const* smem_ptr, uint32_t arrive_count) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.init.shared.b64 [%1], %0; \n" -+ "}" -+ : -+ : "r"(arrive_count), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Static version of wait - in case we don't want to burn a register -+ CUTLASS_DEVICE -+ static void wait(ValueType const* smem_ptr, uint32_t phase) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ // Arbitrarily large timer value after which try-wait expires and re-tries. -+ uint32_t ticks = 0x989680; -+ asm volatile( -+ "{\n\t" -+ ".reg .pred P1; \n\t" -+ "LAB_WAIT: \n\t" -+ "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" -+ "@P1 bra.uni DONE; \n\t" -+ "bra.uni LAB_WAIT; \n\t" -+ "DONE: \n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(phase), "r"(ticks)); -+ -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ uint32_t waitComplete; -+ -+ asm volatile( -+ "{\n\t" -+ ".reg .pred P1; \n\t" -+ ".reg .pred P2; \n\t" -+ "setp.eq.u32 P2, %3, 1;\n\t" -+ "@P2 mbarrier.test_wait.parity.shared.b64 P1, [%1], %2; \n\t" -+ "selp.b32 %0, 1, 0, P1; \n\t" -+ "}" -+ : "=r"(waitComplete) -+ : "r"(smem_addr), "r"(phase), "r"(pred)); -+ -+ return waitComplete; -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ return 0; -+ } -+ -+ // Static Predicated version of the above - in case we know the address. -+ CUTLASS_DEVICE -+ static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ ".reg .b32 remAddr32;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" -+ "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(cta_id), "r"(pred)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Barrier arrive on local smem -+ CUTLASS_DEVICE -+ static void arrive(ValueType const* smem_ptr) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ uint64_t state = 0; -+ asm volatile( -+ "{\n\t" -+ "mbarrier.arrive.shared.b64 %1, [%0];\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "l"(state)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ static void invalidate(ValueType const* smem_ptr) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.ival.shared.b64 [%0]; \n\t" -+ "}" -+ : -+ : "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// SM90 also introduces a new type of cluster-barrier which supports sync. -+// not just based on Arrive Count, but also transaction count (in bytes) -+struct ClusterTransactionBarrier : public ClusterBarrier { -+ -+ CUTLASS_DEVICE -+ ClusterTransactionBarrier() = delete; -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ void arrive_and_reset_bytes(uint32_t transaction_bytes) const { -+ ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes); -+ } -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { -+ ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true); -+ } -+ -+ CUTLASS_DEVICE -+ void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { -+ uint32_t cta_rank = cute::block_rank_in_cluster(); -+ ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred); -+ } -+ -+ CUTLASS_DEVICE -+ void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { -+ ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred); -+ } -+ -+ // -+ // Static Versions -+ // -+ -+ // Performs an arrive operation + bytes reset -+ CUTLASS_DEVICE -+ static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0; \n\t" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Performs an arrive operation + bytes reset for a remote cta_id in a Cluster -+ CUTLASS_DEVICE -+ static void arrive_and_reset_bytes( -+ ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ ".reg .b32 remAddr32;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" -+ "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" -+ "}" -+ : -+ : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Performs an bytes reset without doing an arrive operation -+ CUTLASS_DEVICE -+ static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ asm volatile( -+ "{\n\t" -+ "mbarrier.expect_tx.shared.b64 [%1], %0; \n\t" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+ -+ // Increments transaction bytes in the barrier -+ CUTLASS_DEVICE -+ static void commit( -+ ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { -+#if CUDA_BARRIER_ENABLED -+ uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); -+ smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); -+ asm volatile( -+ "{\n\t" -+ ".reg .pred p;\n\t" -+ "setp.eq.u32 p, %2, 1;\n\t" -+ "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" -+ "}" -+ : -+ : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+ } -+}; -+ -+// Helps with visibility of barrier init operations across warps / cta / cluster -+// Available as a separate function so as to batch inits across barriers and fence once -+// Note : It must be composed with an appropriate sync instruction with the right scope -+// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() -+CUTLASS_DEVICE -+void fence_barrier_init() { -+#if CUDA_BARRIER_ENABLED -+ asm volatile( -+ "{\n\t" -+ "fence.mbarrier_init.release.cluster; \n" -+ "}" -+ ::); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+// Issue a shared memory fence for async operations -+CUTLASS_DEVICE -+void fence_view_async_shared() { -+#if CUDA_BARRIER_ENABLED -+ asm volatile ( -+ "{\n\t" -+ "fence.proxy.async.shared::cta; \n" -+ "}" -+ ::); -+#else -+ asm volatile ("brkpt;\n" ::); -+#endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+} // end namespace arch -+} // end namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/cache_operation.h b/3rdparty/cutlass/include/cutlass/arch/cache_operation.h -new file mode 100644 -index 0000000..fa70c4c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/cache_operation.h -@@ -0,0 +1,66 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Directives related to cache operations -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Controls PTX cache operations -+struct CacheOperation { -+ enum Kind { -+ /// Cache at all levels - accessed again -+ Always, -+ /// Cache at global level -+ Global, -+ /// Streaming - likely to be accessed once -+ Streaming, -+ /// Indicates the line will not be used again -+ LastUse, -+ /// Don't cache, and fetch again -+ Volatile, -+ /// Write back at all coherent levels -+ WriteBack, -+ /// Write through to system memory -+ WriteThrough -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory.h b/3rdparty/cutlass/include/cutlass/arch/memory.h -new file mode 100644 -index 0000000..b2a9468 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory.h -@@ -0,0 +1,474 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Architecture-specific operators on memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Fragment type to store loaded data -+ typename AccessType, -+ /// The bytes of loading -+ int LoadBytes -+ > -+struct global_load; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Specializations -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ -+ (__CUDACC_VER_MAJOR__ > 11)) && \ -+ defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ -+ ! (defined(__clang__) && defined(__CUDA__)) -+ #define CUTLASS_ENABLE_L2_PREFETCH 1 -+#else -+ #define CUTLASS_ENABLE_L2_PREFETCH 0 -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// The redundant mov PTX instruction is used to enforce the compiler to -+// keep the initializing code before ld.global -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint4 *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %9, 0;\n" -+ " mov.b32 %0, %10;\n" -+ " mov.b32 %1, %11;\n" -+ " mov.b32 %2, %12;\n" -+ " mov.b32 %3, %13;\n" -+ " mov.b32 %4, %14;\n" -+ " mov.b32 %5, %15;\n" -+ " mov.b32 %6, %16;\n" -+ " mov.b32 %7, %17;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%8];\n" -+ " @p ld.global.L2::128B.v4.u32 {%4, %5, %6, %7}, [%18];\n" -+#else -+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" -+ " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n" -+#endif -+ "}\n" -+ : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), -+ "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), -+ "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), -+ "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint4 &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " mov.b32 %0, %6;\n" -+ " mov.b32 %1, %7;\n" -+ " mov.b32 %2, %8;\n" -+ " mov.b32 %3, %9;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+#else -+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+#endif -+ "}\n" -+ : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint2 &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %3, 0;\n" -+ " mov.b32 %0, %4;\n" -+ " mov.b32 %1, %5;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.v2.u32 {%0, %1}, [%2];\n" -+#else -+ " @p ld.global.v2.u32 {%0, %1}, [%2];\n" -+#endif -+ "}\n" -+ : "=r"(data.x), "=r"(data.y) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ unsigned &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " mov.b32 %0, %3;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.u32 %0, [%1];\n" -+#else -+ " @p ld.global.u32 %0, [%1];\n" -+#endif -+ "}\n" -+ : "=r"(data) -+ : "l"(ptr), "r"((int)pred_guard), "r"(data)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ uint16_t &data = reinterpret_cast(D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " mov.b16 %0, %3;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p ld.global.L2::128B.u16 %0, [%1];\n" -+#else -+ " @p ld.global.u16 %0, [%1];\n" -+#endif -+ "}\n" -+ : "=h"(data) -+ : "l"(ptr), "r"((int)pred_guard), "h"(data)); -+ } -+}; -+ -+template -+struct global_load { -+ CUTLASS_DEVICE -+ global_load(AccessType &D, void const *ptr, bool pred_guard) { -+ if (pred_guard) D = *(reinterpret_cast(ptr)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Fragment type to store data -+ typename AccessType, -+ /// The bytes of storing -+ int StoreBytes -+ > -+struct global_store; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Specializations -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" -+ " @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n" -+ " @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), -+ "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), -+ "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w), -+ "l"(((uint8_t *)ptr) + 32), -+ "r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w), -+ "l"(((uint8_t *)ptr) + 48), -+ "r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[3].w)); -+ } -+}; -+ -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const *data = reinterpret_cast(&D); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), -+ "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16), -+ "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint4 const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %5, 0;\n" -+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint2 const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %3, 0;\n" -+ " @p st.global.v2.u32 [%0], {%1, %2};\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint32_t const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " @p st.global.u32 [%0], %1;\n" -+ "}\n" -+ : -+ : "l"(ptr), "r"(data), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ uint16_t const &data = reinterpret_cast(D); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %2, 0;\n" -+ " @p st.global.u16 [%0], %1;\n" -+ "}\n" -+ : -+ : "l"(ptr), "h"(data), "r"((int)pred_guard)); -+ } -+}; -+ -+template -+struct global_store { -+ CUTLASS_DEVICE -+ global_store(AccessType const &D, void *ptr, bool pred_guard) { -+ if (pred_guard) *(reinterpret_cast(ptr)) = D; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// ld.shared -+template -+CUTLASS_DEVICE -+void shared_load(void *dst, uint32_t ptr); -+ -+/// ld.shared - 16b -+template <> -+CUTLASS_DEVICE -+void shared_load<2>(void *dst, uint32_t ptr) { -+ asm volatile("ld.shared.u16 %0, [%1];\n" -+ : "=h"(*reinterpret_cast(dst)) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 32b -+template <> -+CUTLASS_DEVICE -+void shared_load<4>(void *dst, uint32_t ptr) { -+ asm volatile("ld.shared.u32 %0, [%1];\n" -+ : "=r"(*reinterpret_cast(dst)) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 64b -+template <> -+CUTLASS_DEVICE -+void shared_load<8>(void *dst, uint32_t ptr) { -+ uint2 *dst_u64 = reinterpret_cast(dst); -+ asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" -+ : -+ "=r"(dst_u64->x), -+ "=r"(dst_u64->y) -+ : "r"(ptr)); -+} -+ -+/// ld.shared - 128b -+template <> -+CUTLASS_DEVICE -+void shared_load<16>(void *dst, uint32_t ptr) { -+ uint4 *dst_u128 = reinterpret_cast(dst); -+ asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" -+ : -+ "=r"(dst_u128->x), -+ "=r"(dst_u128->y), -+ "=r"(dst_u128->z), -+ "=r"(dst_u128->w) -+ : "r"(ptr)); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// st.shared -+template -+CUTLASS_DEVICE -+void shared_store(uint32_t ptr, void const *src); -+ -+/// st.shared - 16b -+template <> -+CUTLASS_DEVICE -+void shared_store<2>(uint32_t ptr, void const *src) { -+ asm volatile("st.shared.u16 [%0], %1;\n" -+ : : -+ "r"(ptr), -+ "h"(*reinterpret_cast(src)) -+ ); -+} -+ -+/// st.shared - 32b -+template <> -+CUTLASS_DEVICE -+void shared_store<4>(uint32_t ptr, void const *src) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(ptr), -+ "r"(*reinterpret_cast(src)) -+ ); -+} -+ -+/// st.shared - 64b -+template <> -+CUTLASS_DEVICE -+void shared_store<8>(uint32_t ptr, void const *src) { -+ uint2 const *dst_u64 = reinterpret_cast(src); -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(ptr), -+ "r"(dst_u64->x), -+ "r"(dst_u64->y) -+ ); -+} -+ -+/// st.shared - 128b -+template <> -+CUTLASS_DEVICE -+void shared_store<16>(uint32_t ptr, void const *src) { -+ uint4 const *dst_u128 = reinterpret_cast(src); -+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ : : -+ "r"(ptr), -+ "r"(dst_u128->x), -+ "r"(dst_u128->y), -+ "r"(dst_u128->z), -+ "r"(dst_u128->w) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "memory_sm75.h" -+#include "memory_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h b/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h -new file mode 100644 -index 0000000..ba59364 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory_sm75.h -@@ -0,0 +1,279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Architecture-specific operators on memory added for SM75 -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cute/arch/util.hpp" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Layout of destination matrix (column-major implies transpose) -+ typename Layout, -+ /// .x1, .x2, or .x4 -+ int MatrixCount -+> -+inline __device__ void ldsm(Array & D, void const* ptr); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Determine the appropriate way to target PTX's "ldmatrix" instruction. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11) -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) -+#define CUDA_LDMATRIX_ACTIVATED 1 -+#endif -+ -+#define CUDA_LDMATRIX_SUPPORTED 1 -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CUTLASS helper to get SMEM pointer -+inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { -+ return cute::cast_smem_ptr_to_uint(ptr); -+} -+ -+/// CUTLASS helper to get SMEM pointer -+inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { -+ return cutlass_get_smem_pointer(const_cast(ptr)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x; -+ asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); -+ reinterpret_cast(D) = x; -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y; -+ asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); -+ reinterpret_cast(D) = make_int2(x, y); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y, z, w; -+ asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); -+ reinterpret_cast(D) = make_int4(x, y, z, w); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Transpose on 16b granularity -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if CUDA_LDMATRIX_ACTIVATED -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x; -+ asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];" : "=r"(x) : "r"(addr)); -+ reinterpret_cast(D) = x; -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y; -+ asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];" : "=r"(x), "=r"(y) : "r"(addr)); -+ reinterpret_cast(D) = make_int2(x, y); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+inline __device__ void ldsm( -+ Array & D, -+ void const* ptr) { -+ -+ #if defined(CUDA_LDMATRIX_ACTIVATED) -+ -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ int x, y, z, w; -+ asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(x), "=r"(y), "=r"(z), "=r"(w) : "r"(addr)); -+ reinterpret_cast(D) = make_int4(x, y, z, w); -+ -+ #else -+ -+ CUTLASS_UNUSED(D); -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ D = *reinterpret_cast(ptr); -+ } -+}; -+ -+template -+CUTLASS_DEVICE void shared_load(AccessType &D, void const *ptr) { -+ shared_load_op(D, ptr); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ uint4 v; -+ asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : -+ "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) : "r"(addr)); -+ -+ D = reinterpret_cast(v); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct shared_load_op { -+ CUTLASS_DEVICE -+ shared_load_op(AccessType &D, void const *ptr) { -+ unsigned addr = cutlass_get_smem_pointer(ptr); -+ -+ uint2 v; -+ asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : -+ "=r"(v.x), "=r"(v.y) : "r"(addr)); -+ -+ D = reinterpret_cast(v); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h b/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h -new file mode 100644 -index 0000000..04bab1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/memory_sm80.h -@@ -0,0 +1,466 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Architecture-specific operators on memory added for SM80 -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/cache_operation.h" -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ #define CUDA_CP_ASYNC_ACTIVATED 1 -+#else -+ #define CUDA_CP_ASYNC_ACTIVATED 0 -+#endif -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Initiates an asynchronous copy from global memory to shared memory. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async; -+ -+/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -+/// the entire transfer, zeros are written to SMEM if the guard predicate is false. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async_zfill; -+ -+/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -+/// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false. -+/// -+/// cp.async -+/// -+template < -+ /// Size of the access in bytes -+ int SizeInBytes, -+ /// Cache operation -+ CacheOperation::Kind cache_op = CacheOperation::Always> -+struct cp_async_nan; -+ -+/// Either 0 or 1 are written to SMEM based on input element type -+/// Used for diagonal elements of triangular matrix of BLAS3 functions -+/// -+/// st.shared -+/// -+template < -+ /// Type of Element -+ typename Element, -+ /// If the data is for a Hermitian matrix diagonal -+ bool IsHermitianData = false> -+struct cp_async_diag; -+ -+static const uint32_t OOB_NAN_F16 = 0x7eff; -+static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async { -+ -+ /// Copy -+ CUTLASS_DEVICE -+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ // Make sure the size is supported. -+ static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), -+ "Size is not supported"); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -+#endif -+ "}\n" ::"r"((int)pred_guard), -+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async_zfill { -+ -+ /// Copy with zero fill -+ CUTLASS_DEVICE -+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ // Make sure the size is supported. -+ static_assert((SizeInBytes == 4 || SizeInBytes == 8 || SizeInBytes == 16), -+ "Size is not supported"); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ int src_in_bytes = (pred_guard ? SizeInBytes : 0); -+ -+ asm volatile( -+#if CUTLASS_ENABLE_L2_PREFETCH -+ "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#else -+ "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#endif -+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ else { -+ AccessType zeros; -+ zeros.clear(); -+ *static_cast(smem_ptr) = zeros; -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template <> -+struct cp_async_nan<16, CacheOperation::Always> { -+ static int const kSizeInBytes = 16; -+ -+ /// Copy with nan fill -+ CUTLASS_DEVICE -+ cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, -+ OOB_NAN_F16x2, OOB_NAN_F16x2}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.ca.shared.global [%1], [%2], %3;\n" -+#endif -+ " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" -+ "}\n" -+ : -+ : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), -+ "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), -+ "r"(OOB_NAN_F16x8.w)); -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_UNUSED(global_ptr); -+ CUTLASS_UNUSED(pred_guard); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+/// Partial specialization to write one (1) -+template -+struct cp_async_diag { -+ using Element = Element_; -+ -+ CUTLASS_DEVICE -+ cp_async_diag(void *smem_ptr) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ /// Values for the diagonal elements of the triangular input matrix -+ static __constant__ uint2 DIAG_DATA_DOUBLE_ONE = {0x3ff00000, 0x00000000}; -+ static __constant__ uint1 DIAG_DATA_FLOAT_ONE = {0x3f800000}; -+ static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ if (platform::is_same>::value) { -+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y), "r"(DIAG_DATA_DOUBLE_ONE.x), -+ "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same>::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_DOUBLE_ONE.y),"r"(DIAG_DATA_DOUBLE_ONE.x)); -+ } else if (platform::is_same::value) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_FLOAT_ONE.x)); -+ } else { -+ CUTLASS_UNUSED(smem_int_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ } -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+/// Partial specialization to write zero for the imaginary part of Hermitian data -+template -+struct cp_async_diag { -+ using Element = Element_; -+ -+ CUTLASS_DEVICE -+ cp_async_diag(void *smem_ptr) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ /// Values for the diagonal elements of the triangular input matrix -+ static __constant__ uint1 DIAG_DATA_ZERO = {0x00000000}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ if (platform::is_same>::value) { -+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x), "r"(DIAG_DATA_ZERO.x)); -+ } else if (platform::is_same>::value) { -+ asm volatile("st.shared.u32 [%0], %1;\n" -+ : : -+ "r"(smem_int_ptr), "r"(DIAG_DATA_ZERO.x)); -+ } else { -+ CUTLASS_UNUSED(smem_int_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ } -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async { -+ -+ /// Copy -+ CUTLASS_DEVICE -+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static_assert(SizeInBytes == 16, -+ "cp.async only supports CacheOperation::Global when access size is 16B."); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -+#endif -+ "}\n" ::"r"((int)pred_guard), -+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template < -+ /// Size of the access in bytes -+ int SizeInBytes> -+struct cp_async_zfill { -+ -+ /// Copy with zero fill -+ CUTLASS_DEVICE -+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static_assert(SizeInBytes == 16, -+ "cp.async only supports CacheOperation::Global when access size is 16B."); -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ int src_in_bytes = (pred_guard ? SizeInBytes : 0); -+ -+ asm volatile( -+#if CUTLASS_ENABLE_L2_PREFETCH -+ "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#else -+ "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), -+#endif -+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes)); -+ -+ #else -+ using AccessType = Array; -+ -+ if (pred_guard) { -+ *static_cast(smem_ptr) = *static_cast(global_ptr); -+ } -+ else { -+ AccessType zeros; -+ zeros.clear(); -+ *static_cast(smem_ptr) = zeros; -+ } -+ #endif -+ } -+}; -+ -+/// Partial specialization -+template <> -+struct cp_async_nan<16, CacheOperation::Global> { -+ static int const kSizeInBytes = 16; -+ -+ /// Copy with nan fill -+ CUTLASS_DEVICE -+ cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ -+ static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, -+ OOB_NAN_F16x2, OOB_NAN_F16x2}; -+ -+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); -+ -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " setp.ne.b32 p, %0, 0;\n" -+#if CUTLASS_ENABLE_L2_PREFETCH -+ " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" -+#else -+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n" -+#endif -+ " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" -+ "}\n" -+ : -+ : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), -+ "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), -+ "r"(OOB_NAN_F16x8.w)); -+ -+ #else -+ -+ CUTLASS_UNUSED(smem_ptr); -+ CUTLASS_UNUSED(global_ptr); -+ CUTLASS_UNUSED(pred_guard); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+ #endif -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -+CUTLASS_DEVICE -+void cp_async_fence() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.commit_group;\n" ::); -+ #endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Blocks until all but previous cp.async.commit_group operations have committed. -+template -+CUTLASS_DEVICE void cp_async_wait() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); -+ #endif -+} -+ -+/// Blocks until all previous cp.async.commit_group operations have committed. -+template <> -+CUTLASS_DEVICE void cp_async_wait<0>() { -+ #if CUDA_CP_ASYNC_ACTIVATED -+ asm volatile("cp.async.wait_all;\n" ::); -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma.h b/3rdparty/cutlass/include/cutlass/arch/mma.h -new file mode 100644 -index 0000000..7d4d693 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the operation implied by MMA. -+struct OpMultiplyAdd {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT -+struct OpMultiplyAddSaturate {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to a narrower type (BF16) -+struct OpMultiplyAddFastBF16 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to a narrower type (F16) -+struct OpMultiplyAddFastF16 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the input is converted to 2 (big and small) TF32 components -+// Perform 3xTF32 or 4xTF32 for every F32 output element -+struct OpMultiplyAddFastF32 {}; -+ -+/// Tag indicating the input is converted to 2 (big and small) TF32 components -+// Perform 3xTF32 or 4xTF32 for every complex output element -+struct OpMultiplyAddComplexFastF32 {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the complex multiply-add operation -+struct OpMultiplyAddComplex {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the gaussian complex multiply-add operation -+struct OpMultiplyAddGaussianComplex {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag indicating the inner product is defined by (XOR, POPC) -+struct OpXorPopc {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag classifying math operators as thread-level operations. -+struct OpClassSimt {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tag classifing operators as Tensor Core operations. -+struct OpClassTensorOp {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Tag classifing operators as WMMA Tensor Core operations -+struct OpClassWmmaTensorOp {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator -+> -+struct Mma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation -+template < -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator_ -+> -+struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC_, LayoutC, Operator_> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = Operator_; -+ using ElementC = ElementC_; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ multiply_add op; -+ -+ d[0] = op(a[0], b[0], c[0]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specifies internal data type for computation -+struct SPFormatType { -+ enum Kind { -+ Thread -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Inner product operator -+ typename Operator, -+ /// Specifies meta data format -+ SPFormatType::Kind SPFormat = SPFormatType::Thread -+> -+struct SparseMma; -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Specializations for each compute capability -+// -+ -+#include "cutlass/arch/mma_sm50.h" -+#include "cutlass/arch/mma_sm60.h" -+#include "cutlass/arch/mma_sm61.h" -+#include "cutlass/arch/mma_sm70.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+#include "cutlass/arch/mma_sparse_sm80.h" -+#include "cutlass/arch/mma_sm90.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h -new file mode 100644 -index 0000000..8aca344 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm50.h -@@ -0,0 +1,432 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = float; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = double; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+ d[0] = a[0] * b[0] + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ float, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0] + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0] + c[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ float, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0] * b[0].real() + c[0].real(); -+ d[0].imag() = a[0] * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ complex, -+ LayoutA, -+ double, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0] + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0] + c[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma< -+ gemm::GemmShape<1, 1, 1>, -+ 1, -+ double, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAddComplex; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0] * b[0].real() + c[0].real(); -+ d[0].imag() = a[0] * b[0].imag() + d[0].imag(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = float; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ d[0] = float(a[0]) * float(b[0]) + c[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation for Quaternions -+template < -+ /// Layout of A matrix -+ typename LayoutA, -+ /// Layout of B matrix -+ typename LayoutB, -+ /// Layout of C matrix -+ typename LayoutC -+> -+struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using Element = Quaternion; -+ using ElementC = Element; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ multiply_add op; -+ d[0] = op(a[0], b[0], c[0]); -+ } -+ -+}; -+ -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h -new file mode 100644 -index 0000000..349c838 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm60.h -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 1, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 B = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast &>(D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[i] * b[0] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 B = reinterpret_cast<__half2 const &>(b); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast &>(D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[0] * b[i] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma < -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b)); -+ __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b)); -+ -+ __half2 const *C = reinterpret_cast<__half2 const *>(&c); -+ -+ __half2 Dlo = __hfma2(A, Blo, C[0]); -+ __half2 Dhi = __hfma2(A, Bhi, C[1]); -+ -+ Array * D = reinterpret_cast *>(&d); -+ -+ D[0] = reinterpret_cast const &>(Dlo); -+ D[1] = reinterpret_cast const &>(Dhi); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < 2; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j]; -+ } -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma< -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); -+ __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); -+ __half2 const & B = reinterpret_cast<__half2 const &>(b); -+ -+ __half2 const *C = reinterpret_cast<__half2 const *>(&c); -+ -+ __half2 Dlo = __hfma2(Alo, B, C[0]); -+ __half2 Dhi = __hfma2(Ahi, B, C[0]); -+ -+ Array * D = reinterpret_cast *>(&d); -+ -+ D[0] = reinterpret_cast &>(Dlo); -+ D[1] = reinterpret_cast &>(Dhi); -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < 2; ++j) { -+ d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j]; -+ } -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h -new file mode 100644 -index 0000000..a1af935 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm61.h -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ int8_t, -+ LayoutA, -+ int8_t, -+ LayoutB, -+ int, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 4>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+ -+ unsigned const &A = reinterpret_cast(a); -+ unsigned const &B = reinterpret_cast(b); -+ -+ asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d[0]) -+ : "r"(A), "r"(B), "r"(c[0])); -+ -+#else -+ -+ d[0] = c[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < 4; ++k) { -+ d[0] += a[k] * b[k]; -+ } -+ -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template -+struct Mma< -+ gemm::GemmShape<1, 1, 2>, -+ 1, -+ int16_t, -+ layout::RowMajor, -+ int16_t, -+ layout::ColumnMajor, -+ int, -+ LayoutC, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<1, 1, 2>; -+ using Operator = OpMultiplyAdd; -+ using ElementC = int; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) -+ -+ unsigned const &A = reinterpret_cast(a); -+ unsigned const &B = reinterpret_cast(b); -+ -+ asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" -+ : "=r"(d[0]) -+ : "r"(A), "r"(B), "r"(c[0])); -+#else -+ d[0] = c[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < 2; ++k) { -+ d[0] += a[k] * b[k]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h -new file mode 100644 -index 0000000..9f93714 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm70.h -@@ -0,0 +1,665 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) -+#define CUTLASS_ARCH_MMA_SM70_SUPPORTED -+#endif -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) -+#define CUTLASS_ARCH_MMA_SM70_ENABLED -+#endif -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix multiply accumulate 884 - FP16 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix multiply accumulate 884 - FP32 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::ColumnMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 4>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::RowMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ /// Multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) { -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " -+ "{%12,%13,%14,%15,%16,%17,%18,%19};\n" -+ : "=f"(D[0]), -+ "=f"(D[1]), -+ "=f"(D[2]), -+ "=f"(D[3]), -+ "=f"(D[4]), -+ "=f"(D[5]), -+ "=f"(D[6]), -+ "=f"(D[7]) -+ : "r"(A[0]), -+ "r"(A[1]), -+ "r"(B[0]), -+ "r"(B[1]), -+ "f"(C[0]), -+ "f"(C[1]), -+ "f"(C[2]), -+ "f"(C[3]), -+ "f"(C[4]), -+ "f"(C[5]), -+ "f"(C[6]), -+ "f"(C[7]) -+ ); -+ -+#else -+ assert(0); -+ #if defined(__CUDA_ARCH__) -+ asm volatile ("brkpt;\n" ::); -+ #endif -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation specialized for the entire warp -+template < -+ typename LayoutA, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename Operator -+> -+struct Mma< -+ gemm::GemmShape<16, 16, 4>, -+ 32, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+> : -+ public Mma< -+ gemm::GemmShape<8, 8, 4>, -+ 8, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator> { -+ -+ using Shape = gemm::GemmShape<16, 16, 4>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h -new file mode 100644 -index 0000000..1402e76 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm75.h -@@ -0,0 +1,1301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply for SM75 -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply. -+#include -+#include "cutlass/wmma_array.h" -+#endif -+ -+// CUTLASS includes -+#include "cutlass/arch/mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) -+ -+#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) -+#define CUTLASS_ARCH_MMA_SM75_ENABLED -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - FP16 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ unsigned const *C = reinterpret_cast(&c); -+ unsigned *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const *A = reinterpret_cast(&a); -+ unsigned const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply .8816 (8b) -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8, 8, 16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8, 8, 16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (8b) with SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (4b) -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Integer matrix multiply (4b) - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ int4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,32>, -+ 32, -+ uint4b_t, -+ layout::RowMajor, -+ uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<8,8,32>; -+ -+ using ElementA = uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+ unsigned const & A = reinterpret_cast(a); -+ unsigned const & B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// b1 ^ b1 + s32 => s32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,128>, -+ 32, -+ uint1b_t, -+ layout::RowMajor, -+ uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpXorPopc> { -+ -+ using Shape = gemm::GemmShape<8,8,128>; -+ -+ using ElementA = uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpXorPopc; -+ using ArchTag = arch::Sm75; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -+ -+#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED)) -+ using WmmaFragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ nvcuda::wmma::experimental::precision::b1, -+ nvcuda::wmma::row_major>; -+ -+ using WmmaFragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ nvcuda::wmma::experimental::precision::b1, -+ nvcuda::wmma::col_major>; -+ -+ using WmmaFragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ int>; -+ -+ WmmaFragmentA const & A = reinterpret_cast(a); -+ WmmaFragmentB const & B = reinterpret_cast(b); -+ -+ WmmaFragmentC const & C = reinterpret_cast(c); -+ WmmaFragmentC & D = reinterpret_cast(d); -+ -+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, -+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h -new file mode 100644 -index 0000000..8682ae1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm80.h -@@ -0,0 +1,2185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -+ -+#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+#define CUTLASS_ARCH_MMA_SM80_ENABLED -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - Float BF16, FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 8>, -+ 32, -+ bfloat16_t, -+ layout::RowMajor, -+ bfloat16_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm( -+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1684 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 4>, -+ 32, -+ tfloat32_t, -+ layout::RowMajor, -+ tfloat32_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 4>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : -+ "r"(A[0]), "r"(A[1]), -+ "r"(B[0]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 1688 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct Mma, 32, tfloat32_t, layout::RowMajor, -+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd> { -+ using Shape = gemm::GemmShape<16, 8, 8>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ uint32_t const *C = reinterpret_cast(&c); -+ uint32_t *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), -+ "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]) -+ ); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ bfloat16_t, -+ layout::RowMajor, -+ bfloat16_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 16>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " -+ "{%10,%11,%12,%13};\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 884 - F64 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<8,8,4>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<8,8,4>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm80; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ double const & A = reinterpret_cast(a); -+ double const & B = reinterpret_cast(b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" -+ : "=d"(D[0]), "=d"(D[1]) -+ : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+ -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " -+ "{%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const &B = reinterpret_cast(b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " -+ "{%6}, {%7,%8,%9,%10};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), -+ "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16832 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const * A = reinterpret_cast(&a); -+ uint32_t const * B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,32>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16,8,32>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16864 - S4 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const * A = reinterpret_cast(&a); -+ uint32_t const * B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16, 8, 64>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate> { -+ -+ using Shape = gemm::GemmShape<16, 8, 64>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAddSaturate; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,256>, -+ 32, -+ cutlass::uint1b_t, -+ layout::RowMajor, -+ cutlass::uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,256>; -+ -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int32_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = B1 & B1 + S32 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,256>, -+ 32, -+ cutlass::uint1b_t, -+ layout::RowMajor, -+ cutlass::uint1b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpXorPopc> { -+ -+ using Shape = gemm::GemmShape<16,8,256>; -+ -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpXorPopc; -+ using ArchTag = arch::Sm80; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c -+ ) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ asm volatile( -+ "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, " -+ "{%8,%9}, {%10,%11,%12,%13};\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED) -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h b/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h -new file mode 100644 -index 0000000..1d0745b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sm90.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) -+ #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED -+ #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) -+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -+ #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED -+ #endif -+ #endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ #define CUTLASS_ARCH_MMA_SM90_SUPPORTED -+ #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) -+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -+ #define CUTLASS_ARCH_MMA_SM90_ENABLED -+ #endif -+ #endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x4 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,4>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,4>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" -+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) -+ : "d"(A[0]), "d"(A[1]), -+ "d"(B[0]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x8 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,8>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,8>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" -+ : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) -+ : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), -+ "d"(B[0]), "d"(B[1]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ -+ CUTLASS_UNUSED(d); -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Matrix Multiply-Add 16x8x16 fp64 -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F64 = F64 * F64 + F64 -+template <> -+struct Mma< -+ gemm::GemmShape<16,8,16>, -+ 32, -+ double, -+ layout::RowMajor, -+ double, -+ layout::ColumnMajor, -+ double, -+ layout::RowMajor, -+ OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<16,8,16>; -+ -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = double; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using Operator = OpMultiplyAdd; -+ -+ using ArchTag = arch::Sm90; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c) const { -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+ double const *A = reinterpret_cast(&a); -+ double const *B = reinterpret_cast(&b); -+ -+ double const *C = reinterpret_cast(&c); -+ double *D = reinterpret_cast(&d); -+ -+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" -+ : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) -+ : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]) -+ "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), -+ "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); -+ -+#else -+ CUTLASS_NOT_IMPLEMENTED(); -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h b/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h -new file mode 100644 -index 0000000..a1f5b1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/mma_sparse_sm80.h -@@ -0,0 +1,1685 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Sparse matrix multiply accumulate for SM80 -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "mma.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) -+ -+#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED -+#endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16832 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F16 = F16 * F16 + F16 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16, 8, 32>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread -+> { -+ -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ uint32_t const *C = reinterpret_cast(&c); -+ uint32_t *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " -+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); -+ } -+ else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " -+ "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" -+ : "=r"(D[0]), "=r"(D[1]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); -+ } -+ else { -+ assert(0); -+ } -+#else -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = F16 * F16 + F32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16, 8, 32>, -+ 32, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ float, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread -+ > { -+ -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = half_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = half_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), -+ "r"(E)); -+ } -+ else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " -+ "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), -+ "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), -+ "r"(E)); -+ } -+ else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 -+template <> -+struct SparseMma, 32, bfloat16_t, layout::RowMajor, -+ bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd, SPFormatType::Thread> { -+ using Shape = gemm::GemmShape<16, 8, 32>; -+ -+ using ElementA = bfloat16_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = bfloat16_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 2; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16816 - Float TF32 -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 -+template <> -+struct SparseMma, 32, tfloat32_t, layout::RowMajor, -+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, -+ OpMultiplyAdd, SPFormatType::Thread> { -+ using Shape = gemm::GemmShape<16, 8, 16>; -+ -+ using ElementA = tfloat32_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = tfloat32_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = float; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 4; -+ -+ static int const kMaxID2 = 2; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, -+ FragmentC const &c, uint32_t const &E, int const id2) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ float const *C = reinterpret_cast(&c); -+ float *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else if (id2 == 1) { -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" -+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); -+ } else { -+ assert(0); -+ } -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ int8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,64>, -+ 32, -+ uint8_t, -+ layout::RowMajor, -+ uint8_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,64>; -+ -+ using ElementA = uint8_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = uint8_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAdd, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::int4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::int4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -+template <> -+struct SparseMma< -+ gemm::GemmShape<16,8,128>, -+ 32, -+ cutlass::uint4b_t, -+ layout::RowMajor, -+ cutlass::uint4b_t, -+ layout::ColumnMajor, -+ int, -+ layout::RowMajor, -+ OpMultiplyAddSaturate, -+ SPFormatType::Thread> { -+ -+ using Shape = gemm::GemmShape<16,8,128>; -+ -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = layout::RowMajor; -+ using FragmentA = Array; -+ -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = layout::ColumnMajor; -+ using FragmentB = Array; -+ -+ using ElementC = int; -+ using LayoutC = layout::RowMajor; -+ using FragmentC = Array; -+ -+ using FragmentE = uint32_t; -+ -+ using Operator = OpMultiplyAdd; -+ using ArchTag = arch::Sm80; -+ -+ static int const kSparse = 2; -+ -+ static int const kMetaSizeInBits = 2; -+ -+ static int const kMaxID2 = 1; -+ -+ /// Computes multiply-add -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA const &a, -+ FragmentB const &b, -+ FragmentC const &c, -+ uint32_t const &E, -+ int const id2 -+ ) const { -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) -+ -+ uint32_t const *A = reinterpret_cast(&a); -+ uint32_t const *B = reinterpret_cast(&b); -+ -+ int const *C = reinterpret_cast(&c); -+ int *D = reinterpret_cast(&d); -+ -+ if (id2 == 0) -+ asm volatile( -+ "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " -+ "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" -+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) -+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), -+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); -+ else -+ assert(0); -+ -+#else -+ -+ CUTLASS_UNUSED(a); -+ CUTLASS_UNUSED(b); -+ CUTLASS_UNUSED(c); -+ CUTLASS_UNUSED(d); -+ assert(0); -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h b/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h -new file mode 100644 -index 0000000..2b74a22 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/reg_reconfig.h -@@ -0,0 +1,68 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief PTX for CTA Reconfiguration -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -+ #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) -+ #define CUDA_CTA_RECONFIG_ACTIVATED 1 -+ #endif -+#else -+ #define CUDA_CTA_RECONFIG_ACTIVATED 0 -+#endif -+ -+namespace cutlass { -+namespace arch { -+ -+template -+CUTLASS_DEVICE -+void warpgroup_reg_alloc(){ -+#if CUDA_CTA_RECONFIG_ACTIVATED -+ asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -+#endif -+} -+ -+template -+CUTLASS_DEVICE -+void warpgroup_reg_dealloc(){ -+#if CUDA_CTA_RECONFIG_ACTIVATED -+ asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); -+#endif -+} -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd.h b/3rdparty/cutlass/include/cutlass/arch/simd.h -new file mode 100644 -index 0000000..71128c2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd.h -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators -+*/ -+ -+#pragma once -+ -+#include "../array.h" -+#include "../numeric_types.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Element-wise operators -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator*(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] * b[i]; -+ } -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator+(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] + b[i]; -+ } -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template -+Array operator-(Array const &a, Array const &b) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] - b[i]; -+ } -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Multiply-accumulate operators -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Array mac(Array const &a, Array const &b, Array const &c) { -+ Array d; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Dot product operator -+// -+ -+CUTLASS_HOST_DEVICE -+template -+Accumulator dot(Array const &a, Array const &b, Accumulator accum) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ accum += a[i] * b[i]; -+ } -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "simd_sm60.h" -+#include "simd_sm61.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h b/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h -new file mode 100644 -index 0000000..16d528b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd_sm60.h -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators for SM60 -+*/ -+ -+#pragma once -+ -+#include "simd.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Element-wise operators - specialized for half_t x 2 -+// -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator*(Array const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator+(AArray const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+CUTLASS_HOST_DEVICE -+template <> -+Array operator-(Array const &a, Array const &b) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Multiply-accumulate operators - specialized for half_t x 2 -+CUTLASS_HOST_DEVICE -+template <> -+Array mac(Array const &a, Array const &b, Array const &c) { -+ Array d; -+ -+ // TODO -+ -+ return d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t -+CUTLASS_HOST_DEVICE -+template <> -+half_t dot(Array const &a, Array const &b, half_t accum) { -+ -+ // TODO -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float -+CUTLASS_HOST_DEVICE -+template <> -+float dot(Array const &a, Array const &b, float accum) { -+ -+ // TODO -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h b/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h -new file mode 100644 -index 0000000..ba9abb7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/simd_sm61.h -@@ -0,0 +1,147 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing SIMD operators for SM61 -+*/ -+ -+#pragma once -+ -+#include "simd.h" -+ -+namespace cutlass { -+namespace arch { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t -+CUTLASS_HOST_DEVICE -+template <> -+int32_t dot(Array const &a, Array const &b, int32_t accum) { -+ -+ return accum; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma.h b/3rdparty/cutlass/include/cutlass/arch/wmma.h -new file mode 100644 -index 0000000..db54e45 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations -+*/ -+ -+#pragma once -+ -+// CUTLASS WMMA does not support clang at present. -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#if (__CUDACC_VER_MAJOR__ >= 9) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) -+#define CUTLASS_ARCH_WMMA_ENABLED -+#define CUTLASS_ARCH_WMMA_SM70_ENABLED -+#endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 10) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) -+#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+#define CUTLASS_ARCH_WMMA_SM72_ENABLED -+#endif -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 10) -+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) -+#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#define CUTLASS_ARCH_WMMA_SM75_ENABLED -+#endif -+#endif -+ -+#endif //!(defined(__clang__) && defined(__CUDA__)) -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include -+#include "cutlass/arch/mma.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass data types => nvcuda::wmma data types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct CutlassToWmmaDataType{ -+ using Type = Type_; -+}; -+ -+/// Statically maps cutlass::half_t => __half -+template<> -+struct CutlassToWmmaDataType { -+ using Type = __half; -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+template<> -+struct CutlassToWmmaDataType { -+ using Type = __nv_bfloat16; -+}; -+#endif -+ -+/// Statically maps int8_t => char -+template<> -+struct CutlassToWmmaDataType { -+ using Type = signed char; -+}; -+ -+/// Statically maps uint8_t => char -+template<> -+struct CutlassToWmmaDataType { -+ using Type = unsigned char; -+}; -+ -+/// Statically maps int32_t => int -+template<> -+struct CutlassToWmmaDataType { -+ using Type = int; -+}; -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+/// Statically maps cutlass::int4b_t => experimental::precision::s4 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::s4; -+}; -+ -+/// Statically maps cutlass::uint4b_t => experimental::precision::s4 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::u4; -+}; -+ -+/// Statically maps cutlass::uint1b_t => experimental::precision::b1 -+template<> -+struct CutlassToWmmaDataType { -+ using Type = nvcuda::wmma::experimental::precision::b1; -+}; -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass::layout => nvcuda::wmma layout tags -+//////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct CutlassToWmmaLayout { -+}; -+ -+/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -+template <> -+struct CutlassToWmmaLayout { -+ using Layout = nvcuda::wmma::row_major; -+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags -+//////////////////////////////////////////////////////////////////////////////////////////////// -+template <> -+struct CutlassToWmmaLayout { -+ using Layout = nvcuda::wmma::col_major; -+ static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major; -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+/// Statically maps nvcuda::wmma data types => cutlass data types -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct WmmaToCutlassDataType{ -+ using Type = Type_; -+}; -+ -+/// Statically maps __half => cutlass::half_t -+template<> -+struct WmmaToCutlassDataType<__half> { -+ using Type = cutlass::half_t; -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+template<> -+struct WmmaToCutlassDataType<__nv_bfloat16> { -+ using Type = cutlass::bfloat16_t; -+}; -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks -+// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]), -+// and native wmma size (Shape) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, ///< Size of the matrix product (concept: GemmShape) -+ typename ElementA_, ///< Data type of A elements -+ typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) -+ typename ElementB_, ///< Data type of B elements -+ typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) -+ typename ElementC_, ///< Element type of C matrix -+ typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) -+ typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) -+> -+struct Wmma; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace arch -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Specializations for each compute capability -+// -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "cutlass/arch/wmma_sm70.h" -+#endif -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include "cutlass/arch/wmma_sm72.h" -+#endif -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "cutlass/arch/wmma_sm75.h" -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h -new file mode 100644 -index 0000000..0658474 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm70.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for half -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename ElementC_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::half_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::half_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ ElementC_, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+ -+#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::half_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::half_t; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm70; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ // check supported wmma output data type for the given multiplicand data types -+ static_assert( -+ platform::is_same::value || platform::is_same::value, -+ "Supported of wmma output data type for f16 multiplicands are: f16 and f32"); -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+#else -+ static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h -new file mode 100644 -index 0000000..c20e1b3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm72.h -@@ -0,0 +1,210 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for int8_t -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ int8_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ int8_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ using Shape = Shape_; -+ using ElementA = int8_t; -+ using LayoutA = LayoutA_; -+ using ElementB = int8_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm72; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); -+#endif -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for uint8_t -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ uint8_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ uint8_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ using Shape = Shape_; -+ using ElementA = uint8_t; -+ using LayoutA = LayoutA_; -+ using ElementB = uint8_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm72; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value || -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h b/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h -new file mode 100644 -index 0000000..89d030f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/arch/wmma_sm75.h -@@ -0,0 +1,207 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Matrix multiply -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace arch { -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4). -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::int4b_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::int4b_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpMultiplyAdd; -+ using ArchTag = arch::Sm75; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ nvcuda::wmma::mma_sync(D, A, B, C); -+ -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); -+#endif -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// WMMA template structure defines nvcuda::wmma::fragments and static assert for -+// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). -+// -+//////////////////////////////////////////////////////////////////////////////// -+template < -+typename Shape_, -+typename LayoutA_, -+typename LayoutB_, -+typename LayoutC_> -+struct Wmma< -+ Shape_, ///< Size of the matrix product (concept: GemmShape) -+ cutlass::uint1b_t, ///< ElementA -+ LayoutA_, ///< LayoutA -+ cutlass::uint1b_t, ///< ElementB -+ LayoutB_, ///< LayoutB -+ int32_t, ///< ElementC -+ LayoutC_, ///< LayoutC -+ cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc) -+> { -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ using Shape = Shape_; -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = LayoutA_; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = LayoutB_; -+ using ElementC = int32_t; -+ using LayoutC = LayoutC_; -+ using Operator = cutlass::arch::OpXorPopc; -+ using ArchTag = arch::Sm75; -+ -+ // check supported wmma shape for the given multiplicand data types -+ static_assert( -+ platform::is_same, Shape>::value, -+ "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128"); -+ -+ -+ // Wmma Fragment -+ using FragmentA = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_a, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentB = nvcuda::wmma::fragment< -+ nvcuda::wmma::matrix_b, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type, -+ typename CutlassToWmmaLayout::Layout>; -+ -+ using FragmentC = nvcuda::wmma::fragment< -+ nvcuda::wmma::accumulator, -+ Shape::kM, -+ Shape::kN, -+ Shape::kK, -+ typename CutlassToWmmaDataType::Type>; -+ -+ /// Performs a nvcuda::wmma matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, -+ nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); -+ } -+ -+#else -+ static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond"); -+#endif -+ -+}; -+ -+} // namespace arch -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/array.h b/3rdparty/cutlass/include/cutlass/array.h -new file mode 100644 -index 0000000..9fe245b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array.h -@@ -0,0 +1,2457 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/half.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N, -+ bool RegisterSized = sizeof_bits::value >= 32 -+> -+class Array; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an Array<> in bits -+template -+struct sizeof_bits > { -+ static int const value = -+ int(sizeof(typename Array::Storage)) * 8 * int(Array::kStorageElements); -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if the argument is a power of 2 -+CUTLASS_HOST_DEVICE -+constexpr bool ispow2(unsigned x) { -+ return x && (!(x & (x - 1))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the largest power of two not greater than the argument. -+CUTLASS_HOST_DEVICE -+constexpr unsigned floor_pow_2(unsigned x) { -+ return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N -+> -+class Array { -+public: -+ -+ /// Storage type -+ using Storage = T; -+ -+ /// Element type -+ using Element = T; -+ -+ /// Number of storage elements -+ //static std::size_t const kStorageElements = N; -+ static size_t const kStorageElements = N; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ // -+ // C++ standard members -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type &reference; -+ typedef value_type const & const_reference; -+ typedef value_type *pointer; -+ typedef value_type const * const_pointer; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to object -+ T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator(T *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T &operator*() const { -+ return *ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to object -+ const T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(T const *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator &operator++() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator &operator--() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator operator++(int) { -+ const_iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator operator--(int) { -+ const_iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T const &operator*() const { -+ return *ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(const_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(const_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional iterator over elements -+ class reverse_iterator { -+ -+ /// Pointer to object -+ T *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(T *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator &operator++() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator &operator--() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator operator++(int) { -+ iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator operator--(int) { -+ iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T &operator*() const { -+ return *(ptr_ - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(reverse_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(reverse_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_reverse_iterator { -+ -+ /// Pointer to object -+ T const *ptr_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(): ptr_(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(T const *_ptr): ptr_(_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator &operator++() { -+ --ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator &operator--() { -+ ++ptr_; -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator operator++(int) { -+ const_reverse_iterator ret(*this); -+ --ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator operator--(int) { -+ const_reverse_iterator ret(*this); -+ ++ptr_; -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T const &operator*() const { -+ return *(ptr_ - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(const_iterator const &other) const { -+ return ptr_ == other.ptr_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(const_iterator const &other) const { -+ return ptr_ != other.ptr_; -+ } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kElements]; -+ -+public: -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ Array() { } -+ -+ CUTLASS_HOST_DEVICE -+ Array(Array const &x) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElements; ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ #endif -+ -+ /// Efficient clear method -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ fill(T(0)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference at(size_type pos) { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference at(size_type pos) const { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator[](size_type pos) { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator[](size_type pos) const { -+ return reinterpret_cast(storage[pos]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference front() { -+ return reinterpret_cast(storage[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference front() const { -+ return reinterpret_cast(storage[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference back() { -+ return reinterpret_cast(storage[kStorageElements - 1]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference back() const { -+ return reinterpret_cast(storage[kStorageElements - 1]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer raw_data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer raw_data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void fill(T const &value) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElements; ++i) { -+ storage[i] = static_cast(value); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator begin() const { -+ return cbegin(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cbegin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator end() { -+ return iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator end() const { -+ return cend(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cend() const { -+ return const_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rbegin() { -+ return reverse_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator rbegin() const { -+ return crbegin(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crbegin() const { -+ return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rend() { -+ return reverse_iterator(reinterpret_cast(storage)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator rend() const { -+ return crend(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crend() const { -+ return const_reverse_iterator(reinterpret_cast(storage)); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Factories -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x) { -+ Array m; -+ m[0] = x; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y, Element z) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ m[2] = z; -+ return m; -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array make_Array(Element x, Element y, Element z, Element w) { -+ Array m; -+ m[0] = x; -+ m[1] = y; -+ m[2] = z; -+ m[3] = w; -+ return m; -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct absolute_value_op< Array > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs) const { -+ -+ Array result; -+ absolute_value_op scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct plus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ plus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+template -+struct minus> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ minus scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct multiplies> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ multiplies scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct divides> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ divides scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct maximum> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ maximum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct minimum> { -+ -+ CUTLASS_HOST_DEVICE -+ static T scalar_op(T const &lhs, T const &rhs) { -+ return (rhs < lhs ? rhs : lhs); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, T const &scalar) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i], scalar); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( T const &scalar, Array const &rhs) const { -+ -+ Array result; -+ minimum scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct negate> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs) const { -+ -+ Array result; -+ negate scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(lhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(a[i], b[i], c[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, T const &scalar, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(a[i], scalar, c[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const &scalar, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = scalar_op(scalar, b[i], c[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add-relu0 -+template -+struct multiply_add_relu0, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, T const &scalar, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const &scalar, Array const &b, Array const &c) const { -+ -+ Array result; -+ multiply_add scalar_op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+template -+struct conjugate > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a) const { -+ -+ conjugate conj_op; -+ -+ Array ca; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ ca[i] = conj_op(a[i]); -+ } -+ return ca; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations targeting SIMD instructions in device code. -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct plus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] + rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs + rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] + rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct minus> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] - rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs - rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] - rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct multiplies> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] * rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmul( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs * rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmul( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] * rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct divides> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hdiv( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] / rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hdiv( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs / rhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hdiv( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = lhs[i] / rhs; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct negate> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hneg2(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ half_t x = lhs[N - 1]; -+ __half lhs_val = -reinterpret_cast<__half const &>(x); -+ result[N - 1] = reinterpret_cast(lhs_val); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = -lhs[i]; -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ half_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ __half d_residual = __hfma( -+ reinterpret_cast<__half const &>(a), -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a, b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ half_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(b), -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b, c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ half_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ -+ __half d_residual = __hfma( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(c)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add-relu0 -+template -+struct multiply_add_relu0, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ half_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ __half d_residual = __hfma_relu( -+ reinterpret_cast<__half const &>(a), -+ b_residual_ptr[N - 1], -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a, b[i], c[i]), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ half_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); -+ __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(b), -+ c_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b, c[i]), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ half_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); -+ __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); -+ __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); -+ } -+ -+ if (N % 2) { -+ -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); -+ -+ __half d_residual = __hfma_relu( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(c)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ multiply_add op; -+ maximum mx; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = mx(op(a[i], b[i], c), half_t(0)); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct minimum> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmin( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmin( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs[i] < lhs ? rhs[i] : lhs); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmin( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (rhs < lhs[i] ? rhs : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+template -+struct maximum> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmax( -+ a_residual_ptr[N - 1], -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(half_t const & lhs, Array const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); -+ } -+ -+ if (N % 2) { -+ __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); -+ -+ __half d_residual = __hmax( -+ reinterpret_cast<__half const &>(lhs), -+ b_residual_ptr[N - 1]); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs < rhs[i] ? rhs[i] : lhs); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const & lhs, half_t const &rhs) const { -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); -+ __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); -+ } -+ -+ if (N % 2) { -+ __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); -+ -+ __half d_residual = __hmax( -+ a_residual_ptr[N - 1], -+ reinterpret_cast<__half const &>(rhs)); -+ -+ result[N - 1] = reinterpret_cast(d_residual); -+ } -+ -+ #else -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = (lhs[i] < rhs ? rhs : lhs[i]); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, Array, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ bfloat16_t const &a, -+ Array const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ unsigned a_packed = static_cast(a.raw()); -+ a_packed = (a_packed | (a_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a, b[i], c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ bfloat16_t const &b, -+ Array const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *c_ptr = reinterpret_cast(&c); -+ -+ unsigned b_packed = static_cast(b.raw()); -+ b_packed = (b_packed | (b_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b, c[i]); -+ } -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &a, -+ Array const &b, -+ bfloat16_t const &c) const { -+ -+ Array result; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ unsigned *result_ptr = reinterpret_cast(&result); -+ -+ unsigned const *a_ptr = reinterpret_cast(&a); -+ unsigned const *b_ptr = reinterpret_cast(&b); -+ -+ unsigned c_packed = static_cast(c.raw()); -+ c_packed = (c_packed | (c_packed << 16)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" -+ : "=r"(result_ptr[i]) -+ : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) -+ ); -+ } -+ -+ if (N % 2) { -+ -+ uint16_t *result_ptr = reinterpret_cast(&result); -+ uint16_t const *a_residual_ptr = reinterpret_cast(&a); -+ uint16_t const *b_residual_ptr = reinterpret_cast(&b); -+ uint16_t const *c_residual_ptr = reinterpret_cast(&c); -+ -+ asm ("fma.rn.bf16 %0, %1, %2, %3;\n" -+ : "=h"(result_ptr[N - 1]) -+ : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) -+ ); -+ } -+ -+ #else -+ -+ multiply_add op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = op(a[i], b[i], c); -+ } -+ #endif -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_and -+template -+struct bit_and> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] & b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_or -+template -+struct bit_or> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] | b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_not -+template -+struct bit_not> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (~a_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+/// bit_xor -+template -+struct bit_xor> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &a, Array const &b) const { -+ using ArrayType = Array; -+ using Storage = typename ArrayType::Storage; -+ ArrayType result; -+ -+ Storage *result_data = result.raw_data(); -+ Storage const *a_data = a.raw_data(); -+ Storage const *b_data = b.raw_data(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ArrayType::kStorageElements; ++i) { -+ result_data[i] = (a_data[i] ^ b_data[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operator overloads -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator+(Array const &lhs, Array const &rhs) { -+ plus> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator-(Array const &lhs, Array const &rhs) { -+ minus> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator-(Array const &lhs) { -+ negate> op; -+ return op(lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(Array const &lhs, Array const &rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(T lhs, Array const &rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator*(Array const &lhs, T rhs) { -+ multiplies> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array operator/(Array const &lhs, Array const &rhs) { -+ divides> op; -+ return op(lhs, rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, Array const &b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(T a, Array const &b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, T b, Array const &c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+Array fma(Array const &a, Array const &b, T c) { -+ multiply_add> op; -+ return op(a, b, c); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/array_subbyte.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// AlignedArray -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Aligned array type -+template < -+ /// Element type -+ typename T, -+ /// Number of elements in the array -+ int N, -+ /// Alignment requirement in bytes -+ int Alignment = sizeof_bits::value * N / 8 -+> -+class alignas(Alignment) AlignedArray: public Array { -+public: -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/array_planar_complex.h b/3rdparty/cutlass/include/cutlass/array_planar_complex.h -new file mode 100644 -index 0000000..4503b77 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array_planar_complex.h -@@ -0,0 +1,103 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Array holding planar complex elements -+template -+struct ArrayPlanarComplex { -+ -+ /// Underlying real element -+ using Element = Element_; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ /// Underlying Fragment of real-valued elemenets -+ using ArrayReal = Array; -+ -+public: -+ -+ /// Fragment of real-valued elements representing the real part -+ ArrayReal real; -+ -+ /// Fragment of real-valued elements representing the imaginary part -+ ArrayReal imag; -+ -+public: -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ArrayPlanarComplex() { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ArrayPlanarComplex( -+ ArrayReal const &real_, -+ ArrayReal const &imag_ -+ ): -+ real(real_), imag(imag_) { } -+ -+ /// Sets the array to zero efficiently -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ real.clear(); -+ imag.clear(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to deduce template arguments -+template -+CUTLASS_HOST_DEVICE -+ArrayPlanarComplex -+make_ArrayPlanarComplex(Array const &real, Array const &imag) { -+ return ArrayPlanarComplex(real, imag); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/array_subbyte.h b/3rdparty/cutlass/include/cutlass/array_subbyte.h -new file mode 100644 -index 0000000..ac30422 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/array_subbyte.h -@@ -0,0 +1,564 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array for any data type -+template < -+ typename T, -+ int N -+> -+class Array { -+public: -+ -+ static int const kSizeBits = sizeof_bits::value * N; -+ -+ /// Storage type -+ using Storage = typename platform::conditional< -+ ((kSizeBits % 32) != 0), -+ typename platform::conditional< -+ ((kSizeBits % 16) != 0), -+ uint8_t, -+ uint16_t -+ >::type, -+ uint32_t -+ >::type; -+ -+ /// Element type -+ using Element = T; -+ -+ /// Number of logical elements per stored object -+ static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; -+ -+ /// Number of storage elements -+ static size_t const kStorageElements = N / kElementsPerStoredItem; -+ -+ /// Number of logical elements -+ static size_t const kElements = N; -+ -+ /// Bitmask for covering one item -+ static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); -+ -+ // -+ // C++ standard members with pointer types removed -+ // -+ -+ typedef T value_type; -+ typedef size_t size_type; -+ typedef ptrdiff_t difference_type; -+ typedef value_type *pointer; -+ typedef value_type const *const_pointer; -+ -+ // -+ // References -+ // -+ -+ /// Reference object inserts or extracts sub-byte items -+ class reference { -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ reference &operator=(T x) { -+ Storage item = (reinterpret_cast(x) & kMask); -+ -+ Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); -+ *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T get() const { -+ Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ return get(); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ /// Reference object extracts sub-byte items -+ class const_reference { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ const_reference(): ptr_(nullptr), idx_(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ const T get() const { -+ Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; -+ return reinterpret_cast(item); -+ } -+ -+ /// Extract -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ }; -+ -+ // -+ // Iterators -+ // -+ -+ /// Bidirectional iterator over elements -+ class iterator { -+ -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator*() const { -+ return reference(ptr_, idx_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_iterator { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator++() { -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator &operator--() { -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator++(int) { -+ iterator ret(*this); -+ ++idx_; -+ if (idx_ == kElementsPerStoredItem) { -+ ++ptr_; -+ idx_ = 0; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator operator--(int) { -+ iterator ret(*this); -+ if (!idx_) { -+ --ptr_; -+ idx_ = kElementsPerStoredItem - 1; -+ } -+ else { -+ --idx_; -+ } -+ return ret; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator*() const { -+ return const_reference(ptr_, idx_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator==(iterator const &other) const { -+ return ptr_ == other.ptr_ && idx_ == other.idx_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool operator!=(iterator const &other) const { -+ return !(*this == other); -+ } -+ }; -+ -+ /// Bidirectional iterator over elements -+ class reverse_iterator { -+ -+ /// Pointer to storage element -+ Storage *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ }; -+ -+ /// Bidirectional constant iterator over elements -+ class const_reverse_iterator { -+ -+ /// Pointer to storage element -+ Storage const *ptr_; -+ -+ /// Index into elements packed into Storage object -+ int idx_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(): ptr_(nullptr), idx_(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } -+ }; -+ -+private: -+ -+ /// Internal storage -+ Storage storage[kStorageElements]; -+ -+public: -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ Array() { } -+ -+ CUTLASS_HOST_DEVICE -+ Array(Array const &x) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(kStorageElements); ++i) { -+ storage[i] = x.storage[i]; -+ } -+ } -+ #endif -+ -+ /// Efficient clear method -+ CUTLASS_HOST_DEVICE -+ void clear() { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(kStorageElements); ++i) { -+ storage[i] = Storage(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference at(size_type pos) { -+ return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference at(size_type pos) const { -+ return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference operator[](size_type pos) { -+ return at(pos); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference operator[](size_type pos) const { -+ return at(pos); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference front() { -+ return at(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference front() const { -+ return at(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reference back() { -+ return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reference back() const { -+ return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ pointer data() { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_pointer data() const { -+ return reinterpret_cast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage * raw_data() { -+ return storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Storage const * raw_data() const { -+ return storage; -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ constexpr bool empty() const { -+ return !kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ constexpr size_type max_size() const { -+ return kElements; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void fill(T const &value) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerStoredItem; ++i) { -+ reference ref(storage, i); -+ ref = value; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kStorageElements; ++i) { -+ storage[i] = storage[0]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator begin() { -+ return iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cbegin() const { -+ return const_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ iterator end() { -+ return iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_iterator cend() const { -+ return const_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rbegin() { -+ return reverse_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crbegin() const { -+ return const_reverse_iterator(storage + kStorageElements); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ reverse_iterator rend() { -+ return reverse_iterator(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ const_reverse_iterator crend() const { -+ return const_reverse_iterator(storage); -+ } -+ -+ // -+ // Comparison operators -+ // -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/barrier.h b/3rdparty/cutlass/include/cutlass/barrier.h -new file mode 100644 -index 0000000..85a178b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/barrier.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CTA-wide semaphore for inter-CTA synchronization. -+struct Barrier -+{ -+ -+public: -+ -+ /// Flag type -+ using T = int; -+ -+ /// Initial flag value -+ static const T INIT = 0; -+ -+ -+protected: -+ -+ /// Load flag, as a strong acquire operation (int specialization) -+ CUTLASS_DEVICE -+ static int ld_acquire(int *ptr) -+ { -+ int state = 0; -+ -+#if (__CUDA_ARCH__ >= 700) -+ /// SM70 and newer use memory consistency qualifiers -+ -+ // Acquire pattern using acquire modifier -+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); -+ -+#else -+ asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); -+#endif // (__CUDA_ARCH__ >= 700) -+ -+ return state; -+ } -+ -+ -+ /// Reduce into flag, with release pattern (int specialization) -+ CUTLASS_DEVICE -+ static void red_release(int *ptr, int val) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+#if (__CUDA_ARCH__ >= 700) -+ /// SM70 and newer use memory consistency qualifiers -+ -+ // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data -+ // that was weakly-written by other threads prior to the last syncthreads) -+ asm volatile ("fence.acq_rel.gpu;\n"); -+ asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); -+ -+#else -+ __threadfence(); -+ atomicAdd(ptr, val); -+#endif // (__CUDA_ARCH__ >= 700) -+#endif -+ } -+ -+ -+public: -+ -+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(ld_acquire(flag_ptr) < count) {} -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(ld_acquire(flag_ptr) != val) {} -+ } -+ __syncthreads(); -+#endif -+ } -+ -+ /// Uses thread[0] to wait for the specified count of signals on the given flag counter -+ CUTLASS_DEVICE -+ static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ if (thread_idx == 0) -+ { -+ // Spin-loop -+ #pragma unroll 1 -+ while(atomicCAS(flag_ptr, val, 0) != val) {} -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Increment the arrival count for a flag -+ CUTLASS_DEVICE -+ static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ __syncthreads(); -+ -+ if (thread_idx == 0) -+ { -+ red_release(flag_ptr, 1); -+ } -+#endif -+ } -+ -+ -+ /// Increment the arrival counts for a range of flags -+ CUTLASS_DEVICE -+ static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1) -+ { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ int flag_idx = first_flag_idx + thread_idx; -+ T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; -+ -+ // Barrier to make sure all other threads in block have written their data -+ __syncthreads(); -+ -+ // Select threads increment their flags -+ if (thread_idx < count) { -+ red_release(flag_ptr, 1); -+ } -+#endif -+ } -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/bfloat16.h b/3rdparty/cutlass/include/cutlass/bfloat16.h -new file mode 100644 -index 0000000..b660cd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/bfloat16.h -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a proxy class for storing non-standard 16-bit floating point values with -+ 8 bits of exponent and 7 bit of mantissa. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include "cutlass/floating_point_nvrtc.h" -+#else -+#include -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. -+struct alignas(2) bfloat16_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint16_t storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs from an unsigned short -+ CUTLASS_HOST_DEVICE -+ static bfloat16_t bitcast(uint16_t x) { -+ bfloat16_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// Default constructor -+ bfloat16_t() = default; -+ -+ /// Floating-point conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(float x) { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) -+ -+ asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); -+ -+ #else -+ uint32_t bits; -+ -+ #if defined(__CUDA_ARCH__) -+ bits = reinterpret_cast(x); -+ #else -+ std::memcpy(&bits, &x, sizeof(bits)); -+ #endif -+ -+ if ((bits & 0x7f800000) != 0x7f800000) { -+ -+ bool mantissa_bit = ((bits & (1 << 16)) != 0); -+ bool round_bit = ((bits & (1 << 15)) != 0); -+ bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { -+ bits += uint32_t(1 << 16); -+ } -+ } -+ else if (bits & ~0xff800000) { -+ bits = 0x7fffffff; -+ } -+ -+ storage = uint16_t((bits >> 16) & 0xffff); -+ #endif -+ } -+ -+ /// Floating-point conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(double x): bfloat16_t(float(x)) { -+ -+ } -+ -+ /// Integer conversion - round toward nearest -+ CUTLASS_HOST_DEVICE -+ explicit bfloat16_t(int x) { -+ float flt = static_cast(x); -+ uint32_t bits; -+ -+ #if defined(__CUDA_ARCH__) -+ bits = reinterpret_cast(flt); -+ #else -+ std::memcpy(&bits, &flt, sizeof(bits)); -+ #endif -+ -+ storage = uint16_t(bits >> 16); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ unsigned bits = (unsigned(storage) << 16); -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(bits); -+ #else -+ float flt; -+ std::memcpy(&flt, &bits, sizeof(flt)); -+ return flt; -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(float(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (float(*this) != 0.0f); -+ } -+ -+ /// Obtains raw bits -+ CUTLASS_HOST_DEVICE -+ uint16_t raw() const { -+ return storage; -+ } -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((raw() & 0x8000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((raw() >> 7) & 0x0ff); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 127; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(raw() & 0x7f); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::bfloat16_t const& h) { -+ return h.signbit(); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { -+ return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() != 0x0ff); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t nan_bf16(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::bfloat16_t::bitcast(0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::bfloat16_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::bfloat16_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x0ff; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::bfloat16_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x0ff) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::bfloat16_t(sqrtf(float(h))); -+#else -+ return cutlass::bfloat16_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { -+ -+ uint16_t a_bits; -+ uint16_t b_bits; -+ -+ #if defined(__CUDA_ARCH__) -+ a_bits = reinterpret_cast(a); -+ b_bits = reinterpret_cast(b); -+ #else -+ std::memcpy(&a_bits, &a, sizeof(a_bits)); -+ std::memcpy(&b_bits, &b, sizeof(b_bits)); -+ #endif -+ -+ uint16_t a_mag = (a_bits & 0x7fff); -+ uint16_t b_sign = (b_bits & 0x8000); -+ uint16_t result = (a_mag | b_sign); -+ -+ return bfloat16_t::bitcast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace std { -+ -+#if !defined(__CUDACC_RTC__) -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 7; -+ -+ /// Least positive value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } -+ -+ /// Minimum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } -+ -+ /// Maximum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } -+}; -+#endif -+ -+} // namespace std -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator-(bfloat16_t const& lhs) { -+ return bfloat16_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { -+ return bfloat16_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { -+ lhs = bfloat16_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator++(bfloat16_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = bfloat16_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t& operator--(bfloat16_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = bfloat16_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator++(bfloat16_t & lhs, int) { -+ bfloat16_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = bfloat16_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+bfloat16_t operator--(bfloat16_t & lhs, int) { -+ bfloat16_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = bfloat16_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t operator "" _bf16(long double x) { -+ return cutlass::bfloat16_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { -+ return cutlass::bfloat16_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/blas3.h b/3rdparty/cutlass/include/cutlass/blas3.h -new file mode 100644 -index 0000000..f5f8a09 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/blas3.h -@@ -0,0 +1,176 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic include for CUTLASS BLAS3/HPC code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Enumerated type describing the type of kernel (based on input or output matrices). -+enum class BlasMode { -+ kGemm, -+ kSymmetric, -+ kHermitian, -+ kTriangular, -+ kInvalid -+}; -+ -+/// Enumerated type describing the fill mode for matrices for BLAS functions. -+enum class FillMode { -+ kFull, /// The entire tensor is covered. -+ kLower, /// The 'lower' part of a tensor is covered including diagonal -+ kUpper, /// The 'upper' part of a tensor is covered including diaognal -+ kDiagonal, /// Only diagonal elements are covered. -+ kNone, /// No element is covered. -+ kInvalid -+}; -+ -+/// Enumerated type describing the diagonal property of matrices for BLAS functions. -+enum class DiagType { -+ kNonUnit, -+ kUnit, -+ kZero, // Only used internally for computing SYMM/HEMM -+ kInvalid -+}; -+ -+/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. -+enum class SideMode { -+ kLeft, -+ kRight, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines FillMode inversions -+template -+struct InvertFillMode; -+ -+/// Invert FillMode lower to upper -+template <> -+struct InvertFillMode { -+ static FillMode const mode = FillMode::kUpper; -+}; -+ -+/// Invert FillMode upper to lower -+template <> -+struct InvertFillMode { -+ static FillMode const mode = FillMode::kLower; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines SideMode inversions -+template -+struct InvertSideMode; -+ -+/// Invert SideMode left to right -+template <> -+struct InvertSideMode { -+ static SideMode const mode = SideMode::kRight; -+}; -+ -+/// Invert SideMode right to left -+template <> -+struct InvertSideMode { -+ static SideMode const mode = SideMode::kLeft; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines correct compare operation for Triangular matrix boundary -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater_equal, -+ less_equal>::type; -+}; -+ -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater_equal, -+ less_equal>::type; -+}; -+ -+template -+struct TrMatrixCompareOp { -+ using Index = int32_t; -+ using Type = typename platform::conditional< -+ (kFillMode == FillMode::kLower), -+ greater, -+ less>::type; -+}; -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Returns precision in terms of bits (based on datatype) to fill tensors with. -+// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs). -+// Also defines acceptable mantissa result variance/error. -+template -+struct MantissaInBits { -+ static int constexpr bits = 5; -+ static double constexpr error = 1.0e-7; -+}; -+ -+// Full precision is supported for FP64 -+template <> -+struct MantissaInBits { -+ static int constexpr bits = 30; -+ static double constexpr error = 1.0e-15; -+}; -+ -+template <> -+struct MantissaInBits> { -+ static int constexpr bits = 30; -+ static double constexpr error = 1.0e-15; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/block_striped.h b/3rdparty/cutlass/include/cutlass/block_striped.h -new file mode 100644 -index 0000000..563e619 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/block_striped.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, -+ statically-sized array types to global memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/wmma_array.h" -+#include "cutlass/functional.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// AccessWidth -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit -+template < -+ typename T, -+ int Limit> -+struct AccessWidth -+{ -+ // Inductive case -+ template < -+ int ObjectBytes, /// Size of T in bytes -+ int AlignBytes, /// Template induction variable -+ bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes -+ ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> -+ struct Detail -+ { -+ static const int value = Detail::value; -+ }; -+ -+ // Base case (ObjectBytes is not an even multiple of AlignBytes) -+ template < -+ int ObjectBytes, /// Size of T in bytes -+ int AlignBytes> /// Template induction variable -+ struct Detail -+ { -+ static const int value = AlignBytes / 2; -+ }; -+ -+ /// The maximal power-of-two that evenly divides the size of T -+ static const int value = Detail< -+ (int) sizeof(T), -+ 1>::value; -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// StripedAccessType -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Default specialization. Striping granularity is type T.) -+template < -+ typename T, /// Data type -+ int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) -+ AccessWidth::value> -+struct alignas(TransferBytes) StripedAccessType : public T -+{}; -+ -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) -+template < -+ typename T, /// Array element type -+ int N, /// Number of elements in array -+ bool RegisterSized, /// T is register-sized -+ int TransferBytes> /// Data access width -+struct StripedAccessType< -+ Array, -+ TransferBytes> -+: public AlignedArray< -+ T, // Element type of StripedAccessType -+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType -+ TransferBytes> // Alignment of StripedAccessType -+{}; -+ -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+/// ReinterpretCast type for striping a trivially-copyable type in global memory -+/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) -+template< -+ typename Use, -+ int m, -+ int n, -+ int k, -+ typename ElementT, -+ typename Layout, -+ int kFragments, -+ int TransferBytes> -+struct StripedAccessType< -+ WmmaFragmentArray, kFragments>, -+ TransferBytes> -+: public AlignedArray< -+ ElementT, -+ __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), -+ TransferBytes> -+{}; -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// BlockStriped -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Utility for performing block-striped access (load, store) of trivially-copyable, -+/// statically-sized array types to global memory -+template < -+ int BlockThreads, -+ typename ArrayT, -+ typename AccessT = StripedAccessType > -+struct BlockStriped -+{ -+ /// Number of striped accesses -+ static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); -+ static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); -+ -+ /// Load -+ CUTLASS_DEVICE -+ static void load(ArrayT &data, ArrayT *ptr, int thread_idx) -+ { -+ AccessT *access_input = reinterpret_cast(ptr); -+ AccessT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) { -+ access_data[i] = access_input[(BlockThreads * i) + thread_idx]; -+ } -+ } -+ -+ /// Load & Add -+ CUTLASS_DEVICE -+ static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) -+ { -+ AccessT *access_input = reinterpret_cast(ptr); -+ AccessT *access_data = reinterpret_cast(&data); -+ -+ plus add; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) -+ { -+ access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); -+ } -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ AccessT *access_output = reinterpret_cast(ptr); -+ const AccessT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStripes; ++i) { -+ access_output[(BlockThreads * i) + thread_idx] = access_data[i]; -+ } -+ } -+ -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// BlockStripedReduce -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -+/// statically-sized array types to global memory. -+/// (Default specialization) -+template < -+ int BlockThreads, -+ typename ArrayT, -+ typename ElementT = typename StripedAccessType::Element> -+struct BlockStripedReduce : -+ BlockStriped< -+ BlockThreads, -+ ArrayT, -+ ElementT> -+{ -+ /// Reduce -+ CUTLASS_DEVICE -+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ cutlass::red reduce; -+ ElementT *access_output = reinterpret_cast(ptr); -+ const ElementT *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { -+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); -+ } -+ } -+}; -+ -+ -+/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, -+/// statically-sized array types to global memory. -+/// (Specialization for half_t. Uses half2 vectorized-reduction.) -+template < -+ int BlockThreads, -+ typename ArrayT> -+struct BlockStripedReduce : -+ BlockStriped< -+ BlockThreads, -+ ArrayT, -+ half2> -+{ -+ static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); -+ -+ /// Reduce -+ CUTLASS_DEVICE -+ static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) -+ { -+ cutlass::red reduce; -+ half2 *access_output = reinterpret_cast(ptr); -+ const half2 *access_data = reinterpret_cast(&data); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < BlockStripedReduce::kStripes; ++i) -+ { -+ reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); -+ } -+ } -+}; -+ -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/cluster_launch.hpp b/3rdparty/cutlass/include/cutlass/cluster_launch.hpp -new file mode 100644 -index 0000000..4843540 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/cluster_launch.hpp -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief PTX for TMA Tensor Memory Access operators on memory added for SM90 -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+ -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) -+# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED -+#endif -+ -+namespace cutlass { -+ -+#ifndef NDEBUG -+#define Return_Status(cudaError_t_status) \ -+ if (cudaError_t_status != cudaSuccess) { \ -+ fprintf(stderr, \ -+ "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ -+ __FILE__, \ -+ __LINE__, \ -+ cudaGetErrorString(cudaError_t_status)); \ -+ return Status::kInvalid; \ -+ } else { \ -+ return Status::kSuccess; \ -+ } -+#else -+#define Return_Status(cudaError_t_status) \ -+ if (cudaError_t_status != cudaSuccess) { \ -+ return Status::kInvalid; \ -+ } else { \ -+ return Status::kSuccess; \ -+ } -+#endif -+ -+struct ClusterLauncher { -+ constexpr static int MaxClusterSize = 32; -+ -+ // Check for hardware compatibility -+ static inline __host__ -+ Status check_cluster_dims(dim3 const& grid, dim3 const& cluster) { -+ if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && -+ (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { -+ return Status::kSuccess; -+ } -+ else { -+ CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); -+ return Status::kInvalid; -+ } -+ } -+ -+ static inline __host__ -+ Status -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ init(void const* kernel_function) -+#else -+ init(void const* /* kernel_function */) -+#endif -+ { -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ // This attribute was added in CUDA 11.8. -+ cudaError_t status = -+ cudaFuncSetAttribute( -+ kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); -+ Return_Status(status); -+#else -+ return Status::kInvalid; -+#endif -+ } -+ -+ // This is the method we expect to use going forward -+ static inline __host__ -+ Status launch( -+ dim3 const& grid_dims, -+ dim3 const& cluster_dims, -+ dim3 const& block_dims, -+ size_t const& smem_size, -+ cudaStream_t& cuda_stream, -+ void const* kernel, -+ void** kernel_params) { -+#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) -+ if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { -+ CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); -+ return Status::kInvalid; -+ } -+ -+ auto init_status = init(kernel); -+ if (init_status != Status::kSuccess) { -+ CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); -+ return Status::kInvalid; -+ } -+ -+ cudaLaunchConfig_t launch_config; -+ launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; -+ launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; -+ launch_config.dynamicSmemBytes = smem_size; -+ launch_config.stream = cuda_stream; -+ -+ cudaLaunchAttribute launch_attribute[1]; -+ launch_attribute[0].id = cudaLaunchAttributeClusterDimension; -+ launch_attribute[0].val.clusterDim.x = cluster_dims.x; -+ launch_attribute[0].val.clusterDim.y = cluster_dims.y; -+ launch_attribute[0].val.clusterDim.z = cluster_dims.z; -+ -+ launch_config.attrs = launch_attribute; -+ launch_config.numAttrs = 1; -+ -+ CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " -+ "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " -+ "And ClusterDims = " -+ "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); -+ -+ cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); -+ Return_Status(status); -+#else -+ CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); -+ return Status::kInvalid; -+#endif -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/complex.h b/3rdparty/cutlass/include/cutlass/complex.h -new file mode 100644 -index 0000000..089f474 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/complex.h -@@ -0,0 +1,705 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/half.h" -+#include "cutlass/real.h" -+ -+#include "cutlass/bfloat16.h" -+#include "cutlass/tfloat32.h" -+ -+#include "cutlass/fast_math.h" -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+namespace cutlass { -+ -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Enumeraed type describing a transformation on a complex value. -+enum class ComplexTransform { -+ kNone, -+ kConjugate -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines ComplexTransform inversions -+template -+struct InvertComplexTransform; -+ -+/// Invert ComplexTransform from kNone to kConjugate -+template <> -+struct InvertComplexTransform { -+ static ComplexTransform const transform = ComplexTransform::kConjugate; -+}; -+ -+/// Invert ComplexTransform from kConjugate to kNone -+template <> -+struct InvertComplexTransform { -+ static ComplexTransform const transform = ComplexTransform::kNone; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Accessors for CUDA complex types -+// -+ -+#if !defined(__CUDACC_RTC__) -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+float const &real(cuFloatComplex const &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+float &real(cuFloatComplex &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+double const &real(cuDoubleComplex const &z) { return z.x; } -+ -+/// Returns the real part of the complex number -+CUTLASS_HOST_DEVICE -+double &real(cuDoubleComplex &z) { return z.x; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+float const &imag(cuFloatComplex const &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+float &imag(cuFloatComplex &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+double const &imag(cuDoubleComplex const &z) { return z.y; } -+ -+/// Returns the imaginary part of the complex number -+CUTLASS_HOST_DEVICE -+double &imag(cuDoubleComplex &z) { return z.y; } -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Class for representing and manipulating complex numbers with conversions from built-in CUDA -+/// complex types. -+ -+template -+class complex -+{ -+ public: -+ /// Type alias for scalar type -+ using value_type = T; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Real part -+ T _real; -+ -+ /// Imaginary part -+ T _imag; -+ -+ public: -+ -+// -+// Methods -+// -+ -+ /// Default constructor -+ complex() = default; -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ complex(T r) : _real(r), _imag(T(0)) {} -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ complex(T r, T i) : _real(r), _imag(i) {} -+ -+ /// Constructor -+ template -+ CUTLASS_HOST_DEVICE -+ complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} -+ -+ -+ #if !defined(__CUDACC_RTC__) -+ /// Conversion from cuFloatComplex -+ CUTLASS_HOST_DEVICE -+ complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} -+ -+ /// Conversion from cuDoubleComplex -+ CUTLASS_HOST_DEVICE -+ complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} -+ #endif -+ -+ /// Assignment -+ template -+ CUTLASS_HOST_DEVICE -+ complex& operator=(complex const &z) -+ { -+ _real = static_cast(z.real()); -+ _imag = static_cast(z.imag()); -+ return *this; -+ } -+ -+ /// Equality operator -+ CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { -+ return this->real() == rhs.real() && this->imag() == rhs.imag(); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ /// Addition -+ template -+ CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { -+ return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); -+ } -+ -+ /// Reduction into memory address. Components may update out of order. -+ template -+ CUTLASS_DEVICE void red(complex *ptr) const { -+ static_assert(platform::is_same::value, "Component type must match"); -+ cutlass::red reduce; -+ reduce(&ptr->_real, _real); -+ reduce(&ptr->_imag, _imag); -+ } -+ -+ /// Reduction into memory address. Components may update out of order. (Half specialization) -+ CUTLASS_DEVICE void red(complex *ptr) const { -+ static_assert(platform::is_same::value, "Component type must match"); -+ half2 *h2_ptr = reinterpret_cast(ptr); -+ half2 h2_data = reinterpret_cast(*this); -+ cutlass::red reduce; -+ reduce(h2_ptr, h2_data); -+ } -+ -+ /// Subtraction -+ template -+ CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { -+ return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); -+ } -+ -+ /// Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { -+ return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), -+ this->real() * rhs.imag() + this->imag() * rhs.real()); -+ } -+ -+ /// Scalar Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex operator*(A const &s) const { -+ return complex(this->real() * s, this->imag() * s); -+ } -+ -+ /// Division -+ template -+ CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { -+ T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); -+ -+ return complex( -+ (real() * rhs.real() + imag() * rhs.imag()) / d, -+ (imag() * rhs.real() - real() * rhs.imag()) / d -+ ); -+ } -+ -+ /// Scalar Division -+ template -+ CUTLASS_HOST_DEVICE complex operator/(A const &s) const { -+ return complex(this->real() / s, this->imag() / s); -+ } -+ -+ /// Addition -+ template -+ CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { -+ *this = *this + rhs; -+ return *this; -+ } -+ -+ /// Subtraction -+ template -+ CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { -+ *this = *this - rhs; -+ return *this; -+ } -+ -+ /// Multiplication -+ template -+ CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { -+ *this = *this * rhs; -+ return *this; -+ } -+ -+ /// Scalar multiplication -+ template -+ CUTLASS_HOST_DEVICE complex &operator*=(A s) { -+ *this = *this * s; -+ return *this; -+ } -+ -+ /// Division -+ template -+ CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { -+ *this = *this / rhs; -+ return *this; -+ } -+ -+ /// Accesses the real part of the complex number -+ CUTLASS_HOST_DEVICE -+ T const &real() const { return _real; } -+ -+ /// Accesses the real part of the complex number -+ CUTLASS_HOST_DEVICE -+ T &real() { return _real; } -+ -+ /// Accesses the imaginary part of the complex number -+ CUTLASS_HOST_DEVICE -+ T const &imag() const { return _imag; } -+ -+ /// Accesses the imaginary part of the complex number -+ CUTLASS_HOST_DEVICE -+ T &imag() { return _imag; } -+ -+ -+ #if !defined(__CUDACC_RTC__) -+ /// Converts to cuFloatComplex -+ CUTLASS_HOST_DEVICE -+ explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } -+ -+ /// Converts to cuDoubleComplex -+ CUTLASS_HOST_DEVICE -+ explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } -+ #endif -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Accessors for complex template -+// -+ -+/// Returns the real part of the complex number -+template -+CUTLASS_HOST_DEVICE T const &real(complex const &z) { -+ return z.real(); -+} -+ -+/// Returns the real part of the complex number -+template -+CUTLASS_HOST_DEVICE T &real(complex &z) { -+ return z.real(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T const &imag(complex const &z) { -+ return z.imag(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T &imag(complex &z) { -+ return z.imag(); -+} -+ -+/// Returns the real part of the real number -+template -+CUTLASS_HOST_DEVICE T const &real(T const &r) { -+ return r; -+} -+ -+/// Returns the real part of the real number -+template -+CUTLASS_HOST_DEVICE T &real(T &r) { -+ return r; -+} -+ -+/// Returns the imaginary part of the real number -+template -+CUTLASS_HOST_DEVICE T const &imag(T const &r) { -+ return T(); -+} -+ -+/// Returns the imaginary part of the complex number -+template -+CUTLASS_HOST_DEVICE T &imag(T &r) { -+ return T(); -+} -+ -+// -+// Output operators -+// -+ -+#if !defined(__CUDACC_RTC__) -+template -+std::ostream &operator<<(std::ostream &out, complex const &z) { -+ T _r = real(z); -+ T _i = imag(z); -+ -+ if (bool(_i)) { -+ return out << _r << "+i" << _i; -+ } -+ return out << _r; -+} -+#endif -+ -+// -+// Non-member operators defined for complex types -+// -+ -+ -+// -+// Non-member functions defined for complex numbers -+// -+ -+/// Returns the magnitude of the complex number -+template -+CUTLASS_HOST_DEVICE T abs(complex const &z) { -+ return sqrt(norm(z)); -+} -+ -+/// Returns the magnitude of the complex number -+template -+CUTLASS_HOST_DEVICE T arg(complex const &z) { -+ return atan2(imag(z), real(z)); -+} -+ -+/// Returns the squared magnitude of a real number -+template -+CUTLASS_HOST_DEVICE T norm(T const &z) { -+ return z * z; -+} -+ -+/// Returns the squared magnitude of a real number -+template <> -+CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { -+ return static_cast(z * z); -+} -+ -+/// Returns the squared magnitude of a complex number -+template -+CUTLASS_HOST_DEVICE double norm(complex const &z) { -+ return real(z) * real(z) + imag(z) * imag(z); -+} -+ -+/// Norm-accumulate calculation -+template -+CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { -+ return accumulator + static_cast(x) * static_cast(x); -+} -+ -+/// Norm accumulate specialized for complex types -+template -+CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { -+ return accumulator + static_cast(real(z)) * static_cast(real(z)) + -+ static_cast(imag(z)) * static_cast(imag(z)); -+} -+ -+/// Returns the complex conjugate -+CUTLASS_HOST_DEVICE float conj(float const &z) { -+ return z; -+} -+ -+/// Returns the complex conjugate -+CUTLASS_HOST_DEVICE double conj(double const &z) { -+ return z; -+} -+ -+/// Returns the complex conjugate -+template -+CUTLASS_HOST_DEVICE complex conj(complex const &z) { -+ return complex(real(z), -imag(z)); -+} -+/// Indentity transform for non-complex types -+template -+CUTLASS_HOST_DEVICE T conj(T const &z) { -+ static_assert( !platform::is_same::value && -+ !platform::is_same::value && -+ !platform::is_same>::value && -+ !platform::is_same>::value, "May not be a complex data type"); -+ return z; -+} -+ -+/// Projects the complex number z onto the Riemann sphere -+template -+CUTLASS_HOST_DEVICE complex proj(complex const &z) { -+ T d = real(z) * real(z) + imag(z) * imag(z) + T(1); -+ return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); -+} -+ -+/// Returns a complex number with magnitude r and phase theta -+template -+CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { -+ return complex(r * cos(theta), r * sin(theta)); -+} -+ -+/// Computes the complex exponential of z. -+template -+CUTLASS_HOST_DEVICE complex exp(complex const &z) { -+ return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); -+} -+ -+/// Computes the log of z -+template -+CUTLASS_HOST_DEVICE complex log(complex const &z) { -+ return complex(log(abs(z)), arg(z)); -+} -+ -+/// Computes the log base 10 of z -+template -+CUTLASS_HOST_DEVICE complex log10(complex const &z) { -+ return log(z) / T(log(T(10))); -+} -+ -+/// Computes the square root of complex number z -+template -+CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { -+ return sqrt(T(2)) / T(2) * -+ complex(sqrt(sqrt(norm(z)) + real(z)), -+ (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); -+} -+ -+/// Computes the cosine of complex z. -+template -+CUTLASS_HOST_DEVICE complex cos(complex const &z) { -+ return (exp(z) + exp(-z)) / T(2); -+} -+ -+/// Computes the sin of complex z. -+template -+CUTLASS_HOST_DEVICE complex sin(complex const &z) { -+ return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); -+} -+ -+/// Comparison -+template -+CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { -+ //TODO -+ return true; -+} -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex-valued type. -+template -+struct RealType< complex > -+{ -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = 2; -+ -+ CUTLASS_HOST_DEVICE -+ static complex from_real(double x) { -+ return complex(static_cast(x)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(half_t(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(float(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::complex from_real >(double r) { -+ return cutlass::complex(r); -+} -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct is_complex { -+ static bool const value = false; -+}; -+ -+template -+struct is_complex> { -+ static bool const value = true; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Squares with optional conversion -+template -+struct magnitude_squared, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(complex lhs) const { -+ multiplies mul_op; -+ -+ Output y_r = Output(lhs.real()); -+ Output y_i = Output(lhs.imag()); -+ -+ return mul_op(y_r, y_r) + mul_op(y_i, y_i); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, complex, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ complex const &a, -+ complex const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a.real() * b.real(); -+ real += -a.imag() * b.imag(); -+ imag += a.real() * b.imag(); -+ imag += a.imag () * b.real(); -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, T, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ complex const &a, -+ T const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a.real() * b; -+ imag += a.imag () * b; -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add, complex> { -+ CUTLASS_HOST_DEVICE -+ complex operator()( -+ T const &a, -+ complex const &b, -+ complex const &c) const { -+ -+ T real = c.real(); -+ T imag = c.imag(); -+ -+ real += a * b.real(); -+ imag += a * b.imag(); -+ -+ return complex{ -+ real, -+ imag -+ }; -+ } -+}; -+ -+/// Conjugate -+template -+struct conjugate> { -+ CUTLASS_HOST_DEVICE -+ complex operator()(complex const &a) const { -+ return conj(a); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct magnitude_squared_difference, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(complex lhs, complex rhs) const { -+ multiplies mul_op; -+ -+ Output y_r = Output(lhs.real()) - Output(rhs.real()); -+ Output y_i = Output(lhs.imag()) - Output(rhs.imag()); -+ -+ return mul_op(y_r, y_r) + mul_op(y_i, y_i); -+ } -+}; -+ -+/// Reduces value into the data pointed to by ptr (complex specialization) -+template -+struct red> { -+ CUTLASS_DEVICE -+ void operator()(complex *ptr, const complex &data) -+ { -+ data.red(ptr); -+ } -+}; -+ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/constants.h b/3rdparty/cutlass/include/cutlass/constants.h -new file mode 100644 -index 0000000..ca7ea89 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/constants.h -@@ -0,0 +1,1239 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/* \file -+ \brief Boost-style constant definitions for floating-point types. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace constants { -+ -+/////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Primary templates -+// -+ -+/// Returns 1, the multiplicative identity element -+template CUTLASS_HOST_DEVICE T one(); -+ -+/// Returns 0, the additive identity element -+template CUTLASS_HOST_DEVICE T zero(); -+ -+/// Returns 2 -+template CUTLASS_HOST_DEVICE T two(); -+ -+/// Returns pi, approximately 3.141 -+template CUTLASS_HOST_DEVICE T pi(); -+ -+/// Returns 2 * pi -+template CUTLASS_HOST_DEVICE T two_pi(); -+ -+/// Returns pi / 2 -+template CUTLASS_HOST_DEVICE T half_pi(); -+ -+/// Returns sqrt(pi) -+template CUTLASS_HOST_DEVICE T root_pi(); -+ -+/// Returns sqrt(pi / 2) -+template CUTLASS_HOST_DEVICE T root_half_pi(); -+ -+/// Returns sqrt(2 * pi) -+template CUTLASS_HOST_DEVICE T root_two_pi(); -+ -+/// Returns sqrt(ln(4)) -+template CUTLASS_HOST_DEVICE T root_ln_four(); -+ -+/// Returns e, approximately 2.718... -+template CUTLASS_HOST_DEVICE T e(); -+ -+/// Returns (1/2) -+template CUTLASS_HOST_DEVICE T half(); -+ -+/// Returns sqrt(2), approximately 1.414... -+template CUTLASS_HOST_DEVICE T root_two(); -+ -+/// Returns sqrt(2)/2, approximately 0.707... -+template CUTLASS_HOST_DEVICE T half_root_two(); -+ -+/// Returns ln(2), approximately 0.693... -+template CUTLASS_HOST_DEVICE T ln_two(); -+ -+/// Returns ln(ln(2)), approximately -0.3665... -+template CUTLASS_HOST_DEVICE T ln_ln_two(); -+ -+/// Returns 1/3, approximately 0.333... -+template CUTLASS_HOST_DEVICE T third(); -+ -+/// Returns 2/3, approximately 0.666... -+template CUTLASS_HOST_DEVICE T twothirds(); -+ -+/// Returns pi - 3, approximately 0.1416... -+template CUTLASS_HOST_DEVICE T pi_minus_three(); -+ -+/// Returns 4 - pi, approximately 0.858... -+template CUTLASS_HOST_DEVICE T four_minus_pi(); -+ -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for double -+ -+/// Returns 1, the multiplicative identity element (specialization for double) -+template <> CUTLASS_HOST_DEVICE double one() { -+ uint64_t bits = 0x3ff0000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), double()); -+} -+ -+/// Returns 0, the additive identity element (specialization for double) -+template <> CUTLASS_HOST_DEVICE double zero() { -+ uint64_t bits = 0x0ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), double()); -+} -+ -+/// Returns 2 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double two() { -+ uint64_t bits = 0x4000000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), double()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double pi() { -+ uint64_t bits = 0x400921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), double()); -+} -+ -+/// Returns 2 * pi (specialization for double) -+template <> CUTLASS_HOST_DEVICE double two_pi() { -+ uint64_t bits = 0x401921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), double()); -+} -+ -+/// Returns pi / 2 (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half_pi() { -+ uint64_t bits = 0x3ff921fb54442d18ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), double()); -+} -+ -+/// Returns sqrt(pi) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_pi() { -+ uint64_t bits = 0x3ffc5bf891b4ef6aull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), double()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_half_pi() { -+ uint64_t bits = 0x3ff40d931ff62705ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), double()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_two_pi() { -+ uint64_t bits = 0x40040d931ff62705ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), double()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_ln_four() { -+ uint64_t bits = 0x3ff2d6abe44afc43ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), double()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double e() { -+ uint64_t bits = 0x4005bf0a8b145769ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), double()); -+} -+ -+/// Returns (1/2) (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half() { -+ uint64_t bits = 0x3fe0000000000000ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), double()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double root_two() { -+ uint64_t bits = 0x3ff6a09e667f3bcdull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), double()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double half_root_two() { -+ uint64_t bits = 0x3fe6a09e667f3bcdull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), double()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double ln_two() { -+ uint64_t bits = 0x3fe62e42fefa39efull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), double()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double ln_ln_two() { -+ uint64_t bits = 0xbfd774f29bdd6b9full; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), double()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double third() { -+ uint64_t bits = 0x3fd5555555555555ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), double()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double twothirds() { -+ uint64_t bits = 0x3fe5555555555555ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), double()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double pi_minus_three() { -+ uint64_t bits = 0x3fc21fb54442d180ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), double()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for double) -+template <> CUTLASS_HOST_DEVICE double four_minus_pi() { -+ uint64_t bits = 0x3feb7812aeef4ba0ull; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), double()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for float -+ -+/// Returns 1, the multiplicative identity element (specialization for float) -+template <> CUTLASS_HOST_DEVICE float one() { -+ uint32_t bits = 0x3f800000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), float()); -+} -+ -+/// Returns 0, the additive identity element (specialization for float) -+template <> CUTLASS_HOST_DEVICE float zero() { -+ uint32_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), float()); -+} -+ -+/// Returns 2 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float two() { -+ uint32_t bits = 0x40000000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), float()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float pi() { -+ uint32_t bits = 0x40490fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), float()); -+} -+ -+/// Returns 2 * pi (specialization for float) -+template <> CUTLASS_HOST_DEVICE float two_pi() { -+ uint32_t bits = 0x40c90fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), float()); -+} -+ -+/// Returns pi / 2 (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half_pi() { -+ uint32_t bits = 0x3fc90fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), float()); -+} -+ -+/// Returns sqrt(pi) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_pi() { -+ uint32_t bits = 0x3fe2dfc5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), float()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_half_pi() { -+ uint32_t bits = 0x3fa06c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), float()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_two_pi() { -+ uint32_t bits = 0x40206c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), float()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_ln_four() { -+ uint32_t bits = 0x3f96b55fu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), float()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float e() { -+ uint32_t bits = 0x402df854u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), float()); -+} -+ -+/// Returns (1/2) (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half() { -+ uint32_t bits = 0x3f000000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), float()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float root_two() { -+ uint32_t bits = 0x3fb504f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), float()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float half_root_two() { -+ uint32_t bits = 0x3f3504f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), float()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float ln_two() { -+ uint32_t bits = 0x3f317218u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), float()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float ln_ln_two() { -+ uint32_t bits = 0xbebba795u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), float()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float third() { -+ uint32_t bits = 0x3eaaaaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), float()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float twothirds() { -+ uint32_t bits = 0x3f2aaaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), float()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float pi_minus_three() { -+ uint32_t bits = 0x3e10fdaau; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), float()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for float) -+template <> CUTLASS_HOST_DEVICE float four_minus_pi() { -+ uint32_t bits = 0x3f5bc095u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), float()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for tfloat32_t -+ -+/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t one() { -+ uint32_t bits = 0x3f801000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), tfloat32_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { -+ uint32_t bits = 0x1000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), tfloat32_t()); -+} -+ -+/// Returns 2 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t two() { -+ uint32_t bits = 0x40001000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), tfloat32_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { -+ uint32_t bits = 0x40491fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), tfloat32_t()); -+} -+ -+/// Returns 2 * pi (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { -+ uint32_t bits = 0x40c91fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), tfloat32_t()); -+} -+ -+/// Returns pi / 2 (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { -+ uint32_t bits = 0x3fc91fdbu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { -+ uint32_t bits = 0x3fe2efc5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { -+ uint32_t bits = 0x3fa07c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { -+ uint32_t bits = 0x40207c99u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), tfloat32_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { -+ uint32_t bits = 0x3f96c55fu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), tfloat32_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t e() { -+ uint32_t bits = 0x402e0854u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), tfloat32_t()); -+} -+ -+/// Returns (1/2) (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half() { -+ uint32_t bits = 0x3f001000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { -+ uint32_t bits = 0x3fb514f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), tfloat32_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { -+ uint32_t bits = 0x3f3514f3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), tfloat32_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { -+ uint32_t bits = 0x3f318218u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), tfloat32_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { -+ uint32_t bits = 0xbebbb795u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), tfloat32_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t third() { -+ uint32_t bits = 0x3eaabaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), tfloat32_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { -+ uint32_t bits = 0x3f2abaabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), tfloat32_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { -+ uint32_t bits = 0x3e110daau; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), tfloat32_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) -+template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { -+ uint32_t bits = 0x3f5bd095u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), tfloat32_t()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for half_t -+ -+/// Returns 1, the multiplicative identity element (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t one() { -+ uint16_t bits = 0x3c00u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), half_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t zero() { -+ uint16_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), half_t()); -+} -+ -+/// Returns 2 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t two() { -+ uint16_t bits = 0x4000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), half_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t pi() { -+ uint16_t bits = 0x4248u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), half_t()); -+} -+ -+/// Returns 2 * pi (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t two_pi() { -+ uint16_t bits = 0x4648u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), half_t()); -+} -+ -+/// Returns pi / 2 (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half_pi() { -+ uint16_t bits = 0x3e48u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), half_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_pi() { -+ uint16_t bits = 0x3f17u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), half_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { -+ uint16_t bits = 0x3d03u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), half_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { -+ uint16_t bits = 0x4103u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), half_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { -+ uint16_t bits = 0x3cb6u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), half_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t e() { -+ uint16_t bits = 0x4170u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), half_t()); -+} -+ -+/// Returns (1/2) (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half() { -+ uint16_t bits = 0x3800u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), half_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t root_two() { -+ uint16_t bits = 0x3da8u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), half_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t half_root_two() { -+ uint16_t bits = 0x39a8u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), half_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t ln_two() { -+ uint16_t bits = 0x398cu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), half_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { -+ uint16_t bits = 0xb5ddu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), half_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t third() { -+ uint16_t bits = 0x3555u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), half_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t twothirds() { -+ uint16_t bits = 0x3955u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), half_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { -+ uint16_t bits = 0x3088u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), half_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for half_t) -+template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { -+ uint16_t bits = 0x3adeu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), half_t()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////// -+ -+// Specialization for bfloat16_t -+ -+/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t one() { -+ uint16_t bits = 0x3f80u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1, the multiplicative identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex one< complex >() { -+ return complex(one(), bfloat16_t()); -+} -+ -+/// Returns 0, the additive identity element (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { -+ uint16_t bits = 0x0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 0, the additive identity element (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex zero< complex >() { -+ return complex(zero(), bfloat16_t()); -+} -+ -+/// Returns 2 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t two() { -+ uint16_t bits = 0x4000u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two< complex >() { -+ return complex(two(), bfloat16_t()); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { -+ uint16_t bits = 0x4049u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi, approximately 3.141 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi< complex >() { -+ return complex(pi(), bfloat16_t()); -+} -+ -+/// Returns 2 * pi (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { -+ uint16_t bits = 0x40c9u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2 * pi (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { -+ return complex(two_pi(), bfloat16_t()); -+} -+ -+/// Returns pi / 2 (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { -+ uint16_t bits = 0x3fc9u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi / 2 (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { -+ return complex(half_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(pi) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { -+ uint16_t bits = 0x3fe3u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { -+ return complex(root_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { -+ uint16_t bits = 0x3fa0u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(pi / 2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { -+ return complex(root_half_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { -+ uint16_t bits = 0x4020u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2 * pi) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { -+ return complex(root_two_pi(), bfloat16_t()); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { -+ uint16_t bits = 0x3f97u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(ln(4)) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { -+ return complex(root_ln_four(), bfloat16_t()); -+} -+ -+/// Returns e, approximately 2.718... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t e() { -+ uint16_t bits = 0x402eu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns e, approximately 2.718... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex e< complex >() { -+ return complex(e(), bfloat16_t()); -+} -+ -+/// Returns (1/2) (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half() { -+ uint16_t bits = 0x3f00u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns (1/2) (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half< complex >() { -+ return complex(half(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { -+ uint16_t bits = 0x3fb5u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2), approximately 1.414... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { -+ return complex(root_two(), bfloat16_t()); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { -+ uint16_t bits = 0x3f35u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { -+ return complex(half_root_two(), bfloat16_t()); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { -+ uint16_t bits = 0x3f31u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(2), approximately 0.693... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { -+ return complex(ln_two(), bfloat16_t()); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { -+ uint16_t bits = 0xbebcu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { -+ return complex(ln_ln_two(), bfloat16_t()); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t third() { -+ uint16_t bits = 0x3eabu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 1/3, approximately 0.333... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex third< complex >() { -+ return complex(third(), bfloat16_t()); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { -+ uint16_t bits = 0x3f2bu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 2/3, approximately 0.666... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { -+ return complex(twothirds(), bfloat16_t()); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { -+ uint16_t bits = 0x3e11u; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns pi - 3, approximately 0.1416... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { -+ return complex(pi_minus_three(), bfloat16_t()); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) -+template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { -+ uint16_t bits = 0x3f5cu; -+ return reinterpret_cast(bits); -+} -+ -+/// Returns 4 - pi, approximately 0.858... (specialization for complex) -+template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { -+ return complex(four_minus_pi(), bfloat16_t()); -+} -+/////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace constants -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h b/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h -new file mode 100644 -index 0000000..2bc4eb0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/conv2d_problem_size.h -@@ -0,0 +1,652 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This file contains definitions and utility functions for describing convolution problem sizes. -+ -+ Conv2dProblem desciption: -+ activation (NHWC), -+ filter (KRSC), -+ output (NPQK), -+ pading (pad_h, pad_w), -+ stride (stride_h, stride_w), -+ dilation (dilation_h, dilation_w). -+ -+ Free functions to map: -+ Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) -+ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) -+ Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -+*/ -+ -+#pragma once -+ -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Problem size structure -+struct Conv2dProblemSize { -+ -+ // Conv2d strictly problem size parameters -+ int N, H, W, C, P, Q, K, R, S; -+ int pad_h, pad_w; -+ int stride_h, stride_w; -+ int dilation_h, dilation_w; -+ Mode mode; -+ -+ // Conv2d implementation-related parameters -+ int split_k_slices; -+ int groups; -+ -+ // -+ // Methods -+ // -+ -+public: -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize(): -+ N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), -+ pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), -+ mode(Mode::kConvolution), split_k_slices(1), groups(1) { } -+ -+ /// Constructor for default padding, stride, dilation, and split-K -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ int N, -+ int H, -+ int W, -+ int C, -+ int P, -+ int Q, -+ int K, -+ int R, -+ int S, -+ Mode mode -+ ): -+ N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), -+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), -+ mode(mode), split_k_slices(1), groups (1) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ int N, -+ int H, -+ int W, -+ int C, -+ int K, -+ int R, -+ int S, -+ int P, -+ int Q, -+ int pad_h, -+ int pad_w, -+ int stride_h, -+ int stride_w, -+ int dilation_h, -+ int dilation_w, -+ Mode mode, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(N), H(H), W(W), C(C), K(K), R(R), S(S), P(P), Q(Q), -+ pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), -+ dilation_h(dilation_h), dilation_w(dilation_w), -+ mode(mode), split_k_slices(split_k_slices), groups (groups) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // set user-defined output size and sets P and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ -+ cutlass::MatrixCoord stride, // stride_h, stride_w -+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w -+ cutlass::Tensor4DCoord output_size, // NPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ pad_h(padding[0]), pad_w(padding[2]), -+ stride_h(stride.row()), stride_w(stride.column()), -+ dilation_h(dilation.row()), dilation_w(dilation.column()), -+ P(output_size.h()), Q(output_size.w()), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) {} -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // computes output size and sets P and Q (skip output from ctor arguments) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ -+ cutlass::MatrixCoord stride, // stride_h, stride_w -+ cutlass::MatrixCoord dilation, // dilation_h, dilation_w -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ pad_h(padding[0]), pad_w(padding[2]), -+ stride_h(stride.row()), stride_w(stride.column()), -+ dilation_h(dilation.row()), dilation_w(dilation.column()), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) { -+ // set output P and Q -+ P = ((H + pad_h * 2 - R * dilation_h) / stride_h) + 1; -+ Q = ((W + pad_w * 2 - S * dilation_w) / stride_w) + 1; -+ } -+ -+ /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord -+ // set user-defined output size and sets P and Q (skip padding, striding, and dilation) -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize( -+ cutlass::Tensor4DCoord input_size, // NHWC -+ cutlass::Tensor4DCoord filter_size, // KRSC -+ cutlass::Tensor4DCoord output_size, // NPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), -+ K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), -+ P(output_size.h()), Q(output_size.w()), -+ pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), -+ dilation_h(1), dilation_w(1), -+ mode(mode), split_k_slices(split_k_slices), groups(groups) {} -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) { -+ Conv2dProblemSize tmp(*this); -+ tmp.mode = mode_; -+ return tmp; -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv2dProblemSize reset_split_k_slices(int split_k_slices_) { -+ Conv2dProblemSize tmp(*this); -+ tmp.split_k_slices = split_k_slices_; -+ return tmp; -+ } -+ -+ /// Equality operator (ignores mode and split_k_slice) -+ CUTLASS_HOST_DEVICE -+ bool operator==(Conv2dProblemSize const &conv) const { -+ return ( -+ (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && -+ (K == conv.K) && (R == conv.R) && (S == conv.S) && -+ (P == conv.P) && (Q == conv.Q) && -+ (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && -+ (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && -+ (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) -+ ); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Conv2dProblemSize const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ /// Returns activation extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord activation_extent() const { -+ -+ return cutlass::Tensor4DCoord ({N, H, W, C}); -+ } -+ -+ /// Returns filter extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord filter_extent() const { -+ -+ return cutlass::Tensor4DCoord ({K, R, S, C / groups}); -+ } -+ -+ /// Returns output extent as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord output_extent() const { -+ -+ return cutlass::Tensor4DCoord ({N, P, Q, K}); -+ } -+ -+ /// Returns activation size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t activation_size() const { -+ -+ return (N * H * W * C); -+ } -+ -+ /// Returns filter size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t filter_size() const { -+ -+ return (K * R * S * C / groups); -+ } -+ -+ /// Returns output size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t output_size() const { -+ -+ return (N * P * Q * K); -+ } -+ -+ /// Returns padding as Tensor4DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor4DCoord padding() const { -+ -+ return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w}); -+ } -+ -+ /// Returns stride as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::MatrixCoord stride() const { -+ -+ return cutlass::MatrixCoord ({stride_h, stride_w}); -+ } -+ -+ /// Returns dilation as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::MatrixCoord dilation() const { -+ -+ return cutlass::MatrixCoord ({dilation_h, dilation_w}); -+ } -+ -+ ///////////////////////////////////////////////////////////////// -+ // Methods used for strided dgrad implementation -+ ///////////////////////////////////////////////////////////////// -+ /// Number of filter r positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_r(int r) const { -+ return ((R - r + stride_h - 1) / stride_h); -+ } -+ -+ /// Number of filter s positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_s(int s) const { -+ return ((S - s + stride_w - 1) / stride_w); -+ } -+ -+ /// Number of filter positions to accumulate in gemm-k dim -+ CUTLASS_HOST_DEVICE -+ int num_gemm_k_filter_positions(int r, int s) const { -+ return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// ImplicitGemm helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Determine the problem size of the implicit GEMM operation -+CUTLASS_HOST_DEVICE -+cutlass::gemm::GemmCoord implicit_gemm_problem_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ // Compute problem size -+ switch (conv_operator) { -+ case Operator::kFprop: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.P * problem_size.Q, -+ problem_size.K, -+ problem_size.R * problem_size.S * problem_size.C / problem_size.groups -+ ); -+ case Operator::kDgrad: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.H * problem_size.W, -+ problem_size.C, -+ problem_size.R * problem_size.S * problem_size.K -+ ); -+ case Operator::kWgrad: -+ return gemm::GemmCoord( -+ problem_size.K, -+ problem_size.R * problem_size.S * problem_size.C, -+ problem_size.N * problem_size.P * problem_size.Q -+ ); -+ default: -+ break; -+ } -+ return gemm::GemmCoord(); -+} -+ -+// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int iterations = 0; -+ -+ if (group_mode == GroupMode::kNone) { -+ -+ if (algorithm == IteratorAlgorithm::kFixedChannels) { -+ -+ int positions_per_iteration = threadblock_K / problem_size.C; -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ else if (algorithm == IteratorAlgorithm::kFewChannels) { -+ -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ else { -+ int elements_per_split_k_slice = 0; -+ -+ switch (conv_operator) { -+ case Operator::kFprop: -+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kDgrad: -+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kWgrad: -+ elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ -+ } else if (group_mode == GroupMode::kDepthwise) { -+ int channels_per_cta = threadblock_N; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * -+ ((channels_per_cta + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } else { // Group conv -+ -+ int channels_per_group = problem_size.C / problem_size.groups; -+ int k_per_group = problem_size.K / problem_size.groups; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); -+ // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups -+ if (problem_size.groups != 1) { -+ if (k_per_group < threadblock_N) { -+ iterations *= threadblock_N / k_per_group; -+ } -+ } -+ break; -+ -+ default: -+ break; -+ } -+ } else if (algorithm == IteratorAlgorithm::kOptimized) { -+ // Current optimized iterator only support GroupMode::kSingleGroup -+ if (group_mode == GroupMode::kSingleGroup) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } -+ -+ } -+ -+ return iterations; -+} -+ -+ -+template -+CUTLASS_HOST_DEVICE -+int depthwise_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int n = problem_size.N; -+ int p = (problem_size.P + Output_P - 1) / Output_P; -+ int q = (problem_size.Q + Output_Q - 1) / Output_Q; -+ -+ int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ return iterations; -+} -+ -+ -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations_per_channel( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv2dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { -+ -+ int iterations = 0; //0 means not applicable -+ if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.R * problem_size.S; -+ break; -+ -+ case Operator::kDgrad: -+ iterations = problem_size.R * problem_size.S; -+ break; -+ -+ default: -+ break; -+ } -+ } -+ return iterations; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -+//////////////////////////////////////////////////////////////////////////////// -+/// Returns ImplicitGemm tensor A extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor B extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor C extent as Tensor4DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); -+ default : break; -+ } -+ return cutlass::Tensor4DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor A size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_a_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor B size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_b_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor C size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_c_size( -+ Operator conv_operator, -+ Conv2dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided dgrad helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Returns number of CTAs tile M to cover valid MMAs per starting filter postion -+CUTLASS_HOST_DEVICE -+int strided_dgrad_tile_m_per_filter( -+ Conv2dProblemSize const &problem_size, -+ int tile_size_m) { -+ -+ // Compute NHW rows in Dx output that needs MMA per starting filter position -+ int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; -+ int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; -+ int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; -+ -+ // Number of CTAs tile M to cover valid MMAs per starting filter postion -+ int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; -+ -+ return tile_m_per_filter; -+} -+ -+// Computes starting Dx coord (h, w) for given starting filter postion -+CUTLASS_HOST_DEVICE -+void strided_dgrad_starting_coords( -+ Conv2dProblemSize const &problem_size, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int r, int s, -+ int &start_h, int &start_w) { -+ -+ // function locals for remainder by fast divmod -+ int pad_h_rem_, pad_w_rem_; -+ -+ // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; -+ stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); -+ int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); -+ stride_h_divmod.divmod(start_h, r_); -+ -+ //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; -+ stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); -+ int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); -+ stride_w_divmod.divmod(start_w, s_); -+} -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h b/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h -new file mode 100644 -index 0000000..5bef4ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/conv3d_problem_size.h -@@ -0,0 +1,477 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This file contains definitions and utility functions for describing convolution problem sizes. -+ -+ Conv3dProblem desciption: -+ activation (NDHWC), -+ filter (KTRSC), -+ output (NZPQK), -+ pading (pad_d, pad_h, pad_w), -+ stride (stride_d, stride_h, stride_w), -+ dilation (dilation_d, dilation_h, dilation_w). -+ -+ Free functions to map: -+ Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) -+ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) -+ Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) -+*/ -+ -+#pragma once -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Problem size structure -+struct Conv3dProblemSize : public Conv2dProblemSize { -+ // -+ // Type definitions -+ // -+ -+ // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions -+ using Coord3D = Coord<3>; -+ -+ // -+ // Data members -+ // -+ -+ // Conv3d strictly problem size parameters -+ int D, T, Z; // input depth, filter depth, output depth -+ int pad_d; // padding in depth dimension -+ int stride_d; // stride in depth dimension -+ int dilation_d; // dilation in depth dimension -+ -+ // -+ // Methods -+ // -+public: -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize(): -+ D(0), T(0), Z(0), -+ pad_d(0), -+ stride_d(1), -+ dilation_d(1), -+ Conv2dProblemSize() { } -+ -+ /// Constructor for default padding, stride, dilation, and split-K -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ int N, -+ int D, -+ int H, -+ int W, -+ int C, -+ int Z, -+ int P, -+ int Q, -+ int K, -+ int T, -+ int R, -+ int S, -+ Mode mode -+ ): -+ D(D), T(T), Z(Z), -+ pad_d(T / 2), stride_d(1), dilation_d(1), -+ Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ int N, -+ int D, -+ int H, -+ int W, -+ int C, -+ int K, -+ int T, -+ int R, -+ int S, -+ int Z, -+ int P, -+ int Q, -+ int pad_d, -+ int pad_h, -+ int pad_w, -+ int stride_d, -+ int stride_h, -+ int stride_w, -+ int dilation_d, -+ int dilation_h, -+ int dilation_w, -+ Mode mode, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(D), T(T), Z(Z), -+ pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d), -+ Conv2dProblemSize( -+ N, H, W, C, K, R, S, P, Q, -+ pad_h, pad_w, -+ stride_h, stride_w, -+ dilation_h, dilation_w, -+ mode, split_k_slices, groups) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D -+ // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ cutlass::Tensor5DCoord input_size, // NDHWC -+ cutlass::Tensor5DCoord filter_size, // KTRSC -+ Coord3D padding, // pad_d, pad_h, pad_w -+ Coord3D stride, // stride_d, stride_h, stride_w -+ Coord3D dilation, // dilation_d, dilation_h, dilation_w -+ cutlass::Tensor5DCoord output_size, // NZPQK -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(input_size.d()), T(filter_size.d()), Z(output_size.d()), -+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), -+ Conv2dProblemSize( -+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, -+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, -+ {padding[1], padding[1], padding[2], padding[2]}, -+ {stride[1], stride[2]}, -+ {dilation[1], dilation[2]}, -+ {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, -+ mode, split_k_slices, groups -+ ) { } -+ -+ /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D -+ // *computes* output size and sets Z, P and Q (include all data members in ctor) -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize( -+ cutlass::Tensor5DCoord input_size, // NDHWC -+ cutlass::Tensor5DCoord filter_size, // KTRSC -+ Coord3D padding, // pad_d, pad_h, pad_w -+ Coord3D stride, // stride_d, stride_h, stride_w -+ Coord3D dilation, // dilation_d, dilation_h, dilation_w -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, -+ int split_k_slices = 1, -+ int groups = 1 -+ ): -+ D(input_size.d()), T(filter_size.d()), -+ pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), -+ Conv2dProblemSize( -+ {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, -+ {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, -+ {padding[1], padding[1], padding[2], padding[2]}, -+ {stride[1], stride[2]}, -+ {dilation[1], dilation[2]}, -+ mode, split_k_slices, groups -+ ) { -+ // set output Z -+ Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; -+ } -+ -+ /// Equality operator (ignores mode and split_k_slice) -+ CUTLASS_HOST_DEVICE -+ bool operator==(Conv3dProblemSize const &conv) const { -+ return ( -+ (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) && -+ (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && -+ (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && -+ (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && -+ (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && -+ (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) -+ ); -+ } -+ -+ /// Inequality operator -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Conv3dProblemSize const &rhs) const { -+ return !(*this == rhs); -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) { -+ Conv3dProblemSize tmp(*this); -+ tmp.mode = mode_; -+ return tmp; -+ } -+ -+ // Reset covolution mode in the problem -+ CUTLASS_HOST_DEVICE -+ Conv3dProblemSize reset_split_k_slices(int split_k_slices_) { -+ Conv3dProblemSize tmp(*this); -+ tmp.split_k_slices = split_k_slices_; -+ return tmp; -+ } -+ -+ /// Returns activation extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord activation_extent() const { -+ -+ return cutlass::Tensor5DCoord ({N, D, H, W, C}); -+ } -+ -+ /// Returns filter extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord filter_extent() const { -+ -+ return cutlass::Tensor5DCoord ({K, T, R, S, C}); -+ } -+ -+ /// Returns output extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ cutlass::Tensor5DCoord output_extent() const { -+ -+ return cutlass::Tensor5DCoord ({N, Z, P, Q, K}); -+ } -+ -+ /// Returns activation size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t activation_size() const { -+ -+ return (N * D * H * W * C); -+ } -+ -+ /// Returns filter size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t filter_size() const { -+ -+ return (K * T * R * S * C); -+ } -+ -+ /// Returns output size in number of elements -+ CUTLASS_HOST_DEVICE -+ int64_t output_size() const { -+ -+ return (N * Z * P * Q * K); -+ } -+ -+ /// Returns output extent as Tensor5DCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D padding() const { -+ -+ return Coord3D ({pad_d, pad_h, pad_w}); -+ } -+ -+ /// Returns stride as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D stride() const { -+ -+ return Coord3D ({stride_d, stride_h, stride_w}); -+ } -+ -+ /// Returns dilation as MatrixCoord -+ CUTLASS_HOST_DEVICE -+ Coord3D dilation() const { -+ -+ return Coord3D ({dilation_d, dilation_h, dilation_w}); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// ImplicitGemm helper functions // -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Determine the problem size of the implicit GEMM operation -+CUTLASS_HOST_DEVICE -+cutlass::gemm::GemmCoord implicit_gemm_problem_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ // Compute problem size -+ switch (conv_operator) { -+ case Operator::kFprop: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q, -+ problem_size.K, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.C -+ ); -+ case Operator::kDgrad: -+ return gemm::GemmCoord( -+ problem_size.N * problem_size.D * problem_size.H * problem_size.W, -+ problem_size.C, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.K -+ ); -+ case Operator::kWgrad: -+ return gemm::GemmCoord( -+ problem_size.K, -+ problem_size.T * problem_size.R * problem_size.S * problem_size.C, -+ problem_size.N * problem_size.Z * problem_size.P * problem_size.Q -+ ); -+ default: -+ break; -+ } -+ return gemm::GemmCoord(); -+} -+ -+// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm -+CUTLASS_HOST_DEVICE -+int implicit_gemm_k_iterations( -+ Operator conv_operator, -+ int threadblock_K, -+ Conv3dProblemSize const &problem_size, -+ IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, -+ GroupMode group_mode = GroupMode::kNone, -+ int threadblock_N = 0) { -+ -+ int iterations = 0; -+ int elements_per_split_k_slice = 0; -+ if (group_mode == GroupMode::kNone) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kDgrad: -+ elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ case Operator::kWgrad: -+ elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; -+ iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; -+ break; -+ -+ default: -+ break; -+ } -+ } else if (group_mode == GroupMode::kDepthwise) { -+ int channels_per_cta = threadblock_N; -+ -+ if (algorithm == IteratorAlgorithm::kAnalytic) { -+ switch (conv_operator) { -+ case Operator::kFprop: -+ iterations = problem_size.T * problem_size.R * problem_size.S * -+ ((channels_per_cta + threadblock_K - 1) / threadblock_K); -+ break; -+ -+ default: -+ break; -+ } -+ } -+ } -+ -+ return iterations; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) -+//////////////////////////////////////////////////////////////////////////////// -+/// Returns ImplicitGemm tensor A extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor B extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor C extent as Tensor5DCoord -+CUTLASS_HOST_DEVICE -+cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); -+ default : break; -+ } -+ return cutlass::Tensor5DCoord(); -+} -+ -+/// Returns ImplicitGemm tensor A size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_a_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor B size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_b_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+/// Returns ImplicitGemm tensor C size in number of elements -+CUTLASS_HOST_DEVICE -+int64_t implicit_gemm_tensor_c_size( -+ Operator conv_operator, -+ Conv3dProblemSize const &problem_size) { -+ switch (conv_operator) { -+ case cutlass::conv::Operator::kFprop: return problem_size.output_size(); -+ case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); -+ case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); -+ default : break; -+ } -+ return 0; -+} -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/convolution.h b/3rdparty/cutlass/include/cutlass/conv/convolution.h -new file mode 100644 -index 0000000..0647edf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/convolution.h -@@ -0,0 +1,167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+This file contains definitions and utility functions for describing convolution problem sizes in terms of -+activation (NHWC), filter (KRSC), output (NPQK), pading (pad_h, pad_w), stride (stride_h, stride_w), -+dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map cutlass' implicit gemm -+tensor extents, sizes, data types to that of convolutions extents, sizes, and data types. -+ -+ * Mapping convolutions to Gemm computation * -+ -+Cutlass employs ImplicitGemm algorithm to implement convolutions. ImplicitGemm algorithm runs gemm operation -+on convolution tensors Activation, Filter, and Output . The underlying gemm operation follows the standard -+gemm definition: -+ -+ C = A * B + C -+ -+ A and B are input matrices -+ C is source and output matrix -+ -+ -+For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped on -+to convolution tensors Activation, Filter and Output as per the below table: -+ -+ ___________________________________________________________________________ -+ ConvolutionalOperator | A | B | C -+ ___________________________________________________________________________ -+ | | | | | -+ | Fprop | Activation | Filter | Output | -+ | Dgrad | Output | Filter | Activation | -+ | Wgrad | Output | Activation | Filter | -+ ___________________________________________________________________________ -+ -+In convolution codebase, DO NOT mix using (A, B, C) with (Acvitation, Filter, Output). -+ -+For example, a convolution class/function with A, B, Output is confusing and error-prone. Instead use below -+mapping functions and adhere to using either A, B, C or Acvitation, Filter, Output. -+ -+Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap -+Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+namespace cutlass { -+namespace conv { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Convolutional operator -+enum class Operator { -+ kFprop, -+ kDgrad, -+ kWgrad -+}; -+ -+/// Distinguishes convolution from cross correlation -+enum class Mode { -+ kCrossCorrelation, -+ kConvolution -+}; -+ -+/// Selects among several implementation variants trading off performance with simplicity -+enum class IteratorAlgorithm { -+ kAnalytic, ///< functionally correct in all cases but lower performance -+ kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad -+ kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) -+ kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) -+ kFixedStrideDilation ///< Optimized for fixed stride and dilation -+}; -+ -+/// Distinguishes among partial specializations that accelerate certain problems where convolution -+/// stride is unit. -+enum class StrideSupport { -+ kStrided, ///< arbitrary convolution stride -+ kUnity, ///< unit convolution stride -+ kFixed ///< fixed convolution stride -+}; -+ -+/// Identifies split-K mode -+enum class SplitKMode { -+ kNone, -+ kSerial, -+ kParallel -+}; -+ -+/// Identifies group mode -+enum class GroupMode { -+ kNone, -+ kSingleGroup, ///< One CTA calculates one group or less -+ kMultipleGroup, ///< One CTA calculates multiple groups -+ kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Shape of a tensor -+template < -+ int N = 1, -+ int H = 1, -+ int W = 1, -+ int C = 1 -+> -+struct TensorNHWCShape { -+ static int const kN = N; -+ static int const kH = H; -+ static int const kW = W; -+ static int const kC = C; -+ -+ static int const kHW = H * W; -+ static int const kNHW = N * kHW; -+ static int const kNHWC = N * H * W * C; -+ -+ static int const kCount = kNHWC; -+ -+ // -+ // Static member functions -+ // -+ -+ /// Returns a Coord object -+ CUTLASS_HOST_DEVICE -+ static Coord<4> toCoord() { -+ return make_Coord(kN, kH, kW, kC); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h b/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h -new file mode 100644 -index 0000000..d7f28f1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/direct_convolution.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Depthwise Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectConvolution { -+public: -+ -+ using UnderlyingKernel = DirectConvolutionKernel_; -+ -+ using ElementA = typename UnderlyingKernel::ElementA; -+ using LayoutA = typename UnderlyingKernel::LayoutA; -+ using ElementB = typename UnderlyingKernel::ElementB; -+ using LayoutB = typename UnderlyingKernel::LayoutB; -+ using ElementC = typename UnderlyingKernel::ElementC; -+ using LayoutC = typename UnderlyingKernel::LayoutC; -+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; -+ using ElementCompute = typename UnderlyingKernel::ElementCompute; -+ using OperatorClass = typename UnderlyingKernel::OperatorClass; -+ using ArchTag = typename UnderlyingKernel::ArchTag; -+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; -+ using WarpShape = typename UnderlyingKernel::WarpShape; -+ using InstructionShape = typename UnderlyingKernel::InstructionShape; -+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; -+ static int const kStages = UnderlyingKernel::kStages; -+ static int const kConvDim = UnderlyingKernel::kConvDim; -+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; -+ using MathOperator = typename UnderlyingKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; -+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; -+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingKernel::Arguments; -+ -+ using ReorderKernel = typename UnderlyingKernel::ReorderKernel; -+ -+ private: -+ -+ /// Kernel parameters object -+ typename UnderlyingKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ DirectConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ if (kGroupMode != conv::GroupMode::kDepthwise) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // C and K should be multiple of groups -+ if (args.problem_size.K != args.problem_size.groups && -+ args.problem_size.C != args.problem_size.groups) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ -+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; -+ if (kConvolutionalOperator == conv::Operator::kFprop) { -+ if (args.problem_size.K % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kDgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ // initialize the params structure from the arguments -+ params_ = typename UnderlyingKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.ptr_reordered_B = args.ref_reordered_B.data();; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // Launch reorder kernel -+ if (params_.ptr_reordered_B != nullptr) { -+ dim3 grid = ReorderKernel::get_grid_shape(params_); -+ dim3 block = ReorderKernel::get_block_shape(); -+ -+ cutlass::Kernel<<>>(params_); -+ } -+ -+ // Launch main kernel -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ // Dynamic SMEM size based on input params. -+ int smem_size = int(params_.get_smem_size()); -+ -+ // Make sure we can use that much shared memory. -+ cudaError_t status = -+ cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ if (status != cudaSuccess) -+ return Status::kErrorInternal; -+ -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+ int get_smem_size() { return int(params_.get_smem_size()); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h -new file mode 100644 -index 0000000..50bdc47 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level Implicit GEMM Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ImplicitGemmConvolution { -+public: -+ -+ using UnderlyingKernel = ImplicitGemmKernel_; -+ -+ using ElementA = typename UnderlyingKernel::ElementA; -+ using LayoutA = typename UnderlyingKernel::LayoutA; -+ using ElementB = typename UnderlyingKernel::ElementB; -+ using LayoutB = typename UnderlyingKernel::LayoutB; -+ using ElementC = typename UnderlyingKernel::ElementC; -+ using LayoutC = typename UnderlyingKernel::LayoutC; -+ using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; -+ using ElementCompute = typename UnderlyingKernel::ElementCompute; -+ using OperatorClass = typename UnderlyingKernel::OperatorClass; -+ using ArchTag = typename UnderlyingKernel::ArchTag; -+ using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; -+ using WarpShape = typename UnderlyingKernel::WarpShape; -+ using InstructionShape = typename UnderlyingKernel::InstructionShape; -+ using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; -+ static int const kStages = UnderlyingKernel::kStages; -+ static int const kConvDim = UnderlyingKernel::kConvDim; -+ using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; -+ using MathOperator = typename UnderlyingKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; -+ static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; -+ static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename UnderlyingKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ ImplicitGemmConvolution() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // check group conv constraint -+ if (args.problem_size.groups != 1) { -+ if (kGroupMode == conv::GroupMode::kNone) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // C and K should be multiple of groups -+ if (args.problem_size.K % args.problem_size.groups || -+ args.problem_size.C % args.problem_size.groups) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // split-k is not supported -+ if (args.problem_size.split_k_slices != 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ int k_per_group = args.problem_size.K / args.problem_size.groups; -+ // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group -+ if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { -+ return Status::kErrorInvalidProblem; -+ } -+ // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups -+ if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // current optimized iterator algo only supports SingleGroup mode -+ if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && -+ kGroupMode != conv::GroupMode::kSingleGroup) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; -+ if (kConvolutionalOperator == conv::Operator::kFprop) { -+ if (args.problem_size.K % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kDgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } else if (kConvolutionalOperator == conv::Operator::kWgrad) { -+ if (args.problem_size.C % kAlignmentC) -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // check for unsupported problem sizes for strided dgrad implementation -+ if (kConvolutionalOperator == conv::Operator::kDgrad && -+ kStrideSupport == conv::StrideSupport::kStrided) { -+ -+ // split-k (serial or parallel) is not supported for strided dgrad -+ if(args.problem_size.split_k_slices > 1) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // dilation > {1x1} is not supported for strided dgrad -+ if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename UnderlyingKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h -new file mode 100644 -index 0000000..2f434bd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h -@@ -0,0 +1,268 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+#include "cutlass/conv/convolution.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ImplicitGemmConvolutionFusion { -+public: -+ -+ using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_; -+ -+ using ElementA = typename ImplicitGemmFusionKernel::ElementA; -+ using LayoutA = typename ImplicitGemmFusionKernel::LayoutA; -+ using ElementB = typename ImplicitGemmFusionKernel::ElementB; -+ using LayoutB = typename ImplicitGemmFusionKernel::LayoutB; -+ -+// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias; -+// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias; -+ -+ using ElementC = typename ImplicitGemmFusionKernel::ElementC; -+ using LayoutC = typename ImplicitGemmFusionKernel::LayoutC; -+ using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator; -+ using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute; -+ using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass; -+ using ArchTag = typename ImplicitGemmFusionKernel::ArchTag; -+ using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape; -+ using WarpShape = typename ImplicitGemmFusionKernel::WarpShape; -+ using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape; -+ using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle; -+ using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp; -+ static int const kStages = ImplicitGemmFusionKernel::kStages; -+ static int const kConvDim = ImplicitGemmFusionKernel::kConvDim; -+ using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator; -+ using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator; -+ using MathOperator = typename ImplicitGemmFusionKernel::MathOperator; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm; -+ -+ static int const kWarpCount = -+ (ThreadblockShape::kM / WarpShape::kM) * -+ (ThreadblockShape::kN / WarpShape::kN) * -+ (ThreadblockShape::kK / WarpShape::kK); -+ -+ /// Argument structure -+ using Arguments = typename ImplicitGemmFusionKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename ImplicitGemmFusionKernel::Params params_; -+ -+public: -+ -+ /// Constructs Implicit GEMM -+ ImplicitGemmConvolutionFusion() { } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // dispatch to iterators -+ Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size); -+ if (Status::kSuccess != status) { -+ return status; -+ } -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape( -+ threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices)); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) { -+ -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t workspace_bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ if(args.split_k_mode == SplitKMode::kParallel) { -+ -+ // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. -+ // The user needs to call a reduction operator to optain the final output tensor -+ workspace_bytes = -+ sizeof(ElementAccumulator) * -+ size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * -+ size_t(grid_tiled_shape.k()); -+ } -+ -+ else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { -+ -+ // Split-K serial: The user workspace is used to store semaphore and serialize writing the -+ // final reduced output to user's output tensor -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ if (args.problem_size.split_k_slices > 1) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); -+ -+ if (status != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize the params structure from the arguments -+ params_ = typename ImplicitGemmFusionKernel::Params( -+ args, -+ static_cast(workspace) -+ ); -+ -+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Initializes Impicit GEMM state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.ptr_A = args.ref_A.data(); -+ params_.ptr_B = args.ref_B.data(); -+ params_.ptr_scale = args.ref_A_scale.data(); -+ params_.ptr_bias = args.ref_A_bias.data(); -+ params_.ptr_C = args.ref_C.data(); -+ params_.ptr_D = args.ref_D.data(); -+ params_.output_op = args.output_op; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(32 * kWarpCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h -new file mode 100644 -index 0000000..cb7980b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d.h -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" -+#include "cutlass/conv/threadblock/implicit_gemm_multistage.h" -+#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h" -+#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogue { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogue< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp -+> { -+ -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithBroadcastTensorOp { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithBroadcastTensorOp< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ > { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ ElementTensor, -+ ElementVector, -+ OutputOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithReductionTensorOp { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess -+> -+struct DefaultConvEpilogueWithReductionTensorOp< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ > { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ ElementsPerAccess -+ >::Epilogue; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Defaults for strided Dgrad -+template < -+ typename ArchTag, -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogueStridedDgrad { -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename OutputOp -+> -+struct DefaultConvEpilogueStridedDgrad< -+ arch::Sm70, -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp -+> { -+ -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ OutputOp::kCount -+ >::Epilogue; -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h -new file mode 100644 -index 0000000..6a54120 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h -@@ -0,0 +1,1927 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dDgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dDgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and -+// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -+// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity -+// 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and -+// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -+// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kStrided, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity -+// 2 stage pipeline -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ StrideSupport::kUnity -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ conv::StrideSupport::kStrided, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIteratorStridedDgrad< -+ cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ conv::StrideSupport::kStrided -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad -+ >; -+ -+}; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h -new file mode 100644 -index 0000000..3e16d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h -@@ -0,0 +1,1989 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFewChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kFewChannels, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ // Note GEMM shared memory threadmap is used here because conv global memory -+ // layout needs to be mapped to fprop which is similar to the crosswise -+ // layout which is used by the interleaved GEMM shared memory threadmap. -+ // The Interleaved GEMM global memory layout is similar to the congruous -+ // layout. -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+// multistage pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, MathOperatorTag, true -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ layout::TensorNCxHWx, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ layout::TensorCxRSKx, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage -+/// pipeline with interleaved layout. -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB, -+ int InterleavedK -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ layout::TensorNCxHWx, -+ ElementB, -+ layout::TensorCxRSKx, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, -+ ElementB, layout::RowMajorInterleaved, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ 2, MathOperatorTag, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::SmemThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, layout::TensorNCxHWx, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::SmemThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, layout::TensorCxRSKx, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ InterleavedK -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h -new file mode 100644 -index 0000000..da48878 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h -@@ -0,0 +1,357 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution -+ definitions that combine threadblock-scoped matrix multiply-add with the -+ appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for fused batch norm and Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv2dFpropFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h -new file mode 100644 -index 0000000..d744ae8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> -+struct DefaultConv2dFpropWithBroadcast { -+ -+ using ImplicitGemmBase = typename DefaultConv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+ >::Kernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< -+ ArchTag, -+ typename ImplicitGemmBase::Epilogue::Shape, -+ typename ImplicitGemmBase::Epilogue::WarpMmaOperator, -+ ImplicitGemmBase::Epilogue::kPartitionsK, -+ ElementC, -+ typename EpilogueOutputOp::ElementT, -+ ElementC, -+ EpilogueOutputOp, -+ ImplicitGemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< -+ typename ImplicitGemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h -new file mode 100644 -index 0000000..00b8c90 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename EpilogueReductionOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> -+struct DefaultConv2dFpropWithReduction { -+ -+ using ImplicitGemmBase = typename DefaultConv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+ >::Kernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< -+ ArchTag, -+ typename ImplicitGemmBase::Epilogue::Shape, -+ typename ImplicitGemmBase::Epilogue::WarpMmaOperator, -+ ImplicitGemmBase::Epilogue::kPartitionsK, -+ ElementC, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ ImplicitGemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< -+ typename ImplicitGemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h -new file mode 100644 -index 0000000..cdd89e0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h -@@ -0,0 +1,490 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dGroupFpro -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::GroupMode GroupMode, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dGroupFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage -+/// pipeline that supports all GroupMode. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::GroupMode GroupMode, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ GroupMode, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA, -+ GroupMode -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB, -+ GroupMode -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage -+/// pipeline that supports GroupMode::kSingleGroup. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ GroupMode::kSingleGroup, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode::kSingleGroup -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and -+/// 2 stage pipeline that supports GroupMode::kSingleGroup. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dGroupFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ GroupMode::kSingleGroup, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ static_assert(std::is_same::value, -+ "Current group conv only support NHWC layout"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ GroupMode::kSingleGroup -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h -new file mode 100644 -index 0000000..099bb6c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h -@@ -0,0 +1,1011 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultConv2dWgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::AlignedArray; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ AccessTypeA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB, -+ AccessTypeB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -+/// multi-stage pipeline, and FFMA-based mainloop for SM80 -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, -+/// 2 stage pipeline, and FFMA-based mainloop for SM50 -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AccessTypeA, -+ int AccessTypeB -+> -+struct DefaultConv2dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ AccessTypeA, -+ AccessTypeB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h -new file mode 100644 -index 0000000..62bf177 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h -@@ -0,0 +1,325 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv2dWgradFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dWgradFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmWgradFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ IteratorScaleBias, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv2dWgradFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmWgradFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ IteratorScaleBias, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h -new file mode 100644 -index 0000000..01e895c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h" -+ -+#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dDgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dDgrad; -+ -+/// Defines a kernel for Conv3dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport::kStrided -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kStrided -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided -+// and multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dDgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport::kUnity -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA, -+ StrideSupport::kUnity -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kDgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h -new file mode 100644 -index 0000000..9c8f8cf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h -@@ -0,0 +1,515 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" -+ -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv2dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Analytic Iterator Algorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm -+/// and 2 stage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h -new file mode 100644 -index 0000000..66cbbcd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h -@@ -0,0 +1,360 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution -+ definitions that combine threadblock-scoped matrix multiply-add with the -+ appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for fused batch norm and Conv3dFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dFpropFusion; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassTensorOp convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage -+/// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and -+/// multistage pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementScaleBias, -+ typename LayoutScaleBias, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dFpropFusion < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, -+ ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, MathOperatorTag -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorScaleBias = -+ cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorScaleBias = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmFpropFusionMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Global, -+ IteratorScaleBias, -+ SmemIteratorScaleBias, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ WarpIteratorScaleBias, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h -new file mode 100644 -index 0000000..3807911 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+ -+#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h" -+#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv2dWgrad -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided -+> struct DefaultConv3dWgrad; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kAnalytic -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and multistage -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmMultistage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ arch::CacheOperation::Always, -+ IteratorB, -+ SmemIteratorB, -+ arch::CacheOperation::Always, -+ MmaPolicy, -+ Stages -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and two -+// pipeline. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag -+> -+struct DefaultConv3dWgrad < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, -+ ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, -+ 2, MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, -+ ThreadMapB -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::ImplicitGemmPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename detail::DefaultConvEpilogue< -+ ArchTag, -+ ThreadblockShape, -+ WarpMmaTensorOp, -+ 1, -+ EpilogueOutputOp -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kWgrad, -+ Conv3dProblemSize -+ >; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h b/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h -new file mode 100644 -index 0000000..df57e30 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h -@@ -0,0 +1,588 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped -+ matrix multiply-add with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d.h" -+#include "cutlass/conv/kernel/direct_convolution.h" -+ -+#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" -+ -+// Direct Conv Related Header files -+#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" -+ -+#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" -+#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for DepthwiseFprop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value -+> struct DefaultDepthwiseFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for DepthwiseFprop with direct convolution algorithm -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename OperatorClass, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ conv::StrideSupport StrideSupport = StrideSupport::kStrided, -+ // MatrixShape -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ // MatrixShape< Height, Width> -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = 128 / cutlass::sizeof_bits::value, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = 128 / cutlass::sizeof_bits::value -+> struct DefaultDepthwiseDirect2dConvFprop; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// OpClassSimt convolutions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ MathOperatorTag, // cutlass::arch::OpMultiplyAdd -+ IteratorAlgorithm::kAnalytic, -+ StrideSupport, -+ AlignmentA, -+ AlignmentB -+> { -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ sizeof_bits::value, -+ 2, -+ MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ > -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::TileIterator< -+ cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB, -+ AccessTypeB, -+ cutlass::conv::GroupMode::kDepthwise -+ > -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropPipelined< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ IteratorB, -+ SmemIteratorB, -+ ElementC, -+ LayoutC, -+ MmaPolicy -+ >; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -+/// multiple stage pipeline, and SIMT-based mainloop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ typename StrideShape, -+ typename DilationShape, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseDirect2dConvFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kOptimized, -+ StrideSupport, -+ StrideShape, -+ DilationShape, -+ AlignmentA, -+ AlignmentB -+> { -+ // One warp handles the entrie groups per cta. -+ static_assert(ThreadblockShape::kN == WarpShape::kN, -+ "ThreadblockShape::kN should be same as WarpShape::kN "); -+ static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, -+ "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); -+ static_assert(ThreadblockShape::kM % WarpShape::kM == 0, -+ "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); -+ static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ 128, -+ Stages, -+ MathOperatorTag>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> -+ ThreadBlockOutputShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ using ThreadOutputShape = typename MmaCore::ThreadOutputShape; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * AlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< -+ ThreadblockShape, // < outputShape:KMNK, groups per cta> -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >::Epilogue; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages, -+ Epilogue -+ >; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::DirectConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise, -+ ThreadBlockOutputShape -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, -+/// multiple stage pipeline, and SIMT-based mainloop -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementAccumulator, -+ typename ArchTag, -+ typename ThreadblockShape, -+ typename ThreadBlockOutputShape, -+ typename FilterShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename EpilogueOutputOp, -+ typename ThreadblockSwizzle, -+ int Stages, -+ typename MathOperatorTag, -+ conv::StrideSupport StrideSupport, -+ typename StrideShape, -+ typename DilationShape, -+ int AlignmentA, -+ int AlignmentB -+> -+struct DefaultDepthwiseDirect2dConvFprop < -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideSupport, -+ StrideShape, -+ DilationShape, -+ AlignmentA, -+ AlignmentB, -+> { -+ -+ -+ -+ // One warp handles the entrie groups per cta. -+ static_assert(ThreadblockShape::kN == WarpShape::kN, -+ "ThreadblockShape::kN should be same as WarpShape::kN "); -+ static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, -+ "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); -+ static_assert(ThreadblockShape::kM % WarpShape::kM == 0, -+ "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); -+ static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); -+ -+ static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); -+ static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); -+ -+ // Activations loaded by threadblock -+ static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + -+ (FilterShape::kRow - 1) * DilationShape::kRow + 1; -+ -+ static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + -+ (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; -+ -+ using ActivationShape = -+ cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; -+ -+ // Define the core components from GEMM -+ using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ layout::RowMajor, -+ ElementB, -+ layout::ColumnMajor, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ 128, -+ 128, -+ Stages, -+ MathOperatorTag, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideShape, -+ DilationShape, -+ ActivationShape>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using IteratorA = -+ cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< -+ cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> -+ ThreadBlockOutputShape, -+ StrideShape, -+ DilationShape, -+ ActivationShape, -+ ElementA, LayoutA, -+ ThreadMapA -+ >; -+ -+ using SmemIteratorA = typename MmaCore::SmemIteratorA; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::AlignedArray; -+ using IteratorB = -+ cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, -+ ThreadMapB -+ >; -+ -+ using SmemIteratorB = typename MmaCore::SmemIteratorB; -+ -+ // Warp-level GEMM components -+ using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; -+ using MmaPolicy = typename MmaCore::MmaPolicy; -+ using ThreadOutputShape = typename MmaCore::ThreadOutputShape; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * AlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * AlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< -+ ThreadblockShape, // < outputShape:KMNK, groups per cta> -+ WarpMmaSimtOp, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >::Epilogue; -+ -+ // Define the Mma -+ using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< -+ ThreadblockShape, -+ IteratorA, -+ SmemIteratorA, -+ CacheOpA, -+ IteratorB, -+ SmemIteratorB, -+ CacheOpB, -+ MmaPolicy, -+ Stages, -+ Epilogue, -+ IteratorAlgorithm::kFixedStrideDilation -+ >; -+ -+ // Define the kernel -+ using Kernel = cutlass::conv::kernel::DirectConvolution< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ conv::Operator::kFprop, -+ Conv2dProblemSize, -+ cutlass::conv::GroupMode::kDepthwise, -+ ThreadBlockOutputShape -+ >; -+}; -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h b/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h -new file mode 100644 -index 0000000..ef7a920 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/direct_convolution.h -@@ -0,0 +1,505 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multi-staged Depthwise Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure -+template > ///! OutputShape per ThreadBlock -+struct DirectConvolutionParams { -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ using Arguments = Arguments_; -+ using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ static int const kStages = Mma::kStages; -+ -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ int smem_size_; -+ -+ int gemm_k_iterations; -+ int gemm_k_iterations_per_channel; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Mma::IteratorB::Element *ptr_reordered_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) -+ : problem_size(args.problem_size), -+ implicit_gemm_problem_size( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), -+ ptr_B(args.ref_B.data()), -+ ptr_reordered_B(args.ref_reordered_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode), -+ split_k_slices(args.problem_size.split_k_slices) { -+ gemm_k_iterations = -+ depthwise_gemm_k_iterations(kConvolutionalOperator, -+ ThreadblockShape::kK, -+ args.problem_size, -+ kIteratorAlgorithm, -+ kGroupMode, -+ ThreadblockShape::kN); -+ -+ gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( -+ kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ -+ // Dynamic SMEM usage because stride and dilation are runtime params. -+ smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_smem_size() { -+ // Dynamic Smem Size -+ return smem_size_; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct ReorderKernel { -+ using Params = Params_; -+ using ElementB = ElementB_; -+ -+ union SharedStorage {}; -+ -+ static unsigned int const kReorderKernelThreadPerCTA = 128; -+ -+ CUTLASS_HOST_DEVICE -+ ReorderKernel() {} -+ -+ CUTLASS_HOST_DEVICE -+ static dim3 get_grid_shape(Params const ¶ms) { -+ return dim3{static_cast( -+ (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / -+ kReorderKernelThreadPerCTA), -+ 1, -+ 1}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ int64_t m = static_cast(params.problem_size.groups); -+ int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); -+ const ElementB *src_with_type = static_cast(params.ptr_B); -+ ElementB *dst_with_type = static_cast(params.ptr_reordered_B); -+ -+ int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; -+ int64_t index_m = linear_index / n; -+ int64_t index_n = linear_index % n; -+ int64_t new_linear_index = index_m + index_n * m; -+ -+ if (linear_index < m * n) { -+ dst_with_type[new_linear_index] = src_with_type[linear_index]; -+ } -+ return; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> -+> -+struct DirectConvolution { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefB ref_reordered_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ TensorRefB const & ref_reordered_B = nullptr, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ ref_reordered_B(ref_reordered_B), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ using Params = -+ typename cutlass::conv::kernel::DirectConvolutionParams; -+ -+ using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ DirectConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if threadblock is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ int iterator_column_offset = 0; -+ int filter_row_offset = 0; -+ if (kGroupMode != GroupMode::kNone) { -+ if (kGroupMode == GroupMode::kDepthwise) { -+ iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; -+ } -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() + threadblock_tile_idx.k(), -+ iterator_column_offset -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_reordered_B, -+ thread_idx, -+ MatrixCoord( -+ filter_row_offset, -+ iterator_column_offset -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() + threadblock_tile_idx.k(), -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ // Epilogue is fused in the mainloop -+ mma(params.gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ params.iterator_A, -+ iterator_B, -+ params.iterator_B, -+ accumulators, -+ epilogue, -+ output_op, -+ iterator_D, -+ iterator_C, -+ params.split_k_slices); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h -new file mode 100644 -index 0000000..11ac967 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h -@@ -0,0 +1,456 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode -+> -+struct ImplicitGemmConvolution { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ -+ int gemm_k_iterations; -+ int gemm_k_iterations_per_channel; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations( -+ kConvolutionalOperator, -+ ThreadblockShape::kK, -+ args.problem_size, -+ kIteratorAlgorithm, -+ kGroupMode, -+ ThreadblockShape::kN); -+ -+ gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( -+ kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolution() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; -+ if (kGroupMode != GroupMode::kNone) { -+ if (kGroupMode != GroupMode::kDepthwise) { -+ int k_per_group = params.problem_size.K / params.problem_size.groups; -+ int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; -+ int channels_per_group = params.problem_size.C / params.problem_size.groups; -+ iterator_A_column_offset += group_idx * channels_per_group; -+ } else { -+ iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; -+ } -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ iterator_A_column_offset -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h -new file mode 100644 -index 0000000..b740c90 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h -@@ -0,0 +1,463 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionFusion { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ -+ using ElementScaleBias = typename Mma::IteratorScaleBias::Element; -+ using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout; -+ -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefScaleBias ref_scale; -+ TensorRefScaleBias ref_bias; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefScaleBias const & ref_scale, -+ TensorRefScaleBias const & ref_bias, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_scale(ref_scale), -+ ref_bias(ref_bias), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Mma::IteratorScaleBias::Params iterator_scale_bias; -+ typename Mma::IteratorScaleBias::Element const *ptr_scale; -+ typename Mma::IteratorScaleBias::Element const *ptr_bias; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_scale_bias(args.problem_size, args.ref_scale.layout()), -+ ptr_scale(args.ref_scale.data()), -+ ptr_bias(args.ref_bias.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionFusion() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A operand -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ // Construct iterators to B operand -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Construct iterators to A scale/bias vector -+ typename Mma::IteratorScaleBias iterator_scale_bias( -+ params.iterator_scale_bias, -+ params.problem_size, -+ params.ptr_scale, -+ params.ptr_bias, -+ thread_idx, -+ MatrixCoord( -+ 0, (kConvolutionalOperator == conv::Operator::kFprop) ? -+ (threadblock_tile_idx.k() * Mma::Shape::kK) : -+ // Wgrad -+ (threadblock_tile_idx.n() * Mma::Shape::kN) -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, -+ iterator_B, iterator_scale_bias, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h -new file mode 100644 -index 0000000..7304cbd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h -@@ -0,0 +1,492 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionStridedDgrad { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // Strided dgrad uses a specialized threadblock swizzle for functionality and performance -+ static_assert((platform::is_same::value) || -+ (platform::is_same>::value) || -+ (platform::is_same>::value) || -+ (platform::is_same>::value), -+ "Needs ThreadblockSwizzle type specialized for strided dgrad"); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ FastDivmod stride_h_divmod; -+ FastDivmod stride_w_divmod; -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): gemm_k_iterations(0) { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ stride_h_divmod(args.problem_size.stride_h), -+ stride_w_divmod(args.problem_size.stride_w), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode) -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ kConvolutionalOperator, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionStridedDgrad() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Compute starting filter position for strided dgrad -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, -+ ThreadblockShape::kM); -+ int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); -+ -+ -+ // The subsequent fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // int start_r = filter_tile_m / (params.problem_size.stride_w); -+ // int start_s = filter_tile_m % (params.problem_size.stride_w); -+ -+ int start_r, start_s; -+ params.stride_w_divmod(start_r, start_s, filter_tile_m); -+ -+ int filter_r = start_r; -+ int filter_s = start_s; -+ -+ if (params.problem_size.mode == Mode::kConvolution) { -+ filter_r = (params.problem_size.R - 1 - filter_r); -+ filter_s = (params.problem_size.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ params.problem_size, -+ params.stride_h_divmod, params.stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { -+ return; -+ } -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA -+ if (start_r < params.problem_size.R && start_s < params.problem_size.S) { -+ // Scale gemm_k_iterations for strided dgrad -+ int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) -+ ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ start_r, start_s, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset -+ ); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ if (output_op.is_source_needed()) -+ { -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ params.stride_h_divmod, params.stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ } -+ -+ // Run epilogue with addend source iterator -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ else -+ { -+ // Run epilogue without addend source iterator -+ epilogue(output_op, iterator_D, accumulators); -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h -new file mode 100644 -index 0000000..3fa7dac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h -@@ -0,0 +1,499 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Implicit GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem -+> -+struct ImplicitGemmConvolutionWithFusedEpilogue { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static Operator const kConvolutionalOperator = ConvOperator; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueOutputOp::ElementOutput; -+ -+ /// Set output tensor C layout -+ using LayoutC = LayoutA; -+ -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using WarpMmaOperator = typename Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using TensorRefA = typename Mma::IteratorA::TensorRef; -+ using TensorRefB = typename Mma::IteratorB::TensorRef; -+ using TensorRefC = cutlass::TensorRef; -+ -+ /// Check iterator A and B convolution dimension are the same and -+ // set device::ImplicitGemmConvolution::kConvDim -+ static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, -+ "Convolution on different different dimensions is not supported"); -+ static int const kConvDim = Mma::IteratorA::kConvDim; -+ -+ /// Conv dimension and problem size structure (Conv2d or Conv3d) -+ using ConvProblemSize = ConvProblemSize_; -+ -+ static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; -+ -+ /// Wgrad C stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix C (KxRSC) -+ // Conv3d row-major matrix C (KxTRSC) -+ static int const kWgradCStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorCStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); -+ -+ // -+ // -+ // -+ using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< -+ LayoutC, -+ typename Epilogue::OutputTileIterator::Layout, -+ TensorRefC, -+ ConvOperator, -+ ConvProblemSize -+ >; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ ConvProblemSize problem_size; -+ TensorRefA ref_A; -+ TensorRefB ref_B; -+ TensorRefC ref_C; -+ TensorRefC ref_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ SplitKMode split_k_mode; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ConvProblemSize const & problem_size, -+ TensorRefA const & ref_A, -+ TensorRefB const & ref_B, -+ TensorRefC const & ref_C, -+ TensorRefC const & ref_D, -+ typename EpilogueOutputOp::Params const & output_op, -+ SplitKMode const & split_k_mode = SplitKMode::kSerial, -+ void * ptr_Vector = nullptr, -+ void * ptr_Tensor = nullptr, -+ typename LayoutC::Stride::Index ldr = 0, -+ typename LayoutC::Stride::Index ldt = 0 -+ ): -+ problem_size(problem_size), -+ ref_A(ref_A), -+ ref_B(ref_B), -+ ref_C(ref_C), -+ ref_D(ref_D), -+ output_op(output_op), -+ split_k_mode(split_k_mode), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ ldr(ldr), -+ ldt(ldt) -+ { -+ -+ } -+ -+ }; -+ -+ /// Parameters structure -+ struct Params { -+ ConvProblemSize problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ gemm::GemmCoord implicit_gemm_problem_size; -+ int swizzle_log_tile; -+ -+ int gemm_k_iterations; -+ typename Mma::IteratorA::Params iterator_A; -+ typename Mma::IteratorA::Element const *ptr_A; -+ typename Mma::IteratorB::Params iterator_B; -+ typename Mma::IteratorB::Element const *ptr_B; -+ typename Epilogue::OutputTileIterator::Params iterator_C; -+ typename Epilogue::OutputTileIterator::Element *ptr_C; -+ typename Epilogue::OutputTileIterator::Params iterator_D; -+ typename Epilogue::OutputTileIterator::Element *ptr_D; -+ typename EpilogueOutputOp::Params output_op; -+ int *semaphore; -+ SplitKMode split_k_mode; -+ -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ void * ptr_Tensor; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ gemm_k_iterations(0), -+ ptr_Vector(nullptr), -+ ldr(0), -+ ptr_Tensor(nullptr) -+ { } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int *semaphore = nullptr -+ ): -+ problem_size(args.problem_size), -+ implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), -+ iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), -+ ptr_A(args.ref_A.data()), -+ iterator_B(args.problem_size, args.ref_B.layout()), -+ ptr_B(args.ref_B.data()), -+ iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), -+ ptr_C(args.ref_C.data()), -+ iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), -+ ptr_D(args.ref_D.data()), -+ output_op(args.output_op), -+ semaphore(semaphore), -+ split_k_mode(args.split_k_mode), -+ params_Tensor(args.ldt), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor) -+ -+ { -+ gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ implicit_gemm_problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.problem_size.split_k_slices); -+ -+ swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ ImplicitGemmConvolutionWithFusedEpilogue() { } -+ -+ /// Executes one ImplicitGEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { -+ -+ return; -+ } -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.iterator_A, -+ params.problem_size, -+ params.ptr_A, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.k() * Mma::Shape::kK -+ ) -+ ); -+ -+ typename Mma::IteratorB iterator_B( -+ params.iterator_B, -+ params.problem_size, -+ params.ptr_B, -+ thread_idx, -+ MatrixCoord( -+ threadblock_tile_idx.k() * Mma::Shape::kK, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // Construct the semaphore. -+ int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); -+ -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Compute logical position within grid -+ threadblock_tile_idx = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); -+ } -+ -+ MatrixCoord threadblock_offset( -+ threadblock_tile_idx.m() * Mma::Shape::kM, -+ threadblock_tile_idx.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.iterator_D, -+ params.ptr_D, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator reading from source accumulator tensor -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.iterator_C, -+ params.ptr_C, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::ElementTensor *ptr_Tensor = -+ static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; -+ } -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_idx.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_idx.k()); -+ -+ } -+ // Each split-k-slice writes to a unique tensor location -+ else if (params.split_k_mode == SplitKMode::kParallel) { -+ iterator_D.add_pointer_offset(threadblock_tile_idx.k() * -+ cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ ConvOutputIteratorParameter::extent(params.problem_size), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_idx.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h b/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h -new file mode 100644 -index 0000000..8f84563 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/thread/depthwise_mma.h -@@ -0,0 +1,325 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for depthwise convolution -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// MMA operation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Number of threads participating -+ int kThreads_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Inner product operator -+ typename Operator -+> -+struct ElementwiseInnerProduct; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// General implementation -+template < -+ /// Size of the matrix product (concept: GemmShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_> -+struct ElementwiseInnerProduct { -+ using Shape = Shape_; -+ using Operator = arch::OpMultiplyAdd; -+ using ElementC = ElementC_; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Shape::kN; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization of half_t -+template <> -+struct ElementwiseInnerProduct< -+ gemm::GemmShape<2, 2, 1>, -+ 1, -+ half_t, -+ half_t, -+ half_t, -+ arch::OpMultiplyAdd> { -+ -+ using Shape = gemm::GemmShape<2, 2, 1>; -+ using Operator = arch::OpMultiplyAdd; -+ using ElementC = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array &d, -+ Array const &a, -+ Array const &b, -+ Array const &c -+ ) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) -+ -+ __half2 const & A = reinterpret_cast<__half2 const &>(a); -+ __half2 const & B = reinterpret_cast<__half2 const &>(b); -+ __half2 const & C = reinterpret_cast<__half2 const &>(c); -+ -+ __half2 tmp_D = __hfma2(A, B, C); -+ -+ d = reinterpret_cast const &>(tmp_D); -+ -+#else -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 2; ++i) { -+ d[i] = a[i] * b[i] + c[i]; -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Concept: arch::OpMultiplyAdd or arch::Mma<> -+ typename Operator = arch::OpMultiplyAdd, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+struct DepthwiseDirectConvElementwiseInnerProduct; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Operator used to compute GEMM -+ typename Operator_ -+> -+struct DepthwiseDirectConvElementwiseInnerProductGeneric { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Underlying mathematical operator -+ using Operator = Operator_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< -+ gemm::GemmShape, -+ 1, -+ ElementA, -+ ElementB, -+ ElementC, -+ Operator>; -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ -+ MmaOp mma_op; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; -+ Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; -+ Array tmpB = ptr_B[n]; -+ -+ mma_op(tmpD, tmpA, tmpB, tmpD); -+ -+ ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; -+ -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Element type of C matrix -+ typename ElementC_ -+> -+struct DepthwiseDirectConvElementwiseInnerProduct< -+ Shape_, -+ ElementA_, -+ ElementB_, -+ ElementC_, -+ arch::OpMultiplyAdd -+ > { -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = -+ Array; // output_tile_size per thread * groups_per_thread -+ -+ /// B operand storage -+ using FragmentB = Array; // 1 * groups_per_thread -+ -+ /// C operand storage -+ using FragmentC = -+ Array; // output_tile_size per thread * groups_per_thread -+ -+ static bool const use_optimized = 0; -+ -+ using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ ArchMmaOperator mma; -+ -+ mma(D, A, B, C); -+ -+ } -+}; -+ -+} // namespace thread -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..9464074 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,485 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs -+// on non-contributing w positions -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Moves filter_s -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r -+ filter_r_ += problem_size_.stride_h; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+>{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_r_; -+ int filter_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..bd5aa70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,619 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+ > { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base): -+ Conv2dStridedDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dStridedDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ -+ int start_r_; -+ int start_s_; -+ -+ int64_t reset_bytes_s_; -+ int64_t reset_bytes_r_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized( -+ Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; -+ reset_bytes_r_ = reset_bytes_s_ + -+ (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ TensorCoord coord{filter_k_, filter_r_, filter_s_, column}; -+ -+ pointer_ += params_.layout(coord) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ LongIndex reset_bytes = params_.reset_bytes; -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ >= problem_size_.S) { -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+#if 0 -+ bool check = (filter_r_ < problem_size_.R); -+ -+ filter_r_ = check ? filter_r_ : start_r_; -+ next_idx = check ? 1 : 2; -+ reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); -+#else -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " .reg .s64 t1;\n\t" -+ " setp.lt.s32 %%p, %3, %4;\n\t" -+ " selp.s32 %0, %3, %5, %%p;\n\t" -+ " selp.s32 %1, 1, 2, %%p;\n\t" -+ " selp.s64 t1, %6, %7, %%p;\n\t" -+ " add.s64 %2, %8, t1;\n\t" -+ "}\n" -+ : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) -+ : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), -+ "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); -+#endif -+ } -+ -+ // offset pointers by offset_bytes -+ pointer_ += (params_.inc_next[next_idx] - reset_bytes); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad -+// on problem sizes with stride = {1x1} -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradFilterTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+ > { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dDgradFilterIteratorOptimizedParams const &base): -+ Conv2dDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv2dDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_rs_; -+ int filter_k_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized( -+ Conv2dDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_rs_(0), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ pointer_ += ( -+ filter_k_ * params.layout.stride()[2] + column -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_rs; -+ -+ // moves to the next tile -+ ++filter_rs_; -+ if (filter_rs_ == params_.RS) { -+ -+ filter_rs_ = 0; -+ next = params_.inc_next_k; -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..08f5465 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,606 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -+// unscaled coordinations -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int filter_r = filter_r_; -+ int filter_s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ filter_r = (problem_size_.R - 1 - filter_r); -+ filter_s = (problem_size_.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ problem_size_, -+ stride_h_divmod, stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ // Effective P and Q for filter position required for remapping NHW rows -+ int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; -+ int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; -+ -+ // (STEP 1) [reorder NHW rows to start with same filter positions] -+ offset_n_[s] = offset_npq / (P * Q); -+ int residual = offset_npq % (P * Q); -+ -+ int p = (residual / Q); -+ int q = (residual % Q); -+ -+ int mapped_h = (start_h + p * problem_size_.stride_h); -+ int mapped_w = (start_w + q * problem_size_.stride_w); -+ -+ // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 -+ // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible -+ // by stride_h and stride_w -+ offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; -+ offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ // Move filter_k -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); -+ -+ p += (conv_sign * (filter_r_ / problem_size_.stride_h)); -+ q += (conv_sign * (filter_s_ / problem_size_.stride_w)); -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord( -+ n, -+ p, -+ q, -+ k); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return -+ coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by -+// eliminating modulo arithmetic to compute unscaled coordinates -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_w_[ThreadMap::Iterations::kStrided]; -+ int offset_h_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); -+ int residual = offset_nhw % (problem_size_.H * problem_size_.W); -+ -+ offset_h_[s] = residual / problem_size_.W; -+ offset_w_[s] = residual % problem_size_.W; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // move to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int n = offset_n_[iteration_strided_]; -+ int h = offset_h_[iteration_strided_]; -+ int w = offset_w_[iteration_strided_]; -+ -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h; -+ int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w; -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization -+ // only supports (stride_h, stride_w) = (1, 1) -+ if (problem_size.stride() != MatrixCoord({1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..38d94ac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,821 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling -+// to skip MMAs (Dx = Dy * w) on invalid filter positions -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = uint64_t; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ int filter_k_; -+ int filter_r_; -+ int filter_s_; -+ int start_r_; -+ int start_s_; -+ int64_t reset_bytes_s_; -+ int64_t reset_bytes_r_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_r_(start_r), -+ filter_s_(start_s), -+ start_r_(start_r), -+ start_s_(start_s) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; -+ -+ reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] + -+ (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ int filter_r = filter_r_; -+ int filter_s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ filter_r = (problem_size_.R - 1 - filter_r); -+ filter_s = (problem_size_.S - 1 - filter_s); -+ } -+ -+ // Starting h, w positions for filter position in gemm_k=0 -+ int start_h, start_w; -+ strided_dgrad_starting_coords( -+ problem_size_, -+ stride_h_divmod, stride_w_divmod, -+ filter_r, filter_s, -+ start_h, start_w); -+ -+ -+ // Effective starting P and Q for filter position required for remapping NHW rows -+ int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; -+ int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; -+ -+ // (STEP 1) [reorder NHW rows to start with same filter positions] -+ offset_n[s] = offset_npq / (P * Q); -+ int residual = offset_npq % (P * Q); -+ -+ int p = (residual / Q); -+ int q = (residual % Q); -+ -+ int mapped_h = (start_h + p * problem_size_.stride_h); -+ int mapped_w = (start_w + q * problem_size_.stride_w); -+ -+ // Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0 -+ // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be -+ // divisible by stride_h and stride_w -+ offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; -+ offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; -+ -+ // Intialize pointers for gemm_k=0 -+ TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ // -+ // Precompute mask predicates -+ // -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int p = offset_p[s_idx] ; -+ -+ p += (params_.conv_sign * (r / problem_size_.stride_h)); -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int q = offset_q[s_idx]; -+ q += (params_.conv_sign * (s / problem_size_.stride_w)); -+ -+ bool pred = (q >=0 && q < problem_size_.Q); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}); -+ } -+ -+private: -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset - byte_reset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ int64_t reset_bytes = 0; -+ -+ // Move filter_s by stride_w -+ filter_s_ += problem_size_.stride_w; -+ if (filter_s_ >= problem_size_.S) { -+ -+ // Restore filter_s -+ filter_s_ = start_s_; -+ -+ // Move filter_r by stride_h -+ filter_r_ += problem_size_.stride_h; -+#if 0 -+ if (filter_r_ < problem_size_.R) { -+ -+ next_idx = 1; -+ -+ // Restore bytes in q coordinate (Mma in filter s dimenstion) -+ reset_bytes = reset_bytes_s_; -+ -+ } else { -+ -+ // Restore filter_r -+ filter_r_ = start_r_; -+ -+ next_idx = 2; -+ -+ // Restore bytes in p and q coordinate (Mma in filter s and r dimenstion) -+ reset_bytes = reset_bytes_r_; -+ } -+#else -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " setp.lt.s32 %%p, %3, %4;\n\t" -+ " selp.s32 %0, %3, %5, %%p;\n\t" -+ " selp.s32 %1, 1, 2, %%p;\n\t" -+ " selp.s64 %2, %6, %7, %%p;\n\t" -+ "}\n" -+ : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) -+ : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), -+ "l"(reset_bytes_s_), "l"(reset_bytes_r_)); -+#endif -+ } -+ -+ // offset pointers by offset_bytes -+ add_byte_offset_(params_.inc_next[next_idx] - reset_bytes); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Limit on filter size -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad -+// with problem stride = {1x1} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ -+> -+class Conv2dDgradOutputGradientTileAccessIteratorOptimized < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kUnity, -+ AccessType_ -+> { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (r, s) -+ int filter_r_; -+ int filter_s_; -+ int filter_k_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized( -+ Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_h[ThreadMap::Iterations::kStrided]; -+ int offset_w[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W); -+ // int residual = offset_nhw % (problem_size_.H * problem_size_.W); -+ // -+ // offset_h[s] = residual / problem_size_.W; -+ // offset_w[s] = residual % problem_size_.W; -+ // -+ -+ int residual; -+ -+ params_.hw_divmod(offset_n[s], residual, offset_nhw); -+ params_.w_divmod(offset_h[s], offset_w[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; -+ -+ bool pred = (q >= 0 && q < problem_size_.Q); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the output gradient tensor dy that is correspoinding to -+ // activation nhw and filter position k, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int h, int w, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; -+ int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, p, q, filter_k_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ filter_s_ = 0; -+ ++filter_r_; -+ -+ if (filter_r_ < problem_size_.R) { -+ next_idx = 1; -+ } -+ else { -+ filter_r_ = 0; -+ next_idx = 2; -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 2) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // This is specialized for unit stride -+ if (problem_size.stride() != MatrixCoord({1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // Limit on filter size -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..e667ddd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,332 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray, -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone -+> -+class Conv2dFpropActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_c_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_init_; -+ int group_idx_offset_; -+ int channels_per_group_; -+ int crs_cnt_; -+ int crs_per_group_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ crs_cnt_(0), -+ group_idx_offset_(0), -+ filter_c_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ if (kGroupMode != conv::GroupMode::kNone) { -+ filter_c_init_ = filter_c_; -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ if (kGroupMode != conv::GroupMode::kNone) { -+ ++crs_cnt_; -+ } -+ -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } else { -+ if (crs_cnt_ == crs_per_group_) { -+ // moves to next group -+ crs_cnt_ = 0; -+ ++group_idx_offset_; -+ filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; -+ } else { -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - filter_r_); -+ s = (problem_size_.S - 1 - filter_s_); -+ } -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ int c = filter_c_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h -new file mode 100644 -index 0000000..1b668ce ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h -@@ -0,0 +1,360 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorFewChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kPositionsPerTile = Shape::kColumn; -+ -+ static int const kAccessesPerVector = kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static int const kStrideH = 0; -+ static int const kStrideW = 0; -+ static int const kDilationH = 0; -+ static int const kDilationW = 0; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rsc_index_; -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFewChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rsc_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rsc_index_ = (threadblock_offset.column() + thread_coord.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ if (kUseFastDivmodPrologue) { -+ int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); -+ offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); -+ } -+ else { -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; -+ -+ int r = 0; -+ int s = 0; -+ int c = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ int rs_index = params_.divmod_C.divmod(c, rsc_index); -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ c = (rsc_index % problem_size_.C); -+ -+ int rs_index = (rsc_index / problem_size_.C); -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int stride_h = kStrideH; -+ if (!kStrideH) { -+ stride_h = problem_size_.stride_h; -+ } -+ -+ int stride_w = kStrideW; -+ if (!kStrideW) { -+ stride_w = problem_size_.stride_w; -+ } -+ -+ int dilation_h = kDilationH; -+ if (!kDilationH) { -+ dilation_h = problem_size_.dilation_h; -+ } -+ -+ int dilation_w = kDilationW; -+ if (!kDilationW) { -+ dilation_w = problem_size_.dilation_w; -+ } -+ -+ int h = p * stride_h - problem_size_.pad_h + r * dilation_h; -+ int w = q * stride_w - problem_size_.pad_w + s * dilation_w; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ bool in_bounds = -+ coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ -+ return in_bounds; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + -+ coord.c(); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFewChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationH && problem_size.dilation_h != kDilationH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationW && problem_size.dilation_w != kDilationW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideH && problem_size.stride_h != kStrideH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideW && problem_size.stride_w != kStrideW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h -new file mode 100644 -index 0000000..3e680f4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h -@@ -0,0 +1,353 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorFixedChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kFilterPositionsPerTile = Shape::kColumn / AccessType::kElements; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static int const kStrideH = 0; -+ static int const kStrideW = 0; -+ static int const kDilationH = 0; -+ static int const kDilationW = 0; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rs_index_; -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFixedChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rs_index_(0) { -+ -+ // -+ // This requires problem_size.C == AccessType::kElements -+ // -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rs_index_ = (threadblock_offset.column() + thread_coord.contiguous()) / AccessType::kElements; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ if (kUseFastDivmodPrologue) { -+ int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); -+ offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); -+ } -+ else { -+ offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int rs_index = rs_index_ + iteration_vector_; -+ -+ int r = 0; -+ int s = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int stride_h = kStrideH; -+ if (!kStrideH) { -+ stride_h = problem_size_.stride_h; -+ } -+ -+ int stride_w = kStrideW; -+ if (!kStrideW) { -+ stride_w = problem_size_.stride_w; -+ } -+ -+ int dilation_h = kDilationH; -+ if (!kDilationH) { -+ dilation_h = problem_size_.dilation_h; -+ } -+ -+ int dilation_w = kDilationW; -+ if (!kDilationW) { -+ dilation_w = problem_size_.dilation_w; -+ } -+ -+ int h = p * stride_h - problem_size_.pad_h + r * dilation_h; -+ int w = q * stride_w - problem_size_.pad_w + s * dilation_w; -+ -+ return TensorCoord(n, h, w, 0); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + coord.c(); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorFixedChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C != AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationH && problem_size.dilation_h != kDilationH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kDilationW && problem_size.dilation_w != kDilationW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideH && problem_size.stride_h != kStrideH) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (kStrideW && problem_size.stride_w != kStrideW) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..fb1fcfc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,422 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ using Mask = uint64_t; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFpropActivationIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (r, s) -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_c_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q); -+ // int residual = offset_npq % (problem_size_.P * problem_size_.Q); -+ // -+ // offset_p[s] = residual / problem_size_.Q; -+ // offset_q[s] = residual % problem_size_.Q; -+ // -+ -+ int residual; -+ -+ params.pq_divmod(offset_n[s], residual, offset_npq); -+ params.q_divmod(offset_p[s], offset_q[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][0] |= (pred << r); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; -+ -+ bool pred = (w >= 0 && w < problem_size_.W); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ masks_[s_idx][v_idx][1] |= (pred << s); -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the activations tensor X that is correspoinding to -+ // output npq and filter position r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int p, int q, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, h, w, filter_c_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ filter_s_ = 0; -+ ++filter_r_; -+ -+ if (filter_r_ < problem_size_.R) { -+ next_idx = 1; -+ } -+ else { -+ filter_r_ = 0; -+ next_idx = 2; -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 2) { -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; -+ } -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; -+ masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.C % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions -+ // due to the number of mask bits. -+ if (problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..5c7dbd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray, -+ conv::GroupMode GroupMode_ = conv::GroupMode::kNone -+> -+class Conv2dFpropFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ static conv::GroupMode const kGroupMode = GroupMode_; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ int filter_c_init_; -+ int crs_cnt_; -+ int crs_per_group_; -+ int group_idx_offset_c_; -+ int channels_per_group_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ crs_cnt_(0), -+ group_idx_offset_c_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ -+ if (kGroupMode != conv::GroupMode::kNone) { -+ filter_c_init_ = filter_c_; -+ if (kGroupMode == conv::GroupMode::kDepthwise){ -+ channels_per_group_ = 1; -+ crs_per_group_ = problem_size_.S * problem_size_.R; -+ } else { -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { -+ group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); -+ } -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ if (kGroupMode != conv::GroupMode::kNone) { -+ ++crs_cnt_; -+ } -+ -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } else { -+ if (crs_cnt_ == crs_per_group_) { -+ crs_cnt_ = 0; -+ filter_c_ = filter_c_init_; -+ if (kGroupMode != conv::GroupMode::kDepthwise) { -+ // moves to next group -+ ++group_idx_offset_c_; -+ } -+ } else { -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ int c = filter_c_ + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(k, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ if (kGroupMode == conv::GroupMode::kNone) { -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } else if (kGroupMode == conv::GroupMode::kDepthwise) { -+ return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. -+ } else { -+ return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && -+ group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; -+ } -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h -new file mode 100644 -index 0000000..f0a3219 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorFewChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kPositionsPerTile = Shape::kRow; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rsc_index_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFewChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rsc_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rsc_index_ = (threadblock_offset.row() + thread_coord.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; -+ -+ int c = 0; -+ int s = 0; -+ int r = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ int rs_index = params_.divmod_C.divmod(c, rsc_index); -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ c = (rsc_index % problem_size_.C); -+ int rs_index = (rsc_index / problem_size_.C); -+ -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, r, s, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ bool in_bounds = -+ coord.n() < problem_size_.K && -+ coord.h() >= 0 && -+ coord.h() < problem_size_.R && -+ coord.c() < problem_size_.C; -+ -+ return in_bounds; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + -+ coord.c(); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFewChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h -new file mode 100644 -index 0000000..6536f62 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h -@@ -0,0 +1,275 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorFixedChannels { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kFilterPositionsPerTile = Shape::kRow / AccessType::kElements; -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static bool const kUseFastDivmodPrologue = true; -+ static bool const kUseFastDivmodMainloop = true; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dFewChannelsParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int rs_index_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFixedChannels( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ rs_index_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ rs_index_ = (threadblock_offset.row() + thread_coord.contiguous()) / AccessType::kElements; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int rs_index = rs_index_ + iteration_vector_; -+ -+ int r = 0; -+ int s = 0; -+ -+ if (kUseFastDivmodMainloop) { -+ r = params_.divmod_S.divmod(s, rs_index); -+ } -+ else { -+ s = (rs_index % problem_size_.S); -+ r = (rs_index / problem_size_.S); -+ } -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, r, s, 0); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.h() >= 0 && coord.h() < problem_size_.R; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ -+ int32_t offset = -+ coord.n() * params_.stride_n + -+ coord.h() * params_.stride_h + -+ coord.w() * params_.stride_w + coord.c(); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorFixedChannels &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C != AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..a85c620 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,317 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dFpropFilterTileAccessIteratorOptimized{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv2dFpropFilterIteratorOptimizedParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dFpropFilterIteratorOptimizedParams const &base): -+ Conv2dFpropFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv2dFpropFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { -+ -+ } -+ }; -+ -+private: -+ -+ Conv2dFpropFilterIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_rs_; -+ int filter_c_; -+ int channels_per_group_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorOptimized( -+ Conv2dFpropFilterIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_rs_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ Index column = threadblock_offset.column() + thread_coord.strided(); -+ channels_per_group_ = problem_size_.C / problem_size_.groups; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ predicates_[v_idx] |= (pred << s); -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); -+ } -+ -+ pointer_ += ( -+ params_.layout({filter_c_, column}) -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_rs; -+ -+ // moves to the next tile -+ ++filter_rs_; -+ if (filter_rs_ == params_.RS) { -+ -+ filter_rs_ = 0; -+ next = params_.inc_next_c; -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { -+ clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask(int v, bool clear = true) { -+ predicates_[v] = clear ? 0u : predicates_[v]; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return (predicates_[iteration_vector_] & (1u << iteration_strided_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_k; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 32) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ if (platform::is_same>::value) { -+ if (problem_size.K % 64) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h -new file mode 100644 -index 0000000..d96dee8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_params.h -@@ -0,0 +1,893 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv2d analytic tile iterators -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dAnalyticParams { -+ -+ using Layout = Layout_; -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dAnalyticParams( -+ Conv2dProblemSize const &, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv2d analytic tile iterators -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFewChannelsParams { -+ -+ using Layout = Layout_; -+ -+ -+ int32_t stride_w; -+ int32_t stride_h; -+ int32_t stride_n; -+ -+ FastDivmod divmod_P; -+ FastDivmod divmod_Q; -+ FastDivmod divmod_S; -+ FastDivmod divmod_C; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFewChannelsParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFewChannelsParams( -+ Conv2dProblemSize const &problem_size, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): -+ stride_w(int32_t(layout.stride()[0])), -+ stride_h(int32_t(layout.stride()[1])), -+ stride_n(int32_t(layout.stride()[2])), -+ divmod_P(problem_size.P), -+ divmod_Q(problem_size.Q), -+ divmod_S(problem_size.S), -+ divmod_C(problem_size.C) -+ { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams -+struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int tiled_rows_per_filter; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape -+ ): layout(layout) { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+ -+CUTLASS_HOST_DEVICE -+void TraceIteratorParams( -+ char const *conv_operator, -+ char const *operand, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+) { -+ -+#if !defined(__CUDA_ARCH__) -+ -+ char const *fname = "conv_iterator_params.csv"; -+ -+ std::ifstream test(fname); -+ bool file_exists = test.is_open(); -+ -+ if (file_exists) { -+ test.close(); -+ } -+ -+ std::ofstream trace("conv_iterator_params.csv", std::ofstream::app); -+ -+ if (!file_exists) { -+ trace -+ << "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize," -+ << "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n"; -+ } -+ -+ trace << conv_operator << "," << operand << "," << element_size_bits << "," -+ << threadblock_shape.row() << "," << threadblock_shape.column() -+ << "," << thread_count << "," << access_size -+ << "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided() -+ << "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n"; -+#endif -+} -+ -+#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \ -+ TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta); -+ -+#else -+ -+#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {} -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFpropActivationIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template<> -+struct Conv2dFpropActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int PQ; // product of P*Q -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ PQ(problem_size.P * problem_size.Q), -+ pq_divmod(PQ), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+ -+#if ENABLE_CONV2D_PARAMS_PRINT -+ /// Prints internal state. -+ CUTLASS_HOST_DEVICE -+ void print() { -+ auto stride = layout.stride(); -+ printf( -+ "Conv2dFpropActivationIteratorOptimizedParams:\n" -+ " layout(w: %d, h: %d, n: %d)\n" -+ " inc_next[%ld, %ld, %ld]\n" -+ " filter_c_delta(%d) - PQ(%d)\n" -+ " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" -+ " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", -+ stride[0], stride[1], stride[2], -+ inc_next[0], inc_next[1], inc_next[2], -+ filter_c_delta, -+ PQ, -+ pq_divmod.divisor, -+ pq_divmod.multiplier, -+ pq_divmod.shift_right, -+ q_divmod.divisor, -+ q_divmod.multiplier, -+ q_divmod.shift_right -+ ); -+ } -+#endif -+}; -+ -+/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized -+template -+struct Conv2dFpropActivationIteratorOptimizedParams> { -+ static int const kInterleaved = Interleaved_; -+ -+ using Layout = layout::TensorNCxHWx; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int PQ; // product of P*Q -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_h -+ - (problem_size.S - 1) * kInterleaved * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1]) -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< typename Layout_ = layout::TensorNHWC > -+struct Conv2dFpropFilterIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template<> -+struct Conv2dFpropFilterIteratorOptimizedParams -+{ -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ RS = problem_size.R * problem_size.S; -+ -+ inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( int64_t(layout.stride()[0]) -+ - int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices -+ - int64_t(RS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+ -+#if ENABLE_CONV2D_PARAMS_PRINT -+ /// Prints internal state. -+ CUTLASS_HOST_DEVICE -+ void print() { -+ auto stride = layout.stride(); -+ printf( -+ "Conv2dFpropFilterIteratorOptimizedParams:\n" -+ " layout[%d, %d, %d]\n" -+ " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", -+ stride[0], stride[1], stride[2], -+ RS, -+ filter_c_delta, -+ inc_next_k, inc_next_rs, inc_next_c -+ ); -+ } -+#endif -+}; -+ -+template -+struct Conv2dFpropFilterIteratorOptimizedParams> -+{ -+ static int const kInterleaved = Interleaved_; -+ using Layout = layout::TensorCxRSKx; -+ -+ Layout layout; -+ int RS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dFpropFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ RS = problem_size.R * problem_size.S; -+ -+ inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( int64_t(layout.stride()[0]) -+ - kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2]) -+ - int64_t(RS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Dgrad Optimized Dy params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator -+struct Conv2dDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next K} -+ -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ int HW; // product of H*W -+ -+ FastDivmod hw_divmod; -+ FastDivmod w_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ HW(problem_size.H *problem_size.W), -+ hw_divmod(HW), -+ w_divmod(problem_size.W) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ (int64_t)layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad Optimized Dy params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[3]; // {next S, next R, next K} -+ -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ int tiled_rows_per_filter; -+ -+ int conv_sign; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape -+ ): layout(layout) { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); -+ -+ conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ (int64_t)layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ (int64_t)layout.stride()[1] * problem_size.dilation_h -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+// Dgrad Optimized w params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next_rs; // offset in units of bytes to next RS position -+ int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dDgradFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), RS(problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_rs = -+ ( (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ inc_next_k = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] -+ - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+// StridedDgrad Optimized w params (layout::TensorNHWC) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct Conv2dStridedDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ int RS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next[3]; // {next S, next R, next K} -+ int64_t reset_bytes; // offset in units of bytes to move back the pointer -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dStridedDgradFilterIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), RS(problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ // next S -+ inc_next[0] = -+ ( (int64_t)layout.stride()[0] * problem_size.stride_w -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = -+ ( (int64_t)layout.stride()[1] * problem_size.stride_h -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[2] = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] -+ //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] -+ //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] -+ ) * element_size_bits / 8; -+ -+ // offset in units of bytes to move the pointer in backward direction -+ reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] -+ * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator -+struct Conv2dWgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int NPQ; // precomputd product of N*P*Q for clearing predicates -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile -+ int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile -+ int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ NPQ(problem_size.N * problem_size.P * problem_size.Q), -+ pq_divmod(problem_size.P * problem_size.Q), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 -+ offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ offset_next_contiguous = (threadmap_delta.contiguous()) -+ * element_size_bits / 8; -+ -+ inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ } -+}; -+ -+struct Conv2dWgradActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ FastDivmod sc_divmod; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ FastDivmod c_divmod; -+ FastDivmod s_divmod; -+ int small_channel_conv_s_offset; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ layout(layout), -+ sc_divmod(problem_size.S * problem_size.C), -+ pq_divmod(problem_size.P * problem_size.Q), -+ q_divmod(problem_size.Q), -+ c_divmod(problem_size.C), -+ s_divmod(problem_size.S * problem_size.dilation_w), -+ small_channel_conv_s_offset((problem_size.S - 1) * problem_size.dilation_w - problem_size.pad_w) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationIteratorOptimizedParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ Conv2dWgradActivationIteratorOptimizedParams( -+ problem_size, -+ layout -+ ) { -+ -+ TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ } -+}; -+ -+struct PredicatedScaleBiasVectorAccessIteratorParams { -+ public: -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams() { } -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ layout::PitchLinear const &layout) {} -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ layout::RowMajor const &layout) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h -new file mode 100644 -index 0000000..9c1742d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template wraps the tile access iterator concept to load whole tiles from tensors in -+ memory used for implicit GEMM convolution. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TileIterator { -+public: -+ using TileAccessIterator = TileAccessIterator_; -+ -+ using Shape = typename TileAccessIterator::Shape; -+ using Element = typename TileAccessIterator::Element; -+ using Layout = typename TileAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = typename TileAccessIterator::ThreadMap; -+ using AccessType = typename TileAccessIterator::AccessType; -+ using TensorRef = typename TileAccessIterator::TensorRef; -+ using Index = typename TileAccessIterator::Index; -+ using LongIndex = typename TileAccessIterator::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; -+ using Params = typename TileAccessIterator::Params; -+ static int const kConvDim = TileAccessIterator::kConvDim; -+ using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+private: -+ -+ /// Internal state -+ TileAccessIterator tile_access_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TileIterator( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { -+ return TileAccessIterator::getParams(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ tile_access_iterator_.set_iteration_index(index); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ tile_access_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIterator &operator++() { -+ tile_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIterator operator++(int) { -+ TileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[idx], -+ tile_access_iterator_.get() + pointer_offset, -+ tile_access_iterator_.valid() -+ ); -+ -+ ++tile_access_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ tile_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ tile_access_iterator_.advance(); -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // dispatch to iterator implementation -+ return TileAccessIterator::can_implement(problem_size); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad Tile Iterator -+template -+class TileIteratorStridedDgrad { -+public: -+ using TileAccessIterator = TileAccessIterator_; -+ -+ using Shape = typename TileAccessIterator::Shape; -+ using Element = typename TileAccessIterator::Element; -+ using Layout = typename TileAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = typename TileAccessIterator::ThreadMap; -+ using AccessType = typename TileAccessIterator::AccessType; -+ using TensorRef = typename TileAccessIterator::TensorRef; -+ using Index = typename TileAccessIterator::Index; -+ using LongIndex = typename TileAccessIterator::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; -+ static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; -+ using Params = typename TileAccessIterator::Params; -+ static int const kConvDim = TileAccessIterator::kConvDim; -+ using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, -+ ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; -+ -+private: -+ -+ /// Internal state -+ TileAccessIterator tile_access_iterator_; -+ -+public: -+ -+ /// Constructor (output gradient (Dy) OperandA ctor) -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_( -+ params, -+ problem_size, -+ ptr, -+ thread_idx, -+ stride_h_divmod, stride_w_divmod, -+ start_r, start_s, -+ threadblock_offset) { } -+ -+ /// Constructor (filter (w) OperandB ctor) -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ int start_r, int start_s, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ tile_access_iterator_(params, -+ problem_size, -+ ptr, -+ thread_idx, -+ start_r, start_s, -+ threadblock_offset) { } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { -+ return TileAccessIterator::getParams(problem_size, layout); -+ } -+ -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ tile_access_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad &operator++() { -+ tile_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ TileIteratorStridedDgrad operator++(int) { -+ TileIteratorStridedDgrad self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c + s * ThreadMap::Iterations::kContiguous], -+ tile_access_iterator_.get() + pointer_offset, -+ tile_access_iterator_.valid() -+ ); -+ -+ ++tile_access_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ tile_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ tile_access_iterator_.advance(); -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // dispatch to iterator implementation -+ return TileAccessIterator::can_implement(problem_size); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..6e73115 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_r_[ThreadMap::Iterations::kContiguous]; -+ int filter_s_[ThreadMap::Iterations::kContiguous]; -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) -+ { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); -+ int residual = rsc_offset % (problem_size_.S * problem_size_.C); -+ -+ filter_s_[c] = residual / problem_size_.C; -+ filter_c_[c] = residual % problem_size_.C; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int r, s, c; -+ -+ if (kAccessesPerVector == 1) { -+ /// One 128b aligned access fetching more than one element -+ c = filter_c_[iteration_contiguous_]; -+ r = filter_r_[iteration_contiguous_]; -+ s = filter_s_[iteration_contiguous_]; -+ } -+ else { -+ /// Multiple access to support non-128b alignment in contiguous dimenstion -+ c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; -+ int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; -+ s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; -+ int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; -+ r = filter_r_[iteration_contiguous_] + wrap_s; -+ } -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); -+ int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..8871735 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dWgradActivationIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k -+ // required for npq -> nhw translation -+ int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorOptimized( -+ Conv2dWgradActivationIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) -+ { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); -+ // int residual = rsc_offset % (problem_size_.S * problem_size_.C); -+ // -+ // filter_s_[c] = residual / problem_size_.C; -+ // filter_c_[c] = residual % problem_size_.C; -+ -+ int residual; -+ params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset); -+ params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual); -+ -+ int r = precomputed_filter_r_[c]; -+ int s = precomputed_filter_s_[c]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; -+ precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int r = precomputed_filter_r_[iteration_contiguous_]; -+ int s = precomputed_filter_s_[iteration_contiguous_]; -+ int c = filter_c_[iteration_contiguous_]; -+ -+ if (kAccessesPerVector > 1) { -+ // This code section is only to support non-128b alignment -+ // Multiple access to support non-128b alignment in contiguous dimenstion -+ int wrap_c; -+ params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ s -= (problem_size_.dilation_w * wrap_c); -+ -+ int wrap_s; -+ params_.s_divmod(wrap_s, s, params_.small_channel_conv_s_offset - s); -+ s = params_.small_channel_conv_s_offset - s; -+ -+ r -= (problem_size_.dilation_h * wrap_s); -+ -+ } else { -+ s += (problem_size_.dilation_w * wrap_c); -+ -+ int wrap_s; -+ params_.s_divmod(wrap_s, s, s + problem_size_.pad_w); -+ s -= problem_size_.pad_w; -+ -+ r += (problem_size_.dilation_h * wrap_s); -+ } -+ } -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); -+ // int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, p, q; -+ -+ params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); -+ params_.q_divmod(p, q, residual); -+ -+ int h = p * problem_size_.stride_h + r; -+ int w = q * problem_size_.stride_w + s; -+ -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..97fd31e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,260 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_npq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize filter_k for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] = threadblock_offset.column() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int npq = offset_npq_[iteration_strided_]; -+ -+ int n = npq / (problem_size_.P * problem_size_.Q); -+ int residual = npq % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.h() < problem_size_.P && -+ coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..6725ed4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ typename AccessType_ = cutlass::AlignedArray -+> -+class Conv2dWgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; -+ Conv2dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ uint32_t predicates_[kAccessesPerVector]; -+ int filter_k_; -+ int offset_npq_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorOptimized( -+ Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_k_(0), -+ offset_npq_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); -+ offset_npq_ = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; -+ int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); -+ -+ uint32_t pred = (predicate ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_[v] |= (pred << pred_idx); -+ } -+ } -+ } -+ -+ // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) -+ pointer_ += ( -+ offset_npq_ * params.layout.stride()[0] + filter_k_ -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ offset_npq_ += Shape::kColumn * problem_size_.split_k_slices; -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ predicates_[v] = (predicates_[v] & (~kClearMask)); -+ } -+ } -+ } -+ -+ pointer_ += params_.inc_next_npq; -+ } -+ -+private: -+ /// Returns the coordinate in the output gradient tensor Dy that is pointed to -+ /// by offset_npq and k. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int offset_npq, int k) const { -+ -+ // The subsequent fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int npq = offset_npq; -+ // int n = npq / (problem_size_.P * problem_size_.Q); -+ // int residual = npq % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, p, q; -+ -+ params_.pq_divmod(n, residual, offset_npq); -+ params_.q_divmod(p, q, residual); -+ -+ return TensorCoord(n, p, q, k); -+ } -+ -+ /// Returns true if the coord is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid_(TensorCoord coord) const { -+ -+ return coord.n() < problem_size_.N && -+ coord.c() < problem_size_.K; -+ } -+ -+public: -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_[iteration_vector_] & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ iteration_strided_ * params_.offset_next_strided + -+ iteration_contiguous_ * params_.offset_next_contiguous -+ ) + iteration_vector_; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..8566f07 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,268 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dDgradFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or larger."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ int offset_c_[ThreadMap::Iterations::kContiguous]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = -+ threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the filter tensor w that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int c = offset_c_[iteration_contiguous_]; -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, filter_t_, filter_r_, filter_s_, c); -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor w -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..b9876ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -+> -+class Conv3dDgradFilterTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = StrideSupport_; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dDgradFilterIteratorOptimizedParams { -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dDgradFilterIteratorOptimizedParams const &base): -+ Conv3dDgradFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv3dDgradFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { } -+ -+ }; -+ -+private: -+ -+ Conv3dDgradFilterIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_trs_; -+ int filter_k_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided * -+ ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorOptimized( -+ Conv3dDgradFilterIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_(0), -+ filter_trs_(0), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.strided(); -+ Index column = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; -+ int filter_c = column + c * ThreadMap::Delta::kContiguous; -+ -+ uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_ |= (pred << pred_idx); -+ } -+ } -+ -+ pointer_ += ( -+ filter_k_ * params.layout.stride()[3] + column -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_trs; -+ -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == params_.TRS) { -+ -+ filter_trs_ = 0; -+ next = params_.inc_next_k; -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ -+ predicates_ = (predicates_ & (~kClearMask)); -+ } -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_ & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_strided; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..5c399e2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided -+> -+class Conv3dDgradOutputGradientTileAccessIteratorAnalytic; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using -+// unscaled coordinations -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < -+ Shape_, -+ Element_, -+ ThreadMap_, -+ conv::StrideSupport::kStrided -+> { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "DGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Simpligying assertions -+ // -+ -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ConvProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_d_[ThreadMap::Iterations::kStrided]; -+ int offset_w_[ThreadMap::Iterations::kStrided]; -+ int offset_h_[ThreadMap::Iterations::kStrided]; -+ -+private: -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator but DOES NOT scale by the convolution stride. This is needed -+ /// to compute predicates in the valid() method. The return value of the public at() -+ /// method is correctly scaled. -+ CUTLASS_HOST_DEVICE -+ TensorCoord unscaled_at_() const { -+ int n = offset_n_[iteration_strided_]; -+ int d = offset_d_[iteration_strided_]; -+ int h = offset_h_[iteration_strided_]; -+ int w = offset_w_[iteration_strided_]; -+ -+ int t = filter_t_; -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d); -+ int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); -+ int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); -+ -+ return TensorCoord(n, z, p, q, filter_k_); -+ } -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); -+ int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); -+ -+ offset_d_[s] = residual / (problem_size_.H * problem_size_.W); -+ residual = residual % (problem_size_.H * problem_size_.W); -+ -+ offset_h_[s] = residual / problem_size_.W; -+ offset_w_[s] = residual % problem_size_.W; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // move to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the output tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ TensorCoord coord = unscaled_at_(); -+ -+ return TensorCoord( -+ coord.n(), -+ coord.d() / problem_size_.stride_d, -+ coord.h() / problem_size_.stride_h, -+ coord.w() / problem_size_.stride_w, -+ coord.c()); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord unscaled_coord = unscaled_at_(); -+ TensorCoord coord = at(); -+ -+ return -+ !(unscaled_coord.d() % problem_size_.stride_d) && -+ !(unscaled_coord.h() % problem_size_.stride_h) && -+ !(unscaled_coord.w() % problem_size_.stride_w) && -+ coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.Z && -+ coord.h() >= 0 && coord.h() < problem_size_.P && -+ coord.w() >= 0 && coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..f834a34 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_, -+ conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity -+> -+class Conv3dDgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ static_assert(StrideSupport_ == conv::StrideSupport::kUnity, -+ "Only unit-stride dgrad is supported at this time."); -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ using Coord3D = Coord<3>; -+ static int const kAccessesPerVector = 1; -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (t, r, s) -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_k_; -+ -+ Index masks_[ThreadMap::Iterations::kStrided][3]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ filter_k_(0), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_d[ThreadMap::Iterations::kStrided]; -+ int offset_h[ThreadMap::Iterations::kStrided]; -+ int offset_w[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); -+ // int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); -+ // -+ // -+ // offset_d[s] = residual / (problem_size_.H * problem_size_.W); -+ // residual = residual % (problem_size_.H * problem_size_.W); -+ // -+ // offset_h[s] = residual / problem_size_.W; -+ // offset_w[s] = residual % problem_size_.W; -+ // -+ -+ int residual; -+ -+ // input: (ndhw offset) output: (n offset and resudial (dhw offset)) -+ params_.dhw_divmod(offset_n[s], residual, offset_ndhw); -+ // input: (dhw offset) output: (d offset and resudial (hw)) -+ params_.hw_divmod(offset_d[s], residual, residual); -+ // input: (hw offset) output: (h offset and resudial (w offset)) -+ params_.w_divmod(offset_h[s], offset_w[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_d[s], offset_h[s], offset_w[s], 0, 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int t = 0; t < problem_size_.T; ++t) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int t_ = t; -+ if (problem_size_.mode == Mode::kConvolution) { -+ t_ = problem_size_.T - 1 - t; -+ } -+ -+ int z = offset_d[s_idx] + problem_size_.pad_d - t_ * problem_size_.dilation_d; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && z >= 0 && z < problem_size_.Z); -+ masks_[s_idx][0] |= (pred << t); -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; -+ -+ bool pred = (p >= 0 && p < problem_size_.P); -+ masks_[s_idx][1] |= (pred << r); -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; -+ -+ bool pred = (q >= 0 && q < problem_size_.Q); -+ masks_[s_idx][2] |= (pred << s); -+ } -+ } -+ -+ if (filter_k_ >= problem_size.K) { -+ clear_mask(); -+ } -+ -+ set_iteration_index(0); -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ -+ /// Returns the coordinate in the output gradient tensor dy that is correspoinding to -+ // activation ndhw and filter position k, t, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int d, int h, int w, int t, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = problem_size_.T - 1 - t; -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int z = d + problem_size_.pad_d - t * problem_size_.dilation_d; -+ int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; -+ int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, z, p, q, filter_k_); -+ } -+ -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask_(bool clear) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ // We are using inline PTX assembly here to avoid an CUDA C++ compilation -+ // artifact in which control flow instructions are generated. Instead, our -+ // intent is to predicate the mov instructions. -+ #if defined(__CUDA_ARCH__) -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][0]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][0]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][1]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][1]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][2]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][2]) -+ ); -+ #else -+ if (clear) { -+ masks_[s][0] = 0; -+ masks_[s][1] = 0; -+ masks_[s][2] = 0; -+ } -+ #endif -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ -+ filter_s_ = 0; -+ ++filter_r_; -+ next_idx = 1; -+ -+ if (filter_r_ == problem_size_.R) { -+ filter_r_ = 0; -+ ++filter_t_; -+ -+ if (filter_t_ < problem_size_.T) { -+ next_idx = 2; -+ } -+ else { -+ filter_t_ = 0; -+ next_idx = 3; -+ } -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 3) { -+ filter_k_ += params_.filter_k_delta; -+ } -+ -+ clear_mask_(filter_k_ >= problem_size_.K); -+ } -+ -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][0] = Mask(0); -+ masks_[s][1] = Mask(0); -+ masks_[s][2] = Mask(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && -+ (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // This is specialized for unit stride -+ if (problem_size.stride() != Coord3D({1, 1, 1})) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorNotSupported; -+ } -+ -+ // Limit on filter size -+ if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..0519ebe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,291 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dFpropActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ int offset_n_[ThreadMap::Iterations::kStrided]; -+ int offset_z_[ThreadMap::Iterations::kStrided]; -+ int offset_p_[ThreadMap::Iterations::kStrided]; -+ int offset_q_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ offset_z_[s] = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ offset_p_[s] = residual / problem_size_.Q; -+ offset_q_[s] = residual % problem_size_.Q; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_c_ += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int n = offset_n_[iteration_strided_]; -+ int z = offset_z_[iteration_strided_]; -+ int p = offset_p_[iteration_strided_]; -+ int q = offset_q_[iteration_strided_]; -+ -+ int t = filter_t_; -+ int r = filter_r_; -+ int s = filter_s_; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - filter_t_); -+ r = (problem_size_.R - 1 - filter_r_); -+ s = (problem_size_.S - 1 - filter_s_); -+ } -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..c51eb59 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,478 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_ -+> -+class Conv3dFpropActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ using Mask = uint64_t; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dFpropActivationIteratorOptimizedParams; -+ -+private: -+ -+ Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ -+ // One pointer per access -+ char const *pointer_[ThreadMap::Iterations::kStrided]; -+ -+ // current filter position (t, r, s) -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ // mask for t, r, and s -+ Index masks_[ThreadMap::Iterations::kStrided][3]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorOptimized( -+ Conv3dFpropActivationIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles -+ ) : -+ params_(params), -+ problem_size_(problem_size), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ int offset_n[ThreadMap::Iterations::kStrided]; -+ int offset_z[ThreadMap::Iterations::kStrided]; -+ int offset_p[ThreadMap::Iterations::kStrided]; -+ int offset_q[ThreadMap::Iterations::kStrided]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ pointer_[s] = reinterpret_cast(ptr); -+ -+ int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // offset_n[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // offset_z[s] = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // offset_p[s] = residual / problem_size_.Q; -+ // offset_q[s] = residual % problem_size_.Q; -+ // -+ -+ int residual; -+ -+ // input: (nzpq offset) output: (n offset and resudial (zpq offset)) -+ params.zpq_divmod(offset_n[s], residual, offset_nzpq); -+ // input: (zpq offset) output: (z offset and resudial (pq)) -+ params.pq_divmod(offset_z[s], residual, residual); -+ // input: (pq offset) output: (p offset and resudial (q offset)) -+ params.q_divmod(offset_p[s], offset_q[s], residual); -+ -+ TensorCoord coord = at_(offset_n[s], offset_z[s], offset_p[s], offset_q[s], 0, 0, 0); -+ -+ pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; -+ } -+ -+ clear_mask(); -+ -+ // mask predicates for filter position T -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int t = 0; t < problem_size_.T; ++t) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int t_ = t; -+ if (problem_size_.mode == Mode::kConvolution) { -+ t_ = problem_size_.T - 1 - t; -+ } -+ -+ int d = offset_z[s_idx] * problem_size_.stride_d - problem_size_.pad_d + t_ * problem_size_.dilation_d; -+ -+ bool pred = (offset_n[s_idx] < problem_size_.N && d >= 0 && d < problem_size_.D); -+ masks_[s_idx][0] |= (pred << t); -+ } -+ } -+ -+ // mask predicates for filter position R -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int r = 0; r < problem_size_.R; ++r) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int r_ = r; -+ if (problem_size_.mode == Mode::kConvolution) { -+ r_ = problem_size_.R - 1 - r; -+ } -+ -+ int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; -+ -+ bool pred = (h >= 0 && h < problem_size_.H); -+ masks_[s_idx][1] |= (pred << r); -+ } -+ } -+ -+ // mask predicates for filter position S -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int s = 0; s < problem_size_.S; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { -+ -+ int s_ = s; -+ if (problem_size_.mode == Mode::kConvolution) { -+ s_ = problem_size_.S - 1 - s; -+ } -+ -+ int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; -+ -+ bool pred = (w >= 0 && w < problem_size_.W); -+ masks_[s_idx][2] |= (pred << s); -+ } -+ } -+ -+ if (filter_c_ >= problem_size.C) { -+ clear_mask(); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); -+ } -+ -+private: -+ -+ /// Returns the coordinate in the activations tensor X that is correspoinding to -+ // output nzpq and filter position t, r, s -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int n, int z, int p, int q, int t, int r, int s) const { -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = problem_size_.T - 1 - t; -+ r = problem_size_.R - 1 - r; -+ s = problem_size_.S - 1 - s; -+ } -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_); -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_byte_offset_(LongIndex byte_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ pointer_[s] += byte_offset; -+ } -+ } -+ -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask_(bool clear) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ // We are using inline PTX assembly here to avoid an CUDA C++ compilation -+ // artifact in which control flow instructions are generated. Instead, our -+ // intent is to predicate the mov instructions. -+ #if defined(__CUDA_ARCH__) -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][0]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][0]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][1]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][1]) -+ ); -+ asm volatile( -+ "{\n" -+ " .reg .pred p;\n" -+ " .reg .u32 m;" -+ " mov.u32 m, %2;" -+ " setp.ne.b32 p, %1, 0;\n" -+ " @p mov.u32 m, 0;\n" -+ " mov.u32 %0, m;\n" -+ "}\n" -+ : -+ "=r"(masks_[s][2]) -+ : -+ "r"((int)clear), -+ "r"(masks_[s][2]) -+ ); -+ #else -+ if (clear) { -+ masks_[s][0] = 0; -+ masks_[s][1] = 0; -+ masks_[s][2] = 0; -+ } -+ #endif -+ } -+ } -+ -+public: -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ add_byte_offset_(pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ int next_idx = 0; -+ -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ == problem_size_.S) { -+ -+ filter_s_ = 0; -+ ++filter_r_; -+ next_idx = 1; -+ -+ if (filter_r_ == problem_size_.R) { -+ filter_r_ = 0; -+ ++filter_t_; -+ -+ if (filter_t_ < problem_size_.T) { -+ next_idx = 2; -+ } -+ else { -+ filter_t_ = 0; -+ next_idx = 3; -+ } -+ } -+ } -+ -+ add_byte_offset_(params_.inc_next[next_idx]); -+ -+ if (next_idx == 3) { -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ clear_mask_(filter_c_ >= problem_size_.C); -+ } -+ -+ /// Clears the predicates -+ CUTLASS_HOST_DEVICE -+ void clear_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ masks_[s][0] = Mask(0); -+ masks_[s][1] = Mask(0); -+ masks_[s][2] = Mask(0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ return -+ (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && -+ (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && -+ (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast(pointer_[iteration_strided_]); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationTileAccessIteratorOptimized &operator++() { -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // Conv3dFpropActivationTileAccessIteratorOptimized has constraint on filter positions -+ // due to the number of mask bits. -+ if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { -+ return Status::kErrorNotSupported; -+ } -+ return Status::kSuccess; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..41d87fe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dFpropFilterTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Conv3dAnalyticParams; -+ -+private: -+ -+ Params const ¶ms_; -+ ConvProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_t_; -+ int filter_r_; -+ int filter_s_; -+ int filter_c_; -+ -+ int offset_k_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ ConvProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_t_(0), -+ filter_r_(0), -+ filter_s_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_s_; -+ if (filter_s_ < problem_size_.S) { -+ return; -+ } -+ filter_s_ = 0; -+ -+ ++filter_r_; -+ if (filter_r_ < problem_size_.R) { -+ return; -+ } -+ filter_r_ = 0; -+ -+ ++filter_t_; -+ if (filter_t_ < problem_size_.T) { -+ return; -+ } -+ filter_t_ = 0; -+ -+ filter_c_ += Shape::kRow * problem_size_.split_k_slices; -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = offset_k_[iteration_strided_]; -+ -+ return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(ConvProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..c6c6f6f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h -@@ -0,0 +1,277 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename Layout_, -+ typename ThreadMap_ -+> -+class Conv3dFpropFilterTileAccessIteratorOptimized{ -+public: -+ -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dFpropFilterIteratorOptimizedParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dFpropFilterIteratorOptimizedParams const &base): -+ Conv3dFpropFilterIteratorOptimizedParams(base) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): -+ Conv3dFpropFilterIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ) { -+ -+ } -+ }; -+ -+private: -+ -+ Conv3dFpropFilterIteratorOptimizedParams const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_trs_; -+ int filter_c_; -+ -+ // -+ // Assertions -+ // -+ -+ // We map predicates into bits packed in this uint32_t container -+ static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, -+ "Currently, the number of loads per iteration is limited by the size of the predicates container."); -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorOptimized( -+ Conv3dFpropFilterIteratorOptimizedParams const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_{0}, -+ filter_trs_(0), -+ filter_c_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); -+ Index column = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); -+ predicates_ |= (pred << s); -+ } -+ -+ if (filter_c_ >= problem_size.C) { -+ predicates_ = 0u; -+ } -+ -+ pointer_ += ( -+ params_.layout({filter_c_, column}) -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ LongIndex next = params_.inc_next_trs; -+ -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == params_.TRS) { -+ -+ filter_trs_ = 0; -+ next = params_.inc_next_c; -+ filter_c_ += params_.filter_c_delta; -+ } -+ -+ if (filter_c_ >= problem_size_.C) { -+ predicates_ = 0; -+ } -+ -+ pointer_ += next; -+ } -+ -+ /// Returns true if the current coordinate is within the filter tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return (predicates_ & (1u << iteration_strided_)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ return reinterpret_cast(pointer_); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ -+ // Move to the next K coordinate within the tile -+ pointer_ += params_.inc_next_k; -+ -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h -new file mode 100644 -index 0000000..180dca5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_params.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Params structure used for all Conv3d analytic tile iterators -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dAnalyticParams { -+ -+ using Layout = Layout_; -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dAnalyticParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dAnalyticParams( -+ Conv3dProblemSize const &, // unused; placeholder to match other Params interfaces. -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dFpropActivationIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized -+template<> -+struct Conv3dFpropActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[4]; // {next S, next R, next T, next C} -+ int filter_c_delta; // number of logical elements to add to filter_c_ -+ int ZPQ; // product of Z*P*Q -+ int PQ; // product of P*Q -+ -+ FastDivmod zpq_divmod; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropActivationIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ PQ(problem_size.P * problem_size.Q), -+ ZPQ(problem_size.Z * problem_size.P * problem_size.Q), -+ zpq_divmod(ZPQ), -+ pq_divmod(PQ), -+ q_divmod(problem_size.Q) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_fprop", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next T -+ inc_next[2] = conv_sign * ( -+ int64_t(layout.stride()[2]) * problem_size.dilation_d -+ - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next C -+ inc_next[3] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template< typename Layout_ = layout::TensorNDHWC > -+struct Conv3dFpropFilterIteratorOptimizedParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template<> -+struct Conv3dFpropFilterIteratorOptimizedParams -+{ -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ int TRS; -+ int filter_c_delta; -+ -+ int64_t inc_next_k; // offset in units of bytes to next K position -+ int64_t inc_next_trs; // offset in units of bytes to next TRS position -+ int64_t inc_next_c; // offset in units of bytes to next C position -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dFpropFilterIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_fprop", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ TRS = problem_size.T * problem_size.R * problem_size.S; -+ -+ inc_next_k = (int64_t(layout.stride()[3]) * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_trs = -+ ( int64_t(layout.stride()[0]) -+ - int64_t(layout.stride()[3]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() -+ ) * element_size_bits / 8; -+ -+ inc_next_c = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices -+ - int64_t(TRS - 1) * layout.stride()[0] -+ - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv3d DGRAD OutputGradient (dy) iterator -+struct Conv3dDgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int64_t inc_next[4]; // {next S, next R, next T, next K} -+ int filter_k_delta; // number of logical elements to add to filter_k_ -+ -+ FastDivmod dhw_divmod; -+ FastDivmod hw_divmod; -+ FastDivmod w_divmod; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradOutputGradientIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), -+ dhw_divmod(problem_size.D * problem_size.H * problem_size.W), -+ hw_divmod(problem_size.H * problem_size.W), -+ w_divmod(problem_size.W) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_dgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); -+ -+ // next S -+ inc_next[0] = conv_sign * ( -+ int64_t(layout.stride()[0]) * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next R -+ inc_next[1] = conv_sign * ( -+ int64_t(layout.stride()[1]) * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next T -+ inc_next[2] = conv_sign * ( -+ int64_t(layout.stride()[2]) * problem_size.dilation_d -+ - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // next K -+ inc_next[3] = ( -+ threadblock_shape.column() * problem_size.split_k_slices -+ - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d -+ - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h -+ - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w -+ ) * element_size_bits / 8; -+ -+ // logical offset added to internal channel counter - units are elements, not bytes -+ filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters object for Conv2d DGRAD Filter (w) iterator -+struct Conv3dDgradFilterIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ int TRS; -+ int filter_k_delta; -+ -+ int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile -+ int64_t inc_next_trs; // offset in units of bytes to next TRS position -+ int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dDgradFilterIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, ///< size of each element in bits -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): -+ layout(layout), TRS(problem_size.T * problem_size.R * problem_size.S) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; -+ -+ inc_next_trs = -+ ( (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ inc_next_k = -+ ( -+ threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] -+ - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] -+ - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] -+ ) * element_size_bits / 8; -+ -+ filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; -+ } -+}; -+ -+/// Parameters object for Conv3d WGRAD OutputGradient iterator -+struct Conv3dWgradOutputGradientIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ using LongIndex = typename Layout::LongIndex; -+ -+ Layout layout; -+ -+ int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates -+ int ZPQ; // product of Z*P*Q -+ unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ -+ unsigned zpq_shr; // in device code. -+ -+ int PQ; // product of P*Q -+ unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ -+ unsigned pq_shr; // in device code. -+ -+ unsigned q_mul; // precomputed quantities for fast computation of div/% by Q -+ unsigned q_shr; // in device code. -+ -+ LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile -+ LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile -+ LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 -+ offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ offset_next_contiguous = (threadmap_delta.contiguous()) -+ * element_size_bits / 8; -+ -+ inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) -+ * element_size_bits / 8; -+ -+ // Precompute several quantities for fast modulo arithmetic. -+ NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; -+ ZPQ = problem_size.Z * problem_size.P * problem_size.Q; -+ find_divisor(zpq_mul, zpq_shr, ZPQ); -+ -+ PQ = problem_size.P * problem_size.Q; -+ find_divisor(pq_mul, pq_shr, PQ); -+ -+ find_divisor(q_mul, q_shr, problem_size.Q); -+ -+ } -+}; -+ -+/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator -+struct Conv3dWgradActivationIteratorOptimizedParams { -+ -+ using Layout = layout::TensorNDHWC; -+ -+ Layout layout; -+ -+ int RSC; // product of R*S*C -+ unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC -+ unsigned rsc_shr; // in device code. -+ -+ int SC; // product of S*C -+ unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC -+ unsigned sc_shr; // in device code. -+ -+ unsigned c_mul; // precomputed quantities for fast computation of div/% by C -+ unsigned c_shr; // in device code. -+ -+ int ZPQ; // product of Z*P*Q -+ unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ -+ unsigned zpq_shr; // in device code. -+ -+ int PQ; // product of P*Q -+ unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ -+ unsigned pq_shr; // in device code. -+ -+ unsigned q_mul; // precomputed quantities for fast computation of div/% by Q -+ unsigned q_shr; // in device code. -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationIteratorOptimizedParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationIteratorOptimizedParams( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout, -+ int element_size_bits, -+ MatrixCoord threadblock_shape, -+ int thread_count, -+ int access_size, -+ layout::PitchLinearCoord threadmap_iterations, -+ layout::PitchLinearCoord threadmap_delta -+ ): layout(layout) { -+ -+ TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", -+ element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); -+ -+ // Precompute several quantities for fast modulo arithmetic. -+ RSC = problem_size.R * problem_size.S * problem_size.C; -+ find_divisor(rsc_mul, rsc_shr, RSC); -+ -+ SC = problem_size.S * problem_size.C; -+ find_divisor(sc_mul, sc_shr, SC); -+ -+ find_divisor(c_mul, c_shr, problem_size.C); -+ -+ ZPQ = problem_size.Z * problem_size.P * problem_size.Q; -+ find_divisor(zpq_mul, zpq_shr, ZPQ); -+ -+ PQ = problem_size.P * problem_size.Q; -+ find_divisor(pq_mul, pq_shr, PQ); -+ -+ find_divisor(q_mul, q_shr, problem_size.Q); -+ -+ } -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..d9fe9ad ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradActivationTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ -+ static int const kAccessesPerVector = 1; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_t_[ThreadMap::Iterations::kContiguous]; -+ int filter_r_[ThreadMap::Iterations::kContiguous]; -+ int filter_s_[ThreadMap::Iterations::kContiguous]; -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize t,r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); -+ int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); -+ -+ filter_r_[c] = residual / (problem_size_.S * problem_size_.C); -+ residual = residual % (problem_size_.S * problem_size_.C); -+ -+ filter_s_[c] = residual / problem_size_.C; -+ filter_c_[c] = residual % problem_size_.C; -+ -+ } -+ -+ // initialize n, z, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int t = filter_t_[iteration_contiguous_]; -+ int r = filter_r_[iteration_contiguous_]; -+ int s = filter_s_[iteration_contiguous_]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ int z = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; -+ int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; -+ int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..2d56341 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradActivationTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dWgradActivationIteratorOptimizedParams { -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dWgradActivationIteratorOptimizedParams const &base) -+ : Conv3dWgradActivationIteratorOptimizedParams(base) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dProblemSize const &problem_size, Layout const &layout) -+ : Conv3dWgradActivationIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ // Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k -+ // required for nzpq -> ndhw translation -+ int precomputed_filter_t_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; -+ int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int filter_c_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize t,r,s,c filter position for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); -+ // int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); -+ // -+ // filter_r_[c] = residual / (problem_size_.S * problem_size_.C); -+ // residual = residual % (problem_size_.S * problem_size_.C); -+ // -+ // filter_s_[c] = residual / problem_size_.C; -+ // filter_c_[c] = residual % problem_size_.C; -+ -+ int residual; -+ fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr); -+ fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr); -+ fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr); -+ -+ int t = precomputed_filter_t_[c]; -+ int r = precomputed_filter_r_[c]; -+ int s = precomputed_filter_s_[c]; -+ -+ if (problem_size_.mode == Mode::kConvolution) { -+ t = (problem_size_.T - 1 - t); -+ r = (problem_size_.R - 1 - r); -+ s = (problem_size_.S - 1 - s); -+ } -+ -+ // efective t,r,s for every contiguous dimension -+ precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d; -+ precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; -+ precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; -+ -+ -+ } -+ -+ // initialize n, z, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ } -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the activation tensor x that is currently pointed to -+ /// by the iterator. -+ -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // int z = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, z, p, q; -+ fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr); -+ fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); -+ fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); -+ -+ int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; -+ int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];; -+ int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; -+ -+ return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); -+ } -+ -+ /// Returns true if the current coordinate is within the activation tensor x -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() >= 0 && coord.d() < problem_size_.D && -+ coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && -+ coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradActivationTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h -new file mode 100644 -index 0000000..c21d3f9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params { -+ -+ Layout layout; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Conv3dProblemSize const &problem_size, -+ Layout const &layout -+ ): layout(layout) { -+ -+ } -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ int filter_k_[ThreadMap::Iterations::kContiguous]; -+ -+ int offset_nzpq_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorAnalytic( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)) { -+ -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ // initialize filter_k for every contiguous iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() -+ + c * ThreadMap::Delta::kContiguous; -+ } -+ -+ // initialize n, p, q offset for every strided iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided() -+ + s * ThreadMap::Delta::kStrided; -+ -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices; -+ } -+ } -+ -+ /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int nzpq = offset_nzpq_[iteration_strided_]; -+ -+ int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ -+ int z = residual / (problem_size_.P * problem_size_.Q); -+ residual = residual % (problem_size_.P * problem_size_.Q); -+ -+ int p = residual / problem_size_.Q; -+ int q = residual % problem_size_.Q; -+ -+ return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]); -+ } -+ -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && -+ coord.d() < problem_size_.Z && -+ coord.h() < problem_size_.P && -+ coord.w() < problem_size_.Q && -+ coord.c() < problem_size_.K; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h -new file mode 100644 -index 0000000..7a79983 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNDHWC layout of tensors in Global Memory. -+ -+ The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), -+ backward data gradient (Dgrad), and backward weight gradient (Wgrad). -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/conv/threadblock/conv3d_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, -+ typename Element_, -+ typename ThreadMap_ -+> -+class Conv3dWgradOutputGradientTileAccessIteratorOptimized { -+public: -+ -+ // -+ // Types -+ // -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorNDHWC; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AlignedArray; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 3; -+ using ConvProblemSize = typename conv::Conv3dProblemSize; -+ static int const kAccessesPerVector = 1; -+ static_assert(sizeof_bits::value >= 8, -+ "WGRAD requires elements of size 8b or greater."); -+ -+ // -+ // Parameters structure -+ // -+ -+ struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) -+ : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(Conv3dProblemSize const &problem_size, Layout const &layout) -+ : Conv3dWgradOutputGradientIteratorOptimizedParams( -+ problem_size, -+ layout, -+ sizeof_bits::value, -+ {Shape::kRow, Shape::kColumn}, -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} -+ }; -+ -+private: -+ -+ Params const ¶ms_; -+ Conv3dProblemSize const &problem_size_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ char const *pointer_; -+ -+ uint32_t predicates_; -+ int filter_k_; -+ int offset_nzpq_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv3dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ predicates_(0), -+ filter_k_(0), -+ offset_nzpq_(0) { -+ -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); -+ offset_nzpq_ = threadblock_offset.column() + thread_coord.strided(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; -+ int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided; -+ -+ bool predicate = valid_(at_(offset_nzpq, filter_k)); -+ -+ uint32_t pred = (predicate ? 1u : 0); -+ -+ int pred_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ predicates_ |= (pred << pred_idx); -+ } -+ } -+ -+ // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) -+ pointer_ += ( -+ offset_nzpq_ * params.layout.stride()[0] + filter_k_ -+ ) * sizeof_bits::value / 8; -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile -+ offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices; -+ -+ // Clear predicates if needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) { -+ uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); -+ predicates_ = (predicates_ & (~kClearMask)); -+ } -+ } -+ pointer_ += params_.inc_next_nzpq; -+ } -+ -+private: -+ /// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at_(int offset_nzpq, int k) const { -+ -+ // The subseqnet fast_divmod() operations are equivalent to the following logical computation: -+ // -+ // -+ // int nzpq = offset_nzpq_; -+ // int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); -+ // -+ // int z = residual / (problem_size_.P * problem_size_.Q); -+ // residual = residual % (problem_size_.P * problem_size_.Q); -+ // -+ // int p = residual / problem_size_.Q; -+ // int q = residual % problem_size_.Q; -+ -+ int residual, n, z, p, q; -+ fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr); -+ fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); -+ fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); -+ -+ return TensorCoord(n, z, p, q, k); -+ } -+ -+ /// Returns true if the coord is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid_(TensorCoord coord) const { -+ -+ return coord.n() < problem_size_.N && -+ coord.c() < problem_size_.K; -+ } -+ -+public: -+ -+ /// Returns true if the current coordinate is within the output gradient tensor Dy -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; -+ return (predicates_ & (1u << pred_idx)); -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ iteration_strided_ * params_.offset_next_strided + -+ iteration_contiguous_ * params_.offset_next_contiguous -+ ); -+ -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() { -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv3dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % (128/sizeof_bits::value)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h -new file mode 100644 -index 0000000..86b5bc4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Extracts the host-params objects into non-template code. -+*/ -+ -+#pragma once -+ -+#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED -+#include -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -+template -+struct Depthwise2dFpropDirectConvParams; -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -+template -+struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; -+ -+/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -+template -+struct Depthwise2dFpropDirectConvFilterIteratorParams; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized -+template<> -+struct Depthwise2dFpropDirectConvParams { -+ -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int32_t activation_tile_h; -+ int32_t activation_tile_w; -+ int32_t activation_tile_hw; -+ FastDivmod activation_tile_w_divmod; -+ -+ int filter[2]; -+ int stride[2]; -+ int dilation[2]; -+ int inc_next[2]; -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int activation_load_count; -+ int activation_storage_elements; -+ int activation_size; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< layout object -+ MatrixCoord threadblock_shape, ///< CTA threadblock Shape -+ Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock -+ const int element_size_bits, ///< bits of activation element -+ const int thread_count, ///< threads per threadblock -+ const int thread_count_contiguous, ///< number of threads for continuous dimension -+ const int element_per_load) ///< element per each load -+ : layout(layout) { -+ -+ filter[0] = problem_size.S; -+ filter[1] = problem_size.R; -+ -+ stride[0] = problem_size.stride_w; -+ stride[1] = problem_size.stride_h; -+ -+ dilation[0] = problem_size.dilation_w; -+ dilation[1] = problem_size.dilation_h; -+ -+ // Compute activation_tile size per threadblock because stride and dilation are runtime params. -+ activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + -+ (problem_size.R - 1) * problem_size.dilation_h + 1; -+ activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + -+ (problem_size.S - 1) * problem_size.dilation_w + 1; -+ activation_tile_hw = activation_tile_h * activation_tile_w; -+ -+ activation_tile_w_divmod = FastDivmod(activation_tile_w); -+ -+ /// Below two values could not be templatized because the stride and dilation are runtime params -+ activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; -+ activation_storage_elements = activation_load_count * element_per_load * thread_count; -+ activation_size = activation_storage_elements * element_size_bits / 8; -+ -+ // Fastdivmod for output P, Q -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); -+ int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / -+ (threadblock_output_shape.w()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ -+ // next S -+ inc_next[0] = problem_size.dilation_w; -+ // next R -+ inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation -+template <> -+struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ int activation_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< Layout object -+ MatrixCoord threadblock_shape, ///< Threadblock Shape -+ Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock -+ const int activation_size_ ///< Activation size loaded by iterator -+ ) -+ : layout(layout), -+ activation_size(activation_size_) { -+ // Fastdivmod for output P, Q -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); -+ int tiles_q = -+ (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized -+template <> -+struct Depthwise2dFpropDirectConvFilterIteratorParams { -+ using Layout = layout::TensorNHWC; -+ -+ Layout layout; -+ -+ int filter_size; -+ -+ bool is_convolution; -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvFilterIteratorParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ Depthwise2dFpropDirectConvFilterIteratorParams( -+ Conv2dProblemSize const &problem_size, -+ Layout const &layout, ///< Layout object -+ MatrixCoord threadblock_shape, ///< Threadblock Shape -+ const int filter_size_) ///< Filter size loaded by iterator -+ : layout(layout), -+ filter_size(filter_size_), -+ is_convolution(problem_size.mode == Mode::kConvolution){} -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h -new file mode 100644 -index 0000000..80ec5d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h -@@ -0,0 +1,314 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template > -+class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { -+ public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using OutputTileShape = OutputTileShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ // Compilation value of stride , dialtion and activation shape -+ using StrideShape = StrideShape_; -+ using DilationShape = DilationShape_; -+ using ActivationShape = ActivationShape_; -+ -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * -+ sizeof_bits::value / 8; -+ -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); -+ -+ static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); -+ static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; -+ -+ private: -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ char const *pointer_; -+ -+ // Base channels for current threadblock -+ int base_c_; -+ // Base activation index for current threadblock -+ int offset_intial_npq_; -+ // Base activation coord for current threadblock -+ TensorCoord activatioin_base_; -+ // Intial thread positioin -+ int offset_initial_hwc_; -+ // Overall load instruction per thread. -+ int iterator_load_; -+ // thread loading position. -+ int iterator_hwc_; -+ // activation N is inside the Tensor or not -+ bool valid_n_; -+ -+ public: -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = -+ MatrixCoord() -+ ) -+ : params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ offset_intial_npq_(threadblock_offset.row()), -+ offset_initial_hwc_(thread_idx), -+ iterator_load_(0) { -+ -+ base_c_ = threadblock_offset.column(); -+ -+ set_iteration_index(0); -+ -+ set_activation_coord(offset_intial_npq_); -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_activation_coord(int offset_npq) { -+ int offset_inital_n, offset_inital_p, offset_inital_q; -+ int residual; -+ -+ params_.pq_divmod(offset_inital_n, residual, offset_npq); -+ params_.q_divmod(offset_inital_p, offset_inital_q, residual); -+ -+ int base_n = offset_inital_n; -+ -+ int base_h = -+ offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; -+ -+ int base_w = -+ offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; -+ -+ activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); -+ -+ valid_n_ = activatioin_base_.n() < problem_size_.N; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params( -+ problem_size, -+ layout, -+ {Shape::kRow, Shape::kColumn}, -+ {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, -+ kActivationSize); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; -+ iterator_load_ = index; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Go to next threadblock -+ offset_intial_npq_ += problem_size_.split_k_slices; -+ -+ set_iteration_index(0); -+ -+ set_activation_coord(offset_intial_npq_); -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; -+ int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; -+ int h = next / ActivationShape::kW; -+ int w = next % ActivationShape::kW; -+ -+ c = c * AccessType::kElements; -+ -+ return activatioin_base_ + TensorCoord(0, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ bool valid_c = coord.c() < problem_size_.C; -+ bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; -+ bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; -+ return valid_n_ ? valid_c & valid_h & valid_w : 0; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = -+ reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { -+ -+ ++iterator_load_; -+ iterator_hwc_ += ThreadMap::kThreads; -+ -+ if (iterator_load_ < ThreadMap::Iterations::kCount) { -+ return *this; -+ } -+ -+ iterator_load_ = 0; -+ iterator_hwc_ = offset_initial_hwc_; -+ -+ return *this; -+ } -+ -+ /// Determines the activation size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return kActivationSize; -+ } -+ -+ /// Determines the iterations needed -+ CUTLASS_HOST_DEVICE -+ int get_iteration_num() { -+ return ThreadMap::Iterations::kCount; -+ } -+ -+ /// Determines whether the Depthwise fprop can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check stride and dilation constraint -+ if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h -new file mode 100644 -index 0000000..3439d46 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h -@@ -0,0 +1,291 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template > -+class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { -+ public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using OutputTileShape = OutputTileShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using TensorCoord = typename Layout::TensorCoord; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); -+ -+ static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); -+ static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); -+ -+ // -+ // Parameters structure -+ // -+ -+ using Params = Depthwise2dFpropDirectConvParams; -+ -+ private: -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ char const *pointer_; -+ -+ // Base channels for current threadblock -+ int base_c_; -+ // Base activation index for current threadblock -+ int offset_intial_npq_; -+ // Base activation coord for current threadblock -+ TensorCoord activatioin_base_; -+ // Intial thread positioin -+ int offset_initial_hwc_; -+ // Overall load instruction per thread. -+ int iterator_load_; -+ // thread loading position. -+ int iterator_hwc_; -+ // Number of loads for activations tensor X. -+ const int number_of_loads_; -+ -+ public: -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = -+ MatrixCoord() -+ ) -+ : params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ offset_intial_npq_(threadblock_offset.row()), -+ offset_initial_hwc_(thread_idx), -+ iterator_load_(0), -+ number_of_loads_(params.activation_load_count) { -+ -+ base_c_ = threadblock_offset.column(); -+ -+ set_activation_coord(offset_intial_npq_); -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_activation_coord(int offset_npq) { -+ int offset_inital_n, offset_inital_p, offset_inital_q; -+ int residual; -+ -+ params_.pq_divmod(offset_inital_n, residual, offset_npq); -+ params_.q_divmod(offset_inital_p, offset_inital_q, residual); -+ -+ int base_n = offset_inital_n; -+ -+ int base_h = -+ offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; -+ -+ int base_w = -+ offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; -+ -+ activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params( -+ problem_size, -+ layout, -+ {Shape::kRow, Shape::kColumn}, -+ {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, -+ sizeof_bits::value, -+ ThreadMap::kThreads, -+ ThreadMap::Detail::ShapeVec::kContiguous, -+ ThreadMap::kElementsPerAccess); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; -+ iterator_load_ = index; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Go to next threadblock -+ offset_intial_npq_ += problem_size_.split_k_slices; -+ -+ set_activation_coord(offset_intial_npq_); -+ } -+ -+ /// Returns the coordinate in the activations tensor X that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; -+ int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; -+ int h, w; -+ params_.activation_tile_w_divmod(h, w, next) ; -+ -+ c = c * AccessType::kElements; -+ -+ return activatioin_base_ + TensorCoord(0, h, w, c); -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor X -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && -+ coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ LongIndex offset = params_.layout(coord); -+ -+ AccessType const *ptr = -+ reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); -+ -+ return ptr; -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { -+ -+ ++iterator_load_; -+ iterator_hwc_ += ThreadMap::kThreads; -+ -+ if (iterator_load_ < number_of_loads_) { -+ return *this; -+ } -+ -+ iterator_load_ = 0; -+ iterator_hwc_ = offset_initial_hwc_; -+ -+ return *this; -+ } -+ -+ /// Determines the activation size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return params_.activation_size; -+ } -+ -+ /// Determines the iterations needed -+ CUTLASS_HOST_DEVICE -+ int get_iteration_num() { -+ return number_of_loads_; -+ } -+ -+ /// Determines whether the Depthwise fprop can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.C % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h -new file mode 100644 -index 0000000..26bbe57 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h -@@ -0,0 +1,551 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/conv/threadblock/depthwise_mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Epilogue stores the data into global memory -+ typename Epilogue_, -+ /// iterator implementation variants -+ conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DepthwiseFpropDirectConvMultipleStage : -+ public DepthwiseDirectConvMmaBase { -+public: -+ ///< Base class -+ using Base = DepthwiseDirectConvMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using Epilogue = Epilogue_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseFpropDirectConvMultipleStage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { -+ // Number of iterators is a static value. -+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ ++this->smem_iterator_A_; -+ } -+ } else { -+ // Number of iterators is a runtime value. -+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ ++this->smem_iterator_A_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA &iterator_A, -+ ///< Params of global memory iterator -+ typename IteratorA::Params const &iterator_a_params, -+ ///< iterator over B operand in global memory -+ IteratorB &iterator_B, -+ ///< Params of global memory iterator -+ typename IteratorB::Params const &iterator_b_params, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ /// Epilogue -+ Epilogue &epilogue, -+ ///< Output operator -+ typename Epilogue::OutputOp const &output_op, -+ ///< Tile iterator for destination -+ typename Epilogue::OutputTileIterator &destination_iterator, -+ ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ typename Epilogue::OutputTileIterator &source_iterator, -+ -+ int split_k_slices = 1 -+ ) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { -+ -+ if (stage == 0) { -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ -+ // Number of iterators is compilation static. -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ } else { -+ // Number of iterators is a runtime value. -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ ///////////////////////////////////////////////////////////////////////////// -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); -+ -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ unsigned int iterations = 0; -+ constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. -+ -+ accum.clear(); -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { -+ if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k == 0) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); -+ } -+ -+ if (warp_mma_k < Base::kWarpGemmIterations) { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ if (warp_mma_k + 1 == inner_loop_iterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == inner_loop_iterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next cta -+ iterator_A.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); -+ -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); -+ smem_read_stage_idx = 0; -+ } else { -+ this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); -+ ++smem_read_stage_idx; -+ } -+ -+ if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { -+ this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); -+ } -+ -+ // goback to start position. B has no multiple stage -+ this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ // -+ // Epilogue -+ // -+ int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); -+ -+ destination_iterator.set_tile_index(iterations * split_k_slices); -+ -+ source_iterator.set_tile_index(iterations * split_k_slices); -+ -+ epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); -+ -+ ++iterations; -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h -new file mode 100644 -index 0000000..e9153c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) -+ matrix from memory. -+ -+ This iterator assumes TensorNHWC layout of tensors in Global Memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+template > -+class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { -+public: -+ // -+ // Types -+ // -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ using TensorRef = cutlass::TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; -+ static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; -+ static int const kConvDim = 2; -+ using ConvProblemSize = typename conv::Conv2dProblemSize; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * -+ sizeof_bits::value / 8; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ // -+ // Simplifying assertions -+ // -+ static_assert(ThreadMap::Iterations::kContiguous == 1, -+ "Require Iterations::kContiguous == 1"); -+ -+ // -+ // Parameters structure -+ // -+ using Params = Depthwise2dFpropDirectConvFilterIteratorParams; -+ -+ protected: -+ -+ Conv2dProblemSize const &problem_size_; -+ Params const ¶ms_; -+ LongIndex iteration_contiguous_; -+ LongIndex iteration_strided_; -+ LongIndex iteration_vector_; -+ char const *pointer_; -+ -+ int filter_k_; -+ int offset_trs_[ThreadMap::Iterations::kStrided]; -+ -+public: -+ -+ -+ -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( -+ Params const ¶ms, -+ Conv2dProblemSize const &problem_size, -+ Element const *ptr, -+ int thread_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ params_(params), -+ problem_size_(problem_size), -+ pointer_(reinterpret_cast(ptr)), -+ filter_k_(0) { -+ -+ layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); -+ -+ filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { -+ return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(Index index) { -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset * 8 / sizeof_bits::value; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // Do nothing because the filter is persistent in the SMEM -+ } -+ -+ /// Returns the coordinate in the filter tensor W that is currently pointed to -+ /// by the iterator. -+ CUTLASS_HOST_DEVICE -+ TensorCoord at() const { -+ -+ int k = filter_k_ + iteration_vector_ * AccessType::kElements; -+ int trs = offset_trs_[iteration_strided_]; -+ -+ return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix -+ } -+ -+ /// Returns true if the current coordinate is within the activations tensor W -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ TensorCoord coord = at(); -+ -+ return coord.n() < problem_size_.K && -+ coord.h() < Shape::kColumn; -+ } -+ -+ /// Returns a pointer to the vector starting at the current coordinate -+ CUTLASS_HOST_DEVICE -+ AccessType const *get() const { -+ TensorCoord coord = at(); -+ int64_t offset = coord.n(); -+ if (params_.is_convolution) { -+ offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; -+ } else { -+ offset += coord.h() * problem_size_.K; -+ } -+ -+ return reinterpret_cast(pointer_ + -+ offset * sizeof_bits::value / 8); -+ } -+ -+ /// Increments to the next memory access -+ CUTLASS_HOST_DEVICE -+ DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ iteration_vector_ = 0; -+ -+ ++iteration_contiguous_; -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ iteration_contiguous_ = 0; -+ -+ ++iteration_strided_; -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Determines the filter size loaded by iterator -+ CUTLASS_HOST_DEVICE -+ int get_load_size() { -+ return kFilterSize; -+ } -+ -+ /// Determines whether the Implicit GEMM can execute the given problem. -+ CUTLASS_HOST_DEVICE -+ static Status can_implement(Conv2dProblemSize const &problem_size) { -+ -+ // check alignment constraint on iterator's contiguous dimension -+ if (problem_size.K % AccessType::kElements) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // check whether runtime filter size is same as templated filter size. -+ if ((problem_size.R * problem_size.S) != Shape::kColumn) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h -new file mode 100644 -index 0000000..fd43e40 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h -@@ -0,0 +1,336 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to A operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseFpropPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ // Depthwise specific -+ int channel_start_index = 0; -+ int rs_plane_idx = 0; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ -+ // Reset interation index. -+ iterator_B.set_iteration_index(0); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ -+ // Move to next set of filter groups. -+ channel_start_index += Base::kWarpGemmIterations; -+ } -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ -+ rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h -new file mode 100644 -index 0000000..e839b9a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a directconv threadblock-scoped Depthwise kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// -+ typename ThreadMapA_, -+ /// -+ typename ThreadMapB_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct DepthwiseDirectConvMmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ using ThreadMapA = ThreadMapA_; -+ using ThreadMapB = ThreadMapB_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class DepthwiseDirectConvMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm:: -+ GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ /// kWarpGemmIterations could be even and odd. -+ static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape<1, // Not determined at compile-time :( -+ Shape::kN + Policy::SmemPaddingA::kRow>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = MatrixShape; // Tile N = 64? -+ -+ public: -+ // -+ // Data members -+ // -+ -+ // Let persistent B matrix in front of dynamic matrix A -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand -+ /// Not be determined at compile-time -- Just to get a Smem start address. -+ AlignedBuffer operand_A; -+ public: -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } -+ }; -+ -+ protected: -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ DepthwiseDirectConvMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h -new file mode 100644 -index 0000000..dadd2b4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h -@@ -0,0 +1,952 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/warp/mma_depthwise_simt.h" -+ -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/conv/threadblock/depthwise_mma_base.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" -+ -+#include "cutlass/arch/cache_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+namespace detail { -+// -+// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by -+// each partitions within warp. -+// The goal is for each thread's tile of elements to be as square as -+// possible for performance (4x4 will be faster than 2x8). -+template // The number of partitions within the warp -+struct SimtWarpShape { -+ // kP * kQ * WarpNumThreadsM = WarpShapeM -+ // If needed, enable more specializations. -+}; -+template <> -+struct SimtWarpShape<4, 4> { -+ static constexpr int kP = 1; -+ static constexpr int kQ = 1; -+}; -+ -+template <> -+struct SimtWarpShape<4, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 1; -+}; -+ -+template <> -+struct SimtWarpShape<4, 1> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+ -+template <> -+struct SimtWarpShape<8, 1> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<8, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+template <> -+struct SimtWarpShape<8, 4> { -+ static constexpr int kP = 1; -+ static constexpr int kQ = 2; -+}; -+ -+template <> -+struct SimtWarpShape<16, 1> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<16, 2> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+template <> -+struct SimtWarpShape<16, 4> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 2; -+}; -+ -+template -+struct SimtWarpShape<25, WarpNumThreadsM> { -+ static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); -+ static constexpr int kP = 5; -+ static constexpr int kQ = 5; -+}; -+ -+template <> -+struct SimtWarpShape<32, 1> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 8; -+}; -+ -+template <> -+struct SimtWarpShape<32, 2> { -+ static constexpr int kP = 4; -+ static constexpr int kQ = 4; -+}; -+ -+template <> -+struct SimtWarpShape<32, 4> { -+ static constexpr int kP = 2; -+ static constexpr int kQ = 4; -+}; -+ -+} // namespace detail -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_ = 0, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeB_ = 0, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DepthwiseMmaCoreWithLaneAccessSize; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of threadblock-scoped output tile -+ typename ThreadBlockOutputShape, -+ /// Shape of filter shape per threadblock -+ typename FilterShape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_ = 0, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeB_ = 0, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB, -+ bool IsComplex -+> -+struct DepthwiseMmaCoreWithLaneAccessSize< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> : cutlass::gemm::threadblock::DefaultMmaCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access (a value of -1 indicates the default) -+ int kLaneAccessSizeA_, -+ /// Size of a warp-scoped per thread access (a value of -1 indicates the default) -+ int kLaneAccessSizeB_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DepthwiseMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ kLaneAccessSizeB_, -+ 2, -+ Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ 2, -+ Operator_> { -+ using Base = cutlass::gemm::threadblock::DefaultMmaCore, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ 2, -+ Operator_>; -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static int const kLaneAccessSizeA = kLaneAccessSizeA_; -+ static int const kLaneAccessSizeB = kLaneAccessSizeB_; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = typename Base::WarpCount; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory are same as base class -+ // -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) -+ typename ThreadBlockOutputShape_, -+ /// Shape of filter shape per threadblock -+ typename FilterShape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_, -+ /// Number of stages -+ int Stages_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ 128, -+ Stages_, -+ Operator_> { -+ using Shape = Shape_; -+ using FilterShape = FilterShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static int const kLaneAccessSizeB = 128; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = cutlass::gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ 1 -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // For Gmem load -+ static int const kElementsPerAccessA = 128 / sizeof_bits::value; -+ static int const kElementsPerAccessB = 128 / sizeof_bits::value; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. -+ kThreads, -+ kElementsPerAccessA -+ >; -+ -+ /// ThreadMap of iterator A -+ using SmemThreadMapA = IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value -+ ElementA, -+ SmemLayoutA, -+ 0, -+ SmemThreadMapA, // was IteratorThreadMapA -+ true // Dynamic iterations. -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessB -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB, // was IteratorThreadMapB -+ false // static iterations. -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ // Groups per threads -+ // Fp32: 2 groups -+ // Fp16: 2 groups -+ static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; -+ // Define the warp-level op -+ static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); -+ static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; -+ -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ -+ // Get output P, Q per thread -+ static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; -+ static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; -+ -+ static const int LaneLayout = 1; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); -+ -+ // Define the output tile computed by each thread -+ using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; -+ -+ // Fetch the channel with same access size -+ static const int LaneM = LaneN; -+ -+ // No paddings -+ static int const kPaddingM = 0; -+ static int const kPaddingN = 0; -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> -+ ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ IteratorThreadMapA, -+ IteratorThreadMapB, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) -+ typename ThreadBlockOutputShape_, -+ /// Shape of filter shape per threadblock -+ typename FilterShape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a warp-scoped per thread access -+ int kLaneAccessSizeA_, -+ /// Number of stages -+ int Stages_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_> -+struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, -+ ElementA_, -+ layout::RowMajor, -+ ElementB_, -+ layout::ColumnMajor, -+ ElementC_, -+ LayoutC_, -+ arch::OpClassSimt, -+ kLaneAccessSizeA_, -+ 128, -+ Stages_, -+ Operator_, -+ IteratorAlgorithm::kFixedStrideDilation, -+ StrideShape_, -+ DilationShape_, -+ ActivationShape_> { -+ using Shape = Shape_; -+ using FilterShape = FilterShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ using StrideShape = StrideShape_; -+ using DilationShape = DilationShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using ActivationShape = ActivationShape_; -+ -+ static int const kLaneAccessSizeB = 128; -+ -+ // Divisility requirements -+ static_assert( kLaneAccessSizeB > 0, -+ "Size of a warp-scoped per thread access should be larger then ZERO" ); -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = cutlass::gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ 1 -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // For Gmem load -+ static int const kElementsPerAccessA = 128 / sizeof_bits::value; -+ static int const kElementsPerAccessB = 128 / sizeof_bits::value; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessA -+ >; -+ -+ /// ThreadMap of iterator A -+ using SmemThreadMapA = IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ SmemThreadMapA, // was IteratorThreadMapA -+ false // static iterations. -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccessB -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB, // was IteratorThreadMapB -+ false // static iterations. -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ // Groups per threads -+ // Fp32: 2 groups -+ // Fp16: 2 groups -+ static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; -+ // Define the warp-level op -+ static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); -+ static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; -+ -+ static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; -+ static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; -+ -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ -+ static const int LaneLayout = 1; -+ static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; -+ static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); -+ -+ // Define the output tile computed by each thread -+ using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; -+ -+ // Fetch the channel with same access size -+ static const int LaneM = LaneN; -+ -+ // No paddings -+ static int const kPaddingM = 0; -+ static int const kPaddingN = 0; -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> -+ ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type -+ StrideShape, /// Stride ( MatrixShape ) -+ DilationShape, /// Dilation ( MatrixShape ) -+ ActivationShape /// Activation Shape loaded by threadblock -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ IteratorThreadMapA, -+ IteratorThreadMapB, -+ WarpCount::kK -+ >; -+}; -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h -new file mode 100644 -index 0000000..cc33c69 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h -@@ -0,0 +1,802 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped fused activation's -+ scale+bias+relu and Implicit GEMM Convolution kernel. -+ -+ The original implicit gemm will store out-of-bound data as zeroes in the -+ shared memory because zeros into the tensor core, zeroes out of the tensor -+ cores. The result is remained the same. When fusing scale+bias+relu -+ into the mainloop, it is no longer true because -+ -+ 0 x scale + bias = bias -+ -+ which is no longer always 0. So, instead of storing zeroes, this fused -+ kernel stores the out-of-bound data as a special NaN (0x7eff), when applying -+ scale+bias+relu, the code is like -+ -+ if (data == 0x7eff) -+ data = 0; -+ else -+ data = scale+bias+relu(data, scale, bias); -+ -+ See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the -+ elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/conv/warp/scale_bias_relu_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorScaleBias_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaFpropFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorScaleBias = WarpIteratorScaleBias_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the scale and bias vectors -+ using TensorRefScaleBias = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the A scale and bias vectors in shared memory -+ using ShapeScaleBias = -+ MatrixShape<1 + Policy::SmemPaddingA::kRow, -+ 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand Scale and Bias -+ AlignedBuffer operand_A_scale_bias; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the A scale and bias vectors -+ CUTLASS_DEVICE -+ static LayoutScaleBias LayoutScaleBias() { -+ return LayoutScaleBias::packed( -+ {ShapeScaleBias::kRow, ShapeScaleBias::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the A operand Scale vector -+ CUTLASS_HOST_DEVICE -+ TensorRefScaleBias operand_A_scale_bias_ref() { -+ return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of A operand scale and bias vector -+ /// from shared memory -+ WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaFpropFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_A_scale_bias_( -+ shared_storage.operand_A_scale_bias_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorScaleBias_, -+ /// Iterates over vectors of scale and bias vector in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorScaleBias_, -+ /// Cache operation for scale/bias operand -+ cutlass::arch::CacheOperation::Kind CacheOpScaleBias, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorScaleBias_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmFpropFusionMultistage -+ : public MmaFpropFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorScaleBias = IteratorScaleBias_; -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorScaleBias = WarpIteratorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Base class -+ using Base = MmaFpropFusionBase; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorScaleBias = SmemIteratorScaleBias_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias = -+ CacheOpScaleBias; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpLoadedFragmentScaleBias = -+ typename WarpIteratorScaleBias::Fragment; -+ -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory -+ SmemIteratorScaleBias smem_iterator_A_scale_bias_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmFpropFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(), -+ thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_A_scale_bias_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorScaleBias &iterator_A_scale_bias, -+ IteratorB &iterator_B, int group_start_A = 0, -+ int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Async Copy for operand A scale and bias vector. Scale and bias vectors -+ // are small. One iteration is enough. -+ if (group_start_A == 0) { -+ typename IteratorScaleBias::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_scale_bias_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorScaleBias::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); -+ } -+ -+ iterator_B.set_iteration_index(group_start_B); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorScaleBias iterator_A_scale_bias, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses Nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ ++this->smem_iterator_A_; -+ } -+ -+ // Async Copy for operand A scale and bias vectors. Scale and bias -+ // vectors are small. One iteration is enough. -+ { -+ typename IteratorScaleBias::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_scale_bias_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorScaleBias::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_A_scale_bias.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::conv::warp::FpropScaleBiasReluTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_A_scale_bias_.load( -+ warp_loaded_frag_A_scale_bias[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_scale_bias_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_A_scale_bias[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_scale_bias_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_scale_bias_.load( -+ warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_scale_bias_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_A_scale_bias[warp_mma_k % 2]); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_A_scale_bias.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_A_scale_bias_.add_tile_offset( -+ {0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_A_scale_bias_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h -new file mode 100644 -index 0000000..80dc435 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h -@@ -0,0 +1,542 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmMultistage : -+ public gemm::threadblock::MmaBase { -+public: -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h -new file mode 100644 -index 0000000..4a36ef5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h -@@ -0,0 +1,320 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to A operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class ImplicitGemmPipelined : public gemm::threadblock::MmaBase { -+public: -+ -+ ///< Base class -+ using Base = gemm::threadblock::MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h -new file mode 100644 -index 0000000..13b5a34 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h -@@ -0,0 +1,729 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and -+ Implicit GEMM Convolution kernel. -+ -+ The original implicit gemm will store out-of-bound data as zeroes in the -+ shared memory because zeros into the tensor core, zeroes out of the tensor -+ cores. The result is remained the same. When fusing scale+bias+relu -+ into the mainloop, it is no longer true because -+ -+ 0 x scale + bias = bias -+ -+ which is no longer always 0. So, instead of storing zeroes, this fused -+ kernel stores the out-of-bound data as a special NaN (0x7eff), when applying -+ scale+bias+relu, the code is like -+ -+ if (data == 0x7eff) -+ data = 0; -+ else -+ data = scale+bias+relu(data, scale, bias); -+ -+ The biggest difference compared with the fused Fprop and scale+bias+relu is -+ that scale and bias are loop invariant in Wgrad so that they only needs to -+ be loaded once before the mainloop. -+ -+ See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the -+ elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/conv/warp/scale_bias_relu_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Element type of scale and bias vectors -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaWgradFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaWgradFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorScaleBias_, -+ /// Iterates over vectors of scale and bias vector i -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class ImplicitGemmWgradFusionMultistage -+ : public MmaWgradFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorScaleBias = IteratorScaleBias_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Base class -+ using Base = MmaWgradFusionBase; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ -+ using ElementC = typename Policy::Operator::ElementC; -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ static int const kBBufferSize = -+ ((sizeof(typename Operator::ElementC) == 4) && -+ ((platform::is_same::value && -+ platform::is_same::value)) && -+ (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) -+ ? 1 -+ : 2; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment; -+ -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ ImplicitGemmWgradFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ -+ iterator_A.set_iteration_index(group_start_A); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B); -+ -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorScaleBias iterator_B_scale_bias, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ ///< number of iterations per channel -+ int gemm_k_iterations_per_channel = 0, -+ ///< Imaginary strides used for planar-complex only - ignored here -+ int64_t imag_stride_A = 0, -+ int64_t imag_stride_B = 0) { -+ -+ // -+ // Prologue -+ // -+ -+ WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias; -+ iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_}); -+ iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias); -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / 8; -+ -+ // Uses Nan fill for out of bound data -+ cutlass::arch::cp_async_nan( -+ dst_ptr, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::conv::warp::WgradScaleBiasReluTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_B[0], -+ warp_loaded_frag_B_scale_bias); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ if (Detail::kBBufferSize == 2) { -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]); -+ ++this->warp_tile_iterator_A_; -+ } -+ -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_B_scale_bias); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (Detail::kBBufferSize == 1) { -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ } -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B_scale_bias); -+ } -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a fence to group cp.async instructions into stages. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages of cp.async have committed -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.advance(); -+ iterator_B.advance(); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ } -+ } -+ -+ } -+ -+ // Insert fence and wait for all outstanding cp.async operations to commit. -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..8b5b111 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h -@@ -0,0 +1,470 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorAccessIterator -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; -+ -+ using AccessType = AlignedArray; -+ -+ using Params = PredicatedScaleBiasVectorAccessIteratorParams; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ int problem_size_trs; -+ int problem_size_c; -+ int filter_trs_; -+ -+ TensorCoord thread_offset_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_trs(problem_size.R * problem_size.S), -+ problem_size_c(problem_size.C), -+ filter_trs_(0) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_trs(problem_size.T * problem_size.R * problem_size.S), -+ problem_size_c(problem_size.C), -+ filter_trs_(0) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ thread_offset_ = -+ thread_offset_ + -+ TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ (thread_offset_.contiguous() * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ // moves to the next tile -+ ++filter_trs_; -+ if (filter_trs_ == problem_size_trs) { -+ filter_trs_ = 0; -+ add_tile_offset(TensorCoord(1, 0)); -+ } -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ uint32_t enabled = 0; -+ -+#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11) -+ enabled = threadIdx.x < kThreads * 2; -+#else -+ asm volatile( -+ "{\n" -+ " .reg .u32 tid_reg;\n" -+ " .reg .pred p;\n" -+ " mov.u32 tid_reg, %%tid.x;\n" -+ " setp.lt.u32 p, tid_reg, %1;\n" -+ " selp.u32 %0, 1, 0, p;\n" -+ "}\n" : "+r"(enabled) :"n"(kThreads * 2)); -+#endif -+ -+ return ((thread_offset_.contiguous() < problem_size_c) && enabled); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ -+ using Params = PredicatedScaleBiasVectorAccessIteratorParams; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv3dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv2dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv3dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ iterator_.advance(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h -new file mode 100644 -index 0000000..98b4c82 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h -@@ -0,0 +1,371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorIterator -+/// -+template -+class PredicatedScaleBiasVectorIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 1; -+ -+ using AccessType = AlignedArray; -+ -+ static int const kIterations = WarpShape::kContiguous / 8; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ using Params = Conv2dWgradActivationIteratorOptimizedParams; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ ConstPointer scale_pointer_; -+ ConstPointer bias_pointer_; -+ -+ /// Size of tensor -+ Conv2dProblemSize problem_size_; -+ -+ int32_t thread_offset_; -+ -+ // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k -+ int32_t filter_c_[kIterations]; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ problem_size_(problem_size), -+ scale_pointer_(scale_pointer), -+ bias_pointer_(bias_pointer) { -+ -+ thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; -+ } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int c = 0; c < kIterations; ++c) { -+ int rsc_offset = thread_offset_ + c * 8; -+ -+ int residual, tmp; -+ params_.sc_divmod(tmp, residual, rsc_offset); -+ params_.c_divmod(tmp, filter_c_[c], residual); -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.fill(__float2half2_rn(0.0f)); -+ __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); -+ -+ // load scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2].x, -+ scale_pointer_ + filter_c_[c], -+ true -+ ); -+ } -+ -+ // load bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2 + 1].x, -+ bias_pointer_ + filter_c_[c], -+ true -+ ); -+ } -+ -+ // duplicate scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2].y = frag_ptr[c * 2].x; -+ } -+ -+ // duplicate bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ using Fragment = typename UnderlyingIterator::Fragment; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedScaleBiasVectorIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Conv2dProblemSize const &problem_size, Layout const &layout) -+ : params_(problem_size, layout::TensorNHWC(0, 0, 0)){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Extent of tensor -+ Conv2dProblemSize const &problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Conv2dProblemSize const &problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorIterator(params, problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load(frag); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h -new file mode 100644 -index 0000000..0ed0b24 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h -@@ -0,0 +1,193 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements several possible threadblock-swizzling functions mapping blockIdx to -+ Convolution problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+CUTLASS_HOST_DEVICE -+static int get_strided_dgrad_tile_m( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ int tile_size_m) { -+ -+ // CTAs in M dimension per starting filter position -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); -+ -+ // Inflate number of CTAs in M dimension to cover every strating filter position even those that -+ // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) -+ // and point-wise fusion -+ int tile_m = tile_m_per_filter * int(problem_size.stride().product()); -+ -+ // There is a possible performance optimization here that leads up to 2x speeds than the current -+ // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) -+ // -+ // * Optimization * -+ // Only launch CTAs in M dimenstion which contribute to a row in Dx output -+ // -+ // -+ // * Constraints * -+ // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: -+ // - (A.1): There are no constraints for this case and the optimization does -+ // affect this case functionality or performance. -+ // (B) stride > filter, for example, stride={2x2} and filter={1x1}: -+ // - (B.1): Dx output tensor should be zero initialized -+ // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero -+ -+ return tile_m; -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Threadblock swizzling function for strided dgrad convolution -+struct StridedDgradHorizontalThreadblockSwizzle : -+ public gemm::threadblock::GemmHorizontalThreadblockSwizzle { -+ -+ using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; -+ -+ CUTLASS_HOST_DEVICE -+ StridedDgradHorizontalThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ // compute number of tiles in m dimension -+ int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); -+ -+ // compute number of tiles in n dimenstion -+ int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); -+ -+ return gemm::GemmCoord( -+ tile_m, -+ tile_n, -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) -+ private: -+ using Base::get_tiled_shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Threadblock swizzling function for strided dgrad convolution -+template -+struct StridedDgradIdentityThreadblockSwizzle : -+ public gemm::threadblock::GemmIdentityThreadblockSwizzle { -+ -+ using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; -+ -+ CUTLASS_HOST_DEVICE -+ StridedDgradIdentityThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ // compute number of tiles in m dimension -+ int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); -+ -+ // compute number of tiles in n dimenstion -+ int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); -+ -+ return gemm::GemmCoord( -+ tile_m, -+ tile_n, -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) -+ private: -+ using Base::get_tiled_shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+template -+struct DepthwiseDirect2dConvIdentityThreadblockSwizzle -+ : public gemm::threadblock::GemmIdentityThreadblockSwizzle { -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ gemm::GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return gemm::GemmCoord(1, -+ (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+}; -+ -+} // namespace threadblock -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h -new file mode 100644 -index 0000000..ae49cc1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h -@@ -0,0 +1,381 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/thread/depthwise_mma.h" -+ -+ -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaDepthwiseSimt -+ : public cutlass::gemm::warp:: -+ MmaSimt { -+ using Base = cutlass::gemm::warp:: -+ MmaSimt; -+ -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+public: -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< -+ MatrixShape, -+ cutlass::gemm::Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaDepthwiseSimt():Base() {} -+}; -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ typename FilterShape_, -+ /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> -+ typename ThreadOutputShape_, -+ /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_ = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_ = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaDepthwiseDirectConvSimt { -+ public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter shape per threadblock - concept: gemm::GemmShape -+ using FilterShape = FilterShape_; -+ -+ /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Iterator algo type -+ static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; -+ -+ /// Stride ( MatrixShape ) -+ using StrideShape = StrideShape_; -+ -+ /// Dilation ( MatrixShape ) -+ using DilationShape = DilationShape_; -+ -+ /// Activation Shape loaded by threadblock -+ using ActivationShape = ActivationShape_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || -+ platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && -+ platform::is_same< ElementA, int8_t >::value && -+ platform::is_same< ElementB, int8_t >::value; -+ -+ using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; -+ -+ /// Thread-level matrix multiply accumulate operator -+ using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< -+ cutlass::gemm::GemmShape< -+ Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread -+ Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread -+ 1>, -+ ElementA, -+ ElementB, -+ ElementC, -+ arch::OpMultiplyAdd, -+ dp4a_type -+ >; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Shape of the underlying instruction -+ using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< -+ MatrixShape, // per warp -+ FilterShape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ cutlass::gemm::Operand::kA, -+ ElementA, -+ Policy, -+ IteratorAlgorithm, -+ StrideShape, -+ DilationShape, -+ ActivationShape, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< -+ MatrixShape<1, Shape::kN>, -+ cutlass::gemm::Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< -+ MatrixShape, -+ cutlass::gemm::Operand::kC, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename ThreadMma::FragmentC; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaDepthwiseDirectConvSimt() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA a, -+ FragmentB b, -+ FragmentC const &c, int group_idx = 0) const { -+ -+ ThreadMma mma; -+ -+ mma(d, a, b, c); -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h -new file mode 100644 -index 0000000..b750a4b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h -@@ -0,0 +1,862 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/conv/convolution.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -+/// -+/// concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ cutlass::gemm::Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1 -+> -+class DepthwiseMmaSimtTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseMmaSimtTileIterator -+ : public cutlass::gemm::warp::MmaSimtTileIterator { -+ -+ using Base = cutlass::gemm::warp::MmaSimtTileIterator; -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = typename Base::ThreadShape; -+ -+ /// Number of individual loads -+ using Iterations = typename Base::Iterations; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); -+ -+private: -+ -+ MatrixCoord lane_offset_; -+ int channel_idx_; -+ int base_channel_idx_; -+ int warps_n_; -+ -+ public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator():Base() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) : Base(ref, lane_id) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ warps_n_ = -1; -+ channel_idx_ = 0; -+ base_channel_idx_ = 0; -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ if(warps_n_ == -1){ -+ warps_n_ = coord.column(); -+ } -+ -+ Base::add_tile_offset(coord); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ void const *ptr = this->ref_.data() + -+ this->ref_.offset({-(channel_idx_ - base_channel_idx_), -+ n * Policy::WarpShape::kColumn}) + -+ pointer_offset / Policy::LaneMmaShape::kN; -+ -+ // Base_k of a warp + Base_k of current threads. -+ int thread_k_base_idx = -+ warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); -+ -+ if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { -+ // Depthwise kernel would only do computation when channel == k. -+ // Loads an element when the current computation channel == the k corresponding to this thread. -+ arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); -+ } else { -+ // Reduce SMEM load -+ dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ if(k_group % PartitionGroupSize == 0 && k_group != 0){ -+ base_channel_idx_ = k_group; -+ } -+ channel_idx_ = k_group; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename ThreadBlockOutputShape_, -+ /// Operand identity -+ cutlass::gemm::Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, -+ /// Stride ( MatrixShape ) -+ typename StrideShape = cutlass::MatrixShape<-1, -1>, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape = cutlass::MatrixShape<-1, -1>, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1> -+class DepthwiseDirect2dConvSimtTileIterator; -+ -+ -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Iterator algo type -+ conv::IteratorAlgorithm IteratorAlgorithm, -+ /// Stride ( MatrixShape ) -+ typename StrideShape, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseDirect2dConvSimtTileIterator { -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter (concept: gemm::GemmShape) -+ using FilterShape = FilterShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ ThreadOutputShape::kNHW, // Output tile shape Computed by current threads -+ ThreadOutputShape::kC -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using ThreadTileCount = MatrixShape< -+ ThreadBlockOutputShape::kH / ThreadOutputShape::kH, -+ ThreadBlockOutputShape::kW / ThreadOutputShape::kW -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+protected: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+ int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; -+ int iterator_r_; -+ int iterator_s_; -+ int iterator_offset_; -+ -+ int inc_next_s_ ; -+ int inc_next_r_ ; -+ -+ MatrixCoord lane_offset_; -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ // Set channel offset -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset_); -+ -+ ref_.reset(reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ -+ iterator_r_ = 0; -+ iterator_s_ = 0; -+ iterator_offset_ = 0; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ template -+ CUTLASS_HOST_DEVICE -+ void setup_initial_status(Params const& params) { -+ -+ inc_next_s_ = params.inc_next[0]; -+ inc_next_r_ = params.inc_next[1]; -+ -+ // Get base HW offset of current threads -+ int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ int base_p_ = -+ (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; -+ int base_q_ = -+ (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int base_w = (base_q_ + q) * params.stride[0]; -+ int base_h = (base_p_ + p) * params.stride[1]; -+ -+ int offset = base_h * params.activation_tile_w + base_w; -+ activation_offset[p][q][col] = offset; -+ } -+ } -+ } -+ } -+ -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ // Set warp row and col start -+ lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ void advance(int32_t pointer_offset) { -+ ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); -+ iterator_s_ = 0; -+ iterator_r_ = 0; -+ iterator_offset_ = 0; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator++() { -+ ++iterator_s_; -+ if (iterator_s_ < FilterShape::kColumn) { -+ iterator_offset_ += inc_next_s_; -+ -+ return *this; -+ } -+ -+ iterator_s_ = 0; -+ -+ ++iterator_r_; -+ if (iterator_r_ < FilterShape::kRow) { -+ iterator_offset_ += inc_next_r_; -+ return *this; -+ } -+ -+ iterator_r_ = 0; -+ iterator_offset_ = 0; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator & operator--() { -+ // Do nothing -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ void const *ptr = ref_.data() + -+ ref_.offset({activation_offset[p][q][n] + (iterator_offset_), -+ n * Policy::WarpShape::kColumn}) + -+ pointer_offset / Policy::LaneMmaShape::kN; -+ arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ // Do nothing at present. -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of filter (concept: gemm::GemmShape) -+ typename FilterShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadOutputShape_, -+ /// Size of the matrix to load (concept: TensorNHWC) -+ typename ThreadBlockOutputShape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Stride ( MatrixShape ) -+ typename StrideShape_, -+ /// Dilation ( MatrixShape ) -+ typename DilationShape_, -+ /// Activation Shape loaded by threadblock -+ typename ActivationShape_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize> -+class DepthwiseDirect2dConvSimtTileIterator { -+ public: -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of filter (concept: gemm::GemmShape) -+ using FilterShape = FilterShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadOutputShape = ThreadOutputShape_; -+ -+ /// Shape of tile to load (concept: TensorNHWC) -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ /// Stride ( MatrixShape ) -+ using StrideShape = StrideShape_; -+ -+ /// Dilation ( MatrixShape ) -+ using DilationShape = DilationShape_; -+ -+ /// Activation Shape loaded by threadblock -+ using ActivationShape = ActivationShape_; -+ -+ /// Operand tag -+ static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged " -+ "along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, -+ "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ // Activations loaded by threadblock -+ static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + -+ (FilterShape::kRow - 1) * DilationShape::kRow + 1; -+ -+ static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + -+ (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; -+ -+ using ThreadActivationShape = cutlass::conv:: -+ TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; -+ -+ // Thread-level shape of a fragment -+ using ThreadShape = -+ MatrixShape; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = -+ MatrixShape; -+ -+ using ThreadTileCount = MatrixShape; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ protected: -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+ Array -+ activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; -+ int iterator_r_; -+ int iterator_s_; -+ -+ -+ MatrixCoord lane_offset_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ // Set channel offset -+ lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset_); -+ -+ ref_.reset(reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ -+ iterator_r_ = 0; -+ iterator_s_ = 0; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ template -+ CUTLASS_HOST_DEVICE void setup_initial_status( -+ Params const ¶ms) { -+ -+ // Get base HW offset of current threads -+ int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ int base_h = -+ (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; -+ int base_w = -+ (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int h = 0; h < ThreadActivationShape::kH; ++h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int w = 0; w < ThreadActivationShape::kW; ++w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int offset = (base_h + h) * ActivationShape::kW + (base_w + w); -+ -+ void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); -+ arch::shared_load(activation[h][w][col], ptr); -+ } -+ } -+ } -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ // Set warp row and col start -+ lane_offset_ = -+ MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ void advance(int32_t pointer_offset) { -+ ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); -+ iterator_s_ = 0; -+ iterator_r_ = 0; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator++() { -+ ++iterator_s_; -+ if (iterator_s_ < FilterShape::kColumn) { -+ return *this; -+ } -+ -+ iterator_s_ = 0; -+ -+ ++iterator_r_; -+ if (iterator_r_ < FilterShape::kRow) { -+ return *this; -+ } -+ -+ iterator_r_ = 0; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ DepthwiseDirect2dConvSimtTileIterator &operator--() { -+ // Do nothing -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < ThreadOutputShape::kH; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int q = 0; q < ThreadOutputShape::kW; ++q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; -+ const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; -+ -+ dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ // Do nothing at present. -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h b/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h -new file mode 100644 -index 0000000..a1a4dff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per channel scale+bias+relu before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace conv { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct FpropScaleBiasReluTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumScaleBias = FragmentScaleBias::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns -+ static int const MmaCols = 2; -+ -+ using MmaOperand = Array; -+ using ScaleBiasOperand = Array; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ uint32_t *ptr_activations = reinterpret_cast(&activations); -+ uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We assumes the pair of FP16 are either both inbound or both out-of-bound. -+ // It requires C to be an even number. -+ asm volatile( -+ "{\n\t" -+ " .reg .pred %%p;\n\t" -+ " .reg .b32 t1;\n\t" -+ " setp.eq.u32 %%p, %2, %4;\n\t" -+ " fma.rn.f16x2.relu t1, %1, %2, %3;\n" -+ " selp.u32 %0, 0, t1, %%p;\n\t" -+ "}\n" -+ : "=r"(ptr_activations[0]) -+ : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), -+ "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentScaleBias const &scale_bias) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ ScaleBiasOperand const *ptr_scale_bias = -+ reinterpret_cast(&scale_bias); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], ptr_scale_bias[(i / MmaScaleBiasPair) % MmaCols]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct WgradScaleBiasReluTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumScaleBias = FragmentScaleBias::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 rows -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using ScaleBiasOperand = Array<__half2, MmaScaleBiasPair>; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ -+ __half2 *ptr_activations = reinterpret_cast<__half2 *>(&activations); -+ uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); -+ -+#if 1 -+ // CUDA + PTX version -+ -+ bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); -+ bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We cannot gurantee that the pair of F16 are both in bound or both -+ // out-of-bound because C x R x S can be an odd number. -+ asm volatile( -+ "{\n\t" -+ " fma.rn.f16x2.relu %0, %1, %2, %3;\n" -+ "}" -+ : "=r"(reinterpret_cast(ptr_activations[0])) -+ : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), -+ "r"(ptr_scale_bias[1])); -+ -+ reinterpret_cast(ptr_activations[0]) = h1_oob ? -+ (reinterpret_cast(ptr_activations[0]) & 0xffff0000) : -+ reinterpret_cast(ptr_activations[0]); -+ -+ reinterpret_cast(ptr_activations[0]) = h2_oob ? -+ (reinterpret_cast(ptr_activations[0]) & 0xffff) : -+ reinterpret_cast(ptr_activations[0]); -+#else -+ // pure PTX version -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ asm volatile( -+ "{\n" -+ " .reg .b16 t1, t2;\n" -+ " .reg .b32 t3, t4, t5, t6;\n" -+ " .reg .pred p1, p2;\n" -+ " mov.b32 {t1, t2}, %2;\n" -+ " setp.eq.s16 p1, t1, %4;\n" -+ " setp.eq.s16 p2, t2, %4;\n" -+ " fma.rn.f16x2.relu t3, %1, %2, %3;\n" -+ " and.b32 t4, t3, %5;\n" -+ " selp.b32 t5, t4, t3, p1;\n" -+ " and.b32 t6, t5, %6;\n" -+ " selp.b32 %0, t6, t5, p2;\n" -+ "}\n" -+ : "=r"(reinterpret_cast(ptr_activations[0])) -+ : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), -+ "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); -+#endif -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentScaleBias const &scale_bias) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ ScaleBiasOperand const *ptr_scale_bias = -+ reinterpret_cast(&scale_bias); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], ptr_scale_bias[(i / MmaRows)]); -+ } -+ } -+}; -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/coord.h b/3rdparty/cutlass/include/cutlass/coord.h -new file mode 100644 -index 0000000..4558385 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/coord.h -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically-sized array specifying Coords within a tensor -+template < -+ int Rank_, ///< Logical rank of coordinate -+ typename Index_ = int, ///< Index type used for each dimension -+ typename LongIndex_ = int64_t ///< Long index type used for linear offsets -+> -+struct Coord { -+ -+public: -+ -+ // -+ // Type and constant definitions -+ // -+ -+ /// Number of elements in Coord -+ static int const kRank = Rank_; -+ -+ /// Index type used to store elements -+ using Index = Index_; -+ -+ /// Type used to represent linear offsets -+ using LongIndex = LongIndex_; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Indices -+ Index idx[kRank]; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor initializes uniformly -+ CUTLASS_HOST_DEVICE -+ explicit Coord(Index value = Index(0)) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = value; -+ } -+ } -+ -+ /// Constructs from an array of integers -+ CUTLASS_HOST_DEVICE -+ Coord(Index const (&_idx)[kRank]) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = _idx[i]; -+ } -+ } -+ -+ /// Constructs from some other Coord -+ template -+ CUTLASS_HOST_DEVICE -+ Coord(Coord other) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = other[i]; -+ } -+ } -+ -+ /// Returns a slice of the Coord which may be larger or smaller in rank -+ /// than this. -+ template -+ CUTLASS_HOST_DEVICE -+ Coord slice(int start = 0, Index identity = 0) const { -+ Coord result; -+ for (int i = 0; i < Slice; ++i) { -+ if (i + start < kRank) { -+ result[i] = idx[i + start]; -+ } -+ else { -+ result[i] = identity; -+ } -+ } -+ return result; -+ } -+ -+ /// Returns the index of the dimension with least value -+ CUTLASS_HOST_DEVICE -+ int min_dim_index() const { -+ int i = 0; -+ for (int j = 1; j < kRank; ++j) { -+ if (idx[j] < idx[i]) { -+ i = j; -+ } -+ } -+ return i; -+ } -+ -+ /// Returns the index of the dimension with greatest value -+ CUTLASS_HOST_DEVICE -+ int max_dim_index() const { -+ int i = 0; -+ for (int j = 1; j < kRank; ++j) { -+ if (idx[j] > idx[i]) { -+ i = j; -+ } -+ } -+ return i; -+ } -+ -+ /// Returns true if Coord is non-zero. -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ for (int i = 0; i < kRank; ++i) { -+ if (idx[i]) { -+ return true; -+ } -+ } -+ return false; -+ } -+ -+ /// Returns true if Coord is uniformly zero. -+ CUTLASS_HOST_DEVICE -+ bool operator!() const { -+ for (int i = 0; i < kRank; ++i) { -+ if (idx[i]) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Coord operator+(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] + b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Coord operator-(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] - b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Coord operator*(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] * b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Coord operator/(Coord const& b) const { -+ Coord c; -+ for (int i = 0; i < kRank; ++i) { -+ c.idx[i] = idx[i] / b.idx[i]; -+ } -+ return c; -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Coord& operator+=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] += b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Coord& operator-=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] -= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Coord& operator*=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] *= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Coord& operator/=(Coord const& b) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] /= b.idx[i]; -+ } -+ return *this; -+ } -+ -+ /// Member access operator -+ CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } -+ -+ /// Member access operator -+ CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } -+ -+ /// Computes the dot product with anotherCoord object -+ CUTLASS_HOST_DEVICE -+ LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const { -+ for (int i = 0; i < kRank; ++i) { -+ sum += idx[i] * b.idx[i]; -+ } -+ return sum; -+ } -+ -+ /// Gets the index of a given Coord element -+ template -+ CUTLASS_HOST_DEVICE Index& at() { -+ return idx[Dim]; -+ } -+ -+ /// Access via index; may limit unrolling potential -+ CUTLASS_HOST_DEVICE -+ Index& at(int dim) { return idx[dim]; } -+ -+ /// Gets the index of a given Coord element -+ template -+ CUTLASS_HOST_DEVICE Index const& at() const { -+ return idx[Dim]; -+ } -+ -+ /// Access via index; may limit unrolling potential -+ CUTLASS_HOST_DEVICE -+ Index const& at(int dim) const { return idx[dim]; } -+ -+ /// Determines if two Coord<> objects are equal -+ CUTLASS_HOST_DEVICE -+ bool operator==(Coord const& b) const { -+ bool equal = true; -+ for (int i = 0; equal && i < kRank; ++i) { -+ equal = (idx[i] == b.idx[i]); -+ } -+ return equal; -+ } -+ -+ /// Not equal -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Coord const& b) const { return !(*this == b); } -+ -+ /// Clamps a coordinate to a range specified by maximum and minimum values -+ CUTLASS_HOST_DEVICE -+ Coord& clamp(Coord const& max, Coord const& min = Coord()) { -+ for (int i = 0; i < kRank; ++i) { -+ idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); -+ } -+ return *this; -+ } -+ -+ /// Returns the sum of all elements -+ CUTLASS_HOST_DEVICE -+ Index sum() const { -+ Index sum_(idx[0]); -+ for (int i = 1; i < kRank; ++i) { -+ sum_ += idx[i]; -+ } -+ return sum_; -+ } -+ -+ /// Returns the product of all elements -+ CUTLASS_HOST_DEVICE -+ LongIndex product() const { -+ LongIndex product_(idx[0]); -+ for (int i = 1; i < kRank; ++i) { -+ product_ *= idx[i]; -+ } -+ return product_; -+ } -+ -+ /// Less than operator -+ CUTLASS_HOST_DEVICE -+ bool operator<(Coord const &b) const { -+ for (int i = 0; i < kRank; ++i) { -+ if (!(idx[i] < b[i])) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Less than or equals operator -+ CUTLASS_HOST_DEVICE -+ bool operator<=(Coord const &b) const { -+ for (int i = 0; i < kRank; ++i) { -+ if (!(idx[i] <= b[i])) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Greater than operator -+ CUTLASS_HOST_DEVICE -+ bool operator>(Coord const &b) const { -+ return !(*this <= b); -+ } -+ -+ /// Greater than or equals operator -+ CUTLASS_HOST_DEVICE -+ bool operator>=(Coord const &b) const { -+ return !(*this < b); -+ } -+}; -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+ -+/// Scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Coord operator*(Index s, Coord coord) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] *= s; -+ } -+ return coord; -+} -+ -+/// Scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Coord operator*(Coord coord, Index s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] *= s; -+ } -+ return coord; -+} -+ -+/// Scalar division -+template -+CUTLASS_HOST_DEVICE -+Coord operator/(Index s, Coord coord) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] = s / coord[i]; -+ } -+ return coord; -+} -+ -+/// Scalar division -+template -+CUTLASS_HOST_DEVICE -+Coord operator/(Coord coord, Index s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] /= s; -+ } -+ return coord; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Integer-valued make_Coord -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to make a 2-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<1, T> make_Coord(T _0) { -+ T values[1] = {_0}; -+ return Coord<1, T>(values); -+} -+ -+/// Helper to make a 2-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<2, T> make_Coord(T _0, T _1) { -+ T values[2] = {_0, _1}; -+ return Coord<2, T>(values); -+} -+ -+/// Helper to make a 3-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<3, T> make_Coord(T _0, T _1, T _2) { -+ T values[3] = {_0, _1, _2}; -+ return Coord<3, T>(values); -+} -+ -+/// Helper to make a 4-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { -+ T values[4] = {_0, _1, _2, _3}; -+ return Coord<4, T>(values); -+} -+ -+/// Helper to make a 5-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { -+ T values[5] = {_0, _1, _2, _3, _4}; -+ return Coord<5, T>(values); -+} -+ -+/// Helper to make a 1-element coordinate -+template -+CUTLASS_HOST_DEVICE -+Coordmake_Coord_with_padding(T _0) { -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = N - 1; i > 0; --i) { -+ coord[i] = 0; -+ } -+ -+ coord[0] = _0; -+ -+ return coord; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/core_io.h b/3rdparty/cutlass/include/cutlass/core_io.h -new file mode 100644 -index 0000000..4d15432 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/core_io.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for printing cutlass/core objects -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Output operator for CUDA built-in dim3 type -+inline std::ostream &operator<<(std::ostream &out, dim3 d) { -+ return out << d.x << ", " << d.y << ", " << d.z; -+} -+ -+/// Output operator for CUDA built-in error type -+inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ return out << cudaGetErrorString(error); -+#endif -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline -+std::ostream& operator<<(std::ostream& out, Array const& v) { -+ for (int i = 0; i < Rank; ++i) { -+ out << (i ? ", " : "") << v[i]; -+ } -+ return out; -+} -+ -+template -+inline -+std::ostream& operator<<(std::ostream& out, Coord const& coord) { -+ for (int i = 0; i < Rank; ++i) { -+ out << (i ? ", " : "") << coord[i]; -+ } -+ return out; -+} -+ -+inline -+std::istream & operator>>(std::istream &stream, half_t &x) { -+ float tmp; -+ stream >> tmp; -+ x = static_cast(tmp); -+ return stream; -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, half_t const &x) { -+ return out << float(x); -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { -+ return out << float(x); -+} -+ -+inline -+std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { -+ return out << float(x); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to enable formatted printing of CUTLASS scalar types to an ostream -+template -+struct ScalarIO { -+ -+ /// Value to print -+ T value; -+ -+ /// Default ctor -+ ScalarIO() { } -+ -+ /// Constructs from a value -+ ScalarIO(T value): value(value) {} -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default printing to ostream -+template -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << scalar.value; -+} -+ -+/// Printing to ostream of int8_t as integer rather than character -+template <> -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << int(scalar.value); -+} -+ -+/// Printing to ostream of uint8_t as integer rather than character -+template <> -+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { -+ return out << unsigned(scalar.value); -+} -+ -+ -+/// Default printing to ostream for MatrixShape -+template -+inline -+std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { -+ out << "cutlass::MatrixShape::(kRow, kColumn) {" -+ << cutlass::MatrixShape::kRow <<"," -+ << cutlass::MatrixShape::kColumn <<"}"; -+ return out; -+} -+ -+ -+/// Prints matrix to ostream -+template -+std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { -+ -+ for (int i = 0; i < Rows; ++i) { -+ for (int j = 0; j < Columns; ++j) { -+ ScalarIO element(rhs.at(i, j)); -+ out << (j ? ", " : "") << element; -+ } -+ out << "\\n"; -+ } -+ -+ return out; -+} -+ -+template -+std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { -+ -+ out << ScalarIO(rhs.w()) << " "; -+ if (rhs.x() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.x()) << "*i "; -+ if (rhs.y() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.y()) << "*j "; -+ if (rhs.z() >= 0) { -+ out << "+"; -+ } -+ -+ out << ScalarIO(rhs.z()) << "*k"; -+ -+ return out; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass::gemm namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+namespace gemm { -+ -+/// Default printing to ostream for GemmShape -+template -+inline -+std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { -+ out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" -+ << cutlass::gemm::GemmShape::kM <<"," -+ << cutlass::gemm::GemmShape::kN <<"," -+ << cutlass::gemm::GemmShape::kK << "}"; -+ return out; -+} -+ -+/// Default printing to ostream for GemmCoord -+inline -+std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { -+ out << "cutlass::gemm::GemmCoord {" -+ << gemm_coord.m() <<"," -+ << gemm_coord.n() <<"," -+ << gemm_coord.k() << "}"; -+ return out; -+} -+ -+} //namespace gemm -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default printing to ostream for PitchLinearShape -+template < int Contiguous, int Strided> -+inline -+std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { -+ out << "cutlass::PitchLinearShape:(kContiguous, kStrided) {" -+ << cutlass::layout::PitchLinearShape::kContiguous <<"," -+ << cutlass::layout::PitchLinearShape::kStrided <<"}"; -+ return out; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// stream operators for cutlass::conv namespace // -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+namespace conv { -+/// Default printing to ostream for Conv2dProblemSize -+inline -+std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { -+ out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl -+ << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl -+ << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl -+ << "groups: (" << problem.groups << ")" << std::endl -+ << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl -+ << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl -+ << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl -+ << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl -+ << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; -+ -+ return out; -+} -+ -+ -+/// Default printing to ostream for Conv3dProblemSize -+inline -+std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { -+ out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl -+ << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl -+ << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl -+ << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl -+ << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl -+ << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl -+ << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl -+ << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; -+ -+ return out; -+} -+ -+} // namespace conv -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/cutlass.h b/3rdparty/cutlass/include/cutlass/cutlass.h -new file mode 100644 -index 0000000..12bc3a3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/cutlass.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic include for CUTLASS. -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#ifdef CUTLASS_NAMESPACE -+#define concat_tok(a, b) a ## b -+#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) -+#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -+#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -+#define CUTLASS_DEVICE __forceinline__ __device__ -+#elif defined(__CUDACC_RTC__) -+#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -+#define CUTLASS_DEVICE __forceinline__ __device__ -+#else -+#define CUTLASS_HOST_DEVICE inline -+#define CUTLASS_DEVICE inline -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) -+{ } -+ -+#if defined(__GNUC__) -+ #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) -+#else -+ #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+ -+#include -+ -+ #if defined(__CUDA_ARCH__) -+ #if defined(_MSC_VER) -+ #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } -+ #else -+ #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } -+ #endif -+ -+ #else -+ #if defined(_MSC_VER) -+ #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -+ #else -+ #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -+ #endif -+ #endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/// Status code returned by CUTLASS operations -+enum class Status { -+ kSuccess, ///< Operation was successful. -+ kErrorMisalignedOperand, ///< operands fail alignment requirements. -+ kErrorInvalidDataType, ///< DataType fails requirement. -+ kErrorInvalidLayout, ///< Layout fails alignment requirement. -+ kErrorInvalidProblem, ///< Specified problem size is not supported by operator. -+ kErrorNotSupported, ///< Operation is not supported on current device. -+ kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. -+ kErrorInternal, ///< An error within CUTLASS occurred. -+ kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. -+ kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. -+ kErrorMemoryAllocation, ///< Kernel launch failed due to insufficient device memory. -+ kInvalid ///< Status is unspecified. -+}; -+ -+/// Convert cutlass status to status strings -+CUTLASS_HOST_DEVICE -+static char const* cutlassGetStatusString(cutlass::Status status) { -+ switch (status) { -+ case cutlass::Status::kSuccess: -+ return "Success"; -+ case cutlass::Status::kErrorMisalignedOperand: -+ return "Error Misaligned Operand"; -+ case cutlass::Status::kErrorInvalidDataType: -+ return "Error Invalid Data Type"; -+ case cutlass::Status::kErrorInvalidLayout: -+ return "Error Invalid Layout"; -+ case cutlass::Status::kErrorInvalidProblem: -+ return "Error Invalid Problem"; -+ case cutlass::Status::kErrorNotSupported: -+ return "Error Not Supported"; -+ case cutlass::Status::kErrorWorkspaceNull: -+ return "Error Workspace Null"; -+ case cutlass::Status::kErrorInternal: -+ return "Error Internal"; -+ case cutlass::Status::kErrorInsufficientDriver: -+ return "Error Insufficient Driver"; -+ case cutlass::Status::kErrorArchMismatch: -+ return "Error Architecture Mismatch"; -+ case cutlass::Status::kErrorMemoryAllocation: -+ return "Error Memory Allocation failed"; -+ case cutlass::Status::kInvalid: break; -+ } -+ -+ return "Invalid status"; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 -+#endif -+ -+ -+// CUDA 10.1 introduces the mma instruction -+#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) -+#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#define CUTLASS_ASSERT(x) assert(x) -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -+#if defined(__CUDA_ARCH__) -+ #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) -+ #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") -+ #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") -+ #else -+ #define CUTLASS_PRAGMA_UNROLL #pragma unroll -+ #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 -+ #endif -+ -+ #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL -+ -+#else -+ -+ #define CUTLASS_PRAGMA_UNROLL -+ #define CUTLASS_PRAGMA_NO_UNROLL -+ #define CUTLASS_GEMM_LOOP -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static const int NumThreadsPerWarp = 32; -+static const int NumThreadsPerWarpGroup = 128; -+static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -+static const int NumThreadsPerQuad = 4; -+static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper function to return true when called by thread 0 of threadblock 0. -+CUTLASS_HOST_DEVICE bool thread0() { -+ #if defined(__CUDA_ARCH__) -+ return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); -+ #else -+ return false; -+ #endif -+} -+ -+/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. -+/// Threads within the warp must be converged. -+CUTLASS_DEVICE -+int canonical_warp_idx() { -+ #if defined(__CUDA_ARCH__) -+ return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); -+ #else -+ return 0; -+ #endif -+} -+ -+/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. -+/// Threads within the warp must be converged. -+CUTLASS_DEVICE -+int canonical_warp_group_idx() { -+ #if defined(__CUDA_ARCH__) -+ return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); -+ #else -+ return 0; -+ #endif -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/device_kernel.h b/3rdparty/cutlass/include/cutlass/device_kernel.h -new file mode 100644 -index 0000000..68042e3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/device_kernel.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for generic CUTLASS kernel. -+*/ -+ -+#pragma once -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTLASS_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+# define CUTLASS_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTLASS_GRID_CONSTANT) -+# if defined(CUTLASS_GRID_CONSTANT_ENABLED) -+# define CUTLASS_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTLASS_GRID_CONSTANT -+# endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ -+void Kernel(typename Operator::Params params) { -+ // Dynamic shared memory base pointer -+ extern __shared__ int SharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Operator::SharedStorage *shared_storage = -+ reinterpret_cast(SharedStorageBase); -+ -+ Operator op; -+ -+ op(params, *shared_storage); -+} -+ -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ -+void Kernel2(typename Operator::Params params) { -+ // Dynamic shared memory base pointer -+ extern __shared__ int SharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Operator::SharedStorage *shared_storage = -+ reinterpret_cast(SharedStorageBase); -+ -+ Operator::invoke(params, *shared_storage); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+// -+// 3.0 specific launch -+// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic CUTLASS kernel template. -+template -+__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) -+void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) -+{ -+ // Dynamic shared memory base pointer -+ extern __shared__ char smem[]; -+ -+ Operator op; -+ op(params, smem); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+} /// namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp -new file mode 100644 -index 0000000..5b1b924 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::epilogue::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class DispatchPolicy, -+ class... Args -+> -+struct CollectiveEpilogue { -+ static_assert(std::is_void_v, "Could not find an epilogue specialization."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::epilogue::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "default_epilogue.hpp" -+#include "epilogue.hpp" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp -new file mode 100644 -index 0000000..71499b5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cute/numeric/int.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes them out to destination storage. -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_ -+> -+class DefaultEpilogue { -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage { }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ DefaultEpilogue(Params const& params_) : params(params_) { } -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_HOST_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ (void) smem_buf; -+ ThreadEpilogueOp epilogue_op{params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Partition source and destination tiles to match the accumulator partitioning -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) -+ Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) -+ -+ static_assert(is_static::value, "Accumulator layout must be static"); -+ CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), -+ "Source and destination must have the same number of elements."); -+ CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), -+ "Accumulator count must have the same destination element count."); -+ -+ // Make an identity coordinate tensor for predicating our output MN tile -+ auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); -+ Tensor tCcD = thr_mma.partition_C(cD); -+ -+ // source is needed -+ if (epilogue_op.is_source_needed()) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); -+ } -+ } -+ } -+ // source is not needed, avoid load -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i)); -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp -new file mode 100644 -index 0000000..7e38acd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+#include "cute/numeric/int.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes them out to destination storage. -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_ -+> -+class DefaultTransposedEpilogue { -+ -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage { }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ DefaultTransposedEpilogue(Params const& params_) : params(params_) { } -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_HOST_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ (void) smem_buf; -+ ThreadEpilogueOp epilogue_op{params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Tranpose stride C/D. -+ auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC)); -+ auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD)); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Partition source and destination tiles to match the accumulator partitioning -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) -+ Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) -+ -+ static_assert(is_static::value, "Accumulator layout must be static"); -+ CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), -+ "Source and destination must have the same number of elements."); -+ CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), -+ "Accumulator count must have the same destination element count."); -+ -+ auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); -+ Tensor tCcD = thr_mma.partition_C(cD); -+ -+ // source is needed -+ if (epilogue_op.is_source_needed()) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); -+ } -+ } -+ } -+ // source is not needed, avoid load -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(accumulators); ++i) { -+ if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { -+ tCgD(i) = epilogue_op(accumulators(i)); -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp b/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp -new file mode 100644 -index 0000000..565e752 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/collective/epilogue.hpp -@@ -0,0 +1,322 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing elementwise operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cute/tensor.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies an element wise operation to all elements within the fragment -+/// and writes it out to destination storage. -+/// -+/// Ways to generalize this: -+/// - CTA tile shape -+/// - vectorization requirements (GMEM) -+/// - vectoriz(able) transform() -+/// -+template < -+ class StrideC_, -+ class StrideD_, -+ class ThreadEpilogueOp_, -+ class SmemLayout_, -+ class CopyAtomR2S_, -+ class TiledCopyS2R_, -+ class CopyAtomR2G_ -+> -+class Epilogue { -+public: -+ // -+ // Type Aliases -+ // -+ // derived types of output thread level operator -+ using ThreadEpilogueOp = ThreadEpilogueOp_; -+ using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; -+ using ElementCompute = typename ThreadEpilogueOp::ElementCompute; -+ using ElementScalar = ElementCompute; -+ using ElementOutput = typename ThreadEpilogueOp::ElementOutput; -+ using ElementC = typename ThreadEpilogueOp::ElementC; -+ using StrideC = StrideC_; -+ using ElementD = typename ThreadEpilogueOp::ElementD; -+ using StrideD = StrideD_; -+ -+ using SmemLayout = SmemLayout_; -+ using CopyAtomR2S = CopyAtomR2S_; -+ using TiledCopyS2R = TiledCopyS2R_; -+ using CopyAtomR2G = CopyAtomR2G_; -+ -+ static const int kOutputAlignment = ThreadEpilogueOp::kCount; -+ using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_epilogue; -+ }; -+ -+ // Params of epilogue::collective contain the epilogue::thread params -+ struct Params { -+ ElementC const* ptr_C = nullptr; -+ StrideC dC{}; -+ ElementD* ptr_D = nullptr; -+ StrideD dD{}; -+ typename ThreadEpilogueOp::Params thread_params{}; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.epilogue_params}; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Epilogue(Params const& params_) : params(params_) { }; -+ -+ template< -+ class ProblemShapeMNKL, -+ class BlockShapeMNK, -+ class BlockCoordMNKL, -+ class FrgEngine, class FrgLayout, -+ class TiledMma, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator()( -+ ProblemShapeMNKL problem_shape_mnkl, -+ BlockShapeMNK blk_shape_MNK, -+ BlockCoordMNKL blk_coord_mnkl, -+ cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) -+ TiledMma tiled_mma, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char* smem_buf) -+ { -+ using namespace cute; -+ using X = Underscore; -+ -+ static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); -+ static_assert(is_static::value, "ThreadBlock tile shape must be static"); -+ static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); -+ static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); -+ -+ // synchronizing function for smem reads/writes -+#if CUDA_BARRIER_ENABLED -+ auto synchronize = [] () { NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; -+#else -+ auto synchronize = [] () { __syncthreads(); }; -+#endif -+ -+ ThreadEpilogueOp epilogue_op{this->params.thread_params}; -+ -+ // Separate out problem shape for convenience -+ auto M = get<0>(problem_shape_mnkl); -+ auto N = get<1>(problem_shape_mnkl); -+ auto L = get<3>(problem_shape_mnkl); -+ -+ // Represent the full output tensor -+ Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) -+ Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) -+ Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) -+ -+ // Slice to get the tile this CTA is responsible for -+ auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; -+ Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) -+ -+ // Construct a tensor in SMEM that we can partition for rearranging data -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) -+ -+ // Partition sC to match the accumulator partitioning -+ auto tC = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) -+ Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) -+ -+ // Tile gD and gC by the shape of SmemLayout first -+ auto tile = make_shape(size<0>(sC), size<1>(sC)); -+ Tensor gCt = local_tile(gC, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ Tensor gDt = local_tile(gD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ -+ // Partition sC, gC, and gD for the output -+ auto tD = TiledCopyS2R{}.get_thread_slice(thread_idx); -+ Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ -+ // Allocate intermediate registers on the dst tensors -+ Tensor tDrC = make_tensor(take<0,3>(shape(tDgC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ Tensor tDrD = make_tensor(shape(tDrC)); // ((Atom,AtomNum),ATOM_M,ATOM_N) -+ -+ // Repeat the D-partitioning for coordinates and predication -+ Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) -+ Tensor cDt = local_tile(cD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) -+ Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) -+ -+ CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M -+ CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0); // TILE_N divides MMA_N -+ CUTE_STATIC_ASSERT(typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{})); -+ -+#if 0 -+ if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { -+ print("aC : "); print(accumulators.layout()); print("\n"); -+ print("gC : "); print(gC.layout()); print("\n"); -+ print("gD : "); print(gD.layout()); print("\n"); -+ print("sC : "); print(sC.layout()); print("\n"); -+ print("\n"); -+ print("tCsC : "); print(tCsC.layout()); print("\n"); -+ print("tCaC : "); print(tCaC.layout()); print("\n"); -+ print("\n"); -+ print("gDt : "); print(gDt.layout()); print("\n"); -+ print("tDsC : "); print(tDsC.layout()); print("\n"); -+ print("tDrC : "); print(tDrC.layout()); print("\n"); -+ print("\n"); -+ print("tDrD : "); print(tDrD.layout()); print("\n"); -+ print("tDgC : "); print(tDgC.layout()); print("\n"); -+ print("tDgD : "); print(tDgD.layout()); print("\n"); -+ print("\n"); -+ } -+#endif -+ -+ // For each tiling needed for SmemLayout to cover shape(gD) -+ CUTLASS_PRAGMA_UNROLL -+ for (int step_m = 0; step_m < size<2>(cDt); ++step_m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int step_n = 0; step_n < size<3>(cDt); ++step_n) -+ { -+ // Step 1. Copy to SMEM -+ CUTLASS_PRAGMA_UNROLL -+ for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { -+ int mma_m = step_m * size<1>(tCsC) + pipe_m; -+ int mma_n = step_n * size<2>(tCsC) + pipe_n; -+ -+ copy(tC, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); -+ } -+ } -+ -+ // Step 2. Wait for SMEM writes to complete -+ synchronize(); -+ -+ // Step 3. Copy from SMEM into a fragment -+ copy(tD, tDsC, tDrC); -+ -+ // Step 4. Wait for SMEM reads to complete -+ synchronize(); -+ -+ Tensor tDgDmn = tDgD(_,_,_,step_m,step_n); -+ Tensor tDcDmn = tDcD(_,_,_,step_m,step_n); -+ -+ if (epilogue_op.is_source_needed()) { -+ // source is needed -+ Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<1>(tDgDmn); ++m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<2>(tDgDmn); ++n) -+ { -+ // Predication -+ if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && -+ get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) -+ { -+ // Step 5. Elementwise operation with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size<0>(tDrC); ++i) { -+ tDrD(i,m,n) = epilogue_op(tDrC(i,m,n), tDgCmn(i,m,n)); -+ } -+ // Step 6. Copy to GMEM -+ copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); -+ } -+ } -+ } -+ } -+ else { -+ // source is not needed, avoid load and lift compute -+ -+ // Step 5. Elementwise operation with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < size(tDrC); ++i) { -+ tDrD(i) = epilogue_op(tDrC(i)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<1>(tDgDmn); ++m) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<2>(tDgDmn); ++n) -+ { -+ // Predication -+ if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && -+ get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) -+ { -+ // Step 6. Copy to GMEM -+ copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+private: -+ Params params; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace collective -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp -new file mode 100644 -index 0000000..de318d5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/dispatch_policy.hpp -@@ -0,0 +1,39 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::epilogue { -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Collective Epilogue Policies -+// -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::epilogue -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h -new file mode 100644 -index 0000000..484f2cc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/activation.h -@@ -0,0 +1,705 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This extends the contents of cutlass/functional.h with frequently used activation functions. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/constants.h" -+#include "cutlass/complex.h" -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+#include "cutlass/functional.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+struct LinearCombinationGenericParams { -+ T alpha; ///< scales accumulators -+ T beta; ///< scales source tensor -+ T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams(): -+ alpha(T(1)), -+ beta(T(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams( -+ T alpha, -+ T beta = T(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGenericParams( -+ T const *alpha_ptr, -+ T const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Identity operator -+template -+struct Identity { -+ static const bool kIsHeavy=false; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value) const { -+ return value; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Identity > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ return value; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+/// ReLu operator - propagates NaNs -+/// Always put threshold in the right hand side of max to propagate NaN. -+template -+struct ReLu { -+ static const bool kIsHeavy=false; -+ CUTLASS_HOST_DEVICE -+ T operator()(T const & threshold, T value) const { -+ maximum mx; -+ -+ return mx(value, threshold); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value) const { -+ maximum mx; -+ -+ return mx(value, T(0)); -+ } -+ -+ /// Host-constructable parameters structure -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct ReLu> { -+ static const bool kIsHeavy=false; -+ CUTLASS_HOST_DEVICE -+ Array operator()(T const & threshold, Array const &frag) const { -+ maximum > mx; -+ -+ return mx(frag, threshold); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ maximum > mx; -+ return mx(frag, T(0)); -+ } -+ -+ /// Host-constructable parameters structure -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag, Params const ¶ms_) const { -+ return this->operator()(frag); -+ } -+}; -+ -+// Leaky Relu operator -+template -+struct LeakyReLU { -+ -+ struct Params: LinearCombinationGenericParams { -+ T leaky_alpha; ///< leaky_alpha -+ -+ // Methods -+ using LinearCombinationGenericParams::LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ LinearCombinationGenericParams(), -+ leaky_alpha(T(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ T alpha, -+ T beta, -+ T leaky_alpha = T(1) -+ ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, T const & alpha_recip) const { -+ T res = value > T(0) ? value : value * alpha_recip; -+ return res; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &value, Params const ¶ms_) const { -+ this->operator()(value, params_.leaky_alpha); -+ } -+}; -+ -+template -+struct LeakyReLU > { -+ -+ struct Params: LinearCombinationGenericParams { -+ T leaky_alpha; ///< leaky_alpha -+ using LinearCombinationGenericParams::LinearCombinationGenericParams; -+ -+ // Methods -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ LinearCombinationGenericParams(), -+ leaky_alpha(T(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ T alpha, -+ T beta, -+ T leaky_alpha = T(1) -+ ): LinearCombinationGenericParams(alpha, beta), leaky_alpha(leaky_alpha) {} -+ }; -+ -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, T const & alpha_recip) const { -+ Array y; -+ LeakyReLU leaky_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < int(value.size()); ++i) { -+ y[i] = leaky_op(value[i], alpha_recip); -+ } -+ -+ return y; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value, params_.leaky_alpha); -+ } -+}; -+ -+// Tanh operator -+template -+struct Tanh { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return fast_tanh(scalar); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct Tanh > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ Tanh tanh_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = tanh_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Tanh> { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const& z) const { -+ fast_tanh_op> tanh; -+ return tanh(z); -+ -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// Sigmoid operator -+template -+struct Sigmoid { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return T(1) / (T(1) + fast_exp(-scalar)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct Sigmoid > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ Sigmoid sigmoid_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = sigmoid_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct Sigmoid> { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const& z) const { -+ plus> add; -+ -+#if defined(CUTLASS_USE_TANH_FOR_SIGMOID) -+ multiplies> mul; -+ fast_tanh_op> tanh; -+ return mul(add(tanh(mul(z, cutlass::constants::half())), cutlass::constants::one()), -+ cutlass::constants::half()); -+#else -+ divides> div; -+ negate> neg; -+ fast_exp_op> fast_exp; -+ return div(cutlass::constants::one(), -+ add(cutlass::constants::one(), -+ fast_exp(neg(z)))); -+#endif -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &z, Params const ¶ms_) const { -+ return this->operator()(z); -+ } -+}; -+ -+// SiLu (swish) operator introduced by Elfwing et al. in the following paper -+// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017) -+// https://arxiv.org/pdf/1702.03118.pdf -+// It is used in EfficientNet and YOLOv5, for example. -+// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html -+template -+struct SiLu { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ Sigmoid sigmoid; -+ return scalar * sigmoid(scalar); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct SiLu> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Sigmoid> sigmoid_op; -+ multiplies> mul; -+ return mul(value, sigmoid_op(value)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// Hardswish operator introduced by Howard et al. in the following paper -+// "Searching for MobileNetV3" (2019) -+// https://arxiv.org/pdf/1905.02244.pdf -+// It is used in models based on MobilenetNetV3. -+// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html -+template -+struct HardSwish { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x) const { -+ minimum mn; -+ maximum mx; -+ T relu6 = mn(mx(x + T(3), T(0)), T(6)); -+ return x * relu6 / T(6); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template <> -+struct HardSwish { -+ using T = float; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x) const { -+ minimum mn; -+ maximum mx; -+ T relu6 = mn(mx(x + T(3), T(0)), T(6)); -+ return x * relu6 * 0.16666667f; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template -+struct HardSwish > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ HardSwish hardswish_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = hardswish_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+template -+struct HardSwish > { -+ using T = half_t; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ minimum > mn; -+ maximum > mx; -+ multiplies > mul; -+ plus > add; -+ -+ return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(0.16666667f)); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &x, Params const ¶ms_) const { -+ return this->operator()(x); -+ } -+}; -+ -+// -+// GELU function definitions implemented as described by -+// Hendrycks, D., and Gimpel, K. in -+// "Gaussian Error Linear Units (GELUs)." (2020) -+// https://arxiv.org/pdf/1606.08415.pdf -+// -+// Floating-point constants are Taylor coefficients described in the paper. -+// -+ -+// GELU operator -+template -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar) const { -+ return T(cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template <> -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &scalar) const { -+ return cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template <> -+struct GELU { -+ CUTLASS_HOST_DEVICE -+ double operator()(double const &scalar) const { -+ return cutlass::constants::half() * scalar * -+ (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ double operator()(double const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct GELU > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ GELU gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+// GELU operator implemented using the Taylor series approximation -+template -+struct GELU_taylor { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &z) const { -+ -+ T k0 = T(0.7978845608028654); -+ T k1 = T(0.044715); -+ -+ return T(cutlass::constants::half() * z * -+ (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &scalar, Params const ¶ms_) const { -+ return this->operator()(scalar); -+ } -+}; -+ -+template -+struct GELU_taylor > { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &z) const { -+ -+ using T = half_t; -+ Array y; -+ -+ half_t k0 = half_t(0.7978845608028654); -+ half_t k1 = half_t(0.044715); -+ -+ multiply_add> fma; -+ multiplies> mul; -+ plus> add; -+ -+ fast_tanh_op> tanh; -+ -+ Array u = mul(mul(k0, z), fma(mul(k1, z), z, cutlass::constants::one())); -+ -+ y = mul(mul(z, cutlass::constants::half()), add(cutlass::constants::one(), tanh(u))); -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+template -+struct GELU_taylor > { -+ static const bool kIsHeavy=true; -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value) const { -+ Array y; -+ GELU_taylor gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(value[i]); -+ } -+ -+ return y; -+ } -+ -+ using Params = LinearCombinationGenericParams; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &value, Params const ¶ms_) const { -+ return this->operator()(value); -+ } -+}; -+ -+/// Computes backwards pass for GELU operator assuming d_t is the layer gradient and -+/// z is computed from the forward pass. -+template -+struct dGELU { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &d_t, T const &z) const { -+ -+ T k0 = T(0.7978845608028654); -+ T k1 = T(0.044715); -+ T k2 = T(0.1070322243); -+ -+ T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z)); -+ -+ T ff = constants::half() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) + -+ constants::half() * (1 + tanh_out); -+ -+ return ff * d_t; -+ } -+}; -+ -+template -+struct dGELU > { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &d_t, Array const &z) const { -+ Array y; -+ dGELU gelu_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = gelu_op(d_t[i], z[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h -new file mode 100644 -index 0000000..98e3beb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/conversion_op.h -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing conversion operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts the result without other operations -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class Convert { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementAccumulator_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = FragmentAccumulator; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ }; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ Convert(Params const ¶ms = Params()) { -+ -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ -+ } -+ -+ /// Returns true if source is needed based on state of runtime arguments -+ CUTLASS_HOST_DEVICE -+ constexpr bool is_source_needed() const { -+ return false; -+ } -+ -+ /// Constexpr function to enable the compiler to optimize away the source loading if it is -+ /// never needed. -+ CUTLASS_HOST_DEVICE -+ constexpr bool is_source_ever_needed() const { -+ return false; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source = FragmentOutput(), -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(accumulator); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h -new file mode 100644 -index 0000000..0c4b384 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination.h -@@ -0,0 +1,306 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+#include "cutlass/epilogue/thread/linear_combination_params.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation. -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ typename ElementSource_ = ElementOutput_ -+> -+class LinearCombination { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementC = ElementSource_; -+ using ElementD = ElementOutput_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ using ParamsBase = LinearCombinationParams; -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params : ParamsBase{ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ParamsBase( -+ ElementCompute(1), -+ ElementCompute(0) -+ ), -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): -+ ParamsBase(alpha, beta), -+ alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): -+ ParamsBase(alpha, ElementCompute(0)), -+ alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): -+ ParamsBase(*alpha_ptr, *beta_ptr), -+ alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): -+ ParamsBase(*alpha_ptr, ElementCompute(0)), -+ alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ParamsBase const& base -+ ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ #if defined(__CUDA_ARCH__) -+ alpha = reinterpret_cast(base.alpha_data); -+ beta = reinterpret_cast(base.beta_data); -+ #else -+ memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); -+ memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); -+ #endif -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombination(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ if (Scale == ScaleType::Nothing) -+ return destination_converter(converted_accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ if (Scale == ScaleType::NoBetaScaling) -+ intermediate = converted_source; -+ else -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ if (Scale == ScaleType::Nothing) -+ return destination_converter(converted_accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ return destination_converter(intermediate); -+ } -+ -+ // -+ // Specializations for scalar (for use with cute::collective::DefaultEpilogue) -+ // -+ CUTLASS_HOST_DEVICE -+ ElementD operator()(ElementAccumulator const accumulator, ElementC const source) const { -+ // Convert everything to Compute type, do compute, and then store to output type -+ NumericConverter accumulator_converter; -+ [[maybe_unused]] NumericConverter source_converter; -+ NumericConverter destination_converter; -+ -+ // Convert to destination numeric type -+ -+ ElementCompute converted_accumulator = accumulator_converter(accumulator); -+ if constexpr (Scale == ScaleType::Nothing) { -+ return destination_converter(converted_accumulator); -+ } -+ -+ // Perform binary operations -+ ElementCompute intermediate; -+ multiplies multiply; -+ multiply_add madd; -+ -+ if constexpr (Scale == ScaleType::NoBetaScaling) { -+ intermediate = source_converter(source); -+ } -+ else { -+ intermediate = multiply(beta_, source); // X = beta * C + uniform -+ } -+ -+ intermediate = madd(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ return destination_converter(intermediate); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ ElementD operator()(ElementAccumulator const accumulator) const { -+ // Convert everything to Compute type, do compute, and then store to output type -+ NumericConverter accumulator_converter; -+ NumericConverter destination_converter; -+ ElementCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Convert to destination numeric type -+ if constexpr (Scale == ScaleType::Nothing) { -+ return destination_converter(converted_accumulator); -+ } -+ -+ // Perform binary operations -+ ElementCompute intermediate; -+ multiplies multiply; -+ -+ intermediate = multiply(alpha_, accumulator); // D = alpha * Accum -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h -new file mode 100644 -index 0000000..6892efb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h -@@ -0,0 +1,260 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ typename ElementT_, -+ int ElementsPerAccess, -+ typename ElementwiseOp_ = Identity, -+ typename BinaryOp_ = plus -+> -+class LinearCombinationBiasElementwise { -+public: -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ using ElementT = ElementT_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using ElementwiseOp = ElementwiseOp_; -+ using BinaryOp = BinaryOp_; -+ -+ // Indicates that this epilogue applies only one binary operation -+ static bool const kIsSingleSource = true; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ using FragmentOutput = FragmentZ; -+ -+ static bool const kIsHeavy = ElementwiseOp::kIsHeavy; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = true; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = true; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationBiasElementwise(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C, -+ FragmentCompute const &V) const { -+ -+ ElementwiseOp elementwise_op; -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute tmp_C = NumericArrayConverter()(frag_C); -+ FragmentCompute result_Z; -+ FragmentCompute result_T; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]); -+ result_T[i] = z; -+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ NumericArrayConverter convert_t; -+ frag_T = convert_t(result_T); -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ ElementwiseOp elementwise_op; -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute result_Z; -+ FragmentCompute result_T; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); -+ result_T[i] = z; -+ result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ NumericArrayConverter convert_t; -+ frag_T = convert_t(result_T); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h -new file mode 100644 -index 0000000..b095c91 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_bias_relu.h -@@ -0,0 +1,450 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ArrayMaximum { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()( -+ Array const &lhs, -+ Array const &rhs) const { -+ -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ result[i] = fmax(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+template -+struct ArrayMaximum { -+ -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const &lhs, -+ Array const &rhs) const { -+ -+ Array result; -+ -+ #if __CUDA_ARCH__ >= 800 -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); -+ __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(rhs.raw_data()); -+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); -+ } -+ -+ #else -+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); -+ __half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data()); -+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]); -+ } -+ -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ Array operator()( -+ Array const &lhs, -+ half_t const &rhs) const { -+ -+ Array result; -+ -+ #if __CUDA_ARCH__ >= 800 -+ int const kVectorCount = ElementsPerAccess / 2; -+ -+ -+ __half rhs_raw = reinterpret_cast<__half const &>(rhs); -+ __half2 rhs_pair = __half2half2(rhs_raw); -+ -+ __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(lhs.raw_data()); -+ __half2 *res_ptr = reinterpret_cast<__half2 *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorCount; ++i) { -+ res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); -+ } -+ -+ #else -+ -+ __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); -+ __half const rhs_raw = reinterpret_cast<__half const &>(rhs); -+ __half *res_ptr = reinterpret_cast<__half *>(result.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]); -+ } -+ -+ #endif -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ReluConditional { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ bool conditional[], -+ Array const &fragment, -+ Element threshold) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ conditional[i] = !(fragment[i] < threshold); -+ } -+ } -+}; -+ -+template -+struct ReluConditional { -+ -+ CUTLASS_DEVICE -+ void operator()( -+ bool conditional[], -+ Array const &fragment, -+ half_t threshold) const { -+ -+ __half y = reinterpret_cast<__half const &>(threshold); -+ __half const *x = reinterpret_cast<__half const *>(fragment.raw_data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < ElementsPerAccess; ++i) { -+ conditional[i] = !__hlt(x[i], y); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This is a partial specialization for fused Bias and ReLU. It supports the option of packing -+/// ReLU conditionals in a bit vector that may be used by backwards passes as an optimization. -+/// -+/// This class can only be used with cutlass::epilogue::threadblock::EpilogueWithBroadcast<>. -+/// -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ int ElementsPerAccess, -+ bool StoreT = true -+> -+class LinearCombinationBiasRelu { -+public: -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ -+ using ElementT = uint1b_t; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using ElementwiseOp = ReLu; -+ using BinaryOp = plus; -+ -+ // Indicates that this epilogue applies only one binary operation -+ static bool const kIsSingleSource = true; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = true; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = StoreT; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ ElementZ threshold; ///< ReLu threshold -+ -+ // -+ // Methods -+ // -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute()), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr), -+ threshold(ElementCompute()) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold_ = ElementCompute() -+ ): -+ alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ NumericConverter convert_threshold; -+ -+ threshold = convert_threshold(threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr), threshold(ElementZ()) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold_ = ElementCompute() -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ NumericConverter convert_threshold; -+ -+ threshold = convert_threshold(threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr), threshold(ElementZ()) { -+ } -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementZ threshold_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationBiasRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C, -+ FragmentCompute const &V) const { -+ -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute tmp_C = NumericArrayConverter()(frag_C); -+ FragmentCompute result_Z; -+ -+ bool conditions[kElementsPerAccess]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ -+ ElementCompute z = alpha_ * tmp_Accum[i]; -+ z += beta_ * tmp_C[i]; -+ -+ z = binary_op(z, V[i]); -+ result_Z[i] = z; -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ // -+ // Compute condition -+ // -+ -+ detail::ReluConditional relu_conditional; -+ relu_conditional(conditions, frag_Z, threshold_); -+ -+ detail::ArrayMaximum maximum_op; -+ frag_Z = maximum_op(frag_Z, threshold_); -+ -+ if (kStoreT) { -+ PackPredicates pack_predicates; -+ frag_T = pack_predicates(conditions); -+ } -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ BinaryOp binary_op; -+ -+ FragmentCompute tmp_Accum = NumericArrayConverter()(AB); -+ FragmentCompute result_Z; -+ -+ bool conditions[kElementsPerAccess]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementsPerAccess; ++i) { -+ ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); -+ result_Z[i] = z; -+ } -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ -+ // -+ // Compute condition -+ // -+ -+ detail::ReluConditional relu_conditional; -+ relu_conditional(conditions, frag_Z, threshold_); -+ -+ detail::ArrayMaximum maximum_op; -+ frag_Z = maximum_op(frag_Z, threshold_); -+ -+ // -+ // Compute conditions -+ // -+ -+ // -+ // Store -+ // -+ if (kStoreT) { -+ PackPredicates pack_predicates; -+ frag_T = pack_predicates(conditions); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h -new file mode 100644 -index 0000000..fdfe171 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_clamp.h -@@ -0,0 +1,693 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear scaling operations used by epilogues. Values are clamped before -+ converting to the output element type. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationClampIsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements then clamps the output before -+/// converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationClamp { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationClamp(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClampMax = -+ ElementCompute(platform::numeric_limits::max()); -+ -+ ElementCompute const kClampMin = -+ ElementCompute(platform::numeric_limits::lowest()); -+ -+ intermediate = max_accumulator(intermediate, kClampMin); -+ intermediate = min_accumulator(intermediate, kClampMax); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClampMax = -+ ElementCompute(platform::numeric_limits::max()); -+ -+ ElementCompute const kClampMin = -+ ElementCompute(platform::numeric_limits::lowest()); -+ -+ intermediate = max_accumulator(intermediate, kClampMin); -+ intermediate = min_accumulator(intermediate, kClampMax); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements then clamps the output before -+/// converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationClamp { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static_assert( -+ platform::numeric_limits::is_integer, -+ "This elementwise op expects the output to be int."); -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationClampIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha -+ ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationClamp(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements then clamps -+/// the output before converting to the output element type. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm -+/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is -+/// above. -+/// TODO: Add logic to fallback to the default approach -+template < -+ /// Data type used to load and store< tensors -+ typename ElementOutput_, -+ /// Number of elements computed per operation -+ int Count, -+ ///< Control Alpha and Beta scaling -+ ScaleType::Kind Scale = ScaleType::Default, -+ /// Rounding mode -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> -+class FastLinearCombinationClamp { -+ public: -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static_assert( -+ platform::numeric_limits::is_integer, -+ "This elementwise op expects the output to be int."); -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ /// scales accumulators -+ ElementCompute alpha; -+ /// scales source tensor -+ ElementCompute beta; -+ /// pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *alpha_ptr; -+ /// pointer to source scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() -+ : alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha) -+ : alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+ public: -+ /// Constructs the function object, possibly loading from pointers in host -+ /// memory -+ CUTLASS_HOST_DEVICE -+ FastLinearCombinationClamp(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator, -+ FragmentOutput const &source, -+ ElementCompute uniform = ElementCompute(0)) const { -+ // Convert source to interal compute numeric type -+ FastNumericArrayConverter -+ source_converter; -+ FastNumericArrayConverter -+ accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = -+ mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, -+ intermediate); // D = alpha * Accum + X -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClamp = -+ ElementCompute(1 << (sizeof_bits::value - 1)); -+ -+ intermediate = max_accumulator(intermediate, -kClamp); -+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); -+ -+ // Convert to destination numeric type -+ FastNumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()(FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ FastNumericArrayConverter -+ accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Compute linear scaling in floating point -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ -+ minimum min_accumulator; -+ maximum max_accumulator; -+ -+ // Float min-max -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); -+ } -+ -+ /// Clamping constant value -+ ElementCompute const kClamp = -+ ElementCompute(1 << (sizeof_bits::value - 1)); -+ -+ intermediate = max_accumulator(intermediate, -kClamp); -+ intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1)); -+ -+ // Convert to destination numeric type -+ FastNumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h -new file mode 100644 -index 0000000..d026a8c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_dgelu.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Functor performing linear combination followed by dGelu operation -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/constants.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDGelu { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static bool const kIsHeavy = true; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDGelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ dGELU gelu_op; -+ -+ // dGelu -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ dGELU gelu_op; -+ -+ // dGelu with conversion -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ intermediate[i] = gelu_op(intermediate[i], ElementCompute(tensor[i])); -+ } -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h -new file mode 100644 -index 0000000..f05da6d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_drelu.h -@@ -0,0 +1,452 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a maximum operation used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDRelu { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementTensor threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = ElementTensor(params.threshold); -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ ElementTensor cond = tensor[i]; -+ if (cond <= threshold_) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ ElementTensor cond = tensor[i]; -+ if (cond <= threshold_) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ int Count, ///< Number of elements computed per operation -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationDReluConditionalBits { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = uint1b_t; -+ -+ static bool const kIsHeavy = false; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ FragmentTensor predicate_mask_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationDReluConditionalBits(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ participates_in_reduction_ = true; -+ predicate_mask_.clear(); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ predicate_mask_.clear(); -+ -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ -+ bit_not not_op; -+ predicate_mask_ = not_op(predicate_mask_); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ bit_or or_op; -+ -+ FragmentTensor predicates = or_op(tensor, predicate_mask_); -+ -+ // Obtain from packed bits -+ bool conditions[kCount]; -+ UnpackPredicates unpack_predicates; -+ -+ unpack_predicates(conditions, predicates); -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ if (!conditions[i]) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ bit_or or_op; -+ -+ FragmentTensor predicates = or_op(tensor, predicate_mask_); -+ -+ // Obtain from packed bits -+ bool conditions[kCount]; -+ UnpackPredicates unpack_predicates; -+ -+ unpack_predicates(conditions, predicates); -+ -+ // dReLU = (cond ? dy : 0) -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ if (!conditions[i]) { -+ intermediate[i] = ElementCompute(); -+ } -+ } -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h -new file mode 100644 -index 0000000..0a68c16 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_gelu.h -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with GELU operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the GELU activation to an array of elements. -+/// -+/// D = gelu(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationGELU = LinearCombinationGeneric; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h -new file mode 100644 -index 0000000..71ada3f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_generic.h -@@ -0,0 +1,207 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by an activation function to an array of elements. -+/// -+/// D = activation(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ template class ActivationFunctor, -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ bool IsHeavy = false -+> -+class LinearCombinationGeneric { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static bool const kIsHeavy = IsHeavy; -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ using Params = typename ActivationFunctor::Params; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Params params_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationGeneric(Params const ¶ms) { -+ params_ = params; -+ params_.alpha = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ params_.beta = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return params_.beta != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ params_.beta = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ActivationFunctor activation; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(params_.beta, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_accumulator; -+ ActivationFunctor activation; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_accumulator(params_.alpha, converted_accumulator); // D = alpha * Accum -+ } -+ -+ intermediate = skip_elementwise_ ? intermediate : activation(intermediate, params_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h -new file mode 100644 -index 0000000..3bd4b89 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_hardswish.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with HardSwish operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the HardSwish activation to an array of elements. -+/// -+/// D = hardswish(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationHardSwish = LinearCombinationGeneric; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h -new file mode 100644 -index 0000000..ebee6b4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationLeakyRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using ComputeFragment = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta_bias; ///< scales bias tensor -+ ElementCompute leaky_alpha; ///< leaky_alpha -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta_bias(ElementCompute(0)), -+ leaky_alpha(ElementCompute(1)) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta_bias, -+ ElementCompute leaky_alpha = ElementCompute(1) -+ ): alpha(alpha), beta_bias(beta_bias), leaky_alpha(leaky_alpha) { -+ -+ } -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_bias_; -+ ElementCompute leaky_alpha_recip_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationLeakyRelu(Params const ¶ms) { -+ alpha_ = (params.alpha); -+ beta_bias_ = (params.beta_bias); -+ leaky_alpha_recip_ = (ElementCompute(params.leaky_alpha)); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_bias_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition) { -+ if (k_partition) { -+ beta_bias_ = ElementCompute(1); -+ } -+ } -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_bias_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source = source_converter(source); -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ LeakyReLU leakyrelu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ // Compute threshold optionally -+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies mul_accumulator; -+ LeakyReLU leakyrelu; -+ //printf("in doing with bias"); -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = leakyrelu(intermediate, leaky_alpha_recip_); -+ -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h -new file mode 100644 -index 0000000..a3f825e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_params.h -@@ -0,0 +1,75 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct LinearCombinationParams { -+ uint64_t alpha_data[2]; -+ uint64_t beta_data[2]; -+ -+ CUTLASS_HOST_DEVICE -+ LinearCombinationParams() -+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} -+ { } -+ -+ template -+ CUTLASS_HOST_DEVICE -+ LinearCombinationParams(ElementCompute alpha, ElementCompute beta) -+ : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} -+ { -+ #if defined(__CUDA_ARCH__) -+ reinterpret_cast(alpha_data) = alpha; -+ reinterpret_cast(beta_data) = beta; -+ #else -+ memcpy( alpha_data, &alpha, sizeof(ElementCompute) ); -+ memcpy( beta_data, &beta, sizeof(ElementCompute) ); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h -new file mode 100644 -index 0000000..005e301 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_planar_complex.h -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination operations on planar-complex arrays -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to arrays of planar-complex elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Note, as with most CUTLASS components for planar complex, the template arguments describe -+/// the underlying real data type. -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationPlanarComplex { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ -+ using FragmentOutput = ArrayPlanarComplex; -+ using FragmentAccumulator = ArrayPlanarComplex; -+ using ComputeFragment = ArrayPlanarComplex; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ complex alpha; ///< scales accumulators -+ complex beta; ///< scales source tensor -+ complex const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ complex const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ complex alpha, -+ complex beta -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ complex const *alpha_ptr, -+ complex const *beta_ptr -+ ): alpha(complex()), beta(complex()), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ complex alpha_; -+ complex beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationPlanarComplex(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_.real() != ElementCompute(0) || beta_.imag() != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_source( -+ source_converter(source.real), -+ source_converter(source.imag)); -+ -+ ComputeFragment converted_accumulator( -+ accumulator_converter(accumulator.real), -+ accumulator_converter(accumulator.imag)); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies > mul_op; -+ multiply_add > mul_add_op; -+ -+ // complex multiply: I = beta * C -+ intermediate.real = mul_op(beta_.real(), converted_source.real); -+ intermediate.imag = mul_op(beta_.real(), converted_source.imag); -+ -+ intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); -+ intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); -+ -+ // complex multiply-add: I = alpha * AB + I -+ intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real, intermediate.real); -+ intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag, intermediate.imag); -+ -+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); -+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return FragmentOutput( -+ destination_converter(intermediate.real), -+ destination_converter(intermediate.imag)); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ ComputeFragment converted_accumulator( -+ accumulator_converter(accumulator.real), -+ accumulator_converter(accumulator.imag)); -+ -+ // Perform binary operations -+ ComputeFragment intermediate; -+ -+ multiplies > mul_op; -+ multiply_add > mul_add_op; -+ -+ // complex multiply-add: I = alpha * AB + I -+ intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real); -+ intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag); -+ -+ intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); -+ intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return FragmentOutput( -+ destination_converter(intermediate.real), -+ destination_converter(intermediate.imag)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h -new file mode 100644 -index 0000000..eb1b436 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h -@@ -0,0 +1,570 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a maximum operation used by epilogues. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationReluIsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0), -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Special handling for int types -+ -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationRelu { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static bool const kIsHeavy = detail::LinearCombinationReluIsHeavy(); -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0), -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(threshold_, intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h -new file mode 100644 -index 0000000..3cffd93 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_relu0.h -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with a relu operation used by epilogues. -+ This one only supports relu0 and tries to folding relu into other instructions. Thus, -+ serial splitk is not supported by this one. For example, relu can be folded into -+ hfma2/hmul2 for sm80+ -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/scale_type.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Single source of truth for whether to unroll for `LinearCombinationClamp()` -+constexpr bool LinearCombinationRelu0IsHeavy() { -+ return false; -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationRelu0 { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu0(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// This is used for serial reduction which is not supported by Relu0 -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(k_partition == 0); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add_relu0 mul_add_relu0_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_relu0_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter destination_converter; -+ -+ return destination_converter(intermediate); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+/// Special handling for int types -+ -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ScaleType::Kind Scale, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round -+> -+class LinearCombinationRelu0 { -+public: -+ -+ using ElementOutput = ElementOutput_; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ static bool const kIsHeavy = detail::LinearCombinationRelu0IsHeavy(); -+ -+ static int const kCount = Count; -+ static const ScaleType::Kind kScale = Scale; -+ -+ using FragmentOutput = Array; -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentScaleBias = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta = ElementCompute(0) -+ ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr = nullptr -+ ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationRelu0(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ if (Scale == ScaleType::NoBetaScaling) return true; -+ -+ if (Scale == ScaleType::OnlyAlphaScaling) return false; -+ -+ if (Scale == ScaleType::Nothing) return false; -+ -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// This is used for serial reduction which is not supported by Relu0 -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ assert(k_partition == 0); -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentOutput const &source) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::NoBetaScaling) { -+ intermediate = converted_source; -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } else if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ ReLu relu; -+ -+ if (Scale == ScaleType::Nothing) { -+ intermediate = converted_accumulator; -+ } else { -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ } -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+ -+ /// Computes per-channel linear scaling and bias : D = scale * accumulator + bias -+ /// Scale and Bias are from input Fragment -+ CUTLASS_HOST_DEVICE -+ FragmentOutput operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentScaleBias const &scale, -+ FragmentScaleBias const &bias) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform per-channel scale and bias -+ FragmentCompute intermediate; -+ -+ multiply_add mul_add_accumulator; -+ -+ if(Scale == ScaleType::OnlyAlphaPerChannelScaling) -+ intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias -+ else -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias -+ -+ ReLu relu; -+ -+ // Compute threshold optionally -+ intermediate = relu(intermediate); -+ -+ if (platform::numeric_limits::is_integer) { -+ // Convert floats back to INT -+ FragmentAccumulator scaled_accumulator; -+ -+ NumericArrayConverter compute_converter; -+ -+ scaled_accumulator = compute_converter(intermediate); -+ -+ // Convert to destination numeric type -+ NumericArrayConverter -+ destination_converter; -+ -+ return destination_converter(scaled_accumulator); -+ } else { -+ NumericArrayConverter -+ destination_converter; -+ return destination_converter(intermediate); -+ } -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h -new file mode 100644 -index 0000000..7c47c24 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_residual_block.h -@@ -0,0 +1,302 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue functor specialized for residual blocks in deep neural networks. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+namespace detail { -+ -+/// Dummy class used to designate that the second binary operator in the epilogue is unsued -+template -+class NoOp {}; -+ -+} -+ -+/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) -+template class ActivationOp_, -+ template class BinaryOp1_, -+ template class UnaryOp_, -+ template class BinaryOp2_ = detail::NoOp> -+class LinearCombinationResidualBlock { -+public: -+ static bool const kIsSingleSource = false; -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using UnaryOp = UnaryOp_>; -+ using BinaryOp1 = BinaryOp1_>; -+ using BinaryOp2 = BinaryOp2_>; -+ using ActivationOp = ActivationOp_>; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentOutput = Array; -+ -+ using ElementZ = ElementOutput_; -+ using ElementT = ElementZ; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ static bool const kIsHeavy = true; -+ static bool const kStoreZ = true; -+ static bool const kStoreT = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales residual input -+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ }; -+ -+private: -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationResidualBlock(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// The "source" tensor corresponds to the residual input -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { return true; } -+ -+ /// Functionally required for serial reduction in the epilogue -+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2)) -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, -+ FragmentC const &residual1, FragmentC const &residual2, -+ FragmentCompute const &bias) const { -+ UnaryOp unary_op; -+ BinaryOp1 binary_op1; -+ BinaryOp2 binary_op2; -+ ActivationOp activation; -+ -+ FragmentCompute tmp_Accum = -+ NumericArrayConverter()(AB); -+ FragmentCompute tmp_residual1 = -+ NumericArrayConverter()(residual1); -+ FragmentCompute tmp_residual2 = -+ NumericArrayConverter()(residual2); -+ -+ FragmentCompute z = -+ binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2); -+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ } -+ -+ /// Should never be called -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, -+ FragmentCompute const &) const {} -+}; -+ -+/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) -+template class ActivationOp_, -+ template class BinaryOp1_, -+ template class UnaryOp_> -+class LinearCombinationResidualBlock { -+public: -+ static bool const kIsSingleSource = true; -+ -+ using ElementOutput = ElementC_; -+ using ElementC = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kCount = kElementsPerAccess; -+ -+ using UnaryOp = UnaryOp_>; -+ using BinaryOp = BinaryOp1_>; -+ using ActivationOp = ActivationOp_>; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentOutput = Array; -+ -+ using ElementZ = ElementOutput_; -+ using ElementT = ElementZ; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ static bool const kIsHeavy = true; -+ static bool const kStoreZ = true; -+ static bool const kStoreT = false; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales residual input -+ ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory -+ -+ CUTLASS_HOST_DEVICE -+ Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute alpha, ElementCompute beta) -+ : alpha(alpha), beta(beta) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) -+ : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} -+ }; -+ -+private: -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ bool skip_elementwise_; -+ -+public: -+ -+ /// Constructor from Params -+ CUTLASS_HOST_DEVICE -+ LinearCombinationResidualBlock(Params const ¶ms) { -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ skip_elementwise_ = false; -+ } -+ -+ /// The "source" tensor corresponds to the residual input -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { return true; } -+ -+ /// Functionally required for serial reduction in the epilogue -+ /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ skip_elementwise_ = true; -+ } -+ } -+ -+ /// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual)) -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, -+ FragmentC const &residual, -+ FragmentCompute const &bias) const { -+ UnaryOp unary_op; -+ BinaryOp binary_op; -+ ActivationOp activation; -+ -+ FragmentCompute tmp_Accum = -+ NumericArrayConverter()(AB); -+ FragmentCompute tmp_residual = -+ NumericArrayConverter()(residual); -+ -+ FragmentCompute z = -+ binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual); -+ FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); -+ -+ NumericArrayConverter convert_z; -+ frag_Z = convert_z(result_Z); -+ } -+ -+ /// Should never be called -+ CUTLASS_HOST_DEVICE -+ void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, -+ FragmentCompute const &) const {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h -new file mode 100644 -index 0000000..c449d23 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_sigmoid.h -@@ -0,0 +1,70 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with Sigmoid operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements. -+/// -+/// D = sigmoid(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationSigmoid = LinearCombinationGeneric; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h -new file mode 100644 -index 0000000..222f6de ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_silu.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing linear combination with SiLU operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_generic.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator folllowed by the SiLU activation to an array of elements. -+/// -+/// D = silu(alpha * accumulator + beta * source + uniform) -+/// -+template < -+ typename ElementOutput_, ///< Data type used to load and store tensors -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type -+ typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination -+ ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+using LinearCombinationSilu = LinearCombinationGeneric; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h -new file mode 100644 -index 0000000..aac19b0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Functor performing linear combination with elementwise -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/constants.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a linear combination operator to an array of elements. -+/// -+/// D = alpha * accumulator + beta * source + uniform -+/// -+template < -+ typename ElementCompute_, ///< Data type returned by this functor -+ typename ElementAccumulator_, ///< Data type of accumulators -+ typename ElementSource_, ///< Data type of source tensor -+ typename ElementTensor_, ///< Data type of additional tensor -+ int Count, ///< Number of elements computed per operation -+ ///< Usually it is 128/sizeof_bits, -+ ///< but we use 64 or 32 sometimes when there are not enough data to store -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+class LinearCombinationWithElementwise { -+public: -+ -+ using ElementOutput = ElementSource_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementSource = ElementSource_; -+ using ElementTensor = ElementTensor_; -+ -+ static bool const kIsHeavy = true; -+ -+ static int const kCount = Count; -+ -+ using FragmentCompute = Array; -+ using FragmentAccumulator = Array; -+ using FragmentSource = Array; -+ using FragmentTensor = Array; -+ -+ static FloatRoundStyle const kRound = Round; -+ -+ /// Host-constructable parameters structure -+ struct Params { -+ -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ ElementCompute threshold; ///< minimum value that is output -+ ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory -+ ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)), -+ threshold(ElementCompute(0)), -+ alpha_ptr(nullptr), -+ beta_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute alpha, -+ ElementCompute beta, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementCompute const *alpha_ptr, -+ ElementCompute const *beta_ptr, -+ ElementCompute threshold = ElementCompute(0) -+ ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { -+ -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ ElementCompute threshold_; -+ bool participates_in_reduction_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ LinearCombinationWithElementwise(Params const ¶ms) { -+ -+ alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); -+ beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); -+ threshold_ = params.threshold; -+ participates_in_reduction_ = true; -+ } -+ -+ /// Returns true if source is needed -+ CUTLASS_HOST_DEVICE -+ bool is_source_needed() const { -+ return beta_ != ElementCompute(0); -+ } -+ -+ /// Returns true if the threadblock computes the reduction -+ CUTLASS_HOST_DEVICE -+ bool participates_in_reduction() const { -+ return participates_in_reduction_; -+ } -+ -+ /// Functionally required for serial reduction in the epilogue -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { -+ if (k_partition) { -+ beta_ = ElementCompute(1); -+ } -+ -+ if (k_partition != k_partition_count - 1) { -+ // set to NaN to make ReLU no-op for all except last k partitions -+ int64_t allones = -1; -+ threshold_ = reinterpret_cast(allones); -+ // Avoid computing the reduction if this isn't the final Split-K slice -+ participates_in_reduction_ = false; -+ } -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator + beta * source -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentSource const &source, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter source_converter; -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_source = source_converter(source); -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_add_source; -+ multiply_add mul_add_accumulator; -+ -+ intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform -+ intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X -+ -+ return intermediate; -+ } -+ -+ /// Computes linear scaling: D = alpha * accumulator -+ CUTLASS_HOST_DEVICE -+ FragmentCompute operator()( -+ FragmentAccumulator const &accumulator, -+ FragmentTensor const &tensor) const { -+ -+ // Convert source to interal compute numeric type -+ NumericArrayConverter accumulator_converter; -+ -+ FragmentCompute converted_accumulator = accumulator_converter(accumulator); -+ -+ // Perform binary operations -+ FragmentCompute intermediate; -+ -+ multiplies mul_accumulator; -+ -+ intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum -+ -+ return intermediate; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h -new file mode 100644 -index 0000000..f904856 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/reduction_op.h -@@ -0,0 +1,97 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Functor performing reduction operations used by epilogues. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Applies a reduction sum to an array of elements. -+/// -+/// -+template < -+ typename Element_, ///< Data type used to load and store tensors -+ int Count ///< Number of elements computed per operation -+> -+class ReductionOpPlus { -+public: -+ -+ using Element = Element_; -+ static int const kCount = Count; -+ -+ using Fragment = Array; -+ using Operator = plus; -+ -+ /// Host-constructable parameters structure -+ struct Params { }; -+ -+private: -+ -+ /// reduction operator -+ Operator operator_; -+ -+public: -+ -+ /// Constructs the function object, possibly loading from pointers in host memory -+ CUTLASS_HOST_DEVICE -+ ReductionOpPlus(Params const ¶ms) { -+ -+ } -+ -+ /// Computes Compute => -+ CUTLASS_HOST_DEVICE -+ Fragment operator()( -+ Fragment const &lhs, -+ Fragment const &rhs) const { -+ -+ return operator_(lhs, rhs); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h b/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h -new file mode 100644 -index 0000000..f229927 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/thread/scale_type.h -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Enum defines the behaviors of the epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specifies internal data type for computation -+struct ScaleType { -+ enum Kind { -+ Default, // alpha x C + beta x D -+ NoBetaScaling, // alpha x C + D -+ OnlyAlphaScaling, // alpha x C -+ OnlyAlphaPerChannelScaling, // alpha_vec x C -+ Nothing // C -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h -new file mode 100644 -index 0000000..1b25816 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h -@@ -0,0 +1,255 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization and defines sensible defaults for epilogues for complex*complex case -+// 4 real-valued mma operations (Complex) -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Epilouge Shape -+ typename Shape_, -+ /// Warp-level mma operator -+ typename WarpMmaTensorOp_, -+ /// Number of k partitions -+ int PartitionsK, -+ /// Epilogue output operator -+ typename OutputOp_, -+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() -+ int ElementsPerAccess, -+ /// Multiply-add operator -+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex -+> -+struct DefaultEpilogueComplexTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = Operator_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization and defines sensible defaults for epilogues for complex*complex case -+// 3 real-valued mma operations (Gaussian Complex) -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueComplexTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = arch::OpMultiplyAddGaussianComplex; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h -new file mode 100644 -index 0000000..966d44c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped complex GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Specialization and defines sensible defaults for epilogues for complex*complex case -+// 4 real-valued mma operations (Complex) -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Epilouge Shape -+ typename Shape_, -+ /// Warp-level mma operator -+ typename WarpMmaTensorOp_, -+ /// Number of k partitions -+ int PartitionsK, -+ /// Epilogue output operator -+ typename OutputOp_, -+ /// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load() -+ int ElementsPerAccess, -+ /// Multiply-add operator -+ /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex, -+ /// Is for a symmetric kernel -+ BlasMode BlasMode_ = BlasMode::kGemm -+> -+struct DefaultEpilogueComplexTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = Operator_; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput -+ , kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization and defines sensible defaults for epilogues for complex*complex case -+// 3 real-valued mma operations (Gaussian Complex) -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ BlasMode BlasMode_ -+> -+struct DefaultEpilogueComplexTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using Operator = arch::OpMultiplyAddGaussianComplex; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput, -+ kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 0>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h -new file mode 100644 -index 0000000..fc93eb0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_direct_store.h -@@ -0,0 +1,74 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Direct store epilogue -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/epilogue/threadblock/epilogue_direct_store.h" -+#include "cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Given a properly constructed epilogue, returns a direct store epilogue -+template -+struct DefaultEpilogueDirectStore { -+ -+ using OutputTileIterator = DirectStoreEpilogueIterator; -+ -+ using Epilogue = EpilogueDirectStore< -+ typename EpilogueTensorOp::Shape, -+ typename EpilogueTensorOp::WarpMmaOperator, -+ EpilogueTensorOp::kPartitionsK, -+ OutputTileIterator, -+ typename EpilogueTensorOp::AccumulatorFragmentIterator, -+ typename EpilogueTensorOp::WarpTileIterator, -+ typename EpilogueTensorOp::SharedLoadIterator, -+ typename EpilogueTensorOp::OutputOp -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h -new file mode 100644 -index 0000000..872e425 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_planar_complex.h -@@ -0,0 +1,241 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Constructs a default epilogue for planar complex outputs. -+ -+ This template reuses components for real-valued epilogues and applies them to planar complex -+ output matrices. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+ -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMma_, -+ typename OpcodeClass_, -+ typename ArchTag_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueTensorOp< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues. -+template < -+ typename ThreadblockShape_, -+ typename WarpMmaOperator_, -+ typename ArchTag_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ arch::OpClassSimt, -+ ArchTag_, -+ PartitionsK, -+ OutputOp_, -+ ElementsPerAccess> { -+ -+ using RealEpilogue = DefaultEpilogueSimt< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ OutputOp_, -+ ElementsPerAccess -+ >; -+ -+ using Epilogue = EpiloguePlanarComplex< -+ ThreadblockShape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ typename RealEpilogue::OutputTileIterator, -+ typename RealEpilogue::AccumulatorFragmentIterator, -+ typename RealEpilogue::WarpTileIterator, -+ typename RealEpilogue::SharedLoadIterator, -+ OutputOp_, -+ typename RealEpilogue::Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h -new file mode 100644 -index 0000000..3214d19 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_simt.h -@@ -0,0 +1,422 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using SIMT. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_simt.h" -+#include "cutlass/epilogue/warp/tile_iterator_simt.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_simt.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_depthwise.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueSimt { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueSimtStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaSimt_, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueSimtAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static const int kPartitionsK = Shape::kK / WarpMmaSimt::Shape::kK; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapSimt< -+ Shape, -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::Policy, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaSimt, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for SimtOps. -+template , -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> > -+struct DefaultDirectConvEpilogueSimt { -+ using Shape = Shape_; -+ using WarpMmaSimt = WarpMmaSimt_; -+ using WarpShape = typename WarpMmaSimt::Shape; -+ using OutputOp = OutputOp_; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ static int const kElementsPerAccess = ElementsPerAccess_; -+ -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaSimt::LayoutC; -+ using ElementAccumulator = typename WarpMmaSimt::ElementC; -+ -+ /// Number of threads total -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN -+ >; -+ -+ static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; -+ -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv< -+ OutputTileThreadMap, -+ ElementOutput, -+ ThreadOutputShape, -+ ThreadBlockOutputShape -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< -+ typename WarpMmaSimt::Shape, -+ typename WarpMmaSimt::ThreadMma, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv< -+ typename WarpMmaSimt::Shape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ typename WarpMmaSimt::ThreadMma, -+ ElementAccumulator, -+ layout::RowMajor, -+ typename WarpMmaSimt::Policy -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner< -+ OutputTileThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise< -+ Shape, -+ ThreadOutputShape, -+ ThreadBlockOutputShape, -+ WarpMmaSimt, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h -new file mode 100644 -index 0000000..77411f3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h -@@ -0,0 +1,808 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_relu0.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_hardswish.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename ElementOutput, -+ typename ElementAccumulator, -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ ElementAccumulator, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ ElementAccumulator -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float <= float x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ static int const kFragmentsPerIteration = 2; -+}; -+ -+/// Partial specialization for int32_t <= int32_t x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float <= int32_t x 4 -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts. -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ half_t, -+ float, -+ 8, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ 16, -+ 8, -+ 8 -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ 16, -+ 8, -+ 8 -+ >; -+ -+ static int const kFragmentsPerIteration = 2; -+}; -+ -+/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ typename ElementOutput, -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ ElementOutput, -+ int32_t, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value, -+ "ElementOutput needs to be 4 or 8 bit (unsigned) int."); -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ int32_t, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ int32_t, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ int32_t -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ cutlass::float_e4m3_t, -+ float, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using ElementOutput = cutlass::float_e4m3_t; -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts. -+/// Threadblock::kN = 256 still has bank conflicts. -+template < -+ int ElementsPerAccess, -+ typename ThreadblockShape, -+ typename WarpShape, -+ typename InstructionShape, -+ typename ThreadMap -+> -+struct DefaultIteratorsTensorOp< -+ cutlass::float_e5m2_t, -+ float, -+ ElementsPerAccess, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ThreadMap> { -+ -+ using ElementOutput = cutlass::float_e5m2_t; -+ -+ static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), -+ "ElementsPerAccess needs to be 16 or 8."); -+ -+ using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< -+ WarpShape, -+ InstructionShape, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< -+ WarpShape, -+ InstructionShape, -+ float, -+ layout::RowMajor -+ >; -+ -+ using WarpTileIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ WarpTileIteratorNotMixed, -+ WarpTileIteratorMixed>::type; -+ -+ using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< -+ ThreadMap, -+ float, -+ 32, -+ cutlass::sizeof_bits::value, -+ ElementsPerAccess, -+ 8 -+ >; -+ -+ using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< -+ ThreadMap, -+ float -+ >; -+ -+ using SharedLoadIterator = typename platform::conditional< -+ (ThreadblockShape::kN == 256), -+ SharedLoadIteratorNotMixed, -+ SharedLoadIteratorMixed>::type; -+ -+ static int const kFragmentsPerIteration = 1; -+}; -+ -+} // namespace detail -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ static bool const UseCUDAStore = platform::is_same::value; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout, -+ UseCUDAStore -+ >; -+ -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueTensorOpStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueTensorOpAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ // Map to the row major iterator since the iterator selection for affineN is the same. -+ using AccumulatorFragmentIterator = typename platform::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ layout::RowMajor>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ layout::RowMajor> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding, -+ kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Defines sensible defaults for epilogues for TensorOps which uses -+/// intereleaved output layout. For this case, shared memory is not needed. -+template -+struct DefaultInterleavedEpilogueTensorOp { -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedThreadMapTensorOp< -+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput, -+ kElementsPerAccess, InterleavedK>::Type; -+ -+ using OutputTileIterator = -+ cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator< -+ OutputTileThreadMap, ElementOutput, InterleavedK>; -+ -+ using AccumulatorFragmentIterator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue< -+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator, -+ AccumulatorFragmentIterator, OutputOp, InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps which uses -+/// intereleaved output layout. For this case, shared memory is not needed. -+template -+struct DefaultInterleavedConvEpilogue { -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedConvThreadMapTensorOp< -+ Shape, typename WarpMmaTensorOp::Shape, kPartitionsK, ElementOutput, -+ kElementsPerAccess, InterleavedK>::Type; -+ -+ using OutputTileIterator = -+ cutlass::epilogue::threadblock::InterleavedConvPredicatedTileIterator< -+ OutputTileThreadMap, ElementOutput, InterleavedK>; -+ -+ using AccumulatorFragmentIterator = -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ // can reuse the gemm version here to do element selection -+ layout::ColumnMajorInterleaved>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::InterleavedEpilogue< -+ Shape, WarpMmaTensorOp, kPartitionsK, OutputTileIterator, -+ AccumulatorFragmentIterator, OutputOp, InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h -new file mode 100644 -index 0000000..aef4961 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ /// Is for a symmetric kernel -+ BlasMode BlasMode_ = BlasMode::kGemm -+> -+struct DefaultEpilogueTensorOpBlas3 { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorBlas3< -+ OutputTileThreadMap, -+ ElementOutput, -+ kBlasMode -+ >; -+ -+ using AccumulatorFragmentIterator = typename std::conditional::value, -+ cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC>, -+ cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC> >::type; -+ -+ /// Support several implementations depending on structure of epilogue -+ using DefaultIterators = detail::DefaultIteratorsTensorOp< -+ ElementOutput, -+ ElementAccumulator, -+ kElementsPerAccess, -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename OutputTileThreadMap::CompactedThreadMap -+ >; -+ -+ using WarpTileIterator = typename DefaultIterators::WarpTileIterator; -+ using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; -+ -+ /// Hard-coded padding elements added -+ using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h -new file mode 100644 -index 0000000..9936f96 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueVoltaTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueVoltaTensorOpStridedDgrad { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< -+ OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ int Rank, -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueVoltaTensorOpAffineRankN { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankN< -+ OutputTileThreadMap, -+ ElementOutput, -+ Rank -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ gemm::GemmShape<32, 32, 4>, -+ ElementAccumulator, -+ LayoutC -+ >; -+ -+ static int const kSharedMemAlignment = sizeof_bits::value * WarpTileIterator::kElementsPerAccess / 8; -+ -+ static_assert(kSharedMemAlignment == 8, "Shared memory alignment must be 8B"); -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator, -+ kSharedMemAlignment -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h -new file mode 100644 -index 0000000..381cb30 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h -@@ -0,0 +1,183 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithBroadcastTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ // -+ // Stores the result z = (y = GEMM(A, B, C), broadcast) -+ // -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ // -+ // Additional tensor tile iterator - stores t = Elementwise(z) -+ // -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementTensor -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithBroadcast< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ ElementVector, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ OutputOp, -+ typename Base::Padding, -+ Base::kFragmentsPerIteration -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for VoltaTensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename ElementTensor, -+ typename ElementVector, -+ typename OutputOp, -+ int ElementsPerAccess -+> -+struct DefaultEpilogueWithBroadcastVoltaTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ // -+ // Stores the result z = (y = GEMM(A, B, C), broadcast) -+ // -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput -+ >; -+ -+ // -+ // Additional tensor tile iterator - stores t = Elementwise(z) -+ // -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementTensor -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithBroadcast< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ ElementVector, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ OutputOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h -new file mode 100644 -index 0000000..3c85551 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithReductionTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ /// Additional tensor tile iterator -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ typename OutputOp::ElementTensor -+ >; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithReduction< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ typename WarpMmaTensorOp::ElementC, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ typename Base::OutputOp, -+ ReductionOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for TensorOps. -+template < -+ typename Shape, -+ typename WarpMmaTensorOp, -+ int PartitionsK, -+ typename ElementOutput, -+ typename OutputOp, -+ typename ReductionOp, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWithReductionVoltaTensorOp { -+ -+ /// Use defaults related to the existing epilogue -+ using Base = DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputOp, -+ ElementsPerAccess -+ >; -+ -+ /// Additional tensor tile iterator -+ using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ typename OutputOp::ElementTensor -+ >; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ typename Base::OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ /// Define the epilogue -+ using Epilogue = EpilogueWithReduction< -+ Shape, -+ WarpMmaTensorOp, -+ PartitionsK, -+ OutputTileIterator, -+ TensorTileIterator, -+ typename WarpMmaTensorOp::ElementC, -+ typename Base::AccumulatorFragmentIterator, -+ typename Base::WarpTileIterator, -+ typename Base::SharedLoadIterator, -+ typename Base::OutputOp, -+ ReductionOp, -+ typename Base::Padding -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h -new file mode 100644 -index 0000000..f95e4ea ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using WMMA. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/thread/linear_combination_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines sensible defaults for epilogues for WMMA TensorOps. -+template < -+ typename Shape_, -+ typename WarpMmaTensorOp_, -+ int PartitionsK, -+ typename OutputOp_, -+ int ElementsPerAccess, -+ bool ScatterD = false, -+ typename PermuteDLayout = layout::NoPermute -+> -+struct DefaultEpilogueWmmaTensorOp { -+ -+ using Shape = Shape_; -+ using WarpMmaTensorOp = WarpMmaTensorOp_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using LayoutC = typename WarpMmaTensorOp::LayoutC; -+ using ElementAccumulator = typename WarpMmaTensorOp::ElementC; -+ -+ // -+ // Thread map -+ // -+ -+ using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp< -+ Shape, -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ kPartitionsK, -+ ElementOutput, -+ kElementsPerAccess -+ >::Type; -+ -+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ OutputTileThreadMap, -+ ElementOutput, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::ElementC, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using WarpTileIterator = cutlass::epilogue::warp::TileIteratorWmmaTensorOp< -+ typename WarpMmaTensorOp::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::Shape, -+ typename WarpMmaTensorOp::Policy::Operator::FragmentC, -+ LayoutC -+ >; -+ -+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< -+ typename OutputTileThreadMap::CompactedThreadMap, -+ ElementAccumulator -+ >; -+ -+ /// Hard-coded padding elements added -+ using Padding = typename WarpTileIterator::Padding; -+ -+ // -+ // Define the epilogue -+ // -+ using Epilogue = cutlass::epilogue::threadblock::Epilogue< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputTileIterator, -+ AccumulatorFragmentIterator, -+ WarpTileIterator, -+ SharedLoadIterator, -+ OutputOp, -+ Padding -+ >; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h -new file mode 100644 -index 0000000..363d1e5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_simt.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for SIMT accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename MmaSimtPolicy_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapSimt { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Computes number of thread-level matrix multiplies are needed to span a warp -+ static int const kGroupCount = -+ WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM); -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Number of iterations -+ static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap< -+ OutputTileShape< // Shape -+ ThreadblockShape::kN, -+ 1, -+ MmaSimtPolicy::WarpShape::kRow, -+ Detail::WarpCount::kM, -+ 1>, -+ OutputTileShape< // Count -+ 1, -+ MmaSimtPolicy::LaneMmaShape::kM, -+ Detail::kGroupCount, -+ 1, -+ Detail::kIterations>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h -new file mode 100644 -index 0000000..14972d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapTensorOp { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template -+struct DefaultInterleavedThreadMapTensorOp { -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kInterleavedK = InterleavedK; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), -+ "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept -+ /// InterleavedOutputTileThreadMap -+ using Type = InterleavedOutputTileThreadMap< -+ layout::PitchLinearShape, -+ layout::PitchLinearShape, -+ Detail::kThreads, kElementsPerAccess, sizeof_bits::value>; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template -+struct DefaultInterleavedConvThreadMapTensorOp { -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kInterleavedK = InterleavedK; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ /// Tensor Operations fundamentally perform operations on 8 rows -+ static int const kTensorOpRows = 8; -+ static int const kWarpSize = 32; -+ -+ static_assert(!(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), -+ "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept -+ /// InterleavedOutputTileThreadMap -+ using Type = InterleavedConvOutputTileThreadMap< -+ MatrixShape, -+ MatrixShape, -+ Detail::kThreads, kElementsPerAccess, sizeof_bits::value>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h -new file mode 100644 -index 0000000..1c0edb1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h -@@ -0,0 +1,228 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape, -+ typename WarpShape, -+ int PartitionsK, -+ typename ElementOutput, -+ int ElementsPerAccess, -+ typename ElementAccumulator -+> -+struct DefaultThreadMapVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename ElementOutput_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapVoltaTensorOp< -+ ThreadblockShape_, -+ WarpShape_, -+ PartitionsK, -+ ElementOutput_, -+ ElementsPerAccess, -+ half_t> { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using ElementOutput = ElementOutput_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementAccumulator = half_t; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kTensorOpRows = 16; -+ static int const kWarpSize = 32; -+ static int const kInterleavedTilesM = WarpShape::kM / 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ ThreadblockShape::kN, // column -+ 4, // row -+ 4, // group -+ WarpCount::kM, // cluster -+ 1 // tile -+ >; -+ -+ /// Number of iterations per subspace -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ kInterleavedTilesM, // group -+ 1, // cluster -+ WarpShape::kM / kTensorOpRows // iterations -+ >; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ typename Detail::Shape, -+ typename Detail::Count, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ int PartitionsK, -+ typename ElementOutput_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapVoltaTensorOp< -+ ThreadblockShape_, -+ WarpShape_, -+ PartitionsK, -+ ElementOutput_, -+ ElementsPerAccess, -+ float> { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using ElementOutput = ElementOutput_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ using ElementAccumulator = float; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ static int const kTensorOpRows = 16; -+ static int const kWarpSize = 32; -+ static int const kInterleavedTilesM = WarpShape::kM / 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ ThreadblockShape::kN, // column -+ 4, // row -+ 4, // group -+ WarpCount::kM, // cluster -+ 1 // tile -+ >; -+ -+ /// Number of iterations per subspace -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ kInterleavedTilesM, // group -+ 1, // cluster -+ WarpShape::kM / kTensorOpRows // iterations -+ >; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ typename Detail::Shape, -+ typename Detail::Count, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h -new file mode 100644 -index 0000000..929762b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h -@@ -0,0 +1,113 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "predicated_tile_iterator.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the optimal thread map for Wmma TensorOp accumulator layouts -+template < -+ typename ThreadblockShape_, -+ typename WarpShape_, -+ typename InstructionShape_, -+ int PartitionsK, -+ typename Element_, -+ int ElementsPerAccess -+> -+struct DefaultThreadMapWmmaTensorOp { -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ static int const kPartitionsK = PartitionsK; -+ using Element = Element_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ // -+ // Definitions -+ // -+ -+ struct Detail { -+ -+ /// Wmma Tensor Operations fundamentally perform operations on InstructionShape::kM rows -+ static int const kTensorOpRows = InstructionShape::kM; -+ static int const kWarpSize = 32; -+ -+ static_assert( -+ !(ThreadblockShape::kM % WarpShape::kM) && -+ !(ThreadblockShape::kN % WarpShape::kN), "Divisibility"); -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ ThreadblockShape::kM / WarpShape::kM, -+ ThreadblockShape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Number of participating threads -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ }; -+ -+ // -+ // ThreadMap -+ // -+ -+ /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap -+ using Type = OutputTileOptimalThreadMap < -+ OutputTileShape, -+ OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, -+ Detail::kThreads, -+ kElementsPerAccess, -+ sizeof_bits::value -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h -new file mode 100644 -index 0000000..afacca2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/direct_store_epilogue_iterator.h -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectStoreEpilogueIterator { -+public: -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = 1; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) { -+ stride = layout.stride(0) * sizeof(Element); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ Element *pointer; // pointer to the output matrix -+ -+ LongIndex stride; // stride in elements between rows -+ -+ TensorCoord extent; // extent of output matrix -+ -+ int thread_idx; // thread index -+ -+ TensorCoord threadblock_offset; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ DirectStoreEpilogueIterator( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer_, -+ TensorCoord extent_, -+ int thread_idx_, -+ TensorCoord threadblock_offset_ = TensorCoord(), -+ int const * indices = nullptr -+ ): -+ pointer(pointer_), -+ stride(params.stride / sizeof(Element)), -+ extent(extent_), -+ thread_idx(thread_idx_), -+ threadblock_offset(threadblock_offset_) -+ { -+ -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h -new file mode 100644 -index 0000000..7672a59 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue.h -@@ -0,0 +1,535 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ The shared memory resource is time-sliced across warps. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class Epilogue : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>, -+ public EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_> -+{ -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using BaseStreamK = EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Number of warps per block -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Number of threads per block -+ static int const kBlockThreads = 32 * WarpCount::kCount; -+ -+ /// Per-thread accumulator tile type -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename WarpMmaOperator::ElementC; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Vector type used by the global output iterator -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Vector type used by the shared output iterator -+ using AccumulatorAccessType = Array; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ -+public: -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); -+ -+ -+public: -+ -+ /// Aspect for when epilogue source is not needed -+ struct SourceAspectNotNeeded -+ { -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNotNeeded() -+ {} -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+ }; -+ -+ -+ /// Aspect for when epilogue source is needed -+ struct SourceAspectNeeded -+ { -+ OutputTileIterator source_iterator; -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ static void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNeeded(OutputTileIterator source_iterator) : -+ source_iterator(source_iterator) -+ { -+ source_fragment.clear(); -+ } -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) -+ { -+ // Load addend source fragment from global memory -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ } -+ }; -+ -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index in the threadblock -+ int thread_idx; -+ -+ /// Warp index in the threadblock -+ int warp_idx; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ Epilogue( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx) ///< Id of thread within warp -+ : -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ BaseStreamK(thread_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx), -+ thread_idx(thread_idx), -+ warp_idx(warp_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace, -+ /// performing epilogue computations, writing to output -+ CUTLASS_DEVICE -+ void reduce( -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *element_workspace, -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ { -+ // Redcuce peer accumulator fragments into one fragment -+ AccumulatorFragment accum_fragment; -+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); -+ -+ // Store fragment to shared memory -+ this->warp_tile_iterator_.store(accum_fragment); -+ -+ __syncthreads(); -+ -+ // Initialize/load source-fragment data -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ if (output_op.is_source_needed()) -+ { -+ source_iterator += reduce_fragment_idx; -+ source_iterator.load(source_fragment); -+ } -+ -+ // Load fragment from shared memory -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ // Add fragments shared by other k partitions -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ typename SharedLoadIterator::Fragment aligned_addend_fragment; -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_addend_fragment); -+ aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment); -+ } -+ } -+ -+ // Compute the output result -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ // Apply the output operator -+ SourceAspectNeeded::apply_output_operator( -+ output_fragment, -+ output_op, -+ aligned_accum_fragment, -+ source_fragment); -+ -+ // Store the final result -+ destination_iterator += reduce_fragment_idx; -+ destination_iterator.store(output_fragment); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements -+ /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (output_op.is_source_needed()) -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ else -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements a -+ /// single codepath, regardless of whether the output op requires addend data to be loaded -+ CUTLASS_DEVICE -+ void unified( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (!output_op.is_source_needed()) -+ { -+ source_iterator.clear_mask(); -+ __syncthreads(); // Dummy (CUDA 11.0) -+ } -+ -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ -+ -+ /// Streams the result to global memory -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ SourceAspect source) -+ { -+ // Iterator over warp-level accumulator fragment -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) -+ { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) -+ { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ this->warp_tile_iterator_.store(accum_fragment); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ __syncthreads(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) -+ { -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ if (p < Base::kFragmentsPerIteration - 1) -+ { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_addend; -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment_addend); -+ aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h -new file mode 100644 -index 0000000..cad06bb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#include -+#endif -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// -+// This is used for metaprogramming epilogue functors. If they define -+// `static bool const kIsHeavy = true;`, then the epilogue functor itself is -+// not inlined. This results in smaller code and is advantageous if the epilogue -+// functor consists of many instructions. -+// -+// If the epilogue functor does not define `kIsHeavy` or if it is `false`, then -+// the behavior from CUTLASS 2.5 and before is retained. The epilogue is fully -+// unrolled and inlined. -+// -+ -+template -+struct TypeSink { typedef void type; }; -+ -+template using TypeSinkT = typename TypeSink::type; -+ -+template struct IsEpilogueFunctorHeavy { -+ static bool const value = false; -+}; -+ -+template struct IsEpilogueFunctorHeavy > { -+ static bool const value = T::kIsHeavy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for epilogues defining warp-level -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpShape_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerIteration = 1 -+> -+class EpilogueBase { -+public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ static int const kPartitionsK = PartitionsK; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using Padding = Padding_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename AccumulatorTile::Element; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Use this to control the granularity of one epilogue 'iteration' -+ static int const kFragmentsPerIteration = FragmentsPerIteration; -+ -+public: -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape< -+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, -+ WarpCount::kN * WarpTileIterator::Shape::kColumn -+ >; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ (Shape::kRow + Padding::kRow) * kFragmentsPerIteration, -+ Shape::kColumn + Padding::kColumn -+ >; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { -+ return storage.data(); -+ } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef( -+ storage.data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueBase( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_storage_(shared_storage), -+ warp_tile_iterator_(shared_storage.reference(), lane_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator_.add_tile_offset(warp_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h -new file mode 100644 -index 0000000..2be1aeb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Basic subset of epilogue functionality for supporting StreamK decompositions -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/block_striped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// StreamK epilogue functionality for cross-block accumulator fragment reduction -+template < -+ typename Shape, ///< Shape of threadblock tile (concept: GemmShape) -+ int PartitionsK, -+ typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators -+class EpilogueBaseStreamK -+{ -+ -+protected: -+ -+ /// The per-thread tile of raw accumulators -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpMmaOperator::Shape::kM, -+ Shape::kN / WarpMmaOperator::Shape::kN, -+ PartitionsK>; -+ -+ /// Number of threads per block -+ static int const kBlockThreads = 32 * WarpCount::kCount; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename WarpMmaOperator::ElementC; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+public: -+ -+ /// Number of AccumulatorTile fragments per thread -+ static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations; -+ -+protected: -+ -+ /// Number of AccumulatorTile fragments per block output tile -+ static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments; -+ -+ /// Block-striped transfer utility for sharing AccumulatorFragment -+ using BlockStripedT = BlockStriped; -+ -+ /// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles) -+ static const int kPeerFragmentStride = kOutputTileFragments * 2; -+ -+public: -+ -+ /// Workspace bytes per thread block -+ static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride; -+ -+public: -+ -+ /// Thread index in the threadblock -+ int thread_idx; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueBaseStreamK( -+ int thread_idx) ///< ID of a thread within the threadblock -+ : -+ thread_idx(thread_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace -+ CUTLASS_DEVICE -+ void reduce( -+ AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *workspace_ptr) -+ { -+ plus add_fragments; -+ -+ AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); -+ -+ int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads); -+ -+ // Load first peer fragment -+ BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx); -+ -+ fragment_offset += kPeerFragmentStride; // Move to next peer -+ fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile -+ -+ // Reduce fragments from additional peers -+ #pragma unroll 2 -+ for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride) -+ { -+ // Load peer fragment -+ AccumulatorFragment addend_fragment; -+ BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx); -+ -+ // Add peer fragment -+ accum_fragment = add_fragments(accum_fragment, addend_fragment); -+ } -+ } -+ -+ -+ /// Shares the accumulator set with peers in the global workspace -+ CUTLASS_DEVICE -+ void share( -+ int peer_idx, -+ void *workspace_ptr, -+ AccumulatorTile const &accumulators, -+ bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile -+ { -+ AccumulatorFragment *fragment_workspace = reinterpret_cast(workspace_ptr); -+ -+ int fragment_offset = peer_idx * kPeerFragmentStride; -+ -+ if (!started_tile) { -+ // Move to the set of fragments for the "non-started" output tile -+ fragment_offset += kOutputTileFragments; -+ } -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // Convert raw accumulator tile to fragments and store -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < kAccumulatorFragments; ++iter) -+ { -+ // Acquire reordered accumulator fragment -+ AccumulatorFragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ // Store accumulator fragment -+ BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx); -+ -+ fragment_offset += kBlockThreads; -+ } -+ } -+ -+}; -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h -new file mode 100644 -index 0000000..d5a52ea ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_depthwise.h -@@ -0,0 +1,335 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for Depthwise convoltuion -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template -+class EpilogueDepthwise { -+ public: -+ using Shape = Shape_; -+ using WarpShape = typename WarpMmaOperator_::Shape; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = -+ Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = -+ Array; -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+ public: -+ static_assert(SharedLoadIterator::Fragment::kElements == -+ OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, -+ "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { return storage.data(); } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+ private: -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+ LongIndex warp_offset; -+ int thread_idx; -+ int warp_idx; -+ int lane_idx; -+ int warp_m, warp_n; // warp coordinates within a cta -+ int tid_m, tid_n; // thread coordinates within a warp -+ -+ public: -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx_, ///< ID of a thread within the threadblock -+ int warp_idx_, ///< ID of warp within threadblock -+ int lane_idx_ ///< Id of thread within warp -+ ) -+ : thread_idx(thread_idx_), -+ warp_idx(warp_idx_), -+ lane_idx(lane_idx_), -+ shared_load_iterator_(shared_storage.reference(), thread_idx_), -+ warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {} -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()(OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in -+ ///< units of threadblock tiles) -+ const int smem_base_offset) { ///< SMEM base offset for epilogue operation -+ // initiate the smem base offset for different output tile. -+ warp_tile_iterator_.set_smem_base_address(smem_base_offset); -+ -+ shared_load_iterator_.set_smem_base_address(smem_base_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } else { -+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ -+ private: -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ source_fragment.clear(); -+ -+ source_iterator.load(source_fragment); -+ -+ // store to smem -+ warp_tile_iterator_.store(accumulators); -+ -+ __syncthreads(); -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ -+ // load from smem -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ -+ // Store to GMEM -+ destination_iterator.store(output_fragment); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ // store to smem -+ warp_tile_iterator_.store(accumulators); -+ -+ __syncthreads(); -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment; -+ -+ // load from smem -+ shared_load_iterator_.load(aligned_accum_fragment); -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment); -+ -+ // Store to GMEM -+ destination_iterator.store(output_fragment); -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) { -+ -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { -+ OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h -new file mode 100644 -index 0000000..8cd4791 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_direct_store.h -@@ -0,0 +1,347 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs and convolution using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/epilogue/thread/reduction_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_ ///< Output operator -+> -+class EpilogueDirectStore { -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using WarpShape = typename WarpMmaOperator_::Shape; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = MatrixShape<0, 0>; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Use this to control the granularity of one epilogue 'iteration' -+ static int const kFragmentsPerIteration = 1; -+ -+ static int constexpr kSmemTiles = 1; -+ static int constexpr kSmemPointerOffset = 0; -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { } ; -+ -+private: -+ -+ // Assume accumulator tile is multipile interleaved 32x32 tile. -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename platform::conditional< -+ platform::is_same::value, -+ MatrixShape<2, 2>, -+ MatrixShape<1, 4> >::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = MatrixShape<4, 4>; -+ -+ static_assert(OutputOp::kCount >= 2, -+ "The direct store epilogue for Tensor Ops requires the output functor have kCount >= 2."); -+ -+private: -+ -+ LongIndex warp_offset; -+ int thread_idx; -+ int warp_idx; -+ int lane_idx; -+ int warp_m, warp_n; // warp coordinates within a cta -+ int tid_m, tid_n; // thread coordinates within a warp -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueDirectStore( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx_, ///< ID of a thread within the threadblock -+ int warp_idx_, ///< ID of warp within threadblock -+ int lane_idx_ ///< Id of thread within warp -+ ): -+ thread_idx(thread_idx_), -+ warp_idx(warp_idx_), -+ lane_idx(lane_idx_) -+ { -+ -+ // warp offsetting calculations -+ warp_offset = warp_idx * WarpShape::kM * WarpShape::kN; -+ int warp_id_mn = warp_idx % (WarpCount::kM * WarpShape::kN); -+ warp_m = warp_id_mn % WarpCount::kM; -+ warp_n = warp_id_mn / WarpCount::kM; -+ MatrixCoord warp_offset_coord(warp_m*WarpShape::kM, warp_n*WarpShape::kN); -+ -+ // thread offsetting calculations -+ int quad = (lane_idx >> 2); -+ int lane_in_quad = (lane_idx & 3); -+ -+ // this seems to be te correct layout -+ tid_m = quad; -+ tid_n = 2 * lane_in_quad; -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_(output_op, destination_iterator, accumulators); -+ } -+ else { -+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); -+ } -+ } -+ -+private: -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ const int kAccumBlockN = 2; -+ const int kThreadsM = 8; -+ const int kThreadsN = 4; -+ const int kBlockM = WarpShape::kM / kThreadsM; -+ -+ /// Array type used to output -+ using OutputAccessType = AlignedArray; -+ -+ /// Array type passed to the output operator - unused elements are optimized away -+ using OutputFragmentType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorFragmentType = Array; -+ -+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast(&accumulators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) { -+ -+ int accum_m = kThreadsM * accum_m_idx; -+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m; -+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n; -+ -+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride; -+ ElementOutput *source_ptr = source_iterator.pointer + mL * source_iterator.stride; -+ -+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) { -+ -+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx; -+ int accum_n = kThreadsM * accum_n_idx; -+ -+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output -+ int nL = nL_base + accum_n; -+ -+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column()); -+ -+ AccumulatorFragmentType accum_fragment; -+ reinterpret_cast(accum_fragment) = accumulator_pair[accum_idx]; -+ -+ OutputFragmentType output_fragment; -+ -+ if(guard) { -+ reinterpret_cast(output_fragment) = -+ *reinterpret_cast(source_ptr + nL); -+ } -+ -+ // Perform output operator -+ output_fragment = output_op(accum_fragment, output_fragment); -+ -+ if(guard) { -+ // Store -+ *reinterpret_cast(output_ptr + nL) = reinterpret_cast(output_fragment); -+ } -+ } -+ } -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ const int kAccumBlockN = 2; -+ const int kThreadsM = 8; -+ const int kThreadsN = 4; -+ const int kBlockM = WarpShape::kM / kThreadsM; -+ -+ /// Array type used to output -+ using OutputAccessType = AlignedArray; -+ -+ /// Array type passed to the output operator - unused elements are optimized away -+ using OutputFragmentType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorFragmentType = Array; -+ -+ AccumulatorAccessType const *accumulator_pair = reinterpret_cast(&accumulators); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_m_idx = 0; accum_m_idx < WarpShape::kM / kThreadsM; accum_m_idx++) { -+ -+ int accum_m = kThreadsM * accum_m_idx; -+ int mL = destination_iterator.threadblock_offset.row() + WarpShape::kM * warp_m + tid_m + accum_m; -+ int nL_base = destination_iterator.threadblock_offset.column() + WarpShape::kN * warp_n + tid_n; -+ -+ ElementOutput *output_ptr = destination_iterator.pointer + mL * destination_iterator.stride; -+ -+ int const kIterationsN = WarpShape::kN / kThreadsN / kAccumBlockN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int accum_n_idx = 0; accum_n_idx < kIterationsN; accum_n_idx++) { -+ -+ int accum_idx = accum_m_idx + kBlockM * accum_n_idx; -+ int accum_n = kThreadsM * accum_n_idx; -+ -+ // mL and nL are logical coordinate in 2D mapping of epilogue's 4D output -+ int nL = nL_base + accum_n; -+ -+ bool guard = (mL < destination_iterator.extent.row()) && (nL < destination_iterator.extent.column()); -+ -+ AccumulatorFragmentType accum_fragment; -+ reinterpret_cast(accum_fragment) = accumulator_pair[accum_idx]; -+ -+ OutputFragmentType output_fragment; -+ -+ // Perform output operator -+ output_fragment = output_op(accum_fragment); -+ -+ if(guard) { -+ -+ // Store -+ *reinterpret_cast(output_ptr + nL) = -+ reinterpret_cast(output_fragment); -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h -new file mode 100644 -index 0000000..927035b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h -@@ -0,0 +1,212 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/numeric_types.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename ElementAccumulator_, -+ typename ElementOutput_, -+ typename ThreadBlockShape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ bool ReduceKForA_ -+> -+class EpilogueGemmKReduction { -+ -+public: -+ -+ using ThreadBlockShape = ThreadBlockShape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ using WarpShape = typename WarpMmaOperator::Shape; -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Accumulator element -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Output element -+ using ElementOutput = ElementOutput_; -+ -+ /// Output access size -+ static int const kElementsPerAccess = 1; -+ -+ static bool const kReduceKForA = ReduceKForA_; -+ -+ static int const kThreadBlockSize = kReduceKForA ? ThreadBlockShape::kM : ThreadBlockShape::kN; -+ -+ static int const kWarpSize = kReduceKForA ? WarpShape::kM : WarpShape::kN; -+ -+ static int const kIterations = kWarpSize / 8; -+ -+ using FragmentAccumulator = Array; -+ -+private: -+ -+ int thread_offset_; -+ ElementOutput* pointer_; -+ int col_; -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueGemmKReduction( -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx, ///< Id of thread within warp -+ int threadblock_offset, -+ ElementOutput* pointer -+ ) -+ { -+ col_ = lane_idx % 4; -+ thread_offset_ = threadblock_offset * kThreadBlockSize -+ + warp_idx * kWarpSize -+ + lane_idx / 4 + col_ * 8; -+ -+ pointer_ = pointer + LongIndex(thread_offset_); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ int size, -+ FragmentAccumulator &gemm_k_with_reduction_accumulation, -+ bool LoadForSerialSplitK -+ ) { -+ bool guard[kIterations / 4]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ guard[i] = ((thread_offset_ + i * 32) < size); -+ } -+ -+ Array source; -+ source.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ ElementOutput *source_ptr = reinterpret_cast(&source); -+ cutlass::arch::global_load( -+ source_ptr[i], -+ (void *)(pointer_ + i * 32), -+ guard[i] && LoadForSerialSplitK); -+ -+ } -+ -+ FragmentAccumulator sum = gemm_k_with_reduction_accumulation; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations; ++i) { -+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 1); -+ sum[i] += __shfl_xor_sync(0xffffffff, sum[i], 2); -+ } -+ -+ Array intermediate; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ if (col_ == 0) { -+ intermediate[i] = sum[0 + i * 4]; -+ } -+ -+ if (col_ == 1) { -+ intermediate[i] = sum[1 + i * 4]; -+ } -+ -+ if (col_ == 2) { -+ intermediate[i] = sum[2 + i * 4]; -+ } -+ -+ if (col_ == 3) { -+ intermediate[i] = sum[3 + i * 4]; -+ } -+ } -+ -+ NumericArrayConverter source_converter; -+ Array converted_source = source_converter(source); -+ -+ plus> plus_source; -+ intermediate = plus_source(intermediate, converted_source); -+ -+ NumericArrayConverter converter; -+ Array result = converter(intermediate); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations / 4; ++i) { -+ cutlass::arch::global_store(result[i], -+ (void *)(pointer_ + i * 32), guard[i]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h -new file mode 100644 -index 0000000..1c70bed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_planar_complex.h -@@ -0,0 +1,401 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator for planar-complex output representations. -+/// -+/// Note, as with most CUTLASS components for planar complex, the template arguments describe -+/// the underlying real data type. -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename Padding_ ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+> -+class EpiloguePlanarComplex { -+public: -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ /// Output layout is always row-major -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = ArrayPlanarComplex< -+ typename WarpMmaOperator::FragmentC::Element, -+ WarpMmaOperator::FragmentC::kElements -+ >; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Shape of each warp-level operation -+ using WarpShape = typename WarpMmaOperator::Shape; -+ -+ /// Number of warps -+ using WarpCount = gemm::GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ kPartitionsK -+ >; -+ -+ /// Shared memory allocation -+ struct SharedStorage { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element type of shared memory -+ using Element = typename WarpTileIterator::Element; -+ -+ /// Tensor reference to shared memory allocation -+ using TensorRef = typename WarpTileIterator::TensorRef; -+ -+ /// Layout of shared memory allocation -+ using Layout = typename WarpTileIterator::Layout; -+ -+ /// Logical shape of the shared memory tile written to by all warps. -+ using Shape = MatrixShape< -+ WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK, -+ WarpCount::kN * WarpTileIterator::Shape::kColumn -+ >; -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ Shape::kRow + Padding::kRow, -+ Shape::kColumn + Padding::kColumn -+ >; -+ -+ static int const kImaginaryStride = StorageShape::kCount; -+ -+ // -+ // Data members -+ // -+ -+ AlignedBuffer storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a pointer to the shared memory buffer -+ CUTLASS_DEVICE -+ Element *data() { -+ return storage.data(); -+ } -+ -+ /// Returns a tensor reference to the shared memory buffer -+ CUTLASS_DEVICE -+ TensorRef reference() { -+ return TensorRef( -+ storage.data(), -+ Layout::packed({StorageShape::kRow, StorageShape::kColumn})); -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage_; -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Stores a warp's fragment of accumulators to SMEM -+ WarpTileIterator warp_tile_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpiloguePlanarComplex( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ shared_storage_(shared_storage), -+ shared_load_iterator_(shared_storage.reference(), thread_idx), -+ warp_tile_iterator_(shared_storage.reference(), lane_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: -+ // -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); -+ int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); -+ int warp_m = warp_mn % WarpCount::kM; -+ int warp_n = warp_mn / WarpCount::kM; -+ -+ MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; -+ -+ warp_tile_iterator_.add_tile_offset(warp_offset); -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator_real, ///< Tile iterator for destination -+ OutputTileIterator destination_iterator_imag, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator_real, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ OutputTileIterator source_iterator_imag) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ typename OutputTileIterator::Fragment source_fragment_real; -+ typename OutputTileIterator::Fragment source_fragment_imag; -+ -+ if (!output_op.is_source_needed()) { -+ source_iterator_real.clear_mask(); -+ source_iterator_imag.clear_mask(); -+ } -+ -+ source_fragment_real.clear(); -+ source_fragment_imag.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator_real(accumulators.real); -+ AccumulatorFragmentIterator accum_fragment_iterator_imag(accumulators.imag); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator_real.load(source_fragment_real); -+ source_iterator_imag.load(source_fragment_imag); -+ -+ ++source_iterator_real; -+ ++source_iterator_imag; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment_real; -+ typename AccumulatorFragmentIterator::Fragment accum_fragment_imag; -+ -+ accum_fragment_iterator_real.load(accum_fragment_real); -+ accum_fragment_iterator_imag.load(accum_fragment_imag); -+ -+ ++accum_fragment_iterator_real; -+ ++accum_fragment_iterator_imag; -+ -+ this->warp_tile_iterator_.store(accum_fragment_real); -+ this->warp_tile_iterator_.store_with_pointer_offset(accum_fragment_imag, SharedStorage::kImaginaryStride); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_real[kPartitionsK]; -+ typename SharedLoadIterator::Fragment aligned_accum_fragment_imag[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment_real[0]); -+ shared_load_iterator_.load_with_pointer_offset(aligned_accum_fragment_imag[0], SharedStorage::kImaginaryStride); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ static_assert(kPartitionsK == 1, "Sliced-K not supported for planar complex at this time"); -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment_real; -+ typename OutputTileIterator::Fragment output_fragment_imag; -+ -+ apply_output_operator_( -+ output_fragment_real, -+ output_fragment_imag, -+ output_op, -+ aligned_accum_fragment_real[0], -+ aligned_accum_fragment_imag[0], -+ source_fragment_real, -+ source_fragment_imag); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator_real.store(output_fragment_real); -+ destination_iterator_imag.store(output_fragment_imag); -+ -+ ++destination_iterator_real; -+ ++destination_iterator_imag; -+ } -+ } -+ -+private: -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &output_fragment_real, -+ typename OutputTileIterator::Fragment &output_fragment_imag, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_real, -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment_imag, -+ typename OutputTileIterator::Fragment const &source_fragment_real, -+ typename OutputTileIterator::Fragment const &source_fragment_imag) { -+ -+ OutputAccessType *output_frag_real_ptr = -+ reinterpret_cast(&output_fragment_real); -+ -+ OutputAccessType *output_frag_imag_ptr = -+ reinterpret_cast(&output_fragment_imag); -+ -+ AccumulatorAccessType const *compute_frag_real_ptr = -+ reinterpret_cast(&aligned_accum_fragment_real); -+ -+ AccumulatorAccessType const *compute_frag_imag_ptr = -+ reinterpret_cast(&aligned_accum_fragment_imag); -+ -+ OutputAccessType const *source_frag_real_ptr = -+ reinterpret_cast(&source_fragment_real); -+ -+ OutputAccessType const *source_frag_imag_ptr = -+ reinterpret_cast(&source_fragment_imag); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ auto result_fragment = output_op( -+ make_ArrayPlanarComplex(compute_frag_real_ptr[i], compute_frag_imag_ptr[i]), -+ make_ArrayPlanarComplex(source_frag_real_ptr[i], source_frag_imag_ptr[i]) -+ ); -+ -+ output_frag_real_ptr[i] = result_fragment.real; -+ output_frag_imag_ptr[i] = result_fragment.imag; -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h -new file mode 100644 -index 0000000..6dabe72 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMM/CONV to store accumulator in shared memory after -+ applying scale, bias loaded from global memory and element-wise operations. -+ -+ This Epilogue is typically used in fused GEMM/CONV to stage the intermediate accumulator. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename SmemTileIterator_, ///< Shared memory Tile iterator to output to shared memory -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename ScaleBiasIterator_, ///< Iterator to load scale and bias from global memory -+ typename OutputOp_ ///< Output operator -+> -+class EpilogueSmemAccumulator { -+ -+public: -+ -+ using SmemTileIterator = SmemTileIterator_; -+ -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ -+ using ScaleBiasIterator = ScaleBiasIterator_; -+ -+ using OutputOp = OutputOp_; -+ -+ /// Fragment of accumulator tile -+ using FragmentAccumulator = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Fragment of Scale and Bias loaded from global memory -+ using FragmentScaleBias = typename ScaleBiasIterator::Fragment; -+ -+ static const bool PerChannelScale = (OutputOp::kScale == -+ epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueSmemAccumulator() {} -+ -+ /// Streams the result to shared memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory -+ AccumulatorTile const &accumulator, ///< Complete warp-level accumulator tile -+ ScaleBiasIterator scale_iterator, ///< iterator for scale vector in global memory -+ ScaleBiasIterator bias_iterator) { ///< iterator for bias vector in global memory -+ -+ -+ // Fragment to load scale bias from global memory -+ FragmentScaleBias tb_frag_scale; -+ FragmentScaleBias tb_frag_bias; -+ -+ /// Fragment Iterator to load slice of accumulator tile -+ AccumulatorFragmentIterator frag_iterator_accum(accumulator); -+ FragmentAccumulator tb_frag_accum; -+ -+ /// Epilogue output fragment -+ typename SmemTileIterator::Fragment tb_frag_smem; -+ -+ /// Load scale and bias from global memory -+ -+ if(PerChannelScale) -+ scale_iterator.load(tb_frag_scale); -+ -+ bias_iterator.load(tb_frag_bias); -+ -+ /// Iterate over the accumulator tile and store to shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) { -+ -+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator; -+ using ScaleBiasAccessType = typename OutputOp::FragmentScaleBias; -+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput; -+ -+ -+ ScaleBiasAccessType const * scale_frag_ptr = -+ reinterpret_cast(&tb_frag_scale); -+ ScaleBiasAccessType const * bias_frag_ptr = -+ reinterpret_cast(&tb_frag_bias); -+ -+ FragmentSmemAccessType * smem_frag_ptr = -+ reinterpret_cast(&tb_frag_smem); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) { -+ frag_iterator_accum.load(tb_frag_accum); -+ ++frag_iterator_accum; -+ -+ AccumulatorAccessType const * accumulator_frag_ptr = -+ reinterpret_cast(&tb_frag_accum); -+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int it = 0; it < kOutputIterations; it++) { -+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it], -+ scale_frag_ptr[cid * kOutputIterations + it], bias_frag_ptr[cid * kOutputIterations + it]); -+ } -+ } -+ -+ smem_iterator.store(tb_frag_smem); -+ ++smem_iterator; -+ -+ } -+ } -+ } -+ -+ /// Streams the result to shared memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ SmemTileIterator smem_iterator, ///< Tile iterator for destination in shared memory -+ AccumulatorTile const &accumulator) { ///< Complete warp-level accumulator tile -+ -+ /// Fragment Iterator to load slice of accumulator tile -+ AccumulatorFragmentIterator frag_iterator_accum(accumulator); -+ FragmentAccumulator tb_frag_accum; -+ -+ /// Epilogue output fragment -+ typename SmemTileIterator::Fragment tb_frag_smem; -+ -+ /// Iterate over the accumulator tile and store to shared memory -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < AccumulatorFragmentIterator::TileIterations::kRow; ++rid) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cid = 0; cid < AccumulatorFragmentIterator::TileIterations::kColumn; ++cid) { -+ -+ using AccumulatorAccessType = typename OutputOp::FragmentAccumulator; -+ using FragmentSmemAccessType = typename OutputOp::FragmentOutput; -+ -+ FragmentSmemAccessType * smem_frag_ptr = -+ reinterpret_cast(&tb_frag_smem); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < AccumulatorFragmentIterator::kIterationsPerTile; ++idx) { -+ frag_iterator_accum.load(tb_frag_accum); -+ ++frag_iterator_accum; -+ -+ AccumulatorAccessType const * accumulator_frag_ptr = -+ reinterpret_cast(&tb_frag_accum); -+ const int kOutputIterations = FragmentAccumulator::kElements / OutputOp::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int it = 0; it < kOutputIterations; it++) { -+ smem_frag_ptr[idx * kOutputIterations + it] = output_op(accumulator_frag_ptr[it]); -+ } -+ } -+ -+ smem_iterator.store(tb_frag_smem); -+ ++smem_iterator; -+ -+ } -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h -new file mode 100644 -index 0000000..de70352 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h -@@ -0,0 +1,513 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue. -+ -+ The epilogue finds max values in each row of the row-major output matrix and stores them. -+ The max values are also used for a further round of threadblock scoped reduction operation, where -+ the partial reduction results are stored in a pre-allocated array and used for further full reduction. -+ -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/fast_math.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename ElementAccumulator_, -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoftmaxCompute_, -+ typename ElementwiseFunctor_, -+ bool UseMasking_ = false -+> -+class EpilogueVisitorSoftmax { -+public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ using LayoutOutput = cutlass::layout::RowMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ -+ using AccumulatorFragment = Array; -+ using SoftmaxFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; -+ static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); -+ static bool const kUseMasking = UseMasking_; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ batch_stride_C(0), -+ batch_stride_D(0), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_ -+ ): -+ elementwise(elementwise_), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ int64_t batch_stride_Max_, -+ int64_t batch_stride_Sum_ -+ ): -+ elementwise(elementwise_), -+ batch_stride_C(batch_stride_C_), -+ batch_stride_D(batch_stride_D_), -+ batch_stride_Max(batch_stride_Max_), -+ batch_stride_Sum(batch_stride_Sum_) -+ { -+ -+ } -+ -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ batch_stride_Max(args.batch_stride_Max), -+ batch_stride_Sum(args.batch_stride_Sum) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ MatrixCoord extent_real_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ -+ ElementNorm *ptr_Max_; -+ ElementSum *ptr_Sum_; -+ -+ int column_offset_; -+ -+ ElementSoftmaxCompute accum_max_; -+ ElementSoftmaxCompute accum_sum_; -+ -+ MatrixCoord thread_offset_; -+ -+ float infinity_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorSoftmax( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ cutlass::MatrixCoord const &problem_size, -+ int thread_idx, -+ int warp_idx, -+ int lane_idx, -+ typename OutputTileIterator::Params params_C, -+ typename OutputTileIterator::Params params_D, -+ typename OutputTileIterator::Element *ptr_C, -+ typename OutputTileIterator::Element *ptr_D, -+ ElementNorm *ptr_Max = nullptr, -+ ElementSum *ptr_Sum = nullptr, -+ cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0), -+ int column_offset = 0, -+ cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0), -+ float infinity = 10000.0f -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ extent_(problem_size), -+ elementwise_(params.elementwise), -+ iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), -+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), -+ ptr_Max_(ptr_Max), -+ ptr_Sum_(ptr_Sum), -+ column_offset_(column_offset), -+ extent_real_(problem_size_real), -+ infinity_(infinity) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); -+ iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ fragment_C_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ // Clear accumulators for max and sum when starting a whole row -+ clear_accum_(); -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ Minus minus; -+ Exp exponential; -+ -+ SoftmaxFragment result; -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (kUseMasking) { -+ int elements_in_boundary = extent_real_.column() - thread_offset_.column(); -+ elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary; -+ elementwise_padding_(result, elements_in_boundary); -+ } -+ -+ ElementSoftmaxCompute accum_max_prev = accum_max_; -+ -+ // Compute the maximum within one row -+ if (!column_idx) { -+ // This is the first fragment in a new row -+ if (column_guard) { -+ accum_max_ = maximum_accumulator_(result); -+ } -+ } -+ else { -+ // This is an additional fragment in the same row -+ if (column_guard) { -+ accum_max_ = maximum_accumulator_(result, accum_max_); -+ } -+ } -+ -+ // proactively compute max in warps -+ accum_max_ = warp_reduce_max_(accum_max_); -+ -+ ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_); -+ -+ SoftmaxFragment intermediate = exponential(minus(result, accum_max_)); -+ -+ if (kHasMultiStepsInRow) { -+ if (!column_idx) { -+ accum_sum_ = (column_guard) ? \ -+ sum_accumulator_(intermediate) : ElementSoftmaxCompute(0); -+ } else { -+ // Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf -+ // S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row) -+ accum_sum_ = (column_guard) ? \ -+ sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater; -+ } -+ } else { -+ accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0); -+ } -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the end of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertSumOutput = cutlass::NumericConverter; -+ using ConvertNormOutput = cutlass::NumericConverter; -+ -+ ConvertSumOutput convert_sum_output; -+ ConvertNormOutput convert_norm_output; -+ -+ // Compute accumulate sum only in the last step -+ accum_sum_ = warp_reduce_sum_(accum_sum_); -+ -+ bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); -+ bool row_guard = thread_offset_.row() < extent_.row(); -+ bool is_write_thread = row_guard && is_first_thread_in_tile; -+ -+ int block_batch = blockIdx.z; -+ -+ ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max; -+ ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum; -+ -+ arch::global_store( -+ convert_norm_output(accum_max_), -+ (void *)curr_ptr_max, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_sum_output(accum_sum_), -+ (void *)curr_ptr_sum, -+ is_write_thread); -+ -+ // Clear accumulators for max and sum when finishing a whole row -+ clear_accum_(); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) { -+ int half_thread_in_row = (kThreadsPerRow >> 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = half_thread_in_row; i > 0; i >>= 1) { -+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i); -+ sum_ += tmp; -+ } -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) { -+ int half_thread_in_row = (kThreadsPerRow >> 1); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = half_thread_in_row; i > 0; i >>= 1) { -+ ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i); -+ max_ = fast_max(max_, tmp); -+ } -+ return max_; -+ } -+ -+ CUTLASS_DEVICE -+ void clear_accum_() { -+ -+ uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX -+ float min_float = reinterpret_cast(float_max_bits); -+ accum_max_ = ElementSoftmaxCompute(min_float); -+ accum_sum_ = ElementSoftmaxCompute(0); -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { -+ ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ sum_ += ElementSoftmaxCompute(accum[i]); -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) { -+ // ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ sum_ += ElementSoftmaxCompute(accum[i]); -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { -+ ElementSoftmaxCompute max_ = accum[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < SoftmaxFragment::kElements; ++i) { -+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); -+ } -+ -+ return max_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < SoftmaxFragment::kElements; ++i) { -+ max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); -+ } -+ -+ return max_; -+ } -+}; -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h -new file mode 100644 -index 0000000..9c9f716 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h -@@ -0,0 +1,1540 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#include -+#else -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This base class is meant to define the concept required of the -+/// EpilogueWithBroadcast::OutputOp -+template < -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename ElementCompute_, -+ typename ElementZ_, -+ typename ElementT_, -+ int ElementsPerAccess, -+ bool StoreZ = true, -+ bool StoreT = true -+> -+struct EpilogueWithBroadcastOpBase { -+ -+ using ElementOutput = ElementC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using ElementZ = ElementZ_; -+ using ElementT = ElementT_; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using FragmentAccumulator = Array; -+ using FragmentCompute = Array; -+ using FragmentC = Array; -+ using FragmentZ = Array; -+ using FragmentT = Array; -+ -+ /// If true, the 'Z' tensor is stored -+ static bool const kStoreZ = StoreZ; -+ -+ /// If true, the 'T' tensor is stored -+ static bool const kStoreT = StoreT; -+ -+ /// Parameters structure - required -+ struct Params { }; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor from Params -+ EpilogueWithBroadcastOpBase(Params const ¶ms_) { } -+ -+ /// Determine if the source is needed. May return false if -+ bool is_source_needed() const { -+ return true; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_k_partition(int k_partition, int k_partition_count) { } -+ -+ /// Applies the operation when is_source_needed() is true -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentC const &frag_C1, -+ FragmentC const &frag_C2, -+ FragmentCompute const &V) const { -+ -+ } -+ -+ /// Applies the operation when is_source_needed() is false -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentZ &frag_Z, -+ FragmentT &frag_T, -+ FragmentAccumulator const &AB, -+ FragmentCompute const &V) const { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator with bias vector broadcast over columns. -+/// -+/// Computes the following: -+/// -+/// -+/// Z, T = OutputOp(AB, C, Broadcast) -+/// -+/// if (ElementwiseOp::kStoreZ) { -+/// store(converted_u); -+/// } -+/// -+/// if (ElementwiseOp::kStoreT) { -+/// store(v); -+/// } -+/// -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) -+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) -+ typename ElementVector_, ///< Pointer to broadcast vector -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value), -+ bool IsSingleSource = OutputOp_::kIsSingleSource -+> -+class EpilogueWithBroadcast; -+ -+template < -+ typename Shape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputTileIterator_, -+ typename TensorTileIterator_, -+ typename ElementVector_, -+ typename AccumulatorFragmentIterator_, -+ typename WarpTileIterator_, -+ typename SharedLoadIterator_, -+ typename OutputOp_, -+ typename Padding_, -+ int FragmentsPerPartition, -+ int IterationsUnroll -+> -+class EpilogueWithBroadcast< -+ Shape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputTileIterator_, -+ TensorTileIterator_, -+ ElementVector_, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ SharedLoadIterator_, -+ OutputOp_, -+ Padding_, -+ FragmentsPerPartition, -+ IterationsUnroll, -+ false -+> : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ static bool const kIsSingleSource = false; -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementCompute, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("BroadcastDetail {\n"); -+ printf( -+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithBroadcast( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix -+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ BroadcastFragment broadcast_fragment; -+ -+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator1, -+ source_iterator2, -+ tensor_iterator); -+ } -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_( -+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using ComputeFragmentType = Array; -+ -+ ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ ComputeFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+ -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * -+ (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ // CUTLASS_PRAGMA_UNROLL -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed< -+ cutlass::make_index_sequence>::push(iter, -+ accum_fragment_iterator, -+ this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_source_not_needed_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix -+ OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment1; -+ source_fragment1.clear(); -+ typename OutputTileIterator::Fragment source_fragment2; -+ source_fragment2.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator1.load(source_fragment1); -+ ++source_iterator1; -+ -+ source_iterator2.load(source_fragment2); -+ ++source_iterator2; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment1, -+ source_fragment2, -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ typename OutputTileIterator::Fragment const &frag_C1, -+ typename OutputTileIterator::Fragment const &frag_C2, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ OutputAccessType const *frag_C1_ptr = -+ reinterpret_cast(&frag_C1); -+ -+ OutputAccessType const *frag_C2_ptr = -+ reinterpret_cast(&frag_C2); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_C1_ptr[i], -+ frag_C2_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+}; -+ -+ -+template < -+ typename Shape_, -+ typename WarpMmaOperator_, -+ int PartitionsK, -+ typename OutputTileIterator_, -+ typename TensorTileIterator_, -+ typename ElementVector_, -+ typename AccumulatorFragmentIterator_, -+ typename WarpTileIterator_, -+ typename SharedLoadIterator_, -+ typename OutputOp_, -+ typename Padding_, -+ int FragmentsPerPartition, -+ int IterationsUnroll -+> -+class EpilogueWithBroadcast< -+ Shape_, -+ WarpMmaOperator_, -+ PartitionsK, -+ OutputTileIterator_, -+ TensorTileIterator_, -+ ElementVector_, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ SharedLoadIterator_, -+ OutputOp_, -+ Padding_, -+ FragmentsPerPartition, -+ IterationsUnroll, -+ true -+> : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ static bool const kIsSingleSource = true; -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementCompute, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("BroadcastDetail {\n"); -+ printf( -+ " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithBroadcast( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ BroadcastFragment broadcast_fragment; -+ -+ load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ broadcast_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator, -+ tensor_iterator); -+ } -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_( -+ BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ ElementVector const * broadcast_ptr, ///< Broadcast vector -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using ComputeFragmentType = Array; -+ -+ ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ ComputeFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+ -+ template -+ struct acc2smem_source_not_needed; -+ -+ template -+ struct acc2smem_source_not_needed> { -+ template -+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ warp_tile_iterator.store(accum_fragment); -+ if (p < Base::kFragmentsPerIteration - 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * -+ (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = { -+ (pos == (Seq * Base::kFragmentsPerIteration)) && -+ (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ -+ CUTLASS_UNUSED(dummy[0]); -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ // CUTLASS_PRAGMA_UNROLL -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ -+ __syncthreads(); -+ -+ acc2smem_source_not_needed< -+ cutlass::make_index_sequence>::push(iter, -+ accum_fragment_iterator, -+ this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { -+ -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ if (p < Base::kFragmentsPerIteration - 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ } -+ else if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_source_not_needed_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ -+ if (Base::kFragmentsPerIteration > 1) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); -+ } -+ } -+ } -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Apply output operation -+ // -+ -+ typename OutputTileIterator::Fragment frag_Z; -+ typename TensorTileIterator::Fragment frag_T; -+ -+ apply_output_operator_( -+ frag_Z, -+ frag_T, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment, -+ broadcast_fragment); -+ -+ // -+ // Conditionally store fragments -+ // -+ -+ if (OutputOp::kStoreZ) { -+ destination_iterator.store(frag_Z); -+ ++destination_iterator; -+ } -+ -+ if (OutputOp::kStoreT) { -+ tensor_iterator.store(frag_T); -+ ++tensor_iterator; -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ typename OutputTileIterator::Fragment const &frag_C, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ OutputAccessType const *frag_C_ptr = -+ reinterpret_cast(&frag_C); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_C_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ typename OutputTileIterator::Fragment &frag_Z, -+ typename TensorTileIterator::Fragment &frag_T, -+ OutputOp const &output_op, -+ typename SharedLoadIterator::Fragment const &frag_AB, -+ BroadcastFragment const &frag_Broadcast) { -+ -+ using AccessTypeZ = Array; -+ using AccessTypeT = Array; -+ using AccessTypeBroadcast = Array; -+ -+ AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); -+ AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); -+ -+ AccumulatorAccessType const *frag_AB_ptr = -+ reinterpret_cast(&frag_AB); -+ -+ AccessTypeBroadcast const *frag_Broadcast_ptr = -+ reinterpret_cast(&frag_Broadcast); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ output_op( -+ frag_Z_ptr[i], -+ frag_T_ptr[i], -+ frag_AB_ptr[i], -+ frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h -new file mode 100644 -index 0000000..6e76f7e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h -@@ -0,0 +1,823 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/functional.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator with reduction over each column -+template < -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors -+ typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands -+ typename ElementVector_, ///< Pointer to reduction vector -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename OutputOp_, ///< Output operator -+ typename ReductionOp_, ///< Reduction operator -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (!IsEpilogueFunctorHeavy::value) -+> -+class EpilogueWithReduction : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_> { -+ -+public: -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using OutputTileIterator = OutputTileIterator_; -+ using TensorTileIterator = TensorTileIterator_; -+ using ElementVector = ElementVector_; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using OutputOp = OutputOp_; -+ using ReductionOp = ReductionOp_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ static bool const kIsSingleSource = true; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Compute data type produced by the output op -+ using ElementCompute = typename OutputOp::ElementCompute; -+ -+ /// Compute fragment -+ using FragmentCompute = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ -+ /// Fragment object used in reduction -+ using ReductionFragment = Array< -+ ElementAccumulator, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Data type of additional tensor -+ using ElementTensor = typename TensorTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array< -+ typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array; -+ -+ /// Array type used by output functor -+ using ComputeAccessType = Array; -+ -+ /// Tensor access type -+ using TensorAccessType = Array; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ /// Shared memory allocation from epilogue base class -+ using BaseSharedStorage = typename Base::SharedStorage; -+ -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// If true, accumulator coordinates are computed and out-of-bounds checks are enabled when -+ /// performing the reduction. -+ static bool const kOobCheck = false; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = kWarpSize * WarpCount::kCount; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// I'm not sure what I meant here. -+ static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ -+ /// Shape of the shared memory allocation for the epilogue -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ Shape::kN -+ >; -+ -+ /// Debug printing -+ CUTLASS_DEVICE -+ static void print() { -+#if 0 -+ printf("ReductionDetail {\n"); -+ printf( -+ " kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" -+ "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", -+ kElementsPerAccess, -+ kColumnsPerThread, -+ kRowsPerThread, -+ kThreadCount, -+ kThreadsPerRow, -+ kThreadRows, -+ kThreadAccessesPerRow, -+ StorageShape::kRow, -+ StorageShape::kColumn, -+ StorageShape::kCount -+ ); -+ printf("};\n"); -+#endif -+ } -+ }; -+ -+ /// Shared storage structure (shadows base) with additional SMEM buffer for reduction -+ struct SharedStorage { -+ union { -+ BaseSharedStorage base; -+ AlignedArray reduction; ///< Shared storage for reduction -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ -+ static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, -+ "Mismatch between shared load iterator and output tile iterator."); -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+ /// Shared memory pointer fo rreduction -+ ElementAccumulator *reduction_ptr_; -+ -+ /// Thread index within the threadblock -+ int thread_idx_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithReduction( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage.base, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.base.reference(), thread_idx), -+ reduction_ptr_(shared_storage.reduction.data()), -+ thread_idx_(thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ ElementVector * reduction_output_ptr, ///< Reduction output vector -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand -+ MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord(Shape::kM, Shape::kN), -+ MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space -+ MatrixCoord()) { -+ -+ ReductionFragment reduction_fragment; -+ reduction_fragment.clear(); -+ -+ if (!output_op.is_source_needed()) { -+ compute_source_not_needed_( -+ output_op, -+ reduction_fragment, -+ destination_iterator, -+ accumulators, -+ tensor_iterator, -+ problem_size, -+ threadblock_offset); -+ } -+ else { -+ compute_source_needed_( -+ output_op, -+ reduction_fragment, -+ destination_iterator, -+ accumulators, -+ source_iterator, -+ tensor_iterator, -+ problem_size, -+ threadblock_offset); -+ } -+ -+ if (output_op.participates_in_reduction()) { -+ reduction_(problem_size, threadblock_offset, reduction_output_ptr, reduction_fragment); -+ } -+ } -+ -+private: -+ -+ /// Perform the reduction -+ CUTLASS_DEVICE -+ void reduction_( -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset, ///< Problem size needed to guard against out-of-bounds accesses -+ ElementVector * reduction_output_ptr, ///< Reduction output vector -+ ReductionFragment const & reduction_fragment) { -+ -+ // -+ // Store the partially reduced value to SMEM -+ // -+ -+ // Guard against uses of the existing SMEM tile -+ __syncthreads(); -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Determine a compacted thread arrangement to store to SMEM. -+ // -+ int const kThreadsPerRow = Shape::kN / (ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess); -+ -+ MatrixCoord thread_offset( -+ thread_idx_ / kThreadsPerRow, -+ (thread_idx_ % kThreadsPerRow) * ThreadMap::kElementsPerAccess); -+ -+ // -+ // Each thread store its fragment to a SMEM -+ // -+ -+ AccessType *aligned_reduction_ptr = reinterpret_cast( -+ &reduction_ptr_[thread_offset.row() * Shape::kN + thread_offset.column()]); -+ -+ AccessType const *frag_ptr = reinterpret_cast(&reduction_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; -+ -+ aligned_reduction_ptr[col_idx] = frag_ptr[column]; -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Now, threads are assigned several columns of the output. They fetch over all rows from -+ // the compacted SMEM tile and perform a reduction. -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { -+ int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; -+ -+ ReductionOp reduction_op; -+ ElementAccumulator reduction_element = ElementAccumulator(); -+ -+ int output_column_idx = threadblock_offset.column() + column_idx; -+ -+ if (column_idx < Shape::kN && output_column_idx < problem_size.column()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { -+ if (row) { -+ auto frag = reduction_ptr_[row * Shape::kN + column_idx]; -+ -+ reduction_element = reduction_op(reduction_element, frag); -+ } -+ else { -+ -+ reduction_element = reduction_ptr_[column_idx]; -+ } -+ } -+ -+ // Store -+ reduction_output_ptr[column_idx] = ElementVector(reduction_element); -+ } -+ } -+ } -+ -+ template -+ struct acc2smem; -+ -+ template -+ struct acc2smem> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_not_needed_( -+ OutputOp const &output_op, ///< Output operator -+ ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additioanl tensor operand -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ typename TensorTileIterator::Fragment tensor_fragment; -+ tensor_fragment.clear(); -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Convert and store fragment -+ // -+ -+ tensor_iterator.load(tensor_fragment); -+ ++tensor_iterator; -+ -+ __syncthreads(); -+ -+ acc2smem>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ // -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ FragmentCompute compute_fragment; -+ -+ apply_output_operator_source_not_needed_( -+ reduction_fragment, -+ compute_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ tensor_fragment, -+ destination_iterator); -+ -+ // -+ // Store the final result -+ // -+ -+ NumericArrayConverter converter; -+ -+ typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void compute_source_needed_( -+ OutputOp const &output_op, ///< Output operator -+ ReductionFragment &reduction_fragment, ///< Fragment containing the accumulated partial reduction over columns -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additioanl tensor operand -+ MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses -+ MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space -+ ) { -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ typename TensorTileIterator::Fragment tensor_fragment; -+ tensor_fragment.clear(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Load the source -+ // -+ -+ source_fragment.clear(); -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ tensor_iterator.load(tensor_fragment); -+ ++tensor_iterator; -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem>::push( -+ iter, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) -+ { -+ plus add_fragments; -+ const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); -+ } -+ -+ // -+ // Compute the output result -+ // -+ -+ FragmentCompute compute_fragment; -+ -+ apply_output_operator_( -+ reduction_fragment, -+ compute_fragment, -+ output_op, -+ aligned_accum_fragment[0], -+ source_fragment, -+ tensor_fragment, -+ destination_iterator); -+ -+ // -+ // Convert and store the final result -+ // -+ -+ NumericArrayConverter converter; -+ -+ typename OutputTileIterator::Fragment output_fragment = converter(compute_fragment); -+ -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_( -+ ReductionFragment &reduction_fragment, -+ FragmentCompute &compute_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment, -+ typename TensorTileIterator::Fragment const &tensor_fragment, -+ OutputTileIterator const & destination_iterator) { -+ -+ ComputeAccessType *compute_frag_ptr = -+ reinterpret_cast(&compute_fragment); -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ TensorAccessType const *tensor_frag_ptr = -+ reinterpret_cast(&tensor_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ compute_frag_ptr[i] = output_op(accum_frag_ptr[i], source_frag_ptr[i], tensor_frag_ptr[i]); -+ } -+ -+ // -+ // Partial reduction over each column -+ // -+ -+ ReductionOp reduction_op; -+ -+ typename OutputTileIterator::Mask mask; -+ destination_iterator.get_mask(mask); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { -+ -+ int column_vector_idx = column / ThreadMap::kElementsPerAccess; -+ bool column_guard = mask.predicates[column_vector_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { -+ -+ bool fetch; -+ if (ReductionDetail::kOobCheck) { -+ int row_idx = (row % ThreadMap::Iterations::kRow); -+ int residual = (row / ThreadMap::Iterations::kRow); -+ -+ int group_idx = (residual % ThreadMap::Iterations::kGroup); -+ residual = (residual / ThreadMap::Iterations::kGroup); -+ -+ int cluster_idx = (residual % ThreadMap::Iterations::kCluster); -+ -+ int row_offset = row_idx * ThreadMap::Delta::kRow -+ + group_idx * ThreadMap::Delta::kGroup -+ + cluster_idx * ThreadMap::Delta::kCluster; -+ -+ int output_row = destination_iterator.thread_start_row() + row_offset; -+ -+ fetch = (output_row < destination_iterator.extent_row() && column_guard); -+ } -+ else { -+ fetch = true; -+ } -+ -+ ElementCompute value = ElementCompute(); -+ if (fetch) { -+ value = compute_fragment[row * ReductionDetail::kColumnsPerThread + column]; -+ } -+ -+ reduction_fragment[column] = reduction_op( -+ reduction_fragment[column], -+ value); -+ } -+ } -+ } -+ -+ /// Helper to invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator_source_not_needed_( -+ ReductionFragment &reduction_fragment, -+ FragmentCompute &compute_fragment, -+ OutputOp const &output_op, ///< Output operator -+ typename SharedLoadIterator::Fragment const &aligned_accum_fragment, -+ typename TensorTileIterator::Fragment const &tensor_fragment, -+ OutputTileIterator const & destination_iterator -+ ) { -+ -+ ComputeAccessType *compute_frag_ptr = -+ reinterpret_cast(&compute_fragment); -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ TensorAccessType const *tensor_frag_ptr = -+ reinterpret_cast(&tensor_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) { -+ -+ // Call the output operator -+ compute_frag_ptr[i] = output_op(accum_frag_ptr[i], tensor_frag_ptr[i]); -+ } -+ -+ // -+ // Partial reduction over each column -+ // -+ -+ ReductionOp reduction_op; -+ -+ typename OutputTileIterator::Mask mask; -+ destination_iterator.get_mask(mask); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ReductionDetail::kColumnsPerThread; ++column) { -+ -+ int column_vector_idx = column / ThreadMap::kElementsPerAccess; -+ bool column_guard = mask.predicates[column_vector_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kRowsPerThread; ++row) { -+ -+ bool fetch; -+ if (ReductionDetail::kOobCheck) { -+ int row_idx = (row % ThreadMap::Iterations::kRow); -+ int residual = (row / ThreadMap::Iterations::kRow); -+ -+ int group_idx = (residual % ThreadMap::Iterations::kGroup); -+ residual = (residual / ThreadMap::Iterations::kGroup); -+ -+ int cluster_idx = (residual % ThreadMap::Iterations::kCluster); -+ -+ int row_offset = row_idx * ThreadMap::Delta::kRow -+ + group_idx * ThreadMap::Delta::kGroup -+ + cluster_idx * ThreadMap::Delta::kCluster; -+ -+ int output_row = destination_iterator.thread_start_row() + row_offset; -+ -+ fetch = (output_row < destination_iterator.extent_row() && column_guard); -+ } -+ else { -+ fetch = true; -+ } -+ -+ ElementCompute value = ElementCompute(); -+ if (fetch) { -+ value = compute_fragment[row * ReductionDetail::kColumnsPerThread + column]; -+ } -+ -+ reduction_fragment[column] = reduction_op( -+ reduction_fragment[column], -+ value); -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h -new file mode 100644 -index 0000000..6c54353 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h -@@ -0,0 +1,409 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Generic epilogue for implementing certain kinds of fused epilogue behavior. -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/epilogue/threadblock/epilogue_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+class EpilogueFusedVisitorConcept { -+public: -+ -+ static int const kIterations = 1; -+ static int const kElementsPerAccess = 4; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using AccumulatorFragment = Array; -+ -+ /// Arguments structure -+ struct Arguments { }; -+ -+ /// Params structure -+ struct Params { -+ -+ Params() { } -+ Params(Arguments const &args) { } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { }; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueFusedVisitorConcept( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord const &problem_size, ///< Problem size of the output -+ int thread_idx, ///< Thread index within the threadblock -+ int warp_idx, ///< Warp index within the threadblock -+ int lane_idx, ///< Lane index within the warp -+ MatrixCoord const &threadblock_offset = MatrixCoord(0, 0)) { ///< Coordinate -+ -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ } -+ -+ /// Called at the end of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator -+template < -+ typename Visitor_, ///< Functor containing fused operations (satisfies EpilogueFusedVisitorConcept) -+ typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) -+ typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ int PartitionsK, ///< Number of partitions of the K dimension -+ typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators -+ typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM -+ typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM -+ typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) -+ int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity -+ int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large -+ (true || !IsEpilogueFunctorHeavy::value) -+> -+class EpilogueWithVisitor : -+ public EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition> { -+ -+public: -+ -+ using Visitor = Visitor_; -+ -+ using Base = EpilogueBase< -+ Shape_, -+ typename WarpMmaOperator_::Shape, -+ PartitionsK, -+ AccumulatorFragmentIterator_, -+ WarpTileIterator_, -+ Padding_, -+ FragmentsPerPartition>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using WarpTileIterator = WarpTileIterator_; -+ using SharedLoadIterator = SharedLoadIterator_; -+ using Padding = Padding_; -+ -+ using Layout = layout::RowMajor; -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename Base::AccumulatorTile; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename WarpTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = Visitor::kElementsPerAccess; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = typename cutlass::TensorRef; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = Array< -+ typename WarpTileIterator::Element, kElementsPerAccess>; -+ -+ /// Number of warps -+ using WarpCount = typename Base::WarpCount; -+ -+ static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; -+ static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -+ -+ using SharedStorage = typename Base::SharedStorage; -+ -+private: -+ -+ /// Loads fragment from shared memory aligned with output tensor -+ SharedLoadIterator shared_load_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWithVisitor( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ shared_load_iterator_(shared_storage.reference(), thread_idx) -+ { -+ -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ Visitor & visitor, -+ AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ -+ visitor.begin_epilogue(); -+ -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ #pragma unroll(IterationsUnroll ? Visitor::kIterations : 1) -+ for (int iter_idx = 0; iter_idx < Visitor::kIterations; ++iter_idx) { -+ -+ // -+ // Load the source -+ // -+ -+ visitor.begin_step(iter_idx); -+ -+ // -+ // Convert and store fragment -+ // -+ -+ __syncthreads(); -+ -+ acc2smem_source_needed>::push( -+ iter_idx, accum_fragment_iterator, this->warp_tile_iterator_); -+ -+ __syncthreads(); -+ -+ // -+ // Load fragments from shared memory -+ // -+ -+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; -+ -+ shared_load_iterator_.load(aligned_accum_fragment[0]); -+ -+ // If the number of k-slices is > 1 - perform a reduction amongst the k-slices -+ if (kPartitionsK > 1) { -+ -+ plus add_fragments; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for ( int i = 1; i < kPartitionsK; ++i) { -+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); -+ shared_load_iterator_.load(aligned_accum_fragment[i]); -+ aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); -+ } -+ -+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); -+ } -+ -+ // -+ // Iterate over output fragments -+ // -+ -+ AccumulatorAccessType const *accum_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment[0]); -+ -+ int const kAccumulatorFragmentCount = AccumulatorTile::kElements / (Visitor::kIterations * AccumulatorAccessType::kElements); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < kAccumulatorFragmentCount; ++idx) { -+ -+ int row_idx = idx / SharedLoadIterator::ThreadMap::Iterations::kColumn; -+ int col_idx = idx % SharedLoadIterator::ThreadMap::Iterations::kColumn; -+ -+ // Start a new row of the output fragment -+ if (!col_idx) { -+ visitor.begin_row(row_idx); -+ } -+ -+ visitor.visit( -+ iter_idx, -+ row_idx, -+ col_idx, -+ idx, -+ accum_frag_ptr[idx] -+ ); -+ -+ // End the row of the output fragment -+ if (col_idx + 1 == SharedLoadIterator::ThreadMap::Iterations::kColumn) { -+ visitor.end_row(row_idx); -+ } -+ } -+ -+ // -+ // Conclude the step -+ // -+ -+ visitor.end_step(iter_idx); -+ } -+ -+ visitor.end_epilogue(); -+ } -+ -+private: -+ -+ -+ template -+ struct acc2smem_source_needed; -+ -+ template -+ struct acc2smem_source_needed> { -+ template -+ CUTLASS_DEVICE -+ static void helper(AccumulatorFragmentIterator accum_fragment_iterator, -+ WarpTileIterator &warp_tile_iterator) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Advance; i++) { -+ ++accum_fragment_iterator; -+ } -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ accum_fragment_iterator.load(accum_fragment); -+ warp_tile_iterator.store(accum_fragment); -+ } -+ -+ CUTLASS_DEVICE -+ static void push(size_t pos, -+ AccumulatorFragmentIterator const &iterator_begin, -+ WarpTileIterator &warp_tile_iterator) { -+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; -+ } -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to create an EpilogueWithVisitor from an existing epilogue -+template -+struct EpilogueWithVisitorFromExistingEpilogue { -+ -+ using Epilogue = EpilogueWithVisitor< -+ Visitor_, -+ typename Existing_::Shape, -+ typename Existing_::WarpMmaOperator, -+ Existing_::kPartitionsK, -+ typename Existing_::AccumulatorFragmentIterator, -+ typename Existing_::WarpTileIterator, -+ typename Existing_::SharedLoadIterator, -+ typename Existing_::Padding, -+ Existing_::kFragmentsPerIteration, -+ IterationsUnroll -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h -new file mode 100644 -index 0000000..5034af3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_workspace.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs. -+ -+ This does not attempt to target any particular output layout. Instead, each threadblock -+ streams out its accumulator elements using 128b store operations. This assumes all threadblocks -+ have unique output tiles. -+ -+ The target data layout is: -+ - threadblock indices mapped to linear offsets as (m, n, k), where m is fastest-changing -+ - threadblock output space partitioned into warps; each warp's region is contiguous -+ - per-thread accumulators partitioned into 128b accesses -+ - output memory striped across the threads of a warp -+ -+ This enables very fast streaming of data, completely limited by the memory system. No predication -+ or data exchange is performed, and each threadblock is assumed to have a full region of memory -+ to write to. -+ -+ This epilogue establishes an upper bound for epilogue performance and is suitable for -+ reductions across the GEMM K dimension which require a separate workspace. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, ///< shape of accumulator tile (concept: MatrixShape) -+ int WarpCount, ///< number of warps -+ typename FragmentC_ ///< warp-level GEMM operator (concept: gemm::warp::Mma) -+> -+class EpilogueWorkspace { -+public: -+ -+ using Shape = Shape_; -+ using FragmentC = FragmentC_; -+ using ElementC = typename FragmentC::value_type; -+ -+ static int const kWarpCount = WarpCount; -+ -+ /// Optimize for 128b accesses -+ static int const kAccessSizeInBits = 128; -+ -+ /// Warp size from the perspective of memory operations -+ static int const kWarpSize = 32; -+ -+ /// Vector length of accesses -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// Number of stores per thread -+ static int const kIterations = FragmentC::kElements / kElementsPerAccess; -+ -+ static_assert( -+ !(FragmentC::kElements % kElementsPerAccess), -+ "The number of accumulators must be divisible by the access size."); -+ -+ /// Total number of vectorized accesses in warp (in units of vector) -+ static int const kWarpAccesses = kIterations * kWarpSize; -+ -+ /// Total number of vectorized accesses in threadblock tile (in units of vector) -+ static int const kThreadblockAccesses = kWarpAccesses * kWarpCount; -+ -+ /// Parameters structure -+ struct Params { -+ -+ /// Pointer to C matrix -+ ElementC *ptr_C; -+ -+ /// Stride between tiles along the GEMM N dimension (in units of vectors) -+ int stride_n; -+ -+ /// Stride between tiles along the GEMM K dimension (in units of vectors) -+ int stride_k; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ ElementC *ptr_C, ///< Pointer to C matrix -+ int stride_n_, ///< Stride between tiles along the GEMM N dimension (in units of ElementC) -+ int stride_k_ ///< Stride between tiles along the GEMM K dimension (in units of ElementC) -+ ): -+ ptr_C(ptr_C), stride_n(stride_n_ / kElementsPerAccess), stride_k(stride_k_ / kElementsPerAccess) { -+ -+ } -+ }; -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage { -+ // Intentionally empty -+ }; -+ -+private: -+ -+ struct alignas((kAccessSizeInBits / 8)) AccessType { -+ Array storage; -+ }; -+ -+ /// Constant reference to parameters object -+ AccessType *pointer_; -+ -+ /// Stride between tiles along the n dimension (in vectors) -+ int stride_n_; -+ -+ /// Stride between tiles along the k dimension (in vectors) -+ int stride_k_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueWorkspace( -+ Params const ¶ms, ///< Host-constructable params object -+ SharedStorage &, ///< Shared storage object -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx ///< Id of thread within warp -+ -+ ): -+ pointer_(reinterpret_cast(params.ptr_C)), -+ stride_n_(params.stride_n), -+ stride_k_(params.stride_k) { -+ -+ // Add per-thread offset -+ pointer_ += lane_idx + warp_idx * kWarpAccesses; -+ } -+ -+ /// Streams the result to global memory -+ CUTLASS_DEVICE -+ void operator()( -+ cutlass::gemm::GemmCoord problem_size, ///< Problem size of GEMM (units of ElementC) -+ cutlass::gemm::GemmCoord tb_tile_coord, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ FragmentC const &accum) { ///< Accumulator tile -+ -+ // Compute offset for entire threadblock (note, per-thread offset has been folded in already) -+ AccessType *pointer = pointer_ + -+ tb_tile_coord.m() * kThreadblockAccesses + -+ tb_tile_coord.n() * stride_n_ + -+ tb_tile_coord.k() * stride_k_; -+ -+ // Cast to vectorized view of accumulator fragments -+ AccessType const * src_pointer = reinterpret_cast(&accum); -+ -+ // Write out accumulators at full speed -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kIterations; ++i) { -+ pointer[i * kWarpSize] = src_pointer[i]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h -new file mode 100644 -index 0000000..b4d1bbe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h -@@ -0,0 +1,407 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue operator without splitk -+template < -+ /// Shape of threadblock tile (concept: GemmShape) -+ typename Shape_, -+ /// Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) -+ typename WarpMmaOperator_, -+ /// Number of partitions of the K dimension -+ int PartitionsK, -+ /// Tile iterator reading and writing output tensors -+ typename OutputTileIterator_, -+ /// Fragment iterator selecting accumulators -+ typename AccumulatorFragmentIterator_, -+ /// Output operator -+ typename OutputOp_, -+ /// Number of interleaved k -+ int InterleavedK> -+class InterleavedEpilogue : -+ public EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_> -+{ -+public: -+ -+ using BaseStreamK = EpilogueBaseStreamK< -+ Shape_, -+ PartitionsK, -+ WarpMmaOperator_, -+ AccumulatorFragmentIterator_>; -+ -+ using Shape = Shape_; -+ using WarpMmaOperator = WarpMmaOperator_; -+ static int const kPartitionsK = PartitionsK; -+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; -+ using OutputTileIterator = OutputTileIterator_; -+ using OutputOp = OutputOp_; -+ -+ /// The complete warp-level accumulator tile -+ using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; -+ -+ /// Fragment type used by the accumulator tile's fragment iterator -+ using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; -+ -+ /// Accumulator element -+ using ElementAccumulator = typename AccumulatorTile::Element; -+ -+ /// Output element -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Output access size -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ /// Tensor reference to destination tensor -+ using TensorRef = typename OutputTileIterator::TensorRef; -+ -+ /// Tensor reference to sync tensor -+ using SyncTensorRef = -+ typename cutlass::TensorRef; -+ -+ /// Const tensor reference to source tensor -+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; -+ -+ /// Array type used to output -+ using OutputAccessType = Array; -+ -+ /// Array type used by output functor -+ using AccumulatorAccessType = -+ Array; -+ -+ /// Number of warps -+ using WarpCount = -+ gemm::GemmShape; -+ -+public: -+ -+ static_assert(OutputTileIterator::kElementsPerAccess, -+ "This must not be zero."); -+ -+ static_assert(!(OutputTileIterator::Fragment::kElements % -+ OutputTileIterator::kElementsPerAccess), -+ "Divisibility"); -+ -+public: -+ -+ /// Aspect for when epilogue source is not needed -+ struct SourceAspectNotNeeded -+ { -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNotNeeded() -+ {} -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i]); -+ } -+ } -+ }; -+ -+ -+ /// Aspect for when epilogue source is needed -+ struct SourceAspectNeeded -+ { -+ OutputTileIterator source_iterator; -+ -+ typename OutputTileIterator::Fragment source_fragment; -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ static void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, -+ typename OutputTileIterator::Fragment const &source_fragment) -+ { -+ OutputAccessType *output_frag_ptr = -+ reinterpret_cast(&output_fragment); -+ -+ AccumulatorAccessType const *compute_frag_ptr = -+ reinterpret_cast(&aligned_accum_fragment); -+ -+ OutputAccessType const *source_frag_ptr = -+ reinterpret_cast(&source_fragment); -+ -+ int const kOutputOpIterations = -+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kOutputOpIterations; ++i) -+ { -+ // Call the output operator -+ output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); -+ } -+ } -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SourceAspectNeeded(OutputTileIterator source_iterator) : -+ source_iterator(source_iterator) -+ { -+ source_fragment.clear(); -+ } -+ -+ /// Invoke the output functor over each vector of output -+ CUTLASS_DEVICE -+ void apply_output_operator( -+ typename OutputTileIterator::Fragment &output_fragment, -+ OutputOp const &output_op, -+ typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) -+ { -+ // Load addend source fragment from global memory -+ source_iterator.load(source_fragment); -+ ++source_iterator; -+ -+ apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); -+ } -+ }; -+ -+ -+ /// Shared storage allocation needed by the epilogue -+ struct SharedStorage {}; -+ -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedEpilogue( -+ SharedStorage &shared_storage, ///< Shared storage object -+ int thread_idx, ///< ID of a thread within the threadblock -+ int warp_idx, ///< ID of warp within threadblock -+ int lane_idx) ///< Id of thread within warp -+ : -+ BaseStreamK(thread_idx) -+ {} -+ -+ -+ /// Aggregates the accumulator sets shared by peer blocks in the global workspace, -+ /// performing epilogue computations, writing to output -+ CUTLASS_DEVICE -+ void reduce( -+ int peer_idx_begin, -+ int peer_idx_end, -+ int reduce_fragment_idx, -+ void *element_workspace, -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) -+ { -+ // Redcuce peer accumulator fragments into one fragment -+ AccumulatorFragment accum_fragment; -+ BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); -+ -+ // Source-fragment data (zero-initialized for scenarios where the -+ // output operator allows us to skip loading it from global input) -+ typename OutputTileIterator::Fragment source_fragment; -+ source_fragment.clear(); -+ -+ if (output_op.is_source_needed()) -+ { -+ source_iterator += reduce_fragment_idx; -+ source_iterator.load(source_fragment); -+ } -+ -+ // Compute the output result -+ typename OutputTileIterator::Fragment output_fragment; -+ -+ // Apply the output operator -+ SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); -+ -+ // Store the final result -+ destination_iterator += reduce_fragment_idx; -+ destination_iterator.store(output_fragment); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements -+ /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (output_op.is_source_needed()) -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ else -+ { -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); -+ } -+ } -+ -+ -+ /// Perform the epilogue computations and stream the result to global memory. Implements a -+ /// single codepath, regardless of whether the output op requires addend data to be loaded -+ CUTLASS_DEVICE -+ void unified( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ OutputTileIterator source_iterator ) ///< Tile iterator for addend source -+ { -+ if (!output_op.is_source_needed()) -+ { -+ source_iterator.clear_mask(); -+ __syncthreads(); // Dummy (CUDA 11.0) -+ } -+ -+ operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); -+ } -+ -+ -+ /// Streams the result to global memory -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ OutputOp const &output_op, ///< Output operator -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile -+ SourceAspect source) -+ { -+ // -+ // Iterator over warp-level accumulator fragment -+ // -+ -+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators); -+ -+ // -+ // Iterate over accumulator tile -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { -+ -+ // -+ // Convert fragment -+ // -+ -+ typename AccumulatorFragmentIterator::Fragment accum_fragment; -+ -+ accum_fragment_iterator.load(accum_fragment); -+ ++accum_fragment_iterator; -+ -+ // -+ // Compute the output result -+ // -+ -+ typename OutputTileIterator::Fragment output_fragment; -+ source.apply_output_operator(output_fragment, output_op, accum_fragment); -+ -+ // -+ // Store the final result -+ // -+ -+ destination_iterator.set_iteration_index(iter); -+ destination_iterator.store(output_fragment); -+ ++destination_iterator; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h -new file mode 100644 -index 0000000..8cfba76 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h -@@ -0,0 +1,92 @@ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+template< -+ typename TensorLayout_, ///! The original output tensor layout -+ typename OutputIteratorLayout_, ///! Layout used by epilogue output iterator -+ typename TensorRef_, ///! Input tensor to epilogue output iterator -+ conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) -+ typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem -+> -+struct ConvOutputIteratorParameter { -+ -+ using TensorLayout = TensorLayout_; -+ using OutputIteratorLayout = OutputIteratorLayout_; -+ using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; -+ using TensorRef = TensorRef_; -+ static conv::Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ -+ /// Wgrad stride idx for implicit gemm algorithm -+ // Conv2d row-major matrix (KxRSC) -+ // Conv3d row-major matrix (KxTRSC) -+ static int const kWgradStrideIdx = -+ platform::is_same::value ? 2 : 3; -+ -+ /// This chooses the appropriate stride element of the C tensor. -+ static int const kTensorStrideIdx = -+ (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0); -+ -+ -+ CUTLASS_HOST_DEVICE -+ static OutputIteratorLayout layout(const TensorRef & ref) { -+ return ref.stride(kTensorStrideIdx); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static OutputTensorCoord extent(ConvProblemSize problem_size) { -+ return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); -+ } -+ -+}; -+ -+ -+ -+template < -+ int InterleavedK, -+ typename TensorRef_, -+ conv::Operator ConvOperator, -+ typename ConvProblemSize_ -+> -+struct ConvOutputIteratorParameter< -+ layout::TensorNCxHWx, -+ layout::TensorNCxHWx, -+ TensorRef_, -+ ConvOperator, -+ ConvProblemSize_> -+{ -+ -+ using TensorLayout = typename layout::TensorNCxHWx; -+ using OutputIteratorLayout = typename layout::TensorNCxHWx; -+ using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; -+ using TensorRef = TensorRef_; -+ static conv::Operator const kConvolutionalOperator = ConvOperator; -+ using ConvProblemSize = ConvProblemSize_; -+ -+ CUTLASS_HOST_DEVICE -+ static OutputIteratorLayout layout(const TensorRef & ref) { -+ return ref.stride(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static OutputTensorCoord extent(ConvProblemSize problem_size) { -+ return problem_size.output_extent(); -+ } -+ -+}; -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h -new file mode 100644 -index 0000000..828b7a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h -@@ -0,0 +1,626 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple defining point in output tile -+template < -+ int Column, -+ int Row, -+ int Group, -+ int Cluster, -+ int Tile -+> -+struct OutputTileShape { -+ static int const kColumn = Column; -+ static int const kRow = Row; -+ static int const kGroup = Group; -+ static int const kCluster = Cluster; -+ static int const kTile = Tile; -+ -+ static int const kCount = kColumn * kRow * kGroup * kCluster * kTile; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct OutputTileThreadMapHelpers { -+ -+ /// Determines the iteration index of a vector access according to the thread map -+ CUTLASS_HOST_DEVICE -+ static void iteration_index( -+ int &column_idx, -+ int &row_idx, -+ int &group_idx, -+ int &cluster_idx, -+ int &tile_idx, -+ int iter_idx) { -+ -+ column_idx = iter_idx % Iterations::kColumn; -+ int residual = iter_idx / Iterations::kColumn; -+ -+ row_idx = residual % Iterations::kRow; -+ residual = residual / Iterations::kRow; -+ -+ group_idx = residual % Iterations::kGroup; -+ residual = residual / Iterations::kGroup; -+ -+ cluster_idx = residual % Iterations::kCluster; -+ tile_idx = residual / Iterations::kCluster; -+ } -+ -+ /// Computes the offset of a given vector access -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord iteration_offset(int iter_idx) { -+ -+ int column_idx; -+ int row_idx; -+ int group_idx; -+ int cluster_idx; -+ int tile_idx; -+ -+ iteration_index(column_idx, row_idx, group_idx, cluster_idx, tile_idx, iter_idx); -+ -+ return -+ MatrixCoord( -+ row_idx * Delta::kRow + -+ group_idx * Delta::kGroup + -+ cluster_idx * Delta::kCluster + -+ tile_idx * Delta::kTile, -+ -+ column_idx * Delta::kColumn); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template < -+ typename ThreadMap_, -+ typename Shape_, -+ typename Iterations_, -+ typename Delta_, -+ typename Count_ -+> -+struct OutputTileThreadMap : public OutputTileThreadMapHelpers { -+ -+ /// Conventional thread map (concept: ThreadMap) -+ using ThreadMap = ThreadMap_; -+ -+ /// Number of threads participating in the operation -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Number of scalar elements per access -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Shape of the tile -+ using Shape = Shape_; -+ -+ /// Iterations performed by each thread -+ using Iterations = Iterations_; -+ -+ /// Delta between accesses -+ using Delta = Delta_; -+ -+ /// Number of iterator iterations -+ using Count = Count_; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ using Index = typename layout::PitchLinearCoord::Index; -+ -+ layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx); -+ -+ Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow); -+ Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow); -+ -+ Index group = cluster_residual / (Shape::kRow); -+ Index row = cluster_residual % (Shape::kRow); -+ -+ return MatrixCoord{ -+ row + group * Shape::kRow * Count::kRow -+ + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow, -+ coord.contiguous() -+ }; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// RowArrangement determines how one or more warps cover a region of consecutive rows. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize, -+ bool Is2dTile -+> -+struct RowArrangement; -+ -+/// RowArrangement in which each warp's access is a 1D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangement { -+ static int const kWarpSize = 32; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ static int const kIterationsRow = 1; -+ static int const kDeltaRow = 1; -+ static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; -+ static int const kDeltaColumn = kWarpSize * kElementsPerAccess; -+ -+ static int const kAccessWidth = kWarpSize; -+ static int const kAccessRows = 1; -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = WarpsRemaining; -+}; -+ -+/// RowArrangement in which each warp's access is a 2D tiled arrangement. -+template < -+ typename Shape, -+ int WarpsRemaining, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct RowArrangement { -+ -+ static int const kMemoryAccessSize = 256; // Preferred access size -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ struct Detail { -+ static int const kShapeRow = Shape::kRow / WarpsRemaining; -+ static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; -+ -+ static int const kTargetMemoryAccessWidth = -+ kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); -+ -+ static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; -+ }; -+ -+ static int const kAccessWidth = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ kWarpSize / Detail::kShapeRow -+ : const_min( -+ Detail::kShapeWidth, -+ const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) -+ )); -+ -+ static int const kAccessRows = -+ (Detail::kTargetAccessRows > Detail::kShapeRow ? -+ Detail::kShapeRow -+ : const_min(Shape::kRow, kWarpSize / kAccessWidth)); -+ -+ static int const kIterationsRow = Detail::kShapeRow / kAccessRows; -+ static int const kDeltaRow = kAccessRows; -+ -+ static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; -+ static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; -+ -+ static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); -+ static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); -+ static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); -+ -+ static int const kWarpPartitionsRow = 1; -+ static int const kWarpPartitionsColumn = 1; -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 128 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template < -+ typename Shape_, -+ typename Count_, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct OutputTileOptimalThreadMap { -+ -+ using Shape = Shape_; -+ using Count = Count_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail { -+ -+ // Clusters -+ static int const kIterationsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kCluster / kWarpCount -+ : 1); -+ -+ static int const kDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kCompactedDeltaCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster -+ : 1); -+ -+ static int const kWarpPartitionsCluster = -+ ((Shape::kCluster > kWarpCount) ? -+ kWarpCount -+ : kWarpCount / Shape::kCluster); -+ -+ static int const kWarpsRemainingForGroups = -+ ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); -+ -+ // Groups -+ static int const kIterationsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kGroup / kWarpsRemainingForGroups -+ : 1); -+ -+ static int const kDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kCompactedDeltaGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ Shape::kRow * Shape::kGroup / kIterationsGroup -+ : 1); -+ -+ static int const kWarpPartitionsGroup = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ static int const kWarpsRemainingForRows = -+ ((Shape::kGroup > kWarpsRemainingForGroups) ? -+ 1 -+ : kWarpsRemainingForGroups / Shape::kGroup); -+ -+ // Rows -+ using RowArrangement = detail::RowArrangement< -+ Shape, -+ kWarpsRemainingForRows, -+ kElementsPerAccess, -+ kElementSize, -+ (Shape::kRow > kWarpsRemainingForRows) -+ >; -+ -+ // Warp partitions -+ using WarpPartitions = OutputTileShape< -+ RowArrangement::kWarpPartitionsColumn, -+ RowArrangement::kWarpPartitionsRow, -+ kWarpPartitionsGroup, -+ kWarpPartitionsCluster, -+ 1>; -+ -+ static int const kAccessWidth = RowArrangement::kAccessWidth; -+ static int const kAccessRows = RowArrangement::kAccessRows; -+ }; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kDeltaGroup, -+ Detail::kDeltaCluster, -+ 1>; -+ -+ /// Initial offset function -+ CUTLASS_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; -+ int group_offset = group_idx * Shape::kRow * Count::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ return MatrixCoord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ column_offset + lane_col_offset * kElementsPerAccess -+ ); -+ } -+ -+ /// Computes the offset of a given vector access -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord iteration_offset(int iter_idx) { -+ return OutputTileThreadMapHelpers::iteration_offset(iter_idx); -+ } -+ -+ /// Compacted thread map in which the 4D region is contiguous -+ struct CompactedThreadMap { -+ -+ -+ using Shape = Shape_; -+ -+ using TileShape = MatrixShape< -+ Shape::kTile * Shape::kCluster * Shape::kGroup * Shape::kRow, -+ Shape::kColumn -+ >; -+ -+ using Iterations = OutputTileShape< -+ Detail::RowArrangement::kIterationsColumn, -+ Detail::RowArrangement::kIterationsRow, -+ Detail::kIterationsGroup, -+ Detail::kIterationsCluster, -+ 1>; -+ -+ using Delta = OutputTileShape< -+ Detail::RowArrangement::kDeltaColumn, -+ Detail::RowArrangement::kDeltaRow, -+ Detail::kCompactedDeltaGroup, -+ Detail::kCompactedDeltaCluster, -+ 1>; -+ -+ /// Number of elements within each vector access -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Number of threads -+ static int const kThreads = Threads; -+ -+ /// Function to compute each thread's initial offset -+ CUTLASS_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ -+ int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; -+ int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; -+ -+ int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; -+ int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; -+ -+ int row_idx = residual_group / Detail::WarpPartitions::kRow; -+ int col_idx = residual_group % Detail::WarpPartitions::kRow; -+ -+ // Compute per-lane offset -+ int lane_row_offset = lane_idx / Detail::kAccessWidth; -+ int lane_col_offset = lane_idx % Detail::kAccessWidth; -+ -+ // Compute coordinate in output space -+ int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup; -+ int group_offset = group_idx * Shape::kRow; -+ int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; -+ int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; -+ -+ MatrixCoord coord( -+ cluster_offset + group_offset + row_offset + lane_row_offset, -+ column_offset + lane_col_offset * kElementsPerAccess -+ ); -+ -+ return coord; -+ } -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 3D interleaved layout across warps -+/// to achieve several performance objectives: -+/// -+/// - coalesced memory accesses in units of 64 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template -+struct InterleavedOutputTileThreadMap { -+ using WarpCount = WarpCount_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail {}; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = Iterations_; -+ -+ using Delta = layout::PitchLinearShape; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static layout::PitchLinearCoord initial_offset(int thread_idx) { -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ layout::PitchLinearCoord warp_footprint{ -+ Delta::kContiguous * Iterations::kContiguous, -+ Delta::kStrided * Iterations::kStrided}; -+ -+ layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous, -+ warp_idx / WarpCount::kContiguous}; -+ -+ // Compute per-lane offset -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_idx * kElementsPerAccess, 0}; -+ -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ return thread_offset_in_threadblock_tile; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template metaprogram for partitioning a 4D interleaved layout across warps -+/// to achieve several performance objectives: -+/// -+/// - coalesced memory accesses in units of 64 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+template -+struct InterleavedConvOutputTileThreadMap { -+ using WarpCount = WarpCount_; -+ -+ static int const kWarpSize = 32; -+ static int const kThreads = Threads; -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kElementSize = ElementSize; -+ -+ // -+ // Metaprogram computation -+ // -+ -+ struct Detail {}; -+ -+ // -+ // Output -+ // -+ -+ using Iterations = Iterations_; -+ -+ using Delta = MatrixShape; -+ -+ /// Initial offset function -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord initial_offset(int thread_idx) { -+ int warp_idx = thread_idx / kWarpSize; -+ int lane_idx = thread_idx % kWarpSize; -+ -+ // Compute warp location -+ MatrixCoord warp_footprint{ -+ Delta::kRow * Iterations::kRow, -+ Delta::kColumn * Iterations::kColumn, -+ }; -+ -+ MatrixCoord warp_offset{warp_idx % WarpCount::kRow, -+ warp_idx / WarpCount::kRow}; -+ -+ // Compute per-lane offset -+ MatrixCoord thread_offset_in_warp{lane_idx / 4, -+ (lane_idx % 4) * kElementsPerAccess}; -+ -+ MatrixCoord thread_offset_in_threadblock_tile = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ return thread_offset_in_threadblock_tile; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h -new file mode 100644 -index 0000000..685b6bb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h -@@ -0,0 +1,1351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/permute.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ bool ScatterD = false, ///< Scatter D operand or not -+ typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not -+ bool UseCUDAStore = false -+> -+class PredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ using Base = PredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). -+ uint8_t *byte_pointer_; -+ -+ /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_. -+ uint8_t *store_byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Scatter indices -+ int const *indices_; -+ -+ /// Whether to perform Permute Op -+ bool PermuteD; -+ /// PermuteDLayout -+ mutable PermuteDLayout permute_layout_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIterator( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord(), -+ int const *indices = nullptr -+ ): -+ params_(params), indices_(indices) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ if (ScatterD && !indices) { -+ mask_.clear(); -+ } -+ -+ // Initialize byte_pointer_ -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ if (ScatterD) { -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ } -+ -+ // store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used. -+ store_byte_pointer_ = byte_pointer_; -+ -+ // Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly. -+ if (platform::is_same::value) { -+ PermuteD = false; -+ }else{ -+ PermuteD = true; -+ store_byte_pointer_ = reinterpret_cast(pointer); -+ permute_layout_ = PermuteDLayout(extent, -+ params_.stride * kElementsPerAccess / sizeof(AccessType)); -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { -+ uint8_t *byte_pointer = store_byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ if (ScatterD && row_guard) { -+ assert(indices_); -+ -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset + -+ LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ int col_offset = column * ThreadMap::Delta::kColumn; -+ -+ if (PermuteD) { -+ int col = col_offset + thread_start_column_; -+ int row = row_offset + thread_start_row_; -+ -+ TensorCoord init_coord(row, col); -+ -+ // Locate memory_pointer -+ memory_pointer = reinterpret_cast(byte_pointer + byte_offset -+ + permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess); -+ } -+ -+ if (UseCUDAStore) { -+ if (guard) { -+ memory_pointer[0] = -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; -+ } -+ } else { -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)&memory_pointer[0], -+ guard); -+ } -+ -+ if (!PermuteD) { -+ memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ if (!ScatterD && !PermuteD) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) const { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void downsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ -+ int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + -+ (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; -+ -+ int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void upsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ int output_row = row_offset + thread_start_row_; -+ int output_N = output_row / (convolution_P * convolution_Q); -+ int output_PQ = output_row % (convolution_P * convolution_Q); -+ int output_P = output_PQ / convolution_Q; -+ int output_Q = output_PQ % convolution_Q; -+ int row_add_P = add_P; -+ int row_add_Q = add_Q; -+ if (output_P > convolution_P - 2) row_add_P = 0; -+ if (output_Q > convolution_Q - 2) row_add_Q = 0; -+ -+ int input_row = output_N * (convolution_P/2) * (convolution_Q/2) + -+ ((output_P + row_add_P)/2) * (convolution_Q/2) + (output_Q + row_add_Q)/2; -+ -+ int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ -+ ++state_[0]; -+ -+ if (!ScatterD && !PermuteD) { -+ store_byte_pointer_ += params_.advance_row; -+ } -+ -+ if (!ScatterD) { -+ byte_pointer_ += params_.advance_row; -+ } -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ store_byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ store_byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ store_byte_pointer_ += params_.advance_tile; -+ -+ thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow -+ * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances a number of positions to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator+=(int increment) -+ { -+ // Row -+ state_[0] += increment; -+ int increment_row = state_[0] / ThreadMap::Count::kRow; -+ state_[0] = state_[0] % ThreadMap::Count::kRow; -+ -+ byte_pointer_ += (params_.advance_row * increment); -+ store_byte_pointer_ += (params_.advance_row * increment); -+ thread_start_row_ += (ThreadMap::Shape::kRow * increment); -+ -+ // Group -+ state_[1] += increment_row; -+ int increment_group = state_[1] / ThreadMap::Count::kGroup; -+ state_[1] = state_[1] % ThreadMap::Count::kGroup; -+ -+ byte_pointer_ += (params_.advance_group * increment_row); -+ store_byte_pointer_ += (params_.advance_group * increment_row); -+ thread_start_row_ += -+ (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * -+ ThreadMap::Count::kRow * -+ increment_row; -+ -+ -+ // Cluster -+ state_[2] += increment_group; -+ int increment_cluster = state_[2] / ThreadMap::Count::kCluster; -+ state_[2] = state_[2] % ThreadMap::Count::kCluster; -+ -+ byte_pointer_ += (params_.advance_cluster * increment_group); -+ store_byte_pointer_ += (params_.advance_cluster * increment_group); -+ thread_start_row_ += -+ ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * -+ ThreadMap::Count::kRow * -+ ThreadMap::Shape::kRow * -+ increment_group; -+ -+ // Tile -+ byte_pointer_ += (params_.advance_tile * increment_cluster); -+ store_byte_pointer_ += (params_.advance_tile * increment_cluster); -+ thread_start_row_ += -+ ThreadMap::Shape::kGroup * -+ ThreadMap::Shape::kRow * -+ ThreadMap::Shape::kCluster * -+ ThreadMap::Shape::kTile * -+ increment_cluster; -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int InterleavedN ///< Number of Interleaved N -+> -+class InterleavedPredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::ColumnMajorInterleaved; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Iterations::kCount; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ /// Uses a non-template class -+ struct Params : InterleavedPredicatedTileIteratorParams { -+ using Base = InterleavedPredicatedTileIteratorParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ Base( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_InterleavedPredicatedTileIteratorDesc() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = (ThreadMap::Iterations::kContiguous < 8) -+ ? 8 -+ : ThreadMap::Iterations::kContiguous; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_col_; -+ -+ /// Internal iteration counter -+ int iteration_contiguous_; -+ -+ int iteration_strided_; -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedPredicatedTileIterator( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ params_(params) { -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + -+ TensorCoord(threadblock_offset.contiguous() * InterleavedN, -+ threadblock_offset.strided() / InterleavedN); -+ -+ extent_col_ = extent.strided() / InterleavedN; -+ thread_start_col_ = thread_offset.strided(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ mask_.predicates[c] = -+ ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < -+ (extent.contiguous() * InterleavedN)); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ // Initialize internal state counter -+ iteration_contiguous_ = iteration_strided_ = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ -+ bool guard = col_guard && mask_.predicates[iteration_contiguous_]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ *frag_ptr, -+ (void *)memory_pointer, -+ guard); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ -+ bool guard = col_guard && mask_.predicates[iteration_contiguous_]; -+ -+ cutlass::arch::global_store( -+ *frag_ptr, (void *)memory_pointer, guard); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int iteration) { -+ iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIterator &operator++() { -+ -+ ++iteration_contiguous_; -+ byte_pointer_ += params_.advance_row; -+ -+ if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) { -+ -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ byte_pointer_ += params_.advance_column; -+ -+ if (iteration_strided_ == ThreadMap::Iterations::kStrided) { -+ iteration_strided_ = 0; -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances a number of positions to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIterator &operator+=(int increment) -+ { -+ // Contiguous -+ iteration_contiguous_ += increment; -+ int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous; -+ iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous; -+ byte_pointer_ += (params_.advance_row * increment); -+ -+ // Strided -+ iteration_strided_ += increment_strided; -+ byte_pointer_ += (params_.advance_column * increment_strided); -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int InterleavedN ///< Number of Interleaved N -+> -+class InterleavedConvPredicatedTileIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::TensorNCxHWx; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = Tensor4DCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Iterations::kCount; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride_col; ///< stride in bytes between columns -+ LongIndex stride_row; ///< stride in bytes between rows -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(typename Layout::Stride stride_) { -+ stride_col = stride_[1]; -+ stride_row = stride_[2]; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { -+ initialize(cutlass::make_Coord(0, 0, 0)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) { -+ -+ initialize(layout.stride()); -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ static int const kCount = -+ (ThreadMap::Iterations::kRow < 8) ? 8 : ThreadMap::Iterations::kRow; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in pq -+ Index extent_pq_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have -+ /// been computed) -+ Index thread_start_col_; -+ -+ /// Internal iteration counter -+ LongIndex iteration_row_; -+ LongIndex iteration_col_; -+ -+ uint32_t pq_mul_; -+ -+ uint32_t pq_shr_; -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ InterleavedConvPredicatedTileIterator( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ MatrixCoord threadblock_offset -+ ): -+ params_(params) { -+ MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_col_ = extent.c(); -+ extent_pq_ = extent.h() * extent.w(); -+ extent_row_ = extent.n() * extent_pq_; -+ -+ find_divisor(pq_mul_, pq_shr_, extent_pq_); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_col_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int r = 0; r < ThreadMap::Iterations::kRow; ++r) { -+ mask_.predicates[r] = -+ ((thread_offset.row() + ThreadMap::Delta::kRow * r) < extent_row_); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ ((thread_start_col_ / InterleavedN) * params_.stride_col + -+ (thread_start_col_ % InterleavedN)) * -+ sizeof_bits::value / 8; -+ -+ // Initialize internal state counter -+ iteration_row_ = iteration_col_ = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ bool guard = col_guard && mask_.predicates[iteration_row_]; -+ -+ int n, pq_rem; -+ -+ fast_divmod(n, pq_rem, -+ thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, -+ extent_pq_, pq_mul_, pq_shr_); -+ -+ uint8_t *byte_pointer = -+ byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * -+ sizeof_bits::value / 8; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *memory_pointer = -+ reinterpret_cast(byte_pointer); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ *frag_ptr, -+ (void *)memory_pointer, -+ guard); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; -+ bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); -+ bool guard = col_guard && mask_.predicates[iteration_row_]; -+ -+ int n, pq_rem; -+ -+ fast_divmod(n, pq_rem, -+ thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, -+ extent_pq_, pq_mul_, pq_shr_); -+ -+ uint8_t *byte_pointer = -+ byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * -+ sizeof_bits::value / 8; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ cutlass::arch::global_store( -+ *frag_ptr, (void *)memory_pointer, guard); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int iteration) { -+ iteration_row_ = iteration % ThreadMap::Iterations::kRow; -+ iteration_col_ = iteration / ThreadMap::Iterations::kRow; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ InterleavedConvPredicatedTileIterator &operator++() { -+ -+ ++iteration_row_; -+ -+ if (iteration_row_ == ThreadMap::Iterations::kRow) { -+ -+ iteration_row_ = 0; -+ ++iteration_col_; -+ byte_pointer_ += params_.stride_col; -+ -+ if (iteration_col_ == ThreadMap::Iterations::kColumn) { -+ iteration_col_ = 0; -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h -new file mode 100644 -index 0000000..505f529 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h -@@ -0,0 +1,615 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+/// It provides a fast path for the case Rank = 2 which does not need div/rem to -+/// calculate modes. -+ -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int Rank -+> -+class PredicatedTileIteratorAffineRankN { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::AffineRankN; -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ static_assert( !(Layout::kRank % 2), -+ "Layout rank must be even. This assumes the first half of the modes correspond to the 'row' " -+ "and the second half of the modes correspond to the 'column'"); -+ -+ static bool const kBigEndian = false; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ Layout layout; -+ -+ /// Stride in units of bytes along M modes -+ Coord stride_m; -+ -+ /// Stride in units of bytes along N modes -+ Coord stride_n; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ int64_t rank2_inc_col; -+ int64_t rank2_inc_row; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(TensorCoord const &extent, Layout const &layout_): layout(layout_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i]); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); -+ } -+ -+ if (kBigEndian) { -+ // "Big Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i + 1]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); -+ } -+ } -+ else { -+ // "Little Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); -+ } -+ } -+ -+ #if 0 -+ // -+ // Debug print statements to verify extents and strides are passed correctly. -+ // -+ printf("PredicatedTileIteratorAffine::Params() entered\n"); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" extent[%d]: %d\n", i, extent[i]); -+ } -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); -+ } -+ printf("PredicatedTileIteratorAffine::Params() returning\n"); -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout_): layout(layout_) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i]); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); -+ } -+ -+ rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; -+ rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in columns -+ Index extent_col_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have been computed) -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Offsets in columns, cached for performance -+ int64_t offset_modes_n_[ThreadMap::Iterations::kColumn]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorAffineRankN( -+ Params const & params, -+ Element *pointer, -+ MatrixCoord extent, -+ int thread_idx, -+ MatrixCoord threadblock_offset = MatrixCoord(), -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ params_(params) -+ { -+ -+ MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ extent_col_ = extent.column(); -+ -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ if (Layout::kRank > 2) { -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn; -+ -+ mask_.predicates[c] = coord_n < extent.column(); -+ -+ Coord modes_n; -+ -+ int64_t offset_modes_n = 0; -+ -+ if (kBigEndian) { -+ modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); -+ -+ offset_modes_n = dot(modes_n, params_.stride_n); -+ } -+ else { -+ modes_n = CoordinateDecompositionLittleEndian(coord_n, params_.divmod_n); -+ -+ offset_modes_n = dot(modes_n, params_.stride_n); -+ } -+ -+ offset_modes_n_[c] = offset_modes_n; -+ -+ } -+ -+ if (!pointer) { -+ mask_.clear(); -+ } -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer); -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ uint8_t const *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; -+ int64_t offset_modes_m = row_begin * params_.stride_m[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ // -+ // Compute coordinate and decompose into M modes -+ // -+ -+ int coord_m = row * ThreadMap::Delta::kRow + row_begin; -+ -+ Coord modes_m; -+ -+ if (Layout::kRank > 2) { -+ if (kBigEndian) { -+ modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); -+ } else { -+ modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); -+ } -+ -+ offset_modes_m = dot(modes_m, params_.stride_m); -+ } -+ -+ // -+ // Compute the offset due to modes M -+ // -+ -+ bool row_guard = (coord_m < extent_row_); -+ int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ if (Layout::kRank > 2) { -+ offset_modes_n = offset_modes_n_[column]; -+ } -+ -+ // -+ // Compute the pointer and access -+ // -+ bool guard; -+ -+ if (Layout::kRank > 2) { -+ guard = row_guard && mask_.predicates[column]; -+ } else { -+ guard = (coord_m < extent_row_) && -+ ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); -+ } -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), -+ guard -+ ); -+ -+ if (Layout::kRank == 2) { -+ offset_modes_n += params_.rank2_inc_col; -+ } -+ } -+ -+ if (Layout::kRank == 2) { -+ offset_modes_m += params_.rank2_inc_row; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; -+ int64_t offset_modes_m = row_begin * params_.stride_m[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ // -+ // Compute coordinate and decompose into M modes -+ // -+ -+ int coord_m = row * ThreadMap::Delta::kRow + row_begin; -+ -+ Coord modes_m; -+ -+ if (Layout::kRank > 2) { -+ if (kBigEndian) { -+ modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); -+ } else { -+ modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); -+ } -+ -+ offset_modes_m = dot(modes_m, params_.stride_m); -+ } -+ -+ // -+ // Compute the offset due to modes M -+ // -+ -+ bool row_guard = (coord_m < extent_row_); -+ int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ // -+ // Compute coordinate and decompose into N modes -+ // -+ -+ if (Layout::kRank > 2) { -+ offset_modes_n = offset_modes_n_[column]; -+ } -+ -+ // -+ // Compute the pointer and access -+ // -+ bool guard; -+ if (Layout::kRank > 2) { -+ guard = row_guard && mask_.predicates[column]; -+ } else { -+ guard = (coord_m < extent_row_) && ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); -+ } -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), -+ guard); -+ -+ if (Layout::kRank == 2) { -+ offset_modes_n += params_.rank2_inc_col; -+ } -+ } -+ -+ if (Layout::kRank == 2) { -+ offset_modes_m += params_.rank2_inc_row; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineRankN &operator++() { -+ -+ ++state_[0]; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h -new file mode 100644 -index 0000000..7832fde ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/fast_math.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Rank -+> -+struct PredicatedTileIteratorAffineLayoutRankNParams { -+ using Layout = layout::AffineRankN; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static bool const kBigEndian = false; -+ -+ // -+ // Data members -+ // -+ -+ Layout layout; -+ -+ /// Stride in units of bytes along M modes -+ Coord stride_m; -+ -+ /// Stride in units of bytes along N modes -+ Coord stride_n; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ /// Fast divmod objects divided by tensor extents -+ FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; -+ -+ int64_t rank2_inc_col; -+ int64_t rank2_inc_row; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams() { } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent, -+ Layout const &layout_, -+ int64_t element_sizeof_bits) -+ : layout(layout_) -+ { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); -+ } -+ -+ if (kBigEndian) { -+ // "Big Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i + 1]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); -+ } -+ } -+ else { -+ // "Little Endian" scheme -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { -+ divmod_m[i] = FastDivmod(extent[i]); -+ divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); -+ } -+ } -+ -+ #if 0 -+ // -+ // Debug print statements to verify extents and strides are passed correctly. -+ // -+ printf("PredicatedTileIteratorAffine::Params() entered\n"); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" extent[%d]: %d\n", i, extent[i]); -+ } -+ for (int i = 0; i < Layout::kRank; ++i) { -+ printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); -+ } -+ printf("PredicatedTileIteratorAffine::Params() returning\n"); -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_, -+ int32_t threadmap_delta_kColumn, -+ int32_t threadmap_delta_kRow, -+ int64_t element_sizeof_bits) -+ : layout(layout_) -+ { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank / 2; ++i) { -+ stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); -+ stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); -+ } -+ -+ rank2_inc_col = threadmap_delta_kColumn * stride_n[0]; -+ rank2_inc_row = threadmap_delta_kRow * stride_m[0]; -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h -new file mode 100644 -index 0000000..9aab017 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h -@@ -0,0 +1,633 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ BlasMode BlasMode_ = BlasMode::kGemm ///< Tile Iterator for a Symmetric or Hermitian Kernel -+> -+class PredicatedTileIteratorBlas3 { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ static_assert( AccessType::kElements == 1, "BLAS3 Epilogue must use AccessType::kElements as 1"); -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Fill Mode for a tile on diagonal of a symmetric kernel -+ cutlass::FillMode fill_mode; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ /// Starting address of the matrix -+ size_t matrix_start_addr; -+ -+ static_assert((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian), -+ "Unsupported blas3 mode."); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorBlas3( -+ PredicatedTileIteratorParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset -+ , cutlass::FillMode fill_mode -+ ): -+ params_(params), fill_mode(fill_mode) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ extent_row_ = extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Check Symmetric kernel modes (Lower and Upper - for diagonal CTAs, None for rest CTAs) -+ if ((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian) && -+ fill_mode == cutlass::FillMode::kInvalid) { -+ arch::device_breakpoint(); -+ } -+ -+ // Starting address of the matrix -+ matrix_start_addr = reinterpret_cast(pointer); -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer) + -+ LongIndex(thread_offset.row()) * LongIndex(params_.stride) + -+ LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment on the diagonal of a symmetric kernel to memory -+ CUTLASS_DEVICE -+ void load_symmetric_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ // Offset of row from beginning of the matrix per thread -+ size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; -+ -+ // Absolute row index -+ int row_index = int(row_start_offset/params_.stride); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ // Offset of column from beginning of row per thread -+ size_t col_start_offset = row_start_offset + -+ (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); -+ -+ // Absolute column index -+ size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); -+ guard = guard && ( (isLowerMode && row_index >= col_index) || -+ (!isLowerMode && row_index <= col_index) ); -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ -+ // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero -+ if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ Element *scalar_ptr = reinterpret_cast(frag_ptr); -+ -+ if (row_index == col_index) { -+ scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = -+ real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); -+ } -+ } -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ if (fill_mode == cutlass::FillMode::kNone) { -+ load_with_byte_offset(frag, 0); -+ } -+ else { -+ load_symmetric_with_byte_offset(frag, 0); -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment on the diagonal of a symmetric kernel to memory -+ CUTLASS_DEVICE -+ void store_symmetric_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ bool row_guard = ((row_offset + thread_start_row_) < extent_row_); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ // Offset of row from beginning of the matrix per thread -+ size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; -+ -+ // Absolute row index -+ int row_index = int(row_start_offset/params_.stride); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ // Offset of column from beginning of row per thread -+ size_t col_start_offset = row_start_offset + -+ (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); -+ -+ // Absolute column index -+ size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); -+ -+ guard = guard && ( (isLowerMode && row_index >= col_index) || -+ (!isLowerMode && row_index <= col_index) ); -+ -+ // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero -+ if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ -+ AccessType *frag_ptr_modify = const_cast(frag_ptr); -+ Element *scalar_ptr = reinterpret_cast(frag_ptr_modify); -+ -+ if (row_index == col_index) { -+ scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = -+ real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); -+ } -+ } -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / -+ kElementsPerAccess], -+ guard); -+ } -+ -+ if (row + 1 < ThreadMap::Iterations::kRow) { -+ byte_pointer += params_.increment_row; -+ } -+ } -+ -+ if (group + 1 < ThreadMap::Iterations::kGroup) { -+ byte_pointer += params_.increment_group; -+ } -+ } -+ -+ if (cluster + 1 < ThreadMap::Iterations::kCluster) { -+ byte_pointer += params_.increment_cluster; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ if (fill_mode == cutlass::FillMode::kNone) { -+ store_with_byte_offset(frag, 0); -+ } -+ else { -+ store_symmetric_with_byte_offset(frag, 0); -+ } -+ -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorBlas3 &operator++() { -+ -+ ++state_[0]; -+ byte_pointer_ += params_.advance_row; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ byte_pointer_ += params_.advance_group; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ byte_pointer_ += params_.advance_cluster; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ byte_pointer_ += params_.advance_tile; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h -new file mode 100644 -index 0000000..a641f60 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h -@@ -0,0 +1,445 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/permute.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap) -+ typename Element_, ///< Element data type -+ typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>, -+ typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> -+> -+class PredicatedTileIteratorDirectConv { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ -+ using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements; -+ -+ using ThreadTileCount = MatrixShape< -+ ThreadBlockOutputShape::kH / ThreadOutputShape::kH, -+ ThreadBlockOutputShape::kW / ThreadOutputShape::kW -+ >; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorDirect2dConvParams { -+ using Base = PredicatedTileIteratorDirect2dConvParams; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size): -+ PredicatedTileIteratorDirect2dConvParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ problem_size, -+ {ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW} -+ ) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kContiguous; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorDirect2dConvParams params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// -+ Element *pointer_; -+ -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_column_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column -+ Index thread_start_column_; -+ -+ /// Initial thread ouput location -+ int thread_start_n_, thread_start_p_, thread_start_q_; -+ -+ /// Current threadblock tile index -+ int tile_index_; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorDirectConv( -+ PredicatedTileIteratorDirect2dConvParams const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params), pointer_(pointer) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ extent_row_ = extent.row(); -+ extent_column_ = extent.column(); -+ -+ // stride dim (PQ) -+ thread_start_row_ = thread_offset.column(); -+ // contiguous dim (Channels) -+ thread_start_column_ = threadblock_offset.column() + thread_offset.row(); -+ -+ tile_index_ = threadblock_offset.row(); -+ -+ set_tile_index(0); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void set_tile_index(const int index) { -+ -+ int residual; -+ params_.pq_divmod(thread_start_n_, residual, tile_index_ + index); -+ params_.q_divmod(thread_start_p_, thread_start_q_, residual); -+ -+ // Compute the base output coord of ThreadBlock -+ thread_start_p_ *= ThreadBlockOutputShape::kH; -+ thread_start_q_ *= ThreadBlockOutputShape::kW; -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ mask_.predicates[c] = ((thread_start_column_ -+ + c * ThreadMap::Delta::kContiguous) < extent_column_); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer_) { -+ mask_.clear(); -+ } -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; -+ int p = current_row / ThreadBlockOutputShape::kW; -+ int q = current_row % ThreadBlockOutputShape::kW; -+ -+ int current_p = thread_start_p_ + p; -+ int current_q = thread_start_q_ + q; -+ -+ bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && -+ (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; -+ -+ int output_row_offset = -+ thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; -+ -+ uint8_t *byte_pointer = -+ reinterpret_cast(pointer_) + -+ LongIndex(output_row_offset) * LongIndex(params_.stride) + -+ LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * -+ sizeof(AccessType) / kElementsPerAccess; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ bool guard = row_guard && mask_.predicates[c]; -+ -+ cutlass::arch::global_load( -+ frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; -+ int p = current_row / ThreadBlockOutputShape::kW; -+ int q = current_row % ThreadBlockOutputShape::kW; -+ -+ int current_p = thread_start_p_ + p; -+ int current_q = thread_start_q_ + q; -+ -+ bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && -+ (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; -+ -+ int output_row_offset = -+ thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; -+ -+ uint8_t *byte_pointer = -+ reinterpret_cast(pointer_) + -+ LongIndex(output_row_offset) * LongIndex(params_.stride) + -+ LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * -+ sizeof(AccessType) / kElementsPerAccess; -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); -+ -+ bool guard = row_guard && mask_.predicates[c]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) const { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ MatrixCoord thread_start() const { -+ return MatrixCoord(thread_start_row_, thread_start_column_); -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_row() const { -+ return thread_start_row_; -+ } -+ -+ /// Need to get the thread start row from the tile iterator -+ CUTLASS_DEVICE -+ int32_t thread_start_column() const { -+ return thread_start_column_; -+ } -+ -+ /// Extent of the matrix in rows -+ CUTLASS_DEVICE -+ Index extent_row() const { -+ return extent_row_; -+ } -+ -+ /// Extent of the matrix in columns -+ CUTLASS_DEVICE -+ Index extent_column() const { -+ return extent_column_; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirectConv &operator++() { -+ // do nothing -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) const { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h -new file mode 100644 -index 0000000..937409a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h -@@ -0,0 +1,475 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct OutputTileShapeDesc { -+ -+ int column; -+ int row; -+ int group; -+ int cluster; -+ int tile; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ OutputTileShapeDesc(): column(0), row(0), group(0), cluster(0), tile(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ OutputTileShapeDesc( -+ int column_, -+ int row_, -+ int group_, -+ int cluster_, -+ int tile_ -+ ): -+ column(column_), -+ row(row_), -+ group(group_), -+ cluster(cluster_), -+ tile(tile_) { } -+ -+ /// Total number of points in the 5D space -+ CUTLASS_HOST_DEVICE -+ int count() const { -+ return column * row * group * cluster * tile; -+ } -+ -+ #if 0 -+ CUTLASS_HOST_DEVICE -+ void print() const { -+ printf("{%d, %d, %d, %d, %d}", column, row, group, cluster, tile); -+ } -+ #endif -+}; -+ -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileShape template. -+template -+CUTLASS_HOST_DEVICE -+OutputTileShapeDesc make_OutputTileShapeDesc() { -+ return OutputTileShapeDesc( -+ Shape::kColumn, -+ Shape::kRow, -+ Shape::kGroup, -+ Shape::kCluster, -+ Shape::kTile -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread map description -+struct OutputTileThreadMapDesc { -+ -+ int threads; -+ int elements_per_access; -+ OutputTileShapeDesc shape; -+ OutputTileShapeDesc iterations; -+ OutputTileShapeDesc delta; -+ OutputTileShapeDesc count; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc( -+ int threads_, -+ int elements_per_access_, -+ OutputTileShapeDesc shape_, -+ OutputTileShapeDesc iterations_, -+ OutputTileShapeDesc delta_, -+ OutputTileShapeDesc count_ -+ ): -+ threads(threads_), -+ elements_per_access(elements_per_access_), -+ shape(shape_), -+ iterations(iterations_), -+ delta(delta_), -+ count(count_) -+ { -+ -+ } -+}; -+ -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. -+template -+CUTLASS_HOST_DEVICE -+OutputTileThreadMapDesc make_OutputTileThreadMapDesc() { -+ return OutputTileThreadMapDesc( -+ ThreadMap::kThreads, -+ ThreadMap::kElementsPerAccess, -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc(), -+ make_OutputTileShapeDesc() -+ ); -+} -+/////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct for PredicatedTileIterator -+// -+ -+struct PredicatedTileIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride; ///< stride in bytes between rows -+ -+ LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows -+ LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group -+ LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster -+ -+ LongIndex advance_row; ///< amount to add to move to the next 'row' position -+ LongIndex advance_group; ///< amount to add to move to the next 'group' position -+ LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position -+ LongIndex advance_tile; ///< amount to add to move to the next 'tile' -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, OutputTileThreadMapDesc thread_map) { -+ -+ stride = stride_; -+ -+ increment_row = stride * thread_map.delta.row; -+ -+ increment_group = stride * thread_map.delta.group -+ - stride * thread_map.delta.row * (thread_map.iterations.row - 1); -+ -+ increment_cluster = stride * thread_map.delta.cluster -+ - stride * thread_map.delta.group * (thread_map.iterations.group - 1) -+ - stride * thread_map.delta.row * (thread_map.iterations.row - 1); -+ -+ advance_row = stride * thread_map.shape.row; -+ -+ advance_group = -+ stride * -+ (thread_map.shape.group - 1) * thread_map.shape.row * thread_map.count.row; -+ -+ advance_cluster = -+ stride * -+ thread_map.count.group * -+ thread_map.shape.group * -+ thread_map.count.row * -+ thread_map.shape.row; -+ -+ advance_tile = -+ stride * -+ thread_map.shape.group * -+ thread_map.shape.row * -+ thread_map.shape.cluster * -+ thread_map.shape.tile; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) { -+ return initialize(LongIndex(stride_), thread_map); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams() { -+ initialize(LongIndex(0), OutputTileThreadMapDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams(Index stride, OutputTileThreadMapDesc thread_map) { -+ initialize(stride, thread_map); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorParams(LongIndex stride, OutputTileThreadMapDesc thread_map) { -+ initialize(stride, thread_map); -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct for PredicatedTileIteratorDirect2dConv -+// -+ -+struct PredicatedTileIteratorDirect2dConvParams{ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ FastDivmod pq_divmod; -+ FastDivmod q_divmod; -+ -+ LongIndex stride; -+ LongIndex stride_n; -+ LongIndex stride_p; -+ -+ int N; -+ int P; -+ int Q; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ stride = stride_; // The stride per row of output tensor (bytes) -+ stride_n = problem_size.P * problem_size.Q; -+ stride_p = problem_size.Q ; -+ -+ N = problem_size.N; -+ P = problem_size.P; -+ Q = problem_size.Q; -+ -+ // Fastdivmod for output O, P, Q -+ if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){ -+ int tiles_p = -+ (problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row()); -+ int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) / -+ (threadblock_output_shape.column()); -+ -+ pq_divmod = FastDivmod(tiles_p * tiles_q); -+ q_divmod = FastDivmod(tiles_q); -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize( -+ Index stride_, -+ cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(), -+ MatrixCoord threadblock_output_shape = MatrixCoord()) { -+ return initialize(LongIndex(stride_), problem_size, threadblock_output_shape); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams(Index stride, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ initialize(stride, problem_size, threadblock_output_shape); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorDirect2dConvParams(LongIndex stride, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ MatrixCoord threadblock_output_shape) { -+ initialize(stride, problem_size, threadblock_output_shape); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+// InterleavedPredicatedTileIterator -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Predicated tile access iterator descriptor object containing template dependent state -+struct InterleavedPredicatedTileIteratorDesc { -+ -+ int element_size_bits; -+ int elements_per_access; -+ int threadmap_warp_size; -+ layout::PitchLinearCoord threadmap_iterations; -+ layout::PitchLinearCoord threadmap_delta; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc( -+ int element_size_bits_, -+ int elements_per_access_, -+ int threadmap_warp_size_, -+ layout::PitchLinearCoord threadmap_iterations_, -+ layout::PitchLinearCoord threadmap_delta_ -+ ): -+ element_size_bits(element_size_bits_), -+ elements_per_access(elements_per_access_), -+ threadmap_warp_size(threadmap_warp_size_), -+ threadmap_iterations(threadmap_iterations_), -+ threadmap_delta(threadmap_delta_) { } -+}; -+ -+// -+// Parameters struct InterleavedPredicatedTileIterator -+// -+ -+struct InterleavedPredicatedTileIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ -+ LongIndex stride; ///< stride in bytes between rows -+ LongIndex advance_row; ///< amount to add to move to the next 'row' position -+ LongIndex advance_column; ///< amount to add to move to the next 'column' position -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride_, InterleavedPredicatedTileIteratorDesc desc) { -+ -+ stride = stride_; -+ -+ advance_row = desc.threadmap_delta.contiguous() * desc.element_size_bits / 8; -+ -+ advance_column = stride_ - desc.threadmap_iterations.contiguous() * -+ desc.elements_per_access * -+ desc.element_size_bits * -+ desc.threadmap_warp_size / 8; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams() { -+ initialize(LongIndex(0), InterleavedPredicatedTileIteratorDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams(Index stride, InterleavedPredicatedTileIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorParams(LongIndex stride, InterleavedPredicatedTileIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. -+template -+CUTLASS_HOST_DEVICE -+InterleavedPredicatedTileIteratorDesc make_InterleavedPredicatedTileIteratorDesc() { -+ return InterleavedPredicatedTileIteratorDesc( -+ sizeof_bits::value, -+ ThreadMap::kElementsPerAccess, -+ ThreadMap::kWarpSize, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an MakePredicatedTileIteratorDesc from a template -+// dependent state -+template -+ struct MakePredicatedTileIteratorDesc; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for layout::RowMajor output data. -+template -+struct MakePredicatedTileIteratorDesc < -+ Element, layout::RowMajor, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ OutputTileThreadMapDesc operator()() { -+ -+ return make_OutputTileThreadMapDesc(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for layout::ColumnMajorInterleaved output data. -+template -+struct MakePredicatedTileIteratorDesc < -+ Element, layout::ColumnMajorInterleaved, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ InterleavedPredicatedTileIteratorDesc operator()() { -+ -+ return make_InterleavedPredicatedTileIteratorDesc(); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h -new file mode 100644 -index 0000000..36202be ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h -@@ -0,0 +1,309 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief PredicatedTileIteratorPredicates. -+ -+ PredicatedTileIteratorPredicates enables both upper and lower bounds for predicates. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator predicates used to bound computations in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Element data type -+> -+class PredicatedTileIteratorPredicates { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ PredicatedTileIteratorParams params_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index lower_extent_row_; -+ Index upper_extent_row_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(lower_extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(upper_extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorPredicates( -+ PredicatedTileIteratorParams const & params, -+ TensorCoord lower_extent, -+ TensorCoord upper_extent, -+ int thread_idx, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ lower_extent_row_ = lower_extent.row(); -+ upper_extent_row_ = upper_extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < upper_extent.column()) && -+ ((thread_offset.column() + ThreadMap::Delta::kColumn * c) >= lower_extent.column()); -+ } -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorPredicates &operator++() { -+ -+ ++state_[0]; -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Gets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+ -+ ///< Gets lower_extent_row_ -+ CUTLASS_DEVICE Index get_lower_extent_row() { -+ return lower_extent_row_; -+ } -+ -+ ///< Gets upper_extent_row_ -+ CUTLASS_DEVICE Index get_upper_extent_row() { -+ return upper_extent_row_; -+ } -+ -+ ///< Gets thread_start_row_ -+ CUTLASS_DEVICE Index get_thread_start_row() { -+ return thread_start_row_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h -new file mode 100644 -index 0000000..1e8c71e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h -@@ -0,0 +1,479 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace epilogue { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load and store output tile from global memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Element data type -+> -+class PredicatedTileIteratorStridedDgrad { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ static int const kThreads = ThreadMap::kThreads; -+ static int const kIterations = ThreadMap::Count::kTile; -+ -+ static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); -+ static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); -+ static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); -+ static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ // -+ // Parameters struct -+ // -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileIteratorParams { -+ -+ /// Convolution problem size -+ cutlass::conv::Conv2dProblemSize problem_size; -+ int tiled_rows_per_filter; -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, cutlass::conv::Conv2dProblemSize problem_size_, int threadblock_row): -+ problem_size(problem_size_), -+ PredicatedTileIteratorParams( -+ layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, -+ make_OutputTileThreadMapDesc() -+ ) -+ { -+ -+ int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_row); -+ -+ tiled_rows_per_filter = tile_m_per_filter * threadblock_row; -+ } -+ }; -+ -+ /// Mask object -+ struct Mask { -+ -+ static int const kCount = ThreadMap::Iterations::kColumn; -+ -+ /// Predicate state -+ bool predicates[kCount]; -+ -+ // -+ // Mask -+ // -+ CUTLASS_HOST_DEVICE -+ Mask() { -+ enable(); -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = false; -+ } -+ } -+ -+ ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kCount; ++i) { -+ predicates[i] = true; -+ } -+ } -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters structure containing reference and precomputed state. -+ Params params_; -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Array of boolean values to contain steady-state predicates -+ Mask mask_; -+ -+ /// Extent of the matrix tile in rows -+ Index extent_row_; -+ -+ /// Starting Dx h and w dimenstion for strided dgrad mapping -+ int start_h_, start_w_; -+ -+ /// Effective Dy P and Q dimenstions for strided dgrad mapping -+ int p_, q_; -+ -+ /// A thread's starting row position (assuming steady-state predicates have been computed) -+ Index thread_start_row_; -+ -+ /// A thread's starting column position (assuming steady-state predicates have been computed) -+ Index thread_start_column_; -+ -+ /// Internal state counter -+ int state_[3]; -+ -+ // -+ // Static asserts about internal strides -+ // -+ -+ static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); -+ static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); -+ -+private: -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ PredicatedTileIteratorStridedDgrad( -+ Params const & params, -+ Element *pointer, -+ TensorCoord extent, -+ int thread_idx, -+ FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, -+ int start_r, int start_s, -+ TensorCoord threadblock_offset = TensorCoord() -+ ): -+ params_(params) -+ { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; -+ -+ int r = start_r; -+ int s = start_s; -+ -+ if (params_.problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ r = (params_.problem_size.R - 1 - r); -+ s = (params_.problem_size.S - 1 - s); -+ } -+ -+ // compute starting coordinates in Dx start_h_ and start_w_ -+ strided_dgrad_starting_coords( -+ params_.problem_size, -+ stride_h_divmod, stride_w_divmod, -+ r, s, -+ start_h_, start_w_); -+ -+ p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h; -+ q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w; -+ -+ extent_row_ = extent.row(); -+ thread_start_row_ = thread_offset.row(); -+ thread_start_column_ = thread_offset.column(); -+ -+ // Initialize predicates -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { -+ -+ mask_.predicates[c] = ((thread_offset.column() -+ + ThreadMap::Delta::kColumn * c) < extent.column()); -+ } -+ -+ // Null pointer performs no accesses -+ if (!pointer) { -+ mask_.clear(); -+ } -+ -+ // Initialize pointer -+ byte_pointer_ = reinterpret_cast(pointer); -+ -+ // Initialize internal state counter -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { -+ -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ // remapping rows to find the mapped_row_offset -+ int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; -+ -+ // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] -+ int n = npq_offset / (p_ * q_); -+ int residual = npq_offset % (p_ * q_); -+ int p = residual / q_; -+ int q = residual % q_; -+ -+ int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + -+ (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + -+ (start_w_ + q * params_.problem_size.stride_w); -+ bool row_guard = mapped_row_offset < extent_row_; -+ -+ int64_t row_byte_offset = mapped_row_offset * params_.stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + -+ column], -+ (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), -+ guard); -+ } -+ } -+ } -+ } -+ } -+ -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { -+ uint8_t *byte_pointer = byte_pointer_; -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ int row_offset = row * ThreadMap::Delta::kRow -+ + group * ThreadMap::Delta::kGroup -+ + cluster * ThreadMap::Delta::kCluster; -+ -+ // remapping rows to find the mapped_row_offset -+ int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; -+ -+ // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] -+ int n = npq_offset / (p_ * q_); -+ int residual = npq_offset % (p_ * q_); -+ int p = residual / q_; -+ int q = residual % q_; -+ -+ int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + -+ (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + -+ (start_w_ + q * params_.problem_size.stride_w); -+ bool row_guard = mapped_row_offset < extent_row_; -+ -+ int64_t row_byte_offset = mapped_row_offset * params_.stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); -+ -+ bool guard = row_guard && mask_.predicates[column]; -+ -+ cutlass::arch::global_store( -+ frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], -+ (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), -+ guard); -+ } -+ } -+ } -+ } -+ } -+ -+ -+ /// Stores a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Advances to the next position to load or store -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorStridedDgrad &operator++() { -+ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ -+ if (state_[0] == ThreadMap::Count::kRow) { -+ -+ state_[0] = 0; -+ ++state_[1]; -+ -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ -+ state_[1] = 0; -+ ++state_[2]; -+ -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ ///< Efficiently disables all accesses guarded by mask -+ CUTLASS_DEVICE void clear_mask() { -+ mask_.clear(); -+ } -+ -+ ///< Efficiently enables all accesses guarded by mask -+ CUTLASS_DEVICE void enable_mask() { -+ mask_.enable(); -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void get_mask(Mask &mask) { -+ mask = mask_; -+ } -+ -+ ///< Sets the mask -+ CUTLASS_DEVICE void set_mask(Mask const &mask) { -+ mask_ = mask; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h -new file mode 100644 -index 0000000..197a4df ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Element data type -+ int MaxAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8 -+> -+class SharedLoadIterator { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::TileShape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kMinAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ ThreadMap::kElementsPerAccess, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment) -+ >; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Stride along adjacent rows -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIterator( -+ TensorRef ref, -+ int thread_idx -+ ): -+ byte_pointer_(reinterpret_cast(ref.data())), -+ stride_((ref.stride(0) * sizeof_bits::value) / 8) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointer -+ byte_pointer_ += -+ thread_offset.row() * stride_ + -+ thread_offset.column() * sizeof(AccessType) / kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ byte_pointer_ += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn * sizeof_bits::value / 8; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ uint8_t const *byte_pointer = byte_pointer_ + -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset * sizeof_bits::value / 8; -+ -+ int frag_row_idx = -+ (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ LoadType const *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = -+ memory_pointer[(column * ThreadMap::Delta::kColumn / kElementsPerAccess) * kLoadsPerAccess + v]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h -new file mode 100644 -index 0000000..a471137 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h -@@ -0,0 +1,585 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops optimized for mixed-precision. -+ -+ This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. -+ -+ When the fragment is loaded into registers, it matches the row-major thread map assumed by -+ the predicated tile iterator writing to global memory. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_, ///< Accumulator data type -+ int ElementSizeBits_, ///< Size of accumulator in bits -+ int OutputSizeBits_, ///< Size of output element in bits -+ int ElementsPerAccess, ///< Vector length of output vector -+ int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements -+ /// in the global memory tensor -+> -+class SharedLoadIteratorMixed; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ typename Element_ ///< Accumulator data type -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ ThreadMap::kElementsPerAccess, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment) -+ >; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] = reinterpret_cast(ref.data()); -+ -+ int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; -+ int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; -+ -+ col_idx += (bank_offset + i) % kLoadsPerAccess; -+ -+ pointers_[i] += thread_offset.row() * stride_ + col_idx; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); -+ -+ LoadType const *memory_pointer = pointers_[v] + row_ptr_offset; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ int OutputSizeBits_ ///< Size of output element in bits -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = int32_t; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = 16; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ 16, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ 4, -+ 16 -+ >; -+ -+ static int const kLoadsPerAccess = 4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; -+ -+ int lane_col_idx = thread_offset.column() / 16; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ int lane_offset = (lane_col_idx % 2) * 4 | ((lane_col_idx / 2) * 8) | ((lane_col_idx / 2) ^ i); -+ -+ pointers_[i] = base_ptr + lane_offset; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ LoadType const *memory_pointer = pointers_[v]; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 8 => int8_t/int4b_t x 8 -+template < -+ typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) -+ int OutputSizeBits_ -+> -+class SharedLoadIteratorMixed { -+public: -+ using ThreadMap = ThreadMap_; -+ using Shape = typename ThreadMap::Shape; -+ -+ using Element = int32_t; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kAlignment = 8; -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array< -+ Element, -+ ThreadMap::Iterations::kColumn * -+ ThreadMap::Iterations::kRow * -+ ThreadMap::Iterations::kGroup * -+ ThreadMap::Iterations::kCluster * -+ ThreadMap::kElementsPerAccess>; -+ -+ /// Memory access size -+ using AccessType = AlignedArray< -+ Element, -+ 8, -+ kAlignment>; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = AlignedArray< -+ Element, -+ 4, -+ 16 -+ >; -+ -+ static int const kLoadsPerAccess = 2; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ LoadType const *pointers_[kLoadsPerAccess]; -+ -+ /// Stride along adjacent rows in units of LoadType -+ int stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorMixed( -+ TensorRef ref, -+ int thread_idx -+ ): -+ stride_((ref.stride(0) / LoadType::kElements)) { -+ -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointers -+ LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; -+ -+ int lane_col_idx = thread_offset.column() / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ int lane_offset = (lane_col_idx % 8) * 2 | ((lane_col_idx / 4) ^ i); -+ -+ pointers_[i] = base_ptr + lane_offset; -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += pointer_offset / LoadType::kElements; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kLoadsPerAccess; ++i) { -+ pointers_[i] += -+ offset.row() * Shape::kRow * stride_ + -+ offset.column() * Shape::kColumn / LoadType::kElements; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { -+ -+ int row_ptr_offset = -+ row * ThreadMap::Delta::kRow * stride_ + -+ group * ThreadMap::Delta::kGroup* stride_ + -+ cluster * ThreadMap::Delta::kCluster * stride_ + -+ pointer_offset / LoadType::kElements; -+ -+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ -+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ -+ LoadType const *memory_pointer = pointers_[v]; -+ -+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Set base smem address -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) {} -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h -new file mode 100644 -index 0000000..df8676e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. -+ -+ This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. -+ -+ When the fragment is loaded into registers, it matches the row-major thread map assumed by -+ the predicated tile iterator writing to global memory. -+ -+ The epilogue rearranges the result of a matrix product through shared memory to match canonical -+ tensor layouts in global memory. Epilogues support conversion and reduction operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator used to load output tile from shared memory in epilogue. -+/// -+/// Satisfies: ReadableTileIterator -+/// -+template ::value / 8> -+class SharedLoadIteratorPitchLiner { -+ public: -+ using ThreadMap = ThreadMap_; -+ using Element = Element_; -+ -+ using Layout = layout::RowMajor; -+ using TensorRef = TensorRef; -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using TensorCoord = MatrixCoord; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static int const kMinAlignment = -+ ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; -+ -+ static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); -+ -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Fragment object -+ using Fragment = Array; -+ -+ /// Memory access size -+ using AccessType = AlignedArray; -+ -+ /// Vector type used for SMEM loads -+ using LoadType = -+ AlignedArray::value, ThreadMap::kElementsPerAccess), -+ const_min(16, kAlignment)>; -+ -+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Byte-level pointer -+ uint8_t *byte_pointer_; -+ -+ /// Stride along adjacent rows -+ int stride_; -+ -+ /// Base address offset -+ Index base_smem_address_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx) -+ : byte_pointer_(reinterpret_cast(ref.data())), -+ stride_((ref.stride(0) * sizeof_bits::value) / 8), -+ base_smem_address_(0) { -+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); -+ -+ // Initialize pointer -+ // thread_offset.row() is contiguous dim -+ // thread_offset.column() is stride dim -+ byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+ -+ thread_offset.column() * stride_ ; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_pointer_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &offset) { -+ byte_pointer_ += -+ offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess + -+ offset.column() * ThreadMap::StorageShape::kStrided * stride_; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ uint8_t const *byte_pointer = -+ byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess * -+ sizeof_bits::value / 8 + -+ pointer_offset * sizeof_bits::value / 8 + base_smem_address_; -+ -+ int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; -+ -+ LoadType *frag_ptr = reinterpret_cast(&frag); -+ -+ LoadType const *memory_pointer = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kLoadsPerAccess; ++v) { -+ frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void set_smem_base_address(Index address) { base_smem_address_ = address; } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h -new file mode 100644 -index 0000000..6dd04ed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorComplexTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type -+ typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorComplexTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ complex, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ static int const kRealIndex = 0; -+ -+ /// Offset into the accumulator fragment -+ static int const kImaginaryIndex = -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using OutputAccumulatorTile = Array, kImaginaryIndex>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+ using FragmentAccessType = Array, Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorComplexTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ auto const & real_accum_array = accumulators_[accumulator_access_offset + kRealIndex]; -+ auto const & imag_accum_array = accumulators_[accumulator_access_offset + kImaginaryIndex / Policy::kElementsPerAccess]; -+ -+ // Pack real and imaginary parts into a structure. This is likely to result in MOVs -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::kElementsPerAccess; ++i) { -+ -+ frag_ptr[n][i].real() = real_accum_array[i]; -+ frag_ptr[n][i].imag() = imag_accum_array[i]; -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h -new file mode 100644 -index 0000000..f55c4bd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorGaussianComplexTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type -+ typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorGaussianComplexTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ complex, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// Size of one part of accumulator of 3-part accumulator in units of number of OperatorElementC -+ static int const kElementsAccumulatorPerPart = -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; -+ -+ /// Offset into the accumulator fragment part 1 -+ static int const kPart1Index = kElementsAccumulatorPerPart * 0; -+ -+ /// Offset into the accumulator fragment part 2 -+ static int const kPart2Index = kElementsAccumulatorPerPart * 1; -+ -+ /// Offset into the accumulator fragment part 3 -+ static int const kPart3Index = kElementsAccumulatorPerPart * 2; -+ -+ /// This is the complete warp-level accumulator tile holding part1, part2, and part3 -+ using AccumulatorTile = Array; -+ -+ /// This is the complete warp-level accumulator tile holding final output of complex type -+ using OutputAccumulatorTile = Array, kElementsAccumulatorPerPart>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+ using FragmentAccessType = Array, Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorGaussianComplexTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ auto const & part1_accum_array = accumulators_[accumulator_access_offset + kPart1Index]; -+ auto const & part2_accum_array = accumulators_[accumulator_access_offset + kPart2Index / Policy::kElementsPerAccess]; -+ auto const & part3_accum_array = accumulators_[accumulator_access_offset + kPart3Index / Policy::kElementsPerAccess]; -+ -+ // Pack parts 1, 2, and 3 into a structure. This is likely to result in MOVs -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::kElementsPerAccess; ++i) { -+ -+ frag_ptr[n][i].real() = part1_accum_array[i] - part3_accum_array[i]; -+ frag_ptr[n][i].imag() = part1_accum_array[i] + part2_accum_array[i]; -+ } -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h -new file mode 100644 -index 0000000..b181c81 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fragment iterator for SIMT accumulator arrangements -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Layout, ///< target shared memory layout -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class FragmentIteratorSimt; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename Operator_ , ///< matrix multiply operator (concept: arch::Mma) -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class FragmentIteratorSimt { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Layout = layout::RowMajor; -+ -+ /// Policy for warp-level epilogue components -+ using Policy = SimtPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorSimt &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ -+ int accumulator_access_offset = index_ * Policy::kAccessesPerIteration + n; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h -new file mode 100644 -index 0000000..f9b20a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h -@@ -0,0 +1,277 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) -+> -+class FragmentIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ OperatorElementC, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ OperatorElementC, -+ OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ using TileIterations = typename Policy::TileIterations; -+ static int const kIterationsPerTile = kIterations / TileIterations::kCount; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int accumulator_access_offset = -+ index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Dedicated to interleaved layout -+template < -+ /// shape of the warp-level GEMM tile -+ typename WarpShape_, -+ /// matrix multiply operator shape (concept: gemm::GemmShape) -+ typename OperatorShape_, -+ /// matrix multiply operator data type (concept: data type) -+ typename OperatorElementC_, -+ /// matrix multiply operator fragment (concept: Array) -+ typename OperatorFragmentC_, -+ /// number of interleaved k -+ int InterleavedK> -+class FragmentIteratorTensorOp> { -+ public: -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = -+ Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = -+ Array; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ using TileIterations = typename Policy::TileIterations; -+ static int const kIterationsPerTile = kIterations / TileIterations::kCount; -+ -+ private: -+ /// Internal access type -+ using AccessType = -+ Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp(AccumulatorTile const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0) {} -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ int index = index_ + index_offset; -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < (InterleavedK / OperatorShape::kN); ++n) { -+ int index_m = index % (Policy::OperatorCount::kRow * -+ Policy::kIterationsPerInstruction); -+ int index_n = index / (Policy::OperatorCount::kRow * -+ Policy::kIterationsPerInstruction); -+ int accumulator_access_offset = -+ (index_m / Policy::kIterationsPerInstruction) * -+ (Policy::OperatorCount::kColumn * -+ Policy::kIterationsPerInstruction) + -+ (index_m % Policy::kIterationsPerInstruction) + -+ index_n * (InterleavedK / OperatorShape::kN) * -+ Policy::kIterationsPerInstruction + -+ n * Policy::kIterationsPerInstruction; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h -new file mode 100644 -index 0000000..d37e82e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+class FragmentIteratorVoltaTensorOp, half_t, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = half_t; -+ using Layout = layout::RowMajor; -+ -+ /// Policy operator -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ static int const kAccessesPerMma = Policy::kElementsPerMma / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ -+ int tile_access_idx = -+ (tile_n * Policy::TileIterations::kRow + (index_ & 2) / 2) * Policy::MmaIterations::kCount * kAccessesPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * kAccessesPerMma; ++mma_n) { -+ -+ int mma_access_idx = ((mma_n & 1) * 2 + (index_ & 1)) * kAccessesPerMma + (mma_n & 2) / 2; -+ -+ frag_ptr[tile_n * Policy::MmaIterations::kColumn * kAccessesPerMma + -+ mma_n] = accumulators_[tile_access_idx + mma_access_idx]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+class FragmentIteratorVoltaTensorOp, float, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = float; -+ using Layout = layout::RowMajor; -+ -+ /// Policy operator -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+private: -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorVoltaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int const kRegsPerMmaRow = 2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int reg_row = 0; reg_row < Policy::kRowsPerMmaTile; ++reg_row) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * 2; ++mma_n) { -+ -+ int mma_idx = (index_ & 1) + (index_ & 2) * Policy::MmaIterations::kCount / 2 + -+ (tile_n * Policy::TileIterations::kRow) * Policy::MmaIterations::kCount + (mma_n & 1) * 2; -+ -+ int reg_offset = reg_row * kRegsPerMmaRow + (mma_n & 2) * 2; -+ int reg_idx = mma_idx * Policy::kElementsPerMma + reg_offset; -+ -+ *frag_ptr = accumulators_[reg_idx / Policy::kElementsPerAccess]; -+ ++frag_ptr; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h -new file mode 100644 -index 0000000..225e0f0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile -+ that participate in one warp-level store operation. -+ -+ Typically, the accumulator tile is the largest single block of register-backed storage -+ within the kernel. Storing it to memory is best accomplished by partitioning it into -+ smaller tiles and storing these sequentially. -+ -+ Round trips through shared memory during the Epilogue phase require partitioning, as -+ shared memory capacity is typically insufficient for a threadblock's total accumulator -+ size. -+*/ -+ -+#pragma once -+ -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) -+ typename Layout ///< target shared memory layout -+> -+class FragmentIteratorWmmaTensorOp; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major shared memory -+template < -+ typename WarpShape_, ///< shape of the warp-level GEMM tile -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) -+ typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) -+> -+class FragmentIteratorWmmaTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorElementC = OperatorElementC_; -+ using OperatorFragmentC = OperatorFragmentC_; -+ using Layout = layout::RowMajor; -+ -+ using Policy = WmmaTensorOpPolicy; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = WmmaFragmentArray; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = WmmaFragmentArray; -+ -+ using OutputAccumulatorTile = AccumulatorTile; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = WmmaFragmentArray; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp(AccumulatorTile const &accum): -+ accumulators_(reinterpret_cast(&accum)), -+ index_(0) { -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp &operator++() { -+ ++index_; -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ FragmentIteratorWmmaTensorOp &operator--() { -+ --index_; -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, int index_offset = 0) const { -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ int accumulator_access_offset = index_ * Policy::OperatorCount::kColumn + n; -+ -+ frag_ptr[n] = accumulators_[accumulator_access_offset]; -+ } -+ } -+}; -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#else -+#error (defined(__clang__) && defined(__CUDA__)) -+#endif // !defined(__clang__) -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h -new file mode 100644 -index 0000000..21bca80 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/simt_policy.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of SimtOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Layout, ///< destination layout in shared memory -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+struct SimtPolicy; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+struct SimtPolicy { -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ -+ static_assert(!(WarpShape::kM % MmaSimtPolicy::WarpShape::kRow), "Divisibility"); -+ static_assert(!(WarpShape::kN % MmaSimtPolicy::WarpShape::kColumn), "Divisibility"); -+ -+ /// Number of iterations -+ static int const kIterations = WarpShape::kM / MmaSimtPolicy::WarpShape::kRow; -+ -+ /// Number of accumulators written per iteration -+ static int const kElementsPerIteration = -+ (WarpShape::kN / MmaSimtPolicy::WarpShape::kColumn); -+ -+ /// Total number of accumulators -+ static int const kAccumulatorElementCount = kElementsPerIteration * kIterations; -+ -+ /// Number of consecutive elements -+ static int const kElementsPerAccess = MmaSimtPolicy::LaneMmaShape::kN; -+ -+ /// Number of rows per epilogue iteration -+ static int const kRowsPerIteration = MmaSimtPolicy::WarpShape::kRow; -+ -+ /// Number of accesses made in one iteration -+ static int const kAccessesPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ /// Number of elements in between accumulator chunks of (LaneMmaShape::kM x LaneMmaShape::kN) -+ using Delta = MatrixShape< -+ MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM, -+ MmaSimtPolicy::WarpShape::kColumn * MmaSimtPolicy::LaneMmaShape::kN -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h -new file mode 100644 -index 0000000..e0d1f6f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) -+ typename Layout ///< target shared memory layout -+> -+struct TensorOpPolicy; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) -+> -+struct TensorOpPolicy { -+ -+ /// Number of operations -+ using OperatorCount = MatrixShape< -+ (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, -+ (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN -+ >; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = 8; -+ static bool const kDivisible = -+ !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of 'externally visible' iterations per actual instruction -+ static int const kIterationsPerInstruction = OperatorShape::kM / kRowsPerIteration; -+ -+ // Number of externally visible iterations -+ static int const kIterations = OperatorCount::kRow * kIterationsPerInstruction; -+ -+ using TileIterations = MatrixShape; -+ -+ static int const kAccumulatorRowStride = kElementsPerAccess; -+ static int const kAccumulatorColumnStride = kElementsPerAccess * OperatorCount::kRow * kIterationsPerInstruction; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for column-major-interleaved -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation (concept: arch::Mma) -+ int InterleavedK ///< number of interleaved k -+ > -+struct TensorOpPolicy > { -+ /// Number of operations -+ using OperatorCount = MatrixShape; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = 8; -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of 'externally visible' iterations per actual instruction -+ static int const kIterationsPerInstruction = -+ OperatorShape::kM / kRowsPerIteration; -+ -+ // Number of externally visible iterations -+ static int const kIterations = WarpShape::kN / InterleavedK * -+ OperatorCount::kRow * -+ kIterationsPerInstruction; -+ -+ static int const kElementsPerIteration = InterleavedK / OperatorShape::kN * kElementsPerAccess; -+ -+ static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ // Number of externally visible iterations -+ //static int const kTileIterations = OperatorCount::kRow * kIterationsPerInstruction; -+ using TileIterations = MatrixShape<1, WarpShape::kN / InterleavedK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h -new file mode 100644 -index 0000000..5ef4b2e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h -@@ -0,0 +1,785 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/simt_policy.h" -+ -+#define CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES 1 -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename Operator, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element, ///< data type of element to be written -+ typename Layout, ///< target shared memory layout -+ typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimt; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element_, ///< data type of element to be written -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimt { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ + 1 -+#endif -+ >; -+ -+private: -+ -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ 1 -+ >; -+ -+#else -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ Policy::kElementsPerAccess -+ >; -+#endif -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ pointer_ += layout_({ -+ lane_offset.row(), -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimt & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES -+ // de-vectorized stores -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ scalarPointer[n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s] = scalarFragPtr[n * Policy::kElementsPerAccess + s]; -+ } -+ } -+#else -+ // original vector stores -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; -+ } -+#endif -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template -+class TileIteratorSimtDirectConv { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, -+ 0 -+ >; -+ -+private: -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ Policy::kElementsPerAccess -+ >; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Base smem offset; -+ Index base_smem_address_; -+ -+ public: -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv() : pointer_(nullptr) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ pointer_ += layout_({ -+ lane_offset.row(), -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ // original vector stores -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ AccessType * load_pointer_ = reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address){ -+ base_smem_address_ = address; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Template for reading and writing tiles of accumulators to shared memory -+template -+class TileIteratorSimtDirect2dConv { -+ public: -+ using WarpShape = WarpShape_; -+ using ThreadOutputShape = ThreadOutputShape_; -+ using ThreadBlockOutputShape = ThreadBlockOutputShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ using MmaSimtPolicy = MmaSimtPolicy_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ // Thread-level shape of a fragment -+ using ThreadShape = MatrixShape; -+ -+ static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ using ThreadTileCount = MatrixShape; -+ -+ using Iterations = -+ MatrixShape; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = AccumulatorTile; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, 0>; -+ -+ private: -+ // Storage type for accessing memory -+ using AccessType = AlignedArray; -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Base smem offset; -+ Index base_smem_address_; -+ -+ public: -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv() : pointer_(nullptr) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements) { -+ -+ auto lane_layout = MmaSimtPolicy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ // Get base HW offset of current threads -+ const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); -+ const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; -+ const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; -+ -+ const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q; -+ -+ pointer_ += layout_( -+ {row_offset, -+ lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType *storer_pointer_ = -+ reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int h = 0; h < ThreadOutputShape::kH; ++h) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int w = 0; w < ThreadOutputShape::kW; ++w) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < Iterations::kColumn; ++col) { -+ int offset = (w + h * ThreadBlockOutputShape::kW) * -+ (ThreadBlockOutputShape::kC / AccessType::kElements) + -+ col; -+ storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] = -+ frag_ptr[w + h * ThreadOutputShape::kW + col]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { base_smem_address_ = address; } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename Operator_, ///< matrix multiply operation (concept: arch::Mma) -+ typename Element_, ///< data type of element to be written -+ typename Layout_, ///< target shared memory layout -+ typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) -+> -+class TileIteratorSimtCanonical { -+public: -+ -+ using WarpShape = WarpShape_; -+ using Operator = Operator_; -+ using Element = Element_; -+ using Layout = Layout_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = SimtPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ typename Operator::ElementC, -+ Policy::kElementsPerIteration>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ typename Operator::ElementC, -+ Policy::kAccumulatorElementCount>; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess + 1 -+ >; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray< -+ Element, -+ 1 -+ >; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Guard to indicate whether the shape is divisible -+ bool divisible_; -+ -+ /// Extent of the output tensor -+ MatrixCoord extent_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements), -+ divisible_(true), -+ extent_(WarpShape::kM, WarpShape::kN) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ thread_offset_ = { -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({ -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / AccessType::kElements), -+ divisible_(false), -+ extent_(extent) { -+ -+ auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id); -+ -+ thread_offset_ = { -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({ -+ lane_offset.row() * Shape::kRow, -+ lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) -+ }); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / AccessType::kElements; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row(), -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & operator+=(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ // de-vectorized stores -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ -+ int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; -+ int frag_idx = n * Policy::kElementsPerAccess + s; -+ -+ int col = thread_offset_.column() + ptr_idx; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ scalarPointer[ptr_idx] = scalarFragPtr[frag_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ // de-vectorized loads -+ using ScalarAccessType = AlignedArray; -+ ScalarAccessType *scalarFragPtr = reinterpret_cast(&frag); -+ ScalarAccessType const *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::kElementsPerAccess; s++) { -+ -+ int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; -+ int frag_idx = n * Policy::kElementsPerAccess + s; -+ -+ int col = thread_offset_.column() + ptr_idx; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ scalarFragPtr[frag_idx] = scalarPointer[ptr_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorSimtCanonical & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h -new file mode 100644 -index 0000000..a1eb5c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h -@@ -0,0 +1,671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element, ///< data type of element to be written -+ typename Layout ///< target shared memory layout -+> -+class TileIteratorTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_ ///< data type of element to be written -+> -+class TileIteratorTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using TensorLayout = Layout; -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of times this iterator can be incremented -+ using TileIterations = typename Policy::TileIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() / Policy::kElementsPerAccess -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ frag_ptr[n] = pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of element to be written -+ int InterleavedK ///< number of interleaved k -+> -+class TileIteratorTensorOp > { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorInterleaved; -+ using TensorLayout = Layout; ///< shared memory tensor ref layout -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+// Policy::kRowsPerIteration, -+ WarpShape::kM, -+ InterleavedK -+ >; -+ -+ /// This is the fragment size produced by one tile -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction -+ * Policy::kElementsPerIteration>; -+ -+ /// This is the fragment size produced by one iteration -+// using Fragment = Array< -+// Element, Policy::kElementsPerIteration >; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ using TileIterations = typename Policy::TileIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerIteration>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ TensorLayout layout_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerIteration -+ }; -+ -+ pointer_ += (layout_({thread_offset_.row(), thread_offset_.column()}) / Policy::kElementsPerAccess); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += (layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }) / Policy::kElementsPerAccess); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { -+ -+ AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < Policy::kAccessPerIteration; ++a) { -+ ptr[a + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n * Policy::kAccessPerIteration + a]; -+ -+// printf("store thread %d, address %p, bank %ld\n", threadIdx.x, pointer_+a+n*Detail::kLanesInQuad, -+// ((long long)(pointer_+a+n*Detail::kLanesInQuad)>>2)&0x1f); -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { -+ -+ AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < Policy::kAccessPerIteration; ++a) { -+ frag_ptr[n * Policy::kAccessPerIteration + a] = ptr[a + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOp & operator++() { -+ return add_tile_offset({0, 1}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of element to be written -+ typename Layout_ -+> -+class TileIteratorTensorOpCanonical { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ static int const kAccessSize = 1; -+ static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+ /// Guard to indicate whether the shape is divisible -+ bool divisible_; -+ -+ /// Extent of the output tensor -+ MatrixCoord extent_; -+ -+ /// Thread offset -+ MatrixCoord thread_offset_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]), -+ divisible_(true), -+ extent_(WarpShape::kM, WarpShape::kN) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0]), -+ divisible_(false), -+ extent_(extent) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ thread_offset_ = { -+ quad_id, lane_in_quad * Policy::kElementsPerAccess -+ }; -+ -+ pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ MatrixCoord coord_offset( -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn -+ ); -+ -+ thread_offset_ += coord_offset; -+ -+ pointer_ += layout_({ -+ coord_offset.row(), -+ coord_offset.column() -+ }); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < kAccessCount; ++a) { -+ -+ int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; -+ int frag_idx = n * kAccessCount + a; -+ -+ int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ pointer_[ptr_idx] = frag_ptr[frag_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int a = 0; a < kAccessCount; ++a) { -+ -+ int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; -+ int frag_idx = n * kAccessCount + a; -+ -+ int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; -+ -+ if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { -+ frag_ptr[frag_idx] = pointer_[ptr_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpCanonical & operator++() { -+ return add_tile_offset({1, 0}); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h -new file mode 100644 -index 0000000..3bbc942 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h -@@ -0,0 +1,727 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This is an optimization available on CUDA 11.2 and beyond that eliminates branches in the epilogue. -+#define CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED ((__CUDACC_VER_MAJOR__ * 10 + __CUDACC_VER_MINOR__) >= 112) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory. This is optimized -+/// for mixed-precision epilogues in which the accumulators are 32b in width, but the output -+/// data type is smaller. -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename Element_, ///< data type of accumulator element -+ int ElementSizeBits, ///< Size of accumulator element in bits -+ int OutputSizeBits, ///< Size of output element in bits -+ int OutputElementCount, ///< number of elements in output vector -+ int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = OutputElementCount; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = -+ (OutputElementCount * sizeof_bits::value) / (const_min(128, OutputElementCount * sizeof_bits::value)); -+ -+ static_assert(kPointerCount <= 4, "Can only accommodate four pointers at present."); -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Detail::kLanesInQuad * Policy::kElementsPerAccess>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+ /// Logical column in which warp tile is aligned -+ int warp_column_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / Policy::kElementsPerAccess), -+ warp_column_(0) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = (lane_in_quad % 2) + (((lane_in_quad / 2) + i) % Detail::kPointerCount) * 2; -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 2) { -+ pointers_[2 % Detail::kPointerCount] = ptr; -+ } -+ else if (i == 3) { -+ pointers_[3 % Detail::kPointerCount] = ptr; -+ } -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / Policy::kElementsPerAccess; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess; -+ } -+ -+ warp_column_ += tile_offset.column() * Shape::kColumn; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ AccessType *ptr = pointers_[0]; -+ -+#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+ // When the optimization is enabled, small tiles require separate logic. -+ bool kN32_optimization = (WarpShape::kN * Detail::kLanesInQuad * Policy::kElementsPerAccess * sizeof_bits::value) % 1024 == 0; -+ if (kN32_optimization) { -+ int ptr_idx = ((warp_column_ * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } else if (ptr_idx == 2) { -+ ptr = pointers_[2]; -+ } else if (ptr_idx == 3) { -+ ptr = pointers_[3]; -+ } -+ } -+ -+#endif -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+ // -+ // When the optimization is enabled, this expression suffices to obtain the SMEM pointer. -+ // -+ if (WarpShape::kN == 64) { -+ ptr = pointers_[n / 4]; -+ } -+ else if (!kN32_optimization) -+#endif -+ { -+ // This is the reference implementation -+ int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; -+ int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ -+ if (ptr_idx == 0) { -+ ptr = pointers_[0 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 2) { -+ ptr = pointers_[2 % Detail::kPointerCount]; -+ } -+ else if (ptr_idx == 3) { -+ ptr = pointers_[3 % Detail::kPointerCount]; -+ } -+ } -+ -+ int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; -+ ptr[offset] = frag_ptr[n]; -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; -+ int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; -+ -+ AccessType const *smem_ptr = pointers_[ptr_idx]; -+ frag_ptr[n] = smem_ptr[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape), -+ int OutputSizeBits ///< Size of output element in bits -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = int32_t; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = 16; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = 2; -+ -+ /// Offsets added -+ static int const kOffsetCount = 4; -+ -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+ /// Uniform offset in bytes added to warp tile iterator -+ int uniform_offset_[Detail::kOffsetCount]; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / AccessType::kElements) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = lane_in_quad ^ (i * 2); -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1] = ptr; -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kOffsetCount; ++i) { -+ uniform_offset_[i] = (i ^ 0) * 4 * sizeof(AccessType); -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / AccessType::kElements; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / AccessType::kElements; -+ -+ pointers_[0] += ptr_offset; -+ pointers_[1] += ptr_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kOffsetCount; ++i) { -+ uniform_offset_[i] = (i ^ tile_offset.column()) * 4 * sizeof(AccessType); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int ptr_idx = (n / 4); -+ int offset_idx = (n % 4); -+ -+ AccessType *ptr; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } -+ -+ int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements; -+ -+#if 0 -+ // -+ // Using inline PTX to avoid generic memory -+ // -+ AccessType *smem_ptr = pointers_[ptr_idx]; -+ smem_ptr[offset] = frag_ptr[n]; -+#else -+ uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); -+ uint32_t const *data = reinterpret_cast(frag_ptr + n); -+ uint32_t offset_in_bytes = offset * sizeof(AccessType) + uniform_offset_[offset_idx]; -+ -+ asm volatile( -+ "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" -+ : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) -+ ); -+#endif -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for int32_t x 8 => int8_t/int4b_t x 8 -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ int OutputSizeBits ///< Size of output element in bits -+> -+class TileIteratorTensorOpMixed { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using Element = int32_t; -+ using Layout = layout::RowMajor; -+ static int const kOutputElementCount = 8; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = TensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ Element, -+ Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ -+ /// Number of pointers needed to write accumulators -+ static int const kPointerCount = 2; -+ -+ static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; -+ -+private: -+ -+ /// Storage type for accessing memory -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointers_[Detail::kPointerCount]; -+ -+ /// Stride in units of AccessType -+ int stride_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] = nullptr; -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ stride_(ref.stride()[0] / AccessType::kElements) { -+ -+ int quad_id = (lane_id / Detail::kLanesInQuad); -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; -+ int column_idx = lane_in_quad ^ (i * 2); -+ -+ ptr += column_idx; -+ -+ if (i == 0) { -+ pointers_[0] = ptr; -+ } -+ else if (i == 1) { -+ pointers_[1] = ptr; -+ } -+ } -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int64_t i = 0; i < Detail::kPointerCount; ++i) { -+ pointers_[i] += pointer_offset / AccessType::kElements; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + -+ tile_offset.column() * Shape::kColumn / AccessType::kElements; -+ -+ pointers_[0] += ptr_offset; -+ pointers_[1] += ptr_offset; -+ -+ if (tile_offset.column() % 2) { -+ auto tmp = pointers_[0]; -+ pointers_[0] = pointers_[1]; -+ pointers_[1] = tmp; -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { -+ return add_tile_offset(tile_offset); -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { -+ -+ int ptr_idx = (n / 4); -+ -+ AccessType *ptr; -+ if (ptr_idx == 0) { -+ ptr = pointers_[0]; -+ } -+ else if (ptr_idx == 1) { -+ ptr = pointers_[1]; -+ } -+ -+ int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements + (n % 4) * 4; -+ -+#if 0 -+ // -+ // Using inline PTX to avoid generic memory -+ // -+ AccessType *smem_ptr = pointers_[ptr_idx]; -+ smem_ptr[offset] = frag_ptr[n]; -+#else -+ uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); -+ uint32_t const *data = reinterpret_cast(frag_ptr + n); -+ uint32_t offset_in_bytes = offset * sizeof(AccessType); -+ -+ asm volatile( -+ "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" -+ : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) -+ ); -+#endif -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#undef CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h -new file mode 100644 -index 0000000..a4cabd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h -@@ -0,0 +1,440 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/epilogue/warp/tensor_op_policy.h" -+#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+struct TileIteratorVoltaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct TileIteratorVoltaTensorOp, half_t, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using Element = half_t; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of elements per access -+ static int const kElementsPerAccess = Policy::kElementsPerAccess; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ static int const kRowsPerQuad = 4; -+ static int const kColumnsPerQuad = 8; -+ static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; -+ static int const kAccessQuadDelta = 16; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorVoltaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = lane_id / Detail::kLanesInQuad; -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ int quad_row_idx = ((quad_id & 4) >> 1) + (quad_id & 1); -+ int quad_col_idx = ((quad_id & 2) >> 1); -+ -+ int row = quad_row_idx * Detail::kRowsPerQuad + lane_in_quad; -+ int column = quad_col_idx * Detail::kColumnsPerQuad; -+ -+ pointer_ += layout_({row, column / kElementsPerAccess}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { -+ -+ int access_quad = access_idx / 2; -+ int access = access_idx % 2; -+ -+ int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess + -+ access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + -+ access + pointer_offset / Policy::kElementsPerAccess; -+ -+ int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; -+ -+ AccessType access_vector = frag_ptr[frag_idx]; -+ -+ pointer_[ptr_offset] = access_vector; -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { -+ -+ int access_quad = access_idx / 2; -+ int access = access_idx % 2; -+ -+ int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + -+ access + pointer_offset / Policy::kElementsPerAccess; -+ -+ int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; -+ -+ frag_ptr[frag_idx] = pointer_[ptr_offset]; -+ } -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment const &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct TileIteratorVoltaTensorOp, float, layout::RowMajor> { -+public: -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using Element = float; -+ using Layout = layout::RowMajor; -+ -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = VoltaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// Array type for aligned memory accesses -+ using AccessType = typename Policy::AccessType; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = typename Policy::Fragment; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = typename Policy::AccumulatorTile; -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = Policy::kIterations; -+ -+ /// Number of elements per access -+ static int const kElementsPerAccess = Policy::kElementsPerAccess; -+ -+ // Internal constants -+ struct Detail { -+ static int const kLanesInQuad = 4; -+ static int const kRowsPerQuad = 4; -+ static int const kColumnsPerQuad = 8; -+ static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; -+ static int const kAccessQuadDelta = 16; -+ }; -+ -+ /// Padding quantity -+ using Padding = MatrixShape< -+ 0, -+ Policy::kElementsPerAccess>; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to memory -+ AccessType *pointer_; -+ -+ /// Internal layout object -+ Layout layout_; -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp(): pointer_(nullptr) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorVoltaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): -+ pointer_(reinterpret_cast(ref.data())), -+ layout_(ref.stride()[0] / Policy::kElementsPerAccess) { -+ -+ int quad_id = lane_id / Detail::kLanesInQuad; -+ int lane_in_quad = (lane_id % Detail::kLanesInQuad); -+ -+ int const kQuadRowDelta = 4; -+ int const kQuadColumnDelta = 2 * Policy::MmaIterations::kColumn; -+ -+ int quad_row_offset = ((quad_id & 4) / 2 + (quad_id & 1)) * kQuadRowDelta; -+ int quad_column_offset = (quad_id & 2) / 2 * kQuadColumnDelta; -+ -+ int thread_row_offset = (lane_in_quad & 1); -+ int thread_column_offset = (lane_in_quad & 2) / 2; -+ -+ int row = quad_row_offset + thread_row_offset; -+ int column = quad_column_offset + thread_column_offset; -+ -+ pointer_ += layout_({row, column}); -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { -+ pointer_ += pointer_offset / Policy::kElementsPerAccess; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ -+ pointer_ += layout_({ -+ tile_offset.row() * Shape::kRow, -+ tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ int const kAccessesPerRow = Policy::TileIterations::kColumn * Policy::MmaIterations::kColumn * 2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row_idx = 0; row_idx < Policy::kRowsPerMmaTile; ++row_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < kAccessesPerRow; ++access_idx) { -+ -+ int frag_idx = row_idx * kAccessesPerRow + access_idx; -+ -+ int ptr_column_offset = (access_idx & 1) * 2 + -+ (access_idx & 2) * Policy::MmaIterations::kColumn * 2 + -+ (access_idx & 4) * Policy::MmaIterations::kColumn * 2; -+ -+ int ptr_row_offset = row_idx * 2; -+ -+ int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess; -+ -+ pointer_[ptr_offset] = frag_ptr[frag_idx]; -+ } -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ assert(0); // TODO -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment const &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h -new file mode 100644 -index 0000000..6856b3e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#if !(defined(__clang__) && defined(__CUDA__)) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/wmma_array.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorFragment, ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) -+ typename Layout ///< target shared memory layout -+> -+class TileIteratorWmmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template for reading and writing tiles of accumulators to shared memory -+template < -+ typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) -+ typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) -+ typename OperatorFragment_ ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) -+> -+class TileIteratorWmmaTensorOp { -+public: -+ -+ using WarpShape = WarpShape_; -+ using OperatorShape = OperatorShape_; -+ using OperatorFragment = OperatorFragment_; -+ using Layout = layout::RowMajor; -+ -+ // -+ // Derived types -+ // -+ using WmmaDataType = typename OperatorFragment::element_type; -+ using Element = typename cutlass::arch::WmmaToCutlassDataType::Type; ///< Data Type of element stored in nvcuda::wmma::frament -+ using TensorRef = TensorRef; ///< Tensor Reference object -+ using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor -+ using Index = typename TensorRef::Index; -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ using Policy = WmmaTensorOpPolicy; -+ -+ /// Shape of the tile in memory -+ using Shape = MatrixShape< -+ Policy::kRowsPerIteration, -+ WarpShape::kN -+ >; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = WmmaFragmentArray; -+ -+ -+ /// This is the complete warp-level accumulator tile. -+ //using AccumulatorTile = typename Operator::FragmentC; -+ -+ -+ /// Padding quantity -+ // (Epilogue shared memory padding for WMMA Gemm kernel is set to run optimaly on Turing) -+ using Padding = MatrixShape< -+ 0, -+ 4 * Policy::kElementsPerAccess -+ >; -+ -+private: -+ -+ /// Storage type for accessing memory -+ //using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to shared memory -+ TensorRef ref_; -+ -+ -+public: -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp(): ref_(nullptr) { -+ -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp( -+ TensorRef const &ref, -+ unsigned lane_id -+ ): ref_(ref) { -+ } -+ -+ /// Adds a pointer offset -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & add_pointer_offset(Index pointer_offset) { -+ ref_.add_pointer_offset(pointer_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { -+ ref_.add_coord_offset({tile_offset.row() * OperatorShape::kM, tile_offset.column() * WarpShape::kN}); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_HOST_DEVICE -+ TileIteratorWmmaTensorOp & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); -+ -+ nvcuda::wmma::store_matrix_sync( -+ ptr, -+ frag[n], -+ ref_.stride()[0], -+ nvcuda::wmma::layout_t::mem_row_major -+ ); -+ -+ } -+ } -+ -+ /// Store -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ for(int n=0; n < Policy::OperatorCount::kColumn; n++) { -+ -+ WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); -+ -+ nvcuda::wmma::load_matrix_sync( -+ frag[n], -+ ptr, -+ ref_.stride()[0], -+ nvcuda::wmma::layout_t::mem_row_major -+ ); -+ -+ } -+ } -+ -+ /// Load -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ -+ /// Set smem base address -+ CUTLASS_HOST_DEVICE -+ void set_smem_base_address(Index address) { -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // !defined(__clang__) -+ -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h -new file mode 100644 -index 0000000..dede3fd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) -+ typename ElementC, ///< Accumulator layout -+ typename Layout ///< target shared memory layout -+> -+struct VoltaTensorOpPolicy; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: GemmShape) -+> -+struct VoltaTensorOpPolicy, half_t, layout::RowMajor> { -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = half_t; -+ using Layout = layout::RowMajor; -+ -+ /// Shape of one warp-levelinstruction -+ using InstructionShape = gemm::GemmShape<16, 16, 4>; -+ -+ /// Number of mma operations performed for one 32x32x4 interleaved tile -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / InstructionShape::kM, -+ InterleavedTileShape::kN / InstructionShape::kN -+ >; -+ -+ /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape -+ using TileIterations = MatrixShape< -+ WarpShape::kM / InterleavedTileShape::kM, -+ WarpShape::kN / InterleavedTileShape::kN -+ >; -+ -+ /// Number of accumulator elements owned by each thread per Mma -+ static int const kElementsPerMma = 8; -+ static int const kRowsPerIteration = 16; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ /// Number of accumulator elements stored per memory instruction to shared memory -+ static int const kElementsPerAccess = 4; -+ -+ /// Number of accesses performed per interleaved tile -+ static int const kAccessesPerInterleavedTile = 4; -+ -+ /// Total number of iterations needed to cover the entire tile -+ static int const kIterations = TileIterations::kRow * 2; -+ -+ // -+ // Derived types -+ // -+ -+ /// Array type for aligned memory accesses -+ using AccessType = AlignedArray; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ ElementC, -+ kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ ElementC, -+ TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) -+> -+struct VoltaTensorOpPolicy, float, layout::RowMajor> { -+ -+ using WarpShape = WarpShape_; -+ using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; -+ using ElementC = float; -+ using Layout = layout::RowMajor; -+ -+ /// Shape of one warp-levelinstruction -+ using InstructionShape = gemm::GemmShape<16, 16, 4>; -+ -+ /// Number of mma operations performed for one 32x32x4 interleaved tile -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / InstructionShape::kM, -+ InterleavedTileShape::kN / InstructionShape::kN -+ >; -+ -+ /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape -+ using TileIterations = MatrixShape< -+ WarpShape::kM / InterleavedTileShape::kM, -+ WarpShape::kN / InterleavedTileShape::kN -+ >; -+ -+ /// Number of accumulator elements owned by each thread per Mma -+ static int const kElementsPerMma = 8; -+ static int const kRowsPerIteration = 16; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ -+ /// Number of accumulator elements stored per memory instruction to shared memory -+ static int const kElementsPerAccess = 2; -+ -+ /// Number of accesses performed per interleaved tile -+ static int const kAccessesPerInterleavedTile = 8; -+ -+ /// Number of rows per interleaved tile -+ static int const kRowsPerMmaTile = 2; -+ -+ /// Total number of iterations needed to cover the entire tile -+ static int const kIterations = TileIterations::kRow * MmaIterations::kRow; -+ -+ // -+ // Derived types -+ // -+ -+ /// Array type for aligned memory accesses -+ using AccessType = AlignedArray; -+ -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array< -+ ElementC, -+ kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; -+ -+ /// This is the complete warp-level accumulator tile. -+ using AccumulatorTile = Array< -+ ElementC, -+ TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h -new file mode 100644 -index 0000000..bbce5cb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. -+ These quantities assume a 'column-major' arrangement of TensorOp instructions, of which -+ a row-oriented slice is visible per iteration. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy details related to the epilogue -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) -+ typename Layout ///< target shared memory layout -+> -+struct WmmaTensorOpPolicy; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for row-major -+template < -+ typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) -+ typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) -+> -+struct WmmaTensorOpPolicy { -+ -+ /// Number of operations -+ using OperatorCount = MatrixShape< -+ WarpShape::kM / OperatorShape::kM, -+ WarpShape::kN / OperatorShape::kN -+ >; -+ -+ // -+ // Hard-coded constants regarding Tensor Operations -+ // -+ static int const kElementsPerAccess = 2; -+ static int const kRowsPerIteration = OperatorShape::kM; -+ static int const kWmmaFragmentsPerAccess = 1; -+ -+ // -+ // Derived quantities -+ // -+ -+ // Number of externally visible iterations -+ static int const kIterations = OperatorCount::kRow; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -+ -diff --git a/3rdparty/cutlass/include/cutlass/fast_math.h b/3rdparty/cutlass/include/cutlass/fast_math.h -new file mode 100644 -index 0000000..c449def ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/fast_math.h -@@ -0,0 +1,975 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/uint128.h" -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+ -+/** -+ * \file -+ * \brief Math utilities -+ */ -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { -+ T tmp = lhs; -+ lhs = rhs; -+ rhs = tmp; -+} -+ -+/****************************************************************************** -+ * Static math utilities -+ ******************************************************************************/ -+ -+/// Mixed precision dot product -+template -+CUTLASS_HOST_DEVICE LongIndex dot( -+ Coord const &coord, -+ Coord const &stride, -+ LongIndex acc = LongIndex()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < N; ++n) { -+ acc += LongIndex(coord[n]) * stride[n]; -+ } -+ return acc; -+} -+ -+/** -+ * Statically determine if N is a power-of-two -+ */ -+template -+struct is_pow2 { -+ static bool const value = ((N & (N - 1)) == 0); -+}; -+ -+/** -+ * Statically determine log2(N), rounded down -+ */ -+template -+struct log2_down { -+ /// Static logarithm value -+ enum { value = log2_down> 1), Count + 1>::value }; -+}; -+ -+// Base case -+template -+struct log2_down { -+ enum { value = Count }; -+}; -+ -+/** -+ * Statically determine log2(N), rounded up -+ */ -+template -+struct log2_up { -+ /// Static logarithm value -+ enum { value = log2_up> 1), Count + 1>::value }; -+}; -+ -+// Base case -+template -+struct log2_up { -+ enum { value = ((1 << Count) < N) ? Count + 1 : Count }; -+}; -+ -+/** -+ * Statically estimate sqrt(N) to the nearest power-of-two -+ */ -+template -+struct sqrt_est { -+ enum { value = 1 << (log2_up::value / 2) }; -+}; -+ -+/** -+ * For performing a constant-division with a compile-time assertion that the -+ * Divisor evenly-divides the Dividend. -+ */ -+template -+struct divide_assert { -+ enum { value = Dividend / Divisor }; -+ -+ static_assert((Dividend % Divisor == 0), "Not an even multiple"); -+}; -+ -+/****************************************************************************** -+ * Rounding -+ ******************************************************************************/ -+ -+/** -+ * Round dividend up to the nearest multiple of divisor -+ */ -+template -+CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) { -+ return ((dividend + divisor - 1) / divisor) * divisor; -+} -+ -+/** -+ * Greatest common divisor -+ */ -+template -+CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) { -+ for (;;) { -+ if (a == 0) return b; -+ b %= a; -+ if (b == 0) return a; -+ a %= b; -+ } -+} -+ -+/** -+ * Least common multiple -+ */ -+template -+CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) { -+ value_t temp = gcd(a, b); -+ -+ return temp ? (a / temp * b) : 0; -+} -+ -+/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b -+CUTLASS_HOST_DEVICE -+constexpr int round_up(int a, int b) { -+ return ((a + b - 1) / b) * b; -+} -+ -+/// Returns the ceiling of (a / b) -+CUTLASS_HOST_DEVICE -+constexpr int ceil_div(int a, int b) { -+ return (a + b - 1) / b; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/** -+ * log2 computation, what's the -+ * difference between the below codes and -+ * log2_up/down codes? -+ */ -+template -+CUTLASS_HOST_DEVICE value_t clz(value_t x) { -+ for (int i = 31; i >= 0; --i) { -+ if ((1 << i) & x) return 31 - i; -+ } -+ return 32; -+} -+ -+template -+CUTLASS_HOST_DEVICE value_t find_log2(value_t x) { -+ int a = int(31 - clz(x)); -+ a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2. -+ return a; -+} -+ -+ -+/** -+ * Find divisor, using find_log2 -+ */ -+CUTLASS_HOST_DEVICE -+void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) { -+ if (denom == 1) { -+ mul = 0; -+ shr = 0; -+ } else { -+ unsigned int p = 31 + find_log2(denom); -+ unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom)); -+ -+ mul = m; -+ shr = p - 32; -+ } -+} -+ -+/** -+ * Find quotient and remainder using device-side intrinsics -+ */ -+CUTLASS_HOST_DEVICE -+void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) { -+ -+ #if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if div != 1, else simply copy the source. -+ quo = (div != 1) ? __umulhi(src, mul) >> shr : src; -+ #else -+ quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src); -+ #endif -+ -+ // The remainder. -+ rem = src - (quo * div); -+} -+ -+// For long int input -+CUTLASS_HOST_DEVICE -+void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) { -+ -+ #if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if div != 1, else simply copy the source. -+ quo = (div != 1) ? __umulhi(src, mul) >> shr : src; -+ #else -+ quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src); -+ #endif -+ // The remainder. -+ rem = src - (quo * div); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object to encapsulate the fast division+modulus operation. -+/// -+/// This object precomputes two values used to accelerate the computation and is best used -+/// when the divisor is a grid-invariant. In this case, it may be computed in host code and -+/// marshalled along other kernel arguments using the 'Params' pattern. -+/// -+/// Example: -+/// -+/// -+/// int quotient, remainder, dividend, divisor; -+/// -+/// FastDivmod divmod(divisor); -+/// -+/// divmod(quotient, remainder, dividend); -+/// -+/// // quotient = (dividend / divisor) -+/// // remainder = (dividend % divisor) -+/// -+struct FastDivmod { -+ -+ int divisor; -+ unsigned int multiplier; -+ unsigned int shift_right; -+ -+ /// Find quotient and remainder using device-side intrinsics -+ CUTLASS_HOST_DEVICE -+ void fast_divmod(int& quotient, int& remainder, int dividend) const { -+ -+#if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if divisor != 1, else simply copy the source. -+ quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; -+#else -+ quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend); -+#endif -+ -+ // The remainder. -+ remainder = dividend - (quotient * divisor); -+ } -+ -+ /// For long int input -+ CUTLASS_HOST_DEVICE -+ void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const { -+ -+#if defined(__CUDA_ARCH__) -+ // Use IMUL.HI if divisor != 1, else simply copy the source. -+ quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; -+#else -+ quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend); -+#endif -+ // The remainder. -+ remainder = dividend - (quotient * divisor); -+ } -+ -+ -+ /// Construct the FastDivmod object, in host code ideally. -+ /// -+ /// This precomputes some values based on the divisor and is computationally expensive. -+ -+ CUTLASS_HOST_DEVICE -+ FastDivmod(): divisor(0), multiplier(0), shift_right(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ FastDivmod(int divisor): divisor(divisor) { -+ -+ if (divisor != 1) { -+ unsigned int p = 31 + find_log2(divisor); -+ unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor)); -+ -+ multiplier = m; -+ shift_right = p - 32; -+ } else { -+ multiplier = 0; -+ shift_right = 0; -+ } -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(int "ient, int &remainder, int dividend) const { -+ fast_divmod(quotient, remainder, dividend); -+ } -+ -+ /// Computes integer division using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ int div(int dividend) const { -+ int quotient, remainder; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ /// -+ /// Simply returns the quotient -+ CUTLASS_HOST_DEVICE -+ int divmod(int &remainder, int dividend) const { -+ int quotient; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(int "ient, int64_t &remainder, int64_t dividend) const { -+ fast_divmod(quotient, remainder, dividend); -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ int divmod(int64_t &remainder, int64_t dividend) const { -+ int quotient; -+ fast_divmod(quotient, remainder, dividend); -+ return quotient; -+ } -+ -+ /// Returns the divisor when cast to integer -+ CUTLASS_HOST_DEVICE -+ operator int() const { return divisor; } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object to encapsulate the fast division+modulus operation for 64b integer division. -+/// -+/// This object precomputes two values used to accelerate the computation and is best used -+/// when the divisor is a grid-invariant. In this case, it may be computed in host code and -+/// marshalled along other kernel arguments using the 'Params' pattern. -+/// -+/// Example: -+/// -+/// -+/// uint64_t quotient, remainder, dividend, divisor; -+/// -+/// FastDivmodU64 divmod(divisor); -+/// -+/// divmod(quotient, remainder, dividend); -+/// -+/// // quotient = (dividend / divisor) -+/// // remainder = (dividend % divisor) -+/// -+struct FastDivmodU64 { -+ -+ uint64_t divisor; -+ uint64_t multiplier; -+ unsigned int shift_right; -+ unsigned int round_up; -+ -+ // -+ // Static methods -+ // -+ -+ /// Computes b, where 2^b is the greatest power of two that is less than or equal to x -+ CUTLASS_HOST_DEVICE -+ static uint32_t integer_log2(uint64_t x) { -+ uint32_t n = 0; -+ while (x >>= 1) { -+ ++n; -+ } -+ return n; -+ } -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ FastDivmodU64(): divisor(0), multiplier(0), shift_right(0), round_up(0) { } -+ -+ /// Construct the FastDivmod object, in host code ideally. -+ /// -+ /// This precomputes some values based on the divisor and is computationally expensive. -+ CUTLASS_HOST_DEVICE -+ FastDivmodU64(uint64_t divisor_): divisor(divisor_), multiplier(1), shift_right(0), round_up(0) { -+ -+ if (divisor) { -+ shift_right = integer_log2(divisor); -+ -+ if ((divisor & (divisor - 1)) == 0) { -+ multiplier = 0; -+ } -+ else { -+ uint64_t power_of_two = (uint64_t(1) << shift_right); -+ uint64_t multiplier_lo = uint128_t(0, power_of_two) / divisor; -+ multiplier = uint128_t(power_of_two, power_of_two) / divisor; -+ round_up = (multiplier_lo == multiplier ? 1 : 0); -+ } -+ } -+ } -+ -+ /// Returns the quotient of floor(dividend / divisor) -+ CUTLASS_HOST_DEVICE -+ uint64_t divide(uint64_t dividend) const { -+ uint64_t quotient = 0; -+ -+ #ifdef __CUDA_ARCH__ -+ uint64_t x = dividend; -+ if (multiplier) { -+ x = __umul64hi(dividend + round_up, multiplier); -+ } -+ quotient = (x >> shift_right); -+ #else -+ // TODO - use proper 'fast' division here also. No reason why x86-code shouldn't be optimized. -+ quotient = dividend / divisor; -+ #endif -+ -+ return quotient; -+ } -+ -+ /// Computes the remainder given a computed quotient and dividend -+ CUTLASS_HOST_DEVICE -+ uint64_t modulus(uint64_t quotient, uint64_t dividend) const { -+ return uint32_t(dividend - quotient * divisor); -+ } -+ -+ /// Returns the quotient of floor(dividend / divisor) and computes the remainder -+ CUTLASS_HOST_DEVICE -+ uint64_t divmod(uint64_t &remainder, uint64_t dividend) const { -+ uint64_t quotient = divide(dividend); -+ remainder = modulus(quotient, dividend); -+ return quotient; -+ } -+ -+ /// Computes integer division and modulus using precomputed values. This is computationally -+ /// inexpensive. -+ CUTLASS_HOST_DEVICE -+ void operator()(uint64_t "ient, uint64_t &remainder, uint64_t dividend) const { -+ quotient = divmod(remainder, dividend); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes the coordinate decomposition from a linear index (64-bit linear index => coord) -+/// -+/// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that -+/// a coordinate of indices can be decomposed by div/mod operations. -+/// Note, is assumed that element divmod[0] divides by extent[1]. -+/// -+/// For example, assume 4-D coordinate (n, p, q, c) is mapped to a linear index `npqc`. This -+/// can be decomposed via three divide and modulus operations: -+/// -+/// c = npqc % C; | divmod[2] = FastDivmodU64(C) -+/// npq = npqc / C; | coord[3] = c -+/// -+/// q = npq % Q; | divmod[1] = FastDivmodU64(Q) -+/// np = npq / Q; | coord[2] = q -+/// -+/// p = np % P; | divmod[0] = FastDivmodU64(P) -+/// n = np / P; | coord[1] = p -+/// -+/// | coord[0] = n -+/// -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( -+ uint64_t linear_idx, ///< Linear index to decompose -+ FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank; i > 1; --i) { -+ uint64_t remainder; -+ linear_idx = divmod[i - 2].divmod(remainder, linear_idx); -+ coord[i - 1] = int(remainder); -+ } -+ -+ coord[0] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( -+ int linear_idx, ///< Linear index to decompose -+ FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank; i > 1; --i) { -+ int remainder; -+ linear_idx = divmod[i - 2].divmod(remainder, linear_idx); -+ coord[i - 1] = int(remainder); -+ } -+ -+ coord[0] = int(linear_idx); -+ -+ return coord; -+} -+ -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( -+ uint64_t linear_idx, ///< Linear index to decompose -+ FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank - 1; ++i) { -+ uint64_t remainder; -+ linear_idx = divmod[i].divmod(remainder, linear_idx); -+ coord[i] = int(remainder); -+ } -+ -+ coord[Rank - 1] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) -+template -+CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( -+ int linear_idx, ///< Linear index to decompose -+ FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects -+ -+ static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); -+ -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank - 1; ++i) { -+ int remainder; -+ linear_idx = divmod[i].divmod(remainder, linear_idx); -+ coord[i] = int(remainder); -+ } -+ -+ coord[Rank - 1] = int(linear_idx); -+ -+ return coord; -+} -+ -+/// Safely computes the offset of a linear index in bytes for all types -+template -+CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { -+ -+ static_assert( -+ (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || -+ (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), -+ "Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size."); -+ -+ if (sizeof_bits::value >= 8) { -+ return index * (sizeof_bits::value / 8); -+ } -+ else { -+ int const kElementsPerByte = ((8 / sizeof_bits::value) + ((sizeof_bits::value >= 8) ? 1 : 0)); -+ return index / kElementsPerByte; -+ } -+} -+ -+CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) { -+ if (element_sizeof_bits >= 8) { -+ return index * (element_sizeof_bits / 8); -+ } -+ else { -+ int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0)); -+ return index / kElementsPerByte; -+ } -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Min/Max -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Min { -+ static int const kValue = (A < B) ? A : B; -+}; -+ -+template -+struct Max { -+ static int const kValue = (A > B) ? A : B; -+}; -+ -+CUTLASS_HOST_DEVICE -+constexpr int const_min(int a, int b) { -+ return (b < a ? b : a); -+} -+ -+CUTLASS_HOST_DEVICE -+constexpr int const_max(int a, int b) { -+ return (b > a ? b : a); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+T fast_min(T a, T b) { -+ return (b < a ? b : a); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+float fast_min(float a, float b) { -+ return fminf(a, b); -+} -+ -+template -+CUTLASS_HOST_DEVICE -+T fast_max(T a, T b) { -+ return (a < b ? b : a); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+float fast_max(float a, float b) { -+ return fmaxf(a, b); -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_cos(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::cosf(theta); -+ #else -+ return std::cos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_cos(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::cos(theta); -+ #else -+ return std::cos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_sin(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sinf(theta); -+ #else -+ return std::sin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_sin(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sin(theta); -+ #else -+ return std::sin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_acos(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::acosf(theta); -+ #else -+ return std::acos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_acos(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::acos(theta); -+ #else -+ return std::acos(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_asin(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::asinf(theta); -+ #else -+ return std::asin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_asin(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::asin(theta); -+ #else -+ return std::asin(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_sqrt(float theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sqrtf(theta); -+ #else -+ return std::sqrt(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_sqrt(double theta) { -+ #if defined(__CUDA_ARCH__) -+ return ::sqrt(theta); -+ #else -+ return std::sqrt(theta); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_exp(float x) { -+ #if defined(__CUDA_ARCH__) -+ return ::expf(x); -+ #else -+ return std::exp(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_exp(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::exp(x); -+ #else -+ return std::exp(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t fast_exp(half_t x) { -+ #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) -+ return (half_t)(::hexp(x.to_half())); -+ #else -+ return (half_t)(fast_exp(float(x))); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_log(float x) { -+ #if defined(__CUDA_ARCH__) -+ return ::logf(x); -+ #else -+ return std::log(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_log(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::log(x); -+ #else -+ return std::log(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+float fast_tanh(float x) { -+ #if defined(__CUDA_ARCH__) -+ #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+ float y; -+ asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); -+ return y; -+ #else -+ return ::tanhf(x); -+ #endif -+ #else -+ return std::tanh(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+double fast_tanh(double x) { -+ #if defined(__CUDA_ARCH__) -+ return ::tanh(x); -+ #else -+ return std::tanh(x); -+ #endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t fast_tanh(half_t x) { -+ #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+ -+ asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw())); -+ return x; -+ -+ #else -+ return half_t(fast_tanh(float(x))); -+ #endif -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct fast_exp_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &rhs) const { -+ return fast_exp(rhs); -+ } -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) -+template -+struct fast_exp_op> { -+ CUTLASS_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ Array result; -+ -+ // use x2 specialization -+ __half2 const *in = reinterpret_cast<__half2 const *>(&rhs); -+ __half2 *out = reinterpret_cast<__half2 *>(&result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ out[i] = ::h2exp(in[i]); -+ } -+ -+ // residual -+ if (N % 2) { -+ half_t last = rhs[N - 1]; -+ result[N - 1] = half_t(::hexp(last.to_half())); -+ } -+ -+ return result; -+ } -+}; -+#endif // #if defined(__CUDA_ARCH__) -+ -+template -+struct fast_exp_op> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ fast_exp_op fast_op; -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = fast_op(rhs[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct fast_tanh_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &rhs) const { -+ return fast_tanh(rhs); -+ } -+}; -+ -+#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) -+template -+struct fast_tanh_op> { -+ CUTLASS_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ Array result; -+ -+ // use x2 specialization -+ uint32_t const *in = reinterpret_cast(&rhs); -+ uint32_t *out = reinterpret_cast(&result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i])); -+ } -+ -+ // residual -+ if (N % 2) { -+ uint16_t const *in = reinterpret_cast(&rhs); -+ uint16_t *out = reinterpret_cast(&result); -+ asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1])); -+ } -+ -+ return result; -+ } -+}; -+#endif // #if defined(__CUDA_ARCH__) -+ -+template -+struct fast_tanh_op> { -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &rhs) const { -+ -+ fast_tanh_op fast_op; -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ y[i] = fast_op(rhs[i]); -+ } -+ -+ return y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Absolute value function -+template -+CUTLASS_HOST_DEVICE -+T absolute_value(T x) { -+ if (x < T()) { -+ return -x; -+ } -+ return x; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/float8.h b/3rdparty/cutlass/include/cutlass/float8.h -new file mode 100644 -index 0000000..93e3209 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/float8.h -@@ -0,0 +1,1215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using IEEE half-precision floating-point types in host or -+ device code. -+*/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#if defined(__CUDACC_RTC__) -+ -+#include "cutlass/floating_point_nvrtc.h" -+ -+#else -+// -+// Standard Library headers belong here to avoid conflicts with NVRTC. -+// -+#include -+#include -+#include -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -+#ifndef CUDA_PTX_FP8_CVT_ENABLED -+#define CUDA_PTX_FP8_CVT_ENABLED 1 -+#endif -+#endif -+#endif -+ -+#ifdef __GNUC__ -+// Ignore checks on reinterpret-casts that are being used for bitcasts. -+#pragma GCC diagnostic ignored "-Wstrict-aliasing" -+#endif -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// FP8 Has 2 encodings possible : E4M3 and E5M2 -+// -+// E4M3 : 7 | 6 5 4 3 | 2 1 0 -+// E5M2 : 7 | 6 5 4 3 2 | 1 0 -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class FloatEncoding { -+ E4M3, -+ E5M2 -+}; -+ -+template -+struct alignas(1) float8_base { -+ -+ static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3); -+ static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2); -+ -+ // Number of Bits representing mantissa and exponents -+ static constexpr int FP32_NUM_BITS = 32; -+ static constexpr int FP32_NUM_EXPONENT_BITS = 8; -+ static constexpr int FP32_NUM_MANTISSA_BITS = 23; -+ static constexpr uint32_t FP32_NAN = 0x7fffffff; -+ static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000; -+ static constexpr int FP32_MAX_EXPONENT = 127; -+ static constexpr int FP32_MIN_EXPONENT = -126; -+ static constexpr int FP32_EXPONENT_BIAS = 127; -+ -+ static constexpr int FP16_NUM_BITS = 16; -+ static constexpr int FP16_NUM_EXPONENT_BITS = 5; -+ static constexpr int FP16_NUM_MANTISSA_BITS = 10; -+ static constexpr uint16_t FP16_NAN = 0x7fff; -+ static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00; -+ static constexpr int FP16_MAX_EXPONENT = 15; -+ static constexpr int FP16_MIN_EXPONENT = -14; -+ static constexpr int FP16_EXPONENT_BIAS = 15; -+ -+ static constexpr int FP8_NUM_BITS = 8; -+ static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5; -+ static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2; -+ static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF -+ static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c; -+ static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15; -+ static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14; -+ static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15; -+ -+ static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1; -+ static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1; -+ -+ static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b); -+ -+ // 256 in float -+ static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000; -+ -+ // -+ // Data members -+ // -+ -+ /// Data container -+ uint8_t storage; -+ -+ /// Ctors. -+ CUTLASS_HOST_DEVICE -+ float8_base() : storage(0) { } -+ -+ /// Is finite implementation -+ CUTLASS_HOST_DEVICE -+ static bool isfinite(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ return (s & 0x7f800000) < 0x7f800000; -+ } -+ -+ /// Is NaN implementation -+ CUTLASS_HOST_DEVICE -+ static bool isnan(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ return (s & 0x7fffffff) > 0x7f800000; -+ } -+ -+ /// Is infinite implementation -+ CUTLASS_HOST_DEVICE -+ static bool isinf(float flt) { -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ // Sign = 0 for +inf, 1 for -inf -+ // Exponent = all ones -+ // Mantissa = all zeros -+ return (s == 0x7f800000) || (s == 0xff800000); -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static uint8_t convert_float_to_fp8(float const& flt) { -+ -+ // software implementation rounds toward nearest even -+ uint32_t s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ // Extract the bits in the FP32 type -+ uint8_t sign = uint8_t((s >> 24 & 0x80)); -+ int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS); -+ int mantissa = s & 0x7fffff; -+ uint8_t u = 0; -+ -+ uint8_t const kF8_NaN = 0x7f; -+ -+ // NaN => NaN -+ if (isnan(flt)) { -+ return kF8_NaN; -+ } -+ -+ // Inf => MAX_FLT (satfinite) -+ if (isinf(flt)) { -+ return sign | FP8_MAX_FLT; -+ } -+ -+ // Special handling -+ if ( exp == -128 ) { -+ // int8 range is from -128 to 127 -+ // So 255(inf) - 127(bias) = 128 - will show up as -128 -+ -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ -+ int sticky_bit = 0; -+ -+ bool skip_sign = false; -+ bool may_be_nan = false; -+ -+ if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { -+ // normal fp32 to normal fp8 -+ exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); -+ u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS)); -+ u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); -+ } else if(exp < FP8_MIN_EXPONENT) { -+ // normal single-precision to subnormal float8-precision representation -+ int rshift = (FP8_MIN_EXPONENT - exp); -+ if (rshift < FP32_NUM_BITS) { -+ mantissa |= (1 << FP32_NUM_MANTISSA_BITS); -+ -+ sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); -+ -+ mantissa = (mantissa >> rshift); -+ u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ // Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW -+ // 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf. -+ } else { -+ if( exp == (FP8_MAX_EXPONENT + 1) ) { -+ uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); -+ if( mantissa_tmp < FP8_MANTISSA_MASK) { -+ exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); -+ u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; -+ may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); -+ } else { -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ } else{ -+ // satfinite -+ return (sign | FP8_MAX_FLT); -+ } -+ } -+ -+ // round to nearest even -+ int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1); -+ int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1); -+ sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { -+ u = uint8_t(u + 1); -+ if( may_be_nan ) { -+ skip_sign = true; -+ } -+ } -+ -+ if (u > FP8_MAX_FLT) { -+ // satfinite -+ u = (sign | FP8_MAX_FLT); -+ } -+ -+ if( ! skip_sign ) { -+ u |= sign; -+ } -+ -+ return u; -+ } -+ -+ -+ /// Converts a fp8 value stored as a uint8_t to a float -+ CUTLASS_HOST_DEVICE -+ static float convert_fp8_to_float(uint8_t const& x) { -+ -+ uint32_t constexpr kF32_NaN = 0x7fffffff; -+ -+ uint8_t const &f8 = x; -+ int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; -+ int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; -+ int mantissa = f8 & FP8_MANTISSA_MASK; -+ unsigned f = (sign << (FP32_NUM_BITS-1)); -+ -+ if (IS_E4M3 && exp == 15 && mantissa == 0x7) { -+ f = kF32_NaN; -+ } -+ else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) { -+ // normal -+ exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS); -+ f = f | -+ (exp << FP32_NUM_MANTISSA_BITS) | -+ (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); -+ } else if (exp == 0) { -+ if (mantissa) { -+ // subnormal -+ exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1; -+ while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) { -+ mantissa <<= 1; -+ exp--; -+ } -+ mantissa &= FP8_MANTISSA_MASK; -+ f = f | -+ (exp << FP32_NUM_MANTISSA_BITS) | -+ (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); -+ } else { -+ // sign-preserving zero -+ } -+ } else { -+ if(mantissa == 0){ -+ // Sign-preserving infinity -+ f = (f | 0x7f800000); -+ } else { -+ // Canonical NaN -+ f = kF32_NaN; -+ } -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(f); -+ #else -+ float flt; -+ std::memcpy(&flt, &f, sizeof(flt)); -+ return flt; -+ #endif -+ } -+}; -+ -+ -+// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t -+// conversions in class float_e4m3_t -+struct float_e5m2_t; -+ -+ -+/////////////////////////////////////////////////////////////// -+/// -+/// floating-point 8 type : E4M3 -+/// -+/////////////////////////////////////////////////////////////// -+struct alignas(1) float_e4m3_t : float8_base { -+ -+ using Base = float8_base; -+ -+ static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an uint8_t -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t bitcast(uint8_t x) { -+ float_e4m3_t f; -+ f.storage = x; -+ return f; -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t from_float(float const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp; -+ float y = float(); -+ asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(flt)); -+ #endif -+ } -+ -+ /// FP16 -> E5M2 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e4m3_t from_half(half const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp = 0; -+ uint32_t bits = reinterpret_cast(flt); -+ asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(float(flt))); -+ #endif -+ } -+ -+ // E4M3 -> half -+ CUTLASS_HOST_DEVICE -+ static half to_half(float_e4m3_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return reinterpret_cast(packed).x; -+ #else -+ return half(Base::convert_fp8_to_float(x.storage)); -+ #endif -+ } -+ -+ // E4M3 -> Float -+ CUTLASS_HOST_DEVICE -+ static float to_float(float_e4m3_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return float(reinterpret_cast(packed).x); -+ #else -+ return Base::convert_fp8_to_float(x.storage); -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t() : Base() { } -+ -+ /// Reinterpret cast from CUDA's FP8 type -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t(float_e4m3_t const& x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(float x) { -+ storage = from_float(x).storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(half x) { -+ storage = from_half(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(double x): float_e4m3_t(float(x)) { -+ } -+ -+ /// Integer conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(int x): float_e4m3_t(float(x)) { -+ } -+ -+ /// E5M2 conversion. Defined after float_e5m2_t is defined. -+ CUTLASS_HOST_DEVICE -+ explicit float_e4m3_t(float_e5m2_t x); -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ float_e4m3_t & operator=(float_e4m3_t const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return to_float(*this); -+ } -+ -+ /// Converts to half -+ CUTLASS_HOST_DEVICE -+ operator half() const { -+ return to_half(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(to_float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ #if defined(__CUDA_ARCH__) -+ return __half2int_rn(to_half(*this)); -+ #else -+ return int(to_float(*this)); -+ #endif -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ #if defined(__CUDA_ARCH__) -+ return bool(__half2int_rn(to_half(*this))); -+ #else -+ return bool(int(to_float(*this))); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & Base::FP8_MANTISSA_MASK); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////// -+/// -+/// floating-point 8 type : E5M2 -+/// -+/////////////////////////////////////////////////////////////// -+struct alignas(1) float_e5m2_t : float8_base { -+ -+ using Base = float8_base; -+ -+ static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an uint8_t -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t bitcast(uint8_t x) { -+ float_e5m2_t f; -+ f.storage = x; -+ return f; -+ } -+ -+ /// FP32 -> FP8 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t from_float(float const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp; -+ float y = float(); -+ asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(flt)); -+ #endif -+ } -+ -+ /// FP16 -> E5M2 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static float_e5m2_t from_half(half const& flt) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t tmp = 0; -+ uint32_t bits = reinterpret_cast(flt); -+ asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); -+ -+ return *reinterpret_cast(&tmp); -+ #else -+ return bitcast(Base::convert_float_to_fp8(float(flt))); -+ #endif -+ } -+ -+ // E5M2 -> half -+ CUTLASS_HOST_DEVICE -+ static half to_half(float_e5m2_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return reinterpret_cast(packed).x; -+ #else -+ return half(Base::convert_fp8_to_float(x.storage)); -+ #endif -+ } -+ -+ // E5M2 -> Float -+ CUTLASS_HOST_DEVICE -+ static float to_float(float_e5m2_t const& x) { -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint16_t bits = x.storage; -+ uint32_t packed; -+ asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); -+ -+ return float(reinterpret_cast(packed).x); -+ #else -+ return Base::convert_fp8_to_float(x.storage); -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t() : Base() { } -+ -+ /// Reinterpret cast from CUDA's FP8 type -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t(float_e5m2_t const& x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(float x) { -+ storage = from_float(x).storage; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(half x) { -+ storage = from_half(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(double x): float_e5m2_t(float(x)) { -+ } -+ -+ /// Integer conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(int x): float_e5m2_t(float(x)) { -+ } -+ -+ /// E4M3 conversion -+ CUTLASS_HOST_DEVICE -+ explicit float_e5m2_t(float_e4m3_t x); -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ float_e5m2_t & operator=(float_e5m2_t const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ uint8_t raw = x.storage; -+ std::memcpy(&storage, &raw, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return to_float(*this); -+ } -+ -+ /// Converts to half -+ CUTLASS_HOST_DEVICE -+ operator half() const { -+ return to_half(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(to_float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ #if defined(__CUDA_ARCH__) -+ return __half2int_rn(to_half(*this)); -+ #else -+ return int(to_float(*this)); -+ #endif -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ #if defined(__CUDA_ARCH__) -+ return bool(__half2int_rn(to_half(*this))); -+ #else -+ return bool(int(to_float(*this))); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint8_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & Base::FP8_MANTISSA_MASK); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator-(float_e4m3_t const& lhs) { -+ return float_e4m3_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { -+ return float_e4m3_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { -+ lhs = float_e4m3_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator++(float_e4m3_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = float_e4m3_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t& operator--(float_e4m3_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = float_e4m3_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator++(float_e4m3_t & lhs, int) { -+ float_e4m3_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = float_e4m3_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e4m3_t operator--(float_e4m3_t & lhs, int) { -+ float_e4m3_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = float_e4m3_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) + float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator-(float_e5m2_t const& lhs) { -+ return float_e5m2_t(-float(lhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { -+ return float_e5m2_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { -+ lhs = float_e5m2_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator++(float_e5m2_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = float_e5m2_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t& operator--(float_e5m2_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = float_e5m2_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator++(float_e5m2_t & lhs, int) { -+ float_e5m2_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = float_e5m2_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+float_e5m2_t operator--(float_e5m2_t & lhs, int) { -+ float_e5m2_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = float_e5m2_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// float_e4m3_t <=> float_e5m2_t conversions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// float_e4m3_t <= float_e5m2_t -+CUTLASS_HOST_DEVICE -+float_e4m3_t::float_e4m3_t(float_e5m2_t x) { -+ storage = from_float(float_e5m2_t::to_float(x)).storage; -+} -+ -+/// float_e5m2_t <= float_e4m3_t -+CUTLASS_HOST_DEVICE -+float_e5m2_t::float_e5m2_t(float_e4m3_t x) { -+ storage = from_float(float_e4m3_t::to_float(x)).storage; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+namespace std { -+ -+/// Numeric limits common to all float8 types -+template -+struct float8_base_numeric_limits { -+private: -+ using F8Type = T; -+public: -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; -+ -+ /// Least positive value -+ static F8Type min() { return F8Type::bitcast(0x01); } -+ -+ /// Maximum finite value -+ static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } -+ -+ /// Returns maximum rounding error -+ static F8Type round_error() { return F8Type(0.5f); } -+ -+ /// Returns positive infinity value -+ static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } -+ -+ /// Returns quiet NaN value -+ static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns signaling NaN value -+ static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns smallest positive subnormal value -+ static F8Type denorm_min() { return F8Type::bitcast(0x01); } -+}; -+ -+/// Numeric limits for float_e4m3_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = false; -+ -+ /// Minimum finite value -+ static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } -+}; -+ -+/// Numeric limits for float_e5m2_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = true; -+ -+ /// Minimum finite value -+ static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } -+}; -+ -+} // namespace std -+#endif -+ -+namespace platform { -+ -+/// Numeric limits common to all float8 types -+template -+struct float8_base_numeric_limits { -+private: -+ using F8Type = T; -+public: -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+#if !defined(__CUDACC_RTC__) -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+#endif -+ static bool const has_denorm_loss = true; -+#if !defined(__CUDACC_RTC__) -+ static std::float_round_style const round_style = std::round_to_nearest; -+#endif -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; -+ -+ /// Least positive value -+ static F8Type min() { return F8Type::bitcast(0x01); } -+ -+ /// Maximum finite value -+ static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } -+ -+ /// Returns maximum rounding error -+ static F8Type round_error() { return F8Type(0.5f); } -+ -+ /// Returns positive infinity value -+ static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } -+ -+ /// Returns quiet NaN value -+ static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns signaling NaN value -+ static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } -+ -+ /// Returns smallest positive subnormal value -+ static F8Type denorm_min() { return F8Type::bitcast(0x01); } -+}; -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+/// Numeric limits for float_e4m3_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = false; -+ -+ /// Minimum finite value -+ static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } -+}; -+ -+/// Numeric limits for float_e5m2_t -+template <> -+struct numeric_limits : -+ public float8_base_numeric_limits { -+ static bool const has_infinity = true; -+ -+ /// Minimum finite value -+ static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } -+ -+ /// Returns smallest finite value -+ static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } -+}; -+ -+} // namespace platform -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e4m3_t operator "" _fe4m3(long double x) { -+ return cutlass::float_e4m3_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { -+ return cutlass::float_e4m3_t(int(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e5m2_t operator "" _fe5m2(long double x) { -+ return cutlass::float_e5m2_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { -+ return cutlass::float_e5m2_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h b/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h -new file mode 100644 -index 0000000..99deff5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/floating_point_nvrtc.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines categories for floating point numbers for use in NVRTC-compiled code -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// All floating-point numbers can be put in one of these categories. -+enum { -+ FP_NAN = -+# define FP_NAN 0 -+ FP_NAN, -+ FP_INFINITE = -+# define FP_INFINITE 1 -+ FP_INFINITE, -+ FP_ZERO = -+# define FP_ZERO 2 -+ FP_ZERO, -+ FP_SUBNORMAL = -+# define FP_SUBNORMAL 3 -+ FP_SUBNORMAL, -+ FP_NORMAL = -+# define FP_NORMAL 4 -+ FP_NORMAL -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/functional.h b/3rdparty/cutlass/include/cutlass/functional.h -new file mode 100644 -index 0000000..277bad5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/functional.h -@@ -0,0 +1,490 @@ -+ /*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Define basic numeric operators -+ -+ This is inspired by the Standard Library's header. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/half.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct absolute_value_op { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs) const { -+ return abs(lhs); -+ } -+}; -+ -+template <> -+struct absolute_value_op { -+ CUTLASS_HOST_DEVICE -+ float operator()(float lhs) const { return fabs(lhs); } -+}; -+ -+template -+struct plus { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs += rhs; -+ return lhs; -+ } -+}; -+ -+template -+struct minus { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs -= rhs; -+ return lhs; -+ } -+}; -+ -+template -+struct multiplies { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs *= rhs; -+ return lhs; -+ } -+}; -+ -+// Maximum with nan propogation -+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -+template -+struct maximum_with_nan_propogation { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ } -+}; -+ -+template <> -+struct maximum_with_nan_propogation { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const lhs, float const rhs) const { -+ float res; -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); -+#else -+ res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+#endif -+ return res; -+ } -+}; -+ -+/// Squares with optional conversion -+template -+struct square { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Returns the magnitude squared of an element. -+template -+struct magnitude_squared { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct square_difference { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs, T rhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs) - Output(rhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Computes the square of a difference with optional conversion -+template -+struct magnitude_squared_difference { -+ CUTLASS_HOST_DEVICE -+ Output operator()(T lhs, T rhs) const { -+ multiplies mul_op; -+ -+ Output y = Output(lhs) - Output(rhs); -+ return mul_op(y, y); -+ } -+}; -+ -+/// Divides -+template -+struct divides { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ lhs /= rhs; -+ return lhs; -+ } -+}; -+ -+/// Negate -+template -+struct negate { -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs) const { -+ return -lhs; -+ } -+}; -+ -+/// Greater equal -+template -+struct greater_equal { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs >= rhs); -+ } -+}; -+ -+/// Greater -+template -+struct greater { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs > rhs); -+ } -+}; -+ -+/// Less equal -+template -+struct less_equal { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs <= rhs); -+ } -+}; -+ -+/// Less -+template -+struct less { -+ CUTLASS_HOST_DEVICE -+ bool operator()(T const &lhs, T const &rhs) const { -+ return (lhs < rhs); -+ } -+}; -+ -+template -+struct maximum { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return (lhs < rhs ? rhs : lhs); -+ } -+}; -+ -+template <> -+struct maximum { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &lhs, float const &rhs) const { -+ return fmaxf(lhs, rhs); -+ } -+}; -+ -+template -+struct minimum { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &lhs, T const &rhs) const { -+ return (rhs < lhs ? rhs : lhs); -+ } -+}; -+ -+template <> -+struct minimum { -+ CUTLASS_HOST_DEVICE -+ float operator()(float const &lhs, float const &rhs) const { -+ return fminf(lhs, rhs); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add { -+ CUTLASS_HOST_DEVICE -+ C operator()(A const &a, B const &b, C const &c) const { -+ return C(a) * C(b) + c; -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct multiply_add_relu0 { -+ CUTLASS_HOST_DEVICE -+ C operator()(A const &a, B const &b, C const &c) const { -+ maximum mx; -+ return mx(C(a) * C(b) + c, C(0)); -+ } -+}; -+ -+/// Fused multiply-add -+template -+struct and_add { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b, T const &c) const { -+ return ((a & b) + c); -+ } -+}; -+ -+ -+/// Fused multiply-add -+template -+struct xor_add { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b, T const &c) const { -+ return ((a ^ b) + c); -+ } -+}; -+ -+template -+struct conjugate { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return a; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct logical_and { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return ((a && b) ? T(1) : T()); -+ } -+}; -+ -+template -+struct logical_or { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return ((a || b) ? T(1) : T()); -+ } -+}; -+ -+template -+struct logical_not { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return T(!(a)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct bit_and { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a & b; -+ } -+}; -+ -+template -+struct bit_or { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a | b; -+ } -+}; -+ -+template -+struct bit_not { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a) const { -+ return ~a; -+ } -+}; -+ -+template -+struct bit_xor { -+ CUTLASS_HOST_DEVICE -+ T operator()(T const &a, T const &b) const { -+ return a ^ b; -+ } -+}; -+ -+ -+ -+////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Reduces value into the data pointed to by ptr -+template -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(T *ptr, const T &data) -+ { -+ atomicAdd(ptr, data); -+ } -+}; -+ -+ -+/// Reduces value into the data pointed to by ptr (double specialization) -+template<> -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(double *ptr, const double &data) -+ { -+#if !defined(__CUDA_ARCH__) -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_UNUSED(data); -+#elif (__CUDA_ARCH__ >= 600) -+ -+ atomicAdd(ptr, data); -+ -+#else -+ -+ // Use CAS loop -+ unsigned long long int* ptr_int = reinterpret_cast(ptr); -+ unsigned long long int old_int = *ptr_int; -+ unsigned long long int assumed_int; -+ -+ do { -+ double update = data + __longlong_as_double(old_int); -+ assumed_int = old_int; -+ old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update)); -+ } while (assumed_int != old_int); -+ -+#endif // (__CUDA_ARCH__ >= 600) -+ } -+}; -+ -+ -+/// Reduces value into the data pointed to by ptr (half2 specialization) -+template<> -+struct red -+{ -+ CUTLASS_DEVICE -+ void operator()(half2 *ptr, const half2 &data) -+ { -+#if !defined(__CUDA_ARCH__) -+ CUTLASS_UNUSED(ptr); -+ CUTLASS_UNUSED(data); -+#elif (__CUDA_ARCH__ >= 600) -+ -+ // Vector-2 atomic reduction requires .target sm_60 or higher -+ uint32_t word = reinterpret_cast(data); -+ asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word)); -+ -+#else -+ -+ // Use CAS loop -+ uint32_t *ptr_int = reinterpret_cast(ptr); -+ uint32_t old_int = *ptr_int; -+ uint32_t assumed_int; -+ -+ do -+ { -+ half2 old = reinterpret_cast(old_int); -+ -+ half hi = __hadd(__high2half(old), __high2half(data)); -+ half lo = __hadd(__low2half(old), __low2half(data)); -+ half2 update = __halves2half2(hi, lo); -+ uint32_t update_int = reinterpret_cast(update); -+ -+ assumed_int = old_int; -+ old_int = atomicCAS(ptr_int, assumed_int, update_int); -+ -+ } while (assumed_int != old_int); -+ -+#endif // (__CUDA_ARCH__ >= 600) -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for nvcuda::wmma::fragment -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+template -+struct plus> -+{ -+ using Fragment = nvcuda::wmma::fragment; -+ using ElementType = typename Fragment::element_type; -+ -+ CUTLASS_HOST_DEVICE -+ Fragment operator()(Fragment const &lhs, Fragment const &rhs) const -+ { -+ Fragment result; -+ plus scalar_op; -+ -+ ElementType *result_elts = reinterpret_cast(&result); -+ const ElementType *lhs_elts = reinterpret_cast(&lhs); -+ const ElementType *rhs_elts = reinterpret_cast(&rhs); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Fragment::num_elements; i++) { -+ result_elts[i] = scalar_op(lhs_elts[i], rhs_elts[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp -new file mode 100644 -index 0000000..3cd68a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_builder.hpp -@@ -0,0 +1,78 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#include "collective_mma.hpp" -+ -+namespace cutlass::gemm::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Used to specify stage counts or dispatch to automatic computation of stage count -+template -+struct StageCount { static constexpr int value = num_stages; }; -+struct StageCountAuto {}; -+ -+// Used to automatically let the builder pick the kernel schedule. -+// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp -+struct KernelScheduleAuto {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ArchTag, -+ class OpClass, -+ class ElementA, -+ class GmemLayoutA, -+ int AlignmentA, -+ class ElementB, -+ class GmemLayoutB, -+ int AlignmentB, -+ class ElementAccumulator, -+ class TileShape_MNK, -+ class ClusterShape_MNK, -+ class StageCountType, -+ class KernelScheduleType, -+ class Enable = void -+> -+struct CollectiveBuilder { -+ static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "builders/sm90_gmma_builder.inl" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp -new file mode 100644 -index 0000000..a2a9067 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/collective_mma.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class DispatchPolicy, -+ class TileShape, -+ class ElementA, -+ class StrideA, -+ class ElementB, -+ class StrideB, -+ class TiledMma, -+ class GmemTiledCopyA, -+ class SmemLayoutAtomA, -+ class SmemCopyAtomA, -+ class TransformA, -+ class GmemTiledCopyB, -+ class SmemLayoutAtomB, -+ class SmemCopyAtomB, -+ class TransformB -+> -+struct CollectiveMma { -+ static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization."); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "sm70_mma_twostage.hpp" -+#include "sm80_mma_multistage.hpp" -+#include "sm90_mma_multistage_gmma_ss.hpp" -+#include "sm90_mma_tma_gmma_ss.hpp" -+#include "sm90_mma_tma_gmma_ss_warpspecialized.hpp" -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp -new file mode 100644 -index 0000000..11e5515 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp -@@ -0,0 +1,588 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/tensor_predicate.hpp" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm70TwoStageUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm70TwoStageUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ (void)residue_mnk; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ static_assert(rank(SmemLayoutB{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto copy_a_thr = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto copy_b_thr = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = copy_a_thr.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = copy_a_thr.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBgB = copy_b_thr.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = copy_b_thr.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // Allocate the register tiles for double buffering -- same shape as partitioned data -+ Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_M,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsA = thr_copy_A.partition_S(sA); -+ Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M -+ -+ auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsB = thr_copy_B.partition_S(sB); -+ Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N -+ -+ // -+ // Prologue -+ // -+ -+ // Copy gmem to rmem for the first k_tile -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); -+ if (--k_tile_count > 0) ++k_tile_iter; -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ // Clear accumulators -+ __syncthreads(); -+ -+ // Load A, B smem->rmem for k=0 -+ copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); -+ copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); -+ // -+ // Mainloop -+ // -+ -+ // Size of the k-tiles's outer product mode (k) -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (k_tile_count > -1) -+ { -+ // Pipeline the outer products with a static for loop -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ __syncthreads(); -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ __syncthreads(); -+ } -+ -+ // Load A, B smem->rmem for k+1 -+ int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ if (k_block == 0) -+ { -+ // Copy gmem to rmem -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); -+ if (--k_tile_count > 0) ++k_tile_iter; -+ } -+ -+ // transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ -+ // Thread-level register gemm for k -+ // disambiguate gemm (shared with the namespace name) -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm70TwoStage, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ static_assert(rank(SmemLayoutB{}) == 2, -+ "MainloopTwoStage must not have a smem shape with a pipeline mode."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // Allocate the register tiles for double buffering -- same shape as partitioned data -+ Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) -+ Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // PREFETCH -+ // -+ -+ // Clear the rmem tiles to account for predicated off loads -+ clear(tArA); -+ clear(tBrB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tArA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tArA(_,_,k)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBrB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBrB(_,_,k)); -+ } -+ } -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)); // (MMA,MMA_M,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsA = thr_copy_A.partition_S(sA); -+ Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M -+ -+ auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); -+ Tensor tCsB = thr_copy_B.partition_S(sB); -+ Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N -+ -+ // -+ // Prologue -+ // -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ // Clear accumulators -+ __syncthreads(); -+ -+ // Load A, B smem->rmem for k=0 -+ copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); -+ copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); -+ // -+ // Mainloop -+ // -+ -+ // Size of the k-tiles's outer product mode (k) -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (k_tile_count > -1) -+ { -+ // Pipeline the outer products with a static for loop -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ __syncthreads(); -+ -+ // Copy rmem to smem -+ copy(tArA, tAsA); -+ copy(tBrB, tBsB); -+ __syncthreads(); -+ } -+ -+ // Load A, B smem->rmem for k+1 -+ int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ if (k_block == 0) -+ { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tArA); -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBrB); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ -+ // Thread-level register gemm for k -+ // disambiguate gemm (shared with the namespace name) -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp -new file mode 100644 -index 0000000..6ba6ccc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp -@@ -0,0 +1,680 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm80CpAsyncUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm80CpAsyncUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, -+ "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); -+ static_assert(rank(SmemLayoutB{}) == 3, -+ "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M -+ CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N -+ CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_A; -+ GmemTiledCopyB gmem_tiled_copy_B; -+ auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); -+ auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ (void) residue_mnk; -+ //assert(residue_mnk == make_tuple(0,0,0)); -+ -+ // -+ // PREFETCH -+ // -+ -+ // Start async loads for all pipes but the last -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); -+ copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); -+ cp_async_fence(); -+ --k_tile_count; -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_A) == size(tiled_mma)); -+ CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_B) == size(tiled_mma)); -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); -+ auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); -+ Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) -+ Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K -+ -+ auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); -+ auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); -+ Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) -+ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Size of the register pipeline -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ // PREFETCH register pipeline -+ if (K_BLOCK_MAX > 1) { -+ // Wait until our first prefetched tile is loaded in -+ cp_async_wait(); -+ __syncthreads(); -+ -+ // Prefetch the first rmem from the first k-tile -+ copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); -+ } -+ -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Pipeline the outer products with a static for loop. -+ // -+ // Note, the for_each() function is required here to ensure `k_block` is of type Int. -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ // Slice the smem_pipe_read smem -+ tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Commit the smem for smem_pipe_read -+ cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ // Load A, B shmem->regs for k_block+1 -+ auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ // Copy gmem to smem before computing gemm on each k-pipe -+ if (k_block == 0) -+ { -+ copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ -+ // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) -+ smem_pipe_write = smem_pipe_read; -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; -+ } -+ -+ // Transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ // Thread-level register gemm for k_block -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm80CpAsync, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm80CpAsync; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, // (BLK_M, BLK_K, K_TILES) -+ TensorB gB, // (BLK_N, BLK_K, K_TILES) -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ -+ // Construct shared memory tiles -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M -+ CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N -+ CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_A; -+ GmemTiledCopyB gmem_tiled_copy_B; -+ auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); -+ auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // PREFETCH -+ // -+ -+ // Clear the smem tiles to account for predicated off loads -+ clear(tAsA); -+ clear(tBsB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ constexpr int k_pipe = 0; -+ -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tAsA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBsB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); -+ } -+ } -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Start async loads for 1st k-tile onwards, no k-residue handling needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync -+ copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA compute thread partitions and allocate accumulators -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) -+ Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K -+ -+ // -+ // Copy Atom retiling -+ // -+ -+ auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); -+ auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); -+ Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) -+ Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K -+ -+ auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); -+ auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); -+ Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) -+ Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Size of the register pipeline -+ auto K_BLOCK_MAX = size<2>(tCrA); -+ -+ // PREFETCH register pipeline -+ if (K_BLOCK_MAX > 1) { -+ // Wait until our first prefetched tile is loaded in -+ cp_async_wait(); -+ __syncthreads(); -+ -+ // Prefetch the first rmem from the first k-tile -+ copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); -+ } -+ -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Pipeline the outer products with a static for loop. -+ // -+ // Note, the for_each() function is required here to ensure `k_block` is of type Int. -+ for_each(make_int_sequence{}, [&] (auto k_block) -+ { -+ if (k_block == K_BLOCK_MAX - 1) -+ { -+ // Slice the smem_pipe_read smem -+ tCsA_p = tCsA(_,_,_,smem_pipe_read); -+ tCsB_p = tCsB(_,_,_,smem_pipe_read); -+ -+ // Commit the smem for smem_pipe_read -+ cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ // Load A, B shmem->regs for k_block+1 -+ auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static -+ copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); -+ copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); -+ // Copy gmem to smem before computing gemm on each k-pipe -+ if (k_block == 0) -+ { -+ // Set all predicates to false if we are going to overshoot bounds -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ ++k_tile_iter; -+ -+ // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) -+ smem_pipe_write = smem_pipe_read; -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; -+ } -+ -+ // Transform before compute -+ cute::transform(tCrA(_,_,k_block), TransformA{}); -+ cute::transform(tCrB(_,_,k_block), TransformB{}); -+ // Thread-level register gemm for k_block -+ cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); -+ }); -+ -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp -new file mode 100644 -index 0000000..3b1921b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp -@@ -0,0 +1,596 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/pipeline.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+ -+#include "cute/arch/copy_sm90.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+ -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90CpAsyncGmmaUnpredicated, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90CpAsyncGmmaUnpredicated; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC& accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf, -+ Params const& mainloop_params) -+ { -+ using namespace cute; -+ -+ (void) residue_mnk; -+ -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // Tile MMA atom and compute thread partitions across A, B and C -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ // Allocate registers for pipelining -+ Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // -+ // Prologue -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ // -+ // Pipelined Main Loop -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // Copy gmem to smem before computing gemm on each k-pipe -+ // pipe index in smem where the next gmem tile will be read into -+ copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); -+ copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); -+ cp_async_fence(); -+ if (k_tile_count > 0) { ++k_tile_iter; } -+ -+ // -+ // Compute on k_tile -+ // -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ -+ cp_async_wait(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), accum); -+ warpgroup_commit_batch(); -+ -+ // -+ // Advance the pipe -+ // -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; -+ -+ ++smem_pipe_write; -+ smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; -+ -+ // Wait for the pipeline MMAs to drain -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90CpAsyncGmma, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90CpAsyncGmma; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_a; -+ cute::array_aligned> smem_b; -+ }; -+ -+ struct Params { -+ ElementA const* ptr_A; -+ StrideA dA; -+ ElementB const* ptr_B; -+ StrideB dB; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CollectiveMma() = default; -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ return {args.ptr_A, args.dA, args.ptr_B, args.dB}; -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class FrgTensorD, -+ class TensorA, -+ class TensorB, -+ class FrgTensorC, -+ class KTileIterator, -+ class ResidueMNK -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ FrgTensorD &accum, -+ TensorA gA, -+ TensorB gB, -+ FrgTensorC const &src_accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ ResidueMNK residue_mnk, -+ int thread_idx, -+ char *smem_buf) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "D tensor must be rmem resident."); -+ static_assert(is_gmem::value, "A tensor must be gmem resident."); -+ static_assert(is_gmem::value, "B tensor must be gmem resident."); -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 warpgroup MMA must specify transforms through MMA_Atom."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_same::value, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(smem_buf); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) -+ // This aligns the tensor with BLK_K for all but the 0th k_tile -+ gA.data() = &gA(0, get<2>(residue_mnk), 0); -+ gB.data() = &gB(0, get<2>(residue_mnk), 0); -+ -+ // Partition the copying of A and B tiles across the threads -+ GmemTiledCopyA gmem_tiled_copy_a; -+ GmemTiledCopyB gmem_tiled_copy_b; -+ auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); -+ auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); -+ -+ Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) -+ Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) -+ Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) -+ Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) -+ -+ // -+ // PREDICATES -+ // -+ -+ // Allocate predicate tensors for m and n -+ Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); -+ Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); -+ -+ // Construct identity layout for sA and sB -+ Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -+ Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -+ -+ // Repeat the partitioning with identity layouts -+ Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) -+ Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -+ -+ // Set predicates for m bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < size<0>(tApA); ++m) { -+ tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m -+ } -+ // Set predicates for n bounds -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < size<0>(tBpB); ++n) { -+ tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n -+ } -+ -+ // -+ // Prologue/PREFETCH -+ // -+ -+ // Clear the smem tiles to account for predicated off loads -+ clear(tAsA); -+ clear(tBsB); -+ -+ // Start async loads for 0th k-tile, where we take care of the k residue -+ { -+ constexpr int k_pipe = 0; -+ -+ Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tAsA); ++k) { -+ if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) -+ copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); -+ } -+ } -+ Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < size<2>(tBsB); ++k) { -+ if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) -+ copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); -+ } -+ } -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // Start async loads for 1st k-tile onwards, no k-residue handling needed -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ --k_tile_count; -+ } -+ -+ // -+ // MMA Atom partitioning -+ // -+ -+ // Tile MMA atom and compute thread partitions across A, B and C -+ TiledMma tiled_mma; -+ auto thr_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ // Allocate registers for pipelining -+ Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) -+ Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(src_accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(src_accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // Current pipe index in smem to read from -+ int smem_pipe_read = 0; -+ // Current pipe index in smem to write to -+ int smem_pipe_write = DispatchPolicy::Stages-1; -+ -+ // -+ // Pipelined Main Loop -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) -+ { -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ if (k_tile_count <= 0) { -+ clear(tApA); -+ clear(tBpB); -+ } -+ copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); // CpAsync -+ copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); // CpAsync -+ cp_async_fence(); -+ ++k_tile_iter; -+ -+ // -+ // Compute on k_tile -+ // -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ -+ cp_async_wait(); -+ cute::gemm(tiled_mma, accum, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), src_accum); -+ warpgroup_commit_batch(); -+ -+ // -+ // Advance the pipe -+ // -+ ++smem_pipe_read; -+ smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; -+ -+ ++smem_pipe_write; -+ smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; -+ -+ // Wait for the pipeline MMAs to drain -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp -new file mode 100644 -index 0000000..25eaffb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cute/arch/copy_sm90.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+#include "cutlass/pipeline.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ int Stages, -+ class ClusterShape, -+ int PipelineAsyncMmaStages, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90TmaGmma, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90TmaGmma; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ using MainloopPipeline = cutlass::PipelineTmaAsync< -+ DispatchPolicy::Stages, -+ typename DispatchPolicy::ClusterShape>; -+ -+ using PipelineParams = typename MainloopPipeline::Params; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ // Tile along K mode first before tiling over MN. PIPE mode last as usual. -+ // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ -+ // TMA converts f32 input to tf32 when copying from GMEM to SMEM -+ // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. -+ static constexpr bool ConvertF32toTF32A = std::is_same_v; -+ static constexpr bool ConvertF32toTF32B = std::is_same_v; -+ using InternalElementA = std::conditional_t>>; -+ using InternalElementB = std::conditional_t>>; -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_A; -+ cute::array_aligned> smem_B; -+ -+ using PipelineStorage = typename MainloopPipeline::SharedStorage; -+ alignas(16) PipelineStorage pipeline_storage; -+ }; -+ -+ struct Params { -+ InternalElementA const* ptr_A; -+ StrideA dA; -+ InternalElementB const* ptr_B; -+ StrideB dB; -+ // Assumption: StrideA is congruent with Problem_MK -+ using TMA_A = decltype(make_tma_copy( -+ GmemTiledCopyA{}, -+ make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), -+ SmemLayoutA{}(_,_,0), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any -+ // Assumption: StrideB is congruent with Problem_NK -+ using TMA_B = decltype(make_tma_copy( -+ GmemTiledCopyB{}, -+ make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), -+ SmemLayoutB{}(_,_,0), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any -+ TMA_A tma_load_a; -+ TMA_B tma_load_b; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); -+ auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); -+ -+ Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); -+ Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); -+ typename Params::TMA_A tma_load_a = make_tma_copy( -+ GmemTiledCopyA{}, -+ tensor_a, -+ SmemLayoutA{}(_,_,cute::Int<0>{}), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{})); // mcast along N mode for this M load, if any -+ typename Params::TMA_B tma_load_b = make_tma_copy( -+ GmemTiledCopyB{}, -+ tensor_b, -+ SmemLayoutB{}(_,_,cute::Int<0>{}), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{})); // mcast along M mode for this N load, if any -+ return { -+ reinterpreted_ptr_A, -+ args.dA, -+ reinterpreted_ptr_B, -+ args.dB, -+ tma_load_a, -+ tma_load_b -+ }; -+ } -+ -+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance -+ CUTLASS_DEVICE -+ static void prefetch_tma_descriptors(Params const& mainloop_params) -+ { -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ template < -+ class TensorA, class TMA_LOAD_A, -+ class TensorB, class TMA_LOAD_B, -+ class FrgTensorC, -+ class KTileIterator -+ > -+ CUTLASS_DEVICE void -+ operator() ( -+ TensorA const& gA, TMA_LOAD_A& tma_load_a, -+ TensorB const& gB, TMA_LOAD_B& tma_load_b, -+ FrgTensorC& accum, -+ KTileIterator k_tile_iter, int k_tile_count, -+ int thread_idx, -+ char* shared_memory, -+ Params const& mainloop_params) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Prepare the TMA loads for A and B -+ // -+ dim3 cluster_local_block_id = cute::block_id_in_cluster(); -+ auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); -+ auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); -+ -+ // Applies the mapping from block_tma_a -+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) -+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) -+ -+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) -+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) -+ -+ // -+ // Prepare TMA membars and PREFETCH -+ // -+ -+ // Number of pipelined k-tiles in smem -+ constexpr int K_PIPE_MAX = DispatchPolicy::Stages; -+ -+ // NOTE: Another parameter: Partition the pipeline between active MMAs and active TMAs -+ // Tunable via the dispatch policy to tollerate latencies evenly across the math and compute stages -+ // K_PIPE_MMAS: The max number of active MMA pipes at beginning of every loop -+ // K_PIPE_TMAS: The max number of active TMA pipes at beginning of every loop (geq 1) -+ constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; -+ constexpr int K_PIPE_TMAS = K_PIPE_MAX - K_PIPE_MMAS; -+ static_assert(0 <= K_PIPE_MMAS && K_PIPE_MMAS < K_PIPE_MAX); -+ static_assert(0 < K_PIPE_TMAS && K_PIPE_TMAS <= K_PIPE_MAX); -+ -+ static_assert(K_PIPE_MMAS < K_PIPE_MAX - 1); -+ -+ // Set the bytes transferred in this TMA transaction (may involve multiple issues) -+ constexpr uint32_t TmaTransactionBytes = static_cast( -+ (size<0>(sA) * size<1>(sA) * sizeof(InternalElementA)) + -+ (size<0>(sB) * size<1>(sB) * sizeof(InternalElementB))); -+ -+ -+ // Obtain warp index -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ -+ PipelineParams params; -+ params.transaction_bytes = TmaTransactionBytes; -+ params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = NumThreadsPerWarpGroup; -+ -+ MainloopPipeline pipeline( -+ storage.pipeline_storage, -+ params); -+ -+ // State variables used for iterating the circular buffer -+ // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA -+ // smem_pipe_write is used by the producer of SMEM data - i.e TMA -+ PipelineState smem_pipe_read; -+ PipelineState smem_pipe_release; -+ PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ } -+ else { -+ __syncthreads(); -+ } -+ -+ // Set predicate for the lowest lane_id in the warp -+ int lane_predicate = cute::elect_one_sync(); -+ -+ uint16_t mcast_mask_a = 0; -+ uint16_t mcast_mask_b = 0; -+ // Keep a copy to know when to stop issuing loads -+ int k_tile_count_tma = k_tile_count; -+ -+ // Issue TmaLoads (Prologue fetches) -+ if (warp_idx == 0 && lane_predicate == 1) { -+ // Maps the tile -> block, value -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int n = 0; n < size<1>(block_layout); ++n) { -+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); -+ } -+ } -+ -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int m = 0; m < size<0>(block_layout); ++m) { -+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); -+ } -+ } -+ -+ // Issue the prologue loads -+ int prologue_tma_count = min(K_PIPE_MAX, k_tile_count); -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < prologue_tma_count; ++stage) { -+ pipeline.producer_acquire(smem_pipe_write); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,stage)); -+ ++k_tile_iter; -+ ++smem_pipe_write; -+ } -+ k_tile_count_tma -= prologue_tma_count; -+ } -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ TiledMma tiled_mma; -+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ // Allocate "fragments/descriptors" -+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE -+ CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ __syncthreads(); -+ -+ warpgroup_fence_operand(accum); -+ // Prologue MMAs -+ CUTLASS_PRAGMA_UNROLL -+ for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); -+ prologue_mma_count > 0; --prologue_mma_count) -+ { -+ // WAIT on smem_pipe_read until it's data is available -+ pipeline.consumer_wait(smem_pipe_read); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ ++smem_pipe_read; -+ --k_tile_count; -+ } -+ warpgroup_fence_operand(accum); -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // WAIT on smem_pipe_read until data is available -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ // -+ // Compute on k_tile -+ // -+ -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed -+ warpgroup_wait(); -+ warpgroup_fence_operand(accum); -+ -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK wr stage, done _computing_ on it -+ -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ -+ // Do Acquire & Load only if needed - helps with both performance and also corner case illegal barrier-ops -+ if (warp_idx == 0 && lane_predicate == 1 && (k_tile_count_tma > 0) ) { -+ pipeline.producer_acquire(smem_pipe_write); // LOCK wr stage, for _writing_ -+ -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write.index()); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write.index())); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write.index())); -+ ++smem_pipe_write; -+ ++k_tile_iter; -+ --k_tile_count_tma; -+ } -+ -+ // Advance consumer pipeline -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Wait on all GMMAs -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ -+ // Workaround for ensuring Smem destruction doesn't happen accidentally -+ if constexpr (size(typename DispatchPolicy::ClusterShape{}) > 1) { -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp -new file mode 100644 -index 0000000..41b0f13 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp -@@ -0,0 +1,494 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cute/arch/copy_sm90.hpp" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/algorithm/functional.hpp" -+#include "cute/atom/mma_atom.hpp" -+#include "cute/algorithm/gemm.hpp" -+#include "cute/tensor_predicate.hpp" -+#include "cute/numeric/arithmetic_tuple.hpp" -+#include "cutlass/pipeline.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::collective { -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// WarpSpecialized Mainloop -+template < -+ int Stages, -+ class ClusterShape, -+ class KernelSchedule, -+ class TileShape_, -+ class ElementA_, -+ class StrideA_, -+ class ElementB_, -+ class StrideB_, -+ class TiledMma_, -+ class GmemTiledCopyA_, -+ class SmemLayoutAtomA_, -+ class SmemCopyAtomA_, -+ class TransformA_, -+ class GmemTiledCopyB_, -+ class SmemLayoutAtomB_, -+ class SmemCopyAtomB_, -+ class TransformB_> -+struct CollectiveMma< -+ MainloopSm90TmaGmmaWarpSpecialized, -+ TileShape_, -+ ElementA_, -+ StrideA_, -+ ElementB_, -+ StrideB_, -+ TiledMma_, -+ GmemTiledCopyA_, -+ SmemLayoutAtomA_, -+ SmemCopyAtomA_, -+ TransformA_, -+ GmemTiledCopyB_, -+ SmemLayoutAtomB_, -+ SmemCopyAtomB_, -+ TransformB_> -+{ -+ // -+ // Type Aliases -+ // -+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; -+ using TileShape = TileShape_; -+ using ElementA = ElementA_; -+ using StrideA = StrideA_; -+ using ElementB = ElementB_; -+ using StrideB = StrideB_; -+ using TiledMma = TiledMma_; -+ using ElementAccumulator = typename TiledMma::ValTypeC; -+ using GmemTiledCopyA = GmemTiledCopyA_; -+ using GmemTiledCopyB = GmemTiledCopyB_; -+ using SmemLayoutAtomA = SmemLayoutAtomA_; -+ using SmemLayoutAtomB = SmemLayoutAtomB_; -+ using SmemCopyAtomA = SmemCopyAtomA_; -+ using SmemCopyAtomB = SmemCopyAtomB_; -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ using ArchTag = typename DispatchPolicy::ArchTag; -+ -+ using MainloopPipeline = cutlass::PipelineTmaAsync< -+ DispatchPolicy::Stages, -+ typename DispatchPolicy::ClusterShape>; -+ using PipelineState = cutlass::PipelineState; -+ -+ using PipelineParams = typename MainloopPipeline::Params; -+ -+ static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); -+ static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); -+ -+ // Tile along K mode first before tiling over MN. PIPE mode last as usual. -+ // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. -+ using SmemLayoutA = decltype(tile_to_shape( -+ SmemLayoutAtomA{}, -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ using SmemLayoutB = decltype(tile_to_shape( -+ SmemLayoutAtomB{}, -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), -+ Step<_2,_1,_3>{})); -+ -+ static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); -+ static_assert(std::is_base_of::value && -+ std::is_base_of::value, -+ "MMA atom must source both A and B operand from smem_desc for this mainloop."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ static_assert(std::is_same_v || std::is_same_v, -+ "GmemTiledCopy - invalid SM90 TMA copy atom specified."); -+ -+ // TMA converts f32 input to tf32 when copying from GMEM to SMEM -+ // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. -+ static constexpr bool ConvertF32toTF32A = std::is_same_v; -+ static constexpr bool ConvertF32toTF32B = std::is_same_v; -+ using InternalElementA = std::conditional_t>>; -+ using InternalElementB = std::conditional_t>>; -+ -+ struct SharedStorage -+ { -+ cute::array_aligned> smem_A; -+ cute::array_aligned> smem_B; -+ -+ using PipelineStorage = typename MainloopPipeline::SharedStorage; -+ alignas(16) PipelineStorage pipeline_storage; -+ }; -+ -+ struct Params { -+ InternalElementA const* ptr_A; -+ StrideA dA; -+ InternalElementB const* ptr_B; -+ StrideB dB; -+ // Assumption: StrideA is congruent with Problem_MK -+ using TMA_A = decltype(make_tma_copy( -+ GmemTiledCopyA{}, -+ make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), -+ SmemLayoutA{}(_,_,0), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any -+ // Assumption: StrideB is congruent with Problem_NK -+ using TMA_B = decltype(make_tma_copy( -+ GmemTiledCopyB{}, -+ make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), -+ SmemLayoutB{}(_,_,0), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any -+ TMA_A tma_load_a; -+ TMA_B tma_load_b; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ template -+ static constexpr Params -+ to_underlying_arguments(Args const& args, void* workspace) { -+ (void) workspace; -+ // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); -+ auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); -+ -+ Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); -+ Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); -+ typename Params::TMA_A tma_load_a = make_tma_copy( -+ GmemTiledCopyA{}, -+ tensor_a, -+ SmemLayoutA{}(_,_,cute::Int<0>{}), -+ make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), -+ size<1>(ClusterShape{})); // mcast along N mode for this M load, if any -+ typename Params::TMA_B tma_load_b = make_tma_copy( -+ GmemTiledCopyB{}, -+ tensor_b, -+ SmemLayoutB{}(_,_,cute::Int<0>{}), -+ make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), -+ size<0>(ClusterShape{})); // mcast along M mode for this N load, if any -+ return { -+ reinterpreted_ptr_A, -+ args.dA, -+ reinterpreted_ptr_B, -+ args.dB, -+ tma_load_a, -+ tma_load_b -+ }; -+ } -+ -+ static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; -+ static constexpr int K_PIPE_MMAS = 1; -+ static constexpr uint32_t TmaTransactionBytes = -+ (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ -+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); -+ -+ CUTLASS_DEVICE -+ static MainloopPipeline make_pipeline(char* shared_memory, PipelineParams params){ -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ return {shared_storage.pipeline_storage, params}; -+ } -+ -+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance -+ CUTLASS_DEVICE -+ static void prefetch_tma_descriptors(Params const& mainloop_params) -+ { -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); -+ cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ /// Producer Perspective -+ template < -+ class TensorA, class TMA_LOAD_A, -+ class TensorB, class TMA_LOAD_B, -+ class KTileIterator -+ > -+ CUTLASS_DEVICE void -+ dma(MainloopPipeline pipeline, -+ PipelineState smem_pipe_write, -+ TensorA const& gA, TMA_LOAD_A& tma_load_a, -+ TensorB const& gB, TMA_LOAD_B& tma_load_b, -+ KTileIterator k_tile_iter, int k_tile_count, -+ int thread_idx, -+ char* shared_memory) -+ { -+ -+ using namespace cute; -+ int warp_idx = canonical_warp_idx(); -+ int warp_idx_in_warp_group = warp_idx % 4; -+ int lane_predicate = cute::elect_one_sync(); -+ -+ if (warp_idx_in_warp_group == 0 and lane_predicate) { -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Prepare the TMA loads for A and B -+ // -+ -+ dim3 cluster_local_block_id = cute::block_id_in_cluster(); -+ auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); -+ auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); -+ -+ // Applies the mapping from block_tma_a -+ Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) -+ Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) -+ -+ Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) -+ Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) -+ -+ uint16_t mcast_mask_a = 0; -+ uint16_t mcast_mask_b = 0; -+ -+ // Issue TmaLoads -+ // Maps the tile -> block, value -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int n = 0; n < size<1>(block_layout); ++n) { -+ mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); -+ } -+ } -+ -+ if constexpr (std::is_same_v) { -+ auto block_layout = Layout{}; // (m,n) -> block_id -+ for (int m = 0; m < size<0>(block_layout); ++m) { -+ mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); -+ } -+ } -+ -+ // Issue the prologue loads -+ int k_tile_prologue = min(k_tile_count, K_PIPE_MAX); -+ CUTLASS_PRAGMA_UNROLL -+ for (int count = 0; count < k_tile_prologue; ++count) { -+ pipeline.producer_acquire(smem_pipe_write); -+ int write_stage = smem_pipe_write.index(); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); -+ ++k_tile_iter; -+ ++smem_pipe_write; -+ } -+ k_tile_count -= k_tile_prologue; -+ -+ // Mainloop -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // LOCK smem_pipe_write for _writing_ -+ pipeline.producer_acquire(smem_pipe_write); -+ -+ // -+ // Copy gmem to smem for *k_tile_iter -+ // -+ -+ int write_stage = smem_pipe_write.index(); -+ using BarrierType = typename MainloopPipeline::ValueType; -+ BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); -+ -+ copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); -+ copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); -+ ++k_tile_iter; -+ -+ // Advance smem_pipe_write -+ ++smem_pipe_write; -+ } -+ } -+ } -+ -+ /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster -+ CUTLASS_DEVICE void -+ dma_epilogue(MainloopPipeline pipeline, -+ PipelineState smem_pipe_write) -+ { -+ int warp_idx = canonical_warp_idx(); -+ int warp_idx_in_warp_group = warp_idx % 4; -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue the epilogue waits -+ if (warp_idx_in_warp_group == 0 and lane_predicate) { -+ /* This helps avoid early exit of blocks in Cluster -+ * Waits for all stages to either be released (all -+ * Consumer UNLOCKs), or if the stage was never used -+ * then would just be acquired since the phase was -+ * still inverted from make_producer_start_state -+ */ -+ for (int count = 0; count < K_PIPE_MAX; ++count) { -+ pipeline.producer_acquire(smem_pipe_write); -+ ++smem_pipe_write; -+ } -+ } -+ } -+ -+ /// Perform a collective-scoped matrix multiply-accumulate -+ /// Consumer Perspective -+ template < -+ class FrgTensorC -+ > -+ CUTLASS_DEVICE void -+ mma(MainloopPipeline pipeline, -+ PipelineState smem_pipe_read, -+ FrgTensorC& accum, -+ int k_tile_count, -+ int thread_idx, -+ char* shared_memory, -+ Params const& mainloop_params -+ ) -+ { -+ using namespace cute; -+ -+ static_assert(is_rmem::value, "C tensor must be rmem resident."); -+ static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); -+ static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ static_assert(std::is_void_v, -+ "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); -+ -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) -+ Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) -+ -+ // -+ // Define C accumulators and A/B partitioning -+ // -+ -+ TiledMma tiled_mma; -+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx); -+ -+ Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ // Allocate "fragments/descriptors" -+ Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) -+ Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) -+ -+ CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M -+ CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N -+ CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K -+ CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE -+ CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE -+ -+ // -+ // PIPELINED MAIN LOOP -+ // -+ static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), -+ "ERROR : Incorrect number of MMAs in flight"); -+ -+ // We release buffers to producer warps(dma) with some mmas in flight -+ PipelineState smem_pipe_release = smem_pipe_read; -+ -+ // Prologue GMMAs -+ int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); -+ -+ warpgroup_fence_operand(accum); -+ CUTLASS_PRAGMA_UNROLL -+ for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) -+ { -+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ int read_stage = smem_pipe_read.index(); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ ++smem_pipe_read; -+ } -+ -+ warpgroup_fence_operand(accum); -+ // Mainloop GMMAs -+ k_tile_count -= prologue_mma_count; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; k_tile_count > 0; --k_tile_count) -+ { -+ // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ // -+ // Compute on k_tile -+ // -+ -+ int read_stage = smem_pipe_read.index(); -+ warpgroup_fence_operand(accum); -+ warpgroup_arrive(); -+ cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) -+ warpgroup_commit_batch(); -+ -+ /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed -+ warpgroup_wait(); -+ warpgroup_fence_operand(accum); -+ -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it -+ -+ // Advance smem_pipe_read and smem_pipe_release -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Wait on all GMMAs to complete -+ warpgroup_wait<0>(); -+ warpgroup_fence_operand(accum); -+ -+ for (int count = 0; count < prologue_mma_count; ++count) { -+ pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it -+ ++smem_pipe_release; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::collective -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h -new file mode 100644 -index 0000000..2e9398a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/base_grouped.h -@@ -0,0 +1,479 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Base device-level grouped kernel. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/trace.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM Grouped -+template -+class BaseGrouped { -+public: -+ -+ using BaseKernel = BaseKernel_; -+ -+ using ElementA = typename BaseKernel::ElementA; -+ using LayoutA = typename BaseKernel::LayoutA; -+ using TensorRefA = TensorRef; -+ static ComplexTransform const kTransformA = BaseKernel::kTransformA; -+ static int const kAlignmentA = BaseKernel::kAlignmentA; -+ -+ using ElementB = typename BaseKernel::ElementB; -+ using LayoutB = typename BaseKernel::LayoutB; -+ using TensorRefB = TensorRef; -+ static ComplexTransform const kTransformB = BaseKernel::kTransformB; -+ static int const kAlignmentB = BaseKernel::kAlignmentB; -+ -+ using ElementC = typename BaseKernel::ElementC; -+ using LayoutC = typename BaseKernel::LayoutC; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ static int const kAlignmentC = BaseKernel::kAlignmentC; -+ -+ using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; -+ -+ using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; -+ using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; -+ -+ using Operator = typename BaseKernel::Operator; -+ using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; -+ -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename WarpMmaOperator::MathOperator; -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ using ThreadblockShape = typename BaseKernel::Mma::Shape; -+ using WarpShape = typename BaseKernel::WarpShape; -+ using InstructionShape = typename BaseKernel::InstructionShape; -+ static int const kStages = BaseKernel::Mma::kStages; -+ -+ /// Argument structure -+ using Arguments = typename BaseKernel::Arguments; -+ -+ using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; -+ -+protected: -+ -+ /// Kernel parameters object -+ typename BaseKernel::Params params_; -+ -+private: -+ -+ /// Get the number of tiles across all problems in a group -+ static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { -+ int32_t tiles = 0; -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; -+ BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); -+ tiles += problem_tile_count(problem); -+ } -+ return tiles; -+ } -+ -+ /// Copy from `data` to `workspace` -+ Status copy_to_workspace(void* workspace, void* data, size_t bytes) { -+ cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); -+ if (cuda_error != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ cuda_error = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaMemcpy() returned error " -+ << cudaGetErrorString(cuda_error)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Precomputes scheduling information for the grouped GEMM -+ Status precompute(Arguments const &args, int32_t tile_count, void* workspace) { -+ size_t workspace_bytes = get_workspace_size(args); -+ std::vector host_workspace(workspace_bytes); -+ BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, -+ args.problem_count, -+ args.threadblock_count, -+ (void*)host_workspace.data()); -+ return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); -+ } -+ -+ /// Reorder `data` according to `indices` -+ template -+ static void reorder_array(T* data, const std::vector& indices) { -+ // For now, simply create a copy of the data and then copy over to the original. -+ std::vector copy(indices.size()); -+ for (int i = 0; i < indices.size(); ++i) { -+ copy.at(i) = data[indices[i]]; -+ } -+ -+ memcpy(data, copy.data(), indices.size() * sizeof(T)); -+ } -+ -+public: -+ -+ /// Constructs the GEMM. -+ BaseGrouped() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return BaseKernel::can_implement(args); -+ } -+ -+ /// Get the number of tiles in a problem -+ static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) { -+ auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); -+ return BaseKernel::ProblemVisitor::tile_count(grid); -+ } -+ -+ /// Get the number of tiles across all problems in a group -+ static int32_t group_tile_count(Arguments const &args) { -+ if (args.host_problem_sizes == nullptr) { -+ CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); -+ return -1; -+ } -+ -+ return group_tile_count(args.host_problem_sizes, args.problem_count); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, -+ args.problem_count, -+ args.threadblock_count); -+ } else { -+ return 0; -+ } -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ -+ return dim3(args.threadblock_count, 1, 1); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ -+ CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); -+ -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); -+ -+ cudaError_t result; -+ if (smem_size > (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaFuncSetAttribute() returned error " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ } -+ -+ int max_active_blocks = -1; -+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( -+ &max_active_blocks, -+ Kernel, -+ BaseKernel::kThreadCount, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST( -+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); -+ return max_active_blocks; -+ } -+ -+ /// Sorts each pointer passed in according to the indices that sort -+ /// `problem_sizes_ptr` in descending order of problem-K dimension. -+ static void sort_problems(int problem_count, -+ cutlass::gemm::GemmCoord* problem_sizes_ptr, -+ int64_t* lda_host_ptr, -+ int64_t* ldb_host_ptr, -+ int64_t* ldc_host_ptr, -+ int64_t* ldd_host_ptr, -+ int64_t* offset_A_ptr, -+ int64_t* offset_B_ptr, -+ int64_t* offset_C_ptr, -+ int64_t* offset_D_ptr) -+ { -+ std::vector indices(problem_count); -+ std::iota(indices.begin(), indices.end(), 0); -+ std::stable_sort(indices.begin(), indices.end(), -+ [&problem_sizes_ptr](size_t i, size_t j) { -+ return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); -+ }); -+ -+ reorder_array(problem_sizes_ptr, indices); -+ reorder_array(lda_host_ptr, indices); -+ reorder_array(ldb_host_ptr, indices); -+ reorder_array(ldc_host_ptr, indices); -+ reorder_array(ldd_host_ptr, indices); -+ reorder_array(offset_A_ptr, indices); -+ reorder_array(offset_B_ptr, indices); -+ reorder_array(offset_C_ptr, indices); -+ reorder_array(offset_D_ptr, indices); -+ } -+ -+ /// Computes the number of threadblocks to launch for the grouped kernel -+ static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr, -+ int problem_count=0, -+ int available_sm_count=-1) { -+ // Determine the number of blocks that would be launched to fill up a single -+ // wave on the GPU with each SM having maximum occupancy. -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ -+ bool override_sm_count = (available_sm_count < 0 || available_sm_count > properties.multiProcessorCount); -+ if (override_sm_count) { -+ available_sm_count = properties.multiProcessorCount; -+ } -+ -+ int max_active_blocks = maximum_active_blocks(); -+ if (max_active_blocks <= 0) { -+ return 0; -+ } -+ -+ int occupancy_based_block_count = available_sm_count * max_active_blocks; -+ -+ if (problem_sizes_ptr == nullptr || problem_count == 0) { -+ return occupancy_based_block_count; -+ } -+ -+ int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); -+ -+ // If the group contains a single problem, launching the exact number of -+ // threadblocks needed to cover the problem minimizes the work performed -+ // per threadblock in finding the next tile to compute. We return total_tiles -+ // unless the user has provided the SM count. -+ if (problem_count == 1 && override_sm_count) { -+ return total_tiles; -+ } -+ -+ // Choose between the full wave of threadblocks and the tile count. If there -+ // are fewer tiles in the group than threadblocks in the full wave, only -+ // some threadblocks will be assigned tiles. Those threadblocks -+ // which are not assigned tiles still need to perform the work of iterating through -+ // problem sizes to determine that they have no work to do. This competes for cycles -+ // with those threadblocks that are assigned tiles to compute. -+ return min(total_tiles, occupancy_based_block_count); -+ } -+ -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ // Workspace -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ int32_t tile_count = group_tile_count(args); -+ Status status = precompute(args, tile_count, workspace); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ params_ = typename BaseKernel::Params(args, workspace, tile_count); -+ } else { -+ params_ = typename BaseKernel::Params(args, workspace); -+ } -+ -+ // Specify shared memory capacity for kernel. -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { -+ int32_t tile_count = group_tile_count(args); -+ Status status = precompute(args, tile_count, workspace); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ params_.update(args, workspace, tile_count); -+ } else { -+ params_.update(args, workspace); -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Configure grid and block dimensions -+ // -+ -+ if (!params_.problem_visitor.problem_count) { -+ return Status::kSuccess; -+ } -+ -+ dim3 grid(params_.threadblock_count, 1, 1); -+ dim3 block(BaseKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ // -+ // Launch kernel -+ // -+ -+ // Launch -+ cutlass::Kernel<<>>(params_); -+ -+ // -+ // Query for errors -+ // -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ // Call cudaGetLastError() to clear the error bit -+ result = cudaGetLastError(); -+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Initializes and runs the kernel. -+ Status operator()( -+ Arguments const &args, -+ void *workspace, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h b/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h -new file mode 100644 -index 0000000..46ef274 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h -@@ -0,0 +1,818 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Definitions for GEMM structures -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename OperatorClass, -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator -+> -+struct DefaultGemmConfiguration; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassSimt, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ using ThreadblockShape = GemmShape<128, 128, 8>; -+ using WarpShape = GemmShape<32, 64, 8>; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementC> -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 4; -+ static int const kAlignmentB = 4; -+ using ThreadblockShape = GemmShape<128, 128, 32>; -+ using WarpShape = GemmShape<32, 64, 32>; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 1, -+ int32_t, -+ float -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ArchTag, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 32>; -+ using WarpShape = GemmShape<64, 64, 32>; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename ElementB, -+ typename ElementC, -+ typename ElementAccumulator> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementC, -+ ElementAccumulator> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ using ThreadblockShape = GemmShape<128, 256, 32>; -+ using WarpShape = GemmShape<64, 64, 32>; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using Operator = typename platform::conditional< -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<8, 8, 16>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ int4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<8, 8, 32>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ uint1b_t, -+ uint1b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 512>; -+ using WarpShape = GemmShape<64, 64, 512>; -+ using InstructionShape = GemmShape<8, 8, 128>; -+ static int const kStages = 2; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpXorPopc; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 16>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = typename platform::conditional< -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 16>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+ -+template <> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ complex, -+ complex, -+ complex, -+ complex -+ > { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<64, 64, 16>; -+ using WarpShape = GemmShape<32, 32, 16>; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ complex, 1, complex, -+ complex>; -+ -+ using Operator = arch::OpMultiplyAddComplex; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint8_t, -+ int8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint8_t, -+ uint8_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 32>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ int4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint4b_t, -+ int4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint4b_t, -+ uint4b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 128>; -+ using WarpShape = GemmShape<64, 64, 128>; -+ using InstructionShape = GemmShape<16, 8, 64>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAddSaturate; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementC> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm80, -+ uint1b_t, -+ uint1b_t, -+ ElementC, -+ int32_t> { -+ -+ static int const kAlignmentA = 128 / sizeof_bits::value; -+ static int const kAlignmentB = 128 / sizeof_bits::value; -+ -+ using ThreadblockShape = GemmShape<128, 256, 512>; -+ using WarpShape = GemmShape<64, 64, 512>; -+ using InstructionShape = GemmShape<16, 8, 256>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< -+ ElementC, 128 / sizeof_bits::value, int32_t, float>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemmConfiguration { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<128, 256, 64>; -+ using WarpShape = GemmShape<64, 64, 64>; -+ using InstructionShape = GemmShape<16, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ ElementC, 128 / sizeof_bits::value, ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Operator = arch::OpMultiplyAdd; -+}; -+ -+template <> -+struct DefaultGemmConfiguration< -+ arch::OpClassTensorOp, -+ arch::Sm90, -+ complex, -+ complex, -+ complex, -+ complex -+ > { -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ -+ using ThreadblockShape = GemmShape<64, 64, 16>; -+ using WarpShape = GemmShape<32, 32, 16>; -+ using InstructionShape = GemmShape<16, 8, 4>; -+ static int const kStages = 3; -+ -+ using EpilogueOutputOp = epilogue::thread::LinearCombination< -+ complex, 1, complex, -+ complex>; -+ -+ using Operator = arch::OpMultiplyAddComplex; -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h -new file mode 100644 -index 0000000..d8698a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/ell_gemm.h -@@ -0,0 +1,848 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a Block-Ell sparse gemm kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/ell_gemm.h" -+ -+#include "cutlass/gemm/kernel/default_ell_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS -+ Blocked-Ell kernels that may be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ Example of a CUTLASS EllGemm operator is as follows: -+ -+ // -+ // Instantiate the CUTLASS EllGemm operator. -+ // -+ -+ cutlass::gemm::device::EllGemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ cutlass::half_t, 128 / cutlass::sizeof_bits::value, -+ float, float>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, // Stages -+ 128 / cutlass::sizeof_bits::value, // Alignment A -+ 128 / cutlass::sizeof_bits::value // Alignment B -+ > ellgemm_op; -+ -+ // -+ // Launch the EllGemm operation on the device -+ // -+ -+ Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: -+ a_rows - Rows in the sparse matrix. -+ a_cols - Colums in the sparse matrix. -+ BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in -+ consecutive blocks, whose size is (a_rows * a_ell_num_columns) -+ ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is -+ (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) -+ a_ell_blocksize - Size of the ELL-Blocks. -+ a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) -+ B - Input dense matrix whose size is (a_cols * n) -+ C/D - Output dense matrix whose size is (a_rows * n) -+ -+ cutlass::Status status = ellgemm_op({ -+ {a_rows, n, a_cols}, // GemmCoord problem_size -+ {BlockedEllA, lda}, // TensorRef ref_BlockedEllA -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*) -+ a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int) -+ a_ell_blocksize, // Size of the ELL-Blocks (int) -+ a_ell_base, // Base index of ellColInd (int) - Zero or One -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ -+ /// Supports split-K with serial reduction -+ bool SplitKSerial, -+ -+ /// Operation performed by GEMM -+ typename Operator, -+ -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+ class EllGemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse = true -+ > -+class EllGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kIsASparse = IsASparse; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultEllGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kIsASparse -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ const int* ell_idx_, -+ int ell_ncol_, -+ int ell_blocksize_, -+ int ell_base_idx_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ell_idx(ell_idx_), -+ ell_ncol(ell_ncol_), -+ ell_blocksize(ell_blocksize_), -+ ell_base_idx(ell_base_idx_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ EllGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {args.ell_blocksize, -+ ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ell_idx, -+ args.ell_ncol, -+ args.ell_blocksize, -+ args.ell_base_idx, -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ return Status::kSuccess; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ return set(args, grid_shape, workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+class EllGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ static bool const kIsASparse = false; -+ -+ using UnderlyingOperator = EllGemm< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ SplitKSerial, -+ Operator, -+ kIsASparse -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ const int* ell_idx_, -+ int ell_ncol_, -+ int ell_blocksize_, -+ int ell_base_idx_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ell_idx(ell_idx_), -+ ell_ncol(ell_ncol_), -+ ell_blocksize(ell_blocksize_), -+ ell_base_idx(ell_base_idx_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ EllGemm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.ell_idx, -+ args.ell_ncol, -+ args.ell_blocksize, -+ args.ell_base_idx, -+ args.epilogue, -+ args.split_k_slices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ -+ // Initialize the Params structure -+ return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ set(args, grid_shape, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h -new file mode 100644 -index 0000000..68fa29b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm.h -@@ -0,0 +1,771 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute> -+class Gemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ // For gather+scatter operations -+ int const *gather_A_indices; -+ int const *gather_B_indices; -+ int const *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ int const *gather_A_indices_ = nullptr, -+ int const *gather_B_indices_ = nullptr, -+ int const *scatter_D_indices_ = nullptr -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ gather_A_indices(gather_A_indices_), -+ gather_B_indices(gather_B_indices_), -+ scatter_D_indices(scatter_D_indices_) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ Gemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.epilogue, -+ static_cast(workspace), -+ args.gather_A_indices, -+ args.gather_B_indices, -+ args.scatter_D_indices -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class Gemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = Gemm< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ SplitKSerial, -+ Operator, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ // For gather+scatter operations -+ int *gather_A_indices; -+ int *gather_B_indices; -+ int *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ int *gather_A_indices_ = nullptr, -+ int *gather_B_indices_ = nullptr, -+ int *scatter_D_indices_ = nullptr -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ gather_A_indices(gather_A_indices_), -+ gather_B_indices(gather_B_indices_), -+ scatter_D_indices(scatter_D_indices_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ Gemm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices, -+ args.gather_B_indices, -+ args.gather_A_indices, -+ args.scatter_D_indices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h -new file mode 100644 -index 0000000..dd244f8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_array.h -@@ -0,0 +1,737 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_array.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmArray { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ using Operator = Operator_; -+ -+ /// Define the kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ false, -+ Operator -+ >::GemmKernel; -+ -+ using GemmKernel = kernel::GemmArray; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ -+ ElementA const * const *ptr_A; -+ LayoutA layout_A; -+ -+ ElementB const * const *ptr_B; -+ LayoutB layout_B; -+ -+ ElementC const * const *ptr_C; -+ LayoutC layout_C; -+ -+ ElementC * const * ptr_D; -+ LayoutC layout_D; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ ElementA const * const *ptr_A_, -+ LayoutA layout_A_, -+ ElementB const * const *ptr_B_, -+ LayoutB layout_B_, -+ ElementC const * const *ptr_C_, -+ LayoutC layout_C_, -+ ElementC * const * ptr_D_, -+ LayoutC layout_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ptr_A(ptr_A_), -+ layout_A(layout_A_), -+ ptr_B(ptr_B_), -+ layout_B(layout_B_), -+ ptr_C(ptr_C_), -+ layout_C(layout_C_), -+ ptr_D(ptr_D_), -+ layout_D(layout_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmArray() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (args.layout_A.stride(0) % kAlignmentA) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_B.stride(0) % kAlignmentB) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_C.stride(0) % kAlignmentC) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (args.layout_D.stride(0) % kAlignmentC) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ptr_A, -+ args.layout_A, -+ args.ptr_B, -+ args.layout_B, -+ args.ptr_C, -+ args.layout_C, -+ args.ptr_D, -+ args.layout_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ args.batch_count, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); -+ -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ptr_A, -+ args.layout_A, -+ args.ptr_B, -+ args.layout_B, -+ args.ptr_C, -+ args.layout_C, -+ args.ptr_D, -+ args.layout_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ typename Operator_ -+> -+class GemmArray< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ AlignmentA, -+ AlignmentB, -+ Operator_ -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = false; -+ -+ // -+ using UnderlyingOperator = GemmArray< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ -+ ElementA const * const *ptr_A; -+ LayoutA layout_A; -+ -+ ElementB const * const *ptr_B; -+ LayoutB layout_B; -+ -+ ElementC const * const *ptr_C; -+ LayoutC layout_C; -+ -+ ElementC * const * ptr_D; -+ LayoutC layout_D; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ ElementA const * const *ptr_A_, -+ LayoutA layout_A_, -+ ElementB const * const *ptr_B_, -+ LayoutB layout_B_, -+ ElementC const * const *ptr_C_, -+ LayoutC layout_C_, -+ ElementC * const * ptr_D_, -+ LayoutC layout_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ptr_A(ptr_A_), -+ layout_A(layout_A_), -+ ptr_B(ptr_B_), -+ layout_B(layout_B_), -+ ptr_C(ptr_C_), -+ layout_C(layout_C_), -+ ptr_D(ptr_D_), -+ layout_D(layout_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmArray() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ -+ GemmCoord problem_size{ -+ args.problem_size.n(), -+ args.problem_size.m(), -+ args.problem_size.k() -+ }; -+ -+ return UnderlyingArguments( -+ problem_size, -+ args.ptr_B, -+ args.layout_B.stride(), -+ args.ptr_A, -+ args.layout_A.stride(), -+ args.ptr_C, -+ args.layout_C.stride(), -+ args.ptr_D, -+ args.layout_D.stride(), -+ args.epilogue, -+ args.batch_count -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h -new file mode 100644 -index 0000000..6f510e9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_batched.h -@@ -0,0 +1,703 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined batch GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_batched.h" -+ -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmBatched { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ using Operator = Operator_; -+ -+ /// Define the kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ false, -+ Operator -+ >::GemmKernel; -+ -+ using GemmKernel = kernel::GemmBatched; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ int64_t stride_A; -+ TensorRef ref_B; -+ int64_t stride_B; -+ TensorRef ref_C; -+ int64_t stride_C; -+ TensorRef ref_D; -+ int64_t stride_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ int64_t stride_A_, -+ TensorRef ref_B_, -+ int64_t stride_B_, -+ TensorRef ref_C_, -+ int64_t stride_C_, -+ TensorRef ref_D_, -+ int64_t stride_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmBatched() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.stride_A, -+ args.ref_B.non_const_ref(), -+ args.stride_B, -+ args.ref_C.non_const_ref(), -+ args.stride_C, -+ args.ref_D, -+ args.stride_D, -+ args.epilogue, -+ args.batch_count -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ typename Operator_ -+> -+class GemmBatched< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ AlignmentA, -+ AlignmentB, -+ Operator_ -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = false; -+ -+ // -+ using UnderlyingOperator = GemmBatched< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ int64_t stride_A; -+ TensorRef ref_B; -+ int64_t stride_B; -+ TensorRef ref_C; -+ int64_t stride_C; -+ TensorRef ref_D; -+ int64_t stride_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int batch_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ int64_t stride_A_, -+ TensorRef ref_B_, -+ int64_t stride_B_, -+ TensorRef ref_C_, -+ int64_t stride_C_, -+ TensorRef ref_D_, -+ int64_t stride_D_, -+ typename EpilogueOutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmBatched() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ args.stride_B, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ args.stride_A, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ args.stride_C, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.stride_D, -+ args.epilogue, -+ args.batch_count -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h -new file mode 100644 -index 0000000..5bd856f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_complex.h -@@ -0,0 +1,717 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM -+ kernels that may be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters -+ onto specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel -+ parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most -+ plausible GEMM configurations for each supported architecture. Consequently, -+ not all parameters are exposed to the top-level interface. Rather, sensible -+ defaults at each level of the CUTLASS hierarchy are selected to tradeoff -+ simplicity of the interface with flexibility. We expect most configurations to -+ be specified at this level. Applications with more exotic requirements may -+ construct their kernels of interest using CUTLASS components at the -+ threadblock, warp, and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects -+ compose some internal state with an overloaded function call operator. This -+ enables decoupling of initialization from execution, possibly reducing -+ overhead during steady state phases of application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each -+ logical input to the computation. This is distinct from the kernel-level -+ Params structure pattern which contains application-specific precomputed state -+ needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's -+ SGEMM NN is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator -+ // (selects complex or gaussian complex) -+ typename Operator_ = arch::OpMultiplyAddComplex, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false> -+class GemmComplex { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static bool const kSplitKSerial = SplitKSerial; -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultGemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kTransformA, -+ kTransformB, -+ Operator, -+ kSplitKSerial -+ >::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmComplex() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return 0; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (selects complex or gaussian complex) -+ typename Operator_, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial -+> -+class GemmComplex< -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ layout::ColumnMajor, // partially specialized on LayoutC -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator_, -+ SplitKSerial -+> { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static int const kStages = Stages; -+ using Operator = Operator_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = GemmComplex< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformB, -+ TransformA, -+ Operator, -+ SplitKSerial -+ >; -+ -+ static int const kAlignmentA = UnderlyingOperator::kAlignmentB; -+ static int const kAlignmentB = UnderlyingOperator::kAlignmentA; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB; -+ static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { } -+ }; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmComplex() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h -new file mode 100644 -index 0000000..3e932eb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_grouped.h -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Device-level grouped GEMM. -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/device/base_grouped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM Grouped -+template -+class GemmGrouped : public BaseGrouped { -+public: -+ using GemmKernel = GemmKernel_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..3ebb2a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Device-level GEMM with layernorm elementwise operations fused in mainloop -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator -+> -+class GemmLayernormMainloopFusion : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmLayernormMainloopFusion< -+ ElementA_, -+ LayoutA_, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ AlignmentB, -+ ElementScaleBias_, -+ LayoutScaleBias_, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmLayernormMainloopFusion< -+ ElementA_, -+ LayoutA_, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ AlignmentB, -+ ElementScaleBias_, -+ LayoutScaleBias_, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ -+> -+class GemmLayernormMainloopFusion { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementScaleBias = ElementScaleBias_; -+ using LayoutScaleBias = LayoutScaleBias_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ -+ using UnderlyingOperator = typename GemmLayernormMainloopFusion< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementScaleBias, -+ LayoutScaleBias, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmLayernormMainloopFusion() { } -+ -+ /// Helper to construct a transposed equivalent for the underlying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h -new file mode 100644 -index 0000000..0366b05 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_sparse.h -@@ -0,0 +1,514 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/sparse_gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_sparse.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible GEMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS GEMM operator. -+ // -+ -+ cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ > gemm_op; -+ -+ // -+ // Launch the GEMM operation on the device -+ // -+ -+ cutlass::Status status = gemm_op({ -+ {m, n, k}, // GemmCoord problem_size, -+ {A, lda}, // TensorRef ref_A, -+ {B, ldb}, // TensorRef ref_B, -+ {C, ldc}, // TensorRef ref_C, -+ {D, ldd}, // TensorRef ref_D, -+ {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params -+ }); -+ -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages -+ > -+ class Gemm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class SparseGemm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ using MathOperator = Operator; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using GemmKernel = typename kernel::DefaultSparseGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator -+ >::GemmKernel; -+ -+ using ElementE = typename GemmKernel::ElementE; -+ -+ using LayoutE = typename GemmKernel::LayoutE; -+ -+ static int const kAlignmentE = 128 / sizeof_bits::value; -+ -+ static int const kSparse = GemmKernel::kSparse; -+ static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; -+ static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ TensorRef ref_E; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): problem_size(0, 0, 0), split_k_slices(1) { -+ -+ } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ TensorRef ref_E_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1 -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ ref_E(ref_E_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices) { -+ -+ } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ SparseGemm() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = GemmKernel::can_implement( -+ args.problem_size, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ref_E.non_const_ref() -+ ); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ if (kSplitKSerial) { -+ if (args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.split_k_slices > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ args.ref_C.non_const_ref(), -+ args.ref_D, -+ args.ref_E.non_const_ref(), -+ args.epilogue, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.split_k_slices > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ params_.ref_A.reset(args.ref_A.non_const_ref().data()); -+ params_.ref_B.reset(args.ref_B.non_const_ref().data()); -+ params_.ref_C.reset(args.ref_C.non_const_ref().data()); -+ params_.ref_D.reset(args.ref_D.data()); -+ params_.ref_E.reset(args.ref_E.non_const_ref().data()); -+ params_.output_op = args.epilogue; -+ params_.semaphore = static_cast(workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h -new file mode 100644 -index 0000000..55db955 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h -@@ -0,0 +1,638 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for GEMM performing a reduction over K partitions in parallel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/epilogue/thread/conversion_op.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ Gemm device-level operator performing parallel reduction over the K partition. -+ -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Epilogue output operator -+ typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert< -+ ElementAccumulator_, -+ DefaultGemmConfiguration::EpilogueOutputOp::kCount, -+ ElementAccumulator_>, -+ /// Reduction operator -+ typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator, -+ EpilogueOutputOp_::kCount>, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ threadblock::GemmSplitKHorizontalThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator> -+class GemmSplitKParallel { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ConvertScaledOp = ConvertScaledOp_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ReductionOp = ReductionOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ -+ /// GEMM kernel -+ using GemmKernel = typename kernel::DefaultGemmSplitKParallel< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ConvertScaledOp, -+ ThreadblockSwizzle, -+ kStages, -+ Operator -+ >::GemmKernel; -+ -+ /// Reduction kernel -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ // -+ // -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ typename ConvertScaledOp::Params convert; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ typename ConvertScaledOp::Params convert_ = -+ typename ConvertScaledOp::Params(), -+ typename ReductionOp::Params reduction_ = -+ typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ convert(convert_), -+ reduction(reduction_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ typename GemmKernel::Params gemm_params_; -+ -+ /// Reduction kernel parameters object -+ typename ReductionKernel::Params reduction_params_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmSplitKParallel() { } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ // TODO -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k(); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.split_k_slices); -+ -+ // Define a reference to the workspace - this is an aligned region in device memory. -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ TensorRef ref_workspace( -+ static_cast(workspace), -+ args.problem_size.n()); -+ -+ int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n()); -+ -+ // Initialize the Params structure -+ gemm_params_ = typename GemmKernel::Params{ -+ args.problem_size, -+ grid_shape, -+ args.ref_A.non_const_ref(), -+ args.ref_B.non_const_ref(), -+ ref_workspace, -+ args.convert, -+ partition_stride -+ }; -+ -+ reduction_params_ = typename ReductionKernel::Params( -+ args.problem_size.mn(), -+ grid_shape.k(), -+ partition_stride, -+ ref_workspace, -+ args.ref_D, -+ args.ref_C.non_const_ref(), -+ args.epilogue -+ ); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ gemm_params_.ref_A.reset(args.ref_A.data()); -+ gemm_params_.ref_B.reset(args.ref_B.data()); -+ gemm_params_.ref_D.reset(workspace); -+ -+ reduction_params_.ref_D.reset(args.ref_D.data()); -+ reduction_params_.ref_C.reset(args.ref_C.data()); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Launch GEMM kernel -+ // -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape); -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ -+ result = cudaFuncSetAttribute( -+ Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ Kernel<<>>(gemm_params_); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ // -+ // Launch reduction kernel -+ // -+ -+ block = ReductionKernel::block_shape(); -+ grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn()); -+ -+ Kernel<<< grid, block, 0, stream >>>(reduction_params_); -+ -+ result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for column-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Epilogue output operator -+ typename ConvertScaledOp_, -+ /// Reduction operator -+ typename ReductionOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, int kAlignmentA, int kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_> -+class GemmSplitKParallel { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ConvertScaledOp = ConvertScaledOp_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ReductionOp = ReductionOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ -+ using UnderlyingOperator = GemmSplitKParallel< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ConvertScaledOp, -+ ReductionOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentA, -+ kAlignmentB, -+ Operator -+ >; -+ -+ using UnderlyingArguments = typename UnderlyingOperator::Arguments; -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ using ReductionKernel = typename UnderlyingOperator::ReductionKernel; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ TensorRef ref_A; -+ TensorRef ref_B; -+ TensorRef ref_C; -+ TensorRef ref_D; -+ typename EpilogueOutputOp::Params epilogue; -+ int split_k_slices; -+ typename ConvertScaledOp::Params convert; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ /// Constructs an Arguments structure -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord problem_size_, -+ TensorRef ref_A_, -+ TensorRef ref_B_, -+ TensorRef ref_C_, -+ TensorRef ref_D_, -+ typename EpilogueOutputOp::Params epilogue_ = -+ typename EpilogueOutputOp::Params(), -+ int split_k_slices = 1, -+ typename ConvertScaledOp::Params convert_ = -+ typename ConvertScaledOp::Params(), -+ typename ReductionOp::Params reduction_ = -+ typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ ref_A(ref_A_), -+ ref_B(ref_B_), -+ ref_C(ref_C_), -+ ref_D(ref_D_), -+ epilogue(epilogue_), -+ split_k_slices(split_k_slices), -+ convert(convert_), -+ reduction(reduction_) { } -+ }; -+ -+private: -+ -+ /// Kernel parameters object -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmSplitKParallel() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static UnderlyingArguments to_underlying_arguments(Arguments const &args) { -+ return UnderlyingArguments( -+ {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, -+ {args.ref_B.data(), args.ref_B.stride(0)}, -+ {args.ref_A.data(), args.ref_A.stride(0)}, -+ {args.ref_C.data(), args.ref_C.stride(0)}, -+ {args.ref_D.data(), args.ref_D.stride(0)}, -+ args.epilogue, -+ args.split_k_slices, -+ args.convert, -+ args.reduction -+ ); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h -new file mode 100644 -index 0000000..6c19b8a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal.h -@@ -0,0 +1,420 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation -+ (problem geometry and data references), it can be reused across different GEMM problems having the -+ geometry. (Once initialized, details regarding problem geometry and references to workspace memory -+ cannot be updated.) -+ -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute -+> -+class GemmUniversal : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmUniversal< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmUniversal< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class GemmUniversal { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmUniversal< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversal() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h -new file mode 100644 -index 0000000..66884fb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h -@@ -0,0 +1,549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+ -+#pragma once -+ -+// common -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+#include "cutlass/cluster_launch.hpp" -+#include "cutlass/device_kernel.h" -+#include "cutlass/gemm/gemm.h" -+ -+// 2.x -+#include "cutlass/gemm/device/gemm_universal_base.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+// 3.x -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel -+ of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. -+ -+ It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs -+ to create it from the host facing arguments. For power users, new static methods -+ are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. -+ -+ It supports kernel types that implement both the 2.x and 3.0 APIs, -+ however, this is done by specializing the implementation of GemmUniversalAdapter -+ on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might -+ differ between the two specializations. -+*/ -+template -+class GemmUniversalAdapter; -+ -+//////////////////////////////////////////////////////////////////////////////// -+////////////////////////////// CUTLASS 3.x API ///////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalAdapter< -+ GemmKernel_, -+ std::enable_if_t::value>> -+{ -+public: -+ using GemmKernel = GemmKernel_; -+ using TileShape = typename GemmKernel::TileShape; -+ using ElementA = typename GemmKernel::ElementA; -+ using ElementB = typename GemmKernel::ElementB; -+ using ElementC = typename GemmKernel::ElementC; -+ using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; -+ using DispatchPolicy = typename GemmKernel::DispatchPolicy; -+ using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; -+ using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; -+ -+ // Map back to 2.x type as best as possible -+ using LayoutA = gemm::detail::StrideToLayoutTagA_t; -+ using LayoutB = gemm::detail::StrideToLayoutTagB_t; -+ using LayoutC = gemm::detail::StrideToLayoutTagC_t; -+ using LayoutD = gemm::detail::StrideToLayoutTagC_t; -+ -+ // NOTE: 3.0 kernels do not support complex transforms for now ... -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 -+ using MathOperator = cutlass::arch::OpMultiplyAdd; -+ -+ // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! -+ using OperatorClass = std::conditional_t< -+ (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), -+ cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; -+ -+ using ArchTag = typename GemmKernel::ArchTag; -+ -+ // NOTE: Assume identity swizzle for now -+ static_assert(std::is_void_v, -+ "CUTLASS 3.x kernel types do not support grid swizzle functors yet."); -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -+ -+ // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape -+ using ThreadblockShape = cutlass::gemm::GemmShape< -+ cute::size<0>(TileShape{}), -+ cute::size<1>(TileShape{}), -+ cute::size<2>(TileShape{})>; -+ -+ using ClusterShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; -+ -+ // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape -+ using InstructionShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), -+ cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), -+ cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; -+ -+ // Legacy: provide a correct warp count, but no reliable warp shape -+ static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; -+ -+ // Warp shape is not a primary API type in 3.x -+ // But we can best approximate it by inspecting the TiledMma::TiledShape_MNK -+ // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K -+ // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads -+ static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); -+ static constexpr int WarpsInMmaM = 4; -+ static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); -+ using WarpCount = cutlass::gemm::GemmShape; -+ using WarpShape = cutlass::gemm::GemmShape< -+ cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, -+ cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, -+ cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; -+ -+ static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; -+ -+ // Inspect TiledCopy for A and B to compute the alignment size -+ static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< -+ typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); -+ static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy< -+ typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); -+ -+ // NOTE: 3.0 DefaultEpilogues don't support vectorized stores (yet) -+ static int constexpr kAlignmentC = 1; -+ static int constexpr kAlignmentD = 1; -+ -+ using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; -+ -+ // Split-K preserves splits that are 128b aligned -+ static int constexpr kSplitKAlignment = std::max( -+ 128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ /// Argument structure: User API -+ using Arguments = typename GemmKernel::Arguments; -+ /// Argument structure: Kernel API -+ using Params = typename GemmKernel::Params; -+ -+private: -+ -+ /// Kernel API parameters object -+ Params params_; -+ -+public: -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status -+ can_implement(Arguments const& args) { -+ if (GemmKernel::can_implement(args)) { -+ return Status::kSuccess; -+ } -+ else { -+ return Status::kInvalid; -+ } -+ } -+ -+ /// Gets the workspace size -+ static size_t -+ get_workspace_size(Arguments const& args) { -+ size_t workspace_bytes = 0; -+ if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); -+ } -+ -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ -+ workspace_bytes += GemmKernel::get_workspace_size(args); -+ return workspace_bytes; -+ } -+ -+ /// Computes the grid shape -+ static dim3 -+ get_grid_shape(Arguments const& args) { -+ auto tmp_params = GemmKernel::to_underlying_arguments(args); -+ return GemmKernel::get_grid_shape(tmp_params); -+ } -+ -+ /// Computes the grid shape -+ static dim3 -+ get_grid_shape(Params const& params) { -+ return GemmKernel::get_grid_shape(params); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int /* smem_capacity */ = -1) { -+ CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); -+ int max_active_blocks = -1; -+ int smem_size = GemmKernel::SharedStorageSize; -+ -+ // first, account for dynamic smem capacity if needed -+ cudaError_t result; -+ if (smem_size >= (48 << 10)) { -+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); -+ result = cudaFuncSetAttribute( -+ device_kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST( -+ " cudaFuncSetAttribute() returned error: " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ } -+ -+ // query occupancy after setting smem size -+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( -+ &max_active_blocks, -+ device_kernel, -+ GemmKernel::MaxThreadsPerBlock, -+ smem_size); -+ -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST( -+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " -+ << cudaGetErrorString(result)); -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); -+ return max_active_blocks; -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status -+ initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ size_t workspace_bytes = GemmKernel::get_workspace_size(args); -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ -+ if (workspace_bytes) { -+ if (!workspace) { -+ CUTLASS_TRACE_HOST(" error: device workspace must not be null"); -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ if (args.mode == GemmUniversalMode::kGemm) { -+ CUTLASS_TRACE_HOST(" clearing device workspace"); -+ cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ -+ // Initialize the Params structure -+ params_ = GemmKernel::to_underlying_arguments(args, workspace); -+ -+ // account for dynamic smem capacity if needed -+ int smem_size = GemmKernel::SharedStorageSize; -+ if (smem_size >= (48 << 10)) { -+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); -+ cudaError_t result = cudaFuncSetAttribute( -+ device_kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ if (cudaSuccess != result) { -+ result = cudaGetLastError(); // to clear the error bit -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ return Status::kSuccess; -+ } -+ -+ /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. -+ Status -+ update(Arguments const& args, void* workspace = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ if (workspace_bytes > 0 && nullptr == workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_ = GemmKernel::to_underlying_arguments(args, workspace); -+ return Status::kSuccess; -+ } -+ -+ /// Primary run() entry point API that is static allowing users to create and manage their own params. -+ /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() -+ static Status -+ run(Params& params, cudaStream_t stream = nullptr) { -+ CUTLASS_TRACE_HOST("GemmUniversal::run()"); -+ dim3 constexpr block = GemmKernel::get_block_shape(); -+ dim3 const grid = get_grid_shape(params); -+ -+ // configure smem size and carveout -+ int smem_size = GemmKernel::SharedStorageSize; -+ -+ Status launch_result; -+ // Use extended launch API only for mainloops that use it -+ if constexpr(GemmKernel::ArchTag::kMinComputeCapability >= 90) { -+ dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), -+ cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); -+ void const* kernel = (void const*) device_kernel; -+ void* kernel_params[] = {¶ms}; -+ launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); -+ } -+ else { -+ launch_result = Status::kSuccess; -+ device_kernel<<>>(params); -+ } -+ -+ cudaError_t result = cudaGetLastError(); -+ if (cudaSuccess == result && Status::kSuccess == launch_result) { -+ return Status::kSuccess; -+ } -+ else { -+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // -+ // Non-static launch overloads that first create and set the internal params struct of this kernel handle. -+ // -+ -+ /// Launches the kernel after first constructing Params internal state from supplied arguments. -+ Status -+ run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ Status status = initialize(args, workspace, stream); -+ if (Status::kSuccess == status) { -+ status = run(params_, stream); -+ } -+ return status; -+ } -+ -+ /// Launches the kernel after first constructing Params internal state from supplied arguments. -+ Status -+ operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { -+ return run(args, workspace, stream); -+ } -+ -+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct. -+ Status -+ run(cudaStream_t stream = nullptr) { -+ return run(params_, stream); -+ } -+ -+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct. -+ Status -+ operator()(cudaStream_t stream = nullptr) const { -+ return run(params_, stream); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+////////////////////////////// CUTLASS 2.x API ///////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalAdapter< -+ GemmKernel_, -+ std::enable_if_t::value>> -+{ -+public: -+ -+ using GemmKernel = GemmKernel_; -+ -+ static bool const kInternalTranspose = -+ platform::is_same::value; -+ -+ using ThreadblockShape = typename GemmKernel::Mma::Shape; -+ using WarpShape = typename GemmKernel::WarpShape; -+ using InstructionShape = typename GemmKernel::InstructionShape; -+ -+ // warp-level, arch-level (instruction), math operator -+ using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; -+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; -+ using MathOperator = typename WarpMmaOperator::MathOperator; -+ -+ // Operator class and arch tag extract bottom-up -+ // set it for top-level gemm device-level template -+ using OperatorClass = typename WarpMmaOperator::OperatorClass; -+ using ArchTag = typename WarpMmaOperator::ArchTag; -+ -+ // Type, layout, and complex transform deliberately exchanged with B -+ using MapArguments = kernel::detail::MapArguments< -+ typename GemmKernel::ElementA, -+ typename GemmKernel::LayoutA, -+ GemmKernel::kTransformA, -+ GemmKernel::kAlignmentA, -+ typename GemmKernel::ElementB, -+ typename GemmKernel::LayoutB, -+ GemmKernel::kTransformB, -+ GemmKernel::kAlignmentB, -+ typename GemmKernel::LayoutC, -+ kInternalTranspose -+ >; -+ -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ -+ using ElementC = typename GemmKernel::ElementC; -+ using LayoutC = typename MapArguments::LayoutC; -+ static int const kAlignmentC = GemmKernel::kAlignmentC; -+ -+ using TensorRefA = TensorRef; -+ using TensorRefB = TensorRef; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ -+ static int const kStages = GemmKernel::Mma::kStages; -+ -+ using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; -+ using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; -+ using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; -+ using UnderlyingOperator = GemmUniversalBase; -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversalAdapter() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ if (kInternalTranspose) { -+ return args.transposed_problem(); -+ } -+ else { -+ return args; -+ } -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed to -+ /// remain the same. -+ Status update(Arguments const &args) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args)); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::device -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h -new file mode 100644 -index 0000000..cca768a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_base.h -@@ -0,0 +1,416 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. -+*/ -+ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template -+class GemmUniversalBase { -+public: -+ -+ using GemmKernel = GemmKernel_; -+ using ThreadblockShape = typename GemmKernel::Mma::Shape; -+ -+ using ElementA = typename GemmKernel::ElementA; -+ using LayoutA = typename GemmKernel::LayoutA; -+ using TensorRefA = TensorRef; -+ static ComplexTransform const kTransformA = GemmKernel::kTransformA; -+ -+ using ElementB = typename GemmKernel::ElementB; -+ using LayoutB = typename GemmKernel::LayoutB; -+ using TensorRefB = TensorRef; -+ static ComplexTransform const kTransformB = GemmKernel::kTransformB; -+ -+ using ElementC = typename GemmKernel::ElementC; -+ using LayoutC = typename GemmKernel::LayoutC; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ -+ /// Numerical accumulation element type -+ using ElementAccumulator = typename GemmKernel::Mma::ElementC; -+ -+ using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; -+ using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; -+ using Operator = typename GemmKernel::Operator; -+ -+ /// Argument structure -+ using Arguments = typename GemmKernel::Arguments; -+ -+protected: -+ -+ // -+ // Device properties (uniform across all instances of the current thread) -+ // -+ -+ // Device ordinal -+ thread_local static int device_ordinal_; -+ -+ /// Device SM count -+ thread_local static int device_sms_; -+ -+ /// Kernel SM occupancy (in thread blocks) -+ thread_local static int sm_occupancy_; -+ -+ /// Kernel dynamic shared memory allocation requirement -+ thread_local static int smem_size_; -+ -+ /// Initialize static thread-local members for the thread's current device, -+ /// if necessary. -+ static Status init_device_props() -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); -+ -+ cudaError_t cudart_result; -+ -+ // Get current device ordinal -+ int current_ordinal; -+ cudart_result = cudaGetDevice(¤t_ordinal); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Done if matches the current static member -+ if (current_ordinal == device_ordinal_) { -+ // Already initialized -+ return Status::kSuccess; -+ } -+ -+ // Update SM count member -+ cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Update the kernel function's shared memory configuration for the current device -+ smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); -+ -+ // If requires more than 48KB: configure for extended, dynamic shared memory -+ if (smem_size_ >= (48 << 10)) -+ { -+ cudart_result = cudaFuncSetAttribute( -+ Kernel2, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size_); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ cudart_result = cudaFuncSetAttribute( -+ Kernel2, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // Update SM occupancy member -+ cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( -+ &sm_occupancy_, -+ Kernel2, -+ GemmKernel::kThreadCount, -+ smem_size_, -+ cudaOccupancyDisableCachingOverride); -+ if (cudart_result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); -+ return Status::kErrorInternal; -+ } -+ -+ // Update device ordinal member on success -+ device_ordinal_ = current_ordinal; -+ -+ CUTLASS_TRACE_HOST(" " -+ "device_ordinal: (" << device_ordinal_ << "), " -+ "device_sms: (" << device_sms_ << "), " -+ "sm_occupancy: (" << sm_occupancy_ << ") " -+ "smem_size: (" << smem_size_ << ") " -+ "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); -+ -+ return Status::kSuccess; -+ } -+ -+ -+protected: -+ -+ // -+ // Instance data members -+ // -+ -+ /// Kernel parameters -+ typename GemmKernel::Params params_; -+ -+ -+ /// Initialize params member -+ Status init_params(Arguments const &args) -+ { -+ // Initialize static device properties, if necessary -+ Status result = init_device_props(); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ // Initialize params member -+ params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_); -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ //--------------------------------------------------------------------------------------------- -+ // Stateless API -+ //--------------------------------------------------------------------------------------------- -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); -+ -+ // Initialize static kernel and device properties, if necessary. -+ Status result = init_device_props(); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ dim3 grid = get_grid_shape(args); -+ -+ if (!(grid.y <= std::numeric_limits::max() && -+ grid.z <= std::numeric_limits::max())) -+ { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ return GemmKernel::can_implement(args); -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for the problem -+ /// geometry expressed by these arguments -+ static size_t get_workspace_size(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); -+ -+ // Initialize parameters from args -+ GemmUniversalBase base; -+ if (base.init_params(args) != Status::kSuccess) { -+ return 0; -+ } -+ -+ // Get size from parameters -+ size_t workspace_bytes = base.params_.get_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); -+ return workspace_bytes; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ static dim3 get_grid_shape(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); -+ -+ // Initialize parameters from args -+ GemmUniversalBase base; -+ if (base.init_params(args) != Status::kSuccess) { -+ return dim3(0,0,0); -+ } -+ -+ // Get dims from parameters -+ dim3 grid_dims = base.params_.get_grid_dims(); -+ -+ CUTLASS_TRACE_HOST( -+ " tiled_shape: " << base.params_.get_tiled_shape() << "\n" -+ << " grid_dims: {" << grid_dims << "}"); -+ -+ return grid_dims; -+ } -+ -+ -+ /// Returns the maximum number of active thread blocks per multiprocessor -+ static int maximum_active_blocks() -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); -+ -+ // Initialize static device properties, if necessary -+ if (init_device_props() != Status::kSuccess) { -+ return -1; -+ } -+ -+ CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); -+ return sm_occupancy_; -+ } -+ -+ -+ //--------------------------------------------------------------------------------------------- -+ // Stateful API -+ //--------------------------------------------------------------------------------------------- -+ -+ /// Initializes GEMM state from arguments and workspace memory -+ Status initialize( -+ Arguments const &args, -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " -+ << workspace << ", stream: " << (stream ? "non-null" : "null")); -+ -+ // Initialize parameters from args -+ Status result = init_params(args); -+ if (result != Status::kSuccess) { -+ return result; -+ } -+ -+ // Assign and prepare workspace memory -+ return params_.init_workspace(workspace, stream); -+ } -+ -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed to -+ /// remain the same. -+ Status update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); -+ params_.update(args); -+ return Status::kSuccess; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); -+ -+ // Configure grid and block dimensions -+ dim3 block(GemmKernel::kThreadCount, 1, 1); -+ dim3 grid = params_.get_grid_dims(); -+ -+ // Launch kernel -+ CUTLASS_TRACE_HOST(" " -+ "grid: (" << grid << "), " -+ "block: (" << block << "), " -+ "SMEM: (" << smem_size_ << ")"); -+ -+ Kernel2<<>>(params_); -+ -+ // Query for errors -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) -+ { -+ return run(stream); -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) -+ { -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Static initializers -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Device ordinal -+template -+thread_local int GemmUniversalBase::device_ordinal_ = -1; -+ -+/// Device SM count -+template -+thread_local int GemmUniversalBase::device_sms_ = -1; -+ -+/// Kernel SM occupancy (in thread blocks) -+template -+thread_local int GemmUniversalBase::sm_occupancy_ = -1; -+ -+/// Kernel dynamic shared memory allocation requirement -+template -+thread_local int GemmUniversalBase::smem_size_ = -1; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h -new file mode 100644 -index 0000000..34b3f6c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h -@@ -0,0 +1,386 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a GEMM kernel that can broadcast bias vector in the -+ epigloue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM with a broadcast epilogue. -+ Supports -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ ElementC_, ElementAccumulator_, ElementAccumulator_, -+ ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class GemmUniversalWithBroadcast : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmWithBroadcast< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_ -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmWithBroadcast< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_ -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB> -+class GemmUniversalWithBroadcast { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmUniversalWithBroadcast< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmUniversalWithBroadcast() { } -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h -new file mode 100644 -index 0000000..c671d7c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a GEMM kernel that can reduce one of the input matrix -+ into a vector along the K dimension. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+#include "cutlass/layout/permute.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! -+ The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassSimt, -+ /// Reduce A or B operand along the K dimension -+ bool ReduceKForA_ = true, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_ = arch::Sm70, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute -+> -+class GemmWithKReduction : -+ public GemmUniversalBase< -+ typename kernel::DefaultGemmWithKReduction< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ReduceKForA_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ > { -+ -+ public: -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static constexpr int kStages = Stages; -+ static constexpr int kAlignmentA = AlignmentA; -+ static constexpr int kAlignmentB = AlignmentB; -+ static constexpr int kAlignmentC = EpilogueOutputOp::kCount; -+ static constexpr ComplexTransform kTransformA = TransformA; -+ static constexpr ComplexTransform kTransformB = TransformB; -+ -+ using Base = GemmUniversalBase< -+ typename kernel::DefaultGemmWithKReduction< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ ElementC_, -+ LayoutC_, -+ ElementAccumulator_, -+ OperatorClass_, -+ ReduceKForA_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ EpilogueOutputOp_, -+ ThreadblockSwizzle_, -+ Stages, -+ Operator_, -+ SharedMemoryClearOption::kNone -+ >::GemmKernel -+ >; -+ -+ using Arguments = typename Base::Arguments; -+ using GemmKernel = typename Base::GemmKernel; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Reduce A or B operand along the K dimension -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+class GemmWithKReduction { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using UnderlyingOperator = typename GemmWithKReduction< -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ !ReduceKForA_, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ kAlignmentB, -+ kAlignmentA, -+ Operator, -+ kTransformB, -+ kTransformA, -+ GatherB, -+ GatherA, -+ ScatterD, -+ PermuteDLayout -+ >::Base; -+ -+ using GemmKernel = typename UnderlyingOperator::GemmKernel; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the GEMM. -+ GemmWithKReduction() = default; -+ -+ /// Helper to construct a transposed equivalent for the underying GEMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the GEMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes GEMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h -new file mode 100644 -index 0000000..c62168f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/gemv.h -@@ -0,0 +1,174 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/gemm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Gemv { -+public: -+ -+ using GemvKernel = GemvKernel_; -+ -+ -+ using ElementA = typename GemvKernel::ElementA; -+ using LayoutA = typename GemvKernel::LayoutA; -+ using ElementB = typename GemvKernel::ElementB; -+ using ElementC = typename GemvKernel::ElementC; -+ -+ using ElementAccumulator = typename GemvKernel::ElementAccumulator; -+ using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; -+ -+ static ComplexTransform const kTransformA = GemvKernel::kTransformA; -+ static ComplexTransform const kTransformB = GemvKernel::kTransformB; -+ -+ static int const kThreadCount = GemvKernel::kThreadCount; -+ static int const kStages = GemvKernel::kStages; -+ -+ static int const kAlignmentA = GemvKernel::kAlignmentA; -+ static int const kAlignmentB = GemvKernel::kAlignmentB; -+ static int const kAlignmentC = GemvKernel::kAlignmentC; -+ -+ using Arguments = typename GemvKernel::Arguments; -+ using Params = typename GemvKernel::Params; -+ -+private: -+ -+ Params params_; -+ -+public: -+ -+ /// Constructs the Gemv. -+ Gemv() { } -+ -+ /// Determines whether the Gemv can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return GemvKernel::can_implement(args); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return 0; -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return dim3((args.problem_size.row() + (kThreadCount - 1)) / kThreadCount, 1, args.batch_count % 65565); -+ } -+ -+ /// Initializes Gemv state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ params_ = Params(args); -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ return params_.update(args); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ dim3 grid = get_grid_shape(params_); -+ dim3 block(GemvKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); -+ -+ // Launch -+ cutlass::Kernel<<>>(params_); -+ -+ // -+ // Query for errors -+ // -+ cudaError_t result = cudaGetLastError(); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h -new file mode 100644 -index 0000000..d333ffa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined Rank2K kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+ -+#include "cutlass/gemm/kernel/default_rank_2k_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYRK -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class Rank2K { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 2; -+ -+ // static asserts for rank 2k update kernel -+ static_assert(platform::is_same::value, -+ "Rank 2K update operator support same layouts for operandA and B"); -+ -+ /// Define the kernel -+ using Rank2Kkernel = typename kernel::DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ kTransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kTransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::Rank2Kkernel; -+ -+ using Arguments = typename Rank2Kkernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename Rank2Kkernel::Params params_; -+public: -+ -+ /// Constructs the SYRK. -+ Rank2K() { } -+ -+ /// Determines whether the SYRK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = Rank2Kkernel::can_implement(args); -+ -+ if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYRK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Initialize the Params structure -+ params_ = typename Rank2Kkernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(Rank2Kkernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchange operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by Rank2K update kernel -+ typename Operator_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class Rank2K { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kUpdateRank = 2; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::Rank2K< -+ ElementB, -+ LayoutB, -+ ElementA, -+ LayoutA, -+ ElementC, -+ layout::RowMajor, -+ InvertFillMode::mode, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentB, -+ kAlignmentA, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kTransformB, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using Rank2Kkernel = typename UnderlyingOperator::Rank2Kkernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the Rank2K. -+ Rank2K() { } -+ -+ /// Helper to construct a transposed equivalent for the underying Rank2K operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem(); -+ } -+ -+ /// Determines whether the Rank2K can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes Rank2K state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace Rank2K -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h -new file mode 100644 -index 0000000..f38b07a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Device-level grouped Rank2K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/device/base_grouped.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Rank2K Grouped -+template -+class Rank2KGrouped : public BaseGrouped { -+public: -+ using Rank2Kkernel = Rank2Kkernel_; -+ static const cutlass::FillMode kFillModeC = Rank2Kkernel::kFillModeC; -+ static const cutlass::BlasMode kBlasMode = Rank2Kkernel::kBlasMode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h b/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h -new file mode 100644 -index 0000000..a2101a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/rank_k.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined RankK kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+ -+#include "cutlass/gemm/kernel/default_rank_k_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = -+ typename threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYRK -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class RankK { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static ComplexTransform const kTransformA = TransformA; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 1; -+ -+ /// Define the kernel -+ using RankKkernel = typename kernel::DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ kTransformA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::RankKkernel; -+ -+ using Arguments = typename RankKkernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename RankKkernel::Params params_; -+public: -+ -+ /// Constructs the SYRK. -+ RankK() { } -+ -+ /// Determines whether the SYRK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = RankKkernel::can_implement(args); -+ -+ if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYRK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Initialize the Params structure -+ params_ = typename RankKkernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(RankKkernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for column-major output exchange operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by RankK update kernel -+ typename Operator_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class RankK { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static FillMode const kFillModeC = FillModeC; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ static int const kUpdateRank = 1; -+ -+ // Complex transform for input A matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ layout::RowMajor, -+ InvertFillMode::mode, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kSplitKSerial, -+ Operator, -+ kTransformA, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using RankKkernel = typename UnderlyingOperator::RankKkernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the RankK. -+ RankK() { } -+ -+ /// Helper to construct a transposed equivalent for the underying RankK operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args; -+ } -+ -+ /// Determines whether the RankK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes RankK state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace RankK -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/symm.h b/3rdparty/cutlass/include/cutlass/gemm/device/symm.h -new file mode 100755 -index 0000000..57bfeec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/symm.h -@@ -0,0 +1,602 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined SYMM and HEMM kernels. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_symm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< -+ ElementC_, -+ 128 / sizeof_bits::value, -+ ElementAccumulator_, -+ ElementAccumulator_, -+ epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by SYMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+class Symm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; -+ using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; -+ using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ // static asserts for symm update kernel -+ static_assert(platform::is_same::value, -+ "SYMM update operator support same layouts for operand A and B"); -+ -+ /// Define the kernel -+ using SymmKernel = typename kernel::DefaultSymmUniversal< -+ ElementAKernel, -+ LayoutAKernel, -+ kSideModeA, -+ kFillModeA, -+ kAlignmentAKernel, -+ ElementBKernel, -+ LayoutBKernel, -+ kAlignmentBKernel, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >::SymmKernel; -+ -+ using Arguments = typename SymmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename SymmKernel::Params params_; -+public: -+ -+ /// Constructs the SYMM. -+ Symm() { } -+ -+ /// Determines whether the SYMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = SymmKernel::can_implement(args); -+ -+ if (SideModeA == SideMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (FillModeA != FillMode::kLower && FillModeA != FillMode::kUpper) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes SYMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). -+ if (kSideModeA == SideMode::kRight) { -+ // Initialize the Params structure -+ params_ = typename SymmKernel::Params{ -+ args.swapped_matrices(), -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ // Initialize the Params structure -+ params_ = typename SymmKernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(SymmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename SymmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/******************************************************************************************************** -+ SYMM/HEMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} -+ In templates and arguments to cutlass kernel, `matrix A` is always symmetric/hermitian, and `matrix B` is rectangular. -+ (adhering to the cuBLAS convention) -+ -+ Although, cuBLAS SYMM/HEMM only supports ColumnMajor layouts for all matrices (A, B, C/D). -+ -+ For the mainloop and symm kernel, `A` and `B` points to left-side and right-side matrices, respectively. -+ -+ Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for -+ the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. -+ -+ Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by -+ transposing the GEMM problem. Thus, ColumnMajor output layout for SYMM/HEMM requires: -+ - Transposing `matrix A` and `matrix B` layouts -+ - Swapping problem size m and n values -+ - Swapping LeftSide and RightSide mode -+ -+ RowMajor output: D = matrix A x matrix B -+ ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) -+ -+ {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: -+ 1. LeftSide mode and RowMajor output (default template) -+ 2. LeftSide mode and ColumnMajor output -+ 3. RightSide mode and RowMajor output -+ 4. RightSide mode and ColumnMajor output -+ -+ Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: -+ -+ Case 2 -> Case 3: -+ D_col = matrix A x matrix B (LeftSide mode) -+ => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) -+ -+ swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue -+ -+ Case 4 -> Case 1: -+ D_col = matrix B x matrix A (RightSide mode) -+ => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) -+ -+ call GEMM mainloop for with RowMajor efficient-epilogue -+********************************************************************************************************/ -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ /// Operation performed by Symm update kernel -+ typename Operator_, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > -+class Symm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ /// Define the kernel -+ using UnderlyingOperator = typename cutlass::gemm::device::Symm< -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ InvertSideMode::mode, -+ InvertFillMode::mode, -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ kBlasMode -+ >; -+ -+ -+ /// Argument structure -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using SymmKernel = typename UnderlyingOperator::SymmKernel; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the Symm. -+ Symm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying SYMM operator -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem_size(); -+ } -+ -+ /// Determines whether the Symm can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the grid shape -+ static dim3 get_grid_shape(Arguments const &args) { -+ return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); -+ } -+ -+ /// Computes the maximum number of active blocks per multiprocessor -+ static int maximum_active_blocks(int smem_capacity = -1) { -+ return UnderlyingOperator::maximum_active_blocks(smem_capacity); -+ } -+ -+ /// Initializes Symm state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace Symm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h b/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h -new file mode 100644 -index 0000000..34816db ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/device/trmm.h -@@ -0,0 +1,758 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a TRMM kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+ -+#include "cutlass/gemm/kernel/default_trmm_universal.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*! Trmm device-level operator. This is an interface to efficient CUTLASS TRMM kernels that may -+ be invoked from host code. -+ -+ The contributions of this class are: -+ -+ 1. At compile time, it maps data types and high-level structural parameters onto -+ specific CUTLASS components. -+ -+ 2. At runtime, it maps logical arguments to TRMM problems to kernel parameters. -+ -+ 3. At runtime, it launches kernels on the device. -+ -+ The intent is to provide a convenient mechanism for interacting with most plausible TRMM -+ configurations for each supported architecture. Consequently, not all parameters are exposed -+ to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy -+ are selected to tradeoff simplicity of the interface with flexibility. We expect -+ most configurations to be specified at this level. Applications with more exotic requirements -+ may construct their kernels of interest using CUTLASS components at the threadblock, warp, -+ and thread levels of abstraction. -+ -+ CUTLASS exposes computations using the functor design pattern in which objects compose some -+ internal state with an overloaded function call operator. This enables decoupling of -+ initialization from execution, possibly reducing overhead during steady state phases of -+ application execution. -+ -+ CUTLASS device-level operators expose an Arguments structure encompassing each logical -+ input to the computation. This is distinct from the kernel-level Params structure pattern -+ which contains application-specific precomputed state needed by the device code. -+ -+ Example of a CUTLASS TRMM operator implementing the functionality of cuBLAS's STRMM NN -+ is as follows: -+ -+ // -+ // Instantiate the CUTLASS TRMM operator. -+ // -+ -+ cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ > trmm_op; -+ -+ // -+ // Launch the TRMM operation on the device -+ // -+ -+ cutlass::Status status = trmm_op({ -+ cutlass::gemm::GemmUniversalMode, // Trmm Problem Mode -+ {m, n, m/n}, // GemmCoord problem_size (k is based on left- or right-side mode) -+ batch_count, -+ {alpha}, // EpilogueOutputOp::Params epilogue_op_params -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int lda, -+ int ldb, -+ int ldc -+ }); -+ -+ A simplified view of the template is listed below. -+ -+ template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ -+ /// DiagType for A (kNonUnit or kUnit) -+ DiagType DiagTypeA, -+ -+ /// Element type for B matrix operand -+ typename ElementB, -+ -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ -+ /// Operator class tag -+ typename OperatorClass, -+ -+ /// Tag indicating architecture to tune for. This is the minimum SM that -+ /// supports the intended feature. The device kernel can be built -+ /// targeting any SM larger than this number. -+ typename ArchTag, -+ -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial, -+ -+ /// Operation performed by TRMM -+ typename Operator, -+ -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA -+ > -+ class Trmm; -+*/ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A -+ SideMode SideModeA, -+ /// Fill Mode for A -+ FillMode FillModeA, -+ /// DiagType for A -+ DiagType DiagTypeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_ = ElementC_, -+ /// Operator class tag -+ typename OperatorClass_ = arch::OpClassTensorOp, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_ = arch::Sm80, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< -+ ElementC_, -+ 128 / sizeof_bits::value, -+ ElementAccumulator_, -+ ElementAccumulator_, -+ epilogue::thread::ScaleType::OnlyAlphaScaling -+ >, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, -+ /// Number of stages used in the pipelined mainloop -+ int Stages = -+ DefaultGemmConfiguration::kStages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA = -+ DefaultGemmConfiguration::kAlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB = -+ DefaultGemmConfiguration::kAlignmentB, -+ /// If true, kernel supports split-K with serial reduction -+ bool SplitKSerial = false, -+ /// Operation performed by TRMM -+ typename Operator_ = typename DefaultGemmConfiguration< -+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator_>::Operator, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone> -+class Trmm { -+ public: -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; -+ using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; -+ using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideMode = SideModeA; -+ static FillMode const kFillMode = FillModeA; -+ static DiagType const kDiagType = DiagTypeA; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; -+ static int const kAlignmentC = EpilogueOutputOp::kCount; -+ static bool const kSplitKSerial = SplitKSerial; -+ // Complex Transform don't appply to B -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static ComplexTransform const kTransformAKernel = (SideModeA == SideMode::kRight) ? -+ ComplexTransform::kNone : TransformA; -+ static ComplexTransform const kTransformBKernel = (SideModeA == SideMode::kRight) ? -+ TransformA : ComplexTransform::kNone; -+ -+ /// Define the kernel -+ using TrmmKernel = typename kernel::DefaultTrmmUniversal< -+ ElementAKernel, -+ LayoutAKernel, -+ kTransformAKernel, -+ kAlignmentAKernel, -+ ElementBKernel, -+ LayoutBKernel, -+ kTransformBKernel, -+ kAlignmentBKernel, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kSplitKSerial, -+ Operator -+ >::TrmmKernel; -+ -+ using Arguments = typename TrmmKernel::Arguments; -+ -+private: -+ -+ /// Kernel parameters object -+ typename TrmmKernel::Params params_; -+public: -+ -+ /// Constructs the TRMM. -+ Trmm() { } -+ -+ /// Determines whether the TRMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ if (!kSplitKSerial && args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = TrmmKernel::can_implement(args); -+ -+ if (SideModeA == SideMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (FillModeA == FillMode::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (DiagTypeA == DiagType::kInvalid) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ size_t bytes = 0; -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ -+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); -+ } -+ -+ return bytes; -+ } -+ -+ /// Initializes TRMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ // Determine grid shape -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ if (kSplitKSerial) { -+ if (args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ size_t bytes = get_workspace_size(args); -+ -+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ } -+ else { -+ -+ if (args.batch_count > 1) { -+ return Status::kErrorInvalidProblem; -+ } -+ } -+ -+ int gemm_k_size = args.problem_size.k(); -+ -+ // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). -+ if (kSideMode == SideMode::kRight) { -+ // Initialize the Params structure -+ params_ = typename TrmmKernel::Params{ -+ args.swapped_matrices(), -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ // Initialize the Params structure -+ params_ = typename TrmmKernel::Params{ -+ args, -+ grid_tiled_shape, -+ gemm_k_size, -+ static_cast(workspace) -+ }; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ if (kSplitKSerial && args.batch_count > 1) { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ } -+ -+ size_t workspace_bytes = get_workspace_size(args); -+ -+ if (workspace_bytes && !workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ params_.update(args, workspace); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); -+ dim3 block(TrmmKernel::kThreadCount, 1, 1); -+ -+ int smem_size = int(sizeof(typename TrmmKernel::SharedStorage)); -+ -+ if (smem_size >= (48 << 10)) { -+ cudaError_t result = cudaFuncSetAttribute(Kernel, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ cutlass::Kernel<<>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+/******************************************************************************************************** -+ TRMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} -+ In templates and arguments to cutlass kernel, `matrix A` is always triangular, and `matrix B` is rectangular. -+ (adhering to the cuBLAS convention) -+ -+For the mainloop and trmm kernel, `A` and `B` points to left-side and right-side matrices, respectively. -+ -+ Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for -+ the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. -+ -+ Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by -+ transposing the GEMM problem. Thus, ColumnMajor output layout for TRMM requires: -+ - Transposing `matrix A` and `matrix B` layouts -+ - Swapping problem size m and n values -+ - Swapping LeftSide and RightSide mode -+ -+ RowMajor output: D = matrix A x matrix B -+ ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) -+ -+ {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: -+ 1. LeftSide mode and RowMajor output (default template) -+ 2. LeftSide mode and ColumnMajor output -+ 3. RightSide mode and RowMajor output -+ 4. RightSide mode and ColumnMajor output -+ -+ Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: -+ -+ Case 2 -> Case 3: -+ D_col = matrix A x matrix B (LeftSide mode) -+ => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) -+ -+ swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue -+ -+ Case 4 -> Case 1: -+ D_col = matrix B x matrix A (RightSide mode) -+ => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) -+ -+ call GEMM mainloop for with RowMajor efficient-epilogue -+********************************************************************************************************/ -+ -+/// Parital specialization for column-major output exchanges problem size and operand. -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A -+ SideMode SideModeA, -+ /// Fill Mode for A -+ FillMode FillModeA, -+ /// DiagType for A -+ DiagType DiagTypeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Epilogue output operator -+ typename EpilogueOutputOp_, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Access granularity of A matrix in units of elements -+ int AlignmentA, -+ /// Access granularity of B matrix in units of elements -+ int AlignmentB, -+ /// If true, kernel supports split-K as a serial reduction -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA> -+class Trmm { -+ public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ using ElementC = ElementC_; -+ using LayoutC = layout::ColumnMajor; -+ using TensorRefC = TensorRef; -+ using TensorRefD = TensorRef; -+ using ElementAccumulator = ElementAccumulator_; -+ using OperatorClass = OperatorClass_; -+ using ArchTag = ArchTag_; -+ using ThreadblockShape = ThreadblockShape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ using Operator = Operator_; -+ static SideMode const kSideMode = SideModeA; -+ static FillMode const kFillMode = FillModeA; -+ static DiagType const kDiagType = DiagTypeA; -+ // Changing SideMode as we change the layout -+ static SideMode const kSideModeT = (SideModeA == SideMode::kLeft) ? -+ SideMode::kRight : SideMode::kLeft; -+ // Changing FillMode as we change the layout -+ static FillMode const kFillModeT = (FillModeA == FillMode::kLower) ? -+ FillMode::kUpper : FillMode::kLower; -+ static int const kStages = Stages; -+ static int const kAlignmentA = AlignmentA; -+ static int const kAlignmentB = AlignmentB; -+ static ComplexTransform const kTransformA = TransformA; -+ // Complex Transform don't appply to B -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ using UnderlyingOperator = Trmm< -+ ElementA, -+ typename layout::LayoutTranspose::type, -+ kSideModeT, -+ kFillModeT, -+ kDiagType, -+ ElementB, -+ typename layout::LayoutTranspose::type, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial, -+ Operator, -+ TransformA -+ >; -+ -+ using Arguments = typename UnderlyingOperator::Arguments; -+ using TrmmKernel = typename UnderlyingOperator::TrmmKernel; -+ static int const kAlignmentC = UnderlyingOperator::kAlignmentC; -+ -+private: -+ -+ UnderlyingOperator underlying_operator_; -+ -+public: -+ -+ /// Constructs the TRMM. -+ Trmm() { } -+ -+ /// Helper to construct a transposed equivalent for the underying TRMM operator which is identical -+ static Arguments to_underlying_arguments(Arguments const &args) { -+ return args.transposed_problem_size(); -+ } -+ -+ /// Determines whether the TRMM can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return UnderlyingOperator::can_implement(to_underlying_arguments(args)); -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ -+ return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); -+ } -+ -+ /// Initializes TRMM state from arguments. -+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); -+ } -+ -+ /// Lightweight update given a subset of arguments -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ return underlying_operator_.update(to_underlying_arguments(args), workspace); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ return underlying_operator_.run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp b/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp -new file mode 100644 -index 0000000..a2cd9a1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/dispatch_policy.hpp -@@ -0,0 +1,144 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+ -+#include "cute/layout.hpp" -+#include "cute/numeric/integral_constant.hpp" -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm { -+using namespace cute; -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Policies for categorical dispatch of mainloop against kernel grid schedules -+// -+struct KernelMultistage { }; -+struct KernelTma { }; -+struct KernelTmaWarpSpecialized { }; -+struct KernelTmaWarpSpecializedPersistent { }; -+ -+// -+// Collective Mainloop Policies -+// -+ -+// 2 stage pipeline through 1 stage in smem, 1 in rmem, WITHOUT predicated gmem loads -+struct MainloopSm70TwoStageUnpredicated { -+ constexpr static int Stages = 2; -+ using ArchTag = arch::Sm70; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// 2 stage pipeline through 1 stage in smem, 1 in rmem, with predicated gmem loads -+struct MainloopSm70TwoStage { -+ constexpr static int Stages = 2; -+ using ArchTag = arch::Sm70; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with registers, WITHOUT predicated gmem loads -+template -+struct MainloopSm80CpAsyncUnpredicated { -+ constexpr static int Stages = Stages_; -+ using ArchTag = arch::Sm80; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads -+template -+struct MainloopSm80CpAsync { -+ constexpr static int Stages = Stages_; -+ using ArchTag = arch::Sm80; -+ using Schedule = KernelMultistage; -+ using ClusterShape = Shape<_1,_1,_1>; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with Hopper GMMA, WITHOUT predicated gmem loads -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1> -+> -+struct MainloopSm90CpAsyncGmmaUnpredicated { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelMultistage; -+}; -+ -+// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1> -+> -+struct MainloopSm90CpAsyncGmma { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelMultistage; -+}; -+ -+// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1>, -+ int PipelineAsyncMmaStages_ = 1 -+> -+struct MainloopSm90TmaGmma { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelTma; -+}; -+ -+// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule -+template< -+ int Stages_, -+ class ClusterShape_ = Shape<_1,_1,_1>, -+ class KernelSchedule = KernelTmaWarpSpecialized -+> -+struct MainloopSm90TmaGmmaWarpSpecialized { -+ constexpr static int Stages = Stages_; -+ using ClusterShape = ClusterShape_; -+ using ArchTag = arch::Sm90; -+ using Schedule = KernelSchedule; -+}; -+ -+////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm -diff --git a/3rdparty/cutlass/include/cutlass/gemm/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/gemm.h -new file mode 100644 -index 0000000..4b76101 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/gemm.h -@@ -0,0 +1,574 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines common types used for all GEMM-like operators. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/matrix.h" -+#include "cute/layout.hpp" -+#include "cute/arch/copy_sm90.hpp" -+ -+namespace cutlass { -+namespace gemm { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM operand enumeration: D = A * B + C -+enum class Operand { -+ kA, /// A multiplicand -+ kB, /// B multiplicand -+ kC, /// Source accumulator -+ kD /// Destination accumulator -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Shape of a matrix multiply-add operation -+template < -+ /// Rows of matrix product -+ int M = 1, -+ /// Columns of matrix product -+ int N = 1, -+ /// Inner dimension of matrix product -+ int K = 1 -+> -+struct GemmShape { -+ static int const kM = M; -+ static int const kN = N; -+ static int const kK = K; -+ -+ static int const kMN = M * N; -+ static int const kMK = M * K; -+ static int const kKN = N * K; -+ static int const kMNK = M * N * K; -+ -+ static int const kCount = kMNK; -+ -+ // -+ // Static member functions -+ // -+ -+ /// Returns a Coord object -+ CUTLASS_HOST_DEVICE -+ static Coord<3> toCoord() { -+ return make_Coord(kM, kN, kK); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Type alias of the transpose of a GemmShape -+template < -+ /// concept: GemmShape -+ typename Shape -+> -+using GemmShapeTranspose = GemmShape; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GemmCoord is a structure derived from Coord<3> that specifies a location within the -+/// coordinate space of a GEMM problem. -+struct GemmCoord : public Coord<3, int> { -+ -+ /// Integer-valued index -+ typedef int Index; -+ -+ /// Base type is a Coord of rank=3 -+ typedef Coord<3, Index> Base; -+ -+ /// GEMM M dimension - rows of the output C matrix -+ static int const kM = 0; -+ -+ /// GEMM N dimension - columns of the output C matrix -+ static int const kN = 1; -+ -+ /// GEMM K dimension - inner dimension of the GEMM problem -+ static int const kK = 2; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ GemmCoord() { } -+ -+ /// Constructs from Coord<3> and a batch -+ CUTLASS_HOST_DEVICE -+ GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { } -+ -+ /// Helper to construct from a K, N, M, batch variables -+ CUTLASS_HOST_DEVICE -+ GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { } -+ -+ /// Returns the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & m() const { return this->at(kM); } -+ -+ /// Returns reference to the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index & m() { return this->at(kM); } -+ -+ /// Returns the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns reference to the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & k() const { return this->at(kK); } -+ -+ /// Returns reference to the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index & k() { return this->at(kK); } -+ -+ /// Obtains a Coord<3> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<3> mnk() const { -+ return make_Coord(m(), n(), k()); -+ } -+ -+ /// Obtains a Coord<3> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<3> knm() const { -+ return make_Coord(k(), n(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> nm() const { -+ return make_Coord(n(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> mn() const { -+ return make_Coord(m(), n()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> mk() const { -+ return make_Coord(m(), k()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> km() const { -+ return make_Coord(k(), m()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> nk() const { -+ return make_Coord(n(), k()); -+ } -+ -+ /// Obtains a Coord<2> from GemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<2> kn() const { -+ return make_Coord(k(), n()); -+ } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator+(Base const& b) const { -+ return GemmCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator-(Base const& b) const { -+ return GemmCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator*(Base const& b) const { -+ return GemmCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ GemmCoord operator/(Base const& b) const { -+ return GemmCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ GemmCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// BatchedGemmCoord is a structure derived from Coord<4> that specifies a location within the -+/// coordinate space of a batched GEMM problem. -+struct BatchedGemmCoord : public Coord<4, int> { -+ -+ /// Integer-valued index -+ typedef int Index; -+ -+ /// Base type is a Coord of rank=4 -+ typedef Coord<4, Index> Base; -+ -+ /// GEMM M dimension - rows of the output C matrix -+ static int const kM = 0; -+ -+ /// GEMM N dimension - columns of the output C matrix -+ static int const kN = 1; -+ -+ /// GEMM K dimension - inner dimension of the GEMM problem -+ static int const kK = 2; -+ -+ /// GEMM Batch dimension - inner dimension of the GEMM problem -+ static int const kBatch = 3; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord() { } -+ -+ /// Constructs from Coord<4> -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord(Base const &coord): Base(coord) { } -+ -+ /// Helper to construct from a K, N, M, and batch variables -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { } -+ -+ /// Returns the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & m() const { return this->at(kM); } -+ -+ /// Returns reference to the GEMM M coordinate -+ CUTLASS_HOST_DEVICE -+ Index & m() { return this->at(kM); } -+ -+ /// Returns the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns reference to the GEMM N coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & k() const { return this->at(kK); } -+ -+ /// Returns reference to the GEMM K coordinate -+ CUTLASS_HOST_DEVICE -+ Index & k() { return this->at(kK); } -+ -+ /// Returns the GEMM batch coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & batch() const { return this->at(kBatch); } -+ -+ /// Returns reference to the GEMM batch coordinate -+ CUTLASS_HOST_DEVICE -+ Index & batch() { return this->at(kBatch); } -+ -+ /// Obtains a GemmCoord from BatchedGemmCoord -+ CUTLASS_HOST_DEVICE -+ GemmCoord mnk() const { -+ return GemmCoord(m(), n(), k()); -+ } -+ -+ /// Obtains a Coord<4> from BatchedGemmCoord -+ CUTLASS_HOST_DEVICE -+ Coord<4> mnkb() const { -+ return make_Coord(m(), n(), k(), batch()); -+ } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator+(Base const& b) const { -+ return BatchedGemmCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator-(Base const& b) const { -+ return BatchedGemmCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator*(Base const& b) const { -+ return BatchedGemmCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord operator/(Base const& b) const { -+ return BatchedGemmCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class GemmUniversalMode { -+ kGemm, -+ kGemmSplitKParallel, -+ kBatched, -+ kArray, -+ kInvalid -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Some options for clearing shared memory -+enum class SharedMemoryClearOption { -+ kNone, ///< SMEM is in don't-care state -+ kZfill, ///< Kernels fill out of bounds accesses with zeros -+ kClearLastStage ///< Last SMEM stage is explicitly cleared. Mainloop uses 'kNone' -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// For each cutlass::layout, provides its corresponding cute stride types, 64b by default -+ -+template -+struct TagToStrideA {}; -+ -+// Maps to modes [M, K, L] -+template <> -+struct TagToStrideA { -+ using type = cute::Stride, int64_t>; -+ using tag = layout::RowMajor; -+}; -+ -+// Maps to modes [M, K, L] -+template <> -+struct TagToStrideA { -+ using type = cute::Stride, int64_t, int64_t>; -+ using tag = layout::ColumnMajor; -+}; -+ -+template -+struct TagToStrideB {}; -+ -+// Maps to modes [N, K, L] -+template <> -+struct TagToStrideB { -+ using type = cute::Stride, int64_t, int64_t>; -+ using tag = layout::RowMajor; -+}; -+ -+// Maps to modes [N, K, L] -+template <> -+struct TagToStrideB { -+ using type = cute::Stride, int64_t>; -+ using tag = layout::ColumnMajor; -+}; -+ -+ -+// Maps to modes [N, N, L] -+template -+struct TagToStrideC : TagToStrideA { }; -+ -+// Convenience aliases -+template -+using TagToStrideA_t = typename TagToStrideA::type; -+ -+template -+using TagToStrideB_t = typename TagToStrideB::type; -+ -+template -+using TagToStrideC_t = typename TagToStrideC::type; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// For 2.x compatibility APIs, provide stride->layout tag mappers -+ -+namespace detail { -+ -+// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices -+template -+constexpr -+auto -+stride_to_layout_tag_A() { -+ // Account for stride types with and without batch mode and batch modes with static zero stride -+ if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major -+ return layout::ColumnMajor{}; -+ } -+ else { // K major -+ return layout::RowMajor{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+template -+constexpr -+auto -+stride_to_layout_tag_B() { -+ // Account for stride types with and without batch mode and batch modes with static zero stride -+ if constexpr (cute::size<0>(StrideB{}) == 1) { // N major -+ return layout::RowMajor{}; -+ } -+ else { // K major -+ return layout::ColumnMajor{}; -+ } -+ -+ CUTE_GCC_UNREACHABLE; -+} -+ -+// Inspects a TiledCopy and returns its alignment in terms of element count -+template -+constexpr int -+get_alignment_count_from_gmem_tiled_copy() { -+ // For TMA tiled copies, we know the alignment has to be 128 bits -+ if constexpr (std::is_base_of_v || -+ std::is_base_of_v) { -+ return 128 / sizeof_bits::value; -+ } -+ else -+ { -+ // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN -+ return GmemTiledCopy::NumValSrc; -+ } -+} -+ -+// Utilities to map Stride back on to their corresponding layout tags -+template -+struct StrideToLayoutTagA { -+ using type = decltype(detail::stride_to_layout_tag_A()); -+}; -+ -+template -+struct StrideToLayoutTagB { -+ using type = decltype(detail::stride_to_layout_tag_B()); -+}; -+ -+// Maps to modes [N, N, L] -+template -+struct StrideToLayoutTagC : StrideToLayoutTagA { }; -+ -+// Convenience aliases -+template -+using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; -+ -+template -+using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; -+ -+template -+using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` -+// is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. -+template -+struct IsCutlass3GemmKernel : std::false_type { }; -+ -+template -+struct IsCutlass3GemmKernel> -+ : std::true_type { }; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h -new file mode 100644 -index 0000000..04b14a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_ell_gemm.h -@@ -0,0 +1,837 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Default kernel-level Blocked-Ell sparse gemm operators. -+ This operator combines threadblock-scoped ELL MMA -+ with the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+#include "cutlass/gemm/kernel/ell_gemm.h" -+#include "cutlass/gemm/threadblock/default_ell_mma.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 2, -+ Operator -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, -+ layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, -+ InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, IsASparse> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, -+ true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse> -+struct DefaultEllGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, -+ WarpShape, InstructionShape, EpilogueOutputOp, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, IsASparse> { -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, -+ InstructionShape, 2, Operator, true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Partial specialization for Volta architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+> -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ 2, -+ Operator -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassSimt, -+ arch::Sm50, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ 2, -+ Operator>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80, -+ ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, -+ Operator>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for SIMT DP4A -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Layout type for C matrix operand -+ typename LayoutC, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm, -+ EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, -+ Operator, IsASparse> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ -+ using OperatorClass = arch::OpClassSimt; -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for Wmma Gemm Kernel -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Sparse matrix is A or not -+ bool IsASparse -+ > -+struct DefaultEllGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ IsASparse> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::EllGemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h -new file mode 100644 -index 0000000..4432008 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm.h -@@ -0,0 +1,1060 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/layout/permute.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+> -+struct DefaultGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using RegularEpilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; -+ -+ using Affine2Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< -+ 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm75, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear> -+struct DefaultGemm< -+ ElementA, layout::ColumnMajorInterleaved, kAlignmentA, -+ ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, -+ layout::ColumnMajorInterleaved, int32_t, -+ arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, -+ InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ SplitKSerial, Operator, SharedMemoryClear, false, false, false> { -+ -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, -+ true, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of Interleaved k -+ int InterleavedK, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear> -+struct DefaultGemm, -+ kAlignmentA, ElementB, -+ layout::RowMajorInterleaved, kAlignmentB, -+ ElementC, layout::ColumnMajorInterleaved, -+ int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, -+ WarpShape, InstructionShape, EpilogueOutputOp, -+ ThreadblockSwizzle, 2, SplitKSerial, Operator, SharedMemoryClear, -+ false, false, false> { -+ -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using LayoutB = layout::RowMajorInterleaved; -+ using LayoutC = layout::ColumnMajorInterleaved; -+ -+ using ElementAccumulator = int32_t; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, -+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, -+ InstructionShape, 2, Operator, true>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock:: -+ DefaultInterleavedEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ 64 / sizeof_bits::value, InterleavedK>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Volta architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, layout::RowMajor, -+ ElementAccumulator, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ layout::RowMajor, -+ arch::OpClassTensorOp, -+ arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<8, 8, 4>, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB -+ >::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+ > -+struct DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ arch::OpClassSimt, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ 2, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout, -+ typename platform::enable_if< ! platform::is_same::value >::type > { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ arch::OpClassSimt, -+ arch::Sm50, -+ ThreadblockShape, -+ WarpShape, -+ GemmShape<1, 1, 1>, -+ 2, -+ Operator, -+ false, -+ SharedMemoryClear, -+ GatherA, -+ GatherB>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< -+ 2, -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemm, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout> { -+ -+ static_assert((platform::is_same::value -+ || platform::is_same>::value), -+ "Epilogue in the kernel level must be row major"); -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm80, -+ ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, -+ Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using RegularEpilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess, -+ ScatterD, -+ PermuteDLayout -+ >::Epilogue; -+ -+ using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< -+ 2, -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ using Epilogue = typename platform::conditional::value, -+ RegularEpilogue, -+ Affine2Epilogue>::type; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for SIMT DP4A -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Layout type for C matrix operand -+ typename LayoutC, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+> -+struct DefaultGemm, -+ EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, -+ Operator, SharedMemoryClear, false, false, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ -+ using OperatorClass = arch::OpClassSimt; -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma::ThreadblockMma; -+ -+ static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; -+ static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ kEpilogueElementsPerAccess -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+//////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for Wmma Gemm Kernel -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+> -+struct DefaultGemm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ SharedMemoryClear, -+ false, -+ false, -+ false -+> { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ ThreadblockShape, -+ typename Mma::Operator, -+ kPartitionsK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h -new file mode 100644 -index 0000000..956068b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_complex.h -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+> -+struct DefaultGemmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassSimt, -+ arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassSimt, -+ Stages, -+ Operator, -+ false, -+ cutlass::arch::CacheOperation::Global, -+ cutlass::arch::CacheOperation::Global, -+ TransformA, -+ TransformB -+ >; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, -+ typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, -+ typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultGemmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassSimt, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, -+ layout::RowMajor, arch::OpClassSimt, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ ThreadblockShape, -+ typename Mma::Operator, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::Gemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h -new file mode 100644 -index 0000000..c44f060 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// Operation performed by GEMM -+ typename Operator = typename device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator>::Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+ > -+struct DefaultGemmGrouped; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemmGrouped< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ GroupScheduleMode_, -+ Operator, -+ SharedMemoryClear, -+ PermuteDLayout, -+ typename platform::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, -+ kAlignmentB, -+ LayoutC, -+ kInternalTranspose -+ >; -+ -+ // Define the default GEMM kernel -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ true, -+ Operator, -+ SharedMemoryClear, -+ false, /*GatherA*/ -+ false, /*GatherB*/ -+ false, /*ScatterD*/ -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmGrouped< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+ > -+struct DefaultGemmGrouped< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ GroupScheduleMode_, -+ Operator, -+ SharedMemoryClear, -+ layout::NoPermute, /*PermuteDLayout*/ -+ typename platform::enable_if::value>::type -+> { -+ -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ kInternalTranspose -+ >; -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemmComplex< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MapArguments::kTransformA, -+ MapArguments::kTransformB, -+ Operator, -+ false -+ >::GemmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmGrouped< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..323ae5d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level softmax-grouped-GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias_, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// Operation performed by GEMM -+ typename Operator = typename device::DefaultGemmConfiguration< -+ OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, -+ ElementAccumulator>::Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultGemmGroupedSoftmaxMainloopFusion { -+ // If true, we must construct a 'transposed-and-exchanged' Mma operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::MapArguments< -+ ElementA_, -+ LayoutA_, -+ ComplexTransform::kNone, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ ComplexTransform::kNone, -+ kAlignmentB, -+ LayoutC_, -+ kInternalTranspose -+ >; -+ -+private: -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaSoftmaxMainloopFusion< -+ typename MapArguments::ElementA, typename MapArguments::LayoutA, MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, typename MapArguments::LayoutB, MapArguments::kAlignmentB, -+ ElementScaleBias_, LayoutScaleBias_, ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, kInternalTranspose, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+public: -+ using GemmKernel = kernel::GemmGroupedSoftmaxMainloopFusion< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..76a405a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> -+struct DefaultGemmLayernormMainloopFusion { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaLayernormMainloopFusion< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementScaleBias, LayoutScaleBias, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmLayernormMainloopFusion; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h -new file mode 100644 -index 0000000..e3b58cb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_planar_complex_universal.h -@@ -0,0 +1,352 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_planar_complex.h" -+#include "cutlass/gemm/kernel/gemm_planar_complex_array.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Math operation performed by GEMM (e.g. arch::OpMultiplyAdd) -+ typename Operator, -+ /// Conditional enabling to switch between stages -+ typename Enable = void -+ > -+struct DefaultGemmPlanarComplexUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for pipelined mainloop -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultGemmPlanarComplexUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ typename platform::enable_if<(Stages <= 2)>::type -+> { -+ -+ /// Define planar complex valued variants instead -+ using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexPipelined< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator -+ >::ThreadblockMma; -+ -+ /// Planar complex epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ ThreadblockShape, -+ typename Mma::Policy::Operator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape::kK / WarpShape::kK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmPlanarComplex< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Array variant -+ using GemmArrayKernel = kernel::GemmPlanarComplexArray< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiple pipeline stages. -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultGemmPlanarComplexUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ typename platform::enable_if<(Stages > 2)>::type -+> { -+ -+ /// Define planar complex valued variants instead -+ using Mma = typename gemm::threadblock::DefaultMmaPlanarComplexMultistage< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator -+ >::ThreadblockMma; -+ -+ /// Planar complex epilogue -+ using Epilogue = typename epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ ThreadblockShape, -+ typename Mma::Policy::Operator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape::kK / WarpShape::kK, -+ EpilogueOutputOp, -+ EpilogueOutputOp::kCount -+ >::Epilogue; -+ -+ /// Define the kernel in terms of the default kernel -+ using GemmKernel = kernel::GemmPlanarComplex< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+ -+ // Array variant -+ using GemmArrayKernel = kernel::GemmPlanarComplexArray< -+ Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h -new file mode 100644 -index 0000000..7303e01 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_sparse.h -@@ -0,0 +1,191 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm.h" -+#include "cutlass/gemm/kernel/sparse_gemm.h" -+#include "cutlass/gemm/kernel/gemm_pipelined.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#include "cutlass/gemm/threadblock/default_sparse_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSparseGemm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSparseGemm { -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::SparseGemm; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h -new file mode 100644 -index 0000000..7fc9da3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_splitk_parallel.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/gemm_splitk_parallel.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator -+> -+struct DefaultGemmSplitKParallel { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate using the basic GEMM's -+ /// mainloop. -+ using Default = DefaultGemm< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator, -+ LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ false, -+ Operator -+ >; -+ -+ /// Define the matrix multiply operator -+ using Mma = typename Default::Mma; -+ -+ /// Define the epilogue -+ using Epilogue = typename Default::Epilogue; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmSplitKParallel; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h -new file mode 100644 -index 0000000..45a825d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_universal.h -@@ -0,0 +1,382 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/gemm_universal.h" -+#include "cutlass/gemm/kernel/gemm_universal_streamk.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+ -+#include "cutlass/layout/permute.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false, -+ /// Scatter result D by using an index array -+ bool ScatterD = false, -+ /// Permute result D -+ typename PermuteDLayout = layout::NoPermute, -+ /// -+ typename Enable = void -+ > -+struct DefaultGemmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB, -+ /// Scatter result D by using an index array -+ bool ScatterD, -+ /// Permute result D -+ typename PermuteDLayout -+> -+struct DefaultGemmUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout, -+ typename platform::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ true, -+ Operator, -+ SharedMemoryClear, -+ GatherA, -+ GatherB, -+ ScatterD, -+ PermuteDLayout -+ >::GemmKernel; -+ -+ /// Universal kernel without StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversal< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Universal kernel with StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversalStreamk< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Select kernel by ThreadblockSwizzle's support for StreamkFeature -+ using GemmKernel = SelectBase; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued GEMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear -+ > -+struct DefaultGemmUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClear, -+ false, -+ false, -+ false, -+ layout::NoPermute, -+ typename platform::enable_if::value>::type -+> { -+ -+ using DefaultGemmKernel = typename kernel::DefaultGemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ false -+ >::GemmKernel; -+ -+ /// Universal kernel without StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversal< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Universal kernel with StreamkFeature member type -+ template -+ class SelectBase : -+ public kernel::GemmUniversalStreamk< -+ typename DefaultGemmKernel::Mma, -+ typename DefaultGemmKernel::Epilogue, -+ SwizzleT> -+ {}; -+ -+ /// Select kernel by ThreadblockSwizzle's support for StreamkFeature -+ using GemmKernel = SelectBase; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h -new file mode 100644 -index 0000000..1356b49 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable = void -+> -+struct DefaultGemmWithBroadcast { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ typename EpilogueOutputOp::ElementT, -+ ElementC_, -+ EpilogueOutputOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization: ArchTag = cutlass::arch::Sm70 -+/// -+/// -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable -+> -+struct DefaultGemmWithBroadcast< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ Enable -+ > { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ typename EpilogueOutputOp::ElementT, -+ ElementC_, -+ EpilogueOutputOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h -new file mode 100644 -index 0000000..422db5c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial -+ specializations here choose 'device::GemmTransposed' to implement this functionality. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" -+#include "cutlass/gemm/threadblock/default_mma_with_reduction.h" -+#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Reduce A or B along the K dimension -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// -+ typename Enable = void> -+struct DefaultGemmWithKReduction { -+ -+ static const bool kReduceKForA = (platform::is_same::value) ? ReduceKForA_ : !ReduceKForA_; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaWithReduction< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator, false, SharedMemoryClear>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the epilogue of the reduction vector -+ using EpilogueGemmKReduction = -+ typename cutlass::epilogue::threadblock::EpilogueGemmKReduction< -+ ElementAccumulator, ElementC, ThreadblockShape, typename Mma::Operator, kReduceKForA>; -+ -+ /// Define the kernel-level GEMM operator. -+ using GemmKernel = kernel::GemmWithKReduction; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h -new file mode 100644 -index 0000000..6e9e647 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemm_with_reduction.h -@@ -0,0 +1,246 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Epilogue reduction operator -+ typename EpilogueReductionOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable = void -+> -+struct DefaultGemmWithReduction { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SharedMemoryClearOption::kClearLastStage -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization: ArchTag = cutlass::arch::Sm70 -+/// -+/// -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Epilogue reduction operator -+ typename EpilogueReductionOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// -+ typename Enable -+> -+struct DefaultGemmWithReduction< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, -+ ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ Enable -+ > { -+ -+ using GemmBase = typename DefaultGemmUniversal< -+ ElementA_, LayoutA_, TransformA, kAlignmentA, -+ ElementB_, LayoutB_, TransformB, kAlignmentB, -+ ElementC_, LayoutC_, ElementAccumulator, -+ OperatorClass, -+ cutlass::arch::Sm70, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator -+ >::GemmKernel; -+ -+ // Replace epilogue -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< -+ typename GemmBase::Epilogue::Shape, -+ typename GemmBase::Epilogue::WarpMmaOperator, -+ GemmBase::Epilogue::kPartitionsK, -+ ElementC_, -+ EpilogueOutputOp, -+ EpilogueReductionOp, -+ GemmBase::Epilogue::kElementsPerAccess -+ >::Epilogue; -+ -+ // Compose the GEMM kernel -+ using GemmKernel = GemmWithFusedEpilogue< -+ typename GemmBase::Mma, -+ Epilogue, -+ ThreadblockSwizzle -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h -new file mode 100755 -index 0000000..263930c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_gemv.h -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+#include "cutlass/gemm/threadblock/default_gemv_core.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the ThreadBlock tile - concept: gemm::GemmShape<> -+ typename ThreadBlockShape_, -+ /// Size of the per-thread shape - concept: gemm::GemmShape<> -+ typename ThreadShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C/D matrix -+ typename ElementCD_, -+ /// Layout of C/D matrix (concept: MatrixLayout) -+ typename LayoutCD_, -+ /// Data type of the accumulator -+ typename ElementAccumulator_ = ElementCD_> -+struct DefaultGemv { -+ -+ /// Shape of Threadblock-level matrix operation (concept: GemmShape) -+ using ThreadBlockShape = ThreadBlockShape_; -+ -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using ThreadShape = ThreadShape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulators -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Data type of accumulators (same as C/D) -+ using LayoutAccumulator = LayoutCD_; -+ -+ /// Data type of input/output matrix C/D -+ using ElementCD = ElementCD_; -+ -+ /// Layout of input/output matrix C/D -+ using LayoutCD = LayoutCD_; -+ -+ // Define the core components -+ using Core = typename cutlass::gemm::threadblock::DefaultGemvCore< -+ ThreadBlockShape, ThreadShape, ElementA, LayoutA, ElementB, LayoutB, -+ ElementAccumulator, LayoutAccumulator>; -+ -+ // Define the threadblock-scoped gemv -+ using ThreadBlockGemv = cutlass::gemm::threadblock::Gemv; -+ -+ // Iterator for multiplicand A -+ using IteratorA = typename ThreadBlockGemv::IteratorA; -+ -+ // Iterator for multiplicand B -+ using IteratorB = typename ThreadBlockGemv::IteratorB; -+ -+ /// Policy for the iterator that reads/writes C/D -+ using IteratorPolicyCD = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, Core::kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, Core::kThreadsPerN, ThreadShape::kM>>::type; -+ -+ /// Iterator that reads/writes C/D -+ using IteratorCD = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementCD, LayoutCD, 0, IteratorPolicyCD>; -+ -+ /// Fragment storage for C/D -+ using FragmentCD = typename IteratorCD::Fragment; -+ -+ // Define the threadblock swizzle -+ using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h -new file mode 100644 -index 0000000..4573a3a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank2K definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRank2K; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRank2K< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x BT) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementB, typename layout::LayoutTranspose::type, -+ kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x AT) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementB, LayoutB, -+ kAlignmentB, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRank2K< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x BT) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementB, typename layout::LayoutTranspose::type, -+ kAlignmentB, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x AT) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementB, LayoutB, -+ kAlignmentB, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h -new file mode 100644 -index 0000000..dc34fe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_complex.h -@@ -0,0 +1,498 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank2K definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRank2KComplex; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace detail { -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation -+ ComplexTransform TransformB, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > struct Rank2KTransposedComplexTransform { -+ -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+}; -+ -+ // partial specializations for HER2K CUBLAS_OP_N layout (ColumMajor) -+template <> -+ struct Rank2KTransposedComplexTransform < -+ layout::ColumnMajor, layout::ColumnMajor, -+ ComplexTransform::kNone, ComplexTransform::kNone, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kConjugate; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+}; -+ -+ // partial specializations for HER2K CUBLAS_OP_C layout (RowMajor + Complex conjugate) -+template <> -+ struct Rank2KTransposedComplexTransform < -+ layout::RowMajor, layout::RowMajor, -+ ComplexTransform::kConjugate, ComplexTransform::kConjugate, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kConjugate; -+ -+}; -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^T) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform< -+ LayoutA, LayoutB, -+ TransformA, TransformB, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^H) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^H) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^T) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRank2KComplex< -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ -+ using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform< -+ LayoutA, LayoutB, -+ TransformA, TransformB, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^H) -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementB, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (B x A^H) -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementB, LayoutB, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level Rank2K operator. -+ using Rank2Kkernel = kernel::Rank2KUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h -new file mode 100644 -index 0000000..a237125 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_grouped.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level grouped Rank2K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, -+ /// -+ typename Enable = void -+ > -+struct DefaultRank2KGrouped; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued grouped Rank2K -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ -+ > -+struct DefaultRank2KGrouped::value>::type -+> { -+ // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::Rank2KMapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ FillModeC, -+ kInternalTranspose -+ >; -+ -+ // Define the default grouped Rank2K kernel -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2K< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ MapArguments::kAlignmentA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ MapArguments::kAlignmentB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ MapArguments::kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ false, // SplitKSerial -+ Operator, -+ BlasMode_ -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KGrouped< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ TransformA, -+ TransformB, -+ DefaultRank2Kkernel::kFillModeC, -+ DefaultRank2Kkernel::kBlasMode, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued grouped Rank2K -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_, -+ /// Whether the schedule of problems to visit has been precomputed -+ GroupScheduleMode GroupScheduleMode_ -+ > -+struct DefaultRank2KGrouped::value>::type -+> { -+ // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. -+ static bool const kInternalTranspose = platform::is_same::value; -+ -+ using MapArguments = kernel::detail::Rank2KMapArguments< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ LayoutC, -+ FillModeC, -+ kInternalTranspose -+ >; -+ -+ // Define the default grouped Rank2K kernel -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2KComplex< -+ typename MapArguments::ElementA, -+ typename MapArguments::LayoutA, -+ typename MapArguments::ElementB, -+ typename MapArguments::LayoutB, -+ ElementC, -+ typename MapArguments::LayoutC, -+ MapArguments::kFillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ MapArguments::kTransformA, -+ MapArguments::kTransformB, -+ Operator, -+ false, // SplitKSerial -+ BlasMode_ -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ /// Pass through the user-provided TransformA and TransformB so as to -+ /// correctly set public-facing TransformA and TransformB in kernel::Rank2KGrouped. -+ /// This is needed because kernel::DefaultRank2KComplex may change TransformA and -+ /// TransformB that become template arguments to Mma1 and Mma2. -+ using Rank2Kkernel = kernel::Rank2KGrouped< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ TransformA, -+ TransformB, -+ DefaultRank2Kkernel::kFillModeC, -+ DefaultRank2Kkernel::kBlasMode, -+ GroupScheduleMode_, -+ kInternalTranspose -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h -new file mode 100644 -index 0000000..9651300 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_2k_universal.h -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank 2k definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/rank_2k_universal.h" -+#include "cutlass/gemm/kernel/default_rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultRank2KUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Rank 2k update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by Rank2k -+ typename Operator> -+struct DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2K< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KUniversal< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC, -+ BlasMode::kSymmetric -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued Rank 2K update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultRank2KUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultRank2Kkernel = typename kernel::DefaultRank2KComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::Rank2Kkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using Rank2Kkernel = kernel::Rank2KUniversal< -+ typename DefaultRank2Kkernel::Mma1, -+ typename DefaultRank2Kkernel::Mma2, -+ typename DefaultRank2Kkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC, -+ kBlasMode -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h -new file mode 100644 -index 0000000..2c0c7a8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k.h -@@ -0,0 +1,247 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level RankK definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRankK; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRankK< -+ ElementA, LayoutA, kAlignmentA, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x AT) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2 operator. -+ using RankKkernel = kernel::RankKUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultRankK< -+ ElementA, LayoutA, kAlignmentA, -+ ElementC,layout::RowMajor, FillModeC, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x AT) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMma< -+ ElementA, LayoutA, -+ kAlignmentA, -+ ElementA, typename layout::LayoutTranspose::type, -+ kAlignmentA, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; -+ -+ /// Define the kernel-level Rank2 operator. -+ using RankKkernel = kernel::RankKUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h -new file mode 100644 -index 0000000..d7569a9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_complex.h -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level RankK definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultRankKComplex; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+namespace detail { -+ -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation -+ ComplexTransform TransformA, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ -+ > struct RankKTransposedComplexTransform { -+ -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformA; -+ -+}; -+ -+ // partial specializations for HERK CUBLAS_OP_N layout (ColumMajor) -+template <> -+ struct RankKTransposedComplexTransform < -+ layout::ColumnMajor, -+ ComplexTransform::kNone, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kConjugate; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+}; -+ -+ // partial specializations for HERK CUBLAS_OP_C layout (RowMajor + Complex conjugate) -+template <> -+ struct RankKTransposedComplexTransform < -+ layout::RowMajor, -+ ComplexTransform::kConjugate, -+ BlasMode::kHermitian> { -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kConjugate; -+ -+}; -+ -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformA, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ using TransposedComplexTransform = detail::RankKTransposedComplexTransform< -+ LayoutA, -+ TransformA, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x A^H) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ TransformA, TransformA, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultRankKComplex< -+ ElementA, LayoutA, ElementC, -+ layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ TransformA, Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ // Complex transform for input A and B matrices (function on input layout) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ using TransposedComplexTransform = detail::RankKTransposedComplexTransform< -+ LayoutA, -+ TransformA, -+ kBlasMode>; -+ -+ // Complex transform on operandA and operandB (function of blas3 computation) -+ static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; -+ static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate (A x A^H) -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< -+ ElementA, LayoutA, -+ ElementA, typename layout::LayoutTranspose::type, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; -+ -+ /// Define the kernel-level RankK operator. -+ using RankKkernel = kernel::RankKUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h -new file mode 100644 -index 0000000..b8ce45c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_rank_k_universal.h -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level Rank k definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/rank_k_universal.h" -+#include "cutlass/gemm/kernel/default_rank_k.h" -+#include "cutlass/gemm/kernel/default_rank_k_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultRankKUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Rank k update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by Rank2k -+ typename Operator> -+struct DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultRankKkernel = typename kernel::DefaultRankK< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::RankKkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using RankKkernel = kernel::RankKUniversal< -+ typename DefaultRankKkernel::Mma, -+ typename DefaultRankKkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued Rank 2K update kernels -+// -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Fill Mode for C (kLower or kUpper) -+ FillMode FillModeC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultRankKUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultRankKkernel = typename kernel::DefaultRankKComplex< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ FillModeC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::RankKkernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using RankKkernel = kernel::RankKUniversal< -+ typename DefaultRankKkernel::Mma, -+ typename DefaultRankKkernel::Epilogue, -+ ThreadblockSwizzle, -+ FillModeC -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h -new file mode 100755 -index 0000000..1faf25d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm.h -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_trmm.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultSymm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSymm< -+ ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutAMma2, kAlignmentA, -+ ElementB, LayoutBMma2, kAlignmentB, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level SYMM/HEMM operator. -+ using SymmKernel = kernel::SymmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultSymm< -+ ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ ElementC,layout::RowMajor, -+ ElementAccumulator, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, -+ Operator> { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, -+ ElementB, LayoutB, kAlignmentB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutAMma2, kAlignmentA, -+ ElementB, LayoutBMma2, kAlignmentB, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level SYMM/HEMM operator. -+ using SymmKernel = kernel::SymmUniversal; -+}; -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h -new file mode 100755 -index 0000000..09cb7e5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_complex.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_trmm_complex.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kSymmetric> -+struct DefaultSymmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ // Complex Transform don't appply to A or B for SYMM -+ static ComplexTransform const TransformA = ComplexTransform::kNone; -+ static ComplexTransform const TransformB = ComplexTransform::kNone; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ static ComplexTransform const TransformAMma1 = ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma1 = ComplexTransform::kNone; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kConjugate : ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kNone : ComplexTransform::kConjugate; -+ -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (symmetric) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kSymmetric> { -+ -+ static BlasMode const kBlasMode = BlasMode::kSymmetric; -+ // Complex Transform don't appply to A or B for SYMM -+ static ComplexTransform const TransformA = ComplexTransform::kNone; -+ static ComplexTransform const TransformB = ComplexTransform::kNone; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture complex datatype (hermitian) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode kSideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode kFillModeA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial> -+struct DefaultSymmComplex< -+ ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, -+ layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, -+ Operator, SplitKSerial, BlasMode::kHermitian> { -+ -+ static BlasMode const kBlasMode = BlasMode::kHermitian; -+ -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - with diagonal: alpha * A * B or alpha * B * A -+ static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; -+ static ComplexTransform const TransformAMma1 = ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma1 = ComplexTransform::kNone; -+ using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ kSideModeA, kFillModeA, kDiagTypeMma1, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma; -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ /// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT -+ static const DiagType kDiagTypeMma2 = DiagType::kZero; -+ using LayoutAMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ typename layout::LayoutTranspose::type, -+ LayoutA -+ >::type; -+ using LayoutBMma2 = typename platform::conditional< -+ (kSideModeA == SideMode::kLeft), -+ LayoutB, -+ typename layout::LayoutTranspose::type -+ >::type; -+ static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kConjugate : ComplexTransform::kNone; -+ static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ? -+ ComplexTransform::kNone : ComplexTransform::kConjugate; -+ -+ using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutAMma2, -+ ElementB, LayoutBMma2, -+ kSideModeA, InvertFillMode::mode, kDiagTypeMma2, -+ ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level Symm operator. -+ using SymmKernel = kernel::SymmUniversal; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h -new file mode 100755 -index 0000000..adcf1ff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_symm_universal.h -@@ -0,0 +1,342 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level SYMM/HEMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/symm_universal.h" -+#include "cutlass/gemm/kernel/default_symm.h" -+#include "cutlass/gemm/kernel/default_symm_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ /// Blas3 computation mode (symmetric/hermitian) -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ /// -+ typename Enable = void -+ > -+struct DefaultSymmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued SYMM/HEMM update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYMM/HEMM -+ typename Operator> -+struct DefaultSymmUniversal< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultSymmkernel = typename kernel::DefaultSymm< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ BlasMode::kSymmetric -+ >::SymmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using SymmKernel = kernel::SymmUniversal< -+ typename DefaultSymmkernel::Mma1, -+ typename DefaultSymmkernel::Mma2, -+ typename DefaultSymmkernel::Epilogue, -+ ThreadblockSwizzle, -+ SideModeA, -+ FillModeA -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued SYMM/HEMM update kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Side Mode for A (kLeft or kRight) -+ SideMode SideModeA, -+ /// Fill Mode for A (kLower or kUpper) -+ FillMode FillModeA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by SYRK -+ typename Operator, -+ // BlasMode -+ BlasMode kBlasMode -+ > -+ -+struct DefaultSymmUniversal< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ kBlasMode, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultSymmkernel = typename kernel::DefaultSymmComplex< -+ ElementA, -+ LayoutA, -+ SideModeA, -+ FillModeA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ Operator, -+ SplitKSerial, -+ kBlasMode -+ >::SymmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using SymmKernel = kernel::SymmUniversal< -+ typename DefaultSymmkernel::Mma1, -+ typename DefaultSymmkernel::Mma2, -+ typename DefaultSymmkernel::Epilogue, -+ ThreadblockSwizzle, -+ SideModeA, -+ FillModeA -+ >; -+}; -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h -new file mode 100644 -index 0000000..cf2896a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+// -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_trmm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode SideMode_, -+ /// Fill Mode for the triangular matrix -+ FillMode FillMode_, -+ /// Diag Type for the triangular matrix -+ DiagType DiagType_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultTrmm { -+ -+ /// Define the threadblock-scoped triagular matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultTrmm< -+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, Stages, -+ Operator>::ThreadblockMma; -+ -+ static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, -+ EpilogueOutputOp::kCount>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h -new file mode 100644 -index 0000000..4909396 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_complex.h -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/epilogue/threadblock/epilogue.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/default_multistage_trmm_complex.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Side Mode for the kernel -+ SideMode SideMode_, -+ /// Fill Mode for the triangular matrix -+ FillMode FillMode_, -+ /// Diag Type for the triangular matrix -+ DiagType DiagType_, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+> -+struct DefaultTrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Hopper Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Ampere Architecture -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Multiply-add operator -+ // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator, -+ /// If true, kernel is configured to support serial reduction in the epilogue -+ bool SplitKSerial -+ > -+struct DefaultTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, -+ arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { -+ -+ /// Define the threadblock-scoped matrix multiply-accumulate -+ using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< -+ ElementA, LayoutA, ElementB, LayoutB, -+ kSideMode, kFillMode, kDiagType, -+ ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, -+ WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; -+ -+ /// Define the epilogue -+ using Epilogue = -+ typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< -+ ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, -+ EpilogueOutputOp::kCount, Operator>::Epilogue; -+ -+ /// Define the kernel-level TRMM operator. -+ using TrmmKernel = kernel::TrmmUniversal; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h -new file mode 100644 -index 0000000..50e8d8d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/default_trmm_universal.h -@@ -0,0 +1,359 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ Default kernel-level TRMM definitions combine threadblock-scoped matrix multiply-add with -+ the appropriate threadblock-scoped epilogue. -+ -+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are -+ accommodated by exchanging A and B operands and assuming transposed layouts. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/kernel/trmm_universal.h" -+#include "cutlass/gemm/kernel/default_trmm.h" -+#include "cutlass/gemm/kernel/default_trmm_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator, -+ /// -+ typename Enable = void -+ > -+struct DefaultTrmmUniversal; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued TRMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator> -+struct DefaultTrmmUniversal< -+ ElementA, -+ LayoutA, -+ ComplexTransform::kNone, // transform A -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ ComplexTransform::kNone, // transform B -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ typename std::enable_if< ! cutlass::is_complex::value>::type -+> { -+ -+ using DefaultTrmmKernel = typename kernel::DefaultTrmm< -+ ElementA, -+ LayoutA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator -+ >::TrmmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using TrmmKernel = kernel::TrmmUniversal< -+ typename DefaultTrmmKernel::Mma, -+ typename DefaultTrmmKernel::Epilogue, -+ ThreadblockSwizzle, -+ kSideMode, -+ kFillMode, -+ kDiagType -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Complex-valued TRMM kernels -+// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Complex elementwise transformation on A operand -+ ComplexTransform TransformA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Complex elementwise transformation on B operand -+ ComplexTransform TransformB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for C and D matrix operands -+ typename ElementC, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Epilogue output operator -+ typename EpilogueOutputOp, -+ /// Threadblock-level swizzling operator -+ typename ThreadblockSwizzle, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// If true, kernel is configured to support serial reduction in the -+ /// epilogue -+ bool SplitKSerial, -+ /// Operation performed by TRMM -+ typename Operator -+ > -+struct DefaultTrmmUniversal< -+ ElementA, -+ LayoutA, -+ TransformA, -+ kAlignmentA, -+ ElementB, -+ LayoutB, -+ TransformB, -+ kAlignmentB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ SplitKSerial, -+ Operator, -+ typename std::enable_if::value>::type -+> { -+ -+ using DefaultTrmmKernel = typename kernel::DefaultTrmmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ OperatorClass, -+ ArchTag, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOutputOp, -+ ThreadblockSwizzle, -+ Stages, -+ TransformA, -+ TransformB, -+ Operator, -+ SplitKSerial -+ >::TrmmKernel; -+ -+ /// Define the kernel in terms of the default kernel -+ using TrmmKernel = kernel::TrmmUniversal< -+ typename DefaultTrmmKernel::Mma, -+ typename DefaultTrmmKernel::Epilogue, -+ ThreadblockSwizzle, -+ kSideMode, -+ kFillMode, -+ kDiagType -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h -new file mode 100644 -index 0000000..88a1bd3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/ell_gemm.h -@@ -0,0 +1,830 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a Block-Ell sparse gemm kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/transform/threadblock/ell_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. -+ bool IsASparse ///! If true, A is sparse matrix -+> -+struct EllGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ const int* ell_idx, -+ int ell_ncol, -+ int ell_blocksize, -+ int ell_base_idx, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ ell_idx(ell_idx), -+ ell_ncol(ell_ncol), -+ ell_blocksize(ell_blocksize), -+ ell_base_idx(ell_base_idx) -+ { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union{ -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ typename cutlass::transform::threadblock::ell::SharedStorage ell; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ EllGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM; -+ int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; -+ int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // skip computation if matrix is 0 -+ if (params.ell_ncol > 0) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ ell_block_offset_m * params.ell_blocksize -+ + tile_offset_m * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ int ell_idx_start = -+ (threadblock_tile_offset.m() / tile_in_ell_block) * -+ (params.ell_ncol / params.ell_blocksize); -+ const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ problem_size_k = min(problem_size_k, params.ell_ncol); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = -+ (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Define coef for ELL index depending on LayoutB -+ int ell_stride = iterator_B.get_stride(); -+ -+ typename cutlass::transform::threadblock::ell::Iterator ell_iterator( -+ shared_storage.ell, -+ ell_idx_ptr, -+ params.ell_blocksize, -+ params.ell_base_idx, -+ Mma::Shape::kK, -+ problem_size_k, -+ ell_stride, -+ thread_idx -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // check if index computations can be skipped -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); -+ constexpr bool is_multiple_alignment = -+ (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); -+ const bool is_specialized_blocksize = -+ ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 -+ && params.ell_blocksize >= Mma::Shape::kK; -+ // Compute threadblock-scoped matrix multiply-add -+ if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ else { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ } -+ } // if (params.ell_ncols > 0) -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; -+ tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ ell_block_offset_m * params.ell_blocksize -+ + tile_offset_m * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ //avoid out of bounds -+ MatrixCoord threadblock_extent( -+ min(params.problem_size.m(), -+ ell_block_offset_m * params.ell_blocksize -+ + min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)), -+ min(params.problem_size.n(), -+ (threadblock_tile_offset.n()+1) * Mma::Shape::kN) -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+// B is Sparse -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct EllGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ const int* ell_idx; -+ int ell_ncol; -+ int ell_blocksize; -+ int ell_base_idx; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ const int* ell_idx, -+ int ell_ncol, -+ int ell_blocksize, -+ int ell_base_idx, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ ell_idx(ell_idx), -+ ell_ncol(ell_ncol), -+ ell_blocksize(ell_blocksize), -+ ell_base_idx(ell_base_idx) -+ { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union{ -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ typename cutlass::transform::threadblock::ell::SharedStorage ell; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ EllGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN; -+ int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; -+ int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_idx = threadIdx.x % 32; -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // skip computation if matrix is 0 -+ if (params.ell_ncol > 0) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ ell_block_offset_n * params.ell_blocksize -+ + tile_offset_n * Mma::Shape::kN, -+ }; -+ -+ int ell_idx_start = -+ (threadblock_tile_offset.n() / tile_in_ell_block) * -+ (params.ell_ncol / params.ell_blocksize); -+ const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ problem_size_k = min(problem_size_k, params.ell_ncol); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = -+ (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Define coef for ELL index depending on LayoutA -+ int ell_stride = iterator_A.get_stride(); -+ -+ typename cutlass::transform::threadblock::ell::Iterator ell_iterator( -+ shared_storage.ell, -+ ell_idx_ptr, -+ params.ell_blocksize, -+ params.ell_base_idx, -+ Mma::Shape::kK, -+ problem_size_k, -+ ell_stride, -+ thread_idx -+ ); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // check if index computations can be skipped -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); -+ constexpr bool is_multiple_alignment = -+ (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); -+ const bool is_specialized_blocksize = -+ ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 -+ && params.ell_blocksize >= Mma::Shape::kK; -+ // Compute threadblock-scoped matrix multiply-add -+ if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ else { -+ mma.operator()( -+ gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); -+ } -+ } -+ } // if (params.ell_ncols > 0) -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; -+ tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ ell_block_offset_n * params.ell_blocksize -+ + tile_offset_n * Mma::Shape::kN -+ ); -+ -+ //avoid out of bounds -+ MatrixCoord threadblock_extent( -+ min(params.problem_size.m(), -+ (threadblock_tile_offset.m()+1) * Mma::Shape::kM), -+ min(params.problem_size.n(), -+ ell_block_offset_n * params.ell_blocksize -+ + min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize)) -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ threadblock_extent, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h -new file mode 100644 -index 0000000..b5064ec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm.h -@@ -0,0 +1,380 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct Gemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_size; -+ // For gather+scatter operations -+ int const *gather_A_indices; -+ int const *gather_B_indices; -+ int const *scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr, -+ int const *gather_A_indices = nullptr, -+ int const *gather_B_indices = nullptr, -+ int const *scatter_D_indices = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ gather_A_indices(gather_A_indices), -+ gather_B_indices(gather_B_indices), -+ scatter_D_indices(scatter_D_indices) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Gemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ CUTLASS_HOST_DEVICE -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D) { -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h -new file mode 100644 -index 0000000..1862e20 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_array.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmArray { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::Element const * const * ptr_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::Element const * const * ptr_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Element const * const * ptr_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::Element * const * ptr_D; -+ int64_t stride_D; -+ typename OutputOp::Params epilogue; -+ int batch_count; -+ int gemm_k_iterations; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() : -+ swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape_, -+ typename Mma::IteratorA::Element const * const * ptr_A_, -+ typename Mma::IteratorA::Layout layout_A, -+ typename Mma::IteratorB::Element const * const * ptr_B_, -+ typename Mma::IteratorB::Layout layout_B, -+ typename Epilogue::OutputTileIterator::Element const * const * ptr_C_, -+ typename Epilogue::OutputTileIterator::Layout layout_C, -+ typename Epilogue::OutputTileIterator::Element * const * ptr_D_, -+ typename Epilogue::OutputTileIterator::Layout layout_D, -+ typename OutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(layout_A), -+ ptr_A(ptr_A_), -+ params_B(layout_B), -+ ptr_B(ptr_B_), -+ params_C(layout_C), -+ ptr_C(ptr_C_), -+ params_D(layout_D), -+ ptr_D(ptr_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_), -+ gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { -+ -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmArray() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ -+ // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension -+ for (int batch_idx = threadblock_swizzle.get_batch_idx(); -+ batch_idx < params.batch_count; -+ batch_idx += gridDim.z) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ 0 -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ const_cast(params.ptr_A[batch_idx]), -+ params.problem_size.mk(), -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ const_cast(params.ptr_B[batch_idx]), -+ params.problem_size.kn(), -+ thread_idx, -+ tb_offset_B); -+ -+ // -+ // Main loop -+ // -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.epilogue); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ const_cast(params.ptr_C[batch_idx]), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ptr_D[batch_idx], -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h -new file mode 100644 -index 0000000..464aeef ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_batched.h -@@ -0,0 +1,279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmBatched { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ int64_t stride_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ int64_t stride_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ int64_t stride_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ int64_t stride_D; -+ typename OutputOp::Params epilogue; -+ int batch_count; -+ int gemm_k_iterations; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() : swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size_, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape_, -+ typename Mma::IteratorA::TensorRef ref_A_, -+ int64_t stride_A_, -+ typename Mma::IteratorB::TensorRef ref_B_, -+ int64_t stride_B_, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C_, -+ int64_t stride_C_, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D_, -+ int64_t stride_D_, -+ typename OutputOp::Params epilogue_, -+ int batch_count_ -+ ): -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A_.layout()), -+ ref_A(ref_A_), -+ stride_A(stride_A_), -+ params_B(ref_B_.layout()), -+ ref_B(ref_B_), -+ stride_B(stride_B_), -+ params_C(ref_C_.layout()), -+ ref_C(ref_C_), -+ stride_C(stride_C_), -+ params_D(ref_D_.layout()), -+ ref_D(ref_D_), -+ stride_D(stride_D_), -+ epilogue(epilogue_), -+ batch_count(batch_count_), -+ gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) { -+ -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmBatched() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ -+ // Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension -+ for (int batch_idx = threadblock_swizzle.get_batch_idx(); -+ batch_idx < params.batch_count; -+ batch_idx += gridDim.z) { -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ 0 -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ params.problem_size.mk(), -+ thread_idx, -+ tb_offset_A); -+ -+ iterator_A.add_pointer_offset(params.stride_A * batch_idx); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ params.problem_size.kn(), -+ thread_idx, -+ tb_offset_B); -+ -+ iterator_B.add_pointer_offset(params.stride_B * batch_idx); -+ -+ -+ // -+ // Main loop -+ // -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.epilogue); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_C.add_pointer_offset(params.stride_C * batch_idx); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_D.add_pointer_offset(params.stride_D * batch_idx); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -new file mode 100644 -index 0000000..84dc4ae ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -@@ -0,0 +1,481 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped GEMMs -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct GemmGrouped { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Optional transpose -+ using MapArguments = kernel::detail::MapArguments< -+ typename Mma::IteratorA::Element, -+ typename Mma::IteratorA::Layout, -+ Mma::kTransformA, -+ Mma::IteratorA::AccessType::kElements, -+ typename Mma::IteratorB::Element, -+ typename Mma::IteratorB::Layout, -+ Mma::kTransformB, -+ Mma::IteratorB::AccessType::kElements, -+ typename Mma::LayoutC, -+ kTransposed -+ >; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion. -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArguments::LayoutC; -+ -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ -+ // Type definitions about the mainloop. -+ using Operator = typename Mma::Operator; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kTransposed>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params output_op, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ output_op(output_op), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.output_op), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ lda = args.lda; -+ ldb = args.ldb; -+ ldc = args.ldc; -+ ldd = args.ldd; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // These types shadow the type-level definitions and support the ability to implement -+ // a 'transposed' GEMM that computes the transposed problems. -+ // -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ // -+ // Problem visitor. -+ // -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_offset( -+ int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, -+ int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, -+ 0); -+ -+ // Load element pointers. Exchange pointers and strides if working on the transpose -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_offset.m(), -+ 0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_offset.n() -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ LayoutA(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ LayoutB(ldm_B), -+ ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Matrix multiply phase -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ ElementC *ptr_C = params.ptr_C[problem_idx]; -+ ElementC *ptr_D = params.ptr_D[problem_idx]; -+ -+ LayoutC layout_C(params.ldc[problem_idx]); -+ LayoutC layout_D(params.ldd[problem_idx]); -+ -+ typename Epilogue::OutputTileIterator::Params params_C(layout_C); -+ typename Epilogue::OutputTileIterator::Params params_D(layout_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h -new file mode 100644 -index 0000000..9df78c9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h -@@ -0,0 +1,122 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Scheduler for grouped GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+// Helper for correctly representing problem sizes in grouped kernels -+template < -+ typename ThreadblockShape, -+ bool Transposed -+> -+struct GemmGroupedProblemSizeHelper { -+ -+ static bool const kTransposed = Transposed; -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), -+ ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), -+ 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { -+ if (kTransposed) { -+ swap(problem.m(), problem.n()); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return grid.m() * grid.n(); -+ } -+}; -+ -+} // namespace detail -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct GemmGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::GemmGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ static bool const kTransposed = Transposed; -+ -+ using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using Params = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GemmGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base (params_, shared_storage_, block_idx) -+ {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..cac99f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped GEMMs with a softmax fused beforehand -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct GemmGroupedSoftmaxMainloopFusion { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Optional transpose -+ using MapArguments = kernel::detail::MapArguments< -+ typename Mma::IteratorA::Element, -+ typename Mma::IteratorA::Layout, -+ Mma::kTransformA, -+ Mma::IteratorA::AccessType::kElements, -+ typename Mma::IteratorB::Element, -+ typename Mma::IteratorB::Layout, -+ Mma::kTransformB, -+ Mma::IteratorB::AccessType::kElements, -+ typename Mma::LayoutC, -+ kTransposed -+ >; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion. -+ using ElementA = typename MapArguments::ElementA; -+ using LayoutA = typename MapArguments::LayoutA; -+ using ElementB = typename MapArguments::ElementB; -+ using LayoutB = typename MapArguments::LayoutB; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArguments::LayoutC; -+ -+ using ElementScaleBias = typename Mma::IteratorNormSum::Element; -+ -+ static ComplexTransform const kTransformA = MapArguments::kTransformA; -+ static ComplexTransform const kTransformB = MapArguments::kTransformB; -+ -+ // Type definitions about the mainloop. -+ using Operator = typename Mma::Operator; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = MapArguments::kAlignmentA; -+ static int const kAlignmentB = MapArguments::kAlignmentB; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kTransposed>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ void ** ptr_norm; -+ void ** ptr_sum; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_norm(nullptr), -+ ptr_sum(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params output_op, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ void ** ptr_norm, -+ void ** ptr_sum, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ output_op(output_op), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ ptr_norm(ptr_norm), -+ ptr_sum(ptr_sum), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ void ** ptr_norm; -+ void ** ptr_sum; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_norm(nullptr), -+ ptr_sum(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.output_op), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ ptr_norm(args.ptr_norm), -+ ptr_sum(args.ptr_sum), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, -+ workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ ptr_norm = args.ptr_norm; -+ ptr_sum = args.ptr_sum; -+ lda = args.lda; -+ ldb = args.ldb; -+ ldc = args.ldc; -+ ldd = args.ldd; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmGroupedSoftmaxMainloopFusion() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // These types shadow the type-level definitions and support the ability to implement -+ // a 'transposed' GEMM that computes the transposed problems. -+ // -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ // -+ // Problem visitor. -+ // -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_offset( -+ int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, -+ int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, -+ 0); -+ -+ // Load element pointers. Exchange pointers and strides if working on the transpose -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_offset.m(), -+ 0, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ 0, -+ threadblock_offset.n() -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ LayoutA(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ LayoutB(ldm_B), -+ ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Construct iterator to the softmax norm/sum vector -+ typename Mma::IteratorNormSum iterator_norm_sum( -+ problem_size.m(), -+ static_cast(params.ptr_norm[problem_idx]), -+ static_cast(params.ptr_sum[problem_idx]), -+ thread_idx, -+ MatrixCoord(0, threadblock_offset.m()) -+ ); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Matrix multiply phase -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ iterator_norm_sum, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ ElementC *ptr_C = params.ptr_C[problem_idx]; -+ ElementC *ptr_D = params.ptr_D[problem_idx]; -+ -+ LayoutC layout_C(params.ldc[problem_idx]); -+ LayoutC layout_D(params.ldd[problem_idx]); -+ -+ typename Epilogue::OutputTileIterator::Params params_C(layout_C); -+ typename Epilogue::OutputTileIterator::Params params_D(layout_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ threadblock_offset.mn() -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..94e2f1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h -@@ -0,0 +1,777 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel with layernorm operations fused in mainloop. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmLayernormMainloopFusion { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ using ElementScaleBias = typename Mma::IteratorVarMean::Element; -+ using LayoutScaleBias = typename Mma::IteratorVarMean::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_var; -+ void const * ptr_mean; -+ void const * ptr_gamma; -+ void const * ptr_beta; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_var; -+ int64_t batch_stride_mean; -+ int64_t batch_stride_gamma; -+ int64_t batch_stride_beta; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutScaleBias::Stride stride_var; -+ typename LayoutScaleBias::Stride stride_mean; -+ typename LayoutScaleBias::Stride stride_gamma; -+ typename LayoutScaleBias::Stride stride_beta; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutScaleBias::Stride::LongIndex ld_var; -+ typename LayoutScaleBias::Stride::LongIndex ld_mean; -+ typename LayoutScaleBias::Stride::LongIndex ld_gamma; -+ typename LayoutScaleBias::Stride::LongIndex ld_beta; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_var(nullptr), ptr_mean(nullptr), -+ ptr_gamma(nullptr), ptr_beta(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_var, -+ void const * ptr_mean, -+ void const * ptr_gamma, -+ void const * ptr_beta, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_var, -+ int64_t batch_stride_mean, -+ int64_t batch_stride_gamma, -+ int64_t batch_stride_beta, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutScaleBias::Stride stride_var, -+ typename LayoutScaleBias::Stride stride_mean, -+ typename LayoutScaleBias::Stride stride_gamma, -+ typename LayoutScaleBias::Stride stride_beta, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_var(ptr_var), ptr_mean(ptr_mean), -+ ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), -+ batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), -+ lda(0), ldb(0), ldc(0), ldd(0), -+ ld_var(0), ld_mean(0), -+ ld_gamma(0), ld_beta(0), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ stride_var(stride_var), stride_mean(stride_mean), -+ stride_gamma(stride_gamma), stride_beta(stride_beta), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_var, -+ void const * ptr_mean, -+ void const * ptr_gamma, -+ void const * ptr_beta, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_var, -+ int64_t batch_stride_mean, -+ int64_t batch_stride_gamma, -+ int64_t batch_stride_beta, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutScaleBias::Stride::LongIndex ld_var, -+ typename LayoutScaleBias::Stride::LongIndex ld_mean, -+ typename LayoutScaleBias::Stride::LongIndex ld_gamma, -+ typename LayoutScaleBias::Stride::LongIndex ld_beta, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_var(ptr_var), ptr_mean(ptr_mean), -+ ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), -+ batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ld_var(ld_var), ld_mean(ld_mean), -+ ld_gamma(ld_gamma), ld_beta(ld_beta), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ stride_var = make_Coord(ld_var); -+ stride_mean = make_Coord(ld_mean); -+ stride_gamma = make_Coord(ld_gamma); -+ stride_beta = make_Coord(ld_beta); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_var; -+ void * ptr_mean; -+ void * ptr_gamma; -+ void * ptr_beta; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_var; -+ int64_t batch_stride_mean; -+ int64_t batch_stride_gamma; -+ int64_t batch_stride_beta; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_var(const_cast(args.ptr_var)), -+ ptr_mean(const_cast(args.ptr_mean)), -+ ptr_gamma(const_cast(args.ptr_gamma)), -+ ptr_beta(const_cast(args.ptr_beta)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_var(args.batch_stride_var), -+ batch_stride_mean(args.batch_stride_mean), -+ batch_stride_gamma(args.batch_stride_gamma), -+ batch_stride_beta(args.batch_stride_beta), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_var = const_cast(args.ptr_var); -+ ptr_mean = const_cast(args.ptr_mean); -+ ptr_gamma = const_cast(args.ptr_gamma); -+ ptr_beta = const_cast(args.ptr_beta); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmLayernormMainloopFusion op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Construct iterators to A var/mean vector -+ typename Mma::IteratorVarMean iterator_var_mean( -+ params.problem_size.m(), -+ static_cast(params.ptr_var), -+ static_cast(params.ptr_mean), -+ thread_idx, -+ MatrixCoord(0, (threadblock_tile_offset.m() * Mma::Shape::kM)) -+ ); -+ -+ // Construct iterators to A scale/bias vector -+ typename Mma::IteratorGammaBeta iterator_gamma_beta( -+ problem_size_k, -+ static_cast(params.ptr_gamma), -+ static_cast(params.ptr_beta), -+ thread_idx, -+ MatrixCoord( -+ 0, (threadblock_tile_offset.k() * Mma::Shape::kK) -+ ) -+ ); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ iterator_var_mean, -+ iterator_gamma_beta, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h -new file mode 100755 -index 0000000..046ad75 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_params.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GemmParams { -+ -+ // -+ // Type definitions -+ // -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ using MmaIteratorParams = typename cutlass::transform::threadblock::PredicatedTileAccessIteratorParams; -+ using EpilogueIteratorParams = typename cutlass::epilogue::threadblock::PredicatedTileIteratorParams; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Data members for Mma::Iterator::Params -+ MmaIteratorParams params_itr_a; -+ MmaIteratorParams params_itr_b; -+ -+ // Data member for Epilogue::OutputTileIterator::Params -+ EpilogueIteratorParams params_itr_c; -+ EpilogueIteratorParams params_itr_d; -+ -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ LongIndex lda; -+ LongIndex ldb; -+ LongIndex ldc; -+ LongIndex ldd; -+ -+ LongIndex batch_stride_A; -+ LongIndex batch_stride_B; -+ LongIndex batch_stride_C; -+ LongIndex batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmParams() {} -+ -+ CUTLASS_HOST_DEVICE -+ GemmParams( -+ cutlass::gemm::GemmCoord problem_size_, -+ cutlass::gemm::GemmCoord grid_tiled_shape_, -+ int swizzle_log_tile_, -+ GemmUniversalMode mode_, -+ int batch_count_, -+ int gemm_k_size_, -+ void const * ptr_A_, -+ void const * ptr_B_, -+ void const * ptr_C_, -+ void * ptr_D_, -+ LongIndex lda_, -+ LongIndex ldb_, -+ LongIndex ldc_, -+ LongIndex ldd_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ MmaIteratorParams const & params_itr_a_, -+ MmaIteratorParams const & params_itr_b_, -+ EpilogueIteratorParams const & params_itr_c_, -+ EpilogueIteratorParams const & params_itr_d_, -+ void *workspace_ = nullptr) : -+ problem_size(problem_size_), -+ grid_tiled_shape(grid_tiled_shape_), -+ swizzle_log_tile(swizzle_log_tile_), -+ mode(mode_), -+ batch_count(batch_count_), -+ gemm_k_size(gemm_k_size_), -+ ptr_A(const_cast(ptr_A_)), -+ ptr_B(const_cast(ptr_B_)), -+ ptr_C(const_cast(ptr_C_)), -+ ptr_D(ptr_D_), -+ lda(lda_), -+ ldb(ldb_), -+ ldc(ldc_), -+ ldd(ldd_), -+ batch_stride_A(batch_stride_A_), -+ batch_stride_B(batch_stride_B_), -+ batch_stride_C(batch_stride_C_), -+ batch_stride_D(batch_stride_D_), -+ params_itr_a(params_itr_a_), -+ params_itr_b(params_itr_b_), -+ params_itr_c(params_itr_c_), -+ params_itr_d(params_itr_d_), -+ semaphore(static_cast(workspace_) -+ ) { } -+ -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ void const * ptr_A_, -+ void const * ptr_B_, -+ void const * ptr_C_, -+ void * ptr_D_, -+ int64_t batch_stride_A_, -+ int64_t batch_stride_B_, -+ int64_t batch_stride_C_, -+ int64_t batch_stride_D_, -+ void *workspace_ = nullptr) { -+ -+ ptr_A = const_cast(ptr_A_); -+ ptr_B = const_cast(ptr_B_); -+ ptr_C = const_cast(ptr_C_); -+ ptr_D = ptr_D_; -+ -+ batch_stride_A = batch_stride_A_; -+ batch_stride_B = batch_stride_B_; -+ batch_stride_C = batch_stride_C_; -+ batch_stride_D = batch_stride_D_; -+ -+ -+ semaphore = static_cast(workspace_); -+ CUTLASS_TRACE_HOST("GemmParams::update()"); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h -new file mode 100644 -index 0000000..df450d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_pipelined.h -@@ -0,0 +1,158 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void GemmPipelined( -+ cutlass::gemm::GemmCoord problem_size, -+ cutlass::gemm::GemmCoord grid_tiled_shape, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::Params params_epilogue -+ ) { -+ -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ union { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } shared_storage; -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ int swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); -+ -+ cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); -+ -+ if (grid_tiled_shape.m() <= tb_tile_offset.m() || -+ grid_tiled_shape.n() <= tb_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params_A, -+ ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params_B, -+ ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, -+ tb_offset_B); -+ -+ int warp_id = canonical_warp_idx(); -+ int lane_id = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, tb_thread_id, warp_id, lane_id); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(problem_size, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ Epilogue epilogue( -+ params_epilogue, -+ shared_storage.epilogue, -+ tb_thread_id, -+ warp_id, -+ lane_id); -+ -+ tb_tile_offset = threadblock_swizzle.get_tile_offset(swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // run efficient epilogue -+ epilogue({problem_size.m(), problem_size.n()}, accumulators, threadblock_offset); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h -new file mode 100644 -index 0000000..7dbc592 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex.h -@@ -0,0 +1,715 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmPlanarComplex { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using Operator = typename Mma::Operator; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value); -+ -+ // -+ // Additional types needed for reflection -+ // -+ -+ using ElementAccumulator = typename Mma::Policy::Operator::ElementC; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ // -+ // Arguments structure -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A_real; -+ void const * ptr_A_imag; -+ -+ void const * ptr_B_real; -+ void const * ptr_B_imag; -+ -+ void const * ptr_C_real; -+ void const * ptr_C_imag; -+ -+ void * ptr_D_real; -+ void * ptr_D_imag; -+ -+ typename LayoutA::Stride::Index lda_real; -+ typename LayoutA::Stride::Index lda_imag; -+ typename LayoutB::Stride::Index ldb_real; -+ typename LayoutB::Stride::Index ldb_imag; -+ typename LayoutC::Stride::Index ldc_real; -+ typename LayoutC::Stride::Index ldc_imag; -+ typename LayoutC::Stride::Index ldd_real; -+ typename LayoutC::Stride::Index ldd_imag; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_A_imag; -+ int64_t batch_stride_B; -+ int64_t batch_stride_B_imag; -+ int64_t batch_stride_C; -+ int64_t batch_stride_C_imag; -+ int64_t batch_stride_D_imag; -+ -+ // -+ // Methods -+ // -+ -+ Arguments() : -+ ptr_A_real(nullptr), -+ ptr_A_imag(nullptr), -+ ptr_B_real(nullptr), -+ ptr_B_imag(nullptr), -+ ptr_C_real(nullptr), -+ ptr_C_imag(nullptr), -+ ptr_D_real(nullptr), -+ ptr_D_imag(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A_real, -+ void const * ptr_A_imag, -+ void const * ptr_B_real, -+ void const * ptr_B_imag, -+ void const * ptr_C_real, -+ void const * ptr_C_imag, -+ void * ptr_D_real, -+ void * ptr_D_imag, -+ typename LayoutA::Stride::Index lda_real, -+ typename LayoutA::Stride::Index lda_imag, -+ typename LayoutB::Stride::Index ldb_real, -+ typename LayoutB::Stride::Index ldb_imag, -+ typename LayoutC::Stride::Index ldc_real, -+ typename LayoutC::Stride::Index ldc_imag, -+ typename LayoutC::Stride::Index ldd_real, -+ typename LayoutC::Stride::Index ldd_imag, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_A_imag = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_B_imag = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_C_imag = 0, -+ int64_t batch_stride_D = 0, -+ int64_t batch_stride_D_imag = 0) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A_real(ptr_A_real), -+ ptr_A_imag(ptr_A_imag), -+ ptr_B_real(ptr_B_real), -+ ptr_B_imag(ptr_B_imag), -+ ptr_C_real(ptr_C_real), -+ ptr_C_imag(ptr_C_imag), -+ ptr_D_real(ptr_D_real), -+ ptr_D_imag(ptr_D_imag), -+ lda_real(lda_real), -+ lda_imag(lda_imag), -+ ldb_real(ldb_real), -+ ldb_imag(ldb_imag), -+ ldc_real(ldc_real), -+ ldc_imag(ldc_imag), -+ ldd_real(ldd_real), -+ ldd_imag(ldd_imag), -+ batch_stride_A(batch_stride_A), -+ batch_stride_A_imag(batch_stride_A_imag), -+ batch_stride_B(batch_stride_B), -+ batch_stride_B_imag(batch_stride_B_imag), -+ batch_stride_C(batch_stride_C), -+ batch_stride_C_imag(batch_stride_C_imag), -+ batch_stride_D_imag(batch_stride_D_imag) -+ {} -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A_real, args.ptr_B_real); -+ std::swap(args.ptr_A_imag, args.ptr_B_imag); -+ std::swap(args.lda_real, args.ldb_real); -+ std::swap(args.lda_imag, args.ldb_imag); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.batch_stride_A_imag, args.batch_stride_B_imag); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A_real; -+ typename Mma::IteratorA::Params params_A_imag; -+ typename Mma::IteratorB::Params params_B_real; -+ typename Mma::IteratorB::Params params_B_imag; -+ typename Epilogue::OutputTileIterator::Params params_C_real; -+ typename Epilogue::OutputTileIterator::Params params_C_imag; -+ typename Epilogue::OutputTileIterator::Params params_D_real; -+ typename Epilogue::OutputTileIterator::Params params_D_imag; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A_real; -+ void * ptr_A_imag; -+ void * ptr_B_real; -+ void * ptr_B_imag; -+ void * ptr_C_real; -+ void * ptr_C_imag; -+ void * ptr_D_real; -+ void * ptr_D_imag; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int64_t batch_stride_A_imag; -+ int64_t batch_stride_B_imag; -+ int64_t batch_stride_C_imag; -+ int64_t batch_stride_D_imag; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A_real(args.lda_real), -+ params_A_imag(args.lda_imag), -+ params_B_real(args.ldb_real), -+ params_B_imag(args.ldb_imag), -+ params_C_real(args.ldc_real), -+ params_C_imag(args.ldc_imag), -+ params_D_real(args.ldd_real), -+ params_D_imag(args.ldd_imag), -+ output_op(args.epilogue), -+ ptr_A_real(const_cast(args.ptr_A_real)), -+ ptr_A_imag(const_cast(args.ptr_A_imag)), -+ ptr_B_real(const_cast(args.ptr_B_real)), -+ ptr_B_imag(const_cast(args.ptr_B_imag)), -+ ptr_C_real(const_cast(args.ptr_C_real)), -+ ptr_C_imag(const_cast(args.ptr_C_imag)), -+ ptr_D_real(args.ptr_D_real), -+ ptr_D_imag(args.ptr_D_imag), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_A_imag(args.batch_stride_A_imag), -+ batch_stride_B_imag(args.batch_stride_B_imag), -+ batch_stride_C_imag(args.batch_stride_C_imag), -+ batch_stride_D_imag(args.batch_stride_D_imag) -+ {} -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = ParamsBase::get_workspace_size(); -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Double the size returned by the base class because we need to -+ // accumulate two ElementC components -+ workspace_bytes *= 2; -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A_real = const_cast(args.ptr_A_real); -+ ptr_A_imag = const_cast(args.ptr_A_imag); -+ -+ ptr_B_real = const_cast(args.ptr_B_real); -+ ptr_B_imag = const_cast(args.ptr_B_imag); -+ -+ ptr_C_real = const_cast(args.ptr_C_real); -+ ptr_C_imag = const_cast(args.ptr_C_imag); -+ -+ ptr_D_real = const_cast(args.ptr_D_real); -+ ptr_D_imag = const_cast(args.ptr_D_imag); -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(Arguments const &args) -+ { -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.m() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.m() % kAlignmentC; -+ } -+ -+ if (isAMisaligned || isBMisaligned || isCMisaligned) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmPlanarComplex op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A_real = static_cast(params.ptr_A_real); -+ ElementA *ptr_A_imag = static_cast(params.ptr_A_imag); -+ -+ ElementB *ptr_B_real = static_cast(params.ptr_B_real); -+ ElementB *ptr_B_imag = static_cast(params.ptr_B_imag); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A; -+ ptr_A_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_A_imag; -+ ptr_B_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B; -+ ptr_B_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_B_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A_real = static_cast(params.ptr_A_real)[threadblock_tile_offset.k()]; -+ ptr_A_imag = static_cast(params.ptr_A_imag)[threadblock_tile_offset.k()]; -+ ptr_B_real = static_cast(params.ptr_B_real)[threadblock_tile_offset.k()]; -+ ptr_B_imag = static_cast(params.ptr_B_imag)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A_real( -+ params.params_A_real, -+ ptr_A_real, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag( -+ params.params_A_imag, -+ ptr_A_imag, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B_real( -+ params.params_B_real, -+ ptr_B_real, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag( -+ params.params_B_imag, -+ ptr_B_imag, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C_real = static_cast(params.ptr_C_real); -+ ElementC *ptr_C_imag = static_cast(params.ptr_C_imag); -+ ElementC *ptr_D_real = static_cast(params.ptr_D_real); -+ ElementC *ptr_D_imag = static_cast(params.ptr_D_imag); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D_real += threadblock_tile_offset.k() * params.batch_stride_D; -+ ptr_D_imag += threadblock_tile_offset.k() * params.batch_stride_D_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C; -+ ptr_C_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_C_imag; -+ ptr_D_real += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D; -+ ptr_D_imag += int64_t(threadblock_tile_offset.k()) * params.batch_stride_D_imag; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C_real = static_cast(params.ptr_C_real)[threadblock_tile_offset.k()]; -+ ptr_C_imag = static_cast(params.ptr_C_imag)[threadblock_tile_offset.k()]; -+ ptr_D_real = static_cast(params.ptr_D_real)[threadblock_tile_offset.k()]; -+ ptr_D_imag = static_cast(params.ptr_D_imag)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params.params_C_real, -+ ptr_C_real, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params.params_C_imag, -+ ptr_C_imag, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params.params_D_real, -+ ptr_D_real, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params.params_D_imag, -+ ptr_D_imag, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // -+ // Construct epilogue -+ // -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C_real = iterator_D_real; -+ iterator_C_imag = iterator_D_imag; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h -new file mode 100644 -index 0000000..21b8011 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_planar_complex_array.h -@@ -0,0 +1,618 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmPlanarComplexArray { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using Operator = typename Mma::Operator; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value); -+ -+ // -+ // Additional types needed for reflection -+ // -+ -+ using ElementAccumulator = typename Mma::Policy::Operator::ElementC; -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::Shape; -+ -+ static int const kStages = Mma::kStages; -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ // -+ // Arguments structure -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ int const *ptr_M; -+ int const *ptr_N; -+ int const *ptr_K; -+ -+ void const * const * ptr_A_real; -+ void const * const * ptr_A_imag; -+ -+ void const * const * ptr_B_real; -+ void const * const * ptr_B_imag; -+ -+ void const * const * ptr_C_real; -+ void const * const * ptr_C_imag; -+ -+ void * const * ptr_D_real; -+ void * const * ptr_D_imag; -+ -+ typename LayoutA::Stride::Index lda_real; -+ typename LayoutA::Stride::Index lda_imag; -+ typename LayoutB::Stride::Index ldb_real; -+ typename LayoutB::Stride::Index ldb_imag; -+ typename LayoutC::Stride::Index ldc_real; -+ typename LayoutC::Stride::Index ldc_imag; -+ typename LayoutC::Stride::Index ldd_real; -+ typename LayoutC::Stride::Index ldd_imag; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_M(nullptr), -+ ptr_N(nullptr), -+ ptr_K(nullptr), -+ ptr_A_real(nullptr), -+ ptr_A_imag(nullptr), -+ ptr_B_real(nullptr), -+ ptr_B_imag(nullptr), -+ ptr_C_real(nullptr), -+ ptr_C_imag(nullptr), -+ ptr_D_real(nullptr), -+ ptr_D_imag(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ int const *ptr_M, -+ int const *ptr_N, -+ int const *ptr_K, -+ void const * const * ptr_A_real, -+ void const * const * ptr_A_imag, -+ void const * const * ptr_B_real, -+ void const * const * ptr_B_imag, -+ void const * const * ptr_C_real, -+ void const * const * ptr_C_imag, -+ void * const * ptr_D_real, -+ void * const * ptr_D_imag, -+ typename LayoutA::Stride::Index lda_real, -+ typename LayoutA::Stride::Index lda_imag, -+ typename LayoutB::Stride::Index ldb_real, -+ typename LayoutB::Stride::Index ldb_imag, -+ typename LayoutC::Stride::Index ldc_real, -+ typename LayoutC::Stride::Index ldc_imag, -+ typename LayoutC::Stride::Index ldd_real, -+ typename LayoutC::Stride::Index ldd_imag) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_M(ptr_M), -+ ptr_N(ptr_N), -+ ptr_K(ptr_K), -+ ptr_A_real(ptr_A_real), -+ ptr_A_imag(ptr_A_imag), -+ ptr_B_real(ptr_B_real), -+ ptr_B_imag(ptr_B_imag), -+ ptr_C_real(ptr_C_real), -+ ptr_C_imag(ptr_C_imag), -+ ptr_D_real(ptr_D_real), -+ ptr_D_imag(ptr_D_imag), -+ lda_real(lda_real), -+ lda_imag(lda_imag), -+ ldb_real(ldb_real), -+ ldb_imag(ldb_imag), -+ ldc_real(ldc_real), -+ ldc_imag(ldc_imag), -+ ldd_real(ldd_real), -+ ldd_imag(ldd_imag) -+ {} -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_M, args.ptr_N); -+ std::swap(args.ptr_A_real, args.ptr_B_real); -+ std::swap(args.ptr_A_imag, args.ptr_B_imag); -+ std::swap(args.lda_real, args.ldb_real); -+ std::swap(args.lda_imag, args.ldb_imag); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A_real; -+ typename Mma::IteratorA::Params params_A_imag; -+ typename Mma::IteratorB::Params params_B_real; -+ typename Mma::IteratorB::Params params_B_imag; -+ typename Epilogue::OutputTileIterator::Params params_C_real; -+ typename Epilogue::OutputTileIterator::Params params_C_imag; -+ typename Epilogue::OutputTileIterator::Params params_D_real; -+ typename Epilogue::OutputTileIterator::Params params_D_imag; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ int const *ptr_M; -+ int const *ptr_N; -+ int const *ptr_K; -+ -+ void const * const * ptr_A_real; -+ void const * const * ptr_A_imag; -+ void const * const * ptr_B_real; -+ void const * const * ptr_B_imag; -+ void const * const * ptr_C_real; -+ void const * const * ptr_C_imag; -+ void * const * ptr_D_real; -+ void * const * ptr_D_imag; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ ptr_M(args.ptr_M), -+ ptr_N(args.ptr_N), -+ ptr_K(args.ptr_K), -+ params_A_real(args.lda_real), -+ params_A_imag(args.lda_imag), -+ params_B_real(args.ldb_real), -+ params_B_imag(args.ldb_imag), -+ params_C_real(args.ldc_real), -+ params_C_imag(args.ldc_imag), -+ params_D_real(args.ldd_real), -+ params_D_imag(args.ldd_imag), -+ output_op(args.epilogue), -+ ptr_A_real(args.ptr_A_real), -+ ptr_A_imag(args.ptr_A_imag), -+ ptr_B_real(args.ptr_B_real), -+ ptr_B_imag(args.ptr_B_imag), -+ ptr_C_real(args.ptr_C_real), -+ ptr_C_imag(args.ptr_C_imag), -+ ptr_D_real(args.ptr_D_real), -+ ptr_D_imag(args.ptr_D_imag) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_M = args.ptr_M; -+ ptr_N = args.ptr_N; -+ ptr_K = args.ptr_K; -+ -+ ptr_A_real = args.ptr_A_real; -+ ptr_A_imag = args.ptr_A_imag; -+ -+ ptr_B_real = args.ptr_B_real; -+ ptr_B_imag = args.ptr_B_imag; -+ -+ ptr_C_real = args.ptr_C_real; -+ ptr_C_imag = args.ptr_C_imag; -+ -+ ptr_D_real = args.ptr_D_real; -+ ptr_D_imag = args.ptr_D_imag; -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(Arguments const &args) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = args.problem_size.m() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = args.problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = args.problem_size.m() % kAlignmentC; -+ } -+ -+ if (isAMisaligned || isBMisaligned || isCMisaligned) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmPlanarComplexArray op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int batch_idx = threadblock_tile_offset.k(); -+ -+ int problem_size_m = params.problem_size.m(); -+ int problem_size_n = params.problem_size.n(); -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A_real = static_cast(const_cast(params.ptr_A_real[batch_idx])); -+ ElementA *ptr_A_imag = static_cast(const_cast(params.ptr_A_imag[batch_idx])); -+ -+ ElementB *ptr_B_real = static_cast(const_cast(params.ptr_B_real[batch_idx])); -+ ElementB *ptr_B_imag = static_cast(const_cast(params.ptr_B_imag[batch_idx])); -+ -+ // -+ // If pointers for problem sizes are specified, these are loaded from global memory -+ // -+ -+ if (params.ptr_M) { -+ problem_size_m = params.ptr_M[batch_idx]; -+ } -+ -+ if (params.ptr_N) { -+ problem_size_n = params.ptr_N[batch_idx]; -+ } -+ -+ if (params.ptr_K) { -+ problem_size_k = params.ptr_K[batch_idx]; -+ } -+ -+ int const kBlockCountM = (problem_size_m + Mma::Shape::kM - 1) / Mma::Shape::kM; -+ int const kBlockCountN = (problem_size_n + Mma::Shape::kN - 1) / Mma::Shape::kN; -+ -+ int const kGemmKIterations = (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // -+ // Each threadblock loops over the logical problem size which the kernel may have discovered -+ // after the grid is launched. -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int block_m = threadblock_tile_offset.m(); -+ block_m < kBlockCountM; -+ block_m += params.grid_tiled_shape.m()) { -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int block_n = threadblock_tile_offset.n(); -+ block_n < kBlockCountN; -+ block_n += params.grid_tiled_shape.n()) { -+ -+ // -+ // Compute indices within threadblock and warp. -+ // -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Proceed with regular GEMM logic. -+ // -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ block_m * Mma::Shape::kM, 0}; -+ cutlass::MatrixCoord tb_offset_B{ 0, block_n * Mma::Shape::kN }; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A_real( -+ params.params_A_real, -+ ptr_A_real, -+ {problem_size_m, problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag( -+ params.params_A_imag, -+ ptr_A_imag, -+ {problem_size_m, problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B_real( -+ params.params_B_real, -+ ptr_B_real, -+ {problem_size_k, problem_size_n}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag( -+ params.params_B_imag, -+ ptr_B_imag, -+ {problem_size_k, problem_size_n}, -+ thread_idx, -+ tb_offset_B); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ kGemmKIterations, -+ accumulators, -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ block_m * Mma::Shape::kM, -+ block_n * Mma::Shape::kN -+ ); -+ -+ ElementC *ptr_C_real = static_cast(const_cast(params.ptr_C_real[batch_idx])); -+ ElementC *ptr_C_imag = static_cast(const_cast(params.ptr_C_imag[batch_idx])); -+ ElementC *ptr_D_real = static_cast(params.ptr_D_real[batch_idx]); -+ ElementC *ptr_D_imag = static_cast(params.ptr_D_imag[batch_idx]); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params.params_C_real, -+ ptr_C_real, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params.params_C_imag, -+ ptr_C_imag, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params.params_D_real, -+ ptr_D_real, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params.params_D_imag, -+ ptr_D_imag, -+ {problem_size_m, problem_size_n}, -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // -+ // Construct epilogue -+ // -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ -+ -+ } // for block_n -+ } // for block_m -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h -new file mode 100644 -index 0000000..ffb928c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_splitk_parallel.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for GEMM performing a reduction over K partitions in parallel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmSplitKParallel { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ static int const kAlignmentK = Mma::Operator::Shape::kK; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename OutputOp::Params output_op; -+ int64_t splitk_slice_stride; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename OutputOp::Params output_op, -+ int64_t splitk_slice_stride -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ output_op(output_op), -+ splitk_slice_stride(splitk_slice_stride) { -+ -+ int full_gemm_k_iterations = problem_size.k() / Mma::Shape::kK; -+ int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ GemmSplitKParallel() { } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k; -+ if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { -+ problem_size_k = params.problem_size.k(); -+ } -+ else { -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k()); -+ -+ // Execute the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Run efficient epilogue -+ epilogue(output_op, iterator_D, accumulators, iterator_D); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h -new file mode 100644 -index 0000000..dec9935 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_transpose_operands.h -@@ -0,0 +1,124 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and -+ batched array variants. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ bool Transpose -+> -+struct MapArguments { -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ static ComplexTransform const kTransformA = TransformA; -+ static int const kAlignmentA = AlignmentA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kAlignmentB = AlignmentB; -+ using LayoutC = LayoutC_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_ -+> -+struct MapArguments< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ LayoutC_, -+ true -+> { -+ using ElementA = ElementB_; -+ using LayoutA = typename layout::LayoutTranspose::type; -+ static ComplexTransform const kTransformA = TransformB; -+ static int const kAlignmentA = AlignmentB; -+ using ElementB = ElementA_; -+ using LayoutB = typename layout::LayoutTranspose::type; -+ static ComplexTransform const kTransformB = TransformA; -+ static int const kAlignmentB = AlignmentA; -+ using LayoutC = typename layout::LayoutTranspose::type; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h -new file mode 100644 -index 0000000..fc62c01 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.h -@@ -0,0 +1,694 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+class GemmUniversal< -+ Mma_, -+ Epilogue_, -+ ThreadblockSwizzle_, -+ void, -+ // 3.x kernels use the first template argument to define the ProblemShape tuple -+ // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API -+ std::enable_if_t::value> -+> { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ lda = 0; -+ ldb = 0; -+ ldc = 0; -+ ldd = 0; -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const -+ { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) -+ {} -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ -+ // Update input/output pointers -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ output_op = args.epilogue; -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmUniversal op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ params.ptr_scatter_D_indices -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp -new file mode 100644 -index 0000000..cdac6ca ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal.hpp -@@ -0,0 +1,72 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/* -+ * Stateless universal device GEMM kernel type that treats GEMM as -+ * a composition of a collective mainloop and a collective epilogue. -+ * -+ * Supports both the 2.x and 3.x APIs based on whether the first type is -+ * a cute::tuple<> or not. -+ * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h -+ * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp -+ * -+ * In the following declaration, the name preceding the 'Or' refers to -+ * 3.x API type argument order, and the name succeeding the 'Or' refers to -+ * 2.x API type argument order. Template arguments without two names -+ * belong to the 3.x API only. -+**/ -+template < -+ class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) -+ class CollectiveMainloopOrEpilogue_, -+ class CollectiveEpilogueOrThreadblockSwizzle_, -+ class GridSwizzle_ = void, -+ class Enable = void -+> -+class GemmUniversal; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/kernel/sm70_gemm.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" -+#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp" -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h -new file mode 100644 -index 0000000..27da66f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_universal_streamk.h -@@ -0,0 +1,1249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/barrier.h" -+#include "cutlass/block_striped.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock mapping function -+> -+struct GemmUniversalStreamk { -+public: -+ -+ -+ // -+ // Types and constants -+ // -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ /// The per-thread tile of raw accumulators -+ using AccumulatorTile = typename Mma::FragmentC; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Workspace bytes per thread block -+ static size_t const kWorkspaceBytesPerBlock = -+ __NV_STD_MAX( -+ kThreadCount * sizeof(AccumulatorTile), -+ Epilogue::kWorkspaceBytesPerBlock); -+ -+ /// Block-striped reduction utility -+ using BlockStripedReduceT = BlockStripedReduce; -+ -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Default Constructor -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ avail_sms(-1) -+ {} -+ -+ /// Constructor -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_split), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Constructor -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_split), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms) -+ { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const -+ { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ /// Parameters structure -+ struct Params -+ { -+ public: -+ -+ // -+ // Data members -+ // -+ -+ void * ptr_A; -+ void * ptr_B; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ -+ GemmUniversalMode mode; -+ -+ ThreadblockSwizzle block_mapping; -+ -+ bool quick_dp; -+ -+ void *barrier_workspace; -+ void *partials_workspace; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_D; -+ void * ptr_C; -+ -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ -+ int64_t batch_stride_D; -+ int64_t batch_stride_C; -+ -+ -+ protected: -+ -+ // -+ // Host-only dispatch-utilities -+ // -+ -+ /// Pad the given allocation size up to the nearest cache line -+ static size_t cacheline_align_up(size_t size) -+ { -+ static const int CACHELINE_SIZE = 128; -+ return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE; -+ } -+ -+ /// Get the workspace size needed for barrier -+ size_t get_barrier_workspace_size() const -+ { -+ // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, -+ // each reduction block needs its own synchronization flag. -+ int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); -+ int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); -+ -+ return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); -+ } -+ -+ /// Get the workspace size needed for intermediate partial sums -+ size_t get_partials_workspace_size() const -+ { -+ int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); -+ return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); -+ } -+ -+ -+ public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ output_op(args.epilogue), -+ mode(args.mode), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ barrier_workspace(nullptr), -+ partials_workspace(nullptr) -+ { -+ // Number of SMs to make available for StreamK decomposition -+ int avail_sms = (args.avail_sms == -1) ? -+ device_sms : -+ fast_min(args.avail_sms, device_sms); -+ -+ // Initialize the block mapping structure -+ block_mapping = ThreadblockSwizzle( -+ typename ThreadblockSwizzle::template KernelTraits(), -+ args.mode, -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count, -+ sm_occupancy, -+ device_sms, -+ avail_sms); -+ -+ quick_dp = -+ (block_mapping.sk_waves == 0) && -+ (mode == GemmUniversalMode::kGemm) && -+ !block_mapping.cohort_raster && -+ !EpilogueOutputOp(output_op).is_source_needed(); -+ -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for these parameters -+ size_t get_workspace_size() const -+ { -+ return -+ get_barrier_workspace_size() + -+ get_partials_workspace_size(); -+ } -+ -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ uint8_t *ptr = static_cast(workspace); -+ -+ // Establish partials workspace -+ partials_workspace = nullptr; -+ size_t partials_workspace_bytes = get_partials_workspace_size(); -+ if (partials_workspace_bytes > 0) -+ { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ partials_workspace = ptr; -+ ptr += partials_workspace_bytes; -+ } -+ -+ // Establish barrier workspace -+ barrier_workspace = nullptr; -+ size_t barrier_workspace_bytes = get_barrier_workspace_size(); -+ if (barrier_workspace_bytes > 0) -+ { -+ if (!workspace) { -+ return Status::kErrorWorkspaceNull; -+ } -+ barrier_workspace = ptr; -+ ptr += barrier_workspace_bytes; -+ } -+ -+ // Zero-initialize barrier workspace -+ if (barrier_workspace) -+ { -+ size_t barrier_workspace_bytes = get_barrier_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); -+ -+ cudaError_t result = cudaMemsetAsync( -+ barrier_workspace, -+ 0, -+ barrier_workspace_bytes, -+ stream); -+ -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Returns the GEMM volume in thread block tiles -+ cutlass::gemm::GemmCoord get_tiled_shape() const -+ { -+ return block_mapping.tiled_shape(); -+ } -+ -+ -+ /// Returns the total number of thread blocks to launch -+ int get_grid_blocks() const -+ { -+ dim3 grid_dims = get_grid_dims(); -+ return grid_dims.x * grid_dims.y * grid_dims.z; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ dim3 get_grid_dims() const -+ { -+ return block_mapping.get_grid_dims(); -+ } -+ -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()"); -+ -+ // Update input/output pointers -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_C = args.batch_stride_C; -+ batch_stride_D = args.batch_stride_D; -+ -+ output_op = args.epilogue; -+ } -+ -+ }; -+ -+ /// Tile work descriptor -+ struct TileWorkDesc -+ { -+ /// The linear tile index -+ int tile_idx; -+ -+ /// The location of this tile (in threadblock-tile coordinates) in the output matrix -+ cutlass::gemm::GemmCoord tiled_coord; -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ int iter_begin; -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ int k_begin; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ int k_end; -+ -+ /// The number of remaining MAC-iterations this threadblock will perform for this tile -+ int k_iters_remaining; -+ -+ // Whether this block will perform the first iteration of this tile -+ CUTLASS_DEVICE -+ bool tile_started() -+ { -+ return (k_begin == 0); -+ } -+ -+ // Whether this block will perform the last iteration of this tile -+ CUTLASS_DEVICE -+ bool tile_finished(Params const ¶ms) -+ { -+ return (k_end == params.block_mapping.problem_size.k()); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage -+ { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem parameters -+ Params const ¶ms; -+ -+ /// Shared storage reference -+ SharedStorage &shared_storage; -+ -+ /// ID within the threadblock -+ int thread_idx; -+ -+ /// ID of warp -+ int warp_idx; -+ -+ /// ID of each thread within a warp -+ int lane_idx; -+ -+ /// Threadblock scoped epilogue -+ Epilogue epilogue; -+ -+ -+public: -+ -+ // -+ // Host-only dispatch API -+ // -+ -+ /// Determines whether the GEMM problem size satisfies this kernel's -+ /// alignment requirements -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Determines whether the GEMM problem satisfies this kernel's -+ /// alignment requirements -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+protected: -+ -+ // -+ // Device-only utility methods -+ // -+ -+ /// Iterator for fetching tile fragments from A -+ CUTLASS_DEVICE -+ typename Mma::IteratorA init_iterator_A( -+ TileWorkDesc &tile_work, -+ GemmUniversalMode mode) -+ { -+ // The input A matrix -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ -+ // Update input pointers based on batched/array mode -+ if (mode == GemmUniversalMode::kBatched) { -+ ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; -+ } -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; -+ } -+ -+ int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; -+ int m_end = params.block_mapping.problem_size.m(); -+ return Mma::IteratorA( -+ params.params_A, -+ ptr_A, -+ { m_end, tile_work.k_end }, -+ threadIdx.x, -+ { m_begin, tile_work.k_begin }); -+ -+ } -+ -+ -+ /// Iterator for fetching tile fragments from B -+ CUTLASS_DEVICE -+ typename Mma::IteratorB init_iterator_B( -+ TileWorkDesc &tile_work, -+ GemmUniversalMode mode) -+ { -+ // The input B matrix -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // Update input pointers based on batched/array mode -+ if (mode == GemmUniversalMode::kBatched) { -+ ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; -+ } -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; -+ } -+ -+ int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; -+ int n_end = params.block_mapping.problem_size.n(); -+ return Mma::IteratorB( -+ params.params_B, -+ ptr_B, -+ { tile_work.k_end, n_end }, -+ threadIdx.x, -+ { tile_work.k_begin, n_begin }); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void init_dp_tile_work( -+ TileWorkDesc &tile_work, -+ int tile_idx) -+ { -+ // The linear tile index -+ tile_work.tile_idx = tile_idx; -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ -+ // The number of MAC-iterations this threadblock will perform for this tile -+ tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_begin = 0; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_end = params.block_mapping.problem_size.k(); -+ -+ // The location of this tile (in threadblock-tile coordinates) in the output matrix -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void init_sk_tile_work( -+ TileWorkDesc &tile_work, -+ int tile_idx, -+ int block_iter_begin, -+ int block_iter_end) -+ { -+ // The linear tile index -+ tile_work.tile_idx = tile_idx; -+ -+ // The first global-scoped MAC-iteration for this tile -+ int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ -+ // The first global-scoped MAC-iteration this threadblock will perform for this tile -+ tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); -+ -+ // The first tile-scoped MAC-iteration this threadblock will perform for this tile -+ int k_iter_begin = tile_work.iter_begin - tile_iter_begin; -+ -+ // The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile -+ int k_iter_end = block_iter_end - tile_iter_begin; -+ -+ // The number of MAC-iterations this threadblock will perform for this tile -+ tile_work.k_iters_remaining = k_iter_end - k_iter_begin; -+ -+ // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_begin = k_iter_begin * Mma::Shape::kK; -+ -+ // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile -+ tile_work.k_end = min( -+ params.block_mapping.problem_size.k(), // extent of k domain -+ (k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment -+ -+ // The location of this tile (in threadblock-tile coordinates) in the output matrix -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); -+ } -+ -+ -+ /// Share accumulators with peers -+ CUTLASS_DEVICE -+ void share_accumulators( -+ AccumulatorTile const &accumulator_tile, -+ int block_idx, -+ int first_block_idx) -+ { -+ AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); -+ -+ int accum_tile_offset = first_block_idx * kThreadCount; -+ -+ if (block_idx == first_block_idx) -+ { -+ // First peer initializes the workspace partials -+ BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); -+ } -+ else -+ { -+ // Subsequent peers atomically accumulate into the workspace partials -+ if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) -+ { -+ // Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them -+ Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1); -+ } -+ else -+ { -+ // Turnstile reduction order: wait until the previous peer has written -+ int wait_count = block_idx - first_block_idx; -+ Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count); -+ } -+ -+ // Perform reduction in workspace -+ BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); -+ } -+ -+ // Signal our arrival -+ Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx); -+ } -+ -+ -+ /// Acquire accumulators from peers -+ CUTLASS_DEVICE -+ void acquire_accumulators( -+ AccumulatorTile &accumulator_tile, -+ int block_idx, -+ int first_block_idx) -+ { -+ AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); -+ -+ // Wait for arrival -+ int num_carry_in = block_idx - first_block_idx; -+ Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); -+ -+ // Load and add peer-partials accumulator tile to local accumulator tile -+ int accum_tile_offset = first_block_idx * kThreadCount; -+ BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx); -+ } -+ -+ -+ /// Perform epilogue computations and output -+ CUTLASS_DEVICE -+ void do_epilogue( -+ TileWorkDesc &tile_work, -+ AccumulatorTile &accumulator_tile) -+ { -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Update pointers for batched/array mode(s) -+ if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C; -+ ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; -+ } -+ if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[tile_work.tiled_coord.k()]; -+ ptr_D = static_cast(params.ptr_D)[tile_work.tiled_coord.k()]; -+ } -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tile_work.tiled_coord.m() * Mma::Shape::kM, -+ tile_work.tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue.unified( -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ accumulator_tile, -+ iterator_C); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void separate_reduction(int reduce_idx) -+ { -+ int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; -+ -+ // Reduce by sk-tile (every tile contributed to by one or more blocks) -+ reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; -+ reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; -+ -+ int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile(); -+ int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1; -+ -+ peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); -+ peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); -+ -+ // Wait for peers to complete -+ int peer_idx_end = peer_idx_last + 1; -+ int num_peers = peer_idx_end - peer_idx_begin; -+ Barrier::wait_eq_reset( -+ params.barrier_workspace, -+ thread_idx, -+ (reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx, -+ num_peers); -+ -+ /// The location of this tile (in threadblock-tile coordinates) in the output matrix -+ GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx); -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tiled_coord.m() * Mma::Shape::kM, -+ tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue.reduce( -+ peer_idx_begin, -+ peer_idx_end, -+ reduce_fragment_idx, -+ params.partials_workspace, -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ iterator_C); -+ } -+ -+ -+ CUTLASS_DEVICE -+ void process_tile( -+ TileWorkDesc tile_work, -+ int block_idx, -+ int dp_start_block_idx, -+ int block_iter_begin) -+ { -+ // Initialize input iterators -+ typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); -+ typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); -+ -+ // Initialize accumulators -+ AccumulatorTile accumulator_tile; -+ accumulator_tile.clear(); -+ -+ // Perform this tile's range of multiply-accumulate (MAC) iterations -+ Mma mma( -+ shared_storage.main_loop, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); -+ -+ if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || -+ (params.block_mapping.reduction_blocks == 0) || -+ (block_idx >= dp_start_block_idx)) -+ { -+ // -+ // Cooperative SK peer reduction or DP block -+ // -+ -+ int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx); -+ -+ if (!tile_work.tile_finished(params)) { -+ // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace -+ share_accumulators(accumulator_tile, block_idx, first_block_idx); -+ } -+ else -+ { -+ // DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile -+ if (!tile_work.tile_started()) -+ { -+ // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks -+ acquire_accumulators(accumulator_tile, block_idx, first_block_idx); -+ } -+ -+ do_epilogue(tile_work, accumulator_tile); -+ } -+ } -+ else -+ { -+ // -+ // Separate peer reduction -+ // -+ -+ // Share accumulator partial sums with peer threadblock(s) through scratch workspace -+ epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started()); -+ -+ // Signal arrival -+ Barrier::arrive_range_inc( -+ params.barrier_workspace, -+ thread_idx, -+ tile_work.tile_idx * Epilogue::kAccumulatorFragments, -+ Epilogue::kAccumulatorFragments); -+ } -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void gemm() -+ { -+ // Initialize block's iteration range -+ int tile_idx, block_iter_begin, block_iters_remaining; -+ -+ int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); -+ int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; -+ int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; -+ int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; -+ -+ int block_idx = params.block_mapping.get_block_idx(); -+ if (block_idx < sk_padding_start_block_idx) -+ { -+ // This is a SK block -+ int block_iter_end; -+ params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); -+ block_iters_remaining = block_iter_end - block_iter_begin; -+ -+ tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); -+ } -+ else if (block_idx < dp_start_block_idx) -+ { -+ // This is a filler block -+ return; -+ } -+ else if (block_idx < reduce_start_block_idx) -+ { -+ // This is a DP block -+ int dp_block_idx = block_idx - dp_start_block_idx; -+ int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles; -+ -+ // Blocks in first DP wave get configured number of tiles -+ tile_idx = first_dp_tile + dp_block_idx; -+ int tile_allottment = params.block_mapping.dp_first_wave_tiles; -+ -+ // Blocks in subsequent DP waves get 1 tile -+ if (dp_block_idx >= params.block_mapping.avail_sms) { -+ tile_allottment = 1; -+ tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; -+ } -+ -+ block_iter_begin = 0; -+ block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; -+ } -+ -+ else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) && -+ (block_idx < grid_padding_start_block_idx)) -+ { -+ // This is a reduction threadblock -+ int reduce_block_idx = block_idx - reduce_start_block_idx; -+ separate_reduction(reduce_block_idx); -+ return; -+ } -+ else -+ { -+ // This is a filler block -+ return; -+ } -+ -+ // Iteration-processing loop body -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (true) -+ { -+ // Initialize tile work descriptor -+ TileWorkDesc tile_work; -+ if (block_idx >= dp_start_block_idx) -+ { -+ init_dp_tile_work(tile_work, tile_idx); -+ -+ // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) -+ if ((tile_idx < params.block_mapping.sk_tiles) || -+ (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || -+ (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) -+ { -+ break; -+ } -+ } -+ else -+ { -+ init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); -+ } -+ -+ // Perform this block's share of work for this tile -+ process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin); -+ -+ // Update remaining work for this block -+ block_iters_remaining -= tile_work.k_iters_remaining; -+ if (block_iters_remaining == 0) { -+ // Done -+ break; -+ } -+ -+ // Continue to next tile -+ __syncthreads(); -+ -+ if (block_idx >= dp_start_block_idx) -+ { -+ // DP block consume their tiles at stride -+ tile_idx += params.block_mapping.avail_sms; -+ } -+ else -+ { -+ // SK blocks consume their tiles in backwards order -+ tile_idx--; -+ } -+ } -+ -+ } -+ -+ -+ /// Executes one DP-only GEMM -+ CUTLASS_DEVICE -+ void gemm_dp() -+ { -+ int block_idx = blockIdx.x; -+ int tile_idx = block_idx; -+ -+ TileWorkDesc tile_work; -+ tile_work.tile_idx = tile_idx; -+ tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); -+ tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); -+ tile_work.k_begin = 0; -+ tile_work.k_end = params.block_mapping.problem_size.k(); -+ tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx); -+ -+ // Initialize input iterators -+ typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); -+ typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); -+ -+ // Initialize accumulators -+ AccumulatorTile accumulator_tile; -+ accumulator_tile.clear(); -+ -+ // Perform this tile's range of multiply-accumulate (MAC) iterations -+ Mma mma( -+ shared_storage.main_loop, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // Location of this tile in item-coords -+ MatrixCoord threadblock_item_begin( -+ tile_work.tiled_coord.m() * Mma::Shape::kM, -+ tile_work.tiled_coord.n() * Mma::Shape::kN -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.block_mapping.problem_size.mn(), -+ thread_idx, -+ threadblock_item_begin); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ EpilogueOutputOp(params.output_op), -+ iterator_D, -+ accumulator_tile); -+ } -+ -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmUniversalStreamk op(params, shared_storage); -+ op(); -+ } -+ -+ -+ // Constructor -+ CUTLASS_DEVICE -+ GemmUniversalStreamk( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ : -+ params(params), -+ shared_storage(shared_storage), -+ thread_idx(threadIdx.x), -+ warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code -+ lane_idx(threadIdx.x % 32), -+ epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx) -+ {} -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()() -+ { -+#if (__CUDACC_VER_MAJOR__ > 10) -+ if (params.quick_dp) -+ { -+ // Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray -+ // modes will only be launched using a data-parallel configurations) -+ gemm_dp(); -+ return; -+ } -+#endif -+ -+ // Generic SK code path -+ gemm(); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h -new file mode 100644 -index 0000000..8f67bd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h -@@ -0,0 +1,1487 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Gemm kernel with fused reduction operation. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/layout.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool IsSingleSource = Epilogue_::kIsSingleSource -+> -+struct GemmWithFusedEpilogue; -+ -+// GemmWithFusedEpilogue with two sources -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithFusedEpilogue { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase{ -+ -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C1; -+ void const * ptr_C2; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C1; -+ int64_t batch_stride_C2; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc1; -+ typename LayoutC::Stride::Index ldc2; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C1(nullptr), -+ ptr_C2(nullptr), -+ ptr_D(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C1, -+ void const * ptr_C2, -+ void * ptr_D, -+ void * ptr_Vector, -+ void * ptr_Tensor, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C1, -+ int64_t batch_stride_C2, -+ int64_t batch_stride_D, -+ int64_t batch_stride_Vector, -+ int64_t batch_stride_Tensor, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc1, -+ typename LayoutC::Stride::Index ldc2, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutC::Stride::Index ldr, -+ typename LayoutC::Stride::Index ldt) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C1(batch_stride_C1), -+ batch_stride_C2(batch_stride_C2), -+ batch_stride_Vector(batch_stride_Vector), -+ batch_stride_Tensor(batch_stride_Tensor), -+ lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << this->ldt); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C1; -+ typename Epilogue::OutputTileIterator::Params params_C2; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C1; -+ void * ptr_C2; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C1; -+ int64_t batch_stride_C2; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C1(args.ldc1), -+ params_C2(args.ldc2), -+ params_D(args.ldd), -+ params_Tensor(args.ldt), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C1(const_cast(args.ptr_C1)), -+ ptr_C2(const_cast(args.ptr_C2)), -+ ptr_D(args.ptr_D), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C1(args.batch_stride_C1), -+ batch_stride_C2(args.batch_stride_C2), -+ batch_stride_Vector(args.batch_stride_Vector), -+ batch_stride_Tensor(args.batch_stride_Tensor) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << args.ldt); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ CUTLASS_HOST_DEVICE -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C1 = const_cast(args.ptr_C1); -+ ptr_C2 = const_cast(args.ptr_C2); -+ ptr_D = args.ptr_D; -+ -+ ptr_Vector = args.ptr_Vector; -+ ldr = args.ldr; -+ ptr_Tensor = args.ptr_Tensor; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithFusedEpilogue op; -+ op(params, shared_storage); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C1 = static_cast(params.ptr_C1); -+ ElementC *ptr_C2 = static_cast(params.ptr_C2); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // -+ // Special path when split-K not enabled. -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ ptr_C1, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C2( -+ params.params_C2, -+ ptr_C2, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C1, -+ iterator_C2, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ return; -+ } -+ -+ // -+ // Slower path when split-K or batching is needed -+ // -+ -+ -+ #if SPLIT_K_ENABLED -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; -+ if (ptr_C2) { -+ ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2; -+ } -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ if (ptr_Tensor) { -+ ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; -+ } -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C1 = static_cast(params.ptr_C1)[threadblock_tile_offset.k()]; -+ if (ptr_C2) { -+ ptr_C2 = static_cast(params.ptr_C2)[threadblock_tile_offset.k()]; -+ } -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ if (ptr_Tensor) { -+ ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; -+ } -+ if (ptr_Vector) { -+ ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; -+ } -+ } -+ #endif -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C1( -+ params.params_C1, -+ ptr_C1, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C2( -+ params.params_C2, -+ ptr_C2, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ #if SPLIT_K_ENABLED -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C1 = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ #endif -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C1, -+ iterator_C2, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ #if SPLIT_K_ENABLED -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ #endif -+ } -+}; -+ -+// GemmWithFusedEpilogue with one source -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithFusedEpilogue { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutC::Stride::Index ldr; -+ typename LayoutC::Stride::Index ldt; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ void * ptr_Vector, -+ void * ptr_Tensor, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ int64_t batch_stride_Vector, -+ int64_t batch_stride_Tensor, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutC::Stride::Index ldr, -+ typename LayoutC::Stride::Index ldt) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ ptr_Vector(ptr_Vector), -+ ptr_Tensor(ptr_Tensor), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C(batch_stride_C), -+ batch_stride_Vector(batch_stride_Vector), -+ batch_stride_Tensor(batch_stride_Tensor), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << this->ldt); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::TensorTileIterator::Params params_Tensor; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ void * ptr_Vector; -+ typename LayoutC::Stride::Index ldr; -+ -+ void * ptr_Tensor; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_Vector; -+ int64_t batch_stride_Tensor; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ params_Tensor(args.ldt), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ ptr_Vector(args.ptr_Vector), -+ ldr(args.ldr), -+ ptr_Tensor(args.ptr_Tensor), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_Vector(args.batch_stride_Vector), -+ batch_stride_Tensor(args.batch_stride_Tensor) -+ { -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); -+ CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ CUTLASS_TRACE_HOST(" ldt: " << args.ldt); -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ CUTLASS_HOST_DEVICE -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_Vector = args.ptr_Vector; -+ ldr = args.ldr; -+ ptr_Tensor = args.ptr_Tensor; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); -+ CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); -+ CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); -+ CUTLASS_TRACE_HOST(" ldr: " << this->ldr); -+ } -+ }; -+ -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithFusedEpilogue op; -+ op(params, shared_storage); -+ } -+ -+ #define SPLIT_K_ENABLED 1 -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ -+ #if SPLIT_K_ENABLED -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ #endif -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); -+ -+ // Define the reduction output pointer and move to the appropriate place -+ typename Epilogue::ElementVector *ptr_Vector = -+ static_cast(params.ptr_Vector); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // -+ // Special path when split-K not enabled. -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ return; -+ } -+ -+ // -+ // Slower path when split-K or batching is needed -+ // -+ -+ -+ #if SPLIT_K_ENABLED -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ if (ptr_Tensor) { -+ ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; -+ } -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ if (ptr_Tensor) { -+ ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; -+ } -+ if (ptr_Vector) { -+ ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; -+ } -+ } -+ #endif -+ -+ // Tile iterators loading from source tensors. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Additional tensor to load from -+ typename Epilogue::TensorTileIterator tensor_iterator( -+ params.params_Tensor, -+ // Only the final block outputs Tensor -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Tensor, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset); -+ -+ // Construct the epilogue -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ #if SPLIT_K_ENABLED -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ #endif -+ -+ // Move to appropriate location for this output tile -+ if (ptr_Vector) { -+ ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, -+ // Only the final block uses Vector -+ ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && -+ (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) -+ ? nullptr -+ : ptr_Vector, -+ iterator_D, -+ accumulators, -+ iterator_C, -+ tensor_iterator, -+ params.problem_size.mn(), -+ threadblock_offset); -+ -+ // -+ // Release the semaphore -+ // -+ -+ #if SPLIT_K_ENABLED -+ if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h -new file mode 100644 -index 0000000..8e00e18 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemm_with_k_reduction.h -@@ -0,0 +1,695 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename EpilogueGemmKReduction_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmWithKReduction { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using EpilogueGemmKReduction = EpilogueGemmKReduction_; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ using LayoutGemmKReduction = cutlass::layout::PitchLinear; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ static int const kReduceKForA = Mma::kReduceKForA; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase -+ { -+ // -+ // Data members -+ // -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ void * ptr_gemm_k_reduction; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_gemm_k_reduction; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction; -+ -+ // -+ // Methods -+ // -+ -+ Arguments() : -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ ptr_gemm_k_reduction(nullptr) -+ {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ void * ptr_gemm_k_reduction, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ int64_t batch_stride_gemm_k_reduction, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd, -+ typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction) -+ : -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_gemm_k_reduction(ptr_gemm_k_reduction), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_gemm_k_reduction(batch_stride_gemm_k_reduction), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_gemm_k_reduction(ld_gemm_k_reduction) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> -+ { -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ // -+ // Data members -+ // -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ void * ptr_gemm_k_reduction; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_gemm_k_reduction; -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Constructor -+ Params( -+ Arguments const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_gemm_k_reduction(args.batch_stride_gemm_k_reduction), -+ ptr_D(args.ptr_D), -+ ptr_gemm_k_reduction(args.ptr_gemm_k_reduction) -+ {} -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << this->problem_size); -+ -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D = workspace; -+ ptr_gemm_k_reduction = static_cast(workspace) -+ + sizeof(ElementC) * size_t(this->batch_stride_D) * size_t(this->grid_tiled_shape.k()); -+ -+ return Status::kSuccess; -+ } -+ -+ return ParamsBase::init_workspace(workspace, stream); -+ } -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = ParamsBase::get_workspace_size(); -+ -+ if (this->mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Split-K parallel always requires a temporary workspace -+ workspace_bytes += -+ sizeof(ElementC) * -+ size_t(batch_stride_gemm_k_reduction) * -+ size_t(this->grid_tiled_shape.k()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ /// Lightweight update given a subset of arguments. Problem geometry is assumed -+ /// to remain the same. -+ void update(Arguments const &args) -+ { -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; -+ -+ output_op = args.epilogue; -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ -+public: -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand A"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand B"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for operand C"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ -+public: -+ -+ // -+ // Device-only API -+ // -+ -+ // Factory invocation -+ CUTLASS_DEVICE -+ static void invoke( -+ Params const ¶ms, -+ SharedStorage &shared_storage) -+ { -+ GemmWithKReduction op; -+ op(params, shared_storage); -+ } -+ -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ typename Mma::FragmentReduction gemm_k_accumulators; -+ -+ gemm_k_accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators, -+ gemm_k_accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ ElementC *ptr_gemm_k_reduction = static_cast(params.ptr_gemm_k_reduction); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ ptr_gemm_k_reduction += threadblock_tile_offset.k() * params.batch_stride_gemm_k_reduction; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ if ((kReduceKForA && threadblock_tile_offset.n() == 0) -+ || (!kReduceKForA && threadblock_tile_offset.m() == 0)) { -+ -+ int warp_idx_mn = warp_idx % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); -+ int warp_idx_m = warp_idx_mn % Mma::Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Mma::Base::WarpCount::kM; -+ -+ if ((kReduceKForA && warp_idx_n == 0) -+ || (!kReduceKForA && warp_idx_m == 0)) { -+ -+ int reduction_warp_idx = kReduceKForA ? warp_idx_m : warp_idx_n; -+ int reduction_threadblock_offset = kReduceKForA ? threadblock_tile_offset.m() : -+ threadblock_tile_offset.n(); -+ int reduction_vector_size = kReduceKForA ? params.problem_size.m() -+ : params.problem_size.n(); -+ EpilogueGemmKReduction epilogue_gemm_k_reduction(thread_idx, -+ reduction_warp_idx, -+ lane_idx, -+ reduction_threadblock_offset, -+ ptr_gemm_k_reduction); -+ epilogue_gemm_k_reduction( -+ reduction_vector_size, -+ gemm_k_accumulators, -+ params.mode == GemmUniversalMode::kGemm -+ && (params.grid_tiled_shape.k() > 1) -+ && (threadblock_tile_offset.k() > 0)); -+ } -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h -new file mode 100644 -index 0000000..acde3d5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementAccumulator_, -+ typename EpilogueOutputOp_ -+> -+struct Gemv { -+public: -+ -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using TensorRefA = TensorRef; -+ -+ static_assert(platform::is_same::value, -+ "Only supported for column-major A matrix"); -+ -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using EpilogueOutputOp = EpilogueOutputOp_; -+ -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ static int const kThreadCount = 32; -+ static int const kStages = 1; -+ -+ static int const kAlignmentA = 1; -+ static int const kAlignmentB = 1; -+ static int const kAlignmentC = 1; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ MatrixCoord problem_size; -+ int32_t batch_count; -+ typename EpilogueOutputOp::Params output_op; -+ -+ TensorRefA ref_A; -+ -+ ElementB const *ptr_B; -+ ElementC const *ptr_C; -+ ElementC *ptr_D; -+ -+ int64_t inc_B; -+ int64_t inc_C; -+ int64_t inc_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): batch_count(0) { } -+ -+ Arguments( -+ MatrixCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params output_op, -+ TensorRefA ref_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t inc_B, -+ int64_t inc_C, -+ int64_t inc_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D -+ ): -+ problem_size(problem_size), -+ batch_count(batch_count), -+ output_op(output_op), -+ ref_A(ref_A), -+ ptr_B(static_cast(ptr_B)), -+ ptr_C(static_cast(ptr_C)), -+ ptr_D(static_cast(ptr_D)), -+ inc_B(inc_B), -+ inc_C(inc_C), -+ inc_D(inc_D), -+ batch_stride_A(batch_stride_A), -+ batch_stride_B(batch_stride_B), -+ batch_stride_C(batch_stride_C), -+ batch_stride_D(batch_stride_D) -+ { } -+ -+ Arguments( -+ MatrixCoord problem_size, -+ typename EpilogueOutputOp::Params output_op, -+ TensorRefA ref_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t inc_B, -+ int64_t inc_C, -+ int64_t inc_D -+ ): -+ Arguments( -+ problem_size, -+ 1, -+ output_op, -+ ref_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ inc_B, -+ inc_C, -+ inc_D, -+ 1, -+ 1, -+ 1, -+ 1) -+ { } -+ -+ Status update(Arguments const &args) { -+ output_op = args.output_op; -+ ref_A = ref_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ -+ return Status::kSuccess; -+ } -+ }; -+ -+ using Params = Arguments; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Gemv() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::MatrixCoord const & problem_size) { -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Loop over batch indices -+ for (int batch_idx = blockIdx.z; batch_idx < params.batch_count; batch_idx += gridDim.z) { -+ -+ int i = blockIdx.x * kThreadCount + threadIdx.x; -+ -+ ElementA const *ptr_A = params.ref_A.data() + i; -+ ElementB const *ptr_B = params.ptr_B; -+ -+ ptr_A += batch_idx * params.batch_stride_A; -+ ptr_B += batch_idx * params.batch_stride_B; -+ -+ ElementAccumulator accum = ElementAccumulator(); -+ -+ // Compute inner product -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < params.problem_size.column(); ++k) { -+ -+ // Fetch from A -+ ElementA a = ElementA(); -+ if (i < params.problem_size.row()) { -+ a = *ptr_A; -+ } -+ ptr_A += params.ref_A.stride(0); -+ -+ // Fetch from B -+ ElementB b = *ptr_B; -+ ptr_B += params.inc_B; -+ -+ // Math -+ accum += ElementAccumulator(a) * ElementAccumulator(b); -+ } -+ -+ // -+ // Epilogue phase -+ // -+ -+ ElementC const *ptr_C = params.ptr_C + i * params.inc_C + batch_idx * params.batch_stride_C; -+ ElementC *ptr_D = params.ptr_D + i * params.inc_D + batch_idx * params.batch_stride_D; -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ typename EpilogueOutputOp::FragmentAccumulator accum_fragment; -+ typename EpilogueOutputOp::FragmentOutput source_fragment; -+ typename EpilogueOutputOp::FragmentOutput output_fragment; -+ -+ accum_fragment[0] = accum; -+ -+ if (i < params.problem_size.row()) { -+ if (output_op.is_source_needed()) { -+ source_fragment[0] = *ptr_C; -+ output_fragment = output_op(accum_fragment, source_fragment); -+ } -+ else { -+ output_fragment = output_op(accum_fragment); -+ } -+ -+ *ptr_D = output_fragment[0]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h -new file mode 100755 -index 0000000..613a279 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/gemv_batched_strided.h -@@ -0,0 +1,244 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+namespace detail -+{ -+ template -+ struct GemvBatchedStridedEpilogueScaling -+ { -+ ElementAlphaBeta const & alpha; -+ ElementAlphaBeta const & beta; -+ -+ CUTLASS_DEVICE -+ GemvBatchedStridedEpilogueScaling(ElementAlphaBeta& alpha_, ElementAlphaBeta& beta_) : -+ alpha(alpha_), beta(beta_) -+ { } -+ -+ template -+ CUTLASS_DEVICE -+ void operator()(FragmentAccumulator& accumulators, -+ FragmentCD const& fragment_C, -+ FragmentCD& fragment_D) const -+ { -+ using AccType = typename FragmentAccumulator::value_type; -+ using CDType = typename FragmentCD::value_type; -+ -+ static_assert(FragmentCD::kElements == FragmentAccumulator::kElements, -+ "Mistmatch in fragment sizes."); -+ -+ for (int i = 0; i < FragmentCD::kElements; ++i) -+ { -+ if (BetaIsZero) -+ { -+ fragment_D[i] = CDType(accumulators[i] * AccType(alpha)); -+ } -+ else -+ { -+ fragment_D[i] = CDType(accumulators[i] * AccType(alpha) -+ + AccType(fragment_C[i]) * AccType(beta)); -+ } -+ } -+ } -+ }; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_DEVICE void GemvBatchedStridedDevice( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ ElementAlphaBeta beta, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_C, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; -+ using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; -+ using EpilogueScale = detail::GemvBatchedStridedEpilogueScaling; -+ -+ ThreadBlockSwizzle swizzler; -+ -+ // Compute initial location in logical coordinates -+ BatchedGemmCoord tb_offset = swizzler.get_tile_offset(); -+ int const batch_idx = swizzler.get_batch_idx(); -+ -+ // Offset to the batch -+ ref_A.add_pointer_offset(batch_idx*lda); -+ ref_B.add_pointer_offset(batch_idx*ldb); -+ -+ // Construct iterators to A and B operands -+ typename GemvKernel::IteratorA::Params params_A(ref_A.layout()); -+ typename GemvKernel::IteratorA iterator_A( -+ params_A, -+ ref_A.data(), -+ { 1, problem_size.k() }, -+ 0, -+ { 0, 0 }); -+ -+ typename GemvKernel::IteratorB::Params params_B(ref_B.layout()); -+ typename GemvKernel::IteratorB iterator_B( -+ params_B, -+ ref_B.data(), -+ { problem_size.k(), problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ ThreadBlockGemv mma; -+ -+ typename ThreadBlockGemv::FragmentC accumulators; -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped gemv -+ mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators); -+ -+ // -+ // Epilogue (TODO: Epiloge as template argument) -+ // -+ typename GemvKernel::FragmentCD fragment_CD; -+ -+ // Load C (skip if beta is zero) -+ if (!BetaIsZero) -+ { -+ tb_offset = swizzler.get_tile_offset(); -+ ref_C.add_pointer_offset(batch_idx*ldc); -+ typename GemvKernel::IteratorCD::Params params_C(ref_C.layout()); -+ typename GemvKernel::IteratorCD iterator_C( -+ params_C, -+ ref_C.data(), -+ { 1, problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ iterator_C.load(fragment_CD); -+ } -+ -+ // Apply alpha/beta scaling -+ EpilogueScale epilogue_scale(alpha, beta); -+ epilogue_scale(accumulators, fragment_CD, fragment_CD); -+ -+ // Store D -+ tb_offset = swizzler.get_tile_offset(); -+ ref_D.add_pointer_offset(batch_idx*ldd); -+ typename GemvKernel::IteratorCD::Params params_D(ref_D.layout()); -+ typename GemvKernel::IteratorCD iterator_D( -+ params_D, -+ ref_D.data(), -+ { 1, problem_size.n() }, -+ threadIdx.x, -+ { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN }); -+ iterator_D.store(fragment_CD); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ ElementAlphaBeta beta, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_C, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ GemvBatchedStridedDevice( -+ problem_size, alpha, beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd -+ ); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementAlphaBeta alpha, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ GemvBatchedStridedDevice( -+ problem_size, alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd -+ ); -+} -+ -+template -+__global__ void GemvBatchedStrided( -+ cutlass::gemm::BatchedGemmCoord problem_size, -+ typename GemvKernel::IteratorA::TensorRef ref_A, -+ typename GemvKernel::IteratorA::TensorRef::LongIndex lda, -+ typename GemvKernel::IteratorB::TensorRef ref_B, -+ typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, -+ typename GemvKernel::IteratorCD::TensorRef ref_D, -+ typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd) -+{ -+ using ElementAlphaBeta = typename GemvKernel::IteratorCD::Element; -+ GemvBatchedStridedDevice( -+ problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h -new file mode 100644 -index 0000000..d9f0249 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/grouped_problem_visitor.h -@@ -0,0 +1,464 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Base scheduler for grouped problems -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerated type describing the type of scheduling to perform for the ProblemVisitor -+enum class GroupScheduleMode { -+ // Perform all scheduling on device -+ kDeviceOnly, -+ // Precompute on the host the full sequence of problems to access -+ kHostPrecompute -+}; -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles -+template -+struct BaseGroupedProblemVisitor { -+ using ThreadblockShape = ThreadblockShape_; -+ -+ struct ProblemInfo { -+ static int32_t const kNoPrefetchEntry = -1; -+ int32_t problem_idx; -+ int32_t problem_start; -+ -+ CUTLASS_DEVICE -+ ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} -+ -+ CUTLASS_DEVICE -+ ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : -+ problem_idx(problem_idx_), problem_start(problem_start_) {} -+ }; -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes; -+ int32_t problem_count; -+ void const *workspace; -+ int32_t tile_count; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params(): problem_sizes(nullptr), problem_count(0), workspace(nullptr), tile_count(0) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const *problem_sizes, -+ int32_t problem_count, -+ void const *workspace = nullptr, -+ int32_t tile_count = 0 -+ ): -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ workspace(workspace), -+ tile_count(tile_count) -+ {} -+ -+ }; -+ -+ Params const ¶ms; -+ int32_t tile_idx; -+ int32_t problem_tile_start; -+ int32_t problem_idx; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaseGroupedProblemVisitor( -+ Params const ¶ms_, -+ int32_t block_idx -+ ): -+ params(params_), -+ tile_idx(block_idx), -+ problem_tile_start(0), -+ problem_idx(0) -+ {} -+ -+ /// Get the grid shape -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return ProblemSizeHelper::grid_shape(problem); -+ } -+ -+ /// Gets the global tile index -+ CUTLASS_HOST_DEVICE -+ int32_t tile_index() const { -+ return tile_idx; -+ } -+ -+ /// Gets the index of the problem -+ CUTLASS_HOST_DEVICE -+ int32_t problem_index() const { -+ return problem_idx; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int32_t threadblock_idx() const { -+ return tile_idx - problem_tile_start; -+ } -+ -+ CUTLASS_DEVICE -+ void advance(int32_t grid_size) { -+ tile_idx += grid_size; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ } -+ -+ /// Returns the problem size for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size() const { -+ GemmCoord problem = params.problem_sizes[problem_idx]; -+ ProblemSizeHelper::possibly_transpose_problem(problem); -+ return problem; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ return ProblemSizeHelper::tile_count(grid); -+ } -+ -+ static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { -+ int32_t total_tiles = 0; -+ for (int32_t i = 0; i < problem_count; ++i) { -+ auto problem = host_problem_sizes_ptr[i]; -+ possibly_transpose_problem(problem); -+ auto grid = grid_shape(problem); -+ total_tiles += tile_count(grid); -+ } -+ -+ return total_tiles; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ProblemSizeHelper, -+ typename ThreadblockShape, -+ GroupScheduleMode GroupScheduleMode_, -+ int PrefetchTileCount, -+ int ThreadCount -+> -+struct GroupedProblemVisitor; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// ProblemVisitor that performs all scheduling on device -+// -+template -+struct GroupedProblemVisitor: public BaseGroupedProblemVisitor { -+ using Base = BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ static bool const kRequiresPrecomputation = false; -+ static int const kThreadsPerWarp = 32; -+ -+ struct SharedStorage {}; -+ -+ // Final tile of the problem loaded by this thread. Each thread will hold -+ // a separate value. -+ int32_t problem_ending_tile; -+ -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ problem_ending_tile(0), -+ shared_storage(shared_storage_) -+ { -+ this->problem_idx = -1 * kThreadsPerWarp; -+ this->problem_tile_start = 0; -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ // Check whether the tile to compute is within the range of the current problem. -+ int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); -+ if (this->tile_idx < problem_tile_end) { -+ return true; -+ } -+ -+ // Check whether the tile to compute is within the current group of problems fetched by the warp. -+ // The last tile for this group is the final tile of the problem held by the final thread in the warp. -+ int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); -+ -+ // Keep the starting problem for this group in `problem_idx`. This is done to reduce -+ // register pressure. The starting problem for this group is simply the first problem -+ // in the group most recently fetched by the warp. -+ int32_t &group_problem_start = this->problem_idx; -+ group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; -+ -+ // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce -+ // register pressure. -+ int32_t &group_tile_start = this->problem_tile_start; -+ -+ // Each thread in the warp processes a separate problem to advance until -+ // reaching a problem whose starting tile is less less than tile_idx. -+ while (group_tile_end <= this->tile_idx) { -+ group_problem_start += kThreadsPerWarp; -+ if (group_problem_start > this->params.problem_count) { -+ return false; -+ } -+ -+ // Since `group_tile_start` is a reference to `this->problem_tile_start`, this -+ // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` -+ // is also set here is used later in `next_tile`. -+ group_tile_start = group_tile_end; -+ -+ int lane_idx = threadIdx.x % kThreadsPerWarp; -+ int32_t lane_problem = group_problem_start + lane_idx; -+ -+ // Compute the number of tiles in the problem assigned to each thread. -+ problem_ending_tile = 0; -+ if (lane_problem < this->params.problem_count) { -+ cutlass::gemm::GemmCoord problem = this->params.problem_sizes[lane_problem]; -+ this->possibly_transpose_problem(problem); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ problem_ending_tile = this->tile_count(grid); -+ } -+ -+ // Compute a warp-wide inclusive prefix sum to compute the ending tile index of -+ // each thread's problem. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kThreadsPerWarp; i <<= 1) { -+ int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); -+ if (lane_idx >= i) { -+ problem_ending_tile += val; -+ } -+ } -+ -+ // The total tile count for this group is now in the final position of the prefix sum -+ int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); -+ -+ problem_ending_tile += group_tile_start; -+ group_tile_end += tiles_in_group; -+ } -+ -+ // The next problem to process is the first one that does not have ending tile position -+ // that is greater than or equal to tile index. -+ int32_t problem_idx_in_group = -+ __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); -+ -+ this->problem_idx = group_problem_start + problem_idx_in_group; -+ -+ // The starting tile for this problem is the ending tile of the previous problem. In cases -+ // where `problem_idx_in_group` is the first problem in the group, we do not need to reset -+ // `problem_tile_start`, because it is set to the previous group's ending tile in the while -+ // loop above. -+ if (problem_idx_in_group > 0) { -+ this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); -+ } -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Precomputes schedule on host and prefetches into shared memory -+// -+template -+struct GroupedProblemVisitor : public BaseGroupedProblemVisitor { -+ static_assert(PrefetchTileCount > 0, -+ "GroupedProblemVisitor with GroupScheduleMode `kHostPrecompute` currently requires prefetching to shared memory"); -+ -+ using Base = BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ using ProblemInfo = typename Base::ProblemInfo; -+ static bool const kRequiresPrecomputation = true; -+ -+ static int const kPrefetchTileCount = PrefetchTileCount; -+ static int const kThreadCount = ThreadCount; -+ -+ struct SharedStorage { -+ // Sequence of problem IDs and starting tiles to compute -+ cutlass::Array prefetched_problems; -+ }; -+ -+ int32_t tiles_computed; -+ int32_t iterations_per_block; -+ int32_t block_load_start; -+ SharedStorage &shared_storage; -+ ProblemInfo const *problem_info_ptr; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ tiles_computed(0), -+ shared_storage(shared_storage_), -+ problem_info_ptr(reinterpret_cast(params_.workspace)) -+ { -+ iterations_per_block = (params_.tile_count - 1 + gridDim.x) / gridDim.x; -+ block_load_start = iterations_per_block * block_idx; -+ // Start prefetching the first set of tiles to compute -+ prefetch_tiles(); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx >= this->params.tile_count) { -+ return false; -+ } -+ -+ int32_t prefetch_idx = (tiles_computed % kPrefetchTileCount); -+ if (prefetch_idx == 0) { -+ // Ensure all previous stores to shared memory have been completed -+ __syncthreads(); -+ } -+ -+ auto problem_info = shared_storage.prefetched_problems[prefetch_idx]; -+ ++tiles_computed; -+ -+ if ((tiles_computed % kPrefetchTileCount) == 0) { -+ // Begin prefetching next set of tiles. Synchronize first to ensure that -+ // we don't overwrite the current buffer while someone else is using it. -+ __syncthreads(); -+ prefetch_tiles(); -+ } -+ -+ this->problem_idx = problem_info.problem_idx; -+ this->problem_tile_start = problem_info.problem_start; -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); -+ int32_t entries_per_block = ((total_tiles - 1 + block_count) / block_count); -+ return sizeof(ProblemInfo) * entries_per_block * block_count; -+ } -+#if !defined(__CUDACC_RTC__) -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) { -+ ProblemInfo* host_problem_info_ptr = reinterpret_cast(host_workspace_ptr); -+ int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); -+ int32_t entries_per_block = (total_tiles - 1 + block_count) / block_count; -+ -+ int tile = 0; -+ int start_tile = 0; -+ for (int p_idx = 0; p_idx < problem_count; ++p_idx) { -+ auto problem = host_problem_sizes_ptr[p_idx]; -+ Base::possibly_transpose_problem(problem); -+ auto grid = Base::grid_shape(problem); -+ int tiles = Base::tile_count(grid); -+ ProblemInfo problem_info(p_idx, start_tile); -+ for (int i = 0; i < tiles; ++i, ++tile) { -+ host_problem_info_ptr[(entries_per_block * (tile % block_count)) + (tile / block_count)] = problem_info; -+ } -+ start_tile += tiles; -+ } -+ } -+#endif -+private: -+ CUTLASS_DEVICE -+ void prefetch_tiles() { -+ // TODO: Consider changing to use async copies from global to shared mem -+ CUTLASS_PRAGMA_UNROLL -+ for (int32_t i = 0; i < kPrefetchTileCount; i += kThreadCount) { -+ int32_t offset = threadIdx.x + i; -+ if (offset < kPrefetchTileCount && (tiles_computed + offset < iterations_per_block)) { -+ shared_storage.prefetched_problems[offset] = problem_info_ptr[block_load_start + tiles_computed + offset]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h -new file mode 100644 -index 0000000..453379d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/params_universal_base.h -@@ -0,0 +1,245 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Base functionality for common types of universal GEMM kernel parameters -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Argument structure -+struct UniversalArgumentsBase -+{ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t batch_stride_D; -+ -+ // -+ // Methods -+ // -+ -+ UniversalArgumentsBase() : -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ batch_stride_D(0) -+ {} -+ -+ /// constructs an arguments structure -+ UniversalArgumentsBase( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ int64_t batch_stride_D) -+ : -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ batch_stride_D(batch_stride_D) -+ { -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+}; -+ -+ -+/// Parameters structure -+template < -+ typename ThreadblockSwizzle, -+ typename ThreadblockShape, -+ typename ElementA, -+ typename ElementB, -+ typename ElementC> -+struct UniversalParamsBase -+{ -+ // -+ // Data members -+ // -+ -+ GemmCoord problem_size; -+ GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ -+ // -+ // Host dispatch API -+ // -+ -+ /// Default constructor -+ UniversalParamsBase() = default; -+ -+ -+ /// Constructor -+ UniversalParamsBase( -+ UniversalArgumentsBase const &args, /// GEMM application arguments -+ int device_sms, /// Number of SMs on the device -+ int sm_occupancy) /// Kernel SM occupancy (in thread blocks) -+ : -+ problem_size(args.problem_size), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(nullptr) -+ { -+ ThreadblockSwizzle swizzle; -+ -+ // Get GEMM volume in thread block tiles -+ grid_tiled_shape = swizzle.get_tiled_shape( -+ args.problem_size, -+ {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, -+ args.batch_count); -+ -+ swizzle_log_tile = swizzle.get_log_tile(grid_tiled_shape); -+ -+ // Determine extent of K-dimension assigned to each block -+ gemm_k_size = args.problem_size.k(); -+ -+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); -+ -+ gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); -+ if (gemm_k_size) { -+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); -+ } -+ } -+ } -+ -+ -+ /// Returns the workspace size (in bytes) needed for this problem geometry -+ size_t get_workspace_size() const -+ { -+ size_t workspace_bytes = 0; -+ if (mode == GemmUniversalMode::kGemmSplitKParallel) -+ { -+ // Split-K parallel always requires a temporary workspace -+ workspace_bytes = -+ sizeof(ElementC) * -+ size_t(batch_stride_D) * -+ size_t(grid_tiled_shape.k()); -+ } -+ else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) -+ { -+ // Serial split-K only requires a temporary workspace if the number of partitions along the -+ // GEMM K dimension is greater than one. -+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); -+ } -+ -+ return workspace_bytes; -+ } -+ -+ -+ /// Assign and initialize the specified workspace buffer. Assumes -+ /// the memory allocated to workspace is at least as large as get_workspace_size(). -+ Status init_workspace( -+ void *workspace, -+ cudaStream_t stream = nullptr) -+ { -+ semaphore = static_cast(workspace); -+ // Zero-initialize entire workspace -+ if (semaphore) -+ { -+ size_t workspace_bytes = get_workspace_size(); -+ -+ CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); -+ -+ cudaError_t result = cudaMemsetAsync( -+ semaphore, -+ 0, -+ workspace_bytes, -+ stream); -+ -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); -+ return Status::kErrorInternal; -+ } -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ -+ /// Returns the GEMM volume in thread block tiles -+ GemmCoord get_tiled_shape() const -+ { -+ return grid_tiled_shape; -+ } -+ -+ -+ /// Returns the total number of thread blocks to launch -+ int get_grid_blocks() const -+ { -+ dim3 grid_dims = get_grid_dims(); -+ return grid_dims.x * grid_dims.y * grid_dims.z; -+ } -+ -+ -+ /// Returns the grid extents in thread blocks to launch -+ dim3 get_grid_dims() const -+ { -+ return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h -new file mode 100644 -index 0000000..1c840e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h -@@ -0,0 +1,704 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Grouped Rank2K kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T) -+ typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ ComplexTransform OriginalTransformA_, ///! Public-facing transformation on A -+ ComplexTransform OriginalTransformB_, ///! Public-facing transformation on B -+ FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper) -+ BlasMode BlasMode_, ///! Blas3 computation mode -+ GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform -+ bool Transposed = false -+> -+struct Rank2KGrouped { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ -+ static_assert(platform::is_same::value && -+ platform::is_same::value, -+ "Kernel-level grouped Rank2K requires that LayoutC be row major."); -+ -+ // Define generic Mma for usecases that use Kernel::Mma -+ using Mma = Mma1_; -+ -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; -+ static bool const kTransposed = Transposed; -+ -+ // Public-facing type definitions related to operand element type, layout, and complex conjugate -+ // operation. Must interact with the 'kTransposed' notion to reflect the original layout, -+ // fill mode, etc. passed in. -+ // -+ // Recall that a Rank2K operation performs (A x BT) + (B x AT) -+ // This is performed via: -+ // Mma1 = (A x BT) -+ // Mma2 = (B x AT) -+ // -+ // However, if C needs to be transposed, then this is changed to the following: -+ // Mma1 = (B x AT) -+ // Mma2 = (A x BT) -+ // -+ // The transformation above is achieved by swapping the Layouts/Elements/Transforms/etc. -+ // of A and B as they are passed into the instantiations of Mma1 and Mma2. -+ // -+ // Now, given access to only Mma1 and Mma2, as well as whether a transposition has occurred, -+ // we wish to retrieve the original Layouts/Elements/etc. for A and B that were passed into -+ // the device-level call. -+ // -+ // The logic to do this (which is made clearer by referencing the above instantiations) is as follows: -+ // LayoutA = kTransposed ? Mma2::LayoutA : Mma1::LayoutA -+ // LayoutB = kTransposed ? Mma1::LayoutA : Mma2::LayoutA -+ // -+ // We achieve this swapping by passing Mma1::*A and Mma2::*B to Rank2KMapArguments: -+ using MapArgumentsA = kernel::detail::Rank2KMapArguments< -+ typename Mma1::IteratorA::Element, -+ typename Mma1::IteratorA::Layout, -+ Mma1::kTransformA, -+ Mma1::IteratorA::AccessType::kElements, -+ typename Mma2::IteratorA::Element, -+ typename Mma2::IteratorA::Layout, -+ Mma2::kTransformA, -+ Mma2::IteratorA::AccessType::kElements, -+ typename Mma1::LayoutC, -+ FillModeC_, -+ kTransposed -+ >; -+ -+ using ElementA = typename MapArgumentsA::ElementA; -+ using LayoutA = typename MapArgumentsA::LayoutA; -+ static int const kAlignmentA = MapArgumentsA::kAlignmentA; -+ -+ using MapArgumentsB = kernel::detail::Rank2KMapArguments< -+ typename Mma2::IteratorA::Element, -+ typename Mma2::IteratorA::Layout, -+ Mma2::kTransformA, -+ Mma2::IteratorA::AccessType::kElements, -+ typename Mma1::IteratorA::Element, -+ typename Mma1::IteratorA::Layout, -+ Mma1::kTransformA, -+ Mma1::IteratorA::AccessType::kElements, -+ typename Mma2::LayoutC, -+ FillModeC_, -+ kTransposed -+ >; -+ -+ using ElementB = typename MapArgumentsB::ElementA; -+ using LayoutB = typename MapArgumentsB::LayoutA; -+ static int const kAlignmentB = MapArgumentsB::kAlignmentA; -+ -+ // Use the user-provided TransformA and TransformB, rather than those -+ // resulting from MapArguments, because Mma1 and Mma2 may have different -+ // complex transforms than those passed in by the user. -+ // (See kernel/rank_2k_complex.h for an example of this) -+ static cutlass::ComplexTransform const kTransformA = OriginalTransformA_; -+ static cutlass::ComplexTransform const kTransformB = OriginalTransformB_; -+ -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename MapArgumentsA::LayoutC; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static FillMode const kFillModeC = MapArgumentsA::kFillModeC; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+private: -+ static FillMode const kInternalFillModeC = FillModeC_; -+ -+public: -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ using ProblemVisitor = Rank2KGroupedProblemVisitor< -+ ThreadblockShape, -+ kGroupScheduleMode, -+ kThreadCount, -+ kThreadCount, -+ kInternalFillModeC>; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord *problem_sizes; -+ int problem_count; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ // Only used by device-level operator -+ GemmCoord *host_problem_sizes; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ problem_count(0), -+ threadblock_count(0), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr), -+ host_problem_sizes(nullptr) -+ { -+ -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord *problem_sizes, -+ int problem_count, -+ int threadblock_count, -+ typename EpilogueOutputOp::Params epilogue, -+ ElementA ** ptr_A, -+ ElementB ** ptr_B, -+ ElementC ** ptr_C, -+ ElementC ** ptr_D, -+ typename LayoutA::Stride::LongIndex *lda, -+ typename LayoutB::Stride::LongIndex *ldb, -+ typename LayoutC::Stride::LongIndex *ldc, -+ typename LayoutC::Stride::LongIndex *ldd, -+ GemmCoord *host_problem_sizes=nullptr -+ ): -+ mode(mode), -+ problem_sizes(problem_sizes), -+ problem_count(problem_count), -+ threadblock_count(threadblock_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), -+ ptr_B(ptr_B), -+ ptr_C(ptr_C), -+ ptr_D(ptr_D), -+ lda(lda), -+ ldb(ldb), -+ ldc(ldc), -+ ldd(ldd), -+ host_problem_sizes(host_problem_sizes) -+ { -+ -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ typename ProblemVisitor::Params problem_visitor; -+ int threadblock_count; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ -+ ElementA ** ptr_A; -+ ElementB ** ptr_B; -+ ElementC ** ptr_C; -+ ElementC ** ptr_D; -+ -+ typename LayoutA::Stride::LongIndex *lda; -+ typename LayoutB::Stride::LongIndex *ldb; -+ typename LayoutC::Stride::LongIndex *ldc; -+ typename LayoutC::Stride::LongIndex *ldd; -+ -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ lda(nullptr), -+ ldb(nullptr), -+ ldc(nullptr), -+ ldd(nullptr) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0): -+ problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), -+ threadblock_count(args.threadblock_count), -+ output_op(args.epilogue), -+ ptr_A(args.ptr_A), -+ ptr_B(args.ptr_B), -+ ptr_C(args.ptr_C), -+ ptr_D(args.ptr_D), -+ lda(args.lda), -+ ldb(args.ldb), -+ ldc(args.ldc), -+ ldd(args.ldd) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr, -+ int tile_count = 0) { -+ -+ problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); -+ threadblock_count = args.threadblock_count; -+ output_op = args.output_op; -+ ptr_A = args.ptr_A; -+ ptr_B = args.ptr_B; -+ ptr_C = args.ptr_C; -+ ptr_D = args.ptr_D; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ struct SharedStorage { -+ union { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ } kernel; -+ -+ // ProblemVisitor shared storage can't be overlapped with others -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Rank2KGrouped() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // -+ // Problem visitor. -+ // -+ -+ ProblemVisitor problem_visitor( -+ params.problem_visitor, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ // Outer 'persistent' loop to iterate over tiles -+ while (problem_visitor.next_tile()) { -+ -+ GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = problem_visitor.threadblock_offset(threadblock_idx); -+ -+ // -+ // Perform checks to determine whether the results of this threadblock will be needed. -+ // An example of an unneeded threadblock is one that is assigned to compute in the upper -+ // portion of a Rank2K kernel filled with mode kLower. -+ // -+ // TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously -+ // returning from `next_tile()`. -+ // -+ -+ // Early exit if threadblock is out of range -+ if (grid_shape.m() <= threadblock_tile_offset.m() || -+ grid_shape.n() <= threadblock_tile_offset.n()) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ // Skip this tile if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kInternalFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ // Skip this tile if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kInternalFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ continue; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN -+ && threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = problem_size.k(); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < grid_shape.k()) { -+ problem_size_k = (threadblock_tile_offset.k() + 1) * problem_size.k(); -+ } -+ -+ offset_k = threadblock_tile_offset.k() * problem_size.k(); -+ } -+ -+ ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); -+ typename LayoutA::Stride::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); -+ -+ ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); -+ typename LayoutB::Stride::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ // Assume identity swizzle -+ MatrixCoord tb_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A( -+ Mma1::IteratorA::Params(ldm_A), -+ ptr_A, -+ {problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma1::IteratorB iterator_BT( -+ Mma1::IteratorB::Params(ldm_B), -+ ptr_B, -+ {problem_size_k, problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_B( -+ Mma2::IteratorA::Params(ldm_B), -+ ptr_B, -+ {problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma2::IteratorB iterator_AT( -+ Mma2::IteratorB::Params(ldm_A), -+ ptr_A, -+ {problem_size_k, problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 (A x BT) -+ Mma1 mma1(shared_storage.kernel.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 (B x AT) -+ Mma2 mma2(shared_storage.kernel.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ -+ // Wait for all threads to finish their epilogue phases from the previous tile. -+ __syncthreads(); -+ -+ // Compute threadblock-scoped matrix multiply-add (A x BT) -+ mma1( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_BT, -+ accumulators); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); -+ ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); -+ -+ // If TB not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ __syncthreads(); -+ -+ accumulators.clear(); -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add (B x AT) -+ mma2( -+ gemm_k_iterations, -+ accumulators, -+ iterator_B, -+ iterator_AT, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ /* Needed for HER2K where the second HERK is multiplied by conj(alpha) */ -+ typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1); -+ EpilogueOutputOp output_op_her2k(second_her2k_params); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ ptr_C = static_cast(params.ptr_D[problem_idx]); -+ } -+ -+ ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); -+ -+ // If TB not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), -+ ptr_C, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), -+ ptr_D, -+ problem_size.mn(), -+ thread_idx, -+ tb_offset, -+ kFillModeTB -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.kernel.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Execute the epilogue operator to update the destination tensor. -+ if (kBlasMode == BlasMode::kSymmetric) { -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } else { -+ epilogue( -+ output_op_her2k, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h -new file mode 100644 -index 0000000..92cc2a7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h -@@ -0,0 +1,376 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Problem visitor for grouped Rank2K operations. -+ -+ This problem visitor is specialized for Rank2K operations, for which matrix C is upper/lower -+ triangular. Using a problem visitor designed for GEMMs for Rank2K problems is inefficient -+ because threadblocks will be frequently assigned to tiles that exit early (e.g., due to -+ being assigned to a tile in the upper-triangular portion of a lower-triangular problem). -+ This can lead to load imbalance among threadblocks, as the GEMM-based scheduler -+ assigns all threadblocks to nearly the same number of tiles, regardless of whether -+ those tiles exit early. -+ -+ Consider an example of a group of four Rank2Ks with matrix C consisting of a grid of 2x2 tiles. -+ Consider a grid of 8 threadblocks. The default GEMM scheduler will assign threadblocks to -+ tiles in the following order: -+ Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 -+ 0 1 4 5 0 1 4 5 -+ 2 3 6 7 2 3 6 7 -+ Assuming that the problems are lower triangular, blocks 1 and 5 are continuously assigned -+ to inactive tiles. -+ -+ This problem visitor aims to assign threadblocks to only those tiles which are in the -+ upper/lower triangular portion of a given problem. Using the example above, the resulting -+ assignment would be: -+ Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 -+ 0 - 3 - 6 - 1 - -+ 1 2 4 5 7 0 2 3 -+ -+ Achieving the schedule above requires a mapping from threadblock ID to tile coordinates (i, j). -+ We will illustrate this by mapping on a lower-triangular matrix with a 3x3 grid. We first -+ calculate row and column indices assuming one-indexed rows, tiles, and threadblock IDs, and -+ then subtract one to convert to zero-indexed. -+ Col 1 Col 2 Col 3 -+ ---------------------- -+ Row 1 | 1 - - -+ Row 2 | 2 3 - -+ Row 3 | 4 5 6 -+ -+ We next outline this mapping, borrowing from: https://stackoverflow.com/a/40954159 -+ -+ Calculating row i given threadblock ID t -+ ---------------------------------------- -+ For a given row i, all threadblock IDs t in that row satisfy the following: -+ t <= 1 + 2 + 3 + ... + (i-1) + i -+ -+ The closed-form equation for the right-hand side is: i(i+1)/2. -+ Using this, we can solve for i given t: -+ t <= i(i+1)/2 -+ 2t <= i^2 + i -+ 2t <= i^2 + i + 0.25 - 0.25 -+ 2t + 0.25 <= i^2 + i + 0.25 -+ 2t + 0.25 <= (i + 0.5)^2 -+ sqrt(2t + 0.25) - 0.5 <= i -+ -+ To account for fractional values, we set: -+ i = ceil(sqrt(2t + 0.25) - 0.5) -+ -+ To turn this into a zero-indexed row and work with zero-indexed t, we perform: -+ i = ceil(sqrt(2(t+1) + 0.25) - 0.5) - 1 -+ = ceil(sqrt(2t + 2.25) - 0.5) - 1 -+ -+ Calculating column j given threadblock ID t and row i -+ ----------------------------------------------------- -+ For a given row i, all threadblock IDs t in that row also satisfy the following: -+ t > 1 + 2 + 3 + ... + (i-2) + (i-1) -+ --> t > i(i-1)/2 -+ -+ Threadblock IDs within a given row are sequential, so the one-indexed column ID -+ for one-indexed threadblock ID t and row i is: -+ j = t - (i(i-1)/2) -+ -+ The zero-indexed version becomes: -+ j = (t+1) - (i(i+1)/2) -1 -+ = t - (i(i+1)/2) -+ -+ Accounting for non-square grids -+ ------------------------------- -+ Though the overall output problem size for Rank2K problems is guranteed to be square, the -+ grids used in computing may not be square due to using non-square threadblock shapes. For -+ example, a threadblock shape of 64x32 operating on a problem of output size 128x128 would -+ result in a grid of 2x4 tiles. -+ -+ This case can be handled by noting that the output resembles a square grid of 2x2 "macro tiles" -+ each of which contains 2 "true tiles." We can thus first map a threadblock ID to its "macro tile" -+ using the equations above, and then map it to the "true tile" within its "macro tile." In the example -+ of a 2x4 grid, this mapping would look as follows: -+ "Macro grid" "True grid" -+ {0, 1} - 0 1 - - -+ {2, 3} {4, 5} 2 3 4 5 -+ -+ A zero-indexed threadblock ID t is mapped to its "macro tile ID" t_macro as: -+ t_macro = t // r -+ Where r is the ratio of the maximum dimension of the grid to the minimum dimension of the grid -+ (i.e., r = 4 / 2 = 2 in the previous example). -+ -+ One uses t_macro and the calculations above to find the row and column in the square matrix to -+ obtain i_macro and j_macro (zero-indexed). The mapping from (i_macro, j_macro) --> (i, j) -+ is simply the following: -+ if (ThreadblockShape::M > ThreadblockShape::N): -+ r = ThreadblockShape::M / ThreadblockShape::N -+ i = i_macro -+ j = (j_macro * r) + (t % r) -+ elif (ThreadblockShape::M < ThreadblockShape::N): -+ r = ThreadblockShape::N / ThreadblockShape::M -+ i = (i_macro * r) + (t % r) -+ j = j_macro -+ else: -+ i = i_macro -+ j = j_macro -+ -+ Handling cases with grid dimensions that aren't multiples of eachother -+ ---------------------------------------------------------------------- -+ Even though threadblock shapes M and N are typically multiples of one another, the grid -+ for a given problem may not have dimensions of the same ratio as that of the threadblock. -+ For example, a problem of size 132x132 using a threadblock of shape 64x32 will result -+ in a grid of 3x5 tiles. In this case, there is not an integer number of "true tiles" -+ per "macro tile." -+ -+ When this scenario arises, we simply pad the larger dimension of the grid such that -+ there are an integer number of "true tiles" per "macro tile." Thus, the 3x5 grid in -+ the example above will be treated as a 3x6 grid. Row and column positions for each -+ tile are calculated as above. Any threadblocks that map to tiles that are outside the -+ problem range or upper/lower triangular portion (e.g., (2, 5)) will exit early from -+ this problem and may proceed to the next problem in the group. -+ -+ Handling upper-triangular matrices -+ ---------------------------------- -+ The only modification needed for upper-triangular matrices is to swap i_macro and j_macro -+ in the calculations above. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+namespace detail { -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Helpers for calculating offsets for Rank2K problem visitor. These helpers specifically pertain -+// to the conversion from "macro tiles" to "true tiles" in the description above. -+// -+template < -+ typename ThreadblockShape, -+ typename Enable = void -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper; -+ -+// Partial specialization for the case where threadblock shape M > threadblock shape N -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM > ThreadblockShape::kN) >::type -+> { -+ static_assert(ThreadblockShape::kM % ThreadblockShape::kN == 0, -+ "Rank2KGroupedProblemVisitor with threadblock shape M > threadblock shape N " -+ "requires that threadblock shape M be a multiple of threadblock shape N."); -+ -+ static int32_t const kThreadblockSkewRatio = ThreadblockShape::kM / ThreadblockShape::kN; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return row; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return (col * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); -+ } -+}; -+ -+// Partial specialization for the case where threadblock shape M < threadblock shape N -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM < ThreadblockShape::kN) >::type -+> { -+ -+ static_assert(ThreadblockShape::kN % ThreadblockShape::kM == 0, -+ "Rank2KGroupedProblemVisitor with threadblock shape M < threadblock shape N " -+ "requires that threadblock shape N be a multiple of threadblock shape M."); -+ -+ static int32_t const kThreadblockSkewRatio = ThreadblockShape::kN / ThreadblockShape::kM; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.n(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return (row * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return col; -+ } -+}; -+ -+// Partial specialization for the case where threadblock shape M == threadblock shape N -+// In this case, macro tiles are equivalent to true tiles, so the conversions are -+// identity functions. -+template < -+ typename ThreadblockShape -+> -+struct Rank2KGroupedProblemVisitorOffsetHelper< -+ ThreadblockShape, -+ typename platform::enable_if< (ThreadblockShape::kM == ThreadblockShape::kN) >::type -+> { -+ -+ static int32_t const kThreadblockSkewRatio = 1; -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t min_dim(cutlass::gemm::GemmCoord grid) { -+ return grid.m(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { -+ return row; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { -+ return col; -+ } -+}; -+ -+// Helper for correctly representing problem sizes in grouped kernels -+template -+struct Rank2KGroupedProblemSizeHelper { -+ using OffsetHelper = Rank2KGroupedProblemVisitorOffsetHelper; -+ -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), -+ ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), -+ 1); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { -+ // Return the number of tiles at or below the diagonal (or at and above -+ // for mode kUpper). We do this by first calculating this value assuming -+ // we have a square matrix of tiles of size `dim x dim` where `dim` is the -+ // minimum among {grid.m(), grid.n()}. We then multiply the resulting value -+ // by OffsetHelper::kThreadblockSkewRatio to account for cases in which there -+ // are more tiles in one dimension than the other. -+ int32_t dim = OffsetHelper::min_dim(grid); -+ int32_t tiles_on_diagonal = dim; -+ int32_t tiles_below_diagonal = ((dim * (dim - 1)) / 2); -+ return (tiles_on_diagonal + tiles_below_diagonal) * OffsetHelper::kThreadblockSkewRatio; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Default problem visitor for fill modes kUpper and kLower. -+// -+template -+struct Rank2KGroupedProblemVisitor : public GroupedProblemVisitor< -+ detail::Rank2KGroupedProblemSizeHelper, -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount> { -+ -+ static cutlass::FillMode const kFillModeC = FillModeC; -+ -+ static_assert(kFillModeC == cutlass::FillMode::kLower || kFillModeC == cutlass::FillMode::kUpper, -+ "Default Rank2KGroupedProblemVisitor requires fill mode of kLower or kUpper."); -+ -+ using ProblemSizeHelper = detail::Rank2KGroupedProblemSizeHelper; -+ using Base = GroupedProblemVisitor; -+ using OffsetHelper = typename ProblemSizeHelper::OffsetHelper; -+ using Params = typename Base::Params; -+ using SharedStorage = typename Base::SharedStorage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ Rank2KGroupedProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, shared_storage_, block_idx) -+ {} -+ -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { -+ int32_t macro_id = threadblock_id / OffsetHelper::kThreadblockSkewRatio; -+ int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; -+ int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); -+ -+ if (kFillModeC == cutlass::FillMode::kUpper) { -+ swap(macro_row, macro_col); -+ } -+ -+ int32_t row = OffsetHelper::macro_row_to_row(macro_row, threadblock_id); -+ int32_t col = OffsetHelper::macro_col_to_col(macro_col, threadblock_id); -+ -+ return cutlass::gemm::GemmCoord(row, col, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h -new file mode 100644 -index 0000000..0837a9d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Transpositions for Rank2K problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ FillMode FillModeC_, -+ bool Transpose -+> -+struct Rank2KMapArguments { -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ static ComplexTransform const kTransformA = TransformA; -+ static int const kAlignmentA = AlignmentA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ static ComplexTransform const kTransformB = TransformB; -+ static int const kAlignmentB = AlignmentB; -+ using LayoutC = LayoutC_; -+ static FillMode const kFillModeC = FillModeC_; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ ComplexTransform TransformA, -+ int AlignmentA, -+ typename ElementB_, -+ typename LayoutB_, -+ ComplexTransform TransformB, -+ int AlignmentB, -+ typename LayoutC_, -+ FillMode FillModeC_ -+> -+struct Rank2KMapArguments< -+ ElementA_, -+ LayoutA_, -+ TransformA, -+ AlignmentA, -+ ElementB_, -+ LayoutB_, -+ TransformB, -+ AlignmentB, -+ LayoutC_, -+ FillModeC_, -+ true -+> { -+ using ElementA = ElementB_; -+ using LayoutA = LayoutB_; -+ static ComplexTransform const kTransformA = TransformB; -+ static int const kAlignmentA = AlignmentB; -+ using ElementB = ElementA_; -+ using LayoutB = LayoutA_; -+ static ComplexTransform const kTransformB = TransformA; -+ static int const kAlignmentB = AlignmentA; -+ using LayoutC = typename layout::LayoutTranspose::type; -+ static FillMode const kFillModeC = InvertFillMode::mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h -new file mode 100644 -index 0000000..6d1f4ac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_2k_universal.h -@@ -0,0 +1,778 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T) -+ typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper) -+ BlasMode BlasMode_ ///! Blas3 computation mode -+> -+struct Rank2KUniversal { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma1::IteratorA::Element; -+ using ElementB = typename Mma1::IteratorB::Element; -+ -+ // Mma1 (A x B^T) -+ using LayoutA = typename Mma1::IteratorA::Layout; -+ using LayoutBT = typename Mma1::IteratorB::Layout; -+ static ComplexTransform const kMma1TransformA = Mma1::kTransformA; -+ static ComplexTransform const kMma1TransformB = Mma1::kTransformB; -+ -+ // Mma2 (B x A^T) -+ using LayoutB = typename Mma2::IteratorA::Layout; -+ using LayoutAT = typename Mma2::IteratorB::Layout; -+ static ComplexTransform const kMma2TransformA = Mma2::kTransformA; -+ static ComplexTransform const kMma2TransformB = Mma2::kTransformB; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ -+ // Output related typedefinitions -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static FillMode const kFillModeC = FillModeC_; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ /// Returns arguments for a the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma1 Iterator A and B params -+ typename Mma1::IteratorA::Params params_A; -+ typename Mma1::IteratorB::Params params_BT; -+ -+ // Mma2 Iterator A and B params -+ typename Mma2::IteratorA::Params params_B; -+ typename Mma2::IteratorB::Params params_AT; -+ -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_BT(0), -+ params_B(0), -+ params_AT(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_BT(args.ldb), -+ params_B(args.ldb), -+ params_AT(args.lda), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ Rank2KUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ return; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN -+ && threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma1::IteratorB iterator_BT( -+ params.params_BT, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_B( -+ params.params_B, -+ ptr_B, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK); -+ -+ typename Mma2::IteratorB iterator_AT( -+ params.params_AT, -+ ptr_A, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 (A x BT) -+ Mma1 mma1(shared_storage.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 (B x AT) -+ Mma2 mma2(shared_storage.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add (A x BT) -+ mma1( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_BT, -+ accumulators); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ -+ __syncthreads(); -+ -+ accumulators.clear(); -+ } -+ -+ // Compute threadblock-scoped matrix multiply-add (B x AT) -+ mma2( -+ gemm_k_iterations, -+ accumulators, -+ iterator_B, -+ iterator_AT, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ /* Needed for HER2K where the second HERK is multiplied by conj(alpha) */ -+ typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1); -+ EpilogueOutputOp output_op_her2k(second_her2k_params); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ -+ // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. -+ if (kBlasMode == BlasMode::kHermitian) { -+ ptr_C = static_cast(params.ptr_D); -+ } -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ if (kBlasMode == BlasMode::kSymmetric) { -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } else { -+ output_op_her2k.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ if (kBlasMode == BlasMode::kSymmetric) { -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } else { -+ epilogue( -+ output_op_her2k, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ } -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h -new file mode 100644 -index 0000000..b7d1ad1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/rank_k_universal.h -@@ -0,0 +1,565 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ FillMode FillModeC_ ///! Fill Mode for C (kLower or kUpper) -+> -+struct RankKUniversal { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static FillMode const kFillModeC = FillModeC_; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = 128 / sizeof_bits::value; -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_B(args.lda), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_A)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_A), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_A); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ RankKUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Lower and -+ // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) -+ if (kFillModeC == cutlass::FillMode::kLower && -+ (threadblock_tile_offset.m() + 1) * Mma::Shape::kM <= threadblock_tile_offset.n() * Mma::Shape::kN) { -+ return; -+ } -+ -+ // Early exit if Fill Mode is Upper and -+ // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) -+ if (kFillModeC == cutlass::FillMode::kUpper && -+ threadblock_tile_offset.m() * Mma::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma::Shape::kN) { -+ return; -+ } -+ -+ bool tile_on_diagonal = false; -+ // Mark tiles that are being crossed by the main diagonal -+ // (top-right and bottom-left corners are on either side of the diagonal) -+ if ((threadblock_tile_offset.m() + 1) * Mma::Shape::kM > threadblock_tile_offset.n() * Mma::Shape::kN -+ && threadblock_tile_offset.m() * Mma::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma::Shape::kN) { -+ tile_on_diagonal = true; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // If CTA not on diagonal, FillMode doesn't apply. -+ FillMode kFillModeCTA = tile_on_diagonal ? kFillModeC : FillMode::kNone; -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset, -+ kFillModeCTA -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp -new file mode 100644 -index 0000000..efe51e2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+#include "cute/tensor.hpp" -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ return { -+ args.mode, -+ args.problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ int batch_count = 1; -+ if constexpr (rank(ProblemShape{}) == 4) { -+ batch_count = cute::size<3>(params.problem_shape); -+ } -+ -+ return dim3( -+ cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), -+ cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), -+ batch_count -+ ); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Preconditions -+ CUTE_STATIC_ASSERT(is_static::value); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ int thread_idx = int(threadIdx.x); -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto [m_coord, n_coord, l_coord] = blockIdx; -+ auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) -+ -+ // Represent the full tensors -+ Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) -+ Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) -+ -+ // Get batch slice -+ Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) -+ Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) -+ -+ // Slice to get the tiles this thread block is responsible for -+ Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) -+ Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) -+ -+ // Compute tile residues for predication -+ auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord -+ auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord -+ auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); -+ -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ int k_tile_count = size<2>(gA); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mma; -+ collective_mma( -+ accumulators, -+ gA, -+ gB, -+ accumulators, -+ k_tile_iter, k_tile_count, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ blk_coord_mnkl, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp -new file mode 100644 -index 0000000..bd82ed1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+ -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+namespace detail { -+ -+// IF_SWAP_AB::value will be true only if: -+// class T has member SwapAB and T::SwapAB is true -+template -+struct IF_SWAP_AB { static constexpr bool value = false; }; -+ -+template -+struct IF_SWAP_AB > -+{ static constexpr bool value = T::SwapAB; }; -+ -+} // namespace -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ auto cluster_shape = ClusterShape{}; -+ auto tile_shape = TileShape{}; -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( -+ problem_shape_MNKL, tile_shape, cluster_shape); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Make tiled views -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Compute m_coord, n_coord, and l_coord with their post-tiled shapes -+ auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); -+ auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with m_coord and n_coord -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ -+ clear(accumulators); -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ auto k_tile_count = size<2>(gA); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mma; -+ collective_mma( -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ accumulators, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ smem_buf, -+ params.mainloop -+ ); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ output_tile_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ thread_idx, -+ smem_buf -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp -new file mode 100644 -index 0000000..9fc719e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp -@@ -0,0 +1,351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+#include "cutlass/pipeline.hpp" -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr int SharedStorageSize = cute::max( -+ sizeof(typename CollectiveMainloop::SharedStorage), -+ sizeof(typename CollectiveEpilogue::SharedStorage)); -+ -+ static constexpr uint32_t NumDmaWarpGroups = 1; -+ static constexpr uint32_t NumMmaWarpGroups = 1; -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumDmaWarpGroups * NumThreadsPerWarpGroup); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace) -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ return args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ auto cluster_shape = ClusterShape{}; -+ auto tile_shape = TileShape{}; -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( -+ problem_shape_MNKL, tile_shape, cluster_shape); -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ enum class WarpGroupRole { -+ Producer = 0, -+ Consumer = 1, -+ }; -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ using Pipeline = typename CollectiveMainloop::MainloopPipeline; -+ -+ using PipelineParams = typename CollectiveMainloop::PipelineParams; -+ PipelineParams params_pipeline; -+ params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; -+ if (warp_group_role == WarpGroupRole::Producer) { -+ params_pipeline.role = Pipeline::ThreadCategory::Producer; -+ } -+ else { -+ params_pipeline.role = Pipeline::ThreadCategory::Consumer; -+ } -+ params_pipeline.is_leader = warp_group_thread_idx == 0; -+ params_pipeline.num_consumers = NumThreadsPerWarpGroup; -+ -+ // Initialize pipeline and setup starting pipeline state for the collectives -+ Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); -+ -+ auto cluster_wait_fn = [&] () { -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer thread blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ return [] () { cute::cluster_wait(); }; -+ } -+ else { -+ __syncthreads(); -+ return [] () {}; // do nothing -+ } -+ } (); -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Make tiled views -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Compute m_coord, n_coord, and l_coord with their post-tiled shapes -+ auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); -+ auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with m_coord and n_coord -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ auto k_tile_count = size<2>(gA); -+ -+ // Wait for all thread blocks in the Cluster -+ cluster_wait_fn(); -+ -+ // In a warp specialized kernel, CollectiveMainloop exposes data movement and compute operations separately -+ CollectiveMainloop collective_mainloop; -+ -+ if (warp_group_role == WarpGroupRole::Producer) { -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ collective_mainloop.dma( -+ pipeline, -+ smem_pipe_write, -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ smem_buf -+ ); -+ // Update starting pipeline state for the next tile -+ smem_pipe_write.advance(k_tile_count); -+ // Make sure all Consumer Warp Groups have been waited upon -+ collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); -+ } -+ else if (warp_group_role == WarpGroupRole::Consumer) { -+ typename CollectiveMainloop::PipelineState smem_pipe_read; -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ collective_mainloop.mma( -+ pipeline, -+ smem_pipe_read, -+ accumulators, -+ k_tile_count, -+ thread_idx, -+ smem_buf, -+ params.mainloop -+ ); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ output_tile_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ warp_group_thread_idx, -+ smem_buf -+ ); -+ } -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp -new file mode 100644 -index 0000000..498bfad ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp -@@ -0,0 +1,487 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/fast_math.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/reg_reconfig.h" -+#include "cutlass/arch/mma_sm90.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/trace.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -+ -+#include "cute/tensor.hpp" -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::gemm::kernel { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ class ProblemShape_, -+ class CollectiveMainloop_, -+ class CollectiveEpilogue_, -+ class GridSwizzle_ -+> -+class GemmUniversal< -+ ProblemShape_, -+ CollectiveMainloop_, -+ CollectiveEpilogue_, -+ GridSwizzle_, -+ std::enable_if_t>> -+{ -+public: -+ // -+ // Type Aliases -+ // -+ using ProblemShape = ProblemShape_; -+ using GridSwizzle = GridSwizzle_; -+ static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, -+ "ProblemShape{} should be or "); -+ -+ // Mainloop derived types -+ using CollectiveMainloop = CollectiveMainloop_; -+ using TileShape = typename CollectiveMainloop::TileShape; -+ using TiledMma = typename CollectiveMainloop::TiledMma; -+ using ArchTag = typename CollectiveMainloop::ArchTag; -+ using ElementA = typename CollectiveMainloop::ElementA; -+ using StrideA = typename CollectiveMainloop::StrideA; -+ using ElementB = typename CollectiveMainloop::ElementB; -+ using StrideB = typename CollectiveMainloop::StrideB; -+ using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; -+ using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; -+ using ClusterShape = typename DispatchPolicy::ClusterShape; -+ using MainloopParams = typename CollectiveMainloop::Params; -+ static_assert(ArchTag::kMinComputeCapability >= 90); -+ -+ // Epilogue derived types -+ using CollectiveEpilogue = CollectiveEpilogue_; -+ using ElementC = typename CollectiveEpilogue::ElementC; -+ using StrideC = typename CollectiveEpilogue::StrideC; -+ using ElementD = typename CollectiveEpilogue::ElementD; -+ using StrideD = typename CollectiveEpilogue::StrideD; -+ using EpilogueParams = typename CollectiveEpilogue::Params; -+ static_assert(std::is_same_v, -+ "Mainloop and epilogue do not agree on accumulator value type."); -+ -+ static constexpr uint32_t NumDmaWarpGroups = 1; -+ static constexpr uint32_t NumMmaWarpGroups = 2; -+ static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); -+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1; -+ -+ /// Register requirement for DMA and MATH WGs -+ static constexpr uint32_t DmaRegisterRequirement = 40; -+ static constexpr uint32_t MmaRegisterRequirement = 232; -+ -+ /* Order Sequence barrier with two stages: one for Mainloop and one for Epilogue */ -+ static constexpr uint32_t StagesPerMathWarpGroup = 2; -+ using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< -+ StagesPerMathWarpGroup, NumMmaWarpGroups>; -+ -+ // Kernel level shared memory storage -+ struct SharedStorage { -+ using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; -+ using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; -+ using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; -+ -+ MainloopSharedStorage mainloop; -+ EpilogueSharedStorage epilogue; -+ alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order_barrier_storage; -+ }; -+ -+ static constexpr int SharedStorageSize = sizeof(SharedStorage); -+ -+ // Device side arguments -+ struct Arguments { -+ GemmUniversalMode mode{}; -+ ProblemShape problem_shape{}; -+ ElementA const* ptr_A = nullptr; -+ StrideA dA{}; -+ ElementB const* ptr_B = nullptr; -+ StrideB dB{}; -+ EpilogueParams epilogue_params{}; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // Kernel entry point API -+ struct Params { -+ GemmUniversalMode mode; -+ ProblemShape problem_shape; -+ MainloopParams mainloop; -+ EpilogueParams epilogue; -+ KernelHardwareInfo hw_info; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ // Convert to underlying arguments. In this case, a simple copy for the aliased type. -+ static -+ Params -+ to_underlying_arguments(Arguments const& args, void* workspace) { -+ CUTLASS_TRACE_HOST("to_underlying_arguments():"); -+ -+ (void) workspace; -+ auto problem_shape = args.problem_shape; -+ if constexpr (detail::IF_SWAP_AB::value) { -+ // swap M/N -+ get<0>(problem_shape) = get<1>(args.problem_shape); -+ get<1>(problem_shape) = get<0>(args.problem_shape); -+ } -+ -+ // Get SM count if needed, otherwise use user supplied SM count -+ int sm_count = args.hw_info.sm_count; -+ if (sm_count <= 0) { -+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" -+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); -+ sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); -+ } -+ -+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); -+ return { -+ args.mode, -+ problem_shape, -+ CollectiveMainloop::to_underlying_arguments(args, workspace), -+ CollectiveEpilogue::to_underlying_arguments(args, workspace), -+ {args.hw_info.device_id, sm_count} -+ }; -+ } -+ -+ CUTLASS_HOST_DEVICE static -+ bool -+ can_implement(Arguments const& args) { -+ bool implementable = args.mode == GemmUniversalMode::kGemm or -+ (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); -+ -+ // Number of blocks per problem (without batch) must not exceed 2^31 for the persistent scheduler to calculate using FastDivmod -+ auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = -+ detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); -+ uint64_t problem_blocks = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ implementable = implementable && (problem_blocks < (uint64_t(1) << 31)); -+ -+ return implementable; -+ } -+ -+ static -+ int -+ get_workspace_size(Arguments const& args) { -+ return 0; -+ } -+ -+ // Computes the kernel launch grid shape based on runtime parameters -+ static constexpr -+ dim3 -+ get_grid_shape(Params const& params) { -+ int sm_count = params.hw_info.sm_count; -+ CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); -+ -+ // Compute the total number of output tiles our problem has -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = -+ detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); -+ int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ -+ // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently -+ dim3 launch_grid(1, cute::size<1>(ClusterShape{}), 1); -+ -+ // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 -+ if constexpr (size(ClusterShape{}) == 1) { -+ launch_grid.x = std::min(sm_count, problem_blocks_total); -+ } -+ else { -+ /* -+ * Optimal grid size calculation is based on -+ * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU -+ * Hence, maximum SMs per GPC = 18 -+ */ -+ constexpr int max_sm_per_gpc = 18; -+ // Provided SM count could possibly be less than the assumed maximum SMs per GPC -+ int min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; -+ int max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(ClusterShape{})); -+ int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; -+ -+ launch_grid.x = std::min( -+ blk_per_device / size<1>(ClusterShape{}), -+ problem_blocks_total / size<1>(ClusterShape{})); -+ } -+ -+ return launch_grid; -+ } -+ -+ static constexpr -+ dim3 -+ get_block_shape() { -+ return dim3(MaxThreadsPerBlock, 1, 1); -+ } -+ -+ CUTLASS_DEVICE -+ void -+ operator()(Params const& params, char* smem_buf) { -+ using namespace cute; -+ using X = Underscore; -+ -+ // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -+ #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) -+ if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { -+ printf("ERROR : Arch conditional MMA instruction used without targetting sm90a compute capability. Aborting.\n"); -+ return; -+ } -+ #endif -+ -+ // Preconditions -+ static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); -+ -+ enum class WarpGroupRole { -+ Producer = 0, -+ Consumer0 = 1, -+ Consumer1 = 2 -+ }; -+ -+ // Kernel level shared memory storage -+ SharedStorage& shared_storage = *reinterpret_cast(smem_buf); -+ -+ int thread_idx = int(threadIdx.x); -+ int warp_idx = canonical_warp_idx(); -+ int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; -+ auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Issue Tma Descriptor Prefetch from a single thread -+ if ((warp_idx == 0) && lane_predicate) { -+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); -+ } -+ -+ using Pipeline = typename CollectiveMainloop::MainloopPipeline; -+ using PipelineParams = typename CollectiveMainloop::PipelineParams; -+ PipelineParams params_pipeline; -+ params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; -+ if (warp_group_role == WarpGroupRole::Producer) { -+ params_pipeline.role = Pipeline::ThreadCategory::Producer; -+ } -+ else { -+ params_pipeline.role = Pipeline::ThreadCategory::Consumer; -+ } -+ params_pipeline.is_leader = warp_group_thread_idx == 0; -+ params_pipeline.num_consumers = NumThreadsPerWarpGroup; -+ -+ // Initialize pipeline and setup starting pipeline state for the collectives -+ Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); -+ typename CollectiveMainloop::PipelineState collective_start_state_pipe; -+ -+ typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; -+ // DMA WG will not participate in these Ordered Barrier syncs -+ params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); -+ params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group -+ MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.math_wg_order_barrier_storage, params_math_wg_order_barrier); -+ -+ auto cluster_wait_fn = [&] () { -+ // We need this to guarantee that the Pipeline init is visible -+ // To all producers and consumer thread blocks in the Cluster -+ if constexpr (size(ClusterShape{}) > 1) { -+ cute::cluster_arrive_relaxed(); -+ return [] () { cute::cluster_wait(); }; -+ } -+ else { -+ __syncthreads(); -+ return [] () {}; // do nothing -+ } -+ } (); -+ -+ // Separate out problem shape for convenience -+ // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) -+ auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); -+ auto M = get<0>(problem_shape_MNKL); -+ auto N = get<1>(problem_shape_MNKL); -+ auto K = get<2>(problem_shape_MNKL); -+ auto L = get<3>(problem_shape_MNKL); -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) -+ Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) -+ -+ // Get the appropriate blocks for this thread block -- potential for thread block locality -+ auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) -+ auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice -+ -+ // Slice to get the tiles this thread block is responsible for -+ Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) -+ Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) -+ -+ // Get iterations along k-dimension -+ auto k_tile_count = size<3>(gA_mkl); -+ -+ detail::PersistentTileSchedulerSm90 scheduler(problem_shape_MNKL, blk_shape, ClusterShape{}); -+ -+ if (warp_group_role == WarpGroupRole::Consumer1) { -+ /* Advance 2nd Math WG to the next work tile for the startup */ -+ scheduler.advance_to_next_work(); -+ /* Advance 2nd Math WG pipeline state to the end of 1st Math WG */ -+ collective_start_state_pipe.advance(k_tile_count); -+ } -+ auto work_tile_info = scheduler.get_current_work(); -+ -+ // Perform the collective scoped MMA -+ CollectiveMainloop collective_mainloop; -+ -+ // Wait for all thread blocks in the Cluster -+ cluster_wait_fn(); -+ -+ if (warp_group_role == WarpGroupRole::Producer) { -+ cutlass::arch::warpgroup_reg_dealloc(); -+ -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ while (work_tile_info.is_valid_tile) { -+ // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape -+ auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); -+ auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with our work tile coordinates to construct mainloop tensor views -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ -+ collective_mainloop.dma( -+ pipeline, -+ smem_pipe_write, -+ gA, params.mainloop.tma_load_a, -+ gB, params.mainloop.tma_load_b, -+ k_tile_iter, k_tile_count, -+ thread_idx, -+ reinterpret_cast(&shared_storage.mainloop) -+ ); -+ // Update starting pipeline state for the next tile -+ smem_pipe_write.advance(k_tile_count); -+ scheduler.advance_to_next_work(); -+ work_tile_info = scheduler.get_current_work(); -+ } // Scheduler work fetch loop -+ -+ // Make sure all Consumer Warp Groups have been waited upon -+ collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); -+ } // Producer Warp Group End -+ -+ else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { -+ // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape -+ cutlass::arch::warpgroup_reg_alloc(); -+ -+ while (work_tile_info.is_valid_tile) { -+ // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape -+ auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); -+ auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); -+ auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); -+ auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); -+ -+ // Slice with our work tile coordinates to construct mainloop tensor views -+ Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) -+ Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) -+ -+ auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); -+ -+ TiledMma tiled_mma; -+ Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) -+ clear(accumulators); -+ -+ /* Order two Math WG's MMA one after the other, helps hide Epilogue */ -+ math_wg_order_barrier.wait(); -+ -+ collective_mainloop.mma( -+ pipeline, -+ collective_start_state_pipe, -+ accumulators, -+ k_tile_count, -+ thread_idx, -+ reinterpret_cast(&shared_storage.mainloop), -+ params.mainloop -+ ); -+ -+ /* Cue for next Math WG's MMA to start */ -+ math_wg_order_barrier.arrive(); -+ -+ /* Order two Math WG's Epilogue one after the other */ -+ math_wg_order_barrier.wait(); -+ -+ constexpr int BLK_M_RANK = rank<0>(blk_shape); -+ bool m_oob = int(work_tile_info.M_idx) >= size<2>(gA_mkl); -+ auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); -+ })); -+ -+ constexpr int BLK_N_RANK = rank<1>(blk_shape); -+ bool n_oob = int(work_tile_info.N_idx) >= size<2>(gB_nkl); -+ auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { -+ return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); -+ })); -+ auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); -+ -+ // Epilogue and write to gD -+ CollectiveEpilogue epilogue{params.epilogue}; -+ epilogue( -+ problem_shape_MNKL, -+ blk_shape, -+ blk_coord, -+ accumulators, -+ tiled_mma, -+ residue_mnk, -+ warp_group_thread_idx, -+ reinterpret_cast(&shared_storage.epilogue) -+ ); -+ -+ /* Cue for next Math WG's Epilogue to start */ -+ math_wg_order_barrier.arrive(); -+ -+ // Update starting pipeline state for the next tile -+ collective_start_state_pipe.advance(k_tile_count * NumMmaWarpGroups); -+ -+ scheduler.advance_to_next_work(NumMmaWarpGroups); -+ work_tile_info = scheduler.get_current_work(); -+ } // Scheduler work fetch loop -+ } // Consumer Warp Groups End -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::gemm::kernel -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp -new file mode 100644 -index 0000000..496d5e0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/fast_math.h" -+#include "cute/layout.hpp" -+ -+namespace cutlass::gemm::kernel::detail { -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Persistent Thread Block (TB) scheduler -+class PersistentTileSchedulerSm90 { -+ // -+ // Data members -+ // -+ -+private: -+ uint32_t blocks_per_problem_; -+ uint32_t current_work_linear_idx_; -+ uint32_t grid_blocks_total_; -+ -+ FastDivmod divmod_batch_; -+ FastDivmod divmod_grid_y_; -+ FastDivmod divmod_blk_m_; -+ -+ struct WorkTileInfo { -+ int32_t M_idx = 0; -+ int32_t N_idx = 0; -+ int32_t L_idx = 0; -+ uint32_t is_valid_tile = false; -+ }; -+ -+ // -+ // Methods -+ // -+ -+public: -+ -+ template -+ CUTLASS_DEVICE -+ PersistentTileSchedulerSm90(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { -+ // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic -+ static_assert(is_static::value); -+ static_assert(is_static::value); -+ -+ // Round up to nearest multiple of cluster dim along each mode -+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl( -+ problem_shape_mnkl, tile_shape, cluster_shape); -+ -+ blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks_l; -+ current_work_linear_idx_ = (int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y); -+ grid_blocks_total_ = int(gridDim.x) * int(gridDim.y); -+ -+ // Pre-compute our fast div/mods for rasterization so we don't have to pay for DIVs -+ divmod_batch_ = FastDivmod(problem_blocks_m * problem_blocks_n); -+ divmod_grid_y_ = FastDivmod(size<1>(cluster_shape)); -+ divmod_blk_m_ = FastDivmod(problem_blocks_m); -+ } -+ -+ CUTLASS_DEVICE -+ WorkTileInfo -+ get_current_work() const { -+ // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices -+ int work_idx_l, remainder; -+ divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); -+ -+ int blk_per_grid_dim, dontcare; -+ divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); -+ -+ int block_idx_m, block_idx_n; -+ divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); -+ int work_idx_m = block_idx_m; -+ int work_idx_n = (block_idx_n * gridDim.y) + blockIdx.y; -+ -+ return {work_idx_m, work_idx_n, work_idx_l, current_work_linear_idx_ < blocks_per_problem_}; -+ } -+ -+ CUTLASS_DEVICE -+ void -+ advance_to_next_work(uint32_t advance_count = 1) { -+ current_work_linear_idx_ += grid_blocks_total_ * advance_count; -+ } -+ -+ // Given the inputs, computes the total number of output blocks this problem will compute over -+ // Note that this is only the logical size of our grid, not the physical grid we will actually launch. -+ template -+ CUTLASS_HOST_DEVICE constexpr static -+ dim3 -+ get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) { -+ // Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles -+ auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape))); -+ auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape))); -+ -+ // Round up to nearest multiple of cluster dim along each mode -+ int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape)); -+ int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape)); -+ -+ // Cluster tile does not span the batch mode, so no extra rounding up required for it -+ int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl)); -+ return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)}; -+ } -+}; -+ -+} // namespace cutlass::gemm::kernel::detail -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h -new file mode 100644 -index 0000000..eba95aa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/sparse_gemm.h -@@ -0,0 +1,400 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. -+> -+struct SparseGemm { -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using OutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ static bool const kSplitKSerial = SplitKSerial; -+ -+ static int const kSparse = Mma::kSparse; -+ static int const kMetaSizeInBits = Mma::kMetaSizeInBits; -+ static int const kMaxID2 = Mma::kMaxID2; -+ static int const kElementsPerElementE = Mma::kElementsPerElementE; -+ -+ using ElementE = typename Mma::ElementE; -+ using LayoutE = typename Mma::LayoutE; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Parameters structure -+ struct Params { -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorA::TensorRef ref_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Mma::IteratorB::TensorRef ref_B; -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::TensorRef ref_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ typename Epilogue::OutputTileIterator::TensorRef ref_D; -+ typename Mma::IteratorE::Params params_E; -+ typename Mma::IteratorE::TensorRef ref_E; -+ typename OutputOp::Params output_op; -+ int *semaphore; -+ int gemm_k_iterations; -+ int gemm_k_size; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename Mma::IteratorE::TensorRef ref_E, -+ typename OutputOp::Params output_op = typename OutputOp::Params(), -+ int *workspace = nullptr -+ ): -+ problem_size(problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(ref_A.layout()), -+ ref_A(ref_A), -+ params_B(ref_B.layout()), -+ ref_B(ref_B), -+ params_C(ref_C.layout()), -+ ref_C(ref_C), -+ params_D(ref_D.layout()), -+ ref_D(ref_D), -+ params_E(ref_E.layout()), -+ ref_E(ref_E), -+ output_op(output_op) { -+ -+ int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); -+ -+ gemm_k_size = gemm_k_iterations * Mma::Shape::kK; -+ -+ semaphore = workspace; -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ SparseGemm() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Epilogue::OutputTileIterator::TensorRef ref_C, -+ typename Epilogue::OutputTileIterator::TensorRef ref_D, -+ typename Mma::IteratorE::TensorRef ref_E) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; -+ -+ if (!TensorRef_aligned(ref_A, kAlignmentA)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_B, kAlignmentB)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_C, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_D, kAlignmentC)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (!TensorRef_aligned(ref_E, kAlignmentE)) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) || -+ (problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // The k dimension has to be the multiple of the Threadblock k because out -+ // of bound meta data would be initialized to 0 by acync.zfill but 0 is not -+ // a valid meta data. -+ if (problem_size.k() % Mma::Shape::kK) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) -+ // because of the row reordering of operand E -+ static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; -+ -+ if (problem_size.m() % kAlignmentM) { -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size / kSparse, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ threadblock_tile_offset.k() * params.gemm_k_size, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_E{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.k() * params.gemm_k_size / kSparse, -+ }; -+ -+ // Problem size is a function of threadblock index in the K dimension -+ int problem_size_k = min( -+ params.problem_size.k(), -+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A, B, and E operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ params.ref_A.data(), -+ {params.problem_size.m(), problem_size_k / kSparse}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ params.ref_B.data(), -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ typename Mma::IteratorE iterator_E( -+ params.params_E, params.ref_E.data(), -+ {params.problem_size.m(), -+ problem_size_k / kSparse / kElementsPerElementE}, -+ thread_idx, tb_offset_E); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ if (!kSplitKSerial || gemm_k_iterations > 0) { -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); -+ } -+ -+ // -+ // Epilogue -+ // -+ -+ OutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ params.ref_C.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ params.ref_D.data(), -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ __threadfence(); -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h -new file mode 100755 -index 0000000..47e7035 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/symm_universal.h -@@ -0,0 +1,698 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma1_, ///! Threadblock-scoped triangular matrix multiply-accumulate (A*B or B*A) -+ typename Mma2_, ///! Threadblock-scoped triangular matrix multiply-accumulate (AT*B or B*AT) -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) -+ FillMode FillMode_ ///! Fill Mode for triangular matrix (kLower or kUpper) -+> -+struct SymmUniversal { -+public: -+ -+ using Mma1 = Mma1_; -+ using Mma2 = Mma2_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma1::IteratorA::Element; -+ using ElementB = typename Mma1::IteratorB::Element; -+ -+ // Mma1 (TRMM - with diagonal: C_tmp = alpha * A * B) -+ using LayoutA = typename Mma1::IteratorA::Layout; -+ using LayoutBT = typename Mma1::IteratorB::Layout; -+ static ComplexTransform const kMma1TransformA = Mma1::kTransformA; -+ static ComplexTransform const kMma1TransformB = Mma1::kTransformB; -+ -+ // Mma2 (TRMM - withOUT diagonal: alpha * AT * B) -+ using LayoutB = typename Mma2::IteratorA::Layout; -+ using LayoutAT = typename Mma2::IteratorB::Layout; -+ static ComplexTransform const kMma2TransformA = Mma2::kTransformA; -+ static ComplexTransform const kMma2TransformB = Mma2::kTransformB; -+ -+ // Common type definitions for Mma1 and Mma2 -+ using Operator = typename Mma1::Operator; -+ using OperatorClass = typename Mma1::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma1::Shape; -+ using WarpShape = typename Mma1::Operator::Shape; -+ using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma1::ArchTag; -+ -+ static int const kStages = Mma1::kStages; -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ -+ // Output related typedefinitions -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static SideMode const kSideModeA = SideMode_; -+ static FillMode const kFillModeA = FillMode_; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma1::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldc; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldc, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { -+ -+ } -+ -+ /// Returns arguments for the transposed problem sizes -+ Arguments transposed_problem_size() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ -+ return args; -+ } -+ -+ /// Returns arguments for the transposed matrices -+ Arguments swapped_matrices() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ // Mma1 Iterator A and B params -+ typename Mma1::IteratorA::Params params_A_mma1; -+ typename Mma1::IteratorB::Params params_B_mma1; -+ -+ // Mma2 Iterator A and B params -+ typename Mma2::IteratorA::Params params_A_mma2; -+ typename Mma2::IteratorB::Params params_B_mma2; -+ -+ typename Epilogue::OutputTileIterator::Params params_C; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A_mma1(0), -+ params_B_mma1(0), -+ params_A_mma2(0), -+ params_B_mma2(0), -+ params_C(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_C(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_C(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A_mma1(args.lda), -+ params_B_mma1(args.ldb), -+ params_A_mma2(args.lda), -+ params_B_mma2(args.ldb), -+ params_C(args.ldc), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(const_cast(args.ptr_D)), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma1::SharedStorage mma1_main_loop; -+ typename Mma2::SharedStorage mma2_main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ SymmUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma1::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma1::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes two GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_MxK_mma1{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN_mma1{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ cutlass::MatrixCoord tb_offset_MxK_mma2{ -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_KxN_mma2{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply for Mma1 -+ Mma1 mma1(shared_storage.mma1_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ // Construct thread-scoped matrix multiply for Mma2 -+ Mma2 mma2(shared_storage.mma2_main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma1::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ int gemm_k_iterations_mma1 = gemm_k_iterations; -+ int gemm_k_iterations_mma2 = gemm_k_iterations; -+ -+ -+ /****************************************************************************************************** -+ * SYMM (Side Mode, Fill Mode) is made of two TRMMs: -+ First TRMM (Mma1: Side Mode, Fill Mode, Non-Unit Diag): (A * B) or (B * A) -+ Second TRMM (Mma2: Side Mode, Inverted Fill Mode, Unit Diag): (AT * B) or (B * AT) -+ -+ * For the first TRMM (Mma1) of SYMM, the following method is used to calculate the k-iterations: -+ First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other -+ - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ -+ Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other -+ - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ -+ * For the second TRMM (Mma2) of SYMM, the k-iterations and threadblock offsets are calculated -+ the same way as the first TRMM (Mma1) of same side mode but with inverted fill mode. -+ For example, if the first TRMM is left sided with lower fill, the second TRMM would be -+ left sided with upper fill. -+ ********************************************************************************************************/ -+ -+ if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { -+ gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 != 0) { -+ tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 < gemm_k_iterations) { -+ gemm_k_iterations_mma1 = k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 != 0) { -+ tb_offset_MxK_mma2 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma2 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma2 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma2 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma2 -= k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kLeft && kFillModeA == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.m()) * Mma1::Shape::kM) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma1 != 0) { -+ tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { -+ gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; -+ } -+ -+ } else if (kSideModeA == SideMode::kRight && kFillModeA == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal_mma1 = ((threadblock_tile_offset.n()) * Mma1::Shape::kN) / Mma1::Shape::kK; -+ -+ if (k_iterations_till_diagonal_mma1 != 0) { -+ tb_offset_MxK_mma1 += cutlass::MatrixCoord({0, k_iterations_till_diagonal_mma1 * Mma1::Shape::kK}); -+ tb_offset_KxN_mma1 += cutlass::MatrixCoord({k_iterations_till_diagonal_mma1 * Mma1::Shape::kK, 0}); -+ gemm_k_iterations_mma1 -= k_iterations_till_diagonal_mma1; -+ } -+ -+ int k_iterations_till_diagonal_mma2 = ((threadblock_tile_offset.n() + 1) * Mma1::Shape::kN + Mma1::Shape::kK - 1) / Mma1::Shape::kK; -+ if (k_iterations_till_diagonal_mma2 < gemm_k_iterations) { -+ gemm_k_iterations_mma2 = k_iterations_till_diagonal_mma2; -+ } -+ -+ } -+ -+ // Construct iterators to A and B operands for Mma1 -+ typename Mma1::IteratorA iterator_A_mma1( -+ params.params_A_mma1, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK_mma1); -+ -+ typename Mma1::IteratorB iterator_B_mma1( -+ params.params_B_mma1, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN_mma1); -+ -+ // Construct iterators to A and B operands for Mma2 -+ typename Mma2::IteratorA iterator_A_mma2( -+ params.params_A_mma2, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_MxK_mma2); -+ -+ typename Mma2::IteratorB iterator_B_mma2( -+ params.params_B_mma2, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_KxN_mma2); -+ -+ // Compute threadblock-scoped matrix multiply-add (A x B) or (B x A) -+ mma1( -+ gemm_k_iterations_mma1, -+ accumulators, -+ iterator_A_mma1, -+ iterator_B_mma1, -+ accumulators); -+ -+ // Compute threadblock-scoped matrix multiply-add (AT x B) or (B x AT) -+ mma2( -+ gemm_k_iterations_mma2, -+ accumulators, -+ iterator_A_mma2, -+ iterator_B_mma2, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = -+ threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma1::Shape::kM, -+ threadblock_tile_offset.n() * Mma1::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ // Tile iterator loading from source tensor. -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_C, -+ ptr_C, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h b/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h -new file mode 100644 -index 0000000..7ba223b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/kernel/trmm_universal.h -@@ -0,0 +1,599 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+#include "cutlass/core_io.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function -+ SideMode SideMode_, ///! Side Mode for the kernel (kLeft or kRight) -+ FillMode FillMode_, ///! Fill Mode for triangular matrix (kLower or kUpper) -+ DiagType DiagType_ ///! Diag Type for triangular matrix (kNonUnit or kUnit) -+> -+struct TrmmUniversal { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueOutputOp = typename Epilogue::OutputOp; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Epilogue::OutputTileIterator::Element; -+ using LayoutC = typename Epilogue::OutputTileIterator::Layout; -+ static SideMode const kSideMode = SideMode_; -+ static FillMode const kFillMode = FillMode_; -+ static DiagType const kDiagType = DiagType_; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ GemmUniversalMode mode; -+ GemmCoord problem_size; -+ int batch_count; -+ -+ typename EpilogueOutputOp::Params epilogue; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+ -+ typename LayoutA::Stride::Index lda; -+ typename LayoutB::Stride::Index ldb; -+ typename LayoutC::Stride::Index ldd; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ mode(GemmUniversalMode::kGemm), -+ batch_count(1), -+ ptr_A(nullptr), ptr_B(nullptr), ptr_D(nullptr) { } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueOutputOp::Params epilogue, -+ void const * ptr_A, -+ void const * ptr_B, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::Index lda, -+ typename LayoutB::Stride::Index ldb, -+ typename LayoutC::Stride::Index ldd -+ ): -+ mode(mode), -+ problem_size(problem_size), -+ batch_count(batch_count), -+ epilogue(epilogue), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_D(batch_stride_D), -+ lda(lda), ldb(ldb), ldd(ldd) { -+ } -+ -+ /// Returns arguments for the transposed problem sizes -+ Arguments transposed_problem_size() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ -+ return args; -+ } -+ -+ /// Returns arguments for the transposed matrices -+ Arguments swapped_matrices() const { -+ Arguments args(*this); -+ -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params { -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::gemm::GemmCoord grid_tiled_shape; -+ int swizzle_log_tile; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename Epilogue::OutputTileIterator::Params params_D; -+ -+ typename EpilogueOutputOp::Params output_op; -+ -+ GemmUniversalMode mode; -+ int batch_count; -+ int gemm_k_size; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params(): -+ swizzle_log_tile(0), -+ params_A(0), -+ params_B(0), -+ params_D(0), -+ batch_count(0), -+ gemm_k_size(0), -+ mode(cutlass::gemm::GemmUniversalMode::kGemm), -+ ptr_A(nullptr), -+ ptr_B(nullptr), -+ ptr_D(nullptr), -+ batch_stride_A(0), -+ batch_stride_B(0), -+ batch_stride_D(0), -+ semaphore(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ cutlass::gemm::GemmCoord const & grid_tiled_shape, -+ int gemm_k_size, -+ void *workspace = nullptr -+ ): -+ problem_size(args.problem_size), -+ grid_tiled_shape(grid_tiled_shape), -+ swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), -+ params_A(args.lda), -+ params_B(args.ldb), -+ params_D(args.ldd), -+ output_op(args.epilogue), -+ mode(args.mode), -+ batch_count(args.batch_count), -+ gemm_k_size(gemm_k_size), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_D(args.batch_stride_D), -+ semaphore(static_cast(workspace)) { -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_D = args.ptr_D; -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_D = args.batch_stride_D; -+ -+ output_op = args.epilogue; -+ -+ semaphore = static_cast(workspace); -+ } -+ -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TrmmUniversal() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || -+ (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || -+ (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { -+ -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = canonical_warp_idx(); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ /****************************************************************************************************** -+ First two cases: (Left Side, Lower Fill) and (Right Side, Upper Fill) are transpose of each other -+ - (Left Side, Lower Fill): calculate bottom of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ - (Right Side, Upper Fill): calculate right end of the CTA tile, then find the k-iterations -+ needed to process all elements till that coordinate. -+ -+ Last two cases: (Left Side, Upper Fill) and (Right Side, Lower Fill) are transpose of each other -+ - (Left Side, Upper Fill): calculate the top of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ - (Right Side, Lower Fill): calculate the left start of the CTA tile, then find k-iterations -+ that can be skipped for all elements of this tile. -+ ********************************************************************************************************/ -+ -+ if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.m() + 1) * Mma::Shape::kM + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ if (k_iterations_till_diagonal < gemm_k_iterations) { -+ gemm_k_iterations = k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.n() + 1) * Mma::Shape::kN + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ if (k_iterations_till_diagonal < gemm_k_iterations) { -+ gemm_k_iterations = k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kLeft && kFillMode == FillMode::kUpper) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.m()) * Mma::Shape::kM) / Mma::Shape::kK; -+ -+ if (k_iterations_till_diagonal != 0) { -+ tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); -+ tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); -+ gemm_k_iterations -= k_iterations_till_diagonal; -+ } -+ -+ } else if (kSideMode == SideMode::kRight && kFillMode == FillMode::kLower) { -+ -+ int k_iterations_till_diagonal = ((threadblock_tile_offset.n()) * Mma::Shape::kN) / Mma::Shape::kK; -+ -+ if (k_iterations_till_diagonal != 0) { -+ tb_offset_A += cutlass::MatrixCoord({0, k_iterations_till_diagonal * Mma::Shape::kK}); -+ tb_offset_B += cutlass::MatrixCoord({k_iterations_till_diagonal * Mma::Shape::kK, 0}); -+ gemm_k_iterations -= k_iterations_till_diagonal; -+ } -+ -+ } -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ if (params.mode == GemmUniversalMode::kGemm) { -+ -+ // If performing a reduction via split-K, fetch the initial synchronization -+ if (params.grid_tiled_shape.k() > 1) { -+ -+ // Fetch the synchronization lock initially but do not block. -+ semaphore.fetch(); -+ -+ // Indicate which position in a serial reduction the output operator is currently updating -+ output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); -+ } -+ } -+ else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; -+ } -+ -+ -+ // Tile iterator loading from source tensor (although irrelevant to this kernel as beta is zero). -+ typename Epilogue::OutputTileIterator iterator_C( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ // Tile iterator writing to destination tensor. -+ typename Epilogue::OutputTileIterator iterator_D( -+ params.params_D, -+ ptr_D, -+ params.problem_size.mn(), -+ thread_idx, -+ threadblock_offset -+ ); -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ if (threadblock_tile_offset.k()) { -+ iterator_C = iterator_D; -+ } -+ -+ semaphore.wait(threadblock_tile_offset.k()); -+ -+ __threadfence(); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue( -+ output_op, -+ iterator_D, -+ accumulators, -+ iterator_C); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h -new file mode 100644 -index 0000000..d1f9b69 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma.h -@@ -0,0 +1,90 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp-level multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Concept: arch::OpMultiplyAdd or arch::Mma<> -+ typename Operator = arch::OpMultiplyAdd, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+struct Mma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Overloads specialized for existing architectures -+// -+ -+#include "cutlass/gemm/thread/mma_sm50.h" -+#include "cutlass/gemm/thread/mma_sm60.h" -+#include "cutlass/gemm/thread/mma_sm61.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h -new file mode 100644 -index 0000000..1573e64 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm50.h -@@ -0,0 +1,539 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_, -+ /// Operator used to compute GEMM -+ typename Operator_ -+> -+struct MmaGeneric { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = Operator_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = arch::Mma< -+ gemm::GemmShape<1,1,1>, -+ 1, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ Operator>; -+ -+ static bool const kMultipleOf2 = ((Shape::kM % 2 == 0) && (Shape::kN % 2 == 0)); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef a_ref( -+ reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); -+ -+ TensorRef b_ref( -+ reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); -+ -+ TensorRef d_ref( -+ reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); -+ -+ MmaOp mma_op; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK; ++k) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) -+ if (kMultipleOf2 && -+ platform::is_same::value && -+ platform::is_same::value && -+ platform::is_same::value) { -+ -+ //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; n+=2) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; m+=2) { -+ -+ int m_serpentine = (n % 4) ? (Shape::kM - 2 - m) : m; -+ -+ //top-left element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //bottom-left element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine+1, n); -+ MatrixCoord mk(m_serpentine+1, k); -+ MatrixCoord kn(k, n); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //bottom-right element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine+1, n+1); -+ MatrixCoord mk(m_serpentine+1, k); -+ MatrixCoord kn(k, n+1); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ -+ //top-right element in 2x2 tile -+ { -+ MatrixCoord mn(m_serpentine, n+1); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n+1); -+ Array d; -+ Array a; -+ Array b; -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ mma_op(d, a, b, d); -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } else -+ #endif -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; -+ -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ -+ Array d; -+ Array a; -+ Array b; -+ -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ -+ mma_op(d, a, b, d); -+ -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Matrix multiply-add operation - assumes operand B is not changing -+struct MmaComplexF32_Column { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ } -+}; -+ -+/// Matrix multiply-add operation - assumes operand A is not changing -+struct MmaComplexF32_Corner { -+ -+ using Shape = gemm::GemmShape<1, 1, 1>; -+ using ElementC = complex; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ Array, 1> &d, -+ Array, 1> const &a, -+ Array, 1> const &b, -+ Array, 1> const &c -+ ) { -+ -+ d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); -+ d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); -+ d[0].real() = a[0].real() * b[0].real() + c[0].real(); -+ d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles all packed matrix layouts -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_ -+> -+struct MmaGeneric< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ arch::OpMultiplyAdd> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = complex; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = complex; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = complex; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Instruction -+ using MmaOp = arch::Mma< -+ gemm::GemmShape<1,1,1>, -+ 1, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ Operator>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef a_ref( -+ reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); -+ -+ TensorRef b_ref( -+ reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); -+ -+ TensorRef d_ref( -+ reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); -+ -+ detail::MmaComplexF32_Column mma_column; -+ detail::MmaComplexF32_Corner mma_corner; -+ -+ // Copy accumulators -+ D = C; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ -+ int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; -+ -+ MatrixCoord mn(m_serpentine, n); -+ MatrixCoord mk(m_serpentine, k); -+ MatrixCoord kn(k, n); -+ -+ Array d; -+ Array a; -+ Array b; -+ -+ d[0] = d_ref.at(mn); -+ a[0] = a_ref.at(mk); -+ b[0] = b_ref.at(kn); -+ -+ if ((m == 0 && n) || m == Shape::kM - 1) { -+ mma_corner(d, a, b, d); -+ } -+ else { -+ mma_column(d, a, b, d); -+ } -+ -+ d_ref.at(mn) = d[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles conventional layouts for FFMA and DFMA GEMM -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: layout::MapFunc) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: layout::MapFunc) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: layout::MapFunc) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ ElementA_, -+ LayoutA_, -+ ElementB_, -+ LayoutB_, -+ ElementC_, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ bool> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = ElementA_; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = LayoutA_; -+ -+ /// Data type of operand B -+ using ElementB = ElementB_; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = LayoutB_; -+ -+ /// Element type of operand C -+ using ElementC = ElementC_; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename MmaGeneric< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator>::MmaOp; -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ MmaGeneric< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator> mma; -+ -+ mma(D, A, B, C); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h -new file mode 100644 -index 0000000..e4bcb70 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm60.h -@@ -0,0 +1,1178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/functional.h" -+#include "cutlass/reduction/thread/reduce.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Structure to compute the matrix product for HFMA -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ -+ /// Type of GEMM inner vs outer product -+ bool -+> -+struct Mma_HFMA2; -+ -+ -+///////////////////////////// -+// Specialization for NNN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM/2 + m], -+ ptr_B[n*Shape::kK + k], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for NNT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ Array tmp_B; -+ tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); -+ tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM + m], -+ tmp_B, -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for NTN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) { -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ -+ ptr_tmp[0] = ptr_D[m + n * Shape::kM/2]; -+ -+ mma( -+ tmp, -+ ptr_A[m + k * Shape::kM/2], -+ ptr_B[k * Shape::kN + n], -+ tmp); -+ -+ ptr_D[m + n * Shape::kM/2] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for NTT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ mma( -+ tmp, -+ ptr_A[k*Shape::kM + m], -+ ptr_B[k*Shape::kN/2 + n], -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for TNN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ Array tmp_A; -+ tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); -+ tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ tmp_A, -+ ptr_B[n*Shape::kK + k], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for TNT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ Array tmp_B; -+ tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); -+ tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ ptr_A[m*Shape::kK + k], -+ tmp_B, -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////// -+// Specialization for TTN // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2 < -+ Shape_, -+ layout::RowMajor, -+ layout::RowMajor, -+ layout::ColumnMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kM % 2), -+ "Mma_HFMA2 requires the M dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<2,1,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; -+ -+ Array tmp_A; -+ tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); -+ tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k); -+ -+ mma( -+ tmp, -+ tmp_A, -+ ptr_B[k*Shape::kN + n], -+ tmp); -+ -+ ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////// -+// Specialization for TTT // -+///////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ layout::RowMajor, -+ layout::RowMajor, -+ layout::RowMajor, -+ true -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kN % 2), -+ "Mma_HFMA2 requires the N dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x2x1 HFMA2 sequence for bulk of computation -+ using Mma = arch::Mma< -+ gemm::GemmShape<1,2,1>, -+ 1, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Mma mma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ -+ -+ Array tmp; -+ Array *ptr_tmp = &tmp; -+ ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; -+ -+ mma( -+ tmp, -+ ptr_A[m*Shape::kK + k], -+ ptr_B[k*Shape::kN/2 + n], -+ tmp); -+ -+ ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////// -+// Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T // -+///////////////////////////////////////////////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ LayoutA, -+ LayoutB, -+ layout::RowMajor, -+ false -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kK % 2), -+ "Mma_HFMA2 requires the K dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x2 HFMA2 sequence for bulk of computation -+ using GemmShape = gemm::GemmShape<1,1,2>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Inner product is calculated using MACs, followed by final reduction -+ multiply_add> mac; -+ cutlass::reduction::thread::Reduce< plus, Array > reduce; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ -+ -+ Array tmp_C; -+ tmp_C.clear(); -+ Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); -+ ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ -+ tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); -+ } -+ -+ Array res; -+ Array *ptr_res = &res; -+ res = reduce(tmp_C); -+ -+ ptr_D[m*Shape::kN + n] = ptr_res[0]; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////// -+// Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N // -+///////////////////////////////////////////////////////////////////// -+ -+template -+struct Mma_HFMA2< -+ Shape_, -+ LayoutA, -+ LayoutB, -+ layout::ColumnMajor, -+ false -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ static_assert( -+ !(Shape::kK % 2), -+ "Mma_HFMA2 requires the K dimension to be divisible by 2." -+ ); -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ /// Initialize output with input -+ D = C; -+ -+ /// Use 1x1x2 HFMA2 sequence for bulk of computation -+ using GemmShape= gemm::GemmShape<1,1,2>; -+ -+ Array *ptr_D = reinterpret_cast *>(&D); -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Inner product is calculated using MACs, followed by final reduction -+ multiply_add> mac; -+ cutlass::reduction::thread::Reduce< plus, Array > reduce; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto n=0; n < Shape::kN / GemmShape::kN; n++){ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto m=0; m < Shape::kM / GemmShape::kM; m++){ -+ -+ Array tmp_C; -+ tmp_C.clear(); -+ Array *ptr_tmp_C = reinterpret_cast *>(&tmp_C); -+ ptr_tmp_C[0] = ptr_D[n*Shape::kM + m]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(auto k=0; k < Shape::kK / GemmShape::kK; k++){ -+ -+ tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C); -+ -+ } -+ -+ Array res; -+ Array *ptr_res = &res; -+ res = reduce(tmp_C); -+ -+ ptr_D[n*Shape::kM + m] = ptr_res[0]; -+ } -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC -+> -+struct Mma< -+ Shape_, -+ half_t, -+ LayoutA, -+ half_t, -+ LayoutB, -+ half_t, -+ LayoutC, -+ arch::OpMultiplyAdd -+ > { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = half_t; -+ -+ /// Data type of operand B -+ using ElementB = half_t; -+ -+ /// Element type of operand C -+ using ElementC = half_t; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value; -+ static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value; -+ static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value; -+ static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value; -+ -+ static bool const m_mod2 = !(Shape::kM % 2); -+ static bool const n_mod2 = !(Shape::kN % 2); -+ static bool const k_mod2 = !(Shape::kK % 2); -+ -+ // HFMA based MMA optimizations are of 2 types : -+ // 1. Inner product -+ // 2. Outer product -+ // It is chosen based on LayoutC (for outer product gemm) or -+ // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms) -+ // If all fails, we choose the generic MMA -+ static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2); -+ static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2); -+ static bool const use_optimized = (use_outer_prod || use_inner_prod); -+ -+ using ArchMmaOperator = typename platform::conditional< use_optimized, -+ detail::Mma_HFMA2, -+ MmaGeneric -+ >::type; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ ArchMmaOperator mma; -+ -+ mma(D, A, B, C); -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+ /// Determines whether to enable thread::Gemm<> specializations compatible with SM50 -+ template < -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB> -+ struct EnableMma_Crow_SM60 { -+ -+ static bool const kIsConventionalLayout = -+ (platform::is_same::value || -+ platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value); -+ -+ static bool const value = kIsConventionalLayout; -+ }; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes matrix product when C is row-major -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ typename LayoutA_, -+ typename LayoutB_ -+> -+struct Mma< -+ Shape_, -+ half_t, -+ LayoutA_, -+ half_t, -+ LayoutB_, -+ half_t, -+ layout::RowMajor, -+ arch::OpMultiplyAdd, -+ typename platform::enable_if::value>::type>{ -+ -+ using Shape = Shape_; -+ using ElementA = half_t; -+ using LayoutA = LayoutA_; -+ using ElementB = half_t; -+ using LayoutB = LayoutB_; -+ using ElementC = half_t; -+ using LayoutC = layout::RowMajor; -+ using Operator = arch::OpMultiplyAdd; -+ -+ using TransposeMma = Mma< -+ GemmShapeTranspose, -+ half_t, -+ typename layout::LayoutTranspose::type, -+ half_t, -+ typename layout::LayoutTranspose::type, -+ half_t, -+ layout::ColumnMajor, -+ arch::OpMultiplyAdd, -+ bool>; -+ -+ using FragmentA = Array; -+ using FragmentB = Array; -+ using FragmentC = Array; -+ -+ using ArchMmaOperator = typename TransposeMma::ArchMmaOperator; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TransposeMma mma; -+ -+ mma(D, B, A, C); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h -new file mode 100644 -index 0000000..7ef1efb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/thread/mma_sm61.h -@@ -0,0 +1,284 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gemplate that handles conventional layouts for IDP4A -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ int8_t, -+ layout::RowMajor, -+ int8_t, -+ layout::ColumnMajor, -+ int32_t, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ bool> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = int8_t; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = layout::RowMajor; -+ -+ /// Data type of operand B -+ using ElementB = int8_t; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = layout::ColumnMajor; -+ -+ /// Element type of operand C -+ using ElementC = int32_t; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ // Use 1x1x4 IDP4A sequence for bulk of computation -+ using ArchMmaOperator = arch::Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef d( -+ reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); -+ -+ // Copy accumulators -+ D = C; -+ -+ /// Use 1x1x4 IDP4A sequence for bulk of computation -+ ArchMmaOperator mma; -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ MatrixCoord mn(m, n); -+ -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ Array tmp = reinterpret_cast &>(d.at(mn)); -+ -+ mma( -+ tmp, -+ ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k], -+ ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k], -+ tmp); -+ -+ d.at(mn) = reinterpret_cast(tmp); -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Gemplate that handles conventional layouts for IDP4A -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_ -+> -+struct Mma< -+ Shape_, -+ int8_t, -+ layout::ColumnMajor, -+ int8_t, -+ layout::RowMajor, -+ int32_t, -+ LayoutC_, -+ arch::OpMultiplyAdd, -+ int8_t> { -+ -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ /// Data type of operand A -+ using ElementA = int8_t; -+ -+ /// Layout of A matrix (concept: layout::MapFunc) -+ using LayoutA = layout::ColumnMajor; -+ -+ /// Data type of operand B -+ using ElementB = int8_t; -+ -+ /// Layout of B matrix (concept: layout::MapFunc) -+ using LayoutB = layout::RowMajor; -+ -+ /// Element type of operand C -+ using ElementC = int32_t; -+ -+ /// Layout of C matrix (concept: layout::MapFunc) -+ using LayoutC = LayoutC_; -+ -+ /// Underlying mathematical operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ /// A operand storage -+ using FragmentA = Array; -+ -+ /// B operand storage -+ using FragmentB = Array; -+ -+ /// C operand storage -+ using FragmentC = Array; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ /// Use 1x1x4 IDP4A sequence for bulk of computation -+ using ArchMmaOperator = arch::Mma< -+ gemm::GemmShape<1,1,4>, -+ 1, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd>; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes a matrix product D = A * B + C -+ CUTLASS_HOST_DEVICE -+ void operator()( -+ FragmentC & D, -+ FragmentA const & A, -+ FragmentB const & B, -+ FragmentC const & C) { -+ -+ TensorRef d( -+ reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); -+ -+ // Copy accumulators -+ D = C; -+ -+ /// Underlying matrix multiply operator -+ ArchMmaOperator mma; -+ -+ Array const *ptr_A = reinterpret_cast const *>(&A); -+ Array const *ptr_B = reinterpret_cast const *>(&B); -+ -+ // Compute matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Shape::kN; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Shape::kM; ++m) { -+ MatrixCoord mn(m, n); -+ -+ Array tmp = reinterpret_cast &>(d.at(mn)); -+ -+ mma( -+ tmp, -+ ptr_A[m + k * Shape::kM], -+ ptr_B[n + k * Shape::kN], -+ tmp); -+ -+ d.at(mn) = reinterpret_cast(tmp); -+ } -+ } -+ } -+ } -+}; -+ -+} // namespace thread -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h -new file mode 100644 -index 0000000..7e4d765 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_ell_mma.h -@@ -0,0 +1,734 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default template for a Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+#include "cutlass/gemm/threadblock/ell_mma_pipelined.h" -+#include "cutlass/gemm/threadblock/ell_mma_multistage.h" -+#include "cutlass/transform/threadblock/ell_predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultEllMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass Simt) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassSimt, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, -+ LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, -+ arch::OpMultiplyAddFastF16>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, float, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultEllMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, -+ Operator, true> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, -+ typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, -+ Stages, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultEllMma { -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultEllMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator, true> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for SIMT IDP4A Kernels -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape> -+struct DefaultEllMma, 2, -+ Operator, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ OperatorClass, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+/// Specialization for Wmma TensorOp operator with 2 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for Wmma TensorOp operator with 1 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultEllMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 1, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::EllPredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped singlestage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h -new file mode 100755 -index 0000000..afb74e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_gemv_core.h -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level batched GEMV assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting SIMT instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+/// Template defininng default vector-matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout. -+template < -+ typename Shape_, /// Shape of the threadblock vector-matrix multiply operator -+ typename ThreadShape_, /// Shape of per-thread vector-matrix multiply operator -+ typename ElementA_, /// Element data type of A operand -+ typename LayoutA_, /// Layout of operand A -+ typename ElementB_, /// Element data type of B operand -+ typename LayoutB_, /// Layout of operand B -+ typename ElementC_, /// Data type of accumulator -+ typename LayoutC_ /// Layout of accumulator -+> -+struct DefaultGemvCore { -+ -+ using Shape = Shape_; -+ using ThreadShape = ThreadShape_; -+ -+ using LayoutA = LayoutA_; -+ using LayoutB = LayoutB_; -+ using LayoutC = LayoutC_; -+ -+ using ElementA = ElementA_; -+ using ElementB = ElementB_; -+ using ElementC = ElementC_; -+ -+ static int const kThreadsPerN = Shape::kN / ThreadShape::kN; -+ -+ using IteratorPolicyA = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, 1, ThreadShape::kK>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, 1, ThreadShape::kM>>::type; -+ -+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, LayoutA, 1, IteratorPolicyA>; -+ -+ using IteratorPolicyB = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kK>>::type; -+ -+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, LayoutB, 0, IteratorPolicyB>; -+ -+ using IteratorPolicyC = typename platform::conditional< -+ platform::is_same::value, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kN>, -+ cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< -+ layout::PitchLinearShape, kThreadsPerN, ThreadShape::kM>>::type; -+ -+ using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementC, LayoutC, 0, IteratorPolicyC>; -+ -+ using MmaSimtOp = typename cutlass::gemm::thread::Mma< -+ cutlass::gemm::GemmShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>; -+ -+ using Operator = MmaSimtOp; -+ -+ // Assertions for correctness -+ static_assert((Shape::kM == 1), "M=1 is required for GEMV"); -+ -+ static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV"); -+ -+ static_assert(Shape::kK % ThreadShape::kK == 0, "Shape::K must be a multiple of ThreadShape::K"); -+ -+ static_assert(((ThreadShape::kK == 1) || -+ (ThreadShape::kK == 2) || -+ (ThreadShape::kK == 4) || -+ (ThreadShape::kK == 8) || -+ (ThreadShape::kK == 16) || -+ (ThreadShape::kK == 32) -+ ), -+ "ThreadShape::K must be a 1, 2, 4, 8, 16 or 32"); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h -new file mode 100644 -index 0000000..7e0b206 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma.h -@@ -0,0 +1,791 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Gather operand A by using an index array -+ bool GatherA = false, -+ /// Gather operand B by using an index array -+ bool GatherB = false -+ > -+struct DefaultMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass Simt) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassSimt, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ arch::OpClassTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, -+ GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, -+ GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, -+ LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, -+ arch::OpMultiplyAddFastF16>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, float, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, -+ Operator, true, SharedMemoryClearOption::kNone, false, false> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, -+ true>; -+ -+ static_assert(kAlignmentA == 128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ static_assert(kAlignmentB ==128 / sizeof_bits::value, -+ "Alignment must match thread data map's vector length"); -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementA, -+ LayoutA, 1, typename MmaCore::IteratorThreadMapA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, ElementB, -+ LayoutB, 0, typename MmaCore::IteratorThreadMapB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, -+ typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassSimt, -+ Stages, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, LayoutC, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operand -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear, -+ /// Gather operand A by using an index array -+ bool GatherA, -+ /// Gather operand B by using an index array -+ bool GatherB -+ > -+struct DefaultMma { -+ -+ static_assert(platform::is_same::value -+ || platform::is_same>::value, -+ "simt epilogue must be row major"); -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, LayoutC, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for column-major-interleaved output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Number of Interleaved K -+ int InterleavedK> -+struct DefaultMma, OperatorClass, -+ ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ Stages, Operator, true, SharedMemoryClearOption::kNone, false, false> { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, -+ layout::ColumnMajorInterleaved, OperatorClass, Stages, -+ Operator, true>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for SIMT IDP4A Kernels -+template < -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Operation performed by GEMM -+ typename Operator, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape> -+struct DefaultMma, 2, -+ Operator, false, SharedMemoryClearOption::kNone, false, false> { -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using OperatorClass = arch::OpClassSimt; -+ -+ static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value; -+ static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, -+ OperatorClass, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ layout::RowMajor, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+/// Specialization for Wmma TensorOp operator with 2 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 2, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for Wmma TensorOp operator with 1 staged pipeline -+template < -+ ///< Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Operation performed by GEMM -+ typename Operator> -+struct DefaultMma { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, LayoutC, -+ arch::OpClassWmmaTensorOp, 1, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; -+ -+ // Define the threadblock-scoped singlestage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, -+ LayoutC, typename MmaCore::MmaPolicy>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h -new file mode 100644 -index 0000000..3d7ffe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+#include "cutlass/arch/cache_operation.h" -+#include "cutlass/arch/mma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DefaultMmaCore; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h -new file mode 100644 -index 0000000..a6d8ec0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_simt.h -@@ -0,0 +1,1723 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+namespace detail { -+ -+// convert a WarpShape which is the whole tile of elements into warp num threads. -+// The goal is for each thread's tile of elements to be as square as possible -+// for performance (4x4 will be faster than 2x8). -+template -+constexpr int simt_get_warp_threads_m() { -+ return (WarpShape::kM > WarpShape::kN) ? 8 : 4; -+} -+ -+/// Computes padding in shared memory to perform efficient transpose without bank conflicts. -+constexpr int simt_transpose_padding(int threads, int crosswise, int size_in_bits) { -+ return (size_in_bits >= 32 ? -+ threads / crosswise / (size_in_bits / 32) : -+ threads / crosswise * (32 / size_in_bits) -+ ); -+} -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA // was IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB // was IteratorThreadMapA -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingM % LaneM), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, // skew for A matrix to avoid SMEM bank conflicts -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ static_assert(!(kPaddingN % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2RowMajor, ElementB_, layout::AffineRank2RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::AffineRank2ColumnMajor, ElementB_, layout::AffineRank2ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::ColumnMajor, int8_t, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Row-major -+/// B: Column-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, kPaddingN>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Row-major -+/// B: Row-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::RowMajor, int8_t, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ SmemThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization: -+// -+/// -+/// A: Column-major -+/// B: Column-major -+/// Operator: simt class, for dp4a -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, int8_t, -+ layout::ColumnMajor, int8_t, layout::ColumnMajor, ElementC_, -+ LayoutC_, arch::OpClassSimt, 2, Operator_ -+ > { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 4>; -+ using ElementA = int8_t; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorInterleaved<4>; -+ using SmemLayoutB = layout::RowMajorInterleaved<4>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 4> -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(4, ThreadTileM); -+ static const int LaneN = cutlass::const_min(4, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 4>; -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::ColumnMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ PartitionsK /// Number of partitions along K dimension -+ >; -+ -+ static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, kPaddingN>, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h -new file mode 100644 -index 0000000..fc83965 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm70.h -@@ -0,0 +1,682 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>; -+ -+ // Shared memory layout -+ using SmemLayoutB = -+ layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::RowMajor, ElementB_, layout::RowMajor, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::ColumnMajor, -+ ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<8, 8, 4>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<4, 8>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h -new file mode 100644 -index 0000000..697c45f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h -@@ -0,0 +1,1279 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// Below is for arch::OpMultiplyAddFastF16 -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = float; -+ using LayoutB = layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(half_t))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::RowMajor; -+ using ElementB = float; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = -+ layout::RowMajorTensorOpMultiplicandCrosswise::value, -+ Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 1, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::RowMajor; -+ using ElementB = float; -+ using LayoutB = layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutA, -+ 0, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ half_t, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = float; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = float; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 256; -+ -+ /// Default Operator -+ using Operator = arch::OpMultiplyAdd; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(half_t))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, half_t, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, half_t, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, half_t, SmemLayoutA, half_t, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major-interleave -+/// B: row-major-interleave -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+/// -+/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -+/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -+/// can be reused. The shared store iterator is the same as the crosswise shared -+/// store iterator. So, the only thing we need to do is to swap the coordinates -+/// (contiguous <=> strided) used by the global iterator and the shared store -+/// iterator. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Number of interleaved k -+ int InterleavedK> -+struct DefaultMmaCore, ElementB_, -+ layout::RowMajorInterleaved, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, 2, Operator_, -+ AccumulatorsInRowMajor> { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajorInterleaved; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassTensorOp; -+ static int const kInterleavedK = InterleavedK; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ static int const kWarpThreadArrangementContiguous = -+ kInterleavedK / kElementsPerAccess; -+ -+ static int const kWarpThreadArrangementStrided = -+ kWarpSize / kWarpThreadArrangementContiguous; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapA, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapB, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h -new file mode 100644 -index 0000000..ad232fc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h -@@ -0,0 +1,2916 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+ -+ SM80 Multi stage kernel expects stage number to be larger or equal to 3 -+ to use asyncronous copy. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = double; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = double; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = double; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float-precision -+/// -+/// ElementA: complex -+/// ElementB: complex -+/// ElementC: complex -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout for A operand -+ typename LayoutA_, -+ /// Layout for B operand -+ typename LayoutB_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA_, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB_ -+ > -+struct DefaultMmaCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, LayoutA_, -+ complex, LayoutB_, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ Operator_, -+ false, -+ CacheOpA, -+ CacheOpB, -+ TransformA_, TransformB_, true> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = LayoutA_; -+ using ElementB = complex; -+ using LayoutB = LayoutB_; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ static const ComplexTransform TransformA = TransformA_; -+ static const ComplexTransform TransformB = TransformB_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value, -+ "The operator tag must indicate complex multiplication."); -+ -+ // -+ // Underlying template -+ // -+ -+ using MmaComplexCore = DefaultMultistageMmaComplexCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ arch::OpClassTensorOp, -+ kStages, -+ TransformA, -+ TransformB, -+ Operator, -+ kCacheOpA, -+ kCacheOpB -+ >; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; -+ -+ // Shared memory layout -+ using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename MmaComplexCore::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for double-precision -+/// -+/// ElementA: complex -+/// ElementB: complex -+/// ElementC: complex -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout for A operand -+ typename LayoutA_, -+ /// Layout for B operand -+ typename LayoutB_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA_, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB_ -+ > -+struct DefaultMmaCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, LayoutA_, -+ complex, LayoutB_, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ Operator_, -+ false, -+ CacheOpA, -+ CacheOpB, -+ TransformA_, TransformB_, true> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = LayoutA_; -+ using ElementB = complex; -+ using LayoutB = LayoutB_; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ static const ComplexTransform TransformA = TransformA_; -+ static const ComplexTransform TransformB = TransformB_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 64; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "The operator tag must indicate complex multiplication."); -+ -+ // -+ // Underlying template -+ // -+ -+ using MmaComplexCore = DefaultMultistageMmaComplexCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ arch::OpClassTensorOp, -+ kStages, -+ TransformA, -+ TransformB, -+ Operator, -+ kCacheOpA, -+ kCacheOpB -+ >; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename MmaComplexCore::SmemLayoutA; -+ -+ // Shared memory layout -+ using SmemLayoutB = typename MmaComplexCore::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename MmaComplexCore::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename MmaComplexCore::SmemIteratorA; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = typename MmaComplexCore::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename MmaComplexCore::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename MmaComplexCore::MmaTensorOp; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename MmaComplexCore::MmaPolicy; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major-interleaved -+/// B: row-major-interleaved -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+/// -+/// Column/RowMajorInterleved(m, n) is mapped to Column/RowMajor(m -+/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators -+/// can be reused. The shared store iterator is the same as the crosswise shared -+/// store iterator. So, the only thing we need to do is to swap the coordinates -+/// (contiguous <=> strided) used by the global iterator and the shared store -+/// iterator. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Number of interleaved K -+ int InterleavedK> -+struct DefaultMmaCore, ElementB_, -+ layout::RowMajorInterleaved, ElementC_, -+ LayoutC_, arch::OpClassTensorOp, Stages, Operator_, -+ AccumulatorsInRowMajor, CacheOpA, CacheOpB> { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajorInterleaved; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajorInterleaved; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static int const kInterleavedK = InterleavedK; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ static int const kWarpThreadArrangementContiguous = -+ kInterleavedK / kElementsPerAccess; -+ -+ static int const kWarpThreadArrangementStrided = -+ kWarpSize / kWarpThreadArrangementContiguous; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kInterleavedK>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapA, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMap< -+ IteratorThreadMapB, -+ layout::PitchLinearShape>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneM) && !((Shape::kK / 32) % LaneN), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ -+ static_assert(!((Shape::kK / 32) % LaneM), -+ "Padding must be divisible by Lane"); -+ -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+ -+}; -+ -+/// Partial specialization for SIMT GEMMs using multistage pipeline. -+/// -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by Simt -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::AffineRank2RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::AffineRank2RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ using Base = DefaultMmaCore; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = typename Base::MmaPolicy; -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h -new file mode 100644 -index 0000000..870845f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h -@@ -0,0 +1,834 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting sparse -+ TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_sparse_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ /// Cache operation of operand A -+ , cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global -+> -+struct DefaultSparseMmaCore; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // crosswise cannot be larger than 1024 bit. -+ static int const kCrosswiseB = -+ (Shape::kK > (1024 / sizeof_bits::value)) -+ ? (1024 / sizeof_bits::value) -+ : Shape::kK; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK / kSparse>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswiseB>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ // crosswise cannot be larger than 1024 bit. -+ static int const kCrosswiseB = -+ (Shape::kK > (1024 / sizeof_bits::value)) -+ ? (1024 / sizeof_bits::value) -+ : Shape::kK; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementA))>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswiseB>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultSparseMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ static int const kSparse = 2; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kK / kSparse>; -+ -+ // Shared memory layout -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(ElementB))>; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, WarpCount::kK>::Type; -+ -+ /// Cache operation of operand E -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = -+ cutlass::arch::CacheOperation::Global; -+ -+ static int const kInterleavedE = MmaTensorOp::kInterleaved; -+ static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; -+ static int const kMaxID2 = MmaTensorOp::kMaxID2; -+ static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; -+ -+ using ElementE = typename MmaTensorOp::ElementE; -+ using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. -+ using SmemLayoutE = typename MmaTensorOp::LayoutE; -+ -+ /// ThreadMap of iterator E -+ static int const kElementsPerAccessE = -+ kAccessSizeInBits / sizeof_bits::value; -+ -+ /// E is tiny. Not all warps are needed. -+ static int const kThreadsE = -+ (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value) > -+ kThreads) -+ ? kThreads -+ : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / -+ (kAccessSizeInBits / sizeof_bits::value)); -+ -+ using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreadsE, kElementsPerAccessE>; -+ -+ /// Shared memory iterator to E operand -+ using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, -+ ElementE, SmemLayoutE, 0, IteratorThreadMapE>; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = -+ SparseMmaPolicy, MatrixShape<0, 0>, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h -new file mode 100644 -index 0000000..0345084 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting simt instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/warp/mma.h" -+#include "cutlass/gemm/threadblock/mma_pipelined.h" -+#include "cutlass/gemm/threadblock/mma_singlestage.h" -+#include "cutlass/arch/cache_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Size of a threadblock-scoped access -+ int kAccessSizeInBits = -1, // -1 denoting the default -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false // (is_complex::value || is_complex::value) -+> -+struct DefaultMmaCoreWithAccessSize; -+ -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Operation performed by MMA -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB, -+ bool IsComplex -+> -+struct DefaultMmaCoreWithAccessSize< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, -1, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> : DefaultMmaCore< -+ Shape, WarpShape, InstructionShape, -+ ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ OperatorClass, Stages, Operator, AccumulatorsInRowMajor, -+ CacheOpA, CacheOpB, TransformA, TransformB, IsComplex -+> {}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: simt class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Size of a threadblock-scoped access (a value of -1 indicates the default) -+ int kAccessSizeInBits_, -+ /// Operation performed by GEMM -+ typename Operator_> -+struct DefaultMmaCoreWithAccessSize>::type, ElementA_, -+ layout::ColumnMajor, ElementB_, layout::RowMajor, -+ ElementC_, LayoutC_, arch::OpClassSimt, kAccessSizeInBits_, 2, Operator_ -+ > { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassSimt; -+ static int const PartitionsK = Shape::kK / WarpShape::kK; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ PartitionsK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ static int const kElementsPerAccessDefault = 1; -+ static_assert(kAccessSizeInBits_ == -1 || -+ sizeof_bits::value == sizeof_bits::value || -+ kAccessSizeInBits_ / sizeof_bits::value == kElementsPerAccessDefault, -+ "Non-default value for kAccessSizeInBits_ is only allowed if size(elementA) == sizeof(elementB)"); -+ static int const kElementsPerAccess = (kAccessSizeInBits_ != -1) ? kAccessSizeInBits_ / sizeof_bits::value : kElementsPerAccessDefault; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = detail::simt_get_warp_threads_m(); -+ static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h -new file mode 100644 -index 0000000..d150791 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h -@@ -0,0 +1,167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_with_reduction_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from threadblock tile size, -+/// global memory data layout, and target math instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of stages -+ int Stages = 2, -+ /// Operation performed by MMA -+ typename Operator = typename platform::conditional< -+ (platform::is_same::value) && -+ (platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value || -+ platform::is_same::value), -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::arch::OpMultiplyAdd>::type, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global, -+ /// per-element transformation for elements of A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// per-element transformation for elements of B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ bool IsComplex = false// (is_complex::value || is_complex::value) -+> -+struct DefaultMmaWithReductionCore { -+ using Base = DefaultMmaCore; -+ using Shape = Shape_; -+ using IteratorThreadMapA = typename Base::IteratorThreadMapA; -+ using IteratorThreadMapB = typename Base::IteratorThreadMapB; -+ using SmemIteratorA = typename Base::SmemIteratorA; -+ using SmemIteratorB = typename Base::SmemIteratorB; -+ using SmemLayoutA = typename Base::SmemLayoutA; -+ using SmemLayoutB = typename Base::SmemLayoutB; -+ using WarpCount = typename Base::WarpCount; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp< -+ WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, -+ ElementC, LayoutC, Operator, ReduceKForA_, WarpCount::kK>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h -new file mode 100644 -index 0000000..f4d0a23 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_core_wmma.h -@@ -0,0 +1,712 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data -+ layout of the global memory fragments, data types, and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: wmma tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ ///< Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // -+ // Shared memory layouts -+ // -+ // NOTE: shared memory layout for wmma is same as the operands' layout in the global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape, -+ MatrixShape<0, kPaddingB>, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: wmma tensorop class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ ///< Shape of threadblock-scoped matrix multiply operator -+ ///< (concept:GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) [allowed -+ /// wmma instruction shapes, e.g., 16x16x16, 32x8x16, 8x32x16,...] -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by GEMM -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads per threadblock -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB // SmemThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, kPaddingA>, -+ MatrixShape, -+ WarpCount::kK -+ >; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::RowMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::RowMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape< -+ Shape::kM / WarpShape::kM, -+ Shape::kN / WarpShape::kN, -+ Shape::kK / WarpShape::kK -+ >; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && -+ !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." -+ ); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousA = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedA = -+ kWarpSize / kWarpThreadArrangementContiguousA; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementA, -+ SmemLayoutA, -+ 1, -+ IteratorThreadMapA -+ >; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, -+ ElementB, -+ SmemLayoutB, -+ 0, -+ IteratorThreadMapB -+ >; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape<0, kPaddingA>, -+ MatrixShape<0, kPaddingB>, -+ WarpCount::kK -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization: -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: tensor op class -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A operand -+ typename ElementA_, -+ /// Data type of B operand -+ typename ElementB_, -+ /// Data type of accumulator -+ typename ElementC_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Operation performed by MMA -+ typename Operator_, -+ /// Number of stages -+ int Stages> -+struct DefaultMmaCore { -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = ElementA_; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = ElementB_; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of warps present -+ using WarpCount = -+ GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped access -+ static int const kAccessSizeInBits = 128; -+ -+ /// Default Operator -+ using Operator = Operator_; -+ -+ // Warp thread arrangement -+ static int const kWarpThreadArrangementContiguousB = -+ Shape::kK / (kAccessSizeInBits / sizeof_bits::value); -+ -+ static int const kWarpThreadArrangementStridedB = -+ kWarpSize / kWarpThreadArrangementContiguousB; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ // shared memory layout for wmma is same as the operands' layout in global memory -+ using SmemLayoutA = LayoutA; -+ using SmemLayoutB = LayoutB; -+ -+ // Pad shared memory to avoid bank conflicts -+ static int const kPaddingA = 128 / sizeof_bits::value; -+ static int const kPaddingB = 128 / sizeof_bits::value; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kAccessSizeInBits / sizeof_bits::value -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape, -+ ElementA, -+ SmemLayoutA, -+ ElementB, -+ SmemLayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaTensorOp, -+ MatrixShape, -+ MatrixShape, -+ WarpCount::kK -+ >; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h -new file mode 100644 -index 0000000..b05c634 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaLayernormMainloopFusion { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorVarMean = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorGammaBeta = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ using SmemIteratorGammaBeta = -+ cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< -+ cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, -+ LayoutScaleBias>; -+ -+ static int const kThreadCount = 32; -+ -+ // Warp-level iterators to load scale and bias vectors -+ using WarpIteratorGammaBeta = cutlass::gemm::warp::ScaleBiasTileIterator< -+ MatrixShape, ElementScaleBias, -+ LayoutScaleBias, MatrixShape, -+ typename MmaCore::MmaTensorOp::IteratorA::Base::Policy, kThreadCount, -+ MmaCore::WarpCount::kK>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaLayernormMainloopFusionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, IteratorVarMean, IteratorGammaBeta, SmemIteratorGammaBeta, -+ CacheOpGammaBeta, -+ ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, WarpIteratorGammaBeta, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h -new file mode 100644 -index 0000000..6915b20 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_multistage.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Math operator tag (e.g. arch::OpMultiplyAdd) -+ typename Operator = arch::OpMultiplyAdd -+> -+struct DefaultMmaPlanarComplexMultistage { -+ -+ // Construct a planar complex variant from the real-valued variant -+ using RealMmaMultistage = typename DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator_, -+ LayoutC_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ Stages, -+ Operator -+ >::ThreadblockMma; -+ -+ using ThreadblockMma = MmaPlanarComplexMultistage< -+ ThreadblockShape_, -+ typename RealMmaMultistage::IteratorA, -+ typename RealMmaMultistage::SmemIteratorA, -+ cutlass::arch::CacheOperation::Global, -+ typename RealMmaMultistage::IteratorB, -+ typename RealMmaMultistage::SmemIteratorB, -+ cutlass::arch::CacheOperation::Global, -+ ElementAccumulator_, -+ LayoutC_, -+ typename RealMmaMultistage::Policy, -+ Stages, -+ TransformA, -+ TransformB -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h -new file mode 100644 -index 0000000..a7ae5a4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_planar_complex_pipelined.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+#include "cutlass/gemm/warp/mma_planar_complex.h" -+#include "cutlass/gemm/threadblock/default_mma.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_pipelined.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Math operator tag (e.g. arch::OpMultiplyAdd) -+ typename Operator = arch::OpMultiplyAdd -+> -+struct DefaultMmaPlanarComplexPipelined { -+ -+ // Construct a planar complex variant from the real-valued variant -+ using RealMma = typename DefaultMma< -+ ElementA_, -+ LayoutA_, -+ kAlignmentA, -+ ElementB_, -+ LayoutB_, -+ kAlignmentB, -+ ElementAccumulator_, -+ LayoutC_, -+ OperatorClass_, -+ ArchTag_, -+ ThreadblockShape_, -+ WarpShape_, -+ InstructionShape_, -+ Stages, -+ Operator -+ >::ThreadblockMma; -+ -+ using ThreadblockMma = MmaPlanarComplexPipelined< -+ ThreadblockShape_, -+ typename RealMma::IteratorA, -+ typename RealMma::SmemIteratorA, -+ typename RealMma::IteratorB, -+ typename RealMma::SmemIteratorB, -+ ElementAccumulator_, -+ LayoutC_, -+ typename RealMma::Policy, -+ Stages, -+ TransformA, -+ TransformB -+ >; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h -new file mode 100644 -index 0000000..e8db4d8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h -@@ -0,0 +1,160 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined softmax-GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+#include "cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" -+#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" -+#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for Scale/Bias vectors -+ typename ElementScaleBias, -+ /// Layout type for Scale/Bias vectors -+ typename LayoutScaleBias, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Whether problem has been transformed. This determines to which operand -+ /// the softmax is applied. -+ bool InternalTranspose, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaSoftmaxMainloopFusion { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ /// Define iterators over tiles from scale/bias vectors -+ using IteratorNormSum = -+ cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< -+ cutlass::MatrixShape<1, WarpShape::kN>, -+ ElementScaleBias, -+ LayoutScaleBias>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaSoftmaxMainloopFusionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, IteratorNormSum, -+ ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, InternalTranspose, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h -new file mode 100644 -index 0000000..bc6957a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_mma_with_reduction.h -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_with_reduction.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Layout type for C and D matrix operands -+ typename LayoutC, -+ /// Operator class tag -+ typename OperatorClass, -+ /// -+ bool ReduceKForA_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Use zfill or predicate for SM80 out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone -+ > -+struct DefaultMmaWithReduction { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaWithReductionCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ ReduceKForA_, Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaWithReductionMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h -new file mode 100644 -index 0000000..4bd3530 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/arch.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMultistageMmaComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageMmaComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h -new file mode 100644 -index 0000000..79b4ec3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core.h -@@ -0,0 +1,119 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defininng default matrix multiply operators inferred from -+/// threadblock tile size, global memory data layout, and target math -+/// instruction. -+template < -+ /// Shape of threadblock-scoped matrix multiply operator -+ typename Shape, -+ /// Shape of warp-level matrix multiply operator -+ typename WarpShape, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape, -+ /// Element data type of A operand -+ typename ElementA, -+ /// Layout of operand A -+ typename LayoutA, -+ /// Element data type of B operand -+ typename ElementB, -+ /// Layout of operand B -+ typename LayoutB, -+ /// Data type of accumulator -+ typename ElementC, -+ /// Layout of accumulator -+ typename LayoutC, -+ /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) -+ typename OperatorClass, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA = -+ cutlass::arch::CacheOperation::Global, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB = -+ cutlass::arch::CacheOperation::Global> -+struct DefaultMultistageMmaComplexCore; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h -new file mode 100644 -index 0000000..1a7065b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h -@@ -0,0 +1,1808 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic properties needed by CTA-level GEMMs assuming -+ expectations about data layout of the global memory fragments, data types, -+ and internal tile sizes. -+ -+ Partial specializations for threadblock::Mma operations targeting TensorOp -+ instructions. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" -+ -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" -+#include "cutlass/gemm/threadblock/mma_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ using Operator = Operator_; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex double-precision -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, InstructionShape_, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = InstructionShape_; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped 128 -+ static int const kAccessSizeInBits = 128; -+ -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<8, 4>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex floating-point -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex -+/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<16, 8, 8>, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassTensorOp, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<16, 8, 8>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of a threadblock-scoped -+ static int const kAccessSizeInBits = 64; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 1, -+ IteratorThreadMapA>; -+ -+ /// ThreadMap of iterator B -+ using IteratorThreadMapB = transform::PitchLinearWarpStripedThreadMap< -+ layout::PitchLinearShape, kThreads, -+ layout::PitchLinearShape<16, 2>, -+ kAccessSizeInBits / sizeof_bits::value>; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 0, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level tensor op -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ WarpShape, InstructionShape, -+ ElementA, SmemLayoutA, -+ ElementB, SmemLayoutB, -+ ElementC, LayoutC, -+ kTransformA, kTransformB, -+ Operator>::Type; -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy, -+ MatrixShape<0, 0>, WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: column-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::ColumnMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: column-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::ColumnMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::ColumnMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ IteratorThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape<0, 0>, -+ MatrixShape<0, 0>, // or Shape::kK / 32 -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: row-major -+/// B: column-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::RowMajor, -+ complex, layout::ColumnMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::ColumnMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator B -+ using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ SmemThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, Shape::kK / 32>, -+ WarpCount::kK>; -+}; -+ -+/// Partial specialization for complex SIMT operation -+/// -+/// A: row-major -+/// B: row-major -+/// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex -+/// -+/// This uses the default warp-level operator given tile sizes -+template < -+ /// Shape of threadblock-scoped matrix multiply operator (concept: -+ /// GemmShape) -+ typename Shape_, -+ /// Shape of warp-level matrix multiply operator (concept: GemmShape) -+ typename WarpShape_, -+ typename RealA, -+ typename RealB, -+ typename RealC, -+ /// Layout of accumulator -+ typename LayoutC_, -+ /// Number of stages -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_, -+ /// Cache operation of operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Cache operation of operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB> -+struct DefaultMultistageMmaComplexCore< -+ Shape_, WarpShape_, GemmShape<1, 1, 1>, -+ complex, layout::RowMajor, -+ complex, layout::RowMajor, -+ complex, LayoutC_, -+ arch::OpClassSimt, -+ Stages, -+ TransformA, TransformB, -+ Operator_, -+ CacheOpA, CacheOpB> { -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using InstructionShape = GemmShape<1, 1, 1>; -+ using ElementA = complex; -+ using LayoutA = layout::RowMajor; -+ using ElementB = complex; -+ using LayoutB = layout::RowMajor; -+ using ElementC = complex; -+ using LayoutC = LayoutC_; -+ static int const kStages = Stages; -+ static ComplexTransform const kTransformA = TransformA; -+ static ComplexTransform const kTransformB = TransformB; -+ using Operator = Operator_; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always; -+ -+ /// Number of warps present -+ using WarpCount = GemmShape; -+ -+ // Divisility requirements -+ static_assert( -+ !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), -+ "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); -+ -+ static_assert(WarpCount::kCount > 1, -+ "This specialization requires at least two warps."); -+ -+ /// Number of threads per warp -+ static int const kWarpSize = warp::WarpSize::value; -+ -+ /// Number of threads total -+ static int const kThreads = WarpCount::kCount * kWarpSize; -+ -+ /// Size of access -+ static int const kAccessSizeInBits = sizeof_bits::value; -+ -+ /// No vectorized accesses -+ static int const kElementsPerAccess = 1; -+ -+ // -+ // Shared memory layouts -+ // -+ -+ using SmemLayoutA = layout::ColumnMajor; -+ -+ using SmemLayoutB = layout::RowMajor; -+ -+ // -+ // Iterators to write to shared memory -+ // -+ -+ /// ThreadMap of iterator A -+ using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Transpose the ThreadMap of iterator A -+ using SmemThreadMapA = transform::TransposePitchLinearThreadMapSimt; -+ -+ /// Shared memory iterator to A operand -+ using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementA, SmemLayoutA, 0, -+ SmemThreadMapA>; -+ -+ /// Policy of iterator B -+ using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< -+ layout::PitchLinearShape, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ /// Shared memory iterator to B operand -+ using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< -+ MatrixShape, ElementB, SmemLayoutB, 1, -+ IteratorThreadMapB>; -+ -+ // -+ // Warp-level matrix multiply operator -+ // -+ -+ // Define the warp-level op -+ static const int WarpNumThreadsM = 4; // TODO need to extract these from template data -+ static const int WarpNumThreadsN = 8; -+ static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), -+ "WarpShape must be divisible by ThreadTile shape."); -+ static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; -+ static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; -+ static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; -+ static const int numElementsA = 128 / sizeof_bits::value; -+ static const int numElementsB = 128 / sizeof_bits::value; -+ static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); -+ static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); -+ // these should have max of thread tile also -+ using LaneMmaShape = cutlass::gemm::GemmShape< -+ LaneM, -+ LaneN, -+ 1>; -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape, // WarpShape -+ cutlass::layout::RowMajorInterleaved, // LaneLayout -+ LaneMmaShape -+ >; -+ -+ using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 -+ ElementA, /// Data type of A elements -+ SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) -+ ElementB, /// Data type of B elements -+ SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) -+ ElementC, /// Element type of C matrix -+ LayoutC, /// Layout of C matrix (concept: MatrixLayout) -+ Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ 1, /// 1 partition along K dimension -+ kTransformA, /// Transform for A -+ kTransformB /// Transform for B -+ >; /// Used for partial specialization -+ -+ /// Policy used to define MmaPipelined -+ using MmaPolicy = MmaPolicy< -+ MmaWarpSimt, -+ MatrixShape, -+ MatrixShape<0, 0>, // or Shape::kK / 32 -+ WarpCount::kK>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h -new file mode 100644 -index 0000000..367869e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_multistage_trmm_complex.h -@@ -0,0 +1,556 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a multistage GEMM kernel. Does not compute batching or support split-K. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -+#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator = arch::OpMultiplyAddComplex, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kTriangular, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMultistageTrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, kDiagType, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, kDiagType, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output with unit diagonal -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, DiagType::kUnit, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode, unit diagonal -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, DiagType::kUnit, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (for TRMM where diagonal imag part is ignored - used by HEMM) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, -+ // when DiagType is kUnit -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ kSideMode, kFillMode, DiagType::kUnit, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ kSideMode, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, -+ BlasMode::kHermitian>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output and right-side mode (for TRMM where diagonal imag part is ignored - used by HEMM) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename OperatorClass, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator> -+struct DefaultMultistageTrmmComplex { -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplexCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, -+ Stages, TransformA, TransformB, Operator>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, -+ SideMode::kRight, FillMode::kFull, DiagType::kInvalid, -+ AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ // PredicatedTileAccessIteratorTriangularMatrix only tracks diagonal elements, -+ // when DiagType is kUnit -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, -+ SideMode::kRight, kFillMode, DiagType::kUnit, -+ AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill, -+ BlasMode::kHermitian>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h -new file mode 100644 -index 0000000..5faa76b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_sparse_mma.h -@@ -0,0 +1,196 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultSparseMma; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultSparseMma { -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ static int const kSparse = MmaCore::kSparse; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define iterators over tiles from the E operand -+ using ElementE = typename MmaCore::ElementE; -+ using LayoutE = typename MmaCore::GmemLayoutE; -+ using ThreadMapE = typename MmaCore::IteratorThreadMapE; -+ using AccessTypeE = -+ cutlass::Array::value>; -+ using IteratorE = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE, -+ typename MmaCore::MmaPolicy, Stages>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h -new file mode 100644 -index 0000000..8c13d17 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/default_trmm.h -@@ -0,0 +1,445 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+// -+/*! \file -+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/wmma.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h" -+#include "cutlass/gemm/threadblock/mma_blas3_multistage.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+#endif //CUTLASS_ARCH_WMMA_ENABLED -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Element type for A matrix operand -+ typename ElementA_, -+ /// Layout type for A matrix operand -+ typename LayoutA_, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB_, -+ /// Layout type for B matrix operand -+ typename LayoutB_, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator_, -+ /// Layout type for C and D matrix operands -+ typename LayoutC_, -+ /// Operator class tag -+ typename OperatorClass_, -+ /// Tag indicating architecture to tune for -+ typename ArchTag_, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape_, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape_, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape_, -+ /// Number of stages used in the pipelined mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+ > -+struct DefaultTrmm; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, kDiagType, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output, right side mode (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Diag Type for the triangular matrix -+ DiagType kDiagType, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, kDiagType, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output with unit diagonal (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Side Mode for the kernel -+ SideMode kSideMode, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, kSideMode, kFillMode, DiagType::kUnit, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, kSideMode, FillMode::kFull, DiagType::kInvalid, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for row-major output, right side mode, unit diagonal (OperatorClass TensorOp) -+template < -+ /// Element type for A matrix operand -+ typename ElementA, -+ /// Layout type for A matrix operand -+ typename LayoutA, -+ /// Access granularity of A matrix in units of elements -+ int kAlignmentA, -+ /// Element type for B matrix operand -+ typename ElementB, -+ /// Layout type for B matrix operand -+ typename LayoutB, -+ /// Access granularity of B matrix in units of elements -+ int kAlignmentB, -+ /// Fill Mode for the triangular matrix -+ FillMode kFillMode, -+ /// Element type for internal accumulation -+ typename ElementAccumulator, -+ /// Tag indicating architecture to tune for -+ typename ArchTag, -+ /// Threadblock-level tile size (concept: GemmShape) -+ typename ThreadblockShape, -+ /// Warp-level tile size (concept: GemmShape) -+ typename WarpShape, -+ /// Instruction-level tile size (concept: GemmShape) -+ typename InstructionShape, -+ /// Number of stages used in the multistage mainloop -+ int Stages, -+ /// Operation perfomed by GEMM -+ typename Operator -+ > -+struct DefaultTrmm { -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ ((sizeof_bits::value * kAlignmentA) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ ((sizeof_bits::value * kAlignmentB) == 128) -+ ? cutlass::arch::CacheOperation::Global -+ : cutlass::arch::CacheOperation::Always; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, -+ Stages, Operator, false, CacheOpA, CacheOpB>; -+ -+ // Define iterators over tiles from the A operand -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using AccessTypeA = cutlass::Array; -+ -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, SideMode::kRight, FillMode::kFull, DiagType::kInvalid, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeB = cutlass::Array; -+ -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIteratorTriangularMatrix< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, SideMode::kRight, kFillMode, DiagType::kUnit, AccessTypeB>; -+ -+ // Define the threadblock-scoped multistage matrix multiply -+ using ThreadblockMma = cutlass::gemm::threadblock::MmaBlas3Multistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, -+ MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, -+ typename MmaCore::MmaPolicy, Stages, SharedMemoryClearOption::kZfill>; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h -new file mode 100644 -index 0000000..3f73b9e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_multistage.h -@@ -0,0 +1,642 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a multistage threadblock-scoped Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class EllMmaMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ EllMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ template -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, EllIterator &ell_iter, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ bool is_valid = iterator_A.valid(); -+ -+ if (!is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iter.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_A.get_k(); -+ auto ell_offset = ell_iter.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ bool is_valid = iterator_B.valid(); -+ -+ if (is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iter.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_B.get_k(); -+ auto ell_offset = ell_iter.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ EllIterator &ell_iterator -+ ) { -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_A.get(); -+ bool is_valid = iterator_A.valid(); -+ -+ if (!is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iterator.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_A.get_k(); -+ auto ell_offset = ell_iterator.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_B.get(); -+ bool is_valid = iterator_B.valid(); -+ -+ if (is_A_sparse){ -+ if (is_offset_constant){ -+ auto ell_offset = ell_iterator.get_offset_fast(); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; -+ } else { -+ int k_offset = iterator_B.get_k(); -+ auto ell_offset = ell_iterator.get_offset(k_offset); -+ is_valid = is_valid && (ell_offset >= 0); -+ gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; -+ } -+ } -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, is_valid); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ if (is_A_sparse){ -+ iterator_A.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ else { -+ iterator_B.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, ell_iterator, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, ell_iterator, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h -new file mode 100644 -index 0000000..10ff6df ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/ell_mma_pipelined.h -@@ -0,0 +1,376 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped Blocked-Ell MMA. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class EllMmaPipelined : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for EllMmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "EllMmaPipelined requires kStages set to value 2"); -+ -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+public: -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ EllMmaPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ template -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum, ///< source accumulator tile -+ EllIterator &ell_iterator, -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB()) { ///< transformation applied to B fragment -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // load sparse matrix -+ if (is_A_sparse){ -+ iterator_A.load(tb_frag_A); -+ } else { -+ iterator_B.load(tb_frag_B); -+ } -+ -+ // load dense matrix -+ if (is_offset_constant){ -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); -+ } -+ } else { -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); -+ } -+ } -+ -+ ++iterator_A; -+ ++iterator_B; -+ ++ell_iterator; -+ -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ if (is_A_sparse){ -+ iterator_A.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ else { -+ iterator_B.ell_add_mask(ell_iterator.get_blocksize()); -+ } -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B(tb_frag_B)); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ // load sparse matrix -+ if (is_A_sparse){ -+ iterator_A.load(tb_frag_A); -+ } else { -+ iterator_B.load(tb_frag_B); -+ } -+ -+ // load dense matrix -+ if (is_offset_constant){ -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); -+ } -+ } else { -+ if (is_A_sparse){ -+ iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); -+ } else { -+ iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); -+ } -+ } -+ -+ ++iterator_A; -+ ++iterator_B; -+ ++ell_iterator; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], accum); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h -new file mode 100755 -index 0000000..f0a4b1d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/gemv.h -@@ -0,0 +1,147 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Template for a threadblock-scoped GEMV kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix-vector product using SIMT math instructions. -+template < -+ class Core_ //< GemvCore -+> -+class Gemv { -+public: -+ using Shape = typename Core_::Shape; -+ -+ /// The MMA operator that computes GEMV -+ using Operator = typename Core_::Operator; -+ -+ /// Iterates over A in global memory -+ using IteratorA = typename Core_::IteratorA; -+ -+ /// Iterates over B in global memory -+ using IteratorB = typename Core_::IteratorB; -+ -+ /// Fragment of operand C loaded from global memory -+ using IteratorC = typename Core_::IteratorC; -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of operand accumulator loaded/stored to global memory -+ using FragmentC = typename Operator::FragmentC; -+ -+ /// Shape of the per-thread GEMV operation -+ using ThreadShape = typename Core_::ThreadShape; -+ -+public: -+ CUTLASS_DEVICE -+ Gemv() { } -+ -+ CUTLASS_DEVICE -+ void operator()( -+ GemmCoord const &problem_size, ///< problem size of batched GEMV -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) { ///< source accumualtor tile -+ -+ // -+ // Prologue -+ // -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ frag_A.clear(); -+ frag_B.clear(); -+ -+ iterator_A.load(frag_A); -+ iterator_B.load(frag_B); -+ ++iterator_A; -+ ++iterator_B; -+ -+ // -+ // Mainloop -+ // -+ Operator thread_mma; -+ int gemm_k = problem_size.k(); -+ -+ if (gemm_k < Shape::kK) -+ { -+ iterator_A.clear_mask(); -+ iterator_B.clear_mask(); -+ } -+ -+ // iterate over K to accumulate result -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k > 0; gemm_k -= Shape::kK) { -+ thread_mma(accum, frag_A, frag_B, accum); -+ -+ iterator_A.load(frag_A); -+ iterator_B.load(frag_B); -+ ++iterator_A; -+ ++iterator_B; -+ -+ if (gemm_k < Shape::kK) -+ { -+ iterator_A.clear_mask(); -+ iterator_B.clear_mask(); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h -new file mode 100644 -index 0000000..1e24568 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/index_remat.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for rematerializing indices/dimensions in the thread hierarchy from special registers -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxX() { -+ return threadIdx.x; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxY() { -+ return threadIdx.y; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeThreadIdxZ() { -+ return threadIdx.z; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxX() { -+ return blockIdx.x; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxY() { -+ return blockIdx.y; -+} -+ -+/// Helper to rematerialize block Idx. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockIdxZ() { -+ return blockIdx.z; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimX() { -+ return blockDim.x; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimY() { -+ return blockDim.y; -+} -+ -+/// Helper to rematerialize block Dim. Reduces register liveness. -+CUTLASS_DEVICE -+int RematerializeBlockDimZ() { -+ return blockDim.z; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h -new file mode 100644 -index 0000000..524fdf9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_base.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct MmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h -new file mode 100644 -index 0000000..fa05aac ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_blas3_multistage.h -@@ -0,0 +1,702 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ Used by BLAS3 kernels that need to treat diagonal elements of a input iterator as a special case. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kZfill, -+ /// Blas3 computation mode -+ BlasMode BlasMode_ = BlasMode::kTriangular, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaBlas3Multistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ ///< Blas Mode -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaBlas3Multistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ bool isvalid = iterator_A.valid(); -+ -+ if (isvalid && iterator_A.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ /* The following logic to determine kSizeRealBytes is so that compiler doesn't complain when -+ * compiling for not complex datatype and using half the size for cp_async_zfill */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ bool isvalid = iterator_B.valid(); -+ -+ if (isvalid && iterator_B.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_A.get(); -+ bool isvalid = iterator_A.valid(); -+ -+ if (isvalid && iterator_A.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ auto gmem_ptr = iterator_B.get(); -+ bool isvalid = iterator_B.valid(); -+ -+ if (isvalid && iterator_B.getOnDiag()) { -+ // Elements that are on diagonal -+ if (kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { -+ /* Copy real part from gmem, write zero for imag part in smem */ -+ int const kSizeRealBytes = (platform::is_same>::value) ? 8 : 4; -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, true); -+ cutlass::arch::cp_async_diag( -+ reinterpret_cast (dst_ptr + v) + kSizeRealBytes); -+ } else { -+ /* Write one (1) directly to smem*/ -+ cutlass::arch::cp_async_diag(dst_ptr + v); -+ } -+ } else { -+ // Elements that are not of diagonal -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, isvalid); -+ } -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // -+ // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels -+ // so that all accumulator elements outside the GEMM footprint are zero. -+ // -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ -+ typename IteratorA::AccessType zero_A; -+ zero_A.clear(); -+ -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary -+ // accumulator and this temporary accumulator is added to the final -+ // accumulator once in every mainloop iteration. -+ plus plus_accum; -+ -+ FragmentC tmp_accum; -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ tmp_accum.clear(); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ -+ warp_mma( -+ tmp_accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ tmp_accum -+ ); -+ -+ if (warp_mma_k == 0) { -+ accum = plus_accum(accum, tmp_accum); -+ tmp_accum.clear(); -+ } -+ } else { -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (platform::is_same::value -+ || platform::is_same::value) { -+ accum = plus_accum(accum, tmp_accum); -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h -new file mode 100644 -index 0000000..03055ee ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h -@@ -0,0 +1,865 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ -+ It loads two loop invariant vectors, mean and var, in the prologue and -+ stores them in the register file. In the mainloop, it loads two loop -+ variant vectors, gamma and beta, by using cp.async. We will call -+ elementwise operation to apply var, mean, gamma, beta between ldmatrix and -+ warp mma. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/layernorm_scale_bias_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Element type of scale and bias vectors -+ typename ElementScaleBias_, -+ /// Layout of scale and bias vectors -+ typename LayoutScaleBias_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorGammaBeta_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMainloopFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Element type of scale and bias vectors -+ using ElementScaleBias = ElementScaleBias_; -+ -+ /// Layout of scale and bias vectors -+ using LayoutScaleBias = LayoutScaleBias_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the scale and bias vectors -+ using TensorRefGammaBeta = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the A scale and bias vectors in shared memory -+ using ShapeGammaBeta = -+ MatrixShape<1 + Policy::SmemPaddingA::kRow, -+ 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for A operand Scale and Bias -+ AlignedBuffer operand_A_gamma_beta; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the A scale and bias vectors -+ CUTLASS_DEVICE -+ static LayoutScaleBias LayoutScaleBias() { -+ return LayoutScaleBias::packed( -+ {ShapeGammaBeta::kRow, ShapeGammaBeta::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the A operand Scale vector -+ CUTLASS_HOST_DEVICE -+ TensorRefGammaBeta operand_A_gamma_beta_ref() { -+ return TensorRefGammaBeta{operand_A_gamma_beta.data(), LayoutScaleBias()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of A operand scale and bias vector -+ /// from shared memory -+ WarpIteratorGammaBeta warp_tile_iterator_A_gamma_beta_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMainloopFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_A_gamma_beta_( -+ shared_storage.operand_A_gamma_beta_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of var and mean vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorVarMean_, -+ /// Iterates over vectors of scale and bias vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorGammaBeta_, -+ /// Iterates over vectors of scale and bias vector in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorGammaBeta_, -+ /// Cache operation for scale/bias operand -+ cutlass::arch::CacheOperation::Kind CacheOpGammaBeta, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// WarpIterator to load Scale or Bias vector from the shared memory -+ typename WarpIteratorGammaBeta_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaLayernormMainloopFusionMultistage : -+ public MmaMainloopFusionBase { -+public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the var and mean vectors in global memory -+ using IteratorVarMean = IteratorVarMean_; -+ ///< Iterates over tiles of the scale and bias vectors in global memory -+ using IteratorGammaBeta = IteratorGammaBeta_; -+ ///< WarpIterator to load Scale or Bias vector from the shared memory -+ using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Base class -+ using Base = MmaMainloopFusionBase; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorGammaBeta = SmemIteratorGammaBeta_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpGammaBeta = -+ CacheOpGammaBeta; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ using WarpLoadedFragmentVarMean = typename IteratorVarMean::Fragment; -+ using WarpLoadedFragmentGammaBeta = -+ typename WarpIteratorGammaBeta::Fragment; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory -+ SmemIteratorGammaBeta smem_iterator_A_gamma_beta_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaLayernormMainloopFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_A_gamma_beta_(shared_storage.operand_A_gamma_beta_ref(), -+ thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorGammaBeta &iterator_A_gamma_beta, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ // Async Copy for operand A scale and bias vector. Scale and bias vectors -+ // are small. One iteration is enough. -+ if (group_start_A == 0) { -+ typename IteratorGammaBeta::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_gamma_beta_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorGammaBeta::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over B operand in global memory -+ IteratorVarMean iterator_var_mean, -+ ///< iterator over scale and bias vectors in global memory -+ IteratorGammaBeta iterator_A_gamma_beta, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ WarpLoadedFragmentVarMean warp_loaded_frag_var_mean; -+ iterator_var_mean.add_tile_offset({0, warp_idx_m_}); -+ iterator_var_mean.load(warp_loaded_frag_var_mean); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ // Async Copy for operand A scale and bias vectors. Scale and bias -+ // vectors are small. One iteration is enough. -+ { -+ typename IteratorGammaBeta::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_gamma_beta_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorGammaBeta::kElementsPerAccess / 8; -+ -+ cutlass::arch::cp_async( -+ dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_A_gamma_beta.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpLoadedFragmentGammaBeta warp_loaded_frag_A_gamma_beta[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::gemm::warp::LayernormScaleBiasTransform -+ elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_A_gamma_beta_.load( -+ warp_loaded_frag_A_gamma_beta[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_gamma_beta_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index( -+ (warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_gamma_beta_.load( -+ warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_A_gamma_beta_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[warp_mma_k % 2]); -+ } -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_A_gamma_beta.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_A_gamma_beta_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ elementwise_transform( -+ warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_var_mean, -+ warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); -+ } -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h -new file mode 100644 -index 0000000..5f6f852 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_multistage.h -@@ -0,0 +1,746 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical -+ // accuracy, where each mainloop iteration first accumulates into a temporary -+ // set of freshly-cleared accumulators, which are subsequently added to the -+ // final accumulator set. -+ static bool const kStagedAccumulation = -+ platform::is_same::value || -+ platform::is_same::value; -+ -+ }; -+ -+ private: -+ -+ -+ // Structure encapsulating pipeline state live from one iteration to the next -+ struct PipeState { -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ /// Temporary accumulator to facilitate staged-accumulation -+ FragmentC tmp_accum_; -+ -+ /// Pair of A fragments used to overlap shared memory loads and math instructions -+ WarpLoadedFragmentA warp_loaded_frag_A_[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A_[2]; -+ -+ /// Pair of B fragments used to overlap shared memory loads and math instructions -+ WarpLoadedFragmentB warp_loaded_frag_B_[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B_[2]; -+ }; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Warp-level MMA operator -+ Operator warp_mma_; -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Shared memory write stage index -+ int smem_write_stage_idx_; -+ -+ /// Shared memory read stage index -+ int smem_read_stage_idx_; -+ -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ smem_write_stage_idx_(0), -+ smem_read_stage_idx_(0) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ /// Advance shared memory read-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_read_stage() -+ { -+ ++smem_read_stage_idx_; -+ -+ if (smem_read_stage_idx_ == Base::kStages) { -+ // Wrap back around to the 'start' of the circular buffer in shared memory -+ this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ smem_read_stage_idx_ = 0; -+ } -+ } -+ -+ /// Advance global memory read-iterators and shared memory write-iterators to the stage -+ CUTLASS_DEVICE -+ void advance_smem_write_stage( -+ IteratorA &iterator_A, -+ IteratorB &iterator_B) -+ { -+ // Advance global iterators -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ // Advance shared iterators -+ smem_iterator_A_.add_tile_offset({0, 1}); -+ smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Increment shared memory write stage index -+ ++smem_write_stage_idx_; -+ -+ if (smem_write_stage_idx_ == Base::kStages) { -+ // Wrap back around to the 'start' of the circular buffer in shared memory -+ smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx_ = 0; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching -+ /// the global fragments needed by the first kStages-1 threadblock mainloop iterations -+ CUTLASS_DEVICE -+ void prologue( -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { -+ -+ // Disable global fetching if done with global fetch iterations -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next write stage -+ advance_smem_write_stage(iterator_A, iterator_B); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Optionally clear the remaining stages of SMEM. This is a functional requirement for -+ // some kernels so that all accumulator elements outside the GEMM footprint are zero. -+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); -+ typename IteratorA::AccessType zero_A; -+ -+ zero_A.clear(); -+ last_smem_iterator_A.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_A.get()); -+ -+ *dst_ptr = zero_A; -+ -+ ++last_smem_iterator_A; -+ } -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); -+ typename IteratorB::AccessType zero_B; -+ -+ zero_B.clear(); -+ last_smem_iterator_B.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ last_smem_iterator_B.get()); -+ -+ *dst_ptr = zero_B; -+ -+ ++last_smem_iterator_B; -+ } -+ } -+ } -+ -+ -+ /// Wait until we have at least one completed global fetch stage -+ CUTLASS_DEVICE -+ void gmem_wait() -+ { -+ // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ } -+ -+ -+ /// Perform a threadblock mainloop iteration of matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void mac_loop_iter( -+ PipeState &pipe_state, ///< [in|out] loop-carried pipeline state -+ FragmentC &accum, ///< [in|out] destination accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load the next warp-tile's A fragment from shared memory -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load the next warp-tile's B fragment from shared memory -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary -+ if (warp_mma_k > 0) { -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); -+ } -+ -+ // Execute the current warp-tile of MMA operations -+ if (Detail::kStagedAccumulation) { -+ warp_mma_( -+ pipe_state.tmp_accum_, -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ pipe_state.tmp_accum_ -+ ); -+ -+ if (warp_mma_k == 0) { -+ plus plus_accum; -+ accum = plus_accum(accum, pipe_state.tmp_accum_); -+ pipe_state.tmp_accum_.clear(); -+ } -+ } else { -+ warp_mma_( -+ accum, -+ pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], -+ pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], -+ accum -+ ); -+ } -+ -+ // Except for the last warp-tile, all warp-tiles issue their share of -+ // global->shared fragment copies -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ // The second-to-last warp-tile also: -+ // - performs the last warp-tile's share of global->shared fragment copies -+ // - moves to the next global fetch stage -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ -+ // Performs the last warp-tile's share of global->shared fragment copies -+ int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance( -+ iterator_A, -+ iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Move to the next global fetch stage -+ advance_smem_write_stage(iterator_A, iterator_B); -+ advance_smem_read_stage(); -+ -+ // Disable global fetching when done with global fetch iterations -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // The last warp-tile also converts the shared memory fragments used by -+ // the first warp-tile of the next iteration, if necessary (so we can -+ // immediately start issuing MMA instructions at the top of the loop ) -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], -+ pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ } -+ -+ -+ /// Perform the specified number of threadblock mainloop iterations of matrix -+ /// multiply-accumulate. Assumes prologue has been initiated. -+ CUTLASS_DEVICE -+ void gemm_iters( -+ int gemm_k_iterations, ///< number of threadblock mainloop iterations -+ FragmentC &accum, ///< [in|out] accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory -+ { -+ PipeState pipe_state; -+ -+ // Disable global fetching if done with global fetch iterations -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ // Load first warp-tile's A fragment from shared memory -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load first warp-tile's B fragment from shared memory -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Transform, if necessary, the first warp-tile's shared memory fragments -+ warp_mma_.transform( -+ pipe_state.warp_transformed_frag_A_[0], -+ pipe_state.warp_transformed_frag_B_[0], -+ pipe_state.warp_loaded_frag_A_[0], -+ pipe_state.warp_loaded_frag_B_[0]); -+ -+ if (Detail::kStagedAccumulation) { -+ pipe_state.tmp_accum_.clear(); -+ } -+ -+ // Mainloop -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ mac_loop_iter( -+ pipe_state, -+ accum, -+ iterator_A, -+ iterator_B, -+ gemm_k_iterations); -+ } -+ -+ if (Detail::kStagedAccumulation) { -+ plus plus_accum; -+ accum = plus_accum(accum, pipe_state.tmp_accum_); -+ } -+ -+ // Optionally commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+ -+ -+ /// Prepares the class for another prologue. -+ CUTLASS_DEVICE -+ void wind_down() -+ { -+ // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) -+ -+ // First, increment remaining warp tiles to get to the next full stage. (Ideally we would -+ // just decrement one tile, but not all iterators implement --() decrement.) -+ #pragma unroll -+ for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) -+ { -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ smem_read_stage_idx_++; -+ -+ // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) -+ static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; -+ if (smem_read_stage_idx_ > 1) -+ { -+ this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); -+ this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); -+ } -+ else -+ { -+ this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); -+ this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); -+ } -+ smem_read_stage_idx_ = smem_write_stage_idx_; -+ } -+ -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // Prologue (start fetching iterations of global fragments into shared memory) -+ prologue(iterator_A, iterator_B, gemm_k_iterations); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Initialize destination accumulators with source accumulators -+ accum = src_accum; -+ -+ // Perform the MAC-iterations -+ gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h -new file mode 100644 -index 0000000..8ada21c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_pipelined.h -@@ -0,0 +1,439 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Transformation applied to A operand -+ typename TransformA_ = NumericArrayConverter< -+ typename SmemIteratorA_::Element, -+ typename IteratorA_::Element, -+ IteratorA_::Fragment::kElements>, -+ /// -+ /// Transformation applied to B operand -+ typename TransformB_ = NumericArrayConverter< -+ typename SmemIteratorB_::Element, -+ typename IteratorB_::Element, -+ IteratorB_::Fragment::kElements>, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaPipelined : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ using TransformA = TransformA_; -+ using TransformB = TransformB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Obtain the arch tag from the warp-level operator -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) -+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Warp-level MMA operator -+ Operator warp_mma; -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ ///< transformation applied to A fragment -+ TransformA transform_A_; -+ -+ ///< transformation applied to B fragment -+ TransformB transform_B_; -+ -+ /// Shared memory write stage index -+ int smem_write_stage_idx; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPipelined( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx, ///< ID of each thread within a warp -+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment -+ TransformB transform_B = TransformB() ///< transformation applied to B fragment -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ transform_A_(transform_A), -+ transform_B_(transform_B), -+ smem_write_stage_idx(0) -+ { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ -+ /// Advance shared memory write-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_write_stage() -+ { -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ /// Advance shared memory read- and write-iterators to the next stage -+ CUTLASS_DEVICE -+ void advance_smem_stages() -+ { -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ // wrap write stage -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else -+ { -+ // wrap read stage -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ -+ /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching -+ /// the global fragments needed by the first kStages-1 threadblock mainloop iterations -+ CUTLASS_DEVICE -+ void prologue( -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory -+ int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining -+ { -+ // The last kblock is loaded in the prolog -+ -+ // Load A fragment from global A -+ FragmentA tb_frag_A; -+ tb_frag_A.clear(); -+ iterator_A.load(tb_frag_A); -+ ++iterator_A; -+ -+ // Load B fragment from global B -+ FragmentB tb_frag_B; -+ tb_frag_B.clear(); -+ iterator_B.load(tb_frag_B); -+ ++iterator_B; -+ -+ // Store A and B fragments to shared -+ this->smem_iterator_A_.store(transform_A_(tb_frag_A)); -+ this->smem_iterator_B_.store(transform_B_(tb_frag_B)); -+ -+ // Advance write stage -+ advance_smem_write_stage(); -+ } -+ -+ /// Wait until we have at least one completed global fetch stage -+ CUTLASS_DEVICE -+ void gmem_wait() -+ { -+ __syncthreads(); -+ } -+ -+ -+ /// Perform the specified number of threadblock mainloop iterations of matrix -+ /// multiply-accumulate. Assumes prologue has been initiated. -+ CUTLASS_DEVICE -+ void gemm_iters( -+ int gemm_k_iterations, ///< number of threadblock mainloop iterations -+ FragmentC &accum, ///< [in|out] accumulator tile -+ IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory -+ IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory -+ { -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A[2]; -+ WarpFragmentB warp_frag_B[2]; -+ -+ // Load A fragment from shared A -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_A_.load(warp_frag_A[0]); -+ ++this->warp_tile_iterator_A_; -+ -+ // Load B fragment from shared B -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.load(warp_frag_B[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ // Pair of fragments used to overlap global memory loads and math instructions; -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(transform_A_(tb_frag_A)); -+ -+ this->smem_iterator_B_.store(transform_B_(tb_frag_B)); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Advance smem read and write stages -+ advance_smem_stages(); -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ // Load fragment from global A -+ tb_frag_A.clear(); -+ iterator_A.load(tb_frag_A); -+ ++iterator_A; -+ -+ // Load fragment from global B -+ tb_frag_B.clear(); -+ iterator_B.load(tb_frag_B); -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma( -+ accum, -+ warp_frag_A[warp_mma_k % 2], -+ warp_frag_B[warp_mma_k % 2], -+ accum); -+ } -+ } -+ -+ } -+ -+ -+ /// Prepares the class for another prologue. -+ CUTLASS_DEVICE -+ void wind_down() -+ { -+ // First, increment remaining warp tiles to catch it up with the write stage. -+ #pragma unroll -+ for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) -+ { -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ // If we bumped the read iterators to the end of the circular buffer, wrap them around to -+ // align them with the write iterators -+ if (smem_write_stage_idx == 0) -+ { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) ///< source accumulator tile -+ { -+ // Prologue -+ prologue(iterator_A, iterator_B, gemm_k_iterations); -+ -+ // Wait until we have at least one completed global fetch stage -+ gmem_wait(); -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Perform the MAC-iterations -+ gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h -new file mode 100644 -index 0000000..d21600e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_base.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaPlanarComplexBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Stride to the imaginary part of the A operand -+ static int const kImaginaryStrideA = ShapeA::kCount; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ /// Stride to the imaginary part of the A operand -+ static int const kImaginaryStrideB = ShapeB::kCount; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h -new file mode 100644 -index 0000000..b7edd51 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h -@@ -0,0 +1,640 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Transformation applied to A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplexMultistage : -+ public MmaPlanarComplexBase { -+public: -+ ///< Base class -+ using Base = MmaPlanarComplexBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Archtecture tag -+ using ArchTag = arch::Sm80; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ /// Transformation applied to A -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B -+ static ComplexTransform const kTransformB = TransformB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = ArrayPlanarComplex< -+ typename Policy::Operator::FragmentC::Element, -+ Policy::Operator::FragmentC::kElements -+ >; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const TBLoadIterationsA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const TBLoadIterationsB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ static int const kAccessesPerGroupA = -+ (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ static int const kAccessesPerGroupB = -+ (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance( -+ IteratorA &iterator_A_real, -+ IteratorA &iterator_A_imag, -+ -+ IteratorB &iterator_B_real, -+ IteratorB &iterator_B_imag, -+ -+ int group_start_A = 0, -+ int group_start_B = 0) { -+ -+ iterator_A_real.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ iterator_A_imag.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ -+ auto gmem_ptr_real = iterator_A_real.get(); -+ auto gmem_ptr_imag = iterator_A_imag.get(); -+ -+ bool pred_guard = iterator_A_real.valid(); -+ cutlass::arch::cp_async( -+ dst_ptr + v, -+ gmem_ptr_real, -+ pred_guard); -+ cutlass::arch::cp_async( -+ dst_ptr + v + (Base::SharedStorage::kImaginaryStrideA / IteratorA::ThreadMap::kElementsPerAccess), -+ reinterpret_cast(gmem_ptr_imag), -+ pred_guard); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B_real.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); -+ iterator_B_imag.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr_real = iterator_B_real.get(); -+ auto gmem_ptr_imag = iterator_B_imag.get(); -+ -+ bool pred_guard = iterator_B_real.valid(); -+ cutlass::arch::cp_async( -+ dst_ptr + v, -+ gmem_ptr_real, -+ pred_guard); -+ cutlass::arch::cp_async( -+ dst_ptr + v + (Base::SharedStorage::kImaginaryStrideB / IteratorB::ThreadMap::kElementsPerAccess), -+ reinterpret_cast(gmem_ptr_imag), -+ pred_guard); -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void warp_mma_planar_complex( -+ Operator & warp_mma, -+ FragmentC &accum, -+ WarpFragmentA const & real_A, -+ WarpFragmentA const & imag_A, -+ WarpFragmentB const & real_B, -+ WarpFragmentB const & imag_B) { -+ -+ cutlass::negate> neg_op_B; -+ -+ WarpFragmentB neg_real_B = neg_op_B(real_B); -+ WarpFragmentB neg_imag_B = neg_op_B(imag_B); -+ -+ warp_mma(accum.real, real_A, real_B, accum.real); -+ -+ if (kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.imag, real_A, imag_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone) { -+ warp_mma(accum.imag, imag_A, real_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.real, imag_A, imag_B, accum.real); -+ } -+ else { -+ warp_mma(accum.real, imag_A, neg_imag_B, accum.real); -+ } -+ } -+ -+public: -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_real, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_imag, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_real, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_imag, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A_real.set_iteration_index(0); -+ iterator_A_imag.set_iteration_index(0); -+ -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Load for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { -+ -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; -+ -+ bool pred_guard = iterator_A_real.valid(); -+ -+ auto src_ptr_real = iterator_A_real.get(); -+ auto src_ptr_imag = iterator_A_imag.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, src_ptr_real, pred_guard); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v + -+ Base::SharedStorage::kImaginaryStrideA / -+ IteratorA::ThreadMap::kElementsPerAccess, -+ reinterpret_cast(src_ptr_imag), -+ pred_guard); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B_real.set_iteration_index(0); -+ iterator_B_imag.set_iteration_index(0); -+ -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Load for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { -+ -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast(this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; -+ -+ bool pred_guard = iterator_B_real.valid(); -+ -+ auto src_ptr_real = iterator_B_real.get(); -+ auto src_ptr_imag = iterator_B_imag.get(); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, src_ptr_real, pred_guard); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v + -+ Base::SharedStorage::kImaginaryStrideB / -+ IteratorB::ThreadMap::kElementsPerAccess, -+ reinterpret_cast(src_ptr_imag), -+ pred_guard); -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A_real.add_tile_offset({0, 1}); -+ iterator_A_imag.add_tile_offset({0, 1}); -+ -+ iterator_B_real.add_tile_offset({1, 0}); -+ iterator_B_imag.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Inserts a memory fence between stages of cp.async instructions -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Blocks until all but kStages-2 cp.async stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ -+ WarpFragmentA warp_frag_real_A[2]; -+ WarpFragmentA warp_frag_imag_A[2]; -+ -+ WarpFragmentB warp_frag_real_B[2]; -+ WarpFragmentB warp_frag_imag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } -+ else { -+ group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance( -+ iterator_A_real, -+ iterator_A_imag, -+ iterator_B_real, -+ iterator_B_imag, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ // Inserts a memory fence between stages of cp.async instructions -+ cutlass::arch::cp_async_fence(); -+ -+ // Blocks until all but kStages-2 cp.async stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A_real.add_tile_offset({0, 1}); -+ iterator_A_imag.add_tile_offset({0, 1}); -+ -+ iterator_B_real.add_tile_offset({1, 0}); -+ iterator_B_imag.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A_real.clear_mask(gemm_k_iterations == 0); -+ iterator_A_imag.clear_mask(gemm_k_iterations == 0); -+ iterator_B_real.clear_mask(gemm_k_iterations == 0); -+ iterator_B_imag.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ warp_mma_planar_complex( -+ warp_mma, -+ accum, -+ warp_frag_real_A[warp_mma_k % 2], -+ warp_frag_imag_A[warp_mma_k % 2], -+ warp_frag_real_B[warp_mma_k % 2], -+ warp_frag_imag_B[warp_mma_k % 2]); -+ } -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h -new file mode 100644 -index 0000000..160c548 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_planar_complex_pipelined.h -@@ -0,0 +1,424 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_planar_complex_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Transformation applied to A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplexPipelined : -+ public MmaPlanarComplexBase { -+public: -+ ///< Base class -+ using Base = MmaPlanarComplexBase; -+ -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ /// Transformation applied to A -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B -+ static ComplexTransform const kTransformB = TransformB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = ArrayPlanarComplex< -+ typename Policy::Operator::FragmentC::Element, -+ Policy::Operator::FragmentC::kElements -+ >; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ private: -+ -+ using FragmentA = typename IteratorA::Fragment; -+ using FragmentB = typename IteratorB::Fragment; -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaPlanarComplexPipelined( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void warp_mma_planar_complex( -+ Operator & warp_mma, -+ FragmentC &accum, -+ WarpFragmentA const & real_A, -+ WarpFragmentA const & imag_A, -+ WarpFragmentB const & real_B, -+ WarpFragmentB const & imag_B) { -+ -+ cutlass::negate> neg_op_B; -+ -+ WarpFragmentB neg_real_B = neg_op_B(real_B); -+ WarpFragmentB neg_imag_B = neg_op_B(imag_B); -+ -+ warp_mma(accum.real, real_A, real_B, accum.real); -+ -+ if (kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.imag, real_A, imag_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, real_A, neg_imag_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone) { -+ warp_mma(accum.imag, imag_A, real_B, accum.imag); -+ } -+ else { -+ warp_mma(accum.imag, imag_A, neg_real_B, accum.imag); -+ } -+ -+ if (kTransformA == ComplexTransform::kNone ^ kTransformB == ComplexTransform::kNone) { -+ warp_mma(accum.real, imag_A, imag_B, accum.real); -+ } -+ else { -+ warp_mma(accum.real, imag_A, neg_imag_B, accum.real); -+ } -+ } -+ -+public: -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_real, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A_imag, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_real, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B_imag, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A_real; -+ FragmentA tb_frag_A_imag; -+ -+ FragmentB tb_frag_B_real; -+ FragmentB tb_frag_B_imag; -+ -+ tb_frag_A_real.clear(); -+ tb_frag_A_imag.clear(); -+ -+ tb_frag_B_real.clear(); -+ tb_frag_B_imag.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A_real.load(tb_frag_A_real); -+ iterator_A_imag.load(tb_frag_A_imag); -+ -+ iterator_B_real.load(tb_frag_B_real); -+ iterator_B_imag.load(tb_frag_B_imag); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ -+ this->smem_iterator_A_.store(tb_frag_A_real); -+ this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); -+ -+ this->smem_iterator_B_.store(tb_frag_B_real); -+ this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->smem_iterator_A_; -+ ++this->smem_iterator_B_; -+ -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_real_A[2]; -+ WarpFragmentA warp_frag_imag_A[2]; -+ -+ WarpFragmentB warp_frag_real_B[2]; -+ WarpFragmentB warp_frag_imag_B[2]; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[0]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[0]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB); -+ -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ Operator warp_mma; -+ -+ int smem_write_stage_idx = 1; -+ -+ // Avoid reading out of bounds -+ iterator_A_real.clear_mask(gemm_k_iterations <= 1); -+ iterator_A_imag.clear_mask(gemm_k_iterations <= 1); -+ -+ iterator_B_real.clear_mask(gemm_k_iterations <= 1); -+ iterator_B_imag.clear_mask(gemm_k_iterations <= 1); -+ -+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing -+ // shared memory loads (which have the tighest latency requirement). -+ -+ // -+ // Mainloop -+ // -+ -+ // Note: The main loop does not support Base::kWarpGemmIterations == 2. -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ if (warp_mma_k == Base::kWarpGemmIterations - 1) { -+ -+ // Write fragments to shared memory -+ this->smem_iterator_A_.store(tb_frag_A_real); -+ this->smem_iterator_A_.store_with_pointer_offset(tb_frag_A_imag, Base::SharedStorage::kImaginaryStrideA); -+ -+ this->smem_iterator_B_.store(tb_frag_B_real); -+ this->smem_iterator_B_.store_with_pointer_offset(tb_frag_B_imag, Base::SharedStorage::kImaginaryStrideB); -+ -+ __syncthreads(); -+ -+ ++this->smem_iterator_B_; -+ ++this->smem_iterator_A_; -+ -+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory -+ if (smem_write_stage_idx == 1) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ } -+ else { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, -+ 0}); -+ } -+ -+ smem_write_stage_idx ^= 1; -+ } -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_real_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_A_.load_with_pointer_offset(warp_frag_imag_A[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideA); -+ -+ this->warp_tile_iterator_B_.load(warp_frag_real_B[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load_with_pointer_offset(warp_frag_imag_B[(warp_mma_k + 1) % 2], Base::SharedStorage::kImaginaryStrideB); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k == 0) { -+ -+ iterator_A_real.load(tb_frag_A_real); -+ iterator_A_imag.load(tb_frag_A_imag); -+ -+ iterator_B_real.load(tb_frag_B_real); -+ iterator_B_imag.load(tb_frag_B_imag); -+ -+ ++iterator_A_real; -+ ++iterator_A_imag; -+ ++iterator_B_real; -+ ++iterator_B_imag; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A_real.clear_mask(gemm_k_iterations <= 2); -+ iterator_A_imag.clear_mask(gemm_k_iterations <= 2); -+ iterator_B_real.clear_mask(gemm_k_iterations <= 2); -+ iterator_B_imag.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ warp_mma_planar_complex( -+ warp_mma, -+ accum, -+ warp_frag_real_A[warp_mma_k % 2], -+ warp_frag_imag_A[warp_mma_k % 2], -+ warp_frag_real_B[warp_mma_k % 2], -+ warp_frag_imag_B[warp_mma_k % 2]); -+ } -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h -new file mode 100644 -index 0000000..3ce8ac8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_singlestage.h -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaSingleStage : public MmaBase { -+public: -+ -+ ///< Base class -+ using Base = MmaBase; -+ -+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory -+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory -+ using ElementC = ElementC_; ///< Data type of accumulator matrix -+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix -+ using Policy = Policy_; ///< Policy describing tuning details -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of operand A loaded from global memory -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Fragment of operand B loaded from global memory -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ using ArchTag = arch::Sm70; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ // staticaly assert kStages for MmaSingleStage is 1 (single stage mma pipeline) -+ static_assert((Base::kStages==1), "MmaSingleStage requires kStages set to value 1"); -+private: -+ -+ using WarpFragmentA = typename Operator::FragmentA; -+ using WarpFragmentB = typename Operator::FragmentB; -+ -+protected: -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaSingleStage( -+ typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ int thread_idx, ///< ID within the threadblock -+ int warp_idx, ///< ID of warp -+ int lane_idx ///< ID of each thread within a warp -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ int gemm_k_iterations, ///< number of iterations of the mainloop -+ FragmentC &accum, ///< destination accumulator tile -+ IteratorA iterator_A, ///< iterator over A operand in global memory -+ IteratorB iterator_B, ///< iterator over B operand in global memory -+ FragmentC const &src_accum) { ///< source accumualtor tile -+ -+ // -+ // Prologue -+ // -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ FragmentA tb_frag_A; -+ FragmentB tb_frag_B; -+ -+ tb_frag_A.clear(); -+ tb_frag_B.clear(); -+ -+ // The last kblock is loaded in the prolog -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Pair of fragments used to overlap shared memory loads and math instructions -+ WarpFragmentA warp_frag_A; -+ WarpFragmentB warp_frag_B; -+ -+ Operator warp_mma; -+ -+ // Avoid reading out of bounds -+ iterator_A.clear_mask(gemm_k_iterations <= 1); -+ iterator_B.clear_mask(gemm_k_iterations <= 1); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > 0; --gemm_k_iterations) { -+ this->smem_iterator_A_.store(tb_frag_A); -+ this->smem_iterator_B_.store(tb_frag_B); -+ -+ __syncthreads(); -+ -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group -+ // as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_frag_A); -+ this->warp_tile_iterator_B_.load(warp_frag_B); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ warp_mma(accum, warp_frag_A, warp_frag_B, accum); -+ } -+ -+ // Add negative offsets to return smem load iterators to the 'start' of the shared memory -+ this->warp_tile_iterator_A_.add_tile_offset({0, -Policy::kPartitionsK * Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); -+ -+ __syncthreads(); -+ -+ iterator_A.load(tb_frag_A); -+ iterator_B.load(tb_frag_B); -+ -+ ++iterator_A; -+ ++iterator_B; -+ -+ // Avoid reading out of bounds if this was the last loop iteration -+ iterator_A.clear_mask(gemm_k_iterations <= 2); -+ iterator_B.clear_mask(gemm_k_iterations <= 2); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h -new file mode 100644 -index 0000000..905283e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h -@@ -0,0 +1,751 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+ -+ It loads two loop invariant vectors, norm and sum, in the prologue and -+ stores them in the register file. We will call elementwise operation to -+ apply norm and sum between ldmatrix and warp mma. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" -+#include "cutlass/gemm/threadblock/mma_base.h" -+#include "cutlass/gemm/warp/softmax_scale_bias_transform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaMainloopFusionBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = cutlass::gemm::GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaMainloopFusionBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx) -+ : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} -+}; -+ -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Iterates over vectors of var and mean vector in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorNormSum_, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Whether problem has been transformed. This determines to which operand -+ /// the softmax is applied. -+ bool InternalTranspose, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaSoftmaxMainloopFusionMultistage : -+ public MmaMainloopFusionBase { -+public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of the var and mean vectors in global memory -+ using IteratorNormSum = IteratorNormSum_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ ///< Base class -+ using Base = MmaMainloopFusionBase; -+ -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ static_assert(Base::kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ using WarpLoadedFragmentNormSum = typename IteratorNormSum::Fragment; -+ -+ static bool const kInternalTranspose = InternalTranspose; -+ -+ using SoftmaxFragment = typename platform::conditional::type; -+ -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ int warp_idx_m_; -+ -+ int warp_idx_n_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaSoftmaxMainloopFusionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; -+ warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, -+ IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over B operand in global memory -+ IteratorNormSum iterator_norm_sum, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ WarpLoadedFragmentNormSum warp_loaded_frag_norm_sum; -+ iterator_norm_sum.add_tile_offset({0, warp_idx_m_}); -+ iterator_norm_sum.load(warp_loaded_frag_norm_sum); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ cutlass::gemm::warp::SoftmaxScaleBiasTransform< -+ SoftmaxFragment, WarpLoadedFragmentNormSum> elementwise_transform; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ // Start issuing the first group of the next stage outside of the mainloop -+ copy_tiles_and_advance(iterator_A, iterator_B); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[0], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[0], -+ warp_loaded_frag_norm_sum); -+ } -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) { -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_norm_sum); -+ } -+ } -+ -+ // Issue global->shared copies for the next stage -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ group_start_iteration_A = 0; -+ group_start_iteration_B = 0; -+ } else { -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ } -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, -+ group_start_iteration_A, -+ group_start_iteration_B); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum -+ ); -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) { -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ if (kInternalTranspose) { -+ elementwise_transform(warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_norm_sum); -+ } else { -+ elementwise_transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_norm_sum); -+ } -+ } -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h -new file mode 100644 -index 0000000..9f82a7f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_base.h -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy object describing MmaTensorOp -+template < -+ /// Warp-level GEMM operator (concept: gemm::warp::Mma) -+ typename Operator_, -+ /// Padding used for A operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingA_, -+ /// Padding used for B operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingB_, -+ /// Padding used for E operand in shared memory (concept: MatrixShape) -+ typename SmemPaddingE_, -+ /// Number of partitions of K dimension of GEMM -+ int PartitionsK = 1> -+struct SparseMmaPolicy { -+ /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) -+ using Operator = Operator_; -+ -+ /// Padding used for A operand in shared memory -+ using SmemPaddingA = SmemPaddingA_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingB = SmemPaddingB_; -+ -+ /// Padding used for B operand in shared memory -+ using SmemPaddingE = SmemPaddingE_; -+ -+ /// Number of partitions of K dimension -+ static int const kPartitionsK = PartitionsK; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class SparseMmaBase { -+ public: -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// Shape describing the overall GEMM computed from shared memory -+ /// by each warp. -+ using WarpGemm = typename Policy::Operator::Shape; -+ -+ /// Shape describing the number of warps filling the CTA -+ using WarpCount = GemmShape; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = -+ (WarpGemm::kK / Operator::Policy::MmaShape::kK); -+ -+ static_assert(kWarpGemmIterations > 1, -+ "The pipelined structure requires at least two warp-level " -+ "GEMM operations."); -+ -+ static_assert((kWarpGemmIterations % 2) == 0, -+ "Inner loop iteration must be an even number."); -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ static int const kSparse = Operator::kSparse; -+ -+ static int const kElementsPerElementE = Operator::kElementsPerElementE; -+ -+ /// Tensor reference to the A operand -+ using TensorRefA = TensorRef; -+ -+ /// Tensor reference to the B operand -+ using TensorRefB = TensorRef; -+ -+ /// Tensor reference to the E operand -+ using TensorRefE = TensorRef; -+ -+ // -+ // Nested structs -+ // -+ -+ /// Shared storage object needed by threadblock-scoped GEMM -+ class SharedStorage { -+ public: -+ // -+ // Type definitions -+ // -+ -+ /// Shape of the A matrix operand in shared memory -+ using ShapeA = MatrixShape; -+ -+ /// Shape of the B matrix operand in shared memory -+ using ShapeB = -+ MatrixShape; -+ -+ /// Shape of the E matrix operand in shared memory -+ using ShapeE = -+ MatrixShape; -+ -+ public: -+ // -+ // Data members -+ // -+ -+ /// Buffer for A operand -+ AlignedBuffer operand_A; -+ -+ /// Buffer for B operand -+ AlignedBuffer operand_B; -+ -+ /// Buffer for E operand -+ AlignedBuffer operand_E; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns a layout object for the A matrix -+ CUTLASS_DEVICE -+ static typename Operator::LayoutA LayoutA() { -+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); -+ } -+ -+ /// Returns a layout object for the B matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutB LayoutB() { -+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); -+ } -+ -+ /// Returns a layout object for the E matrix -+ CUTLASS_HOST_DEVICE -+ static typename Operator::LayoutE LayoutE() { -+ return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}); -+ } -+ -+ /// Returns a TensorRef to the A operand -+ CUTLASS_HOST_DEVICE -+ TensorRefA operand_A_ref() { -+ return TensorRefA{operand_A.data(), LayoutA()}; -+ } -+ -+ /// Returns a TensorRef to the B operand -+ CUTLASS_HOST_DEVICE -+ TensorRefB operand_B_ref() { -+ return TensorRefB{operand_B.data(), LayoutB()}; -+ } -+ -+ /// Returns a TensorRef to the E operand -+ CUTLASS_HOST_DEVICE -+ TensorRefE operand_E_ref() { -+ return TensorRefE{operand_E.data(), LayoutE()}; -+ } -+ }; -+ -+ protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to load a warp-scoped tile of A operand from shared memory -+ typename Operator::IteratorA warp_tile_iterator_A_; -+ -+ /// Iterator to load a warp-scoped tile of B operand from shared memory -+ typename Operator::IteratorB warp_tile_iterator_B_; -+ -+ /// Iterator to load a warp-scoped tile of E operand from shared memory -+ typename Operator::IteratorE warp_tile_iterator_E_; -+ -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ SparseMmaBase( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), -+ warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), -+ warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h -new file mode 100644 -index 0000000..beb58c8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_sparse_multistage.h -@@ -0,0 +1,662 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_sparse_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Iterates over tiles of E operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorE_, -+ /// Iterates over tiles of E operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorE_, -+ /// Cache operation for operand E -+ cutlass::arch::CacheOperation::Kind CacheOpE, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Used for partial specialization -+ typename Enable = bool> -+class SparseMmaMultistage : -+ public SparseMmaBase { -+public: -+ ///< Base class -+ using Base = SparseMmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Iterates over tiles of E operand in global memory -+ using IteratorE = IteratorE_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ using SmemIteratorE = SmemIteratorE_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE; -+ -+ static int const kSparse = Policy::Operator::kSparse; -+ static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; -+ static int const kMaxID2 = Policy::Operator::kMaxID2; -+ static int const kElementsPerElementE = -+ Policy::Operator::kElementsPerElementE; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ /// ElementE -+ using ElementE = typename IteratorE::Element; -+ -+ /// LayoutE -+ using LayoutE = typename IteratorE::Layout; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of async copies to load one stage of operand A -+ static int const TBLoadIterationsA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of async copies to load one stage of operand B -+ static int const TBLoadIterationsB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of async copies to load one stage of operand E -+ static int const TBLoadIterationsE = -+ IteratorE::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of async copies to load one group of operand A -+ static int const kAccessesPerGroupA = -+ (TBLoadIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of async copies to load one group of operand B -+ static int const kAccessesPerGroupB = -+ (TBLoadIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of async copies to load one group of operand E -+ static int const kAccessesPerGroupE = -+ (TBLoadIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// E operand is tiny. For the most of time, not all the warps are needed -+ /// to load it from the global memory. -+ static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32; -+ -+ /// B operand is twice as big as A which brings very high register pressure. -+ /// We have to sacrifice the double buffer when the warp tile size is big. -+ static int const kBBufferSize = -+ ((sizeof(typename Operator::ElementC) == 4) && -+ ((platform::is_same::value && -+ platform::is_same::value)) && -+ (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) -+ ? 1 -+ : 2; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ using WarpFragmentE = typename Operator::FragmentE; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+ /// Iterator to write threadblock-scoped tile of E operand to shared memory -+ SmemIteratorE smem_iterator_E_; -+ -+ /// Warp id -+ bool is_warp_valid_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ SparseMmaMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), -+ smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx) -+ { -+ is_warp_valid_ = warp_idx < Detail::kValidWarps; -+ -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ this->warp_tile_iterator_E_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ IteratorE &iterator_E, int group_start_A = 0, -+ int group_start_B = 0, int group_start_E = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // async copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::TBLoadIterationsA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // async copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::TBLoadIterationsB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ -+ iterator_E.set_iteration_index(group_start_E); -+ this->smem_iterator_E_.set_iteration_index(group_start_E); -+ -+ // async copy for operand E -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) { -+ if (group_start_E + j < Detail::TBLoadIterationsE) { -+ typename IteratorE::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_E_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorE::ThreadMap::kElementsPerAccess / 8; -+ -+ auto gmem_ptr = iterator_E.get(); -+ -+ cutlass::arch::cp_async( -+ dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_); -+ -+ ++iterator_E; -+ ++this->smem_iterator_E_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< iterator over E operand in global memory -+ IteratorE iterator_E, -+ ///< initial value of accumulator -+ FragmentC const &src_accum) { -+ -+ // -+ // Prologue -+ // -+ -+ // Issue several complete stages -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // async copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // async copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ iterator_E.set_iteration_index(0); -+ this->smem_iterator_E_.set_iteration_index(0); -+ -+ // async copy for operand E -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::TBLoadIterationsE; ++j) { -+ typename IteratorE::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_E_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorE::ThreadMap::kElementsPerAccess / 8; -+ if (is_warp_valid_) -+ cutlass::arch::cp_async_zfill( -+ dst_ptr, iterator_E.get(), iterator_E.valid()); -+ -+ ++iterator_E; -+ -+ ++this->smem_iterator_E_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ iterator_E.add_tile_offset({0, 1}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, 1}); -+ -+ // cp.async.commit_group - completes a stage -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize]; -+ WarpFragmentE warp_frag_E[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ this->warp_tile_iterator_E_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ this->warp_tile_iterator_E_.load(warp_frag_E[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ ++this->warp_tile_iterator_E_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_E_; -+ -+ if (Detail::kBBufferSize == 2) { -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load( -+ warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); -+ ++this->warp_tile_iterator_B_; -+ } -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum, -+ warp_frag_E[warp_mma_k % 2] -+ ); -+ -+ if (Detail::kBBufferSize == 1) { -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ ++this->warp_tile_iterator_B_; -+ -+ } -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, iterator_E, group_start_iteration_A, -+ group_start_iteration_B, group_start_iteration_E); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ group_start_iteration_E = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupE; -+ -+ copy_tiles_and_advance( -+ iterator_A, iterator_B, iterator_E, group_start_iteration_A, -+ group_start_iteration_B, group_start_iteration_E); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ iterator_E.add_tile_offset({0, 1}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, 1}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ this->smem_iterator_E_.add_tile_offset({0, -Base::kStages}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ this->warp_tile_iterator_E_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ iterator_E.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h -new file mode 100644 -index 0000000..fb0e92e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template for a double-buffered threadblock-scoped GEMM kernel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/threadblock/mma_base.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math -+/// instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Iterates over tiles of A operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorA_, -+ /// Iterates over tiles of A operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorA_, -+ /// Cache operation for operand A -+ cutlass::arch::CacheOperation::Kind CacheOpA, -+ /// Iterates over tiles of B operand in global memory -+ // (concept: ReadableTileIterator | ForwardTileIterator | -+ // MaskedTileIterator) -+ typename IteratorB_, -+ /// Iterates over tiles of B operand in shared memory -+ /// (concept: WriteableTileIterator | RandomAccessTileIterator) -+ typename SmemIteratorB_, -+ /// Cache operation for operand B -+ cutlass::arch::CacheOperation::Kind CacheOpB, -+ /// Data type of accumulator matrix -+ typename ElementC_, -+ /// Data type of accumulator matrix -+ typename LayoutC_, -+ /// Policy describing tuning details (concept: MmaPolicy) -+ typename Policy_, -+ /// Number of stages, -+ int Stages, -+ /// Use zfill or predicate for out-of-bound cp.async -+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, -+ /// Used for partial specialization -+ typename Enable = bool> -+class MmaWithReductionMultistage : -+ public MmaBase { -+public: -+ ///< Base class -+ using Base = MmaBase; -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ using Shape = Shape_; -+ ///< Iterates over tiles of A operand in global memory -+ using IteratorA = IteratorA_; -+ ///< Iterates over tiles of B operand in global memory -+ using IteratorB = IteratorB_; -+ ///< Data type of accumulator matrix -+ using ElementC = ElementC_; -+ ///< Layout of accumulator matrix -+ using LayoutC = LayoutC_; -+ ///< Policy describing tuning details -+ using Policy = Policy_; -+ -+ using SmemIteratorA = SmemIteratorA_; -+ using SmemIteratorB = SmemIteratorB_; -+ -+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; -+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; -+ -+ // -+ // Dependent types -+ // -+ -+ /// Fragment of accumulator tile -+ using FragmentC = typename Policy::Operator::FragmentC; -+ -+ /// Warp-level Mma -+ using Operator = typename Policy::Operator; -+ -+ using FragmentReduction = typename Operator::FragmentReduction; -+ -+ /// Minimum architecture is Sm80 to support cp.async -+ using ArchTag = arch::Sm80; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Operator::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Operator::kTransformB; -+ -+ static int const kReduceKForA = Operator::kReduceKForA; -+ -+ /// Internal structure exposed for introspection. -+ struct Detail { -+ -+ /// Number of cp.async instructions to load one stage of operand A -+ static int const AsyncCopyIterationsPerStageA = -+ IteratorA::ThreadMap::Iterations::kCount; -+ -+ /// Number of cp.async instructions to load one stage of operand B -+ static int const AsyncCopyIterationsPerStageB = -+ IteratorB::ThreadMap::Iterations::kCount; -+ -+ /// Number of stages -+ static int const kStages = Stages; -+ -+ /// Number of cp.async instructions to load on group of operand A -+ static int const kAccessesPerGroupA = -+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ -+ /// Number of cp.async instructions to load on group of operand B -+ static int const kAccessesPerGroupB = -+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; -+ }; -+ -+ private: -+ -+ using WarpLoadedFragmentA = typename Operator::FragmentA; -+ using WarpLoadedFragmentB = typename Operator::FragmentB; -+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; -+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Iterator to write threadblock-scoped tile of A operand to shared memory -+ SmemIteratorA smem_iterator_A_; -+ -+ /// Iterator to write threadblock-scoped tile of B operand to shared memory -+ SmemIteratorB smem_iterator_B_; -+ -+public: -+ -+ /// Construct from tensor references -+ CUTLASS_DEVICE -+ MmaWithReductionMultistage( -+ ///< Shared storage needed for internal use by threadblock-scoped GEMM -+ typename Base::SharedStorage &shared_storage, -+ ///< ID within the threadblock -+ int thread_idx, -+ ///< ID of warp -+ int warp_idx, -+ ///< ID of each thread within a warp -+ int lane_idx -+ ): -+ Base(shared_storage, thread_idx, warp_idx, lane_idx), -+ smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), -+ smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) -+ { -+ // Compute warp location within threadblock tile by mapping the warp_id to -+ // three coordinates: -+ // _m: the warp's position within the threadblock along the M dimension -+ // _n: the warp's position within the threadblock along the N dimension -+ // _k: the warp's position within the threadblock along the K dimension -+ -+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); -+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); -+ -+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; -+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; -+ -+ // Add per-warp offsets in units of warp-level tiles -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); -+ } -+ -+ CUTLASS_DEVICE -+ void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, -+ int group_start_A = 0, int group_start_B = 0) { -+ iterator_A.set_iteration_index(group_start_A * -+ IteratorA::kAccessesPerVector); -+ this->smem_iterator_A_.set_iteration_index(group_start_A); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { -+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_A.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_A.valid()); -+ } -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ } -+ -+ iterator_B.set_iteration_index(group_start_B * -+ IteratorB::kAccessesPerVector); -+ this->smem_iterator_B_.set_iteration_index(group_start_B); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { -+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ int const kSrcBytes = sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ auto gmem_ptr = iterator_B.get(); -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } else { -+ cutlass::arch::cp_async( -+ dst_ptr + v, gmem_ptr, iterator_B.valid()); -+ } -+ -+ ++iterator_B; -+ } -+ ++this->smem_iterator_B_; -+ } -+ } -+ } -+ -+ /// Perform a threadblock-scoped matrix multiply-accumulate -+ CUTLASS_DEVICE -+ void operator()( -+ ///< problem size of GEMM -+ int gemm_k_iterations, -+ ///< destination accumulator tile -+ FragmentC &accum, -+ ///< iterator over A operand in global memory -+ IteratorA iterator_A, -+ ///< iterator over B operand in global memory -+ IteratorB iterator_B, -+ ///< initial value of accumulator -+ FragmentC const &src_accum, -+ FragmentReduction &gemm_k_reduction_accum) { -+ -+ // -+ // Prologue -+ // -+ // Issue several complete stages -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int stage = 0; stage < Base::kStages - 1; -+ ++stage, --gemm_k_iterations) { -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ iterator_A.set_iteration_index(0); -+ this->smem_iterator_A_.set_iteration_index(0); -+ -+ // Async Copy for operand A -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { -+ typename IteratorA::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_A_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorA::ThreadMap::kElementsPerAccess / -+ IteratorA::kAccessesPerVector / 8; -+ -+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_A.get(), iterator_A.valid()); -+ -+ ++iterator_A; -+ } -+ -+ ++this->smem_iterator_A_; -+ } -+ -+ iterator_B.set_iteration_index(0); -+ this->smem_iterator_B_.set_iteration_index(0); -+ -+ // Async Copy for operand B -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { -+ typename IteratorB::AccessType *dst_ptr = -+ reinterpret_cast( -+ this->smem_iterator_B_.get()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { -+ int const kSrcBytes = -+ sizeof_bits::value * -+ IteratorB::ThreadMap::kElementsPerAccess / -+ IteratorB::kAccessesPerVector / 8; -+ -+ cutlass::arch::cp_async_zfill( -+ dst_ptr + v, iterator_B.get(), iterator_B.valid()); -+ -+ ++iterator_B; -+ } -+ -+ ++this->smem_iterator_B_; -+ } -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Defines the boundary of a stage of cp.async. -+ cutlass::arch::cp_async_fence(); -+ } -+ -+ // Perform accumulation in the 'd' output operand -+ accum = src_accum; -+ -+ // Waits until kStages-2 stages have committed. -+ cutlass::arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Pair of fragments used to overlap shared memory loads and math -+ // instructions -+ WarpLoadedFragmentA warp_loaded_frag_A[2]; -+ WarpLoadedFragmentB warp_loaded_frag_B[2]; -+ WarpTransformedFragmentA warp_transformed_frag_A[2]; -+ WarpTransformedFragmentB warp_transformed_frag_B[2]; -+ -+ Operator warp_mma; -+ -+ this->warp_tile_iterator_A_.set_kgroup_index(0); -+ this->warp_tile_iterator_B_.set_kgroup_index(0); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ -+ int smem_write_stage_idx = Base::kStages - 1; -+ int smem_read_stage_idx = 0; -+ -+ warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], -+ warp_loaded_frag_A[0], warp_loaded_frag_B[0]); -+ -+ // -+ // Mainloop -+ // -+ -+ CUTLASS_GEMM_LOOP -+ for (; gemm_k_iterations > (-Base::kStages + 1);) { -+ // -+ // Loop over GEMM K dimension -+ // -+ -+ // Computes a warp-level GEMM on data held in shared memory -+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate -+ CUTLASS_PRAGMA_UNROLL -+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; -+ ++warp_mma_k) { -+ -+ // Load warp-level tiles from shared memory, wrapping to k offset if -+ // this is the last group as the case may be. -+ -+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); -+ -+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); -+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ -+ ++this->warp_tile_iterator_A_; -+ ++this->warp_tile_iterator_B_; -+ -+ if (warp_mma_k > 0) -+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ warp_loaded_frag_A[warp_mma_k % 2], -+ warp_loaded_frag_B[warp_mma_k % 2]); -+ -+ warp_mma( -+ accum, -+ warp_transformed_frag_A[warp_mma_k % 2], -+ warp_transformed_frag_B[warp_mma_k % 2], -+ accum, -+ gemm_k_reduction_accum -+ ); -+ -+ // Issue global->shared copies for the this stage -+ if (warp_mma_k < Base::kWarpGemmIterations - 1) { -+ int group_start_iteration_A, group_start_iteration_B; -+ -+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ } -+ -+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) { -+ int group_start_iteration_A, group_start_iteration_B; -+ group_start_iteration_A = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupA; -+ group_start_iteration_B = -+ (warp_mma_k + 1) * Detail::kAccessesPerGroupB; -+ -+ copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, -+ group_start_iteration_B); -+ -+ // Inserts a memory fence between stages of cp.async instructions. -+ cutlass::arch::cp_async_fence(); -+ -+ // Waits until kStages-2 stages have committed. -+ arch::cp_async_wait(); -+ __syncthreads(); -+ -+ // Move to the next stage -+ iterator_A.add_tile_offset({0, 1}); -+ iterator_B.add_tile_offset({1, 0}); -+ -+ this->smem_iterator_A_.add_tile_offset({0, 1}); -+ this->smem_iterator_B_.add_tile_offset({1, 0}); -+ -+ // Add negative offsets to return iterators to the 'start' of the -+ // circular buffer in shared memory -+ if (smem_write_stage_idx == (Base::kStages - 1)) { -+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); -+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); -+ smem_write_stage_idx = 0; -+ } else { -+ ++smem_write_stage_idx; -+ } -+ -+ if (smem_read_stage_idx == (Base::kStages - 1)) { -+ this->warp_tile_iterator_A_.add_tile_offset( -+ {0, -Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations}); -+ this->warp_tile_iterator_B_.add_tile_offset( -+ {-Base::kStages * Policy::kPartitionsK * -+ Base::kWarpGemmIterations, -+ 0}); -+ smem_read_stage_idx = 0; -+ } else { -+ ++smem_read_stage_idx; -+ } -+ -+ --gemm_k_iterations; -+ iterator_A.clear_mask(gemm_k_iterations == 0); -+ iterator_B.clear_mask(gemm_k_iterations == 0); -+ } -+ -+ // Do any conversions feeding the first stage at the end of the loop so -+ // we can start right away on mma instructions -+ if (warp_mma_k + 1 == Base::kWarpGemmIterations) -+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], -+ warp_transformed_frag_B[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_A[(warp_mma_k + 1) % 2], -+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]); -+ } -+ -+ } -+ -+ if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { -+ // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop -+ cutlass::arch::cp_async_fence(); -+ cutlass::arch::cp_async_wait<0>(); -+ __syncthreads(); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h -new file mode 100644 -index 0000000..48c1737 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle.h -@@ -0,0 +1,459 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements several possible threadblock-swizzling functions mapping blockIdx to -+ GEMM problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/gemm/threadblock/index_remat.h" -+#include "cutlass/gemm/threadblock/threadblock_swizzle_streamk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+template -+struct GemmIdentityThreadblockSwizzle { -+ -+ CUTLASS_HOST_DEVICE -+ GemmIdentityThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *Gemm* problem size: gemm(M, N, K) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return get_tiled_shape( -+ implicit_gemm_problem_size, tile_size, split_k_slices); -+ } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ /// *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv3dProblemSize const &problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ gemm::GemmCoord implicit_gemm_problem_size = -+ cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); -+ -+ return get_tiled_shape( -+ implicit_gemm_problem_size, tile_size, split_k_slices); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ int tile = 1 << get_log_tile(tiled_shape); -+ return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ auto n = tiled_shape.n(); -+ // Thresholds picked so that it doesn't cause too many no-op CTAs -+ if (N >= 8 && n >= 6) -+ return 3; -+ else if (N >= 4 && n >= 3) -+ return 2; -+ else if (N >= 2 && n >= 2) -+ return 1; -+ else -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ -+ int const kTile = N; -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ -+ if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) -+ return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; -+ -+ return GemmCoord{ -+ (block_idx_x / kTile), -+ (block_idx_y * kTile) + (block_idx_x % kTile), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for GEMMs -+struct GemmHorizontalThreadblockSwizzle { -+ -+ CUTLASS_HOST_DEVICE -+ GemmHorizontalThreadblockSwizzle() { } -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int split_k_slices) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ split_k_slices); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for batched GEMMs -+struct GemmBatchedIdentityThreadblockSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int batch_count) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ batch_count % (1 << 16)); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Gets the batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const { -+ return RematerializeBlockIdxZ(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for split-K GEMMs -+template -+struct GemmSplitKIdentityThreadblockSwizzle { -+ -+ int const kTile = N; -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int partitions) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ partitions); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ auto n = tiled_shape.n(); -+ // Thresholds picked so that it doesn't cause too many no-op CTAs -+ if (N >= 8 && n >= 6) -+ return 3; -+ else if (N >= 4 && n >= 3) -+ return 2; -+ else if (N >= 2 && n >= 2) -+ return 1; -+ else -+ return 0; -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ int tile = 1 << get_log_tile(tiled_shape); -+ return dim3(tiled_shape.m() * tile, (tiled_shape.n() + tile - 1) / tile, tiled_shape.k()); -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ int block_idx_z = RematerializeBlockIdxZ(); -+ -+ return GemmCoord{(block_idx_x >> log_tile), // -+ (block_idx_y << log_tile) + ((block_idx_x) & ((1 << (log_tile)) - 1)), -+ block_idx_z}; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ -+ int const kTile = N; -+ int block_idx_x = RematerializeBlockIdxX(); -+ int block_idx_y = RematerializeBlockIdxY(); -+ -+ if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) -+ return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; -+ -+ return GemmCoord{ -+ (block_idx_x / kTile), -+ (block_idx_y * kTile) + (block_idx_x % kTile), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for split-K GEMMs -+struct GemmSplitKHorizontalThreadblockSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord get_tiled_shape( -+ GemmCoord problem_size, -+ GemmCoord tile_size, -+ int partitions) const { -+ -+ return GemmCoord( -+ (problem_size.m() + tile_size.m() - 1) / tile_size.m(), -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ partitions); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(GemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int log_tile) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(GemmCoord tiled_shape) const { -+ return GemmCoord{ -+ RematerializeBlockIdxY(), -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ() -+ }; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock swizzling function for batched GEMVs -+struct GemvBatchedStridedThreadblockDefaultSwizzle { -+ -+ /// Returns the shape of the problem in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ BatchedGemmCoord get_tiled_shape( -+ BatchedGemmCoord problem_size, -+ BatchedGemmCoord tile_size) const { -+ -+ return BatchedGemmCoord( -+ 1, // M is always 1 -+ (problem_size.n() + tile_size.n() - 1) / tile_size.n(), -+ (problem_size.k() + tile_size.k() - 1) / tile_size.k(), -+ (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch()); -+ } -+ -+ /// Computes CUDA grid dimensions given a size in units of logical tiles -+ CUTLASS_HOST_DEVICE -+ dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const { -+ return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k()); -+ } -+ -+ /// Calculates optimal swizzle width -+ CUTLASS_HOST_DEVICE -+ int get_log_tile(GemmCoord tiled_shape) const { -+ return 0; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ BatchedGemmCoord get_tile_offset(int log_tile) const { -+ return BatchedGemmCoord{ -+ 0, // M is always 1 -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ(), -+ RematerializeBlockIdxY(), -+ }; -+ } -+ -+ /// Obtains the threadblock offset (in units of threadblock-scoped tiles) -+ CUTLASS_DEVICE -+ BatchedGemmCoord get_tile_offset() const { -+ return BatchedGemmCoord{ -+ 0, // M is always 1 -+ RematerializeBlockIdxX(), -+ RematerializeBlockIdxZ(), -+ RematerializeBlockIdxY(), -+ }; -+ } -+ -+ /// Gets the batch tile index -+ CUTLASS_DEVICE -+ int get_batch_tile_idx() const { -+ return RematerializeBlockIdxY(); -+ } -+ -+ /// Gets the absolute batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const { -+ return RematerializeBlockDimY()*RematerializeBlockIdxY() + RematerializeThreadIdxY(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h -new file mode 100644 -index 0000000..b91046e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h -@@ -0,0 +1,813 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implements streamk threadblock mapping blockIdx to GEMM problems. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/gemm/threadblock/index_remat.h" -+ -+#include -+#include "cutlass/core_io.h" -+#include "cutlass/trace.h" -+ -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Threadblock mapping control for GEMMs -+struct ThreadblockSwizzleStreamK { -+ -+ /// Advertise StreamkFeature -+ using StreamkFeature = void; -+ -+ -+ /// Kernel traits -+ template -+ struct KernelTraits {}; -+ -+ -+ /// Reduction strategy -+ enum ReductionStrategy -+ { -+ kNone, // Data-parallel strategy (no seams, fixup, etc.) -+ -+ kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2 -+ -+ kMixed, // Deterministic reduction of SK-block partials employing either: -+ // (a) A separate wave of reduction thread blocks" (for scenarios with lots of -+ // SK-blocks per SK-tile) -+ // (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few -+ // SK-blocks per SK-tile) -+ }; -+ -+ static ReductionStrategy const kReductionStrategy = kMixed; -+ -+ -+ // -+ // Heuristics -+ // -+ -+ /// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel) -+ static float constexpr kDpEfficiencyThreshold = 0.92f; -+ -+ /// Minimum number of MAC-iterations per streamk block -+ static int const kMinItersPerSkBlock = 2; -+ -+ /// Height in CTAs of a grid rasterization cohort -+ static int const kCohortCtasM = 8; -+ -+ /// Width in CTAs of a grid rasterization cohort -+ static int const kCohortCtasN = 4; -+ -+ /// Number of CTAs per cohort -+ static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM; -+ -+ /// Cost-equivalent number of SM-iterations for fixup I/O -+ static int const kFixupStartupIterEquiv = 10; -+ static int const kFixupPeerIterEquiv = 3; -+ -+ -+ // -+ // Member state -+ // -+ -+ -+ /// The 3D value-extents of the GEMM computation volume (m,n,k) -+ GemmCoord problem_size; -+ -+ /// Div/mod accelerators -+ FastDivmod div_mod_tiled_shape_m; -+ FastDivmod div_mod_tiled_shape_n; -+ FastDivmod div_mod_tiled_cohort_shape_n; -+ FastDivmod div_mod_iters_per_tile; -+ -+ /// Whether to perform cohort CTA rasterization -+ bool cohort_raster; -+ -+ // Whether to pad and remap block indices -+ bool remap_block_indices; -+ -+ /// CTA occupancy per SM -+ int sm_occupancy; -+ -+ /// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size) -+ int avail_sms; -+ -+ int dp_blocks; /// Number of data-parallel thread blocks in the grid -+ int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce -+ -+ /// Number of reduction blocks in the grid -+ int reduction_blocks; -+ -+ int sk_waves; -+ int sk_tiles; -+ int sk_big_blocks_per_region; -+ int sk_iters_per_region; -+ -+ /// Div/mod accelerators -+ FastDivmod div_mod_sk_iters_per_normal_block; -+ FastDivmod div_mod_sk_iters_per_big_block; -+ FastDivmod div_mod_sk_iters_per_region; -+ FastDivmod div_mod_sk_regions; //!! used in block map -+ FastDivmod div_mod_sk_blocks_per_region; //!! used in block map -+ -+ /// The batch count -+ int batch_count; -+ -+ -+ // -+ // Host+device interface -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ThreadblockSwizzleStreamK() {} -+ -+ /// Returns the GEMM volume in thread block tiles -+ CUTLASS_HOST_DEVICE -+ GemmCoord tiled_shape() const -+ { -+ return GemmCoord( -+ static_cast(div_mod_tiled_shape_m), -+ static_cast(div_mod_tiled_shape_n), -+ batch_count); -+ } -+ -+ /// Number of iterations per output tile -+ CUTLASS_HOST_DEVICE -+ int iters_per_tile() const -+ { -+ return static_cast(div_mod_iters_per_tile); -+ } -+ -+ /// Number of iterations for normal SK-blocks -+ CUTLASS_HOST_DEVICE -+ int sk_iters_per_normal_block() const -+ { -+ return static_cast(div_mod_sk_iters_per_normal_block); -+ } -+ -+ /// Number of SK regions -+ CUTLASS_HOST_DEVICE -+ int sk_regions() const -+ { -+ return static_cast(div_mod_sk_regions); -+ } -+ -+ /// Number of SK blocks per region (splitting factor) -+ CUTLASS_HOST_DEVICE -+ int sk_blocks_per_region() const -+ { -+ return static_cast(div_mod_sk_blocks_per_region); -+ } -+ -+ -+ // -+ // Host-side interface -+ // -+ -+ /// Debug print -+ void Print() -+ { -+#ifndef __CUDA_ARCH__ -+ auto tiles = tiled_shape().mn().product(); -+ std::cout << -+ "problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" << -+ ", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" << -+ ", tiles: " << tiles << -+ ", dp_tiles: " << tiles - sk_tiles << -+ ", sk_tiles: " << sk_tiles << -+ ", iters_per_tile: " << iters_per_tile() << -+ ", reduction_blocks: " << reduction_blocks << -+ ", dp_blocks: " << dp_blocks << -+ ", dp_waves: " << dp_blocks / avail_sms << -+ ", dp_first_wave_tiles: " << dp_first_wave_tiles << -+ ", sk_blocks_per_region: " << sk_blocks_per_region() << -+ ", sk_regions: " << sk_regions() << -+ ", sk_waves: " << sk_waves << -+ ", sk_iters_per_normal_block: " << sk_iters_per_normal_block() << -+ ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << -+ ", remap_block_indices: " << remap_block_indices << -+ ", cohort_raster: " << cohort_raster << -+ ", sm_occupancy: " << sm_occupancy << -+ ", avail_sms: " << avail_sms << -+ ", num_blocks: " << get_num_blocks() << -+ "\n\n"; -+#endif -+ } -+ -+ -+ // Compute sk_blocks to dispatch for a given number of sk_tiles -+ static void get_sk_blocks( -+ int &sk_blocks, /// [out] -+ int &savings_iters, /// [out] -+ int sk_tiles, -+ int iters_per_tile, -+ int avail_sms, -+ int max_sk_occupancy, -+ bool allow_partial_wave) -+ { -+ savings_iters = INT_MIN; -+ sk_blocks = 0; -+ -+ if (sk_tiles == 0) { -+ return; -+ } -+ -+ int sk_iters = sk_tiles * iters_per_tile; -+ -+ int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms; -+ int dp_equiv_iters = iters_per_tile * dp_equiv_waves; -+ -+ int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms; -+ int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock); -+ -+ for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks) -+ { -+ int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms; -+ int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks; -+ int sk_iter_equiv = max_sk_iters_per_block * sk_waves; -+ -+ int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew -+ -+ float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv); -+ -+ if (trial_sk_blocks % sk_tiles == 0) -+ { -+ // aligned -+ num_peers = (trial_sk_blocks / sk_tiles); -+ -+ iter_cost = 0.0f; -+ } -+ -+ float peer_cost = 2.0f * float(num_peers); -+ -+ float base_cost = 2.0f * float(sk_waves); -+ -+ int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost); -+ -+ int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv; -+ -+ if (trial_savings_iters >= savings_iters) { -+ savings_iters = trial_savings_iters; -+ sk_blocks = trial_sk_blocks; -+ } -+ } -+ } -+ -+ -+ /// Determine the populations of DP and SK blocks to invoke for the given number of output tiles -+ static void get_blocks( -+ int &dp_tiles, /// [out] -+ int &sk_blocks, /// [out] -+ int output_tiles, -+ int iters_per_tile, -+ int avail_sms, -+ int sm_occupancy) -+ { -+ int full_waves = output_tiles / avail_sms; -+ int full_wave_tiles = full_waves * avail_sms; -+ int partial_wave_tiles = output_tiles - full_wave_tiles; -+ -+ int score = -1; -+ dp_tiles = output_tiles; -+ sk_blocks = 0; -+ -+ if (partial_wave_tiles == 0) -+ { -+ // Perfect quantization -+ return; -+ } -+ -+ if (full_waves < sm_occupancy) -+ { -+ // We're less than full GPU occupancy -+ -+ // Form the SK wave from the partial wave to get us up to full GPU occupancy -+ int max_sk_occupancy = sm_occupancy - full_waves; -+ -+ dp_tiles = full_wave_tiles; -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ true); // we can run with less than a full wave of SK-blocks -+ -+ if (score < 0) { -+ // not profitable -+ sk_blocks = 0; -+ dp_tiles = output_tiles; -+ } -+ -+ return; -+ } -+ -+ // We're at (or greater) than GPU occupancy -+ -+ if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1)) -+ { -+ // If occupancy is more than one CTA per SM, form the SK wave from the partial -+ // wave to get us to full GPU occupancy -+ int max_sk_occupancy = 1; -+ -+ dp_tiles = full_wave_tiles; -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ true); // we can run with less than a full wave of SK-blocks -+ -+ if (score >= 0) { -+ return; -+ } -+ } -+ -+ // Form the SK wave by combining the last full wave and the partial wave -+ // We're less than full GPU occupancy -+ dp_tiles = full_wave_tiles - avail_sms; -+ -+ int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy); -+ -+ get_sk_blocks( -+ sk_blocks, -+ score, -+ partial_wave_tiles + avail_sms, -+ iters_per_tile, -+ avail_sms, -+ max_sk_occupancy, -+ false); // we cannot run with less than a full wave of SK-blocks -+ -+ if (score < 0) { -+ // not profitable -+ sk_blocks = 0; -+ dp_tiles = output_tiles; -+ } -+ -+ } -+ -+ /// Constructor: *Gemm* problem size (m, n, k) -+ template -+ ThreadblockSwizzleStreamK( -+ KernelTraits const kernel_traits_, -+ GemmUniversalMode const mode_, -+ GemmCoord const problem_size_, -+ GemmCoord const tile_size_, -+ int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) -+ int const sm_occupancy_, -+ int const device_sms_, -+ int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) -+ : -+ problem_size(problem_size_), -+ batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1), -+ reduction_blocks(0), -+ dp_blocks(0), -+ dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks -+ sk_tiles(0), -+ sk_big_blocks_per_region(0), -+ sk_iters_per_region(0), -+ sk_waves(0), -+ sm_occupancy(sm_occupancy_), -+ remap_block_indices(false), -+ avail_sms(fast_max(1, avail_sms_)), -+ cohort_raster(false) -+ { -+ int gpu_occupancy = device_sms_ * sm_occupancy; -+ int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k(); -+ int sk_iters_per_normal_block = 0; -+ -+ int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles) -+ int sk_blocks_per_region = 0; -+ -+ GemmCoord tiled_shape( -+ (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), -+ (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), -+ batch_count); -+ -+ size_t problem_bytes = -+ (sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) + -+ (sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) + -+ (sizeof(typename GemmKernel::ElementB) * problem_size.k() * problem_size.n()); -+ -+ size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2; -+ -+ float flops_per_byte = float(problem_flops) / float(problem_bytes); -+ -+ int output_tiles = tiled_shape.m() * tiled_shape.n(); -+ int waves = (output_tiles + avail_sms - 1) / avail_sms; -+ float dp_efficiency = float(output_tiles) / float(waves * avail_sms); -+ -+ // -+ // Determine dispatch composition of DP-tiles and SK-blocks -+ // -+ -+ // Start with a DP-only configuration -+ int dp_tiles = output_tiles; // Number of data-parallel tiles -+ int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles -+ -+ // Only kGemm mode allows for SK load balancing -+ if (mode_ == GemmUniversalMode::kGemm) -+ { -+ int split_factor = batch_split_; -+ if (split_factor > 1) -+ { -+ // Split-K override -+ dp_tiles = 0; -+ sk_blocks = output_tiles * split_factor; -+ } -+ else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled -+ (avail_sms > 1)) // Plurality of SMs to load balance across -+ { -+ // Use heuristics -+ get_blocks( -+ dp_tiles, /// [out] -+ sk_blocks, /// [out] -+ output_tiles, -+ iters_per_tile, -+ avail_sms, -+ sm_occupancy); -+ } -+ } -+ -+ sk_tiles = output_tiles - dp_tiles; -+ -+ -+ // Compute SK block iteration details -+ if (sk_blocks > 0) -+ { -+ sk_waves = (sk_blocks + avail_sms - 1) / avail_sms; -+ -+ int sk_iters = sk_tiles * iters_per_tile; -+ sk_blocks = fast_min(sk_blocks, sk_iters); -+ -+ sk_iters_per_normal_block = sk_iters / sk_blocks; -+ int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks); -+ int sk_big_blocks = extra_sk_iters; -+ -+ if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0)) -+ { -+ // Split-K decomposition -+ sk_regions = sk_tiles; -+ } -+ -+ sk_blocks_per_region = sk_blocks / sk_regions; -+ sk_big_blocks_per_region = sk_big_blocks / sk_regions; -+ sk_iters_per_region = sk_iters / sk_regions; -+ -+ // Use a separate reduction wave when all of: -+ // - Non-atomic reduction stratgy -+ // - The number of SK waves won't fully occupy the GPU (Otherwise we don't have -+ // a strong-scaling case for more parallel reduction) -+ // - More than three peers working on an SK tile. (This occurs when the ratio of -+ // SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks, -+ // e.g.:[partial-block | block | block | partial-block] ). With three or -+ // less peers, the two non-finishing SK-blocks are not expexted to contend. -+ if ((kReductionStrategy == kMixed) && -+ (sk_waves < sm_occupancy) && -+ (sk_blocks > 2 * sk_tiles)) -+ { -+ // Launch a reduction block for every accumulator fragment in each SK-tile -+ static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments; -+ reduction_blocks = sk_tiles * kAccumulatorFragments; -+ -+ } -+ -+ // When we have a multi-occupancy kernel and at least two waves of active blocks (where -+ // at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2) -+ // remap the block indices so that we can reliably spread the SK blocks evenly across the -+ // device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx(). -+ remap_block_indices = ( -+ (sm_occupancy > 1) && -+ (device_sms_ == avail_sms) && -+ (get_num_active_blocks() > avail_sms * 2)); -+ -+ // Initialize fast div/mod members related to SK -+ div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); -+ div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); -+ div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region); -+ div_mod_sk_regions = FastDivmod(sk_regions); -+ div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region); -+ } -+ -+ // -+ // Compute DP blocks -+ // -+ -+ dp_blocks = dp_tiles; -+ -+ cutlass::gemm::GemmCoord tiled_cohort_shape( -+ (tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM, -+ (tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN, -+ tiled_shape.k()); -+ int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort; -+ float cohort_efficiency = float(dp_blocks) / float(cohort_blocks); -+ -+ // Check if the SK tiles would be in cohorts that are in-bounds -+ bool sk_in_range = true; -+ if (sk_tiles > 0) -+ { -+ int last_sk_tile = sk_tiles - 1; -+ int cohort_tile_idx = last_sk_tile / kCtasPerCohort; -+ int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n(); -+ int cohort_grid_n = (cohort_grid_m > 0) ? -+ tiled_cohort_shape.n() - 1 : -+ cohort_tile_idx % tiled_cohort_shape.n(); -+ -+ if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) || -+ (((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n())) -+ { -+ sk_in_range = false; -+ } -+ -+ } -+ -+ // Decide if we're going to be doing cohort raster -+ if (sk_in_range && -+ (dp_blocks >= gpu_occupancy * 2) && -+ (cohort_efficiency > 0.85f)) -+ { -+ cohort_raster = true; -+ dp_blocks = cohort_blocks; -+ } -+ else if (sk_waves > 0) -+ { -+ // Update semi-persistence of first DP wave to ensure full grid wavesets -+ // (Only applies when there's an SK component and we're not doing blocked cohort rasterization) -+ int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms; -+ int full_dp_tile_waves = dp_tiles / avail_sms; -+ int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy; -+ -+ if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves) -+ { -+ dp_first_wave_tiles += waveset_excess; -+ dp_blocks -= (waveset_excess * avail_sms); -+ } -+ } -+ -+ // Setup fast-div/mod for device-side usage -+ div_mod_tiled_shape_m = FastDivmod(tiled_shape.m()); -+ div_mod_tiled_shape_n = FastDivmod(tiled_shape.n()); -+ div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); -+ div_mod_iters_per_tile = FastDivmod(iters_per_tile); -+ -+ } -+ -+ /// Number of blocks performing useful work -+ int get_num_active_blocks() const -+ { -+ return (sk_waves * avail_sms) + dp_blocks + reduction_blocks; -+ } -+ -+ /// Obtains number of threadblocks per GEMM -+ int get_num_blocks() const -+ { -+ int active_blocks = get_num_active_blocks(); -+ if (remap_block_indices) -+ { -+ // Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves -+ return fast_max(active_blocks, avail_sms * 4); -+ } -+ -+ return active_blocks; -+ } -+ -+ -+ /// Obtains grid extents in CTAs -+ dim3 get_grid_dims() const -+ { -+ return dim3(get_num_blocks(), 1, batch_count); -+ } -+ -+ -+// Guards needed for PyCUTLASS library generation -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ -+ // -+ // Device-side interface -+ // -+ -+ /// Proves to the compiler that val is warp-uniform -+ CUTLASS_DEVICE -+ int uniform(int val) const -+ { -+ return __shfl_sync(0xffffffff, val, 0); -+ } -+ -+ /// Obtains number of threadblocks per GEMM -+ CUTLASS_DEVICE -+ int device_num_blocks() const -+ { -+ return gridDim.x; -+ } -+ -+ /// Obtains tile index for the given sk iteration -+ CUTLASS_DEVICE -+ int get_sk_tile_idx(int iter) const -+ { -+ int tile_idx = div_mod_iters_per_tile.div(iter); -+ return uniform(tile_idx); -+ } -+ -+ /// Obtains the batch index -+ CUTLASS_DEVICE -+ int get_batch_idx() const -+ { -+ return RematerializeBlockIdxZ(); -+ } -+ -+ /// Obtains the calling threadblock's tiled coordinates for the given tile index -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset(int tile_idx) const -+ { -+ int m, n; -+ -+ // row-major raster -+ div_mod_tiled_shape_n(m, n, tile_idx); -+ -+ if (tiled_shape().m() < tiled_shape().n()) -+ { -+ // column-major raster -+ div_mod_tiled_shape_m(n, m, tile_idx); -+ } -+ -+ if (cohort_raster) -+ { -+ // tiled cohort raster -+ int cohort_tile_idx = tile_idx / kCtasPerCohort; -+ int cohort_grid_m, cohort_grid_n; -+ div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); -+ -+ int block_idx_cohort = tile_idx % kCtasPerCohort; -+ int block_cohort_m = block_idx_cohort / kCohortCtasN; -+ int block_cohort_n = block_idx_cohort % kCohortCtasN; -+ -+ m = (cohort_grid_m * kCohortCtasM) + block_cohort_m; -+ n = (cohort_grid_n * kCohortCtasN) + block_cohort_n; -+ } -+ -+ return GemmCoord(m, n, get_batch_idx()); -+ } -+ -+ /// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rastorization) -+ CUTLASS_DEVICE -+ GemmCoord get_tile_offset_row_major(int tile_idx) const -+ { -+ // row-major raster -+ int m, n; -+ div_mod_tiled_shape_n(m, n, tile_idx); -+ return GemmCoord(m, n, get_batch_idx()); -+ } -+ -+ /// Obtains calling threadblock's linear threadblock index -+ CUTLASS_DEVICE -+ int get_block_idx() const -+ { -+ int block_idx = RematerializeBlockIdxX(); -+ -+ // Remap the block indices for the first two waves of thread blocks if -+ // we have multi-occupancy and the grid constitutes four or more waves -+ if (remap_block_indices && (block_idx < avail_sms * 2)) -+ { -+ int dest_sm = block_idx / 2; -+ int dest_wave = block_idx % 2; -+ int remapped_block_idx = dest_sm + (dest_wave * avail_sms); -+ block_idx = remapped_block_idx; -+ } -+ -+ // Remap block indices to interleave SK regions to limit intra-region waiting -+ if (block_idx < sk_regions() * sk_blocks_per_region()) -+ { -+ int block_in_region; -+ int region; -+ div_mod_sk_regions(block_in_region, region, block_idx); -+ block_idx = (region * sk_blocks_per_region()) + block_in_region; -+ } -+ -+ return uniform(block_idx); -+ } -+ -+ -+ /// Obtains calling linear threadblock index of the first block to work on the given tile -+ CUTLASS_DEVICE -+ int get_sk_block_idx(int iter) const -+ { -+ int region_idx; -+ int iter_in_region; -+ div_mod_sk_iters_per_region(region_idx, iter_in_region, iter); -+ -+ int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks -+ int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks -+ -+ int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region); -+ int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters); -+ -+ int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ? -+ big_block_idx_in_region : -+ normal_block_idx_in_region; -+ -+ int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region; -+ -+ return owning_block_idx; -+ } -+ -+ /// Obtains iteration extends for the given SK block index -+ CUTLASS_DEVICE -+ void get_iter_extents( -+ int sk_block_idx, -+ int &block_iter_begin, -+ int &block_iter_end) const -+ { -+ int region_idx; -+ int block_idx_in_region; -+ div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); -+ -+ block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block()); -+ -+ // Adjust extents for the first "num_big_blocks" blocks that get one extra iteration -+ int block_iters = sk_iters_per_normal_block(); -+ if (block_idx_in_region < sk_big_blocks_per_region) { -+ // This is a +1 iteration block -+ block_iter_begin += block_idx_in_region; -+ block_iters++; -+ } else { -+ // This is a regular block -+ block_iter_begin += sk_big_blocks_per_region; -+ } -+ block_iter_end = block_iter_begin + block_iters; -+ } -+ -+ -+ /// Obtains calling linear threadblock index of the first block to work on the given tile -+ CUTLASS_DEVICE -+ int get_first_block_idx(int tile_idx, int block_idx) const -+ { -+ if (tile_idx >= sk_tiles) { -+ // DP tile -+ return block_idx; -+ } -+ -+ int iter = tile_idx * iters_per_tile(); -+ return get_sk_block_idx(iter); -+ } -+ -+#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h -new file mode 100644 -index 0000000..1c794b1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h -@@ -0,0 +1,612 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h" -+#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Multiply-add operator (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) -+ typename Operator_ = arch::OpMultiplyAddComplex> -+struct DefaultMmaComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case -+// 4 real-valued mma operations -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case using GaussianComplex operation -+// 3 real-valued mma operations -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddGaussianComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use TF32 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, -+ cutlass::layout::RowMajor, -+ tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use BF16 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 operations on BF16 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddFastBF16> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ bfloat16_t, -+ cutlass::layout::RowMajor, -+ bfloat16_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization - input and output types are complex*complex -+// Use F16 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.f16.f16.f32 operations on F16 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddFastF16> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.f16.f16.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ half_t, -+ cutlass::layout::RowMajor, -+ half_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// 3xTF32 or 4xTF32 (fast and accurate complex operation) -+/// Partial specialization - input and output types are complex * complex -+// Use 3xTF32 or 4xTF32 tensor operation internally -+// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = 3x[(ar*br - ai*bi) + j (ar*bi + ai*br)] -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplexFastF32> { -+ -+ // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, -+ cutlass::layout::RowMajor, -+ tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOpFastF32< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex case -+// 4 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -+// A = (ar + j ai), B (br +j bi), D = AB -+// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 4>, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 4>, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB, -+ true>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for complex*complex case using GaussianComplex operation -+// 3 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations -+// A = (ar + j ai), B = (br +j bi), D = AB -+// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) -+// D = dr + j di = (P1 - P3) + j (P1 + P2) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Real-valued underlying type of complex-valued A operand -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Real-valued underlying type of complex-valued B operand -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Real-valued underlying type of complex-valued C operand -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB> -+struct DefaultMmaComplexTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 4>, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ TransformA, -+ TransformB, -+ arch::OpMultiplyAddGaussianComplex> { -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 4>, -+ 32, -+ RealElementA, -+ cutlass::layout::RowMajor, -+ RealElementB, -+ cutlass::layout::ColumnMajor, -+ RealElementC, -+ cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd>, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< -+ WarpShape_, -+ complex, -+ LayoutA, -+ complex, -+ LayoutB, -+ complex, -+ LayoutC, -+ Policy, -+ TransformA, -+ TransformB, -+ true>; -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h -new file mode 100644 -index 0000000..89f8f1c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_sparse_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false -+> -+struct DefaultSparseMmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultSparseMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::SparseMma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, cutlass::layout::RowMajor, -+ tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::SparseMmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultSparseMmaTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::SparseMma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::SparseMmaTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h -new file mode 100644 -index 0000000..3421de9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h -@@ -0,0 +1,123 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMmaTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "default_mma_tensor_op_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h -new file mode 100644 -index 0000000..d4d8026 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses BF16 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 8>, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastBF16, -+ PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses BF16 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 8>, -+ 32, -+ bfloat16_t, cutlass::layout::RowMajor, -+ bfloat16_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses F16 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ GemmShape<16, 8, 8>, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastF16, -+ PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses F16 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ GemmShape<16, 8, 8>, -+ 32, -+ half_t, cutlass::layout::RowMajor, -+ half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 internally -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ tfloat32_t, cutlass::layout::RowMajor, -+ tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOp< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization - inputs and output types are float - uses TF32 for Fast Accurate FP32 -+template < -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of target matrix multiply instruction (concept: GemmShape) -+ typename InstructionShape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor> -+struct DefaultMmaTensorOp< -+ WarpShape_, -+ InstructionShape_, -+ float, LayoutA, -+ float, LayoutB, -+ float, LayoutC, -+ arch::OpMultiplyAddFastF32, PartitionsK, AccumulatorsInRowMajor> { -+ -+ // Uses TF32 internally -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ InstructionShape_, -+ 32, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOpFastF32< -+ WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, -+ Policy, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h -new file mode 100644 -index 0000000..63effe8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h -@@ -0,0 +1,92 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_with_reduction_tensor_op.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false> -+struct DefaultMmaWithReductionTensorOp { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaWithReductionTensorOp< -+ WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, -+ Policy, ReduceKForA_, PartitionsK, AccumulatorsInRowMajor>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h -new file mode 100644 -index 0000000..4f951d4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/default_mma_wmma_tensor_op.h -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the Gemm problem (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Operator describing the tensor operation -+ typename Operator_ = arch::OpMultiplyAdd, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1 -+> -+struct DefaultMmaTensorOpWmma; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for m-by-n-by-kgroup -+template < -+ ///< Shape of one matrix production operation (concept: GemmShape) -+ typename WarpShape_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC, -+ /// Operator describing the tensor operation -+ typename Operator_, -+ /// Number of partitions along K dimension -+ int PartitionsK> -+struct DefaultMmaTensorOpWmma { -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape_, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Operator_>, -+ cutlass::MatrixShape<1, 1> >; -+ -+ // Define the warp-level tensor op -+ using Type = cutlass::gemm::warp::MmaTensorOpWmma< -+ WarpShape_, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy, -+ PartitionsK>; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+#endif -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h b/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h -new file mode 100644 -index 0000000..c604ef3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per channel scale+bias+relu before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LayernormScaleBiasTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumVarMean = FragmentVarMean::kElements; -+ static int const NumGammaBeta = FragmentGammaBeta::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns and 2 rows -+ static int const MmaCols = 2; -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using VarMeanOperand = Array<__half2, MmaScaleBiasPair>; -+ using GammaBetaOperand = Array; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, -+ VarMeanOperand const &var_mean, -+ GammaBetaOperand const &gamma_beta) { -+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) -+ uint32_t *ptr_activations = reinterpret_cast(&activations); -+ uint32_t const *ptr_var_mean = reinterpret_cast(&var_mean); -+ uint32_t const *ptr_gamma_beta = reinterpret_cast(&gamma_beta); -+ -+ // Apply per channel scale+bias+relu if the data is not a special NaN -+ // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. -+ -+ // We assumes the pair of FP16 are either both inbound or both out-of-bound. -+ // It requires C to be an even number. -+ asm volatile( -+ "{\n\t" -+ " fma.rn.f16x2 %0, %1, %2, %3;\n" -+ " fma.rn.f16x2 %0, %4, %0, %5;\n" -+ "}\n" -+ : "=r"(ptr_activations[0]) -+ : "r"(ptr_var_mean[0]), "r"(ptr_activations[0]), -+ "r"(ptr_var_mean[1]), -+ "r"(ptr_gamma_beta[0]), "r"(ptr_gamma_beta[1])); -+#else -+ // TODO: write emulation code -+ assert(0); -+#endif -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentVarMean const &var_mean, -+ FragmentGammaBeta const &gamma_beta) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ VarMeanOperand const *ptr_var_mean = -+ reinterpret_cast(&var_mean); -+ GammaBetaOperand const *ptr_gamma_beta = -+ reinterpret_cast(&gamma_beta); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], -+ ptr_var_mean[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows], -+ ptr_gamma_beta[(i / MmaScaleBiasPair) % MmaCols]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h -new file mode 100644 -index 0000000..1f3ca94 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma.h -@@ -0,0 +1,60 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates exposing architecture support for warp-level multiply-add operations -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Query the number of threads per warp -+template -+struct WarpSize { -+ static int const value = 32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h -new file mode 100644 -index 0000000..7bcf7fe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op.h -@@ -0,0 +1,1167 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+#include "cutlass/arch/mma_sm90.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ /// Data type of real & imag members of complex numbers in the SourceFragment -+ typename RealElement, -+ /// Destination fragment required by the mma operation -+ typename DestinationFragment, -+ /// Source fragment holding complex elements -+ typename SourceFragment, -+ /// Number of mma operations performed -+ typename MmaIterations, -+ /// Shape of operand elements -+ typename MmaOperandShape, -+ /// Complex transform on A operand -+ ComplexTransform Transform_, -+ /// Operand A or Operand B -+ Operand Operand_, -+ /// Floating-point rounding style -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma; -+ -+// Partial specialization for OperandA and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kA, -+ Round_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kA; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRound = Round_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElement <= RealElement -+ using Converter = NumericConverter; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::ColumnMajor; -+ static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMma() {} -+ -+ CUTLASS_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i and apply rounding on real and imag parts -+ MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest[i][pos] = a; -+ dest[i+MmaIterations::kRow][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); -+ -+ } -+ } -+ } -+ } -+}; -+ -+// Partial specialization for OperandB and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle Round_> -+struct UnpackComplexConvertAndPackForMma < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kB, -+ Round_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kB; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRound = Round_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElement <= RealElement -+ using Converter = NumericConverter; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::RowMajor; -+ static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMma() {} -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i apply rounding on real and imag parts -+ MmaElement a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ MmaElement b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest[i][pos] = a; -+ dest[i+MmaIterations::kColumn][pos++] = (kTransform == ComplexTransform::kConjugate ? -b : b); -+ } -+ } -+ } -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Do source operands need more than one elements -+ bool GeneralizedOperatorElements = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected planar complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(MmaOperandA::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(MmaOperandB::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = A[m].real(); -+ operand_B[0] = B[n].real(); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = A[m].real(); -+ operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.real(), -a.imag(), b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ // A imaginary part is intentionally negated -+ operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag()); -+ operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag()); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag()); -+ operand_B[0] = B[n].real(); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = float; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = float; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = float; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = typename arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of complex products operations performed (one complex product needs four mma instructions) -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, -+ "This implementation only supports mma.m16n8k8 math instructions."); -+ -+ static_assert(InstMmaOperandA::kElements == 4, -+ "This implementation only supports math instructions in which exactly four element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(InstMmaOperandB::kElements == 2, -+ "This implementation only supports math instructions in which exactly two element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ // Instruction Operands A & B holding real part followed by imaginary part for mma operations -+ InstMmaOperandA const *operand_A = reinterpret_cast(&A); -+ InstMmaOperandB const *operand_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m], operand_B[n], *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); -+ } -+ -+ // mma(accum.real(), a.imag(), -b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // negate OperandB to accumulate -(a.imag()*b.imag()) -+ // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements -+ negate negate_op; -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ // Alias types for underlying real-valued matrix multiply operator -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ -+ // -+ // Define conversions from source type to instruction operands' type -+ // -+ -+ FloatRoundStyle const kRoundA = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ FloatRoundStyle const kRoundB = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ -+ detail::UnpackComplexConvertAndPackForMma < -+ RealElementA, -+ InstMmaOperandA, -+ FragmentA, -+ MmaIterations, -+ MatrixShape<2, 2>, -+ kTransformA, -+ Operand::kA, -+ kRoundA> convert_A; -+ -+ detail::UnpackComplexConvertAndPackForMma < -+ RealElementB, -+ InstMmaOperandB, -+ FragmentB, -+ MmaIterations, -+ MatrixShape<2, 1>, -+ kTransformB, -+ Operand::kB, -+ kRoundB> convert_B; -+ -+ // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element -+ convert_A(reinterpret_cast(&dst_A), A); -+ convert_B(reinterpret_cast(&dst_B), B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Math instruction: mma.sync.aligned.m16n8k4.f64.f64.f64.f64 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ true> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = double; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = double; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = double; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = typename arch::OpMultiplyAddComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected planar complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? -+ -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.real(), -a.imag(), b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ // A imaginary part is intentionally negated -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? -+ A[m*MmaOperandA::kElements + mk].imag() : -A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? -+ -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_A; -+ MmaOperandB operand_B; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? -+ -A[m*MmaOperandA::kElements + mk].imag() : A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A, operand_B, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// TODO - partial specializations of real*complex and complex*real -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h -new file mode 100644 -index 0000000..4db983d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h -@@ -0,0 +1,663 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+namespace detail { -+ -+template < -+ /// Data type of real & imag members of complex numbers in the SourceFragment -+ typename RealElement, -+ /// Destination fragment required by the mma operation -+ typename DestinationFragment, -+ /// Source fragment holding complex elements -+ typename SourceFragment, -+ /// Number of mma operations performed -+ typename MmaIterations, -+ /// Shape of operand elements -+ typename MmaOperandShape, -+ /// Complex transform on A operand -+ ComplexTransform Transform_, -+ /// Operand A or Operand B -+ Operand Operand_, -+ /// Floating-point rounding style for big part -+ FloatRoundStyle RoundBig_, -+ /// Floating-point rounding style for small part -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32; -+ -+// Partial specialization for OperandA and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle RoundBig_, -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kA, -+ RoundBig_, -+ RoundSmall_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kA; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRoundBig = RoundBig_; -+ static FloatRoundStyle const kRoundSmall = RoundSmall_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement -+ using Converter = NumericConverterFastF32; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::ColumnMajor; -+ static int const kLdm = MmaIterations::kRow * MmaOperandShape::kRow; -+ -+ // BigSmall Fragment holding two TF32 elements (big, small) for every float -+ using BigSmallFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMmaFastF32() {} -+ -+ CUTLASS_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ DestinationFragment *dest_big_ = reinterpret_cast(dest); -+ DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kRow * 2]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i and apply rounding on real and imag parts -+ BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_big_[i][pos] = a[kBigIndex]; -+ dest_big_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_small_[i][pos] = a[kSmallIndex]; -+ dest_small_[i+MmaIterations::kRow][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); -+ -+ // Next position -+ pos++; -+ } -+ } -+ } -+ } -+}; -+ -+// Partial specialization for OperandB and Congruous smem layout -+template < -+ typename RealElement, -+ typename DestinationFragment, -+ typename SourceFragment, -+ typename MmaIterations, -+ typename MmaOperandShape, -+ ComplexTransform Transform_, -+ FloatRoundStyle RoundBig_, -+ FloatRoundStyle RoundSmall_> -+struct UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElement, -+ DestinationFragment, -+ SourceFragment, -+ MmaIterations, -+ MmaOperandShape, -+ Transform_, -+ Operand::kB, -+ RoundBig_, -+ RoundSmall_> { -+ -+ // -+ // Type definitions -+ // -+ static Operand const kOperand = Operand::kB; -+ static ComplexTransform const kTransform = Transform_; -+ static FloatRoundStyle const kRoundBig = RoundBig_; -+ static FloatRoundStyle const kRoundSmall = RoundSmall_; -+ -+ // Data type of elements in the destination fragment -+ using MmaElement = typename DestinationFragment::Element; -+ -+ // Numeric convertor MmaElementBig, MmaElementSmall <= RealElement -+ using Converter = NumericConverterFastF32; -+ -+ // Operand layout parameters -+ using SourceFragmentLayout = layout::RowMajor; -+ static int const kLdm = MmaIterations::kColumn * MmaOperandShape::kColumn; -+ -+ // BigSmall Fragment holding two TF32 elements (big, small) for every float -+ using BigSmallFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ UnpackComplexConvertAndPackForMmaFastF32() {} -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(DestinationFragment *dest, SourceFragment const &source) { -+ -+ Converter convert_op; -+ SourceFragmentLayout layout(kLdm); -+ -+ DestinationFragment *dest_big_ = reinterpret_cast(dest); -+ DestinationFragment *dest_small_ = reinterpret_cast(&dest[MmaIterations::kColumn * 2]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i=0; i apply rounding on real and imag parts -+ BigSmallFragment a = convert_op(source[layout(MatrixCoord{row,col})].real()); -+ BigSmallFragment b = convert_op(source[layout(MatrixCoord{row,col})].imag()); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_big_[i][pos] = a[kBigIndex]; -+ dest_big_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kBigIndex] : b[kBigIndex]); -+ -+ // Unpack rounded complex and pack into DestinationFragment for mma operation -+ dest_small_[i][pos] = a[kSmallIndex]; -+ dest_small_[i+MmaIterations::kColumn][pos] = (kTransform == ComplexTransform::kConjugate ? -b[kSmallIndex] : b[kSmallIndex]); -+ -+ // next position -+ pos++; -+ } -+ } -+ } -+ } -+}; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaComplexTensorOpFastF32; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex: -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB, -+ /// Used for partial specialization -+ typename Enable -+> -+class MmaComplexTensorOpFastF32< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ Enable> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of members of complex multiplicand A -+ using RealElementA = float; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of members of complex multiplicand B -+ using RealElementB = float; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of members of complex accumulator matrix C -+ using RealElementC = float; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddComplexFastF32; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ -+ /// Tune F32 to TF32 big small conversion for complex operation -+ /// Different combination of big small conversin can cause different tradeoff -+ /// between speed and accuracy. Generally, use round_half_ulp_truncate can -+ /// improve the performance but hur the accuracy. -+ using ComplexFastF32 = FastF32 < -+ FloatRoundStyle::round_toward_zero, // kRoundBigA -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA -+ FloatRoundStyle::round_toward_zero, // kRoundBigB -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB -+ TensorFloat32Op::k3xTF32 // Number of TF32 operations -+ >; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ // (4 times the original FragmentA::kElements) -+ // (real_big), (imag_big), (real_small), (imag_small) -+ using TransformedFragmentA = Array; -+ -+ // Fragment bisecting big and small sections -+ // (real_big, imag_big), (real_small, imag_small) -+ using AccessTypeFragmentA = Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ // (4 times the original FragmentB::kElements) -+ // (real_big), (imag_big), (real_small), (imag_small) -+ using TransformedFragmentB = Array; -+ -+ // Fragment bisecting big and small sections -+ // (real_big, imag_big), (real_small, imag_small) -+ using AccessTypeFragmentB = Array; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of complex products operations performed (one complex product needs four mma instructions) -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued -+ /// parts are stored consecutively followed by all imaginary parts. This matches the structure -+ /// of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ // -+ // Alias types for underlying real-valued matrix multiply operator -+ // -+ using InstMmaOperandA = typename ArchMmaOperator::FragmentA; -+ using InstMmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, -+ "This implementation only supports mma.m16n8k8 math instructions."); -+ -+ static_assert(InstMmaOperandA::kElements == 4, -+ "This implementation only supports math instructions in which exactly four element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(InstMmaOperandB::kElements == 2, -+ "This implementation only supports math instructions in which exactly two element is needed for the B operand." -+ "We can geneneralize later."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaComplexTensorOpFastF32() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ AccessTypeFragmentA const *complex_A = reinterpret_cast(&A); -+ AccessTypeFragmentB const *complex_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ -+ complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kBigIndex], D); -+ -+ complex_mma_operator(D, complex_A[kBigIndex], complex_B[kSmallIndex], D); -+ -+ complex_mma_operator(D, complex_A[kBigIndex], complex_B[kBigIndex], D); -+ -+ if (ComplexFastF32::kPrecision == TensorFloat32Op::k4xTF32) -+ complex_mma_operator(D, complex_A[kSmallIndex], complex_B[kSmallIndex], D); -+ } -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void complex_mma_operator( -+ FragmentC &D, -+ AccessTypeFragmentA const &complex_A, -+ AccessTypeFragmentB const &complex_B, -+ FragmentC const &C -+ ) const { -+ -+ // Instruction Operands A & B holding real part followed by imaginary part for mma operations -+ InstMmaOperandA const *operand_A = reinterpret_cast(&complex_A); -+ InstMmaOperandB const *operand_B = reinterpret_cast(&complex_B); -+ -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.real(), a.real(), b.real(), accum.real()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m], operand_B[n], *accum); -+ } -+ -+ // mma(accum.imag(), a.real(), b.imag(), accum.imag()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m], operand_B[n+MmaIterations::kColumn], *accum); -+ } -+ -+ // mma(accum.real(), a.imag(), -b.imag(), accum.real()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // negate OperandB to accumulate -(a.imag()*b.imag()) -+ // negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements -+ negate negate_op; -+ -+ // Real-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], negate_op(operand_B[n+MmaIterations::kColumn]), *accum); -+ } -+ -+ // mma(accum.imag(), a.imag(), b.real(), accum.imag()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Complex-valued accumulator part -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_A[m+MmaIterations::kRow], operand_B[n], *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ detail::UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElementA, -+ InstMmaOperandA, -+ FragmentA, -+ MmaIterations, -+ MatrixShape<2, 2>, -+ kTransformA, -+ Operand::kA, -+ ComplexFastF32::kRoundBigA, -+ ComplexFastF32::kRoundSmallA> convert_A; -+ -+ detail::UnpackComplexConvertAndPackForMmaFastF32 < -+ RealElementB, -+ InstMmaOperandB, -+ FragmentB, -+ MmaIterations, -+ MatrixShape<2, 1>, -+ kTransformB, -+ Operand::kB, -+ ComplexFastF32::kRoundBigB, -+ ComplexFastF32::kRoundSmallB> convert_B; -+ -+ // Convert Fragment[A|B] holding complex to InstMmaOperand[A|B] holding InstMmaOperand[A|B]::Element -+ convert_A(reinterpret_cast(&dst_A), A); -+ convert_B(reinterpret_cast(&dst_B), B); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..d872012 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,2493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 128b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 8) && !(Shape::kStrided % 4), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 1; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { -+ -+ int quad_pair = lane_id / 8; -+ int quad = lane_id / 4; -+ int lane = lane_id % 4; -+ -+ int row = (quad & 1) * 4 + (lane ^ quad_pair); -+ -+ byte_offset_ = (row + quad_pair * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.contiguous() * Shape::kContiguous) + -+ (tile_offset.strided() * InstructionShape::kStrided * stride_); -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kStrided; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for complex -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of underlying field of reals. -+ typename RealElement, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators -+ /// are stored in a planar complex arrangement with the real parts as entirely contiguous -+ /// followed by the imaginary parts. -+ using Fragment = Array; -+ -+ static int const kRealIndex = 0; -+ static int const kImaginaryIndex = Shape::kCount / kThreads; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ Element z = offset_ref.at({accum_m, accum_n}); -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col + kRealIndex] = z.real(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kImaginaryIndex] = z.imag(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ Element z(frag[kRealIndex + idx], frag[kImaginaryIndex + idx]); -+ -+ offset_ref.at({accum_m, accum_n}) = z; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 128b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 8), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 128, "This is specialized for 128b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 1; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 8>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0) { -+ -+ int quad = lane_id / 4; -+ int liq = lane_id % 4; -+ -+ int c = liq + (quad & 1) * 4; -+ int s = (quad / 2); -+ -+ byte_offset_ = (c + s * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ // Compute the offset in units of elements. Note, the external coordinate system is -+ // approximately transposed with respect to the tiled internal structure -+ int offset = -+ (tile_offset.contiguous() * InstructionShape::kContiguous) * stride_ + -+ (tile_offset.strided() * Shape::kStrided); -+ -+ add_pointer_offset(offset); -+ -+ byte_offset_ ^= (tile_offset.contiguous() & 1) * 4 * sizeof(AccessType); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ byte_offset_ ^= 4 * sizeof(AccessType); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = s + c * Policy::Iterations::kStrided; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * InstructionShape::kContiguous * stride_ + -+ tile_offset.strided() * Shape::kStrided; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.column(), -tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(layout::PitchLinearCoord(-tile_offset.row(), -tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Congruous shared memory layout -+// Warp-level iterators for complex*complex + complex => complex -+// The underlying iterators are similar to that for MMA f64*f64 + f64 = f64 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, cutlass::complex, -+ cutlass::layout::TensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 8), "Divisibility."); -+ -+ /// Element type -+ using Element = cutlass::complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / Policy::Delta::kContiguous; -+ int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; -+ -+ pointer_= reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + -+ tile_offset.contiguous() * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ add_tile_offset({0, -1}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Crosswise shared memory layout -+// Warp-level iterators for complex*complex + complex => complex -+// The underlying iterators are similar to that for f64*f64 + f64 = f64 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, complex, -+ cutlass::layout::TensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); -+ -+ static_assert(sizeof_bits>::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 16>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter for tracking K-group -+ Index k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / 8; -+ int access_contiguous = (lane_id % 8); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset / kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * -+ stride_ * kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided; -+ -+ add_pointer_offset(offset); -+ -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); -+ -+ if (k_group_idx_ & 1) -+ byte_offset_ ^= 0x40; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ // xor ptr -+ byte_offset_ ^= 0x40; -+ -+ ++k_group_idx_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = c * Policy::Iterations::kStrided + s; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s / kElementsPerAccess; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ -+ Element *exchange_ptr = reinterpret_cast(&frag); -+ -+ // exchange on 64b granularity only for fragments held in k=8/2 to k=8 -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Fragment::kElements/2; i < Fragment::kElements; i += 2) { -+ Element tmp = exchange_ptr[i]; -+ exchange_ptr[i] = exchange_ptr[i + 1]; -+ exchange_ptr[i + 1] = tmp; -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h -new file mode 100644 -index 0000000..00760a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h -@@ -0,0 +1,643 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transform on B operand -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Do source operands need more than one elements -+ bool GeneralizedOperatorElements = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaGaussianComplexTensorOp; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaGaussianComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddGaussianComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is -+ /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively -+ /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected gaussian complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaGaussianComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ static_assert(MmaOperandA::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the A operand." -+ "We can geneneralize later."); -+ -+ static_assert(MmaOperandB::kElements == 1, -+ "This implementation only supports math instructions in which exactly one element is needed for the B operand." -+ "We can geneneralize later."); -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Asum; -+ MmaOperandB operand_Br; -+ -+ operand_Asum[0] = A[m].real() + ((kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag()); -+ operand_Br[0] = B[n].real(); -+ -+ // accumulator part1 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_Asum, operand_Br, *accum); -+ } -+ -+ // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ar; -+ MmaOperandB operand_Bdiff; -+ -+ operand_Ar[0] = -A[m].real(); -+ operand_Bdiff[0] = B[n].real() - ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); -+ -+ // accumulator part2 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_Ar, operand_Bdiff, *accum); -+ } -+ -+ // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ai; -+ MmaOperandB operand_Bsum; -+ -+ operand_Ai[0] = (kTransformA == ComplexTransform::kConjugate) ? -A[m].imag() : +A[m].imag(); -+ operand_Bsum[0] = B[n].real() + ((kTransformB == ComplexTransform::kConjugate) ? -B[n].imag() : +B[n].imag()); -+ -+ // accumulator part3 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; -+ -+ mma(*accum, operand_Ai, operand_Bsum, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename RealElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename RealElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename RealElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Complex transform on A operand -+ ComplexTransform TransformA, -+ /// Complex transform on B operand -+ ComplexTransform TransformB -+> -+class MmaGaussianComplexTensorOp< -+ Shape_, -+ complex, -+ LayoutA_, -+ complex, -+ LayoutB_, -+ complex, -+ LayoutC_, -+ Policy_, -+ TransformA, -+ TransformB, -+ true> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = complex; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = complex; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = complex; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Underlying arch tag -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddGaussianComplex; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kColumn, -+ 32, -+ 1 -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = FragmentB; -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Shape::kM / ArchMmaOperator::Shape::kM, -+ Shape::kN / ArchMmaOperator::Shape::kN -+ >; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta>; -+ -+ /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this -+ /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is -+ /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively -+ /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. -+ using FragmentC = typename IteratorC::Fragment; -+ -+ static_assert( -+ FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, -+ "Unexpected gaussian complex fragment length."); -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying real-valued matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaGaussianComplexTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ // Alias types for underlying real-valued matrix multiply operator -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Asum; -+ MmaOperandB operand_Br; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Asum[mk] = A[m*MmaOperandA::kElements + mk].real() + ((kTransformA == ComplexTransform::kConjugate) ? -+ -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Br[nk] = B[n*MmaOperandB::kElements + nk].real(); -+ -+ // accumulator part1 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow); -+ -+ mma(*accum, operand_Asum, operand_Br, *accum); -+ } -+ -+ // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ar; -+ MmaOperandB operand_Bdiff; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Ar[mk] = -A[m*MmaOperandA::kElements + mk].real(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Bdiff[nk] = B[n*MmaOperandB::kElements + nk].real() - ((kTransformB == ComplexTransform::kConjugate) ? -+ -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // accumulator part2 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + MmaIterations::kCount; -+ -+ mma(*accum, operand_Ar, operand_Bdiff, *accum); -+ } -+ -+ // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // Pack operands together. This may result in actual MOVs -+ MmaOperandA operand_Ai; -+ MmaOperandB operand_Bsum; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mk = 0; mk < MmaOperandA::kElements; ++mk) -+ operand_Ai[mk] = (kTransformA == ComplexTransform::kConjugate) ? -+ -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int nk = 0; nk < MmaOperandB::kElements; ++nk) -+ operand_Bsum[nk] = B[n*MmaOperandB::kElements + nk].real() + ((kTransformB == ComplexTransform::kConjugate) ? -+ -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); -+ -+ // accumulator part3 -+ MmaOperandC *accum = reinterpret_cast(&D) + -+ (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; -+ -+ mma(*accum, operand_Ai, operand_Bsum, *accum); -+ } -+ } -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..1903622 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,390 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpGaussianComplexAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// -+/// Partial specialization for complex -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of underlying field of reals. -+ typename RealElement, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpGaussianComplexAccumulatorTileIterator< -+ Shape_, complex, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = complex; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile. It is assumed that the accumulators -+ /// are stored in a gaussian complex arrangement with parts 1, 2, and 3 as entirely contiguous -+ /// arranged as [part1, part2, part3] -+ using Fragment = Array; -+ -+ static int const kPart1Index = (Shape::kCount / kThreads) * 0; -+ static int const kPart2Index = (Shape::kCount / kThreads) * 1; -+ static int const kPart3Index = (Shape::kCount / kThreads) * 2; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpGaussianComplexAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ Element z = offset_ref.at({accum_m, accum_n}); -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart1Index] = z.real() + z.imag(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart2Index] = -z.real(); -+ frag[mma_accum_start + row * kElementsPerAccess + col + kPart3Index] = z.imag(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ Element z(frag[kPart1Index + idx] - frag[kPart3Index + idx], -+ frag[kPart1Index + idx] + frag[kPart2Index + idx]); -+ -+ offset_ref.at({accum_m, accum_n}) = z; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h -new file mode 100644 -index 0000000..894efd7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_planar_complex.h -@@ -0,0 +1,182 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/array_planar_complex.h" -+#include "cutlass/gemm/warp/tile_iterator_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Underlying real-valued warp-level matrix multiply -+ typename Operator_, -+ /// Transformation applied to A operand (typically folded into math instruction) -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Transformation applied to B operand (typically folded into math instruction) -+ ComplexTransform TransformB = ComplexTransform::kNone -+> -+class MmaPlanarComplex { -+public: -+ -+ /// Underlying real-valued warp-level matrix multiply -+ using Operator = Operator_; -+ -+ /// Shape of warp-level matrix multipy -+ using Shape = typename Operator::Shape; -+ -+ /// Transformation applied to A operand (typically folded into math instruction) -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Transformation applied to B operand (typically folded into math instruction) -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Fragment of elements -+ using FragmentA = ArrayPlanarComplex; -+ -+ /// Iterator into planar complex -+ using IteratorA = TileIteratorPlanarComplex; -+ -+ /// Layout in memory of the A operand -+ using LayoutA = typename Operator::LayoutA; -+ -+ using FragmentB = ArrayPlanarComplex; -+ -+ /// Iterator into planar complex -+ using IteratorB = TileIteratorPlanarComplex; -+ -+ /// Layout in memory of the B operand -+ using LayoutB = typename Operator::LayoutB; -+ -+ /// Tile iterator for accumulator -+ using IteratorC = TileIteratorPlanarComplex; -+ -+ /// Accumulator fragment -+ using FragmentC = ArrayPlanarComplex; -+ -+ /// Layout of accumulator fragment in memory -+ using LayoutC = typename Operator::LayoutC; -+ -+private: -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ Operator::Shape::kM / Operator::Policy::Operator::Shape::kM, -+ Operator::Shape::kN / Operator::Policy::Operator::Shape::kN -+ >; -+ -+public: -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaPlanarComplex() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A_in, -+ FragmentB const &B_in, -+ FragmentC const &C) const { -+ -+ D.real = C.real; -+ D.imag = C.imag; -+ -+ // -+ // Transform fragments based on conjugate operations. -+ // -+ -+ negate neg_A; -+ -+ FragmentA frag_A; -+ frag_A.real = A_in.real; -+ -+ if (kTransformA == ComplexTransform::kConjugate) { -+ frag_A.imag = neg_A(frag_A.imag); -+ } -+ else { -+ frag_A.imag = frag_A.imag; -+ } -+ -+ FragmentB frag_B; -+ frag_B.real = B_in.real; -+ -+ if (kTransformB == ComplexTransform::kConjugate) { -+ negate neg; -+ frag_B.imag = neg(frag_B.imag); -+ } -+ else { -+ frag_B.imag = frag_B.imag; -+ } -+ -+ // -+ // Accumulated real-valued matrix multiplies -+ // -+ -+ Operator real_mma; -+ -+ // D.i += A.i * B.r -+ real_mma(D.imag, frag_A.imag, frag_B.real, D.imag); -+ -+ // D.r += A.r * B.r -+ real_mma(D.real, frag_A.real, frag_B.real, D.real); -+ -+ // D.i += A.r * B.i -+ real_mma(D.imag, frag_A.real, frag_B.imag, D.imag); -+ -+ // D.r += -A.i * B.i -+ frag_A.imag = neg_A(frag_A.imag); -+ real_mma(D.real, frag_A.imag, frag_B.imag, D.real); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h -new file mode 100644 -index 0000000..9790792 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK = 1, -+ /// Complex transformation on operand A -+ ComplexTransform TransformA = ComplexTransform::kNone, -+ /// Complex transformation on operand B -+ ComplexTransform TransformB = ComplexTransform::kNone, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaSimt { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassSimt; -+ -+ /// Hard-coded for now -+ using ArchTag = arch::Sm50; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = TransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = TransformB; -+ -+ /// Layout of threads -+ using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value, -+ layout::ColumnMajor, -+ typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value, -+ layout::RowMajor, -+ LayoutA>::type -+ >::type; -+ -+ using ThreadLayoutB = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutB >::value, -+ layout::ColumnMajor, -+ typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value, -+ layout::RowMajor, -+ LayoutB>::type -+ >::type; -+ -+ static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || -+ platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && -+ platform::is_same< ElementA, int8_t >::value && -+ platform::is_same< ElementB, int8_t >::value; -+ -+ using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; -+ -+ /// Thread-level matrix multiply accumulate operator -+ using ThreadMma = thread::Mma< -+ GemmShape< -+ Shape::kM / Policy::WarpShape::kRow, -+ Shape::kN / Policy::WarpShape::kColumn, -+ Policy::LaneMmaShape::kK>, -+ ElementA, -+ ThreadLayoutA, -+ ElementB, -+ ThreadLayoutB, -+ ElementC, -+ LayoutC, -+ arch::OpMultiplyAdd, -+ dp4a_type -+ >; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Shape of the underlying instruction -+ using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = FragmentA; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ Policy, -+ PartitionsK, -+ Shape::kK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentB = FragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaSimtTileIterator< -+ MatrixShape, -+ Operand::kC, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename ThreadMma::FragmentC; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaSimt() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &d, -+ FragmentA a, -+ FragmentB b, -+ FragmentC const &c, int group_idx = 0) const { -+ -+ ThreadMma mma; -+ -+ if (kTransformA == ComplexTransform::kConjugate) { -+ conjugate conj_a; -+ a = conj_a(a); -+ } -+ -+ if (kTransformB == ComplexTransform::kConjugate) { -+ conjugate conj_b; -+ b = conj_b(b); -+ } -+ -+ mma(d, a, b, c); -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ //TODO: Implement this -+ dst_A = A; -+ dst_B = B; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h -new file mode 100644 -index 0000000..a0b0a75 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_policy.h -@@ -0,0 +1,69 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Describes the arrangement and configuration of per-lane operations in warp-level matrix multiply -+template < -+ typename WarpShape_, ///< shape of the warp in lanes (concept: MatrixShape) -+ typename LaneLayout_, ///< layout function of lanes -+ typename LaneMmaShape_ ///< size of each lane's thread-level matrix product (concept: GemmShape) -+> -+struct MmaSimtPolicy { -+ using WarpShape = WarpShape_; -+ using LaneLayout = LaneLayout_; -+ using LaneMmaShape = LaneMmaShape_; -+ using MmaShape = LaneMmaShape; -+ -+ /// Returns a layout functor mapping lane position in the warp to thread ID -+ CUTLASS_HOST_DEVICE -+ static LaneLayout get_lane_layout() { -+ return LaneLayout::packed({WarpShape::kRow, WarpShape::kColumn}); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h -new file mode 100644 -index 0000000..53c1c36 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_simt_tile_iterator.h -@@ -0,0 +1,1890 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT -+ instructions -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions -+/// -+/// concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK = 1, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize = 1 -+> -+class MmaSimtTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::ColumnMajor> ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kM); -+ } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow / Policy::LaneMmaShape::kM, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({0, Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ // This logic has been replaced with calls to inline PTX to guarantee vectorization. -+ #if 0 -+ dst_ptr[m + k * Iterations::kRow] = -+ *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM); -+ #endif -+ -+ auto ptr = ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM; -+ arch::shared_load(dst_ptr[m + k * Iterations::kRow], ptr); -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kN; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kM; ++m) { -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = -+ src_ptr[m + k * Iterations::kM]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension - used in sliced-K -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads (scalar loads) -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() : divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) : extent_(Shape::kRow, Shape::kColumn), divisible_ (true) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ TensorCoord extent, -+ int lane_id -+ ) : extent_(extent), divisible_ (false) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ TensorCoord coord_offset( -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn); -+ -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({0, Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { -+ -+ MatrixCoord offset(m * Policy::WarpShape::kRow * Policy::LaneMmaShape::kM + i, k); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ int frag_idx = m * Policy::LaneMmaShape::kM + i + k * Iterations::kRow; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); -+ } -+ else { -+ frag[frag_idx] = Element(); -+ } -+ } -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kM; i++) { -+ -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM * Policy::LaneMmaShape::kM + i, k) + pointer_offset) = -+ frag[m * Policy::LaneMmaShape::kM + i + k * Iterations::kM]; -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+protected: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajor> ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kN); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ #if 0 -+ dst_ptr[n + k * Iterations::kColumn] = -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN); -+ #endif -+ -+ void const *ptr = ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN; -+ arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Group Size along kPartition - used in sliced-K -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ): extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ TensorCoord extent, -+ int lane_id -+ ): extent_(extent), divisible_(false) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ origin_ = lane_offset; -+ -+ ref.add_coord_offset(lane_offset); -+ -+ ref_.reset(ref.data(), ref.stride(0)); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ TensorCoord coord_offset( -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn); -+ -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. (scalar loads) -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Policy::LaneMmaShape::kN; ++i) { -+ -+ MatrixCoord offset(k, n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + i); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ int frag_idx = n * Policy::LaneMmaShape::kN + i + k * Iterations::kColumn; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ frag[frag_idx] = *(ref_.data() + ref_.offset(offset) + pointer_offset); -+ } -+ else { -+ frag[frag_idx] = Element(); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for C operands of column-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_ -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of accumulators in memory -+ using Layout = layout::ColumnMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert( -+ (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thraed-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert( -+ (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using Delta = MatrixShape< -+ Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, -+ Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to be loaded from memory -+ Index pointer_offset) const { ///< linear offset (in units of Element) when loading -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>( -+ ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag) + -+ mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN); -+ -+ *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { -+ -+ Array *dst_ptr= -+ reinterpret_cast *>( -+ ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>(&frag) + -+ mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN); -+ -+ dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr; -+ } -+ } -+ } -+ } -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for C operands of row-major layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_ -+> -+class MmaSimtTileIterator { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of accumulators in memory -+ using Layout = layout::RowMajor; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert( -+ (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thraed-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert( -+ (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)), -+ "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ using Delta = MatrixShape< -+ Policy::WarpShape::kRow * Policy::LaneMmaShape::kM, -+ Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to be loaded from memory -+ Index pointer_offset) const { ///< linear offset (in units of Element) when loading -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>( -+ ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag) + -+ mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); -+ -+ *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn]; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>( -+ ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0})); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { -+ -+ Array const *src_ptr = -+ reinterpret_cast const *>(&frag) + -+ mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM); -+ -+ dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr; -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for A operands of column-major-K interleaved layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Number of KGroups per kPartition -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::ColumnMajorInterleaved<4> ; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Iterleave factor -+ static const int kInterleave = 4; -+ -+ /// Number of partitions along K dimension -+ static const int kPartitionsK = PartitionsK; -+ -+ /// Number of KGroups per kPartition -+ static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kRow % Policy::WarpShape::kRow), -+ "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); -+ static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow / Policy::WarpShape::kRow, -+ Shape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kM, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kK -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::ColumnMajorInterleaved<4>> ref_; -+ -+ /// group index within tile -+ int k_group_idx_; -+ -+public: -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(Policy::LaneMmaShape::kM, 0); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ k_group_idx_ = 0; -+ ref_.reset(reinterpret_cast *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK); -+ } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK, -+ coord.column() * Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == kGroupPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({0, -Shape::kColumn}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ dst_ptr[m + k * Iterations::kRow] = -+ *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave, -+ k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kN; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kM; ++m) { -+ *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = -+ src_ptr[m + k * Iterations::kM]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization for B operands of row-major k-interleaved layouts -+/// -+/// Concept: MutableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Shape of the warp in units of thread (concept: MmaSimtPolicy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK, -+ /// Number of KGroups per kPartition -+ int PartitionGroupSize -+> -+class MmaSimtTileIterator, Policy_, PartitionsK, PartitionGroupSize> { -+public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of policy -+ using Layout = layout::RowMajorInterleaved<4>; -+ -+ /// Decomposition of elements among threads -+ using Policy = Policy_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Interleave factor -+ static const int kInterleave = 4; -+ -+ /// Number of partitions along K dimension -+ static const int kPartitionsK = PartitionsK; -+ -+ /// Number of KGroups per kPartition -+ static const int kGroupPerTile = PartitionGroupSize / Shape::kRow; -+ -+ // -+ // Derived quantities -+ // -+ -+ static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), -+ "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension."); -+ -+ static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); -+ static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); -+ static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero."); -+ static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero."); -+ -+ /// Thread-level shape of a fragment -+ using ThreadShape = MatrixShape< -+ Shape::kRow, -+ Shape::kColumn / Policy::WarpShape::kColumn -+ >; -+ -+ static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK), -+ "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); -+ -+ /// Number of individual loads -+ using Iterations = MatrixShape< -+ ThreadShape::kRow / Policy::LaneMmaShape::kK, -+ ThreadShape::kColumn / Policy::LaneMmaShape::kN -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef, layout::RowMajorInterleaved<4>> ref_; -+ -+ /// group index within tile -+ int k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator( -+ TensorRef ref, -+ int lane_id -+ ) { -+ -+ // compute offset based on thread ID and lane layout -+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); -+ -+ MatrixCoord lane_offset = lane_layout.inverse(lane_id) * -+ MatrixCoord(0, Policy::LaneMmaShape::kN); -+ -+ ref.add_coord_offset(lane_offset); -+ -+ k_group_idx_ = 0; -+ -+ ref_.reset( -+ reinterpret_cast *>(ref.data()), -+ ref.stride(0) / Policy::LaneMmaShape::kKN); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { -+ -+ ref_.add_coord_offset({ -+ coord.row() * Shape::kRow, -+ coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator++() { -+ -+ add_tile_offset({1, 0}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == kGroupPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaSimtTileIterator & operator--() { -+ -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ Array *dst_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ dst_ptr[n + k * Iterations::kColumn] = -+ *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK, -+ n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ Array const *src_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kM; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kN; ++n) { -+ *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = -+ src_ptr[n + k * Iterations::kN]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, Index pointer_offset) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h -new file mode 100644 -index 0000000..e049f4f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_sparse_tensor_op.h -@@ -0,0 +1,339 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate -+ operations targeting sparse Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class SparseMmaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Equivalant base dense mma -+ using Base = MmaTensorOp; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Base::ArchMmaOperator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename Base::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = typename Base::OperatorClass; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename Base::InstructionShape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = Base::kTransformA; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = Base::kTransformB; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Sparsity in Operand A -+ static int const kSparse = Policy::Operator::kSparse; -+ -+ /// Meta data size in bits -+ static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; -+ -+ /// Max ID2 -+ static int const kMaxID2 = Policy::Operator::kMaxID2; -+ -+ /// Data type of meta E that is moved at the same time -+ using ElementE = -+ typename cutlass::platform::conditional::type; -+ -+ /// Number of ElementA that is associated with one ElementE -+ static int const kElementsPerElementE = -+ 128 / cutlass::sizeof_bits::value; -+ -+ /// Meta data is essentially interleaved but mapped to ColumnMajor internally -+ static int const kInterleaved = 2; -+ -+ /// Layout of meta E -+ using LayoutE = cutlass::layout::ColumnMajor; -+ -+ public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = typename Base::IteratorB; -+ -+ /// Storage for B tile -+ using FragmentB = typename Base::FragmentB; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = typename Base::TransformedFragmentB; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = typename Base::IteratorC; -+ -+ /// Storage for C tile -+ using FragmentC = typename Base::FragmentC; -+ -+ /// Iterates over the E operand in memory -+ using IteratorE = SparseMmaTensorOpMetaTileIterator< -+ MatrixShape, -+ ElementE, LayoutE, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for E tile -+ using FragmentE = typename IteratorE::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = typename Base::MmaIterations; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ SparseMmaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C, -+ FragmentE const &E -+ ) const { -+ -+ using MmaOperandA = typename Policy::Operator::FragmentA; -+ using MmaOperandB = typename Policy::Operator::FragmentB; -+ using MmaOperandC = typename Policy::Operator::FragmentC; -+ using MmaOperandE = typename Policy::Operator::FragmentE; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ MmaOperandE const *ptr_E = reinterpret_cast(&E); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ int id2 = m % kMaxID2; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_E[(m / kMaxID2)], -+ id2); -+ } else { -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_E[(m / kMaxID2)], -+ id2); -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h -new file mode 100644 -index 0000000..3124618 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op.h -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ return converter(source); -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ return source; -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ Array tmp; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); -+ tmp[i] = source[idx]; -+ } -+ -+ return converter(tmp); -+ } -+}; -+ -+template -+struct ConvertAndPack { -+ -+ using Converter = NumericArrayConverter; -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) { -+ Converter converter; -+ -+ Array tmp; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); -+ tmp[i] = source[idx]; -+ } -+ -+ return converter(tmp); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ // Serpentine visitation order maximizing reuse of Rb -+ // The visitation order is like -+ // _ -+ // | | | | -+ // | | | | -+ // |_| |_| -+ // -+ // Down Up Down Up -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n + m_serpentine * MmaIterations::kColumn], -+ ptr_A[m_serpentine], -+ ptr_B[n], -+ ptr_D[n + m_serpentine * MmaIterations::kColumn]); -+ } else { -+ mma( -+ ptr_D[m_serpentine + n * MmaIterations::kRow], -+ ptr_A[m_serpentine], -+ ptr_B[n], -+ ptr_D[m_serpentine + n * MmaIterations::kRow]); -+ } -+ } -+ } -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // Serpentine visitation order maximizing reuse of Ra -+ // The visitation order is like -+ // _________ -+ // _________| -+ // |_________ -+ // __________| -+ // -+ // Right Left Right Left -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn]); -+ } else { -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ Array * -+ ptr_dst_B = reinterpret_cast *>(&dst_B); -+ -+ dst_A = convert_A(A); -+ -+ ptr_dst_B[0] = convert_B(ptr_B[0]); -+ ptr_dst_B[1] = convert_B(ptr_B[1]); -+ -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h -new file mode 100644 -index 0000000..d17edc1 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h -@@ -0,0 +1,471 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+enum class TensorFloat32Op { -+ k3xTF32, -+ k4xTF32 -+}; -+ -+template < -+ /// Floating-point rounding style -+ FloatRoundStyle RoundBigA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundSmallA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundBigB_ = RoundBigA_, -+ /// Floating-point rounding style -+ FloatRoundStyle RoundSmallB_ = RoundSmallA_, -+ /// Precision for TensorFloat32Op -+ // (k3xTF32: BigxBig, BigxSmall, SmallxBig) -+ // (k4xTF32: BigxBig, BigxSmall, SmallxBig, SmallxSmall) -+ TensorFloat32Op Precision_ = TensorFloat32Op::k3xTF32 -+ > -+struct FastF32 { -+ -+ static FloatRoundStyle const kRoundBigA = RoundBigA_; -+ static FloatRoundStyle const kRoundSmallA = RoundSmallA_; -+ static FloatRoundStyle const kRoundBigB = RoundBigB_; -+ static FloatRoundStyle const kRoundSmallB = RoundSmallB_; -+ static TensorFloat32Op const kPrecision = Precision_; -+}; -+ -+ -+namespace detail { -+ -+ template< -+ int N, -+ FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, -+ FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate -+ > -+ struct ConvertAndPackAccurateF32 { -+ -+ /// Rounding styles for big and small part -+ static FloatRoundStyle const kRoundBig = RoundBig; -+ static FloatRoundStyle const kRoundSmall = RoundSmall; -+ -+ /// Converter type -+ using Converter = NumericConverterFastF32; -+ -+ /// Source fragement -+ using SourceFragment = Array; -+ -+ /// Destination fragment -+ using DestinationFragment = Array; -+ -+ /// Converter Fragment holding two tfloat32_t elements for every float -+ using ConverterFragment = Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(SourceFragment const &source, -+ DestinationFragment &dst_big, -+ DestinationFragment &dst_small) { -+ -+ Converter convert_; -+ ConverterFragment result_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ // convert source to result fragment -+ result_ = convert_(source[i]); -+ -+ // store converted result fragments to destination fragment -+ dst_big[i] = result_[kBigIndex]; -+ dst_small[i] = result_[kSmallIndex]; -+ } -+ } -+ }; -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOpFastF32; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float*float+float => float using TF32 TensorOps -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor, -+ /// Used for partial specialization -+ typename Enable -+> -+class MmaTensorOpFastF32< -+ Shape_, -+ float, LayoutA_, -+ float, LayoutB_, -+ float, LayoutC_, -+ Policy_, PartitionsK_, -+ AccumulatorsInRowMajor, Enable> { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = float; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = float; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = float; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = arch::OpMultiplyAddFastF32; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Tune F32 to TF32 big small conversion for float operation -+ /// Different combination of big small conversin can cause different tradeoff -+ /// between speed and accuracy. Generally, use round_half_ulp_truncate can -+ /// improve the performance but hur the accuracy. -+ using MmaFastF32 = FastF32 < -+ FloatRoundStyle::round_toward_zero, // kRoundBigA -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallA -+ FloatRoundStyle::round_toward_zero, // kRoundBigB -+ FloatRoundStyle::round_half_ulp_truncate, // kRoundSmallB -+ TensorFloat32Op::k3xTF32 // Number of TF32 operations -+ >; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ kThreadCount, -+ kPartitionsK -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Fragment bisecting big and small sections -+ using AccessTypeFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, -+ kThreadCount, -+ kPartitionsK -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Fragment bisecting big and small sections -+ using AccessTypeFragmentB = -+ Array; -+ -+ /// Index in fargments for the big and small part -+ static int const kBigIndex = 0; -+ static int const kSmallIndex = 1; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOpFastF32() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ AccessTypeFragmentA const *ptr_A = reinterpret_cast(&A); -+ AccessTypeFragmentB const *ptr_B = reinterpret_cast(&B); -+ -+ // -+ // Accumulate in place -+ // -+ D = C; -+ -+ mma_operator(D, ptr_A[kSmallIndex], ptr_B[kBigIndex], D); -+ -+ mma_operator(D, ptr_A[kBigIndex], ptr_B[kSmallIndex], D); -+ -+ mma_operator(D, ptr_A[kBigIndex], ptr_B[kBigIndex], D); -+ -+ if (MmaFastF32::kPrecision == TensorFloat32Op::k4xTF32) -+ mma_operator(D, ptr_A[kSmallIndex], ptr_B[kSmallIndex], D); -+ } -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void mma_operator( -+ FragmentC &D, -+ AccessTypeFragmentA const &A, -+ AccessTypeFragmentB const &B, -+ FragmentC const &C -+ ) const { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ // Serpentine visitation order maximizing reuse of Ra -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ // This allows to reuse of Rb when at serpentine turns -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ if (AccumulatorsInRowMajor) { // matrix B is reordered -+ mma( -+ ptr_D[n_serpentine + m * MmaIterations::kColumn], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[n_serpentine + m * MmaIterations::kColumn]); -+ } else { -+ mma( -+ ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ } -+ } // end n loop -+ } // end m loop -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+ detail::ConvertAndPackAccurateF32< -+ FragmentA::kElements / 2, -+ MmaFastF32::kRoundBigA, -+ MmaFastF32::kRoundSmallA> convert_A; -+ -+ detail::ConvertAndPackAccurateF32< -+ FragmentB::kElements, -+ MmaFastF32::kRoundBigB, -+ MmaFastF32::kRoundSmallB> convert_B; -+ -+ Array *ptr_dst_B = -+ reinterpret_cast *>(&dst_B); -+ -+ convert_B(B, ptr_dst_B[0], ptr_dst_B[1]); -+ -+ Array *ptr_dst_A = -+ reinterpret_cast *>(&dst_A); -+ -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ -+ convert_A(ptr_A[0], ptr_dst_A[0], ptr_dst_A[2]); -+ -+ convert_A(ptr_A[1], ptr_dst_A[1], ptr_dst_A[3]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h -new file mode 100644 -index 0000000..aa2806d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h -@@ -0,0 +1,528 @@ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of a warp tile -+ that participate in one warp-level mma operation. -+ -+ Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation. -+ The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into -+ next warp-level mma operation. -+ -+ This iterator is necessary to accomplish warp-level mma fusion where the accumulator tile is -+ reused as multiplicand tile for the next mma. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Size of the accumulation tile shape (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on the fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator; -+ -+ -+// Partial specialization for col-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Output operation on fragment -+ using OutputOp = OutputOp_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ AccumulatorShape::kRow == Shape::kRow, -+ "Rows of Warp Accumulator must be the same as rows of warp"); -+ static_assert( -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn -+ * (AccumulatorShape::kRow / Shape::kRow); -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ /// Scale Bias Element Type -+ using ElementScaleBias = typename OutputOp::ElementCompute; -+ -+ /// Scale Bias Fragment object -+ using ScaleBiasFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ -+ using ScaleBiasAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index = index_ * MmaIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ n * AccumulatorIterations::kRow + m + index; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset]); -+ } -+ } -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ /// Then apply per-channel scale and bias -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, ScaleBiasFragment &scale, -+ ScaleBiasFragment &bias, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); -+ ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); -+ -+ int index = index_ * MmaIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; n++) { -+ for (int m = 0; m < MmaIterations::kRow; m++) { -+ int accumulator_access_offset = -+ n * AccumulatorIterations::kRow + m + index; -+ -+ frag_ptr[m * MmaIterations::kColumn + n].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[m * MmaIterations::kColumn + n] = -+ output_op(accumulators_[accumulator_access_offset], -+ scale_ptr[n] /*scale*/, bias_ptr[n] /*bias*/); -+ } -+ } -+ } -+ -+ -+ -+}; -+ -+// Partial specialization for row-major accumulator tile -+ -+template < -+ /// Shape of warp tile to load (concept: MatrixShape) -+ typename Shape_, -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ typename AccumulatorShape_, -+ /// KBlocks columns to compute residual -+ int KBlocksColumn_, -+ /// Accumulator Element type -+ typename ElementAccumulator_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Output operation on fragment -+ typename OutputOp_> -+class MmaTensorOpFragmentIterator { -+ public: -+ -+ /// Shape of warp tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Shape of the warp accumulation tile (concept: MatrixShape) -+ using AccumulatorShape = AccumulatorShape_; -+ -+ /// KBlocks columns to compute residual -+ static int const kKBlockColumn = KBlocksColumn_; -+ -+ /// Accumulator Element type -+ using ElementAccumulator = ElementAccumulator_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Output operation on fragment -+ using OutputOp = OutputOp_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ static_assert( -+ AccumulatorShape::kRow == Shape::kRow, -+ "Rows of Warp Accumulator must be the same as rows of warp"); -+ static_assert( -+ !(AccumulatorShape::kColumn % Shape::kColumn), -+ "Shape of Warp Accumulator must be divisible by warp shape."); -+ static_assert( -+ !(kKBlockColumn % Shape::kColumn), -+ "KBlock size must be divisible by warp shape."); -+ -+ /// Number of times this iterator can be incremented -+ static int const kIterations = AccumulatorShape::kCount / Shape::kCount; -+ }; -+ -+private: -+ -+ static int const kRowsPerIteration = 8; -+ static int const kColumnsPerIteration = 16; -+ static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kN / kThreads; -+ static int const kElementsPerAccess = kRowsPerIteration * kColumnsPerIteration / kThreads; -+ static int const kIterationsPerAccess = kElementsPerAccess / kElementsPerIteration; -+ -+ // Number of iterations per actual instruction -+ static int const kIterationsPerInstruction = InstructionShape::kM / kRowsPerIteration; -+ -+ static int const kAccessStride = kIterationsPerInstruction; -+ -+ /// Number of mma operations performed by a warp -+ using MmaIterations = MatrixShape; -+ /// Number of mma operations performed by the entire accumulator -+ using AccumulatorIterations = MatrixShape; -+ -+ /// Number of Accesses in a warp -+ using AccessIterations = MatrixShape; -+ -+ /// Number of K iterations -+ static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; -+ static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; -+ static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn; -+ static int const kResidualIndex = kResidualColumn / Shape::kColumn; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one access of the iterator. -+ using Fragment = Array; -+ -+ /// Accumulator Fragment object -+ using AccumulatorFragment = Array; -+ -+ /// Scale Bias Element Type -+ using ElementScaleBias = typename OutputOp::ElementCompute; -+ -+ /// Scale Bias Fragment object -+ using ScaleBiasFragment = Array; -+ -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ using FragmentAccessType = Array; -+ using ScaleBiasAccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Accumulator tile -+ AccessType const *accumulators_; -+ -+ /// Internal index -+ int index_; -+ -+ /// Used to access residual tile first -+ bool is_residual_tile_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator(AccumulatorFragment const &accum) -+ : accumulators_(reinterpret_cast(&accum)), -+ index_(0), is_residual_tile_(true) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { -+ index_ = index_ - kKBlockColumnIterations + kResidualIndex; -+ is_residual_tile_ = false; -+ } -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ /// Decrements -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpFragmentIterator &operator--() { -+ add_offset(-1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ index_ = idx; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ int index = index_ * AccessIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < AccessIterations::kCount; i++) { -+ -+ int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + -+ (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * -+ AccumulatorIterations::kColumn * kIterationsPerInstruction + -+ (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * -+ (kIterationsPerInstruction * kIterationsPerAccess) + -+ (index % kIterationsPerInstruction); -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kIterationsPerAccess; j++) { -+ -+ frag_ptr[i*kIterationsPerAccess + j].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride]); -+ } -+ index++; -+ } -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ /// Then apply per-channel scale and bias -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, ScaleBiasFragment &scale, -+ ScaleBiasFragment & bias, OutputOp output_op) const { -+ -+ if (output_op.is_source_needed()) //beta must be zero -+ assert(0); -+ -+ FragmentAccessType *frag_ptr = reinterpret_cast(&frag); -+ ScaleBiasAccessType * scale_ptr = reinterpret_cast(&scale); -+ ScaleBiasAccessType * bias_ptr = reinterpret_cast(&bias); -+ -+ int index = index_ * AccessIterations::kCount; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < AccessIterations::kCount; i++) { -+ -+ int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + -+ (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * -+ AccumulatorIterations::kColumn * kIterationsPerInstruction + -+ (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * -+ (kIterationsPerInstruction * kIterationsPerAccess) + -+ (index % kIterationsPerInstruction); -+ -+ int scale_bias_offset = (index -+ % (kIterationsPerInstruction * AccessIterations::kColumn)) -+ * kIterationsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kIterationsPerAccess; j++) { -+ -+ -+ frag_ptr[i*kIterationsPerAccess + j].clear(); -+ if(!(is_residual_tile_ && index_ >= kResidualIndex)) -+ frag_ptr[i*kIterationsPerAccess + j] = output_op( -+ accumulators_[accumulator_access_offset + j * kAccessStride], -+ scale_ptr[scale_bias_offset + j], bias_ptr[scale_bias_offset + j]); -+ } -+ index++; -+ } -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h -new file mode 100644 -index 0000000..f73ede6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Policy describing implementation details of warp-level GEMM targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy -+template < -+ typename Operator_, ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) -+ typename OpDelta_ ///< distance between operations (concept: MatrixShape) -+> -+struct MmaTensorOpPolicy { -+ -+ using Operator = Operator_; ///< hardware instruction(s) performing TensorOp (concept: arch::Mma) -+ using OpDelta = OpDelta_; ///< distance between operations (concept: MatrixShape) -+ using MmaShape = typename Operator::Shape; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h -new file mode 100644 -index 0000000..0a2449d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_sm70.h -@@ -0,0 +1,280 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+ -+ This is a work in progress. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/mma.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaVoltaTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Architecture tag -+ using ArchTag = arch::Sm70; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Underlying instruction shape -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// interleaved 32x32 tiles -+ using InterleavedTileShape = GemmShape<32, 32, 4>; -+ -+ static_assert(!(Shape::kM % InterleavedTileShape::kM) && -+ !(Shape::kN % InterleavedTileShape::kN), -+ "Shape must be a multiple of InterleavedTileShape."); -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaVoltaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kA, -+ ElementA, -+ LayoutA, -+ MatrixShape< -+ ArchMmaOperator::Shape::kM, -+ ArchMmaOperator::Shape::kK -+ >, -+ Policy::OpDelta::kRow, -+ kThreadCount -+ >; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaVoltaTensorOpMultiplicandTileIterator< -+ MatrixShape, -+ Operand::kB, -+ ElementB, -+ LayoutB, -+ MatrixShape< -+ ArchMmaOperator::Shape::kK, -+ ArchMmaOperator::Shape::kN -+ >, -+ Policy::OpDelta::kRow, -+ kThreadCount -+ >; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaVoltaTensorOpAccumulatorTileIterator< -+ MatrixShape, -+ ElementC, -+ LayoutC, -+ typename ArchMmaOperator::Shape, -+ typename Policy::OpDelta -+ >; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ static_assert( -+ !(Shape::kM % ArchMmaOperator::Shape::kM) && -+ !(Shape::kN % ArchMmaOperator::Shape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ InterleavedTileShape::kM / ArchMmaOperator::Shape::kM, -+ InterleavedTileShape::kN / ArchMmaOperator::Shape::kN -+ >; -+ using TileIterations = MatrixShape< -+ Shape::kM / InterleavedTileShape::kM, -+ Shape::kN / InterleavedTileShape::kN -+ >; -+ -+ // Whether matrix B is reordered -+ bool reorder_B_; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) { -+ CUTLASS_PRAGMA_UNROLL -+ -+ for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) { -+ -+ int op_col = inner_col + MmaIterations::kColumn * outer_col; -+ -+ // Column-major serpentine sequence to maximize reuse of A operand. -+ int inner_row_serp = inner_row; -+ int outer_row_serp = outer_row; -+ if (op_col & 1) { -+ inner_row_serp = MmaIterations::kRow - inner_row - 1; -+ outer_row_serp = TileIterations::kRow - outer_row - 1; -+ } -+ int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp; -+ int op_idx = inner_row_serp + MmaIterations::kRow * -+ (inner_col + MmaIterations::kColumn * -+ (outer_row_serp + TileIterations::kRow * outer_col)); -+ mma( -+ ptr_D[op_idx], -+ ptr_A[op_row], -+ ptr_B[op_col], -+ ptr_D[op_idx]); -+ -+ } -+ } -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h -new file mode 100644 -index 0000000..5e4de60 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h -@@ -0,0 +1,362 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+ -+/// Tile access iterator -+/// Each iteration acess in the tile is -+/// used as multiplicand for one -+/// warp-level matrix multiplication -+template < -+ /// Size of the tile (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Enable Residual Support -+ bool EnableResidual = false, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1 -+> -+class MmaTensorOpMultiplicandTileAccessIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = -+ (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); -+ -+ using InstructionCount = MatrixShape< -+ Shape::kRow / InstructionShape::kRow, -+ Shape::kColumn / InstructionShape::kColumn -+ >; -+ -+ static int const kIterations = (kOperand == Operand::kA) ? -+ InstructionCount::kColumn : InstructionCount::kRow; -+ -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA) ? -+ (Shape::kRow * InstructionShape::kColumn / kThreads) : -+ (Shape::kColumn * InstructionShape::kRow / kThreads) -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to load residual tile -+ bool is_residual_; -+ -+ /// residual offset of each thread -+ TensorCoord residual_offset_; -+ -+ /// Iterations in a tile -+ int iterations_; -+ -+public: -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), is_residual_(false), iterations_(0) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ -+ if(EnableResidual) { -+ // compute residual offset -+ if (kOperand == Operand::kA) { -+ typename TensorCoord::Index residual_size = -+ extent_.column() % Shape::kColumn; -+ if(residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(0, residual_size); -+ } -+ } -+ else { -+ typename TensorCoord::Index residual_size = -+ extent_.row() % Shape::kRow; -+ if(residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(residual_size, 0); -+ } -+ } -+ } -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): MmaTensorOpMultiplicandTileAccessIterator(ref, -+ {Shape::kRow, Shape::kColumn}, lane_id) { -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ void advance() { -+ -+ if(EnableResidual && is_residual_) { -+ is_residual_ = false; -+ -+ origin_ += residual_offset_; -+ ref_.add_coord_offset(residual_offset_); -+ -+ } -+ -+ else { -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ } -+ -+ iterations_ = 0; -+ } -+ -+ /// increase iterations in a tile -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileAccessIterator & operator++() { -+ -+ iterations_++; -+ -+ if(iterations_ >= kIterations) -+ advance(); -+ -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ int const kWarpShapeDivisibleInner = -+ (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); -+ -+ // Take advantage of Tensor Op's 8 x 4T access pattern -+ int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; -+ -+ AccessType *access_ptr = reinterpret_cast(&frag); -+ -+ if (kOperand == Operand::kA) { -+ int const kTilesPerInstruction = InstructionShape::kRow / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { -+ int access_idx = -+ access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); -+ -+ MatrixCoord offset( -+ access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, -+ inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kColumn); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+// } -+// else { -+// AccessType zero; -+// zero.clear(); -+// access_ptr[access_idx] = zero; -+// } -+ } -+ } -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ int access_idx = inner_idx + kAccessesInner * inst_n_idx; -+ -+ MatrixCoord offset( -+ inner_idx * 4 * kElementsPerAccess + iterations_ * InstructionShape::kRow, -+ inst_n_idx * 8); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+// if(access_coord.row() < extent_.row() && access_coord.column() < extent_.column()) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+// } -+// else { -+// AccessType zero; -+// zero.clear(); -+// access_ptr[access_idx] = zero; -+// } -+ } -+ } -+ } -+ } -+ -+}; -+ -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -new file mode 100644 -index 0000000..54f194f ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h -@@ -0,0 +1,3982 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaTensorOpMultiplicandTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous::value, -+ 64>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous< -+ sizeof_bits::value, 64>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = Layout::kElementsPerAccess; -+ static int const kLdsmOpInner = 8; -+ -+ static_assert(!(Shape::kContiguous % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeStrided = -+ InstructionShape::kStrided / kLdsmOpInner; -+ static int const LdsmShapeContiguous = 4 / LdsmShapeStrided; -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = layout::PitchLinearShape< -+ Shape::kContiguous / Layout::kElementsPerAccess / LdsmShapeContiguous, -+ 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kStrided / InstructionShape::kStrided; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = -+ Layout::TileShape::kContiguous / Policy::LdsmShape::kContiguous; -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int quad_pair = (lane_id >> 3); -+ int quad_quad = (lane_id >> 4); -+ int lane_in_quad = (lane_id & 3); -+ int lane_in_quad_pair = (lane_id & 7); -+ int lane_in_quad_quad = (lane_id & 15); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ int partition_contiguous_idx = -1; -+ int access_contiguous_idx = -1; -+ int access_strided_idx = -1; -+ -+ if (Policy::LdsmShape::kContiguous == 4) { -+ // Matrix multiply 1688 A/B -+ // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block). -+ // Four blocks are next to each other in the contiguous dimension. -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i); -+ access_contiguous_idx = (quad_pair ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair; -+ } -+ else if (Policy::LdsmShape::kContiguous == 2 && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816 A -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); -+ access_contiguous_idx = -+ (((quad_pair & 1) + ((i & 1) << 1)) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); -+ } else if (Policy::LdsmShape::kContiguous == 2 && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816 B -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); -+ access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_quad; -+ } else if (Policy::LdsmShape::kContiguous == 1) { -+ // Matrix multiply 16832.SP B -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2)); -+ access_contiguous_idx = ((i & 3) ^ lane_in_quad); -+ access_strided_idx = lane_id; -+ } -+ -+ int access_contiguous = -+ partition_contiguous_idx * Layout::PartitionShape::kContiguous + -+ access_contiguous_idx; -+ -+ int access_strided = access_strided_idx; -+ -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ if (Shape::kContiguous == -+ Layout::PartitionShape::kContiguous * Layout::kElementsPerAccess) { -+ if (tile_offset.contiguous() % 2) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount / 2; ++i) { -+ AccessType const *tmp_pointer = pointer_[i]; -+ pointer_[i] = pointer_[i + kPointerCount / 2]; -+ pointer_[i + kPointerCount / 2] = tmp_pointer; -+ } -+ } -+ contiguous_offset = (tile_offset.contiguous() >> 1) << 1; -+ } -+ -+ int offset = (tile_offset.strided() * InstructionShape::kStrided) * -+ stride_ * Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_[c % kPointerCount] + -+ Layout::TileShape::kContiguous * (c / kPointerCount) + -+ Policy::kLdsmOpInner * Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], -+ source_byte_ptr -+ ); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread MMA.TF32 NT TensorOps. It -+/// uses LDS.32 to load from shared memory and therefore must be initialized -+/// with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous<32, 32>, InstructionShape_, -+ OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous<32, 32>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual 32bit -+ // shared memory load op. Every one warp of 32bit shared memory load loads -+ // 8x4 elements -+ static int const kLdsOpInner = Layout::TileShape::kStrided; -+ static int const kLdsOpOuter = kThreads / kLdsOpInner; -+ -+ static_assert(!(Shape::kContiguous % kLdsOpOuter), -+ "Shape of warp-level mma must be divisible by 32bit " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsOpInner), -+ "Shape of warp-level mma must be divisible by 32bit " -+ "fundamental tile size."); -+ -+ /// Number of 32 bit shared memory load instructions needed by one MMA instruction -+ /// 1688 A 2x2 -+ /// 1688 B 1x2 -+ /// 16816 B 1x4 -+ static int const LdsShapeContiguous = -+ InstructionShape::kContiguous / kLdsOpOuter; -+ static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner; -+ using LdsShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDS instructions -+ using LdsIterations = layout::PitchLinearShape< -+ Shape::kContiguous / LdsShapeContiguous / kLdsOpOuter, 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kStrided / InstructionShape::kStrided; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = Layout::TileShape::kContiguous * -+ Layout::kElementsPerAccess / -+ Policy::kLdsOpOuter; -+ -+ /// Vectorized access is not used -+ static int const kElementsPerAccess = 1; -+ -+ /// Pointer type used for accesses -+ using AccessType = Element; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() : stride_(0), byte_offset_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : stride_(ref.stride(0)), byte_offset_(0), k_group_idx_(0) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ int access_strided = lane_id % Policy::kLdsOpInner; -+ int access_contiguous = (lane_id / Policy::kLdsOpInner) + -+ (access_strided ^ i) * Policy::kLdsOpOuter; -+ -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int contiguous_offset = tile_offset.contiguous(); -+ if (Shape::kContiguous == -+ Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) { -+ if (tile_offset.contiguous() % 2) { -+ // Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5] -+ // pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7] -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount / 2; ++i) { -+ AccessType const *tmp_pointer = pointer_[i]; -+ pointer_[i] = pointer_[i + kPointerCount / 2]; -+ pointer_[i + kPointerCount / 2] = tmp_pointer; -+ } -+ } -+ contiguous_offset = (tile_offset.contiguous() >> 1) << 1; -+ } -+ -+ int offset = (tile_offset.strided() * InstructionShape::kStrided) * stride_ + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Element *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ss = 0; ss < Policy::LdsShape::kStrided; ++ss) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int cc = 0; cc < Policy::LdsShape::kContiguous; ++cc) { -+ int access_idx = -+ cc + (ss + (c + s * Policy::LdsIterations::kContiguous) * -+ Policy::LdsShape::kStrided) * -+ Policy::LdsShape::kContiguous; -+ int access_idx_contiguous = cc + c * Policy::LdsShape::kContiguous; -+ int access_idx_strided = -+ (ss + s * Policy::LdsShape::kStrided) * Policy::kLdsOpInner; -+ -+ AccessType const *source_ptr = -+ pointer_[access_idx_contiguous % kPointerCount] + -+ Layout::TileShape::kContiguous * Layout::kElementsPerAccess * -+ (access_idx_contiguous / kPointerCount) + -+ access_idx_strided * stride_; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ fetch_ptr[access_idx] = -+ *reinterpret_cast(source_byte_ptr); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpMultiplicandIterator for ColumnMajor Congruous may " -+ "only be instantiated for A operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator for RowMajor Congruous may " -+ "only be instantiated for B operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Element number when the layout crosses -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = Layout::kElementsPerAccess; -+ static int const kLdsmOpInner = 8; -+ -+ static_assert(!(Shape::kContiguous % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kStrided % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeContiguous = -+ InstructionShape::kContiguous / kLdsmOpOuter; -+ static int const LdsmShapeStrided = -+ ((4 / LdsmShapeContiguous * kLdsmOpInner) > Shape::kStrided) -+ ? (Shape::kStrided / kLdsmOpInner) -+ : (4 / LdsmShapeContiguous); -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = -+ layout::PitchLinearShape<1, Shape::kStrided / kLdsmOpInner / -+ LdsmShape::kStrided>; -+ -+ /// -+ static int const kGroupsPerTile = Layout::TileShape::kContiguous / -+ Layout::kFactor / LdsmShape::kContiguous; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ private: -+ -+ /// Total number of sections. The memory is divided into stages. One stage -+ /// can store one tile. Stage is divided into sections. Interleaved layout -+ /// can have multiple sections in a stage. The rest layout only has one section -+ /// in a stage. -+ int sections_; -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() -+ : pointer_(nullptr), -+ sections_(0), -+ stride_(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ sections_(ref.stride(0) / kCrosswise), -+ // stride_ = kCrosswise x sections_ x kFactor -+ stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), -+ byte_offset_(0), -+ k_group_idx_(0) { -+ // Warp level iterator at most use double buffer to hide latency. If there -+ // are more than 2 sections, every stage should have more than 1 section. -+ -+ // Turing silicon requires all 32 threads in a warp provide valid addresses -+ // even for LDSM.1 and LDSM.2 -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 750)) -+ lane_id = lane_id % (Policy::LdsmShape::kCount * Policy::kLdsmOpInner); -+#endif -+ -+ int quad_quad = (lane_id >> 4); -+ int quad_pair = (lane_id >> 3); -+ int lane_in_pair = (lane_id & 1); -+ int lane_in_quad = (lane_id & 3); -+ int lane_in_quad_pair = (lane_id & 7); -+ int lane_in_quad_quad = (lane_id & 15); -+ -+ int partition_contiguous_idx = -1; -+ int access_contiguous_idx = -1; -+ int access_strided_idx = -1; -+ -+ if (Layout::kFactor == 4) { -+ // Super Integer matrix multiply Interleaved-32 -+ -+ int factor_in_partition = -+ (Layout::PartitionShape::kContiguous * Layout::kFactor / -+ Layout::TileShape::kContiguous); -+ -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Integer matrix multiply 8816 A/B -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^ -+ (lane_in_quad_quad / Layout::kFactor)); -+ access_strided_idx = lane_id / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Integer matrix multiply 16832 A -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_strided_idx = lane_in_quad_quad / Layout::kFactor; -+ access_contiguous_idx = -+ ((lane_in_pair * factor_in_partition + quad_quad) ^ -+ access_strided_idx); -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Integer matrix multiply 16832 B -+ partition_contiguous_idx = lane_in_quad / factor_in_partition; -+ access_strided_idx = lane_in_quad_pair / Layout::kFactor + quad_quad * 2; -+ access_contiguous_idx = -+ ((lane_in_pair * factor_in_partition + ((lane_id & 8) >> 3)) ^ -+ access_strided_idx); -+ } -+ } else if (Layout::kFactor == 2) { -+ // Super Matrix multiply kBlock = 32 -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Matrix multiply 1688 A/B -+ // (Q stands for 1 8x128bit block). -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ // Four blocks are next to each other in the strided dimension. -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor); -+ access_strided_idx = lane_id / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816|1688.TF32 A -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ (quad_quad ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = (lane_in_quad_quad / Layout::kFactor); -+ } else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816|1688.TF32 B -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ ((quad_pair & 1) ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = -+ (lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor; -+ } -+ else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { -+ // Matrix multiply 16832.SP B -+ // Q0 Q1 Q2 Q3 -+ partition_contiguous_idx = (lane_id % Layout::kFactor); -+ access_contiguous_idx = -+ (quad_pair ^ (lane_in_quad_pair / Layout::kFactor)); -+ access_strided_idx = lane_in_quad_pair / Layout::kFactor; -+ } -+ } else if (Layout::kFactor == 1) { -+ // Super Matrix multiply kBlock = 64 -+ if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { -+ // Q0 -+ // Q1 -+ // Q2 -+ // Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = lane_in_quad; -+ access_strided_idx = lane_id; -+ } -+ else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kA) { -+ // Matrix multiply 16816|1688.TF32 A -+ // Q0 Q2 -+ // Q1 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = (quad_quad ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_quad; -+ } else if (Policy::LdsmShape::kStrided == -+ (Policy::LdsmShape::kCount / 2) && -+ kOperand == Operand::kB) { -+ // Matrix multiply 16816|1688.TF32 B -+ // Q0 Q1 -+ // Q2 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); -+ } -+ else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { -+ // Matrix multiply 16832.SP B -+ // Q0 Q1 Q2 Q3 -+ partition_contiguous_idx = (lane_in_quad_pair >> 2); -+ access_contiguous_idx = (quad_pair ^ lane_in_quad); -+ access_strided_idx = lane_in_quad_pair; -+ } -+ } -+ -+ int access_contiguous = -+ partition_contiguous_idx * Layout::PartitionShape::kContiguous + -+ access_contiguous_idx; -+ -+ int access_strided = access_strided_idx; -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * -+ sizeof_bits::value * Layout::kElementsPerAccess / 8; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ -+ byte_offset_ ^= k_groups_delta * sizeof_bits::value * -+ Layout::kElementsPerAccess * -+ Policy::LdsmShape::kContiguous / 8; -+ pointer_ += -+ tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + -+ whole_tiles * stride_ / sections_; -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ if (k_groups_delta < 0) { -+ whole_tiles -= 1; -+ k_groups_delta += Policy::kGroupsPerTile; -+ } -+ -+ if ((Policy::kGroupsPerTile / kPartitionsK) >= 2) { -+ byte_offset_ ^= (k_groups_delta & 1) * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ if ((Policy::kGroupsPerTile / kPartitionsK) >= 4) { -+ byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 1)) & 2) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ if ((Policy::kGroupsPerTile / kPartitionsK) == 8) { -+ byte_offset_ ^= ((k_groups_delta + (k_group_idx_ & 3)) & 4) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ -+ k_group_idx_ += k_groups_delta; -+ whole_tiles += k_group_idx_ / (Policy::kGroupsPerTile / kPartitionsK); -+ k_group_idx_ = k_group_idx_ % (Policy::kGroupsPerTile / kPartitionsK); -+ -+ pointer_ += -+ tile_offset.strided() * stride_ * Shape::kStrided / Layout::kFactor + -+ whole_tiles * stride_ / sections_; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ -+ // Integer matrix multiply 16832 Interleaved-32 -+ // NONE -+ // Integer matrix multiply 16816 Interleaved-32 || Integer matrix multiply 16816 kblock=32 -+ -+ // Integer matrix multiply 8816 Interleaved-32 -+ // ^1 ^1 -+ // Matrix multiply 1684.TF32 kblock=16 || Integer matrix multiply 16816 kblock=64 -+ // Matrix multiply 1688 kblock=32 || Integer matrix multiply 8816 kblock=64 -+ // ^1 ^3 ^1 ^3 -+ // Matrix multiply 1688 kblock=64 -+ // ^1 ^3 ^1 ^7 ^1 ^3 ^1 ^7 -+ -+ // Matrix multiply 16816 kblock=32 | 1688.TF32 kblock=16 || Integer matrix multiply 16832 kblock=64 -+ // ^2 ^2 -+ // Matrix multiply 16816 kblock=64 | 1688.TF32 kblock=32 || Integer matrix multiply 16832 kblock=128 -+ // ^2 ^6 ^2 ^6 -+ -+ if ((Policy::kGroupsPerTile / kPartitionsK) > 1) { -+ int mask = ((Policy::kGroupsPerTile / kPartitionsK) == 8) -+ ? 3 -+ : (((Policy::kGroupsPerTile / kPartitionsK) == 4) ? 1 : 0); -+ -+ if (((k_group_idx_ & mask) % 2) == 0) -+ byte_offset_ ^= 1 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ else if ((k_group_idx_ & mask) == 1) -+ byte_offset_ ^= 3 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ else if ((k_group_idx_ & mask) == 3) -+ byte_offset_ ^= 7 * Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * -+ Layout::kElementsPerAccess / 8; -+ } -+ -+ k_group_idx_++; -+ -+ if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { -+ k_group_idx_ = 0; -+ add_tile_offset({Policy::kGroupsPerTile, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + Policy::LdsmShape::kContiguous * c + -+ Policy::kLdsmOpInner / Layout::kFactor * -+ Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof_bits::value * pointer_offset / 8; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator for ColumnMajor Crosswise may " -+ "only be instantiated for B operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ kCrosswise>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Element number when the layout crosses (in units of elements) -+ int Crosswise, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpMultiplicandIterator for RowMajor Crosswise may " -+ "only be instantiated for A operand to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Element number when the layout crosses -+ static int const kCrosswise = Crosswise; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kCrosswise>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ kCrosswise>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. -+/// -+/// This iterator is not tested. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::AffineRankN<2>, InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ -+ frag[mma_accum_start + row * kElementsPerAccess + col] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaTensorOpAccumulatorTileIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static bool const kDivisible = -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, -+ (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN -+ >; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire -+ // shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements -+ // of that row. The accumulators within one row are assumed to be consecutive. -+ static int const kElementsPerAccess = InstructionShape::kN / 4; -+ static int const kRowsPerTile = 8; -+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ frag[idx] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess * -+ (mma_n * Policy::MmaIterations::kRow + mma_m); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < kAccumulatorRows; ++row) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int col = 0; col < kElementsPerAccess; ++col) { -+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + -+ row * kRowsPerTile; -+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col; -+ int idx = mma_accum_start + row * kElementsPerAccess + col; -+ -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element typ -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Interleaved N -+ int InterleavedN> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::ColumnMajorInterleaved, -+ InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorInterleaved; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = 2; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ using AccessType = Array; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ MatrixCoord lane_offset(quad, lane_in_quad * kElementsPerAccess); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kRow; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_ref.offset(TensorCoord(accum_m, accum_n))); -+ -+ frag_ptr[idx] = access_ptr[0]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kRow; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_ref.offset(TensorCoord(accum_m, accum_n))); -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element typ -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Interleaved N -+ int InterleavedN> -+class MmaTensorOpAccumulatorTileIterator< -+ Shape_, Element_, cutlass::layout::TensorNCxHWx, -+ InstructionShape_, OpDelta_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = int8_t; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorNCxHWx; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kRow % InstructionShape::kM) && -+ !(Shape::kColumn % InstructionShape::kN), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ /// Number of elements in strided dimension that each STG writes -+ static int const kStridedPerSTG = 8; -+ -+ /// Factor to calculate reorder index to pack accumulator. -+ static int const kPackedFactor = Shape::kColumn / 32; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape; -+ }; -+ -+private: -+ -+ static int const kElementsPerAccess = InterleavedN / 4; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ struct alignas((kElementsPerAccess * sizeof_bits::value / 8)) AccessType { -+ Array storage; -+ }; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+ /// Row offset index globally -+ LongIndex global_offset_row_; -+ -+ /// Column offset index globally -+ LongIndex global_offset_col_; -+ -+ /// Output tensor size -+ TensorCoord extent_; -+ -+ /// Alpha -+ float alpha_; -+ -+ /// Beta -+ float beta_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int const lane_id, -+ TensorCoord extent, -+ float alpha = 1.0f, -+ float beta = 0.0f -+ ): -+ ref_(ref), -+ extent_(extent), -+ alpha_(alpha), -+ beta_(beta) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ -+ global_offset_row_ = quad; -+ -+ global_offset_col_ = lane_in_quad * kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator &add_tile_offset(MatrixCoord const &tile_offset) { -+ -+ global_offset_row_ += tile_offset.row() * Shape::kRow; -+ -+ global_offset_col_ += tile_offset.column() * Shape::kColumn; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ AccessType* frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kN; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kM; ++mma_m) { -+ int accum_m = mma_m * InstructionShape::kM; -+ int accum_n = mma_n * InstructionShape::kN; -+ -+ int idx = mma_m + mma_n * Policy::MmaIterations::kM; -+ -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ accum_m * offset_ref.stride(0) + accum_n); -+ -+ frag_ptr[idx] = access_ptr[0]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ Array output_frag_f; -+ Array output_frag; -+ -+ LongIndex pq = extent_.h() * extent_.w(); -+ -+ LongIndex extent_row = extent_.n() * pq; -+ LongIndex extent_col = extent_.c(); -+ -+ LongIndex k_major = (global_offset_col_ / InterleavedN) * pq; -+ Index k_minor = global_offset_col_ % InterleavedN; -+ LongIndex k_offset = k_major * InterleavedN + k_minor; -+ LongIndex k_offset_delta = pq * InterleavedN; -+ -+ LongIndex stride_n = pq * extent_.c(); -+ -+ Index n; -+ LongIndex pq_rem; -+ -+ unsigned int pq_mul, pq_shr; -+ find_divisor(pq_mul, pq_shr, pq); -+ -+ if(beta_ == 0.0f) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag_f[i] = frag[i]; -+ } -+ -+ if(InstructionShape::kM == Policy::kStridedPerSTG) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag[i] = (Element)(output_frag_f[i] * alpha_); -+ } -+ } else { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) -+ + (i % (8 * Policy::kPackedFactor)) / 2 * 4 -+ + (i % (8 * Policy::kPackedFactor)) % 2 -+ + (i / (8 * Policy::kPackedFactor)) % 2 * 2; -+ output_frag[i] = (Element)(output_frag_f[map_i] * alpha_); -+ } -+ } -+ -+ AccessType const *frag_ptr = reinterpret_cast(&output_frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * Policy::kStridedPerSTG; -+ -+ fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); -+ LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ -+ int accum_n = mma_n * InterleavedN; -+ -+ int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; -+ -+ if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_m + mma_n * k_offset_delta); -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ } else { -+ if(InstructionShape::kM == Policy::kStridedPerSTG) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ output_frag_f[i] = frag[i]; -+ } -+ } else { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < frag.size(); ++i) { -+ int map_i = (i / (16 * Policy::kPackedFactor)) * (16 * Policy::kPackedFactor) -+ + (i % (8 * Policy::kPackedFactor)) / 2 * 4 -+ + (i % (8 * Policy::kPackedFactor)) % 2 -+ + (i / (8 * Policy::kPackedFactor)) % 2 * 2; -+ output_frag_f[i] = frag[map_i]; -+ } -+ } -+ -+ AccessType const *frag_ptr = reinterpret_cast(&output_frag); -+ -+ Array ref_frag; -+ AccessType *ref_frag_ptr = reinterpret_cast(&ref_frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ int accum_m = mma_m * Policy::kStridedPerSTG; -+ -+ fast_divmod(n, pq_rem, global_offset_row_ + accum_m, pq, pq_mul, pq_shr); -+ LongIndex offset_m = n * stride_n + k_offset + pq_rem * InterleavedN; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ -+ int accum_n = mma_n * InterleavedN; -+ -+ int idx = mma_n + mma_m * Policy::MmaIterations::kColumn; -+ -+ if((global_offset_row_ + accum_m < extent_row) && (global_offset_col_ + accum_n < extent_col)) { -+ AccessType* access_ptr = reinterpret_cast(offset_ref.data() + -+ offset_m + mma_n * k_offset_delta); -+ -+ ref_frag_ptr[0] = access_ptr[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < kElementsPerAccess; ++i) { -+ output_frag[idx * kElementsPerAccess + i] = Element(alpha_ * output_frag_f[idx * kElementsPerAccess + i] -+ + beta_ * ref_frag[i]); -+ } -+ -+ access_ptr[0] = frag_ptr[idx]; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h -new file mode 100644 -index 0000000..bf192e6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h -@@ -0,0 +1,3106 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+ -+#include "cutlass/platform/platform.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads> -+class MmaVoltaTensorOpMultiplicandTileIterator; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Shape of one individual LDS.128 -+ // TODO: 32 and 4 are hardcoded, 32-by-4 is logical shape -+ using LdsShape = layout::PitchLinearShape< -+ 32, -+ 4 -+ >; -+ -+ // LdsShapes are arranged in the strided direction in SMEM -+ using LdsIterations = layout::PitchLinearShape< -+ InstructionShape::kStrided / LdsShape::kStrided, -+ Shape::kContiguous / LdsShape::kContiguous -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Number of internal pointers needed to reference shared memory -+ static int const kPointerCount = 2; -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_[kPointerCount]; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ // swizzle patterns for operandA LDS are -+ // 1. (tid[4] << 3) | (tid[2:0] ^ tid[4]) -+ // 2. (tid[4] << 3) | (tid[2:0] ^ tid[4] ^ 0b10010) -+ -+ int vec_row = (lane_id >> 4); // tid[4] -+ int vec_col = ((lane_id & 4) >> 2); // tid[2] -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPointerCount; ++i) { -+ -+ if(i == 1) { -+ vec_row |= 2; -+ } -+ int access_contiguous_idx = (vec_col << 2) | ((lane_id & 3) ^ vec_row); -+ int access_contiguous = access_contiguous_idx; -+ -+ int access_strided = vec_row; -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ -+ // To support 32x32 tile size -+ if (Shape::kContiguous == Policy::LdsShape::kContiguous) { -+ if (contiguous_offset % 2) { -+ AccessType const *tmp_pointer = pointer_[0]; -+ pointer_[0] = pointer_[1]; -+ pointer_[1] = tmp_pointer; -+ } -+ contiguous_offset = contiguous_offset / 2 * 2; -+ } -+ -+ int offset = (strided_offset * InstructionShape::kStrided) * stride_ * -+ Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ -= stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_[s & 1] + -+ Policy::LdsShape::kContiguous * c + -+ Policy::LdsShape::kStrided * (s / 2) * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+ -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandBCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kContiguous % InstructionShape::kContiguous), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ // Shape of one individual LDS -+ // TODO: remove hardcoded 32 and 4 -+ using LdsShape = layout::PitchLinearShape< -+ 32, -+ 4 -+ >; -+ -+ using LdsIterations = layout::PitchLinearShape< -+ Shape::kContiguous / LdsShape::kContiguous, -+ InstructionShape::kStrided / LdsShape::kStrided -+ >; -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile, needs on more time number of registers -+ using Fragment = Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ // swizzle pattern is (tid & (3 << 3) | (tid[1:0] ^ tid[4:3])) -+ int access_strided = (lane_id >> 3) & 0x3; -+ int access_contiguous = ((lane_id ^ (lane_id >> 3)) & 0x3); -+ -+ pointer_ = reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ -+ int offset = (strided_offset * InstructionShape::kStrided) * stride_ * -+ Layout::kElementsPerAccess + -+ contiguous_offset * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ byte_offset_ += stride_ * InstructionShape::kStrided * sizeof(Element) * -+ Layout::kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::LdsShape::kContiguous / Layout::kElementsPerAccess * c + -+ Policy::LdsShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, -+ cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous< -+ sizeof_bits::value>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It is used to load or store -+/// accumulators from memory and is agnostic to layout. It could be faster if it assumed row-major -+/// accumulator layout. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions, concept: MatrixShape) -+ typename OpDelta_> -+class MmaVoltaTensorOpAccumulatorTileIterator { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kC; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ -+ /// Volta Tensor Op uses 32x32 interleaved tile -+ using InterleavedTile = MatrixShape<32, 32>; -+ -+ static_assert(!(Shape::kRow % InterleavedTile::kRow) && !(Shape::kColumn % InterleavedTile::kColumn), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static_assert(platform::is_same::value, -+ "Layouts must be defined for logical MatrixCoord coordinate space."); -+ -+ /// Number of mma operations performed -+ using TileIterations = MatrixShape< -+ Shape::kRow / InterleavedTile::kRow, -+ Shape::kColumn / InterleavedTile::kColumn -+ >; -+ -+ using MmaIterations = -+ MatrixShape; -+ }; -+ -+private: -+ -+ // Assume accumulator tile is multipile interleaved 32x32 tile. -+ static int const kElementsPerPartial = 4; -+ using EleShapePerPatial = typename platform::conditional< -+ platform::is_same::value, -+ MatrixShape<2, 2>, -+ MatrixShape<1, 4> >::type; -+ static int const kElementsPerMma = 8; -+ static int const kAccumulatorPatials = 2; -+ using QuadShapePerPatialMma = MatrixShape<4, 4>; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+private: -+ -+ /// Reference to output tensor -+ TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref) { -+ -+ int quad = (lane_id >> 2); -+ int lane_in_quad = (lane_id & 3); -+ int accum_m, accum_n; -+ -+ if (platform::is_same::value) { -+ // (quad[2],quad[0])+lane_in_quad[0] -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); -+ // (quad[1])+lane_in_quad[1] -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + -+ (lane_in_quad & 2); -+ } else { -+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) -+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; -+ } -+ MatrixCoord lane_offset(accum_m, accum_n); -+ -+ ref_.add_coord_offset(lane_offset); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ ref_.add_coord_offset(tile_offset * make_Coord(Shape::kRow, Shape::kColumn)); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator++() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator--() { -+ // deliberate no-op -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index pointer_offset) const { ///< loads a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + mma_n) * -+ Policy::MmaIterations::kRow + mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn/2 + n; -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ frag[idx] = offset_ref.at({accum_m, accum_n}); -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ Fragment &frag, ///< fragment to load from the tensor -+ Index byte_offset) const { ///< loads a tile with a linear offset -+ -+ load_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset) const { ///< loads a tile with a logical offset in units of whole tiles -+ -+ load(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void load( -+ Fragment &frag, ///< fragment to load from the tensor -+ TensorCoord const &tile_offset, ///< loads a tile with a logical offset in units of whole tiles -+ Index pointer_offset) const { ///< loads a tile with a logical offset AND a pointer offset -+ -+ load_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+ -+ /// Stores a fragment to memory -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index pointer_offset) const { ///< store a tile with a linear offset -+ -+ TensorRef offset_ref(ref_); -+ offset_ref.add_pointer_offset(pointer_offset); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { -+ -+ int mma_accum_start = -+ (((tile_n * Policy::TileIterations::kRow + tile_m) * -+ Policy::MmaIterations::kColumn + mma_n) * -+ Policy::MmaIterations::kRow + mma_m) * -+ kElementsPerMma; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int p = 0; p < kAccumulatorPatials; ++p) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { -+ int accum_m = tile_m * Policy::InterleavedTile::kRow + -+ mma_m * QuadShapePerPatialMma::kRow + m * 2; -+ int accum_n = tile_n * Policy::InterleavedTile::kColumn + -+ mma_n * QuadShapePerPatialMma::kColumn + -+ p * Policy::InterleavedTile::kColumn/2 + n; -+ int idx = mma_accum_start + p * kElementsPerPartial + -+ m * EleShapePerPatial::kColumn + n; -+ offset_ref.at({accum_m, accum_n}) = frag[idx]; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory with additional pointer offset -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset( -+ Fragment const &frag, ///< fragment to store from the tensor -+ Index byte_offset) const { ///< store a tile with a linear offset -+ -+ store_with_pointer_offset(byte_offset / sizeof(Element)); -+ } -+ -+ /// Stores a fragment to memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void store( -+ Fragment &frag, ///< fragment to store to the tensor -+ TensorCoord const &tile_offset) const { ///< stores a tile with a logical offset in units of whole tiles -+ -+ store(frag, tile_offset, 0); -+ } -+ -+ /// Stores a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_HOST_DEVICE -+ void store( -+ /// fragment to store to the tensor -+ Fragment const &frag, -+ /// stores a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// stores a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ store_with_pointer_offset(frag, ref_.offset(tile_offset) + pointer_offset); -+ } -+}; -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::VoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::VoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ -+ /// Shape of one individual LDS instruction -+ using LdsShape = layout::PitchLinearShape<1, 32>; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsIterations = layout::PitchLinearShape<1, Shape::kStrided / 32>; -+ -+ /// Using LDS.128 -+ static int const kElementsPerAccess = 8; -+ -+ /// Contiguous elements per line -+ static int const kContiguousElementsPerLine = 4; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Crosswised elements are arranged in a SMEM line -+ /// in units of AccessType -+ Index line_size; -+ -+ /// Internal counter used to determine load addr offset -+ /// and when to swap higher 64bit with lower 64bit -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() -+ : pointer_(nullptr), -+ stride_(0), -+ line_size(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ stride_(ref.stride(0) * Policy::kElementsPerAccess), -+ line_size((ref.stride(0) * Policy::kContiguousElementsPerLine) / -+ Policy::kElementsPerAccess), -+ k_group_idx_(0), -+ byte_offset_(0) { -+ -+ int quad = (lane_id / 4); -+ int lane_in_quad = (lane_id % 4); -+ int access_contiguous; -+ -+ if(kOperand == Operand::kA) { -+ -+ // swizzle id: tid[4]|tid[1:0]|(tid[2]^tid[4]) -+ access_contiguous = ((quad & 0x4) << 1) + ((lane_in_quad) << 1) + -+ ((quad & 0x1) ^ ((quad & 0x4) >> 2)); -+ } else { -+ -+ // swizzle id: tid[4]|tid[1:0]|tid[3] -+ access_contiguous = ((quad & 0x4) << 1) + (lane_in_quad << 1) + -+ ((quad & 0x2) >> 1 ^ ((quad & 0x4) >> 2)); -+ } -+ -+ byte_offset_ = access_contiguous * -+ sizeof(Element) * Policy::kElementsPerAccess; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ int contiguous_offset = tile_offset.contiguous(); -+ int strided_offset = tile_offset.strided(); -+ k_group_idx_ = 0; -+ -+ pointer_ += contiguous_offset * -+ (InstructionShape::kContiguous / -+ Policy::kContiguousElementsPerLine) * -+ line_size + -+ strided_offset * Shape::kStrided / 2; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ k_group_idx_ = (k_group_idx_ + 1) % 8; -+ -+ if (k_group_idx_ == 4 || k_group_idx_ == 0) { -+ byte_offset_ ^= 1 * sizeof(Element) * Policy::kElementsPerAccess; -+ } -+ -+ pointer_ += line_size; -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType * fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsIterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsIterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::LdsShape::kContiguous * c * line_size + -+ Policy::LdsShape::kStrided * s / 2; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ fetch_ptr[access_idx] = *(reinterpret_cast (source_byte_ptr)); -+ -+ // swap higher 64bit and lower 64bit -+ if (k_group_idx_ & 0x2) { -+ uint64_t *low = reinterpret_cast(&frag) + access_idx * 2; -+ uint64_t *high = reinterpret_cast(&frag) + access_idx * 2 + 1; -+ uint64_t tmp = *low; -+ *low = *high; -+ *high = tmp; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Policy::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ kKBlock>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.contiguous(), tile_offset.strided()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDS to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// KBlock size (in units of elements) -+ int KBlock> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, KBlock>, -+ InstructionShape_, OpDelta_, 32> { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand == Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for " -+ "A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// KBlock size -+ static int const kKBlock = KBlock; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, kKBlock>; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaVoltaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ kKBlock>, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator(TensorRef const &ref, int lane_id) -+ : iterator_({ref.data(), ref.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for 'TN' arrangement -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of matrix operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = 4; -+ -+private: -+ -+ static int const kInterleavedTileRows = 32; -+ static int const kInterleavedTileColumns = 32; -+ static int const kInstructionsPerTile = 2; -+ -+ /// Rounded up instruction counts -+ using TileCount = MatrixShape< -+ Shape::kRow / kInterleavedTileRows, -+ Shape::kColumn / kInterleavedTileColumns -+ >; -+ -+ using FragmentCount = MatrixShape< -+ TileCount::kRow * kInstructionsPerTile, -+ TileCount::kColumn * kInstructionsPerTile -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; -+ int col_idx = 0; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = 0; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile + lane_in_quad; -+ int col_idx = 0; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = 0; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile + lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ __syncthreads(); -+ #endif -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *access_ptr = reinterpret_cast(ref_.data()); -+ int ldm = ref_.stride()[0]; -+ -+ if (kOperand == Operand::kA) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kRow; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int row_offset = tile_idx * kInterleavedTileRows + quad_idx * 4; -+ frag_ptr[idx] = access_ptr[row_offset * ldm / kElementsPerAccess]; -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int col_offset = tile_idx * kInterleavedTileColumns + quad_idx * 4; -+ frag_ptr[idx] = access_ptr[col_offset * ldm / kElementsPerAccess]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+ -+/// Tile iterator specialized for 'NT' arrangement -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of matrix operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaVoltaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = 4; -+ -+private: -+ -+ static int const kInterleavedTileRows = 32; -+ static int const kInterleavedTileColumns = 32; -+ static int const kInstructionsPerTile = 2; -+ -+ /// Rounded up instruction counts -+ using TileCount = MatrixShape< -+ Shape::kRow / kInterleavedTileRows, -+ Shape::kColumn / kInterleavedTileColumns -+ >; -+ -+ using FragmentCount = MatrixShape< -+ TileCount::kRow * kInstructionsPerTile, -+ TileCount::kColumn * kInstructionsPerTile -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ (kOperand == Operand::kA ? FragmentCount::kRow : FragmentCount::kColumn) * kElementsPerAccess -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; -+ int col_idx = lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = lane_in_quad; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ int quad_id = lane_id / 4; -+ int lane_in_quad = (lane_id % 4); -+ -+ if (kOperand == Operand::kA) { -+ -+ int row_idx = ((quad_id & 1) + ((quad_id & 4) / 2)) * 4 * kInstructionsPerTile; -+ int col_idx = lane_in_quad; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ else { -+ -+ int row_idx = lane_in_quad; -+ int col_idx = (quad_id / 2) * 4 * kInstructionsPerTile; -+ -+ origin_ = MatrixCoord(row_idx, col_idx); -+ } -+ -+ #if defined(__CUDA_ARCH__) -+ __syncthreads(); -+ #endif -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ AccessType const *access_ptr = reinterpret_cast(ref_.data()); -+ int ldm = ref_.stride()[0]; -+ -+ if (kOperand == Operand::kA) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kRow; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int row_offset = tile_idx * kInterleavedTileRows; -+ frag_ptr[idx] = access_ptr[row_offset / kElementsPerAccess + quad_idx]; -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx = 0; idx < FragmentCount::kColumn; ++idx) { -+ -+ int tile_idx = idx / 2; -+ int quad_idx = idx % 2; -+ -+ int col_offset = tile_idx * kInterleavedTileColumns; -+ frag_ptr[idx] = access_ptr[col_offset / kElementsPerAccess + quad_idx]; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, -+ Operand::kA, -+ Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, -+ OpDelta_, -+ 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kA, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> ; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+ -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, -+ Operand::kA, -+ Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, -+ OpDelta_, -+ 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kA, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> ; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+ -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, OpDelta_, 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalInner< -+ Shape_, Operand::kB, Element_, cutlass::layout::ColumnMajor, InstructionShape_, OpDelta_>; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+}; -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_> -+class MmaVoltaTensorOpMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, OpDelta_, 32 -+> : public MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_> { -+ -+public: -+ using Base = MmaVoltaTensorOpMultiplicandTileIteratorCanonicalOuter< -+ Shape_, Operand::kB, Element_, cutlass::layout::RowMajor, InstructionShape_, OpDelta_>; -+ -+ using TensorRef = typename Base::TensorRef; -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaVoltaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): Base(ref, lane_id) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h -new file mode 100644 -index 0000000..29cc3d9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h -@@ -0,0 +1,2452 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 16) && !(Shape::kStrided % 4), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<8, 4>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess / Delta::kContiguous, -+ InstructionShape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+ /// Internal counter used to jump to next K partition -+ int k_group_idx_; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / Policy::Delta::kContiguous; -+ int access_contiguous = (lane_id % Policy::Delta::kContiguous) ^ access_strided; -+ -+ pointer_= reinterpret_cast(ref.data()) + -+ access_contiguous + access_strided * stride_; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += offset * sizeof(Element); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ int offset = -+ (tile_offset.strided() * InstructionShape::kStrided) * stride_ * kElementsPerAccess + -+ tile_offset.contiguous() * Shape::kContiguous; -+ -+ add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ add_tile_offset({0, 1}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ add_tile_offset({0, -1}); -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c + -+ Policy::Delta::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kContiguous / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for loading 128b vectors of 64b elements. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::TensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ static_assert(!(Shape::kContiguous % 4) && !(Shape::kStrided % 16), "Divisibility."); -+ -+ static_assert(sizeof_bits::value == 64, "This is specialized for 64b accesses."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::TensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Long Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Load two elements per access -+ static int const kElementsPerAccess = 2; -+ -+ /// Policy defining internal details of tile iterator -+ struct Policy { -+ -+ /// Shape of one access -+ using Delta = layout::PitchLinearShape<4, 16>; -+ -+ /// Number of iterations to load -+ using Iterations = layout::PitchLinearShape< -+ InstructionShape::kContiguous / Delta::kContiguous, -+ Shape::kStrided / Delta::kStrided -+ >; -+ -+ }; -+ -+private: -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = AlignedArray; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+private: -+ -+ /// Layout object storing stride values -+ StrideIndex stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter for tracking K-group -+ Index k_group_idx_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator(): stride_(0), byte_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): -+ stride_(ref.stride(0) / kElementsPerAccess), byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_strided = lane_id / 8; -+ int access_contiguous = (lane_id % 8); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * sizeof(AccessType); -+ -+ pointer_= reinterpret_cast(ref.data()); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ pointer_ += offset / kElementsPerAccess; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ int offset = (tile_offset.contiguous() * InstructionShape::kContiguous) * -+ stride_ * kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided; -+ -+ add_pointer_offset(offset); -+ -+ int old_k_group_idx = k_group_idx_; -+ -+ k_group_idx_ += tile_offset.contiguous(); -+ -+ if ((k_group_idx_ & 2) ^ (old_k_group_idx & 2)) { -+ byte_offset_ ^= 0x40; -+ } -+ -+ return *this; -+ } -+ -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ add_tile_offset(tile_offset); // TODO fix this if it becomes an issue during warp it reset -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ pointer_ += stride_ * InstructionShape::kContiguous; -+ -+ if (k_group_idx_ & 0x1) { -+ // xor ptr -+ byte_offset_ ^= 0x40; -+ } -+ -+ ++k_group_idx_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ AccessType *fetch_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::Iterations::kStrided; ++s) { -+ -+ int access_idx = c + s * Policy::Iterations::kContiguous; -+ -+ AccessType const *source_ptr = pointer_ + -+ Policy::Delta::kContiguous * c * stride_ + -+ Policy::Delta::kStrided * s / kElementsPerAccess; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + byte_offset + byte_offset_; -+ -+ AccessType const *source = reinterpret_cast(source_byte_ptr); -+ -+ fetch_ptr[access_idx] = *source; -+ } -+ } -+ -+ Element *exchange_ptr = reinterpret_cast(&frag); -+ -+ if (k_group_idx_ & 1) { -+ // exchange on 64b granularity -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Fragment::kElements; i += 2) { -+ Element tmp = exchange_ptr[i]; -+ exchange_ptr[i] = exchange_ptr[i + 1]; -+ exchange_ptr[i + 1] = tmp; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ Layout::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.strided(), tile_offset.contiguous()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIterator< -+ layout::PitchLinearShape, kOperand, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ layout::PitchLinearShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset_negative({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Tile iterator specialized for canonical matrix layouts -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity -+ Operand Operand_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads = 32, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class MmaTensorOpMultiplicandTileIteratorCanonical { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ /// Basic check -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Number of elements accessed per Shared Memory load -+ static int const kElementsPerAccess = -+ (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); -+ -+private: -+ -+ static int const kWarpShapeOuter = -+ (kOperand == Operand::kA ? Shape::kRow : Shape::kColumn); -+ -+ static int const kWarpShapeInner = -+ (kOperand == Operand::kA ? Shape::kColumn : Shape::kRow); -+ -+ -+ /// Rounded up instruction counts -+ using InstructionCount = MatrixShape< -+ Shape::kRow / InstructionShape::kRow, -+ Shape::kColumn / InstructionShape::kColumn -+ >; -+ -+ /// Rounded up tile dimensions -+ using WarpShapeDivisible = MatrixShape< -+ InstructionCount::kRow * InstructionShape::kRow, -+ InstructionCount::kColumn * InstructionShape::kColumn -+ >; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array< -+ Element, -+ WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads -+ >; -+ -+ /// Memory access type -+ using AccessType = AlignedArray; -+ -+private: -+ -+ /// Underlying tensor reference -+ TensorRef ref_; -+ -+ /// Extent of tensor -+ MatrixCoord extent_; -+ -+ /// Origin -+ MatrixCoord origin_; -+ -+ /// Used to conditionally enable extents checking -+ bool divisible_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical( -+ TensorRef const &ref, -+ int lane_id -+ ): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical( -+ TensorRef const &ref, -+ TensorCoord extent, -+ int lane_id -+ ): ref_(ref), extent_(extent), divisible_(false) { -+ -+ if (kOperand == Operand::kA) { -+ origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); -+ } -+ else { -+ origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); -+ } -+ -+ ref_.add_coord_offset(origin_); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) { -+ -+ ref_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ origin_ += coord_offset; -+ -+ ref_.add_coord_offset(coord_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator++() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, 1}); -+ } -+ else { -+ add_tile_offset({1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator--() { -+ -+ if (kOperand == Operand::kA) { -+ add_tile_offset({0, -1}); -+ } -+ else { -+ add_tile_offset({-1, 0}); -+ } -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ int const kWarpShapeDivisibleInner = -+ (kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow); -+ -+ // Take advantage of Tensor Op's 8 x 4T access pattern -+ int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; -+ -+ AccessType *access_ptr = reinterpret_cast(&frag); -+ -+ if (kOperand == Operand::kA) { -+ int const kTilesPerInstruction = InstructionShape::kRow / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { -+ int access_idx = -+ access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); -+ -+ MatrixCoord offset( -+ access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, -+ inner_idx * 4 * kElementsPerAccess); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+ } -+ else { -+ AccessType zero; -+ zero.clear(); -+ access_ptr[access_idx] = zero; -+ } -+ } -+ } -+ } -+ } -+ else { -+ CUTLASS_PRAGMA_UNROLL -+ for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { -+ int access_idx = inner_idx + kAccessesInner * inst_n_idx; -+ -+ MatrixCoord offset( -+ inner_idx * 4 * kElementsPerAccess, -+ inst_n_idx * 8); -+ -+ MatrixCoord access_coord = origin_ + offset; -+ -+ if (divisible_ || -+ (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { -+ -+ access_ptr[access_idx] = *reinterpret_cast( -+ ref_.data() + ref_.offset(offset)); -+ } -+ else { -+ AccessType zero; -+ zero.clear(); -+ access_ptr[access_idx] = zero; -+ } -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ -+ load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); -+ -+ load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation -+ } -+}; -+ -+/// Wrapper for ColumnMajor -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::ColumnMajor, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIteratorCanonical< -+ Shape, kOperand, Element, -+ layout::ColumnMajor, -+ InstructionShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ TensorCoord const & extent, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+/// Wrapper for RowMajor -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Identifies A or B multiplicand -+ Operand Operand_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Interval between adjacent *MMA instructions (in units of MMA -+ /// instructions) -+ int OpDelta_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class MmaTensorOpMultiplicandTileIterator< -+ Shape_, Operand_, Element_, -+ cutlass::layout::RowMajor, -+ InstructionShape_, OpDelta_, 32, PartitionsK_> { -+ public: -+ -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand_; -+ -+ static_assert(kOperand == Operand::kA || kOperand== Operand::kB, -+ "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Underlying tile iterator implementation -+ using Base = MmaTensorOpMultiplicandTileIteratorCanonical< -+ Shape, kOperand, Element, -+ layout::RowMajor, -+ InstructionShape, -+ kOpDelta, kThreads, PartitionsK_>; -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+private: -+ -+ /// Underlying tile iterator -+ Base iterator_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, lane_id) { -+ } -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator( -+ TensorRef const &ref, -+ TensorCoord const &extent, -+ int lane_id -+ ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator++() { -+ -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator--() { -+ -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ iterator_.load(frag); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, -+ {tile_offset.contiguous(), tile_offset.strided()}, -+ byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h -new file mode 100644 -index 0000000..f7370a6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h -@@ -0,0 +1,380 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations -+ targeting Sparse Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class SparseMmaTensorOpMetaTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: -+ /// MatrixShape) -+ static int const kOpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ static int const kSparse = 2; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ struct Policy { -+ static_assert( -+ !(Shape::kColumn % InstructionShape::kColumn), -+ "Shape of warp-level Mma must be divisible by operator shape."); -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ -+ // Determine number of elements along outer dimension per individual LDSM op -+ static int const kLdsmOpOuter = InstructionShape::kColumn; -+ static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter; -+ -+ static_assert(!(Shape::kColumn % kLdsmOpOuter), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ static_assert(!(Shape::kRow % kLdsmOpInner), -+ "Shape of warp-level mma must be divisible by LDSM's " -+ "fundamental tile size."); -+ -+ /// Shape of one individual LDSM instruction -+ static int const LdsmShapeColumn = -+ InstructionShape::kColumn / kLdsmOpOuter; -+ static int const LdsmShapeRow = -+ ((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow) -+ ? (Shape::kRow / kLdsmOpInner) -+ : (4 / LdsmShapeColumn); -+ using LdsmShape = -+ layout::PitchLinearShape; -+ -+ /// Number and arrangement of LDSM instructions -+ using LdsmIterations = layout::PitchLinearShape< -+ Shape::kRow / kLdsmOpInner / LdsmShapeRow, -+ 1>; -+ -+ /// Number of groups for each tile -+ static int const kGroupsPerTile = -+ Shape::kColumn / InstructionShape::kColumn; -+ }; -+ -+ private: -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = -+ Array; -+ -+ private: -+ -+ /// Layout object storing stride values -+ Index stride_; -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ SparseMmaTensorOpMetaTileIterator() -+ : pointer_(nullptr), -+ stride_(0), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id) -+ : pointer_(reinterpret_cast(ref.data())), -+ stride_(ref.stride(0) / Policy::kElementsPerAccess), -+ byte_offset_(0), -+ k_group_idx_(0) { -+ -+ int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess)); -+ int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess)); -+ -+ byte_offset_ = (access_contiguous + access_strided * stride_) * -+ sizeof_bits::value * Policy::kElementsPerAccess / 8; -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int offset = tile_offset.row() * Shape::kRow + -+ tile_offset.column() * InstructionShape::kColumn * stride_ * -+ Policy::kElementsPerAccess; -+ -+ add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator++() { -+ add_tile_offset({0, 1}); -+ -+ if (kPartitionsK > 1) { -+ ++k_group_idx_; -+ // Jump to next stage -+ if (k_group_idx_ == Policy::kGroupsPerTile) { -+ k_group_idx_ = 0; -+ add_tile_offset( -+ {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator--(){ -+ byte_offset_ -= stride_ * InstructionShape::kColumn * -+ sizeof_bits::value * Policy::kElementsPerAccess / -+ 8; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator & -+ operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ SparseMmaTensorOpMetaTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + -+ Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c + -+ Policy::LdsmShape::kStrided * s * stride_; -+ -+ char const *source_byte_ptr = reinterpret_cast(source_ptr) + -+ byte_offset + byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = -+ tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess + -+ tile_offset.strided() * InstructionShape::kColumn * stride_; -+ -+ byte_offset += sizeof(AccessType) * pointer_offset; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no op -+ } -+}; -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h -new file mode 100644 -index 0000000..d841d2b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h -@@ -0,0 +1,805 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Operand identity (A or B) -+ Operand Operand, -+ /// Data type of operand -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Delta between *MMA operations (in units of *WMMA operations, concept:MatrixShape) -+ int OpDelta_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::load_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ int OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator< -+ Shape_, Operand::kA, Element_, Layout_, -+ OpDelta_, 32, Policy_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kA; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *WMMA operations -+ static int const kOpDelta = OpDelta_; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Stride Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape for operand A (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kM, -+ Policy::Operator::Shape::kK -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Shape of individual WMMA load / stores for operand A -+ using Iterations = MatrixShape< -+ Shape::kRow / WmmaShape::kRow, -+ 1 -+ >; -+ -+ /// Fragment object holding a warps part -+ using Fragment = WmmaFragmentArray; -+ -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically assert this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// This iterator is specalized for Operand A -+ static_assert(kOperand == Operand::kA, -+ "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma."); -+ -+ /// Supported memory layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+private: -+ -+ /// Shared memory base pointers - not advanced -+ char const *pointer_; -+ -+ /// Byte offset into shared memory - advanced -+ Index byte_offset_; -+ -+ /// Stride in units of number of elements -+ StrideIndex stride_; -+ -+ /// Layout of shared memory -+ Layout layout_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { -+ -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += (offset * sizeof_bits::value) / 8; -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator++() { -+ -+ Index elements_offset = layout_({0, WmmaShape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator--() { -+ -+ Index elements_offset = layout_({0, WmmaShape::kColumn}); -+ -+ byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ Index load_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_); -+ -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kColumn; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ -+ Index store_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_); -+ -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::load_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ int OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaMultiplicandTileIterator< -+ Shape_, Operand::kB, Element_, Layout_, -+ OpDelta_, 32, Policy_> { -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Operand tag -+ static Operand const kOperand = Operand::kB; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *WMMA operations -+ static int const kOpDelta = OpDelta_; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Stride Index type -+ using StrideIndex = typename TensorRef::Layout::Stride::Index; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kK, -+ Policy::Operator::Shape::kN -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Shape of individual WMMA load / stores for operand B -+ using Iterations = MatrixShape< -+ 1, -+ Shape::kColumn / WmmaShape::kColumn -+ >; -+ -+ /// Fragment object holding a warps part -+ using Fragment = WmmaFragmentArray; -+ -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically asserts this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// This iterator is specalized for Operand B -+ static_assert(kOperand == Operand::kB, -+ "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma."); -+ -+ /// Supported memory layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+ /// Not working on this feature at the moment. -+ static_assert(kOpDelta == 1, -+ "Alternative arrangements not supported at present."); -+ -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+private: -+ -+ /// Shared memory base pointers - not advanced -+ char const *pointer_; -+ -+ /// Byte offset into shared memory - advanced -+ Index byte_offset_; -+ -+ /// Stride in units of number of elements -+ StrideIndex stride_; -+ -+ /// Layout of shared memory -+ Layout layout_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): pointer_(reinterpret_cast(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { -+ -+ byte_offset_ += (offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator++() { -+ -+ Index elements_offset = layout_({WmmaShape::kRow, 0}); -+ -+ byte_offset_ += (elements_offset * sizeof_bits::value) / 8; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator--() { -+ -+ Index elements_offset = layout_({WmmaShape::kRow, 0}); -+ -+ byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ Index load_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ const WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + load_byte_offset + byte_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_); -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_byte_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < Iterations::kRow; ++k) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ Index store_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits::value / 8; -+ -+ WmmaDataType *ptr = reinterpret_cast(pointer_ + byte_offset_ + store_byte_offset + byte_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_byte_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions, concept: MatrixShape) -+ typename OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaAccumulatorTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+/// This tile iterator is specialized for 32-thread WMMA operation. -+/// It uses nvcuda::wmma::store_matrix_sync to load from shared -+/// memory and therefore must be initialized with a TensorRef to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept | -+/// WriteableRandomAccessContiguousTileIteratorConcept -+/// -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ ///< Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Interval between adjacent *WMMA instructions (in units of WMMA instructions) -+ typename OpDelta_, -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ typename Policy_> -+class MmaTensorOpWmmaAccumulatorTileIterator -+{ -+ public: -+ -+ /// Shape of tile to load (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = Layout_; -+ -+ /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) -+ using OpDelta = OpDelta_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Wmma Operator information and operation delta -+ using Policy = Policy_; -+ -+ -+ // -+ // Derived quantities -+ // -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Native Wmma shape (concept MatrixShape) -+ using WmmaShape = MatrixShape< -+ Policy::Operator::Shape::kM, -+ Policy::Operator::Shape::kN -+ >; -+ -+ /// Map cutlass dataype to nvcuda::wmma datatype -+ using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType::Type; -+ -+ /// Map cutlass::layout to nvuda::wmma::layout_t enum -+ static nvcuda::wmma::layout_t const WmmaLayout = cutlass::arch::CutlassToWmmaLayout::value; -+ -+ /// Shape of individual WMMA load / stores for accumulator -+ using Iterations = MatrixShape< -+ Shape::kRow / WmmaShape::kRow, -+ Shape::kColumn / WmmaShape::kColumn -+ >; -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = WmmaFragmentArray; -+ -+ ////////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// statically asserts this specialization -+ ///////////////////////////////////////////////////////////////////////////////////////////////////// -+ /// Supported layouts -+ static_assert( -+ platform::is_same::value || -+ platform::is_same::value, -+ "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor"); -+ -+private: -+ -+ /// Internal reference -+ cutlass::TensorRef ref_; -+ -+public: -+ -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator() { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator( -+ TensorRef const &ref, -+ int lane_id -+ ): ref_(ref) { } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) { -+ ref_.add_pointer_offset(offset); -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) { -+ ref_.add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn}); -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator++() { -+ ref_.add_coord_offset({Shape::kRow, 0}); -+ return *this; -+ } -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator--() { -+ ref_.add_coord_offset({-Shape::kRow, 0}); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ MmaTensorOpWmmaAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ const WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); -+ -+ nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.stride()[0], WmmaLayout); -+ -+ } -+ } -+ } -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < Iterations::kRow; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < Iterations::kColumn; ++n) { -+ -+ WmmaDataType * ptr = reinterpret_cast (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset); -+ -+ nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.stride()[0], WmmaLayout); -+ } -+ } -+ } -+ -+ /// Stores a fragment to memory at the location pointed to by the iterator -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) const { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ // no operation here -+ } -+}; -+ -+ -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h -new file mode 100644 -index 0000000..c3954f3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_tensor_op_wmma.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/wmma_array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///< Structure to compute the matrix product targeting CUDA cores via WMMA. -+template < -+ ///< Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ ///< Data type of A elements -+ typename ElementA_, -+ ///< Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ ///< Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ ///< Element type of C matrix -+ typename ElementC_, -+ ///< Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ ///< Policy describing warp-level Wmma operation (concept: MmaTensorOpPolicy) -+ typename Policy_, -+ ///< Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ ///< Used for partial specialization -+ typename Enable = bool -+> -+class MmaTensorOpWmma { -+public: -+ ///< Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ ///< Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ ///< Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ ///< Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ ///< Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ ///< Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ ///< Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaTensorOpPolicy) -+ using Policy = Policy_; -+ -+ /// Underlying instruction shape -+ using InstructionShape = typename Policy::Operator::Shape; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Underlying architecture tag -+ using ArchTag = typename Policy::Operator::ArchTag; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassWmmaTensorOp; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ Policy::OpDelta::kRow, kThreadCount, Policy>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ Policy::OpDelta::kRow, kThreadCount, Policy>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename Policy::OpDelta, Policy>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+private: -+ -+ static_assert( -+ !(Shape::kM % Policy::Operator::Shape::kM) && -+ !(Shape::kN % Policy::Operator::Shape::kN), -+ "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)"); -+ -+ /// Number of wmma operations performed -+ using WmmaIterations = MatrixShape< -+ Shape::kM / Policy::Operator::Shape::kM, -+ Shape::kN / Policy::Operator::Shape::kN -+ >; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: cutlass::arch::Wmma) -+ typename Policy::Operator wmma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaTensorOpWmma() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ FragmentA const &A, -+ FragmentB const &B, -+ FragmentC const &C) const { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < WmmaIterations::kColumn; ++n) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < WmmaIterations::kRow; ++m) { -+ -+ // accumulate wmma mma -+ wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]); -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h -new file mode 100644 -index 0000000..9957967 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations targeting -+ Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+#include "cutlass/gemm/warp/mma_tensor_op.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape_, -+ /// Data type of A elements -+ typename ElementA_, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA_, -+ /// Data type of B elements -+ typename ElementB_, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB_, -+ /// Element type of C matrix -+ typename ElementC_, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC_, -+ /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) -+ typename Policy_, -+ /// Reduce operand A or B along K dimension -+ bool ReduceKForA_, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1, -+ /// Store the accumulators in row major or column major. Row major is used -+ /// when output layout is interleaved. -+ bool AccumulatorsInRowMajor = false, -+ /// Used for partial specialization -+ typename Enable = bool -+> -+class MmaWithReductionTensorOp { -+public: -+ /// Shape of warp-level matrix operation (concept: GemmShape) -+ using Shape = Shape_; -+ -+ /// Data type of multiplicand A -+ using ElementA = ElementA_; -+ -+ /// Layout of multiplicand A -+ using LayoutA = LayoutA_; -+ -+ /// Data type of multiplicand B -+ using ElementB = ElementB_; -+ -+ /// Layout of multiplicand B -+ using LayoutB = LayoutB_; -+ -+ /// Data type of accumulator matrix C -+ using ElementC = ElementC_; -+ -+ /// Layout of accumulator matrix C -+ using LayoutC = LayoutC_; -+ -+ /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) -+ using Policy = Policy_; -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ using ArchMmaOperator = typename Policy::Operator; -+ -+ /// Indicates math operator -+ using MathOperator = typename ArchMmaOperator::Operator; -+ -+ /// Architecture tag from underlying instruction -+ using ArchTag = typename ArchMmaOperator::ArchTag; -+ -+ /// Indicates class of matrix operator -+ using OperatorClass = arch::OpClassTensorOp; -+ -+ /// Shape of underlying instruction -+ using InstructionShape = typename ArchMmaOperator::Shape; -+ -+ /// Complex transform on A operand -+ static ComplexTransform const kTransformA = ComplexTransform::kNone; -+ -+ /// Complex transform on B operand -+ static ComplexTransform const kTransformB = ComplexTransform::kNone; -+ -+ /// Number of threads participating in warp-level matrix product -+ static int const kThreadCount = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ static bool const kReduceKForA = ReduceKForA_; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "ElementA needs to be fp16 or bf16."); -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "ElementB needs to be fp16 or bf16."); -+ -+ static_assert(platform::is_same>::value, -+ "Only supports 16x8x16 tensor core instruction."); -+ -+ static_assert(!AccumulatorsInRowMajor, -+ "Only calls tensor core instructions in column major."); -+ -+public: -+ -+ /// Iterates over the A operand in memory -+ using IteratorA = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kA, ElementA, LayoutA, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for A tile -+ using FragmentA = typename IteratorA::Fragment; -+ -+ /// Storage for transformed A tile -+ using TransformedFragmentA = -+ Array; -+ -+ /// Iterates over the B operand in memory -+ using IteratorB = MmaTensorOpMultiplicandTileIterator< -+ MatrixShape, Operand::kB, ElementB, LayoutB, -+ MatrixShape, -+ Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; -+ -+ /// Storage for B tile -+ using FragmentB = typename IteratorB::Fragment; -+ -+ /// Storage for transformed B tile -+ using TransformedFragmentB = -+ Array; -+ -+ /// Iterates over the C operand in memory -+ using IteratorC = MmaTensorOpAccumulatorTileIterator< -+ MatrixShape, ElementC, LayoutC, -+ typename ArchMmaOperator::Shape, typename Policy::OpDelta>; -+ -+ /// Storage for C tile -+ using FragmentC = typename IteratorC::Fragment; -+ -+ /// Number of mma operations performed -+ using MmaIterations = MatrixShape< -+ (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, -+ (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN -+ >; -+ -+ using FragmentReduction = Array; -+ -+public: -+ -+ /// Underlying matrix multiply operator (concept: arch::Mma) -+ ArchMmaOperator mma; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_DEVICE -+ MmaWithReductionTensorOp() {} -+ -+ /// Performs a warp-level matrix multiply-accumulate operation -+ CUTLASS_DEVICE -+ void operator()( -+ FragmentC &D, -+ TransformedFragmentA const &A, -+ TransformedFragmentB const &B, -+ FragmentC const &C, -+ FragmentReduction &gemm_k_reduction -+ ) const { -+ -+ using MmaOperandA = typename ArchMmaOperator::FragmentA; -+ using MmaOperandB = typename ArchMmaOperator::FragmentB; -+ using MmaOperandC = typename ArchMmaOperator::FragmentC; -+ -+ D = C; -+ -+ [[maybe_unused]] MmaOperandA const *ptr_A = reinterpret_cast(&A); -+ [[maybe_unused]] MmaOperandB const *ptr_B = reinterpret_cast(&B); -+ [[maybe_unused]] MmaOperandC *ptr_D = reinterpret_cast(&D); -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ assert(0); -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ // Serpentine visitation order maximizing reuse of Ra -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < MmaIterations::kRow; ++m) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < MmaIterations::kColumn; ++n) { -+ -+ int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); -+ -+ mma(ptr_D[m + n_serpentine * MmaIterations::kRow], -+ ptr_A[m], -+ ptr_B[n_serpentine], -+ ptr_D[m + n_serpentine * MmaIterations::kRow]); -+ -+ if (!kReduceKForA && m == 0) { -+ #if 0 -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); -+ gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); -+ #else -+ uint32_t const *tmp = reinterpret_cast(&B); -+ -+ if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f16 low, high;\n\t" -+ " .reg .f32 tmp;\n\t" -+ " mov.b32 {low, high}, %1;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %2;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[n_serpentine]) -+ : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); -+ } else if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f32 tmp;\n\t" -+ " shl.b32 tmp, %1, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %1, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %2, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %2, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[n_serpentine]) -+ : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); -+ } else { -+ assert(0); -+ } -+ #endif -+ } -+ -+ if (kReduceKForA && (n == 0)) { -+ #if 0 -+ gemm_k_reduction[m * 2] += float(A[m * 8]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); -+ gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); -+ -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); -+ gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); -+ #else -+ uint32_t const *tmp = reinterpret_cast(&A); -+ -+ if (platform::is_same::value) { -+ asm volatile( -+ "{\n\t" -+ " .reg .f16 low, high;\n\t" -+ " .reg .f32 tmp;\n\t" -+ " mov.b32 {low, high}, %2;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %3;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " mov.b32 {low, high}, %4;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " mov.b32 {low, high}, %5;\n\t" -+ " cvt.f32.f16 tmp, low;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " cvt.f32.f16 tmp, high;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) -+ : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); -+ -+ } else if (platform::is_same::value) { -+ -+ asm volatile( -+ "{\n\t" -+ " .reg .f32 tmp;\n\t" -+ " shl.b32 tmp, %2, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %2, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %3, 16;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " and.b32 tmp, %3, 0xffff0000;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " shl.b32 tmp, %4, 16;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " and.b32 tmp, %4, 0xffff0000;\n\t" -+ " add.f32 %0, tmp, %0;\n\t" -+ " shl.b32 tmp, %5, 16;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ " and.b32 tmp, %5, 0xffff0000;\n\t" -+ " add.f32 %1, tmp, %1;\n\t" -+ "}\n\t" -+ : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) -+ : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); -+ -+ } else { -+ assert(0); -+ } -+ #endif -+ } -+ } -+ } -+ #else -+ assert(0); -+ #endif -+ } -+ -+ /// Transform the mma operands to the required types -+ CUTLASS_DEVICE -+ void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, -+ FragmentA const &A, FragmentB const &B) const { -+ -+ // -+ // Define conversions from source type to instruction type -+ // -+ FloatRoundStyle const kRoundA = -+ PreferredRoundingMode::kRound; -+ FloatRoundStyle const kRoundB = -+ PreferredRoundingMode::kRound; -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_B = -+ reinterpret_cast const *>(&B); -+ Array * -+ ptr_dst_B = reinterpret_cast *>(&dst_B); -+ -+ dst_A = convert_A(A); -+ -+ ptr_dst_B[0] = convert_B(ptr_B[0]); -+ ptr_dst_B[1] = convert_B(ptr_B[1]); -+ -+ #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ detail::ConvertAndPack -+ convert_A; -+ NumericArrayConverter -+ convert_B; -+ Array const *ptr_A = -+ reinterpret_cast const *>(&A); -+ Array * -+ ptr_dst_A = reinterpret_cast *>(&dst_A); -+ -+ dst_B = convert_B(B); -+ -+ ptr_dst_A[0] = convert_A(ptr_A[0]); -+ ptr_dst_A[1] = convert_A(ptr_A[1]); -+ #else -+ assert(0); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h b/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h -new file mode 100644 -index 0000000..9c9b90b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/scale_bias_tile_iterator.h -@@ -0,0 +1,574 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Defines iterators used by warp-level loading scale and bias vectors. -+ Every scale/bias data only needs to be loaded once for every channel. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/fast_math.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of A elements -+ typename Element_, -+ /// Layout of operand -+ typename Layout_, -+ /// Shape of one matrix production operation (concept: GemmShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of threads participating in one matrix operation -+ int Threads, -+ /// Number of partitions along K dimension -+ int PartitionsK_ = 1> -+class ScaleBiasTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: PitchLinearShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: PitchLinearShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class ScaleBiasTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::PitchLinear; -+ -+ /// Shape of one matrix product operation (concept: GemmShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// Number of partitions along K dimension -+ static int const kPartitionsK = PartitionsK_; -+ -+ /// Number of partitions along K dimension -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ using Policy = Policy_; -+ -+ private: -+ -+ /// Pointer type used for accesses -+ using AccessType = Array; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = Array; -+ -+ private: -+ -+ /// Shared memory base pointers - not advanced -+ AccessType const *pointer_; -+ -+ /// Byte offset incremented as iterator advances -+ Index byte_offset_; -+ -+ /// Internal counter used to determine when to increment byte offset and when -+ /// to XOR it -+ int k_group_idx_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator() -+ : pointer_(nullptr), -+ byte_offset_(0), -+ k_group_idx_(0) {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator(TensorRef const &ref_scale_bias, -+ int lane_id) -+ : byte_offset_(0), k_group_idx_(0) { -+ /// 16816 only -+ pointer_ = reinterpret_cast(ref_scale_bias.data()) + -+ ((lane_id >> 3) & 1) * Shape::kContiguous / kElementsPerAccess + -+ (lane_id >> 4); -+ } -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { -+ byte_offset_ += offset * sizeof_bits::value / 8; -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ int whole_tiles = tile_offset.contiguous() / Policy::kGroupsPerTile; -+ int k_groups_delta = tile_offset.contiguous() % Policy::kGroupsPerTile; -+ -+ byte_offset_ += k_groups_delta * sizeof_bits::value * -+ kElementsPerAccess * Policy::LdsmShape::kContiguous / 8; -+ -+ // Multiply by 2 because scale and bias belonging to the same stage are next -+ // to each other in the shared memory. -+ pointer_ += (2 * whole_tiles * Shape::kContiguous / kElementsPerAccess); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator++() { -+ byte_offset_ += Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * kElementsPerAccess / 8; -+ -+ k_group_idx_++; -+ -+ if (k_group_idx_ == (Policy::kGroupsPerTile / kPartitionsK)) { -+ k_group_idx_ = 0; -+ byte_offset_ -= (Policy::kGroupsPerTile / kPartitionsK) * -+ Policy::LdsmShape::kContiguous * -+ sizeof_bits::value * kElementsPerAccess / 8; -+ add_tile_offset({Policy::kGroupsPerTile, 0}); -+ } -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator--() { assert(0); } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ Array *fetch_ptr = -+ reinterpret_cast *>(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < 1; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { -+ int access_idx = c + s * Policy::LdsmIterations::kContiguous; -+ -+ AccessType const *source_ptr = -+ pointer_ + Policy::LdsmShape::kContiguous * c; -+ -+ char const *source_byte_ptr = -+ reinterpret_cast(source_ptr) + byte_offset + -+ byte_offset_; -+ -+ cutlass::arch::ldsm( -+ fetch_ptr[access_idx], source_byte_ptr); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ load_with_byte_offset(frag, tile_offset, 0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ Index pointer_offset = tile_offset.contiguous() * -+ InstructionShape::kContiguous / -+ kElementsPerAccess; -+ -+ byte_offset += sizeof_bits::value * pointer_offset / 8; -+ -+ load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ k_group_idx_ = k_group % (Policy::kGroupsPerTile / kPartitionsK); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// This tile iterator is specialized for 32-thread TensorOps. It uses LDSM to -+/// load from shared memory and therefore must be initialized with a TensorRef -+/// to shared memory. -+/// -+/// Satisfies: -+/// ReadableRandomAccessContiguousTileIteratorConcept -+/// -+template < -+ /// Size of the matrix to load (concept: MatrixShape) -+ typename Shape_, -+ /// Data type of elements -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ /// Policy of the details of LDSM shape and iterations -+ typename Policy_, -+ /// Number of partitions along K dimension -+ int PartitionsK_> -+class ScaleBiasTileIterator { -+ public: -+ /// Shape of tile to load (concept: PitchLinearShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Internal structure of iterator - made public to enable introspection -+ using Policy = Policy_; -+ -+ /// Underlying tile iterator implementation -+ using Base = ScaleBiasTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ layout::PitchLinearShape, -+ Policy, kThreads, PartitionsK_>; -+ -+ public: -+ // -+ // Derived quantities -+ // -+ -+ /// Fragment object holding a thread's part of a tile -+ using Fragment = typename Base::Fragment; -+ -+ private: -+ /// Underlying tile iterator -+ Base iterator_; -+ -+ public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator() {} -+ -+ /// Constructor from TensorRef -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator(TensorRef const &ref_scale_bias, int lane_id) -+ : iterator_({ref_scale_bias.data(), ref_scale_bias.stride()}, lane_id) {} -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &add_pointer_offset(LongIndex offset) { -+ iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &add_tile_offset( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &add_tile_offset_negative( -+ TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset_negative({tile_offset.column(), tile_offset.row()}); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator++() { -+ ++iterator_; -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_HOST_DEVICE -+ ScaleBiasTileIterator &operator--() { -+ --iterator_; -+ -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator+=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of -+ ///< the tensor -+ CUTLASS_DEVICE -+ ScaleBiasTileIterator &operator-=( -+ TensorCoord const &tile_offset) { -+ add_tile_offset(-PitchLinearCoord(tile_offset.column(), tile_offset.row())); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { iterator_.load(frag); } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ // TODO -+ assert(0); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ iterator_.load_with_byte_offset( -+ frag, {tile_offset.strided(), tile_offset.contiguous()}, byte_offset); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h b/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h -new file mode 100644 -index 0000000..bf8efe9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/softmax_scale_bias_transform.h -@@ -0,0 +1,117 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level per-channel softmax before -+ matrix multiply-accumulate operations targeting Tensor Cores. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/arch/mma_sm75.h" -+#include "cutlass/arch/mma_sm80.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_policy.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct SoftmaxScaleBiasTransform { -+ -+ using T = typename FragmentActivations::Element; -+ -+ static int const NumActivations = FragmentActivations::kElements; -+ static int const NumNormSum = FragmentNormSum::kElements; -+ static int const MmaElements = 2; -+ // One element has one scale and one bias -+ static int const MmaScaleBiasPair = 2; -+ // 16816 has 2 columns and 2 rows -+ static int const MmaCols = 2; -+ static int const MmaRows = 2; -+ -+ using MmaOperand = Array; -+ using NormSumOperand = Array<__half2, MmaScaleBiasPair>; -+ -+ CUTLASS_DEVICE -+ void transform(MmaOperand &activations, -+ NormSumOperand const &norm_sum) { -+ -+ __half2* packed_activations = reinterpret_cast<__half2*>(&activations); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < MmaElements / 2; ++i) { -+ __half2 out = ::h2exp(__hsub2(packed_activations[i], norm_sum[2*i])); -+ packed_activations[i] = __hmul2(out, norm_sum[2*i + 1]); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(FragmentActivations &activations, -+ FragmentNormSum const &norm_sum) { -+ MmaOperand *ptr_activations = reinterpret_cast(&activations); -+ NormSumOperand const *ptr_norm_sum = -+ reinterpret_cast(&norm_sum); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < (NumActivations / MmaElements); ++i) { -+ transform(ptr_activations[i], -+ ptr_norm_sum[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h b/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h -new file mode 100644 -index 0000000..1633dd2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/gemm/warp/tile_iterator_planar_complex.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing warp-level matrix multiply-accumulate operations. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/array_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TileIteratorPlanarComplex { -+public: -+ -+ /// Underlying iterator over real-valued tiles -+ using TileIterator = TileIterator_; -+ -+ /// Underlying element type -+ using Element = typename TileIterator::Element; -+ -+ /// Underlying layout type -+ using Layout = typename TileIterator::Layout; -+ -+ /// TensorRef type for loading element from a tensor -+ using TensorRef = typename TileIterator::TensorRef; -+ -+ /// Index type -+ using Index = typename TensorRef::Index; -+ -+ /// Long Index type -+ using LongIndex = typename TensorRef::LongIndex; -+ -+ /// Coordinate for an element in the tensor -+ using TensorCoord = typename TensorRef::TensorCoord; -+ -+ /// Planar complex fragment -+ using Fragment = ArrayPlanarComplex; -+ -+public: -+ -+ /// Underlying tile iterator -+ TileIterator tile_iterator_; -+ -+ /// Offset (in units of bytes) to the imaginary part of the planar complex matrix -+ LongIndex imaginary_offset_; -+ -+public: -+ /// Default ctor constructs null iterator -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex(): imaginary_offset_(0) { } -+ -+ /// Constructor from TensorRef -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex( -+ TensorRef const &ref, -+ int lane_id, -+ LongIndex imaginary_offset -+ ): -+ tile_iterator_(ref, lane_id), -+ imaginary_offset_((imaginary_offset * sizeof_bits::value) / 8) { } -+ -+ -+ /// Adds a pointer offset to internal pointer(s) to advance through memory -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex &add_pointer_offset(LongIndex offset) { -+ -+ tile_iterator_.add_pointer_offset(offset); -+ -+ return *this; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex &add_tile_offset(TensorCoord const &tile_offset) { -+ -+ tile_iterator_.add_tile_offset(tile_offset); -+ -+ return *this; -+ } -+ -+ /// Advances the iterator along the advance dimension -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator++() { -+ ++tile_iterator_; -+ return *this; -+ } -+ -+ // -+ // WIP -+ // -+ -+ /// Advances the iterator along the opposite of the advance dimension -+ CUTLASS_HOST_DEVICE -+ TileIteratorPlanarComplex & operator--() { -+ --tile_iterator_; -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator+=(TensorCoord const &tile_offset) { -+ tile_iterator_.add_tile_offset(tile_offset); -+ return *this; -+ } -+ -+ ///< advances in units of whole tiles along the logical coordinate space of the tensor -+ CUTLASS_DEVICE -+ TileIteratorPlanarComplex & operator-=(TensorCoord const &tile_offset) { -+ tile_iterator_.add_tile_offset(-tile_offset); -+ return *this; -+ } -+ -+ /// Loads a fragment from memory at the location pointed to by the iterator. -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, 0); -+ tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset in units of bytes -+ Index byte_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with additional logical offset -+ CUTLASS_DEVICE -+ void load_with_pointer_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a linear offset -+ Index pointer_offset) const { -+ -+ Index byte_offset = (pointer_offset * sizeof_bits::value)/8; -+ -+ tile_iterator_.load_with_byte_offset(frag.real, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0); -+ tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index pointer_offset) const { -+ -+ Index byte_offset = (pointer_offset * sizeof_bits::value)/8; -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset + imaginary_offset_); -+ } -+ -+ /// Loads a fragment from memory with logical offset in units of whole tiles. -+ CUTLASS_DEVICE -+ void load_with_byte_offset( -+ /// fragment to load from the tensor -+ Fragment &frag, -+ /// loads a tile with a logical offset in units of whole tiles -+ TensorCoord const &tile_offset, -+ /// loads a tile with a logical offset AND a pointer offset -+ Index byte_offset) const { -+ -+ tile_iterator_.load_with_byte_offset(frag.real, tile_offset, byte_offset); -+ tile_iterator_.load_with_byte_offset(frag.imag, tile_offset, byte_offset + imaginary_offset_); -+ } -+ -+ /// Notify the iterator which k-group it is currently pointing to. -+ /// -+ /// This does not advance the iterator. Rather, it overrides its internal -+ /// tracking with constant-valued k-group index to enable the compiler to -+ /// fold constants and achieve more efficient code. -+ /// -+ /// This is used by some nontrivial permuted layouts. -+ CUTLASS_DEVICE -+ void set_kgroup_index(int k_group) { -+ tile_iterator_.set_kgroup_index(k_group); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/half.h b/3rdparty/cutlass/include/cutlass/half.h -new file mode 100644 -index 0000000..8d90b26 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/half.h -@@ -0,0 +1,919 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using IEEE half-precision floating-point types in host or -+ device code. -+*/ -+#pragma once -+ -+#ifndef CUTLASS_ENABLE_F16C -+#define CUTLASS_ENABLE_F16C 0 -+#endif -+ -+#if defined(__CUDACC_RTC__) -+ -+#include "cutlass/floating_point_nvrtc.h" -+ -+// F16C extensions are not meaningful when compiling for NVRTC which only accommodates device code. -+#undef CUTLASS_ENABLE_F16C -+#define CUTLASS_ENABLE_F16C 0 -+ -+#else -+// -+// Standard Library headers belong here to avoid conflicts with NVRTC. -+// -+#include -+#include -+#include -+#include -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/float8.h" -+#include "cutlass/platform/platform.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Optionally target F16C extentions to accelerate half-precision conversion. -+#if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) -+#if defined(_MSC_VER) -+ -+#include -+ -+#if defined(__i386__) || defined(__x86_64__) -+#include -+#endif -+ -+#define F16C_ROUND_NEAREST 0 -+ -+#if !defined(__CUDA_ARCH__) -+extern __inline float _cvtsh_ss (unsigned short __S) { -+ __m128i packed; -+ std::memcpy(&packed, &__S, sizeof(__S)); -+ -+ __m128 result = _mm_cvtph_ps(packed); -+ -+ float flt; -+ std::memcpy(&flt, &result, sizeof(flt)); -+ -+ return flt; -+} -+ -+__inline unsigned short _cvtss_sh (float __F, const int) { -+ __m128 packed; -+ std::memcpy(&packed, &__F, sizeof(__F)); -+ -+ __m128i result = _mm_cvtps_ph(packed, F16C_ROUND_NEAREST); -+ -+ unsigned short u; -+ std::memcpy(&u, &result, sizeof(u)); -+ -+ return u; -+} -+#endif -+ -+#else -+ -+// Linux -+#include -+ -+#if defined(__i386__) || defined(__x86_64__) -+#include -+#endif -+ -+#define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) -+ -+#endif // _MSC_VER -+ -+class CpuId { -+ -+ bool f16c_enabled; -+ -+ CpuId() { -+ #if defined(__i386__) || defined(__x86_64__) -+ #if defined(_MSC_VER) -+ int exx[4]; -+ -+ __cpuid (exx, 1); -+ f16c_enabled = exx[2] & 0x20000000; -+ -+ #else -+ // GCC / Clang -+ int eax, ebx, ecx, edx; -+ -+ __cpuid (1 , eax, ebx, ecx, edx); -+ f16c_enabled = ecx & 0x20000000; -+ #endif -+ #else -+ // Arm / PowerPC etc. -+ f16c_enabled = false; -+ #endif -+ } -+ -+public: -+ -+ bool is_f16c_supported() const { -+ return f16c_enabled; -+ } -+ -+ static const CpuId& instance() { -+ static CpuId cpu; -+ return cpu; -+ } -+}; -+#endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// IEEE half-precision floating-point type -+struct alignas(2) half_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint16_t storage; -+ -+ // -+ // Static conversion operators -+ // -+ -+ /// Constructs from an unsigned short -+ CUTLASS_HOST_DEVICE -+ static half_t bitcast(uint16_t x) { -+ half_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) -+ // Avoid inlining in device code if no hardware support -+ __device__ __noinline__ -+ #else -+ CUTLASS_HOST_DEVICE -+ #endif -+ static half_t convert(float const& flt) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__float2half_rn(flt)); -+ #else -+ -+ #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ if( CpuId::instance().is_f16c_supported() ) { -+ unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST); -+ return bitcast(u); -+ } -+ #endif -+ -+ // software implementation rounds toward nearest even -+ unsigned s; -+ -+ #if defined(__CUDA_ARCH__) -+ s = reinterpret_cast(flt); -+ #else -+ std::memcpy(&s, &flt, sizeof(s)); -+ #endif -+ -+ uint16_t sign = uint16_t((s >> 16) & 0x8000); -+ int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); -+ int mantissa = s & 0x7fffff; -+ uint16_t u = 0; -+ -+ if ((s & 0x7fffffff) == 0) { -+ // sign-preserving zero -+ return bitcast(sign); -+ } -+ -+ if (exp > 15) { -+ if (exp == 128 && mantissa) { -+ // not a number -+ u = 0x7fff; -+ } else { -+ // overflow to infinity -+ u = sign | 0x7c00; -+ } -+ return bitcast(u); -+ } -+ -+ int sticky_bit = 0; -+ -+ if (exp >= -14) { -+ // normal fp32 to normal fp16 -+ exp = uint16_t(exp + uint16_t(15)); -+ u = uint16_t(((exp & 0x1f) << 10)); -+ u = uint16_t(u | (mantissa >> 13)); -+ } else { -+ // normal single-precision to subnormal half_t-precision representation -+ int rshift = (-14 - exp); -+ if (rshift < 32) { -+ mantissa |= (1 << 23); -+ -+ sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); -+ -+ mantissa = (mantissa >> rshift); -+ u = (uint16_t(mantissa >> 13) & 0x3ff); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ } -+ -+ // round to nearest even -+ int round_bit = ((mantissa >> 12) & 1); -+ sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { -+ u = uint16_t(u + 1); -+ } -+ -+ u |= sign; -+ -+ return bitcast(u); -+ #endif -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static half_t convert(int const& n) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__int2half_rn(n)); -+ #else -+ return convert(float(n)); -+ #endif -+ } -+ -+ /// FP32 -> FP16 conversion - rounds to nearest even -+ CUTLASS_HOST_DEVICE -+ static half_t convert(unsigned const& n) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__uint2half_rn(n)); -+ #else -+ return convert(float(n)); -+ #endif -+ } -+ -+ /// Converts a half-precision value stored as a uint16_t to a float -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) -+ // Avoid inlining in device code if no hardware support -+ __device__ __noinline__ -+ #else -+ CUTLASS_HOST_DEVICE -+ #endif -+ static float convert(half_t const& x) { -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __half2float(x.to_half()); -+ #else -+ -+ #if !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C -+ if( CpuId::instance().is_f16c_supported() ) { -+ unsigned short u = x.storage; -+ return _cvtsh_ss(u); -+ } -+ #endif -+ -+ uint16_t const &h = x.storage; -+ int sign = ((h >> 15) & 1); -+ int exp = ((h >> 10) & 0x1f); -+ int mantissa = (h & 0x3ff); -+ unsigned f = 0; -+ -+ if (exp > 0 && exp < 31) { -+ // normal -+ exp += 112; -+ f = (sign << 31) | (exp << 23) | (mantissa << 13); -+ } else if (exp == 0) { -+ if (mantissa) { -+ // subnormal -+ exp += 113; -+ while ((mantissa & (1 << 10)) == 0) { -+ mantissa <<= 1; -+ exp--; -+ } -+ mantissa &= 0x3ff; -+ f = (sign << 31) | (exp << 23) | (mantissa << 13); -+ } else { -+ // sign-preserving zero -+ f = (sign << 31); -+ } -+ } else if (exp == 31) { -+ if (mantissa) { -+ f = 0x7fffffff; // not a number -+ } else { -+ f = (0xff << 23) | (sign << 31); // inf -+ } -+ } -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(f); -+ #else -+ float flt; -+ std::memcpy(&flt, &f, sizeof(flt)); -+ return flt; -+ #endif -+ #endif -+ } -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ half_t() = default; -+ -+ /// Reinterpret cast from CUDA's half type -+ CUTLASS_HOST_DEVICE -+ explicit half_t(half const & x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ __half_raw raw(x); -+ std::memcpy(&storage, &raw.x, sizeof(storage)); -+ #endif -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Floating point conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(double x): half_t(float(x)) { -+ -+ } -+ -+ /// float_e4m3_t conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float_e4m3_t x): half_t(float(x)) { -+ -+ } -+ -+ /// float_e5m2_t conversion -+ CUTLASS_HOST_DEVICE -+ explicit half_t(float_e5m2_t x): half_t(float(x)) { -+ -+ } -+ -+ /// Integer conversion - round to nearest even -+ CUTLASS_HOST_DEVICE -+ explicit half_t(int x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Integer conversion - round toward zero -+ CUTLASS_HOST_DEVICE -+ explicit half_t(unsigned x) { -+ storage = convert(x).storage; -+ } -+ -+ /// Assignment -+ CUTLASS_HOST_DEVICE -+ half_t & operator=(half const &x) { -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(x); -+ #else -+ __half_raw raw(x); -+ std::memcpy(&storage, &raw.x, sizeof(storage)); -+ #endif -+ return *this; -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ return convert(*this); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(convert(*this)); -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(convert(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (convert(*this) != 0.0f); -+ } -+ -+ /// Bitcasts to CUDA's half type -+ CUTLASS_HOST_DEVICE -+ half to_half() const { -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(storage); -+ #else -+ __half_raw raw; -+ std::memcpy(&raw.x, &storage, sizeof(raw.x)); -+ return half(raw); -+ #endif -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint16_t& raw() { -+ return storage; -+ } -+ -+ /// Accesses raw internal state -+ CUTLASS_HOST_DEVICE -+ uint16_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((storage & 0x8000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((storage >> 10) & 0x1f); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 15; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(storage & 0x3ff); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::half_t const& h) { -+ return ((h.raw() & 0x8000) != 0); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t abs(cutlass::half_t const& h) { -+ return cutlass::half_t::bitcast(h.raw() & 0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::half_t const& h) { -+ return (h.exponent_biased() == 0x1f) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::half_t const& h) { -+ return (h.exponent_biased() != 0x1f); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t nanh(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::half_t::bitcast(0x7fff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::half_t const& h) { -+ return (h.exponent_biased() == 0x1f) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::half_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x1f; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::half_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x1f) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t sqrt(cutlass::half_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::half_t(sqrtf(float(h))); -+#else -+ return cutlass::half_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t copysign(half_t const& a, half_t const& b) { -+ -+ uint16_t a_mag = (a.raw() & 0x7fff); -+ uint16_t b_sign = (b.raw() & 0x8000); -+ uint16_t result = (a_mag | b_sign); -+ -+ return half_t::bitcast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if !defined(__CUDACC_RTC__) -+namespace std { -+ -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = true; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 10; -+ -+ /// Least positive value -+ static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } -+ -+ /// Minimum finite value -+ static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } -+ -+ /// Maximum finite value -+ static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } -+ -+ /// Returns smallest finite value -+ static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } -+ -+ /// Returns maximum rounding error -+ static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } -+ -+ /// Returns positive infinity value -+ static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } -+ -+ /// Returns quiet NaN value -+ static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns signaling NaN value -+ static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns smallest positive subnormal value -+ static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -+}; -+} // namespace std -+#endif -+ -+namespace platform { -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+#if !defined(__CUDACC_RTC__) -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+#endif -+ static bool const has_denorm_loss = true; -+#if !defined(__CUDACC_RTC__) -+ static std::float_round_style const round_style = std::round_to_nearest; -+#endif -+ static bool const is_iec559 = true; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 10; -+ -+ /// Least positive value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t min() { return cutlass::half_t::bitcast(0x0001); } -+ -+ /// Minimum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t lowest() { return cutlass::half_t::bitcast(0xfbff); } -+ -+ /// Maximum finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t max() { return cutlass::half_t::bitcast(0x7bff); } -+ -+ /// Returns smallest finite value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } -+ -+ /// Returns maximum rounding error -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } -+ -+ /// Returns positive infinity value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } -+ -+ /// Returns quiet NaN value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns signaling NaN value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } -+ -+ /// Returns smallest positive subnormal value -+ CUTLASS_HOST_DEVICE -+ static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } -+}; -+} // namespace platform -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __heq(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) == float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hne(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) != float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hlt(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) < float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hle(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) <= float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hgt(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) > float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return __hge(lhs.to_half(), rhs.to_half()); -+#else -+ return float(lhs) >= float(rhs); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator+(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hadd(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) + float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator-(half_t const& lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hneg(lhs.to_half())); -+#else -+ return half_t(-float(lhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator-(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hsub(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) - float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator*(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hmul(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) * float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator/(half_t const& lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__hdiv(lhs.to_half(), rhs.to_half())); -+#else -+ return half_t(float(lhs) / float(rhs)); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator+=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) + float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator-=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) - float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator*=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hmul(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) * float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator/=(half_t & lhs, half_t const& rhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hdiv(lhs.to_half(), rhs.to_half())); -+#else -+ lhs = half_t(float(lhs) / float(rhs)); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator++(half_t & lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ ++tmp; -+ lhs = half_t(tmp); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t& operator--(half_t & lhs) { -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ --tmp; -+ lhs = half_t(tmp); -+#endif -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator++(half_t & lhs, int) { -+ half_t ret(lhs); -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hadd(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ tmp++; -+ lhs = half_t(tmp); -+#endif -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+half_t operator--(half_t & lhs, int) { -+ half_t ret(lhs); -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ lhs = half_t(__hsub(lhs.to_half(), half_t(1.0f).to_half())); -+#else -+ float tmp(lhs); -+ tmp--; -+ lhs = half_t(tmp); -+#endif -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t operator "" _hf(long double x) { -+ return cutlass::half_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::half_t operator "" _hf(unsigned long long int x) { -+ return cutlass::half_t(int(x)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/integer_subbyte.h b/3rdparty/cutlass/include/cutlass/integer_subbyte.h -new file mode 100644 -index 0000000..f02a7d3 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/integer_subbyte.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a class for using integer types smaller than one byte in host or -+ device code. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-bit signed integer type -+template -+struct integer_subbyte { -+ -+ /// Number of bits -+ static int const kBits = Bits; -+ -+ /// Whether type is signed -+ static bool const kSigned = Signed; -+ -+ /// External type -+ using T = typename platform::conditional::type; -+ -+ /// Storage type -+ using Storage = uint8_t; -+ -+ /// Bitmask used to truncate from larger integers -+ static Storage const kMask = Storage((1 << kBits) - 1); -+ -+ // -+ // Data members -+ // -+ -+ Storage storage; -+ -+ // -+ // Methods -+ // -+ -+ /// No operation -+ integer_subbyte() = default; -+ -+ /// Conversion from integer type -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(int value) -+ : storage(reinterpret_cast(value) & kMask) {} -+ -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(unsigned value) -+ : storage(reinterpret_cast(value) & kMask) {} -+ -+ CUTLASS_HOST_DEVICE -+ integer_subbyte(double value) { -+ T tmp = static_cast(value); -+ storage = Storage(reinterpret_cast(tmp) & kMask); -+ } -+ -+ /// -+ CUTLASS_HOST_DEVICE -+ operator T() const { -+ if (kSigned) { -+ // Sign extend -+ if (storage & Storage(1 << (kBits - 1))) { -+ return T(storage) | ~T(kMask); -+ } -+ } -+ return T(storage); -+ } -+ -+ /// Equality -+ CUTLASS_HOST_DEVICE -+ bool operator==(integer_subbyte const &rhs) const { -+ return storage == rhs.storage; -+ } -+ -+ /// Inequality -+ CUTLASS_HOST_DEVICE -+ bool operator!=(integer_subbyte const &rhs) const { -+ return storage != rhs.storage; -+ } -+ -+ /// Less than or equal -+ CUTLASS_HOST_DEVICE -+ bool operator<=(integer_subbyte const &rhs) const { -+ if (kSigned) { -+ if (storage & (1 << (kBits - 1))) { -+ return !(rhs.storage < storage); -+ } -+ } -+ return storage < rhs.storage; -+ } -+ -+ /// Less than -+ CUTLASS_HOST_DEVICE -+ bool operator<(integer_subbyte const &rhs) const { -+ if (kSigned) { -+ if (storage & (1 << (kBits - 1))) { -+ return !(rhs.storage <= storage); -+ } -+ } -+ return storage < rhs.storage; -+ } -+ -+ /// Greater than or equal -+ CUTLASS_HOST_DEVICE -+ bool operator>=(integer_subbyte const &rhs) const { -+ return !(*this < rhs); -+ } -+ -+ /// Greater than -+ CUTLASS_HOST_DEVICE -+ bool operator>(integer_subbyte const &rhs) const { -+ return !(*this <= rhs); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// 1-bit Unsigned integer type -+using uint1b_t = integer_subbyte<1, false>; -+ -+/// 2-bit Integer type -+using int2b_t = integer_subbyte<2, true>; -+ -+/// 2-bit Unsigned integer type -+using uint2b_t = integer_subbyte<2, false>; -+ -+/// 4-bit Integer type -+using int4b_t = integer_subbyte<4, true>; -+ -+/// 4-bit Unsigned integer type -+using uint4b_t = integer_subbyte<4, false>; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an element in bits - specialized for uint1b_t -+template <> -+struct sizeof_bits { -+ static int const value = 1; -+}; -+ -+/// Defines the size of an element in bits - specialized for int2b_t -+template <> -+struct sizeof_bits { -+ static int const value = 2; -+}; -+ -+/// Defines the size of an element in bits - specialized for uint2b_t -+template <> -+struct sizeof_bits { -+ static int const value = 2; -+}; -+ -+/// Defines the size of an element in bits - specialized for int4b_t -+template <> -+struct sizeof_bits { -+ static int const value = 4; -+}; -+ -+/// Defines the size of an element in bits - specialized for uint4b_t -+template <> -+struct sizeof_bits { -+ static int const value = 4; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace platform { -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::int4b_t const lowest() noexcept { return -8;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::int4b_t const max() noexcept { return 7;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint4b_t const lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint4b_t const max() noexcept { return 15;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint1b_t const lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static cutlass::uint1b_t const max() noexcept { return 1;} -+ static constexpr bool is_integer = true; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace platform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp b/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp -new file mode 100644 -index 0000000..3ae0932 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/kernel_hardware_info.hpp -@@ -0,0 +1,71 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cuda_runtime.h" -+ -+#include "cutlass/trace.h" -+ -+namespace cutlass { -+ -+struct KernelHardwareInfo { -+ // -+ // Data members -+ // -+ int device_id = 0; -+ int sm_count = 0; -+ -+ // -+ // Methods -+ // -+ -+ static int -+ query_device_multiprocessor_count(int device_id = 0) { -+ cudaError_t result = cudaGetDevice(&device_id); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST( -+ " cudaGetDevice() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ cudaDeviceProp properties; -+ result = cudaGetDeviceProperties(&properties, device_id); -+ if (result != cudaSuccess) { -+ CUTLASS_TRACE_HOST( -+ " cudaGetDeviceProperties() returned error " -+ << cudaGetErrorString(result)); -+ return 0; -+ } -+ return properties.multiProcessorCount; -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/kernel_launch.h b/3rdparty/cutlass/include/cutlass/kernel_launch.h -new file mode 100644 -index 0000000..c54f1fa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/kernel_launch.h -@@ -0,0 +1,73 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines structures and helpers to launch CUDA kernels within CUTLASS. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure containing the basic launch configuration of a CUDA kernel. -+struct KernelLaunchConfiguration { -+ -+ /// CUDA grid dimensions -+ dim3 grid; -+ -+ /// CUDA threablock dimensions -+ dim3 block; -+ -+ /// Bytes of dynamically allocated SMEM in addition to static SMEM -+ size_t dynamic_smem; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a KernellaunchConfiguration object -+ CUTLASS_HOST_DEVICE -+ KernelLaunchConfiguration( -+ dim3 _grid = dim3(1,1,1), -+ dim3 _block = dim3(1,1,1), -+ size_t _dynamic_smem = 0 -+ ): -+ grid(_grid), -+ block(_block), -+ dynamic_smem(_dynamic_smem) { } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/layout.h b/3rdparty/cutlass/include/cutlass/layout/layout.h -new file mode 100644 -index 0000000..6f638eb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/layout.h -@@ -0,0 +1,64 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/vector.h" -+ -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/matrix.h b/3rdparty/cutlass/include/cutlass/layout/matrix.h -new file mode 100644 -index 0000000..fe7a848 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/matrix.h -@@ -0,0 +1,1371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+ -+#include "cute/layout.hpp" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/pitch_linear_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines data layouts of various matrix formats usable by TensorRef and other classes. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for row-major matrices. -+class RowMajor { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ RowMajor(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajor(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajor packed(MatrixCoord const &extent) { -+ return RowMajor(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return LongIndex(coord.row()) * LongIndex(stride_[0]) + coord.column(); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ return MatrixCoord(Index(offset / stride_[0]), Index(offset % stride_[0])); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return LongIndex(extent.row()) * LongIndex(stride_[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ cute::Layout, cute::Stride > > -+ to_cute_layout(MatrixCoord const &extent) const { -+ return cute::Layout, cute::Stride > >{ -+ {extent[0], extent[1]}, -+ {stride(0), cute::Int<1>{}} -+ }; -+ } -+}; -+ -+/// Mapping function for column-major matrices. -+class ColumnMajor { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajor(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajor(Stride stride): stride_(stride) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajor packed(MatrixCoord const &extent) { -+ return ColumnMajor(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return LongIndex(coord.column()) * LongIndex(stride_[0]) + coord.row(); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ return MatrixCoord(Index(offset % stride_[0]), Index(offset / stride_[0])); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return LongIndex(extent.column()) * LongIndex(stride_[0]); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ cute::Layout, cute::Stride< cute::Int<1>, int64_t> > -+ to_cute_layout(MatrixCoord const &extent) const { -+ return cute::Layout, cute::Stride, int64_t> >{ -+ {extent[0], extent[1]}, -+ {cute::Int<1>{}, stride(0)} -+ }; -+ } -+}; -+ -+/// Mapping function for interleaved matrices. Matrix is structured -+/// as row-major arrangement of fixed-size columns. -+template -+struct RowMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of interleaved columns -+ static int const kInterleave = Interleave; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorInterleaved packed(MatrixCoord const &extent) { -+ return RowMajorInterleaved(extent.column() * kInterleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index row_major = coord.row() / kInterleave; -+ Index row_minor = coord.row() % kInterleave; -+ return LongIndex(row_major) * LongIndex(stride_[0]) + LongIndex(coord.column()) * kInterleave + row_minor; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ Index row_major = Index(offset / stride_[0]); -+ Index residual = Index(offset % stride_[0]); -+ -+ Index column = residual / kInterleave; -+ Index row_minor = residual % kInterleave; -+ -+ return MatrixCoord(row_major * kInterleave + row_minor, column); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.row() + kInterleave - 1) / kInterleave * stride_[0]; -+ } -+}; -+ -+/// Mapping function for interleaved matrices. Matrix is structured -+/// as column-major arrangement of fixed-size rows. -+template -+struct ColumnMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of interleaved columns -+ static int const kInterleave = Interleave; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorInterleaved(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorInterleaved packed(MatrixCoord const &extent) { -+ return ColumnMajorInterleaved(extent.row() * kInterleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index column_major = coord.column() / kInterleave; -+ Index column_minor = coord.column() % kInterleave; -+ return LongIndex(column_major) * LongIndex(stride_[0]) + LongIndex(coord.row()) * kInterleave + column_minor; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ Index column_major = Index(offset / stride_[0]); -+ Index residual = Index(offset % stride_[0]); -+ -+ Index row = residual / kInterleave; -+ Index column_minor = residual % kInterleave; -+ -+ return MatrixCoord(row, column_major * kInterleave + column_minor); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.column() + kInterleave - 1) / kInterleave * stride_[0]; -+ } -+}; -+ -+/// Enumerated type for canonical pitch-linear matrix layouts -+enum class Matrix { -+ kColumnMajor, ///< leading dimension refers to stride between columns; stride along rows is 1 -+ kRowMajor ///< leading dimension refers to stride between rows; stride along columns is 1 -+}; -+ -+/// Mapping function for scenario in which layout is row-major or column-major but this information -+/// is only available at runtime. -+struct ContiguousMatrix { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+ /// Enumerated type indicating canonical matrix layout -+ Matrix layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ContiguousMatrix( -+ Index ldm = 0, -+ Matrix layout = Matrix::kColumnMajor -+ ): -+ stride_(ldm), layout_(layout) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ContiguousMatrix packed( -+ MatrixCoord const &extent, -+ Matrix layout = Matrix::kColumnMajor) { -+ -+ Index ldm = 0; -+ if (layout == Matrix::kColumnMajor) { -+ ldm = extent.row(); -+ } -+ else if (layout == Matrix::kRowMajor) { -+ ldm = extent.column(); -+ } -+ return ContiguousMatrix(ldm, layout); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ if (layout_ == Matrix::kColumnMajor) { -+ return coord.row() + coord.column() * stride_[0]; -+ } -+ else if (layout_ == Matrix::kRowMajor) { -+ return coord.row() * stride_[0] + coord.column(); -+ } -+ else { -+ // degenerate case -+ return 0; -+ } -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ if (layout_ == Matrix::kColumnMajor) { -+ return stride_[0] * extent.column(); -+ } -+ else if (layout_ == Matrix::kRowMajor) { -+ return stride_[0] * extent.row(); -+ } -+ else { -+ // degenerate case -+ return 0; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+template -+struct AffineRankN { -+ -+ /// Logical rank of tensor -+ static int const kRank = Rank; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = kRank; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ Coord const &stride_m, -+ Coord const &stride_n -+ ) { -+ -+ // Concatenate the strides -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kRank/2; ++m) { -+ stride_[m] = stride_m[m]; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kRank/2; ++n) { -+ stride_[n + kRank/2] = stride_n[n]; -+ } -+ } -+ -+ /// Ctor for N = 2 -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ LongIndex const &stride_m, -+ LongIndex const &stride_n -+ ) { -+ stride_[0] = stride_m; -+ stride_[1] = stride_n; -+ } -+ -+ /// Ctor for N = 2 -+ CUTLASS_HOST_DEVICE -+ AffineRankN( -+ LongIndex const &stride -+ ) { -+ stride_[0] = stride; -+ stride_[1] = 1; -+ } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRankN packed(TensorCoord const &extent) { -+ -+ AffineRankN layout; -+ layout.stride_[kRank - 1] = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kRank - 1; i > 0; --i) { -+ layout.stride_[i - 1] = layout.stride_[i] * extent[i]; -+ } -+ -+ return layout; -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ // TODO -+ return TensorCoord(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ int idx = stride_.max_dim_index(); -+ return extent[idx] * stride_[idx]; -+ } -+}; -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+/// Row stride is smaller than column stride in AffineRank2ColumnMajor. -+struct AffineRank2ColumnMajor { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ LongIndex row_stride, ///< stride between elements in consecutive rows -+ LongIndex column_stride ///< stride between elements in consecutive columns -+ ) -+ { stride_[0] = row_stride; stride_[1] = column_stride;} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2ColumnMajor( -+ LongIndex stride -+ ) -+ { stride_[0] = 1; stride_[1] = stride;} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRank2ColumnMajor packed(MatrixCoord const &extent) { -+ return AffineRank2ColumnMajor(extent.column(), 1); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return extent.column() * stride_[1]; -+ } -+}; -+ -+/// Mapping function for scenario in which both rows and columns are separated by a stride. -+/// Column stride is smaller than row stride in AffineRank2RowMajor. -+struct AffineRank2RowMajor { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ Stride const &stride = Stride() -+ ): -+ stride_(stride) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ LongIndex row_stride, ///< stride between elements in consecutive rows -+ LongIndex column_stride ///< stride between elements in consecutive columns -+ ) { stride_[0] = row_stride; stride_[1] = column_stride;} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ AffineRank2RowMajor( -+ LongIndex stride -+ ) { stride_[0] = stride; stride_[1] = 1;} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static AffineRank2RowMajor packed(MatrixCoord const &extent) { -+ return AffineRank2RowMajor(extent.column(), 1); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return dot(coord, stride_); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return extent.row() * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Utility functions to convert stride_factor to the strides used by the Affine2 layout. -+// -+// stride_factor is the logical distance between two coorinates. -+// -+// All Coodinates used here are matrix coordinates. stride[0] and extent[0] are for the -+// rows. stride[1] and extent[1] are for the columns. -+template -+ struct Affine2Layout_Factory { -+ CUTLASS_HOST_DEVICE -+ static Affine2Layout layout_factory(cutlass::Coord<2> const &extent, typename Affine2Layout::Stride stride_factor) { -+ return Affine2Layout::packed(extent); -+ } -+}; -+ -+template <> -+struct Affine2Layout_Factory { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRank2ColumnMajor layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRank2ColumnMajor::Stride stride_factor) { -+ return cutlass::layout::AffineRank2ColumnMajor({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); -+ } -+}; -+ -+template <> -+struct Affine2Layout_Factory { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRank2RowMajor layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRank2RowMajor::Stride stride_factor) { -+ return cutlass::layout::AffineRank2RowMajor({ stride_factor[0] * stride_factor[1] * extent[1], stride_factor[1] }); -+ } -+}; -+ -+// The base layout cutlass::layout::AffineRankN<2> is similar to AffineRank2ColumnMajor -+template <> -+struct Affine2Layout_Factory> { -+CUTLASS_HOST_DEVICE -+static cutlass::layout::AffineRankN<2> layout_factory( -+ cutlass::Coord<2> const &extent, -+ typename cutlass::layout::AffineRankN<2>::Stride stride_factor) { -+ return cutlass::layout::AffineRankN<2>({ stride_factor[0], stride_factor[0] * stride_factor[1] * extent[0] }); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for block-linear matrices. Matrix is structured -+/// as column-major arrangement of 2D tiles (that are column-major). -+template -+struct ColumnMajorBlockLinear { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of a block in rows -+ static int const kBlockRows = BlockRows; -+ -+ /// Size of a block in columns -+ static int const kBlockColumns = BlockColumns; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorBlockLinear(Index ldm = 0): stride_(ldm) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorBlockLinear packed(MatrixCoord const &extent) { -+ return ColumnMajorBlockLinear(extent.row() * kBlockRows * kBlockColumns); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return -+ (coord.row() % kBlockRows) + -+ (coord.column() % kBlockColumns) * kBlockRows + -+ (coord.row() / kBlockRows) * kBlockRows * kBlockColumns + -+ (coord.column() / kBlockColumns) * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.column() + kBlockColumns - 1) / kBlockColumns * stride_[0]; -+ } -+}; -+ -+/// Mapping function for block-linear matrices. Matrix is structured -+/// as row-major arrangement of 2D tiles (that are row-major) -+template -+struct RowMajorBlockLinear { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ /// Size of a block in rows -+ static int const kBlockRows = BlockRows; -+ -+ /// Size of a block in columns -+ static int const kBlockColumns = BlockColumns; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorBlockLinear(Index ldm = 0): stride_(ldm) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorBlockLinear packed(MatrixCoord const &extent) { -+ return RowMajorBlockLinear(extent.column() * kBlockRows * kBlockColumns); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ return -+ (coord.column() % kBlockColumns) + -+ (coord.row() % kBlockRows) * kBlockColumns + -+ (coord.column() / kBlockColumns) * kBlockRows * kBlockColumns + -+ (coord.row() / kBlockRows) * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ MatrixCoord inverse(LongIndex offset) const { -+ // TODO -+ return MatrixCoord(0, 0); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ return (extent.row() + kBlockRows - 1) / kBlockRows * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GeneralMatrix { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 2; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ Matrix layout_id_; -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ GeneralMatrix( -+ Matrix layout_id, -+ Index ldm, -+ Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static GeneralMatrix packed( -+ MatrixCoord const &extent, -+ Matrix layout_id = Matrix::kColumnMajor, -+ Index interleave = 1) { -+ -+ Index c; -+ if (layout_id == Matrix::kRowMajor) { -+ c = extent.column(); -+ } -+ else { -+ c = extent.row(); -+ } -+ -+ Index ldm = c * interleave; -+ -+ return GeneralMatrix(layout_id, ldm, interleave); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (row, column) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord const &coord) const { -+ Index c, s; -+ if (layout_id_ == Matrix::kRowMajor) { -+ c = coord.column(); -+ s = coord.row(); -+ } -+ else { -+ s = coord.column(); -+ c = coord.row(); -+ } -+ -+ Index v = s / stride_[1]; -+ Index residual = (s % stride_[1]); -+ -+ return LongIndex(c) * LongIndex(stride_[1]) + LongIndex(v) * LongIndex(stride_[0]) + residual; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix layout_id() const { -+ return layout_id_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix & layout_id() { -+ return layout_id_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index stride(int idx) const { -+ return stride_[idx]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ typename Stride::Index & stride(int idx) { -+ return stride_[idx]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(MatrixCoord const &extent) const { -+ Index s; -+ if (layout_id_ == Matrix::kRowMajor) { -+ s = extent.row(); -+ } -+ else { -+ s = extent.column(); -+ } -+ -+ Index v = Index((s + stride_[1] - 1) / stride_[1]); -+ return LongIndex(v) * LongIndex(stride_[0]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines transposes of matrix layouts -+template -+struct LayoutTranspose; -+ -+/// Transpose of row-major is column-major -+template <> -+struct LayoutTranspose { -+ using type = layout::ColumnMajor; -+}; -+ -+/// Transpose of column-major is row-major -+template <> -+struct LayoutTranspose { -+ using type = layout::RowMajor; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/permute.h b/3rdparty/cutlass/include/cutlass/layout/permute.h -new file mode 100644 -index 0000000..693425b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/permute.h -@@ -0,0 +1,314 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by GEMM+permute path for common tensor or matrix formats. -+ -+ Like Layout functions, permute layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Permute layout functions must implement all members in the interface of NoPermute<> defined in this file. Address offset -+ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_permute_} as new addresses after permute op. -+*/ -+#pragma once -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include "assert.h" -+#endif -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/coord.h" -+#include "cutlass/tensor_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+class NoPermute { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_unit_; // sizeof(AccessType) / kElementsPerAccess in epilogue's predicated_tile_iterator -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ NoPermute() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ NoPermute(MatrixCoord extent, Index stride_init): extent_(extent) { } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { return 0; } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines permute layouts of various tensor formats. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permute layout function for 4-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding output tensor. -+template -+class Tensor4DPermute0213 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermute0213() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermute0213(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init / D2 * D1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X -+ // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. -+ assert(extent_.row() % D1 == 0); -+ assert(extent_.column() % D2 == 0); -+ -+ int D3 = extent_.column() / D2; -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int l = col_init % D3; -+ int k = col_init / D3; -+ int j = row_init % D1; -+ int i = row_init / D1; -+ -+ // After the Permute Op -+ Index col_permute = l + j * D3; -+ Index row_permute = k + i * D2; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+ -+ /// Return D1 -+ CUTLASS_HOST_DEVICE -+ Index d1() const { -+ return D1; -+ } -+ -+ /// Return D2 -+ CUTLASS_HOST_DEVICE -+ Index d2() const { -+ return D2; -+ } -+}; -+ -+/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -+/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. -+template -+class Tensor4DPermuteBMM0213 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermuteBMM0213() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor4DPermuteBMM0213(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init * D1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ -+ // The batch index for BMM -+ Index BMM_batch_idx = blockIdx.z; -+ -+ // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X -+ // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. -+ int D2 = extent_.row(); -+ int D3 = extent_.column(); -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int l = col_init; -+ int k = row_init; -+ int j = BMM_batch_idx % D1; -+ int i = BMM_batch_idx / D1; -+ -+ // After the Permute Op -+ Index col_permute = l + j * D3; -+ Index row_permute = k + i * D2; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+ -+ /// Return D1 -+ CUTLASS_HOST_DEVICE -+ Index d1() const { -+ return D1; -+ } -+}; -+ -+/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -+/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -+template -+class Tensor5DPermute20314 { -+public: -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+private: -+ // -+ // Data members -+ // -+ -+ MatrixCoord extent_; -+ -+ Index stride_permute_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor5DPermute20314() { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Tensor5DPermute20314(MatrixCoord extent, Index stride_init): extent_(extent) { -+ -+ /// Update stride_permute with stride_init -+ stride_permute_ = stride_init / T2 * T1; // stride in Elements -+ -+ } -+ -+ /// Computes the address offset after Permute Op in Bytes -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(MatrixCoord offset_init) { -+ -+ // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X -+ // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. -+ int T0 = extent_.row() / T1; -+ int T4 = extent_.column() / T2 / T3; -+ -+ Index col_init = offset_init.column(); -+ Index row_init = offset_init.row(); -+ -+ int m = col_init % T4; -+ int l = int(col_init / T4) % T3; -+ int k = int(col_init / T4) / T3; -+ int j = row_init % T1; -+ int i = row_init / T1; -+ -+ // After the Permute Op -+ Index col_permute = m + j * T4 + l * T1 * T4; -+ Index row_permute = i + k * T0; -+ -+ return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h b/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h -new file mode 100644 -index 0000000..b49ab95 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/pitch_linear.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/pitch_linear_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+template -+ using PitchLinearShape = cutlass::PitchLinearShape < Contiguous, Strided >; -+ using PitchLinearCoord = PitchLinearCoord; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for pitch-linear memory -+class PitchLinear { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ PitchLinear(LongIndex ldm = 0): stride_(ldm) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ PitchLinear(Stride _stride): stride_(_stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static PitchLinear packed(TensorCoord const &extent) { -+ return PitchLinear(extent.contiguous()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return LongIndex(coord.contiguous()) + LongIndex(coord.strided()) * LongIndex(stride_[0]); -+ } -+ -+ /// Returns the logical coordinate given an offset. -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex index) const { -+ return make_Coord( -+ TensorCoord::Index(index % stride_[0]), -+ TensorCoord::Index(index / stride_[0]) -+ ); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ LongIndex stride(int rank) const { -+ return stride_[rank]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ LongIndex & stride(int rank) { -+ return stride_[rank]; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.strided() * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor.h b/3rdparty/cutlass/include/cutlass/layout/tensor.h -new file mode 100644 -index 0000000..29ac570 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor.h -@@ -0,0 +1,636 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D -+ tensor formats. -+ -+ Layout functions map logical coordinates to linear memory. They often require additional -+ data to describe strides between elements. -+ -+ Layout functions must implement all members in the public interface of IdentityTensorLayout<> -+ defined in cutlass/tensor_ref.h. -+*/ -+#pragma once -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include "assert.h" -+#endif -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/coord.h" -+#include "cutlass/tensor_coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Defines data layouts of various tensor formats usable by TensorRef and other classes. -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NHWC tensors. -+class TensorNHWC { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate (n, h, w, c) -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [stride_w, stride_h, stride_n] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNHWC(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNHWC( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNHWC(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed NHWC tensor. -+ CUTLASS_HOST_DEVICE -+ static TensorNHWC packed(TensorCoord const &extent) { -+ return TensorNHWC( -+ make_Coord( -+ extent.c(), -+ extent.w() * extent.c(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate (n, h, w, c) in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.c() + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord coord) const { -+ return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); -+ } -+ -+ /// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory. -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex index) const { -+ -+ int n = 0, h = 0, w = 0, c = 0; -+ -+ #if defined(__CUDA_ARCH__) -+ int tmp = 0; -+ c = int(index % static_cast(stride_[0])); -+ -+ unsigned int hw_mul, hw_shr, w_mul, w_shr, c_mul, c_shr; -+ -+ find_divisor(hw_mul, hw_shr, stride_[2]); -+ find_divisor(w_mul, w_shr, stride_[1]); -+ find_divisor(c_mul, c_shr, stride_[0]); -+ -+ fast_divmod(n, tmp, index, int(stride_[2]), hw_mul, hw_shr); -+ fast_divmod(h, w, tmp, int(stride_[1]), w_mul, w_shr); -+ fast_divmod(w, tmp, w, int(stride_[0]), c_mul, c_shr); -+ #else -+ -+ n = int(index / stride_[2]); -+ LongIndex residual = index % stride_[2]; -+ -+ h = int(residual / stride_[1]); -+ residual = (residual % stride_[1]); -+ -+ w = int(residual / stride_[0]); -+ c = int(residual % stride_[0]); -+ -+ #endif -+ return TensorCoord(n, h, w, c); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ // it does not make sense if the extent is larger than stride -+ // and we could not rely on the capacity calculation in such cases -+ // we could move this checkers to debug code only -+ if ((extent.c() > stride_[0]) -+ || (extent.w() * stride_[0] > stride_[1]) -+ || (extent.h() * stride_[1] > stride_[2])) { -+ assert(0); -+ } -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NCHW tensors. -+class TensorNCHW { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [w, hw, chw] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCHW(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorNCHW packed(TensorCoord const &extent) { -+ return TensorNCHW( -+ make_Coord( -+ extent.w(), -+ extent.w() * extent.h(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.w() + -+ LongIndex(stride_[0] * coord.h()) + -+ LongIndex(stride_[1] * coord.c()) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D NC/xHWx tensors. -+template -+class TensorNCxHWx { -+public: -+ -+ /// Interleaving quantity -+ static int const kInterleave = Interleave; -+ -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [Interleave x w, Interleave x wh, hwc] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNCxHWx(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorNCxHWx packed(TensorCoord const &extent) { -+ return TensorNCxHWx( -+ make_Coord( -+ kInterleave * extent.w(), -+ kInterleave * extent.w() * extent.h(), -+ extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index c_minor = (coord.c() % kInterleave); -+ Index c_major = (coord.c() / kInterleave); -+ -+ return c_minor + -+ LongIndex(kInterleave * coord.w()) + -+ LongIndex(stride_[0] * coord.h()) + -+ LongIndex(stride_[1] * c_major) + -+ LongIndex(stride_[2] * coord.n()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent.n() * stride_[2]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 4-D CxRSKx tensors. -+template -+class TensorCxRSKx { -+public: -+ -+ /// Interleaving quantity -+ static int const kInterleave = Interleave; -+ -+ /// Logical rank of tensor -+ static int const kRank = 4; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 3; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Tensor4DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [Interleave x n, Interleave x nw, Interleave x nwh] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx( -+ typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates -+ typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates -+ typename Stride::Index stride_n ///< number of elements between adjacent N coordinates -+ ): -+ stride_(make_Coord(stride_w, stride_h, stride_n)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorCxRSKx(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2])) -+ ) { } -+ -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorCxRSKx packed(TensorCoord const &extent) { -+ return TensorCxRSKx( -+ make_Coord( -+ kInterleave * extent.n(), -+ kInterleave * extent.n() * extent.w(), -+ kInterleave * extent.n() * extent.w() * extent.h() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index c_minor = (coord.c() % kInterleave); -+ Index c_major = (coord.c() / kInterleave); -+ -+ return c_minor + -+ LongIndex(kInterleave * coord.n()) + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * c_major); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord const &coord) const { -+ return (coord.contiguous() % kInterleave) + -+ LongIndex((coord.contiguous() / kInterleave) * stride_[2]) + -+ LongIndex(coord.strided() * kInterleave); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent.c() / kInterleave * stride_[2]); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mapping function for 5-D NDHWC tensors. -+class TensorNDHWC { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 5; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 4; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate (n, d, h, w, c) -+ using TensorCoord = Tensor5DCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member - [c, wc, hwc, dhwc] -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC( -+ typename Stride::Index c, -+ typename Stride::Index wc, -+ typename Stride::Index hwc, -+ typename Stride::Index dhwc): -+ stride_(make_Coord(c, wc, hwc, dhwc)) { } -+ -+ /// Constructor -+ // Once convolutions implement 64b stride this ctor can be deleted -+ CUTLASS_HOST_DEVICE -+ TensorNDHWC(Coord const &stride): -+ stride_(make_Coord( -+ static_cast(stride[0]), -+ static_cast(stride[1]), -+ static_cast(stride[2]), -+ static_cast(stride[3])) -+ ) { } -+ -+ /// Helper returns a layout to a tightly packed NHWC tensor. -+ CUTLASS_HOST_DEVICE -+ static TensorNDHWC packed(TensorCoord const &extent) { -+ return TensorNDHWC( -+ make_Coord( -+ extent.c(), -+ extent.w() * extent.c(), -+ extent.h() * extent.w() * extent.c(), -+ extent.d() * extent.h() * extent.w() * extent.c() -+ ) -+ ); -+ } -+ -+ /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord.c() + -+ LongIndex(stride_[0] * coord.w()) + -+ LongIndex(stride_[1] * coord.h()) + -+ LongIndex(stride_[2] * coord.d()) + -+ LongIndex(stride_[3] * coord.n()); -+ } -+ -+ /// Returns the offset of a pitchlinear coordinate in linear memory. -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(PitchLinearCoord coord) const { -+ return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ // it does not make sense if the extent is larger than stride -+ // and we could not rely on the capacity calculation in such cases -+ // we could move this checkers to debug code only -+ if ((extent.c() > stride_[0]) -+ || (extent.w() * stride_[0] > stride_[1]) -+ || (extent.h() * stride_[1] > stride_[2]) -+ || (extent.d() * stride_[2] > stride_[3])) { -+ assert(0); -+ } -+ return extent.n() * stride_[3]; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h -new file mode 100644 -index 0000000..b127bff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm70.h -@@ -0,0 +1,1044 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct VoltaTensorOpMultiplicandCongruous; -+ -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct ColumnMajorVoltaTensorOpMultiplicandCongruous; -+// template < -+// int ElementSize, -+// gemm::Operand Operand -+// > -+// struct RowMajorVoltaTensorOpMultiplicandCongruous; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct VoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<8, 2>; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ using PartitionCount = PitchLinearShape< -+ TileShape::kContiguous / PartitionShape::kContiguous, -+ TileShape::kStrided / PartitionShape::kStrided -+ >; -+ -+ using AccessCount = PitchLinearShape< -+ PartitionShape::kContiguous, -+ PartitionShape::kStrided -+ >; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCongruous(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCongruous(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // First, compute c and s of vector within source (in units of vector accesses) -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; -+ int tile_strided_idx = vec_strided_idx / TileShape::kStrided; -+ -+ int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Then swizzle in a tile -+ // Swizzle pattern is (tid[2:0] << 2)|(tid[4:3] ^ tid[2:1]) -+ int permuted_strided_within_tile = (tile_contiguous_residual >> 1); -+ int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | -+ ((tile_contiguous_residual & 1) << 2); -+ // Compute final element location -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; -+ -+ return element_contiguous + element_strided * stride_[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct ColumnMajorVoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct RowMajorVoltaTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+// template -+template -+struct VoltaTensorOpMultiplicandBCongruous { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<4, 4>; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ using PartitionCount = PitchLinearShape< -+ TileShape::kContiguous / PartitionShape::kContiguous, -+ TileShape::kStrided / PartitionShape::kStrided -+ >; -+ -+ using AccessCount = PitchLinearShape< -+ PartitionShape::kContiguous, -+ PartitionShape::kStrided -+ >; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandBCongruous(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandBCongruous(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandBCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // First, compute c and s of vector within source (in units of vector accesses) -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = vec_contiguous_idx / TileShape::kContiguous; -+ int tile_strided_idx = vec_strided_idx / TileShape::kStrided; -+ -+ int tile_contiguous_residual = vec_contiguous_idx % TileShape::kContiguous; -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Then swizzle in a tile -+ // Swizzle pattern is (tid[1:0] << 3)|(tid & 0x4)|(tid[1:0]) -+ int permuted_strided_within_tile = (tile_contiguous_residual & 0x3); -+ int permuted_contiguous_within_tile = (tile_strided_residual ^ permuted_strided_within_tile) | -+ (tile_contiguous_residual & 0x4); -+ -+ // Compute final element location -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_contiguous_within_tile) * kElementsPerAccess + (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = tile_strided_idx * TileShape::kStrided + permuted_strided_within_tile; -+ -+ return element_contiguous + element_strided * stride_[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct ColumnMajorVoltaTensorOpMultiplicandBCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandBCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandBCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous -+template -+struct RowMajorVoltaTensorOpMultiplicandBCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandBCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandBCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandBCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandBCongruous packed(TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandBCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and KBlock size (in elements). -+template -+struct VoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = 64; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ static int const kKBlock = KBlock; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. For GEMM, it equals to KBlock x stage. -+ Stride stride_; -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ VoltaTensorOpMultiplicandCrosswise(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static VoltaTensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { -+ return VoltaTensorOpMultiplicandCrosswise(extent[1]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ // -+ // First, compute c and s of vector within source (in units of vector -+ // accesses) -+ // -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided(); -+ -+ // -+ // Then swizzle -+ // The mapping is like this: -+ // id[1:0]|(id[3]^id[4])|id[2] -+ -+ int vec_strided_within_tile = vec_contiguous_idx & 0x7; -+ int permuted_vec_contiguous = -+ (vec_strided_idx & (~0xF)) + (vec_strided_idx & 0x3) * 4 + -+ (((vec_strided_idx >> 2) ^ ((vec_strided_idx & 0x10) >> 3)) & 0x3); -+ -+ permuted_vec_contiguous ^= ((vec_strided_within_tile >> 1) & 0x3); -+ -+ int permuted_vec_strided = vec_contiguous_idx; -+ -+ // -+ // Compute final element location -+ // -+ -+ int element_contiguous = permuted_vec_contiguous * kElementsPerAccess + -+ (coord.contiguous() % kElementsPerAccess); -+ -+ return element_contiguous + permuted_vec_strided * (stride_[0] * kElementsPerAccess); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[0] * stride_[0]; -+ } -+}; -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// VoltaTensorOpMultiplicandCrosswise -+template -+struct ColumnMajorVoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorVoltaTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return ColumnMajorVoltaTensorOpMultiplicandCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct RowMajorVoltaTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = VoltaTensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 64b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorVoltaTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorVoltaTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return RowMajorVoltaTensorOpMultiplicandCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+} // namespace layout -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h -new file mode 100644 -index 0000000..14148b7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h -@@ -0,0 +1,1161 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+/// This one is the base class of all Ampere/Turing fp16/bf16/int8/int4/int1 -+/// tensor core kernels. tf32 TN uses this too. -+template -+struct TensorOpMultiplicand { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ static int const kCrosswise = Crosswise; -+ -+ /// Contiguous dimension of the tile shape matches one shared memory cache -+ /// line - 128B. For 128bit access size, it equals to 8 accesses. -+ static int const kTileShapeContiguous = 128 / (kAccessSize / 8); -+ -+ /// Number of kblocks to store PartitionShape::kContiguous Elements -+ static int const kFactor = -+ kTileShapeContiguous * kElementsPerAccess / kCrosswise; -+ -+ static_assert( -+ (kFactor > 0), -+ "kCrosswise should be no large than one shared memory cache line."); -+ -+ /// The strided dimension needs to be at least (WarpSize(32) / -+ /// kTileShapeContiguous) for a warp to access. To ensure conflict free -+ /// access, it also needs to be at least (kTileShapeContiguous / kFactor). -+ /// See comments below -+ static int const kTileShapeStride = -+ ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous)) -+ ? (kTileShapeContiguous / kFactor) -+ : (32 / kTileShapeContiguous); -+ -+ /// Fundamental tile shape in units of vectors to guarantee bank conflict free -+ /// shared memory load/store. -+ /// For kFactor = 1, TileShape = <8, 8> -+ /// For kFactor > 1, TileShape = <8, 4> -+ using TileShape = PitchLinearShape; -+ -+ /// Fundamental partition shape in units of vectors -+ using PartitionShape = PitchLinearShape<4, 4>; -+ -+ using PartitionCount = -+ PitchLinearShape; -+ -+ using AccessCount = -+ PitchLinearShape; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. For GEMM, it equals to kCrosswise x stage. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicand packed(TensorCoord const &extent) { -+ return TensorOpMultiplicand(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ // -+ // First, compute c and s of vector within source (in units of vector -+ // accesses) -+ // -+ -+ int vec_contiguous_idx = coord.contiguous() / kElementsPerAccess; -+ int vec_strided_idx = coord.strided() / kFactor; -+ -+ // Compute the fundamental tile being accessed -+ int tile_contiguous_idx = -+ vec_contiguous_idx / (TileShape::kContiguous / kFactor); -+ -+ int tile_contiguous_residual = -+ vec_contiguous_idx % (TileShape::kContiguous / kFactor) + -+ ((coord.strided() % kFactor) * (TileShape::kContiguous / kFactor)); -+ int tile_strided_residual = vec_strided_idx % TileShape::kStrided; -+ -+ // Compute the 'partition' within the fundamental tile -+ int partition_contiguous_idx = -+ tile_contiguous_residual / PartitionShape::kContiguous; -+ int partition_strided_idx = -+ tile_strided_residual / PartitionShape::kStrided; -+ -+ int partition_contiguous_residual = -+ tile_contiguous_residual % PartitionShape::kContiguous; -+ int partition_strided_residual = -+ tile_strided_residual % PartitionShape::kStrided; -+ -+ // -+ // Then swizzle -+ // -+ -+ int permuted_vec_contiguous_within_partition = -+ partition_contiguous_residual ^ (partition_strided_residual % 4); -+ -+ int permuted_partition_contiguous_within_tile = -+ partition_contiguous_idx ^ (partition_strided_idx % 2); -+ -+ // -+ // Compute final element location -+ // -+ -+ int element_contiguous = (tile_contiguous_idx * TileShape::kContiguous + -+ permuted_partition_contiguous_within_tile * -+ PartitionShape::kContiguous + -+ permuted_vec_contiguous_within_partition) * -+ kElementsPerAccess + -+ (coord.contiguous() % kElementsPerAccess); -+ -+ int element_strided = vec_strided_idx; -+ -+ return element_contiguous + element_strided * stride_[0] * kFactor; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+template -+struct TensorOpMultiplicandCongruous { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(coord); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return coord; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(extent); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+/// This one is just for TF32 NT kernel. -+template -+struct TensorOpMultiplicandCongruous<32, Crosswise> { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ /// Fundamental tile shape in units of vectors -+ using TileShape = PitchLinearShape<8, 4>; -+ -+ /// Partitionshape is the same as TileShape for this layout -+ using PartitionShape = PitchLinearShape<8, 4>; -+ -+ using PartitionCount = -+ PitchLinearShape; -+ -+ using AccessCount = -+ PitchLinearShape; -+ -+ // -+ // Static constants -+ // -+ static int const kElementSize = 32; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int tc = coord.contiguous() / 32; -+ int ts = coord.strided() / 4; -+ -+ int c = (coord.contiguous() % 32) / kElementsPerAccess; -+ int s = coord.strided() % 4; -+ -+ LongIndex offset = (c ^ (2 * s)) * kElementsPerAccess + s * stride_[0] + -+ tc * 32 + ts * stride_[0] * 4 + coord.contiguous() % 4; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+template -+struct ColumnMajorTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+template -+struct RowMajorTensorOpMultiplicandCongruous { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+template -+struct TensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ static int const kCrosswise = Base::kCrosswise; -+ static int const kFactor = Base::kFactor; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCrosswise packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCrosswise(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(coord); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return coord; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(extent); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct ColumnMajorTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicandCrosswise -+template -+struct RowMajorTensorOpMultiplicandCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise; -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = Base::kAccessSize; -+ using TileShape = typename Base::TileShape; -+ using PartitionShape = typename Base::PartitionShape; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = Base::kElementSize; -+ static int const kElementsPerAccess = Base::kElementsPerAccess; -+ using PartitionCount = typename Base::PartitionCount; -+ using AccessCount = typename Base::AccessCount; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise(Index ldm = 0) : layout_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise(Stride stride) : layout_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCrosswise packed( -+ TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return layout_.stride(); } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return layout_.stride(); } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct TensorOpMultiplicandColumnMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ //static int const kThreadBlockStrided = ThreadBlockStrided; -+ static int const kInterleavedK = InterleavedK; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandColumnMajorInterleaved(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandColumnMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandColumnMajorInterleaved packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandColumnMajorInterleaved(extent[0] * kInterleavedK); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int const rows_per_smem_cache_line = 128 / kInterleavedK; -+ -+ int row_id = coord.strided() / rows_per_smem_cache_line; -+ int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); -+ -+ int access_block_id = col_id >> 4; -+ int swizzle_access_block_id = access_block_id ^ (row_id & 1); -+ -+ int swizzle_col_id = swizzle_access_block_id << 4; -+ -+ return row_id * 128 + swizzle_col_id; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent[1] / kInterleavedK) * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear memory. -+template -+struct TensorOpMultiplicandRowMajorInterleaved { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ /// This layout is optimized for 128b accesses -+ static int const kAccessSize = 128; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = ElementSize; -+ static int const kElementsPerAccess = kAccessSize / kElementSize; -+ -+ //static int const kThreadBlockStrided = ThreadBlockStrided; -+ static int const kInterleavedK = InterleavedK; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandRowMajorInterleaved(Index ldm = 0): stride_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandRowMajorInterleaved(Stride stride): stride_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandRowMajorInterleaved packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandRowMajorInterleaved(extent[1] * kInterleavedK); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ int const rows_per_smem_cache_line = 128 / kInterleavedK; -+ -+ int row_id = coord.strided() / rows_per_smem_cache_line; -+ int col_id = (coord.strided() % rows_per_smem_cache_line) * kInterleavedK + coord.contiguous(); -+ -+ int access_block_id = col_id >> 4; -+ int swizzle_access_block_id = access_block_id ^ (row_id & 1); -+ -+ int swizzle_col_id = swizzle_access_block_id << 4; -+ -+ return row_id * 128 + swizzle_col_id; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return (extent[0] / kInterleavedK) * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h -new file mode 100644 -index 0000000..f75c2a8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm80.h -@@ -0,0 +1,1139 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief layouts needed by Ampere fp64 tensor core kernels. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace layout { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCongruous64b { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 64; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous64b(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous64b(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous64b(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ int tc = coord.contiguous() / 16; -+ int ts = coord.strided() / 4; -+ -+ int c = coord.contiguous() % 16; -+ int s = coord.strided() % 4; -+ -+ -+ int bank = ((((c & 1) * 4 + (c & 6) / 2)) ^ (s & 1)) * 2 + (c / 8); -+ int row = (c & 6) / 2; -+ -+ bank ^= ((s & 2) * 2); -+ -+ LongIndex offset = tc * 16 + bank + (ts * 4 + row) * stride_[0]; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ return TensorCoord(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCongruous64b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous64b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous64b(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCongruous64b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous64b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous64b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous64b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous64b packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous64b(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicand64bCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 64; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand64bCrosswise(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicand64bCrosswise(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return TensorOpMultiplicand64bCrosswise(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ int tc = coord.contiguous() / 16; -+ int ts = coord.strided() / 16; -+ -+ int c = coord.contiguous() % 16; -+ int s = coord.strided() % 16; -+ -+ int k_group = c / 4; -+ int access_s = s / 2; -+ -+ int row = access_s % 4; -+ int bank = ((k_group & 2) << 2) ^ ((s % 2) << 3) + (c % 4) * 2 + (access_s / 4) ^ (k_group & 1); -+ -+ int smem_row = (k_group * 4 + row) + tc * 16; -+ int smem_col = ts * 16 + bank; -+ -+ LongIndex offset = smem_row * stride_[0] + smem_col; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct ColumnMajorTensorOpMultiplicand64bCrosswise { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand64bCrosswise; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicand64bCrosswise(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct RowMajorTensorOpMultiplicand64bCrosswise { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicand64bCrosswise; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicand64bCrosswise(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicand64bCrosswise(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicand64bCrosswise packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicand64bCrosswise(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCongruous128b { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 128; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous128b(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCongruous128b(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCongruous128b(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index tc = coord.contiguous() / 8; -+ Index ts = coord.strided() / 4; -+ -+ Index c = coord.contiguous() % 8; -+ Index s = coord.strided() % 4; -+ -+ Index k_index = (c / 2); -+ -+ Index bank = (((c & 1) * 4) | (s ^ k_index)); -+ -+ LongIndex offset = tc * 8 + bank + (ts * 4 + k_index) * stride_[0]; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ return TensorCoord(); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCongruous128b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous128b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCongruous128b(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.contiguous(), coord.strided()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCongruous128b { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCongruous128b; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous128b(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCongruous128b(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCongruous128b packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCongruous128b(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Inverse of layout function, mapping linear offset to logical coordinate -+ CUTLASS_HOST_DEVICE -+ TensorCoord inverse(LongIndex offset) const { -+ PitchLinearCoord coord = layout_.inverse(offset); -+ return MatrixCoord(coord.strided(), coord.contiguous()); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template based on element size (in bits) - defined in terms of pitch-linear -+/// memory and Crosswise size (in elements). -+struct TensorOpMultiplicandCrosswise128x4 { -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = PitchLinearCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Static constants -+ // -+ -+ static int const kElementSize = 128; -+ static int const kElementsPerAccess = 1; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member. -+ Stride stride_; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise128x4(Index ldm = 0) : stride_(ldm) {} -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorOpMultiplicandCrosswise128x4(Stride stride) : stride_(stride) {} -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static TensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return TensorOpMultiplicandCrosswise128x4(extent[0]); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ -+ Index tc = coord.contiguous() / 8; -+ Index ts = coord.strided() / 8; -+ -+ Index c = coord.contiguous() % 8; -+ Index s = coord.strided() % 8; -+ -+ Index liq = c % 4; -+ -+ Index bank = liq + ((s & 1) * 4) ^ (c & 4); -+ -+ Index k_index = (c & 4) + (s / 4) * 2 + ((s & 2) / 2); -+ -+ LongIndex offset = (tc * 8 + k_index) * stride_[0] + ts * 8 + bank; -+ -+ return offset; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { return stride_; } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride &stride() { return stride_; } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with -+ /// the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return extent[1] * stride_[0]; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a column-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct ColumnMajorTensorOpMultiplicandCrosswise128x4 { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise128x4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ ColumnMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static ColumnMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return ColumnMajorTensorOpMultiplicandCrosswise128x4(extent.column()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.row(), coord.column())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.row(), extent.column())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Template mapping a row-major view of pitch-linear memory to -+/// TensorOpMultiplicand -+struct RowMajorTensorOpMultiplicandCrosswise128x4 { -+ -+ /// Logical rank of tensor -+ static int const kRank = 2; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = MatrixCoord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+ // -+ // Invariants -+ // -+ -+ using Base = TensorOpMultiplicandCrosswise128x4; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ Base layout_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise128x4(Index ldm = 0): layout_(ldm) { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ RowMajorTensorOpMultiplicandCrosswise128x4(Stride stride): layout_(stride) { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static RowMajorTensorOpMultiplicandCrosswise128x4 packed(TensorCoord const &extent) { -+ return RowMajorTensorOpMultiplicandCrosswise128x4(extent.row()); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory. -+ /// Assumes coordinate has convention (contiguous, strided) -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return layout_(PitchLinearCoord(coord.column(), coord.row())); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &extent) const { -+ return layout_.capacity(PitchLinearCoord(extent.column(), extent.row())); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace layout -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/layout/vector.h b/3rdparty/cutlass/include/cutlass/layout/vector.h -new file mode 100644 -index 0000000..e9ad6da ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/layout/vector.h -@@ -0,0 +1,104 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used for rank=1 vectors. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+namespace layout { -+ -+/// Tensor layout for densely packed vectors. -+class PackedVectorLayout { -+public: -+ /// Logical rank of tensor -+ static int const kRank = 1; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = 1; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ -+ // -+ // No actual stride vector stored -+ // -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PackedVectorLayout() { } -+ -+ /// Helper returns a layout to a tightly packed tensor -+ CUTLASS_HOST_DEVICE -+ static PackedVectorLayout packed(TensorCoord const &size) { -+ return PackedVectorLayout(); -+ } -+ -+ /// Returns the offset of a coordinate in linear memory -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(TensorCoord const &coord) const { -+ return coord[0]; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return make_Coord(1); -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &size) const { -+ return size[0]; -+ } -+}; -+ -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/matrix.h b/3rdparty/cutlass/include/cutlass/matrix.h -new file mode 100644 -index 0000000..ba9ffbb ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix.h -@@ -0,0 +1,14129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Matrix classes with value semantics. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/matrix.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Primary template with partial specializations to follow -+template struct Matrix; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 2; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Forms a 1-by-2 matrix by horizontally concatenating an Element with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Element rhs) { -+ return Matrix( -+ lhs, rhs); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Element rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 1-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 1-by-2 matrix -+template -+using Matrix1x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( -+ Element _0_0, Element _0_1 -+) { -+ return Matrix1x2( -+ _0_0, _0_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 3; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Forms a 1-by-3 matrix by horizontally concatenating an Element with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs, rhs.at(0, 0), rhs.at(0, 1)); -+ } -+ -+ /// Forms a 1-by-3 matrix by horizontally concatenating a 1-by-2 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Element rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Element rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ // k=2 -+ accum += data[2] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 1-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+ /// Cross product -+ CUTLASS_HOST_DEVICE -+ Matrix cross(Matrix const &rhs) const { -+ return Matrix( -+ data[1] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[1] - data[1] * rhs.data[0] -+ ); -+ } -+ -+}; -+ -+/// Template alias for 1-by-3 matrix -+template -+using Matrix1x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( -+ Element _0_0, Element _0_1, Element _0_2 -+) { -+ return Matrix1x3( -+ _0_0, _0_1, _0_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 1-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 1; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 1-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 1 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating an Element with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Element lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs, rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)); -+ } -+ -+ /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-3 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Element rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-4 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (1-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 1-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Element product(Matrix const &rhs, Element accum = Element()) const { -+ -+ // k=0 -+ accum += data[0] * rhs.data[0]; -+ -+ // k=1 -+ accum += data[1] * rhs.data[1]; -+ -+ // k=2 -+ accum += data[2] * rhs.data[2]; -+ -+ // k=3 -+ accum += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Element operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 1-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 1-by-4 matrix -+template -+using Matrix1x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3 -+) { -+ return Matrix1x4( -+ _0_0, _0_1, _0_2, _0_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 2; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-1 matrix by vertically concatenating an Element with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Element lower) { -+ return Matrix( -+ upper -+ , lower); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Element rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 2 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-1 matrix -+template -+using Matrix2x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( -+ Element _0_0, -+ Element _1_0 -+) { -+ return Matrix2x1( -+ _0_0, -+ _1_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ } -+ -+ /// Constucts a 2-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ } -+ -+ /// Static method to construct a 2-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[3] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-2 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Element C, Element D) { -+ return Matrix( -+ A, B -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns 2-by-2 rotation matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta) { -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ return Matrix( -+ c, -s, -+ s, c -+ ); -+ } -+ -+ /// Computes the determinant of a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ accum += data[0] * data[3] - data[1] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 2-by-2 matrix given -+ /// the matrix's determinant -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element det) const { -+ return Matrix( -+ data[3], -data[1], -+ -data[2], data[0] -+ ) * (Element(1) / det); -+ } -+ -+ /// Computes the inverse of a 2-by-2 matrix. -+ CUTLASS_HOST_DEVICE -+ Matrix inverse() const { -+ return inverse(determinant()); -+ } -+ -+}; -+ -+/// Template alias for 2-by-2 matrix -+template -+using Matrix2x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1 -+) { -+ return Matrix2x2( -+ _0_0, _0_1, -+ _1_0, _1_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 6; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ } -+ -+ /// Constucts a 2-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ } -+ -+ /// Static method to construct a 2-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[4] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[3] = data[4]; -+ mt.data[5] = data[5]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)); -+ } -+ -+ /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 2-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 2-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-3 matrix -+template -+using Matrix2x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2 -+) { -+ return Matrix2x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 2-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 2; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 8; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 2-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ } -+ -+ /// Constucts a 2-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ } -+ -+ /// Static method to construct a 2-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[3] = diag.data[1]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[3]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[2] = data[1]; -+ mt.data[4] = data[2]; -+ mt.data[6] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[3] = data[5]; -+ mt.data[5] = data[6]; -+ mt.data[7] = data[7]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 2 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_2x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_2x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)); -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)); -+ } -+ -+ /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-3 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)); -+ } -+ -+ /// Forms a 2-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 2-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 2-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (2-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 2-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 2-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 2-by-4 matrix -+template -+using Matrix2x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3 -+) { -+ return Matrix2x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 3; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ data[2] = _2_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-3 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-1 matrix by vertically concatenating an Element with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Matrix const & lower) { -+ return Matrix( -+ upper -+ , lower.at(0, 0) -+ , lower.at(1, 0)); -+ } -+ -+ /// Forms a 3-by-1 matrix by vertically concatenating a 2-by-1 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Element lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , lower); -+ } -+ -+ /// Concatenates this matrix with a an Element to form a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Element rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ data[2] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ data[2] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ accum.data[6] += data[2] * rhs.data[0]; -+ accum.data[7] += data[2] * rhs.data[1]; -+ accum.data[8] += data[2] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ accum.data[8] += data[2] * rhs.data[0]; -+ accum.data[9] += data[2] * rhs.data[1]; -+ accum.data[10] += data[2] * rhs.data[2]; -+ accum.data[11] += data[2] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 3 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+ /// Cross product -+ CUTLASS_HOST_DEVICE -+ Matrix cross(Matrix const &rhs) const { -+ return Matrix( -+ data[1] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[2] - data[2] * rhs.data[1], -+ data[0] * rhs.data[1] - data[1] * rhs.data[0] -+ ); -+ } -+ -+}; -+ -+/// Template alias for 3-by-1 matrix -+template -+using Matrix3x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0 -+) { -+ return Matrix3x1( -+ _0_0, -+ _1_0, -+ _2_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 6; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ data[4] = _2_0; data[5] = _2_1; -+ } -+ -+ /// Constucts a 3-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ data[4] = row_2.data[0]; -+ data[5] = row_2.data[1]; -+ } -+ -+ /// Static method to construct a 3-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ result.data[4] = column_0.data[2]; -+ result.data[5] = column_1.data[2]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[4] = data[3]; -+ mt.data[2] = data[4]; -+ mt.data[5] = data[5]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-2 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0) -+ , lhs.at(2, 0), rhs.at(2, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1)); -+ } -+ -+ /// Forms a 3-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-2 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Element D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ accum.data[2] += data[5] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[2]; -+ accum.data[5] += data[5] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ accum.data[6] += data[4] * rhs.data[0]; -+ accum.data[7] += data[4] * rhs.data[1]; -+ accum.data[8] += data[4] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[3]; -+ accum.data[7] += data[5] * rhs.data[4]; -+ accum.data[8] += data[5] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ accum.data[8] += data[4] * rhs.data[0]; -+ accum.data[9] += data[4] * rhs.data[1]; -+ accum.data[10] += data[4] * rhs.data[2]; -+ accum.data[11] += data[4] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ accum.data[8] += data[5] * rhs.data[4]; -+ accum.data[9] += data[5] * rhs.data[5]; -+ accum.data[10] += data[5] * rhs.data[6]; -+ accum.data[11] += data[5] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 3-by-2 matrix -+template -+using Matrix3x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1 -+) { -+ return Matrix3x2( -+ _0_0, _0_1, -+ _1_0, _1_1, -+ _2_0, _2_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 9; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; -+ } -+ -+ /// Constucts a 3-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ data[6] = row_2.data[0]; -+ data[7] = row_2.data[1]; -+ data[8] = row_2.data[2]; -+ } -+ -+ /// Static method to construct a 3-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ result.data[6] = column_0.data[2]; -+ result.data[7] = column_1.data[2]; -+ result.data[8] = column_2.data[2]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[4] = Element(1); -+ m.data[8] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[6] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[4] = data[4]; -+ mt.data[7] = data[5]; -+ mt.data[2] = data[6]; -+ mt.data[5] = data[7]; -+ mt.data[8] = data[8]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1)); -+ } -+ -+ /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); -+ } -+ -+ /// Forms a 3-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-3 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ result.data[8] = data[8] + rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ data[8] += rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ result.data[8] = data[8] - rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ data[8] -= rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ result.data[8] = data[8] * rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ result.data[8] = data[8] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ data[8] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ result.data[8] = data[8] / rhs.data[8]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ result.data[8] = data[8] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ data[8] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ data[8] /= rhs.data[8]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ accum.data[2] += data[6] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ accum.data[2] += data[7] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ accum.data[2] += data[8] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ accum.data[4] += data[6] * rhs.data[0]; -+ accum.data[5] += data[6] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[7] * rhs.data[2]; -+ accum.data[5] += data[7] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ accum.data[4] += data[8] * rhs.data[4]; -+ accum.data[5] += data[8] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ accum.data[8] += data[6] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[7] * rhs.data[3]; -+ accum.data[7] += data[7] * rhs.data[4]; -+ accum.data[8] += data[7] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ accum.data[6] += data[8] * rhs.data[6]; -+ accum.data[7] += data[8] * rhs.data[7]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 3-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ accum.data[8] += data[6] * rhs.data[0]; -+ accum.data[9] += data[6] * rhs.data[1]; -+ accum.data[10] += data[6] * rhs.data[2]; -+ accum.data[11] += data[6] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ accum.data[8] += data[7] * rhs.data[4]; -+ accum.data[9] += data[7] * rhs.data[5]; -+ accum.data[10] += data[7] * rhs.data[6]; -+ accum.data[11] += data[7] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[8] * rhs.data[9]; -+ accum.data[10] += data[8] * rhs.data[10]; -+ accum.data[11] += data[8] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the X axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_X(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(1, 1) = c; -+ m.at(1, 2) = -s; -+ m.at(2, 1) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the Y axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Y(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(2, 0) = -s; -+ m.at(0, 2) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 3-by-3 rotation matrix around the Z axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Z(Element theta) { -+ Matrix m = Matrix::identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(0, 1) = -s; -+ m.at(1, 0) = s; -+ m.at(1, 1) = c; -+ -+ return m; -+ } -+ -+ /// Returns a 3-by-3 rotation matrix around a unit-length axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta, Matrix const &u) { -+ Element x = u.data[0]; -+ Element y = u.data[1]; -+ Element z = u.data[2]; -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ Element one_minus_cos = Element(1) - fast_cos(theta); -+ -+ Matrix m; -+ -+ m.set_slice3x3({ -+ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, -+ y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, -+ z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a 3-by-3 reflection about the plane specified by the -+ /// unit-length normal vector n_unit -+ CUTLASS_HOST_DEVICE -+ static Matrix reflection(Matrix const &n_unit) { -+ -+ Element a = n_unit.data[0]; -+ Element b = n_unit.data[1]; -+ Element c = n_unit.data[2]; -+ -+ Matrix m = Matrix::identity(); -+ -+ m.set_slice3x3({ -+ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, -+ Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, -+ Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c -+ }); -+ -+ return m; -+ } -+ -+ /// Computes the determinant of a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ -+ accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(2, 1), at(2, 2) }).determinant(); -+ accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(2, 0), at(2, 2) }).determinant(); -+ accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(2, 0), at(2, 1) }).determinant(); -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 3-by-3 matrix given -+ /// the matrix's determinant -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element det) const { -+ return Matrix( -+ at(1, 1) * at(2, 2) - at(1, 2) * at(2, 1), -+ at(0, 2) * at(2, 1) - at(0, 1) * at(2, 2), -+ at(0, 1) * at(1, 2) - at(0, 2) * at(1, 1), -+ -+ at(1, 2) * at(2, 0) - at(1, 0) * at(2, 2), -+ at(0, 0) * at(2, 2) - at(0, 2) * at(2, 0), -+ at(0, 2) * at(1, 0) - at(0, 0) * at(1, 2), -+ -+ at(1, 0) * at(2, 1) - at(1, 1) * at(2, 0), -+ at(0, 1) * at(2, 0) - at(0, 0) * at(2, 1), -+ at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0) -+ ) * (Element(1) / det); -+ } -+ /// Computes the inverse of a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix inverse() const { -+ return inverse(determinant()); -+ } -+ -+}; -+ -+/// Template alias for 3-by-3 matrix -+template -+using Matrix3x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2 -+) { -+ return Matrix3x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2, -+ _2_0, _2_1, _2_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 3-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 3; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 12; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 3-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; -+ } -+ -+ /// Constucts a 3-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ data[8] = row_2.data[0]; -+ data[9] = row_2.data[1]; -+ data[10] = row_2.data[2]; -+ data[11] = row_2.data[3]; -+ } -+ -+ /// Static method to construct a 3-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ result.data[8] = column_0.data[2]; -+ result.data[9] = column_1.data[2]; -+ result.data[10] = column_2.data[2]; -+ result.data[11] = column_3.data[2]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[4] = diag.data[1]; -+ m.data[8] = diag.data[2]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[4]; -+ diag.data[2] = data[8]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[3] = data[1]; -+ mt.data[6] = data[2]; -+ mt.data[9] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[4] = data[5]; -+ mt.data[7] = data[6]; -+ mt.data[10] = data[7]; -+ mt.data[2] = data[8]; -+ mt.data[5] = data[9]; -+ mt.data[8] = data[10]; -+ mt.data[11] = data[11]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 3 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_3x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_3x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2)); -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1)); -+ } -+ -+ /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-3 matrix with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0)); -+ } -+ -+ /// Forms a 3-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); -+ } -+ -+ /// Forms a 3-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Concatenates this matrix with a a 1-by-4 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix vcat(Matrix const & rhs) const { -+ return Matrix::vcat(*this, rhs); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 3-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ result.data[8] = data[8] + rhs.data[8]; -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ data[8] += rhs.data[8]; -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ result.data[8] = data[8] - rhs.data[8]; -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ data[8] -= rhs.data[8]; -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ result.data[8] = data[8] * rhs.data[8]; -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ result.data[8] = data[8] * s; -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ data[8] *= s; -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ result.data[8] = data[8] / rhs.data[8]; -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ result.data[8] = data[8] / s; -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ data[8] /= s; -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (3-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ data[8] /= rhs.data[8]; -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 3-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ accum.data[2] += data[8] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ accum.data[2] += data[9] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ accum.data[2] += data[10] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ accum.data[2] += data[11] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ accum.data[4] += data[8] * rhs.data[0]; -+ accum.data[5] += data[8] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[9] * rhs.data[2]; -+ accum.data[5] += data[9] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ accum.data[4] += data[10] * rhs.data[4]; -+ accum.data[5] += data[10] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ accum.data[4] += data[11] * rhs.data[6]; -+ accum.data[5] += data[11] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ accum.data[6] += data[8] * rhs.data[0]; -+ accum.data[7] += data[8] * rhs.data[1]; -+ accum.data[8] += data[8] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[9] * rhs.data[3]; -+ accum.data[7] += data[9] * rhs.data[4]; -+ accum.data[8] += data[9] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ accum.data[6] += data[10] * rhs.data[6]; -+ accum.data[7] += data[10] * rhs.data[7]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ accum.data[6] += data[11] * rhs.data[9]; -+ accum.data[7] += data[11] * rhs.data[10]; -+ accum.data[8] += data[11] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ accum.data[8] += data[8] * rhs.data[0]; -+ accum.data[9] += data[8] * rhs.data[1]; -+ accum.data[10] += data[8] * rhs.data[2]; -+ accum.data[11] += data[8] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ accum.data[8] += data[9] * rhs.data[4]; -+ accum.data[9] += data[9] * rhs.data[5]; -+ accum.data[10] += data[9] * rhs.data[6]; -+ accum.data[11] += data[9] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[10] * rhs.data[9]; -+ accum.data[10] += data[10] * rhs.data[10]; -+ accum.data[11] += data[10] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ accum.data[8] += data[11] * rhs.data[12]; -+ accum.data[9] += data[11] * rhs.data[13]; -+ accum.data[10] += data[11] * rhs.data[14]; -+ accum.data[11] += data[11] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 3-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ accum += data[10]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 3-by-4 matrix -+template -+using Matrix3x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3 -+) { -+ return Matrix3x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3, -+ _2_0, _2_1, _2_2, _2_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-1 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 1; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 4; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-1 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0, -+ Element _3_0 -+ ) { -+ -+ data[0] = _0_0; -+ data[1] = _1_0; -+ data[2] = _2_0; -+ data[3] = _3_0; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[1] = data[1]; -+ mt.data[2] = data[2]; -+ mt.data[3] = data[3]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 1 + j + 0]; -+ m.data[1] = data[i * 1 + j + 1]; -+ m.data[2] = data[i * 1 + j + 2]; -+ m.data[3] = data[i * 1 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 1 + j + 0] = m.data[0]; -+ data[i * 1 + j + 1] = m.data[1]; -+ data[i * 1 + j + 2] = m.data[2]; -+ data[i * 1 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-3 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating an Element with a 3-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Element upper, Matrix const & lower) { -+ return Matrix( -+ upper -+ , lower.at(0, 0) -+ , lower.at(1, 0) -+ , lower.at(2, 0)); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating a 2-by-1 matrix with a 2-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , lower.at(0, 0) -+ , lower.at(1, 0)); -+ } -+ -+ /// Forms a 4-by-1 matrix by vertically concatenating a 3-by-1 matrix with an Element -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Element lower) { -+ return Matrix( -+ upper.at(0, 0) -+ , upper.at(1, 0) -+ , upper.at(2, 0) -+ , lower); -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ -+ data[1] *= s; -+ -+ data[2] *= s; -+ -+ data[3] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ -+ data[1] /= s; -+ -+ data[2] /= s; -+ -+ data[3] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-1) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[1] * rhs.data[0]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-1-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[1] * rhs.data[0]; -+ accum.data[3] += data[1] * rhs.data[1]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[0]; -+ accum.data[7] += data[3] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[1] * rhs.data[0]; -+ accum.data[4] += data[1] * rhs.data[1]; -+ accum.data[5] += data[1] * rhs.data[2]; -+ accum.data[6] += data[2] * rhs.data[0]; -+ accum.data[7] += data[2] * rhs.data[1]; -+ accum.data[8] += data[2] * rhs.data[2]; -+ accum.data[9] += data[3] * rhs.data[0]; -+ accum.data[10] += data[3] * rhs.data[1]; -+ accum.data[11] += data[3] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[1] * rhs.data[0]; -+ accum.data[5] += data[1] * rhs.data[1]; -+ accum.data[6] += data[1] * rhs.data[2]; -+ accum.data[7] += data[1] * rhs.data[3]; -+ accum.data[8] += data[2] * rhs.data[0]; -+ accum.data[9] += data[2] * rhs.data[1]; -+ accum.data[10] += data[2] * rhs.data[2]; -+ accum.data[11] += data[2] * rhs.data[3]; -+ accum.data[12] += data[3] * rhs.data[0]; -+ accum.data[13] += data[3] * rhs.data[1]; -+ accum.data[14] += data[3] * rhs.data[2]; -+ accum.data[15] += data[3] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-1 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Dot product of vectors with extent 4 -+ CUTLASS_HOST_DEVICE -+ Element dot(Matrix const &rhs, Element accum = Element()) const { -+ -+ accum += data[0] * rhs.data[0]; -+ accum += data[1] * rhs.data[1]; -+ accum += data[2] * rhs.data[2]; -+ accum += data[3] * rhs.data[3]; -+ return accum; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-1 matrix -+template -+using Matrix4x1 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( -+ Element _0_0, -+ Element _1_0, -+ Element _2_0, -+ Element _3_0 -+) { -+ return Matrix4x1( -+ _0_0, -+ _1_0, -+ _2_0, -+ _3_0 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-2 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 2; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 8; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-2 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1, -+ Element _3_0, Element _3_1 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; -+ data[2] = _1_0; data[3] = _1_1; -+ data[4] = _2_0; data[5] = _2_1; -+ data[6] = _3_0; data[7] = _3_1; -+ } -+ -+ /// Constucts a 4-by-2 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_1.data[0]; -+ data[3] = row_1.data[1]; -+ data[4] = row_2.data[0]; -+ data[5] = row_2.data[1]; -+ data[6] = row_3.data[0]; -+ data[7] = row_3.data[1]; -+ } -+ -+ /// Static method to construct a 4-by-2 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_0.data[1]; -+ result.data[3] = column_1.data[1]; -+ result.data[4] = column_0.data[2]; -+ result.data[5] = column_1.data[2]; -+ result.data[6] = column_0.data[3]; -+ result.data[7] = column_1.data[3]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[1] = data[2]; -+ mt.data[5] = data[3]; -+ mt.data[2] = data[4]; -+ mt.data[6] = data[5]; -+ mt.data[3] = data[6]; -+ mt.data[7] = data[7]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x2(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x2(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 2]; -+ m.data[2] = data[i * 2 + j + 4]; -+ m.data[3] = data[i * 2 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 2] = m.data[1]; -+ data[i * 2 + j + 4] = m.data[2]; -+ data[i * 2 + j + 6] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 2 + j + 0]; -+ m.data[1] = data[i * 2 + j + 1]; -+ m.data[2] = data[i * 2 + j + 2]; -+ m.data[3] = data[i * 2 + j + 3]; -+ m.data[4] = data[i * 2 + j + 4]; -+ m.data[5] = data[i * 2 + j + 5]; -+ m.data[6] = data[i * 2 + j + 6]; -+ m.data[7] = data[i * 2 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 2 + j + 0] = m.data[0]; -+ data[i * 2 + j + 1] = m.data[1]; -+ data[i * 2 + j + 2] = m.data[2]; -+ data[i * 2 + j + 3] = m.data[3]; -+ data[i * 2 + j + 4] = m.data[4]; -+ data[i * 2 + j + 5] = m.data[5]; -+ data[i * 2 + j + 6] = m.data[6]; -+ data[i * 2 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-2 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0) -+ , lhs.at(1, 0), rhs.at(1, 0) -+ , lhs.at(2, 0), rhs.at(2, 0) -+ , lhs.at(3, 0), rhs.at(3, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 3-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1) -+ , lower.at(2, 0), lower.at(2, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 2-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , lower.at(0, 0), lower.at(0, 1) -+ , lower.at(1, 0), lower.at(1, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by vertically concatenating a 3-by-2 matrix with a 1-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1) -+ , upper.at(1, 0), upper.at(1, 1) -+ , upper.at(2, 0), upper.at(2, 1) -+ , lower.at(0, 0), lower.at(0, 1)); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ , C.at(2, 0), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , C.at(0, 0), D.at(0, 0) -+ , C.at(1, 0), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-2 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Element D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0) -+ , A.at(1, 0), B.at(1, 0) -+ , A.at(2, 0), B.at(2, 0) -+ , C, D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-2) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[2] * rhs.data[0]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[6] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[3] * rhs.data[1]; -+ accum.data[2] += data[5] * rhs.data[1]; -+ accum.data[3] += data[7] * rhs.data[1]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[2] * rhs.data[0]; -+ accum.data[3] += data[2] * rhs.data[1]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[3] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[2]; -+ accum.data[5] += data[5] * rhs.data[3]; -+ accum.data[6] += data[7] * rhs.data[2]; -+ accum.data[7] += data[7] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[2] * rhs.data[0]; -+ accum.data[4] += data[2] * rhs.data[1]; -+ accum.data[5] += data[2] * rhs.data[2]; -+ accum.data[6] += data[4] * rhs.data[0]; -+ accum.data[7] += data[4] * rhs.data[1]; -+ accum.data[8] += data[4] * rhs.data[2]; -+ accum.data[9] += data[6] * rhs.data[0]; -+ accum.data[10] += data[6] * rhs.data[1]; -+ accum.data[11] += data[6] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[3] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[3]; -+ accum.data[7] += data[5] * rhs.data[4]; -+ accum.data[8] += data[5] * rhs.data[5]; -+ accum.data[9] += data[7] * rhs.data[3]; -+ accum.data[10] += data[7] * rhs.data[4]; -+ accum.data[11] += data[7] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[2] * rhs.data[0]; -+ accum.data[5] += data[2] * rhs.data[1]; -+ accum.data[6] += data[2] * rhs.data[2]; -+ accum.data[7] += data[2] * rhs.data[3]; -+ accum.data[8] += data[4] * rhs.data[0]; -+ accum.data[9] += data[4] * rhs.data[1]; -+ accum.data[10] += data[4] * rhs.data[2]; -+ accum.data[11] += data[4] * rhs.data[3]; -+ accum.data[12] += data[6] * rhs.data[0]; -+ accum.data[13] += data[6] * rhs.data[1]; -+ accum.data[14] += data[6] * rhs.data[2]; -+ accum.data[15] += data[6] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[3] * rhs.data[4]; -+ accum.data[5] += data[3] * rhs.data[5]; -+ accum.data[6] += data[3] * rhs.data[6]; -+ accum.data[7] += data[3] * rhs.data[7]; -+ accum.data[8] += data[5] * rhs.data[4]; -+ accum.data[9] += data[5] * rhs.data[5]; -+ accum.data[10] += data[5] * rhs.data[6]; -+ accum.data[11] += data[5] * rhs.data[7]; -+ accum.data[12] += data[7] * rhs.data[4]; -+ accum.data[13] += data[7] * rhs.data[5]; -+ accum.data[14] += data[7] * rhs.data[6]; -+ accum.data[15] += data[7] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-2 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[3]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-2 matrix -+template -+using Matrix4x2 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( -+ Element _0_0, Element _0_1, -+ Element _1_0, Element _1_1, -+ Element _2_0, Element _2_1, -+ Element _3_0, Element _3_1 -+) { -+ return Matrix4x2( -+ _0_0, _0_1, -+ _1_0, _1_1, -+ _2_0, _2_1, -+ _3_0, _3_1 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-3 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 3; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 12; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-3 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2, -+ Element _3_0, Element _3_1, Element _3_2 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; -+ data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; -+ data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; -+ data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; -+ } -+ -+ /// Constucts a 4-by-3 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_1.data[0]; -+ data[4] = row_1.data[1]; -+ data[5] = row_1.data[2]; -+ data[6] = row_2.data[0]; -+ data[7] = row_2.data[1]; -+ data[8] = row_2.data[2]; -+ data[9] = row_3.data[0]; -+ data[10] = row_3.data[1]; -+ data[11] = row_3.data[2]; -+ } -+ -+ /// Static method to construct a 4-by-3 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_0.data[1]; -+ result.data[4] = column_1.data[1]; -+ result.data[5] = column_2.data[1]; -+ result.data[6] = column_0.data[2]; -+ result.data[7] = column_1.data[2]; -+ result.data[8] = column_2.data[2]; -+ result.data[9] = column_0.data[3]; -+ result.data[10] = column_1.data[3]; -+ result.data[11] = column_2.data[3]; -+ return result; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[8] = data[2]; -+ mt.data[1] = data[3]; -+ mt.data[5] = data[4]; -+ mt.data[9] = data[5]; -+ mt.data[2] = data[6]; -+ mt.data[6] = data[7]; -+ mt.data[10] = data[8]; -+ mt.data[3] = data[9]; -+ mt.data[7] = data[10]; -+ mt.data[11] = data[11]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x3(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x3(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 3]; -+ m.data[2] = data[i * 3 + j + 6]; -+ m.data[3] = data[i * 3 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 3] = m.data[1]; -+ data[i * 3 + j + 6] = m.data[2]; -+ data[i * 3 + j + 9] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 3]; -+ m.data[3] = data[i * 3 + j + 4]; -+ m.data[4] = data[i * 3 + j + 6]; -+ m.data[5] = data[i * 3 + j + 7]; -+ m.data[6] = data[i * 3 + j + 9]; -+ m.data[7] = data[i * 3 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 3] = m.data[2]; -+ data[i * 3 + j + 4] = m.data[3]; -+ data[i * 3 + j + 6] = m.data[4]; -+ data[i * 3 + j + 7] = m.data[5]; -+ data[i * 3 + j + 9] = m.data[6]; -+ data[i * 3 + j + 10] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 3 + j + 0]; -+ m.data[1] = data[i * 3 + j + 1]; -+ m.data[2] = data[i * 3 + j + 2]; -+ m.data[3] = data[i * 3 + j + 3]; -+ m.data[4] = data[i * 3 + j + 4]; -+ m.data[5] = data[i * 3 + j + 5]; -+ m.data[6] = data[i * 3 + j + 6]; -+ m.data[7] = data[i * 3 + j + 7]; -+ m.data[8] = data[i * 3 + j + 8]; -+ m.data[9] = data[i * 3 + j + 9]; -+ m.data[10] = data[i * 3 + j + 10]; -+ m.data[11] = data[i * 3 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 3 + j + 0] = m.data[0]; -+ data[i * 3 + j + 1] = m.data[1]; -+ data[i * 3 + j + 2] = m.data[2]; -+ data[i * 3 + j + 3] = m.data[3]; -+ data[i * 3 + j + 4] = m.data[4]; -+ data[i * 3 + j + 5] = m.data[5]; -+ data[i * 3 + j + 6] = m.data[6]; -+ data[i * 3 + j + 7] = m.data[7]; -+ data[i * 3 + j + 8] = m.data[8]; -+ data[i * 3 + j + 9] = m.data[9]; -+ data[i * 3 + j + 10] = m.data[10]; -+ data[i * 3 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1) -+ , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1)); -+ } -+ -+ /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0) -+ , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0)); -+ } -+ -+ /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix hcat(Matrix const & rhs) const { -+ return Matrix::hcat(*this, rhs); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 3-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2) -+ , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 2-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by vertically concatenating a 3-by-3 matrix with a 1-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) -+ , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ , C.at(2, 0), D.at(2, 0), D.at(2, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ , C.at(2, 0), C.at(2, 1), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1) -+ , A.at(2, 0), B.at(2, 0), B.at(2, 1) -+ , C, D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-3 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0) -+ , A.at(2, 0), A.at(2, 1), B.at(2, 0) -+ , C.at(0, 0), C.at(0, 1), D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ -+ result.data[3] = data[3] + rhs.data[3]; -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ result.data[8] = data[8] + rhs.data[8]; -+ -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ -+ data[3] += rhs.data[3]; -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ data[8] += rhs.data[8]; -+ -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ -+ result.data[3] = data[3] - rhs.data[3]; -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ result.data[8] = data[8] - rhs.data[8]; -+ -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ -+ data[3] -= rhs.data[3]; -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ data[8] -= rhs.data[8]; -+ -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ -+ result.data[3] = data[3] * rhs.data[3]; -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ result.data[8] = data[8] * rhs.data[8]; -+ -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ -+ result.data[3] = data[3] * s; -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ result.data[8] = data[8] * s; -+ -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ -+ data[3] *= s; -+ data[4] *= s; -+ data[5] *= s; -+ -+ data[6] *= s; -+ data[7] *= s; -+ data[8] *= s; -+ -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ -+ result.data[3] = data[3] / rhs.data[3]; -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ result.data[8] = data[8] / rhs.data[8]; -+ -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ -+ result.data[3] = data[3] / s; -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ result.data[8] = data[8] / s; -+ -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ -+ data[3] /= s; -+ data[4] /= s; -+ data[5] /= s; -+ -+ data[6] /= s; -+ data[7] /= s; -+ data[8] /= s; -+ -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-3) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ -+ data[3] /= rhs.data[3]; -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ data[8] /= rhs.data[8]; -+ -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[3] * rhs.data[0]; -+ accum.data[2] += data[6] * rhs.data[0]; -+ accum.data[3] += data[9] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[4] * rhs.data[1]; -+ accum.data[2] += data[7] * rhs.data[1]; -+ accum.data[3] += data[10] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[5] * rhs.data[2]; -+ accum.data[2] += data[8] * rhs.data[2]; -+ accum.data[3] += data[11] * rhs.data[2]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[3] * rhs.data[0]; -+ accum.data[3] += data[3] * rhs.data[1]; -+ accum.data[4] += data[6] * rhs.data[0]; -+ accum.data[5] += data[6] * rhs.data[1]; -+ accum.data[6] += data[9] * rhs.data[0]; -+ accum.data[7] += data[9] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[4] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[7] * rhs.data[2]; -+ accum.data[5] += data[7] * rhs.data[3]; -+ accum.data[6] += data[10] * rhs.data[2]; -+ accum.data[7] += data[10] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[5] * rhs.data[4]; -+ accum.data[3] += data[5] * rhs.data[5]; -+ accum.data[4] += data[8] * rhs.data[4]; -+ accum.data[5] += data[8] * rhs.data[5]; -+ accum.data[6] += data[11] * rhs.data[4]; -+ accum.data[7] += data[11] * rhs.data[5]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[3] * rhs.data[0]; -+ accum.data[4] += data[3] * rhs.data[1]; -+ accum.data[5] += data[3] * rhs.data[2]; -+ accum.data[6] += data[6] * rhs.data[0]; -+ accum.data[7] += data[6] * rhs.data[1]; -+ accum.data[8] += data[6] * rhs.data[2]; -+ accum.data[9] += data[9] * rhs.data[0]; -+ accum.data[10] += data[9] * rhs.data[1]; -+ accum.data[11] += data[9] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[4] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[7] * rhs.data[3]; -+ accum.data[7] += data[7] * rhs.data[4]; -+ accum.data[8] += data[7] * rhs.data[5]; -+ accum.data[9] += data[10] * rhs.data[3]; -+ accum.data[10] += data[10] * rhs.data[4]; -+ accum.data[11] += data[10] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[5] * rhs.data[6]; -+ accum.data[4] += data[5] * rhs.data[7]; -+ accum.data[5] += data[5] * rhs.data[8]; -+ accum.data[6] += data[8] * rhs.data[6]; -+ accum.data[7] += data[8] * rhs.data[7]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[11] * rhs.data[6]; -+ accum.data[10] += data[11] * rhs.data[7]; -+ accum.data[11] += data[11] * rhs.data[8]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Matrix product of size 4-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[3] * rhs.data[0]; -+ accum.data[5] += data[3] * rhs.data[1]; -+ accum.data[6] += data[3] * rhs.data[2]; -+ accum.data[7] += data[3] * rhs.data[3]; -+ accum.data[8] += data[6] * rhs.data[0]; -+ accum.data[9] += data[6] * rhs.data[1]; -+ accum.data[10] += data[6] * rhs.data[2]; -+ accum.data[11] += data[6] * rhs.data[3]; -+ accum.data[12] += data[9] * rhs.data[0]; -+ accum.data[13] += data[9] * rhs.data[1]; -+ accum.data[14] += data[9] * rhs.data[2]; -+ accum.data[15] += data[9] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[4] * rhs.data[4]; -+ accum.data[5] += data[4] * rhs.data[5]; -+ accum.data[6] += data[4] * rhs.data[6]; -+ accum.data[7] += data[4] * rhs.data[7]; -+ accum.data[8] += data[7] * rhs.data[4]; -+ accum.data[9] += data[7] * rhs.data[5]; -+ accum.data[10] += data[7] * rhs.data[6]; -+ accum.data[11] += data[7] * rhs.data[7]; -+ accum.data[12] += data[10] * rhs.data[4]; -+ accum.data[13] += data[10] * rhs.data[5]; -+ accum.data[14] += data[10] * rhs.data[6]; -+ accum.data[15] += data[10] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[5] * rhs.data[8]; -+ accum.data[5] += data[5] * rhs.data[9]; -+ accum.data[6] += data[5] * rhs.data[10]; -+ accum.data[7] += data[5] * rhs.data[11]; -+ accum.data[8] += data[8] * rhs.data[8]; -+ accum.data[9] += data[8] * rhs.data[9]; -+ accum.data[10] += data[8] * rhs.data[10]; -+ accum.data[11] += data[8] * rhs.data[11]; -+ accum.data[12] += data[11] * rhs.data[8]; -+ accum.data[13] += data[11] * rhs.data[9]; -+ accum.data[14] += data[11] * rhs.data[10]; -+ accum.data[15] += data[11] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-3 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[4]; -+ accum += data[8]; -+ -+ return accum; -+ } -+ -+}; -+ -+/// Template alias for 4-by-3 matrix -+template -+using Matrix4x3 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( -+ Element _0_0, Element _0_1, Element _0_2, -+ Element _1_0, Element _1_1, Element _1_2, -+ Element _2_0, Element _2_1, Element _2_2, -+ Element _3_0, Element _3_1, Element _3_2 -+) { -+ return Matrix4x3( -+ _0_0, _0_1, _0_2, -+ _1_0, _1_1, _1_2, -+ _2_0, _2_1, _2_2, -+ _3_0, _3_1, _3_2 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// 4-by-4 matrix template class definition -+template -+struct Matrix { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Element data type -+ using Element = Element_; -+ -+ /// Number of rows in matrix -+ static int const kRows = 4; -+ -+ /// Number of columns in matrix -+ static int const kColumns = 4; -+ -+ /// Layout of matrix in underlying array -+ using Layout = layout::RowMajor; -+ -+ /// Number of elements in matrix -+ static int const kCount = 16; -+ -+ // -+ // Data members -+ // -+ -+ /// Elements of the matrix in row-major layout -+ Array data; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a zero matrix -+ CUTLASS_HOST_DEVICE -+ Matrix() { -+ data.clear(); -+ } -+ -+ /// Copy constructor for a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Matrix(Matrix const &rhs) { -+ data = rhs.data; -+ } -+ -+ /// Constucts a 4-by-4 matrix from scalar elements -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3, -+ Element _3_0, Element _3_1, Element _3_2, Element _3_3 -+ ) { -+ -+ data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; -+ data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; -+ data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; -+ data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; -+ } -+ -+ /// Constucts a 4-by-4 matrix from row vectors -+ CUTLASS_HOST_DEVICE -+ Matrix( -+ Matrix const &row_0, -+ Matrix const &row_1, -+ Matrix const &row_2, -+ Matrix const &row_3 -+ ) { -+ data[0] = row_0.data[0]; -+ data[1] = row_0.data[1]; -+ data[2] = row_0.data[2]; -+ data[3] = row_0.data[3]; -+ data[4] = row_1.data[0]; -+ data[5] = row_1.data[1]; -+ data[6] = row_1.data[2]; -+ data[7] = row_1.data[3]; -+ data[8] = row_2.data[0]; -+ data[9] = row_2.data[1]; -+ data[10] = row_2.data[2]; -+ data[11] = row_2.data[3]; -+ data[12] = row_3.data[0]; -+ data[13] = row_3.data[1]; -+ data[14] = row_3.data[2]; -+ data[15] = row_3.data[3]; -+ } -+ -+ /// Static method to construct a 4-by-4 matrix from column vectors -+ CUTLASS_HOST_DEVICE -+ static Matrix from_columns( -+ Matrix const &column_0, -+ Matrix const &column_1, -+ Matrix const &column_2, -+ Matrix const &column_3 -+ ) { -+ Matrix result; -+ -+ result.data[0] = column_0.data[0]; -+ result.data[1] = column_1.data[0]; -+ result.data[2] = column_2.data[0]; -+ result.data[3] = column_3.data[0]; -+ result.data[4] = column_0.data[1]; -+ result.data[5] = column_1.data[1]; -+ result.data[6] = column_2.data[1]; -+ result.data[7] = column_3.data[1]; -+ result.data[8] = column_0.data[2]; -+ result.data[9] = column_1.data[2]; -+ result.data[10] = column_2.data[2]; -+ result.data[11] = column_3.data[2]; -+ result.data[12] = column_0.data[3]; -+ result.data[13] = column_1.data[3]; -+ result.data[14] = column_2.data[3]; -+ result.data[15] = column_3.data[3]; -+ return result; -+ } -+ -+ /// Constructs an identity matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix identity() { -+ Matrix m; -+ -+ m.data[0] = Element(1); -+ m.data[5] = Element(1); -+ m.data[10] = Element(1); -+ m.data[15] = Element(1); -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element -+ CUTLASS_HOST_DEVICE -+ static Matrix uniform(Element s) { -+ Matrix m; -+ -+ m.data[0] = s; -+ m.data[1] = s; -+ m.data[2] = s; -+ m.data[3] = s; -+ m.data[4] = s; -+ m.data[5] = s; -+ m.data[6] = s; -+ m.data[7] = s; -+ m.data[8] = s; -+ m.data[9] = s; -+ m.data[10] = s; -+ m.data[11] = s; -+ m.data[12] = s; -+ m.data[13] = s; -+ m.data[14] = s; -+ m.data[15] = s; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from a uniform element 1 -+ CUTLASS_HOST_DEVICE -+ static Matrix ones() { -+ return uniform(Element(1)); -+ } -+ -+ /// Constructs a matrix from a uniform element 0 -+ CUTLASS_HOST_DEVICE -+ static Matrix zero() { -+ return Matrix(); -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Constructs a matrix from elements along its diagonal -+ CUTLASS_HOST_DEVICE -+ static Matrix from_diagonal(Matrix const &diag) { -+ Matrix m; -+ -+ m.data[0] = diag.data[0]; -+ m.data[5] = diag.data[1]; -+ m.data[10] = diag.data[2]; -+ m.data[15] = diag.data[3]; -+ -+ return m; -+ } -+ -+ /// Gets an array of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Matrix diagonal() const { -+ Matrix diag; -+ -+ diag.data[0] = data[0]; -+ diag.data[1] = data[5]; -+ diag.data[2] = data[10]; -+ diag.data[3] = data[15]; -+ -+ return diag; -+ } -+ -+ /// Returns a transposed matrix -+ CUTLASS_HOST_DEVICE -+ Matrix transpose() const { -+ Matrix mt; -+ -+ mt.data[0] = data[0]; -+ mt.data[4] = data[1]; -+ mt.data[8] = data[2]; -+ mt.data[12] = data[3]; -+ mt.data[1] = data[4]; -+ mt.data[5] = data[5]; -+ mt.data[9] = data[6]; -+ mt.data[13] = data[7]; -+ mt.data[2] = data[8]; -+ mt.data[6] = data[9]; -+ mt.data[10] = data[10]; -+ mt.data[14] = data[11]; -+ mt.data[3] = data[12]; -+ mt.data[7] = data[13]; -+ mt.data[11] = data[14]; -+ mt.data[15] = data[15]; -+ -+ return mt; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(int i, int j) const { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(int i, int j) { -+ return data[i * 4 + j]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element at(Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & at(Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element &at(int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element at(int offset) const { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element operator[](Coord<2> const &coord) const { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by coordinate -+ CUTLASS_HOST_DEVICE -+ Element & operator[](Coord<2> const &coord) { -+ return at(coord[0], coord[1]); -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element & operator[](int offset) { -+ return data[offset]; -+ } -+ -+ /// Accesses an element by offset -+ CUTLASS_HOST_DEVICE -+ Element operator[](int offset) const { -+ return data[offset]; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_1x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix row(int i) const { -+ return slice_1x4(i, 0); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_row(Matrix const &v, int i = 0) { -+ return set_slice_1x4(v, i, 0); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_2x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_3x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x1(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 4]; -+ m.data[2] = data[i * 4 + j + 8]; -+ m.data[3] = data[i * 4 + j + 12]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 4] = m.data[1]; -+ data[i * 4 + j + 8] = m.data[2]; -+ data[i * 4 + j + 12] = m.data[3]; -+ -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix column(int j) const { -+ return slice_4x1(0, j); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Matrix &set_column(Matrix const &v, int j =0) { -+ return set_slice_4x1(v, 0, j); -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x2(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 4]; -+ m.data[3] = data[i * 4 + j + 5]; -+ m.data[4] = data[i * 4 + j + 8]; -+ m.data[5] = data[i * 4 + j + 9]; -+ m.data[6] = data[i * 4 + j + 12]; -+ m.data[7] = data[i * 4 + j + 13]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 4] = m.data[2]; -+ data[i * 4 + j + 5] = m.data[3]; -+ data[i * 4 + j + 8] = m.data[4]; -+ data[i * 4 + j + 9] = m.data[5]; -+ data[i * 4 + j + 12] = m.data[6]; -+ data[i * 4 + j + 13] = m.data[7]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x3(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 4]; -+ m.data[4] = data[i * 4 + j + 5]; -+ m.data[5] = data[i * 4 + j + 6]; -+ m.data[6] = data[i * 4 + j + 8]; -+ m.data[7] = data[i * 4 + j + 9]; -+ m.data[8] = data[i * 4 + j + 10]; -+ m.data[9] = data[i * 4 + j + 12]; -+ m.data[10] = data[i * 4 + j + 13]; -+ m.data[11] = data[i * 4 + j + 14]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 4] = m.data[3]; -+ data[i * 4 + j + 5] = m.data[4]; -+ data[i * 4 + j + 6] = m.data[5]; -+ data[i * 4 + j + 8] = m.data[6]; -+ data[i * 4 + j + 9] = m.data[7]; -+ data[i * 4 + j + 10] = m.data[8]; -+ data[i * 4 + j + 12] = m.data[9]; -+ data[i * 4 + j + 13] = m.data[10]; -+ data[i * 4 + j + 14] = m.data[11]; -+ -+ return *this; -+ } -+ -+ /// Gets a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix slice_4x4(int i = 0, int j = 0) const { -+ Matrix m; -+ -+ m.data[0] = data[i * 4 + j + 0]; -+ m.data[1] = data[i * 4 + j + 1]; -+ m.data[2] = data[i * 4 + j + 2]; -+ m.data[3] = data[i * 4 + j + 3]; -+ m.data[4] = data[i * 4 + j + 4]; -+ m.data[5] = data[i * 4 + j + 5]; -+ m.data[6] = data[i * 4 + j + 6]; -+ m.data[7] = data[i * 4 + j + 7]; -+ m.data[8] = data[i * 4 + j + 8]; -+ m.data[9] = data[i * 4 + j + 9]; -+ m.data[10] = data[i * 4 + j + 10]; -+ m.data[11] = data[i * 4 + j + 11]; -+ m.data[12] = data[i * 4 + j + 12]; -+ m.data[13] = data[i * 4 + j + 13]; -+ m.data[14] = data[i * 4 + j + 14]; -+ m.data[15] = data[i * 4 + j + 15]; -+ -+ return m; -+ } -+ -+ /// Overwrites a submatrix with optional offset -+ CUTLASS_HOST_DEVICE -+ Matrix & set_slice_4x4(Matrix const &m, int i = 0, int j = 0) { -+ -+ data[i * 4 + j + 0] = m.data[0]; -+ data[i * 4 + j + 1] = m.data[1]; -+ data[i * 4 + j + 2] = m.data[2]; -+ data[i * 4 + j + 3] = m.data[3]; -+ data[i * 4 + j + 4] = m.data[4]; -+ data[i * 4 + j + 5] = m.data[5]; -+ data[i * 4 + j + 6] = m.data[6]; -+ data[i * 4 + j + 7] = m.data[7]; -+ data[i * 4 + j + 8] = m.data[8]; -+ data[i * 4 + j + 9] = m.data[9]; -+ data[i * 4 + j + 10] = m.data[10]; -+ data[i * 4 + j + 11] = m.data[11]; -+ data[i * 4 + j + 12] = m.data[12]; -+ data[i * 4 + j + 13] = m.data[13]; -+ data[i * 4 + j + 14] = m.data[14]; -+ data[i * 4 + j + 15] = m.data[15]; -+ -+ return *this; -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-3 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) -+ , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) -+ , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2) -+ , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1), rhs.at(3, 2)); -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-2 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) -+ , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) -+ , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1) -+ , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0), rhs.at(3, 1)); -+ } -+ -+ /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-3 matrix with a 4-by-1 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { -+ return Matrix( -+ lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) -+ , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) -+ , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0) -+ , lhs.at(3, 0), lhs.at(3, 1), lhs.at(3, 2), rhs.at(3, 0)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 3-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3) -+ , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2), lower.at(2, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 2-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) -+ , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by vertically concatenating a 3-by-4 matrix with a 1-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ static Matrix vcat(Matrix const & upper, Matrix const & lower) { -+ return Matrix( -+ upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) -+ , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) -+ , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2), upper.at(2, 3) -+ , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Element A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A, B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ , C.at(2, 0), D.at(2, 0), D.at(2, 1), D.at(2, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ , C.at(2, 0), C.at(2, 1), D.at(2, 0), D.at(2, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Element B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ , C.at(2, 0), C.at(2, 1), C.at(2, 2), D.at(2, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) -+ , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Element C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) -+ , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) -+ , A.at(2, 0), B.at(2, 0), B.at(2, 1), B.at(2, 2) -+ , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Matrix const & D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) -+ , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) -+ , A.at(2, 0), A.at(2, 1), B.at(2, 0), B.at(2, 1) -+ , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) -+ ); -+ } -+ -+ /// Forms a 4-by-4 matrix by concatenating four components -+ CUTLASS_HOST_DEVICE -+ static Matrix block( -+ Matrix const & A, Matrix const & B, -+ Matrix const & C, Element D) { -+ return Matrix( -+ A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) -+ , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) -+ , A.at(2, 0), A.at(2, 1), A.at(2, 2), B.at(2, 0) -+ , C.at(0, 0), C.at(0, 1), C.at(0, 2), D -+ ); -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix add(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] + rhs.data[0]; -+ result.data[1] = data[1] + rhs.data[1]; -+ result.data[2] = data[2] + rhs.data[2]; -+ result.data[3] = data[3] + rhs.data[3]; -+ -+ result.data[4] = data[4] + rhs.data[4]; -+ result.data[5] = data[5] + rhs.data[5]; -+ result.data[6] = data[6] + rhs.data[6]; -+ result.data[7] = data[7] + rhs.data[7]; -+ -+ result.data[8] = data[8] + rhs.data[8]; -+ result.data[9] = data[9] + rhs.data[9]; -+ result.data[10] = data[10] + rhs.data[10]; -+ result.data[11] = data[11] + rhs.data[11]; -+ -+ result.data[12] = data[12] + rhs.data[12]; -+ result.data[13] = data[13] + rhs.data[13]; -+ result.data[14] = data[14] + rhs.data[14]; -+ result.data[15] = data[15] + rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator +(Matrix const &rhs) const { -+ return add(rhs); -+ } -+ -+ /// Elementwise add operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator +=(Matrix const &rhs) { -+ -+ data[0] += rhs.data[0]; -+ data[1] += rhs.data[1]; -+ data[2] += rhs.data[2]; -+ data[3] += rhs.data[3]; -+ -+ data[4] += rhs.data[4]; -+ data[5] += rhs.data[5]; -+ data[6] += rhs.data[6]; -+ data[7] += rhs.data[7]; -+ -+ data[8] += rhs.data[8]; -+ data[9] += rhs.data[9]; -+ data[10] += rhs.data[10]; -+ data[11] += rhs.data[11]; -+ -+ data[12] += rhs.data[12]; -+ data[13] += rhs.data[13]; -+ data[14] += rhs.data[14]; -+ data[15] += rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix subtract(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] - rhs.data[0]; -+ result.data[1] = data[1] - rhs.data[1]; -+ result.data[2] = data[2] - rhs.data[2]; -+ result.data[3] = data[3] - rhs.data[3]; -+ -+ result.data[4] = data[4] - rhs.data[4]; -+ result.data[5] = data[5] - rhs.data[5]; -+ result.data[6] = data[6] - rhs.data[6]; -+ result.data[7] = data[7] - rhs.data[7]; -+ -+ result.data[8] = data[8] - rhs.data[8]; -+ result.data[9] = data[9] - rhs.data[9]; -+ result.data[10] = data[10] - rhs.data[10]; -+ result.data[11] = data[11] - rhs.data[11]; -+ -+ result.data[12] = data[12] - rhs.data[12]; -+ result.data[13] = data[13] - rhs.data[13]; -+ result.data[14] = data[14] - rhs.data[14]; -+ result.data[15] = data[15] - rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator -(Matrix const &rhs) const { -+ return subtract(rhs); -+ } -+ -+ /// Elementwise subtract operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator -=(Matrix const &rhs) { -+ -+ data[0] -= rhs.data[0]; -+ data[1] -= rhs.data[1]; -+ data[2] -= rhs.data[2]; -+ data[3] -= rhs.data[3]; -+ -+ data[4] -= rhs.data[4]; -+ data[5] -= rhs.data[5]; -+ data[6] -= rhs.data[6]; -+ data[7] -= rhs.data[7]; -+ -+ data[8] -= rhs.data[8]; -+ data[9] -= rhs.data[9]; -+ data[10] -= rhs.data[10]; -+ data[11] -= rhs.data[11]; -+ -+ data[12] -= rhs.data[12]; -+ data[13] -= rhs.data[13]; -+ data[14] -= rhs.data[14]; -+ data[15] -= rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Elementwise multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * rhs.data[0]; -+ result.data[1] = data[1] * rhs.data[1]; -+ result.data[2] = data[2] * rhs.data[2]; -+ result.data[3] = data[3] * rhs.data[3]; -+ -+ result.data[4] = data[4] * rhs.data[4]; -+ result.data[5] = data[5] * rhs.data[5]; -+ result.data[6] = data[6] * rhs.data[6]; -+ result.data[7] = data[7] * rhs.data[7]; -+ -+ result.data[8] = data[8] * rhs.data[8]; -+ result.data[9] = data[9] * rhs.data[9]; -+ result.data[10] = data[10] * rhs.data[10]; -+ result.data[11] = data[11] * rhs.data[11]; -+ -+ result.data[12] = data[12] * rhs.data[12]; -+ result.data[13] = data[13] * rhs.data[13]; -+ result.data[14] = data[14] * rhs.data[14]; -+ result.data[15] = data[15] * rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix multiply(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] * s; -+ result.data[1] = data[1] * s; -+ result.data[2] = data[2] * s; -+ result.data[3] = data[3] * s; -+ -+ result.data[4] = data[4] * s; -+ result.data[5] = data[5] * s; -+ result.data[6] = data[6] * s; -+ result.data[7] = data[7] * s; -+ -+ result.data[8] = data[8] * s; -+ result.data[9] = data[9] * s; -+ result.data[10] = data[10] * s; -+ result.data[11] = data[11] * s; -+ -+ result.data[12] = data[12] * s; -+ result.data[13] = data[13] * s; -+ result.data[14] = data[14] * s; -+ result.data[15] = data[15] * s; -+ -+ return result; -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator *(Element const &s) const { -+ return multiply(s); -+ } -+ -+ /// Scalar multiply operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator *=(Element const &s) { -+ -+ data[0] *= s; -+ data[1] *= s; -+ data[2] *= s; -+ data[3] *= s; -+ -+ data[4] *= s; -+ data[5] *= s; -+ data[6] *= s; -+ data[7] *= s; -+ -+ data[8] *= s; -+ data[9] *= s; -+ data[10] *= s; -+ data[11] *= s; -+ -+ data[12] *= s; -+ data[13] *= s; -+ data[14] *= s; -+ data[15] *= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Matrix const &rhs) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / rhs.data[0]; -+ result.data[1] = data[1] / rhs.data[1]; -+ result.data[2] = data[2] / rhs.data[2]; -+ result.data[3] = data[3] / rhs.data[3]; -+ -+ result.data[4] = data[4] / rhs.data[4]; -+ result.data[5] = data[5] / rhs.data[5]; -+ result.data[6] = data[6] / rhs.data[6]; -+ result.data[7] = data[7] / rhs.data[7]; -+ -+ result.data[8] = data[8] / rhs.data[8]; -+ result.data[9] = data[9] / rhs.data[9]; -+ result.data[10] = data[10] / rhs.data[10]; -+ result.data[11] = data[11] / rhs.data[11]; -+ -+ result.data[12] = data[12] / rhs.data[12]; -+ result.data[13] = data[13] / rhs.data[13]; -+ result.data[14] = data[14] / rhs.data[14]; -+ result.data[15] = data[15] / rhs.data[15]; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix divide(Element const &s) const { -+ -+ Matrix result; -+ -+ result.data[0] = data[0] / s; -+ result.data[1] = data[1] / s; -+ result.data[2] = data[2] / s; -+ result.data[3] = data[3] / s; -+ -+ result.data[4] = data[4] / s; -+ result.data[5] = data[5] / s; -+ result.data[6] = data[6] / s; -+ result.data[7] = data[7] / s; -+ -+ result.data[8] = data[8] / s; -+ result.data[9] = data[9] / s; -+ result.data[10] = data[10] / s; -+ result.data[11] = data[11] / s; -+ -+ result.data[12] = data[12] / s; -+ result.data[13] = data[13] / s; -+ result.data[14] = data[14] / s; -+ result.data[15] = data[15] / s; -+ -+ return result; -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Element const &s) const { -+ return divide(s); -+ } -+ -+ /// Scalar divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Element const &s) { -+ -+ data[0] /= s; -+ data[1] /= s; -+ data[2] /= s; -+ data[3] /= s; -+ -+ data[4] /= s; -+ data[5] /= s; -+ data[6] /= s; -+ data[7] /= s; -+ -+ data[8] /= s; -+ data[9] /= s; -+ data[10] /= s; -+ data[11] /= s; -+ -+ data[12] /= s; -+ data[13] /= s; -+ data[14] /= s; -+ data[15] /= s; -+ -+ return *this; -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix operator /(Matrix const &rhs) const { -+ return divide(rhs); -+ } -+ -+ /// Elementwise divide operator (4-by-4) -+ CUTLASS_HOST_DEVICE -+ Matrix & operator /=(Matrix const &rhs) { -+ -+ data[0] /= rhs.data[0]; -+ data[1] /= rhs.data[1]; -+ data[2] /= rhs.data[2]; -+ data[3] /= rhs.data[3]; -+ -+ data[4] /= rhs.data[4]; -+ data[5] /= rhs.data[5]; -+ data[6] /= rhs.data[6]; -+ data[7] /= rhs.data[7]; -+ -+ data[8] /= rhs.data[8]; -+ data[9] /= rhs.data[9]; -+ data[10] /= rhs.data[10]; -+ data[11] /= rhs.data[11]; -+ -+ data[12] /= rhs.data[12]; -+ data[13] /= rhs.data[13]; -+ data[14] /= rhs.data[14]; -+ data[15] /= rhs.data[15]; -+ -+ return *this; -+ } -+ -+ /// Negates each element of the matrix -+ CUTLASS_HOST_DEVICE -+ Matrix operator-() const { -+ Matrix m; -+ -+ m.data[0] = -m.data[0]; -+ m.data[1] = -m.data[1]; -+ m.data[2] = -m.data[2]; -+ m.data[3] = -m.data[3]; -+ m.data[4] = -m.data[4]; -+ m.data[5] = -m.data[5]; -+ m.data[6] = -m.data[6]; -+ m.data[7] = -m.data[7]; -+ m.data[8] = -m.data[8]; -+ m.data[9] = -m.data[9]; -+ m.data[10] = -m.data[10]; -+ m.data[11] = -m.data[11]; -+ m.data[12] = -m.data[12]; -+ m.data[13] = -m.data[13]; -+ m.data[14] = -m.data[14]; -+ m.data[15] = -m.data[15]; -+ -+ return m; -+ } -+ -+ /// Matrix product of size 4-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[4] * rhs.data[0]; -+ accum.data[2] += data[8] * rhs.data[0]; -+ accum.data[3] += data[12] * rhs.data[0]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[1]; -+ accum.data[1] += data[5] * rhs.data[1]; -+ accum.data[2] += data[9] * rhs.data[1]; -+ accum.data[3] += data[13] * rhs.data[1]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[2]; -+ accum.data[1] += data[6] * rhs.data[2]; -+ accum.data[2] += data[10] * rhs.data[2]; -+ accum.data[3] += data[14] * rhs.data[2]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[3]; -+ accum.data[1] += data[7] * rhs.data[3]; -+ accum.data[2] += data[11] * rhs.data[3]; -+ accum.data[3] += data[15] * rhs.data[3]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-1-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[4] * rhs.data[0]; -+ accum.data[3] += data[4] * rhs.data[1]; -+ accum.data[4] += data[8] * rhs.data[0]; -+ accum.data[5] += data[8] * rhs.data[1]; -+ accum.data[6] += data[12] * rhs.data[0]; -+ accum.data[7] += data[12] * rhs.data[1]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[2]; -+ accum.data[1] += data[1] * rhs.data[3]; -+ accum.data[2] += data[5] * rhs.data[2]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[9] * rhs.data[2]; -+ accum.data[5] += data[9] * rhs.data[3]; -+ accum.data[6] += data[13] * rhs.data[2]; -+ accum.data[7] += data[13] * rhs.data[3]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[4]; -+ accum.data[1] += data[2] * rhs.data[5]; -+ accum.data[2] += data[6] * rhs.data[4]; -+ accum.data[3] += data[6] * rhs.data[5]; -+ accum.data[4] += data[10] * rhs.data[4]; -+ accum.data[5] += data[10] * rhs.data[5]; -+ accum.data[6] += data[14] * rhs.data[4]; -+ accum.data[7] += data[14] * rhs.data[5]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[6]; -+ accum.data[1] += data[3] * rhs.data[7]; -+ accum.data[2] += data[7] * rhs.data[6]; -+ accum.data[3] += data[7] * rhs.data[7]; -+ accum.data[4] += data[11] * rhs.data[6]; -+ accum.data[5] += data[11] * rhs.data[7]; -+ accum.data[6] += data[15] * rhs.data[6]; -+ accum.data[7] += data[15] * rhs.data[7]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-2-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[4] * rhs.data[0]; -+ accum.data[4] += data[4] * rhs.data[1]; -+ accum.data[5] += data[4] * rhs.data[2]; -+ accum.data[6] += data[8] * rhs.data[0]; -+ accum.data[7] += data[8] * rhs.data[1]; -+ accum.data[8] += data[8] * rhs.data[2]; -+ accum.data[9] += data[12] * rhs.data[0]; -+ accum.data[10] += data[12] * rhs.data[1]; -+ accum.data[11] += data[12] * rhs.data[2]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[3]; -+ accum.data[1] += data[1] * rhs.data[4]; -+ accum.data[2] += data[1] * rhs.data[5]; -+ accum.data[3] += data[5] * rhs.data[3]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[9] * rhs.data[3]; -+ accum.data[7] += data[9] * rhs.data[4]; -+ accum.data[8] += data[9] * rhs.data[5]; -+ accum.data[9] += data[13] * rhs.data[3]; -+ accum.data[10] += data[13] * rhs.data[4]; -+ accum.data[11] += data[13] * rhs.data[5]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[6]; -+ accum.data[1] += data[2] * rhs.data[7]; -+ accum.data[2] += data[2] * rhs.data[8]; -+ accum.data[3] += data[6] * rhs.data[6]; -+ accum.data[4] += data[6] * rhs.data[7]; -+ accum.data[5] += data[6] * rhs.data[8]; -+ accum.data[6] += data[10] * rhs.data[6]; -+ accum.data[7] += data[10] * rhs.data[7]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[14] * rhs.data[6]; -+ accum.data[10] += data[14] * rhs.data[7]; -+ accum.data[11] += data[14] * rhs.data[8]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[9]; -+ accum.data[1] += data[3] * rhs.data[10]; -+ accum.data[2] += data[3] * rhs.data[11]; -+ accum.data[3] += data[7] * rhs.data[9]; -+ accum.data[4] += data[7] * rhs.data[10]; -+ accum.data[5] += data[7] * rhs.data[11]; -+ accum.data[6] += data[11] * rhs.data[9]; -+ accum.data[7] += data[11] * rhs.data[10]; -+ accum.data[8] += data[11] * rhs.data[11]; -+ accum.data[9] += data[15] * rhs.data[9]; -+ accum.data[10] += data[15] * rhs.data[10]; -+ accum.data[11] += data[15] * rhs.data[11]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-3-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix product( -+ Matrix const &rhs, -+ Matrix accum = Matrix() -+ ) const { -+ -+ // k=0 -+ accum.data[0] += data[0] * rhs.data[0]; -+ accum.data[1] += data[0] * rhs.data[1]; -+ accum.data[2] += data[0] * rhs.data[2]; -+ accum.data[3] += data[0] * rhs.data[3]; -+ accum.data[4] += data[4] * rhs.data[0]; -+ accum.data[5] += data[4] * rhs.data[1]; -+ accum.data[6] += data[4] * rhs.data[2]; -+ accum.data[7] += data[4] * rhs.data[3]; -+ accum.data[8] += data[8] * rhs.data[0]; -+ accum.data[9] += data[8] * rhs.data[1]; -+ accum.data[10] += data[8] * rhs.data[2]; -+ accum.data[11] += data[8] * rhs.data[3]; -+ accum.data[12] += data[12] * rhs.data[0]; -+ accum.data[13] += data[12] * rhs.data[1]; -+ accum.data[14] += data[12] * rhs.data[2]; -+ accum.data[15] += data[12] * rhs.data[3]; -+ -+ // k=1 -+ accum.data[0] += data[1] * rhs.data[4]; -+ accum.data[1] += data[1] * rhs.data[5]; -+ accum.data[2] += data[1] * rhs.data[6]; -+ accum.data[3] += data[1] * rhs.data[7]; -+ accum.data[4] += data[5] * rhs.data[4]; -+ accum.data[5] += data[5] * rhs.data[5]; -+ accum.data[6] += data[5] * rhs.data[6]; -+ accum.data[7] += data[5] * rhs.data[7]; -+ accum.data[8] += data[9] * rhs.data[4]; -+ accum.data[9] += data[9] * rhs.data[5]; -+ accum.data[10] += data[9] * rhs.data[6]; -+ accum.data[11] += data[9] * rhs.data[7]; -+ accum.data[12] += data[13] * rhs.data[4]; -+ accum.data[13] += data[13] * rhs.data[5]; -+ accum.data[14] += data[13] * rhs.data[6]; -+ accum.data[15] += data[13] * rhs.data[7]; -+ -+ // k=2 -+ accum.data[0] += data[2] * rhs.data[8]; -+ accum.data[1] += data[2] * rhs.data[9]; -+ accum.data[2] += data[2] * rhs.data[10]; -+ accum.data[3] += data[2] * rhs.data[11]; -+ accum.data[4] += data[6] * rhs.data[8]; -+ accum.data[5] += data[6] * rhs.data[9]; -+ accum.data[6] += data[6] * rhs.data[10]; -+ accum.data[7] += data[6] * rhs.data[11]; -+ accum.data[8] += data[10] * rhs.data[8]; -+ accum.data[9] += data[10] * rhs.data[9]; -+ accum.data[10] += data[10] * rhs.data[10]; -+ accum.data[11] += data[10] * rhs.data[11]; -+ accum.data[12] += data[14] * rhs.data[8]; -+ accum.data[13] += data[14] * rhs.data[9]; -+ accum.data[14] += data[14] * rhs.data[10]; -+ accum.data[15] += data[14] * rhs.data[11]; -+ -+ // k=3 -+ accum.data[0] += data[3] * rhs.data[12]; -+ accum.data[1] += data[3] * rhs.data[13]; -+ accum.data[2] += data[3] * rhs.data[14]; -+ accum.data[3] += data[3] * rhs.data[15]; -+ accum.data[4] += data[7] * rhs.data[12]; -+ accum.data[5] += data[7] * rhs.data[13]; -+ accum.data[6] += data[7] * rhs.data[14]; -+ accum.data[7] += data[7] * rhs.data[15]; -+ accum.data[8] += data[11] * rhs.data[12]; -+ accum.data[9] += data[11] * rhs.data[13]; -+ accum.data[10] += data[11] * rhs.data[14]; -+ accum.data[11] += data[11] * rhs.data[15]; -+ accum.data[12] += data[15] * rhs.data[12]; -+ accum.data[13] += data[15] * rhs.data[13]; -+ accum.data[14] += data[15] * rhs.data[14]; -+ accum.data[15] += data[15] * rhs.data[15]; -+ -+ return accum; -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix operator*(Matrix const &rhs) const { -+ return product(rhs); -+ } -+ -+ /// Matrix product of size 4-by-4-by-4 -+ CUTLASS_HOST_DEVICE -+ Matrix & operator*=(Matrix const &rhs) { -+ *this = product(rhs); -+ return *this; -+ } -+ -+ /// Returns the sum of elements -+ CUTLASS_HOST_DEVICE -+ Element sum(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[1]; -+ accum += data[2]; -+ accum += data[3]; -+ accum += data[4]; -+ accum += data[5]; -+ accum += data[6]; -+ accum += data[7]; -+ accum += data[8]; -+ accum += data[9]; -+ accum += data[10]; -+ accum += data[11]; -+ accum += data[12]; -+ accum += data[13]; -+ accum += data[14]; -+ accum += data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns the sum of squared elements -+ CUTLASS_HOST_DEVICE -+ Element norm(Element accum = Element()) const { -+ -+ accum += data[0] * data[0]; -+ accum += data[1] * data[1]; -+ accum += data[2] * data[2]; -+ accum += data[3] * data[3]; -+ accum += data[4] * data[4]; -+ accum += data[5] * data[5]; -+ accum += data[6] * data[6]; -+ accum += data[7] * data[7]; -+ accum += data[8] * data[8]; -+ accum += data[9] * data[9]; -+ accum += data[10] * data[10]; -+ accum += data[11] * data[11]; -+ accum += data[12] * data[12]; -+ accum += data[13] * data[13]; -+ accum += data[14] * data[14]; -+ accum += data[15] * data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns square root of the norm -+ CUTLASS_HOST_DEVICE -+ Element magnitude() const { -+ return fast_sqrt(norm()); -+ } -+ -+ /// Returns the sum of diagonal elements -+ CUTLASS_HOST_DEVICE -+ Element trace(Element accum = Element()) const { -+ -+ accum += data[0]; -+ accum += data[5]; -+ accum += data[10]; -+ accum += data[15]; -+ -+ return accum; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the X axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_X(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(1, 1) = c; -+ m.at(1, 2) = -s; -+ m.at(2, 1) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the Y axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Y(Element theta) { -+ Matrix m = identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(2, 0) = -s; -+ m.at(0, 2) = s; -+ m.at(2, 2) = c; -+ -+ return m; -+ } -+ -+ /// Returns 4-by-4 rotation matrix around the Z axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation_Z(Element theta) { -+ Matrix m = Matrix::identity(); -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ m.at(0, 0) = c; -+ m.at(0, 1) = -s; -+ m.at(1, 0) = s; -+ m.at(1, 1) = c; -+ -+ return m; -+ } -+ -+ /// Returns a 4-by-4 rotation matrix around a unit-length axis -+ CUTLASS_HOST_DEVICE -+ static Matrix rotation(Element theta, Matrix const &u) { -+ Element x = u.data[0]; -+ Element y = u.data[1]; -+ Element z = u.data[2]; -+ -+ Element c = fast_cos(theta); -+ Element s = fast_sin(theta); -+ -+ Element one_minus_cos = Element(1) - fast_cos(theta); -+ -+ Matrix m; -+ -+ m.set_slice3x3({ -+ c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, -+ y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, -+ z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a 4-by-4 reflection about the plane specified by the -+ /// unit-length normal vector n_unit -+ CUTLASS_HOST_DEVICE -+ static Matrix reflection(Matrix const &n_unit) { -+ -+ Element a = n_unit.data[0]; -+ Element b = n_unit.data[1]; -+ Element c = n_unit.data[2]; -+ -+ Matrix m = Matrix::identity(); -+ -+ m.set_slice3x3({ -+ Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, -+ Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, -+ Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c -+ }); -+ -+ return m; -+ } -+ -+ /// Returns a perspective projection matrix typical of OpenGL applications -+ CUTLASS_HOST_DEVICE -+ static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) { -+ Element aspect = fovH / fovV; -+ Element f = Element(cos(fovV)) / Element(fovH); -+ Element Q = near_plane - far_plane; -+ -+ return Matrix( -+ f / aspect, 0, 0, 0, -+ 0, f, 0, 0, -+ 0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q, -+ 0, 0, -1, 0 -+ ); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Matrix translation(Matrix const &v) { -+ return Matrix( -+ 1, 0, 0, v.data[0], -+ 0, 1, 0, v.data[1], -+ 0, 0, 1, v.data[2], -+ 0, 0, 0, 1 -+ ); -+ } -+ -+ /// Computes the determinant of a 4-by-4 matrix -+ CUTLASS_HOST_DEVICE -+ Element determinant(Element accum = Element()) const { -+ -+ accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(1, 3), at(2, 1), at(2, 2), at(2, 3), at(3, 1), at(3, 2), at(3, 3) }).determinant(); -+ accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(1, 3), at(2, 0), at(2, 2), at(2, 3), at(3, 0), at(3, 2), at(3, 3) }).determinant(); -+ accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(1, 3), at(2, 0), at(2, 1), at(2, 3), at(3, 0), at(3, 1), at(3, 3) }).determinant(); -+ accum -= at(0, 3) * Matrix({ at(1, 0), at(1, 1), at(1, 2), at(2, 0), at(2, 1), at(2, 2), at(3, 0), at(3, 1), at(3, 2) }).determinant(); -+ -+ return accum; -+ } -+ -+ /// Computes the inverse of a 4-by-4 matrix (ignores the optional argument) -+ CUTLASS_HOST_DEVICE -+ Matrix inverse(Element ignore = 1) const { -+ Matrix B = slice_2x2(0, 2); -+ Matrix A = slice_2x2(0, 0); -+ Matrix C = slice_2x2(2, 0); -+ Matrix D = slice_2x2(2, 2); -+ -+ Matrix D_inv = D.inverse(); -+ -+ Matrix E = (A - B * D_inv * C).inverse(); -+ -+ return Matrix::block( -+ E, -E * B * D_inv, -+ -D_inv * C * E, D_inv + D_inv * C * E * B * D_inv -+ ); -+ } -+ -+}; -+ -+/// Template alias for 4-by-4 matrix -+template -+using Matrix4x4 = Matrix; -+ -+ -+/// Free funciton to infer element type from template arguments -+template -+CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( -+ Element _0_0, Element _0_1, Element _0_2, Element _0_3, -+ Element _1_0, Element _1_1, Element _1_2, Element _1_3, -+ Element _2_0, Element _2_1, Element _2_2, Element _2_3, -+ Element _3_0, Element _3_1, Element _3_2, Element _3_3 -+) { -+ return Matrix4x4( -+ _0_0, _0_1, _0_2, _0_3, -+ _1_0, _1_1, _1_2, _1_3, -+ _2_0, _2_1, _2_2, _2_3, -+ _3_0, _3_1, _3_2, _3_3 -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Elementwise scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Matrix operator*(Element s, Matrix const &rhs) { -+ return rhs.multiply(s); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/matrix_coord.h b/3rdparty/cutlass/include/cutlass/matrix_coord.h -new file mode 100644 -index 0000000..1563575 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix_coord.h -@@ -0,0 +1,164 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a canonical coordinate for rank=2 matrices offering named indices. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes -+/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord. -+struct MatrixCoord : public Coord<2, int> { -+ -+public: -+ -+ /// Integer-valued index -+ using Index = int; -+ -+ /// Base type is a Coord of rank=2 -+ using Base = Coord<2, Index>; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+private: -+ -+ /// Rows dimension -+ static int const kRow = 0; -+ -+ /// Columns dimension -+ static int const kColumn = 1; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ MatrixCoord() { } -+ -+ /// Constructs from Coord<2> -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(Coord<2, Index> const &coord): Base(coord) { } -+ -+ /// Helper to construct from a row and column -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { } -+ -+ /// Helper to construct from a row and column, which are LongIndex based -+ CUTLASS_HOST_DEVICE -+ MatrixCoord(LongIndex row, LongIndex column): Base(make_Coord(Index(row), Index(column))) { } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & row() const { return this->at(kRow); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & row() { return this->at(kRow); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & column() const { return this->at(kColumn); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & column() { return this->at(kColumn); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator+(Base const& b) const { -+ return MatrixCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator-(Base const& b) const { -+ return MatrixCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator*(Base const& b) const { -+ return MatrixCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ MatrixCoord operator/(Base const& b) const { -+ return MatrixCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ MatrixCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/matrix_shape.h b/3rdparty/cutlass/include/cutlass/matrix_shape.h -new file mode 100644 -index 0000000..deae47c ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/matrix_shape.h -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a Shape template for matrix tiles -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Describes the size of a matrix tile -+template < -+ int Row_, ///< rows of a matrix -+ int Column_ ///< columns of a matrix -+> -+struct MatrixShape { -+ static int const kRow = Row_; ///< rows of a matrix -+ static int const kColumn = Column_; ///< columns of a matrix -+ static int const kCount = Row_ * Column_; ///< total number of elements in a matrix -+ -+ // -+ // Static member functions -+ // -+ -+ CUTLASS_HOST_DEVICE -+ static Coord<2> toCoord() { -+ return make_Coord(kRow, kColumn); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/numeric_conversion.h b/3rdparty/cutlass/include/cutlass/numeric_conversion.h -new file mode 100644 -index 0000000..3095cec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/numeric_conversion.h -@@ -0,0 +1,2481 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Boost-like numeric conversion operator for CUTLASS numeric types -+*/ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/thread/unary_op.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Floating-point rounding style similare to Standard Library's formats but supporting -+/// additional rounding options. -+enum class FloatRoundStyle { -+ round_indeterminate, ///< rounding mode unknown -+ round_toward_zero, ///< round toward zero -+ round_to_nearest, ///< round to nearest even -+ round_toward_infinity, ///< round toward infinity -+ round_toward_neg_infinity, ///< round toward negative infinity -+ round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero -+ round_half_ulp_trunc_dntz ///< like round_half_ulp_truncate, except denorms are rounded *toward* zero -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest -+> -+struct NumericConverter { -+ -+ using result_type = T; -+ using source_type = S; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float => int32_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return __float2int_rn(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return __float2int_rz(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#elif !defined(__CUDACC_RTC__) -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TONEAREST); -+ return (result_type)std::nearbyint(s); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TOWARDZERO); -+ return (result_type)std::nearbyint(s); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float => int8_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ int32_t intermediate; -+ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); -+ -+ return static_cast(intermediate); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ int32_t intermediate; -+ asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); -+ -+ return static_cast(intermediate); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#elif !defined(__CUDACC_RTC__) -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TONEAREST); -+ int32_t intermediate = (int32_t)std::nearbyint(s); -+ -+ // Low-end saturation -+ intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); -+ -+ // High-end saturation -+ intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); -+ -+ return static_cast(intermediate); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ -+ using result_type = int8_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ static result_type convert(source_type const & s) { -+ std::fesetround(FE_TOWARDZERO); -+ int32_t intermediate = (int32_t)std::nearbyint(s); -+ -+ // Low-end saturation -+ intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); -+ -+ // High-end saturation -+ intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); -+ -+ return static_cast(intermediate); -+ } -+ -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= half_t -+template -+struct NumericConverter { -+ -+ using result_type = T; -+ using source_type = T; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return s; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> half_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= half_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = half_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result = static_cast(s); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Specialization for round-to-nearest -+template <> -+struct NumericConverter { -+ -+ using result_type = half_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result = static_cast(s); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Specialization for round-toward-zero -+template <> -+struct NumericConverter { -+ -+ using result_type = half_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ /// Round toward zero -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & flt) { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ return half_t(__float2half_rz(flt)); -+ #else -+ // software implementation rounds toward nearest even -+ unsigned const& s = reinterpret_cast(flt); -+ uint16_t sign = uint16_t((s >> 16) & 0x8000); -+ int16_t exp = uint16_t(((s >> 23) & 0xff) - 127); -+ int mantissa = s & 0x7fffff; -+ uint16_t u = 0; -+ -+ if ((s & 0x7fffffff) == 0) { -+ // sign-preserving zero -+ return half_t::bitcast(sign); -+ } -+ -+ if (exp > 15) { -+ if (exp == 128 && mantissa) { -+ // not a number -+ u = 0x7fff; -+ } else { -+ // overflow to infinity -+ u = sign | 0x7c00; -+ } -+ return half_t::bitcast(u); -+ } -+ -+ if (exp >= -14) { -+ // normal fp32 to normal fp16 -+ exp = uint16_t(exp + uint16_t(15)); -+ u = uint16_t(((exp & 0x1f) << 10)); -+ u = uint16_t(u | (mantissa >> 13)); -+ } else { -+ // normal single-precision to subnormal half_t-precision representation -+ int rshift = (-14 - exp); -+ if (rshift < 32) { -+ mantissa |= (1 << 23); -+ mantissa = (mantissa >> rshift); -+ u = (uint16_t(mantissa >> 13) & 0x3ff); -+ } else { -+ mantissa = 0; -+ u = 0; -+ } -+ } -+ -+ u |= sign; -+ -+ return half_t::bitcast(u); -+ -+ #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> bfloat16_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= bfloat16_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = bfloat16_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ uint32_t x32 = reinterpret_cast(s); -+ -+ #if defined(__CUDA_ARCH__) -+ if (::isfinite(s)) { -+ x32 += 0x8000; -+ } -+ #else -+ if (std::isfinite(s)) { -+ x32 += 0x8000; -+ } -+ #endif -+ -+ uint16_t x16 = uint16_t((x32 >> 16) & 0xffff); -+ return bfloat16_t::bitcast(x16); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = bfloat16_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ uint32_t x32 = reinterpret_cast(s); -+ uint16_t x16 = uint16_t(x32 >> 16); -+ -+ return bfloat16_t::bitcast(x16); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for float <=> tfloat32_t -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for float <= tfloat32_t -+template -+struct NumericConverter { -+ -+ using result_type = float; -+ using source_type = tfloat32_t; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ return static_cast(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ unsigned storage = reinterpret_cast(s); -+ -+ if ((storage & 0x7f800000) != 0x7f800000) { -+ -+ bool mantissa_bit = ((storage & (1 << 13)) != 0); -+ bool round_bit = ((storage & (1 << 12)) != 0); -+ bool sticky_bit = ((storage & ((1 << 12) - 1)) != 0); -+ -+ if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { -+ storage += uint32_t(1 << 13); -+ } -+ -+ // Note, the following is intentionally commented out. TF32 -+ // does not define the low order bits, so they may be left in -+ // an undefined state. -+ // -+ // By not truncating these bit explicitly, we avoid an extra logical -+ // operation. -+ // -+ // TF32 may be implicitly converted to float by performing this -+ // operation as needed. -+ // -+ // storage = (storage & ~0x1fff); -+ } -+ else if (storage & ~0xff800000) { -+ storage = 0x7fffffff; -+ } -+ -+ return tfloat32_t::bitcast(storage); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_truncate; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ return tfloat32_t::round_half_ulp_truncate(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// This rounding operation is similar to half_ulp_truncate except it rounds denorms toward zero. -+/// It avoids predicated code, though it requires a temporary register. -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_half_ulp_trunc_dntz; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ unsigned y = reinterpret_cast(s); -+ y = y & 0xff800000; -+ float d = reinterpret_cast(y); -+ float z = d / float(1 << 11) + s; -+ -+ return reinterpret_cast(z); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template <> -+struct NumericConverter { -+ using result_type = tfloat32_t; -+ using source_type = float; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ uint32_t x = reinterpret_cast(s); -+ return tfloat32_t::bitcast(x & 0xffffe000); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion operator for float to tfloat32_t big and small values -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ FloatRoundStyle RoundBig = FloatRoundStyle::round_toward_zero, -+ FloatRoundStyle RoundSmall = FloatRoundStyle::round_half_ulp_truncate -+> -+struct NumericConverterFastF32 { -+ -+ // result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1) -+ using result_type = Array; -+ -+ // source data type -+ using source_type = float; -+ -+ // rounding styles for big and small part -+ static FloatRoundStyle const kRoundBig = RoundBig; -+ static FloatRoundStyle const kRoundSmall = RoundSmall; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ result_type result; -+ NumericConverter convert_big_; -+ NumericConverter convert_small_; -+ -+ // convert and fill tfloat32_t big at idx 0 -+ result[0] = convert_big_(source); -+ -+ // convert and fill tfloat32_t small at idx 1 -+ result[1] = convert_small_(source - static_cast(result[0])); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion and Clamp operator for Integers -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S -+> -+struct NumericConverterClamp { -+ -+ using result_type = T; -+ using source_type = S; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ NumericConverter convert_op; -+ result_type const kClamp_max = platform::numeric_limits::max(); -+ result_type const kClamp_min = platform::numeric_limits::lowest(); -+ if (s < (source_type)kClamp_min) -+ return kClamp_min; -+ if (s > (source_type)kClamp_max) -+ return kClamp_max; -+ return convert_op(s); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conversion operator for Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conversion operator for Array -+template < -+ typename T, -+ typename S, -+ int N, -+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, -+ typename Transform = cutlass::transform::thread::UnaryTransform::Identity -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & s) { -+ -+ result_type result; -+ NumericConverter convert_; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ if( platform::is_same::value ) -+ { -+ result[i] = convert_(s[i]); -+ } else { // conjugate -+ result[i] = conj(convert_(s[i])); -+ } -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round, -+ typename Transform -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ if( platform::is_same::value ) -+ { -+ return s; -+ } else { -+ result_type result; -+ for (int i = 0; i < N; ++i) { -+ result[i] = conj(s[i]); -+ } -+ return result; -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array, round to nearest -+template <> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ Array result; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast(source)); -+ #else -+ NumericConverter convert_; -+ result[0] = convert_(source[0]); -+ result[1] = convert_(source[1]); -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array, round to nearest -+template -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ Array result; -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -+ reinterpret_cast(result) = __half22float2(reinterpret_cast<__half2 const &>(source)); -+ #else -+ NumericConverter convert_; -+ result[0] = convert_(source[0]); -+ result[1] = convert_(source[1]); -+ #endif -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array, round to nearest -+template <> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned d; -+ -+ asm("cvt.rn.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); -+ -+ return reinterpret_cast(d); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 2; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ if (N % 2) { -+ result[N - 1] = convert_element_(source[N - 1]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif // if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conditional guards to enable partial specialization for packed integers -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ -+ ((__CUDACC_VER_MAJOR__ > 10) || \ -+ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ result[0] = convert_element_(source[0]); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ uint32_t tmp; -+ -+ asm volatile( -+ "cvt.pack.sat.s8.s32.b32 %0, %2, %1, 0;\n" -+ : "=r"(tmp) : "r"(source[0]), "r"(source[1])); -+ -+ uint16_t out = (tmp & 0xffff); -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.s8.s32.b32 r4, %4, %3, 0;" -+ "cvt.pack.sat.s8.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ NumericConverter convert_element_; -+ -+ result_type result; -+ -+ result[0] = convert_element_(source[0]); -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ uint32_t tmp; -+ -+ asm volatile( -+ "cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n" -+ : "=r"(tmp) : "r"(source[0]), "r"(source[1])); -+ -+ uint16_t out = (tmp & 0xffff); -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;" -+ "cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out_fp16[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ -+ "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); -+ -+ float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); -+ float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); -+ -+ result_type out; -+ out[0] = res0.x; -+ out[1] = res0.y; -+ out[2] = res1.x; -+ out[3] = res1.y; -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out_fp16[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ -+ "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); -+ -+ float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); -+ float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); -+ -+ result_type out; -+ out[0] = res0.x; -+ out[1] = res0.y; -+ out[2] = res1.x; -+ out[3] = res1.y; -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = half_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ -+ "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = half_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ uint32_t const* src_packed = reinterpret_cast(&source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ -+ "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = half_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out[2]; -+ uint32_t const& src_packed = reinterpret_cast(source); -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo, hi;\n" \ -+ "mov.b32 {lo, hi}, %2;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ -+ "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ -+ "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = half_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ uint32_t out; -+ uint32_t const* src_packed = reinterpret_cast(&source); -+ -+ asm volatile( \ -+ "{\n" \ -+ ".reg .b16 lo;\n" \ -+ ".reg .b16 hi;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ -+ "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ -+ "mov.b32 %0, {lo, hi};\n" \ -+ "}" \ -+ : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); -+ -+ return reinterpret_cast(out); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = bfloat16_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert f8 to float -+ NumericArrayConverter src2float; -+ Array tmp_floats = src2float(source); -+ -+ // Convert float to bf16 -+ result_type out; -+ Array* packed_tmp = reinterpret_cast*>(&tmp_floats); -+ Array* packed_out = reinterpret_cast*>(&out); -+ NumericArrayConverter float2result; -+ packed_out[0] = float2result(packed_tmp[0]); -+ packed_out[1] = float2result(packed_tmp[1]); -+ -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = bfloat16_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert bf16 to float -+ Array tmp; -+ Array* packed_tmp = reinterpret_cast*>(&tmp); -+ Array const* packed_source = reinterpret_cast const*>(&source); -+ NumericArrayConverter src2float; -+ packed_tmp[0] = src2float(packed_source[0]); -+ packed_tmp[1] = src2float(packed_source[1]); -+ -+ // Convert float to f8 -+ NumericArrayConverter float2result; -+ return float2result(tmp); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = bfloat16_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert f8 to float -+ NumericArrayConverter src2float; -+ Array tmp_floats = src2float(source); -+ -+ // Convert float to bf16 -+ result_type out; -+ Array* packed_tmp = reinterpret_cast*>(&tmp_floats); -+ Array* packed_out = reinterpret_cast*>(&out); -+ NumericArrayConverter float2result; -+ packed_out[0] = float2result(packed_tmp[0]); -+ packed_out[1] = float2result(packed_tmp[1]); -+ -+ return out; -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = bfloat16_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ #if defined(CUDA_PTX_FP8_CVT_ENABLED) -+ // Convert bf16 to float -+ Array tmp; -+ Array* packed_tmp = reinterpret_cast*>(&tmp); -+ Array const* packed_source = reinterpret_cast const*>(&source); -+ NumericArrayConverter src2float; -+ packed_tmp[0] = src2float(packed_source[0]); -+ packed_tmp[1] = src2float(packed_source[1]); -+ -+ // Convert float to f8 -+ NumericArrayConverter float2result; -+ return float2result(tmp); -+ #else -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ #endif -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for Array <=> Array -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ NumericConverter converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ result[i] = converter(source[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations for: -+// Array <=> Array -+// Array <=> Array -+// -+// These are needed to avoid multiple-matching-template compilation errors (e.g., when -+// compiling float_e4m3_t <=> float_e4m3_t, which among T <= float_e4m3_t and float_e4m3_t <= T -+// should be used?) -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e4m3_t; -+ using source_element = float_e4m3_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return s; -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ using result_element = float_e5m2_t; -+ using source_element = float_e5m2_t; -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return s; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specialziations for: -+// Array <=> Array -+// Array <=> Array -+// using packed converter under the hood -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename T, -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct PackedNumericArrayConverter { -+ using result_element = T; -+ using source_element = S; -+ -+ using result_type = Array; -+ using source_type = Array; -+ -+ static FloatRoundStyle const round_style = Round; -+ -+private: -+ using packed_result_type = Array; -+ using packed_source_type = Array; -+ -+public: -+ CUTLASS_DEVICE -+ static result_type convert(source_type const & source) { -+ result_type result; -+ packed_result_type* packed_result = reinterpret_cast(&result); -+ const packed_source_type* packed_source = reinterpret_cast(&source); -+ -+ NumericArrayConverter packed_converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ packed_result[i] = packed_converter(packed_source[i]); -+ } -+ -+ // Handle leftovers -+ NumericConverter converter; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N % 4; ++i) { -+ int idx = ((N / 4) * 4) + i; -+ result[idx] = converter(source[idx]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename T, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ typename S, -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter : -+ public PackedNumericArrayConverter {}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Array <= Array -+/// Conversion is performed with saturation regardless of setting of -+/// the `Round` template parameter. -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ // Convert float to int -+ Array temporary; -+ -+ NumericArrayConverter compute_converter; -+ temporary = compute_converter(source); -+ -+ // Convert to int to int8_t -+ NumericArrayConverter destination_converter; -+ return destination_converter(temporary); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ -+ ((__CUDACC_VER_MAJOR__ > 10) || \ -+ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" -+ "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" -+ "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" -+ "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) -+ : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), -+ "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 8), "N must be multiple of 8."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 8; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ unsigned out; -+ -+ asm volatile( -+ "{ .reg .u32 r4;" -+ "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" -+ "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" -+ "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" -+ "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" -+ "}" -+ : "=r"(out) -+ : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), -+ "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); -+ -+ return reinterpret_cast(out); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+/// Partial specialization for Array <= Array -+template < -+ int N, -+ FloatRoundStyle Round -+> -+struct NumericArrayConverter { -+ static_assert(!(N % 8), "N must be multiple of 8."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_HOST_DEVICE -+ static result_type convert(source_type const & source) { -+ -+ NumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = reinterpret_cast *>(&result); -+ Array const *source_ptr = reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 8; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(source_type const &s) { -+ return convert(s); -+ } -+}; -+ -+#endif // Conditional guards to enable partial specialization for packed integers -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// FastNumericArrayConverter only works when the source is within center range. -+/// Conversion operator for Array. See the comments before -+/// FastLinearCombinationClamp. -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &s) { -+ result_type result; -+ NumericArrayConverter convert_; -+ -+ return convert_(s); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ result_type result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int tmp = source[i] + 1262485504 /*0x4B400000*/; -+ result[i] = reinterpret_cast(tmp) - 12582912.0f; -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ Array result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < 4; ++i) { -+ float tmp = source[i] + 12582912.0f; -+ result[i] = reinterpret_cast(tmp); -+ } -+ -+ result[0] = __byte_perm(result[0], result[1], 0x40); -+ result[2] = __byte_perm(result[2], result[3], 0x40); -+ result[0] = __byte_perm(result[0], result[2], 0x5410); -+ -+ return reinterpret_cast(result[0]); -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+/// Partial specialization for Array <= Array -+template -+struct FastNumericArrayConverter { -+ static_assert(!(N % 4), "N must be multiple of 4."); -+ -+ using result_type = Array; -+ using source_type = Array; -+ static FloatRoundStyle const round_style = Round; -+ -+ CUTLASS_DEVICE -+ static result_type convert(source_type const &source) { -+ FastNumericArrayConverter convert_vector_; -+ -+ result_type result; -+ -+ Array *result_ptr = -+ reinterpret_cast *>(&result); -+ Array const *source_ptr = -+ reinterpret_cast const *>(&source); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N / 4; ++i) { -+ result_ptr[i] = convert_vector_(source_ptr[i]); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ result_type operator()(source_type const &s) { return convert(s); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines preferred rounding mode for a pair of types -+template -+struct PreferredRoundingMode { -+ static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; -+}; -+ -+/// Defines preferred rounding mode for a pair of types -+template <> -+struct PreferredRoundingMode { -+ static FloatRoundStyle const kRound = FloatRoundStyle::round_half_ulp_truncate; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Packs predicates into an array. -+template -+struct PackPredicates { -+ using result_type = Array; -+ -+ static_assert(!(N % 4), "Must pack predicates in a count that is a multiple of 4"); -+ -+ CUTLASS_HOST_DEVICE -+ result_type operator()(bool const predicates[]) { -+ -+ result_type packed; -+ packed.clear(); -+ -+ int const kWordSize = 8; -+ uint8_t *bytes = reinterpret_cast(packed.data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int word_idx = (i / kWordSize); -+ int bit_idx = (i % kWordSize); -+ -+ uint8_t mask = ((predicates[i] ? 1u : 0u) << bit_idx); -+ bytes[word_idx] = (bytes[word_idx] | mask); -+ } -+ return packed; -+ } -+}; -+ -+/// Packs predicates into an array -+template -+struct UnpackPredicates { -+ using result_type = Array; -+ -+ static_assert(!(N % 4), "Must unpack predicates in a count that is a multiple of 4"); -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(bool predicates[], result_type const &packed) { -+ -+ int const kWordSize = 8; -+ uint8_t const *bytes = reinterpret_cast(packed.data()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ int word_idx = (i / kWordSize); -+ int bit_idx = (i % kWordSize); -+ -+ predicates[i] = bool((bytes[word_idx] >> bit_idx) & 0x1); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/numeric_types.h b/3rdparty/cutlass/include/cutlass/numeric_types.h -new file mode 100644 -index 0000000..55555ec ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/numeric_types.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Top-level include for all CUTLASS numeric types. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the size of an element in bits -+template -+struct sizeof_bits { -+ static int const value = int(sizeof(T) * 8); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Definitions for 1-bit binary and 4-bit integer types -+// -+ -+/// 1-bit binary type -+using bin1_t = bool; -+ -+/// Defines the size of an element in bits - specialized for bin1_t -+template <> -+struct sizeof_bits { -+ static int const value = 1; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct index_sequence; -+ -+template -+struct index_sequence_helper : index_sequence_helper {}; -+ -+template -+struct index_sequence_helper<0, 0, Next...> { -+ using type = index_sequence<0, Next...>; -+}; -+ -+template -+using make_index_sequence = typename index_sequence_helper::type; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/integer_subbyte.h" -+ -+#include "cutlass/half.h" -+#include "cutlass/bfloat16.h" -+#include "cutlass/tfloat32.h" -+#include "cutlass/float8.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/pipeline.hpp b/3rdparty/cutlass/include/cutlass/pipeline.hpp -new file mode 100644 -index 0000000..67538ae ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/pipeline.hpp -@@ -0,0 +1,529 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are not permit- -+ * ted. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cute/numeric/integral_constant.hpp" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/gemm/dispatch_policy.hpp" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+using namespace arch; -+using namespace cute; -+ -+// Circular Buffer Index + Associated Phase -+// Assumes only one operation possible - i.e., ++ -+template -+struct PipelineState { -+ -+ static constexpr uint32_t Stages = Stages_; -+ -+private: -+ int index_ = 0; -+ uint32_t phase_ = 0; -+ -+public: -+ CUTLASS_DEVICE -+ PipelineState(): index_{}, phase_{} {} -+ -+ CUTLASS_DEVICE -+ PipelineState(int index, uint32_t phase) -+ : index_(index) -+ , phase_(phase){} -+ -+ CUTLASS_DEVICE -+ int index() const { -+ return index_; -+ } -+ -+ CUTLASS_DEVICE -+ uint32_t phase() const { -+ return phase_; -+ } -+ -+ CUTLASS_DEVICE -+ void operator++() { -+ ++index_; -+ if (index_ == Stages) { -+ index_ = 0; -+ phase_ ^= 1; -+ } -+ } -+ -+ CUTLASS_DEVICE -+ PipelineState& operator=(const PipelineState& other) { -+ index_ = other.index(); -+ phase_ = other.phase(); -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ PipelineState advance(uint32_t num_iterations) { -+ // Number of iterations cross over the stage boundary => flipped phase -+ if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { -+ phase_ ^= 1; -+ } -+ // How many times number of iterations cross over the stage boundary and -+ // end up on a odd number => flipped phase -+ if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { -+ phase_ ^= 1; -+ } -+ index_ = (index_ + num_iterations) % Stages; -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { -+ return start_state.advance(num_iterations); -+ } -+}; -+ -+template -+CUTLASS_DEVICE -+PipelineState make_producer_start_state() -+{ -+ // Producer starts with an opposite phase as the buffer are initially empty -+ constexpr int InitialProducerStage = 0; -+ constexpr uint32_t InitialProducerPhase = 1; -+ return {InitialProducerStage, InitialProducerPhase}; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// TMA (producer) Async Pipeline class -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// Assumptions : Constructor is Visible Cluster-wide (as it needs a Cluster-Sync) -+// We have exactly one thread elected in the Producer as the "leader" -+// Currently, it is optional to elect a leader for the Consumers -+template -+class PipelineTmaAsync { -+public : -+ using ClusterShape = ClusterShape_; -+ using FullBarrier = ClusterTransactionBarrier; -+ using EmptyBarrier = ClusterBarrier; -+ using ValueType = FullBarrier::ValueType; -+ static constexpr uint32_t Stages = Stages_; -+ -+ struct SharedStorage { -+ FullBarrier full_barrier_[Stages]; -+ EmptyBarrier empty_barrier_[Stages]; -+ }; -+ -+ enum class ThreadCategory { -+ NonParticipant, -+ Producer, -+ Consumer, -+ ProducerConsumer -+ }; -+ -+ struct Params { -+ uint32_t transaction_bytes = 0; -+ ThreadCategory role = ThreadCategory::NonParticipant; -+ uint32_t is_leader = 0; -+ uint32_t num_consumers = 0; -+ }; -+ -+private : -+ // -+ // Data Members -+ // -+ uint32_t dst_blockid_ = 0; -+ uint32_t is_signalling_thread_ = 0; -+ FullBarrier *full_barrier_ptr_ = nullptr; -+ EmptyBarrier *empty_barrier_ptr_ = nullptr; -+ Params params_; -+ -+ // -+ // Methods -+ // -+ -+public: -+ // Constructor -+ CUTLASS_DEVICE -+ PipelineTmaAsync(SharedStorage& storage, Params params) -+ : params_(params) -+ , full_barrier_ptr_(&storage.full_barrier_[0]) -+ , empty_barrier_ptr_(&storage.empty_barrier_[0]) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ auto cluster_shape = ClusterShape{}; -+ -+ if (warp_idx == 0 && lane_predicate == 1) { -+ // Barrier FULL init -+ for (int i = 0; i < Stages; ++i) { -+ full_barrier_ptr_[i].init(1); -+ } -+ -+ // Barrier EMPTY init -+ uint32_t const num_consumers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ for (int i = 0; i < Stages; ++i) { -+ empty_barrier_ptr_[i].init(num_consumers); -+ } -+ } -+ -+ // Logic to optimally schedule Empty Arrives -+ // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) -+ dim3 block_id = block_id_in_cluster(); -+ auto cluster_size = cute::size(cluster_shape); -+ static constexpr int MaxClusterSize = 16; -+ static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); -+ -+ // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) -+ if (params_.num_consumers == 128) { -+ int thread_idx = threadIdx.x % 128; -+ is_signalling_thread_ = (thread_idx % (128 / MaxClusterSize)) == 0; -+ auto layout = cute::composition(Swizzle<2,0,-2>{}, -+ Layout,Stride<_4, _1>>{}); -+ uint32_t thread_row = warp_idx % 4; -+ uint32_t thread_col = (thread_idx / 8) % 4; -+ dst_blockid_ = layout(thread_row, thread_col); -+ } -+ else if (params_.num_consumers == 32){ -+ int thread_idx = threadIdx.x % 32; -+ is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; -+ auto layout = Layout,Stride<_4, _1>>{}; -+ uint32_t thread_row = thread_idx / 8; -+ uint32_t thread_col = (thread_idx % 8) / 2; -+ dst_blockid_ = layout(thread_row, thread_col); -+ } -+ else { -+ is_signalling_thread_ = 0; -+ } -+ -+ // STEP 2: Find if this dst block-id needs an arrival for this problem -+ is_signalling_thread_ &= dst_blockid_ < cluster_size; -+ is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { -+ // 1. Wait for empty barrier to be ready -+ // 2. Set the transaction bytes set to occur on the Full barrier -+ uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); -+ if ((!done) && (!skip_wait)){ -+ empty_barrier_ptr_[stage].wait(phase); -+ } -+ -+ if (params_.is_leader) { -+ full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); -+ } -+ -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(PipelineState state) { -+ producer_acquire(state.index(), state.phase()); -+ } -+ -+ // NOP for TMA based mainloop -+ CUTLASS_DEVICE -+ void producer_commit(uint32_t stage, uint32_t bytes) { -+ // Below code is used only for unit-testing (in the absennce of TMA commit) -+ #if CUTLASS_UNIT_TEST_PIPELINE -+ if (params_.is_leader) { -+ // STEP 1 : Commit to self -+ full_barrier_ptr_[stage].commit(bytes); -+ -+ // STEP 2 : Commit to other blocks in our cluster -+ auto cluster_shape = ClusterShape{}; -+ Layout block_layout_in_cluster = make_layout(cluster_shape); -+ dim3 local_block_id = cute::block_id_in_cluster(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { -+ uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); -+ full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { -+ uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); -+ full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); -+ } -+ } -+ #endif -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(PipelineState state, uint32_t bytes) { -+ producer_commit(state.index(), bytes); -+ } -+ -+ -+ // Wait for producer to commit transactions (done by TMA) -+ CUTLASS_DEVICE -+ void consumer_wait(uint32_t stage, uint32_t phase) { -+ uint32_t done = full_barrier_ptr_[stage].test_wait(phase); -+ if (!done){ -+ full_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(PipelineState state) { -+ consumer_wait(state.index(), state.phase()); -+ } -+ -+ // Consumer signalling Producer of completion -+ // Ensures all blocks in the Same Row and Column get notifed. -+ CUTLASS_DEVICE -+ void consumer_release(uint32_t stage, uint32_t skip = false) { -+ empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(PipelineState state) { -+ consumer_release(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ ValueType* producer_get_barrier(uint32_t stage) { -+ return reinterpret_cast(&full_barrier_ptr_[stage]); -+ } -+ -+ CUTLASS_DEVICE -+ bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { -+ return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || -+ (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Simple producer-consumer async Pipeline class -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// *Count Signifies the number of producers / consumers who will announce their completion -+ -+template -+class PipelineAsync { -+public : -+ using FullBarrier = ClusterBarrier; -+ using EmptyBarrier = ClusterBarrier; -+ using ProducerBarrierType = FullBarrier::ValueType; -+ static constexpr uint32_t Stages = Stages_; -+ -+ struct SharedStorage { -+ FullBarrier full_barrier_[Stages]; -+ EmptyBarrier empty_barrier_[Stages]; -+ }; -+ -+ enum class ThreadCategory { -+ NonParticipant, -+ Producer, -+ Consumer, -+ ProducerConsumer -+ }; -+ -+ struct Params { -+ ThreadCategory role = ThreadCategory::NonParticipant; -+ uint32_t producer_arv_count = 1; -+ uint32_t consumer_arv_count = 1; -+ uint32_t dst_blockid = cute::block_rank_in_cluster(); -+ }; -+ -+private: -+ // -+ // Data Members -+ // -+ Params params_; -+ FullBarrier *full_barrier_ptr_; -+ EmptyBarrier *empty_barrier_ptr_; -+ -+public: -+ -+ // Default assumption when only storage is passed is : -+ // => single producer, single consumer & they are in the same block (within the Cluster) -+ CUTLASS_DEVICE -+ PipelineAsync(SharedStorage& storage) -+ : PipelineAsync(storage, {}) {} -+ -+ CUTLASS_DEVICE -+ PipelineAsync( -+ SharedStorage& storage, -+ Params const& params) : -+ params_(params), -+ full_barrier_ptr_(&storage.full_barrier_[0]), -+ empty_barrier_ptr_(&storage.empty_barrier_[0]) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Barrier FULL, EMPTY init -+ // Init is done only by thread 0 of the block -+ if (warp_idx == 0 && lane_predicate == 1) { -+ for (int i = 0; i < Stages; ++i) { -+ full_barrier_ptr_[i].init(params.producer_arv_count); -+ empty_barrier_ptr_[i].init(params.consumer_arv_count); -+ } -+ } -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { -+ uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); -+ if ((!done) && (!skip_wait)){ -+ empty_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void producer_acquire(PipelineState state) { -+ producer_acquire(state.index(), state.phase()); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(uint32_t stage) { -+ full_barrier_ptr_[stage].arrive(); -+ } -+ -+ CUTLASS_DEVICE -+ void producer_commit(PipelineState state) { -+ producer_commit(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(uint32_t stage, uint32_t phase) { -+ uint32_t done = full_barrier_ptr_[stage].test_wait(phase); -+ if (!done){ -+ full_barrier_ptr_[stage].wait(phase); -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_wait(PipelineState state) { -+ consumer_wait(state.index(), state.phase()); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(uint32_t stage, uint32_t skip = false) { -+ empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); -+ } -+ -+ CUTLASS_DEVICE -+ void consumer_release(PipelineState state) { -+ consumer_release(state.index()); -+ } -+ -+ CUTLASS_DEVICE -+ ProducerBarrierType* get_producer_barrier(uint32_t stage) { -+ return reinterpret_cast(&full_barrier_ptr_[stage]); -+ } -+}; -+ -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Barrier to ensure an Ordered Sequence between -+// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages -+// i.e., for all i < j - only after id "i" arrives at a particular stage "m" -+// will the wait() for id "j" succeed for the same stage -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class OrderedSequenceBarrier { -+public : -+ using Barrier = ClusterBarrier; -+ -+ struct SharedStorage { -+ Barrier barrier_[SequenceDepth][SequenceLength]; -+ }; -+ -+ struct Params { -+ uint32_t group_id; -+ uint32_t group_size; -+ }; -+ -+private : -+ // -+ // Data Members -+ // -+ -+ // In future this Params object can be replaced easily with a CG object -+ Params params_; -+ Barrier *barrier_ptr_; -+ PipelineState stage_; -+ -+ static constexpr int Depth = SequenceDepth; -+ static constexpr int Length = SequenceLength; -+ -+public: -+ OrderedSequenceBarrier() = delete; -+ OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; -+ OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; -+ OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; -+ OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; -+ ~OrderedSequenceBarrier() = default; -+ -+ CUTLASS_DEVICE -+ OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : -+ params_(params), -+ barrier_ptr_(&storage.barrier_[0][0]), -+ // Group 0 - starts with an opposite phase -+ stage_({0, params.group_id == 0}) { -+ -+ int warp_idx = canonical_warp_idx(); -+ int lane_predicate = cute::elect_one_sync(); -+ -+ // Barrier FULL, EMPTY init -+ // Init is done only by the one elected thread of the block -+ if (warp_idx == 0 && lane_predicate == 1) { -+ for (int d = 0; d < Depth; ++d) { -+ for (int l = 0; l < Length; ++l) { -+ barrier_ptr_[d * Length + l].init(params.group_size); -+ } -+ } -+ } -+ -+ cutlass::arch::fence_barrier_init(); -+ } -+ -+ // Wait on a stage to be unlocked -+ CUTLASS_DEVICE -+ void wait() { -+ get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); -+ } -+ -+ // Signal completion of Stage and move to the next stage -+ // (group_id) signals to (group_id+1) -+ CUTLASS_DEVICE -+ void arrive() { -+ int signalling_id = (params_.group_id + 1) % Length; -+ get_barrier_for_current_stage(signalling_id).arrive(); -+ ++stage_; -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ Barrier& get_barrier_for_current_stage(int group_id) { -+ return barrier_ptr_[stage_.index() * Length + group_id]; -+ } -+}; -+ -+} // end namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h b/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h -new file mode 100644 -index 0000000..2cd7bfe ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/pitch_linear_coord.h -@@ -0,0 +1,181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template defining a shape used by pitch-linear operators -+template < -+ int Contiguous, -+ int Strided -+> -+struct PitchLinearShape { -+ static int const kContiguous = Contiguous; -+ static int const kStrided = Strided; -+ static int const kCount = Contiguous * Strided; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Coordinate in pitch-linear space -+struct PitchLinearCoord : public Coord<2, int> { -+public: -+ -+ /// Integer-valued index -+ using Index = int; -+ -+ /// Base type is a Coord of rank=2 -+ using Base = Coord<2, Index>; -+ -+ /// Long integer type -+ using LongIndex = typename Base::LongIndex; -+ -+private: -+ -+ /// Rows dimension -+ static int const kContiguous = 0; -+ -+ /// Columns dimension -+ static int const kStrided = 1; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord() { } -+ -+ /// Constructs from Coord<2> -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(Coord<2, Index> const &coord): Base(coord) { } -+ -+ /// Helper to construct from a row and column -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(Index contiguous_, Index strided_): Base(make_Coord(contiguous_, strided_)) { } -+ -+ /// Helper to construct from a row and column based on LongIndex -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord(LongIndex contiguous_, LongIndex strided_) -+ : Base(make_Coord(Index(contiguous_), Index(strided_))) { } -+ -+ /// Returns the contiguous dimension -+ CUTLASS_HOST_DEVICE -+ Index const & contiguous() const { return this->at(kContiguous); } -+ -+ /// Returns the contiguous dimension -+ CUTLASS_HOST_DEVICE -+ Index & contiguous() { return this->at(kContiguous); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & strided() const { return this->at(kStrided); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & strided() { return this->at(kStrided); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator+(Base const& b) const { -+ return PitchLinearCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator-(Base const& b) const { -+ return PitchLinearCoord(Base::operator-(b)); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator-() const { -+ return PitchLinearCoord(-at(0), -at(1)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator*(Base const& b) const { -+ return PitchLinearCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord operator/(Base const& b) const { -+ return PitchLinearCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ PitchLinearCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/platform/platform.h b/3rdparty/cutlass/include/cutlass/platform/platform.h -new file mode 100644 -index 0000000..96bb8f6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/platform/platform.h -@@ -0,0 +1,891 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ features that may be otherwise unimplemented for CUDA device functions. -+ * -+ * This file has three components: -+ * -+ * (1) Macros: -+ * - Empty macro defines for C++ keywords not supported by the current -+ * version of C++. These simply allow compilation to proceed (but do -+ * not provide the added semantics). -+ * - \p noexcept -+ * - \p constexpr -+ * - \p nullptr -+ * - \p static_assert -+ * -+ * - Macro functions that we need in constant expressions because the -+ * C++ equivalents require constexpr compiler support. These are -+ * prefixed with \p __NV_STD_* -+ * - \p __NV_STD_MAX -+ * - \p __NV_STD_MIN -+ * -+ * (2) Re-implementations of STL functions and types: -+ * - C++ features that need the \p __device__ annotation. These are -+ * placed into the \p platform namespace. -+ * - \p abs -+ * - \p plus -+ * - \p less -+ * - \p greater -+ * - \p min -+ * - \p max -+ * - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair()) -+ * -+ * (3) Stop-gap implementations of unsupported STL functions and types: -+ * - STL functions and types defined by C++ 11/14/17/etc. that are not -+ * provided by the current version of C++. These are placed into the -+ * \p platform namespace -+ * - \p integral_constant -+ * - \p nullptr_t -+ * - \p true_type -+ * - \p false_type -+ * - \p bool_constant -+ * - \p enable_if -+ * - \p conditional -+ * - \p is_same -+ * - \p is_base_of -+ * - \p remove_const -+ * - \p remove_volatile -+ * - \p remove_cv -+ * - \p is_volatile -+ * - \p is_pointer -+ * - \p is_void -+ * - \p is_integral -+ * - \p is_floating_point -+ * - \p is_arithmetic -+ * - \p is_fundamental -+ * - \p is_trivially_copyable -+ * - \p alignment_of -+ * - \p aligned_storage -+ * -+ * (4) Functions and types that are STL-like (but aren't in the STL): -+ * - \p TODO: min and max functors? -+ * -+ * The idea is that, as we drop support for older compilers, we can simply #define -+ * the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++ -+ * counterparts (or trivially find-and-replace their occurrences in code text). -+ */ -+ -+//----------------------------------------------------------------------------- -+// Dependencies -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+//----------------------------------------------------------------------------- -+// Include STL files that platform provides functionality for -+//----------------------------------------------------------------------------- -+ -+#include // Minimum/maximum operations -+#include // nullptr_t -+#include // Arithmetic operations -+#include // For methods on std::pair -+#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500)) -+#include // For integral constants, conditional metaprogramming, and type traits -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// OS -+//----------------------------------------------------------------------------- -+#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) -+#define CUTLASS_OS_WINDOWS -+#endif -+ -+/****************************************************************************** -+ * Macros -+ ******************************************************************************/ -+//----------------------------------------------------------------------------- -+// Keywords -+//----------------------------------------------------------------------------- -+ -+/// noexcept, constexpr -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) -+#ifndef noexcept -+#define noexcept -+#endif -+#ifndef constexpr -+#define constexpr -+#endif -+#endif -+ -+/// nullptr -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310)) -+#ifndef nullptr -+#define nullptr 0 -+#endif -+#endif -+ -+/// static_assert -+#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) -+#ifndef static_assert -+#define __platform_cat_(a, b) a##b -+#define __platform_cat(a, b) __platform_cat_(a, b) -+#define static_assert(__e, __m) typedef int __platform_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1] -+#endif -+#endif -+ -+//----------------------------------------------------------------------------- -+// Functions -+//----------------------------------------------------------------------------- -+ -+/// Select maximum(a, b) -+#ifndef __NV_STD_MAX -+#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a)) -+#endif -+ -+/// Select minimum(a, b) -+#ifndef __NV_STD_MIN -+#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a)) -+#endif -+ -+/****************************************************************************** -+ * Re-implementations -+ ******************************************************************************/ -+namespace cutlass { -+namespace platform { -+ -+//----------------------------------------------------------------------------- -+// Abs operations -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) -+/// std::abs -+CUTLASS_HOST_DEVICE constexpr int abs(int a) { -+ return (a < 0) ? -a : a; -+} -+CUTLASS_HOST_DEVICE constexpr long long abs(long long a) { -+ return (a < 0) ? -a : a; -+} -+#else -+using std::abs; -+#endif -+ -+//----------------------------------------------------------------------------- -+// Minimum/maximum operations -+//----------------------------------------------------------------------------- -+ -+/// std::min -+template -+CUTLASS_HOST_DEVICE constexpr const T& min(const T& a, const T& b) { -+ return (b < a) ? b : a; -+} -+ -+/// std::max -+template -+CUTLASS_HOST_DEVICE constexpr const T& max(const T& a, const T& b) { -+ return (a < b) ? b : a; -+} -+ -+#if !defined(__CUDACC_RTC__) -+//----------------------------------------------------------------------------- -+// Methods on std::pair -+//----------------------------------------------------------------------------- -+ -+using std::pair; -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator==(const pair& lhs, const pair& rhs) { -+ return (lhs.first == rhs.first) && (lhs.second == rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator!=(const pair& lhs, const pair& rhs) { -+ return (lhs.first != rhs.first) && (lhs.second != rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator<(const pair& lhs, const pair& rhs) { -+ return (lhs.first < rhs.first) ? true : (rhs.first < lhs.first) ? false -+ : (lhs.second < rhs.second); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator<=(const pair& lhs, const pair& rhs) { -+ return !(rhs < lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator>(const pair& lhs, const pair& rhs) { -+ return (rhs < lhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE constexpr bool operator>=(const pair& lhs, const pair& rhs) { -+ return !(lhs < rhs); -+} -+ -+template -+CUTLASS_HOST_DEVICE std::pair make_pair(T1 t, T2 u) { -+ std::pair retval; -+ retval.first = t; -+ retval.second = u; -+ return retval; -+} -+#endif -+ -+} // namespace platform -+ -+/****************************************************************************** -+ * Implementations of C++ 11/14/17/... STL features -+ ******************************************************************************/ -+ -+namespace platform { -+ -+//----------------------------------------------------------------------------- -+// Integral constant helper types -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::integral_constant -+template -+struct integral_constant; -+ -+/// std::integral_constant -+template -+struct integral_constant { -+ static const value_t value = V; -+ -+ typedef value_t value_type; -+ typedef integral_constant type; -+ -+ CUTLASS_HOST_DEVICE operator value_type() const { return value; } -+ -+ CUTLASS_HOST_DEVICE const value_type operator()() const { return value; } -+}; -+ -+#else -+ -+using std::integral_constant; -+using std::pair; -+ -+#endif -+ -+/// The type used as a compile-time boolean with true value. -+typedef integral_constant true_type; -+ -+/// The type used as a compile-time boolean with false value. -+typedef integral_constant false_type; -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) -+ -+/// std::bool_constant -+template -+struct bool_constant : platform::integral_constant {}; -+ -+#else -+ -+using std::bool_constant; -+ -+#endif -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700)) -+ -+/// std::nullptr_t -+struct nullptr_t {}; -+ -+#else -+ -+using std::nullptr_t; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Conditional metaprogramming -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) -+ -+/// std::enable_if (true specialization) -+template -+struct enable_if { -+ typedef T type; -+}; -+ -+/// std::enable_if (false specialization) -+template -+struct enable_if {}; -+ -+/// std::conditional (true specialization) -+template -+struct conditional { -+ typedef T type; -+}; -+ -+/// std::conditional (false specialization) -+template -+struct conditional { -+ typedef F type; -+}; -+ -+#else -+ -+using std::enable_if; -+using std::conditional; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Const/volatility specifiers -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::remove_const (non-const specialization) -+template -+struct remove_const { -+ typedef T type; -+}; -+ -+/// std::remove_const (const specialization) -+template -+struct remove_const { -+ typedef T type; -+}; -+ -+/// std::remove_volatile (non-volatile specialization) -+template -+struct remove_volatile { -+ typedef T type; -+}; -+ -+/// std::remove_volatile (volatile specialization) -+template -+struct remove_volatile { -+ typedef T type; -+}; -+ -+/// std::remove_cv -+template -+struct remove_cv { -+ typedef typename remove_volatile::type>::type type; -+}; -+ -+#else -+ -+using std::remove_const; -+using std::remove_volatile; -+using std::remove_cv; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Type relationships -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::is_same (false specialization) -+template -+struct is_same : false_type {}; -+ -+/// std::is_same (true specialization) -+template -+struct is_same : true_type {}; -+ -+/// Helper for std::is_base_of -+template -+struct is_base_of_helper { -+ typedef char (&yes)[1]; -+ typedef char (&no)[2]; -+ -+ template -+ struct dummy { -+ CUTLASS_HOST_DEVICE operator B*() const; -+ CUTLASS_HOST_DEVICE operator D*(); -+ }; -+ -+ template -+ CUTLASS_HOST_DEVICE static yes check(DerivedT*, T); -+ -+ CUTLASS_HOST_DEVICE static no check(BaseT*, int); -+ -+ static const bool value = sizeof(check(dummy(), int())) == sizeof(yes); -+}; -+ -+/// std::is_base_of -+template -+struct is_base_of -+ : integral_constant::type, -+ typename remove_cv::type>::value) || -+ (is_same::type, -+ typename remove_cv::type>::value)> {}; -+ -+#else -+ -+using std::is_same; -+using std::is_base_of; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// Type properties -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::is_volatile -+template -+struct is_volatile : false_type {}; -+template -+struct is_volatile : true_type {}; -+ -+/// Helper for std::is_pointer (false specialization) -+template -+struct is_pointer_helper : false_type {}; -+ -+/// Helper for std::is_pointer (true specialization) -+template -+struct is_pointer_helper : true_type {}; -+ -+/// std::is_pointer -+template -+struct is_pointer : is_pointer_helper::type> {}; -+ -+/// std::is_void -+template -+struct is_void : is_same::type> {}; -+ -+/// std::is_integral -+template -+struct is_integral : false_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template <> -+struct is_integral : true_type {}; -+template -+struct is_integral : is_integral {}; -+template -+struct is_integral : is_integral {}; -+template -+struct is_integral : is_integral {}; -+ -+/// std::is_floating_point -+template -+struct is_floating_point -+ : integral_constant::type>::value || -+ is_same::type>::value)> {}; -+ -+/// std::is_arithmetic -+template -+struct is_arithmetic -+ : integral_constant::value || is_floating_point::value)> {}; -+ -+/// std::is_fundamental -+template -+struct is_fundamental -+ : integral_constant::value || is_void::value || -+ is_same::type>::value)> {}; -+ -+#else -+ -+using std::is_volatile; -+using std::is_pointer; -+using std::is_void; -+using std::is_integral; -+using std::is_floating_point; -+using std::is_arithmetic; -+using std::is_fundamental; -+ -+#endif -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) || \ -+ (defined(__GNUG__) && (__GNUC__ < 5)) -+ -+/** -+ * std::is_trivially_copyable -+ * -+ * This implementation only evaluates true if T is fundamental or pointer -+ * -+ * Without help from partial template specializations provided by the user for -+ * a specific class or struct, this trait will never report that the specified -+ * class or struct is trivially-copyable ; this is always safe, -+ * if possibly sub-optimal. -+ */ -+template -+struct is_trivially_copyable -+ : integral_constant::value || is_pointer::value)> {}; -+ -+#else -+ -+using std::is_trivially_copyable; -+ -+#endif -+ -+//----------------------------------------------------------------------------- -+// bit_cast -+//----------------------------------------------------------------------------- -+ -+template< class To, class From > -+constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept; -+ -+template -+constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept -+{ -+ static_assert(sizeof(To) == sizeof(From), "sizes must match"); -+ return reinterpret_cast(src); -+} -+ -+//----------------------------------------------------------------------------- -+// Alignment and layout utilities -+//----------------------------------------------------------------------------- -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -+ -+/// std::alignment_of -+template -+struct alignment_of { -+ struct pad { -+ value_t val; -+ char byte; -+ }; -+ -+ enum { value = sizeof(pad) - sizeof(value_t) }; -+}; -+ -+#else -+ -+template -+struct alignment_of : std::alignment_of {}; -+ -+#endif -+ -+/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */ -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+template <> -+struct alignment_of { -+ enum { value = 16 }; -+}; -+ -+// Specializations for volatile/const qualified types -+template -+struct alignment_of : alignment_of {}; -+template -+struct alignment_of : alignment_of {}; -+template -+struct alignment_of : alignment_of {}; -+ -+#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) -+ -+template -+struct aligned_chunk; -+template <> -+struct __align__(1) aligned_chunk<1> { -+ uint8_t buff; -+}; -+template <> -+struct __align__(2) aligned_chunk<2> { -+ uint16_t buff; -+}; -+template <> -+struct __align__(4) aligned_chunk<4> { -+ uint32_t buff; -+}; -+template <> -+struct __align__(8) aligned_chunk<8> { -+ uint32_t buff[2]; -+}; -+template <> -+struct __align__(16) aligned_chunk<16> { -+ uint32_t buff[4]; -+}; -+template <> -+struct __align__(32) aligned_chunk<32> { -+ uint32_t buff[8]; -+}; -+template <> -+struct __align__(64) aligned_chunk<64> { -+ uint32_t buff[16]; -+}; -+template <> -+struct __align__(128) aligned_chunk<128> { -+ uint32_t buff[32]; -+}; -+template <> -+struct __align__(256) aligned_chunk<256> { -+ uint32_t buff[64]; -+}; -+template <> -+struct __align__(512) aligned_chunk<512> { -+ uint32_t buff[128]; -+}; -+template <> -+struct __align__(1024) aligned_chunk<1024> { -+ uint32_t buff[256]; -+}; -+template <> -+struct __align__(2048) aligned_chunk<2048> { -+ uint32_t buff[512]; -+}; -+template <> -+struct __align__(4096) aligned_chunk<4096> { -+ uint32_t buff[1024]; -+}; -+ -+/// std::aligned_storage -+template -+struct aligned_storage { -+ typedef aligned_chunk type[Len / sizeof(aligned_chunk)]; -+}; -+ -+#else -+ -+using std::aligned_storage; -+ -+#endif -+ -+#if !defined(__CUDACC_RTC__) -+/// Default deleter -+template -+struct default_delete { -+ void operator()(T* ptr) const { delete ptr; } -+}; -+ -+/// Partial specialization for deleting array types -+template -+struct default_delete { -+ void operator()(T* ptr) const { delete[] ptr; } -+}; -+ -+/// std::unique_ptr -+template > -+class unique_ptr { -+ public: -+ typedef T* pointer; -+ typedef T element_type; -+ typedef Deleter deleter_type; -+ -+ private: -+ /// Pointer to memory -+ pointer _ptr; -+ -+ /// Deleter -+ deleter_type _deleter; -+ -+ public: -+ unique_ptr() : _ptr(nullptr) {} -+ unique_ptr(pointer p) : _ptr(p) {} -+ -+ ~unique_ptr() { -+ if (_ptr) { -+ _deleter(_ptr); -+ } -+ } -+ /// Returns a pointer to the managed object or nullptr if no object is owned. -+ pointer get() const noexcept { return _ptr; } -+ -+ /// Releases ownership of the managed object, if any -+ pointer release() noexcept { -+ pointer p(_ptr); -+ _ptr = nullptr; -+ return p; -+ } -+ -+ /// Replaces the managed object, deleting the old object. -+ void reset(pointer p = pointer()) noexcept { -+ pointer old_ptr = _ptr; -+ _ptr = p; -+ if (old_ptr != nullptr) { -+ get_deleter()(old_ptr); -+ } -+ } -+ -+ /// Swaps the managed objects with *this and another unique_ptr -+ void swap(unique_ptr& other) noexcept { std::swap(_ptr, other._ptr); } -+ -+ /// Returns the deleter object -+ Deleter& get_deleter() noexcept { return _deleter; } -+ -+ /// Returns the deleter object -+ Deleter const& get_deleter() const noexcept { return _deleter; } -+ -+ /// Checks whether an object is owned -+ operator bool() const noexcept { return _ptr != nullptr; } -+ -+ /// Dereferences the unique_ptr -+ T& operator*() const { return *_ptr; } -+ -+ /// Returns a pointer to the managed object -+ pointer operator->() const noexcept { return _ptr; } -+ -+ /// Array access to managed object -+ T& operator[](size_t i) const { return _ptr[i]; } -+}; -+ -+/// Specializes the swap algorithm -+template -+void swap(unique_ptr& lhs, unique_ptr& rhs) noexcept { -+ lhs.swap(rhs); -+} -+#endif -+ -+/// std::numeric_limits -+template -+struct numeric_limits; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int32_t lowest() noexcept { return -2147483647 - 1;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int32_t max() noexcept { return 2147483647;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int16_t lowest() noexcept { return -32768;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int16_t max() noexcept { return 32767;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr int8_t lowest() noexcept { return -128;} -+ CUTLASS_HOST_DEVICE -+ static constexpr int8_t max() noexcept { return 127;} -+ static constexpr bool is_integer = true; -+}; -+ -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint32_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint32_t max() noexcept { return 4294967295U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint16_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint16_t max() noexcept { return 65535U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr uint8_t lowest() noexcept { return 0;} -+ CUTLASS_HOST_DEVICE -+ static constexpr uint8_t max() noexcept { return 255U;} -+ static constexpr bool is_integer = true; -+}; -+ -+template <> -+struct numeric_limits { -+ CUTLASS_HOST_DEVICE -+ static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} -+ static constexpr bool is_integer = false; -+ static constexpr bool has_infinity = true; -+}; -+ -+} // namespace platform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/predicate_vector.h b/3rdparty/cutlass/include/cutlass/predicate_vector.h -new file mode 100644 -index 0000000..d158225 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/predicate_vector.h -@@ -0,0 +1,524 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines container classes and iterators for managing a statically sized vector -+ of boolean predicates. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#include -+#else -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_vector_concept Predicate Vector Concept -+@{ -+ -+Implementations of \ref predicate_vector_concept contain an ordered set of boolean predicates which -+may be used as conditionals in other device-side operations. Both random access and iterators -+offering sequential access are provided. -+ -+@par Predicate Vector -+ A \ref predicate_vector_concept satisfies the following expressions -+ - at(int idx) - returns the value of the indexed predicate -+ - set(int idx, bool value) - sets the value of the indexed predicate -+ - begin() - returns a \ref predicate_iterator_concept pointing to the first predicate -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_iterator_concept Predicate Iterator Concept -+@{ -+ -+Implementations of \ref predicate_iterator_concept enables accessing and traversing elements of a -+bit vector. -+ -+@par Const Predicate Iterator -+ A const \ref predicate_iterator_concept satisfies the following expressions -+ - ++it increments the iterator to the next predicate -+ - *it returns the value of the currently pointed-to predicate -+ -+@par Mutable Predicate Iterator -+ A \ref predicate_iterator_concept that is non-const also satisfies the following expressions -+ - it.set(bool value) sets the value of the currently pointed-to predicate -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept -+@{ -+ -+Implementations of \ref predicate_tile_adapter provide a mapping between a the elements of a \ref -+tile_traits_concept and a \ref predicate_vector_concept. -+ -+@par Predicate Tile Adapter -+ A \ref predicate_tile_adapter satisfies the following expressions -+ - at(int d, int h, int w, int c) - returns the value of a predicate corresponding to the -+ access (d, h, w, c) within the tile. -+ -+@} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Statically sized array of bits implementing @concept{predicate_vector_concept}. -+template < -+ /// Number of predicates conatined in predicate vector -+ int kPredicates_, -+ /// Number of predicates contained in each byte of internal storage -+ int kPredicatesPerByte_ = 4, -+ /// Location of first predicate within byte of internal storage -+ int kPredicateStart_ = 0> -+struct PredicateVector { -+ /// Number of bits stored by the PredicateVector -+ static int const kPredicates = kPredicates_; -+ -+ /// Number of bits stored within each byte of the predicate bit vector -+ static int const kPredicatesPerByte = kPredicatesPerByte_; -+ -+ /// First bit withing each byte containing predicates -+ static int const kPredicateStart = kPredicateStart_; -+ -+ // Make sure no one tries to put more than 8 bits in a byte :) -+ static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); -+ // Make sure the "offsetted" bits fit in one byte. -+ static_assert(kPredicateStart + kPredicatesPerByte <= 8, -+ "The offsetted predicates must fit within an actual byte."); -+ -+ /// Storage type of individual elements -+ typedef uint32_t Storage; -+ -+ /// Number of bytes needed -+ static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ -+ /// Number of storage elements needed -+ static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Words of bit vector -+ Storage storageData[kWordCount]; -+ -+ // -+ // Methods -+ // -+ -+ /// Computes the word and bit corresponding to a logical predicate index -+ CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const { -+ CUTLASS_ASSERT(idx < kPredicates); -+ -+ int byte = (idx / kPredicatesPerByte); -+ int bit_offset = (idx % kPredicatesPerByte); -+ -+ word = byte / sizeof(Storage); -+ int byte_offset = (byte % sizeof(Storage)); -+ -+ bit = byte_offset * 8 + bit_offset + kPredicateStart; -+ } -+ -+ /// Accesses a given word with optional assertions -+ CUTLASS_HOST_DEVICE Storage &storage(int word) { -+ CUTLASS_ASSERT(word < kWordCount); -+ return storageData[word]; -+ } -+ -+ /// Accesses a given word with optional assertions -+ CUTLASS_HOST_DEVICE Storage const &storage(int word) const { -+ CUTLASS_ASSERT(word < kWordCount); -+ return storageData[word]; -+ } -+ -+ public: -+ // -+ // Iterator -+ // -+ -+ /** -+ * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential -+ * read and write access to predicates. -+ * @concept{predicate_iterator_concept} -+ */ -+ class Iterator { -+ /// Reference to PredicateVector instance -+ PredicateVector &vec_; -+ -+ /// Index into PredicateVector -+ int bit_; -+ -+ public: -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ Iterator &operator++() { -+ ++bit_; -+ return *this; -+ } -+ -+ /// Increment -+ CUTLASS_HOST_DEVICE -+ Iterator &operator+=(int offset) { -+ bit_ += offset; -+ return *this; -+ } -+ -+ /// Pre-decrement -+ CUTLASS_HOST_DEVICE -+ Iterator &operator--() { -+ --bit_; -+ return *this; -+ } -+ -+ /// Decrement -+ CUTLASS_HOST_DEVICE -+ Iterator &operator-=(int offset) { -+ bit_ -= offset; -+ return *this; -+ } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ Iterator operator++(int) { -+ Iterator ret(*this); -+ ret.bit_++; -+ return ret; -+ } -+ -+ /// Post-decrement -+ CUTLASS_HOST_DEVICE -+ Iterator operator--(int) { -+ Iterator ret(*this); -+ ret.bit_--; -+ return ret; -+ } -+ -+ /// Iterator advances by some amount -+ CUTLASS_HOST_DEVICE -+ Iterator operator+(int offset) { -+ Iterator ret(*this); -+ ret.bit_ += offset; -+ return ret; -+ } -+ -+ /// Iterator recedes by some amount -+ CUTLASS_HOST_DEVICE -+ Iterator operator-(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ -= offset; -+ return ret; -+ } -+ -+ /// Returns true if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator==(Iterator const &it) const { return bit_ == it.bit_; } -+ -+ /// Returns false if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator!=(Iterator const &it) const { return bit_ != it.bit_; } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool get() { return vec_.at(bit_); } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool at() const { return vec_.at(bit_); } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return at(); } -+ -+ /// Sets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ void set(bool value = true) { vec_.set(bit_, value); } -+ }; -+ -+ /** -+ * @brief An iterator implementing \ref predicate_iterator_concept enabling sequential -+ * read and write access to predicates. -+ * @concept{predicate_iterator_concept} -+ */ -+ class ConstIterator { -+ /// Reference to PredicateVector instance -+ PredicateVector const &vec_; -+ -+ /// Index into PredicateVector -+ int bit_; -+ -+ public: -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator++() { -+ ++bit_; -+ return *this; -+ } -+ -+ /// Increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator+=(int offset) { -+ bit_ += offset; -+ return *this; -+ } -+ -+ /// Pre-decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator--() { -+ --bit_; -+ return *this; -+ } -+ -+ /// Decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator &operator-=(int offset) { -+ bit_ -= offset; -+ return *this; -+ } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator++(int) { -+ ConstIterator ret(*this); -+ ret.bit_++; -+ return ret; -+ } -+ -+ /// Post-decrement -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator--(int) { -+ ConstIterator ret(*this); -+ ret.bit_--; -+ return ret; -+ } -+ -+ /// Iterator advances by some amount -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator+(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ += offset; -+ return ret; -+ } -+ -+ /// Iterator recedes by some amount -+ CUTLASS_HOST_DEVICE -+ ConstIterator operator-(int offset) { -+ ConstIterator ret(*this); -+ ret.bit_ -= offset; -+ return ret; -+ } -+ -+ /// Returns true if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; } -+ -+ /// Returns false if iterators point to the same bit -+ CUTLASS_HOST_DEVICE -+ bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool get() { return vec_.at(bit_); } -+ -+ /// Gets the bit at the pointed to location -+ CUTLASS_HOST_DEVICE -+ bool at() const { return vec_.at(bit_); } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return at(); } -+ }; -+ -+ /// Iterator that always returns true -+ struct TrivialIterator { -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TrivialIterator() {} -+ -+ /// Copy constructor -+ CUTLASS_HOST_DEVICE -+ TrivialIterator(Iterator const &it) {} -+ -+ /// Constructs an iterator from a PredicateVector -+ CUTLASS_HOST_DEVICE -+ TrivialIterator(PredicateVector const &_vec) {} -+ -+ /// Pre-increment -+ CUTLASS_HOST_DEVICE -+ TrivialIterator &operator++() { return *this; } -+ -+ /// Post-increment -+ CUTLASS_HOST_DEVICE -+ TrivialIterator operator++(int) { return *this; } -+ -+ /// Dereferences iterator -+ CUTLASS_HOST_DEVICE -+ bool operator*() const { return true; } -+ }; -+ -+ public: -+ // -+ // Methods -+ // -+ -+ /// Initialize the predicate vector -+ CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); } -+ -+ /// Fills all predicates with a given value -+ CUTLASS_HOST_DEVICE void fill(bool value = true) { -+ Storage item = (value ? ~Storage(0) : Storage(0)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = item; -+ } -+ } -+ -+ /// Clears all predicates -+ CUTLASS_HOST_DEVICE void clear() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = 0; -+ } -+ } -+ -+ /// Sets all predicates to true -+ CUTLASS_HOST_DEVICE void enable() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = ~Storage(0); -+ } -+ } -+ -+ /// Accesses a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); } -+ -+ /// Accesses a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE bool at(int idx) const { -+ int bit, word; -+ computeStorageOffset(word, bit, idx); -+ -+ return ((storage(word) >> bit) & 1); -+ } -+ -+ /// Set a bit within the predicate vector. -+ CUTLASS_HOST_DEVICE void set(int idx, bool value = true) { -+ int bit, word; -+ computeStorageOffset(word, bit, idx); -+ -+ Storage disable_mask = (~(Storage(1) << bit)); -+ Storage enable_mask = (Storage(value) << bit); -+ -+ storage(word) = ((storage(word) & disable_mask) | enable_mask); -+ } -+ -+ /// Computes the intersection of two identical predicate vectors. -+ CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = (storage(i) & predicates.storage(i)); -+ } -+ return *this; -+ } -+ -+ /// Computes the union of two identical predicate vectors. -+ CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kWordCount; ++i) { -+ storage(i) = (storage(i) | predicates.storage(i)); -+ } -+ return *this; -+ } -+ -+ /// Returns true if entire predicate array is zero. -+ CUTLASS_HOST_DEVICE bool is_zero() const { -+ Storage mask(0); -+ for (int byte = 0; byte < sizeof(Storage); ++byte) { -+ Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); -+ mask |= (byte_mask << (byte * 8)); -+ } -+ uint32_t result = 0; -+ for (int word = 0; word < kWordCount; ++word) { -+ result |= storage(word); -+ } -+ return result == 0; -+ } -+ -+ /// Returns an iterator to the start of the bit vector -+ CUTLASS_DEVICE -+ Iterator begin() { return Iterator(*this); } -+ -+ /// Returns an iterator -+ CUTLASS_DEVICE -+ Iterator end() { return Iterator(*this, kPredicates); } -+ -+ /// Returns a ConstIterator -+ CUTLASS_DEVICE -+ ConstIterator const_begin() const { return ConstIterator(*this); } -+ -+ /// Returns a ConstIterator -+ CUTLASS_DEVICE -+ ConstIterator const_end() const { return ConstIterator(*this, kPredicates); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/quaternion.h b/3rdparty/cutlass/include/cutlass/quaternion.h -new file mode 100644 -index 0000000..1015be4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/quaternion.h -@@ -0,0 +1,753 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a densely packed quaternion object intended for storing data in registers and -+ executing quaternion operations within a CUDA or host thread. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+#include "cutlass/array.h" -+#include "cutlass/real.h" -+#include "cutlass/coord.h" -+#include "cutlass/matrix.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/layout/vector.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Quaternion: xi + yj + zk + w -+template < -+ typename Element_ = float ///< element type -+> -+class Quaternion : public Array { -+public: -+ -+ /// Logical rank of tensor index space -+ static int const kRank = 1; -+ -+ /// Number of elements -+ static int const kExtent = 4; -+ -+ /// Base class is a four-element array -+ using Base = Array; -+ -+ /// Element type -+ using Element = typename Base::Element; -+ -+ /// Reference type to an element -+ using Reference = typename Base::reference; -+ -+ /// Index type -+ using Index = int; -+ -+ /// Quaternion storage - imaginary part -+ static int const kX = 0; -+ -+ /// Quaternion storage - imaginary part -+ static int const kY = 1; -+ -+ /// Quaternion storage - imaginary part -+ static int const kZ = 2; -+ -+ /// Quaternion storage - real part -+ static int const kW = 3; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a quaternion q = 0 -+ CUTLASS_HOST_DEVICE -+ Quaternion() { -+ Base::at(kX) = Element(); -+ Base::at(kY) = Element(); -+ Base::at(kZ) = Element(); -+ Base::at(kW) = Element(); -+ } -+ -+ /// Constructs a quaternion q = w + 0*i + 0*j + 0*k -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Element w_ -+ ) { -+ Base::at(kX) = Element(); -+ Base::at(kY) = Element(); -+ Base::at(kZ) = Element(); -+ Base::at(kW) = w_; -+ } -+ -+ /// Constructs a quaternion q = w + x*i + y*j + z*k -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Element x_, -+ Element y_, -+ Element z_, -+ Element w_ -+ ) { -+ Base::at(kX) = x_; -+ Base::at(kY) = y_; -+ Base::at(kZ) = z_; -+ Base::at(kW) = w_; -+ } -+ -+ /// Constructs a quaternion from a vector representing the imaginary part and a real number -+ CUTLASS_HOST_DEVICE -+ Quaternion( -+ Matrix3x1 const &imag_, -+ Element w_ = Element() -+ ) { -+ Base::at(kX) = imag_[0]; -+ Base::at(kY) = imag_[1]; -+ Base::at(kZ) = imag_[2]; -+ Base::at(kW) = w_; -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(Index idx) const { -+ return Base::at(idx); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(Index idx) { -+ return Base::at(idx); -+ } -+ -+ /// Accesses the x element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element x() const { -+ return Base::at(kX); -+ } -+ -+ /// Accesses the x element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference x() { -+ return Base::at(kX); -+ } -+ -+ /// Accesses the y element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element y() const { -+ return Base::at(kY); -+ } -+ -+ /// Accesses the y element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference y() { -+ return Base::at(kY); -+ } -+ -+ /// Accesses the z element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element z() const { -+ return Base::at(kZ); -+ } -+ -+ /// Accesses the z element of the imaginary part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference z() { -+ return Base::at(kZ); -+ } -+ -+ /// Accesses the real part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Element w() const { -+ return Base::at(kW); -+ } -+ -+ /// Accesses the real part of the quaternion -+ CUTLASS_HOST_DEVICE -+ Reference w() { -+ return Base::at(kW); -+ } -+ -+ /// Returns the pure imaginary part of the quaternion as a 3-vector -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 pure() const { -+ return Matrix3x1(x(), y(), z()); -+ } -+ -+ /// Returns a quaternion representation of a spatial rotation given a unit-length axis and -+ /// a rotation in radians. -+ CUTLASS_HOST_DEVICE -+ static Quaternion rotation( -+ Matrix3x1 const &axis_unit, ///< axis of rotation (assumed to be unit length) -+ Element theta) { ///< angular rotation in radians -+ -+ Element s = fast_sin(theta / Element(2)); -+ -+ return Quaternion( -+ s * axis_unit[0], -+ s * axis_unit[1], -+ s * axis_unit[2], -+ fast_cos(theta / Element(2)) -+ ); -+ } -+ -+ /// Returns a quaternion representation of a spatial rotation represented as a -+ /// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians -+ CUTLASS_HOST_DEVICE -+ static Quaternion rotation( -+ Element r_x, -+ Element r_y, -+ Element r_z, -+ Element theta) { ///< angular rotation in radians -+ -+ return rotation({r_x, r_y, r_z}, theta); -+ } -+ -+ /// Geometric rotation of a 3-element vector -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 rotate(Matrix3x1 const &rhs) const { -+ return (*this * Quaternion(rhs, 0) * reciprocal(*this)).pure(); -+ } -+ -+ /// Inverse rotation operation -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 rotate_inv(Matrix3x1 const &rhs) const { -+ return (reciprocal(*this) * Quaternion(rhs, 0) * *this).pure(); -+ } -+ -+ /// Rotates a 3-vector assuming this is a unit quaternion (a spinor) -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 spinor(Matrix3x1 const &rhs) const { -+ return (*this * Quaternion(rhs, 0) * conj(*this)).pure(); -+ } -+ -+ /// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor) -+ CUTLASS_HOST_DEVICE -+ Matrix3x1 spinor_inv(Matrix3x1 const &rhs) const { -+ return (conj(*this) * Quaternion(rhs, 0) * *this).pure(); -+ } -+ -+ /// In-place addition -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator+=(Quaternion const &rhs) { -+ *this = (*this + rhs); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator-=(Quaternion const &rhs) { -+ *this = (*this - rhs); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator*=(Quaternion const &rhs) { -+ *this = (*this * rhs); -+ return *this; -+ } -+ -+ /// Scalar multiplication -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator*=(Element s) { -+ *this = (*this * s); -+ return *this; -+ } -+ -+ /// In-place Division -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator/=(Quaternion const &rhs) { -+ *this = (*this / rhs); -+ return *this; -+ } -+ -+ /// In-place Division -+ template -+ CUTLASS_HOST_DEVICE -+ Quaternion &operator/=(Element s) { -+ *this = (*this / s); -+ return *this; -+ } -+ -+ /// Computes a 3x3 rotation matrix (row-major representation) -+ CUTLASS_HOST_DEVICE -+ Matrix3x3 as_rotation_matrix_3x3() const { -+ Matrix3x3 m( -+ w() * w() + x() * x() - y() * y() - z() * z(), -+ 2 * x() * y() - 2 * w() * z(), -+ 2 * x() * z() + 2 * w() * y(), -+ -+ 2 * x() * y() + 2 * w() * z(), -+ w() * w() - x() * x() + y() * y() - z() * z(), -+ 2 * y() * z() - 2 * w() * x(), -+ -+ 2 * x() * z() - 2 * w() * y(), -+ 2 * y() * z() + 2 * w() * x(), -+ w() * w() - x() * x() - y() * y() + z() * z() -+ ); -+ return m; -+ } -+ -+ /// Computes a 4x4 rotation matrix (row-major representation) -+ CUTLASS_HOST_DEVICE -+ Matrix4x4 as_rotation_matrix_4x4() const { -+ Matrix4x4 m = Matrix4x4::identity(); -+ m.set_slice_3x3(as_rotation_matrix_3x3()); -+ return m; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a quaternion that is non-zero only in its real element. -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion( -+ Element w) { ///< real part -+ -+ return Quaternion(w); -+} -+ -+/// Constructs a quaternion from a vector and real -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion( -+ Matrix3x1 const &imag, ///< imaginary party as a vector -+ Element w) { ///< real part -+ -+ return Quaternion(imag, w); -+} -+ -+/// Constructs a quaternion from a unit-length rotation axis and a rotation -+/// angle in radians -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_QuaternionRotation( -+ Matrix3x1 const &axis_unit, ///< rotation axis (unit-length) -+ Element w) { ///< rotation angle in radians -+ -+ return Quaternion::rotation(axis_unit, w); -+} -+ -+/// Constructs a quaternion q = xi + yj + zk + w -+template -+CUTLASS_HOST_DEVICE -+Quaternion make_Quaternion(Element x, Element y, Element z, Element w) { -+ return Quaternion(x, y, z, w); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the real part of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element const &real(Quaternion const &q) { -+ return q.w(); -+} -+ -+/// Returns the real part of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element &real(Quaternion &q) { -+ return q.w(); -+} -+ -+/// Returns the magnitude of the quaternion number -+template -+CUTLASS_HOST_DEVICE -+Element abs(Quaternion const &q) { -+ return fast_sqrt(norm(q)); -+} -+ -+/// Quaternion conjugate -+template -+CUTLASS_HOST_DEVICE -+Quaternion conj(Quaternion const &q) { -+ return make_Quaternion( -+ -q.x(), -+ -q.y(), -+ -q.z(), -+ q.w() -+ ); -+} -+ -+/// Computes the squared magnitude of the quaternion -+template -+CUTLASS_HOST_DEVICE -+Element norm(Quaternion const &q) { -+ return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w(); -+} -+ -+/// Quaternion reciprocal -+template -+CUTLASS_HOST_DEVICE -+Quaternion reciprocal(Quaternion const &q) { -+ -+ Element nsq = norm(q); -+ -+ return make_Quaternion( -+ -q.x() / nsq, -+ -q.y() / nsq, -+ -q.z() / nsq, -+ q.w() / nsq -+ ); -+} -+ -+/// Returns a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Quaternion unit(Quaternion const &q) { -+ -+ Element rcp_mag = Element(1) / abs(q); -+ -+ return make_Quaternion( -+ q.x() * rcp_mag, -+ q.y() * rcp_mag, -+ q.z() * rcp_mag, -+ q.w() * rcp_mag -+ ); -+} -+ -+/// Quaternion exponential -+template -+CUTLASS_HOST_DEVICE -+Quaternion exp(Quaternion const &q) { -+ -+ Element exp_ = fast_exp(q.w()); -+ Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); -+ Element sin_norm = fast_sin(imag_norm); -+ -+ return make_Quaternion( -+ exp_ * q.x() * sin_norm / imag_norm, -+ exp_ * q.y() * sin_norm / imag_norm, -+ exp_ * q.z() * sin_norm / imag_norm, -+ exp_ * fast_cos(imag_norm) -+ ); -+} -+ -+/// Quaternion natural logarithm -+template -+CUTLASS_HOST_DEVICE -+Quaternion log(Quaternion const &q) { -+ -+ Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); -+ Element s = fast_acos(q.w() / abs(q)) / v; -+ -+ return make_Quaternion( -+ q.x() * s, -+ q.y() * s, -+ q.z() * s, -+ fast_log(q.w()) -+ ); -+} -+ -+/// Gets the rotation angle from a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Element get_rotation_angle(Quaternion const &q_unit) { -+ return fast_acos(q_unit.w()) * Element(2); -+} -+ -+/// Gets the rotation axis from a unit-length quaternion -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 get_rotation_axis(Quaternion const &q_unit) { -+ return q_unit.pure().unit(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Equality operator -+template -+CUTLASS_HOST_DEVICE -+bool operator==(Quaternion const &lhs, Quaternion const &rhs) { -+ return lhs.x() == rhs.x() && -+ lhs.y() == rhs.y() && -+ lhs.z() == rhs.z() && -+ lhs.w() == rhs.w(); -+} -+ -+/// Inequality operator -+template -+CUTLASS_HOST_DEVICE -+bool operator!=(Quaternion const &lhs, Quaternion const &rhs) { -+ return !(lhs == rhs); -+} -+ -+/// Quaternion scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Quaternion q, Element s) { -+ return make_Quaternion( -+ q.x() * s, -+ q.y() * s, -+ q.z() * s, -+ q.w() * s -+ ); -+} -+ -+/// Quaternion scalar multiplication -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Element s, Quaternion const &q) { -+ return make_Quaternion( -+ s * q.x(), -+ s * q.y(), -+ s * q.z(), -+ s * q.w() -+ ); -+} -+ -+/// Quaternion scalar division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Quaternion const &q, Element s) { -+ return make_Quaternion( -+ q.x() / s, -+ q.y() / s, -+ q.z() / s, -+ q.w() / s -+ ); -+} -+ -+/// Quaternion unary negation -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator-(Quaternion const &q) { -+ return make_Quaternion( -+ -q.x(), -+ -q.y(), -+ -q.z(), -+ -q.w() -+ ); -+} -+ -+/// Quaternion addition -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator+(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.x() + rhs.x(), -+ lhs.y() + rhs.y(), -+ lhs.z() + rhs.z(), -+ lhs.w() + rhs.w() -+ ); -+} -+ -+/// Quaternion subtraction -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator-(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.x() - rhs.x(), -+ lhs.y() - rhs.y(), -+ lhs.z() - rhs.z(), -+ lhs.w() - rhs.w() -+ ); -+} -+ -+/// Quaternion product -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator*(Quaternion const &lhs, Quaternion const &rhs) { -+ return make_Quaternion( -+ lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(), -+ lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(), -+ lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(), -+ lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z() -+ ); -+} -+ -+/// Quaternion division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Quaternion const &lhs, Quaternion const &rhs) { -+ return lhs * reciprocal(rhs); -+} -+ -+/// Quaternion scalar division -+template -+CUTLASS_HOST_DEVICE -+Quaternion operator/(Element s, Quaternion const &q) { -+ return s * reciprocal(q); -+} -+ -+/// Comparison -+template -+CUTLASS_HOST_DEVICE -+bool operator<(Quaternion const &lhs, Quaternion const &rhs) { -+ //TODO -+ return true; -+} -+ -+/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -+/// a reciprocal. -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 spinor_rotation( -+ Quaternion const &spinor, /// unit-length quaternion -+ Matrix3x1 const &rhs) { /// arbitrary 3-vector -+ -+ return (spinor * Quaternion(rhs, 0) * conj(spinor)).pure(); -+} -+ -+/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing -+/// a reciprocal. -+template -+CUTLASS_HOST_DEVICE -+Matrix3x1 spinor_rotation_inv( -+ Quaternion const &spinor, /// unit-length quaternion -+ Matrix3x1 const &rhs) { /// arbitrary 3-vector -+ -+ return (conj(spinor) * Quaternion(rhs, 0) * spinor).pure(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for Quaternion-valued type. -+template -+struct RealType< Quaternion > { -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = Quaternion::kExtent; -+ -+CUTLASS_HOST_DEVICE -+ static Quaternion from_real(double x) { -+ return Quaternion(static_cast(x)); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// Factories -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(half_t(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(float(r)); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+cutlass::Quaternion from_real >(double r) { -+ return cutlass::Quaternion(r); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// functional.h numeric specializations -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct multiplies> { -+ CUTLASS_HOST_DEVICE -+ Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { -+ lhs = lhs * rhs; -+ return lhs; -+ } -+}; -+ -+/// Squares with optional conversion -+template -+struct magnitude_squared, Output> { -+ CUTLASS_HOST_DEVICE -+ Output operator()(Quaternion lhs) const { -+ multiplies mul_op; -+ -+ Output y_w = Output(lhs.w()); -+ Output y_x = Output(lhs.x()); -+ Output y_y = Output(lhs.y()); -+ Output y_z = Output(lhs.z()); -+ -+ return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ -+ mul_op(y_z, y_z); -+ } -+}; -+ -+template -+struct multiply_add, Quaternion, Quaternion> { -+ CUTLASS_HOST_DEVICE -+ Quaternion operator()( -+ Quaternion const &a, -+ Quaternion const &b, -+ Quaternion const &c) const { -+ -+ T x = c.x(); -+ T y = c.y(); -+ T z = c.z(); -+ T w = c.w(); -+ -+ x += a.w() * b.x(); -+ x += b.w() * a.x(); -+ x += a.y() * b.z(); -+ x += -a.z() * b.y(), -+ -+ y += a.w() * b.y(); -+ y += b.w() * a.y(); -+ y += a.z() * b.x(); -+ y += -a.x() * b.z(); -+ -+ z += a.w() * b.z(); -+ z += b.w() * a.z(); -+ z += a.x() * b.y(); -+ z += -a.y() * b.x(); -+ -+ w += a.w() * b.w(); -+ w += -a.x() * b.x(); -+ w += -a.y() * b.y(); -+ w += -a.z() * b.z(); -+ -+ return cutlass::make_Quaternion(x, y, z, w); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/real.h b/3rdparty/cutlass/include/cutlass/real.h -new file mode 100644 -index 0000000..ed9018a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/real.h -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/** -+ \file -+ \brief This class provides helpers to support real<> and complex<> types in generic code. -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+ -+/// Used to determine the real-valued underlying type of a numeric type T. -+template -+struct RealType { -+ using Type = T; -+ -+ /// Number of elements -+ static int const kExtent = 1; -+ -+CUTLASS_HOST_DEVICE -+ static T from_real(double x) { -+ return static_cast(x); -+ } -+}; -+ -+template -+CUTLASS_HOST_DEVICE -+static T from_real(double r) { -+ return T(r); -+} -+ -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h b/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h -new file mode 100644 -index 0000000..92e1f61 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/reduce_split_k.h -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/device_kernel.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ReductionKernel_ -+> -+class ReduceSplitK { -+public: -+ using ReductionKernel = ReductionKernel_; -+ -+ using Shape = typename ReductionKernel::Shape; -+ using ReductionOp = typename ReductionKernel::ReductionOp; -+ using OutputOp = typename ReductionKernel::OutputOp; -+ -+ using ElementWorkspace = typename ReductionKernel::ElementWorkspace; -+ using ElementAccumulator = typename ReductionKernel::ElementAccumulator; -+ using ElementOutput = typename ReductionKernel::ElementOutput; -+ -+ using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; -+ using OutputTensorRef = typename ReductionKernel::OutputTensorRef; -+ -+ using StrideIndex = typename ReductionKernel::StrideIndex; -+ -+ /// Argument structure -+ struct Arguments { -+ -+ // -+ // Data members -+ // -+ -+ MatrixCoord problem_size; -+ int partitions; -+ size_t partition_stride; -+ WorkspaceTensorRef workspace; -+ OutputTensorRef destination; -+ OutputTensorRef source; -+ typename OutputOp::Params output; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Arguments() : -+ problem_size(0, 0), -+ partitions(1), -+ partition_stride(0) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ MatrixCoord const & problem_size -+ ): -+ problem_size(problem_size) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ MatrixCoord problem_size_, -+ int partitions_, -+ size_t partition_stride_, -+ WorkspaceTensorRef workspace_, -+ OutputTensorRef destination_, -+ OutputTensorRef source_, -+ typename OutputOp::Params output_ = typename OutputOp::Params(), -+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ partitions(partitions_), -+ partition_stride(partition_stride_), -+ workspace(workspace_), -+ destination(destination_), -+ source(source_), -+ output(output_), -+ reduction(reduction_) -+ { -+ -+ } -+ -+ }; -+ -+private: -+ /// Kernel parameters object -+ typename ReductionKernel::Params params_; -+ -+public: -+ /// Constructs Reduction SplitK -+ ReduceSplitK() { } -+ -+ /// Determines whether the ReduceSplitK can execute the given problem. -+ static Status can_implement(Arguments const &args) { -+ -+ return Status::kSuccess; -+ } -+ -+ /// Gets the workspace size -+ static size_t get_workspace_size(Arguments const &args) { -+ // needs no additional workspace -+ return 0; -+ } -+ -+ /// Initializes Reduction state from arguments. -+ Status initialize( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ // initialize the params structure from the arguments -+ params_ = typename ReductionKernel::Params( -+ args.problem_size, -+ args.partitions, -+ args.partition_stride, -+ args.workspace, -+ args.destination, -+ args.source, -+ args.output, -+ args.reduction -+ ); -+ -+ return Status::kSuccess; -+ -+ } -+ -+ /// Initializes Reduction kernel state from arguments. -+ Status update(Arguments const &args, void *workspace = nullptr) { -+ -+ // update the params structure from the arguments -+ params_.workspace.reset(args.workspace.non_const_ref().data()); -+ params_.destination.reset(args.destination.non_const_ref().data()); -+ params_.source.reset(args.source.non_const_ref().data()); -+ params_.output = args.output; -+ params_.reduction = args.reduction; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status run(cudaStream_t stream = nullptr) { -+ -+ // -+ // Launch reduction kernel -+ // -+ dim3 block = ReductionKernel::block_shape(); -+ dim3 grid = ReductionKernel::grid_shape(params_.problem_size); -+ -+ Kernel<<< grid, block, 0, stream >>>(params_); -+ -+ cudaError_t result = cudaGetLastError(); -+ -+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; -+ } -+ -+ -+ /// Runs the kernel using initialized state. -+ Status operator()(cudaStream_t stream = nullptr) { -+ return run(stream); -+ } -+ -+ /// Runs the kernel using initialized state. -+ Status operator()( -+ Arguments const &args, -+ void *workspace = nullptr, -+ cudaStream_t stream = nullptr) { -+ -+ Status status = initialize(args, workspace, stream); -+ -+ if (status == Status::kSuccess) { -+ status = run(stream); -+ } -+ -+ return status; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h -new file mode 100644 -index 0000000..31d50f6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce.h -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/device/tensor_reduce_affine_strided.h" -+#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on specific CUTLASS layouts over exactly one index -+template < -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename Layout_, -+ typename ReductionOp_, -+ int VectorLength_ = 1, -+ typename ElementCompute_ = ElementOutput_ -+> -+struct TensorReduction { -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using Layout = Layout_; -+ using ReductionOp = ReductionOp_; -+ static int const kVectorLength = VectorLength_; -+ using ElementCompute = ElementCompute_; -+ -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Reduction operator -+ using ReductionDeviceStridedOperator = TensorReductionAffineStrided< -+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute -+ >; -+ -+ using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous< -+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute -+ >; -+ -+ // -+ // Data members -+ // -+ -+ ReductionDeviceStridedOperator reduction_strided; -+ ReductionDeviceContiguousOperator reduction_contiguous; -+ int reduction_index; -+ -+ // -+ // Methods -+ // -+ -+ /// -+ TensorReduction( -+ TensorCoord extent, -+ int reduction_index_ -+ ): -+ reduction_index(reduction_index_) { -+ -+ Coord<4> extent_affine; -+ -+ switch (reduction_index) { -+ case 0: -+ extent_affine[0] = extent[1]; -+ extent_affine[1] = extent[2]; -+ extent_affine[2] = extent[0]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 1: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[2]; -+ extent_affine[2] = extent[1]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 2: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[1]; -+ extent_affine[2] = extent[2]; -+ extent_affine[3] = extent[3]; -+ break; -+ case 3: -+ extent_affine[0] = extent[0]; -+ extent_affine[1] = extent[1]; -+ extent_affine[2] = extent[2]; -+ extent_affine[3] = extent[3]; -+ break; -+ default: break; -+ } -+ -+ if (reduction_index == 3) { -+ reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine); -+ } -+ else { -+ reduction_strided = ReductionDeviceStridedOperator(extent_affine); -+ } -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.good(); -+ } -+ return reduction_strided.good(); -+ } -+ -+ /// Size of one workspace -+ int64_t workspace_stride() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.workspace_stride(); -+ } -+ else { -+ return reduction_strided.workspace_stride(); -+ } -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ if (reduction_index == 3) { -+ return reduction_contiguous.workspace_size(); -+ } -+ else { -+ return reduction_strided.workspace_size(); -+ } -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status reduce( -+ TensorRef dst_ref, -+ TensorRef src_ref, -+ void *device_workspace_ptr = nullptr, -+ ElementCompute reduction_identity = ElementCompute(), -+ ReductionOp reduction_op = ReductionOp(), -+ cudaStream_t stream = nullptr) { -+ -+ int64_t src_stride[3]; -+ int64_t dst_stride[3]; -+ -+ switch (reduction_index) { -+ case 0: -+ src_stride[0] = src_ref.stride()[1]; -+ src_stride[1] = src_ref.stride()[0]; -+ src_stride[2] = src_ref.stride()[2]; -+ dst_stride[0] = dst_ref.stride()[1]; -+ dst_stride[1] = dst_ref.stride()[0]; -+ break; -+ case 1: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[0]; -+ src_stride[2] = src_ref.stride()[1]; -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[0]; -+ break; -+ case 2: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[1]; -+ src_stride[2] = src_ref.stride()[0]; -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[1]; -+ break; -+ case 3: -+ src_stride[0] = src_ref.stride()[2]; -+ src_stride[1] = src_ref.stride()[1]; -+ src_stride[2] = src_ref.stride()[0]; -+ -+ dst_stride[0] = dst_ref.stride()[2]; -+ dst_stride[1] = dst_ref.stride()[1]; -+ dst_stride[2] = dst_ref.stride()[0]; -+ -+ default: break; -+ } -+ -+ if (reduction_index == 3) { -+ return reduction_contiguous( -+ dst_ref.data(), -+ dst_stride, -+ src_ref.data(), -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+ else { -+ return reduction_strided( -+ dst_ref.data(), -+ dst_stride, -+ src_ref.data(), -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+ } -+ -+ Status operator()( -+ TensorRef dst_ref, -+ TensorRef src_ref, -+ void *device_workspace_ptr = nullptr, -+ ElementCompute reduction_identity = ElementCompute(), -+ ReductionOp reduction_op = ReductionOp(), -+ cudaStream_t stream = nullptr) { -+ -+ return reduce( -+ dst_ref, -+ src_ref, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h -new file mode 100644 -index 0000000..234a1c4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h -@@ -0,0 +1,373 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on layouts which are affine -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2) -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename ReductionOp_, -+ int VectorLength = 1, -+ typename ElementCompute_ = ElementOutput_, -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineContiguous { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ReductionOp = ReductionOp_; -+ using ElementCompute = ElementCompute_; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal status field -+ Status status; -+ -+ /// Extent of tensor in source layout -+ Coord extent; -+ -+ /// Number of points in the outer index space -+ int64_t outer_count; -+ -+ /// Number of elements in the inner index space -+ int64_t inner_count; -+ -+ /// Number of workspaces needed -+ int workspace_count; -+ -+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 grid_shape; -+ -+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 threadblock_shape; -+ -+ /// CUDA grid shape for the final reduction step if needed -+ dim3 grid_final; -+ -+ /// CUDA threadblock shape for the final reduction step if needed -+ dim3 threadblock_final; -+ -+private: -+ // -+ // Methods -+ // -+ -+ /// Helper to reshape 'count' such that it is less than 2 x 'ext' -+ static int reshape_pow2(int ext, int count) { -+ if (ext > count) { -+ return 1; -+ } -+ int x = 1; -+ for (; count >= ext * 2; ) { -+ count >>= 1; -+ x <<= 1; -+ } -+ return x; -+ } -+ -+public: -+ -+ /// Default ctor -+ TensorReductionAffineContiguous(): -+ status(Status::kErrorInvalidProblem), -+ extent(), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0), -+ grid_shape(0, 0, 0), -+ threadblock_shape(0, 0, 0) { } -+ -+ /// Constructor -+ TensorReductionAffineContiguous( -+ Coord extent_, -+ int target_threadblock_count = 128 -+ ): -+ status(Status::kSuccess), -+ extent(extent_), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0) { -+ -+ // -+ // Plan the parallel mapping strategy. -+ // -+ -+ outer_count = 1; -+ inner_count = 1; -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank; ++p) { -+ outer_count *= extent[p]; -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= extent[kReducedRank + p]; -+ } -+ -+ int cta_count_x = 1; -+ int cta_count_y = 1; -+ int cta_count_z = 1; -+ -+ int cta_threads_x = kThreads; -+ int cta_threads_y = 1; -+ int cta_threads_z = 1; -+ -+ // Determine CTA shape -+ int64_t inner_vector_count = inner_count / kVectorLength; -+ -+ // Priority 1. Assign threadblocks to outer indices if possible -+ if (outer_count > target_threadblock_count) { -+ cta_count_x = 1; -+ cta_count_y = target_threadblock_count; -+ cta_count_z = 1; -+ } -+ else { -+ -+ cta_count_y = int(outer_count); -+ int remaining_ctas = target_threadblock_count / cta_count_y; -+ -+ // Priority 2. Assign inner dimensions to one CTA -+ if (inner_vector_count > cta_threads_x) { -+ int64_t cta_z_bound = inner_vector_count / cta_threads_x; -+ if (cta_z_bound > remaining_ctas) { -+ cta_count_z = remaining_ctas; -+ } -+ else { -+ cta_count_z = int(cta_z_bound); -+ } -+ } -+ else { -+ cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x); -+ cta_count_z = 1; -+ } -+ } -+ -+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); -+ threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z); -+ -+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0); -+ -+ // Determine shape of final reduction kernel if needed -+ if (workspace_count) { -+ -+ int final_threads = kThreads; -+ int final_ctas = 1; -+ -+ if (outer_count > kThreads) { -+ final_ctas = int(outer_count + kThreads - 1) / kThreads; -+ } -+ else { -+ final_threads = int(outer_count); -+ } -+ -+ grid_final = dim3(final_ctas, 1, 1); -+ threadblock_final = dim3(final_threads, 1, 1); -+ } -+ else { -+ grid_final = dim3(0, 0, 0); -+ threadblock_final = dim3(0, 0, 0); -+ } -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ return status == Status::kSuccess; -+ } -+ -+ /// Size (in bytes) of workspace elements which are densely packed together -+ int64_t workspace_stride() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ return outer_count * sizeof_bits::value / 8; -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ // No reduction across CTAs -+ if (grid_shape.z == 1) { -+ return 0; -+ } -+ -+ return workspace_stride() * grid_shape.z; -+ } -+ -+ /// Performs a reduction -+ Status reduce( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ // Initial status check -+ if (!good()) { -+ return status; -+ } -+ -+ // Guard against null workspace -+ if (workspace_count > 1 && device_workspace_ptr == nullptr) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ // Define reduction kernel -+ using ReductionKernel = kernel::TensorReductionAffineContiguous< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using Params = typename ReductionKernel::Params; -+ -+ // Construct the parameters -+ Params params( -+ extent, -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ static_cast(device_workspace_ptr), -+ workspace_stride(), -+ workspace_count, -+ reduction_op, -+ reduction_identity); -+ -+ // Shared memory size -+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); -+ -+ // Launch the kernel -+ Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ // Final reduction kernel -+ if (workspace_count) { -+ Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); -+ } -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ return status; -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status operator()( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h -new file mode 100644 -index 0000000..e613934 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h -@@ -0,0 +1,361 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor reduction operator on layouts which are affine -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput_, -+ typename ElementSource_, -+ typename ReductionOp_, -+ int VectorLength = 1, -+ typename ElementCompute_ = ElementOutput_, -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineStrided { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ using ElementOutput = ElementOutput_; -+ using ElementSource = ElementSource_; -+ using ReductionOp = ReductionOp_; -+ using ElementCompute = ElementCompute_; -+ -+ // -+ // Data members -+ // -+ -+ /// Internal status field -+ Status status; -+ -+ /// Extent of tensor in source layout -+ Coord extent; -+ -+ /// Number of points in the outer index space -+ int64_t outer_count; -+ -+ /// Number of elements in the inner index space -+ int64_t inner_count; -+ -+ /// Number of workspaces needed -+ int workspace_count; -+ -+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 grid_shape; -+ -+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) -+ dim3 threadblock_shape; -+ -+ /// CUDA grid shape for the final reduction step if needed -+ dim3 grid_final; -+ -+ /// CUDA threadblock shape for the final reduction step if needed -+ dim3 threadblock_final; -+ -+private: -+ // -+ // Methods -+ // -+ -+ /// Helper to reshape 'count' such that it is less than 2 x 'ext' -+ static int reshape_pow2(int ext, int count) { -+ if (ext > count) { -+ return 1; -+ } -+ int x = 1; -+ for (; count >= ext * 2; ) { -+ count >>= 1; -+ x <<= 1; -+ } -+ return x; -+ } -+ -+public: -+ -+ /// Default ctor -+ TensorReductionAffineStrided(): -+ status(Status::kErrorInvalidProblem), -+ extent(), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0), -+ grid_shape(0, 0, 0), -+ threadblock_shape(0, 0, 0) { } -+ -+ /// Constructor -+ TensorReductionAffineStrided( -+ Coord extent_, -+ int target_threadblock_count = 128 -+ ): -+ status(Status::kSuccess), -+ extent(extent_), -+ outer_count(0), -+ inner_count(0), -+ workspace_count(0) { -+ -+ // -+ // Plan the parallel mapping strategy. -+ // -+ -+ outer_count = 1; -+ inner_count = 1; -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ outer_count *= extent[p]; -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= extent[kReducedRank + p - 1]; -+ } -+ -+ // Compute plan for the reduction -+ int extent_c = extent[kRank - 1]; -+ int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength; -+ -+ // Determine CTA shape -+ int cta_width = kThreads * kVectorLength; -+ int cta_ways = reshape_pow2(extent_c, cta_width); -+ int cta_threads_x = kThreads / cta_ways; -+ -+ threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64)); -+ -+ // This leads to an error. -+ if (threadblock_shape.z > 1) { -+ if (threadblock_shape.y != 1) { -+ status = Status::kErrorInternal; -+ return; -+ } -+ } -+ -+ // Determine grid shape -+ int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x; -+ int cta_count_y = std::max(1, target_threadblock_count / cta_count_x); -+ -+ // Limit the number of CTAs assigned to outer dimension -+ if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) { -+ cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y; -+ } -+ -+ // Limit the number of CTAs assigned to inner dimension -+ int cta_count_z = std::max(1, target_threadblock_count / cta_count_y); -+ if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) { -+ cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z; -+ } -+ -+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); -+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0); -+ -+ // Determine shape of final reduction kernel if needed -+ grid_final = dim3(cta_count_x, int(outer_count)); -+ threadblock_final = dim3(cta_threads_x, 1, 1); -+ } -+ -+ /// Simple check to verify the object is initialized correctly -+ bool good() const { -+ return status == Status::kSuccess; -+ } -+ -+ /// Size of one CTA's workspace -+ int64_t workspace_stride() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ int vector_size_bytes = kVectorLength * sizeof_bits::value / 8; -+ -+ return extent[kRank - 1] * vector_size_bytes; -+ } -+ -+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs -+ int64_t workspace_size() const { -+ -+ // Error condition -+ if (!good()) { -+ return 0; -+ } -+ -+ // No reduction across CTAs -+ if (grid_shape.z == 1) { -+ return 0; -+ } -+ -+ return workspace_stride() * outer_count * grid_shape.z; -+ } -+ -+ /// Performs a reduction -+ Status reduce( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ // Initial status check -+ if (!good()) { -+ return status; -+ } -+ -+ // Guard against null workspace -+ if (workspace_count > 1 && device_workspace_ptr == nullptr) { -+ return Status::kErrorWorkspaceNull; -+ } -+ -+ // Define reduction kernel -+ using ReductionKernel = kernel::TensorReductionAffineStrided< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal< -+ kRank, -+ kReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ kVectorLength, -+ ElementCompute, -+ kThreads>; -+ -+ using Params = typename ReductionKernel::Params; -+ -+ // Construct the parameters -+ Params params( -+ extent, -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ static_cast(device_workspace_ptr), -+ workspace_stride(), -+ workspace_count, -+ reduction_op, -+ reduction_identity); -+ -+ // Shared memory size -+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); -+ -+ // Launch the kernel -+ Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ -+ // Final reduction kernel -+ if (workspace_count) { -+ -+ Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); -+ -+ // Check error condition -+ if (cudaPeekAtLastError() == cudaSuccess) { -+ status = Status::kSuccess; -+ } -+ else { -+ status = Status::kErrorInternal; -+ } -+ } -+ -+ return status; -+ } -+ -+ /// Helper to use overloaded function call operator -+ Status operator()( -+ ElementOutput *dst_ptr, ///< Pointer to destination tensor -+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) -+ ElementSource const *src_ptr, ///< Pointer to source tensor -+ int64_t src_stride[], ///< Stride vector (of length kRank - 1) -+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace -+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity -+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator -+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched -+ -+ return reduce( -+ dst_ptr, -+ dst_stride, -+ src_ptr, -+ src_stride, -+ device_workspace_ptr, -+ reduction_identity, -+ reduction_op, -+ stream); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h -new file mode 100644 -index 0000000..99e8aed ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a final reduction for softmax -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+template < -+ typename ElementNorm_, -+ typename ElementSum_, -+ typename ElementSoftmaxCompute_, -+ typename ThreadblockShape_, -+ bool GroupedProblem = false -+> -+class ApplySoftmaxFinalReduction { -+public: -+ -+ using ElementNorm = ElementNorm_; -+ using ElementSum = ElementSum_; -+ using ElementSoftmaxCompute = ElementSoftmaxCompute_; -+ using ThreadblockShape = ThreadblockShape_; -+ static const bool isGroupedProblem = GroupedProblem; -+ -+ // -+ // Arguments -+ // -+ -+ struct Arguments { -+ -+ cutlass::gemm::GemmCoord* problem_sizes; -+ cutlass::gemm::GemmCoord problem_size; -+ ElementNorm* block_Norm; -+ ElementSum* block_Sum; -+ int64_t* offset_Norm_Device; -+ int64_t* offset_Sum_Device; -+ int64_t batch_stride_Max; -+ int64_t batch_stride_Sum; -+ -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ // Non-grouped constructor without batching -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum -+ ): -+ problem_size(problem_size), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ problem_sizes(nullptr), -+ offset_Norm_Device(nullptr), -+ offset_Sum_Device(nullptr), -+ batch_stride_Max(0), -+ batch_stride_Sum(0) -+ { -+ -+ } -+ -+ // Non-grouped constructor with batching -+ Arguments( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum, -+ int64_t batch_stride_Max, -+ int64_t batch_stride_Sum -+ ): -+ problem_size(problem_size), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ batch_stride_Max(batch_stride_Max), -+ batch_stride_Sum(batch_stride_Sum), -+ problem_sizes(nullptr), -+ offset_Norm_Device(nullptr), -+ offset_Sum_Device(nullptr) -+ { -+ -+ } -+ -+ -+ // Grouped constructor -+ Arguments( -+ cutlass::gemm::GemmCoord *problem_sizes, -+ ElementNorm* block_Norm, -+ ElementSum* block_Sum, -+ int64_t* offset_Norm_Device, -+ int64_t* offset_Sum_Device -+ ): -+ problem_sizes(problem_sizes), -+ problem_size(cutlass::gemm::GemmCoord(0, 0, 0)), -+ block_Norm(block_Norm), -+ block_Sum(block_Sum), -+ offset_Norm_Device(offset_Norm_Device), -+ offset_Sum_Device(offset_Sum_Device) -+ { -+ -+ } -+ }; -+ -+ struct SharedStorage { -+ -+ -+ }; -+ -+ // -+ // Params struct -+ // -+ -+ struct Params { -+ Arguments args; -+ -+ // -+ // Methods -+ // -+ Params() { } -+ -+ Params(Arguments const &args_): args(args_) { } -+ }; -+ -+private: -+ -+public: -+ -+ CUTLASS_DEVICE -+ ApplySoftmaxFinalReduction() { } -+ -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ apply(params, shared_storage); -+ } -+ -+private: -+ -+ /// Full reduction -+ CUTLASS_DEVICE -+ void apply(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int tid = threadIdx.x; -+ int bid = blockIdx.x; -+ int bdim = blockDim.x; -+ -+ int block_batch = blockIdx.z; -+ -+ // defining three vars for a general reduction module -+ cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; -+ int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; -+ int access_offset = isGroupedProblem ? 0 : bid * bdim; -+ -+ if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; -+ -+ ElementNorm *curr_ptr_Max = isGroupedProblem ? \ -+ params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ -+ params.args.block_Norm + block_batch * params.args.batch_stride_Max; -+ ElementSum *curr_ptr_Sum = isGroupedProblem ? \ -+ params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ -+ params.args.block_Sum + block_batch * params.args.batch_stride_Sum; -+ -+ int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; -+ -+ using ConvertSumOutput = cutlass::NumericConverter; -+ using ConvertNormOutput = cutlass::NumericConverter; -+ -+ using ConvertSum = cutlass::NumericConverter; -+ using ConvertNorm = cutlass::NumericConverter; -+ -+ ConvertSum convert_sum; -+ ConvertNorm convert_norm; -+ -+ ConvertSumOutput convert_sum_output; -+ ConvertNormOutput convert_norm_output; -+ -+ uint32_t float_max_bits = 0xff7fffff; -+ float min_float = reinterpret_cast(float_max_bits); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { -+ ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; -+ ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; -+ ElementNorm *access_n_bak = access_n; -+ ElementSum *access_s_bak = access_s; -+ ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); -+ ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); -+ ElementNorm fetch_n; -+ ElementSum fetch_s; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { -+ cutlass::arch::global_load(fetch_n, access_n, true); -+ max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); -+ access_n += problem_size.m(); -+ } -+ -+ access_n = access_n_bak; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { -+ cutlass::arch::global_load(fetch_n, access_n, true); -+ cutlass::arch::global_load(fetch_s, access_s, true); -+ sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); -+ access_n += problem_size.m(); -+ access_s += problem_size.m(); -+ } -+ -+ ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; -+ -+ access_n = access_n_bak; -+ access_s = access_s_bak; -+ -+ access_n[0] = convert_norm_output(max_val); -+ access_s[0] = convert_sum_output(inv_sum); -+ } -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h -new file mode 100644 -index 0000000..96847e7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h -@@ -0,0 +1,248 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape_, ///< shape of CTA (concept: MatrixShape) -+ typename OutputOp_ , ///< output operator (concept: epilogue::thread operator) -+ typename ReductionOp_, ///< reduction operator (concept: ReductionOperator) -+ int PartitionsPerStage = 4 ///< number of partitions to issue -+> -+class ReduceSplitK { -+public: -+ -+ using Shape = Shape_; -+ using ReductionOp = ReductionOp_; -+ using OutputOp = OutputOp_; -+ static int const kElementsPerAccess = OutputOp::kCount; -+ static int const kPartitionsPerStage = PartitionsPerStage; -+ -+ using ElementWorkspace = typename ReductionOp::Element; -+ using ElementAccumulator = typename ReductionOp::ElementAccumulator; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ -+ using WorkspaceTensorRef = TensorRef; -+ using OutputTensorRef = TensorRef; -+ using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index; -+ -+ using FragmentWorkspace = AlignedArray; -+ using FragmentAccumulator = Array; -+ using FragmentOutput = AlignedArray; -+ -+ // -+ // Types -+ // -+ -+ /// Params structure -+ struct Params { -+ -+ MatrixCoord problem_size; -+ int partitions; -+ size_t partition_stride; -+ WorkspaceTensorRef workspace; -+ OutputTensorRef destination; -+ OutputTensorRef source; -+ typename OutputOp::Params output; -+ typename ReductionOp::Params reduction; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ MatrixCoord problem_size_, -+ int partitions_, -+ size_t partition_stride_, -+ WorkspaceTensorRef workspace_, -+ OutputTensorRef destination_, -+ OutputTensorRef source_, -+ typename OutputOp::Params output_ = typename OutputOp::Params(), -+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params() -+ ): -+ problem_size(problem_size_), -+ partitions(partitions_), -+ partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess), -+ workspace(workspace_), -+ destination(destination_), -+ source(source_), -+ output(output_), -+ reduction(reduction_) { -+ -+ } -+ }; -+ -+ struct SharedStorage { }; -+ -+ -+public: -+ -+ /// Computes the grid size given a chosen threadblock shape -+ CUTLASS_HOST_DEVICE -+ static dim3 grid_shape( -+ cutlass::MatrixCoord problem_size) { -+ -+ return dim3( -+ (problem_size.row() + Shape::kRow - 1) / Shape::kRow, -+ (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); -+ } -+ -+ /// Determines the threadblock shape -+ CUTLASS_HOST_DEVICE -+ static dim3 block_shape() { -+ return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow); -+ } -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &storage) { -+ -+ // Determine CTA position -+ MatrixCoord thread_offset( -+ MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y), -+ MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess) -+ ); -+ -+ // One guard conditional -+ if (!(thread_offset.row() < params.problem_size.row() && -+ thread_offset.column() < params.problem_size.column())) { -+ -+ return; -+ } -+ -+ -+ ReductionOp reduction_op(params.reduction); -+ -+ FragmentAccumulator accumulator; -+ -+ accumulator.clear(); -+ -+ // -+ // Load the first slice -+ // -+ -+ char const *workspace_ptr = -+ reinterpret_cast( -+ params.workspace.data() + params.workspace.offset(thread_offset)); -+ -+ FragmentWorkspace workspace_frag[kPartitionsPerStage]; -+ -+ // -+ // Construct the output operator -+ // -+ -+ OutputOp output_op(params.output); -+ -+ // -+ // Load and accumulate with a simple batched loading sequence. -+ // -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < params.partitions; k += kPartitionsPerStage) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPartitionsPerStage; ++i) { -+ if (k + i < params.partitions) { -+ workspace_frag[i] = *reinterpret_cast(workspace_ptr); -+ workspace_ptr += params.partition_stride; -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPartitionsPerStage; ++i) { -+ if (k + i < params.partitions) { -+ accumulator = reduction_op(accumulator, workspace_frag[i]); -+ } -+ } -+ } -+ -+ // -+ // Conditionally load the source -+ // -+ -+ FragmentOutput source_frag; -+ -+ source_frag.clear(); -+ -+ FragmentOutput const *source_ptr = reinterpret_cast( -+ params.source.data() + params.source.offset(thread_offset)); -+ -+ if (output_op.is_source_needed()) { -+ reinterpret_cast(source_frag) = *source_ptr; -+ } -+ -+ // -+ // Compute the output -+ // -+ -+ typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag); -+ -+ // -+ // Store -+ // -+ -+ FragmentOutput *dest_ptr = reinterpret_cast( -+ params.destination.data() + params.destination.offset(thread_offset)); -+ -+ *dest_ptr = reinterpret_cast(output_frag); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h -new file mode 100644 -index 0000000..d139ed4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h -@@ -0,0 +1,606 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parameters structure -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineContiguousParams { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ Coord extent; /// Extent of source tensor -+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank -+ int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J -+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K -+ int64_t workspace_stride; /// stride (units of bytes) between workspace -+ int workspace_count; /// number of workspaces -+ -+ uint64_t inner_count; /// Number of elements in reduced index space -+ uint64_t outer_count; /// Number of elements in outer index space -+ -+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank -+ ElementSource const * source; /// Poitner to source pointer of rank kRank -+ ReductionOp reduction_op; /// Reduction operator -+ ElementCompute reduction_identity; /// Identity element used by reduction operator -+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorReductionAffineContiguousParams() { -+ -+ } -+ -+ /// Ctor -+ TensorReductionAffineContiguousParams( -+ Coord extent_, ///< Extent of source tensor -+ ElementOutput * dst_ptr_, ///< Output tensor data -+ int64_t dst_stride_[], ///< Stride (units of elements) -+ ElementSource const * src_ptr_, ///< Source tensor data -+ int64_t src_stride_[], ///< Stride (units of elements) -+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions -+ int64_t workspace_stride_, ///< Stride between workspaces -+ int workspace_count_, ///< Number of workspaces -+ ReductionOp reduction_op_, ///< Reduction operator -+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator -+ ): -+ extent(extent_), -+ inner_count(1), -+ outer_count(1), -+ destination(dst_ptr_), -+ source(src_ptr_), -+ device_workspace(device_workspace_), -+ workspace_stride(workspace_stride_), -+ workspace_count(workspace_count_), -+ reduction_op(reduction_op_), -+ reduction_identity(reduction_identity_) { -+ -+ // Initialize divisors for fast div-mod -+ for (int p = 1; p < kRank; ++p) { -+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); -+ } -+ -+ int input_size_bits = sizeof_bits::value; -+ int output_size_bits = sizeof_bits::value; -+ -+ // Compute strides in units of bytes -+ for (int p = 0; p < kReducedRank; ++p) { -+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8; -+ } -+ -+ for (int p = 0; p < kRank - 1; ++p) { -+ src_stride[p] = src_stride_[p] * input_size_bits / 8; -+ } -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank; ++p) { -+ outer_count *= uint64_t(extent[p]); -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= uint64_t(extent[kRank - 1 - p]); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous -+/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineContiguous { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory allocation used for reduction within the CTA -+ struct SharedStorage { -+ Array workspace; -+ }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineContiguousParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_inner_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into a coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); -+ -+ // Compute an offset using the souce stride -+ src_offset = 0; -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kInnerRank - 1; ++i) { -+ src_offset += coord[i] * params.src_stride[kReducedRank + i]; -+ } -+ src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; -+ } -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offsets using destination and source strides -+ dst_offset = 0; -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ src_offset += params.src_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices yielding a single element -+ CUTLASS_DEVICE -+ ElementCompute reduce_indices_( -+ Params const ¶ms, -+ ElementCompute *threadblock_workspace, -+ char const *src_byte_ptr, -+ int coord_c) { -+ -+ NumericArrayConverter convert_source; -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // -+ // Early exit or initialize to identity element -+ // -+ if (!params.inner_count) { -+ return params.reduction_identity; -+ } -+ -+ ComputeFragment accumulator; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator.size(); ++i) { -+ accumulator[i] = params.reduction_identity; -+ } -+ -+ // Compute the coordinate of the first access -+ int64_t src_byte_offset = 0; -+ Coord coord; -+ -+ uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ -+ // Load the first vector -+ SourceFragment source_fragment[kBatchSize]; -+ -+ bool not_done = true; -+ -+ // Iterate over vectors in a linearized reduction index space -+ while (not_done) { -+ -+ bool guards[kBatchSize]; -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ -+ if (linear_idx < params.inner_count) { -+ source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); -+ guards[b] = true; -+ } -+ else { -+ guards[b] = false; -+ not_done = false; -+ } -+ -+ linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ } -+ -+ // Perform a batch of reduction operations -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (guards[b]) { -+ auto cvt = convert_source(source_fragment[b]); -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ cvt); -+ } -+ } -+ }; -+ -+ // -+ // Reduction of vectors to scalar -+ // -+ -+ ElementCompute reduced_accumulator = accumulator[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < kVectorLength; ++i) { -+ reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); -+ } -+ -+ // -+ // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} -+ // -+ // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column -+ // -+ -+ int thread_count = blockDim.x * blockDim.z; -+ int thread_j = threadIdx.x + blockDim.x * threadIdx.z; -+ int thread_i = threadIdx.y; -+ -+ ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; -+ -+ frag_ptr[thread_j] = reduced_accumulator; -+ -+ // -+ // Reduce -+ // -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (thread_count > 1) { -+ thread_count /= 2; -+ -+ __syncthreads(); -+ -+ if (thread_j < thread_count) { -+ ElementCompute other = frag_ptr[thread_j + thread_count]; -+ -+ reduced_accumulator = reduction_op(reduced_accumulator, other); -+ -+ frag_ptr[thread_j] = reduced_accumulator; -+ } -+ -+ __syncthreads(); -+ } -+ -+ -+ return reduced_accumulator; -+ } -+ -+public: -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char const * src_byte_ptr = reinterpret_cast(params.source); -+ char * dst_byte_ptr = nullptr; -+ -+ // If performing a reduction across CTAs, redirect output to device workspace -+ if (gridDim.z == 1) { -+ dst_byte_ptr = reinterpret_cast(params.destination); -+ } -+ else { -+ dst_byte_ptr = reinterpret_cast(params.device_workspace); -+ } -+ -+ uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ int64_t src_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ if (gridDim.z == 1) { -+ -+ /// Complete the reduction with no workspace -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset, -+ coord_c); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0 && threadIdx.x == 0) { -+ -+ // Convert to output type and store -+ NumericConverter convert_output; -+ ElementOutput cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; -+ } -+ -+ __syncthreads(); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while -+ } -+ else { -+ -+ /// Complete the reduction with workspace -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset, -+ coord_c); -+ -+ int64_t byte_offset = -+ blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; -+ -+ // Store the result for final reduction -+ if (threadIdx.z == 0 && threadIdx.x == 0) { -+ *reinterpret_cast(dst_byte_ptr + byte_offset) = result; -+ } -+ -+ __syncthreads(); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ } // while -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to perform final reduction -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineContiguousFinal { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ /// Shared memory -+ struct SharedStorage { }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineContiguousParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate of rank -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offsets using destination and source strides -+ dst_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ElementCompute reduce_indices_( -+ Params const ¶ms, -+ ElementCompute const *device_workspace) { -+ -+ ReductionOp reduction_op(params.reduction_op); -+ char const *src_byte_ptr = reinterpret_cast(device_workspace); -+ -+ // Accumulated output -+ ElementCompute accumulator = params.reduction_identity; -+ -+ for (int iter = 0; iter < params.workspace_count; ++iter) { -+ ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); -+ -+ accumulator = reduction_op(accumulator, workspace_item); -+ -+ src_byte_ptr += params.workspace_stride; -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; -+ -+ char * dst_byte_ptr = reinterpret_cast(params.destination); -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ -+ /// Complete the reduction -+ while (idx_linear < params.outer_count) { -+ -+ ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); -+ -+ // Convert to output type and store -+ NumericConverter convert_output; -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.x * blockDim.x; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h -new file mode 100644 -index 0000000..9d5b045 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h -@@ -0,0 +1,641 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over one or more ranks of an affine tensor -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Parameters structure -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+struct TensorReductionAffineStridedParams { -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ -+ Coord extent; /// Extent of source tensor -+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank -+ int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J -+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K -+ int64_t workspace_stride; /// stride (units of bytes) between workspace -+ int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace -+ int workspace_count; /// number of workspaces -+ -+ uint64_t inner_count; /// Number of elements in reduced index space -+ uint64_t outer_count; /// Number of elements in outer index space -+ -+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank -+ ElementSource const * source; /// Poitner to source pointer of rank kRank -+ ReductionOp reduction_op; /// Reduction operator -+ ElementCompute reduction_identity; /// Identity element for reduction operator -+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ TensorReductionAffineStridedParams() { -+ -+ } -+ -+ /// Ctor -+ TensorReductionAffineStridedParams( -+ Coord extent_, ///< Extent of source tensor -+ ElementOutput * dst_ptr_, ///< Output tensor data -+ int64_t dst_stride_[], ///< Stride (units of elements) -+ ElementSource const * src_ptr_, ///< Source tensor data -+ int64_t src_stride_[], ///< Stride (units of elements) -+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions -+ int64_t workspace_stride_, ///< Stride between workspaces -+ int workspace_count_, ///< Number of workspaces -+ ReductionOp reduction_op_, ///< Reduction operator -+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator -+ ): -+ extent(extent_), -+ inner_count(1), -+ outer_count(1), -+ destination(dst_ptr_), -+ source(src_ptr_), -+ device_workspace(device_workspace_), -+ workspace_outer_stride(0), -+ workspace_stride(workspace_stride_), -+ workspace_count(workspace_count_), -+ reduction_op(reduction_op_), -+ reduction_identity(reduction_identity_) { -+ -+ // Initialize divisors for fast div-mod -+ for (int p = 1; p < kRank; ++p) { -+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); -+ } -+ -+ int input_size_bits = sizeof_bits::value; -+ int output_size_bits = sizeof_bits::value; -+ -+ workspace_outer_stride = workspace_stride * workspace_count; -+ -+ // Compute strides in units of bytes -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8; -+ } -+ -+ for (int p = 0; p < kRank - 1; ++p) { -+ src_stride[p] = src_stride_[p] * input_size_bits / 8; -+ } -+ -+ // Compute number of elements in strided ranks -+ for (int p = 0; p < kReducedRank - 1; ++p) { -+ outer_count *= uint64_t(extent[p]); -+ } -+ -+ for (int p = 0; p < kInnerRank; ++p) { -+ inner_count *= uint64_t(extent[kReducedRank + p - 1]); -+ } -+ } -+}; -+ -+/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous -+/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineStrided { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory allocation used for reduction within the CTA -+ struct SharedStorage { -+ Array workspace; -+ }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineStridedParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_inner_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose into coordinate -+ coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); -+ -+ // Compute linear offset -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kInnerRank; ++i) { -+ src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; -+ } -+ } -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ int64_t &src_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose linear coordinate -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute offset into tensors -+ dst_offset = 0; -+ src_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank - 1; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ src_offset += params.src_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ComputeFragment reduce_indices_( -+ Params const ¶ms, -+ ElementCompute *threadblock_workspace, -+ char const *src_byte_ptr) { -+ -+ NumericArrayConverter convert_source; -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // Accumulated output -+ ComputeFragment identity_frag; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < identity_frag.size(); ++i) { -+ identity_frag[i] = params.reduction_identity; -+ } -+ -+ if (!params.inner_count) { -+ return identity_frag; -+ } -+ -+ ComputeFragment accumulator = identity_frag; -+ -+ // Compute the coordinate of the first access -+ int64_t src_byte_offset = 0; -+ Coord coord; -+ -+ uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ -+ // Load the first vector -+ SourceFragment source_fragment[kBatchSize]; -+ -+ bool not_done = true; -+ -+ // Iterate over vectors in a linearized reduction index space -+ while (not_done) { -+ -+ bool guards[kBatchSize]; -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ -+ if (linear_idx < params.inner_count) { -+ source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); -+ guards[b] = true; -+ } -+ else { -+ guards[b] = false; -+ not_done = false; -+ } -+ -+ linear_idx += blockDim.z * gridDim.z; -+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); -+ } -+ -+ // Perform a batch of reduction operations -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (guards[b]) { -+ -+ auto cvt = convert_source(source_fragment[b]); -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ cvt); -+ } -+ } -+ }; -+ -+ // Optional reduction within a CTA -+ if (blockDim.z > 1) { -+ -+ // Linearized thread ID -+ int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); -+ -+ // all threads store to workspace -+ ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); -+ -+ frag_ptr[thread_idx] = accumulator; -+ -+ __syncthreads(); -+ -+ if (threadIdx.z == 0) { -+ // Load all additional block indices -+ for (int z = 1; z < blockDim.z; ++z) { -+ ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; -+ -+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( -+ reduction_op, -+ accumulator, -+ frag); -+ } -+ } -+ -+ __syncthreads(); -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); -+ char * dst_byte_ptr = nullptr; -+ -+ // If performing a reduction across CTAs, redirect output to device workspace -+ if (gridDim.z == 1) { -+ dst_byte_ptr = reinterpret_cast(params.destination + coord_c); -+ } -+ else { -+ dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); -+ } -+ -+ // If the C index is out of bounds, exit -+ if (coord_c >= params.extent[kRank - 1]) { -+ return; -+ } -+ -+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ int64_t src_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ if (gridDim.z == 1) { -+ -+ /// Complete the reduction with no workspace -+ while (idx_linear < params.outer_count) { -+ -+ ComputeFragment result; -+ -+ result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0) { -+ -+ // Convert to output type and store -+ NumericArrayConverter convert_output; -+ auto cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = -+ reinterpret_cast(cvt); -+ } -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while -+ } -+ else { -+ -+ /// Complete the reduction with a device workspace -+ while (idx_linear < params.outer_count) { -+ -+ ComputeFragment result; -+ -+ result = reduce_indices_( -+ params, -+ shared_storage.workspace.data(), -+ src_byte_ptr + src_byte_offset); -+ -+ // Store the result after possible final reduction within the CTA -+ if (threadIdx.z == 0) { -+ -+ int64_t byte_offset = -+ blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; -+ -+ // No conversion - store in compute type -+ *reinterpret_cast(dst_byte_ptr + byte_offset) = -+ reinterpret_cast(result); -+ } -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ src_byte_offset, -+ idx_linear); -+ -+ } // while (outer index) -+ } // if () -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to perform final reduction -+template < -+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) -+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) -+ typename ElementOutput, ///< Data type of output tensor -+ typename ElementSource, ///< Data type of source tensor -+ typename ReductionOp, ///< Reduction operator -+ int VectorLength = 1, ///< Vector length for memory -+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation -+ int Threads = 256, ///< Number of participating threads -+ int BatchSize = 4 ///< Number of elements to load per batch -+> -+class TensorReductionAffineStridedFinal { -+public: -+ -+ static int const kRank = Rank; -+ static int const kReducedRank = ReducedRank; -+ static int const kVectorLength = VectorLength; -+ static int const kInnerRank = kRank - kReducedRank; -+ static int const kThreads = Threads; -+ static int const kBatchSize = BatchSize; -+ using ComputeFragment = Array; -+ using SourceFragment = AlignedArray; -+ using OutputFragment = AlignedArray; -+ -+ /// Shared memory -+ struct SharedStorage { }; -+ -+ /// Parameters structure -+ using Params = TensorReductionAffineStridedParams< -+ Rank, -+ ReducedRank, -+ ElementOutput, -+ ElementSource, -+ ReductionOp, -+ VectorLength, -+ ElementCompute, -+ Threads, -+ BatchSize -+ >; -+ -+private: -+ -+ /// Computes the coordinate and offset of a given linear index -+ CUTLASS_DEVICE -+ void compute_outer_coord_and_offset_( -+ Params const ¶ms, -+ Coord & coord, -+ int64_t &dst_offset, -+ uint64_t linear_idx) const { -+ -+ // Decompose linear index -+ coord = CoordinateDecomposition(linear_idx, params.divmod); -+ -+ // Compute tensor offset -+ dst_offset = 0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kReducedRank - 1; ++i) { -+ dst_offset += params.dst_stride[i] * coord[i]; -+ } -+ } -+ -+ /// Reduces over the reduction indices -+ CUTLASS_DEVICE -+ ComputeFragment reduce_indices_( -+ Params const ¶ms, -+ char *src_byte_ptr) { -+ -+ ReductionOp reduction_op(params.reduction_op); -+ -+ // Accumulated output -+ ComputeFragment identity_frag; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < identity_frag.size(); ++i) { -+ identity_frag[i] = params.reduction_identity; -+ } -+ -+ ComputeFragment accumulator = identity_frag; -+ ComputeFragment workspace_fragments[kBatchSize]; -+ -+ // Partially unrolled loop -+ for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { -+ -+ // Issue a batch of loads -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ if (idx + b < params.workspace_count) { -+ workspace_fragments[b] = -+ *reinterpret_cast(src_byte_ptr); -+ } -+ else { -+ workspace_fragments[b] = identity_frag; -+ } -+ src_byte_ptr += + params.workspace_stride; -+ } -+ -+ // Perform a reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int b = 0; b < kBatchSize; ++b) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kVectorLength; ++i) { -+ accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); -+ } -+ } -+ } -+ -+ return accumulator; -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Perform a reduction -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; -+ -+ char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); -+ char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); -+ -+ // If the C index is out of bounds, exit -+ if (coord_c >= params.extent[kRank - 1]) { -+ return; -+ } -+ -+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; -+ -+ // Use modulo division to compute location -+ Coord outer_coord; -+ int64_t dst_byte_offset; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ -+ /// Complete the reduction -+ while (idx_linear < params.outer_count) { -+ -+ int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; -+ -+ ComputeFragment result = reduce_indices_( -+ params, -+ src_byte_ptr + src_byte_offset); -+ -+ // Convert to output type and store -+ NumericArrayConverter convert_output; -+ auto cvt = convert_output(result); -+ -+ *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = -+ reinterpret_cast(cvt); -+ -+ // Update indices and pointers -+ idx_linear += gridDim.y * blockDim.y; -+ -+ compute_outer_coord_and_offset_( -+ params, -+ outer_coord, -+ dst_byte_offset, -+ idx_linear); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h b/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h -new file mode 100644 -index 0000000..4f6e180 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/thread/reduce.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines basic thread level reduction with specializations for Array. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/half.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+namespace reduction { -+namespace thread { -+ -+/// Structure to compute the thread level reduction -+template -+struct Reduce; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial Specialization of Reduce for "plus" (a functional operator) -+template -+struct Reduce< plus, T > { -+ -+ CUTLASS_HOST_DEVICE -+ T operator()(T lhs, T const &rhs) const { -+ plus _op; -+ return _op(lhs, rhs); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization of Reduce for Array -+template -+struct Reduce < plus, Array> { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &in) const { -+ -+ Array result; -+ Reduce< plus, T > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ result[0] = scalar_reduce(result[0], in[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specializations of Reduce for Array -+template -+struct Reduce < plus, Array > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &input) { -+ -+ Array result; -+ -+ // If there is only 1 element - there is nothing to reduce -+ if( N ==1 ){ -+ -+ result[0] = input.front(); -+ -+ } else { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) -+ -+ __half result_d; -+ Array const *in_ptr_half = reinterpret_cast const *>(&input); -+ Array const *in_ptr_half2 = reinterpret_cast const *>(&input); -+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); -+ -+ // Set initial result = first half2, in case N==2 -+ __half2 tmp_result = x_in_half2[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < N/2; ++i) { -+ -+ tmp_result = __hadd2(x_in_half2[i], tmp_result); -+ -+ } -+ -+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); -+ -+ // One final step is needed for odd "N" (to add the (N-1)th element) -+ if( N%2 ){ -+ -+ __half last_element; -+ Array tmp_last; -+ Array *tmp_last_ptr = &tmp_last; -+ tmp_last_ptr[0] = in_ptr_half[N-1]; -+ last_element = reinterpret_cast<__half const &>(tmp_last); -+ -+ result_d = __hadd(result_d, last_element); -+ -+ } -+ -+ Array *result_ptr = &result; -+ *result_ptr = reinterpret_cast &>(result_d); -+ -+ #else -+ -+ Reduce< plus, half_t > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ -+ result[0] = scalar_reduce(result[0], input[i]); -+ -+ } -+ -+ #endif -+ } -+ -+ return result; -+ -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specializations of Reduce for AlignedArray -+template -+struct Reduce < plus, AlignedArray > { -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(AlignedArray const &input) { -+ -+ Array result; -+ -+ // If there is only 1 element - there is nothing to reduce -+ if( N ==1 ){ -+ -+ result[0] = input.front(); -+ -+ } else { -+ -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) -+ -+ __half result_d; -+ AlignedArray const *in_ptr_half = reinterpret_cast const *>(&input); -+ AlignedArray const *in_ptr_half2 = reinterpret_cast const *>(&input); -+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); -+ -+ // Set initial result = first half2, in case N==2 -+ __half2 tmp_result = x_in_half2[0]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < N/2; ++i) { -+ -+ tmp_result = __hadd2(x_in_half2[i], tmp_result); -+ -+ } -+ -+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); -+ -+ // One final step is needed for odd "N" (to add the (N-1)th element) -+ if( N%2 ){ -+ -+ __half last_element; -+ AlignedArray tmp_last; -+ AlignedArray *tmp_last_ptr = &tmp_last; -+ tmp_last_ptr[0] = in_ptr_half[N-1]; -+ last_element = reinterpret_cast<__half const &>(tmp_last); -+ -+ result_d = __hadd(result_d, last_element); -+ -+ } -+ -+ Array *result_ptr = &result; -+ *result_ptr = reinterpret_cast &>(result_d); -+ -+ #else -+ -+ Reduce< plus, half_t > scalar_reduce; -+ result.clear(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (auto i = 0; i < N; ++i) { -+ -+ result[0] = scalar_reduce(result[0], input[i]); -+ -+ } -+ -+ #endif -+ } -+ -+ return result; -+ -+ } -+}; -+} -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h -new file mode 100644 -index 0000000..d54bcc0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/thread/reduction_operators.h -@@ -0,0 +1,235 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Kernel performing a reduction over densely packed tensors in global memory -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reduction { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Mixed-precision reduction -+template < -+ typename ElementAccumulator_, -+ typename Element_, -+ int Count = 1 -+> -+struct ReduceAdd { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementAccumulator = ElementAccumulator_; -+ using Element = Element_; -+ static int const kCount = Count; -+ -+ using FragmentAccumulator = cutlass::Array; -+ using FragmentElement = cutlass::Array; -+ -+ struct Params { }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ReduceAdd(Params params_ = Params()): params(params_) { } -+ -+ /// Operator -+ CUTLASS_HOST_DEVICE -+ FragmentAccumulator operator()( -+ FragmentAccumulator accumulator, -+ FragmentElement element) const { -+ -+ plus op; -+ -+ NumericArrayConverter< -+ ElementAccumulator, -+ Element, -+ kCount, -+ PreferredRoundingMode::kRound> converter; -+ -+ return op(accumulator, converter(element)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Special handling for binary operators -+template -+struct VectorizeArrayOperation { -+ -+ using ValueType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ ValueType operator()( -+ ReductionOp const &reduction_op, -+ ValueType const &lhs, -+ ValueType const &rhs) const { -+ -+ ValueType result; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ result[i] = reduction_op(lhs[i], rhs[i]); -+ } -+ -+ return result; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ReduceArrayOperation { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ Element operator()( -+ ReductionOp const &reduction_op, -+ ArrayType const &array) const { -+ -+ Element item = reduction_op(array[0], array[1]); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 2; i < N; ++i) { -+ item = reduction_op(item, array[i]); -+ } -+ -+ return item; -+ } -+}; -+ -+template -+struct ReduceArrayOperation, uint1b_t, N> { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ uint1b_t operator()( -+ logical_and const &reduction_op, -+ ArrayType const &array) const { -+ -+ uint8_t const *ptr = reinterpret_cast(&array); -+ bool item = false; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int byte = 0; byte < (N + 7) / 8; ++byte) { -+ uint8_t bits = ptr[byte]; -+ item = (item || !bits); -+ } -+ -+ return uint1b_t(!item); -+ } -+}; -+ -+template -+struct ReduceArrayOperation, uint1b_t, N> { -+ -+ using ArrayType = Array; -+ -+ CUTLASS_HOST_DEVICE -+ uint1b_t operator()( -+ logical_and const &reduction_op, -+ ArrayType const &array) const { -+ -+ uint8_t const *ptr = reinterpret_cast(&array); -+ bool item = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int byte = 0; byte < (N + 7) / 8; ++byte) { -+ uint8_t bits = ptr[byte]; -+ item = (item || bits); -+ } -+ -+ return uint1b_t(item); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper function to infer template argument types -+template -+CUTLASS_HOST_DEVICE -+Array ApplyArrayOperator( -+ ReductionOp const &reduction_op, -+ Array const &lhs, -+ Array const &rhs) { -+ -+ VectorizeArrayOperation vectorize_op; -+ -+ return vectorize_op(reduction_op, lhs, rhs); -+} -+ -+/// Helper to reduce an array -+template -+Element ReduceArray(ReductionOp const &reduction_op, Array const &array) { -+ ReduceArrayOperation reduce_array_op; -+ -+ return reduce_array_op(reduction_op, array); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace reduction -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h -new file mode 100644 -index 0000000..5dd6e44 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/reduction/threadblock_swizzle.h -@@ -0,0 +1,67 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation. -+*/ -+#pragma once -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+namespace reduction { -+struct DefaultBlockSwizzle { -+ /// Ctor -+ CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} -+ -+ /// Swizzle the block index. -+ CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } -+ -+ /// -+ CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, -+ Coord<3> const &OutputTile) { -+ assert(OutputTile[0] == 1 && OutputTile[1] == 1); -+ assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0); -+ dim3 grid; -+ grid.x = problem_size[0] * problem_size[1] * problem_size[2] -+ / OutputTile[2] ; -+ return grid; -+ } -+ -+ /// -+ CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) { -+ assert(SubTile[0] == 1 && SubTile[1] == 1); -+ dim3 block = swizzle(); -+ Coord<3> threadblock_offset = -+ make_Coord(0, 0, block.x * SubTile[2]); -+ return threadblock_offset; -+ } -+}; -+} // namespace reduction -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/relatively_equal.h b/3rdparty/cutlass/include/cutlass/relatively_equal.h -new file mode 100644 -index 0000000..4736e28 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/relatively_equal.h -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Performs comparison between two elements with support for floating-point comparisons. -+*/ -+ -+#pragma once -+ -+#include "numeric_types.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+CUTLASS_HOST_DEVICE -+bool relatively_equal(T a, T b, T epsilon, T nonzero_floor); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+// This floating-point comparison function implements the method described in -+// -+// https://floating-point-gui.de/errors/comparison/ -+// -+template -+CUTLASS_HOST_DEVICE -+bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { -+ -+ using std::abs; -+ -+ T abs_A = abs(a); -+ T abs_B = abs(b); -+ T diff = abs(a - b); -+ T zero = T(0); -+ -+ if (a == b) { -+ return true; -+ } -+ else if (a == zero || b == zero || diff < nonzero_floor) { -+ return diff < epsilon * nonzero_floor; -+ } -+ -+ return diff < epsilon * (abs_A + abs_B); -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int8_t a, int8_t b, int8_t, int8_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint8_t a, uint8_t b, uint8_t, uint8_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int16_t a, int16_t b, int16_t, int16_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint16_t a, uint16_t b, uint16_t, uint16_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int32_t a, int32_t b, int32_t, int32_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint32_t a, uint32_t b, uint32_t, uint32_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(int64_t a, int64_t b, int64_t, int64_t) { -+ return (a == b); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { -+ return (a == b); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal( -+ bfloat16_t a, -+ bfloat16_t b, -+ bfloat16_t epsilon, -+ bfloat16_t nonzero_floor) { -+ -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal( -+ tfloat32_t a, -+ tfloat32_t b, -+ tfloat32_t epsilon, -+ tfloat32_t nonzero_floor) { -+ -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+ -+template <> -+CUTLASS_HOST_DEVICE -+bool relatively_equal(double a, double b, double epsilon, double nonzero_floor) { -+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/semaphore.h b/3rdparty/cutlass/include/cutlass/semaphore.h -new file mode 100644 -index 0000000..ed8a179 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/semaphore.h -@@ -0,0 +1,122 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/matrix_shape.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CTA-wide semaphore for inter-CTA synchronization. -+class Semaphore { -+public: -+ -+ int *lock; -+ bool wait_thread; -+ int state; -+ -+public: -+ -+ /// Implements a semaphore to wait for a flag to reach a given value -+ CUTLASS_HOST_DEVICE -+ Semaphore(int *lock_, int thread_id): -+ lock(lock_), -+ wait_thread(thread_id < 0 || thread_id == 0), -+ state(-1) { -+ -+ } -+ -+ /// Permit fetching the synchronization mechanism early -+ CUTLASS_DEVICE -+ void fetch() { -+ if (wait_thread) { -+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 -+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); -+ #else -+ asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); -+ #endif -+ } -+ } -+ -+ /// Gets the internal state -+ CUTLASS_DEVICE -+ int get_state() const { -+ return state; -+ } -+ -+ /// Waits until the semaphore is equal to the given value -+ CUTLASS_DEVICE -+ void wait(int status = 0) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ while( __syncthreads_and(state != status) ) { -+ fetch(); -+ } -+ -+ __syncthreads(); -+#endif -+ } -+ -+ /// Updates the lock with the given result -+ CUTLASS_DEVICE -+ void release(int status = 0) { -+#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) -+ __syncthreads(); -+ -+ if (wait_thread) { -+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 -+ asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); -+ #else -+ asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); -+ #endif -+ } -+#endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/subbyte_reference.h b/3rdparty/cutlass/include/cutlass/subbyte_reference.h -new file mode 100644 -index 0000000..58c460a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/subbyte_reference.h -@@ -0,0 +1,637 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Provides a mechanism for packing and unpacking elements smaller than one byte -+*/ -+#pragma once -+ -+#include "cutlass/numeric_types.h" -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It -+/// assumes these sub-byte elements are packed in a traditional C++ numeric type. -+/// -+/// The intended application is to provide a mechanism to indirectly reference elements in -+/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller -+/// than one byte. -+/// -+/// Supports basic pointer arithmetic: -+/// -+/// Example: -+/// -+/// int4b_t *ptr = ...; -+/// -+/// SubbyteReference ref = ptr; -+/// ref += 15; -+/// -+/// int4b_t x = ref; // load an int4b_t -+/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store -+/// -+template < -+ typename Element_, /// CUTLASS numeric element type. -+ typename Storage_ = uint8_t /// Underlying storage type. Must be able to hold an integer -+ /// number of objects of type Element. -+> -+class ConstSubbyteReference { -+public: -+ -+ using Element = Element_; -+ using Storage = Storage_; -+ using StoragePointer = Storage const *; -+ -+ static_assert(sizeof_bits::value <= sizeof_bits::value, -+ "Size of Element must not be greater than Storage."); -+ -+ static_assert(!(sizeof_bits::value % sizeof_bits::value), -+ "Storage must be divisible by Element"); -+ -+private: -+ -+ ///! Number of elements per storage vector -+ int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; -+ -+ ///! Bit mask -+ Storage const kMask = -+ ((sizeof_bits::value < sizeof_bits::value) ? -+ (Storage(1) << sizeof_bits::value) - Storage(1) : -+ ~Storage(0)); -+ -+private: -+ -+ /// Pointer to array containing element -+ StoragePointer ptr_; -+ -+ /// Offset (in units of elements) from pointer. -+ /// -+ /// Invariant: must always be in range [0, kElementsPerVector) -+ int offset_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference( -+ Element const *ptr, /// pointer to memory -+ int64_t offset /// logical offset in units of Element -+ ): -+ ptr_(reinterpret_cast(ptr)), -+ offset_(0) { -+ -+ int64_t offset_in_vectors = offset / kElementsPerVector; -+ int64_t offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = int(offset_in_elements); -+ } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference( -+ Element *ptr = nullptr -+ ): ConstSubbyteReference(ptr, 0) { } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ StoragePointer storage_pointer() const { -+ return ptr_; -+ } -+ -+ /// Gets element offset within storage vector -+ CUTLASS_HOST_DEVICE -+ int element_offset() const { -+ return offset_; -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ Element get() const { -+ Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ operator Element() const { -+ return get(); -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator+=(int offset) { -+ -+ offset += offset_; -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator+=(long long offset) { -+ -+ offset += offset_; -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator-=(int offset) { -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference &operator-=(long long offset) { -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator+(int offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator+(long long offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator-(int offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ ConstSubbyteReference operator-=(long long offset) const { -+ -+ ConstSubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Computes the difference in elements between references -+ CUTLASS_HOST_DEVICE -+ ptrdiff_t operator-(ConstSubbyteReference ref) const { -+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to signed 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator int64_t() const { -+ return int64_t(get()); -+ } -+ -+ /// Explicit cast to unsigned 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return uint64_t(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ -+ /// Explicit cast to double -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(get()); -+ } -+}; -+ -+template < -+ typename Element_, /// CUTLASS numeric element type. -+ typename Storage_ = /// Underlying storage type. Must be able to hold an integer -+ /// number of objects of type Element. -+ -+#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads. -+ #if (__CUDA_ARCH__ >= 700) /// -+ uint16_t -+ #else -+ uint32_t -+ #endif -+#else -+ uint8_t -+#endif -+> -+class SubbyteReference { -+public: -+ -+ using Element = Element_; -+ using Storage = Storage_; -+ using StoragePointer = Storage *; -+ -+ static_assert(sizeof_bits::value <= sizeof_bits::value, -+ "Size of Element must not be greater than Storage."); -+ -+ static_assert(!(sizeof_bits::value % sizeof_bits::value), -+ "Storage must be divisible by Element"); -+ -+private: -+ -+ ///! Number of elements per storage vector -+ int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; -+ -+ ///! Bit mask -+ Storage const kMask = -+ ((sizeof_bits::value < sizeof_bits::value) ? -+ (Storage(1) << sizeof_bits::value) - Storage(1) : -+ ~Storage(0)); -+ -+private: -+ -+ /// Pointer to array containing element -+ StoragePointer ptr_; -+ -+ /// Offset (in units of elements) from pointer. -+ /// -+ /// Invariant: must always be in range [0, kElementsPerVector) -+ int offset_; -+ -+public: -+ -+ CUTLASS_HOST_DEVICE -+ SubbyteReference(): ptr_(nullptr), offset_(0) { } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ SubbyteReference( -+ Element *ptr, /// pointer to memory -+ int64_t offset /// logical offset in units of Element -+ ): -+ ptr_(reinterpret_cast(ptr)), -+ offset_(0) { -+ -+ int64_t offset_in_vectors = offset / kElementsPerVector; -+ int64_t offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = int(offset_in_elements); -+ } -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ SubbyteReference( -+ Element *ptr = nullptr -+ ): SubbyteReference(ptr, 0) { } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ StoragePointer storage_pointer() const { -+ return ptr_; -+ } -+ -+ /// Gets storage pointer -+ CUTLASS_HOST_DEVICE -+ Element * operator&() const { -+ return reinterpret_cast(ptr_); -+ } -+ -+ /// Gets element offset within storage vector -+ CUTLASS_HOST_DEVICE -+ int element_offset() const { -+ return offset_; -+ } -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ Element get() const { -+ Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); -+ return reinterpret_cast(item); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference & set(Element const &x) { -+ -+ Storage item = (reinterpret_cast(x) & kMask); -+ Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); -+ Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits::value)); -+ -+#if defined(__CUDA_ARCH__) -+ -+ // -+ // Homebrew read-modify-write -+ // -+ Storage original; -+ Storage updated; -+ -+ do { -+ -+ original = (*ptr_); -+ -+ updated = Storage((original & kUpdateMask) | new_bits); -+ -+ original = atomicCAS(ptr_, original, updated); -+ -+ } while (updated != original); -+ -+#else -+ -+ Storage original = (*ptr_); -+ Storage updated = Storage((original & kUpdateMask) | new_bits); -+ *ptr_ = updated; -+ -+#endif -+ -+ return *this; -+ } -+ -+ //// -+ -+ /// Unpacks an element from memory -+ CUTLASS_HOST_DEVICE -+ operator Element() const { -+ return get(); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=(Element const & x) { -+ return set(x); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=(SubbyteReference const & x) { -+ return set(x.get()); -+ } -+ -+ /// Stores an element to memory -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator=( -+ ConstSubbyteReference const &x) { -+ return set(x.get()); -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator+=(int offset) { -+ -+ offset += offset_; -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator+=(long long offset) { -+ -+ offset += offset_; -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ += offset_in_vectors; -+ offset_ = offset_in_elements; -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator-=(int offset) { -+ -+ int offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = offset % kElementsPerVector; -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Adds an offset in units of elements to the reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference &operator-=(long long offset) { -+ -+ long long offset_in_vectors = offset / kElementsPerVector; -+ int offset_in_elements = int(offset % kElementsPerVector); -+ -+ ptr_ -= offset_in_vectors; -+ offset_ -= offset_in_elements; -+ -+ if (offset_ < 0) { -+ offset_ += kElementsPerVector; -+ --ptr_; -+ } -+ -+ return *this; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator+(int offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator+(long long offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref += offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator-(int offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Returns a reference to an element with a given offset from the current reference -+ CUTLASS_HOST_DEVICE -+ SubbyteReference operator-=(long long offset) const { -+ -+ SubbyteReference ref(ptr_, offset_); -+ ref -= offset; -+ -+ return ref; -+ } -+ -+ /// Computes the difference in elements between references -+ CUTLASS_HOST_DEVICE -+ ptrdiff_t operator-(SubbyteReference ref) const { -+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); -+ } -+ -+ /// Explicit cast to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(get()); -+ } -+ -+ /// Explicit cast to signed 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator int64_t() const { -+ return int64_t(get()); -+ } -+ -+ /// Explicit cast to unsigned 64-bit integer -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return uint64_t(get()); -+ } -+ -+ /// Explicit cast to float -+ CUTLASS_HOST_DEVICE -+ explicit operator float() const { -+ return float(get()); -+ } -+ -+ /// Explicit cast to double -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(get()); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template ::value < 8)> -+struct ReferenceFactory; -+ -+template -+struct ReferenceFactory { -+ CUTLASS_HOST_DEVICE -+ static Element &get(Element *ptr, int64_t offset) { -+ return ptr[offset]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static Element const &get(Element const *ptr, int64_t offset) { -+ return ptr[offset]; -+ } -+}; -+ -+template -+struct ReferenceFactory { -+ CUTLASS_HOST_DEVICE -+ static SubbyteReference get(Element *ptr, int64_t offset) { -+ return SubbyteReference(ptr, offset); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static ConstSubbyteReference get(Element const *ptr, -+ int64_t offset) { -+ return ConstSubbyteReference(ptr, offset); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_coord.h b/3rdparty/cutlass/include/cutlass/tensor_coord.h -new file mode 100644 -index 0000000..d3a7b32 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_coord.h -@@ -0,0 +1,326 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a canonical coordinate for rank=4 tensors offering named indices. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a canonical 4D coordinate used by tensor operations. -+struct Tensor4DCoord : public Coord<4> { -+ -+ /// Base class -+ using Base = Coord<4>; -+ -+ /// Index type -+ using Index = typename Base::Index; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+ /// Batch dimension -+ static int const kN = 0; -+ -+ /// Height dimension -+ static int const kH = 1; -+ -+ /// Width dimension -+ static int const kW = 2; -+ -+ /// Channels dimension -+ static int const kC = 3; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord() { } -+ -+ /// Constructs from Coord<4> -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(Coord<4> const &coord): Base(coord) { } -+ -+ /// Helper to construct from N, H, W, and C. -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } -+ -+ /// Helper to construct from N, H, W, and C, which are LongIndex type -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) -+ : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & h() const { return this->at(kH); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & h() { return this->at(kH); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & w() const { return this->at(kW); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & w() { return this->at(kW); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & c() const { return this->at(kC); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & c() { return this->at(kC); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator+(Base const& b) const { -+ return Tensor4DCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator-(Base const& b) const { -+ return Tensor4DCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator*(Base const& b) const { -+ return Tensor4DCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord operator/(Base const& b) const { -+ return Tensor4DCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Tensor4DCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a canonical 5D coordinate used by tensor operations. -+struct Tensor5DCoord : public Coord<5> { -+ -+ /// Base class -+ using Base = Coord<5>; -+ -+ /// Index type -+ using Index = typename Base::Index; -+ -+ /// LongIndex type -+ using LongIndex = typename Base::LongIndex; -+ -+ /// Batch dimension -+ static int const kN = 0; -+ -+ /// Depth dimension -+ static int const kD = 1; -+ -+ /// Height dimension -+ static int const kH = 2; -+ -+ /// Width dimension -+ static int const kW = 3; -+ -+ /// Channels dimension -+ static int const kC = 4; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord() { } -+ -+ /// Constructs from Coord<5> -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(Coord<5> const &coord): Base(coord) { } -+ -+ /// Helper to construct from N, D, H, W, and C. -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } -+ -+ /// Helper to construct from N, D, H, W, and C, which are LongIndex type -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) -+ : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & n() const { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & n() { return this->at(kN); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & d() const { return this->at(kD); } -+ -+ /// Returns the batch of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & d() { return this->at(kD); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & h() const { return this->at(kH); } -+ -+ /// Returns the row of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & h() { return this->at(kH); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & w() const { return this->at(kW); } -+ -+ /// Returns the column of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & w() { return this->at(kW); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index const & c() const { return this->at(kC); } -+ -+ /// Returns the channel of the coordinate -+ CUTLASS_HOST_DEVICE -+ Index & c() { return this->at(kC); } -+ -+ // -+ // Coord operators -+ // -+ -+ /// Element-wise addition -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator+(Base const& b) const { -+ return Tensor5DCoord(Base::operator+(b)); -+ } -+ -+ /// Element-wise subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator-(Base const& b) const { -+ return Tensor5DCoord(Base::operator-(b)); -+ } -+ -+ /// Element-wise multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator*(Base const& b) const { -+ return Tensor5DCoord(Base::operator*(b)); -+ } -+ -+ /// Element-wise division -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord operator/(Base const& b) const { -+ return Tensor5DCoord(Base::operator/(b)); -+ } -+ -+ /// In-place addition -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator+=(Base const& b) { -+ Base::operator+=(b); -+ return *this; -+ } -+ -+ /// In-place subtraction -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator-=(Base const& b) { -+ Base::operator-=(b); -+ return *this; -+ } -+ -+ /// In-place multiplication -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator*=(Base const& b) { -+ Base::operator*=(b); -+ return *this; -+ } -+ -+ /// In-place division -+ CUTLASS_HOST_DEVICE -+ Tensor5DCoord& operator/=(Base const& b) { -+ Base::operator/=(b); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_ref.h b/3rdparty/cutlass/include/cutlass/tensor_ref.h -new file mode 100644 -index 0000000..ce2505e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_ref.h -@@ -0,0 +1,418 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -+*/ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/subbyte_reference.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Default layout function from coordinates in a tensor's index space into the n-D array held -+/// in memory. -+/// -+/// All layout functions must define at least the members shown in IdentityTensorLayout<>. -+template -+class IdentityTensorLayout { -+public: -+ /// Logical rank of tensor -+ static int const kRank = Rank; -+ -+ /// Rank of stride vector -+ static int const kStrideRank = Rank; -+ -+ /// Index type used for coordinates -+ using Index = int32_t; -+ -+ /// Long index type used for offsets -+ using LongIndex = int64_t; -+ -+ /// Logical coordinate -+ using TensorCoord = Coord; -+ -+ /// Stride vector -+ using Stride = Coord; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride data member -+ Stride stride_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { } -+ -+ /// Returns the offset of a coordinate in linear memory -+ CUTLASS_HOST_DEVICE -+ LongIndex operator()(Coord const &coord) const { -+ return coord.dot(stride_); -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return stride_; -+ } -+ -+ /// Returns the stride of the layout -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return stride_; -+ } -+ -+ /// Compute the number of contiguous elements needed to store a tensor with the given size -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity(TensorCoord const &size) const { -+ int idx = stride_.max_dim_index(); -+ return stride_[idx] * size[idx]; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank -+ and layout within memory. A TensorRef combines a pointer and a Layout concept -+ -+ Examples: -+ -+ (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h) -+ -+ 1. Column-major matrix may be represented as a rank=2 tensor: -+ -+ TensorRef A(ptr_A, ldm); -+ -+ 2. Row-major matrix may be represented as a rank=2 tensor: -+ -+ TensorRef B(ptr_A, ldm); -+ -+ 3. An interleaved matrix may be represented as a rank=2 tensor: -+ -+ TensorRef > C; -+ -+ 4. A helper exists to define a TensorRef for a contiguous matrix whose layout -+ is not known at compile time. -+ -+ int ldm; // leading dimension -+ layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor -+ -+ -+ TensorRef E(ptr_E, {ldm, kind}); -+ -+*/ -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class TensorRef { -+ public: -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Reference type to an element -+ using Reference = typename platform::conditional< -+ sizeof_bits::value >= 8, -+ Element &, -+ SubbyteReference -+ >::type; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to constant data -+ using ConstTensorRef = TensorRef< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorRef to non-constant data -+ using NonConstTensorRef = TensorRef< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// Pointer -+ Element* ptr_; -+ -+ /// Layout object maps logical coordinates to linear offsets -+ Layout layout_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRef(): ptr_(nullptr) { -+ -+ } -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRef( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout ///< layout object containing stride and mapping function -+ ): -+ ptr_(ptr), layout_(layout) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ template -+ CUTLASS_HOST_DEVICE -+ TensorRef( -+ NonConstTensorRef const &ref, ///< TensorRef to non-const data -+ ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const -+ _Magic magic = (typename platform::enable_if< ! platform::is_same >::value, _Magic>::type)0 -+ ): -+ ptr_(ref.data()), layout_(ref.layout()) { } -+ -+ /// Returns a reference to constant-valued tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(ptr_, layout_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ NonConstTensorRef non_const_ref() const { -+ return NonConstTensorRef(const_cast::type *>(ptr_), layout_); -+ } -+ -+ /// Updates only the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr = nullptr) { -+ ptr_ = ptr; -+ } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout) { -+ ptr_ = ptr; -+ layout_ = layout; -+ } -+ -+ /// Returns true if the TensorRef is non-null -+ CUTLASS_HOST_DEVICE -+ bool good() const { -+ return ptr_ != nullptr; -+ } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * data() const { return ptr_; } -+ -+ /// Returns a reference to the element at a given linear index -+ CUTLASS_HOST_DEVICE -+ Reference data(LongIndex idx) const { -+ return ReferenceFactory::type, -+ (sizeof_bits::value < 8)>::get(ptr_, idx); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ typename Layout::Stride::Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ typename Layout::Stride::Index & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ CUTLASS_HOST_DEVICE -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference operator[](TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRef & add_pointer_offset(LongIndex offset_) { -+ ptr_ += offset_; -+ return *this; -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRef & add_coord_offset(TensorCoord const &coord) { -+ add_pointer_offset(offset(coord)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef operator+(TensorCoord const& b) const { -+ TensorRef result(*this); -+ result.add_coord_offset(b); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef & operator+=(TensorCoord const& b) { -+ add_coord_offset(b); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef operator-(TensorCoord const& b) const { -+ TensorRef result(*this); -+ result.add_pointer_offset(-offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRef & operator-=(TensorCoord const& b) { -+ add_pointer_offset(-offset(b)); -+ return *this; -+ } -+}; -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+TensorRef make_TensorRef(Element *ptr, Layout const &layout) { -+ return TensorRef(ptr, layout); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Partial specializations to handle degenerate and sub-byte cases. -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+bool TensorRef_aligned(TensorRef const &ref, int alignment) { -+ -+ int const kStrideRank = Layout::kStrideRank; -+ -+ if (reinterpret_cast(ref.data()) % alignment) { -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kStrideRank; ++i) { -+ if (ref.stride(i) % alignment) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h b/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h -new file mode 100644 -index 0000000..a0131fd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_ref_planar_complex.h -@@ -0,0 +1,374 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data. -+*/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct PlanarComplexReference { -+ -+ // -+ // Type definitions -+ // -+ -+ using Element = Element_; -+ using ComplexElement = complex; -+ -+ // -+ // Data members -+ // -+ -+ Element *real; -+ Element *imag; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PlanarComplexReference( -+ Element *real_ = nullptr, -+ Element *imag_ = nullptr -+ ): -+ real(real_), imag(imag_) { } -+ -+ /// Loads the complex element -+ CUTLASS_HOST_DEVICE -+ operator complex() const { -+ return complex{*real, *imag}; -+ } -+ -+ /// Stores a complex element to the location pointed to by the reference -+ CUTLASS_HOST_DEVICE -+ PlanarComplexReference &operator=(complex const &rhs) { -+ *real = rhs.real(); -+ *imag = rhs.imag(); -+ return *this; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank -+ and layout within memory. A TensorRef combines a pointer and a Layout concept -+ -+*/ -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class TensorRefPlanarComplex { -+ public: -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Complex element type -+ using ComplexElement = complex; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ static_assert(sizeof_bits::value >= 8, -+ "Planar complex not suitable for subbyte elements at this time"); -+ -+ /// Reference type to an element -+ using Reference = PlanarComplexReference; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to constant data -+ using ConstTensorRef = TensorRefPlanarComplex< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorRef to non-constant data -+ using NonConstTensorRef = TensorRefPlanarComplex< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// Pointer -+ Element* ptr_; -+ -+ /// Layout object maps logical coordinates to linear offsets -+ Layout layout_; -+ -+ /// Offset to imaginary part -+ LongIndex imaginary_stride_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorRef with a pointer and layout object. -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex( -+ Element *ptr = nullptr, ///< pointer to start of tensor -+ Layout const &layout = Layout(), ///< layout object containing stride and mapping function -+ LongIndex imaginary_stride = 0 -+ ): -+ ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex( -+ NonConstTensorRef const &ref ///< TensorRef to non-const data -+ ): -+ ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } -+ -+ /// Returns a reference to constant-valued tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(ptr_, layout_, imaginary_stride_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ NonConstTensorRef non_const_ref() const { -+ return NonConstTensorRef( -+ const_cast::type *>(ptr_), -+ layout_, -+ imaginary_stride_); -+ } -+ -+ /// Updates only the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { -+ ptr_ = ptr; -+ imaginary_stride_ = imaginary_stride; -+ } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { -+ ptr_ = ptr; -+ layout_ = layout; -+ imaginary_stride_ = imaginary_stride; -+ } -+ -+ /// Returns true if the TensorRef is non-null -+ CUTLASS_HOST_DEVICE -+ bool good() const { -+ return ptr_ != nullptr; -+ } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * data() const { return ptr_; } -+ -+ /// Returns the pointer to referenced data -+ CUTLASS_HOST_DEVICE -+ Element * imaginary_data() const { return ptr_ + imaginary_stride_; } -+ -+ /// Returns a reference to the element at a given linear index -+ CUTLASS_HOST_DEVICE -+ Reference data(LongIndex idx) const { -+ return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Gets the stride to an imaginary element -+ LongIndex imaginary_stride() const { -+ return imaginary_stride_; -+ } -+ -+ /// Gets the stride to an imaginary element -+ LongIndex &imaginary_stride() { -+ return imaginary_stride_; -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ CUTLASS_HOST_DEVICE -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ CUTLASS_HOST_DEVICE -+ Index & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ CUTLASS_HOST_DEVICE -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference operator[](TensorCoord const& coord) const { -+ return data(offset(coord)); -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { -+ ptr_ += offset_; -+ return *this; -+ } -+ -+ /// Adds an offset to each pointer -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { -+ add_pointer_offset(offset(coord)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex operator+(TensorCoord const& b) const { -+ TensorRefPlanarComplex result(*this); -+ result.add_coord_offset(b); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & operator+=(TensorCoord const& b) { -+ add_coord_offset(b); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex operator-(TensorCoord const& b) const { -+ TensorRefPlanarComplex result(*this); -+ result.add_pointer_offset(-offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorRefPlanarComplex & operator-=(TensorCoord const& b) { -+ add_pointer_offset(-offset(b)); -+ return *this; -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorRef ref_real() const { -+ return cutlass::TensorRef(data(), layout()); -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorRef ref_imag() const { -+ return cutlass::TensorRef(imaginary_data(), layout()); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE -+TensorRefPlanarComplex make_TensorRefPlanarComplex( -+ Element *ptr, -+ Layout const &layout, -+ int64_t imaginary_stride) { -+ -+ return TensorRefPlanarComplex(ptr, layout, imaginary_stride); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/tensor_view.h b/3rdparty/cutlass/include/cutlass/tensor_view.h -new file mode 100644 -index 0000000..9a4d238 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_view.h -@@ -0,0 +1,297 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides and a pointer to tensor data. -+ -+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, -+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from -+ data storage and is therefore lightweight and may be embedded in larger tensor objects or -+ memory structures. -+ -+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to -+ linear memory. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Data type of element stored within tensor -+ typename Element_, -+ /// Maps a Coord in the logical tensor index space to the internal n-D array -+ typename Layout_ -+> -+class TensorView : public TensorRef { -+ public: -+ -+ /// Base tensor reference -+ using Base = cutlass::TensorRef; -+ -+ /// Mapping function from logical coordinate to internal n-D array -+ using Layout = Layout_; -+ -+ /// TensorRef pointing to constant memory -+ using ConstTensorRef = typename Base::ConstTensorRef; -+ -+ /// Underlying TensorRef type -+ using TensorRef = Base; -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Coordinate in storage n-D array -+ using Stride = typename Layout::Stride; -+ -+ /// TensorView pointing to constant memory -+ using ConstTensorView = TensorView< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorView pointing to non-constant memory -+ using NonConstTensorView = TensorView< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// View extent -+ TensorCoord extent_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView() { } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout, ///< layout object containing stride and mapping function -+ TensorCoord const &extent ///< size of the view in logical coordinates -+ ): -+ Base(ptr, layout), extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ TensorRef const &ref, ///< pointer and layout object referencing a tensor -+ TensorCoord const &extent ///< logical size of tensor -+ ): -+ Base(ref), extent_(extent) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorView( -+ NonConstTensorView const &view ///< TensorView to non-const data -+ ): -+ Base(view), extent_(view.extent_) { } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { -+ Base::reset(ptr, layout); -+ this->resize(extent); -+ } -+ -+ /// Updates the pointer -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr) { -+ Base::reset(ptr); -+ } -+ -+ /// Changes the size of the view without affecting pointer or layout -+ CUTLASS_HOST_DEVICE -+ void resize(TensorCoord const &extent) { -+ this->extent_ = extent; -+ } -+ -+ /// Returns the extent of the view (the size along each logical dimension). -+ CUTLASS_HOST_DEVICE -+ TensorCoord const& extent() const { return extent_; } -+ -+ /// Returns the extent along a particular logical dimension. -+ CUTLASS_HOST_DEVICE -+ Index extent(int dim) const { return extent_.at(dim); } -+ -+ /// Returns the number of logical elements -+ CUTLASS_HOST_DEVICE -+ LongIndex size() const { -+ return extent_.product(); -+ } -+ -+ /// Determines whether a location is within a tensor -+ CUTLASS_HOST_DEVICE -+ bool contains(TensorCoord const& coord) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int dim = 0; dim < kRank; ++dim) { -+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() const { -+ return TensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent_); -+ } -+ -+ /// Returns a Tensor_view given location and size quantities -+ CUTLASS_HOST_DEVICE -+ TensorView subview( -+ TensorCoord extent, ///< extent of the resulting view -+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view -+ ) const { -+ -+ TensorView result(this->ref(), extent.clamp(extent_ - location)); -+ result.add_coord_offset(location); -+ return result; -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ size_t capacity() const { -+ return Base::layout().capacity(extent_); -+ } -+ -+ /// Returns a TensorView offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView operator+( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorView result(*this); -+ result.add_pointer_offset(this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView& operator+=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(this->offset(b)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView operator-( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorRef result(*this); -+ result.add_pointer_offset(-this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorView& operator-=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(-this->offset(b)); -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE TensorView make_TensorView( -+ Element *ptr, -+ Layout const &layout, -+ typename Layout::TensorCoord const &extent) { -+ -+ return TensorView(ptr, layout, extent); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h b/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h -new file mode 100644 -index 0000000..6a66c6a ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tensor_view_planar_complex.h -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a structure containing strides and a pointer to tensor data. -+ -+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, -+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from -+ data storage and is therefore lightweight and may be embedded in larger tensor objects or -+ memory structures. -+ -+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to -+ linear memory. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Data type of element stored within tensor -+ typename Element_, -+ /// Maps a Coord in the logical tensor index space to the internal n-D array -+ typename Layout_ -+> -+class TensorViewPlanarComplex : public TensorRefPlanarComplex { -+ public: -+ -+ /// Base tensor reference -+ using Base = cutlass::TensorRefPlanarComplex; -+ -+ /// Mapping function from logical coordinate to internal n-D array -+ using Layout = Layout_; -+ -+ /// TensorRef pointing to constant memory -+ using ConstTensorRef = typename Base::ConstTensorRef; -+ -+ /// Underlying TensorRef type -+ using TensorRef = Base; -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Coordinate in storage n-D array -+ using Stride = typename Layout::Stride; -+ -+ /// TensorView pointing to constant memory -+ using ConstTensorView = TensorViewPlanarComplex< -+ typename platform::remove_const::type const, -+ Layout>; -+ -+ /// TensorView pointing to non-constant memory -+ using NonConstTensorView = TensorViewPlanarComplex< -+ typename platform::remove_const::type, -+ Layout>; -+ -+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a -+ /// scalar, but degenerate cases such as these are difficult to accommodate without -+ /// extensive C++ metaprogramming or support for zero-length arrays. -+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); -+ -+ private: -+ -+ /// View extent -+ TensorCoord extent_; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ Element *ptr, ///< pointer to start of tensor -+ Layout const &layout, ///< layout object containing stride and mapping function -+ LongIndex imaginary_stride, ///< stride between real and imaginary part -+ TensorCoord const &extent ///< size of the view in logical coordinates -+ ): -+ Base(ptr, layout, imaginary_stride), extent_(extent) { -+ -+ } -+ -+ /// Constructs a TensorView object -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ TensorRef const &ref, ///< pointer and layout object referencing a tensor -+ TensorCoord const &extent ///< logical size of tensor -+ ): -+ Base(ref), extent_(extent) { -+ -+ } -+ -+ /// Converting constructor from TensorRef to non-constant data. -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex( -+ NonConstTensorView const &view ///< TensorView to non-const data -+ ): -+ Base(view), extent_(view.extent_) { } -+ -+ /// Updates the pointer and layout object -+ CUTLASS_HOST_DEVICE -+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { -+ Base::reset(ptr, layout, imaginary_stride); -+ this->resize(extent_); -+ } -+ -+ /// Changes the size of the view without affecting pointer or layout -+ CUTLASS_HOST_DEVICE -+ void resize(TensorCoord extent) { -+ this->extent_ = extent; -+ } -+ -+ /// Returns the extent of the view (the size along each logical dimension). -+ CUTLASS_HOST_DEVICE -+ TensorCoord const& extent() const { return extent_; } -+ -+ /// Returns the extent along a particular logical dimension. -+ CUTLASS_HOST_DEVICE -+ Index extent(int dim) const { return extent_.at(dim); } -+ -+ /// Determines whether a location is within a tensor -+ CUTLASS_HOST_DEVICE -+ bool contains(TensorCoord const& coord) const { -+ CUTLASS_PRAGMA_UNROLL -+ for (int dim = 0; dim < kRank; ++dim) { -+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { -+ return false; -+ } -+ } -+ return true; -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ Base ref() const { -+ return Base(this->data(), this->layout(), this->imaginary_stride()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), this->layout()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent_); -+ } -+ -+ /// Returns a Tensor_view given location and size quantities -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex subview( -+ TensorCoord extent, ///< extent of the resulting view -+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view -+ ) const { -+ -+ TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location)); -+ result.add_coord_offset(location); -+ return result; -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ size_t capacity() const { -+ return Base::layout().capacity(extent_); -+ } -+ -+ /// Returns a TensorView offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex operator+( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorViewPlanarComplex result(*this); -+ result.add_pointer_offset(this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex& operator+=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(this->offset(b)); -+ return *this; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex operator-( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) const { -+ -+ TensorRef result(*this); -+ result.add_pointer_offset(-this->offset(b)); -+ return result; -+ } -+ -+ /// Returns a TensorRef offset by a given amount -+ CUTLASS_HOST_DEVICE -+ TensorViewPlanarComplex& operator-=( -+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor -+ ) { -+ -+ this->add_pointer_offset(-this->offset(b)); -+ return *this; -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorView view_real() const { -+ return cutlass::TensorView(this->data(), this->layout(), extent_); -+ } -+ -+ /// TensorRef to real-valued tensor -+ CUTLASS_HOST_DEVICE -+ cutlass::TensorView view_imag() const { -+ return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs a TensorRef, deducing types from arguments. -+template < -+ typename Element, -+ typename Layout -+> -+CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( -+ Element *ptr, -+ Layout const &layout, -+ typename Layout::LongIndex imaginary_stride, -+ typename Layout::TensorCoord const &extent) { -+ -+ return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/tfloat32.h b/3rdparty/cutlass/include/cutlass/tfloat32.h -new file mode 100644 -index 0000000..76e2bf9 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/tfloat32.h -@@ -0,0 +1,477 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines a proxy class for storing Tensor Float 32 data type. -+*/ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include "cutlass/floating_point_nvrtc.h" -+#else -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tensor Float 32 data type -+struct alignas(4) tfloat32_t { -+ -+ // -+ // Data members -+ // -+ -+ /// Storage type -+ uint32_t storage; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs from an unsigned int -+ CUTLASS_HOST_DEVICE -+ static tfloat32_t bitcast(uint32_t x) { -+ tfloat32_t h; -+ h.storage = x; -+ return h; -+ } -+ -+ /// Emulated rounding is fast in device code -+ CUTLASS_HOST_DEVICE -+ static tfloat32_t round_half_ulp_truncate(float const &s) { -+ uint32_t x = reinterpret_cast(s); -+ -+ #if defined(__CUDA_ARCH__) -+ if (::isfinite(s)) { -+ x += 0x1000u; -+ } -+ #else -+ if (std::isfinite(s)) { -+ x += 0x1000u; -+ } -+ #endif -+ -+ return tfloat32_t::bitcast(x); -+ } -+ -+ /// Default constructor -+ tfloat32_t() = default; -+ -+ /// Floating-point conversion - round toward nearest even -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } -+ tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } -+ -+ /// Floating-point conversion - round toward nearest even -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(double x): tfloat32_t(float(x)) { -+ tfloat32_t(double x): tfloat32_t(float(x)) { -+ } -+ -+ /// Integer conversion - round toward zero -+ CUTLASS_HOST_DEVICE -+// explicit tfloat32_t(int x) { -+ tfloat32_t(int x) { -+ float flt = static_cast(x); -+ #if defined(__CUDA_ARCH__) -+ storage = reinterpret_cast(flt); -+ #else -+ std::memcpy(&storage, &flt, sizeof(storage)); -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ operator float() const { -+ -+ // Conversions to IEEE single-precision requires clearing dont-care bits -+ // of the mantissa. -+ unsigned bits = (storage & ~0x1fffu); -+ -+ #if defined(__CUDA_ARCH__) -+ return reinterpret_cast(bits); -+ #else -+ float flt; -+ std::memcpy(&flt, &bits, sizeof(flt)); -+ return flt; -+ #endif -+ } -+ -+ /// Converts to float -+ CUTLASS_HOST_DEVICE -+ explicit operator double() const { -+ return double(float(*this)); -+ } -+ -+ /// Converts to int -+ CUTLASS_HOST_DEVICE -+ explicit operator int() const { -+ return int(float(*this)); -+ } -+ -+ /// Casts to bool -+ CUTLASS_HOST_DEVICE -+ explicit operator bool() const { -+ return (float(*this) != 0.0f); -+ } -+ -+ /// Obtains raw bits -+ CUTLASS_HOST_DEVICE -+ uint32_t raw() const { -+ return storage; -+ } -+ -+ /// Returns the sign bit -+ CUTLASS_HOST_DEVICE -+ bool signbit() const { -+ return ((raw() & 0x80000000) != 0); -+ } -+ -+ /// Returns the biased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent_biased() const { -+ return int((raw() >> 23) & 0x0ff); -+ } -+ -+ /// Returns the unbiased exponent -+ CUTLASS_HOST_DEVICE -+ int exponent() const { -+ return exponent_biased() - 127; -+ } -+ -+ /// Returns the mantissa -+ CUTLASS_HOST_DEVICE -+ int mantissa() const { -+ return int(raw() & 0x7fffff); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool signbit(cutlass::tfloat32_t const& h) { -+ return h.signbit(); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { -+ return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnan(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isfinite(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() != 0x0ff); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t nan_tf32(const char*) { -+ // NVIDIA canonical NaN -+ return cutlass::tfloat32_t::bitcast(0x7fffffff); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isinf(cutlass::tfloat32_t const& h) { -+ return (h.exponent_biased() == 0x0ff) && !h.mantissa(); -+} -+ -+CUTLASS_HOST_DEVICE -+bool isnormal(cutlass::tfloat32_t const& h) { -+ return h.exponent_biased() && h.exponent_biased() != 0x0ff; -+} -+ -+CUTLASS_HOST_DEVICE -+int fpclassify(cutlass::tfloat32_t const& h) { -+ int exp = h.exponent_biased(); -+ int mantissa = h.mantissa(); -+ if (exp == 0x0ff) { -+ if (mantissa) { -+ return FP_NAN; -+ } -+ else { -+ return FP_INFINITE; -+ } -+ } -+ else if (!exp) { -+ if (mantissa) { -+ return FP_SUBNORMAL; -+ } -+ else { -+ return FP_ZERO; -+ } -+ } -+ return FP_NORMAL; -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { -+#if defined(__CUDACC_RTC__) -+ return cutlass::tfloat32_t(sqrtf(float(h))); -+#else -+ return cutlass::tfloat32_t(std::sqrt(float(h))); -+#endif -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { -+ -+ uint32_t a_mag = (reinterpret_cast(a) & 0x7fffffff); -+ uint32_t b_sign = (reinterpret_cast(b) & 0x80000000); -+ uint32_t result = (a_mag | b_sign); -+ -+ return reinterpret_cast(result); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Standard Library operations and definitions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace std { -+ -+#if !defined(__CUDACC_RTC__) -+/// Numeric limits -+template <> -+struct numeric_limits { -+ static bool const is_specialized = true; -+ static bool const is_signed = true; -+ static bool const is_integer = false; -+ static bool const is_exact = false; -+ static bool const has_infinity = true; -+ static bool const has_quiet_NaN = true; -+ static bool const has_signaling_NaN = false; -+ static std::float_denorm_style const has_denorm = std::denorm_present; -+ static bool const has_denorm_loss = true; -+ static std::float_round_style const round_style = std::round_to_nearest; -+ static bool const is_iec559 = false; -+ static bool const is_bounded = true; -+ static bool const is_modulo = false; -+ static int const digits = 19; -+ -+ /// Least positive value -+ static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } -+ -+ /// Minimum finite value -+ static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } -+ -+ /// Maximum finite value -+ static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } -+ -+ /// Returns smallest finite value -+ static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } -+}; -+#endif -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace std -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Arithmetic operators -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_HOST_DEVICE -+bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) == float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) != float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) < float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) <= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) > float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return float(lhs) >= float(rhs); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) + float(rhs)); -+} -+ -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator-(tfloat32_t const& lhs) { -+ union u_tff32 { -+ float val_f32; -+ tfloat32_t val_tf; -+ CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { } -+ }; -+ union u_tff32 x; x.val_f32 = -reinterpret_cast(lhs); -+ return x.val_tf; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) - float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) * float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { -+ return tfloat32_t(float(lhs) / float(rhs)); -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) + float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) - float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) * float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { -+ lhs = tfloat32_t(float(lhs) / float(rhs)); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator++(tfloat32_t & lhs) { -+ float tmp(lhs); -+ ++tmp; -+ lhs = tfloat32_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t& operator--(tfloat32_t & lhs) { -+ float tmp(lhs); -+ --tmp; -+ lhs = tfloat32_t(tmp); -+ return lhs; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator++(tfloat32_t & lhs, int) { -+ tfloat32_t ret(lhs); -+ float tmp(lhs); -+ tmp++; -+ lhs = tfloat32_t(tmp); -+ return ret; -+} -+ -+CUTLASS_HOST_DEVICE -+tfloat32_t operator--(tfloat32_t & lhs, int) { -+ tfloat32_t ret(lhs); -+ float tmp(lhs); -+ tmp--; -+ lhs = tfloat32_t(tmp); -+ return ret; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// User-defined literals -+// -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t operator "" _tf32(long double x) { -+ return cutlass::tfloat32_t(float(x)); -+} -+ -+CUTLASS_HOST_DEVICE -+cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { -+ return cutlass::tfloat32_t(int(x)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/thread/matrix.h b/3rdparty/cutlass/include/cutlass/thread/matrix.h -new file mode 100644 -index 0000000..bc78cf8 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/thread/matrix.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Defines a matrix object intended for storing data in registers and operations within -+ a CUDA thread. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/matrix_coord.h" -+ -+namespace cutlass { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Per-thread matrix object storing a packed matrix -+template < -+ typename Element, -+ int Rows, -+ int Columns, -+ typename Layout = layout::RowMajor -+> -+class Matrix : public Array { -+public: -+ -+ // Verify layout refers to a rank=2 matrix. -+ static_assert( -+ Layout::kRank == 2, -+ "Layout type must refer to a rank=2 matrix"); -+ -+ /// Base type -+ using Base = Array; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Number of rows -+ static int const kRows = Rows; -+ -+ /// Number of columns -+ static int const kColumns = Columns; -+ -+ /// Layout within the array -+ using Layout = Layout_; -+ -+ /// Reference type to an element -+ using Reference = Element &; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = 2; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Stride type -+ using Stride = typename Layout::Stride; -+ -+ /// TensorRef to matrix object -+ using TensorRef = TensorRef; -+ -+ /// TensorRef to constant matrix object -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// TensorRef to matrix object -+ using TensorView = TensorView; -+ -+ /// TensorRef to constant matrix object -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Diagonal vector -+ using Diagonal = Vector; -+ -+private: -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Returns the size of the object -+ CUTLASS_HOST_DEVICE -+ static MatrixCoord extent() { -+ return make_Coord(kRows, kColumns); -+ } -+ -+ /// Returns the layout object -+ CUTLASS_HOST_DEVICE -+ static Layout layout() { -+ return Layout::packed(extent()); -+ } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Matrix() { } -+ -+ /// Ctor -+ CUTLASS_HOST_DEVICE -+ Matrix(Diagonal const &diag) { -+ // Todo - construct from diagonal -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorRef ref() { -+ return TensorRef(this->data(), layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ ConstTensorRef const_ref() const { -+ return ConstTensorRef(this->data(), layout()); -+ } -+ -+ /// Returns a TensorRef pointing to the first element of the tensor. -+ CUTLASS_HOST_DEVICE -+ TensorView view() { -+ return TensorView(ref(), extent()); -+ } -+ -+ /// Returns a TensorView to const data -+ CUTLASS_HOST_DEVICE -+ ConstTensorView const_view() const { -+ return ConstTensorView(const_ref(), extent()); -+ } -+ -+ /// Returns a reference to the element at a given Coord -+ CUTLASS_HOST_DEVICE -+ Reference at(MatrixCoord const& coord) const { -+ typename Base::size_type offset_(layout().offset(coord)); -+ return Base::at(offset_); -+ } -+ -+ /// Returns the number of scalar elements needed to store tensor. -+ CUTLASS_HOST_DEVICE -+ LongIndex capacity() const { -+ return LongIndex(Base::size()); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Column vector defined as a matrix with exactly one column -+template < -+ typename Element, -+ int Rows, -+ typename Layout = layout::ColumnMajor -+> -+using ColumnVector = Matrix; -+ -+/// Row vector defined as a matrix with exactly one row -+template < -+ typename Element, -+ int Columns, -+ typename Layout = layout::RowMajor -+> -+using RowVector = Matrix; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/trace.h b/3rdparty/cutlass/include/cutlass/trace.h -new file mode 100644 -index 0000000..c77e7f4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/trace.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helpers for optionally tracing through code when debugging. -+ -+ This file is to be included after all other headers. -+*/ -+ -+#pragma once -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Tracing options -+#ifndef CUTLASS_DEBUG_TRACE_LEVEL -+#define CUTLASS_DEBUG_TRACE_LEVEL 0 -+#endif -+ -+#if CUTLASS_DEBUG_TRACE_LEVEL -+#include -+#include "cutlass/core_io.h" -+#if defined(__CUDA_ARCH__) -+#define CUTLASS_TRACE_HOST(x) -+#else -+#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } -+#endif -+#else -+#define CUTLASS_TRACE_HOST(x) -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h -new file mode 100644 -index 0000000..c084dd4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/pitch_linear_thread_map.h -@@ -0,0 +1,926 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing how threads are mapped to a given tile. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Strip-mines a pitch-linear tile among a given number of threads, first along -+/// the contiguous dimension then along the strided dimension. -+/// -+/// The tile must be divisible by the thread count such that all threads may -+/// execute the same number of iterations with the same delta to exhaustively -+/// cover the tile. -+/// -+/// This class satisfies the "RegularThreadMapping" concept. -+/// -+/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor -+/// kernels. -+template < -+ typename Shape_, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearStripminedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal implementation details -+ struct Detail { -+ -+ static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); -+ -+ /// Shape of the tile in units of vectors -+ using ShapeVec = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || -+ (!(kThreads % ShapeVec::kContiguous)), -+ "Shape must be divisible by number of iterations of each thread."); -+ }; -+ -+ /// Number of iterations by each thread -+ using Iterations = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ // Redo the comparison here to work around divide by zero compiler -+ // error. The compiler evaluates both path of platform::conditional. -+ (Threads >= Detail::ShapeVec::kContiguous -+ ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / -+ (kThreads / Detail::ShapeVec::kContiguous) -+ : 0)>, -+ layout::PitchLinearShape>::type; -+ -+ -+ /// Interval between accesses along each dimension of the tensor's logical coordinate space -+ /// (in units of Elements) -+ using Delta = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ kThreads / Detail::ShapeVec::kContiguous -+ >, -+ layout::PitchLinearShape< -+ kThreads * kElementsPerAccess, -+ 1 -+ > -+ >::type; -+ -+ /// Shape of the tile in units of vectors -+ using StorageShape = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape, -+ layout::PitchLinearShape>::type; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ /// (in units of Elements) -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ return TensorCoord( -+ (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, -+ thread_id / Detail::ShapeVec::kContiguous); -+ } -+}; -+ -+/// This ThreadMap is used by GEMV -+template < -+ typename Shape, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearTilePolicyStripminedThreadContiguous -+{ -+ static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0, -+ "Contiguous shape must divide number of threads"); -+ -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kThreads = Threads; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / (kThreads * kElementsPerAccess), -+ Shape::kStrided>; -+ -+ using Delta = layout::PitchLinearShape<1, 1>; -+ -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) -+ { -+ return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0); -+ } -+}; -+ -+template < -+ typename Shape, -+ int Threads, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearTilePolicyStripminedThreadStrided -+{ -+ static_assert((Shape::kStrided % Threads == 0), -+ "Strided shape must divide number of threads"); -+ -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ static int const kThreads = Threads; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ using Iterations = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided / kThreads>; -+ -+ using Delta = layout::PitchLinearShape<1, 1>; -+ -+ using ShapeVec = Shape; -+ -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) -+ { -+ -+ return TensorCoord(0, thread_id * Iterations::kStrided); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -+/// elements. -+/// -+/// This ThreadMap is used by tensor core kernels. -+template < -+ typename Shape_, -+ int Threads, -+ typename WarpThreadArrangement_, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearWarpRakedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert( -+ !(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Compute the 'shape' of the overall tile in units of vectors -+ using ShapeInAccesses = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ static_assert( -+ !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), -+ "ShapeInAccesses must be divisible by WarpThreadArrangement."); -+ -+ static_assert( -+ !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), -+ "ShapeInAccesses must be divisible by WarpThreadArrangement."); -+ -+ // compute number of warp-level accesses total -+ using WarpAccessIterations = layout::PitchLinearShape< -+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, -+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided -+ >; -+ -+ // Divide it into the number of warps, first partitioning the strided dimension then the -+ // contiguous. -+ static int const kWarpsStrided = -+ (WarpAccessIterations::kStrided >= kWarpCount -+ ? kWarpCount -+ : WarpAccessIterations::kStrided); -+ -+ static int const kWarpsContiguous = -+ (kWarpCount > WarpAccessIterations::kStrided -+ ? kWarpCount / kWarpsStrided -+ : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = layout::PitchLinearShape< -+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, -+ Detail::WarpThreadArrangement::kStrided -+ >; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous -+/// elements. Warps are arranged based on a stride. -+/// -+/// This ThreadMap is used by tensor core kernels for NCxHWx layout. -+template < -+ typename Shape_, -+ int Threads, -+ typename WarpThreadArrangement_, -+ int ElementsPerAccess = 1 -+> -+struct PitchLinearStridedWarpRakedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Base ThreadMap -+ using BaseThreadMap = PitchLinearWarpRakedThreadMap< -+ Shape, -+ kThreads, -+ WarpThreadArrangement, -+ kElementsPerAccess -+ >; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape; -+ -+ -+ struct Detail { -+ -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations; -+ -+ static int const kWarpSize = BaseThreadMap::Detail::kWarpSize; -+ -+ static int const kWarpCount = BaseThreadMap::Detail::kWarpCount; -+ -+ using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses; -+ -+ // Divide it into the number of warps, first partitioning the contiguous dimension then the -+ // stride. -+ static int const kWarpsContiguous = -+ (WarpAccessIterations::kContiguous >= kWarpCount -+ ? kWarpCount -+ : WarpAccessIterations::kContiguous); -+ -+ static int const kWarpsStrided = -+ (kWarpCount > WarpAccessIterations::kContiguous -+ ? kWarpCount / kWarpsContiguous -+ : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = typename BaseThreadMap::Delta; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+ -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Transpose the existing ThreadMap. For example, interleaved layout is like -+/// congruous in the global memory and crosswise in the shared memory. We need -+/// to transpose the coordinates between two. -+ -+template -+struct TransposePitchLinearThreadMap { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert(!(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = -+ layout::PitchLinearShape; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kContiguous == 1, -+ "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access -+ // (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided}; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ // Note the order of / and %. Also the 2nd operand is kStrided. -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id / Detail::WarpArrangement::kStrided), -+ (warp_id % Detail::WarpArrangement::kStrided)}; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous}; -+ -+ // This is the offset of a thread within a threadblock tile (units of -+ // vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of -+ // elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided()}; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+template -+struct TransposePitchLinearThreadMapSimt { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ static_assert(Iterations::kStrided == 1, -+ "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ TensorCoord coord = ThreadMap::initial_offset(thread_id); -+ -+ return TensorCoord( -+ coord.strided(), -+ coord.contiguous() -+ ); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory -+/// accesses performed by each warp then distributes warps across them. Warps are striped in the -+/// strided dimension and raked across the contiguous dimension. -+template < -+ typename Shape_, /// Overall shape to partition in units of elements -+ int Threads, /// Number of partiticipation threads -+ typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp -+ int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size) -+> -+struct PitchLinearWarpStripedThreadMap { -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ElementsPerAccess; -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = layout::PitchLinearShape; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// Fixed arrangement of threads within a warp (units of threads). -+ using WarpThreadArrangement = WarpThreadArrangement_; -+ -+ /// Number of threads per warp -+ static int const kWarpSize = WarpThreadArrangement::kCount; -+ -+ /// Number of participating warps -+ static int const kWarpCount = kThreads / kWarpSize; -+ -+ static_assert( -+ !(Shape::kContiguous % kElementsPerAccess), -+ "Shape must be divisible by vector length."); -+ -+ /// Compute the 'shape' of the overall tile in units of vectors -+ using ShapeInAccesses = layout::PitchLinearShape< -+ Shape::kContiguous / kElementsPerAccess, -+ Shape::kStrided -+ >; -+ -+ // compute number of warp-level accesses total -+ using WarpAccessIterations = layout::PitchLinearShape< -+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, -+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided -+ >; -+ -+ // Divide it into the number of warps, first partitioning the strided dimension then the -+ // contiguous. -+ static int const kWarpsStrided = -+ (WarpAccessIterations::kStrided >= kWarpCount -+ ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); -+ -+ static int const kWarpsContiguous = -+ (kWarpCount > WarpAccessIterations::kStrided ? -+ WarpAccessIterations::kContiguous / kWarpsStrided : 1); -+ -+ /// Arrangement of warps within a threadblock-scoped tile -+ using WarpArrangement = layout::PitchLinearShape< -+ kWarpsContiguous, kWarpsStrided -+ >; -+ }; -+ -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = layout::PitchLinearShape< -+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, -+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided -+ >; -+ -+ static_assert(Iterations::kCount, -+ "Number of iterations must be non-zero"); -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = layout::PitchLinearShape< -+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, -+ Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided -+ >; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ int warp_id = (thread_id / Detail::kWarpSize); -+ int lane_id = (thread_id % Detail::kWarpSize); -+ -+ // -+ // compute warp-level offset -+ // -+ -+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors) -+ layout::PitchLinearCoord warp_footprint{ -+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, -+ Detail::WarpThreadArrangement::kStrided -+ }; -+ -+ // This is the offset of a specific warp (in units of vectors) -+ layout::PitchLinearCoord warp_offset{ -+ (warp_id % Detail::kWarpsContiguous), -+ (warp_id / Detail::kWarpsContiguous) -+ }; -+ -+ // This is the offset of a specific thread within a warp (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_warp{ -+ lane_id % Detail::WarpThreadArrangement::kContiguous, -+ lane_id / Detail::WarpThreadArrangement::kContiguous -+ }; -+ -+ // This is the offset of a thread within a threadblock tile (units of vectors) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = -+ warp_footprint * warp_offset + thread_offset_in_warp; -+ -+ // This is the offset of a thread within a threadblock tile (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ -+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, -+ thread_offset_in_threadblock_tile_vec.strided() -+ }; -+ -+ return thread_offset_in_threadblock_tile_base; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous -+/// dimension then along the strided dimension, while each thread access a 2D thread-tile. -+/// -+/// The tile must be divisible by the thread count such that all threads may execute the same -+/// number of iterations with the same delta to exhaustively cover the tile. -+/// -+/// This class satisfies the "RegularThreadMapping" concept. -+template < -+ typename Shape_, -+ int Threads, -+ typename ThreadTileShape -+> -+struct PitchLinear2DThreadTileStripminedThreadMap; -+ -+ -+template < -+ typename Shape_, -+ int Threads -+> -+struct PitchLinear2DThreadTileStripminedThreadMap >{ -+ -+ /// Tensor coordinate -+ using TensorCoord = layout::PitchLinearCoord; -+ -+ /// Tile shape -+ using Shape = Shape_; -+ -+ /// Access Shape of each thread -+ using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>; -+ //using ThreadAccessShape = ThreadTileShape; -+ -+ /// Number of threads total -+ static int const kThreads = Threads; -+ -+ /// Extract length of each access from Layout -+ static int const kElementsPerAccess = ThreadAccessShape::kContiguous; -+ -+ static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)"); -+ -+ /// Internal implementation details -+ struct Detail { -+ -+ static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4"); -+ -+ static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), ""); -+ -+ static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)), -+ "Shape must be divisible thread count * accesses per thread."); -+ -+ /// Shape of the tile in units of vectors -+ using ShapeVec = layout::PitchLinearShape< -+ Shape::kContiguous / ThreadAccessShape::kContiguous, -+ Shape::kStrided / ThreadAccessShape::kStrided -+ >; -+ -+ static_assert( -+ (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || -+ (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), -+ "Shape must be divisible by number of iterations of each thread." -+ ); -+ }; -+ -+ /// Number of iterations by each thread -+ using Iterations = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ 1, -+ // Redo the comparison here to work around divide by zero compiler -+ // error. The compiler evaluates both path of platform::conditional. -+ (Threads >= Detail::ShapeVec::kContiguous -+ ? Detail::ShapeVec::kStrided / -+ (kThreads / Detail::ShapeVec::kContiguous) -+ : 0)>, -+ layout::PitchLinearShape>::type; -+ -+ /// Interval between accesses along each dimension of the tensor's logical coordinate space -+ /// (in units of Elements) -+ using Delta = typename platform::conditional< -+ Threads >= Detail::ShapeVec::kContiguous, -+ layout::PitchLinearShape< -+ Shape::kContiguous, -+ kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous -+ >, -+ layout::PitchLinearShape< -+ kThreads * ThreadAccessShape::kContiguous, -+ 1 -+ > -+ >::type; -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space -+ /// (in units of Elements) -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ return TensorCoord( -+ (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous, -+ (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided); -+ } -+}; -+ -+/// Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping -+template -+struct TransposePitchLinearThreadMap2DThreadTile { -+ /// Underlying ThreadMap -+ using ThreadMap = ThreadMap_; -+ -+ /// Tensor coordinate -+ using TensorCoord = typename ThreadMap::TensorCoord; -+ -+ /// Tile shape -+ using Shape = typename ThreadMap::Shape; -+ -+ /// Number of threads total -+ static int const kThreads = ThreadMap::kThreads; -+ -+ /// Extract vector length from Layout -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ -+ static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); -+ ///< Iterations along each dimension (concept: PitchLinearShape) -+ using Iterations = -+ layout::PitchLinearShape; -+ -+ static_assert(Iterations::kCount, "Number of iterations must be non-zero"); -+ -+ /// Shape of access by each thread -+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; -+ -+ ///< Delta betweeen accesses (units of elements, concept: PitchLinearShape) -+ using Delta = -+ layout::PitchLinearShape; -+ -+ -+ /// Maps thread ID to a coordinate offset within the tensor's logical -+ /// coordinate space Note this is slightly different from the one of -+ /// PitchLinearWarpRakedThreadMap. -+ CUTLASS_HOST_DEVICE -+ static TensorCoord initial_offset(int thread_id) { -+ -+ TensorCoord coord = ThreadMap::initial_offset(thread_id); -+ return TensorCoord( -+ coord.strided(), -+ coord.contiguous() -+ ); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h b/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h -new file mode 100644 -index 0000000..b62b6bf ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/thread/transpose.h -@@ -0,0 +1,107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Basic copy routines for tensor views -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+namespace transform { -+namespace thread { -+ -+/// Transforms a fragment by doing a transpose -+template < -+ int ElementCount, -+ typename TransposeShape, -+ typename Element -+> struct Transpose; -+ -+/// Specialization for int8_t 4x4 transpose -+template -+struct Transpose , int8_t> { -+ -+ static const int kElementCount = ElementCount_; -+ using TransposeShape = layout::PitchLinearShape<4,4>; -+ using Element = int8_t; -+ using Fragment = cutlass::Array; -+ -+ static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose"); -+ -+ CUTLASS_DEVICE -+ void transform(Fragment& dst, Fragment& src) { -+ -+ // Expose src/dst as int arrays. -+ int* src_int = reinterpret_cast(&src); -+ int* dst_int = reinterpret_cast(&dst); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){ -+ -+ int const i0 = 4 * i + 0; -+ int const i1 = 4 * i + 1; -+ int const i2 = 4 * i + 2; -+ int const i3 = 4 * i + 3; -+ -+ int a0 = src_int[i0]; -+ int a1 = src_int[i1]; -+ int a2 = src_int[i2]; -+ int a3 = src_int[i3]; -+ -+ int b0, b1, b2, b3, c0; -+ b0 = __byte_perm(a0, a1, 0x0040); -+ c0 = __byte_perm(a2, a3, 0x0040); -+ b0 = __byte_perm(b0, c0, 0x5410); -+ -+ b1 = __byte_perm(a0, a1, 0x0051); -+ c0 = __byte_perm(a2, a3, 0x0051); -+ b1 = __byte_perm(b1, c0, 0x5410); -+ -+ b2 = __byte_perm(a0, a1, 0x0062); -+ c0 = __byte_perm(a2, a3, 0x0062); -+ b2 = __byte_perm(b2, c0, 0x5410); -+ -+ b3 = __byte_perm(a0, a1, 0x0073); -+ c0 = __byte_perm(a2, a3, 0x0073); -+ b3 = __byte_perm(b3, c0, 0x5410); -+ -+ dst_int[i0] = b0; -+ dst_int[i1] = b1; -+ dst_int[i2] = b2; -+ dst_int[i3] = b3; -+ } -+ } -+}; -+ -+} // namespace thread -+} // namespace layout -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h b/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h -new file mode 100644 -index 0000000..c50e75b ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/thread/unary_op.h -@@ -0,0 +1,105 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+namespace transform { -+namespace thread { -+ -+namespace UnaryTransform { -+ struct Identity; ///< None (i.e., identity) -+ struct Conjugate; ///< Complex conjugate -+} -+ -+/// Element-wise unary operator that transforms one element of a fragment at a time -+template< -+ typename FragmentIn, ///< Input Fragment -+ typename FragmentOut,///< Output Fragment -+ typename Transform> ///< Unary transform operator -+class UnaryOp -+{ -+ public: -+ CUTLASS_DEVICE -+ static FragmentOut execute(FragmentIn &in) -+ { -+ static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ FragmentOut out; -+ if (platform::is_same::value ) -+ { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i=0; i < FragmentIn::kElements; ++i){ -+ out[i] = static_cast(in[i]); -+ } -+ } -+ else if (platform::is_same::value ) -+ { -+ for (int i=0; i < FragmentIn::kElements; ++i){ -+ out[i] = conj(static_cast(in[i])); -+ } -+ } -+ return out; -+ } -+}; -+ -+template -+class UnaryOp -+{ -+ public: -+ CUTLASS_DEVICE -+ static FragmentIn execute(FragmentIn &in) -+ { -+ static_assert(platform::is_same::value || -+ platform::is_same::value, -+ "Unary Operator not supported."); -+ -+ if (platform::is_same::value ) -+ { -+ return in; -+ } -+ else if (platform::is_same::value ) -+ { -+ for(int i=0; i < FragmentIn::kElements; ++i){ -+ in[i] = conj(in[i]); -+ } -+ } -+ return in; -+ } -+ }; -+ } -+ } -+} -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h -new file mode 100644 -index 0000000..0578123 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_iterator.h -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for matrix of indices (ellColInd matrix) -+*/ -+ -+#pragma once -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+namespace ell{ -+ -+constexpr unsigned int SmemPow = 8; -+constexpr unsigned int SmemStages = 2; -+constexpr unsigned int SmemSize = 1 << SmemPow; -+constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); -+ -+class SharedStorage{ -+ public: -+ Array array; -+}; -+ -+class Iterator{ -+ public: -+ using Layout = layout::PitchLinear; -+ using LongIndex = typename Layout::LongIndex; -+ -+ private: -+ const int *gmem_col_idx_; -+ int *smem_col_idx_; -+ const int block_size_; -+ const int base_idx_; -+ const int k_shape_; -+ const int ell_increment_; -+ const int array_length_; -+ int col_idx_base_; -+ int residue_; -+ int counter_; -+ -+ int pow2_; -+ int residue_shape_; -+ -+ int smem_offset_; -+ int smem_stage_; -+ int gmem_offset_; -+ -+ int lane_; -+ -+ bool is_pow2_; -+ bool is_residue_tile_; -+ -+ public: -+ CUTLASS_DEVICE -+ void load_ell_indices(){ -+ for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; -+ } -+ gmem_offset_ += SmemSize; -+ smem_stage_ ^= 1; -+ } -+ -+ CUTLASS_DEVICE -+ Iterator( -+ SharedStorage& shared_storage_base, -+ const int* col_idx, -+ const int& block_size, -+ const int& base_idx, -+ const int k_shape, -+ const int& problem_size_k, -+ const int& ell_stride, -+ const int& thread_idx) -+ : residue_(0), -+ counter_(0), -+ smem_offset_(0), -+ smem_stage_(0), -+ gmem_offset_(0), -+ block_size_(block_size), -+ base_idx_(base_idx), -+ k_shape_(k_shape), -+ ell_increment_(ell_stride * block_size), -+ array_length_((problem_size_k + block_size_ - 1) / block_size_), -+ residue_shape_(problem_size_k % k_shape_), -+ is_residue_tile_(residue_shape_ != 0), -+ smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), -+ gmem_col_idx_(const_cast(col_idx)), -+ lane_(thread_idx % 32) { -+ -+ load_ell_indices(); -+ __syncthreads(); -+ -+ is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); -+ if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; -+ -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; -+ -+ pow2_ = 0; -+ while(block_size_ >> (pow2_ + 1)) ++pow2_; -+ } -+ -+ CUTLASS_DEVICE -+ int get_blocksize(){ -+ return block_size_; -+ } -+ -+ CUTLASS_DEVICE -+ Iterator &operator++(){ -+ if(is_residue_tile_){ -+ residue_ += residue_shape_; -+ is_residue_tile_ = false; -+ } else { -+ residue_ += k_shape_; -+ } -+ -+ if(residue_ < block_size_){ -+ return *this; -+ } -+ -+ if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) -+ load_ell_indices(); -+ -+ if(residue_ == block_size_){ -+ ++smem_offset_; -+ counter_ += ell_increment_; -+ residue_ = 0; -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; -+ return *this; -+ } -+ -+ if(is_pow2_){ -+ smem_offset_ += residue_ >> pow2_; -+ counter_ += (residue_ >> pow2_) * ell_increment_; -+ residue_ = residue_ & ((1 << pow2_) - 1); -+ } -+ else { -+ smem_offset_ += residue_ / block_size_; -+ counter_ += (residue_ / block_size_) * ell_increment_; -+ residue_ %= block_size_; -+ } -+ -+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; -+ -+ return *this; -+ } -+ -+ CUTLASS_DEVICE -+ LongIndex get_offset(const int& idx) { -+ int num_jump_tiles; -+ if(is_pow2_) -+ num_jump_tiles = (idx + residue_) >> pow2_; -+ else -+ num_jump_tiles = (idx + residue_) / block_size_; -+ -+ int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); -+ return tmp - num_jump_tiles * ell_increment_; -+ } -+ -+ CUTLASS_DEVICE -+ LongIndex get_offset_fast() { -+ return col_idx_base_; -+ } -+}; -+ -+} -+} -+} -+} -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h -new file mode 100644 -index 0000000..9eec17e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h -@@ -0,0 +1,1350 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// EllPredicatedTileAccessIterator -+/// -+template -+class EllPredicatedTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend EllPredicatedTileAccessIterator; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ LongIndex stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : stride_(layout.stride(0)) { -+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * -+ ThreadMap::Delta::kStrided * LongIndex(stride_) * -+ sizeof_bits::value / 8; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Offset to the first steady-state tile -+ TensorCoord residue_offset_; -+ -+ /// Initial offset to define ELL block -+ TensorCoord ell_offset_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent.strided()); -+ } else { -+ guard = (coord.contiguous() < extent.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent.strided() && -+ coord.contiguous() < extent.contiguous()); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residue_tile_(true) { -+ -+ TensorCoord residue_extent; -+ if (kAdvanceRank) { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; -+ if (!residue_size) { -+ residue_size = Shape::kStrided; -+ } -+ -+ residue_offset_ = make_Coord(0, residue_size); -+ residue_extent = make_Coord( -+ extent_.contiguous(), -+ min(threadblock_offset.strided() + residue_size, extent_.strided()) -+ ); -+ } else { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; -+ if (!residue_size) { -+ residue_size = Shape::kContiguous; -+ } -+ -+ residue_offset_ = make_Coord(residue_size, 0); -+ -+ residue_extent = make_Coord( -+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), -+ extent_.strided() -+ ); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ ell_offset_ = ThreadMap::initial_offset(thread_id); -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(residue_extent, false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ thread_offset_ += residue_offset_; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(residue_offset_)); -+ -+ compute_predicates_(extent_, true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast( -+ pointer_ + -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; -+ } -+ -+ /// Returns a k_location -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ if(kAdvanceRank){ //strided -+ return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; -+ }else{ -+ return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ if(kAdvanceRank) -+ return params_.stride_; -+ else -+ return 1; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ iteration_vector_ = 0; -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ -+ Mask mask; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = ell_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < blocksize); -+ } else { -+ guard = (coord.contiguous() < blocksize); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] &= predicates_[i]; -+ } -+ set_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class EllPredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ AccessType>; -+ -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_k() const { -+ return iterator_.get_k(); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { -+ return iterator_.get_stride(); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileAccessIterator operator++(int) { -+ EllPredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h -new file mode 100644 -index 0000000..f984733 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h -@@ -0,0 +1,1315 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" -+#include "cutlass/transform/threadblock/ell_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// EllPredicatedTileIterator -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::EllPredicatedTileIterator; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess -+> -+class EllPredicatedTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ EllPredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend EllPredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return address_iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ address_iterator_.set_iteration_index(idx); -+ LongIndex ell_offset = 0; -+ -+ int k_offset = address_iterator_.get_k(); -+ ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); -+ -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ bool is_valid = address_iterator_.valid(); -+ is_valid = is_valid && (ell_offset >= 0); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, is_valid); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { -+ -+ LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ bool is_valid = address_iterator_.valid(); -+ is_valid = is_valid && (ell_offset >= 0); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, is_valid); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class EllPredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ } -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class EllPredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ }; -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { -+ iterator_.ell_add_mask(blocksize); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped -+/// to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class EllPredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Iterator for ELL storage -+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index(frag, ell_iter); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { -+ iterator_.load_with_ell_index_fast(frag, ell_iter); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is -+/// mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class EllPredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = EllPredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend EllPredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ CUTLASS_HOST_DEVICE -+ Params() {} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a EllPredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : EllPredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ EllPredicatedTileIterator operator++(int) { -+ EllPredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns a stride -+ CUTLASS_HOST_DEVICE -+ int get_stride() const { return iterator_.get_stride(); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// add mask for small tiles in ELL -+ CUTLASS_HOST_DEVICE -+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..61bed18 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h -@@ -0,0 +1,375 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ It can be used to load the gamma and beta vectors of layernorm which is loop variant. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/threadblock/conv2d_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorAccessIterator -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; -+ -+ using AccessType = AlignedArray; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ TensorCoord thread_offset_; -+ -+ int problem_size_k_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ bool guard_; -+ -+ TensorCoord::Index residue_size_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Extent of tensor -+ int problem_size_k, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) { -+ pointer_ = (thread_id < kThreads) -+ ? reinterpret_cast( -+ const_cast(scale_pointer)) -+ : reinterpret_cast( -+ const_cast(bias_pointer)); -+ -+ // Per-thread offset in logical coordinates of tensor -+ int thread_base = (thread_id < kThreads) ? 0 : kThreads; -+ -+ problem_size_k_ = problem_size_k; -+ -+ is_residue_tile_ = true; -+ -+ residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; -+ -+ if (residue_size_ == 0) { -+ residue_size_ = ThreadblockShape::kContiguous; -+ } -+ -+ guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; -+ -+ thread_offset_ = -+ threadblock_offset + -+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ /// Extent of tensor -+ int problem_size_k, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ guard_ = threadIdx.x < kThreads * 2; -+ -+ TensorCoord offset = is_residue_tile_ ? -+ TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) -+ : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); -+ -+ thread_offset_ = -+ thread_offset_ + -+ offset; -+ -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ (thread_offset_.contiguous() * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ guard_ &= (!enable); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return guard_; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorAccessIterator { -+ public: -+ -+ using ThreadblockShape = ThreadblockShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ ///< Extent of tensor -+ int problem_size_k, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(problem_size_k, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator( -+ int problem_size_k, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorAccessIterator operator++(int) { -+ PredicatedScaleBiasVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h -new file mode 100644 -index 0000000..fb08930 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h -@@ -0,0 +1,328 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of scale and bias vectors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. -+ -+ This can be used to load var and mean vectors in layernorm which is loop invariant. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedScaleBiasVectorIterator -+/// -+template -+class PredicatedScaleBiasVectorIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kElementsPerAccess = 1; -+ -+ using AccessType = AlignedArray; -+ -+ static int const kIterations = WarpShape::kContiguous / 8; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ ConstPointer scale_pointer_; -+ ConstPointer bias_pointer_; -+ -+ /// Size of tensor -+ int problem_size_; -+ -+ int32_t thread_offset_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Extent of tensor -+ int problem_size, -+ /// Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : problem_size_(problem_size), -+ scale_pointer_(scale_pointer), -+ bias_pointer_(bias_pointer) { -+ -+ thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; -+ } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ /// Extent of tensor -+ int problem_size, -+ /// Pointer to start of scale vector -+ ConstPointer scale_pointer, -+ /// Pointer to start of scale vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedScaleBiasVectorIterator(problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.fill(__float2half2_rn(0.0f)); -+ __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); -+ -+ // load scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2].x, -+ scale_pointer_ + thread_offset_ + c * 8, -+ (thread_offset_ + c * 8) < problem_size_ -+ ); -+ } -+ -+ // load bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ __half, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c * 2 + 1].x, -+ bias_pointer_ + thread_offset_ + c * 8, -+ (thread_offset_ + c * 8) < problem_size_ -+ ); -+ } -+ -+ // duplicate scale -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2].y = frag_ptr[c * 2].x; -+ } -+ -+ // duplicate bias -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedScaleBiasVectorIterator { -+ public: -+ -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedScaleBiasVectorIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ using Fragment = typename UnderlyingIterator::Fragment; -+ -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ ///< Extent of tensor -+ int problem_size, -+ ///< Pointer to the start of the scale vector -+ ConstPointer scale_pointer, -+ ///< Pointer to the start of the bias vector -+ ConstPointer bias_pointer, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(problem_size, scale_pointer, bias_pointer, -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedScaleBiasVectorIterator( -+ int problem_size, ///< Extent of tensor -+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector -+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedScaleBiasVectorIterator(problem_size, -+ scale_pointer, bias_pointer, -+ thread_id, make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// threadblock tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load(frag); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h -new file mode 100644 -index 0000000..29fa8af ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h -@@ -0,0 +1,2085 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile this -+ iterator visits maybe partial, then the remaining tiles are complete. So, we -+ only need to compute the predicates twice, once before the first tile and -+ once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorPredicates -+/// -+template -+class PredicatedTileAccessIteratorPredicates { -+ public: -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = Layout_; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+// private: -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Offset to the first steady-state tile -+ TensorCoord residue_offset_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent.strided()); -+ } else { -+ guard = (coord.contiguous() < extent.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent.strided() && -+ coord.contiguous() < extent.contiguous()); -+ } -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { -+ -+ TensorCoord residue_extent; -+ if (kAdvanceRank) { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; -+ if (!residue_size) { -+ residue_size = Shape::kStrided; -+ } -+ -+ residue_offset_ = make_Coord(0, residue_size); -+ residue_extent = make_Coord( -+ extent_.contiguous(), -+ min(threadblock_offset.strided() + residue_size, extent_.strided()) -+ ); -+ } else { -+ -+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; -+ if (!residue_size) { -+ residue_size = Shape::kContiguous; -+ } -+ -+ residue_offset_ = make_Coord(residue_size, 0); -+ -+ residue_extent = make_Coord( -+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), -+ extent_.strided() -+ ); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ compute_predicates_(residue_extent, false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Default constructor -+ PredicatedTileAccessIteratorPredicates() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorPredicates( -+ /// Extent of tensor -+ TensorCoord extent) -+ : extent_(extent) { -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorPredicates &operator++() { -+ -+ return *this; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIterator -+/// -+template -+class PredicatedTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : -+ Base(layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc()() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ UnderlyingPredicates the_predicates; -+ -+ /// Parameters object with precomputed internal state -+ Params params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Below is used when Gather is turned on. We need to record strided_offset -+ /// and contiguous_offset seperated to compute the offset by using -+ /// -+ /// offset = contiguous_offset + indices[strided_offset] -+ /// -+ -+ /// Gather indices -+ int const *indices_; -+ -+ Index gather_offset_strided; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ is_residue_tile_(true), -+ indices_(indices) { -+ -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset(layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); -+ } -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ the_predicates.set_iteration_index(index); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ the_predicates.thread_offset_ += the_predicates.residue_offset_; -+ -+ the_predicates.compute_predicates_(the_predicates.extent_, true); -+ -+ Layout layout(params_.stride_); -+ -+ if (!Gather) { -+ add_pointer_offset(layout(the_predicates.residue_offset_)); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ gather_offset_strided = the_predicates.thread_offset_.strided(); -+ add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0))); -+ -+ if (kAdvanceRank) { -+ gather_offset_strided += Shape::kStrided * (tile_offset.strided() - 1); -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ } else { -+ add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - 1)); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ } else { -+ if (!Gather) { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); -+ gather_offset_strided += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ if (Gather) { -+ assert(indices_); -+ -+ if (!valid()) { -+ return nullptr; -+ } -+ -+ LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + the_predicates.iteration_vector_; -+ int strided_index = gather_offset_strided + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; -+ -+ LongIndex strided_offset = indices_[strided_index] * LongIndex(params_.stride_) * sizeof_bits::value / 8; -+ -+ return reinterpret_cast(pointer_ + contiguous_offset + strided_offset); -+ } -+ -+ return reinterpret_cast( -+ pointer_ + -+ the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ -+ the_predicates.operator++(); -+ -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ if (!Gather) { -+ pointer_ += params_.inc_strided_; -+ } -+ -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ if (!Gather) { -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ } -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ the_predicates.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ the_predicates.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ the_predicates.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ the_predicates.get_mask(mask); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() const { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, Gather>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, Gather>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){}; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row()), -+ indices) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank 2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< -+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingPredicates::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIterator; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ Coord stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// contiguous dimension -+ LongIndex inc_contiguous_; -+ /// amount (in byte) to increment pointer from first access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access of current -+ /// contiguous dimension to first access of next one. -+ LongIndex inc_next_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { -+ inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * -+ sizeof_bits::value / 8; -+ -+ inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; -+ }; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params params_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ UnderlyingPredicates the_predicates; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent, -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ the_predicates.compute_predicates_(extent, is_steady_state); -+ } -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ the_predicates(extent), -+ is_residue_tile_(true) { -+ -+ the_predicates.set_predicates(thread_id, threadblock_offset); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.thread_offset_)); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ -+ the_predicates.thread_offset_ += the_predicates.residue_offset_; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(the_predicates.residue_offset_)); -+ -+ the_predicates.compute_predicates_(the_predicates.extent_, true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); -+ pointer_ += Shape::kContiguous * tile_offset[0]; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); -+ pointer_ += Shape::kStrided * tile_offset[1]; -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ the_predicates.operator++(); -+ ++the_predicates.iteration_vector_; -+ if (the_predicates.iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ the_predicates.iteration_vector_ = 0; -+ ++the_predicates.iteration_contiguous_; -+ -+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ pointer_ += params_.inc_contiguous_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ the_predicates.iteration_contiguous_ = 0; -+ ++the_predicates.iteration_strided_; -+ -+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_next_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ the_predicates.iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { the_predicates.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { the_predicates.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return the_predicates.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -+/// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major interleaved data. -+// It is mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator, -+ AdvanceRank, ThreadMap_, AccessType_, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ AccessType>; -+ -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileAccessIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator operator++(int) { -+ PredicatedTileAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { return iterator_.valid(); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h -new file mode 100644 -index 0000000..1ce5e39 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h -@@ -0,0 +1,834 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last -+ "residue" tile first, with the objective of minimizing predicate mask updates -+ during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIterator2dThreadTile -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Uses a non-template class -+ struct Params : PredicatedTileAccessIteratorParams { -+ -+ public: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ using Base = PredicatedTileAccessIteratorParams; -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : -+ Base(layout.stride(0), -+ MakePredicatedTileAccessIteratorDesc()() -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) : -+ Base(base) { } -+ }; -+ -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Index of residue tile -+ int residue_tile_idx_; -+ -+ /// Used for out-of-order visitation -+ bool is_residue_tile_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ /// Tracks iterations within the thread loop -+ int iteration_thread_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_HOST_DEVICE -+ void compute_predicates_( -+ /// optionally, simplify predicate calculation during 'steady state' phase -+ bool is_steady_state = false) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) { -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous, -+ ts + s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ -+ if (is_steady_state) { -+ if (kAdvanceRank == 0) { -+ guard = (coord.strided() < extent_.strided()); -+ } else { -+ guard = (coord.contiguous() < extent_.contiguous()); -+ } -+ } else { -+ guard = (coord.strided() < extent_.strided() && -+ coord.contiguous() < extent_.contiguous()); -+ } -+ -+ int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ } -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residue_tile_(true) { -+ -+ -+ TensorCoord residue_offset; -+ if (kAdvanceRank) { -+ residue_tile_idx_ = -+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / -+ Shape::kStrided; -+ residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided); -+ } else { -+ residue_tile_idx_ = -+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / -+ Shape::kContiguous; -+ residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0); -+ } -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + residue_offset + -+ ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(false); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); -+ iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); -+ -+ iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided; -+ iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += int(sizeof(Element)) * pointer_offset; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ if (is_residue_tile_) { -+ TensorCoord residue_offset; -+ if (kAdvanceRank) { -+ residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided); -+ } else { -+ residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0); -+ } -+ -+ thread_offset_ -= residue_offset; -+ -+ Layout layout(params_.stride_); -+ add_pointer_offset(-layout(residue_offset)); -+ -+ compute_predicates_(true); -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } else { -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * tile_offset.strided(); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ } else { -+ pointer_ += params_.inc_advance_ * tile_offset.contiguous(); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ } -+ } -+ is_residue_tile_ = false; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *ret_val = reinterpret_cast( -+ pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element))); -+ -+ return ret_val; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ -+ iteration_thread_++; -+ -+ if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided) -+ return *this; -+ -+ iteration_thread_ = 0; -+ -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ int pred_idx = -+ iteration_thread_ + -+ iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + -+ iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ -+ return pred; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))){} -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIterator2dThreadTile operator++(int) { -+ PredicatedTileAccessIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h -new file mode 100755 -index 0000000..cbabc4e ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Predicated tile access iterator descriptor object containing template dependent state -+struct PredicatedTileAccessIteratorDesc { -+ -+ int element_size_bits; -+ int advance_rank; -+ layout::PitchLinearCoord threadblock_shape; -+ layout::PitchLinearCoord threadmap_iterations; -+ layout::PitchLinearCoord threadmap_delta; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc() { } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc( -+ int element_size_bits_, -+ int advance_rank_, -+ layout::PitchLinearCoord threadblock_shape_, -+ layout::PitchLinearCoord threadmap_iterations_, -+ layout::PitchLinearCoord threadmap_delta_ -+ ): -+ element_size_bits(element_size_bits_), -+ advance_rank(advance_rank_), -+ threadblock_shape(threadblock_shape_), -+ threadmap_iterations(threadmap_iterations_), -+ threadmap_delta(threadmap_delta_) -+ { -+ #if 0 -+ printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n", -+ element_size_bits, -+ advance_rank, -+ threadblock_shape.contiguous(), threadblock_shape.strided(), -+ threadmap_iterations.contiguous(), threadmap_iterations.strided(), -+ threadmap_delta.contiguous(), threadmap_delta.strided()); -+ #endif -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template -+// dependent state -+template < -+ typename Shape, typename Element, typename Layout, -+ int AdvanceRank, typename ThreadMap> -+ struct MakePredicatedTileAccessIteratorDesc; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for pitch-linear data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return PredicatedTileAccessIteratorDesc( -+ sizeof_bits::value, -+ AdvanceRank, -+ {Shape::kContiguous, Shape::kStrided}, -+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, -+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} -+ ); -+} -+ -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for row-major data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap, int InterleavedK> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kInterleavedK = InterleavedK; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. -+template < -+ typename Shape, typename Element, int AdvanceRank, -+ typename ThreadMap, int InterleavedK> -+struct MakePredicatedTileAccessIteratorDesc < -+ Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kInterleavedK = InterleavedK; -+ -+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorDesc operator()() { -+ -+ return UnderlyingMakeOperator()(); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Parameters struct -+// -+ -+struct PredicatedTileAccessIteratorParams { -+ -+ using Index = int32_t; -+ using LongIndex = int64_t; -+ -+ // -+ // Data members -+ // -+ /// stride of pitch-linear layout (units of Element) -+ LongIndex stride_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { -+ -+ stride_ = stride; -+ -+ inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * -+ desc.element_size_bits / 8; -+ -+ if (desc.advance_rank) { -+ // advance along strided dimension -+ inc_advance_ = -+ desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * -+ desc.threadmap_delta.strided() * LongIndex(stride_) * -+ desc.element_size_bits / 8; -+ -+ return Status::kSuccess; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { -+ return initialize(LongIndex(stride), desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams() { -+ initialize(LongIndex(0), PredicatedTileAccessIteratorDesc()); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { -+ initialize(stride, desc); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h -new file mode 100644 -index 0000000..d304b99 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h -@@ -0,0 +1,892 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates calculating the address and predicates to the load of tiles -+ from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last -+ "residue" tile first, with the objective of minimizing predicate mask updates -+ during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be -+ stored in registers, and integer addition is used to advance the pointer -+ through memory. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/predicate_vector.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileAccessIteratorTriangularMatrix -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data. -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ static_assert( kFillMode == FillMode::kFull || -+ ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1), -+ "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1"); -+ -+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), -+ "Vectors implied by the thread map must be divisible by the access type."); -+ -+ static int const kPredicatesPerByte = 4; -+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte; -+ -+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; -+ -+ /// Number of 32b words containing predicates -+ static int const kPredicateByteCount = -+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; -+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; -+ -+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; -+ -+ static_assert(kPredicateWordCount <= 4, "Too many predicates."); -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = Array; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ private: -+ /// stride of pitch-linear layout (units of Element) -+ StrideIndex stride_; -+ /// (true) pitch-linear layout is mapped to row-major matrix -+ /// (false) pitch-linear layout is mapped to column-major matrix -+ bool is_row_major_; -+ /// for vectorized access across the diagonal boundary guard condition is -+ /// checked for the element on the boundary -+ int access_diagonal_boundary_; -+ /// amount (in byte) to increment pointer to move to next access along -+ /// strided dimension -+ LongIndex inc_strided_; -+ /// amount (in byte) to increment pointer from last access to first access -+ /// of next tile -+ LongIndex inc_next_; -+ /// amount (in byte) to increment pointer from first access of current tile -+ /// to first access of next tile -+ LongIndex inc_advance_; -+ -+ public: -+ -+ // Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) : -+ stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) { -+ -+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * -+ sizeof_bits::value / 8; -+ -+ if (kAdvanceRank) { -+ // advance along strided dimension -+ inc_advance_ = -+ Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; -+ } else { -+ // advance along contiguous dimension -+ inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; -+ } -+ -+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * -+ ThreadMap::Delta::kStrided * LongIndex(stride_) * -+ sizeof_bits::value / 8; -+ -+ }; -+ -+ -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Parameters object with precomputed internal state -+ Params const ¶ms_; -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Guard predicates -+ uint32_t predicates_[kPredicateWordCount]; -+ -+ /// Track global memory addresses on the diagonal -+ /// To ignore imag part for diagonal elements of hermitian matrices -+ uint32_t predicates_onDiag_[kPredicateWordCount]; -+ -+ /// Size of tensor -+ TensorCoord extent_; -+ -+ /// Initial offset for each thread -+ TensorCoord thread_offset_; -+ -+ /// Iteration along vectors implied by the thread map -+ int iteration_vector_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ private: -+ /// Computes predicates based on internally tracked per-thread offset. -+ CUTLASS_DEVICE -+ void compute_predicates_( -+ /// Extent of the matrix window -+ TensorCoord extent) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0u; -+ predicates_onDiag_[i] = 0u; -+ } -+ -+ CompareOp compare_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { -+ -+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); -+ -+ int c = access_residual / kAccessesPerVector; -+ int v = access_residual % kAccessesPerVector; -+ -+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, -+ s * ThreadMap::Delta::kStrided); -+ -+ TensorCoord coord = thread_offset_ + iteration_coord; -+ -+ bool guard; -+ bool onDiag = false; -+ -+ guard = ((coord.strided() < extent.strided()) && -+ (coord.contiguous() < extent.contiguous())); -+ -+ -+ // guard access on the wrong side of the triagular matrix diagonal -+ if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) { -+ coord += TensorCoord{params_.access_diagonal_boundary_, 0}; -+ -+ bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_; -+ bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_; -+ -+ guard = guard && triagular_guard_row_major && triagular_guard_col_major; -+ -+ if (kDiagType == DiagType::kUnit) { -+ onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false; -+ } -+ } -+ -+ int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord; -+ int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord; -+ int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte; -+ int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte; -+ -+ predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag)); -+ -+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); -+ -+ } -+ -+ } -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : params_(params), -+ pointer_(reinterpret_cast(const_cast(pointer))), -+ extent_(extent) { -+ -+ -+ // Per-thread offset in logical coordinates of tensor -+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); -+ -+ // update internal pointers -+ Layout layout(params_.stride_); -+ add_pointer_offset(layout(thread_offset_)); -+ -+ compute_predicates_(extent_); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_vector_ = index % kAccessesPerVector; -+ int residual_access = index / kAccessesPerVector; -+ -+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += sizeof_bits::value * pointer_offset / 8; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ -+ if (kAdvanceRank) { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); -+ pointer_ += Shape::kContiguous * tile_offset.contiguous(); -+ thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()}; -+ } else { -+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); -+ pointer_ += Shape::kStrided * tile_offset.strided(); -+ thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0}; -+ } -+ -+ compute_predicates_(extent_); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast( -+ pointer_ + -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ -+ ++iteration_vector_; -+ if (iteration_vector_ < kAccessesPerVector) { -+ return *this; -+ } -+ -+ iteration_vector_ = 0; -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ pointer_ += params_.inc_strided_; -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ // advance to next tile -+ pointer_ += params_.inc_next_; -+ -+ // now return to start tile - if the iterator is subsequently advanced, this -+ // subtraction as well as the subsequent integer addition are both elided by -+ // the compiler. -+ pointer_ -= params_.inc_advance_; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = enable ? 0u : predicates_[i]; -+ } -+ -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = 0xffffffff; -+ } -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ predicates_[i] = mask[i]; -+ } -+ -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kPredicateWordCount; ++i) { -+ mask[i] = predicates_[i]; -+ } -+ } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ -+ -+ int pred_idx = -+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); -+ -+ int word_idx = pred_idx / kPredicatesPerWord; -+ int residual = pred_idx % kPredicatesPerWord; -+ int byte_idx = residual / kPredicatesPerByte; -+ int bit_idx = residual % kPredicatesPerByte; -+ -+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; -+ return pred; -+ -+ -+ //return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, -+ kSideMode, kFillMode, kDiagType, AccessType>; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ static int const kAccessDiagonalBoundary = -+ (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), -+ threadblock_offset.column())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ return iterator_.getOnDiag(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileAccessIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ using AccessType = AccessType_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, -+ kSideMode, kFillMode, kDiagType, AccessType>; -+ -+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; -+ -+ static int const kAccessDiagonalBoundary = -+ (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileAccessIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){}; -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ ///< Precomputed parameters object -+ Params const ¶ms, -+ ///< Pointer to start of tensor -+ Pointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) { -+ PredicatedTileAccessIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Return if the address in on the diagonal -+ CUTLASS_HOST_DEVICE -+ bool getOnDiag() { -+ return iterator_.getOnDiag(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h -new file mode 100644 -index 0000000..839d8f5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h -@@ -0,0 +1,1880 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses. The first tile this -+ iterator visits maybe partial, then the remaining tiles are complete. So, we -+ only need to compute the predicates twice, once before the first tile and -+ once for the remaining full tiles which can share the same predicates. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIterator -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Visitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIterator; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int AccessSize = ThreadMap::kElementsPerAccess, -+ bool Gather = false -+> -+class PredicatedTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) {} -+ -+ /// Default constructor -+ Params() = default; -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ /// Gather indices -+ int const *indices = nullptr) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset, indices) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize, -+ Gather -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) -+ {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), -+ indices) -+ { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize, -+ bool Gather -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize, -+ Gather -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< Gather indices -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), -+ indices -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank-2 data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator, AdvanceRank, -+ ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRankN<2>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ -+ friend PredicatedTileIterator; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) {} -+ -+ /// Default constructor -+ Params() = default; -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset(make_Coord(0, 1)); -+ else -+ address_iterator_.add_tile_offset(make_Coord(1, 0)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) -+ {} -+ }; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int AccessSize -+> -+class PredicatedTileIterator { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::AffineRank2RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ // Map to the underlying AffineRankN<2> layout -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::AffineRankN<2>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given an AffineRankN<2> tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying AffineRankN<2> tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for interleaved data. It is mapped -+/// to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+ -+template -+class PredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::ColumnMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.row() * kInterleavedK, -+ extent.column() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.row() * kInterleavedK, -+ threadblock_offset.column() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator for interleaved-32 data. It is -+/// mapped to the congruous layout. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator, -+ AdvanceRank, ThreadMap_, AccessSize, false> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ static int const kInterleavedK = InterleavedK; -+ using Layout = layout::RowMajorInterleaved; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator< -+ layout::PitchLinearShape, -+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; -+ -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ friend PredicatedTileIterator; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ /// Default constructor -+ Params() = default; -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) -+ : params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ -+ /// Default constructor -+ PredicatedTileIterator() = default; -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : iterator_(params.params_, pointer, -+ layout::PitchLinearCoord(extent.column() * kInterleavedK, -+ extent.row() / kInterleavedK), -+ thread_id, -+ layout::PitchLinearCoord( -+ threadblock_offset.column() * kInterleavedK, -+ threadblock_offset.row() / kInterleavedK)) {} -+ -+ /// Construct a PredicatedTileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator operator++(int) { -+ PredicatedTileIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h -new file mode 100644 -index 0000000..0a685fc ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h -@@ -0,0 +1,787 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h" -+#include "cutlass/transform/thread/transpose.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIterator2dThreadTile -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ bool Transpose = false -+> -+class PredicatedTileIterator2dThreadTile; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIterator2dThreadTile { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ /// extra set of parenthesis is needed for VS compiler -+ struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits::value / -+ 8)) AccessType { -+ -+ Array storage; -+ -+ static int const kElements = ThreadMap::kElementsPerAccess; -+ }; -+ -+ /// Optinally this fragment can be 4x4 transposed -+ using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>; -+ static bool const transpose = Transpose_; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIterator2dThreadTile; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ using Base = typename TileAccessIterator::Params::Base; -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Base const &base) -+ : params_(base) {} -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset, -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ -+ -+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ -+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ address_iterator_.set_iteration_index(access_idx); -+ if (address_iterator_.valid()) { -+ -+ frag_ptr[access_idx] = -+ *(address_iterator_.get() + pointer_offset); -+ } -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ if (transpose) { -+ Transform t; -+ t.transform(frag, frag); -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ -+ -+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ -+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; -+ -+ address_iterator_.set_iteration_index(access_idx); -+ if (address_iterator_.valid()) { -+ *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ bool Transpose_ -+> -+class PredicatedTileIterator2dThreadTile { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static bool const Transpose = Transpose_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ Transpose -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ bool Transpose_ -+> -+class PredicatedTileIterator2dThreadTile { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static bool const Transpose = Transpose_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ Transpose -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIterator2dThreadTile; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(typename UnderlyingIterator::Params::Base const &base) -+ : params_(base) {} -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock -+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIterator2dThreadTile operator++(int) { -+ PredicatedTileIterator2dThreadTile self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h -new file mode 100644 -index 0000000..b849ee7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h -@@ -0,0 +1,818 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/memory.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedTileIteratorTriangularMatrix -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+/// Regular tile iterator using a precomputed control structure to minimize register liveness -+/// and integer arithmetic. -+/// -+/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. -+/// -+/// Base pointer and tensor extents may be specified at the time the iterator is constructed. -+/// Subsequently, they are assumed to be immutable. -+/// -+/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. -+/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. -+/// -+/// Vistitation order is intended to first visit a "residual" tile that may be partially full in -+/// both the advance dimension and the steady-state dimension. This is assumed to be the last -+/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to -+/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent -+/// accesses may be performed without updating internal predicates and are efficient in terms of -+/// live register state and pointer arithmetic instructions. -+/// -+/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once -+/// outside any looping structure to minimize integer arithmetic. -+/// -+/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing -+/// the iterator. -+/// -+/// -+/// Example: -+/// -+/// An efficient pipeline structure may be constructed as follows: -+/// -+// template -+// __global__ void kernel( -+// typename Iterator::Params params, -+// typename Iterator::Element *ptr, -+// TensorCoord extent) { -+// -+// typename Iterator::Fragment fragment; -+// -+// TensorCoord threadblock_offset(0, 0); -+// -+// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); -+// -+// -+// fragment = *iter; // load "residue" tile first -+// ++iter; // advance to first "steady state" tile and update internal masks -+// -+// -+// #pragma unroll -+// for (int i = Remaining - 1; i >= 0; --i) { -+// -+// f(fragment); -+// -+// if (!i) { -+// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. -+// } -+// -+// fragment = *iter; // load tile during "steady state" phase -+// ++iter; // advance to next tile - lightweight due to steady-state masks -+// } -+// } -+// -+// void host(TensorView view) { -+// -+// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix; -+// -+// typename Iterator::Params params(view.layout()); -+// -+// kernel(params, view.data()); -+// } -+/// -+/// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize = ThreadMap::kElementsPerAccess -+> -+class PredicatedTileIteratorTriangularMatrix; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template -+class PredicatedTileIteratorTriangularMatrix { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ /// Type used for internal memory accesses -+ using AccessType = AlignedArray::value / 8)>; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = -+ PredicatedTileAccessIteratorTriangularMatrix; -+ -+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename TileAccessIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ public: -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ private: -+ /// Parameters object -+ typename TileAccessIterator::Params params_; -+ -+ public: -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout) : params_(layout) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ }; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ /// Precomputed parameters object -+ Params const ¶ms, -+ /// Pointer to start of tensor -+ Pointer pointer, -+ /// Extent of tensor -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : address_iterator_(params.params_, pointer, extent, thread_id, -+ threadblock_offset) {} -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, -+ make_Coord(0, 0)) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ if (kAdvanceRank) -+ address_iterator_.add_tile_offset({0, 1}); -+ else -+ address_iterator_.add_tile_offset({1, 0}); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { address_iterator_.enable_mask(); } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } -+ -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ address_iterator_.set_iteration_index(idx); -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ cutlass::arch::global_load( -+ frag_ptr[idx], access_ptr, address_iterator_.valid()); -+ -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < kAccessesPerVector; ++v) { -+ -+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ if (address_iterator_.valid()) { -+ *access_ptr = frag_ptr[idx]; -+ } -+ ++address_iterator_; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize -+> -+class PredicatedTileIteratorTriangularMatrix { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ } -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.row(), extent.column()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) -+ ) { } -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept | -+/// MaskedTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ SideMode kSideMode, -+ FillMode kFillMode, -+ DiagType kDiagType, -+ int AccessSize -+> -+class PredicatedTileIteratorTriangularMatrix { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Pointer = Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kSideMode, -+ kFillMode, -+ kDiagType, -+ AccessSize -+ >; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array; -+ -+ /// Predicate vector stores mask to guard accesses -+ using Mask = typename UnderlyingIterator::Mask; -+ -+ /// Parameters object is precomputed state and is host-constructible -+ class Params { -+ private: -+ -+ friend PredicatedTileIteratorTriangularMatrix; -+ -+ /// Parameters object -+ typename UnderlyingIterator::Params params_; -+ -+ public: -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ /// Construct the Params object given a pitch-linear tensor's layout -+ CUTLASS_HOST_DEVICE -+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { -+ -+ }; -+ }; -+ -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock -+ ): -+ iterator_( -+ params.params_, -+ pointer, -+ layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) -+ ) { } -+ -+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix( -+ Params const ¶ms, ///< Precomputed parameters object -+ Pointer pointer, ///< Pointer to start of tensor -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id ///< ID of each participating thread -+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the iterator's -+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls -+ /// are lightweight and must only update the internal pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedTileIteratorTriangularMatrix operator++(int) { -+ PredicatedTileIteratorTriangularMatrix self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void clear_mask(bool enable = true) { -+ iterator_.clear_mask(enable); -+ } -+ -+ /// Clears the predicate set efficiently -+ CUTLASS_HOST_DEVICE -+ void enable_mask() { -+ iterator_.enable_mask(); -+ } -+ -+ /// Sets the predicate mask, overriding value stored in predicate iterator -+ CUTLASS_HOST_DEVICE -+ void set_mask(Mask const &mask) { -+ iterator_.set_mask(mask); -+ } -+ -+ /// Gets the mask -+ CUTLASS_HOST_DEVICE -+ void get_mask(Mask &mask) { -+ iterator_.get_mask(mask); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { -+ iterator_.load_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { -+ iterator_.store_with_byte_offset(frag, byte_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h -new file mode 100644 -index 0000000..4762175 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h -@@ -0,0 +1,417 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing computing the addresses of loading small -+ vectors from the global memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/coord.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// PredicatedVectorAccessIterator -+/// -+template < -+ /// Shape of the vector accessed by the entire threadblock -+ typename Shape, -+ /// Shape of the vector accessed by the warp -+ typename WarpShape, -+ /// Type of Element -+ typename Element, -+ /// Layout of the vector -+ typename Layout, -+ /// Number of elements for each access -+ int ElementsPerAccess, -+ /// Support residual tile -+ bool EnableResidualAccess = false -+> -+class PredicatedVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Vector access iterator specialized for vectors, e.g. scale and bias -+/// Thread arrangements are for TensorOps -+/// -+template < -+ typename Shape_, -+ typename WarpShape_, -+ typename Element_, -+ int ElementsPerAccess, -+ bool EnableResidualAccess -+> -+class PredicatedVectorAccessIterator < -+ Shape_, -+ WarpShape_, -+ Element_, -+ layout::PitchLinear, -+ ElementsPerAccess, -+ EnableResidualAccess -+> { -+ public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+// static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kThreads = 32; -+ static int const kRowsPerIteration = 8; -+ static int const kThreadsPerRow = kThreads / kRowsPerIteration; -+ static int const kThreadsPerRowMask = 0x3; -+ static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess); -+ static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided; -+ -+ using AccessType = AlignedArray; -+ -+ private: -+ /// Internal pointer type permits fast address arithmetic -+ using BytePointer = char *; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer to first access of tile -+ BytePointer pointer_; -+ -+ /// Extent of tensor -+ TensorCoord extent_; -+ -+ /// pointer offset of each thread -+ TensorCoord thread_offset_; -+ -+ /// iteration index -+ LongIndex iteration_; -+ -+ /// residual access -+ bool is_residual_; -+ -+ /// residual offset of each thread -+ TensorCoord residual_offset_; -+ -+ public: -+ /// Constructs a vector access iterator -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ /// Pointer to the start of the vector -+ ConstPointer pointer, -+ /// Extent of vector -+ TensorCoord extent, -+ /// ID of each participating thread -+ int thread_id, -+ /// ID of each participating warp -+ int warp_id, -+ /// Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : pointer_(reinterpret_cast( -+ const_cast(pointer))), -+ extent_(extent), -+ is_residual_(false) { -+ -+ -+ int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; -+ -+ // Per-thread offset in logical coordinates of tensor -+ -+ thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) + -+ TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0); -+ -+ set_iteration_index(0); -+ -+ if(EnableResidualAccess) { -+ // compute residual offset -+ typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous; -+ if (residual_size) { -+ is_residual_ = true; -+ residual_offset_ = make_Coord(residual_size, 0); -+ } -+ } -+ } -+ -+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ /// Pointer to start of vector -+ ConstPointer pointer, -+ /// Extent of vector -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ /// ID of each participating warp -+ int warp_id) -+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, -+ make_Coord(0, 0)) {} -+ -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_ = index; -+ } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles -+ CUTLASS_DEVICE -+ void add_tile_offset( -+ TensorCoord const &tile_offset) { -+ -+ thread_offset_ = -+ thread_offset_ + -+ TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ return reinterpret_cast( -+ pointer_ + -+ ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess) -+ * sizeof_bits::value / 8)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator &operator++() { -+ ++iteration_; -+ if(iteration_ >= kIterations) -+ iteration_ = 0; -+ -+ return *this; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ if(EnableResidualAccess && is_residual_) { -+ is_residual_ = false; -+ thread_offset_ += residual_offset_; -+ } -+ else -+ add_tile_offset(TensorCoord(1, 0)); -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator operator++(int) { -+ PredicatedVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return ((thread_offset_.contiguous() + -+ iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous()); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Specialization of PredicatedVectorAccessIterator for row-major data. -+/// -+template < -+ typename Shape_, -+ typename WarpShape_, -+ typename Element_, -+ int ElementsPerAccess, -+ bool EnableResidualAccess -+> -+class PredicatedVectorAccessIterator< -+ Shape_, -+ WarpShape_, -+ Element_, -+ layout::RowMajor, -+ ElementsPerAccess, -+ EnableResidualAccess -+> { -+ public: -+ -+ using Shape = Shape_; -+ using WarpShape = WarpShape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorView = TensorView; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ConstPointer = const Element *; -+ using NonConstPointer = typename platform::remove_const::type *; -+ -+ using UnderlyingIterator = PredicatedVectorAccessIterator< -+ layout::PitchLinearShape, -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ ElementsPerAccess, -+ EnableResidualAccess>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; -+ static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration; -+ static int const kThreads = UnderlyingIterator::kThreads; -+ static int const kIterations = UnderlyingIterator::kIterations; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Underlying pitch-linear tile iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Constructs a TileIterator from its precomputed state, threadblock offset, -+ /// and thread ID -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ ///< Pointer to the start of the vector -+ ConstPointer pointer, -+ ///< Extent of tensor -+ TensorCoord extent, -+ ///< ID of each participating thread -+ int thread_id, -+ ///< ID of each participating warp -+ int warp_id, -+ ///< Initial offset of threadblock -+ TensorCoord const &threadblock_offset) -+ : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()), -+ thread_id, warp_id, -+ layout::PitchLinearCoord(threadblock_offset.column(), -+ threadblock_offset.row())) {} -+ -+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator( -+ ConstPointer pointer, ///< Pointer to the start of the vector -+ TensorCoord extent, ///< Extent of tensor -+ int thread_id, ///< ID of each participating thread -+ int warp_id ///< ID of each participating warp -+ ) -+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, -+ make_Coord(0, 0)) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Advances an iterator along logical dimensions of matrix in units of whole -+ /// tiles -+ CUTLASS_HOST_DEVICE -+ void add_tile_offset(TensorCoord const &tile_offset) { -+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ /// -+ /// The first time this method is called, predicates are updated, and the -+ /// iterator's internal pointer is reverted to the first "steady state" tile. -+ /// Subsequent calls are lightweight and must only update the internal -+ /// pointer. -+ CUTLASS_HOST_DEVICE -+ PredicatedVectorAccessIterator operator++(int) { -+ PredicatedVectorAccessIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Increment and return an instance to self. -+ CUTLASS_HOST_DEVICE -+ void advance() { -+ iterator_.advance(); -+ } -+ -+ /// Returns whether access is valid or not -+ CUTLASS_HOST_DEVICE -+ bool valid() { -+ return iterator_.valid(); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h -new file mode 100644 -index 0000000..1de3e65 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of small -+ scale and bias vectors in the shared memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// RegularScaleBiasVectorAccessIterator -+/// -+template -+class RegularScaleBiasVectorAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularScaleBiasVectorAccessIterator { -+ public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Element type per access -+ static int const kElementsPerAccess = 128 / sizeof_bits::value; -+ static int const kThreads = Shape::kContiguous / kElementsPerAccess; -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Internal pointer -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator( -+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias -+ ///< vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : byte_offset_(0) { -+ // Per-thread offset in logical coordinates of tensor -+ int thread_offset = thread_id * kElementsPerAccess; -+ -+ // initialize pointer -+ pointer_ = -+ reinterpret_cast(scale_bias_ref.data() + thread_offset); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ char *access_byte_ptr = -+ reinterpret_cast(pointer_); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator &operator++() { return *this; } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator operator++(int) { -+ RegularScaleBiasVectorAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ // Multiply by 2 because we store scale and bias belong to the same stage -+ // next to each other. -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularScaleBiasVectorAccessIterator< -+ Shape_, Element_, -+ layout::RowMajor> { -+ public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularScaleBiasVectorAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator( -+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias -+ ///< vector -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) { -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularScaleBiasVectorAccessIterator operator++(int) { -+ RegularScaleBiasVectorAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h -new file mode 100644 -index 0000000..a3e30c2 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h -@@ -0,0 +1,58 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing the address computation of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template ::value* ThreadMap::kElementsPerAccess / 8> -+class RegularTileAccessIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h -new file mode 100644 -index 0000000..bba9f66 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h -@@ -0,0 +1,408 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. -+ /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. -+ /// For row major A operand, k dimension is contiguous dimension; -+ /// For col major A operand, k dimension is strided dimension; -+ /// For row major B operand, k dimension is strided dimension; -+ /// For col major B operand, k dimension is contiguous dimension. -+ /// Below two classes map col/row major to the pitch linear coordinates used -+ /// in this base class. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * -+ ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for column major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajor, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::RowMajor, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h -new file mode 100644 -index 0000000..938b419 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h -@@ -0,0 +1,587 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template ::value* ThreadMap::kElementsPerAccess / 8 -+ > -+class RegularTileAccessIteratorDirectConv; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_, false, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ //Do nothing -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * ThreadMap::Iterations::kStrided * -+ ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::PitchLinear, -+ AdvanceRank, ThreadMap_,true, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ /// Total iterattions in the strided dimension: Dynamic value -+ int total_iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ total_iteration_strided_ = num; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < total_iteration_strided_) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * -+ ThreadMap::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for column major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::ColumnMajor, -+ AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_, -+ Dynamic_iterations>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ iterator_.set_iteration_num(num); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for row major layouts -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIteratorDirectConv< -+ Shape_, Element_, -+ layout::RowMajor, -+ AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_, -+ Dynamic_iterations>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_num(int num) { -+ iterator_.set_iteration_num(num); -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIteratorDirectConv operator++(int) { -+ RegularTileAccessIteratorDirectConv prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h -new file mode 100644 -index 0000000..c16daff ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h -@@ -0,0 +1,820 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = -+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast( -+ ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * -+ Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ static int const kCrosswise = Crosswise; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise), -+ "kCrosswise is the smallest unit in the contiguous dimension " -+ "for shared memory swizzling."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ /// Number of pointers -+ /// -+ /// Note:TN kblock32 layouts only needs 1 pointer, but strangely -+ /// reducing pointer count hurts perfomrnace -+ static int const kPointerCount = -+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Total number of sections. The memory is divided into stages. One stage -+ /// can store one tile. Stage is divided into sections. Interleaved layout -+ /// can have multiple sections in a stage. The rest layout only has one section -+ /// in a stage. -+ int sections_; -+ -+ /// Sections that a stage has -+ int sections_per_stage_; -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : sections_(ref.stride(0) / kCrosswise), -+ sections_per_stage_(Shape::kContiguous / kCrosswise), -+ // stride_ = kCrosswise x sections_ x kFactor -+ stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data()) + -+ ref.offset(thread_offset_in_threadblock_tile) / -+ Layout::kElementsPerAccess; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof_bits::value / 8; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = -+ stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + -+ // kCrosswise elements in the contiguous dimension would span to a -+ // shared memory cache line. -+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) * -+ Layout::TileShape::kContiguous; -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next section. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ * -+ ThreadMap::kElementsPerAccess / sections_ + -+ coord.strided() * Shape::kStrided * stride_ * -+ Layout::kElementsPerAccess / Layout::kFactor); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h -new file mode 100644 -index 0000000..2b116d0 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h -@@ -0,0 +1,1532 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing computing the addresses of storing of tiles -+ from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm75.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm80.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous64b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 64; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 64b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous64b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous64b, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicand64bCrosswise, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 64; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 64b"); -+ -+ ///< Number of pointers - two pointers are needed if making more than 4 iterations along -+ ///< strided dimension -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_[Detail::kPointerCount]; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data()); -+ -+ byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); -+ -+ if (Detail::kPointerCount == 2) { -+ byte_offset_[1] = byte_offset_[0] ^ 8; -+ } -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; -+ } -+ -+ /// Returns a pointer -+ CUTLASS_DEVICE -+ AccessType *get() const { -+ -+ // Map the logical contiguous and strided access to the internal swizzled structure. -+ int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_; -+ -+ char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); -+ -+ int byte_offset; -+ -+ // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) -+ // in the strided dimension -+ if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { -+ byte_offset = byte_offset_[1]; -+ } -+ else { -+ byte_offset = byte_offset_[0]; -+ } -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicand64bCrosswise, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicand64bCrosswise, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous128b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous128b, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous128b, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::TensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ static_assert(ThreadMap::kThreads / 32 > 1, -+ "This tile iterator requires at least two warps."); -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * -+ ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128b"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = 1; -+ }; -+ -+ -+ static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ /// Iteration in the contiguous dimension -+ int iteration_contiguous_; -+ -+ /// Iteration in the strided dimension -+ int iteration_strided_; -+ -+ public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ stride_(ref.stride(0) / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; -+ -+ // initialize pointer -+ pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ -+ set_iteration_index(0); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ -+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; -+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ -+ AccessType *access_ptr = pointer_; -+ -+ int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); -+ int offset_s = (iteration_strided_ / 2) * 8; -+ -+ int access_offset = offset_c * stride_ + offset_s; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iteration_contiguous_; -+ -+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) -+ return *this; -+ -+ // Enter here only if (iteration_contiguous_ == -+ // ThreadMap::Iteration::kContiguous) -+ iteration_contiguous_ = 0; -+ ++iteration_strided_; -+ -+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) { -+ return *this; -+ } -+ -+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) -+ // which means we enter the next tile. -+ iteration_strided_ = 0; -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ -+ RegularTileAccessIterator prev(*this); -+ -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous * stride_ + -+ coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileAccessIterator { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileAccessIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise128x4, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ using AccessType = typename UnderlyingIterator::AccessType; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return reinterpret_cast(iterator_.get()); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileAccessIterator operator++(int) { -+ RegularTileAccessIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h -new file mode 100644 -index 0000000..26d7da7 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -+> -+class RegularTileIterator; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h -new file mode 100644 -index 0000000..f761cdd ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h -@@ -0,0 +1,552 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels -+/// and sparse tensor core meta data. -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using AccessType = AlignedArray; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the contiguous or strided dimensions."); -+ -+private: -+ -+ // -+ // Types -+ // -+ -+ // -+ // Data members -+ // -+ -+ /// Pointer to memory -+ uint8_t *pointer_; -+ -+ /// Stride quantity -+ StrideIndex stride_; -+ -+ /// Amount to increment pointer along strided dimension -+ Index increment_strided_; -+ -+ /// Amount to advance pointer between tiles -+ Index increment_advance_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ pointer_(reinterpret_cast(ref.data()) + (ref.offset(ThreadMap::initial_offset(thread_idx)) * sizeof_bits::value / 8)) { -+ -+ stride_ = ref.stride()[0]; -+ increment_strided_ = (ref.stride()[0] * sizeof_bits::value) * ThreadMap::Delta::kStrided / 8; -+ -+ increment_advance_ = -+ (kAdvanceRank == 0 ? -+ Shape::kContiguous * sizeof_bits::value / 8 : -+ Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8)); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ load_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ access_ptr[c * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess] = frag_ptr[idx]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ store_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ pointer_ += increment_advance_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ pointer_ -= increment_advance_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset; -+ } -+ -+ /// Adds a tile offset in the unit of tile. -+ /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. -+ /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. -+ /// For row major A operand, k dimension is contiguous dimension; -+ /// For col major A operand, k dimension is strided dimension; -+ /// For row major B operand, k dimension is strided dimension; -+ /// For col major B operand, k dimension is contiguous dimension. -+ /// Below two classes map col/row major to the pitch linear coordinates used -+ /// in this base class. -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ int offset = sizeof_bits::value * -+ (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; -+ add_pointer_offset(offset); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+#if 0 -+ AccessType *access_ptr = pointer_[iteration_strided_ & 1]; -+ int stride_idx = (iteration_strided_ & ~1); -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ iteration_contiguous_ * ThreadMap::Delta::kContiguous / -+ ThreadMap::kElementsPerAccess; -+ -+ char *access_byte_ptr = -+ reinterpret_cast(access_ptr + access_offset); -+ return reinterpret_cast(access_byte_ptr + byte_offset_); -+#endif -+ return reinterpret_cast(pointer_); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kAlignment -+ >; -+ -+ using AccessType = typename Underlying::AccessType; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return iterator_.get(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for pitch-linear -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajor; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ using AccessType = typename Underlying::AccessType; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Overrides the internal iteration index -+ CUTLASS_HOST_DEVICE -+ void set_iteration_index(int index) { -+ } -+ -+ /// Returns a pointer -+ CUTLASS_HOST_DEVICE -+ AccessType *get() const { -+ return iterator_.get(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h -new file mode 100644 -index 0000000..a954eb4 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h -@@ -0,0 +1,509 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template < -+ typename Shape, -+ typename Element, -+ typename Layout, -+ int AdvanceRank, -+ typename ThreadMap, -+ int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 -+> -+class RegularTileIterator2dThreadTile; -+ -+ -+/// Regular tile iterator specialized for pitch-linear + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::PitchLinear; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the contiguous or strided dimensions."); -+ -+private: -+ -+ // -+ // Types -+ // -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Data members -+ // -+ -+ /// Pointer to memory -+ uint8_t *pointer_; -+ -+ /// Stride quantity -+ StrideIndex stride_; -+ -+ /// Amount to increment pointer along strided dimension -+ LongIndex increment_strided_; -+ -+ /// Amount to advance pointer between tiles -+ LongIndex increment_advance_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx, -+ int interleave -+ ){ -+ -+ TensorCoord t = ThreadMap::initial_offset(thread_idx); -+ long int offset = t[0] * interleave + t[1] * ref.stride()[0]/interleave; -+ pointer_ = reinterpret_cast(ref.data() + offset); -+ -+ stride_ = ref.stride()[0] / interleave; -+ increment_strided_ = (ref.stride()[0] * sizeof_bits::value / 8) * ThreadMap::Delta::kStrided / interleave; -+ -+ increment_advance_ = -+ (kAdvanceRank == 0 ? -+ Shape::kContiguous * sizeof_bits::value / 8 : -+ Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8) / interleave); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType const *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ load_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = reinterpret_cast(byte_pointer); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int idx = c + s * ThreadMap::Iterations::kContiguous; -+ access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx]; -+ } -+ -+ if (s + 1 < ThreadMap::Iterations::kStrided) { -+ byte_pointer += increment_strided_; -+ } -+ } -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ store_with_pointer_offset( -+ frag, -+ tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ -+ ); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ pointer_ += increment_advance_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ pointer_ -= increment_advance_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ pointer_ += pointer_offset; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ int offset = sizeof_bits::value * -+ (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; -+ add_pointer_offset(offset); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorInterleaved<4>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ -+ using Underlying = RegularTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap, -+ kAlignment -+ >; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx, 4) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { -+public: -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorInterleaved<4>; -+ static int const kAdvanceRank = AdvanceRank; -+ using ThreadMap = ThreadMap_; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using Fragment = Array; -+ using PitchLinearThreadMap = PitchLinearStripminedThreadMap< layout::PitchLinearShape, -+ ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >; -+ -+ -+ using Underlying = RegularTileIterator2dThreadTile< -+ layout::PitchLinearShape, -+ Element, -+ layout::PitchLinear, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, -+ "Advance rank may only be along the row or column dimensions."); -+ -+private: -+ -+ Underlying iterator_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile() { } -+ -+ CUTLASS_DEVICE -+ RegularTileIterator2dThreadTile( -+ TensorRef const &ref, -+ int thread_idx -+ ): -+ iterator_({ref.data(), ref.stride()}, thread_idx, 4) { -+ -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag, TensorCoord const & tile_offset) { -+ iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Loads a fragment -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) { -+ iterator_.load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag, TensorCoord const & tile_offset) { -+ iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); -+ } -+ -+ /// Stores a fragment -+ CUTLASS_HOST_DEVICE -+ void store(Fragment const &frag) { -+ iterator_.store_with_pointer_offset(frag, 0); -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances the pointer -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator2dThreadTile &operator--() { -+ --iterator_; -+ return *this; -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h -new file mode 100644 -index 0000000..8ea0efa ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h -@@ -0,0 +1,1107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_tile_offset({0, 1}); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_tile_offset(coord); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_byte_offset(Fragment &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType const *access_ptr = reinterpret_cast(byte_ptr); -+ -+ frag_ptr[access_idx] = *access_ptr; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ *access_ptr = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_byte_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::RowMajorTensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element_))>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous< -+ sizeof_bits::value, int(128 / sizeof(Element))>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCongruous::value, -+ int(128 / sizeof(Element))>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>; -+ -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+ private: -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_tile_offset({1, 0}); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_tile_offset(coord); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); -+ } -+ -+ CUTLASS_DEVICE -+ void store_with_byte_offset(Fragment const &frag, Index byte_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; -+ AccessType *access_ptr = reinterpret_cast(byte_ptr); -+ -+ *access_ptr = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator::value, Crosswise>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Crosswise>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::TensorOpMultiplicandCrosswise::value, -+ Crosswise>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for k interleaved arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, -+ InterleavedK>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, -+ InterleavedK>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ /// This iterator is specialized for an access size that is 128 bits in -+ /// length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == -+ kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ }; -+ -+ private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ /// Underlying iterator to compute the addresses -+ using TileAccessIterator = RegularTileAccessIterator; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// Data member to the tile access iterator -+ TileAccessIterator address_iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : address_iterator_(ref, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ address_iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ address_iterator_.add_pointer_offset(Shape::kCount); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ address_iterator_.set_iteration_index(0); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; -+ ++address_iterator_; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for k interleaved arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+ -+template -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::TensorOpMultiplicandColumnMajorInterleaved::value, -+ InterleavedK>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::TensorOpMultiplicandColumnMajorInterleaved::value, -+ InterleavedK>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ cutlass::MatrixShape, -+ Element, -+ layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, -+ (kAdvanceRank == 1 ? 0 : 1), -+ ThreadMap -+ >; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h -new file mode 100644 -index 0000000..883faa5 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h -@@ -0,0 +1,1460 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. -+ -+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile -+ first, with the objective of minimizing predicate mask updates during steady-state operation. -+ -+ A precomputed "Params" object minimizes the amount of state that must be stored in registers, -+ and integer addition is used to advance the pointer through memory. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/layout/tensor_op_multiplicand_sm70.h" -+ -+#include "cutlass/transform/threadblock/regular_tile_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::VoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType * pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ -+ // This is the offset of a thread within a threadblock tile for a specific pointer -+ // (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess -+ ); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandCongruous::value>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+/// Tile iterator specialized for congruous arrangements for TensorOps -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::VoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ using StrideIndex = typename Layout::Stride::Index; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ /// This iterator is specialized for an access size that is 128 bits in length. -+ static int const kAccessSizeInBits = 128; -+ -+ static_assert( -+ sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, -+ "This iterator requires a policy whose access size is 128bs"); -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ }; -+ -+ -+private: -+ -+ /// Element type per access -+ using AccessType = Array; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Stride value -+ StrideIndex stride_; -+ -+ /// Internal pointer to first access of tile -+ AccessType * pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ -+ // This is the offset of a thread within a threadblock tile for a specific pointer -+ // (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); -+ -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset( -+ coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + -+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess -+ ); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[s & 1]; -+ int stride_idx = (s & ~1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + -+ c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + -+ vec_pointer_offset; -+ -+ int access_idx = c + s * ThreadMap::Iterations::kContiguous; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ (kAdvanceRank == 0 ? 0 : 1), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major congruous TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, -+ Element_, -+ layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>, -+ AdvanceRank, -+ ThreadMap_, -+ Alignment> { -+public: -+ -+ static_assert(AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, -+ Element, -+ layout::VoltaTensorOpMultiplicandBCongruous::value>, -+ (kAdvanceRank == 0 ? 1 : 0), -+ ThreadMap_>; -+ -+public: -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+private: -+ -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+public: -+ -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator( -+ TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ): iterator_({ref.data(), ref.stride()}, thread_id) { -+ -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset( -+ Fragment const &frag, -+ Index pointer_offset) { -+ -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { -+ store_with_pointer_offset(frag, 0); -+ } -+}; -+ -+ -+/// Tile iterator specialized for crosswise arrangements for TensorOps. -+/// -+/// Volta TN SMEM layout is a little diffrent: -+/// Crosseised elements will be stored in a line, while contiguous elements -+/// sre stored in line-by-line. -+/// Padding is used to reduce SMEM bank conflicts. -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator< -+ Shape_, Element_, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape_::kContiguous>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for pitch-linear iterator may along advance along the " -+ "contiguous(rank=0) or strided(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kContiguous>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Internal details made public to facilitate introspection -+ struct Detail { -+ -+ ///< Number of pointers -+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); -+ -+ /// Iterations for the kElementsPerAccess of ThreadMap -+ static int const kIterarionsPerAccess = -+ ThreadMap::kElementsPerAccess / Layout::kElementsPerAccess; -+ -+ /// Contiguous elements per line -+ static int const kContiguousElementsPerLine = 4; -+ }; -+ -+ private: -+ /// Element type per access -+ using AccessType = Array; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = -+ Array; -+ -+ private: -+ // -+ // Data members -+ // -+ -+ /// The crosswised elements will be stored in a line. -+ /// line_size is size of crosswised dimention plus padding. -+ /// in units of AccessType -+ Index line_size; -+ -+ /// Internal pointer to first access of tile -+ AccessType *pointer_[Detail::kPointerCount]; -+ -+ /// Internal byte offset -+ Index byte_offset_; -+ -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : line_size(ref.stride(0) * Detail::kContiguousElementsPerLine / Layout::kElementsPerAccess), -+ byte_offset_(0) { -+ -+ layout::PitchLinearCoord thread_offset_base = -+ ThreadMap::initial_offset(thread_id); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Detail::kPointerCount; ++i) { -+ // This is the offset of a thread within a threadblock tile for a specific -+ // pointer (units of elements) -+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = -+ thread_offset_base + -+ layout::PitchLinearCoord{ -+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; -+ -+ // initialize pointer -+ pointer_[i] = reinterpret_cast( -+ ref.data() + ref.offset(thread_offset_in_threadblock_tile)); -+ } -+ } -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ byte_offset_ += pointer_offset * sizeof(Element); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ // (Shape::kContiguous/Layout::kElementsPerAccess)* -+ // line_size * Layout::kElementsPerAccess -+ add_pointer_offset(Shape::kContiguous * line_size); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ this->operator++(); -+ -+ return prev; -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ add_pointer_offset((coord.contiguous() * (Shape::kContiguous / Layout::kElementsPerAccess) * -+ line_size + coord.strided() * Shape::kStrided) * -+ Layout::kElementsPerAccess); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ AccessType *access_ptr = pointer_[(s & 1) ^ (s / 2)]; -+ -+ access_ptr += 16 * (s / 2); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { -+ -+ int access_offset = -+ c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + -+ vec_pointer_offset + i * line_size; -+ -+ int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * -+ Detail::kIterarionsPerAccess + i; -+ -+ char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ frag_ptr[access_idx] = *reinterpret_cast( -+ access_byte_ptr + byte_offset_); -+ } -+ } -+ } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ AccessType const *frag_ptr = reinterpret_cast(&frag); -+ -+ Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { -+ -+ AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; -+ -+ access_ptr += 16 * (s / 2) + vec_pointer_offset; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { -+ -+ int access_offset = -+ c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; -+ -+ int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * -+ Detail::kIterarionsPerAccess + i; -+ -+ char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); -+ -+ *reinterpret_cast(access_byte_ptr + byte_offset_) = -+ frag_ptr[access_idx]; -+ } -+ } -+ } -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for column-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator::value, Shape_::kRow>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for column-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kRow>; -+ static int const kAdvanceRank = AdvanceRank; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kRow>, -+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.row(), coord.column()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tile Iterator specialized for row-major crosswise TensorOp formats. -+/// -+/// -+/// Satisfies: ForwardTileIteratorConcept | -+/// ReadableContiguousTileIteratorConcept | -+/// WriteableContiguousTileIteratorConcept -+/// -+template < -+ typename Shape_, -+ typename Element_, -+ int AdvanceRank, -+ typename ThreadMap_, -+ int Alignment -+> -+class RegularTileIterator::value, Shape_::kColumn>, -+ AdvanceRank, ThreadMap_, Alignment> { -+ public: -+ static_assert( -+ AdvanceRank == 0 || AdvanceRank == 1, -+ "Specialization for row-major iterator may along advance along the " -+ "columns(rank=0) or rows(rank=1) dimension."); -+ -+ using Shape = Shape_; -+ using Element = Element_; -+ using Layout = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ sizeof_bits::value, Shape::kColumn>; -+ static int const kAdvanceRank = AdvanceRank; -+ static int const kAlignment = Alignment; -+ -+ using Index = typename Layout::Index; -+ using LongIndex = typename Layout::LongIndex; -+ -+ using TensorRef = TensorRef; -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ using ThreadMap = ThreadMap_; -+ -+ /// Underlying iterator type -+ using UnderlyingIterator = RegularTileIterator< -+ layout::PitchLinearShape, Element, -+ layout::VoltaTensorOpMultiplicandCrosswise::value, -+ Shape::kColumn>, -+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; -+ -+ public: -+ /// Fragment object to be loaded or stored -+ using Fragment = Array; -+ -+ private: -+ /// Underlying iterator -+ UnderlyingIterator iterator_; -+ -+ public: -+ /// Construct a TileIterator with zero threadblock offset -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor -+ int thread_id ///< ID of each participating thread -+ ) -+ : iterator_({ref.data(), ref.stride()}, thread_id) {} -+ -+ /// Adds a pointer offset in units of Element -+ CUTLASS_HOST_DEVICE -+ void add_pointer_offset(LongIndex pointer_offset) { -+ iterator_.add_pointer_offset(pointer_offset); -+ } -+ -+ /// Adds a tile offset -+ CUTLASS_DEVICE -+ void add_tile_offset(TensorCoord const &coord) { -+ iterator_.add_tile_offset({coord.column(), coord.row()}); -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator &operator++() { -+ ++iterator_; -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ RegularTileIterator operator++(int) { -+ RegularTileIterator prev(*this); -+ ++iterator_; -+ -+ return prev; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ iterator_.load_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { -+ iterator_.store_with_pointer_offset(frag, pointer_offset); -+ } -+ -+ /// Store a fragment to memory -+ CUTLASS_DEVICE -+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -diff --git a/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h -new file mode 100644 -index 0000000..8536a32 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/threadblock/vector_iterator.h -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Template wraps the vector access iterator concept to load whole vector from tensors in -+ memory. This is typically used for per-channel scale and bias in convolution kernels. -+*/ -+ -+#pragma once -+ -+#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace transform { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class VectorIterator { -+public: -+ using VectorAccessIterator = VectorAccessIterator_; -+ -+ using Shape = typename VectorAccessIterator::Shape; -+ using Element = typename VectorAccessIterator::Element; -+ using Layout = typename VectorAccessIterator::Layout; -+ using TensorCoord = typename Layout::TensorCoord; -+ using AccessType = typename VectorAccessIterator::AccessType; -+ using TensorRef = typename VectorAccessIterator::TensorRef; -+ using Index = typename VectorAccessIterator::Index; -+ using LongIndex = typename VectorAccessIterator::LongIndex; -+ -+ static int const kElementsPerAccess = VectorAccessIterator::kElementsPerAccess; -+ static int const kRowsPerIteration = VectorAccessIterator::kRowsPerIteration; -+ static int const kThreads = VectorAccessIterator::kThreads; -+ static int const kIterations = VectorAccessIterator::kIterations; -+ -+ /// Fragment object to be loaded or stored -+ using Fragment = cutlass::Array< -+ Element, kElementsPerAccess * kIterations>; -+ -+private: -+ -+ /// Internal state -+ VectorAccessIterator vector_access_iterator_; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ VectorIterator( -+ Element const *ptr, -+ TensorCoord extent, -+ int thread_idx, -+ int warp_idx, -+ MatrixCoord const &threadblock_offset = MatrixCoord() -+ ): -+ vector_access_iterator_(ptr, extent, thread_idx, warp_idx, threadblock_offset) { } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ VectorIterator &operator++() { -+ vector_access_iterator_.advance(); -+ return *this; -+ } -+ -+ /// Advances to the next tile in memory. -+ CUTLASS_HOST_DEVICE -+ VectorIterator operator++(int) { -+ VectorIterator self(*this); -+ operator++(); -+ return self; -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { -+ -+ frag.clear(); -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < kIterations; ++c) { -+ -+ cutlass::arch::global_load< -+ AccessType, -+ sizeof(AccessType) -+ >( -+ frag_ptr[c], -+ vector_access_iterator_.get() + pointer_offset, -+ vector_access_iterator_.valid() -+ ); -+ -+ ++vector_access_iterator_; -+ } -+// } -+ } -+ -+ /// Loads a fragment from memory -+ CUTLASS_DEVICE -+ void load(Fragment &frag) { -+ vector_access_iterator_.set_iteration_index(0); -+ load_with_pointer_offset(frag, 0); -+ } -+ -+ CUTLASS_DEVICE -+ void advance() { -+ vector_access_iterator_.advance(); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace transform -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h -new file mode 100644 -index 0000000..5b5baba ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h -@@ -0,0 +1,283 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+ -+/*! \file -+ \brief This defines a "fragment" iterator for visiting the fragments of a warp vector -+ that participate in one warp-level mma operation. -+ -+ Typically, this is used to access the scale/bias fragement of a warp-level mma operation. -+ The scale/bias vector is then partitioned into smaller fragments that can be fed into -+ next warp-level mma operation. -+ -+ This iterator is necessary to accomplish warp-level mma fusion where the scale/bias vector is -+ applied to the multiplicand for the next mma. -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+ -+namespace cutlass { -+namespace transform { -+namespace warp { -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ /// Size of the input fragment tile shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Layout of operand in memory -+ typename Layout_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator; -+ -+ -+// Partial specialization for PitchLinear layout tile -+ -+template < -+ /// Size of the input fragment vector shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator { -+ public: -+ -+ /// Size of the input threadblock tile shape (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::PitchLinear; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Number of participating threads -+ static int const kThreads = 32; -+ -+ static int const kElementsPerAccess = ElementsPerAccess; -+ static int const kRowsPerIteration = 8; -+ static int const kColumnsPerAccess = 8; -+ static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kK / kThreads; -+ static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; -+ -+ /// Number of iterations -+ using Iterations = MatrixShape; -+ -+public: -+ -+ // -+ // Derived quantities -+ // -+ // All fragments have kElementsPerAccess scale followed by bias -+ -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one iteration of the iterator. -+ using Fragment = Array; -+ -+ /// Input threadblock fragment tile -+ using ThreadblockFragment = Array; -+ -+private: -+ -+ /// Internal access type -+ using AccessType = Array; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Input threadblock fragment tile -+ AccessType const *iterator_; -+ -+ /// Internal index -+ int index_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) -+ : iterator_(reinterpret_cast(&threadblock_frag)), -+ index_(0) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ index_ += index_offset; -+ -+ if(index_ >= Iterations::kColumn) -+ index_ = 0; -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ index_ = idx; -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ -+ AccessType *frag_ptr = reinterpret_cast(&frag); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int r = 0; r < Iterations::kRow; r++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kAccessPerIteration; i++) { -+ -+ frag_ptr[i * Iterations::kRow + r].clear(); -+ frag_ptr[i * Iterations::kRow + r] = iterator_[index_ * kAccessPerIteration + i]; -+ } -+ } -+ } -+ -+}; -+ -+// Partial specialization for Row-Major layout tile -+ -+template < -+ /// Size of the input fragment tile shape (concept: MatrixShape) -+ typename Shape_, -+ /// Element type -+ typename Element_, -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ typename InstructionShape_, -+ //// Number of elements per access when loading fragment -+ int ElementsPerAccess> -+class VectorFragmentIterator { -+ public: -+ -+ /// Size of the input threadblock tile shape (concept: MatrixShape) -+ using Shape = Shape_; -+ -+ /// Element type -+ using Element = Element_; -+ -+ /// Layout of source tile -+ using Layout = cutlass::layout::RowMajor; -+ -+ /// Shape of one matrix product operation (concept: MatrixShape) -+ using InstructionShape = InstructionShape_; -+ -+ /// Underlying iterator -+ using Base = VectorFragmentIterator< -+ layout::PitchLinearShape, Element, -+ layout::PitchLinear, InstructionShape, ElementsPerAccess>; -+ -+ -+ public: -+ -+ // -+ // Derived quantities -+ // -+ /// Fragment object holding a thread's part of a tile -+ /// This is the fragment size produced by one iteration of the iterator. -+ using Fragment = typename Base::Fragment; -+ -+ /// Input threadblock fragment tile -+ using ThreadblockFragment = typename Base::ThreadblockFragment; -+ -+ private: -+ /// Underlying iterator -+ Base iterator_; -+ -+public: -+ /// Constructs an iterator -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) -+ : iterator_(threadblock_frag) {} -+ -+ /// Add offset -+ CUTLASS_HOST_DEVICE -+ void add_offset(int index_offset) { -+ iterator_.add_offset(index_offset); -+ } -+ -+ /// Increments -+ CUTLASS_HOST_DEVICE -+ VectorFragmentIterator &operator++() { -+ add_offset(1); -+ return *this; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void set_index(int idx) { -+ iterator_.set_index(idx); -+ } -+ -+ /// Loads a fragment from the referenced part of the accumulator tile -+ CUTLASS_HOST_DEVICE -+ void load(Fragment &frag) const { -+ iterator_.load(frag); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace conv -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/uint128.h b/3rdparty/cutlass/include/cutlass/uint128.h -new file mode 100644 -index 0000000..38d5b4d ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/uint128.h -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. -+*/ -+ -+#pragma once -+ -+#if defined(__CUDACC_RTC__) -+#include -+#else -+#include -+#include -+#include -+#include -+#include -+#endif -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Optionally enable GCC's built-in type -+#if defined(__x86_64) && !defined(__CUDA_ARCH__) && defined(__GNUC__) -+#define CUTLASS_UINT128_NATIVE -+#elif defined(_MSC_VER) && defined(_M_AMD64) && !defined(__CUDA_ARCH__) -+#define CUTLASS_INT128_ARITHMETIC -+#include -+#if _MSC_VER >= 1920 -+#define CUTLASS_INT128_ARITHMETIC_DIV -+#include -+#endif -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///! Unsigned 128b integer type -+struct uint128_t { -+ -+ /// Size of one part of the uint's storage in bits -+ static constexpr int kPartSize = sizeof_bits::value; -+ -+ struct hilo { -+ uint64_t lo; -+ uint64_t hi; -+ -+ hilo() = default; -+ -+ CUTLASS_HOST_DEVICE hilo(uint64_t lo_, uint64_t hi_):lo(lo_), hi(hi_) {} -+ }; -+ -+ // Use a union to store either low and high parts or, if present, a built-in 128b integer type. -+ union { -+ struct hilo hilo_; -+ -+ #if defined(CUTLASS_UINT128_NATIVE) -+ unsigned __int128 native; -+ #endif // defined(CUTLASS_UINT128_NATIVE) -+ }; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ uint128_t() = default; -+ -+ /// Constructor from uint64 -+ CUTLASS_HOST_DEVICE -+ uint128_t(uint64_t lo_): hilo_(lo_, 0) { } -+ -+ /// Constructor from two 64b unsigned integers -+ CUTLASS_HOST_DEVICE -+ uint128_t(uint64_t lo_, uint64_t hi_): hilo_(lo_, hi_) { -+ -+ } -+ -+ /// Optional constructor from native value -+ #if defined(CUTLASS_UINT128_NATIVE) -+ uint128_t(unsigned __int128 value): native(value) { } -+ #endif -+ -+ /// Lossily cast to uint64 -+ CUTLASS_HOST_DEVICE -+ explicit operator uint64_t() const { -+ return hilo_.lo; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ static void exception() { -+#if defined(__CUDA_ARCH__) -+ asm volatile (" brkpt;\n"); -+#else -+ // throw std::runtime_error("Not yet implemented."); -+ abort(); -+#endif -+ } -+ -+ /// Add -+ CUTLASS_HOST_DEVICE -+ uint128_t operator+(uint128_t const &rhs) const { -+ uint128_t y; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native + rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (!y.hilo_.lo && (rhs.hilo_.lo)); -+#endif -+ return y; -+ } -+ -+ /// Subtract -+ CUTLASS_HOST_DEVICE -+ uint128_t operator-(uint128_t const &rhs) const { -+ uint128_t y; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native - rhs.native; -+#else -+ y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; -+ y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); -+#endif -+ return y; -+ } -+ -+ /// Multiply by unsigned 64b integer yielding 128b integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator*(uint64_t const &rhs) const { -+ uint128_t y{}; -+#if defined(CUTLASS_UINT128_NATIVE) -+ y.native = native * rhs; -+#elif defined(CUTLASS_INT128_ARITHMETIC) -+ // Multiply by the low part -+ y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); -+ -+ // Add the high part and ignore the overflow -+ uint64_t overflow; -+ y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(rhs); -+ exception(); -+#endif -+ return y; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTLASS_HOST_DEVICE -+ uint64_t operator/(uint64_t const &divisor) const { -+ uint64_t quotient = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ uint64_t remainder = 0; -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Divide 128b operation by 64b operation yielding a 64b quotient -+ CUTLASS_HOST_DEVICE -+ uint64_t operator%(uint64_t const &divisor) const { -+ uint64_t remainder = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return remainder; -+ } -+ -+ /// Computes the quotient and remainder in a single method. -+ CUTLASS_HOST_DEVICE -+ uint64_t divmod(uint64_t &remainder, uint64_t divisor) const { -+ uint64_t quotient = 0; -+#if defined(CUTLASS_UINT128_NATIVE) -+ quotient = uint64_t(native / divisor); -+ remainder = uint64_t(native % divisor); -+#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) -+ // implemented using MSVC's arithmetic intrinsics -+ quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); -+#else -+ // TODO - not implemented -+ CUTLASS_UNUSED(remainder); -+ CUTLASS_UNUSED(divisor); -+ exception(); -+#endif -+ return quotient; -+ } -+ -+ /// Left-shifts a 128b unsigned integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator<<(int sh) const { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= kPartSize) { -+ return uint128_t(0, hilo_.lo << (sh - kPartSize)); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo << sh), -+ (hilo_.hi << sh) | uint64_t(hilo_.lo >> (kPartSize - sh)) -+ ); -+ } -+ } -+ -+ /// Right-shifts a 128b unsigned integer -+ CUTLASS_HOST_DEVICE -+ uint128_t operator>>(int sh) const { -+ if (sh == 0) { -+ return *this; -+ } -+ else if (sh >= kPartSize) { -+ return uint128_t((hilo_.hi >> (sh - kPartSize)), 0); -+ } -+ else { -+ return uint128_t( -+ (hilo_.lo >> sh) | (hilo_.hi << (kPartSize - sh)), -+ (hilo_.hi >> sh) -+ ); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/include/cutlass/wmma_array.h b/3rdparty/cutlass/include/cutlass/wmma_array.h -new file mode 100644 -index 0000000..4a074b6 ---- /dev/null -+++ b/3rdparty/cutlass/include/cutlass/wmma_array.h -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/functional.h" -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Wmma array type (WmmaFragmentArray holds elements of of type nvcuda::wmma::fragment) -+template < -+ /// Element type -+ typename T, -+ /// Number of elements in the array -+ int N -+> -+class WmmaFragmentArray: public Array { -+public: -+ -+ /// Efficient clear method (override Array::clear()) -+ CUTLASS_HOST_DEVICE -+ void clear() -+ { -+ for(int i = 0; i < Array::kElements; i++) -+ { -+ nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); -+ } -+ } -+ -+ CUTLASS_HOST_DEVICE -+ WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) -+ { -+ using element_type = typename T::element_type; -+ plus add; -+ -+ for (int i = 0; i < Array::kElements; i++) -+ { -+ (*this)[i] = add((*this)[i], rhs[i]); -+ } -+ -+ return *this; -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) -+ -diff --git a/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h b/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h -new file mode 100644 -index 0000000..8843e40 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/common/cutlass_unit_test.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for vistual studio */ -+ -+#pragma nv_diag_suppress boolean_controlling_expr_is_constant -+#include -+#pragma nv_diag_warning boolean_controlling_expr_is_constant -+#pragma warning( disable : 4503) -+ -+#include -+#include -+ -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gets a CUDA device -+cudaDeviceProp GetCudaDevice(); -+ -+/// Prints device properties -+std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sets flags for Unit test -+void FilterArchitecture(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order -+// of problem sizes run by CUTLASS unit tests -+int CutlassUnitTestProblemCount(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// active test macro -+#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ -+ TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ -+ -+// disabled test macro -+#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ -+ TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} -+ -+#if CUTLASS_TEST_LEVEL == 0 -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#elif CUTLASS_TEST_LEVEL == 1 -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#else -+#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) -+#endif -+ -+#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) -+#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false -+#endif -+ -+#if (__CUDACC_VER_MAJOR__ >= 12) -+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED true -+#else -+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED false -+#endif -+ -+#include -+#include -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h b/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h -new file mode 100644 -index 0000000..29be434 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/cache_testbed_output.h -@@ -0,0 +1,797 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Helper to construct cached name for -+*/ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+#include "cutlass/conv/conv3d_problem_size.h" -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS -+#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Result of a test -+struct CachedTestKey { -+ -+ std::string op; ///< Concatenated string representation of operation performed -+ std::string problem; ///< Concatenated string representation of problem description -+ std::string types; ///< Concatenated string representation of operand types -+ uint32_t A; ///< Hashed result of tensor A -+ uint32_t B; ///< Hashed result of tensor B -+ uint32_t C; ///< Hashed result of tensor C -+ -+ // -+ // Methods -+ // -+ inline CachedTestKey(): A(), B(), C() { } -+ -+ inline CachedTestKey( -+ std::string op, ///< Concatenated string representation of operation performed -+ std::string problem, ///< Concatenated string representation of problem description -+ std::string types, ///< Concatenated string representation of operand types -+ uint32_t A, ///< Hashed result of tensor A -+ uint32_t B, ///< Hashed result of tensor B -+ uint32_t C ///< Hashed result of tensor C -+ ): -+ op(op), problem(problem), types(types), A(A), B(B), C(C) -+ { } -+ -+ /// Checks for equality of the problem -+ bool operator==(CachedTestKey const &rhs) const { -+ return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { -+ -+ in >> result.op; -+ in >> result.problem; -+ in >> result.types; -+ in >> result.A; -+ in >> result.B; -+ in >> result.C; -+ -+ return in; -+} -+ -+inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { -+ -+ out << result.op << " "; -+ out << result.problem << " "; -+ out << result.types << " "; -+ out << result.A << " "; -+ out << result.B << " "; -+ out << result.C << " "; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct CachedTestResult { -+ uint32_t D; -+ -+ // -+ // Methods -+ // -+ -+ CachedTestResult(): D() { } -+ -+ CachedTestResult(uint32_t D): D(D) { } -+ -+ operator bool() const { -+ return bool(D); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { -+ in >> result.D; -+ return in; -+} -+ -+inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { -+ out << result.D; -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct CachedTestResultListing { -+ -+ std::list> results; -+ -+ // -+ // Methods -+ // -+ -+ inline CachedTestResultListing(std::string const &path) { -+ std::ifstream file(path); -+ -+ while (file.good()) { -+ CachedTestKey key; -+ file >> key; -+ -+ CachedTestResult result; -+ file >> result; -+ -+ if (result) { -+ results.push_back(std::make_pair(key, result)); -+ } -+ } -+ } -+ -+ /// Returns the cached result -+ std::pair find(CachedTestKey const &rhs) const { -+ for (auto const & result : results) { -+ if (result.first == rhs) { -+ return std::make_pair(true, result.second); -+ } -+ } -+ return std::make_pair(false, CachedTestResult()); -+ } -+ -+ /// Appends an entry -+ void append(CachedTestKey const &key, CachedTestResult const &result) { -+ if (result) { -+ results.push_back(std::make_pair(key, result)); -+ } -+ } -+ -+ /// Writes the entire listing to a file -+ bool write(std::string const &path) { -+ std::ofstream file(path); -+ if (!file.good()) { -+ return false; -+ } -+ -+ for (auto const &result : results) { -+ file << result.first << result.second << std::endl; -+ } -+ -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ScalarEncoder { -+ Element scalar; -+ -+ ScalarEncoder(Element s): scalar(s) { } -+ -+ std::string str() const { -+ std::stringstream ss; -+ Element s = scalar; -+ if (s < Element()) { -+ s = -s; -+ ss << "n"; -+ } -+ ss << s; -+ return ss.str(); -+ } -+}; -+ -+template -+ScalarEncoder EncodeScalar(Element a) { -+ return ScalarEncoder(a); -+} -+ -+template -+struct ScalarEncoder> { -+ cutlass::complex scalar; -+ -+ ScalarEncoder(cutlass::complex s): scalar(s) { } -+ -+ std::string str() const { -+ std::stringstream ss; -+ ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; -+ return ss.str(); -+ } -+}; -+ -+template -+std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { -+ out << scalar.str(); -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { -+ switch (conv_op) { -+ case cutlass::conv::Operator::kFprop: return "fprop"; -+ case cutlass::conv::Operator::kDgrad: return "dgrad"; -+ case cutlass::conv::Operator::kWgrad: return "wgrad"; -+ } -+ return "conv_unknown"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Encode GemmCoord (Gemm problem size) -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::gemm::GemmCoord const &problem) { -+ -+ out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Encode Conv2dProblemSize -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::conv::Conv2dProblemSize const &problem) { -+ -+ out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" -+ << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; -+ -+ out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; -+ out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; -+ out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; -+ -+ switch (problem.mode) { -+ case cutlass::conv::Mode::kCrossCorrelation: -+ out << "corr"; -+ break; -+ case cutlass::conv::Mode::kConvolution: -+ out << "conv"; -+ break; -+ } -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Encode Conv3dProblemSize -+inline std::ostream &EncodeProblemSize( -+ std::ostream &out, -+ cutlass::conv::Conv3dProblemSize const &problem) { -+ -+ out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" -+ << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; -+ -+ out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; -+ out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; -+ out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; -+ -+ switch (problem.mode) { -+ case cutlass::conv::Mode::kCrossCorrelation: -+ out << "corr"; -+ break; -+ case cutlass::conv::Mode::kConvolution: -+ out << "conv"; -+ break; -+ } -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string ElementTypeName() { -+ return std::string(typeid(Element).name()); -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "h"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "ch"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "bf16"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "cbf16"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "tf32"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "ctf32"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "c"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "z"; -+} -+ -+template <> -+inline std::string ElementTypeName>() { -+ return "q"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "s8"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "u8"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "s4"; -+} -+ -+template <> -+inline std::string ElementTypeName() { -+ return "u4"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string LayoutTypeName() { -+ return std::string(typeid(Layout).name()); -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "n"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "t"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "nhwc"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "nc32hw32"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "nc64hw64"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "c32rsk32"; -+} -+ -+template <> -+inline std::string LayoutTypeName>() { -+ return "c64rsk64"; -+} -+ -+template <> -+inline std::string LayoutTypeName() { -+ return "ndhwc"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+inline std::string TensorTypeName() { -+ std::stringstream ss; -+ ss << ElementTypeName() << LayoutTypeName(); -+ return ss.str(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function on a byte array -+struct CRC32 { -+ -+ uint32_t table[256]; -+ -+ // -+ // Methods -+ // -+ -+ CRC32() { -+ -+ uint32_t rem; -+ int i, j; -+ -+ for (i = 0; i < 256; i++) { -+ rem = i; -+ for (j = 0; j < 8; j++) { -+ if (rem & 1) { -+ rem >>= 1; -+ rem ^= 0xedb88320; -+ } else -+ rem >>= 1; -+ } -+ table[i] = rem; -+ } -+ } -+ -+ /// Computes the CRC of an array of bytes -+ uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { -+ uint8_t const *p = static_cast(start); -+ uint8_t const *q = static_cast(start) + length; -+ -+ crc = ~crc; -+ -+ for (; p != q; ++p) { -+ uint8_t octet = *p; -+ crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; -+ } -+ -+ return ~crc; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Element, typename Layout -+> -+uint32_t TensorHash( -+ cutlass::TensorView view, -+ CRC32 const &hash = CRC32(), -+ uint32_t crc = uint32_t() -+) { -+ -+ return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline std::ostream &EncodeTypes( -+ std::ostream &out -+) { -+ -+ out << TensorTypeName() << "_" -+ << TensorTypeName() << "_" -+ << TensorTypeName() << "_" -+ << ElementTypeName() << "_" -+ << ElementTypeName(); -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedGemmTestKey( -+ cutlass::gemm::GemmCoord const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode gemm operator and problem sizes -+ key.op = "gemm"; -+ -+ std::stringstream ss_problem; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d_with_broadcast"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv2dWithReductionTestKey( -+ -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv2dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv2d operator and problem sizes -+ key.op = "conv2d_with_reduction"; -+ -+ std::stringstream ss_problem; -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode hash for problem data -+ CRC32 crc_hash; -+ -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename ElementAccumulator, -+ typename ElementCompute -+> -+inline CachedTestKey CreateCachedConv3dTestKey( -+ cutlass::conv::Operator conv_operator, -+ cutlass::conv::Conv3dProblemSize const &problem, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::TensorView A, -+ cutlass::TensorView B, -+ cutlass::TensorView C -+) { -+ -+ CachedTestKey key; -+ -+ // Encode conv3d operator and problem sizes -+ key.op = "conv3d"; -+ -+ std::stringstream ss_problem; -+ -+ ss_problem << EncodeOperator(conv_operator) << "_"; -+ EncodeProblemSize(ss_problem, problem); -+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); -+ -+ key.problem = ss_problem.str(); -+ -+ // Encode problem data types -+ std::stringstream ss_types; -+ EncodeTypes< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute>(ss_types); -+ key.types = ss_types.str(); -+ -+ // Encode problem data -+ CRC32 crc_hash; -+ key.A = TensorHash(A, crc_hash); -+ key.B = TensorHash(B, crc_hash); -+ key.C = TensorHash(C, crc_hash); -+ -+ return key; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // nammespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..cbabe42 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 64x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..08e3abd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..eaade32 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,298 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x64_64x3_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 2, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {35, 100, 50, 64}, // input size (NHWC) -+ {22, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..55d9525 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..62836ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..0891f80 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,209 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride, -+ 128x128_64x4_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..845e86b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3a7b380 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..42e85be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,303 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+std::vector Conv2dFewChannelProblemSizes(int channels) { -+ -+ std::vector problems; -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {16, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {32, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 8; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 4; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 2; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(2 * kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_1, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 1; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFewChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ Conv2dFewChannelProblemSizes(3 * kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..e6f676b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+std::vector Conv2dFixedChannelProblemSizes(int channels) { -+ -+ std::vector problems; -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {16, 3, 3, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {32, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 7, 7, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, channels}, // input size (NHWC) -+ {64, 5, 5, channels}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 8; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 4; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kChannelCount = 2; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kFixedChannels, -+ cutlass::conv::StrideSupport::kStrided, -+ kChannelCount, -+ kChannelCount -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d( -+ Conv2dFixedChannelProblemSizes(kChannelCount))); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..f892d33 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x128_8x2_16x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 128, 8>, -+ cutlass::gemm::GemmShape<16, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..e320c77 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x5_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 5, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..64d40d8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..af476db ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,350 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 14}, // input size (NHWC) -+ {8, 3, 3, 14}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 98}, // input size (NHWC) -+ {128, 3, 3, 98}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 14}, // input size (NHWC) -+ {8, 3, 3, 14}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 98}, // input size (NHWC) -+ {128, 3, 3, 98}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 28}, // input size (NHWC) -+ {8, 3, 3, 28}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 56, 100}, // input size (NHWC) -+ {128, 3, 3, 100}, // filter size (KRSC) -+ {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..b681dd2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..a848192 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+TEST(SM70_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a3e96e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,293 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align2, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7b68a68 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#if 0 -+TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..8f8eb88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..6b4fe2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu -new file mode 100755 -index 0000000..2ac1dfd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x64_8x2_8x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 64, 8>, -+ cutlass::gemm::GemmShape<8, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Fprop_Analytic_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 32x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Fprop_Optimized_ImplicitGemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32, -+ 16x32_8x2_16x16x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::Quaternion; -+ using ElementB = cutlass::Quaternion; -+ using ElementC = cutlass::Quaternion; -+ using ElementAccumulator = cutlass::Quaternion; -+ using ElementCompute = cutlass::Quaternion; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..0f794f2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu -@@ -0,0 +1,526 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x2_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x2_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x2_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9af25ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm80.cu -@@ -0,0 +1,527 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x4_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x128_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 128x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 256x64_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x256_128x3_64x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4ncxhwx_s4cxrskx_s4ncxhwx_tensor_op_s32, -+ 64x128_128x4_32x64x128) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<64>, -+ ElementB, cutlass::layout::TensorCxRSKx<64>, -+ ElementC, cutlass::layout::TensorNCxHWx<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..096e44f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..d285a2d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..9d77fb1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu -@@ -0,0 +1,685 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x2_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x2_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x2_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x2_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..120ce06 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm80.cu -@@ -0,0 +1,686 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x4_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x4_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x6_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 6, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 256x64_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 128x64_64x4_64x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x128_64x4_32x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 4, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8ncxhwx_s8cxrskx_s8ncxhwx_tensor_op_s32, -+ 64x64_64x6_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNCxHWx<32>, -+ ElementB, cutlass::layout::TensorCxRSKx<32>, -+ ElementC, cutlass::layout::TensorNCxHWx<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementC, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, -+ 6, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE((test::conv::device::TestAllInterleavedConv2d())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..d15f5c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu -@@ -0,0 +1,125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x2_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..e192d65 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAddSaturate, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..d15a435 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,142 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 8, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 2, -+ 2 -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu -new file mode 100644 -index 0000000..19dc3c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_broadcast_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) -+// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. -+// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, -+// which only the last thread block would have an access to, before applying BinaryOp. -+// The epilogue functor in the last thread block would have to be given three inputs, namely -+// partial outputs, bias, and residual, but this is not supported in the current interface. -+// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. -+template < -+ typename ElementAccumulator, -+ template class ActivationOp, -+ template class BinaryOp, -+ template class UnaryOp, -+ bool TestSplitK = false -+> -+void TestResidaulBlock() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementD = ElementC; -+ using ElementCompute = ElementAccumulator; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementD, -+ ElementAccumulator, -+ ElementCompute, -+ ElementC, -+ 8, -+ ActivationOp, -+ BinaryOp, -+ UnaryOp -+ >; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ struct ReferenceOp { -+ using OutputOp = typename Conv2dFprop::EpilogueOutputOp; -+ using ElementZ = typename OutputOp::ElementZ; -+ -+ ActivationOp activation; -+ BinaryOp binary_op; -+ UnaryOp unary_op; -+ -+ void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { -+ Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); -+ } -+ }; -+ -+ bool passed = test::conv::device::TestAllConv2dWithBroadcast(); -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Device_Conv2d_Fprop_With_Residual_Block_Plus_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // Resnet -+ TestResidaulBlock(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu -new file mode 100644 -index 0000000..17c77be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_broadcast_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::ReLu -+ >; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); -+} -+ -+// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) -+// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. -+// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, -+// which only the last thread block would have an access to, before applying BinaryOp. -+// The epilogue functor in the last thread block would have to be given three inputs, namely -+// partial outputs, bias, and residual, but this is not supported in the current interface. -+// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. -+template < -+ typename ElementAccumulator, -+ template class ActivationOp, -+ template class BinaryOp, -+ template class UnaryOp, -+ bool TestSplitK = true -+> -+void TestResidaulBlock() { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementD = ElementC; -+ using ElementCompute = ElementAccumulator; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementD, -+ ElementAccumulator, -+ ElementCompute, -+ ElementC, -+ 8, -+ ActivationOp, -+ BinaryOp, -+ UnaryOp -+ >; -+ -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ struct ReferenceOp { -+ using OutputOp = typename Conv2dFprop::EpilogueOutputOp; -+ using ElementZ = typename OutputOp::ElementZ; -+ -+ ActivationOp activation; -+ BinaryOp binary_op; -+ UnaryOp unary_op; -+ -+ void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { -+ Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); -+ } -+ }; -+ -+ bool passed = test::conv::device::TestAllConv2dWithBroadcast(); -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // Resnet -+ TestResidaulBlock(); -+} -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Multiply_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ // EfficientNet V2 -+ // Do not run split-K tests since the activation op is not Identity. -+ TestResidaulBlock(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu -new file mode 100644 -index 0000000..dc56278 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_fprop_with_reduction_sm75.cu -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_with_elementwise.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_with_reduction_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv2d_Fprop_With_Reduction_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationWithElementwise< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ /// Device-level Conv2d instance -+ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithReduction< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2dWithReduction()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h -new file mode 100644 -index 0000000..5d1fbdc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_problems.h -@@ -0,0 +1,860 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed sizes for Conv2d problem -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+using Conv2dProblemVector = std::vector; -+ -+// -+// Structures to prune items from Conv2dProblemVector -+// -+// Specification template for pruning items for convolution problem lists -+template struct Specification -+{ -+ virtual ~Specification() = default; -+ virtual bool is_satisfied(T item) const = 0; -+}; -+ -+// input size (NHWC) specification -+struct InputSizeSpecification : Specification -+{ -+ cutlass::Tensor4DCoord input_size; -+ -+ InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); -+ } -+}; -+ -+// stride (stride_h, stride_w) specification -+struct StrideSpecification : Specification -+{ -+ cutlass::MatrixCoord stride; -+ -+ StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); -+ } -+}; -+ -+// channel (C,K) specification, must be multiple of minimum channel -+struct ChannelDivisibilitySpecification : Specification -+{ -+ int channel_multiple; -+ -+ ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} -+ -+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { -+ return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); -+ } -+}; -+ -+// -+// Pruning function for items from Conv2dProblemVector based on a Specification -+// -+inline Conv2dProblemVector prune(Conv2dProblemVector const &items, -+ Specification const &spec) -+{ -+ Conv2dProblemVector pruned_list; -+ -+ for (auto& p : items) -+ if (spec.is_satisfied(p)) -+ pruned_list.push_back(p); -+ return pruned_list; -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedConv2dProblemSizes initializes and holds conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedConv2dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int minimum_channel_size; -+ -+ Conv2dProblemVector conv2d_default_sizes; -+ Conv2dProblemVector conv2d_rigorous_sizes; -+ Conv2dProblemVector conv2d_resnet50_sizes; -+ Conv2dProblemVector conv2d_resnet50_sizes_perf; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { -+ initialize_conv2d_default_sizes(); -+ initialize_conv2d_rigorous_sizes(); -+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); -+ -+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv2dProblemVector *problems_vectors[] = { -+ &conv2d_default_sizes, -+ &conv2d_rigorous_sizes, -+ &conv2d_resnet50_sizes, -+ &conv2d_resnet50_sizes_perf -+ }; -+ -+ for (Conv2dProblemVector *problems : problems_vectors) { -+ Conv2dProblemVector filtered; -+ -+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { -+ if (!(problem.C % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_conv2d_default_sizes() { -+ -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ // Small input size x stride (1,1) -+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 8, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 7, 8, minimum_channel_size}, // input size (NHWC) -+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 4, 4, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {2, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 5, 5, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 6, 5, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 6, 6, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 7, 9, minimum_channel_size}, // input size (NHWC) -+ {8, 7, 7, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ // Small input size x stride (2,2) -+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 11, 7, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 11, 7, minimum_channel_size}, // input size (NHWC) -+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 11, minimum_channel_size}, // input size (NHWC) -+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 17, 19, minimum_channel_size}, // input size (NHWC) -+ {16, 2, 2, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 5, minimum_channel_size}, // input size (NHWC) -+ {16, 3, 3, minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 17, 8}, // input size (NHWC) -+ {24, 3, 3, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 21, 8}, // input size (NHWC) -+ {24, 3, 3, 8}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 20, 24, 8}, // input size (NHWC) -+ {40, 3, 3, 8}, // filter size (KRSC) -+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 15, 19, 160}, // input size (NHWC) -+ {224, 1, 1, 160}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 19, 37, 160}, // input size (NHWC) -+ {224, 3, 3, 160}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, 160}, // input size (NHWC) -+ {224, 2, 3, 160}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 23, 21, 128}, // input size (NHWC) -+ {224, 3, 3, 128}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 29, 37, 160}, // input size (NHWC) -+ {224, 5, 5, 160}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) -+ {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) -+ {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 13, 16, 288}, // input size (NHWC) -+ {160, 5, 5, 288}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 51, 256}, // input size (NHWC) -+ {512, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 71, 80, 32}, // input size (NHWC) -+ {64, 5, 5, 32}, // filter size (KRSC) -+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 224, 224, 8}, // input size (NHWC) -+ {64, 7, 7, 8}, // filter size (KRSC) -+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size stride (3, 3), filter (3, 3), non-default padding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 23, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size padding > stride, asymmetric filter, padding and striding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 31, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) -+ {3, 4}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 35, 256}, // input size (NHWC) -+ {512, 7, 5, 256}, // filter size (KRSC) -+ {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) -+ {3, 5}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // Medium input size *mixed* stride (1, 2) and (2, 1), -+ // filter (3, 3), default padding -+ //////////////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 27, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 27, 27, 256}, // input size (NHWC) -+ {512, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ ///////////////////////////////////////////////////////////////////////////// -+ // Additional input size -+ ///////////////////////////////////////////////////////////////////////////// -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {3, 28, 28, 256}, // input size (NHWC) -+ {256, 2, 2, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 32, 32, 16}, // input size (NHWC) -+ {32, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {6, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {32, 24, 32, 32}, // input size (NHWC) -+ {32, 1, 2, 32}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {4, 4, 5, 128}, // input size (NHWC) -+ {256, 3, 6, 128}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {4, 3, 3, 256} // output size (NPQK) -+ )); -+ -+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {4, 2, 3, 256}, // input size (NHWC) -+ {328, 3, 5, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ {4, 1, 1, 328} // output size (NPQK) -+ )); -+ } -+ -+ -+ // Add a few large and rigorous convolution problem sizes -+ void initialize_conv2d_rigorous_sizes() { -+ -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 124, 224, 96}, // input size (NHWC) -+ {24, 7, 7, 96}, // filter size (KRSC) -+ {1, 229, 129, 32} // output size (NPQK) -+ )); -+ -+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 233, 35, 48}, // input size (NHWC) -+ {24, 7, 5, 48}, // filter size (KRSC) -+ {1, 233, 35, 24} // output size (NPQK) -+ )); -+ -+#endif -+ -+ } -+ -+ -+ // Add resent50 layers to unit testing sizes -+ void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ -+ -+#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ [1, 224, 224, 3], // input size (NHWC) -+ [64, 7, 7, 3], // filter size (KRSC) -+ [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) -+ [2, 2], // stride (stride_h, stride_w) -+ [1, 1], // dilation (dilation_h, dilation_w) -+ )); -+#endif -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {256, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {64, 1, 1, 64}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 64}, // input size (NHWC) -+ {64, 3, 3, 64}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {64, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {512, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 56, 56, 256}, // input size (NHWC) -+ {128, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 128}, // input size (NHWC) -+ {128, 3, 3, 128}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 128}, // input size (NHWC) -+ {512, 1, 1, 128}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {128, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {1024, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 28, 28, 512}, // input size (NHWC) -+ {256, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 256}, // input size (NHWC) -+ {256, 3, 3, 256}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 256}, // input size (NHWC) -+ {1024, 1, 1, 256}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {256, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {2048, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 14, 14, 1024}, // input size (NHWC) -+ {512, 1, 1, 1024}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 512}, // input size (NHWC) -+ {512, 3, 3, 512}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 512}, // input size (NHWC) -+ {2048, 1, 1, 512}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( -+ {batch_size, 7, 7, 2048}, // input size (NHWC) -+ {512, 1, 1, 2048}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ } -+ -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedGroupConv2dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int threadblock_n; -+ int threadblock_k; -+ int minimum_channel_size; -+ -+ Conv2dProblemVector default_single_group_sizes; -+ Conv2dProblemVector default_multiple_group_sizes; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedGroupConv2dProblemSizes( -+ int threadblock_n_, -+ int threadblock_k_, -+ int minimum_channel_size_ = 64) -+ : threadblock_n (threadblock_n_), -+ threadblock_k (threadblock_k_), -+ minimum_channel_size (minimum_channel_size_) { -+ initialize_group_conv2d_default_sizes(); -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv2dProblemVector *problems_vectors[] = { -+ &default_single_group_sizes, -+ &default_multiple_group_sizes -+ }; -+ -+ for (Conv2dProblemVector *problems : problems_vectors) { -+ Conv2dProblemVector filtered; -+ -+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { -+ if (!((problem.C / problem.groups) % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_group_conv2d_default_sizes() { -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 -+ // One CTA calculates a single group -+ //////////////////////////////////////////////////////////////////////////////////// -+ -+ for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { -+ // groups = 2, 3, 4 -+ for (int groups = 2; groups < 5; ++groups) { -+ -+ int conv_k = cta_per_group_k * threadblock_n * groups; -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) -+ {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ groups // groups -+ )); -+ -+ } // loop groups -+ } // loop cta_per_group_k -+ -+ // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k}, // input size (NHWC) -+ {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // Larger problem sizes -+ -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 696}, // input size (NHWC) -+ {768, 3, 3, 232}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 3 // groups -+ )); -+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 14, 14, 1392}, // input size (NHWC) -+ {1536, 3, 3, 232}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 3 // groups -+ )); -+ -+ //////////////////////////////////////////////////////////////////////////////////// -+ // One CTA calculate multiple groups: CTA::N % k_per_group = 0 -+ //////////////////////////////////////////////////////////////////////////////////// -+ -+ // 2 groups per CTA -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 4}, // input size (NHWC) -+ {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // 2 groups per CTA and partial gemm_k -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k}, // input size (NHWC) -+ {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 2 // groups -+ )); -+ -+ // 4 groups per CTA -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 8}, // input size (NHWC) -+ {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 4 // groups -+ )); -+ -+ // 4 groups per CTA and partial gemm_k -+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, threadblock_k * 2}, // input size (NHWC) -+ {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, -+ 1, // split_k_slices -+ 4 // groups -+ )); -+ } -+ -+}; -+ -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a910d61 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,370 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad (Analytic) -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ -+// run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 8}, // input size (NHWC) -+ {8, 3, 3, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x256_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x256_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Strided Dgrad (Optimized) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 55, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 12}, // input size (NHWC) -+ {8, 1, 1, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 55, 55, 12}, // input size (NHWC) -+ {8, 1, 1, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..b607a8a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,112 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align4, -+ 64x64_32x5_32x32x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, -+ 5, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, 16}, // input size (NHWC) -+ {8, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 1, 1, 16}, // input size (NHWC) -+ {8, 3, 3, 16}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h -new file mode 100644 -index 0000000..582b433 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed.h -@@ -0,0 +1,806 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv2d { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ int tested_problem_count; -+ -+public: -+ -+ TestbedConv2d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } -+ else { -+ scope = 5; -+ } -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // increment tested problem count run by the testbed -+ tested_problem_count++; -+ -+#if 0 // display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run." << std::endl; -+ return false; -+ } -+ -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ std::stringstream ss_problem_size_text; -+ ss_problem_size_text << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << ss_problem_size_text.str() -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestSpecificConv2d( -+ const Conv2dProblemVector & problem_sizes) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2d testbed; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for(auto conv_problem : problem_sizes) { -+ -+ // -+ // Test -+ // -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllConv2d( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) -+ std::vector problem_vectors = { -+ conv_test_sizes, // run user specified sizes -+ conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Flatten 2D problem_vectors into a 1D problem_sizes -+ std::vector problem_sizes; -+ for (auto problem_vector : problem_vectors) { -+ for(auto conv_problem : problem_vector) { -+ problem_sizes.push_back(conv_problem); -+ } -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) -+ // run the most rigorous problem size first -+ if (CutlassUnitTestProblemCount()) { -+ std::reverse(problem_sizes.begin(), problem_sizes.end()); -+ } -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for(auto conv_problem : problem_sizes) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // Fixed channels algorithm requires channel count to match access size -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFixedChannels) { -+ if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { -+ continue; -+ } -+ } -+ -+ // Few channels algorithm requires channel count to match access size -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFewChannels) { -+ if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { -+ continue; -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} -+ // Although strided dgrad works for all stride combinations, we are only going -+ // to run strided dgrad for non-unity strides -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts -+ if (CutlassUnitTestProblemCount() && -+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) { -+ return true; -+ } -+ } -+ -+ // Small-channels convolution can't run here. -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFixedChannels) { -+ -+ return true; -+ } -+ -+ // Small-channels convolution can't run here. -+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == -+ cutlass::conv::IteratorAlgorithm::kFewChannels) { -+ -+ return true; -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts -+ if (CutlassUnitTestProblemCount() && -+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) { -+ return true; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h -new file mode 100644 -index 0000000..79f00d1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h -@@ -0,0 +1,665 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class InterleavedTestbedConv2d { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_B_reordered; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ InterleavedTestbedConv2d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ scope = 3; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ cutlass::reorder_convK( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerMultiprocessor < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ cutlass::NumericConverterClamp -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "ncxhwx_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_cxrskx_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllInterleavedConv2d( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ InterleavedTestbedConv2d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK -+ auto pruned_problem_vector = prune(*problem_vector, channel_spec); -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : pruned_problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+#if 0 -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+#endif -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -new file mode 100644 -index 0000000..4fbdf98 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 64x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 32x64_8x2_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..c8d6bde ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,138 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_32x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::complex; -+ using ElementB = cutlass::complex; -+ using ElementC = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..8932187 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, -+ 128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..23c749a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm70.cu -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+TEST(SM70_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a07c9b4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..9c81b48 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32 >, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 4, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kStrided, -+ 4, -+ 4 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 4, 4, 12}, // input size (NHWC) -+ {8, 3, 3, 12}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {3, 3}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -new file mode 100644 -index 0000000..3c6cbf4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, -+ 128x128_8x4_64x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = float; -+ using ElementB = float; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..991e1e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity, -+ 1, -+ 1 -+ >::Kernel; -+ -+ using Conv2dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ test::conv::device::Conv2dProblemVector problem_size_list; -+ -+ // run specific problem size in the unit test first -+ problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, 1}, // input size (NHWC) -+ {1, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ )); -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h -new file mode 100644 -index 0000000..117fef0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h -@@ -0,0 +1,686 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM for fused epilogue broadcast testbed -+ -+ Parallel split-k is not tested because we can just use regular conv kernel -+ when we need to use parallel-splitk. Broadcast can happen in the reduction -+ kernel. -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Conv2dWithBroadcastReferenceOp { -+ -+ using OutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ Conv2dWithBroadcastReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { -+ ElementCompute t_full = binary_op(conv2d, bias); -+ T = ElementT(t_full); -+ -+ ElementCompute z_full = elementwise_op(t_full); -+ Z = ElementZ(z_full); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Fused testbed -+// -+// Y = CONV(AB, C) -+// -+// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) -+// -+// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) -+// -+ -+template < -+ typename Conv2d, -+ typename ReferenceOp, -+ bool AddBroadcastFirst = false -+> -+class TestbedConv2dWithBroadcast { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ using ElementZ = typename EpilogueOutputOp::ElementZ; -+ using ElementT = typename EpilogueOutputOp::ElementT; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ static const bool kAddBroadcastFirst = AddBroadcastFirst; -+ static const bool kStoreT = EpilogueOutputOp::kStoreT; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_C_reference; -+ cutlass::HostTensor tensor_Z_computed; -+ cutlass::HostTensor tensor_Z_reference; -+ cutlass::HostTensor tensor_T_computed; -+ cutlass::HostTensor tensor_T_reference; -+ cutlass::HostTensor tensor_Y_reference; -+ cutlass::HostTensor tensor_Broadcast; // Input Broadcast -+ -+public: -+ -+ TestbedConv2dWithBroadcast( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } -+ else { -+ scope = 5; -+ } -+ } -+ else { -+ scope = 8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_Broadcast.resize({ -+ 1, -+ 1, -+ 1, -+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), -+ }); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); -+ -+ for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { -+ for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { -+ for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { -+ for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { -+ tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_Broadcast.sync_device(); -+ tensor_C_reference.sync_device(); -+ tensor_Z_computed.sync_device(); -+ tensor_Z_reference.sync_device(); -+ tensor_T_computed.sync_device(); -+ tensor_T_reference.sync_device(); -+ tensor_Y_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(1)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_Z_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode, -+ tensor_Broadcast.device_data(), -+ kStoreT ? tensor_T_computed.device_data() : nullptr, -+ 0, // This must be zero -+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() -+ ); -+ -+ // initialize the kernel -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_T_computed.sync_host(); -+ tensor_Z_computed.sync_host(); -+ -+ // -+ // Reference check -+ // -+ -+ // When kAddBroadcastFirst is true, add bias on the host -+ ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C_reference.device_ref(), -+ tensor_Y_reference.device_ref(), -+ alpha, -+ beta_ref); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_Y_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementAccumulator, -+ LayoutC, -+ ElementAccumulator, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C_reference.host_ref(), -+ tensor_Y_reference.host_ref(), -+ alpha, -+ beta_ref); -+ -+#endif -+ ReferenceOp reference_op; -+ -+ // compute tensor Z and tensor T -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ ElementZ z; -+ ElementT t; -+ -+ ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); -+ ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); -+ -+ -+ if (kAddBroadcastFirst) { -+ reference_op(z, t, accum + bias, -+ beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); -+ } else { -+ reference_op(z, t, accum, bias); -+ } -+ -+ tensor_Z_reference.at({n, p, q, k}) = z; -+ tensor_T_reference.at({n, p, q, k}) = t; -+ } -+ } -+ } -+ } -+ -+ if (kStoreT) { -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_T_computed.host_view(), -+ tensor_T_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ } -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_Z_computed.host_view(), -+ tensor_Z_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" -+ << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" -+ << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" -+ << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" -+ << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" -+ << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template , -+ bool AddBroadcastFirst = false, -+ bool TestSplitK = true -+> -+bool TestAllConv2dWithBroadcast( -+ const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2dWithBroadcast testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+#if 0 // relax restrictions on analytic strided dgrad -+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+#endif -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ -+ if (!TestSplitK) -+ return passed; -+ -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h -new file mode 100644 -index 0000000..4064648 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h -@@ -0,0 +1,643 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "conv2d_problems.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv2dWithReduction { -+public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ using ElementT = typename EpilogueOutputOp::ElementTensor; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ -+ cutlass::HostTensor tensor_Reduction; -+ cutlass::HostTensor tensor_Tensor; -+ cutlass::HostTensor tensor_Final_Reduction; -+ -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ TestbedConv2dWithReduction( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope = 2; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ tensor_Reduction.resize({ -+ 1, -+ 1, -+ (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, -+ (problem_size.K) -+ }); -+ -+ tensor_Final_Reduction.resize({ -+ 1, -+ 1, -+ 1, -+ (problem_size.K) -+ }); -+ -+ tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); -+ -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode, -+ tensor_Reduction.device_data(), -+ tensor_Tensor.device_data(), -+ static_cast(tensor_Reduction.stride()[0]), -+ static_cast(tensor_Tensor.stride()[0]) -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv2d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv2d output is written to workspace in global memory -+ conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; -+ // update conv2d operator arguments -+ status = conv2d_op.update(conv2d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ // Final reduction over the partial reduction tensor -+ using Functor = cutlass::plus; -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementAccumulator, -+ ElementAccumulator, -+ LayoutC, -+ Functor, -+ 8, -+ ElementAccumulator -+ >; -+ -+ TensorReduction reduction(tensor_Reduction.extent(), 2); -+ -+ cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); -+ -+ status = reduction.reduce( -+ tensor_Final_Reduction.device_ref(), -+ tensor_Reduction.device_ref(), -+ reduction_device_workspace.get(), -+ ElementAccumulator()); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ -+ // -+ // Reference check -+ // -+ -+ tensor_D_computed.sync_host(); -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ // -+ // Reference check on reduction results -+ // -+ -+ tensor_Reduction.sync_host(); -+ tensor_Final_Reduction.sync_host(); -+ -+ // compute backwards for reduction results -+ cutlass::HostTensor reference_Reduction; -+ reference_Reduction.resize({ -+ 1, -+ 1, -+ 1, -+ (problem_size.K) -+ }); -+ -+ for (int k = 0; k < problem_size.K; ++k) { -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ reduced_value += tensor_D_reference.at({n, p, q, k}); -+ } -+ } -+ } -+ reference_Reduction.at({0, 0, 0, k}) = reduced_value; -+ } -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_Final_Reduction.host_view(), -+ reference_Reduction.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" -+ << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" -+ << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestAllConv2dWithReduction( -+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), -+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ TestbedConv2dWithReduction testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv2d problem sizes to avoid duplicate runs -+ Conv2dProblemVector conv_tested_sizes; -+ -+ Conv2dProblemVector const *problem_vectors[] = { -+ &conv_test_sizes, // run user specified sizes -+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes -+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -+#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled -+#endif -+ }; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv2dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity)) { -+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+ -+#if 0 // relax restrictions on analytic strided dgrad -+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { -+ continue; -+ } -+ } -+#endif -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kStrided)) { -+ -+ passed = testbed.run( -+ cutlass::conv::Conv2dProblemSize( -+ {1, 56, 56, 8}, // input size (NHWC) -+ {8, 1, 1, 8}, // filter size (KRSC) -+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}), // dilation (dilation_h, dilation_w) -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::from_real(2.0), -+ cutlass::from_real(2.0)); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ return passed; -+ } -+ -+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( -+ {1, 17, 11, 288}, // input size (NHWC) -+ {160, 3, 3, 288}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1} // dilation (dilation_h, dilation_w) -+ ); -+ -+ // Parallel SplitK is not tested. -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..909a1df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Dgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Dgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..6864bc4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_dgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Dgrad_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Dgrad_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dDgradKernel = typename cutlass::conv::kernel::DefaultConv3dDgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized, -+ cutlass::conv::StrideSupport::kUnity -+ >::Kernel; -+ -+ using Conv3dDgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..7484e8d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,86 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Conv3d_Fprop_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..24990ff ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Fprop_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..723e15e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_fprop_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dFpropKernel = typename cutlass::conv::kernel::DefaultConv3dFprop< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h b/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h -new file mode 100644 -index 0000000..3c0512e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_problems.h -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed sizes for Conv2d problem -+*/ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/layout/pitch_linear.h" -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+using Conv3dProblemVector = std::vector; -+ -+//////////////////////////////////////////////////////////////////////////// -+/// Structure TestbedConv3dProblemSizes initializes and holds conv default and -+/// important network sizes -+//////////////////////////////////////////////////////////////////////////// -+struct TestbedConv3dProblemSizes { -+ -+ // -+ // Data members -+ // -+ int minimum_channel_size; -+ Conv3dProblemVector conv3d_default_sizes; -+ Conv3dProblemVector conv3d_vnet_medical_sizes; -+ -+ // -+ // Methods -+ // -+ /// Default ctor -+ TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { -+ -+ initialize_conv3d_default_sizes(); -+ initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); -+ -+ filter_all(); -+ } -+ -+ /// Eliminates some illegal cases -+ void filter_all() { -+ -+ Conv3dProblemVector *problems_vectors[] = { -+ &conv3d_default_sizes, -+ &conv3d_vnet_medical_sizes -+ }; -+ -+ for (Conv3dProblemVector *problems : problems_vectors) { -+ Conv3dProblemVector filtered; -+ -+ for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { -+ if (!(problem.C % minimum_channel_size)) { -+ filtered.push_back(problem); -+ } -+ } -+ -+ *problems = filtered; -+ } -+ } -+ -+ // Add a few standard convolution problem sizes -+ void initialize_conv3d_default_sizes() { -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) -+ {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) -+ {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) -+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) -+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 15, 19, 160}, // input size (NDHWC) -+ {224, 1, 3, 6, 160}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) -+ {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) -+ {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( -+ {1, 11, 15, 19, 64}, // input size (NDHWC) -+ {32, 4, 3, 6, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ } -+ -+ // Add vnet layers to unit testing sizes -+ void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC) -+ {32, 2, 2, 2, 16}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {32, 3, 3, 3, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {64, 2, 2, 2, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC) -+ {64, 3, 3, 3, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC) -+ {128, 2, 2, 2, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 4, 4, 4, 128}, // input size (NDHWC) -+ {128, 3, 3, 3, 128}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 8, 8, 8, 128}, // input size (NDHWC) -+ {128, 3, 3, 3, 128}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 64}, // input size (NDHWC) -+ {64, 3, 3, 3, 64}, // filter size (KTRSC) -+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC) -+ {64, 2, 2, 2, 16}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ -+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( -+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC) -+ {128, 2, 2, 2, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ )); -+ -+ } -+ -+}; -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h b/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h -new file mode 100644 -index 0000000..a5fa186 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_testbed.h -@@ -0,0 +1,669 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Implicit GEMM testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+ -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+ -+#include "conv3d_problems.h" -+#include "cutlass/core_io.h" -+ -+#include "cache_testbed_output.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedConv3d { -+public: -+ -+ using ElementA = typename Conv3d::ElementA; -+ using LayoutA = typename Conv3d::LayoutA; -+ using ElementB = typename Conv3d::ElementB; -+ using LayoutB = typename Conv3d::LayoutB; -+ using ElementC = typename Conv3d::ElementC; -+ using LayoutC = typename Conv3d::LayoutC; -+ using ElementAccumulator = typename Conv3d::ElementAccumulator; -+ using ElementCompute = typename Conv3d::ElementCompute; -+ using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; -+ -+ /// Reduction kernel -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ >; -+ -+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK; -+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex; -+ -+public: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+public: -+ -+ TestbedConv3d( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } -+ else if (bits == 16) { -+ scope = 4; -+ } -+ else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope, -scope, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } -+ else { -+ } -+ } -+ -+ void initialize( -+ cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { -+ -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Conv3d::UnderlyingKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::conv::Conv3dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute()) { -+ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 //display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl -+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv3d conv3d_op; -+ -+ typename Conv3d::Arguments conv3d_args( -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ split_k_mode -+ ); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // conv3d operation with parallel split-k-mode -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // conv3d output is written to workspace in global memory -+ conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); -+ // accumulate mma for each cta in k-dimension (1.0 * A * B) -+ conv3d_args.output_op = {1.0, 0.0}; -+ // update conv3d operator arguments -+ status = conv3d_op.update(conv3d_args, workspace.get()); -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run conv3d operator -+ status = conv3d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { -+ -+ // configure parallel reduction operator -+ ReductionDevice reduction_op; -+ -+ typename ReductionDevice::Arguments reduction_args( -+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), -+ problem_size.split_k_slices, -+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), -+ { -+ reinterpret_cast (workspace.get()), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_D_computed.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ { -+ tensor_C.device_data(), -+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) -+ }, -+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C -+ {alpha, beta} -+ ); -+ -+ status = reduction_op.initialize(reduction_args, nullptr); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // run prallel reduction kernel -+ status = reduction_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ } -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " -+ << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = CreateCachedConv3dTestKey< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view() -+ ); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+ -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta -+ ); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ cutlass::reference::host::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ ElementCompute -+ >( -+ kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta -+ ); -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv3d_ImplicitGemm_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << "ndhwc_" -+ << problem_size.N << "x" -+ << problem_size.D << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_ktrsc_" -+ << problem_size.K << "x" -+ << problem_size.T << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_d << "x" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_d << "x" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_d << "x" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") -+ << Conv3d::ThreadblockShape::kM << "x" -+ << Conv3d::ThreadblockShape::kN << "x" -+ << Conv3d::ThreadblockShape::kK << "_" -+ << Conv3d::WarpShape::kM << "x" -+ << Conv3d::WarpShape::kN << "x" -+ << Conv3d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -+// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -+// Additionaly, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -+// (conv_blacklist_sizes) -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllConv3d( -+ const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), -+ const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { -+ -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ -+ //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); -+ TestbedConv3d testbed; -+ -+ // -+ // Get conv problem sizes to run conv operator -+ // -+ TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); -+ -+ // Vector of conv3d problem sizes to avoid duplicate runs -+ Conv3dProblemVector conv_tested_sizes; -+ -+ Conv3dProblemVector const *problem_vectors[] = { -+ &conv3d_problems.conv3d_default_sizes, -+ &conv3d_problems.conv3d_vnet_medical_sizes, -+ &conv_test_sizes -+ }; -+ -+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (Conv3dProblemVector const * problem_vector : problem_vectors) { -+ -+ // Run conv testbed on default convolution sizes -+ for(auto conv_problem : *problem_vector) { -+ -+ // Skip blacklist and avoid duplicate problem sizes -+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || -+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { -+ continue; -+ } -+ -+ // -+ // Procedurally disable certain cases -+ // -+ -+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} -+ if ((ImplicitGemm::kConvolutionalOperator == -+ cutlass::conv::Operator::kDgrad) && -+ ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity) || -+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == -+ cutlass::conv::StrideSupport::kUnity))) { -+ if (!((conv_problem.stride_d == 1) && -+ (conv_problem.stride_h == 1) && -+ (conv_problem.stride_w == 1)) -+ ) { -+ continue; -+ } -+ } -+ -+ // -+ // Test -+ // -+ // push back tested problem size to avoid re-running duplicates -+ conv_tested_sizes.push_back(conv_problem); -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for -+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters -+ // which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep -+ // alpha and beta for local testing, but only runs one value for alpha and beta. -+ cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( -+ {1, 8, 8, 8, 32}, // input size (NDHWC) -+ {32, 3, 3, 3, 32}, // filter size (KTRSC) -+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) -+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) -+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) -+ ); -+ -+ cutlass::conv::SplitKMode split_k_modes [] = { -+ cutlass::conv::SplitKMode::kSerial, -+ cutlass::conv::SplitKMode::kParallel -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3, 4, 201 -+ }; -+ -+ double problem_alpha[] = { -+ 2.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (auto split_k_mode : split_k_modes) { -+ for (auto split_k_slice : split_k_slices) { -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ passed = testbed.run( -+ conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), -+ split_k_mode, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..4da6f71 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+TEST(SM75_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x2_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..9d4f228 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,165 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, -+ 64x256_32x4_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM75_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..abcb58b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/conv3d_wgrad_implicit_gemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv3d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv3d_Wgrad_Analytic_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Conv3d_Wgrad_Optimized_ImplicitGemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, -+ 128x128_32x3_64x64x32) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::tfloat32_t; -+ using ElementB = cutlass::tfloat32_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ /// Device-level Conv2d instance -+ using Conv3dWgradKernel = typename cutlass::conv::kernel::DefaultConv3dWgrad< -+ ElementA, cutlass::layout::TensorNDHWC, -+ ElementB, cutlass::layout::TensorNDHWC, -+ ElementC, cutlass::layout::TensorNDHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv3dWgrad = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv3d instance -+ EXPECT_TRUE(test::conv::device::TestAllConv3d()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h -new file mode 100644 -index 0000000..1c2506c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h -@@ -0,0 +1,473 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Depthwise Direct Conv testbed -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cache_testbed_output.h" -+#include "conv2d_problems.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/device/convolution.h" -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+namespace test { -+namespace conv { -+namespace device { -+ -+template -+class TestbedDepthwiseDirectConv2d { -+ public: -+ -+ using ElementA = typename Conv2d::ElementA; -+ using LayoutA = typename Conv2d::LayoutA; -+ using ElementB = typename Conv2d::ElementB; -+ using LayoutB = typename Conv2d::LayoutB; -+ using ElementC = typename Conv2d::ElementC; -+ using LayoutC = typename Conv2d::LayoutC; -+ using ElementAccumulator = typename Conv2d::ElementAccumulator; -+ using ElementCompute = typename Conv2d::ElementCompute; -+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; -+ -+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; -+ -+ public: -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_reordered_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ int tested_problem_count; -+ -+ public: -+ TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ void initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ int scope; -+ int bits = cutlass::sizeof_bits::value; -+ -+ if (bits <= 8) { -+ scope = 2; -+ } else if (bits == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope = 3; -+ } else { -+ scope = 5; -+ } -+ } else { -+ scope = 8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ -+ } else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); -+ } else { -+ } -+ } -+ -+ void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { -+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); -+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); -+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); -+ -+ initialize_tensor(tensor_A.host_view(), init_A, seed); -+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); -+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_reordered_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_D_reference.sync_device(); -+ } -+ -+ bool sufficient(int smem_size) const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run(cutlass::conv::Conv2dProblemSize const &problem_size, -+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, -+ ElementCompute alpha = ElementCompute(1.5), -+ ElementCompute beta = ElementCompute(1)) { -+ // increment tested problem count run by the testbed -+ tested_problem_count++; -+ -+#if 0 // display conv2d problem size for debugging -+ std::cout << problem_size << std::endl -+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl -+ << "split_k_mode: " -+ << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") -+ << std::endl -+ << std::endl; -+#endif -+ -+ initialize(problem_size); -+ -+ // configure the operator -+ Conv2d conv2d_op; -+ -+ typename Conv2d::Arguments conv2d_args(problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_computed.device_ref(), -+ {alpha, beta}, -+ tensor_reordered_B.device_ref(), -+ split_k_mode); -+ -+ // find workspace requirement for parallel split-k reduction -+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = conv2d_op.can_implement(problem_size); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ status = conv2d_op.initialize(conv2d_args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ if (!sufficient(conv2d_op.get_smem_size())) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // run conv2d operator -+ status = conv2d_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Failed to run." << std::endl; -+ return false; -+ } -+ -+ bool passed = false; -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference check - support caching results -+ // -+ -+ CachedTestKey cached_test_key = -+ CreateCachedConv2dTestKey(kConvolutionalOperator, -+ problem_size, -+ alpha, -+ beta, -+ tensor_A.host_view(), -+ tensor_B.host_view(), -+ tensor_C.host_view()); -+ -+ // -+ // Look for the cached key -+ // -+ -+ bool cached_result_loaded = false; -+ CachedTestResult cached_test_result; -+ -+ std::string conv2d_result_cache_name = -+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ auto cached = cached_results.find(cached_test_key); -+ -+ cached_result_loaded = cached.first; -+ if (cached_result_loaded) { -+ cached_test_result = cached.second; -+ } -+ } -+ -+ if (!cached_result_loaded) { -+#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED -+ -+ cutlass::reference::device::Conv2d(kConvolutionalOperator, -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D_reference.device_ref(), -+ alpha, -+ beta); -+ -+ // sync host (copy device data to host) for dumping error output in case of mismatches -+ tensor_D_reference.sync_host(); -+ -+#else -+ -+ cutlass::reference::host::Conv2d(kConvolutionalOperator, -+ problem_size, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref(), -+ alpha, -+ beta); -+ -+#endif -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ -+ cached_test_result.D = TensorHash(tensor_D_reference.host_view()); -+ -+ CachedTestResultListing cached_results(conv2d_result_cache_name); -+ -+ cached_results.append(cached_test_key, cached_test_result); -+ cached_results.write(conv2d_result_cache_name); -+ } -+ } // if (!cached_result_loaded) -+ -+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); -+ -+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { -+ passed = (tensor_D_hash == cached_test_result.D); -+ -+ EXPECT_EQ(tensor_D_hash, cached_test_result.D) -+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; -+ } -+ else { -+ -+ passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view()); -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ std::stringstream ss_problem_size_text; -+ ss_problem_size_text << "nhwc_" -+ << problem_size.N << "x" -+ << problem_size.H << "x" -+ << problem_size.W << "x" -+ << problem_size.C -+ << "_krsc_" -+ << problem_size.K << "x" -+ << problem_size.R << "x" -+ << problem_size.S << "x" -+ << problem_size.C -+ << "_padding_" -+ << problem_size.pad_h << "x" -+ << problem_size.pad_w -+ << "_stride_" -+ << problem_size.stride_h << "x" -+ << problem_size.stride_w -+ << "_dilation_" -+ << problem_size.dilation_h << "x" -+ << problem_size.dilation_w << "_" -+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Conv2d_DirectConv_device_" -+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") -+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : -+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) -+ << ss_problem_size_text.str() -+ << Conv2d::ThreadblockShape::kM << "x" -+ << Conv2d::ThreadblockShape::kN << "x" -+ << Conv2d::ThreadblockShape::kK << "_" -+ << Conv2d::WarpShape::kM << "x" -+ << Conv2d::WarpShape::kN << "x" -+ << Conv2d::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n"; -+ -+ results << "\nD reference (hash: " << cached_test_result.D << ")\n"; -+ -+ if (!cached_result_loaded) { -+ results -+ << tensor_D_reference.host_view() << "\n"; -+ } -+ -+ results -+ << "\nD computed (hash: " << tensor_D_hash << ")\n" -+ << tensor_D_computed.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { -+ bool passed = true; -+ -+ // -+ // Testbed object -+ // -+ TestbedDepthwiseDirectConv2d testbed; -+ -+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) -+ for (auto conv_problem : problem_sizes) { -+ // -+ // Test -+ // -+ -+ // test mode = xcross -+ passed = testbed.run( -+ conv_problem, -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ -+ // test mode = convolution -+ passed = testbed.run( -+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), -+ cutlass::conv::SplitKMode::kSerial); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace conv -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..8efc73e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "conv2d_testbed.h" -+#include "depthwise_conv2d_direct_conv_testbed.h" -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3() { -+ std::vector problems; -+ -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ // if(channels == 512 || channels == 16*14) -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x37() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 128, 128, channels}, // input size (NHWC) -+ {channels, 5, 37, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 108, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_3x3) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_5x5) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 64; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5())); -+} -+ -+#if 0 -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_3_16x32_5x37) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 37>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 2; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kOptimized; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kStrided>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x37())); -+} -+#endif -+ -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..00bbafa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,522 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "conv2d_testbed.h" -+#include "depthwise_conv2d_direct_conv_testbed.h" -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1() { -+ std::vector problems; -+ -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2() { -+ std::vector problems; -+ for (int channels = 16; channels <= 512; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1() { -+ std::vector problems; -+ -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+ -+} -+ -+std::vector DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2() { -+ std::vector problems; -+ for (int channels = 16; channels < 256; channels += 16) { -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 16, // split_k_slices -+ channels // groups -+ )); -+ } -+ -+ return problems; -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_Filter3x3_Stride1x1_Dilation1x1) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<1, 1>; -+ using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x32_4_8x32_Filter3x3_Stride2x2_Dilation2x2) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<3, 3>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 4; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<2, 2>; -+ using DilationShape = cutlass::MatrixShape<2, 2>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_Filter5x5_Stride1x1_Dilation1x1) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 64; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<1, 1>; -+ using DilationShape = cutlass::MatrixShape<1, 1>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1())); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST( -+ SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_3_16x64_Filter5x5_Stride2x2_Dilation2x2) { -+ -+ using ElementInputA = cutlass::half_t; -+ using ElementInputB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementComputeEpilogue = cutlass::half_t; -+ -+ using LayoutInputA = cutlass::layout::TensorNHWC; -+ using LayoutInputB = cutlass::layout::TensorNHWC; -+ using LayoutOutput = cutlass::layout::TensorNHWC; -+ -+ // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU -+ // SM -+ using MMAOp = cutlass::arch::OpClassSimt; -+ -+ // This code section describes CUDA SM architecture number -+ using SmArch = cutlass::arch::Sm60; -+ -+ // This code section describes the groups a thread block will compute -+ constexpr int groups_per_cta = 32; -+ -+ // This code section describes the output tile a thread block will compute -+ using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; -+ -+ // This code section describes the filter shape -+ using FilterShape = cutlass::MatrixShape<5, 5>; -+ -+ // Threadblock tile shape -+ using ThreadblockShape = -+ cutlass::gemm::GemmShape; -+ -+ // This code section describes tile size a warp will computes -+ using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; -+ -+ // This code section describes the size of MMA op -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ -+ // This code section describes how threadblocks are scheduled on GPU -+ using SwizzleThreadBlock = -+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< -+ 1, -+ ThreadBlockOutputShape::kN, -+ ThreadBlockOutputShape::kH, -+ ThreadBlockOutputShape::kW>; -+ -+ // Number of pipelines you want to use -+ constexpr int NumStages = 3; -+ -+ // This code section describe iterator algorithm selected is Analytic or Optimized -+ static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = -+ cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; -+ using StrideShape = cutlass::MatrixShape<2, 2>; -+ using DilationShape = cutlass::MatrixShape<2, 2>; -+ -+ constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ // This code section describes the epilogue part of the kernel, we use default value -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, // Data type of output matrix. -+ kEpilogueElementsPerAccess, // The number of elements per vectorized. -+ // memory access. This becomes the vector width of -+ // math instructions in the epilogue too. -+ ElementAccumulator, // Data type of accumulator -+ ElementComputeEpilogue, // Data type for alpha/beta in linear combination -+ cutlass::epilogue::thread::ScaleType::Default>; -+ -+ using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< -+ ElementInputA, -+ LayoutInputA, -+ ElementInputB, -+ LayoutInputB, -+ ElementOutput, -+ LayoutOutput, -+ ElementAccumulator, -+ MMAOp, -+ SmArch, -+ ThreadblockShape, -+ ThreadBlockOutputShape, -+ FilterShape, -+ WarpShape, -+ InstructionShape, -+ EpilogueOp, -+ SwizzleThreadBlock, -+ NumStages, -+ cutlass::arch::OpMultiplyAdd, -+ IteratorAlgorithm, -+ cutlass::conv::StrideSupport::kFixed, -+ StrideShape, -+ DilationShape>::Kernel; -+ -+ using Direct2dConv = cutlass::conv::device::DirectConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( -+ DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2())); -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -new file mode 100644 -index 0000000..3c9cf10 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu -@@ -0,0 +1,221 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for Depthwise Direct Conv interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+ -+std::vector DepthwiseFpropProblemSizes() { -+ -+std::vector problems; -+ -+for ( int channels = 16; channels < 256 ; channels+=16){ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 8, 8, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 3, 3, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 16, 16, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 7, 7, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2} , // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {1, 1}, // stride (stride_h, stride_w) -+ {1, 1}, // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+ -+ problems.push_back(cutlass::conv::Conv2dProblemSize( -+ {1, 112, 112, channels}, // input size (NHWC) -+ {channels, 5, 5, 1}, // filter size (KRSC) -+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) -+ {2, 2}, // stride (stride_h, stride_w) -+ {2, 2} , // dilation (dilation_h, dilation_w) -+ cutlass::conv::Mode::kCrossCorrelation, // Convolution mode -+ 1, // split_k_slices -+ channels // groups -+ )); -+} -+ -+return problems; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 128x128_8x2_64x64x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level depthwiseFpropKernel instance -+ using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ DepthwiseFpropProblemSizes())); -+ -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, -+ 64x64_8x2_32x32x8) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ -+ /// Device-level depthwiseFpropKernel instance -+ using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< -+ ElementA, -+ cutlass::layout::TensorNHWC, -+ ElementB, -+ cutlass::layout::TensorNHWC, -+ ElementC, -+ cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm60, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run all unit test sizes with device-level Conv2d instance -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( -+ DepthwiseFpropProblemSizes())); -+ -+} -diff --git a/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..acf073f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -@@ -0,0 +1,395 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Implicit GEMM interface -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+ -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "conv2d_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ MultipleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ MutipleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kMultipleGroup, -+ cutlass::conv::IteratorAlgorithm::kAnalytic -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_128x128_64x3_64x64x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Optimized multistage singleGroup kernel -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x3_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+// Optimized 2 stage singleGroup kernel -+TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, -+ SingleGroupPerCTA_64x64_64x2_32x32x64) { -+ -+ /// Conv operation element types for the Gemm equivalent (ImplicitGemm) -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ /// Device-level Conv2d instance -+ using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< -+ ElementA, cutlass::layout::TensorNHWC, -+ ElementB, cutlass::layout::TensorNHWC, -+ ElementC, cutlass::layout::TensorNHWC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, -+ WarpShape, -+ InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::conv::GroupMode::kSingleGroup, -+ cutlass::conv::IteratorAlgorithm::kOptimized -+ >::Kernel; -+ -+ using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; -+ -+ /// Run group conv unit test sizes with device-level Conv2d instance -+ test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( -+ ThreadblockShape::kN, ThreadblockShape::kK, -+ 128/cutlass::sizeof_bits::value -+ ); -+ EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/array.cu b/3rdparty/cutlass/test/unit/core/array.cu -new file mode 100644 -index 0000000..910d1af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/array.cu -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/util/device_memory.h" -+#pragma warning( disable : 4800) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+ -+/// Each thread clears its array and writes to global memory. No PRMT instructions should -+/// be generated if Array is a multiple of 32 bits. -+template -+__global__ void test_array_clear(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ storage.clear(); -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+/// Each thread writes its thread index into the elements of its array and then writes the result -+/// to global memory. -+template -+__global__ void test_array_threadid(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ storage.at(i) = T(int(threadIdx.x)); -+ } -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+/// Each thread writes its thread index into the elements of its array and then writes the result -+/// to global memory. -+template -+__global__ void test_array_sequence(cutlass::Array *ptr) { -+ -+ cutlass::Array storage; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < N; ++i) { -+ storage.at(i) = T(i); -+ } -+ -+ ptr[threadIdx.x] = storage; -+} -+ -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestArray { -+public: -+ -+ // -+ // Data members -+ // -+ -+ /// Number of threads -+ int const kThreads = 32; -+ -+ typedef cutlass::Array ArrayTy; -+ -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TestArray() { -+ -+ } -+ -+ /// Runs the test -+ void run() { -+ -+ /// Device memory containing output -+ cutlass::device_memory::allocation< ArrayTy > output(kThreads); -+ std::vector< ArrayTy > output_host(kThreads); -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1, 1); -+ -+ test::core::test_array_clear<<< grid, block >>>(output.get()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ // -+ // Verify contains all zeros -+ // -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ char const *ptr_host = reinterpret_cast(output_host.data()); -+ for (int i = 0; i < sizeof(ArrayTy) * kThreads; ++i) { -+ EXPECT_FALSE(ptr_host[i]); -+ } -+ -+ // -+ // Verify each element contains the low bits of the thread Id -+ // -+ -+ test::core::test_array_threadid<<< grid, block >>>(output.get()); -+ -+ result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int i = 0; i < kThreads; ++i) { -+ T tid = T(i); -+ -+ ArrayTy thread = output_host.at(i); -+ -+ // Element-wise access -+ for (int j = 0; j < N; ++j) { -+ EXPECT_TRUE(tid == thread[j]); -+ } -+ -+ // Iterator access -+ for (auto it = thread.begin(); it != thread.end(); ++it) { -+ EXPECT_TRUE(tid == *it); -+ } -+ -+ // Range-based for -+ for (auto const & x : thread) { -+ EXPECT_TRUE(tid == x); -+ } -+ } -+ -+ // -+ // Verify each element -+ // -+ -+ test::core::test_array_sequence<<< grid, block >>>(output.get()); -+ -+ result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ cutlass::device_memory::copy_to_host(output_host.data(), output.get(), kThreads); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int i = 0; i < kThreads; ++i) { -+ -+ ArrayTy thread = output_host.at(i); -+ -+ // Element-wise access -+ for (int j = 0; j < N; ++j) { -+ T got = T(j); -+ EXPECT_TRUE(got == thread[j]); -+ } -+ -+ // Iterator access -+ int j = 0; -+ for (auto it = thread.begin(); it != thread.end(); ++it, ++j) { -+ T got = T(j); -+ EXPECT_TRUE(got == *it); -+ } -+ -+ // Range-based for -+ j = 0; -+ for (auto const & x : thread) { -+ T got = T(j); -+ EXPECT_TRUE(got == x); -+ ++j; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(Array, Int8x16) { -+ TestArray().run(); -+} -+ -+TEST(Array, Int32x4) { -+ TestArray().run(); -+} -+ -+#if __CUDA_ARCH__ >= 520 -+TEST(Array, Float16x8) { -+ TestArray().run(); -+} -+#endif -+ -+TEST(Array, FloatBF16x8) { -+ TestArray().run(); -+} -+ -+TEST(Array, FloatTF32x4) { -+ TestArray().run(); -+} -+ -+TEST(Array, Float32x4) { -+ TestArray().run(); -+} -+ -+TEST(Array, Int4x32) { -+ TestArray().run(); -+} -+ -+TEST(Array, Uint4x32) { -+ TestArray().run(); -+} -+ -+TEST(Array, Bin1x128) { -+ TestArray().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/bfloat16.cu b/3rdparty/cutlass/test/unit/core/bfloat16.cu -new file mode 100644 -index 0000000..6227250 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/bfloat16.cu -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+__global__ void convert_bf16_f32(cutlass::bfloat16_t *output, float const *input, int N) { -+ int tid = threadIdx.x + blockIdx.x * blockDim.x; -+ if (tid < N) { -+ output[tid] = static_cast(input[tid]); -+ } -+} -+ -+__global__ void convert_and_pack_bf16(cutlass::bfloat16_t *output, float const *input, int N) { -+ int tid = threadIdx.x + blockIdx.x * blockDim.x; -+ if (tid * 2 < N) { -+ -+ cutlass::NumericArrayConverter convert; -+ -+ cutlass::Array *dst_ptr = -+ reinterpret_cast *>(output + tid * 2); -+ -+ cutlass::Array const *src_ptr = -+ reinterpret_cast const *>(input + tid * 2); -+ -+ *dst_ptr = convert(*src_ptr); -+ } -+} -+ -+TEST(bfloat16_t, device_conversion) { -+ using T = cutlass::bfloat16_t; -+ using S = float; -+ -+ int const N = 256; -+ -+ cutlass::HostTensor destination({N, 1}); -+ cutlass::HostTensor source({N, 1}); -+ -+ for (int i = 0; i < N; ++i) { -+ source.at({i, 0}) = float(i - 128); -+ destination.at({i, 0}) = T(0); -+ } -+ -+ source.sync_device(); -+ destination.sync_device(); -+ -+ convert_bf16_f32<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); -+ -+ ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; -+ -+ destination.sync_host(); -+ -+ int errors = 0; -+ for (int i = 0; i < N; ++i) { -+ T got = destination.at({i, 0}); -+ S expected = source.at({i, 0}); -+ -+ if (S(got) != expected) { -+ ++errors; -+ if (errors < 10) { -+ std::cerr << "Basic conversion error - [" << i << "] - got " << got << ", expected " << expected << "\n"; -+ } -+ } -+ -+ destination.at({i, 0}) = T(0); -+ } -+ -+ destination.sync_device(); -+ -+ convert_and_pack_bf16<<< dim3(1,1), dim3(N, 1) >>>(destination.device_data(), source.device_data(), N); -+ -+ ASSERT_EQ(cudaGetLastError(), cudaSuccess) << "Kernel launch error."; -+ -+ destination.sync_host(); -+ -+ for (int i = 0; i < N; ++i) { -+ T got = destination.at({i, 0}); -+ S expected = source.at({i, 0}); -+ -+ if (S(got) != expected) { -+ ++errors; -+ if (errors < 10) { -+ std::cerr << "Convert and pack error - [" << i << "] - got " << got << ", expected " << expected << "\n"; -+ } -+ } -+ } -+ -+ EXPECT_EQ(errors, 0); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(bfloat16_t, host_conversion) { -+ for (int i = -128; i < 128; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::bfloat16_t x = static_cast(i); -+ cutlass::bfloat16_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::bfloat16_t() == 0.0_bf16); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16); -+ EXPECT_TRUE(7 == static_cast(7_bf16)); -+} -+ -+TEST(bfloat16_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::bfloat16_t x = static_cast(i); -+ cutlass::bfloat16_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(bfloat16_t, host_round) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint16_t expected; -+ } tests[] = { -+ {0x40040000, 0x4004}, // M=0, R=0, S=0 => rtz -+ {0x40048000, 0x4004}, // M=0, R=1, S=0 => rtz -+ {0x40040001, 0x4004}, // M=0, R=1, S=1 => +inf -+ {0x4004c000, 0x4005}, // M=0, R=1, S=1 => +inf -+ {0x4004a000, 0x4005}, // M=0, R=1, S=1 => +inf -+ {0x40050000, 0x4005}, // M=1, R=0, S=0 => rtz -+ {0x40054000, 0x4005}, // M=1, R=0, S=1 => rtz -+ {0x40058000, 0x4006}, // M=1, R=1, S=0 => +inf -+ {0x40058001, 0x4006}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f80}, // +inf -+ {0xff800000, 0xff80}, // -inf -+ {0x7fffffff, 0x7fff}, // canonical NaN -+ {0x7ff00001, 0x7fff}, // NaN -> canonical NaN -+ {0xfff00010, 0x7fff}, // Nan -> canonical NaN -+ {0, 0} -+ }; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::bfloat16_t bf16 = cutlass::bfloat16_t(f32); -+ -+ bool passed = (tests[i].expected == bf16.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << bf16.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Device -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/complex.cu b/3rdparty/cutlass/test/unit/core/complex.cu -new file mode 100644 -index 0000000..2962f5a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/complex.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief CUTLASS host-device template for complex numbers supporting all CUTLASS numeric types. -+*/ -+ -+// Standard Library's std::complex used for reference checking -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/constants.h" -+#include "cutlass/numeric_conversion.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f64_to_f32_conversion) { -+ -+ cutlass::complex source = {1.5, -1.25}; -+ -+ cutlass::complex dest = cutlass::complex(source); // explicit conversion -+ -+ EXPECT_TRUE(source.real() == 1.5 && source.imag() == -1.25 && -+ dest.real() == 1.5f && dest.imag() == -1.25f); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f32_to_f64_conversion) { -+ -+ cutlass::complex source = {-1.5f, 1.25f}; -+ -+ cutlass::complex dest = source; // implicit conversion -+ -+ EXPECT_TRUE(source.real() == -1.5f && source.imag() == 1.25f && -+ dest.real() == -1.5 && dest.imag() == 1.25); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, s32_to_f64_conversion) { -+ -+ cutlass::complex source = {-2, 1}; -+ -+ cutlass::complex dest = source; // implicit conversion -+ -+ EXPECT_TRUE(source.real() == -2 && source.imag() == 1 && -+ dest.real() == -2 && dest.imag() == 1); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, f16_to_f32_conversion) { -+ -+ cutlass::complex source = {1.5_hf, -1.25_hf}; -+ -+ cutlass::complex dest = cutlass::complex(source); // explicit conversion -+ -+ EXPECT_TRUE(source.real() == 1.5_hf && source.imag() == -1.25_hf && -+ dest.real() == 1.5f && dest.imag() == -1.25f); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, exp_f32) { -+ -+ cutlass::complex Z[] = { -+ {1, 1}, -+ {2 , cutlass::constants::pi()/2.0f }, -+ {0.5f, cutlass::constants::pi() }, -+ {0.25f, cutlass::constants::pi()*3/4.0f }, -+ {0, 0}, -+ }; -+ -+ cutlass::complex Expected[] = { -+ {1.4686939399158851, 2.2873552871788423}, -+ {4.524491950137825e-16, 7.38905609893065}, -+ {-1.6487212707001282, 2.019101226849069e-16}, -+ {-0.9079430793557842, 0.9079430793557843}, -+ {1, 0} -+ }; -+ -+ double tolerance = 0.00001; -+ -+ for (int i = 0; cutlass::real(Z[i]); ++i) { -+ double e_r = cutlass::real(Expected[i]); -+ double e_i = cutlass::real(Expected[i]); -+ -+ cutlass::complex got = cutlass::exp(Z[i]); -+ float g_r = cutlass::real(got); -+ float g_i = cutlass::real(got); -+ -+ EXPECT_TRUE( -+ std::abs(g_r - e_r) < tolerance && std::abs(g_i - e_i) < tolerance -+ ) << "Expected(" << Expected[i] << "), Got(" << got << ")"; -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+ -+ /// Thorough testing for basic complex math operators. Uses std::complex as a reference. -+ template -+ struct ComplexOperators { -+ ComplexOperators() { -+ for (int ar = -N; ar <= N; ++ar) { -+ for (int ai = -N; ai <= N; ++ai) { -+ for (int br = -N; br <= N; ++br) { -+ for (int bi = -N; bi <= N; ++bi) { -+ -+ cutlass::complex Ae(T(ar) / T(M), T(ai) / T(M)); -+ cutlass::complex Be(T(br) / T(M), T(bi) / T(M)); -+ -+ std::complex Ar(T(ar) / T(M), T(ai) / T(M)); -+ std::complex Br(T(br) / T(M), T(bi) / T(M)); -+ -+ cutlass::complex add_e = Ae + Be; -+ cutlass::complex sub_e = Ae - Be; -+ cutlass::complex mul_e = Ae * Be; -+ -+ std::complex add_r = (Ar + Br); -+ std::complex sub_r = (Ar - Br); -+ std::complex mul_r = (Ar * Br); -+ -+ EXPECT_EQ(real(add_e), real(add_r)); -+ EXPECT_EQ(imag(add_e), imag(add_r)); -+ -+ EXPECT_EQ(real(sub_e), real(sub_r)); -+ EXPECT_EQ(imag(sub_e), imag(sub_r)); -+ -+ EXPECT_EQ(real(mul_e), real(mul_r)); -+ EXPECT_EQ(imag(mul_e), imag(mul_r)); -+ -+ if (!(br == 0 && bi == 0)) { -+ -+ cutlass::complex div_e = Ae / Be; -+ std::complex div_r = Ar / Br; -+ -+ T const kRange = T(0.001); -+ -+ EXPECT_NEAR(real(div_e), real(div_r), kRange); -+ EXPECT_NEAR(imag(div_e), imag(div_r), kRange); -+ } -+ } -+ } -+ } -+ } -+ } -+ }; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, host_float) { -+ test::ComplexOperators test; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(complex, host_double) { -+ test::ComplexOperators test; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/float8.cu b/3rdparty/cutlass/test/unit/core/float8.cu -new file mode 100644 -index 0000000..b685838 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/float8.cu -@@ -0,0 +1,103 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for basic float8 functionality -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/numeric_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(float_e4m3_t, host_conversion) { -+ for (int i = -8; i < 8; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::float_e4m3_t x = static_cast(i); -+ cutlass::float_e4m3_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::float_e4m3_t() == 0.0_fe4m3); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::float_e4m3_t(7) == 7_fe4m3); -+ EXPECT_TRUE(7 == static_cast(7_fe4m3)); -+} -+ -+TEST(float_e5m2_t, host_conversion) { -+ for (int i = -8; i < 8; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::float_e5m2_t x = static_cast(i); -+ cutlass::float_e5m2_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::float_e5m2_t() == 0.0_fe5m2); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::float_e5m2_t(7) == 7_fe5m2); -+ EXPECT_TRUE(7 == static_cast(7_fe5m2)); -+} -+ -+TEST(float_e4m3_t, host_arithmetic) { -+ for (int i = -4; i < 4; ++i) { -+ for (int j = -4; j < 4; ++j) { -+ -+ cutlass::float_e4m3_t x = static_cast(i); -+ cutlass::float_e4m3_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(float_e5m2_t, host_arithmetic) { -+ for (int i = -4; i < 4; ++i) { -+ for (int j = -4; j < 4; ++j) { -+ -+ cutlass::float_e5m2_t x = static_cast(i); -+ cutlass::float_e5m2_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/functional.cu b/3rdparty/cutlass/test/unit/core/functional.cu -new file mode 100644 -index 0000000..bd76bc0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/functional.cu -@@ -0,0 +1,494 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for functional operators. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/core_io.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conversion template -+template -+__global__ void unary_operator(Element *d, Element const *a) { -+ -+ Operator op; -+ -+ *d = op(*a); -+} -+ -+/// Conversion template -+template -+__global__ void binary_operator(Element *d, Element const *a, Element const *b, int Iterations = 1) { -+ -+ Operator op; -+ -+ Element a_x = *a; -+ Element b_x = *b; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < Iterations; ++i) { -+ b_x = op(a_x, b_x); -+ } -+ -+ *d = b_x; -+} -+ -+/// Conversion template -+template -+__global__ void trinary_operator( -+ Element *d, -+ Element const *a, -+ Element const *b, -+ Element const *c, -+ int Iterations = 1) { -+ -+ Operator op; -+ -+ Element a_x = a[blockIdx.x]; -+ Element b_x = b[blockIdx.x]; -+ Element c_x = c[blockIdx.x]; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < Iterations; ++i) { -+ c_x = op(a_x, b_x, c_x); -+ } -+ -+ d[blockIdx.x] = c_x; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_plus_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::plus; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a + b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, plus_f16x16) { -+ Functional_plus_f16xN<16>(); -+} -+ -+TEST(Functional, plus_f16x17) { -+ Functional_plus_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_minus_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::minus; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a - b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, minus_f16x16) { -+ Functional_minus_f16xN<16>(); -+} -+ -+TEST(Functional, minus_f16x17) { -+ Functional_minus_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_multiplies_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::multiplies; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a * b)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, multiplies_f16x16) { -+ -+ Functional_multiplies_f16xN<16>(); -+} -+ -+TEST(Functional, multiplies_f16x17) { -+ -+ Functional_multiplies_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_divides_f16xN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::divides; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = cutlass::half_t((i * 2 + 1) % 5); -+ B.host_data()[i] = cutlass::half_t((i * 4 + 8) % 7); -+ D.host_data()[i] = cutlass::half_t(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ -+ test::core::kernel::binary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ float expected = a / b; -+ -+ float const kThreshold = 0.0005f; -+ -+ if (std::isnan(expected)) { -+ EXPECT_TRUE(std::isnan(d)); -+ } -+ else if (std::isinf(expected)) { -+ EXPECT_TRUE(std::isinf(d)); -+ } -+ else { -+ EXPECT_TRUE(std::abs(d - expected) < kThreshold) -+ << "Got: " << d << " = " << a << " / " << b << ", expected: " << (a / b); -+ } -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+TEST(Functional, divides_f16x16) { -+ -+ Functional_divides_f16xN<16>(); -+} -+ -+TEST(Functional, divides_f16x17) { -+ -+ Functional_divides_f16xN<17>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void Functional_multiply_add_TxN() { -+ -+ using Element = cutlass::Array; -+ using Operator = cutlass::multiply_add; -+ -+ using Tensor = cutlass::HostTensor; -+ -+ Tensor D({1, kN}); -+ Tensor A({1, kN}); -+ Tensor B({1, kN}); -+ Tensor C({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ A.host_data()[i] = T((i * 2 + 1) % 5); -+ B.host_data()[i] = T((i * 4 + 8) % 7); -+ C.host_data()[i] = T((i * 3 + 11) % 11); -+ D.host_data()[i] = T(0); -+ } -+ -+ D.sync_device(); -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ -+ test::core::kernel::trinary_operator<<< dim3(1,1), dim3(1,1) >>>( -+ reinterpret_cast(D.device_data()), -+ reinterpret_cast(A.device_data()), -+ reinterpret_cast(B.device_data()), -+ reinterpret_cast(C.device_data()) -+ ); -+ -+ D.sync_host(); -+ -+ bool some_d_nonzero = false; -+ -+ for (int i = 0; i < kN; ++i) { -+ float a = float(A.host_data()[i]); -+ float b = float(B.host_data()[i]); -+ float c = float(C.host_data()[i]); -+ float d = float(D.host_data()[i]); -+ -+ EXPECT_TRUE(d == (a * b + c)); -+ -+ if (d != 0) { -+ some_d_nonzero = true; -+ } -+ } -+ -+ EXPECT_TRUE(some_d_nonzero); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Functional, multiply_add_f16x16) { -+ Functional_multiply_add_TxN(); -+} -+ -+TEST(Functional, multiply_add_f16x17) { -+ Functional_multiply_add_TxN(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Functional, multiply_add_bf16x16) { -+ Functional_multiply_add_TxN(); -+} -+ -+TEST(Functional, multiply_add_bf16x17) { -+ Functional_multiply_add_TxN(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+cutlass::Quaternion random_quaternion(int range) { -+ return cutlass::Quaternion{ -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range), -+ T((rand() % range * 2) - range) -+ }; -+} -+ -+template -+void Functional_multiply_add_QuaternionT() { -+ -+ using Element = cutlass::Quaternion; -+ using Operator = cutlass::multiply_add; -+ using HostTensor = cutlass::HostTensor; -+ -+ int const kM = 128; -+ int const kRange = 8; -+ -+ HostTensor A({kM, 1}); -+ HostTensor B({kM, 1}); -+ HostTensor C({kM, 1}); -+ HostTensor D({kM, 1}); -+ -+ srand(2021); -+ -+ for (int m = 0; m < kM; ++m) { -+ A.at({m, 0}) = random_quaternion(kRange); -+ B.at({m, 0}) = random_quaternion(kRange); -+ C.at({m, 0}) = random_quaternion(kRange); -+ } -+ -+ A.sync_device(); -+ B.sync_device(); -+ C.sync_device(); -+ D.sync_device(); -+ -+ test::core::kernel::trinary_operator<<< dim3(kM,1), dim3(1,1) >>>( -+ D.device_data(), -+ A.device_data(), -+ B.device_data(), -+ C.device_data() -+ ); -+ -+ D.sync_host(); -+ -+ for (int m = 0; m < kM; ++m) { -+ -+ Element a = A.at({m, 0}); -+ Element b = B.at({m, 0}); -+ Element c = C.at({m, 0}); -+ Element got = D.at({m, 0}); -+ Element expected = a * b + c; -+ -+ EXPECT_TRUE(got == expected); -+ } -+} -+ -+TEST(Functional, multiply_add_quaternion_f32) { -+ Functional_multiply_add_QuaternionT(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/half.cu b/3rdparty/cutlass/test/unit/core/half.cu -new file mode 100644 -index 0000000..27d0872 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/half.cu -@@ -0,0 +1,90 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(half_t, host_conversion) { -+ for (int i = -1024; i < 1024; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::half_t() == 0.0_hf); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::half_t(7) == 7_hf); -+ EXPECT_TRUE(7 == static_cast(7_hf)); -+} -+ -+TEST(half_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+ -+ for (int i = -6; i < 6; ++i) { -+ for (int j = -6; j < 6; ++j) { -+ -+ cutlass::half_t x = static_cast(i); -+ cutlass::half_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x * y) == (i * j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/matrix.cu b/3rdparty/cutlass/test/unit/core/matrix.cu -new file mode 100644 -index 0000000..334521c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/matrix.cu -@@ -0,0 +1,205 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ \brief Unit tests for the small matrix class. -+*/ -+ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/matrix.h" -+#include "cutlass/core_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix, elementwise_add) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = A.transpose(); -+ -+ Matrix4x4 C = A.add(B * 2.125f); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = C.at(i, j); -+ float expected = A.at(i, j) + A.at(j, i) * 2.125f; -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; -+ } -+} -+ -+TEST(Matrix, elementwise_multiply) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = A.transpose(); -+ -+ Matrix4x4 C = A.multiply(B); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = C.at(i, j); -+ float expected = A.at(i, j) * A.at(j, i); -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; -+ } -+} -+ -+TEST(Matrix, product_4x4_overloads) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = { -+ -1, -2, 0, 4, -+ 1, 2, 1, 1, -+ 3, 2, 1, 1, -+ 1, 0, 8, 2 -+ }; -+ -+ Matrix4x4 C = Matrix4x4::identity(); -+ -+ Matrix4x4 D = A * B + C; -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = D.at(i, j); -+ float expected = (i == j ? 1.0f : 0); -+ for (int k = 0; k < 4; ++k) { -+ expected += A.at(i, k) * B.at(k, j); -+ } -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; -+ } -+} -+ -+ -+TEST(Matrix, product_4x4) { -+ -+ using Matrix4x4 = cutlass::Matrix4x4; -+ -+ Matrix4x4 A = { -+ 1, 2, 3, 4, -+ 5, 6, 7, 8, -+ 9, 10, 11, 12, -+ 13, 14, 15, 16 -+ }; -+ -+ Matrix4x4 B = { -+ -1, -2, 0, 4, -+ 1, 2, 1, 1, -+ 3, 2, 1, 1, -+ 1, 0, 8, 2 -+ }; -+ -+ Matrix4x4 C = Matrix4x4::identity(); -+ -+ // Compute product with optional source accumulator -+ Matrix4x4 D = A.product(B, C); -+ -+ bool passed = true; -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float got = D.at(i, j); -+ float expected = (i == j ? 1.0f : 0.0f); -+ for (int k = 0; k < 4; ++k) { -+ expected += A.at(i, k) * B.at(k, j); -+ } -+ if (got != expected) { -+ passed = false; -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; -+ } -+ -+ for (int i = 0; i < 4; ++i) { -+ for (int j = 0; j < 4; ++j) { -+ float c = (i == j ? 1.0f : 0.0f); -+ EXPECT_TRUE(A.row(i).dot(B.column(j)) + c == D.at(i, j)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/matrix_coord.cu b/3rdparty/cutlass/test/unit/core/matrix_coord.cu -new file mode 100644 -index 0000000..c703769 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/matrix_coord.cu -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for matrix_coord -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/matrix_coord.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace core { -+ -+ void test_matrix_coord(cutlass::MatrixCoord::Index row, cutlass::MatrixCoord::Index column) { -+ cutlass::MatrixCoord matrix_coord(row, column); -+ -+ EXPECT_EQ(matrix_coord.row(), row); -+ EXPECT_EQ(matrix_coord.column(), column); -+ } -+ -+ void test_matrix_coord_operator_addition() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a + matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a + row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a + column_b); -+ } -+ -+ void test_matrix_coord_operator_subtraction() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a - matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a - row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a - column_b); -+ } -+ -+ void test_matrix_coord_operator_multiply() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a * matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a * row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a * column_b); -+ } -+ -+ void test_matrix_coord_operator_division() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ auto matrix_coord_c = matrix_coord_a / matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_c.row(), row_a / row_b); -+ EXPECT_EQ(matrix_coord_c.column(), column_a / column_b); -+ } -+ -+ void test_matrix_coord_operator_addition_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a += matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a + row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a + column_b); -+ } -+ -+ void test_matrix_coord_operator_subtraction_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a -= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a - row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a - column_b); -+ } -+ -+ void test_matrix_coord_operator_multiply_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a *= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a * row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a * column_b); -+ } -+ -+ void test_matrix_coord_operator_division_assignment() { -+ cutlass::MatrixCoord::Index row_a = 13; -+ cutlass::MatrixCoord::Index column_a = 42; -+ cutlass::MatrixCoord::Index row_b = 20; -+ cutlass::MatrixCoord::Index column_b = 15; -+ -+ cutlass::MatrixCoord matrix_coord_a(row_a, column_a); -+ cutlass::MatrixCoord matrix_coord_b(row_b, column_b); -+ -+ matrix_coord_a /= matrix_coord_b; -+ -+ EXPECT_EQ(matrix_coord_a.row(), row_a / row_b); -+ EXPECT_EQ(matrix_coord_a.column(), column_a / column_b); -+ } -+} -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_row12_column24) { -+ cutlass::MatrixCoord::Index row = 12; -+ cutlass::MatrixCoord::Index column = 24; -+ test::core::test_matrix_coord(row, column); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_addition) { -+ test::core::test_matrix_coord_operator_addition(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_subtraction) { -+ test::core::test_matrix_coord_operator_subtraction(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_multiply) { -+ test::core::test_matrix_coord_operator_multiply(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_division) { -+ test::core::test_matrix_coord_operator_division(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_addition_assignment) { -+ test::core::test_matrix_coord_operator_addition_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_subtraction_assignment) { -+ test::core::test_matrix_coord_operator_subtraction_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_multiply_assignment) { -+ test::core::test_matrix_coord_operator_multiply_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Matrix_Coord, basic_operator_division_assignment) { -+ test::core::test_matrix_coord_operator_division_assignment(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/numeric_conversion.cu b/3rdparty/cutlass/test/unit/core/numeric_conversion.cu -new file mode 100644 -index 0000000..8d7a296 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/numeric_conversion.cu -@@ -0,0 +1,331 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for conversion operators. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace core { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Simple conversion function -+template -+__global__ void convert( -+ cutlass::Array *destination, -+ cutlass::Array const *source) { -+ -+ cutlass::NumericArrayConverter convert; -+ -+ *destination = convert(*source); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void run_test() { -+ const int kN = Count; -+ -+ dim3 grid(1, 1); -+ dim3 block(1, 1); -+ -+ cutlass::HostTensor destination({1, kN}); -+ cutlass::HostTensor source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ source.host_data()[i] = Source(i % 4); -+ } -+ -+ source.sync_device(); -+ -+ convert<<< grid, block >>>( -+ reinterpret_cast *>(destination.device_data()), -+ reinterpret_cast const *>(source.device_data()) -+ ); -+ -+ destination.sync_host(); -+ -+ for (int i = 0; i < kN; ++i) { -+ EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); -+ } -+} -+ -+} // namespace kernel -+} // namespace core -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32_to_f16_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32x8_to_f16x8_rn) { -+ int const kN = 8; -+ using Source = float; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f16_to_f32_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16x8_to_f32x8_rn) { -+ int const kN = 8; -+ using Source = cutlass::half_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = float; -+ using Destination = cutlass::float_e4m3_t; -+ -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = float; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f32_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = float; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, f16_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::half_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe4m3_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, bf16_to_fe5m2_rn_array) { -+ int const kN = 27; -+ using Source = cutlass::bfloat16_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, fe4m3_to_fe5m2_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_fe5m2_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::float_e5m2_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_fe4m3_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_fe4m3_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::float_e4m3_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f32_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(NumericConversion, f32x8_to_s8x8_rn) { -+ -+ int const kN = 8; -+ using Source = float; -+ using Destination = int8_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f32_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f32_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = float; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_f16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_f16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::half_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_bf16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe4m3_to_bf16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e4m3_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_bf16_rn) { -+ int const kN = 1; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+TEST(NumericConversion, fe5m2_to_bf16_array) { -+ int const kN = 27; -+ using Source = cutlass::float_e5m2_t; -+ using Destination = cutlass::bfloat16_t; -+ test::core::kernel::run_test(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/predicate_vector.cu b/3rdparty/cutlass/test/unit/core/predicate_vector.cu -new file mode 100644 -index 0000000..5db96c9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/predicate_vector.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/predicate_vector.h" -+#include "cutlass/util/host_tensor.h" -+ -+namespace test { -+ -+template -+__global__ void load_predicates(unsigned *output, unsigned const *input) { -+ -+ PredicateVector predicates; -+ -+ int const word_count = (PredicateVector::kPredicates + 31) / 32; -+ -+ int i = 0; -+ for (int word_idx = 0; word_idx < word_count; ++word_idx) { -+ unsigned word = input[word_idx]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int bit = 0; bit < sizeof(unsigned) * 8; ++bit) { -+ bool pred = ((word >> bit) & 1); -+ predicates.set(i, pred); -+ -+ if (predicates.at(i) != pred) { -+ printf("ERROR - cannot read back predicate\n"); -+ } -+ ++i; -+ } -+ } -+ -+ -+ __syncthreads(); -+ -+ i = 0; -+ for (int word_idx = 0; word_idx < word_count; ++word_idx) { -+ -+ unsigned result = 0; -+ for (int bit = 0; bit < sizeof(unsigned) * 8; ++bit) { -+ bool pred = predicates.at(i ++); -+ result |= (unsigned(pred) << bit); -+ } -+ output[word_idx] = result; -+ } -+} -+} -+ -+TEST(PredicateVector, Basic) { -+ -+ static int const Bits = 32; -+ static int const Words = (Bits + 31) / 32; -+ -+ typedef cutlass::PredicateVector PredicateVector; -+ -+ cutlass::HostTensor > output; -+ cutlass::HostTensor> input; -+ -+ output.reserve(Words); -+ input.reserve(Words); -+ -+ // some arbitrary test bits -+ unsigned values[] = { -+ 0xdeadbeef, -+ 0xa0070032, -+ 0x9076d001, -+ 0x00000000, -+ 0xabdfc0ad -+ }; -+ -+ for (int test = 0; test < 5; ++test) { -+ -+ input.host_data(0) = values[test]; -+ output.host_data(0) = 0; -+ -+ input.sync_device(); -+ output.sync_device(); -+ -+ test::load_predicates<<< -+ dim3(1,1,1), dim3(1,1,1) -+ >>>( -+ output.device_data(), -+ input.device_data() -+ ); -+ -+ output.sync_host(); -+ -+ for (int word = 0; word < Words; ++word) { -+ EXPECT_EQ(input.host_data(word), output.host_data(word)) -+ << "Expected: 0x" << std::hex << input.host_data(word) -+ << ", got: 0x" << output.host_data(word) -+ << std::dec; -+ } -+ } -+} -+ -+TEST(PredicateVector, Count) { -+ -+ { -+ typedef cutlass::PredicateVector<4, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<4, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<4, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<8, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<8, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<8, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<16, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<16, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<16, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<16, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<16, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 1) -+ << "PredicateVector<32, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<32, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<32, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<32, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 8) -+ << "PredicateVector<32, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 8> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 2) -+ << "PredicateVector<64, 8> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 4> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 4) -+ << "PredicateVector<64, 4> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 2> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 8) -+ << "PredicateVector<64, 2> word count: " << int(PredicateVector::kWordCount); -+ } -+ -+ { -+ typedef cutlass::PredicateVector<64, 1> PredicateVector; -+ EXPECT_EQ(int(PredicateVector::kWordCount), 16) -+ << "PredicateVector<64, 1> word count: " << int(PredicateVector::kWordCount); -+ } -+} -diff --git a/3rdparty/cutlass/test/unit/core/quaternion.cu b/3rdparty/cutlass/test/unit/core/quaternion.cu -new file mode 100644 -index 0000000..400ea6a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/quaternion.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for the CUTLASS Quaternion template class. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/constants.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static float const half_pi = cutlass::constants::half_pi(); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, add_f32) { -+ -+ cutlass::Quaternion q0(1, 1, 1, 1); -+ cutlass::Quaternion q1(0, 0, 0, 2); -+ -+ cutlass::Quaternion q2 = q0 + q1; -+ -+ EXPECT_TRUE( -+ q2.x() == 1 && -+ q2.y() == 1 && -+ q2.z() == 1 && -+ q2.w() == 3 -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, rotation) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; -+ cutlass::Matrix3x1 v = q.rotate(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, rotation_inv) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; -+ cutlass::Matrix3x1 v = q.rotate(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(-v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, spinor_rotation) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = cutlass::spinor_rotation(q, x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, spinor_rotation_inv) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = cutlass::spinor_rotation_inv(q, x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(-v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, as_rotation_matrix3x3) { -+ -+ cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix3x1 v = q.as_rotation_matrix_3x3().product(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon -+ ); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Quaternion, as_rotation_matrix4x4) { -+ -+ cutlass::Matrix4x1 x(1.0f, 0.0f, 0.0f, 1.0f); -+ cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); -+ cutlass::Matrix4x1 v = q.as_rotation_matrix_4x4().product(x); -+ -+ float epsilon = 0.001f; -+ -+ EXPECT_TRUE( -+ std::abs(v.at(0)) < epsilon && -+ std::abs(v.at(1)) > (1 - epsilon) && -+ std::abs(v.at(2)) < epsilon && -+ std::abs(v.at(3)) > (1 - epsilon) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/tensor_ref.cu b/3rdparty/cutlass/test/unit/core/tensor_ref.cu -new file mode 100644 -index 0000000..a4c46fa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tensor_ref.cu -@@ -0,0 +1,224 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/layout/matrix.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, basic_rank2) { -+ int const M = 8; -+ int const N = 16; -+ -+ int matrix_data[M * N] = {0}; -+ -+ cutlass::TensorRef< -+ int, -+ cutlass::IdentityTensorLayout<2> > matrix_ref(matrix_data, cutlass::make_Coord(N, 1)); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ matrix_ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * N + n], int(m * N + n)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_column_major) { -+ int const M = 8; -+ int const N = 8; -+ -+ int matrix_data[M * N]; -+ -+ cutlass::TensorRef ref(matrix_data, M); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m + n * M], int(m * N + n)); -+ } -+ } -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_row_major) { -+ int const M = 8; -+ int const N = 16; -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ cutlass::TensorRef ref(matrix_data, N); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * N + n], int(m * N + n)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_contiguous_dynamic) { -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorRef ContiguousTensorRef; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorRef ref( -+ matrix_data, -+ cutlass::layout::ContiguousMatrix::packed(cutlass::make_Coord(M, N), layouts[i])); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m * N + n; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ EXPECT_EQ(matrix_data[m * row_stride + n * col_stride], int(m * N + n)); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_column_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ // Define the Layout for a column-major interleaved matrix format -+ using Layout = cutlass::layout::ColumnMajorInterleaved; -+ -+ // Construct a TensorRef -+ cutlass::TensorRef< -+ int, -+ Layout> ref(matrix_data, Layout::packed(cutlass::make_Coord(M, N))); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Verify -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; n += kInterleave) { -+ for (int i = 0; i < kInterleave; ++i) { -+ EXPECT_EQ(matrix_data[m * kInterleave + n * M + i], int(m + (n + i) * M)); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorRef, rank2_row_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ // Define the Layout for a row-major interleaved matrix format -+ using Layout = cutlass::layout::RowMajorInterleaved; -+ -+ // Construct a TensorRef -+ cutlass::TensorRef< -+ int, -+ Layout> ref(matrix_data, Layout::packed(cutlass::make_Coord(M, N))); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ ref.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Verify -+ for (int m = 0; m < M; m += kInterleave) { -+ for (int n = 0; n < N; ++n) { -+ for (int i = 0; i < kInterleave; ++i) { -+ EXPECT_EQ(matrix_data[m * N + i + n * kInterleave], int((m + i) + n * M)); -+ } -+ } -+ } -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/core/tensor_view.cu b/3rdparty/cutlass/test/unit/core/tensor_view.cu -new file mode 100644 -index 0000000..26f1a70 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tensor_view.cu -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorView, rank2_contiguous_dynamic) { -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorView ContiguousTensorView; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M - 2, N - 2); -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorView view( -+ matrix_data, -+ cutlass::layout::ContiguousMatrix::packed(cutlass::make_Coord(M, N), layouts[i]), -+ bounds); -+ -+ ASSERT_TRUE(view.good()); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ cutlass::Coord<2> coord = cutlass::make_Coord(m, n); -+ if (view.contains(coord)) { -+ view.at(coord) = m * N + n; -+ } -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int expected = 0; -+ if (m < bounds[0] && n < bounds[1]) { -+ expected = int(m * N + n); -+ } -+ EXPECT_EQ(matrix_data[m * row_stride + n * col_stride], expected); -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Uncomment the following line to observe output from printing TensorView objects -+// -+ -+// #define OBSERVE_TENSORVIEW_IO // uncomment to enable printing -+ -+#ifdef OBSERVE_TENSORVIEW_IO -+ -+// This test construct a TensorView of rank=2 with matrix layouts known at runtime. This -+// uses TensorRefMapFunc classes defined in cutlass/matrix_traits.h to define the mapping -+// from logical tensor indices to storage in memory. -+// -+// Helpers in tools/util/tensor_view_io.h print both the logical TensorView and the -+// linear memory of the tensor. -+TEST(TensorView, contiguous) { -+ -+ int const M = 8; -+ int const N = 16; -+ -+ typedef cutlass::TensorView< -+ int32_t, -+ cutlass::layout::ContiguousLayout> ContiguousTensorView; -+ -+ cutlass::layout::Matrix layouts[] = { -+ cutlass::layout::Matrix::kColumnMajor, -+ cutlass::layout::Matrix::kRowMajor -+ }; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M, N); -+ -+ for (int i = 0; i < 2; ++i) { -+ -+ int matrix_data[M * N] = { 0 }; -+ -+ int ldm; -+ int row_stride; -+ int col_stride; -+ -+ if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { -+ row_stride = 1; -+ col_stride = M; -+ ldm = col_stride; -+ } -+ else { -+ row_stride = N; -+ col_stride = 1; -+ ldm = row_stride; -+ } -+ -+ // Use helper to determine stride vector from leading dimension -+ ContiguousTensorView view( -+ matrix_data, -+ cutlass::layout::ContiguousLayout::stride(layouts[i], ldm), -+ bounds); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ cutlass::Coord<2> coord = cutlass::make_Coord(m, n); -+ if (view.contains(coord)) { -+ view.at(coord) = m * N + n; -+ } -+ } -+ } -+ -+ std::cout << "---------\n"; -+ std::cout << (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? -+ "Column-major:" : "Row-major:") << "\n\n"; -+ -+ std::cout << "Logical view:\n"; -+ std::cout.width(4); -+ std::cout << view << "\n" << std::endl; // Print TensorView object. -+ -+ std::cout << "Linear memory:"; -+ for (int idx = 0; idx < view.capacity(); ++idx) { -+ if (!(idx % (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? M : N))) { -+ std::cout << std::endl; -+ } -+ std::cout << std::setw(4) << view.at(idx) << " "; -+ } -+ -+ std::cout << "\n" << std::endl; -+ } -+} -+ -+// This test is similar to the previous except it uses a column-major, interleaved data -+// layout. The test prints both the logical representation (a typical column-major matrix) -+// and a representation of linear memory. -+// -+// Note, the interleave=4 structure implies that every four consecutive elements in the -+// same row shall be adjacent in memory followed by the next row. -+TEST(TensorView, rank2_column_major_interleaved) { -+ int const M = 16; -+ int const N = 16; -+ int const kInterleave = 4; -+ -+ int matrix_data[M * N] = {0}; -+ -+ cutlass::Coord<2> bounds = cutlass::make_Coord(M, N); -+ -+ // Define the TensorRefMapFunc for a column-major interleaved matrix format -+ typedef cutlass::layout::ColumnMajorInterleaved TensorRefMapFunc; -+ -+ // Define a TensorView of rank=2 using the column-major interleaved mapping function -+ typedef cutlass::TensorView< -+ int, -+ TensorRefMapFunc> InterleavedTensorView; -+ -+ InterleavedTensorView view( -+ matrix_data, -+ TensorRefMapFunc::stride(M), -+ bounds); -+ -+ // Initialize -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ view.at(cutlass::make_Coord(m, n)) = m + n * M; -+ } -+ } -+ -+ // Print logical view -+ std::cout << "Column-major, interleave=" << kInterleave << " (logical view):\n"; -+ -+ std::cout << std::setw(4) << view << "\n" << std::endl; -+ -+ // Now define a linear view of the same data in memory -+ typedef cutlass::TensorView LinearTensorView; -+ -+ LinearTensorView linear_view(matrix_data, cutlass::make_Coord(N), bounds); -+ -+ std::cout << "Linear view in memory:\n"; -+ std::cout << std::setw(4) << linear_view << std::endl; -+} -+ -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorView, int4) { -+ -+ int const M = 4; -+ int const N = 8; -+ -+ using T = cutlass::int4b_t; -+ -+ cutlass::HostTensor tensor({M, N}); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ T x = T(n ^ m); // some simple hash -+ tensor.host_view().at({m, n}) = x; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int x = (n ^ m); // some simple hash -+ EXPECT_TRUE(int(tensor.host_view().at({m, n})) == x); -+ } -+ } -+ -+ EXPECT_EQ(tensor.size(), M * N); -+} -+ -+TEST(TensorView, uint4) { -+ -+ int const M = 4; -+ int const N = 8; -+ -+ using T = cutlass::uint4b_t; -+ -+ cutlass::HostTensor tensor({M, N}); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ T x = T(n ^ m); // some simple hash -+ tensor.host_view().at({m, n}) = x; -+ } -+ } -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ int x = (n ^ m); // some simple hash -+ EXPECT_TRUE(int(tensor.host_view().at({m, n})) == x); -+ } -+ } -+ -+ EXPECT_EQ(tensor.size(), M * N); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/core/tfloat32.cu b/3rdparty/cutlass/test/unit/core/tfloat32.cu -new file mode 100644 -index 0000000..aff50cf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/core/tfloat32.cu -@@ -0,0 +1,206 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types -+ and is safe to use in a union. -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Host -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(tfloat32_t, host_conversion) { -+ for (int i = -1024; i < 1024; ++i) { -+ float f = static_cast(i); -+ -+ cutlass::tfloat32_t x = static_cast(i); -+ cutlass::tfloat32_t y = static_cast(f); -+ -+ EXPECT_TRUE(static_cast(x) == i); -+ EXPECT_TRUE(static_cast(y) == f); -+ } -+ -+ // Try out default-ctor (zero initialization of primitive proxy type) -+ EXPECT_TRUE(cutlass::tfloat32_t() == 0.0_tf32); -+ -+ // Try out user-defined literals -+ EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32); -+ EXPECT_TRUE(7 == static_cast(7_tf32)); -+} -+ -+TEST(tfloat32_t, host_arithmetic) { -+ -+ for (int i = -100; i < 100; ++i) { -+ for (int j = -100; j < 100; ++j) { -+ -+ cutlass::tfloat32_t x = static_cast(i); -+ cutlass::tfloat32_t y = static_cast(j); -+ -+ EXPECT_TRUE(static_cast(x + y) == (i + j)); -+ } -+ } -+} -+ -+TEST(tfloat32_t, host_round_nearest) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint32_t expected; -+ } tests[] = { -+ {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz -+ {0x40001000, 0x40000000}, // M=0, R=1, S=0 => rtz -+ {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz -+ {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf -+ {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz -+ {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz -+ {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf -+ {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f800000}, // +inf -+ {0xff800000, 0xff800000}, // -inf -+ {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN -+ {0x7f800001, 0x7fffffff}, // NaN to canonical NaN -+ {0xff800001, 0x7fffffff}, // NaN to canonical NaN -+ {0, 0} -+ }; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_to_nearest> converter; -+ -+ cutlass::tfloat32_t tf32 = converter(f32); -+ -+ // note, we must explicitly truncate the low-order bits since they are not defined in TF32. -+ if (cutlass::isfinite(tf32)) { -+ tf32.storage &= 0xffffe000; -+ } -+ -+ bool passed = (tests[i].expected == tf32.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+namespace test { -+namespace core { -+ -+__global__ void convert_tf32_half_ulp(cutlass::tfloat32_t *out, float const *in) { -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; -+ -+ *out = convert(*in); -+} -+ -+} -+} -+ -+ -+TEST(tfloat32_t, host_round_half_ulp) { -+ -+ struct { -+ uint32_t f32_bits; -+ uint32_t expected; -+ } tests[] = { -+ {0x40001fff, 0x40002000}, -+ {0x40000000, 0x40000000}, // M=0, R=0, S=0 => rtz -+ {0x40001000, 0x40002000}, // M=0, R=1, S=0 => rtz - this difers from RNE -+ {0x40000001, 0x40000000}, // M=0, R=0, S=1 => rtz -+ {0x40001001, 0x40002000}, // M=0, R=1, S=1 => +inf -+ {0x40002000, 0x40002000}, // M=1, R=0, S=0 => rtz -+ {0x40002001, 0x40002000}, // M=1, R=0, S=1 => rtz -+ {0x40003000, 0x40004000}, // M=1, R=1, S=0 => +inf -+ {0x40003001, 0x40004000}, // M=1, R=1, S=1 => +inf -+ {0x7f800000, 0x7f800000}, // +inf -+ {0xff800000, 0xff800000}, // -inf -+ {0x7fffffff, 0x7fffffff}, // canonical NaN to canonical NaN -+ {0x7f800001, 0x7f800001}, // NaN to NaN -+ {0xff800001, 0xff800001}, // NaN to NaN -+ {0, 0} -+ }; -+ -+ cutlass::NumericConverter< -+ cutlass::tfloat32_t, -+ float, -+ cutlass::FloatRoundStyle::round_half_ulp_truncate> convert; -+ -+ bool running = true; -+ for (int i = 0; running; ++i) { -+ -+ float f32 = reinterpret_cast(tests[i].f32_bits); -+ -+ cutlass::tfloat32_t tf32 = convert(f32); -+ -+ // note, for this test, we must explicitly truncate the low-order bits since they are not -+ // defined in TF32. -+ if (cutlass::isfinite(tf32)) { -+ tf32.storage &= 0xffffe000; -+ } -+ -+ bool passed = (tests[i].expected == tf32.raw()); -+ -+ EXPECT_TRUE(passed) -+ << "Error - convert(f32: 0x" << std::hex << tests[i].f32_bits -+ << ") -> 0x" << std::hex << tests[i].expected << "\ngot: 0x" << std::hex << tf32.raw(); -+ -+ if (!tests[i].f32_bits) { -+ running = false; -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Device -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu b/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu -new file mode 100644 -index 0000000..7a80a51 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/ampere/cp_async.cu -@@ -0,0 +1,104 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+__global__ void -+test(double const* g_in, double* g_out) -+{ -+ extern __shared__ double smem[]; -+ -+ smem[threadIdx.x] = g_in[threadIdx.x]; -+ -+ __syncthreads(); -+ -+ g_out[threadIdx.x] = 2 * smem[threadIdx.x]; -+} -+ -+__global__ void -+test2(double const* g_in, double* g_out) -+{ -+ using namespace cute; -+ -+ extern __shared__ double smem[]; -+ -+ auto s_tensor = make_tensor(make_smem_ptr(smem + threadIdx.x), Int<1>{}); -+ auto g_tensor = make_tensor(make_gmem_ptr(g_in + threadIdx.x), Int<1>{}); -+ -+ copy(g_tensor, s_tensor); -+ -+ cp_async_fence(); -+ cp_async_wait<0>(); -+ __syncthreads(); -+ -+ g_out[threadIdx.x] = 2 * smem[threadIdx.x]; -+} -+ -+TEST(SM80_CuTe_Ampere, CpAsync) -+{ -+ constexpr int count = 32; -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = double(i); -+ } -+ -+ thrust::device_vector d_in(h_in); -+ -+ thrust::device_vector d_out(count, -1); -+ test<<<1, count, sizeof(double) * count>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_result = d_out; -+ -+ thrust::device_vector d_out_cp_async(count, -2); -+ test2<<<1, count, sizeof(double) * count>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out_cp_async.data())); -+ thrust::host_vector h_result_cp_async = d_out_cp_async; -+ -+ for (int i = 0; i < count; ++i) { -+ EXPECT_EQ(h_result[i], h_result_cp_async[i]); -+ } -+} -diff --git a/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu b/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu -new file mode 100644 -index 0000000..15ec44b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/ampere/ldsm.cu -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+#include -+ -+ -+using namespace cute; -+ -+template -+__global__ void -+ldsm_test_device(uint16_t* g_in, uint16_t* g_out) -+{ -+ constexpr int count = sizeof(T) / 4; -+ int tid = threadIdx.x; -+ int stride = blockDim.x; -+ -+ // load input gmem -> smem -+ __shared__ uint32_t smem[32 * count]; -+ for (int i = 0; i < count; ++i) { -+ smem[tid + (stride * i)] = reinterpret_cast(g_in)[tid + (stride * i)]; -+ } -+ -+ __syncthreads(); -+ -+ uint32_t reg[count]; -+ for (int i = 0; i < count; ++i) { -+ reg[i] = 0; -+ } -+ -+ // load smem -> rmem using LDSM -+ uint128_t* smem_ptr = reinterpret_cast(smem) + tid; -+ T* rmem_ptr = reinterpret_cast(reg); -+ cute::copy_ldsm(smem_ptr, rmem_ptr); -+ -+ // store output rmem -> gmem -+ for (int i = 0; i < count; ++i) { -+ reinterpret_cast(g_out)[tid + (stride * i)] = reg[i]; -+ } -+} -+ -+template -+__global__ void -+ldsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, -+ TiledCopy tiled_copy, SmemLayout smem_layout) -+{ -+ using namespace cute; -+ -+ __shared__ uint16_t smem[size(smem_layout)]; -+ -+ auto t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); -+ auto t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); -+ auto t_smem = make_tensor(make_smem_ptr(smem), smem_layout); -+ -+ int tid = threadIdx.x; -+ -+ // Load input gmem -> smem -+ for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { -+ t_smem(i) = t_g_in(i); -+ } -+ -+ __syncthreads(); -+ -+ auto thr_copy = tiled_copy.get_thread_slice(tid); -+ -+ auto tXsX = thr_copy.partition_S(t_smem); // (V,M,N) -+ auto tXgX = thr_copy.partition_D(t_g_out); // (V,M,N) -+ -+ auto tXrX = make_tensor(shape(tXgX)); // (V,M,N) -+ clear(tXrX); // Just to make sure -+ -+/* -+ if (thread0()) { -+ print("tXsX: " ); print(tXsX.layout()); print("\n"); -+ print("tXgX: " ); print(tXgX.layout()); print("\n"); -+ print("tXrX: " ); print(tXrX.layout()); print("\n"); -+ } -+*/ -+ -+ // Copy smem -> rmem via tiled_copy (LDSM, LDS) -+ copy(tiled_copy, tXsX, tXrX); -+ -+ // Output rmem -> gmem -+ copy(tXrX, tXgX); -+} -+ -+ -+TEST(SM80_CuTe_Ampere, Ldsm) -+{ -+ constexpr int count = 1024; -+ -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = uint16_t(i); -+ } -+ thrust::device_vector d_in = h_in; -+ -+ // -+ // LDSM 1x (32b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 32; ++i) { -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 1x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // LDSM 2x (64b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 64; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 2x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // LDSM 4x (128b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ ldsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 128; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("LDSM 4x ldsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // CuTe LDSM -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i] , h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved LDS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_LDSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 LDS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_LDSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_LDSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ ldsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_LDSM_T SUCCESS\n"); -+ } -+ -+ CUTLASS_TRACE_HOST("PASS"); -+} -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu b/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu -new file mode 100644 -index 0000000..ffc8aa7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/stsm.cu -@@ -0,0 +1,426 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+using namespace cute; -+ -+template -+__global__ void -+stsm_test_device(uint16_t* g_in, uint16_t* g_out) -+{ -+ constexpr int count = sizeof(T) / 4; -+ int tid = threadIdx.x; -+ int stride = blockDim.x; -+ -+ // load input gmem -> rmem -+ uint32_t reg[count]; -+ for (int i = 0; i < (sizeof(T) / 4); i++) { -+ reg[i] = reinterpret_cast(g_in)[tid + (stride * i)]; -+ } -+ -+ __shared__ uint32_t smem[32 * count]; -+ -+ // load rmem -> smem using STSM -+ uint128_t* smem_ptr = reinterpret_cast(smem) + tid; -+ T* rmem_ptr = reinterpret_cast(reg); -+ cute::copy_stsm(rmem_ptr, smem_ptr); -+ -+ __syncthreads(); -+ -+ // store output smem -> gmem -+ for (int i = 0; i < (sizeof(T) / 4); i++) { -+ reinterpret_cast(g_out)[tid + (stride * i)] = smem[tid + (stride * i)]; -+ } -+} -+ -+template -+__global__ void -+stsm_test_device_cute(uint16_t* g_in, uint16_t* g_out, -+ TiledCopy tiled_copy, SmemLayout smem_layout) -+{ -+ using namespace cute; -+ -+ __shared__ uint16_t smem[size(smem_layout)]; -+ -+ Tensor t_g_in = make_tensor(make_gmem_ptr(g_in), smem_layout); -+ Tensor t_g_out = make_tensor(make_gmem_ptr(g_out), smem_layout); -+ Tensor t_smem = make_tensor(make_smem_ptr(smem), smem_layout); -+ -+ int tid = threadIdx.x; -+ -+ auto thr_copy = tiled_copy.get_thread_slice(tid); -+ -+ Tensor tXgX = thr_copy.partition_S(t_g_in); // (V,M,N) -+ Tensor tXsX = thr_copy.partition_D(t_smem); // (V,M,N) -+ -+ Tensor tXrX = make_tensor(shape(tXgX)); // (V,M,N) -+ clear(tXrX); // Just to make sure -+ -+/* -+ if (thread0()) { -+ print("tXsX: " ); print(tXsX.layout()); print("\n"); -+ print("tXgX: " ); print(tXgX.layout()); print("\n"); -+ print("tXrX: " ); print(tXrX.layout()); print("\n"); -+ } -+*/ -+ -+ // Load input gmem -> rmem -+ copy(tXgX, tXrX); -+ -+ // Copy rmem -> smem via tiled_copy (STSM, STS) -+ copy(tiled_copy, tXrX, tXsX); -+ -+ // Output smem -> gmem -+ for (int i = tid; i < size(t_smem); i += size(tiled_copy)) { -+ t_g_out(i) = t_smem(i); -+ } -+} -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_CuTe_Hopper, Stsm) -+{ -+ constexpr int count = 1024; -+ -+ thrust::host_vector h_in(count); -+ for (int i = 0; i < count; ++i) { -+ h_in[i] = uint16_t(i); -+ } -+ thrust::device_vector d_in = h_in; -+ -+ // -+ // STSM 1x (32b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 32; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 1x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // STSM 2x (64b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 64; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 2x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // STSM 4x (128b) -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ stsm_test_device<<<1, 32>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data())); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < 128; ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("STSM 4x stsm_test_device SUCCESS\n"); -+ } -+ -+ // -+ // CuTe STSM -+ // -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x1_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x2_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved U32x4_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout>, -+ Stride< _2,Stride<_1,_64>>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x8 interleaved STS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x1_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x2_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U32x4_STSM_N SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride< _1,_32>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom, uint16_t>{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 STS.U16 SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x2_STSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x4_STSM_T SUCCESS\n"); -+ } -+ -+ { -+ thrust::device_vector d_out(count); -+ -+ auto smem_layout = Layout, -+ Stride<_32, _1>>{}; -+ auto tiled_copy = make_tiled_copy(Copy_Atom{}, -+ Layout>{}, -+ Layout>{}); -+ -+ stsm_test_device_cute<<<1, int(size(tiled_copy))>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tiled_copy, -+ smem_layout); -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe 32x32 U16x8_STSM_T SUCCESS\n"); -+ } -+ -+ CUTLASS_TRACE_HOST("PASS"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu b/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu -new file mode 100644 -index 0000000..24f17fc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/tma_load.cu -@@ -0,0 +1,495 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+template -+struct SharedStorage -+{ -+ cute::array_aligned> smem; -+ cute::uint64_t tma_load_mbar[1]; -+}; -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTE_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTE_GRID_CONSTANT) -+# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -+# define CUTE_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTE_GRID_CONSTANT -+# endif -+#endif -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+template -+__global__ void -+tma_test_device_cute(T const* g_in, T* g_out, -+ CUTE_GRID_CONSTANT TiledCopy const tma, -+ GmemLayout gmem_layout, SmemLayout smem_layout) -+{ -+ assert(product_each(shape(gmem_layout)) == product_each(smem_layout.shape())); -+ -+ // Use Shared Storage structure to allocate and distribute aligned SMEM addresses -+ extern __shared__ char shared_memory[]; -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ // Shared memory barriers use 64bits in SMEM for synchronization -+ uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; -+ // Construct SMEM tensor -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); -+ -+#if 0 -+ -+ // -+ // Read in trivially -+ // -+ -+ Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); -+ -+ // Input gmem -> smem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ sA(i) = gA_in(i); -+ } -+ __syncthreads(); -+ -+#else -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); -+ -+ // -+ // Prepare the TMA_LOAD -+ // -+ -+ auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice -+ -+ Tensor tAgA = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N) -+ Tensor tAsA = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) -+ -+#if 0 -+ if (thread0()) { -+ print(" gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); -+ print("tAgA: "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); -+ print(" sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); -+ print("tAsA: "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); -+ } -+#endif -+ -+ // -+ // Perform the TMA_LOAD -+ // -+ -+ // Group the TMA_M and TMA_N modes -+ Tensor tAgA_2 = group_modes<1,rank(tAgA)>(tAgA); // (TMA,Rest) -+ Tensor tAsA_TR = group_modes<1,rank(tAsA)>(tAsA); // (TMA,Rest) -+ static_assert(size<1>(tAsA_TR) == 1); -+ Tensor tAsA_2 = tAsA_TR(_,0); -+ -+ // Loop over the TMA stages, using smem as our buffer -+ for (int stage = 0; stage < size<1>(tAgA_2); ++stage) -+ { -+ // Set the bytes transferred in this TMA transaction (may involve multiple issues) -+ constexpr int kTmaTransactionBytes = size(sA) * sizeof(T); -+ -+ if (threadIdx.x == 0) -+ { -+ /// Initialize shared memory barrier -+ tma_load_mbar[0] = 0; -+ cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); -+ cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); -+ -+ copy(tma.with(tma_load_mbar[0]), tAgA_2(_,stage), tAsA_2); -+ } -+ __syncthreads(); -+ -+ /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value -+ constexpr int kPhaseBit = 0; -+ cute::wait_barrier(tma_load_mbar[0], kPhaseBit); -+ -+ #endif -+ -+ // -+ // Write out trivially -+ // -+ -+ Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); -+ // Do the same slicing and grouping as sA -+ Tensor tAgA_out = cta_tma.partition_D(gA_out); // (TMA,TMA_M,TMA_N) -+ Tensor tAgA_2_out = group_modes<1,rank(tAgA_out)>(tAgA_out); // (TMA,Rest) -+ -+ // Output smem -> gmem -+ for (int i = threadIdx.x; i < size(tAsA_2); i += blockDim.x) { -+ tAgA_2_out(i,stage) = tAsA_2(i); -+ } -+ __syncthreads(); -+ } -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Col) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_1,_32>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 ColMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Row) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 RowMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_K) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_K_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_K_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi2) -+{ -+ using T = half_t; -+ // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi_Dyn) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_32x32_Multimode) -+{ -+ using T = half_t; -+ auto smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); -+ -+ //auto smem_layout = Layout>{}; -+ //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking) -+{ -+ using T = half_t; -+ auto gmem_layout = make_shape(make_shape(336,40),make_shape(32,656)); // GMEM -+ auto cta_tile = make_shape(make_shape(_16{},_8{}),make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 16-elem from m0, 8-elem from m1, -+ // Take 32-elem from k0, 2-elem from k1 -+ auto smem_layout = make_layout(cta_tile); // Col-Major SMEM -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking_2) -+{ -+ using T = half_t; -+ auto gmem_layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM -+ auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: -+ // Take 128-elem from m: m0 must divide 128, -+ // m-last may be predicated -+ // Take 32-elem from k0, 2-elem from k1 -+ auto smem_layout = make_layout(cta_tile); // Col-Major SMEM -+ -+ thrust::host_vector h_in(size(gmem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_in.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking 2 SUCCESS\n"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu b/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu -new file mode 100644 -index 0000000..448b7f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/hopper/tma_store.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "cutlass_unit_test.h" -+ -+#include -+ -+#include -+#include -+ -+#include -+ -+using namespace cute; -+ -+template -+struct SharedStorage -+{ -+ cute::array_aligned> smem; -+}; -+ -+// __grid_constant__ was introduced in CUDA 11.7. -+#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -+# define CUTE_GRID_CONSTANT_SUPPORTED -+#endif -+ -+// __grid_constant__ can be enabled only on SM70+ -+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -+# define CUTE_GRID_CONSTANT_ENABLED -+#endif -+ -+#if ! defined(CUTE_GRID_CONSTANT) -+# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -+# define CUTE_GRID_CONSTANT __grid_constant__ -+# else -+# define CUTE_GRID_CONSTANT -+# endif -+#endif -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+template -+__global__ void -+tma_test_device_cute(T const* g_in, T* g_out, -+ CUTE_GRID_CONSTANT TiledCopy const tma, -+ GmemLayout gmem_layout, SmemLayout smem_layout) -+{ -+ // Use Shared Storage structure to allocate and distribute aligned SMEM addresses -+ extern __shared__ char shared_memory[]; -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ // Construct SMEM tensor -+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); -+ -+ // -+ // Read in trivially -+ // -+ -+ Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); -+ -+ // Input gmem -> smem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ sA(i) = gA_in(i); -+ } -+ -+ __syncthreads(); -+ -+#if 0 -+ -+ // -+ // Write out trivially -+ // -+ -+ Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); -+ -+ // Output smem -> gmem -+ for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { -+ gA_out(i) = sA(i); -+ } -+ -+#else -+ -+ // TMA requires special handling of strides to deal with coord codomain mapping -+ // Represent the full tensors -- get these from TMA -+ Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); -+ -+ // -+ // Prepare the TMA_STORE -+ // -+ -+ auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice -+ -+ Tensor tAsA = cta_tma.partition_S(sA); -+ Tensor tAgA = cta_tma.partition_D(gA); -+ -+ // -+ // Perform the TMA_STORE -+ // -+ -+ if (threadIdx.x == 0) { -+ copy(tma, tAsA, tAgA); -+ } -+ -+#endif -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_1,_32>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 ColMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Row) -+{ -+ using T = half_t; -+ Layout smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = smem_layout; -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 RowMajor SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_K) -+{ -+ using T = half_t; -+ auto smem_layout = GMMA::Layout_K_SW128_Atom{}; -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_K_SW128_Atom SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi2) -+{ -+ using T = half_t; -+ // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi_Dyn) -+{ -+ using T = half_t; -+ auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); -+ Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+ -+TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Multimode) -+{ -+ using T = half_t; -+ auto smem_layout = Layout, Stride<_32,_1>>{}; -+ Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); -+ -+ //auto smem_layout = Layout>{}; -+ //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); -+ -+ thrust::host_vector h_in(size(smem_layout)); -+ for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } -+ thrust::device_vector d_in = h_in; -+ thrust::device_vector d_out(h_in.size(), T(-1)); -+ -+ Tensor gA = make_tensor(d_out.data().get(), gmem_layout); -+ auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); -+ //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ tma_test_device_cute<<<1, 128, smem_size>>>( -+ thrust::raw_pointer_cast(d_in.data()), -+ thrust::raw_pointer_cast(d_out.data()), -+ tma, -+ gmem_layout, -+ smem_layout); -+ -+ thrust::host_vector h_out = d_out; -+ for (int i = 0; i < size(smem_layout); ++i) { -+ //printf("%d %d\n", int(h_in[i]), int(h_out[i])); -+ EXPECT_EQ(h_out[i], h_in[i]); -+ } -+ CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu b/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu -new file mode 100644 -index 0000000..6c44f5a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/cute/layout/layout_operator.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests Generic CuTe Layouts -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/layout.h" -+#include "cutlass/matrix_coord.h" -+ -+// Cute includes -+#include -+#include -+ -+using namespace cutlass; -+using namespace cute; -+ -+namespace test { -+namespace layout { -+ -+template -+ struct Testbed { -+ -+ -+ Testbed() {} -+ -+ bool run() { -+ GenericLayout generic_layout; -+ Layout layout = Layout::packed({size<0>(generic_layout), size<1>(generic_layout)}); -+ -+ for (int m = 0; m < size<0>(generic_layout); m++) { -+ for (int n = 0; n < size<1>(generic_layout); n++) { -+ if (generic_layout(m, n) != layout({m, n})) return false; -+ } -+ } -+ -+ return true; -+ } -+ }; -+ -+} -+} -+ -+////////////////////////////////////////////////////////////////////////// -+// Test Generic CuTe Layouts -+////////////////////////////////////////////////////////////////////////// -+ -+/// Canonical Layouts -+ -+TEST(GenericLayout, ColumnMajor) { -+ using GenericLayout = cute::Layout, Stride<_1, _8>>; -+ using Layout = cutlass::layout::ColumnMajor; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+TEST(GenericLayout, RowMajor) { -+ using GenericLayout = cute::Layout, Stride<_4, _1>>; -+ using Layout = cutlass::layout::RowMajor; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Swizzle Shared Memory layouts -+ -+TEST(GenericLayout, RowMajorTensorOpMultiplicandCrosswise) { -+ -+ using GenericLayout = decltype( -+ composition( -+ Swizzle<3,3,3>{}, -+ Layout, Stride<_64, _1>>{}) -+ ); -+ -+ using Layout = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -+ -+TEST(GenericLayout, ColumnMajorTensorOpMultiplicandCongruous) { -+ -+ using GenericLayout = decltype( -+ composition( -+ Swizzle<3,3,4>{}, -+ Layout>{}) -+ ); -+ -+ using Layout = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ -+ test::layout::Testbed testbed; -+ -+ EXPECT_TRUE(testbed.run()); -+} -+////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu b/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu -new file mode 100644 -index 0000000..9241ea2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/activation.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/layout.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+#include "cutlass/util/host_tensor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void test_Epilogue_thread_activation(T *out, T *in) { -+ -+ cutlass::Array *vec_out = reinterpret_cast *>(out); -+ cutlass::Array *vec_in = reinterpret_cast *>(in); -+ -+ Func func; -+ vec_out[threadIdx.x] = func(vec_in[threadIdx.x]); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Reference -+// -+ -+static double GELU_golden_input[] = { -+ 1.587425827980, 1.157652974129, 0.750432848930, -0.965980410576, -+ -0.388184845448, 0.014422321692, 0.353164494038, 1.354383468628, -+ 0.167588576674, 0.272798538208, -0.377032428980, 1.923444747925, -+ 0.308164477348, -0.341318070889, 0.278338819742, -0.292668998241, -+ -1.051743745804, -0.814175724983, 0.112737402320, 1.262938618660, -+ -1.582363605499, 0.722016870975, 1.053453564644, -0.659764587879, -+ 0.734917521477, 0.091274201870, 0.604461073875, -0.219043627381, -+ -0.136795744300, 0.960650205612, -1.805408835411, 0.091029644012, -+ -1.023343324661, 0.147713735700, -0.499895423651, 1.351878166199, -+ -1.631091356277, -0.336171895266, -1.612408638000, 0.090832948685, -+ -0.658132910728, -0.326727777719, -1.986387014389, 0.787685871124, -+ -1.015677452087, -0.225094825029, 0.876752018929, 0.744826257229, -+ 0.870290279388, -0.757595360279, 1.510331749916, 0.750012576580, -+ 0.906444966793, -0.915759027004, 1.260277032852, -0.158465340734, -+ -0.109191477299, -0.817102134228, 0.391305118799, -0.524910449982, -+ 0.351349592209, 0.801979541779, 0.446691334248, -0.741077482700, -+ 1.205966711044, -0.910210072994, 0.945986449718, 0.784096539021, -+ 1.670521497726, 0.344931513071, -0.301411420107, 0.309870749712, -+ -0.879704594612, -1.951189517975, -0.805817663670, -0.661812782288, -+ -0.505914270878, -1.836273789406, -0.381845980883, -0.554707705975, -+ -0.375447630882, -0.516645610332, 0.509586095810, 1.087131023407, -+ 2.664817094803, -1.558295488358, -0.076461032033, -0.504621028900, -+ 1.327111959457, -1.819981694221, 1.350415468216, -2.074112653732, -+ 1.501431345940, -1.339013576508, 0.162817999721, -1.473457217216, -+ 0.357770472765, 0.188413277268, 1.601302266121, -0.653882205486, -+ 0.856162548065, 0.763102591038, -0.526283502579, 0.581961452961, -+ 0.089969776571, 1.968745589256, 0.545802056789, -1.168786048889, -+ 1.206663012505, -0.109096683562, -1.223938226700, 0.744599223137, -+ -1.779406785965, 0.766436159611, -0.579044401646, -1.002057313919, -+ -0.715845823288, -0.562508940697, 0.886768460274, 2.327786445618, -+ -0.148763969541, -0.918884515762, -0.367678701878, -1.105021238327, -+ -0.461237311363, 0.158228352666, -0.254040330648, 1.427477598190, -+ 0.277530491352, 0.046293262392, -0.535557329655, -1.486695051193, -+ -0.953706681728, -1.040495038033, -0.314667612314, 0.348172843456, -+ 0.522773325443, 0.025960063562, -0.482472360134, 1.993084549904, -+ -0.253064930439, -0.012146313675, -2.166327714920, 0.398040622473, -+ -0.022238900885, -0.443580865860, -0.898376941681, -0.571689844131, -+ 1.666979670525, -0.831176340580, -0.671057403088, 0.481970995665, -+ -1.096243023872, -1.493894338608, 0.596651911736, -0.229505166411, -+ 1.165976166725, 0.905094027519, 0.049716457725, -1.362933635712, -+ -0.366948783398, 1.461613893509, -0.718411505222, 0.895385026932, -+ -0.763122260571, 1.329716682434, 1.366570711136, -0.086544901133, -+ 0.059739742428, 0.940766513348, -0.272854357958, -1.738811373711, -+ -0.361239165068, 0.696977972984, 1.288442254066, 1.264815807343, -+ -0.573566436768, -1.141678214073, 0.081865988672, -0.886228799820, -+ -0.236933603883, 1.050115466118, -0.538952171803, 0.651773929596, -+ -0.220034509897, -1.198960781097, 1.247478365898, -0.053529661149, -+ 0.639809548855, 1.672434806824, 0.511088073254, -1.179364681244, -+ -0.730427742004, 0.157630980015, 0.389369845390, -0.925578773022, -+ -0.093250080943, -0.391062080860, 0.852983593941, 1.868778109550, -+ -1.198786258698, 0.604997038841, -1.482687234879, -2.469333171844, -+ 0.718807697296, -0.559609353542, 2.187228441238, -2.927527904510, -+ 0.148535788059, -0.097280368209, 0.674131810665, -1.137645959854, -+ 0.792729616165, -1.166317462921, -0.498791724443, 1.675866723061, -+ -0.137909621000, -0.653263568878, -2.281216144562, 0.296096831560, -+ 2.002410173416, 1.083609819412, 0.933580815792, -1.504760265350, -+ 2.185185909271, 0.286121010780, -1.035485863686, -0.216372340918, -+ -0.274334043264, -0.849510788918, -1.397169828415, -0.407644748688, -+ 0.159476816654, -0.170650705695, 0.335193097591, -0.156852483749, -+ 0.036168430001, 0.858105242252, -1.086121797562, 0.404813349247, -+ -0.481496721506, -0.389882832766, 0.020690204576, -0.772020936012, -+ -0.758921504021, 0.323482036591, 0.115715265274, -0.811228036880, -+ -0.882436633110, 0.176811277866, 1.678015947342, 0.379081040621, -+ -0.842976212502, 0.346952259541, -0.545828759670, 1.632800459862 -+}; -+ -+static double GELU_golden_output[] = { -+ 1.498199582100, 1.014679551125, 0.580462038517, -0.161344811320, -+ -0.135453075171, 0.007294139825, 0.225325092673, 1.235459089279, -+ 0.094946734607, 0.165724009275, -0.133120641112, 1.871103763580, -+ 0.191376730800, -0.125069886446, 0.169681981206, -0.112644664943, -+ -0.154036879539, -0.169163048267, 0.061428427696, 1.132469892502, -+ -0.089851818979, 0.552240371704, 0.899579226971, -0.168043658137, -+ 0.565008401871, 0.048956073821, 0.439583092928, -0.090532489121, -+ -0.060955654830, 0.798911273479, -0.064101703465, 0.048816055059, -+ -0.156645998359, 0.082529976964, -0.154254898429, 1.232632875443, -+ -0.083896033466, -0.123835846782, -0.086161509156, 0.048703473061, -+ -0.167972877622, -0.121522113681, -0.046670529991, 0.617986679077, -+ -0.157319813967, -0.092503339052, 0.709896743298, 0.574865520000, -+ 0.703132867813, -0.169963955879, 1.411436080933, 0.580042064190, -+ 0.741154611111, -0.164741978049, 1.129479527473, -0.069256491959, -+ -0.049848672003, -0.169087052345, 0.255214750767, -0.157380074263, -+ 0.223928079009, 0.632535398006, 0.300378054380, -0.169946283102, -+ 1.068588852882, -0.165071934462, 0.783203184605, 0.614346146584, -+ 1.591325283051, 0.219006344676, -0.115003645420, 0.192637458444, -+ -0.166712537408, -0.049788996577, -0.169361919165, -0.168130636215, -+ -0.155041679740, -0.060888241976, -0.134137839079, -0.160614117980, -+ -0.132782235742, -0.156389534473, 0.354075312614, 0.936574816704, -+ 2.654553413391, -0.092845752835, -0.035900454968, -0.154874503613, -+ 1.204704761505, -0.062572605908, 1.230982899666, -0.039479542524, -+ 1.401402950287, -0.120890334249, 0.091938301921, -0.103604510427, -+ 0.228880971670, 0.108285568655, 1.513783097267, -0.167782157660, -+ 0.688394129276, 0.593158841133, -0.157540664077, 0.418839782476, -+ 0.048209801316, 1.920528769493, 0.386099845171, -0.141709372401, -+ 1.069367766380, -0.049809500575, -0.135230198503, 0.574639260769, -+ -0.066881760955, 0.596510827541, -0.162873372436, -0.158483341336, -+ -0.169686436653, -0.161375194788, 0.720409095287, 2.304597616196, -+ -0.065585561097, -0.164551988244, -0.131098195910, -0.148708447814, -+ -0.148663327098, 0.089060656726, -0.101548098028, 1.317959904671, -+ 0.169103100896, 0.024001283571, -0.158595800400, -0.101909510791, -+ -0.162240833044, -0.155090972781, -0.118474565446, 0.221488356590, -+ 0.365645468235, 0.013248858973, -0.151851043105, 1.946992278099, -+ -0.101253561676, -0.006014300976, -0.032804865390, 0.260597169399, -+ -0.010922161862, -0.145792976022, -0.165743649006, -0.162226170301, -+ 1.587365984917, -0.168676435947, -0.168497130275, 0.330191940069, -+ -0.149622067809, -0.100989677012, 0.432351946831, -0.093922272325, -+ 1.023946166039, 0.739726305008, 0.025843897834, -0.117827951908, -+ -0.130937814713, 1.356489539146, -0.169726014137, 0.729478538036, -+ -0.169943705201, 1.207641005516, 1.249209761620, -0.040288090706, -+ 0.031292784959, 0.777626037598, -0.107090584934, -0.071350336075, -+ -0.129670530558, 0.527676224709, 1.161149263382, 1.134579420090, -+ -0.162394225597, -0.144757837057, 0.043603736907, -0.166386902332, -+ -0.096278958023, 0.895924389362, -0.158969298005, 0.484089732170, -+ -0.090857118368, -0.138206124306, 1.115107178688, -0.025622237474, -+ 0.472724437714, 1.593463659286, 0.355387806892, -0.140493586659, -+ -0.169871479273, 0.088687323034, 0.253673940897, -0.164135158062, -+ -0.043161027133, -0.136040985584, 0.685087263584, 1.811169505119, -+ -0.138226687908, 0.440080583096, -0.102422207594, -0.016713079065, -+ 0.549075841904, -0.161096408963, 2.155813455582, -0.005001218989, -+ 0.083037458360, -0.044870752841, 0.505522191525, -0.145202502608, -+ 0.623111069202, -0.141991063952, -0.154108211398, 1.597298502922, -+ -0.061391282827, -0.167753636837, -0.025704355910, 0.182520583272, -+ 1.957115054131, 0.932696640491, 0.769961357117, -0.099604383111, -+ 2.153636932373, 0.175279796124, -0.155551761389, -0.089653611183, -+ -0.107515335083, -0.168032020330, -0.113423995674, -0.139319628477, -+ 0.089841812849, -0.073763631284, 0.211594089866, -0.068651281297, -+ 0.018605981022, 0.690416753292, -0.150658726692, 0.266040354967, -+ -0.151710823178, -0.135800719261, 0.010515870526, -0.169883996248, -+ -0.169960290194, 0.202769815922, 0.063187584281, -0.169236257672, -+ -0.166577890515, 0.100812792778, 1.599699616432, 0.245525524020, -+ -0.168275654316, 0.220552831888, -0.159705042839, 1.549110531807 -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_gelu_taylor, device_f32) { -+ -+ int const kN = 256; -+ int const kV = 4; -+ -+ using Element = float; -+ using Func = cutlass::epilogue::thread::GELU_taylor>; -+ -+ double tolerance = 0.005; -+ -+ // -+ // Construct workspace -+ // -+ cutlass::HostTensor tensor_Destination({1, kN}); -+ cutlass::HostTensor tensor_Source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ tensor_Source.host_data(i) = Element(GELU_golden_input[i]); -+ } -+ -+ tensor_Destination.sync_device(); -+ tensor_Source.sync_device(); -+ -+ // -+ // Launch the kernel -+ // -+ dim3 grid(1,1,1); -+ dim3 block(kN / kV, 1, 1); -+ -+ test_Epilogue_thread_activation<<< grid, block >>>( -+ tensor_Destination.device_data(), -+ tensor_Source.device_data()); -+ -+ tensor_Destination.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ for (int i = 0; i < kN; ++i) { -+ Element input = Element(GELU_golden_input[i]); -+ Element got = tensor_Destination.host_data(i); -+ Element expected = Element(GELU_golden_output[i]); -+ -+ double rel_error = (double(got) - double(expected)) / double(expected); -+ -+ double tolerance_override = tolerance; -+ -+ switch (i) { -+ case 142: tolerance_override = 0.008; break; -+ case 203: tolerance_override = 0.03; break; -+ case 207: tolerance_override = 0.09; break; -+ case 218: tolerance_override = 0.013; break; -+ } -+ -+ EXPECT_LT(std::abs(rel_error), tolerance_override) -+ << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_gelu_taylor, device_f16) { -+ -+ int const kN = 256; -+ int const kV = 8; -+ -+ using Element = cutlass::half_t; -+ using Func = cutlass::epilogue::thread::GELU_taylor>; -+ -+ double tolerance = 0.005; -+ -+ // -+ // Construct workspace -+ // -+ cutlass::HostTensor tensor_Destination({1, kN}); -+ cutlass::HostTensor tensor_Source({1, kN}); -+ -+ for (int i = 0; i < kN; ++i) { -+ tensor_Source.host_data(i) = Element(GELU_golden_input[i]); -+ } -+ -+ tensor_Destination.sync_device(); -+ tensor_Source.sync_device(); -+ -+ // -+ // Launch the kernel -+ // -+ dim3 grid(1,1,1); -+ dim3 block(kN / kV, 1, 1); -+ -+ test_Epilogue_thread_activation<<< grid, block >>>( -+ tensor_Destination.device_data(), -+ tensor_Source.device_data()); -+ -+ tensor_Destination.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ for (int i = 0; i < kN; ++i) { -+ Element input = Element(GELU_golden_input[i]); -+ Element got = tensor_Destination.host_data(i); -+ Element expected = Element(GELU_golden_output[i]); -+ -+ double rel_error = (double(got) - double(expected)) / double(expected); -+ -+ double tolerance_override = tolerance; -+ -+ switch (i) { -+ case 36: tolerance_override = 0.006; break; -+ case 77: tolerance_override = 0.009; break; -+ case 95: tolerance_override = 0.008; break; -+ case 112: tolerance_override = 0.007; break; -+ case 171: tolerance_override = 0.006; break; -+ case 203: tolerance_override = 0.03; break; -+ case 207: tolerance_override = 0.15; break; -+ } -+ -+ EXPECT_LT(std::abs(rel_error), tolerance_override) -+ << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu -new file mode 100644 -index 0000000..548924e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination.cu -@@ -0,0 +1,205 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_gelu.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination, device_side_f16_f32_value) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombination = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(2); -+ Element beta = Element(1); -+ -+ typename LinearCombination::Params params(alpha, beta); -+ -+ LinearCombination linear_combination_op(params); -+ -+ cutlass::Array source; -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element(i * 2); -+ source[i] = ElementOutput((i * 7 % 9) - 4); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, source); -+ -+ for (int i = 0; i < kCount; ++i) { -+ -+ ElementOutput expected = ElementOutput( -+ alpha * accum[i] + -+ beta * Element(ElementOutput(source[i])) -+ ); -+ -+ ElementOutput got = destination[i]; -+ -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombination = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(2); -+ Element beta = Element(1); -+ -+ typename LinearCombination::Params params(&alpha, &beta); -+ -+ LinearCombination linear_combination_op(params); -+ -+ cutlass::Array source; -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element(i * 2); -+ source[i] = ElementOutput((i * 7 % 9) - 4); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, source); -+ -+ for (int i = 0; i < kCount; ++i) { -+ -+ ElementOutput expected = ElementOutput( -+ alpha * accum[i] + -+ beta * Element(ElementOutput(source[i])) -+ ); -+ -+ ElementOutput got = destination[i]; -+ -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombinationGELU = cutlass::epilogue::thread::LinearCombinationGELU< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(1); -+ Element beta = Element(0); -+ -+ typename LinearCombinationGELU::Params params(&alpha, &beta); -+ -+ LinearCombinationGELU linear_combination_op(params); -+ -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element((float)i * 0.3f); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, accum); -+ cutlass::epilogue::thread::GELU gelu_func; -+ -+ for (int i = 0; i < kCount; ++i) { -+ ElementOutput expected = gelu_func(accum[i]); -+ ElementOutput got = destination[i]; -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_gelu_taylor, device_side_f16_f16_ptr) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using LinearCombinationGELU = cutlass::epilogue::thread::LinearCombinationGELU< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ Element alpha = Element(1); -+ Element beta = Element(0); -+ -+ typename LinearCombinationGELU::Params params(&alpha, &beta); -+ -+ LinearCombinationGELU linear_combination_op(params); -+ -+ cutlass::Array accum; -+ -+ for (int i = 0; i < kCount; ++i) { -+ accum[i] = Element((float)i * 0.3f); -+ } -+ -+ cutlass::Array destination = linear_combination_op(accum, accum); -+ cutlass::epilogue::thread::GELU gelu_func; -+ -+ for (int i = 0; i < kCount; ++i) { -+ ElementOutput expected = gelu_func(accum[i]); -+ ElementOutput got = destination[i]; -+ EXPECT_TRUE(expected == got); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu -new file mode 100644 -index 0000000..cc027e0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/thread/linear_combination_planar_complex.cu -@@ -0,0 +1,286 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF32F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f32_f32( -+ float *output_ptr, -+ float const *accum_ptr, -+ float const *source_ptr, -+ typename FunctorPlanarComplexF32F32::Params params) { -+ -+ FunctorPlanarComplexF32F32 linear_combination_op(params); -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ *reinterpret_cast*>(output_ptr) = linear_combination_op(accum, source); -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f32) { -+ -+ using Element = float; -+ using ElementOutput = float; -+ int const kCount = 4; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF16F32 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 4, -+ float, -+ float>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f16_f32( -+ cutlass::half_t *output_ptr, -+ float const *accum_ptr, -+ cutlass::half_t const *source_ptr, -+ typename FunctorPlanarComplexF16F32::Params params, -+ int N) { -+ -+ FunctorPlanarComplexF16F32 linear_combination_op(params); -+ -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ #pragma unroll 1 -+ for (int n = 0; n < N; ++n) { -+ source = linear_combination_op(accum, source); -+ } -+ -+ *reinterpret_cast*>(output_ptr) = source; -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { -+ -+ using Element = float; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 4; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace thread { -+ -+using FunctorPlanarComplexF16F16 = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 4, -+ cutlass::half_t, -+ cutlass::half_t>; -+ -+__global__ void epilogue_thread_functor_planar_complex_f16_f16( -+ cutlass::half_t *output_ptr, -+ cutlass::half_t const *accum_ptr, -+ cutlass::half_t const *source_ptr, -+ typename FunctorPlanarComplexF16F16::Params params, -+ int N) { -+ -+ FunctorPlanarComplexF16F16 linear_combination_op(params); -+ -+ auto accum = *reinterpret_cast const *>(accum_ptr); -+ auto source = *reinterpret_cast const *>(source_ptr); -+ -+ #pragma unroll 1 -+ for (int n = 0; n < N; ++n) { -+ source = linear_combination_op(accum, source); -+ } -+ -+ *reinterpret_cast*>(output_ptr) = source; -+} -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_thread_linear_combination_planar_complex, f16_f16) { -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ int const kCount = 8; -+ -+ using Functor = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kCount, -+ Element, -+ Element>; -+ -+ cutlass::complex alpha(Element(2), Element(1)); -+ cutlass::complex beta(Element(1), Element(-1)); -+ -+ typename Functor::Params params(alpha, beta); -+ -+ Functor linear_combination_op(params); -+ -+ cutlass::ArrayPlanarComplex source; -+ cutlass::ArrayPlanarComplex accum; -+ -+ // Define arbitrary inputs -+ for (int i = 0; i < kCount; ++i) { -+ accum.real[i] = Element(i * 2); -+ accum.imag[i] = Element((i * 3 % 6) - 3); -+ source.real[i] = ElementOutput((i * 7 % 9) - 4); -+ source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); -+ } -+ -+ cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); -+ -+ // Verify each result -+ for (int i = 0; i < kCount; ++i) { -+ -+ cutlass::complex expected = alpha * cutlass::complex(accum.real[i], accum.imag[i]) + -+ beta * cutlass::complex(Element(source.real[i]), Element(source.imag[i])); -+ -+ cutlass::complex got(destination.real[i], destination.imag[i]); -+ -+ EXPECT_TRUE(ElementOutput(expected.real()) == got.real()); -+ EXPECT_TRUE(ElementOutput(expected.imag()) == got.imag()); -+ EXPECT_TRUE(expected.real() != Element(0) || expected.imag() != Element(0)); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu -new file mode 100644 -index 0000000..341e009 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_planar_complex.cu -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+// Tensor Op -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+// Volta Tensor Op -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+// Simt -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+// Epilogue components -+ -+#include "cutlass/epilogue/threadblock/default_epilogue_planar_complex.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed_planar_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f16_f16_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, -+ InstructionShape, -+ Element, LayoutA, -+ Element, LayoutB, -+ ElementAccumulator, cutlass::layout::RowMajor -+ >::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_f32_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using Element = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementAccumulator, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaTensorOp, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f32_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaSimt, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Epilogue_threadblock_epilogue, planar_complex_simt_f64_64x64_16x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpiloguePlanarComplex< -+ Shape, -+ WarpMmaSimt, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpiloguePlanarComplexTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu -new file mode 100644 -index 0000000..5bd1ddf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt.cu -@@ -0,0 +1,1172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued single precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using Element = float; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued double precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_f64_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued single-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex-valued double-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_complex_f64_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::complex; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Quaternion-valued single-precision -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Epilogue_threadblock_epilogue, simt_quaternion_f32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::Quaternion; -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu -new file mode 100644 -index 0000000..36ed25b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm60.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued half precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_64x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM60_Epilogue_threadblock_epilogue, simt_f16_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Element = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using ElementOutput = Element; -+ using ElementAccumulator = Element; -+ using ElementCompute = Element; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu -new file mode 100644 -index 0000000..ff17915 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_simt_sm61.cu -@@ -0,0 +1,1120 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+ -+#include "cutlass/gemm/warp/mma_simt.h" -+#include "cutlass/gemm/warp/mma_simt_policy.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer - single-precision float output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_f32_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = float; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Real-valued Integer tests - mixed-precision with clamping -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_32x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_32x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_128x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM61_Epilogue_threadblock_epilogue, simt_i8_i32_128x64_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int; -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ -+ int const kElementsPerAccess = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaSimt = cutlass::gemm::warp::MmaSimt< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ > -+ >; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< -+ Shape, -+ WarpMmaSimt, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu -new file mode 100644 -index 0000000..cdeb188 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_tensor_op.cu -@@ -0,0 +1,3076 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x64_32x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x128_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x64_64x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_64x128_32x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_32x128_32x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x32_64x32x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_256x128_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s4_tensor_op_128x256_64x64x32) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x64_32x3216) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x64_64x32x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_64x128_32x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_32x128_32x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM75_Epilogue_threadblock_epilogue, s8_tensor_op_128x32_64x32x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 64 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = ElementOutput; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementAccumulator, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Mixed precision tests -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, mixed_f16_f32_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F16 acumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_32x32_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x64_32x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, f16_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Epilogue_threadblock_epilogue, f64_tensor_op_128x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_mixed_f16_f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_threadblock_epilogue, vec1_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 1; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu -new file mode 100644 -index 0000000..62c86c8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_volta_tensor_op.cu -@@ -0,0 +1,2893 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+#include "cutlass/epilogue/threadblock/default_thread_map_volta_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_64x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Mixed: F32 accumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_f32_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F32 accumulation, F32 output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x64_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_256x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_128x64_64x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f32_volta_tensor_op_64x128_32x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This works -+TEST(SM70_Epilogue_threadblock_epilogue, vec8_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 8; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+// This works -+TEST(SM70_Epilogue_threadblock_epilogue, vec2_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 2; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// This fails -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_64x64_32x32x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = float; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_128x128_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, vec1_f16_f32_volta_tensor_op_128x256_64x64x4) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 4>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = ElementC; -+ using ElementCompute = ElementC; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using WarpMmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ WarpShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ -+ int const kPartitionsK = 1; -+ int const kElementsPerAccess = 1; -+ -+ using ThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapVoltaTensorOp< -+ Shape, -+ WarpShape, -+ kPartitionsK, -+ ElementC, -+ kElementsPerAccess, -+ ElementAccumulator>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu -new file mode 100644 -index 0000000..1932765 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_tensor_op.cu -@@ -0,0 +1,879 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "epilogue_with_reduction_testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// Disable selected tests on CUDA 11.1 -+// -+// -+#define ENABLE_BLOCKED_TESTS (!(__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 1)) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x64_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if ENABLE_BLOCKED_TESTS -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x64_64x32x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_64x128_32x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_128x256_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f32_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_with_reduction_threadblock, f16_tensor_op_256x128_64x64x8) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ using ElementCompute = float; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ WarpShape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ ElementAccumulator, -+ ElementAccumulator, -+ ElementOutput, -+ ElementOutput, -+ kElementsPerAccess -+ >; -+ -+ using ReductionOp = cutlass::plus; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ ElementOutput, -+ OutputOp, -+ ReductionOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueWithReductionTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h -new file mode 100644 -index 0000000..c0e6fcc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h -@@ -0,0 +1,435 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_with_reduction_threadblock( -+ typename Epilogue::ElementVector *ptr_Reduction, -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ typename Epilogue::TensorTileIterator::Params params_Tensor, -+ typename Epilogue::TensorTileIterator::Element *ptr_Tensor, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::TensorTileIterator iterator_T( -+ params_Tensor, -+ ptr_Tensor, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ accumulator_iterator.load(accumulators); -+ -+#if 0 -+ // For debugging, enable this block of code to fill each accumulator element with its -+ // source thread ID. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); -+ //typename Epilogue::WarpMmaOperator::ElementC x(i); -+ accumulators[i] = x; -+ } -+ -+ /* -+ #pragma unroll 1 -+ for (int tid = 0; tid < 32; ++tid) { -+ if (tid == thread_idx) { -+ printf("\nT%d: ", thread_idx); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ printf("%d ", int(accumulators[i])); -+ } -+ } -+ } -+ -+ if (thread_idx == 0) { -+ printf("\n\n"); -+ } -+ */ -+ -+ __syncthreads(); -+ -+#endif -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpilogueWithReductionTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementTensor = typename Epilogue::TensorTileIterator::Element; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensor accumulator_tensor; -+ cutlass::HostTensor source_tensor; -+ cutlass::HostTensor output_tensor; -+ cutlass::HostTensor additional_tensor; -+ cutlass::HostTensor reduction_tensor; -+ -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpilogueWithReductionTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ reduction_tensor({1, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); -+ } -+ -+ bool run_all() { -+ -+ /* -+ double alpha_values[] = {1, 0, 2.25}; -+ double beta_values[] = {0, 1, -1.25}; -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ int m = quantized_size.row() - m_idx * 3; -+ int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; -+ -+ for (double const &alpha : alpha_values) { -+ for (double const &beta : beta_values) { -+ -+ bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ return true; -+ */ -+ -+ double alpha = 1; -+ double beta = 0; -+ -+ return run( -+ {quantized_size.row(), quantized_size.column()}, -+ {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ElementOutput default_output = ElementOutput(-127); -+ ElementAccumulator default_reduction = ElementAccumulator(); -+ -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ additional_tensor.sync_device(); -+ reduction_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); -+ typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( -+ reduction_tensor.device_data(), -+ params_D, -+ output_tensor.device_data(), -+ params_C, -+ source_tensor.device_data(), -+ params_T, -+ additional_tensor.device_data(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ reduction_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ // -+ // The output has two parts: -+ // - GEMM tensor epilogue in canonical layout -+ // - partial reduction in canonical row-major layout -+ // -+ -+ // Verify the GEMM tensor output -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ElementOutput got = output_tensor.at(coord); -+ -+ ElementOutput expected; -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ -+ expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ElementCompute(source_tensor.at(coord))); -+ } -+ else { -+ expected = default_output; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // Verify the partial reduction -+ for (int c = 0; c < quantized_size.column(); ++c) { -+ -+ ElementAccumulator reduction_acc = ElementAccumulator(); -+ -+ for (int r = 0; r < quantized_size.row(); ++r) { -+ reduction_acc += accumulator_tensor.at({r, c}); -+ } -+ -+ ElementAccumulator expected = default_reduction; -+ ElementAccumulator got = reduction_tensor.at({0, c}); -+ -+ if (c < problem_size.column()) { -+ expected = reduction_acc; -+ } -+ else { -+ expected = default_reduction; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - reduction element (" << c << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu -new file mode 100644 -index 0000000..bc835f2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/epilogue_wmma_tensor_op_sm70.cu -@@ -0,0 +1,264 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "testbed.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F16 acumulation -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Epilogue_threadblock_epilogue, f16_wmma_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+ -+TEST(SM70_Epilogue_threadblock_epilogue, f16_wmma_tensor_op_64x128_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// F32 acumulation and F32 output -+// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Epilogue_threadblock_epilogue, f32_wmma_tensor_op_64x64_64x64x16) { -+ -+ // -+ // Define the warp-level matrix multiply -+ // -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ using ElementCompute = cutlass::half_t; -+ int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ int const kPartitionsK = 1; -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = ElementAccumulator; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WarpMmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ // -+ // Output operator -+ // -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kElementsPerAccess, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ // -+ // Define the epilogue -+ // -+ -+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< -+ Shape, -+ WarpMmaTensorOp, -+ kPartitionsK, -+ OutputOp, -+ kElementsPerAccess -+ >::Epilogue; -+ -+ // -+ // Instantiate epilogue -+ // -+ -+ EpilogueTestbed testbed; -+ -+ bool passed = testbed.run_all(); -+ -+ EXPECT_TRUE(passed); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+#endif //CUTLASS_ARCH_WMMA_ENABLED -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu -new file mode 100644 -index 0000000..7874363 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/output_tile_threadmap.cu -@@ -0,0 +1,549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prototype algorithm for partitioning a 4D space across warps to achieve several performance -+/// objectives: -+/// -+/// - coalesced memory accesses in units of 128 Byte lines -+/// - minimal address arithmetic -+/// - minimal predicate calculations -+/// -+struct OutputTileThreadMapExpr { -+ -+ struct Shape { -+ int column; -+ int row; -+ int group; -+ int cluster; -+ -+ Shape(int col = 1, int r = 1, int g = 1, int c = 1): -+ column(col), row(r), group(g), cluster(c) { } -+ }; -+ -+ int const kWarpSize = 32; -+ int const kMemoryAccessSize = 256; // size in bytes of the preferred memory access size -+ -+ // -+ // Data members -+ // -+ -+ Shape shape; -+ Shape count; -+ int threads; -+ int warp_count; -+ int elements_per_access; -+ int element_size; -+ -+ Shape iterations; -+ Shape delta; -+ Shape warp_partitions; -+ -+ int access_width_in_vectors; -+ int access_rows; -+ -+ // -+ // Methods -+ // -+ -+ OutputTileThreadMapExpr( -+ Shape shape_, -+ Shape count_, -+ int threads_, -+ int elements_per_access_, -+ int element_size_ -+ ): -+ shape(shape_), -+ count(count_), -+ threads(threads_), -+ warp_count(threads_ / kWarpSize), -+ elements_per_access(elements_per_access_), -+ element_size(element_size_) { -+ -+ int warps_remaining = warp_count; -+ -+ // clusters -+ if (shape.cluster > warp_count) { -+ iterations.cluster = shape.cluster / warp_count; -+ delta.cluster = shape.row * count.row * shape.group * count.group * shape.cluster / iterations.cluster; -+ warps_remaining = 1; -+ warp_partitions.cluster = warp_count; -+ } -+ else { -+ iterations.cluster = 1; -+ delta.cluster = 1; -+ warps_remaining = warp_count / shape.cluster; -+ warp_partitions.cluster = warps_remaining; -+ } -+ -+ // group size -+ if (shape.group > warps_remaining) { -+ iterations.group = shape.group / warps_remaining; -+ delta.group = shape.row * count.row * shape.group / iterations.group; -+ warps_remaining = 1; -+ warp_partitions.group = warps_remaining; -+ } -+ else { -+ iterations.group = 1; -+ delta.group = 1; -+ warps_remaining = warps_remaining / shape.group; -+ warp_partitions.group = warps_remaining; -+ } -+ -+ // Number of rows in a group -+ if (shape.row > warps_remaining) { -+ -+ // We must cover this shape within a warp -+ int shape_row = shape.row / warps_remaining; -+ int shape_width_vectors = shape.column / elements_per_access; -+ -+ // We would still like to minimize the number of strided increments. We can accomplish this -+ // by arranging the memory instructions as 2D, 128B wide accesses. -+ -+ int target_memory_access_width = kMemoryAccessSize / (elements_per_access * element_size / 8); -+ int target_rows_per_access = kWarpSize / target_memory_access_width; -+ -+ if (target_rows_per_access > shape_row) { -+ access_rows = shape_row; -+ access_width_in_vectors = kWarpSize / access_rows; -+ } -+ else { -+ -+ access_width_in_vectors = cutlass::platform::min( -+ shape_width_vectors, -+ cutlass::platform::min(kWarpSize, kMemoryAccessSize / (elements_per_access * element_size / 8))); -+ -+ access_rows = cutlass::platform::min(shape_row, kWarpSize / access_width_in_vectors); -+ } -+ -+ iterations.row = shape_row / access_rows; -+ delta.row = access_rows; -+ -+ iterations.column = shape_width_vectors / access_width_in_vectors; -+ delta.column = access_width_in_vectors * elements_per_access; -+ -+ warp_partitions.column = 1; -+ warp_partitions.row = 1; -+ } -+ else { -+ iterations.row = 1; -+ delta.row = 1; -+ iterations.column = (shape.column / elements_per_access) / kWarpSize; -+ delta.column = kWarpSize * elements_per_access; -+ -+ access_width_in_vectors = kWarpSize; -+ access_rows = 1; -+ -+ warp_partitions.row = 1; -+ warp_partitions.column = warps_remaining; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::ostream & operator<<(std::ostream &out, OutputTileThreadMapExpr::Shape const &shape) { -+ out << "col: " << shape.column << ", r: " << shape.row << ", g: " << shape.group << ", c: " << shape.cluster; -+ return out; -+} -+ -+std::ostream & operator<<(std::ostream &out, OutputTileThreadMapExpr const &map) { -+ out -+ << " shape(" << map.shape << ")\n" -+ << " count(" << map.count << ")\n" -+ << " iterations(" << map.iterations << ")\n" -+ << " delta(" << map.delta << ")\n" -+ << " warps(" << map.warp_partitions << ")\n" -+ << " access(width: " << map.access_width_in_vectors -+ << ", rows: " << map.access_rows -+ << ") x v" << map.elements_per_access -+ << ".b" << map.element_size << "\n"; -+ -+ return out; -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Shape, -+ typename Count, -+ int Threads, -+ int ElementsPerAccess, -+ int ElementSize -+> -+struct ThreadMapTestbed { -+ ThreadMapTestbed() { -+ OutputTileThreadMapExpr map( -+ { Shape::kColumn, Shape::kRow, Shape::kGroup, Shape::kCluster }, -+ { Count::kColumn, Count::kRow, Count::kGroup, Count::kCluster }, -+ Threads, -+ ElementsPerAccess, -+ ElementSize -+ ); -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< -+ Shape, -+ Count, -+ Threads, -+ ElementsPerAccess, -+ ElementSize -+ >; -+ -+ using CompactThreadmap = typename ThreadMap::CompactedThreadMap; -+ -+ bool const kVerbose = false; -+ -+ if (kVerbose) { -+ -+ std::cout << map << std::endl; -+ -+ std::cout << "ThreadMap::warps remaining:\n" -+ << " for groups: " << ThreadMap::Detail::kWarpsRemainingForGroups << "\n" -+ << " for rows: " << ThreadMap::Detail::kWarpsRemainingForRows << "\n"; -+ -+ std::cout << "ThreadMap::Access:\n" -+ << " width: " << ThreadMap::Detail::kAccessWidth << "\n" -+ << " rows: " << ThreadMap::Detail::kAccessRows << "\n"; -+ -+ std::cout << "ThreadMap::RowArrangement::Iterations:\n" -+ << " row: " << int(ThreadMap::Detail::RowArrangement::kIterationsRow) << "\n"; -+ } -+ -+ EXPECT_EQ(int(ThreadMap::Delta::kCluster), map.delta.cluster); -+ EXPECT_EQ(int(ThreadMap::Delta::kGroup), map.delta.group); -+ EXPECT_EQ(int(ThreadMap::Delta::kRow), map.delta.row); -+ EXPECT_EQ(int(ThreadMap::Delta::kColumn), map.delta.column); -+ -+ EXPECT_EQ(int(ThreadMap::Iterations::kCluster), map.iterations.cluster); -+ EXPECT_EQ(int(ThreadMap::Iterations::kGroup), map.iterations.group); -+ EXPECT_EQ(int(ThreadMap::Iterations::kRow), map.iterations.row); -+ EXPECT_EQ(int(ThreadMap::Iterations::kColumn), map.iterations.column); -+ -+ if (kVerbose) { -+ std::cout << "Iterations(col: " << ThreadMap::Iterations::kColumn -+ << ", r: " << ThreadMap::Iterations::kRow -+ << ", g: " << ThreadMap::Iterations::kGroup -+ << ", c: " << ThreadMap::Iterations::kCluster << ")\n"; -+ -+ std::cout << "Delta(col: " << ThreadMap::Delta::kColumn -+ << ", r: " << ThreadMap::Delta::kRow -+ << ", g: " << ThreadMap::Delta::kGroup -+ << ", c: " << ThreadMap::Delta::kCluster << ")\n"; -+ -+ for (int tid = 0; tid < Threads; ++tid) { -+ auto output_coord = ThreadMap::initial_offset(tid); -+ auto source_coord = CompactThreadmap::initial_offset(tid); -+ -+ std::cout << "T" << tid << " - output: " << output_coord << ", source: " << source_coord << "\n"; -+ } -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, f16_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+ -+TEST(ThreadMap, f16_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f16_tensor_op_64x128_128x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 8; -+ int const kElementSize = 16; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 8, 2, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_tensor_op_64x128_128x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, f32_volta_tensor_op_64x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_64x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x64_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x64_64x32x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_128x256_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<256, 2, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, f32_volta_tensor_op_256x128_64x64x8) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 4, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 4; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(ThreadMap, simt_32x64_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<64, 1, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 32; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_32x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 1, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 64; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_64x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 2, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 128; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+TEST(ThreadMap, simt_128x128_32x64x1) { -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>; -+ using Count = cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 1>; -+ int const kThreads = 256; -+ int const kElementsPerAccess = 1; -+ int const kElementSize = 32; -+ -+ ThreadMapTestbed(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu b/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu -new file mode 100644 -index 0000000..287b51a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/predicated_tile_iterator.cu -@@ -0,0 +1,1125 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_store_iterator( -+ typename TileIterator::Params params, -+ typename TileIterator::TensorRef ref, -+ cutlass::MatrixCoord extent) { -+ -+ TileIterator iterator(params, ref.data(), extent, threadIdx.x, {0, 0}); -+ -+ typename TileIterator::Fragment fragment; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < TileIterator::ThreadMap::Count::kTile; ++iter) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < TileIterator::Fragment::kElements; ++i) { -+ typename TileIterator::Element tidx(iter + 1); -+ fragment[i] = tidx; -+ } -+ -+ iterator.store(fragment); -+ -+ ++iterator; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} -+} -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static bool verify_footprint(cutlass::TensorView view, cutlass::MatrixCoord extent) { -+ -+ for (int r = 0; r < view.extent().row(); ++r) { -+ for (int c = 0; c < view.extent().column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ bool within = coord < extent; -+ if (within) { -+ if (view.at(coord) == T(0)) { -+ return false; -+ } -+ } -+ else { -+ if (view.at(coord) != T(0)) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_64x64x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+ // -+ // The following tests were used to develop the OutputTileOptimalThreadMap -+ // metaprogram. The definitions in the disabled blocks of code in this and -+ // the following tests are hand-written quantities. They are expected to -+ // match what is defined in the ThreadMap. -+ // -+ -+ #if 1 -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 64>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 1, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 64}; -+ cutlass::MatrixCoord output_extent{62, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_64x64x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_128x64x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 64; -+ -+ #if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 128>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 8, // row -+ 2, // group -+ 1, // cluster -+ 8 // tile -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 2, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 64, // group -+ 1, // cluster -+ 1 // tile -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 64}; -+ cutlass::MatrixCoord output_extent{125, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_128x64x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, tensor_op_128x256x32_64x64x8) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+ #if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ #else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<256, 128>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 256, // column -+ 8, // row -+ 2, // group -+ 1, // cluster -+ 8 // tile -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 2, // row -+ 2, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 64, // group -+ 1, // cluster -+ 1 // tile -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 8, // row -+ 1, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+ #endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 256}; -+ cutlass::MatrixCoord output_extent{123, 252}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("tensor_op_128x256x32_64x64x8.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_64x64x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 8>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 2, // row -+ 4, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 8, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 64}; -+ cutlass::MatrixCoord output_extent{62, 56}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("volta_tensor_op_64x64x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_64x128x32_32x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 128; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<128, 8>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 128, // column -+ 2, // row -+ 2, // group -+ 2, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 1, // group -+ 2, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 8, // group -+ 32, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 4, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{64, 128}; -+ cutlass::MatrixCoord output_extent{57, 124}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("volta_tensor_op_64x128x32_32x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, volta_tensor_op_128x256x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<256, 2, 4, 2, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<256, 16>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 256, // column -+ 2, // row -+ 4, // group -+ 2, // cluster -+ 8 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 2, // group -+ 2, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 1, // row -+ 16, // group -+ 64, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 256}; -+ cutlass::MatrixCoord output_extent{128, 256}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed || true) { -+ std::ofstream output("volta_tensor_op_128x256x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+ -+TEST(PredicatedTileIterator, volta_tensor_op_256x128x32_64x64x4) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 128 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4, 4, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{ 256, 128 }; -+ cutlass::MatrixCoord output_extent{ 256, 128 }; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator <<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed || true) { -+ std::ofstream output("volta_tensor_op_256x128x32_64x64x4.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, simt_32x64x8_32x64x1) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ static int const kThreads = 32; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<64, 1, 4, 1, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<64, 4>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 64, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 2, // column -+ 1, // row -+ 4, // group -+ 1, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 32, // column -+ 1, // row -+ 4, // group -+ 16, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{32, 64}; -+ cutlass::MatrixCoord output_extent{27, 63}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("simt_32x64x8_32x64x1.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(PredicatedTileIterator, simt_128x128x8_32x64x1) { -+ -+ using Layout = cutlass::layout::RowMajor; -+ using Element = int; -+ -+ static int const kElementsPerAccess = 32 / cutlass::sizeof_bits::value; -+ static int const kThreads = 256; -+ -+#if 1 -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileOptimalThreadMap < -+ cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>, -+ cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, -+ kThreads, -+ kElementsPerAccess, -+ cutlass::sizeof_bits::value -+ >; -+ -+#else -+ using InternalThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< -+ cutlass::layout::PitchLinearShape<128, 16>, -+ kThreads, -+ kElementsPerAccess -+ >; -+ -+ using Shape = cutlass::epilogue::threadblock::OutputTileShape< -+ 128, // column -+ 1, // row -+ 4, // group -+ 4, // cluster -+ 1 // iterations -+ >; -+ -+ using Iterations = cutlass::epilogue::threadblock::OutputTileShape< -+ 2, // column -+ 1, // row -+ 2, // group -+ 4, // cluster -+ 1 // iterations -+ >; -+ -+ using Delta = cutlass::epilogue::threadblock::OutputTileShape< -+ 32, // column -+ 1, // row -+ 8, // group -+ 32, // cluster -+ 1 // iterations -+ >; -+ -+ using Count = cutlass::epilogue::threadblock::OutputTileShape< -+ 1, // column -+ 4, // row -+ 2, // group -+ 1, // cluster -+ 8 // iterations -+ >; -+ -+ using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap< -+ InternalThreadMap, -+ Shape, -+ Iterations, -+ Delta, -+ Count -+ >; -+#endif -+ -+ using PredicatedTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< -+ ThreadMap, -+ Element -+ >; -+ -+ // -+ // Initialize workspace -+ // -+ cutlass::MatrixCoord tensor_extent{128, 128}; -+ cutlass::MatrixCoord output_extent{123, 121}; -+ -+ // -+ // Configure parameters -+ // -+ -+ cutlass::HostTensor host_tensor(tensor_extent); -+ -+ typename PredicatedTileIterator::Params iterator_params(host_tensor.layout()); -+ -+ host_tensor.sync_device(); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ test::epilogue::threadblock::kernel_store_iterator<<< grid, block >>>( -+ iterator_params, host_tensor.device_ref(), output_extent); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result); -+ -+ // -+ // Verify results -+ // -+ -+ host_tensor.sync_host(); -+ -+ bool passed = verify_footprint(host_tensor.host_view(), output_extent); -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("simt_128x128x8_32x64x1.csv"); -+ output << host_tensor.host_view(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h -new file mode 100644 -index 0000000..c2982c3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed.h -@@ -0,0 +1,371 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_threadblock( -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ accumulator_iterator.load(accumulators); -+ -+#if 0 -+ // For debugging, enable this block of code to fill each accumulator element with its -+ // source thread ID. -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); -+ //typename Epilogue::WarpMmaOperator::ElementC x(i); -+ accumulators[i] = x; -+ } -+ -+ /* -+ #pragma unroll 1 -+ for (int tid = 0; tid < 32; ++tid) { -+ if (tid == thread_idx) { -+ printf("\nT%d: ", thread_idx); -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulators.size(); ++i) { -+ printf("%d ", int(accumulators[i])); -+ } -+ } -+ } -+ -+ if (thread_idx == 0) { -+ printf("\n\n"); -+ } -+ */ -+ -+ __syncthreads(); -+ -+#endif -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue(output_op, iterator_D, accumulators, iterator_C); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpilogueTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensor accumulator_tensor; -+ cutlass::HostTensor source_tensor; -+ cutlass::HostTensor output_tensor; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpilogueTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ } -+ -+ bool run_all() { -+ -+ double alpha_values[] = {1, 0, 2.25}; -+ double beta_values[] = {0, 1, -1.25}; -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ int m = quantized_size.row() - m_idx * 3; -+ int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; -+ -+ for (double const &alpha : alpha_values) { -+ for (double const &beta : beta_values) { -+ -+ bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ElementOutput default_output = ElementOutput(-127); -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_threadblock<<< grid, block >>>( -+ params_D, -+ output_tensor.device_data(), -+ params_C, -+ source_tensor.device_data(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view()); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ElementOutput got = output_tensor.at(coord); -+ -+ ElementOutput expected; -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ ElementCompute intermediate = -+ output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ElementCompute(source_tensor.at(coord)); -+ -+ if (std::numeric_limits::is_integer -+ && !std::numeric_limits::is_integer) { -+ std::fesetround(FE_TONEAREST); -+ expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); -+ } else { -+ expected = ElementOutput(intermediate); -+ } -+ } else { -+ expected = default_output; -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) -+ << ", accum: " << (accumulator_tensor.at(coord)) -+ << ", source: " << OutputIO(source_tensor.at(coord)) -+ << ", alpha: " << (output_params.alpha) -+ << ", beta: " << (output_params.beta) << "\n"; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h -new file mode 100644 -index 0000000..68da6f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h -@@ -0,0 +1,394 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for epilogues -+*/ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/complex.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -+ -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace kernel { -+ -+template -+__global__ void epilogue_planar_complex_threadblock( -+ typename Epilogue::OutputTileIterator::Params params_D, -+ typename Epilogue::OutputTileIterator::Element *ptr_D, -+ int64_t imaginary_stride_D, -+ typename Epilogue::OutputTileIterator::Params params_C, -+ typename Epilogue::OutputTileIterator::Element *ptr_C, -+ int64_t imaginary_stride_C, -+ typename Epilogue::OutputOp::Params params_output_op, -+ cutlass::MatrixCoord problem_size, -+ cutlass::TensorRef< -+ typename Epilogue::WarpMmaOperator::ElementC, -+ typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, -+ int64_t imaginary_stride_accum, -+ int epilogue_count = 1) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int thread_idx = threadIdx.x; -+ int warp_idx = threadIdx.x / 32; -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Construct the epilogue -+ // -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_D_real( -+ params_D, -+ ptr_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_D_imag( -+ params_D, -+ ptr_D + imaginary_stride_D, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Tile iterator writing to output tile -+ typename Epilogue::OutputTileIterator iterator_C_real( -+ params_C, -+ ptr_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ typename Epilogue::OutputTileIterator iterator_C_imag( -+ params_C, -+ ptr_C + imaginary_stride_C, -+ problem_size, -+ thread_idx -+ ); -+ -+ // Epilogue operator -+ Epilogue epilogue( -+ shared_storage, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // -+ // Initialize the accumulators -+ // -+ -+ int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); -+ int warp_m = warp_mn % Epilogue::WarpCount::kM; -+ int warp_n = warp_mn / Epilogue::WarpCount::kM; -+ -+ accumulator_ref.add_coord_offset({ -+ warp_m * Epilogue::WarpMmaOperator::Shape::kM, -+ warp_n * Epilogue::WarpMmaOperator::Shape::kN}); -+ -+ // -+ // Load accumulators -+ // -+ -+ typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); -+ -+ typename Epilogue::AccumulatorTile accumulators; -+ -+ accumulators.clear(); -+ -+ accumulator_iterator.load(accumulators.real); -+ accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); -+ -+ // -+ // Perform the epilogue operation -+ // -+ -+ typename Epilogue::OutputOp output_op(params_output_op); -+ -+ // Place the epilogue in a loop so assembly is clearly visible -+ for (int iter = 0; iter < epilogue_count; ++iter) { -+ epilogue( -+ output_op, -+ iterator_D_real, -+ iterator_D_imag, -+ accumulators, -+ iterator_C_real, -+ iterator_C_imag); -+ } -+} -+ -+} // namespace kernel -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Epilogue_ -+> -+class EpiloguePlanarComplexTestbed { -+public: -+ -+ using Epilogue = Epilogue_; -+ using ElementAccumulator = typename Epilogue::ElementAccumulator; -+ using ElementCompute = typename Epilogue::OutputOp::ElementCompute; -+ using ElementOutput = typename Epilogue::ElementOutput; -+ using OutputOpParams = typename Epilogue::OutputOp::Params; -+ -+ using ComplexElementOutput = cutlass::complex; -+ using ComplexElementAccumulator = cutlass::complex; -+ using ComplexElementCompute = cutlass::complex; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ cutlass::MatrixCoord quantized_size; -+ cutlass::HostTensorPlanarComplex accumulator_tensor; -+ cutlass::HostTensorPlanarComplex source_tensor; -+ cutlass::HostTensorPlanarComplex output_tensor; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ EpiloguePlanarComplexTestbed(): -+ quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), -+ accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), -+ output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ #if 1 -+ uint64_t seed = 2019; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ accumulator_tensor.host_view(), -+ seed, -+ 20, -+ -20, -+ 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ source_tensor.host_view(), -+ seed + 2018, -+ 20, -+ -20, -+ 0); -+ #else -+ -+ cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); -+ -+ #endif -+ } -+ -+ bool run_all() { -+ -+ cutlass::complex alpha_values[3]; -+ -+ alpha_values[0] = cutlass::complex(1, 0); -+ alpha_values[1] = cutlass::complex(0, 0); -+ alpha_values[2] = cutlass::complex(2.25f, -0.5f); -+ -+ cutlass::complex beta_values[3]; -+ -+ beta_values[0] = cutlass::complex(0, 0); -+ beta_values[1] = cutlass::complex(1, 0); -+ beta_values[2] = cutlass::complex(0.5f, -2.25f); -+ -+ // Test runtime explodes if we tried to test every case exhaustively. This tests the full -+ // output tile and several smaller sizes to stress predication. -+ for (int m_idx = 0; m_idx < 3; ++m_idx) { -+ for (int n_idx = 0; n_idx < 3; ++n_idx) { -+ -+ cutlass::MatrixCoord problem_size( -+ quantized_size.row() - m_idx * 3, -+ quantized_size.column() - n_idx * Epilogue::kElementsPerAccess -+ ); -+ -+ for (auto const &alpha : alpha_values) { -+ for (auto const &beta : beta_values) { -+ -+ bool passed = run(problem_size, {alpha, beta}); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ OutputOpParams output_params) { -+ -+ // -+ // Initialize problem space -+ // -+ -+ ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); -+ -+ cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); -+ -+ accumulator_tensor.sync_device(); -+ output_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ // -+ // Initialize epilogue parameters -+ // -+ -+ typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); -+ typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1, 1); -+ dim3 block(Epilogue::WarpCount::kCount * 32, 1); -+ -+ test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( -+ params_D, -+ output_tensor.device_data(), -+ output_tensor.imaginary_stride(), -+ params_C, -+ source_tensor.device_data(), -+ source_tensor.imaginary_stride(), -+ output_params, -+ problem_size, -+ accumulator_tensor.device_view_real(), -+ accumulator_tensor.imaginary_stride() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ // -+ // Verify results -+ // -+ output_tensor.sync_host(); -+ -+ int errors = 0; -+ int const kMaxErrors = 5; -+ -+ for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { -+ for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { -+ -+ cutlass::MatrixCoord coord{r, c}; -+ ComplexElementOutput got = output_tensor.at(coord); -+ -+ ComplexElementOutput expected = default_output; -+ -+ if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { -+ -+ ComplexElementOutput src = source_tensor.at(coord); -+ -+ ComplexElementCompute tmp = -+ output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + -+ output_params.beta * ComplexElementCompute(src.real(), src.imag()); -+ -+ expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); -+ } -+ -+ if (expected != got) { -+ -+ using OutputIO = cutlass::ScalarIO; -+ -+ EXPECT_TRUE(false) -+ << "-------\n" -+ << "Error - output element (" << coord << ") - expected: " -+ << OutputIO(expected) -+ << ", got: " << OutputIO(got) << std::endl; -+ -+ ++errors; -+ } -+ } -+ } -+ -+ // -+ // Report results on error -+ // -+ -+ if (errors) { -+ -+ -+ std::cout << "Incorrect result for problem(" -+ << problem_size.row() << ", " -+ << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; -+ -+ std::stringstream ss; -+ ss -+ << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" -+ << Epilogue::WarpTileIterator::WarpShape::kM << "x" -+ << Epilogue::WarpTileIterator::WarpShape::kN -+ << "_slice_" << Epilogue::WarpCount::kK << ".csv"; -+ -+ std::ofstream output_file(ss.str()); -+ output_file << output_tensor.host_view(); -+ -+ std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; -+ } -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu -new file mode 100644 -index 0000000..e7e15ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu -@@ -0,0 +1,194 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Epilogue_warp_FragmentIterator, mma_f32_64x64x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // This test just prints things. -+ #if 0 -+ typename MmaTensorOp::FragmentC accum; -+ -+ std::cout << "Native accumulators:\n"; -+ -+ for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { -+ accum[i] = ElementC(i); -+ -+ std::cout << accum[i] << " "; -+ if (i && !((i + 1) % 4)) { -+ std::cout << "\n"; -+ } -+ } -+ -+ std::cout << std::endl; -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " kAccessesPerInstruction: " << FragmentIterator::Policy::kIterationsPerInstruction << "\n" -+ << " kAccumulatorRowStride: " << FragmentIterator::Policy::kAccumulatorRowStride << "\n" -+ << " kAccumulatorColumnStride: " << FragmentIterator::Policy::kAccumulatorColumnStride << "\n" -+ << " kIterations: " << FragmentIterator::Policy::kIterations << "\n" -+ << " }" << std::endl; -+ -+ FragmentIterator fragment_iterator(accum); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ -+ fragment_iterator.load(frag); -+ -+ std::cout << "Iteration " << iter << ":\n"; -+ -+ for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { -+ std::cout << frag[i] << " "; -+ } -+ -+ std::cout << std::endl; -+ -+ ++fragment_iterator; -+ } -+ #endif -+} -+ -+TEST(SM75_Epilogue_warp_FragmentIterator, mma_f16_64x64x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ // This test just prints things. -+ #if 0 -+ typename MmaTensorOp::FragmentC accum; -+ -+ std::cout << "Native accumulators:\n"; -+ -+ for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { -+ accum[i] = ElementC(i); -+ -+ std::cout << (float)accum[i] << " "; -+ if (i && !((i + 1) % 4)) { -+ std::cout << "\n"; -+ } -+ } -+ -+ std::cout << std::endl; -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " kAccessesPerInstruction: " << FragmentIterator::Policy::kIterationsPerInstruction << "\n" -+ << " kAccumulatorRowStride: " << FragmentIterator::Policy::kAccumulatorRowStride << "\n" -+ << " kAccumulatorColumnStride: " << FragmentIterator::Policy::kAccumulatorColumnStride << "\n" -+ << " kIterations: " << FragmentIterator::Policy::kIterations << "\n" -+ << " }" << std::endl; -+ -+ FragmentIterator fragment_iterator(accum); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ -+ fragment_iterator.load(frag); -+ -+ std::cout << "Iteration " << iter << ":\n"; -+ -+ for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { -+ std::cout << (float)frag[i] << " "; -+ } -+ -+ std::cout << std::endl; -+ -+ ++fragment_iterator; -+ } -+ #endif -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu -new file mode 100644 -index 0000000..ffb670f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_volta_tensor_op.cu -@@ -0,0 +1,216 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+#include "cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, mma_f16_64x64x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ cutlass::HostTensor accumulator_tensor({Shape::kM, Shape::kN}); -+ -+ cutlass::reference::host::TensorFill(accumulator_tensor.host_view(), ElementC(-1)); -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator_tile.size(); ++i) { -+ accumulator_tile[i] = ElementC(i); -+ } -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ cutlass::gemm::GemmShape<64, 64, 4>, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >; -+ -+ FragmentIterator frag_iterator(accumulator_tile); -+ -+ typename FragmentIterator::Fragment frag; -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ frag_iterator.load(frag); -+ ++frag_iterator; -+ -+ #if 0 -+ std::cout << "T" << tid << ": "; -+ for (int i = 0; i < frag.size(); ++i) { -+ std::cout << " " << frag[i]; -+ } -+ std::cout << std::endl; -+ #endif -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, mma_f32_64x64x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ cutlass::HostTensor accumulator_tensor({Shape::kM, Shape::kN}); -+ -+ cutlass::reference::host::TensorFill(accumulator_tensor.host_view(), ElementC(-1)); -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < accumulator_tile.size(); ++i) { -+ accumulator_tile[i] = ElementC(i); -+ } -+ -+ typename MmaTensorOp::IteratorC iterator_C(accumulator_tensor.host_ref(), tid); -+ iterator_C.store(accumulator_tile); -+ } -+ -+ /* -+ std::ofstream output("volta_mma_f32_64x64x4.csv"); -+ output << accumulator_tensor.host_view() << std::endl; -+ */ -+ -+ for (int tid = 0; tid < 1; ++tid) { -+ typename MmaTensorOp::IteratorC::Fragment accumulator_tile; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< -+ cutlass::gemm::GemmShape<64, 64, 4>, -+ cutlass::gemm::GemmShape<32, 32, 4>, -+ ElementC, -+ LayoutC -+ >; -+ -+ FragmentIterator frag_iterator(accumulator_tile); -+ -+ for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { -+ -+ typename FragmentIterator::Fragment frag; -+ frag_iterator.load(frag); -+ ++frag_iterator; -+ -+ #if 0 -+ std::cout << "Iteration: " << iter << " - T" << tid << ": "; -+ -+ for (int i = 0; i < frag.size(); ++i) { -+ std::cout << " " << frag[i]; -+ } -+ -+ std::cout << std::endl; -+ #endif -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu -new file mode 100644 -index 0000000..fe3f47a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/epilogue/warp/fragment_iterator_wmma_tensor_op.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_wmma.h" -+ -+#include "cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h" -+#include "cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Epilogue_warp_FragmentIterator, wmma_f16_64x64x16) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ #if 0 -+ // -+ // Enable this code block to print comments for debugging. -+ // -+ -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " OperatorCount: (" << FragmentIterator::Policy::OperatorCount::kRow <<", "<; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Wmma< -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ Policy -+ >; -+ -+ using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< -+ Shape, -+ typename MmaTensorOp::Policy::Operator::Shape, -+ typename MmaTensorOp::Policy::Operator::ElementC, -+ typename MmaTensorOp::Policy::Operator::FragmentC, -+ cutlass::layout::RowMajor -+ >; -+ -+ #if 0 -+ // -+ // Enable this code block to print comments for debugging. -+ // -+ std::cout << "FragmentIterator::Policy = { \n" -+ << " OperatorCount: (" << FragmentIterator::Policy::OperatorCount::kRow <<", "< -+struct DefaultGemmConfigurationToCutlass3Types { -+ static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct DefaultGemm_TensorOpSm80_OperandA; -+ -+template -+struct DefaultGemm_TensorOpSm80_OperandB; -+ -+// -+// F16: 128-by-128-by-64 -+// -+ -+/// Operand A - Row-major (K-Major) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,3,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Column-major (M-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,3,3>{}, -+ Layout, -+ Stride< _1,_64>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _1,_16>>{}, -+ Layout>{})); -+}; -+ -+// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands -+ -+// Operand B - Column-Major (K-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// Operand B - Row-Major (N-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// -+// F16: 128-by-128-by-32 (small k-block) -+// -+ -+/// Operand A - Row-major (K-Major) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<2,3,3>{}, -+ Layout, -+ Stride<_32, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, half_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>{})); -+}; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere MMA F32F16 -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ half_t, LayoutA, -+ half_t, LayoutB, -+ float, LayoutC, -+ float> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x16 MMA and LDSM -+ -+ // A -+ static constexpr int kAlignmentA = 8; -+ using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< -+ half_t, LayoutA, kAlignmentA, 32>; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 8; -+ using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< -+ half_t, LayoutB, kAlignmentB, 32>; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ half_t, TagToStrideA_t, -+ half_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+// -+// TF32: 128-by-128-by-kblock (kBlock = 16, 32) -+// -+ -+/// Operand A - Row-major (K-major) (kBlock = 32) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,2,3>{}, -+ Layout, -+ Stride<_32, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Row-major (K-major) (kBlock = 16) -+template <> -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<2,2,3>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ using SmemCopyAtom = Copy_Atom; -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>{})); -+}; -+ -+/// Operand A - Column-major (M-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandA -+{ -+ // Smem -+ using SmemLayoutAtom = decltype( -+ composition(Swizzle<3,2,3>{}, -+ Layout, -+ Stride< _1,_32>>{})); -+ using SmemCopyAtom = Copy_Atom, tfloat32_t>; -+ // Gmem -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, tfloat32_t>{}, -+ Layout, -+ Stride< _1,_16>>{}, -+ Layout>{})); -+}; -+ -+// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands -+ -+// Operand B - Column-Major (K-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+// Operand B - Row-Major (N-major) -+template -+struct DefaultGemm_TensorOpSm80_OperandB -+ : DefaultGemm_TensorOpSm80_OperandA -+{}; -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere MMA F32TF32 -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ tfloat32_t, LayoutA, -+ tfloat32_t, LayoutB, -+ float, LayoutC, -+ float> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x8 and LDSM -+ -+ // A -+ static constexpr int kAlignmentA = 4; -+ using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< -+ tfloat32_t, LayoutA, kAlignmentA, 32>; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 4; -+ using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< -+ tfloat32_t, LayoutB, kAlignmentB, 32>; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ tfloat32_t, TagToStrideA_t, -+ tfloat32_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+template -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _64>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>, // 2x2x1 thread group -+ Layout>>; // 1x2x1 value group for 16x16x32 and LDSM -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = decltype( -+ composition( -+ Swizzle<2,4,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ static constexpr int kAlignmentA = 16; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, int8_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>>{})); -+ // LDS.32- or LDSM-based copy atom -+ // using SmemCopyAtomA = Copy_Atom; -+ using SmemCopyAtomA = Copy_Atom; // LDSM works -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = decltype( -+ composition( -+ Swizzle<2,4,3>{}, -+ Layout, -+ Stride<_64, _1>>{})); -+ static constexpr int kAlignmentB = 16; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, int8_t>{}, -+ Layout, -+ Stride< _4,_1>>{}, -+ Layout>>{})); -+ -+ // LDS.32- or LDSM-based copy atom -+ // using SmemCopyAtomB = Copy_Atom; -+ using SmemCopyAtomB = Copy_Atom; // LDSM works -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ int8_t, TagToStrideA_t, -+ int8_t, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+//////////////////////////// SIMT TWO STAGE /////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct DefaultGemm_Simt_OperandA; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct DefaultGemm_Simt_OperandA -+{ -+ using SmemLayoutAtom = Layout, -+ Stride< _1,_128>>; -+ -+ using SmemCopyAtom = Copy_Atom; -+ -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, Element>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+}; -+ -+template -+struct DefaultGemm_Simt_OperandA -+{ -+ using SmemLayoutAtom = Layout, -+ Stride< _1,Int<128 + 4>>>; // Padded -+ -+ using SmemCopyAtom = Copy_Atom; -+ -+ using GmemTiledCopy = decltype( -+ make_tiled_copy(Copy_Atom, Element>{}, -+ Layout, -+ Stride< _8, _1>>{}, -+ Layout>{})); -+ -+}; -+ -+template -+struct DefaultGemm_Simt_OperandB; -+ -+template -+struct DefaultGemm_Simt_OperandB -+ : DefaultGemm_Simt_OperandA {}; -+ -+template -+struct DefaultGemm_Simt_OperandB -+ : DefaultGemm_Simt_OperandA {}; -+ -+} // end namespace detail -+ -+// SIMT Two Stage -+template < -+ class ArchTag, -+ class ElementA, class LayoutA, -+ class ElementB, class LayoutB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _8>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>>; -+ -+ // A -+ static constexpr int kAlignmentA = 1; -+ using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; -+ using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; -+ using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; -+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; -+ -+ // B -+ static constexpr int kAlignmentB = 1; -+ using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; -+ using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; -+ using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; -+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+ -+// -+// DP4A - int8 Proof-of-concept -+// -+ -+// SIMT Two Stage TN - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; // Tile of atoms (threads) -+ -+ // A (M,K) K-major -+ using ElementA = int8_t; -+ // 40% from regular M and N major layout -+ // using SmemLayoutAtomA = Layout, -+ // Stride< _1,_128>>; -+ // 80% from interleaved layouts -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 4; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using ElementB = int8_t; -+ // 40% from regular M and N major layout -+ // using SmemLayoutAtomB = Layout, -+ // Stride< _1,_128>>; -+ // 80% from interleaved layouts -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 4; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage NN - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ -+ using DispatchPolicy = MainloopSm70TwoStage; -+ -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) M-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 4; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilouge -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage NT - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) M-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Two Stage TT - idp4a -+template < -+ class ArchTag, -+ class ElementC, class LayoutC> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, ArchTag, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ int32_t> -+{ -+ using TileShape = Shape<_128, _128, _32>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm70TwoStage; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using ElementA = int8_t; -+ using SmemLayoutAtomA = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 4; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using ElementB = int8_t; -+ using SmemLayoutAtomB = Layout>, -+ Stride< _4, Stride<_1,_512>>>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride< _1,_32>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage NT -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Layout<_2,_16>,Underscore>>; -+ -+ // A (M,K) M-major -+ using SmemLayoutAtomA = Layout>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // B (N,K) N-major -+ using SmemLayoutAtomB = Layout>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage TN -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage NN -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Underscore,Underscore>>; -+ -+ // A (M,K) M-major -+ using SmemLayoutAtomA = Layout>; -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// SIMT Multi Stage TT -+template < -+ class ElementA, -+ class ElementB, -+ class ElementC, class LayoutC, -+ class ElementAccumulator> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassSimt, arch::Sm80, -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, LayoutC, -+ ElementAccumulator> -+{ -+ using TileShape = Shape<_128, _128, _16>; -+ static constexpr int ThreadCount = 256; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom>, -+ Layout>, -+ Layout>, -+ Tile,Underscore>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = Layout, -+ Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, ElementA>{}, -+ Layout, -+ Stride<_16, _1>>{})); -+ -+ // B (N,K) N-major -+ using SmemLayoutAtomB = Layout>; -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, ElementB>{}, -+ Layout>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ ElementA, TagToStrideA_t, -+ ElementB, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA TN (K-Major A and K-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) K-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // B (N,K) K-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+ -+/* -+ using EpilogueOutputOp = epilogue::collective::Epilogue< -+ epilogue::thread::LinearCombination, -+ Layout, -+ Stride< _1,_64>>, // SMEM layout -+ Copy_Atom,double>, // R2S with tiled_mma layout -+ decltype(make_tiled_copy(Copy_Atom,double>{},// S2R -+ Layout, -+ Stride< _1,_16>>{}, // Thread layout -+ Layout>{})), // Value layout -+ Copy_Atom,double> // R2G with S2R_dst layout -+ >; -+*/ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA NN (M-Major A and K-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) M-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // B (N,K) K-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{}));// N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 1; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA NT (M-Major A and N-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) M-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // B (N,K) N-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Ampere fp64 MMA TT (K-Major A and N-Major B) -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, // Atom -+ Layout>, // Atom layout -+ Layout>, // Val layout -+ Tile,Layout<_2,_16>,Underscore>>; // Mode permutations -+ -+ // A (M,K) K-Major -+ using SmemLayoutAtomA = decltype( -+ composition(SwizzleXor<2,0,2>{}, -+ Layout, -+ Stride<_1, _4>>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 1; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride<_16, _1>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 1x1 doubles -+ -+ // B (N,K) N-Major -+ using SmemLayoutAtomB = decltype( -+ composition(SwizzleXor<2,2,0>{}, -+ Layout, -+ Stride< _1,_16>>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, // CopyAtom -+ Layout, -+ Stride< _1,_16>>{}, // ThrLayout for CopyAtom -+ Layout>{})); // Value layout: 2x1 doubles -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+// Hopper fp64 MMA TN -+template <> -+struct DefaultGemmConfigurationToCutlass3Types< -+ arch::OpClassTensorOp, arch::Sm90, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double> -+{ -+ using TileShape = Shape<_128, _64, _16>; -+ static constexpr int ThreadCount = 128; -+ using DispatchPolicy = MainloopSm80CpAsync<3>; -+ using TiledMma = TiledMMA< -+ MMA_Atom, -+ Layout>>; -+ -+ // A (M,K) K-major -+ using SmemLayoutAtomA = decltype( -+ make_ordered_layout(Shape<_128,_16>{}, -+ Step < _2, _1>{})); // M, K -+ using SmemCopyAtomA = Copy_Atom; -+ static constexpr int kAlignmentA = 2; -+ using GmemTiledCopyA = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // B (N,K) K-major -+ using SmemLayoutAtomB = decltype( -+ make_ordered_layout(Shape<_64,_16>{}, -+ Step < _2, _1>{})); // N, K -+ using SmemCopyAtomB = Copy_Atom; -+ static constexpr int kAlignmentB = 2; -+ using GmemTiledCopyB = decltype( -+ make_tiled_copy(Copy_Atom, double>{}, -+ Layout, -+ Stride< _8,_1>>{}, -+ Layout>{})); -+ -+ // Mainloop -+ using CollectiveMainloop = collective::CollectiveMma< -+ DispatchPolicy, TileShape, -+ double, TagToStrideA_t, -+ double, TagToStrideB_t, -+ TiledMma, -+ GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A -+ GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B -+ >; -+ -+ // Epilogue -+ using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< -+ TagToStrideC_t, -+ TagToStrideC_t, -+ epilogue::thread::LinearCombination>; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace cutlass -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..45c1d80 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..67bcb85 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,379 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 1024>, -+ cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..6c8ab54 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..445fa88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,232 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..c819148 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,380 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x1024_64x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 1024>, -+ cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x1024_32x64x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 1024>, -+ cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x1024_64x32x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 1024>, -+ cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x1024_32x32x1024, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 1024>, -+ cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_XOR_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, -+ false, cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..755661f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 512>, -+ cutlass::gemm::GemmShape<32, 64, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::uint1b_t, -+ cutlass::layout::RowMajor, -+ cutlass::uint1b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, -+ cutlass::gemm::GemmShape<8, 8, 128>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, 128, 128, false, -+ cutlass::arch::OpXorPopc>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a25d8aa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16n_bf16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,359 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, -+ cutlass::bfloat16_t, cutlass::layout::ColumnMajor, ElementOutput, -+ cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3dfd4f1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::bfloat16_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu -new file mode 100644 -index 0000000..00c64b7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (half_ulp_truncate) -+// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) -+// Output data type: complex -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu -new file mode 100644 -index 0000000..146e2ec ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Operands data type: complex -+// Rounding: float -> tfloat32_t (round to nearest) -+// Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) -+// Output data type: complex -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { -+ -+ using Element = cutlass::complex;; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..9164326 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu -new file mode 100644 -index 0000000..cc94303 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d93f3fb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..e2931b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..60ec7a8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu -new file mode 100644 -index 0000000..eb011e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d9d171b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..c0333e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+ -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 16, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..e98764e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu -@@ -0,0 +1,114 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/epilogue/threadblock/epilogue_direct_store.h" -+#include "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_DirectStore_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ // Define the GEMM kernel -+ using GemmBase = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 4, // This is the vector size of the epilogue. -+ ElementAccumulator, -+ ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 8, -+ 8 -+ >; -+ -+ // Define the direct store epilogue -+ using EpilogueDirectStore = typename cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< -+ typename GemmBase::GemmKernel::Epilogue -+ >::Epilogue; -+ -+ // Define a new kernel -+ using Kernel = cutlass::gemm::kernel::GemmUniversal< -+ typename GemmBase::GemmKernel::Mma, -+ EpilogueDirectStore, -+ typename GemmBase::GemmKernel::ThreadblockSwizzle -+ >; -+ -+ // Define the adaptor -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..4fc49a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..f4912ee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,154 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..2808f9d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..274d41f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..c7c894b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..9bdf56d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..411e95b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,404 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..c99f75a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,403 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..6bf8ae5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5c398c3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,343 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..4ff878b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..a123d29 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32_brief) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..4811c9d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..4c7ce79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..bba2b6d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..350d7e9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16n_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..eff38fd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..ddd426d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu -new file mode 100644 -index 0000000..7392cf9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm75.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu -new file mode 100644 -index 0000000..468e698 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu -new file mode 100644 -index 0000000..5fd2fb7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..90fd6d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64> , -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu -new file mode 100644 -index 0000000..ebe3acb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64> , -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a94f4ac ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu -@@ -0,0 +1,81 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..969f54b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_volta_tensor_op_f16_sm70.cu -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..ca8eeeb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..8b15b0d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fe7e1b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..3b3d293 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7387b99 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+#include "testbed_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64_sk, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32, { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..60a4760 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,272 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if (CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..02c4ddf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,267 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..ef37420 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16n_f16t_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16n_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..cf6d5dd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 64x64x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x64_64x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..9156d8e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..8bfed3f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 128x128x32_64x64x16_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16n_wmma_tensor_op_f32, 64x64x32_64x64x16_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..9e5973a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,321 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 64x64x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x64_64x32x64_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_singlestage_wmma_tensor_op_f16, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -new file mode 100644 -index 0000000..b69a304 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -@@ -0,0 +1,440 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for GEMM + broadcast interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal.h" -+#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/activation.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -+#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_elementwise.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+template -+struct TestbedUtils { -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; // Input A -+ cutlass::HostTensor tensor_B; // Input B -+ cutlass::HostTensor tensor_C; // Input C -+ cutlass::HostTensor tensor_D1; // Input D -+ cutlass::HostTensor tensor_D2; // Input D -+ cutlass::HostTensor tensor_Y1; // Input Y -+ cutlass::HostTensor tensor_Y2; // Input Y -+ cutlass::HostTensor tensor_Y_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedUtils( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::AllZeros) { -+ cutlass::reference::host::TensorFill(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize({1, problem_size.n()}); -+ tensor_D1.resize(problem_size.mn()); -+ tensor_D2.resize(problem_size.mn()); -+ tensor_Y1.resize(problem_size.mn()); -+ tensor_Y2.resize(problem_size.mn()); -+ tensor_Y_ref.resize(problem_size.mn()); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // Initialize D data to smaller data range. This helps avoid large roundoff errors. -+ int d_scope_min = -2; -+ int d_scope_max = 2; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_D1.host_view(), seed + 2016, d_scope_max, d_scope_min, 0); -+ cutlass::reference::host::TensorFillRandomUniform(tensor_D2.host_view(), seed + 2015, d_scope_max, d_scope_min, 0); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0)); -+ EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0)); -+ EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = GemmElement(1); -+ tensor_B.host_view().at({0, 0}) = GemmElement(1); -+ tensor_C.host_view().at({0, 0}) = GemmElement(1); -+ tensor_D1.host_view().at({0, 0}) = GemmElement(1); -+ tensor_D2.host_view().at({0, 0}) = GemmElement(1); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D1.sync_device(); -+ tensor_D2.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor& tensor_Y_ref, cutlass::HostTensor& tensor_Y) { -+ -+ tensor_Y_ref.sync_host(); -+ tensor_Y.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0); -+ -+ bool passed = true; -+ float norm_diff = 0; -+ -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float()); -+ passed = (norm_diff <= 0.1f); -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect"; -+ -+ -+ if (!passed) { -+ std::ofstream file("errors_testbed_gemm_broadcast_new.txt"); -+ -+ -+ file -+ << "problem: " << problem_size << "\n\n"; -+ -+ file -+ << "capacity: \n" -+ << "A: " << tensor_A.capacity() -+ << "\nB: " << tensor_B.capacity() -+ << "\nC: " << tensor_C.capacity() -+ << "\nD1: " << tensor_D1.capacity() -+ << "\nD2: " << tensor_D2.capacity() -+ << "\nY: " << tensor_Y.capacity() -+ << "\n\n" -+ << "\nY_ref: " << tensor_Y_ref.capacity() -+ << "\n\n"; -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\n\nB =\n" << tensor_B.host_view() -+ << "\n\nC =\n" << tensor_C.host_view() -+ << "\n\nD1 =\n" << tensor_D1.host_view() -+ << "\n\nD2 =\n" << tensor_D2.host_view() -+ << "\n\nY =\n" << tensor_Y.host_view() -+ << "\n\nY_ref =\n" << tensor_Y_ref.host_view(); -+ } -+ -+ return passed; -+ } -+}; -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+TEST(SM80_Device_GemmWithBroadcast_f16t_f16n_f16t_tensor_op_f16, 128x128_32x3_64x64x32_16x8x16) { -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using OpClass = cutlass::arch::OpClassTensorOp; -+ using ArchTag = cutlass::arch::Sm80; -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; -+ const int kStages = 3; -+ -+ const int batch_count = 1; -+ const cutlass::half_t alpha(1); -+ const cutlass::half_t beta(1); -+ -+ const int M = 1024; -+ const int K = 10240; -+ const int N = 512; -+ cutlass::gemm::GemmCoord problem{M, N, K}; -+ -+ const int batch_stride_A = 0; -+ const int batch_stride_B = 0; -+ const int batch_stride_C1 = 0; -+ const int batch_stride_C2 = 0; -+ const int batch_stride_D = 0; -+ const int batch_stride_Vector = 0; -+ const int batch_stride_Tensor = 0; -+ -+ const int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ const int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ const int64_t ldc1 = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldc2 = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldd = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ const int64_t ldv = 0; -+ const int64_t ldt = 0; -+ -+ TestbedUtils utils; -+ utils.initialize(problem); -+ -+ // -+ // Create reference Gemm -+ // -+ using GemmRef = cutlass::gemm::device::GemmUniversal< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmRef::Arguments args_ref{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_C.device_data(), -+ utils.tensor_Y_ref.device_data(), -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_D, -+ lda, -+ ldb, -+ ldv, -+ ldd, -+ }; -+ -+ GemmRef gemm_op_ref; -+ size_t workspace_size_ref = GemmRef::get_workspace_size(args_ref); -+ cutlass::device_memory::allocation workspace_ref(workspace_size_ref); -+ cutlass::Status status = gemm_op_ref.initialize(args_ref, workspace_ref.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_ref(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // -+ // Create GemmWithBroadcast from single source -+ // -+ using GemmSingle = cutlass::gemm::device::GemmUniversalWithBroadcast< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementOutput, ElementAccumulator, ElementAccumulator, -+ ElementAccumulator, 128 / cutlass::sizeof_bits::value, -+ cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmSingle::Arguments args_single{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_D1.device_data(), -+ utils.tensor_Y1.device_data(), -+ utils.tensor_C.device_data(), -+ /* ptr_Tensor = */ nullptr, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_D, -+ batch_stride_Vector, -+ batch_stride_Tensor, -+ lda, -+ ldb, -+ ldc1, -+ ldd, -+ ldv, -+ ldt -+ }; -+ -+ GemmSingle gemm_op_single; -+ size_t workspace_size_single = GemmSingle::get_workspace_size(args_single); -+ cutlass::device_memory::allocation workspace_single(workspace_size_single); -+ status = gemm_op_single.initialize(args_single, workspace_single.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_single(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // Compute the broadcast on the reference previously computed and compare results -+ utils.tensor_Y_ref.sync_host(); -+ cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view()); -+ utils.tensor_Y_ref.sync_device(); -+ utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y1); -+ -+ // -+ // Create GemmWithBroadcast from two sources -+ // -+ using GemmDouble = cutlass::gemm::device::GemmUniversalWithBroadcast< -+ ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, -+ OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, -+ cutlass::epilogue::thread::LinearCombinationResidualBlock< -+ ElementOutput, ElementAccumulator, ElementAccumulator, -+ ElementAccumulator, 128 / cutlass::sizeof_bits::value, -+ cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>, -+ ThreadblockSwizzle, kStages>; -+ -+ typename GemmDouble::Arguments args_double{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem, -+ batch_count, -+ {alpha, beta}, -+ utils.tensor_A.device_data(), -+ utils.tensor_B.device_data(), -+ utils.tensor_D1.device_data(), -+ utils.tensor_D2.device_data(), -+ utils.tensor_Y2.device_data(), -+ utils.tensor_C.device_data(), -+ /* ptr_Tensor = */ nullptr, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C1, -+ batch_stride_C2, -+ batch_stride_D, -+ batch_stride_Vector, -+ batch_stride_Tensor, -+ lda, -+ ldb, -+ ldc1, -+ ldc2, -+ ldd, -+ ldv, -+ ldt -+ }; -+ -+ GemmDouble gemm_op_double; -+ size_t workspace_size_double = GemmDouble::get_workspace_size(args_double); -+ cutlass::device_memory::allocation workspace_double(workspace_size_double); -+ status = gemm_op_double.initialize(args_double, workspace_double.get()); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ status = gemm_op_double(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); -+ -+ // Compute the broadcast on the reference previously computed and compare results -+ utils.tensor_Y_ref.sync_host(); -+ cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view()); -+ utils.tensor_Y_ref.sync_device(); -+ utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y2); -+} -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu -new file mode 100644 -index 0000000..f595fd6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm75.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 64x64x64_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu -new file mode 100644 -index 0000000..0881964 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu -@@ -0,0 +1,89 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k, 128x64x64_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu -new file mode 100644 -index 0000000..c3b3094 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm75.cu -@@ -0,0 +1,242 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu -new file mode 100644 -index 0000000..0343f0b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sm80.cu -@@ -0,0 +1,345 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x32_64x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x32_64x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x32_32x32x32) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu -new file mode 100644 -index 0000000..2c9402f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x128_64x64x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x128_64x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x128_32x32x128) { -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..f986113 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_volta_tensor_op_f16_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x256x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 256x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_volta_tensor_op_f16, 64x64x32_32x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..be966e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..afd4fbf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,402 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..5f6b77d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,158 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fab5576 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_64x32x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32, 128x128x32_64x32x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ static const int kStages = 1; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..5aa81be ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..708b4df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..e1d9381 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,271 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..32f4923 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,274 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..70298cf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16n_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16n_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..129d4a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..6fcdcee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,155 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16n_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu -new file mode 100644 -index 0000000..71fea92 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f16_sm70.cu -@@ -0,0 +1,405 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f16, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..fc980ac ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f16t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,403 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x64x32_16x16x16) { -+ // single cta, two warps vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ // single cta, two warps horizontally two waprs vertically -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x128x32_32x64x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F16=>F16 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f16t_wmma_tensor_op_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..9612d76 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..f89b076 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..6a32a33 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32n_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,156 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32n_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..d21ee87 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5a3d6d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..3bafd4d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, -+ cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..b2a2397 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_volta_tensor_op_f32_sm70.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_volta_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ENABLE_TENSOR_CORE_MMA) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu -new file mode 100644 -index 0000000..4572356 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f16t_f16t_f32t_wmma_tensor_op_f32_sm70.cu -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 16x16x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x256x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 256x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x64x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x128x32_64x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 64x64x32_32x32x32_16x16x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 32x8x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_32x8x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x32x16, DataType/Instruction = F16*F16+F32=>F32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_Device_Gemm_f16t_f16t_f32t_wmma_tensor_op_f32, 128x128x32_64x64x32_8x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu -new file mode 100644 -index 0000000..c83f7a7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_bf16_f32_sm80.cu -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface using BF16. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 4, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastBF16 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..bd71b19 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,88 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..895f175 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..c37d48c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..43bb129 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,428 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu -new file mode 100644 -index 0000000..4500835 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu -@@ -0,0 +1,429 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..d733d8d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64an_f64at_f64at_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ LayoutA, -+ double, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..62cb15d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..8961ab7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,259 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64at_f64an_f64at_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ LayoutA, -+ double, -+ LayoutB, -+ ElementOutput, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..881d81c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface with Hopper FP64 -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ using ElementCompute = double; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu -new file mode 100644 -index 0000000..4fed1dc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped GEMM problem visitors -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "testbed_grouped_scheduler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Run a series of tests on the testbed -+template -+void run_tests() { -+ for (int scale_factor : {8, 16, 32, 64}) { -+ for (int threadblock_count : {54, 108, 216, 324, 432}) { -+ for (int problems : {1, 27, 180, 300}) { -+ Testbed testbed; -+ testbed.run(problems, threadblock_count, scale_factor); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p128_t128, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p128_t128_transpose, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t128, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 128; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = false; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static bool const kTranspose = true; -+ -+ using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kTranspose, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu -new file mode 100644 -index 0000000..3fa3519 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_grouped_sm80.cu -@@ -0,0 +1,859 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Visitor class to abstract away the algorithm for iterating over tiles. -+// -+// This is the prototype. We will delete this when the efficient kernel is -+// available. -+struct GemmGroupedProblemVisitor { -+ -+ struct Params { -+ cutlass::gemm::GemmCoord const *problem_sizes; -+ int32_t problem_count; -+ int64_t const *tile_count; -+ }; -+ -+ struct SharedStorage { -+ // -+ // Nothing for now. As an optimization step, we could consider parallel -+ // argmin or prefix sums across the block. -+ // -+ }; -+ -+ // -+ // Data members -+ // -+ -+ SharedStorage &shared_storage; -+ Params const ¶ms; -+ cutlass::MatrixCoord threadblock_shape; -+ -+ int64_t tile_idx; -+ int64_t tile_count_sum; -+ int64_t problem_tile_start; -+ int32_t problem_idx; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ GemmGroupedProblemVisitor( -+ SharedStorage &shared_storage_, -+ Params const ¶ms_, -+ cutlass::MatrixCoord threadblock_shape_, -+ int32_t block_idx -+ ): -+ shared_storage(shared_storage_), -+ params(params_), -+ threadblock_shape(threadblock_shape_), -+ tile_idx(block_idx), -+ tile_count_sum(0), -+ problem_idx(0) -+ { -+ -+ cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; -+ -+ cutlass::gemm::GemmCoord grid = grid_shape(problem); -+ -+ problem_tile_start = 0; -+ tile_count_sum = grid.m() * grid.n(); -+ } -+ -+ /// Get the grid shape -+ CUTLASS_HOST_DEVICE -+ static cutlass::gemm::GemmCoord grid_shape( -+ cutlass::gemm::GemmCoord const &problem, -+ cutlass::MatrixCoord const & block_shape) { -+ -+ return cutlass::gemm::GemmCoord( -+ ((problem.m() - 1 + block_shape.row()) / block_shape.row()), -+ ((problem.n() - 1 + block_shape.column()) / block_shape.column()), -+ 1); -+ } -+ -+ /// Get the grid shape -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const &problem) const { -+ return grid_shape(problem, threadblock_shape); -+ } -+ -+ /// Returns true if there is a tile to compute -+ CUTLASS_DEVICE -+ bool next_tile() { -+ -+ if (tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++problem_idx; -+ -+ if (problem_idx >= params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; -+ cutlass::gemm::GemmCoord grid = grid_shape(problem); -+ -+ int64_t tile_count = grid.m() * grid.n(); -+ -+ problem_tile_start = tile_count_sum; -+ tile_count_sum += tile_count; -+ -+ } while (tile_count_sum <= tile_idx); -+ -+ return true; -+ } -+ -+ /// Gets the global tile index -+ CUTLASS_HOST_DEVICE -+ int64_t tile_index() const { -+ return tile_idx; -+ } -+ -+ /// Gets the index of the problem -+ CUTLASS_HOST_DEVICE -+ int32_t problem_index() const { -+ return problem_idx; -+ } -+ -+ /// Returns the problem size for the current problem -+ CUTLASS_HOST_DEVICE -+ cutlass::gemm::GemmCoord problem_size() const { -+ return params.problem_sizes[problem_idx]; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ int64_t threadblock_idx() const { -+ return tile_idx - problem_tile_start; -+ } -+ -+ CUTLASS_DEVICE -+ void advance(int32_t grid_size) { -+ tile_idx += grid_size; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { -+ -+ __shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage; -+ -+ GemmGroupedProblemVisitor problem_visitor( -+ shared_storage, -+ params, -+ {ThreadblockShapeM, ThreadblockShapeN}, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ -+ cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); -+ int64_t threadblock_idx = problem_visitor.threadblock_idx(); -+ -+ cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ -+ int threadblock_tile_m_idx = int(threadblock_idx / grid_shape.n()); -+ int threadblock_tile_n_idx = int(threadblock_idx % grid_shape.n()); -+ -+ // -+ // Do the MMA -+ // -+ -+ if (threadIdx.x == 0) { -+ #if 0 -+ printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n", -+ blockIdx.x, -+ problem_visitor.tile_index(), -+ problem_visitor.problem_index(), -+ threadblock_idx, -+ threadblock_tile_m_idx, -+ threadblock_tile_n_idx); -+ #endif -+ } -+ -+ // Next tile -+ problem_visitor.advance(gridDim.x); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { -+ -+ int32_t problem_count = 16; -+ -+ int const kThreadblockShapeM = 64; -+ int const kThreadblockShapeN = 64; -+ -+ std::vector problem_sizes(problem_count); -+ std::vector tile_counts(problem_count); -+ -+ // construct a few problems of random sizes -+ srand(1921); -+ for (int32_t i = 0; i < problem_count; ++i) { -+ problem_sizes.at(i) = cutlass::gemm::GemmCoord( -+ 8 * (rand() % 48) + 64, -+ 8 * (rand() % 48) + 64, -+ 8 * (rand() % 48) + 64); -+ } -+ -+ // compute prefix sum -+ int64_t tile_count = 0; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord grid_shape = GemmGroupedProblemVisitor::grid_shape( -+ problem_sizes.at(i), {kThreadblockShapeM, kThreadblockShapeN}); -+ -+ int32_t problem_tile_count = (grid_shape.m() * grid_shape.n()); -+ -+ int64_t tile_start = tile_count; -+ -+ tile_count += problem_tile_count; -+ tile_counts.at(i) = tile_count; -+ -+ if (false) { -+ std::cout << "Problem " << i << " size(" -+ << problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n() -+ << ") - tiles: " << problem_tile_count << ", grid(" << grid_shape.m() << ", " << grid_shape.n() -+ << "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl; -+ } -+ } -+ -+ // Copy to device memory -+ cutlass::DeviceAllocation problem_sizes_device(problem_count); -+ cutlass::DeviceAllocation tile_counts_device(problem_count); -+ -+ problem_sizes_device.copy_from_host(problem_sizes.data()); -+ tile_counts_device.copy_from_host(tile_counts.data()); -+ -+ GemmGroupedProblemVisitor::Params params; -+ params.problem_sizes = problem_sizes_device.get(); -+ params.problem_count = problem_count; -+ params.tile_count = tile_counts_device.get(); -+ -+ // Launch the kernel -+ dim3 grid(108, 1, 1); -+ dim3 block(128, 1, 1); -+ -+ GroupedBatchedKernel<<< grid, block >>>(params); -+ -+ // wait -+ cudaDeviceSynchronize(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::RowMajor, // row major -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16t_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementInput = double; -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 4>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x128x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x64x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x64x8_64x32x1) { -+ -+ using ElementInput = float; -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32t_tensorop_f32, 64x64x16_32x32x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementInput, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) { -+ -+ using ElementInput = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 1, -+ ElementInput, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 1, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 1, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::arch::OpMultiplyAddComplex>::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmGrouped; -+ -+ // -+ // Test -+ // -+ -+ test::gemm::device::TestbedGrouped testbed; -+ -+ bool passed = testbed.run(27); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu -new file mode 100644 -index 0000000..83e7cfa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu -@@ -0,0 +1,353 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_tn : gemm_planar_complex_s884_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_nt : gemm_planar_complex_s884_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_884, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s884_nn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s884_nn : gemm_planar_complex_s884_nn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f32n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_nn_128x64_32x2 : gemm_planar_complex_f16_s884_f16_nn_128x64_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_nn_64x128_32x2 : gemm_planar_complex_f16_s884_f16_nn_64x128_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16n_f16n_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_tt_128x64_32x2 : gemm_planar_complex_f16_s884_f16_tt_128x64_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 128x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ cutlass::half_t, -+ 8, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s884_f16_tt_64x128_32x2 : gemm_planar_complex_f16_s884_f16_tt_64x128_32x2_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmPlanarComplex_f16t_f16t_f16n_tensor_op_f32_884, 64x128x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu -new file mode 100644 -index 0000000..1f702ab ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu -@@ -0,0 +1,223 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_base.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_tn : gemm_planar_complex_s1688_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_hc : gemm_planar_complex_s1688_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_nt : gemm_planar_complex_s1688_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s1688_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s1688_ch : gemm_planar_complex_s1688_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_1688, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu -new file mode 100644 -index 0000000..beed868 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu -@@ -0,0 +1,393 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-level GEMM API for Planar Complex. -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "testbed_planar_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_tn : gemm_planar_complex_s16816_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_tn_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_tn : gemm_planar_complex_f16_s16816_tn_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16t_f16n_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_hc : gemm_planar_complex_s16816_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_hc_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_hc : gemm_planar_complex_f16_s16816_hc_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16h_f16c_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_nt : gemm_planar_complex_s16816_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_f16_s16816_nt_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_f16_s16816_nt : gemm_planar_complex_f16_s16816_nt_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16n_f16t_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_s16816_ch : gemm_planar_complex_s16816_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f32n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+using gemm_planar_complex_cf16_s16816_ch_base = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ 8, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationPlanarComplex< -+ float, -+ 4, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+>::GemmKernel; -+ -+struct gemm_planar_complex_cf16_s16816_ch : gemm_planar_complex_cf16_s16816_ch_base { -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmPlanarComplex_f16c_f16h_f16n_tensor_op_f32_16816, 64x64x32_32x32x32) { -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmPlanarComplex()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..f8505ed ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..0d95c50 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "multistage_testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4n_s4t_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ cutlass::int4b_t, -+ cutlass::layout::RowMajorInterleaved<64>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<64>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 32, -+ 32, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..fbb576f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..0c028e0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,360 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x256_32x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s4t_s4n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif //#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..23dc8eb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,248 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x8x32, DataType/Instruction = s4 * s4 + s32 => s32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32n_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..4016558 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..d962249 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,363 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x128_32x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu -new file mode 100644 -index 0000000..f9903f3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu -@@ -0,0 +1,267 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x512_64x64x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 512>, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x512_64x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 512>, -+ cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x512_32x32x512) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 512>, -+ cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..3df1180 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu -@@ -0,0 +1,247 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////// WMMA Instruction Shape = 8x8x32, DataType/Instruction = s4 * s4 + s32 => s32 ////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x256x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 256x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x128x128_64x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x128x128_32x64x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 128x64x128_64x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s32t_wmma_tensor_op_s32, 64x64x128_32x32x128_8x8x32) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..09a502b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu -@@ -0,0 +1,312 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..7b002d5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..525677a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu -@@ -0,0 +1,312 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { -+ -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, -+ cutlass::layout::RowMajor, -+ cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<8, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 32 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..ccaaabf ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x256_64x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x256_32x64x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 256>, -+ cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x256_64x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x256_32x32x256, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = cutlass::int4b_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 32 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..98b74e1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu -@@ -0,0 +1,293 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_interleaved.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 32x64x64_16x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ test::gemm::device::InterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..5dbb51e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8n_s8t_s8n_tensor_op_s32_sm80.cu -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "multistage_testbed_interleaved.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 64x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 256x64x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<32>, -+ ElementOutput, -+ cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 16, -+ 16, -+ false, -+ cutlass::arch::OpMultiplyAddSaturate -+ >; -+ -+ test::gemm::device::MultistageInterleavedTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..7a7b9df ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..4ce7d70 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu -@@ -0,0 +1,361 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x64_64x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..251e138 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32n_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..2c4ca98 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu -@@ -0,0 +1,249 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9ff2bcc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu -@@ -0,0 +1,361 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu -new file mode 100644 -index 0000000..c2f46a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu -@@ -0,0 +1,269 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_sparse.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x256_64x64x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 256>, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x256_64x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 256>, -+ cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x256_32x32x256) { -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = int32_t; -+ -+ using Gemm = cutlass::gemm::device::SparseGemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 256>, -+ cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); -+} -+ -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..08ae460 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s32t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..cb7401f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..18daa2b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..e0ffc83 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8n_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,177 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8n_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu -new file mode 100644 -index 0000000..a13a6eb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ test::gemm::device::Testbed testbed; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<8, 8, 16>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementCompute>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; -+ -+ test::gemm::device::Testbed testbed; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu -new file mode 100644 -index 0000000..9af1d01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu -@@ -0,0 +1,374 @@ -+/************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "multistage_testbed.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x128_64x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 128>, -+ cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x128_32x64x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 128>, -+ cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x128_64x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 128>, -+ cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x128_32x32x128, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, -+ cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x64x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x256x64_64x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 64>, -+ cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x128x64_32x64x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 128 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x64x64_64x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 64>, -+ cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 64x64x64_32x32x64, { -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, -+ ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, 64 / cutlass::sizeof_bits::value>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; -+ -+ test::gemm::device::MultistageTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..12ab891 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_s8t_s8n_s8t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,178 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = s8*s8+s32=>s8 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = s8*s8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_s8t_s8n_s8t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<8, 32, 16>, -+ cutlass::epilogue::thread::FastLinearCombinationClamp< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu -new file mode 100644 -index 0000000..30f0bb7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_serial_tensor_op_sm75.cu -@@ -0,0 +1,114 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKSerial_f16n_f16n_f16t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ static const int kStages = 2; -+ -+ static const int kAlignmentA = cutlass::gemm::device::DefaultGemmConfiguration< -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementOutput, -+ ElementAccumulator>::kAlignmentA; -+ -+ static const int kAlignmentB = cutlass::gemm::device::DefaultGemmConfiguration< -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ ElementA, -+ ElementB, -+ ElementOutput, -+ ElementAccumulator>::kAlignmentB; -+ -+ static const bool kSplitKSerial = true; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombinationRelu< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ kStages, -+ kAlignmentA, -+ kAlignmentB, -+ kSplitKSerial -+ >; -+ -+ bool result = test::gemm::device::TestAllGemm(); -+ EXPECT_TRUE(result); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu -new file mode 100644 -index 0000000..1b4eba2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_simt_sm50.cu -@@ -0,0 +1,146 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_GemmSplitKParallel_f32n_f32t_f32t_simt_f32, 128x128x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM50_Device_GemmSplitKParallel_f32n_f32n_f32n_simt_f32, 128x128x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_GemmSplitKParallel_f64n_f64n_f64t_simt_f64, 64x128x8) { -+ -+ using Element = double; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM50_Device_GemmSplitKParallel_f64t_f64t_f64n_simt_f64, 64x64x8) { -+ -+ using Element = double; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu -new file mode 100644 -index 0000000..5f15749 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm70.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+// These operators are assert(0) unless extended PTX is used. -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f32t_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+ -+TEST(SM70_Device_GemmSplitK_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu -new file mode 100644 -index 0000000..3d8e8db ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_splitk_tensor_op_sm75.cu -@@ -0,0 +1,336 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_splitk_parallel.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_splitk.h" -+ -+// These operators are assert(0) unless extended PTX is used. -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f32t_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f32n_tensor_op_f32, 64x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16n_f16t_f16n_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f32n_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16t_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+TEST(SM75_Device_GemmSplitKParallel_f16t_f16n_f16n_tensor_op_f16, 64x128x32_32x64x32) { -+ -+ using ElementOutput = cutlass::half_t; -+ using ElementAccumulator = cutlass::half_t; -+ -+ using Gemm = cutlass::gemm::device::GemmSplitKParallel< -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8> -+ >; -+ -+ test::gemm::device::TestAllGemmSplitK(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp -new file mode 100644 -index 0000000..24a9e24 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp -@@ -0,0 +1,717 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/packed_stride.hpp" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gett.hpp" -+ -+#include "testbed_utils.h" -+ -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cute/int_tuple.hpp" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail{ -+ -+template -+struct TestbedImpl { -+ // Kernel data types -+ using ElementA = typename Gemm::GemmKernel::ElementA; -+ using StrideA = typename Gemm::GemmKernel::StrideA; -+ using ElementB = typename Gemm::GemmKernel::ElementB; -+ using StrideB = typename Gemm::GemmKernel::StrideB; -+ using ElementC = typename Gemm::GemmKernel::ElementC; -+ using StrideC = typename Gemm::GemmKernel::StrideC; -+ using ElementD = typename Gemm::GemmKernel::ElementD; -+ using StrideD = typename Gemm::GemmKernel::StrideD; -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); -+ -+ // Looks at Cute Stride to check Row / Column Major -+ template -+ static constexpr bool is_row_or_col_major(){ -+ int stride_0 = int(cute::size<0>(Stride{})); -+ int stride_1 = int(cute::size<1>(Stride{})); -+ int depth = cute::depth(Stride{}); -+ return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); -+ } -+ -+ // Note: this limitation comes from testbed / not the library -+ static_assert(is_row_or_col_major(), -+ "ERROR : A Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : B Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : C Layout is neither Row / Column Major)"); -+ static_assert(is_row_or_col_major(), -+ "ERROR : D Layout is neither Row / Column Major)"); -+ -+ // Deduce Cutlass Layouts (RowMajor & ColumnMajor) -+ using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); -+ using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); -+ using LayoutTagPackedVector = cutlass::layout::PackedVectorLayout; -+ -+ /// Initialization -+ StrideA stride_a; -+ StrideB stride_b; -+ StrideC stride_c; -+ StrideD stride_d; -+ typename LayoutTagA::Stride stride_factor_A; -+ typename LayoutTagB::Stride stride_factor_B; -+ typename LayoutTagC::Stride stride_factor_C; -+ typename LayoutTagD::Stride stride_factor_D; -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ static constexpr uint64_t kDefaultSeed = 4096; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ uint32_t sm_count; -+ -+ // Used to force multi-wave tests for persistent kernel schedules -+ constexpr static int MaxSmCount = 16; -+ -+ // -+ // Methods -+ // -+ -+ TestbedImpl( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = kDefaultSeed -+ ): -+ stride_factor_A(typename LayoutTagA::Stride()), -+ stride_factor_B(typename LayoutTagB::Stride()), -+ stride_factor_C(typename LayoutTagC::Stride()), -+ stride_factor_D(typename LayoutTagD::Stride()), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ TestbedImpl( -+ typename LayoutTagA::Stride stride_factor_A_, -+ typename LayoutTagB::Stride stride_factor_B_, -+ typename LayoutTagC::Stride stride_factor_C_, -+ typename LayoutTagD::Stride stride_factor_D_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = kDefaultSeed -+ ): -+ stride_factor_A(stride_factor_A_), -+ stride_factor_B(stride_factor_B_), -+ stride_factor_C(stride_factor_C_), -+ stride_factor_D(stride_factor_D_), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } -+ else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ -+ else { -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(ProblemShapeType problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto M = cute::size<0>(problem_shape_MNKL); -+ auto N = cute::size<1>(problem_shape_MNKL); -+ auto K = cute::size<2>(problem_shape_MNKL); -+ auto L = cute::size<3>(problem_shape_MNKL); -+ -+ stride_a = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); -+ stride_b = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); -+ stride_c = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); -+ stride_d = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); -+ -+ // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode -+ auto a_coord = cutlass::make_Coord(M * L, K); -+ auto c_coord = cutlass::make_Coord(M * L, N); -+ // Cutlass has Row/Col major refers to MxK times KxN matrix product, -+ // so the HostTensorB should be treated as KxN in "coord"'s view -+ auto b_coord = cutlass::make_Coord(K, N * L); -+ -+ -+ tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); -+ tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); -+ tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); -+ tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); -+ reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = ElementA(1); -+ tensor_B.host_view().at({0, 0}) = ElementB(1); -+ tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cute::Shape problem_shape_MNKL, -+ ElementScalar alpha, -+ ElementScalar beta -+ ) { -+ auto [M, N, K, L] = problem_shape_MNKL; -+ -+ tensor_D.sync_host(); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) { -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ } -+ -+ if (reference_D.size() > 1) { -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ } -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::stringstream fname; -+ fname << "error_Gemm_device_" -+ << M << "x" << N << "x" << K << "x" << L << "_" -+ << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" -+ << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" -+ << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ file -+ << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L -+ << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\n\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ ProblemShapeType problem_size, -+ ElementScalar alpha, -+ ElementScalar beta -+ ) { -+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1); -+ auto M = cute::size<0>(problem_shape_MNKL); -+ auto N = cute::size<1>(problem_shape_MNKL); -+ auto K = cute::size<2>(problem_shape_MNKL); -+ auto L = cute::size<3>(problem_shape_MNKL); -+ -+ auto A = cute::make_tensor(tensor_A.host_data(), -+ cute::make_layout(cute::make_shape(M, K, L), stride_a)); -+ auto B = cute::make_tensor(tensor_B.host_data(), -+ cute::make_layout(cute::make_shape(N, K, L), stride_b)); -+ auto C = cute::make_tensor(tensor_C.host_data(), -+ cute::make_layout(cute::make_shape(M, N, L), stride_c)); -+ auto D = cute::make_tensor(reference_D.host_data(), -+ cute::make_layout(cute::make_shape(M, N, L), stride_d)); -+ cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; -+ -+ cutlass::reference::host::GettEpilogueParams< -+ ElementScalar, -+ ElementAccumulator, -+ ElementCompute, -+ decltype(C), -+ decltype(D) -+ > -+ epilogue_params{ -+ alpha, beta, -+ C, D -+ }; -+ -+ cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); -+ -+ return compare_reference( -+ problem_shape_MNKL, alpha, beta -+ ); -+ } -+ -+ /// Determine if the CUDA device is sufficient to run the kernel -+ bool sufficient() { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = Gemm::GemmKernel::SharedStorageSize; -+ -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ cudaDeviceProp properties; -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ this->sm_count = properties.multiProcessorCount; -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ bool profile( -+ ProblemShapeType problem_size, -+ int iterations, -+ Gemm& gemm_op, -+ typename Gemm::Arguments& arguments, -+ cutlass::device_memory::allocation& workspace) { -+ int M = cute::size<0>(problem_size); -+ int N = cute::size<1>(problem_size); -+ int K = cute::size<2>(problem_size); -+ int L = 1; -+ if constexpr(cute::rank(ProblemShapeType{}) == 4) { -+ L = cute::size<3>(problem_size); -+ } -+ -+ -+ cutlass::Status status; -+ // -+ // Run the GEMM -+ // -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ status = gemm_op(arguments, workspace.get()); -+ if (status != cutlass::Status::kSuccess) { -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ return false; -+ } -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ ProblemShapeType problem_size, -+ ElementScalar alpha = ElementScalar(1), -+ ElementScalar beta = ElementScalar(0), -+ bool profiling = false, -+ int iterations = 20 -+ ) { -+ // Fail test if insufficient CUDA device -+ if (!sufficient()) { -+ std::cout << "Test failed due to insufficient CUDA device." << std::endl; -+ return false; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments; -+ cutlass::KernelHardwareInfo hw_info; -+ hw_info.device_id = 0; -+ if (not profiling) { -+ this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); -+ hw_info.sm_count = this->sm_count; -+ } -+ else { -+ this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); -+ hw_info.sm_count = this->sm_count; -+ } -+ -+ // DefaultEpilogue -+ arguments = typename Gemm::Arguments{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ tensor_A.device_data(), -+ stride_a, -+ tensor_B.device_data(), -+ stride_b, -+ {tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d, {alpha, beta}}, -+ hw_info -+ }; -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.can_implement(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ if (profiling) { -+ return profile(problem_size, iterations, gemm_op, arguments, workspace); -+ } -+ else { -+ cudaError_t result; -+ status = gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op.run(); -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; -+ return false; -+ } -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ bool passed = this->verify( -+ problem_size, alpha, beta -+ ); -+ if (!passed) { -+ std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) -+ << "\n"; -+ } -+ -+ return passed; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Testbed { -+ -+ using TestBedImplementation = typename detail::TestbedImpl; -+ -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using LayoutTagA = typename TestBedImplementation::LayoutTagA; -+ using LayoutTagB = typename TestBedImplementation::LayoutTagB; -+ using LayoutTagC = typename TestBedImplementation::LayoutTagC; -+ using LayoutTagD = typename TestBedImplementation::LayoutTagD; -+ -+ // Detail Implementation -+ TestBedImplementation impl_; -+ -+ // -+ // Methods -+ // -+ Testbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = TestBedImplementation::kDefaultSeed) -+ : impl_(init_A_, init_B_, init_C_, seed_) {} -+ -+ Testbed( -+ typename LayoutTagA::Stride stride_factor_A_, -+ typename LayoutTagB::Stride stride_factor_B_, -+ typename LayoutTagC::Stride stride_factor_C_, -+ typename LayoutTagD::Stride stride_factor_D_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = TestBedImplementation::kDefaultSeed) -+ : impl_(stride_factor_A_, -+ stride_factor_B_, -+ stride_factor_C_, -+ stride_factor_D_, -+ init_A_, -+ init_B_, -+ init_C_, -+ seed_) {} -+ -+ /// Executes one test -+ bool run( -+ typename TestBedImplementation::ProblemShapeType problem_size, -+ ElementScalar alpha = ElementScalar(1), -+ ElementScalar beta = ElementScalar(0), -+ bool profiling = false, -+ int iterations = 20 -+ ) { -+ return impl_.run( -+ problem_size, alpha, beta, profiling, iterations -+ ); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAll() { -+ using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ -+ int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); -+ std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; -+ std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; -+ -+ if constexpr (std::is_same_v) { -+ problem_size_m.push_back(768); -+ problem_size_n.push_back(768); -+ } -+ -+ constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; -+ constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); -+ -+ std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; -+ -+ Testbed testbed; -+ bool passed = true; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ ProblemShapeType problem_size; -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ problem_size = ProblemShapeType{m, n, k, /* l */ 1}; -+ } -+ else { -+ problem_size = ProblemShapeType{m, n, k}; -+ } -+ -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ // if we do support batched GEMM, just run one test on it to save on test time -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmPerf(int iterations = 20) { -+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; -+ using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; -+ using ElementScalar = ElementAccumulator; -+ bool passed = true; -+ -+ std::vector problem_size_m = { 4608 }; -+ std::vector problem_size_n = { 4608 }; -+ std::vector problem_size_k = { 8192 }; -+ -+ Testbed testbed; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ ProblemShapeType problem_size; -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ problem_size = ProblemShapeType{m, n, k, /* l */ 1}; -+ } -+ else { -+ problem_size = ProblemShapeType{m, n, k}; -+ } -+ -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0), -+ true, -+ iterations -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ -+ -+ // if we do support batched GEMM, just run it once -+ if constexpr (cute::rank(ProblemShapeType{}) == 4) { -+ auto problem_size = ProblemShapeType{problem_size_m[0], problem_size_n[0], problem_size_k[0], /* l */ 4}; -+ passed = testbed.run( -+ problem_size, -+ cutlass::from_real(1), -+ cutlass::from_real(0), -+ true, -+ iterations -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ -+ return passed; -+} -+ -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1c33d50 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,555 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..375e7b9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,555 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3353368 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..3b243b2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,556 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<64, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x256x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 256x64x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 128x64x16_64x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu -new file mode 100644 -index 0000000..a39f29d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_u8t_u8n_s32t_wmma_tensor_op_s32_sm72.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 16x16x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x32x64_16x16x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 32x8x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_32x8x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+//////////////// WMMA Size = 8x32x16, DataType/Instruction = u8*u8+s32=>s32 ////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_Device_Gemm_u8t_u8n_s32t_wmma_tensor_op_s32, 64x128x64_32x64x64_8x32x16) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ uint8_t, -+ cutlass::layout::RowMajor, -+ uint8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassWmmaTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 128, 64>, -+ cutlass::gemm::GemmShape<32, 64, 64>, -+ cutlass::gemm::GemmShape<32, 8, 16>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..96981a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,199 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32n_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32n_cf32h_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32h_cf32t_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf32h_cf32c_cf32n_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..29103d0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,200 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64_gaussian, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64_gaussian, 64x32x32_32x16x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64_gaussian, 64x64x32_32x16x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..b82c5e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,200 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64n_cf64h_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64t_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmUniversal_cf64h_cf64c_cf64n_tensor_op_f64, 64x64x32_32x32x32) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 3, -+ 1, -+ 1, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..72c3d5d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32n_tensor_op_f32_sm75.cu -@@ -0,0 +1,117 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32n_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::ColumnMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2, -+ 1, -+ 1>; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmUniversal( -+ {128, 128, 2}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 15)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu -new file mode 100644 -index 0000000..771573b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu -@@ -0,0 +1,115 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm_universal.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_universal.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); -+} -+ -+TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversal< -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, cutlass::layout::RowMajor, -+ ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, ElementAccumulator>, -+ cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, -+ 2, -+ 1, -+ 1>; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmUniversal( -+ {128, 128, 2}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 15)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu -new file mode 100644 -index 0000000..8a1884d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu -@@ -0,0 +1,464 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" -+#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_broadcast.h" -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes: -+/// -+/// Z = GEMM+Bias+ReLu -+/// T = Relu conditional -+/// -+template -+struct GemmWithBiasReluReferenceOp { -+ -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ GemmWithBiasReluReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { -+ -+ ElementCompute kThreshold = ElementCompute(); -+ -+ ElementCompute z_full = binary_op(gemm, bias); -+ -+ bool conditional = (z_full >= kThreshold); -+ -+ if (!conditional) { -+ z_full = kThreshold; -+ } -+ -+ Z = ElementZ(z_full); -+ T = ElementT(conditional); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x5_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast>(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x4_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 4, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast>(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithBroadcast_GELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8, -+ cutlass::epilogue::thread::GELU_taylor -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast(); -+} -+ -+TEST(SM80_Device_GemmWithBroadcast_RELU_f16n_f16n_f16n_tensor_op_f32, 128x128_32x3_64x64x32_16x8x16) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasRelu< -+ cutlass::half_t, -+ float, -+ float, -+ cutlass::half_t, -+ 8, -+ true -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ test::gemm::device::TestAllGemmWithBroadcast >(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu -new file mode 100644 -index 0000000..15eca4b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu -@@ -0,0 +1,384 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/epilogue/thread/linear_combination_dgelu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct dReluLambda { -+ float operator()(float d_y, float t) { -+ if (t <= 0) { -+ d_y = 0; -+ } -+ return d_y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 1, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_bGrad_f16n_f16n_f16n_tensor_op_f32, 256x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 1, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+template -+struct Gemm_dReLU_packed_bits_reference_op { -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; -+ using ElementC = typename Gemm::ElementC; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ -+ // -+ // Methods -+ // -+ -+ Gemm_dReLU_packed_bits_reference_op() { } -+ -+ ElementCompute operator()( -+ ElementAccumulator d_y, -+ ElementT t) const { -+ -+ ElementCompute result = ElementCompute(d_y); -+ -+ bool cond = bool(t); -+ if (!cond) { -+ result = ElementCompute(); -+ } -+ -+ return result; -+ } -+}; -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< -+ float, -+ float, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm75, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_Device_GemmWithReduction_dReLU_conditional_bits_bGrad_f16n_f16n_f16n_tensor_op_f32, 128x128x32_64x64x8) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDReluConditionalBits< -+ float, -+ float, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm70, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 2, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::Gemm_dReLU_packed_bits_reference_op; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {520, 264, 96}, -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ 2, -+ float(1.25), -+ float(2.25) -+ ); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defiend(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -new file mode 100644 -index 0000000..3e04929 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -@@ -0,0 +1,118 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "cutlass/gemm/kernel/default_gemm_with_reduction.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+#include "cutlass/epilogue/thread/linear_combination_drelu.h" -+#include "cutlass/epilogue/thread/linear_combination_dgelu.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_gemm_with_reduction.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct dReluLambda { -+ float operator()(float d_y, float t) { -+ if (t <= 0) { -+ d_y = 0; -+ } -+ return d_y; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_GemmWithReduction_dReLU_bGrad_f16t_f16n_f16n_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationDRelu< -+ float, -+ float, -+ cutlass::half_t, -+ cutlass::half_t, -+ 8 -+ >; -+ -+ using GemmKernel = -+ typename cutlass::gemm::kernel::DefaultGemmWithReduction< -+ cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, // transposed B operand -+ cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, // transposed A operand -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 16>, -+ EpilogueOutputOp, -+ cutlass::plus, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 5, -+ cutlass::arch::OpMultiplyAdd -+ >::GemmKernel; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ using ReferenceOp = test::gemm::device::GemmWithReductionReference< -+ Gemm, -+ dReluLambda -+ >; -+ -+ test::gemm::device::TestGemmWithReduction( -+ {136, 6920, 512}, -+ cutlass::gemm::GemmUniversalMode::kGemm -+ ); -+} -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/gemv.cu b/3rdparty/cutlass/test/unit/gemm/device/gemv.cu -new file mode 100644 -index 0000000..fe68e0e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/gemv.cu -@@ -0,0 +1,444 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMV interface -+*/ -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/kernel/gemv.h" -+#include "cutlass/gemm/device/gemv.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+ -+template -+class TestbedGemv { -+public: -+ -+ using ElementA = typename Gemv::ElementA; -+ using LayoutA = typename Gemv::LayoutA; -+ using ElementB = typename Gemv::ElementB; -+ using ElementC = typename Gemv::ElementC; -+ -+ using ElementAccumulator = typename Gemv::ElementAccumulator; -+ using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; -+ -+ using LayoutV = cutlass::layout::RowMajor; -+ -+private: -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemv( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize( -+ cutlass::MatrixCoord problem_size -+ ) { -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size); -+ tensor_B.resize({problem_size.column(), 1}); -+ tensor_C.resize({problem_size.row(), 1}); -+ tensor_D.resize({problem_size.row(), 1}); -+ reference_D.resize({problem_size.row(), 1}, false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemv::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemv::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemv::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed) << " mismatched reference"; -+ -+ if (!passed) { -+ -+ std::ofstream file("testbed_universal_errors.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemv::ElementA, typename Gemv::LayoutA, -+ typename Gemv::ElementB, LayoutV, -+ typename Gemv::ElementC, LayoutV, -+ ElementCompute, ElementAccumulator -+ >( -+ {problem_size.row(), 1, problem_size.column()}, -+ alpha, -+ tensor_A.host_ref(), -+ Gemv::kTransformA, -+ tensor_B.host_ref(), -+ Gemv::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Runs one problem size -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemv::Arguments arguments{ -+ problem_size, -+ {alpha, beta}, -+ tensor_A.device_ref(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Gemv gemm_op; -+ -+ size_t workspace_size = Gemv::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemv() { -+ -+ using ElementCompute = typename Gemv::EpilogueOutputOp::ElementCompute; -+ -+ int M[] = { -+ 8, 48, 192, 520 -+ }; -+ -+ int K[] = { -+ 8, 192, 528 -+ }; -+ -+ double Alpha[] = { -+ 1, 1.25 -+ }; -+ -+ double Beta[] = { -+ 0, 1, 1.25 -+ }; -+ -+ for (int m : M) { -+ for (int k : K) { -+ for (double alpha : Alpha) { -+ for (double beta : Beta) { -+ -+ TestbedGemv testbed; -+ -+ if (!testbed.run({m, k}, ElementCompute(alpha), ElementCompute(beta))) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f32n_f32_f32_simt_f32, Simple) { -+ -+ using ElementOutput = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementOutput, // Element A -+ LayoutA, // Layout A -+ ElementOutput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f16n_f16_f32_simt_f32, Simple) { -+ -+ using ElementInput = cutlass::half_t; -+ using ElementOutput = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementInput, // Element A -+ LayoutA, // Layout A -+ ElementInput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemv_f16n_f16_f16_simt_f32, Simple) { -+ -+ using ElementInput = cutlass::half_t; -+ using ElementOutput = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator>; -+ -+ using Gemv = cutlass::gemm::device::Gemv< -+ cutlass::gemm::kernel::Gemv< -+ ElementInput, // Element A -+ LayoutA, // Layout A -+ ElementInput, // Element B -+ ElementOutput, // Element C -+ ElementAccumulator, // Element Accumulator -+ EpilogueOp // Output operator -+ > -+ >; -+ -+ -+ EXPECT_TRUE(test::gemm::TestAllGemv()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..e09bf17 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..45fbebd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..def5edd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_ls_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..ebf9055 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf32h_cf32n_rs_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..9a11549 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..882bbf2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -new file mode 100644 -index 0000000..4b4b166 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_ls_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -new file mode 100644 -index 0000000..d6d1690 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HEMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Hemm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1764e32 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32h_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..926ed0a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf32h_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..fbc4efd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Her2k_cf64c_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..4697598 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..c7dca8c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,310 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kHermitian>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..717d3f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,149 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64h_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3a65931 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/her2k_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,201 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HER2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 0 // HER2K with RowMajor output is not supported -+TEST(SM80_Device_Her2k_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64c_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Her2k_cf64h_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ false, // IsBetaZero -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a6503d1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..56ef601 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf32n_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf32h_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..114a20c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,93 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..71d8a9c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/herk_cf64h_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide HERK interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_N (column-major) input layouts -+TEST(SM80_Device_Herk_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts -+TEST(SM80_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::BlasMode::kHermitian -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h -new file mode 100644 -index 0000000..681e051 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed.h -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MultistageTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = -+ typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ MultistageTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, uint64_t seed) { -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -+ -scope, 0); -+ } else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), -+ view.capacity()); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Waives test if CUDA device is insufficient -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run(cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waives test if CUDA device is insufficient -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor -+ tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor -+ tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor -+ tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor -+ tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor -+ reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, tensor_A.device_ref(), tensor_B.device_ref(), -+ tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, -+ reference_D.host_ref(), ElementAccumulator(0)); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" << problem_size.m() << "x" -+ << problem_size.n() << "x" << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN -+ << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM -+ << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK -+ << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file << "problem: " << problem_size << ", alpha: " << alpha -+ << ", beta: " << beta << "\n\n"; -+ -+ file << "A =\n" -+ << tensor_A.host_view() << "\nB =\n" -+ << tensor_B.host_view() << "\nC =\n" -+ << tensor_C.host_view() << "\n\nReference =\n" -+ << reference_D.host_view() << "\nComputed =\n" -+ << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = {16, 528}; -+ -+ int problem_size_n[] = {16, 528}; -+ -+ int problem_size_k[] = {Gemm::InstructionShape::kK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages + -+ Gemm::InstructionShape::kK}; -+ -+ double problem_alpha[] = {1.0}; -+ -+ // TODO Try non zero beta value after multistaged epilogue is implemented -+ double problem_beta[] = {0.0}; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ passed = -+ run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h -new file mode 100644 -index 0000000..5f33206 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h -@@ -0,0 +1,349 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct MultistageInterleavedTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ MultistageInterleavedTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerMultiprocessor < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementA, -+ typename Gemm::LayoutA> tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reorder_column( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); -+ -+ cutlass::reference::host::TensorCopy( -+ reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta} -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), -+ tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nB_reordered =\n" << tensor_B_reordered.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = { -+ InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_n[] = { -+ InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_k[] = { -+ InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 0.0 -+ }; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = run( -+ {m, n, k}, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu -new file mode 100644 -index 0000000..021182e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K problem visitors -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "testbed_grouped_rank_2k_scheduler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Run a series of tests on the testbed -+template -+void run_tests(bool skip_tile_check=false) { -+ for (int scale_factor : {8, 16, 32, 64}) { -+ for (int threadblock_count : {54, 108, 216, 324, 432}) { -+ for (int problems : {1, 27, 180, 300}) { -+ Testbed testbed(skip_tile_check); -+ testbed.run(problems, threadblock_count, scale_factor); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_u, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 128; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t128_l, 64x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 128; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ run_tests(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 64x32x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 32x64x32) { -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ static int const kNumPrefetch = 256; -+ static int const kThreadCount = 256; -+ static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; -+ -+ using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< -+ ThreadblockShape, -+ kNumPrefetch, -+ kThreadCount, -+ kFillModeC, -+ // List of GroupScheduleModes to compare. List must contain at least two. -+ cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, -+ cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; -+ -+ // Skip individual tile check for the non-square SYR2K versions. We still -+ // compare the problem visitors with one another -+ run_tests(/*skip_tile_check=*/true); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu -new file mode 100644 -index 0000000..51632bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nn_sm50.cu -@@ -0,0 +1,1131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu -new file mode 100644 -index 0000000..512fcbc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm50.cu -@@ -0,0 +1,1311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu -new file mode 100644 -index 0000000..805937a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_nt_sm80.cu -@@ -0,0 +1,265 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 32x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x256x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu -new file mode 100644 -index 0000000..7405802 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm50.cu -@@ -0,0 +1,1131 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu -new file mode 100644 -index 0000000..cfb3764 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tn_sm80.cu -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_complex.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 32x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x64x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x128x8_32x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x64x8_64x32x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x256x8_64x64x1) { -+ -+ using Element = cutlass::complex; -+ -+ using Gemm = cutlass::gemm::device::GemmComplex< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu -new file mode 100644 -index 0000000..3c232f1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_cgemm_tt_sm50.cu -@@ -0,0 +1,1130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_cgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_cgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_cgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu -new file mode 100644 -index 0000000..f65fd01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nn_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affin2_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu -new file mode 100644 -index 0000000..5ffbbfa ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_nt_sm50.cu -@@ -0,0 +1,1170 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_dgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu -new file mode 100644 -index 0000000..9205761 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tn_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu -new file mode 100644 -index 0000000..b635978 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_dgemm_tt_sm50.cu -@@ -0,0 +1,991 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_affine2_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_dgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_dgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_dgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = double; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu -new file mode 100644 -index 0000000..a10a604 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#if (__CUDACC_VER_MAJOR__ > 11) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) -+ -+TEST(SM50_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_simt_f32, 32x64x8_32x64x1) { -+ -+ using ElementA = cutlass::float_e4m3_t; -+ using ElementB = cutlass::float_e4m3_t; -+ using ElementC = cutlass::float_e4m3_t; -+ using ElementAccumulator = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementC>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} -+ -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu -new file mode 100644 -index 0000000..b399303 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nn_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nn, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nn, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu -new file mode 100644 -index 0000000..d414a7b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_nt_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_nt, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_nt, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu -new file mode 100644 -index 0000000..2891c97 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tn_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tn, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tn, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu -new file mode 100644 -index 0000000..c9eb576 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_hgemm_tt_sm50.cu -@@ -0,0 +1,2181 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x128x8_32x128x1_8x16_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x64x8_64x64x1_16x8_4x8_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x32x8_128x32x1_16x8_8x4_1x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 32x256x8_32x128x1_8x16_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 64x128x8_64x64x1_16x8_4x8_1x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x64x8_64x64x1_16x8_4x8_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x32x8_128x32x1_16x8_8x4_2x1, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 16 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 64x256x8_32x128x1_8x16_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x128x8_64x64x1_16x8_4x8_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x64x8_128x32x1_16x8_8x4_2x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 128x256x8_64x64x1_16x8_4x8_2x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 16 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_hgemm_tt, 256x128x8_64x64x1_16x8_4x8_4x2, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 128x256x8_32x64x1_8x8_4x8_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_hgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_hgemm_tt, 256x128x8_64x32x1_8x8_8x4_4x4, { -+ using precision = cutlass::half_t; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu -new file mode 100644 -index 0000000..5292e59 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nn_sm50.cu -@@ -0,0 +1,1701 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu -new file mode 100644 -index 0000000..64391a4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_nt_sm50.cu -@@ -0,0 +1,1761 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu -new file mode 100644 -index 0000000..9e6c841 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tn_sm50.cu -@@ -0,0 +1,1671 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu -new file mode 100644 -index 0000000..87c7976 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_igemm_tt_sm50.cu -@@ -0,0 +1,1731 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_igemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_igemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_igemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = int; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu -new file mode 100644 -index 0000000..22729f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61.cu -@@ -0,0 +1,161 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#define N cutlass::layout::ColumnMajor -+#define T cutlass::layout::RowMajor -+ -+#define RUN_GEMM(X, Y) \ -+ using ElementOutput = int8_t; \ -+ using ElementAccumulator = int32_t; \ -+ using ElementCompute = float; \ -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ -+ using Gemm = cutlass::gemm::device::Gemm< \ -+ int8_t, \ -+ X, \ -+ int8_t, \ -+ Y, \ -+ ElementOutput, \ -+ cutlass::layout::RowMajor, \ -+ int32_t, \ -+ cutlass::arch::OpClassSimt, \ -+ cutlass::arch::Sm61, \ -+ ThreadBlockShape, \ -+ WarpShape, \ -+ InstructionShape, \ -+ cutlass::epilogue::thread::LinearCombinationClamp< \ -+ ElementOutput, \ -+ 1, \ -+ ElementAccumulator, \ -+ ElementCompute \ -+ >, \ -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ -+ 2 \ -+ >; \ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(N, T) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(N, T) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(N, T) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(T, N) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(T, N) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(T, N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(N, N) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(N, N) -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(N, N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 64x64x16_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ RUN_GEMM(T, T) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 256x128x64_64x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ RUN_GEMM(T, T) -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a, 256x256x16_128x64x4) { -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 256, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ RUN_GEMM(T, T) -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu -new file mode 100644 -index 0000000..10a61fc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_perf.cu -@@ -0,0 +1,195 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////// -+// NT -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// TT -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int32_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// NN -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -+ -+///////////////////////////////////// -+// TN -+///////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_perf, 128x256x32_64x64x8) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestGemmPerf()); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu -new file mode 100644 -index 0000000..bd0a1f8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_int8_igemm_sm61_sliced_k.cu -@@ -0,0 +1,307 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8t_s8t_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x32x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 32, 128>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM61_Device_Gemm_s8n_s8n_simt_op_dp4a_sliced_k, 32x64x128_32x32x4) { -+ -+ using ElementOutput = int8_t; -+ using ElementAccumulator = int32_t; -+ using ElementCompute = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ int32_t, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm61, -+ cutlass::gemm::GemmShape<32, 64, 128>, -+ cutlass::gemm::GemmShape<32, 32, 64>, -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ cutlass::epilogue::thread::LinearCombinationClamp< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementCompute -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu -new file mode 100644 -index 0000000..10889bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nn_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu -new file mode 100644 -index 0000000..f3d0a78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_nt_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu -new file mode 100644 -index 0000000..ed0f74d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tn_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu -new file mode 100644 -index 0000000..c8127c5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_qgemm_tt_sm50.cu -@@ -0,0 +1,861 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_qgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = cutlass::Quaternion; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu -new file mode 100644 -index 0000000..f48e9e1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nn_sm50.cu -@@ -0,0 +1,1740 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu -new file mode 100644 -index 0000000..69058bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm50.cu -@@ -0,0 +1,1800 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_nt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_nt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x32x16_32x8x1_4x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_nt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu -new file mode 100644 -index 0000000..fda68e5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_nt_sm80.cu -@@ -0,0 +1,296 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 32x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32an_f32at_f32at_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2RowMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C )); -+ -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 64x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x128x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32t_simt_f32, 128x256x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu -new file mode 100644 -index 0000000..b67aa23 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm50.cu -@@ -0,0 +1,1710 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tn, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tn, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tn, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tn, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu -new file mode 100644 -index 0000000..202c5a1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tn_sm80.cu -@@ -0,0 +1,296 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed.h" -+ -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+//////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 32x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x64x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32at_f32an_f32t_simt_f32, 128x128x8_32x64x1) { -+ -+ using Element = float; -+ using LayoutA = cutlass::layout::AffineRank2RowMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ Element, -+ LayoutC, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {1}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm( stride_factor_A, stride_factor_B, stride_factor_C )); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 64x128x8_32x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 8>, -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x64x8_64x32x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 8>, -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x128x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32t_simt_f32, 128x256x8_64x64x1) { -+ -+ using Element = float; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::layout::ColumnMajor, -+ Element, -+ cutlass::layout::RowMajor, -+ Element, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 8>, -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ cutlass::gemm::GemmShape<1, 1, 1>, -+ cutlass::epilogue::thread::LinearCombination< -+ Element, -+ 1, -+ Element, -+ Element>, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu -new file mode 100644 -index 0000000..82b0773 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_sgemm_tt_sm50.cu -@@ -0,0 +1,1770 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x64x8_16x64x1_4x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x32x8_32x32x1_8x4_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x64x8_32x64x1_8x8_4x8_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x32x8_64x32x1_8x8_8x4_1x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 16x128x8_16x64x1_4x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_32x32x1_8x4_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x128x8_32x64x1_8x8_4x8_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tt, 64x64x8_64x32x1_8x8_8x4_1x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x32x1_8x4_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x64x8_32x64x1_8x8_4x8_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 1 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x32x8_64x32x1_8x8_8x4_2x1, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x64x1_4x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x32x1_8x4_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x128x8_32x64x1_8x8_4x8_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_64x16x1_8x4_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x64x8_64x32x1_8x8_8x4_2x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 16x128x16_8x32x1_2x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 32x256x8_16x64x1_4x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_32x32x1_8x4_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 64x256x8_32x64x1_8x8_4x8_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L0(SM50_device_sgemm_affine2_tt, 128x128x8_64x32x1_8x8_8x4_2x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using LayoutA = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutB = cutlass::layout::AffineRank2ColumnMajor; -+ using LayoutC = cutlass::layout::AffineRankN<2>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, LayoutA, -+ precision, LayoutB, -+ precision, LayoutC, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ -+ typename LayoutA::Stride::Index stride_factor_A[] = {3, 4}; -+ typename LayoutB::Stride::Index stride_factor_B[] = {5, 6}; -+ typename LayoutC::Stride::Index stride_factor_C[] = {7, 8}; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllGemm(stride_factor_A, stride_factor_B, stride_factor_C)); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x32x1_8x4_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 128x128x8_32x64x1_8x8_4x8_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x32x8_64x16x1_8x4_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 8 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_sgemm_tt, 256x64x8_64x32x1_8x8_8x4_4x2, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 128 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 32x128x16_8x32x1_2x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 8 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 256 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 64x256x8_16x64x1_4x8_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 256, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 128 x 128 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 128x128x8_32x32x1_8x4_4x8_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 8 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 256 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_sgemm_tt, 256x64x8_64x16x1_8x4_8x4_4x4, { -+ using precision = float; -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu -new file mode 100644 -index 0000000..fb268af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nn_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu -new file mode 100644 -index 0000000..0b1312a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_nt_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu -new file mode 100644 -index 0000000..28dbb9b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tn_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::ColumnMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu -new file mode 100644 -index 0000000..079e756 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/simt_zgemm_tt_sm50.cu -@@ -0,0 +1,801 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/numeric_types.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 1 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 8 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<8, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 1 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 1 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 1 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 16 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 2 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L0(SM50_device_zgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 16 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 2 x 4 -+// Threadblock: 32 x 128 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 2 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 32 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 32 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 2 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 4 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 2 -+// Threadblock: 128 x 32 x 8 -+CUTLASS_TEST_L1(SM50_device_zgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 32 x 64 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<8, 16, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 2 x 2 -+// Threads / Warp: 8 x 4 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 32 x 16 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -+//////////////////////////////////////////////////////////////////////////////// -+// Elements / Thread: 4 x 2 -+// Threads / Warp: 4 x 8 -+// Warps / Block: 4 x 4 -+// Threadblock: 64 x 64 x 8 -+CUTLASS_TEST_L2(SM50_device_zgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { -+ using precision = cutlass::complex; -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 8>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 8>; -+ -+ static int const kEpilogueElementsPerAccess = 1; -+ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ precision, kEpilogueElementsPerAccess, precision, precision>; -+ -+ using Gemm = cutlass::gemm::device::Gemm< -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, cutlass::layout::RowMajor, -+ precision, -+ cutlass::arch::OpClassSimt, -+ cutlass::arch::Sm50, -+ ThreadblockShape, WarpShape, InstructionShape, -+ EpilogueOutputOp, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 2 // Stages -+ >; -+ EXPECT_TRUE(test::gemm::device::TestAllGemm()); -+} ) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu -new file mode 100644 -index 0000000..f7a18bc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f32_f32_f32_simt.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu -new file mode 100644 -index 0000000..421072f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm50_gemm_f64_f64_f64_simt.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM50_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu -new file mode 100644 -index 0000000..ba6456b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm61_gemm_s8_s8_s32_simt.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8n_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8n_s8t_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::ColumnMajor, -+ int8_t, cutlass::layout::RowMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8n_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_Device_Gemm_s8t_s8t_s32n_simt_s32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm50, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::RowMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM61_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..40f7cdb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f16_f16_f32_tensor_op_f32.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 1 -+TEST(SM80_Device_Gemm_f16t_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 1 -+TEST(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ cutlass::half_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_3x, 128x128x32_64x64x32) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ cutlass::half_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu -new file mode 100644 -index 0000000..a7c6b52 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f32_f32_f32_simt.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32n_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32n_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f32t_f32t_f32n_simt_f32, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::ColumnMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu -new file mode 100644 -index 0000000..274b30c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_simt.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Gemm_f64n_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64t_f64n_simt_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassSimt, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu -new file mode 100644 -index 0000000..e53a8e8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_f64_f64_f64_tensor_op_f64.cu -@@ -0,0 +1,98 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64, 128x128x64_64x64x64) { -+ -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ double, cutlass::layout::RowMajor, -+ double, cutlass::layout::ColumnMajor, -+ double, cutlass::layout::ColumnMajor, -+ double>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// #endif -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu -new file mode 100644 -index 0000000..d53cf54 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_s8_s8_s32_tensor_op.cu -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8n_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8n_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ int8_t, cutlass::layout::RowMajor, -+ int8_t, cutlass::layout::ColumnMajor, -+ int32_t, cutlass::layout::ColumnMajor, -+ int32_t>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(DISABLED_SM80_Device_Gemm_s8t_s8t_s32n_tensor_op_s32, 128x128x32_64x64x64) { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..14654c7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "default_gemm_configuration.hpp" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+using namespace cute; -+ -+ -+//#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32n_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Gemm_tf32t_tf32n_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM80_Device_Gemm_tf32t_tf32t_f32n_tensor_op_f32, 128x128x32_64x64x64) { -+ using Config = cutlass::gemm::device::DefaultGemmConfigurationToCutlass3Types< -+ cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ cutlass::tfloat32_t, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ Config::CollectiveMainloop, -+ Config::CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu -new file mode 100644 -index 0000000..9fbbd86 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu -@@ -0,0 +1,188 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 4, -+ cutlass::bfloat16_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 2, -+ cutlass::bfloat16_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu -new file mode 100644 -index 0000000..d3983e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::bfloat16_t, LayoutA, 8, -+ cutlass::bfloat16_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu -new file mode 100644 -index 0000000..0ee526b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// TT ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// TN ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// NT ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////// NN ////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 4, -+ cutlass::half_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 2, -+ cutlass::half_t, LayoutB, 2, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu -new file mode 100644 -index 0000000..4fea99a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu -@@ -0,0 +1,1077 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ cutlass::half_t, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 64x128x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 128x64x64) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_128,_64,_64>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, -+ Copy_Atom>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu -new file mode 100644 -index 0000000..1646632 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu -@@ -0,0 +1,582 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x2x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 4x1x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 1x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTma -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu -new file mode 100644 -index 0000000..378315d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu -@@ -0,0 +1,582 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x2x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x2x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 4x1x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_4x1x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_4,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 1x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_1x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_1,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////// Cluster 2x4x1 //////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x64_2x4x1) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::half_t, LayoutA, 8, -+ cutlass::half_t, LayoutB, 8, -+ float, -+ Shape<_64,_128,_64>, Shape<_2,_4,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecialized -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu -new file mode 100644 -index 0000000..c7d814b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu -@@ -0,0 +1,1018 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/epilogue.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_1,_1>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x1x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_1,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_1,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4x4x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_128,_64>; -+ using ClusterShape_MNK = Shape<_4,_4,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_1,_64>>; -+ using TileShapeS2R = Shape<_64,_16>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,_64>,Stride,_64>>; -+ using TileShapeS2R = Shape<_128,_8>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; -+ using TileShapeS2R = Shape<_8,_128>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_64,_1>>; -+ using TileShapeS2R = Shape<_16,_64>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_1,_64>>; -+ using TileShapeS2R = Shape<_64,_16>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,_64>,Stride,_64>>; -+ using TileShapeS2R = Shape<_128,_8>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 64x128x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_64,_128,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; -+ using TileShapeS2R = Shape<_8,_128>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128x64x64_2x2x1) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ using TileShape_MNK = Shape<_128,_64,_64>; -+ using ClusterShape_MNK = Shape<_2,_2,_1>; -+ using StageCountType = cutlass::gemm::collective::StageCountAuto; -+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; -+ -+ using PreSwizzleLayout = Layout,Stride<_64,_1>>; -+ using TileShapeS2R = Shape<_16,_64>; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination, -+ ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, -+ Copy_Atom, -+ TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, -+ Copy_Atom>; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementAccumulator, -+ TileShape_MNK, ClusterShape_MNK, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelTmaWarpSpecializedPersistent -+ >::CollectiveOp; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..b4edaf6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu -@@ -0,0 +1,86 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Redistribution and use in source and binary forms, with or without modification, are permitted -+ * provided that the following conditions are met: -+ * * Redistributions of source code must retain the above copyright notice, this list of -+ * conditions and the following disclaimer. -+ * * Redistributions in binary form must reproduce the above copyright notice, this list of -+ * conditions and the following disclaimer in the documentation and/or other materials -+ * provided with the distribution. -+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -+ * to endorse or promote products derived from this software without specific prior written -+ * permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -+ * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/collective/default_transposed_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32, 64x128x32_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ float, LayoutA, 4, -+ float, LayoutB, 4, -+ float, -+ Shape<_64,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveMainloop, -+ CollectiveEpilogue -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu -new file mode 100644 -index 0000000..5d30e96 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu -@@ -0,0 +1,152 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 8, -+ int8_t, LayoutB, 8, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 4, -+ int8_t, LayoutB, 4, -+ int32_t, -+ Shape<_128,_64,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu -new file mode 100644 -index 0000000..f0762a9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu -@@ -0,0 +1,243 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_64,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_1x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_1,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x1x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_2,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ int8_t, LayoutA, 16, -+ int8_t, LayoutB, 16, -+ int32_t, -+ Shape<_128,_128,_128>, Shape<_2,_2,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu -new file mode 100644 -index 0000000..e95772f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ tfloat32_t, LayoutA, 4, -+ tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::KernelMultistage -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 2, -+ cutlass::tfloat32_t, LayoutB, 2, -+ float, -+ Shape<_64,_64,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32, 128x64x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_128,_64,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu -new file mode 100644 -index 0000000..ce570a2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cute/tensor.hpp" -+#include "cute/atom/mma_atom.hpp" -+ -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/gemm_universal.hpp" -+#include "cutlass/gemm/collective/collective_builder.hpp" -+#include "cutlass/epilogue/collective/default_epilogue.hpp" -+#include "cutlass/epilogue/thread/linear_combination.h" -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "gemm_testbed_3x.hpp" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -+ -+using namespace cute; -+ -+TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 4, -+ cutlass::tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 4, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 1, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< -+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, -+ cutlass::tfloat32_t, LayoutA, 4, -+ cutlass::tfloat32_t, LayoutB, 1, -+ float, -+ Shape<_64,_128,_32>, Shape<_1,_1,_1>, -+ cutlass::gemm::collective::StageCountAuto, -+ cutlass::gemm::collective::KernelScheduleAuto -+ >::CollectiveOp; -+ -+ using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::gemm::TagToStrideC_t, -+ cutlass::epilogue::thread::LinearCombination>; -+ -+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal< -+ Shape, -+ CollectiveOp, -+ EpilogueOp -+ >; -+ -+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -+ EXPECT_TRUE(test::gemm::device::TestAll()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+ -+#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..d386a7e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..07f8564 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..3ad96e4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_ls_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..4eb8b7a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf32n_cf32n_rs_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..a13f744 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..fa2f574 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -new file mode 100644 -index 0000000..3dd4edd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_ls_f64_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_ls_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -new file mode 100644 -index 0000000..af810d4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_cf64n_cf64n_cf64n_tensor_op_rs_f64_sm80.cu -@@ -0,0 +1,172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..6cdc04d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..1ae9cf9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_fast_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..28be9a6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..1feb2d6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..abb5020 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..57c4f79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..6c82d76 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..a7a44f6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64n_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64n_f64t_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..64f4078 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..21cf9fd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64n_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..9fdd1a0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_ls_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..fce589d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_f64t_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f64t_f64t_rs_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ ElementA, -+ LayoutA, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..7e6e4b4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_l_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_ls_u_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..cb9bf2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32n_f32n_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,276 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_u_tensor_op_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32n_f32n_rs_l_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..a80084c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/symm_tf32t_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,489 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_symm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Symm_{ElementA/B}{LayoutA/B}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_l_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Symm_f32t_f32t_ls_u_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Symm = cutlass::gemm::device::Symm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..218fcd6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..b559945 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32n_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..7090a0a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..0c6efb1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_l_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf32n_cf32t_u_tensor_op_fast_f32, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..76d19f6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..cea1691 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x32x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3f7b03a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..b3e2e27 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..30dc4ba ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,150 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ 1, // AlignmentB -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..75ade1f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..71e794d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::complex; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..e310ac8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32n_f32n_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32n_f32n_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..e24a150 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f32t_f32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32t_f32n_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f32t_f32n_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..f7aa84d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..24f832d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,483 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..9cf9173 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..e7b165f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,273 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..e3fb6ee ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..b53b710 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..f720f88 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,253 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu -new file mode 100644 -index 0000000..f9292e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu -@@ -0,0 +1,308 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_grouped_rank_2k.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< -+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, -+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, -+ ElementC, LayoutC, cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ cutlass::arch::OpMultiplyAdd, -+ cutlass::BlasMode::kSymmetric>::Rank2Kkernel; -+ -+ using Rank2K = cutlass::gemm::device::Rank2KGrouped; -+ -+ test::gemm::device::TestbedGrouped testbed; -+ bool passed = testbed.run(24); -+ EXPECT_TRUE(passed); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..c6bb3b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32n_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,132 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32n_f32n_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32n_f32n_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..25e62fe ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syr2k_tf32t_f32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,133 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank2k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32t_f32n_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syr2k_tf32t_f32n_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = float; -+ -+ using Rank2K = cutlass::gemm::device::Rank2K< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..dcd963b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..007faad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32n_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..5d90211 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_l_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_u_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..85301bc ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_l_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf32n_cf32t_u_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..98da67d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..3888116 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..b826f05 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,95 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..2c455e7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = cutlass::complex; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = cutlass::complex; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = cutlass::complex; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, // kStages -+ 1, // AlignmentA -+ false, // SplitKSerial -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone, -+ cutlass::BlasMode::kSymmetric -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..8f4e9f9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32n_f32t_u_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..4dbd5b0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f32t_f32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f32t_f32t_u_tensor_op_fast_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..8fe7627 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,126 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..62d29af ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64n_f64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,237 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..0ad9dbb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_f64t_f64n_tensor_op_f64_sm80.cu -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_f64t_f64n_u_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ using ElementAccumulator = double; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..ba96ad5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32n_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32n_f32t_u_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..a1466d6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/syrk_tf32t_f32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,541 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide SYRK interface -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_rank_k_universal.h" -+ -+#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x128x32_32x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x64x32_64x32x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x128x16_64x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 64x128x16_32x64x16) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kLower, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 6 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 128x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 256x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 64x256x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 256x64x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 64, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Syrk_tf32t_f32t_u_tensor_op_f32, 128x128x32_64x64x32) { -+ -+ using ElementA = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ using ElementAccumulator = float; -+ -+ using RankK = cutlass::gemm::device::RankK< -+ ElementA, -+ LayoutA, -+ ElementC, -+ LayoutC, -+ cutlass::FillMode::kUpper, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementC, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed.h b/3rdparty/cutlass/test/unit/gemm/device/testbed.h -new file mode 100644 -index 0000000..dc21f41 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed.h -@@ -0,0 +1,600 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed_utils.h" -+#include "testbed_universal.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Testbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ typename Gemm::LayoutA::Stride stride_factor_A; -+ typename Gemm::LayoutB::Stride stride_factor_B; -+ typename Gemm::LayoutC::Stride stride_factor_C; -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ Testbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ stride_factor_A(typename Gemm::LayoutA::Stride()), -+ stride_factor_B(typename Gemm::LayoutB::Stride()), -+ stride_factor_C(typename Gemm::LayoutC::Stride()), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ Testbed( -+ typename Gemm::LayoutA::Stride stride_factor_A_, -+ typename Gemm::LayoutB::Stride stride_factor_B_, -+ typename Gemm::LayoutC::Stride stride_factor_C_, -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ stride_factor_A(stride_factor_A_), -+ stride_factor_B(stride_factor_B_), -+ stride_factor_C(stride_factor_C_), -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); -+ tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); -+ tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); -+ tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); -+ reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ if (Relu) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ reference_D.at(cutlass::MatrixCoord(i, j)) = -+ ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) -+ ? (typename Gemm::ElementC)0 -+ : reference_D.at(cutlass::MatrixCoord(i, j)); -+ } -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Determine if the CUDA device is sufficient to run the kernel -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) -+ { -+/* -+ std::cout << "\n-----------------------\n"; -+ std::cout << "problem size: " << problem_size << "\n"; -+ std::cout << "split_k_slices: " << split_k_slices << "\n"; -+ std::cout << "alpha: " << alpha << "\n"; -+ std::cout << "beta: " << beta << "\n"; -+ std::cout << "-----------------------\n\n"; -+*/ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ cudaError_t error = cudaGetLastError(); -+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmBasic( -+ const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), -+ const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), -+ const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ (cutlass::platform::is_same::value || -+ cutlass::platform::is_same::value) ? 4 : kAlignment; -+ -+ int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; -+ -+ int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; -+ -+ int problem_size_k[] = { -+ kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemm( -+ const typename Gemm::LayoutA::Stride& stride_factor_A, -+ const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), -+ const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) -+{ -+ // Test basic GEMM with non-default stride factors -+ return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); -+} -+ -+template -+bool TestAllGemm() -+{ -+#ifdef NDEBUG -+ // Non-debug builds also test basic GEMM with default stride factors -+ if (!TestAllGemmBasic()) { -+ return false; -+ } -+#endif // NDEBUG -+ -+ // Test universal GEMM -+#if 0 -+ // Define the universal kernel -+ using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< -+ typename Gemm::GemmKernel::Mma, // Mma -+ typename Gemm::GemmKernel::Epilogue, // Epilogue -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle -+ >; -+#else -+ // Define the streamk universal kernel -+ using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< -+ typename Gemm::GemmKernel::Mma, // Mma -+ typename Gemm::GemmKernel::Epilogue, // Epilogue -+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle -+ >; -+#endif -+ -+ // Define the universal adaptor -+ using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; -+ -+ // Test universal GEMM -+ return TestAllGemmUniversal(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmPerf(int iterations = 1) { -+ bool passed = true; -+ -+ int problem_size_m[] = { 2048 }; -+ -+ int problem_size_n[] = { 4352 }; -+ -+ int problem_size_k[] = { 4096 }; -+ -+ int split_k_slices[] = { 1 }; -+ double problem_alpha[] = { 1 }; -+ double problem_beta[] = { 0.0 }; -+ -+ Testbed testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ for (int i = 0; i < iterations; i++){ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ } -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h -new file mode 100644 -index 0000000..244bc06 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_complex.h -@@ -0,0 +1,294 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedComplex : public Testbed { -+ -+ using Base = Testbed; -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ -+ // -+ // Methods -+ // -+ -+ TestbedComplex( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ Base(init_A_, init_B_, init_C_, seed_) { } -+ -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ problem_size, -+ alpha, -+ this->tensor_A.host_ref(), -+ Gemm::kTransformA, -+ this->tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ this->tensor_C.host_ref(), -+ this->reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return this->compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Initialize workspace -+ // -+ -+ this->initialize(problem_size); -+ -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ this->tensor_A.device_ref(), -+ this->tensor_B.device_ref(), -+ this->tensor_C.device_ref(), -+ this->tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmComplex() { -+ bool passed = true; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = -+ cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ int problem_size_m[] = { -+ kAlignment, 512 - 3*kAlignment -+ }; -+ -+ int problem_size_n[] = { -+ kAlignment, 512 - 2*kAlignment -+ }; -+ -+ int problem_size_k[] = { -+ kAlignment, 128 - kAlignment -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ TestbedComplex testbed; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h -new file mode 100644 -index 0000000..10d5d3f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h -@@ -0,0 +1,657 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct GemmWithBroadcastReferenceOp { -+ -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ -+ using ElementCompute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ typename OutputOp::BinaryOp binary_op; -+ typename OutputOp::ElementwiseOp elementwise_op; -+ -+ GemmWithBroadcastReferenceOp() { } -+ -+ void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { -+ -+ ElementCompute t_full = binary_op(gemm, bias); -+ T = ElementT(t_full); -+ -+ ElementCompute z_full = elementwise_op(t_full); -+ Z = ElementZ(z_full); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Fused testbed -+// -+// Y = GEMM(AB, C) -+// -+// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) -+// -+// Z[i, j] = Elementwise(T[i, j]) -+// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+struct TestbedGemmWithBroadcast { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCOmpute = typename OutputOp::ElementCompute; -+ using ElementZ = typename OutputOp::ElementZ; -+ using ElementT = typename OutputOp::ElementT; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; // Input A -+ cutlass::HostTensor tensor_B; // Input B -+ cutlass::HostTensor tensor_C; // Input C -+ cutlass::HostTensor tensor_Broadcast; // Input Broadcast -+ -+ cutlass::HostTensor tensor_Z; -+ cutlass::HostTensor tensor_T; -+ -+ cutlass::HostTensor tensor_C_ref; -+ cutlass::HostTensor tensor_Y_ref; -+ cutlass::HostTensor tensor_Z_ref; -+ cutlass::HostTensor tensor_T_ref; -+ -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemmWithBroadcast( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_Z.resize(problem_size.mn()); -+ tensor_T.resize(problem_size.mn()); -+ tensor_Broadcast.resize({ -+ problem_size.m(), -+ 1 -+ }); -+ -+ tensor_C_ref.resize(problem_size.mn()); -+ tensor_Y_ref.resize(problem_size.mn()); -+ tensor_Z_ref.resize(problem_size.mn()); -+ tensor_T_ref.resize(problem_size.mn()); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { -+ for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { -+ tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_Broadcast.sync_device(); -+ -+ tensor_Z.sync_device(); -+ tensor_T.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ tensor_Z.sync_host(); -+ tensor_T.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); -+ -+ bool passed = true; -+ float norm_diff = 0; -+ -+ if (OutputOp::kStoreZ) { -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); -+ passed = (norm_diff <= 0.1f); -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; -+ } -+ -+ if (OutputOp::kStoreT) { -+ -+ norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); -+ passed = (passed && (norm_diff <= 0.1f)); -+ -+ EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; -+ } -+ -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); -+ -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nZ =\n" << tensor_Z.host_view() -+ << "\nT =\n" << tensor_T.host_view() -+ << "\n\n" -+ << "\nY_ref =\n" << tensor_Y_ref.host_view() -+ << "\nZ_ref =\n" << tensor_Z_ref.host_view() -+ << "\nT_ref =\n" << tensor_T_ref.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ ElementAccumulator, typename Gemm::LayoutC, -+ ElementAccumulator, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C_ref.host_ref(), -+ tensor_Y_ref.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ using ElementC = typename Gemm::ElementC; -+ -+ ReferenceOp reference_op; -+ -+ // compute tensor Z and tensor T -+ for (int m = 0; m < problem_size.m(); ++m) { -+ for (int n = 0; n < problem_size.n(); ++n) { -+ -+ ElementZ z; -+ ElementT t; -+ -+ reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); -+ -+ tensor_Z_ref.at({m, n}) = z; -+ tensor_T_ref.at({m, n}) = t; -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementAccumulator alpha = ElementAccumulator(1), -+ ElementAccumulator beta = ElementAccumulator(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_Z.device_data(), -+ tensor_Broadcast.device_data(), -+ tensor_T.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_Z.layout().stride(0), -+ 0, // This must be zero -+ tensor_T.layout().stride(0), -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = true; -+ -+ passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ // -+ // Profile -+ // -+ -+ #if 0 // profiling disabled for now. -+ -+ int const kWorkspaces = 100; -+ -+ cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); -+ -+ cudaEvent_t events[2]; -+ for (auto & event : events) { -+ cudaError_t result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); -+ return false; -+ break; -+ } -+ } -+ -+ int const kWarmupIterations = 5; -+ int const kProfilingIterations = 100; -+ -+ for (int i = 0; i < kWarmupIterations; ++i) { -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ -+ cudaError_t result = cudaEventRecord(events[0]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ for (int i = 0; i < kProfilingIterations; ++i) { -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), -+ profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), -+ profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), -+ profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), -+ profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), -+ profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_Z.layout().stride(0), -+ 0, // This must be zero -+ tensor_T.layout().stride(0), -+ }; -+ -+ gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ result = cudaEventRecord(events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ float elapsed_time = 0; -+ result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double average_time = double(elapsed_time) / double(kProfilingIterations); -+ -+ std::cout << problem_size << ": " << average_time << " ms" << std::endl; -+ -+ for (auto & event : events) { -+ cudaEventDestroy(event); -+ } -+ #endif -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+bool TestGemmWithBroadcast( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedGemmWithBroadcast testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp = GemmWithBroadcastReferenceOp -+> -+bool TestAllGemmWithBroadcast() { -+ -+ int M_problems[] = {8, 136, 264, 520}; -+ int N_problems[] = {8, 136, 264, 520}; -+ int K_problems[] = {8, 136, 264, 520}; -+ double alpha_problems[] = {1.25, 2.25}; -+ double beta_problems[] = {0, 1, 2.0}; -+ -+ bool passed = true; -+ -+ for (int M : M_problems) { -+ for (int N : N_problems) { -+ for (int K : K_problems) { -+ for (double alpha : alpha_problems) { -+ for (double beta : beta_problems) { -+ -+ TestbedGemmWithBroadcast testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ {M, N, K}, -+ 1, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; -+ -+ if (!passed) { -+ -+ return passed; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h -new file mode 100644 -index 0000000..6f220b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h -@@ -0,0 +1,589 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct GemmWithReductionReference { -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; -+ using ElementC = typename Gemm::ElementC; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ // -+ // Data members -+ // -+ -+ BinaryOp binary_op; -+ -+ // -+ // Methods -+ // -+ -+ GemmWithReductionReference() { } -+ -+ ElementCompute operator()( -+ ElementAccumulator d_y, -+ ElementT t) { -+ -+ return binary_op(ElementCompute(d_y), ElementCompute(t)); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Gemm, -+ typename ReferenceOp -+> -+struct TestbedGemmWithReduction { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor tensor_Reduction; -+ cutlass::HostTensor tensor_Tensor; -+ cutlass::HostTensor tensor_C_ref; -+ cutlass::HostTensor reference_d_Y; -+ cutlass::HostTensor reference_D; -+ cutlass::HostTensor reference_Reduction; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGemmWithReduction( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ for (int m = 0; m < view.extent().row(); ++m) { -+ for (int n = 0; n < view.extent().column(); ++n) { -+ //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); -+ view.at({m, n}) = (n == 0 ? Element(m) : Element()); -+ -+ } -+ } -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ -+ tensor_Reduction.resize({ -+ problem_size.m(), -+ (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN -+ }); -+ -+ tensor_Tensor.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ reference_d_Y.resize(problem_size.mn(), false); -+ tensor_C_ref.resize(problem_size.mn(), false); -+ reference_Reduction.resize({problem_size.m(), 1}, false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { -+ for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { -+ tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ tensor_Reduction.sync_device(); -+ tensor_Tensor.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ tensor_Reduction.sync_host(); -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); -+ -+ bool passed = true; -+ for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { -+ -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { -+ reduced_value += tensor_Reduction.at({m, j}); -+ } -+ -+ if (reduced_value != reference_Reduction.at({m, 0})) { -+ std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; -+ passed = false; -+ break; -+ } -+ } -+ EXPECT_TRUE(passed) << "Reduction is incorect."; -+ -+ if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { -+ EXPECT_TRUE(false) << " mismatched reference"; -+ passed = false; -+ } -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("testbed_universal_errors_sm70.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nT = \n" << tensor_Tensor.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view() -+ << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" -+ << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementAccumulator alpha, -+ ElementAccumulator beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ ElementAccumulator, typename Gemm::LayoutC, -+ ElementAccumulator, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C_ref.host_ref(), -+ reference_d_Y.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ using ElementC = typename Gemm::ElementC; -+ -+ ReferenceOp reference_op; -+ -+ // compute backwards -+ for (int m = 0; m < problem_size.m(); ++m) { -+ ElementAccumulator reduced_value = ElementAccumulator(); -+ for (int n = 0; n < problem_size.n(); ++n) { -+ ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); -+ reduced_value += d_full; -+ reference_D.at({m, n}) = ElementC(d_full); -+ } -+ reference_Reduction.at({m, 0}) = reduced_value; -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementAccumulator alpha = ElementAccumulator(1), -+ ElementAccumulator beta = ElementAccumulator(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ tensor_Reduction.device_data(), -+ tensor_Tensor.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0), -+ tensor_Reduction.layout().stride(0), -+ tensor_Tensor.layout().stride(0), -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ // -+ // Profile -+ // -+ -+ #if 0 // profiling disabled for now. -+ -+ int const kWorkspaces = 100; -+ -+ cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); -+ cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); -+ -+ cudaEvent_t events[2]; -+ for (auto & event : events) { -+ cudaError_t result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); -+ return false; -+ break; -+ } -+ } -+ -+ int const kWarmupIterations = 5; -+ int const kProfilingIterations = 100; -+ -+ for (int i = 0; i < kWarmupIterations; ++i) { -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ -+ cudaError_t result = cudaEventRecord(events[0]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ for (int i = 0; i < kProfilingIterations; ++i) { -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), -+ profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), -+ profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), -+ profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), -+ profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), -+ profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0), -+ tensor_Reduction.layout().stride(0), -+ tensor_Tensor.layout().stride(0), -+ }; -+ -+ gemm_op.initialize(arguments, workspace.get()); -+ status = gemm_op(); -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ } -+ -+ result = cudaEventRecord(events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ float elapsed_time = 0; -+ result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double average_time = double(elapsed_time) / double(kProfilingIterations); -+ -+ std::cout << problem_size << ": " << average_time << " ms" << std::endl; -+ -+ for (auto & event : events) { -+ cudaEventDestroy(event); -+ } -+ #endif -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmWithReduction( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count = 1, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedGemmWithReduction testbed; -+ -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h -new file mode 100644 -index 0000000..c5ee3ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped.h -@@ -0,0 +1,501 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped.h" -+#include "cutlass/gemm/kernel/default_gemm_grouped.h" -+#include "cutlass/gemm/device/gemm_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedGrouped { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Gemm::LayoutA; -+ using LayoutB = typename Gemm::LayoutB; -+ using LayoutC = typename Gemm::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ int problem_count; -+ -+ std::vector problem_sizes_host; -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGrouped( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // no fill - remain zero -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize() { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ -+ lda_host.resize(problem_count); -+ ldb_host.resize(problem_count); -+ ldc_host.resize(problem_count); -+ ldd_host.resize(problem_count); -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord problem( -+ 8 * (rand() % 64) + 24, -+ 8 * (rand() % 64) + 24, -+ 8 * (rand() % 64) + 24); -+ -+ if (!i) { -+ problem = cutlass::gemm::GemmCoord(48, 16, 8); -+ } -+ -+ problem_sizes_host.at(i) = problem; -+ -+ // std::cout << "Problem[" << i << "]: " << problem << std::endl; -+ -+ lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.m() * problem.k(); -+ int64_t elements_B = problem.k() * problem.n(); -+ int64_t elements_C = problem.m() * problem.n(); -+ int64_t elements_D = problem.m() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ -+ // Random strides between problems? -+ } -+ -+ problem_sizes_device.reset(problem_count); -+ problem_sizes_device.copy_from_host(problem_sizes_host.data()); -+ -+ lda.reset(problem_count); -+ ldb.reset(problem_count); -+ ldc.reset(problem_count); -+ ldd.reset(problem_count); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ std::vector ptr_A_host(problem_count); -+ std::vector ptr_B_host(problem_count); -+ std::vector ptr_C_host(problem_count); -+ std::vector ptr_D_host(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ -+ initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); -+ initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); -+ initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); -+ -+ cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); -+ cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); -+ cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); -+ cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); -+ } -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.m(), problem.k()}; -+ MatrixCoord extent_B{problem.k(), problem.n()}; -+ MatrixCoord extent_C{problem.m(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); -+ cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); -+ cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); -+ cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); -+ cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference GEMM -+ cutlass::reference::host::GemmComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ alpha, -+ view_A, -+ Gemm::kTransformA, -+ view_B, -+ Gemm::kTransformB, -+ beta, -+ view_C, -+ view_Ref, -+ ElementAccumulator(0) -+ ); -+ -+ // Ensure that no input or output is entirely zero -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); -+ -+ // Compare against reference -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::ofstream file("testbed_grouped_errors.txt"); -+ -+ file -+ << "problem: " << problem << " [group: " << i << "]\n" -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << view_A -+ << "\nB =\n" << view_B -+ << "\nC =\n" << view_C -+ << "\n\nReference =\n" << view_Ref -+ << "\nComputed =\n" << view_D; -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+ /// Executes one test -+ bool run( -+ int problem_count, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ this->problem_count = problem_count; -+ -+ // Initialize the problem -+ initialize(); -+ -+ int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; -+ } -+ return true; -+ } -+ -+ // Configure the GEMM arguments -+ typename EpilogueOutputOp::Params epilogue_op(alpha, beta); -+ -+ // Configure GEMM arguments -+ typename Gemm::Arguments args( -+ problem_sizes_device.get(), -+ problem_count, -+ threadblock_count, -+ epilogue_op, -+ ptr_A.get(), -+ ptr_B.get(), -+ ptr_C.get(), -+ ptr_D.get(), -+ lda.get(), -+ ldb.get(), -+ ldc.get(), -+ ldd.get(), -+ problem_sizes_host.data() -+ ); -+ -+ // Initialize the GEMM object -+ Gemm gemm; -+ -+ size_t workspace_size = gemm.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm.initialize(args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Run the GEMM object -+ status = gemm.run(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Wait for completion -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "Kernel execution error: " << cudaGetErrorString(result); -+ -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Verify correctness -+ return verify(alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h -new file mode 100644 -index 0000000..7b212ae ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h -@@ -0,0 +1,502 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/device_kernel.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped.h" -+#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" -+#include "cutlass/gemm/device/rank_2k_grouped.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedGrouped { -+ -+ // -+ // Type definitions -+ // -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; -+ using ElementCompute = typename EpilogueOutputOp::ElementCompute; -+ -+ using LayoutA = typename Rank2K::LayoutA; -+ using LayoutB = typename Rank2K::LayoutB; -+ using LayoutC = typename Rank2K::LayoutC; -+ -+ using MatrixCoord = typename LayoutC::TensorCoord; -+ -+ // -+ // Data members -+ // -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint32_t seed; -+ -+ int problem_count; -+ -+ std::vector problem_sizes_host; -+ cutlass::DeviceAllocation problem_sizes_device; -+ -+ std::vector offset_A; -+ std::vector offset_B; -+ std::vector offset_C; -+ std::vector offset_D; -+ -+ std::vector lda_host; -+ std::vector ldb_host; -+ std::vector ldc_host; -+ std::vector ldd_host; -+ -+ cutlass::DeviceAllocation lda; -+ cutlass::DeviceAllocation ldb; -+ cutlass::DeviceAllocation ldc; -+ cutlass::DeviceAllocation ldd; -+ -+ cutlass::DeviceAllocation block_A; -+ cutlass::DeviceAllocation block_B; -+ cutlass::DeviceAllocation block_C; -+ cutlass::DeviceAllocation block_D; -+ -+ cutlass::DeviceAllocation ptr_A; -+ cutlass::DeviceAllocation ptr_B; -+ cutlass::DeviceAllocation ptr_C; -+ cutlass::DeviceAllocation ptr_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGrouped( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint32_t seed_ = 3080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint32_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ if (cutlass::sizeof_bits::value <= 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } -+ else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // no fill - remain zero -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize() { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ // construct a few problems of random sizes -+ srand(seed); -+ -+ int64_t total_elements_A = 0; -+ int64_t total_elements_B = 0; -+ int64_t total_elements_C = 0; -+ int64_t total_elements_D = 0; -+ -+ -+ lda_host.resize(problem_count); -+ ldb_host.resize(problem_count); -+ ldc_host.resize(problem_count); -+ ldd_host.resize(problem_count); -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ auto N = 8 * (rand() % 64) + 24; -+ auto K = 8 * (rand() % 64) + 24; -+ cutlass::gemm::GemmCoord problem(N, N, K); -+ -+ if (!i) { -+ problem = cutlass::gemm::GemmCoord(16, 16, 8); -+ } -+ -+ problem_sizes_host.at(i) = problem; -+ -+ lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); -+ ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); -+ ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); -+ -+ offset_A.push_back(total_elements_A); -+ offset_B.push_back(total_elements_B); -+ offset_C.push_back(total_elements_C); -+ offset_D.push_back(total_elements_D); -+ -+ int64_t elements_A = problem.n() * problem.k(); -+ int64_t elements_B = problem.n() * problem.k(); -+ int64_t elements_C = problem.n() * problem.n(); -+ int64_t elements_D = problem.n() * problem.n(); -+ -+ total_elements_A += elements_A; -+ total_elements_B += elements_B; -+ total_elements_C += elements_C; -+ total_elements_D += elements_D; -+ -+ // Random strides between problems? -+ } -+ -+ problem_sizes_device.reset(problem_count); -+ problem_sizes_device.copy_from_host(problem_sizes_host.data()); -+ -+ lda.reset(problem_count); -+ ldb.reset(problem_count); -+ ldc.reset(problem_count); -+ ldd.reset(problem_count); -+ -+ lda.copy_from_host(lda_host.data()); -+ ldb.copy_from_host(ldb_host.data()); -+ ldc.copy_from_host(ldc_host.data()); -+ ldd.copy_from_host(ldd_host.data()); -+ -+ // -+ // Assign pointers -+ // -+ -+ block_A.reset(total_elements_A); -+ block_B.reset(total_elements_B); -+ block_C.reset(total_elements_C); -+ block_D.reset(total_elements_D); -+ -+ std::vector ptr_A_host(problem_count); -+ std::vector ptr_B_host(problem_count); -+ std::vector ptr_C_host(problem_count); -+ std::vector ptr_D_host(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ ptr_A_host.at(i) = block_A.get() + offset_A.at(i); -+ ptr_B_host.at(i) = block_B.get() + offset_B.at(i); -+ ptr_C_host.at(i) = block_C.get() + offset_C.at(i); -+ ptr_D_host.at(i) = block_D.get() + offset_D.at(i); -+ } -+ -+ ptr_A.reset(problem_count); -+ ptr_A.copy_from_host(ptr_A_host.data()); -+ -+ ptr_B.reset(problem_count); -+ ptr_B.copy_from_host(ptr_B_host.data()); -+ -+ ptr_C.reset(problem_count); -+ ptr_C.copy_from_host(ptr_C_host.data()); -+ -+ ptr_D.reset(problem_count); -+ ptr_D.copy_from_host(ptr_D_host.data()); -+ -+ // -+ // Initialize the problems of the workspace -+ // -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.n(), problem.k()}; -+ MatrixCoord extent_B{problem.n(), problem.k()}; -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ -+ initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); -+ initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); -+ initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); -+ -+ cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); -+ cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); -+ cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); -+ cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); -+ } -+ } -+ -+ /// Verifies the result is a Rank2K -+ bool verify( -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ bool passed = true; -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); -+ -+ LayoutA layout_A(lda_host.at(i)); -+ LayoutB layout_B(ldb_host.at(i)); -+ LayoutC layout_C(ldc_host.at(i)); -+ LayoutC layout_D(ldd_host.at(i)); -+ -+ MatrixCoord extent_A{problem.n(), problem.k()}; -+ MatrixCoord extent_B{problem.n(), problem.k()}; -+ MatrixCoord extent_C{problem.n(), problem.n()}; -+ -+ std::vector matrix_A(layout_A.capacity(extent_A)); -+ std::vector matrix_B(layout_B.capacity(extent_B)); -+ std::vector matrix_C(layout_C.capacity(extent_C)); -+ std::vector matrix_D(layout_D.capacity(extent_C)); -+ std::vector matrix_Ref(layout_D.capacity(extent_C)); -+ -+ cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); -+ cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); -+ cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); -+ cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); -+ -+ cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); -+ cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); -+ cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); -+ cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); -+ cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); -+ -+ // Reference Rank2K -+ cutlass::reference::host::Rank2KComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem, -+ alpha, -+ view_A, -+ Rank2K::kTransformA, -+ view_B, -+ Rank2K::kTransformB, -+ beta, -+ view_C, -+ view_Ref, -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ // Ensure that no input or output is entirely zero -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); -+ -+ // Compare against reference -+ passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); -+ -+ if (!passed) { -+ std::ofstream file("testbed_grouped_errors.txt"); -+ -+ file -+ << "problem: " << problem << " [group: " << i << "]\n" -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << view_A -+ << "\nB =\n" << view_B -+ << "\nC =\n" << view_C -+ << "\n\nReference =\n" << view_Ref -+ << "\nComputed =\n" << view_D; -+ -+ return passed; -+ } -+ } -+ -+ return passed; -+ } -+ -+ /// Executes one test -+ bool run( -+ int problem_count, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ this->problem_count = problem_count; -+ -+ // Initialize the problem -+ initialize(); -+ -+ int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); -+ -+ // Early exit -+ if (!threadblock_count) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; -+ } -+ return true; -+ } -+ -+ // Configure the Rank2K arguments -+ typename EpilogueOutputOp::Params epilogue_op(alpha, beta); -+ -+ // Configure Rank2K arguments -+ typename Rank2K::Arguments args( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_sizes_device.get(), -+ problem_count, -+ threadblock_count, -+ epilogue_op, -+ ptr_A.get(), -+ ptr_B.get(), -+ ptr_C.get(), -+ ptr_D.get(), -+ lda.get(), -+ ldb.get(), -+ ldc.get(), -+ ldd.get(), -+ problem_sizes_host.data() -+ ); -+ -+ // Initialize the Rank2K object -+ Rank2K rank2k; -+ -+ size_t workspace_size = rank2k.get_workspace_size(args); -+ cutlass::DeviceAllocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k.initialize(args, workspace.get()); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Run the Rank2K object -+ status = rank2k.run(); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return false; -+ } -+ -+ // Wait for completion -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "Kernel execution error: " << cudaGetErrorString(result); -+ -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Verify correctness -+ return verify(alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h -new file mode 100644 -index 0000000..af588d3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h -@@ -0,0 +1,461 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped Rank2K problem visitors -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/device_kernel.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Use simple problem visitor as a baseline -+template -+struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { -+ using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ static cutlass::FillMode const kFillModeC = FillModeC; -+ -+ struct SharedStorage {}; -+ -+ int32_t tile_count_sum; -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaselineProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ shared_storage(shared_storage_) -+ { -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ tile_count_sum = this->tile_count(grid); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++this->problem_idx; -+ -+ if (this->problem_idx >= this->params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ -+ this->problem_tile_start = tile_count_sum; -+ tile_count_sum += this->tile_count(grid); -+ -+ } while (tile_count_sum <= this->tile_idx); -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+ -+ CUTLASS_DEVICE -+ cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { -+ int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; -+ int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; -+ int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); -+ -+ if (FillModeC == cutlass::FillMode::kUpper) { -+ cutlass::swap(macro_row, macro_col); -+ } -+ -+ int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); -+ int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); -+ -+ return cutlass::gemm::GemmCoord(row, col, 0); -+ } -+}; -+ -+template -+struct ProblemVisitorKernel { -+ struct SharedStorage { -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct Params { -+ typename ProblemVisitor::Params problem_visitor_params; -+ int32_t* visited_problems_ptr; -+ int32_t* visited_tiles_ptr; -+ int32_t visits_per_block; -+ -+ Params(): -+ visited_problems_ptr(nullptr), -+ visited_tiles_ptr(nullptr), -+ visits_per_block(0) {} -+ -+ Params(typename ProblemVisitor::Params problem_visitor_params_, -+ int32_t* visited_problems_ptr_, -+ int32_t* visited_tiles_ptr_, -+ int32_t visits_per_block_): -+ problem_visitor_params(problem_visitor_params_), -+ visited_problems_ptr(visited_problems_ptr_), -+ visited_tiles_ptr(visited_tiles_ptr_), -+ visits_per_block(visits_per_block_) {} -+ }; -+ -+ CUTLASS_DEVICE -+ void operator()(const Params& params, SharedStorage &shared_storage) { -+ int32_t store_offset = params.visits_per_block * blockIdx.x; -+ ProblemVisitor problem_visitor(params.problem_visitor_params, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); -+ cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); -+ -+ problem_visitor.advance(gridDim.x); -+ -+ // -+ // Early exit conditions -+ // 1) Out of range -+ // 2) Upper-triangular block in lower-triangular problem -+ // 3) Lower-triangular block in upper-triangular problem -+ // -+ -+ if (grid_shape.m() <= tile_offset.m() || -+ grid_shape.n() <= tile_offset.n()) { -+ continue; -+ } -+ -+ if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && -+ (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { -+ continue; -+ } -+ -+ if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && -+ tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { -+ continue; -+ } -+ -+ if (threadIdx.x == 0) { -+ params.visited_problems_ptr[store_offset] = problem_idx; -+ params.visited_tiles_ptr[store_offset] = threadblock_idx; -+ ++store_offset; -+ } -+ } -+ } -+}; -+ -+template -+struct ProblemVisitorRunner { -+ using BaseKernel = ProblemVisitorKernel; -+ using Params = typename BaseKernel::Params; -+ -+ Params params; -+ std::vector host_problem_sizes; -+ int32_t problem_count; -+ int32_t threadblock_count; -+ int32_t visits_per_block; -+ cutlass::DeviceAllocation visited_problems; -+ cutlass::DeviceAllocation visited_tiles; -+ cutlass::DeviceAllocation device_problem_sizes; -+ cutlass::DeviceAllocation workspace; -+ std::vector host_visited_problems; -+ std::vector host_visited_tiles; -+ -+ ProblemVisitorRunner(const std::vector& host_problem_sizes_, -+ int32_t threadblock_count_): -+ host_problem_sizes(host_problem_sizes_), -+ problem_count(int32_t(host_problem_sizes_.size())), -+ threadblock_count(threadblock_count_) {} -+ -+ /// Initializes GEMM state from arguments. -+ cutlass::Status initialize() { -+ size_t workspace_bytes = ProblemVisitor::get_workspace_size( -+ host_problem_sizes.data(), -+ problem_count, -+ threadblock_count); -+ -+ workspace.reset(workspace_bytes); -+ std::vector host_workspace(workspace_bytes); -+ -+ int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); -+ -+ ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, -+ threadblock_count, host_workspace.data()); -+ -+ workspace.copy_from_host(host_workspace.data(), workspace_bytes); -+ -+ device_problem_sizes.reset(problem_count); -+ device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); -+ -+ visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; -+ int32_t total_visits = visits_per_block * threadblock_count; -+ -+ visited_problems.reset(total_visits); -+ visited_tiles.reset(total_visits); -+ host_visited_problems.resize(total_visits); -+ host_visited_tiles.resize(total_visits); -+ -+ cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); -+ params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ bool verify() { -+ // Sort by problem size and then by threadblock_idx -+ std::vector indices(host_visited_problems.size()); -+ std::iota(indices.begin(), indices.end(), 0); -+ -+ std::stable_sort(indices.begin(), indices.end(), -+ [&](int32_t i1, int32_t i2) { -+ if (host_visited_problems[i1] == host_visited_problems[i2]) { -+ return host_visited_tiles[i1] < host_visited_tiles[i2]; -+ } -+ return host_visited_problems[i1] < host_visited_problems[i2]; -+ }); -+ -+ int32_t idx = 0; -+ -+ // Skip any entries that were not visited -+ while (host_visited_problems[indices[idx]] == -1) { -+ ++idx; -+ } -+ -+ // Check that each problem visited has the tiles we expect -+ for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { -+ auto problem = host_problem_sizes[problem_idx]; -+ ProblemVisitor::possibly_transpose_problem(problem); -+ int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); -+ for (int i = 0; i < problem_tiles; ++i) { -+ EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); -+ EXPECT_EQ(i, host_visited_tiles[indices[idx]]); -+ ++idx; -+ } -+ } -+ -+ return true; -+ } -+ -+ bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { -+ cutlass::Status status = initialize(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Initialization failed" << std::endl; -+ return false; -+ } -+ -+ dim3 grid(threadblock_count, 1, 1); -+ dim3 block(ProblemVisitor::kThreadCount, 1, 1); -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params); -+ -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ visited_problems.copy_to_host(host_visited_problems.data()); -+ visited_tiles.copy_to_host(host_visited_tiles.data()); -+ -+ if (skip_tile_check) { -+ return true; -+ } -+ -+ return verify(); -+ } -+}; -+ -+template -+struct TestbedGroupedRank2KScheduler { -+ -+ using BaselinePV = BaselineProblemVisitor, -+ ThreadblockShape, -+ PrefetchTileCount, -+ ThreadCount, -+ FillModeC>; -+ -+ // -+ // Data members -+ // -+ -+ // Whether to skip checking that the tiles are visited as expected. This is useful -+ // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped -+ // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to -+ // exit early, but which are difficult to detect in tests without reimplementing -+ // this functionality. -+ bool skip_tile_check; -+ uint32_t seed; -+ int problem_count; -+ int threadblock_count; -+ std::vector problem_sizes_host; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): -+ skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } -+ -+ /// Initializes data structures -+ void initialize(int32_t scale_factor) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ int n = scale_factor * (rand() % 64) + 24; -+ -+ cutlass::gemm::GemmCoord problem( -+ n, -+ n, -+ scale_factor * (rand() % 64) + 24); -+ -+ problem_sizes_host.at(i) = problem; -+ } -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount, -+ FillModeC>; -+ ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(runner.run(skip_tile_check)); -+ -+ // Check that this problem visitor visits the same problems and tiles as the baseline -+ EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); -+ EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ // Compare the next visitor with the baseline visitor -+ compare_visitors(baseline_runner); -+ -+ // Recurse to compare the next visitors -+ compare_visitors(baseline_runner); -+ } -+ -+ /// Executes the test on all scheduler modes -+ void run(int problem_count, int threadblock_count, int scale_factor=8) { -+ -+ this->problem_count = problem_count; -+ this->threadblock_count = threadblock_count; -+ -+ // Initialize the problem -+ initialize(scale_factor); -+ -+ // Run the baseline visitor to which we will compare all other visitors -+ ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(baseline_runner.run(skip_tile_check)); -+ -+ compare_visitors(baseline_runner); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h -new file mode 100644 -index 0000000..00d83b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h -@@ -0,0 +1,407 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for grouped GEMM problem visitors -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -+#include "cutlass/gemm/kernel/grouped_problem_visitor.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Use simple problem visitor as a baseline -+template -+struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { -+ using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; -+ using Params = typename Base::Params; -+ static int const kThreadCount = ThreadCount; -+ -+ struct SharedStorage {}; -+ -+ int32_t tile_count_sum; -+ SharedStorage &shared_storage; -+ -+ // -+ // Methods -+ // -+ CUTLASS_DEVICE -+ BaselineProblemVisitor( -+ Params const ¶ms_, -+ SharedStorage &shared_storage_, -+ int32_t block_idx -+ ): Base(params_, block_idx), -+ shared_storage(shared_storage_) -+ { -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ tile_count_sum = this->tile_count(grid); -+ } -+ -+ CUTLASS_DEVICE -+ bool next_tile() { -+ if (this->tile_idx < tile_count_sum) { -+ return true; -+ } -+ -+ do { -+ ++this->problem_idx; -+ -+ if (this->problem_idx >= this->params.problem_count) { -+ return false; -+ } -+ -+ cutlass::gemm::GemmCoord problem = this->problem_size(); -+ cutlass::gemm::GemmCoord grid = this->grid_shape(problem); -+ -+ this->problem_tile_start = tile_count_sum; -+ tile_count_sum += this->tile_count(grid); -+ -+ } while (tile_count_sum <= this->tile_idx); -+ -+ return true; -+ } -+ -+ static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count) { -+ return 0; -+ } -+ -+ static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, -+ int32_t problem_count, -+ int32_t block_count, -+ void* host_workspace_ptr) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct ProblemVisitorKernel { -+ struct SharedStorage { -+ typename ProblemVisitor::SharedStorage problem_visitor; -+ }; -+ -+ struct Params { -+ typename ProblemVisitor::Params problem_visitor_params; -+ int32_t* visited_problems_ptr; -+ int32_t* visited_tiles_ptr; -+ int32_t visits_per_block; -+ -+ Params(): -+ visited_problems_ptr(nullptr), -+ visited_tiles_ptr(nullptr), -+ visits_per_block(0) {} -+ -+ Params(typename ProblemVisitor::Params problem_visitor_params_, -+ int32_t* visited_problems_ptr_, -+ int32_t* visited_tiles_ptr_, -+ int32_t visits_per_block_): -+ problem_visitor_params(problem_visitor_params_), -+ visited_problems_ptr(visited_problems_ptr_), -+ visited_tiles_ptr(visited_tiles_ptr_), -+ visits_per_block(visits_per_block_) {} -+ }; -+ -+ CUTLASS_DEVICE -+ void operator()(const Params& params, SharedStorage &shared_storage) { -+ int32_t store_offset = params.visits_per_block * blockIdx.x; -+ ProblemVisitor problem_visitor(params.problem_visitor_params, -+ shared_storage.problem_visitor, -+ blockIdx.x); -+ -+ while (problem_visitor.next_tile()) { -+ int32_t problem_idx = problem_visitor.problem_index(); -+ int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); -+ -+ if (threadIdx.x == 0) { -+ params.visited_problems_ptr[store_offset] = problem_idx; -+ params.visited_tiles_ptr[store_offset] = threadblock_idx; -+ ++store_offset; -+ } -+ problem_visitor.advance(gridDim.x); -+ } -+ } -+}; -+ -+template -+struct ProblemVisitorRunner { -+ using BaseKernel = ProblemVisitorKernel; -+ using Params = typename BaseKernel::Params; -+ -+ Params params; -+ std::vector host_problem_sizes; -+ int32_t problem_count; -+ int32_t threadblock_count; -+ int32_t visits_per_block; -+ cutlass::DeviceAllocation visited_problems; -+ cutlass::DeviceAllocation visited_tiles; -+ cutlass::DeviceAllocation device_problem_sizes; -+ cutlass::DeviceAllocation workspace; -+ std::vector host_visited_problems; -+ std::vector host_visited_tiles; -+ -+ ProblemVisitorRunner(const std::vector& host_problem_sizes_, -+ int32_t threadblock_count_): -+ host_problem_sizes(host_problem_sizes_), -+ problem_count(int32_t(host_problem_sizes_.size())), -+ threadblock_count(threadblock_count_) {} -+ -+ /// Initializes GEMM state from arguments. -+ cutlass::Status initialize() { -+ size_t workspace_bytes = ProblemVisitor::get_workspace_size( -+ host_problem_sizes.data(), -+ problem_count, -+ threadblock_count); -+ -+ workspace.reset(workspace_bytes); -+ std::vector host_workspace(workspace_bytes); -+ -+ int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); -+ -+ ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, -+ threadblock_count, host_workspace.data()); -+ -+ workspace.copy_from_host(host_workspace.data(), workspace_bytes); -+ -+ device_problem_sizes.reset(problem_count); -+ device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); -+ -+ visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; -+ int32_t total_visits = visits_per_block * threadblock_count; -+ -+ visited_problems.reset(total_visits); -+ visited_tiles.reset(total_visits); -+ host_visited_problems.resize(total_visits); -+ host_visited_tiles.resize(total_visits); -+ -+ cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); -+ if (result != cudaSuccess) { -+ return cutlass::Status::kErrorInternal; -+ } -+ -+ typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); -+ params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); -+ -+ return cutlass::Status::kSuccess; -+ } -+ -+ bool verify() { -+ // Sort by problem size and then by threadblock_idx -+ std::vector indices(host_visited_problems.size()); -+ std::iota(indices.begin(), indices.end(), 0); -+ -+ std::stable_sort(indices.begin(), indices.end(), -+ [&](int32_t i1, int32_t i2) { -+ if (host_visited_problems[i1] == host_visited_problems[i2]) { -+ return host_visited_tiles[i1] < host_visited_tiles[i2]; -+ } -+ return host_visited_problems[i1] < host_visited_problems[i2]; -+ }); -+ -+ int32_t idx = 0; -+ -+ // Skip any entries that were not visited -+ while (host_visited_problems[indices[idx]] == -1) { -+ ++idx; -+ } -+ -+ // Check that each problem visited has the tiles we expect -+ for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { -+ auto problem = host_problem_sizes[problem_idx]; -+ ProblemVisitor::possibly_transpose_problem(problem); -+ int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); -+ for (int i = 0; i < problem_tiles; ++i) { -+ EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); -+ EXPECT_EQ(i, host_visited_tiles[indices[idx]]); -+ ++idx; -+ } -+ } -+ -+ return true; -+ } -+ -+ bool run(cudaStream_t stream = nullptr) { -+ cutlass::Status status = initialize(); -+ if (status != cutlass::Status::kSuccess) { -+ std::cerr << "Initialization failed" << std::endl; -+ return false; -+ } -+ -+ dim3 grid(threadblock_count, 1, 1); -+ dim3 block(ProblemVisitor::kThreadCount, 1, 1); -+ int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); -+ -+ cutlass::Kernel<<>>(params); -+ -+ cudaError_t result = cudaGetLastError(); -+ if (result != cudaSuccess) { -+ std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; -+ return false; -+ } -+ -+ visited_problems.copy_to_host(host_visited_problems.data()); -+ visited_tiles.copy_to_host(host_visited_tiles.data()); -+ -+ return verify(); -+ } -+}; -+ -+template -+struct TestbedGroupedGemmScheduler { -+ -+ using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; -+ using BaselinePV = BaselineProblemVisitor; -+ -+ // -+ // Data members -+ // -+ uint32_t seed; -+ int problem_count; -+ int threadblock_count; -+ std::vector problem_sizes_host; -+ -+ // -+ // Methods -+ // -+ -+ TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): -+ seed(seed_) { srand(seed); } -+ -+ /// Initializes data structures -+ void initialize(int32_t scale_factor) { -+ -+ // -+ // Choose random problem sizes -+ // -+ -+ problem_sizes_host.clear(); -+ problem_sizes_host.resize(problem_count); -+ -+ for (int32_t i = 0; i < problem_count; ++i) { -+ -+ cutlass::gemm::GemmCoord problem( -+ scale_factor * (rand() % 64) + 24, -+ scale_factor * (rand() % 64) + 24, -+ scale_factor * (rand() % 64) + 24); -+ -+ problem_sizes_host.at(i) = problem; -+ } -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< -+ ThreadblockShape, -+ GroupScheduleMode_, -+ PrefetchTileCount, -+ ThreadCount, -+ Transpose>; -+ ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(runner.run()); -+ -+ // Check that this problem visitor visits the same problems and tiles as the baseline -+ EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); -+ EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); -+ } -+ -+ template -+ void compare_visitors(const ProblemVisitorRunner& baseline_runner) { -+ // Compare the next visitor with the baseline visitor -+ compare_visitors(baseline_runner); -+ -+ // Recurse to compare the next visitors -+ compare_visitors(baseline_runner); -+ } -+ -+ /// Executes the test on all scheduler modes -+ void run(int problem_count, int threadblock_count, int scale_factor=8) { -+ -+ this->problem_count = problem_count; -+ this->threadblock_count = threadblock_count; -+ -+ // Initialize the problem -+ initialize(scale_factor); -+ -+ // Run the baseline visitor to which we will compare all other visitors -+ ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); -+ EXPECT_TRUE(baseline_runner.run()); -+ -+ compare_visitors(baseline_runner); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // gemm -+} // test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h -new file mode 100644 -index 0000000..b54a4b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_interleaved.h -@@ -0,0 +1,347 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct InterleavedTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ // -+ // Methods -+ // -+ -+ InterleavedTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, 2, -2, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Waives test if CUDA device is insufficient -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementA, -+ typename Gemm::LayoutA> tensor_A(problem_size.mk()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementB, -+ typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_C(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> tensor_D(problem_size.mn()); -+ -+ cutlass::HostTensor< -+ typename Gemm::ElementC, -+ typename Gemm::LayoutC> reference_D(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ cutlass::reorder_column( -+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); -+ -+ cutlass::reference::host::TensorCopy( -+ reference_D.host_view(), -+ tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B_reordered.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B_reordered.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ {alpha, beta} -+ }; -+ -+ Gemm gemm_op; -+ -+ cutlass::Status status = gemm_op.initialize(arguments); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ reference_D.host_view(), -+ tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nB_reordered =\n" << tensor_B_reordered.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Runs a set of problem sizes -+ bool run_all() { -+ bool passed = true; -+ -+ int problem_size_m[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_n[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ int problem_size_k[] = { -+ InterleavedK, 256 + InterleavedK, 512 + InterleavedK -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = run( -+ {m, n, k}, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h -new file mode 100644 -index 0000000..a721cc8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_planar_complex.h -@@ -0,0 +1,326 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm_planar_complex.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TestbedPlanarComplex { -+public: -+ -+ using ElementA = typename Gemm::ElementA; -+ using LayoutA = typename Gemm::LayoutA; -+ using ElementB = typename Gemm::ElementB; -+ using LayoutB = typename Gemm::LayoutB; -+ using ElementC = typename Gemm::ElementC; -+ using LayoutC = typename Gemm::LayoutC; -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::gemm::GemmCoord problem_size; -+ cutlass::HostTensorPlanarComplex tensor_A; -+ cutlass::HostTensorPlanarComplex tensor_B; -+ cutlass::HostTensorPlanarComplex tensor_C; -+ cutlass::HostTensorPlanarComplex tensor_D; -+ cutlass::HostTensorPlanarComplex tensor_D_ref; -+ -+ // -+ // Methods -+ // -+ -+ TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { -+ -+ tensor_A.reset({problem_size.m(), problem_size.k()}); -+ tensor_B.reset({problem_size.k(), problem_size.n()}); -+ tensor_C.reset({problem_size.m(), problem_size.n()}); -+ tensor_D.reset({problem_size.m(), problem_size.n()}); -+ tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); -+ } -+ -+ void initialize() { -+ -+ uint64_t seed = 1073; -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); -+ -+ cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); -+ cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ bool run( -+ cutlass::complex alpha = {1, 0}, -+ cutlass::complex beta = {0, 0}) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ initialize(); -+ -+ int batch_count = 1; -+ -+ ElementA *ptr_A = tensor_A.device_data(); -+ ElementB *ptr_B = tensor_B.device_data(); -+ ElementC *ptr_C = tensor_C.device_data(); -+ ElementC *ptr_D = tensor_D.device_data(); -+ -+ typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); -+ typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); -+ typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); -+ typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); -+ -+ int64_t imag_stride_A = tensor_A.imaginary_stride(); -+ int64_t imag_stride_B = tensor_B.imaginary_stride(); -+ int64_t imag_stride_C = tensor_C.imaginary_stride(); -+ int64_t imag_stride_D = tensor_D.imaginary_stride(); -+ -+ // -+ // Launch device kernel -+ // -+ -+ Gemm gemm_op; -+ -+ typename Gemm::Arguments args{ -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ ptr_A, -+ ptr_A + imag_stride_A, -+ ptr_B, -+ ptr_B + imag_stride_B, -+ ptr_C, -+ ptr_C + imag_stride_C, -+ ptr_D, -+ ptr_D + imag_stride_D, -+ lda, -+ lda, -+ ldb, -+ ldb, -+ ldc, -+ ldc, -+ ldd, -+ ldd -+ }; -+ -+ cutlass::Status status = gemm_op(args); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ -+ cudaError_t error = cudaDeviceSynchronize(); -+ -+ tensor_D.sync_host(); -+ -+ // -+ // Compute reference -+ // -+ -+ cutlass::reference::host::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ tensor_D_ref.host_ref() -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D.host_view(), -+ tensor_D_ref.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("gemm_planar_complex.txt"); -+ -+ output -+ << "A:\n" << tensor_A.host_view() << "\n" -+ << "B:\n" << tensor_B.host_view() << "\n" -+ << "C:\n" << tensor_C.host_view() << "\n" -+ << "Reference:\n" -+ << tensor_D_ref.host_view() << "\n" -+ << "Computed:\n" -+ << tensor_D.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+template -+bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { -+ -+ TestbedPlanarComplex testbed(problem_size); -+ -+ return testbed.run(); -+} -+ -+template -+bool TestAllGemmPlanarComplex() { -+ -+ int M[] = { -+ 16, 64, 72, 144, 264, 520, -+ }; -+ -+ int N[] = { -+ 16, 64, 72, 144, 248, 264, 520 -+ }; -+ -+ int K[] = { -+ 8, 64, 72, 96, 264, 520 -+ }; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ cutlass::complex alpha_values[] = { -+ {ElementCompute(1.25), ElementCompute(-0.5)} -+ }; -+ -+ cutlass::complex beta_values[] = { -+ {ElementCompute(-2.25), ElementCompute(1.5)} -+ }; -+ -+ for (int m : M) { -+ for (int n : N) { -+ for (int k : K) { -+ -+ test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); -+ -+ for (auto const &alpha : alpha_values) { -+ for (auto const &beta : beta_values) { -+ -+ bool passed = testbed.run(alpha, beta); -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h -new file mode 100644 -index 0000000..29f3989 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h -@@ -0,0 +1,641 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Rank 2k update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/rank_2k.h" -+#include "cutlass/util/reference/host/rank_2k_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedRank2KUniversal { -+ -+ using ElementA = typename Rank2K::ElementA; -+ using ElementB = typename Rank2K::ElementB; -+ using ElementC = typename Rank2K::ElementC; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedRank2KUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the Rank2K workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.mk()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a Rank2K -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ cutlass::reference::host::Rank2KComplex< -+ typename Rank2K::ElementA, typename Rank2K::LayoutA, -+ typename Rank2K::ElementB, typename Rank2K::LayoutB, -+ typename Rank2K::ElementC, typename Rank2K::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Rank2K::kTransformA, -+ tensor_B.host_ref(), -+ Rank2K::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0), -+ Rank2K::kFillModeC, -+ Rank2K::kBlasMode -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Rank2K::Rank2Kkernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the Rank2K operator -+ // -+ -+ typename Rank2K::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.n() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Rank2K rank2k_op; -+ -+ size_t workspace_size = Rank2K::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the Rank2K -+ // -+ -+ status = rank2k_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Rank2k_device_" -+ << "fill_mode_c_" -+ << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : -+ (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Rank2K::ThreadblockShape::kM << "x" -+ << Rank2K::ThreadblockShape::kN << "x" -+ << Rank2K::ThreadblockShape::kK << "_" -+ << Rank2K::WarpShape::kM << "x" -+ << Rank2K::WarpShape::kN << "x" -+ << Rank2K::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestRank2kUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedRank2KUniversal testbed; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllRank2KUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Rank2K::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 3.25 -+ }; -+ -+ double problem_beta[] = { -+ 0.0, 2.15 -+ }; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { -+ // continue; -+ //} -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+template -+bool TestAllRank2KHermitianUniversal() { -+ bool passed = true; -+ -+ using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; -+ using ElementAccumulator = typename Rank2K::ElementAccumulator; -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Rank2K::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, -+ Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ /* Complex alpha for HER2K */ -+ ElementAccumulator problem_alpha[] = { -+ {1.0}, -+ {1.25, 3.25}, -+ {-0.25, -2.25} -+ }; -+ -+ ElementAccumulator problem_beta[] = { -+ 0.0, -2.25 -+ }; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { -+ // continue; -+ //} -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ alpha, -+ beta -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h -new file mode 100644 -index 0000000..7c403ad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h -@@ -0,0 +1,511 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Rank 2k update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/rank_k_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedRank2KUniversal { -+ -+ using ElementA = typename RankK::ElementA; -+ using ElementC = typename RankK::ElementC; -+ using ElementAccumulator = typename RankK::ElementAccumulator; -+ using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedRank2KUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the RankK workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); -+ tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a RankK -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ cutlass::reference::host::Rank2KComplex< -+ typename RankK::ElementA, typename RankK::LayoutA, -+ typename RankK::ElementC, typename RankK::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ RankK::kTransformA, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0), -+ RankK::kFillModeC, -+ RankK::kBlasMode -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename RankK::RankKkernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the RankK operator -+ // -+ -+ typename RankK::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ RankK rank2k_op; -+ -+ size_t workspace_size = RankK::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the RankK -+ // -+ -+ status = rank2k_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_RankK_device_" -+ << "fill_mode_c_" -+ << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : -+ (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << RankK::ThreadblockShape::kM << "x" -+ << RankK::ThreadblockShape::kN << "x" -+ << RankK::ThreadblockShape::kK << "_" -+ << RankK::WarpShape::kM << "x" -+ << RankK::WarpShape::kN << "x" -+ << RankK::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestRank2kUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedRank2KUniversal testbed; -+ -+ using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllRankKUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ int const kAlignmentN = 128 / kMinimumOperandElementSize; -+ int const kAlignmentK = 128 / kMinimumOperandElementSize; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, -+ RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ -+ using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(n, n, k); -+ -+ TestbedRank2KUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h -new file mode 100644 -index 0000000..73c0c5c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_sanity.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/core_io.h" -+ -+#include "testbed.h" -+ -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// -+// List of Gemm internal paramters this testbed supports user verification -+// -+enum class ParameterID { -+ -+ // Threadblock-level parameters -+ kSmemASize, -+ kSmemBSize, -+ -+ // Warp-level parameters -+ kWarpFragmentASize, -+ kWarpFragmentBSize, -+ kWarpFragmentCSize, -+ kInvalid -+}; -+ -+struct Reference { -+ ParameterID parameter_id; -+ -+ union { -+ int value; -+ -+ struct { -+ int m, n, k; -+ } gemm_shape; -+ -+ struct { -+ int row, column; -+ } matrix_shape; -+ }; -+ -+ std::string error_msg; -+ -+ Reference( -+ ParameterID parameter_id_, -+ int value_=-1, -+ std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} -+}; -+ -+ -+template -+struct TestbedSanity { -+ -+ // -+ // Type definitions (All Gemm types top down) -+ // -+ -+ // Unpacking Gemm types in the following order -+ // Kernel-level > Threadblock-level > Warp-level > Instruction-level -+ -+ // kernel-level cutlass Gemm -+ using GemmKernel = typename Gemm::GemmKernel; -+ -+ // -+ // Threadblock-level gemm types -+ // -+ using MmaThreadBlock = typename GemmKernel::Mma; -+ -+ // Threadblock-level gemm shape covering one stage -+ using ThreadblockShape = typename MmaThreadBlock::Shape; -+ -+ // Shared memory size covering all stages -+ using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; -+ using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; -+ using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; -+ using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; -+ -+ -+ /// Number of stages -+ static int const kStages = MmaThreadBlock::Base::kStages; -+ -+ /// Number of warp-level GEMM oeprations -+ static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; -+ -+ -+ // -+ // Warp-level gemm types -+ // -+ -+ // Warp-level gemm operator -+ using MmaWarp = typename MmaThreadBlock::Operator; -+ -+ // Warp-level gemm shape covering all kgroups -+ using WarpShape = typename MmaWarp::Shape; -+ -+ // Warp-level framents holding operands A & B operand and destination C -+ using WarpFragmentA = typename MmaWarp::FragmentA; -+ using WarpFragmentB = typename MmaWarp::FragmentB; -+ using WarpFragmentC = typename MmaWarp::FragmentC; -+ -+ // -+ // Instruction-level gemm types -+ // -+ -+ // Instruction-level gemm operator -+ using MmaInstruction = typename MmaWarp::Policy::Operator; -+ -+ // Instruction shape -+ using InstructionShape = typename MmaInstruction::Shape; -+ -+ // Instruction-level framents holding operands A & B operand and destination C -+ using InstructionFragmentA = typename MmaInstruction::FragmentA; -+ using InstructionFragmentB = typename MmaInstruction::FragmentB; -+ using InstructionFragmentC = typename MmaInstruction::FragmentC; -+ -+ // -+ // Testbed types -+ // -+ -+ // Vector of values holding user provided reference -+ using ReferenceVector = std::vector; -+ -+ // -+ // Data members -+ // -+ ReferenceVector references; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } -+ -+ // verify all parameter in ReferenceVector -+ bool verify() { -+ for(auto ref : references) -+ verify_parameter(ref); -+ return true; -+ } -+ -+ // verify parameter of type Reference -+ void verify_parameter(Reference const& ref) { -+ switch(ref.parameter_id) { -+ case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; -+ case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; -+ case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; -+ } -+ } -+ -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+// Overload output operators for TesbedSanity -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { -+ -+ -+ out << "Gemm internal parameters" << std::endl -+ << " Threadblock-level parameters:" << std::endl -+ << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl -+ << " kStages = " << TestbedSanity::kStages << std::endl -+ << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl -+ <<" Shared memory sizes:" << std::endl -+ <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl -+ <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl -+ <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl -+ <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl -+ <<" Warp-level parameters" << std::endl -+ <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl -+ <<" Fragment sizes:" << std::endl -+ <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl -+ <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl -+ <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl -+ <<" Instruction-level parameters" << std::endl -+ <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl -+ <<" Fragment sizes:" << std::endl -+ <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl -+ <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl -+ <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; -+ -+ return out; -+} -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h -new file mode 100644 -index 0000000..56f3e5e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_sparse.h -@@ -0,0 +1,488 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+ -+ Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct SparseTestbed { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ static int const kSparse = Gemm::GemmKernel::kSparse; -+ static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; -+ static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; -+ static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; -+ -+ using ElementE = typename Gemm::GemmKernel::ElementE; -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ cutlass::Distribution::Kind init_E; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_A_uncompressed; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ cutlass::HostTensor tensor_E; -+ cutlass::HostTensor tensor_E_reordered; -+ -+ // -+ // Methods -+ // -+ -+ SparseTestbed( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080) -+ : init_A(init_A_), -+ init_B(init_B_), -+ init_C(init_C_), -+ init_E(init_E_), -+ seed(seed_) {} -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); -+ tensor_A_uncompressed.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ tensor_E.resize(cutlass::make_Coord( -+ problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ tensor_E_reordered.resize(cutlass::make_Coord( -+ problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_E.host_view(), seed, kMetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(tensor_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false); -+ } -+ -+ cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / kSparse / kElementsPerElementE}); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ tensor_E_reordered.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\nE =\n" << tensor_E.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), -+ tensor_E.host_ref(), problem_size.m(), problem_size.k()); -+ -+ cutlass::reference::host::Gemm< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, typename Gemm::Operator> -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, -+ alpha, -+ tensor_A_uncompressed.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ tensor_A.device_ref(), -+ tensor_B.device_ref(), -+ tensor_C.device_ref(), -+ tensor_D.device_ref(), -+ tensor_E_reordered.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ // This failure is likely due to insufficient device capabilities. Waive the test. -+ if (status != cutlass::Status::kSuccess) { -+ return true; -+ } -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllSparseGemm() { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) -+ // because of the reordering of operand E -+ int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), -+ kMinimumOperandElementSize); -+ -+ int const kAlignmentN = 128 / kMinimumOperandElementSize; -+ -+ int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; -+ -+ int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; -+ -+ int problem_size_k[] = {Gemm::ThreadblockShape::kK, -+ Gemm::ThreadblockShape::kK * (Gemm::kStages + 1)}; -+ -+ int split_k_slices[] = { -+ 1, 2, 3 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ SparseTestbed testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int split_k : split_k_slices) { -+ -+ if (!Gemm::kSplitKSerial && split_k > 1) { -+ continue; -+ } -+ -+ if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { -+ continue; -+ } -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ passed = testbed.run( -+ problem_size, -+ split_k, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h -new file mode 100644 -index 0000000..73dda7e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_splitk.h -@@ -0,0 +1,218 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedSplitK : public Testbed { -+ -+ using Base = Testbed; -+ -+ using ElementCompute = typename Base::ElementCompute; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSplitK( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ Base(init_A_, init_B_, init_C_, seed_) { } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmCoord problem_size, -+ int split_k_slices, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ problem_size, -+ this->tensor_A.device_ref(), -+ this->tensor_B.device_ref(), -+ this->tensor_C.device_ref(), -+ this->tensor_D.device_ref(), -+ {alpha, beta}, -+ split_k_slices -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess); -+ -+ // -+ // Verify -+ // -+ -+ return this->verify(problem_size, alpha, beta); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+bool TestAllGemmSplitK() { -+ bool passed = true; -+ -+ cutlass::gemm::GemmCoord problem_sizes[] = { -+ {8, 8, 2048}, -+ {8, 8, 2056}, -+ {264, 72, 520}, -+ {264, 520, 120}, -+ {264, 520, 264} -+ }; -+ -+ int split_k_slices[] = { -+ 1, 2, 4, 5, 7 -+ }; -+ -+ double problem_alpha[] = { -+ 0.5 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ using Testbed = TestbedSplitK; -+ using ElementCompute = typename Testbed::ElementCompute; -+ -+ Testbed testbed; -+ -+ for (auto problem_size : problem_sizes) { -+ for (int split_k_count : split_k_slices) { -+ for (double alpha : problem_alpha) { -+ for (double beta : problem_beta) { -+ -+ passed = testbed.run( -+ problem_size, -+ split_k_count, -+ ElementCompute(alpha), -+ ElementCompute(beta) -+ ); -+ -+ if (!passed) { -+ std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; -+ return false; -+ } -+ } -+ } -+ } -+ } -+ -+ EXPECT_TRUE(passed); -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h -new file mode 100644 -index 0000000..1050a2e ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_symm_universal.h -@@ -0,0 +1,592 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide Symm update interface -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/symm.h" -+#include "cutlass/util/reference/host/symm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedSymmUniversal { -+ -+ using ElementA = typename Symm::ElementA; -+ using ElementB = typename Symm::ElementB; -+ using ElementC = typename Symm::ElementC; -+ using ElementAccumulator = typename Symm::ElementAccumulator; -+ using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedSymmUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ -+ EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the Symm workspace -+ // -+ -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); -+ } -+ else if (Symm::kSideModeA == cutlass::SideMode::kRight) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); -+ } -+ -+ tensor_B.resize(problem_size.mn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); -+ tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a Symm -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ using HostReference = typename cutlass::platform::conditional< -+ (cutlass::platform::is_same -+ >::value || -+ cutlass::platform::is_same -+ >::value -+ ), -+ cutlass::reference::host::SymmComplex< -+ typename Symm::ElementA, typename Symm::LayoutA, -+ Symm::kSideModeA, Symm::kFillModeA, -+ typename Symm::ElementB, typename Symm::LayoutB, -+ typename Symm::ElementC, typename Symm::LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ Symm::kBlasMode>, -+ cutlass::reference::host::Symm< -+ typename Symm::ElementA, typename Symm::LayoutA, -+ Symm::kSideModeA, Symm::kFillModeA, -+ typename Symm::ElementB, typename Symm::LayoutB, -+ typename Symm::ElementC, typename Symm::LayoutC, -+ ElementCompute, -+ ElementAccumulator> -+ >::type; -+ -+ -+ HostReference reference_symm; -+ -+ reference_symm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Symm::SymmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) -+ << " beta: " << ElementCompute(beta) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the Symm operator -+ // -+ -+ int batch_stride_A; -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) -+ batch_stride_A = problem_size.m()*problem_size.m(); -+ if (Symm::kSideModeA == cutlass::SideMode::kRight) -+ batch_stride_A = problem_size.n()*problem_size.n(); -+ -+ typename Symm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ batch_stride_A, -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Symm symm_op; -+ -+ size_t workspace_size = Symm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = symm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the Symm -+ // -+ -+ status = symm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ //if (true) { -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_" -+ << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) -+ << "device_" -+ << "fill_mode_a_" -+ << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : -+ (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) -+ << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : -+ (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Symm::ThreadblockShape::kM << "x" -+ << Symm::ThreadblockShape::kN << "x" -+ << Symm::ThreadblockShape::kK << "_" -+ << Symm::WarpShape::kM << "x" -+ << Symm::WarpShape::kN << "x" -+ << Symm::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "alpha: " << ElementCompute(alpha) << "\n" -+ << "beta: " << ElementCompute(beta) << "\n" -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nC:\n" << tensor_C.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestsymmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedSymmUniversal testbed; -+ -+ using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllSymmUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Symm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentK, -+ Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, -+ Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 3.0 -+ }; -+ -+ double problem_beta[] = { -+ 0, 2.0 -+ }; -+ -+ -+ using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ int k = 0; -+ if (Symm::kSideModeA == cutlass::SideMode::kLeft) -+ k = m; -+ else if (Symm::kSideModeA == cutlass::SideMode::kRight) -+ k = n; -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ #if 0 -+ // skip very small K problems -+ if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { -+ continue; -+ } -+ #endif -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedSymmUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h -new file mode 100644 -index 0000000..db40eff ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_trmm_universal.h -@@ -0,0 +1,609 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/error_metrics.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/trmm_complex.h" -+#include "cutlass/core_io.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedTrmmUniversal { -+ -+ using ElementA = typename Trmm::ElementA; -+ using ElementB = typename Trmm::ElementB; -+ using ElementC = typename Trmm::ElementC; -+ using ElementAccumulator = typename Trmm::ElementAccumulator; -+ using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_D; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedTrmmUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_symmetric_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int mantissa_in_bits) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillSymmetricRandomUniform( -+ view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillSymmetricRandomGaussian( -+ view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) -+ template -+ bool initialize_pad_diagonal_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed, -+ int alignment) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillPadDiagonalRandomUniform( -+ view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the TRMM workspace -+ // -+ -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); -+ } -+ else if (Trmm::kSideMode == cutlass::SideMode::kRight) { -+ tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); -+ } -+ -+ tensor_B.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); -+ //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); -+ tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ -+ if (tensor_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ -+ if (reference_D.size() > 1) -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); -+ -+ bool passed = l2_norm < cutlass::MantissaInBits::error; -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a TRMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha) { -+ -+ // -+ // Verify -+ // -+ -+ using HostReference = typename cutlass::platform::conditional< -+ (cutlass::platform::is_same -+ >::value || -+ cutlass::platform::is_same -+ >::value -+ ), -+ cutlass::reference::host::TrmmComplex< -+ typename Trmm::ElementA, typename Trmm::LayoutA, -+ Trmm::kTransformA, -+ Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, -+ typename Trmm::ElementB, typename Trmm::LayoutB, -+ Trmm::kTransformB, -+ typename Trmm::ElementC, typename Trmm::LayoutC, -+ ElementCompute, -+ ElementAccumulator>, -+ cutlass::reference::host::Trmm< -+ typename Trmm::ElementA, typename Trmm::LayoutA, -+ Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, -+ typename Trmm::ElementB, typename Trmm::LayoutB, -+ typename Trmm::ElementC, typename Trmm::LayoutC, -+ ElementCompute, -+ ElementAccumulator> -+ >::type; -+ -+ -+ HostReference reference_trmm; -+ -+ reference_trmm( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ return compare_reference(problem_size, alpha); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Trmm::TrmmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1)) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+#if 0 -+ std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size -+ << " alpha: " << ElementCompute(alpha) << std::endl; -+#endif -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the TRMM operator -+ // -+ -+ int batch_stride_A; -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) -+ batch_stride_A = problem_size.m()*problem_size.m(); -+ if (Trmm::kSideMode == cutlass::SideMode::kRight) -+ batch_stride_A = problem_size.n()*problem_size.n(); -+ -+ typename Trmm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_D.device_data(), -+ batch_stride_A, -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Trmm trmm_op; -+ -+ size_t workspace_size = Trmm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the TRMM -+ // -+ -+ status = trmm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ bool passed = this->verify(problem_size, alpha); -+ -+ if (!passed) { -+ std::stringstream fname; -+ -+ fname << "error_Trmm_device_" -+ << "fill_mode_" -+ << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : -+ (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) -+ << "side_mode_" -+ << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : -+ (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) -+ << "mnk_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Trmm::ThreadblockShape::kM << "x" -+ << Trmm::ThreadblockShape::kN << "x" -+ << Trmm::ThreadblockShape::kK << "_" -+ << Trmm::WarpShape::kM << "x" -+ << Trmm::WarpShape::kN << "x" -+ << Trmm::WarpShape::kK << ".txt"; -+ -+ std::cout << fname.str() << std::endl; -+ -+ std::ofstream results(fname.str()); -+ -+ results << problem_size << std::endl; -+ -+ results -+ << "\nA:\n" << tensor_A.host_view() << "\n" -+ << "\nB:\n" << tensor_B.host_view() << "\n" -+ << "\nD reference:\n" << reference_D.host_view() << "\n" -+ << "\nD computed:\n" << tensor_D.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestTrmmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0) { -+ -+ bool passed = true; -+ -+ TestbedTrmmUniversal testbed; -+ -+ using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllTrmmUniversal() { -+ bool passed = true; -+ -+ int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Trmm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = kAlignmentM; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value -+ ? 4 : kAlignment; -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentK, -+ Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, -+ Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1 // Just running one batch for now (removing 2, 3, 5, 7) -+ }; -+ -+ double problem_alpha[] = { -+ 1.0, 2.0 -+ }; -+ -+ using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int batch_count : batch_counts) { -+ for (auto alpha : problem_alpha) { -+ -+ int k = 0; -+ if (Trmm::kSideMode == cutlass::SideMode::kLeft) -+ k = m; -+ else if (Trmm::kSideMode == cutlass::SideMode::kRight) -+ k = n; -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+#if 0 -+ // skip very small K problems -+ if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { -+ continue; -+ } -+#endif -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedTrmmUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h -new file mode 100644 -index 0000000..615e9c5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_universal.h -@@ -0,0 +1,547 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+ -+#include "testbed_utils.h" -+ -+namespace test { -+namespace gemm { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct TestbedUniversal { -+ -+ using ElementA = typename Gemm::ElementA; -+ using ElementB = typename Gemm::ElementB; -+ using ElementC = typename Gemm::ElementC; -+ using ElementAccumulator = typename Gemm::ElementAccumulator; -+ using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; -+ -+ /// Initialization -+ cutlass::Distribution::Kind init_A; -+ cutlass::Distribution::Kind init_B; -+ cutlass::Distribution::Kind init_C; -+ uint64_t seed; -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D; -+ cutlass::HostTensor reference_D; -+ -+ // -+ // Methods -+ // -+ -+ TestbedUniversal( -+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, -+ uint64_t seed_ = 2080 -+ ): -+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor( -+ cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ -+ double scope_max, scope_min; -+ int bits_input = cutlass::sizeof_bits::value; -+ int bits_output = cutlass::sizeof_bits::value; -+ -+ if (bits_input == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } else if (bits_input <= 8) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (bits_output == 16) { -+ scope_max = 5; -+ scope_min = -5; -+ } else { -+ scope_max = 8; -+ scope_min = -8; -+ } -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ view, seed, scope_max, scope_min, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Identity) { -+ -+ cutlass::reference::host::TensorFillIdentity(view); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); -+ } -+ else if (dist_kind == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential( -+ view.data(), view.capacity()); -+ } -+ else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Initializes data structures -+ void initialize(cutlass::gemm::GemmCoord problem_size) { -+ // -+ // Allocate the GEMM workspace -+ // -+ -+ tensor_A.resize(problem_size.mk()); -+ tensor_B.resize(problem_size.kn()); -+ tensor_C.resize(problem_size.mn()); -+ tensor_D.resize(problem_size.mn()); -+ reference_D.resize(problem_size.mn(), false); -+ -+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); -+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); -+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); -+ -+ // It is possible to randomly initialize to all zeros, so override this with non-zeros -+ // in the upper left corner of each operand. -+ cutlass::Coord<2> origin(0); -+ tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); -+ tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); -+ tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); -+ -+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D.sync_device(); -+ } -+ -+ /// Compares computed reference with device reference and outputs to a file if incorrect -+ bool compare_reference( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ tensor_D.sync_host(); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); -+ -+ EXPECT_TRUE(passed) << " mismatched reference"; -+ -+ if (!passed) { -+ -+ /* -+ std::stringstream fname; -+ -+ fname << "error_Gemm_device_" -+ << problem_size.m() << "x" -+ << problem_size.n() << "x" -+ << problem_size.k() << "_" -+ << Gemm::ThreadblockShape::kM << "x" -+ << Gemm::ThreadblockShape::kN << "x" -+ << Gemm::ThreadblockShape::kK << "_" -+ << Gemm::WarpShape::kM << "x" -+ << Gemm::WarpShape::kN << "x" -+ << Gemm::WarpShape::kK << ".txt"; -+ -+ std::ofstream file(fname.str()); -+ */ -+ -+ std::ofstream file("testbed_universal_errors.txt"); -+ -+ file -+ << "problem: " << problem_size -+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; -+ -+ file -+ << "A =\n" << tensor_A.host_view() -+ << "\nB =\n" << tensor_B.host_view() -+ << "\nC =\n" << tensor_C.host_view() -+ << "\n\nReference =\n" << reference_D.host_view() -+ << "\nComputed =\n" << tensor_D.host_view(); -+ } -+ -+ return passed; -+ } -+ -+ /// Verifies the result is a GEMM -+ bool verify( -+ cutlass::gemm::GemmCoord problem_size, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ // -+ // Verify -+ // -+ -+ cutlass::reference::host::GemmComplex< -+ typename Gemm::ElementA, typename Gemm::LayoutA, -+ typename Gemm::ElementB, typename Gemm::LayoutB, -+ typename Gemm::ElementC, typename Gemm::LayoutC, -+ ElementCompute, ElementAccumulator -+ >( -+ problem_size, -+ alpha, -+ tensor_A.host_ref(), -+ Gemm::kTransformA, -+ tensor_B.host_ref(), -+ Gemm::kTransformB, -+ beta, -+ tensor_C.host_ref(), -+ reference_D.host_ref(), -+ ElementAccumulator(0) -+ ); -+ -+ if (Relu) { -+ for (int i = 0; i < problem_size.m(); ++i) { -+ for (int j = 0; j < problem_size.n(); ++j) { -+ reference_D.at(cutlass::MatrixCoord(i, j)) = -+ ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) -+ ? (typename Gemm::ElementC)0 -+ : reference_D.at(cutlass::MatrixCoord(i, j)); -+ } -+ } -+ } -+ -+ return compare_reference(problem_size, alpha, beta); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.sharedMemPerBlockOptin < smem_size) { -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Executes one test -+ bool run( -+ cutlass::gemm::GemmUniversalMode mode, -+ cutlass::gemm::GemmCoord problem_size, -+ int batch_count = 1, -+ ElementCompute alpha = ElementCompute(1), -+ ElementCompute beta = ElementCompute(0)) -+ { -+/* -+ std::cout << "\n-----------------------\n"; -+ std::cout << "mode: " << (int) mode << "\n"; -+ std::cout << "problem size: " << problem_size << "\n"; -+ std::cout << "batch_count: " << batch_count << "\n"; -+ std::cout << "alpha: " << alpha << "\n"; -+ std::cout << "beta: " << beta << "\n"; -+ std::cout << "-----------------------\n\n"; -+*/ -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ this->initialize(problem_size); -+ -+ // -+ // Initialize the GEMM operator -+ // -+ -+ typename Gemm::Arguments arguments{ -+ mode, -+ problem_size, -+ batch_count, -+ {alpha, beta}, -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_D.device_data(), -+ problem_size.m() * problem_size.k(), -+ problem_size.n() * problem_size.k(), -+ problem_size.m() * problem_size.n(), -+ problem_size.m() * problem_size.n(), -+ tensor_A.layout().stride(0), -+ tensor_B.layout().stride(0), -+ tensor_C.layout().stride(0), -+ tensor_D.layout().stride(0) -+ }; -+ -+ Gemm gemm_op; -+ -+ size_t workspace_size = Gemm::get_workspace_size(arguments); -+ -+ cutlass::device_memory::allocation workspace(workspace_size); -+ -+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Run the GEMM -+ // -+ -+ status = gemm_op(); -+ -+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); -+ -+ // -+ // Verify -+ // -+ -+ bool passed = this->verify(problem_size, alpha, beta); -+ -+ if (!passed) { -+ std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+template -+bool TestGemmUniversal( -+ cutlass::gemm::GemmCoord const & problem_size, -+ cutlass::gemm::GemmUniversalMode mode, -+ int batch_count, -+ double alpha = 1.0, -+ double beta = 2.0) { -+ -+ bool passed = true; -+ -+ TestbedUniversal testbed; -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ return passed; -+} -+ -+template -+bool TestAllGemmUniversal() { -+ bool passed = true; -+ -+ -+ int const kMinimumOperandElementSize = -+ std::min( -+ int(cutlass::sizeof_bits::value), -+ int(cutlass::sizeof_bits::value)); -+ -+ int const kAlignment = cutlass::platform::is_same< -+ typename Gemm::OperatorClass, -+ cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; -+ -+ // int8_t gemm alignment constraints -+ int const kAlignmentM = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentN = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value ? 4 : kAlignment; -+ -+ int const kAlignmentK = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ (cutlass::platform::is_same::value || -+ cutlass::platform::is_same::value) ? 4 : kAlignment; -+ -+ -+ -+ cutlass::gemm::GemmUniversalMode modes[] = { -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ }; -+ -+ int problem_size_m[] = { -+ kAlignmentM, 512 - 3*kAlignmentM -+ }; -+ -+ int problem_size_n[] = { -+ kAlignmentN, 512 - 2*kAlignmentN -+ }; -+ -+ int problem_size_k[] = { -+ kAlignmentK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, -+ Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK -+ }; -+ -+ int batch_counts[] = { // may be interpretted as batch count or split-K slices -+ 1, 2, 3, 5, 7 -+ }; -+ -+ double problem_alpha[] = { -+ 1 -+ }; -+ -+ double problem_beta[] = { -+ 2.0 -+ }; -+ -+ -+ using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; -+ -+ for (cutlass::gemm::GemmUniversalMode mode : modes) { -+ for (int m : problem_size_m) { -+ for (int n : problem_size_n) { -+ for (int k : problem_size_k) { -+ for (int batch_count : batch_counts) { -+ -+ for (auto alpha : problem_alpha) { -+ for (auto beta : problem_beta) { -+ -+ if (mode == cutlass::gemm::GemmUniversalMode::kGemm || -+ mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { -+ -+ // skip very small K problems -+ if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { -+ continue; -+ } -+ } -+ -+ cutlass::gemm::GemmCoord problem_size(m, n, k); -+ -+ TestbedUniversal testbed; -+ -+ passed = testbed.run( -+ mode, -+ problem_size, -+ batch_count, -+ cutlass::from_real(alpha), -+ cutlass::from_real(beta) -+ ); -+ -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ /* -+ // large problem with high coverage -+ for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { -+ TestbedUniversal testbed; -+ -+ cutlass::gemm::GemmCoord problem_size(72, 56, 8192); -+ -+ passed = testbed.run( -+ cutlass::gemm::GemmUniversalMode::kGemm, -+ problem_size, -+ split_k_slices, -+ cutlass::from_real(1.0), -+ cutlass::from_real(2.0) -+ ); -+ -+ if (!passed) { -+ break; -+ } -+ } -+ */ -+ -+ return passed; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h b/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h -new file mode 100644 -index 0000000..e47ecda ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/testbed_utils.h -@@ -0,0 +1,53 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+inline char const *to_string(cutlass::Status status) { -+ -+ switch (status) { -+ case cutlass::Status::kSuccess: return "kSuccess"; -+ case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; -+ case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; -+ case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; -+ case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; -+ case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; -+ case cutlass::Status::kErrorInternal: return "kErrorInternal"; -+ case cutlass::Status::kInvalid: return "kInvalid"; -+ default: break; -+ } -+ return "invalid"; -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu -new file mode 100644 -index 0000000..1e31402 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_nu_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_un_tensor_op_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -new file mode 100644 -index 0000000..8dc41a4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_nu_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_l_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_ls_u_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf32n_cf32t_cf32t_rs_u_un_tensor_op_fast_f32, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplexFastF32, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..437bed5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -new file mode 100644 -index 0000000..b26a8d2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu -@@ -0,0 +1,137 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddGaussianComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu -new file mode 100644 -index 0000000..2db3d2c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_cf64n_cf64n_cf64t_tensor_op_f64_sm80.cu -@@ -0,0 +1,301 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_l_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64n_cf64n_cf64t_ls_u_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kNone -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_cf64h_cf64n_cf64t_ls_u_un_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ cutlass::complex, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddComplex, -+ cutlass::ComplexTransform::kConjugate -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..d8ad244 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu -new file mode 100644 -index 0000000..a9ed921 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_rs_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_l_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_u_nu_tensor_op_fast_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32n_f32t_f32t_rs_l_nu_tensor_op_fast_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..56c6396 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align1, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32n_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu -new file mode 100644 -index 0000000..9217ebd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f32t_f32n_f32t_tensor_op_fast_f32_ls_sm80.cu -@@ -0,0 +1,458 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_un_tensor_op_fast_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_l_nu_tensor_op_fast_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_un_tensor_op_fast_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f32t_f32n_f32t_ls_u_nu_tensor_op_fast_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAddFastF32 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu -new file mode 100644 -index 0000000..5339bc5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm90, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..0dd9064 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_ls_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..f00e50c ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64n_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64n_f64t_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..98f2a57 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64n_f64t_f64t_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_un_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_un_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64n_f64t_f64t_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu -new file mode 100644 -index 0000000..bb4443d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_ls_sm80.cu -@@ -0,0 +1,414 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_ls_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu -new file mode 100644 -index 0000000..dd07d78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_f64t_f64t_f64n_tensor_op_f64_rs_sm80.cu -@@ -0,0 +1,415 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 32x32x16_16x16x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 16, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 64x64x16_32x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 128x64x16_64x32x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<64, 32, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 64x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_f64t_f64t_f64n_rs_u_nu_tensor_op_f64, 128x128x16_32x64x16) { -+ -+ using ElementOutput = double; -+ using ElementAccumulator = double; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ double, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kRight, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ double, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<32, 64, 16>, -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3 -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..97106d8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,500 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::RowMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu -new file mode 100644 -index 0000000..723af64 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32n_tf32t_f32t_tensor_op_f32_rs_sm80.cu -@@ -0,0 +1,252 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align1, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_l_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_u_nu_tensor_op_f32_align1_align4, 128x64x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kUpper, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 64, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32n_tf32t_f32t_rs_l_nu_tensor_op_f32_align1_align4, 64x128x32_32x64x32) { -+ -+using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::ColumnMajor, -+ cutlass::SideMode::kRight, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::RowMajor, -+ float, cutlass::layout::RowMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+>; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..ebb427b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32n_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,449 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align1, 64x128x32_32x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, cutlass::FillMode::kLower, cutlass::DiagType::kNonUnit, -+ float, cutlass::layout::ColumnMajor, -+ float, cutlass::layout::ColumnMajor, -+ float, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 128, 32>, -+ cutlass::gemm::GemmShape<32, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ float, -+ 1, -+ float, -+ float -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, -+ 3, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 128x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32n_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::ColumnMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu -new file mode 100644 -index 0000000..ba3f7f3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu -@@ -0,0 +1,458 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide TRMM interface -+ -+ -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/blas3.h" -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/trmm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "testbed_trmm_universal.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+////////////////////////////////////////////Test name////////////////////////////////////////////////// -+// -+// SM80_Device_Trmm_{ElementA}{LayoutA}_{ElementB}{LayoutB}_{ElementC}{LayoutC}_{SideMode}_{FillMode}\ -+// _{DiagType}_tensor_op_{ElementAccumulator}_align{AlignmentA}_align{AlignmentB} -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align1, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 1, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 1, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 64x64x32_32x32x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<32, 32, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 256x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_un_tensor_op_f32_align1_align4, 128x256x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_l_nu_tensor_op_f32_align1_align4, 256x128x32_64x64x32) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kLower, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 32>, -+ cutlass::gemm::GemmShape<64, 64, 32>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 64x64x16_32x32x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<32, 32, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 10, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_un_tensor_op_f32_align1_align4, 128x128x16_64x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 128, 16>, -+ cutlass::gemm::GemmShape<64, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 128, 16>, -+ cutlass::gemm::GemmShape<128, 64, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 4, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 128x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<128, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x256x16_64x128x16) { -+ -+ using ElementOutput = float; -+ using ElementAccumulator = float; -+ -+ using Trmm = cutlass::gemm::device::Trmm< -+ float, -+ cutlass::layout::RowMajor, -+ cutlass::SideMode::kLeft, -+ cutlass::FillMode::kUpper, -+ cutlass::DiagType::kNonUnit, -+ float, -+ cutlass::layout::ColumnMajor, -+ ElementOutput, -+ cutlass::layout::RowMajor, -+ ElementAccumulator, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ cutlass::gemm::GemmShape<256, 256, 16>, -+ cutlass::gemm::GemmShape<64, 128, 16>, -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementAccumulator -+ >, -+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, -+ 3, -+ 1, -+ 4, -+ false, -+ cutlass::arch::OpMultiplyAdd -+ >; -+ -+ EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu b/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu -new file mode 100755 -index 0000000..4e06485 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/kernel/batched_gemv.cu -@@ -0,0 +1,1082 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include "testbed_gemv.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcr_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcr_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_rcr_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcr_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+///////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_crc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_crc_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_crc_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_crc_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -+ -+///////////// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp16_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+/// -+ -+TEST(SM50_batched_gemv, 1x64x64x1_1x64x4x1_1x4x4x1_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 1); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 1; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x64x4_1x64x4x2_1x4x4x2_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 64, 4); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 2; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x256x256x64_1x64x4x8_1x4x4x8_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 256, 256, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; -+ static int const kBatchTileSize = 8; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x7x256x4096_1x8x4x64_1x1x4x64_rcc_i8_i32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 7, 256, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ int8_t, int32_t, int32_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcc_alpha_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x27x4096_1x8x1x64_1x1x1x64_rcc_alpha_beta_fp32_fp32) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 27, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 1>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ float, float, float, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, 4.5f, -0.5f); -+} -+ -+TEST(SM50_batched_gemv, 1x64x24x4096_1x8x4x64_1x1x4x64_rcc_alpha_beta_fp16_fp16) -+{ -+ cutlass::gemm::BatchedGemmCoord problem_size(1, 64, 24, 4096); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<1, 8, 4>; -+ using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; -+ static int const kBatchTileSize = 64; -+ -+ test::gemm::kernel::batched_gemv_kernel_test< -+ ThreadBlockShape, -+ ThreadShape, -+ cutlass::half_t, float, cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::layout::ColumnMajor, -+ cutlass::layout::ColumnMajor, -+ kBatchTileSize>(problem_size, cutlass::half_t(4.5f), cutlass::half_t(-0.5f)); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h -new file mode 100755 -index 0000000..dc551ef ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/kernel/testbed_gemv.h -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "cutlass/gemm/kernel/default_gemv.h" -+#include "cutlass/gemm/kernel/gemv_batched_strided.h" -+ -+namespace test { -+namespace gemm { -+namespace kernel { -+ -+template -+void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementCD_ alpha = ElementCD_(1), -+ ElementCD_ beta = ElementCD_(0), -+ bool perf_test = false, -+ int perf_test_iter = 1) -+{ -+ using ThreadBlockShape = ThreadBlockShape_; -+ using ThreadShape = ThreadShape_; -+ using ElementA = ElementAB_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementAB_; -+ using LayoutB = LayoutB_; -+ using ElementAccumulator = ElementCD_; -+ using ElementCD = ElementCD_; -+ using LayoutCD = LayoutCD_; -+ -+ using GemvKernel = cutlass::gemm::kernel::DefaultGemv; -+ -+ using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; -+ using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; -+ -+ if (DEBUG) -+ { -+ problem_size = cutlass::gemm::BatchedGemmCoord( -+ problem_size.m(), problem_size.n(), problem_size.k(), 1); -+ } -+ -+ // Create host tensors that will be the backing store for the batches -+ // Note that no device memory is initially allocated -+ cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); -+ cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); -+ -+ // Reserve memory for the batch of tensors -+ matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); -+ matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); -+ matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); -+ matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); -+ -+ // Fill eatch tensor batch -+ const int seed = 9876; -+ for (int b = 0; b < problem_size.batch(); b++) -+ { -+ if(DEBUG) -+ { -+ cutlass::reference::host::BlockFillSequential( -+ matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); -+ cutlass::reference::host::BlockFillSequential( -+ matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); -+ } -+ else -+ { -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(b*matrix_A.capacity()), -+ seed + 1660, -+ 8, -+ -8, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(b*matrix_B.capacity()), -+ seed + 1880, -+ 8, -+ -8, -+ 0 -+ ); -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ ThreadBlockSwizzle swizzle; -+ -+ cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, -+ ThreadBlockShape::kN, -+ problem_size.k(), // no split-k -+ DEBUG ? 1 : THREAD_B }; -+ -+ cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); -+ -+ #if 0 -+ printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); -+ printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); -+ #endif -+ -+ // No split-k -+ EXPECT_EQ(tiled_size.k(), problem_size.k()); -+ -+ dim3 grid = swizzle.get_grid_shape(tiled_shape); -+ dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); -+ -+ // Some sanity checks -+ EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); -+ EXPECT_TRUE( block.x <= 1024 ); -+ EXPECT_TRUE( block.y <= 1024 ); -+ EXPECT_TRUE( block.z <= 64 ); -+ -+ #if 0 -+ printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); -+ printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); -+ #endif -+ -+ cudaError_t result; -+ cudaEvent_t start_event, end_event; -+ -+ for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) -+ { -+ if (perf_test && iter == 1) -+ { -+ result = cudaEventCreate(&start_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaEventCreate(&end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ result = cudaEventRecord(start_event); -+ EXPECT_EQ(result, cudaSuccess); -+ } -+ -+ if (beta == ElementCD(0)) -+ { -+ if (alpha == ElementCD(1)) -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ else -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ alpha, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ } -+ else -+ { -+ cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( -+ problem_size, -+ alpha, -+ beta, -+ matrix_A.device_ref(), -+ matrix_A.capacity(), -+ matrix_B.device_ref(), -+ matrix_B.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity(), -+ matrix_C_computed.device_ref(), -+ matrix_C_computed.capacity() -+ ); -+ } -+ -+ if (iter == 0) -+ { -+ result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ } -+ } -+ -+ if (perf_test) -+ { -+ result = cudaEventRecord(end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ } -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ -+ if (perf_test) -+ { -+ float ms; -+ result = cudaEventElapsedTime(&ms, start_event, end_event); -+ EXPECT_EQ(result, cudaSuccess); -+ -+ double flops = (double(problem_size.m()) * -+ double(problem_size.n()) * -+ double(problem_size.k()) * -+ double(problem_size.batch()) * 2); // 2 for MAC -+ -+ double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + -+ sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); -+ -+ double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); -+ -+ double avg_runtime = double(ms) / perf_test_iter; -+ double gflops_per_sec = flops / 1.0e6 / avg_runtime; -+ double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; -+ double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; -+ -+ std::cout << "\n\nProblem size: " -+ << problem_size.m() -+ << " x " << problem_size.n() -+ << " x " << problem_size.k() -+ << " x " << problem_size.batch() -+ << std::endl; -+ -+ std::cout << " GFLOPs: " << gflops_per_sec << std::endl; -+ std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; -+ std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; -+ } -+ else -+ { -+ matrix_C_computed.sync_host(); -+ -+ // Compute the batched gemms -+ for (int b = 0; b < problem_size.batch(); b++) -+ { -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size.mnk(), alpha, -+ matrix_A.host_ref(b * matrix_A.capacity()), -+ matrix_B.host_ref(b * matrix_B.capacity()), beta, -+ matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(b * matrix_C_computed.capacity()), -+ matrix_C_reference.host_view(b * matrix_C_reference.capacity())); -+ -+ EXPECT_TRUE(passed) -+ //<< "A:\n" << matrix_A.host_view() << "\n" -+ //<< "B:\n" << matrix_B.host_view() << "\n" -+ << "Batch: " << b << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) -+ << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) -+ << "\n"; -+ } -+ } -+} -+ -+template -+void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, -+ ElementCD_ alpha = ElementCD_(1), -+ ElementCD_ beta = ElementCD_(0), -+ int iter = 50) -+{ -+ batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); -+} -+ -+} // namespace threadblock -+} // namespace kernel -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu -new file mode 100644 -index 0000000..1ac6ea5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm50.cu -@@ -0,0 +1,175 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Sgemm_thread, col_row_3x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x5x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 5, 3>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Dgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Dgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu -new file mode 100644 -index 0000000..23099b2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm60.cu -@@ -0,0 +1,499 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM60 -+// -+ -+TEST(SM60_Hgemm_thread, col_row_col_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_1x3x8) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 3, 8>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_7x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_7x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<7, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x3x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 3, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+}TEST(SM60_Hgemm_thread, row_col_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x8x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 3>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_row_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, row_col_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_row_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_row_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_Hgemm_thread, col_col_col_16x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu -new file mode 100644 -index 0000000..68f9110 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/gemm_sm61.cu -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM61 -+// -+ -+TEST(SM61_Igemm_thread, col_row_1x1x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM61_Igemm_thread, col_row_2x3x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 3, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM61_Igemm_thread, col_row_8x8x4) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 4>, -+ int8_t, -+ cutlass::layout::RowMajor, -+ int8_t, -+ cutlass::layout::ColumnMajor, -+ int32_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu b/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu -new file mode 100644 -index 0000000..5b1b5da ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/host/gemm_sm60_host.cu -@@ -0,0 +1,176 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed_host.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Compute capability SM60 -+// -+ -+TEST(SM60_host_Hgemm_thread, col_row_col_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_row_1x1x16) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<1, 1, 16>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_row_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_row_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, row_col_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_row_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_row_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_col_row_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor -+ >().run(); -+} -+ -+TEST(SM60_host_Hgemm_thread, col_col_col_2x2x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<2, 2, 2>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h b/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h -new file mode 100644 -index 0000000..bd78947 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/host/testbed_host.h -@@ -0,0 +1,232 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "cutlass/layout/vector.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace test { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+void kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ using Btype = typename Mma::ElementB; -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( -+ 0, // seed -+ 10, // max -+ 0, // min -+ 0); // bits after decimal -+ -+ cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( -+ tensor_A.host_view(), -+ tfill_rand_func); -+ -+ for (auto i=0; i< Shape::kM; i++) -+ for (auto j=0; j< Shape::kK; j++) -+ tfill_rand(cutlass::make_Coord(i,j)); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ -+ // Host side call -+ kernel( -+ tensor_D_computed.host_data(), -+ tensor_A.host_data(), -+ tensor_B.host_data(), -+ tensor_C.host_data()); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ -+ return passed; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/thread/testbed.h b/3rdparty/cutlass/test/unit/gemm/thread/testbed.h -new file mode 100644 -index 0000000..c5ad60f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/thread/testbed.h -@@ -0,0 +1,236 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace test { -+namespace gemm { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+__global__ void kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_A.host_data(), -+ tensor_A.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ //tensor_D_reference.fill(tensor_C.host_view()); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu -new file mode 100644 -index 0000000..28b49f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/batched_gemv.cu -@@ -0,0 +1,646 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock level GEMV -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "cutlass/gemm/threadblock/gemv.h" -+#include "cutlass/gemm/threadblock/default_gemv_core.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void batched_gemv_threadblock_test_kernel( -+ cutlass::gemm::GemmCoord problem_size, -+ LongIndex stride_a, -+ LongIndex stride_b, -+ LongIndex stride_c, -+ RefA ref_A, -+ RefB ref_B, -+ RefC ref_C -+ ) { -+ -+ typename Gemv::IteratorA::TensorCoord threadblock_offset_A(0, 0); -+ typename Gemv::IteratorB::TensorCoord threadblock_offset_B(0, 0); -+ typename Gemv::IteratorB::TensorCoord threadblock_offset_C(0, 0); -+ -+ // Move to the right batches for these threads -+ ref_A.add_pointer_offset(threadIdx.y * stride_a); -+ ref_B.add_pointer_offset(threadIdx.y * stride_b); -+ ref_C.add_pointer_offset(threadIdx.y * stride_c); -+ -+ // Construct iterators to A and B operands -+ typename Gemv::IteratorA::Params params_A(ref_A.layout()); -+ typename Gemv::IteratorA iterator_A(params_A, ref_A.data(), { problem_size.m(), problem_size.k() }, 0, threadblock_offset_A); -+ typename Gemv::IteratorB::Params params_B(ref_B.layout()); -+ typename Gemv::IteratorB iterator_B(params_B, ref_B.data(), { problem_size.k(), problem_size.n() }, threadIdx.x, threadblock_offset_B); -+ -+ Gemv gemv; -+ -+ typename Gemv::FragmentC accum; -+ accum.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ gemv(problem_size, accum, iterator_A, iterator_B, accum); -+ -+ // IteratorC is PitchLinear<> assumes n() contiguous -+ typename Gemv::IteratorC::Params params_C(ref_C.layout()); -+ typename Gemv::IteratorC iterator_C(params_C, ref_C.data(), { problem_size.m(), problem_size.n() }, threadIdx.x, threadblock_offset_C); -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+void batched_gemv_threadblock_test(cutlass::gemm::GemmCoord problem_size, int num_batch) -+{ -+ using Shape = Shape_; -+ using ElementA = ElementAB_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementAB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ThreadShape = cutlass::gemm::GemmShape<1, THREAD_N, THREAD_K>; -+ -+ using Core = typename cutlass::gemm::threadblock::DefaultGemvCore< -+ Shape, -+ ThreadShape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ if (DEBUG) -+ { -+ num_batch = 1; -+ } -+ -+ using Mma = cutlass::gemm::threadblock::Gemv; -+ -+ // Create host tensors that will be the backing store for the batches -+ // Note that no device memory is initially allocated -+ cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); -+ cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); -+ cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); -+ -+ // Reserve memory for the batch of tensors -+ matrix_A.reserve(problem_size.m()*problem_size.k()*num_batch); -+ matrix_B.reserve(problem_size.n()*problem_size.k()*num_batch); -+ matrix_C_computed.reserve(problem_size.m()*problem_size.n()*num_batch); -+ matrix_C_reference.reserve(problem_size.m()*problem_size.n()*num_batch, false); -+ -+ // Fill eatch tensor batch -+ const int seed = 6834; -+ for (int b = 0; b < num_batch; b++) -+ { -+ if(DEBUG) -+ { -+ cutlass::reference::host::BlockFillSequential( -+ matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); -+ cutlass::reference::host::BlockFillSequential( -+ matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); -+ } -+ else -+ { -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(b*matrix_A.capacity()), -+ seed + 1660, -+ 8, -+ -8, -+ 0 -+ ); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(b*matrix_B.capacity()), -+ seed + 1880, -+ 8, -+ -8, -+ 0 -+ ); -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ dim3 grid(1, 1); // only 1 CTA is used -+ dim3 block(Shape::kN / THREAD_N, num_batch, 1); -+ -+ #if 0 -+ printf("block dim = %d x %d\n", block.x, block.y); -+ #endif -+ -+ // Some sanity checks -+ EXPECT_TRUE( problem_size.n() % THREAD_N == 0 ); -+ EXPECT_TRUE( block.x*block.y <= MAX_THREADS_PER_BLOCK ); -+ -+ test::gemm::threadblock::batched_gemv_threadblock_test_kernel<<< grid, block >>>( -+ problem_size, -+ matrix_A.capacity(), -+ matrix_B.capacity(), -+ matrix_C_computed.capacity(), -+ matrix_A.device_ref(), -+ matrix_B.device_ref(), -+ matrix_C_computed.device_ref() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ // Compute the batched gemms -+ for (int b = 0; b < num_batch; b++) -+ { -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ problem_size.mnk(), -+ ElementC(1), -+ matrix_A.host_ref(b*matrix_A.capacity()), -+ matrix_B.host_ref(b*matrix_B.capacity()), -+ ElementC(0), -+ matrix_C_reference.host_ref(b*matrix_C_computed.capacity()) -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(b*matrix_C_computed.capacity()), -+ matrix_C_reference.host_view(b*matrix_C_reference.capacity())); -+ -+ EXPECT_TRUE(passed) -+ //<< "A:\n" << matrix_A.host_view() << "\n" -+ //<< "B:\n" << matrix_B.host_view() << "\n" -+ << "Batch: " << b << "\n" -+ << "Reference:\n" << matrix_C_reference.host_view(b*matrix_C_reference.capacity()) << "\n" -+ << "Computed:\n" << matrix_C_computed.host_view(b*matrix_C_computed.capacity()) << "\n"; -+ } -+} -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A: ColumnMajor -+// B: RowMajor -+// C: ColumnMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_crc_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+// A: RowMajor -+// B: ColumnMajor -+// C: RowMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcr_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+// A: RowMajor -+// B: ColumnMajor -+// C: ColumnMajor -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp32_fp32_2N_2K) { -+ -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcc_fp32_fp32_4N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 128, 128); -+ const int num_batch = 5; -+ const int THREAD_N = 4; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp32_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_2K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 2; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_8K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 8; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 64, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp16_fp32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_i8_i32_2N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 64, 64); -+ const int num_batch = 4; -+ const int THREAD_N = 2; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 128, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -+ -+TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_i8_i32_1N_4K) { -+ using namespace test::gemm::threadblock; -+ cutlass::gemm::GemmCoord problem_size(1, 17, 64); -+ const int num_batch = 16; -+ const int THREAD_N = 1; -+ const int THREAD_K = 4; -+ -+ using Shape = cutlass::gemm::GemmShape<1, 32, THREAD_K>; -+ batched_gemv_threadblock_test(problem_size, num_batch); -+} -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu -new file mode 100644 -index 0000000..7e08723 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/epilogue_workspace.cu -@@ -0,0 +1,130 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/epilogue/epilogue_workspace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel computes accumulator data and stores it out -+template -+__global__ void kernel_epilogue_workspace(typename Epilogue::Params params) { -+ -+ __shared__ typename Epilogue::SharedStorage shared_storage; -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ Epilogue epilogue(params, shared_storage, warp_id, lane_id); -+ -+ // -+ // Initialize accumulator tile -+ // -+ typename Epilogue::FragmentC accum; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Epilogue::FragmentC::kElements; ++i) { -+ accum[i] = Element(warp_id * blockDim.x + lane_id); -+ } -+ -+ // -+ // Efficient epilogue -+ // -+ -+ cutlass::GemmCoord tb_tile_coord{blockIdx.x, blockIdx.y, 0}; -+ -+ cutlass::GemmCoord problem_size = -+ tb_tile_coord * -+ cutlass::GemmCoord{Epilogue::Shape::kM, Epilogue::Shape::kN, 1}; -+ -+ // Store accumulators -+ epilogue( -+ problem_size, -+ tb_tile_coord, -+ accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_epilogue_workspace, tensor_op_128x128_64x64) { -+ -+ // -+ // Define an instance of the epilogue and see if it works -+ // -+ static int const kWarpCount = 4; -+ static int const kWarpSize = 32; -+ -+ using Shape = cutlass::MatrixShape<128, 128>; -+ using FragmentC = cutlass::Array; -+ -+ using Epilogue = cutlass::gemm::threadblock::EpilogueWorkspace< -+ Shape, -+ kWarpCount, -+ FragmentC -+ >; -+ -+ typename Epilogue::Params params( -+ -+ ); -+ -+ // Launch the kernel -+ dim3 grid(1,1); -+ dim3 block(kWarpSize, kWarpCount); -+ -+ test::gemm::threadblock::kernel_epilogue_workspace<<< grid, block >>>( -+ params -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << "Kernel launch error - " << cudaGetErrorString(result); -+ -+ // -+ // -+ // -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu -new file mode 100644 -index 0000000..8025637 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage.cu -@@ -0,0 +1,3835 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_multistage_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x64x16_64x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x128x16_32x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x32_64x64x32_16x8x8_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x16_32x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x16_64x32x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x16_32x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x8_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x128_64x64x128_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x256x256_32x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x256x256_64x64x256_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x256x128_32x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x256x128_64x64x128_16x8x64_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x1024_32x32x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x1024_64x32x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x1024x1024_32x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x6144_128x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x1024x1024_64x64x1024_16x8x256_3stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 1024>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_32x32x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x64x512_64x32x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_64x128x512_32x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ tensor_op_128x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 4096); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x6144_128x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x128x512_64x64x512_16x8x256_4stage) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_64x64x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 16); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ tensor_op_128x128x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_64x128x64_32x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_128x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_256x256x384_128x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_512x256x384_256x128x64_64x64x64_16x8x32_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_64x128x128_32x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ tensor_op_128x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_256x256x768_128x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_interleaved, -+ multicta_512x256x1536_256x128x128_64x64x128_16x8x64_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_32x32x16_16x16x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_64x64x16_32x32x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_64x128x16_32x64x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_128x64x16_64x32x16_8x8x4_4stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_f64, -+ tensor_op_128x128x16_32x64x16_8x8x4_3stage) { -+ using ElementA = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = double; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = double; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu -new file mode 100644 -index 0000000..7418732 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_slicedk.cu -@@ -0,0 +1,111 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for CTA-level GEMM specifically for sliced-k kernels (SM_61 and SM_75) -+*/ -+ -+#include "mma_multistage_testbed_slicedk.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Tensor Op GEMM for SM_80 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous_sliced, tensor_op_128x64x256_tb128x64x64_warp64x64x32_16x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM80_gemm_threadblock_crosswise_sliced, tensor_op_128x64x256_tb128x64x64_warp64x64x32_16x8x16) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu -new file mode 100644 -index 0000000..4bb98cd ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse.cu -@@ -0,0 +1,2703 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_multistage_sparse_testbed.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 768); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x32_32x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x32_64x32x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x32_32x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ tensor_op_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_congruous, -+ multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_64x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x64_32x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x64_32x64x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x128_32x32x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x128_64x32x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x128_32x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x128x128_64x64x128_16x8x64_4stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x256_32x64x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 1024); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x256_64x32x256_16x8x64_3stage) { -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x256_32x32x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x256_64x32x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x256_32x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x3072_128x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_512x256x3072_256x128x256_64x64x256_16x8x128_4stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 4; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_64x64x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x64x512_32x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x64x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_64x128x512_32x64x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ tensor_op_128x128x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_sparse_gemm_threadblock_crosswise, -+ multicta_256x256x3072_128x128x512_64x32x512_16x8x128_3stage) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 3072); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, -+ Stages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::SparseTestbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h -new file mode 100644 -index 0000000..6e14745 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h -@@ -0,0 +1,438 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc, -+ typename Mma::IteratorE::Params params_E, -+ typename Mma::IteratorE::TensorRef ref_E) { -+ // Shared storage needed by threadblock-scoped matrix multiply- -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() / Mma::kSparse}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k() / Mma::kSparse}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k() / Mma::kSparse}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ typename Mma::IteratorE iterator_E( -+ params_E, ref_E.data(), -+ {problem_size.m(), -+ problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, -+ tb_thread_id, tb_offset_E); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct SparseTestbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ElementE = typename MmaCore::ElementE; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using ThreadMapE = typename MmaCore::IteratorThreadMapE; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ using AccessTypeE = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ static cutlass::arch::CacheOperation::Kind const CacheOpE = -+ MmaCore::kCacheOpE; -+ -+ static int const Sparse = MmaCore::kSparse; -+ static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; -+ static int const MaxID2 = MmaCore::kMaxID2; -+ -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = typename MmaCore::GmemLayoutE; -+ -+ static int const ElementsPerElementE = MmaCore::kElementsPerElementE; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define iterators over tiles from the E operand -+ using IteratorE = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, -+ LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, -+ typename MmaCore::MmaPolicy, Stages>; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_A_uncompressed; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_E; -+ cutlass::HostTensor matrix_E_reordered; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); -+ matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); -+ matrix_E_reordered.reset( -+ cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { -+ -+ // Waive the test -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ matrix_E.host_view(), seed, MetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(matrix_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), -+ {problem_size.m(), problem_size.n(), -+ problem_size.k() / Sparse / ElementsPerElementE}); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ matrix_E_reordered.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ typename IteratorE::Params params_E(matrix_E_reordered.layout()); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma_sparse, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ if (result != cudaSuccess) { -+ return true; -+ } -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma_sparse, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ if (result != cudaSuccess) { -+ return true; -+ } -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma_sparse -+ <<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0), params_E, -+ matrix_E_reordered.device_ref()); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), -+ matrix_E.host_ref(), problem_size.m(), -+ problem_size.k()); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm(problem_size, ElementC(alpha), -+ matrix_A_uncompressed.host_view(), matrix_B.host_view(), -+ ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ -+ std::cout -+ << __FILE__ << ":" << __LINE__ << " " -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "E:\n" << matrix_E.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); -+ -+ return passed; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h -new file mode 100644 -index 0000000..1e859b6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h -@@ -0,0 +1,374 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/array.h" -+#include "cutlass/core_io.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, -+ LayoutC, typename MmaCore::MmaPolicy, Stages>; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ // -+ // Determine SMEM requirements and waive if not satisfied -+ // -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ if (result != cudaSuccess) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ if (result != cudaSuccess) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma -+ <<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cout -+ << __FILE__ << ":" << __LINE__ << " " -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); -+ -+ return passed; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h -new file mode 100644 -index 0000000..a47a300 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h -@@ -0,0 +1,389 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC **ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ -+ // Dynamic shared memory base pointer -+ extern __shared__ int GemmSharedStorageBase[]; -+ -+ // Declare pointer to dynamic shared memory. -+ typename Mma::SharedStorage *shared_storage = -+ reinterpret_cast(GemmSharedStorageBase); -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); -+ int lane_id = threadIdx.x; -+ -+ int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); -+ -+ int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_idx_mn % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_idx_mn / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ using ThreadMapA = typename MmaCore::IteratorThreadMapA; -+ using ThreadMapB = typename MmaCore::IteratorThreadMapB; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = MmaCore::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ MmaCore::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ MmaCore::kCacheOpB; -+ -+ // Define iterators over tiles from the A operand -+ using IteratorA = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = -+ cutlass::transform::threadblock::PredicatedTileAccessIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaMultistage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, -+ IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy, Stages>; -+ -+ static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed[kPartitionsK]; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_C_pointers; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); -+ -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); -+ -+ matrix_C_pointers.sync_device(); -+ -+ cudaError_t result; -+ -+ int smem_size = int(sizeof(typename Mma::SharedStorage)); -+ if (smem_size >= (48 << 10)) { -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << " cudaFuncSetAttribute " -+ "cudaFuncAttributeMaxDynamicSharedMemorySize error: " -+ << cudaGetErrorString(result); -+ -+ result = cudaFuncSetAttribute( -+ test::gemm::threadblock::kernel_multistage_mma, -+ cudaFuncAttributePreferredSharedMemoryCarveout, 100); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << " cudaFuncSetAttribute " -+ "cudaFuncAttributePreferredSharedMemoryCarveout error: " -+ << cudaGetErrorString(result); -+ } -+ -+ test::gemm::threadblock::kernel_multistage_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_pointers.device_data(), -+ matrix_C_computed[0].layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_host(); -+ -+ // TODO: this is temporary. it will be removed after slicing can de -+ // reduction -+ // -+ // Reduce matrix_C_computed -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 1; k < kPartitionsK; k++) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ -+ matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); -+ } -+ } -+ } -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_multistage_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed[0].host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu -new file mode 100644 -index 0000000..506ca09 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_simt.cu -@@ -0,0 +1,1022 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// sgemm_NT -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_sgemm, sgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass, -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_sgemm, sgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ float, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ float, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ float, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// dgemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_dgemm, dgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_dgemm, dgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ double, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ double, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ double, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_igemm, igemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_igemm, igemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ int, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// hgemm_NN -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_hgemm, hgemm_nt_32x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_64x64x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_32x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_64x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 16); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM50_hgemm, hgemm_nt_128x128x8_32x64x1) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 1>, // InstructionShape, -+ cutlass::half_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ cutlass::half_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ cutlass::half_t, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NT DP4A -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x32_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 32>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_64x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_256x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 256, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<128, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_128x256x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 256, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 256, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nt_256x128x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 128, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x32_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 32>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 4096); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_64x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+TEST(SM61_igemm, igemm_int8_tn_128x64x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 2, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_128x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_256x128x16_64x64x8) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 256, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<128, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_128x256x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<128, 256, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 256, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_tn_256x128x64_64x64x16) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<256, 128, 64>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::RowMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 64); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm, igemm_int8_nn_64x64x16_64x64x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<64, 64, 16>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::ColumnMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2, // Stages, -+ cutlass::arch::OpMultiplyAdd // Operator, -+ >; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu -new file mode 100644 -index 0000000..af1d61d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit tests for CTA-level GEMM specifically for sliced-k kernels (SM_61 and SM_75) -+*/ -+ -+#include "mma_pipelined_testbed_slicedk.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// igemm_NT DP4A -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM61_igemm_sliced_k, igemm_int8_nt_32x32x128_32x32x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 32, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+TEST(SM61_igemm_sliced_k_big, igemm_int8_nt_32x32x128_32x32x4_bigk) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 32, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 32>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 1024); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+ -+TEST(SM61_igemm_sliced_k, igemm_int8_nt_32x64x128_32x32x4) { -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ cutlass::gemm::GemmShape<32, 64, 128>, // ThreadblockShape, -+ cutlass::gemm::GemmShape<32, 32, 64>, // WarpShape, -+ cutlass::gemm::GemmShape<1, 1, 4>, // InstructionShape, -+ int8_t, // ElementA, -+ cutlass::layout::ColumnMajor, // LayoutA, -+ int8_t, // ElementB, -+ cutlass::layout::RowMajor, // LayoutB, -+ int, // ElementC, -+ cutlass::layout::RowMajor, // LayoutC, -+ cutlass::arch::OpClassSimt, // OpClass -+ 2>; // Stages, -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ float alpha = 1.f; -+ float beta = 0.0f; -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ test::gemm::threadblock::Testbed( -+ problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) -+ .run(grid, block, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform); -+} -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Tensor Op GEMM for SM_75 -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous_sliced, tensor_op_64x64x256_tb64x64x64_warp64x32x32_16x8x8) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpMultiplyAdd>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_crosswise_sliced, tensor_op_64x64x256_tb64x64x64_warp64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpMultiplyAdd>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu -new file mode 100644 -index 0000000..3374263 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm70.cu -@@ -0,0 +1,498 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x64x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x64x32_32x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x64x32_64x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_128x64x64_64x32x64_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using OperatorShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, OperatorShape, ElementA, LayoutA, ElementB, -+ LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_64x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_congruous, tensor_op_256x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x64x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_256x128x32_64x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x64x32_32x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x64x32_64x32x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_128x64x64_64x32x64_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using OperatorShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, OperatorShape, ElementA, LayoutA, ElementB, -+ LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_crosswise, tensor_op_64x128x32_32x64x32_8x8x4) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu -new file mode 100644 -index 0000000..3f17387 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm75.cu -@@ -0,0 +1,2129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_64x64x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_128x64x32_64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_64x128x32_32x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, tensor_op_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, -+ multicta_256x256x96_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_congruous, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x32_16x16x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x32_16x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x32_32x16x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x32_32x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x32_64x32x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x32_32x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x96_128x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x384_256x128x32_64x64x32_16x8x8) { -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 384); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x32x64_16x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x32x64_32x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x64x64_16x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x64x64_32x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x64x64_64x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x128x64_32x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_256x256x192_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_512x256x768_256x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<32>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x64_16x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x64_32x16x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x64_16x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x64_32x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x64_64x32x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x64_32x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 256); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x192_128x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 192); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x768_256x128x64_64x64x64_8x8x16) { -+ using ElementA = uint8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = uint8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 768); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x128_16x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x128_32x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x128_16x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x128_32x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x128_64x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x128_32x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x384_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x1536_256x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x32x128_16x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x32x128_32x16x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_32x64x128_16x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x64x128_32x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x64x128_64x32x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_64x128x128_32x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, tensor_op_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 512); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_256x256x384_128x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 384); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_interleaved, -+ multicta_512x256x1536_256x128x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::uint4b_t; -+ using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; -+ using ElementB = cutlass::uint4b_t; -+ using LayoutB = cutlass::layout::RowMajorInterleaved<64>; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x32x512_16x16x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x32x512_32x16x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_32x64x512_16x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_32x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x64x512_64x32x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore component -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x128x512_32x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, tensor_op_128x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_256x256x1536_128x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 1536); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_crosswise, -+ multicta_512x256x6144_256x128x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 6144); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, 2, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu -new file mode 100644 -index 0000000..6e91c01 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_sm80.cu -@@ -0,0 +1,569 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "mma_pipelined_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_64x64x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_128x64x16_64x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_64x128x16_32x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, tensor_op_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_256x256x96_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_congruous, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x32x16_16x16x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 32, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x64x16_16x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(32, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x32x16_32x16x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 32, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_32x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x64x16_64x32x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 64, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x128x16_32x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 48); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_256x256x48_128x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 48); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_crosswise, -+ multicta_512x256x192_256x128x16_64x64x16_16x8x4) { -+ using ElementA = cutlass::tfloat32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::tfloat32_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(512, 256, 192); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 8, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h -new file mode 100644 -index 0000000..6f36b53 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_, -+ /// Number of stages -+ int Stages = 2> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ static const int kStages = Stages; -+ -+ // Define iterators over tiles from the A operand -+ static const bool use_idp4a = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; -+ -+ using IteratorA = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> -+ >::type; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> -+ >::type; -+ -+ // Define MmaPipeline Single Stage -+ using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ // Define MmaPipeline Two Stages -+ using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) -+ using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed; -+ cutlass::HostTensor matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_, float beta_) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ bool sufficient() { -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ // Waive test if insufficient CUDA device -+ if (!sufficient()) { -+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { -+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl; -+ } -+ return true; -+ } -+ -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ test::gemm::threadblock::kernel_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h -new file mode 100644 -index 0000000..9e8d351 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h -@@ -0,0 +1,372 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::TensorRef ref_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::TensorRef ref_B, -+ typename Mma::ElementC **ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc) { -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A(params_A, ref_A.data(), -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorB iterator_B(params_B, ref_B.data(), -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); -+ -+ -+ int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_idx_mn % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_idx_mn / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename MmaCore_> -+struct Testbed { -+ /// Threadblock-level GEMM implementation -+ using MmaCore = MmaCore_; -+ using ThreadblockShape = typename MmaCore::Shape; -+ using WarpShape = typename MmaCore::WarpShape; -+ using InstructionShape = typename MmaCore::InstructionShape; -+ using ElementA = typename MmaCore::ElementA; -+ using LayoutA = typename MmaCore::LayoutA; -+ using ElementB = typename MmaCore::ElementB; -+ using LayoutB = typename MmaCore::LayoutB; -+ using ElementC = typename MmaCore::ElementC; -+ using LayoutC = typename MmaCore::LayoutC; -+ -+ // Define iterators over tiles from the A operand -+ static const bool use_idp4a = cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value && -+ cutlass::platform::is_same::value; -+ -+ static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; -+ static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; -+ -+ using IteratorA = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> -+ >::type; -+ -+ // Define iterators over tiles from the B operand -+ using IteratorB = typename cutlass::platform::conditional< use_idp4a, -+ cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , -+ -+ cutlass::transform::threadblock::PredicatedTileIterator< -+ cutlass::MatrixShape, -+ ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> -+ >::type; -+ -+ // Define the threadblock-scoped pipelined matrix multiply -+ using Mma = cutlass::gemm::threadblock::MmaPipelined< -+ typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, -+ IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, -+ typename MmaCore::MmaPolicy>; -+ -+ static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor matrix_A; -+ cutlass::HostTensor matrix_B; -+ cutlass::HostTensor matrix_C_computed[kPartitionsK]; -+ cutlass::HostTensor matrix_C_reference; -+ cutlass::HostTensor matrix_C_pointers; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ float alpha, beta; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed(int m, int n, int k, float alpha_, float beta_) -+ : problem_size(m, n, k), alpha(alpha_), beta(beta_) { -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); -+ -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); -+ -+ cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); -+ -+ matrix_C_pointers.sync_device(); -+ -+ test::gemm::threadblock::kernel_mma<<>>( -+ problem_size, params_A, matrix_A.device_ref(), params_B, -+ matrix_B.device_ref(), matrix_C_pointers.device_data(), -+ matrix_C_computed[0].layout().stride(0)); -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 0; k < kPartitionsK; k++) -+ matrix_C_computed[k].sync_host(); -+ -+ // TODO: this is temporary. it will be removed after slicing can de -+ // reduction -+ // -+ // Reduce matrix_C_computed -+ // -+ CUTLASS_PRAGMA_UNROLL -+ for(int k = 1; k < kPartitionsK; k++) { -+ CUTLASS_PRAGMA_UNROLL -+ for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ -+ CUTLASS_PRAGMA_UNROLL -+ for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ -+ matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); -+ } -+ } -+ } -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ problem_size, ElementC(alpha), matrix_A.host_view(), -+ matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed[0].host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu -new file mode 100644 -index 0000000..28dc1c8 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm70.cu -@@ -0,0 +1,766 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use double-buffered (kStages=2) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+//////////////// [START] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [START] ////////////////////// -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+//////////////////////////////////////////////////////////// -+ -+// tests for {N,T}x{N,T}=>{T} -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+// tests for {N,T}x{N,T}=>{N} -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_row_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_row_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.col.m16n16k16.f16.f16 (wmma native size 16x16x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_col_col_col_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+//////////////// [END] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [END] ////////////////////// -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 (wmma native size 32x8x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 (wmma native size 8x32x16) -+////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 (wmma native size 16x16x16) -+////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 (wmma native size 32x8x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 (wmma native size 8x32x16) -+///////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu -new file mode 100644 -index 0000000..12fae1f ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use double-buffered (kStages=2) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[atype].[threadblock_shape]_[warp_shape]_[instruction_shape] -+ -+///////////////////////////////////////////////////////////////////////// -+/// Integer (s8 and u8) WMMA threadblock level tests ///// -+///////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED) -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_col_row_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_col_row_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////// -+/// SUBBYTE (s4 and b1) WMMA threadblock level tests //// -+/////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_s4, 64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_s4, 64x64x64_64x64x64_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 2; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu -new file mode 100644 -index 0000000..0b6dc11 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu -@@ -0,0 +1,79 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for threadblock-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" -+ -+#include "mma_planar_complex_testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_gemm_threadblock_planar_complex_congruous, tensor_op_64x64x32_64x64x32_16x8x16_3stage) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 8); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ int const Stages = 3; -+ -+ // Define the MmaCore components -+ using Mma = typename cutlass::gemm::threadblock::DefaultMmaPlanarComplexMultistage< -+ ElementA, LayoutA, 8, -+ ElementB, LayoutB, 8, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassTensorOp, -+ cutlass::arch::Sm80, -+ ThreadblockShape, WarpShape, InstructionShape, -+ Stages>::ThreadblockMma; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, Mma::WarpCount::kCount, 1); -+ -+ test::gemm::threadblock::TestbedPlanarComplex(problem_size.m(), problem_size.n(), -+ problem_size.k()) -+ .run(grid, block); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h -new file mode 100644 -index 0000000..b33abdb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h -@@ -0,0 +1,352 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit testbed for kernel-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/platform/platform.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/vector.h" -+#include "cutlass/numeric_types.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor_planar_complex.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm_planar_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void kernel_mma_planar_complex( -+ cutlass::gemm::GemmCoord problem_size, -+ typename Mma::IteratorA::Params params_A, -+ typename Mma::IteratorA::Element *ptr_A, -+ int64_t imaginary_stride_A, -+ typename Mma::IteratorB::Params params_B, -+ typename Mma::IteratorB::Element *ptr_B, -+ int64_t imaginary_stride_B, -+ typename Mma::ElementC *ptr_C, -+ typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { -+ -+ // Shared storage needed by threadblock-scoped matrix multiply-accumulate -+ __shared__ typename Mma::SharedStorage shared_storage; -+ -+ // Compute threadblock location -+ cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), -+ 0}; -+ -+ cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, -+ tb_tile_offset.k()}; -+ -+ cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), -+ tb_tile_offset.n() * Mma::Shape::kN}; -+ -+ // Compute position within threadblock -+ int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; -+ -+ // Construct iterators to A operand -+ typename Mma::IteratorA iterator_A_real(params_A, ptr_A, -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, -+ {problem_size.m(), problem_size.k()}, -+ tb_thread_id, tb_offset_A); -+ -+ // Construct iterators to B operand -+ typename Mma::IteratorB iterator_B_real(params_B, ptr_B, -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, -+ {problem_size.k(), problem_size.n()}, -+ tb_thread_id, tb_offset_B); -+ -+ int warp_id = threadIdx.y; -+ int lane_id = threadIdx.x; -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); -+ -+ typename Mma::FragmentC accum; -+ -+ accum.clear(); -+ -+ int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); -+ -+ // Output results -+ typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); -+ -+ iterator_C.add_tile_offset( -+ {(tb_tile_offset.m() * Mma::WarpCount::kM) + -+ (warp_id % Mma::WarpCount::kM), -+ (tb_tile_offset.n() * Mma::WarpCount::kN) + -+ (warp_id / Mma::WarpCount::kM)}); -+ -+ iterator_C.store(accum.real); -+ -+ iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Threadblock-level matrix multiply-accumulate -+ typename Mma_> -+struct TestbedPlanarComplex { -+ -+ using Mma = Mma_; -+ using ThreadblockShape = typename Mma::Shape; -+ using IteratorA = typename Mma::IteratorA; -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using IteratorB = typename Mma::IteratorB; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename Mma::ElementC; -+ using ElementAccumulator = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ using ThreadMapA = typename Mma::IteratorA::ThreadMap; -+ using ThreadMapB = typename Mma::IteratorB::ThreadMap; -+ using AccessTypeA = cutlass::Array; -+ using AccessTypeB = cutlass::Array; -+ static int const Stages = Mma::kStages; -+ static cutlass::arch::CacheOperation::Kind const CacheOpA = -+ Mma::kCacheOpA; -+ static cutlass::arch::CacheOperation::Kind const CacheOpB = -+ Mma::kCacheOpB; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensorPlanarComplex matrix_A; -+ cutlass::HostTensorPlanarComplex matrix_B; -+ cutlass::HostTensorPlanarComplex matrix_C_computed; -+ cutlass::HostTensorPlanarComplex matrix_C_reference; -+ -+ cutlass::gemm::GemmCoord problem_size; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TestbedPlanarComplex(int m, int n, int k) -+ : problem_size(m, n, k) { -+ -+ matrix_A.reset(cutlass::make_Coord(m, k)); -+ matrix_B.reset(cutlass::make_Coord(k, n)); -+ matrix_C_computed.reset(cutlass::make_Coord(m, n)); -+ matrix_C_reference.reset(cutlass::make_Coord(m, n), false); -+ } -+ -+ /// Runs the test -+ bool run( -+ dim3 grid, dim3 block, -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_A.host_view(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ -+ for (int i = 0; i < matrix_A.capacity() * 2; ++i) { -+ matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); -+ } -+ /* -+ cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), -+ matrix_A.capacity() * 2); -+ */ -+ } else if (init_A == cutlass::Distribution::Identity) { -+ //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ -+ -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ -+ -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ -+ cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), -+ matrix_B.capacity() * 2); -+ -+ for (int i = 0; i < matrix_B.capacity() * 2; ++i) { -+ matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); -+ } -+ -+ -+ } else if (init_B == cutlass::Distribution::Identity) { -+ -+ //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); -+ -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ matrix_A.sync_device(); -+ matrix_B.sync_device(); -+ matrix_C_computed.sync_device(); -+ -+ typename IteratorA::Params params_A(matrix_A.layout()); -+ typename IteratorB::Params params_B(matrix_B.layout()); -+ -+ test::gemm::threadblock::kernel_mma_planar_complex<<>>( -+ problem_size, -+ params_A, -+ matrix_A.device_data(), -+ matrix_A.imaginary_stride(), -+ params_B, -+ matrix_B.device_data(), -+ matrix_B.imaginary_stride(), -+ matrix_C_computed.device_data(), -+ matrix_C_computed.layout().stride(0), -+ matrix_C_computed.imaginary_stride() -+ ); -+ -+ -+ // -+ // Check error code -+ // -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) -+ << " kernel error: " << cudaGetErrorString(result); -+ -+ matrix_C_computed.sync_host(); -+ -+ cutlass::reference::host::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementAccumulator -+ >( -+ problem_size, -+ cutlass::complex(ElementAccumulator(1)), -+ matrix_A.host_ref(), -+ Mma::kTransformA, -+ matrix_B.host_ref(), -+ Mma::kTransformB, -+ cutlass::complex(ElementAccumulator(0)), -+ matrix_C_reference.host_ref(), -+ matrix_C_reference.host_ref() -+ ); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ matrix_C_computed.host_view(), -+ matrix_C_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::ofstream output("mma_pipelined_testbed_errors.txt"); -+ -+ output -+ << "A:\n" << matrix_A.host_view() << "\n" -+ << "B:\n" << matrix_B.host_view() << "\n" -+ << "Reference:\n" -+ << matrix_C_reference.host_view() << "\n" -+ << "Computed:\n" -+ << matrix_C_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu -new file mode 100644 -index 0000000..06a3ebb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm70.cu -@@ -0,0 +1,417 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use single staged (kStages=1) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_singlestage_wmma_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// WMMA Floating point (f16 accumulation) - Single stage - Threadblock level tests //// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 32); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 64); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 (wmma native size 32x8x16) -+/////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 (wmma native size 8x32x16) -+////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f16, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = cutlass::half_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+/// WMMA Floating point (f32 accumulation) - Single stage - Threadblock level tests //// -+/////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+////////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 (wmma native size 16x16x16) -+////////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(128, 128, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, multicta_256x256x96_128x128x32_64x64x32_16x16x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(256, 256, 96); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(2, 2); -+ dim3 block(32, 4, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+/////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 (wmma native size 32x8x16) -+//////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_32x8x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+///////////////////////////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 (wmma native size 8x32x16) -+///////////////////////////////////////////////////////////////////////////////// -+TEST(SM70_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_f32, 64x64x32_64x64x32_8x32x16) { -+ -+ using ElementA = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::half_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = float; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, -+ ElementB, LayoutB, ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu -new file mode 100644 -index 0000000..1b24ebe ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu -@@ -0,0 +1,337 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED -+#include "mma_pipelined_testbed.h" -+#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" -+ -+/// All tests use single staged (kStages=1) mma pipeline for the gemm mainloop -+/// Test name format: SM[arch]_gemm_threadblock_singlestage_wmma_tensor_op_[alayout]_[blayout]_[clayout]_[atype].[threadblock_shape]_[warp_shape]_[instruction_shape] -+ -+///////////////////////////////////////////////////////////////////////// -+/// Integer (s8 and u8) WMMA threadblock level tests //// -+///////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED) -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_col_row_row_s8, 64x64x32_64x64x32_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_col_row_row_s8, 64x64x64_64x64x64_16x16x16) { -+ -+ using ElementA = int8_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using ElementB = int8_t; -+ using LayoutB = cutlass::layout::RowMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ -+ float alpha = 1.f; -+ float beta = 0.0f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadblockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+ -+//////////////////////////////////////////////////////////////////////// -+/// SUBBYTE (s4 and b1) WMMA threadblock level tests //// -+/////////////////////////////////////////////////////////////////////// -+ -+#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_s4, 64x64x128_64x64x128_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 128); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_col_s4, 64x64x64_64x64x64_8x8x32) { -+ using ElementA = cutlass::int4b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::int4b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 64); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::RowMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+ -+TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_col_b1, 64x64x512_64x64x512_8x8x128) { -+ using ElementA = cutlass::uint1b_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using ElementB = cutlass::uint1b_t; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using ElementC = int32_t; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ static const int kStages = 1; -+ -+ cutlass::gemm::GemmCoord problem_size(64, 64, 2048); -+ -+ using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ -+ float alpha = 1.f; -+ float beta = 0.f; -+ -+ // Define the MmaCore components -+ using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< -+ ThreadBlockShape, WarpShape, InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpClassWmmaTensorOp, kStages, -+ cutlass::arch::OpXorPopc>; -+ -+ dim3 grid(1, 1); -+ dim3 block(32, 1, 1); -+ -+ test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), -+ problem_size.k(), alpha, beta) -+ .run(grid, block); -+} -+#endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED -+ -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu -new file mode 100644 -index 0000000..28410ad ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm80.cu -@@ -0,0 +1,698 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m8n8k4.f64.f64.f64.f64 -+// Output data type: complex -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x4_8x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// Shared memory layout: Congrous -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x32x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x16x8_16x16x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 16, 8> >() -+ .run(); -+} -+ -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// complex * complex => complex -+// Input data type: complex -+// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 -+// Output data type: complex -+// Shared memory layout: Crosswise -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 8> >() -+ .run(); -+} -+ -+// TEST FAILS crosswise complex TN mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 test fails for k = 2*8 = 16 -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<16, 16, 16> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x64x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 64, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 64, 8> >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f32, 64x32x8_16x8x8_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<64, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TransformedTestbedComplex< -+ MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu -new file mode 100644 -index 0000000..38bdfa6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_complex_sm90.cu -@@ -0,0 +1,334 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM with Hopper FP64 -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x32x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x16x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x16_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex().run(); -+} -+ -+TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex().run(); -+} -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu -new file mode 100644 -index 0000000..e6f71ce ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_gaussian_complex_sm80.cu -@@ -0,0 +1,287 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 8x8x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x16x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nt) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_nh) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 32x32x4_8x8x4_ct) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kConjugate, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_gaussian_complex_tensor_op, 16x16x4_8x8x4_tn) { -+ -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ -+ using Element = cutlass::complex; -+ using ElementC = cutlass::complex; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< -+ Shape, -+ InstructionShape, -+ Element, -+ LayoutA, -+ Element, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::ComplexTransform::kNone, -+ cutlass::ComplexTransform::kNone, -+ cutlass::arch::OpMultiplyAddGaussianComplex -+ >::Type; -+ -+ test::gemm::warp::TestbedComplex >().run(); -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu -new file mode 100644 -index 0000000..5a9e2e2 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm50.cu -@@ -0,0 +1,654 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TT SMEM layout -+TEST(SM50_warp_gemm_f32_row_row_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// NN SMEM layout -+TEST(SM50_warp_gemm_f32_col_col_col, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_row, 16x32x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_row, 16x32x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_col_row_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_row_col_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// NT SMEM layout -+TEST(SM50_warp_gemm_f32_col_row_col, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_col_row_row, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+// TN SMEM layout -+TEST(SM50_warp_gemm_f32_row_col_col, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f32_row_col_row, 32x64x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<4, 8>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 64, 8>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_complex_f32_col_row_col, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using complex_f32_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_complex_f32_col_row_row, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using complex_f32_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ complex_f32_t, -+ cutlass::layout::ColumnMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ complex_f32_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 32x16x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_f64_col_row_col, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_f64_col_row_row, 64x32x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_complex_f64_col_row_col, 32x16x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using complex_f64_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 16, 8>, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+TEST(SM50_warp_gemm_complex_f64_col_row_row, 32x16x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::RowMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using complex_f64_t = cutlass::complex; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 16, 8>, -+ complex_f64_t, -+ cutlass::layout::ColumnMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ complex_f64_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_quaternion_f32_col_row_col, 16x8x8_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using quaternion_f32_t = cutlass::Quaternion; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_warp_gemm_quaternion_f32_col_row_row, 16x8x8_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using quaternion_f32_t = cutlass::Quaternion; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ quaternion_f32_t, -+ cutlass::layout::ColumnMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ quaternion_f32_t, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu -new file mode 100644 -index 0000000..03ba3ea ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm60.cu -@@ -0,0 +1,140 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 8x4x1_1x1x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 16x8x1_2x2x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 32x16x1_4x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM60_warp_gemm_f16_col_row, 64x16x1_8x4x1) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<8, 8, 1> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 32, 8>, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ cutlass::half_t, -+ cutlass::layout::RowMajor, -+ cutlass::half_t, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu -new file mode 100644 -index 0000000..c042b5b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm61.cu -@@ -0,0 +1,198 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/warp/mma_simt.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM61_warp_gemm_int8_col_row, col_row_8x4x8_1x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_8x4x4_1x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<1, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<8, 4, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_16x4x4_2x1x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 1, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 4, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_16x4x4_2x2x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<2, 2, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<16, 8, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_32x16x4_4x4x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<32, 16, 16>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_128x64x4_16x16x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<16, 16, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<128, 64, 4>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM61_warp_gemm_int8_col_row, col_row_64x64x4_4x4x4) { -+ -+ using Policy = cutlass::gemm::warp::MmaSimtPolicy< -+ cutlass::MatrixShape<8, 4>, -+ cutlass::layout::ColumnMajorInterleaved<2>, -+ cutlass::gemm::GemmShape<4, 4, 4> -+ >; -+ -+ using Mma = cutlass::gemm::warp::MmaSimt< -+ cutlass::gemm::GemmShape<64, 64, 8>, -+ int8_t, -+ cutlass::layout::ColumnMajorInterleaved<4>, -+ int8_t, -+ cutlass::layout::RowMajorInterleaved<4>, -+ int, -+ cutlass::layout::ColumnMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu -new file mode 100644 -index 0000000..6785ddb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm70.cu -@@ -0,0 +1,295 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x128x16_64x64x16_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x64x4_64x64x4_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_congruous, 128x128x4_32x32x4_16x16x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; -+ using LayoutB = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::ColumnMajor, -+ ElementB, -+ cutlass::layout::RowMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_gemm_tensor_op_crosswise, 64x64x32_64x64x32_16x16x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM70_warp_gemm_volta_tensor_op_canonical_f32_row_col, 64x64x16_64x64x4_8x8x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ cutlass::layout::RowMajor, -+ ElementB, -+ cutlass::layout::ColumnMajor, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+TEST(SM70_warp_gemm_volta_tensor_op_canonical_f32_col_row, 64x64x16_64x64x4_8x8x4) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< -+ cutlass::arch::Mma< -+ cutlass::gemm::GemmShape<16, 16, 4>, -+ 32, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ cutlass::arch::OpMultiplyAdd -+ >, -+ cutlass::MatrixShape<1, 1> -+ >; -+ -+ using MmaTensorOp = cutlass::gemm::warp::MmaVoltaTensorOp< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ cutlass::layout::RowMajor, -+ Policy -+ >; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#endif // CUTLASS_ARCH_MMA_SM70_SUPPORTED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu -new file mode 100644 -index 0000000..43f185d ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm75.cu -@@ -0,0 +1,860 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x8_32x128x8_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 128, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x32_16x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i8, 128x128x64_16x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_64x64x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_64x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_32x32x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_32x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i8, 128x128x64_16x16x64_8x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_i4, 128x128x128_16x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_64x64x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_64x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_32x32x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_32x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_16x16x128_8x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x32x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x32x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x16x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_8x8x128) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, -+ cutlass::arch::OpXorPopc>() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu -new file mode 100644 -index 0000000..54a0248 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm80.cu -@@ -0,0 +1,1865 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_32x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x32_16x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f16, 128x128x64_16x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x64x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_64x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_32x16x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x16_16x16x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_tf32, 128x128x32_16x16x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_64x64x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x16_32x32x16_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x8) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_tn, tf32_round_128x128x32_64x64x32_16x8x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = float; -+ using ElementC = float; -+ -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::TransformTestbed >() -+ .run(); -+} -+ -+TEST(SM80_warp_gemm_tensor_op_nt, tf32_round_128x128x32_64x64x32_16x8x8) { -+ -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = float; -+ using ElementC = float; -+ -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::TransformTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_16x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x64_16x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_32x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i8, 128x128x128_16x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = int8_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x128_16x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x64x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_64x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_32x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_16x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x32x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x32x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_32x16x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 512>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x64x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_64x32x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x32x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_32x16x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_16x16x1024_16x8x256) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 1024>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; -+ using Element = cutlass::uint1b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 1024>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_16x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_f64_row_col, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_f64_col_row, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_tf32_row_col, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_tensor_op_canonical_tf32_col_row, 32x32x8_64x32x8_8x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 8>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -+ -+ -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu -new file mode 100644 -index 0000000..f417a41 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sm90.cu -@@ -0,0 +1,206 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM with Hopper FP64 -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 4>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; -+ using Element = double; -+ using ElementC = double; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; -+ -+ test::gemm::warp::Testbed >() -+ .run(); -+} -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu -new file mode 100644 -index 0000000..af87ee7 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/gemm_sparse_sm80.cu -@@ -0,0 +1,1107 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x64x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x128x128_32x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x64x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x32x128_32x16x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x64x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x64x128_64x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x128x128_32x64x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x64x128_32x32x128_16x8x32) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -+ using Element = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x64x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x32x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x16x128_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 64>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x64x256_64x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x128x256_32x64x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x64x256_32x32x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x32x256_32x16x256_16x8x64) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; -+ using Element = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 16>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x64x256_64x32x256_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x128x64_32x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x64x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x32x64_32x16x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x64x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x64x64_64x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x128x64_32x64x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x64x64_32x32x64_16x8x16) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -+ using Element = cutlass::tfloat32_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 32>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x64x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x32x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x64x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x32x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x16x256_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 256>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 128>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x64x512_64x32x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<64, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x128x512_32x64x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 64, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x64x512_32x32x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 32, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x32x512_32x16x512_16x8x128) { -+ using Shape = cutlass::gemm::GemmShape<32, 16, 512>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; -+ using Element = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< -+ cutlass::sizeof_bits::value, 256>; -+ -+ using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< -+ Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, -+ cutlass::layout::RowMajor>::Type; -+ -+ test::gemm::warp::SparseTestbed >() -+ .run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/testbed.h b/3rdparty/cutlass/test/unit/gemm/warp/testbed.h -new file mode 100644 -index 0000000..3487aa0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/testbed.h -@@ -0,0 +1,1554 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/platform/platform.h" -+#include "cutlass/arch/arch.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/distribution.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/util/host_uncompress.h" -+ -+namespace test { -+namespace gemm { -+namespace warp { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void kernel( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); -+ typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ -+ FragmentC accum; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(frag_A); -+ iter_B.load(frag_B); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ mma(accum, frag_A, frag_B, accum); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The inner product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ -+ cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), -+ tensor_A.capacity(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ -+ cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), -+ tensor_B.capacity(), seed, scope_max, scope_min, 0); -+ -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] -+ << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] -+ << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_ -+> -+struct TestbedComplex { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TestbedComplex() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), -+ seed, 8, -8, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), -+ seed + 16, 8, -8, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ Mma::kTransformA, -+ tensor_B.host_ref(), -+ Mma::kTransformB, -+ ElementC(0), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void kernel_transform( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ -+ using TransformedFragmentA = typename Mma::TransformedFragmentA; -+ using TransformedFragmentB = typename Mma::TransformedFragmentB; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); -+ typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ FragmentA loaded_frag_A; -+ FragmentB loaded_frag_B; -+ TransformedFragmentA transformed_frag_A; -+ TransformedFragmentB transformed_frag_B; -+ -+ FragmentC accum; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(loaded_frag_A); -+ iter_B.load(loaded_frag_B); -+ -+ ++iter_A; -+ ++iter_B; -+ -+ mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, -+ loaded_frag_B); -+ -+ mma(accum, transformed_frag_A, transformed_frag_B, accum); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The innter product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct TransformTestbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TransformTestbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel_transform<<>>( -+ tensor_D_computed.device_data(), tensor_A.device_data(), -+ tensor_B.device_data(), tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_ -+> -+struct TransformedTestbedComplex { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ TransformedTestbedComplex() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), -+ seed, 8, -8, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), -+ seed + 16, 8, -8, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+ // launch kernel -+ kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ cutlass::reference::host::GemmComplex( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ Mma::kTransformA, -+ tensor_B.host_ref(), -+ Mma::kTransformB, -+ ElementC(0), -+ tensor_C.host_ref(), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ -+ cutlass::TensorView tensor_A_physical( -+ tensor_A.host_data(), -+ tensor_A.stride()[0], -+ tensor_A.extent()); -+ -+ cutlass::TensorView tensor_B_physical( -+ tensor_B.host_data(), -+ tensor_B.stride()[0], -+ tensor_B.extent()); -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test kernel -+template -+__global__ void sparse_kernel( -+ typename Mma::ElementC *output_C, -+ typename Mma::ElementA const *input_A, -+ typename Mma::ElementB const *input_B, -+ typename Mma::ElementC const *input_C, -+ typename Mma::ElementE const *input_E, -+ int iterations = 1) { -+ -+ // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. -+ __shared__ cutlass::AlignedBuffer -+ smem_buffer_A; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; -+ -+ __shared__ cutlass::AlignedBuffer< -+ typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / -+ Mma::kSparse / Mma::kElementsPerElementE> -+ smem_buffer_E; -+ -+ __syncthreads(); -+ -+ if (threadIdx.x == 0) { -+ typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_A.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_A, i) = -+ cutlass::ReferenceFactory::type>::get(input_A, i); -+ } -+ -+ typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_B.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_B, i) = -+ cutlass::ReferenceFactory::type>::get(input_B, i); -+ } -+ -+ typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); -+ #pragma unroll 1 -+ for (int i = 0; i < smem_buffer_E.size(); ++i) { -+ cutlass::ReferenceFactory::get(smem_ptr_E, i) = -+ cutlass::ReferenceFactory::type>::get(input_E, i); -+ } -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Construct warp-level matrix product -+ // -+ -+ using FragmentA = typename Mma::FragmentA; -+ using FragmentB = typename Mma::FragmentB; -+ using FragmentC = typename Mma::FragmentC; -+ using FragmentE = typename Mma::FragmentE; -+ -+ typename Mma::LayoutA layout_A = Mma::LayoutA::packed( -+ {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); -+ typename Mma::LayoutB layout_B = -+ Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); -+ typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); -+ typename Mma::LayoutE layout_E = -+ Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, -+ Mma::Shape::kK / Mma::kSparse / -+ Mma::kElementsPerElementE / Mma::kInterleaved}); -+ -+ typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); -+ -+ typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); -+ -+ FragmentA frag_A; -+ FragmentB frag_B; -+ -+ FragmentC accum; -+ -+ FragmentE frag_E; -+ -+ Mma mma; -+ -+ accum.clear(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0; k < ThreadblockShape::kK; -+ k += Mma::Policy::MmaShape::kK) { -+ iter_A.load(frag_A); -+ iter_B.load(frag_B); -+ iter_E.load(frag_E); -+ -+ ++iter_A; -+ ++iter_B; -+ ++iter_E; -+ -+ mma(accum, frag_A, frag_B, accum, frag_E); -+ } -+ } -+ -+ typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); -+ -+ iter_C.store(accum); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the matrix product -+template < -+ /// Warp-level matrix multiply-accumulate -+ typename Mma_, -+ /// Size of threadblock-scoped shape used to store SMEM -+ typename ThreadblockShape_, -+ /// The innter product operation performed by GEMM -+ typename Operator_ = cutlass::arch::OpMultiplyAdd -+> -+struct SparseTestbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = Mma_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Operator = Operator_; -+ -+ using Shape = typename Mma::Shape; -+ using ElementA = typename Mma::ElementA; -+ using LayoutA = typename Mma::LayoutA; -+ using ElementB = typename Mma::ElementB; -+ using LayoutB = typename Mma::LayoutB; -+ using ElementC = typename Mma::ElementC; -+ using LayoutC = typename Mma::LayoutC; -+ -+ static int const Sparse = Mma::kSparse; -+ static int const MetaSizeInBits = Mma::kMetaSizeInBits; -+ static int const MaxID2 = Mma::kMaxID2; -+ static int const Interleaved = Mma::kInterleaved; -+ -+ using ElementE = typename Mma::ElementE; -+ -+ static int const ElementsPerElementE = Mma::kElementsPerElementE; -+ -+ using LayoutE = cutlass::layout::RowMajor; -+ using ReorderedLayoutE = -+ cutlass::layout::ColumnMajorInterleaved; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_A_uncompressed; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ cutlass::HostTensor tensor_E; -+ cutlass::HostTensor tensor_E_reordered; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ SparseTestbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, -+ ThreadblockShape::kK / Sparse)); -+ tensor_A_uncompressed.reset( -+ cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); -+ tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ tensor_E.reset(cutlass::make_Coord( -+ Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); -+ tensor_E_reordered.reset(cutlass::make_Coord( -+ Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); -+ } -+ -+ /// Returns true if the CUDA device is sufficient to execute the kernel. -+ bool sufficient() const { -+ -+ cudaDeviceProp properties; -+ int device_idx; -+ cudaError_t result = cudaGetDevice(&device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() API call failed."); -+ } -+ -+ result = cudaGetDeviceProperties(&properties, device_idx); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ if (properties.major == 9) { -+ // NVIDIA Hopper drops support for several data types -+ if ( -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8 || -+ cutlass::sizeof_bits::value < 8) { -+ -+ return false; -+ } -+ } -+ -+ return true; -+ } -+ -+ /// Runs the test -+ bool run( -+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { -+ -+ if (!sufficient()) { -+ return true; -+ } -+ -+ // -+ // initialize device memory -+ // -+ -+ if (init_A == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_A.host_view(), seed, scope_max, scope_min, 0); -+ } else if (init_A == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), -+ tensor_A.capacity()); -+ } else if (init_A == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ if (init_B == cutlass::Distribution::Uniform) { -+ int scope_max = 8; -+ int scope_min = -8; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ scope_max = 2; -+ scope_min = -2; -+ } else if (cutlass::sizeof_bits::value == 1) { -+ scope_max = 2; -+ scope_min = 0; -+ } -+ -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomUniform( -+ tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); -+ } else if (init_B == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), -+ tensor_B.capacity()); -+ } else if (init_B == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ if (init_E == cutlass::Distribution::Uniform) { -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFillRandomSparseMeta( -+ tensor_E.host_view(), seed, MetaSizeInBits); -+ } else if (init_E == cutlass::Distribution::Identity) { -+ uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; -+ cutlass::reference::host::TensorFill(tensor_E.host_view(), -+ (ElementE)(content)); -+ } else { -+ // TODO: Implement the rest -+ return false; -+ } -+ -+ cutlass::reorder_meta( -+ tensor_E_reordered.host_ref(), tensor_E.host_ref(), -+ {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ tensor_E_reordered.sync_device(); -+ -+ // launch kernel -+ sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data(), -+ tensor_E_reordered.device_data()); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), -+ tensor_E.host_ref(), Shape::kM, Shape::kK); -+ -+ cutlass::reference::host::Gemm -+ reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, ThreadblockShape::kK}, -+ ElementC(1), -+ tensor_A_uncompressed.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed); -+ -+ if (!passed) { -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; -+ -+ std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; -+ std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; -+ -+ std::cout -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << "\n"; -+ } -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace warp -+} // namespace gemm -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu -new file mode 100644 -index 0000000..f2d6762 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm70.cu -@@ -0,0 +1,688 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for warp-level wmma gemm -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/// Test name format: SM[arch]_warp_wmma_[alayout]_[blayout]_[clayout]_[dtype].[threadblock_shape]_[warp_shape] -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// f16 accumulation point wmma.mma ////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+//////////////// [START] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [START] ////////////////////// -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+ -+// 4 tests for {N,T}x{N,T}=>{T} -+TEST(SM70_warp_wmma_row_col_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_row_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_col_row_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+// 4 tests for {N,T}x{N,T}=>{N} -+TEST(SM70_warp_wmma_row_col_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_row_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m16n16k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_col_col_f16, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::ColumnMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+/////////// [END] Verifying all layouts {N,T}x{N,T}=>{N,T} for WMMA 16x16x16 [END] /////////////////////////// -+ -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x16_64x64x16_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_64x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_64x32x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_32x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 64x64x32_32x32x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f16, 128x128x16_64x64x16_16x16x16) { -+ // Even though the test launches 128x128x16 CTA tile this test only verfies one warp -+ // , i.e., warp_0 of size 64x64x16 out of the four warps required to cover the CTA -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f16, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f16, 8x32x16_8x32x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+//////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.col.row.m8n32k16.f16.f16 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_col_row_row_f16, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_col_row_row_f16, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = cutlass::half_t; -+ using LayoutA = cutlass::layout::ColumnMajor; -+ using LayoutB = cutlass::layout::RowMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// f32 accumulation point wmma.mma ////////////////////////////////// -+//////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m16n16k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 64x64x16_64x64x16_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 64x64x32_64x64x32_16x16x16) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+TEST(SM70_warp_wmma_row_col_row_f32, 128x128x16_64x64x16_16x16x16) { -+ // Even though the test launches 128x128x16 CTA tile this test only verfies one warp -+ // , i.e., warp_0 of size 64x64x16 out of the four warps required to cover the CTA -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m32n8k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+ -+///////////////////////////////////////////////////////////// -+/// wmma.mma.sync.aligned.alayout.blayout.shape.dtype.ctype -+/// wmma.mma.sync.aligned.row.col.m8n32k16.f32.f32 -+//////////////////////////////////////////////////////////// -+TEST(SM70_warp_wmma_row_col_row_f32, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = cutlass::half_t; -+ using ElementB = cutlass::half_t; -+ using ElementC = float; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, -+ LayoutA, -+ ElementB, LayoutB, -+ ElementC, -+ LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+#endif //CUTLASS_ARCH_WMMA_SM70_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu -new file mode 100644 -index 0000000..8b56220 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm72.cu -@@ -0,0 +1,185 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// Integer wmma.mma //////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+// TODO: FIXME SM75 should SM72, but the compilation breaks as SM72 shows up and runs on VOLTA -+TEST(SM75_warp_wmma_row_col_s8, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_s8, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_s8, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = int8_t; -+ using ElementB = int8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 16x16x16_16x16x16_16x16x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<16, 16, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 32x8x16_32x8x16_32x8x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<32, 8, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+ -+TEST(SM75_warp_wmma_row_col_u8, 8x32x16_8x32x16_8x32x16) { -+ // Threadblock and warp with just one native WMMA operation (most basic unit test) -+ using WarpShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 32, 16>; -+ using ElementA = uint8_t; -+ using ElementB = uint8_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+} -+#endif //CUTLASS_ARCH_WMMA_SM72_ENABLED -diff --git a/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu -new file mode 100644 -index 0000000..ebc0f3b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/gemm/warp/wmma_sm75.cu -@@ -0,0 +1,170 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ -+ \brief Unit tests for thread-level GEMM -+*/ -+#include "cutlass/arch/wmma.h" -+ -+#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/aligned_buffer.h" -+#include "cutlass/half.h" -+ -+#include "cutlass/gemm/warp/default_mma_wmma_tensor_op.h" -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include "testbed.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////// SUBBYTE wmma.mma //////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(SM75_warp_wmma_row_col_s4, 64x64x32_8x8x32_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_s4, 64x64x32_64x64x32_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_s4, 64x64x64_8x8x64_8x8x32) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 64>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; -+ using ElementA = cutlass::int4b_t; -+ using ElementB = cutlass::int4b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC>::Type; -+ -+ test::gemm::warp::Testbed >().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_b1, 64x64x128_8x8x128_8x8x128) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using ElementA = cutlass::uint1b_t; -+ using ElementB = cutlass::uint1b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, cutlass::arch::OpXorPopc>().run(); -+ -+} -+ -+TEST(SM75_warp_wmma_row_col_b1, 64x64x128_64x64x128_8x8x128) { -+ -+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; -+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; -+ using ElementA = cutlass::uint1b_t; -+ using ElementB = cutlass::uint1b_t; -+ using ElementC = int32_t; -+ using LayoutA = cutlass::layout::RowMajor; -+ using LayoutB = cutlass::layout::ColumnMajor; -+ using LayoutC = cutlass::layout::RowMajor; -+ -+ using WmmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOpWmma< -+ WarpShape, -+ InstructionShape, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ cutlass::arch::OpXorPopc>::Type; -+ -+ test::gemm::warp::Testbed, cutlass::arch::OpXorPopc>().run(); -+ -+} -+#endif //CUTLASS_ARCH_WMMA_SM75_ENABLED -diff --git a/3rdparty/cutlass/test/unit/layout/matrix.cu b/3rdparty/cutlass/test/unit/layout/matrix.cu -new file mode 100644 -index 0000000..c603ced ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/matrix.cu -@@ -0,0 +1,151 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for matrix layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ void test_row_major_layout(int row_size, int column_size, int ldm) { -+ cutlass::layout::RowMajor row_major(ldm); -+ -+ // test pointer offset -+ for (int row_idx = 0; row_idx < row_size; row_idx++) { -+ for (int column_idx = 0; column_idx < column_size; column_idx++) { -+ cutlass::MatrixCoord matrix_coord(row_idx, column_idx); -+ auto ptr_offset = row_major(matrix_coord); -+ decltype(ptr_offset) reference_offset = row_idx * ldm + column_idx; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ -+ // test stride -+ EXPECT_EQ(row_major.stride()[0], ldm); -+ -+ // test capacity -+ auto capacity = row_major.capacity(cutlass::MatrixCoord(row_size, column_size)); -+ decltype(capacity) reference_capacity = row_size * ldm; -+ EXPECT_EQ(capacity, reference_capacity); -+ -+ // test packed -+ auto packed = row_major.packed(cutlass::MatrixCoord(row_size, column_size)); -+ // the packed matrix's stride is the same with column size -+ EXPECT_EQ(packed.stride()[0], column_size); -+ } -+ -+ void test_column_major_layout(int row_size, int column_size, int ldm) { -+ cutlass::layout::ColumnMajor column_major(ldm); -+ -+ // test pointer offset -+ for (int row_idx = 0; row_idx < row_size; row_idx++) { -+ for (int column_idx = 0; column_idx < column_size; column_idx++) { -+ cutlass::MatrixCoord matrix_coord(row_idx, column_idx); -+ auto ptr_offset = column_major(matrix_coord); -+ decltype(ptr_offset) reference_offset = row_idx + column_idx * ldm; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ -+ // test stride -+ EXPECT_EQ(column_major.stride()[0], ldm); -+ -+ // test capacity -+ auto capacity = column_major.capacity(cutlass::MatrixCoord(row_size, column_size)); -+ decltype(capacity) reference_capacity = column_size * ldm; -+ EXPECT_EQ(capacity, reference_capacity); -+ -+ // test packed -+ auto packed = column_major.packed(cutlass::MatrixCoord(row_size, column_size)); -+ // the packed matrix's stride is the same with row size -+ EXPECT_EQ(packed.stride()[0], row_size); -+ } -+ -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, row_major_32_53) { -+ int const row_size = 32; -+ int const column_size = 53; -+ int const ldm = 55; -+ test::layout::test_row_major_layout(row_size, column_size, ldm); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, column_major_32_53) { -+ int const row_size = 32; -+ int const column_size = 53; -+ int const ldm = 55; -+ test::layout::test_column_major_layout(row_size, column_size, ldm); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Matrix, general_matrix) { -+ -+ int M = 16; -+ int N = 16; -+ int interleave = 4; -+ -+ cutlass::layout::GeneralMatrix::TensorCoord extent = {M, N}; -+ -+ cutlass::layout::GeneralMatrix layout = -+ cutlass::layout::GeneralMatrix::packed( -+ extent, cutlass::layout::Matrix::kColumnMajor, interleave); -+ -+ cutlass::HostTensor tensor(extent); -+ -+ for (int m = 0; m < M; ++m) { -+ for (int n = 0; n < N; ++n) { -+ tensor.host_data(m * N + n) = m * N + n; -+ } -+ } -+ -+ cutlass::TensorView canonical({tensor.host_data(), layout}, extent); -+ -+ // Uncomment this to view -+ // -+ //std::cout << canonical << std::endl; -+ // -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/layout/tensor.cu b/3rdparty/cutlass/test/unit/layout/tensor.cu -new file mode 100644 -index 0000000..253f0c0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/tensor.cu -@@ -0,0 +1,153 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for tensor layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/tensor.h" -+#include "cutlass/tensor_coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ void test_NHWC_layout(int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size + 1; -+ int ldw = ldc * (w_size + 2); -+ int ldh = ldw * (h_size + 3); -+ -+ cutlass::layout::TensorNHWC::Stride tensor_stride({ ldc, ldw, ldh }); -+ -+ cutlass::layout::TensorNHWC tensor_nhwc(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int h_idx = 0; h_idx < h_size; h_idx++) { -+ for (int w_idx = 0; w_idx < w_size; w_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, h_idx, w_idx, c_idx); -+ auto ptr_offset = tensor_nhwc(tensor_coord); -+ decltype(ptr_offset) reference_offset = c_idx + -+ w_idx * ldc + -+ h_idx * ldw + -+ n_idx * ldh; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nhwc.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nhwc.capacity(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldh * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ // test packed -+ auto packed_tensor_layout = tensor_nhwc.packed(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ auto packed_stride = packed_tensor_layout.stride(); -+ EXPECT_EQ(packed_stride, cutlass::layout::TensorNHWC::Stride({ c_size, w_size * c_size, h_size * w_size * c_size })); -+ } -+ -+ -+ void test_NCHW_layout(int n_size, int c_size, int h_size, int w_size) { -+ int ldw = w_size + 1; -+ int ldh = ldw * (h_size + 2); -+ int ldc = ldh * (c_size + 1); -+ -+ cutlass::layout::TensorNCHW::Stride tensor_stride({ ldw, ldh, ldc }); -+ -+ cutlass::layout::TensorNCHW tensor_nchw(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ for (int h_idx = 0; h_idx < w_size; h_idx++) { -+ for (int w_idx = 0; w_idx < c_size; w_idx++) { -+ // tensor4DCoord is always created in nhwc order -+ cutlass::Tensor4DCoord tensor_coord(n_idx, h_idx, w_idx, c_idx); -+ auto ptr_offset = tensor_nchw(tensor_coord); -+ decltype(ptr_offset) reference_offset = w_idx + -+ h_idx * ldw + -+ c_idx * ldh + -+ n_idx * ldc; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nchw.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nchw.capacity(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldc * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ // test packed -+ auto packed_tensor_layout = tensor_nchw.packed(cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ auto packed_stride = packed_tensor_layout.stride(); -+ EXPECT_EQ(packed_stride, cutlass::layout::TensorNHWC::Stride({ w_size, w_size * h_size, w_size * h_size * c_size })); -+ } -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Tensor, NHWC_32_12_10_14) { -+ int n_size = 32; -+ int h_size = 12; -+ int w_size = 10; -+ int c_size = 14; -+ test::layout::test_NHWC_layout(n_size, h_size, w_size, c_size); -+ -+} -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_Tensor, NCHW_32_12_10_14) { -+ int n_size = 32; -+ int c_size = 12; -+ int h_size = 10; -+ int w_size = 14; -+ test::layout::test_NCHW_layout(n_size, c_size, h_size, w_size); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu b/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu -new file mode 100644 -index 0000000..e0f6b5b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/layout/tensor_nhwc.cu -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+/*! \file -+\brief unit tests for NHWC tensor layout -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/util/device_memory.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+namespace test { -+namespace layout { -+ -+ void test_nhwc_layout(int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size + 1; -+ int ldw = ldc * (w_size + 2); -+ int ldh = ldw * (h_size + 3); -+ -+ typedef cutlass::layout::TensorNHWC Tensor; -+ -+ Tensor::Stride tensor_stride({ ldc, ldw, ldh }); -+ Tensor tensor_nhw_packed_c(tensor_stride); -+ -+ // test pointer offset -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, p_idx, q_idx, c_idx); -+ auto ptr_offset = tensor_nhw_packed_c(tensor_coord); -+ decltype(ptr_offset) reference_offset = c_idx + -+ q_idx * ldc + -+ p_idx * ldw + -+ n_idx * ldh; -+ EXPECT_EQ(ptr_offset, reference_offset); -+ } -+ } -+ } -+ } -+ -+ // test stride -+ auto stride = tensor_nhw_packed_c.stride(); -+ EXPECT_EQ(stride, tensor_stride); -+ -+ // test capacity -+ auto capacity = tensor_nhw_packed_c.capacity( -+ cutlass::Tensor4DCoord(n_size, h_size, w_size, c_size)); -+ decltype(capacity) referece_capacity = ldh * n_size; -+ EXPECT_EQ(capacity, referece_capacity); -+ -+ } -+ -+ __global__ void test_nhwc_inverse( -+ int *output, int n_size, int h_size, int w_size, int c_size) { -+ int ldc = c_size; -+ int ldw = ldc * w_size; -+ int ldh = ldw * h_size; -+ -+ typedef cutlass::layout::TensorNHWC Tensor; -+ -+ Tensor::Stride tensor_stride({ ldc, ldw, ldh }); -+ Tensor tensor_nhw_packed_c(tensor_stride); -+ -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ cutlass::Tensor4DCoord tensor_coord(n_idx, p_idx, q_idx, threadIdx.x); -+ int ptr_offset = tensor_nhw_packed_c(tensor_coord); -+ cutlass::Tensor4DCoord inv_coord = tensor_nhw_packed_c.inverse(ptr_offset); -+ output[ptr_offset] = tensor_nhw_packed_c(inv_coord); -+ } -+ } -+ } -+ } -+ -+ class TestTensorNHWC { -+ public: -+ -+ // -+ // Data members -+ // -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TestTensorNHWC() { -+ -+ } -+ -+ /// Runs the test -+ void run(int n_size, int h_size, int w_size, int c_size) { -+ -+ size_t size = n_size * h_size * w_size * c_size; -+ -+ /// Device memory containing output -+ cutlass::device_memory::allocation< int > output(size); -+ int *output_host = (int *)malloc(sizeof(int) * size); -+ -+ dim3 grid(1,1); -+ dim3 block(c_size, 1, 1); -+ -+ test::layout::test_nhwc_inverse<<< grid, block >>>(output.get(), -+ n_size, h_size, w_size, c_size); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ // -+ // Verify output -+ // -+ -+ cutlass::device_memory::copy_to_host(output_host, output.get(), size); -+ -+ result = cudaGetLastError(); -+ ASSERT_EQ(result, cudaSuccess) << "CUDA error: " << cudaGetErrorString(result); -+ -+ for (int n_idx = 0; n_idx < n_size; n_idx++) { -+ for (int p_idx = 0; p_idx < h_size; p_idx++) { -+ for (int q_idx = 0; q_idx < w_size; q_idx++) { -+ for (int c_idx = 0; c_idx < c_size; c_idx++) { -+ int reference_offset = c_idx + -+ q_idx * c_size + -+ p_idx * (c_size * w_size) + -+ n_idx * (c_size * w_size * h_size); -+ EXPECT_EQ(output_host[reference_offset], reference_offset); -+ } -+ } -+ } -+ } -+ } -+}; -+ -+ -+} // namespace layout -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Layout_TensorNHWC, NHWC_1_16_8_32) { -+ int n_size = 1; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 32; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+TEST(Layout_TensorNHWC, NHWC_2_16_8_32) { -+ int n_size = 2; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 32; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+} -+ -+TEST(Layout_TensorNHWC, NHWC_2_16_8_128) { -+ int n_size = 2; -+ int h_size = 16; -+ int w_size = 8; -+ int c_size = 128; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+TEST(Layout_TensorNHWC, NHWC_4_8_16_128) { -+ int n_size = 4; -+ int h_size = 8; -+ int w_size = 16; -+ int c_size = 128; -+ test::layout::test_nhwc_layout(n_size, h_size, w_size, c_size); -+ test::layout::TestTensorNHWC test_nhwc; -+ test_nhwc.run(n_size, h_size, w_size, c_size); -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h -new file mode 100644 -index 0000000..94f3c78 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h -@@ -0,0 +1,43 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace nvrtc { -+ -+extern char const *kCutlassHeaders[]; -+extern char const *kCutlassHeaderNames[]; -+extern size_t const kCutlassHeaderCount; -+} // namespace nvrtc -+} // namespace cutlass -diff --git a/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h -new file mode 100644 -index 0000000..c2d9cde ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h -@@ -0,0 +1,76 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include "cutlass/array.h" -+ -+namespace test { -+namespace nvrtc { -+namespace kernel { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level matrix multiply-accumulate -+template -+__global__ void testbed_kernel( -+ typename Mma::ElementC *D, -+ typename Mma::ElementA const *A, -+ typename Mma::ElementB const *B, -+ typename Mma::ElementC const *C) { -+ -+ auto ptr_D = reinterpret_cast *>(D); -+ auto ptr_A = reinterpret_cast const *>(A); -+ auto ptr_B = reinterpret_cast const *>(B); -+ auto ptr_C = reinterpret_cast const *>(C); -+ -+ Mma mma; -+ -+ auto a = *ptr_A; -+ auto b = *ptr_B; -+ auto c = *ptr_C; -+ -+ cutlass::Array d; -+ -+ mma(d, a, b, c); -+ -+ *ptr_D = d; -+} -+ -+} -+} -+} -+} -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/stdlib/assert.h b/3rdparty/cutlass/test/unit/nvrtc/stdlib/assert.h -new file mode 100644 -index 0000000..e69de29 -diff --git a/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h b/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h -new file mode 100644 -index 0000000..f6033de ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/stdlib/stdint.h -@@ -0,0 +1,129 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+typedef char int8_t; -+typedef unsigned char uint8_t; -+typedef short int16_t; -+typedef unsigned short uint16_t; -+typedef int int32_t; -+typedef unsigned int uint32_t; -+typedef long long int int64_t; -+typedef unsigned long long int uint64_t; -+ -+#if defined __x86_64__ && !defined __ILP32__ -+# define __WORDSIZE 64 -+#else -+# define __WORDSIZE 32 -+#endif -+ -+ -+/* Small types. */ -+ -+/* Signed. */ -+typedef signed char int_least8_t; -+typedef short int int_least16_t; -+typedef int int_least32_t; -+#if __WORDSIZE == 64 -+typedef long int int_least64_t; -+#else -+__extension__ -+typedef long long int int_least64_t; -+#endif -+ -+/* Unsigned. */ -+typedef unsigned char uint_least8_t; -+typedef unsigned short int uint_least16_t; -+typedef unsigned int uint_least32_t; -+#if __WORDSIZE == 64 -+typedef unsigned long int uint_least64_t; -+#else -+__extension__ -+typedef unsigned long long int uint_least64_t; -+#endif -+ -+ -+/* Fast types. */ -+ -+/* Signed. */ -+typedef signed char int_fast8_t; -+#if __WORDSIZE == 64 -+typedef long int int_fast16_t; -+typedef long int int_fast32_t; -+typedef long int int_fast64_t; -+#else -+typedef int int_fast16_t; -+typedef int int_fast32_t; -+__extension__ -+typedef long long int int_fast64_t; -+#endif -+ -+/* Unsigned. */ -+typedef unsigned char uint_fast8_t; -+#if __WORDSIZE == 64 -+typedef unsigned long int uint_fast16_t; -+typedef unsigned long int uint_fast32_t; -+typedef unsigned long int uint_fast64_t; -+#else -+typedef unsigned int uint_fast16_t; -+typedef unsigned int uint_fast32_t; -+__extension__ -+typedef unsigned long long int uint_fast64_t; -+#endif -+ -+/* Types for `void *' pointers. */ -+#if __WORDSIZE == 64 -+# ifndef __intptr_t_defined -+typedef long int intptr_t; -+# define __intptr_t_defined -+# endif -+typedef unsigned long int uintptr_t; -+#else -+# ifndef __intptr_t_defined -+typedef int intptr_t; -+# define __intptr_t_defined -+# endif -+typedef unsigned int uintptr_t; -+#endif -+ -+ -+/* Largest integral types. */ -+#if __WORDSIZE == 64 -+typedef long int intmax_t; -+typedef unsigned long int uintmax_t; -+#else -+__extension__ -+typedef long long int intmax_t; -+__extension__ -+typedef unsigned long long int uintmax_t; -+#endif -+ -diff --git a/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu b/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu -new file mode 100644 -index 0000000..8b9b8bb ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/thread/gemm_nvrtc.cu -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/gemm/thread/mma.h" -+ -+#include "testbed.h" -+ -+#if 0 -+int main() { -+ nvrtc::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run("cutlass::gemm::thread::Mma, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::ColumnMajor >"); -+ return 0; -+} -+#endif -+ -+TEST(SM50_Sgemm_thread_nvrtc, DISABLED_col_row_3x4x2) { -+ -+ test::nvrtc::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run("cutlass::gemm::thread::Mma, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::ColumnMajor >"); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if 0 -+TEST(SM50_Sgemm_thread, col_row_3x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<3, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col_4x4x2) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 4, 2>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row_4x5x3) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<4, 5, 3>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, col_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Sgemm_thread, row_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::RowMajor, -+ float, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM50_Dgemm_thread, col_row) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+ -+TEST(SM50_Dgemm_thread, row_col) { -+ -+ test::gemm::thread::Testbed< -+ cutlass::gemm::GemmShape<8, 8, 1>, -+ double, -+ cutlass::layout::RowMajor, -+ double, -+ cutlass::layout::ColumnMajor, -+ double, -+ cutlass::layout::ColumnMajor -+ >().run(); -+} -+#endif -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h b/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h -new file mode 100644 -index 0000000..378be81 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/nvrtc/thread/testbed.h -@@ -0,0 +1,323 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level GEMM -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/gemm/thread/mma.h" -+#include "../kernel/thread/testbed_kernel.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+#include -+#include -+#include "../cutlass/nvrtc/environment.h" -+#include -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace nvrtc { -+namespace thread { -+ -+/// Structure to compute the matrix product -+template < -+ /// Size of the Gemm problem - concept: gemm::GemmShape<> -+ typename Shape, -+ /// Data type of A elements -+ typename ElementA, -+ /// Layout of A matrix (concept: MatrixLayout) -+ typename LayoutA, -+ /// Data type of B elements -+ typename ElementB, -+ /// Layout of B matrix (concept: MatrixLayout) -+ typename LayoutB, -+ /// Element type of C matrix -+ typename ElementC, -+ /// Layout of C matrix (concept: MatrixLayout) -+ typename LayoutC -+> -+struct Testbed { -+ -+ /// Thread-level matrix multiply-accumulate operator -+ using Mma = cutlass::gemm::thread::Mma< -+ Shape, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_A; -+ cutlass::HostTensor tensor_B; -+ cutlass::HostTensor tensor_C; -+ cutlass::HostTensor tensor_D_computed; -+ cutlass::HostTensor tensor_D_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed() { -+ -+ tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); -+ tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); -+ tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); -+ tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); -+ } -+ -+ static inline bool check_nvrtc_error(nvrtcResult error) { -+ if (error != NVRTC_SUCCESS) { -+ std::cerr << "failed to compile "; -+ return false; -+ } -+ return true; -+ } -+ -+ /// Runs the test -+ bool run(std::string const &gemm_traits) { -+ -+ // -+ // initialize device memory -+ // -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_A.host_data(), -+ tensor_A.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ tensor_B.host_data(), -+ tensor_B.capacity(), -+ ElementB(1), -+ ElementB(2) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_C.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_computed.host_view(), -+ ElementC(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ tensor_D_reference.host_view(), -+ ElementC(0) -+ ); -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ tensor_C.sync_device(); -+ tensor_D_computed.sync_device(); -+ -+#if 0 -+ // launch kernel -+ cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( -+ tensor_D_computed.device_data(), -+ tensor_A.device_data(), -+ tensor_B.device_data(), -+ tensor_C.device_data()); -+ -+#else -+ // Instantiate gemm_kernel -+ nvrtcResult result_nvrtc; -+ nvrtcProgram program; -+ static char const *src = -+ "#include \"cutlass/gemm/thread/mma.h\"\n" -+ "#include \"cutlass/gemm/gemm.h\"\n" -+ "#include \"cutlass/layout/matrix.h\"\n" -+ "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" -+ ; -+ -+ std::string type_name; -+#if 0 -+ // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names -+ // As altername solution we might want to implement to_string() to get the traits string. -+ nvrtcGetTypeName(&type_name); -+#else -+ type_name = gemm_traits; -+#endif -+ -+ result_nvrtc = nvrtcCreateProgram(&program, -+ src, -+ NULL, -+ (int)cutlass::nvrtc::kCutlassHeaderCount, -+ cutlass::nvrtc::kCutlassHeaders, -+ cutlass::nvrtc::kCutlassHeaderNames); -+ check_nvrtc_error(result_nvrtc); -+ -+ std::string gemm_kernel_instantiation = -+ "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; -+ nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); -+ -+ const char *opts[] = {"--gpu-architecture=compute_75", -+ "--std=c++11", -+ "--include-path=/usr/local/cuda-10.1/include"}; -+ -+ result_nvrtc = nvrtcCompileProgram(program, 3, opts); -+ if (result_nvrtc != NVRTC_SUCCESS) { -+ size_t logSize; -+ nvrtcGetProgramLogSize(program, &logSize); -+ std::vector log(logSize); -+ nvrtcGetProgramLog(program, log.data()); -+ std::cout << "Compile log:" << std::endl << log.data() << std::endl; -+ } -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // The lowered name is the name of the template instantiation in the generated PTX code. -+ char const *gemm_kernel_lowered_name; -+ nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards -+ size_t ptx_size; -+ result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ std::vector ptx(ptx_size); -+ result_nvrtc = nvrtcGetPTX(program, ptx.data()); -+ if (!check_nvrtc_error(result_nvrtc)) { -+ assert(0); -+ } -+ -+ // we do not need the nvrtc program anymore -+ //nvrtcDestroyProgram(&program); -+ -+ CUmodule module; -+ CUresult result_cuda; -+ result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } -+ -+ CUfunction kernel; -+ result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } -+ -+ void* d_a = (void*)tensor_A.device_data(); -+ void* d_b = (void*)tensor_B.device_data(); -+ void* d_c = (void*)tensor_C.device_data(); -+ void* d_d = (void*)tensor_D_computed.device_data(); -+ void* args[] = { &d_d, &d_a, &d_b, &d_c }; -+ -+ // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra -+ result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); -+ if (result_cuda != CUDA_SUCCESS) { -+ assert(0); -+ } else { -+} -+#endif -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cout << "CUDA ERROR: " << cudaGetErrorString(result); -+ return false; -+ } -+ -+ tensor_D_computed.sync_host(); -+ -+ // -+ // Reference implementation -+ // -+ -+ //tensor_D_reference.fill(tensor_C.host_view()); -+ -+ cutlass::reference::host::Gemm reference_gemm; -+ -+ reference_gemm( -+ {Shape::kM, Shape::kN, Shape::kK}, -+ ElementC(1), -+ tensor_A.host_ref(), -+ tensor_B.host_ref(), -+ ElementC(0), -+ tensor_D_reference.host_ref() -+ ); -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = cutlass::reference::host::TensorEquals( -+ tensor_D_computed.host_view(), -+ tensor_D_reference.host_view() -+ ); -+ -+ if(!passed) std::cout -+ << "A:\n" << tensor_A.host_view() << "\n\n" -+ << "B:\n" << tensor_B.host_view() << "\n\n" -+ << "C:\n" << tensor_C.host_view() << "\n\n" -+ << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" -+ << "Computed:\n" << tensor_D_computed.host_view() << std::endl; -+ -+ std::cout << "passed " << passed << std::endl; -+ -+ return passed; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace nvrtc -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu -new file mode 100644 -index 0000000..d2adad6 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_async.cu -@@ -0,0 +1,468 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineAsync class -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineAsync::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+// Simple 1 producer warp, one consumer warp scenario -+template -+__global__ static -+void pipeline_async_basic_device(uint32_t const num_iterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using MainloopPipeline = typename cutlass::PipelineAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int lane_predicate = cute::elect_one_sync(); -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ auto cluster_shape = ClusterShape{}; -+ -+ // This example showcases 2 producer 1 consumer example -+ typename MainloopPipeline::Params params; -+ params.producer_arv_count = 2; -+ params.consumer_arv_count = 1; -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ __syncthreads(); -+ -+ if (lane_predicate) { -+ // Producer Warps -+ if (warp_idx==0 || warp_idx==1) { -+ -+ int prologue_iterations = min(NumStages, num_iterations); -+ for ( int i = 0; i < prologue_iterations; ++i) { -+ // Can also specify stage to commit directly -+ pipeline.producer_commit(i); -+ } -+ -+ int mainloop_iterations = num_iterations - prologue_iterations; -+ -+ // Only the mainloop needs a PipelineState because this is where we start "waiting" (acquiring) -+ PipelineState smem_pipe_write; -+ -+ for ( ; mainloop_iterations > 0; --mainloop_iterations) { -+ pipeline.producer_acquire(smem_pipe_write); -+ pipeline.producer_commit(smem_pipe_write); -+ ++smem_pipe_write; -+ } -+ } -+ else { -+ PipelineState smem_pipe_read; -+ for (int iter=0 ; iter < num_iterations; ++iter) { -+ pipeline.consumer_wait(smem_pipe_read); -+ pipeline.consumer_release(smem_pipe_read.index()); -+ ++smem_pipe_read; -+ } -+ } -+ } -+ -+ // To make sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 96; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest() = default; -+ -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = nullptr) { -+ -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 2; -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ // Define the tiled MMA layout (static, 4warps) -+ using MainloopPipeline = typename cutlass::PipelineAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_async_basic_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_async_basic_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ return cudaSuccess; -+ } -+ -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage3) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 3; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage4) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 4; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage6) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 6; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage8) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 8; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage9) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 9; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineAsync, Cluster4x4_Stage11) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 11; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu -new file mode 100644 -index 0000000..90e0ca3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async.cu -@@ -0,0 +1,469 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class -+*/ -+ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__global__ static -+void pipeline_device(uint32_t const NumIterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; -+ using MainloopPipeline = cutlass::PipelineTmaAsync; -+ using PipelineState = cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ int warp_group_thread_idx = threadIdx.x % 128; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = sizeof(uint32_t) * NumProducers; -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params params; -+ params.transaction_bytes = TmaTransactionBytes; -+ params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = 128; -+ -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ // Total number of gemm_k_iterations -+ auto mma_k_iterations = NumIterations; -+ auto tma_k_iterations = NumIterations; -+ -+ PipelineState smem_pipe_read; -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState smem_pipe_write = cutlass::make_producer_start_state(); -+ PipelineState smem_pipe_release; -+ int K_TILE_MMAS = 1; -+ -+ int lane_predicate = cute::elect_one_sync(); -+ int k_pipe_tma_prologue = min(NumStages, tma_k_iterations); -+ -+ // DMA Prologue (Loads) -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < k_pipe_tma_prologue; ++i) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // cp.async.bulk.tensor would typically happen here -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ } -+ tma_k_iterations -= k_pipe_tma_prologue; -+ -+ // MMA Prologue (Compute) - modeling inflight MMAs -+ for (int iter = 0; iter < K_TILE_MMAS; ++iter) -+ { -+ pipeline.consumer_wait(smem_pipe_read); -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++smem_pipe_read; -+ } -+ -+ mma_k_iterations -= K_TILE_MMAS; -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int iter = 0; iter < mma_k_iterations; ++iter) -+ { -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ pipeline.consumer_release(smem_pipe_release); -+ -+ if (lane_predicate && (warp_idx == 0) && (tma_k_iterations > 0)) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // cp.async.bulk.tensor would typically happen here -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ --tma_k_iterations; -+ } -+ -+ // next read stage -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // To make sure remote SMEM doesn't get destoryed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest(){}; -+ -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ // Define the tiled MMA layout (static, 4warps) -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x4_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x2_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu -new file mode 100644 -index 0000000..f0d6a79 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu -@@ -0,0 +1,525 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class as it would be used in a Warp specialized loop -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/arch/reg_reconfig.h" -+ -+ -+using namespace cute; -+using namespace cutlass; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage storage ; -+}; -+ -+struct KernelParams -+{ -+ uint32_t num_iterations; -+ int* data_ptr; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__launch_bounds__(384, 1) -+__global__ static -+void pipeline_device(KernelParams const kernel_params) -+{ -+ extern __shared__ char shared_memory[]; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); -+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); -+ int warp_group_thread_idx = threadIdx.x % 128; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params params; -+ params.transaction_bytes = TmaTransactionBytes; -+ if (warp_group_idx == 0) { -+ params.role = MainloopPipeline::ThreadCategory::Producer; -+ } -+ else { -+ params.role = MainloopPipeline::ThreadCategory::Consumer; -+ } -+ params.is_leader = warp_group_thread_idx == 0; -+ params.num_consumers = 128; -+ -+ MainloopPipeline pipeline(shared_storage.storage, params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ -+ // Producer WarpGroup -+ if (warp_group_idx == 0) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ -+ int lane_predicate = cute::elect_one_sync(); -+ if (warp_idx_in_warpgroup == 0 && lane_predicate) { -+ -+ int tma_k_prologue = min(Stages, kernel_params.num_iterations); -+ -+ // Simulating Prologue TMA Loads -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState smem_pipe_write = make_producer_start_state(); -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(smem_pipe_write); -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ ++smem_pipe_write; -+ } -+ int tma_k_iter = kernel_params.num_iterations - tma_k_prologue; -+ -+ // Simulating Mainloop TMA Loads -+ CUTE_NO_UNROLL -+ for ( ; tma_k_iter > 0; --tma_k_iter) { -+ -+ pipeline.producer_acquire(smem_pipe_write); -+ -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); -+ -+ // Advance write stage -+ ++smem_pipe_write; -+ } -+ -+ // Tail Loop -+ // Handles the case where we never enter the mainloop -+ PipelineState tail = tma_k_prologue == Stages ? smem_pipe_write : PipelineState{}; -+ for ( int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(tail); -+ ++tail; -+ } -+ } -+ // Consumer WarpGroup -+ } else if(warp_group_idx == 1) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ -+ PipelineState smem_pipe_read; -+ PipelineState smem_pipe_release; -+ -+ // simulates accumulators + extra reg. pressure -+ int arr[168]; -+ -+ // Init Shared Memory read stages & PhaseBit -+ static constexpr uint32_t K_PIPE_MMAS = 1; -+ static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); -+ -+ // Total number of gemm iterations -+ auto gemm_k_iterations = kernel_params.num_iterations; -+ -+ // Simulating Prologue MMAs -+ int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < mma_k_prologue; ++iter) { -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++smem_pipe_read; -+ } -+ gemm_k_iterations -= mma_k_prologue; -+ -+ // Simulating Mainloop MMAs -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { -+ -+ /// Wait on the smem_pipe_read stage / phase -+ pipeline.consumer_wait(smem_pipe_read); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ // Dummy op - which will never happen -+ // But simulates high register usage. -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 256){ -+ arr[i] += kernel_params.data_ptr[i]; -+ } -+ } -+ -+ pipeline.consumer_release(smem_pipe_release); -+ -+ // Advance stages -+ ++smem_pipe_read; -+ ++smem_pipe_release; -+ } -+ -+ // Dummy op - which will never happen -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 256){ -+ kernel_params.data_ptr[i] = arr[i]; -+ } -+ } -+ -+ // Tail Loop -+ for (int i = 0; i < K_PIPE_MMAS; ++i){ -+ pipeline.consumer_release(smem_pipe_release); -+ ++smem_pipe_release; -+ } -+ -+ // Warp-Group #2 -+ } else { -+ cutlass::arch::warpgroup_reg_dealloc<40>(); -+ } -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128 * 3; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Ctor -+ PipelineTest(){}; -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ int smem_size = int(sizeof(SharedStorage)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with kBlockSize threads per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ const void* kernel = (const void*)pipeline_device; -+ KernelParams params{kNumIters, nullptr}; -+ void* kernel_params[] = {reinterpret_cast(¶ms)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu -new file mode 100644 -index 0000000..4b6a3b1 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu -@@ -0,0 +1,585 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the PipelineTmaAsync class used in a WarpSpecialized Persistent loop -+*/ -+ -+#define KERNEL_DBG_TRACE false -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cutlass/arch/reg_reconfig.h" -+ -+ -+using namespace cute; -+using namespace cutlass; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename cutlass::PipelineTmaAsync::SharedStorage pipeline_storage; -+ typename PingPongBarrier::SharedStorage pingpong_storage; -+}; -+ -+template -+struct CollectiveSimulation { -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ CUTLASS_DEVICE -+ static void -+ dma_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, -+ uint32_t const num_iterations) { -+ uint32_t const per_cta_bytes = sizeof(uint32_t); -+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); -+ int lane_predicate = cute::elect_one_sync(); -+ if (warp_idx_in_warpgroup==0 && lane_predicate) { -+ -+ int tma_k_prologue = min(Stages, num_iterations); -+ -+ // Simulating Prologue TMA Loads -+ CUTLASS_PRAGMA_UNROLL -+ for(int i = 0; i < tma_k_prologue; ++i) { -+ pipeline.producer_acquire(tile_start_state_pipe); -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(tile_start_state_pipe.index(), per_cta_bytes); -+ ++tile_start_state_pipe; -+ } -+ int tma_k_iter = num_iterations - tma_k_prologue; -+ -+ PipelineState wr_pipe = tile_start_state_pipe; -+ // Simulating Mainloop TMA Loads -+ CUTE_NO_UNROLL -+ for ( ; tma_k_iter > 0; --tma_k_iter){ -+ -+ pipeline.producer_acquire(wr_pipe); -+ -+ // Simulating cp.async.bulk.tensor behavior -+ pipeline.producer_commit(wr_pipe.index(), per_cta_bytes); -+ -+ // Advance write stage -+ ++wr_pipe; -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ static void -+ math_wg_simulation(MainloopPipeline pipeline, PipelineState tile_start_state_pipe, -+ uint32_t const num_iterations, int* data_ptr) { -+ PipelineState rd_pipe = tile_start_state_pipe; -+ PipelineState release_pipe = rd_pipe; -+ -+ // simulates accumulators + extra reg. pressure -+ int arr[168]; -+ -+ // Init Shared Memory read stages & PhaseBit -+ static constexpr uint32_t K_PIPE_MMAS = 1; -+ static_assert( K_PIPE_MMAS < Stages, "ERROR : Too many MMAs in flight"); -+ -+ // Total number of gemm iterations -+ auto gemm_k_iterations = num_iterations; -+ -+ // Simulating Prologue MMAs -+ int mma_k_prologue = min(K_PIPE_MMAS, gemm_k_iterations); -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter = 0; iter < mma_k_prologue; ++iter) { -+ pipeline.consumer_wait(rd_pipe); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ ++rd_pipe; -+ } -+ gemm_k_iterations -= mma_k_prologue; -+ -+ // Simulating Mainloop MMAs -+ CUTLASS_PRAGMA_NO_UNROLL -+ for ( ; gemm_k_iterations > 0; --gemm_k_iterations) { -+ -+ /// Wait on the rd_pipe stage / phase -+ pipeline.consumer_wait(rd_pipe); -+ -+ warpgroup_arrive(); -+ // GMMA would typically happen here -+ -+ // Dummy op - which will never happen -+ // But simulates high register usage. -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 384){ -+ arr[i] += data_ptr[i]; -+ } -+ } -+ -+ pipeline.consumer_release(release_pipe); -+ -+ // Advance stages -+ ++rd_pipe; -+ ++release_pipe; -+ } -+ -+ // Dummy op - which will never happen -+ CUTE_UNROLL -+ for(int i = 0; i < 168; ++i){ -+ if (threadIdx.x > 384){ -+ data_ptr[i] = arr[i]; -+ } -+ } -+ -+ // Tail Loop -+ for (int i = 0; i < K_PIPE_MMAS; ++i){ -+ pipeline.consumer_release(release_pipe); -+ ++release_pipe; -+ } -+ -+ } -+}; -+ -+struct KernelParams -+{ -+ uint32_t num_iterations; -+ int tiles_per_cluster; -+ int* data_ptr; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__launch_bounds__(384, 1) -+__global__ static -+void pipeline_device(KernelParams params) -+{ -+ extern __shared__ char shared_memory[]; -+ using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ using PipelineState = typename cutlass::PipelineState; -+ -+ /* One for Mainloop and one for Epilogue */ -+ constexpr int StagesPerMathWarpGroup = 2; -+ constexpr int MathWarpGroupCountPersistent = 2; -+ using PingPongBarrier = typename cutlass::OrderedSequenceBarrier; -+ -+ using SharedStorage = SharedStorage; -+ SharedStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ auto cta_layout = Layout{}; // (m,n) -> cta_id -+ int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); -+ int warp_group_thread_idx = threadIdx.x % NumThreadsPerWarpGroup; -+ dim3 block_id_in_cluster = cute::block_id_in_cluster(); -+ -+ auto cluster_shape = ClusterShape{}; -+ -+ // #Producers = #RowsInCluster + #ColsInCluster - 1 -+ uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; -+ uint32_t const TmaTransactionBytes = static_cast(sizeof(uint32_t) * NumProducers); -+ -+ // mbarrier.init -+ typename MainloopPipeline::Params pipeline_params; -+ pipeline_params.transaction_bytes = TmaTransactionBytes; -+ if (warp_group_idx == 0) { -+ pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; -+ } -+ else { -+ pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; -+ } -+ pipeline_params.is_leader = warp_group_thread_idx == 0; -+ pipeline_params.num_consumers = NumThreadsPerWarpGroup; -+ -+ MainloopPipeline pipeline(shared_storage.pipeline_storage, pipeline_params); -+ PipelineState tile_start_state_pipe; -+ -+ int tiles_per_cluster = params.tiles_per_cluster; -+ -+ /* Offset pipeline start state for Math WG 2 */ -+ if (warp_group_idx == 2) { -+ // Update pipeline state for next persistent tile -+ tile_start_state_pipe.advance(params.num_iterations); -+ tiles_per_cluster--; -+ } -+ -+ typename PingPongBarrier::Params pingpong_params; -+ pingpong_params.group_id = warp_group_idx - 1; // Since DMA Warp Group Idx 0 will not participate -+ pingpong_params.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group -+ PingPongBarrier math_wg_barrier(shared_storage.pingpong_storage, pingpong_params); -+ -+ __syncthreads(); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ // Producer/DMA WarpGroup -+ if (warp_group_idx == 0) { -+ cutlass::arch::warpgroup_reg_dealloc<40>(); -+ // For the DMA (prologue) - we start with an opposite phase - since we skip all waits -+ // i.e., we know that the buffer is indeed empty -+ PipelineState tile_prologue_state_pipe = make_producer_start_state(); -+ while (tiles_per_cluster > 0) { -+ CollectiveSimulation::dma_wg_simulation(pipeline, tile_prologue_state_pipe, params.num_iterations); -+ // Update pipeline state for next persistent tile -+ tile_prologue_state_pipe.advance(params.num_iterations); -+ tiles_per_cluster--; -+ } -+ } -+ // Math WarpGropups -+ if(warp_group_idx == 1 || warp_group_idx == 2) { -+ cutlass::arch::warpgroup_reg_alloc<232>(); -+ while (tiles_per_cluster > 0) { -+ // MMA -+ math_wg_barrier.wait(); -+ CollectiveSimulation::math_wg_simulation(pipeline, tile_start_state_pipe, params.num_iterations, params.data_ptr); -+ math_wg_barrier.arrive(); -+ // Epilogue -+ math_wg_barrier.wait(); -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ math_wg_barrier.arrive(); -+ // Update pipeline state for next persistent tile -+ tile_start_state_pipe.advance(params.num_iterations * 2); -+ tiles_per_cluster -= 2; -+ } -+ } -+ -+ // Makes sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+/// Device NT GMMA + TMA specialized -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t kBlockSize = 128 * 3; -+ using ClusterShape = ClusterShape_; -+ -+ // -+ // Methods -+ // -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = 0) { -+ -+ float elapsed_ms = 0.0f; -+ // Pipeline (multistage pipeline) -+ auto num_stages = Int{}; -+ auto cluster_shape = Shape, Int, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaEvent_t events[2]; -+ cudaError_t result; -+ -+ for (cudaEvent_t & event : events) { -+ result = cudaEventCreate(&event); -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to create event."; -+ return result; -+ } -+ } -+ -+ result = cudaEventRecord(events[0]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record start event."; -+ return result; -+ } -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ using MainloopPipeline = typename cutlass::PipelineTmaAsync; -+ -+ constexpr int StagesPerMathWarpGroup = 2; -+ constexpr int MathWarpGroupCountPersistent = 2; -+ int smem_size = int(sizeof(SharedStorage>)); -+ -+ result = cudaFuncSetAttribute( -+ pipeline_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with kBlockSize threads per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(kBlockSize,1,1); -+ -+ int tiles_per_cluster = (kNumIters % 10) + 1; -+ printf("Persistent version: Tiles per Cluster = %d\n", tiles_per_cluster); -+ -+ const void* kernel = (const void*)pipeline_device; -+ KernelParams params{kNumIters, tiles_per_cluster, nullptr}; -+ void *kernel_params[] = {¶ms}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } -+ -+ result = cudaEventRecord(events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: Failed to record stop event."; -+ return result; -+ } -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to create event."; -+ return result; -+ } -+ -+ for (cudaEvent_t & event : events) { -+ (void)cudaEventDestroy(event); -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x1_Stage10) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; -+ static constexpr uint32_t Stages = 10; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage5) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 5; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x1_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster1x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster2x4_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage2) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_PipelineTmaAsync_WS_Persistent, Cluster4x2_Stage7) { -+ Options options; -+ using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; -+ static constexpr uint32_t Stages = 7; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu b/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu -new file mode 100644 -index 0000000..f426ca0 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/sequence_barrier.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Unit test for the OrderedSequenceBarrier class -+*/ -+ -+#include "../common/cutlass_unit_test.h" -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cutlass/util/print_error.hpp" -+#include "cutlass/util/GPU_Clock.hpp" -+ -+#include "testbed.h" -+#include "cutlass/pipeline.hpp" -+#include "cutlass/arch/barrier.h" -+#include "cute/arch/cluster_sm90.hpp" -+ -+using namespace cute; -+ -+//////////////////// KERNEL ///////////////////////// -+ -+template -+struct SharedStorage -+{ -+ typename OrderedSequencer::SharedStorage storage; -+}; -+ -+// Goal of this kernel is to complete deadlock-free -+template -+__global__ static -+void ordered_sequence_device(uint32_t const num_iterations) -+{ -+ -+ extern __shared__ char shared_memory[]; -+ using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; -+ using SmemStorage = SharedStorage; -+ -+ SmemStorage& shared_storage = *reinterpret_cast(shared_memory); -+ -+ int group_idx = threadIdx.x / ThreadsPerGroup; -+ -+ typename SequenceBarrier::Params params; -+ params.group_id = group_idx; // sequence ID -+ params.group_size = ThreadsPerGroup; // Number of threads / participants in a group -+ -+ SequenceBarrier barrier(shared_storage.storage, params); -+ -+ // Ensure All CTAs in Cluster have completed init before issuing commits -+ __syncthreads(); -+ cute::cluster_arrive_relaxed(); -+ cute::cluster_wait(); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < num_iterations; ++i){ -+ -+ barrier.wait(); -+ // STAGE 1 CODE... -+ #ifndef NDEBUG -+ int thread_idx_in_group = threadIdx.x % ThreadsPerGroup; -+ if (thread_idx_in_group == 0) { -+ printf("STAGE 0 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); -+ } -+ #endif -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ barrier.arrive(); -+ -+ barrier.wait(); -+ // STAGE 2 CODE... -+ #ifndef NDEBUG -+ if (thread_idx_in_group == 0) { -+ printf("STAGE 1 : Group_IDX : %d, id = %d, iter = %d, tidx = %d\n", group_idx, params.id, i, threadIdx.x); -+ } -+ #endif -+ // Simulates long running stage -+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) -+ __nanosleep(100000); -+ #endif -+ barrier.arrive(); -+ } -+ -+ // To make sure remote SMEM doesn't get destroyed -+ cute::cluster_arrive(); -+ cute::cluster_wait(); -+} -+///////////////////////////////////////////////////// -+ -+template -+struct PipelineTest { -+ -+ // -+ // Data members -+ // -+ static constexpr uint32_t ThreadsPerGroup = 128; -+ static constexpr uint32_t BlockSize = GroupCount_ * ThreadsPerGroup; -+ static constexpr uint32_t Stages = Stages_; -+ static constexpr uint32_t GroupCount = GroupCount_; -+ using SequenceBarrier = typename cutlass::OrderedSequenceBarrier; -+ using SmemStorage = SharedStorage; -+ -+ // -+ // Methods -+ // -+ -+ // Run CuTe GEMM kernel -+ cudaError_t run(uint32_t const kNumIters, -+ cudaStream_t stream = nullptr) { -+ -+ // Pipeline (multistage pipeline) -+ auto cluster_shape = Shape<_1, _1, _1>{}; -+ -+ // -+ // Configure and launch -+ // -+ int iterations = 1; -+ cudaError_t result; -+ -+ for (int iter = 0; iter < iterations; ++iter) { -+ -+ int smem_size = int(sizeof(SmemStorage)); -+ -+ result = cudaFuncSetAttribute( -+ ordered_sequence_device, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, -+ smem_size); -+ -+ // Launch a single Cluster, with 128 thread per CTA -+ dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); -+ dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); -+ dim3 dimBlock(BlockSize,1,1); -+ -+ const void* kernel = (const void*)ordered_sequence_device; -+ int iters = kNumIters; -+ void* kernel_params[] = {reinterpret_cast(&iters)}; -+ cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); -+ -+ } // profiling loop ends -+ -+ result = cudaDeviceSynchronize(); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Error: cudaDeviceSynchronize() failed" << std::endl; -+ return result; -+ } -+ -+ return cudaSuccess; -+ } -+}; -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_2) { -+ Options options; -+ static constexpr uint32_t GroupCount = 2; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_3) { -+ Options options; -+ static constexpr uint32_t GroupCount = 3; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_4) { -+ Options options; -+ static constexpr uint32_t GroupCount = 4; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+ -+TEST(SM90_Verify_OrderedSequence, Depth_2_Length_5) { -+ Options options; -+ static constexpr uint32_t GroupCount = 5; -+ static constexpr uint32_t Stages = 2; -+ using Test = PipelineTest; -+ Testbed testbed(options); -+ EXPECT_TRUE(testbed.verification()); -+} -+#endif -diff --git a/3rdparty/cutlass/test/unit/pipeline/testbed.h b/3rdparty/cutlass/test/unit/pipeline/testbed.h -new file mode 100644 -index 0000000..b809e74 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/pipeline/testbed.h -@@ -0,0 +1,145 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Common Testbed file shared by Pipeline unit tests -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/util/command_line.h" -+#include "../common/cutlass_unit_test.h" -+ -+#if CUDA_12_0_SM90_FEATURES_SUPPORTED -+ #define CUTLASS_UNIT_TEST_PIPELINE true -+#else -+ #define CUTLASS_UNIT_TEST_PIPELINE false -+#endif -+ -+// Command line test options -+struct Options { -+ // -+ // Data Members -+ // -+ bool help; -+ bool verification_enabled; -+ int SM_count; -+ int clock_MHz; -+ -+ // -+ // Methods -+ // -+ Options(): -+ help(false), -+ verification_enabled(true), -+ SM_count(116), -+ clock_MHz(1477) -+ { } -+ -+ void parse(int argc, char const **args) { -+ cutlass::CommandLine cmd(argc, args); -+ -+ if (cmd.check_cmd_line_flag("help")) { -+ help = true; -+ } -+ -+ cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); -+ cmd.get_cmd_line_argument("sm-count", SM_count, 116); -+ cmd.get_cmd_line_argument("clock", clock_MHz, 1477); -+ } -+ -+ /// Prints the usage statement. -+ std::ostream & print_usage(std::ostream &out) const { -+ -+ out << "Options:\n\n" -+ << " --help If specified, displays this usage statement.\n\n" -+ << " --verification-enabled= Enable/Disable verification\n" -+ << " --sm-count= Number of SMs on the chip\n" -+ << " --clock= Locked clock value in Mhz\n"; -+ -+ return out; -+ } -+}; -+ -+// -+// Testbed -+// -+ -+template -+struct Testbed { -+private: -+ // Commandline options -+ Options options; -+ -+ void run_test(uint32_t const kNumIters) { -+ -+ // Run CuTe Gemm -+ Pipeline pipeline; -+ -+ cudaError_t result = pipeline.run(kNumIters); -+ -+ CUTE_CHECK_LAST(); -+ } -+ -+ -+public: -+ Testbed(Options const &options_) : options(options_) { -+ int device_id = 0; -+ cudaDeviceProp device_prop; -+ CUTE_CHECK_ERROR(cudaSetDevice(device_id)); -+ CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); -+ -+ if (device_prop.major < 1) { -+ fprintf(stderr, "Device does not support CUDA.\n"); -+ exit(1); -+ } -+ } -+ -+ /// Run verification Gemm problem sizes -+ bool verification() { -+ -+ std::array kNumIters; -+ -+ for (int i = 0; i < kNumIters.size(); ++i) { -+ kNumIters[i] = (rand() % 1000) + 1; -+ } -+ -+ for (int n : kNumIters) { -+ std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; -+ run_test(n); -+ } -+ -+ return true; -+ } -+}; -diff --git a/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu -new file mode 100644 -index 0000000..c582eb5 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_contiguous.cu -@@ -0,0 +1,476 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for TensorReduce family of device-wide operators -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This reduces the C dimension, transforming an NHWC tensor into NHWC with C=1. -+template -+bool TestAllReduction_NHWC_reduce_c(ElementCompute reduction_identity = ElementCompute()) { -+ -+ using Layout = typename TensorReduction::Layout; -+ using ElementOutput = typename TensorReduction::ElementOutput; -+ using ElementSource = typename TensorReduction::ElementSource; -+ -+ int const kV = TensorReduction::kVectorLength; -+ -+ int const N_indices[] = {3, 13}; -+ int const H_indices[] = {5, 17}; -+ int const W_indices[] = {7, 19}; -+ int const C_indices[] = {2049, 2048, 2047, 384, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1}; -+ -+ for (int N : N_indices) { -+ for (int H : H_indices) { -+ for (int W : W_indices) { -+ for (int Cx : C_indices) { -+ -+ int C = Cx * kV; -+ -+ cutlass::HostTensor src_tensor({N, H, W, C}); -+ cutlass::HostTensor dst_tensor({N, H, W, 1}); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ src_tensor.host_view(), 17, 10, -10, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ // Execute a tensor reduction over rank 3 (the 'C' dimension is reduced; NHWC => NHW) -+ TensorReduction reduction(src_tensor.extent(), 3); -+ -+ cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); -+ -+ cutlass::Status status = reduction.reduce( -+ dst_tensor.device_ref(), -+ src_tensor.device_ref(), -+ device_workspace.get(), -+ reduction_identity -+ ); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ -+ dst_tensor.sync_host(); -+ -+ typename TensorReduction::ReductionOp reduction_op; -+ -+ // -+ // Reference check -+ // -+ for (int n = 0; n < src_tensor.extent().n(); ++n) { -+ for (int h = 0; h < src_tensor.extent().h(); ++h) { -+ for (int w = 0; w < src_tensor.extent().w(); ++w) { -+ -+ ElementCompute c_accum = reduction_identity; -+ -+ for (int c = 0; c < src_tensor.extent().c(); ++c) { -+ c_accum = reduction_op(c_accum, ElementCompute(src_tensor.at({n, h, w, c}))); -+ } -+ -+ ElementCompute got = ElementCompute(dst_tensor.at({n, h, w, 0})); -+ -+ bool equal = (c_accum == got); -+ -+ EXPECT_TRUE(equal); -+ if (!equal) { -+ -+ std::cerr -+ << "Error at location (" << n << ", " << h << ", " << w << ", 0)" << std::endl; -+ -+ std::cerr -+ << " expected: " << c_accum << std::endl -+ << " got: " << got << std::endl; -+ -+ std::cerr -+ << "Problem: " << src_tensor.extent() << " -> " -+ << dst_tensor.extent() << std::endl; -+ -+ std::cerr -+ << " Grid: " << reduction.reduction_strided.grid_shape -+ << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl -+ << " FInal: " << reduction.reduction_strided.grid_final -+ << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; -+ -+ return false; -+ } -+ -+ } //w -+ } // h -+ } // n -+ -+ // -+ // Next problem -+ // -+ -+ } // C -+ } // W -+ } // H -+ } // N -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x1_f16x1) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 2; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x2_f16x2) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 2; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_reduce_c_f32x4_f16x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_maximum_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::maximum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( -std::numeric_limits::max() )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_minimum_c_f32x4) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 4; -+ -+ // Define the functor -+ using Functor = cutlass::minimum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( std::numeric_limits::max() )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ANY_c_s32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ALL_c_s32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ANY_c_f32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(0) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHW -+TEST(Reduction_TensorReduce, nhwc_ALL_c_f32) { -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ int const kV = 1; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_c( ElementCompute(1) )); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu -new file mode 100644 -index 0000000..7e9ccc3 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/device/tensor_reduce_strided.cu -@@ -0,0 +1,523 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for TensorReduce family of device-wide operators -+*/ -+ -+#include -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/tensor_reduce.h" -+ -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This reduces the W dimension, transforming an NHWC tensor into NHWC with W=1. -+template < -+ typename TensorReduction, -+ typename ElementCompute = typename TensorReduction::ElementCompute -+> -+bool TestAllReduction_NHWC_reduce_w(ElementCompute reduction_identity = ElementCompute()) { -+ -+ using Layout = typename TensorReduction::Layout; -+ using ElementOutput = typename TensorReduction::ElementOutput; -+ using ElementSource = typename TensorReduction::ElementSource; -+ -+ int const kV = TensorReduction::kVectorLength; -+ -+ int const N_indices[] = {1, 2, 5, 10}; -+ int const H_indices[] = {1, 3, 9 }; -+ int const W_indices[] = {1, 5, 19, 40, 224}; -+ int const C_indices[] = { -+ kV, -+ 2 * kV, -+ 5 * kV, -+ 9 * kV, -+ 17 * kV, -+ 39 * kV, -+ 257 * kV, -+ kV * 760 -+ }; -+ -+ using Element = int; -+ -+ for (int N : N_indices) { -+ for (int H : H_indices) { -+ for (int W : W_indices) { -+ for (int C : C_indices) { -+ -+ cutlass::HostTensor src_tensor({N, H, W, C}); -+ cutlass::HostTensor dst_tensor({N, H, 1, C}); -+ -+ cutlass::reference::host::TensorFillRandomUniform( -+ src_tensor.host_view(), 17, 10, -10, 0); -+ -+ cutlass::reference::host::BlockFillSequential( -+ dst_tensor.host_data(), dst_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ // Execute a tensor reduction over rank 2 (the 'W' dimension is reduced; NHWC => NHC) -+ TensorReduction reduction(src_tensor.extent(), 2); -+ -+ cutlass::DeviceAllocation device_workspace(reduction.workspace_size()); -+ -+ cutlass::Status status = reduction.reduce( -+ dst_tensor.device_ref(), -+ src_tensor.device_ref(), -+ device_workspace.get(), -+ reduction_identity -+ ); -+ -+ EXPECT_EQ(status, cutlass::Status::kSuccess); -+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); -+ // Reference check -+ dst_tensor.sync_host(); -+ -+ typename TensorReduction::ReductionOp reduction_op; -+ -+ for (int n = 0; n < src_tensor.extent().n(); ++n) { -+ for (int h = 0; h < src_tensor.extent().h(); ++h) { -+ for (int c = 0; c < src_tensor.extent().c(); ++c) { -+ -+ ElementCompute w_accum = reduction_identity; -+ -+ for (int w = 0; w < src_tensor.extent().w(); ++w) { -+ w_accum = reduction_op(w_accum, ElementCompute(src_tensor.at({n, h, w, c}))); -+ } -+ -+ ElementCompute got = ElementCompute(dst_tensor.at({n, h, 0, c})); -+ -+ bool equal = (w_accum == got); -+ -+ EXPECT_TRUE(equal); -+ if (!equal) { -+ -+ std::cerr -+ << "Error at location (" << n << ", " << h << ", 0, " << c << ")" << std::endl; -+ -+ std::cerr -+ << " expected: " << w_accum << std::endl -+ << " got: " << got << std::endl; -+ -+ std::cerr -+ << "Problem: " << src_tensor.extent() << " -> " -+ << dst_tensor.extent() << std::endl; -+ -+ std::cerr -+ << " Grid: " << reduction.reduction_strided.grid_shape -+ << "\n Block: " << reduction.reduction_strided.threadblock_shape << std::endl -+ << " Final: " << reduction.reduction_strided.grid_final -+ << "\n Block: " << reduction.reduction_strided.threadblock_final << "\n"; -+ -+ return false; -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x8_f16x8) { -+ -+ int const kV = 8; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x2_f16x2) { -+ -+ int const kV = 2; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_f32x1_f16x1) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = cutlass::half_t; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_s32x4) { -+ -+ int const kV = 4; -+ using Element = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ Element, -+ Element, -+ Layout, -+ Functor, -+ kV, -+ Element -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_reduce_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = cutlass::complex; -+ using ElementSource = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::plus; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_maximum_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::maximum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w( -std::numeric_limits::max() )); -+} -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_minimum_w_cf32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::minimum; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(std::numeric_limits::max())); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_XOR_w_u32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_xor; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_AND_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = unsigned; -+ using ElementSource = unsigned; -+ using ElementCompute = unsigned; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(0xffffffff)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_OR_w_u32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::bit_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ANY_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ALL_w_s32) { -+ -+ int const kV = 1; -+ using ElementOutput = int; -+ using ElementSource = int; -+ using ElementCompute = int; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ANY_w_f32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_or; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(0))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Test tensor reduction from NHWC to NHC -+TEST(Reduction_TensorReduce, nhwc_ALL_w_f32) { -+ -+ int const kV = 1; -+ using ElementOutput = float; -+ using ElementSource = float; -+ using ElementCompute = float; -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ // Define the functor -+ using Functor = cutlass::logical_and; -+ -+ using TensorReduction = cutlass::reduction::device::TensorReduction< -+ ElementOutput, -+ ElementSource, -+ Layout, -+ Functor, -+ kV, -+ ElementCompute -+ >; -+ -+ EXPECT_TRUE(TestAllReduction_NHWC_reduce_w(ElementCompute(1))); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu -new file mode 100644 -index 0000000..6a990f4 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk.cu -@@ -0,0 +1,389 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests for device-wide GEMM interface -+*/ -+ -+#include -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/reduction/kernel/reduce_split_k.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace reduction { -+ -+template -+__global__ void kernel_reduce_splitk(typename ReductionKernel::Params params) { -+ -+ __shared__ typename ReductionKernel::SharedStorage shared_storage; -+ -+ ReductionKernel reduction_op; -+ -+ reduction_op(params, shared_storage); -+} -+ -+template -+class ReduceSplitKTestbed { -+public: -+ -+ using ElementAccumulator = typename ReductionKernel::ElementAccumulator; -+ using ElementWorkspace = typename ReductionKernel::ElementWorkspace; -+ using ElementOutput = typename ReductionKernel::ElementOutput; -+ using Layout = cutlass::layout::RowMajor; -+ -+public: -+ -+ cutlass::Distribution::Kind distribution_workspace; -+ cutlass::Distribution::Kind distribution_source; -+ uint64_t seed; -+ -+public: -+ -+ /// Ctor -+ ReduceSplitKTestbed( -+ cutlass::Distribution::Kind distribution_workspace = cutlass::Distribution::Uniform, -+ cutlass::Distribution::Kind distribution_source = cutlass::Distribution::Uniform, -+ uint64_t seed = 2019 -+ ): -+ distribution_workspace(distribution_workspace), -+ distribution_source(distribution_source), -+ seed(seed) { -+ -+ } -+ -+ /// Helper to initialize a tensor view -+ template -+ bool initialize_tensor(cutlass::TensorView view, -+ cutlass::Distribution::Kind dist_kind, -+ uint64_t seed) { -+ -+ if (dist_kind == cutlass::Distribution::Uniform) { -+ cutlass::reference::host::TensorFillRandomUniform(view, seed, 8, -8, 0); -+ } -+ else if (dist_kind == cutlass::Distribution::Gaussian) { -+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); -+ } else if (dist_kind == cutlass::Distribution::Identity) { -+ cutlass::reference::host::TensorFillIdentity(view); -+ } else if (dist_kind == cutlass::Distribution::Sequential) { -+ cutlass::reference::host::BlockFillSequential(view.data(), -+ view.capacity()); -+ } else { -+ // TODO: Implement the rest -+ EXPECT_TRUE(false) << "Not implemented"; -+ return false; -+ } -+ -+ return true; -+ } -+ -+ /// Runs a single problem size -+ bool run( -+ cutlass::MatrixCoord problem_size, -+ int partitions, -+ ElementAccumulator alpha = 1, -+ ElementAccumulator beta = 0) { -+ -+ cutlass::HostTensor workspace({ -+ problem_size.row() * partitions, -+ problem_size.column() -+ }); -+ -+ cutlass::HostTensor source(problem_size); -+ cutlass::HostTensor destination(problem_size); -+ cutlass::HostTensor destination_reference(problem_size, false); -+ -+ // -+ // Initialize -+ // -+ initialize_tensor(workspace.host_view(), distribution_workspace, seed); -+ initialize_tensor(source.host_view(), distribution_source, seed + 23); -+ -+ cutlass::reference::host::TensorFill(destination.host_view()); -+ -+ workspace.sync_device(); -+ source.sync_device(); -+ destination.sync_device(); -+ -+ // -+ // Launch reduction kernel -+ // -+ -+ dim3 block = ReductionKernel::block_shape(); -+ dim3 grid = ReductionKernel::grid_shape(problem_size); -+ -+ typename ReductionKernel::Params params( -+ problem_size, -+ partitions, -+ problem_size.row() * problem_size.column(), -+ workspace.device_ref(), -+ destination.device_ref(), -+ source.device_ref(), -+ {alpha, beta} -+ ); -+ -+ test::reduction::kernel_reduce_splitk<<< grid, block >>>(params); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) -+ << "CUDA error: " << cudaGetErrorString(result); -+ -+ destination.sync_host(); -+ -+ // -+ // Compute reference -+ // -+ -+ for (int m = 0; m < problem_size.row(); ++m) { -+ for (int n = 0; n < problem_size.column(); ++n) { -+ -+ ElementAccumulator accum = 0; -+ -+ for (int k = 0; k < partitions; ++k) { -+ accum += ElementAccumulator(workspace.at({m + k * problem_size.row(), n})); -+ } -+ -+ ElementAccumulator c = ElementAccumulator(source.at({m, n})); -+ -+ destination_reference.at({m, n}) = ElementOutput(accum * alpha + beta * c); -+ } -+ } -+ -+ // -+ // Compare -+ // -+ -+ EXPECT_GT(cutlass::reference::host::TensorNorm(destination.host_view()), 0); -+ EXPECT_GT(cutlass::reference::host::TensorNorm(destination_reference.host_view()), 0); -+ -+ bool passed = cutlass::reference::host::TensorEquals( -+ destination.host_view(), destination_reference.host_view()); -+ -+ EXPECT_TRUE(passed) -+ << "Workspace =\n" << workspace.host_view() << "\n\n" -+ << "\n" -+ << "Reference =\n" << destination_reference.host_view() << "\n\n" -+ << "Computed =\n" << destination.host_view() << "\n"; -+ -+ return passed; -+ } -+ -+ /// Runs through a variety of test cases -+ bool run_all() { -+ -+ cutlass::MatrixCoord problem_sizes[] = { -+ {8, 8}, -+ {136, 72}, -+ {248, 232}, -+ }; -+ -+ int partition_counts[] = { -+ 1,3,4,5,11 -+ }; -+ -+ bool passed = false; -+ -+ for (cutlass::MatrixCoord problem : problem_sizes) { -+ for (int partitions : partition_counts) { -+ passed = run(problem, partitions); -+ if (!passed) { -+ return false; -+ } -+ } -+ } -+ -+ return passed; -+ } -+}; -+ -+} // namespace reduction -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Strictly F32 data -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f32_1_1x32) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ int const kN = 1; -+ using Shape = cutlass::MatrixShape<1, 32>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f32_2_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ int const kN = 2; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f16_2_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ int const kN = 2; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Vectorized access -+// -+TEST(Reduction_ReduceSplitK, f32_f32_f16_8_4x64) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ int const kN = 8; -+ using Shape = cutlass::MatrixShape<4, 64>; -+ -+ using OutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ kN, -+ ElementAccumulator, -+ ElementAccumulator -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ ElementWorkspace, -+ kN -+ >; -+ -+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< -+ Shape, -+ OutputOp, -+ ReductionOp -+ >; -+ -+ test::reduction::ReduceSplitKTestbed testbed; -+ -+ EXPECT_TRUE(testbed.run_all()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h -new file mode 100644 -index 0000000..78c720a ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h -@@ -0,0 +1,45 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#pragma once -+ -+#include "cutlass/reduction/thread/reduce.h" -+ -+#include "cutlass/layout/vector.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -diff --git a/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu b/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu -new file mode 100644 -index 0000000..be92fea ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/thread/reduction_thread.cu -@@ -0,0 +1,100 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "testbed.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+ -+TEST(Reduce_thread_device, Reduce_half_t_1) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_device, Reduce_half_t_16) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 16 -+ >().run(); -+} -+ -+TEST(Reduce_thread_device, Reduce_half_t_31) { -+ -+ test::reduction::thread::Testbed_reduce_device< -+ cutlass::half_t, -+ 31 -+ >().run(); -+} -+ -+ -+TEST(Reduce_thread_host, Reduce_float_1) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ float, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_host, Reduce_float_16) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ float, -+ 16 -+ >().run(); -+ -+} -+ -+TEST(Reduce_thread_host, Reduce_half_t_1) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ cutlass::half_t, -+ 1 -+ >().run(); -+} -+ -+TEST(Reduce_thread_host, Reduce_half_t_16) { -+ -+ test::reduction::thread::Testbed_reduce_host< -+ cutlass::half_t, -+ 16 -+ >().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/reduction/thread/testbed.h b/3rdparty/cutlass/test/unit/reduction/thread/testbed.h -new file mode 100644 -index 0000000..e0e38ed ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/reduction/thread/testbed.h -@@ -0,0 +1,242 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Unit tests for thread-level Reduction -+*/ -+ -+#pragma once -+ -+#include "cutlass/reduction/thread/reduce.h" -+ -+#include "cutlass/layout/vector.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/reference/host/tensor_copy.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+namespace test { -+namespace reduction { -+namespace thread { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure to compute the reduction -+template < -+ /// Data type of elements -+ typename Element, -+ /// Number of elements -+ int N -+> -+struct Testbed_reduce_host { -+ -+ /// Thread-level reduction operator -+ using Reduce = cutlass::reduction::thread::Reduce< -+ cutlass::plus, -+ cutlass::Array -+ >; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::Array tensor_in; -+ cutlass::Array reduced_tensor_computed; -+ cutlass::Array reduced_tensor_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed_reduce_host() { -+ tensor_in.clear(); -+ reduced_tensor_computed.clear(); -+ reduced_tensor_reference.clear(); -+ } -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize memory -+ // -+ -+ for(int i = 0; i < N; i++) -+ tensor_in.at(i) = Element(i); -+ -+ -+ Reduce reduce; -+ -+ cutlass::Array *out_ptr = &reduced_tensor_computed; -+ out_ptr[0] = reduce(tensor_in); -+ -+ // -+ // Reference implementation -+ // -+ Element e(0); -+ for (int i = 0; i < N; i++) -+ e = e + Element(i); -+ -+ reduced_tensor_reference.at(0) = e; -+ -+ // -+ // Verify equivalence -+ // -+ -+ // compare -+ bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; -+ -+ EXPECT_TRUE(passed) -+ << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" -+ << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" -+ << std::endl; -+ -+ return passed; -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level reduction kernel -+template -+__global__ void kernel_reduce(Element const *array_in, Element *result) { -+ -+ /// Thread-level reduction operator -+ using Reduce = cutlass::reduction::thread::Reduce< -+ cutlass::plus, -+ cutlass::Array -+ >; -+ -+ Reduce reduce; -+ -+ auto ptr_in = reinterpret_cast const *>(array_in); -+ auto result_ptr = reinterpret_cast *>(result); -+ auto in = *ptr_in; -+ result_ptr[0] = reduce(in); -+} -+ -+ -+/// Structure to compute the reduction -+template < -+ /// Data type of elements -+ typename Element, -+ /// Number of elements -+ int N -+> -+struct Testbed_reduce_device { -+ -+ using Layout = cutlass::layout::PackedVectorLayout; -+ -+ // -+ // Data members -+ // -+ -+ cutlass::HostTensor tensor_in; -+ cutlass::HostTensor reduced_tensor_computed; -+ cutlass::HostTensor reduced_tensor_reference; -+ -+ // -+ // Methods -+ // -+ -+ /// Allocates workspace in device memory -+ Testbed_reduce_device() { -+ -+ tensor_in.reset(cutlass::make_Coord(N), true); -+ reduced_tensor_computed.reset(cutlass::make_Coord(1), true); -+ reduced_tensor_reference.reset(cutlass::make_Coord(1), true); -+ } -+ -+ -+ /// Runs the test -+ bool run() { -+ -+ // -+ // initialize memory -+ // -+ -+ cutlass::reference::host::TensorFill( -+ tensor_in.host_view(), -+ Element(1) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ reduced_tensor_computed.host_view(), -+ Element(0) -+ ); -+ -+ cutlass::reference::host::TensorFill( -+ reduced_tensor_reference.host_view(), -+ Element(N) -+ ); -+ -+ tensor_in.sync_device(); -+ reduced_tensor_computed.sync_device(); -+ reduced_tensor_reference.sync_device(); -+ -+ /// call the kernel -+ kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( -+ tensor_in.device_data(), -+ reduced_tensor_computed.device_data() -+ ); -+ -+ // verify no errors -+ cudaError_t result = cudaDeviceSynchronize(); -+ -+ EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); -+ if (result != cudaSuccess) { -+ return false; -+ } -+ -+ // Copy back results -+ reduced_tensor_computed.sync_host(); -+ -+ // Verify equivalence -+ bool passed = cutlass::reference::host::TensorEquals( -+ reduced_tensor_computed.host_view(), -+ reduced_tensor_reference.host_view() -+ ); -+ -+ EXPECT_TRUE(passed) -+ << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" -+ << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" -+ << std::endl; -+ -+ return passed; -+ } -+}; -+ -+} // namespace thread -+} // namespace reduction -+} // namespace test -diff --git a/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu b/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu -new file mode 100644 -index 0000000..e30986b ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/transform/threadblock/predicated_tile_iterator.cu -@@ -0,0 +1,798 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Tests cutlass::transform::threadblock::PredicatedTileIterator -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" -+ -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace transform { -+namespace threadblock { -+namespace kernel { -+ -+/// Copy with an iterator -+template -+__global__ void copy( -+ typename Iterator::Params dst_params, -+ typename Iterator::Element *dst_pointer, -+ typename Iterator::Params src_params, -+ typename Iterator::Element *src_pointer, -+ cutlass::Coord<2> extent) { -+ -+ Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x); -+ Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x); -+ -+ int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided; -+ -+ typename Iterator::Fragment frag; -+ -+ for(int i = 0; i < frag.size(); i++) -+ frag[i] = 0; -+ -+ src_iterator.load(frag); -+ dst_iterator.store(frag); -+ -+ ++dst_iterator; -+ ++src_iterator; -+ -+ for (; iterations > 1; --iterations) { -+ -+ src_iterator.load(frag); -+ dst_iterator.store(frag); -+ -+ ++dst_iterator; -+ ++src_iterator; -+ } -+} -+ -+} // namespace kernel -+} // namespace threadblock -+} // namespace transform -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(57, 35); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 35); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_128x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 4>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, false -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(128, 4); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(128, 4); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_128x64) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 64>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(128, 64); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(128, 64); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x64) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 64>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 64); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 64); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x8) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(32, 8); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 8); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity()); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]; ++s) { -+ for (int c = 0; c < alloc_extent[0]; ++c) { -+ -+ Element expected = Element(0); -+ -+ if (c < copy_extent[0] && s < copy_extent[1]) { -+ expected = src_tensor.at({c, s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({c, s}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x32_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 32); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 32); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_64x29_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(64, 29); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(64, 29); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_120x4_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<128, 4>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(120, 4); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(120, 4); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(Transform_threadblock_PredicatedTileIterator, PitchLinear_Stripmined_2dtile_48x29_transpose4x4) { -+ -+ using Shape = cutlass::layout::PitchLinearShape<64, 8>; -+ using ThreadTileShape = cutlass::layout::PitchLinearShape<4, 4>; -+ using Layout = cutlass::layout::PitchLinear; -+ using Element = int8_t; -+ static int const kThreads = 32; -+ -+ using ThreadMap = cutlass::transform::PitchLinear2DThreadTileStripminedThreadMap; -+ -+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< -+ Shape, Element, Layout, 1, ThreadMap, true -+ >; -+ -+ cutlass::Coord<2> copy_extent = cutlass::make_Coord(48, 29); -+ cutlass::Coord<2> alloc_extent = cutlass::make_Coord(48, 29); -+ -+ cutlass::HostTensor src_tensor(alloc_extent); -+ cutlass::HostTensor dst_tensor(alloc_extent); -+ -+ Element oob_value = Element(-1); -+ uint64_t seed = 7; -+ cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value); -+ cutlass::reference::host::TensorFillRandomUniform(src_tensor.host_view(), seed, 8, -8, 0); -+ -+ dst_tensor.sync_device(); -+ src_tensor.sync_device(); -+ -+ typename Iterator::Params dst_params(dst_tensor.layout()); -+ typename Iterator::Params src_params(src_tensor.layout()); -+ -+ dim3 block(kThreads, 1); -+ dim3 grid(1, 1); -+ -+ test::transform::threadblock::kernel::copy<<< grid, block >>>( -+ dst_params, -+ dst_tensor.device_data(), -+ src_params, -+ src_tensor.device_data(), -+ copy_extent -+ ); -+ -+ cudaError_t result = cudaGetLastError(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA error: " << cudaGetErrorString(result); -+ -+ dst_tensor.sync_host(); -+ -+ for (int s = 0; s < alloc_extent[1]/4; ++s) { -+ for (int c = 0; c < alloc_extent[0]/4; ++c) { -+ for (int s1 = 0; s1 < 4; s1++){ -+ for(int c1 = 0; c1 < 4; c1++){ -+ Element expected = Element(0); -+ -+ int l_c = c * 4 + c1; -+ int l_s = s * 4 + s1; -+ -+ int l_tc = c * 4 + s1; -+ int l_ts = s * 4 + c1; -+ -+ if (l_c < copy_extent[0] && l_s < copy_extent[1]) { -+ expected = src_tensor.at({l_c, l_s}); -+ } -+ else { -+ expected = oob_value; -+ } -+ -+ Element got = dst_tensor.at({l_tc, l_ts}); -+ bool equal = (expected == got); -+ -+ EXPECT_EQ(expected, got) -+ << "Source:\n" << src_tensor.host_view() << "\n\n" -+ << "Destination:\n" << dst_tensor.host_view() << "\n"; -+ if (!equal) { -+ return; -+ } -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu b/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu -new file mode 100644 -index 0000000..c5ad3e9 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/transform/threadblock/regular_tile_iterator_tensor_op.cu -@@ -0,0 +1,289 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief -+*/ -+ -+#include "../../common/cutlass_unit_test.h" -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/core_io.h" -+#include "cutlass/layout/pitch_linear.h" -+ -+#include "cutlass/transform/pitch_linear_thread_map.h" -+#include "cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h" -+ -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace test { -+namespace gemm { -+namespace threadblock { -+ -+/// -+template -+__global__ void kernel_gemm_threadblock_tensor_op_multiplicand_store( -+ typename Iterator::TensorRef ref_output, -+ typename Iterator::Element *input) { -+ -+ // Construct fragment -+ typename Iterator::Fragment frag; -+ -+ frag.clear(); -+ -+ // each thread loads a fragment -+ using AccessType = cutlass::Array; -+ -+ int const kElementsPerAccess = Iterator::ThreadMap::kElementsPerAccess; -+ int stride = Iterator::Shape::kContiguous; -+ -+ int warp_id = (threadIdx.x / 32); -+ int lane_id = (threadIdx.x % 32); -+ -+ input += (lane_id % 8) * kElementsPerAccess + (lane_id / 8) * stride; -+ -+ input += (warp_id * Iterator::Shape::kStrided / Iterator::ThreadMap::Detail::kWarpCount) * stride; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int s = 0; s < Iterator::ThreadMap::Iterations::kStrided; ++s) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int c = 0; c < Iterator::ThreadMap::Iterations::kContiguous; ++c) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int v = 0; v < Iterator::ThreadMap::kElementsPerAccess; ++v) { -+ frag[v + Iterator::ThreadMap::kElementsPerAccess * (c + s * Iterator::ThreadMap::Iterations::kContiguous)] = -+ input[v + c * 64 + s * Iterator::ThreadMap::Delta::kStrided * stride]; -+ } -+ } -+ } -+ -+ // Use iterator to store results -+ Iterator iter(ref_output, threadIdx.x); -+ iter.store(frag); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Simple test environment -+template < -+ typename Shape_, -+ int WarpCount -+> -+class MultiplicandTileIteratorTestbed { -+public: -+ -+ // -+ // Define iterator -+ // -+ -+ using Shape = Shape_; -+ using Element = cutlass::half_t; -+ using Layout = cutlass::layout::TensorOpMultiplicandCongruous< -+ cutlass::sizeof_bits::value, 64>; -+ static int const kAdvanceRank = 1; -+ static int const kThreads = 32 * WarpCount; -+ -+ using ThreadMap = cutlass::transform::PitchLinearWarpRakedThreadMap< -+ Shape, -+ kThreads, -+ cutlass::layout::PitchLinearShape<8, 4>, -+ 128 / cutlass::sizeof_bits::value -+ >; -+ -+ using Iterator = cutlass::transform::threadblock::RegularTileIterator< -+ Shape, Element, Layout, kAdvanceRank, ThreadMap -+ >; -+ -+public: -+ -+ // -+ // Members -+ // -+ -+ cutlass::HostTensor destination_tensor; -+ cutlass::HostTensor source_tensor; -+ -+ -+public: -+ -+ MultiplicandTileIteratorTestbed(): -+ destination_tensor({Shape::kContiguous, Shape::kStrided}), -+ source_tensor({Shape::kContiguous, Shape::kStrided}) { -+ -+ } -+ -+ bool run() { -+ -+ cutlass::reference::host::BlockFillSequential( -+ source_tensor.host_data(), -+ source_tensor.capacity() -+ ); -+ -+ cutlass::reference::host::BlockFillSequential( -+ destination_tensor.host_data(), -+ destination_tensor.capacity(), -+ Element(0), -+ Element(0) -+ ); -+ -+ // -+ // Launch kernel -+ // -+ -+ dim3 grid(1,1); -+ dim3 block(kThreads, 1); -+ -+ destination_tensor.sync_device(); -+ source_tensor.sync_device(); -+ -+ test::gemm::threadblock::kernel_gemm_threadblock_tensor_op_multiplicand_store<<< -+ grid, block -+ >>>( -+ destination_tensor.device_ref(), -+ source_tensor.device_data() -+ ); -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ EXPECT_EQ(result, cudaSuccess) << " - CUDA ERROR: " << cudaGetErrorString(result); -+ -+ destination_tensor.sync_host(); -+ -+ // -+ // Verify -+ // -+ -+ // Verify that its contents match the destination -+ int errors = 0; -+ for (int s = 0; s < Shape::kStrided; ++s) { -+ for (int c = 0; c < Shape::kContiguous; ++c) { -+ -+ if (errors >= 10) { -+ break; -+ } -+ -+ Element expected = source_tensor.at({c, s}); -+ Element got = destination_tensor.at({c, s}); -+ -+ bool passed = (expected == got); -+ if (!passed) { -+ ++errors; -+ } -+ } -+ } -+ -+ EXPECT_EQ(errors, 0) -+ << source_tensor.host_view() << "\n\n" << destination_tensor.host_view() << std::endl; -+ -+ return !errors; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace gemm -+} // namespace test -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x8_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 8>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x16_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 16>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x16_w2) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 16>, 2>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x8_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 8>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 64x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<64, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x32_w1) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 32>, 1>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 128x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<128, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 256x32_w4) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<256, 32>, 4>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_gemm_threadblock_tensor_op_multplicand_iterator_congruous_16b, 256x32_w8) { -+ -+ test::gemm::threadblock::MultiplicandTileIteratorTestbed< -+ cutlass::layout::PitchLinearShape<256, 32>, 8>().run(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu b/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu -new file mode 100644 -index 0000000..3879783 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/util/cutlass_test_levels.cu -@@ -0,0 +1,77 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(SM75_CUTLASS_TEST, level_not_specified) { -+ -+ EXPECT_TRUE(true); -+} -+ -+TEST(SM80_CUTLASS_TEST, level_not_specified) { -+ -+ EXPECT_TRUE(true); -+} -+ -+CUTLASS_TEST_L0(SM75_CUTLASS_TEST, level0, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L1(SM75_CUTLASS_TEST, level1, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L2(SM75_CUTLASS_TEST, level2, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L0(SM80_CUTLASS_TEST, level0, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L1(SM80_CUTLASS_TEST, level1, { -+ -+ EXPECT_TRUE(true); -+}) -+ -+CUTLASS_TEST_L2(SM80_CUTLASS_TEST, level2, { -+ -+ EXPECT_TRUE(true); -+}) -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/test/unit/util/tensor_reduce.cu b/3rdparty/cutlass/test/unit/util/tensor_reduce.cu -new file mode 100644 -index 0000000..c71d080 ---- /dev/null -+++ b/3rdparty/cutlass/test/unit/util/tensor_reduce.cu -@@ -0,0 +1,244 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#include -+ -+#include "../common/cutlass_unit_test.h" -+ -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/reference/device/tensor_reduce.h" -+#include "cutlass/util/reference/host/tensor_norm.h" -+#include "cutlass/util/host_tensor.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+TEST(TensorReduce, norm_rowmajor_f32) { -+ -+ int const kM = 129; -+ int const kN = 91; -+ -+ cutlass::HostTensor tensor({kM, kN}); -+ -+ for (int m = 0; m < kM; ++m) { -+ for (int n = 0; n < kN; ++n) { -+ -+ float x = float(((m * kN + m + 7) % 8) - 4); -+ -+ tensor.at({m, n}) = x; -+ } -+ } -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001); -+} -+ -+TEST(TensorReduce, norm_nhwc_f32) { -+ -+ int const kN = 19; -+ int const kH = 18; -+ int const kW = 17; -+ int const kC = 16; -+ -+ cutlass::HostTensor tensor({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double computed_norm = double(); -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float x = float(((idx + 7) % 8) - 4); -+ -+ computed_norm += double(x) * double(x); -+ -+ tensor.at({n, h, w, c}) = x; -+ } -+ } -+ } -+ } -+ -+ computed_norm = std::sqrt(computed_norm); -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) -+ << "computed norm: " << computed_norm << "\n" -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm << "\n"; -+} -+ -+TEST(TensorReduce, norm_nhwc_f16) { -+ -+ int const kN = 69; -+ int const kH = 68; -+ int const kW = 67; -+ int const kC = 66; -+ -+ cutlass::HostTensor tensor({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double computed_norm = double(); -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float x = float(((idx + 7) % 8) - 4); -+ computed_norm += double(x) * double(x); -+ -+ tensor.at({n, h, w, c}) = cutlass::half_t(x); -+ } -+ } -+ } -+ } -+ -+ computed_norm = std::sqrt(computed_norm); -+ -+ tensor.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); -+ double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) -+ << "computed norm: " << computed_norm << "\n" -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm << "\n"; -+} -+ -+TEST(TensorReduce, norm_diff_nhwc_f32) { -+ -+ int const kN = 59; -+ int const kH = 24; -+ int const kW = 57; -+ int const kC = 78; -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ cutlass::HostTensor tensor_A({kN, kH, kW, kC}); -+ cutlass::HostTensor tensor_B({kN, kH, kW, kC}); -+ -+ -+ int idx = 0; -+ -+ double sum_sq_diff = 0; -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float a = float(((idx * 5 + 7) % 8) - 4); -+ float b = float(((idx * 3 + 7) % 8) - 4); -+ -+ sum_sq_diff += double(a - b) * double(a - b); -+ -+ tensor_A.at({n, h, w, c}) = a; -+ tensor_B.at({n, h, w, c}) = b; -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNormDiff( -+ tensor_A.device_view(), tensor_B.device_view(), double()); -+ -+ double host_norm = std::sqrt(sum_sq_diff); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm; -+} -+ -+ -+TEST(TensorReduce, norm_diff_nhwc_f16) { -+ -+ int const kN = 59; -+ int const kH = 24; -+ int const kW = 57; -+ int const kC = 78; -+ -+ using Layout = cutlass::layout::TensorNHWC; -+ -+ cutlass::HostTensor tensor_A({kN, kH, kW, kC}); -+ cutlass::HostTensor tensor_B({kN, kH, kW, kC}); -+ -+ int idx = 0; -+ -+ double sum_sq_diff = 0; -+ -+ for (int n = 0; n < kN; ++n) { -+ for (int h = 0; h < kH; ++h) { -+ for (int w = 0; w < kW; ++w) { -+ for (int c = 0; c < kC; ++c, ++idx) { -+ -+ float a = float(((idx * 5 + 7) % 8) - 4); -+ float b = float(((idx * 3 + 7) % 8) - 4); -+ -+ sum_sq_diff += double(a - b) * double(a - b); -+ -+ tensor_A.at({n, h, w, c}) = cutlass::half_t(a); -+ tensor_B.at({n, h, w, c}) = cutlass::half_t(b); -+ } -+ } -+ } -+ } -+ -+ tensor_A.sync_device(); -+ tensor_B.sync_device(); -+ -+ double device_norm = cutlass::reference::device::TensorNormDiff( -+ tensor_A.device_view(), tensor_B.device_view(), double()); -+ -+ double host_norm = std::sqrt(sum_sq_diff); -+ -+ EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) -+ << " host norm: " << host_norm << "\n" -+ << "device norm: " << device_norm; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h -new file mode 100644 -index 0000000..0d6790e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/arch_mappings.h -@@ -0,0 +1,110 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+#include "cutlass/arch/mma.h" -+#include "cutlass/arch/arch.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ArchMap; -+ -+template <> struct ArchMap { -+ static int const kMin = 50; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 60; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 61; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 70; -+ static int const kMax = 1024; -+}; -+ -+template <> struct ArchMap { -+ static int const kMin = 70; -+ static int const kMax = 75; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 75; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 80; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 86; -+ static int const kMax = 1024; -+}; -+ -+template struct ArchMap { -+ static int const kMin = 90; -+ static int const kMax = 1024; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h b/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h -new file mode 100644 -index 0000000..8125989 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/handle.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief BLAS-like handle used to launch operations on the CUDA device. -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/library/library.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Handle object -+class Handle { -+private: -+ -+ /// Host workspace -+ static int const kHostWorkspaceSize = (4 << 10); -+ -+ /// Provider of operations -+ Provider provider_; -+ -+ /// CUDA device properties -+ cudaDeviceProp device_; -+ -+ /// CUDA stream -+ cudaStream_t stream_; -+ -+ /// Device workspace -+ void *workspace_; -+ -+ /// Size of device workspace in bytes -+ size_t workspace_size_; -+ -+ /// Indicates whether scalars are host or device pointers -+ ScalarPointerMode scalar_pointer_mode_; -+ -+ /// Pointer to the most recently executed operation -+ Operation const *last_operation_; -+ -+public: -+ -+ /// Constructor -+ Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); -+ -+ /// Destructor -+ ~Handle(); -+ -+ /// Move constructor -+ Handle(Handle && handle); -+ -+ /// Move assignment operator -+ Handle &operator=(Handle && handle); -+ -+ // -+ // Persistent state accessors -+ // -+ -+ /// Returns compute capability of the selected device -+ int compute_capability() const; -+ -+ /// Sets the current CUDA stream -+ void set_stream(cudaStream_t stream); -+ -+ /// Gets the current CUDA stream -+ cudaStream_t get_stream() const; -+ -+ /// Gets the current provider -+ Provider get_provider() const; -+ -+ /// Sets the provider of operations -+ void set_provider(Provider provider); -+ -+ /// Gets the device workspace size -+ size_t get_workspace_size() const; -+ -+ /// Gets a pointer to the device workspace allocation in Global Memory -+ void *get_workspace() const; -+ -+ /// Sets the size of device workspace, invalidating calls to get_device_workspace() -+ void set_workspace_size(size_t bytes); -+ -+ /// Gets the scalar pointer mode -+ ScalarPointerMode get_scalar_pointer_mode() const; -+ -+ /// Sets the scalar pointer mode -+ void set_scalar_pointer_mode(ScalarPointerMode mode); -+ -+ /// Gets the most recently executed operation -+ Operation const *get_last_operation() const; -+ -+ // -+ // Computations -+ // -+ -+ /// Executes a GEMM computation: D <= alpha * A*B + beta * C -+ Status gemm( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd /// Leading dimension of D matrix -+ ); -+ -+ /// Executes a GEMM computation: D <= alpha * A*B + beta * C. -+ // -+ // Supports batched-strided, batched array or split-K serial or split-K parallel. -+ // -+ Status gemm_universal( -+ -+ GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd, /// Leading dimension of D matrix -+ -+ int batch_count = 1, /// Batch count or number of split-K slices -+ -+ int64_t batch_stride_A = 0, /// Batch stride of A operand -+ int64_t batch_stride_B = 0, /// Batch stride of B operand -+ int64_t batch_stride_C = 0, /// Batch stride of C operand -+ int64_t batch_stride_D = 0 /// Batch stride of D operand -+ ); -+ -+ /// Planar complex GEMM -+ /// -+ /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. -+ /// -+ Status gemm_planar_complex( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * ptr_A_real, /// Pointer to real part of A matrix -+ void const * ptr_A_imag, /// Pointer to imaginary part of A matrix -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * ptr_B_real, /// Pointer to real part of B matrix -+ void const * ptr_B_imag, /// Pointer to imaginary part of B matrix -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * ptr_C_real, /// Pointer to real part of C matrix -+ void const * ptr_C_imag, /// Pointer to imaginary part of C matrix -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * ptr_D_real, /// Pointer to real part of D matrix -+ void * ptr_D_imag, /// Pointer to imaginary part of D matrix -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix -+ -+ int batch_count = 1, /// Number of batched GEMMs to execute -+ -+ int64_t batch_stride_A_real = 0, -+ int64_t batch_stride_A_imag = 0, -+ -+ int64_t batch_stride_B_real = 0, -+ int64_t batch_stride_B_imag = 0, -+ -+ int64_t batch_stride_C_real = 0, -+ int64_t batch_stride_C_imag = 0, -+ -+ int64_t batch_stride_D_real = 0, -+ int64_t batch_stride_D_imag = 0 -+ ); -+ -+ /// Planar complex GEMM loading pointers from arrays in global memory -+ Status gemm_planar_complex_array( -+ -+ int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) -+ int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) -+ int expected_K, /// Expected GEMM K dimension -+ int batch_count, /// Number of independent GEMM computations to execute -+ -+ int const *M, /// Array containing the GEMM M dimension for each batch index -+ int const *N, /// Array containing the GEMM N dimension for each batch index -+ int const *K, /// Array containing the GEMM K dimension for each batch index -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices -+ void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices -+ -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices -+ void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices -+ -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices -+ void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices -+ -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices -+ void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices -+ -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag /// Leading dimension of imaginary part of D matrix -+ ); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Unique pointer storing the handle -+using HandlePtr = std::unique_ptr; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace -+Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace -+Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/library.h b/3rdparty/cutlass/tools/library/include/cutlass/library/library.h -new file mode 100644 -index 0000000..6bb3f79 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/library.h -@@ -0,0 +1,1537 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Layout type identifier -+enum class LayoutTypeID { -+ kUnknown, -+ kColumnMajor, -+ kRowMajor, -+ kColumnMajorInterleavedK2, -+ kRowMajorInterleavedK2, -+ kColumnMajorInterleavedK4, -+ kRowMajorInterleavedK4, -+ kColumnMajorInterleavedK16, -+ kRowMajorInterleavedK16, -+ kColumnMajorInterleavedK32, -+ kRowMajorInterleavedK32, -+ kColumnMajorInterleavedK64, -+ kRowMajorInterleavedK64, -+ kTensorNCHW, -+ kTensorNCDHW, -+ kTensorNHWC, -+ kTensorNDHWC, -+ kTensorNC32HW32, -+ kTensorC32RSK32, -+ kTensorNC64HW64, -+ kTensorC64RSK64, -+ kInvalid -+}; -+ -+/// Numeric data type -+enum class NumericTypeID { -+ kUnknown, -+ kVoid, -+ kB1, -+ kU2, -+ kU4, -+ kU8, -+ kU16, -+ kU32, -+ kU64, -+ kS2, -+ kS4, -+ kS8, -+ kS16, -+ kS32, -+ kS64, -+ kF16, -+ kBF16, -+ kTF32, -+ kF32, -+ kF64, -+ kCF16, -+ kCBF16, -+ kCF32, -+ kCTF32, -+ kCF64, -+ kCS2, -+ kCS4, -+ kCS8, -+ kCS16, -+ kCS32, -+ kCS64, -+ kCU2, -+ kCU4, -+ kCU8, -+ kCU16, -+ kCU32, -+ kCU64, -+ kInvalid -+}; -+ -+/// Enumerated type describing a transformation on a complex value. -+enum class ComplexTransform { -+ kNone, -+ kConjugate, -+ kInvalid -+}; -+ -+/// Providers -+enum class Provider { -+ kNone, -+ kCUTLASS, -+ kReferenceHost, -+ kReferenceDevice, -+ kCUBLAS, -+ kCUDNN, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumeration indicating the kind of operation -+enum class OperationKind { -+ kGemm, -+ kRankK, -+ kRank2K, -+ kTrmm, -+ kSymm, -+ kConv2d, -+ kConv3d, -+ kEqGemm, -+ kSparseGemm, -+ kReduction, -+ kInvalid -+}; -+ -+/// Enumeration indicating whether scalars are in host or device memory -+enum class ScalarPointerMode { -+ kHost, -+ kDevice, -+ kInvalid -+}; -+ -+/// Describes how reductions are performed across threadblocks -+enum class SplitKMode { -+ kNone, -+ kSerial, -+ kParallel, -+ kParallelSerial, -+ kInvalid -+}; -+ -+/// Indicates the classificaition of the math instruction -+enum class OpcodeClassID { -+ kSimt, -+ kTensorOp, -+ kWmmaTensorOp, -+ kSparseTensorOp, -+ kInvalid -+}; -+ -+enum class MathOperationID { -+ kAdd, -+ kMultiplyAdd, -+ kMultiplyAddSaturate, -+ kMultiplyAddFastBF16, -+ kMultiplyAddFastF16, -+ kMultiplyAddFastF32, -+ kMultiplyAddComplex, -+ kMultiplyAddComplexFastF32, -+ kMultiplyAddGaussianComplex, -+ kXorPopc, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumeration indicating what kind of GEMM operation to perform -+enum class GemmKind { -+ kGemm, -+ kSparse, -+ kUniversal, -+ kPlanarComplex, -+ kPlanarComplexArray, -+ kGrouped, -+ kInvalid -+}; -+ -+/// Mode of Universal GEMM -+using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; -+ -+/// Enumeration indicating what kind of RankK update operation to perform -+enum class RankKKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of TRMM operation to perform -+enum class TrmmKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of SYMM/HEMM operation to perform -+enum class SymmKind { -+ kUniversal, -+ kInvalid -+}; -+ -+/// Enumeration indicating what kind of Conv2d operation to perform -+enum class ConvKind { -+ kUnknown, -+ kFprop, -+ kDgrad, -+ kWgrad, -+ kInvalid -+}; -+ -+enum class ConvModeID { -+ kCrossCorrelation, -+ kConvolution, -+ kInvalid -+}; -+ -+// Iterator algorithm enum in order of general performance-efficiency -+enum class IteratorAlgorithmID { -+ kNone, -+ kAnalytic, -+ kOptimized, -+ kFixedChannels, -+ kFewChannels, -+ kInvalid -+}; -+ -+ -+enum class EpilogueKind { -+ kUnknown, -+ kConversion, -+ kLinearCombination, -+ kLinearCombinationClamp, -+ kLinearCombinationPlanarComplex, -+ kLinearCombinationRelu, -+ kLinearCombinationSigmoid, -+ kInvalid -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct MathInstructionDescription { -+ -+ /// Shape of the target math instruction -+ cutlass::gemm::GemmCoord instruction_shape; -+ -+ /// Describes the data type of the internal accumulator -+ NumericTypeID element_accumulator; -+ -+ /// Classification of math instruction -+ OpcodeClassID opcode_class; -+ -+ /// Type of math operation performed -+ MathOperationID math_operation; -+ -+ // -+ // Methods -+ // -+ -+ MathInstructionDescription( -+ cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), -+ NumericTypeID element_accumulator = NumericTypeID::kInvalid, -+ OpcodeClassID opcode_class = OpcodeClassID::kInvalid, -+ MathOperationID math_operation = MathOperationID::kMultiplyAdd -+ ): -+ instruction_shape(instruction_shape), -+ element_accumulator(element_accumulator), -+ opcode_class(opcode_class), -+ math_operation(math_operation) {} -+ -+ // Equality operator -+ inline -+ bool operator==(MathInstructionDescription const& rhs) const{ -+ return ( -+ (instruction_shape == rhs.instruction_shape) && -+ (element_accumulator == rhs.element_accumulator) && -+ (opcode_class == rhs.opcode_class) && -+ (math_operation == rhs.math_operation)); -+ } -+ -+ // Inequality operator -+ inline -+ bool operator!=(MathInstructionDescription const& rhs) const { -+ return !(*this == rhs); -+ } -+ -+}; -+ -+/// Structure describing the tiled structure of a GEMM-like computation -+struct TileDescription { -+ -+ /// Describes the shape of a threadblock (in elements) -+ cutlass::gemm::GemmCoord threadblock_shape; -+ -+ /// Describes the number of pipeline stages in the threadblock-scoped mainloop -+ int threadblock_stages; -+ -+ /// Number of warps in each logical dimension -+ cutlass::gemm::GemmCoord warp_count; -+ -+ /// Core math instruction -+ MathInstructionDescription math_instruction; -+ -+ /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. -+ int minimum_compute_capability; -+ -+ /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. -+ int maximum_compute_capability; -+ -+ /// Describes the shape of a cluster (in blocks) -+ cutlass::gemm::GemmCoord cluster_shape; -+ -+ // -+ // Methods -+ // -+ -+ TileDescription( -+ cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), -+ int threadblock_stages = 0, -+ cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), -+ MathInstructionDescription math_instruction = MathInstructionDescription(), -+ int minimum_compute_capability = 0, -+ int maximum_compute_capability = 0, -+ cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) -+ ): -+ threadblock_shape(threadblock_shape), -+ threadblock_stages(threadblock_stages), -+ warp_count(warp_count), -+ math_instruction(math_instruction), -+ minimum_compute_capability(minimum_compute_capability), -+ maximum_compute_capability(maximum_compute_capability), -+ cluster_shape(cluster_shape) { } -+ -+ // Equality operator -+ inline -+ bool operator==(TileDescription const& rhs) const{ -+ return ( -+ (threadblock_shape == rhs.threadblock_shape) && -+ (threadblock_stages == rhs.threadblock_stages) && -+ (warp_count == rhs.warp_count) && -+ (math_instruction == rhs.math_instruction) && -+ (minimum_compute_capability == rhs.minimum_compute_capability) && -+ (maximum_compute_capability == rhs.maximum_compute_capability)); -+ } -+ -+ // Inequality operator -+ inline -+ bool operator!=(TileDescription const& rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+/// High-level description of an operation -+struct OperationDescription { -+ -+ /// Unique identifier describing the operation -+ char const * name; -+ -+ /// Operation provider -+ Provider provider; -+ -+ /// Kind of operation -+ OperationKind kind; -+ -+ /// Describes the tiled structure of a GEMM-like computation -+ TileDescription tile_description; -+ -+ // -+ // Methods -+ // -+ OperationDescription( -+ char const * name = "unknown", -+ Provider Provider = Provider::kInvalid, -+ OperationKind kind = OperationKind::kInvalid, -+ TileDescription const & tile_description = TileDescription() -+ ): -+ name(name), kind(kind), tile_description(tile_description) { } -+}; -+ -+/// Structure describing the properties of a tensor -+struct TensorDescription { -+ -+ /// Numeric type of an individual element -+ NumericTypeID element; -+ -+ /// Enumerant identifying the layout function for the tensor -+ LayoutTypeID layout; -+ -+ /// Alignment restriction on pointers, strides, and extents -+ int alignment; -+ -+ /// log2() of the maximum extent of each dimension -+ int log_extent_range; -+ -+ /// log2() of the maximum value each relevant stride may have -+ int log_stride_range; -+ -+ // -+ // Methods -+ // -+ -+ TensorDescription( -+ NumericTypeID element = NumericTypeID::kInvalid, -+ LayoutTypeID layout = LayoutTypeID::kInvalid, -+ int alignment = 1, -+ int log_extent_range = 24, -+ int log_stride_range = 24 -+ ): -+ element(element), -+ layout(layout), -+ alignment(alignment), -+ log_extent_range(log_extent_range), -+ log_stride_range(log_stride_range) { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all GEMM computations -+struct GemmDescription : public OperationDescription { -+ -+ /// Indicates the kind of GEMM performed -+ GemmKind gemm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the sparse meta matrices -+ TensorDescription E; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ GemmDescription( -+ GemmKind gemm_kind = GemmKind::kGemm, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ gemm_kind(gemm_kind), -+ A(A), -+ B(B), -+ C(C), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Desciprion for structured sparse GEMMs. -+struct SparseGemmDescription : public GemmDescription { -+ -+ /// Description structure for structured sparse GEMM -+ SparseGemmDescription( -+ GemmKind gemm_kind = GemmKind::kGemm, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ TensorDescription const &E = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ GemmDescription(gemm_kind, A, B, C, element_epilogue, split_k_mode, transform_A, transform_B) -+ {this->E = E;} -+}; -+ -+/// Description of all Reduction operations -+struct ReductionDescription : public OperationDescription { -+ -+ /// Describes the data type of workspace -+ NumericTypeID element_workspace; -+ -+ /// Describes the data type of final output -+ NumericTypeID element_output; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+}; -+ -+/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) -+struct RankKDescription : public OperationDescription { -+ -+ /// Indicates which device template is used (universal or regular) -+ RankKKind rank_k_kind; -+ -+ /// Number of rank update (rank k or rank 2k) -+ int num_ranks; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand (used only for SYR2K and HER2K) -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the fill mode for matrix C -+ FillMode fill_mode; -+ -+ /// Describes the blas mode (symmetric/hermitian) -+ BlasMode blas_mode; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ RankKDescription( -+ RankKKind rank_k_kind = RankKKind::kUniversal, -+ int num_ranks = 1, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ FillMode fill_mode = FillMode::kInvalid, -+ BlasMode blas_mode = BlasMode::kInvalid, -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ rank_k_kind(rank_k_kind), -+ num_ranks(num_ranks), -+ A(A), -+ B(B), -+ C(C), -+ fill_mode(fill_mode), -+ blas_mode(blas_mode), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all TRMM computations -+struct TrmmDescription : public OperationDescription { -+ -+ /// Indicates the kind of TRMM performed -+ TrmmKind trmm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the side mode for matrix A -+ SideMode side_mode; -+ -+ /// Describes the fill mode for matrix A -+ FillMode fill_mode; -+ -+ /// Describes the diag type for matrix A -+ DiagType diag_type; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription D; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ // -+ // Methods -+ // -+ -+ TrmmDescription( -+ TrmmKind trmm_kind = TrmmKind::kUniversal, -+ TensorDescription const &A = TensorDescription(), -+ SideMode side_mode = SideMode::kInvalid, -+ FillMode fill_mode = FillMode::kInvalid, -+ DiagType diag_type = DiagType::kInvalid, -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &D = TensorDescription(), -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone -+ ): -+ trmm_kind(trmm_kind), -+ A(A), -+ side_mode(side_mode), -+ fill_mode(fill_mode), -+ diag_type(diag_type), -+ B(B), -+ D(D), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all SYMM/HEMM update computations -+struct SymmDescription : public OperationDescription { -+ -+ /// Indicates which device template is used (universal or regular) -+ SymmKind symm_kind; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the source and destination matrices -+ TensorDescription C; -+ -+ /// Describes the side mode for matrix A -+ SideMode side_mode; -+ -+ /// Describes the fill mode for matrix A -+ FillMode fill_mode; -+ -+ /// Describes the blas mode (symmetric/hermitian) -+ BlasMode blas_mode; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ /// Describes the structure of parallel reductions -+ SplitKMode split_k_mode; -+ -+ /// Transformation on A operand -+ ComplexTransform transform_A; -+ -+ /// Transformation on B operand -+ ComplexTransform transform_B; -+ -+ // -+ // Methods -+ // -+ -+ SymmDescription( -+ SymmKind symm_kind = SymmKind::kUniversal, -+ TensorDescription const &A = TensorDescription(), -+ TensorDescription const &B = TensorDescription(), -+ TensorDescription const &C = TensorDescription(), -+ SideMode side_mode = SideMode::kInvalid, -+ FillMode fill_mode = FillMode::kInvalid, -+ BlasMode blas_mode = BlasMode::kInvalid, -+ NumericTypeID element_epilogue = NumericTypeID::kInvalid, -+ SplitKMode split_k_mode = SplitKMode::kNone, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ ComplexTransform transform_B = ComplexTransform::kNone -+ ): -+ symm_kind(symm_kind), -+ A(A), -+ B(B), -+ C(C), -+ side_mode(side_mode), -+ fill_mode(fill_mode), -+ blas_mode(blas_mode), -+ element_epilogue(element_epilogue), -+ split_k_mode(split_k_mode), -+ transform_A(transform_A), -+ transform_B(transform_B) {} -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Description of all Conv2d operations -+struct ConvDescription : public OperationDescription { -+ /// Describes the convolution dimension support (2D or 3D) -+ int conv_dim; -+ -+ /// Describes the kind of convolution -+ ConvKind conv_kind; -+ -+ /// Describes the type of iterator algorithm (analytic or precomputed) -+ IteratorAlgorithmID iterator_algorithm; -+ -+ /// Describes the A operand -+ TensorDescription A; -+ -+ /// Describes the B operand -+ TensorDescription B; -+ -+ /// Describes the C operand -+ TensorDescription C; -+ -+ /// Describes the data type of the scalars passed to the epilogue -+ NumericTypeID element_epilogue; -+ -+ // -+ // Methods -+ // -+ // Returns Activation TensorDescription -+ TensorDescription activation() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return A; -+ case library::ConvKind::kDgrad : return C; -+ case library::ConvKind::kWgrad : return B; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Filter TensorDescription -+ TensorDescription filter() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return B; -+ case library::ConvKind::kDgrad : return B; -+ case library::ConvKind::kWgrad : return C; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Output TensorDescription -+ TensorDescription output() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return C; -+ case library::ConvKind::kDgrad : return A; -+ case library::ConvKind::kWgrad : return A; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for all operations -+class Operation { -+public: -+ -+ virtual ~Operation() { } -+ -+ virtual OperationDescription const & description() const = 0; -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const = 0; -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const = 0; -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const = 0; -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const = 0; -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const = 0; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic GEMM operations -+// -+// OperationKind: Gemm -+// GemmKind: Gemm -+// -+struct GemmConfiguration { -+ -+ /// GEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Number of partitions of K dimension -+ int split_k_slices; -+}; -+ -+/// Arguments for GEMM -+struct GemmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for batched GEMM in which multiple matrix products are computed -+// -+// OperationKind: Gemm -+// GemmKind: Batched -+ -+struct GemmBatchedConfiguration { -+ -+ /// GEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Stride between instances of the A matrix in memory -+ int64_t batch_stride_A; -+ -+ /// Stride between instances of the B matrix in memory -+ int64_t batch_stride_B; -+ -+ /// Stride between instances of the C matrix in memory -+ int64_t batch_stride_C; -+ -+ /// Stride between instances of the D matrix in memory -+ int64_t batch_stride_D; -+ -+ /// Number of GEMMs in batch -+ int batch_count; -+}; -+ -+/// Arguments to batched GEMM -+using GemmBatchedArguments = GemmArguments; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for batched GEMM in which multiple matrix products are computed -+// -+// OperationKind: Gemm -+// GemmKind: Array -+ -+struct GemmArrayConfiguration { -+ -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ int batch_count; -+}; -+ -+/// Arguments for GEMM - used by all the GEMM operations -+struct GemmArrayArguments { -+ void const * const *A; -+ void const * const *B; -+ void const * const *C; -+ void * const *D; -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex -+// -+// OperationKind: Gemm -+// GemmKind: Universal -+ -+struct GemmUniversalConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t ldd; -+}; -+ -+struct GemmUniversalArguments { -+ // NOTE: these are replicated for 3.0 interfaces -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ void const *A; -+ void const *B; -+ void const *C; -+ void *D; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+ -+ // NOTE: these are replicated for 3.0 interfaces -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t ldd; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Complex valued GEMM in which real and imaginary parts are separated by a stride -+// -+// OperationKind: Gemm -+// GemmKind: Planar complex -+ -+struct GemmPlanarComplexConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda_real; -+ int64_t lda_imag; -+ -+ int64_t ldb_real; -+ int64_t ldb_imag; -+ -+ int64_t ldc_real; -+ int64_t ldc_imag; -+ -+ int64_t ldd_real; -+ int64_t ldd_imag; -+}; -+ -+/// Arguments for planar complex GEMMs -+struct GemmPlanarComplexArguments { -+ -+ void const *A_real; -+ void const *A_imag; -+ -+ void const *B_real; -+ void const *B_imag; -+ -+ void const *C_real; -+ void const *C_imag; -+ -+ void *D_real; -+ void *D_imag; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A_real; -+ int64_t batch_stride_A_imag; -+ -+ int64_t batch_stride_B_real; -+ int64_t batch_stride_B_imag; -+ -+ int64_t batch_stride_C_real; -+ int64_t batch_stride_C_imag; -+ -+ int64_t batch_stride_D_real; -+ int64_t batch_stride_D_imag; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// This is a special form of planar complex which loads pointers and problem size -+/// from memory. -+struct GemmPlanarComplexArrayConfiguration { -+ -+ gemm::GemmCoord problem_size; -+ int batch_count; -+ -+ int64_t lda_real; -+ int64_t lda_imag; -+ -+ int64_t ldb_real; -+ int64_t ldb_imag; -+ -+ int64_t ldc_real; -+ int64_t ldc_imag; -+ -+ int64_t ldd_real; -+ int64_t ldd_imag; -+}; -+ -+/// Arguments for planar complex GEMMs -+struct GemmPlanarComplexArrayArguments { -+ -+ int const *M; -+ int const *N; -+ int const *K; -+ -+ void const * const * A_real; -+ void const * const * A_imag; -+ void const * const * B_real; -+ void const * const * B_imag; -+ void const * const * C_real; -+ void const * const * C_imag; -+ void * const * D_real; -+ void * const * D_imag; -+ -+ void const * alpha; -+ void const * beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Grouped GEMM supporting -+// -+// OperationKind: Gemm -+// GemmKind: Grouped -+ -+struct GemmGroupedConfiguration { -+ -+ int problem_count; -+ int threadblock_count; -+ -+}; -+ -+struct GemmGroupedArguments { -+ -+ gemm::GemmCoord *problem_sizes; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t *lda; -+ int64_t *ldb; -+ int64_t *ldc; -+ int64_t *ldd; -+ -+ void const *alpha; -+ void const *beta; -+ ScalarPointerMode pointer_mode; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// OperationKind: kSparseGemm -+// -+ -+/// Computes GEMM assumine one of the inputs has 2:4 structured sparsity. -+struct SparseGemmConfiguration { -+ -+ GemmUniversalMode mode; -+ gemm::GemmCoord problem_size; -+ int batch_count; /// number of sparse matrix products in batch -+ -+ int64_t lda; /// leading dimension of A operand -+ int64_t ldb; /// leading dimension of B operand -+ int64_t ldc; /// leading dimension of C operand -+ int64_t ldd; /// leading dimension of D operand -+ int64_t lde; /// leading dimension of E operand (metadata matrix) -+ -+ int64_t batch_stride_A; // stride between matrices -+ int64_t batch_stride_B; // stride between matrices -+ int64_t batch_stride_C; // stride between matrices -+ int64_t batch_stride_D; // stride between matrices -+ int64_t batch_stride_E; // stride between matrices -+}; -+ -+/// Arguments for sparse GEMMs -+struct SparseGemmArguments { -+ -+ void const *A; /// pointer to A matrix -+ void const *B; /// pointer to B matrix -+ void const *C; /// pointer to C matrix -+ void *D; /// pointer to D matrix -+ void const *E; /// pointer to E matric (metadata) -+ -+ void const *alpha; /// pointer to alpha scalar -+ void const *beta; /// pointer to beta scalar -+ ScalarPointerMode pointer_mode; /// enumerant indicating whether alpha/beta pointers are host -+ /// or device pointers. -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic Rank K update operations -+// -+// OperationKind: (Syrk, Herk, Syr2k, Her2k) -+// RankKKind: Universal -+// -+struct RankKConfiguration { -+ -+ /// SYRK problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for (Syrk, Herk, Syr2k, Her2k) -+struct RankKArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix (used only for Syr2k and Her2k) -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic TRMM operations -+// -+// OperationKind: Trmm -+// TrmmKind: Universal -+// -+struct TrmmConfiguration { -+ -+ /// TRMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for TRMM -+struct TrmmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for basic SYMM/HEMM update operations -+// -+// OperationKind: (Symm, Hemm) -+// SymmKind: Universal -+// -+struct SymmConfiguration { -+ -+ /// SYMM/HEMM problem size -+ gemm::GemmCoord problem_size; -+ -+ /// Leading dimension of A matrix -+ int64_t lda; -+ -+ /// Leading dimension of B matrix -+ int64_t ldb; -+ -+ /// Leading dimension of C matrix -+ int64_t ldc; -+ -+ /// Leading dimension of D matrix -+ int64_t ldd; -+ -+ /// Batch Count -+ int batch_count; -+}; -+ -+/// Arguments for (Symm, Hemm) -+struct SymmArguments { -+ -+ /// Pointer to A matrix -+ void const *A; -+ -+ /// Pointer to B matrix -+ void const *B; -+ -+ /// Pointer to C matrix -+ void const *C; -+ -+ /// Pointer to D matrix -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ int64_t batch_stride_D; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Two dimensional convolution -+// -+// OperationKind: Conv2d -+// -+struct Conv2dConfiguration { -+ -+ conv::SplitKMode split_k_mode; -+ -+ /// Conv2d problem size -+ // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) -+ // also includes (split_k_slices, groups) -+ conv::Conv2dProblemSize problem_size; -+ -+ // stride of operand A -+ std::vector stride_a; -+ -+ // stride of operand B -+ std::vector stride_b; -+ -+ // stride of operand C -+ std::vector stride_c; -+}; -+ -+ -+/// Three dimensional convolution -+// -+// OperationKind: Conv3d -+// -+struct Conv3dConfiguration { -+ -+ conv::SplitKMode split_k_mode; -+ -+ /// Conv2d problem size -+ // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) -+ // also includes (split_k_slices, groups) -+ conv::Conv3dProblemSize problem_size; -+ -+ /// Layout object for activations tensor -+ layout::TensorNDHWC layout_activations; -+ -+ /// Layout object for filters tensor -+ layout::TensorNDHWC layout_filters; -+ -+ /// Layout object for source tensor -+ layout::TensorNDHWC layout_source; -+ -+ /// Layout object for output tensor -+ layout::TensorNDHWC layout_output; -+ -+ // -+ // Methods -+ // -+ -+ // Mapping functions (A,B,C -> activation,filter,output) -+ layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_activations; -+ case library::ConvKind::kDgrad: return layout_output; -+ case library::ConvKind::kWgrad: return layout_output; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_filters; -+ case library::ConvKind::kDgrad: return layout_filters; -+ case library::ConvKind::kWgrad: return layout_activations; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return layout_output; -+ case library::ConvKind::kDgrad: return layout_activations; -+ case library::ConvKind::kWgrad: return layout_filters; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+}; -+ -+/// Arguments for CONV -+struct ConvArguments { -+ -+ ///////////////////////////////////////////////////////// -+ /// ImplicitGemm matrices A, B, C, D -+ ///////////////////////////////////////////////////////// -+ /// pointer to implicit gemm matrix A -+ void const *A; -+ -+ /// pointer to implicit gemm matrix B -+ void const *B; -+ -+ /// pointer to reordered matrix B -+ void const *reordered_B; -+ -+ /// pointer to implicit gemm matrix C -+ void const *C; -+ -+ /// pointer to implicit gemm desitination matrix D -+ void *D; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Configuration for Reduction operations -+// -+// OperationKind: Reduction -+// -+struct ReductionConfiguration { -+ -+ /// Redcution problem size -+ MatrixCoord problem_size; -+ -+ /// Number of partitions to reduce -+ int partitions; -+ -+ /// Number of lements between each partition -+ int64_t partition_stride; -+ -+ /// leading dimension of 'w'orksace operand -+ int64_t ldw; -+ -+ /// leading dimension of 's'ource operand -+ int64_t lds; -+ -+ /// leading dimension of 'd'estination operand -+ int64_t ldd; -+}; -+ -+/// Arguments for Reduction -+struct ReductionArguments { -+ -+ /// Pointer to workspace matrix -+ void const *workspace; -+ -+ /// Pointer to source matrix -+ void const *source; -+ -+ /// Pointer to destination matrix -+ void *destination; -+ -+ /// pointer to reference matrix -+ void *reference; -+ -+ /// Host or device pointer to alpha scalar -+ void const *alpha; -+ -+ /// Host or device pointer to beta scalar -+ void const *beta; -+ -+ /// Enumerant indicating whether alpha/beta point to host or device memory -+ ScalarPointerMode pointer_mode; -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h b/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h -new file mode 100644 -index 0000000..abce958 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/manifest.h -@@ -0,0 +1,110 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Manifest of CUTLASS Library -+ -+ This is the root of the data structure containing CUTLASS objects -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "library.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// Forward declaration -+class Manifest; -+ -+// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) -+void initialize_all(Manifest &manifest); -+ -+// init and insert all reduction op in manifest object (manually instantiated in library/reduction) -+void initialize_all_reduction_op(Manifest &manifest); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// List of operations -+using OperationVector = std::vector>; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Manifest of CUTLASS Library -+class Manifest { -+private: -+ -+ /// Operation provider -+ Provider provider_; -+ -+ /// Global list of operations -+ OperationVector operations_; -+ -+public: -+ Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } -+ -+ /// Top-level initialization -+ Status initialize(); -+ -+ /// Used for initialization -+ void reserve(size_t operation_count); -+ -+ /// Graceful shutdown -+ Status release(); -+ -+ /// Appends an operation and takes ownership -+ void append(Operation *operation_ptr); -+ -+ /// Returns an iterator to the first operation -+ OperationVector const &operations() const; -+ -+ /// Returns a const iterator -+ OperationVector::const_iterator begin() const; -+ -+ /// Returns a const iterator -+ OperationVector::const_iterator end() const; -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h b/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h -new file mode 100644 -index 0000000..037703f ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/operation_table.h -@@ -0,0 +1,508 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Defines a data structure in which a set of functionally equivalent library::Operation -+ instances may be queried. -+*/ -+ -+#pragma once -+#include -+#include -+#include -+#include -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Data Structures for Gemm Functional Maps -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple uniquely identifying Gemm functional behavior -+struct GemmFunctionalKey { -+ -+ Provider provider; -+ GemmKind gemm_kind; -+ NumericTypeID element_compute; -+ NumericTypeID element_scalar; -+ NumericTypeID element_A; -+ LayoutTypeID layout_A; -+ ComplexTransform transform_A; -+ NumericTypeID element_B; -+ LayoutTypeID layout_B; -+ ComplexTransform transform_B; -+ NumericTypeID element_C; -+ -+ // -+ // Methods -+ // -+ -+ inline -+ GemmFunctionalKey( -+ Provider provider, -+ GemmKind gemm_kind = GemmKind::kGemm, -+ NumericTypeID element_compute = NumericTypeID::kF32, -+ NumericTypeID element_scalar = NumericTypeID::kF32, -+ NumericTypeID element_A = NumericTypeID::kF16, -+ LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, -+ ComplexTransform transform_A = ComplexTransform::kNone, -+ NumericTypeID element_B = NumericTypeID::kF16, -+ LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, -+ ComplexTransform transform_B = ComplexTransform::kNone, -+ NumericTypeID element_C = NumericTypeID::kF16 -+ ): -+ provider(provider), -+ gemm_kind(gemm_kind), -+ element_compute(element_compute), -+ element_scalar(element_scalar), -+ element_A(element_A), -+ layout_A(layout_A), -+ transform_A(transform_A), -+ element_B(element_B), -+ layout_B(layout_B), -+ transform_B(transform_B), -+ element_C(element_C) -+ { } -+ -+ inline -+ bool operator==(GemmFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (gemm_kind == rhs.gemm_kind) && -+ (element_compute == rhs.element_compute) && -+ (element_scalar == rhs.element_scalar) && -+ (element_A == rhs.element_A) && -+ (layout_A == rhs.layout_A) && -+ (transform_A == rhs.transform_A) && -+ (element_B == rhs.element_B) && -+ (layout_B == rhs.layout_B) && -+ (transform_B == rhs.transform_B) && -+ (element_C == rhs.element_C); -+ } -+ -+ inline -+ bool operator!=(GemmFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+inline -+std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { -+ -+ out << "{\n" -+ << " provider: " << to_string(k.provider) << "\n" -+ << " gemm_kind: " << to_string(k.gemm_kind) << "\n" -+ << " element_compute: " << to_string(k.element_compute) << "\n" -+ << " element_scalar: " << to_string(k.element_scalar) << "\n" -+ << " element_A: " << to_string(k.element_A) << "\n" -+ << " layout_A: " << to_string(k.layout_A) << "\n" -+ << " transform_A: " << to_string(k.transform_A) << "\n" -+ << " element_B: " << to_string(k.element_B) << "\n" -+ << " layout_B: " << to_string(k.layout_B) << "\n" -+ << " transform_B: " << to_string(k.transform_B) << "\n" -+ << " element_C: " << to_string(k.element_C) << "\n" -+ << "}"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Hash function for GemmFunctionalKey -+struct GemmFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(GemmFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.gemm_kind)), 2) ^ -+ rotl(hash(int(key.element_compute)), 3) ^ -+ rotl(hash(int(key.element_scalar)), 4) ^ -+ rotl(hash(int(key.element_A)), 5) ^ -+ rotl(hash(int(key.layout_A)), 6) ^ -+ rotl(hash(int(key.transform_A)), 7) ^ -+ rotl(hash(int(key.element_B)), 8) ^ -+ rotl(hash(int(key.layout_B)), 9) ^ -+ rotl(hash(int(key.transform_B)), 10) ^ -+ rotl(hash(int(key.element_C)), 11); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes a partial ordering to search for GEMM operators -+struct GemmPreferenceKey { -+ -+ int compute_capability; -+ int alignment; -+ -+ // -+ // Methods -+ // -+ -+ GemmPreferenceKey(): compute_capability(), alignment() { } -+ -+ GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } -+ -+ bool operator<(GemmPreferenceKey const &rhs) const { -+ return (compute_capability < rhs.compute_capability) || -+ ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); -+ } -+ -+ bool operator==(GemmPreferenceKey const &rhs) const { -+ return compute_capability == rhs.compute_capability; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Maps minimum compute capability onto a vector of possible operations -+using GemmOperationVectorMap = std::map< -+ GemmPreferenceKey, -+ std::vector -+>; -+ -+/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -+using GemmOperationFunctionalMap = std::unordered_map< -+ GemmFunctionalKey, -+ GemmOperationVectorMap, -+ GemmFunctionalKeyHasher -+>; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Data Structures for Conv Functional Maps -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Tuple uniquely identifying conv2d functional behavior -+struct ConvFunctionalKey { -+ library::Provider provider; -+ library::ConvKind conv_kind; -+ library::NumericTypeID element_A; -+ library::LayoutTypeID layout_A; -+ library::NumericTypeID element_B; -+ library::LayoutTypeID layout_B; -+ library::NumericTypeID element_C; -+ library::LayoutTypeID layout_C; -+ library::NumericTypeID element_accumulator; -+ library::NumericTypeID element_compute; -+ -+ -+ // -+ // Methods -+ // -+ -+ inline -+ ConvFunctionalKey( -+ library::Provider provider = library::Provider::kInvalid, -+ library::ConvKind conv_kind = library::ConvKind::kFprop, -+ library::NumericTypeID element_A = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_B = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_C = library::NumericTypeID::kF16, -+ library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, -+ library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, -+ library::NumericTypeID element_compute = library::NumericTypeID::kF32 -+ ): -+ provider(provider), -+ conv_kind(conv_kind), -+ element_A(element_A), -+ layout_A(layout_A), -+ element_B(element_B), -+ layout_B(layout_B), -+ element_C(element_C), -+ layout_C(layout_C), -+ element_accumulator(element_accumulator), -+ element_compute(element_compute) -+ { } -+ -+ inline -+ bool operator==(ConvFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (conv_kind == rhs.conv_kind) && -+ (element_A == rhs.element_A) && -+ (layout_A == rhs.layout_A) && -+ (element_B == rhs.element_B) && -+ (layout_B == rhs.layout_B) && -+ (element_C == rhs.element_C) && -+ (layout_C == rhs.layout_C) && -+ (element_accumulator == rhs.element_accumulator) && -+ (element_compute == rhs.element_compute); -+ } -+ -+ inline -+ bool operator!=(ConvFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+inline -+std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { -+ out << "{\n" -+ << "provider: " << to_string(key.provider) << std::endl -+ << "conv_kind: " << to_string(key.conv_kind) << std::endl -+ << "element_A: " << to_string(key.element_A) << std::endl -+ << "layout_A: " << to_string(key.layout_A) << std::endl -+ << "element_B: " << to_string(key.element_B) << std::endl -+ << "layout_B: " << to_string(key.layout_B) << std::endl -+ << "element_C: " << to_string(key.element_C) << std::endl -+ << "layout_C: " << to_string(key.layout_C) << std::endl -+ << "element_accumulator: " << to_string(key.element_accumulator) << std::endl -+ << "element_compute: " << to_string(key.element_compute) << std::endl -+ << "}"; -+ -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+struct ConvFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(ConvFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.conv_kind)), 2) ^ -+ rotl(hash(int(key.element_A)), 3) ^ -+ rotl(hash(int(key.layout_A)), 4) ^ -+ rotl(hash(int(key.element_B)), 5) ^ -+ rotl(hash(int(key.layout_B)), 6) ^ -+ rotl(hash(int(key.element_C)), 7) ^ -+ rotl(hash(int(key.layout_C)), 8) ^ -+ rotl(hash(int(key.element_accumulator)), 9) ^ -+ rotl(hash(int(key.element_compute)), 10); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Establishes a partial ordering to search for Conv2d operators -+struct ConvPreferenceKey { -+ -+ int compute_capability; -+ IteratorAlgorithmID iterator_algorithm; -+ -+ -+ // -+ // Methods -+ // -+ -+ ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } -+ -+ ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): -+ compute_capability(cc), iterator_algorithm(iterator_algorithm) { } -+ -+ bool operator<(ConvPreferenceKey const &rhs) const { -+ return (compute_capability < rhs.compute_capability) || -+ ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); -+ } -+ -+ bool operator==(ConvPreferenceKey const &rhs) const { -+ return (compute_capability == rhs.compute_capability) && -+ (iterator_algorithm == rhs.iterator_algorithm); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Maps minimum compute capability onto a vector of possible operations -+using ConvOperationVectorMap = std::map< -+ ConvPreferenceKey, -+ std::vector -+>; -+ -+/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm -+using ConvOperationFunctionalMap = std::unordered_map< -+ ConvFunctionalKey, -+ ConvOperationVectorMap, -+ ConvFunctionalKeyHasher -+>; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Tuple uniquely identifying conv2d functional behavior -+struct ReductionFunctionalKey { -+ library::Provider provider; -+ library::NumericTypeID element_workspace; -+ library::NumericTypeID element_accumulator; -+ library::NumericTypeID element_output; -+ library::NumericTypeID element_compute; -+ library::MathOperationID reduce_math_op; -+ library::EpilogueKind epilogue_math_op; -+ -+ -+ // -+ // Methods -+ // -+ -+ inline -+ ReductionFunctionalKey( -+ library::Provider provider = library::Provider::kInvalid, -+ library::NumericTypeID element_workspace = library::NumericTypeID::kF16, -+ library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, -+ library::NumericTypeID element_output = library::NumericTypeID::kF16, -+ library::NumericTypeID element_compute = library::NumericTypeID::kF32, -+ library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, -+ library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination -+ ): -+ provider(provider), -+ element_workspace(element_workspace), -+ element_accumulator(element_accumulator), -+ element_output(element_output), -+ element_compute(element_compute), -+ reduce_math_op(reduce_math_op), -+ epilogue_math_op(epilogue_math_op) -+ { } -+ -+ inline -+ bool operator==(ReductionFunctionalKey const &rhs) const { -+ return -+ (provider == rhs.provider) && -+ (element_workspace == rhs.element_workspace) && -+ (element_accumulator == rhs.element_accumulator) && -+ (element_output == rhs.element_output) && -+ (element_compute == rhs.element_compute) && -+ (reduce_math_op == rhs.reduce_math_op) && -+ (epilogue_math_op == rhs.epilogue_math_op); -+ } -+ -+ inline -+ bool operator!=(ReductionFunctionalKey const &rhs) const { -+ return !(*this == rhs); -+ } -+}; -+ -+ -+struct ReductionFunctionalKeyHasher { -+ using IntHash = std::hash; -+ -+ inline -+ static size_t rotl(size_t key, int shl) { -+ return (key << shl) | (key >> (sizeof(key)*8 - shl)); -+ } -+ -+ inline -+ size_t operator()(ReductionFunctionalKey const &key) const { -+ IntHash hash; -+ -+ return -+ rotl(hash(int(key.provider)), 1) ^ -+ rotl(hash(int(key.element_workspace)), 2) ^ -+ rotl(hash(int(key.element_accumulator)), 3) ^ -+ rotl(hash(int(key.element_output)), 4) ^ -+ rotl(hash(int(key.element_compute)), 5) ^ -+ rotl(hash(int(key.reduce_math_op)), 6) ^ -+ rotl(hash(int(key.epilogue_math_op)), 7); -+ } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+inline -+std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { -+ out << "{\n" -+ << "provider: " << library::to_string(key.provider) << std::endl -+ << "element_workspace : " << library::to_string(key.element_workspace) << std::endl -+ << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl -+ << "element_output : " << library::to_string(key.element_output) << std::endl -+ << "element_compute : " << library::to_string(key.element_compute) << std::endl -+ << "}"; -+ return out; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key -+// i.e. only one tile size configuration per functional key -+using ReductionOperationFunctionalMap = std::unordered_map< -+ ReductionFunctionalKey, -+ library::Operation const *, -+ ReductionFunctionalKeyHasher -+>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Table of cutlass::library::Operation instances -+class OperationTable { -+public: -+ -+ /// Map of all operations of type kGemm -+ // provider (kCUTLASS) -+ GemmOperationFunctionalMap gemm_operations; -+ -+ /// Map of all operations of type kConv2d -+ // provider (kCUTLASS, kReferenceHost, kReferenceDevice) -+ ConvOperationFunctionalMap conv2d_operations; -+ -+ /// Map of all operations of type kConv3d -+ // provider (kCUTLASS, kReferenceHost, kReferenceDevice) -+ ConvOperationFunctionalMap conv3d_operations; -+ -+ /// Map of all operations of type kConv2d -+ // provider (kCUTLASS) -+ ReductionOperationFunctionalMap reduction_operations; -+ -+public: -+ -+ void append(Manifest const &manifest); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h b/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h -new file mode 100644 -index 0000000..e0bd959 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/singleton.h -@@ -0,0 +1,68 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/operation_table.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Singleton instance stores a Manifest and Operation table -+class Singleton { -+public: -+ -+ /// Manifest object -+ Manifest manifest; -+ -+ /// Operation table referencing the Manifest -+ OperationTable operation_table; -+ -+public: -+ -+ Singleton(); -+ -+ static Singleton const &get(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/include/cutlass/library/util.h b/3rdparty/cutlass/tools/library/include/cutlass/library/util.h -new file mode 100644 -index 0000000..517c6e9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/include/cutlass/library/util.h -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief Utilities accompanying the CUTLASS library for interacting with Library types. -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexical cast from string -+template T from_string(std::string const &); -+ -+/// Converts a Provider enumerant to a string -+char const *to_string(Provider provider, bool pretty = false); -+ -+/// Parses a Provider enumerant from a string -+template <> Provider from_string(std::string const &str); -+ -+/// Converts a GemmKind enumerant to a string -+char const *to_string(GemmKind type, bool pretty = false); -+ -+/// Converts a RankKKind enumerant to a string -+char const *to_string(RankKKind type, bool pretty = false); -+ -+/// Converts a TrmmKind enumerant to a string -+char const *to_string(TrmmKind type, bool pretty = false); -+ -+/// Converts a SymmKind enumerant to a string -+char const *to_string(SymmKind type, bool pretty = false); -+ -+/// Converts a SideMode enumerant to a string -+char const *to_string(SideMode type, bool pretty = false); -+ -+/// Converts a FillMode enumerant to a string -+char const *to_string(FillMode type, bool pretty = false); -+ -+/// Converts a BlasMode enumerant to a string -+char const *to_string(BlasMode type, bool pretty = false); -+ -+/// Converts a DiagType enumerant to a string -+char const *to_string(DiagType type, bool pretty = false); -+ -+/// Converts a NumericType enumerant to a string -+char const *to_string(OperationKind type, bool pretty = false); -+ -+/// Parses a NumericType enumerant from a string -+template <> OperationKind from_string(std::string const &str); -+ -+/// Converts a NumericType enumerant to a string -+char const *to_string(NumericTypeID type, bool pretty = false); -+ -+/// Parses a NumericType enumerant from a string -+template <> NumericTypeID from_string(std::string const &str); -+ -+/// Returns the size of a data type in bits -+int sizeof_bits(NumericTypeID type); -+ -+/// Returns true if the numeric type is a complex data type or false if real-valued. -+bool is_complex_type(NumericTypeID type); -+ -+/// Returns the real-valued type underlying a type (only different from 'type' if complex) -+NumericTypeID get_real_type(NumericTypeID type); -+ -+/// Returns true if numeric type is integer -+bool is_integer_type(NumericTypeID type); -+ -+/// Returns true if numeric type is signed -+bool is_signed_type(NumericTypeID type); -+ -+/// Returns true if numeric type is a signed integer -+bool is_signed_integer(NumericTypeID type); -+ -+/// returns true if numeric type is an unsigned integer -+bool is_unsigned_integer(NumericTypeID type); -+ -+/// Returns true if numeric type is floating-point type -+bool is_float_type(NumericTypeID type); -+ -+/// To string method for cutlass::Status -+char const *to_string(Status status, bool pretty = false); -+ -+/// Converts a LayoutTypeID enumerant to a string -+char const *to_string(LayoutTypeID layout, bool pretty = false); -+ -+/// Parses a LayoutType enumerant from a string -+template <> LayoutTypeID from_string(std::string const &str); -+ -+/// Returns the rank of a layout's stride base on the LayoutTypeID -+int get_layout_stride_rank(LayoutTypeID layout_id); -+ -+/// Converts a OpcodeClassID enumerant to a string -+char const *to_string(OpcodeClassID type, bool pretty = false); -+ -+/// Converts a OpcodeClassID enumerant from a string -+template <> -+OpcodeClassID from_string(std::string const &str); -+ -+/// Converts a ComplexTransform enumerant to a string -+char const *to_string(ComplexTransform type, bool pretty = false); -+ -+/// Converts a ComplexTransform enumerant from a string -+template <> -+ComplexTransform from_string(std::string const &str); -+ -+ -+/// Converts a SplitKMode enumerant to a string -+char const *to_string(SplitKMode split_k_mode, bool pretty = false); -+ -+/// Converts a SplitKMode enumerant from a string -+template <> -+SplitKMode from_string(std::string const &str); -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(ConvModeID type, bool pretty = false); -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+ConvModeID from_string(std::string const &str); -+ -+/// Converts a IteratorAlgorithmID enumerant to a string -+char const *to_string(IteratorAlgorithmID type, bool pretty = false); -+ -+/// Converts a IteratorAlgorithmID enumerant from a string -+template <> -+IteratorAlgorithmID from_string(std::string const &str); -+ -+/// Converts a ConvKind enumerant to a string -+char const *to_string(ConvKind type, bool pretty = false); -+ -+/// Converts a ConvKind enumerant from a string -+template <> -+ConvKind from_string(std::string const &str); -+ -+/// Lexical cast from int64_t to string -+std::string lexical_cast(int64_t int_value); -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); -+ -+/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -+std::string lexical_cast(std::vector &bytes, NumericTypeID type); -+ -+/// Casts from a signed int64 to the destination type. Returns true if successful. -+bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); -+ -+/// Casts from an unsigned int64 to the destination type. Returns true if successful. -+bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); -+ -+/// Casts from a real value represented as a double to the destination type. Returns true if successful. -+bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h -new file mode 100644 -index 0000000..b8e60bc ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/compiler.h -@@ -0,0 +1,75 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief In-memory compiled artifact cache -+*/ -+ -+#include -+#include -+#include -+ -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+ -+struct CompileCache { -+public: -+ CompileCache() = default; -+ ~CompileCache() = default; -+ -+ using Cache = std::unordered_map; -+ -+ /// Check if the kernel has already been compiled -+ py::object at(const std::string &kernel) { -+ auto item = cache_.find(kernel); -+ -+ if (item != cache_.end()) { -+ return item->second; -+ } -+ return py::none(); -+ } -+ -+ /// Insert a new compiled kernel for new configuration -+ void insert(const std::string &kernel, const py::object &compiled_kernel){ -+ cache_.emplace(kernel, compiled_kernel); -+ } -+ -+ const int64_t size() const { return cache_.size(); } -+ -+ /// Clear the cache -+ void clear() { cache_.clear(); } -+ -+private: -+ Cache cache_; -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h -new file mode 100644 -index 0000000..21f9771 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/arch.h -@@ -0,0 +1,59 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind opcode classes to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/arch/mma.h" -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+enum class OpcodeClass { -+ kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp -+}; -+} -+ -+void bind_opcode(py::module &m) { -+ py::enum_(m, "OpClass", -+ R"pbdoc(classification of math operators)pbdoc") -+ .value("Simt", cutlass::OpcodeClass::kSimt, -+ R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc") -+ .value("TensorOp", cutlass::OpcodeClass::kTensorOp, -+ R"pbdoc(Tag classifing operators as Tensor Core operations)pbdoc") -+ .value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp, -+ R"pbdoc(Tag classifing operators as WMMA Tensor Core operations)pbdoc") -+ .value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp, -+ R"pbdoc(Tag classifing operators as sparseTensor Core operations)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h -new file mode 100644 -index 0000000..ab4a067 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Convolution problem sizes to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace py = pybind11; -+ -+void bind_conv_problem_size(py::module &m) { -+ // -+ // Conv2d Problem Size: -+ // include/cutlass/conv/conv2d_problem_sizd.h -+ // -+ py::class_(m, "Conv2dProblemSize") -+ // constructors -+ .def(py::init()) -+ .def(py::init()) -+ // attribute accessors -+ .def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N) -+ .def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H) -+ .def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W) -+ .def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C) -+ .def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P) -+ .def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q) -+ .def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K) -+ .def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R) -+ .def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S) -+ .def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h) -+ .def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w) -+ .def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h) -+ .def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w) -+ .def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h) -+ .def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w) -+ .def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode) -+ .def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices) -+ .def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups) -+ // functions -+ .def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices) -+ .def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent) -+ .def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent) -+ .def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent) -+ .def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size) -+ .def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size) -+ .def("output_size", &cutlass::conv::Conv2dProblemSize::output_size); -+ -+ // Get tensor size -+ m.def("implicit_gemm_tensor_a_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_a_size)); -+ m.def("implicit_gemm_tensor_b_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_b_size)); -+ m.def("implicit_gemm_tensor_c_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_c_size)); -+ -+ // Get tensor extent -+ m.def("implicit_gemm_tensor_a_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_a_extent)); -+ -+ m.def("implicit_gemm_tensor_b_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_b_extent)); -+ -+ m.def("implicit_gemm_tensor_c_extent", -+ py::overload_cast< -+ cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& -+ >(&cutlass::conv::implicit_gemm_tensor_c_extent)); -+ -+ m.def("implicit_gemm_problem_size", py::overload_cast(&cutlass::conv::implicit_gemm_problem_size)); -+ -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h -new file mode 100644 -index 0000000..36126ec ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h -@@ -0,0 +1,91 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution related enum types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "conv_problem_size.h" -+#include "host.h" -+#include "cutlass/conv/convolution.h" -+ -+namespace py = pybind11; -+ -+void bind_convolution(py::module &m) { -+ // -+ // Enumerate types -+ // cutlass/include/cutlass/conv/convolution.h -+ // -+ -+ /// Convolutional operator -+ py::enum_(m, "Operator", R"pbdoc(Convolutional operator)pbdoc") -+ .value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation") -+ .value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad") -+ .value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad"); -+ -+ /// Distinguishes convolution from cross correlation -+ py::enum_(m, "Mode") -+ .value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation) -+ .value("convolution", cutlass::conv::Mode::kConvolution); -+ -+ /// Selects among several implementation variants trading off performance with simplicity -+ py::enum_(m, "IteratorAlgorithm", -+ R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc") -+ .value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc") -+ .value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc") -+ .value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc") -+ .value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc"); -+ -+ /// Distinguishes among partial specializations that accelerate certain problems where convolution -+ /// stride is unit. -+ py::enum_(m, "StrideSupport", -+ R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution -+ stride is unit.)pbdoc") -+ .value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc") -+ .value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc"); -+ -+ /// Identifies split-K mode -+ py::enum_(m, "SplitKMode") -+ .value("None", cutlass::conv::SplitKMode::kNone) -+ .value("Serial", cutlass::conv::SplitKMode::kSerial) -+ .value("Parallel", cutlass::conv::SplitKMode::kParallel); -+ -+ // Conv problem sizes -+ bind_conv_problem_size(m); -+ -+ // -+ // host helper functions -+ // -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_conv_host_helper(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h -new file mode 100644 -index 0000000..7a33251 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h -@@ -0,0 +1,54 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind conv host helpers to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+ -+void bind_conv_host_helper(py::module &m) { -+ -+ /// reorder operand B for interleaved layout -+ m.def("reorder_convK", []( -+ cutlass::TensorRef> dest, -+ cutlass::TensorRef> src, -+ cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) { -+ cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size); -+ cutlass::reorder_convK<32>(dest, src, implicit_problem_size); -+ }); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h -new file mode 100644 -index 0000000..6b33f9a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A generic wrapper around an epilogue visitor operation -+*/ -+ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+#include "epilogue_visitor_op/visitor_op_linear_combination.h" -+#include "epilogue_visitor_op/visitor_op_tensor_input.h" -+#include "epilogue_visitor_op/visitor_op_accumulator.h" -+#include "epilogue_visitor_op/visitor_op_row_broadcast.h" -+#include "epilogue_visitor_op/visitor_op_tensor_output.h" -+#include "epilogue_visitor_op/visitor_op_column_reduction.h" -+#include "epilogue_visitor_op/visitor_op_row_reduction.h" -+#include "epilogue_visitor_op/visitor_op_column_broadcast.h" -+#include "epilogue_visitor_op/visitor_op_unary.h" -+#include "epilogue_visitor_op/visitor_op_binary.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic Epilogue Visitor. -+template < -+ typename OutputOp_ -+> -+class EpilogueVisitorGeneric { -+public: -+ -+ using OutputOp = OutputOp_; -+ using AccumulatorAccessType = typename OutputOp::AccumulatorAccessType; -+ static int const kElementsPerAccess = OutputOp::kElementsPerAccess; -+ using ElementOutput = typename OutputOp::ElementOutput; -+ using OutputTileIterator = typename OutputOp::OutputTileIterator; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ -+ /// -+ /// End Epilogue Tree -+ /// -+ -+ /// Additional SMEM bufer is not required in the broadcast epilogue visitor -+ struct SharedStorage { -+ -+ typename OutputOp::SharedStorage output_smem; -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+public: -+ -+ /// Argument structure -+ struct Arguments { -+ typename OutputOp::Arguments output_op_args; -+ // -+ // Methods -+ // -+ Arguments() { } -+ -+ Arguments( -+ typename OutputOp::Arguments output_op_args -+ ): -+ output_op_args(output_op_args) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ typename OutputOp::Params output_op_params; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ output_op_params(args.output_op_args) -+ { -+ -+ } -+ }; -+ -+ -+ -+private: -+ -+ OutputOp output_op; -+ -+public: -+ -+ /// Constructor -+ CUTLASS_DEVICE -+ EpilogueVisitorGeneric( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord threadblock_offset, -+ gemm::GemmCoord threadblock_tile_offset, -+ int thread_idx, -+ MatrixCoord problem_size -+ ): -+ output_op(params.output_op_params, shared_storage.output_smem, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ output_op.set_batch_index(batch_idx); -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ output_op.begin_epilogue(); -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ output_op.begin_step(step_idx); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ output_op.begin_row(row_idx); -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum) { -+ output_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ output_op.end_row(row_idx); -+ -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ output_op.end_step(step_idx); -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ output_op.end_epilogue(); -+ } -+ -+}; -+ -+//////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace threadblock -+} // namespace epilogue -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h -new file mode 100644 -index 0000000..f64066a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h -@@ -0,0 +1,84 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the binary ops -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Scalar multiplication -+template -+struct VectorAdd { -+ -+ struct Arguments { -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():tmp(0){ } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { } -+ }; -+ -+ struct Params { -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args) { } -+ }; -+ -+ CUTLASS_HOST_DEVICE -+ VectorAdd( -+ Params const ¶ms -+ ) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &lhs, Array const &rhs) const { -+ cutlass::plus> add_op; -+ return add_op(lhs, rhs); -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h -new file mode 100644 -index 0000000..9952a52 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h -@@ -0,0 +1,233 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the unary ops -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/activation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Scalar multiplication -+template -+struct Mult { -+ -+ struct Arguments { -+ T alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():alpha(T(1.0)){ } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T alpha): alpha(alpha) { } -+ }; -+ -+ struct Params { -+ T alpha; ///< scales accumulators -+ -+ CUTLASS_HOST_DEVICE -+ Params():alpha(T(1.0)){ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): alpha(args.alpha) { } -+ }; -+ -+ T alpha_; -+ -+ CUTLASS_HOST_DEVICE -+ Mult( -+ Params const ¶ms -+ ): -+ alpha_(params.alpha) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &source) const { -+ cutlass::multiplies> multiply_op; -+ return multiply_op(source, alpha_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return alpha_ != T(0); -+ } -+ -+}; -+ -+ -+/// ReLU -+template -+struct ReLUVisitor { -+ struct Arguments { -+ T threshold; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():threshold(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T threshold): threshold(threshold) { } -+ }; -+ -+ struct Params { -+ T threshold; -+ -+ CUTLASS_HOST_DEVICE -+ Params():threshold(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): threshold(args.threshold) { } -+ }; -+ -+ T threshold_; -+ -+ CUTLASS_HOST_DEVICE -+ ReLUVisitor(Params const ¶ms): -+ threshold_(params.threshold) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ maximum> mx; -+ return mx(frag, threshold_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+}; -+ -+/// leakyReLU -+template -+struct LeakyReLUVisitor { -+ struct Arguments { -+ T leaky_alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments():leaky_alpha(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(T leaky_alpha): leaky_alpha(leaky_alpha) { } -+ }; -+ -+ struct Params { -+ T leaky_alpha; -+ -+ CUTLASS_HOST_DEVICE -+ Params():leaky_alpha(T(0.0)) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): leaky_alpha(args.leaky_alpha) { } -+ }; -+ -+ T leaky_alpha_; -+ -+ CUTLASS_HOST_DEVICE -+ LeakyReLUVisitor(Params const ¶ms): -+ leaky_alpha_(params.leaky_alpha) { } -+ -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ cutlass::epilogue::thread::LeakyReLU> leaky_op; -+ return leaky_op(frag, leaky_alpha_); -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+ -+}; -+ -+/// Tanh -+template -+struct TanhVisitor { -+ /// Argument -+ struct Arguments { -+ // a placeholder argument to ensure correctness of ctypes -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(): tmp(0) { }; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { }; -+ }; -+ -+ /// Param -+ struct Params { -+ CUTLASS_HOST_DEVICE -+ Params(){ }; -+ Params(Arguments const &args) { } -+ }; -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ TanhVisitor(Params const ¶ms) { } -+ -+ // scalar operator -+ CUTLASS_HOST_DEVICE -+ T tanh_op(T const &scalar) const { -+ return fast_tanh(scalar); -+ } -+ -+ /// vector operator -+ CUTLASS_HOST_DEVICE -+ Array operator()(Array const &frag) const { -+ Array y; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i=0; i < N; ++i) { -+ y[i] = tanh_op(frag[i]); -+ } -+ -+ return y; -+ } -+ -+ CUTLASS_HOST_DEVICE -+ bool guard() { -+ return true; -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h -new file mode 100644 -index 0000000..2072cfa ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h -@@ -0,0 +1,148 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with accumulator -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following Computation -+/// -+/// ElementAccumulator accum; -+/// return accum; -+/// -+/// It can only be the leaf node of the epilogue tree -+ -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ int kElementsPerAccess_ ///< Number of elements computed per operation -+> -+class VisitorOpAccumulator{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ /// Fragment type for Accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = AccumulatorAccessType; -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ // Note: it is strange that ctypes will return issue with empty arguments -+ int tmp; -+ -+ CUTLASS_HOST_DEVICE -+ Arguments() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments(int tmp): tmp(tmp) { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args) { } -+ }; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpAccumulator( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ return accum; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h -new file mode 100644 -index 0000000..d9fa445 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h -@@ -0,0 +1,245 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Binary op -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "binary_ops.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename VisitorA_, ///< Child node A -+ typename VisitorB_, ///< Child node B -+ template typename BinaryOp_ -+> -+class VisitorOpBinary{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using VisitorA = VisitorA_; -+ using VisitorB = VisitorB_; -+ -+ /// Fragment type returned from VisitorA.visit -+ using VisitAccessTypeA = typename VisitorA::VisitAccessType; -+ using ElementA = typename VisitAccessTypeA::Element; -+ -+ /// Fragment type returned from VisitorB.visit -+ using VisitAccessTypeB = typename VisitorB::VisitAccessType; -+ using ElementB = typename VisitAccessTypeB::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ using BinaryOp = BinaryOp_; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); -+ static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename VisitorA::SharedStorage storage_a; -+ typename VisitorB::SharedStorage storage_b; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ typename BinaryOp::Arguments binary_arg; -+ typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a -+ typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments():binary_arg() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ typename BinaryOp::Arguments binary_arg, -+ typename VisitorA::Arguments visitor_a_arg, -+ typename VisitorB::Arguments visitor_b_arg -+ ): -+ binary_arg(binary_arg), -+ visitor_a_arg(visitor_a_arg), -+ visitor_b_arg(visitor_b_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ typename BinaryOp::Params binary_param; -+ typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a -+ typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ binary_param(args.binary_arg), -+ visitor_a_param(args.visitor_a_arg), -+ visitor_b_param(args.visitor_b_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ -+ BinaryOp binary_op; -+ -+ VisitorA visitor_a_op; -+ VisitorB visitor_b_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpBinary( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ binary_op(params.binary_param), -+ visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), -+ visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_a_op.begin_epilogue(); -+ visitor_b_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ visitor_a_op.set_batch_index(batch_idx); -+ visitor_b_op.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_a_op.begin_step(step_idx); -+ visitor_b_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_a_op.begin_row(row_idx); -+ visitor_b_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeA result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ VisitAccessTypeB result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ /// Type conversion -+ NumericArrayConverter source_converter_A; -+ NumericArrayConverter source_converter_B; -+ -+ return binary_op( -+ source_converter_A(result_A), -+ source_converter_B(result_B) -+ ); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_a_op.end_row(row_idx); -+ visitor_b_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_a_op.end_step(step_idx); -+ visitor_b_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_a_op.end_epilogue(); -+ visitor_b_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h -new file mode 100644 -index 0000000..6dcb32b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h -@@ -0,0 +1,250 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with broadcasting vector to all columns -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementVector T[i][j] <- device-memory Td[i] -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementFragment_, ///< Data type used to cache vector in register -+ typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor -+> -+class VisitorOpColumnBroadcast { -+public: -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementVector = typename InputTileIterator::Element; -+ using ElementFragment = ElementFragment_; -+ -+ using VisitAccessType = Array; -+ -+ /// Thread map used by input tile iterators -+ using ThreadMap = typename InputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementFragment, kElementsPerAccess>; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ // /// Number of iterations (accesses) the threadblock takes to reduce a row -+ // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ }; -+ -+ // using ComputeFragmentType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementVector *broadcast_ptr, -+ int64_t batch_stride -+ ): -+ broadcast_ptr(broadcast_ptr), -+ batch_stride(batch_stride) { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ broadcast_ptr(args.broadcast_ptr), -+ batch_stride(args.batch_stride) { } -+ }; -+ -+private: -+ ElementVector *broadcast_ptr; -+ BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment -+ MatrixCoord threadblock_offset_; -+ int thread_idx_; -+ MatrixCoord problem_size; -+ -+ int thread_start_row_; -+ int state_[3]; -+ int thread_offset_row_; -+ -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpColumnBroadcast( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ broadcast_ptr(params.broadcast_ptr), -+ threadblock_offset_(threadblock_offset), -+ thread_idx_(thread_idx), -+ problem_size(problem_size), -+ thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), -+ batch_stride_(params.batch_stride) -+ { -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ broadcast_ptr += batch_idx * batch_stride_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) {} -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) {} -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ // get pointer -+ thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); -+ -+ ElementFragment broadcast_data = ElementFragment(*(broadcast_ptr + thread_offset_row_)); -+ -+ broadcast_fragment.fill(broadcast_data); -+ -+ return broadcast_fragment; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ // run operator ++ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h -new file mode 100644 -index 0000000..624d7e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with reduction over columns in CTA -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementReductionAccumulator R[j] = \sum_i ElementReductionAccumulator(T[i][j]) -+/// device memory <- ElementReduction(R[j]) -+/// -+template < -+ typename ThreadblockShape_, /// Threadblock shape -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementReduction_, ///< Data type of the output reduction in device memory -+ typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register -+ typename OutputTileIterator_, ///< Tile Iterator type -+ typename Visitor_ ///< preceeding visitor op -+> -+class VisitorOpColumnReduction { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementReductionAccumulator = ElementReductionAccumulator_; -+ using ElementReduction = ElementReduction_; -+ using OutputTileIterator = OutputTileIterator_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Visitor = Visitor_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ReductionOp = cutlass::plus>; -+ using ReductionOpScalar = cutlass::plus; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of redcution -+ using ReductionAccumulatorAccessType = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ /// Number of iterations (accesses) the threadblock takes to reduce a row -+ static int const kThreadAccessesPerRow = const_max(1, (ThreadblockShape::kN + kThreadCount - 1) / kThreadCount); -+ -+ using StorageShape = MatrixShape< -+ kThreadRows, -+ ThreadblockShape::kN -+ >; -+ }; -+ -+ using ReductionFragment = Array; -+ -+ /// Shared storage -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ AlignedArray reduction; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Arguments(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementReduction *reduction_ptr, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ reduction_ptr(reduction_ptr), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ reduction_ptr(args.reduction_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory -+ ElementReductionAccumulator *reduction_smem_ptr_; ///< Pointer to the partial reductions in shared memory -+ ReductionFragment reduction_fragment; ///< register fragments that hold the partial reduction -+ Visitor visitor_; ///< visitor -+ int thread_idx_; -+ MatrixCoord threadblock_offset; -+ MatrixCoord problem_size_; -+ int64_t batch_stride_; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpColumnReduction( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, -+ thread_idx, threadblock_offset, problem_size), -+ reduction_smem_ptr_(shared_storage.reduction.data()), -+ reduction_output_ptr_(params.reduction_ptr), -+ thread_idx_(thread_idx), -+ threadblock_offset(threadblock_offset), -+ problem_size_(problem_size), -+ batch_stride_(params.batch_stride) -+ { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ reduction_output_ptr_ += batch_idx * batch_stride_; -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ -+ // clear the reduction fragment -+ reduction_fragment.clear(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ NumericArrayConverter reduction_converter; -+ ReductionOp reduction_op; -+ ReductionAccumulatorAccessType* reduction_fragment_ = reinterpret_cast(&reduction_fragment); -+ reduction_fragment_[column_idx] = reduction_op(reduction_fragment_[column_idx], reduction_converter(result)); -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ // -+ // Store the partially reduced value to SMEM -+ // -+ -+ // Guard against uses of the existing SMEM tile -+ __syncthreads(); -+ -+ using AccessType = AlignedArray; -+ -+ // -+ // Determine a compact thread arrangement to store to SMEM -+ // -+ -+ MatrixCoord thread_offset( -+ thread_idx_ / ReductionDetail::kThreadsPerRow, -+ (thread_idx_ % ReductionDetail::kThreadsPerRow) * ThreadMap::kElementsPerAccess -+ ); -+ -+ // -+ // Each thread store its fragment to a SMEM -+ // -+ AccessType *aligned_reduction_ptr = reinterpret_cast( -+ &reduction_smem_ptr_[thread_offset.row() * ThreadblockShape::kN + thread_offset.column()] -+ ); -+ -+ AccessType const *frag_ptr = reinterpret_cast( -+ &reduction_fragment -+ ); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { -+ int col_idx = column * ThreadMap::Delta::kColumn / ThreadMap::kElementsPerAccess; -+ -+ aligned_reduction_ptr[col_idx] = frag_ptr[column]; -+ } -+ -+ __syncthreads(); -+ -+ // -+ // Now, threads are assigned several columns of the output. The fetch over all rows from -+ // the compacted SMEM tile and perform a reduction. -+ // -+ -+ NumericConverter output_converter; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ReductionDetail::kThreadAccessesPerRow; ++j) { -+ int column_idx = thread_idx_ + j * ReductionDetail::kThreadCount; -+ -+ ReductionOpScalar reduction_op; -+ ElementReductionAccumulator reduction_element = ElementReductionAccumulator(); -+ -+ int output_column_idx = threadblock_offset.column() + column_idx; -+ -+ if (column_idx < ThreadblockShape::kN && output_column_idx < problem_size_.column()) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int row = 0; row < ReductionDetail::kThreadRows; ++row) { -+ if (row) { -+ auto frag = reduction_smem_ptr_[row * ThreadblockShape::kN + column_idx]; -+ reduction_element = reduction_op(reduction_element, frag); -+ } -+ else { -+ -+ reduction_element = reduction_smem_ptr_[column_idx]; -+ } -+ } -+ -+ // Store -+ reduction_output_ptr_[column_idx + threadblock_offset.column() + threadblock_offset.row() / ThreadblockShape::kM * problem_size_.column()] = output_converter(reduction_element); -+ } -+ } -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h -new file mode 100644 -index 0000000..1e2b8e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h -@@ -0,0 +1,266 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Linear Combination -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = BinaryOp(alpha * ElementCompute(Visitor_A), beta * ElementCompute(Visitor_B) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename VisitorA_, ///< Child node A -+ typename VisitorB_ ///< Child node B -+> -+class VisitorOpLinearCombination{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using VisitorA = VisitorA_; -+ using VisitorB = VisitorB_; -+ -+ /// Fragment type returned from VisitorA.visit -+ using VisitAccessTypeA = typename VisitorA::VisitAccessType; -+ using ElementA = typename VisitAccessTypeA::Element; -+ -+ /// Fragment type returned from VisitorB.visit -+ using VisitAccessTypeB = typename VisitorB::VisitAccessType; -+ using ElementB = typename VisitAccessTypeB::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Combination Op -+ using CombinationOp = cutlass::plus; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeA::kElements, "kElementsPerAccess mismatches with Visitor A"); -+ static_assert(kElementsPerAccess==VisitAccessTypeB::kElements, "kElementsPerAccess misnatches with Visitor B"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename VisitorA::SharedStorage storage_a; -+ typename VisitorB::SharedStorage storage_b; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ typename VisitorA::Arguments visitor_a_arg; ///< Argument type for visitor_a -+ typename VisitorB::Arguments visitor_b_arg; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ alpha(ElementCompute(1)), -+ beta(ElementCompute(0)) -+ { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementCompute alpha, -+ ElementCompute beta, -+ typename VisitorA::Arguments visitor_a_arg, -+ typename VisitorB::Arguments visitor_b_arg -+ ): -+ alpha(alpha), -+ beta(beta), -+ visitor_a_arg(visitor_a_arg), -+ visitor_b_arg(visitor_b_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ ElementCompute alpha; ///< scales accumulators -+ ElementCompute beta; ///< scales source tensor -+ typename VisitorA::Params visitor_a_param; ///< Argument type for visitor_a -+ typename VisitorB::Params visitor_b_param; ///< Argument type for visitor_b -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ alpha(args.alpha), -+ beta(args.beta), -+ visitor_a_param(args.visitor_a_arg), -+ visitor_b_param(args.visitor_b_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ -+ ElementCompute alpha_; -+ ElementCompute beta_; -+ -+ VisitorA visitor_a_op; -+ VisitorB visitor_b_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpLinearCombination( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ alpha_(params.alpha), -+ beta_(params.beta), -+ visitor_a_op(params.visitor_a_param, shared_storage.storage_a, thread_idx, threadblock_offset, problem_size), -+ visitor_b_op(params.visitor_b_param, shared_storage.storage_b, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_epilogue(); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_step(step_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.begin_row(row_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeA result_A; -+ VisitAccessTypeB result_B; -+ -+ if (alpha_ != ElementCompute(0)) { -+ result_A = visitor_a_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ // Fill the result A with zeros -+ result_A.clear(); -+ } -+ -+ if (beta_ != ElementCompute(0)) { -+ result_B = visitor_b_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ // Fill the result B with zeros -+ result_B.clear(); -+ } -+ -+ /// Type conversion -+ NumericArrayConverter source_converter_A; -+ NumericArrayConverter source_converter_B; -+ -+ CombinationOp combination_op; -+ -+ cutlass::multiplies multiply_op; -+ -+ return combination_op( -+ multiply_op(alpha_, source_converter_A(result_A)), -+ multiply_op(beta_, source_converter_B(result_B)) -+ ); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_row(row_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_step(step_idx); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ if (alpha_ != ElementCompute(0)) visitor_a_op.end_epilogue(); -+ if (beta_ != ElementCompute(0)) visitor_b_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h -new file mode 100644 -index 0000000..dc7bfa2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h -@@ -0,0 +1,258 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with broadcasting vector to all rows -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementVector T[i][j] <- device-memory Td[j] -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementFragment_, ///< Data type used to cache vector in register -+ typename InputTileIterator_ ///< Tile iterator type to read the broadcasted tensor -+> -+class VisitorOpRowBroadcast { -+public: -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementVector = typename InputTileIterator::Element; -+ using ElementFragment = ElementFragment_; -+ -+ using VisitAccessType = Array; -+ -+ /// Thread map used by input tile iterators -+ using ThreadMap = typename InputTileIterator::ThreadMap; -+ -+ /// Fragment object used to store the broadcast values -+ using BroadcastFragment = Array< -+ ElementFragment, -+ ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Used for the broadcast -+ struct BroadcastDetail { -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = (InputTileIterator::Shape::kN / kColumnsPerThread); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ -+ // /// Number of iterations (accesses) the threadblock takes to reduce a row -+ // static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); -+ }; -+ -+ // using ComputeFragmentType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementVector *broadcast_ptr, -+ int64_t batch_stride -+ ): -+ broadcast_ptr(broadcast_ptr), -+ batch_stride(batch_stride) { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementVector *broadcast_ptr; ///< Pointer to the additional tensor operand -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ broadcast_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ broadcast_ptr(args.broadcast_ptr), -+ batch_stride(args.batch_stride) { } -+ }; -+ -+private: -+ ElementVector *broadcast_ptr; -+ BroadcastFragment broadcast_fragment; ///< Array holds the loaded broadcast fragment -+ MatrixCoord threadblock_offset_; -+ int thread_idx_; -+ MatrixCoord problem_size; -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpRowBroadcast( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ broadcast_ptr(params.broadcast_ptr + threadblock_offset.column()), -+ threadblock_offset_(threadblock_offset), -+ thread_idx_(thread_idx), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ broadcast_ptr += batch_idx * batch_stride_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ // load broadcast fragment -+ load_broadcast_fragment_(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) {} -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) {} -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ VisitAccessType* broadcast_fragment_ = reinterpret_cast(&broadcast_fragment); -+ return broadcast_fragment_[column_idx]; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+ -+private: -+ -+ CUTLASS_DEVICE -+ void load_broadcast_fragment_() { -+ -+ broadcast_fragment.clear(); -+ -+ // If no pointer is supplied, set with all zeros and avoid memory accesses -+ if (!broadcast_ptr) { -+ return; -+ } -+ -+ int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); -+ -+ int thread_column_idx = threadblock_offset_.column() + thread_initial_column; -+ broadcast_ptr += thread_initial_column; -+ -+ NumericArrayConverter converter; -+ using AccessType = AlignedArray; -+ using AccessFragmentType = Array; -+ -+ AccessFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { -+ -+ AccessType loaded; -+ -+ loaded.clear(); -+ -+ if (thread_column_idx < problem_size.column()) { -+ loaded = *reinterpret_cast(broadcast_ptr); -+ } -+ -+ AccessFragmentType cvt = converter(loaded); -+ frag_ptr[j] = cvt; -+ -+ thread_column_idx += ThreadMap::Delta::kColumn; -+ broadcast_ptr += ThreadMap::Delta::kColumn; -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h -new file mode 100644 -index 0000000..27b03f8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with reduction over rows in CTA -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "stdio.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementReductionAccumulator R[i] = \sum_i ElementReductionAccumulator(T[i][j]) -+/// device memory <- ElementReduction(R[i]) -+/// -+template < -+ typename ThreadblockShape_, /// Threadblock shape -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementReduction_, ///< Data type of the output reduction in device memory -+ typename ElementReductionAccumulator_ , ///< Data type to accumulate reduction in smem and register -+ typename OutputTileIterator_, ///< Tile Iterator type -+ typename Visitor_ ///< preceeding visitor op -+> -+class VisitorOpRowReduction { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementReductionAccumulator = ElementReductionAccumulator_; -+ using ElementReduction = ElementReduction_; -+ using OutputTileIterator = OutputTileIterator_; -+ using ThreadblockShape = ThreadblockShape_; -+ using Visitor = Visitor_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ -+ using ReductionOp = cutlass::plus>; -+ using ReductionOpScalar = cutlass::plus; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of redcution -+ using ReductionAccumulatorAccessType = Array; -+ -+ /// Thread map used by output tile iterators -+ using ThreadMap = typename OutputTileIterator::ThreadMap; -+ /// Used for the reduction -+ struct ReductionDetail { -+ -+ /// Number of threads per warp -+ static int const kWarpSize = 32; -+ -+ /// Number of distinct scalar column indices handled by each thread -+ static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; -+ -+ /// Number of distinct scalar row indices handled by each thread -+ static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; -+ -+ /// Number of threads per threadblock -+ static int const kThreadCount = ThreadMap::kThreads; -+ -+ /// Number of distinct threads per row of output tile -+ static int const kThreadsPerRow = ThreadblockShape::kN / kColumnsPerThread; -+ -+ /// Half number of threads per row used for cross-thread reduction -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock -+ static int const kThreadRows = kThreadCount / kThreadsPerRow; -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Arguments(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementReduction *reduction_ptr, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ reduction_ptr(reduction_ptr), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ ElementReduction *reduction_ptr; ///< Pointer to the reduction tensor in device memory -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; ///< Argument type of visitor -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): reduction_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ reduction_ptr(args.reduction_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ ElementReduction *reduction_output_ptr_; ///< Pointer to the reduction tensor in device memory -+ ElementReductionAccumulator reduction_accum; -+ Visitor visitor_; ///< visitor -+ int thread_idx_; -+ MatrixCoord threadblock_offset; -+ MatrixCoord problem_size_; -+ -+ int thread_start_row_; /// used to identify -+ int state_[3]; /// used to track row iterator -+ int thread_offset_row_; -+ int64_t batch_stride_; -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpRowReduction( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, -+ thread_idx, threadblock_offset, problem_size), -+ reduction_output_ptr_(params.reduction_ptr), -+ thread_idx_(thread_idx), -+ threadblock_offset(threadblock_offset), -+ problem_size_(problem_size), -+ thread_start_row_(ThreadMap::initial_offset(thread_idx).row() + threadblock_offset.row()), -+ batch_stride_(params.batch_stride) -+ { -+ state_[0] = state_[1] = state_[2] = 0; -+ } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ reduction_output_ptr_ += batch_idx * batch_stride_; -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ -+ reduction_accum = ElementReductionAccumulator(0); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ thread_offset_row_ = thread_start_row_ + ThreadMap::iteration_offset(frag_idx).row(); -+ -+ ReductionOpScalar reduction_op; -+ -+ ElementReductionAccumulator reduction_accum_ = reduction(result); -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = ReductionDetail::kHalfThreadsPerRow; i > 0; i >>= 1) { -+ reduction_accum_ = reduction_op(reduction_accum_, __shfl_xor_sync(0xFFFFFFFF, reduction_accum_, i)); -+ } -+ reduction_accum = reduction_op(reduction_accum, reduction_accum_); -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ NumericConverter output_converter; -+ -+ bool is_write_thread = (thread_offset_row_ < problem_size_.row() && (thread_idx_ % ReductionDetail::kThreadsPerRow) == 0); -+ int row_offset = thread_offset_row_ + threadblock_offset.column() / ThreadblockShape::kN * problem_size_.row(); -+ -+ ElementReduction *curr_ptr_reduction = reduction_output_ptr_ + row_offset; -+ -+ arch::global_store( -+ output_converter(reduction_accum), -+ (void *)curr_ptr_reduction, -+ is_write_thread); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ -+ // run operator ++ -+ ++state_[0]; -+ -+ thread_start_row_ += ThreadMap::Shape::kRow; -+ if (state_[0] == ThreadMap::Count::kRow) { -+ state_[0] = 0; -+ ++state_[1]; -+ thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * -+ ThreadMap::Shape::kRow * ThreadMap::Count::kRow; -+ -+ if (state_[1] == ThreadMap::Count::kGroup) { -+ state_[1] = 0; -+ ++state_[2]; -+ thread_start_row_ += ThreadMap::Count::kGroup * -+ ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; -+ -+ if (state_[2] == ThreadMap::Count::kCluster) { -+ state_[2] = 0; -+ } -+ } -+ } -+ -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementReductionAccumulator reduction(VisitAccessTypeVisitor const& result) { -+ ElementReductionAccumulator sum_ = ElementReductionAccumulator(0); -+ -+ ReductionOpScalar reduction_op; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < VisitAccessTypeVisitor::kElements; ++i) { -+ sum_ = reduction_op(sum_, result[i]); -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h -new file mode 100644 -index 0000000..d2eac4f ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h -@@ -0,0 +1,188 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Tensor Output -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementInput C <- device memory -+/// -+/// It can only be a leaf node in the epilogue tree -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename InputTileIterator_ ///< Tile iterator type to read the tensor -+> -+class VisitorOpTensorInput { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using InputTileIterator = InputTileIterator_; -+ -+ static int const kElementsPerAccess = InputTileIterator::kElementsPerAccess; -+ using ElementInput = typename InputTileIterator::Element; -+ -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ struct SharedStorage { -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementInput *input_ptr; ///< Pointer to the input tensor in device memory -+ int ldt; ///< Leading dimension of the input tensor operand -+ int64_t batch_stride; ///< batch stride for batched GEMM -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): input_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementInput *input_ptr, -+ int ldt, int64_t batch_stride -+ ): -+ input_ptr(input_ptr), -+ ldt(ldt), -+ batch_stride(batch_stride) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ typename InputTileIterator::Params params_input; -+ ElementInput *input_ptr; -+ int64_t batch_stride; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ input_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ params_input(args.ldt), -+ input_ptr(args.input_ptr), -+ batch_stride(args.batch_stride) -+ { } -+ }; -+ -+private: -+ InputTileIterator iterator_T_; -+ typename InputTileIterator::Fragment fragment_T_; -+ MatrixCoord problem_size; -+ int64_t batch_stride_; -+ -+public: -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpTensorInput( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ iterator_T_( -+ InputTileIterator( -+ params.params_input, -+ params.input_ptr, -+ problem_size, -+ thread_idx, -+ threadblock_offset -+ ) -+ ), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_T_.add_pointer_offset(batch_idx * batch_stride_); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_T_.clear(); -+ iterator_T_.load(fragment_T_); -+ ++iterator_T_; -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ VisitAccessType source = reinterpret_cast(&fragment_T_)[frag_idx]; -+ return source; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { } -+}; -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h -new file mode 100644 -index 0000000..407611a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h -@@ -0,0 +1,240 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Tensor Output -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "stdio.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementOutput T = ElementOutput(Visitor) -+/// T-> device memory -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename OutputTileIterator_, ///< Tile iterator type to write the tensor -+ typename Visitor_ ///< Child visitor that produces the output tensor -+> -+class VisitorOpTensorOutput { -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using OutputTileIterator = OutputTileIterator_; -+ -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ using Visitor = Visitor_; -+ -+ /// Fragment type returned from Visitor -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisitor = typename VisitAccessTypeVisitor::Element; -+ -+ using VisitAccessType = VisitAccessTypeVisitor; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Fragment type of output -+ using OutputAccessType = Array; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); -+ -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() { } -+ }; -+ -+ /// Host-constructable Argument structure -+ struct Arguments { -+ ElementOutput *output_ptr; ///< Pointer to the output tensor in device memory -+ int ldt; ///< Leading dimension of the output tensor operand -+ int64_t batch_stride; ///< batch stride -+ typename Visitor::Arguments visitor_arg; ///< Argument type of visitor -+ -+ /// Methods -+ CUTLASS_HOST_DEVICE -+ Arguments(): output_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ ElementOutput *output_ptr, -+ int ldt, -+ int64_t batch_stride, -+ typename Visitor::Arguments visitor_arg -+ ): -+ output_ptr(output_ptr), -+ ldt(ldt), -+ batch_stride(batch_stride), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Param structure -+ struct Params { -+ typename OutputTileIterator::Params params_output; -+ ElementOutput *output_ptr; -+ int64_t batch_stride; -+ typename Visitor::Params visitor_param; -+ -+ /// Method -+ CUTLASS_HOST_DEVICE -+ Params(): -+ output_ptr(nullptr) { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ params_output(args.ldt), -+ output_ptr(args.output_ptr), -+ batch_stride(args.batch_stride), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ OutputTileIterator iterator_T_; -+ typename OutputTileIterator::Fragment fragment_T_; -+ MatrixCoord problem_size; -+ Visitor visitor_; -+ int64_t batch_stride_; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpTensorOutput( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ visitor_(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size), -+ iterator_T_( -+ OutputTileIterator( -+ params.params_output, -+ params.output_ptr, -+ problem_size, -+ thread_idx, -+ threadblock_offset -+ ) -+ ), -+ problem_size(problem_size), -+ batch_stride_(params.batch_stride) { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ iterator_T_.add_pointer_offset(batch_idx * batch_stride_); -+ visitor_.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ visitor_.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_T_.clear(); -+ visitor_.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ visitor_.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor -+ VisitAccessTypeVisitor result = visitor_.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ -+ // Column guard -+ MatrixCoord thread_offset_ = iterator_T_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ bool column_guard = (thread_offset_.column() < problem_size.column()); -+ -+ if (column_guard) { -+ NumericArrayConverter output_converter; -+ OutputAccessType &output = reinterpret_cast(&fragment_T_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ return result; -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ visitor_.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ visitor_.end_step(step_idx); -+ iterator_T_.store(fragment_T_); -+ ++iterator_T_; -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ visitor_.end_epilogue(); -+ } -+ -+}; -+ -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h -new file mode 100644 -index 0000000..c80543e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ -+ \brief A file contains the epilogue visitor Op with Unary operation -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "unary_ops.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Epilogue Visitor operator for the following computation: -+/// -+/// ElementCompute alpha; -+/// ElementCompute beta; -+/// ElementCompute C = UnaryOp(ElementCompute(Visitor)) -+/// Return C; -+/// -+template < -+ typename ElementAccumulator_, ///< Data type of the Accumulator -+ typename ElementCompute_, ///< Data type used to compute linear combination -+ int kElementsPerAccess_, ///< Number of elements computed per operation -+ typename Visitor_, ///< Child node -+ template typename UnaryOp_ -+> -+class VisitorOpUnary{ -+public: -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ static int const kElementsPerAccess = kElementsPerAccess_; -+ -+ using Visitor = Visitor_; -+ -+ /// Fragment type returned from Visitor.visit -+ using VisitAccessTypeVisitor = typename Visitor::VisitAccessType; -+ using ElementVisit = typename VisitAccessTypeVisitor::Element; -+ -+ /// Fragment type returned by this visitor -+ using VisitAccessType = Array; -+ -+ /// Fragment type of accumulator -+ using AccumulatorAccessType = Array; -+ -+ /// Combination Op -+ using UnaryOp = UnaryOp_; -+ -+ static_assert(kElementsPerAccess==VisitAccessTypeVisitor::kElements, "kElementsPerAccess mismatches with Visitor"); -+ -+ /// SMEM buffer class required in the epilogue visitor -+ struct SharedStorage { -+ typename Visitor::SharedStorage storage_visitor; -+ -+ CUTLASS_HOST_DEVICE -+ SharedStorage() {} -+ }; -+ -+ -+ /// Host-constructable Arguments structure -+ struct Arguments { -+ typename UnaryOp::Arguments unary_arg; -+ typename Visitor::Arguments visitor_arg; ///< Argument type for visitor -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Arguments():unary_arg() { } -+ -+ CUTLASS_HOST_DEVICE -+ Arguments( -+ typename UnaryOp::Arguments unary_arg, -+ typename Visitor::Arguments visitor_arg -+ ): -+ unary_arg(unary_arg), -+ visitor_arg(visitor_arg) -+ { } -+ }; -+ -+ /// Parameter structure -+ struct Params { -+ typename UnaryOp::Params unary_param; -+ typename Visitor::Params visitor_param; ///< Argument type for visitor -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params():unary_param() { } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ unary_param(args.unary_arg), -+ visitor_param(args.visitor_arg) -+ { } -+ }; -+ -+private: -+ // -+ // Data members -+ // -+ UnaryOp unary_op; -+ -+ Visitor visitor_op; -+ -+public: -+ -+ /// Constructs the function object -+ CUTLASS_HOST_DEVICE -+ VisitorOpUnary( -+ Params const ¶ms, -+ SharedStorage &shared_storage, -+ int thread_idx, -+ MatrixCoord threadblock_offset, -+ MatrixCoord problem_size -+ ): -+ unary_op(params.unary_param), -+ visitor_op(params.visitor_param, shared_storage.storage_visitor, thread_idx, threadblock_offset, problem_size) -+ { } -+ -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ visitor_op.set_batch_index(batch_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ if (unary_op.guard()) visitor_op.begin_epilogue(); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ if (unary_op.guard()) visitor_op.begin_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ if (unary_op.guard()) visitor_op.begin_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ VisitAccessType visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorAccessType const &accum -+ ) { -+ /// Get result from visitor A and visitor B -+ VisitAccessTypeVisitor result; -+ -+ if (unary_op.guard()){ -+ result = visitor_op.visit(iter_idx, row_idx, column_idx, frag_idx, accum); -+ } else { -+ result.clear(); -+ } -+ -+ /// Type conversion -+ NumericArrayConverter source_converter; -+ -+ cutlass::multiplies multiply_op; -+ -+ return unary_op(source_converter(result)); -+ } -+ -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ if (unary_op.guard()) visitor_op.end_row(row_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ if (unary_op.guard()) visitor_op.end_step(step_idx); -+ } -+ -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ if (unary_op.guard()) visitor_op.end_epilogue(); -+ } -+}; -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h -new file mode 100644 -index 0000000..54936ff ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h -@@ -0,0 +1,480 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this layernormware without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Epilogue visitor type used for partial computation of a layernorm operation -+ -+ GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) -+ + lightweight full reduction kernel (ApplyFinalReduction) -+ + GEMM1 with elementwise operations fused in mainloop (GemmLayernormMainloopFusion) -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/arch/memory.h" -+#include "cutlass/arch/memory_sm75.h" -+#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -+#include "cutlass/gemm/kernel/default_gemm.h" -+#include "cutlass/gemm/kernel/default_gemm_complex.h" -+#include "cutlass/gemm/device/default_gemm_configuration.h" -+#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace cutlass { -+namespace epilogue { -+namespace threadblock { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ThreadblockShape_, -+ int ThreadCount, -+ typename OutputTileIterator_, -+ typename AccumulatorTile_, -+ typename ElementAccumulator_, -+ typename ElementVariance_, -+ typename ElementMean_, -+ typename ElementLayernormCompute_, -+ typename ElementwiseFunctor_, -+ bool IsShiftedVariance_ = false -+> -+class EpilogueVisitorLayerNorm { -+public: -+ -+ using ElementVariance = ElementVariance_; -+ using ElementMean = ElementMean_; -+ using ElementLayernormCompute = ElementLayernormCompute_; -+ -+ using AccumulatorTile = AccumulatorTile_; -+ -+ using ThreadblockShape = ThreadblockShape_; -+ static int const kThreadCount = ThreadCount; -+ -+ using OutputTileIterator = OutputTileIterator_; -+ using ElementwiseFunctor = ElementwiseFunctor_; -+ -+ static int const kIterations = OutputTileIterator::kIterations; -+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; -+ static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; -+ -+ static int const kThreads = OutputTileIterator::ThreadMap::kThreads; -+ -+ static bool const kIsShiftedVariance = IsShiftedVariance_; -+ -+ using ElementOutput = typename OutputTileIterator::Element; -+ -+ static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; -+ -+ /// Array type used in Shift-K Layernorm -+ static int const kRowAccessCount = kIterations * kRowIterations; -+ -+ using ConvertedShiftFragment = Array; -+ -+ // Conducts manual transpose externally (already supported) for column major -+ using LayoutOutput = cutlass::layout::RowMajor; -+ -+ using ElementAccumulator = ElementAccumulator_; -+ -+ using AccumulatorFragment = Array; -+ using LayernormFragment = Array; -+ using OutputVector = Array; -+ using TensorRefD = TensorRef; -+ -+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; -+ static int const kThreadsInColumn = kThreads / kThreadsPerRow; -+ static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); -+ -+ /// Argument structure -+ struct Arguments { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ MatrixCoord extent; -+ -+ // -+ // Methods -+ // -+ Arguments(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr), -+ ptr_Shifted_K(nullptr) -+ { -+ -+ } -+ -+ Arguments( -+ typename ElementwiseFunctor::Params elementwise_, -+ ElementVariance *ptr_Variance, -+ ElementMean *ptr_Mean_, -+ ElementOutput *ptr_Shifted_K_ = nullptr, -+ MatrixCoord extent = MatrixCoord(0, 0) -+ ): -+ elementwise(elementwise_), -+ ptr_Variance(ptr_Variance), -+ ptr_Mean(ptr_Mean_), -+ ptr_Shifted_K(ptr_Shifted_K_), -+ extent(extent) -+ { -+ -+ } -+ }; -+ -+ struct Params { -+ -+ typename ElementwiseFunctor::Params elementwise; -+ ElementVariance *ptr_Variance; -+ ElementMean *ptr_Mean; -+ ElementOutput *ptr_Shifted_K; -+ MatrixCoord extent; -+ -+ // -+ // Methods -+ // -+ CUTLASS_HOST_DEVICE -+ Params(): -+ ptr_Variance(nullptr), -+ ptr_Mean(nullptr) -+ { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ Params(Arguments const &args): -+ elementwise(args.elementwise), -+ ptr_Variance(args.ptr_Variance), -+ ptr_Mean(args.ptr_Mean), -+ ptr_Shifted_K(args.ptr_Shifted_K), -+ extent(args.extent) -+ { -+ -+ } -+ }; -+ -+ /// Shared storage -+ struct SharedStorage { -+ -+ }; -+ -+private: -+ -+ Params const & params_; -+ SharedStorage & shared_storage_; -+ MatrixCoord extent_; -+ ElementwiseFunctor elementwise_; -+ -+ OutputTileIterator iterator_C_; -+ OutputTileIterator iterator_D_; -+ typename OutputTileIterator::Fragment fragment_C_; -+ typename OutputTileIterator::Fragment fragment_D_; -+ -+ ElementAccumulator alpha_; -+ ElementAccumulator beta_; -+ ConvertedShiftFragment shift_k_frag_; -+ -+ ElementLayernormCompute accum_sum_square_; -+ ElementLayernormCompute accum_sum_element_; -+ int thread_idx_; -+ -+ MatrixCoord thread_offset_; -+ -+ gemm::GemmCoord threadblock_tile_offset_; -+ -+public: -+ -+ CUTLASS_DEVICE -+ EpilogueVisitorLayerNorm( -+ Params const ¶ms, ///< Parameters routed to the epilogue -+ SharedStorage &shared_storage, ///< Shared storage needed by the functors here -+ MatrixCoord threadblock_offset, -+ gemm::GemmCoord threadblock_tile_offset, -+ int thread_idx, -+ OutputTileIterator destination_iterator, ///< Tile iterator for destination -+ OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMMM -+ ): -+ params_(params), -+ shared_storage_(shared_storage), -+ elementwise_(params.elementwise), -+ extent_(params.extent), -+ iterator_C_(source_iterator), -+ iterator_D_(destination_iterator), -+ threadblock_tile_offset_(threadblock_tile_offset), -+ thread_idx_(thread_idx) -+ { -+ alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); -+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); -+ -+ if (beta_ == ElementAccumulator()) { -+ iterator_C_.clear_mask(); -+ } -+ } -+ -+ /// Helper to indicate split-K behavior -+ CUTLASS_DEVICE -+ void set_k_partition( -+ int split_k_index, ///< Index of this threadblock within split-K partitioned scheme -+ int split_k_slices) { ///< Total number of split-K slices -+ -+ } -+ -+ /// Called to set the batch index -+ CUTLASS_DEVICE -+ void set_batch_index(int batch_idx) { -+ -+ } -+ -+ /// Called at the start of the epilogue just before iterating over accumulator slices -+ CUTLASS_DEVICE -+ void begin_epilogue() { -+ -+ // If shift-K feature is enabled, we load shift-k fragment -+ // at the very beginning of an epilogue -+ if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { -+ shift_k_frag_.clear(); -+ int thread_offset_row_base = iterator_D_.thread_start_row(); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { -+ int step_offset = iter_idx * OutputTileIterator::Shape::kRow; -+ CUTLASS_PRAGMA_UNROLL -+ for (int rid = 0; rid < kRowIterations; ++rid) { -+ int row_step_offset = rid * kDeltaRow; -+ int row_offset = thread_offset_row_base + step_offset + row_step_offset; -+ bool is_load = (row_offset < extent_.row()); -+ shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); -+ } -+ -+ } -+ -+ } -+ -+ } -+ -+ /// Called at the start of one step before starting accumulator exchange -+ CUTLASS_DEVICE -+ void begin_step(int step_idx) { -+ fragment_D_.clear(); -+ -+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ fragment_C_.clear(); -+ iterator_C_.load(fragment_C_); -+ ++iterator_C_; -+ } -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void begin_row(int row_idx) { -+ /// set the accumulator to 0 -+ accum_sum_element_ = ElementLayernormCompute(0); -+ accum_sum_square_ = ElementLayernormCompute(0); -+ } -+ -+ /// Called after accumulators have been exchanged for each accumulator vector -+ CUTLASS_DEVICE -+ void visit( -+ int iter_idx, -+ int row_idx, -+ int column_idx, -+ int frag_idx, -+ AccumulatorFragment const &accum) { -+ -+ using Mul = cutlass::multiplies; -+ using Minus = cutlass::minus; -+ using Exp = cutlass::fast_exp_op; -+ -+ Minus minus; -+ Mul mul; -+ Exp exponential; -+ -+ LayernormFragment result; -+ -+ thread_offset_ = -+ iterator_D_.thread_start() + -+ OutputTileIterator::ThreadMap::iteration_offset(frag_idx); -+ -+ NumericArrayConverter source_converter; -+ OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; -+ -+ bool column_guard = (thread_offset_.column() < extent_.column()); -+ -+ if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { -+ result = source_converter(elementwise_(accum)); -+ }else{ -+ result = source_converter(elementwise_(accum, source_vector)); -+ } -+ -+ -+ ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); -+ -+ // Fragment is cleared for non-reachable columns so no need to check against column guard -+ ElementLayernormCompute accum_sum_element_tmp = element_sum_accumulator_(result); -+ -+ // Square sum is different. Non-reachable columns should've been computed for shift-k -+ // Otherwise we will incorrectly have some extra k^2 added into square sum. -+ ElementLayernormCompute accum_sum_square_tmp = ElementLayernormCompute(0); -+ -+ if (column_guard) { -+ accum_sum_square_tmp = (kIsShiftedVariance) ? \ -+ square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ -+ square_sum_accumulator_(result); -+ } -+ -+ accum_sum_element_tmp *= inv_scalar; -+ accum_sum_square_tmp *= inv_scalar; -+ -+ // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { -+ accum_sum_element_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_tmp, i); -+ accum_sum_square_tmp += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_tmp, i); -+ } -+ accum_sum_element_ += accum_sum_element_tmp; -+ accum_sum_square_ += accum_sum_square_tmp; -+ -+ // Convert to the output -+ NumericArrayConverter output_converter; -+ OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; -+ output = output_converter(result); -+ } -+ -+ /// Called at the start of a row -+ CUTLASS_DEVICE -+ void end_row(int row_idx) { -+ -+ using ConvertVarianceOutput = cutlass::NumericConverter; -+ using ConvertMeanOutput = cutlass::NumericConverter; -+ -+ ConvertVarianceOutput convert_variance_output; -+ ConvertMeanOutput convert_mean_output; -+ -+ bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); -+ int row_offset = thread_offset_.row() + threadblock_tile_offset_.n() * extent_.row(); -+ -+ ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; -+ ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; -+ -+ arch::global_store( -+ convert_variance_output(accum_sum_square_), -+ (void *)curr_ptr_sum_square, -+ is_write_thread); -+ -+ arch::global_store( -+ convert_mean_output(accum_sum_element_), -+ (void *)curr_ptr_element_sum, -+ is_write_thread); -+ } -+ -+ /// Called after all accumulator elements have been visited -+ CUTLASS_DEVICE -+ void end_step(int step_idx) { -+ -+ iterator_D_.store(fragment_D_); -+ ++iterator_D_; -+ } -+ -+ /// Called after all steps have been completed -+ CUTLASS_DEVICE -+ void end_epilogue() { -+ -+ } -+ -+private: -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { -+ using ConvertShiftK = cutlass::NumericConverter; -+ ConvertShiftK convert_shift_k; -+ ElementOutput shift_k_val; -+ -+ // Computes the address to load shift_k element -+ ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; -+ // Conditionally loads from global memory -+ arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); -+ // Converts data type to return -+ ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); -+ -+ return converted_shift_k_val; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i]; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ auto accum_ = accum[i] - shift_k_val; -+ sum_ += accum_ * accum_; -+ } -+ -+ return sum_; -+ } -+ -+ CUTLASS_DEVICE -+ ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { -+ ElementLayernormCompute sum_ = ElementLayernormCompute(0); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < LayernormFragment::kElements; ++i) { -+ sum_ += accum[i]; -+ } -+ -+ return sum_; -+ } -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h -new file mode 100644 -index 0000000..36987b5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h -@@ -0,0 +1,77 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm related enum types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/gemm/gemm.h" -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_gemm(py::module &m) { -+ // -+ // Enumerate types -+ // cutlass/gemm/gemm.h -+ -+ py::enum_(m, "Mode") -+ .value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial") -+ .value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel") -+ .value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM") -+ .value("Array", cutlass::gemm::GemmUniversalMode::kArray) -+ .value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid); -+ -+ /// GemmCoord is a structure that specifies a location within the coordiate space of a GEMM problem -+ py::class_(m, "GemmCoord") -+ .def(py::init()) -+ .def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m)) -+ .def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n)) -+ .def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k)) -+ // get tensor coords -+ .def("mk", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.mk()); -+ }) -+ .def("kn", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.kn()); -+ }) -+ .def("mn", -+ [](const cutlass::gemm::GemmCoord & problem_size) { -+ return cutlass::MatrixCoord(problem_size.mn()); -+ }); -+ -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_gemm_host_helper(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h -new file mode 100644 -index 0000000..64b65a0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h -@@ -0,0 +1,628 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/fast_math.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/gemm/kernel/params_universal_base.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/semaphore.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/trace.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace gemm { -+namespace kernel { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate -+ typename Epilogue_, ///! Epilogue -+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function -+> -+struct GemmUniversalwithEpilogueVisitor { -+public: -+ -+ using Mma = Mma_; -+ using Epilogue = Epilogue_; -+ using EpilogueVisitor = typename Epilogue::Visitor; -+ using ThreadblockSwizzle = ThreadblockSwizzle_; -+ -+ using ElementA = typename Mma::IteratorA::Element; -+ using LayoutA = typename Mma::IteratorA::Layout; -+ using ElementB = typename Mma::IteratorB::Element; -+ using LayoutB = typename Mma::IteratorB::Layout; -+ using ElementC = typename EpilogueVisitor::ElementOutput; -+ using LayoutC = typename EpilogueVisitor::OutputTileIterator::Layout; -+ -+ static ComplexTransform const kTransformA = Mma::kTransformA; -+ static ComplexTransform const kTransformB = Mma::kTransformB; -+ using Operator = typename Mma::Operator; -+ -+ using OperatorClass = typename Mma::Operator::OperatorClass; -+ using ThreadblockShape = typename Mma::Shape; -+ using WarpShape = typename Mma::Operator::Shape; -+ using InstructionShape = typename Mma::Policy::Operator::InstructionShape; -+ using ArchTag = typename Mma::ArchTag; -+ -+ static int const kStages = Mma::kStages; -+ static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; -+ -+ /// Warp count (concept: GemmShape) -+ using WarpCount = typename Mma::WarpCount; -+ static int const kThreadCount = 32 * WarpCount::kCount; -+ -+ /// Split-K preserves splits that are 128b aligned -+ static int const kSplitKAlignment = const_max( -+ 128 / sizeof_bits::value, -+ 128 / sizeof_bits::value -+ ); -+ -+ // -+ // Structures -+ // -+ -+ /// Argument structure -+ struct Arguments : UniversalArgumentsBase { -+ -+ // -+ // Data members -+ // -+ -+ typename EpilogueVisitor::Arguments epilogue_visitor; -+ -+ void const * ptr_A; -+ void const * ptr_B; -+ void const * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ typename LayoutA::Stride stride_a; -+ typename LayoutB::Stride stride_b; -+ typename LayoutC::Stride stride_c; -+ typename LayoutC::Stride stride_d; -+ -+ typename LayoutA::Stride::LongIndex lda; -+ typename LayoutB::Stride::LongIndex ldb; -+ typename LayoutC::Stride::LongIndex ldc; -+ typename LayoutC::Stride::LongIndex ldd; -+ -+ int const * ptr_gather_A_indices; -+ int const * ptr_gather_B_indices; -+ int const * ptr_scatter_D_indices; -+ -+ // -+ // Methods -+ // -+ -+ Arguments(): -+ ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), -+ ptr_gather_A_indices(nullptr), -+ ptr_gather_B_indices(nullptr), -+ ptr_scatter_D_indices(nullptr) {} -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueVisitor::Arguments epilogue_visitor, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride stride_a, -+ typename LayoutB::Stride stride_b, -+ typename LayoutC::Stride stride_c, -+ typename LayoutC::Stride stride_d, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue_visitor(epilogue_visitor), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) { -+ lda = 0; -+ ldb = 0; -+ ldc = 0; -+ ldd = 0; -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// constructs an arguments structure -+ Arguments( -+ GemmUniversalMode mode, -+ GemmCoord problem_size, -+ int batch_count, -+ typename EpilogueVisitor::Arguments epilogue_visitor, -+ void const * ptr_A, -+ void const * ptr_B, -+ void const * ptr_C, -+ void * ptr_D, -+ int64_t batch_stride_A, -+ int64_t batch_stride_B, -+ int64_t batch_stride_C, -+ int64_t batch_stride_D, -+ typename LayoutA::Stride::LongIndex lda, -+ typename LayoutB::Stride::LongIndex ldb, -+ typename LayoutC::Stride::LongIndex ldc, -+ typename LayoutC::Stride::LongIndex ldd, -+ int const *ptr_gather_A_indices = nullptr, -+ int const *ptr_gather_B_indices = nullptr, -+ int const *ptr_scatter_D_indices = nullptr -+ ): -+ UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), -+ epilogue_visitor(epilogue_visitor), -+ ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), -+ batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), -+ lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), -+ ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), -+ ptr_scatter_D_indices(ptr_scatter_D_indices) { -+ stride_a = make_Coord(lda); -+ stride_b = make_Coord(ldb); -+ stride_c = make_Coord(ldc); -+ stride_d = make_Coord(ldd); -+ CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); -+ } -+ -+ /// Returns arguments for the transposed problem -+ Arguments transposed_problem() const { -+ Arguments args(*this); -+ -+ std::swap(args.problem_size.m(), args.problem_size.n()); -+ std::swap(args.ptr_A, args.ptr_B); -+ std::swap(args.lda, args.ldb); -+ std::swap(args.stride_a, args.stride_b); -+ std::swap(args.batch_stride_A, args.batch_stride_B); -+ std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); -+ -+ return args; -+ } -+ }; -+ -+ // -+ // Structure for precomputing values in host memory and passing to kernels -+ // -+ -+ /// Parameters structure -+ struct Params : UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC> { -+ -+ using ParamsBase = UniversalParamsBase< -+ ThreadblockSwizzle, -+ ThreadblockShape, -+ ElementA, -+ ElementB, -+ ElementC>; -+ -+ typename Mma::IteratorA::Params params_A; -+ typename Mma::IteratorB::Params params_B; -+ typename EpilogueVisitor::OutputTileIterator::Params params_C; -+ typename EpilogueVisitor::OutputTileIterator::Params params_D; -+ -+ typename EpilogueVisitor::Params epilogue_visitor; -+ -+ void * ptr_A; -+ void * ptr_B; -+ void * ptr_C; -+ void * ptr_D; -+ -+ int64_t batch_stride_A; -+ int64_t batch_stride_B; -+ int64_t batch_stride_C; -+ -+ int * ptr_gather_A_indices; -+ int * ptr_gather_B_indices; -+ int * ptr_scatter_D_indices; -+ -+ int *semaphore; -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor -+ Params() = default; -+ -+ CUTLASS_HOST_DEVICE -+ Params( -+ Arguments const &args, -+ int device_sms, -+ int sm_occupancy -+ ): -+ ParamsBase(args, device_sms, sm_occupancy), -+ params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), -+ params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), -+ params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), -+ params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), -+ epilogue_visitor(args.epilogue_visitor), -+ ptr_A(const_cast(args.ptr_A)), -+ ptr_B(const_cast(args.ptr_B)), -+ ptr_C(const_cast(args.ptr_C)), -+ ptr_D(args.ptr_D), -+ batch_stride_A(args.batch_stride_A), -+ batch_stride_B(args.batch_stride_B), -+ batch_stride_C(args.batch_stride_C), -+ ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), -+ ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), -+ ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) { -+ -+ } -+ -+ CUTLASS_HOST_DEVICE -+ void update( -+ Arguments const &args, -+ void *workspace = nullptr) { -+ -+ ptr_A = const_cast(args.ptr_A); -+ ptr_B = const_cast(args.ptr_B); -+ ptr_C = const_cast(args.ptr_C); -+ ptr_D = args.ptr_D; -+ -+ ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); -+ ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); -+ ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); -+ -+ batch_stride_A = args.batch_stride_A; -+ batch_stride_B = args.batch_stride_B; -+ batch_stride_C = args.batch_stride_C; -+ -+ epilogue_visitor = args.epilogue_visitor; -+ -+ semaphore = static_cast(workspace); -+ CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); -+ } -+ }; -+ -+ /// Shared memory storage structure -+ union SharedStorage { -+ typename Mma::SharedStorage main_loop; -+ typename Epilogue::SharedStorage epilogue; -+ typename EpilogueVisitor::SharedStorage visitor; -+ }; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ GemmUniversalwithEpilogueVisitor() { } -+ -+ /// Determines whether kernel satisfies alignment -+ static Status can_implement( -+ cutlass::gemm::GemmCoord const & problem_size) { -+ -+ CUTLASS_TRACE_HOST("GemmUniversalwithEpilogueVisitor::can_implement()"); -+ -+ static int const kAlignmentA = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorA::AccessType::kElements; -+ static int const kAlignmentB = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Mma::IteratorB::AccessType::kElements; -+ static int const kAlignmentC = (platform::is_same>::value) -+ ? 32 -+ : (platform::is_same>::value) -+ ? 64 -+ : Epilogue::OutputTileIterator::kElementsPerAccess; -+ -+ bool isAMisaligned = false; -+ bool isBMisaligned = false; -+ bool isCMisaligned = false; -+ -+ if (platform::is_same::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } else if (platform::is_same::value) { -+ isAMisaligned = problem_size.m() % kAlignmentA; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isAMisaligned = problem_size.k() % kAlignmentA; -+ } -+ -+ if (platform::is_same::value) { -+ isBMisaligned = problem_size.n() % kAlignmentB; -+ } else if (platform::is_same::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isBMisaligned = problem_size.k() % kAlignmentB; -+ } -+ -+ if (platform::is_same::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } else if (platform::is_same::value) { -+ isCMisaligned = problem_size.m() % kAlignmentC; -+ } else if (platform::is_same>::value -+ || platform::is_same>::value) { -+ isCMisaligned = problem_size.n() % kAlignmentC; -+ } -+ -+ if (isAMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isBMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ if (isCMisaligned) { -+ CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); -+ return Status::kErrorMisalignedOperand; -+ } -+ -+ CUTLASS_TRACE_HOST(" returning kSuccess"); -+ -+ return Status::kSuccess; -+ } -+ -+ static Status can_implement(Arguments const &args) { -+ return can_implement(args.problem_size); -+ } -+ -+ /// Executes one GEMM -+ CUTLASS_DEVICE -+ void operator()(Params const ¶ms, SharedStorage &shared_storage) { -+ -+ // Compute threadblock location -+ ThreadblockSwizzle threadblock_swizzle; -+ -+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ // Early exit if CTA is out of range -+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || -+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { -+ -+ return; -+ } -+ -+ int offset_k = 0; -+ int problem_size_k = params.problem_size.k(); -+ -+ ElementA *ptr_A = static_cast(params.ptr_A); -+ ElementB *ptr_B = static_cast(params.ptr_B); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ if (params.mode == GemmUniversalMode::kGemm || -+ params.mode == GemmUniversalMode::kGemmSplitKParallel) { -+ -+ if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { -+ -+ problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; -+ } -+ -+ offset_k = threadblock_tile_offset.k() * params.gemm_k_size; -+ } -+ else if (params.mode == GemmUniversalMode::kBatched) { -+ ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; -+ ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; -+ } -+ else if (params.mode == GemmUniversalMode::kArray) { -+ ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; -+ ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; -+ } -+ -+ __syncthreads(); -+ -+ // Compute initial location in logical coordinates -+ cutlass::MatrixCoord tb_offset_A{ -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ offset_k, -+ }; -+ -+ cutlass::MatrixCoord tb_offset_B{ -+ offset_k, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ }; -+ -+ // Compute position within threadblock -+ int thread_idx = threadIdx.x; -+ -+ // Construct iterators to A and B operands -+ typename Mma::IteratorA iterator_A( -+ params.params_A, -+ ptr_A, -+ {params.problem_size.m(), problem_size_k}, -+ thread_idx, -+ tb_offset_A, -+ params.ptr_gather_A_indices); -+ -+ typename Mma::IteratorB iterator_B( -+ params.params_B, -+ ptr_B, -+ {problem_size_k, params.problem_size.n()}, -+ thread_idx, -+ tb_offset_B, -+ params.ptr_gather_B_indices); -+ -+ // Broadcast the warp_id computed by lane 0 to ensure dependent code -+ // is compiled as warp-uniform. -+ int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); -+ -+ int lane_idx = threadIdx.x % 32; -+ -+ // -+ // Main loop -+ // -+ -+ // Construct thread-scoped matrix multiply -+ Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); -+ -+ typename Mma::FragmentC accumulators; -+ -+ accumulators.clear(); -+ -+ // Compute threadblock-scoped matrix multiply-add -+ int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; -+ -+ // Compute threadblock-scoped matrix multiply-add -+ mma( -+ gemm_k_iterations, -+ accumulators, -+ iterator_A, -+ iterator_B, -+ accumulators); -+ -+ // -+ // Epilogue -+ // -+ -+ // EpilogueOutputOp output_op(params.output_op); -+ -+ // -+ // Masked tile iterators constructed from members -+ // -+ -+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); -+ -+ //assume identity swizzle -+ MatrixCoord threadblock_offset( -+ threadblock_tile_offset.m() * Mma::Shape::kM, -+ threadblock_tile_offset.n() * Mma::Shape::kN -+ ); -+ -+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); -+ -+ ElementC *ptr_C = static_cast(params.ptr_C); -+ ElementC *ptr_D = static_cast(params.ptr_D); -+ -+ // -+ // Fetch pointers based on mode. -+ // -+ -+ // Construct the semaphore. -+ Semaphore semaphore(params.semaphore + block_idx, thread_idx); -+ -+ // Tile iterator loading from source tensor. -+ -+ EpilogueVisitor epilogue_visitor( -+ params.epilogue_visitor, -+ shared_storage.visitor, -+ threadblock_offset, -+ threadblock_tile_offset, -+ thread_idx, -+ params.problem_size.mn() -+ ); -+ -+ if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { -+ epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); -+ } -+ -+ Epilogue epilogue( -+ shared_storage.epilogue, -+ thread_idx, -+ warp_idx, -+ lane_idx); -+ -+ // Wait on the semaphore - this latency may have been covered by iterator construction -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor. -+ semaphore.wait(threadblock_tile_offset.k()); -+ } -+ -+ -+ // Execute the epilogue operator to update the destination tensor. -+ epilogue(epilogue_visitor, accumulators); -+ -+ // -+ // Release the semaphore -+ // -+ -+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { -+ -+ int lock = 0; -+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { -+ -+ // The final threadblock resets the semaphore for subsequent grids. -+ lock = 0; -+ } -+ else { -+ // Otherwise, the semaphore is incremented -+ lock = threadblock_tile_offset.k() + 1; -+ } -+ -+ semaphore.release(lock); -+ } -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace gemm -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h -new file mode 100644 -index 0000000..3a6a587 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h -@@ -0,0 +1,47 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm host helpers to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/util/host_reorder.h" -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+ -+void bind_gemm_host_helper(py::module &m) { -+ m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>); -+ m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h -new file mode 100644 -index 0000000..5968bc0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h -@@ -0,0 +1,47 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind CUTLASS layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "tensor.h" -+#include "matrix.h" -+ -+ -+namespace py = pybind11; -+ -+void bind_layout(py::module &m) { -+ bind_tensor_layout(m); -+ bind_matrix_layout(m); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h -new file mode 100644 -index 0000000..f19e04e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h -@@ -0,0 +1,87 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Matrix layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/layout/matrix.h" -+ -+namespace py = pybind11; -+ -+void bind_matrix_layout(py::module &m) { -+ // -+ // Matrix layouts -+ // cutlass/layout/matrix.h -+ // -+ -+ py::class_(m, "RowMajor", R"pbdoc( -+ Mapping function for row-major matrices. -+ )pbdoc") -+ .def_static("packed", &cutlass::layout::RowMajor::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::RowMajor & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_(m, "ColumnMajor", R"pbdoc( -+ Mapping function for column-major matrices. -+ )pbdoc") -+ .def_static("packed", &cutlass::layout::ColumnMajor::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" ) -+ .def("stride", [](const cutlass::layout::ColumnMajor & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_>(m, "RowMajorInterleaved32", -+ R"pbdoc(Mapping function for interleaved matrices. Matrix is structured -+ as row-major arrangement of fixed-size columns 32)pbdoc") -+ .def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ py::class_>(m, "ColumnMajorInterleaved32", -+ R"pbdoc(Mapping function for interleaved matrices. Matrix is structured -+ as column-major arrangement of fixed-size rows 32)pbdoc") -+ .def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){ -+ return layout.stride().at(0); -+ }, R"pbdoc(Returns the stride of the layout)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h -new file mode 100644 -index 0000000..5edb100 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h -@@ -0,0 +1,74 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Tensor layouts to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/layout/tensor.h" -+ -+namespace py = pybind11; -+ -+void bind_tensor_layout(py::module &m) { -+ // -+ // Tensor layouts -+ // cutlass/include/cutlass/layout/tensor.h -+ // -+ -+ /// Mapping function for 4-D NHWC tensors. -+ py::class_(m, "TensorNHWC", -+ R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorNHWC::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ /// Mapping function for 4-D NC/xHWx tensors. -+ py::class_>(m, "TensorNC32HW32", -+ R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+ -+ /// Mapping function for 4-D CxRSKx tensors. -+ py::class_>(m, "TensorC32RSK32", -+ R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc") -+ .def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed, -+ py::arg("extent"), -+ R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") -+ .def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride), -+ R"pbdoc(Returns the stride of the layout)pbdoc"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h -new file mode 100644 -index 0000000..43991e4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind threadblock swizzling to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -+#include "cutlass/conv/threadblock/threadblock_swizzle.h" -+ -+#include -+#include -+ -+namespace py = pybind11; -+ -+std::string demangle(const char* mangled_name) { -+ std::size_t len = 0; -+ int status = 0; -+ std::unique_ptr ptr( -+ __cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status)); -+ return ptr.get(); -+} -+ -+template -+void bind_identity_swizzle(py::module & m, std::string name) { -+ py::class_(m, name.c_str(), -+ R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc") -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: gemm(M, N, K) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord` -+ )pbdoc") -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_grid_shape", &T::get_grid_shape, -+ py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+template -+void bind_swizzle(py::module & m, std::string name, std::string doc) { -+ py::class_(m, name.c_str(), doc.c_str()) -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: gemm(M, N, K) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord` -+ )pbdoc") -+ .def("get_grid_shape", &T::get_grid_shape, -+ py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+template -+void bind_dgrad_swizzle(py::module & m, std::string name) { -+ py::class_(m, name.c_str(), -+ R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc") -+ .def(py::init<>()) -+ .def("get_tiled_shape", -+ py::overload_cast( -+ &T::get_tiled_shape, py::const_ -+ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), -+ R"pbdoc(Returns the shape of the problem in units of logical tiles -+ -+ :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) -+ :type problem_size: :class:`cutlass.gemm.GemmCoord`) -+ )pbdoc") -+ .def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) { -+ return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); -+ }, py::arg("tiled_shape"), -+ R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") -+ .def("tag", [](const T & swizzle){ -+ return demangle(typeid(T).name()); -+ }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); -+} -+ -+void bind_threadblock_swizzle(py::module &m) { -+ -+ py::class_(m, "dim3", -+ R"pbdoc(A int3 type xyz contains three integers)pbdoc") -+ .def(py::init(), -+ py::arg("x"), py::arg("y"), py::arg("z")) -+ .def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc") -+ .def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc") -+ .def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc"); -+ -+ bind_identity_swizzle>(m, "IdentitySwizzle1"); -+ bind_identity_swizzle>(m, "IdentitySwizzle2"); -+ bind_identity_swizzle>(m, "IdentitySwizzle4"); -+ bind_identity_swizzle>(m, "IdentitySwizzle8"); -+ -+ bind_swizzle(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc"); -+ bind_swizzle(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc"); -+ -+ bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle1"); -+ bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle4"); -+ bind_dgrad_swizzle(m, "StridedDgradHorizontalSwizzle"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h -new file mode 100644 -index 0000000..547df07 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h -@@ -0,0 +1,78 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Tensor Coord to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/tensor_coord.h" -+ -+namespace py = pybind11; -+ -+void bind_tensor_coord(py::module &m) { -+ // -+ // Tensor Coords -+ // cutlass/include/cutlass/tensor_coord.h -+ // -+ -+ /// Defines a canonical 4D coordinate used by tensor operations. -+ py::class_(m, "Tensor4DCoord", -+ R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc") -+ .def(py::init(), -+ py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"), -+ R"pbdoc(Helper to construct from N, H, W, and C)pbdoc") -+ .def("at", py::overload_cast(&cutlass::Tensor4DCoord::at), -+ py::arg("dim"), -+ R"pbdoc(Gets the index of a given Coord element)pbdoc") -+ .def("size", [](const cutlass::Tensor4DCoord & coord) { -+ return coord.at(0) * coord.at(1) * coord.at(2) * coord.at(3);}, -+ R"pbdoc(The size of the tensor coord)pbdoc"); -+ -+ py::class_>(m, "Tensor3DCoord", -+ R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc") -+ .def("at", py::overload_cast(&cutlass::Coord<3>::at), -+ py::arg("dim"), -+ R"pbdoc(Gets the index of a given Coord element)pbdoc"); -+ -+ // Matrix Size -+ py::class_(m, "MatrixCoord", -+ R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes -+ expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc") -+ .def(py::init(), -+ py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc") -+ .def("row", py::overload_cast<>(&cutlass::MatrixCoord::row), -+ R"pbdoc(Returns the row of the coordinate)pbdoc") -+ .def("column", py::overload_cast<>(&cutlass::MatrixCoord::column), -+ R"pbdoc(Returns the column of the coordinate)pbdoc"); -+ -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h -new file mode 100644 -index 0000000..09a4add ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h -@@ -0,0 +1,102 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSE -+#include -+ -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "types.h" -+ -+ -+template -+void bind_tensor_ref_view(py::module &m, std::string name) { -+ py::class_>(m, ("TensorRef" + name).c_str()) -+ .def("__init__", [](cutlass::TensorRef& tensor_ref, int64_t address, const L& layout_ ) { -+ T* ptr = reinterpret_cast< T*>(address); -+ new (&tensor_ref) cutlass::TensorRef(ptr, layout_); -+ }) -+ .def("data", [](cutlass::TensorRef& tensor_ref) { -+ T* ptr = tensor_ref.data(); -+ return int64_t(ptr); -+ }) -+ .def("layout", py::overload_cast<>(&cutlass::TensorRef::layout)); -+ -+ m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) { -+ T* ptr = reinterpret_cast(address); -+ cutlass::TensorRef tensor_ref = cutlass::TensorRef(ptr, layout_); -+ return tensor_ref; -+ }); -+ -+ py::class_>(m, ("TensorView" + name).c_str()) -+ .def(py::init&, const typename L::TensorCoord &>()); -+} -+ -+ -+void bind_tensor_refs_and_views(py::module &m) { -+ -+ /// float -+ bind_tensor_ref_view(m, "F32RowMajor"); -+ bind_tensor_ref_view(m, "F32ColumnMajor"); -+ bind_tensor_ref_view(m, "F32NHWC"); -+ -+ /// double -+ bind_tensor_ref_view(m, "F64RowMajor"); -+ bind_tensor_ref_view(m, "F64ColumnMajor"); -+ bind_tensor_ref_view(m, "F64NHWC"); -+ -+ // half_t -+ bind_tensor_ref_view(m, "F16RowMajor"); -+ bind_tensor_ref_view(m, "F16ColumnMajor"); -+ bind_tensor_ref_view(m, "F16NHWC"); -+ -+ // bfloat16 -+ bind_tensor_ref_view(m, "BF16RowMajor"); -+ bind_tensor_ref_view(m, "BF16ColumnMajor"); -+ bind_tensor_ref_view(m, "BF16NHWC"); -+ -+ // int8_t -+ bind_tensor_ref_view, cutlass::int8>(m, "S8RowMajorInterleaved32"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8ColumnMajorInterleaved32"); -+ bind_tensor_ref_view(m, "S8RowMajor"); -+ bind_tensor_ref_view(m, "S8ColumnMajor"); -+ bind_tensor_ref_view(m, "S8NHWC"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8NC32HW32"); -+ bind_tensor_ref_view, cutlass::int8>(m, "S8C32RSK32"); -+ -+ // int32_t -+ bind_tensor_ref_view(m, "S32RowMajor"); -+ bind_tensor_ref_view(m, "S32ColumnMajor"); -+ bind_tensor_ref_view(m, "S32NHWC"); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h -new file mode 100644 -index 0000000..da16696 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/include/types.h -@@ -0,0 +1,146 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind CUTLASS types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/half.h" -+ -+ -+namespace py = pybind11; -+ -+namespace cutlass { -+ -+/// IEEE 32-bit signed integer -+struct alignas(1) int8 { -+ int8_t storage; -+ explicit int8(int x) { -+ storage = int8_t(x); -+ } -+ explicit int8(float x) { -+ storage = int8_t(x); -+ } -+ -+ int8_t c_value(){return storage;} -+}; -+ -+/// IEEE 32-bit signed integer -+struct alignas(4) int32 { -+ int storage; -+ explicit int32(int x) { -+ storage = x; -+ } -+ explicit int32(float x) { -+ storage = int(x); -+ } -+ -+ int c_value(){return storage;} -+}; -+/// IEEE single-precision floating-point type -+struct alignas(4) float32 { -+ float storage; -+ explicit float32(float x) { -+ storage = x; -+ } -+ explicit float32(int x) { -+ storage = float(x); -+ } -+ float c_value(){return storage;} -+}; -+/// IEEE double-precision floating-point type -+struct alignas(4) float64 { -+ double storage; -+ explicit float64(float x) { -+ storage = double(x); -+ } -+ explicit float64(int x) { -+ storage = double(x); -+ } -+ double c_value(){return storage;} -+}; -+} -+ -+void bind_cutlass_types(py::module &m) { -+ -+ // s8 -+ py::class_(m, "int8") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::int8::storage) -+ .def("value", &cutlass::int8::c_value); -+ -+ // s32 -+ py::class_(m, "int32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::int32::storage) -+ .def("value", &cutlass::int32::c_value); -+ -+ // f16 -+ py::class_(m, "float16") -+ .def(py::init()) -+ .def(py::init()) -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::half_t::storage) -+ .def("value", [](const cutlass::half_t& value) {return value;}); -+ -+ // bf16 -+ py::class_(m, "bfloat16") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::bfloat16_t::storage) -+ .def("value", [](const cutlass::bfloat16_t& value) {return value;}); -+ -+ // f32 -+ py::class_(m, "float32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::float32::storage) -+ .def("value", &cutlass::float32::c_value); -+ -+ // tf32 -+ py::class_(m, "tfloat32") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::tfloat32_t::storage) -+ .def("value", [](const cutlass::tfloat32_t& value) {return value;}); -+ -+ // f64 -+ py::class_(m, "float64") -+ .def(py::init()) -+ .def(py::init()) -+ .def_readwrite("storage", &cutlass::float64::storage) -+ .def("value", &cutlass::float64::c_value); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h -new file mode 100644 -index 0000000..5d46f69 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/library.h -@@ -0,0 +1,32 @@ -+#include -+ -+namespace cutlass { -+ -+/// ENUM class for datatypes -+enum class DataType { -+ kB1, kU2, kU4, kU8, -+ kU16, kU32, kU64, kS2, -+ kS4, kS8, kS16, kS32, -+ kS64, kF16, kBF16, kF32, -+ kTF32, kF64, kCF16, kCBF16, -+ kCF32, kCTF32, kCF64, kCS2, -+ kCS4, kCS8, kCS16, kCS32, -+ kCS64, kCU2, kCU4, kCU8, -+ kCU16, kCU32, kCU64, kInvalid -+}; -+ -+/// ENUM class for LayoutTypes -+enum class LayoutType { -+ kColumnMajor, kRowMajor, -+ kColumnMajorInterleaved2, kRowMajorInterleaved2, -+ kColumnMajorInterleaved32, kRowMajorInterleaved32, -+ kColumnMajorInterleaved64, kRowMajorInterleaved64, -+ kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC, -+ kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32, -+ kTensorC64RSK64 -+}; -+ -+/// ENUM class for opcode class -+ -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h -new file mode 100644 -index 0000000..f2c8ec8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h -@@ -0,0 +1,54 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution problems to python -+*/ -+#pragma once -+#include -+#include -+ -+ -+#include "unit/conv/device/conv2d_problems.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+ -+namespace py = pybind11; -+ -+PYBIND11_MAKE_OPAQUE(std::vector); -+ -+void bind_conv_problem_size_test(py::module &m) { -+ -+ py::bind_vector>(m, "Conv2dProblemVector") -+ .def("size", &std::vector::size); -+ // Get Conv2d problem sizes -+ py::class_(m, "TestbedConv2dProblemSizes") -+ .def(py::init()) -+ .def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h -new file mode 100644 -index 0000000..dd97d28 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h -@@ -0,0 +1,49 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind convolution related types to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "conv_problems.h" -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_convolution_test(py::module &m) { -+ // Conv problem sizes -+ bind_conv_problem_size_test(m); -+ -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_conv_host_references(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h -new file mode 100644 -index 0000000..ca15ce6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h -@@ -0,0 +1,180 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind Convolution host test helpers to python -+*/ -+#pragma once -+#include -+#include -+#include "unit/conv/device/cache_testbed_output.h" -+ -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+ -+namespace py = pybind11; -+ -+ -+template -+void bind_conv2d_host(py::module &m) { -+ m.def("conv2d", \ -+ &cutlass::reference::host::Conv2d< \ -+ Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>); -+ -+ m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); -+} -+ -+template -+void bind_conv2d_host_sat(py::module &m) { -+ m.def("conv2d", \ -+ &cutlass::reference::host::Conv2d< \ -+ Ta, La, Tb, Lb, Tc, Lc, Te, Tacc, cutlass::NumericConverterClamp>); -+ -+ m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); -+} -+ -+template -+void bind_conv2d_host_nhwc(py::module &m) { -+ bind_conv2d_host< -+ Ta, cutlass::layout::TensorNHWC, -+ Tb, cutlass::layout::TensorNHWC, -+ Tc, cutlass::layout::TensorNHWC, -+ Tacc, Te>(m); -+} -+ -+template -+void bind_conv2d_host_nc32hw32(py::module &m) { -+ bind_conv2d_host_sat< -+ Ta, cutlass::layout::TensorNCxHWx<32>, -+ Tb, cutlass::layout::TensorCxRSKx<32>, -+ Tc, cutlass::layout::TensorNCxHWx<32>, -+ Tacc, Te>(m); -+} -+ -+ -+template -+void bind_tensor_equals(py::module &m) { -+ m.def("equals", py::overload_cast< -+ const cutlass::TensorView&, const cutlass::TensorView&>( -+ &cutlass::reference::host::TensorEquals -+ )); -+} -+ -+#define BIND_TENSOR_HASH(Element, Layout) { \ -+ m.def("TensorHash", &test::conv::device::TensorHash, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \ -+} -+ -+void bind_conv_host_references(py::module &m) { -+ // -+ // Conv2d reference on host -+ // tools/util/include/cutlass/util/reference/host/convolution.h -+ -+ /// double -+ bind_conv2d_host_nhwc(m); -+ /// float -+ bind_conv2d_host_nhwc(m); -+ /// half -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ /// bfloat16 -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ /// s8 -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ bind_conv2d_host_nhwc(m); -+ -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ bind_conv2d_host_nc32hw32(m); -+ -+ // -+ // Compare whether two tensors are equal -+ // -+ /// double -+ bind_tensor_equals(m); -+ /// float -+ bind_tensor_equals(m); -+ /// half -+ bind_tensor_equals(m); -+ /// bfloat16 -+ bind_tensor_equals(m); -+ /// s32 -+ bind_tensor_equals(m); -+ bind_tensor_equals>(m); -+ /// s8 -+ bind_tensor_equals(m); -+ bind_tensor_equals>(m); -+ -+ /// Cache -+ py::class_(m, "CachedTestKey") -+ .def(py::init<>()) -+ .def(py::init()); -+ -+ py::class_(m, "CachedTestResult") -+ .def(py::init<>()) -+ .def(py::init()) -+ .def_readwrite("D", &test::conv::device::CachedTestResult::D); -+ -+ py::class_(m, "CachedTestResultListing") -+ .def(py::init()) -+ .def("find", &test::conv::device::CachedTestResultListing::find) -+ .def("append", &test::conv::device::CachedTestResultListing::append) -+ .def("write", &test::conv::device::CachedTestResultListing::write); -+ -+ py::class_(m, "CRC32") -+ .def(py::init<>()); -+ -+ BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC) -+ BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC); -+ BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h -new file mode 100644 -index 0000000..749d8d9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h -@@ -0,0 +1,45 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm test to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "host.h" -+ -+namespace py = pybind11; -+ -+void bind_gemm_test(py::module &m) { -+ py::module_ host_submodule = m.def_submodule("host"); -+ bind_gemm_host_reference(host_submodule); -+} -diff --git a/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h -new file mode 100644 -index 0000000..c6aeee8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h -@@ -0,0 +1,431 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Bind gemm test host functions to python -+*/ -+#pragma once -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/util/reference/host/gemm.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/util/host_reorder.h" -+ -+#include "cutlass/functional.h" -+ -+namespace py = pybind11; -+ -+ -+template< -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename AccumulatorType, typename ComputeType, -+ typename InnerProductOp> -+void bind_host_gemm_saturate(py::module &m) { -+ m.def("gemm_saturate", py::overload_cast< -+ cutlass::gemm::GemmCoord, ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ AccumulatorType>( -+ &cutlass::reference::host::compute_gemm< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ComputeType, -+ AccumulatorType, -+ InnerProductOp, -+ cutlass::NumericConverterClamp> -+ )); -+} -+ -+template< -+ typename ElementA, typename LayoutA, -+ typename ElementB, typename LayoutB, -+ typename ElementC, typename LayoutC, -+ typename AccumulatorType, typename ComputeType, -+ typename InnerProductOp> -+void bind_host_gemm(py::module &m) { -+ m.def("gemm", py::overload_cast< -+ cutlass::gemm::GemmCoord, ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ ComputeType, -+ cutlass::TensorRef, -+ cutlass::TensorRef, -+ AccumulatorType>( -+ &cutlass::reference::host::compute_gemm< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ComputeType, -+ AccumulatorType, -+ InnerProductOp, -+ cutlass::NumericConverter> -+ )); -+} -+ -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add(py::module &m) { -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_saturate(py::module &m) { -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::RowMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::RowMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajor, -+ ElementB, cutlass::layout::ColumnMajor, -+ ElementC, cutlass::layout::ColumnMajor, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_interleaved(py::module &m) { -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+template< -+ typename ElementA, typename ElementB, typename ElementC, -+ typename AccumulatorType, typename ComputeType> -+void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) { -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ ComputeType, AccumulatorType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::RowMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::RowMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::RowMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+ -+ bind_host_gemm_saturate< -+ ElementA, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementB, cutlass::layout::ColumnMajorInterleaved<32>, -+ ElementC, cutlass::layout::ColumnMajorInterleaved<32>, -+ AccumulatorType, ComputeType, -+ cutlass::multiply_add>(m); -+} -+ -+#define BIND_TENSOR_EQUAL(Element, Layout) { \ -+ m.def("equals", py::overload_cast< \ -+ const cutlass::TensorView&, const cutlass::TensorView&>( \ -+ &cutlass::reference::host::TensorEquals)); \ -+} -+ -+void bind_gemm_host_reference(py::module &m) { -+ -+ /// double -+ bind_host_gemm_multiply_add(m); -+ /// float -+ bind_host_gemm_multiply_add(m); -+ /// half_t -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ /// bfloat16 -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ -+ /// s8 -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ bind_host_gemm_multiply_add(m); -+ -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ bind_host_gemm_multiply_add_interleaved(m); -+ -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ bind_host_gemm_multiply_add_saturate(m); -+ -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ bind_host_gemm_multiply_add_saturate_interleaved(m); -+ -+ // float -+ BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor); -+ -+ // double -+ BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor); -+ -+ // half_t -+ BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor); -+ -+ // bfloat16 -+ BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor); -+ -+ // int32_t -+ BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor); -+ -+ // int8_t -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>); -+ BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>); -+ -+ -+} -diff --git a/3rdparty/cutlass/tools/library/src/conv2d_operation.h b/3rdparty/cutlass/tools/library/src/conv2d_operation.h -new file mode 100644 -index 0000000..5d06e72 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/conv2d_operation.h -@@ -0,0 +1,642 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv2d_fprop.h" -+#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" -+#include "cutlass/conv/kernel/default_depthwise_fprop.h" -+#include "cutlass/conv/kernel/default_conv2d_dgrad.h" -+#include "cutlass/conv/kernel/default_conv2d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+#include "cutlass/conv/device/direct_convolution.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Conv2dOperationBase : public Operation { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ Conv2dOperationBase(char const *name = "unknown_conv2d") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kConv2d; -+ description_.conv_dim = Operator::kConvDim; -+ -+ description_.iterator_algorithm = IteratorAlgorithmMap::kId; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::UnderlyingKernel::WarpCount::kM, -+ Operator::UnderlyingKernel::WarpCount::kN, -+ Operator::UnderlyingKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ // TODO: Add split k mode Serial and parallel to convolutions -+ // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; -+ -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+class Conv2dOperation : public Conv2dOperationBase { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv2dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv2dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv2dOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << "}" << std::endl; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// DirectConv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DirectConv2dOperation : public Conv2dOperation { -+public: -+ -+ using Operator = Operator_; -+ using Base = Conv2dOperation; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv2dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_reordered_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv2dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv2dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv2dOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << "}" << std::endl; -+ } -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/conv3d_operation.h b/3rdparty/cutlass/tools/library/src/conv3d_operation.h -new file mode 100644 -index 0000000..0e2a1c6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/conv3d_operation.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/conv/kernel/default_conv3d_fprop.h" -+#include "cutlass/conv/kernel/default_conv3d_dgrad.h" -+#include "cutlass/conv/kernel/default_conv3d_wgrad.h" -+#include "cutlass/conv/device/implicit_gemm_convolution.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/host/tensor_compare.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Conv3dOperationBase : public Operation { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ Conv3dOperationBase(char const *name = "unknown_conv3d") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kConv3d; -+ description_.conv_dim = Operator::kConvDim; -+ -+ description_.iterator_algorithm = IteratorAlgorithmMap::kId; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::UnderlyingKernel::WarpCount::kM, -+ Operator::UnderlyingKernel::WarpCount::kN, -+ Operator::UnderlyingKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Conv2d library operation class for cutlass profiler -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+template -+class Conv3dOperation : public Conv3dOperationBase { -+public: -+ -+ using Operator = Operator_; -+ -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ /// Constructor -+ Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { -+ this->description_.conv_kind = ConvKindMap::kId; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ Conv3dConfiguration const *configuration) { -+ -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = -+ { -+ nullptr, -+ LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_B = -+ { -+ nullptr, -+ LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_C = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.ref_D = -+ { -+ nullptr, -+ LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) -+ }; -+ -+ operator_args.split_k_mode = configuration->split_k_mode; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ConvArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); -+ operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); -+ operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); -+ operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ Conv3dConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ConvArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Conv3dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ //std::cout << "run library::Conv3dOperation" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Conv3dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Conv3dOperation::OperatorArguments" << std::endl -+ << " problem_size: " -+ << operator_args.problem_size << std::endl -+ << " split_k_mode: " -+ << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output_op.alpha << ", " -+ << operator_args.output_op.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ref_A.data() << ", {" -+ << operator_args.ref_A.stride(0) << ", " -+ << operator_args.ref_A.stride(1) << ", " -+ << operator_args.ref_A.stride(2) << ", " -+ << operator_args.ref_A.stride(3) << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ref_B.data() << ", {" -+ << operator_args.ref_B.stride(0) << ", " -+ << operator_args.ref_B.stride(1) << ", " -+ << operator_args.ref_B.stride(2) << ", " -+ << operator_args.ref_B.stride(3) << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ref_C.data() << ", {" -+ << operator_args.ref_C.stride(0) << ", " -+ << operator_args.ref_C.stride(1) << ", " -+ << operator_args.ref_C.stride(2) << ", " -+ << operator_args.ref_C.stride(3) << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ref_D.data() << ", {" -+ << operator_args.ref_D.stride(0) << ", " -+ << operator_args.ref_D.stride(1) << ", " -+ << operator_args.ref_D.stride(2) << ", " -+ << operator_args.ref_D.stride(3) << "}" << std::endl; -+ } -+}; -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/gemm_operation.h b/3rdparty/cutlass/tools/library/src/gemm_operation.h -new file mode 100644 -index 0000000..ab5704b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/gemm_operation.h -@@ -0,0 +1,1356 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/gemm.h" -+#include "cutlass/gemm/device/gemm_sparse.h" -+#include "cutlass/gemm/device/gemm_complex.h" -+#include "cutlass/gemm/device/gemm_batched.h" -+#include "cutlass/gemm/device/gemm_array.h" -+#include "cutlass/gemm/device/gemm_universal_adapter.h" -+#include "cutlass/gemm/kernel/default_gemm_universal.h" -+#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ // assuming all tensors use same type for StrideIndex -+ using StrideIndex = typename Operator::LayoutA::Index; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmOperationBase(char const *name = "unknown_gemm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = GemmKind::kGemm; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::GemmKernel::WarpCount::kM, -+ Operator::GemmKernel::WarpCount::kN, -+ Operator::GemmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kGemm; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ -+ operator_args.ref_A = {nullptr, configuration->lda}; -+ operator_args.ref_B = {nullptr, configuration->ldb}; -+ operator_args.ref_C = {nullptr, configuration->ldc}; -+ operator_args.ref_D = {nullptr, configuration->ldd}; -+ -+ operator_args.split_k_slices = configuration->split_k_slices; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(arguments->A)); -+ operator_args.ref_B.reset(static_cast(arguments->B)); -+ operator_args.ref_C.reset(static_cast(arguments->C)); -+ operator_args.ref_D.reset(static_cast(arguments->D)); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return op->run(stream); -+ } -+ -+ void print_operator_args(OperatorArguments &operator_args) const { -+#if 0 -+ std::cout << "GemmOperation::OperatorArguments" << std::endl; -+ std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; -+ std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; -+ std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; -+ std::cout << " beta: " << operator_args.epilogue.beta << std::endl; -+ std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; -+ std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; -+ std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; -+ std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; -+ std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; -+ std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; -+ std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -+#endif -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmSparseOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementE = typename Operator::ElementE; -+ using LayoutE = typename Operator::LayoutE; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.kind = OperationKind::kSparseGemm; -+ this->description_.gemm_kind = GemmKind::kSparse; -+ this->description_.E = make_TensorDescription(Operator::kAlignmentE); -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ SparseGemmConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.ref_A = {nullptr, configuration->lda}; -+ operator_args.ref_B = {nullptr, configuration->ldb}; -+ operator_args.ref_C = {nullptr, configuration->ldc}; -+ operator_args.ref_D = {nullptr, configuration->ldd}; -+ operator_args.ref_E = {nullptr, configuration->lde}; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ SparseGemmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.ref_A.reset(static_cast(arguments->A)); -+ operator_args.ref_B.reset(static_cast(arguments->B)); -+ operator_args.ref_C.reset(static_cast(arguments->C)); -+ operator_args.ref_D.reset(static_cast(arguments->D)); -+ operator_args.ref_E.reset(static_cast(arguments->E)); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ SparseGemmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ SparseGemmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return op->run(stream); -+ } -+ -+ void print_operator_args(OperatorArguments &operator_args) const { -+#if 0 -+ std::cout << "GemmOperation::OperatorArguments" << std::endl; -+ std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; -+ std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; -+ std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; -+ std::cout << " beta: " << operator_args.epilogue.beta << std::endl; -+ std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; -+ std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; -+ std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; -+ std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; -+ std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; -+ std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; -+ std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; -+#endif -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversalOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmUniversalOperation(char const *name = "unknown_gemm"): -+ GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmUniversalConfiguration const *configuration) { -+ -+ operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = (configuration->lda); -+ operator_args.ldb = (configuration->ldb); -+ operator_args.ldc = (configuration->ldc); -+ operator_args.ldd = (configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmUniversalArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmUniversalConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmUniversalArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmPlanarComplexOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kPlanarComplex; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexConfiguration const *configuration) { -+ -+ operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ -+ operator_args.lda_real = configuration->lda_real; -+ operator_args.lda_imag = configuration->lda_imag; -+ operator_args.ldb_real = configuration->ldb_real; -+ operator_args.ldb_imag = configuration->ldb_imag; -+ operator_args.ldc_real = configuration->ldc_real; -+ operator_args.ldc_imag = configuration->ldc_imag; -+ operator_args.ldd_real = configuration->ldd_real; -+ operator_args.ldd_imag = configuration->ldd_imag; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast const *>(arguments->alpha), -+ *static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast const *>(arguments->alpha), -+ static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A_real = arguments->A_real; -+ operator_args.ptr_A_imag = arguments->A_imag; -+ operator_args.ptr_B_real = arguments->B_real; -+ operator_args.ptr_B_imag = arguments->B_imag; -+ operator_args.ptr_C_real = arguments->C_real; -+ operator_args.ptr_C_imag = arguments->C_imag; -+ operator_args.ptr_D_real = arguments->D_real; -+ operator_args.ptr_D_imag = arguments->D_imag; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A_real; -+ operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; -+ operator_args.batch_stride_B = arguments->batch_stride_B_real; -+ operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; -+ operator_args.batch_stride_C = arguments->batch_stride_C_real; -+ operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; -+ operator_args.batch_stride_D = arguments->batch_stride_D_real; -+ operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmPlanarComplexConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmPlanarComplexArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmPlanarComplexArrayOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kPlanarComplexArray; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArrayConfiguration const *configuration) { -+ -+ operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda_real = configuration->lda_real; -+ operator_args.lda_imag = configuration->lda_imag; -+ operator_args.ldb_real = configuration->ldb_real; -+ operator_args.ldb_imag = configuration->ldb_imag; -+ operator_args.ldc_real = configuration->ldc_real; -+ operator_args.ldc_imag = configuration->ldc_imag; -+ operator_args.ldd_real = configuration->ldd_real; -+ operator_args.ldd_imag = configuration->ldd_imag; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ GemmPlanarComplexArrayArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast const *>(arguments->alpha), -+ *static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast const *>(arguments->alpha), -+ static_cast const *>(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A_real = arguments->A_real; -+ operator_args.ptr_A_imag = arguments->A_imag; -+ operator_args.ptr_B_real = arguments->B_real; -+ operator_args.ptr_B_imag = arguments->B_imag; -+ operator_args.ptr_C_real = arguments->C_real; -+ operator_args.ptr_C_imag = arguments->C_imag; -+ operator_args.ptr_D_real = arguments->D_real; -+ operator_args.ptr_D_imag = arguments->D_imag; -+ -+ operator_args.ptr_M = arguments->M; -+ operator_args.ptr_N = arguments->N; -+ operator_args.ptr_K = arguments->K; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmPlanarComplexArrayConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmPlanarComplexArrayArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmGroupedOperation : public GemmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ GemmGroupedOperation(char const *name = "unknown_gemm"): -+ GemmOperationBase(name) { -+ -+ this->description_.gemm_kind = GemmKind::kGrouped; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &op_args, -+ GemmGroupedConfiguration const *config) { -+ -+ op_args.problem_count = config->problem_count; -+ op_args.threadblock_count = config->threadblock_count; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &op_args, -+ GemmGroupedArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ -+ op_args.output_op = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { -+ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ -+ op_args.output_op = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ op_args.problem_sizes = arguments->problem_sizes; -+ -+ op_args.ptr_A = static_cast(arguments->ptr_A); -+ op_args.ptr_B = static_cast(arguments->ptr_B); -+ op_args.ptr_C = static_cast(arguments->ptr_C); -+ op_args.ptr_D = static_cast(arguments->ptr_D); -+ -+ op_args.lda = arguments->lda; -+ op_args.ldb = arguments->ldb; -+ op_args.ldc = arguments->ldc; -+ op_args.ldd = arguments->ldd; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ GemmGroupedConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ GemmGroupedArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp b/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp -new file mode 100644 -index 0000000..895de5b ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/gemm_operation_3x.hpp -@@ -0,0 +1,292 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all GEMM operation kinds in CUTLASS Library. -+*/ -+ -+#pragma once -+#include "cutlass/cutlass.h" -+#include "cutlass/kernel_hardware_info.hpp" -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmOperation3xBase : public Operation { -+public: -+ using Operator = Operator_; -+ using OperatorArguments = typename Operator::Arguments; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ // assuming all tensors use same type for StrideIndex -+ using StrideIndex = typename Operator::LayoutA::Index; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::CollectiveEpilogue::ElementCompute; -+ -+private: -+ -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = gemm_kind_; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { -+ description_.tile_description.cluster_shape = make_Coord( -+ Operator::ClusterShape::kM, -+ Operator::ClusterShape::kN, -+ Operator::ClusterShape::kK); -+ } -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::WarpCount::kM, -+ Operator::WarpCount::kN, -+ Operator::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class GemmUniversal3xOperation : public GemmOperation3xBase { -+public: -+ -+ using Operator = Operator_; -+ using OperatorArguments = typename Operator::Arguments; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using CollectiveMainloop = typename Operator::CollectiveMainloop; -+ using CollectiveEpilogue = typename Operator::CollectiveEpilogue; -+ using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; -+ -+public: -+ -+ /// Constructor -+ GemmUniversal3xOperation(char const *name = "unknown_gemm"): -+ GemmOperation3xBase(name, GemmKind::kUniversal) { -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { -+ // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides -+ // Do nothing here and construct kernel arguments in update_arguments_ instead -+ // We also cannot construct TMA descriptors without all the arguments available -+ -+ if (operator_args.hw_info.sm_count <= 0) { -+ operator_args.hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); -+ } -+ operator_args.mode = configuration->mode; -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, GemmUniversalArguments const *arguments) { -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename ThreadEpilogueOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta)); -+ operator_args.epilogue_params.thread_params = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { -+ typename ThreadEpilogueOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta)); -+ operator_args.epilogue_params.thread_params = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // TODO: type erase Arguments structure in 3.0 GEMM -+ operator_args.problem_shape = cute::make_shape( -+ arguments->problem_size.m(), -+ arguments->problem_size.n(), -+ arguments->problem_size.k(), -+ arguments->batch_count); -+ -+ // update arguments -+ operator_args.ptr_A = static_cast(arguments->A); -+ operator_args.ptr_B = static_cast(arguments->B); -+ operator_args.epilogue_params.ptr_C = static_cast(arguments->C); -+ operator_args.epilogue_params.ptr_D = static_cast(arguments->D); -+ -+ operator_args.dA = cute::make_int_tuple_from( -+ arguments->lda, arguments->batch_stride_A); -+ operator_args.dB = cute::make_int_tuple_from( -+ arguments->ldb, arguments->batch_stride_B); -+ operator_args.epilogue_params.dC = cute::make_int_tuple_from( -+ arguments->ldc, arguments->batch_stride_C); -+ operator_args.epilogue_params.dD = operator_args.epilogue_params.dC; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ Status can_implement( -+ void const *configuration_ptr, void const *arguments_ptr) const override { -+ -+ GemmUniversalArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ auto status = update_arguments_(args, arguments); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ uint64_t get_host_workspace_size(void const *configuration) const override { -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ uint64_t get_device_workspace_size( -+ void const *configuration_ptr,void const *arguments_ptr) const override { -+ -+ OperatorArguments args; -+ auto status = update_arguments_( -+ args, static_cast(arguments_ptr)); -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ return size; -+ } -+ -+ /// Initializes the workspace -+ Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const override { -+ Operator *op = new (host_workspace) Operator; -+ return Status::kSuccess; -+ } -+ -+ /// Runs the kernel -+ Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const override { -+ -+ OperatorArguments args; -+ Status status = update_arguments_(args, static_cast(arguments_ptr)); -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ // We need to call initialize() since we have to rebuild TMA desc for every new set of args -+ status = op->run(args, device_workspace, stream); -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass::library -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/handle.cu b/3rdparty/cutlass/tools/library/src/handle.cu -new file mode 100644 -index 0000000..fdfe251 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/handle.cu -@@ -0,0 +1,1172 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief CUTLASS Library handle. -+*/ -+#include -+#include -+#include -+ -+#include "cutlass/library/handle.h" -+#include "cutlass/library/singleton.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructor -+Handle::Handle( -+ cudaStream_t stream, -+ size_t workspace_size -+): -+ provider_(Provider::kCUTLASS), -+ stream_(stream), -+ workspace_(nullptr), -+ workspace_size_(0), -+ scalar_pointer_mode_(ScalarPointerMode::kHost), -+ last_operation_(nullptr) { -+ -+ int device_idx = -1; -+ -+ cudaError_t error = cudaGetDevice(&device_idx); -+ if (error != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ error = cudaGetDeviceProperties(&device_, device_idx); -+ if (error != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed"); -+ } -+ -+ set_workspace_size(workspace_size); -+ -+ Singleton::get(); -+} -+ -+/// Destructor -+Handle::~Handle() { -+ if (workspace_) { -+ -+ if (workspace_) { -+ cudaFree(workspace_); -+ } -+ -+ workspace_ = nullptr; -+ workspace_size_ = 0; -+ } -+} -+ -+/// Move constructor -+Handle::Handle(Handle && handle) { -+ device_ = handle.device_; -+ workspace_size_ = handle.workspace_size_; -+ workspace_ = handle.workspace_; -+ stream_ = handle.stream_; -+ scalar_pointer_mode_ = handle.scalar_pointer_mode_; -+ -+ handle.workspace_ = nullptr; -+ handle.workspace_size_ = 0; -+} -+ -+/// Move assignment operator -+Handle & Handle::operator=(Handle && handle) { -+ -+ provider_ = handle.provider_; -+ device_ = handle.device_; -+ workspace_size_ = handle.workspace_size_; -+ workspace_ = handle.workspace_; -+ stream_ = handle.stream_; -+ scalar_pointer_mode_ = handle.scalar_pointer_mode_; -+ -+ handle.workspace_ = nullptr; -+ handle.workspace_size_ = 0; -+ -+ return *this; -+} -+ -+int Handle::compute_capability() const { -+ return device_.major * 10 + device_.minor; -+} -+ -+/// Sets the current CUDA stream -+void Handle::set_stream(cudaStream_t stream) { -+ stream_ = stream; -+} -+ -+/// Gets the current CUDA stream -+cudaStream_t Handle::get_stream() const { -+ return stream_; -+} -+ -+/// Gets the current provider -+Provider Handle::get_provider() const { -+ return provider_; -+} -+ -+/// Sets the provider of operations -+void Handle::set_provider(Provider provider) { -+ provider_ = provider; -+} -+ -+/// Gets the device workspace size -+size_t Handle::get_workspace_size() const { -+ return workspace_size_; -+} -+ -+/// Gets a pointer to the device workspace allocation in Global Memory -+void *Handle::get_workspace() const { -+ return workspace_; -+} -+ -+/// Sets the size of device workspace, invalidating previous calls to get_device_workspace() -+void Handle::set_workspace_size(size_t bytes) { -+ if (bytes != workspace_size_) { -+ -+ if (workspace_) { -+ cudaFree(workspace_); -+ } -+ -+ workspace_ = nullptr; -+ workspace_size_ = bytes; -+ -+ if (workspace_size_) { -+ -+ cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate workspace"); -+ } -+ } -+ } -+ -+ if (workspace_) { -+ cudaError_t error = cudaMemset(workspace_, 0, workspace_size_); -+ -+ if (error != cudaSuccess) { -+ throw std::runtime_error("Failed to clear workspace"); -+ } -+ } -+} -+ -+/// Gets the scalar pointer mode -+ScalarPointerMode Handle::get_scalar_pointer_mode() const { -+ return scalar_pointer_mode_; -+} -+ -+/// Sets the scalar pointer mode -+void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) { -+ scalar_pointer_mode_ = mode; -+} -+ -+/// Gets the last operation -+Operation const *Handle::get_last_operation() const { -+ return last_operation_; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the maximum required alignment for each operator -+static int maximum_alignment_requirement(GemmDescription const &desc) { -+ return std::max( -+ std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment); -+} -+ -+/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a -+/// given upper limit. -+static int gemm_problem_alignment( -+ int M, -+ int N, -+ int K, -+ NumericTypeID element_A, -+ void const *ptr_A, -+ int64_t lda, -+ int64_t batch_stride_A, -+ NumericTypeID element_B, -+ void const *ptr_B, -+ int64_t ldb, -+ int64_t batch_stride_B, -+ NumericTypeID element_C, -+ void const * ptr_C, -+ int64_t ldc, -+ int64_t batch_stride_C, -+ void const * ptr_D, -+ int64_t ldd, -+ int64_t batch_stride_D, -+ int max_alignment_in_bytes = 16 -+) { -+ -+ void const *pointers[] = { -+ ptr_A, ptr_B, ptr_C, ptr_D -+ }; -+ -+ int64_t extents[] = { -+ M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D -+ }; -+ -+ NumericTypeID elements[] = { -+ element_A, element_B, element_C -+ }; -+ -+ for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { -+ -+ bool satisfied = true; -+ -+ // Can pointers satisfy this? -+ for (void const *ptr : pointers) { -+ std::uintptr_t int_ptr = reinterpret_cast(ptr); -+ -+ if (int_ptr % max_alignment_in_bytes) { -+ satisfied = false; -+ break; -+ } -+ } -+ -+ if (!satisfied) { -+ continue; -+ } -+ -+ // Compute the maximum alignment based on element data types -+ int max_element_alignment = 0; -+ -+ for (NumericTypeID type_id : elements) { -+ int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); -+ max_element_alignment = std::max(max_element_alignment, element_alignment); -+ } -+ -+ // Can the problem size and leading dimensions satisfy this? -+ for (int64_t extent : extents) { -+ if (extent % max_element_alignment) { -+ satisfied = false; -+ break; -+ } -+ } -+ -+ if (!satisfied) { -+ continue; -+ } -+ -+ // Yes -+ return max_element_alignment; -+ } -+ -+ // No alignment satisfies this problem -+ return 0; -+} -+ -+/// Find the best kernel in descending order of preference. -+static Operation const * find_gemm_operation( -+ GemmOperationFunctionalMap::const_iterator operators_it, -+ GemmPreferenceKey const preference_key) { -+ -+ auto cc_it = operators_it->second.upper_bound(preference_key); -+ -+ if (cc_it == operators_it->second.begin()) { -+ return nullptr; -+ } -+ -+ Operation const *operation = nullptr; -+ -+ // Search in descending order of compute capability -+ do { -+ --cc_it; -+ -+ // Search tile sizes in order, for now. -+ for (auto const * op : cc_it->second) { -+ -+ GemmDescription const &desc = static_cast(op->description()); -+ -+ int min_cc = desc.tile_description.minimum_compute_capability; -+ int max_cc = desc.tile_description.maximum_compute_capability; -+ -+ int op_alignment = maximum_alignment_requirement(desc); -+ -+ if ((min_cc <= preference_key.compute_capability) && -+ (preference_key.compute_capability <= max_cc) && -+ (op_alignment <= preference_key.alignment)) { -+ -+ operation = op; -+ break; -+ } -+ } -+ } while (!operation && cc_it != operators_it->second.begin()); -+ -+ return operation; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Executes a GEMM computation: D <= alpha * A*B + beta * C -+Status Handle::gemm( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd /// Leading dimension of D matrix -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kGemm, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A, lda, 0, -+ element_B, ptr_B, ldb, 0, -+ element_C, ptr_C, ldc, 0, -+ ptr_D, ldd, 0, kMaximumAlignmentSize -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmConfiguration configuration{ -+ {M, N, K}, -+ lda, -+ ldb, -+ ldc, -+ ldd, -+ 1 -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmArguments arguments{ -+ ptr_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ alpha, -+ beta, -+ scalar_pointer_mode_ -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Executes a GEMM computation: D <= alpha * A*B + beta * C. -+// -+// Supports batched-strided, batched array or split-K serial or split-K parallel. -+// -+Status Handle::gemm_universal( -+ -+ GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices -+ -+ void const * ptr_A, /// Pointer to A matrix in Global Memory -+ int64_t lda, /// Leading dimension of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices -+ -+ void const * ptr_B, /// Pointer to B matrix in Global Memory -+ int64_t ldb, /// Leading dimension of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrices -+ -+ void const * ptr_C, /// Pointer to C matrix -+ int64_t ldc, /// Leading dimension of C matrix -+ -+ void * ptr_D, /// Pointer to D matrix -+ int64_t ldd, /// Leading dimension of D matrix -+ -+ int batch_count, /// Batch count or number of split-K slices -+ -+ int64_t batch_stride_A, /// Batch stride of A operand -+ int64_t batch_stride_B, /// Batch stride of B operand -+ int64_t batch_stride_C, /// Batch stride of C operand -+ int64_t batch_stride_D /// Batch stride of D operand -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kUniversal, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ void const *ptr_A_check = ptr_A; -+ void const *ptr_B_check = ptr_B; -+ void const *ptr_C_check = ptr_C; -+ void * ptr_D_check = ptr_D; -+ -+ // Ignore alignment of pointers to pointers. We can't check this from the host, -+ // as each batch index has its own pointer in device memory. -+ if (mode == GemmUniversalMode::kArray) { -+ ptr_A_check = nullptr; -+ ptr_B_check = nullptr; -+ ptr_C_check = nullptr; -+ ptr_D_check = nullptr; -+ } -+ -+ int alignment = gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_check, lda, 0, -+ element_B, ptr_B_check, ldb, 0, -+ element_C, ptr_C_check, ldc, 0, -+ ptr_D_check, ldd, 0, kMaximumAlignmentSize -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmUniversalConfiguration configuration{ -+ mode, -+ {M, N, K}, -+ batch_count, -+ lda, -+ ldb, -+ ldc, -+ ldd -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ GemmUniversalArguments arguments{ -+ {M, N, K}, -+ batch_count, -+ ptr_A, -+ ptr_B, -+ ptr_C, -+ ptr_D, -+ alpha, -+ beta, -+ scalar_pointer_mode_, -+ lda, -+ ldb, -+ ldc, -+ ldd, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ }; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Planar complex GEMM -+Status Handle::gemm_planar_complex( -+ -+ int M, /// GEMM M dimension -+ int N, /// GEMM N dimension -+ int K, /// GEMM K dimension -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * ptr_A_real, /// Pointer to real part of A matrix -+ void const * ptr_A_imag, /// Pointer to imaginary part of A matrix -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * ptr_B_real, /// Pointer to real part of B matrix -+ void const * ptr_B_imag, /// Pointer to imaginary part of B matrix -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * ptr_C_real, /// Pointer to real part of C matrix -+ void const * ptr_C_imag, /// Pointer to imaginary part of C matrix -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * ptr_D_real, /// Pointer to real part of D matrix -+ void * ptr_D_imag, /// Pointer to imaginary part of D matrix -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix -+ -+ int batch_count, /// Number of batched GEMMs to execute -+ -+ int64_t batch_stride_A_real, -+ int64_t batch_stride_A_imag, -+ -+ int64_t batch_stride_B_real, -+ int64_t batch_stride_B_imag, -+ -+ int64_t batch_stride_C_real, -+ int64_t batch_stride_C_imag, -+ -+ int64_t batch_stride_D_real, -+ int64_t batch_stride_D_imag -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kPlanarComplex, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = std::max( -+ gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_real, lda_real, batch_stride_A_real, -+ element_B, ptr_B_real, ldb_real, batch_stride_B_real, -+ element_C, ptr_C_real, ldc_real, batch_stride_C_real, -+ ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize -+ ), -+ gemm_problem_alignment( -+ M, N, K, -+ element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, -+ element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, -+ element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, -+ ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize -+ ) -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmPlanarComplexConfiguration configuration{ -+ GemmUniversalMode::kBatched, -+ {M, N, K}, -+ batch_count, -+ lda_real, -+ lda_imag, -+ ldb_real, -+ ldb_imag, -+ ldc_real, -+ ldc_imag, -+ ldd_real, -+ ldd_imag -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmPlanarComplexArguments arguments{ -+ ptr_A_real, -+ ptr_A_imag, -+ ptr_B_real, -+ ptr_B_imag, -+ ptr_C_real, -+ ptr_C_imag, -+ ptr_D_real, -+ ptr_D_imag, -+ alpha, -+ beta, -+ scalar_pointer_mode_, -+ batch_stride_A_real, -+ batch_stride_A_imag, -+ batch_stride_B_real, -+ batch_stride_B_imag, -+ batch_stride_C_real, -+ batch_stride_C_imag, -+ batch_stride_D_real, -+ batch_stride_D_imag -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Planar complex batched GEMM loading pointers from arrays in global memory -+Status Handle::gemm_planar_complex_array( -+ -+ int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) -+ int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) -+ int expected_K, /// Expected GEMM K dimension -+ int batch_count, /// Number of independent GEMM computations to execute -+ -+ int const *M, /// Array containing the GEMM M dimension for each batch index -+ int const *N, /// Array containing the GEMM N dimension for each batch index -+ int const *K, /// Array containing the GEMM K dimension for each batch index -+ -+ NumericTypeID element_compute, /// Data type of internal accumulation -+ -+ NumericTypeID element_scalar, /// Data type of alpha/beta scalars -+ -+ void const *alpha, /// Pointer to alpha scalar -+ -+ NumericTypeID element_A, /// Data type of A matrix elements -+ LayoutTypeID layout_A, /// Layout of A matrix -+ ComplexTransform transform_A, /// Complex transformation applied to A matrix -+ -+ void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices -+ void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices -+ -+ int64_t lda_real, /// Leading dimension of real part of A matrix -+ int64_t lda_imag, /// Leading dimension of imaginary part of A matrix -+ -+ NumericTypeID element_B, /// Data type of B matrix elements -+ LayoutTypeID layout_B, /// Layout of B matrix -+ ComplexTransform transform_B, /// Complex transformation applied to B matrix -+ -+ void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices -+ void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices -+ -+ int64_t ldb_real, /// Leading dimension of real part of B matrix -+ int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix -+ -+ void const * beta, /// Pointer to beta scalar -+ -+ NumericTypeID element_C, /// Data type of C and D matrix -+ -+ void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices -+ void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices -+ -+ int64_t ldc_real, /// Leading dimension of real part of C matrix -+ int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix -+ -+ void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices -+ void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices -+ -+ int64_t ldd_real, /// Leading dimension of real part of D matrix -+ int64_t ldd_imag /// Leading dimension of imaginary part of D matrix -+) { -+ -+ // -+ // Find the operation -+ // -+ -+ GemmFunctionalKey key( -+ provider_, -+ GemmKind::kPlanarComplexArray, -+ element_compute, -+ element_scalar, -+ element_A, -+ layout_A, -+ transform_A, -+ element_B, -+ layout_B, -+ transform_B, -+ element_C -+ ); -+ -+ auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); -+ -+ if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ if (operators_it->second.empty()) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // -+ // Compute the largest alignment restriction the kernel can satisfy. -+ // -+ -+ // Maximum alignment expectation among all kernels (in units of bytes) -+ int const kMaximumAlignmentSize = 16; -+ -+ int alignment = std::max( -+ gemm_problem_alignment( -+ expected_M, expected_N, expected_K, -+ element_A, nullptr, lda_real, 0, -+ element_B, nullptr, ldb_real, 0, -+ element_C, nullptr, ldc_real, 0, -+ nullptr, ldd_real, 0, kMaximumAlignmentSize -+ ), -+ gemm_problem_alignment( -+ expected_M, expected_N, expected_K, -+ element_A, nullptr, lda_imag, 0, -+ element_B, nullptr, ldb_imag, 0, -+ element_C, nullptr, ldc_imag, 0, -+ nullptr, ldd_imag, 0, kMaximumAlignmentSize -+ ) -+ ); -+ -+ // -+ // Find the best kernel in descending order of preference. -+ // -+ -+ GemmPreferenceKey preference_key(compute_capability(), alignment); -+ -+ Operation const *operation = find_gemm_operation(operators_it, preference_key); -+ -+ if (!operation) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ last_operation_ = operation; -+ -+ // -+ // Configure operation -+ // -+ -+ GemmPlanarComplexArrayConfiguration configuration{ -+ {expected_M, expected_N, expected_K}, -+ batch_count, -+ lda_real, -+ lda_imag, -+ ldb_real, -+ ldb_imag, -+ ldc_real, -+ ldc_imag, -+ ldd_real, -+ ldd_imag -+ }; -+ -+ // Query host work space size -+ uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); -+ -+ if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ char host_workspace[kHostWorkspaceSize]; -+ -+ // Query device workspace size -+ uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); -+ -+ if (uint64_t(workspace_size_) < device_workspace_size_needed) { -+ return cutlass::Status::kErrorNotSupported; -+ } -+ -+ // Initialize host and device workspaces -+ Status status = operation->initialize( -+ &configuration, -+ host_workspace, -+ workspace_, -+ stream_); -+ -+ if (status != cutlass::Status::kSuccess) { -+ return status; -+ } -+ -+ // Run the operator -+ GemmPlanarComplexArrayArguments arguments{ -+ M, N, K, -+ ptr_A_real, -+ ptr_A_imag, -+ ptr_B_real, -+ ptr_B_imag, -+ ptr_C_real, -+ ptr_C_imag, -+ ptr_D_real, -+ ptr_D_imag, -+ alpha, -+ beta, -+ scalar_pointer_mode_ -+ }; -+ -+ return operation->run(&arguments, host_workspace, workspace_, stream_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Finds conv operation instances with Conv::ElementC = Reduction::ElementWorkspace -+Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation) { -+ -+ ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ // if the curren conv operation accumulator and output data type match return operation -+ if(conv_desc.tile_description.math_instruction.element_accumulator == conv_desc.C.element) { -+ return operation; -+ } -+ -+ // find conv operation to match conv output and reduction workspace data type -+ ConvFunctionalKey key( -+ library::Provider::kCUTLASS, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+ // conv operation table for conv2d or conv3d -+ auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ? -+ Singleton::get().operation_table.conv2d_operations : -+ Singleton::get().operation_table.conv3d_operations; -+ -+ // find ConvFunctionalKey in convolution operation table -+ auto operators_it = conv_operations.find(key); -+ -+ if (operators_it == conv_operations.end()) { -+ return nullptr; -+ } -+ -+ if (operators_it->second.empty()) { -+ return nullptr; -+ } -+ -+ // conv operation for same compute capability and iterator algorithm -+ ConvPreferenceKey preference_key( -+ conv_desc.tile_description.minimum_compute_capability, -+ conv_desc.iterator_algorithm); -+ -+ auto it = operators_it->second.find(preference_key); -+ -+ if(it == operators_it->second.end()) { -+ return nullptr; -+ } -+ -+ // return matching conv opertion (same tile sizes and instruction) -+ for (auto op : it->second) { -+ if (op->description().tile_description == operation->description().tile_description) { -+ return op; -+ } -+ } -+ -+ return nullptr; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace -+Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) { -+ -+ GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // if the curren gemm operation accumulator and output data type match return operation -+ if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) { -+ return operation; -+ } -+ -+ // find gemm operation to match gemm output and reduction workspace data type -+ GemmFunctionalKey key( -+ library::Provider::kCUTLASS, -+ gemm_desc.gemm_kind, -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ gemm_desc.tile_description.math_instruction.element_accumulator); -+ -+ // gemm operation table -+ auto gemm_operations = Singleton::get().operation_table.gemm_operations; -+ -+ // find ConvFunctionalKey in gemm operation table -+ auto operators_it = gemm_operations.find(key); -+ -+ if (operators_it == gemm_operations.end()) { -+ return nullptr; -+ } -+ -+ if (operators_it->second.empty()) { -+ return nullptr; -+ } -+ -+ // A and B uses the same alignment in the generator.py -+ int alignment = gemm_desc.A.alignment; -+ -+ // gemm operation for same compute capability and iterator algorithm -+ GemmPreferenceKey preference_key( -+ gemm_desc.tile_description.minimum_compute_capability, -+ alignment); -+ -+ return find_gemm_operation(operators_it, preference_key); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/library_internal.h b/3rdparty/cutlass/tools/library/src/library_internal.h -new file mode 100644 -index 0000000..e9739e3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/library_internal.h -@@ -0,0 +1,356 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! -+ \file -+ -+ \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. -+ -+ Generally, -+ -+ description - compile-time constant parameters used to instantiate an operation -+ -+ configuration - runtime parameters with computationally expensive initialization -+ -+ arguments - runtime parameters that may be passed to an initialized operation with low -+ computational overhead -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/arch/arch.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/arch_mappings.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct NumericTypeMap; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kB1; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS4; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS8; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kS64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU4; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU8; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kU64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF32; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kF64; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF16; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF32; -+}; -+ -+template <> struct NumericTypeMap > { -+ static NumericTypeID const kId = NumericTypeID::kCF64; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kBF16; -+}; -+ -+template <> struct NumericTypeMap { -+ static NumericTypeID const kId = NumericTypeID::kTF32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kInvalid; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAdd; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kXorPopc; -+}; -+ -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; -+}; -+ -+template <> struct MathOperationMap { -+ static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct LayoutMap; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajor; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; -+}; -+ -+template <> struct LayoutMap { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; -+}; -+ -+template <> struct LayoutMap> { -+ static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct OpcodeClassMap; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kSimt; -+}; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kTensorOp; -+}; -+ -+template <> struct OpcodeClassMap { -+ static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ComplexTransformMap; -+ -+template <> struct ComplexTransformMap { -+ static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; -+}; -+ -+template <> struct ComplexTransformMap { -+ static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template struct ConvModeMap; -+ -+template <> struct ConvModeMap { -+ static ConvModeID const kId = ConvModeID::kCrossCorrelation; -+}; -+ -+template <> struct ConvModeMap { -+ static ConvModeID const kId = ConvModeID::kConvolution; -+}; -+ -+ -+template struct ConvKindMap; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kFprop; -+}; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kDgrad; -+}; -+ -+template <> struct ConvKindMap { -+ static ConvKind const kId = ConvKind::kWgrad; -+}; -+ -+ -+template struct IteratorAlgorithmMap; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; -+}; -+ -+template <> struct IteratorAlgorithmMap { -+ static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+TensorDescription make_TensorDescription(int alignment = 1) { -+ TensorDescription desc; -+ -+ desc.element = NumericTypeMap::kId; -+ desc.layout = LayoutMap::kId; -+ desc.alignment = alignment; -+ desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; -+ desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; -+ -+ return desc; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/operation_table.cu b/3rdparty/cutlass/tools/library/src/operation_table.cu -new file mode 100644 -index 0000000..d3799c3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/operation_table.cu -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* -+ \file -+ \brief Defines a data structure in which a set of functionally equivalent library::Operation -+ instances may be queried. -+*/ -+ -+#include "cutlass/library/operation_table.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void OperationTable::append(Manifest const &manifest) { -+ -+ // Insert operations into appropriate data structure -+ for (auto const & operation : manifest) { -+ -+ OperationDescription const &desc = operation->description(); -+ -+ // insert all gemm operation into operation table -+ if (desc.kind == OperationKind::kGemm) { -+ GemmDescription const &gemm_desc = static_cast(desc); -+ -+ -+ GemmFunctionalKey functional_key( -+ gemm_desc.provider, -+ gemm_desc.gemm_kind, -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ gemm_desc.C.element -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ int cc = gemm_desc.tile_description.minimum_compute_capability; -+ -+ int alignment = std::max(std::max( -+ gemm_desc.A.alignment, gemm_desc.B.alignment), gemm_desc.C.alignment); -+ -+ GemmPreferenceKey preference_key(cc, alignment); -+ -+ gemm_operations[functional_key][preference_key].push_back(op); -+ } -+ -+ // insert all conv2d or conv3d operation into operation table -+ if (desc.kind == OperationKind::kConv2d || desc.kind == OperationKind::kConv3d) { -+ auto &conv_desc = static_cast(desc); -+ -+ ConvFunctionalKey functional_key( -+ conv_desc.provider, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ int cc = conv_desc.tile_description.minimum_compute_capability; -+ -+ ConvPreferenceKey preference_key(cc, conv_desc.iterator_algorithm); -+ -+ // insert conv operation to conv2d_operations or conv3d_operations map -+ (desc.kind == OperationKind::kConv2d) ? -+ conv2d_operations[functional_key][preference_key].push_back(op) : -+ conv3d_operations[functional_key][preference_key].push_back(op); -+ } -+ -+ // insert all reduction operation into operation table -+ if (desc.kind == OperationKind::kReduction) { -+ auto &reduce_desc = static_cast(desc); -+ -+ ReductionFunctionalKey functional_key( -+ reduce_desc.provider, -+ reduce_desc.element_workspace, -+ reduce_desc.tile_description.math_instruction.element_accumulator, -+ reduce_desc.element_output, -+ reduce_desc.element_epilogue, -+ library::MathOperationID::kAdd, -+ library::EpilogueKind::kLinearCombination -+ ); -+ -+ Operation const *op = operation.get(); -+ -+ reduction_operations[functional_key] = op; -+ -+ } -+ -+ } -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/rank_2k_operation.h b/3rdparty/cutlass/tools/library/src/rank_2k_operation.h -new file mode 100644 -index 0000000..d6e0dca ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/rank_2k_operation.h -@@ -0,0 +1,373 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/rank_2k.h" -+#include "cutlass/gemm/kernel/default_rank_2k_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Rank2KOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ RankKDescription description_; -+ -+public: -+ -+ /// Constructor -+ Rank2KOperationBase(char const *name = "unknown_rank_k") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.rank_k_kind = RankKKind::kUniversal; -+ description_.fill_mode = kFillModeC; -+ description_.blas_mode = kBlasMode; -+ description_.num_ranks = kUpdateRank; -+ -+ description_.kind = OperationKind::kRank2K; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::Rank2Kkernel::WarpCount::kM, -+ Operator::Rank2Kkernel::WarpCount::kN, -+ Operator::Rank2Kkernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the SYRK operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class Rank2KOperation : public Rank2KOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ Rank2KOperation(char const *name = "unknown_rank_2k"): -+ Rank2KOperationBase(name) { -+ -+ this->description_.rank_k_kind = RankKKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ RankKConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ RankKArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ RankKConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ RankKArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ //std::cout << "initialize() library::Rank2KOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run() library::Rank2KOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->run(stream); -+ -+ return status; -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Rank2KOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.epilogue.alpha << ", " -+ << operator_args.epilogue.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ptr_A << ", {" -+ << operator_args.lda << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ptr_B << ", {" -+ << operator_args.ldb << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ptr_C << ", {" -+ << operator_args.ldc << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ptr_D << ", {" -+ << operator_args.ldd << "}" << std::endl; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/rank_k_operation.h b/3rdparty/cutlass/tools/library/src/rank_k_operation.h -new file mode 100644 -index 0000000..2eb7a2d ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/rank_k_operation.h -@@ -0,0 +1,344 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Rank K operation kinds (Syrk, Herk) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/rank_k.h" -+#include "cutlass/gemm/kernel/default_rank_k_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class RankKOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementA; -+ using LayoutB = typename Operator::LayoutA; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ RankKDescription description_; -+ -+public: -+ -+ /// Constructor -+ RankKOperationBase(char const *name = "unknown_rank_k") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.rank_k_kind = RankKKind::kUniversal; -+ description_.fill_mode = kFillModeC; -+ description_.blas_mode = kBlasMode; -+ description_.num_ranks = kUpdateRank; -+ -+ description_.kind = OperationKind::kRankK; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::RankKkernel::WarpCount::kM, -+ Operator::RankKkernel::WarpCount::kN, -+ Operator::RankKkernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentA); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.transform_B = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the SYRK operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class RankKOperation : public RankKOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementA; -+ using LayoutB = typename Operator::LayoutA; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static int const kUpdateRank = Operator::kUpdateRank; -+ static FillMode const kFillModeC = Operator::kFillModeC; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ RankKOperation(char const *name = "unknown_rank_k"): -+ RankKOperationBase(name) { -+ -+ this->description_.rank_k_kind = RankKKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ RankKConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->lda); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ RankKArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ RankKConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ RankKArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu b/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu -new file mode 100644 -index 0000000..b0f1695 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/init_reduction_operations.cu -@@ -0,0 +1,65 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Initialize operations for reduction operation in CUTLASS Library. -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+namespace cutlass { -+namespace library { -+/////////////////////////////////////////////////////////////////////////////////////////////// -+// CUTLASS Reduction Instances // -+/////////////////////////////////////////////////////////////////////////////////////////////// -+void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest); -+void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest); -+void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest); -+void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest); -+ -+// -+// Entry point to construct operations -+// -+void initialize_all_reduction_op(Manifest &manifest) { -+ -+ initialize_reduce_add_linear_combination_f32_f32_f16(manifest); -+ initialize_reduce_add_linear_combination_f32_f32_f32(manifest); -+ initialize_reduce_add_linear_combination_f64_f64_f64(manifest); -+ initialize_reduce_add_linear_combination_cf32_cf32_cf32(manifest); -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu b/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu -new file mode 100644 -index 0000000..2eb6ab7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/reduction_device.cu -@@ -0,0 +1,184 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for reduction operation in CUTLASS Library. -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "reduction_operation.h" -+ -+namespace cutlass { -+namespace library { -+ -+// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput] -+ -+void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = cutlass::half_t; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f32_f32_f16 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f32_f32_f16>( -+ "reduce_add_linear_combination_f32_f32_f16" -+ )); -+} -+ -+ -+void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { -+ -+ using ElementWorkspace = float; -+ using ElementAccumulator = float; -+ using ElementOutput = float; -+ using ElementCompute = float; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f32_f32_f32 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f32_f32_f32>( -+ "reduce_add_linear_combination_f32_f32_f32" -+ )); -+} -+ -+void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest) { -+ -+ using ElementWorkspace = double; -+ using ElementAccumulator = double; -+ using ElementOutput = double; -+ using ElementCompute = double; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_f64_f64_f64 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_f64_f64_f64>( -+ "reduce_add_linear_combination_f64_f64_f64" -+ )); -+} -+ -+void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) { -+ -+ using ElementWorkspace = cutlass::complex; -+ using ElementAccumulator = cutlass::complex; -+ using ElementOutput = cutlass::complex; -+ using ElementCompute = cutlass::complex; -+ -+ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< -+ ElementOutput, -+ 128 / cutlass::sizeof_bits::value, -+ ElementAccumulator, -+ ElementCompute -+ >; -+ -+ using ReductionOp = cutlass::reduction::thread::ReduceAdd< -+ ElementAccumulator, -+ typename EpilogueOutputOp::ElementAccumulator, -+ EpilogueOutputOp::kCount -+ >; -+ -+ using Operation_reduce_add_linear_combination_cf32_cf32_cf32 = cutlass::reduction::device::ReduceSplitK< -+ cutlass::reduction::kernel::ReduceSplitK< -+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, -+ EpilogueOutputOp, -+ ReductionOp -+ > -+ >; -+ -+ manifest.append(new ReductionOperation< -+ Operation_reduce_add_linear_combination_cf32_cf32_cf32>( -+ "reduce_add_linear_combination_cf32_cf32_cf32" -+ )); -+} -+ -+} -+} -diff --git a/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h b/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h -new file mode 100644 -index 0000000..846ca02 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reduction/reduction_operation.h -@@ -0,0 +1,290 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for reduction operation in CUTLASS Library. -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/epilogue/thread/linear_combination.h" -+#include "cutlass/epilogue/thread/linear_combination_clamp.h" -+#include "cutlass/reduction/thread/reduction_operators.h" -+#include "cutlass/reduction/device/reduce_split_k.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class ReductionOperation : public Operation { -+public: -+ using Operator = Operator_; -+ -+ using ElementWorkspace = typename Operator::ElementWorkspace; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementOutput = typename Operator::ElementOutput; -+ -+ using ElementCompute = typename Operator::OutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ ReductionDescription description_; -+ -+public: -+ -+ /// Constructor -+ ReductionOperation(char const *name = "unknown_reduction") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kReduction; -+ -+ description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); -+ description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; -+ description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; -+ description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; -+ -+ description_.tile_description.minimum_compute_capability = 50; -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ description_.element_workspace = NumericTypeMap::kId; -+ description_.element_output = NumericTypeMap::kId; -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ } -+ -+ /// Returns the description of the Reduction operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ ReductionConfiguration const *configuration) { -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.partitions = configuration->partitions; -+ operator_args.partition_stride = configuration->partition_stride; -+ -+ operator_args.workspace = {nullptr, int(configuration->ldw)}; -+ operator_args.source = {nullptr, int(configuration->lds)}; -+ operator_args.destination = {nullptr, int(configuration->ldd)}; -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ ReductionArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::OutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.output = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::OutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.output = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); -+ operator_args.source.reset(static_cast(const_cast(arguments->source))); -+ operator_args.destination.reset(static_cast(const_cast(arguments->destination))); -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ ReductionConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ ReductionArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ return Operator::get_workspace_size(args); -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ //std::cout << "initialize library::Reduction" << std::endl; -+ //print_operator_args(args); -+ return op->initialize(args, device_workspace, stream); -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ status = op->update(args, device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run library::Reduction" << std::endl; -+ //print_operator_args(args); -+ return op->run(stream); -+ } -+ -+ /// Call print_operator_args from the Reduction::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "Reduction::OperatorArguments" << std::endl -+ << " problem_size: " -+ << operator_args.problem_size << std::endl -+ << " partitions: " -+ << operator_args.partitions << std::endl -+ << " partition_stride: " -+ << operator_args.partition_stride << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.output.alpha << ", " -+ << operator_args.output.beta << std::endl -+ << " workspace (ptr, stride): " -+ << operator_args.workspace.data() << ", " -+ << operator_args.workspace.stride(0) << std::endl -+ << " source (ptr, stride): " -+ << operator_args.source.data() << ", " -+ << operator_args.source.stride(0) << std::endl -+ << " destination (ptr, stride): " -+ << operator_args.destination.data() << ", " -+ << operator_args.destination.stride(0) << std::endl; -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv2d.cu b/3rdparty/cutlass/tools/library/src/reference/conv2d.cu -new file mode 100644 -index 0000000..715e3b0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv2d.cu -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "conv_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_conv2d_reference_operations(Manifest &manifest) { -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, -+ cutlass::half_t -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ cutlass::half_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ float, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, cutlass::layout::TensorNHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 2, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, cutlass::layout::TensorNHWC, -+ cutlass::complex, -+ cutlass::complex -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ uint8_t, cutlass::layout::TensorNHWC, -+ uint8_t, cutlass::layout::TensorNHWC, -+ int8_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ int32_t, cutlass::layout::TensorNHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 2, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv3d.cu b/3rdparty/cutlass/tools/library/src/reference/conv3d.cu -new file mode 100644 -index 0000000..a0f9069 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv3d.cu -@@ -0,0 +1,209 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "conv_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_conv3d_reference_operations(Manifest &manifest) { -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, -+ cutlass::half_t -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ cutlass::half_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ cutlass::bfloat16_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ cutlass::tfloat32_t, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_all< -+ 3, -+ float, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, cutlass::layout::TensorNDHWC, -+ float, -+ float -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ uint8_t, cutlass::layout::TensorNDHWC, -+ int8_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::int4b_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ int32_t, cutlass::layout::TensorNDHWC, -+ int32_t, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_conv_fprop< -+ 3, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ cutlass::uint4b_t, cutlass::layout::TensorNDHWC, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h b/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h -new file mode 100644 -index 0000000..3a294a2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/conv_reference_operation.h -@@ -0,0 +1,632 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all CONV operation kinds in CUTLASS Library -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+#include "library_internal.h" -+ -+#include "cutlass/util/reference/host/convolution.h" -+#include "cutlass/util/reference/device/convolution.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ Provider kProvider, -+ conv::Operator ConvolutionalOperator, -+ int ConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+struct ConvReferenceDispatcher; -+ -+/// Dispatcher for Conv2d (partially specialied for kConvDim == 2) -+template < -+ Provider kProvider, -+ conv::Operator kConvolutionalOperator, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator, -+ typename ConvertOp, -+ typename InnerProductOp -+> -+struct ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ 2, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp> { -+ -+ static Status dispatch( -+ void const *configuration, -+ ElementA *ptr_A, -+ ElementB *ptr_B, -+ ElementC *ptr_C, -+ ElementC *ptr_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr -+ ) { -+ -+ Conv2dConfiguration const &config = -+ *static_cast(configuration); -+ -+ // TODO: make below code more general. It is fixed for NHWC now. -+ layout::TensorNHWC layout_a; -+ layout::TensorNHWC layout_b; -+ layout::TensorNHWC layout_c; -+ -+ layout_a.stride() = -+ make_Coord(int32_t(config.stride_a[0]), -+ int32_t(config.stride_a[1]), -+ int32_t(config.stride_a[2])); -+ -+ layout_b.stride() = -+ make_Coord(int32_t(config.stride_b[0]), -+ int32_t(config.stride_b[1]), -+ int32_t(config.stride_b[2])); -+ -+ layout_c.stride() = -+ make_Coord(int32_t(config.stride_c[0]), -+ int32_t(config.stride_c[1]), -+ int32_t(config.stride_c[2])); -+ -+ if (kProvider == Provider::kReferenceHost) { -+ -+ cutlass::reference::host::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC , -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, layout_a}, -+ {ptr_B, layout_b}, -+ {ptr_C, layout_c}, -+ {ptr_D, layout_c}, -+ alpha, -+ beta -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ return cutlass::reference::device::Conv2d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, layout_a}, -+ {ptr_B, layout_b}, -+ {ptr_C, layout_c}, -+ {ptr_D, layout_c}, -+ alpha, -+ beta, -+ stream -+ ); -+ } -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) -+template < -+ Provider kProvider, -+ conv::Operator kConvolutionalOperator, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator, -+ typename ConvertOp, -+ typename InnerProductOp -+> -+struct ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ 3, -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp> { -+ -+ static Status dispatch( -+ void const *configuration, -+ ElementA *ptr_A, -+ ElementB *ptr_B, -+ ElementC *ptr_C, -+ ElementC *ptr_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr -+ ) { -+ -+ Conv3dConfiguration const &config = -+ *static_cast(configuration); -+ -+ ConvKind const conv_kind = ConvKindMap::kId; -+ -+ if (kProvider == Provider::kReferenceHost) { -+ cutlass::reference::host::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC , -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, config.layout_a(conv_kind)}, -+ {ptr_B, config.layout_b(conv_kind)}, -+ {ptr_C, config.layout_c(conv_kind)}, -+ {ptr_D, config.layout_c(conv_kind)}, -+ alpha, -+ beta -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ return cutlass::reference::device::Conv3d< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ kConvolutionalOperator, -+ config.problem_size, -+ {ptr_A, config.layout_a(conv_kind)}, -+ {ptr_B, config.layout_b(conv_kind)}, -+ {ptr_C, config.layout_c(conv_kind)}, -+ {ptr_D, config.layout_c(conv_kind)}, -+ alpha, -+ beta, -+ stream -+ ); -+ } -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ Provider Provider_, -+ conv::Operator ConvolutionalOperator, -+ int ConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+class ConvReferenceOperation : public Operation { -+public: -+ static Provider const kProvider = Provider_; -+ static conv::Operator const kConvolutionalOperator = ConvolutionalOperator; -+ static int const kConvDim = ConvDim; -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ConvertOp = ConvertOp_; -+ using InnerProductOp = InnerProductOp_; -+ -+protected: -+ -+ /// Storage for the name string -+ std::string name_; -+ -+ /// -+ ConvDescription description_; -+ -+public: -+ -+ /// Constructor -+ ConvReferenceOperation() { -+ -+ // Basic information -+ description_.provider = kProvider; -+ description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); -+ description_.conv_kind = ConvKindMap::kId; -+ description_.conv_dim = kConvDim; -+ -+ // Tensor description -+ description_.A = make_TensorDescription(); -+ description_.B = make_TensorDescription(); -+ description_.C = make_TensorDescription(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ // Iterator algorithm for convolution reference -+ description_.iterator_algorithm = IteratorAlgorithmID::kNone; -+ -+ // Compute capability for convolution reference -+ description_.tile_description.minimum_compute_capability = -+ (kProvider == Provider::kReferenceDevice ? 50 : 0); -+ -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ // Procedural name -+ std::stringstream ss; -+ -+ ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) -+ << "_reference_" << to_string(description_.provider) -+ << "_" << to_string(description_.A.element) << to_string(description_.A.layout) -+ << "_" << to_string(description_.B.element) << to_string(description_.B.layout) -+ << "_" << to_string(description_.C.element) << to_string(description_.C.layout) -+ << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); -+ -+ name_ = ss.str(); -+ -+ description_.name = name_.c_str(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const { -+ -+ return Status::kSuccess; -+ } -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ switch (kConvDim) { -+ case 2: -+ return sizeof(Conv2dConfiguration); -+ case 3: -+ return sizeof(Conv3dConfiguration); -+ default: -+ break; -+ } -+ -+ return 0; -+ } -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const { -+ -+ return 0; -+ } -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); -+ -+ return Status::kSuccess; -+ } -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ ConvArguments const &args = *static_cast(arguments); -+ -+ ElementCompute alpha; -+ ElementCompute beta; -+ -+ alpha = *static_cast(args.alpha); -+ beta = *static_cast(args.beta); -+ -+ // TODO - respect pointer mode -+ -+ // Invoke 2D or 3D convolution -+ return detail::ConvReferenceDispatcher< -+ kProvider, -+ kConvolutionalOperator, -+ kConvDim, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >::dispatch( -+ host_workspace, -+ static_cast(const_cast(args.A)), -+ static_cast(const_cast(args.B)), -+ static_cast(const_cast(args.C)), -+ static_cast(args.D), -+ alpha, -+ beta, -+ stream -+ ); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Constructs Fprop reference operators. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_fprop(Manifest &manifest) { -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kFprop, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kFprop, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Constructs Dgrad and Wgrad reference operators. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_backwards(Manifest &manifest) { -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kDgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kDgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceHost, -+ conv::Operator::kWgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new ConvReferenceOperation< -+ Provider::kReferenceDevice, -+ conv::Operator::kWgrad, -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Six operators for the price of one. -+template < -+ int kConvDim, -+ typename ElementA_, -+ typename LayoutA_, -+ typename ElementB_, -+ typename LayoutB_, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_conv_all(Manifest &manifest) { -+ -+ make_conv_fprop< -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_conv_backwards< -+ kConvDim, -+ ElementA_, LayoutA_, -+ ElementB_, LayoutB_, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/gemm.cu b/3rdparty/cutlass/tools/library/src/reference/gemm.cu -new file mode 100644 -index 0000000..890772e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/gemm.cu -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Instantiates GEMM reference implementations. -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+#include "gemm_reference_operation.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_gemm_reference_operations(Manifest &manifest) { -+ -+ make_gemm_real_canonical_layouts< -+ float, // ElementA -+ float, // ElementB -+ float, // ElementC -+ float, // ElementScalar -+ float // ElementAccumulator -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ tfloat32_t, -+ tfloat32_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ tfloat32_t, -+ tfloat32_t, -+ tfloat32_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ half_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ half_t, -+ half_t, -+ half_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ half_t, -+ half_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ bfloat16_t, -+ bfloat16_t, -+ bfloat16_t, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ bfloat16_t, -+ bfloat16_t, -+ float, -+ float, -+ float -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ double, -+ double, -+ double, -+ double, -+ double -+ >(manifest); -+ -+ // -+ // Integer-valued GEMMs -+ // -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ int8_t, -+ int8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_real_canonical_layouts< -+ uint8_t, -+ uint8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ int8_t, -+ int8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ uint8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 32, -+ uint8_t, -+ uint8_t, -+ int8_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ int4b_t, -+ int4b_t, -+ int4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int32_t, -+ int32_t, -+ int32_t -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int32_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ uint4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ make_gemm_interleaved_layouts< -+ 64, -+ uint4b_t, -+ uint4b_t, -+ int4b_t, -+ float, -+ int32_t, -+ NumericConverterClamp -+ >(manifest); -+ -+ // -+ // Complex-valued GEMMs -+ // -+ -+ make_gemm_complex_canonical_layouts< -+ complex, -+ complex, -+ complex, -+ complex, -+ complex -+ >(manifest); -+ -+ make_gemm_complex_canonical_layouts< -+ complex, -+ complex, -+ complex, -+ complex, -+ complex -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h b/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h -new file mode 100644 -index 0000000..5d4d150 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/gemm_reference_operation.h -@@ -0,0 +1,473 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines reference operations for GEMM operation kinds in CUTLASS Library -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/util.h" -+#include "library_internal.h" -+ -+#include "cutlass/util/reference/host/gemm_complex.h" -+#include "cutlass/util/reference/device/gemm_complex.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ Provider Provider_, -+ typename ElementA_, -+ typename LayoutA_, -+ cutlass::ComplexTransform TransformA, -+ typename ElementB_, -+ typename LayoutB_, -+ cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+class GemmReferenceOperation : public Operation { -+public: -+ static Provider const kProvider = Provider_; -+ -+ using ElementA = ElementA_; -+ using LayoutA = LayoutA_; -+ using TensorRefA = TensorRef; -+ static cutlass::ComplexTransform const kTransformA = TransformA; -+ using ElementB = ElementB_; -+ using LayoutB = LayoutB_; -+ using TensorRefB = TensorRef; -+ static cutlass::ComplexTransform const kTransformB = TransformB; -+ using ElementC = ElementC_; -+ using LayoutC = LayoutC_; -+ using TensorRefC = TensorRef; -+ using ElementCompute = ElementCompute_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ConvertOp = ConvertOp_; -+ using InnerProductOp = InnerProductOp_; -+ -+protected: -+ -+ /// Storage for the name string -+ std::string name_; -+ -+ /// -+ GemmDescription description_; -+ -+public: -+ -+ /// Constructor -+ GemmReferenceOperation() { -+ -+ // Basic information -+ description_.provider = kProvider; -+ description_.kind = OperationKind::kGemm; -+ description_.gemm_kind = GemmKind::kUniversal; -+ -+ // Tensor description -+ description_.A = make_TensorDescription(); -+ description_.transform_A = ComplexTransformMap::kId; -+ description_.B = make_TensorDescription(); -+ description_.transform_B = ComplexTransformMap::kId; -+ description_.C = make_TensorDescription(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ // Compute capability for gemm reference -+ description_.tile_description.minimum_compute_capability = -+ (kProvider == Provider::kReferenceDevice ? 50 : 0); -+ -+ description_.tile_description.maximum_compute_capability = 1024; -+ -+ // Procedural name -+ std::stringstream ss; -+ -+ ss << "gemm" -+ << "_reference_" << to_string(description_.provider) -+ << "_" << to_string(description_.A.element) << to_string(description_.A.layout) -+ << "_" << to_string(description_.B.element) << to_string(description_.B.layout) -+ << "_" << to_string(description_.C.element) << to_string(description_.C.layout) -+ << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); -+ -+ name_ = ss.str(); -+ -+ description_.name = name_.c_str(); -+ -+ // Epilogue compute and accumulator type description -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ } -+ -+ /// Returns the description of the GEMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+ -+ virtual Status can_implement( -+ void const *configuration, -+ void const *arguments) const { -+ -+ return Status::kSuccess; -+ } -+ -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(GemmUniversalConfiguration); -+ } -+ -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration, -+ void const *arguments = nullptr) const { -+ -+ return 0; -+ } -+ -+ virtual Status initialize( -+ void const *configuration, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); -+ -+ return Status::kSuccess; -+ } -+ -+ virtual Status run( -+ void const *arguments, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ GemmUniversalConfiguration const &config = *static_cast(host_workspace); -+ GemmUniversalArguments const &args = *static_cast(arguments); -+ -+ TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; -+ TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; -+ TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; -+ TensorRefC ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; -+ -+ if (kProvider == Provider::kReferenceHost) { -+ -+ cutlass::reference::host::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ config.problem_size, -+ *static_cast(args.alpha), -+ ref_A, -+ kTransformA, -+ ref_B, -+ kTransformB, -+ *static_cast(args.beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(), -+ ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), -+ args.batch_stride_A, -+ args.batch_stride_B, -+ args.batch_stride_C, -+ args.batch_stride_D -+ ); -+ -+ return Status::kSuccess; -+ } -+ else if (kProvider == Provider::kReferenceDevice) { -+ -+ cutlass::reference::device::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp -+ >( -+ config.problem_size, -+ *static_cast(args.alpha), -+ ref_A, -+ kTransformA, -+ ref_B, -+ kTransformB, -+ *static_cast(args.beta), -+ ref_C, -+ ref_D, -+ ElementAccumulator(), -+ ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), -+ args.batch_stride_A, -+ args.batch_stride_B, -+ args.batch_stride_C, -+ args.batch_stride_D -+ ); -+ -+ return Status::kSuccess; -+ } -+ -+ return Status::kErrorNotSupported; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA_, -+ typename LayoutA_, -+ cutlass::ComplexTransform TransformA, -+ typename ElementB_, -+ typename LayoutB_, -+ cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename LayoutC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm(Manifest &manifest) { -+ -+ manifest.append(new GemmReferenceOperation< -+ Provider::kReferenceHost, -+ ElementA_, LayoutA_, TransformA, -+ ElementB_, LayoutB_, TransformB, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+ -+ manifest.append(new GemmReferenceOperation< -+ Provider::kReferenceDevice, -+ ElementA_, LayoutA_, TransformA, -+ ElementB_, LayoutB_, TransformB, -+ ElementC_, LayoutC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >); -+} -+ -+/// Helper to create NN, NT, TN, and TT GEMM layouts. -+template < -+ typename ElementA_, cutlass::ComplexTransform TransformA, -+ typename ElementB_, cutlass::ComplexTransform TransformB, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_canonical_layouts(Manifest &manifest) { -+ -+ make_gemm< -+ ElementA_, cutlass::layout::ColumnMajor, TransformA, -+ ElementB_, cutlass::layout::ColumnMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::ColumnMajor, TransformA, -+ ElementB_, cutlass::layout::RowMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, TransformA, -+ ElementB_, cutlass::layout::ColumnMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, TransformA, -+ ElementB_, cutlass::layout::RowMajor, TransformB, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+ -+/// Helper to create TN and interleaved layouts GEMM layouts. -+template < -+ int InterleaveK, -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_interleaved_layouts(Manifest &manifest) { -+ -+ make_gemm< -+ ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, -+ ElementC_, cutlass::layout::ColumnMajor, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+} -+ -+/// Helper to real-valued GEMM with canonical layouts -+template < -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_real_canonical_layouts(Manifest &manifest) { -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+// Helper to create all complex transformation permutations -+template < -+ typename ElementA_, -+ typename ElementB_, -+ typename ElementC_, -+ typename ElementCompute_, -+ typename ElementAccumulator_ = ElementCompute_, -+ typename ConvertOp_ = NumericConverter, -+ typename InnerProductOp_ = multiply_add -+> -+void make_gemm_complex_canonical_layouts(Manifest &manifest) { -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kConjugate, -+ ElementB_, cutlass::ComplexTransform::kConjugate, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kNone, -+ ElementB_, cutlass::ComplexTransform::kConjugate, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+ -+ make_gemm_canonical_layouts< -+ ElementA_, cutlass::ComplexTransform::kConjugate, -+ ElementB_, cutlass::ComplexTransform::kNone, -+ ElementC_, -+ ElementCompute_, -+ ElementAccumulator_, -+ ConvertOp_, -+ InnerProductOp_ -+ >(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu b/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu -new file mode 100644 -index 0000000..b63367e ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/reference/initialize_reference_operations.cu -@@ -0,0 +1,63 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+void initialize_gemm_reference_operations(Manifest &manifest); -+void initialize_conv2d_reference_operations(Manifest &manifest); -+void initialize_conv3d_reference_operations(Manifest &manifest); -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+void initialize_reference_operations(Manifest &manifest) { -+ initialize_conv2d_reference_operations(manifest); -+ initialize_conv3d_reference_operations(manifest); -+ initialize_gemm_reference_operations(manifest); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/library/src/singleton.cu b/3rdparty/cutlass/tools/library/src/singleton.cu -new file mode 100644 -index 0000000..2315448 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/singleton.cu -@@ -0,0 +1,62 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/operation_table.h" -+#include "cutlass/library/singleton.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Singleton::Singleton() { -+ -+ manifest.initialize(); -+ -+ operation_table.append(manifest); -+} -+ -+Singleton const & Singleton::get() { -+ static Singleton instance; -+ return instance; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/symm_operation.h b/3rdparty/cutlass/tools/library/src/symm_operation.h -new file mode 100644 -index 0000000..d7554ed ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/symm_operation.h -@@ -0,0 +1,379 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all Symm operation kinds (Symm, Hemm) -+ in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+#include -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/symm.h" -+#include "cutlass/gemm/kernel/default_symm_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+#include "cutlass/core_io.h" -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class SymmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static SideMode const kSideModeA = Operator::kSideModeA; -+ static FillMode const kFillModeA = Operator::kFillModeA; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ SymmDescription description_; -+ -+public: -+ -+ /// Constructor -+ SymmOperationBase(char const *name = "unknown_symm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.symm_kind = SymmKind::kUniversal; -+ description_.side_mode = kSideModeA; -+ description_.fill_mode = kFillModeA; -+ description_.blas_mode = kBlasMode; -+ -+ description_.kind = OperationKind::kSymm; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::SymmKernel::WarpCount::kM, -+ Operator::SymmKernel::WarpCount::kN, -+ Operator::SymmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.C = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ } -+ -+ /// Returns the description of the SYMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class SymmOperation : public SymmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ static BlasMode const kBlasMode = Operator::kBlasMode; -+ static SideMode const kSideModeA = Operator::kSideModeA; -+ static FillMode const kFillModeA = Operator::kFillModeA; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ SymmOperation(char const *name = "unknown_symm"): -+ SymmOperationBase(name) { -+ -+ this->description_.symm_kind = SymmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ SymmConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldc = int(configuration->ldc); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ SymmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.ptr_C = arguments->C; -+ operator_args.ptr_D = arguments->D; -+ -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.batch_stride_C = arguments->batch_stride_C; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ SymmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ SymmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ //std::cout << "initialize() library::SymmOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && -+ std::is_same::value) || -+ (kSideModeA == SideMode::kRight && -+ std::is_same::value); -+ if (need_swapped_matrices) { -+ status = op->update(args.swapped_matrices(), device_workspace); -+ } else { -+ status = op->update(args, device_workspace); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ //std::cout << "run() library::SymmOperation" << std::endl; -+ //print_operator_args(args); -+ status = op->run(stream); -+ -+ return status; -+ } -+ -+ /// Call print_operator_args from the Conv2dOperation::initialize() -+ // to dump arguments passed on to cutlass operator for debugging -+ void print_operator_args(OperatorArguments &operator_args) const { -+ std::cout << "SymmOperation::OperatorArguments" << std::endl -+ << " problem_size:" << std::endl -+ << operator_args.problem_size << std::endl -+ << " epilouge (alpha, beta): " -+ << operator_args.epilogue.alpha << ", " -+ << operator_args.epilogue.beta << std::endl -+ << " ref_A (ptr, {stride}): " -+ << operator_args.ptr_A << ", {" -+ << operator_args.lda << "}" << std::endl -+ << " ref_B (ptr, {stride}): " -+ << operator_args.ptr_B << ", {" -+ << operator_args.ldb << "}" << std::endl -+ << " ref_C (ptr, {stride}): " -+ << operator_args.ptr_C << ", {" -+ << operator_args.ldc << "}" << std::endl -+ << " ref_D (ptr, {stride}): " -+ << operator_args.ptr_D << ", {" -+ << operator_args.ldd << "}" << std::endl; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/trmm_operation.h b/3rdparty/cutlass/tools/library/src/trmm_operation.h -new file mode 100644 -index 0000000..55f4fa6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/trmm_operation.h -@@ -0,0 +1,346 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines operations for all TRMM operation kinds in CUTLASS Library. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/gemm/device/trmm.h" -+#include "cutlass/gemm/kernel/default_trmm_universal.h" -+#include "cutlass/gemm/kernel/trmm_universal.h" -+ -+#include "cutlass/library/library.h" -+#include "library_internal.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace library { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TrmmOperationBase : public Operation { -+public: -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ static SideMode const kSideMode = Operator::kSideMode; -+ static FillMode const kFillMode = Operator::kFillMode; -+ static DiagType const kDiagType = Operator::kDiagType; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+protected: -+ -+ /// -+ TrmmDescription description_; -+ -+public: -+ -+ /// Constructor -+ TrmmOperationBase(char const *name = "unknown_trmm") { -+ -+ description_.name = name; -+ description_.provider = Provider::kCUTLASS; -+ description_.kind = OperationKind::kTrmm; -+ description_.trmm_kind = TrmmKind::kUniversal; -+ description_.side_mode = kSideMode; -+ description_.fill_mode = kFillMode; -+ description_.diag_type = kDiagType; -+ -+ description_.tile_description.threadblock_shape = make_Coord( -+ Operator::ThreadblockShape::kM, -+ Operator::ThreadblockShape::kN, -+ Operator::ThreadblockShape::kK); -+ -+ description_.tile_description.threadblock_stages = Operator::kStages; -+ -+ description_.tile_description.warp_count = make_Coord( -+ Operator::TrmmKernel::WarpCount::kM, -+ Operator::TrmmKernel::WarpCount::kN, -+ Operator::TrmmKernel::WarpCount::kK); -+ -+ description_.tile_description.math_instruction.instruction_shape = make_Coord( -+ Operator::InstructionShape::kM, -+ Operator::InstructionShape::kN, -+ Operator::InstructionShape::kK); -+ -+ description_.tile_description.math_instruction.element_accumulator = -+ NumericTypeMap::kId; -+ -+ description_.tile_description.math_instruction.opcode_class = -+ OpcodeClassMap::kId; -+ -+ description_.tile_description.math_instruction.math_operation = -+ MathOperationMap::kId; -+ -+ description_.tile_description.minimum_compute_capability = -+ ArchMap::kMin; -+ -+ description_.tile_description.maximum_compute_capability = -+ ArchMap::kMax; -+ -+ description_.A = make_TensorDescription(Operator::kAlignmentA); -+ description_.B = make_TensorDescription(Operator::kAlignmentB); -+ description_.D = make_TensorDescription(Operator::kAlignmentC); -+ description_.element_epilogue = NumericTypeMap::kId; -+ -+ description_.split_k_mode = SplitKMode::kNone; -+ description_.transform_A = ComplexTransformMap::kId; -+ } -+ -+ /// Returns the description of the TRMM operation -+ virtual OperationDescription const & description() const { -+ return description_; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class TrmmOperation : public TrmmOperationBase { -+public: -+ -+ using Operator = Operator_; -+ using ElementA = typename Operator::ElementA; -+ using LayoutA = typename Operator::LayoutA; -+ static SideMode const kSideMode = Operator::kSideMode; -+ static FillMode const kFillMode = Operator::kFillMode; -+ static DiagType const kDiagType = Operator::kDiagType; -+ using ElementB = typename Operator::ElementB; -+ using LayoutB = typename Operator::LayoutB; -+ using ElementC = typename Operator::ElementC; -+ using LayoutC = typename Operator::LayoutC; -+ using ElementAccumulator = typename Operator::ElementAccumulator; -+ using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -+ -+ using OperatorArguments = typename Operator::Arguments; -+ -+public: -+ -+ /// Constructor -+ TrmmOperation(char const *name = "unknown_trmm"): -+ TrmmOperationBase(name) { -+ -+ this->description_.trmm_kind = TrmmKind::kUniversal; -+ } -+ -+protected: -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status construct_arguments_( -+ OperatorArguments &operator_args, -+ TrmmConfiguration const *configuration) { -+ -+ //operator_args.mode = configuration->mode; -+ -+ operator_args.problem_size = configuration->problem_size; -+ operator_args.batch_count = configuration->batch_count; -+ -+ operator_args.lda = int(configuration->lda); -+ operator_args.ldb = int(configuration->ldb); -+ operator_args.ldd = int(configuration->ldd); -+ -+ return Status::kSuccess; -+ } -+ -+ /// Constructs the arguments structure given the configuration and arguments -+ static Status update_arguments_( -+ OperatorArguments &operator_args, -+ TrmmArguments const *arguments) { -+ -+ if (arguments->pointer_mode == ScalarPointerMode::kHost) { -+ typename Operator::EpilogueOutputOp::Params params( -+ *static_cast(arguments->alpha), -+ *static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ -+ typename Operator::EpilogueOutputOp::Params params( -+ static_cast(arguments->alpha), -+ static_cast(arguments->beta) -+ ); -+ operator_args.epilogue = params; -+ } -+ else { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ // update arguments -+ operator_args.ptr_A = arguments->A; -+ operator_args.ptr_B = arguments->B; -+ operator_args.batch_stride_A = arguments->batch_stride_A; -+ operator_args.batch_stride_B = arguments->batch_stride_B; -+ operator_args.ptr_D = arguments->D; -+ operator_args.batch_stride_D = arguments->batch_stride_D; -+ -+ return Status::kSuccess; -+ } -+ -+public: -+ -+ /// Returns success if the operation can proceed -+ virtual Status can_implement( -+ void const *configuration_ptr, -+ void const *arguments_ptr) const { -+ -+ TrmmConfiguration const *configuration = -+ static_cast(configuration_ptr); -+ -+ TrmmArguments const *arguments = -+ static_cast(arguments_ptr); -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_(args, configuration); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = update_arguments_(args, arguments); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ return Operator::can_implement(args); -+ } -+ -+ /// Gets the host-side workspace -+ virtual uint64_t get_host_workspace_size( -+ void const *configuration) const { -+ -+ return sizeof(Operator); -+ } -+ -+ /// Gets the device-side workspace -+ virtual uint64_t get_device_workspace_size( -+ void const *configuration_ptr, -+ void const *arguments_ptr = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return 0; -+ } -+ -+ uint64_t size = Operator::get_workspace_size(args); -+ -+ return size; -+ } -+ -+ /// Initializes the workspace -+ virtual Status initialize( -+ void const *configuration_ptr, -+ void *host_workspace, -+ void *device_workspace, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = construct_arguments_( -+ args, -+ static_cast(configuration_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = new (host_workspace) Operator; -+ -+ status = op->initialize(args, device_workspace, stream); -+ -+ return status; -+ } -+ -+ /// Runs the kernel -+ virtual Status run( -+ void const *arguments_ptr, -+ void *host_workspace, -+ void *device_workspace = nullptr, -+ cudaStream_t stream = nullptr) const { -+ -+ OperatorArguments args; -+ -+ Status status = update_arguments_( -+ args, -+ static_cast(arguments_ptr)); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ Operator *op = static_cast(host_workspace); -+ -+ bool need_swapped_matrices = (kSideMode == SideMode::kLeft && -+ std::is_same::value) || -+ (kSideMode == SideMode::kRight && -+ std::is_same::value); -+ if (need_swapped_matrices) { -+ status = op->update(args.swapped_matrices(), device_workspace); -+ } else { -+ status = op->update(args, device_workspace); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ status = op->run(stream); -+ -+ return status; -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/library/src/util.cu b/3rdparty/cutlass/tools/library/src/util.cu -new file mode 100644 -index 0000000..a4e234a ---- /dev/null -+++ b/3rdparty/cutlass/tools/library/src/util.cu -@@ -0,0 +1,1599 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/layout/matrix.h" -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace library { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ Provider enumerant; -+} -+Provider_enumerants[] = { -+ {"none", "None", Provider::kNone}, -+ {"cutlass", "CUTLASS", Provider::kCUTLASS}, -+ {"host", "reference_host", Provider::kReferenceHost}, -+ {"device", "reference_device", Provider::kReferenceDevice}, -+ {"cublas", "cuBLAS", Provider::kCUBLAS}, -+ {"cudnn", "cuDNN", Provider::kCUDNN}, -+}; -+ -+/// Converts a Provider enumerant to a string -+char const *to_string(Provider provider, bool pretty) { -+ -+ for (auto const & possible : Provider_enumerants) { -+ if (provider == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a Provider enumerant from a string -+template <> -+Provider from_string(std::string const &str) { -+ -+ for (auto const & possible : Provider_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return Provider::kInvalid; -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ GemmKind enumerant; -+} -+GemmKind_enumerants[] = { -+ {"gemm", "", GemmKind::kGemm}, -+ {"spgemm", "", GemmKind::kSparse}, -+ {"universal", "", GemmKind::kUniversal}, -+ {"planar_complex", "", GemmKind::kPlanarComplex}, -+ {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, -+ {"grouped", "", GemmKind::kGrouped}, -+}; -+ -+/// Converts a GemmKind enumerant to a string -+char const *to_string(GemmKind type, bool pretty) { -+ -+ for (auto const & possible : GemmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ RankKKind enumerant; -+} -+RankKKind_enumerants[] = { -+ {"universal", "", RankKKind::kUniversal}, -+}; -+ -+/// Converts a SyrkKind enumerant to a string -+char const *to_string(RankKKind type, bool pretty) { -+ -+ for (auto const & possible :RankKKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ TrmmKind enumerant; -+} -+TrmmKind_enumerants[] = { -+ {"universal", "", TrmmKind::kUniversal}, -+}; -+ -+/// Converts a TrmmKind enumerant to a string -+char const *to_string(TrmmKind type, bool pretty) { -+ -+ for (auto const & possible :TrmmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SymmKind enumerant; -+} -+SymmKind_enumerants[] = { -+ {"universal", "", SymmKind::kUniversal}, -+}; -+ -+/// Converts a SymmKind enumerant to a string -+char const *to_string(SymmKind type, bool pretty) { -+ -+ for (auto const & possible :SymmKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SideMode enumerant; -+} -+SideMode_enumerants[] = { -+ {"left", "Left", SideMode::kLeft}, -+ {"right", "Right", SideMode::kRight} -+}; -+ -+/// Converts a SideMode enumerant to a string -+char const *to_string(SideMode type, bool pretty) { -+ -+ for (auto const & possible :SideMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ FillMode enumerant; -+} -+FillMode_enumerants[] = { -+ {"lower", "Lower", FillMode::kLower}, -+ {"upper", "Upper", FillMode::kUpper} -+}; -+ -+/// Converts a FillMode enumerant to a string -+char const *to_string(FillMode type, bool pretty) { -+ -+ for (auto const & possible :FillMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ BlasMode enumerant; -+} -+BlasMode_enumerants[] = { -+ {"symmetric", "Symmetric", BlasMode::kSymmetric}, -+ {"hermitian", "Hermitian", BlasMode::kHermitian} -+}; -+ -+/// Converts a BlasMode enumerant to a string -+char const *to_string(BlasMode type, bool pretty) { -+ -+ for (auto const & possible :BlasMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ DiagType enumerant; -+} -+DiagType_enumerants[] = { -+ {"nonunit", "NonUnit", DiagType::kNonUnit}, -+ {"unit", "Unit", DiagType::kUnit} -+}; -+ -+/// Converts a DiagType enumerant to a string -+char const *to_string(DiagType type, bool pretty) { -+ -+ for (auto const & possible :DiagType_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ OperationKind enumerant; -+} -+OperationKind_enumerants[] = { -+ {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, -+ {"gemm", "Gemm", OperationKind::kGemm}, -+ {"rank_k", "RankK", OperationKind::kRankK}, -+ {"rank_2k", "Rank2K", OperationKind::kRank2K}, -+ {"trmm", "Trmm", OperationKind::kTrmm}, -+ {"symm", "Symm", OperationKind::kSymm}, -+ {"conv2d", "Conv2d", OperationKind::kConv2d}, -+ {"conv3d", "Conv3d", OperationKind::kConv3d}, -+ {"spgemm", "SparseGemm", OperationKind::kSparseGemm}, -+}; -+ -+/// Converts a Status enumerant to a string -+char const *to_string(OperationKind enumerant, bool pretty) { -+ -+ for (auto const & possible : OperationKind_enumerants) { -+ if (enumerant == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a Status enumerant from a string -+template <> -+OperationKind from_string(std::string const &str) { -+ -+ for (auto const & possible : OperationKind_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return OperationKind::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ Status enumerant; -+} -+Status_enumerants[] = { -+ {"success", "Success", Status::kSuccess}, -+ {"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand}, -+ {"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem}, -+ {"not_supported", "Error: not supported", Status::kErrorNotSupported}, -+ {"internal", "Error: internal", Status::kErrorInternal} -+}; -+ -+/// Converts a Status enumerant to a string -+char const *to_string(Status status, bool pretty) { -+ -+ for (auto const & possible : Status_enumerants) { -+ if (status == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a Status enumerant from a string -+template <> -+Status from_string(std::string const &str) { -+ -+ for (auto const & possible : Status_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return Status::kInvalid; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ NumericTypeID enumerant; -+} -+NumericTypeID_enumerants[] = { -+ {"unknown", "", NumericTypeID::kUnknown}, -+ {"void", "Void", NumericTypeID::kVoid}, -+ {"b1", "B1", NumericTypeID::kB1}, -+ {"u2", "U2", NumericTypeID::kU2}, -+ {"u4", "U4", NumericTypeID::kU4}, -+ {"u8", "U8", NumericTypeID::kU8}, -+ {"u16", "U16", NumericTypeID::kU16}, -+ {"u32", "U32", NumericTypeID::kU32}, -+ {"u64", "U64", NumericTypeID::kU64}, -+ {"s2", "S2", NumericTypeID::kS2}, -+ {"s4", "S4", NumericTypeID::kS4}, -+ {"s8", "S8", NumericTypeID::kS8}, -+ {"s16", "S16", NumericTypeID::kS16}, -+ {"s32", "S32", NumericTypeID::kS32}, -+ {"s64", "S64", NumericTypeID::kS64}, -+ {"f16", "F16", NumericTypeID::kF16}, -+ {"bf16", "BF16", NumericTypeID::kBF16}, -+ {"f32", "F32", NumericTypeID::kF32}, -+ {"tf32", "TF32", NumericTypeID::kTF32}, -+ {"f64", "F64", NumericTypeID::kF64}, -+ {"cf16", "CF16", NumericTypeID::kCF16}, -+ {"cbf16", "CBF16", NumericTypeID::kCBF16}, -+ {"cf32", "CF32", NumericTypeID::kCF32}, -+ {"ctf32", "CTF32", NumericTypeID::kCTF32}, -+ {"cf64", "CF64", NumericTypeID::kCF64}, -+ {"cu2", "CU2", NumericTypeID::kCU2}, -+ {"cu4", "CU4", NumericTypeID::kCU4}, -+ {"cu8", "CU8", NumericTypeID::kCU8}, -+ {"cu16", "CU16", NumericTypeID::kCU16}, -+ {"cu32", "CU32", NumericTypeID::kCU32}, -+ {"cu64", "CU64", NumericTypeID::kCU64}, -+ {"cs2", "CS2", NumericTypeID::kCS2}, -+ {"cs4", "CS4", NumericTypeID::kCS4}, -+ {"cs8", "CS8", NumericTypeID::kCS8}, -+ {"cs16", "CS16", NumericTypeID::kCS16}, -+ {"cs32", "CS32", NumericTypeID::kCS32}, -+ {"cs64", "CS64", NumericTypeID::kCS64}, -+ {"*", "", NumericTypeID::kUnknown} -+}; -+ -+/// Converts a NumericTypeID enumerant to a string -+char const *to_string(NumericTypeID type, bool pretty) { -+ -+ for (auto const & possible : NumericTypeID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a NumericTypeID enumerant from a string -+template <> -+NumericTypeID from_string(std::string const &str) { -+ -+ for (auto const & possible : NumericTypeID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return NumericTypeID::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns the size of a data type in bits -+int sizeof_bits(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return 16; -+ case NumericTypeID::kBF16: return 16; -+ case NumericTypeID::kTF32: return 32; -+ case NumericTypeID::kF32: return 32; -+ case NumericTypeID::kF64: return 64; -+ case NumericTypeID::kCF16: return 32; -+ case NumericTypeID::kCBF16: return 32; -+ case NumericTypeID::kCF32: return 64; -+ case NumericTypeID::kCTF32: return 64; -+ case NumericTypeID::kCF64: return 128; -+ case NumericTypeID::kS2: return 2; -+ case NumericTypeID::kS4: return 4; -+ case NumericTypeID::kS8: return 8; -+ case NumericTypeID::kS16: return 16; -+ case NumericTypeID::kS32: return 32; -+ case NumericTypeID::kS64: return 64; -+ case NumericTypeID::kU2: return 2; -+ case NumericTypeID::kU4: return 4; -+ case NumericTypeID::kU8: return 8; -+ case NumericTypeID::kU16: return 16; -+ case NumericTypeID::kU32: return 32; -+ case NumericTypeID::kU64: return 64; -+ case NumericTypeID::kB1: return 1; -+ default: break; -+ } -+ return 0; -+} -+ -+/// Returns true if the numeric type is a complex data type or false if real-valued. -+bool is_complex_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kCF16: return true; -+ case NumericTypeID::kCF32: return true; -+ case NumericTypeID::kCF64: return true; -+ case NumericTypeID::kCBF16: return true; -+ case NumericTypeID::kCTF32: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns the field underlying a complex valued type -+NumericTypeID get_real_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kCF16: return NumericTypeID::kF16; -+ case NumericTypeID::kCF32: return NumericTypeID::kF32; -+ case NumericTypeID::kCF64: return NumericTypeID::kF64; -+ case NumericTypeID::kCBF16: return NumericTypeID::kBF16; -+ case NumericTypeID::kCTF32: return NumericTypeID::kTF32; -+ default: break; -+ } -+ return type; -+} -+ -+/// Returns true if numeric type is integer -+bool is_integer_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kS2: return true; -+ case NumericTypeID::kS4: return true; -+ case NumericTypeID::kS8: return true; -+ case NumericTypeID::kS16: return true; -+ case NumericTypeID::kS32: return true; -+ case NumericTypeID::kS64: return true; -+ case NumericTypeID::kU2: return true; -+ case NumericTypeID::kU4: return true; -+ case NumericTypeID::kU8: return true; -+ case NumericTypeID::kU16: return true; -+ case NumericTypeID::kU32: return true; -+ case NumericTypeID::kU64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns true if numeric type is signed -+bool is_signed_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return true; -+ case NumericTypeID::kBF16: return true; -+ case NumericTypeID::kTF32: return true; -+ case NumericTypeID::kF32: return true; -+ case NumericTypeID::kF64: return true; -+ case NumericTypeID::kS2: return true; -+ case NumericTypeID::kS4: return true; -+ case NumericTypeID::kS8: return true; -+ case NumericTypeID::kS16: return true; -+ case NumericTypeID::kS32: return true; -+ case NumericTypeID::kS64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+/// Returns true if numeric type is a signed integer -+bool is_signed_integer(NumericTypeID type) { -+ return is_integer_type(type) && is_signed_type(type); -+} -+ -+/// returns true if numeric type is an unsigned integer -+bool is_unsigned_integer(NumericTypeID type) { -+ return is_integer_type(type) && !is_signed_type(type); -+} -+ -+/// Returns true if numeric type is floating-point type -+bool is_float_type(NumericTypeID type) { -+ switch (type) { -+ case NumericTypeID::kF16: return true; -+ case NumericTypeID::kBF16: return true; -+ case NumericTypeID::kTF32: return true; -+ case NumericTypeID::kF32: return true; -+ case NumericTypeID::kF64: return true; -+ case NumericTypeID::kCF16: return true; -+ case NumericTypeID::kCBF16: return true; -+ case NumericTypeID::kCTF32: return true; -+ case NumericTypeID::kCF32: return true; -+ case NumericTypeID::kCF64: return true; -+ default: break; -+ } -+ return false; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ LayoutTypeID layout; -+ char const *alias; -+} -+layout_aliases[] = { -+ {LayoutTypeID::kUnknown, "unknown"}, -+ {LayoutTypeID::kRowMajor, "row"}, -+ {LayoutTypeID::kRowMajor, "t"}, -+ {LayoutTypeID::kColumnMajor, "column"}, -+ {LayoutTypeID::kColumnMajor, "col"}, -+ {LayoutTypeID::kColumnMajor, "n"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, -+ {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, -+ {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, -+ {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, -+ {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, -+ -+ {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, -+ {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, -+ -+ {LayoutTypeID::kTensorNCHW, "nchw"}, -+ {LayoutTypeID::kTensorNCDHW, "ncdhw"}, -+ {LayoutTypeID::kTensorNHWC, "nhwc"}, -+ {LayoutTypeID::kTensorNDHWC, "ndhwc"}, -+ {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, -+ {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, -+ {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, -+ {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, -+ -+ {LayoutTypeID::kUnknown, "*"}, -+ {LayoutTypeID::kInvalid, nullptr} -+}; -+ -+/// Converts a LayoutTypeID enumerant to a string -+char const *to_string(LayoutTypeID layout, bool pretty) { -+ for (auto const & alias : layout_aliases) { -+ if (alias.layout == layout) { -+ return alias.alias; -+ } -+ } -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Parses a LayoutTypeID enumerant from a string -+template <> -+LayoutTypeID from_string(std::string const &str) { -+ for (auto const & alias : layout_aliases) { -+ if (str.compare(alias.alias) == 0) { -+ return alias.layout; -+ } -+ } -+ return LayoutTypeID::kInvalid; -+} -+ -+/// Gets stride rank for the layout_id (static function) -+int get_layout_stride_rank(LayoutTypeID layout_id) { -+ switch (layout_id) { -+ case LayoutTypeID::kColumnMajor: -+ return cutlass::layout::ColumnMajor::kStrideRank; -+ case LayoutTypeID::kRowMajor: -+ return cutlass::layout::RowMajor::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK2: -+ return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK2: -+ return cutlass::layout::RowMajorInterleaved<2>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK4: -+ return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK4: -+ return cutlass::layout::RowMajorInterleaved<4>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK16: -+ return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK16: -+ return cutlass::layout::RowMajorInterleaved<16>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK32: -+ return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK32: -+ return cutlass::layout::RowMajorInterleaved<32>::kStrideRank; -+ case LayoutTypeID::kColumnMajorInterleavedK64: -+ return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank; -+ case LayoutTypeID::kRowMajorInterleavedK64: -+ return cutlass::layout::RowMajorInterleaved<64>::kStrideRank; -+ case LayoutTypeID::kTensorNCHW: -+ return cutlass::layout::TensorNCHW::kStrideRank; -+ case LayoutTypeID::kTensorNHWC: -+ return cutlass::layout::TensorNHWC::kStrideRank; -+ case LayoutTypeID::kTensorNDHWC: -+ return cutlass::layout::TensorNDHWC::kStrideRank; -+ case LayoutTypeID::kTensorNC32HW32: -+ return cutlass::layout::TensorNCxHWx<32>::kStrideRank; -+ case LayoutTypeID::kTensorNC64HW64: -+ return cutlass::layout::TensorNCxHWx<64>::kStrideRank; -+ case LayoutTypeID::kTensorC32RSK32: -+ return cutlass::layout::TensorCxRSKx<32>::kStrideRank; -+ case LayoutTypeID::kTensorC64RSK64: -+ return cutlass::layout::TensorCxRSKx<64>::kStrideRank; -+ default: -+ throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ OpcodeClassID enumerant; -+} -+OpcodeClassID_enumerants[] = { -+ {"simt", "", OpcodeClassID::kSimt}, -+ {"tensorop", "", OpcodeClassID::kTensorOp}, -+ {"wmmatensorop", "", OpcodeClassID::kWmmaTensorOp}, -+ {"wmma", "", OpcodeClassID::kWmmaTensorOp}, -+}; -+ -+/// Converts a OpcodeClassID enumerant to a string -+char const *to_string(OpcodeClassID type, bool pretty) { -+ -+ for (auto const & possible : OpcodeClassID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a OpcodeClassID enumerant from a string -+template <> -+OpcodeClassID from_string(std::string const &str) { -+ -+ for (auto const & possible : OpcodeClassID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return OpcodeClassID::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ ComplexTransform enumerant; -+} -+ComplexTransform_enumerants[] = { -+ {"n", "none", ComplexTransform::kNone}, -+ {"c", "conj", ComplexTransform::kConjugate} -+}; -+ -+/// Converts a ComplexTransform enumerant to a string -+char const *to_string(ComplexTransform type, bool pretty) { -+ -+ for (auto const & possible : ComplexTransform_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ComplexTransform enumerant from a string -+template <> -+ComplexTransform from_string(std::string const &str) { -+ -+ for (auto const & possible : ComplexTransform_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ComplexTransform::kInvalid; -+} -+ -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ SplitKMode enumerant; -+} -+SplitKMode_enumerants[] = { -+ {"serial", "", SplitKMode::kSerial}, -+ {"parallel", "", SplitKMode::kParallel}, -+}; -+ -+/// Converts a SplitKMode enumerant to a string -+char const *to_string(SplitKMode type, bool pretty) { -+ -+ for (auto const & possible : SplitKMode_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a SplitKMode enumerant from a string -+template <> -+SplitKMode from_string(std::string const &str) { -+ -+ for (auto const & possible : SplitKMode_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return SplitKMode::kInvalid; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+static struct { -+ char const *text; -+ char const *pretty; -+ ConvModeID enumerant; -+} -+ConvModeID_enumerants[] = { -+ {"cross", "", ConvModeID::kCrossCorrelation}, -+ {"conv", "", ConvModeID::kConvolution}, -+}; -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(ConvModeID type, bool pretty) { -+ -+ for (auto const & possible : ConvModeID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+ConvModeID from_string(std::string const &str) { -+ -+ for (auto const & possible : ConvModeID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ConvModeID::kInvalid; -+} -+ -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ IteratorAlgorithmID enumerant; -+} -+IteratorAlgorithmID_enumerants[] = { -+ {"none", "", IteratorAlgorithmID::kNone}, -+ {"analytic", "", IteratorAlgorithmID::kAnalytic}, -+ {"optimized", "", IteratorAlgorithmID::kOptimized}, -+ {"fixed_channels", "", IteratorAlgorithmID::kFixedChannels}, -+ {"few_channels", "", IteratorAlgorithmID::kFewChannels}, -+}; -+ -+/// Converts a ConvModeID enumerant to a string -+char const *to_string(IteratorAlgorithmID type, bool pretty) { -+ -+ for (auto const & possible : IteratorAlgorithmID_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+/// Converts a ConvModeID enumerant from a string -+template <> -+IteratorAlgorithmID from_string(std::string const &str) { -+ -+ for (auto const & possible : IteratorAlgorithmID_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return IteratorAlgorithmID::kInvalid; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static struct { -+ char const *text; -+ char const *pretty; -+ ConvKind enumerant; -+} -+ConvKind_enumerants[] = { -+ {"unknown", "", ConvKind::kUnknown}, -+ {"fprop", "", ConvKind::kFprop}, -+ {"dgrad", "", ConvKind::kDgrad}, -+ {"wgrad", "", ConvKind::kWgrad}, -+}; -+ -+/// Converts a ConvKind enumerant to a string -+char const *to_string(ConvKind type, bool pretty) { -+ -+ for (auto const & possible : ConvKind_enumerants) { -+ if (type == possible.enumerant) { -+ if (pretty) { -+ return possible.pretty; -+ } -+ else { -+ return possible.text; -+ } -+ } -+ } -+ -+ return pretty ? "Invalid" : "invalid"; -+} -+ -+ -+/// Converts a ConvKind enumerant from a string -+template <> -+ConvKind from_string(std::string const &str) { -+ -+ for (auto const & possible : ConvKind_enumerants) { -+ if ((str.compare(possible.text) == 0) || -+ (str.compare(possible.pretty) == 0)) { -+ return possible.enumerant; -+ } -+ } -+ -+ return ConvKind::kInvalid; -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ std::stringstream ss; -+ ss << str; -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ float tmp; -+ ss >> tmp; -+ *reinterpret_cast(bytes.data()) = static_cast(tmp); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ ss >> *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ ss >> *reinterpret_cast*>(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ std::complex tmp; -+ ss >> tmp; -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(std::real(tmp)); -+ x->imag() = static_cast(std::imag(tmp)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ ss >> *reinterpret_cast*>(bytes.data()); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+std::string lexical_cast(int64_t int_value) { -+ std::stringstream ss; -+ ss << int_value; -+ return ss.str(); -+} -+ -+/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. -+std::string lexical_cast(std::vector &bytes, NumericTypeID type) { -+ -+ int size_bytes = sizeof_bits(type) / 8; -+ -+ if (!size_bytes || size_bytes != bytes.size()) { -+ return ""; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ std::stringstream ss; -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ float tmp = *reinterpret_cast(bytes.data()); -+ ss << tmp; -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ ss << *reinterpret_cast(bytes.data()); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex const *x = -+ reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != cutlass::half_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ cutlass::complex const *x = -+ reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != cutlass::bfloat16_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << x->real(); -+ -+ if (x->imag() != float()) { -+ ss << "+i" << x->imag(); -+ } -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << float(x->real()); -+ -+ if (x->imag() != tfloat32_t()) { -+ ss << "+i" << float(x->imag()); -+ } -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); -+ -+ ss << x->real(); -+ -+ if (x->imag() != double()) { -+ ss << "+i" << x->imag(); -+ } -+ } -+ break; -+ default: -+ return ""; -+ } -+ -+ return ss.str(); -+} -+ -+/// Casts from a signed int64 to the destination type. Returns true if successful. -+bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = double(src); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float(0)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(double(src), double(0)); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+ -+} -+ -+/// Casts from an unsigned int64 to the destination type. Returns true if successful. -+bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src) { -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = double(src); -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = std::complex(double(src), double(0)); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+ -+} -+ -+/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. -+bool cast_from_double(std::vector &bytes, NumericTypeID type, double src) { -+ -+ int size_bytes = sizeof_bits(type) / 8; -+ if (!size_bytes) { -+ return false; -+ } -+ -+ bytes.resize(size_bytes, 0); -+ -+ switch (type) { -+ case NumericTypeID::kU8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kU64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS8: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kS64: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kBF16: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kTF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(float(src)); -+ } -+ break; -+ case NumericTypeID::kF32: -+ { -+ *reinterpret_cast(bytes.data()) = static_cast(src); -+ } -+ break; -+ case NumericTypeID::kF64: -+ { -+ *reinterpret_cast(bytes.data()) = src; -+ } -+ break; -+ case NumericTypeID::kCF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(float(src)); -+ x->imag() = static_cast(float(0)); -+ } -+ break; -+ case NumericTypeID::kCBF16: -+ { -+ cutlass::complex *x = reinterpret_cast *>(bytes.data()); -+ x->real() = static_cast(bfloat16_t(src)); -+ x->imag() = static_cast(bfloat16_t(0)); -+ } -+ break; -+ case NumericTypeID::kCF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float()); -+ } -+ break; -+ case NumericTypeID::kCTF32: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(tfloat32_t(src), tfloat32_t()); -+ } -+ break; -+ case NumericTypeID::kCF64: -+ { -+ *reinterpret_cast*>(bytes.data()) = cutlass::complex(src, double()); -+ } -+ break; -+ default: -+ return false; -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace library -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu -new file mode 100644 -index 0000000..0693058 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.cu -@@ -0,0 +1,1488 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Convolution 2D profiling -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "conv2d_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+using namespace cutlass::library; -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Conv2dOperationProfiler::Conv2dOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kConv2d, -+ { -+ {ArgumentTypeID::kEnumerated, {"conv_kind"}, "Convolutional operator (fprop, dgrad, wgrad)"}, -+ {ArgumentTypeID::kInteger, {"n", "input_n"}, "Input N dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"h", "input_h"}, "Input H dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"w", "input_w"}, "Input W dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"c", "input_c"}, "Input C dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "filter_k"}, "Filter K dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"r", "filter_r"}, "Filter R dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv2d problem space"}, -+ {ArgumentTypeID::kInteger, {"g", "groups"}, "Number of convolution groups"}, -+ {ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"}, -+ {ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"}, -+ {ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"}, -+ {ArgumentTypeID::kInteger, {"stride_w"}, "Stride in W direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_h"}, "Dilation in H direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_w"}, "Dilation in W direction"}, -+ {ArgumentTypeID::kTensor, {"Activation"}, "Tensor storing the Activation operand"}, -+ {ArgumentTypeID::kTensor, {"Filter"}, "Tensor storing the Filter operand"}, -+ {ArgumentTypeID::kTensor, {"Output"}, "Tensor storing the Output operand"}, -+ {ArgumentTypeID::kEnumerated, {"conv_mode"}, "Convolution filter mode (conv, cross)"}, -+ {ArgumentTypeID::kEnumerated, {"iterator_algorithm", "iterator_algo"}, "Convolution iterator algorithm (analytic, optimized)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "SplitK mode for serial or parallel reduction (serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kEnumerated, {"eq_gemm_provider", "eq-gemm-provider"}, "Enable profiling equivalent gemm by the following providers (cutlass)"}, -+ }, -+ { library::Provider::kReferenceDevice, library::Provider::kReferenceHost, library::Provider::kCUDNN } -+ ) { -+ -+ description_ = " Conv2d operation. Output(Tensor4D) = alpha * Input(Tensor4D) * Filter(Tensor4D) + beta * Input(Tensor4D)"; -+ -+} -+ -+/// Destructor -+Conv2dOperationProfiler::~Conv2dOperationProfiler() { -+ -+} -+ -+ -+/// Prints usage statement for the math function -+void Conv2dOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Conv2d" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Conv2dOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular convolution (specify all the convolution parameters):\n" -+ << " $ cutlass_profiler --operation=Conv2d" -+ " --Activation=f16:nhwc --Filter=f16:nhwc --Output=f16 --accumulator-type=f32" -+ " --n=32 --h=14 --w=14 --c=8 --k=64 --r=3 --s=3" -+ " --pad_h=1 --pad_w=1" -+ " --stride_h=1 --stride_w=1" -+ " --dilation_h=1 --dilation_w=1\n\n"; -+} -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Total number of bytes loaded -+int64_t Conv2dOperationProfiler::Conv2dProblem::bytes( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes_ = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * mnk.m() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * mnk.n() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes_ += int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ } -+ -+ return bytes_; -+} -+ -+/// Total number of flops computed -+int64_t Conv2dOperationProfiler::Conv2dProblem::flops( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ int64_t flops_mainloop_ = int64_t(mnk.m()) * mnk.n() * mnk.k() * 2; -+ int64_t flops_epilogue_ = int64_t(mnk.m()) * int64_t(mnk.n()) * 2; -+ -+ // Adjust mainloop flop for dgrad strided -+ if (operation_desc.conv_kind == library::ConvKind::kDgrad) { -+ flops_mainloop_ = flops_mainloop_ / (stride_h * stride_w); -+ } -+ int64_t flops_total_ = flops_mainloop_ + flops_epilogue_; -+ -+ //complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_total_ *=4; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_total_; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Conv2dOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (!arg_as_int(problem_.n, "n", problem_space, problem)) { -+ // default value -+ problem_.n = 1; -+ } -+ -+ if (!arg_as_int(problem_.h, "h", problem_space, problem)) { -+ // default value -+ problem_.h = 16; -+ } -+ -+ if (!arg_as_int(problem_.w, "w", problem_space, problem)) { -+ // default value -+ problem_.w = 16; -+ } -+ -+ if (!arg_as_int(problem_.c, "c", problem_space, problem)) { -+ // default value -+ problem_.c = 64; -+ } -+ -+ if (!arg_as_int(problem_.k, "k", problem_space, problem)) { -+ // default value -+ problem_.k = 64; -+ } -+ -+ if (!arg_as_int(problem_.r, "r", problem_space, problem)) { -+ // default value -+ problem_.r = 3; -+ } -+ -+ if (!arg_as_int(problem_.s, "s", problem_space, problem)) { -+ // default value -+ problem_.s = 3; -+ } -+ -+ if (!arg_as_int(problem_.groups, "g", problem_space, problem)) { -+ // default value -+ problem_.groups = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) { -+ // default value -+ problem_.pad_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_w, "pad_w", problem_space, problem)) { -+ // default value -+ problem_.pad_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_h, "stride_h", problem_space, problem)) { -+ // default value -+ problem_.stride_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_w, "stride_w", problem_space, problem)) { -+ // default value -+ problem_.stride_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_h, "dilation_h", problem_space, problem)) { -+ // default value -+ problem_.dilation_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_w, "dilation_w", problem_space, problem)) { -+ // default value -+ problem_.dilation_w = 1; -+ } -+ -+ //////////////////////// Convolution output dimensions p and q //////////////////////// -+ // Cutlass convolutions support arbitrary output sizes and not constriant by // -+ // input, filter, padding, striding, dilation sizes. // -+ // cuDNN sets the output dimensions (p, q) using following equations: // -+ // // -+ // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride) // -+ // where; div_up(a, b) : (a - 1)/b + 1 // -+ // // -+ // Thus, when output p and q dimensions are unspecified by the user // -+ // cutlass profiler sets p and q which are cuDNN compliant. // -+ // // -+ //////////////////////////////////////////////////////////////////////////////////////// -+ // set convolution output p -+ if (!arg_as_int(problem_.p, "p", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.p = ( -+ problem_.h + -+ 2 * problem_.pad_h - -+ ((problem_.r - 1) * problem_.dilation_h + 1) -+ ) / (problem_.stride_h) -+ + 1; -+ } -+ -+ // set convolution output q -+ if (!arg_as_int(problem_.q, "q", problem_space, problem)) { -+ // default value (set using cudnn formula for output width, when q is not provided) -+ problem_.q = ( -+ problem_.w + -+ 2 * problem_.pad_w - -+ ((problem_.s - 1) * problem_.dilation_w + 1) -+ ) / (problem_.stride_w) -+ + 1; -+ } -+ ///////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ if (!arg_as_SplitKModeID(problem_.split_k_mode, "split_k_mode", problem_space, problem)) { -+ // default value -+ problem_.split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ if (!arg_as_int(problem_.split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ problem_.split_k_slices = 1; -+ } -+ -+ if (!arg_as_ConvModeID(problem_.conv_mode, "conv_mode", problem_space, problem)) { -+ // default value -+ problem_.conv_mode = library::ConvModeID::kCrossCorrelation; -+ } -+ -+ if (!arg_as_ProviderID(problem_.eq_gemm_provider, "eq_gemm_provider", problem_space, problem)) { -+ // default value -+ problem_.eq_gemm_provider = library::Provider::kNone; -+ } -+ -+ if (!conv_kind_satisfies(operation_desc.conv_kind, "conv_kind", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!iterator_algorithm_satisfies(operation_desc.iterator_algorithm, "iterator_algorithm", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.activation(), "Activation", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.filter(), "Filter", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.output(), "Output", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ problem_.alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ problem_.beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize library::Conv2dConfiguration -+ conv_workspace_.configuration.problem_size = conv::Conv2dProblemSize( -+ int(problem_.n), -+ int(problem_.h), -+ int(problem_.w), -+ int(problem_.c), -+ int(problem_.k), -+ int(problem_.r), -+ int(problem_.s), -+ int(problem_.p), -+ int(problem_.q), -+ int(problem_.pad_h), -+ int(problem_.pad_w), -+ int(problem_.stride_h), -+ int(problem_.stride_w), -+ int(problem_.dilation_h), -+ int(problem_.dilation_w), -+ static_cast(static_cast(problem_.conv_mode)), -+ int(problem_.split_k_slices), -+ int(problem_.groups) -+ ); -+ -+ conv_workspace_.configuration.split_k_mode = static_cast(static_cast(problem_.split_k_mode)); -+ -+ conv_workspace_.set_stride_vector( -+ problem_, operation_desc.conv_kind, operation_desc.A.layout, -+ operation_desc.B.layout, operation_desc.C.layout); -+ -+ // initialize library::ConvArguments -+ conv_workspace_.arguments.A = nullptr; -+ conv_workspace_.arguments.B = nullptr; -+ conv_workspace_.arguments.C = nullptr; -+ conv_workspace_.arguments.D = nullptr; -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if(!initialize_reduction_configuration_(options, report, device_context, operation, problem_space, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&conv_workspace_.configuration, &conv_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Conv2dOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "Activation", problem_space, -+ std::string(library::to_string(operation_desc.activation().element)) -+ + ":" + library::to_string(operation_desc.activation().layout)); -+ -+ set_argument(result, "Filter", problem_space, -+ std::string(library::to_string(operation_desc.filter().element)) -+ + ":" + library::to_string(operation_desc.filter().layout)); -+ -+ set_argument(result, "Output", problem_space, -+ std::string(library::to_string(operation_desc.output().element)) -+ + ":" + library::to_string(operation_desc.output().layout)); -+ -+ set_argument(result, "conv_kind", problem_space, library::to_string(operation_desc.conv_kind)); -+ -+ set_argument(result, "iterator_algorithm", problem_space, std::string(library::to_string(operation_desc.iterator_algorithm))); -+ -+ set_argument(result, "n", problem_space, problem_.n); -+ set_argument(result, "h", problem_space, problem_.h); -+ set_argument(result, "w", problem_space, problem_.w); -+ set_argument(result, "c", problem_space, problem_.c); -+ -+ set_argument(result, "k", problem_space, problem_.k); -+ set_argument(result, "r", problem_space, problem_.r); -+ set_argument(result, "s", problem_space, problem_.s); -+ -+ set_argument(result, "p", problem_space, problem_.p); -+ set_argument(result, "q", problem_space, problem_.q); -+ -+ set_argument(result, "g", problem_space, problem_.groups); -+ -+ set_argument(result, "pad_h", problem_space, problem_.pad_h); -+ set_argument(result, "pad_w", problem_space, problem_.pad_w); -+ -+ set_argument(result, "stride_h", problem_space, problem_.stride_h); -+ set_argument(result, "stride_w", problem_space, problem_.stride_w); -+ -+ set_argument(result, "dilation_h", problem_space, problem_.dilation_h); -+ set_argument(result, "dilation_w", problem_space, problem_.dilation_w); -+ -+ set_argument(result, "split_k_mode", problem_space, -+ std::string(library::to_string(problem_.split_k_mode))); -+ set_argument(result, "split_k_slices", problem_space, problem_.split_k_slices); -+ -+ set_argument(result, "conv_mode", problem_space, -+ std::string(library::to_string(problem_.conv_mode))); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(problem_.alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(problem_.beta, operation_desc.element_epilogue)); -+ -+ set_argument(result, "eq_gemm_provider", problem_space, -+ std::string(library::to_string(problem_.eq_gemm_provider))); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Bytes of activation, filter, and output tensors -+ int64_t activation_bytes = int64_t(library::sizeof_bits(operation_desc.activation().element) / 8) * -+ conv_workspace_.configuration.problem_size.activation_size(); -+ -+ int64_t filter_bytes = int64_t(library::sizeof_bits(operation_desc.filter().element) / 8) * -+ conv_workspace_.configuration.problem_size.filter_size(); -+ -+ int64_t output_bytes = int64_t(library::sizeof_bits(operation_desc.output().element) / 8) * -+ conv_workspace_.configuration.problem_size.output_size(); -+ -+ // Bytes of activation, filter, and output tensors -+ result.bytes = problem_.bytes(operation_desc); -+ -+ // Theoritical flops required for the computation -+ result.flops = problem_.flops(operation_desc); -+ -+ // Measured runtime -+ result.runtime = 0; -+ -+} -+ -+/// Initialize reduction problem dimenstions and library::Operation -+bool Conv2dOperationProfiler::initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ library::ConvKind const &conv_kind = conv_desc.conv_kind; -+ -+ if (!cast_from_double(problem_.alpha_one, conv_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, conv_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// This chooses the appropriate stride element of the row-major C tensor. -+ int const & tensor_c_stride_idx = (conv_kind == library::ConvKind::kWgrad ? 2 : 0); -+ -+ /// intialize library::ReductionConfiguration -+ conv_workspace_.reduction_configuration.problem_size = problem_.eq_gemm_size(conv_kind).mn(); -+ conv_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ conv_workspace_.reduction_configuration.partition_stride = problem_.eq_gemm_size(conv_kind).mn().product(); -+ conv_workspace_.reduction_configuration.ldw = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.lds = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.ldd = -+ conv_workspace_.configuration.stride_c[tensor_c_stride_idx]; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ conv_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ conv_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ conv_desc.C.element, // element output -+ conv_desc.element_epilogue // element compute -+ ); -+ -+#if 0// debug print to check which reduction instance is selected -+ std::cout << reduction_key << "\n"; -+#endif -+ auto reduction_it = Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if(reduction_it == Singleton::get().operation_table.reduction_operations.end()) { -+ -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k conv2d operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+ -+/// Initializes workspace -+Status Conv2dOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(underlying_operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ conv_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ conv_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ conv_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ conv_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ problem_.extent_a(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_a, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_b, -+ conv_workspace_.problem_count -+ ); -+ -+ if(problem_.groups == problem_.c && problem_.groups == problem_.k){ -+ // Depthwise direct conv kernel needs reorder the filter. -+ conv_workspace_.reordered_B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_b, -+ conv_workspace_.problem_count -+ ); -+ } -+ -+ conv_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.configuration.stride_c, -+ conv_workspace_.problem_count -+ ); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &conv_workspace_.configuration, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&conv_workspace_.reduction_configuration); -+ conv_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &conv_workspace_.reduction_configuration, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kConv2d; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Conv2dOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ cudaError_t result; -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.reordered_B != nullptr){ -+ conv_workspace_.arguments.reordered_B = conv_workspace_.reordered_B->data(); -+ }else{ -+ conv_workspace_.arguments.reordered_B = nullptr; -+ } -+ -+ conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data()); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->data(); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->data(); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ // -+ // Run the CUTLASS operation -+ // -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+#if 0 -+ std::cout << "profiling : " << std::endl -+ << "conv2d : " << operation->description().name << std::endl -+ << "underlying conv2d : " << underlying_operation->description().name << std::endl -+ << "reduction : " << reduction_op_->description().name << std::endl; -+#endif -+ -+ // run cutlass conv2d operation -+ results_.back().status = underlying_operation->run( -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ results_.back().status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ } -+ -+ // Synchronize before running device reference -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUDNN -+ // Run verification cudnn reference -+ if (options.verification.provider_enabled(library::Provider::kCUDNN)) { -+ -+ // Guard against unsupported cases -+ auto const & conv_desc = static_cast(operation->description()); -+ -+ Status status = cudnn_satisfies(conv_desc, conv_workspace_.configuration); -+ -+ // Initialize reference data to the source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ if (status == Status::kSuccess) { -+ // call cudnn verification if supported -+ verify_with_cudnn_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else if (status == Status::kErrorInvalidProblem) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kInvalidProblem; -+ } -+ -+ else { -+ // set verification map for cudnn to not supported -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+ // Run verification device reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceDevice)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_device_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Run verification host reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceHost)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_host_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv2dOperationProfiler::verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find host reference operation using conv2d functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv2d_key( -+ library::Provider::kReferenceHost, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+#if 0 // debug print to check which host refererence instance is selected -+ std::cout << conv2d_key << "\n"; -+#endif -+ -+ auto operators_it = Singleton::get().operation_table.conv2d_operations.find(conv2d_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv2d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // conv2d host reference minimum cc is 0 (CPU) and no iterator algorithm -+ library::ConvPreferenceKey preference_key(0, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // host refernce has only one instances in Conv2dOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Copy input tensors A, B, and C from device to host buffers -+ // -+ conv_workspace_.host_tensor_a.resize(conv_workspace_.A->bytes()); -+ conv_workspace_.host_tensor_b.resize(conv_workspace_.B->bytes()); -+ conv_workspace_.host_tensor_c.resize(conv_workspace_.C->bytes()); -+ -+ conv_workspace_.A->copy_to_host(conv_workspace_.host_tensor_a.data()); -+ conv_workspace_.B->copy_to_host(conv_workspace_.host_tensor_b.data()); -+ conv_workspace_.C->copy_to_host(conv_workspace_.host_tensor_c.data()); -+ -+ // -+ // Initialize structure containing Conv2d arguments -+ // -+ conv_workspace_.arguments.A = conv_workspace_.host_tensor_a.data(); -+ conv_workspace_.arguments.B = conv_workspace_.host_tensor_b.data(); -+ conv_workspace_.arguments.C = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.D = conv_workspace_.host_tensor_c.data(); -+ -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Intialize host reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // -+ // Run host reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Copy host reference output to device memory for equality check on device -+ // -+ conv_workspace_.Reference->copy_from_host(conv_workspace_.arguments.D); -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceHost] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceHost] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceHost); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv2dOperationProfiler::verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find device reference operation using conv2d functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv2d_key( -+ library::Provider::kReferenceDevice, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+ auto operators_it = Singleton::get().operation_table.conv2d_operations.find(conv2d_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv2d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotRun; -+ -+ return true; -+ } -+ -+ // conv2d device reference minimum cc is 50 and no iterator algorithm -+ library::ConvPreferenceKey preference_key(50, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotRun; -+ -+ return true; -+ } -+ -+ // device refernce has only one instances in Conv2dOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Intialize device reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run device reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceDevice] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceDevice] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceDevice] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceDevice); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/// Measures performance results -+bool Conv2dOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.C = conv_workspace_.C->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->data(); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->data(); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+ -+} -+ -+/// Method to profile a CUTLASS Operation -+Status Conv2dOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ library::ConvArguments *conv_arguments = static_cast(arguments); -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % conv_workspace_.problem_count); -+ -+ conv_arguments->A = conv_workspace_.A->batch_data(problem_idx); -+ conv_arguments->B = conv_workspace_.B->batch_data(problem_idx); -+ conv_arguments->C = conv_workspace_.C->batch_data(problem_idx); -+ conv_arguments->D = conv_workspace_.Computed->batch_data(problem_idx); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_arguments->D = conv_workspace_.device_workspace.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int problem_idx = (iteration % conv_workspace_.problem_count); -+ -+ conv_arguments->A = conv_workspace_.A->batch_data(problem_idx); -+ conv_arguments->B = conv_workspace_.B->batch_data(problem_idx); -+ conv_arguments->C = conv_workspace_.C->batch_data(problem_idx); -+ conv_arguments->D = conv_workspace_.Computed->batch_data(problem_idx); -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_arguments->D = conv_workspace_.device_workspace.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if CUTLASS_ENABLE_CUDNN -+ -+/// Verifies CUTLASS against cudnn reference -+bool Conv2dOperationProfiler::verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ auto &conv_desc = static_cast(operation->description()); -+ -+ // -+ // Construct cudnn operators -+ // -+ -+ CudnnCreate handle; -+ cudnnStatus_t status = handle.get_cudnn_create_status(); -+ -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // cuDNN does not support four tensor arguments, so we copy the tensor C data into -+ // tensor D. -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ conv_workspace_.arguments.C = conv_workspace_.arguments.D; -+ -+ try { -+ -+ // -+ // Construct dispatcher to cudnn operator -+ // -+ -+ detail::cudnnConvDispatcher conv_op( -+ conv_desc, -+ conv_workspace_.configuration, -+ conv_workspace_.arguments, -+ handle -+ ); -+ -+ if (conv_op.status != Status::kSuccess) { -+ if (conv_op.status == Status::kErrorNotSupported) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ -+ } else { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ return true; -+ } -+ -+ -+ status = conv_op(handle); -+ -+ // Handle errors -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUDNN] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ conv_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUDNN); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h -new file mode 100644 -index 0000000..f432c7e ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv2d_operation_profiler.h -@@ -0,0 +1,493 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for convolution -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/handle.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class Conv2dOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct Conv2dProblem { -+ -+ int64_t n, h, w, c, p, q, k, r, s; -+ int64_t groups; -+ int64_t pad_h, pad_w; -+ int64_t stride_h, stride_w; -+ int64_t dilation_h, dilation_w; -+ -+ std::vector alpha; -+ std::vector beta; -+ -+ library::SplitKMode split_k_mode; -+ int64_t split_k_slices; -+ -+ library::ConvModeID conv_mode; -+ -+ library::Provider eq_gemm_provider; -+ -+ // convolution with parallel interleaved reduction -+ // convolution epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::ConvDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::ConvDescription const &operation_desc) const; -+ -+ void set_default_output_size() { -+ p = ((h + pad_h - r * dilation_h) / stride_h) + 1; -+ q = ((w + pad_w - s * dilation_w) / stride_w) + 1; -+ } -+ -+ // Returns equivalent gemm problem size for convolution -+ cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); -+ case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); -+ case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor A -+ std::vector extent_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; -+ case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; -+ case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor B -+ std::vector extent_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; -+ case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; -+ case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor C -+ std::vector extent_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; -+ case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; -+ case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix A -+ library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix B -+ library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix C -+ library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ // Gemm operator assumes column-major output -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix A -+ int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix B -+ int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix C -+ int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+ /// Workspace used -+ struct Conv2dWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *reordered_B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments for convolution operator -+ library::Conv2dConfiguration configuration; -+ library::ConvArguments arguments; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ /// Buffer used for the cutlass conv2d operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ /// Host data buffers for host reference operation -+ /// host buffer for tensor -+ std::vector host_tensor_a; -+ -+ /// host buffer for tensor b -+ std::vector host_tensor_b; -+ -+ /// host buffer for tensor c -+ std::vector host_tensor_c; -+ -+ // -+ // Methods -+ // -+ -+ Conv2dWorkspace() -+ : A(nullptr), -+ B(nullptr), -+ reordered_B(nullptr), -+ C(nullptr), -+ Computed(nullptr), -+ Reference(nullptr) {} -+ -+ // Set stride vector for tensor activations, filters, output -+ void set_stride_vector(Conv2dProblem const &problem, -+ library::ConvKind const &conv_kind, -+ library::LayoutTypeID const &layout_a, -+ library::LayoutTypeID const &layout_b, -+ library::LayoutTypeID const &layout_c) { -+ std::vector stride_activations; -+ std::vector stride_filters; -+ std::vector stride_output; -+ -+ // Strides for interleaved fprop -+ if (conv_kind == library::ConvKind::kFprop && -+ ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && -+ layout_b == library::LayoutTypeID::kTensorC32RSK32 && -+ layout_c == library::LayoutTypeID::kTensorNC32HW32) || -+ (layout_a == library::LayoutTypeID::kTensorNC64HW64 && -+ layout_b == library::LayoutTypeID::kTensorC64RSK64 && -+ layout_c == library::LayoutTypeID::kTensorNC64HW64))) { -+ int interleave = -+ (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; -+ -+ stride_activations.push_back(int(problem.w) * interleave); -+ stride_activations.push_back(int(problem.w) * int(problem.h) * -+ interleave); -+ stride_activations.push_back(int(problem.h) * int(problem.w) * -+ int(problem.c)); -+ -+ stride_filters.push_back(int(problem.k) * interleave); -+ stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); -+ stride_filters.push_back(int(problem.k) * int(problem.s) * -+ int(problem.r) * interleave); -+ -+ stride_output.push_back(int(problem.q) * interleave); -+ stride_output.push_back(int(problem.q) * int(problem.p) * interleave); -+ stride_output.push_back(int(problem.q) * int(problem.p) * -+ int(problem.k)); -+ } else { -+ // Strides for the rest cases -+ stride_activations.push_back(int(problem.c)); -+ stride_activations.push_back(int(problem.w) * int(problem.c)); -+ stride_activations.push_back(int(problem.h) * int(problem.w) * -+ int(problem.c)); -+ -+ stride_filters.push_back(int(problem.c / problem.groups)); -+ stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); -+ stride_filters.push_back(int(problem.r) * int(problem.s) * -+ int(problem.c / problem.groups)); -+ -+ stride_output.push_back(int(problem.k)); -+ stride_output.push_back(int(problem.q) * int(problem.k)); -+ stride_output.push_back(int(problem.q) * int(problem.p) * -+ int(problem.k)); -+ } -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ configuration.stride_a = stride_activations; -+ configuration.stride_b = stride_filters; -+ configuration.stride_c = stride_output; -+ -+ break; -+ case library::ConvKind::kDgrad: -+ configuration.stride_a = stride_output; -+ configuration.stride_b = stride_filters; -+ configuration.stride_c = stride_activations; -+ -+ break; -+ case library::ConvKind::kWgrad: -+ configuration.stride_a = stride_output; -+ configuration.stride_b = stride_activations; -+ configuration.stride_c = stride_filters; -+ -+ break; -+ default: -+ throw std::runtime_error( -+ "Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// CONV problem obtained from problem space -+ Conv2dProblem problem_; -+ -+ /// Device memory allocations -+ Conv2dWorkspace conv_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* conv2d operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Conv2dOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Conv2dOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ -+ /// Initialize reduction problem dimenstions and library::Operation -+ bool initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against host reference -+ bool verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against device reference -+ bool verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#if CUTLASS_ENABLE_CUDNN -+ -+ /// Verifies CUTLASS against cudnn reference -+ bool verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#endif //#if CUTLASS_ENABLE_CUDNN -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu -new file mode 100644 -index 0000000..34fee85 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.cu -@@ -0,0 +1,1351 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Convolution 3D profiling -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "conv3d_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+using namespace cutlass::library; -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Conv3dOperationProfiler::Conv3dOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kConv3d, -+ { -+ {ArgumentTypeID::kEnumerated, {"conv_kind"}, "Convolutional operator (fprop, dgrad, wgrad)"}, -+ {ArgumentTypeID::kInteger, {"n", "input_n"}, "Input N dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"d", "input_d"}, "Input D dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"h", "input_h"}, "Input H dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"w", "input_w"}, "Input W dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"c", "input_c"}, "Input C dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "filter_k"}, "Filter K dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"t", "filter_t"}, "Filter T dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"r", "filter_r"}, "Filter R dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"z", "output_z"}, "Output Z dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv3d problem space"}, -+ {ArgumentTypeID::kInteger, {"pad_d"}, "Padding in D direction"}, -+ {ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"}, -+ {ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"}, -+ {ArgumentTypeID::kInteger, {"stride_d"}, "Stride in D direction"}, -+ {ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"}, -+ {ArgumentTypeID::kInteger, {"stride_w"}, "Stride in W direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_d"}, "Dilation in D direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_h"}, "Dilation in H direction"}, -+ {ArgumentTypeID::kInteger, {"dilation_w"}, "Dilation in W direction"}, -+ {ArgumentTypeID::kTensor, {"Activation"}, "Tensor storing the Activation operand"}, -+ {ArgumentTypeID::kTensor, {"Filter"}, "Tensor storing the Filter operand"}, -+ {ArgumentTypeID::kTensor, {"Output"}, "Tensor storing the Output operand"}, -+ {ArgumentTypeID::kEnumerated, {"conv_mode"}, "Convolution filter mode (conv, cross)"}, -+ {ArgumentTypeID::kEnumerated, {"iterator_algorithm", "iterator_algo"}, "Convolution iterator algorithm (analytic, optimized)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "SplitK mode for serial or parallel reduction (serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kEnumerated, {"eq_gemm_provider", "eq-gemm-provider"}, "Enable profiling equivalent gemm by the following providers (cutlass)"}, -+ }, -+ { library::Provider::kReferenceDevice, library::Provider::kReferenceHost, library::Provider::kCUDNN } -+ ) { -+ -+ description_ = " Conv3d operation. Output(Tensor5D) = alpha * Input(Tensor5D) * Filter(Tensor5D) + beta * Input(Tensor5D)"; -+ -+} -+ -+/// Destructor -+Conv3dOperationProfiler::~Conv3dOperationProfiler() { -+ -+} -+ -+ -+/// Prints usage statement for the math function -+void Conv3dOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Conv3d" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Conv3dOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular convolution (specify all the convolution parameters):\n" -+ << " $ cutlass_profiler --operation=Conv3d" -+ " --Activation=f16:ndhwc --Filter=f16:ndhwc --Output=f16 --accumulator-type=f32" -+ " --n=32 --d=16 --h=14 --w=14 --c=8 --k=64 --t=3 --r=3 --s=3" -+ " --pad_d=1 --pad_h=1 --pad_w=1" -+ " --stride_d=1 --stride::h=1 --stride::w=1" -+ " --dilation_d=1 --dilation::h=1 --dilation::w=1\n\n"; -+} -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Total number of bytes loaded -+int64_t Conv3dOperationProfiler::Conv3dProblem::bytes(library::ConvDescription const &operation_desc) const { -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes_ = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * mnk.m() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * mnk.n() / 8) * mnk.k() + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes_ += int64_t(library::sizeof_bits(operation_desc.C.element) * mnk.m() / 8) * mnk.n(); -+ } -+ -+ return bytes_; -+} -+ -+/// Total number of flops computed -+int64_t Conv3dOperationProfiler::Conv3dProblem::flops( -+ library::ConvDescription const &operation_desc) const { -+ -+ cutlass::gemm::GemmCoord mnk = eq_gemm_size(operation_desc.conv_kind); -+ -+ int64_t flops_mainloop_ = int64_t(mnk.m()) * mnk.n() * mnk.k() * 2; -+ int64_t flops_epilogue_ = int64_t(mnk.m()) * int64_t(mnk.n()) * 2; -+ -+ // Adjust mainloop flop for dgrad strided -+ if (operation_desc.conv_kind == library::ConvKind::kDgrad) { -+ flops_mainloop_ = flops_mainloop_ / ( stride_d * stride_h * stride_w); -+ } -+ -+ return (flops_mainloop_ + flops_epilogue_); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Conv3dOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (!arg_as_int(problem_.n, "n", problem_space, problem)) { -+ // default value -+ problem_.n = 1; -+ } -+ -+ if (!arg_as_int(problem_.d, "d", problem_space, problem)) { -+ // default value -+ problem_.d = 8; -+ } -+ -+ if (!arg_as_int(problem_.h, "h", problem_space, problem)) { -+ // default value -+ problem_.h = 14; -+ } -+ -+ if (!arg_as_int(problem_.w, "w", problem_space, problem)) { -+ // default value -+ problem_.w = 14; -+ } -+ -+ if (!arg_as_int(problem_.c, "c", problem_space, problem)) { -+ // default value -+ problem_.c = 32; -+ } -+ -+ if (!arg_as_int(problem_.k, "k", problem_space, problem)) { -+ // default value -+ problem_.k = 32; -+ } -+ -+ if (!arg_as_int(problem_.t, "t", problem_space, problem)) { -+ // default value -+ problem_.t = 3; -+ } -+ -+ if (!arg_as_int(problem_.r, "r", problem_space, problem)) { -+ // default value -+ problem_.r = 3; -+ } -+ -+ if (!arg_as_int(problem_.s, "s", problem_space, problem)) { -+ // default value -+ problem_.s = 3; -+ } -+ -+ if (!arg_as_int(problem_.pad_d, "pad_d", problem_space, problem)) { -+ // default value -+ problem_.pad_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.pad_w, "pad_w", problem_space, problem)) { -+ // default value -+ problem_.pad_w = 1; -+ } -+ if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) { -+ // default value -+ problem_.pad_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_d, "stride_d", problem_space, problem)) { -+ // default value -+ problem_.stride_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_h, "stride_h", problem_space, problem)) { -+ // default value -+ problem_.stride_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.stride_w, "stride_w", problem_space, problem)) { -+ // default value -+ problem_.stride_w = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_d, "dilation_d", problem_space, problem)) { -+ // default value -+ problem_.dilation_d = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_h, "dilation_h", problem_space, problem)) { -+ // default value -+ problem_.dilation_h = 1; -+ } -+ -+ if (!arg_as_int(problem_.dilation_w, "dilation_w", problem_space, problem)) { -+ // default value -+ problem_.dilation_w = 1; -+ } -+ -+ //////////////////////// Convolution output dimensions p and q //////////////////////// -+ // Cutlass convolutions support arbitrary output sizes and not constriant by // -+ // input, filter, padding, striding, dilation sizes. // -+ // cuDNN sets the output dimensions (p, q) using following equations: // -+ // // -+ // output = div_up(input + 2 * pad - ((filter - 1) * dilation + 1) + 1, stride) // -+ // where; div_up(a, b) : (a - 1)/b + 1 // -+ // // -+ // Thus, when output p and q dimensions are unspecified by the user // -+ // cutlass profiler sets p and q which are cuDNN compliant. // -+ // // -+ //////////////////////////////////////////////////////////////////////////////////////// -+ // set convolution output z -+ if (!arg_as_int(problem_.z, "z", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.z = ( -+ problem_.d + -+ 2 * problem_.pad_d - -+ ((problem_.t - 1) * problem_.dilation_d + 1) -+ ) / (problem_.stride_d) -+ + 1; -+ } -+ -+ // set convolution output p -+ if (!arg_as_int(problem_.p, "p", problem_space, problem)) { -+ // default value (set using cudnn formula for output height, when p is not provided) -+ problem_.p = ( -+ problem_.h + -+ 2 * problem_.pad_h - -+ ((problem_.r - 1) * problem_.dilation_h + 1) -+ ) / (problem_.stride_h) -+ + 1; -+ } -+ -+ // set convolution output q -+ if (!arg_as_int(problem_.q, "q", problem_space, problem)) { -+ // default value (set using cudnn formula for output width, when q is not provided) -+ problem_.q = ( -+ problem_.w + -+ 2 * problem_.pad_w - -+ ((problem_.s - 1) * problem_.dilation_w + 1) -+ ) / (problem_.stride_w) -+ + 1; -+ } -+ ///////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+ if (!arg_as_SplitKModeID(problem_.split_k_mode, "split_k_mode", problem_space, problem)) { -+ // default value -+ problem_.split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ if (!arg_as_int(problem_.split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ problem_.split_k_slices = 1; -+ } -+ -+ if (!arg_as_ConvModeID(problem_.conv_mode, "conv_mode", problem_space, problem)) { -+ // default value -+ problem_.conv_mode = library::ConvModeID::kCrossCorrelation; -+ } -+ -+ if (!arg_as_ProviderID(problem_.eq_gemm_provider, "eq_gemm_provider", problem_space, problem)) { -+ // default value -+ problem_.eq_gemm_provider = library::Provider::kNone; -+ } -+ -+ if (!conv_kind_satisfies(operation_desc.conv_kind, "conv_kind", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!iterator_algorithm_satisfies(operation_desc.iterator_algorithm, "iterator_algorithm", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.activation(), "Activation", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.filter(), "Filter", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.output(), "Output", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ problem_.alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ problem_.beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(problem_.beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ // initialize library::ConvConfiguration -+ conv_workspace_.configuration.problem_size = conv::Conv3dProblemSize( -+ int(problem_.n), -+ int(problem_.d), -+ int(problem_.h), -+ int(problem_.w), -+ int(problem_.c), -+ int(problem_.k), -+ int(problem_.t), -+ int(problem_.r), -+ int(problem_.s), -+ int(problem_.z), -+ int(problem_.p), -+ int(problem_.q), -+ int(problem_.pad_d), -+ int(problem_.pad_h), -+ int(problem_.pad_w), -+ int(problem_.stride_d), -+ int(problem_.stride_h), -+ int(problem_.stride_w), -+ int(problem_.dilation_d), -+ int(problem_.dilation_h), -+ int(problem_.dilation_w), -+ static_cast(static_cast(problem_.conv_mode)), -+ int(problem_.split_k_slices), -+ 1 // groups -+ ); -+ -+ conv_workspace_.configuration.split_k_mode = static_cast(static_cast(problem_.split_k_mode)); -+ -+ conv_workspace_.configuration.layout_activations.stride() = make_Coord( -+ int(problem_.c), -+ int(problem_.w) * int(problem_.c), -+ int(problem_.h) * int(problem_.w) * int(problem_.c), -+ int(problem_.d) * int(problem_.h) * int(problem_.w) * int(problem_.c) -+ ); -+ -+ conv_workspace_.configuration.layout_filters.stride() = make_Coord( -+ int(problem_.c), -+ int(problem_.s) * int(problem_.c), -+ int(problem_.r) * int(problem_.s) * int(problem_.c), -+ int(problem_.t) * int(problem_.r) * int(problem_.s) * int(problem_.c) -+ ); -+ -+ conv_workspace_.configuration.layout_output.stride() = make_Coord( -+ int(problem_.k), -+ int(problem_.q) * int(problem_.k), -+ int(problem_.q) * int(problem_.p) * int(problem_.k), -+ int(problem_.z) * int(problem_.q) * int(problem_.p) * int(problem_.k) -+ ); -+ -+ -+ // initialize library::ConvArguments -+ conv_workspace_.arguments.A = nullptr; -+ conv_workspace_.arguments.B = nullptr; -+ conv_workspace_.arguments.C = nullptr; -+ conv_workspace_.arguments.D = nullptr; -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode not supported for conv3d -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if(!initialize_reduction_configuration_(options, report, device_context, operation, problem_space, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&conv_workspace_.configuration, &conv_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Conv3dOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "Activation", problem_space, -+ std::string(library::to_string(operation_desc.activation().element)) -+ + ":" + library::to_string(operation_desc.activation().layout)); -+ -+ set_argument(result, "Filter", problem_space, -+ std::string(library::to_string(operation_desc.filter().element)) -+ + ":" + library::to_string(operation_desc.filter().layout)); -+ -+ set_argument(result, "Output", problem_space, -+ std::string(library::to_string(operation_desc.output().element)) -+ + ":" + library::to_string(operation_desc.output().layout)); -+ -+ set_argument(result, "conv_kind", problem_space, library::to_string(operation_desc.conv_kind)); -+ -+ set_argument(result, "iterator_algorithm", problem_space, std::string(library::to_string(operation_desc.iterator_algorithm))); -+ -+ set_argument(result, "n", problem_space, problem_.n); -+ set_argument(result, "d", problem_space, problem_.d); -+ set_argument(result, "h", problem_space, problem_.h); -+ set_argument(result, "w", problem_space, problem_.w); -+ set_argument(result, "c", problem_space, problem_.c); -+ -+ set_argument(result, "k", problem_space, problem_.k); -+ set_argument(result, "t", problem_space, problem_.t); -+ set_argument(result, "r", problem_space, problem_.r); -+ set_argument(result, "s", problem_space, problem_.s); -+ -+ set_argument(result, "z", problem_space, problem_.z); -+ set_argument(result, "p", problem_space, problem_.p); -+ set_argument(result, "q", problem_space, problem_.q); -+ -+ set_argument(result, "pad_d", problem_space, problem_.pad_d); -+ set_argument(result, "pad_h", problem_space, problem_.pad_h); -+ set_argument(result, "pad_w", problem_space, problem_.pad_w); -+ -+ set_argument(result, "stride_d", problem_space, problem_.stride_d); -+ set_argument(result, "stride_h", problem_space, problem_.stride_h); -+ set_argument(result, "stride_w", problem_space, problem_.stride_w); -+ -+ set_argument(result, "dilation_d", problem_space, problem_.dilation_d); -+ set_argument(result, "dilation_h", problem_space, problem_.dilation_h); -+ set_argument(result, "dilation_w", problem_space, problem_.dilation_w); -+ -+ set_argument(result, "split_k_mode", problem_space, -+ std::string(library::to_string(problem_.split_k_mode))); -+ set_argument(result, "split_k_slices", problem_space, problem_.split_k_slices); -+ -+ set_argument(result, "conv_mode", problem_space, -+ std::string(library::to_string(problem_.conv_mode))); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(problem_.alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(problem_.beta, operation_desc.element_epilogue)); -+ -+ set_argument(result, "eq_gemm_provider", problem_space, -+ std::string(library::to_string(problem_.eq_gemm_provider))); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Bytes of activation, filter, and output tensors -+ result.bytes = problem_.bytes(operation_desc); -+ -+ // Theoritical flops required for the computation -+ result.flops = problem_.flops(operation_desc); -+ -+ // Measured runtime -+ result.runtime = 0; -+ -+} -+ -+/// Initialize reduction problem dimenstions and library::Operation -+bool Conv3dOperationProfiler::initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::ConvDescription const &conv_desc = -+ static_cast(operation->description()); -+ -+ library::ConvKind const &conv_kind = conv_desc.conv_kind; -+ -+ if (!cast_from_double(problem_.alpha_one, conv_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, conv_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// This chooses the appropriate stride element of the row-major C tensor. -+ int const & tensor_c_stride_idx = (conv_kind == library::ConvKind::kWgrad ? 3 : 0); -+ -+ /// intialize library::ReductionConfiguration -+ conv_workspace_.reduction_configuration.problem_size = problem_.eq_gemm_size(conv_kind).mn(); -+ conv_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ conv_workspace_.reduction_configuration.partition_stride = problem_.eq_gemm_size(conv_kind).mn().product(); -+ conv_workspace_.reduction_configuration.ldw = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.lds = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ conv_workspace_.reduction_configuration.ldd = conv_workspace_.configuration.layout_c(conv_kind).stride()[tensor_c_stride_idx]; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ conv_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ conv_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ conv_desc.C.element, // element output -+ conv_desc.element_epilogue // element compute -+ ); -+ -+#if 0// debug print to check which reduction instance is selected -+ std::cout << reduction_key << "\n"; -+#endif -+ auto reduction_it = Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if(reduction_it == Singleton::get().operation_table.reduction_operations.end()) { -+ -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k conv2d operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+ -+/// Initializes workspace -+Status Conv3dOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::ConvDescription const &operation_desc = -+ static_cast(underlying_operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ conv_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ conv_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ conv_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ conv_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ problem_.extent_a(operation_desc.conv_kind), -+ conv_workspace_.stride_a(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ problem_.extent_b(operation_desc.conv_kind), -+ conv_workspace_.stride_b(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ conv_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ problem_.extent_c(operation_desc.conv_kind), -+ conv_workspace_.stride_c(operation_desc.conv_kind), -+ conv_workspace_.problem_count -+ ); -+ -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&conv_workspace_.configuration); -+ conv_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &conv_workspace_.configuration, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&conv_workspace_.reduction_configuration); -+ conv_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &conv_workspace_.reduction_configuration, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kConv3d; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Conv3dOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ cudaError_t result; -+ -+ // Initialize structure containing Conv arguments -+ set_cutlass_operator_arguments_(); -+ -+ conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data()); -+ -+ // -+ // Run the CUTLASS operation -+ // -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+#if 0 -+ std::cout << "profiling : " << std::endl -+ << "conv2d : " << operation->description().name << std::endl -+ << "underlying conv2d : " << underlying_operation->description().name << std::endl -+ << "reduction : " << reduction_op_->description().name << std::endl; -+#endif -+ -+ // run cutlass conv2d operation -+ results_.back().status = underlying_operation->run( -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ results_.back().status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ } -+ -+ // Synchronize before running device reference -+ result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUDNN -+ // Run verification cudnn reference -+ if (options.verification.provider_enabled(library::Provider::kCUDNN)) { -+ -+ // Guard against unsupported cases -+ auto const & conv_desc = static_cast(operation->description()); -+ -+ Status status = cudnn_satisfies(conv_desc, conv_workspace_.configuration); -+ -+ // Initialize reference data to the source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ if (status == Status::kSuccess) { -+ // call cudnn verification if supported -+ verify_with_cudnn_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else if (status == Status::kErrorInvalidProblem) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kInvalidProblem; -+ } -+ -+ else { -+ // set verification map for cudnn to not supported -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+ // Run verification host reference -+ if (options.verification.provider_enabled(library::Provider::kReferenceHost)) { -+ -+ // Restore reference data back to initial source data -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ -+ verify_with_host_reference_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv3dOperationProfiler::verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ Status status; -+ -+ // -+ // Find host reference operation using conv functional description key -+ // -+ library::OperationDescription const &desc = operation->description(); -+ -+ auto &conv_desc = static_cast(desc); -+ -+ library::ConvFunctionalKey conv_key( -+ library::Provider::kReferenceHost, -+ conv_desc.conv_kind, -+ conv_desc.A.element, -+ conv_desc.A.layout, -+ conv_desc.B.element, -+ conv_desc.B.layout, -+ conv_desc.C.element, -+ conv_desc.C.layout, -+ conv_desc.tile_description.math_instruction.element_accumulator, -+ conv_desc.element_epilogue); -+ -+#if 0 // debug print to check which host refererence instance is selected -+ std::cout << conv_key << "\n"; -+#endif -+ -+ auto operators_it = Singleton::get().operation_table.conv3d_operations.find(conv_key); -+ -+ if(operators_it == Singleton::get().operation_table.conv3d_operations.end()) { -+ -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // conv3d host reference minimum cc is 0 (CPU) and no iterator algorithm -+ library::ConvPreferenceKey preference_key(0, library::IteratorAlgorithmID::kNone); -+ auto cc_it = operators_it->second.find(preference_key); -+ -+ if(cc_it == operators_it->second.end()) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotRun; -+ return true; -+ } -+ -+ // host refernce has only one instances in ConvOperationVectorMap -+ library::Operation const *reference_op = cc_it->second[0]; -+ -+ // -+ // Copy input tensors A, B, and C from device to host buffers -+ // -+ conv_workspace_.host_tensor_a.resize(conv_workspace_.A->bytes()); -+ conv_workspace_.host_tensor_b.resize(conv_workspace_.B->bytes()); -+ conv_workspace_.host_tensor_c.resize(conv_workspace_.C->bytes()); -+ conv_workspace_.A->copy_to_host(conv_workspace_.host_tensor_a.data()); -+ conv_workspace_.B->copy_to_host(conv_workspace_.host_tensor_b.data()); -+ conv_workspace_.C->copy_to_host(conv_workspace_.host_tensor_c.data()); -+ -+ // -+ // Initialize structure containing Conv3d arguments -+ // -+ conv_workspace_.arguments.A = conv_workspace_.host_tensor_a.data(); -+ conv_workspace_.arguments.B = conv_workspace_.host_tensor_b.data(); -+ conv_workspace_.arguments.C = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.D = conv_workspace_.host_tensor_c.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Intialize host reference operation -+ // -+ std::vector host_workspace_reference_op; -+ -+ uint64_t workspace_size = reference_op->get_host_workspace_size(&conv_workspace_.configuration); -+ host_workspace_reference_op.resize(workspace_size, 0); -+ -+ reference_op->initialize( -+ &conv_workspace_.configuration, -+ host_workspace_reference_op.data()); -+ -+ // -+ // Run host reference operation -+ // -+ status = reference_op->run( -+ &conv_workspace_.arguments, -+ host_workspace_reference_op.data()); -+ -+ // Handle errors -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kReferenceHost] = Disposition::kNotVerified; -+ return true; -+ } -+ -+ // -+ // Copy host reference output to device memory for equality check on device -+ // -+ conv_workspace_.Reference->copy_from_host(conv_workspace_.arguments.D); -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kReferenceHost] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference, -+ conv_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kReferenceHost] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ static_cast(operation->description()), -+ library::Provider::kCUTLASS, -+ library::Provider::kReferenceHost); -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+ -+/// Verifies CUTLASS against host reference -+bool Conv3dOperationProfiler::verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ // TODO: verify cutlass conv3d against device reference -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/// Measures performance results -+bool Conv3dOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ set_cutlass_operator_arguments_(); -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &conv_workspace_.arguments, -+ conv_workspace_.host_workspace.data(), -+ conv_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+ -+} -+ -+/// Updates the arguments structure for the CUTLASS operator based on -+/// the problem index. -+void Conv3dOperationProfiler::set_cutlass_operator_arguments_(int problem_idx) { -+ // Initialize structure containing Conv3d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->batch_data(problem_idx); -+ conv_workspace_.arguments.B = conv_workspace_.B->batch_data(problem_idx); -+ conv_workspace_.arguments.C = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.arguments.D = conv_workspace_.Computed->batch_data(problem_idx); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ // update library::ConvArguments for parallel split-k reduction -+ conv_workspace_.arguments.D = conv_workspace_.device_workspace.data(); -+ conv_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ conv_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ /// intialize library::ReductionArguments -+ conv_workspace_.reduction_arguments.workspace = conv_workspace_.device_workspace.data(); -+ conv_workspace_.reduction_arguments.source = conv_workspace_.C->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.destination = conv_workspace_.Computed->batch_data(problem_idx); -+ conv_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ conv_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+} -+ -+/// Method to profile a CUTLASS Operation -+Status Conv3dOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize conv2d underlying operation to handle parallel reduction -+ library::Operation const* underlying_operation = operation; -+ -+ if(conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_conv_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % conv_workspace_.problem_count); -+ -+ set_cutlass_operator_arguments_(problem_idx); -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Setup rotating workspace -+ int problem_idx = (iteration % conv_workspace_.problem_count); -+ -+ set_cutlass_operator_arguments_(problem_idx); -+ -+ // Run underlying conv2d operation -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &conv_workspace_.reduction_arguments, -+ conv_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ } -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#if CUTLASS_ENABLE_CUDNN -+ -+/// Verifies CUTLASS against cudnn reference -+bool Conv3dOperationProfiler::verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ auto &conv_desc = static_cast(operation->description()); -+ -+ // -+ // Construct cudnn operators -+ // -+ -+ CudnnCreate handle; -+ cudnnStatus_t status = handle.get_cudnn_create_status(); -+ -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ // Initialize structure containing Conv2d arguments -+ conv_workspace_.arguments.A = conv_workspace_.A->data(); -+ conv_workspace_.arguments.B = conv_workspace_.B->data(); -+ conv_workspace_.arguments.D = conv_workspace_.Reference->data(); -+ conv_workspace_.arguments.alpha = problem_.alpha.data(); -+ conv_workspace_.arguments.beta = problem_.beta.data(); -+ conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // cuDNN does not support four tensor arguments, so we copy the tensor C data into -+ // tensor D. -+ conv_workspace_.Reference->copy_from_device(conv_workspace_.C->data()); -+ conv_workspace_.arguments.C = conv_workspace_.arguments.D; -+ -+ try { -+ -+ // -+ // Construct dispatcher to cudnn operator -+ // -+ -+ detail::cudnnConvDispatcher conv_op( -+ conv_desc, -+ conv_workspace_.configuration, -+ conv_workspace_.arguments, -+ handle -+ ); -+ -+ if (conv_op.status != Status::kSuccess) { -+ if (conv_op.status == Status::kErrorNotSupported) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kNotSupported; -+ -+ } else { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ return true; -+ } -+ -+ -+ status = conv_op(handle); -+ -+ // Handle errors -+ if (status != CUDNN_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUDNN] = compare_tensors( -+ options, -+ *conv_workspace_.Computed, -+ *conv_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUDNN] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ conv_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUDNN); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUDNN] = Disposition::kFailed; -+ } -+ -+ // Return true means continue profiling -+ return true; -+ -+} -+ -+#endif // #if CUTLASS_ENABLE_CUDNN -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h -new file mode 100644 -index 0000000..aba832e ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/conv3d_operation_profiler.h -@@ -0,0 +1,447 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for convolution -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/handle.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class Conv3dOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct Conv3dProblem { -+ -+ int64_t n, d, h, w, c, z, p, q, k, t, r, s; -+ int64_t pad_d, pad_h, pad_w; -+ int64_t stride_d, stride_h, stride_w; -+ int64_t dilation_d, dilation_h, dilation_w; -+ -+ std::vector alpha; -+ std::vector beta; -+ -+ library::SplitKMode split_k_mode; -+ int64_t split_k_slices; -+ -+ library::ConvModeID conv_mode; -+ -+ library::Provider eq_gemm_provider; -+ -+ // convolution with parallel interleaved reduction -+ // convolution epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::ConvDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::ConvDescription const &operation_desc) const; -+ -+ /// Infers output size from theinput size, padding, stride, and dilation -+ void set_default_output_size() { -+ z = ((d + pad_d - t * dilation_d) / stride_d) + 1; -+ p = ((h + pad_h - r * dilation_h) / stride_h) + 1; -+ q = ((w + pad_w - s * dilation_w) / stride_w) + 1; -+ } -+ -+ // Returns equivalent gemm problem size for convolution -+ cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); -+ case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); -+ case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor A -+ std::vector extent_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; -+ case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; -+ case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor B -+ std::vector extent_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; -+ case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; -+ case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns extent for tensor C -+ std::vector extent_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; -+ case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; -+ case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix A -+ library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix B -+ library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm -+ case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns layout for equivalent gemm matrix C -+ library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ // Gemm operator assumes column-major output -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix A -+ int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix B -+ int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); -+ case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns leading dimenstion for equivalent gemm matrix C -+ int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ case library::ConvKind::kDgrad: -+ case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ }; -+ -+ /// Workspace used -+ struct Conv2dWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments for convolution operator -+ library::Conv3dConfiguration configuration; -+ library::ConvArguments arguments; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ /// Buffer used for the cutlass conv2d operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ /// Host data buffers for host reference operation -+ /// host buffer for tensor -+ std::vector host_tensor_a; -+ -+ /// host buffer for tensor b -+ std::vector host_tensor_b; -+ -+ /// host buffer for tensor c -+ std::vector host_tensor_c; -+ -+ -+ // -+ // Methods -+ // -+ -+ Conv2dWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ -+ // Returns stride vector for tensor A -+ std::vector stride_a(library::ConvKind const &conv_kind) { -+ return { -+ configuration.layout_a(conv_kind).stride()[0], -+ configuration.layout_a(conv_kind).stride()[1], -+ configuration.layout_a(conv_kind).stride()[2], -+ configuration.layout_a(conv_kind).stride()[3] -+ }; -+ } -+ -+ // Returns stride vector for tensor B -+ std::vector stride_b(library::ConvKind const &conv_kind) { -+ -+ return { -+ configuration.layout_b(conv_kind).stride()[0], -+ configuration.layout_b(conv_kind).stride()[1], -+ configuration.layout_b(conv_kind).stride()[2], -+ configuration.layout_b(conv_kind).stride()[3] -+ }; -+ } -+ -+ // Returns stride vector for tensor C -+ std::vector stride_c(library::ConvKind const &conv_kind) { -+ -+ return { -+ configuration.layout_c(conv_kind).stride()[0], -+ configuration.layout_c(conv_kind).stride()[1], -+ configuration.layout_c(conv_kind).stride()[2], -+ configuration.layout_c(conv_kind).stride()[3] -+ }; -+ } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// CONV problem obtained from problem space -+ Conv3dProblem problem_; -+ -+ /// Device memory allocations -+ Conv2dWorkspace conv_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* conv2d operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Conv3dOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Conv3dOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Updates the arguments structure for the CUTLASS operator based on -+ /// the problem index. -+ void set_cutlass_operator_arguments_(int problem_idx = 0); -+ -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ /// Initialize reduction problem dimenstions and library::Operation -+ bool initialize_reduction_configuration_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::ConvDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against host reference -+ bool verify_with_host_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against device reference -+ bool verify_with_device_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#if CUTLASS_ENABLE_CUDNN -+ -+ /// Verifies CUTLASS against cudnn reference -+ bool verify_with_cudnn_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+#endif //#if CUTLASS_ENABLE_CUDNN -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu -new file mode 100644 -index 0000000..5f7354c ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.cu -@@ -0,0 +1,1159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuBLAS. -+*/ -+ -+#include -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include "cublas_helpers.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts a cuBLAS status to cutlass::Status -+Status get_cutlass_status(cublasStatus_t cublas) { -+ -+ switch (cublas) { -+ case CUBLAS_STATUS_SUCCESS: -+ return Status::kSuccess; -+ case CUBLAS_STATUS_INVALID_VALUE: -+ return Status::kErrorInvalidProblem; -+ case CUBLAS_STATUS_NOT_SUPPORTED: -+ return Status::kErrorNotSupported; -+ default: break; -+ } -+ return Status::kErrorInternal; -+} -+ -+/// Converts a cuBLASS status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cublasStatus_t cublas_status) { -+ -+ if (cublas_status == CUBLAS_STATUS_INVALID_VALUE) { -+ return Disposition::kInvalidProblem; -+ } -+ else if (cublas_status == CUBLAS_STATUS_NOT_SUPPORTED) { -+ return Disposition::kNotSupported; -+ } -+ return Disposition::kFailed; -+} -+ -+/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -+bool get_cublas_transpose_operation( -+ cublasOperation_t &operation, -+ library::LayoutTypeID layout, -+ library::ComplexTransform transform) { -+ -+ switch (layout) { -+ case library::LayoutTypeID::kColumnMajor: -+ if (transform == library::ComplexTransform::kNone) { -+ operation = CUBLAS_OP_N; -+ return true; -+ } -+ else { -+ return false; -+ } -+ break; -+ case library::LayoutTypeID::kRowMajor: -+ if (transform == library::ComplexTransform::kNone) { -+ operation = CUBLAS_OP_T; -+ return true; -+ } -+ else if (transform == library::ComplexTransform::kConjugate) { -+ operation = CUBLAS_OP_C; -+ return true; -+ } -+ break; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration -+bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type) { -+ switch (element_type) { -+ case library::NumericTypeID::kF16: -+ data_type = CUDA_R_16F; -+ return true; -+ -+ case library::NumericTypeID::kBF16: -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ break; -+ -+ case library::NumericTypeID::kF32: -+ data_type = CUDA_R_32F; -+ return true; -+ -+ case library::NumericTypeID::kF64: -+ data_type = CUDA_R_64F; -+ return true; -+ -+ case library::NumericTypeID::kS4: -+ break; -+ -+ case library::NumericTypeID::kS8: -+ data_type = CUDA_R_8I; -+ return true; -+ -+ case library::NumericTypeID::kS16: -+ break; -+ -+ case library::NumericTypeID::kS32: -+ data_type = CUDA_R_32I; -+ return true; -+ -+ case library::NumericTypeID::kS64: -+ break; -+ -+ case library::NumericTypeID::kU4: -+ break; -+ -+ case library::NumericTypeID::kU8: -+ data_type = CUDA_R_8U; -+ return true; -+ -+ case library::NumericTypeID::kU16: -+ break; -+ -+ case library::NumericTypeID::kU32: -+ data_type = CUDA_R_32U; -+ return true; -+ -+ case library::NumericTypeID::kU64: -+ break; -+ -+ case library::NumericTypeID::kB1: -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ data_type = CUDA_C_32F; -+ return true; -+ -+ case library::NumericTypeID::kCF64: -+ data_type = CUDA_C_64F; -+ return true; -+ -+ case library::NumericTypeID::kInvalid: -+ -+ default: -+ break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::SideMode to cuBLAS side mode -+bool get_cublas_side_mode(cublasSideMode_t& side, SideMode side_mode) { -+ -+ switch (side_mode) { -+ case SideMode::kLeft: -+ side = CUBLAS_SIDE_LEFT; -+ return true; -+ case SideMode::kRight: -+ side = CUBLAS_SIDE_RIGHT; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::FillMode to cuBLAS fill mode -+bool get_cublas_fill_mode(cublasFillMode_t& uplo, FillMode fill_mode) { -+ -+ switch (fill_mode) { -+ case FillMode::kLower: -+ uplo = CUBLAS_FILL_MODE_LOWER; -+ return true; -+ case FillMode::kUpper: -+ uplo = CUBLAS_FILL_MODE_UPPER; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+/// Maps a cutlass::DiagType to cuBLAS diag type -+bool get_cublas_diag_type(cublasDiagType_t& diag, DiagType diag_type) { -+ -+ switch (diag_type) { -+ case DiagType::kNonUnit: -+ diag = CUBLAS_DIAG_NON_UNIT; -+ return true; -+ case DiagType::kUnit: -+ diag = CUBLAS_DIAG_UNIT; -+ return true; -+ default: break; -+ } -+ -+ return false; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class -+cublasGemmAlgo_t get_cublas_gemm_algo(int cta_m, int cta_n, int cta_k, library::OpcodeClassID opcode_class) { -+ return (opcode_class == library::OpcodeClassID::kSimt ? -+ CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular GEMM description -+Status cublas_satisfies(library::GemmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasGemmExDispatcher::cublasGemmExDispatcher( -+ library::GemmDescription const &op_desc, -+ library::GemmUniversalConfiguration configuration_, -+ library::GemmUniversalArguments arguments_, -+ cublasGemmAlgo_t algorithm -+): -+ configuration(configuration_), arguments(arguments_), algo(algorithm), status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_transpose_operation(trans_B, op_desc.B.layout, op_desc.transform_B)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ op_desc.B.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes GEMM using these arguments -+cublasStatus_t cublasGemmExDispatcher::operator()(cublasHandle_t handle) { -+ -+ if (configuration.mode == library::GemmUniversalMode::kBatched) { -+ return cublasGemmStridedBatchedEx( -+ handle, -+ trans_A, -+ trans_B, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ arguments.alpha, -+ arguments.A, -+ data_type_A, -+ int(configuration.lda), -+ arguments.batch_stride_A, -+ arguments.B, -+ data_type_B, -+ int(configuration.ldb), -+ arguments.batch_stride_B, -+ arguments.beta, -+ arguments.D, -+ data_type_C, -+ int(configuration.ldc), -+ arguments.batch_stride_C, -+ configuration.batch_count, -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ compute_type, -+ #else -+ compute_data_type, -+ #endif -+ algo -+ ); -+ } -+ else { -+ return cublasGemmEx( -+ handle, -+ trans_A, -+ trans_B, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ arguments.alpha, -+ arguments.A, -+ data_type_A, -+ int(configuration.lda), -+ arguments.B, -+ data_type_B, -+ int(configuration.ldb), -+ arguments.beta, -+ arguments.D, -+ data_type_C, -+ int(configuration.ldc), -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ compute_type, -+ #else -+ compute_data_type, -+ #endif -+ algo -+ ); -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular RankK description -+Status cublas_satisfies(library::RankKDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasRankKDispatcher::cublasRankKDispatcher( -+ library::RankKDescription const &op_desc, -+ library::RankKConfiguration configuration_, -+ library::RankKArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ blas_mode = op_desc.blas_mode; -+ num_ranks = op_desc.num_ranks; -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes RankK using these arguments -+cublasStatus_t cublasRankKDispatcher::operator()(cublasHandle_t handle) { -+ -+ // SYRK and HERK -+ if (num_ranks == 1) { -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ return cublasSsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZherk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasCherk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsyrk( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+ } -+ -+ // SYR2K and HER2K -+ else if (num_ranks == 2) { -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ return cublasSsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZher2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+ #if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ #endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasCher2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsyr2k( -+ handle, -+ uplo, -+ trans_A, -+ configuration.problem_size.n(), -+ configuration.problem_size.k(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+ } -+ else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular TRMM description -+Status cublas_satisfies(library::TrmmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.D.element == library::NumericTypeID::kS4 || -+ desc.D.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasTrmmDispatcher::cublasTrmmDispatcher( -+ library::TrmmDescription const &op_desc, -+ library::TrmmConfiguration configuration_, -+ library::TrmmArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ good = (good && get_cublas_transpose_operation(trans_A, op_desc.A.layout, op_desc.transform_A)); -+ good = (good && get_cublas_side_mode(side, op_desc.side_mode)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_diag_type(diag, op_desc.diag_type)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_B, op_desc.B.element)); -+ good = (good && get_cublas_datatype(data_type_D, op_desc.D.element)); -+ -+ // if A is Transposed, then for cuBLAS that is inverted Fill Mode. -+ if (trans_A == CUBLAS_OP_T || trans_A == CUBLAS_OP_C) { -+ if (uplo == CUBLAS_FILL_MODE_LOWER) -+ uplo = CUBLAS_FILL_MODE_UPPER; -+ else -+ uplo = CUBLAS_FILL_MODE_LOWER; -+ } -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes TRMM using these arguments -+cublasStatus_t cublasTrmmDispatcher::operator()(cublasHandle_t handle) { -+ -+ if (data_type_A == data_type_D && data_type_A == CUDA_R_64F) { -+ return cublasDtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_R_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasStrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_C_64F) { -+ return cublasZtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else if (data_type_A == data_type_D && data_type_A == CUDA_C_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasCtrmm( -+ handle, -+ side, -+ uplo, -+ trans_A, -+ diag, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.D), -+ int(configuration.ldd) -+ ); -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a status if cuBLAS can satisfy a particular Symm description -+Status cublas_satisfies(library::SymmDescription const &desc) { -+ auto const &math_instruction = desc.tile_description.math_instruction; -+ -+ if (math_instruction.element_accumulator == library::NumericTypeID::kS32 && -+ math_instruction.opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // output type S4 and S8 not supported in cuBLAS -+ if (desc.C.element == library::NumericTypeID::kS4 || -+ desc.C.element == library::NumericTypeID::kS8) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.A.element == library::NumericTypeID::kBF16 || -+ desc.A.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // input type BF16 and TF32 not supported in cuBLAS -+ if (desc.B.element == library::NumericTypeID::kBF16 || -+ desc.B.element == library::NumericTypeID::kTF32) { -+ -+ return Status::kErrorNotSupported; -+ } -+ -+ // only column major layout is supported in cuBLAS -+ if (desc.A.layout != library::LayoutTypeID::kColumnMajor || -+ desc.transform_A != library::ComplexTransform::kNone) { -+ -+ return Status::kErrorNotSupported; -+} -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+cublasSymmDispatcher::cublasSymmDispatcher( -+ library::SymmDescription const &op_desc, -+ library::SymmConfiguration configuration_, -+ library::SymmArguments arguments_ -+): -+ configuration(configuration_), arguments(arguments_), status(Status::kSuccess) { -+ -+ blas_mode = op_desc.blas_mode; -+ -+ bool good = true; -+ -+ good = (good && get_cublas_side_mode(side, op_desc.side_mode)); -+ good = (good && get_cublas_fill_mode(uplo, op_desc.fill_mode)); -+ good = (good && get_cublas_datatype(data_type_A, op_desc.A.element)); -+ good = (good && get_cublas_datatype(data_type_C, op_desc.C.element)); -+ -+ good = (good && get_cublas_datatype( -+ compute_data_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // cuBLAS introduces a separate cublasComputeType enumerant to more precisely describe -+ // internal numerical data types used in the computation. -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ if (good && -+ op_desc.A.element == library::NumericTypeID::kF32 && -+ opcode_class == library::OpcodeClassID::kTensorOp) { -+ -+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ } -+ else if (good) { -+ bool const isPedantic = false; -+ switch (compute_data_type) { -+ case CUDA_R_32F: -+ case CUDA_C_32F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; -+ break; -+ case CUDA_R_64F: -+ case CUDA_C_64F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; -+ break; -+ case CUDA_R_16F: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; -+ break; -+ case CUDA_R_32I: -+ compute_type = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; -+ break; -+ default: -+ good = false; -+ break; -+ } -+ } -+#endif // __CUDACC_VER_MAJOR__ >= 11 -+ -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+} -+ -+/// Executes Symm using these arguments -+cublasStatus_t cublasSymmDispatcher::operator()(cublasHandle_t handle) { -+ -+ // SYMM and HEMM -+ if (data_type_A == data_type_C && data_type_A == CUDA_R_64F) { -+ return cublasDsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_R_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ return cublasSsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_64F) { -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasZhemm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasZsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ -+ } else if (data_type_A == data_type_C && data_type_A == CUDA_C_32F) { -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ if (cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH) != CUBLAS_STATUS_SUCCESS) -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+#endif -+ -+ if (blas_mode == BlasMode::kHermitian) { -+ return cublasChemm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ else { -+ return cublasCsymm( -+ handle, -+ side, -+ uplo, -+ configuration.problem_size.m(), -+ configuration.problem_size.n(), -+ static_cast(arguments.alpha), -+ static_cast(arguments.A), -+ int(configuration.lda), -+ static_cast(arguments.B), -+ int(configuration.ldb), -+ static_cast(arguments.beta), -+ static_cast(arguments.D), -+ int(configuration.ldc) -+ ); -+ } -+ } else { -+ return CUBLAS_STATUS_NOT_SUPPORTED; -+ } -+} -+ -+} // namespace detail -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+#endif // #if CUTLASS_ENABLE_CUBLAS -diff --git a/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h -new file mode 100644 -index 0000000..8c36fb7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cublas_helpers.h -@@ -0,0 +1,358 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuBLAS. -+*/ -+ -+#pragma once -+ -+#if CUTLASS_ENABLE_CUBLAS -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/blas3.h" -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Converts a cuBLAS status to cutlass::Status -+Status get_cutlass_status(cublasStatus_t cublas); -+ -+/// Converts a cuBLASS status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cublasStatus_t cublas_status); -+ -+/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation -+bool get_cublas_transpose_operation( -+ cublasOperation_t &operation, -+ library::LayoutTypeID layout, -+ library::ComplexTransform transform = library::ComplexTransform::kNone); -+ -+/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration -+bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); -+ -+/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class -+cublasGemmAlgo_t get_cublas_gemm_algo( -+ int cta_m, -+ int cta_n, -+ int cta_k, -+ library::OpcodeClassID opcode_class); -+ -+/// Returns a status if cuBLAS can satisfy a particular GEMM description -+Status cublas_satisfies(library::GemmDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular RankK description -+Status cublas_satisfies(library::RankKDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular TRMM description -+Status cublas_satisfies(library::TrmmDescription const &desc); -+ -+/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description -+Status cublas_satisfies(library::SymmDescription const &desc); -+ -+/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and -+/// to destroy cublasHandle_t on CublasCreate object destruction. -+/// Additionaly, it provides implicit cast from CublasCreate's object to cublasHandle_t's object -+class CublasCreate { -+private: -+ cublasHandle_t handle; -+ cublasStatus_t status; -+ -+public: -+ CublasCreate() { -+ status = cublasCreate(&handle); -+ } -+ -+ ~CublasCreate() { -+ cublasDestroy(handle); -+ } -+ -+ /// Implicit cast CublasCreate object to cublasHandle_t -+ operator cublasHandle_t() const { return handle; } -+ -+ /// returns cublasStatus_t for handle creation -+ cublasStatus_t get_cublas_create_status() { return status; } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Selects one or more cuBLAS algorithms. -+static void select_cublas_algorithms( -+ std::vector &algorithms, -+ Options const &options, -+ library::GemmDescription const &op_desc) { -+ -+ library::OpcodeClassID const & opcode_class = -+ op_desc.tile_description.math_instruction.opcode_class; -+ -+ switch (options.library.algorithm_mode) { -+ case AlgorithmMode::kMatching: -+ { -+ algorithms.push_back(get_cublas_gemm_algo( -+ op_desc.tile_description.threadblock_shape.m(), -+ op_desc.tile_description.threadblock_shape.n(), -+ op_desc.tile_description.threadblock_shape.k(), -+ opcode_class)); -+ break; -+ } -+ -+ case AlgorithmMode::kBest: -+ { -+ // Choose first enumerated mode. If none are enumerated, choose based on opcode class -+ // and evaluate all of them. -+ -+ if (options.library.algorithms.empty()) { -+ // Enumerate all algorithms -+ if (opcode_class == library::OpcodeClassID::kSimt) { -+ -+ for (int algo = CUBLAS_GEMM_DEFAULT; -+ algo <= CUBLAS_GEMM_ALGO23; -+ ++algo) { -+ -+ algorithms.push_back(cublasGemmAlgo_t(algo)); -+ } -+ } -+ else { -+ -+ for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ ++algo) { -+ -+ algorithms.push_back(cublasGemmAlgo_t(algo)); -+ } -+ } -+ } -+ else { -+ // Use the listed algorithms -+ algorithms.reserve(options.library.algorithms.size()); -+ -+ for (int algo : options.library.algorithms) { -+ algorithms.push_back(reinterpret_cast(algo)); -+ } -+ } -+ -+ break; -+ } -+ -+ case AlgorithmMode::kDefault: -+ { -+ -+ // Use the library's default algorithm -+ algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? -+ CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -+ -+ break; -+ } -+ default: -+ { -+ break; -+ } -+ } -+} -+ -+/// Dispatcher to cublasGemmEx() -+struct cublasGemmExDispatcher { -+ -+ // -+ // Data members -+ // -+ library::GemmUniversalConfiguration configuration; -+ library::GemmUniversalArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasOperation_t trans_B; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ cublasGemmAlgo_t algo; -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasGemmExDispatcher( -+ library::GemmDescription const &op_desc, -+ library::GemmUniversalConfiguration configuration_, -+ library::GemmUniversalArguments arguments_, -+ cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT -+ ); -+ -+ /// Executes GEMM using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublas rank k update kernels -+struct cublasRankKDispatcher { -+ -+ // -+ // Data members -+ // -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasFillMode_t uplo; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ int num_ranks; //(rank-k or rank-2k) -+ BlasMode blas_mode; //(symmetric or hermitian) -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasRankKDispatcher( -+ library::RankKDescription const &op_desc, -+ library::RankKConfiguration configuration_, -+ library::RankKArguments arguments_ -+ ); -+ -+ /// Executes RankK using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublasTrmm() -+struct cublasTrmmDispatcher { -+ -+ // -+ // Data members -+ // -+ library::TrmmConfiguration configuration; -+ library::TrmmArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasOperation_t trans_A; -+ cublasSideMode_t side; -+ cublasFillMode_t uplo; -+ cublasDiagType_t diag; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_D; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasTrmmDispatcher( -+ library::TrmmDescription const &op_desc, -+ library::TrmmConfiguration configuration_, -+ library::TrmmArguments arguments_ -+ ); -+ -+ /// Executes TRMM using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Dispatcher to cublas symm/hemm update kernels -+struct cublasSymmDispatcher { -+ -+ // -+ // Data members -+ // -+ library::SymmConfiguration configuration; -+ library::SymmArguments arguments; -+ -+ // cublass-specific data structures to fill cublas API call arguments -+ cublasSideMode_t side; -+ cublasFillMode_t uplo; -+ cudaDataType_t data_type_A; -+ cudaDataType_t data_type_B; -+ cudaDataType_t data_type_C; -+ cudaDataType_t compute_data_type; -+ -+#if (__CUDACC_VER_MAJOR__ >= 11) -+ cublasComputeType_t compute_type; -+#endif -+ -+ BlasMode blas_mode; //(symmetric or hermitian) -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ cublasSymmDispatcher( -+ library::SymmDescription const &op_desc, -+ library::SymmConfiguration configuration_, -+ library::SymmArguments arguments_ -+ ); -+ -+ /// Executes Symm using these arguments -+ cublasStatus_t operator()(cublasHandle_t handle); -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+ -+} // namespace profiler -+} // namespace cutlass -+ -+ -+#endif // #if CUTLASS_ENABLE_CUBLAS -diff --git a/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h b/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h -new file mode 100644 -index 0000000..2f02382 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cudnn_helpers.h -@@ -0,0 +1,590 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Helper functions for mapping CUTLASS concepts to cuDNN. -+ -+*/ -+ -+#pragma once -+#if CUTLASS_ENABLE_CUDNN -+#include -+#include -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/library/library.h" -+#include "enumerated_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+/// Converts a cuDNN status to cutlass::Status -+Status get_cutlass_status(cudnnStatus_t cudnn_status); -+ -+/// Converts a cuDNN status to cutlass::profiler::Disposition -+Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); -+ -+/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception -+Status checkCudnnErr(cudnnStatus_t cudnn_status); -+ -+/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration -+bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); -+ -+/// Maps a CUTLASS layout type to a cuDNN data type enumeration -+bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); -+ -+/// Maps a CUTLASS numeric type to a cuDNN data type enumeration -+bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); -+ -+/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type -+bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); -+ -+/// Returns a status if cudnn can satisfy a particular Conv2d description -+Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); -+ -+/// Returns a status if cudnn can satisfy a particular Conv3d description -+Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); -+ -+/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) -+float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); -+ -+ -+/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and -+/// to destroy cudnnHandle_t on CudnnCreate object destruction. -+/// Additionaly, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object -+class CudnnCreate { -+private: -+ cudnnHandle_t handle; -+ cudnnStatus_t status; -+ -+public: -+ CudnnCreate() { -+ status = cudnnCreate(&handle); -+ } -+ -+ ~CudnnCreate() { -+ cudnnDestroy(handle); -+ } -+ -+ /// Implicit cast CudnnCreate object to cudnnHandle_t -+ operator cudnnHandle_t() const { return handle; } -+ -+ /// returns cudnnStatus_t for handle creation -+ cudnnStatus_t get_cudnn_create_status() { return status; } -+}; -+ -+ -+namespace detail { -+ -+/// Dispatcher to cudnn convolution operators -+struct cudnnConvDispatcher { -+ -+ // -+ // Data members -+ // -+ //library::Conv2dConfiguration configuration; -+ library::ConvArguments arguments; -+ library::ConvKind conv_kind; -+ -+ // cudnn-specific data structures to fill cudnn API call arguments -+ // cudnn activation, filter, and output descriptors -+ cudnnTensorDescriptor_t activation_desc; -+ cudnnFilterDescriptor_t filter_desc; -+ cudnnTensorDescriptor_t output_desc; -+ cudnnConvolutionDescriptor_t conv_desc; -+ -+ // cudnn datatypes -+ cudnnDataType_t data_type_activation; -+ cudnnDataType_t data_type_filter; -+ cudnnDataType_t data_type_output; -+ -+ // cudnn layouts -+ cudnnTensorFormat_t layout_activation; -+ cudnnTensorFormat_t layout_filter; -+ cudnnTensorFormat_t layout_output; -+ -+ // cudnn convolution mode -+ cudnnConvolutionMode_t conv_mode; -+ -+ // cudnn math type (tensorop, tensorop with conversion, simt) -+ cudnnMathType_t math_type; -+ -+ // cudnn compute data type -+ cudnnDataType_t compute_type; -+ -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ float alpha; -+ float beta; -+ -+ // cudnn workspace -+ size_t workspace_size_in_bytes = 0; -+ cutlass::device_memory::allocation workspace; -+ -+ // select cudnn's implicit gemm precomputed algorithm with tensor operations -+ static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; -+ static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; -+ static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; -+ -+ Status status; -+ -+ // -+ // Methods -+ // -+ -+ // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfigration -+ -+ // ctor for conv2d -+ cudnnConvDispatcher( -+ library::ConvDescription const &op_desc, -+ library::Conv2dConfiguration configuration, -+ library::ConvArguments arguments_, -+ cudnnHandle_t handle -+ ): -+ //configuration(configuration_), -+ arguments(arguments_), -+ conv_kind(op_desc.conv_kind), -+ status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ // Get cudnn datatype, layout, and convolution mode from library::ConvDescription -+ good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); -+ good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); -+ good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); -+ good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); -+ good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); -+ good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); -+ good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); -+ // Get cudnn mathtype (cudnnMathType_t) -+ good = (good && get_cudnn_mathtype(math_type, op_desc)); -+ good = (good && get_cudnn_datatype( -+ compute_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ // Check cutlass Conv2d description has equivalent operator in cudnn -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ return; -+ } -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); -+ beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); -+ -+ // Create convolution descriptor object -+ status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); -+ -+ // Configure convolution operator -+ std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; -+ std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; -+ std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; -+ -+ status = get_cutlass_status( -+ cudnnSetConvolutionNdDescriptor( -+ conv_desc, -+ op_desc.conv_dim, -+ padding.data(), -+ stride.data(), -+ dilation.data(), -+ conv_mode, -+ compute_type -+ )); -+ -+ // Set groups -+ status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); -+ -+ // Create activation, filter, and output descriptor objects -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); -+ status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); -+ -+ // Set activation, filter, and output descriptor -+ status = get_cutlass_status( -+ cudnnSetTensor4dDescriptor( -+ activation_desc, -+ layout_activation, -+ data_type_activation, -+ configuration.problem_size.N, -+ configuration.problem_size.C, -+ configuration.problem_size.H, -+ configuration.problem_size.W -+ )); -+ -+ status = get_cutlass_status( -+ cudnnSetFilter4dDescriptor( -+ filter_desc, -+ data_type_filter, -+ layout_filter, -+ configuration.problem_size.K, -+ configuration.problem_size.C / configuration.problem_size.groups, -+ configuration.problem_size.R, -+ configuration.problem_size.S -+ )); -+ -+ status = get_cutlass_status( -+ cudnnSetTensor4dDescriptor( -+ output_desc, -+ layout_output, -+ data_type_output, -+ configuration.problem_size.N, -+ configuration.problem_size.K, -+ configuration.problem_size.P, -+ configuration.problem_size.Q -+ )); -+ -+ // Set math instruction to tensor op -+ status = get_cutlass_status( -+ cudnnSetConvolutionMathType(conv_desc, math_type)); -+ -+ // Initialize workspace -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ status = get_cutlass_status( -+ cudnnGetConvolutionForwardWorkspaceSize( -+ handle, -+ activation_desc, -+ filter_desc, -+ conv_desc, -+ output_desc, -+ fprop_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kDgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardDataWorkspaceSize( -+ handle, -+ filter_desc, -+ output_desc, -+ conv_desc, -+ activation_desc, -+ dgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kWgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardFilterWorkspaceSize( -+ handle, -+ activation_desc, -+ output_desc, -+ conv_desc, -+ filter_desc, -+ wgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ -+ } -+ -+ workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); -+ } -+ -+ -+ // ctor for conv3d -+ cudnnConvDispatcher( -+ library::ConvDescription const &op_desc, -+ library::Conv3dConfiguration configuration, -+ library::ConvArguments arguments_, -+ cudnnHandle_t handle -+ ): -+ //configuration(configuration_), -+ arguments(arguments_), -+ conv_kind(op_desc.conv_kind), -+ status(Status::kSuccess) { -+ -+ bool good = true; -+ -+ // Get cudnn datatype, layout, and convolution mode from library::ConvDescription -+ good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); -+ good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); -+ good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); -+ -+ good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); -+ good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); -+ good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); -+ -+ good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); -+ -+ // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) -+ alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); -+ beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); -+ -+ good = (good && get_cudnn_datatype( -+ compute_type, -+ op_desc.tile_description.math_instruction.element_accumulator)); -+ -+ // Check cutlass Conv2d description has equivalent operator in cudnn -+ if (!good) { -+ status = Status::kErrorNotSupported; -+ } -+ -+ // Create convolution descriptor object -+ status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); -+ -+ // Configure convolution operator -+ std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; -+ std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; -+ std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; -+ -+ status = get_cutlass_status( -+ cudnnSetConvolutionNdDescriptor( -+ conv_desc, -+ op_desc.conv_dim, -+ padding.data(), -+ stride.data(), -+ dilation.data(), -+ conv_mode, -+ compute_type -+ )); -+ -+ // Set groups -+ status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); -+ -+ // Create activation, filter, and output descriptor objects -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); -+ status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); -+ status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); -+ -+ // Set activation descriptor -+ std::vector activation_extent { -+ configuration.problem_size.N, -+ configuration.problem_size.C, -+ configuration.problem_size.D, -+ configuration.problem_size.H, -+ configuration.problem_size.W -+ }; -+ -+ std::vector activation_stride { -+ configuration.layout_activations.stride()[3], -+ 1, -+ configuration.layout_activations.stride()[2], -+ configuration.layout_activations.stride()[1], -+ configuration.layout_activations.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetTensorNdDescriptor( -+ activation_desc, -+ data_type_activation, -+ op_desc.conv_dim + 2, -+ activation_extent.data(), -+ activation_stride.data() -+ )); -+ -+ // Set filter descriptor -+ std::vector filter_extent { -+ configuration.problem_size.K, -+ configuration.problem_size.C, -+ configuration.problem_size.T, -+ configuration.problem_size.R, -+ configuration.problem_size.S -+ }; -+ -+ std::vector filter_stride { -+ configuration.layout_filters.stride()[3], -+ 1, -+ configuration.layout_filters.stride()[2], -+ configuration.layout_filters.stride()[1], -+ configuration.layout_filters.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetFilterNdDescriptor( -+ filter_desc, -+ data_type_filter, -+ layout_filter, -+ op_desc.conv_dim + 2, -+ filter_extent.data() -+ )); -+ -+ -+ // Set output descriptor -+ std::vector output_extent { -+ configuration.problem_size.N, -+ configuration.problem_size.K, -+ configuration.problem_size.Z, -+ configuration.problem_size.P, -+ configuration.problem_size.Q -+ }; -+ -+ std::vector output_stride { -+ configuration.layout_output.stride()[3], -+ 1, -+ configuration.layout_output.stride()[2], -+ configuration.layout_output.stride()[1], -+ configuration.layout_output.stride()[0] -+ }; -+ -+ status = get_cutlass_status( -+ cudnnSetTensorNdDescriptor( -+ output_desc, -+ data_type_output, -+ op_desc.conv_dim + 2, -+ output_extent.data(), -+ output_stride.data() -+ )); -+ -+ // Set math instruction to tensor op -+ status = get_cutlass_status( -+ cudnnSetConvolutionMathType(conv_desc, math_type)); -+ -+ // Initialize workspace -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ status = get_cutlass_status( -+ cudnnGetConvolutionForwardWorkspaceSize( -+ handle, -+ activation_desc, -+ filter_desc, -+ conv_desc, -+ output_desc, -+ fprop_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kDgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardDataWorkspaceSize( -+ handle, -+ filter_desc, -+ output_desc, -+ conv_desc, -+ activation_desc, -+ dgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ case library::ConvKind::kWgrad: -+ status = get_cutlass_status( -+ cudnnGetConvolutionBackwardFilterWorkspaceSize( -+ handle, -+ activation_desc, -+ output_desc, -+ conv_desc, -+ filter_desc, -+ wgrad_algo, -+ &workspace_size_in_bytes -+ )); break; -+ -+ } -+ -+ workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); -+ } -+ -+ /// Executes Conv2d operater from cudnn library -+ cudnnStatus_t operator()(cudnnHandle_t handle) { -+ -+ switch (conv_kind) { -+ case library::ConvKind::kFprop: -+ return cudnnConvolutionForward( -+ handle, -+ &alpha, -+ activation_desc, -+ activation(), -+ filter_desc, -+ filter(), -+ conv_desc, -+ fprop_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ output_desc, -+ arguments.D -+ ); -+ case library::ConvKind::kDgrad: -+ return cudnnConvolutionBackwardData( -+ handle, -+ &alpha, -+ filter_desc, -+ filter(), -+ output_desc, -+ output(), -+ conv_desc, -+ dgrad_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ activation_desc, -+ arguments.D -+ ); -+ case library::ConvKind::kWgrad: -+ return cudnnConvolutionBackwardFilter( -+ handle, -+ &alpha, -+ activation_desc, -+ activation(), -+ output_desc, -+ output(), -+ conv_desc, -+ wgrad_algo, -+ workspace.get(), -+ workspace_size_in_bytes, -+ &beta, -+ filter_desc, -+ arguments.D -+ ); -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Actviation Tensor -+ void const * activation() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.A; -+ case library::ConvKind::kDgrad : return arguments.C; -+ case library::ConvKind::kWgrad : return arguments.B; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Filter Tensor -+ void const *filter() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.B; -+ case library::ConvKind::kDgrad : return arguments.B; -+ case library::ConvKind::kWgrad : return arguments.C; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+ -+ // Returns Output Tensor -+ void const *output() const { -+ switch(conv_kind) { -+ case library::ConvKind::kFprop : return arguments.C; -+ case library::ConvKind::kDgrad : return arguments.A; -+ case library::ConvKind::kWgrad : return arguments.A; -+ default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); -+ } -+ } -+}; -+ -+} // namespace detail -+///////////////////////////////////////////////////////////////////////////////////////////////// -+#endif //#if CUTLASS_ENABLE_CUDNN -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu -new file mode 100644 -index 0000000..026ffdf ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.cu -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+#include -+ -+// Profiler includes -+#include "cutlass_profiler.h" -+#include "gemm_operation_profiler.h" -+#include "rank_k_operation_profiler.h" -+#include "rank_2k_operation_profiler.h" -+#include "trmm_operation_profiler.h" -+#include "symm_operation_profiler.h" -+#include "conv2d_operation_profiler.h" -+#include "conv3d_operation_profiler.h" -+#include "sparse_gemm_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+CutlassProfiler::CutlassProfiler( -+ Options const &options -+): -+ options_(options) { -+ -+ operation_profilers_.emplace_back(new GemmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Conv2dOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Conv3dOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new RankKOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new Rank2KOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new TrmmOperationProfiler(options)); -+ -+ operation_profilers_.emplace_back(new SymmOperationProfiler(options)); -+} -+ -+CutlassProfiler::~CutlassProfiler() { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Execute the program -+int CutlassProfiler::operator()() { -+ -+ if (options_.cmdline.num_naked_args() > 0) { -+ std::cerr << "Unknown args: \n"; -+ options_.cmdline.print_naked_args(std::cerr); -+ std::cerr << "\n\n\n"; -+ -+ print_usage_(std::cout); -+ return 1; -+ } -+ -+ if (options_.about.help) { -+ if (options_.operation_kind == library::OperationKind::kInvalid) { -+ print_usage_(std::cout); -+ } -+ else { -+ for (auto & profiler : operation_profilers_) { -+ if (profiler->kind() == options_.operation_kind) { -+ profiler->print_usage(std::cout); -+ profiler->print_examples(std::cout); -+ return 0; -+ } -+ } -+ } -+ return 0; -+ } -+ else if (options_.about.version) { -+ options_.about.print_version(std::cout); -+ -+ std::cout << std::endl; -+ return 0; -+ } -+ else if (options_.about.device_info) { -+ options_.device.print_device_info(std::cout); -+ return 0; -+ } -+ -+ if (options_.execution_mode == ExecutionMode::kProfile || -+ options_.execution_mode == ExecutionMode::kDryRun || -+ options_.execution_mode == ExecutionMode::kTrace) { -+ -+ // Profiles all operations -+ profile_(); -+ } -+ else if (options_.execution_mode == ExecutionMode::kEnumerate) { -+ // Enumerates all operations -+ enumerate_(); -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerates all operations -+void CutlassProfiler::enumerate_() { -+ -+} -+ -+/// Profiles all operations -+int CutlassProfiler::profile_() { -+ -+ int result = 0; -+ DeviceContext device_context; -+ -+ // For all profilers -+ for (auto & profiler : operation_profilers_) { -+ -+ if (options_.operation_kind == library::OperationKind::kInvalid || -+ options_.operation_kind == profiler->kind()) { -+ -+ result = profiler->profile_all(options_, library::Singleton::get().manifest, device_context); -+ -+ if (result) { -+ return result; -+ } -+ } -+ } -+ -+ return result; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints all options -+void CutlassProfiler::print_usage_(std::ostream &out) { -+ options_.print_usage(out); -+ -+ out << "\nOperations:\n\n"; -+ -+ // For all profilers -+ for (auto & profiler : operation_profilers_) { -+ -+ -+ std::string kind_str = library::to_string(profiler->kind()); -+ -+ size_t kAlignment = 40; -+ size_t columns = 0; -+ -+ if (kind_str.size() < kAlignment) { -+ columns = kAlignment - kind_str.size(); -+ } -+ -+ out << " " << kind_str << std::string(columns, ' ') << profiler->description() << "\n"; -+ -+ } -+ -+ out << "\n\nFor details about a particular function, specify the function name with --help.\n\nExample:\n\n" -+ << " $ cutlass_profiler --operation=Gemm --help\n\n" -+ << " $ cutlass_profiler --operation=RankK --help\n\n" -+ << " $ cutlass_profiler --operation=Trmm --help\n\n" -+ << " $ cutlass_profiler --operation=Symm --help\n\n" -+ << " $ cutlass_profiler --operation=Conv3d --help\n\n" -+ << " $ cutlass_profiler --operation=Conv2d --help\n\n" -+ << " $ cutlass_profiler --operation=SparseGemm --help\n\n" -+ ; -+} -+ -+/// Prints usage -+void CutlassProfiler::print_options_(std::ostream &out) { -+ options_.print_options(out); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Initializes the CUDA device -+void CutlassProfiler::initialize_device_() { -+ -+ cudaError_t result = cudaSetDevice(options_.device.device); -+ -+ if (result != cudaSuccess) { -+ std::cerr << "Failed to set device."; -+ throw std::runtime_error("Failed to set device"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h -new file mode 100644 -index 0000000..a3b0640 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/cutlass_profiler.h -@@ -0,0 +1,96 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#pragma once -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/manifest.h" -+#include "cutlass/library/singleton.h" -+ -+#include "options.h" -+#include "operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// CUTLASS Profiler application -+class CutlassProfiler { -+private: -+ -+ // -+ // Data members -+ // -+ -+ /// Performance testbench options -+ Options options_; -+ -+ /// Entry points for each operation -+ OperationProfilerVector operation_profilers_; -+ -+private: -+ -+ /// Prints usage -+ void print_usage_(std::ostream &); -+ -+ /// Prints usage -+ void print_options_(std::ostream &); -+ -+ /// Initializes the device -+ void initialize_device_(); -+ -+ /// Enumerates all operations -+ void enumerate_(); -+ -+ /// Profiles all operations -+ int profile_(); -+ -+public: -+ -+ CutlassProfiler(Options const &options); -+ ~CutlassProfiler(); -+ -+ /// Invokes profiling operations -+ int operator()(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/debug.h b/3rdparty/cutlass/tools/profiler/src/debug.h -new file mode 100644 -index 0000000..83e2c33 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/debug.h -@@ -0,0 +1,56 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+ -+//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } -+//#define report(x) {} -+ -+// Enable/Disble Profiler debug prints -+//#define DEBUG_PROFILER -+ -+//RED 31m // profiler prints debug messages in red -+//YELLOW 33m // ir prints debug messages in yellow -+ -+#ifndef DEBUG_PROFILER -+#define debugprof(...) -+#else -+#define debugprof(...) do { \ -+ printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ -+ printf(__VA_ARGS__); \ -+ printf("\033[0m\n"); \ -+ } while (0) -+#endif -diff --git a/3rdparty/cutlass/tools/profiler/src/device_allocation.cu b/3rdparty/cutlass/tools/profiler/src/device_allocation.cu -new file mode 100644 -index 0000000..e59c344 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_allocation.cu -@@ -0,0 +1,1681 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/layout/matrix.h" -+#include "cutlass/layout/tensor.h" -+ -+#include "cutlass/util/reference/device/tensor_compare.h" -+#include "cutlass/util/reference/device/tensor_fill.h" -+#include "cutlass/util/reference/host/tensor_fill.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/tensor_view_io.h" -+ -+#include "cutlass/library/util.h" -+ -+#include "device_allocation.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) { -+ return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static std::vector get_packed_layout_stride(std::vector const &extent) { -+ -+ typename Layout::TensorCoord extent_coord; -+ typename Layout::Stride stride_coord; -+ -+ if (extent.size() != size_t(Layout::kRank)) { -+ throw std::runtime_error("Layout does not have same rank as extent vector."); -+ } -+ -+ for (int i = 0; i < Layout::kRank; ++i) { -+ extent_coord[i] = extent.at(i); -+ } -+ -+ std::vector stride; -+ stride.resize(Layout::kStrideRank, 0); -+ -+ Layout layout = Layout::packed(extent_coord); -+ stride_coord = layout.stride(); -+ -+ for (int i = 0; i < Layout::kStrideRank; ++i) { -+ stride.at(i) = (int64_t)stride_coord[i]; -+ } -+ -+ return stride; -+} -+ -+/// Returns the stride of a packed layout -+std::vector DeviceAllocation::get_packed_layout( -+ library::LayoutTypeID layout_id, -+ std::vector const &extent) { -+ -+ std::vector stride; -+ -+ switch (layout_id) { -+ case library::LayoutTypeID::kColumnMajor: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kRowMajor: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorNCHW: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ stride = get_packed_layout_stride(extent); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ stride = get_packed_layout_stride>(extent); -+ break; -+ default: break; -+ } -+ -+ return stride; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template to use CUTLASS Layout functions to -+template -+static size_t construct_layout_( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride) { -+ -+ if (extent.size() != Layout::kRank) { -+ throw std::runtime_error( -+ "Layout must have same rank as extent vector."); -+ } -+ -+ if (Layout::kStrideRank && stride.empty()) { -+ -+ stride = get_packed_layout_stride(extent); -+ -+ return construct_layout_( -+ bytes, -+ layout_id, -+ extent, -+ stride); -+ } -+ else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) { -+ throw std::runtime_error( -+ "Layout requires either empty stride or stride vector matching Layout::kStrideRank"); -+ } -+ -+ typename Layout::Stride stride_coord; -+ for (int i = 0; i < Layout::kStrideRank; ++i) { -+ stride_coord[i] = (int)stride.at(i); -+ } -+ -+ typename Layout::TensorCoord extent_coord; -+ for (int i = 0; i < Layout::kRank; ++i) { -+ extent_coord[i] = extent.at(i); -+ } -+ -+ // Construct the CUTLASS layout object from the stride object -+ Layout layout(stride_coord); -+ -+ // Pack it into bytes -+ if (bytes) { -+ *reinterpret_cast(bytes) = layout; -+ } -+ -+ // Return capacity -+ size_t capacity_ = layout.capacity(extent_coord); -+ -+ return capacity_; -+} -+ -+/// returns the capacity needed -+size_t DeviceAllocation::construct_layout( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride) { -+ -+ switch (layout_id) { -+ case library::LayoutTypeID::kColumnMajor: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajor: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNCHW: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNHWC: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNDHWC: -+ return construct_layout_(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNC32HW32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorNC64HW64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorC32RSK32: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ case library::LayoutTypeID::kTensorC64RSK64: -+ return construct_layout_>(bytes, layout_id, extent, stride); -+ -+ default: break; -+ } -+ -+ return 0; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+DeviceAllocation::DeviceAllocation(): -+ type_(library::NumericTypeID::kInvalid), -+ batch_stride_(0), -+ capacity_(0), -+ pointer_(nullptr), -+ layout_(library::LayoutTypeID::kUnknown), -+ batch_count_(1) { -+ -+} -+ -+DeviceAllocation::DeviceAllocation( -+ library::NumericTypeID type, -+ size_t capacity -+): -+ type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr), -+ layout_(library::LayoutTypeID::kUnknown), batch_count_(1) { -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity)); -+ -+ if (result != cudaSuccess) { -+ type_ = library::NumericTypeID::kInvalid; -+ capacity_ = 0; -+ pointer_ = nullptr; -+ throw std::bad_alloc(); -+ } -+} -+ -+DeviceAllocation::DeviceAllocation( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count -+): -+ type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) { -+ -+ reset(type, layout_id, extent, stride, batch_count); -+} -+ -+DeviceAllocation::~DeviceAllocation() { -+ if (pointer_) { -+ cudaFree(pointer_); -+ } -+} -+ -+DeviceAllocation &DeviceAllocation::reset() { -+ if (pointer_) { -+ cudaFree(pointer_); -+ } -+ -+ type_ = library::NumericTypeID::kInvalid; -+ batch_stride_ = 0; -+ capacity_ = 0; -+ pointer_ = nullptr; -+ layout_ = library::LayoutTypeID::kUnknown; -+ stride_.clear(); -+ extent_.clear(); -+ tensor_ref_buffer_.clear(); -+ batch_count_ = 1; -+ -+ return *this; -+} -+ -+DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t capacity) { -+ -+ reset(); -+ -+ type_ = type; -+ batch_stride_ = capacity; -+ capacity_ = capacity; -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_)); -+ if (result != cudaSuccess) { -+ throw std::bad_alloc(); -+ } -+ -+ layout_ = library::LayoutTypeID::kUnknown; -+ stride_.clear(); -+ extent_.clear(); -+ batch_count_ = 1; -+ -+ tensor_ref_buffer_.resize(sizeof(pointer_), 0); -+ std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); -+ -+ return *this; -+} -+ -+/// Allocates memory for a given layout and tensor -+DeviceAllocation &DeviceAllocation::reset( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ reset(); -+ -+ tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int64_t) * library::get_layout_stride_rank(layout_id)), 0); -+ -+ type_ = type; -+ -+ layout_ = layout_id; -+ stride_ = stride; -+ extent_ = extent; -+ batch_count_ = batch_count; -+ -+ batch_stride_ = construct_layout( -+ tensor_ref_buffer_.data() + sizeof(pointer_), -+ layout_id, -+ extent, -+ stride_); -+ -+ capacity_ = batch_stride_ * batch_count_; -+ -+ cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_)); -+ if (result != cudaSuccess) { -+ throw std::bad_alloc(); -+ } -+ -+ std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); -+ -+ return *this; -+} -+ -+bool DeviceAllocation::good() const { -+ return (capacity_ && pointer_); -+} -+ -+library::NumericTypeID DeviceAllocation::type() const { -+ return type_; -+} -+ -+void *DeviceAllocation::data() const { -+ return pointer_; -+} -+ -+void *DeviceAllocation::batch_data(int batch_idx) const { -+ return static_cast(data()) + batch_stride_bytes() * batch_idx; -+} -+ -+library::LayoutTypeID DeviceAllocation::layout() const { -+ return layout_; -+} -+ -+std::vector const & DeviceAllocation::stride() const { -+ return stride_; -+} -+ -+/// Gets the extent vector -+std::vector const & DeviceAllocation::extent() const { -+ return extent_; -+} -+ -+/// Gets the number of adjacent tensors in memory -+int DeviceAllocation::batch_count() const { -+ return batch_count_; -+} -+ -+/// Gets the stride (in units of elements) beteween items -+int64_t DeviceAllocation::batch_stride() const { -+ return batch_stride_; -+} -+ -+/// Gets the stride (in units of bytes) beteween items -+int64_t DeviceAllocation::batch_stride_bytes() const { -+ return bytes(type_, batch_stride_); -+} -+ -+size_t DeviceAllocation::capacity() const { -+ return capacity_; -+} -+ -+size_t DeviceAllocation::bytes() const { -+ return bytes(type_, capacity_); -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_from_device(void const *ptr) { -+ cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_from_host(void const *ptr) { -+ cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+/// Copies from an equivalent-sized tensor in device memory -+void DeviceAllocation::copy_to_host(void *ptr) { -+ cudaError_t result = cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed device-to-device copy"); -+ } -+} -+ -+void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ // Instantiate calls to CURAND here. This file takes a long time to compile for -+ // this reason. -+ -+ switch (type_) { -+ case library::NumericTypeID::kF16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kBF16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kTF32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCBF16: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCTF32: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF32: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF64: -+ cutlass::reference::device::BlockFillRandom>( -+ reinterpret_cast *>(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS2: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS4: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS8: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kB1: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU2: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU4: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU8: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU16: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU64: -+ cutlass::reference::device::BlockFillRandom( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ default: break; -+ } -+} -+ -+void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ std::vector host_data(bytes()); -+ -+ switch (type_) { -+ case library::NumericTypeID::kF16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kBF16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kTF32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF16: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCBF16: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCTF32: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF32: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kF64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kCF64: -+ cutlass::reference::host::BlockFillRandom>( -+ reinterpret_cast *>(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS2: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS4: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS8: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kS64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kB1: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU2: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU4: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU8: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU16: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ case library::NumericTypeID::kU64: -+ cutlass::reference::host::BlockFillRandom( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ dist -+ ); -+ break; -+ default: break; -+ } -+ -+ copy_from_host(host_data.data()); -+} -+ -+void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ // Instantiate calls to CURAND here. This file takes a long time to compile for -+ // this reason. -+ -+ switch (type_) { -+ case library::NumericTypeID::kU16: -+ cutlass::reference::device::BlockFillRandomSparseMeta( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ case library::NumericTypeID::kU32: -+ cutlass::reference::device::BlockFillRandomSparseMeta( -+ reinterpret_cast(pointer_), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ default: -+ break; -+ } -+} -+ -+void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) { -+ if (!good()) { -+ throw std::runtime_error("Attempting to initialize invalid allocation."); -+ } -+ -+ std::vector host_data(bytes()); -+ -+ switch (type_) { -+ case library::NumericTypeID::kS16: -+ cutlass::reference::host::BlockFillRandomSparseMeta( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ case library::NumericTypeID::kS32: -+ cutlass::reference::host::BlockFillRandomSparseMeta( -+ reinterpret_cast(host_data.data()), -+ capacity_, -+ seed, -+ MetaSizeInBits -+ ); -+ break; -+ default: -+ break; -+ } -+ -+ copy_from_host(host_data.data()); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two blocks have exactly the same value -+bool DeviceAllocation::block_compare_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity) { -+ -+ switch (numeric_type) { -+ case library::NumericTypeID::kF16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kBF16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kTF32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kF32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF32: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF16: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCBF16: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCTF32: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kF64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF64: -+ return reference::device::BlockCompareEqual>( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS2: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS4: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS8: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kS64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kB1: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU2: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU4: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU8: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU16: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU32: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kU64: -+ return reference::device::BlockCompareEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity); -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+/// Returns true if two blocks have approximately the same value -+bool DeviceAllocation::block_compare_relatively_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity, -+ double epsilon, -+ double nonzero_floor) { -+ -+ switch (numeric_type) { -+ case library::NumericTypeID::kF16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kBF16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kTF32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kF32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kF64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS2: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS4: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS8: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kS64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kB1: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU2: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU4: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU8: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU16: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU32: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ case library::NumericTypeID::kU64: -+ return reference::device::BlockCompareRelativelyEqual( -+ reinterpret_cast(ptr_A), -+ reinterpret_cast(ptr_B), -+ capacity, -+ static_cast(epsilon), -+ static_cast(nonzero_floor)); -+ -+ // No relatively equal comparison for complex numbers. -+ // -+ // As a simplification, we can require bitwise equality. This avoids false positives. -+ // (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.) -+ // -+ case library::NumericTypeID::kCF16: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF32: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ case library::NumericTypeID::kCF64: -+ return reference::device::BlockCompareEqual >( -+ reinterpret_cast const *>(ptr_A), -+ reinterpret_cast const *>(ptr_B), -+ capacity); -+ -+ default: -+ { -+ throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type)); -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[Rank - 1] = (int)vec.at(Rank - 1); -+ -+ if (Rank > 1) { -+ vector_to_coord(coord, vec); -+ } -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = vec.at(0); -+ } -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ coord[0] = (int)vec.at(0); -+ } -+}; -+ -+/// Permits copying dynamic vectors into static-length vectors -+template -+struct vector_to_coord { -+ -+ vector_to_coord(TensorCoord &coord, std::vector const &vec) { -+ -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static void write_tensor_csv_static_tensor_view( -+ std::ostream &out, -+ DeviceAllocation &allocation) { -+ -+ Coord extent; -+ Coord stride; -+ -+ if (allocation.extent().size() != Layout::kRank) { -+ throw std::runtime_error("Allocation extent has invalid rank"); -+ } -+ -+ if (allocation.stride().size() != Layout::kStrideRank) { -+ throw std::runtime_error("Allocation stride has invalid rank"); -+ } -+ -+ vector_to_coord, Layout::kRank>(extent, allocation.extent()); -+ vector_to_coord, -+ Layout::kStrideRank>(stride, allocation.stride()); -+ -+ Layout layout(stride); -+ HostTensor host_tensor(extent, layout, false); -+ -+ if (host_tensor.capacity() != allocation.batch_stride()) { -+ throw std::runtime_error("Unexpected capacity to equal."); -+ } -+ -+ host_tensor.copy_in_device_to_host( -+ static_cast(allocation.data()), -+ allocation.batch_stride()); -+ -+ TensorViewWrite(out, host_tensor.host_view()); -+ -+ out << "\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+static void write_tensor_csv_static_type( -+ std::ostream &out, -+ DeviceAllocation &allocation) { -+ -+ switch (allocation.layout()) { -+ case library::LayoutTypeID::kRowMajor: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajor: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK2: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK2: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK4: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK4: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK16: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK16: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kRowMajorInterleavedK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kColumnMajorInterleavedK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ write_tensor_csv_static_tensor_view(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ write_tensor_csv_static_tensor_view>(out, allocation); -+ break; -+ default: -+ throw std::runtime_error("Unhandled layout"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a tensor to csv -+void DeviceAllocation::write_tensor_csv( -+ std::ostream &out) { -+ -+ switch (this->type()) { -+ case library::NumericTypeID::kF16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kBF16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kF32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kF64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS2: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS4: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS8: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kS64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kB1: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU2: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU4: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU8: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU16: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU32: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kU64: -+ write_tensor_csv_static_type(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF16: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ case library::NumericTypeID::kCF64: -+ write_tensor_csv_static_type >(out, *this); -+ break; -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+template -+static void tensor_fill_tensor_view(DeviceAllocation &allocation, Element val = Element()) { -+ Coord extent; -+ Coord stride; -+ -+ if (allocation.extent().size() != Layout::kRank) { -+ throw std::runtime_error("Allocation extent has invalid rank"); -+ } -+ -+ if (allocation.stride().size() != Layout::kStrideRank) { -+ throw std::runtime_error("Allocation stride has invalid rank"); -+ } -+ -+ vector_to_coord, Layout::kRank>(extent, allocation.extent()); -+ vector_to_coord, -+ Layout::kStrideRank>(stride, allocation.stride()); -+ -+ TensorView view( -+ static_cast(allocation.data()), -+ Layout(stride), -+ extent -+ ); -+ -+ -+ cutlass::reference::device::TensorFill( -+ view, -+ val -+ ); -+} -+ -+template -+static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) { -+ switch (allocation.layout()) { -+ case library::LayoutTypeID::kRowMajor: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kColumnMajor: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNHWC: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNDHWC: -+ tensor_fill_tensor_view(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNC32HW32: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorNC64HW64: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorC32RSK32: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ case library::LayoutTypeID::kTensorC64RSK64: -+ tensor_fill_tensor_view>(allocation, val); -+ break; -+ default: -+ throw std::runtime_error("Unsupported layout"); -+ break; -+ } -+} -+ -+/// Fills a tensor uniformly with a value (most frequently used to clear the tensor) -+void DeviceAllocation::fill(double val = 0.0) { -+ -+ switch (this->type()) { -+ case library::NumericTypeID::kF16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kBF16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kTF32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kF32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kF64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS2: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS4: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS8: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kS64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kB1: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU2: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU4: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU8: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU16: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU32: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kU64: -+ tensor_fill(*this, static_cast(val)); -+ break; -+ -+ case library::NumericTypeID::kCF16: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ case library::NumericTypeID::kCF32: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ case library::NumericTypeID::kCF64: -+ tensor_fill >(*this, from_real(val)); -+ break; -+ -+ default: -+ throw std::runtime_error("Unsupported numeric type"); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/device_allocation.h b/3rdparty/cutlass/tools/profiler/src/device_allocation.h -new file mode 100644 -index 0000000..d0bdfd4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_allocation.h -@@ -0,0 +1,226 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/library/library.h" -+#include "cutlass/util/distribution.h" -+ -+#include "enumerated_types.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Device memory allocation -+class DeviceAllocation { -+private: -+ -+ /// Data type of contained elements -+ library::NumericTypeID type_; -+ -+ /// Gets the stride between elements -+ size_t batch_stride_; -+ -+ /// Capacity in elements of device allocation -+ size_t capacity_; -+ -+ /// Pointer to device memory -+ void *pointer_; -+ -+ /// Layout type ID -+ library::LayoutTypeID layout_; -+ -+ /// Stride vector -+ std::vector stride_; -+ -+ /// Extent vector -+ std::vector extent_; -+ -+ /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory -+ int batch_count_; -+ -+ /// Buffer holding TensorRef instance to recently allocated memory -+ std::vector tensor_ref_buffer_; -+ -+public: -+ // -+ // Static member functions -+ // -+ -+ /// Determines the number of bytes needed to represent this numeric type -+ static size_t bytes(library::NumericTypeID type, size_t capacity); -+ -+ /// Returns the stride of a packed layout -+ static std::vector get_packed_layout( -+ library::LayoutTypeID layout_id, -+ std::vector const &extent); -+ -+ /// returns the capacity needed -+ static size_t construct_layout( -+ void *bytes, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector &stride); -+ -+ /// Returns true if two blocks have exactly the same value -+ static bool block_compare_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity); -+ -+ /// Returns true if two blocks have approximately the same value -+ static bool block_compare_relatively_equal( -+ library::NumericTypeID numeric_type, -+ void const *ptr_A, -+ void const *ptr_B, -+ size_t capacity, -+ double epsilon, -+ double nonzero_floor); -+ -+public: -+ // -+ // Methods -+ // -+ -+ DeviceAllocation(); -+ -+ DeviceAllocation(library::NumericTypeID type, size_t capacity); -+ -+ DeviceAllocation( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ ~DeviceAllocation(); -+ -+ DeviceAllocation &reset(); -+ -+ /// Allocates device memory of a given type and capacity -+ DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); -+ -+ /// Allocates memory for a given layout and tensor -+ DeviceAllocation &reset( -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Returns a buffer owning the tensor reference -+ std::vector &tensor_ref() { -+ return tensor_ref_buffer_; -+ } -+ -+ bool good() const; -+ -+ /// Data type of contained elements -+ library::NumericTypeID type() const; -+ -+ /// Pointer to start of device memory allocation -+ void *data() const; -+ -+ /// Pointer to the first element of a batch -+ void *batch_data(int batch_idx) const; -+ -+ /// Gets the layout type -+ library::LayoutTypeID layout() const; -+ -+ /// Gets the stride vector -+ std::vector const & stride() const; -+ -+ /// Gets the extent vector -+ std::vector const & extent() const; -+ -+ /// Gets the number of adjacent tensors in memory -+ int batch_count() const; -+ -+ /// Gets the stride (in units of elements) beteween items -+ int64_t batch_stride() const; -+ -+ /// Gets the stride (in units of bytes) beteween items -+ int64_t batch_stride_bytes() const; -+ -+ /// Capacity of allocation in number of elements -+ size_t capacity() const; -+ -+ /// Capacity of allocation in bytes -+ size_t bytes() const; -+ -+ /// Initializes a device allocation to a random distribution using cuRAND -+ void initialize_random_device(int seed, Distribution dist); -+ -+ /// Initializes a host allocation to a random distribution using std::cout -+ void initialize_random_host(int seed, Distribution dist); -+ -+ /// Initializes a device allocation to a random distribution using cuRAND -+ void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); -+ -+ /// Initializes a host allocation to a random distribution using std::cout -+ void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); -+ -+ /// Uniformly fills a tensor with a value when provided o.w. zero -+ void fill(double value); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_from_device(void const *ptr); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_from_host(void const *ptr); -+ -+ /// Copies from an equivalent-sized tensor in device memory -+ void copy_to_host(void *ptr); -+ -+ /// Writes a tensor to csv -+ void write_tensor_csv(std::ostream &out); -+}; -+ -+using DeviceAllocationList = std::list; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/device_context.cu b/3rdparty/cutlass/tools/profiler/src/device_context.cu -new file mode 100644 -index 0000000..117f78b ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_context.cu -@@ -0,0 +1,197 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#include "device_context.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_block( -+ std::string const &name, -+ library::NumericTypeID type, -+ size_t capacity) { -+ -+ device_memory_.emplace_back(type, capacity); -+ DeviceAllocation *allocation = &device_memory_.back(); -+ -+ allocations_[name] = allocation; -+ return allocation; -+} -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_tensor( -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ device_memory_.emplace_back(type, layout_id, extent, stride, batch_count); -+ DeviceAllocation *allocation = &device_memory_.back(); -+ -+ allocations_[name] = allocation; -+ return allocation; -+} -+ -+/// Allocates memory of a given type, capacity (elements), and name -+DeviceAllocation *DeviceContext::allocate_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ DeviceAllocation *allocation = -+ allocate_tensor(name, type, layout_id, extent, stride, batch_count); -+ -+ if (options.initialization.enabled) { -+ Distribution data_distribution = options.initialization.data_distribution; -+ -+ // check if data distribution is allowed to change -+ if(!options.initialization.fix_data_distribution) { -+ // change data distribution based on bit width -+ switch(type) { -+ case library::NumericTypeID::kF16: -+ data_distribution.set_uniform(-3, 3, 0); -+ break; -+ case library::NumericTypeID::kB1: -+ data_distribution.set_uniform(0, 1, 0); -+ break; -+ case library::NumericTypeID::kS2: -+ data_distribution.set_uniform(-1, 1, 0); -+ break; -+ case library::NumericTypeID::kS4: -+ data_distribution.set_uniform(-2, 2, 0); -+ break; -+ case library::NumericTypeID::kU2: -+ data_distribution.set_uniform(0, 2, 0); -+ break; -+ case library::NumericTypeID::kU4: -+ data_distribution.set_uniform(0, 2, 0); -+ break; -+ case library::NumericTypeID::kS8: -+ data_distribution.set_uniform(-3, 3, 0); -+ break; -+ case library::NumericTypeID::kU8: -+ data_distribution.set_uniform(0, 4, 0); -+ break; -+ default: break; -+ } -+ } -+ -+ if (options.initialization.provider == library::Provider::kReferenceDevice) { -+ allocation->initialize_random_device( -+ options.initialization.seed, -+ data_distribution); -+ } -+ else if (options.initialization.provider == library::Provider::kReferenceHost) { -+ allocation->initialize_random_host( -+ options.initialization.seed, -+ data_distribution); -+ } -+ } -+ -+ return allocation; -+} -+ -+/// Allocates memory for sparse meta data -+DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ library::NumericTypeID type_a, -+ std::vector const &extent, -+ std::vector const &stride, -+ int batch_count) { -+ -+ DeviceAllocation *allocation = -+ allocate_tensor(name, type, layout_id, extent, stride, batch_count); -+ -+ if (options.initialization.enabled) { -+ // TF32 has 4bit meta data. The rest has 2bit. -+ int MetaSizeInBits = (cutlass::library::sizeof_bits(type_a) == 32) ? 4 : 2; -+ -+ if (options.initialization.provider == library::Provider::kReferenceDevice) { -+ allocation->initialize_random_sparsemeta_device( -+ options.initialization.seed, -+ MetaSizeInBits); -+ } -+ else if (options.initialization.provider == library::Provider::kReferenceHost) { -+ allocation->initialize_random_sparsemeta_host( -+ options.initialization.seed, -+ MetaSizeInBits); -+ } -+ } -+ -+ return allocation; -+} -+/// Clears named allocations (but does not necessarily free memory) -+void DeviceContext::clear() { -+ allocations_.clear(); -+} -+ -+/// Frees all device memory allocations -+void DeviceContext::free() { -+ allocations_.clear(); -+ device_memory_.clear(); -+} -+ -+/// Gets the allocation by name -+DeviceAllocation &DeviceContext::at(std::string const &name) { -+ return *allocations_.at(name); -+} -+ -+size_t DeviceContext::size() const { -+ return allocations_.size(); -+} -+ -+DeviceContext::AllocationMap::iterator DeviceContext::begin() { -+ return allocations_.begin(); -+} -+ -+DeviceContext::AllocationMap::iterator DeviceContext::end() { -+ return allocations_.end(); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/device_context.h b/3rdparty/cutlass/tools/profiler/src/device_context.h -new file mode 100644 -index 0000000..16a72f9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/device_context.h -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+ -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+#include "options.h" -+#include "device_allocation.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Collection of allocations on the device -+class DeviceContext { -+public: -+ -+ // -+ // Type definitions -+ // -+ using AllocationMap = std::map; -+ -+private: -+ // -+ // Data members -+ // -+ -+ /// Memory allocations that exist (owning) -+ DeviceAllocationList device_memory_; -+ -+ /// Non-owning set of named allocations -+ AllocationMap allocations_; -+ -+public: -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_block( -+ std::string const &name, -+ library::NumericTypeID type, -+ size_t capacity); -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_tensor( -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Allocates memory of a given type, capacity (elements), and name -+ DeviceAllocation *allocate_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Allocates memory for sparse meta data -+ DeviceAllocation *allocate_sparsemeta_tensor( -+ Options const &options, -+ std::string const &name, -+ library::NumericTypeID type, -+ library::LayoutTypeID layout_id, -+ library::NumericTypeID type_a, -+ std::vector const &extent, -+ std::vector const &stride = std::vector(), -+ int batch_count = 1); -+ -+ /// Clears named allocations (but does not necessarily free memory) -+ void clear(); -+ -+ /// Frees all device memory allocations -+ void free(); -+ -+ /// Gets the allocation by name -+ DeviceAllocation &at(std::string const &name); -+ -+ size_t size() const; -+ -+ AllocationMap::iterator begin(); -+ AllocationMap::iterator end(); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/enumerated_types.h b/3rdparty/cutlass/tools/profiler/src/enumerated_types.h -new file mode 100644 -index 0000000..4d91324 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/enumerated_types.h -@@ -0,0 +1,169 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include "cutlass/library/library.h" -+ -+#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+T from_string(std::string const &); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Enumerated type describing how the performance testbench evaluates kernels. -+enum class ExecutionMode { -+ kProfile, ///< regular verification and profiling -+ kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched -+ kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations -+ kTrace, ///< executes a single device-side computation with no other kernel launches -+ kInvalid -+}; -+ -+/// Converts a ExecutionMode enumerant to a string -+char const *to_string(ExecutionMode mode, bool pretty = false); -+ -+/// Parses a ExecutionMode enumerant from a string -+template <> -+ExecutionMode from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Library algorithm mode -+enum class AlgorithmMode { -+ kMatching, ///< compare against best matching algorithm -+ kBest, ///< evaluate all library algorithms and report best -+ kDefault, ///< use the library's default algorithm option -+ kInvalid -+}; -+ -+/// Converts a ExecutionMode enumerant to a string -+char const *to_string(AlgorithmMode mode, bool pretty = false); -+ -+/// Parses a ExecutionMode enumerant from a string -+template <> -+AlgorithmMode from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Outcome of a performance test -+enum class Disposition { -+ kPassed, -+ kFailed, -+ kNotRun, -+ kIncorrect, -+ kNotVerified, -+ kInvalidProblem, -+ kNotSupported, -+ kInvalid -+}; -+ -+/// Converts a Disposition enumerant to a string -+char const *to_string(Disposition disposition, bool pretty = false); -+ -+/// Parses a Disposition enumerant from a string -+template <> -+Disposition from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Indicates when to save -+enum class SaveWorkspace { -+ kNever, -+ kIncorrect, -+ kAlways, -+ kInvalid -+}; -+ -+/// Converts a SaveWorkspace enumerant to a string -+char const *to_string(SaveWorkspace save_option, bool pretty = false); -+ -+/// Parses a SaveWorkspace enumerant from a string -+template <> -+SaveWorkspace from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Indicates the type of kernel argument -+// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric -+// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. -+// Its c++ equivalent as "type name = initializer" is "u32 m = 32" -+// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. -+// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" -+enum class ArgumentTypeID { -+ kScalar, -+ kInteger, -+ kTensor, -+ kBatchedTensor, -+ kStructure, -+ kEnumerated, -+ kInvalid -+}; -+ -+/// Converts a ArgumentTypeID enumerant to a string -+char const *to_string(ArgumentTypeID type, bool pretty = false); -+ -+/// Parses a ArgumentTypeID enumerant from a string -+template <> -+ArgumentTypeID from_string(std::string const &str); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+// Profiler typedefs -+using ProviderVector = std::vector; -+using DispositionMap = std::map; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Print vector for the report -+template -+std::ostream& operator<< (std::ostream& out, const std::vector& v) { -+ for(int i = 0; i < v.size(); ++i) { -+ out << to_string(v[i], true) << (i+1 != v.size() ? "," : ""); -+ } -+ return out; -+} -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu -new file mode 100644 -index 0000000..4b15fda ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.cu -@@ -0,0 +1,1219 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "gemm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+#include "cutlass/library/singleton.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/handle.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+GemmOperationProfiler::GemmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kGemm, -+ { -+ {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, -+ {ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ -+ description_ = " General matrix-matrix product. D = alpha * A*B + beta * C"; -+} -+ -+/// Destructor -+GemmOperationProfiler::~GemmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void GemmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "GEMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void GemmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=Gemm --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Gemm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" -+ -+ << "Profile a particular problem size with split K and paralell reduction:\n" -+ << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Gemm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Gemm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status GemmOperationProfiler::GemmProblem::parse( -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ this->mode = library::GemmUniversalMode::kGemm; -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { -+ // defualt value -+ this->split_k_mode = library::SplitKMode::kSerial; -+ } -+ -+ this->mode = library::GemmUniversalMode::kGemm; -+ if(this->split_k_mode == library::SplitKMode::kParallel) { -+ this->mode = library::GemmUniversalMode::kGemmSplitKParallel; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } else if (this->batch_count > 1) { -+ this->mode = library::GemmUniversalMode::kBatched; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->k)}).front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->k), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t GemmOperationProfiler::GemmProblem::bytes(library::GemmDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t GemmOperationProfiler::GemmProblem::flops(library::GemmDescription const &operation_desc) const { -+ int64_t flops_ = (int64_t(m) * n * k + m * n) * 2 * batch_count; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+ -+/// Initializes a performance result -+void GemmOperationProfiler::GemmProblem::initialize_result( -+ PerformanceResult &result, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); -+ -+ set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status GemmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::GemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.gemm_kind != library::GemmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ gemm_workspace_.configuration.mode = problem_.mode; -+ gemm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ gemm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ gemm_workspace_.configuration.problem_size.k() = int(problem_.k); -+ gemm_workspace_.configuration.lda = problem_.lda; -+ gemm_workspace_.configuration.ldb = problem_.ldb; -+ gemm_workspace_.configuration.ldc = problem_.ldc; -+ gemm_workspace_.configuration.ldd = problem_.ldc; -+ -+ if (problem_.mode == library::GemmUniversalMode::kBatched) { -+ gemm_workspace_.configuration.batch_count = problem_.batch_count; -+ } -+ else { -+ gemm_workspace_.configuration.batch_count = problem_.split_k_slices; -+ } -+ -+ gemm_workspace_.arguments.A = nullptr; -+ gemm_workspace_.arguments.B = nullptr; -+ gemm_workspace_.arguments.C = nullptr; -+ gemm_workspace_.arguments.D = nullptr; -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // initialize reduction operation for parallel splitKMode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!initialize_reduction_configuration_(operation, problem)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void GemmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+} -+ -+/// Initialize redution problem dimentions and library::Operation -+bool GemmOperationProfiler::initialize_reduction_configuration_( -+ library::Operation const *operation, -+ ProblemSpace::Problem const &problem) { -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) { -+ return false; -+ } -+ -+ if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) { -+ return false; -+ } -+ -+ /// initialize library::ReductionConfiguration -+ gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); -+ gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); -+ gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); -+ gemm_workspace_.reduction_configuration.ldw = problem_.ldc; -+ gemm_workspace_.reduction_configuration.lds = problem_.ldc; -+ gemm_workspace_.reduction_configuration.ldd = problem_.ldc; -+ -+ // find reduction operation -+ library::ReductionFunctionalKey reduction_key( -+ library::Provider::kCUTLASS, -+ gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace -+ gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator -+ gemm_desc.C.element, // element output -+ gemm_desc.element_epilogue // element coumpute -+ ); -+ -+ auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key); -+ -+ if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) { -+ return false; -+ } -+ -+ // initialize reduction operation required for parallel split-k operator -+ reduction_op_ = reduction_it->second; -+ -+ // reduction operation found and initialized -+ return true; -+} -+ -+/// Initializes workspace -+Status GemmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::Operation const* underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ library::GemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ // Compute the number of copies of the problem to avoid L2 camping. -+ if (!options.profiling.workspace_count) { -+ int64_t bytes = problem_.bytes(operation_desc); -+ if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { -+ gemm_workspace_.problem_count = -+ 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); -+ } -+ else { -+ gemm_workspace_.problem_count = 1; -+ } -+ } -+ else { -+ gemm_workspace_.problem_count = options.profiling.workspace_count; -+ } -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ gemm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.k)}, -+ {int(problem_.lda)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.k), int(problem_.n)}, -+ {int(problem_.ldb)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ problem_.batch_count * gemm_workspace_.problem_count -+ ); -+ -+ gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); -+ -+ // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels -+ gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; -+ gemm_workspace_.arguments.batch_count = problem_.batch_count; -+ gemm_workspace_.arguments.lda = problem_.lda; -+ gemm_workspace_.arguments.ldb = problem_.ldb; -+ gemm_workspace_.arguments.ldc = problem_.ldc; -+ gemm_workspace_.arguments.ldd = problem_.ldc; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, -+ &gemm_workspace_.arguments); -+ gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = underlying_operation->initialize( -+ &gemm_workspace_.configuration, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); -+ gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); -+ -+ status = reduction_op_->initialize( -+ &gemm_workspace_.reduction_configuration, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kGemm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool GemmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ gemm_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); -+ gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ // initialize gemm underlying operation to handle parallel reduction -+ library::Operation const * underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+ results_.back().status = underlying_operation->run( -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ results_.back().status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & gemm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(gemm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ verify_with_reference_(options, report, device_context, operation, problem_space, problem); -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool GemmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ std::vector algorithms; -+ -+ detail::select_cublas_algorithms( -+ algorithms, -+ options, -+ gemm_desc); -+ -+ if (algorithms.empty()) { -+ // no algorithm selected -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasGemmEx() -+ // -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.C = gemm_workspace_.Reference->data(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.Reference->batch_stride(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Reference->data(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Reference->batch_stride(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasGemmExDispatcher gemm_op( -+ gemm_desc, -+ gemm_workspace_.configuration, -+ gemm_workspace_.arguments, -+ algorithms.front() -+ ); -+ -+ if (gemm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = gemm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *gemm_workspace_.Computed, -+ *gemm_workspace_.Reference, -+ gemm_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ gemm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against host and device references -+bool GemmOperationProfiler::verify_with_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::GemmDescription const &gemm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Initialize state -+ // -+ -+ library::Provider references[] = { -+ library::Provider::kReferenceDevice, -+ library::Provider::kReferenceHost -+ }; -+ -+ for (auto provider : references) { -+ -+ // Skip providers that are not enabled -+ if (!options.verification.provider_enabled(provider)) { -+ continue; -+ } -+ -+ void *ptr_A = gemm_workspace_.A->data(); -+ void *ptr_B = gemm_workspace_.B->data(); -+ void *ptr_C = gemm_workspace_.C->data(); -+ void *ptr_D = gemm_workspace_.Reference->data(); -+ -+ // To support the host-side reference, conditionally allocate and -+ // copy tensors to host memory. -+ std::vector host_data_A; -+ std::vector host_data_B; -+ std::vector host_data_C; -+ std::vector host_data_D; -+ -+ if (provider == library::Provider::kReferenceHost) { -+ -+ host_data_A.resize(gemm_workspace_.A->bytes()); -+ ptr_A = host_data_A.data(); -+ gemm_workspace_.A->copy_to_host(ptr_A); -+ -+ host_data_B.resize(gemm_workspace_.B->bytes()); -+ ptr_B = host_data_B.data(); -+ gemm_workspace_.B->copy_to_host(ptr_B); -+ -+ host_data_C.resize(gemm_workspace_.C->bytes()); -+ ptr_C = host_data_C.data(); -+ gemm_workspace_.C->copy_to_host(ptr_C); -+ -+ host_data_D.resize(gemm_workspace_.Reference->bytes()); -+ ptr_D = host_data_D.data(); -+ } -+ -+ // -+ // Launch -+ // -+ -+ library::Handle handle; -+ -+ handle.set_provider(provider); -+ -+ Status status = handle.gemm_universal( -+ problem_.mode, -+ gemm_workspace_.configuration.problem_size.m(), -+ gemm_workspace_.configuration.problem_size.n(), -+ gemm_workspace_.configuration.problem_size.k(), -+ gemm_desc.tile_description.math_instruction.element_accumulator, -+ gemm_desc.element_epilogue, -+ -+ problem_.alpha.data(), -+ -+ gemm_desc.A.element, -+ gemm_desc.A.layout, -+ gemm_desc.transform_A, -+ ptr_A, -+ int(gemm_workspace_.configuration.lda), -+ -+ gemm_desc.B.element, -+ gemm_desc.B.layout, -+ gemm_desc.transform_B, -+ ptr_B, -+ int(gemm_workspace_.configuration.ldb), -+ -+ problem_.beta.data(), -+ -+ gemm_desc.C.element, -+ ptr_C, -+ int(gemm_workspace_.configuration.ldc), -+ -+ ptr_D, -+ int(gemm_workspace_.configuration.ldd), -+ -+ gemm_workspace_.configuration.batch_count, -+ gemm_workspace_.A->batch_stride(), -+ gemm_workspace_.B->batch_stride(), -+ gemm_workspace_.C->batch_stride(), -+ gemm_workspace_.Reference->batch_stride() -+ ); -+ -+ if (status != Status::kSuccess) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = status; -+ -+ if (provider == library::Provider::kReferenceHost) { -+ gemm_workspace_.Reference->copy_from_host(ptr_D); -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[provider] = compare_tensors( -+ options, -+ *gemm_workspace_.Computed, -+ *gemm_workspace_.Reference, -+ gemm_workspace_.Computed->batch_stride() -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[provider] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ gemm_desc, -+ library::Provider::kCUTLASS, -+ provider); -+ } -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool GemmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); -+ gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); -+ gemm_workspace_.arguments.beta = problem_.beta_zero.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); -+ gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); -+ gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ } -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Method to profile a CUTLASS Operation -+Status GemmOperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // initialize gemm underlying operation to handle parallel reduction -+ library::Operation const * underlying_operation = operation; -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { -+ return Status::kErrorNotSupported; -+ } -+ } -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; -+ -+ gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ // Execute the CUTLASS operation -+ status = underlying_operation->run( -+ &gemm_workspace_.arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ // Iterate over copies of the problem in memory -+ int workspace_idx = options.profiling.warmup_iterations + iteration; -+ int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; -+ -+ gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); -+ -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); -+ -+ gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); -+ gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); -+ gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); -+ } -+ -+ status = underlying_operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ // Run parallel reduction kernel for parallel split_k_mode -+ if (problem_.split_k_mode == library::SplitKMode::kParallel) { -+ status = reduction_op_->run( -+ &gemm_workspace_.reduction_arguments, -+ gemm_workspace_.reduction_host_workspace.data(), -+ nullptr); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h -new file mode 100644 -index 0000000..efee650 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gemm_operation_profiler.h -@@ -0,0 +1,269 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "reduction_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class GemmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct GemmProblem { -+ -+ cutlass::library::GemmUniversalMode mode; -+ cutlass::library::SplitKMode split_k_mode; -+ int64_t m; -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ std::vector alpha; -+ std::vector beta; -+ int split_k_slices; -+ int batch_count; -+ -+ // gemm with parallel interleaved reduction -+ // gemm epilogue (alpha, beta) = (1.0, 0.0) -+ // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) -+ std::vector alpha_one; -+ std::vector beta_zero; -+ -+ // -+ // Methods -+ // -+ -+ GemmProblem(): -+ mode(library::GemmUniversalMode::kGemm), -+ m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::GemmDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::GemmDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct GemmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ /// Number of copies of the problem workspace which are visited sequentially during -+ /// profiling to avoid camping in the last level cache. -+ int problem_count; -+ -+ library::GemmUniversalConfiguration configuration; -+ library::GemmUniversalArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ /// Library configuration and arguments for reduction operator -+ library::ReductionConfiguration reduction_configuration; -+ library::ReductionArguments reduction_arguments; -+ -+ /// Buffer used for the cutlass reduction operations' host workspace -+ std::vector reduction_host_workspace; -+ -+ // -+ // Methods -+ // -+ -+ GemmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ GemmProblem problem_; -+ -+ /// Device memory allocations -+ GemmWorkspace gemm_workspace_; -+ -+ /// CUTLASS parallel reduction operation to follow this* gemm operation -+ library::Operation const *reduction_op_; -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ GemmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~GemmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::GemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against host and device references -+ bool verify_with_reference_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Method to profile a CUTLASS Operation -+ Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+ /// Initialize reduction problem dimensions and library::Operation -+ bool initialize_reduction_configuration_( -+ library::Operation const *operation, -+ ProblemSpace::Problem const &problem); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/gpu_timer.h b/3rdparty/cutlass/tools/profiler/src/gpu_timer.h -new file mode 100644 -index 0000000..d8bce95 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/gpu_timer.h -@@ -0,0 +1,72 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+struct GpuTimer { -+ -+ cudaEvent_t events[2]; -+ -+ // -+ // Methods -+ // -+ -+ GpuTimer(); -+ ~GpuTimer(); -+ -+ /// Records a start event in the stream -+ void start(cudaStream_t stream = nullptr); -+ -+ /// Records a stop event in the stream -+ void stop(cudaStream_t stream = nullptr); -+ -+ /// Records a stop event in the stream and synchronizes on the stream -+ void stop_and_wait(cudaStream_t stream = nullptr); -+ -+ /// Returns the duration in miliseconds -+ double duration(int iterations = 1) const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu -new file mode 100644 -index 0000000..b2e8f9b ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/operation_profiler.cu -@@ -0,0 +1,691 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+#ifdef __unix__ -+#include -+#elif defined(_WIN32) || defined(WIN32) -+#include -+#else -+// sleep not supported -+#endif -+ -+#include "options.h" -+#include "operation_profiler.h" -+#include "gpu_timer.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+OperationProfiler::OperationProfiler(): kind_(library::OperationKind::kInvalid) { } -+ -+/// Ctor -+OperationProfiler::OperationProfiler( -+ Options const &options, -+ library::OperationKind kind, -+ ArgumentDescriptionVector const &arguments, -+ ProviderVector const & verification_providers -+): -+ kind_(kind), arguments_(arguments) { -+ -+ ArgumentDescriptionVector tile_description_arguments{ -+ {ArgumentTypeID::kEnumerated, {"op_class", "opcode-class"}, "Class of math instruction (simt, tensorop, wmmatensorop, wmma)"}, -+ {ArgumentTypeID::kEnumerated, {"accum", "accumulator-type"}, "Math instruction accumulator data type"}, -+ {ArgumentTypeID::kInteger, {"cta_m", "threadblock-shape::m"}, "Threadblock shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"cta_n", "threadblock-shape::n"}, "Threadblock shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"cta_k", "threadblock-shape::k"}, "Threadblock shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_m", "cluster-shape::m"}, "Cluster shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_n", "cluster-shape::n"}, "Cluster shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"cluster_k", "cluster-shape::k"}, "Cluster shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"stages", "threadblock-stages"}, "Number of stages of threadblock-scoped matrix multiply"}, -+ {ArgumentTypeID::kInteger, {"warps_m", "warp-count::m"}, "Number of warps within threadblock along the M dimension"}, -+ {ArgumentTypeID::kInteger, {"warps_n", "warp-count::n"}, "Number of warps within threadblock along the N dimension"}, -+ {ArgumentTypeID::kInteger, {"warps_k", "warp-count::k"}, "Number of warps within threadblock along the K dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_m", "instruction-shape::m"}, "Math instruction shape in the M dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_n", "instruction-shape::n"}, "Math instruction shape in the N dimension"}, -+ {ArgumentTypeID::kInteger, {"inst_k", "instruction-shape::k"}, "Math instruction shape in the K dimension"}, -+ {ArgumentTypeID::kInteger, {"min_cc", "minimum-compute-capability"}, "Minimum device compute capability"}, -+ {ArgumentTypeID::kInteger, {"max_cc", "maximum-compute-capability"}, "Maximum device compute capability"} -+ }; -+ -+ arguments_.insert(arguments_.end(), tile_description_arguments.begin(), tile_description_arguments.end()); -+ -+ for (auto provider : verification_providers) { -+ if (std::find( -+ options.verification.providers.begin(), -+ options.verification.providers.end(), -+ provider) != options.verification.providers.end()) { -+ -+ verification_providers_.push_back(provider); -+ } -+ } -+} -+ -+/// Destructor -+OperationProfiler::~OperationProfiler() { -+ -+} -+ -+/// Gets the schema description -+std::string const & OperationProfiler::description() const { -+ return description_; -+} -+ -+/// Prints usage statement for the math function -+void OperationProfiler::print_usage(std::ostream &out) const { -+ for (auto const & desc : arguments_) { -+ -+ size_t const kAliasStart = 10; -+ -+ size_t columns = 0; -+ -+ std::string type_str = to_string(desc.type); -+ columns += type_str.size(); -+ -+ out << " [" << type_str << "]"; -+ -+ if (columns < kAliasStart) { -+ out << std::string(kAliasStart - columns, ' '); -+ } -+ -+ columns = 0; -+ -+ int j = 0; -+ for (auto const & alias : desc.aliases) { -+ columns += alias.size() + (j ? 1 : 0) + 2; -+ -+ out << (j++ ? "," : "") << "--" << alias; -+ } -+ -+ size_t const kTotalColumns = 50; -+ -+ if (columns < kTotalColumns) { -+ out << std::string(kTotalColumns - columns, ' '); -+ } -+ -+ out << desc.description << "\n"; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if the current operation description satisfies the problem space -+bool OperationProfiler::satisfies( -+ library::OperationDescription const &op_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::OpcodeClassID opcode_class; -+ if (arg_as_OpcodeClassID(opcode_class, "op_class", problem_space, problem)) { -+ if (opcode_class != op_desc.tile_description.math_instruction.opcode_class) { -+ return false; -+ } -+ } -+ -+ int64_t int_value; -+ -+ if (arg_as_int(int_value, "inst_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "inst_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "inst_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.math_instruction.instruction_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cta_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "cluster_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.cluster_shape.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "stages", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.threadblock_stages) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_m", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.m()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_n", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.n()) != int_value) { -+ return false; -+ } -+ } -+ -+ if (arg_as_int(int_value, "warps_k", problem_space, problem)) { -+ if (int64_t(op_desc.tile_description.warp_count.k()) != int_value) { -+ return false; -+ } -+ } -+ -+ library::NumericTypeID numeric_type; -+ if (arg_as_NumericTypeID(numeric_type, "accum", problem_space, problem)) { -+ if (numeric_type != op_desc.tile_description.math_instruction.element_accumulator) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Entry point to profile all operations in the manifest -+int OperationProfiler::profile_all( -+ Options const &options, -+ library::Manifest const &manifest, -+ DeviceContext &device_context) { -+ -+ ProblemSpace problem_space(arguments_, options.cmdline); -+ -+ // 1. Construct performance report -+ PerformanceReport report(options, problem_space.argument_names(), kind_); -+ -+ // 2. For each problem in problem space -+ ProblemSpace::Iterator problem_it = problem_space.begin(); -+ ProblemSpace::Iterator problem_end = problem_space.end(); -+ -+ bool continue_profiling = true, internal_error = false; -+ -+ // For each problem in problem space -+ for (; continue_profiling && problem_it != problem_end; ++problem_it) { -+ -+ ProblemSpace::Problem problem = problem_it.at(); -+ -+ report.next_problem(); -+ -+ // For each operation in manifest -+ for (auto const & operation_ptr : manifest) { -+ -+ library::Operation const *operation = operation_ptr.get(); -+ -+ auto min_cc = operation->description().tile_description.minimum_compute_capability; -+ auto max_cc = operation->description().tile_description.maximum_compute_capability; -+ -+ // Clear named allocations -+ device_context.free(); -+ -+ // Execute compatible cutlass operations if they satisfy the current device's compute capability -+ if (operation->description().kind == kind_ && -+ operation->description().provider == library::Provider::kCUTLASS && -+ options.device.compute_capability() >= min_cc && -+ options.device.compute_capability() <= max_cc) { -+ -+ std::string operation_name(operation->description().name); -+ -+ // Filter kernels by name -+ bool filtered_by_name = options.operation_names.empty(); -+ if (!filtered_by_name) { -+ -+ for (auto const & op_name : options.operation_names) { -+ if (find_string_matches_(op_name, operation_name)) { -+ filtered_by_name = true; -+ break; -+ } -+ } -+ } -+ -+ for (auto const & op_name : options.excluded_operation_names) { -+ if (find_string_matches_(op_name, operation_name)) { -+ filtered_by_name = false; -+ break; -+ } -+ } -+ -+ if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { -+ continue; -+ } -+ -+ // A. Initialize configuration -+ Status status = this->initialize_configuration( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ -+ if (status == Status::kErrorInternal) { -+ -+ // If there was an internal error, consume the CUDA error and move to the next operation. -+ (void)cudaGetLastError(); -+ -+ report.append_results(results_); -+ continue; -+ } -+ else if (status != Status::kSuccess) { -+ // If the workspace could not be initialized for any other reason, continue to -+ // the next operation. -+ continue; -+ } -+ -+ if (continue_profiling) { -+ -+ status = this->initialize_workspace( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ -+ if (status == Status::kErrorInternal) { -+ -+ // If there was an internal error, consume the CUDA error and move to the next operation. -+ (void)cudaGetLastError(); -+ -+ report.append_results(results_); -+ continue; -+ } -+ else if (status != Status::kSuccess) { -+ // If the workspace could not be initialized for any other reason, continue to -+ // the next operation. -+ continue; -+ } -+ } -+ -+ // -+ // Profile CUTLASS if it is enabled -+ // -+ -+ // B. Verify CUTLASS -+ -+ if (continue_profiling && options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ continue_profiling = this->verify_cutlass( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ report.append_results(results_); -+ results_.clear(); -+ continue; -+ } -+ -+ // -+ // C. Optionally save workspace -+ // -+ -+ if (options.verification.save_workspace == SaveWorkspace::kAlways) { -+ save_workspace( -+ device_context, -+ options, -+ operation->description(), -+ library::Provider::kCUTLASS); -+ } -+ -+ // -+ // D. Profile -+ // -+ -+ if (continue_profiling && options.profiling.enabled) { -+ -+ continue_profiling = this->profile( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ report.append_results(results_); -+ results_.clear(); -+ } -+ -+ if (!continue_profiling) { -+ break; -+ } -+ } -+ } -+ -+ return internal_error ? 1 : 0; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sleep for a given duration in ms -+void OperationProfiler::sleep(int sleep_duration) { -+ if (sleep_duration) { -+ #ifdef __unix__ -+ usleep(sleep_duration * 1000); -+ #elif defined(_WIN32) || defined(WIN32) -+ SleepEx(sleep_duration, false); -+ #else -+ // sleep not supported -+ #endif -+ } -+} -+ -+ -+/// Compares tensors for equality -+Disposition OperationProfiler::compare_tensors( -+ Options const &options, -+ DeviceAllocation &experimental, -+ DeviceAllocation &reference, -+ int64_t count) { -+ -+ if (experimental.type() != reference.type()) { -+ return Disposition::kIncorrect; -+ } -+ -+ bool passed = false; -+ -+ if (count == 0) { -+ count = reference.capacity(); -+ } -+ -+ if (options.verification.epsilon == 0) { -+ -+ // bit-level equality -+ passed = DeviceAllocation::block_compare_equal( -+ experimental.type(), -+ experimental.data(), -+ reference.data(), -+ count); -+ } -+ else { -+ -+ // relative error function -+ passed = DeviceAllocation::block_compare_relatively_equal( -+ experimental.type(), -+ experimental.data(), -+ reference.data(), -+ count, -+ options.verification.epsilon, -+ options.verification.nonzero_floor); -+ } -+ -+ return passed ? Disposition::kPassed : Disposition::kIncorrect; -+} -+ -+/// Saves the workspace -+void OperationProfiler::save_workspace( -+ DeviceContext &device_context, -+ Options const &options, -+ library::OperationDescription const &desc, -+ library::Provider provider, -+ library::Provider verification_provider) { -+ -+ for (auto const & named_allocation : device_context) { -+ -+ DeviceAllocation *allocation = named_allocation.second; -+ -+ std::stringstream filename; -+ -+ filename << desc.name << "_" << library::to_string(provider) << "_"; -+ -+ if (verification_provider != library::Provider::kInvalid) { -+ filename << "verified_by_" << library::to_string(verification_provider) << "_"; -+ } -+ -+ filename << named_allocation.first + ".mat"; -+ -+ std::ofstream out(filename.str()); -+ -+ allocation->write_tensor_csv(out); -+ out << "\n"; -+ -+ if (options.report.verbose) { -+ std::cout << "wrote '" << filename.str() << "'" << std::endl; -+ } -+ } -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Method to profile a CUTLASS Operation -+Status OperationProfiler::profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace) { -+ -+ GpuTimer timer; -+ -+ // -+ // Optional sleep to limit power consumption and thermals -+ // -+ -+ sleep(options.profiling.sleep_duration); -+ -+ // -+ // Warmup loop -+ // -+ -+ Status status; -+ -+ for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { -+ -+ status = operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Initialize GPU timer -+ // -+ -+ timer.start(); -+ -+ // -+ // Profiling loop -+ // -+ -+ int Iterations = options.profiling.iterations; -+ -+ int iteration = 0; -+ for (; iteration < Iterations; ++iteration) { -+ -+ status = operation->run( -+ arguments, -+ host_workspace, -+ device_workspace); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ } -+ -+ // -+ // Wait for completion -+ // -+ -+ timer.stop_and_wait(); -+ -+ // -+ // Update performance result -+ // -+ -+ runtime = timer.duration(iteration); -+ -+ return status; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Sets operation description -+void OperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ library::OperationDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ set_argument(result, "op_class", problem_space, -+ library::to_string(operation_desc.tile_description.math_instruction.opcode_class)); -+ -+ set_argument(result, "accum", problem_space, -+ library::to_string(operation_desc.tile_description.math_instruction.element_accumulator)); -+ -+ set_argument(result, "cta_m", problem_space, operation_desc.tile_description.threadblock_shape.m()); -+ set_argument(result, "cta_n", problem_space, operation_desc.tile_description.threadblock_shape.n()); -+ set_argument(result, "cta_k", problem_space, operation_desc.tile_description.threadblock_shape.k()); -+ set_argument(result, "cluster_m", problem_space, operation_desc.tile_description.cluster_shape.m()); -+ set_argument(result, "cluster_n", problem_space, operation_desc.tile_description.cluster_shape.n()); -+ set_argument(result, "cluster_k", problem_space, operation_desc.tile_description.cluster_shape.k()); -+ set_argument(result, "stages", problem_space, operation_desc.tile_description.threadblock_stages); -+ set_argument(result, "warps_m", problem_space, operation_desc.tile_description.warp_count.m()); -+ set_argument(result, "warps_n", problem_space, operation_desc.tile_description.warp_count.n()); -+ set_argument(result, "warps_k", problem_space, operation_desc.tile_description.warp_count.k()); -+ set_argument(result, "inst_m", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.m()); -+ set_argument(result, "inst_n", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.n()); -+ set_argument(result, "inst_k", problem_space, operation_desc.tile_description.math_instruction.instruction_shape.k()); -+ set_argument(result, "min_cc", problem_space, operation_desc.tile_description.minimum_compute_capability); -+ set_argument(result, "max_cc", problem_space, operation_desc.tile_description.maximum_compute_capability); -+} -+ -+/// Helper -+void OperationProfiler::set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ std::string const &value) { -+ -+ result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), value); -+} -+ -+void OperationProfiler::set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ int64_t value) { -+ -+ result.arguments.at(problem_space.argument_index(name)) = make_pair(std::string(name), library::lexical_cast(value)); -+} -+ -+ -+/// finds string matches filter_string in operation_name -+bool OperationProfiler::find_string_matches_( -+ std::string const &filter_string, -+ std::string const &operation_name) { -+ // Returns true if all substrings appear in the operation_name in order -+ -+ // Split filter_string of the format "gemm*f32*nt" to tokens ["gemm", "f32", "nt"] -+ std::string item; -+ std::istringstream iss(filter_string); -+ std::vector filter_tokens; -+ while (std::getline(iss, item, '*')) { -+ filter_tokens.push_back(item); -+ } -+ -+ // Search filter_tokens in operation_name in order -+ size_t start = 0, idx = 0; -+ for(auto & token : filter_tokens) { -+ // Check if characters left to be parsed in operation_name -+ if (start < operation_name.length()) { -+ // Find token in operation_name[start:] -+ idx = operation_name.substr(start).find(token); -+ if (idx == std::string::npos) { -+ return false; -+ } -+ } -+ start += (idx + token.length()); -+ } -+ -+ // All tokens in filter_string found in operation_name -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/operation_profiler.h -new file mode 100644 -index 0000000..a2b0bdd ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/operation_profiler.h -@@ -0,0 +1,256 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "performance_result.h" -+#include "performance_report.h" -+#include "problem_space.h" -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class OperationProfiler { -+public: -+ -+ -+protected: -+ // -+ // Data members -+ // -+ -+ /// Top-level operation kind -+ library::OperationKind kind_; -+ -+ /// Human readable description -+ std::string description_; -+ -+ /// Arguments parsed from command line -+ ArgumentDescriptionVector arguments_; -+ -+ /// List of providers used to verify and compare each result -+ ProviderVector verification_providers_; -+ -+ /// Model performance result initailized by the operation profiler with workload statistics -+ /// and reasonable default state. -+ PerformanceResult model_result_; -+ -+ /// Performance result vector constructed by profiling the operation -+ PerformanceResultVector results_; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ OperationProfiler(); -+ -+ OperationProfiler( -+ Options const &options, -+ library::OperationKind kind, -+ ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), -+ ProviderVector const & verification_providers = ProviderVector()); -+ -+ /// Destructor -+ virtual ~OperationProfiler(); -+ -+ /// Obtains the operation kind -+ library::OperationKind kind() const { return kind_; } -+ -+ /// Gets the schema description -+ std::string const &description() const; -+ -+ /// Returns a reference to the arguments -+ ArgumentDescriptionVector const &arguments() const { return arguments_; } -+ -+public: -+ -+ // -+ // Basic overrides -+ // -+ -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const =0; -+ -+ /// Entry point to profile all operations in the manifest -+ virtual int profile_all( -+ Options const &options, -+ library::Manifest const &manifest, -+ DeviceContext &device_context); -+ -+public: -+ -+ // -+ // Operation-specific phases of verification and profiling -+ // -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) = 0; -+ -+public: -+ -+ // -+ // Static helpers -+ // -+ -+ /// Sleep for a given duration in ms -+ static void sleep(int sleep_duration); -+ -+ /// Returns true if the current operation description satisfies the problem space -+ static bool satisfies( -+ library::OperationDescription const &op_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Compares tensors for equality -+ static Disposition compare_tensors( -+ Options const &options, -+ DeviceAllocation &experimental, -+ DeviceAllocation &reference, -+ int64_t count = 0); -+ -+ static void save_workspace( -+ DeviceContext &device_context, -+ Options const &options, -+ library::OperationDescription const &desc, -+ library::Provider provider, -+ library::Provider verification_provider = library::Provider::kInvalid); -+ -+ /// Helper to set a performance result member -+ static void set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ std::string const &value); -+ -+ /// Helper to set a performance result member -+ static void set_argument( -+ PerformanceResult &result, -+ char const *name, -+ ProblemSpace const &problem_space, -+ int64_t value); -+ -+protected: -+ -+ /// Sets operation description -+ static void initialize_result_( -+ PerformanceResult &result, -+ library::OperationDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Method to profile an initialized CUTLASS operation -+ virtual Status profile_cutlass_( -+ double &runtime, -+ Options const &options, -+ library::Operation const *operation, -+ void *arguments, -+ void *host_workspace, -+ void *device_workspace); -+ -+private: -+ /// finds string matches filter_string in operation_name -+ bool find_string_matches_( -+ std::string const &filter_string, -+ std::string const &operation_name); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Vector of owning operation profilers -+using OperationProfilerVector = std::vector>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/options.cu b/3rdparty/cutlass/tools/profiler/src/options.cu -new file mode 100644 -index 0000000..ea79a9d ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/options.cu -@@ -0,0 +1,815 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Command line options for performance test program -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/version.h" -+ -+#include "cutlass/library/util.h" -+ -+#include "options.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Newline and indent for help strings -+static char const *end_of_line = "\n "; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Device::Device(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("device", device, 0); -+ -+ cudaError_t result; -+ result = cudaGetDeviceProperties(&properties, device); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties() failed for given device"); -+ } -+ -+ result = cudaSetDevice(device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaSetDevice() failed for given device."); -+ } -+ -+ // Permit overriding the compute capability -+ if (cmdline.check_cmd_line_flag("compute-capability")) { -+ int cc = compute_capability(); -+ cmdline.get_cmd_line_argument("compute-capability", cc, cc); -+ properties.major = cc / 10; -+ properties.minor = cc % 10; -+ } -+ -+ // Permit overriding the L2 cache capacity -+ if (cmdline.check_cmd_line_flag("llc-capacity")) { -+ int llc_capacity = 0; -+ cmdline.get_cmd_line_argument("llc-capacity", llc_capacity, 0); -+ -+ if (llc_capacity >= 0) { -+ properties.l2CacheSize = (llc_capacity << 10); -+ } -+ } -+ -+} -+ -+void Options::Device::print_usage(std::ostream &out) const { -+ -+ out << "Device:\n" -+ << " --device= " -+ << " CUDA Device ID\n\n"; -+ -+ int device_count = 0; -+ cudaError_t result = cudaGetDeviceCount(&device_count); -+ -+ if (result != cudaSuccess) { -+ out << " \n"; -+ } -+ else { -+ -+ for (int idx = 0; idx < device_count; ++idx) { -+ cudaDeviceProp prop; -+ result = cudaGetDeviceProperties(&prop, idx); -+ if (result != cudaSuccess) { -+ out << " " << std::endl; -+ break; -+ } -+ else { -+ out << " [" << idx << "] - " -+ << prop.name << " - SM " << prop.major << "." << prop.minor << ", " -+ << prop.multiProcessorCount << " SMs @ " << (prop.clockRate / 1000.0) << " MHz, " -+ << "L2 cache: " << (prop.l2CacheSize >> 20) << " MB, Global Memory: " << (prop.totalGlobalMem >> 30) << " GB" -+ << std::endl; -+ } -+ } -+ out << "\n"; -+ } -+ -+ out -+ << " --compute-capability= " -+ << " Override the compute capability.\n\n" -+ -+ << " --llc-capacity= " -+ << " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line -+ << " profiling phases cycle through different input tensors to induce" << end_of_line -+ << " capacity misses in the L2.\n\n"; -+ -+} -+ -+void Options::Device::print_device_info(std::ostream &out) const { -+ int num_devices; -+ cudaDeviceProp props; -+ -+ cudaError_t result; -+ result = cudaGetDeviceCount(&num_devices); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetNumDevices() failed"); -+ } -+ -+ out << "Device Name,SM,CUDA Device ID,Phy Device ID" << std::endl; -+ -+ for(int device = 0; device < num_devices; device++) { -+ result = cudaSetDevice(device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaSetDevice() failed for device"); -+ } -+ -+ result = cudaGetDeviceProperties(&props, device); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProperties failed for device"); -+ } -+ -+ out << props.name << "," << props.major << props.minor << "," -+ << device << "," << props.multiGpuBoardGroupID << std::endl; -+ -+ } -+} -+ -+void Options::Device::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "device: " << device << "\n" -+ << indent_str(indent) << "clock: " << int(double(properties.clockRate) / 1000.0) << "\n" -+ << indent_str(indent) << "compute-capability: " << compute_capability() << "\n"; -+} -+ -+/// Returns the compute capability of the listed device (e.g. 61, 60, 70, 75) -+int Options::Device::compute_capability() const { -+ return properties.major * 10 + properties.minor; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("initialization-enabled", enabled, true); -+ -+ if (cmdline.check_cmd_line_flag("initialization-provider")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("initialization-provider", str); -+ provider = library::from_string(str); -+ if (provider == library::Provider::kInvalid) { -+ enabled = false; -+ } -+ else if (provider != library::Provider::kReferenceHost && provider != library::Provider::kReferenceDevice) { -+ throw std::runtime_error("Unsupported intialization provider specified."); -+ } -+ } -+ else { -+ provider = library::Provider::kReferenceDevice; -+ } -+ -+ cmdline.get_cmd_line_argument("seed", seed, 2019); -+ -+ if (cmdline.check_cmd_line_flag("dist")) { -+ // user has set the data distribution (fix data distribution once set) -+ fix_data_distribution = true; -+ // set user provided data distribution -+ get_distribution(cmdline, "dist", data_distribution); -+ } -+ else { -+ // profiler choosen data distribution (allowed to change based on numeric types) -+ fix_data_distribution = false; -+ // set uniform data distribution with range [-4, 4] -+ data_distribution.set_uniform(-4, 4, 0); -+ } -+ -+ -+} -+ -+/// Gets the initial distribution -+void Options::Initialization::get_distribution( -+ cutlass::CommandLine const &args, -+ std::string const &arg, -+ cutlass::Distribution &dist) { -+ -+ struct { -+ const char *label; -+ cutlass::Distribution::Kind kind; -+ } distribution_kinds[] = { -+ {"uniform", cutlass::Distribution::Uniform}, -+ {"gaussian", cutlass::Distribution::Gaussian}, -+ {"identity", cutlass::Distribution::Identity}, -+ {"sequential", cutlass::Distribution::Sequential}, -+ {0, cutlass::Distribution::Invalid} -+ }; -+ -+ struct { -+ char const *label; -+ double *member; -+ } members[] = { -+ {"min", &dist.uniform.min}, -+ {"max", &dist.uniform.max}, -+ {"mean", &dist.gaussian.mean}, -+ {"stddev", &dist.gaussian.stddev}, -+ {"start", &dist.sequential.start}, -+ {"delta", &dist.sequential.delta}, -+ {0, 0} -+ }; -+ -+ using KeyValueVector = std::vector >; -+ -+ KeyValueVector values; -+ args.get_cmd_line_argument_pairs(arg.c_str(), values); -+ -+ // The parser expects the first token to be a string identifying the distribution type. -+ auto it = values.begin(); -+ if (it != values.end()) { -+ for (int i = 0; distribution_kinds[i].label; ++i) { -+ if (it->first == distribution_kinds[i].label) { -+ dist.kind = distribution_kinds[i].kind; -+ break; -+ } -+ } -+ ++it; -+ } -+ -+ // Subsequent key-value pairs update the named field of the distribution struct. -+ for (; it != values.end(); ++it) { -+ // Integer scaling factor - if < 0, no integer rounding is performed. -+ if ((it->first.compare("scale") == 0) && !it->second.empty()) { -+ std::stringstream ss; -+ ss << it->second; -+ ss >> dist.int_scale; -+ continue; // next token -+ } -+ -+ // Casts as integer without scaling -+ if (it->first.compare("integer") == 0) { -+ dist.int_scale = 0; -+ continue; // next token -+ } -+ -+ // initialize other members -+ for (int m = 0; members[m].label; ++m) { -+ if (it->first == members[m].label && !it->second.empty()) { -+ std::stringstream ss; -+ ss << it->second; -+ ss >> *(members[m].member); -+ } -+ } -+ } -+} -+ -+void Options::Initialization::print_usage(std::ostream &out) const { -+ -+ out << "Initialization:\n" -+ -+ << " --initialization= " -+ << " Enables initialization (default: true). If false, device memory is" << end_of_line -+ << " not initialized after allocation.\n\n" -+ -+ << " --initialization-provider= " -+ << " Selects initialization provider {host, device*}. (default: '*')\n\n" -+ -+ << " --dist= " -+ << " Data distribution of input tensors {uniform*, gaussian, identity, sequential}" << end_of_line -+ << " --dist=uniform,min:,max:,scale:" << end_of_line -+ << " --dist=gaussian,mean:,stddev:,scale:" << end_of_line -+ << " --dist=sequential,start:,delta:,scale:" << end_of_line -+ << " --dist=identity\n\n" -+ -+ << " --seed= " -+ << " Random number generator seed. Used to enforce deterministic" << end_of_line -+ << " initialization.\n\n"; -+ -+} -+ -+void Options::Initialization::print_options(std::ostream &out, int indent) const { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Library::Library(cutlass::CommandLine const &cmdline) { -+ -+ algorithm_mode = AlgorithmMode::kDefault; -+ -+ if (cmdline.check_cmd_line_flag("library-algo-mode")) { -+ std::string mode = "default"; -+ cmdline.get_cmd_line_argument("library-algo-mode", mode); -+ algorithm_mode = from_string(mode); -+ } -+ -+ if (cmdline.check_cmd_line_flag("library-algos")) { -+ -+ // If algorithms are specified, override as kBest. -+ algorithm_mode = AlgorithmMode::kBest; -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("library-algos", tokens); -+ -+ algorithms.reserve(tokens.size()); -+ -+ for (auto const & token : tokens) { -+ if (token.find(":")) { -+ // todo - tokenized range -+ } -+ else { -+ int algo; -+ std::stringstream ss; -+ -+ ss << token; -+ ss >> algo; -+ -+ algorithms.push_back(algo); -+ } -+ } -+ } -+} -+ -+void Options::Library::print_usage(std::ostream &out) const { -+ -+ out << "Library:\n" -+ -+ << " --library-algo-mode= " -+ << " Indicates algorithm mode used to call libraries such as cuBLAS and cuDNN.\n" -+ << " " -+ << " mode={default*,matching,best}\n\n" -+ -+ << " --library-algos= " -+ << " If --algorithm-mode=best, permits specifying a selection of algorithms.\n\n"; -+ -+} -+ -+void Options::Library::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "library-algo-mode: " << to_string(algorithm_mode) << "\n" -+ << indent_str(indent) << "library-algos: "; -+ -+ int j = 0; -+ for (int x : algorithms) { -+ out << (j++ ? "," : "") << x; -+ } -+ -+ out << "\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("workspace-count", workspace_count, 0); -+ cmdline.get_cmd_line_argument("warmup-iterations", warmup_iterations, 10); -+ cmdline.get_cmd_line_argument("profiling-iterations", iterations, 100); -+ cmdline.get_cmd_line_argument("sleep-duration", sleep_duration, 50); -+ cmdline.get_cmd_line_argument("profiling-enabled", enabled, true); -+ -+ if (cmdline.check_cmd_line_flag("providers")) { -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("providers", tokens); -+ -+ providers.clear(); -+ -+ for (auto const &token : tokens) { -+ providers.push_back(library::from_string(token)); -+ } -+ } -+ else { -+ providers.push_back(library::Provider::kCUTLASS); -+ providers.push_back(library::Provider::kCUBLAS); -+ providers.push_back(library::Provider::kCUDNN); -+ } -+} -+ -+void Options::Profiling::print_usage(std::ostream &out) const { -+ -+ out << "Profiling:\n" -+ -+ << " --workspace-count= " -+ << " Number of discrete workspaces maintained to avoid cache-resident " << end_of_line -+ << " If zero (default), the amount is chosen for each workload based on " << end_of_line -+ << " capacity of the last-level cache.\n\n" -+ -+ << " --profiling-iterations= " -+ << " Number of iterations to profile each kernel. If zero, kernels" << end_of_line -+ << " are launched up to the profiling duration.\n\n" -+ -+ << " --warmup-iterations= " -+ << " Number of iterations to execute each kernel prior to profiling.\n\n" -+ -+ << " --sleep-duration= " -+ << " Number of ms to sleep between profiling periods (ms).\n\n" -+ -+ << " --profiling-enabled= " -+ << " If true, profiling is actually conducted.\n\n" -+ -+ ; -+} -+ -+void Options::Profiling::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "profiling_iterations: " << iterations << "\n" -+ << indent_str(indent) << "sleep_duration: " << sleep_duration << "\n" -+ << indent_str(indent) << "profiling_enabled: " << enabled << "\n" -+ << indent_str(indent) << "providers: ["; -+ -+ int j = 0; -+ for (auto const & provider : providers) { -+ out << (j++ ? ", " : "") << library::to_string(provider); -+ } -+ out << "]\n"; -+} -+ -+/// Returns true if a provider is enabled -+bool Options::Profiling::provider_enabled(library::Provider provider) const { -+ return std::find(providers.begin(), providers.end(), provider) != providers.end(); -+} -+ -+/// Returns the index of a provider if its enabled -+size_t Options::Profiling::index(library::Provider provider) const { -+ size_t idx = 0; -+ for (auto const & x : providers) { -+ if (x == provider) { -+ return idx; -+ } -+ ++idx; -+ } -+ return idx; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Verification::Verification(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("verification-enabled", enabled, true); -+ -+ cmdline.get_cmd_line_argument("epsilon", epsilon, 0.05); -+ -+ cmdline.get_cmd_line_argument("nonzero-floor", nonzero_floor, 1.0 / 256.0); -+ -+ if (cmdline.check_cmd_line_flag("save-workspace")) { -+ std::string value; -+ cmdline.get_cmd_line_argument("save-workspace", value); -+ save_workspace = from_string(value); -+ } -+ else { -+ save_workspace = SaveWorkspace::kNever; -+ } -+ -+ if (cmdline.check_cmd_line_flag("verification-providers")) { -+ -+ std::vector tokens; -+ cmdline.get_cmd_line_arguments("verification-providers", tokens); -+ -+ providers.clear(); -+ -+ for (auto const &token : tokens) { -+ library::Provider provider = library::from_string(token); -+ if (provider != library::Provider::kInvalid) { -+ providers.push_back(provider); -+ } -+ } -+ } -+ else { -+ providers.push_back(library::Provider::kCUBLAS); -+ providers.push_back(library::Provider::kReferenceDevice); -+ providers.push_back(library::Provider::kCUDNN); -+ } -+} -+ -+void Options::Verification::print_usage(std::ostream &out) const { -+ -+ out << "Verification:\n" -+ -+ << " --verification-enabled= " -+ << " Whether to perform verification checks.\n\n" -+ -+ << " --epsilon= " -+ << " Error threshold. Setting to zero (default) requires" << end_of_line -+ << " bit-level equivalence.\n\n" -+ -+ << " --nonzero-floor= " -+ << " Results whose absolute value is less than this quantity" << end_of_line -+ << " are treated as zero for comparisons.\n\n" -+ -+ << " --save-workspace= " -+ << " Specifies when to save the GEMM inputs and results to the filesystem." << end_of_line -+ << " --save-workspace=never never save workspace (default)" << end_of_line -+ << " --save-workspace=incorrect save workspace for incorrect results" << end_of_line -+ << " --save-workspace=always always save workspace\n\n" -+ -+ << " --verification-providers= " -+ << " List of providers used to verify result. (default: '*')" << end_of_line -+ << " Gemm verification-providers {cublas*}" << end_of_line -+ << " Conv2d verification-providers {cudnn*, device*, host}" -+ << "\n\n"; -+} -+ -+void Options::Verification::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "verification_enabled: " << enabled << "\n" -+ << indent_str(indent) << "epsilon: " << epsilon << "\n" -+ << indent_str(indent) << "save_workspace: " << to_string(save_workspace) << "\n" -+ << indent_str(indent) << "verification_providers: ["; -+ -+ int j = 0; -+ for (auto const & provider : providers) { -+ out << (j++ ? ", " : "") << library::to_string(provider); -+ } -+ out << "]\n"; -+} -+ -+/// Returns true if a provider is enabled -+bool Options::Verification::provider_enabled(library::Provider provider) const { -+ return std::find(providers.begin(), providers.end(), provider) != providers.end(); -+} -+ -+/// Returns the index of a provider if its enabled -+size_t Options::Verification::index(library::Provider provider) const { -+ size_t idx = 0; -+ for (auto const & x : providers) { -+ if (x == provider) { -+ return idx; -+ } -+ ++idx; -+ } -+ return idx; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Report::Report(cutlass::CommandLine const &cmdline) { -+ -+ cmdline.get_cmd_line_argument("append", append, false); -+ cmdline.get_cmd_line_argument("output", output_path); -+ cmdline.get_cmd_line_argument("junit-output", junit_output_path); -+ -+ if (cmdline.check_cmd_line_flag("tags")) { -+ cmdline.get_cmd_line_argument_pairs("tags", pivot_tags); -+ } -+ -+ cmdline.get_cmd_line_argument("report-not-run", report_not_run, false); -+ -+ cmdline.get_cmd_line_argument("verbose", verbose, true); -+ -+ cmdline.get_cmd_line_argument("sort-results", sort_results, false); -+} -+ -+void Options::Report::print_usage(std::ostream &out) const { -+ -+ out << "Report:\n" -+ -+ << " --append= " -+ << " If true, result is appended to possibly existing file. Otherwise, " << end_of_line -+ << " any existing file is overwritten.\n\n" -+ -+ << " --output= " -+ << " Path to output file for machine readable results. Operation kind and '.csv' is appended.\n\n" -+ -+ << " --junit-output= " -+ << " Path to junit output file for result reporting. Operation kind and '.junit.xml' is appended.\n\n" -+ -+ << " --report-not-run= " -+ << " If true, reports the status of all kernels including those that" << end_of_line -+ << " do not satisfy the given arguments.\n\n" -+ -+ << " --tags= " -+ << " Inserts leading columns in output table and uniform values for each" << end_of_line -+ << " column. Useful for generating pivot tables.\n\n" -+ -+ << " --verbose= " -+ << " Prints human-readable text to stdout. If false, nothing is written to stdout.\n\n" -+ -+ << " --sort-results= " -+ << " Sorts results (by flops-per-byte).\n\n"; -+} -+ -+void Options::Report::print_options(std::ostream &out, int indent) const { -+ -+ out -+ << indent_str(indent) << "append: " << append << "\n" -+ << indent_str(indent) << "output: " << output_path << "\n" -+ << indent_str(indent) << "junit-output: " << junit_output_path << "\n" -+ << indent_str(indent) << "report_not_run: " << report_not_run << "\n" -+ << indent_str(indent) << "tags:\n"; -+ -+ for (auto const & tag : pivot_tags) { -+ out << indent_str(indent + 1) << tag.first << ": " << tag.second << "\n"; -+ } -+ -+ out -+ << indent_str(indent) << "verbose: " << verbose << "\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::About::About(cutlass::CommandLine const &cmdline) { -+ help = cmdline.check_cmd_line_flag("help"); -+ version = cmdline.check_cmd_line_flag("version"); -+ device_info = cmdline.check_cmd_line_flag("device-info"); -+} -+ -+void Options::About::print_usage(std::ostream &out) const { -+ -+ out << "About:\n" -+ << " --version "; -+ -+ print_version(out); -+ -+ out << "\n"; -+} -+ -+void Options::About::print_version(std::ostream &out) { -+ out << "CUTLASS " << cutlass::getVersionString() -+ << " built on " << __DATE__ << " at " << __TIME__; -+ if (!cutlass::getGitRevision().empty()) out << " with commit " << cutlass::getGitRevision() << ""; -+} -+ -+void Options::About::print_options(std::ostream &out, int indent) const { -+ -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Options::Options(cutlass::CommandLine const &cmdline): -+ cmdline(cmdline), -+ device(cmdline), -+ initialization(cmdline), -+ library(cmdline), -+ profiling(cmdline), -+ verification(cmdline), -+ report(cmdline), -+ about(cmdline) { -+ -+ if (cmdline.check_cmd_line_flag("mode")) { -+ std::string token; -+ cmdline.get_cmd_line_argument("mode", token); -+ execution_mode = from_string(token); -+ } -+ else { -+ execution_mode = ExecutionMode::kProfile; -+ } -+ -+ // Enumerating kernels is equivalent to a dry run. -+ if (execution_mode == ExecutionMode::kEnumerate) { -+ execution_mode = ExecutionMode::kDryRun; -+ } -+ -+ if (cmdline.check_cmd_line_flag("operation")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("operation", str); -+ operation_kind = library::from_string(str); -+ } -+ else if (cmdline.check_cmd_line_flag("function")) { -+ std::string str; -+ cmdline.get_cmd_line_argument("function", str); -+ operation_kind = library::from_string(str); -+ } -+ else { -+ operation_kind = library::OperationKind::kInvalid; -+ } -+ -+ if (cmdline.check_cmd_line_flag("operation_names")) { -+ cmdline.get_cmd_line_arguments("operation_names", operation_names); -+ } -+ else if (cmdline.check_cmd_line_flag("kernels")) { -+ cmdline.get_cmd_line_arguments("kernels", operation_names); -+ } -+ -+ if (cmdline.check_cmd_line_flag("ignore-kernels")) { -+ cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names); -+ } -+ -+ // Prevent launches on the device for anything other than CUTLASS operation -+ if (execution_mode == ExecutionMode::kTrace) { -+ initialization.provider = library::Provider::kReferenceHost; -+ verification.enabled = false; -+ profiling.enabled = false; -+ } -+} -+ -+void Options::print_usage(std::ostream &out) const { -+ -+ out -+ << "CUTLASS Profiler\n" -+ << "usage:\n\n" -+ << " cutlass_profiler [options]\n\n" -+ << " --help\n\n" -+ -+ << " --mode= " -+ << " Cutlass profiler execution mode." << end_of_line -+ << " --mode=profile regular verification and profiling (default)" << end_of_line -+ << " --mode=dry_run no kernels are launched or workspaces allocated" << end_of_line -+ << " --mode=enumerate lists all operation kind and operations" << end_of_line -+ << " --mode=trace executes a single device-side computation with" << end_of_line -+ << " no other kernel launches\n\n" -+ -+ << " --device-info " -+ << " Prints information on all GPUs present in the system\n\n" -+ -+ << " --operation= " -+ << " CUTLASS operation to profile.\n\n" -+ -+ << " --kernels= " -+ << " Filter operations by kernel names. For example, call all kernels with" << end_of_line -+ << " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line -+ << " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n" -+ -+ << " --ignore-kernels= " -+ << " Excludes kernels whose names match anything in this list.\n\n" -+ ; -+ -+ // -+ // Detailed options -+ // -+ -+ device.print_usage(out); -+ out << "\n"; -+ -+ initialization.print_usage(out); -+ out << "\n"; -+ -+ library.print_usage(out); -+ out << "\n"; -+ -+ profiling.print_usage(out); -+ out << "\n"; -+ -+ verification.print_usage(out); -+ out << "\n"; -+ -+ report.print_usage(out); -+ out << "\n"; -+ -+ about.print_usage(out); -+ out << "\n"; -+} -+ -+void Options::print_options(std::ostream &out) const { -+ -+ out -+ << "options:\n" -+ << " help: " << about.help << "\n" -+ << " mode: " << to_string(execution_mode) << "\n"; -+ -+ out -+ << " device:\n"; -+ device.print_options(out, 2); -+ -+ out -+ << " initialization:\n"; -+ initialization.print_options(out, 2); -+ -+ out -+ << " profiling:\n"; -+ profiling.print_options(out, 2); -+ -+ out -+ << " verification:\n"; -+ verification.print_options(out, 2); -+ -+ out -+ << " report:\n"; -+ report.print_options(out, 2); -+} -+ -+std::string Options::indent_str(int indent) { -+ return std::string(indent * 2, ' '); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/options.h b/3rdparty/cutlass/tools/profiler/src/options.h -new file mode 100644 -index 0000000..02edd9a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/options.h -@@ -0,0 +1,323 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Command line options for performance test program -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include -+ -+#include "cutlass/util/command_line.h" -+#include "cutlass/util/distribution.h" -+#include "cutlass/library/library.h" -+ -+#include "enumerated_types.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Global options -+class Options { -+public: -+ -+ /// Cublas and cuDNN options -+ struct Library { -+ -+ // -+ // Data members -+ // -+ -+ /// Algorithm mode -+ AlgorithmMode algorithm_mode; -+ -+ /// Algorithm enumerants -+ std::vector algorithms; -+ -+ // -+ // Methods -+ // -+ -+ Library(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ }; -+ -+ /// Options related to the selected device -+ struct Device { -+ -+ /// Device ID -+ int device; -+ -+ /// CUDA Device properties -+ cudaDeviceProp properties; -+ -+ /// Total memory allocation on device -+ size_t maximum_capacity; -+ -+ // -+ // Methods -+ // -+ -+ Device(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ void print_device_info(std::ostream &out) const; -+ -+ /// Returns the compute capability of the listed device (e.g. 61, 60, 70, 75) -+ int compute_capability() const; -+ }; -+ -+ /// Options related to initializing input tensors -+ struct Initialization { -+ -+ /// If true, data is initialized randomly. If false, no initialization is performed after -+ /// allocating tensors. -+ bool enabled; -+ -+ /// If true, data distribution is set by the user and is not allowed to change -+ /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) -+ bool fix_data_distribution; -+ -+ /// Data distribution for input tensors -+ Distribution data_distribution; -+ -+ /// Source of random tensor elements -+ library::Provider provider; -+ -+ /// Random number generator seed. -+ int seed; -+ -+ // -+ // Methods -+ // -+ -+ Initialization(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Helper to parse a Distribution object from the command line parser -+ static void get_distribution( -+ cutlass::CommandLine const &args, -+ std::string const &arg, -+ cutlass::Distribution &dist); -+ }; -+ -+ /// Options related to verification of the result -+ struct Verification { -+ -+ // -+ // Data members -+ // -+ -+ /// If true, kernels are verified before they are profiled -+ bool enabled; -+ -+ /// Relative error threshold - zero to require bit-level consistency -+ double epsilon; -+ -+ /// Values smaller than this are assumed to be zero -+ double nonzero_floor; -+ -+ /// List of providers used to verify each result -+ ProviderVector providers; -+ -+ /// Indicates when to save the workspace -+ SaveWorkspace save_workspace; -+ -+ // -+ // Methods -+ // -+ -+ Verification(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Returns true if a provider is enabled -+ bool provider_enabled(library::Provider provider) const; -+ -+ /// Returns the index of a provider if its enabled -+ size_t index(library::Provider provider) const; -+ }; -+ -+ /// Options related to profiling -+ struct Profiling { -+ -+ /// Number of workspaces to rotate through to avoid cache-resident working sets -+ int workspace_count; -+ -+ /// Number of iterations to warmup each kernel prior to profiling -+ int warmup_iterations; -+ -+ /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration -+ int iterations; -+ -+ /// Number of ms to sleep between profiling periods (ms) -+ int sleep_duration; -+ -+ /// If true, profiling is actually conducted. -+ bool enabled; -+ -+ /// List of providers of each functionality to be profiled -+ ProviderVector providers; -+ -+ // -+ // Methods -+ // -+ -+ Profiling(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ /// Returns true if a provider is enabled -+ bool provider_enabled(library::Provider provider) const; -+ -+ /// Returns the index of a provider if its enabled -+ size_t index(library::Provider provider) const; -+ }; -+ -+ /// Options related to reporting -+ struct Report { -+ -+ /// If true, result is appended to possibly existing file -+ bool append; -+ -+ /// Path to a file containing results -+ std::string output_path; -+ -+ /// Path to a file containing junit xml results -+ std::string junit_output_path; -+ -+ /// Sequence of tags to attach to each result -+ std::vector> pivot_tags; -+ -+ /// If true, reports status of all kernels including those that were -+ /// not run for the given argumetns -+ bool report_not_run; -+ -+ /// Prints human-readable text to stdout. If false, nothing is written to stdout -+ bool verbose; -+ -+ /// Sort results by (currently by flops-per-byte) -+ bool sort_results; -+ -+ // -+ // Methods -+ // -+ -+ Report(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ }; -+ -+ /// Options related to printing usage and version information -+ struct About { -+ -+ /// If true, usage is printed and the program ends. -+ bool help; -+ -+ /// Prints version string -+ bool version; -+ -+ /// Print information about devices -+ bool device_info; -+ -+ // -+ // Methods -+ // -+ -+ About(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out, int indent = 0) const; -+ -+ static void print_version(std::ostream &out); -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ /// Top-level execution mode -+ ExecutionMode execution_mode; -+ -+ /// Name of math function to profile -+ library::OperationKind operation_kind; -+ -+ /// Vector of operation name substrings -+ std::vector operation_names; -+ -+ /// Vector of operation name substrings -+ std::vector excluded_operation_names; -+ -+ -+ // -+ // Detailed configuration options -+ // -+ -+ /// Configuration -+ CommandLine cmdline; -+ Device device; -+ Initialization initialization; -+ Library library; -+ Verification verification; -+ Profiling profiling; -+ Report report; -+ About about; -+ -+public: -+ -+ Options(CommandLine const &cmdline); -+ -+ void print_usage(std::ostream &out) const; -+ void print_options(std::ostream &out) const; -+ -+ static std::string indent_str(int indent); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_report.h b/3rdparty/cutlass/tools/profiler/src/performance_report.h -new file mode 100644 -index 0000000..b74d069 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_report.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Class performing output during profiling -+*/ -+ -+#pragma once -+ -+#include -+#include -+ -+// CUTLASS Profiler includes -+#include "options.h" -+#include "enumerated_types.h" -+#include "performance_result.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+class PerformanceReport { -+private: -+ -+ /// Reference to options -+ Options const &options_; -+ -+ /// Operation kind -+ library::OperationKind op_kind_; -+ -+ /// Operation file name containing performance report of op_kind -+ std::string op_file_name_; -+ -+ /// Output file containing results -+ std::ofstream output_file_; -+ -+ /// Operation file name containing junit performance report of op_kind -+ std::string op_junit_file_name_; -+ -+ /// Output file containing junit results -+ std::ofstream junit_output_file_; -+ -+ /// Flag indicating the performance report is valid -+ bool good_; -+ -+ /// Vector of argument names -+ std::vector argument_names_; -+ -+ /// Counter uniquely identifying problem within the report -+ size_t problem_index_; -+ -+ /// Collection of all results -+ PerformanceResultVector concatenated_results_; -+ -+public: -+ -+ PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); -+ ~PerformanceReport(); -+ -+ bool good() const { return good_; } -+ -+ void next_problem(); -+ void append_result(PerformanceResult result); -+ void sort_results(PerformanceResultVector &results); -+ void append_results(PerformanceResultVector const &results); -+ -+public: -+ -+ /// Prints the CSV header -+ std::ostream & print_csv_header_(std::ostream &out); -+ -+ /// Prints the CSV -+ std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); -+ -+ /// @defgroup jUnit Result Generation -+ /// Functions related to generation of the jUnit results -+ /// @{ -+ -+ std::ostream & print_junit_header_(std::ostream &out); -+ std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); -+ std::ostream & print_junit_footer_(std::ostream &out); -+ -+ /// @} -+ -+ /// Prints the result in human readable form -+ std::ostream & print_result_pretty_( -+ std::ostream &out, -+ PerformanceResult const &result, -+ bool use_shell_coloring = true); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_result.cu b/3rdparty/cutlass/tools/profiler/src/performance_result.cu -new file mode 100644 -index 0000000..810e261 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_result.cu -@@ -0,0 +1,61 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+// CUTLASS Profiler includes -+#include "enumerated_types.h" -+#include "performance_result.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/performance_result.h b/3rdparty/cutlass/tools/profiler/src/performance_result.h -new file mode 100644 -index 0000000..c714e02 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/performance_result.h -@@ -0,0 +1,128 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+*/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+// CUTLASS Profiler includes -+#include "enumerated_types.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performance result object -+struct PerformanceResult { -+ -+ /// Index of problem -+ size_t problem_index; -+ -+ /// library::Provider -+ library::Provider provider; -+ -+ /// Operation kind -+ library::OperationKind op_kind; -+ -+ /// CUTLASS status result from kernels (success or failure) -+ // Status does information on verification -+ Status status; -+ -+ /// Outcome of verification (worst case verification result) -+ Disposition disposition; -+ -+ /// Outcome of verification (all verification results) -+ DispositionMap verification_map; -+ -+ /// Operation name -+ std::string operation_name; -+ -+ /// Stringified vector of argument values -+ std::vector > arguments; -+ -+ /// Number of bytes read or written -+ int64_t bytes; -+ -+ /// Number of DL flops performed by the math function -+ int64_t flops; -+ -+ /// Average runtime in ms -+ double runtime; -+ -+ // -+ // Members -+ // -+ -+ /// Ctor -+ PerformanceResult(): -+ problem_index(0), -+ op_kind(library::OperationKind::kInvalid), -+ provider(library::Provider::kInvalid), -+ disposition(Disposition::kNotRun), -+ status(Status::kInvalid), -+ bytes(0), -+ flops(0), -+ runtime(0) -+ { } -+ -+ /// Returns true if the runtime is valid -+ bool good() const { -+ return runtime > 0; -+ } -+ -+ /// Math throughput in units of GFLOP/s -+ double gflops_per_sec() const { -+ return double(flops) / runtime / 1.0e6; -+ } -+ -+ /// memory bandwidth in units of GiB/s -+ double gbytes_per_sec() const { -+ return double(bytes) / double(1 << 30) / runtime * 1000.0; -+ } -+ -+}; -+ -+using PerformanceResultVector = std::vector; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/problem_space.h b/3rdparty/cutlass/tools/profiler/src/problem_space.h -new file mode 100644 -index 0000000..4e102e6 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/problem_space.h -@@ -0,0 +1,1005 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+ "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, -+ bug-ridden, slow implementation of half of Common Lisp." -+ -+ - Greenspun's Tenth Rule of Programming -+ -+ -+ cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian -+ product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. -+ -+ These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, -+ verify and profile various operations when they are compatible with the command line, and -+ construct data tables of results that are convenient inputs to post processing in Excel or Pandas. -+ -+ By executing multiple problems per invocation, startup overheads may be amortized across many -+ kernel launches. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Utility includes -+#include "cutlass/util/command_line.h" -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+ -+// Profiler includes -+#include "enumerated_types.h" -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines the argument schema -+struct ArgumentDescription { -+ -+ /// Type of argument -+ ArgumentTypeID type; -+ -+ /// Prioritized array of aliases used in command line parsing -+ std::vector aliases; -+ -+ /// Description of argument -+ std::string description; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ArgumentDescription(): -+ type(ArgumentTypeID::kInvalid) { } -+ -+ /// Constructor with aliases -+ ArgumentDescription( -+ ArgumentTypeID type_, -+ std::vector const &aliases_, -+ std::string const &description_ -+ ): -+ type(type_), aliases(aliases_), description(description_) { } -+}; -+ -+/// Vector of arguments -+using ArgumentDescriptionVector = std::vector; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Base class for kernel arguments -+struct KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value base class -+ struct Value { -+ -+ KernelArgument const *argument; -+ bool not_null; -+ -+ // -+ // Methods -+ // -+ -+ Value( -+ KernelArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ): argument(argument_), not_null(not_null_) { } -+ -+ virtual ~Value() { } -+ -+ virtual std::ostream &print(std::ostream &out) const =0; -+ }; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct ValueIterator { -+ -+ /// Indicates type of kernel argument -+ KernelArgument const *argument; -+ -+ /// If the iterator points to an argument that is null, it needs to be distinguished -+ /// from end. -+ bool null_argument; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructs a value iterator - no methods are valid if argument_ == nullptr -+ ValueIterator( -+ KernelArgument const *argument_ = nullptr, -+ bool null_argument_ = false): -+ argument(argument_), null_argument(null_argument_) { -+ -+ if (!argument_->not_null()) { -+ null_argument = true; -+ } -+ } -+ -+ virtual ~ValueIterator() { } -+ -+ /// Advances to next point in range -+ virtual void operator++() = 0; -+ -+ /// Compares against another value iterator - must be of the same KernelArgument type -+ virtual bool operator==(ValueIterator const &it) const = 0; -+ -+ /// Returns a unique_ptr object pointing to a newly created value object -+ virtual std::unique_ptr at() const = 0; -+ -+ /// Gets the type of the iterator -+ ArgumentTypeID type() const { -+ return argument->description->type; -+ } -+ -+ /// Helper to compute inequality -+ bool operator!=(ValueIterator const &it) const { -+ return !(*this == it); -+ } -+ -+ std::ostream &print(std::ostream &out) const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Describes the argument -+ ArgumentDescription const *description; -+ -+ /// Parent node -+ KernelArgument *parent; -+ -+ /// Sequence in which the kernel argument is to be iterated over. -+ /// Smaller means faster changing. -1 is don't care -+ int ordinal; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ KernelArgument( -+ ArgumentDescription const *description_ = nullptr, -+ KernelArgument *parent_ = nullptr, -+ int ordinal_ = -1 -+ ): description(description_), parent(parent_), ordinal(ordinal_) { } -+ -+ virtual ~KernelArgument(); -+ -+ /// Returns true if the kernel argument iself is empty -+ virtual bool not_null() const =0; -+ -+ /// Returns a string name for debugging -+ std::string qualified_name() const { -+ if (description) { -+ if (description->aliases.empty()) { -+ return ""; -+ } -+ return description->aliases.front(); -+ } -+ return ""; -+ } -+ -+ virtual std::unique_ptr begin() const =0; -+ virtual std::unique_ptr end() const =0; -+}; -+ -+using KernelArgumentVector = std::vector>; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel -+/// type. -+struct ScalarArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value type -+ struct ScalarValue : public KernelArgument::Value { -+ -+ std::string value; -+ -+ // -+ // Methods -+ // -+ -+ ScalarValue( -+ std::string const &value_ = "", -+ ScalarArgument const *argument = nullptr, -+ bool not_null_ = true -+ ); -+ -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct ScalarValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ ScalarValueIterator(ScalarArgument const *argument = nullptr); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Set of posible values -+ ValueCollection values; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ScalarArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Closed range supporting additive increment -+struct Range { -+ -+ // -+ // Type definitions -+ // -+ -+ enum class Mode { -+ kSequence, -+ kRandom, -+ kRandomLog2, -+ kInvalid -+ }; -+ -+ struct Iterator { -+ -+ int64_t value; -+ int64_t increment; -+ Range const *range; -+ -+ // -+ // Methods -+ // -+ -+ Iterator( -+ int64_t value_ = 0, -+ int64_t increment_ = 1, -+ Range const *range_ = nullptr -+ ): -+ value(value_), increment(increment_), range(range_) { } -+ -+ Iterator & operator++() { -+ value += increment; -+ return *this; -+ } -+ -+ Iterator operator++(int) { -+ Iterator self(*this); -+ ++(*this); -+ return self; -+ } -+ -+ bool operator==(Iterator const &it) const { -+ return value == it.value; -+ } -+ -+ bool operator!=(Iterator const &it) const { -+ return !(*this == it); -+ } -+ -+ static int64_t round(int64_t value, int64_t divisible) { -+ int64_t rem = (value % divisible); -+ -+ // Round either up or down -+ if (rem > divisible / 2) { -+ value += (divisible - rem); -+ } -+ else { -+ value -= rem; -+ } -+ -+ return value; -+ } -+ -+ int64_t at() const { -+ if (!range) { -+ return value; -+ } -+ -+ switch (range->mode) { -+ case Mode::kSequence: return value; -+ -+ case Mode::kRandom: { -+ double rnd = double(range->minimum) + -+ double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); -+ -+ int64_t value = int64_t(rnd); -+ -+ return round(value, range->divisible); -+ } -+ break; -+ -+ case Mode::kRandomLog2: { -+ double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); -+ double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); -+ double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); -+ -+ int64_t value = int64_t(std::pow(2.0, rnd)); -+ -+ return round(value, range->divisible); -+ } -+ break; -+ default: break; -+ } -+ return value; -+ } -+ -+ int64_t operator*() const { -+ return at(); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ int64_t first; ///< first element in range -+ int64_t last; ///< last element in range -+ int64_t increment; ///< additive increment between values -+ -+ Mode mode; ///< mode selection enables alternative values -+ int64_t minimum; ///< minimum value to return -+ int64_t maximum; ///< maximum value to return -+ int64_t divisible; ///< rounds value down to an integer multiple of this value -+ -+ // -+ // Methods -+ // -+ -+ /// Default constructor - range acts as a scalar -+ Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } -+ -+ /// Range acts as a range -+ Range( -+ int64_t first_, -+ int64_t last_, -+ int64_t increment_ = 1, -+ Mode mode_ = Mode::kSequence, -+ int64_t minimum_ = 0, -+ int64_t maximum_ = 0, -+ int64_t divisible_ = 1 -+ ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { -+ -+ // Helpers to avoid constructing invalid ranges -+ if (increment > 0) { -+ if (last < first) { -+ std::swap(last, first); -+ } -+ } -+ else if (increment < 0) { -+ if (first < last) { -+ std::swap(last, first); -+ } -+ } -+ else if (last != first) { -+ last = first; -+ increment = 1; -+ } -+ } -+ -+ /// Helper to construct a sequence range -+ static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { -+ return Range(first_, last_, increment_, Mode::kSequence); -+ } -+ -+ /// Helper to construct a range that is a random distribution -+ static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { -+ return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); -+ } -+ -+ /// Helper to construct a range that is a random distribution over a log scale -+ static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { -+ return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); -+ } -+ -+ /// Returns an iterator to the first element within the range -+ Iterator begin() const { -+ return Iterator(first, increment, this); -+ } -+ -+ /// Returns an iterator to the first element *after* the range -+ Iterator end() const { -+ return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); -+ } -+}; -+ -+/// Integer-valued argument - represented as a list of integer-valued ranges -+struct IntegerArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ /// Value type -+ struct IntegerValue : public KernelArgument::Value { -+ -+ int64_t value; -+ -+ // -+ // Methods -+ // -+ -+ IntegerValue( -+ int64_t value_ = 0, -+ IntegerArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ /// Collection of ranges represent the IntegerArgument's state -+ using RangeCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct IntegerValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ RangeCollection::const_iterator range_it; -+ Range::Iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ IntegerValueIterator(); -+ IntegerValueIterator(IntegerArgument const *argument); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Set of posible values -+ RangeCollection ranges; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ IntegerArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ bool _not_null = !ranges.empty(); -+ return _not_null; -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Structure defining the data type of tensors -+struct TensorArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ struct TensorDescription { -+ -+ /// Data type of elements -+ library::NumericTypeID element; -+ -+ /// Layout definition -+ library::LayoutTypeID layout; -+ -+ /// Computed extent -+ std::vector extent; -+ -+ /// Enables directly specifying stride value used to size tensor -+ std::vector stride; -+ -+ // -+ // Methods -+ // -+ -+ TensorDescription( -+ library::NumericTypeID element_ = library::NumericTypeID::kUnknown, -+ library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, -+ std::vector extent_ = std::vector(), -+ std::vector stride_ = std::vector() -+ ): -+ element(element_), layout(layout_), extent(extent_), stride(stride_) {} -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Value structure -+ struct TensorValue : public KernelArgument::Value { -+ -+ TensorDescription desc; -+ -+ // -+ // Methods -+ // -+ -+ TensorValue( -+ TensorDescription const &desc_ = TensorDescription(), -+ TensorArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct TensorValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ TensorValueIterator(TensorArgument const *argument_); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ /// Set of possible values -+ ValueCollection values; -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ TensorArgument( -+ ArgumentDescription const *description -+ ): -+ KernelArgument(description) { } -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Numeric data type -+struct EnumeratedTypeArgument : public KernelArgument { -+ -+ // -+ // Type definitions -+ // -+ -+ struct EnumeratedTypeValue : public KernelArgument::Value { -+ -+ /// Data type of element -+ std::string element; -+ -+ // -+ // Methods -+ // -+ -+ EnumeratedTypeValue( -+ std::string const &element_ = std::string(), -+ EnumeratedTypeArgument const *argument_ = nullptr, -+ bool not_null_ = true -+ ); -+ -+ /// Pretty printer for debugging -+ virtual std::ostream &print(std::ostream &out) const; -+ }; -+ -+ using ValueCollection = std::vector; -+ -+ /// Abstract base class to iterate over values within arguments -+ struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection::const_iterator value_it; -+ -+ // -+ // Methods -+ // -+ -+ EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); -+ -+ virtual void operator++(); -+ virtual bool operator==(ValueIterator const &it) const; -+ -+ /// Gets the value pointed to -+ virtual std::unique_ptr at() const; -+ }; -+ -+ // -+ // Data members -+ // -+ -+ ValueCollection values; -+ -+ // -+ // Members -+ // -+ -+ /// Default ctor -+ EnumeratedTypeArgument(ArgumentDescription const *description): -+ KernelArgument(description) {} -+ -+ virtual bool not_null() const { -+ return !values.empty(); -+ } -+ -+ virtual std::unique_ptr begin() const; -+ virtual std::unique_ptr end() const; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Object storing the space argument values -+class ProblemSpace { -+public: -+ -+ /// Tuple of arguments -+ using Problem = std::vector>; -+ -+ /// Type used to iterator over things -+ using IteratorVector = std::vector>; -+ -+ /// Iterates over points in the design space -+ class Iterator { -+ private: -+ -+ /// One iterator per argument -+ IteratorVector iterators; -+ -+ public: -+ -+ // -+ // Methods -+ // -+ -+ explicit Iterator(); -+ Iterator(ProblemSpace const &problem_space); -+ Iterator(Iterator &&it); -+ -+ // Rule of three -+ Iterator(Iterator const &) = delete; -+ Iterator &operator=(Iterator const &it) = delete; -+ ~Iterator() = default; -+ -+ /// Pre-increment - advances to next point in argument range -+ void operator++(); -+ -+ /// Gets the current argument value -+ Problem at() const; -+ -+ /// Moves iterator to end -+ void move_to_end(); -+ -+ /// Equality operator -+ bool operator==(Iterator const &it) const; -+ -+ /// Inequality operator -+ bool operator!=(Iterator const &it) const { -+ return !(*this == it); -+ } -+ -+ /// Helper to call at() method -+ Problem operator*() const { -+ return at(); -+ } -+ -+ /// Helper to print iterator state -+ std::ostream & print(std::ostream &out) const; -+ -+ private: -+ -+ /// Helper for recursively constructing iterators -+ void construct_(KernelArgument const *argument); -+ }; -+ -+public: -+ -+ // -+ // Data members -+ // -+ -+ KernelArgumentVector arguments; -+ -+ /// Map of argument names to their position within the argument vector -+ std::unordered_map argument_index_map; -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Default ctor -+ ProblemSpace() {} -+ -+ /// Constructs a problem space from a vector of arguments. This vector must outlive -+ /// the ProblemSpace object, which stores pointers to objects within the -+ /// ArgumentDescriptionVector. -+ ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); -+ -+ Iterator begin() const; // returns an iterator to the first point in the range -+ Iterator end() const; // returns an iterator to the first point after the range -+ -+ /// Returns the index of an argument by name -+ size_t argument_index(char const *name) const; -+ -+ /// Gets all argument names as an ordered vector -+ std::vector argument_names() const; -+ -+ /// Returns the number of dimensions of the problem space -+ size_t rank() const { return arguments.size(); } -+ -+private: -+ -+ /// Helper for recursively cloning -+ void clone_( -+ KernelArgumentVector &kernel_args, -+ ArgumentDescription const *arg_desc); -+ -+ /// Parses command line argument -+ void parse_( -+ KernelArgument *arg, -+ CommandLine const &cmdline); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Lexically casts an argument to an int if it is defined. Returns true if not null. -+bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int( -+ int &int_value, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_int( -+ int64_t &int_value, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_NumericTypeID( -+ library::NumericTypeID &numeric_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_LayoutTypeID( -+ library::LayoutTypeID &layout_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_OpcodeClassID( -+ library::OpcodeClassID &opcode_class, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_SplitKModeID( -+ library::SplitKMode &split_k_mode, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ConvModeID( -+ library::ConvModeID &conv_mode, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_IteratorAlgorithmID( -+ library::IteratorAlgorithmID &iterator_algorithm, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. -+bool arg_as_ProviderID( -+ library::Provider &provider, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -+bool arg_as_scalar( -+ std::vector &bytes, -+ library::NumericTypeID numeric_type, -+ KernelArgument::Value const *value_ptr); -+ -+/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. -+bool arg_as_scalar( -+ std::vector &bytes, -+ library::NumericTypeID numeric_type, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Returns true if a tensor description satisfies a `tensor` value -+bool tensor_description_satisfies( -+ library::TensorDescription const &tensor_desc, -+ TensorArgument::TensorValue const *value_ptr); -+ -+/// Returns true if a tensor description satisfies a `tensor` value -+bool tensor_description_satisfies( -+ library::TensorDescription const &tensor_desc, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ -+/// Returns true if a conv kind satisfies the value -+bool conv_kind_satisfies( -+ library::ConvKind const &conv_kind, -+ EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); -+ -+/// Returns true if a conv kind satisfies the value -+bool conv_kind_satisfies( -+ library::ConvKind const &conv_kind, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+/// Returns true if a iterator algorithm satisfies the value -+bool iterator_algorithm_satisfies( -+ library::IteratorAlgorithmID const &iterator_algorithm, -+ EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); -+ -+/// Returns true if a iterator algorithm satisfies the value -+bool iterator_algorithm_satisfies( -+ library::IteratorAlgorithmID const &iterator_algorithm, -+ char const *name, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu -new file mode 100644 -index 0000000..2c2f236 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.cu -@@ -0,0 +1,727 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "rank_2k_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+Rank2KOperationProfiler::Rank2KOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kRank2K, -+ { -+ {ArgumentTypeID::kEnumerated, {"rank_k_kind"}, "Variant of RankK (universal)"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the RankK problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the RankK problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for RankK kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for RankK kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of RankK computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Rank 2k Update. D = alpha * (A*B^T + B*A^T) + beta * C (symmetric) or D = alpha * (A*B^H+B*A^H) + beta * C (hermitian)"; -+} -+ -+/// Destructor -+Rank2KOperationProfiler::~Rank2KOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void Rank2KOperationProfiler::print_usage(std::ostream &out) const { -+ out << "RankK" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void Rank2KOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size Syrk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_2k --blas_mode=symmetric --n=1024 --k=128\n\n" -+ -+ << "Profile a particular problem size Herk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_2k --blas_mode=hermitian --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=rank_2k --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=rank_2k --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=rank_2k --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=rank_2k --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=rank_2k --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=rank_2k --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to rank_2k kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=rank_2k \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status Rank2KOperationProfiler::RankKProblem::parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->n), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t Rank2KOperationProfiler::RankKProblem::bytes(library::RankKDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ 2 * int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ 2 * int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t Rank2KOperationProfiler::RankKProblem::flops(library::RankKDescription const &operation_desc) const { -+ -+ // FLOPs = 2 * n(n+1)k/2 [mma1] + 2 * n(n+1)k/2 [mma2] + 2 * n(n+1)/2 [epilogue] -+ // FLOPs = n(n+1)(2k + 1) -+ int64_t flops_ = n * (n + 1) * (2*k + 1); -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void Rank2KOperationProfiler::RankKProblem::initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "rank_k_kind", problem_space, library::to_string(operation_desc.rank_k_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status Rank2KOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.rank_k_kind != library::RankKKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ rank_k_workspace_.configuration.problem_size.m() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.n() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.k() = int(problem_.k); -+ rank_k_workspace_.configuration.lda = problem_.lda; -+ rank_k_workspace_.configuration.ldb = problem_.ldb; -+ rank_k_workspace_.configuration.ldc = problem_.ldc; -+ rank_k_workspace_.configuration.ldd = problem_.ldc; -+ //rank_k_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ rank_k_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ rank_k_workspace_.arguments.A = nullptr; -+ rank_k_workspace_.arguments.B = nullptr; -+ rank_k_workspace_.arguments.C = nullptr; -+ rank_k_workspace_.arguments.D = nullptr; -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&rank_k_workspace_.configuration, &rank_k_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void Rank2KOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+ -+} -+ -+/// Initializes workspace -+Status Rank2KOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ rank_k_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.lda)} -+ ); -+ -+ rank_k_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.ldb)} -+ ); -+ -+ rank_k_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ rank_k_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Computed->copy_from_device(rank_k_workspace_.C->data()); -+ rank_k_workspace_.Reference->copy_from_device(rank_k_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &rank_k_workspace_.configuration, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kRank2K; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Rank2KOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & rank_k_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(rank_k_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool Rank2KOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::RankKDescription const &rank_k_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSyr2k() -+ // -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasRankKDispatcher rank_k_op( -+ rank_k_desc, -+ rank_k_workspace_.configuration, -+ rank_k_workspace_.arguments -+ ); -+ -+ if (rank_k_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = rank_k_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *rank_k_workspace_.Computed, -+ *rank_k_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ rank_k_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool Rank2KOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.B = rank_k_workspace_.B->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h -new file mode 100644 -index 0000000..6dbfc3f ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_2k_operation_profiler.h -@@ -0,0 +1,229 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class Rank2KOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct RankKProblem { -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ RankKProblem(): -+ n(16), k(16), lda(0), ldc(0), -+ fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::RankKDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::RankKDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct RankKWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ RankKWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ RankKProblem problem_; -+ -+ /// Device memory allocations -+ RankKWorkspace rank_k_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ Rank2KOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~Rank2KOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu -new file mode 100644 -index 0000000..7e452e7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.cu -@@ -0,0 +1,715 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "rank_k_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+RankKOperationProfiler::RankKOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kRankK, -+ { -+ {ArgumentTypeID::kEnumerated, {"rank_k_kind"}, "Variant of RankK (universal)"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the RankK problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the RankK problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for RankK kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for RankK kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of RankK computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Rank-k Update. D = alpha * A*A^T + beta * C (symmetric) or D = alpha * A*A^H + beta * C (hermitian)"; -+} -+ -+/// Destructor -+RankKOperationProfiler::~RankKOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void RankKOperationProfiler::print_usage(std::ostream &out) const { -+ out << "RankK" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void RankKOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size Syrk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_k --blas_mode=symmetric --n=1024 --k=128\n\n" -+ -+ << "Profile a particular problem size Herk kernel:\n" -+ << " $ cutlass_profiler --operation=rank_k --blas_mode=hermitian --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=rank_k --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=rank_k --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=rank_k --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=rank_k --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=rank_k --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=rank_k --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to rank_k kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=rank_k \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status RankKOperationProfiler::RankKProblem::parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->k)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->n), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t RankKOperationProfiler::RankKProblem::bytes(library::RankKDescription const &operation_desc) const { -+ // Input bytes read and Output bytes written for the gemm problem -+ int64_t bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * k + -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * n / 8) * (n+1) / 2; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t RankKOperationProfiler::RankKProblem::flops(library::RankKDescription const &operation_desc) const { -+ -+ // FLOPs = 2 * n(n+1)k/2 [mma] + 2 * n(n+1)/2 [epilogue] -+ // FLOPs = n(n+1)(k + 1) -+ int64_t flops_ = n * (n + 1) * (k + 1); -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void RankKOperationProfiler::RankKProblem::initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "rank_k_kind", problem_space, library::to_string(operation_desc.rank_k_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status RankKOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.rank_k_kind != library::RankKKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ rank_k_workspace_.configuration.problem_size.m() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.n() = int(problem_.n); -+ rank_k_workspace_.configuration.problem_size.k() = int(problem_.k); -+ rank_k_workspace_.configuration.lda = problem_.lda; -+ rank_k_workspace_.configuration.ldc = problem_.ldc; -+ rank_k_workspace_.configuration.ldd = problem_.ldc; -+ //rank_k_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ rank_k_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ rank_k_workspace_.arguments.A = nullptr; -+ rank_k_workspace_.arguments.C = nullptr; -+ rank_k_workspace_.arguments.D = nullptr; -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&rank_k_workspace_.configuration, &rank_k_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void RankKOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ -+ result.runtime = 0; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ result.flops *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ result.flops *= 4; -+ break; -+ -+ default: break; -+ } -+ -+} -+ -+/// Initializes workspace -+Status RankKOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::RankKDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ rank_k_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.k)}, -+ {int(problem_.lda)} -+ ); -+ -+ rank_k_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ rank_k_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ rank_k_workspace_.Computed->copy_from_device(rank_k_workspace_.C->data()); -+ rank_k_workspace_.Reference->copy_from_device(rank_k_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&rank_k_workspace_.configuration); -+ rank_k_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &rank_k_workspace_.configuration, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kRankK; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool RankKOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & rank_k_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(rank_k_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool RankKOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::RankKDescription const &rank_k_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSyrk() -+ // -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Reference->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasRankKDispatcher rank_k_op( -+ rank_k_desc, -+ rank_k_workspace_.configuration, -+ rank_k_workspace_.arguments -+ ); -+ -+ if (rank_k_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = rank_k_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *rank_k_workspace_.Computed, -+ *rank_k_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ rank_k_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool RankKOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing RankK arguments -+ rank_k_workspace_.arguments.A = rank_k_workspace_.A->data(); -+ rank_k_workspace_.arguments.C = rank_k_workspace_.C->data(); -+ rank_k_workspace_.arguments.D = rank_k_workspace_.Computed->data(); -+ rank_k_workspace_.arguments.alpha = problem_.alpha.data(); -+ rank_k_workspace_.arguments.beta = problem_.beta.data(); -+ rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &rank_k_workspace_.arguments, -+ rank_k_workspace_.host_workspace.data(), -+ rank_k_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h -new file mode 100644 -index 0000000..779509a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/rank_k_operation_profiler.h -@@ -0,0 +1,227 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class RankKOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct RankKProblem { -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldc; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ RankKProblem(): -+ n(16), k(16), lda(0), ldc(0), -+ fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::RankKDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::RankKDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct RankKWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::RankKConfiguration configuration; -+ library::RankKArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ RankKWorkspace(): -+ A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ RankKProblem problem_; -+ -+ /// Device memory allocations -+ RankKWorkspace rank_k_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ RankKOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~RankKOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::RankKDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h -new file mode 100644 -index 0000000..eef7350 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/reduction_operation_profiler.h -@@ -0,0 +1,173 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines profiling functionality for reduction operation -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#if CUTLASS_ENABLE_CUDNN -+#include "cudnn_helpers.h" -+#endif //#if CUTLASS_ENABLE_CUDNN -+#include "debug.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class ReductionOperationProfiler : public OperationProfiler { -+public: -+ -+ -+ /// Workspace used -+ struct ReductionWorkspace { -+ -+ /// Conv device allocations -+ DeviceAllocation *Workspace; -+ DeviceAllocation *Source; -+ DeviceAllocation *Destination; -+ DeviceAllocation *Reference; -+ -+ /// Library configuration and arguments -+ library::ReductionConfiguration configuration; -+ library::ReductionArguments arguments; -+ -+ /// Buffer used for the cutlass operations' host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the cutlass operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ ReductionWorkspace(): -+ Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// Reduction problem obtained from problem space -+ MatrixCoord problem_; -+ -+ /// Device memory allocations -+ ReductionWorkspace conv_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ ReductionOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~ReductionOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu -new file mode 100644 -index 0000000..2caf5f0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.cu -@@ -0,0 +1,569 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cublas_helpers.h" -+#include "sparse_gemm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+SparseGemmOperationProfiler::SparseGemmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kSparseGemm, -+ { -+ {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (e.g. gemm, planar complex, batched, ...)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kTensor, {"E"}, "Tensor storing the E operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count"}, "Number of GEMMs computed in one batch"}, -+ } -+ ) { -+ -+ description_ = " Structured sparse GEMM. D = alpha * A*B + beta * C"; -+} -+ -+/// Destructor -+SparseGemmOperationProfiler::~SparseGemmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void SparseGemmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Sparse GEMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void SparseGemmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --m=1024 --n=1024 --k=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=SparseGemm --A=f16:column --B=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=SparseGemm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=SparseGemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=SparseGemm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+Status SparseGemmOperationProfiler::SparseGemmProblem::parse( -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->k, "k", problem_space, problem)) { -+ // default value -+ this->k = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.E, "E", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ this->elements_per_128b = -+ 128 / library::sizeof_bits(operation_desc.A.element); -+ -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, -+ {int(this->m), int(this->k) / int(this->sparse)}) -+ .front(); -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->k), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->lde = -+ DeviceAllocation::get_packed_layout( -+ operation_desc.E.layout, -+ {int(this->m), int(this->k / this->sparse / this->elements_per_128b)}) -+ .front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Initializes a performance result -+void SparseGemmOperationProfiler::SparseGemmProblem::initialize_result( -+ PerformanceResult &result, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "E", problem_space, -+ std::string(library::to_string(operation_desc.E.element)) + ":" + library::to_string(operation_desc.E.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ set_argument(result, "k", problem_space, k); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+/// Extracts the problem dimensions -+Status SparseGemmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SparseGemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.gemm_kind != library::GemmKind::kSparse) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ gemm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ gemm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ gemm_workspace_.configuration.problem_size.k() = int(problem_.k); -+ gemm_workspace_.configuration.lda = problem_.lda; -+ gemm_workspace_.configuration.ldb = problem_.ldb; -+ gemm_workspace_.configuration.ldc = problem_.ldc; -+ gemm_workspace_.configuration.ldd = problem_.ldc; -+ gemm_workspace_.configuration.lde = problem_.lde; -+ -+ gemm_workspace_.arguments.A = nullptr; -+ gemm_workspace_.arguments.B = nullptr; -+ gemm_workspace_.arguments.C = nullptr; -+ gemm_workspace_.arguments.D = nullptr; -+ gemm_workspace_.arguments.E = nullptr; -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void SparseGemmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ // Input bytes read and Output bytes written for the gemm problem -+ result.bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * -+ problem_.k / problem_.sparse + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.n / 8) * -+ problem_.k + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * -+ problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.E.element) * problem_.m / 8) * -+ problem_.k / problem_.sparse / problem_.elements_per_128b; -+ -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(problem_.beta.begin(), problem_.beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ result.bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; -+ } -+ -+ result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); -+ result.runtime = 0; -+ -+} -+ -+/// Initializes workspace -+Status SparseGemmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SparseGemmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ gemm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.k) / int(problem_.sparse)}, -+ {int(problem_.lda)} -+ ); -+ -+ gemm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.k), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ gemm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.E = device_context.allocate_sparsemeta_tensor( -+ options, -+ "E", -+ operation_desc.E.element, -+ operation_desc.E.layout, -+ operation_desc.A.element, -+ {int(problem_.m), int(problem_.k) / int(problem_.sparse) / int(problem_.elements_per_128b)}, -+ {int(problem_.lde)} -+ ); -+ -+ gemm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration); -+ gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &gemm_workspace_.configuration, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kSparseGemm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto &verification_provider : options.verification.providers) { -+ results_.back().verification_map[verification_provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SparseGemmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.E = gemm_workspace_.E->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool SparseGemmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing GEMM arguments -+ gemm_workspace_.arguments.A = gemm_workspace_.A->data(); -+ gemm_workspace_.arguments.B = gemm_workspace_.B->data(); -+ gemm_workspace_.arguments.C = gemm_workspace_.C->data(); -+ gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); -+ gemm_workspace_.arguments.E = gemm_workspace_.E->data(); -+ gemm_workspace_.arguments.alpha = problem_.alpha.data(); -+ gemm_workspace_.arguments.beta = problem_.beta.data(); -+ gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &gemm_workspace_.arguments, -+ gemm_workspace_.host_workspace.data(), -+ gemm_workspace_.device_workspace.data() -+ ); -+ } -+ -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h -new file mode 100644 -index 0000000..c1f11c9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/sparse_gemm_operation_profiler.h -@@ -0,0 +1,214 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+#include "gemm_operation_profiler.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class SparseGemmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct SparseGemmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t k; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ int64_t lde; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ static int const sparse = 2; -+ // every 128b ElementA uses one elementE -+ int elements_per_128b; -+ -+ // -+ // Methods -+ // -+ -+ SparseGemmProblem(): -+ m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct SparseGemmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *E; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::SparseGemmConfiguration configuration; -+ library::SparseGemmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ SparseGemmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ // GEMM problem -+ SparseGemmProblem problem_; -+ -+ /// Device memory allocations -+ SparseGemmWorkspace gemm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ SparseGemmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~SparseGemmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SparseGemmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu -new file mode 100644 -index 0000000..97cb34a ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.cu -@@ -0,0 +1,764 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "symm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+SymmOperationProfiler::SymmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kSymm, -+ { -+ {ArgumentTypeID::kEnumerated, {"symm_kind"}, "Variant of Symm (universal)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the Symm problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the Symm problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, -+ {ArgumentTypeID::kEnumerated, {"side_mode"}, "Side Mode for Symm kernel (left or right)"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for Symm kernel (lower or upper)"}, -+ {ArgumentTypeID::kEnumerated, {"blas_mode"}, "Blas Mode for Symm kernel (symmetric or hermitian)"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of Symm computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS } -+ ) { -+ description_ = " Symmetric Matrix-Matrix Multiplication. D = alpha * A * B OR alpha * B * A + beta * C (where A is symmetric/hermitian)"; -+} -+ -+/// Destructor -+SymmOperationProfiler::~SymmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void SymmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "Symm" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void SymmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size SYMM kernel:\n" -+ << " $ cutlass_profiler --operation=Symm --blas_mode=symmetric --m=1024 --n=128\n\n" -+ -+ << "Profile a particular problem size HEMM kernel:\n" -+ << " $ cutlass_profiler --operation=Symm --blas_mode=hermitian --m=1024 --n=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Symm --m=1024:4096:256 --n=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Symm --accumulator-type=f16,f32\n\n" -+ -+ << "Schmoo over side modees:\n" -+ << " $ cutlass_profiler --operation=Symm --side_mode=left/right\n\n" -+ -+ << "Schmoo over fill modees:\n" -+ << " $ cutlass_profiler --operation=Symm --fill_mode=lower/upper\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Symm --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Symm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Symm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Symm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Symm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to symm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Symm \\ \n" -+ << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --n=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status SymmOperationProfiler::SymmProblem::parse( -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->m)}).front(); -+ } -+ else if (operation_desc.side_mode == SideMode::kRight) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->n)}).front(); -+ } -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->ldc = DeviceAllocation::get_packed_layout( -+ operation_desc.C.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Total number of bytes loaded -+int64_t SymmOperationProfiler::SymmProblem::bytes(library::SymmDescription const &operation_desc) const { -+ int64_t bytes; -+ // Input bytes read and Output bytes written for the gemm problem -+ // Half matrix including the diagonal will have (X*(X+1))/2 elements -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * (m + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * m / 8) * n + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ bytes = -+ int64_t(library::sizeof_bits(operation_desc.A.element) * n / 8) * (n + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * m / 8) * n + -+ int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ // Set is_beta_zero true if beta is zero -+ bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); -+ -+ // Output bytes read for the gemm problem for non-zero beta values -+ if (!is_beta_zero) { -+ bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; -+ } -+ -+ bytes *= batch_count; -+ -+ return bytes; -+} -+ -+/// Total number of flops computed -+int64_t SymmOperationProfiler::SymmProblem::flops(library::SymmDescription const &operation_desc) const { -+ -+ // FLOPs for first TRMM kernel (with diagonal) = 2 * [ ( M * (M+1)/2 * N ) ] // Beta is zero -+ // FLOPs for second TRMM kernel (with diagonal) = 2 * [ ( M * (M-1)/2 * N ) ] // Beta is zero -+ // FLOPs = m*(m+1)*n [mma1] + m*(m-1)*n [mma2] + 2*m*n [epilogue] -+ // FLOPs = 2*m*n(m+1) for left side mode -+ // FLOPs can also be calculated to be same as GEMM with correct value for 'k' as below. -+ int64_t k = (operation_desc.side_mode == SideMode::kLeft) ? int64_t(m) : int64_t(n); -+ int64_t flops_ = (int64_t(m) * n * k + m * n) * 2; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ flops_ *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddGaussianComplex: -+ flops_ *= 3; -+ break; -+ -+ default: break; -+ } -+ -+ return flops_; -+} -+ -+/// Initializes a performance result -+void SymmOperationProfiler::SymmProblem::initialize_result( -+ PerformanceResult &result, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "symm_kind", problem_space, library::to_string(operation_desc.symm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "C", problem_space, -+ std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); -+ -+ set_argument(result, "side_mode", problem_space, library::to_string(operation_desc.side_mode)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "blas_mode", problem_space, library::to_string(operation_desc.blas_mode)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status SymmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SymmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.symm_kind != library::SymmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ symm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ symm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ symm_workspace_.configuration.problem_size.k() = (operation_desc.side_mode == SideMode::kLeft) -+ ? int(problem_.m) : int(problem_.n); -+ symm_workspace_.configuration.lda = problem_.lda; -+ symm_workspace_.configuration.ldb = problem_.ldb; -+ symm_workspace_.configuration.ldc = problem_.ldc; -+ symm_workspace_.configuration.ldd = problem_.ldc; -+ //symm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ symm_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ symm_workspace_.arguments.A = nullptr; -+ symm_workspace_.arguments.B = nullptr; -+ symm_workspace_.arguments.C = nullptr; -+ symm_workspace_.arguments.D = nullptr; -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&symm_workspace_.configuration, &symm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void SymmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ -+ result.bytes = problem_.bytes(operation_desc); -+ result.flops = problem_.flops(operation_desc); -+ result.runtime = 0; -+ -+ -+} -+ -+/// Initializes workspace -+Status SymmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::SymmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ symm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.m)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ symm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } -+ -+ symm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ symm_workspace_.C = device_context.allocate_tensor( -+ options, -+ "C", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)}, -+ 1 // batch_count = 1, default -+ ); -+ -+ symm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ symm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.C.element, -+ operation_desc.C.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldc)} -+ ); -+ -+ symm_workspace_.Computed->copy_from_device(symm_workspace_.C->data()); -+ symm_workspace_.Reference->copy_from_device(symm_workspace_.C->data()); -+ } -+ -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&symm_workspace_.configuration); -+ symm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&symm_workspace_.configuration); -+ symm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &symm_workspace_.configuration, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kSymm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SymmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.C->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Computed->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &symm_workspace_.arguments, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & symm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(symm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool SymmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::SymmDescription const &symm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasSymm() -+ // -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.Reference->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Reference->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasSymmDispatcher symm_op( -+ symm_desc, -+ symm_workspace_.configuration, -+ symm_workspace_.arguments -+ ); -+ -+ if (symm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = symm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *symm_workspace_.Computed, -+ *symm_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ symm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool SymmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing Symm arguments -+ symm_workspace_.arguments.A = symm_workspace_.A->data(); -+ symm_workspace_.arguments.B = symm_workspace_.B->data(); -+ symm_workspace_.arguments.C = symm_workspace_.C->data(); -+ symm_workspace_.arguments.D = symm_workspace_.Computed->data(); -+ symm_workspace_.arguments.alpha = problem_.alpha.data(); -+ symm_workspace_.arguments.beta = problem_.beta.data(); -+ symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &symm_workspace_.arguments, -+ symm_workspace_.host_workspace.data(), -+ symm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h -new file mode 100644 -index 0000000..a0162b4 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/symm_operation_profiler.h -@@ -0,0 +1,230 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+/// Abstract base class for each math function -+class SymmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct SymmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldc; -+ SideMode side_mode; -+ FillMode fill_mode; -+ BlasMode blas_mode; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ SymmProblem(): -+ m(16), n(16), lda(0), ldb(0), ldc(0), -+ side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), -+ split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Total number of bytes loaded -+ int64_t bytes(library::SymmDescription const &operation_desc) const; -+ -+ /// Total number of flops computed -+ int64_t flops(library::SymmDescription const &operation_desc) const; -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct SymmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *C; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::SymmConfiguration configuration; -+ library::SymmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ SymmWorkspace(): -+ A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ SymmProblem problem_; -+ -+ /// Device memory allocations -+ SymmWorkspace symm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ SymmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~SymmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::SymmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu -new file mode 100644 -index 0000000..19014d0 ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.cu -@@ -0,0 +1,704 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Execution environment -+ -+ -+*/ -+ -+#include -+#include -+#include -+#include -+ -+#include "cutlass/core_io.h" -+ -+#include "cublas_helpers.h" -+#include "trmm_operation_profiler.h" -+#include "gpu_timer.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Ctor -+TrmmOperationProfiler::TrmmOperationProfiler(Options const &options): -+ OperationProfiler( -+ options, -+ library::OperationKind::kTrmm, -+ { -+ {ArgumentTypeID::kEnumerated, {"trmm_kind"}, "Variant of TRMM (universal)"}, -+ {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the TRMM problem space"}, -+ {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the TRMM problem space"}, -+ {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, -+ {ArgumentTypeID::kEnumerated, {"side_mode"}, "Side Mode for TRMM (left, right)"}, -+ {ArgumentTypeID::kEnumerated, {"fill_mode"}, "Fill Mode for TRMM (lower, upper)"}, -+ {ArgumentTypeID::kEnumerated, {"diag_type"}, "Diag Type for TRMM (nonunit, unit)"}, -+ {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, -+ {ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D operand"}, -+ {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, -+ {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, -+ {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, -+ {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of TRMMs computed in one batch"}, -+ }, -+ { library::Provider::kCUBLAS} -+ ) { -+ description_ = " Triangular Matrix-Multiplication. D = alpha * A * B or alpha * B * A"; -+} -+ -+/// Destructor -+TrmmOperationProfiler::~TrmmOperationProfiler() { -+ -+} -+ -+/// Prints usage statement for the math function -+void TrmmOperationProfiler::print_usage(std::ostream &out) const { -+ out << "TRMM" << "\n\n"; -+ -+ OperationProfiler::print_usage(out); -+} -+ -+/// Prints examples -+void TrmmOperationProfiler::print_examples(std::ostream &out) const { -+ -+ out << "\nExamples:\n\n" -+ << "Profile a particular problem size:\n" -+ << " $ cutlass_profiler --operation=Trmm --n=1024 --m=128\n\n" -+ -+ << "Schmoo over problem size and beta:\n" -+ << " $ cutlass_profiler --operation=Trmm --n=1024:4096:256 --m=128:8192:128 --beta=0,1,2.5\n\n" -+ -+ << "Schmoo over accumulator types:\n" -+ << " $ cutlass_profiler --operation=Trmm --accumulator-type=f16,f32\n\n" -+ -+ << "Run when A is f16 with column-major or A is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" -+ << " $ cutlass_profiler --operation=Trmm --A=f16:column or --A=*:row\n\n" -+ -+ << "Using various input value distribution:\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=uniform,min:0,max:3\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=gaussian,mean:0,stddev:3\n" -+ << " $ cutlass_profiler --operation=Trmm --dist=sequential,start:0,delta:1\n\n" -+ -+ << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" -+ << " $ cutlass_profiler --operation=Trmm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" -+ -+ << "Test your changes to trmm kernels with a quick functional test and save results in functional-test.csv:\n" -+ << " $ cutlass_profiler --operation=Trmm \\ \n" -+ << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" -+ << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" -+ << " --beta=0,1,2 --profiling-iterations=1 \\ \n" -+ << " --providers=cutlass --output=functional-test.csv\n\n"; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#if 0 -+// used this for debugging -+static std::string byte_string(std::vector const &bytes) { -+ std::stringstream ss; -+ -+ ss << "0x"; -+ -+ for (size_t idx = bytes.size(); idx > 0; --idx) { -+ ss << std::hex << std::setw(2) << std::setfill('0') << uint32_t(bytes.at(idx - 1)); -+ } -+ -+ return ss.str(); -+} -+#endif -+ -+Status TrmmOperationProfiler::TrmmProblem::parse( -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!arg_as_int(this->m, "m", problem_space, problem)) { -+ // default value -+ this->m = 1024; -+ } -+ -+ if (!arg_as_int(this->n, "n", problem_space, problem)) { -+ // default value -+ this->n = 1024; -+ } -+ -+ if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { -+ // default value -+ this->split_k_slices = 1; -+ } -+ -+ if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { -+ // default value -+ this->batch_count = 1; -+ } -+ -+ if (this->split_k_slices > 1 && this->batch_count > 1) { -+ // At least one of these must be one -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!tensor_description_satisfies(operation_desc.D, "D", problem_space, problem)) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ if (!arg_as_scalar( -+ this->alpha, -+ operation_desc.element_epilogue, -+ "alpha", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (!arg_as_scalar( -+ this->beta, -+ operation_desc.element_epilogue, -+ "beta", -+ problem_space, -+ problem)) { -+ -+ if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { -+ return Status::kErrorInternal; -+ } -+ } -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->m), int(this->m)}).front(); -+ } -+ else if (operation_desc.side_mode == SideMode::kRight) { -+ this->lda = DeviceAllocation::get_packed_layout( -+ operation_desc.A.layout, {int(this->n), int(this->n)}).front(); -+ } -+ -+ this->ldb = DeviceAllocation::get_packed_layout( -+ operation_desc.B.layout, {int(this->m), int(this->n)}).front(); -+ -+ this->ldd = DeviceAllocation::get_packed_layout( -+ operation_desc.D.layout, {int(this->m), int(this->n)}).front(); -+ -+ return Status::kSuccess; -+} -+ -+/// Initializes a performance result -+void TrmmOperationProfiler::TrmmProblem::initialize_result( -+ PerformanceResult &result, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.arguments.resize(problem_space.rank()); -+ -+ set_argument(result, "trmm_kind", problem_space, library::to_string(operation_desc.trmm_kind)); -+ -+ set_argument(result, "A", problem_space, -+ std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); -+ -+ set_argument(result, "side_mode", problem_space, library::to_string(operation_desc.side_mode)); -+ -+ set_argument(result, "fill_mode", problem_space, library::to_string(operation_desc.fill_mode)); -+ -+ set_argument(result, "diag_type", problem_space, library::to_string(operation_desc.diag_type)); -+ -+ set_argument(result, "B", problem_space, -+ std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); -+ -+ set_argument(result, "D", problem_space, -+ std::string(library::to_string(operation_desc.D.element)) + ":" + library::to_string(operation_desc.D.layout)); -+ -+ set_argument(result, "m", problem_space, m); -+ set_argument(result, "n", problem_space, n); -+ -+ set_argument(result, "split_k_slices", problem_space, split_k_slices); -+ set_argument(result, "batch_count", problem_space, batch_count); -+ -+ set_argument(result, "alpha", problem_space, -+ library::lexical_cast(alpha, operation_desc.element_epilogue)); -+ -+ set_argument(result, "beta", problem_space, -+ library::lexical_cast(beta, operation_desc.element_epilogue)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Extracts the problem dimensions -+Status TrmmOperationProfiler::initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::TrmmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (operation_desc.trmm_kind != library::TrmmKind::kUniversal) { -+ return Status::kErrorInvalidProblem; -+ } -+ -+ Status status = problem_.parse(operation_desc, problem_space, problem); -+ -+ if (status != Status::kSuccess) { -+ return status; -+ } -+ -+ trmm_workspace_.configuration.problem_size.m() = int(problem_.m); -+ trmm_workspace_.configuration.problem_size.n() = int(problem_.n); -+ trmm_workspace_.configuration.problem_size.k() = (operation_desc.side_mode == SideMode::kLeft) -+ ? int(problem_.m) : int(problem_.n); -+ trmm_workspace_.configuration.lda = problem_.lda; -+ trmm_workspace_.configuration.ldb = problem_.ldb; -+ trmm_workspace_.configuration.ldd = problem_.ldd; -+ //trmm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); -+ trmm_workspace_.configuration.batch_count = int(problem_.split_k_slices); -+ -+ trmm_workspace_.arguments.A = nullptr; -+ trmm_workspace_.arguments.B = nullptr; -+ trmm_workspace_.arguments.D = nullptr; -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ initialize_result_(this->model_result_, options, operation_desc, problem_space); -+ -+ return operation->can_implement(&trmm_workspace_.configuration, &trmm_workspace_.arguments); -+} -+ -+/// Initializes the performance result -+void TrmmOperationProfiler::initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space) { -+ -+ result.provider = library::Provider::kCUTLASS; -+ result.disposition = Disposition::kNotRun; -+ result.status = Status::kSuccess; -+ result.operation_name = operation_desc.name; -+ -+ problem_.initialize_result(result, operation_desc, problem_space); -+ -+ OperationProfiler::initialize_result_(result, operation_desc, problem_space); -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ // Input bytes read and Output bytes written for the trmm problem -+ result.bytes = -+ // Half matrix including the diagonal will have (M*(M+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * (problem_.m + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.m / 8) * problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.D.element) * problem_.m / 8) * problem_.n; -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ // Input bytes read and Output bytes written for the trmm problem -+ result.bytes = -+ // Half matrix including the diagonal will have (N*(N+1))/2 elements -+ int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.n / 8) * (problem_.n + 1) / 2 + -+ int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.m / 8) * problem_.n + -+ int64_t(library::sizeof_bits(operation_desc.D.element) * problem_.m / 8) * problem_.n; -+ } -+ -+ // FLOPs = 2 * [ ( M * (M+1)/2 * N ) ] // Beta is zero -+ result.flops = problem_.m * (problem_.m + 1) * problem_.n; -+ -+ result.runtime = 0; -+ -+ // complex-valued support -+ switch (operation_desc.tile_description.math_instruction.math_operation) { -+ case library::MathOperationID::kMultiplyAddComplex: -+ result.flops *= 4; -+ break; -+ -+ case library::MathOperationID::kMultiplyAddComplexFastF32: -+ result.flops *= 4; -+ break; -+ -+ default: break; -+ } -+ -+} -+ -+/// Initializes workspace -+Status TrmmOperationProfiler::initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ library::TrmmDescription const &operation_desc = -+ static_cast(operation->description()); -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ if (operation_desc.side_mode == SideMode::kLeft) { -+ trmm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.m), int(problem_.m)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } else if (operation_desc.side_mode == SideMode::kRight) { -+ trmm_workspace_.A = device_context.allocate_tensor( -+ options, -+ "A", -+ operation_desc.A.element, -+ operation_desc.A.layout, -+ {int(problem_.n), int(problem_.n)}, -+ {int(problem_.lda)}, -+ 1 // batch_count = 1, default -+ ); -+ } -+ -+ trmm_workspace_.B = device_context.allocate_tensor( -+ options, -+ "B", -+ operation_desc.B.element, -+ operation_desc.B.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldb)} -+ ); -+ -+ trmm_workspace_.Computed = device_context.allocate_tensor( -+ "D", -+ operation_desc.D.element, -+ operation_desc.D.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldd)} -+ ); -+ -+ trmm_workspace_.Reference = device_context.allocate_tensor( -+ "Reference", -+ operation_desc.D.element, -+ operation_desc.D.layout, -+ {int(problem_.m), int(problem_.n)}, -+ {int(problem_.ldd)} -+ ); -+ -+ } -+ -+ // -+ // Initialize the CUTLASS operation -+ // -+ Status status = Status::kSuccess; -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ if (options.execution_mode != ExecutionMode::kDryRun) { -+ -+ uint64_t workspace_size = operation->get_host_workspace_size(&trmm_workspace_.configuration); -+ trmm_workspace_.host_workspace.resize(workspace_size, 0); -+ -+ workspace_size = operation->get_device_workspace_size(&trmm_workspace_.configuration); -+ trmm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); -+ -+ status = operation->initialize( -+ &trmm_workspace_.configuration, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data()); -+ } -+ -+ // -+ // If CUTLASS is enabled, generate a result for it -+ // -+ results_.push_back(model_result_); -+ results_.back().provider = library::Provider::kCUTLASS; -+ results_.back().op_kind = library::OperationKind::kTrmm; -+ results_.back().disposition = Disposition::kNotRun; -+ -+ for(auto provider : verification_providers_) { -+ results_.back().verification_map[provider] = Disposition::kNotRun; -+ } -+ } -+ -+ return status; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool TrmmOperationProfiler::verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ return true; -+ } -+ -+ if (options.execution_mode == ExecutionMode::kDryRun) { -+ return true; -+ } -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Computed->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ // -+ // Run the CUTLASS operation -+ // -+ -+ results_.back().status = operation->run( -+ &trmm_workspace_.arguments, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data()); -+ -+ if (results_.back().status != Status::kSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ cudaError_t result = cudaDeviceSynchronize(); -+ if (result != cudaSuccess) { -+ results_.back().disposition = Disposition::kFailed; -+ return false; -+ } -+ -+ // CUTLASS op ran the but not yet verified against any verification provider -+ results_.back().disposition = Disposition::kNotVerified; -+ -+ // -+ // Run verification providers -+ // -+ -+ if (options.verification.enabled) { -+ -+#if CUTLASS_ENABLE_CUBLAS -+ if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { -+ -+ // Guard against unsupported cases -+ auto const & trmm_desc = static_cast(operation->description()); -+ -+ if (cublas_satisfies(trmm_desc) == Status::kSuccess) { -+ -+ // call cublas verification if supported -+ verify_with_cublas_( -+ options, -+ report, -+ device_context, -+ operation, -+ problem_space, -+ problem); -+ } -+ -+ else { -+ // set verification map for cublas to not supported -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; -+ } -+ } -+#endif // #if CUTLASS_ENABLE_CUBLAS -+ -+ // Update disposition to worst case verification outcome among all -+ // verification providers which are supported -+ bool is_any_verification_run_passed = false; -+ for(auto &m : results_.back().verification_map) { -+ if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { -+ results_.back().disposition = m.second; -+ return true; -+ } -+ if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { -+ is_any_verification_run_passed = true; -+ } -+ } -+ -+ if(is_any_verification_run_passed) { -+ results_.back().disposition = Disposition::kPassed; -+ } -+ } -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Verifies CUTLASS against references -+bool TrmmOperationProfiler::verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ -+#if CUTLASS_ENABLE_CUBLAS -+ -+ library::TrmmDescription const &trmm_desc = -+ static_cast(operation->description()); -+ -+ // -+ // Construct cuBLAS operators -+ // -+ -+ CublasCreate handle; -+ cublasStatus_t status = handle.get_cublas_create_status(); -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Initialize state -+ // -+ -+ try { -+ -+ // -+ // Construct dispatcher to cublasTrmm() -+ // -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Reference->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ detail::cublasTrmmDispatcher trmm_op( -+ trmm_desc, -+ trmm_workspace_.configuration, -+ trmm_workspace_.arguments -+ ); -+ -+ if (trmm_op.status != Status::kSuccess) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotRun; -+ return true; -+ } -+ -+ results_.back().status = Status::kSuccess; -+ -+ status = trmm_op(handle); -+ -+ // Handle errors -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ return true; -+ } -+ -+ // -+ // Verify results -+ // -+ results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( -+ options, -+ *trmm_workspace_.Computed, -+ *trmm_workspace_.Reference -+ ); -+ -+ // Save workspace if incorrect -+ if (options.verification.save_workspace == SaveWorkspace::kIncorrect && -+ results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { -+ -+ save_workspace( -+ device_context, -+ options, -+ trmm_desc, -+ library::Provider::kCUTLASS, -+ library::Provider::kCUBLAS); -+ } -+ } -+ catch (...) { -+ results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; -+ } -+ -+#endif -+ -+ // Return true means continue profiling -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Measures performance results -+bool TrmmOperationProfiler::profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem) { -+ -+ if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { -+ -+ // Initialize structure containing TRMM arguments -+ trmm_workspace_.arguments.A = trmm_workspace_.A->data(); -+ trmm_workspace_.arguments.B = trmm_workspace_.B->data(); -+ trmm_workspace_.arguments.D = trmm_workspace_.Computed->data(); -+ trmm_workspace_.arguments.alpha = problem_.alpha.data(); -+ trmm_workspace_.arguments.beta = problem_.beta.data(); -+ trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; -+ -+ results_.back().status = profile_cutlass_( -+ results_.back().runtime, -+ options, -+ operation, -+ &trmm_workspace_.arguments, -+ trmm_workspace_.host_workspace.data(), -+ trmm_workspace_.device_workspace.data() -+ ); -+ } -+ return true; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h -new file mode 100644 -index 0000000..32ebcda ---- /dev/null -+++ b/3rdparty/cutlass/tools/profiler/src/trmm_operation_profiler.h -@@ -0,0 +1,222 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines a math function -+ -+ -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+// CUTLASS Library includes -+#include "cutlass/blas3.h" -+#include "cutlass/library/library.h" -+#include "cutlass/library/util.h" -+#include "cutlass/library/manifest.h" -+ -+// Profiler includes -+#include "options.h" -+#include "device_context.h" -+#include "operation_profiler.h" -+#include "performance_result.h" -+#include "problem_space.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace profiler { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Abstract base class for each math function -+class TrmmOperationProfiler : public OperationProfiler { -+public: -+ -+ /// Problem structure obtained from problem space -+ struct TrmmProblem { -+ int64_t m; -+ int64_t n; -+ int64_t lda; -+ int64_t ldb; -+ int64_t ldd; -+ SideMode side_mode; -+ FillMode fill_mode; -+ DiagType diag_type; -+ std::vector alpha; -+ std::vector beta; -+ int64_t split_k_slices; -+ int64_t batch_count; -+ -+ // -+ // Methods -+ // -+ -+ TrmmProblem(): -+ m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } -+ -+ /// Parses the problem -+ Status parse( -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes a performance result -+ void initialize_result( -+ PerformanceResult &result, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ }; -+ -+ /// Workspace used -+ struct TrmmWorkspace { -+ -+ DeviceAllocation *A; -+ DeviceAllocation *B; -+ DeviceAllocation *D; -+ DeviceAllocation *Computed; -+ DeviceAllocation *Reference; -+ -+ library::TrmmConfiguration configuration; -+ library::TrmmArguments arguments; -+ -+ /// Buffer used for the operation's host workspace -+ std::vector host_workspace; -+ -+ /// Buffer used for the operations' device workspace -+ DeviceAllocation device_workspace; -+ -+ // -+ // Methods -+ // -+ -+ TrmmWorkspace(): -+ A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } -+ }; -+ -+protected: -+ -+ // -+ // Data members -+ // -+ -+ /// GEMM problem obtained from problem space -+ TrmmProblem problem_; -+ -+ /// Device memory allocations -+ TrmmWorkspace trmm_workspace_; -+ -+ -+public: -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TrmmOperationProfiler(Options const &options); -+ -+ /// Destructor -+ virtual ~TrmmOperationProfiler(); -+ -+ /// Prints usage statement for the math function -+ virtual void print_usage(std::ostream &out) const; -+ -+ /// Prints examples -+ virtual void print_examples(std::ostream &out) const; -+ -+ /// Extracts the problem dimensions -+ virtual Status initialize_configuration( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Initializes workspace -+ virtual Status initialize_workspace( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Verifies CUTLASS against references -+ virtual bool verify_cutlass( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+ /// Measures performance results -+ virtual bool profile( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+protected: -+ -+ /// Initializes the performance result -+ void initialize_result_( -+ PerformanceResult &result, -+ Options const &options, -+ library::TrmmDescription const &operation_desc, -+ ProblemSpace const &problem_space); -+ -+ /// Verifies CUTLASS against references -+ bool verify_with_cublas_( -+ Options const &options, -+ PerformanceReport &report, -+ DeviceContext &device_context, -+ library::Operation const *operation, -+ ProblemSpace const &problem_space, -+ ProblemSpace::Problem const &problem); -+ -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace profiler -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp -new file mode 100644 -index 0000000..5f2dd4b ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp -@@ -0,0 +1,67 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+struct GPU_Clock -+{ -+ GPU_Clock() { -+ cudaEventCreate(&start_); -+ cudaEventCreate(&stop_); -+ cudaEventRecord(start_); -+ } -+ -+ ~GPU_Clock() { -+ cudaEventDestroy(start_); -+ cudaEventDestroy(stop_); -+ } -+ -+ void start() { -+ cudaEventRecord(start_); -+ } -+ -+ float milliseconds() { -+ cudaEventRecord(stop_); -+ cudaEventSynchronize(stop_); -+ float time; -+ cudaEventElapsedTime(&time, start_, stop_); -+ return time; -+ } -+ -+ float seconds() { -+ return milliseconds() * float(1e-3); -+ } -+ -+ private: -+ cudaEvent_t start_, stop_; -+}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h b/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h -new file mode 100644 -index 0000000..65cf9a1 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/command_line.h -@@ -0,0 +1,313 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * Utility for parsing command line arguments -+ */ -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+ -+/****************************************************************************** -+ * command_line -+ ******************************************************************************/ -+ -+/** -+ * Utility for parsing command line arguments -+ */ -+struct CommandLine { -+ std::vector keys; -+ std::vector values; -+ std::vector args; -+ -+ /** -+ * Constructor -+ */ -+ CommandLine(int argc, const char** argv) { -+ using namespace std; -+ -+ for (int i = 1; i < argc; i++) { -+ string arg = argv[i]; -+ -+ if ((arg[0] != '-') || (arg[1] != '-')) { -+ args.push_back(arg); -+ continue; -+ } -+ -+ string::size_type pos; -+ string key, val; -+ if ((pos = arg.find('=')) == string::npos) { -+ key = string(arg, 2, arg.length() - 2); -+ val = ""; -+ } else { -+ key = string(arg, 2, pos - 2); -+ val = string(arg, pos + 1, arg.length() - 1); -+ } -+ -+ keys.push_back(key); -+ values.push_back(val); -+ } -+ } -+ -+ /** -+ * Checks whether a flag "--" is present in the commandline -+ */ -+ bool check_cmd_line_flag(const char* arg_name) const { -+ using namespace std; -+ -+ for (int i = 0; i < int(keys.size()); ++i) { -+ if (keys[i] == string(arg_name)) return true; -+ } -+ return false; -+ } -+ -+ /** -+ * Returns number of naked (non-flag and non-key-value) commandline parameters -+ */ -+ size_t num_naked_args() const { -+ return args.size(); -+ } -+ -+ /** -+ * Print naked (non-flag and non-key-value) commandline parameters -+ */ -+ void print_naked_args(std::ostream &out) const { -+ for (auto arg : args) { -+ out << " " << arg <<"\n"; -+ } -+ } -+ -+ /** -+ * Returns the commandline parameter for a given index (not including flags) -+ */ -+ template -+ void get_cmd_line_argument(int index, value_t& val) const { -+ using namespace std; -+ if (index < args.size()) { -+ istringstream str_stream(args[index]); -+ str_stream >> val; -+ } -+ } -+ -+ /** -+ * Obtains the boolean value specified for a given commandline parameter --= -+ */ -+ void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { -+ val = _default; -+ if (check_cmd_line_flag(arg_name)) { -+ std::string value; -+ get_cmd_line_argument(arg_name, value); -+ -+ val = !(value == "0" || value == "false"); -+ } -+ } -+ -+ /** -+ * Obtains the value specified for a given commandline parameter --= -+ */ -+ template -+ void get_cmd_line_argument(const char* arg_name, -+ value_t& val) const { -+ -+ get_cmd_line_argument(arg_name, val, val); -+ } -+ -+ /** -+ * Obtains the value specified for a given commandline parameter --= -+ */ -+ template -+ void get_cmd_line_argument(const char* arg_name, -+ value_t& val, -+ value_t const& _default) const { -+ using namespace std; -+ -+ val = _default; -+ -+ for (int i = 0; i < int(keys.size()); ++i) { -+ if (keys[i] == string(arg_name)) { -+ istringstream str_stream(values[i]); -+ str_stream >> val; -+ } -+ } -+ } -+ -+ /** -+ * Returns the values specified for a given commandline parameter --=,* -+ */ -+ template -+ void get_cmd_line_arguments(const char* arg_name, -+ std::vector& vals, -+ char sep = ',') const { -+ using namespace std; -+ -+ if (check_cmd_line_flag(arg_name)) { -+ // Clear any default values -+ vals.clear(); -+ -+ // Recover from multi-value string -+ for (int i = 0; i < keys.size(); ++i) { -+ if (keys[i] == string(arg_name)) { -+ string val_string(values[i]); -+ seperate_string(val_string, vals, sep); -+ } -+ } -+ } -+ } -+ -+ /** -+ * Returns the values specified for a given commandline parameter -+ * --=,* -+ */ -+ void get_cmd_line_argument_pairs(const char* arg_name, -+ std::vector >& tokens, -+ char delim = ',', -+ char sep = ':') const { -+ if (check_cmd_line_flag(arg_name)) { -+ std::string value; -+ get_cmd_line_argument(arg_name, value); -+ -+ tokenize(tokens, value, delim, sep); -+ } -+ } -+ -+ /** -+ * Returns a list of ranges specified for a given commandline parameter -+ * --=,* -+ */ -+ void get_cmd_line_argument_ranges(const char* arg_name, -+ std::vector >& vals, -+ char delim = ',', -+ char sep = ':') const { -+ std::vector ranges; -+ get_cmd_line_arguments(arg_name, ranges, delim); -+ -+ for (std::vector::const_iterator range = ranges.begin(); -+ range != ranges.end(); ++range) { -+ -+ std::vector range_vals; -+ seperate_string(*range, range_vals, sep); -+ vals.push_back(range_vals); -+ } -+ } -+ -+ /** -+ * The number of pairs parsed -+ */ -+ int parsed_argc() const { return (int)keys.size(); } -+ -+ //------------------------------------------------------------------------- -+ // Utility functions -+ //------------------------------------------------------------------------- -+ -+ /// Tokenizes a comma-delimited list of string pairs delimited by ':' -+ static void tokenize(std::vector >& tokens, -+ std::string const& str, -+ char delim = ',', -+ char sep = ':') { -+ // Home-built to avoid Boost dependency -+ size_t s_idx = 0; -+ size_t d_idx = std::string::npos; -+ while (s_idx < str.size()) { -+ d_idx = str.find_first_of(delim, s_idx); -+ -+ size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); -+ size_t sep_idx = str.find_first_of(sep, s_idx); -+ size_t offset = 1; -+ if (sep_idx == std::string::npos || sep_idx >= end_idx) { -+ sep_idx = end_idx; -+ offset = 0; -+ } -+ -+ std::pair item( -+ str.substr(s_idx, sep_idx - s_idx), -+ str.substr(sep_idx + offset, end_idx - sep_idx - offset)); -+ -+ tokens.push_back(item); -+ s_idx = end_idx + 1; -+ } -+ } -+ -+ /// Tokenizes a comma-delimited list of string pairs delimited by ':' -+ static void tokenize(std::vector& tokens, -+ std::string const& str, -+ char delim = ',', -+ char sep = ':') { -+ typedef std::vector > TokenVector; -+ typedef TokenVector::const_iterator token_iterator; -+ -+ std::vector > token_pairs; -+ tokenize(token_pairs, str, delim, sep); -+ for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { -+ tokens.push_back(tok->first); -+ } -+ } -+ -+ template -+ static void seperate_string(std::string const& str, -+ std::vector& vals, -+ char sep = ',') { -+ std::istringstream str_stream(str); -+ std::string::size_type old_pos = 0; -+ std::string::size_type new_pos = 0; -+ -+ // Iterate -delimited values -+ value_t val; -+ while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { -+ if (new_pos != old_pos) { -+ str_stream.width(new_pos - old_pos); -+ str_stream >> val; -+ vals.push_back(val); -+ } -+ -+ // skip over delimiter -+ str_stream.ignore(1); -+ old_pos = new_pos + 1; -+ } -+ -+ // Read last value -+ str_stream >> val; -+ vals.push_back(val); -+ } -+}; -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp -new file mode 100644 -index 0000000..82d56fa ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp -@@ -0,0 +1,526 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+ -+//-- BLAM_DEBUG_OUT --------------------------------------------------------- -+#ifdef BLAM_DEBUG -+# include -+# ifndef BLAM_DEBUG_OUT -+# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl -+# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl -+# endif // BLAM_DEBUG_OUT -+#else -+# ifndef BLAM_DEBUG_OUT -+# define BLAM_DEBUG_OUT(msg) -+# define BLAM_DEBUG_OUT_2(msg) -+# endif // BLAM_DEBUG_OUT -+#endif // BLAM_DEBUG -+ -+// User could potentially define ComplexFloat/ComplexDouble instead of std:: -+#ifndef BLAM_COMPLEX_TYPES -+#define BLAM_COMPLEX_TYPES 1 -+#include -+namespace blam { -+template -+using Complex = cuda::std::complex; -+using ComplexFloat = cuda::std::complex; -+using ComplexDouble = cuda::std::complex; -+} -+#endif // BLAM_COMPLEX_TYPES -+ -+// User could potentially define Half instead of cute:: -+#ifndef BLAM_HALF_TYPE -+#define BLAM_HALF_TYPE 1 -+#include -+namespace blam { -+using Half = cute::half_t; -+} -+#endif // BLAM_HALF_TYPE -+ -+namespace blam -+{ -+namespace cublas -+{ -+ -+inline const char* -+cublas_get_error(cublasStatus_t status) -+{ -+ switch (status) { -+ case CUBLAS_STATUS_SUCCESS: -+ return "CUBLAS_STATUS_SUCCESS"; -+ case CUBLAS_STATUS_NOT_INITIALIZED: -+ return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; -+ case CUBLAS_STATUS_ALLOC_FAILED: -+ return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; -+ case CUBLAS_STATUS_INVALID_VALUE: -+ return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; -+ case CUBLAS_STATUS_ARCH_MISMATCH: -+ return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; -+ case CUBLAS_STATUS_MAPPING_ERROR: -+ return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; -+ case CUBLAS_STATUS_EXECUTION_FAILED: -+ return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; -+ case CUBLAS_STATUS_INTERNAL_ERROR: -+ return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; -+ case CUBLAS_STATUS_NOT_SUPPORTED: -+ return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; -+ case CUBLAS_STATUS_LICENSE_ERROR: -+ return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; -+ default: -+ return "CUBLAS_ERROR -- "; -+ } -+} -+ -+inline bool -+cublas_is_error(cublasStatus_t status) -+{ -+ return status != CUBLAS_STATUS_SUCCESS; -+} -+ -+ -+// hgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* A, int ldA, -+ const Half* B, int ldB, -+ const Half* beta, -+ Half* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasHgemm"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), CUDA_R_16F, ldA, -+ reinterpret_cast(B), CUDA_R_16F, ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, -+ CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// mixed hf gemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const Half* A, int ldA, -+ const Half* B, int ldB, -+ const float* beta, -+ float* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ alpha, -+ reinterpret_cast(A), CUDA_R_16F, ldA, -+ reinterpret_cast(B), CUDA_R_16F, ldB, -+ beta, -+ C, CUDA_R_32F, ldC, -+ CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// igemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const int32_t* alpha, -+ const int8_t* A, int ldA, -+ const int8_t* B, int ldB, -+ const int32_t* beta, -+ int32_t* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasIgemm"); -+ -+ return cublasGemmEx(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, CUDA_R_8I, ldA, -+ B, CUDA_R_8I, ldB, -+ beta, -+ C, CUDA_R_32I, ldC, -+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* A, int ldA, -+ const float* B, int ldB, -+ const float* beta, -+ float* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasSgemm"); -+ -+ return cublasSgemm(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, -+ B, ldB, -+ beta, -+ C, ldC); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* A, int ldA, -+ const double* B, int ldB, -+ const double* beta, -+ double* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasDgemm"); -+ -+ return cublasDgemm(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, -+ B, ldB, -+ beta, -+ C, ldC); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* A, int ldA, -+ const ComplexFloat* B, int ldB, -+ const ComplexFloat* beta, -+ ComplexFloat* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasCgemm"); -+ -+ return cublasCgemm(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, -+ reinterpret_cast(B), ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* A, int ldA, -+ const ComplexDouble* B, int ldB, -+ const ComplexDouble* beta, -+ ComplexDouble* C, int ldC) -+{ -+ BLAM_DEBUG_OUT("cublasZgemm"); -+ -+ return cublasZgemm(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, -+ reinterpret_cast(B), ldB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC); -+} -+ -+// hgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* A, int ldA, int loA, -+ const Half* B, int ldB, int loB, -+ const Half* beta, -+ Half* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); -+ -+ return cublasHgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast<__half*>(C), ldC, loC, -+ batch_size); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* A, int ldA, int loA, -+ const float* B, int ldB, int loB, -+ const float* beta, -+ float* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); -+ -+ return cublasSgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, loA, -+ B, ldB, loB, -+ beta, -+ C, ldC, loC, -+ batch_size); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* A, int ldA, int loA, -+ const double* B, int ldB, int loB, -+ const double* beta, -+ double* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); -+ -+ return cublasDgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ A, ldA, loA, -+ B, ldB, loB, -+ beta, -+ C, ldC, loC, -+ batch_size); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* A, int ldA, int loA, -+ const ComplexFloat* B, int ldB, int loB, -+ const ComplexFloat* beta, -+ ComplexFloat* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); -+ -+ return cublasCgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC, loC, -+ batch_size); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* A, int ldA, int loA, -+ const ComplexDouble* B, int ldB, int loB, -+ const ComplexDouble* beta, -+ ComplexDouble* C, int ldC, int loC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); -+ -+ return cublasZgemmStridedBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(A), ldA, loA, -+ reinterpret_cast(B), ldB, loB, -+ reinterpret_cast(beta), -+ reinterpret_cast(C), ldC, loC, -+ batch_size); -+} -+ -+// hgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const Half* alpha, -+ const Half* const A[], int ldA, -+ const Half* const B[], int ldB, -+ const Half* beta, -+ Half* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasHgemmBatched"); -+ -+ return cublasHgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ reinterpret_cast(const_cast(A)), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ reinterpret_cast(const_cast(B)), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ reinterpret_cast<__half**>(const_cast(C)), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// sgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const float* alpha, -+ const float* const A[], int ldA, -+ const float* const B[], int ldB, -+ const float* beta, -+ float* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasSgemmBatched"); -+ -+ return cublasSgemmBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ const_cast(A), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ const_cast(B), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ beta, -+ const_cast(C), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// dgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const double* alpha, -+ const double* const A[], int ldA, -+ const double* const B[], int ldB, -+ const double* beta, -+ double* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasDgemmBatched"); -+ -+ return cublasDgemmBatched(handle, transA, transB, -+ m, n, k, -+ alpha, -+ const_cast(A), ldA, -+ // A, ldA, // cuBLAS 9.2 -+ const_cast(B), ldB, -+ // B, ldB, // cuBLAS 9.2 -+ beta, -+ const_cast(C), ldC, -+ // C, ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// cgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexFloat* alpha, -+ const ComplexFloat* const A[], int ldA, -+ const ComplexFloat* const B[], int ldB, -+ const ComplexFloat* beta, -+ ComplexFloat* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasCgemmBatched"); -+ -+ return cublasCgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ const_cast(reinterpret_cast(A)), ldA, -+ //reinterpret_cast(A), ldA, // cuBLAS 9.2 -+ const_cast(reinterpret_cast(B)), ldB, -+ //reinterpret_cast(B), ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ const_cast(reinterpret_cast(C)), ldC, -+ //reinterpret_cast(C), ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+// zgemm -+inline cublasStatus_t -+gemm_batch(cublasHandle_t handle, -+ cublasOperation_t transA, cublasOperation_t transB, -+ int m, int n, int k, -+ const ComplexDouble* alpha, -+ const ComplexDouble* const A[], int ldA, -+ const ComplexDouble* const B[], int ldB, -+ const ComplexDouble* beta, -+ ComplexDouble* const C[], int ldC, -+ int batch_size) -+{ -+ BLAM_DEBUG_OUT("cublasZgemmBatched"); -+ -+ return cublasZgemmBatched(handle, transA, transB, -+ m, n, k, -+ reinterpret_cast(alpha), -+ const_cast(reinterpret_cast(A)), ldA, -+ //reinterpret_cast(A), ldA, // cuBLAS 9.2 -+ const_cast(reinterpret_cast(B)), ldB, -+ //reinterpret_cast(B), ldB, // cuBLAS 9.2 -+ reinterpret_cast(beta), -+ const_cast(reinterpret_cast(C)), ldC, -+ //reinterpret_cast(C), ldC, // cuBLAS 9.2 -+ batch_size); -+} -+ -+} // end namespace cublas -+} // end namespace blam -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h b/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h -new file mode 100644 -index 0000000..3a2480c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/debug.h -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Contains code for debugging cutlass code -+*/ -+ -+#pragma once -+ -+#include "device_dump.h" -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/****************************************************************************** -+ * Debug and logging macros -+ ******************************************************************************/ -+ -+/** -+ * Formats and prints the given message to stdout -+ */ -+#if !defined(CUDA_LOG) -+#if !defined(__CUDA_ARCH__) -+#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) -+#else -+#define CUDA_LOG(format, ...) \ -+ printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ -+ blockIdx.x, \ -+ blockIdx.y, \ -+ blockIdx.z, \ -+ threadIdx.x, \ -+ threadIdx.y, \ -+ threadIdx.z, \ -+ __VA_ARGS__); -+#endif -+#endif -+ -+/** -+ * Formats and prints the given message to stdout only if DEBUG is defined -+ */ -+#if !defined(CUDA_LOG_DEBUG) -+#ifdef DEBUG -+#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) -+#else -+#define CUDA_LOG_DEBUG(format, ...) -+#endif -+#endif -+ -+/** -+ * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) -+ * along with the supplied source context. -+ * -+ * \return The CUDA error. -+ */ -+__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, -+ const char* expression, -+ const char* filename, -+ int line) { -+ (void)filename; -+ (void)line; -+ if (error) { -+#if !defined(__CUDA_ARCH__) -+ fprintf( -+ stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); -+ fflush(stderr); -+#else -+ printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); -+#endif -+ } -+ return error; -+} -+ -+/** -+ * \brief Perror macro -+ */ -+#ifndef CUDA_PERROR -+#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) -+#endif -+ -+/** -+ * \brief Perror macro with exit -+ */ -+#ifndef CUDA_PERROR_EXIT -+#define CUDA_PERROR_EXIT(e) \ -+ do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ -+ exit(1); \ -+ } } while (0) -+#endif -+ -+/** -+ * \brief Perror macro only if DEBUG is defined -+ */ -+#ifndef CUDA_PERROR_DEBUG -+#ifdef DEBUG -+#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) -+#else -+#define CUDA_PERROR_DEBUG(e) (e) -+#endif -+#endif -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// A small helper class to dump a type at compile time -+// Usage:: DumpType::Class -+template -+struct DebugType {}; -+ -+template -+void DebugTypeFunc(T const& t) { -+ T::t; -+} -+ -+// A small helper class to dump a compile time constant at compile time -+// Usage: DumpValue::kConstant -+template -+struct DebugValue {}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h -new file mode 100644 -index 0000000..7a3270d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_dump.h -@@ -0,0 +1,187 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+/** -+ * \file -+ * \brief C++ interface to dump fragments and shared memory contents for -+ * debugging. -+ */ -+ -+namespace cutlass { -+namespace debug { -+ -+/****************************************************************************** -+ * Dump the fragments -+ ******************************************************************************/ -+ -+/// The first N threads dump the first M elements from their fragments with a -+/// stride of S elements. If N is not specified, dump the data of all the -+/// threads. If M is not specified, dump all the elements of the fragment. -+template -+CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, -+ int S = 1) { -+ int total_threads = blockDim.x * blockDim.y * blockDim.z; -+ int block_id = -+ blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; -+ int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + -+ (threadIdx.y * blockDim.x) + threadIdx.x; -+ -+ if (N < 0 || N > total_threads) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Thread number N = %d should between [1, %d].\n", N, -+ total_threads); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ int total_elements = frag.size(); -+ -+ if (M < 0 || M > total_elements) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Element number M = %d should between [1, %d].\n", M, -+ total_elements); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ if (N == 0) N = total_threads; -+ -+ if (M == 0) M = total_elements; -+ -+ if (S < 1 || S > M) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Stride S = %d should between [1, %d].\n", S, M); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ if (thread_id == 0 && block_id == 0) -+ printf("\n*******************Dumping the fragments*******************\n\n"); -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int tid = 0; tid < N; ++tid) { -+ if (tid == thread_id) { -+ printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int i = 0; i < M; i += S) { -+ printf("%.0f ", float(typename Fragment::value_type(frag[i]))); -+ } -+ printf("\n"); -+ } -+ -+ __syncthreads(); -+ } -+ -+ if (thread_id == 0 && block_id == 0) -+ printf("\n***********************************************************\n\n"); -+ -+ __syncthreads(); -+ -+ return; -+} -+ -+/****************************************************************************** -+ * Dump the shared memory -+ ******************************************************************************/ -+ -+#define SHMEM_ROW_SIZE 128 -+ -+/// Dump the shared memory contents. ptr is the begin address, size specifies -+/// the number of elements that need to be dumped, and S specifies the stride. -+template -+CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { -+ int block_id = -+ blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; -+ int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + -+ (threadIdx.y * blockDim.x) + threadIdx.x; -+ -+ if (ptr == nullptr) { -+ if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); -+ -+ __syncthreads(); -+ return; -+ } -+ -+ if (size < 1) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Element size is less than 1\n"); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ int row_elements = SHMEM_ROW_SIZE / sizeof(Element); -+ -+ if (S < 1 || S > row_elements) { -+ if (thread_id == 0 && block_id == 0) -+ printf("Stride S = %d should between [1, %d].\n", S, row_elements); -+ -+ __syncthreads(); -+ -+ return; -+ } -+ -+ __syncthreads(); -+ -+ if (thread_id == 0) -+ printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); -+ -+ if (thread_id == 0) { -+ for (int i = 0; i < size; i += row_elements) { -+ for (int j = 0; j < row_elements; j += S) { -+ printf("%.0f ", float(ptr[i + j])); -+ } -+ -+ printf("\n"); -+ } -+ } -+ -+ if (thread_id == 0) -+ printf("\n***********************************************************\n\n"); -+ -+ __syncthreads(); -+ -+ return; -+} -+} // namespace debug -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h -new file mode 100644 -index 0000000..aaa19b2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_groupnorm.h -@@ -0,0 +1,402 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do group norm on a device memory tensor with NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void groupnorm(cutlass::Tensor4DCoord input_size, -+ const int num_groups, -+ const float eps, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream); -+ -+extern __shared__ char groupnorm_shm[]; -+ -+// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, -+// we store the input in the shared memory. -+// grid(num_groups, dim0) -+// block(BLOCKSIZE) -+// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -+template -+__global__ void groupnorm_twopass_store_locally(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ int num_groups, -+ int prod_dim1_to_last_dim, -+ int last_dim, -+ const float eps, -+ const int TVecs_PER_THREAD) -+{ -+ const int bid = blockIdx.y; // index of batch -+ const int gid = blockIdx.x; // index of group -+ const int tid = threadIdx.x; // index of thread -+ const int bdimx = blockDim.x; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int v_reduce_elements = s_reduce_elements / T_PER_TVec; -+ const int s_group_stride = last_dim / num_groups; -+ const int v_group_stride = s_group_stride / T_PER_TVec; -+ const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; -+ const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; -+ TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; -+ T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; -+ float local_sum[1] = {0.0f}; -+ -+// load from global memory into shared memory -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+ const int local_val_offset = i * T_PER_TVec; -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ local_sum[0] += tmp; -+ local_val[local_val_offset + j] = tmp_vec_ptr[j]; -+ } -+ } -+ } -+ __shared__ float s_mean, s_variance; -+ -+ // reduction for mean -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_mean = local_sum[0] / s_reduce_elements; -+ } -+ __syncthreads(); -+ -+ // reduction for std -+ local_sum[0] = 0.0f; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int local_val_offset = i * T_PER_TVec; -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(local_val[local_val_offset + j]); -+ tmp -= s_mean; -+ local_sum[0] += tmp * tmp; -+ } -+ } -+ } -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); -+ } -+ __syncthreads(); -+ -+ // normalize -+ const int gamma_offset_of_group = gid * v_group_stride; -+ const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; -+ const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; -+ const int local_val_offset = i * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; -+ TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; -+ T* gamma_val_ptr = (T*)(&gamma_val); -+ T* beta_val_ptr = (T*)(&beta_val); -+ TVec tmp_vec; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance -+ * static_cast(gamma_val_ptr[j]) -+ + static_cast(beta_val_ptr[j]); -+ if (sizeof(T) == sizeof(half)) { -+ tmp_vec_ptr[j] = T(__float2half_rn(tmp)); -+ } -+ else { -+ tmp_vec_ptr[j] = T(tmp); -+ } -+ } -+ output_TVec_ptr[offset_in_group] = tmp_vec; -+ } -+ } -+} -+ -+// For large prod_dim1_to_last_dim/num_groups, -+// in which the data cannot be stored locally, -+// we will load from global memory multiple times, -+// grid(num_groups, dim0) -+// block(BLOCKSIZE) -+// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group -+template -+__global__ void groupnorm_twopass_multiple_load(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ int num_groups, -+ int prod_dim1_to_last_dim, -+ int last_dim, -+ const float eps, -+ const int TVecs_PER_THREAD) -+{ -+ const int bid = blockIdx.y; // index of batch -+ const int gid = blockIdx.x; // index of group -+ const int tid = threadIdx.x; // index of thread -+ const int bdimx = blockDim.x; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int v_reduce_elements = s_reduce_elements / T_PER_TVec; -+ const int s_group_stride = last_dim / num_groups; -+ const int v_group_stride = s_group_stride / T_PER_TVec; -+ const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; -+ const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; -+ TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; -+ float local_sum[1] = {0.0f}; -+ -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ local_sum[0] += tmp; -+ } -+ } -+ } -+ __shared__ float s_mean, s_variance; -+ -+ // reduction for mean -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_mean = local_sum[0] / s_reduce_elements; -+ } -+ __syncthreads(); -+ -+ // reduction for std -+ local_sum[0] = 0.0f; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = static_cast(tmp_vec_ptr[j]); -+ tmp -= s_mean; -+ local_sum[0] += tmp * tmp; -+ } -+ } -+ } -+ if (bdimx <= 32) { -+ warpReduceSum(local_sum); -+ } -+ else { -+ blockReduceSum(local_sum); -+ } -+ if (tid == 0) { -+ s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); -+ } -+ __syncthreads(); -+ -+ // normalize -+ const int gamma_offset_of_group = gid * v_group_stride; -+ const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; -+ const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; -+#pragma unroll -+ for (int i = 0; i < TVecs_PER_THREAD; i += 1) { -+ const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; -+ if (current_load_start_idx < s_reduce_elements) { -+ const int offset_in_group = -+ ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) -+ / T_PER_TVec; -+ const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; -+ TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; -+ TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; -+ T* gamma_val_ptr = (T*)(&gamma_val); -+ T* beta_val_ptr = (T*)(&beta_val); -+ TVec tmp_vec = input_TVec_ptr[offset_in_group]; -+ T* tmp_vec_ptr = (T*)(&tmp_vec); -+ TVec output_tmp_vec; -+ T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); -+#pragma unroll -+ for (int j = 0; j < T_PER_TVec; j++) { -+ float tmp = -+ (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) -+ + static_cast(beta_val_ptr[j]); -+ if (sizeof(T) == sizeof(half)) { -+ output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); -+ } -+ else { -+ output_tmp_vec_ptr[j] = T(tmp); -+ } -+ } -+ output_TVec_ptr[offset_in_group] = output_tmp_vec; -+ } -+ } -+} -+ -+//ref_input & ref_output should be [N, H, W, C] -+//ref_gamma & ref_beta shoud be [1, 1, 1, C] -+template -+void groupnorm(cutlass::Tensor4DCoord input_size, -+ const int num_groups, -+ const float eps, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream){ -+ const int N = input_size.n(); -+ const int H = input_size.h(); -+ const int W = input_size.w(); -+ const int C = input_size.c(); -+ if (C % num_groups != 0){ -+ printf("[ERROR] C should be a multiple of num_groups.\n"); -+ } -+ T* output = ref_output.data(); -+ const T* input = ref_input.data(); -+ const T* gamma = ref_gamma.data(); -+ const T* beta = ref_beta.data(); -+ -+ const int dim0 = N; -+ const int last_dim = C; -+ const int prod_dim1_to_last_dim = H*W*C; -+ const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; -+ const int s_group_stride = last_dim / num_groups; -+ dim3 grid(num_groups, dim0); -+ int threadblock_size = 32; -+ if (s_group_stride % 2 == 0) { -+ const int T_PER_TVec = 2; -+ while (threadblock_size < 1024) { -+ if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) -+ break; -+ threadblock_size *= 2; -+ } -+ dim3 block(threadblock_size); -+ const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; -+ const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); -+ // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; -+ // the size of grid & block may have better choice for different cases. -+ // ensure shared memory is smaller than 48KB -+ if (std::is_same::value){ -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ else{ -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ } -+ else { -+ const int T_PER_TVec = 1; -+ while (threadblock_size < 1024) { -+ if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) -+ break; -+ threadblock_size *= 2; -+ } -+ dim3 block(threadblock_size); -+ const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; -+ const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); -+ if (shm_size < 48 * 1024) { -+ groupnorm_twopass_store_locally<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ else { -+ groupnorm_twopass_multiple_load<<>>( -+ output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); -+ } -+ } -+ -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h -new file mode 100644 -index 0000000..c4ec925 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_layernorm.h -@@ -0,0 +1,644 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. -+ * \tparam T: data type -+ */ -+template -+void layernorm(cutlass::MatrixCoord tensor_size, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream); -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ T local_val[ITEM_PER_THREAD]; -+ float local_sums[1] = {0.0f}; -+ int offset = m_idx * n; -+ input += offset; -+ output += offset; -+ -+ const T zero = T(0.0f); -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ local_val[i] = index < n ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i]); -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ if (index < n){ -+ const float tmp = static_cast(local_val[i]) - s_mean; -+ local_sums[0] += tmp * tmp; -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma unroll -+ for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ -+ int index = tid + i*bdimx; -+ if (index < n) { -+ const T gamma_val = gamma[index]; -+ const T beta_val = beta[index]; -+ output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, -+ const T2* input, -+ const T2* gamma, -+ const T2* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ T2 local_val[ITEM_PER_THREAD]; -+ const int n_2 = n / 2; -+ int offset = m_idx * n_2; -+ input += offset; -+ output += offset; -+ -+ const T2 zero = {T(0.0f), T(0.0f)}; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ local_val[i] = index < n_2 ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_2){ -+ const float2 tmp = {static_cast(local_val[i].x) - s_mean, -+ static_cast(local_val[i].y) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; -+ } -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_2){ -+ const T2 gamma_val = gamma[index]; -+ const T2 beta_val = beta[index]; -+ T2 tmp; -+ tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ output[index] = tmp; -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, -+ const T4* input, -+ const T4* gamma, -+ const T4* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ T4 local_val[ITEM_PER_THREAD]; -+ const int n_4 = n / 4; -+ int offset = m_idx * n_4; -+ input += offset; -+ output += offset; -+ -+ const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ local_val[i] = index < n_4 ? input[index] : zero; -+ local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + -+ static_cast(local_val[i].z) + static_cast(local_val[i].w); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_4){ -+ const float4 tmp = {static_cast(local_val[i].x) - s_mean, -+ static_cast(local_val[i].y) - s_mean, -+ static_cast(local_val[i].z) - s_mean, -+ static_cast(local_val[i].w) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; -+ } -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ #pragma UNROLL -+ for (int i = 0; i < ITEM_PER_THREAD; i += 1) { -+ const int index = i*bdimx + tid; -+ if (index < n_4){ -+ const T4 gamma_val = gamma[index]; -+ const T4 beta_val = beta[index]; -+ T4 tmp; -+ tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); -+ tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); -+ output[index] = tmp; -+ } -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements -+*/ -+template -+__global__ void layernorm_twoPassAlgo_e1(T* output, -+ const T* input, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ int offset = m_idx * n; -+ input += offset; -+ output += offset; -+ -+ for (int index = tid ; index < n ; index += bdimx){ -+ float local_val = static_cast(input[index]); -+ local_sums[0] += local_val; -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ for (int index = tid ; index < n ; index += bdimx){ -+ float local_val = static_cast(input[index]); -+ local_val = local_val - s_mean; -+ local_sums[0] += local_val * local_val; -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ for (int index = tid ; index < n ; index += bdimx){ -+ const T gamma_val = gamma[index]; -+ const T beta_val = beta[index]; -+ const T local_val = input[index]; -+ output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); -+ } -+} -+ -+/** -+ * output [m, n] row-major -+ * input [m, n] row-major -+ * gamma [n] -+ * beta [n] -+ * grid(m) -+ * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; -+*/ -+template -+__global__ void layernorm_twoPassAlgo_e2(T2* output, -+ const T2* input, -+ const T2* gamma, -+ const T2* beta, -+ const int m, -+ const int n) -+{ -+ const int m_idx = blockIdx.x; -+ const int tid = threadIdx.x; -+ const int bdimx = blockDim.x; -+ __shared__ float s_mean, s_variance; -+ float local_sums[1] = {0.0f}; -+ const int n_2 = n / 2; -+ int offset = m_idx * n_2; -+ input += offset; -+ output += offset; -+ -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_mean = local_sums[0] / n; -+ } -+ __syncthreads(); -+ -+ local_sums[0] = 0.0f; -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ const float2 tmp = {static_cast(local_val.x) - s_mean, -+ static_cast(local_val.y) - s_mean}; -+ local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; -+ } -+ if (blockDim.x <= 32) { -+ warpReduceSum(local_sums); -+ } -+ else { -+ blockReduceSum(local_sums); -+ } -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(local_sums[0] / n + 1e-5); -+ } -+ __syncthreads(); -+ -+ for (int index = tid; index < n_2; index += bdimx) { -+ const T2 local_val = input[index]; -+ const T2 gamma_val = gamma[index]; -+ const T2 beta_val = beta[index]; -+ T2 tmp; -+ tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); -+ tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); -+ output[index] = tmp; -+ } -+} -+ -+template -+void layernorm(cutlass::MatrixCoord tensor_size, -+ TensorRef ref_output, -+ TensorRef ref_input, -+ TensorRef ref_gamma, -+ TensorRef ref_beta, -+ cudaStream_t stream){ -+ const int m = tensor_size.row(); -+ const int n = tensor_size.column(); -+ T* output = ref_output.data(); -+ const T* input = ref_input.data(); -+ const T* gamma = ref_gamma.data(); -+ const T* beta = ref_beta.data(); -+ dim3 grid(m); -+ dim3 block((n + 31)/32*32); -+ if (block.x > 1024){ -+ block.x = 1024; -+ } -+ // TODO : There should be better configs for different cases, we only use several samples to show how to use here -+ // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. -+ if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { -+ block.x = (n/4 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e4<<>>( -+ (float4*)output, -+ (const float4*)input, -+ (const float4*)gamma, -+ (const float4*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e4<<>>( -+ (half4*)output, -+ (const half4*)input, -+ (const half4*)gamma, -+ (const half4*)beta, -+ m, -+ n); -+ } -+ } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) -+ else if (n % 2 == 0) { -+ if (n / 2 <= 1024) { -+ block.x = (n/2 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } //if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n / 2 <= 1024) -+ else if (n <= 8192) { -+ block.x = ((n + 7)/8 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 8192) -+ else if (n <= 16384) { -+ block.x = ((n + 15)/ 16 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 16384) -+ else if (n <= 32768) { -+ block.x = ((n + 31)/32 + 31)/32*32; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (float2*)output, -+ (const float2*)input, -+ (const float2*)gamma, -+ (const float2*)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_stored_locally_e2<<>>( -+ (half2*)output, -+ (const half2*)input, -+ (const half2*)gamma, -+ (const half2*)beta, -+ m, -+ n); -+ } -+ } // if (n <= 32768) -+ else { -+ if (block.x > 512) -+ block.x = 512; -+ if (std::is_same::value) { -+ layernorm_twoPassAlgo_e2<<>>( -+ (float2 *)output, -+ (const float2 *)input, -+ (const float2 *)gamma, -+ (const float2 *)beta, -+ m, -+ n); -+ } // if (std::is_same::value) -+ else { -+ layernorm_twoPassAlgo_e2<<>>( -+ (half2 *)output, -+ (const half2 *)input, -+ (const half2 *)gamma, -+ (const half2 *)beta, -+ m, -+ n); -+ } -+ } -+ } // if (n % 2 == 0) -+ else { -+ if (n <= 1024) { -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 1024) -+ else if (n <= 8192) { -+ block.x = ((n + 7)/8 + 31)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 8192) -+ else if (n <= 16384) { -+ block.x = ((n + 15)/16 + 32)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 16384) -+ else if (n <= 32768) { -+ block.x = ((n + 31)/32 + 31)/32*32; -+ layernorm_twoPassAlgo_stored_locally_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } // if (n <= 32768) -+ else{ -+ if (block.x > 512) { -+ block.x = 512; -+ } -+ layernorm_twoPassAlgo_e1<<>>( -+ output, -+ input, -+ gamma, -+ beta, -+ m, -+ n); -+ } -+ } -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h -new file mode 100644 -index 0000000..67dfff5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_memory.h -@@ -0,0 +1,338 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ interface to CUDA device memory management functions. -+ */ -+ -+#include -+ -+#include "cutlass/platform/platform.h" -+#include "cutlass/numeric_types.h" -+#include "exceptions.h" -+ -+namespace cutlass { -+namespace device_memory { -+ -+/****************************************************************************** -+ * Allocation lifetime -+ ******************************************************************************/ -+ -+/// Allocate a buffer of \p count elements of type \p T on the current CUDA device -+template -+T* allocate(size_t count = 1) { -+ -+ T* ptr = 0; -+ size_t bytes = 0; -+ -+ bytes = count * sizeof(T); -+ -+ cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); -+ -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("Failed to allocate memory", cuda_error); -+ } -+ -+ return ptr; -+} -+ -+/// Free the buffer pointed to by \p ptr -+template -+void free(T* ptr) { -+ if (ptr) { -+ cudaError_t cuda_error = (cudaFree(ptr)); -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("Failed to free device memory", cuda_error); -+ } -+ } -+} -+ -+/****************************************************************************** -+ * Data movement -+ ******************************************************************************/ -+ -+template -+void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { -+ size_t bytes = count * sizeof_bits::value / 8; -+ if (bytes == 0 && count > 0) -+ bytes = 1; -+ cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); -+ if (cuda_error != cudaSuccess) { -+ throw cuda_exception("cudaMemcpy() failed", cuda_error); -+ } -+} -+ -+template -+void copy_to_device(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyHostToDevice); -+} -+ -+template -+void copy_to_host(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyDeviceToHost); -+} -+ -+template -+void copy_device_to_device(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyDeviceToDevice); -+} -+ -+template -+void copy_host_to_host(T* dst, T const* src, size_t count = 1) { -+ copy(dst, src, count, cudaMemcpyHostToHost); -+} -+ -+/// Copies elements from device memory to host-side range -+template -+void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { -+ size_t elements = end - begin; -+ copy_to_host(&*begin, device_begin, elements); -+} -+ -+/// Copies elements to device memory from host-side range -+template -+void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { -+ size_t elements = end - begin; -+ copy_to_device(device_begin, &*begin, elements); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device_memory -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+class DeviceAllocation { -+public: -+ -+ /// Delete functor for CUDA device memory -+ struct deleter { -+ void operator()(T* ptr) { -+ cudaError_t cuda_error = (cudaFree(ptr)); -+ if (cuda_error != cudaSuccess) { -+ // noexcept -+ // throw cuda_exception("cudaFree() failed", cuda_error); -+ return; -+ } -+ } -+ }; -+ -+public: -+ // -+ // Data members -+ // -+ -+ /// Number of elements of T allocated on the current CUDA device -+ size_t capacity; -+ -+ /// Smart pointer -+ platform::unique_ptr smart_ptr; -+ -+public: -+ -+ // -+ // Static methods -+ // -+ -+ /// Static member to compute the number of bytes needed for a given number of elements -+ static size_t bytes(size_t elements) { -+ if (sizeof_bits::value < 8) { -+ size_t const kElementsPerByte = 8 / sizeof_bits::value; -+ return elements / kElementsPerByte; -+ } -+ else { -+ size_t const kBytesPerElement = sizeof_bits::value / 8; -+ return elements * kBytesPerElement; -+ } -+ } -+ -+public: -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor: allocates no memory -+ DeviceAllocation() : capacity(0) {} -+ -+ /// Constructor: allocates \p capacity elements on the current CUDA device -+ DeviceAllocation(size_t _capacity) : -+ smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} -+ -+ /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation -+ DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} -+ -+ /// Copy constructor -+ DeviceAllocation(DeviceAllocation const &p): -+ smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { -+ -+ device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); -+ } -+ -+ /// Move constructor -+ DeviceAllocation(DeviceAllocation &&p): capacity(0) { -+ std::swap(smart_ptr, p.smart_ptr); -+ std::swap(capacity, p.capacity); -+ } -+ -+ /// Destructor -+ ~DeviceAllocation() { reset(); } -+ -+ /// Returns a pointer to the managed object -+ T* get() const { return smart_ptr.get(); } -+ -+ /// Releases the ownership of the managed object (without deleting) and resets capacity to zero -+ T* release() { -+ capacity = 0; -+ return smart_ptr.release(); -+ } -+ -+ /// Deletes the managed object and resets capacity to zero -+ void reset() { -+ capacity = 0; -+ smart_ptr.reset(); -+ } -+ -+ /// Deletes managed object, if owned, and allocates a new object -+ void reset(size_t _capacity) { -+ reset(device_memory::allocate(_capacity), _capacity); -+ } -+ -+ /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity -+ void reset(T* _ptr, size_t _capacity) { -+ smart_ptr.reset(_ptr); -+ capacity = _capacity; -+ } -+ -+ /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. -+ void reallocate(size_t new_capacity) { -+ -+ platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); -+ -+ device_memory::copy_device_to_device( -+ new_allocation.get(), -+ smart_ptr.get(), -+ std::min(new_capacity, capacity)); -+ -+ std::swap(smart_ptr, new_allocation); -+ std::swap(new_capacity, capacity); -+ } -+ -+ /// Returns the number of elements -+ size_t size() const { -+ return capacity; -+ } -+ -+ /// Returns the number of bytes needed to store the allocation -+ size_t bytes() const { -+ return bytes(capacity); -+ } -+ -+ /// Returns a pointer to the object owned by *this -+ T* operator->() const { return smart_ptr.get(); } -+ -+ /// Returns the deleter object which would be used for destruction of the managed object. -+ deleter& get_deleter() { return smart_ptr.get_deleter(); } -+ -+ /// Returns the deleter object which would be used for destruction of the managed object (const) -+ const deleter& get_deleter() const { return smart_ptr.get_deleter(); } -+ -+ /// Copies a device-side memory allocation -+ DeviceAllocation & operator=(DeviceAllocation const &p) { -+ if (capacity != p.capacity) { -+ smart_ptr.reset(device_memory::allocate(p.capacity)); -+ capacity = p.capacity; -+ } -+ device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); -+ return *this; -+ } -+ -+ /// Move assignment -+ DeviceAllocation & operator=(DeviceAllocation && p) { -+ std::swap(smart_ptr, p.smart_ptr); -+ std::swap(capacity, p.capacity); -+ return *this; -+ } -+ -+ /// Copies the entire allocation from another location in device memory. -+ void copy_from_device(T const *ptr) const { -+ copy_from_device(ptr, capacity); -+ } -+ -+ /// Copies a given number of elements from device memory -+ void copy_from_device(T const *ptr, size_t elements) const { -+ device_memory::copy_device_to_device(get(), ptr, elements); -+ } -+ -+ void copy_to_device(T *ptr) const { -+ copy_to_device(ptr, capacity); -+ } -+ -+ void copy_to_device(T *ptr, size_t elements) const { -+ device_memory::copy_device_to_device(ptr, get(), elements); -+ } -+ -+ void copy_from_host(T const *ptr) const { -+ copy_from_host(ptr, capacity); -+ } -+ -+ void copy_from_host(T const *ptr, size_t elements) const { -+ device_memory::copy_to_device(get(), ptr, elements); -+ } -+ -+ void copy_to_host(T *ptr) const { -+ copy_to_host(ptr, capacity); -+ } -+ -+ void copy_to_host(T *ptr, size_t elements) const { -+ device_memory::copy_to_host(ptr, get(), elements); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace device_memory { -+ -+/// Device allocation abstraction that tracks size and capacity -+template -+using allocation = cutlass::DeviceAllocation; -+ -+} // namespace device_memory -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h -new file mode 100644 -index 0000000..8628c7a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h -@@ -0,0 +1,141 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+template -+__global__ void nchw_to_nhwc_kernel(T *output, -+ const T *input, -+ const int n, -+ const int h, -+ const int w, -+ const int c) { -+ const int hw = h*w; -+ const int chw = c*hw; -+ __shared__ T shbuf[32 * (32 + 1)]; -+ const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; -+ const int32_t wid = tid / 32; -+ const int32_t lid = tid % 32; -+ const int32_t ni = blockIdx.z; -+ const int32_t ci0 = blockIdx.y * 32; -+ const int32_t hwi0 = blockIdx.x * 32; -+ -+ const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; -+ const T *A = input + input_idx; -+ if (hwi0 + lid < hw) { -+ const int lid_x_33 = lid * 33; -+ if ((ci0 + 32) <= c) { -+ int ci = wid; // between 0 and 7 -+ CUTLASS_PRAGMA_UNROLL -+ for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { -+ shbuf[lid_x_33 + ci] = A[lid]; -+ A = &A[8 * hw]; -+ ci += 8; -+ } -+ } else { -+ for (int ci = wid; ci < 32; ci += 8) { -+ if ((ci + ci0) < c) { -+ shbuf[lid_x_33 + ci] = A[lid]; -+ } -+ A = &A[8 * hw]; -+ } -+ } -+ } -+ __syncthreads(); -+ -+ const int32_t ciOut = ci0 + lid; -+ output = &output[ni * chw + ciOut]; -+ if (ciOut < c) { -+ if (hwi0 + 32 < hw) { -+ int hwI = wid; -+ CUTLASS_PRAGMA_UNROLL -+ for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { -+ output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; -+ hwI += 8; -+ } -+ } else { -+ for (int hwI = wid; hwI < 32; hwI += 8) { -+ if (hwi0 + hwI < hw) { -+ output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; -+ } -+ } -+ } -+ } -+} -+ -+template -+void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream) { -+ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.c() == output_tensor_size.h() && -+ input_tensor_size.h() == output_tensor_size.w() && -+ input_tensor_size.w() == output_tensor_size.c()); -+ -+ int n = output_tensor_size.n(); -+ int h = output_tensor_size.h(); -+ int w = output_tensor_size.w(); -+ int c = output_tensor_size.c(); -+ -+ dim3 grid((h*w + 31)/32, (c + 31)/32, n); -+ dim3 block(32, 8); -+ nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), -+ n, h, w, c); -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h -new file mode 100644 -index 0000000..86e5fa7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h -@@ -0,0 +1,276 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels for padding in device memory with NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface for padding in a device memory tensor with NHWC layout -+ * \tparam T: data type -+ */ -+template -+void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void nhwc_padding_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const int32_t c_in, -+ const int32_t c_out, -+ const T zero, -+ const T *input, -+ T *output){ -+ -+ const int32_t idx_jump = blockDim.x * gridDim.x; -+ const int32_t total_elements = n * h * w * c_out; -+ -+ int32_t c_idx, w_idx, h_idx, n_idx, resudial; -+ -+ T value; -+ for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { -+ -+ c_idx = idx%c_out; -+ if (c_idx >= c_in){ -+ value = zero; -+ } -+ else{ -+ resudial = idx/c_out; -+ w_idx = resudial%w; -+ resudial = resudial/w; -+ h_idx = resudial%h; -+ n_idx = resudial/h; -+ resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; -+ value = input[resudial]; -+ } -+ output[idx] = value; -+ } -+} -+ -+ -+// fast kernel for c_in = 3 & c_out = 4 -+template -+__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const Tio *input, -+ Tio *output, -+ const int32_t max_output_element, -+ const int32_t max_input_element, -+ const Tio zero_io, -+ const Telement zero_element){ -+ __shared__ Tio shm[192]; -+ const int tidx = blockIdx.x * 192 + threadIdx.x; -+ const int threadidx = threadIdx.x; -+ -+ shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; -+ __syncthreads(); -+ -+ const int ouput_offset = blockIdx.x * 256; -+ const int lower_bound = max_output_element < ouput_offset + 256 ? max_output_element : ouput_offset + 256; -+ for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) -+ { -+ const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; -+ Telement array[element_in_Tio]; -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; -+ output[i] = *((const Tio *)array); -+ } -+} -+ -+// fast kernel for c_in = 3 & c_out = 8 -+template -+__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, -+ const int32_t h, -+ const int32_t w, -+ const Tio *input, -+ Tio *output, -+ const int32_t max_output_element, -+ const int32_t max_input_element, -+ const Tio zero_io, -+ const Telement zero_element){ -+ __shared__ Tio shm[192]; -+ const int tidx = blockIdx.x * 192 + threadIdx.x; -+ const int threadidx = threadIdx.x; -+ -+ shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; -+ __syncthreads(); -+ -+ const int ouput_offset = blockIdx.x * 512; -+ const int lower_bound = max_output_element < ouput_offset + 512 ? max_output_element : ouput_offset + 512; -+ for (int i = ouput_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) -+ { -+ const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; -+ Telement array[element_in_Tio]; -+ //float -+ if (element_in_Tio == 4){ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); -+ } -+ //half -+ else{ -+ CUTLASS_PRAGMA_UNROLL -+ for (int k = 0 ; k < element_in_Tio ; k++) -+ array[k] = (k >= 3) ? zero_element : shm_element[k]; -+ } -+ output[i] = *((const Tio *)array); -+ } -+} -+ -+template -+void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream){ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.h() == output_tensor_size.h() && -+ input_tensor_size.w() == output_tensor_size.w() && -+ input_tensor_size.c() <= output_tensor_size.c()); -+ -+ int n = input_tensor_size.n(); -+ int h = input_tensor_size.h(); -+ int w = input_tensor_size.w(); -+ int c_in = input_tensor_size.c(); -+ int c_out = output_tensor_size.c(); -+ -+ //case 1 : channel == 3 padding to 4 or 8 -+ if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ -+ dim3 block(192); -+ const int nhw = n*h*w; -+ const int nhwc = nhw*c_in; -+ //for half_t -+ if (cutlass::sizeof_bits::value == 16){ -+ const int element_in_Tio = 8; -+ const int max_input_element = nhwc/element_in_Tio; -+ const int max_output_element = nhw*c_out/element_in_Tio; -+ const int4 zero_io = {0, 0, 0, 0}; -+ const half_t zero_element = static_cast(0.0f); -+ dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); -+ if (c_out == 4){ -+ nhwc_padding_channel_3To4_kernel<<>> -+ (n, h, w, -+ (const int4 *)ref_input.data(), -+ (int4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ else if (c_out == 8){ -+ nhwc_padding_channel_3To8_kernel<<>> -+ (n, h, w, -+ (const int4 *)ref_input.data(), -+ (int4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ } -+ //for float -+ else{ -+ const int element_in_Tio = 4; -+ const int max_input_element = nhwc/element_in_Tio; -+ const int max_output_element = nhw*c_out/element_in_Tio; -+ const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; -+ const float zero_element = 0.0f; -+ dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); -+ if (c_out == 4){ -+ nhwc_padding_channel_3To4_kernel<<>> -+ (n, h, w, -+ (const float4 *)ref_input.data(), -+ (float4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ else if (c_out == 8){ -+ nhwc_padding_channel_3To8_kernel<<>> -+ (n, h, w, -+ (const float4 *)ref_input.data(), -+ (float4 *)ref_output.data(), -+ max_output_element, -+ max_input_element, -+ zero_io, -+ zero_element); -+ } -+ } -+ } -+ //case 2 : even channel -+ else if ((c_out % 2) == 0 && (c_in % 2) == 0){ -+ int32_t total_elements = n * h * w * c_out / 2; -+ int block_size = 256; -+ dim3 grid((total_elements + 255)/256); -+ dim3 block(block_size); -+ //for half_t -+ if (cutlass::sizeof_bits::value == 16){ -+ const __half2 zero = {0.0f, 0.0f}; -+ nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); -+ } -+ //for float -+ else{ -+ const float2 zero = {0.0f, 0.0f}; -+ nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); -+ } -+ } -+ //case 3 : odd channel -+ else{ -+ int32_t total_elements = n * h * w * c_out; -+ int block_size = 256; -+ dim3 grid((total_elements + 255)/256); -+ dim3 block(block_size); -+ const T zero = static_cast(0.0f); -+ nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); -+ } -+} -+ -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h -new file mode 100644 -index 0000000..6bdf866 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h -@@ -0,0 +1,576 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+#include "device_utils.h" -+#include -+ -+namespace cutlass { -+ -+/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. -+ * \tparam T: data type -+ */ -+template -+void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord filter_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ cutlass::MatrixCoord padding, -+ cutlass::MatrixCoord stride, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ int poolingType, //0 for avg pooling ; 1 for max pooling -+ cudaStream_t stream); -+ -+/** get the output size of pooling -+ */ -+inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) -+{ -+ return (H_W + 2 * padding - kernel_size) / stride + 1; -+} -+ -+/** -+ * input is [N, H, W, C] -+ * assume stride == kernel_size -+ * output_h = (H + 2*padding_H - kernel_H)/stride_H -+ * output_w = (W + 2*padding_W - kernel_W)/stride_W -+ * output is [N, output_h, output_w, C] -+ * grid(N, output_h, output_w) -+ * block(min(C, 256)) : -+ * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) -+*/ -+template -+__global__ void pooling_nhwc_element1_kernel(T* output, -+ const T* input, -+ const int N, -+ const int H, -+ const int W, -+ const int C, -+ const int output_H, -+ const int output_W, -+ const int kernel_H, -+ const int kernel_W, -+ const int stride_H, -+ const int stride_W, -+ const int padding_H, -+ const int padding_W) -+{ -+ const int tid = threadIdx.x; -+ const int n_idx = blockIdx.x; -+ const int output_h_idx = blockIdx.y; -+ const int output_w_idx = blockIdx.z; -+ -+ int h_start_idx = output_h_idx * stride_H - padding_H; -+ int h_end_idx = h_start_idx + kernel_H; -+ h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; -+ h_end_idx = h_end_idx > H ? H : h_end_idx; -+ -+ int w_start_idx = output_w_idx * stride_W - padding_W; -+ int w_end_idx = w_start_idx + kernel_W; -+ w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; -+ w_end_idx = w_end_idx > W ? W : w_end_idx; -+ -+ input += n_idx * H * W * C; -+ output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; -+ const int kernel_size2 = kernel_H * kernel_W; -+ for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { -+ float pooling; -+ if (IS_AVG_POOLING){ -+ pooling = 0.0f; -+ } -+ else{ -+ pooling = -FLT_MAX; -+ } -+ for (int h = h_start_idx; h < h_end_idx; h++) { -+ for (int w = w_start_idx; w < w_end_idx; w++) { -+ const int idx = (h * W + w) * C; -+ const float tmp = static_cast(input[idx + c_idx]); -+ if (IS_AVG_POOLING){ -+ pooling = pooling + tmp; -+ } -+ else{ -+ pooling = pooling > tmp ? pooling : tmp; -+ } -+ } -+ } -+ -+ T output_val; -+ if (IS_AVG_POOLING){ -+ output_val = T(pooling/kernel_size2); -+ } -+ else{ -+ output_val = T(pooling); -+ } -+ output[c_idx] = output_val; -+ } -+} -+ -+template -+__global__ void pooling_nhwc_element2_kernel(T2* output, -+ const T2* input, -+ const int N, -+ const int H, -+ const int W, -+ const int C, -+ const int output_H, -+ const int output_W, -+ const int kernel_H, -+ const int kernel_W, -+ const int stride_H, -+ const int stride_W, -+ const int padding_H, -+ const int padding_W) -+{ -+ const int tid = threadIdx.x; -+ const int n_idx = blockIdx.x; -+ const int output_h_idx = blockIdx.y; -+ const int output_w_idx = blockIdx.z; -+ -+ int h_start_idx = output_h_idx * stride_H - padding_H; -+ int h_end_idx = h_start_idx + kernel_H; -+ h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; -+ h_end_idx = h_end_idx > H ? H : h_end_idx; -+ -+ int w_start_idx = output_w_idx * stride_W - padding_W; -+ int w_end_idx = w_start_idx + kernel_W; -+ w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; -+ w_end_idx = w_end_idx > W ? W : w_end_idx; -+ -+ input += n_idx * H * W * C; -+ output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; -+ const int kernel_size2 = kernel_H * kernel_W; -+ for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { -+ float2 pooling; -+ if (IS_AVG_POOLING) { -+ pooling = {0.0f, 0.0f}; -+ } -+ else { -+ pooling = {-FLT_MAX, -FLT_MAX}; -+ } -+ for (int h = h_start_idx; h < h_end_idx; h++) { -+ for (int w = w_start_idx; w < w_end_idx; w++) { -+ const int idx = (h * W + w) * C; -+ const T2 tmp = input[idx + c_idx]; -+ const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; -+ if (IS_AVG_POOLING) { -+ pooling.x += tmp_flt2.x; -+ pooling.y += tmp_flt2.y; -+ } -+ else { -+ pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; -+ pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; -+ } -+ } -+ } -+ -+ T2 output_val; -+ if (IS_AVG_POOLING) { -+ output_val.x = T(pooling.x/kernel_size2); -+ output_val.y = T(pooling.y/kernel_size2); -+ } -+ else { -+ output_val.x = T(pooling.x); -+ output_val.y = T(pooling.y); -+ } -+ output[c_idx] = output_val; -+ } -+} -+ -+/** -+ * output [N, 1, 1, C] -+ * input [N, H, W, C] -+ * grid(C, N) -+ * block(block_size) -- each block deals with H*W/block_size elements; -+*/ -+template -+__global__ void pooling_nxhTo1x1_element1_kernel( -+ T* output, const T* input, const int N, const int HW, const int C) -+{ -+ const int c_idx = blockIdx.x; -+ const int n_idx = blockIdx.y; -+ float pooling[1]; -+ if (IS_AVG_POOLING) { -+ pooling[0] = 0.0f; -+ } -+ else { -+ pooling[0] = -FLT_MAX; -+ } -+ const size_t input_offset = n_idx * HW * C + c_idx; -+ input += input_offset; -+ const size_t output_offset = n_idx * C + c_idx; -+ output += output_offset; -+ int tid = threadIdx.x; -+ -+ for (int index = tid; index < HW; index += blockDim.x) { -+ float val = static_cast(input[index * C]); -+ if (IS_AVG_POOLING) { -+ pooling[0] += val; -+ } -+ else { -+ pooling[0] = pooling[0] > val ? pooling[0] : val; -+ } -+ } -+ if (blockDim.x <= 32) { -+ if (IS_AVG_POOLING) { -+ warpReduceSum(pooling); -+ } -+ else { -+ warpReduceMax(pooling); -+ } -+ } -+ else { -+ if (IS_AVG_POOLING) { -+ blockReduceSum(pooling); -+ } -+ else { -+ blockReduceMax(pooling); -+ } -+ } -+ __syncthreads(); -+ if (threadIdx.x == 0) { -+ T output_val; -+ if (IS_AVG_POOLING) { -+ output_val = T(pooling[0] / HW); -+ } -+ else { -+ output_val = T(pooling[0]); -+ } -+ output[0] = output_val; -+ } -+} -+ -+ -+/** -+ * output [N, 1, 1, C] -+ * input [N, H, W, C] -+ * grid(C/2, N) -+ * block(block_size) -- each thread deals with H*W/block_size * 2 elements; -+*/ -+template -+__global__ void pooling_nxhTo1x1_element2_kernel( -+ T2* output, const T2* input, const int N, const int HW, const int C) -+{ -+ const int c_idx = blockIdx.x; -+ const int n_idx = blockIdx.y; -+ float pooling[2]; -+ if (IS_AVG_POOLING) { -+ pooling[0] = pooling[1] = 0.0f; -+ } -+ else { -+ pooling[0] = pooling[1] = -FLT_MAX; -+ } -+ const int C_2 = C / 2; -+ const size_t input_offset = n_idx * HW * C_2 + c_idx; -+ input += input_offset; -+ const size_t output_offset = n_idx * C_2 + c_idx; -+ output += output_offset; -+ int tid = threadIdx.x; -+ -+ for (int index = tid; index < HW; index += blockDim.x) { -+ T2 val = input[index * C_2]; -+ float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; -+ if (IS_AVG_POOLING) { -+ pooling[0] += val_flt2.x; -+ pooling[1] += val_flt2.y; -+ } -+ else { -+ pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; -+ pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; -+ } -+ } -+ if (blockDim.x <= 32) { -+ if (IS_AVG_POOLING) { -+ warpReduceSum(pooling); -+ } -+ else { -+ warpReduceMax(pooling); -+ } -+ } -+ else { -+ if (IS_AVG_POOLING) { -+ blockReduceSum(pooling); -+ } -+ else { -+ blockReduceMax(pooling); -+ } -+ } -+ __syncthreads(); -+ if (threadIdx.x == 0) { -+ T2 output_val; -+ if (IS_AVG_POOLING) { -+ output_val.x = T(pooling[0] / HW); -+ output_val.y = T(pooling[1] / HW); -+ } -+ else { -+ output_val.x = T(pooling[0]); -+ output_val.y = T(pooling[1]); -+ } -+ output[0] = output_val; -+ } -+} -+ -+template -+void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord filter_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ cutlass::Tensor4DCoord padding, -+ cutlass::MatrixCoord stride, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ int poolingType, //0 for avg pooling ; 1 for max pooling -+ cudaStream_t stream) { -+ -+ assert(input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.c() == output_tensor_size.c()); -+ -+ assert(filter_tensor_size.h() == stride.row() && -+ filter_tensor_size.w() == stride.column()); -+ -+ const int N = input_tensor_size.n(); -+ const int H = input_tensor_size.h(); -+ const int W = input_tensor_size.w(); -+ const int C = input_tensor_size.c(); -+ const int padding_H = padding.h(); -+ const int padding_W = padding.w(); -+ const int kernel_H = filter_tensor_size.h(); -+ const int kernel_W = filter_tensor_size.w(); -+ const int stride_H = stride.row(); -+ const int stride_W = stride.column(); -+ -+ const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); -+ const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); -+ -+ assert(output_tensor_size.h() == output_H && -+ output_tensor_size.w() == output_W); -+ -+ if (C % 2 != 0) { -+ if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { -+ dim3 grid(C, N); -+ dim3 block(256); -+ if (H*W < block.x){ -+ block.x = (H*W + 31)/32*32; -+ } -+ if (poolingType == 0) { -+ pooling_nxhTo1x1_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H*W, -+ C); -+ } // if (poolingType == 0) -+ else { -+ pooling_nxhTo1x1_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H*W, -+ C); -+ } -+ } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) -+ else { -+ dim3 grid(N, output_H, output_W); -+ dim3 block(256); -+ if (C < block.x) { -+ block.x = C; -+ } -+ if (poolingType == 0) { -+ pooling_nhwc_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H, -+ W, -+ C, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (poolingType == 0) -+ else { -+ pooling_nhwc_element1_kernel<<>>( -+ ref_output.data(), -+ ref_input.data(), -+ N, -+ H, -+ W, -+ C, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } -+ } // if (C % 2 != 0)) -+ else { -+ if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { -+ dim3 grid(C/2, N); -+ dim3 block(256); -+ if (H*W < block.x){ -+ block.x = (H*W + 31)/32*32; -+ } -+ if (poolingType == 0) { -+ if (std::is_same::value) { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } // if (std::is_same::value) -+ else { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } -+ } // if (poolingType == 0) -+ else { -+ if (std::is_same::value) { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } // if (std::is_same::value) -+ else { -+ pooling_nxhTo1x1_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H*W, -+ C); -+ } -+ } -+ } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) -+ else { -+ dim3 grid(N, output_H, output_W); -+ dim3 block(256); -+ if (C/2 < block.x) { -+ block.x = C/2; -+ } -+ if (poolingType == 0) { -+ if (std::is_same::value) { -+ pooling_nhwc_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (std::is_same::value) -+ else { -+ pooling_nhwc_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } // if (poolingType == 0) -+ else { -+ if (std::is_same::value) { -+ pooling_nhwc_element2_kernel<<>>( -+ (float2*)(ref_output.data()), -+ (const float2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } // if (std::is_same::value) -+ else { -+ pooling_nhwc_element2_kernel<<>>( -+ (half2*)(ref_output.data()), -+ (const half2*)(ref_input.data()), -+ N, -+ H, -+ W, -+ C/2, -+ output_H, -+ output_W, -+ kernel_H, -+ kernel_W, -+ stride_H, -+ stride_W, -+ padding_H, -+ padding_W); -+ } -+ } -+ } -+ } -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h -new file mode 100644 -index 0000000..d71fd1e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h -@@ -0,0 +1,144 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. -+ */ -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_coord.h" -+#include "cutlass/tensor_ref.h" -+ -+namespace cutlass { -+ -+/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. -+ * \tparam T: data type -+ */ -+template -+void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void nhwc_to_nchw_kernel(T *output, -+ const T *input, -+ const int n, -+ const int h, -+ const int w, -+ const int c) { -+ -+ const int hw = h*w; -+ const int hwc = hw*c; -+ __shared__ T shbuf[32 * (32 + 1)]; -+ const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; -+ const int32_t wid = tid / 32; -+ const int32_t lid = tid % 32; -+ const int32_t ni = blockIdx.z; -+ const int32_t hwi0 = blockIdx.y * 32; -+ const int32_t ci0 = blockIdx.x * 32; -+ -+ const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; -+ const T *A = input + input_idx; -+ if (ci0 + lid < c) { -+ const int lid_x_33 = lid * 33; -+ if ((hwi0 + 32) <= hw) { -+ int hwi = wid; // between 0 and 7 -+ CUTLASS_PRAGMA_UNROLL -+ for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { -+ shbuf[lid_x_33 + hwi] = A[lid]; -+ A = &A[8 * c]; -+ hwi += 8; -+ } -+ } else { -+ for (int hwi = wid; hwi < 32; hwi += 8) { -+ if ((hwi + hwi0) < hw) { -+ shbuf[lid_x_33 + hwi] = A[lid]; -+ } -+ A = &A[8 * c]; -+ } -+ } -+ } -+ __syncthreads(); -+ -+ const int32_t hwiOut = hwi0 + lid; -+ output = &output[ni * hwc + hwiOut]; -+ if (hwiOut < hw) { -+ if (ci0 + 32 < c) { -+ int cI = wid; -+ CUTLASS_PRAGMA_UNROLL -+ for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { -+ output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; -+ cI += 8; -+ } -+ } else { -+ for (int cI = wid; cI < 32; cI += 8) { -+ if (ci0 + cI < c) { -+ output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; -+ } -+ } -+ } -+ } -+} -+ -+template -+void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, -+ cutlass::Tensor4DCoord output_tensor_size, -+ TensorRef ref_input, -+ TensorRef ref_output, -+ cudaStream_t stream) { -+ -+ assert( -+ input_tensor_size.n() == output_tensor_size.n() && -+ input_tensor_size.h() == output_tensor_size.c() && -+ input_tensor_size.w() == output_tensor_size.h() && -+ input_tensor_size.c() == output_tensor_size.w()); -+ -+ int n = input_tensor_size.n(); -+ int h = input_tensor_size.h(); -+ int w = input_tensor_size.w(); -+ int c = input_tensor_size.c(); -+ -+ dim3 grid((c + 31)/32, (h*w + 31)/32, n); -+ dim3 block(32, 8); -+ nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), -+ n, h, w, c); -+ -+} -+ -+} //namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h b/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h -new file mode 100644 -index 0000000..00414a5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/device_utils.h -@@ -0,0 +1,127 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief utils code for device cutlass code -+*/ -+ -+#pragma once -+ -+#include -+#include -+#define FINAL_MASK 0xffffffff -+ -+struct half4 { -+ half x, y, z, w; -+}; -+ -+template -+__inline__ __device__ T warpReduceSum(T* val) -+{ -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); -+ } -+ return (T)(0.0f); -+} -+ -+template -+__inline__ __device__ T blockReduceSum(T* val) -+{ -+ __shared__ T shared[NUM][33]; -+ int lane = threadIdx.x & 0x1f; -+ int wid = threadIdx.x >> 5; -+ -+ warpReduceSum(val); -+ -+ if (lane == 0) { -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ shared[i][wid] = val[i]; -+ } -+ } -+ -+ __syncthreads(); -+ -+ bool is_mask = threadIdx.x < (blockDim.x / 32.f); -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ val[i] = is_mask ? shared[i][lane] : (T)(0.0f); -+ } -+ warpReduceSum(val); -+ return (T)0.0f; -+} -+ -+template -+__inline__ __device__ T warpReduceMax(T* val) -+{ -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); -+ } -+ return (T)(0.0f); -+} -+ -+template -+__inline__ __device__ T blockReduceMax(T* val) -+{ -+ static __shared__ T shared[32][NUM]; -+ int lane = threadIdx.x & 0x1f; // in-warp idx -+ int wid = threadIdx.x >> 5; // warp idx -+ -+ warpReduceMax(val); // get maxx in each warp -+ -+ if (lane == 0) // record in-warp maxx by warp Idx -+ { -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ shared[wid][i] = val[i]; -+ } -+ } -+ -+ __syncthreads(); -+ -+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent -+ // blockDim.x is not divided by 32 -+ bool is_mask = threadIdx.x < (blockDim.x / 32.f); -+#pragma unroll -+ for (int i = 0; i < NUM; i++) { -+ val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); -+ } -+ warpReduceMax(val); -+ -+ return (T)0.0f; -+} -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h -new file mode 100644 -index 0000000..7fee888 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/distribution.h -@@ -0,0 +1,143 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief This header contains a class to parametrize a statistical distribution function. -+*/ -+ -+#include -+ -+namespace cutlass { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Distribution type -+struct Distribution { -+ /// Variant types -+ enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; -+ -+ /// Distribution state -+ union { -+ /// Uniform distribution -+ struct { -+ double min; -+ double max; -+ } uniform; -+ -+ /// Gaussian distribution -+ struct { -+ double mean; -+ double stddev; -+ } gaussian; -+ -+ /// Elements are linear combination of row and column index -+ struct { -+ double start; -+ double delta; -+ } sequential; -+ }; -+ -+ /// Active variant kind -+ Kind kind; -+ -+ /// Random values are cast to integer after scaling by this power of two -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ Distribution() : kind(Invalid), int_scale(0) {} -+ -+ /// Configures distribution as uniform random -+ Distribution &set_uniform(double _min, double _max, int _int_scale = 0) { -+ kind = Uniform; -+ uniform.min = _min; -+ uniform.max = _max; -+ int_scale = _int_scale; -+ return *this; -+ } -+ -+ /// Configures distribution as Gaussian distribution -+ Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0) { -+ kind = Gaussian; -+ gaussian.mean = _mean; -+ gaussian.stddev = _stddev; -+ int_scale = _int_scale; -+ return *this; -+ } -+ -+ /// Sets identity -+ Distribution &set_identity() { -+ kind = Identity; -+ return *this; -+ } -+ -+ /// Sets sequential -+ Distribution &set_sequential(double start, double delta, int _int_scale = 0) { -+ kind = Sequential; -+ sequential.start = start; -+ sequential.delta = delta; -+ int_scale = _int_scale; -+ return *this; -+ } -+}; -+ -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints a Distribution to ostream -+inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { -+ switch (dist.kind) { -+ case cutlass::Distribution::Uniform: -+ out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max; -+ break; -+ case cutlass::Distribution::Gaussian: -+ out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev; -+ break; -+ case cutlass::Distribution::Identity: -+ out << "identity"; -+ break; -+ case cutlass::Distribution::Sequential: -+ out << "sequential"; -+ break; -+ default: -+ out << "unknown"; -+ } -+ -+ out << ", int_scale: " << dist.int_scale; -+ -+ return out; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h b/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h -new file mode 100644 -index 0000000..a349d49 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/exceptions.h -@@ -0,0 +1,69 @@ -+/****************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ ******************************************************************************/ -+ -+#pragma once -+ -+/** -+ * \file -+ * \brief C++ exception semantics for CUDA error codes -+ */ -+ -+#include -+#include -+#include -+ -+#include "cutlass/platform/platform.h" -+ -+namespace cutlass { -+ -+/// C++ exception wrapper for CUDA \p cudaError_t -+class cuda_exception : public std::exception { -+ public: -+ /// Constructor -+ cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} -+ -+ /// Returns the underlying CUDA \p cudaError_t -+ cudaError_t cudaError() const { return err; } -+ -+ protected: -+ /// Explanatory string -+ const char* msg; -+ -+ /// Underlying CUDA \p cudaError_t -+ cudaError_t err; -+}; -+ -+/// Writes a cuda_exception instance to an output stream -+inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { -+ return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); -+} -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp -new file mode 100644 -index 0000000..15e0bc8 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp -@@ -0,0 +1,116 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include -+ -+namespace cute -+{ -+ -+void -+device_init(int device_id, bool quiet = false) -+{ -+ cudaDeviceProp device_prop; -+ std::size_t device_free_physmem; -+ std::size_t device_total_physmem; -+ -+ CUTE_CHECK_ERROR(cudaSetDevice(device_id)); -+ CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); -+ CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); -+ -+ if (device_prop.major < 1) { -+ fprintf(stderr, "Device does not support CUDA.\n"); -+ exit(1); -+ } -+ -+ //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; -+ -+ if (!quiet) { -+ printf("Using device %d: %s (SM%d, %d SMs)\n", -+ device_id, device_prop.name, -+ device_prop.major * 10 + device_prop.minor, -+ device_prop.multiProcessorCount); -+ fflush(stdout); -+ } -+} -+ -+/** -+ * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. -+ */ -+inline int -+_ConvertSMVer2Cores(int major, int minor) -+{ -+ // Defines for GPU Architecture types (using the SM version to determine -+ // the # of cores per SM -+ typedef struct { -+ int SM; // 0xMm (hexidecimal notation), M = SM Major version, -+ // and m = SM minor version -+ int Cores; -+ } sSMtoCores; -+ -+ sSMtoCores nGpuArchCoresPerSM[] = { -+ {0x30, 192}, -+ {0x32, 192}, -+ {0x35, 192}, -+ {0x37, 192}, -+ {0x50, 128}, -+ {0x52, 128}, -+ {0x53, 128}, -+ {0x60, 64}, -+ {0x61, 128}, -+ {0x62, 128}, -+ {0x70, 64}, -+ {0x72, 64}, -+ {0x75, 64}, -+ {-1, -1}}; -+ -+ int index = 0; -+ -+ while (nGpuArchCoresPerSM[index].SM != -1) { -+ if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { -+ return nGpuArchCoresPerSM[index].Cores; -+ } -+ index++; -+ } -+ -+ // If we don't find the values, we default use the previous one -+ // to run properly -+ printf("MapSMtoCores for SM %d.%d is undefined." -+ " Default to use %d Cores/SM\n", -+ major, minor, nGpuArchCoresPerSM[index - 1].Cores); -+ -+ return nGpuArchCoresPerSM[index - 1].Cores; -+} -+ -+} // end namespace cute -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h -new file mode 100644 -index 0000000..c17c0a2 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_reorder.h -@@ -0,0 +1,111 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief reorder data from the host side -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+ -+/// This is needed for the interleaved integer tensor core kernels. The purpose -+/// is to use skip the shared memory part in the epilogue. -+template -+void reorder_column(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ const int InstructionShapeCol = 8; -+ // 4 threads per Quad -+ const int ElementsPerThread = InstructionShapeCol / 4; -+ // 4 threads per Quad -+ const int ReorderedElementsPerThread = -+ Interleaved / 4; -+ -+ for (int n = 0; n < problem_size.n(); n++) { -+ for (int k = 0; k < problem_size.k(); k++) { -+ dest.at({k, (n / Interleaved) * Interleaved + -+ ((n % ReorderedElementsPerThread) / ElementsPerThread) * -+ InstructionShapeCol + -+ ((n % Interleaved) / ReorderedElementsPerThread) * -+ ElementsPerThread + -+ (n % ElementsPerThread)}) = src.at({k, n}); -+ } -+ } -+} -+ -+template -+void reorder_convK(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ -+ TensorRef> mappedDest(dest.data(), dest.stride(0)); -+ TensorRef> mappedSrc(src.data(), src.stride(0)); -+ -+ reorder_column( -+ mappedDest, mappedSrc, problem_size); -+} -+ -+/// This is needed for the sparse tensor core kernels. The purpose -+/// is to use ldmatrix to load from shared memory to the register file. -+template -+void reorder_meta(TensorRef dest, -+ TensorRef src, -+ cutlass::gemm::GemmCoord problem_size) { -+ for (int m = 0; m < problem_size.m(); m++) { -+ for (int k = 0; k < problem_size.k(); k++) { -+ // First reorder the rows. -+ int group = (sizeof(Element) == 2) ? 32 : 16; -+ int interweave = (sizeof(Element) == 2) ? 4 : 2; -+ -+ int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; -+ int dest_col = k; -+ -+ // Next swizzle the 2x2 blocks from Z to N. -+ if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { -+ ++dest_row; -+ --dest_col; -+ } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { -+ --dest_row; -+ ++dest_col; -+ } -+ -+ dest.at({dest_row, dest_col}) = src.at({m, k}); -+ } -+ } -+} -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h -new file mode 100644 -index 0000000..9909ee9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor.h -@@ -0,0 +1,507 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief HostTensor contributes management for both host and device memory. -+ -+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on -+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions -+ for CUDA memcpy operations. -+ -+ Call {host, device}_{data, ref, view}() for accessing host or device memory. -+ -+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+ -+#include "device_memory.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Host tensor -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class HostTensor { -+public: -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// Tensor reference to device memory -+ using TensorRef = TensorRef; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// Tensor reference to device memory -+ using TensorView = TensorView; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Reference to element in tensor -+ using Reference = typename TensorRef::Reference; -+ -+ /// Constant reference to element in tensor -+ using ConstReference = typename ConstTensorRef::Reference; -+ -+ /// Used to handle packing of subbyte elements -+ static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? (8 / sizeof_bits::value) : 1); -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Extent of tensor in logical dimensions -+ TensorCoord extent_; -+ -+ /// Layout object -+ Layout layout_; -+ -+ /// Host-side memory allocation -+ std::vector host_; -+ -+ /// Device-side memory -+ device_memory::allocation device_; -+ -+ public: -+ // -+ // Device and Host Methods -+ // -+ -+ /// Default constructor -+ HostTensor() {} -+ -+ /// Constructs a tensor given an extent. Assumes a packed layout -+ HostTensor( -+ TensorCoord const &extent, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, Layout::packed(extent), device_backed); -+ } -+ -+ /// Constructs a tensor given an extent and layout -+ HostTensor( -+ TensorCoord const &extent, -+ Layout const &layout, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, layout, device_backed); -+ } -+ -+ ~HostTensor() { } -+ -+ /// Clears the HostTensor allocation to size/capacity = 0 -+ void reset() { -+ extent_ = TensorCoord(); -+ layout_ = Layout::packed(extent_); -+ -+ host_.clear(); -+ device_.reset(); -+ } -+ -+ /// Resizes internal memory allocations without affecting layout or extent -+ void reserve( -+ size_t count, ///< size of tensor in elements -+ bool device_backed_ = true) { ///< if true, device memory is also allocated -+ -+ device_.reset(); -+ host_.clear(); -+ -+ count /= kElementsPerStoredItem; -+ -+ host_.resize(count); -+ -+ // Allocate memory -+ Element* device_memory = nullptr; -+ if (device_backed_) { -+ device_memory = device_memory::allocate(count); -+ } -+ device_.reset(device_memory, device_backed_ ? count : 0); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ reserve(size_t(layout_.capacity(extent_)), device_backed_); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. Assumes a packed tensor configuration. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ reset(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ LongIndex new_size = size_t(layout_.capacity(extent_)); -+ -+ if (static_cast(new_size) > host_.size()) { -+ reserve(new_size); -+ } -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ resize(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Returns the number of elements stored in the host tensor -+ size_t size() const { -+ return host_.size() * kElementsPerStoredItem; -+ } -+ -+ /// Returns the logical capacity based on extent and layout. May differ from size(). -+ LongIndex capacity() const { -+ return layout_.capacity(extent_); -+ } -+ -+ /// Gets pointer to host data -+ Element * host_data() { return host_.data(); } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_.data(), ptr_element_offset); } -+ -+ /// Gets a reference to an element in host memory -+ Reference host_data(LongIndex idx) { -+ return ReferenceFactory::get(host_data(), idx); -+ } -+ -+ /// Gets pointer to host data -+ Element const * host_data() const { return host_.data(); } -+ -+ /// Gets a constant reference to an element in host memory -+ ConstReference host_data(LongIndex idx) const { -+ return ReferenceFactory::get(host_data(), idx); -+ } -+ -+ /// Gets pointer to device data -+ Element * device_data() { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } -+ -+ /// Gets pointer to device data -+ Element const * device_data() const { return device_.get(); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef device_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView host_view(LongIndex ptr_element_offset=0) { -+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView device_view(LongIndex ptr_element_offset=0) { -+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); -+ } -+ -+ /// Returns true if device memory is allocated -+ bool device_backed() const { -+ return (device_.get() == nullptr) ? false : true; -+ } -+ -+ -+ /// Returns the layout object -+ Layout & layout() { -+ return layout_; -+ } -+ -+ /// Returns the layout object -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride & stride() { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ LongIndex stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ LongIndex & stride(int dim) { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at the logical Coord in host memory -+ Reference at(TensorCoord const& coord) { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns a const reference to the element at the logical Coord in host memory -+ ConstReference at(TensorCoord const& coord) const { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord extent() const { -+ return extent_; -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord & extent() { -+ return extent_; -+ } -+ -+ /// Copies data from device to host -+ void sync_host() { -+ if (device_backed()) { -+ device_memory::copy_to_host( -+ host_data(), device_data(), size()); -+ } -+ } -+ -+ /// Copies data from host to device -+ void sync_device() { -+ if (device_backed()) { -+ device_memory::copy_to_device( -+ device_data(), host_data(), size()); -+ } -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_host( -+ Element const* ptr_device, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_host( -+ host_data(), ptr_device, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_device( -+ Element const* ptr_device, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_device_to_device( -+ device_data(), ptr_device, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_device( -+ Element const* ptr_host, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_device( -+ device_data(), ptr_host, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_host( -+ Element const* ptr_host, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_host_to_host( -+ host_data(), ptr_host, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_host( -+ Element * ptr_host, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_host( -+ ptr_host, device_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_device( -+ Element * ptr_device, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_device_to_device( -+ ptr_device, device_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_device( -+ Element * ptr_device, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_to_device( -+ ptr_device, host_data(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_host( -+ Element * ptr_host, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ device_memory::copy_host_to_host( -+ ptr_host, host_data(), count); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h -new file mode 100644 -index 0000000..c548d9c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h -@@ -0,0 +1,591 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+/*! \file -+ \brief HostTensor contributes management for both host and device memory. -+ -+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on -+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions -+ for CUDA memcpy operations. -+ -+ Call {host, device}_{data, ref, view}() for accessing host or device memory. -+ -+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. -+*/ -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+#include "cutlass/tensor_ref_planar_complex.h" -+#include "cutlass/tensor_view_planar_complex.h" -+ -+#include "device_memory.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Host tensor -+template < -+ /// Data type of element stored within tensor (concept: NumericType) -+ typename Element_, -+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout) -+ typename Layout_ -+> -+class HostTensorPlanarComplex { -+public: -+ -+ /// Data type of individual access -+ using Element = Element_; -+ -+ /// Mapping function from logical coordinate to linear memory -+ using Layout = Layout_; -+ -+ /// Logical rank of tensor index space -+ static int const kRank = Layout::kRank; -+ -+ /// Index type -+ using Index = typename Layout::Index; -+ -+ /// Long index used for pointer offsets -+ using LongIndex = typename Layout::LongIndex; -+ -+ /// Coordinate in logical tensor space -+ using TensorCoord = typename Layout::TensorCoord; -+ -+ /// Layout's stride vector -+ using Stride = typename Layout::Stride; -+ -+ /// Tensor reference to device memory -+ using TensorRef = TensorRefPlanarComplex; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorRef = typename TensorRef::ConstTensorRef; -+ -+ /// Tensor reference to device memory -+ using TensorView = TensorViewPlanarComplex; -+ -+ /// Tensor reference to constant device memory -+ using ConstTensorView = typename TensorView::ConstTensorView; -+ -+ /// Reference to element in tensor -+ using Reference = typename TensorRef::Reference; -+ -+ /// Constant reference to element in tensor -+ using ConstReference = typename ConstTensorRef::Reference; -+ -+ private: -+ -+ // -+ // Data members -+ // -+ -+ /// Extent of tensor in logical dimensions -+ TensorCoord extent_; -+ -+ /// Layout object -+ Layout layout_; -+ -+ /// Host-side memory allocation -+ std::vector host_; -+ -+ /// Device-side memory -+ device_memory::allocation device_; -+ -+ public: -+ // -+ // Device and Host Methods -+ // -+ -+ /// Default constructor -+ HostTensorPlanarComplex() {} -+ -+ /// Constructs a tensor given an extent. Assumes a packed layout -+ HostTensorPlanarComplex( -+ TensorCoord const &extent, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, Layout::packed(extent), device_backed); -+ } -+ -+ /// Constructs a tensor given an extent and layout -+ HostTensorPlanarComplex( -+ TensorCoord const &extent, -+ Layout const &layout, -+ bool device_backed = true -+ ) { -+ -+ this->reset(extent, layout, device_backed); -+ } -+ -+ ~HostTensorPlanarComplex() { } -+ -+ /// Clears the HostTensor allocation to size/capacity = 0 -+ void reset() { -+ extent_ = TensorCoord(); -+ layout_ = Layout::packed(extent_); -+ -+ host_.clear(); -+ device_.reset(); -+ } -+ -+ /// Resizes internal memory allocations without affecting layout or extent -+ void reserve( -+ size_t count, ///< size of tensor in elements -+ bool device_backed_ = true) { ///< if true, device memory is also allocated -+ -+ device_.reset(); -+ host_.clear(); -+ -+ host_.resize(count * 2); -+ -+ // Allocate memory -+ Element* device_memory = nullptr; -+ if (device_backed_) { -+ device_memory = device_memory::allocate(count * 2); -+ } -+ device_.reset(device_memory, device_backed_ ? count * 2 : 0); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ reserve(size_t(layout_.capacity(extent_)), device_backed_); -+ } -+ -+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new -+ /// extent and layout. Assumes a packed tensor configuration. -+ void reset( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ reset(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ Layout const &layout, ///< layout object of tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ extent_ = extent; -+ layout_ = layout; -+ -+ LongIndex new_size = size_t(layout_.capacity(extent_)); -+ -+ if (static_cast(new_size * 2) > host_.size()) { -+ reserve(new_size); -+ } -+ } -+ -+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. -+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. -+ void resize( -+ TensorCoord const &extent, ///< extent of logical tensor -+ bool device_backed_ = true) { ///< if true, device memory is also allocated. -+ -+ resize(extent, Layout::packed(extent), device_backed_); -+ } -+ -+ /// Returns the number of elements stored in the host tensor -+ size_t size() const { -+ return host_.size() / 2; -+ } -+ -+ /// Returns the logical capacity based on extent and layout. May differ from size(). -+ LongIndex capacity() const { -+ return layout_.capacity(extent_); -+ } -+ -+ /// Stride between real and imaginary parts -+ LongIndex imaginary_stride() const { -+ return host_.size() / 2; -+ } -+ -+ /// Gets pointer to host data -+ Element * host_data() { return host_.data(); } -+ -+ /// Gets pointer to host data imaginary part -+ Element * host_data_imag() { return host_.data() + imaginary_stride(); } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } -+ -+ /// Gets pointer to host data with a pointer offset -+ Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } -+ -+ /// Gets a reference to an element in host memory -+ Reference host_data(LongIndex idx) { -+ return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); -+ } -+ -+ /// Gets pointer to host data -+ Element const * host_data() const { return host_.data(); } -+ -+ /// Gets pointer to host data imaginary part -+ Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } -+ -+ /// Gets a constant reference to an element in host memory -+ ConstReference host_data(LongIndex idx) const { -+ return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); -+ } -+ -+ /// Gets pointer to device data -+ Element * device_data() { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } -+ -+ /// Gets pointer to device data -+ Element const * device_data() const { return device_.get(); } -+ -+ /// Gets pointer to device data with a pointer offset -+ Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } -+ -+ /// Gets a pointer to the device data imaginary part -+ Element * device_data_imag() { return device_.get() + imaginary_stride(); } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef host_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef host_ref_real() { -+ return cutlass::TensorRef(host_data(), layout_); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef host_ref_imag() { -+ return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { -+ return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorRef device_ref(LongIndex ptr_element_offset=0) { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { -+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef device_ref_real() { -+ return cutlass::TensorRef(device_data(), layout_); -+ } -+ -+ /// Returns a tensor reference to the real part of the tensor -+ cutlass::TensorRef device_ref_imag() { -+ return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView host_view(LongIndex ptr_element_offset=0) { -+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView host_view_real() { -+ return cutlass::TensorView(host_data(), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView host_view_imag() { -+ return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ TensorView device_view(LongIndex ptr_element_offset=0) { -+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const { -+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView device_view_real() { -+ return cutlass::TensorView(device_data(), layout_, extent_); -+ } -+ -+ /// Accesses the tensor reference pointing to data -+ cutlass::TensorView device_view_imag() { -+ return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); -+ } -+ -+ /// Returns true if device memory is allocated -+ bool device_backed() const { -+ return (device_.get() == nullptr) ? false : true; -+ } -+ -+ /// Returns the layout object -+ Layout layout() const { -+ return layout_; -+ } -+ -+ /// Returns the layout object's stride vector -+ Stride stride() const { -+ return layout_.stride(); -+ } -+ -+ /// Returns the layout object's stride in a given physical dimension -+ Index stride(int dim) const { -+ return layout_.stride().at(dim); -+ } -+ -+ /// Computes the offset of an index from the origin of the tensor -+ LongIndex offset(TensorCoord const& coord) const { -+ return layout_(coord); -+ } -+ -+ /// Returns a reference to the element at the logical Coord in host memory -+ Reference at(TensorCoord const& coord) { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns a const reference to the element at the logical Coord in host memory -+ ConstReference at(TensorCoord const& coord) const { -+ return host_data(offset(coord)); -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord extent() const { -+ return extent_; -+ } -+ -+ /// Returns the extent of the tensor -+ TensorCoord & extent() { -+ return extent_; -+ } -+ -+ /// Copies data from device to host -+ void sync_host() { -+ if (device_backed()) { -+ device_memory::copy_to_host( -+ host_data(), device_data(), imaginary_stride() * 2); -+ } -+ } -+ -+ /// Copies data from host to device -+ void sync_device() { -+ if (device_backed()) { -+ device_memory::copy_to_device( -+ device_data(), host_data(), imaginary_stride() * 2); -+ } -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_host( -+ Element const* ptr_device_real, ///< source device memory -+ Element const* ptr_device_imag, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_host( -+ host_data(), ptr_device_real, count); -+ -+ device_memory::copy_to_host( -+ host_data_imag(), ptr_device_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_device_to_device( -+ Element const* ptr_device_real, ///< source device memory -+ Element const* ptr_device_imag, ///< source device memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_device_to_device( -+ device_data(), ptr_device_real, count); -+ -+ device_memory::copy_device_to_device( -+ device_data_imag(), ptr_device_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_device( -+ Element const* ptr_host_real, ///< source host memory -+ Element const* ptr_host_imag, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_device( -+ device_data(), ptr_host_real, count); -+ -+ device_memory::copy_to_device( -+ device_data_imag(), ptr_host_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_in_host_to_host( -+ Element const* ptr_host_real, ///< source host memory -+ Element const* ptr_host_imag, ///< source host memory -+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_host_to_host( -+ host_data(), ptr_host_real, count); -+ -+ device_memory::copy_host_to_host( -+ host_data_imag(), ptr_host_imag, count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_host( -+ Element * ptr_host_real, ///< source device memory -+ Element * ptr_host_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_host( -+ ptr_host_real, device_data(), count); -+ -+ device_memory::copy_to_host( -+ ptr_host_imag, device_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_device_to_device( -+ Element * ptr_device_real, ///< source device memory -+ Element * ptr_device_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_device_to_device( -+ ptr_device_real, device_data(), count); -+ -+ device_memory::copy_device_to_device( -+ ptr_device_imag, device_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_device( -+ Element * ptr_device_real, ///< source device memory -+ Element * ptr_device_imag, ///< source device memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_to_device( -+ ptr_device_real, host_data(), count); -+ -+ device_memory::copy_to_device( -+ ptr_device_imag, host_data_imag(), count); -+ } -+ -+ /// Copy data from a caller-supplied device pointer into host memory. -+ void copy_out_host_to_host( -+ Element * ptr_host_real, ///< source host memory -+ Element * ptr_host_imag, ///< source host memory -+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. -+ -+ if (count < 0) { -+ count = capacity(); -+ } -+ else { -+ count = __NV_STD_MIN(capacity(), count); -+ } -+ -+ device_memory::copy_host_to_host( -+ ptr_host_real, host_data(), count); -+ -+ device_memory::copy_host_to_host( -+ ptr_host_imag, host_data_imag(), count); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h -new file mode 100644 -index 0000000..7028bf7 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/host_uncompress.h -@@ -0,0 +1,157 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief uncompress sparse matrix from the host side -+*/ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/tensor_view_io.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+ -+// uncompress sparse tensor core A matrix -+template -+void uncompress(TensorRef uncompressed_tensor_a, -+ TensorRef tensor_a, -+ TensorRef tensor_e, int row, int col) { -+ // How many uncompressed data we can get with ElementE meta data -+ int DecompressedElementsPerElementE = -+ 256 / cutlass::sizeof_bits::value; -+ -+ // Process 4bit meta data a time -+ int step; -+ -+ // 1:2 or 2:4 or 4:8 -+ int a, b; -+ -+ if (cutlass::sizeof_bits::value == 4) { -+ step = 8; -+ a = 4; -+ b = 8; -+ } else if (cutlass::sizeof_bits::value == 8) { -+ step = 4; -+ a = 2; -+ b = 4; -+ } else if (cutlass::sizeof_bits::value == 16) { -+ step = 4; -+ a = 2; -+ b = 4; -+ } else if (cutlass::sizeof_bits::value == 32) { -+ step = 2; -+ a = 1; -+ b = 2; -+ } -+ -+ int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; -+ -+ for (int r = 0; r < row; ++r) { -+ for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { -+ -+ ElementE meta = tensor_e.at(MatrixCoord(r, c)); -+ -+ for (int i = 0; i < DecompressedElementsPerElementE; i += step) { -+ int e = (meta >> (i / step * 4)) & 0xf; -+ int idx0 = e & 0x3; -+ int idx1 = e >> 2; -+ -+ if (a == 1) idx0 = idx0 / 2; -+ -+ for (int ii = 0; ii < step; ii += ElementsPerE) { -+ int real_col = -+ c * DecompressedElementsPerElementE + i + ii; -+ int compressed_col = (real_col / b) * a; -+ -+ if (ii == (idx0 * ElementsPerE)) { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ tensor_a.at(MatrixCoord(r, compressed_col)); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ tensor_a.at(MatrixCoord(r, compressed_col + 1)); -+ } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ tensor_a.at( -+ MatrixCoord(r, compressed_col + ElementsPerE + 1)); -+ } else { -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = -+ ElementA(0); -+ if (ElementsPerE == 2) -+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = -+ ElementA(0); -+ } -+ } -+ } -+ } -+ } -+} -+ -+// uncompress ELL block sparse matrix -+template -+void uncompress_ell_block_sparse( -+ TensorRef uncompressed_tensor_a, -+ TensorRef tensor_a, -+ TensorRef ell_idx, -+ int rows, int cols, -+ int ell_num_cols, int ell_blocksize) { -+ -+ for (int r = 0; r < rows / ell_blocksize; ++r) { -+ for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { -+ -+ ElementE idx = ell_idx.at(MatrixCoord(r, c)); -+ -+ if (idx != -1) { -+ int row_begin = r * ell_blocksize; -+ int col_begin_real = idx * ell_blocksize; -+ int col_begin = c * ell_blocksize; -+ -+ for (int i = 0; i < ell_blocksize; ++i) { -+ for (int j = 0; j < ell_blocksize; ++j) { -+ uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = -+ tensor_a.at( -+ MatrixCoord(row_begin + i, col_begin +j)); -+ } -+ } -+ } -+ } -+ } -+} -+ -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h b/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h -new file mode 100644 -index 0000000..846e02c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/index_sequence.h -@@ -0,0 +1,38 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/numeric_types.h" -+ -+// integer_sequence moved to cutlass/numeric_types.h -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp -new file mode 100644 -index 0000000..7ecffaf ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/packed_stride.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Utilities for packing a rank-X shape into a rank-(X-1) stride in CuTe. -+*/ -+ -+#pragma once -+ -+#include "cute/stride.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Strides without batch mode -+ -+template -+cute::Stride> -+make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); -+ return s_copy; -+} -+ -+template -+cute::Stride, StrideIntT> -+make_cute_packed_stride(cute::Stride, StrideIntT> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); -+ return s_copy; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Strides with batch mode -+ -+template -+cute::Stride, int64_t> -+make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); -+ int batch_count = cute::get<2>(shape_MKL); -+ if (batch_count > 1) { -+ cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); -+ } -+ else { -+ cute::get<2>(s_copy) = static_cast(0); -+ } -+ return s_copy; -+} -+ -+template -+cute::Stride, StrideIntT, int64_t> -+make_cute_packed_stride(cute::Stride, StrideIntT, int64_t> s, cute::Shape shape_MKL) { -+ static_assert(std::is_integral_v, -+ "Stride must have an integral type so it can be set dynamically. Static strides not supported."); -+ auto s_copy = s; -+ cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); -+ int batch_count = cute::get<2>(shape_MKL); -+ if (batch_count > 1) { -+ cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); -+ } -+ else { -+ cute::get<2>(s_copy) = static_cast(0); -+ } -+ return s_copy; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp -new file mode 100644 -index 0000000..f867f88 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/print_error.hpp -@@ -0,0 +1,235 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include -+#include -+ -+#include -+#include -+ -+#include -+ -+// The computed infinity norm does not include -+// any NaN column absolute-value sums. -+struct matrix_inf_norm_result { -+ // Accumulate errors in double, as this is generally -+ // the highest precision that the examples use. -+ double inf_norm = 0.0; -+ bool found_nan = false; -+}; -+ -+// In theory, cute::Tensor, T> could be treated as a view type, -+// and thus passed by value (as std::span or std::string_view would be). -+// However, generic cute::Tensor are more like containers -+// and thus are best passed by reference or const reference. -+template -+matrix_inf_norm_result -+matrix_inf_norm(const cute::Tensor& host_matrix) -+{ -+ using std::abs; -+ using error_type = decltype(std::declval().inf_norm); -+ -+ error_type inf_norm = 0.0; -+ bool found_nan = false; -+ -+ const auto shape = host_matrix.shape(); -+ using index_type = std::decay_t(shape))>; -+ // Computing the infinity norm requires that we be able -+ // to treat the input as a matrix, with rows and columns. -+ static_assert(std::is_integral_v); -+ const index_type num_rows = cute::get<0>(shape); -+ const index_type num_cols = cute::get<1>(shape); -+ -+ for(index_type i = 0; i < num_rows; ++i) { -+ error_type row_abs_sum = 0.0; -+ for(index_type j = 0; j < num_cols; ++j) { -+ row_abs_sum += abs(host_matrix(i, j)); -+ } -+ if(std::isnan(row_abs_sum)) { -+ found_nan = true; -+ } else { -+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; -+ } -+ } -+ -+ return {inf_norm, found_nan}; -+} -+ -+// Infinity norm of (X - Y). -+template -+matrix_inf_norm_result -+matrix_diff_inf_norm(const cute::Tensor& X, -+ const cute::Tensor& Y) -+{ -+ using std::abs; -+ using error_type = decltype(std::declval().inf_norm); -+ -+ const auto X_shape = X.shape(); -+ const auto Y_shape = Y.shape(); -+ -+ using index_type = std::decay_t(X_shape))>; -+ // Computing the infinity norm requires that we be able -+ // to treat the input as a matrix, with rows and columns. -+ static_assert(std::is_integral_v); -+ const index_type num_rows = cute::get<0>(X_shape); -+ const index_type num_cols = cute::get<1>(X_shape); -+ -+ assert(num_rows == cute::get<0>(Y_shape)); -+ assert(num_cols == cute::get<1>(Y_shape)); -+ -+ auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) { -+ return A(i, j); -+ }; -+ auto diff_ij = [&](std::size_t i, std::size_t j) { -+ return matrix_ij(X, i, j) - matrix_ij(Y, i, j); -+ }; -+ -+ error_type inf_norm = 0.0; -+ bool found_nan = false; -+ -+ for(index_type i = 0; i < num_rows; ++i) { -+ error_type row_abs_sum = 0.0; -+ for(index_type j = 0; j < num_cols; ++j) { -+ row_abs_sum += abs(diff_ij(i, j)); -+ } -+ if(std::isnan(row_abs_sum)) { -+ found_nan = true; -+ } else { -+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; -+ } -+ } -+ -+ return {inf_norm, found_nan}; -+} -+ -+template -+void -+print_matrix_multiply_mollified_relative_error( -+ const char A_value_type_name[], -+ const cute::Tensor& A, -+ const char B_value_type_name[], -+ const cute::Tensor& B, -+ const char C_value_type_name[], -+ const cute::Tensor& C_computed, -+ const cute::Tensor& C_expected) -+{ -+ const auto [A_norm, A_has_nan] = matrix_inf_norm(A); -+ const auto [B_norm, B_has_nan] = matrix_inf_norm(B); -+ const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected); -+ const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected); -+ -+ const auto A_norm_times_B_norm = A_norm * B_norm; -+ const auto relative_error = A_norm_times_B_norm == 0.0 ? -+ diff_norm : (diff_norm / A_norm_times_B_norm); -+ -+ // For expected error bounds, please refer to the LAPACK Users' Guide, -+ // in particular https://netlib.org/lapack/lug/node108.html . -+ // Printing the infinity norm of C is a way to check -+ // that both the function being tested (C_computed) -+ // and the reference implementation (C_expected) -+ // don't just do nothing (or fill with zeros). -+ using std::cout; -+ cout << "Value type of A: " << A_value_type_name << '\n' -+ << std::scientific -+ << "Infinity norm of A: " << A_norm << '\n' -+ << "Value type of B: " << B_value_type_name << '\n' -+ << "Infinity norm of B: " << B_norm << '\n' -+ << "Value type of C: " << C_value_type_name << '\n' -+ << "Infinity norm of C_expected: " << C_norm << '\n' -+ << "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n'; -+ -+ if(A_norm_times_B_norm == 0.0) { -+ cout << "Mollified relative error: " << relative_error << '\n'; -+ } else { -+ cout << "Relative error: " << relative_error << '\n'; -+ } -+ -+ cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n' -+ << "Did we encounter NaN in (C_computed - C_expected)? " -+ << (diff_has_nan ? "yes" : "no") << '\n'; -+} -+ -+template -+void -+print_matrix_multiply_mollified_relative_error( -+ const char value_type_name[], -+ const cute::Tensor& A, -+ const cute::Tensor& B, -+ const cute::Tensor& C_computed, -+ const cute::Tensor& C_expected) -+{ -+ print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, -+ value_type_name, C_computed, C_expected); -+} -+ -+// Take a CUTLASS HostTensor (or the like) as input, -+// and return a const CuTe Tensor. -+// This is useful for use with the above error printing functions. -+// This implicitly "transposes" if the layout is RowMajor. -+// Note that the HostTensor must be captured by nonconst reference -+// in order for X.host_ref().data() to compile. -+// (CUTLASS is a bit more container-y than CuTe.) -+template -+auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) -+{ -+ // The tensors were created with post-transposed extents. -+ const auto extents = X.extent(); -+ const auto shape = cute::Shape{extents[0], extents[1]}; -+ // Both RowMajor and ColumnMajor only store one stride. -+ const int LDX = X.stride(0); -+ const auto strides = [&]() { -+ using input_layout_type = typename std::decay_t::Layout; -+ if constexpr (std::is_same_v) { -+ return cute::Stride{1, LDX}; -+ } -+ else { -+ static_assert(std::is_same_v); -+ return cute::Stride{LDX, 1}; -+ } -+ }(); -+ const auto layout = cute::make_layout(shape, strides); -+ auto X_data = X.host_ref().data(); -+ auto X_data_const = const_cast >(X_data); -+ return cute::make_tensor(X_data_const, layout); -+}; -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h -new file mode 100644 -index 0000000..b4bffa3 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h -@@ -0,0 +1,135 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+ -+namespace cutlass { -+namespace reference { -+namespace detail { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Template function to compute an inner product. -+#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a -+ // host-only type -+template -+CUTLASS_HOST_DEVICE -+Ctype inner_product(Atype a, Btype b, Ctype c) { -+ return Ctype(a) * Ctype(b) + c; -+} -+ -+/// Specialization for matrix multiplication with binary operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int bit = 0; bit < 32; bit++) { -+ accum += a[bit] ^ b[bit]; -+ } -+ return accum + c; -+} -+ -+/* -+/// Specialization for matrix multiplication with signed 4-bit integer operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int k = 0; k < 8; k++) { -+ accum += a[k] * b[k]; -+ } -+ return accum + c; -+} -+ -+/// Specialization for matrix multiplication with unsigned 4-bit integer operands -+template <> -+CUTLASS_HOST_DEVICE -+int inner_product, Array, int>( -+ Array a, -+ Array b, -+ int c) { -+ -+ int accum = 0; -+ for (int k = 0; k < 8; k++) { -+ accum += a[k] * b[k]; -+ } -+ return accum + c; -+} -+*/ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct Cast { -+ // Default behavior: convert to the destination type -+#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a -+ // host-only type -+ CUTLASS_HOST_DEVICE -+ static DstType apply(SrcType src) { return static_cast(src); }; -+}; -+ -+template <> -+struct Cast { -+ CUTLASS_HOST_DEVICE -+ static int8_t apply(float src) { -+ // Clamp to the range of signed 8-bit integers. -+ return static_cast(fmaxf(-128.f, fminf(127.f, src))); -+ }; -+}; -+ -+template <> -+struct Cast { -+ CUTLASS_HOST_DEVICE -+ static uint8_t apply(float src) { -+ // Clamp to the range of signed 8-bit integers. -+ return static_cast(fmaxf(0.f, fminf(255.f, src))); -+ }; -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+} // namespace reference -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h -new file mode 100644 -index 0000000..ac22699 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h -@@ -0,0 +1,94 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace detail { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LinearToCoordinateHelper { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ -+ int64_t prod = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank - Index; i < Rank; ++i) { -+ prod *= int64_t(extent[i]); -+ } -+ -+ coord[Rank - Index - 1] = int(idx / prod); -+ -+ int64_t residual = idx % prod; -+ LinearToCoordinateHelper()(coord, residual, extent); -+ } -+}; -+ -+template -+struct LinearToCoordinateHelper { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ coord[Rank - 1] = int(idx); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct LinearToCoordinate { -+ -+ CUTLASS_HOST_DEVICE -+ void operator()(Coord &coord, int64_t idx, Coord const &extent) const { -+ LinearToCoordinateHelper()(coord, idx, extent); -+ } -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace detail -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h -new file mode 100644 -index 0000000..fec0587 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h -@@ -0,0 +1,1549 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Reference implementation for convolution in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/matrix_shape.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Conv2d device reference kernel -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d Fprop kernel - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t npq = npq_start + m; -+ -+ thread_n[m] = int(npq / PQ); -+ -+ int64_t residual = npq % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ int c_per_group = problem_size.C / problem_size.groups; -+ int k_per_group = problem_size.K / problem_size.groups; -+ -+ // Compute convolution -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int C = 0; C < problem_size.C; ++C) { -+ -+ // Get group id of currnet channel -+ int c_group_idx = C / c_per_group; -+ -+ // Load from activations tensor -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { -+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); -+ } -+ else { -+ element_A[m] = ElementAccumulator(); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ int k_group_idx = thread_k / k_per_group; -+ -+ if (thread_k < problem_size.K && k_group_idx == c_group_idx) { -+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); -+ } -+ -+ tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d Fprop kernel - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_z[kThreadM]; -+ int thread_p[kThreadM]; -+ int thread_q[kThreadM]; -+ -+ // Compute N, Z, P, Q coordinates for each row of a thread's tile -+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q; -+ int64_t ZPQ = PQ * problem_size.Z; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t nzpq = nzpq_start + m; -+ -+ thread_n[m] = int(nzpq / ZPQ); -+ -+ int64_t residual = nzpq % ZPQ; -+ thread_z[m] = int(residual / PQ); -+ -+ residual = residual % PQ; -+ thread_p[m] = int(residual / problem_size.Q); -+ thread_q[m] = int(residual % problem_size.Q); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int T = 0; T < problem_size.T; ++T) { -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int C = 0; C < problem_size.C; ++C) { -+ -+ // Load from activations tensor -+ int filter_t = T; -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - T; -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; -+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (thread_n[m] < problem_size.N && -+ d >= 0 && d < problem_size.D && -+ h >= 0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W) { -+ -+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); -+ } -+ else { -+ element_A[m] = ElementAccumulator(); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ -+ if (thread_k < problem_size.K) { -+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && -+ thread_z[m] < problem_size.Z && -+ thread_p[m] < problem_size.P && -+ thread_q[m] < problem_size.Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_k = k_start + n; -+ if (thread_k < problem_size.K) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); -+ } -+ -+ tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } // for (n) -+ -+ } -+ } // for (m) -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d dgrad kernel - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dDgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_h[kThreadM]; -+ int thread_w[kThreadM]; -+ -+ // Compute N, H, W coordinates for each row of a thread's tile -+ int64_t HW = int64_t(problem_size.H) * problem_size.W; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t nhw = nhw_start + m; -+ -+ thread_n[m] = int(nhw / HW); -+ -+ int64_t residual = nhw % HW; -+ thread_h[m] = int(residual / problem_size.W); -+ thread_w[m] = int(residual % problem_size.W); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int K = 0; K < problem_size.K; ++K) { -+ -+ // Load from activations tensor -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { -+ -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); -+ } -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ -+ if (thread_c < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ if (thread_c < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); -+ } -+ -+ tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d dgrad kernel - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 16, // shape of a threadblock in units of threads -+ int kCtaShapeN = 8 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dDgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_n[kThreadM]; -+ int thread_d[kThreadM]; -+ int thread_h[kThreadM]; -+ int thread_w[kThreadM]; -+ -+ // Compute N, H, W coordinates for each row of a thread's tile -+ int64_t HW = int64_t(problem_size.H) * problem_size.W; -+ int64_t DHW = HW * problem_size.D; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int64_t ndhw = ndhw_start + m; -+ -+ thread_n[m] = int(ndhw / DHW); -+ -+ int64_t residual = ndhw % DHW; -+ thread_d[m] = int(residual / HW); -+ -+ residual = residual % HW; -+ thread_h[m] = int(residual / problem_size.W); -+ thread_w[m] = int(residual % problem_size.W); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int T = 0; T < problem_size.T; ++T) { -+ for (int R = 0; R < problem_size.R; ++R) { -+ for (int S = 0; S < problem_size.S; ++S) { -+ for (int K = 0; K < problem_size.K; ++K) { -+ -+ // Load from activations tensor -+ int filter_t = T; -+ int filter_r = R; -+ int filter_s = S; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - T; -+ filter_r = problem_size.R - 1 - R; -+ filter_s = problem_size.S - 1 - S; -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; -+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (z >= 0 && !(z % problem_size.stride_d) && -+ p >= 0 && !(p % problem_size.stride_h) && -+ q >= 0 && !(q % problem_size.stride_w)) { -+ -+ z = z / problem_size.stride_d; -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); -+ } -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ -+ if (thread_c < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); -+ } -+ else { -+ element_B[n] = ElementAccumulator(); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ -+ if (thread_n[m] < problem_size.N && -+ thread_d[m] < problem_size.D && -+ thread_h[m] < problem_size.H && -+ thread_w[m] < problem_size.W) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ int thread_c = c_start + n; -+ if (thread_c < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); -+ } -+ -+ tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+// Conv2d wgrad kernel - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 8, // shape of a threadblock in units of threads -+ int kCtaShapeN = 16 // shape of a threadblock in units of threads -+> -+__global__ void Conv2dWgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_r[kThreadN]; -+ int thread_s[kThreadN]; -+ int thread_c[kThreadN]; -+ -+ // Compute R, S, C coordinates for each row of a thread's tile -+ int64_t SC = int64_t(problem_size.S) * problem_size.C; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ int64_t rsc = rsc_start + n; -+ int64_t residual = rsc % SC; -+ -+ thread_r[n] = int(rsc / SC); -+ thread_s[n] = int(residual / problem_size.C); -+ thread_c[n] = int(residual % problem_size.C); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int N = 0; N < problem_size.N; ++N) { -+ for (int P = 0; P < problem_size.P; ++P) { -+ for (int Q = 0; Q < problem_size.Q; ++Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (thread_k < problem_size.K) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ // Load from activations tensor -+ int filter_r = thread_r[n]; -+ int filter_s = thread_s[n]; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - filter_r; -+ filter_s = problem_size.S - 1 - filter_s; -+ } -+ -+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ element_B[n] = ElementAccumulator(); -+ -+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { -+ element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ } -+ } -+ } -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ if (thread_k < problem_size.K) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); -+ } -+ -+ tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+// Conv3d wgrad kernel - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension -+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension -+ int kCtaShapeM = 8, // shape of a threadblock in units of threads -+ int kCtaShapeN = 16 // shape of a threadblock in units of threads -+> -+__global__ void Conv3dWgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta -+ ) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ ElementAccumulator element_A[kThreadM]; -+ ElementAccumulator element_B[kThreadN]; -+ ElementAccumulator accum[kThreadM][kThreadN]; -+ -+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; -+ int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; -+ -+ int thread_t[kThreadN]; -+ int thread_r[kThreadN]; -+ int thread_s[kThreadN]; -+ int thread_c[kThreadN]; -+ -+ // Compute R, S, C coordinates for each row of a thread's tile -+ int64_t SC = int64_t(problem_size.S) * problem_size.C; -+ int64_t RSC = SC * problem_size.R; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ int64_t trsc = trsc_start + n; -+ -+ thread_t[n] = int(trsc / RSC); -+ -+ int64_t residual = trsc % RSC; -+ thread_r[n] = int(residual / SC); -+ -+ residual = residual % SC; -+ thread_s[n] = int(residual / problem_size.C); -+ thread_c[n] = int(residual % problem_size.C); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = ElementAccumulator(); -+ } -+ } -+ -+ // Compute convolution -+ for (int N = 0; N < problem_size.N; ++N) { -+ for (int Z = 0; Z < problem_size.Z; ++Z) { -+ for (int P = 0; P < problem_size.P; ++P) { -+ for (int Q = 0; Q < problem_size.Q; ++Q) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ element_A[m] = ElementAccumulator(); -+ -+ if (thread_k < problem_size.K) { -+ element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); -+ } -+ } -+ -+ // Load from filters tensor -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ // Load from activations tensor -+ int filter_t = thread_t[n]; -+ int filter_r = thread_r[n]; -+ int filter_s = thread_s[n]; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - filter_t; -+ filter_r = problem_size.R - 1 - filter_r; -+ filter_s = problem_size.S - 1 - filter_s; -+ } -+ -+ int d = Z * problem_size.stride_d - problem_size.pad_w + filter_t * problem_size.dilation_d; -+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ element_B[n] = ElementAccumulator(); -+ -+ if (d >= 0 && d < problem_size.D && -+ h >= 0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W && -+ thread_c[n] < problem_size.C) { -+ -+ element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); -+ } -+ } -+ -+ // Accumulate matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); -+ } -+ } -+ -+ } // for (Q) -+ } // for (P) -+ } // for (Z) -+ } // for (N) -+ -+ // Write out the results -+ CUTLASS_PRAGMA_UNROLL -+ for (int m = 0; m < kThreadM; ++m) { -+ int thread_k = k_start + m; -+ -+ if (thread_k < problem_size.K) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int n = 0; n < kThreadN; ++n) { -+ -+ if (thread_t[n] < problem_size.T && -+ thread_r[n] < problem_size.R && -+ thread_s[n] < problem_size.S && -+ thread_c[n] < problem_size.C) { -+ -+ ElementCompute c_ref = ElementCompute(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); -+ } -+ -+ tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( -+ alpha * ElementCompute(accum[m][n]) + beta * c_ref); -+ } -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Conv2d Fprop dispatcher - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv2dFprop< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_x, -+ tensor_w, -+ tensor_y_in, -+ tensor_y_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Fprop dispatcher - y = fprop(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; -+ int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv3dFprop< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_x, -+ tensor_w, -+ tensor_y_in, -+ tensor_y_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dDgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; -+ int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv2dDgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_w, -+ tensor_dx_in, -+ tensor_dx_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dDgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads -+ -+ int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; -+ int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); -+ -+ kernel::Conv3dDgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_w, -+ tensor_dx_in, -+ tensor_dx_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2dWgrad( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads -+ -+ int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; -+ int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); -+ -+ kernel::Conv2dWgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_x, -+ tensor_dw_in, -+ tensor_dw_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3dWgrad( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ // -+ // Blocking factors improve performance of reference implementation -+ // -+ -+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension -+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension -+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads -+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads -+ -+ int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; -+ int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); -+ -+ dim3 block(kCtaShapeM, kCtaShapeN); -+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); -+ -+ kernel::Conv3dWgrad< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, -+ InnerProductOp, -+ kThreadM, -+ kThreadN, -+ kCtaShapeM, -+ kCtaShapeN -+ ><<< grid, block, 0, stream >>>( -+ problem_size, -+ tensor_dy, -+ tensor_x, -+ tensor_dw_in, -+ tensor_dw_out, -+ alpha, -+ beta -+ ); -+ -+ cudaError_t result = cudaPeekAtLastError(); -+ if (result != cudaSuccess) { -+ return Status::kErrorInternal; -+ } -+ -+ return Status::kSuccess; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv2d( -+ conv::Operator convolutional_operator, -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ return Conv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ case conv::Operator::kDgrad: -+ return Conv2dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ case conv::Operator::kWgrad: -+ return Conv2dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ break; -+ -+ default: break; -+ } -+ -+ return Status::kErrorNotSupported; -+} -+ -+/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+Status Conv3d( -+ conv::Operator convolutional_operator, -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cudaStream_t stream = nullptr) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ return Conv3dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ case conv::Operator::kDgrad: -+ return Conv3dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ case conv::Operator::kWgrad: -+ return Conv3dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); -+ -+ default: break; -+ } -+ -+ return Status::kErrorNotSupported; -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h -new file mode 100644 -index 0000000..1850c2f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h -@@ -0,0 +1,385 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/device/kernel/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Blocking structure potentially improves performance of reference implementation -+ // with a minor increase in complexity. -+ // -+ // Note, this reference implementation is NOT expected to approach peak performance. -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) -+ ); -+ -+ // Launch a GEMM kernel -+ kernel::Gemm< -+ TensorRef, -+ TensorRef, -+ TensorRef, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ tensor_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum -+ ); -+} -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum) { -+ -+ compute_gemm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Gemm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add-saturate -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for XOR-popc -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ AccumulatorType initial_accum = AccumulatorType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Batched GEMM -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a batch of GEMMs over a set of matrices of common dimension. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c, -+ AccumulatorType initial_accum) { -+ -+ static_assert( -+ TensorRefCollectionA::kRank == 2 && -+ TensorRefCollectionB::kRank == 2 && -+ TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Blocking structure potentially improves performance of reference implementation -+ // with a minor increase in complexity. -+ // -+ // Note, this reference implementation is NOT expected to approach peak performance. -+ using OutputTile = MatrixShape<4, 4>; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), -+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), -+ batch_count -+ ); -+ -+ // Launch a GEMM kernel -+ kernel::BatchedGemm< -+ TensorRefCollectionA, -+ TensorRefCollectionB, -+ TensorRefCollectionC, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ tensor_b, -+ beta, -+ tensor_c, -+ initial_accum -+ ); -+} -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c) { -+ -+ BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h -new file mode 100644 -index 0000000..0f3977b ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h -@@ -0,0 +1,345 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kMblock = 4, -+ int kNblock = 4 -+> -+__global__ void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ int batch_idx = blockIdx.z; -+ -+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A); -+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B); -+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C); -+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D); -+ -+ for (; batch_idx < batch_count; batch_idx += gridDim.z) { -+ -+ // Compute matrix product using blocks -+ ComputeType accum[kMblock][kNblock]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ -+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); -+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); -+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); -+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); -+ -+ } // for (batch_idx) -+} -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = 4; -+ int const kNblock = 4; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ if (grid.y <= std::numeric_limits::max()) { -+ kernel::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kMblock, -+ kNblock -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+ } else { -+ // Using bigger thread tile size -+ int const kBigMblock = 4; -+ int const kBigNblock = 16; -+ -+ dim3 Bigblock(16, 8); -+ dim3 Biggrid( -+ (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), -+ (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ kernel::GemmComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kBigMblock, -+ kBigNblock -+ ><<< Biggrid, Bigblock >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d) { -+ -+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h -new file mode 100644 -index 0000000..baab696 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/matrix_coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+static int const kGemmPlanarComplexBlockSize = 4; -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+__global__ void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ int const kMblock = kGemmPlanarComplexBlockSize; -+ int const kNblock = kGemmPlanarComplexBlockSize; -+ -+ using ComplexA = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexB = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexC = typename TensorRefPlanarComplex::ComplexElement; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ complex accum[kMblock][kNblock]; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k_block = 0; k_block < K; ++k_block) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ -+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); -+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ complex a = complex{ -+ ComputeType(a_ik.real()), -+ ComputeType(a_ik.imag()) -+ }; -+ -+ complex b = complex{ -+ ComputeType(b_kj.real()), -+ ComputeType(b_kj.imag()) -+ }; -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a = conj(a); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b = conj(b); -+ } -+ -+ accum[i][j] = inner_product_op(a, b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ complex acc{ -+ ScalarType(accum[i][j].real()), -+ ScalarType(accum[i][j].imag()) -+ }; -+ -+ ComplexC c_ij = ComplexC(); -+ -+ if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { -+ c_ij = tensor_c.at(coord); -+ } -+ -+ complex src{ -+ ScalarType(c_ij.real()), -+ ScalarType(c_ij.imag()) -+ }; -+ -+ complex result = alpha * acc + beta * src; -+ -+ ComplexC d_ij; -+ -+ d_ij.real() = convert_op(result.real()); -+ d_ij.imag() = convert_op(result.imag()); -+ -+ tensor_d.at(coord) = d_ij; -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = kernel::kGemmPlanarComplexBlockSize; -+ int const kNblock = kernel::kGemmPlanarComplexBlockSize; -+ -+ dim3 block(16, 8); -+ -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ 1); -+ -+ kernel::GemmPlanarComplex< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d) { -+ -+ GemmPlanarComplex( -+ problem_size, -+ alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ complex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h -new file mode 100644 -index 0000000..e917765 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h -@@ -0,0 +1,162 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/device/thread/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename TensorRefA, -+ typename TensorRefB, -+ typename TensorRefC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+__global__ void Gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRefA tensor_a, -+ TensorRefB tensor_b, -+ ScalarType beta, -+ TensorRefC tensor_c, -+ TensorRefC tensor_d, -+ AccumulatorType initial_accum) { -+ -+ // Map each thread to a unique tile of the output matrix -+ MatrixCoord output_coord( -+ MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), -+ MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) -+ ); -+ -+ // Compute the general matrix product -+ thread::Gemm< -+ TensorRefA, -+ TensorRefB, -+ TensorRefC, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ > gemm(initial_accum); -+ -+ gemm.multiply_add( -+ problem_size, -+ tensor_a, -+ tensor_b, -+ output_coord); -+ -+ gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp, -+ typename ConvertOp -+> -+__global__ void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRefCollectionA tensor_collection_a, -+ TensorRefCollectionB tensor_collection_b, -+ ScalarType beta, -+ TensorRefCollectionC tensor_collection_c, -+ AccumulatorType initial_accum) { -+ -+ // Obtain batch ID -+ int batch_id = blockIdx.z; -+ -+ // Dereference based on batch_id -+ typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); -+ typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); -+ typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); -+ -+ // Map each thread to a unique tile of the output matrix -+ MatrixCoord output_coord( -+ (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, -+ (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow -+ ); -+ -+ // Compute the general matrix product -+ thread::Gemm< -+ typename TensorRefCollectionA::TensorRef, -+ typename TensorRefCollectionB::TensorRef, -+ typename TensorRefCollectionC::TensorRef, -+ ScalarType, -+ AccumulatorType, -+ OutputTile, -+ InnerProductOp, -+ ConvertOp -+ > gemm(initial_accum); -+ -+ gemm.multiply_add( -+ problem_size, -+ tensor_a, -+ tensor_b, -+ output_coord); -+ -+ gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h -new file mode 100644 -index 0000000..4850b98 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h -@@ -0,0 +1,168 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize tensor to uniform random distribution -+template -+__global__ void TensorInitializeUniform( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ double range = dist.uniform.max - dist.uniform.min; -+ -+ double rnd = curand_uniform(&rng_state[threadIdx.x]); -+ -+ rnd = dist.uniform.min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ if (dist.int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << dist.int_scale))); -+ *tensor = T(rnd / double(1 << dist.int_scale)); -+ } else { -+ *tensor = T(rnd); -+ } -+ -+ tensor += ldm; -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel to initialize tensor to uniform distribution -+template -+__global__ void TensorInitializeGaussian( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ double rnd = curand_normal(&rng_state[threadIdx.x]); -+ -+ rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; -+ -+ if (dist.int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << dist.int_scale))); -+ *tensor = T(rnd / double(1 << dist.int_scale)); -+ } else { -+ *tensor = T(rnd); -+ } -+ } -+ } -+} -+ -+/// Kernel to initialize tensor to an identity matrix -+template -+__global__ void TensorInitializeLinear( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ *tensor = -+ dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; -+ } -+ } -+} -+ -+/// Kernel to initialize tensor to an identity matrix -+template -+__global__ void TensorInitializeIdentity( -+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { -+ __shared__ curandState_t rng_state[1024]; -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; -+ -+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); -+ -+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x; -+ int s_idx = blockIdx.y * blockDim.x; -+ -+ tensor += s_idx * ldm + c_idx; -+ -+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { -+ if (s_idx < dim_strided && c_idx < dim_contiguous) { -+ *tensor = (c_idx == s_idx ? T(1) : T(0)); -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h -new file mode 100644 -index 0000000..ea5359f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h -@@ -0,0 +1,159 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+#pragma once -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/coord.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/fast_math.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace kernel { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines several helpers -+namespace detail { -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Constructor for general rank -+ __inline__ __device__ -+ TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { -+ -+ int64_t product = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = Rank - RankRemaining; i < Rank; ++i) { -+ product *= size[i]; -+ } -+ -+ coord[Rank - 1 - RankRemaining] = index / product; -+ int64_t remaining = index % product; -+ -+ TensorForEachHelper(func, size, coord, remaining); -+ } -+}; -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Constructor for fastest chaning rank -+ __inline__ __device__ -+ TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { -+ -+ coord[Rank - 1] = index; -+ -+ if (coord < size) { -+ func(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel calls a functor for each element in a tensor's index space -+template -+__global__ void TensorForEach(Coord size, Params params = Params()) { -+ -+ Func func(params); -+ -+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t max_index = 1; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ max_index *= size[i]; -+ } -+ -+ CUTLASS_PRAGMA_NO_UNROLL -+ while (index < max_index) { -+ Coord coord; -+ -+ detail::TensorForEachHelper(func, size, coord, index); -+ index += blockDim.x * gridDim.x; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Kernel calls a functor for each element along a tensor's diagonal -+template -+__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { -+ -+ Func func(params); -+ -+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; -+ -+ if (index < end) { -+ Coord coord; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Rank; ++i) { -+ coord[i] = index; -+ } -+ -+ func(coord); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+__global__ void BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params) { -+ -+ Func func(params); -+ -+ size_t index = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ for (; index < capacity; index += blockDim.x * gridDim.x) { -+ ReferenceFactory::get(ptr, index) = func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace kernel -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h -new file mode 100644 -index 0000000..357ca3c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h -@@ -0,0 +1,355 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in device-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add, -+ int kMblock = 4, -+ int kNblock = 4 -+> -+__global__ void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ assert(M=N); -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; -+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; -+ int batch_idx = blockIdx.z; -+ -+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A); -+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B); -+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C); -+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D); -+ -+ for (; batch_idx < batch_count; batch_idx += gridDim.z) { -+ -+ // Compute matrix product using blocks -+ ComputeType accum[kMblock][kNblock]; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x B^T (Symmetric) or A x B^H (Hermitian) -+ // complex conjugation on operandB (b_t) is function of blas3 computation -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_b.at(MatrixCoord(col, k_block))) : -+ tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(b_t); -+ -+ // complex conjugation is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ -+ // B x A^T (Symmetric) or B x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) is function of blas3 computation -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))): -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType b_ik = ComputeType(b); -+ ComputeType a_jk = ComputeType(a_t); -+ -+ // complex conjugation here is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_ik = conj(b_ik); -+ } -+ // complex conjugation here is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_jk = conj(a_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < kNblock; j++) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < kMblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); -+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); -+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); -+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); -+ -+ } // for (batch_idx) -+} -+ -+} // namespace kernel -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ int const kMblock = 4; -+ int const kNblock = 4; -+ -+ dim3 block(16, 8); -+ dim3 grid( -+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), -+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), -+ batch_count % std::numeric_limits::max() -+ ); -+ -+ kernel::Rank2KComplex< -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ ElementC, -+ LayoutC, -+ ScalarType, -+ ComputeType, -+ ConvertOp, -+ InnerProductOp, -+ kMblock, -+ kNblock -+ ><<< grid, block >>>( -+ problem_size, -+ alpha, -+ tensor_a, -+ transform_a, -+ tensor_b, -+ transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ initial_accum, -+ fill_mode_c, -+ blas_mode, -+ batch_count, -+ batch_stride_A, -+ batch_stride_B, -+ batch_stride_C, -+ batch_stride_D -+ ); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h -new file mode 100644 -index 0000000..e29ad69 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h -@@ -0,0 +1,246 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/relatively_equal.h" -+ -+#include "cutlass/util/distribution.h" -+ -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template -+__global__ void BlockCompareEqual( -+ int *equal, -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity) { -+ -+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x; -+ -+ for (; idx < capacity; idx += gridDim.x * blockDim.x) { -+ -+ Element a = cutlass::ReferenceFactory::get(ptr_A, idx); -+ Element b = cutlass::ReferenceFactory::get(ptr_B, idx); -+ -+ if (a != b) { -+ *equal = 0; -+ -+ return; -+ } -+ } -+} -+ -+template -+__global__ void BlockCompareRelativelyEqual( -+ int *equal, -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ Element epsilon, -+ Element nonzero_floor) { -+ -+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x; -+ -+ for (; idx < capacity; idx += gridDim.x * blockDim.x) { -+ -+ Element a = cutlass::ReferenceFactory::get(ptr_A, idx); -+ Element b = cutlass::ReferenceFactory::get(ptr_B, idx); -+ -+ if (!relatively_equal(a, b, epsilon, nonzero_floor)) { -+ *equal = 0; -+ return; -+ } -+ } -+} -+ -+} // namespace kernel -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performs a bit-level equality check between two blocks -+template -+bool BlockCompareEqual( -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ int equal_flag = 1; -+ int *device_equal_flag = nullptr; -+ -+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate device flag."); -+ } -+ -+ if (cudaMemcpy( -+ device_equal_flag, -+ &equal_flag, -+ sizeof(int), -+ cudaMemcpyHostToDevice) != cudaSuccess) { -+ -+ throw std::runtime_error("Failed to copy equality flag to device."); -+ } -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockCompareEqual)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockCompareEqual<<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity); -+ -+ if (cudaMemcpy( -+ &equal_flag, -+ device_equal_flag, -+ sizeof(int), -+ cudaMemcpyDeviceToHost) != cudaSuccess) { -+ -+ cudaFree(device_equal_flag); -+ -+ throw std::runtime_error("Failed to copy equality flag from device."); -+ } -+ -+ cudaFree(device_equal_flag); -+ -+ return equal_flag; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Performs a bit-level equality check between two blocks -+template -+bool BlockCompareRelativelyEqual( -+ Element const *ptr_A, -+ Element const *ptr_B, -+ size_t capacity, -+ Element epsilon, -+ Element nonzero_floor, -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ int equal_flag = 1; -+ int *device_equal_flag = nullptr; -+ -+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { -+ throw std::runtime_error("Failed to allocate device flag."); -+ } -+ -+ if (cudaMemcpy( -+ device_equal_flag, -+ &equal_flag, -+ sizeof(int), -+ cudaMemcpyHostToDevice) != cudaSuccess) { -+ -+ throw std::runtime_error("Failed to copy equality flag to device."); -+ } -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockCompareRelativelyEqual)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockCompareRelativelyEqual<<< grid, block >>>( -+ device_equal_flag, -+ ptr_A, -+ ptr_B, -+ capacity, -+ epsilon, -+ nonzero_floor -+ ); -+ -+ if (cudaMemcpy( -+ &equal_flag, -+ device_equal_flag, -+ sizeof(int), -+ cudaMemcpyDeviceToHost) != cudaSuccess) { -+ -+ cudaFree(device_equal_flag); -+ -+ throw std::runtime_error("Failed to copy equality flag from device."); -+ } -+ -+ cudaFree(device_equal_flag); -+ -+ return equal_flag; -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // device -+} // reference -+} // cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h -new file mode 100644 -index 0000000..8568e47 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h -@@ -0,0 +1,1898 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+#if !defined(__CUDACC_RTC__) -+ -+// Standard Library includes -+#include -+#include -+#include -+#include -+#include -+ -+#endif -+ -+// CUDA includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/array.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/reference/device/tensor_foreach.h" -+#include "cutlass/util/distribution.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+CUTLASS_DEVICE -+FloatType random_normal_float(curandState_t *state) { -+ return curand_normal(state); -+} -+ -+template <> -+CUTLASS_DEVICE -+double random_normal_float(curandState_t *state) { -+ return curand_normal_double(state); -+} -+ -+template -+CUTLASS_DEVICE -+FloatType random_uniform_float(curandState_t *state) { -+ return curand_uniform(state); -+} -+ -+template <> -+CUTLASS_DEVICE -+double random_uniform_float(curandState_t *state) { -+ return curand_uniform_double(state); -+} -+ -+template -+struct RandomGaussianFunc { -+ -+ using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; -+ using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType mean; -+ FloatType stddev; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Element mean_ = 0, -+ Element stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ mean(static_cast(mean_)), -+ stddev(static_cast(stddev_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomGaussianFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd = random_normal_float(&rng_state); -+ rnd = params.mean + params.stddev * rnd; -+ -+ Element result; -+ if (params.int_scale >= 0) { -+ rnd = FloatType(IntType(rnd * params.float_scale_up)); -+ result = Element(rnd * params.float_scale_down); -+ } -+ else { -+ result = Element(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+ -+template -+struct RandomGaussianFunc> { -+ -+ using Element = complex; -+ using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; -+ using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType mean; -+ FloatType stddev; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Real mean_ = 0, -+ Real stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ mean(static_cast(mean_)), -+ stddev(static_cast(stddev_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomGaussianFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd_r = random_normal_float(&rng_state); -+ FloatType rnd_i = random_normal_float(&rng_state); -+ rnd_r = params.mean + params.stddev * rnd_r; -+ rnd_i = params.mean + params.stddev * rnd_i; -+ -+ Element result; -+ if (params.int_scale >= 0) { -+ rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); -+ rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); -+ -+ result = { -+ Real(rnd_r * params.float_scale_down), -+ Real(rnd_i * params.float_scale_down) -+ }; -+ } -+ else { -+ result = Element(Real(rnd_r), Real(rnd_i)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomGaussianFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomGaussianFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = typename RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ Element mean = Element(0), ///< Gaussian distribution's mean -+ Element stddev = Element(1), ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomGaussianFunc; -+ using Func = detail::TensorFillRandomGaussianFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, typename RandomFunc::Params(seed, mean, stddev, bits)) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template ///< Element type -+void BlockFillRandomGaussian( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ typename RealType::Type mean, ///< Gaussian distribution's mean -+ typename RealType::Type stddev, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomGaussianFunc; -+ -+ typename RandomFunc::Params params(seed, mean, stddev, bits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template ///< Element type -+struct RandomUniformFunc { -+ -+ using FloatType = typename std::conditional< -+ (sizeof(Element) > 4), -+ double, -+ float>::type; -+ -+ using IntType = typename std::conditional< -+ (sizeof(Element) > 4), -+ int64_t, -+ int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ FloatType max; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ Element max_ = 1, -+ Element min = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ range(static_cast(max_ - min)), -+ max(static_cast(max_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomUniformFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd = random_uniform_float(&rng_state); -+ rnd = params.max - params.range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (params.int_scale >= 0) { -+ rnd = FloatType(IntType(rnd * params.float_scale_up)); -+ result = Element(rnd * params.float_scale_down); -+ } -+ else { -+ result = Element(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template -+struct RandomUniformFunc> { -+ -+ using Element = complex; -+ -+ using FloatType = typename std::conditional< -+ (sizeof(Real) > 4), -+ double, -+ float>::type; -+ -+ using IntType = typename std::conditional< -+ (sizeof(Real) > 4), -+ int64_t, -+ int>::type; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ FloatType min; -+ int int_scale; -+ FloatType float_scale_up; -+ FloatType float_scale_down; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ FloatType max = 1, -+ FloatType min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), -+ range(static_cast(max - min_)), -+ min(static_cast(min_)), -+ int_scale(int_scale_) { -+ -+ float_scale_up = FloatType(IntType(1) << int_scale); -+ float_scale_up += FloatType(0.5) * float_scale_up; -+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomUniformFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ -+ FloatType rnd_r = random_uniform_float(&rng_state); -+ FloatType rnd_i = random_uniform_float(&rng_state); -+ -+ rnd_r = params.min + params.range * rnd_r; -+ rnd_i = params.min + params.range * rnd_i; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (params.int_scale >= 0) { -+ rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); -+ rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); -+ -+ result = { -+ Real(rnd_r * params.float_scale_down), -+ Real(rnd_i * params.float_scale_down) -+ }; -+ } -+ else { -+ result = Element(Real(rnd_r), Real(rnd_i)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomUniformFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomUniformFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ Element max = Element(1), ///< upper bound of distribution -+ Element min = Element(0), ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomUniformFunc; -+ using Func = detail::TensorFillRandomUniformFunc; -+ using Params = typename Func::Params; -+ -+ typename RandomFunc::Params random(seed, max, min, bits); -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, random) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ typename RealType::Type max, ///< upper bound of distribution -+ typename RealType::Type min, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomUniformFunc; -+ -+ typename RandomFunc::Params params(seed, max, min, bits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random sparse meta -+template ///< Element type -+struct RandomSparseMetaFunc { -+ -+ using FloatType = float; -+ -+ using IntType = int32_t; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ uint64_t seed; -+ FloatType range; -+ int MetaSizeInBits; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ uint64_t seed_ = 0, -+ int MetaSizeInBits_ = 2 -+ ): -+ seed(seed_), -+ MetaSizeInBits(MetaSizeInBits_) { -+ if (MetaSizeInBits_ == 2) { -+ range = 6; -+ } else if (MetaSizeInBits_ == 4) { -+ range = 2; -+ } -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ /// RNG state object -+ curandState_t rng_state; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ RandomSparseMetaFunc(Params const ¶ms): params(params) { -+ -+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; -+ -+ curand_init(params.seed, gtid, 0, &rng_state); -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ Element operator()() { -+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; -+ Element TwoToOneMeta[2] = {0x4, 0xe}; -+ -+ Element *MetaArray = -+ (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; -+ -+ Element result = 0x0; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { -+ FloatType rnd = random_uniform_float(&rng_state); -+ rnd = params.range * rnd; -+ Element meta = MetaArray[(int)rnd]; -+ -+ result = (Element)(result | ((Element)(meta << (i * 4)))); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomSparseMetaFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ using RandomFunc = RandomSparseMetaFunc; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ typename RandomFunc::Params random; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ typename RandomFunc::Params random_ = RandomFunc::Params() -+ ): -+ view(view_), random(random_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ RandomFunc random; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ params.view.at(coord) = random(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomSparseMeta( -+ TensorView view, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits = 2) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ using RandomFunc = detail::RandomSparseMetaFunc; -+ using Func = detail::TensorFillRandomUniformFunc; -+ using Params = typename Func::Params; -+ -+ typename RandomFunc::Params random(seed, MetaSizeInBits); -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, random) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template -+void BlockFillRandomSparseMeta( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits = 2) { ///< meta data size -+ -+ using RandomFunc = detail::RandomSparseMetaFunc; -+ -+ typename RandomFunc::Params params(seed, MetaSizeInBits); -+ -+ BlockForEach(ptr, capacity, params); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ Element other; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_ = TensorView(), -+ Element diag_ = Element(1), -+ Element other_ = Element(0) -+ ): -+ view(view_), diag(diag_), other(other_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Updates the tensor -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ params.view.at(coord) = (is_diag ? params.diag : params.other); -+ } -+}; -+ -+// Overwrites the elements of a tensor with a uniform value depending on fill mode -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillPartialFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element element; -+ FillMode fill_mode; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): fill_mode(FillMode::kNone) { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, -+ Element element_, -+ FillMode fill_mode_ -+ ): -+ view(view_), element(element_), fill_mode(fill_mode_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorFillPartialFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Overwrites the element if it is within the covered region. -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool predicate = true; -+ -+ switch (params.fill_mode) { -+ case FillMode::kFull: -+ predicate = true; -+ break; -+ -+ case FillMode::kLower: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] < coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kUpper: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] > coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kDiagonal: -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i - 1] != coord[i]) { -+ predicate = false; -+ break; -+ } -+ } -+ break; -+ -+ case FillMode::kNone: // fall-through -+ -+ default: -+ predicate = false; -+ break; -+ } -+ -+ if (predicate) { -+ params.view.at(coord) = params.element; -+ } -+ } -+}; -+ -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorClearPartialFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// -+ static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element element; -+ FillMode fill_mode; -+ int alignment; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params(): fill_mode(FillMode::kNone) { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, -+ Element element_, -+ FillMode fill_mode_, -+ int alignment_ -+ ): -+ view(view_), element(element_), fill_mode(fill_mode_), alignment(alignment_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorClearPartialFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Overwrites the element if it is within the covered region. -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool predicate = true; -+ -+ switch (params.fill_mode) { -+ -+ case FillMode::kLower: -+ if ((coord[0] >= coord[1]) || -+ ((coord[1] - coord[0]) >= params.alignment)) { -+ predicate = false; -+ break; -+ } -+ break; -+ -+ case FillMode::kUpper: -+ if ((coord[0] <= coord[1]) || -+ ((coord[0] - coord[1]) >= params.alignment)) { -+ predicate = false; -+ break; -+ } -+ break; -+ -+ case FillMode::kNone: // fall-through -+ -+ default: -+ predicate = false; -+ break; -+ } -+ -+ if (predicate) { -+ params.view.at(coord) = params.element; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor everywhere with a unique value for its diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillDiagonal( -+ TensorView view, ///< destination tensor -+ Element diag = Element(1), ///< value to write in the diagonal -+ Element other = Element(0)) { ///< value to write off the diagonal -+ -+ typedef detail::TensorFillDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, diag, other) -+ ); -+} -+ -+/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are -+/// not written. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillPartial( -+ TensorView view, ///< destination tensor -+ Element element, -+ FillMode fill_mode) { -+ -+ typedef detail::TensorFillPartialFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, element, fill_mode) -+ ); -+} -+ -+/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side -+/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorClearPartial( -+ TensorView view, ///< destination tensor -+ Element element, -+ FillMode fill_mode, -+ int alignment) { -+ -+ typedef detail::TensorClearPartialFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, element, fill_mode, alignment) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorView view, ///< destination tensor -+ Element val = Element(0)) { ///< value to uniformly fill it with -+ -+ TensorFillDiagonal(view, val, val); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor's digonal with 1 and 0 everywhere else. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillIdentity( -+ TensorView view) { ///< destination tensor -+ -+ TensorFillDiagonal(view, Element(1), Element(0)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ Element diag_ = Element(1) -+ ): -+ view(view_), diag(diag_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (is_diag) { -+ params.view.at(coord) = params.diag; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateDiagonal( -+ TensorView view, ///< destination tensor -+ Element diag = Element(1)) { -+ -+ typedef detail::TensorUpdateDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, diag) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateOffDiagonalFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element other; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_ = TensorView(), -+ Element other_ = Element(0) -+ ): -+ view(view_), other(other_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (!is_diag) { -+ params.view.at(coord) = params.other; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateOffDiagonal( -+ TensorView view, ///< destination tensor -+ Element other = Element(1)) { -+ -+ typedef detail::TensorUpdateOffDiagonalFunc Func; -+ typedef typename Func::Params Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, other) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillLinearFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Array v; -+ Element s; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Array const & v_, -+ Element s_ = Element(0) -+ ): -+ view(view_), v(v_), s(s_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorFillLinearFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ Element sum = params.s; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ sum += params.v[i] * Element(coord[i]); -+ } -+ -+ params.view.at(coord) = sum; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillLinear( -+ TensorView view, ///< destination tensor -+ Array const & v, -+ Element s = Element(0)) { -+ -+ using Func = detail::TensorFillLinearFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, v, s) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillRandom( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, -+ Distribution dist) { -+ -+ using Real = typename RealType::Type; -+ -+ if (dist.kind == Distribution::Gaussian) { -+ BlockFillRandomGaussian( -+ ptr, -+ capacity, -+ seed, -+ static_cast(dist.gaussian.mean), -+ static_cast(dist.gaussian.stddev), -+ dist.int_scale); -+ } -+ else if (dist.kind == Distribution::Uniform) { -+ BlockFillRandomUniform( -+ ptr, -+ capacity, -+ seed, -+ static_cast(dist.uniform.max), -+ static_cast(dist.uniform.min), -+ dist.int_scale); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorCopyDiagonalInFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element const *ptr; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Element const *ptr_ -+ ): -+ view(view_), ptr(ptr_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Only update the diagonal element -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ bool is_diagonal = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[0]) { -+ is_diagonal = false; -+ } -+ } -+ if (is_diagonal) { -+ params.view.at(coord) = params.ptr[coord[0]]; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies a diagonal in from host memory without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalIn( -+ TensorView view, ///< destination tensor -+ Element const *ptr) { ///< dense buffer of elements -+ -+ using Func = detail::TensorCopyDiagonalInFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, ptr) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -+namespace detail { -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorCopyDiagonalOutFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Scalar type -+ typedef typename TensorView::Element T; -+ -+ /// Coordinate in tensor's index space -+ typedef typename TensorView::TensorCoord TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element *ptr; -+ -+ /// Default ctor -+ CUTLASS_HOST_DEVICE -+ Params() { } -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ Params( -+ TensorView view_, ///< destination tensor -+ Element *ptr_ -+ ): -+ view(view_), ptr(ptr_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ /// Parameters object -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ /// Device-side initialization of RNG -+ CUTLASS_DEVICE -+ TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ bool is_diagonal = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[0]) { -+ is_diagonal = false; -+ } -+ } -+ if (is_diagonal) { -+ params.ptr[coord[0]] = params.view.at(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies the diagonal of a tensor into a dense buffer in host memory. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalOut( -+ Element *ptr, ///< dense buffer of elements -+ TensorView view) { ///< source tensor -+ -+ using Func = detail::TensorCopyDiagonalOutFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, ptr) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h -new file mode 100644 -index 0000000..cac558d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h -@@ -0,0 +1,136 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+#include "cutlass/util/reference/device/kernel/tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Launches a kernel calling a functor for each element in a tensor's index space. -+template -+struct TensorForEach { -+ -+ /// Constructor performs the operation. -+ TensorForEach(Coord size, Params params = Params(), int grid_size = 0, int block_size = 0) { -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::TensorForEach)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::TensorForEach<<< grid, block >>>(size, params); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Launches a kernel calling a functor for each element along a tensor's diagonal -+template -+struct TensorDiagonalForEach { -+ -+ /// Constructor performs the operation -+ TensorDiagonalForEach(Coord size, Params params = Params(), int start = 0, int end = -1, int block_size = 128) { -+ -+ if (end < 0) { -+ end = size.min(); -+ } -+ -+ dim3 block(block_size, 1, 1); -+ dim3 grid((end - start + block_size - 1) / block_size, 1, 1); -+ -+ kernel::TensorDiagonalForEach<<< grid, block >>>(size, params, start, end); -+ } -+}; -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct BlockForEach { -+ -+ /// Constructor performs the operation. -+ BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params = typename Func::Params(), -+ int grid_size = 0, -+ int block_size = 0) { -+ -+ if (!grid_size || !block_size) { -+ -+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API -+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize( -+ &grid_size, -+ &block_size, -+ reinterpret_cast(kernel::BlockForEach)); -+ -+ if (result != cudaSuccess) { -+ throw std::runtime_error("Failed to query occupancy."); -+ } -+ -+ // Limit block size. This has the effect of increasing the number of items processed by a -+ // single thread and reduces the impact of initialization overhead. -+ block_size = (block_size < 128 ? block_size : 128); -+ } -+ -+ dim3 grid(grid_size, 1, 1); -+ dim3 block(block_size, 1, 1); -+ -+ kernel::BlockForEach<<< grid, block >>>(ptr, capacity, params); -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namesace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h -new file mode 100644 -index 0000000..09c11db ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h -@@ -0,0 +1,510 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/util/device_memory.h" -+#include "cutlass/util/reference/detail/linear_to_coordinate.h" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace kernel { -+ -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp, -+ int kBlockSize = 128 -+> -+__global__ void TensorTransformReducePartial( -+ TensorView view, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ -+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t size = view.size(); -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (; idx < size; idx += blockDim.x * gridDim.x) { -+ -+ // Map linear thread ID onto tensor coordinate -+ typename Layout::TensorCoord coord; -+ -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); -+ -+ if (view.contains(coord)) { -+ -+ // Fetch element -+ Element x = view.at(coord); -+ -+ // Transform -+ identity = reduce(identity, transform(x)); -+ } -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ // One thread performs the final reduction and stores out. This could be enhanced via -+ // a tree reduction and pipelining. -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[blockIdx.x] = identity; -+ } -+} -+ -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp, -+ int kBlockSize = 128 -+> -+__global__ void TensorTransformReducePartial( -+ TensorView view_A, /// View of the tensor to reduce over -+ TensorView view_B, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ -+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; -+ int64_t size = view_A.size(); -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (; idx < size; idx += blockDim.x * gridDim.x) { -+ -+ // Map linear thread ID onto tensor coordinate -+ typename Layout::TensorCoord coord; -+ -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); -+ -+ if (view_A.contains(coord)) { -+ -+ // Fetch element -+ Element a = view_A.at(coord); -+ Element b = view_B.at(coord); -+ -+ // Transform -+ identity = reduce(identity, transform(a, b)); -+ } -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ // One thread performs the final reduction and stores out. This could be enhanced via -+ // a tree reduction and pipelining. -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[blockIdx.x] = identity; -+ } -+} -+ -+ -+template < -+ typename ComputeType, -+ typename ReduceOp, -+ int kBlockSize = 32 -+> -+__global__ void TensorTransformReduceFinalize( -+ ComputeType *workspace, -+ ComputeType identity, -+ int workspace_size, -+ ReduceOp reduce) { -+ -+ __shared__ ComputeType scratchpad[kBlockSize]; -+ -+ for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { -+ identity = reduce(identity, workspace[idx]); -+ } -+ -+ scratchpad[threadIdx.x] = identity; -+ -+ __syncthreads(); -+ -+ if (threadIdx.x == 0) { -+ -+ for (int i = 1; i < kBlockSize; ++i) { -+ identity = reduce(identity, scratchpad[i]); -+ } -+ -+ workspace[0] = identity; -+ } -+} -+ -+} // namespace kernel -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ int workspace_size, /// Number of elements in workspace -+ cudaStream_t stream = nullptr, /// CUDA stream to launch into -+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -+) { -+ -+ int const kBlockSize = 128; -+ -+ dim3 block(kBlockSize, 1); -+ dim3 grid(workspace_size, 1); -+ -+ kernel::TensorTransformReducePartial< -+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize -+ ><<< grid, block, 0, stream >>>( -+ view, identity, reduce, transform, workspace -+ ); -+ -+ int const kFinalizeBlockSize = 32; -+ -+ kernel::TensorTransformReduceFinalize< -+ ComputeType, ReduceOp, kFinalizeBlockSize -+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( -+ workspace, identity, workspace_size, reduce -+ ); -+ -+ if (copy_out) { -+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaMemcpy() failed"); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of two tensors, zipped together -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, /// View of the tensor to reduce over -+ TensorView view_B, /// View of the tensor to reduce over -+ ComputeType identity, /// Identity element of the reduction operation -+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType -+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType -+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] -+ int workspace_size, /// Number of elements in workspace -+ cudaStream_t stream = nullptr, /// CUDA stream to launch into -+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. -+) { -+ -+ if (view_A.extent() != view_B.extent()) { -+ throw std::runtime_error("Extents must be equal."); -+ } -+ -+ int const kBlockSize = 128; -+ -+ dim3 block(kBlockSize, 1); -+ dim3 grid(workspace_size, 1); -+ -+ kernel::TensorTransformReducePartial< -+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize -+ ><<< grid, block, 0, stream >>>( -+ view_A, view_B, identity, reduce, transform, workspace -+ ); -+ -+ int const kFinalizeBlockSize = 32; -+ -+ kernel::TensorTransformReduceFinalize< -+ ComputeType, ReduceOp, kFinalizeBlockSize -+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( -+ workspace, identity, workspace_size, reduce -+ ); -+ -+ if (copy_out) { -+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaMemcpy() failed"); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform, -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ // Optionally query for the SM count to size the workspace. -+ if (!workspace_size) { -+ -+ int device_idx = 0; -+ cudaDeviceProp prop; -+ -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ result = cudaGetDeviceProperties(&prop, device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProp() failed"); -+ } -+ -+ workspace_size = int(prop.multiProcessorCount); -+ } -+ -+ DeviceAllocation workspace(workspace_size); -+ -+ ComputeType output = TensorTransformReduce( -+ view, -+ identity, -+ reduce, -+ transform, -+ workspace.get(), -+ workspace_size, -+ stream, -+ true); -+ -+ return output; -+} -+ -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform, -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ // Optionally query for the SM count to size the workspace. -+ if (!workspace_size) { -+ -+ int device_idx = 0; -+ cudaDeviceProp prop; -+ -+ cudaError_t result = cudaGetDevice(&device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDevice() failed"); -+ } -+ -+ result = cudaGetDeviceProperties(&prop, device_idx); -+ if (result != cudaSuccess) { -+ throw std::runtime_error("cudaGetDeviceProp() failed"); -+ } -+ -+ workspace_size = int(prop.multiProcessorCount); -+ } -+ -+ DeviceAllocation workspace(workspace_size); -+ -+ ComputeType output = TensorTransformReduce( -+ view_A, -+ view_B, -+ identity, -+ reduce, -+ transform, -+ workspace.get(), -+ workspace_size, -+ stream, -+ true); -+ -+ return output; -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSum( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform, stream, workspace_size); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSumSq( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform, stream, workspace_size); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ TensorView view, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform, stream, workspace_size); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType(), -+ cudaStream_t stream = nullptr, -+ int workspace_size = 0 -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h -new file mode 100644 -index 0000000..c78f1dc ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h -@@ -0,0 +1,141 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined -+ in this header are not specialized for any particular data layout and are therefore not -+ intended to offer the best possible performance. Rather, they are intended to be generic -+ reference implementations to support the CUTLASS unit tests. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/util/reference/device/tensor_foreach.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorReLuFunc { -+ -+ /// View type -+ using TensorView = TensorView; -+ -+ /// Coordinate in tensor's index space -+ using TensorCoord = typename TensorView::TensorCoord; -+ -+ /// Parameters structure -+ struct Params { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element threshold; -+ -+ -+ // -+ // Methods -+ // -+ -+ Params( -+ TensorView view_ = TensorView(), -+ Element threshold_ = Element(0) -+ ): -+ view(view_), threshold(threshold_) { -+ -+ } -+ }; -+ -+ // -+ // Data members -+ // -+ -+ Params params; -+ -+ // -+ // Methods -+ // -+ -+ CUTLASS_DEVICE -+ TensorReLuFunc(Params const ¶ms): params(params) { -+ -+ } -+ -+ CUTLASS_DEVICE -+ void operator()(TensorCoord const &coord) { -+ -+ Element const & value = params.view.at(coord); -+ params.view.at(coord) = (value < params.threshold) ? params.threshold : value; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Apply ReLu on a tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorReLu( -+ TensorView view, ///< destination tensor -+ Element threshold = Element(0)) { ///< ReLu threshold -+ -+ using Func = detail::TensorReLuFunc; -+ using Params = typename Func::Params; -+ -+ TensorForEach( -+ view.extent(), -+ Params(view, threshold) -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h -new file mode 100644 -index 0000000..094f716 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h -@@ -0,0 +1,186 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace device { -+namespace thread { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Thread-level blocked general matrix product. -+// -+// Note, this is a reference implementation. Performance is not expected to approach peak. -+// -+template < -+ typename TensorRefA, -+ typename TensorRefB, -+ typename TensorRefC, -+ typename ScalarType, -+ typename AccumulatorType, -+ typename OutputTile, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+struct Gemm { -+ -+ using ElementA = typename TensorRefA::Element; -+ using ElementB = typename TensorRefB::Element; -+ using ElementC = typename TensorRefC::Element; -+ -+ // -+ // Data members -+ // -+ -+ /// Tile for A operand -+ ElementA A_tile[OutputTile::kColumn]; -+ -+ /// Tile for B operand -+ ElementB B_tile[OutputTile::kRow]; -+ -+ /// Tile for Accumulator -+ AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ CUTLASS_HOST_DEVICE -+ Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { -+ -+ // Clear fetch registers -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ A_tile[i] = ElementA(0); -+ } -+ -+ for (int j = 0; j < OutputTile::kColumn; ++j) { -+ B_tile[j] = ElementB(0); -+ } -+ -+ // Clear accumulators -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kColumn; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kRow; ++i) { -+ accum[j][i] = initial_accum; -+ } -+ } -+ } -+ -+ /// Computes a matrix product -+ CUTLASS_HOST_DEVICE -+ Gemm & multiply_add( -+ gemm::GemmCoord problem_size, -+ TensorRefA tensor_a, -+ TensorRefB tensor_b, -+ MatrixCoord output_coord = MatrixCoord()) { -+ -+ InnerProductOp inner_product_op; -+ -+ // Loop over the GEMM K dimension -+ CUTLASS_PRAGMA_NO_UNROLL -+ for (int k = 0; k < problem_size.k(); ++k) { -+ -+ // Fetch a slice of the A matrix -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ if (output_coord.row() + i < problem_size.m()) { -+ A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); -+ } -+ } -+ -+ // Fetch a slice of the B matrix -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ if (output_coord.column() + j < problem_size.n()) { -+ B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); -+ } -+ } -+ -+ // Compute an accumulated matrix product -+ CUTLASS_PRAGMA_UNROLL -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); -+ } -+ } -+ } -+ -+ return *this; -+ } -+ -+ /// Performs linear scaling of matrix product and updates output tensor -+ CUTLASS_HOST_DEVICE -+ Gemm & epilogue( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ ScalarType beta, -+ TensorRefC tensor_c, -+ TensorRefC tensor_d, -+ MatrixCoord output_coord = MatrixCoord()) { -+ -+ ConvertOp convert_op; -+ -+ // Update the output tensor -+ for (int j = 0; j < OutputTile::kRow; ++j) { -+ for (int i = 0; i < OutputTile::kColumn; ++i) { -+ MatrixCoord coord = output_coord + MatrixCoord(i, j); -+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[j][i]) + -+ beta * ScalarType(tensor_c.at(coord)) -+ ); -+ } -+ } -+ } -+ -+ return *this; -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace thread -+} // namespace device -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -new file mode 100644 -index 0000000..4d8a7fc ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -@@ -0,0 +1,789 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+ -+/*! \file -+ \brief Reference implementation for convolution in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/functional.h" -+#include "cutlass/layout/tensor.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/tensor_ref.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/conv/convolution.h" -+#include "cutlass/conv/conv2d_problem_size.h" -+#include "cutlass/conv/conv3d_problem_size.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Forward propagation -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// y = conv2d(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dFprop( -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int group_idx = k / (problem_size.K / problem_size.groups); -+ int channels_per_group = problem_size.C / problem_size.groups; -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < channels_per_group; ++c) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { -+ -+ ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); -+ ElementB b = tensor_w.at({k, r, s, c}); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); -+ } -+ -+ tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+} -+ -+/// Depthwise-separable convolution -+template , -+ typename InnerProductOp = multiply_add > -+void Depsep_Fprop(cutlass::TensorView tensor_A, -+ cutlass::TensorView tensor_B, -+ cutlass::TensorView tensor_C, -+ cutlass::TensorView tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta, -+ cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), -+ cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), -+ cutlass::Coord<2> dilation = cutlass::Coord<2>(), -+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < tensor_C.extent().n(); ++n) { -+ for (int p = 0; p < tensor_C.extent().h(); ++p) { -+ for (int q = 0; q < tensor_C.extent().w(); ++q) { -+ for (int g = 0; g < tensor_C.extent().c(); ++g) { -+ ElementAccumulator acc = ElementAccumulator(); -+ for (int r = 0; r < tensor_B.extent().h(); ++r) { -+ for (int s = 0; s < tensor_B.extent().w(); ++s) { -+ -+ // input activation H and W -+ int h = p * conv_stride[0] - padding[0] + r * dilation[0]; -+ int w = q * conv_stride[1] - padding[2] + s * dilation[1]; -+ -+ if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { -+ ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); -+ -+ ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) -+ ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) -+ : tensor_B.at(cutlass::make_Coord( -+ g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); -+ tensor_D.at(cutlass::make_Coord(n, p, q, g)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Dgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dDgrad( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int h = 0; h < problem_size.H; ++h) { -+ for (int w = 0; w < problem_size.W; ++w) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (p >= 0 && (p % problem_size.stride_h) == 0 && -+ q >= 0 && (q % problem_size.stride_w) == 0) { -+ -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+#if 0 -+ std::cout << "row:" -+ << n * problem_size.H * problem_size.W + -+ h * problem_size.W + -+ w << " " -+ << "n, p, q: (" -+ << n << ", " -+ << p << ", " -+ << q << ") * " -+ << "r, s: (" -+ << r << ", " -+ << s << ") [" -+ << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" -+ << std::endl; -+#endif -+ if (p < problem_size.P && q < problem_size.Q) { -+ -+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); -+ ElementB b = tensor_w.at(cutlass::make_Coord(k, r, s, c)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ -+ } // for (K) -+ } // for (S) -+ } // for (R) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); -+ } -+ -+ tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (W) -+ } // for (H) -+ } // for (N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Wgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2dWgrad( -+ cutlass::conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ InnerProductOp inner_product_op; -+ ConvertOp convert_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int k = 0; k < problem_size.K; ++k) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ -+ cutlass::Tensor4DCoord b_coord; -+ -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ b_coord = make_Coord( -+ n, -+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, -+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, -+ c); -+ -+ if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && -+ b_coord.w() < problem_size.W && b_coord.w() >= 0) { -+ -+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); -+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); -+ acc = inner_product_op(a, b, acc); -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); -+ } -+ -+ tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (K) -+} -+ -+/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv2d( -+ conv::Operator convolutional_operator, -+ conv::Conv2dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ Conv2dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kDgrad: -+ Conv2dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kWgrad: -+ Conv2dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ default: -+ break; -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// 3D convolution -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// y = conv3d(x, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dFprop( -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_x, -+ TensorRef tensor_w, -+ TensorRef tensor_y_in, -+ TensorRef tensor_y_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int z = 0; z < problem_size.Z; ++z) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; -+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; -+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; -+ -+ if (d >= 0 && d < problem_size.D && -+ h >=0 && h < problem_size.H && -+ w >= 0 && w < problem_size.W) { -+ -+ ElementA a = tensor_x.at({n, d, h, w, c}); -+ ElementB b = tensor_w.at({k, t, r, s, c}); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); -+ } -+ -+ tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Dgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dx = dgrad(dy, w) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dDgrad( -+ cutlass::conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_w, -+ TensorRef tensor_dx_in, -+ TensorRef tensor_dx_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int d = 0; d < problem_size.D; ++d) { -+ for (int h = 0; h < problem_size.H; ++h) { -+ for (int w = 0; w < problem_size.W; ++w) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int k = 0; k < problem_size.K; ++k) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; -+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; -+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; -+ -+ if (z >= 0 && (z % problem_size.stride_d) == 0 && -+ p >= 0 && (p % problem_size.stride_h) == 0 && -+ q >= 0 && (q % problem_size.stride_w) == 0) { -+ -+ z = z / problem_size.stride_d; -+ p = p / problem_size.stride_h; -+ q = q / problem_size.stride_w; -+ -+ if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { -+ -+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); -+ ElementB b = tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); -+ -+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); -+ } -+ } -+ -+ } // for (K) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); -+ } -+ -+ tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (W) -+ } // for (H) -+ } // for (D) -+ } // for (N) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Wgrad -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// dw = wgrad(dy, x) -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3dWgrad( -+ cutlass::conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_dy, -+ TensorRef tensor_x, -+ TensorRef tensor_dw_in, -+ TensorRef tensor_dw_out, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ InnerProductOp inner_product_op; -+ ConvertOp convert_op; -+ -+ // Apply MMA and accumulate ElementAccumulator -+ for (int k = 0; k < problem_size.K; ++k) { -+ for (int t = 0; t < problem_size.T; ++t) { -+ for (int r = 0; r < problem_size.R; ++r) { -+ for (int s = 0; s < problem_size.S; ++s) { -+ for (int c = 0; c < problem_size.C; ++c) { -+ -+ ElementAccumulator acc = ElementAccumulator(); -+ -+ for (int n = 0; n < problem_size.N; ++n) { -+ for (int z = 0; z < problem_size.Z; ++z) { -+ for (int p = 0; p < problem_size.P; ++p) { -+ for (int q = 0; q < problem_size.Q; ++q) { -+ -+ int filter_t = t; -+ int filter_r = r; -+ int filter_s = s; -+ -+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) { -+ filter_t = problem_size.T - 1 - t; -+ filter_r = problem_size.R - 1 - r; -+ filter_s = problem_size.S - 1 - s; -+ } -+ -+ Tensor5DCoord b_coord = make_Coord( -+ n, -+ z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, -+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, -+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, -+ c); -+ -+ if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && -+ b_coord.h() < problem_size.H && b_coord.h() >= 0 && -+ b_coord.w() < problem_size.W && b_coord.w() >= 0) { -+ -+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); -+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); -+ -+ acc = inner_product_op(a, b, acc); -+ } -+ } -+ } -+ } -+ } -+ -+ // Apply Epilogue, compute ElementCompute, convert and store ElementC -+ ElementC c_ref = ElementC(); -+ -+ if (beta != ElementCompute()) { -+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); -+ } -+ -+ tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = -+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); -+ -+ } // for (C) -+ } // for (S) -+ } // for (R) -+ } // for (T) -+ } // for (K) -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ElementCompute, -+ typename ElementAccumulator = ElementCompute, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Conv3d( -+ conv::Operator convolutional_operator, -+ conv::Conv3dProblemSize problem_size, -+ TensorRef tensor_A, -+ TensorRef tensor_B, -+ TensorRef tensor_C, -+ TensorRef tensor_D, -+ ElementCompute alpha, -+ ElementCompute beta) { -+ -+ switch (convolutional_operator) { -+ case conv::Operator::kFprop: -+ Conv3dFprop< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kDgrad: -+ Conv3dDgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ case conv::Operator::kWgrad: -+ Conv3dWgrad< -+ ElementA, LayoutA, -+ ElementB, LayoutB, -+ ElementC, LayoutC, -+ ElementCompute, -+ ElementAccumulator, -+ ConvertOp, InnerProductOp -+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); -+ break; -+ -+ default: -+ break; -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h -new file mode 100644 -index 0000000..0b4285c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h -@@ -0,0 +1,66 @@ -+ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/util/reference/host/tensor_reduce.h" -+#include "cutlass/core_io.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorRelativeErrorMetric( -+ TensorView view_A_computed, -+ TensorView view_B_reference, -+ ComputeType identity = ComputeType() -+) { -+ -+ return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / -+ cutlass::reference::host::TensorNorm(view_B_reference, identity); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h -new file mode 100644 -index 0000000..cd87e6f ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h -@@ -0,0 +1,453 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+template -+struct CastIfScalar { -+ static Out cast(In in) { -+ return Out(in); -+ } -+}; -+ -+template -+struct CastIfScalar, In> { -+ typedef cutlass::complex Out; -+ static Out cast(In in) { -+ return Out(static_cast(in)); -+ } -+}; -+ -+template -+struct CastIfScalar, cutlass::complex> { -+ typedef cutlass::complex Out; -+ typedef cutlass::complex In; -+ static Out cast(In in) { -+ return Out(in); -+ } -+}; -+ -+template -+Out cast_if_scalar(In in) { -+ return CastIfScalar::cast(in); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b(cast_if_scalar(b)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_gemm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_gemm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Gemm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add-saturate -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm, -+ NumericConverterClamp>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Parital specialization for XOR-popc -+template -+struct Gemm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_gemm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Batched GEMM -+// -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a batch of GEMMs over a set of matrices of common dimension. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c, -+ AccumulatorType initial_accum) { -+ -+ typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); -+ typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); -+ typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); -+ -+ for (int batch = 0; -+ batch < batch_count; -+ ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { -+ -+ Gemm -+ gemm; -+ -+ gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, -+ initial_accum); -+ } -+} -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+// -+// TensorRefCollection* is a type satisfying the TensorRefCollection concept. -+// -+template < -+ typename TensorRefCollectionA, -+ typename TensorRefCollectionB, -+ typename TensorRefCollectionC, -+ typename ScalarType, -+ typename AccumulatorType -+> -+void BatchedGemm( -+ gemm::GemmCoord problem_size, -+ int batch_count, -+ ScalarType alpha, -+ TensorRefCollectionA const& tensor_a, -+ TensorRefCollectionB const& tensor_b, -+ ScalarType beta, -+ TensorRefCollectionC &tensor_c) { -+ -+ BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h -new file mode 100644 -index 0000000..f16e19c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h -@@ -0,0 +1,208 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/matrix_coord.h" -+ -+#include "cutlass/tensor_view.h" -+ -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d) { -+ -+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h -new file mode 100644 -index 0000000..7e94210 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h -@@ -0,0 +1,228 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued GEMM in host-side code. -+*/ -+ -+#pragma once -+ -+#include "cutlass/coord.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_ref_planar_complex.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add> -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d, -+ complex initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ using ComplexA = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexB = typename TensorRefPlanarComplex::ComplexElement; -+ using ComplexC = typename TensorRefPlanarComplex::ComplexElement; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ complex accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ -+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); -+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); -+ -+ complex a = complex{ -+ ComputeType(a_ik.real()), -+ ComputeType(a_ik.imag()) -+ }; -+ -+ complex b = complex{ -+ ComputeType(b_kj.real()), -+ ComputeType(b_kj.imag()) -+ }; -+ -+ if (transform_a == ComplexTransform::kConjugate) { -+ a = conj(a); -+ } -+ -+ if (transform_b == ComplexTransform::kConjugate) { -+ b = conj(b); -+ } -+ -+ accum[i][j] = inner_product_op(a, b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ complex acc{ -+ ScalarType(accum[i][j].real()), -+ ScalarType(accum[i][j].imag()) -+ }; -+ -+ ComplexC d_ij = tensor_c.at(coord); -+ -+ complex src{ -+ ScalarType(d_ij.real()), -+ ScalarType(d_ij.imag()) -+ }; -+ -+ complex result = alpha * acc + beta * src; -+ -+ d_ij.real() = convert_op(result.real()); -+ d_ij.imag() = convert_op(result.imag()); -+ -+ tensor_d.at(coord) = d_ij; -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void GemmPlanarComplex( -+ gemm::GemmCoord problem_size, -+ complex alpha, -+ TensorRefPlanarComplex tensor_a, -+ ComplexTransform transform_a, -+ TensorRefPlanarComplex tensor_b, -+ ComplexTransform transform_b, -+ complex beta, -+ TensorRefPlanarComplex tensor_c, -+ TensorRefPlanarComplex tensor_d) { -+ -+ GemmPlanarComplex( -+ problem_size, -+ alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, -+ tensor_c, -+ tensor_d, -+ complex()); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp -new file mode 100644 -index 0000000..64a0600 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp -@@ -0,0 +1,311 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for GETT in host-side code. -+*/ -+ -+#pragma once -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cute/tensor.hpp" -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass::reference::host { -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ class ElementAccumulator_, -+ class TensorA_, // (M, K, L) -+ class TensorB_ // (N, K, L) -+> -+struct GettMainloopParams { -+ using ElementAccumulator = ElementAccumulator_; -+ using TensorA = TensorA_; -+ using TensorB = TensorB_; -+ using EngineA = typename TensorA::engine_type; -+ using LayoutA = typename TensorA::layout_type; -+ using EngineB = typename TensorB::engine_type; -+ using LayoutB = typename TensorB::layout_type; -+ -+ TensorA A{}; -+ TensorB B{}; -+ -+ ComplexTransform transform_A = ComplexTransform::kNone; -+ ComplexTransform transform_B = ComplexTransform::kNone; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template< -+ class ElementScalar_, -+ class ElementAccumulator_, -+ class ElementCompute_, -+ class TensorC_, // (M, N, L) -+ class TensorD_ // (M, N, L) -+> -+struct GettEpilogueParams { -+ using ElementScalar = ElementScalar_; -+ using ElementAccumulator = ElementAccumulator_; -+ using ElementCompute = ElementCompute_; -+ using TensorC = TensorC_; -+ using TensorD = TensorD_; -+ using EngineC = typename TensorC::engine_type; -+ using LayoutC = typename TensorC::layout_type; -+ using EngineD = typename TensorD::engine_type; -+ using LayoutD = typename TensorD::layout_type; -+ ElementScalar alpha = ElementScalar(1); -+ ElementScalar beta = ElementScalar(0); -+ -+ TensorC C{}; -+ TensorD D{}; -+}; -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - General Tensor-Tensor contraction reference kernel -+template < -+ class MainloopParams, -+ class EpilogueParams -+> -+void Gett( -+ MainloopParams const& mainloop_params, -+ EpilogueParams const& epilogue_params) -+{ -+ -+ static int constexpr kBlockM = 64; -+ static int constexpr kBlockN = 64; -+ -+ #pragma omp parallel for collapse(3) -+ for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { -+ for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { -+ for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { -+ typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; -+ gett_mainloop(mainloop_params, m, n, l, acc); -+ gett_epilogue(epilogue_params, m, n, l, acc); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - Mainloop -+template -+void gett_mainloop( -+ MainloopParams const& mainloop_params, -+ int64_t m, -+ int64_t n, -+ int64_t l, -+ ElementAccumulator (&acc)[kBlockM][kBlockN]) -+{ -+ -+ static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); -+ static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); -+ -+ using ElementA = typename MainloopParams::EngineA::value_type; -+ using ElementB = typename MainloopParams::EngineB::value_type; -+ -+ using RingOp = multiply_add; -+ RingOp fma_op; -+ -+ // Zero out accumulators -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // Compute on this k-block -+ for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { -+ // Load A -+ ElementAccumulator a_frag[kBlockM]; -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { -+ a_frag[m_b] = static_cast(mainloop_params.A(m + m_b, k, l)); -+ if (mainloop_params.transform_A == ComplexTransform::kConjugate) { -+ a_frag[m_b] = conj(a_frag[m_b]); -+ } -+ } else { -+ a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // Load B -+ ElementAccumulator b_frag[kBlockN]; -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { -+ b_frag[n_b] = static_cast(mainloop_params.B(n + n_b, k, l)); -+ if (mainloop_params.transform_B == ComplexTransform::kConjugate) { -+ b_frag[n_b] = conj(b_frag[n_b]); -+ } -+ } else { -+ b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity -+ } -+ } -+ -+ // do compute -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GETT - Epilogue -+template -+void gett_epilogue( -+ EpilogueParams const& epilogue_params, -+ int64_t m, -+ int64_t n, -+ int64_t l, -+ ElementAccumulator (&acc)[kBlockM][kBlockN]) -+{ -+ static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); -+ static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); -+ -+ using ElementCompute = typename EpilogueParams::ElementCompute; -+ using ElementC = typename EpilogueParams::EngineC::value_type; -+ -+ using ElementD = typename EpilogueParams::EngineD::value_type; -+ using ElementScalar = typename EpilogueParams::ElementScalar; -+ // Input related converter -+ NumericConverter accumulator_converter; -+ NumericConverter source_converter; -+ -+ // Scale related converter -+ NumericConverter scale_converter; -+ // Output related converter -+ NumericConverter destination_converter; -+ // Epilogue operations -+ multiply_add epilogue_fma; -+ multiplies mul; -+ -+ // Do conversion -+ ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); -+ ElementCompute converted_beta = scale_converter(epilogue_params.beta); -+ for (int n_b = 0; n_b < kBlockN; ++n_b) { -+ for (int m_b = 0; m_b < kBlockM; ++m_b) { -+ if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { -+ // Convert every type to ElementCompute first, do compute, convert to output type, write it out -+ ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); -+ ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); -+ -+ ElementScalar output = epilogue_fma(converted_alpha, converted_acc, ElementCompute(0)); -+ output = epilogue_fma(converted_beta, converted_src, output); -+ -+ epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); -+ } -+ } -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// GEMM - General Matrix-Matrix contraction without conjugation options -+template < -+ class MainloopParams, -+ class EpilogueParams -+> -+void Gemm3x( -+ MainloopParams const& mainloop_params, -+ EpilogueParams const& epilogue_params) -+{ -+ using namespace cute; -+ -+ static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{})); -+ static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{})); -+ static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); -+ -+ if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { -+ // append a batch mode of size 1 if we do not have tensors that are rank 3 -+ Layout layout_A = make_layout( -+ make_shape(get<0>(mainloop_params.A.shape()), get<1>(mainloop_params.A.shape()), Int<1>{}), -+ make_stride(get<0>(mainloop_params.A.stride()), get<1>(mainloop_params.A.stride()), int64_t(cosize(mainloop_params.A.layout())))); -+ -+ Layout layout_B = make_layout( -+ make_shape(get<0>(mainloop_params.B.shape()), get<1>(mainloop_params.B.shape()), Int<1>{}), -+ make_stride(get<0>(mainloop_params.B.stride()), get<1>(mainloop_params.B.stride()), int64_t(cosize(mainloop_params.B.layout())))); -+ -+ Layout layout_C = make_layout( -+ make_shape(get<0>(epilogue_params.C.shape()), get<1>(epilogue_params.C.shape()), Int<1>{}), -+ make_stride(get<0>(epilogue_params.C.stride()), get<1>(epilogue_params.C.stride()), int64_t(cosize(epilogue_params.C.layout())))); -+ -+ Layout layout_D = make_layout( -+ make_shape(get<0>(epilogue_params.D.shape()), get<1>(epilogue_params.D.shape()), Int<1>{}), -+ make_stride(get<0>(epilogue_params.D.stride()), get<1>(epilogue_params.D.stride()), int64_t(cosize(epilogue_params.D.layout())))); -+ auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); -+ auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); -+ auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); -+ auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); -+ // Reconstruct mainloop params -+ GettMainloopParams -+ mainloop_params_converted{TensorA, -+ TensorB, -+ mainloop_params.transform_A, -+ mainloop_params.transform_B}; -+ -+ // Reconstruct epilogue params -+ GettEpilogueParams -+ epilogue_params_converted{epilogue_params.alpha, -+ epilogue_params.beta, -+ TensorC, -+ TensorD -+ }; -+ -+ Gett(mainloop_params_converted, epilogue_params_converted); -+ } -+ else { -+ // if we already have a batch mode, just pass it through -+ Gett(mainloop_params, epilogue_params); -+ } -+} -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // cutlass::reference::host -+ -+///////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h -new file mode 100644 -index 0000000..5b34260 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h -@@ -0,0 +1,261 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for Rank 2k update in host-side code. -+ -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_rank2k( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ static_assert( -+ FillModeC == FillMode::kLower || -+ FillModeC == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), -+ std::greater_equal, -+ std::less_equal>::type; -+ -+ // Note: batch is ignored. -+ // Note: M is same as N for Rank 2k update -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < N; row_block += Nblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Nblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < N && col < N && compare_op(row, col)) -+ { -+ -+ // A x B^T -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b_t(cast_if_scalar(b_t)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); -+ -+ // B x A^T -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType compute_b(cast_if_scalar(b)); -+ ComputeType compute_a_t(cast_if_scalar(a_t)); -+ -+ accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Nblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < N && col < N && -+ ( (FillModeC == FillMode::kLower && row >= col) || -+ (FillModeC == FillMode::kUpper && row <= col) ) -+ ) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_rank2k( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_rank2k( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ FillMode FillModeC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Rank2K; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Rank2K { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_rank2k>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_rank2k>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h -new file mode 100644 -index 0000000..519379c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h -@@ -0,0 +1,318 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued Rank 2K update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN -+ assert(M==N); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x B^T (Symmetric) or A x B^H (Hermitian) -+ // complex conjugation on operandB (b_t) is function of blas3 computation -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_b.at(MatrixCoord(col, k_block))) : -+ tensor_b.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(b_t); -+ -+ // complex conjugation is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ /* HER2K need two epilogues to handle complex alpha value */ -+ if ( blas_mode == BlasMode::kHermitian ) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ tensor_d.at(coord) = convert_op(alpha * -+ ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ /* Zeoring out accum for second HERK */ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // B x A^T (Symmetric) or B x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) is function of blas3 computation -+ ElementB b = tensor_b.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))): -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType b_ik = ComputeType(b); -+ ComputeType a_jk = ComputeType(a_t); -+ -+ // complex conjugation here is a function of operand layouts -+ if (transform_b == ComplexTransform::kConjugate) { -+ b_ik = conj(b_ik); -+ } -+ // complex conjugation here is a function of operand layouts -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_jk = conj(a_jk); -+ } -+ -+ accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? -+ conj(alpha) : alpha; -+ ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? -+ 1 : beta; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType d = (blas_mode == BlasMode::kHermitian) ? -+ tensor_d.at(coord) : tensor_c.at(coord); -+ -+ ScalarType tmp_d = convert_op( -+ alpha_hermitian * ScalarType(accum[i][j]) + -+ beta_hermitian * d); -+ -+ if (blas_mode == BlasMode::kHermitian && row == col ) { -+ tensor_d.at(coord) = real(tmp_d); -+ } else { -+ tensor_d.at(coord) = tmp_d; -+ } -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ TensorRef tensor_b, -+ ComplexTransform transform_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ tensor_b, transform_b, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h -new file mode 100644 -index 0000000..d5f3f2e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h -@@ -0,0 +1,234 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued Rank 2K update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename ConvertOp = NumericConverter, -+ typename InnerProductOp = multiply_add -+> -+void Rank2KComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ FillMode fill_mode_c, -+ BlasMode blas_mode, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ int const K = problem_size.k(); -+ -+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN -+ assert(M==N); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N && -+ ( (fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col) ) -+ ) { -+ -+ // A x A^T (Symmetric) or A x A^H (Hermitian) -+ // complex conjugation on operandB (a_t) (function of blas3 computation) -+ ElementA a = tensor_a.at(MatrixCoord(row, k_block)); -+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ? -+ conj(tensor_a.at(MatrixCoord(col, k_block))) : -+ tensor_a.at(MatrixCoord(col, k_block)); -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_jk = ComputeType(a_t); -+ -+ // complex conjugation (function of input layouts) -+ if (transform_a == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } -+ // complex conjugation (function of input layouts) -+ if (transform_a == ComplexTransform::kConjugate) { -+ b_jk = conj(b_jk); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); -+ -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N && -+ ((fill_mode_c == FillMode::kLower && row >= col) || -+ (fill_mode_c == FillMode::kUpper && row <= col)) -+ ) { -+ -+ ScalarType c = tensor_c.at(coord); -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (blas_mode == BlasMode::kHermitian) { -+ c = (row == col) ? real(c) : c; -+ } -+ -+ ScalarType tmp_d = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ -+ if (blas_mode == BlasMode::kHermitian && row == col ) { -+ tensor_d.at(coord) = real(tmp_d); -+ } else { -+ tensor_d.at(coord) = tmp_d; -+ } -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// This assumes the accumulator type is the same type as the scalars. -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType -+> -+void RankKComplex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ ComplexTransform transform_a, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ FillMode fill_mode_c, -+ BlasMode blas_mode) { -+ -+ Rank2KComplex( -+ problem_size, alpha, -+ tensor_a, transform_a, -+ beta, tensor_c, tensor_d, -+ ScalarType(0), -+ fill_mode_c, -+ blas_mode); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h -new file mode 100644 -index 0000000..736107a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm.h -@@ -0,0 +1,285 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for SYMM update in host-side code. -+ -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+ -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert( -+ FillModeA == FillMode::kLower || -+ FillModeA == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp_w_diag = typename TrMatrixCompareOp::Type; -+ using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp_w_diag compare_op_1; -+ CompareOp_wo_diag compare_op_2; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a_1 = ElementA(); -+ ElementB b_1 = ElementB(); -+ ElementA a_2 = ElementA(); -+ ElementB b_2 = ElementB(); -+ -+ // A x B or B x A (with diagonal) -+ if (SideModeA == SideMode::kLeft) { -+ a_1 = (compare_op_1(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); -+ b_1 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a_1 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_1 = (compare_op_1(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); -+ } -+ -+ ComputeType compute_a_1(cast_if_scalar(a_1)); -+ ComputeType compute_b_1(cast_if_scalar(b_1)); -+ -+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); -+ -+ // A^T x B or B x A^T (without diagonal) -+ if (SideModeA == SideMode::kLeft) { -+ a_2 = (compare_op_2(k_block, row)) ? -+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); -+ b_2 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a_2 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_2 = (compare_op_2(col, k_block)) ? -+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); -+ } -+ -+ ComputeType compute_a_2(cast_if_scalar(a_2)); -+ ComputeType compute_b_2(cast_if_scalar(b_2)); -+ -+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * ScalarType(tensor_c.at(coord))); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum) { -+ compute_symm( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, -+ initial_accum); -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Symm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Symm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); -+ } -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h -new file mode 100644 -index 0000000..aa46891 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h -@@ -0,0 +1,319 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued SYMM update in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef -+/// objects. -+/// -+/// Explicitly naming types needed by this template can be cumbersome, particularly for the -+/// accumulator type, so a function argument 'initial_accum' is exposed. Passing -+/// AccumulatorType(0) as the last function argument can be easier than naming all template -+/// arguments explicitly. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ BlasMode BlasMode_ = BlasMode::kSymmetric, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_symm_complex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum, -+ int batch_count = 1, -+ int64_t batch_stride_A = 0, -+ int64_t batch_stride_B = 0, -+ int64_t batch_stride_C = 0, -+ int64_t batch_stride_D = 0) { -+ -+ static SideMode const kSideModeA = SideModeA; -+ static FillMode const kFillModeA = FillModeA; -+ static BlasMode const kBlasMode = BlasMode_; -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutB::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(kSideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert( -+ kFillModeA == FillMode::kLower || -+ kFillModeA == FillMode::kUpper, -+ "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp_w_diag = typename TrMatrixCompareOp::Type; -+ using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp_w_diag compare_op_1; -+ CompareOp_wo_diag compare_op_2; -+ -+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { -+ -+ // Compute matrix product using blocks -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) -+ { -+ ElementA a_1 = ElementA(); -+ ElementB b_1 = ElementB(); -+ ElementA a_2 = ElementA(); -+ ElementB b_2 = ElementB(); -+ -+ // A x B or B x A (with diagonal) -+ if (kSideModeA == SideMode::kLeft) { -+ a_1 = (compare_op_1(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); -+ b_1 = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (kSideModeA == SideMode::kRight) { -+ a_1 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_1 = (compare_op_1(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); -+ } -+ ComputeType compute_a_1 = ComputeType(a_1); -+ ComputeType compute_b_1 = ComputeType(b_1); -+ -+ // The imaginary parts of the diagonal elements of -+ // a complex data type are assumed and set to zero -+ if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { -+ compute_a_1 = real(compute_a_1); -+ } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { -+ compute_b_1 = real(compute_b_1); -+ } -+ -+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); -+ -+ // A^T x B or B x A^T (without diagonal) -+ if (kSideModeA == SideMode::kLeft) { -+ a_2 = (compare_op_2(k_block, row)) ? -+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); -+ b_2 = tensor_b.at(MatrixCoord(k_block, col)); -+ if (kBlasMode == BlasMode::kHermitian) -+ a_2 = conj(a_2); -+ } else if (kSideModeA == SideMode::kRight) { -+ a_2 = tensor_b.at(MatrixCoord(row, k_block)); -+ b_2 = (compare_op_2(col, k_block)) ? -+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); -+ if (kBlasMode == BlasMode::kHermitian) -+ b_2 = conj(b_2); -+ } -+ -+ ComputeType compute_a_2 = ComputeType(a_2); -+ ComputeType compute_b_2 = ComputeType(b_2); -+ -+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ -+ ScalarType c = tensor_c.at(coord); -+ -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j]) + -+ beta * c); -+ } -+ } -+ } -+ -+ } // for (col_block) -+ } // for (row_block) -+ -+ tensor_a.add_pointer_offset(batch_stride_A); -+ tensor_b.add_pointer_offset(batch_stride_B); -+ tensor_c.add_pointer_offset(batch_stride_C); -+ tensor_d.add_pointer_offset(batch_stride_D); -+ -+ } // for (batch_idx) -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -+> -+struct SymmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct SymmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for gaussian multiply-add -+template -+struct SymmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, ScalarType beta, -+ TensorRef tensor_c, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_symm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h -new file mode 100644 -index 0000000..f9a362e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h -@@ -0,0 +1,305 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+ -+#include "cutlass/util/distribution.h" -+//#include "cutlass/util/type_traits.h" -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorEqualsFunc { -+ -+ // -+ // Data members -+ // -+ -+ TensorView lhs; -+ TensorView rhs; -+ bool result; -+ -+ /// Ctor -+ TensorEqualsFunc(): result(true) { } -+ -+ /// Ctor -+ TensorEqualsFunc( -+ TensorView const &lhs_, -+ TensorView const &rhs_ -+ ) : -+ lhs(lhs_), rhs(rhs_), result(true) { } -+ -+ /// Visits a coordinate -+ void operator()(Coord const &coord) { -+ -+ Element lhs_ = lhs.at(coord); -+ Element rhs_ = rhs.at(coord); -+ -+ if (lhs_ != rhs_) { -+ result = false; -+ } -+ } -+ -+ /// Returns true if equal -+ operator bool() const { -+ return result; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorEquals( -+ TensorView const &lhs, -+ TensorView const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc func(lhs, rhs); -+ TensorForEach( -+ lhs.extent(), -+ func -+ ); -+ -+ return bool(func); -+} -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorEquals( -+ TensorViewPlanarComplex const &lhs, -+ TensorViewPlanarComplex const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc real_func( -+ {lhs.data(), lhs.layout(), lhs.extent()}, -+ {rhs.data(), rhs.layout(), rhs.extent()} -+ ); -+ -+ TensorForEach( -+ lhs.extent(), -+ real_func -+ ); -+ -+ if (!bool(real_func)) { -+ return false; -+ } -+ -+ detail::TensorEqualsFunc imag_func( -+ {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, -+ {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} -+ ); -+ -+ TensorForEach( -+ lhs.extent(), -+ imag_func -+ ); -+ -+ return bool(imag_func); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are NOT equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorNotEquals( -+ TensorView const &lhs, -+ TensorView const &rhs) { -+ -+ // Extents must be identical -+ if (lhs.extent() != rhs.extent()) { -+ return true; -+ } -+ -+ detail::TensorEqualsFunc func(lhs, rhs); -+ TensorForEach( -+ lhs.extent(), -+ func -+ ); -+ -+ return !bool(func); -+} -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorNotEquals( -+ TensorViewPlanarComplex const &lhs, -+ TensorViewPlanarComplex const &rhs) { -+ -+ return !TensorEquals(lhs, rhs); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorContainsFunc { -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element value; -+ bool contains; -+ Coord location; -+ -+ // -+ // Methods -+ // -+ -+ /// Ctor -+ TensorContainsFunc(): contains(false) { } -+ -+ /// Ctor -+ TensorContainsFunc( -+ TensorView const &view_, -+ Element value_ -+ ) : -+ view(view_), value(value_), contains(false) { } -+ -+ /// Visits a coordinate -+ void operator()(Coord const &coord) { -+ -+ if (view.at(coord) == value) { -+ if (!contains) { -+ location = coord; -+ } -+ contains = true; -+ } -+ } -+ -+ /// Returns true if equal -+ operator bool() const { -+ return contains; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if a value is present in a tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+bool TensorContains( -+ TensorView const & view, -+ Element value) { -+ -+ detail::TensorContainsFunc func( -+ view, -+ value -+ ); -+ -+ TensorForEach( -+ view.extent(), -+ func -+ ); -+ -+ return bool(func); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of -+/// of the first occurrence. If the value is not contained in the tensor, the second element of the -+/// pair is undefined. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+std::pair > TensorFind( -+ TensorView const & view, -+ Element value) { -+ -+ detail::TensorContainsFunc func( -+ view, -+ value -+ ); -+ -+ TensorForEach( -+ view.extent(), -+ func -+ ); -+ -+ return std::make_pair(bool(func), func.location); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp -new file mode 100644 -index 0000000..a4a5b4e ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp -@@ -0,0 +1,101 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Returns true if two tensor views are equal. -+template < -+ typename TensorL, -+ typename TensorR -+> -+bool TensorEquals( -+ TensorL lhs, -+ TensorR rhs) { -+ -+ // Extents must be identical -+ if (cute::size(lhs) != cute::size(rhs)) { -+ return false; -+ } -+ -+ for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { -+ if (lhs(idx) != rhs(idx)) { -+ return false; -+ } -+ } -+ -+ return true; -+} -+ -+/// Returns true if two tensor views are NOT equal. -+template < -+ typename TensorL, -+ typename TensorR -+> -+bool TensorNotEquals( -+ TensorL lhs, -+ TensorR rhs) { -+ -+ return TensorEquals(lhs, rhs); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h -new file mode 100644 -index 0000000..053511c ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h -@@ -0,0 +1,256 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Helper to convert between types -+template < -+ typename DstElement, -+ typename SrcElement -+> -+struct TrivialConvert { -+ -+ TrivialConvert() { } -+ -+ DstElement operator()(SrcElement src) const { -+ return DstElement(src); -+ } -+}; -+ -+/// Helper to conditionally copy between tensor views. -+template < -+ typename DstElement, -+ typename DstLayout, -+ typename SrcElement, -+ typename SrcLayout, -+ typename F -+> -+struct TensorCopyIf { -+ -+ using DstTensorView = TensorView; -+ using SrcTensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ DstTensorView dst; -+ SrcTensorView src; -+ F convert; -+ -+ // -+ // Methods -+ // -+ -+ TensorCopyIf() { } -+ -+ TensorCopyIf( -+ DstTensorView const &dst_, -+ SrcTensorView const &src_, -+ F const &convert_): dst(dst_), src(src_), convert(convert_) {} -+ -+ /// Copies based on destination and source bounds -+ void operator()(Coord const &coord) { -+ if (dst.contains(coord) && src.contains(coord)) { -+ dst.at(coord) = convert(src.at(coord)); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorView src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ CopyIf copy_if(dst, src, transform); -+ -+ TensorForEach(dst.extent(), copy_if); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -+/// to avoid out of bounds accesses. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorRef src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ TensorView src_view(src, dst.extent()); -+ -+ CopyIf copy_if(dst, src_view, transform); -+ -+ TensorForEach(dst.extent(), copy_if); -+} -+ -+/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent -+/// to avoid out of bounds accesses. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorRef dst, -+ TensorView src, -+ F const &transform) { -+ -+ using CopyIf = detail::TensorCopyIf< -+ DstElement, -+ DstLayout, -+ SrcElement, -+ SrcLayout, -+ F>; -+ -+ TensorView dst_view(dst, src.extent()); -+ -+ CopyIf copy_if(dst_view, src, transform); -+ -+ TensorForEach(src.extent(), copy_if); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout /// Source tensor's layout -+> -+void TensorCopy( -+ TensorView dst, -+ TensorView src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout, /// Source tensor's layout -+ typename F /// Transformation functor -+> -+void TensorCopy( -+ TensorView dst, -+ TensorRef src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds -+/// if SrcElement can be converted to DstElement. -+template < -+ typename DstElement, /// Destination tensor's element type -+ typename DstLayout, /// Destination tensor's layout -+ typename SrcElement, /// Source tensor's element type -+ typename SrcLayout /// Source tensor's layout -+> -+void TensorCopy( -+ TensorRef dst, -+ TensorView src) { -+ -+ detail::TrivialConvert convert; -+ -+ TensorCopy(dst, src, convert); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h -new file mode 100644 -index 0000000..72f5f24 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h -@@ -0,0 +1,341 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Defines host-side elementwise operations on TensorView. -+*/ -+ -+#pragma once -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/functional.h" -+ -+#include "tensor_foreach.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to apply a binary operator in place -+template < -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementD, -+ typename LayoutD, -+ typename BinaryFunc> -+struct TensorFuncBinaryOp { -+ -+ // -+ // Data members -+ // -+ -+ /// View of left-hand-side tensor -+ TensorView view_d; -+ TensorRef view_a; -+ TensorRef view_b; -+ BinaryFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Constructor -+ TensorFuncBinaryOp() { } -+ -+ /// Constructor -+ TensorFuncBinaryOp( -+ TensorView const & view_d_, -+ TensorRef const & view_a_, -+ TensorRef const & view_b_, -+ BinaryFunc func = BinaryFunc() -+ ): -+ view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } -+ -+ /// Equality check -+ void operator()(Coord const &coord) const { -+ view_d.at(coord) = func( -+ ElementD(view_a.at(coord)), -+ ElementD(view_b.at(coord)) -+ ); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Adds two tensors and stores in the destination tensor: d = a + b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorAdd( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::plus -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Adds a tensor in place: d = d .+ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorAdd( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorAdd(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Subtracts two tensors and stores in the destination tensor: d = a - b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorSub( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+ ) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::minus -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Subtracts two tensors in place: d = d .- a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorSub( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+ ) { -+ -+ TensorSub(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Multiplies two tensors and stores in the destination tensor: d = a .* b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorMul( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::multiplies -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Multiplies tensors in place: d = d .* a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorMul( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorMul(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Divides two tensors and stores in the destination tensor: d = a ./ b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorDiv( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::divides -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Divides tensors in place: d = d ./ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorDiv( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorDiv(d, d, a); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Divides two tensors and stores in the destination tensor: d = a ./ b -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA, -+ typename ElementB, -+ typename LayoutB -+> -+void TensorModulus( -+ TensorView d, ///< destination tensor view -+ TensorRef a, ///< A tensor reference -+ TensorRef b ///< B tensor reference -+) { -+ -+ detail::TensorFuncBinaryOp< -+ ElementD, -+ LayoutD, -+ ElementA, -+ LayoutA, -+ ElementB, -+ LayoutB, -+ cutlass::divides -+ > func(d, a, b); -+ -+ TensorForEach( -+ d.extent(), -+ func); -+} -+ -+/// Divides tensors in place: d = d ./ a -+template < -+ typename ElementD, -+ typename LayoutD, -+ typename ElementA, -+ typename LayoutA -+> -+void TensorModulus( -+ TensorView d, ///< destination tensor view -+ TensorRef a ///< A tensor reference -+) { -+ TensorDiv(d, d, a); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h -new file mode 100644 -index 0000000..a8b938d ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h -@@ -0,0 +1,1468 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+#include "cutlass/subbyte_reference.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+#include "cutlass/blas3.h" -+ -+#include "cutlass/util/distribution.h" -+#include "tensor_foreach.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element value; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillFunc( -+ TensorView const &view_ = TensorView(), -+ Element value_ = Element(0) -+ ): view(view_), value(value_) { } -+ -+ void operator()(Coord const & coord) const { -+ view.at(coord) = value; -+ } -+}; -+ -+/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method -+struct BoxMullerFunc { -+ -+ BoxMullerFunc() {} -+ -+ void operator()( -+ double* rnd, ///< Size-2 vector to be filled with random values -+ double mean = 0, ///< Mean of the Gaussian distribution -+ double stddev = 1, ///< Standard deviation of the Gaussian distribution -+ double pi = std::acos(-1)) const { -+ -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); -+ rnd[0] = mean + stddev * rnd[0]; -+ rnd[1] = mean + stddev * rnd[1]; -+ } -+}; -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorView dst, ///< destination tensor -+ Element val = Element(0)) { ///< value to uniformly fill it with -+ -+ detail::TensorFillFunc func(dst, val); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with a uniform value -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFill( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with -+ -+ TensorFill(dst.view_real(), val.real()); -+ TensorFill(dst.view_imag(), val.imag()); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomGaussianFunc { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ // Box-Muller transform to generate random numbers with Normal distribution -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ -+ // Compute Gaussian random value -+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd = mean + stddev * rnd; -+ -+ // Scale and convert final result -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(rnd); -+ } -+ else { -+ result = static_cast(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomGaussianFunc > { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ double rnd[2]; -+ detail::BoxMullerFunc func; -+ func(rnd, mean, stddev, pi); -+ -+ if (int_scale >= 0) { -+ rnd[0] = double(int(rnd[0] * double(1 << int_scale))); -+ rnd[1] = double(int(rnd[1] * double(1 << int_scale))); -+ reals[0] = from_real(rnd[0] / double(1 << int_scale)); -+ reals[1] = from_real(rnd[1] / double(1 << int_scale)); -+ } else { -+ reals[0] = from_real(rnd[0]); -+ reals[1] = from_real(rnd[1]); -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomGaussianFunc > { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ double rnd1[2]; -+ double rnd2[2]; -+ detail::BoxMullerFunc func; -+ func(rnd1, mean, stddev, pi); -+ func(rnd2, mean, stddev, pi); -+ -+ if (int_scale >= 0) { -+ rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); -+ rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); -+ rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); -+ rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); -+ -+ reals[0] = from_real(rnd1[0] / double(1 << int_scale)); -+ reals[1] = from_real(rnd1[1] / double(1 << int_scale)); -+ reals[2] = from_real(rnd2[0] / double(1 << int_scale)); -+ reals[3] = from_real(rnd2[1] / double(1 << int_scale)); -+ } else { -+ reals[0] = from_real(rnd1[0]); -+ reals[1] = from_real(rnd1[1]); -+ reals[2] = from_real(rnd2[0]); -+ reals[3] = from_real(rnd2[1]); -+ } -+ -+ return Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+/// Computes a random Gaussian distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillGaussianFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomGaussianFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillGaussianFunc( -+ TensorView view_ = TensorView(), -+ RandomGaussianFunc func_ = RandomGaussianFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ view.at(coord) = func(); -+ } -+}; -+ -+/// Computes a random Gaussian distribution for a rank-2 tensor -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillSymmetricGaussianFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomGaussianFunc func; -+ cutlass::FillMode fill_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillSymmetricGaussianFunc( -+ TensorView view_ = TensorView(), -+ RandomGaussianFunc func_ = RandomGaussianFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kLower && -+ coord[0] >= coord[1]) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ coord[0] <= coord[1]) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ detail::TensorFillGaussianFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomGaussian( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits); -+ TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSymmetricRandomGaussian( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ detail::TensorFillSymmetricGaussianFunc func( -+ dst, -+ random_func, -+ fill_mode -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values of a Gaussian distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomGaussian( -+ Element *ptr, ///< destination buffer -+ size_t capacity, ///< number of elements -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ReferenceFactory::get(ptr, i) = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomUniformFunc { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(Real(rnd)); -+ } -+ else { -+ result = static_cast(Real(rnd)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ for (int i = 0; i < 2; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a Quaternion value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ for (int i = 0; i < 4; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+/// Computes a random uniform distribution -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ -+ view.at(coord) = func(); -+ } -+}; -+ -+/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillSymmetricRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ cutlass::FillMode fill_mode; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillSymmetricRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kLower && -+ coord[0] >= coord[1]) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ coord[0] <= coord[1]) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+/// Computes a random Uniform distribution and pads diagonal with zeros -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillPadDiagonalRandomUniformFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomUniformFunc func; -+ cutlass::FillMode fill_mode; -+ int alignment; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of uniform RNG functor. -+ TensorFillPadDiagonalRandomUniformFunc( -+ TensorView view_ = TensorView(), -+ RandomUniformFunc func_ = RandomUniformFunc(), -+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, -+ int alignment_ = 1 -+ ): -+ view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ // Fill half of matrix based on FillMode -+ if (Layout::kRank == 2 && -+ (fill_mode == cutlass::FillMode::kLower) && -+ (coord[0] >= coord[1]) || -+ ((coord[1] - coord[0]) >= alignment)) { -+ view.at(coord) = func(); -+ } else if (Layout::kRank == 2 && -+ fill_mode == cutlass::FillMode::kUpper && -+ (coord[0] <= coord[1]) || -+ ((coord[0] - coord[1]) >= alignment)) { -+ view.at(coord) = func(); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values of a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillRandomUniformFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values of a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorViewPlanarComplex dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ TensorFillRandomUniform(dst.view_real(), seed, max, min, bits); -+ TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); -+} -+ -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomUniform( -+ TensorView, Layout> dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc> random_func(seed, max, min, bits); -+ -+ detail::TensorFillRandomUniformFunc, Layout> func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSymmetricRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillSymmetricRandomUniformFunc func( -+ dst, -+ random_func, -+ fill_mode -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillPadDiagonalRandomUniform( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1, ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ int alignment = 1 -+) { -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ detail::TensorFillPadDiagonalRandomUniformFunc func( -+ dst, -+ random_func, -+ fill_mode, -+ alignment -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ReferenceFactory::get(ptr, i) = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillDiagonalFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element diag; -+ Element other; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillDiagonalFunc( -+ TensorView const &view_ = TensorView(), -+ Element diag_ = Element(1), -+ Element other_ = Element(0) -+ ): -+ view(view_), diag(diag_), other(other_) { } -+ -+ void operator()(Coord const & coord) const { -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ view.at(coord) = (is_diag ? diag : other); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor everywhere with a unique value for its diagonal. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillDiagonal( -+ TensorView dst, ///< destination tensor -+ Element diag = Element(1), ///< value to write in the diagonal -+ Element other = Element(0)) { ///< value to write off the diagonal -+ -+ detail::TensorFillDiagonalFunc func( -+ dst, -+ diag, -+ other -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Helper to fill a tensor's digonal with 1 and 0 everywhere else. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillIdentity( -+ TensorView dst) { ///< destination tensor -+ -+ TensorFillDiagonal(dst, Element(1), Element(0)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateDiagonal( -+ TensorView dst, ///< destination tensor -+ Element val = Element(1)) { -+ -+ typename Layout::Index extent = dst.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ dst.at(coord) = val; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorUpdateOffDiagonalFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Element other; -+ -+ // -+ // Methods -+ // -+ -+ TensorUpdateOffDiagonalFunc( -+ TensorView const &view_ = TensorView(), -+ Element other_ = Element(0) -+ ): -+ view(view_), other(other_) { } -+ -+ void operator()(Coord const & coord) const { -+ bool is_diag = true; -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ if (coord[i] != coord[i - 1]) { -+ is_diag = false; -+ break; -+ } -+ } -+ -+ if (!is_diag) { -+ view.at(coord) = other; -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorUpdateOffDiagonal( -+ TensorView dst, ///< destination tensor -+ Element other = Element(1)) { -+ -+ detail::TensorUpdateOffDiagonalFunc func( -+ dst, -+ other -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillLinearFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ Array v; -+ Element s; -+ -+ // -+ // Methods -+ // -+ -+ TensorFillLinearFunc() { } -+ -+ /// Constructs functor -+ TensorFillLinearFunc( -+ TensorView const &view_, -+ Array const & v_, -+ Element s_ = Element(0) -+ ): -+ view(view_), v(v_), s(s_) { } -+ -+ /// Updates the tensor -+ void operator()(Coord const & coord) const { -+ -+ Element sum(s); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 0; i < Layout::kRank; ++i) { -+ sum += Element(coord[i]) * v[i]; -+ } -+ -+ view.at(coord) = sum; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillLinear( -+ TensorView dst, ///< destination tensor -+ Array const & v, -+ Element s = Element(0)) { -+ -+ detail::TensorFillLinearFunc func( -+ dst, -+ v, -+ s -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills tensor with a linear combination of its coordinate and another vector -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillSequential( -+ TensorView dst, ///< destination tensor -+ Element s = Element(0)) { -+ -+ Array stride; -+ -+ stride[0] = Element(1); -+ -+ CUTLASS_PRAGMA_UNROLL -+ for (int i = 1; i < Layout::kRank; ++i) { -+ stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); -+ } -+ -+ TensorFillLinear(dst, stride, s); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ cutlass::ReferenceFactory::value < -+ 8)>::get(ptr, i) = s; -+ -+ s = Element(s + v); -+ ++i; -+ } -+} -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequentialModN( -+ Element *ptr, -+ int64_t capacity, -+ int64_t mod, -+ int64_t v = int64_t(1), -+ int64_t s = int64_t(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ cutlass::ReferenceFactory::value < -+ 8)>::get(ptr, i) = Element(s); -+ -+ s = int64_t(s + v) % mod; -+ ++i; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillRandom( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, -+ Distribution dist) { -+ -+ if (dist.kind == Distribution::Gaussian) { -+ BlockFillRandomGaussian( -+ ptr, -+ capacity, -+ seed, -+ dist.gaussian.mean, -+ dist.gaussian.stddev, -+ dist.int_scale); -+ } -+ else if (dist.kind == Distribution::Uniform) { -+ BlockFillRandomUniform( -+ ptr, -+ capacity, -+ seed, -+ dist.uniform.max, -+ dist.uniform.min, -+ dist.int_scale); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomSparseMetaFunc { -+ -+ uint64_t seed; -+ int range; -+ int MetaSizeInBits; -+ -+ // -+ // Methods -+ // -+ -+ RandomSparseMetaFunc( -+ uint64_t seed_ = 0, -+ int MetaSizeInBits_ = 2 -+ ): -+ seed(seed_), MetaSizeInBits(MetaSizeInBits_) { -+ std::srand((unsigned)seed); -+ if (MetaSizeInBits_ == 2) { -+ range = 6; -+ } else if (MetaSizeInBits_ == 4) { -+ range = 2; -+ } -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; -+ Element TwoToOneMeta[2] = {0x4, 0xe}; -+ -+ Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; -+ -+ Element result = 0x0; -+ -+ for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { -+ int rnd = std::rand() % range; -+ Element meta = MetaArray[rnd]; -+ -+ result = (Element)(result | ((Element)(meta << (i * 4)))); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Computes a random sparse meta -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+struct TensorFillRandomSparseMetaFunc { -+ -+ using TensorView = TensorView; -+ -+ // -+ // Data members -+ // -+ -+ TensorView view; -+ RandomSparseMetaFunc func; -+ -+ // -+ // Methods -+ // -+ -+ /// Construction of Gaussian RNG functor. -+ TensorFillRandomSparseMetaFunc( -+ TensorView view_ = TensorView(), -+ RandomSparseMetaFunc func_ = RandomSparseMetaFunc() -+ ): -+ view(view_), func(func_) { -+ -+ } -+ -+ /// Compute random value and update RNG state -+ void operator()(Coord const &coord) const { -+ -+ view.at(coord) = func(); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomSparseMeta( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits) { ///< 2 bit or 4 bit -+ -+ detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); -+ -+ detail::TensorFillRandomSparseMetaFunc func( -+ dst, -+ random_func -+ ); -+ -+ TensorForEach( -+ dst.extent(), -+ func -+ ); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomSparseMeta( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ int MetaSizeInBits) { ///< 2 bit or 4bit -+ -+ detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a ell block index matrix with random values with a uniform random distribution. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorFillRandomEllIdx( -+ TensorView dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ int rows, int ell_cols, int cols) { ///< dimension of the matrix -+ -+ std::srand((unsigned)seed); -+ -+ for (int i = 0; i < rows; ++i) { -+ int col_idx = std::rand() % cols; -+ -+ for (int j = 0; j < ell_cols; ++j) { -+ dst.at({i, j}) = col_idx; -+ -+ if (col_idx != -1) { -+ if (col_idx == (cols - 1)) { -+ col_idx = -1; -+ } else { -+ col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; -+ } -+ } -+ } -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies a diagonal in from host memory without modifying off-diagonal elements. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalIn( -+ TensorView dst, ///< destination tensor -+ Element const *ptr) { ///< dense buffer of elements -+ -+ typename Layout::Index extent = dst.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ dst.at(coord) = ReferenceFactory::get(ptr, i); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Copies the diagonal of a tensor into a dense buffer in host memory. -+template < -+ typename Element, ///< Element type -+ typename Layout> ///< Layout function -+void TensorCopyDiagonalOut( -+ Element *ptr, ///< dense buffer of elements -+ TensorView src) { ///< source tensor -+ -+ typename Layout::Index extent = src.extent().min(); -+ -+ for (typename Layout::Index i = 0; i < extent; ++i) { -+ Coord coord(i); -+ ReferenceFactory::get(ptr, i) = src.at(coord); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp -new file mode 100644 -index 0000000..3262c53 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp -@@ -0,0 +1,432 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Uniform and procedural tensor fills -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with a scalar element -+template -+void TensorFill(Tensor dst, typename Tensor::value_type element) { -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = element; -+ } -+} -+ -+/// Fills a tensor with the contents of its layout -+template -+void TensorFillSequential(Tensor dst) { -+ -+ auto layout = dst.layout(); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = layout(idx); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Random uniform values -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomUniformFunc { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(Real(rnd)); -+ } -+ else { -+ result = static_cast(Real(rnd)); -+ } -+ -+ return result; -+ } -+}; -+ -+/// Partial specialization for initializing a complex value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ complex operator()() const { -+ -+ Element reals[2]; -+ -+ for (int i = 0; i < 2; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return complex(reals[0], reals[1]); -+ } -+}; -+ -+/// Partial specialization for initializing a Quaternion value. -+template -+struct RandomUniformFunc > { -+ -+ using Real = typename RealType::Type; -+ -+ uint64_t seed; -+ double range; -+ double min; -+ int int_scale; -+ -+ // -+ // Methods -+ // -+ -+ RandomUniformFunc( -+ uint64_t seed_ = 0, -+ double max = 1, -+ double min_ = 0, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { -+ std::srand((unsigned)seed); -+ } -+ -+ -+ /// Compute random value and update RNG state -+ Quaternion operator()() const { -+ -+ Element reals[4]; -+ -+ for (int i = 0; i < 4; ++i) { -+ double rnd = double(std::rand()) / double(RAND_MAX); -+ -+ rnd = min + range * rnd; -+ -+ // Random values are cast to integer after scaling by a power of two to facilitate error -+ // testing -+ -+ if (int_scale >= 0) { -+ rnd = double(int(rnd * double(1 << int_scale))); -+ reals[i] = from_real(Real(rnd / double(1 << int_scale))); -+ } -+ else { -+ reals[i] = from_real(Real(rnd)); -+ } -+ } -+ -+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a uniform random distribution. -+template ///< Tensor object -+void TensorFillRandomUniform( -+ Tensor dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = random_func(); -+ } -+} -+ -+/// Fills a block with random values with a uniform random distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomUniform( -+ Element *ptr, -+ size_t capacity, -+ uint64_t seed, ///< seed for RNG -+ double max = 1, ///< upper bound of distribution -+ double min = 0, ///< lower bound for distribution -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ detail::RandomUniformFunc random_func(seed, max, min, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Random Gaussian -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+template -+struct RandomGaussianFunc { -+ -+ uint64_t seed; -+ double mean; -+ double stddev; -+ int int_scale; -+ double pi; -+ -+ // -+ // Methods -+ // -+ RandomGaussianFunc( -+ uint64_t seed_ = 0, -+ double mean_ = 0, -+ double stddev_ = 1, -+ int int_scale_ = -1 -+ ): -+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { -+ std::srand((unsigned)seed); -+ } -+ -+ /// Compute random value and update RNG state -+ Element operator()() const { -+ -+ // Box-Muller transform to generate random numbers with Normal distribution -+ double u1 = double(std::rand()) / double(RAND_MAX); -+ double u2 = double(std::rand()) / double(RAND_MAX); -+ -+ // Compute Gaussian random value -+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); -+ rnd = mean + stddev * rnd; -+ -+ // Scale and convert final result -+ Element result; -+ -+ if (int_scale >= 0) { -+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); -+ result = static_cast(rnd); -+ } -+ else { -+ result = static_cast(rnd); -+ } -+ -+ return result; -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a tensor with random values with a Gaussian distribution. -+template < -+ typename Tensor -+> -+void TensorFillRandomGaussian( -+ Tensor dst, ///< destination tensor -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) { -+ dst(idx) = random_func(); -+ } -+} -+ -+/// Fills a block with random values with a Gaussian distribution. -+template < -+ typename Element ///< Element type -+> -+void BlockFillRandomGaussian( -+ Element *ptr, ///< destination buffer -+ size_t capacity, ///< number of elements -+ uint64_t seed, ///< seed for RNG -+ double mean = 0, ///< Gaussian distribution's mean -+ double stddev = 1, ///< Gaussian distribution's standard deviation -+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that -+ /// are not truncated to zero. Permits reducing precision of -+ /// data. -+ -+ detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); -+ -+ for (size_t i = 0; i < capacity; ++i) { -+ ptr[i] = random_func(); -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequential( -+ Element *ptr, -+ int64_t capacity, -+ Element v = Element(1), -+ Element s = Element(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ -+ ptr[i] = Element(s + v); -+ ++i; -+ } -+} -+ -+/// Fills a block of data with sequential elements -+template < -+ typename Element -+> -+void BlockFillSequentialModN( -+ Element *ptr, -+ int64_t capacity, -+ int64_t mod, -+ int64_t v = int64_t(1), -+ int64_t s = int64_t(0)) { -+ int i = 0; -+ -+ while (i < capacity) { -+ -+ ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); -+ ++i; -+ } -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h -new file mode 100644 -index 0000000..a195893 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h -@@ -0,0 +1,134 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+#include "cutlass/cutlass.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Defines several helpers -+namespace detail { -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Index of the active rank -+ static int const kActiveRank = Rank - RankRemaining - 1; -+ -+ /// Constructor for general rank -+ TensorForEachHelper( -+ Func &func, -+ Coord const &extent, -+ Coord &coord) { -+ -+ for (int i = 0; i < extent.at(kActiveRank); ++i) { -+ coord[kActiveRank] = i; -+ TensorForEachHelper(func, extent, coord); -+ } -+ } -+}; -+ -+/// Helper to perform for-each operation -+template -+struct TensorForEachHelper { -+ -+ /// Index of the active rank -+ static int const kActiveRank = Rank - 1; -+ -+ /// Constructor for fastest chaning rank -+ TensorForEachHelper( -+ Func &func, -+ Coord const &extent, -+ Coord &coord) { -+ -+ for (int i = 0; i < extent.at(kActiveRank); ++i) { -+ coord[kActiveRank] = i; -+ func(coord); -+ } -+ } -+}; -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over the index space of a tensor -+template < -+ typename Func, ///< function applied to each point in a tensor's index space -+ int Rank> ///< rank of index space -+void TensorForEach(Coord extent, Func & func) { -+ Coord coord; -+ detail::TensorForEachHelper(func, extent, coord); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Iterates over the index space of a tensor and calls a C++ lambda -+template < -+ typename Func, ///< function applied to each point in a tensor's index space -+ int Rank> ///< rank of index space -+void TensorForEachLambda(Coord extent, Func func) { -+ Coord coord; -+ detail::TensorForEachHelper(func, extent, coord); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+struct BlockForEach { -+ -+ /// Constructor performs the operation. -+ BlockForEach( -+ Element *ptr, -+ size_t capacity, -+ typename Func::Params params = typename Func::Params()) { -+ -+ Func func(params); -+ -+ for (size_t index = 0; index < capacity; ++index) { -+ ptr[index] = func(); -+ } -+ } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h -new file mode 100644 -index 0000000..9d52b08 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h -@@ -0,0 +1,42 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+ -+#include "cutlass/cutlass.h" -+ -+// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. -+ -+#include "cutlass/util/reference/host/tensor_reduce.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+ -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h -new file mode 100644 -index 0000000..672e4d5 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+#pragma once -+ -+#include -+ -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/tensor_ref.h" -+ -+#include "cutlass/util/reference/detail/linear_to_coordinate.h" -+#include "cutlass/core_io.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform -+) { -+ -+ for (int64_t idx = 0; idx < view.size(); ++idx) { -+ typename Layout::TensorCoord coord; -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); -+ -+ if (view.contains(coord)) { -+ Element x = view.at(coord); -+ identity = reduce(identity, transform(x)); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform) { -+ -+ if (view_A.extent() != view_B.extent()) { -+ throw std::runtime_error("Tensor extents must match."); -+ } -+ -+ for (int64_t idx = 0; idx < view_A.size(); ++idx) { -+ -+ typename Layout::TensorCoord coord; -+ cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); -+ -+ if (view_A.contains(coord)) { -+ Element a = view_A.at(coord); -+ Element b = view_B.at(coord); -+ identity = reduce(identity, transform(a, b)); -+ } -+ } -+ -+ return identity; -+} -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSum( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = Element -+> -+ComputeType TensorSumSq( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ TensorView view, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity)); -+} -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename Element, -+ typename Layout, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorView view_A, -+ TensorView view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp -new file mode 100644 -index 0000000..aadf60a ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp -@@ -0,0 +1,203 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/* \file -+ \brief Provides several functions for filling tensors with data. -+*/ -+ -+#pragma once -+ -+// Standard Library includes -+#include -+#include -+#include -+ -+// Cute includes -+#include "cute/tensor.hpp" -+ -+// Cutlass includes -+#include "cutlass/cutlass.h" -+#include "cutlass/complex.h" -+#include "cutlass/functional.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/quaternion.h" -+#include "cutlass/array.h" -+#include "cutlass/numeric_types.h" -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Tensor reductions -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename Tensor, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ Tensor view, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform -+) { -+ -+ for (int64_t idx = 0; idx < cute::size(view); ++idx) { -+ identity = reduce(identity, transform(view(idx))); -+ } -+ -+ return identity; -+} -+ -+/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side -+/// workspace -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType, -+ typename ReduceOp, -+ typename TransformOp -+> -+ComputeType TensorTransformReduce( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity, -+ ReduceOp reduce, -+ TransformOp transform) { -+ -+ if (cute::size(view_A) != cute::size(view_B)) { -+ throw std::runtime_error("Tensor sizes must match."); -+ } -+ -+ for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { -+ identity = reduce(identity, transform(view_A(idx), view_B(idx))); -+ } -+ -+ return identity; -+} -+ -+/// Helper to compute the sum of the elements of a tensor -+template < -+ typename Tensor, -+ typename ComputeType = typename Tensor::value_type -+> -+ComputeType TensorSum( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ NumericConverter transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the sum of the squares of the elements of a tensor -+template < -+ typename Tensor, -+ typename ComputeType = typename Tensor::value_type -+> -+ComputeType TensorSumSq( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared transform; -+ -+ return TensorTransformReduce( -+ view, identity, reduce, transform); -+} -+ -+/// Helper to compute the norm of the elements of a tensor. -+template < -+ typename Tensor, -+ typename ComputeType = double -+> -+ComputeType TensorNorm( -+ Tensor view, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSq(view, identity)); -+} -+ -+/// Helper to compute the sum of the squares of the differences of two tensors -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType = double -+> -+ComputeType TensorSumSqDiff( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ plus reduce; -+ magnitude_squared_difference transform; -+ -+ return TensorTransformReduce( -+ view_A, view_B, identity, reduce, transform); -+} -+ -+ -+/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory -+template < -+ typename TensorA, -+ typename TensorB, -+ typename ComputeType = double -+> -+ComputeType TensorNormDiff( -+ TensorA view_A, -+ TensorB view_B, -+ ComputeType identity = ComputeType() -+) { -+ -+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h -new file mode 100644 -index 0000000..0c931ee ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h -@@ -0,0 +1,215 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for TRMM in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+#include "cutlass/arch/mma.h" -+#include "cutlass/util/host_tensor.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_trmm( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper -+ , "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = ElementA(); -+ ElementB b = ElementB(); -+ -+ if (SideModeA == SideMode::kLeft) { -+ a = (compare_op(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); -+ if (row == k_block && DiagTypeA == DiagType::kUnit) { -+ a = ElementA(1); -+ } -+ b = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a = tensor_b.at(MatrixCoord(row, k_block)); -+ b = (compare_op(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); -+ if (k_block == col && DiagTypeA == DiagType::kUnit) { -+ b = ElementA(1); -+ } -+ } -+ -+ ComputeType compute_a(cast_if_scalar(a)); -+ ComputeType compute_b(cast_if_scalar(b)); -+ -+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j])); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd -+> -+struct Trmm; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct Trmm { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h -new file mode 100644 -index 0000000..455c8a9 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Reference implementation for complex-valued TRMM in host-side code. -+ -+ -+*/ -+ -+#pragma once -+ -+#include "cutlass/blas3.h" -+#include "cutlass/complex.h" -+#include "cutlass/numeric_conversion.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/gemm/gemm.h" -+ -+#include "cutlass/util/reference/host/gemm.h" -+ -+namespace cutlass { -+namespace reference { -+namespace host { -+ -+/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef -+/// objects. -+template < -+ typename ElementA, -+ typename LayoutA, -+ ComplexTransform TransformA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ ComplexTransform TransformB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = multiply_add, -+ typename ConvertOp = NumericConverter -+> -+void compute_trmm_complex( -+ gemm::GemmCoord problem_size, -+ ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum) { -+ -+ static_assert( -+ LayoutA::kRank == 2 && -+ LayoutC::kRank == 2, "Tensors must be of rank 2"); -+ -+ static_assert(SideModeA != SideMode::kInvalid -+ , "Side Mode can either be Left or Right."); -+ -+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper -+ , "Fill Mode can either be Lower or Upper."); -+ -+ using CompareOp = typename TrMatrixCompareOp::Type; -+ -+ // Note: batch is ignored. -+ int const M = problem_size.m(); -+ int const N = problem_size.n(); -+ // Assuming correct k-dimension value is passed -+ int const K = problem_size.k(); -+ -+ // Blocking necessary to speedup reference implementation -+ int const Mblock = 16; -+ int const Nblock = 16; -+ -+ ConvertOp convert_op; -+ InnerProductOp inner_product_op; -+ CompareOp compare_op; -+ -+ for (int row_block = 0; row_block < M; row_block += Mblock) { -+ for (int col_block = 0; col_block < N; col_block += Nblock) { -+ -+ ComputeType accum[Mblock][Nblock]; -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ accum[i][j] = initial_accum; -+ } -+ } -+ -+ for (int k_block = 0; k_block < K; ++k_block) { -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ if (row < M && col < N) { -+ ElementA a = ElementA(); -+ ElementB b = ElementB(); -+ -+ if (SideModeA == SideMode::kLeft) { -+ a = (compare_op(row, k_block)) ? -+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); -+ if (row == k_block && DiagTypeA == DiagType::kUnit) { -+ a = ElementA(1); -+ } -+ b = tensor_b.at(MatrixCoord(k_block, col)); -+ } else if (SideModeA == SideMode::kRight) { -+ a = tensor_b.at(MatrixCoord(row, k_block)); -+ b = (compare_op(k_block, col)) ? -+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); -+ if (k_block == col && DiagTypeA == DiagType::kUnit) { -+ b = ElementA(1); -+ } -+ } -+ -+ ComputeType a_ik = ComputeType(a); -+ ComputeType b_kj = ComputeType(b); -+ -+ // Conjugate, and hence hermitian, is only allowed for the triangular matrix -+ if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { -+ a_ik = conj(a_ik); -+ } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { -+ b_kj = conj(b_kj); -+ } -+ -+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); -+ } -+ } -+ } -+ } -+ -+ for (int j = 0; j < Nblock; j++) { -+ for (int i = 0; i < Mblock; i++) { -+ int row = row_block + i; -+ int col = col_block + j; -+ -+ MatrixCoord coord = MatrixCoord(row, col); -+ -+ if (row < M && col < N) { -+ tensor_d.at(coord) = convert_op( -+ alpha * ScalarType(accum[i][j])); -+ } -+ } -+ } -+ } -+ } -+} -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template < -+ typename ElementA, -+ typename LayoutA, -+ ComplexTransform TransformA, -+ SideMode SideModeA, -+ FillMode FillModeA, -+ DiagType DiagTypeA, -+ typename ElementB, -+ typename LayoutB, -+ ComplexTransform TransformB, -+ typename ElementC, -+ typename LayoutC, -+ typename ScalarType, -+ typename ComputeType, -+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex -+> -+struct TrmmComplex; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for multiply-add -+template -+struct TrmmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Partial specialization for gaussian multiply-add -+template -+struct TrmmComplex { -+ -+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha, -+ TensorRef tensor_a, -+ TensorRef tensor_b, -+ TensorRef tensor_d, -+ ComputeType initial_accum = ComputeType(0)) { -+ static_assert( -+ LayoutA::kRank == 2 && LayoutC::kRank == 2, -+ "Tensors must be of rank 2"); -+ -+ compute_trmm_complex>( -+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); -+ } -+}; -+ -+//////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace host -+} // namespace reference -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h -new file mode 100644 -index 0000000..6a352df ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/tensor_view_io.h -@@ -0,0 +1,262 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+* -+**************************************************************************************************/ -+#pragma once -+ -+#include "cutlass/core_io.h" -+#include "cutlass/tensor_view.h" -+#include "cutlass/tensor_view_planar_complex.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+namespace detail { -+ -+/// Helper to write the least significant rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorView_WriteLeastSignificantRank( -+ std::ostream& out, -+ TensorView const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (idx) { -+ out.width(0); -+ out << ", "; -+ } -+ if (idx || coord) { -+ out.width(width); -+ } -+ out << ScalarIO(view.at(coord)); -+ } -+ -+ return out; -+} -+ -+/// Helper to write a rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorView_WriteRank( -+ std::ostream& out, -+ TensorView const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ // If called on the least significant rank, write the result as a row -+ if (rank + 1 == Layout::kRank) { -+ return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); -+ } -+ -+ // Otherwise, write a sequence of rows and newlines -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (rank + 2 == Layout::kRank) { -+ // Write least significant ranks asa matrix with rows delimited by "\n" -+ out << (idx ? ",\n" : ""); -+ TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); -+ } -+ else { -+ // Higher ranks are separated by newlines -+ out << (idx ? ",\n\n" : ""); -+ TensorView_WriteRank(out, view, coord, rank + 1, width); -+ } -+ } -+ -+ return out; -+} -+ -+/// Helper to write the least significant rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (idx) { -+ out.width(0); -+ out << ", "; -+ } -+ if (idx || coord) { -+ out.width(width); -+ } -+ -+ complex x = view.at(coord); -+ out << x; -+ } -+ -+ return out; -+} -+ -+/// Helper to write a rank of a TensorView -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream & TensorViewPlanarComplex_WriteRank( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view, -+ Coord const &start_coord, -+ int rank, -+ std::streamsize width) { -+ -+ // If called on the least significant rank, write the result as a row -+ if (rank + 1 == Layout::kRank) { -+ return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); -+ } -+ -+ // Otherwise, write a sequence of rows and newlines -+ for (int idx = 0; idx < view.extent(rank); ++idx) { -+ -+ Coord coord(start_coord); -+ coord[rank] = idx; -+ -+ if (rank + 2 == Layout::kRank) { -+ // Write least significant ranks asa matrix with rows delimited by ";\n" -+ out << (idx ? ";\n" : ""); -+ TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); -+ } -+ else { -+ // Higher ranks are separated by newlines -+ out << (idx ? "\n" : ""); -+ TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); -+ } -+ } -+ -+ return out; -+} -+ -+} // namespace detail -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& TensorViewWrite( -+ std::ostream& out, -+ TensorView const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); -+} -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& operator<<( -+ std::ostream& out, -+ TensorView const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return TensorViewWrite(out, view); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& TensorViewWrite( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); -+} -+ -+/// Prints human-readable representation of a TensorView to an ostream -+template < -+ typename Element, -+ typename Layout -+> -+inline std::ostream& operator<<( -+ std::ostream& out, -+ TensorViewPlanarComplex const& view) { -+ -+ // Prints a TensorView according to the following conventions: -+ // - least significant rank is printed as rows separated by ";\n" -+ // - all greater ranks are delimited with newlines -+ // -+ // The result is effectively a whitespace-delimited series of 2D matrices. -+ -+ return TensorViewWrite(out, view); -+} -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h b/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h -new file mode 100644 -index 0000000..f187b97 ---- /dev/null -+++ b/3rdparty/cutlass/tools/util/include/cutlass/util/type_traits.h -@@ -0,0 +1,238 @@ -+/*************************************************************************************************** -+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+ * SPDX-License-Identifier: BSD-3-Clause -+ * -+ * Redistribution and use in source and binary forms, with or without -+ * modification, are permitted provided that the following conditions are met: -+ * -+ * 1. Redistributions of source code must retain the above copyright notice, this -+ * list of conditions and the following disclaimer. -+ * -+ * 2. Redistributions in binary form must reproduce the above copyright notice, -+ * this list of conditions and the following disclaimer in the documentation -+ * and/or other materials provided with the distribution. -+ * -+ * 3. Neither the name of the copyright holder nor the names of its -+ * contributors may be used to endorse or promote products derived from -+ * this software without specific prior written permission. -+ * -+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -+ * -+ **************************************************************************************************/ -+/*! \file -+ \brief Type traits for common CUDA types -+*/ -+ -+#pragma once -+ -+#include -+#include -+#include -+ -+#include "cutlass/numeric_types.h" -+#include "cutlass/complex.h" -+ -+namespace cutlass { -+struct half_t; -+ -+template -+struct TypeTraits { -+ typedef T host_type; -+ typedef T device_type; -+ static inline T remove_negative_zero(T x) { return x; } -+ static inline T to_print(T x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef int8_t host_type; -+ typedef int8_t device_type; -+ typedef int8_t integer_type; -+ typedef uint8_t unsigned_type; -+ static inline int8_t remove_negative_zero(int8_t x) { return x; } -+ static inline int to_print(int8_t x) { return (int)x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef uint8_t host_type; -+ typedef uint8_t device_type; -+ typedef uint8_t integer_type; -+ typedef uint8_t unsigned_type; -+ static inline uint8_t remove_negative_zero(uint8_t x) { return x; } -+ static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32I; -+ typedef int host_type; -+ typedef int device_type; -+ typedef int32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline int32_t remove_negative_zero(int32_t x) { return x; } -+ static inline int to_print(int x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32I; -+ typedef unsigned host_type; -+ typedef unsigned device_type; -+ typedef uint32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline uint32_t remove_negative_zero(uint32_t x) { return x; } -+ static inline uint32_t to_print(uint32_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef int64_t host_type; -+ typedef int64_t device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline int64_t remove_negative_zero(int64_t x) { return x; } -+ static inline int64_t to_print(int64_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_8I; -+ typedef uint64_t host_type; -+ typedef uint64_t device_type; -+ typedef uint64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline uint64_t remove_negative_zero(uint64_t x) { return x; } -+ static inline uint64_t to_print(uint64_t x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_16F; -+ typedef half_t host_type; -+ typedef half_t device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline half_t remove_negative_zero(half_t x) { -+ return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); -+ } -+ static inline half_t to_print(half_t x) { return x; } -+ static inline device_type to_device(half_t x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_32F; -+ typedef float host_type; -+ typedef float device_type; -+ typedef int32_t integer_type; -+ typedef uint32_t unsigned_type; -+ static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } -+ static inline float to_print(float x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+template <> -+struct TypeTraits { -+ static cudaDataType_t const cublas_type = CUDA_R_64F; -+ typedef double host_type; -+ typedef double device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } -+ static inline double to_print(double x) { return x; } -+ static inline device_type to_device(host_type x) { return x; } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+// -+// Complex types -+// -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_16F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_16F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int16_t integer_type; -+ typedef uint16_t unsigned_type; -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0_hf ? 0_hf : real(x), -+ imag(x) == -0_hf ? 0_hf : imag(x) -+ ); -+ } -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ -+ static cudaDataType_t const cublas_type = CUDA_C_32F; -+ typedef complex host_type; -+ typedef complex device_type; -+ typedef int64_t integer_type; -+ typedef uint64_t unsigned_type; -+ -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0.f ? 0.f : real(x), -+ imag(x) == -0.f ? 0.f : imag(x) -+ ); -+ } -+ -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+template <> -+struct TypeTraits > { -+ static cudaDataType_t const cublas_type = CUDA_C_64F; -+ typedef complex host_type; -+ typedef complex device_type; -+ struct integer_type { int64_t real, imag; }; -+ struct unsigned_type { uint64_t real, imag; }; -+ static inline complex remove_negative_zero(complex x) { -+ return complex( -+ real(x) == -0.0 ? 0.0 : real(x), -+ imag(x) == -0.0 ? 0.0 : imag(x) -+ ); -+ } -+ static inline complex to_print(complex x) { return x; } -+ static inline device_type to_device(complex x) { return reinterpret_cast(x); } -+}; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+} // namespace cutlass -diff --git a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -index 8707220..c9369e0 100644 ---- a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -+++ b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -@@ -21,7 +21,10 @@ set(trt_fused_multi_head_attention_files - ) - - file(GLOB trt_fused_multi_head_attention_files ${trt_fused_multi_head_attention_files} *.sm*.cpp) -- -+if(${CUDA_VERSION_STRING} VERSION_LESS_EQUAL "10.1.105" ) -+#this cuda don't support sm80 -+ list(REMOVE_ITEM trt_fused_multi_head_attention_files fused_mha_with_relPosBias_fp16_64_32_kernel.sm80.cpp) -+endif() - add_library(trt_fused_multi_head_attention STATIC ${trt_fused_multi_head_attention_files}) - target_link_libraries(trt_fused_multi_head_attention PUBLIC -lcublas -lcudart) - set_property(TARGET trt_fused_multi_head_attention PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/CMakeLists.txt b/CMakeLists.txt -index ea21014..66cf2af 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -14,7 +14,9 @@ - cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13 - project(FasterTransformer LANGUAGES CXX CUDA) - --find_package(CUDA 10.2 REQUIRED) -+find_package(CUDA 10.1 REQUIRED) -+ -+option(EXAMPLES "build examples" on) - - if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") - add_definitions("-DENABLE_BF16") -@@ -61,7 +63,7 @@ if(USE_TRITONSERVER_DATATYPE) - add_definitions("-DUSE_TRITONSERVER_DATATYPE") - endif() - --set(CXX_STD "14" CACHE STRING "C++ standard") -+set(CXX_STD "17" CACHE STRING "C++ standard") - - set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) - -@@ -85,7 +85,7 @@ endif() - - # setting compiler flags - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") --set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong -D_FORTIFY_SOURCE=2") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") - - set(SM_SETS 52 60 61 70 75 80 86) -@@ -92,13 +94,15 @@ set(FIND_SM False) - - foreach(SM_NUM IN LISTS SM_SETS) - string(FIND "${SM}" "${SM_NUM}" SM_POS) -+ message("find ${SM} in ${SM_NUM}") - if(SM_POS GREATER -1) - if(FIND_SM STREQUAL False) - set(ENV{TORCH_CUDA_ARCH_LIST} "") - endif() - set(FIND_SM True) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM_NUM},code=\\\"sm_${SM_NUM},compute_${SM_NUM}\\\"") -- -+ math(EXPR CUDA_ARCH "${SM_NUM}*10") -+ add_definitions("-D__CUDA_ARCH_HOST__=${CUDA_ARCH}") - if (SM_NUM STREQUAL 70 OR SM_NUM STREQUAL 75 OR SM_NUM STREQUAL 80 OR SM_NUM STREQUAL 86) - set(USING_WMMA True) - endif() -@@ -125,8 +129,6 @@ if(NOT (FIND_SM STREQUAL True)) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ - -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ - -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ -- -gencode=arch=compute_80,code=\\\"sm_80,compute_80\\\" \ -- -gencode=arch=compute_86,code=\\\"sm_86,compute_86\\\" \ - ") - # -rdc=true") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") -@@ -136,7 +138,13 @@ if(NOT (FIND_SM STREQUAL True)) - set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5;8.0;8.6") - endif() - set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86) -- message("-- Assign GPU architecture (sm=70,75,80,86)") -+ add_definitions("-D__CUDA_ARCH_HOST__=800") -+ if(${CUDA_VERSION_STRING} VERSION_LESS_EQUAL "10.1" ) -+ message("${CUDA_VERSION_STRING} removing unsupported sm 80 & 86") -+ list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80 86) -+endif() -+ message("-- Assign GPU architectures (sm=${CMAKE_CUDA_ARCHITECTURES})") -+ set(SM 70) - endif() - - if(BUILD_PYT) -@@ -152,8 +160,9 @@ set(CMAKE_CXX_STANDARD "${CXX_STD}") - set(CMAKE_CXX_STANDARD_REQUIRED ON) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") --set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") -- -+if(${CUDA_VERSION_STRING} VERSION_GREATER "10.1.105" ) -+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") -+endif() - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") - # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") - set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") -@@ -230,9 +239,10 @@ link_directories( - - add_subdirectory(3rdparty) - add_subdirectory(src) --add_subdirectory(examples) --add_subdirectory(tests) -- -+if(EXAMPLES) -+ add_subdirectory(examples) -+ add_subdirectory(tests) -+endif() - ######################################## - - if(BUILD_MULTI_GPU) -@@ -249,6 +259,7 @@ add_library(transformer-static STATIC - $ - $ - $ -+ $ - $ - $ - $ -@@ -313,8 +324,9 @@ add_library(transformer-static STATIC - set_property(TARGET transformer-static PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET transformer-static PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(transformer-static PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) -+endif() - --add_library(transformer-shared SHARED -+set(transformer_objects - $ - $ - $ -@@ -324,29 +336,10 @@ add_library(transformer-shared SHARED - $ - $ - $ -- $ -- $ -- $ -+ $ - $ -- $ - $ - $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ -- $ - $ - $ - $ -@@ -373,9 +366,7 @@ add_library(transformer-shared SHARED - $ - $ - $ -- $ - $ -- $ - $ - $ - $ -@@ -387,14 +378,23 @@ add_library(transformer-shared SHARED - $ - $ - $) -+ -+if(${SM} GREATER_EQUAL 70) -+ set(transformer_objects ${transformer_objects} $) -+endif() -+ -+add_library(transformer-shared SHARED ${transformer_objects}) - set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) - set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) - set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) --target_link_libraries(transformer-shared PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) -+target_link_libraries(transformer-shared PUBLIC -lcudart -lcublas -lcublasLt -lcurand) -+target_link_options(transformer-shared PUBLIC -Wl,-z,now,-s,-fstack-protector-strong) - --include(GNUInstallDirs) -+#include(GNUInstallDirs) - set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) - -+ -+ - include(CMakePackageConfigHelpers) - configure_package_config_file( - ${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in -@@ -402,52 +401,23 @@ configure_package_config_file( - INSTALL_DESTINATION ${INSTALL_CONFIGDIR} - ) - --install( -- FILES -- ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake -- DESTINATION ${INSTALL_CONFIGDIR} --) - - install( - TARGETS - transformer-shared - EXPORT - transformer-shared-targets -- LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer -- ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer --) -- --install( -- EXPORT -- transformer-shared-targets -- FILE -- FasterTransformerTargets.cmake -- DESTINATION -- ${INSTALL_CONFIGDIR} -+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib -+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib - ) - - file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh") - foreach ( file ${HEADER_FILES} ) - file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} ) - get_filename_component( dir ${rfile} DIRECTORY ) -- install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} ) -+ install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir}) - endforeach() - - --################################################################################ --# add_executable(gpt sample/cpp/gpt_sample.cc ) --# target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi transformer-static) --# target_link_libraries(gpt PUBLIC -lcublas -lcublasLt -lcudart -lcurand -lnccl -lmpi decoder decoding) -- --export( -- EXPORT -- transformer-shared-targets -- FILE -- ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerTargets.cmake -- NAMESPACE -- TritonCore:: --) - --export(PACKAGE FasterTransformer) - --endif() # BUILD_MULTI_GPU -diff --git a/README.md b/README.md -index a60983c..45b5374 100644 ---- a/README.md -+++ b/README.md -@@ -52,7 +52,7 @@ FasterTransformer is built on top of CUDA, cuBLAS, cuBLASLt and C++. We provide - | Swin Transformer | PyTorch | Yes | Yes | - | - | - | - | Swin Transformer | TensorRT | Yes | Yes | - | - | - | - | ViT | PyTorch | Yes | Yes | - | - | - | --| ViT | TensorRT | Yes | Yes | - | - | - | -+| ViT | TensorRT | Yes | - | - | - | - | - - * Note that the FasterTransformer supports the models above on C++ because all source codes are built on C++. - -diff --git a/deploy.ch b/deploy.ch -new file mode 100644 -index 0000000..9df51e9 ---- /dev/null -+++ b/deploy.ch -@@ -0,0 +1,125 @@ -+ms_benchmark -+ -+sent 15,086 bytes received 22,535 bytes 25,080.67 bytes/sec -+total size is 14,006,280 speedup is 372.30 -+file=/home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -+T5_input1.fp32 -+T5_input2.fp32 -+T5_input3.fp32 -+T5_input4.fp32 -+T5_input5.fp32 -+T5_output1.fp32 -+mha_T5_cross_input1.fp32 -+mha_T5_cross_input2.fp32 -+mha_T5_cross_input3.fp32 -+mha_T5_cross_input4.fp32 -+mha_T5_cross_output1.fp32 -+mha_T5_cross_weight1.fp32 -+mha_T5_cross_weight2.fp32 -+mha_T5_cross_weight3.fp32 -+mha_T5_input1.fp32 -+mha_T5_input2.fp32 -+mha_T5_input3.fp32 -+mha_T5_output1.fp32 -+mha_T5_weight1.fp32 -+mha_T5_weight2.fp32 -+mha_cross_input1.fp32 -+mha_cross_input2.fp32 -+mha_cross_input3.fp32 -+mha_cross_output1.fp32 -+mha_cross_output2.fp32 -+mha_cross_output3.fp32 -+mha_cross_weight1.fp32 -+mha_cross_weight2.fp32 -+mha_cross_weight3.fp32 -+mha_cross_weight4.fp32 -+mha_cross_weight5.fp32 -+mha_x1_input1.fp32 -+mha_x1_input2.fp32 -+mha_x1_output1.fp32 -+mha_x1_output2.fp32 -+mha_x1_output3.fp32 -+mha_x1_weight1.fp32 -+mha_x1_weight2.fp32 -+mha_x1_weight3.fp32 -+mha_x1_weight4.fp32 -+test_input1.fp32 -+transformer_decoder_layer_input1.fp32 -+transformer_decoder_layer_input2.fp32 -+transformer_decoder_layer_input3.fp32 -+transformer_decoder_layer_input4.fp32 -+transformer_decoder_layer_output1.fp32 -+transformer_decoder_layer_t5_input1.fp32 -+transformer_decoder_layer_t5_input2.fp32 -+transformer_decoder_layer_t5_input3.fp32 -+transformer_decoder_layer_t5_input4.fp32 -+transformer_decoder_layer_t5_input5.fp32 -+transformer_decoder_layer_t5_input6.fp32 -+transformer_decoder_layer_t5_output1.fp32 -+transformer_decoder_layer_t5_weight1.fp32 -+transformer_decoder_layer_t5_weight10.fp16 -+transformer_decoder_layer_t5_weight2.fp32 -+transformer_decoder_layer_t5_weight3.fp32 -+transformer_decoder_layer_t5_weight4.fp32 -+transformer_decoder_layer_t5_weight5.fp32 -+transformer_decoder_layer_t5_weight6.fp32 -+transformer_decoder_layer_t5_weight7.fp32 -+transformer_decoder_layer_t5_weight8.fp32 -+transformer_decoder_layer_t5_weight9.fp16 -+transformer_decoder_layer_weight1.fp32 -+transformer_decoder_layer_weight10.fp32 -+transformer_decoder_layer_weight11.fp32 -+transformer_decoder_layer_weight12.fp32 -+transformer_decoder_layer_weight13.fp32 -+transformer_decoder_layer_weight14.fp32 -+transformer_decoder_layer_weight15.fp32 -+transformer_decoder_layer_weight16.fp16 -+transformer_decoder_layer_weight17.fp16 -+transformer_decoder_layer_weight18.fp16 -+transformer_decoder_layer_weight19.fp32 -+transformer_decoder_layer_weight2.fp32 -+transformer_decoder_layer_weight3.fp32 -+transformer_decoder_layer_weight4.fp32 -+transformer_decoder_layer_weight5.fp32 -+transformer_decoder_layer_weight6.fp32 -+transformer_decoder_layer_weight7.fp32 -+transformer_decoder_layer_weight8.fp32 -+transformer_decoder_layer_weight9.fp32 -+transformer_encoder_layer_input1.fp32 -+transformer_encoder_layer_input2.fp32 -+transformer_encoder_layer_output1.fp32 -+transformer_encoder_layer_t5_input1.fp32 -+transformer_encoder_layer_t5_input2.fp32 -+transformer_encoder_layer_t5_input3.fp32 -+transformer_encoder_layer_t5_output1.fp32 -+transformer_encoder_layer_t5_weight1.fp32 -+transformer_encoder_layer_t5_weight2.fp32 -+transformer_encoder_layer_t5_weight3.fp32 -+transformer_encoder_layer_t5_weight4.fp32 -+transformer_encoder_layer_t5_weight5.fp16 -+transformer_encoder_layer_t5_weight6.fp16 -+transformer_encoder_layer_weight1.fp32 -+transformer_encoder_layer_weight10.fp16 -+transformer_encoder_layer_weight11.fp16 -+transformer_encoder_layer_weight12.fp32 -+transformer_encoder_layer_weight2.fp32 -+transformer_encoder_layer_weight3.fp32 -+transformer_encoder_layer_weight4.fp32 -+transformer_encoder_layer_weight5.fp32 -+transformer_encoder_layer_weight6.fp32 -+transformer_encoder_layer_weight7.fp32 -+transformer_encoder_layer_weight8.fp32 -+transformer_encoder_layer_weight9.fp16 -+ -+sent 224,270 bytes received 328,739 bytes 368,672.67 bytes/sec -+total size is 100,832,432 speedup is 182.33 -+libtransformer-shared.so -+ -+sent 40,407 bytes received 70,578 bytes 73,990.00 bytes/sec -+total size is 101,408,056 speedup is 913.71 -+command= CUDA_VISIBLE_DEVICES=5 LD_LIBRARY_PATH=/home/shira/git-proj/FasterTransformer/../FasterTransformer:/usr/local/cuda-11.7/lib64 /home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -b 1 -l 12 -H 2 -S 8 -s 20 -f 32 -x 1 -P 0 -m transformer_edeode_layer_t5 -+[INFO] Device: NVIDIA A100-PCIE-40GB -+[WARNING] gemm_config.in is not found; using default GEMM algo -+model_nametransformer_edeode_layer_t5 -+model num=-1TDL_T59 -+batch_size 1 seq_len 20 layer 12 AVG FT-CPP-time 0.00 ms (1000 iterations) Total Time 0.01 ms -diff --git a/deploy.sh b/deploy.sh -new file mode 100755 -index 0000000..5b0ed1b ---- /dev/null -+++ b/deploy.sh -@@ -0,0 +1,36 @@ -+#copy cuda folder (once) -+base=/home/batya/git-proj/FasterTransformer -+#`git rev-parse --show-toplevel` -+#debug="gdb --args" -+ -+server=pick -+while getopts "d" opt -+do -+case "${opt}" in -+ "d" ) -+ debug="gdb --args" -+ shift -+ ;; -+esac -+done -+file=/home/batya/git-proj/FasterTransformer/build/bin/ms_benchmark -+#`realpath $1` -+shift -+rsync -v ${file} ${server}:${file} -+echo "file=${file}" -+# rsync -Iv ${base}/../mindspore/trc/pangu/*.fp* ${server}:${base}/build/bin -+rm -f ${server}:${base}/build/bin/*decoder* -+ -+# rsync -v ${base}/../mindspore/trc/pangu/transformer_decoder_layer_weight*.fp16 ${server}:${base}/build/bin -+rsync -v ${base}/../mindspore/trc/transformer/*transformer_decoder_layer* ${server}:${base}/build/bin -+rsync -v ${base}/build/lib/*.so ${server}:${base}/build/lib -+# echo "cd ${base}/build/bin/" -+ -+command=$(cat <<-ENDM -+ CUDA_VISIBLE_DEVICES=5 \ -+ LD_LIBRARY_PATH=${base}/../FasterTransformer:/usr/local/cuda-11.7/lib64 \ -+ ${debug} ${file} $@ -+ENDM -+) -+echo "command=${command}" -+ssh ${server} "cd ${base}/build/bin ;${command}" -diff --git a/deploy.trc b/deploy.trc -new file mode 100644 -index 0000000..023cccf ---- /dev/null -+++ b/deploy.trc -@@ -0,0 +1,303 @@ -+ms_benchmark -+ -+sent 848,395 bytes received 22,547 bytes 580,628.00 bytes/sec -+total size is 14,046,536 speedup is 16.13 -+file=/home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -+T5_input1.fp32 -+T5_input2.fp32 -+T5_input3.fp32 -+T5_input4.fp32 -+T5_input5.fp32 -+T5_output1.fp32 -+mha_T5_input1.fp32 -+mha_T5_input2.fp32 -+mha_T5_input3.fp32 -+mha_T5_output1.fp32 -+mha_T5_weight1.fp32 -+mha_T5_weight2.fp32 -+mha_cross_input1.fp32 -+mha_cross_input2.fp32 -+mha_cross_input3.fp32 -+mha_cross_output1.fp32 -+mha_cross_output2.fp32 -+mha_cross_output3.fp32 -+mha_cross_weight1.fp32 -+mha_cross_weight2.fp32 -+mha_cross_weight3.fp32 -+mha_cross_weight4.fp32 -+mha_cross_weight5.fp32 -+mha_x1_input1.fp32 -+mha_x1_input2.fp32 -+mha_x1_output1.fp32 -+mha_x1_output2.fp32 -+mha_x1_output3.fp32 -+mha_x1_weight1.fp32 -+mha_x1_weight2.fp32 -+mha_x1_weight3.fp32 -+mha_x1_weight4.fp32 -+test_input1.fp32 -+transformer_decoder_layer_input1.fp32 -+transformer_decoder_layer_input2.fp32 -+transformer_decoder_layer_input3.fp32 -+transformer_decoder_layer_input4.fp32 -+transformer_decoder_layer_output1.fp32 -+transformer_decoder_layer_t5_input1.fp32 -+transformer_decoder_layer_t5_input2.fp32 -+transformer_decoder_layer_t5_input3.fp32 -+transformer_decoder_layer_t5_input4.fp32 -+transformer_decoder_layer_t5_input5.fp32 -+transformer_decoder_layer_t5_input6.fp32 -+transformer_decoder_layer_t5_output1.fp32 -+transformer_decoder_layer_t5_weight1.fp32 -+transformer_decoder_layer_t5_weight10.fp16 -+transformer_decoder_layer_t5_weight2.fp32 -+transformer_decoder_layer_t5_weight3.fp32 -+transformer_decoder_layer_t5_weight4.fp32 -+transformer_decoder_layer_t5_weight5.fp32 -+transformer_decoder_layer_t5_weight6.fp32 -+transformer_decoder_layer_t5_weight7.fp32 -+transformer_decoder_layer_t5_weight8.fp32 -+transformer_decoder_layer_t5_weight9.fp16 -+transformer_decoder_layer_weight1.fp32 -+transformer_decoder_layer_weight10.fp32 -+transformer_decoder_layer_weight11.fp32 -+transformer_decoder_layer_weight12.fp32 -+transformer_decoder_layer_weight13.fp32 -+transformer_decoder_layer_weight14.fp32 -+transformer_decoder_layer_weight15.fp32 -+transformer_decoder_layer_weight16.fp32 -+transformer_decoder_layer_weight17.fp32 -+transformer_decoder_layer_weight18.fp32 -+transformer_decoder_layer_weight19.fp32 -+transformer_decoder_layer_weight2.fp32 -+transformer_decoder_layer_weight3.fp32 -+transformer_decoder_layer_weight4.fp32 -+transformer_decoder_layer_weight5.fp32 -+transformer_decoder_layer_weight6.fp32 -+transformer_decoder_layer_weight7.fp32 -+transformer_decoder_layer_weight8.fp32 -+transformer_decoder_layer_weight9.fp32 -+transformer_encoder_layer_input1.fp32 -+transformer_encoder_layer_input2.fp32 -+transformer_encoder_layer_output1.fp32 -+transformer_encoder_layer_t5_input1.fp32 -+transformer_encoder_layer_t5_input2.fp32 -+transformer_encoder_layer_t5_input3.fp32 -+transformer_encoder_layer_t5_output1.fp32 -+transformer_encoder_layer_t5_weight1.fp32 -+transformer_encoder_layer_t5_weight2.fp32 -+transformer_encoder_layer_t5_weight3.fp32 -+transformer_encoder_layer_t5_weight4.fp32 -+transformer_encoder_layer_t5_weight5.fp32 -+transformer_encoder_layer_t5_weight6.fp32 -+transformer_encoder_layer_weight1.fp32 -+transformer_encoder_layer_weight10.fp32 -+transformer_encoder_layer_weight11.fp32 -+transformer_encoder_layer_weight12.fp32 -+transformer_encoder_layer_weight2.fp32 -+transformer_encoder_layer_weight3.fp32 -+transformer_encoder_layer_weight4.fp32 -+transformer_encoder_layer_weight5.fp32 -+transformer_encoder_layer_weight6.fp32 -+transformer_encoder_layer_weight7.fp32 -+transformer_encoder_layer_weight8.fp32 -+transformer_encoder_layer_weight9.fp32 -+ -+sent 99,079 bytes received 141,615 bytes 481,388.00 bytes/sec -+total size is 41,715,056 speedup is 173.31 -+libtransformer-shared.so -+ -+sent 2,297,603 bytes received 70,578 bytes 947,272.40 bytes/sec -+total size is 101,408,056 speedup is 42.82 -+command= CUDA_VISIBLE_DEVICES=5 LD_LIBRARY_PATH=/home/shira/git-proj/FasterTransformer/../FasterTransformer:/usr/local/cuda-11.7/lib64 /home/shira/git-proj/FasterTransformer/build/bin/ms_benchmark -b 1 -l 12 -H 2 -S 8 -s 20 -f 32 -x 0 -P 1 -m transformer_decoder_layer_t5 -+[INFO] Device: NVIDIA A100-PCIE-40GB -+[WARNING] gemm_config.in is not found; using default GEMM algo -+model_nametransformer_decoder_layer_t5 -+model num=9TDL_T59 -+ffn hidden size= 32hidden_units= 8opt_a->hidden_size= 8InitWeight -+model_nametransformer_decoder_layer_t5 -+forward -+0 -+i: 0.458147, -+ -+-0.0708623,0.428798,0.185046,0.057154,-0.203372,0.412952,0.149334, -+1 -+i: 1, -+ -+1,1,1,1,1,1,1, -+2 -+i: -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+3 -+i: 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+4 -+i: -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+5 -+i: -0.00660441, -+ -+0.000136423,0.00216104,-0.0194914,-0.0169052,-0.0184758,0.0166858,-0.00725936, -+6 -+i: 1, -+ -+1,1,1,1,1,1,1, -+7 -+i: -0.979686, -+ -+-0.655586,0.619562,-0.208414,-0.406135,-0.363263,0.452679,0.711803, -+8 -+i: 0.00186535, -+ -+-0.000453893,-0.00868625,0.0019144,0.00439803,0.00553658,-0.0059722,0.00984536, -+9 -+i: -0.0151223, -+ -+0.00973828,0.00486344,0.00208442,-0.00624479,0.0143989,0.00552422,0.00341345, -+10 -+i: -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+11 -+i: 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+12 -+i: 0.0161316, -+ -+0.00175404,-0.00734864,-0.0212876,-0.0123986,-0.00419559,0.00115276,0.00576639, -+13 -+i: 1, -+ -+1,1,1,1,1,1,1, -+14 -+i: 1.38904e-28, -+ -+7.80033e-34,2.25142e-20,-3.8139e-20,-8.21938e-19,-2.85581e-23,-8.682e-19,-1.04875e-19, -+15 -+i: 9.3225e-21, -+ -+3.67653e-27,8.89747e-23,5.8367e-25,4.46003e-20,-1.02178e-23,1.33927e-18,-1.14603e-18, -+0001001 -+tensor 1.22719, -+ -+-1.08314,1.09902,0.0344878,-0.524054,-1.66184,1.02982,-0.121479, -+tensor -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+tensor 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+tensor -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+tensor -0.00660441, -+ -+0.000136423,0.00216104,-0.0194914,-0.0169052,-0.0184758,0.0166858,-0.00725936, -+tensor 1, -+ -+1,1,1,1,1,1,1, -+not cross -+weight_qkv -0.00567709, -+ -+0.00275152,-0.0183454,-0.00663061,0.00240306,-0.000401073,-0.0296616,-0.00454295, -+from_tensor 1.22719, -+ -+-1.08314,1.09902,0.0344878,-0.524054,-1.66184,1.02982,-0.121479, -+qkv_buf -0.034644, -+ -+-0.0221119,-0.0255647,-0.0296132,0.039853,-0.0505666,-0.0563698,-0.0394761, -+output1 -0.0189542, -+ -+-0.030236,-0.00238584,0.0115228,-0.00333089,-0.0369255,-0.0199997,0.0257213, -+q_buf_2 -0.034644, -+ -+-0.0221119,-0.0255647,-0.0296132,-0.0132158,0.0133059,-0.0213238,-0.0197066, -+output2 0.0174091, -+ -+-0.0174263,-0.00708923,-0.0546997,0.0247199,0.0332275,0.00747891,-0.0340745, -+output1 -0.0189542, -+ -+-0.030236,-0.00238584,0.0115228,-0.00333089,-0.0369255,-0.0199997,0.0257213, -+qk_buf 0.00104488, -+ -+0.000681124,-0.000605213,0.00103154,0.00169038,-0.00193581,-0.00120881,0.00293503, -+attention_mask 0.042776, -+ -+-0.184324,0.159665,0.0819166,-0.551958,0.424301,-0.107998,0.906672, -+position_bias -0.509908, -+ -+-0.200033,0.462621,0.674143,0.119592,-0.104946,0.62515,-0.689601, -+qk_buf 0, -+ -+0,0,0,0,0,0,0, -+qkv_buf_2 0.0560303, -+ -+-0.0145187,0.00715256,-0.0731201,0.0189056,-0.00694656,-0.00557709,0.0466919, -+qkv_buf_3 0.0560303, -+ -+-0.0145187,0.00715256,-0.0731201,0.00761032,-0.0180664,0.0290527,-0.0121078, -+param->in_idx5 -+output[0] 0.000853663, -+ -+-0.00158517,0.000399898,-0.00118085,-0.000752181,-0.000945039,0.000822433,0.00073974, -+tensor 1.22736, -+ -+-1.08373,1.09882,0.0334885,-0.52422,-1.66137,1.0301,-0.120446, -+tensor -0.979686, -+ -+-0.655586,0.619562,-0.208414,-0.406135,-0.363263,0.452679,0.711803, -+tensor 0.00186535, -+ -+-0.000453893,-0.00868625,0.0019144,0.00439803,0.00553658,-0.0059722,0.00984536, -+tensor -0.0151223, -+ -+0.00973828,0.00486344,0.00208442,-0.00624479,0.0143989,0.00552422,0.00341345, -+tensor -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+tensor 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+is cross -+output1 0.0290729, -+ -+-0.0301254,0.023922,0.0158201,-0.0092989,0.00192484,-0.00540288,-0.0149312, -+qk_buf 0.00079844, -+ -+-0.000482788,-0.000210089,0.000353139,0.000915573,0.000227468,0.000344165,-0.000868972, -+attention_mask -0.411828, -+ -+0.689115,0.670348,-0.657296,0.389433,-0.099893,0.639807,-0.0361588, -+position_bias 0.140165, -+ -+0.147417,0.453423,0.41165,0.410601,0.353489,0.748643,0.34843, -+qk_buf 0, -+ -+0,0,0,0,0,0,0, -+qkv_buf_2 -0.00668335, -+ -+0.0517578,0.00294304,-0.0211945,0.0103607,-0.0372925,0.012619,0.0177155, -+qkv_buf_3 -0.00668335, -+ -+0.0517578,0.00294304,-0.0211945,0.00642776,0.0122375,0.022171,-0.00510406, -+param->in_idx7 -+output[0] -0.000111274, -+ -+-0.00114325,-0.000122436,0.000859971,0.000143449,-0.000135836,-0.000207918,-0.000205797, -+gamma3 1, -+ -+1,1,1,1,1,1,1, -+normed_attn2_out 1.22722, -+ -+-1.08463,1.09869,0.034459,-0.523898,-1.6612,1.02988,-0.120521, -+attn2_out 1.22794, -+ -+-1.08587,1.09929,0.0341669,-0.524663,-1.66292,1.03043,-0.120945, -+attn_out 1.22805, -+ -+-1.08472,1.09942,0.0333069,-0.524806,-1.66279,1.03064,-0.120739, -+13 -diff --git a/docs/gpt_guide.md b/docs/gpt_guide.md -index afcba9a..71c4fab 100644 ---- a/docs/gpt_guide.md -+++ b/docs/gpt_guide.md -@@ -312,7 +312,7 @@ python tools/checkpoint_util.py --model-type GPT --loader megatron --saver faste - To convert the Megatron GPT model to binary, FasterTransformer provides a tool `examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py` to convert the checkpoint. - - ```bash --wget https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.onnx -+wget https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx - python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 1 - python ../examples/onnx/multi_gpu_gpt/onnx_ckpt_convert.py -i gpt2-10.onnx -o ../models/onnx-models/c-model/124m/ -i_g 4 - ``` -diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt -index b67cd01..3cc4155 100644 ---- a/examples/cpp/CMakeLists.txt -+++ b/examples/cpp/CMakeLists.txt -@@ -13,6 +13,7 @@ - # limitations under the License. - - add_subdirectory(bert) -+add_subdirectory(ms) - add_subdirectory(bert_int8) - add_subdirectory(decoding) - add_subdirectory(gpt) -diff --git a/examples/cpp/gpt/gpt_example.cc b/examples/cpp/gpt/gpt_example.cc -index cacb09e..5fec0c9 100644 ---- a/examples/cpp/gpt/gpt_example.cc -+++ b/examples/cpp/gpt/gpt_example.cc -@@ -236,7 +236,7 @@ void gpt_example(const INIReader reader) - #endif - - if (std::is_same::value) { -- cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); -+ cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUBLAS_COMPUTE_32F_FAST_TF32); - } - #ifdef ENABLE_BF16 - else if (std::is_same::value) { -diff --git a/examples/cpp/ms/CMakeLists.txt b/examples/cpp/ms/CMakeLists.txt -new file mode 100644 -index 0000000..33e562b ---- /dev/null -+++ b/examples/cpp/ms/CMakeLists.txt -@@ -0,0 +1,22 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+add_executable(ms_benchmark ms.cc) -+if (SPARSITY_SUPPORT) -+# target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart -lcusparse -lcusparseLt transformer-shared) -+target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart -lcusparse -lcusparseLt GptContextAttentionLayer MSLayer) -+else() -+# target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart transformer-shared) -+target_link_libraries(ms_benchmark PUBLIC -lcublas -lcublasLt -lcudart GptContextAttentionLayer MSLayer) -+endif() -diff --git a/examples/cpp/ms/initialize.h b/examples/cpp/ms/initialize.h -new file mode 100644 -index 0000000..9e72838 ---- /dev/null -+++ b/examples/cpp/ms/initialize.h -@@ -0,0 +1,746 @@ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSAttentionLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSDecoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSEncoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSLayerWeight.h" -+using namespace fastertransformer; -+ -+template -+struct Decriptor { -+ std::vector input_tensors; // GPU -+ std::vector input_python_tensors; // CPU -+ std::vector output_tensors; // GPU -+ std::vector output_python_tensors; // CPU -+ std::vector w_tensors; -+ MSBaseLayer* MSLayer; -+}; -+ -+template -+void InitializeAttn(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+template -+void InitializeAttnX2(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ // GPU RESULTS -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ desc.output_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ desc.output_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->size_per_head}, -+ 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+ -+template -+void InitializeAttnCross(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ true, // is_cross -+ false, // sparse -+ false); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+} -+template -+void InitializeAttnT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false, // is_cross -+ false, // sparse -+ true); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len},0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+} -+ -+template -+void InitializeAttnT5Cross(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ -+ desc.MSLayer = new MSMHALayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ true, // is_cross -+ false, // sparse -+ true); // is_position_bias -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ // GPU RESULTS -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_tensors.push_back(Tensor{ -+ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len},0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, -+ // opt_a->size_per_head}, 0}); -+ // desc.output_python_tensors.push_back(Tensor{ -+ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, -+ // opt_a->tgt_seq_len}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+} -+ -+template -+void InitializeEncoder(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ // const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ const size_t hidden_units = opt_a->hidden_size; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSELayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->post_layernorm_residual, -+ false, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ bool compress=false; -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ if(compress) -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ if(compress) -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size, opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); -+} -+template -+void InitializeEncoderT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ // const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ const size_t hidden_units = opt_a->hidden_size; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSELayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->post_layernorm_residual, -+ true, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->seq_len}, -+ 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g2 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size, opt_a->hidden_size}, 0}); -+} -+ -+template -+void InitializeDecoder(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSDLayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->eps3, -+ opt_a->post_layernorm_residual, -+ false, -+ false, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // G1 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // B1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{3 * hidden_units}, 0}); // bt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); // bp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // b2 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units * 2}, 0}); // bt2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units * 3}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); // bp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g3 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // b3 -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wm -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->ffn_hidden_size}, 0}); // bm -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); -+ ; // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // bp -+} -+template -+void InitializeDecoderT5(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; -+ // TODO Nizzan - check if need to be -+ desc.MSLayer = new MSDLayer(opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->tgt_seq_len, -+ opt_a->head_num, -+ opt_a->size_per_head, -+ opt_a->ffn_hidden_size, -+ opt_a->eps1, -+ opt_a->eps2, -+ opt_a->eps3, -+ opt_a->post_layernorm_residual, -+ true, -+ true, -+ opt_a->is_ffn_fp16, -+ stream, -+ cublas_wrapper, -+ cublas_handle, -+ allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse -+ desc.input_tensors.push_back(Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_CPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->hidden_size}, -+ 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.input_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->tgt_seq_len, opt_a->seq_len}, 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ desc.input_python_tensors.push_back( -+ Tensor{MEMORY_GPU, -+ getTensorType(), -+ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, -+ 0}); -+ -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // G1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); // wt -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g1 -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); -+ desc.w_tensors.push_back( -+ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); // wp2 -+ desc.w_tensors.push_back(Tensor{MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size}, 0}); // g3 -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wm -+ desc.w_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->hidden_size, opt_a->ffn_hidden_size}, 0}); // wp -+} -+ -+template -+void Init(opt_arg* opt_a, -+ Decriptor& desc, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ Allocator* allocator) -+{ -+ int model_num = ModelNum(opt_a->model_name); -+ switch (model_num) { -+ case MHA_X1: -+ InitializeAttn(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_X2: -+ InitializeAttnX2(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_CROSS: -+ InitializeAttnCross(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_T5: -+ InitializeAttnT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case MHA_T5_CROSS: -+ InitializeAttnT5Cross(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TEL: -+ InitializeEncoder(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TEL_T5: -+ InitializeEncoderT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TDL: -+ InitializeDecoder(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ case TDL_T5: -+ InitializeDecoderT5(opt_a, desc, stream, cublas_wrapper, cublas_handle, allocator); -+ break; -+ default: -+ break; -+ } -+} -diff --git a/examples/cpp/ms/ms.cc b/examples/cpp/ms/ms.cc -new file mode 100644 -index 0000000..085407f ---- /dev/null -+++ b/examples/cpp/ms/ms.cc -@@ -0,0 +1,494 @@ -+/* -+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#include "examples/cpp/ms/initialize.h" -+#include "src/fastertransformer/utils/logger.h" -+#include -+#include -+#include -+using namespace fastertransformer; -+ -+template -+int MsExample(opt_arg* opt_a); -+void usage() -+{ -+ std::cout << "Usage: ms_benchmark -b -l " -+ << "-s -H -S -p " -+ << "-T -W -F " -+ << "-m \n"; -+} -+ -+bool read_args(int argc, char* argv[], opt_arg* opt_a) -+{ -+ int opt; -+ while ((opt = getopt(argc, argv, "b:l:s:t:H:S:p:m:T:W:F:i:w:f:P:x:1:2:3")) != -1) { -+ switch (opt) { -+ case 'b': -+ opt_a->batch_size = atoi(optarg); -+ break; -+ case 'l': -+ opt_a->num_layers = atoi(optarg); -+ break; -+ case 's': -+ opt_a->seq_len = atoi(optarg); -+ break; -+ case 't': -+ opt_a->tgt_seq_len = atoi(optarg); -+ break; -+ case 'H': -+ opt_a->head_num = atoi(optarg); -+ break; -+ case 'S': -+ opt_a->hidden_size = atoi(optarg); -+ break; -+ case 'm': -+ opt_a->model_name = std::string(optarg); -+ break; -+ case 'T': -+ opt_a->compute_type = std::string(optarg); -+ break; -+ case 'W': -+ opt_a->w_compute_type = std::string(optarg); -+ break; -+ case 'F': -+ opt_a->s_compute_type = std::string(optarg); -+ break; -+ case 'f': -+ opt_a->ffn_hidden_size = atoi(optarg); -+ break; -+ case '1': -+ opt_a->eps1 = atoi(optarg); -+ break; -+ case '2': -+ opt_a->eps2 = atoi(optarg); -+ break; -+ case '3': -+ opt_a->eps3 = atoi(optarg); -+ break; -+ case 'P': -+ if (atoi(optarg) == 1) -+ opt_a->post_layernorm_residual = true; -+ else if (atoi(optarg) == 0) -+ opt_a->post_layernorm_residual = false; -+ break; -+ case 'p': -+ opt_a->is_remove_padding = bool(optarg); -+ break; -+ case 'x': -+ if (atoi(optarg) == 1) -+ opt_a->is_ffn_fp16 = true; -+ else if (atoi(optarg) == 0) -+ opt_a->is_ffn_fp16 = false; -+ break; -+ case 'i': -+ case 'w': -+ break; -+ case 'h': -+ default: -+ usage(); -+ return false; -+ } -+ } -+ opt_a->size_per_head = opt_a->hidden_size / opt_a->head_num; -+ opt_a->tgt_seq_len = (opt_a->tgt_seq_len == -1) ? opt_a->seq_len : opt_a->tgt_seq_len; -+ if (opt_a->ffn_hidden_size == -1) { -+ opt_a->ffn_hidden_size = opt_a->hidden_size * opt_a->expand_ratio; -+ } -+ return true; -+} -+ -+int main(int argc, char** argv) -+{ -+ opt_arg opt_a; -+ opt_a.batch_size = 1; -+ opt_a.num_layers = 1; -+ opt_a.seq_len = 1; -+ opt_a.tgt_seq_len = -1; -+ opt_a.head_num = 1; -+ opt_a.hidden_size = 1; -+ opt_a.size_per_head = 1; -+ opt_a.expand_ratio = 4; -+ opt_a.ffn_hidden_size = -1; -+ opt_a.eps1 = 1e-6f; -+ opt_a.eps2 = 1e-6f; -+ opt_a.eps3 = 1e-6f; -+ opt_a.post_layernorm_residual = true; -+ opt_a.is_remove_padding = false; -+ opt_a.model_name = ""; -+ opt_a.compute_type = "fp32"; -+ opt_a.w_compute_type = "fp32"; -+ opt_a.s_compute_type = "fp32"; -+ opt_a.is_ffn_fp16 = false; -+ -+ if (read_args(argc, argv, &opt_a)) { -+ bool c_type_fp32 = (opt_a.compute_type.compare("fp32") == 0); -+ bool w_type_fp32 = (opt_a.w_compute_type.compare("fp32") == 0); -+ bool s_type_fp32 = (opt_a.s_compute_type.compare("fp32") == 0); -+ -+ s_type_fp32 = c_type_fp32; // Do softmax compute type as compute type -+ if (c_type_fp32 && w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && !w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (c_type_fp32 && !w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && w_type_fp32 && !s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else if (!c_type_fp32 && !w_type_fp32 && s_type_fp32) { -+ return MsExample(&opt_a); -+ } -+ else { // (!c_type_fp32 && !w_type_fp32 && !s_type_fp32) -+ return MsExample(&opt_a); -+ } -+ } -+} -+ -+template -+int ReadFileBuf(const std::string file, T* buf, size_t size_buff) -+{ -+ if (file.empty()) { -+ FT_LOG_ERROR("file is nullptr\n"); -+ return -1; -+ } -+ -+ std::ifstream ifs(file); -+ if (!ifs.good()) { -+ FT_LOG_ERROR("file: %s does not exist\n", file.c_str()); -+ return -1; -+ } -+ -+ if (!ifs.is_open()) { -+ FT_LOG_ERROR("file: open failed\n"); -+ return -1; -+ } -+ -+ ifs.seekg(0, std::ios::end); -+ size_t file_size = ifs.tellg(); -+ if (file_size != size_buff) { -+ ifs.close(); -+ FT_LOG_ERROR("file: %s size is %d desc size is %d\n", file.c_str(), file_size, size_buff); -+ return -1; -+ } -+ // return 0; -+ ifs.seekg(0, std::ios::beg); -+ ifs.read(reinterpret_cast(buf), size_buff); -+ ifs.close(); -+ return 0; -+} -+ -+template -+int CalcTensorsSize(std::vector& tensors) -+{ -+ int total = 0; -+ for (size_t i = 0; i < tensors.size(); i++) { -+ float size = 1; -+ for (size_t j = 0; j < tensors[i].shape.size(); j++) { -+ size *= tensors[i].shape[j]; -+ } -+ total += size; -+ } -+ -+ return total * sizeof(T); -+} -+ -+template -+int ReadTensors(std::vector& tensors, std::string post, opt_arg* opt_a, bool cpy = true) -+{ -+ for (size_t i = 0; i < tensors.size(); i++) { -+ // if (tensors[i].type != TYPE_FP32) { -+ // FT_LOG_ERROR("Type not supported, exiting "); -+ // return -1; -+ // } -+ float size = 1; -+ for (size_t j = 0; j < tensors[i].shape.size(); j++) { -+ size *= tensors[i].shape[j]; -+ } -+ std::string suffix = post.compare("weight") == 0 ? opt_a->w_compute_type : opt_a->compute_type; -+ std::string fn = opt_a->model_name + "_" + post + std::to_string(i + 1) + "." + suffix; -+ T* input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ int res = ReadFileBuf(fn, input_host, size * sizeof(T)); -+ if (res) { -+ fn = opt_a->model_name + "_" + post + std::to_string(i + 1) + "." + "fp16"; -+ res = ReadFileBuf(fn, input_host, size * 2); -+ } -+ FT_CHECK(!res); -+ if (tensors[i].where == MEMORY_GPU) { -+ deviceMalloc(&input, size, false); -+ if (cpy) -+ cudaH2Dcpy(input, input_host, size); -+ else -+ deviceMemSetZero(input, size); -+ tensors[i].data = input; -+ free(input_host); -+ input_host = 0; -+ } -+ else if (tensors[i].where == MEMORY_CPU) { -+ tensors[i].data = input_host; -+ } -+ } -+ return 0; -+} -+ -+template -+static float CompareData(const T* refOutput, int size, const T* msTensorData) -+{ -+ constexpr float relativeTolerance = 1e-5; -+ constexpr float absoluteTolerance = 1e-8; -+ size_t errorCount = 0; -+ float meanError = 0; -+ std::cout << "Out tensor size is: " << size << std::endl; -+ std::cout << "Data of model output: "; -+ static int x = 0; -+ int s = std::min(10, size); -+ if (x == 0) { -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(msTensorData[j]) << " "; -+ } -+ std::cout << std::endl; -+ std::cout << "Data of Ref output : "; -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(refOutput[j]) << " "; -+ } -+ std::cout << std::endl; -+ } -+ x++; -+ int nan_cnt = 0; -+ for (int j = 0; j < size; j++) { -+ if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { -+ // std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; -+ // FT_LOG_ERROR("Output tensor has nan or inf data, compare fail\n"); -+ // return RET_ERROR; -+ // return -1; -+ nan_cnt++; -+ continue; -+ } -+ -+ auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); -+ auto absoluteError = std::fabs(static_cast(msTensorData[j]) - static_cast(refOutput[j])); -+ if (absoluteError > tolerance) { -+ if (fabs(refOutput[j]) == 0) { -+ if (absoluteError > 1e-5) { -+ meanError += absoluteError; -+ errorCount++; -+ } -+ else { -+ continue; -+ } -+ } -+ else { -+ // if (absoluteError > 1e-2) std::cout << "idx=" < 0.0f) { -+ meanError /= errorCount; -+ } -+ if (meanError <= 0.0000001) { -+ std::cout << "Mean bias of tensor: 0%" << std::endl; -+ } -+ else { -+ std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; -+ } -+ std::cout << std::endl; -+ return meanError; -+} -+ -+template -+int CompareOutput(std::vector output_python_tensors, std::vector output_tensors) -+{ -+ float total_bias = 0; -+ int total_size = 0; -+ float accuracy_threshold_ = 0.5f; -+ bool has_error = false; -+ for (size_t i = 0; i < output_tensors.size(); i++) { -+ float size = 1; -+ for (size_t j = 0; j < output_tensors[i].shape.size(); j++) { -+ size *= output_tensors[i].shape[j]; -+ } -+ T* output_device = (T*)output_tensors[i].data; -+ T* output_host = (T*)malloc(size * sizeof(T)); -+ cudaD2Hcpy(output_host, output_device, size); -+ float bias = CompareData((T*)output_python_tensors[i].data, size, output_host); -+ free(output_host); -+ if (bias >= 0) { -+ total_bias += bias; -+ total_size++; -+ } -+ else { -+ has_error = true; -+ break; -+ } -+ } -+ if (!has_error) { -+ float mean_bias; -+ if (total_size != 0) { -+ mean_bias = total_bias / total_size * 100; -+ } -+ else { -+ mean_bias = 0; -+ } -+ -+ std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" -+ << " threshold is:" << accuracy_threshold_ << std::endl; -+ std::cout << "=======================================================" << std::endl << std::endl; -+ -+ if (mean_bias > accuracy_threshold_) { -+ FT_LOG_INFO("Mean bias of all nodes/tensors is too big: %f %", mean_bias); -+ std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; -+ return -9; -+ } -+ else { -+ return 0; -+ } -+ } -+ else { -+ FT_LOG_ERROR("Error in CompareData"); -+ std::cerr << "Error in CompareData" << std::endl; -+ std::cout << "=======================================================" << std::endl << std::endl; -+ return -1; -+ } -+} -+ -+void FreeDesc(std::vector& desc) -+{ -+ for (size_t i = 0; i < desc.size(); i++) { -+ if (desc[i].where == MEMORY_GPU) { -+ cudaFree((float*)desc[i].data); -+ } -+ else if (desc[i].where == MEMORY_CPU) { -+ free((float*)desc[i].data); -+ } -+ } -+} -+ -+uint64_t GetTimeUs() -+{ -+ const int USEC = 1000000; -+ const int MSEC = 1000; -+ struct timespec ts = {0, 0}; -+ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { -+ return 0; -+ } -+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); -+ return retval; -+} -+ -+template -+int MsExample(opt_arg* opt_a) -+{ -+ printf("[INFO] Device: %s \n", getDeviceName().c_str()); -+ -+ cudaStream_t stream; -+ cublasHandle_t cublas_handle; -+ cublasLtHandle_t cublaslt_handle; -+ cudaStreamCreate(&stream); -+ cublasCreate(&cublas_handle); -+ cublasLtCreate(&cublaslt_handle); -+#ifdef SPARSITY_ENABLED -+ cusparseLtHandle_t cusparselt_handle; -+ CHECK_CUSPARSE(cusparseLtInit(&cusparselt_handle)); -+#endif -+ cublasSetStream(cublas_handle, stream); -+ cublasAlgoMap* cublas_algo_map = new cublasAlgoMap("gemm_config.in", ""); -+ -+ Allocator allocator(getDevice()); -+ -+ std::mutex* cublas_wrapper_mutex = new std::mutex(); -+#ifdef SPARSITY_ENABLED -+ cublasMMWrapper cublas_wrapper = cublasMMWrapper( -+ cublas_handle, cublaslt_handle, cusparselt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); -+#else -+ cublasMMWrapper cublas_wrapper = -+ cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); -+#endif -+ if (std::is_same::value) { -+ if (std::is_same::value) { -+ cublas_wrapper.setFP16MixedGemmConfig(); -+ } -+ else { -+ cublas_wrapper.setFP16GemmConfig(); -+ } -+ } -+ else if (std::is_same::value) { -+ if (std::is_same::value) { -+ cublas_wrapper.setFP32MixedGemmConfig(); -+ } -+ else { -+ cublas_wrapper.setFP32GemmConfig(); -+ } -+ } -+ Decriptor desc; -+ Init(opt_a, desc, stream, &cublas_wrapper, cublas_handle, &allocator); -+ int res = ReadTensors(desc.input_tensors, std::string("input"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.input_python_tensors, std::string("input"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.output_tensors, std::string("output"), opt_a, false); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.output_python_tensors, std::string("output"), opt_a); -+ FT_CHECK(!res); -+ res = ReadTensors(desc.w_tensors, std::string("weight"), opt_a); -+ FT_CHECK(!res); -+ desc.MSLayer->InitWeight(opt_a, desc.MSLayer->ms_weights, desc.w_tensors); -+ desc.MSLayer->forward(&desc.output_tensors, &desc.input_tensors, desc.MSLayer->ms_weights); -+ CompareOutput(desc.output_python_tensors, desc.output_tensors); -+#define DO_TIME1 -+#ifdef DO_TIME -+ // warmup -+ for (int i = 0; i < 10; i++) { -+ desc.MSLayer->forward(&desc.output_tensors, &desc.input_tensors, desc.MSLayer->ms_weights); -+ } -+ // profile time -+ const int ite = 1000; -+ CudaTimer cuda_timer(stream); -+ cuda_timer.start(); -+ float total_time = cuda_timer.stop(); -+ printf("batch_size %ld seq_len %ld layer %ld " -+ "AVG FT-CPP-time %.2f ms (%d iterations) " -+ "Total Time %.2f ms\n", -+ opt_a->batch_size, -+ opt_a->seq_len, -+ opt_a->num_layers, -+ total_time / ite, -+ ite, -+ total_time); -+#endif -+ -+#ifdef SPARSITY_ENABLED -+ cusparseLtDestroy(&cusparselt_handle); -+#endif -+ delete cublas_algo_map; -+ delete cublas_wrapper_mutex; -+ FreeDesc(desc.output_tensors); -+ FreeDesc(desc.input_tensors); -+ FreeDesc(desc.output_python_tensors); -+ FreeDesc(desc.w_tensors); -+ return 0; -+} -diff --git a/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer b/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer -new file mode 160000 -index 0000000..cbaa0d8 ---- /dev/null -+++ b/examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer -@@ -0,0 +1 @@ -+Subproject commit cbaa0d8707db403d85ad0e13c59f2f71cd6db425 -diff --git a/examples/pytorch/vit/ViT-quantization/ViT-pytorch b/examples/pytorch/vit/ViT-quantization/ViT-pytorch -new file mode 160000 -index 0000000..460a162 ---- /dev/null -+++ b/examples/pytorch/vit/ViT-quantization/ViT-pytorch -@@ -0,0 +1 @@ -+Subproject commit 460a162767de1722a014ed2261463dbbc01196b6 -diff --git a/path.sh b/path.sh -new file mode 100755 -index 0000000..53f5ca6 ---- /dev/null -+++ b/path.sh -@@ -0,0 +1 @@ -+export PATH=/usr/local/cuda-11/bin:/home/yoni/.vscode-server/bin/4af164ea3a06f701fe3e89a2bcbb421d2026b68f/bin/remote-cli:/home/yoni/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin -diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt -index 3db0830..3dd4210 100644 ---- a/src/fastertransformer/kernels/CMakeLists.txt -+++ b/src/fastertransformer/kernels/CMakeLists.txt -@@ -159,9 +159,12 @@ add_library(matrix_vector_multiplication STATIC matrix_vector_multiplication.cu) - set_property(TARGET matrix_vector_multiplication PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET matrix_vector_multiplication PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - --add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) --set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) --set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+if(${SM} GREATER_EQUAL 70) -+ message("-- Making custom kernels") -+ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) -+ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -+ set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+endif() - - add_library(vit_kernels STATIC vit_kernels.cu) - set_property(TARGET vit_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/kernels/activation_kernels.cu b/src/fastertransformer/kernels/activation_kernels.cu -index 7ff8e0f..abe7634 100644 ---- a/src/fastertransformer/kernels/activation_kernels.cu -+++ b/src/fastertransformer/kernels/activation_kernels.cu -@@ -19,6 +19,82 @@ - #include "src/fastertransformer/utils/cuda_utils.h" - namespace fastertransformer { - -+template -+__inline__ __device__ T fastGelu(T x) -+{ -+ T abs_x = fabsf((T)x); -+ float numerator = expf(0.851f * (x - abs_x)); -+ float denominator = 1 + expf(-1.702f * abs_x); -+ return (T)(x / denominator * numerator); -+} -+ -+template<> -+__inline__ __device__ half fastGelu(half x) -+{ -+ half abs_x = (half)(fabsf(__half2float(x))); -+ half numerator = hexp((half)(0.851f) * (x - abs_x)); -+ half denominator = (half)1 + hexp(half(-1.702f) * abs_x); -+ return (x / denominator * numerator); -+} -+ -+template<> -+__inline__ __device__ half2 fastGelu(half2 x) -+{ -+ half2 half2_x_abs = __habs2(x); -+ half2 numerator = h2exp(half2(0.851, 0.851) * (x - half2_x_abs)); -+ half2 denominator = half2(1, 1) + h2exp(half2(-1.702, -1.702) * half2_x_abs); -+ return (x / denominator * numerator); -+} -+ -+template -+__global__ void addBiasFastGelu(T* out, const T* __restrict bias, int m, int n) -+{ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ T val = out[id]; -+ if (bias != nullptr) { -+ T reg_bias = __ldg(&bias[id % n]); -+ val = val + reg_bias; -+ } -+ out[id] = (fastGelu(val)); -+ } -+} -+ -+template<> -+__global__ void addBiasFastGelu(half* out, const half* __restrict bias, int m, int n) -+{ -+ half2* out_ptr = (half2*)out; -+ const half2* bias_ptr = (half2*)bias; -+ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ half2 val = out_ptr[id]; -+ if (bias != nullptr) { -+ half2 reg_bias = __ldg(&bias_ptr[id % n]); -+ val = __hadd2(val, reg_bias); -+ } -+ out_ptr[id] = fastGelu(val); -+ } -+} -+ -+template -+void invokeAddBiasFastGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ addBiasFastGelu<<>>(out, bias, m, n / data_type_factor); -+} -+ -+template void invokeAddBiasFastGelu(float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasFastGelu(half* out, const half* bias, const int m, const int n, cudaStream_t stream); -+ -+ - template - __inline__ __device__ T gelu(T x) - { -@@ -201,12 +277,21 @@ __global__ void add_bias(H_T* out, const B_T* __restrict bias, int m, int n) - } - } - -+template -+__global__ void add_bias_basic(H_T* out, const B_T* __restrict bias, int m, int n) -+{ -+ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ out[id] = out[id] + (H_T)ldg(&bias[id % n]); -+ } -+} -+ - template<> - __global__ void add_bias(half* out, const half* __restrict bias, int m, int n) - { - half2* out_ptr = (half2*)out; - const half2* bias_ptr = (half2*)bias; -- for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { -+ int id = blockIdx.x * blockDim.x + threadIdx.x; -+ for (; id < m * n; id += blockDim.x * gridDim.x) { - out_ptr[id] = out_ptr[id] + __ldg(&bias_ptr[id % n]); - } - } -@@ -228,15 +313,29 @@ void invokeAddBias(H_T* out, const B_T* bias, const int m, const int n, cudaStre - { - const int data_type_factor = 4 / sizeof(H_T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; -- if (n / 4 / data_type_factor <= 1024) { -- block.x = n / 4 / data_type_factor; -- grid.x = m; -- } -- else { -- block.x = 1024; -- grid.x = ceil(m * n / 1024.); -+ -+ bool reminder = (data_type_factor != 1) ? (n % data_type_factor) : false; -+ if (reminder) { -+ if (n / 4 <= 1024) { -+ block.x = n / 4; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ add_bias_basic<<>>(out, bias, m, n); -+ } else { -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * n / 1024.); -+ } -+ add_bias<<>>(out, bias, m, (n / data_type_factor)); - } -- add_bias<<>>(out, bias, m, n / data_type_factor); - } - - template void invokeAddBias(float* out, const float* bias, const int m, const int n, cudaStream_t stream); -diff --git a/src/fastertransformer/kernels/activation_kernels.h b/src/fastertransformer/kernels/activation_kernels.h -index 6600457..f8c379a 100644 ---- a/src/fastertransformer/kernels/activation_kernels.h -+++ b/src/fastertransformer/kernels/activation_kernels.h -@@ -25,6 +25,9 @@ namespace fastertransformer { - template - void invokeAddBiasGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); - -+template -+void invokeAddBiasFastGelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ - template - void invokeAddBiasRelu(T* out, const T* bias, const int m, const int n, cudaStream_t stream); - -diff --git a/src/fastertransformer/kernels/add_residual_kernels.cu b/src/fastertransformer/kernels/add_residual_kernels.cu -index 4cd9f0f..42c9216 100644 ---- a/src/fastertransformer/kernels/add_residual_kernels.cu -+++ b/src/fastertransformer/kernels/add_residual_kernels.cu -@@ -29,6 +29,30 @@ __global__ void addBiasResidual(T* output, const T* input, const T* bias, const - } - } - -+template -+__global__ void addBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n) -+{ -+ S *out_cast = (S*)out; -+ const int col_index = blockIdx.y * blockDim.x + threadIdx.x; -+ if (col_index < n) { -+ T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index]; -+ out_cast[blockIdx.x * n + col_index] = -+ (S)((T)output[blockIdx.x * n + col_index] + (T)input[blockIdx.x * n + col_index] + bias_val); -+ } -+} -+ -+template -+__global__ void addBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n) -+{ -+ S *out_cast = (S*)out; -+ const int col_index = blockIdx.y * blockDim.x + threadIdx.x; -+ if (col_index < n) { -+ T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index]; -+ out_cast[blockIdx.x * n + col_index] = -+ (S)((T)output[blockIdx.x * n + col_index] + (T)input[blockIdx.x * n + col_index] + bias_val); -+ } -+} -+ - template - void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream) - { -@@ -38,6 +62,31 @@ void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m - addBiasResidual<<>>(output, input, bias, m, n); - } - -+template -+void invokeAddBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ int blocks_per_row = ceil(float(n) / 1024); -+ dim3 grid(m, blocks_per_row); -+ dim3 block(min(n, 1024)); -+ addBiasResidualCast<<>>(output, input, out, bias, m, n); -+} -+ -+template -+void invokeAddBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream) -+{ -+ int blocks_per_row = ceil(float(n) / 1024); -+ dim3 grid(m, blocks_per_row); -+ dim3 block(min(n, 1024)); -+ addBiasResidualSameTypeCast<<>>(output, input, out, bias, m, n); -+} -+ -+template void invokeAddBiasResidualCast(half* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(float* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(float* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidualCast(half* output, const float* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+ -+template void invokeAddBiasResidualSameTypeCast(half* output, const half* input, float* out, const float* bias, const int m, const int n, cudaStream_t stream); -+ - template - __global__ void addBiasAttentionFfnResidual(T* block_output, - const T* ffn_output, -@@ -88,11 +137,9 @@ void invokeAddBiasAttentionFfnResidual(T* block_output, - } - } - --template void invokeAddBiasResidual( -- float* output, const float* input, const float* bias, const int m, const int n, cudaStream_t stream); -+template void invokeAddBiasResidual(float *output, const float *input, const float *bias, int m, int n, cudaStream_t stream); -+template void invokeAddBiasResidual(half *output, const half *input, const half *bias, int m, int n, cudaStream_t stream); - --template void --invokeAddBiasResidual(half* output, const half* input, const half* bias, const int m, const int n, cudaStream_t stream); - - #ifdef ENABLE_BF16 - template void invokeAddBiasResidual(__nv_bfloat16* output, -diff --git a/src/fastertransformer/kernels/add_residual_kernels.h b/src/fastertransformer/kernels/add_residual_kernels.h -index edd8179..afa5a77 100644 ---- a/src/fastertransformer/kernels/add_residual_kernels.h -+++ b/src/fastertransformer/kernels/add_residual_kernels.h -@@ -27,6 +27,9 @@ namespace fastertransformer { - template - void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream); - -+template -+void invokeAddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, int max_seq, const int *sequent_len, cudaStream_t stream); -+ - template - void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cudaStream_t stream); - -@@ -65,4 +68,11 @@ void invokeAddBiasResidualCol32(T* output, - const float* input1_amax_ptr, - const int scale_is_vector = 0); - -+template -+void invokeAddBiasResidualCast(U* output, const T* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ -+template -+void invokeAddBiasResidualSameTypeCast(U* output, const U* input, T* out, const T* bias, const int m, const int n, cudaStream_t stream); -+ - } // namespace fastertransformer -+ -diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.cu b/src/fastertransformer/kernels/bert_preprocess_kernels.cu -index c855fa1..19e29bc 100644 ---- a/src/fastertransformer/kernels/bert_preprocess_kernels.cu -+++ b/src/fastertransformer/kernels/bert_preprocess_kernels.cu -@@ -14,10 +14,13 @@ - * limitations under the License. - */ - -+#include "reduce_kernel_utils.cuh" - #include "bert_preprocess_kernels.h" -+#include "src/fastertransformer/utils/cuda_utils.h" - - namespace fastertransformer { - -+ - __global__ void getPaddingOffsetKernel(size_t* valid_word_num, - int* tmp_mask_offset, - const int* sequence_length, -@@ -29,7 +32,7 @@ __global__ void getPaddingOffsetKernel(size_t* valid_word_num, - int cum_offset = 0; - int index = 0; - for (int i = 0; i < batch_size; i++) { -- const int seq_len = sequence_length[i]; -+ const int seq_len = (sequence_length[i] == -1) ? 0 : sequence_length[i]; - for (int j = 0; j < seq_len; j++) { - tmp_mask_offset[index] = cum_offset; - index++; -@@ -50,50 +53,315 @@ void invokeGetPaddingOffset(size_t* h_token_num, - { - getPaddingOffsetKernel<<<1, 1, 0, stream>>>( - d_token_num, tmp_mask_offset, sequence_lengths, batch_size, max_seq_len); -- sync_check_cuda_error(); -- check_cuda_error(cudaMemcpyAsync(h_token_num, d_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream)); -- sync_check_cuda_error(); -+ if (h_token_num != nullptr) { -+ cudaMemcpyAsync(h_token_num, d_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream); -+ } - } - - template --__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) -+__global__ void buildSequnceLength(const T * input, int *sequence_length, const int max_seq_length, const int hidden_size) { -+ __shared__ int s_max_val; -+ int bid = blockIdx.x; -+ const T * seq_base = input + bid* max_seq_length * hidden_size; -+ const T zero = static_cast(0.f); -+ int last = -max_seq_length; -+ for (int i=max_seq_length - 1 - threadIdx.x; i >= 0; i -= blockDim.x) { -+ const T * seq_ptr = seq_base + i * hidden_size; -+ if ((seq_ptr[0] == zero) && (seq_ptr[1] == zero)) { -+ last = -i; -+ } -+ } -+ int max_val = blockReduceMax(last); -+ if (threadIdx.x == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ sequence_length[bid] = -s_max_val; -+} -+ -+__global__ void buildSequnceLength(const int *input, int *sequence_length, const int max_seq_length) { -+ __shared__ int s_max_val; -+ int bid = blockIdx.x; -+ int last = 0; -+ const int *base = input + bid * max_seq_length; -+ for (int i=threadIdx.x ; i < max_seq_length; i += blockDim.x) { -+ const int *ptr = base + i; -+ if (*ptr != 0){ -+ last = i; -+ } -+ } -+ int max_val = blockReduceMax(last); -+ if (threadIdx.x == 0) { -+ s_max_val = max_val + 1; -+ } -+ __syncthreads(); -+ sequence_length[bid] = s_max_val; -+} -+ -+__global__ void buildSequnceOffset(int *sequence_length, int *sequence_offset, int batch_size) { -+ for (int i = 0; i < batch_size ; i += 1) { -+ if (i == 0) {sequence_offset[i] = 0;} -+ else { -+ sequence_offset[i] = sequence_offset[i-1] + sequence_length[i-1] + 1; -+ } -+ } -+} -+__global__ void findStartPoint(int *sequence_lengths, int *start_points, int batch_size,int hidden_size) { -+ int i = blockDim.x * blockIdx.x + threadIdx.x; -+ if (i <= batch_size) { -+ int sum = 0; -+ for (int j = 0; j < i; j++) { -+ sum += sequence_lengths[j]*hidden_size; -+ } -+ start_points[i] = sum; -+ } -+} -+ -+ -+ -+template -+__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode) - { - // sequence_lengths: [batch_size] - // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] -- attention_mask += blockIdx.x * max_seq_len * max_seq_len; -- const int length = sequence_lengths[blockIdx.x]; -- for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { -- // int row_id = i / max_seq_len; -- int col_id = i % max_seq_len; -- // if (row_id < length && col_id < length) { -+ attention_mask += blockIdx.x * src_seq_len * tgt_seq_len; -+ const int q_length = q_sequence_lengths[blockIdx.x]; -+ const int kv_length = kv_sequence_lengths[blockIdx.x]; -+ for (int i = threadIdx.x; i < src_seq_len * tgt_seq_len; i += blockDim.x) { -+ int row_id = i / tgt_seq_len; -+ int col_id = i % tgt_seq_len; - // TODO (bhsueh) check this modification is ok or not on other rmodel -- if (col_id < length) { -- attention_mask[i] = (T)(1.0f); -- } -- else { -+ if (col_id >= q_length || row_id >= kv_length) { - attention_mask[i] = (T)(0.0f); - } - } - } - -+ - template - void invokeBuildEncoderAttentionMask( -- T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) -+ T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode, cudaStream_t stream) - { -- buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); -+ buildEncoderAttentionMaskKernel<<>>(attention_mask, q_sequence_lengths, kv_sequence_lengths, src_seq_len, tgt_seq_len, incremental_mode); - } - -+ - template void invokeBuildEncoderAttentionMask(float* attention_mask, -- const int* sequence_lengths, -+ const int* q_sequence_lengths, -+ const int* kv_sequence_lengths, - const int batch_size, -- const int max_seq_len, -+ const int src_seq_len, -+ const int tgt_seq_len, -+ const bool incremental_mode, - cudaStream_t stream); - template void invokeBuildEncoderAttentionMask(half* attention_mask, -- const int* sequence_lengths, -+ const int* q_sequence_lengths, -+ const int* kv_sequence_lengths, - const int batch_size, -- const int max_seq_len, -+ const int src_seq_len, -+ const int tgt_seq_len, -+ const bool incremental_mode, - cudaStream_t stream); - -+__global__ void buildUsePastSeqLenKernel(int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode)//inc fix -+{ -+ // sequence_lengths: [batch_size] -+ // sequence_lengths2: [batch_size] -+ if (sequence_length_src[blockIdx.x] == -1) -+ sequence_offset_dst[blockIdx.x] = -1; -+ else if (!incremental_mode) { -+ sequence_offset_dst[blockIdx.x] = sequence_length_src[blockIdx.x] + 1; -+ } else { -+ sequence_offset_dst[blockIdx.x] = 1; -+ } -+} -+ -+ -+void buildUsePastSeqLenKernel(int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode, cudaStream_t stream) -+{ -+ buildUsePastSeqLenKernel<<>>(sequence_length_src, sequence_offset_dst, batch_size, incremental_mode); -+} -+ -+template -+__global__ void buildUsePastAttentionMaskKernel(T* attention_mask, const int tgt_seq_len) -+{ -+ // attention_mask: [1, 1, tgt_seq_len] -+ for (int i = threadIdx.x; i < tgt_seq_len; i += blockDim.x) { -+ attention_mask[i] = (T)(1.0f); -+ } -+} -+ -+ -+template -+void invokeBuildUsePastAttentionMask( -+ T* attention_mask, const int tgt_seq_len, cudaStream_t stream) -+{ -+ buildUsePastAttentionMaskKernel<<<1, 256, 0, stream>>>(attention_mask, tgt_seq_len); -+} -+ -+template void invokeBuildUsePastAttentionMask(float* attention_mask, -+ const int tgt_seq_len, -+ cudaStream_t stream); -+template void invokeBuildUsePastAttentionMask(half* attention_mask, -+ const int tgt_seq_len, -+ cudaStream_t stream); -+ -+ template -+__global__ void EmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size) { -+ -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * hidden_size; -+ index += gridDim.x * blockDim.x) { -+ // Gather the value from the input array -+ // Store the value in the output array -+ -+ int h_token_idx = index / hidden_size; -+ T val = (T)(emmbeding_pos_table[input_position[h_token_idx] * hidden_size + index % hidden_size]); -+ output[index] = val + emmbeding_table[input[h_token_idx] * hidden_size + index % hidden_size]; -+ } -+} -+__global__ void EmbeddingPanguSigmaHalf(const int *input, -+ const int *input_position, -+ const half *emmbeding_table, -+ const half *emmbeding_pos_table, -+ half *output, -+ int h_token_num, -+ int hidden_size) { -+ half2* output_ptr = (half2*)output; -+ const half2* emmbeding_table_ptr = (half2*)emmbeding_table; -+ const half2* emmbeding_pos_table_ptr = (half2*)emmbeding_pos_table; -+ -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * (hidden_size / 2); -+ index += gridDim.x * blockDim.x) { -+ // Gather the value from the input array -+ // Store the value in the output array -+ int h_token_idx = (index * 2) / hidden_size; -+ half2 val = emmbeding_pos_table_ptr[(input_position[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ half2 val2 = emmbeding_table_ptr[(input[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ output_ptr[index] = __hadd2(val, val2); -+ } -+} -+template -+void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream) { -+ const int m = h_token_num; -+ const int n = hidden_size; -+ const int data_type_factor = (hidden_size % 2 == 1) ? 1 : 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor + n % data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * (n / data_type_factor) / 1024.); -+ } -+ if (data_type_factor == 1) { -+ EmbeddingPanguSigma<<>>( -+ input, input_position, emmbeding_table, emmbeding_pos_table, output, h_token_num, hidden_size); -+ } else { -+ EmbeddingPanguSigmaHalf<<>>( -+ input, input_position, (const half*)emmbeding_table, (const half*)emmbeding_pos_table, (half*)output, h_token_num, hidden_size); -+ } -+} -+template void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const float *emmbeding_table, -+ const float *emmbeding_pos_table, -+ float *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const half *emmbeding_table, -+ const half *emmbeding_pos_table, -+ half *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+ -+ -+template -+__global__ void VocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size) { -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * hidden_size; -+ index += gridDim.x * blockDim.x) { -+ int h_token_idx = index / hidden_size; -+ // Gather the value from the input array -+ // Store the value in the output array -+ output[index] = emmbeding_table[input[h_token_idx] * hidden_size + index % hidden_size]; -+ } -+} -+ -+__global__ void VocabEmbeddingHalf(const int *input, -+ const half *emmbeding_table, -+ half *output, -+ int h_token_num, -+ int hidden_size) { -+ half2* output_ptr = (half2*)output; -+ const half2* emmbeding_table_ptr = (half2*)emmbeding_table; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * (hidden_size / 2); -+ index += gridDim.x * blockDim.x) { -+ int h_token_idx = (index * 2) / hidden_size; -+ // Gather the value from the input array -+ // Store the value in the output array -+ output_ptr[index] = emmbeding_table_ptr[(input[h_token_idx] * hidden_size + (index * 2) % hidden_size) / 2]; -+ } -+} -+template -+void invokeVocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream) { -+ const int m = h_token_num; -+ int n = hidden_size; -+ const int data_type_factor = (hidden_size % 2 == 1) ? 1 : 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 -+ dim3 block, grid; -+ if (n / 4 / data_type_factor <= 1024) { -+ block.x = n / 4 / data_type_factor; -+ grid.x = m; -+ } -+ else { -+ block.x = 1024; -+ grid.x = ceil(m * (n / data_type_factor) / 1024.); -+ } -+ if (data_type_factor == 1) { -+ VocabEmbedding<<>>( -+ input, emmbeding_table, output, h_token_num, hidden_size); -+ } else { -+ VocabEmbeddingHalf<<>>( -+ input, (const half*)emmbeding_table, (half*)output, h_token_num, hidden_size); -+ } -+} -+template void invokeVocabEmbedding(const int *input, -+ const float *emmbeding_table, -+ float *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template void invokeVocabEmbedding(const int *input, -+ const half *emmbeding_table, -+ half *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); - __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) - { - // use for get tensorrt fused mha padding offset -@@ -113,6 +381,26 @@ __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int - } - } - -+ -+ -+template -+void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream) { -+ buildSequnceLength<<>>(input,sequnce_length, max_seq_length,hidden_size); -+} -+ -+void invokeBuildSequenceLength(const int * input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream) { -+ buildSequnceLength<<>>(input,sequnce_length, max_seq_length); -+} -+void invokeBuildSequnceOffset(int batch_size, int *sequnce_length, int* sequnce_offset,int hidden_size,cudaStream_t stream) { -+ findStartPoint<<<1, batch_size+1, 0, stream>>>(sequnce_length, sequnce_offset, batch_size+1,hidden_size); -+ -+} -+ -+ -+ -+ -+ -+ - void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int batch_size, -@@ -176,7 +464,52 @@ void invokeRebuildPadding( - // dst: [batch_size*max_seq_len, hidden_dim] - rebuild_sequence_length_padding<<>>(src, dst, padding_offset, n); - } -- -+template -+__global__ void rebuild_query_padding(const T* src, T* dst, const int* d_seq_len, const int batch, const int n) -+{ -+ // const int tid = threadIdx.x; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch * n; -+ index += gridDim.x * blockDim.x) { -+ int dst_index = index; -+ int batch_id = index / n; -+ int token_idx = 0; -+ int i{0}; -+ while (i <= batch_id) -+ { -+ if (d_seq_len[i] == -1) { -+ dst_index += n; -+ batch_id++; -+ } -+ else { -+ token_idx += d_seq_len[i]; -+ } -+ i++; -+ } -+ dst[dst_index] = src[(token_idx - 1) * n + (index % n)]; -+ } -+} -+template -+void invokeRebuildQuery( -+ T* dst, const T* src, const int* d_seq_len, const int batch, const int n, cudaStream_t stream) -+{ -+ // src: [token_num, hidden_dim] -+ // dst: [batch_size, hidden_dim] -+ dim3 grid((int)(ceil(1.0 * batch * n / 512))); -+ dim3 block(512); -+ rebuild_query_padding<<>>(src, dst, d_seq_len, batch, n); -+} -+template void invokeRebuildQuery(float* dst, -+ const float* src, -+ const int* d_seq_len, -+ const int batch, -+ const int n, -+ cudaStream_t stream); -+template void invokeRebuildQuery(half* dst, -+ const half* src, -+ const int* d_seq_len, -+ const int batch, -+ const int n, -+ cudaStream_t stream); - template - void invokeRebuildPadding( - T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); -@@ -226,6 +559,12 @@ template void invokeRemovePadding(half* dst, - const int token_num, - const int hidden_dim, - cudaStream_t stream); -+template void invokeRemovePadding(int* dst, -+ const int* src, -+ const int* padding_offset, -+ const int token_num, -+ const int hidden_dim, -+ cudaStream_t stream); - - template - __global__ void buildRelativeAttentionBias(T* relative_attention_bias, -@@ -300,6 +639,8 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - is_bidirectional, - max_distance); - } -+template void invokeBuildSequenceLength(const float * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); -+template void invokeBuildSequenceLength(const half * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); - - template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, - const float* relative_attention_bias_table, -diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.h b/src/fastertransformer/kernels/bert_preprocess_kernels.h -index dcb8f85..f444b53 100644 ---- a/src/fastertransformer/kernels/bert_preprocess_kernels.h -+++ b/src/fastertransformer/kernels/bert_preprocess_kernels.h -@@ -19,6 +19,8 @@ - #include "src/fastertransformer/utils/cuda_utils.h" - #include - #include -+#include -+#include - - namespace fastertransformer { - -@@ -32,7 +34,17 @@ void invokeGetPaddingOffset(size_t* h_token_num, - - template - void invokeBuildEncoderAttentionMask( -- T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); -+ T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, const bool incremental_mode, cudaStream_t stream); -+ -+template -+void invokeBuildUsePastAttentionMask( -+ T* attention_mask, const int tgt_seq_len, cudaStream_t stream); -+ -+template -+void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); -+ -+void invokeBuildSequenceLength(const int* input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream); -+void invokeBuildSequnceOffset(int batch_size, int *sequnce_length, int* sequnce_offset,int hidden_size,cudaStream_t stream); - - void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, -@@ -46,6 +58,25 @@ void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - cudaStream_t stream); - - template -+void invokeEmbeddingPanguSigma(const int *input, -+ const int *input_position, -+ const T *emmbeding_table, -+ const T *emmbeding_pos_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template -+void invokeVocabEmbedding(const int *input, -+ const T *emmbeding_table, -+ T *output, -+ int h_token_num, -+ int hidden_size, -+ cudaStream_t stream); -+template -+void invokeRebuildQuery( -+ T* dst, const T* src, const int* d_seq_len, const int batch, const int n, cudaStream_t stream); -+template - void invokeRebuildPadding( - T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); - -@@ -63,5 +94,6 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); -- -+void buildUsePastSeqLenKernel( -+ int *sequence_length_src, int *sequence_offset_dst, int batch_size, bool incremental_mode, cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/layernorm_kernels.cu b/src/fastertransformer/kernels/layernorm_kernels.cu -index 96a090e..e7bfec4 100644 ---- a/src/fastertransformer/kernels/layernorm_kernels.cu -+++ b/src/fastertransformer/kernels/layernorm_kernels.cu -@@ -13,6 +13,8 @@ - * See the License for the specific language governing permissions and - * limitations under the License. - */ -+#include -+ - - #include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" - #include "src/fastertransformer/kernels/layernorm_kernels.h" -@@ -29,7 +31,8 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - const T* __restrict gamma, - const T* __restrict beta, - int m, -- int n) -+ int n, -+ float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -74,7 +77,7 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n / 2 + 1e-6f); -+ s_variance = rsqrtf(variance / n / 2 + eps); - } - __syncthreads(); - -@@ -93,14 +96,15 @@ __global__ void generalAddBiasResidualLayerNormOpt(T* normed_output, - - // * Note that typename T is half2 or bfloat2 type - template --__global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, -- T* output, -+__global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output,//out -+ T* output,//out - const T* __restrict bias, -- const T* __restrict residual, -+ const T* __restrict residual,//input - const T* __restrict gamma, - const T* __restrict beta, - int m, -- int n) -+ int n, -+ float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -108,7 +112,6 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - float x2_sum = 0.0f; - const int b_offset = blockIdx.x * n; - using T1 = typename TypeConverter::Type; -- - #pragma unroll UNROLL_FACTOR - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int index = b_offset + i; -@@ -145,7 +148,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - - if (threadIdx.x == 0) { - s_mean = sums[0] / n / 2; -- s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + 1e-6f); -+ s_variance = rsqrtf(sums[1] / n / 2 - s_mean * s_mean + eps); - } - __syncthreads(); - -@@ -166,7 +169,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(T* normed_output, - // TODO(bhsueh) add half2 implementation - template - __global__ void --addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) -+addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -197,7 +200,7 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -209,10 +212,62 @@ addBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gam - idx += blockDim.x; - } - } -+template -+__global__ void addBiasResidualPostLayerNormCast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ float local_out_cache[N]; -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = (float)((T)attn_output[blockIdx.x * n + idx] + (T)input[blockIdx.x * n + idx] + (T)__ldg(&bias[idx])); -+ mean += local_out; -+ // save local_out to local_out_cache to save some recompute -+ local_out_cache[i] = local_out; -+ idx += blockDim.x; -+ } -+ -+ mean = blockReduceSum(mean); -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = local_out_cache[i]; -+ variance += (local_out - s_mean) * (local_out - s_mean); -+ idx += blockDim.x; -+ } -+ variance = blockReduceSum(variance); -+ if (threadIdx.x == 0) { -+ s_variance = variance / n + eps; -+ } -+ __syncthreads(); -+ -+#pragma unroll N -+ for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { -+ float local_out = local_out_cache[i]; -+ norm_attn_out[blockIdx.x * n + idx] = -+ (D)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); -+ idx += blockDim.x; -+ } -+} - - template - __global__ void addBiasResidualPostLayerNormHalf( -- half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) -+ half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -255,7 +310,7 @@ __global__ void addBiasResidualPostLayerNormHalf( - - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -274,7 +329,7 @@ __global__ void addBiasResidualPostLayerNormHalf( - - template - __global__ void --generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) -+generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -300,7 +355,7 @@ generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -311,9 +366,55 @@ generalAddBiasResidualPostLayerNorm(T* out, const T* input, const T* bias, const - } - } - -+template -+__global__ void generalAddBiasResidualPostLayerNormCast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = (float)((T)attn_output[blockIdx.x * n + idx] + (T)input[blockIdx.x * n + idx] + (T)__ldg(&bias[idx])); -+ mean += local_out; -+ // save local_out to out to save some recompute -+ attn_output[blockIdx.x * n + idx] = (T)local_out; -+ } -+ -+ mean = blockReduceSum(mean); -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = (T)attn_output[blockIdx.x * n + idx]; -+ variance += (local_out - s_mean) * (local_out - s_mean); -+ } -+ variance = blockReduceSum(variance); -+ if (threadIdx.x == 0) { -+ s_variance = variance / n + eps; -+ } -+ __syncthreads(); -+ -+ for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { -+ float local_out = attn_output[blockIdx.x * n + idx]; -+ norm_attn_out[blockIdx.x * n + idx] = -+ (D)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[idx])) + (float)(__ldg(&beta[idx]))); -+ } -+} -+ - template<> - __global__ void generalAddBiasResidualPostLayerNorm( -- half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n) -+ half* out, const half* input, const half* bias, const half* gamma, const half* beta, int m, int n, float eps) - { - __shared__ float s_mean; - __shared__ float s_variance; -@@ -352,7 +453,7 @@ __global__ void generalAddBiasResidualPostLayerNorm( - - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -373,7 +474,8 @@ __global__ void addBiasResidualPostLayerNormV2(T* out, - const T* __restrict bias, - const T* __restrict gamma, - const T* __restrict beta, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -409,7 +511,7 @@ __global__ void addBiasResidualPostLayerNormV2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -428,7 +530,8 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - const half* __restrict bias, - const half* __restrict gamma, - const half* __restrict beta, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -473,7 +576,7 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - - variance = blockReduceSum(var); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -486,26 +589,154 @@ __global__ void addBiasResidualPostLayerNormV2(half* out, - } - } - -+template -+__global__ void addBiasResidualPostLayerNormV2Cast(S* attn_output, -+ D* norm_attn_out, -+ const S* __restrict input, -+ const T* __restrict bias, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ int n, -+ float eps) -+{ -+ const int ite = 4; -+ const int tid = threadIdx.x; -+ const int bid = blockIdx.x; -+ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ float local_out[ite]; -+ -+ float sum = 0.0f; -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ int col_id = i * blockDim.x + tid; -+ int id = bid * n + col_id; -+ local_out[i] = (float)((T)(attn_output[id]) + (T)__ldg(&input[id]) + (T)__ldg(&bias[col_id])); -+ sum += local_out[i]; -+ } -+ -+ mean = blockReduceSum(sum); -+ if (tid == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ float var = 0.0f; -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ float diff = local_out[i] - s_mean; -+ var += diff * diff; -+ } -+ -+ variance = blockReduceSum(var); -+ if (tid == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ -+#pragma unroll -+ for (int i = 0; i < ite; i++) { -+ int col_id = i * blockDim.x + tid; -+ int id = bid * n + col_id; -+ norm_attn_out[id] = -+ (D)((local_out[i] - s_mean) * s_variance * (float)__ldg(&gamma[col_id]) + (float)__ldg(&beta[col_id])); -+ } -+} -+ -+template -+void invokeAddBiasResidualLayerNormCast( -+ S* attn_output, D* norm_attn_out, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps) -+{ -+ dim3 grid(m); -+ dim3 block(std::min(n, 1024)); -+ if (n == 768 || n == 1024) { -+ addBiasResidualPostLayerNormV2Cast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, n, eps); -+ } -+ else { -+ block.x = std::min(n, 1024); -+ int num_trips = (n + block.x - 1) / block.x; -+ if (num_trips == 1) { -+ addBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ else if (num_trips == 2) { -+ addBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ else { -+ generalAddBiasResidualPostLayerNormCast<<>>(attn_output, norm_attn_out, input, bias, gamma, beta, m, n, eps); -+ } -+ } -+} -+ -+ -+template void invokeAddBiasResidualLayerNormCast(float* out, half* norm_attn_out, -+ const float* input, -+ const float* bias, -+ const float* gamma, -+ const float* beta, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps); -+ -+template void invokeAddBiasResidualLayerNormCast(half* out, float* norm_attn_out, -+ const half* input, -+ const float* bias, -+ const float* gamma, -+ const float* beta, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps); -+ -+ -+template void invokeGeneralAddBiasResidualPreLayerNormCast( -+ float* attn_output, -+ half* norm_output, -+ const float* from_tensor, -+ const float* gamma, -+ const float* beta, -+ const float* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version); -+ -+template void invokeGeneralAddBiasResidualT5PreLayerNormCast( -+ float* attn_output, -+ half* norm_output, -+ const float* from_tensor, -+ const float* gamma, -+ const float* beta, -+ const float* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version); - template - void invokeAddBiasResidualLayerNorm( -- T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream) -+ T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps) - { - dim3 grid(m); - dim3 block(std::min(n, 1024)); - if (n == 768 || n == 1024) { -- addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); -+ addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n, eps); - } - else { - block.x = std::min(n, 1024); - int num_trips = (n + block.x - 1) / block.x; - if (num_trips == 1) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else if (num_trips == 2) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else { -- generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - } - } -@@ -518,25 +749,26 @@ void invokeAddBiasResidualLayerNorm(half* out, - const half* beta, - int m, - int n, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(m); - dim3 block(std::min(n, 1024)); - - if (m >= 512 && (n == 768 || n == 1024)) { -- addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n); -+ addBiasResidualPostLayerNormV2<<>>(out, input, bias, gamma, beta, n, eps); - } - else { - block.x = std::min(n, 1024); - int num_trips = (n + block.x - 1) / block.x; - if (num_trips == 1) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else if (num_trips == 2) { -- addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ addBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - else { -- generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n); -+ generalAddBiasResidualPostLayerNorm<<>>(out, input, bias, gamma, beta, m, n, eps); - } - } - } -@@ -548,7 +780,8 @@ template void invokeAddBiasResidualLayerNorm(float* out, - const float* beta, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - template void invokeAddBiasResidualLayerNorm(half* out, - const half* input, - const half* bias, -@@ -556,7 +789,8 @@ template void invokeAddBiasResidualLayerNorm(half* out, - const half* beta, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template - __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, -@@ -566,7 +800,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - T* output, - T* norm_output, - int m, -- int n) -+ int n, -+ float eps) - { - int tid = threadIdx.x; - -@@ -601,7 +836,7 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -612,6 +847,89 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - } - } - -+template -+__global__ void generalAddBiasResidualLayerNormCast(const T* __restrict input, -+ const T* __restrict gamma, -+ const T* __restrict beta, -+ const T* __restrict bias, -+ T* output, -+ S* norm_output, -+ int m, -+ int n, -+ float eps) -+{ -+ int tid = threadIdx.x; -+ -+ __shared__ float s_mean; -+ __shared__ float s_variance; -+ float mean = 0.0f; -+ float variance = 0.0f; -+ -+ float local_sum = 0.0f; -+ for (int i = tid; i < n; i += blockDim.x) { -+ float local_out = (float)(ldg(&input[blockIdx.x * n + i])); -+ local_out += (float)((T)output[blockIdx.x * n + i]); -+ if (bias != nullptr) { -+ local_out += (float)(ldg(&bias[i])); -+ } -+ output[blockIdx.x * n + i] = (T)local_out; -+ local_sum += local_out; -+ } -+ -+ mean = blockReduceSum(local_sum); -+ -+ if (threadIdx.x == 0) { -+ s_mean = mean / n; -+ } -+ __syncthreads(); -+ -+ float local_var_sum = 0.0f; -+ for (int i = tid; i < n; i += blockDim.x) { -+ float diff = (float)(output[blockIdx.x * n + i]) - s_mean; -+ local_var_sum += diff * diff; -+ } -+ variance = blockReduceSum(local_var_sum); -+ -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ -+ for (int i = tid; i < n; i += blockDim.x) { -+ float beta_val = (beta == nullptr) ? 0.0f : (float)(ldg(&beta[i])); -+ norm_output[blockIdx.x * n + i] = -+ (S)((((float)output[blockIdx.x * n + i] - s_mean) * s_variance) * (float)(ldg(&gamma[i])) + beta_val); -+ } -+} -+template -+__global__ void generalAddBiasResidualT5LayerNormCast(const T* __restrict input, -+ const T* __restrict gamma, -+ T* output, -+ S* norm_output, -+ int m, -+ int n, -+ float eps) -+{ -+ __shared__ float s_variance; -+ float variance = 0.0f; -+ float local_var_sum = 0.0f; -+ for (int i = threadIdx.x; i < n; i += blockDim.x) { -+ output[blockIdx.x * n + i] = -+ clamp_inf_for_half((float)__ldg(&input[blockIdx.x * n + i]) + (float)output[blockIdx.x * n + i]); -+ float diff = (float)(output[blockIdx.x * n + i]); -+ local_var_sum += diff * diff; -+ } -+ variance = blockReduceSum(local_var_sum); -+ -+ if (threadIdx.x == 0) { -+ s_variance = rsqrtf(variance / n + eps); -+ } -+ __syncthreads(); -+ for (int i = threadIdx.x; i < n; i += blockDim.x) { -+ norm_output[blockIdx.x * n + i] = -+ (S)(clamp_inf_for_half((((float)output[blockIdx.x * n + i]) * s_variance) * (float)(__ldg(&gamma[i])))); -+ } -+} - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt \ - <<>>((T2*)norm_output, \ -@@ -621,7 +939,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - (const T2*)gamma, \ - (const T2*)beta, \ - m, \ -- half_n); -+ half_n, \ -+ eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2 \ -@@ -632,7 +951,8 @@ __global__ void generalAddBiasResidualLayerNorm(const T* __restrict input, - (const T2*)gamma, \ - (const T2*)beta, \ - m, \ -- half_n); -+ half_n, \ -+ eps); - - template - void invokeGeneralAddBiasResidualPreLayerNorm(T* output, -@@ -644,6 +964,7 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version) - { - if (opt_version > 0 && sizeof(T) == 2 && n % 2 == 0) { -@@ -709,8 +1030,65 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - - /* should pay attention to the rsqrt precision*/ - generalAddBiasResidualLayerNorm -- <<>>(input, gamma, beta, bias, output, norm_output, m, n); // For gpt-3 -+ <<>>(input, gamma, beta, bias, output, norm_output, m, n, eps); // For gpt-3 -+ } -+} -+ -+template -+void invokeGeneralAddBiasResidualPreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ dim3 grid(m); -+ dim3 block(min(n, 1024)); -+ -+ /* For general cases, n is equal to hidden_units, e.g., 512/1024. -+ Since we have warp shuffle inside the code, block.x % 32 should be 0. -+ */ -+ -+ if (n % 32 != 0) { -+ block.x = 1024; -+ } -+ -+ /* should pay attention to the rsqrt precision*/ -+ generalAddBiasResidualLayerNormCast -+ <<>>(from_tensor, gamma, beta, bias, attn_output, norm_output, m, n, eps); // For gpt-3 -+} -+ -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ -+ dim3 grid(m); -+ dim3 block(min(n, 1024)); -+ -+ /* For general cases, n is equal to hidden_units, e.g., 512/1024. -+ Since we have warp shuffle inside the code, block.x % 32 should be 0. -+ */ -+ -+ if (n % 32 != 0) { -+ block.x = 1024; - } -+ -+ /* should pay attention to the rsqrt precision*/ -+ generalAddBiasResidualT5LayerNormCast -+ <<>>(from_tensor, gamma, attn_output, norm_output, m, n, eps); // For gpt-3 - } - - #undef HALF_LAYERNORM_OPT -@@ -725,6 +1103,7 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(float* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, -@@ -736,6 +1115,7 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(half* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - #ifdef ENABLE_BF16 -@@ -748,12 +1128,13 @@ template void invokeGeneralAddBiasResidualPreLayerNorm(__nv_bfloat16* output, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - - template - __global__ void generalAddResidualT5LayerNorm( -- const T* __restrict input, const T* __restrict gamma, T* output, T* norm_output, int m, int n) -+ const T* __restrict input, const T* __restrict gamma, T* output, T* norm_output, int m, int n, float eps) - { - // layernorm module in the T5 style No bias and no subtraction of mean. - __shared__ float s_variance; -@@ -770,7 +1151,7 @@ __global__ void generalAddResidualT5LayerNorm( - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -783,7 +1164,7 @@ __global__ void generalAddResidualT5LayerNorm( - - template - void invokeGeneralAddResidualT5PreLayerNorm( -- T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream) -+ T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream, float eps) - { - dim3 grid(m); - dim3 block(min(n, 1024)); -@@ -799,14 +1180,14 @@ void invokeGeneralAddResidualT5PreLayerNorm( - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ -- generalAddResidualT5LayerNorm<<>>(input, gamma, output, norm_output, m, n); -+ generalAddResidualT5LayerNorm<<>>(input, gamma, output, norm_output, m, n, eps); - } - - template void invokeGeneralAddResidualT5PreLayerNorm( -- float* output, float* norm_output, const float* input, const float* gamma, int m, int n, cudaStream_t stream); -+ float* output, float* norm_output, const float* input, const float* gamma, int m, int n, cudaStream_t stream, float eps); - - template void invokeGeneralAddResidualT5PreLayerNorm( -- half* output, half* norm_output, const half* input, const half* gamma, int m, int n, cudaStream_t stream); -+ half* output, half* norm_output, const half* input, const half* gamma, int m, int n, cudaStream_t stream, float eps); - - template - void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, -@@ -817,17 +1198,39 @@ void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - const T* bias, - int m, - int n, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { -- if (beta != nullptr && bias != nullptr) { -- invokeGeneralAddBiasResidualPreLayerNorm(output, norm_output, input, gamma, beta, bias, m, n, stream); -+ if (beta != nullptr || bias != nullptr) { -+ invokeGeneralAddBiasResidualPreLayerNorm(output, norm_output, input, gamma, beta, bias, m, n, stream, eps); - } - else { -- invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, m, n, stream); -+ invokeGeneralAddResidualT5PreLayerNorm(output, norm_output, input, gamma, m, n, stream, eps); - } - return; - } - -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps, -+ int opt_version) -+{ -+ if (beta != nullptr || bias != nullptr) { -+ invokeGeneralAddBiasResidualPreLayerNormCast(attn_output, norm_output, from_tensor, gamma, beta, bias, m, n, stream, eps); -+ } -+ else { -+ invokeGeneralAddBiasResidualT5PreLayerNormCast(attn_output, norm_output, from_tensor, gamma, m, n, stream, eps, opt_version); -+ } -+ return; -+} - template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, - float* norm_output, - const float* input, -@@ -836,7 +1239,8 @@ template void invokeGeneralAddBiasResidualT5PreLayerNorm(float* output, - const float* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, - half* norm_output, -@@ -846,11 +1250,12 @@ template void invokeGeneralAddBiasResidualT5PreLayerNorm(half* output, - const half* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template - __global__ void generalLayerNorm( -- const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* output, int m, int n) -+ const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* output, int m, int n, float eps) - { - const int tid = threadIdx.x; - -@@ -879,10 +1284,9 @@ __global__ void generalLayerNorm( - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); -- - for (int i = tid; i < n; i += blockDim.x) { - float beta_val = (beta == nullptr) ? 0.0f : (float)ldg(&beta[i]); - output[blockIdx.x * n + i] = -@@ -892,11 +1296,11 @@ __global__ void generalLayerNorm( - - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ -- (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ -- (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, nullptr, (const T2*)input, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - template - void invokeGeneralLayerNorm(T* out, -@@ -906,6 +1310,7 @@ void invokeGeneralLayerNorm(T* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version) - { - dim3 grid(m); -@@ -965,7 +1370,7 @@ void invokeGeneralLayerNorm(T* out, - } - - /* should pay attention to the rsqrt precision*/ -- generalLayerNorm<<>>(input, gamma, beta, out, m, n); // For gpt-3 -+ generalLayerNorm<<>>(input, gamma, beta, out, m, n, eps); // For gpt-3 - } - } - -@@ -979,6 +1384,7 @@ template void invokeGeneralLayerNorm(float* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - template void invokeGeneralLayerNorm(half* out, - const half* input, -@@ -987,6 +1393,7 @@ template void invokeGeneralLayerNorm(half* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #ifdef ENABLE_BF16 - template void invokeGeneralLayerNorm(__nv_bfloat16* out, -@@ -996,11 +1403,12 @@ template void invokeGeneralLayerNorm(__nv_bfloat16* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - - template --__global__ void generalT5LayerNorm(const T* __restrict input, const T* __restrict gamma, T* output, int m, int n) -+__global__ void generalT5LayerNorm(const T* __restrict input, const T* __restrict gamma, T* output, int m, int n, float eps) - { - // layernorm module in the T5 style No bias and no subtraction of mean. - const int tid = threadIdx.x; -@@ -1016,7 +1424,7 @@ __global__ void generalT5LayerNorm(const T* __restrict input, const T* __restric - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1028,10 +1436,10 @@ __global__ void generalT5LayerNorm(const T* __restrict input, const T* __restric - - template - void invokeGeneralT5LayerNorm( -- T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream) -+ T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream, float eps) - { - if (beta != nullptr) { -- invokeGeneralLayerNorm(out, input, gamma, beta, m, n, stream); -+ invokeGeneralLayerNorm(out, input, gamma, beta, m, n, stream, eps); - return; - } - -@@ -1048,7 +1456,7 @@ void invokeGeneralT5LayerNorm( - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ -- generalT5LayerNorm<<>>(input, gamma, out, m, n); // For gpt-3 -+ generalT5LayerNorm<<>>(input, gamma, out, m, n, eps); // For gpt-3 - } - - template void invokeGeneralT5LayerNorm(float* out, -@@ -1057,9 +1465,10 @@ template void invokeGeneralT5LayerNorm(float* out, - const float* beta, - const int m, - const int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - template void invokeGeneralT5LayerNorm( -- half* out, const half* input, const half* gamma, const half* beta, const int m, const int n, cudaStream_t stream); -+ half* out, const half* input, const half* gamma, const half* beta, const int m, const int n, cudaStream_t stream, float eps); - - /******************* invokeLayernormShiftPartition ***********************/ - -@@ -1073,7 +1482,8 @@ __global__ void layernorm_shift_partition(T* out, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - int tid = threadIdx.x; - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; -@@ -1102,7 +1512,7 @@ __global__ void layernorm_shift_partition(T* out, - float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -1122,7 +1532,8 @@ __global__ void layernorm_shift_partition(half2* out_ptr, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; - const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; -@@ -1161,7 +1572,7 @@ __global__ void layernorm_shift_partition(half2* out_ptr, - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / (n * 2) + 1e-6f); -+ s_variance = rsqrtf(variance / (n * 2) + eps); - } - __syncthreads(); - -@@ -1184,7 +1595,8 @@ __global__ void layernorm_shift_partition_v2(T* out, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1236,7 +1648,7 @@ __global__ void layernorm_shift_partition_v2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1260,7 +1672,8 @@ __global__ void layernorm_shift_partition_v2(half2* out_ptr, - int W, - int n, - int shift_size, -- int window_size) -+ int window_size, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1315,7 +1728,7 @@ __global__ void layernorm_shift_partition_v2(half2* out_ptr, - - variance = blockReduceSum(var); - if (threadIdx.x == 0) { -- s_variance = rsqrtf(variance / (n * 2) + 1e-6f); -+ s_variance = rsqrtf(variance / (n * 2) + eps); - } - __syncthreads(); - -@@ -1341,18 +1754,19 @@ void invokeLayernormShiftPartition(T* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(W, H, batch); - int blockSize = (n + 31) / 32 * 32; - if (blockSize >= 768) { - blockSize = ((blockSize / 4) + 31) / 32 * 32; - layernorm_shift_partition_v2 -- <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); -+ <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); - } - else { - layernorm_shift_partition -- <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size); -+ <<>>(out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); - } - } - -@@ -1367,7 +1781,8 @@ void invokeLayernormShiftPartition(half* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream) -+ cudaStream_t stream, -+ float eps) - { - dim3 grid(W, H, batch); - int blockSize = n / 2; -@@ -1384,7 +1799,8 @@ void invokeLayernormShiftPartition(half* out, - W, - n / 2, - shift_size, -- window_size); -+ window_size, -+ eps); - } - else { - layernorm_shift_partition<<>>((half2*)out, -@@ -1396,7 +1812,8 @@ void invokeLayernormShiftPartition(half* out, - W, - n / 2, - shift_size, -- window_size); -+ window_size, -+ eps); - } - } - -@@ -1410,7 +1827,8 @@ template void invokeLayernormShiftPartition(float* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeLayernormShiftPartition(half* out, - const half* input, -@@ -1422,12 +1840,13 @@ template void invokeLayernormShiftPartition(half* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - /******************* invokeAddBiasLayernorm ***********************/ - - template --__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, int n) -+__global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const T* beta, int n, float eps) - { - int tid = threadIdx.x; - const int bid = blockIdx.x; -@@ -1447,7 +1866,7 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const - float diff = (tid < n) ? (local_out - s_mean) : 0.0f; - variance = blockReduceSum(diff * diff); - if (threadIdx.x == 0) { -- s_variance = variance / n + 1e-6f; -+ s_variance = variance / n + eps; - } - __syncthreads(); - -@@ -1459,7 +1878,7 @@ __global__ void add_bias_layernorm(T* out, const T* bias, const T* gamma, const - - template - __global__ void --add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, int n) -+add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, int n, float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1496,7 +1915,7 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1512,15 +1931,15 @@ add_bias_layernorm_v2(T* out, const T* __restrict bias, const T* __restrict gamm - - #define HALF_LAYERNORM_OPT(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt<<>>( \ -- (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - #define HALF_LAYERNORM_OPT2(UNROLL_FACTOR) \ - generalAddBiasResidualLayerNormOpt2<<>>( \ -- (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n); -+ (T2*)out, (T2*)out, (const T2*)bias, (const T2*)out, (const T2*)gamma, (const T2*)beta, m, half_n, eps); - - template - void invokeAddBiasLayernorm( -- T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version) -+ T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps, int opt_version) - { - dim3 grid(m); - if (n % 2 == 0 && std::is_same::value && opt_version > 0) { -@@ -1572,10 +1991,10 @@ void invokeAddBiasLayernorm( - int blockSize = (n + 31) / 32 * 32; - if (blockSize >= 768) { - blockSize = ((blockSize / 4) + 31) / 32 * 32; -- add_bias_layernorm_v2<<>>(out, bias, gamma, beta, n); -+ add_bias_layernorm_v2<<>>(out, bias, gamma, beta, n, eps); - } - else { -- add_bias_layernorm<<>>(out, bias, gamma, beta, n); -+ add_bias_layernorm<<>>(out, bias, gamma, beta, n, eps); - } - } - } -@@ -1590,6 +2009,7 @@ template void invokeAddBiasLayernorm(float* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - - template void invokeAddBiasLayernorm(half* out, -@@ -1599,6 +2019,7 @@ template void invokeAddBiasLayernorm(half* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #ifdef ENABLE_BF16 - template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, -@@ -1608,6 +2029,7 @@ template void invokeAddBiasLayernorm<__nv_bfloat16>(__nv_bfloat16* out, - int m, - int n, - cudaStream_t stream, -+ float eps, - int opt_version); - #endif - -@@ -1625,7 +2047,8 @@ __global__ void merge_layernorm_v2(T* out, - int batch, - int H, - int W, -- int n) -+ int n, -+ float eps) - { - const int ite = 4; - const int tid = threadIdx.x; -@@ -1675,7 +2098,7 @@ __global__ void merge_layernorm_v2(T* out, - - variance = blockReduceSum(var); - if (tid == 0) { -- s_variance = rsqrtf(variance / n + 1e-6f); -+ s_variance = rsqrtf(variance / n + eps); - } - __syncthreads(); - -@@ -1693,7 +2116,7 @@ __global__ void merge_layernorm_v2(T* out, - // TODO : accelerate with half2 - template - void invokeMergeLayernorm( -- T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream) -+ T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream, float eps) - { - if ((W % 2 != 0) || (H % 2 != 0)) { - printf("[ERROR][invokeMergeLayernorm] H(W) should be a multiple of 2.\n"); -@@ -1706,7 +2129,7 @@ void invokeMergeLayernorm( - // if (blockSize >= 768) - { - blockSize = ((blockSize / 4) + 31) / 32 * 32; -- merge_layernorm_v2<<>>(output, input, gamma, beta, batch, H / 2, W / 2, n * 4); -+ merge_layernorm_v2<<>>(output, input, gamma, beta, batch, H / 2, W / 2, n * 4, eps); - } - /* - else -@@ -1722,7 +2145,8 @@ template void invokeMergeLayernorm(float* output, - int H, - int W, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); - - template void invokeMergeLayernorm(half* output, - const half* input, -@@ -1732,6 +2156,45 @@ template void invokeMergeLayernorm(half* output, - int H, - int W, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps); -+ -+ -+ -+ -+ -+ -+__global__ void ToFloat(half* src, float* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (float)(src[pos]); -+ } -+} -+ -+__global__ void ToHalf(float* src, half* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (half)(src[pos]); -+ } -+} -+ -+__global__ void ToFlaotFromFloat(float* src, float* dst, int element_cnt) { -+ for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { -+ dst[pos] = (src[pos]); -+ -+ } -+} -+ -+void InvokeCast(void* src, void* dst, int element_cnt, int dir, cudaStream_t stream) { -+ dim3 block, grid; -+ -+ block.x = 1024; -+ grid.x = ceil(element_cnt / 1024.); -+ if (dir) { -+ ToFloat<<>>((half*)src, (float*)dst, element_cnt); -+ } else { -+ ToHalf<<>>((float*)src, (half*)dst, element_cnt); -+ } -+ return; -+} -+ - - } // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/kernels/layernorm_kernels.h b/src/fastertransformer/kernels/layernorm_kernels.h -index e8319de..22e8b94 100644 ---- a/src/fastertransformer/kernels/layernorm_kernels.h -+++ b/src/fastertransformer/kernels/layernorm_kernels.h -@@ -42,7 +42,19 @@ void invokeAddBiasResidualLayerNorm(T* out, - const T* beta, - const int m, - const int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); -+ -+template -+void invokeAddBiasResidualT5LayerNorm(T* out, -+ const T* input, -+ const T* bias, -+ const T* gamma, -+ const T* beta, -+ const int m, -+ const int n, -+ cudaStream_t stream, -+ float eps = 1e-6f); - - template - void invokeGeneralAddBiasResidualPreLayerNorm(T* output, -@@ -54,6 +66,7 @@ void invokeGeneralAddBiasResidualPreLayerNorm(T* output, - int m, - int n, - cudaStream_t stream, -+ float eps = 1e-6f, - int opt_version = 2); - - template -@@ -64,15 +77,16 @@ void invokeGeneralLayerNorm(T* out, - const int m, - const int n, - cudaStream_t stream, -+ float eps = 1e-6f, - int opt_version = 2); - - template - void invokeGeneralT5LayerNorm( -- T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream); -+ T* out, const T* input, const T* gamma, const T* beta, const int m, const int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeGeneralAddResidualT5PreLayerNorm( -- T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream); -+ T* output, T* norm_output, const T* input, const T* gamma, int m, int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, -@@ -83,7 +97,8 @@ void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, - const T* bias, - int m, - int n, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); - - template - void invokeLayernormShiftPartition(T* out, -@@ -96,14 +111,49 @@ void invokeLayernormShiftPartition(T* out, - int n, - int shift_size, - int window_size, -- cudaStream_t stream); -+ cudaStream_t stream, -+ float eps = 1e-6f); -+ -+template -+void invokeGeneralAddBiasResidualPreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps = 1e-6f, -+ int opt_version = 2); -+ -+template -+void invokeGeneralAddBiasResidualT5PreLayerNormCast(T* attn_output, -+ S* norm_output, -+ const T* from_tensor, -+ const T* gamma, -+ const T* beta, -+ const T* bias, -+ int m, -+ int n, -+ cudaStream_t stream, -+ float eps = 1e-6f, -+ int opt_version = 2); -+ -+template -+void invokeAddBiasResidualLayerNormCast( -+ S* attn_output, D* norm_output, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f); -+template -+void invokeAddBiasResidualT5LayerNormCast( -+ S* attn_output, D* norm_output, const S* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f); - - template - void invokeAddBiasLayernorm( -- T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, int opt_version = 2); -+ T* out, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream, float eps = 1e-6f, int opt_version = 2); - - template - void invokeMergeLayernorm( -- T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream); -+ T* output, const T* input, const T* gamma, const T* beta, int batch, int H, int W, int n, cudaStream_t stream, float eps = 1e-6f); - -+void InvokeCast(void* src, void* dst, int element_cnt, int dir, cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu -index f951e71..f404f45 100644 ---- a/src/fastertransformer/kernels/unfused_attention_kernels.cu -+++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu -@@ -15,6 +15,14 @@ - * limitations under the License. - */ - -+#ifndef CUDART_VERSION -+#error CUDART_VERSION Undefined! -+#elif (CUDART_VERSION >= 11050) -+#include -+#else -+#include "3rdparty/cub/cub.cuh" -+#endif -+ - #include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" - #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" - #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" -@@ -23,6 +31,24 @@ - - namespace fastertransformer { - -+const int WARP_SIZE = 32; -+const bool ATTENION_OPT = true; -+const int ATTENTION_BLOCK_SIZE = 256; -+ -+/////////////////////////////////////////////////////////////////////////////////////////////////// -+ -+template -+using Copy_half_t = typename std::conditional< -+ HALF_ELEMENTS_PER_WARP_LOAD == 32, -+ half, -+ typename std::conditional::type>::type>:: -+ type; -+ -+template -+using Copy_t = Copy_half_t; -+ - __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4) - { - return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; -@@ -243,6 +269,77 @@ __global__ void softmax_kernel_v4(T* qk_buf_, - } - } - -+template -+__global__ void softmax_mix_kernel_bias_v4(T* qk_buf_, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ T* qk_buf_src = qk_buf_; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { -+ float data[ITEMS_PER_THREAD]; -+ int qk_offset; -+ __shared__ float s_mean, s_max; -+ float local_max = -1e20f; -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ qk_offset = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; -+ int pos_offset = -+ ((blockIdx.z) * seq_stride + seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ int mask_offset = (blockIdx.y * seq_stride + seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ -+ int pos_offset2 = (seq_id) * trgt_seq_stride + blockDim.x * i + threadIdx.x; -+ int bias_offset = (position_bias_head_num == 1) ? pos_offset2 : pos_offset; -+ float qk = static_cast(qk_buf_src[qk_offset]); -+ float mask_val = (attr_mask != nullptr) ? static_cast(ldg(&attr_mask[mask_offset])) : 1.0f; -+ if (causal) { -+ mask_val = (blockDim.x * i + threadIdx.x <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1.0f : 0.0f; -+ } else if (d_seq_len != nullptr) { -+ mask_val = (seq_id < d_seq_len[blockIdx.y] && blockDim.x * i + threadIdx.x < d_seq_len2[blockIdx.y]) ? mask_val : 0.0f; -+ } -+ float bias_val = (position_bias == nullptr) ? 0.0f : static_cast(ldg(&position_bias[bias_offset])); -+ mask_val = (1.0f - mask_val) * -10000.0f; -+ -+ data[i] = qk * static_cast(scalar) + mask_val + bias_val; -+ local_max = fmax(local_max, data[i]); -+ } -+ -+ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); -+ if (threadIdx.x == 0) { -+ s_max = max_val; -+ } -+ __syncthreads(); -+ -+ float local_sum = 0; -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ data[i] = __expf(data[i] - s_max); -+ local_sum += data[i]; -+ } -+ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); -+ if (threadIdx.x == 0) { -+ s_mean = sum_val + 1e-6f; -+ s_mean = __fdividef(1.0f, s_mean); -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { -+ qk_offset = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; -+ qk_buf_[qk_offset] = (T)(data[i] * s_mean); -+ } -+ } -+} -+ - template - __global__ void softmax_kernel_v4_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -298,6 +395,89 @@ __global__ void softmax_kernel_v4_half2( - } - } - -+template -+__global__ void softmax_cross_kernel_bias_v4_half2(T* qk_buf_, -+ const T* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ using T2 = typename TypeConverter::Type; -+ T2* qk_buf_half2 = (T2*)qk_buf_; -+ const T2* attr_mask_half2 = (const T2*)attr_mask; -+ const T2* position_bias_half2 = (position_bias == nullptr) ? nullptr : (const T2*)position_bias; -+ const T2 zero = {0, 0}; -+ const T2 one = {1, 1}; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { -+ T2 data[ITEMS_PER_THREAD]; -+ int qk_offset; -+ int pos_offset; -+ int pos_offset2; -+ __shared__ float s_mean, s_max; -+ float local_max = -1e20f; -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i -+ + threadIdx.x; -+ pos_offset = ((blockIdx.z) * seq_stride + seq_id) * (trgt_seq_stride / 2) + blockDim.x * i -+ + threadIdx.x; -+ pos_offset2 = (seq_id) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ int mask_offset = (blockIdx.y * seq_stride + seq_id) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ int bias_offset = (position_bias_head_num == 1) ? pos_offset2 : pos_offset; -+ -+ T2 qk = qk_buf_half2[qk_offset]; -+ T2 mask_val = (attr_mask_half2!= nullptr) ? ldg(&attr_mask_half2[mask_offset]) : one; -+ if (causal) { -+ mask_val.x = ((mask_offset % (trgt_seq_stride / 2)) * 2 <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ mask_val.y = (((mask_offset % (trgt_seq_stride / 2)) * 2 + 1) <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ } else if (d_seq_len != nullptr) { -+ mask_val.x = (seq_id < d_seq_len[blockIdx.y] && (mask_offset % (trgt_seq_stride / 2)) * 2 < d_seq_len2[blockIdx.y]) ? mask_val.x : (T)0; -+ mask_val.y = (seq_id < d_seq_len[blockIdx.y] && ((mask_offset % (trgt_seq_stride / 2)) * 2 + 1) < d_seq_len2[blockIdx.y]) ? mask_val.y : (T)0; -+ } -+ mask_val = hmul2(hsub2(float2type2(1.0f), mask_val), float2type2(-10000.0f)); -+ T2 bias_val = (position_bias_half2 == nullptr) ? zero : (ldg(&position_bias_half2[bias_offset])); -+ -+ data[i] = hadd2(hadd2(hmul2(qk, type2type2(scalar)), mask_val), bias_val); -+ -+ local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); -+ } -+ -+ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); -+ if (threadIdx.x == 0) { -+ s_max = max_val; -+ } -+ __syncthreads(); -+ -+ float local_sum = 0; -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ data[i] = hexp2(hsub2(data[i], float2type2(s_max))); -+ local_sum += (float)(data[i].x + data[i].y); -+ } -+ -+ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); -+ -+ if (threadIdx.x == 0) { -+ s_mean = sum_val + 1e-6f; -+ s_mean = __fdividef(1.0f, s_mean); -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i -+ + threadIdx.x; -+ qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); -+ } -+ } -+} -+ - template - __global__ void softmax_kernel_v5_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -415,6 +595,162 @@ __global__ void softmax_kernel_v5_half2( - } - } - -+template -+__global__ void softmax_cross_kernel_bias_v5_half2(T* qk_buf_, -+ const T* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int head_num, -+ const int seq_len, -+ const int seq_stride, -+ const int trgt_seq_len, -+ const int trgt_seq_stride, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal) -+{ -+ using T2 = typename TypeConverter::Type; -+ T2* qk_buf_half2 = (T2*)qk_buf_; -+ const T2* attr_mask_half2 = (const T2*)attr_mask; -+ const T2* position_bias_half2 = (position_bias == nullptr) ? nullptr : (const T2*)position_bias; -+ const T2 zero = {0, 0}; -+ const T2 one = {1, 1}; -+ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) { -+ T2 data[NUM][ITEMS_PER_THREAD]; -+ -+ int qk_offset[NUM]; -+ int pos_offset[NUM]; -+ int pos_offset2[NUM]; -+ int pos_bias_offset[NUM]; -+ -+ __shared__ float s_sum[NUM], s_max[NUM]; -+ float local_max[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_max[j] = -1e20f; -+ } -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+ int mask_offset[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_offset[j] = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) -+ + blockDim.x * i + threadIdx.x; -+ pos_offset[j] = -+ ((blockIdx.z) * seq_stride + seq_id + j * gridDim.x) * (trgt_seq_stride / 2) -+ + blockDim.x * i + threadIdx.x; -+ pos_offset2[j] = (seq_id + j * gridDim.x) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ mask_offset[j] = (blockIdx.y * seq_stride + seq_id + j * gridDim.x) * (trgt_seq_stride / 2) + blockDim.x * i + threadIdx.x; -+ -+ pos_bias_offset[j] = (position_bias_head_num == 1) ? pos_offset2[j] : pos_offset[j]; -+ } -+ -+ T2 mask_val[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ mask_val[j] = (attr_mask_half2 != 0) ? ldg(&attr_mask_half2[mask_offset[j]]) : one; -+ if (causal) { -+ mask_val[j].x = ((mask_offset[j] % (trgt_seq_stride / 2)) * 2 <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ mask_val[j].y = (((mask_offset[j] % (trgt_seq_stride / 2)) * 2 + 1) <= seq_id && seq_id < d_seq_len[blockIdx.y]) ? 1 : 0; -+ } else if (d_seq_len != nullptr) { -+ mask_val[j].x = (seq_id < d_seq_len[blockIdx.y] && (mask_offset[j] % (trgt_seq_stride / 2)) * 2 < d_seq_len2[blockIdx.y]) ? mask_val[j].x : (T)0; -+ mask_val[j].y = (seq_id < d_seq_len[blockIdx.y] && ((mask_offset[j] % (trgt_seq_stride / 2)) * 2 + 1) < d_seq_len2[blockIdx.y]) ? mask_val[j].y : (T)0; -+ } -+ } -+ -+ T2 qk[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk[j] = qk_buf_half2[qk_offset[j]]; -+ } -+ -+ T2 pos_bias_val[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ pos_bias_val[j] = -+ (position_bias_half2 == nullptr) ? zero : ldg(&position_bias_half2[pos_bias_offset[j]]); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ mask_val[j] = hmul2(hsub2(float2type2(1.0f), mask_val[j]), float2type2(-10000.0f)); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ data[j][i] = -+ hadd2(hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]), pos_bias_val[j]); -+ local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceMaxV2(local_max); -+ } -+ else { -+ blockReduceMaxV2(local_max); -+ } -+ -+ if (threadIdx.x == 0) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ s_max[j] = local_max[j]; -+ } -+ } -+ __syncthreads(); -+ -+ float local_sum[NUM]; -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_sum[j] = {0.f}; -+ } -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ data[j][i] = hexp2(hsub2(data[j][i], float2type2(s_max[j]))); -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ local_sum[j] += (float)(data[j][i].x + data[j][i].y); -+ } -+ } -+ -+ if (blockDim.x <= 32) { -+ warpReduceSumV2(local_sum); -+ } -+ else { -+ blockReduceSumV2(local_sum); -+ } -+ -+ if (threadIdx.x == 0) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); -+ } -+ } -+ __syncthreads(); -+ -+ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_offset[j] = -+ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) -+ + blockDim.x * i + threadIdx.x; -+ } -+ -+#pragma unroll -+ for (int j = 0; j < NUM; j++) { -+ qk_buf_half2[qk_offset[j]] = hmul2(data[j][i], float2type2(s_sum[j])); -+ } -+ } -+ } -+} -+ - #define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \ - block.x /= ITEMS_PER_THREAD; \ - assert(block.x <= 1024); \ -@@ -434,6 +770,63 @@ __global__ void softmax_kernel_v5_half2( - <<>>(buffer, buffer_src, attr_mask, batch_size, head_num, seq_len, scalar); \ - } - -+#define SOFTMAX_MIX_KERNEL_BIAS(ITEMS_PER_THREAD) \ -+ block.x /= ITEMS_PER_THREAD; \ -+ assert(block.x <= 1024); \ -+ if (is_half2) { \ -+ if (grid.x % 4 == 0) { \ -+ grid.x /= 4; \ -+ softmax_cross_kernel_bias_v5_half2 \ -+ <<>>((half*)io_buffer, \ -+ (const half*)attr_mask, \ -+ (const half*)position_bias, \ -+ (const int*)d_seq_len, \ -+ (const int*)d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ (const half)scalar, \ -+ causal); \ -+ } \ -+ else { \ -+ softmax_cross_kernel_bias_v4_half2 \ -+ <<>>((half*)io_buffer, \ -+ (const half*)attr_mask, \ -+ (const half*)position_bias, \ -+ (const int*)d_seq_len, \ -+ (const int*)d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ (const half)scalar, \ -+ causal); \ -+ } \ -+ } \ -+ else { \ -+ softmax_mix_kernel_bias_v4<<>>(io_buffer, \ -+ attr_mask, \ -+ position_bias, \ -+ d_seq_len, \ -+ d_seq_len2, \ -+ batch_size, \ -+ head_num, \ -+ seq_len, \ -+ src_seq_stride, \ -+ trgt_seq_len, \ -+ tgt_seq_stride, \ -+ position_bias_head_num, \ -+ scalar, \ -+ causal); \ -+ } -+ - #ifdef ENABLE_BF16 - #define SOFTMAX_KERNEL_BF16(ITEMS_PER_THREAD) \ - block.x /= ITEMS_PER_THREAD; \ -@@ -501,6 +894,48 @@ void invokeMaskedSoftMax(T* buffer, - } - } - -+template -+void invokeMixMaskedSoftMax(T* io_buffer, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int src_seq_stride, -+ const int trgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal, -+ cudaStream_t stream) -+{ -+ dim3 grid(seq_len, batch_size, head_num); -+ if (batch_size * head_num > 360) { -+ grid.x = ceil(float(seq_len) / 32.0f); -+ } -+ -+ bool is_half2 = sizeof(T) == 2 && sizeof(T_M) == 2 && trgt_seq_len % 2 == 0; -+ dim3 block((trgt_seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); -+ -+ if (block.x > 3072 && block.x <= 4096) { -+ SOFTMAX_MIX_KERNEL_BIAS(4) -+ } -+ if (block.x > 2048) { -+ SOFTMAX_MIX_KERNEL_BIAS(3) -+ } -+ else if (block.x > 1024) { -+ SOFTMAX_MIX_KERNEL_BIAS(2) -+ } -+ else if (block.x > 0) { -+ SOFTMAX_MIX_KERNEL_BIAS(1) -+ } -+ else { -+ FT_CHECK(trgt_seq_len <= 4096 || seq_len <= 4096); -+ } -+} -+ - #ifdef ENABLE_BF16 - template<> - void invokeMaskedSoftMax(__nv_bfloat16* buffer, -@@ -574,8 +1009,73 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, - FT_CHECK(seq_len <= 4096); - } - } -+ - #endif // ENABLE_BF16 - -+template void invokeMixMaskedSoftMax(float* io_buffer, -+ const float* attr_mask, -+ const float* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const float scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(half* io_buffer, -+ const half* attr_mask, -+ const half* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const half scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(float* io_buffer, -+ const half* attr_mask, -+ const float* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const float scalar, -+ const bool causal, -+ cudaStream_t stream); -+ -+template void invokeMixMaskedSoftMax(half* io_buffer, -+ const float* attr_mask, -+ const half* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const half scalar, -+ const bool causal, -+ cudaStream_t stream); -+ - template void invokeMaskedSoftMax(float* buffer, - const float* buffer_src, - const float* attr_mask, -@@ -621,6 +1121,7 @@ template void invokeMaskedSoftMax(__nv_bfloat16* buffer, - const int head_num, - const __nv_bfloat16 scalar, - cudaStream_t stream); -+ - #endif // ENABLE_BF16 - - template -@@ -726,9 +1227,9 @@ void invokeTransposeQKV(T* dst, - seq_per_block *= 2; - } - -- FT_CHECK(grid.x * seq_per_block == batch_size * head_num * seq_len); -+ FT_CHECK((int)(grid.x * seq_per_block) == batch_size * head_num * seq_len); - -- if (seq_per_block * size_per_head % 2 == 0) { -+ if (size_per_head % 2 == 0) { - block.x = seq_per_block * size_per_head / 2; - if (std::is_same::value) { - transpose<<>>( -@@ -778,6 +1279,7 @@ template void invokeTransposeQKV(__nv_bfloat16* src, - const int head_num, - const int size_per_head, - cudaStream_t stream); -+ - #endif - - template -@@ -993,12 +1495,14 @@ __global__ void transpose_remove_padding(const T* src, - - const int dst_seq_id = bid; - -+ const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head; -+ const int dst_offset_base = dst_seq_id * head_num * size_per_head; -+ - for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) { - const int head_id = idx / size_per_head; - const int hidden_id = idx % size_per_head; -- dst[dst_seq_id * head_num * size_per_head + idx] = -- __ldg(&src[src_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -- + src_seq_id * size_per_head + hidden_id]); -+ const T src_elem = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]); -+ dst[dst_offset_base + idx] = src_elem; - } - } - -@@ -1061,12 +1565,12 @@ template void invokeTransposeAttentionOutRemovePadding(half* src, - const int* mask_offset, - cudaStream_t stream); - --template -+template - __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - T* k_buf, - T* v_buf, - const T* __restrict QKV, -- const T* __restrict qkv_bias, -+ const U* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1081,8 +1585,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; - index += gridDim.x * blockDim.x) { - int bias_id = index % (3 * n); -- T val = ldg(&QKV[index]) + ldg(&qkv_bias[bias_id]); -- -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); - int tmp_index = index; - const int target_batch_id = tmp_index / (seq_len * 3 * n); - tmp_index -= target_batch_id * seq_len * 3 * n; -@@ -1097,15 +1600,217 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - + seq_id * size_per_head + size_id] = val; - } - } -- --template --struct Vec_t {}; --template<> --struct Vec_t { -- using Type = float2; --}; --template<> --struct Vec_t { -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_mb_vsl(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int max_seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (3 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / max_seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % max_seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 3 * n; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_use_past(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const T* __restrict qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < actual_seq_len * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ const int offset = head_id * seq_len * size_per_head + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][offset] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_bias_transpose_kernel_use_past_mb(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const T* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (3 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / actual_seq_len; -+ int seq_id = (padding_offset[h_token_idx] + h_token_idx) % actual_seq_len; -+ tmp_index -= h_token_idx * 3 * n; -+ h_token_idx -= seq_id; -+ -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ int offset = seq_id * size_per_head + size_id; -+ if ((!padding || incremental_mode) && (qkv_id == 0)) { -+ offset += h_token_idx * n + head_id * d_sequence_length[batch_id] * size_per_head; -+ } else if (qkv_id != 1 || !padding) { -+ offset += batch_id * actual_seq_len * n + head_id * actual_seq_len * size_per_head; -+ } else { -+ offset = batch_id * actual_seq_len * n + head_id * actual_seq_len * size_per_head + size_id * actual_seq_len + seq_id; -+ } -+ if (incremental_mode && !(qkv_id == 0)) { -+ if (qkv_id == 1 && padding) -+ offset += (d_sequence_length2[batch_id] - 1); -+ else { -+ offset += (d_sequence_length2[batch_id] - 1) * size_per_head; -+ } -+ -+ } -+ qkv_ptr[qkv_id][offset] = val; -+ } -+} -+template -+__global__ void transposeQKV_kernel(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 3 * n); -+ tmp_index -= target_batch_id * seq_len * 3 * n; -+ const int seq_id = tmp_index / (3 * n); -+ tmp_index -= seq_id * 3 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 3, n] -+ // qkv_bias: [3, n] -+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 3 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (3 * n); // 0 - 160 -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; // 0 -160 * 3 * n -+ int token_id = tmp_index / (3 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 3 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; -+ } -+} -+ -+template -+struct Vec_t {}; -+template<> -+struct Vec_t { -+ using Type = float2; -+}; -+template<> -+struct Vec_t { - using Type = uint32_t; - }; - -@@ -1116,12 +1821,12 @@ struct Vec_t<__nv_bfloat16> { - }; - #endif - --template -+template - __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, - T* k_buf, - T* v_buf, - const T* __restrict QKV, -- const T* __restrict qkv_bias, -+ const U* __restrict qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1174,8 +1879,21 @@ template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, -- T* QKV, -+ const T* QKV, - const T* qkv_bias, -+ const int max_seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool inc, -+ cudaStream_t stream); -+ -+template -+void invokeAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -1183,23 +1901,714 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, - const int rotary_embedding_dim, - cudaStream_t stream) - { -- if (rotary_embedding_dim == 0) { -+ if (qkv_bias != nullptr) { -+ if (rotary_embedding_dim == 0) { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ } -+ else { -+ // To implement rotary embeddings, each thread processes two QKV elems: -+ dim3 block((size_per_head / 2 + 31) / 32 * 32); -+ dim3 grid(seq_len, head_num, batch_size); -+ add_fusedQKV_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); -+ } -+ } -+ else { - const int m = batch_size * seq_len; - const int n = head_num * size_per_head; - dim3 block(384); - dim3 grid((int)(ceil(1.0 * m * n / 384))); -- add_fusedQKV_bias_transpose_kernel<<>>( -- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ transposeQKV_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, batch_size, seq_len, head_num, size_per_head); - } -- else { -+} -+template -+void invokeAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel_mb_vsl<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, padding_offset, d_sequence_length, h_token_num, batch_size, max_seq_len, head_num, size_per_head); -+} -+template -+void invokeAddFusedQKVBiasTransposeUsePast(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ if (qkv_bias != nullptr) { - // To implement rotary embeddings, each thread processes two QKV elems: - dim3 block((size_per_head / 2 + 31) / 32 * 32); -- dim3 grid(seq_len, head_num, batch_size); -- add_fusedQKV_bias_transpose_kernel<<>>( -- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); -+ dim3 grid(actual_seq_len, head_num); -+ add_fusedQKV_bias_transpose_kernel_use_past<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, seq_len, actual_seq_len, head_num, size_per_head); -+ } -+ else { -+ std::cout << "null qkv bias not supported" << std::endl; -+ } -+} -+template -+void invokeAddFusedQKVBiasTransposeUsePastMB(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_bias_transpose_kernel_use_past_mb<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, padding_offset, d_sequence_length, -+ d_sequence_length2, batch_size, h_token_num, -+ actual_seq_len, head_num, size_per_head, -+ incremental_mode, padding); -+} -+template -+void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream) -+{ -+ -+ const int m = h_token; -+ const int n = head_num * size_per_head; -+ cudaMemsetAsync(q_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ cudaMemsetAsync(k_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ cudaMemsetAsync(v_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_ZP_bias_transpose_kernel<<>>( -+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); -+} -+ -+template void invokeAddFusedZP_QKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream); -+ -+template void invokeAddFusedZP_QKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int* padding_mask, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedZP_QKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedZP_QKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream); -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeQ(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ tmp_index -= target_batch_id * seq_len * 1 * n; -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeQMBVSL(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ -+ int tmp_index = index; -+ int h_token_idx = (index) / (1 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 1 * n; -+ -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel_q(T* q_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 1, n] -+ // qkv_bias: [1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (1 * n); -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ // const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ int token_id = tmp_index / (1 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 1 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossTransposeQ(T* q_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 1, n] -+ // q_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[1] = {q_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 1 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 1 * n); -+ tmp_index -= target_batch_id * seq_len * 1 * n; -+ const int seq_id = tmp_index / (1 * n); -+ tmp_index -= seq_id * 1 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeKV(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 2 * n); -+ tmp_index -= target_batch_id * seq_len * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void invokeCrossAddFusedQKVBiasTransposeKVMBVSL(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int h_token_num, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < h_token_num * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T bias_val = (qkv_bias == nullptr) ? (T)0.0f : (T)ldg(&qkv_bias[bias_id]); -+ T val = ldg(&QKV[index]) + bias_val; -+ int tmp_index = index; -+ int h_token_idx = (index) / (2 * n); -+ int batch_id = (padding_offset[h_token_idx] + h_token_idx) / seq_len; -+ int seq = (padding_offset[h_token_idx] + h_token_idx) % seq_len; -+ h_token_idx -= seq; -+ tmp_index -= h_token_idx * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][h_token_idx * head_num * size_per_head + head_id * d_sequence_length[batch_id] * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+__global__ void add_fusedQKV_ZP_bias_transpose_kernel_kv(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const U* __restrict qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int token_num, -+ int* mask_offset) -+{ -+ // QKV: [m, 2, n] -+ // qkv_bias: [2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ int bias_id = index % (2 * n); -+ T val = ldg(&QKV[index]); -+ if (qkv_bias != nullptr) -+ val += (T)ldg(&qkv_bias[bias_id]); -+ int tmp_index = index; -+ int token_id = tmp_index / (2 * n); -+ int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; -+ int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; -+ tmp_index -= token_id * 2 * n; -+ int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ int head_id = tmp_index / size_per_head; -+ int size_id = tmp_index - head_id * size_per_head; -+ int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id; -+ qkv_ptr[qkv_id][dst_id] = val; - } - } - -+template -+__global__ void invokeCrossTransposeKV(T* k_buf, -+ T* v_buf, -+ const T* __restrict QKV, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head) -+{ -+ // QKV: [m, 2, n] -+ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] -+ -+ T* qkv_ptr[2] = {k_buf, v_buf}; -+ const int n = head_num * size_per_head; -+ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 2 * n; -+ index += gridDim.x * blockDim.x) { -+ T val = ldg(&QKV[index]); -+ -+ int tmp_index = index; -+ const int target_batch_id = tmp_index / (seq_len * 2 * n); -+ tmp_index -= target_batch_id * seq_len * 2 * n; -+ const int seq_id = tmp_index / (2 * n); -+ tmp_index -= seq_id * 2 * n; -+ const int qkv_id = tmp_index / n; -+ tmp_index -= qkv_id * n; -+ const int head_id = tmp_index / size_per_head; -+ const int size_id = tmp_index - head_id * size_per_head; -+ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head -+ + seq_id * size_per_head + size_id] = val; -+ } -+} -+ -+template -+void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ if (qkv_bias != nullptr) { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossAddFusedQKVBiasTransposeQ<<>>( -+ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ -+ const int m2 = batch_size * tgt_seq_len; -+ const int n2 = head_num * size_per_head; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossAddFusedQKVBiasTransposeKV<<>>( -+ k_buf, v_buf, QKV + m * n, qkv_bias + n2, batch_size, tgt_seq_len, head_num, size_per_head); -+ } -+ else { -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossTransposeQ<<>>(q_buf, QKV, batch_size, seq_len, head_num, size_per_head); -+ -+ const int m2 = batch_size * tgt_seq_len; -+ const int n2 = head_num * size_per_head; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossTransposeKV<<>>( -+ k_buf, v_buf, QKV + m * n, batch_size, tgt_seq_len, head_num, size_per_head); -+ } -+} -+template -+void invokeCrossAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream) -+{ -+ const int m = h_token_num; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossAddFusedQKVBiasTransposeQMBVSL<<>>( -+ q_buf, QKV, qkv_bias, padding_offset, d_sequence_length, h_token_num, batch_size, seq_len, head_num, size_per_head); -+ const int m2 = h_token_num2; -+ const int n2 = head_num * size_per_head; -+ const U* kv_bias = (qkv_bias == nullptr) ? qkv_bias : qkv_bias + n2; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossAddFusedQKVBiasTransposeKVMBVSL<<>>( -+ k_buf, v_buf, QKV + h_token_num * n, kv_bias, padding_offset2, d_sequence_length2, h_token_num2, batch_size, tgt_seq_len, head_num, size_per_head); -+} -+template -+void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int* padding_mask, -+ int* padding_mask2, -+ cudaStream_t stream) -+{ -+ const int m = h_token; -+ const int n = head_num * size_per_head; -+ cudaMemsetAsync(q_buf, 0, batch_size * seq_len * n * sizeof(T), stream); -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ add_fusedQKV_ZP_bias_transpose_kernel_q<<>>( -+ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); -+ -+ const int m2 = h_token2; -+ const int n2 = head_num * size_per_head; -+ cudaMemsetAsync(k_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); -+ cudaMemsetAsync(v_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ qkv_bias = (qkv_bias == nullptr) ? nullptr : qkv_bias + n2; -+ add_fusedQKV_ZP_bias_transpose_kernel_kv<<>>( -+ k_buf, v_buf, QKV + m * n, qkv_bias, batch_size, tgt_seq_len, head_num, size_per_head, h_token2, padding_mask2); -+} -+template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); - template void invokeAddFusedQKVBiasTranspose(float* q_buf, - float* k_buf, - float* v_buf, -@@ -1224,6 +2633,87 @@ template void invokeAddFusedQKVBiasTranspose(half* q_buf, - const int rotary_embedding_dim, - cudaStream_t stream); - -+template void invokeAddFusedQKVBiasTranspose(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int rotary_embedding_dim, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTranspose(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int rotary_embedding_dim, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ float* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ half* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+ - #ifdef ENABLE_BF16 - template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, - __nv_bfloat16* k_buf, -@@ -1236,6 +2726,49 @@ template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, - const int size_per_head, - const int rotary_embedding_dim, - cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeMBVSL(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template void invokeCrossAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template void invokeCrossAddFusedQKVBiasTransposeMBVSL(__nv_bfloat16* q_buf, -+ __nv_bfloat16* k_buf, -+ __nv_bfloat16* v_buf, -+ __nv_bfloat16* QKV, -+ const __nv_bfloat16* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ - #endif - - template -@@ -1860,4 +3393,419 @@ template void invokeMaskedSoftMaxWithRelPosBias(half* qk_buf, - const float qk_scale, - cudaStream_t stream); - -+template -+__global__ void attention_kernel(T* query_buf, -+ const T* Q_bias, -+ T* key_cache, -+ const T* K_bias, -+ T* value_cache, -+ const T* V_bias, -+ const int* length_per_sample, -+ T* context_buf, -+ const bool* finished, -+ int batch_size, -+ int head_num, -+ int size_per_head, -+ int step, -+ const int seq_len, -+ const T scalar) -+{ -+ if (finished != nullptr && finished[blockIdx.x / head_num] == true) { -+ return; -+ } -+ int tid = threadIdx.x; -+ int bid = blockIdx.x / head_num; -+ int head_id = blockIdx.x % head_num; -+ -+ extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; -+ T* sq = reinterpret_cast(s_buf); -+ T* logits = reinterpret_cast(&sq[size_per_head]); -+ -+ int length = __ldg(&length_per_sample[bid]); -+ -+ int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; -+ int qkv_bias_id = head_id * size_per_head + tid; -+ -+ if (tid < size_per_head) { -+ sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; -+ } -+ __syncthreads(); -+ -+ for (int ite = 0; ite < length; ++ite) { -+ int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) -+ + head_id * size_per_head + tid; -+ -+ T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); -+ -+ // For the first step, we should add bias to key memory cache. -+ // The KV memory cache only need to be updated at the first step. -+ if (step == 1 && tid < size_per_head) { -+ key += K_bias[head_id * size_per_head + tid]; -+ key_cache[key_id] = key; -+ } -+ -+ T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); -+ T qk = blockReduceSum(val); -+ if (threadIdx.x == 0) { -+ logits[ite] = qk; -+ } -+ __syncthreads(); // try to remove -+ } -+ __syncthreads(); -+ -+ __shared__ float s_max_val, s_sum; -+ -+ float local_i = tid < length ? (float)logits[tid] : -1e20f; -+ float max_val = blockReduceMax(local_i); -+ if (tid == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ -+ local_i -= s_max_val; -+ float local_o = tid < length ? __expf(local_i) : 0.0f; -+ float val = blockReduceSum(local_o); -+ -+ if (tid == 0) { -+ s_sum = val + 1e-6; -+ } -+ __syncthreads(); -+ if (tid < length) { -+ logits[tid] = local_o / s_sum; -+ } -+ __syncthreads(); -+ -+ if (tid < size_per_head) { -+ T sum = (T)0.0f; -+ for (int ite = 0; ite < length; ++ite) { -+ int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head -+ + head_id * size_per_head + tid; -+ -+ T value = value_cache[value_id]; -+ -+ // for the first step, we should add bias to key memory cache -+ if (step == 1) { -+ value += V_bias[head_id * size_per_head + tid]; -+ value_cache[value_id] = value; -+ } -+ sum += value * logits[ite]; -+ } -+ context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; -+ } -+} -+ -+template -+__global__ void attention_kernel_opt(const T* __restrict qkv_buf, -+ const T* __restrict qkv_bias, -+ const T* __restrict attr_mask, -+ T* __restrict out_buf, -+ T* __restrict key_cache_output, -+ T* __restrict value_cache_output, -+ int batch_size, -+ int head_num, -+ const int seq_len, -+ const float scalar) -+{ -+ typedef Copy_t copy_t; -+ const int elems_per_thread = size_per_head / WARP_SIZE; -+ union Access_t { -+ copy_t v; -+ T x[elems_per_thread]; // supported size 1,2,4 -+ }; -+ typedef struct Float_n_t { -+ float x[elems_per_thread]; // supported size 1,2,4 -+ } float_n_t; -+ -+ __shared__ float_n_t sq[block_sz]; -+ extern __shared__ float logits[]; // use to store the logits from [0~step] -+ -+ const int warp_id = threadIdx.x / WARP_SIZE; -+ const int warp_num = block_sz / WARP_SIZE; -+ -+ typedef cub::BlockReduce MaxValBlockReduce; -+ typedef cub::BlockReduce BlockReduce; -+ __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; -+ __shared__ typename BlockReduce::TempStorage block_temp_storage; -+ -+ __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; -+ -+ const int tid = threadIdx.x; -+ const int bid = blockIdx.x / head_num; -+ const int head_id = blockIdx.x % head_num; -+ int seq_id = blockIdx.y; -+ -+ int length = seq_len; -+ const int lane_id = tid % WARP_SIZE; -+ -+ // QKV [m 3 n] shape -+ int qkv_id = bid * (3 * seq_len * head_num * size_per_head) + seq_id * (3 * head_num * size_per_head) -+ + head_id * size_per_head; -+ int q_id = -+ bid * (seq_len * head_num * size_per_head) + seq_id * (head_num * size_per_head) + head_id * size_per_head; -+ int qkv_bias_id = head_id * size_per_head; -+ int key_id = bid * (3 * seq_len * head_num * size_per_head) + head_num * size_per_head + head_id * size_per_head; -+ int value_id = -+ bid * (3 * seq_len * head_num * size_per_head) + 2 * head_num * size_per_head + head_id * size_per_head; -+ -+ int key_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); -+ int value_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); -+ int mask_offset = bid * (seq_len * seq_len) + seq_id * seq_len; -+ -+ // get pointers -+ const T* query_buf = qkv_buf + qkv_id; -+ const T* Q_bias = qkv_bias + qkv_bias_id; -+ T* context_buf = out_buf + q_id; -+ -+ const T* key_cache = qkv_buf + key_id; -+ const T* K_bias = qkv_bias + head_num * size_per_head + qkv_bias_id; -+ T* key_cache_out = key_cache_output + key_trn_id; -+ -+ const T* value_cache = qkv_buf + value_id; -+ const T* V_bias = qkv_bias + 2 * head_num * size_per_head + qkv_bias_id; -+ T* value_cache_out = value_cache_output + value_trn_id; -+ -+ Access_t bias_r, key_val_r, query_buf_r; -+ // offset inside head -+ int minor_offset = lane_id; // offset in copy_t elements -+ // each warp will have its own copy of sq -+ query_buf_r.v = *((copy_t*)query_buf + minor_offset); -+ -+ bias_r.v = *((copy_t*)Q_bias + minor_offset); -+ float qb_r[elems_per_thread]; -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; ++i) { -+ qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; -+ } -+ -+ // offset for each step -+ int offset = 3 * head_num * size_per_head; -+ bias_r.v = *((copy_t*)K_bias + minor_offset); -+ for (int ite = warp_id; ite < length; ite += warp_num) { -+ key_val_r.v = *((copy_t*)&key_cache[ite * offset] + minor_offset); -+ -+ if (seq_id == 0) { -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ key_cache_out[ite + seq_len * (minor_offset * elems_per_thread + i)] = key_val_r.x[i]; -+ } -+ } -+ else { -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ } -+ } -+ float val = 0; -+ for (int i = 0; i < elems_per_thread; i++) { -+ val = val + (float)key_val_r.x[i] * qb_r[i]; -+ } -+ float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); -+ -+ if (lane_id == 0) { -+ T mask_val = attr_mask[mask_offset + ite]; -+ mask_val = (1.0f - mask_val) * -10000.0f; -+ logits[ite] = qk * scalar + mask_val; -+ } -+ } -+ -+ __syncthreads(); -+ -+ __shared__ float s_max_val, s_sum; -+ float local_i = -1e20f; -+ for (int i = tid; i < length; i += blockDim.x) { -+ local_i = max(local_i, logits[i]); -+ } -+ -+ float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); -+ if (tid == 0) { -+ s_max_val = max_val; -+ } -+ __syncthreads(); -+ -+ float local_o = 0.0f; -+ for (int i = tid; i < length; i += blockDim.x) { -+ logits[i] = __expf(logits[i] - s_max_val); -+ local_o += logits[i]; -+ } -+ float val = BlockReduce(block_temp_storage).Sum(local_o); -+ -+ if (tid == 0) { -+ s_sum = val + 1e-6; -+ } -+ __syncthreads(); -+ -+ float s_sum_inverse = __fdividef(1.0f, s_sum); -+ for (int i = tid; i < length; i += blockDim.x) { -+ logits[i] = logits[i] * s_sum_inverse; -+ } -+ __syncthreads(); -+ -+ // This optimization introduces discrepancy because of different order in FP32 summation -+ float sum_r[elems_per_thread] = {0.f}; -+ bias_r.v = *((copy_t*)V_bias + minor_offset); -+ for (int ite = warp_id; ite < length; ite += warp_num) { -+ key_val_r.v = *((copy_t*)&value_cache[ite * offset] + minor_offset); -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; -+ } -+ if (seq_id == 0) -+ *((copy_t*)&value_cache_out[ite * size_per_head] + minor_offset) = key_val_r.v; -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; ++i) { -+ sum_r[i] += (float)key_val_r.x[i] * logits[ite]; -+ } -+ } -+ for (int i = 0; i < elems_per_thread; i++) { -+ sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; -+ } -+ __syncthreads(); -+ if (threadIdx.x < WARP_SIZE) { -+#pragma unroll -+ for (int j = 1; j < warp_num; j++) { -+ for (int i = 0; i < elems_per_thread; ++i) { -+ sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; -+ } -+ } -+ } -+ __syncthreads(); -+#pragma unroll -+ for (int i = 0; i < elems_per_thread; i++) { -+ key_val_r.x[i] = sum_r[i]; -+ } -+ if (threadIdx.x < WARP_SIZE) { -+ *((copy_t*)context_buf + minor_offset) = key_val_r.v; -+ } -+} -+ -+template -+void myAttnention(const T* qkv_buf, -+ const T* qkv_bias, -+ const T* attr_mask, -+ T* context_buf, -+ T* key_cache_out, -+ T* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream) -+{ -+ const int block_sz = ATTENTION_BLOCK_SIZE; // blockDim.x -+ float scalar = 1.f / (sqrtf(size_per_head * 1.0f) * q_scaling); -+ -+ dim3 grid(inference_batch_size * head_num, seq_len); // gridDim.x gridDim.y -+ int cond = size_per_head * ((ATTENION_OPT) ? 1 : 0); -+ switch (cond) { -+ case 32: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ case 64: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ case 128: -+ attention_kernel_opt -+ <<>>(qkv_buf, -+ qkv_bias, -+ attr_mask, -+ context_buf, -+ key_cache_out, -+ value_cache_out, -+ inference_batch_size, -+ head_num, -+ seq_len, -+ scalar); -+ break; -+ default:; -+ } -+} -+ -+template void myAttnention(const float* qkv_buf, -+ const float* qkv_bias, -+ const float* attr_mask, -+ float* context_buf, -+ float* key_cache_out, -+ float* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePast(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ const float* QKV, -+ const float* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePast(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ const half* QKV, -+ const half* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePastMB(float* q_buf, -+ float* k_buf, -+ float* v_buf, -+ const float* QKV, -+ const float* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); -+ -+template void invokeAddFusedQKVBiasTransposeUsePastMB(half* q_buf, -+ half* k_buf, -+ half* v_buf, -+ const half* QKV, -+ const half* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); - } // namespace fastertransformer -diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h -index be8b178..f418eab 100644 ---- a/src/fastertransformer/kernels/unfused_attention_kernels.h -+++ b/src/fastertransformer/kernels/unfused_attention_kernels.h -@@ -43,6 +43,23 @@ void invokeMaskedSoftMax(T* buffer, - const T scalar, - cudaStream_t stream); - -+template -+void invokeMixMaskedSoftMax(T* io_buffer, -+ const T_M* attr_mask, -+ const T* position_bias, -+ const int* d_seq_len, -+ const int* d_seq_len2, -+ const int batch_size, -+ const int seq_len, -+ const int seq_len_stride, -+ const int tgt_seq_len, -+ const int tgt_seq_stride, -+ const int head_num, -+ const int position_bias_head_num, -+ const T scalar, -+ const bool causal, -+ cudaStream_t stream); -+ - template - void invokeTransposeQKV(T* dst, - T* src, -@@ -81,12 +98,12 @@ void invokeTransposeAttentionOutRemovePadding(T* src, - const int* mask_offset, - cudaStream_t stream); - --template -+template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, -- const T* qkv_bias, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -97,12 +114,132 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, 0, stream); - } - -+ - template -+void invokeAddFusedQKVBiasTransposeUsePast(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template -+void invokeAddFusedQKVBiasTransposeUsePastMB(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ const T* QKV, -+ const T* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int batch_size, -+ const int h_token_num, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool incremental_mode, -+ const bool padding, -+ cudaStream_t stream); -+ -+template -+ void invokeAddFusedQKVBiasTranspose (T *q_buf, -+ T *k_buf, -+ T *v_buf, -+ const T *QKV, -+ const T *qkv_bias, -+ const int max_seq_len, -+ const int actual_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const bool inc, -+ cudaStream_t stream) { -+ -+ } -+template -+void invokeAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* d_sequence_length, -+ const int batch_size, -+ const int max_seq_len, -+ const int h_token_num, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template -+void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ int *padding_mask, -+ cudaStream_t stream); -+template -+void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ const int h_token, -+ const int h_token2, -+ int *padding_mask, -+ int *padding_mask2, -+ cudaStream_t stream); -+ -+template -+void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+ -+template -+void invokeCrossAddFusedQKVBiasTransposeMBVSL(T* q_buf, -+ T* k_buf, -+ T* v_buf, -+ T* QKV, -+ const U* qkv_bias, -+ const int* padding_offset, -+ const int* padding_offset2, -+ const int* d_sequence_length, -+ const int* d_sequence_length2, -+ const int h_token_num, -+ const int h_token_num2, -+ const int batch_size, -+ const int seq_len, -+ const int tgt_seq_len, -+ const int head_num, -+ const int size_per_head, -+ cudaStream_t stream); -+template - void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, -- const T* qkv_bias, -+ const U* qkv_bias, - const int batch_size, - const int seq_len, - const int head_num, -@@ -166,4 +303,21 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, - const float qk_scale, - cudaStream_t stream); - -+ -+template -+void myAttnention(const T* qkv_buf, -+ const T* qkv_bias, -+ const T *attr_mask, -+ T* context_buf, -+ T* key_cache_out, -+ T* value_cache_out, -+ const int inference_batch_size, -+ const int head_num, -+ const int size_per_head, -+ const int seq_len, -+ const float q_scaling, -+ cudaStream_t stream); -+ -+ -+ - } // namespace fastertransformer -diff --git a/src/fastertransformer/layers/CMakeLists.txt b/src/fastertransformer/layers/CMakeLists.txt -index cbaf4fa..49779bf 100644 ---- a/src/fastertransformer/layers/CMakeLists.txt -+++ b/src/fastertransformer/layers/CMakeLists.txt -@@ -14,6 +14,7 @@ - - cmake_minimum_required(VERSION 3.8) - -+add_subdirectory(ms_layers) - add_subdirectory(attention_layers) - add_subdirectory(attention_layers_int8) - add_subdirectory(xlnet_attention_layers) -@@ -30,15 +31,18 @@ set_property(TARGET FfnLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET FfnLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(FfnLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper activation_int8_kernels memory_utils) - -+if(EXAMPLES) - add_library(TensorParallelGeluFfnLayer STATIC TensorParallelGeluFfnLayer.cc) - set_property(TARGET TensorParallelGeluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelGeluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelGeluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) - -+ - add_library(TensorParallelReluFfnLayer STATIC TensorParallelReluFfnLayer.cc) - set_property(TARGET TensorParallelReluFfnLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelReluFfnLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelReluFfnLayer PUBLIC -lcudart FfnLayer nccl_utils) -+endif() - - add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) - set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/layers/DenseWeight.h b/src/fastertransformer/layers/DenseWeight.h -index 5a5eb6a..c95b97c 100644 ---- a/src/fastertransformer/layers/DenseWeight.h -+++ b/src/fastertransformer/layers/DenseWeight.h -@@ -28,4 +28,5 @@ struct DenseWeight { - const float* scale = nullptr; - }; - -+ - } // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -index b21e3a7..746cb71 100644 ---- a/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -+++ b/src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h -@@ -62,13 +62,13 @@ AttentionType getAttentionTypeINT8( - } - } - --template -+template - class BaseAttentionLayer: public BaseLayer { - - public: - virtual void forward(std::vector* output_tensors, - const std::vector* input_tensors, -- const AttentionWeight* attention_weights) = 0; -+ const AttentionWeight* attention_weights) = 0; - BaseAttentionLayer(cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, -diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt -index 9cef315..7170af4 100644 ---- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt -+++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt -@@ -42,8 +42,8 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM - add_library(GptContextAttentionLayer STATIC GptContextAttentionLayer.cc) - set_property(TARGET GptContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) --target_link_libraries(GptContextAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils unfused_attention_kernels) -- -+target_link_libraries(GptContextAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils unfused_attention_kernels activation_kernels) -+if(EXAMPLES) - add_library(TensorParallelDecoderSelfAttentionLayer STATIC TensorParallelDecoderSelfAttentionLayer.cc) - set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -63,6 +63,7 @@ add_library(TensorParallelUnfusedAttentionLayer STATIC TensorParallelUnfusedAtte - set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(TensorParallelUnfusedAttentionLayer PUBLIC -lcudart UnfusedAttentionLayer nccl_utils) -+endif() - - add_library(WindowAttention STATIC WindowAttention.cc) - set_property(TARGET WindowAttention PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -old mode 100644 -new mode 100755 -index bada640..4a48c48 ---- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -+++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -@@ -16,10 +16,39 @@ - */ - - #include "src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" - #include "src/fastertransformer/kernels/unfused_attention_kernels.h" - - namespace fastertransformer { - -+template -+cublasComputeType_t getCublasComputeType() -+{ -+ if (std::is_same::value) -+ return CUBLAS_COMPUTE_16F; -+ -+ else -+ return CUBLAS_COMPUTE_32F_FAST_TF32; -+} -+template -+static void printTensor(char* str, T* input, int size) -+{ -+ printf("%s ", str); -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ fastertransformer::cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ std::cout << input_host[k] << ","; -+ if (k % 10 == 0) -+ std::cout << std::endl; -+ } -+ -+ std::cout << std::endl; -+ -+ free(input_host); -+} - template - void GptContextAttentionLayer::forward(std::vector* output_tensors, - const std::vector* input_tensors, -@@ -34,7 +63,6 @@ void GptContextAttentionLayer::forward(std::vector - // attention_out [batch_size * seq_len, hidden_dimension] - // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] - // value_cache [batch, local_head_num, max_seq_len, size_per_head] -- - FT_CHECK(input_tensors->size() == 3); - FT_CHECK(output_tensors->size() == 3); - FT_CHECK(output_tensors->at(1).shape.size() == 5); -@@ -49,7 +77,7 @@ void GptContextAttentionLayer::forward(std::vector - T* attention_out = (T*)output_tensors->at(0).data; - const T* attention_input = (const T*)input_tensors->at(0).data; - const T* attention_mask = (const T*)input_tensors->at(1).data; -- const bool is_final = *((bool*)(input_tensors->at(2).data)); -+ const bool is_final = false; // *((bool*)(input_tensors->at(2).data)); - - const int m = input_tensors->at(0).shape[0]; - -@@ -134,7 +162,7 @@ void GptContextAttentionLayer::forward(std::vector - request_seq_len, - request_seq_len * request_seq_len, - request_batch_size * local_head_num_, -- CUDA_R_32F); -+ getCublasComputeType()); - sync_check_cuda_error(); - T scalar = 1 / sqrtf(size_per_head_ * 1.0f); - invokeMaskedSoftMax(qk_buf_, -diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -old mode 100644 -new mode 100755 -index 92e2175..9e90e09 ---- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -+++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -@@ -18,7 +18,6 @@ - #pragma once - - #include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" -- - namespace fastertransformer { - - template -diff --git a/src/fastertransformer/layers/ms_layers/BaseLayer.h b/src/fastertransformer/layers/ms_layers/BaseLayer.h -new file mode 100644 -index 0000000..a4078d1 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/BaseLayer.h -@@ -0,0 +1,117 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+#include -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+ -+namespace fastertransformer { -+ -+class BaseLayerMS{ -+public: -+ typedef int (*allGatherFunc)(const void *input_addr, void *output_addr, size_t count, -+ nvinfer1::DataType data_type, cudaStream_t stream); -+ typedef int (*allReduceSumFunc)(const void *input_addr, void *output_addr, size_t count, -+ nvinfer1::DataType data_type, cudaStream_t stream); -+protected: -+ cublasGemmAlgo_t algo_; -+ size_t ws_offset_{0}; -+ int in_idx_; -+ size_t batch_size_; -+ size_t src_seq_len_; -+ size_t tgt_seq_len_; -+ size_t head_num_; -+ size_t head_size_; -+ size_t hidden_size_; -+ -+ int rank_num_{0}; -+ int rank_id_{0}; -+ BaseLayerMS::allGatherFunc all_gather_func_{nullptr}; -+ BaseLayerMS::allReduceSumFunc all_reduce_sum_func_{nullptr}; -+public: -+ template -+ T* GetBuf(void* ws, size_t buf) -+ { -+ return reinterpret_cast(static_cast(ws) + buf); -+ } -+ virtual void SetWSOffset(size_t ws_offset) -+ { -+ ws_offset_ = ws_offset; -+ } -+ size_t GetWSOffset() -+ { -+ return ws_offset_; -+ } -+ void SetIdx(int idx) -+ { -+ in_idx_ = idx; -+ } -+ int GetIdx() -+ { -+ return in_idx_; -+ } -+ virtual void SetAlgo(cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) -+ { -+ algo_ = algo; -+ } -+ virtual void SetParallelFunc(BaseLayerMS::allGatherFunc all_gather_func, BaseLayerMS::allReduceSumFunc all_reduce_sum_func) -+ { -+ all_gather_func_ = all_gather_func; -+ all_reduce_sum_func_ = all_reduce_sum_func; -+ } -+ BaseLayerMS(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ int rank_num, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP): -+ batch_size_(batch_size), -+ src_seq_len_(src_seq_len), -+ tgt_seq_len_(tgt_seq_len), -+ head_num_(head_num), -+ head_size_(head_size), -+ hidden_size_(hidden_size), -+ rank_num_(rank_num), -+ algo_(algo), -+ in_idx_(0){} -+ virtual ~BaseLayerMS() = default; -+ virtual int GetRankNum() -+ { -+ return rank_num_; -+ } -+ virtual void SetRankNum(int rank_num) -+ { -+ rank_num_ = rank_num; -+ } -+ virtual int GetRankId() -+ { -+ return rank_id_; -+ } -+ virtual void SetRankId(int rank_id) -+ { -+ rank_id_ = rank_id; -+ } -+ virtual size_t GetWorkspaceSize() {return 0;} -+ virtual void forward(std::vector &inputs, const std::vector&outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) = 0; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/CMakeLists.txt b/src/fastertransformer/layers/ms_layers/CMakeLists.txt -new file mode 100644 -index 0000000..bd486d1 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/CMakeLists.txt -@@ -0,0 +1,42 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+cmake_minimum_required(VERSION 3.8) -+ -+set(CUTLASS_INSTALL_DIR "${CMAKE_SOURCE_DIR}/3rdparty/cutlass") -+ -+add_library(MSLayer STATIC -+ MSDecoderLayer.cc -+ MSEncoderLayer.cc -+ MSAttentionLayer.cc -+ decoder.cc -+ encoder.cc -+ ffn.cc -+ gemm.cc -+ attention.cc -+ layer_norm.cc -+ opt_allocator.cc -+ debug_utils.cc -+ fmha_cutlass.cu -+ MoeFfnLayer.cu) -+set_property(TARGET MSLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -+set_property(TARGET MSLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+target_include_directories(MSLayer PUBLIC ${CUTLASS_INSTALL_DIR}/include/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/tools/util/include/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/41_fused_multi_head_attention/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/13_two_tensor_op_fusion/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/39_gemm_permute/ -+ PUBLIC ${CUTLASS_INSTALL_DIR}/examples/common) -+target_link_libraries(MSLayer PUBLIC -lcublas -lcudart -lnvinfer unfused_attention_kernels activation_kernels -+ layernorm_kernels add_residual_kernels bert_preprocess_kernels) -diff --git a/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc -new file mode 100755 -index 0000000..51ba2e9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.cc -@@ -0,0 +1,218 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSAttentionLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+ -+namespace fastertransformer { -+ -+template -+MSMHALayer::MSMHALayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool is_cross, -+ bool sparse, -+ bool is_position_bias): -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) -+{ -+ common_param_.stream = stream_; -+ common_param_.cublas_handle = cublas_handle; -+ common_param_.batch_size = max_batch_size; -+ common_param_.src_seq_len = max_src_seq_len; -+ common_param_.tgt_seq_len = max_tgt_seq_len; -+ common_param_.head_num = head_num; -+ common_param_.head_size = size_per_head; -+ common_param_.hidden_size = head_num * size_per_head; -+ common_param_.in_idx = 0; -+ common_param_.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ common_param_.use_past = false; -+ common_param_.query_layer = false; -+ common_param_.h_token_num = max_batch_size * max_src_seq_len; -+ attn_param_.common_param = &common_param_; -+ attn_param_.attn.qkv_bias = !is_position_bias; -+ attn_param_.attn.projection_bias = !is_position_bias; -+ attn_param_.attn.is_cross = is_cross; -+ attn_param_.attn.position_bias = is_position_bias; -+ attn_param_.attn.mask = true; -+ attn_param_.attn.scale = is_position_bias ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ attn_param_.attn.padding_offset = nullptr; -+ this->ms_weights = new AttentionLayerWeight(); -+ -+ attn_layer_ = std::make_shared>( -+ common_param_.batch_size, -+ common_param_.src_seq_len, -+ common_param_.tgt_seq_len, -+ common_param_.head_num, -+ common_param_.head_size, -+ common_param_.hidden_size, -+ attn_param_.attn.qkv_bias, -+ attn_param_.attn.projection_bias, -+ attn_param_.attn.is_cross, -+ attn_param_.attn.position_bias, -+ attn_param_.attn.scale, -+ attn_param_.attn.mask, -+ common_param_.use_past, -+ common_param_.algo); -+ attn_layer_->SetAlgo(common_param_.algo); -+ attn_layer_->SetHTokenNum(common_param_.h_token_num); -+ attn_layer_->SetScale(attn_param_.attn.scale); -+ attn_layer_->SetCross(is_cross); -+ attn_layer_->SetOption(attn_param_.attn.qkv_bias, attn_param_.attn.projection_bias, attn_param_.attn.position_bias, attn_param_.attn.mask); -+} -+template -+void MSMHALayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = attn_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, buff_size_allocator, true)); -+ attn_layer_->SetWSOffset(0); -+ } -+} -+template -+int MSMHALayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const AttentionLayerWeight* attention_weights = dynamic_cast*>(this->ms_weights); -+ if (attention_weights == NULL) { -+ FT_LOG_ERROR("cast AttentionLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ if (attn_param_.attn.position_bias) { -+ if (attn_param_.attn.is_cross) { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->key_weight.kernel, -+ (void*)input_tensors->at(2).data, -+ (void*)input_tensors->at(3).data, -+ (void*)attention_weights->attention_output_weight.kernel}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ else { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vectorinputs = {(void*)input_tensors->at(0).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)attention_weights->attention_output_weight.kernel}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ } -+ else { -+ if (attn_param_.attn.is_cross) { -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->key_weight.kernel, -+ (void*)attention_weights->query_weight.bias, -+ (void*)input_tensors->at(2).data, -+ (void*)attention_weights->attention_output_weight.kernel, -+ (void*)attention_weights->attention_output_weight.bias}; -+ attn_layer_->forward(inputs, outputs,buf_); -+ } -+ else { -+ std::vector outputs = {(void*)output_tensors->at(0).data}; -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)attention_weights->query_weight.kernel, -+ (void*)attention_weights->query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)attention_weights->attention_output_weight.kernel, -+ (void*)attention_weights->attention_output_weight.bias}; -+ attn_layer_->forward(inputs, outputs, buf_, common_param_.cublas_handle, common_param_.stream); -+ } -+ } -+ return 0; -+} -+template -+MSMHALayer::~MSMHALayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+void MSMHALayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+template -+int MSMHALayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ AttentionLayerWeight* attn_weights = dynamic_cast*>(this->ms_weights); -+ if (attn_weights == NULL) { -+ FT_LOG_ERROR("cast AttentionLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == MHA_X1) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->query_weight.bias = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.bias = reinterpret_cast(w_tensors[3].data); -+ } -+ else if (modelId == MHA_X2 || modelId == MHA_CROSS) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = reinterpret_cast(w_tensors[1].data); -+ attn_weights->key_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[3].data); -+ attn_weights->attention_output_weight.bias = reinterpret_cast(w_tensors[4].data); -+ } -+ else if (modelId == MHA_T5) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = nullptr; -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->attention_output_weight.bias = nullptr; -+ } -+ else if (modelId == MHA_T5_CROSS) { -+ attn_weights->query_weight.kernel = reinterpret_cast(w_tensors[0].data); -+ attn_weights->query_weight.bias = nullptr; -+ attn_weights->key_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ attn_weights->attention_output_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ attn_weights->attention_output_weight.bias = nullptr; -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+ -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+template class MSMHALayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h -new file mode 100755 -index 0000000..b8f1ce0 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSAttentionLayer.h -@@ -0,0 +1,68 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSMHALayer: public MSBaseLayer{ -+private: -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ -+ attentionParamRun attn_param_; -+ CommonParam common_param_; -+ std::shared_ptr> attn_layer_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ T* buf_ = nullptr; -+ -+public: -+ MSMHALayer(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool is_cross, -+ bool sparse = false, -+ bool is_position_bias=false); -+ MSMHALayer(MSMHALayer const& attention_layer); -+ virtual ~MSMHALayer(); -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+ int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSBaseLayer.h b/src/fastertransformer/layers/ms_layers/MSBaseLayer.h -new file mode 100644 -index 0000000..b1f4d56 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSBaseLayer.h -@@ -0,0 +1,147 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include -+#include -+ -+#include "3rdparty/trt_fused_multihead_attention/fused_multihead_attention_common.h" -+#include "src/fastertransformer/layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSLayerWeight.h" -+#include "src/fastertransformer/utils/Tensor.h" -+#include "src/fastertransformer/utils/allocator.h" -+#include "src/fastertransformer/utils/cublasMMWrapper.h" -+#include "src/fastertransformer/utils/memory_utils.h" -+namespace fastertransformer { -+ -+enum class MSLayerType { -+ UNFUSED_MS_LAYER, -+ FUSED_MS_LAYER -+}; -+struct opt_arg { -+ size_t batch_size; -+ size_t num_layers; -+ size_t seq_len; // source seq len -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t hidden_size; -+ size_t size_per_head; -+ float eps1; -+ float eps2; -+ float eps3; -+ bool position_bias1; -+ bool position_bias2; -+ bool post_layernorm_residual; -+ bool is_ffn_fp16; -+ bool is_remove_padding; -+ std::string model_name; -+ std::string compute_type; -+ std::string w_compute_type; -+ std::string s_compute_type; -+ size_t ffn_hidden_size; -+ size_t expand_ratio; -+}; -+typedef enum { -+ MHA_X1 = 1, // AttnIn + AttnMask -+ MHA_X2, // AttnIn + EncOut -- same seq size + AttnMask -+ MHA_CROSS, // AttnIn + EncOut + AttnMAsk -+ MHA_T5, // AttnIn + EncOut + AttnMAsk + position_bias -+ MHA_T5_CROSS, // AttnIn + EncOut + AttnMAsk + position_bias -+ TEL, // transformer encoder layer -+ TEL_T5, // transformer encoder layer -+ TDL, -+ TDL_T5, -+} MODEL_TEST_ID_E; -+template -+MSLayerType getMSLayerType( -+ size_t size_per_head, const int sm, const bool remove_padding, const int max_seq_len, const bool is_fuse = true) -+{ -+ if (std::is_same::value && (sm == kSM_70 || sm == kSM_86 || sm == kSM_80 || sm == kSM_75 || sm == kSM_72) -+ && size_per_head == 64 && max_seq_len <= 384 && is_fuse == true) { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+ else { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+} -+ -+template -+MSLayerType getMSLayerTypeINT8( -+ size_t size_per_head, const int sm, const bool remove_padding, const int max_seq_len, const int int8_mode) -+{ -+ if ((int8_mode == 1 || int8_mode == 2) && (sm == kSM_86 || sm == kSM_80 || sm == kSM_75) && size_per_head == 64 -+ && max_seq_len <= 384) { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+ else { -+ return remove_padding ? MSLayerType::FUSED_MS_LAYER : MSLayerType::FUSED_MS_LAYER; -+ } -+} -+ -+template -+class MSBaseLayer: public BaseLayer { -+protected: -+public: -+ MSLayerWeight* ms_weights; -+ virtual int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* layer_weights) = 0; -+ MSBaseLayer(cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool sparse = false): -+ BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse) -+ { -+ } -+ virtual int InitWeight(opt_arg* opt_a, MSLayerWeight* weights, std::vector w_tensors) = 0; -+ virtual ~MSBaseLayer() = default; -+}; -+static int ModelNum(std::string model_name) -+{ -+ if (model_name == "mha_x1") { -+ return MHA_X1; -+ } -+ else if (model_name == "mha_x2") { -+ return MHA_X2; -+ } -+ else if (model_name == "mha_cross") { -+ return MHA_CROSS; -+ } -+ else if (model_name == "mha_T5") { -+ return MHA_T5; -+ } -+ else if (model_name == "mha_T5_cross") { -+ return MHA_T5_CROSS; -+ } -+ else if (model_name == "transformer_encoder_layer") { -+ return TEL; -+ } -+ else if (model_name == "transformer_encoder_layer_t5") { -+ return TEL_T5; -+ } -+ else if (model_name == "transformer_decoder_layer") { -+ return TDL; -+ } -+ else if (model_name == "transformer_decoder_layer_t5") { -+ return TDL_T5; -+ } -+ else { -+ return -1; -+ } -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc -new file mode 100644 -index 0000000..7250635 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.cc -@@ -0,0 +1,248 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSDecoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+ -+namespace fastertransformer { -+ -+template -+MSDLayer::MSDLayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool post_layernorm, -+ bool position_bias1, -+ bool position_bias2, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse): -+ -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), buf_(nullptr) -+{ //update commonparam -+ decoder_param_.common_param.stream = stream_; -+ decoder_param_.common_param.cublas_handle = cublas_handle; -+ decoder_param_.common_param.batch_size = max_batch_size; -+ decoder_param_.common_param.src_seq_len = max_src_seq_len; -+ decoder_param_.common_param.tgt_seq_len = max_tgt_seq_len; -+ decoder_param_.common_param.head_num = head_num; -+ decoder_param_.common_param.head_size = size_per_head; -+ decoder_param_.common_param.hidden_size = head_num * size_per_head; -+ decoder_param_.common_param.in_idx = 0; -+ decoder_param_.common_param.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ decoder_param_.common_param.h_token_num = max_src_seq_len * max_batch_size; -+ decoder_param_.common_param.h_token_num2 = max_tgt_seq_len * max_batch_size; -+ //connect commonparam to attention and ffn -+ decoder_param_.attn1.common_param = &decoder_param_.common_param; -+ decoder_param_.attn2.common_param = &decoder_param_.common_param; -+ decoder_param_.ffn_param.common_param = &decoder_param_.common_param; -+ -+ decoder_param_.ffn_param.ffn_param.ffn_hidden_size = ffn_hidden_size; -+ decoder_param_.ffn_param.ffn_param.ffn_bias = !position_bias1; -+ decoder_param_.ffn_param.ffn_param.ffn_fp16 = is_ffn_fp16; -+ decoder_param_.ffn_param.ffn_param.act_type = !position_bias1 ? ActType::ActType_Gelu : ActType::ActType_Relu; // true; -+ decoder_param_.decoder.eps1 = eps1; -+ decoder_param_.decoder.eps2 = eps2; -+ decoder_param_.decoder.eps3 = eps3; -+ decoder_param_.decoder.layernorm_post = post_layernorm; -+ decoder_param_.decoder.has_beta = !position_bias1; -+ decoder_param_.attn1.attn.qkv_bias = !position_bias1; -+ decoder_param_.attn1.attn.projection_bias = !position_bias1; -+ decoder_param_.attn1.attn.is_cross = false; -+ decoder_param_.attn1.attn.position_bias = position_bias1; -+ decoder_param_.attn1.attn.scale = position_bias1 ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ decoder_param_.attn1.attn.mask = true; -+ decoder_param_.attn2.attn.position_bias = position_bias2; -+ decoder_param_.attn2.attn.qkv_bias = !position_bias2; -+ decoder_param_.attn2.attn.projection_bias = !position_bias2; -+ decoder_param_.attn2.attn.is_cross = true; -+ decoder_param_.attn2.attn.scale = position_bias2 ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ decoder_param_.attn2.attn.mask = true; -+ this->ms_weights = new DecoderLayerWeight(); -+ decoder_layer_ = std::make_shared>( -+ decoder_param_.common_param.batch_size, -+ decoder_param_.common_param.src_seq_len, -+ decoder_param_.common_param.tgt_seq_len, -+ decoder_param_.common_param.head_num, -+ decoder_param_.common_param.head_size, -+ decoder_param_.common_param.hidden_size -+ ); -+ decoder_layer_->SetEps(decoder_param_.decoder.eps1, decoder_param_.decoder.eps2, decoder_param_.decoder.eps3, decoder_param_.decoder.eps3); -+ decoder_layer_->SetIsCross(decoder_param_.attn2.attn.is_cross); -+ decoder_layer_->SetScaleAttn(decoder_param_.attn1.attn.scale); -+ decoder_layer_->SetIsLayerNorm(false, 1e-6f); -+ decoder_layer_->SetFfnParam(decoder_param_.ffn_param.ffn_param.ffn_fp16, decoder_param_.ffn_param.ffn_param.ffn_hidden_size, (FfnBase::ActType)decoder_param_.ffn_param.ffn_param.act_type, decoder_param_.ffn_param.ffn_param.ffn_bias); -+ decoder_layer_->SetT5(position_bias1); -+ decoder_layer_->SetHTokenNum(decoder_param_.common_param.h_token_num, decoder_param_.common_param.h_token_num2); -+ decoder_layer_->SetAlgo(CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+template -+void MSDLayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = decoder_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, buff_size_allocator, true)); -+ } -+} -+ -+template -+void MSDLayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+ -+template -+MSDLayer::~MSDLayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+int MSDLayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const DecoderLayerWeight* decoder_weights = dynamic_cast*>(this->ms_weights); -+ if (decoder_weights == NULL) { -+ FT_LOG_ERROR("cast DecoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ if (decoder_param_.attn1.attn.qkv_bias && decoder_param_.attn2.attn.qkv_bias && !decoder_param_.attn1.attn.position_bias -+ && !decoder_param_.attn2.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)decoder_weights->layernorm1.gamma, -+ (void*)decoder_weights->layernorm1.beta, -+ (void*)decoder_weights->attention.query_weight.kernel, -+ (void*)decoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)decoder_weights->attention.attention_output_weight.kernel, -+ (void*)decoder_weights->attention.attention_output_weight.bias, -+ (void*)decoder_weights->layernorm2.gamma, -+ (void*)decoder_weights->layernorm2.beta, -+ (void*)input_tensors->at(2).data, -+ (void*)decoder_weights->cross_attention.query_weight.kernel, -+ (void*)decoder_weights->cross_attention.key_weight.kernel, -+ (void*)decoder_weights->cross_attention.query_weight.bias, -+ (void*)input_tensors->at(3).data, -+ (void*)decoder_weights->cross_attention.attention_output_weight.kernel, -+ (void*)decoder_weights->cross_attention.attention_output_weight.bias, -+ (void*)decoder_weights->layernorm3.gamma, -+ (void*)decoder_weights->layernorm3.beta, -+ (void*)decoder_weights->decoder_output_mapping.kernel, -+ (void*)decoder_weights->decoder_output_mapping.bias, -+ (void*)decoder_weights->decoder_output_projection.kernel, -+ (void*)decoder_weights->decoder_output_projection.bias}; -+ decoder_layer_->forward(inputs, outputs, buf_, decoder_param_.common_param.cublas_handle, decoder_param_.common_param.stream); -+ } -+ else if (decoder_param_.attn1.attn.position_bias && decoder_param_.attn2.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)decoder_weights->layernorm1.gamma, -+ (void*)decoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(4).data, -+ (void*)decoder_weights->attention.attention_output_weight.kernel, -+ (void*)decoder_weights->layernorm2.gamma, -+ (void*)input_tensors->at(2).data, -+ (void*)decoder_weights->cross_attention.query_weight.kernel, -+ (void*)decoder_weights->cross_attention.key_weight.kernel, -+ (void*)input_tensors->at(3).data, -+ (void*)input_tensors->at(5).data, -+ (void*)decoder_weights->cross_attention.attention_output_weight.kernel, -+ (void*)decoder_weights->layernorm3.gamma, -+ (void*)decoder_weights->decoder_output_mapping.kernel, -+ (void*)decoder_weights->decoder_output_projection.kernel}; -+ decoder_layer_->forward(inputs, outputs, buf_, decoder_param_.common_param.cublas_handle, decoder_param_.common_param.stream); -+ } -+ return 0; -+} -+ -+template -+int MSDLayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ DecoderLayerWeight* decoder_weights = dynamic_cast*>(this->ms_weights); -+ if (decoder_weights == NULL) { -+ FT_LOG_ERROR("cast DecoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == TDL) { -+ decoder_weights->layernorm1.gamma = (const U*)w_tensors[0].data; -+ decoder_weights->layernorm1.beta = (const U*)w_tensors[1].data; -+ decoder_weights->attention.query_weight.kernel = (const U*)w_tensors[2].data; -+ decoder_weights->attention.query_weight.bias = (const U*)w_tensors[3].data; -+ decoder_weights->attention.attention_output_weight.kernel = (const U*)w_tensors[4].data; -+ decoder_weights->attention.attention_output_weight.bias = (const U*)w_tensors[5].data; -+ decoder_weights->layernorm2.gamma = (const U*)w_tensors[6].data; -+ decoder_weights->layernorm2.beta = (const U*)w_tensors[7].data; -+ decoder_weights->cross_attention.query_weight.kernel = (const U*)w_tensors[8].data; -+ decoder_weights->cross_attention.key_weight.kernel = (const U*)w_tensors[9].data; -+ decoder_weights->cross_attention.query_weight.bias = (const U*)w_tensors[10].data; -+ decoder_weights->cross_attention.key_weight.bias = (const U*)w_tensors[10].data; -+ decoder_weights->cross_attention.attention_output_weight.kernel = (const U*)w_tensors[11].data; -+ decoder_weights->cross_attention.attention_output_weight.bias = (const U*)w_tensors[12].data; -+ decoder_weights->layernorm3.gamma = (const U*)w_tensors[13].data; -+ decoder_weights->layernorm3.beta = (const U*)w_tensors[14].data; -+ decoder_weights->decoder_output_mapping.kernel = (const U*)w_tensors[15].data; -+ decoder_weights->decoder_output_mapping.bias = (const U*)w_tensors[16].data; -+ decoder_weights->decoder_output_projection.kernel = (const U*)w_tensors[17].data; -+ decoder_weights->decoder_output_projection.bias = (const U*)w_tensors[18].data; -+ } -+ else if (modelId == TDL_T5) { -+ decoder_weights->layernorm1.gamma = (const U*)w_tensors[0].data; -+ decoder_weights->attention.query_weight.kernel = (const U*)w_tensors[1].data; -+ decoder_weights->attention.attention_output_weight.kernel = (const U*)w_tensors[2].data; -+ decoder_weights->layernorm2.gamma = (const U*)w_tensors[3].data; -+ decoder_weights->cross_attention.query_weight.kernel = (const U*)w_tensors[4].data; -+ decoder_weights->cross_attention.key_weight.kernel = (const U*)w_tensors[5].data; -+ decoder_weights->cross_attention.attention_output_weight.kernel = (const U*)w_tensors[6].data; -+ decoder_weights->layernorm3.gamma = (const U*)w_tensors[7].data; -+ decoder_weights->decoder_output_mapping.kernel = (const U*)w_tensors[8].data; -+ decoder_weights->decoder_output_projection.kernel = (const U*)w_tensors[9].data; -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+template class MSDLayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h -new file mode 100644 -index 0000000..b8f870c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSDecoderLayer.h -@@ -0,0 +1,74 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/decoder.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+ -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSDLayer: public MSBaseLayer { -+private: -+ mutable decoderParamRun decoder_param_; -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ void* buf_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ std::shared_ptr> decoder_layer_; -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ -+public: -+ MSDLayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool post_layernorm, -+ bool position_bias1, -+ bool position_bias2, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse); -+ -+ MSDLayer(MSDLayer const& decoder_layer); -+ -+ virtual ~MSDLayer(); -+ -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+ int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc -new file mode 100644 -index 0000000..9b18049 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc -@@ -0,0 +1,250 @@ -+/* -+ * Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/layers/ms_layers/MSEncoderLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+ -+namespace fastertransformer { -+ -+template -+MSELayer::MSELayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ bool post_layernorm, -+ bool position_bias, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse): -+ -+ MSBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), buf_(nullptr) -+{ -+ // update commonparam -+ encoder_param_.common_param.stream = stream_; -+ encoder_param_.common_param.cublas_handle = cublas_handle; -+ encoder_param_.common_param.batch_size = max_batch_size; -+ encoder_param_.common_param.src_seq_len = max_src_seq_len; -+ encoder_param_.common_param.tgt_seq_len = max_tgt_seq_len; -+ encoder_param_.common_param.head_num = head_num; -+ encoder_param_.common_param.head_size = size_per_head; -+ encoder_param_.common_param.hidden_size = head_num * size_per_head; -+ encoder_param_.common_param.algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ encoder_param_.common_param.in_idx = 0; -+ // connect commonparam to attention and ffn -+ encoder_param_.attn.common_param = &encoder_param_.common_param; -+ encoder_param_.ffn_param.common_param = &encoder_param_.common_param; -+ encoder_param_.common_param.h_token_num = max_src_seq_len * max_batch_size; -+ // update encoder_param_ -+ encoder_param_.encoder.layernorm_post = post_layernorm; -+ encoder_param_.encoder.eps1 = eps1; -+ encoder_param_.encoder.eps2 = eps2; -+ encoder_param_.ffn_param.ffn_param.ffn_hidden_size = ffn_hidden_size; -+ encoder_param_.ffn_param.ffn_param.ffn_fp16 = is_ffn_fp16; -+ encoder_param_.attn.attn.projection_bias = !position_bias; -+ encoder_param_.attn.attn.is_cross = false; -+ encoder_param_.attn.attn.position_bias = position_bias; -+ encoder_param_.attn.attn.qkv_bias = !position_bias; -+ encoder_param_.encoder.has_beta = !position_bias; -+ encoder_param_.ffn_param.ffn_param.ffn_bias = !position_bias; -+ encoder_param_.ffn_param.ffn_param.act_type = -+ !position_bias ? ActType::ActType_Gelu : ActType::ActType_Relu; // true; -+ encoder_param_.attn.attn.scale = position_bias ? 1.0f : 1.0f / sqrtf(size_per_head * 1.0f); -+ encoder_param_.attn.attn.mask = true; -+ this->ms_weights = new EncoderLayerWeight(); -+ encoder_layer_ = std::make_shared>( -+ encoder_param_.common_param.batch_size, -+ encoder_param_.common_param.src_seq_len, -+ encoder_param_.common_param.head_num, -+ encoder_param_.common_param.head_size, -+ encoder_param_.common_param.hidden_size -+ ); -+ -+ encoder_layer_->SetT5(encoder_param_.attn.attn.position_bias); -+ encoder_layer_->SetScaleAttn(encoder_param_.attn.attn.scale); -+ encoder_layer_->SetUsePast(false); -+ encoder_layer_->SetIsLayerNorm(false, 1e-6f); -+ encoder_layer_->SetHTokenNum(encoder_param_.common_param.h_token_num, encoder_param_.common_param.h_token_num); -+ encoder_layer_->SetFfnParam(encoder_param_.ffn_param.ffn_param.ffn_fp16, encoder_param_.ffn_param.ffn_param.ffn_hidden_size, (FfnBase::ActType)encoder_param_.ffn_param.ffn_param.act_type, encoder_param_.ffn_param.ffn_param.ffn_bias); -+ encoder_layer_->SetQueryLayer(false); -+ encoder_layer_->SetAlgo(CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+} -+ -+template -+void MSELayer::allocateBuffer() -+{ -+ if (buf_ == nullptr) { -+ size_t buff_size_allocator = encoder_layer_->GetWorkspaceSize(); -+ buf_ = reinterpret_cast(allocator_->reMalloc(buf_, sizeof(T) * buff_size_allocator, true)); -+ encoder_layer_->SetWSOffset(0); -+ } -+} -+ -+template -+void MSELayer::freeBuffer() -+{ -+ if (buf_ != nullptr) { -+ allocator_->free(buf_); -+ buf_ = nullptr; -+ } -+} -+ -+template -+MSELayer::~MSELayer() -+{ -+ cublas_wrapper_ = nullptr; -+ freeBuffer(); -+} -+ -+template -+int MSELayer::forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) -+{ -+ const EncoderLayerWeight* encoder_weights = dynamic_cast*>(this->ms_weights); -+ if (encoder_weights == NULL) { -+ FT_LOG_ERROR("cast EncoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ allocateBuffer(); // only once -+ std::vector outputs= {(void*)output_tensors->at(0).data}; -+ if (!encoder_param_.encoder.layernorm_post) { -+ if (encoder_param_.attn.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_projection.kernel -+ -+ }; -+ encoder_layer_->forward(inputs,outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ else { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->layernorm1.beta, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)encoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->attention.attention_output_weight.bias, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->layernorm2.beta, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_mapping.bias, -+ (void*)encoder_weights->encoder_output_projection.kernel, -+ (void*)encoder_weights->encoder_output_projection.bias}; -+ encoder_layer_->forward(inputs, outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ } -+ else { -+ if (encoder_param_.attn.attn.position_bias) { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)input_tensors->at(1).data, -+ (void*)input_tensors->at(2).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_projection.kernel -+ -+ }; -+ encoder_layer_->forward(inputs, outputs,buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ else { -+ std::vector inputs = {(void*)input_tensors->at(0).data, -+ (void*)encoder_weights->attention.query_weight.kernel, -+ (void*)encoder_weights->attention.query_weight.bias, -+ (void*)input_tensors->at(1).data, -+ (void*)encoder_weights->attention.attention_output_weight.kernel, -+ (void*)encoder_weights->attention.attention_output_weight.bias, -+ (void*)encoder_weights->layernorm1.gamma, -+ (void*)encoder_weights->layernorm1.beta, -+ (void*)encoder_weights->encoder_output_mapping.kernel, -+ (void*)encoder_weights->encoder_output_mapping.bias, -+ (void*)encoder_weights->encoder_output_projection.kernel, -+ (void*)encoder_weights->encoder_output_projection.bias, -+ (void*)encoder_weights->layernorm2.gamma, -+ (void*)encoder_weights->layernorm2.beta}; -+ encoder_layer_->forward(inputs, outputs, buf_, encoder_param_.common_param.cublas_handle, encoder_param_.common_param.stream); -+ } -+ } -+ -+ return 0; -+} -+ -+template -+int MSELayer::InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) -+{ -+ EncoderLayerWeight* encoder_weights = dynamic_cast*>(this->ms_weights); -+ if (encoder_weights == NULL) { -+ FT_LOG_ERROR("cast EncoderLayerWeight not sucsses\n"); -+ return -1; -+ } -+ int modelId = ModelNum(opt_a->model_name); -+ if (modelId == TEL) { -+ encoder_weights->attention.query_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ encoder_weights->attention.query_weight.bias = reinterpret_cast(w_tensors[3].data); -+ encoder_weights->attention.attention_output_weight.kernel = reinterpret_cast(w_tensors[4].data); -+ encoder_weights->attention.attention_output_weight.bias = reinterpret_cast(w_tensors[5].data); -+ encoder_weights->layernorm1.gamma = reinterpret_cast(w_tensors[0].data); -+ encoder_weights->layernorm1.beta = reinterpret_cast(w_tensors[1].data); -+ encoder_weights->layernorm2.gamma = reinterpret_cast(w_tensors[6].data); -+ encoder_weights->layernorm2.beta = reinterpret_cast(w_tensors[7].data); -+ encoder_weights->encoder_output_mapping.kernel = reinterpret_cast(w_tensors[8].data); -+ encoder_weights->encoder_output_projection.kernel = reinterpret_cast(w_tensors[10].data); -+ encoder_weights->encoder_output_mapping.bias = reinterpret_cast(w_tensors[9].data); -+ encoder_weights->encoder_output_projection.bias = reinterpret_cast(w_tensors[11].data); -+ } -+ else if (modelId == TEL_T5) { -+ encoder_weights->attention.query_weight.kernel = reinterpret_cast(w_tensors[1].data); -+ encoder_weights->attention.attention_output_weight.kernel = reinterpret_cast(w_tensors[2].data); -+ encoder_weights->layernorm1.gamma = reinterpret_cast(w_tensors[0].data); -+ encoder_weights->layernorm2.gamma = reinterpret_cast(w_tensors[3].data); -+ encoder_weights->encoder_output_mapping.kernel = reinterpret_cast(w_tensors[4].data); -+ encoder_weights->encoder_output_projection.kernel = reinterpret_cast(w_tensors[5].data); -+ } -+ else { -+ FT_LOG_ERROR("illegal model !\n"); -+ return -1; -+ } -+ return 0; -+} -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+template class MSELayer; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h -new file mode 100644 -index 0000000..358a3ca ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.h -@@ -0,0 +1,72 @@ -+/* -+ * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -+ * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/encoder.h" -+#include "src/fastertransformer/layers/ms_layers/param.h" -+namespace fastertransformer { -+ -+// TODO(haim): Add template according to "mix" compute type (fp32, fp16) -+template -+class MSELayer: public MSBaseLayer { -+private: -+ void allocateBuffer() override; -+ void freeBuffer() override; -+ void* buf_; -+ using MSBaseLayer::is_free_buffer_after_forward_; -+ using MSBaseLayer::is_allocate_buffer_; -+ using MSBaseLayer::cublas_wrapper_; -+ using MSBaseLayer::allocator_; -+ std::shared_ptr> encoder_layer_; -+protected: -+ using MSBaseLayer::stream_; -+ using MSBaseLayer::sparse_; -+ -+public: -+ encoderParamRun encoder_param_; -+ MSELayer(size_t max_batch_size, -+ size_t max_src_seq_len, -+ size_t max_tgt_seq_len, -+ size_t head_num, -+ size_t size_per_head, -+ size_t ffn_hidden_size, -+ float eps1, -+ float eps2, -+ bool post_layernorm, -+ bool position_bias, -+ bool is_ffn_fp16, -+ cudaStream_t stream, -+ cublasMMWrapper* cublas_wrapper, -+ cublasHandle_t cublas_handle, -+ IAllocator* allocator, -+ bool is_free_buffer_after_forward, -+ bool is_qk_buf_float, -+ bool sparse); -+ -+ MSELayer(MSELayer const& encoder_layer); -+ -+ virtual ~MSELayer(); -+ -+ int forward(std::vector* output_tensors, -+ const std::vector* input_tensors, -+ const MSLayerWeight* weights) override; -+int InitWeight(opt_arg* opt_a, MSLayerWeight* ms_weights, std::vector w_tensors) override; -+}; -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MSLayerWeight.h b/src/fastertransformer/layers/ms_layers/MSLayerWeight.h -new file mode 100644 -index 0000000..d4db37d ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MSLayerWeight.h -@@ -0,0 +1,55 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/layers/DenseWeight.h" -+namespace fastertransformer { -+ -+template -+struct MSLayerWeight { -+ virtual ~MSLayerWeight() {} -+}; -+ -+template -+struct AttentionLayerWeight: MSLayerWeight { -+ DenseWeight query_weight; -+ DenseWeight key_weight; -+ DenseWeight value_weight; -+ DenseWeight attention_output_weight; -+}; -+ -+template -+struct DecoderLayerWeight: MSLayerWeight { -+ AttentionLayerWeight attention; -+ AttentionLayerWeight cross_attention; -+ DenseWeight decoder_output_mapping; -+ DenseWeight decoder_output_projection; -+ LayerNormWeight layernorm1; -+ LayerNormWeight layernorm2; -+ LayerNormWeight layernorm3; -+}; -+ -+template -+struct EncoderLayerWeight: MSLayerWeight { -+ AttentionLayerWeight attention; -+ DenseWeight encoder_output_mapping; -+ DenseWeight encoder_output_projection; -+ LayerNormWeight layernorm1; -+ LayerNormWeight layernorm2; -+}; -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu -new file mode 100644 -index 0000000..ccf9189 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.cu -@@ -0,0 +1,629 @@ -+ -+#include "MoeFfnLayer.h" -+#include "cublas_api.h" -+#include "cuda_kernels.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+ -+ -+namespace fastertransformer { -+ -+cublasStatus_t cublasGemmExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A, -+ cudaDataType_t Atype, -+ const void* B, -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C, -+ cudaDataType_t Ctype, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void* B_ = A; -+ const void* A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ -+ return cublasGemmEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ B_, -+ Btype_, -+ ldb_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ computeType, -+ algo); -+} -+ -+cublasStatus_t cublasGemmStridedBatchedExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A, -+ cudaDataType_t Atype, -+ const void* B, -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C, -+ cudaDataType_t Ctype, -+ int batchCount, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void* B_ = A; -+ const void* A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ long long int stride_a = m * k; -+ long long int stride_b = n * k; -+ long long int stride_c = n * m; -+ long long int stride_a_ = stride_b; -+ long long int stride_b_ = stride_a; -+ -+ return cublasGemmStridedBatchedEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ stride_a_, -+ B_, -+ Btype_, -+ ldb_, -+ stride_b_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ stride_c, -+ batchCount, -+ computeType, -+ algo); -+} -+ -+cublasStatus_t cublasGemmArrBatchedExRowMajor(cublasHandle_t handle, -+ cublasOperation_t transa, -+ cublasOperation_t transb, -+ int m, -+ int n, -+ int k, -+ const void* alpha, -+ const void* A[], -+ cudaDataType_t Atype, -+ const void* B[], -+ cudaDataType_t Btype, -+ const void* beta, -+ void* C[], -+ cudaDataType_t Ctype, -+ int batchCount, -+ cublasComputeType_t computeType, -+ cublasGemmAlgo_t algo) -+{ -+ const void** B_ = A; -+ const void** A_ = B; -+ int lda = (transa == CUBLAS_OP_T) ? m : k; -+ int ldb = (transb == CUBLAS_OP_T) ? k : n; -+ int lda_ = ldb; -+ int ldb_ = lda; -+ int ldc = n; -+ int m_ = n; -+ int n_ = m; -+ cudaDataType_t Atype_ = Btype; -+ cudaDataType_t Btype_ = Atype; -+ cublasOperation_t transa_ = transb; -+ cublasOperation_t transb_ = transa; -+ -+ return cublasGemmBatchedEx(handle, -+ transa_, -+ transb_, -+ m_, -+ n_, -+ k, -+ alpha, -+ A_, -+ Atype_, -+ lda_, -+ B_, -+ Btype_, -+ ldb_, -+ beta, -+ C, -+ Ctype, -+ ldc, -+ batchCount, -+ computeType, -+ algo); -+} -+ -+PanguMoeFfnLayer::PanguMoeFfnLayer(int hidden_size, -+ int expert_num, -+ int ffn_hidden_size, -+ int rank_num, -+ int seq_len, -+ float expert_capability, -+ int batch_size): -+ expert_num_(expert_num), -+ ffn_hidden_size_(ffn_hidden_size), -+ expert_capability_(expert_capability), -+ BaseLayerMS(batch_size, seq_len, seq_len, 0, 0, hidden_size, rank_num) -+{ -+ max_capacity_ = static_cast(std::ceil(expert_capability_ * src_seq_len_ / expert_num_) + 0.01f); -+} -+ -+size_t PanguMoeFfnLayer::GetWorkspaceSize() -+{ -+ size_t size = ALIGN(sizeof(half) * expert_num_ * gather_stride(), ALIGN_SIZE) + // gather tokens (sort by expert) -+ ALIGN(sizeof(int) * expert_num_ * (router_stride() + 1), ALIGN_SIZE) + // router -+ 4 * ALIGN(expert_num_ * sizeof(half*), ALIGN_SIZE); // Group/Batch matmul arrays -+ -+ size_t s1 = ALIGN(sizeof(int) * expert_num_ * batch_size_, ALIGN_SIZE); // capcity per batch and expert -+ size_t s2 = ALIGN(sizeof(half) * ffn_hidden_size_ * max_capacity_, ALIGN_SIZE); // size of mapping mm -+ size_t s3 = ALIGN(sizeof(half) * hidden_size_, ALIGN_SIZE); // size of projection bias (AllGather) -+ size_t s4 = ALIGN(sizeof(half) * expert_num_ * ffn_hidden_size_ * experimental_threshold() * batch_size_, ALIGN_SIZE); // every expert (in incremental can have up to batch token allocated) -+ size_t mx = max(s1, s2); -+ mx = max(mx, s3); -+ mx = max(mx, s4); -+ return size + mx; -+} -+ -+__global__ void HashRouter(const int* expert_id, -+ int capacity, -+ const int* padding_offset, -+ const int* seq_length, -+ int* ws, -+ int* router, -+ int token_num, -+ int seq_len, -+ int batch, -+ int router_stride) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ int w_id = 0; -+ -+ // zero actual capacity per batch -+ ws = ws + e_id * batch; -+ for (int i = 0; i < batch; i++) { -+ ws[i] = (capacity < seq_length[i]) ? capacity : seq_length[i]; -+ } -+ // route tokens to expert (priority for earlier tokens) -+ for (int i = 0; i < token_num; i++) { -+ int element_offset = padding_offset[i] + i; -+ int b_id = element_offset / seq_len; -+ int cur_expert_id = (expert_id[i] > 0) ? expert_id[i] : 0; // do Relu -+ if ((cur_expert_id == e_id) && (ws[b_id] >= 0)) { -+ r[w_id++] = i; -+ ws[b_id]--; -+ } -+ } -+ w_id = (w_id == 1) ? (r[0] | (1 << 31)) : w_id; -+ // Total tokens per expert -+ router[e_id] = w_id; -+} -+ -+__global__ void -+HashRouterGather(int* router, const half* in, half* gather, int router_stride, int gather_stride, int hidden_size) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ half* g = gather + gather_stride * e_id; -+ int tokens_per_expert = router[e_id]; -+ if (tokens_per_expert & (1 << 31)) return; -+ for (int index = threadIdx.x; index < hidden_size * tokens_per_expert; index += blockDim.x) { -+ int token_idx = index / hidden_size; -+ int hid_idx = index % hidden_size; -+ int token_id = r[token_idx]; -+ int src_offset = token_id * hidden_size + hid_idx; -+ int dst_offset = token_idx * hidden_size + hid_idx; -+ -+ g[dst_offset] = in[src_offset]; -+ } -+} -+ -+int GatherByExpert(const int* expert_id, -+ int expert_num, -+ int max_capacity, -+ const int* padding_offset, -+ const int* seq_length, -+ void* ws, -+ const half* in, -+ int* router, -+ half* gather, -+ int token_num, -+ int seq_len, -+ int batch, -+ int router_stride, -+ int gather_stride, -+ int hidden_size, -+ int* expert_per_token_h, -+ cudaStream_t stream) -+{ -+ // Step I - Build router -+ // ARR[EXPERT#] # of tokens per expert -+ // 0 token # [list of tokens] -+ // 1 token # [list of tokens] -+ // . -+ // . -+ // . -+ // 15 token # [list of tokens] -+ dim3 grid(expert_num); -+ dim3 block(1); -+ HashRouter<<>>( -+ expert_id, max_capacity, padding_offset, seq_length, (int*)ws, router, token_num, seq_len, batch, router_stride); -+ // Step II - Gather tokens -+ // 0 [List of token data] -+ // 1 [List of token data] -+ // . -+ // . -+ // . -+ // 15 [list of tokens] -+ -+ dim3 grid1(expert_num); -+ dim3 block1(1024); -+ HashRouterGather<<>>(router, in, gather, router_stride, gather_stride, hidden_size); -+ // step III - Copy token# per expert to host device -+ cudaMemcpyAsync(expert_per_token_h, router, sizeof(int) * expert_num, cudaMemcpyDeviceToHost, stream); -+ return 0; -+} -+ -+__global__ void -+HashRouterScatter(int* router, const half* gather, half* scater, int router_stride, int gather_stride, int hidden_size) -+{ -+ int e_id = blockIdx.x; -+ int* r = router + gridDim.x + router_stride * e_id; -+ const half* g = gather + gather_stride * e_id; -+ int tokens_per_expert = router[e_id]; -+ if (tokens_per_expert & (1 << 31)) return; -+ for (int index = threadIdx.x; index < hidden_size * tokens_per_expert; index += blockDim.x) { -+ int token_idx = index / hidden_size; -+ int hid_idx = index % hidden_size; -+ int token_id = r[token_idx]; -+ int src_offset = token_idx * hidden_size + hid_idx; -+ int dst_offset = token_id * hidden_size + hid_idx; -+ scater[dst_offset] = g[src_offset]; -+ } -+} -+ -+int ScatterByExpert(int* router, -+ half* gather, -+ half* scatter, -+ int expert_num, -+ int router_stride, -+ int gather_stride, -+ int hidden_size, -+ cudaStream_t stream) -+{ -+ dim3 grid1(expert_num); -+ dim3 block1(1024); -+ HashRouterScatter<<>>(router, gather, scatter, router_stride, gather_stride, hidden_size); -+ return 0; -+} -+ -+void PanguMoeFfnLayer::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ in_idx_ = 0; -+ int moe_id = expert_offset_; -+ half* layernorm = reinterpret_cast(inputs[in_idx_++]); -+ const int* expert_ids_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* weight1_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* bias1_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* weight2_ptr = reinterpret_cast(inputs[in_idx_++]); -+ half* bias2_ptr = reinterpret_cast(inputs[in_idx_++]); -+ const int* padding_offset_d = padding_offset_d_; -+ int* seq_length_d = seq_len_d_; -+ half* output = reinterpret_cast(outputs[0]); -+ // Get WS pointer -+ ws = GetBuf(ws, ws_offset_); -+ bool incremental = batch_size_ >= h_token_num_; -+ int token_number = h_token_num_; -+ expert_ids_ptr += moe_id * token_number; -+ if (token_number == 1) { // special handle when token# is 1 -+ forward_single_token(layernorm, expert_ids_ptr, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, ws, cublas_handle, stream); -+ return; -+ } -+ // router & gather are used till end of process - malloc first -+ int* router = reinterpret_cast(ws); -+ half* gather = reinterpret_cast(router + ALIGN(expert_num_ * (router_stride() + 1), ALIGN_SIZE)); -+ void* workspace = reinterpret_cast(gather + ALIGN(expert_num_ * gather_stride(), ALIGN_SIZE)); -+ -+ // Step I - zero output (in incremental mode all tokens are set) -+ if (!incremental) cudaMemsetAsync(output, 0, sizeof(half) * token_number * hidden_size_, stream); -+ -+ // Step II - gather tokens according to expert id -+ int expert_per_token_h[expert_num_]; -+ GatherByExpert(expert_ids_ptr, -+ expert_num_, -+ max_capacity_, -+ padding_offset_d, -+ seq_length_d, -+ workspace, -+ layernorm, -+ router, -+ gather, -+ token_number, -+ src_seq_len_, -+ batch_size_, -+ router_stride(), -+ gather_stride(), -+ hidden_size_, -+ expert_per_token_h, -+ stream); -+ -+ cudaStreamSynchronize(stream); // make sure expert_per_token_h is update to host -+ -+ // Step III - Run FFN per expert -+ for (int ei = 0; ei < expert_num_; ei++) { -+ int expert_token_num = expert_per_token_h[ei]; -+ if (!(expert_token_num & (1 << 31)) && expert_token_num > experimental_threshold()) { -+ int g_offset = ei * gather_stride(); -+ half* g = gather + g_offset; -+ forward_expert(g, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, g, workspace, ei, expert_token_num, cublas_handle, stream); -+ expert_per_token_h[ei] = 0; -+ } -+ } -+ forward_expert_experimental(layernorm, expert_per_token_h, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, workspace, cublas_handle, stream); -+ -+ // step IV - Scatter tokens into output according to router -+ ScatterByExpert(router, gather, output, expert_num_, router_stride(), gather_stride(), hidden_size_, stream); -+} -+ -+void PanguMoeFfnLayer::forward_single_token(half *in, const int *expert_in_ids, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle, cudaStream_t stream) { -+ int expert_token_num = 1; -+ int expert_id; -+ cudaMemcpyAsync(&expert_id, expert_in_ids, sizeof(int), cudaMemcpyDeviceToHost, stream); -+ forward_expert(in, weight1_ptr, bias1_ptr, weight2_ptr, bias2_ptr, output, workspace, expert_id, expert_token_num, cublas_handle, stream); -+} -+ -+void PanguMoeFfnLayer::forward_expert(half* in, -+ half* weight1_ptr, -+ half* bias1_ptr, -+ half* weight2_ptr, -+ half* bias2_ptr, -+ half* output, -+ void *workspace, -+ int expert_id, -+ int expert_token_num, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ float alpha = (float)(1.0f); -+ float beta = (float)(0.0f); -+ cublasGemmAlgo_t algo = (cublasGemmAlgo_t)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; -+ -+ int w_offset = expert_id * hidden_size_ * ffn_hidden_size_; -+ int b1_offset = expert_id * ffn_hidden_size_; -+ int b2_size = hidden_size_;// / rank_num_; -+ int b2_offset = expert_id * b2_size; -+ half* w1 = weight1_ptr + w_offset; -+ half* w2 = weight2_ptr + w_offset; -+ half* bias1 = bias1_ptr + b1_offset; -+ half* bias2 = bias2_ptr + b2_offset; -+ half* mm1 = reinterpret_cast(workspace); -+ cublasGemmExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ expert_token_num, -+ ffn_hidden_size_, -+ hidden_size_, -+ &alpha, -+ in, -+ CUDA_R_16F, -+ w1, -+ CUDA_R_16F, -+ &beta, -+ mm1, -+ CUDA_R_16F, -+ compute_type, -+ algo); -+ if (act_type_ == FfnBase::ActType::Gelu) { -+ invokeAddBiasGelu(mm1, bias1, expert_token_num, ffn_hidden_size_, stream); -+ } -+ else if (act_type_ == FfnBase::ActType::FastGelu) { -+ invokeAddBiasFastGelu(mm1, bias1, expert_token_num, ffn_hidden_size_, stream); -+ } -+ cublasGemmExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ expert_token_num, -+ hidden_size_, -+ ffn_hidden_size_, -+ &alpha, -+ mm1, -+ CUDA_R_16F, -+ w2, -+ CUDA_R_16F, -+ &beta, -+ output, -+ CUDA_R_16F, -+ compute_type, -+ algo); -+ if (all_reduce_sum_func_ != nullptr) { -+ (all_reduce_sum_func_)(output, output, hidden_size_ * expert_token_num, nvinfer1::DataType::kHALF, stream); -+ } -+ -+ invokeAddBias(output, bias2, expert_token_num, hidden_size_, stream); -+} -+ -+ -+void PanguMoeFfnLayer::forward_expert_experimental(half* input, int *expert_per_token_h, -+ half* weight1_ptr, -+ half* bias1_ptr, -+ half* weight2_ptr, -+ half* bias2_ptr, -+ half* output, -+ void *workspace, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ float alpha = (float)(1.0f); -+ float beta = (float)(0.0f); -+ cublasGemmAlgo_t algo = (cublasGemmAlgo_t)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; -+ -+ int cnt = 0; -+ half* a1_arr_h[expert_num_]; -+ half* b1_arr_h[expert_num_]; -+ half* b2_arr_h[expert_num_]; -+ half* c1_arr_h[expert_num_]; -+ -+ uint8_t* align = reinterpret_cast(workspace); -+ a1_arr_ = reinterpret_cast(align); -+ b1_arr_ = a1_arr_ + expert_num_; -+ b2_arr_ = b1_arr_ + expert_num_; -+ c1_arr_ = b2_arr_ + expert_num_; -+ -+ half* mm1 = reinterpret_cast(c1_arr_ + ALIGN(expert_num_, ALIGN_SIZE)); -+ size_t mapping_stride = ffn_hidden_size_ * 1 * batch_size_; -+ // prepare GEMM arrays -+ int mx_token = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ mx_token = max(mx_token, token_num); -+ if (token_num) { -+ a1_arr_h[cnt] = out_flag ? input + (expert_per_token_h[ei] & ~ (1 << 31)) * hidden_size_ : input + ei * gather_stride(); -+ b1_arr_h[cnt] = weight1_ptr + ei * hidden_size_ * ffn_hidden_size_; -+ b2_arr_h[cnt] = weight2_ptr + ei * hidden_size_ * ffn_hidden_size_; -+ c1_arr_h[cnt] = mm1 + ei * mapping_stride; -+ cnt++; -+ } -+ } -+ if (cnt == 0) return; -+ // copy array to device -+ cudaMemcpyAsync(a1_arr_, a1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(b1_arr_, b1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(b2_arr_, b2_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ cudaMemcpyAsync(c1_arr_, c1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ -+ cublasGemmArrBatchedExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ mx_token, -+ ffn_hidden_size_, -+ hidden_size_, -+ &alpha, -+ (const void**)a1_arr_, -+ CUDA_R_16F, -+ (const void**)b1_arr_, -+ CUDA_R_16F, -+ &beta, -+ (void**)c1_arr_, -+ CUDA_R_16F, -+ cnt, -+ compute_type, -+ algo); -+ int arr_idx = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ if (token_num) { -+ int b1_offset = ei * ffn_hidden_size_; -+ half* bias1 = bias1_ptr + b1_offset; -+ if (act_type_ == FfnBase::ActType::Gelu) { -+ invokeAddBiasGelu(c1_arr_h[arr_idx++], bias1, token_num, ffn_hidden_size_, stream); -+ } -+ else if (act_type_ == FfnBase::ActType::FastGelu) { -+ invokeAddBiasFastGelu(c1_arr_h[arr_idx++], bias1, token_num, ffn_hidden_size_, stream); -+ } -+ } -+ } -+ -+ cnt = 0; -+ if (input != output) { -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ mx_token = max(mx_token, token_num); -+ if (token_num) { -+ a1_arr_h[cnt++] = out_flag ? output + (expert_per_token_h[ei] & ~ (1 << 31)) * hidden_size_ : output + ei * gather_stride(); -+ } -+ cudaMemcpyAsync(a1_arr_, a1_arr_h, cnt * sizeof(half*), cudaMemcpyHostToDevice, stream); -+ } -+ } -+ -+ cublasGemmArrBatchedExRowMajor(cublas_handle, -+ CUBLAS_OP_N, -+ CUBLAS_OP_N, -+ mx_token, -+ hidden_size_, -+ ffn_hidden_size_, -+ &alpha, -+ (const void**)c1_arr_, -+ CUDA_R_16F, -+ (const void**)b2_arr_, -+ CUDA_R_16F, -+ &beta, -+ (void**)a1_arr_, -+ CUDA_R_16F, -+ cnt, -+ compute_type, -+ algo); -+ arr_idx = 0; -+ for (int ei = 0; ei < expert_num_; ei++) { -+ bool out_flag = expert_per_token_h[ei] & (1 << 31); -+ int token_num = out_flag ? 1 : expert_per_token_h[ei]; -+ if (token_num) { -+ if (all_reduce_sum_func_ != nullptr) { -+ (all_reduce_sum_func_)(a1_arr_h[arr_idx], -+ a1_arr_h[arr_idx], -+ hidden_size_ * token_num, -+ nvinfer1::DataType::kHALF, -+ stream); -+ } -+ -+ int b2_offset = ei * hidden_size_; -+ half* bias2 = bias2_ptr + b2_offset; -+ invokeAddBias(a1_arr_h[arr_idx], bias2, token_num, hidden_size_, stream); -+ arr_idx++; -+ } -+ } -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h -new file mode 100644 -index 0000000..e3afe9c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/MoeFfnLayer.h -@@ -0,0 +1,82 @@ -+#ifndef MOE_FFN_LAYER_H_ -+#define MOE_FFN_LAYER_H_ -+#pragma once -+ -+#include -+#include -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+ -+class PanguMoeFfnLayer : public BaseLayerMS { -+ public: -+ PanguMoeFfnLayer(int hidden_size, -+ int expert_num, -+ int ffn_hidden_size, -+ int rank_num, -+ int seq_len, -+ float expert_capability, -+ int batch_size); -+ size_t GetWorkspaceSize() override; -+ void SetExpertNum(int expert_num) { -+ expert_num_ = expert_num; -+ max_capacity_ = static_cast(std::ceil(expert_capability_ * src_seq_len_ / expert_num_) + 0.01f); -+ } -+ void SetFfnHiddenSize(int ffn_hidden_size) {ffn_hidden_size_ = ffn_hidden_size;} -+ void SetExpertOffset(size_t expert_offset) {expert_offset_ = expert_offset;} -+ void SetHTokenNum(size_t h_token_num) -+ { -+ h_token_num_ = h_token_num; -+ } -+ void SetPaddingOffsetDevice(int* padding_offset) {padding_offset_d_ = padding_offset;} -+ void SetSeqLenDevice(int* seq_len) {seq_len_d_ = seq_len;} -+ void SetSeqLenHost(int* seq_len) {seq_len_h_ = seq_len;} -+ void SetActType(FfnBase::ActType act_type) { act_type_ = act_type; } -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ private: -+ size_t gather_stride() {return max_capacity_ * batch_size_ * hidden_size_;}; -+ size_t router_stride() {return max_capacity_ * batch_size_;}; -+ -+ void forward_single_token(half *in, const int *expert_in_ids, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ void forward_expert(half *in, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, int expert_id, int expert_token_num, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ void forward_expert_experimental(half *gather, int *token_per_expert, half *weight1_ptr, half *bias1_ptr, half *weight2_ptr, half *bias2_ptr, half* output, void *workspace, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0); -+ int experimental_threshold() {return 1;}; -+ void PanguMoeLayer(const int* d_in, -+ const half* layernorm, -+ half* out, -+ const half* weight1, -+ const half* weight2, -+ const half* bias1, -+ const half* bias2, -+ int batch_size, -+ int expert_num, -+ int length, -+ int hidden_size, -+ int hidden_size2, -+ int onehot_size, -+ void* workspace, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream); -+ int expert_num_; -+ int ffn_hidden_size_; -+ float expert_capability_; -+ int max_capacity_; -+ size_t expert_offset_; -+ int *capacity_ = nullptr; -+ int *padding_offset_d_ = nullptr; -+ int *seq_len_d_ = nullptr; -+ int *seq_len_h_ = nullptr; -+ FfnBase::ActType act_type_; -+ size_t h_token_num_; -+ half** a1_arr_ = nullptr; -+ half** b1_arr_ = nullptr; -+ half** b2_arr_ = nullptr; -+ half** c1_arr_ = nullptr; -+}; -+ -+} // pangumoe -+ -+#endif // MOE_FFN_LAYER_H_ -diff --git a/src/fastertransformer/layers/ms_layers/attention.cc b/src/fastertransformer/layers/ms_layers/attention.cc -new file mode 100644 -index 0000000..0659eab ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/attention.cc -@@ -0,0 +1,773 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "fmha_cutlass.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+namespace fastertransformer { -+ -+template -+size_t UnfusedMhaDispatch::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * head_num_ * head_size_ * tgt_seq_len_; -+ size_t size_q = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t size_k = batch_size_ * tgt_seq_len_ * head_num_ * head_size_; -+ size_t size_v = size_k; -+ size_t qkv_len = size_q + size_k + size_v; -+ size_t q_buf_2_len = size_q; -+ size_t qk_buf_len = batch_size_ * head_num_ * src_seq_len_ * tgt_seq_len_; -+ size_t qkv_buf_2_len = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t qkv_buf_3_len = qkv_buf_2_len; -+ OptAllocator allocator(ALIGN_SIZE); -+ qkv_buf_ = allocator.Malloc(qkv_len * sizeof(T)); -+ q_buf_2_ = allocator.Malloc(q_buf_2_len * sizeof(T)); -+ if (use_past_) { -+ output1_ = 0; -+ output2_ = 0; -+ } else { -+ output1_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ output2_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ } -+ allocator.Free(qkv_buf_); -+ qk_buf_ = allocator.Malloc(qk_buf_len * sizeof(T)); -+ allocator.Free(q_buf_2_); -+ if (!use_past_) -+ allocator.Free(output1_); -+ qkv_buf_2_ = allocator.Malloc(qkv_buf_2_len * sizeof(T)); -+ allocator.Free(output2_); -+ allocator.Free(qk_buf_); -+ qkv_buf_3_ = allocator.Malloc(qkv_buf_3_len * sizeof(T)); -+ allocator.Free(qkv_buf_2_); -+ allocator.Free(qkv_buf_3_); -+ return allocator.total_size(); -+} -+ -+template -+void UnfusedMhaDispatch::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ // setup inputs -+ T* q_buf_2 = reinterpret_cast(inputs[0]); -+ T* output1 = reinterpret_cast(inputs[1]); -+ T* output2 = reinterpret_cast(inputs[2]); -+ T* attention_mask = reinterpret_cast(inputs[3]); -+ T* position_bias = reinterpret_cast(inputs[4]); -+ // setup inner buffers -+ T* qk_buf = GetBuf(ws, qk_buf_); -+ T* qkv_buf_2 = GetBuf(ws, qkv_buf_2_); -+ int src_seq_len = src_seq_len_; // len of q tensor -+ int tgt_seq_len = tgt_seq_len_; // len of K, V tensors -+ int max = d_sequence_length_host_[0], min = d_sequence_length_host_[0]; -+ for (int i = 0; i < batch_size_; i++) -+ { -+ if (d_sequence_length2_host_[i] < min) -+ min = d_sequence_length2_host_[i]; -+ if (d_sequence_length2_host_[i] > max) -+ max = d_sequence_length2_host_[i]; -+ } -+ tgt_seq_len = max; -+ if (use_past_ && incremental_mode_) { -+ src_seq_len = 1; -+ } -+ // run unfused attention -+ int gemm_dims[] = {tgt_seq_len, src_seq_len, (int)head_size_}; -+ int gemm_lds[] = {(int)tgt_seq_len_, -+ (int)head_size_, -+ tgt_seq_len}; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ int gemm_strides[] = {(int)(tgt_seq_len_ * head_size_), -+ (int)(src_seq_len * head_size_), -+ (src_seq_len * tgt_seq_len)}; -+ //If batch valid != batch_size - run loop to skip not valid batches -+ if (padding_offset_ != nullptr && h_token_num_ < batch_size_) { -+ int offset1 = 0; -+ int offset2 = 0; -+ int offset3 = 0; -+ int head = (is_cross_ && position_bias_) ? int(1) : int(head_num_); -+ for (int i = 0; i < batch_size_; i++) { -+ src_seq_len = (int)(d_sequence_length_host_[i]); -+ tgt_seq_len = (int)(d_sequence_length2_host_[i]); -+ -+ if (src_seq_len == -1) { -+ offset1 += head_num_ * head_size_ * tgt_seq_len_; -+ continue; -+ } -+ gemm_dims[0] = tgt_seq_len; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = head_size_; -+ gemm_lds[0] = tgt_seq_len_; -+ gemm_lds[1] = head_size_; -+ gemm_lds[2] = tgt_seq_len; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * head_size_; -+ gemm_strides[2] = src_seq_len * tgt_seq_len; -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ CublasGemmStridedBatchedWrapper(output1 + offset1, -+ q_buf_2 + offset2, -+ qk_buf + offset3, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ head_num_, -+ cublas_handle, -+ algo_); -+ invokeMixMaskedSoftMax(static_cast(qk_buf) + offset3, -+ (use_past_ ) ? nullptr : (attention_mask == nullptr) ? nullptr : attention_mask + i * src_seq_len_ * tgt_seq_len_, -+ position_bias, -+ d_sequence_length_ + i, -+ d_sequence_length2_ + i, -+ 1, -+ src_seq_len, -+ src_seq_len_, -+ tgt_seq_len, -+ tgt_seq_len_, -+ head_num_, -+ head, -+ (T)(scale_), -+ (use_past_ && !incremental_mode_), -+ stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ -+ gemm_dims[0] = head_size_; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = tgt_seq_len; -+ -+ gemm_lds[0] = head_size_; -+ gemm_lds[1] = tgt_seq_len; -+ gemm_lds[2] = head_size_; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * tgt_seq_len; -+ gemm_strides[2] = src_seq_len * head_size_; -+ CublasGemmStridedBatchedWrapper(output2 + offset1, -+ qk_buf + offset3, -+ qkv_buf_2 + offset2, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ head_num_, -+ cublas_handle, -+ algo_); -+ offset1 += head_num_ * head_size_ * tgt_seq_len_; -+ offset2 += head_num_ * head_size_ * src_seq_len; -+ offset3 += head_num_ * src_seq_len * tgt_seq_len_; -+ } -+ offset1 = 0; -+ for (int i = 0; i < batch_size_; i++) { -+ src_seq_len = (int)(d_sequence_length_host_[i]); -+ if (src_seq_len == -1) continue; -+ invokeTransposeQKV(static_cast(output[0]) + offset1, -+ static_cast(qkv_buf_2) + offset1, -+ 1, -+ src_seq_len, -+ head_num_, -+ head_size_, -+ stream); -+ offset1 += head_num_ * head_size_ * src_seq_len; -+ } -+ } -+ -+ else { -+ CublasGemmStridedBatchedWrapper(output1, -+ q_buf_2, -+ qk_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ batch_size_ * head_num_, -+ cublas_handle, -+ algo_); -+ invokeMixMaskedSoftMax(static_cast(qk_buf), -+ (use_past_ ) ? nullptr : attention_mask, -+ position_bias, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ src_seq_len, -+ src_seq_len_, -+ tgt_seq_len, -+ tgt_seq_len_, -+ head_num_, -+ (is_cross_ && position_bias_) ? int(1) : int(head_num_), -+ (T)(scale_), -+ (use_past_ && !incremental_mode_), -+ stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ -+ gemm_dims[0] = head_size_; -+ gemm_dims[1] = src_seq_len; -+ gemm_dims[2] = tgt_seq_len; -+ -+ gemm_lds[0] = head_size_; -+ gemm_lds[1] = tgt_seq_len; -+ gemm_lds[2] = head_size_; -+ -+ gemm_strides[0] = tgt_seq_len_ * head_size_; -+ gemm_strides[1] = src_seq_len * tgt_seq_len; -+ gemm_strides[2] = src_seq_len * head_size_; -+ CublasGemmStridedBatchedWrapper(output2, -+ qk_buf, -+ qkv_buf_2, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_strides, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ batch_size_ * head_num_, -+ cublas_handle, -+ algo_); -+ if (padding_offset_ == nullptr || incremental_mode_) { -+ invokeTransposeQKV(static_cast(output[0]), -+ static_cast(qkv_buf_2), -+ batch_size_, -+ src_seq_len, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeTransposeAttentionOutRemovePadding(qkv_buf_2, -+ reinterpret_cast(output[0]), -+ h_token_num_, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ padding_offset_, -+ stream); -+ } -+ } -+ return; -+} -+ -+template -+bool FusedCutlassMhaDispatch::isSupport() -+{ -+ return fuse_mha_->isSupport(); -+} -+template -+size_t FusedCutlassMhaDispatch::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * head_num_ * head_size_ * tgt_seq_len_; -+ size_t size_q = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t size_k = batch_size_ * tgt_seq_len_ * head_num_ * head_size_; -+ size_t size_v = size_k; -+ size_t qkv_len = size_q + size_k + size_v; -+ size_t q_buf_2_len = size_q; -+ size_t qk_buf_len = batch_size_ * head_num_ * src_seq_len_ * tgt_seq_len_; -+ size_t qkv_buf_2_len = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ size_t qkv_buf_3_len = qkv_buf_2_len; -+ OptAllocator allocator(ALIGN_SIZE); -+ qkv_buf_ = allocator.Malloc(qkv_len * sizeof(T)); -+ q_buf_2_ = allocator.Malloc(q_buf_2_len * sizeof(T)); -+ if (use_past_) { -+ output1_ = 0; -+ output2_ = 0; -+ } else { -+ output1_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ output2_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ } -+ allocator.Free(qkv_buf_); -+ qk_buf_ = 0; // not in use -+ qkv_buf_2_ = allocator.Malloc(qkv_buf_2_len * sizeof(T)); -+ qkv_buf_3_ = allocator.Malloc(qkv_buf_3_len * sizeof(T)); -+ size_t size = 0; -+ size = fuse_mha_->GetWorkspaceSize(); -+ if (size > 0) { -+ mha_ = allocator.Malloc(size); -+ } -+ else { -+ mha_ = 0; // not in use -+ } -+ fuse_mha_->SetWSOffset(mha_); -+ allocator.Free(qkv_buf_3_); -+ size_t total = allocator.total_size(); -+ return total; -+} -+template -+void FusedCutlassMhaDispatch::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ fuse_mha_->forward(inputs, outputs, ws, cublas_handle, stream); -+} -+ -+template -+size_t Attention::GetWorkspaceSize() -+{ -+ size_t size = dispatch_->GetWorkspaceSize(); -+ qkv_buf_ = dispatch_->qkv_buf_; -+ q_buf_2_ = dispatch_->q_buf_2_; -+ output1_ = dispatch_->output1_; -+ output2_ = dispatch_->output2_; -+ qk_buf_ = dispatch_->qk_buf_; -+ qkv_buf_2_ = dispatch_->qkv_buf_2_; -+ qkv_buf_3_ = dispatch_->qkv_buf_3_; -+ mha_ = dispatch_->mha_; -+ return size; -+} -+ -+template -+Attention::Attention(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ qkv_bias_(qkv_bias), -+ projection_bias_(projection_bias), -+ is_cross_(is_cross), -+ position_bias_(position_bias), -+ mask_(mask), -+ use_past_(use_past), -+ BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, rank_num, algo) -+{ -+ std::shared_ptr> fuse = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ algo); -+ if (fuse->isSupport()) { -+ fmha_type_ = MhaDispatch::Type::CutlassFix; -+ dispatch_ = fuse; -+ } -+ else { -+ std::shared_ptr> unfuse = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ algo); -+ fmha_type_ = MhaDispatch::Type::UnFused; -+ dispatch_ = unfuse; -+ } -+} -+ -+MhaDispatch::MhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ is_cross_(is_cross), -+ position_bias_(position_bias), -+ scale_(scale), -+ mask_(mask), -+ use_past_(use_past), -+ BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, rank_num, algo) { -+ } -+template -+FusedCutlassMhaDispatch::FusedCutlassMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool use_past, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) -+{ -+ typedef typename std::conditional::value, cutlass::half_t, float>::type Type; -+ fuse_mha_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo); -+} -+ -+template -+void Attention::forward(std::vector& inputs, -+ const std::vector& outputs, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{dispatch_->SetRankId(rank_id_); -+ ws = GetBuf(ws, ws_offset_); -+ in_idx_ = 0; -+ T* qkv_buf = GetBuf(ws, qkv_buf_); -+ T* q_buf_2 = GetBuf(ws, q_buf_2_); -+ T* qkv_buf_3 = GetBuf(ws, qkv_buf_3_); -+ -+ T* output1 = nullptr; -+ T* output2 = nullptr; -+ if (use_past_) { -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } else { -+ output1 = GetBuf(ws, output1_); -+ output2 = GetBuf(ws, output2_); -+ } -+ int actual_hidden_size = head_size_ * head_num_; -+ int gemm_dims[] = { -+ 3 * (int)actual_hidden_size, (int)h_token_num_, (int)hidden_size_}; -+ int gemm_lds[] = {3 * (int)actual_hidden_size, (int)hidden_size_, 3 * (int)actual_hidden_size}; -+ T* from_tensor = reinterpret_cast(inputs[in_idx_++]); -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ if (is_cross_) { -+ gemm_dims[0] = actual_hidden_size; -+ gemm_dims[1] = h_token_num_; -+ gemm_dims[2] = hidden_size_; -+ gemm_lds[0] = actual_hidden_size; -+ gemm_lds[1] = hidden_size_; -+ gemm_lds[2] = actual_hidden_size; -+ T* encoder_output = reinterpret_cast(inputs[in_idx_++]); -+ T* weight_q = reinterpret_cast(inputs[in_idx_++]); -+ if (use_past_) { -+ gemm_lds[2] = 3 * actual_hidden_size; -+ } -+ CublasGemmWrapper(weight_q, -+ from_tensor, -+ qkv_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ -+ T *kv = qkv_buf + h_token_num_ * hidden_size_; -+ gemm_dims[0] = 2 * actual_hidden_size; -+ gemm_dims[1] = h_token_num2_; -+ gemm_dims[2] = hidden_size_; -+ -+ gemm_lds[0] = 2 * actual_hidden_size; -+ gemm_lds[1] = hidden_size_; -+ gemm_lds[2] = 2 * actual_hidden_size; -+ if (use_past_) { -+ gemm_lds[2] = 3 * actual_hidden_size; -+ kv = qkv_buf + actual_hidden_size; -+ } -+ T* weight_kv = reinterpret_cast(inputs[in_idx_++]); -+ CublasGemmWrapper(weight_kv, -+ encoder_output, -+ kv, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ T* bias_qkv = (qkv_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (padding_offset_ == nullptr) { -+ if (use_past_) { -+ if (incremental_mode_) { -+ output1 += (cur_token_id_) * head_size_; -+ output2 += (cur_token_id_) * head_size_; -+ } -+ invokeAddFusedQKVBiasTransposeUsePast(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ // restore cache to pointer start after concat -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } -+ else { -+ invokeCrossAddFusedQKVBiasTranspose(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ stream); -+ } -+ } else { -+ if (use_past_) { -+ invokeAddFusedQKVBiasTransposeUsePastMB(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ h_token_num_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ incremental_mode_, -+ !(typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)), -+ stream); -+ -+ } -+ else if (typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)) { -+ invokeCrossAddFusedQKVBiasTransposeMBVSL(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ padding_offset2_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ h_token_num_, -+ h_token_num2_, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeCrossAddFusedZP_QKVBiasTranspose(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ tgt_seq_len_, -+ head_num_, -+ head_size_, -+ h_token_num_, -+ h_token_num2_, -+ padding_offset_, -+ padding_offset2_, -+ stream); -+ } -+ } -+ } else { // end of is_cross -+ T* weight_qkv = reinterpret_cast(inputs[in_idx_++]); -+ CublasGemmWrapper(weight_qkv, -+ from_tensor, -+ qkv_buf, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ T* bias_qkv = (qkv_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (padding_offset_ == nullptr) { -+ if (use_past_) { -+ if (incremental_mode_) { -+ output1 += (cur_token_id_) * head_size_; -+ output2 += (cur_token_id_) * head_size_; -+ } -+ invokeAddFusedQKVBiasTransposeUsePast(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ // restore cache to pointer start after concat -+ output1 = reinterpret_cast(k_cache_); -+ output2 = reinterpret_cast(v_cache_); -+ } else { -+ invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ 0, -+ stream); -+ } -+ } else { -+ if (use_past_) { -+ invokeAddFusedQKVBiasTransposeUsePastMB(q_buf_2, -+ output1, -+ output2, -+ qkv_buf, -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ d_sequence_length2_, -+ batch_size_, -+ h_token_num_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ incremental_mode_, -+ !(typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)), -+ stream); -+ } else if (typeid(FusedCutlassMhaDispatch) == typeid(*dispatch_)) { -+ invokeAddFusedQKVBiasTransposeMBVSL(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ padding_offset_, -+ d_sequence_length_, -+ batch_size_, -+ src_seq_len_, -+ h_token_num_, -+ head_num_, -+ head_size_, -+ stream); -+ } else { -+ invokeAddFusedZP_QKVBiasTranspose(static_cast(q_buf_2), -+ static_cast(output1), -+ static_cast(output2), -+ static_cast(qkv_buf), -+ bias_qkv, -+ batch_size_, -+ src_seq_len_, -+ head_num_, -+ head_size_, -+ h_token_num_, -+ padding_offset_, -+ stream); -+ } -+ } -+ } -+ // Do Softmax(Q*Kt + Bias + Mask) -+ T* attention_mask = (mask_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* position_bias = (position_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ if (attention_mask && padding_offset_ != nullptr && typeid(FusedCutlassMhaDispatch) != typeid(dispatch_) && !(use_past_)) { -+ invokeBuildEncoderAttentionMask( -+ attention_mask, d_sequence_length2_, d_sequence_length_, batch_size_, src_seq_len_, tgt_seq_len_, incremental_mode_, stream); -+ } -+ std::vector dispatch_in = {q_buf_2, output1, output2, attention_mask, position_bias}; -+ std::vector dispatch_out = {qkv_buf_3}; -+ dispatch_->forward(dispatch_in, dispatch_out, ws, cublas_handle, stream); -+ gemm_ops[0] = CUBLAS_OP_N; -+ gemm_ops[1] = CUBLAS_OP_N; -+ gemm_dims[0] = hidden_size_; -+ gemm_dims[1] = h_token_num_; -+ gemm_dims[2] = actual_hidden_size; -+ -+ gemm_lds[0] = hidden_size_; -+ gemm_lds[1] = actual_hidden_size; -+ gemm_lds[2] = hidden_size_; -+ -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ qkv_buf_3, -+ static_cast(outputs[0]), -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ const_cast(gemm_data_types), -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ if (projection_bias_) { -+ int len = h_token_num_; -+ invokeAddBias( -+ static_cast(outputs[0]), (const T*)(inputs[in_idx_++]), len, hidden_size_, stream); -+ } -+ return; -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/attention.h b/src/fastertransformer/layers/ms_layers/attention.h -new file mode 100644 -index 0000000..567947c ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/attention.h -@@ -0,0 +1,411 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/gemm.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include -+#include -+ -+namespace fastertransformer { -+class MhaDispatch : public BaseLayerMS{ -+protected: -+ bool position_bias_; -+ float scale_; -+ bool mask_; -+ bool is_cross_; -+ int* padding_offset_{nullptr}; -+ int* d_sequence_length_{nullptr}; -+ int* padding_offset2_{nullptr}; -+ int* d_sequence_length2_{nullptr}; -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ bool incremental_mode_{false}; -+ bool use_past_{false}; // use past mode -+ int* d_sequence_length_host_; -+ int* d_sequence_length2_host_; -+ -+public: -+ typedef enum Type { -+ UnFused, -+ CutlassFix -+ } Type; -+ size_t qkv_buf_{0}; -+ size_t q_buf_2_{0}; -+ size_t output1_{0}; -+ size_t output2_{0}; -+ size_t qk_buf_{0}; -+ size_t qkv_buf_2_{0}; -+ size_t qkv_buf_3_{0}; -+ size_t mha_{0}; -+ size_t GetBatchSize() {return batch_size_;} -+ size_t GetSrcSeqLen() {return src_seq_len_;} -+ size_t GetTgtSeqLen() {return tgt_seq_len_;} -+ size_t GetHeadNum() {return head_num_;} -+ size_t GetHeadSize() {return head_size_;} -+ size_t GetHiddenSize() {return hidden_size_;} -+ bool GetIsCross() {return is_cross_;} -+ float GetScale() {return scale_;} -+ size_t GetHTokenNum() {return h_token_num_;} -+ bool GetPositionBias() {return position_bias_;} -+ bool GetIncrementalMode() {return incremental_mode_;} -+ int GetCurTokenId() {return cur_token_id_;} -+ bool GetUsePast() {return use_past_;} -+ virtual void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ } -+ virtual void SetCurTokenId(int cur_token_id) -+ { -+ cur_token_id_ = cur_token_id; -+ } -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) -+ { -+ h_token_num_ = h_token_num; -+ h_token_num2_ = h_token_num2; -+ } -+ virtual void SetCross(bool cross) {is_cross_ = cross;} -+ virtual void SetIncrementalMode(bool incremental_mode) -+ { -+ incremental_mode_ = incremental_mode; -+ } -+ virtual void SetScale(float scale) {scale_ = scale;} -+ virtual void SetUsePast(bool use_past) -+ { -+ use_past_ = use_past; -+ } -+ void SetBuffers(size_t qkv_buf = 0, -+ size_t q_buf_2 = 0, -+ size_t output1 = 0, -+ size_t output2 = 0, -+ size_t qk_buf = 0, -+ size_t qkv_buf_2 = 0, -+ size_t qkv_buf_3 = 0, -+ size_t mha = 0) -+ { -+ qkv_buf_ = qkv_buf; -+ q_buf_2_ = q_buf_2; -+ output1_ = output1; -+ output2_ = output2; -+ qk_buf_ = qk_buf; -+ qkv_buf_2_ = qkv_buf_2; -+ qkv_buf_3_ = qkv_buf_3; -+ mha_ = mha; -+ } -+ virtual void SetOption(bool position_bias = false, bool mask = true) -+ { -+ position_bias_ = position_bias; -+ mask_ = mask; -+ } -+ MhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross_, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+ virtual bool isSupport() -+ { -+ return true; -+ } -+ virtual void SetFuseWS(void* ws){} -+ virtual size_t GetWorkspaceSize() override {return 0;} -+}; -+template -+class FusedCutlassMhaDispatch: public MhaDispatch { -+private: -+ std::shared_ptr fuse_mha_; -+public: -+ FusedCutlassMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross_, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); -+ void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) override -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ fuse_mha_->SetVslParam(padding_offset, padding_offset2, d_sequence_length, d_sequence_length2); -+ } -+ void SetOption(bool position_bias = false, bool mask = true) override -+ { -+ mask_ = mask; -+ position_bias_ = position_bias; -+ fuse_mha_->SetOption(position_bias, mask); -+ } -+ void SetCross(bool cross) override -+ { -+ is_cross_ = cross; -+ fuse_mha_->SetCross(cross); -+ } -+ void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) override -+ { -+ h_token_num_ = h_token_num; -+ h_token_num2_ = h_token_num2; -+ fuse_mha_->SetHTokenNum(h_token_num, h_token_num2); -+ } -+ void SetIncrementalMode(bool incremental_mode) override -+ { -+ incremental_mode_ = incremental_mode; -+ fuse_mha_->SetIncrementalMode(incremental_mode); -+ } -+ void SetScale(float scale) override -+ { -+ scale_ = scale; -+ fuse_mha_->SetScale(scale); -+ } -+ void SetUsePast(bool use_past) override -+ { -+ use_past_ = use_past; -+ fuse_mha_->SetUsePast(use_past); -+ } -+ void SetCurTokenId(int cur_token_id) override -+ { -+ cur_token_id_ = cur_token_id; -+ fuse_mha_->SetCurTokenId(cur_token_id); -+ } -+ bool isSupport() override; -+ size_t GetWorkspaceSize() override; -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+}; -+template -+class UnfusedMhaDispatch: public MhaDispatch { -+public: -+ UnfusedMhaDispatch(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross = false, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) : -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) { -+ d_sequence_length_host_ = (int*)malloc(batch_size_ * sizeof(int)); -+ d_sequence_length2_host_ = (int*)malloc(batch_size_ * sizeof(int)); -+ -+ } -+ void SetVslParam(int* padding_offset = nullptr, int* padding_offset2 = nullptr, int* d_sequence_length = nullptr, int* d_sequence_length2 = nullptr) override -+ { -+ padding_offset_ = padding_offset; -+ padding_offset2_ = padding_offset2; -+ d_sequence_length_ = d_sequence_length; -+ d_sequence_length2_ = d_sequence_length2; -+ if (d_sequence_length_ != nullptr) { -+ cudaD2Hcpy(d_sequence_length_host_, d_sequence_length_, batch_size_); -+ } -+ if (d_sequence_length2_ != nullptr) { -+ cudaD2Hcpy(d_sequence_length2_host_, d_sequence_length2_, batch_size_); -+ } -+ } -+ size_t GetWorkspaceSize() override; -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+}; -+template -+class Attention : public BaseLayerMS { -+private: -+ std::shared_ptr dispatch_; -+ -+ bool qkv_bias_; // ture -+ bool projection_bias_; // ture -+ bool is_cross_; // false -+ bool position_bias_; -+ bool mask_; -+ bool use_past_; // use past mode -+ MhaDispatch::Type fmha_type_; -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ void* k_cache_{nullptr}; -+ void* v_cache_{nullptr}; -+ bool incremental_mode_{false}; -+ size_t qkv_buf_{0}; -+ size_t q_buf_2_{0}; -+ size_t output1_{0}; -+ size_t output2_{0}; -+ size_t qk_buf_{0}; -+ size_t qkv_buf_2_{0}; -+ size_t qkv_buf_3_{0}; -+ size_t mha_{0}; -+ int* padding_offset_{nullptr}; -+ int* d_sequence_length_{nullptr}; -+ int* padding_offset2_{nullptr}; -+ int* d_sequence_length2_{nullptr}; -+public: -+ void printParam() -+ { -+ std::cout<<"attn param\n"; -+ std::cout<<"batch_size = "<"; -+} -+ -+template -+void check(T result, char const* const func, const char* const file, int const line) -+{ -+ if (result) { -+ throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaKernalGetErrorEnum(result)) + " " -+ + file + ":" + std::to_string(line) + " \n"); -+ } -+} -+ -+#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) -+ -+template -+__inline__ __device__ -+T gelu(T x) -+{ -+ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); -+ return x * cdf; -+} -+ -+template <> -+__inline__ __device__ -+half2 gelu(half2 val) -+{ -+ half2 val_pow3 = __hmul2(val, __hmul2(val, val)); -+ float2 tmp_pow = __half22float2(val_pow3); -+ float2 tmp = __half22float2(val); -+ -+ tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); -+ tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); -+ return __hmul2(val, __float22half2_rn(tmp)); -+ -+} -+ -+template -+__global__ -+void add_bias_act(T* out, const T* bias, int m, int k, int n) -+{ -+ T val, reg_bias; -+ -+ int ite = n / blockDim.x; -+ int tid = threadIdx.x; -+ -+ for(int i = 0; i < ite; ++i) -+ { -+ int row_id = blockIdx.x; -+ while(row_id < m){ -+ reg_bias = __ldg(&bias[row_id / k * n + i * blockDim.x + tid]); -+ -+ val = out[tid + i * blockDim.x + row_id * n]+ reg_bias; -+ out[tid + i * blockDim.x + row_id * n] = gelu(val); -+ row_id += gridDim.x; -+ } -+ } -+} -+ -+template <> -+__global__ -+void add_bias_act(half* out, const half* bias, int m, int k, int n) -+{ -+ half2 val, reg_bias; -+ int ite = n / blockDim.x / 2; -+ int tid = threadIdx.x; -+ -+ half2* out_ptr = (half2*) out; -+ const half2* bias_ptr = (half2*) bias; -+ for(int i = 0; i < ite; ++i) -+ { -+ int row_id = blockIdx.x; -+ while(row_id < m){ -+ reg_bias = __ldg(&bias_ptr[row_id / k * n + i * blockDim.x + tid]); -+ val = out_ptr[tid + i * blockDim.x + row_id * n / 2]; -+ val = __hadd2(val, reg_bias); -+ out_ptr[tid + i * blockDim.x + row_id * n / 2] = gelu(val); -+ row_id += gridDim.x; -+ } -+ } -+} -+ -+template -+void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int k, int n, cudaStream_t stream) -+{ -+ dim3 grid(m * k / 4.); -+ dim3 block(n / 16); -+ assert(block.x <= 1024); -+ add_bias_act<<>>(out, bias, m * k, k, n); -+} -+ -+__device__ void ScanWarp(int32_t* shm_data) { -+ volatile int32_t* vshm_data = shm_data; -+ vshm_data[0] += vshm_data[-1]; -+ vshm_data[0] += vshm_data[-2]; -+ vshm_data[0] += vshm_data[-4]; -+ vshm_data[0] += vshm_data[-8]; -+ vshm_data[0] += vshm_data[-16]; -+} -+ -+__device__ void ScanBlock(int32_t* shm_data) { -+ int32_t warp_id = threadIdx.x >> 5; -+ int32_t lane = threadIdx.x & 31; -+ extern __shared__ int32_t warp_sum[]; // 16 zero padding -+ // scan each warp -+ ScanWarp(shm_data); -+ __syncthreads(); -+ // write sum of each warp to warp_sum -+ if (lane == 31) { -+ warp_sum[16 + warp_id] = *shm_data; -+ } -+ __syncthreads(); -+ // use a single warp to scan warp_sum -+ if (warp_id == 0) { -+ ScanWarp(warp_sum + 16 + lane); -+ } -+ __syncthreads(); -+ // add base -+ if (warp_id > 0) { -+ *shm_data += warp_sum[16 + warp_id - 1]; -+ } -+ __syncthreads(); -+} -+ -+__global__ void ScanAndWritePartSumKernel(const int32_t* input, -+ int32_t* output, size_t n, -+ size_t part_num, size_t shared_num) { -+ // the first 16 + 32 is used to save warp sum -+ extern __shared__ int32_t shm[]; -+ int32_t warp_id = threadIdx.x >> 5; -+ int32_t lane = threadIdx.x & 31; -+ for (int tid = threadIdx.x; tid < shared_num; tid += blockDim.x) { -+ shm[tid] = 0; -+ } -+ __syncthreads(); -+ // process each part -+ for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) { -+ // store this part input to shm -+ size_t index = part_i * blockDim.x + threadIdx.x; -+ int32_t* myshm = shm + (16 + 32) + warp_id * (16 + 32) + 16 + lane; -+ *myshm = index < n ? input[index] : 0; -+ __syncthreads(); -+ // scan on shared memory -+ ScanBlock(myshm); -+ __syncthreads(); -+ // write result -+ if (index < n) { -+ output[index] = *myshm; -+ } -+ } -+} -+ -+__global__ void ScanAndWritePartSumKernel2(const int32_t* input, int32_t* output, size_t n, -+ size_t part_size) { -+ size_t part_begin = part_size * blockIdx.x; -+ size_t part_end = min(part_size * (blockIdx.x + 1), n); -+ int32_t acc = 0; -+ for (size_t i = part_begin; i < part_end; ++i) { -+ acc += input[i]; -+ output[i] = acc; -+ } -+} -+ -+void ScanThenFan(int32_t* input, int32_t* buffer, int32_t* output, -+ size_t n, size_t length, cudaStream_t stream) { -+ size_t part_size = length; -+ size_t part_num = (n + part_size - 1) / part_size; -+ size_t block_num = std::min(part_num, 128); -+ size_t warp_num = (part_size + 31) / 32; -+ size_t shm_num = 16 + 32 + warp_num * (16 + 32); -+ size_t shm_size = shm_num * sizeof(int32_t); -+ ScanAndWritePartSumKernel<<>>(input, output, n, part_num, shm_num); -+} -+ -+template -+__global__ void OneHotTransposeFusionKernel(const T* in, T* out1, T* out2, int batch_size, int length, int expert_num) { -+ //__shared__ T s_mem[2048]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batchlength_idx = tid / expert_num; -+ int expert_idx = tid % expert_num; -+ int batch_idx = batchlength_idx / length; -+ int length_idx = batchlength_idx % length; -+ bool is_on = (in[batch_idx * length + length_idx] == expert_idx); -+ out1[tid] = static_cast(is_on); -+ out2[length_idx + expert_idx * length + batch_idx * expert_num * length] = static_cast(is_on); -+ } -+} -+ -+template -+void OneHotTransposeFusionKernelLaunch(const T* in, T* out1, T* out2, int batch_size, int length, int expert_num, cudaStream_t stream) { -+ if (length != 1) { -+ dim3 grid(batch_size * expert_num); -+ dim3 block(1024); -+ OneHotTransposeFusionKernel<<>>(in, out1, out2, batch_size, length, expert_num); -+ } else { -+ dim3 grid(batch_size); -+ dim3 block(batch_size * expert_num); -+ OneHotTransposeFusionKernel<<>>(in, out1, out2, batch_size, length, expert_num); -+ } -+} -+ -+template -+__global__ void MulLessCastMulReduceMulOnehotMulFusionKernel(T* in1, T* in2, S* out, int batch_size, int length, int expert_num, int max_expert_num, S threshold) { -+ __shared__ T s_mem[512][16 + 1]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batch_idx = tid / length / expert_num; -+ int length_idx = tid / expert_num % length; -+ int expert_idx = tid % expert_num; -+ T val_in1 = in1[tid]; -+ // fuse transpose -+ T val_in2 = in2[batch_idx * length * expert_num + expert_idx * length + length_idx]; -+ T mul1 = val_in1 * val_in2; -+ T mul2 = static_cast(mul1 < threshold) * val_in1; -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] = mul2; -+ __syncthreads(); -+ if (expert_idx < 8) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 8]; -+ } -+ __syncthreads(); -+ if (expert_idx < 4) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 4]; -+ } -+ __syncthreads(); -+ if (expert_idx < 2) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 2]; -+ } -+ __syncthreads(); -+ if (expert_idx < 1) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 1]; -+ } -+ __syncthreads(); -+ -+ for (int i = 0; i != max_expert_num; ++i) { -+ out[tid * max_expert_num + i] = static_cast(0); -+ } -+ out[tid * max_expert_num + mul1] = static_cast(s_mem[length_idx % 512][0]); -+ __syncthreads(); -+ } -+} -+ -+template <> -+__global__ void MulLessCastMulReduceMulOnehotMulFusionKernel(int* in1, int* in2, half* out, int batch_size, int length, int expert_num, int max_expert_num, half threshold) { -+ __shared__ int s_mem[512][16 + 1]; -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < batch_size * length * expert_num; tid += blockDim.x * gridDim.x) { -+ int batch_idx = tid / length / expert_num; -+ int length_idx = tid / expert_num % length; -+ int expert_idx = tid % expert_num; -+ int val_in1 = in1[tid]; -+ // fuse transpose -+ int val_in2 = in2[batch_idx * length * expert_num + expert_idx * length + length_idx]; -+ int mul1 = val_in1 * val_in2; -+ int mul2 = static_cast(mul1 < __half2int_rn(threshold)) * val_in1; -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] = mul2; -+ __syncthreads(); -+ if (expert_idx < 8) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 8]; -+ } -+ __syncthreads(); -+ if (expert_idx < 4) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 4]; -+ } -+ __syncthreads(); -+ if (expert_idx < 2) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 2]; -+ } -+ __syncthreads(); -+ if (expert_idx < 1) { -+ s_mem[(batch_idx * length + length_idx) % 512][expert_idx] += s_mem[(batch_idx * length + length_idx) % 512][expert_idx + 1]; -+ } -+ __syncthreads(); -+ -+ for (int i = 0; i != max_expert_num; ++i) { -+ out[tid * max_expert_num + i] = (half)0.f; -+ } -+ out[tid * max_expert_num + mul1] = __int2half_rd(s_mem[length_idx % 512][0]); -+ } -+} -+ -+template -+void MulLessCastMulReduceMulOnehotMulFusionKernelLaunch(T* in1, T* in2, S* out, int batch_size, int length, int expert_num, int onehot_size, S threshold, cudaStream_t stream) { -+ if (length != 1) { -+ dim3 grid(expert_num); -+ dim3 block(512); -+ MulLessCastMulReduceMulOnehotMulFusionKernel<<>>(in1, in2, out, batch_size, length, expert_num, onehot_size, threshold); -+ } else { -+ dim3 grid(expert_num); -+ dim3 block(512); -+ MulLessCastMulReduceMulOnehotMulFusionKernel<<>>(in1, in2, out, batch_size, length, expert_num, onehot_size, threshold); -+ } -+} -+ -+template -+__global__ void AddBiasTransposeTransposeFusionKernel(const T* in, const T* bias, T* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2) { -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < trans_s1 * trans_s2; tid += blockDim.x * gridDim.x) { -+ int expert_idx = tid / (ori_s2 * ori_s3); -+ int hidden_idx = tid % ori_s3; -+ T in_val = in[tid] + bias[expert_idx * ori_s3 + hidden_idx]; -+ -+ int ori_idx1 = tid / trans_s1; -+ int ori_idx2 = tid % trans_s1; -+ -+ int trans1_idx = ori_idx2 * trans_s2 + ori_idx1; -+ -+ int trans2_idx1 = trans1_idx / trans2_s1; -+ int trans2_idx2 = trans1_idx % trans2_s1; -+ int trans2_idx = trans2_idx2 * trans2_s2 + trans2_idx1; -+ out[trans2_idx] = in_val; -+ } -+} -+ -+template <> -+__global__ void AddBiasTransposeTransposeFusionKernel(const half* in, const half* bias, half* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2) { -+ for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < trans_s1 * trans_s2; tid += blockDim.x * gridDim.x) { -+ int expert_idx = tid / (ori_s2 * ori_s3); -+ int hidden_idx = tid % ori_s3; -+ half in_val = __hadd(in[tid], bias[expert_idx * ori_s3 + hidden_idx]); -+ -+ int ori_idx1 = tid / trans_s1; -+ int ori_idx2 = tid % trans_s1; -+ -+ int trans1_idx = ori_idx2 * trans_s2 + ori_idx1; -+ -+ int trans2_idx1 = trans1_idx / trans2_s1; -+ int trans2_idx2 = trans1_idx % trans2_s1; -+ int trans2_idx = trans2_idx2 * trans2_s2 + trans2_idx1; -+ out[trans2_idx] = in_val; -+ } -+} -+ -+template -+void AddBiasTransposeTransposeFusionKernelLaunch(const T* in, const T* bias, T* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2, cudaStream_t stream) { -+ dim3 grid(trans_s1 * trans_s2 / 4); -+ dim3 block(1024); -+ AddBiasTransposeTransposeFusionKernel<<>>(in, bias, out, ori_s1, ori_s2, ori_s3, trans_s1, trans_s2, trans2_s1, trans2_s2); -+} -+ -+template <> -+void AddBiasTransposeTransposeFusionKernelLaunch(const half* in, const half* bias, half* out, int ori_s1, int ori_s2, int ori_s3, int trans_s1, int trans_s2, int trans2_s1, int trans2_s2, cudaStream_t stream) { -+ dim3 grid(trans_s1 * trans_s2 / 4); -+ dim3 block(1024); -+ AddBiasTransposeTransposeFusionKernel<<>>(in, bias, out, ori_s1, ori_s2, ori_s3, trans_s1, trans_s2, trans2_s1, trans2_s2); -+} -+ -+#endif -diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.cc b/src/fastertransformer/layers/ms_layers/debug_utils.cc -new file mode 100644 -index 0000000..6fc1330 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/debug_utils.cc -@@ -0,0 +1,140 @@ -+#include -+#include -+#include -+#include "src/fastertransformer/utils/memory_utils.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+namespace fastertransformer { -+ -+ -+template -+void printTensor(char* str, T* input, int size) -+{ -+ printf("%s ", str); -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ std::cout << input_host[k] << ","; -+ if (k % 16 == 0 && k != 0) -+ std::cout << std::endl; -+ } -+ -+ std::cout << std::endl; -+ -+ free(input_host); -+} -+int GetSeq(int* d_seq_len, int idx, int batch) -+{ -+ int* input_device = d_seq_len; -+ int* input_host = (int*)malloc(batch * sizeof(int)); -+ -+ cudaD2Hcpy(input_host, input_device, batch); -+ int num = input_host[idx]; -+ free(input_host); -+ return num; -+} -+template -+void isNan(char* str, T* input, int size) -+{ -+ std::cout << str << " " -+ << " size is " << size; -+ T* input_device = input; -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(input_host, input_device, size); -+ -+ for (int k = 0; k < (int)size; k++) { -+ if (std::isnan((T)input_host[k]) || std ::isinf((T)input_host[k])) { -+ std::cout << "found NAN or INF " << k; -+ break; -+ } -+ } -+ -+ std::cout << std::endl; -+ free(input_host); -+} -+template -+T checksum(const T* tensor, int size) -+{ -+ auto tensor_host =(T*)malloc(size * sizeof(T)); -+ double sum = 0.; -+ cudaD2Hcpy(tensor_host, tensor, size); -+ for (int i = 0; i < size; i++) { -+ sum += (double)tensor_host[i]; -+ } -+ return static_cast(sum); -+} -+template -+double checksum2(char* str, const T* tensor, int size) -+ -+{ -+ double sum = 0.; -+ T* ptr = (T*)malloc(size * sizeof(T)); -+ -+ cudaD2Hcpy(ptr, tensor, size); -+ -+ for (int i = 0; i < size; i++) { -+ -+ sum += ptr[i]; -+ -+ } -+ std::cout << "checksum of "<< str << "is " << sum << std::endl; -+ free(ptr); -+ return sum; -+ -+} -+template -+void saveTensor(const std::string& name, T* tensor, int size) -+{ -+ auto tensor_host = std::make_unique(size); -+ T* ptr = tensor_host.get(); -+ cudaD2Hcpy(ptr, tensor, size); -+ std::ofstream wf(name + ".bin", std::ofstream::out | std::ofstream::binary); -+ wf.write(reinterpret_cast(ptr), size * sizeof(T)); -+ wf.close(); -+} -+ -+template -+void saveTensorFile(const std::string& name, T* tensor, int size) { -+ T* input_host = (T*)malloc(size * sizeof(T)); -+ cudaD2Hcpy(input_host, tensor, size); -+ std::ofstream wf(name+ ".bin", std::ios::out | std::ios::binary); -+ if(!wf) { -+ std::cout << "Cannot open file!" << std::endl; -+ return; -+ } -+ wf.write((char *)input_host,sizeof(T)*size); -+ wf.close(); -+} -+uint64_t GetTimeUs() -+{ -+ const int USEC = 1000000; -+ const int MSEC = 1000; -+ struct timespec ts = {0, 0}; -+ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { -+ return 0; -+ } -+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); -+ return retval; -+} -+template void saveTensorFile(const std::string& name, half* tensor, int size); -+template void saveTensorFile(const std::string& name, float* tensor, int size); -+template void saveTensorFile(const std::string& name, int* tensor, int size); -+template void printTensor(char* str, float* input, int size); -+template void isNan(char* str, float* input, int size); -+template float checksum(const float* tensor, int size); -+template void saveTensor(const std::string& name, float* tensor, int size); -+ -+template void printTensor(char* str, half* input, int size); -+template void isNan(char* str, half* input, int size); -+template half checksum(const half* tensor, int size); -+template void saveTensor(const std::string& name, half* tensor, int size); -+ -+template void printTensor(char* str, int* input, int size); -+template void saveTensor(const std::string& name, int* tensor, int size); -+template double checksum2(char* str, const float* tensor, int size); -+template double checksum2(char* str, const half* tensor, int size); -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.h b/src/fastertransformer/layers/ms_layers/debug_utils.h -new file mode 100644 -index 0000000..68ee331 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/debug_utils.h -@@ -0,0 +1,41 @@ -+#pragma once -+#include -+#include -+#include -+#include -+#include -+#if __has_include("NvInferRuntimeCommon.h") -+#include "NvInferRuntimeCommon.h" -+#else -+namespace nvinfer1 { -+enum class DataType : int32_t -+{ -+ kFLOAT = 0, -+ kHALF = 1, -+ kINT8 = 2, -+ kINT32 = 3, -+ kBOOL = 4 -+}; -+} -+#endif -+namespace fastertransformer { -+ -+#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -+#define ALIGN(x, y) (UP_DIV(x, y) * (y)) -+#define ALIGN_SIZE 16 -+ -+template -+void printTensor(char* str, T* input, int size); -+template -+void isNan(char* str, T* input, int size); -+template -+T checksum(const T* tensor, int size); -+template -+void saveTensor(const std::string& name, T* tensor, int size); -+int GetSeq(int* d_seq_len, int idx, int batch); -+template -+double checksum2(char* str, const T* tensor, int size); -+template -+void saveTensorFile(const std::string& name, T* tensor, int size); -+uint64_t GetTimeUs(); -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/decoder.cc b/src/fastertransformer/layers/ms_layers/decoder.cc -new file mode 100644 -index 0000000..16cd1e9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/decoder.cc -@@ -0,0 +1,375 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/decoder.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+namespace fastertransformer { -+ -+template -+Decoder::Decoder(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ float eps4, -+ bool layernorm_post, -+ bool has_beta, -+ bool is_layernorm, -+ bool ffn_fp16, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool ffn_bias, -+ size_t ffn_hidden_size, -+ FfnBase::ActType act_type, -+ cublasGemmAlgo_t algo): -+ layernorm_post_(layernorm_post), -+ has_beta_(has_beta), -+ is_layernorm_(is_layernorm), -+ ffn_fp16_(ffn_fp16), -+ DecoderBase(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, algo) -+{ -+ attention_layer1_ = std::make_shared>(batch_size, -+ src_seq_len, -+ src_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ false, -+ position_bias, -+ scale, -+ mask, -+ false, -+ algo); -+ attention_layer2_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ true, -+ position_bias, -+ scale, -+ mask, -+ false, -+ algo); -+ is_ffn_fp16_ = (std::is_same::value && ffn_fp16_ == true); -+ if (is_ffn_fp16_) { -+ ffn_layer_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ algo); -+ } -+ else { -+ ffn_layer_ = std::make_shared>(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ algo); -+ } -+ layer_norm1_ = std::make_shared>(has_beta_, false, LayerNorm::Type::T5, eps1, algo); -+ layer_norm2_ = -+ std::make_shared>(has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps2, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm3_ = -+ std::make_shared>(has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps3, algo); -+ else -+ layer_norm3_ = std::make_shared>( -+ has_beta_, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE_CAST, eps3, algo); -+ if (is_layernorm) -+ layer_norm4_ = std::make_shared>(has_beta_, false, LayerNorm::Type::T5, eps4, algo); -+} -+template -+size_t Decoder::GetWorkspaceSize() -+{ -+ size_t attn_out_size = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t tmp_out_size = attn_out_size; -+ size_t compress_buffer_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t compress_buffer_len2 = (src_seq_len_ > tgt_seq_len_) ? batch_size_ * src_seq_len_ * hidden_size_ : -+ batch_size_ * tgt_seq_len_ * hidden_size_; -+ size_t padding_len = batch_size_ * src_seq_len_; -+ size_t padding_len2 = batch_size_ * tgt_seq_len_; -+ -+ OptAllocator allocator(ALIGN_SIZE); -+ d_sequence_lengths_offset_buf_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ padding_offset_buf_ = allocator.Malloc(padding_len * sizeof(int)); -+ d_token_num_buf_ = allocator.Malloc(1 * sizeof(size_t)); -+ d_sequence_lengths_offset_buf2_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ padding_offset_buf2_ = allocator.Malloc(padding_len2 * sizeof(int)); -+ d_token_num_buf2_ = allocator.Malloc(1 * sizeof(size_t)); -+ compress_buf_ = allocator.Malloc(compress_buffer_len * sizeof(T)); -+ compress_buf2_ = allocator.Malloc(compress_buffer_len2 * sizeof(T)); -+ normed_from_tensor_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ attn_ws_buf_ = allocator.Malloc(attention_layer1_->GetWorkspaceSize()); -+ attention_layer1_->SetWSOffset(attn_ws_buf_); -+ attn_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn_ws_buf_); -+ if (!layernorm_post_) -+ allocator.Free(normed_from_tensor_buf_); -+ normed_attn_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ if (layernorm_post_) -+ allocator.Free(normed_from_tensor_buf_); -+ attn2_ws_buf_ = allocator.Malloc(attention_layer2_->GetWorkspaceSize()); -+ attention_layer2_->SetWSOffset(attn2_ws_buf_); -+ attn2_out_buf_ = allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn2_ws_buf_); -+ normed_attn2_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(attn_out_size * sizeof(half)) : allocator.Malloc(attn_out_size * sizeof(T)); -+ allocator.Free(attn_out_buf_); -+ tmp_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); -+ -+ ffn_ws_buf_ = allocator.Malloc(ffn_layer_->GetWorkspaceSize()); -+ ffn_layer_->SetWSOffset(ffn_ws_buf_); -+ return allocator.total_size(); -+} -+ -+template -+void Decoder::GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ size_t seq_len, -+ cudaStream_t stream) -+{ -+ invokeGetPaddingOffset(&h_token_num, d_token_num, padding_offset, d_sequence_lengths, batch_size_, seq_len, stream); -+ if (h_token_num * 2 <= batch_size_ * seq_len) { -+ invokeRemovePadding( -+ compress_buffer, (const T*)from_tensor, padding_offset, h_token_num, head_num_ * head_size_, stream); -+ } -+} -+ -+template -+void Decoder::ForwardAttention(std::shared_ptr> attention_layer_, -+ std::vector& inputs, -+ std::vector& from_tensor, -+ std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ inputs[--in_idx_] = from_tensor[0]; -+ bool is_projection_bias = attention_layer_->GetProjectionBias(); -+ attention_layer_->SetProjectionBias(false); -+ std::vector attn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ attention_layer_->forward(attn_in_vector, output, ws, cublas_handle, stream); -+ attention_layer_->SetProjectionBias(is_projection_bias); -+ in_idx_ = attention_layer_->GetIdx() + in_idx_; -+} -+template -+void Decoder::AddBiasResidual(std::vector& inputs, const std::vector& output, cudaStream_t stream) -+{ -+ if (std::is_same::value || !is_ffn_fp16_) { -+ invokeAddBiasResidual(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ if (layernorm_post_) { -+ invokeAddBiasResidualSameTypeCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ invokeAddBiasResidualCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ } -+} -+template -+void Decoder::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ -+ int in_len = inputs.size(); -+ ws = GetBuf(ws, ws_offset_); -+ in_idx_ = 0; -+ std::vector decoder_in = inputs; -+ size_t h_token_num = h_token_num_ = batch_size_ * src_seq_len_; -+ size_t h_token_num2 = h_token_num2_ = batch_size_ * tgt_seq_len_; -+ SetHTokenNum(h_token_num, h_token_num2); -+ int* padding_offset = nullptr; -+ int* padding_offset2 = nullptr; -+ T* input_tensor = reinterpret_cast(inputs[in_idx_++]); -+ T* from_tensor = input_tensor; -+ -+ int idx_encoder_out = attention_layer1_->GetPositionBias() ? 7 : 10; -+ T* encoder_output = reinterpret_cast(inputs[idx_encoder_out]); -+ int* d_sequence_lengths2 = reinterpret_cast(inputs[in_len - 2]); -+ int* d_sequence_lengths1 = reinterpret_cast(inputs[in_len - 1]); -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ T* compress_buffer2 = GetBuf(ws, +compress_buf2_); -+ size_t* d_token_num = GetBuf(ws, +d_token_num_buf_); -+ size_t* d_token_num2 = GetBuf(ws, +d_token_num_buf2_); -+ attention_layer1_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ attention_layer2_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ if (eft_) { -+ padding_offset = GetBuf(ws, padding_offset_buf_); -+ GetCompressBuffer(compress_buffer, -+ from_tensor, -+ d_sequence_lengths1, -+ padding_offset, -+ h_token_num, -+ d_token_num, -+ src_seq_len_, -+ stream); -+ if (batch_size_ > 1) { -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_) { -+ h_token_num_ = h_token_num; -+ from_tensor = compress_buffer; -+ attention_layer1_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths1, d_sequence_lengths1); -+ attention_layer2_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths1, d_sequence_lengths1); -+ } -+ } -+ else { -+ SetHTokenNum(h_token_num, h_token_num2); -+ } -+ padding_offset2 = GetBuf(ws, padding_offset_buf2_); -+ GetCompressBuffer(compress_buffer2, -+ encoder_output, -+ d_sequence_lengths2, -+ padding_offset2, -+ h_token_num2, -+ d_token_num2, -+ tgt_seq_len_, -+ stream); -+ if (h_token_num2 * 2 <= batch_size_ * tgt_seq_len_) { -+ h_token_num2_ = h_token_num2; -+ decoder_in[idx_encoder_out] = compress_buffer2; -+ attention_layer2_->SetVslParam(padding_offset, padding_offset2, d_sequence_lengths1, d_sequence_lengths2); -+ } -+ } -+ SetHTokenNum(h_token_num_, h_token_num2_); -+ h_token_num = h_token_num_; -+ h_token_num2 = h_token_num2_; -+ T* attn_out = GetBuf(ws, attn_out_buf_); -+ T* normed_from_tensor = GetBuf(ws, normed_from_tensor_buf_); -+ -+ T* normed_attn_out = GetBuf(ws, normed_attn_out_buf_); -+ T* attn2_out = GetBuf(ws, attn2_out_buf_); -+ T* normed_attn2_out = GetBuf(ws, normed_attn2_out_buf_); -+ T* tmp_out = reinterpret_cast(output[0]); -+ if (attention_layer1_->GetPaddingOffset() != nullptr || is_ffn_fp16_ == true || is_layernorm_) { -+ tmp_out = GetBuf(ws, tmp_out_buf_); -+ } -+ T* tmp_out1 = reinterpret_cast(output[0]); -+ T* tmp_out2 = reinterpret_cast(output[0]); -+ -+ if (is_layernorm_ && (attention_layer1_->GetPaddingOffset() != nullptr || is_ffn_fp16_ == true)) { -+ tmp_out1 = compress_buffer2; -+ if (attention_layer1_->GetPaddingOffset() != nullptr) { -+ tmp_out2 = compress_buffer; -+ } -+ } -+ else if (attention_layer1_->GetPaddingOffset() != nullptr && is_ffn_fp16_ == true) { -+ tmp_out1 = compress_buffer; -+ } -+ T* out_buf = is_ffn_fp16_ ? tmp_out1 : tmp_out; -+ layer_norm1_->SetParams(h_token_num_, hidden_size_); -+ layer_norm2_->SetParams(h_token_num_, hidden_size_); -+ layer_norm3_->SetParams(h_token_num_, hidden_size_); -+ -+ T* gamma1 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta1 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector in = {from_tensor, gamma1, beta1}; -+ std::vector out = {normed_from_tensor}; -+ layer_norm1_->forward(in, out, ws, cublas_handle, stream); -+ std::vector attn_from_vector{normed_from_tensor}; -+ std::vector attn_out_vector{attn_out}; -+ ForwardAttention(attention_layer1_, decoder_in, attn_from_vector, attn_out_vector, ws, cublas_handle, stream); -+ -+ T* projection_bias = attention_layer1_->GetProjectionBias() ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* gamma2 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta2 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ from_tensor = (layernorm_post_) ? normed_from_tensor : from_tensor; -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm2_->forward(in2, out2, ws, cublas_handle, stream); -+ std::vector attn2_from_vector{normed_attn_out}; -+ std::vector attn2_out_vector{attn2_out}; -+ ForwardAttention(attention_layer2_, decoder_in, attn2_from_vector, attn2_out_vector, ws, cublas_handle, stream); -+ -+ T* projection_bias2 = (attention_layer2_->GetProjectionBias()) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* gamma3 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta3 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ attn_out = (layernorm_post_) ? normed_attn_out : attn_out; -+ std::vector in3 = {attn_out, gamma3, beta3, projection_bias2}; -+ std::vector out3 = {attn2_out, normed_attn2_out}; -+ layer_norm3_->forward(in3, out3, ws, cublas_handle, stream); -+ -+ inputs[--in_idx_] = normed_attn2_out; -+ std::vector ffn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ std::vector ffn_out_vector{tmp_out}; -+ ffn_layer_->forward(ffn_in_vector, ffn_out_vector, ws, cublas_handle, stream); -+ in_idx_ = ffn_layer_->GetIdx() + in_idx_; -+ -+ attn2_out = (layernorm_post_) ? normed_attn2_out : attn2_out; -+ T* ffn_bias = (ffn_layer_->GetffnBias()) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector add_residual_in_vector{attn2_out, ffn_bias}; -+ std::vector add_residual_out_vector{tmp_out, tmp_out1}; -+ AddBiasResidual(add_residual_in_vector, add_residual_out_vector, stream); -+ if (is_layernorm_) { -+ T* gamma4 = reinterpret_cast(inputs[in_idx_++]); -+ T* beta4 = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ std::vector in4 = {out_buf, gamma4, beta4}; -+ std::vector out4 = {tmp_out2}; -+ layer_norm4_->SetParams(h_token_num_, hidden_size_); -+ layer_norm4_->forward(in4, out4, ws, cublas_handle, stream); -+ out_buf = tmp_out2; -+ } -+ if (attention_layer1_->GetPaddingOffset() != nullptr) { -+ cudaMemsetAsync(output[0], 0, batch_size_ * src_seq_len_ * head_size_ * head_num_ * sizeof(T), stream); -+ invokeRebuildPadding( -+ (T*)output[0], out_buf, attention_layer1_->GetPaddingOffset(), h_token_num, hidden_size_, stream); -+ } -+ return; -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/decoder.h b/src/fastertransformer/layers/ms_layers/decoder.h -new file mode 100644 -index 0000000..b33b42b ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/decoder.h -@@ -0,0 +1,243 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+ -+#include -+#include -+ -+namespace fastertransformer { -+ -+class DecoderBase: public BaseLayerMS { -+public: -+ DecoderBase(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(batch_size, src_seq_len, tgt_seq_len, head_num, head_size, hidden_size, 1, algo){} -+ bool GetEft(); -+ virtual void SetFfnParam(bool ffn_fp16, size_t ffn_hidden_size, FfnBase::ActType act_type = FfnBase::ActType::Gelu, bool ffn_bias = true) = 0; -+ virtual void SetScaleAttn(float scale = 1.0f) = 0; -+ virtual void SetIsLayerNorm(bool is_layernorm, float eps = 1e-6f) = 0; -+ virtual void SetLayerNormPost(bool layernorm_post) = 0; -+ virtual void SetVSL(bool eft) = 0; -+ virtual void SetT5(bool t5) = 0; -+ virtual void SetEps(float eps1, float eps2, float eps3, float eps4 = 1e-6f) = 0; -+ virtual void SetAlgo(cublasGemmAlgo_t algo) override = 0; -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2) = 0; -+}; -+template -+class Decoder : public DecoderBase{ -+private: -+ std::shared_ptr> attention_layer1_; -+ std::shared_ptr> attention_layer2_; -+ std::shared_ptr ffn_layer_; -+ std::shared_ptr> layer_norm1_; -+ std::shared_ptr> layer_norm2_; -+ std::shared_ptr> layer_norm3_; -+ std::shared_ptr> layer_norm4_; -+ void ForwardAttention(std::shared_ptr> attention_layer_, std::vector &inputs, std::vector &from_tensor, std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+ void AddBiasResidual(std::vector &inputs, const std::vector &output, cudaStream_t stream); -+ -+ void GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ size_t seq_len, -+ cudaStream_t stream); -+ -+ bool eft_{false}; -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ bool layernorm_post_; -+ bool has_beta_; -+ bool is_layernorm_; -+ bool ffn_fp16_; -+ bool is_ffn_fp16_{false}; -+ -+ size_t normed_from_tensor_buf_; -+ size_t attn_out_buf_; -+ size_t attn_ws_buf_; -+ size_t attn2_out_buf_; -+ size_t attn2_ws_buf_; -+ size_t tmp_out_buf_; -+ size_t ffn_ws_buf_; -+ size_t normed_attn_out_buf_; -+ size_t normed_attn2_out_buf_; -+ size_t compress_buf_; -+ size_t d_token_num_buf_; -+ size_t padding_offset_buf_; -+ size_t d_sequence_lengths_offset_buf_; -+ size_t compress_buf2_; -+ size_t d_token_num_buf2_; -+ size_t padding_offset_buf2_; -+ size_t d_sequence_lengths_offset_buf2_; -+public: -+ void printParam() -+ { -+ std::cout<<"batch_size = "< -+ -+namespace fastertransformer { -+ -+template -+Encoder::Encoder(size_t batch_size, -+ size_t seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ float eps1, -+ float eps2, -+ float eps3, -+ bool layernorm_post, -+ bool has_beta, -+ bool is_layernorm, -+ int embedding_size, -+ bool ffn_fp16, -+ bool qkv_bias, -+ bool projection_bias, -+ bool is_cross, -+ bool position_bias, -+ float scale, -+ bool mask, -+ bool ffn_bias, -+ size_t ffn_hidden_size, -+ FfnBase::ActType act_type, -+ bool use_past, -+ bool query_layer, -+ size_t expert_num, -+ int rank_num, -+ cublasGemmAlgo_t algo): -+ layernorm_post_(layernorm_post), -+ has_beta_(has_beta), -+ is_layernorm_(is_layernorm), -+ embedding_size_(embedding_size), -+ ffn_fp16_(ffn_fp16), -+ use_past_(use_past), -+ query_layer_(query_layer), -+ expert_num_(expert_num), -+ EncoderBase(batch_size, seq_len, head_num, head_size, hidden_size, rank_num, algo) -+{ -+ attention_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ qkv_bias, -+ projection_bias, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo); -+ is_ffn_fp16_ = (std::is_same::value && ffn_fp16_ == true); -+ if (is_ffn_fp16_) { -+ ffn_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ rank_num, -+ algo); -+ } -+ else { -+ ffn_layer_ = std::make_shared>(batch_size, -+ seq_len, -+ seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ ffn_bias, -+ ffn_hidden_size, -+ act_type, -+ rank_num, -+ algo); -+ } -+ if (layernorm_post == false || position_bias) { -+ layer_norm1_ = std::make_shared>(has_beta, false, LayerNorm::Type::T5, eps1, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE, eps2, algo); -+ else -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_T5_PRE_CAST, eps2, algo); -+ } -+ else { -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm1_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL, eps1, algo); -+ else -+ layer_norm1_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_CAST, eps1, algo); -+ if (std::is_same::value || !is_ffn_fp16_) -+ layer_norm2_ = -+ std::make_shared>(has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL, eps2, algo); -+ else -+ layer_norm2_ = std::make_shared>( -+ has_beta, true, LayerNorm::Type::ADD_BIAS_RESIDUAL_CAST_FFN, eps2, algo); -+ } -+ if (is_layernorm) -+ layer_norm3_ = std::make_shared>(has_beta, false, LayerNorm::Type::T5, eps3, algo); -+ moe_ = std::make_shared( -+ hidden_size, expert_num, ffn_hidden_size, rank_num, seq_len, 1.1, batch_size); -+ seq_len_host_ = (int*)malloc(batch_size * sizeof(int)); -+ first_layer_ = false; -+} -+template -+size_t Encoder::GetWorkspaceSize() -+{ -+ size_t attn_out_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t normed_from_tensor_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t normed_attn_out_len = batch_size_ * src_seq_len_ * hidden_size_; -+ size_t tmp_out_size = attn_out_len; -+ size_t compress_buffer_len = batch_size_ * src_seq_len_ * hidden_size_; -+ -+ size_t padding_len = batch_size_ * src_seq_len_; -+ OptAllocator allocator(ALIGN_SIZE); -+ if (use_past_) -+ d_sequence_lengths_offset_buf_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ if (use_past_) -+ d_sequence_lengths_offset_buf2_ = allocator.Malloc(batch_size_ * sizeof(int)); -+ -+ padding_offset_buf_ = allocator.Malloc(padding_len * sizeof(int)); -+ padding_offset_buf2_ = allocator.Malloc(padding_len * sizeof(int)); -+ -+ d_token_num_buf_ = allocator.Malloc(1 * sizeof(size_t)); -+ compress_buf_ = allocator.Malloc(compress_buffer_len * sizeof(T)); -+ -+ normed_from_tensor_buf_ = (!layernorm_post_ || attention_layer_->GetPositionBias()) ? -+ allocator.Malloc(normed_from_tensor_len * sizeof(T)) : -+ 0; -+ attn_ws_buf_ = allocator.Malloc(attention_layer_->GetWorkspaceSize()); -+ attention_layer_->SetWSOffset(attn_ws_buf_); -+ attn_out_buf_ = allocator.Malloc(attn_out_len * sizeof(T)); -+ -+ allocator.Free(d_token_num_buf_); -+ if (use_past_) -+ allocator.Free(d_sequence_lengths_offset_buf_); -+ if (use_past_) -+ allocator.Free(d_sequence_lengths_offset_buf2_); -+ -+ allocator.Free(attn_ws_buf_); -+ if (!layernorm_post_ || attention_layer_->GetPositionBias()) -+ allocator.Free(normed_from_tensor_buf_); -+ normed_attn_out_buf_ = ((!layernorm_post_ || attention_layer_->GetPositionBias()) || is_ffn_fp16_) ? -+ is_ffn_fp16_ ? allocator.Malloc(normed_attn_out_len * sizeof(half)) : -+ allocator.Malloc(normed_attn_out_len * sizeof(T)) : -+ 0; -+ if (is_moe_) { -+ size_t moe_ws_size = moe_->GetWorkspaceSize(); -+ size_t moe_offset = allocator.Malloc(moe_ws_size); -+ moe_->SetWSOffset(moe_offset); -+ } -+ else { -+ moe_ = nullptr; -+ ffn_ws_buf_ = allocator.Malloc(ffn_layer_->GetWorkspaceSize()); -+ ffn_layer_->SetWSOffset(ffn_ws_buf_); -+ } -+ tmp_out_buf_ = -+ is_ffn_fp16_ ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); -+ if (is_ffn_fp16_) -+ tmp_out1_buf_ = allocator.Malloc(tmp_out_size * sizeof(T)); -+ size_t size = allocator.total_size(); -+ return size; -+} -+template -+void Encoder::GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ cudaStream_t stream) -+{ -+ int old_h_token_num = h_token_num; -+ invokeGetPaddingOffset( -+ &h_token_num, d_token_num, padding_offset, d_sequence_lengths, batch_size_, src_seq_len_, stream); -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_ || use_past_) { -+ invokeRemovePadding(compress_buffer, (const T*)from_tensor, padding_offset, h_token_num, hidden_size_, stream); -+ } -+ else -+ h_token_num = old_h_token_num; -+} -+template -+void Encoder::InitUsePast(std::vector& inputs, T*& from_tensor, void* ws, cudaStream_t stream) -+{ -+ int in_len = inputs.size(); -+ int k_cache_idx = in_idx_++; -+ int v_cache_idx = in_idx_++; -+ int vsl_inputs_idx = in_len - 4; -+ int position_idx = in_len - 5; -+ int emmbeding_pos_idx = in_len - 6; -+ int emmbeding_idx = in_len - 7; -+ -+ k_cache_ = reinterpret_cast(inputs[k_cache_idx]); -+ v_cache_ = reinterpret_cast(inputs[v_cache_idx]); -+ -+ size_t* d_token_num = GetBuf(ws, d_token_num_buf_); -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ int* d_sequence_lengths = reinterpret_cast(inputs[vsl_inputs_idx]); -+ int* d_sequence_lengths2 = reinterpret_cast(inputs[vsl_inputs_idx + 1]); -+ int* padding_offset = reinterpret_cast(inputs[vsl_inputs_idx + 2]); -+ size_t* input_h_token_num = reinterpret_cast(inputs[vsl_inputs_idx + 3]); -+ -+ int h_input_position = 0; -+ int* input_position = reinterpret_cast(inputs[position_idx]); -+ cudaMemcpyAsync(&h_input_position, input_position, sizeof(h_input_position), cudaMemcpyDeviceToHost, stream); -+ cudaStreamSynchronize(stream); -+ -+ cudaMemcpyAsync(&h_token_num_, input_h_token_num, sizeof(size_t), cudaMemcpyDeviceToHost, stream); -+ cudaStreamSynchronize(stream); -+ if (h_input_position == 0) { -+ const size_t size = batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ T* k_cache = reinterpret_cast(k_cache_); -+ T* v_cache = reinterpret_cast(v_cache_); -+ cudaMemsetAsync(k_cache, 0, size * sizeof(T), stream); -+ cudaMemsetAsync(v_cache, 0, size * sizeof(T), stream); -+ attention_layer_->SetCache(k_cache_, v_cache_); -+ attention_layer_->SetIncrementalMode(false); -+ } -+ else { -+ attention_layer_->SetIncrementalMode(true); -+ } -+ -+ attention_layer_->SetVslParam(padding_offset, padding_offset, d_sequence_lengths, d_sequence_lengths2); -+ -+ if (first_layer_) { -+ T* input_after_emmbeding = compress_buffer; -+ T* emmbeding_table = reinterpret_cast(inputs[emmbeding_idx]); -+ T* emmbeding_pos_table = reinterpret_cast(inputs[emmbeding_pos_idx]); -+ int* input_position = reinterpret_cast(inputs[position_idx]); -+ invokeEmbeddingPanguSigma(const_cast(reinterpret_cast(from_tensor)), -+ const_cast(input_position), -+ const_cast(emmbeding_table), -+ const_cast(emmbeding_pos_table), -+ input_after_emmbeding, -+ h_token_num_, -+ hidden_size_, -+ stream); -+ from_tensor = input_after_emmbeding; -+ } -+ if (query_layer_) { -+ T* emmbeding_table = reinterpret_cast(inputs[in_len - 4]); -+ invokeVocabEmbedding(const_cast(reinterpret_cast(inputs[5])), -+ const_cast(emmbeding_table), -+ reinterpret_cast(compress_buffer), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ inputs[5] = compress_buffer; -+ } -+ -+ SetHTokenNum(h_token_num_, h_token_num_); -+} -+ -+template -+void Encoder::ForwardAttention(std::vector& inputs, -+ std::vector& from_tensor, -+ std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ if (query_layer_) { -+ inputs[in_idx_ - 1] = inputs[in_idx_]; -+ inputs[in_idx_] = from_tensor[0]; -+ in_idx_--; -+ } -+ else { -+ inputs[--in_idx_] = from_tensor[0]; -+ } -+ bool is_projection_bias = attention_layer_->GetProjectionBias(); -+ attention_layer_->SetProjectionBias(false); -+ std::vector attn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ -+ attention_layer_->forward(attn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = attention_layer_->GetIdx() + in_idx_; -+ attention_layer_->SetProjectionBias(is_projection_bias); -+ -+ nvinfer1::DataType type = (std::is_same::value) ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF; -+ if (all_reduce_sum_func_ != nullptr) { -+ all_reduce_sum_func_(output[0], output[0], h_token_num_ * hidden_size_, type, stream); -+ } -+} -+ -+template -+void Encoder::ForwardFfn(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ nvinfer1::DataType type = (std::is_same::value) ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF; -+ std::vector ffn_in_vector{inputs.begin() + in_idx_, inputs.end()}; -+ std::vector ffn_output = output; -+ if (is_moe_) { -+ if constexpr (std::is_same::value) { -+ moe_->SetPaddingOffsetDevice(attention_layer_->GetPaddingOffset()); -+ moe_->SetSeqLenHost(seq_len_host_); -+ moe_->SetSeqLenDevice(attention_layer_->GetSequenceLength()); -+ moe_->SetParallelFunc(all_gather_func_, all_reduce_sum_func_); -+ moe_->forward(ffn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = in_idx_ + moe_->GetIdx(); -+ } -+ else { -+ std::cout << "moe support only half" << std::endl; -+ } -+ } -+ else { -+ std::vector ffn_in_vector(inputs.begin() + in_idx_, inputs.end()); -+ ffn_layer_->forward(ffn_in_vector, output, ws, cublas_handle, stream); -+ in_idx_ = in_idx_ + ffn_layer_->GetIdx(); -+ if (all_reduce_sum_func_ != nullptr) { -+ all_reduce_sum_func_(output[0], output[0], h_token_num_ * hidden_size_, type, stream); -+ } -+ } -+} -+template -+void Encoder::AddBiasResidual(std::vector& inputs, const std::vector& output, cudaStream_t stream) -+{ -+ if (std::is_same::value || !is_ffn_fp16_) { -+ invokeAddBiasResidual(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ if (layernorm_post_) { -+ invokeAddBiasResidualSameTypeCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ else { -+ invokeAddBiasResidualCast(reinterpret_cast(output[0]), -+ reinterpret_cast(inputs[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(inputs[1]), -+ h_token_num_, -+ hidden_size_, -+ stream); -+ } -+ } -+} -+template -+void Encoder::MulEmbeddingTable(std::vector& inputs, -+ const std::vector& output, -+ cublasHandle_t cublas_handle) -+{ -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ int gemm_dims[] = {(int)embedding_size_, (int)batch_size_, (int)hidden_size_}; -+ int gemm_lds[] = {(int)hidden_size_, (int)hidden_size_, (int)embedding_size_}; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_T, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ reinterpret_cast(output[0]), -+ output[1], -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+} -+template -+void Encoder::forward(std::vector& inputs, -+ const std::vector& output, -+ void* ws, -+ cublasHandle_t cublas_handle, -+ cudaStream_t stream) -+{ -+ ws = GetBuf(ws, ws_offset_); -+ int in_len = inputs.size(); -+ in_idx_ = 0; -+ std::vector encoder_in = inputs; -+ size_t h_token_num = batch_size_ * src_seq_len_; -+ SetHTokenNum(h_token_num, h_token_num); -+ T* input_tensor = reinterpret_cast(encoder_in[in_idx_++]); -+ -+ T* from_tensor = input_tensor; -+ int* d_sequence_lengths = reinterpret_cast(inputs[in_len - 1]); -+ int* d_sequence_lengths2 = d_sequence_lengths; -+ T* compress_buffer = GetBuf(ws, compress_buf_); -+ int* padding_offset = GetBuf(ws, padding_offset_buf_); -+ size_t* d_token_num = GetBuf(ws, d_token_num_buf_); -+ const int batch = batch_size_; -+ -+ bool is_T5 = attention_layer_->GetPositionBias(); -+ SetT5(is_T5); -+ k_cache_ = nullptr; -+ v_cache_ = nullptr; -+ attention_layer_->SetVslParam(nullptr, nullptr, nullptr, nullptr); -+ if (use_past_) { -+ InitUsePast(encoder_in, from_tensor, ws, stream); -+ } -+ else { -+ if (eft_) { -+ GetCompressBuffer( -+ compress_buffer, from_tensor, d_sequence_lengths, padding_offset, h_token_num, d_token_num, stream); -+ if (batch > 1) { -+ if (h_token_num * 2 <= batch_size_ * src_seq_len_) { -+ SetHTokenNum(h_token_num, h_token_num); -+ from_tensor = compress_buffer; -+ attention_layer_->SetVslParam( -+ padding_offset, padding_offset, d_sequence_lengths, d_sequence_lengths); -+ } -+ } -+ else { -+ SetHTokenNum(h_token_num, h_token_num); -+ } -+ } -+ } -+ h_token_num = h_token_num_; -+ bool is_ffn_write_to_output = !(is_ffn_fp16_ || is_layernorm_ || query_layer_); -+ T* attn_out = GetBuf(ws, attn_out_buf_); -+ T* normed_attn_out = GetBuf(ws, normed_attn_out_buf_); -+ T* tmp_out = reinterpret_cast(output[0]); -+ if (!is_ffn_write_to_output) { -+ // in the case that ffn not write directly to output, allocate a diffrent tensor -+ tmp_out = GetBuf(ws, tmp_out_buf_); -+ } -+ T* tmp_out1 = reinterpret_cast(output[0]); -+ T* tmp_out2 = reinterpret_cast(output[0]); -+ if (is_layernorm_ -+ && (attention_layer_->GetPaddingOffset() != nullptr || (std::is_same::value && is_ffn_fp16_)) -+ && !use_past_) { -+ tmp_out1 = GetBuf(ws, tmp_out1_buf_); -+ if (attention_layer_->GetPaddingOffset() != nullptr) { -+ tmp_out2 = compress_buffer; -+ } -+ } -+ else if (attention_layer_->GetPaddingOffset() != nullptr && is_ffn_fp16_) { -+ tmp_out1 = compress_buffer; -+ } -+ T* query_out = reinterpret_cast(output[0]); -+ if (attention_layer_->GetPaddingOffset() != nullptr && query_layer_ -+ && (!attention_layer_->GetIncrementalMode() || h_token_num_ < batch_size_)) { -+ query_out = GetBuf(ws, tmp_out1_buf_); -+ } -+ T* out_buf = tmp_out; -+ -+ // Step I - Do Pre Layer Norm -+ T* normed_from_tensor = from_tensor; -+ layer_norm1_->SetParams(h_token_num_, hidden_size_); -+ layer_norm2_->SetParams(h_token_num_, hidden_size_); -+ if (layernorm_post_ == false || is_T5) { -+ T* gamma1 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta1 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ normed_from_tensor = GetBuf(ws, normed_from_tensor_buf_); -+ std::vector in = {from_tensor, gamma1, beta1}; -+ std::vector normed_out = {normed_from_tensor}; -+ layer_norm1_->forward(in, normed_out, ws, cublas_handle, stream); -+ } -+ -+ std::vector attn_from_vector{normed_from_tensor}; -+ std::vector attn_out_vector{attn_out}; -+ ForwardAttention(encoder_in, attn_from_vector, attn_out_vector, ws, cublas_handle, stream); -+ T* projection_bias = -+ (attention_layer_->GetProjectionBias()) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ T* gamma2 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta2 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ if (layernorm_post_ == false || is_T5) { -+ // setup skip connection -+ from_tensor = (layernorm_post_) ? normed_from_tensor : from_tensor; -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm2_->forward(in2, out2, ws, cublas_handle, stream); -+ } -+ else { -+ std::vector in2 = {from_tensor, gamma2, beta2, projection_bias}; -+ std::vector out2 = {attn_out, normed_attn_out}; -+ layer_norm1_->forward(in2, out2, ws, cublas_handle, stream); -+ if (!is_ffn_fp16_) { -+ normed_attn_out = attn_out; -+ } -+ } -+ -+ encoder_in[--in_idx_] = normed_attn_out; -+ T* ffn_out_tensor = tmp_out; -+ std::vector ffn_out_vector{ffn_out_tensor}; -+ ForwardFfn(encoder_in, ffn_out_vector, ws, cublas_handle, stream); -+ T* ffn_bias = (ffn_layer_->GetffnBias() && !is_moe_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ T* ffn_norm_tensor = (is_ffn_fp16_) ? tmp_out1 : ffn_out_tensor; -+ if (layernorm_post_ == true && !is_T5) { -+ T* gamma3 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta3 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ attn_out = (is_ffn_fp16_) ? normed_attn_out : attn_out; -+ std::vector in3 = {attn_out, gamma3, beta3, ffn_bias}; -+ std::vector out3 = {ffn_out_tensor, tmp_out1}; -+ layer_norm2_->forward(in3, out3, ws, cublas_handle, stream); -+ } -+ else { -+ attn_out = (layernorm_post_) ? normed_attn_out : attn_out; -+ std::vector add_residual_in_vector{attn_out, ffn_bias}; -+ std::vector add_residual_out_vector{ffn_out_tensor, tmp_out1}; -+ AddBiasResidual(add_residual_in_vector, add_residual_out_vector, stream); -+ } -+ T* last_norm_Tensor = ffn_norm_tensor; -+ if (is_layernorm_) { -+ T* gamma4 = reinterpret_cast(encoder_in[in_idx_++]); -+ T* beta4 = (has_beta_) ? reinterpret_cast(encoder_in[in_idx_++]) : nullptr; -+ std::vector in4 = {ffn_norm_tensor, gamma4, beta4}; -+ std::vector out4 = {tmp_out2}; -+ layer_norm3_->SetParams(h_token_num_, hidden_size_); -+ layer_norm3_->forward(in4, out4, ws, cublas_handle, stream); -+ last_norm_Tensor = tmp_out2; -+ } -+ if (query_layer_) { -+ if (attention_layer_->GetIncrementalMode() && h_token_num_ == batch_size_) { -+ query_out = last_norm_Tensor; -+ } -+ else { -+ invokeRebuildQuery(query_out, -+ last_norm_Tensor, -+ attention_layer_->GetSequenceLength(), -+ attention_layer_->GetIncrementalMode() ? h_token_num_ : batch_size_, -+ hidden_size_, -+ stream); -+ } -+ std::vector mul_out = {query_out, output[0]}; -+ MulEmbeddingTable(encoder_in, mul_out, cublas_handle); -+ } -+ if (attention_layer_->GetPaddingOffset() != nullptr && !use_past_) { -+ int size = batch_size_ * src_seq_len_ * hidden_size_; -+ cudaMemsetAsync(output[0], 0, size * sizeof(T), stream); -+ invokeRebuildPadding( -+ (T*)output[0], last_norm_Tensor, attention_layer_->GetPaddingOffset(), h_token_num_, hidden_size_, stream); -+ } -+ return; -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/encoder.h b/src/fastertransformer/layers/ms_layers/encoder.h -new file mode 100644 -index 0000000..881fc48 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/encoder.h -@@ -0,0 +1,402 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+#include "src/fastertransformer/layers/ms_layers/MoeFfnLayer.h" -+ -+#include -+#include -+ -+namespace fastertransformer { -+class EncoderBase: public BaseLayerMS { -+public: -+ EncoderBase(size_t batch_size, -+ size_t seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(batch_size, seq_len, seq_len, head_num, head_size, hidden_size, rank_num, algo){} -+ -+ virtual void SetFfnParam(bool ffn_fp16, size_t ffn_hidden_size, FfnBase::ActType act_type = FfnBase::ActType::Gelu, bool ffn_bias = true) = 0; -+ virtual void SetMoeParam(bool is_moe, size_t expert_num = 0, size_t expert_offset = 0, size_t capacity_factor = 0, FfnBase::ActType act_type = FfnBase::ActType::Gelu) = 0; -+ virtual void SetParallelFunc(BaseLayerMS::allGatherFunc all_gather_func, BaseLayerMS::allReduceSumFunc all_reduce_sum_func) override = 0; -+ virtual void SetRankNum(int rank_num) override = 0; -+ virtual void SetRankId(int rank_id) override = 0; -+ virtual void SetScaleAttn(float scale = 1.0f) = 0; -+ virtual void SetIsLayerNorm(bool is_layernorm, float eps = 1e-6f) = 0; -+ virtual void SetLayerNormPost(bool layernorm_post) = 0; -+ virtual void SetVSL(bool eft) = 0; -+ virtual void SetT5(bool t5) = 0; -+ virtual void SetEps(float eps1, float eps2, float eps3) = 0; -+ virtual void SetUsePast(bool use_past) = 0; -+ virtual void SetAlgo(cublasGemmAlgo_t algo) override = 0; -+ virtual void SetHTokenNum(size_t h_token_num, size_t h_token_num2 = -1) = 0; -+ virtual void SetCache(void* k_cache, void* v_cache) = 0; -+ virtual void SetQueryLayer(bool query_layer) = 0; -+ virtual void SetEmmbedingSize(size_t embedding_size) = 0; -+ virtual void SetFirstLayer(bool first_layer) = 0; -+ virtual void SetRankParam(int rank_num = 0, int rank_id = 0) = 0; -+}; -+template -+class Encoder: public EncoderBase { -+private: -+ std::shared_ptr> attention_layer_; -+ std::shared_ptr ffn_layer_; -+ std::shared_ptr> layer_norm1_; -+ std::shared_ptr> layer_norm2_; -+ std::shared_ptr> layer_norm3_; -+void GetCompressBuffer(T* compress_buffer, -+ T* from_tensor, -+ int* d_sequence_lengths, -+ int* padding_offset, -+ size_t& h_token_num, -+ size_t* d_token_num, -+ cudaStream_t stream); -+void InitUsePast(std::vector &inputs, T* &from_tensor, void *ws, cudaStream_t stream); -+void ForwardAttention(std::vector &inputs, std::vector &from_tensor, std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+void ForwardFfn(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream); -+void AddBiasResidual(std::vector &inputs, const std::vector &output, cudaStream_t stream); -+void MulEmbeddingTable(std::vector &inputs, const std::vector &output, cublasHandle_t cublas_handle); -+ -+ -+ bool use_past_; // use past mode -+ bool query_layer_; // check if quary layer -+ -+ size_t data_parallel_{false}; -+ int cur_token_id_{0}; // current token id id -+ bool eft_{false}; -+ size_t h_token_num_; -+ size_t h_token_num2_; -+ void* k_cache_{nullptr}; -+ void* v_cache_{nullptr}; -+ int* seq_len_host_{nullptr}; -+ bool layernorm_post_; -+ bool has_beta_; -+ bool is_layernorm_; -+ bool ffn_fp16_; -+ int embedding_size_; -+ size_t normed_from_tensor_buf_{0}; -+ size_t attn_ws_buf_{0}; -+ size_t attn_out_buf_{0}; -+ size_t normed_attn_out_buf_{0}; -+ size_t ffn_ws_buf_{0}; -+ size_t tmp_out_buf_{0}; -+ size_t tmp_out1_buf_{0}; -+ size_t compress_buf_{0}; -+ size_t compress_buf2_{0}; -+ size_t d_token_num_buf_{0}; -+ size_t padding_offset_buf_{0}; -+ size_t padding_offset_buf2_{0}; -+ size_t d_sequence_lengths_offset_buf_{0}; -+ size_t d_sequence_lengths_offset_buf2_{0}; -+ -+ size_t norm_out_buf_{0}; -+ bool is_ffn_fp16_{false}; -+ bool is_moe_{0}; -+ std::shared_ptr moe_; -+ size_t expert_num_{0}; -+ size_t expert_offset_{0}; -+ size_t capacity_factor_{0}; -+ -+ bool first_layer_{0}; -+ -+public: -+ void printParam() -+ { -+ std::cout<<"batch_size = "< -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/utils/gemm_test/gemm_func.cc" -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+namespace fastertransformer { -+ -+template -+size_t Ffn::GetWorkspaceSize() -+{ -+ size_t ffn_len = -+ batch_size_ * src_seq_len_ * ffn_hidden_size_; -+ OptAllocator allocator(ALIGN_SIZE); -+ allocator.Malloc(ffn_len * sizeof(T)); -+ return allocator.total_size(); -+} -+ -+template -+void Ffn::forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+{ -+ in_idx_ = 0; -+ ws = GetBuf(ws, ws_offset_); -+ size_t inter_size = ffn_hidden_size_; -+ size_t h_token_num = h_token_num_; -+ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; -+ if (std::is_same::value) { -+ gemm_data_types[0] = CUDA_R_16F; -+ gemm_data_types[1] = CUDA_R_16F; -+ gemm_data_types[2] = CUDA_R_16F; -+ } -+ float alpha = 1.0f; -+ float beta = 0.0f; -+ -+ int gemm_dims[] = {(int)inter_size, (int)h_token_num, (int)hidden_size_}; -+ int gemm_lds[] = {(int)inter_size, (int)hidden_size_, (int)inter_size}; -+ T* normed_attn_out = reinterpret_cast(inputs[in_idx_++]); -+ T* from = reinterpret_cast(inputs[in_idx_++]); -+ T* bias = (ffn_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ -+ CublasGemmWrapper(from, -+ normed_attn_out, -+ ws, -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ if (act_type_ == ActType::Gelu) { -+ invokeAddBiasGelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } else if (act_type_ == ActType::FastGelu) { -+ invokeAddBiasFastGelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } else if (act_type_ == ActType::Relu) { -+ invokeAddBiasRelu(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } -+ else if (ffn_bias_ && act_type_ == ActType::No) { -+ invokeAddBias(reinterpret_cast(ws), bias, h_token_num, inter_size, stream); -+ } -+ gemm_dims[0] = hidden_size_; -+ gemm_dims[1] = h_token_num; -+ gemm_dims[2] = inter_size; -+ gemm_lds[0] = hidden_size_; -+ gemm_lds[1] = inter_size; -+ gemm_lds[2] = hidden_size_; -+ CublasGemmWrapper(reinterpret_cast(inputs[in_idx_++]), -+ ws, -+ outputs[0], -+ gemm_dims, -+ gemm_lds, -+ gemm_ops, -+ gemm_data_types, -+ &alpha, -+ &beta, -+ cublas_handle, -+ algo_); -+ -+ -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/ffn.h b/src/fastertransformer/layers/ms_layers/ffn.h -new file mode 100644 -index 0000000..10bf6dc ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/ffn.h -@@ -0,0 +1,132 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+ -+namespace fastertransformer { -+class FfnBase : public BaseLayerMS { -+public: -+ enum ActType { -+ No = 0, -+ Relu = 1, -+ Sigmoid = 2, -+ Relu6 = 3, -+ Elu = 4, -+ LeakyRelu = 5, -+ Abs = 6, -+ Relu1 = 7, -+ Softsign = 8, -+ Softplus = 9, -+ Tanh = 10, -+ Selu = 11, -+ HSwish = 12, -+ HSigmoid = 13, -+ ThresholdRelu = 14, -+ Linear = 15, -+ HardTanh = 16, -+ Sign = 17, -+ Swish = 18, -+ Gelu = 19, -+ FastGelu = 20, -+ Unknown = 21 -+ }; -+protected: -+ bool ffn_bias_; -+ size_t ffn_hidden_size_; -+ ActType act_type_; -+ size_t h_token_num_; -+ -+public: -+ void SetFfnHiddenSize(size_t ffn_hidden_size) -+ { -+ ffn_hidden_size_ = ffn_hidden_size; -+ } -+ void SetActType(ActType act_type) -+ { -+ act_type_ = act_type; -+ } -+ void SetffnBias(bool ffn_bias) -+ { -+ ffn_bias_ = ffn_bias; -+ } -+ size_t GetFfnHiddenSize() -+ { -+ return ffn_hidden_size_; -+ } -+ ActType GetActType() -+ { -+ return act_type_; -+ } -+ bool GetffnBias() -+ { -+ return ffn_bias_; -+ } -+ void SetHTokenNum(size_t h_token_num) -+ { -+ h_token_num_ = h_token_num; -+ } -+ void printParam() -+ { -+ std::cout<<"ffn param:\n"; -+ std::cout<<"batch_size = "< -+ -+namespace fastertransformer { -+ -+#define __ARCH__ 80 -+#ifdef __CUDA_ARCH__ -+#undef __CUDA_ARCH_HOST__ -+#define __CUDA_ARCH_HOST__ __CUDA_ARCH__ -+#endif -+#ifndef __CUDA_ARCH_HOST__ -+ #ifndef __CUDA_ARCH__ -+ #error "Need cuda arch at least 5.0" -+ #else -+ #define __CUDA_ARCH_HOST__ __CUDA_ARCH__ -+ #endif -+#endif -+#if __CUDA_ARCH_HOST__ < 750 -+ #undef __ARCH__ -+ #define __ARCH__ 70 -+#elif __CUDA_ARCH_HOST__ < 800 -+ #undef __ARCH__ -+ #define __ARCH__ 75 -+#endif -+ -+#define CONCAT(_x) cutlass::arch::Sm##_x -+ -+#define INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( \ -+ ARCH, SCALAR_T, IS_ALIGNED, QUERIES_PER_BLOCK, KEYS_PER_BLOCK, SINGLE_VALUE_ITER) \ -+ AttentionKernel -+ -+#define INSTANTIATE_ATTENTION_KERNEL_PARAM(__p__, __in0__, __in1__, __in2__, __in3__, __in4__, __out__) \ -+ { \ -+ p.query_ptr = __in0__; \ -+ p.key_ptr = __in1__; \ -+ p.value_ptr = __in2__; \ -+ p.attn_mask_ptr = __in3__; \ -+ p.attn_bias_ptr = __in4__; \ -+ p.cu_seqlens_q_ptr = d_sequence_length_; \ -+ p.cu_seqlens_k_ptr = d_sequence_length2_; \ -+ p.logsumexp_ptr = nullptr; \ -+ p.output_accum_ptr = nullptr; \ -+ p.output_ptr = __out__; \ -+ p.num_heads = head_num_; \ -+ p.num_batches = batch_size_; \ -+ p.head_dim = head_size_; \ -+ p.head_dim_value = head_size_; \ -+ p.num_queries =src_seq_len_; \ -+ p.num_keys = tgt_seq_len_; \ -+ p.scale = scale_; \ -+ p.causal = false; \ -+ p.no_bias_head_dim = (is_cross_ && position_bias_); \ -+ p.q_strideM = head_size_; \ -+ p.k_strideM = head_size_; \ -+ p.v_strideM = head_size_; \ -+ p.attn_mask_strideM = tgt_seq_len_; \ -+ p.attn_bias_strideM = tgt_seq_len_; \ -+ p.q_strideH = p.q_strideM *src_seq_len_; \ -+ p.k_strideH = p.k_strideM * tgt_seq_len_; \ -+ p.v_strideH = p.v_strideM * tgt_seq_len_; \ -+ p.o_strideH = head_size_; \ -+ p.attn_mask_strideH = p.attn_mask_strideM *src_seq_len_; \ -+ p.attn_bias_strideH = (p.no_bias_head_dim) ? 0 : p.attn_mask_strideH; \ -+ p.q_strideB = p.q_strideH * head_num_; \ -+ p.k_strideB = p.k_strideH * head_num_; \ -+ p.v_strideB = p.v_strideH * head_num_; \ -+ p.o_strideB = \ -+ src_seq_len_ * head_num_ * head_size_; \ -+ p.attn_mask_strideB = p.attn_mask_strideH; \ -+ p.attn_bias_strideB = (p.no_bias_head_dim) ? (p.attn_mask_strideH) : p.attn_bias_strideH * p.head_dim; \ -+ if (use_past_) { \ -+ p.num_queries = h_token_num_; \ -+ p.num_keys = cur_token_id_ + 1; \ -+ p.attn_mask_ptr = nullptr; \ -+ if (!incremental_mode_) { \ -+ p.causal = true; \ -+ } \ -+ } \ -+ p.use_past = use_past_; \ -+ } -+template -+template -+void FusedCutlassMha::forward_fmha__(const std::vector &inputs, const std::vector &output, void *ws, cudaStream_t stream) -+{ -+ ws = GetBuf(ws, ws_offset_); -+ const bool isAligned = true; -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], output[0]); -+ constexpr auto kernel_fn = attention_kernel_batched_impl; -+ int smem_bytes = sizeof(typename Attention::SharedStorage); -+ if (smem_bytes > 0xc000) { -+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); -+ } -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ p.output_accum_ptr = reinterpret_cast(ws); -+ } -+ kernel_fn<<>>(p); -+} -+ -+template -+template -+bool FusedCutlassMha::is_fmha_support__() -+{ -+ const bool isAligned = true; -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); -+ return Attention::check_supported(p); -+} -+template -+bool FusedCutlassMha::isSupport() -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return is_fmha_support__(); -+ } -+ else { -+ return is_fmha_support__(); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return is_fmha_support__(); -+ } -+} -+ -+template -+template -+size_t FusedCutlassMha::get_fmha_workspace__() -+{ -+ const bool isAligned = true; -+ -+ using Attention = INSTANTIATE_ATTENTION_KERNEL_INSTANCE_ENABLE( -+ __ARCH__, T, isAligned, kQueriesPerBlock, kKeysPerBlock, kSingleValueIteration); -+ typename Attention::Params p; -+ INSTANTIATE_ATTENTION_KERNEL_PARAM(p, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); -+ size_t size = 0; -+ if (Attention::kNeedsOutputAccumulatorBuffer) { -+ size += batch_size_ * src_seq_len_ * head_num_ * head_size_; -+ } -+ return size * sizeof(T); -+} -+ -+template -+size_t FusedCutlassMha::GetWorkspaceSize() -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return get_fmha_workspace__(); -+ } -+ else { -+ return get_fmha_workspace__(); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return get_fmha_workspace__(); -+ } -+} -+ -+template -+void FusedCutlassMha::forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+{ -+ // Determine kernel configuration based on head size. -+ // If head size is less than or equal to 64, each block operates over 64 queries and -+ // 64 keys, and parital results can be stored in the register file. -+ // If head size is greater than 64, each block operates over 32 queries and 128 keys, -+ // and partial results are stored in shared memory. -+ std::vectormha_in(inputs.size()); -+ std::transform(inputs.begin(), inputs.end(), mha_in.begin(), [](void *x) { return reinterpret_cast(x);}); -+ std::vectormha_out(outputs.size()); -+ std::transform(outputs.begin(), outputs.end(), mha_out.begin(), [](void *x) { return reinterpret_cast(x);}); -+ -+ if (head_size_ > 64) { -+ static int const kQueriesPerBlock = 32; -+ static int const kKeysPerBlock = 128; -+ if (head_size_ <= kKeysPerBlock) { -+ return forward_fmha__(mha_in,mha_out,ws, stream); -+ } -+ else { -+ return forward_fmha__(mha_in, mha_out, ws, stream); -+ } -+ } -+ else { -+ static int const kQueriesPerBlock = 64; -+ static int const kKeysPerBlock = 64; -+ return forward_fmha__(mha_in, mha_out ,ws, stream); -+ } -+} -+ -+} // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/fmha_cutlass.h b/src/fastertransformer/layers/ms_layers/fmha_cutlass.h -new file mode 100644 -index 0000000..1b7eb78 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/fmha_cutlass.h -@@ -0,0 +1,64 @@ -+/* -+ * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "param.h" -+#include "cutlass/half.h" -+#include "src/fastertransformer/layers/ms_layers/attention.h" -+namespace fastertransformer { -+template -+class FusedCutlassMha: public MhaDispatch { -+ template -+ bool is_fmha_support__(); -+ template -+ size_t get_fmha_workspace__(); -+ template -+ void forward_fmha__(const std::vector &inputs, const std::vector &output, void *ws, cudaStream_t stream); -+public: -+ FusedCutlassMha(size_t batch_size, -+ size_t src_seq_len, -+ size_t tgt_seq_len, -+ size_t head_num, -+ size_t head_size, -+ size_t hidden_size, -+ bool is_cross = false, -+ bool position_bias = false, -+ float scale = 1.0f, -+ bool mask = true, -+ bool use_past = false, -+ int rank_num = 1, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) : -+ MhaDispatch(batch_size, -+ src_seq_len, -+ tgt_seq_len, -+ head_num, -+ head_size, -+ hidden_size, -+ is_cross, -+ position_bias, -+ scale, -+ mask, -+ use_past, -+ rank_num, -+ algo) {} -+ void forward(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ bool isSupport(); -+ size_t GetWorkspaceSize(); -+}; -+template class FusedCutlassMha; -+template class FusedCutlassMha; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/gemm.cc b/src/fastertransformer/layers/ms_layers/gemm.cc -new file mode 100644 -index 0000000..f249fa8 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/gemm.cc -@@ -0,0 +1,117 @@ -+ -+#include "src/fastertransformer/layers/ms_layers/gemm.h" -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/kernels/unfused_attention_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+namespace fastertransformer { -+ -+void CublasGemmWrapper(const void* a_addr, -+ const void* b_addr, -+ void* c_addr, -+ const int* params, -+ const int* lds, -+ const cublasOperation_t* operations, -+ const cudaDataType* data_types, -+ void* alpha, -+ void* beta, -+ cublasHandle_t cublas_handle, -+ cublasGemmAlgo_t algo) -+{ -+ const int m = params[0]; -+ const int n = params[1]; -+ const int k = params[2]; -+ cublasOperation_t trans_a = operations[0]; -+ cublasOperation_t trans_b = operations[1]; -+ const int lda = lds[0]; -+ const int ldb = lds[1]; -+ const int ldc = lds[2]; -+ cudaDataType type_a = data_types[0]; -+ cudaDataType type_b = data_types[1]; -+ cudaDataType type_c = data_types[2]; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ -+ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { -+ compute_type = CUBLAS_COMPUTE_32F; -+ } -+ -+ cublasGemmEx(cublas_handle, -+ trans_a, -+ trans_b, -+ m, -+ n, -+ k, -+ alpha, -+ a_addr, -+ type_a, -+ lda, -+ b_addr, -+ type_b, -+ ldb, -+ beta, -+ c_addr, -+ type_c, -+ ldc, -+ compute_type, -+ algo); -+} -+ -+void CublasGemmStridedBatchedWrapper(const void* a_addr, -+ const void* b_addr, -+ void* c_addr, -+ const int* params, -+ const int* lds, -+ const cublasOperation_t* operations, -+ const int* strides, -+ const cudaDataType* data_types, -+ void* alpha, -+ void* beta, -+ int batch, -+ cublasHandle_t cublas_handle, -+ cublasGemmAlgo_t algo) -+{ -+ const int m = params[0]; -+ const int n = params[1]; -+ const int k = params[2]; -+ cublasOperation_t trans_a = operations[0]; -+ cublasOperation_t trans_b = operations[1]; -+ const int lda = lds[0]; -+ const int ldb = lds[1]; -+ const int ldc = lds[2]; -+ cudaDataType type_a = data_types[0]; -+ cudaDataType type_b = data_types[1]; -+ cudaDataType type_c = data_types[2]; -+ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; -+ -+ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { -+ compute_type = CUBLAS_COMPUTE_32F; -+ } -+ const int stride_a = strides[0]; -+ const int stride_b = strides[1]; -+ const int stride_c = strides[2]; -+ cublasGemmStridedBatchedEx(cublas_handle, -+ trans_a, -+ trans_b, -+ m, -+ n, -+ k, -+ alpha, -+ a_addr, -+ type_a, -+ lda, -+ stride_a, -+ b_addr, -+ type_b, -+ ldb, -+ stride_b, -+ beta, -+ c_addr, -+ type_c, -+ ldc, -+ stride_c, -+ batch, -+ compute_type, -+ algo); -+} -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/gemm.h b/src/fastertransformer/layers/ms_layers/gemm.h -new file mode 100644 -index 0000000..8c25ea9 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/gemm.h -@@ -0,0 +1,14 @@ -+#pragma once -+ -+#include "src/fastertransformer/kernels/activation_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/MSBaseLayer.h" -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+void CublasGemmWrapper(const void* a_addr, const void* b_addr, void* c_addr, const int* params, const int* lds, const cublasOperation_t* operations, const cudaDataType* data_types, void* alpha, void* beta, cublasHandle_t cublas_handle, cublasGemmAlgo_t algo); -+void CublasGemmStridedBatchedWrapper(const void* a_addr, const void* b_addr, void* c_addr, const int* params, const int* lds, const cublasOperation_t* operations, const int* strides, const cudaDataType* data_types, void* alpha, void* beta, int batch, cublasHandle_t cublas_handle, cublasGemmAlgo_t algo); -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/layer_norm.cc b/src/fastertransformer/layers/ms_layers/layer_norm.cc -new file mode 100644 -index 0000000..dcb8385 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/layer_norm.cc -@@ -0,0 +1,91 @@ -+ -+#include "src/fastertransformer/kernels/layernorm_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include "src/fastertransformer/kernels/add_residual_kernels.h" -+#include "src/fastertransformer/layers/ms_layers/layer_norm.h" -+#include -+namespace fastertransformer { -+ template -+ void LayerNorm::forward(std::vector &inputs, const std::vector &output, void *ws, cublasHandle_t cublas_handle, cudaStream_t stream) -+ { -+ in_idx_ = 0; -+ T* input = reinterpret_cast(inputs[in_idx_++]); -+ T* gamma = reinterpret_cast(inputs[in_idx_++]); -+ T* beta = (has_beta_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ T* bias = (has_bias_) ? reinterpret_cast(inputs[in_idx_++]) : nullptr; -+ switch (type_) -+ { -+ case T5: -+ invokeGeneralT5LayerNorm(reinterpret_cast(output[0]), -+ input, -+ gamma, -+ beta, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL: -+ invokeAddBiasResidualLayerNorm(reinterpret_cast(output[0]), -+ input, -+ bias, -+ gamma, -+ beta, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_CAST: -+ if(!std::is_same::value) invokeAddBiasResidualLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(input), -+ bias, -+ gamma, // gamma -+ beta, // beta -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_CAST_FFN: -+ if(!std::is_same::value) invokeAddBiasResidualLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ reinterpret_cast(input), -+ bias, -+ gamma, // gamma -+ beta, // beta -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_T5_PRE: -+ invokeGeneralAddBiasResidualT5PreLayerNorm(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ input, -+ gamma, // gamma -+ beta, // beta -+ bias, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ case ADD_BIAS_RESIDUAL_T5_PRE_CAST: -+ if(!std::is_same::value) invokeGeneralAddBiasResidualT5PreLayerNormCast(reinterpret_cast(output[0]), -+ reinterpret_cast(output[1]), -+ input, -+ gamma, // gamma -+ beta, // beta -+ bias, -+ m_, -+ n_, -+ stream, -+ eps_); -+ break; -+ default: -+ break; -+ } -+ } -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/layer_norm.h b/src/fastertransformer/layers/ms_layers/layer_norm.h -new file mode 100644 -index 0000000..df88423 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/layer_norm.h -@@ -0,0 +1,60 @@ -+#pragma once -+ -+#include "src/fastertransformer/layers/ms_layers/BaseLayer.h" -+namespace fastertransformer { -+ -+ template -+ class LayerNorm: public BaseLayerMS { -+ public: -+ enum Type { -+ T5, -+ ADD_BIAS_RESIDUAL, -+ ADD_BIAS_RESIDUAL_CAST, -+ ADD_BIAS_RESIDUAL_T5_PRE, -+ ADD_BIAS_RESIDUAL_T5_PRE_CAST, -+ ADD_BIAS_RESIDUAL_CAST_FFN -+ }; -+ static Type type; -+ private: -+ size_t m_; -+ size_t n_; -+ bool has_beta_; -+ bool has_bias_; -+ Type type_; -+ float eps_; -+ public: -+ void SetParams(size_t m, size_t n) -+ { -+ m_ = m; -+ n_ = n; -+ } -+ -+ void SetType(Type type) -+ { -+ type_ = type; -+ if (type != T5) has_bias_ = true; -+ else has_bias_ = false; -+ } -+ void SetEps(float eps) -+ { -+ eps_ = eps; -+ } -+ void SetBeta(bool has_beta) -+ { -+ has_beta_ = has_beta; -+ } -+ void forward(std::vector &inputs, const std::vector &outputs, void *ws, cublasHandle_t cublas_handle = nullptr, cudaStream_t stream = 0) override; -+ LayerNorm(bool has_beta = true, -+ bool has_bias = false, -+ Type type = 0, -+ float eps = 1e-6f, -+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP):BaseLayerMS(0, 0, 0, 0, 0, 0, 0, algo){ -+ has_beta_ = has_beta; -+ has_bias_ = has_bias; -+ type_ = type; -+ eps_ = eps; -+ } -+ }; -+ template class LayerNorm; -+ template class LayerNorm; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/layers/ms_layers/opt_allocator.cc b/src/fastertransformer/layers/ms_layers/opt_allocator.cc -new file mode 100644 -index 0000000..560b5ba ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/opt_allocator.cc -@@ -0,0 +1,89 @@ -+/** -+ * Copyright 2021 Huawei Technologies Co., Ltd -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -+#include -+#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -+ -+namespace fastertransformer { -+size_t OptAllocator::FindFree(size_t size) { -+ size_t min_size = std::numeric_limits::max(); -+ size_t min_addr = std::numeric_limits::max(); -+ for (auto const &itr : arena_) { -+ // best fit -+ if (itr.second >= size) { -+ if (min_size > itr.second) { -+ min_size = itr.second; -+ min_addr = itr.first; -+ } -+ } -+ } -+ return min_addr; -+} -+ -+void OptAllocator::Reorder(size_t addr) { -+ size_t length = arena_[addr]; -+ size_t post = addr + length; -+ // connect to upper block -+ auto it = arena_.find(post); -+ if (it != arena_.end()) { -+ size_t post_size = it->second; -+ arena_[addr] = length + post_size; -+ arena_.erase(post); -+ } -+ // connect to lower block -+ auto itr = arena_.lower_bound(addr); -+ if (itr != arena_.begin()) { -+ itr--; -+ size_t last = itr->first; -+ if ((last + arena_[last]) == addr) { -+ arena_[last] = arena_[last] + arena_[addr]; -+ arena_.erase(addr); -+ } -+ } -+} -+ -+size_t OptAllocator::Malloc(size_t size) { -+ size = UP_DIV(size, align_size_) * align_size_; -+ size_t addr = FindFree(size); -+ // free block not found -+ if (addr == std::numeric_limits::max()) { -+ if (!arena_.empty()) { -+ addr = arena_.rbegin()->first; -+ if (addr + arena_[addr] < heap_) { -+ addr = heap_; -+ } else { -+ arena_.erase(addr); -+ } -+ } else { -+ addr = heap_; -+ } -+ heap_ = addr + size; -+ } else { -+ if (arena_[addr] > size) { -+ arena_[addr + size] = arena_[addr] - size; -+ } -+ arena_.erase(addr); -+ } -+ alloc_[addr] = size; -+ return addr; -+} -+ -+void OptAllocator::Free(size_t addr) { -+ arena_[addr] = alloc_[addr]; -+ alloc_.erase(addr); -+ Reorder(addr); -+} -+} // namespace fastertransformer -diff --git a/src/fastertransformer/layers/ms_layers/opt_allocator.h b/src/fastertransformer/layers/ms_layers/opt_allocator.h -new file mode 100644 -index 0000000..13a539e ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/opt_allocator.h -@@ -0,0 +1,40 @@ -+/** -+ * Copyright 2020 Huawei Technologies Co., Ltd -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#ifndef FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -+#define FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -+ -+#include -+ -+namespace fastertransformer { -+class OptAllocator { -+ public: -+ explicit OptAllocator(size_t aligned_size = 32) : align_size_(aligned_size) {} -+ ~OptAllocator() {} -+ size_t Malloc(size_t size); -+ void Free(size_t offset); -+ size_t total_size() { return heap_; } -+ -+ private: -+ size_t FindFree(size_t size); -+ void Reorder(size_t addr); -+ std::map arena_; -+ std::map alloc_; -+ size_t heap_ = 0; -+ size_t align_size_; -+}; -+}; // namespace fastertransformer -+#endif // FASTERTRANSFOMER_LITE_SRC_TRAIN_OPT_ALLOCATOR_H_ -diff --git a/src/fastertransformer/layers/ms_layers/param.h b/src/fastertransformer/layers/ms_layers/param.h -new file mode 100644 -index 0000000..c8cb149 ---- /dev/null -+++ b/src/fastertransformer/layers/ms_layers/param.h -@@ -0,0 +1,183 @@ -+#pragma once -+#include "src/fastertransformer/layers/ms_layers/debug_utils.h" -+#include -+#include -+#include -+#include -+#include -+#include "src/fastertransformer/layers/ms_layers/ffn.h" -+ -+ -+ -+namespace fastertransformer { -+ -+ -+typedef enum ActType { -+ ActType_No = 0, -+ ActType_Relu = 1, -+ ActType_Sigmoid = 2, -+ ActType_Relu6 = 3, -+ ActType_Elu = 4, -+ ActType_LeakyRelu = 5, -+ ActType_Abs = 6, -+ ActType_Relu1 = 7, -+ ActType_Softsign = 8, -+ ActType_Softplus = 9, -+ ActType_Tanh = 10, -+ ActType_Selu = 11, -+ ActType_HSwish = 12, -+ ActType_HSigmoid = 13, -+ ActType_ThresholdRelu = 14, -+ ActType_Linear = 15, -+ ActType_HardTanh = 16, -+ ActType_Sign = 17, -+ ActType_Swish = 18, -+ ActType_Gelu = 19, -+ ActType_FastGelu = 20, -+ ActType_Unknown = 21 -+} ActType; -+typedef enum FmhaType { -+ FmhaType_UnFused, -+ FmhaType_CutlassFix -+} FmhaType; -+ -+typedef struct { -+ size_t batch_size; -+ size_t src_seq_len; -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t head_size; -+ size_t data_parallel; -+ size_t hidden_size; -+ int rank_num; -+ int rank_id; -+ size_t h_token_num; -+ size_t h_token_num2; -+ cublasGemmAlgo_t algo; -+ cublasHandle_t cublas_handle; -+ cudaStream_t stream; -+ int in_idx; -+ bool eft; -+ FfnBase::allGatherFunc all_gather_func; -+ FfnBase::allReduceSumFunc all_reduce_sum_func; -+ -+ int embedding_size; -+ bool use_past; // use past mode -+ bool query_layer; // check if quary layer -+ int cur_token_id; // current token id id -+ bool incremental_mode; // mode of inference -+ void* k_cache; -+ void* v_cache; -+} CommonParam; -+ -+typedef struct { -+ bool ffn_bias; -+ bool has_beta; -+ bool ffn_fp16; -+ size_t ffn_hidden_size; -+ ActType act_type; -+ size_t expert_num; -+ size_t expert_offset; -+ size_t capacity_factor; -+ bool load_weights; -+ size_t weight_mapping; -+ size_t weight_projection; -+} ffnParam; -+ -+typedef struct { -+ CommonParam* common_param; -+ ffnParam ffn_param; -+ bool is_moe; -+ std::shared_ptr moe; -+} ffnParamRun; -+ -+typedef struct { -+ bool qkv_bias; // ture -+ bool projection_bias; // ture -+ bool is_cross; // false -+ bool position_bias; -+ float scale; -+ size_t qkv_buf; -+ size_t q_buf_2; -+ size_t output1; -+ size_t output2; -+ size_t qk_buf; -+ size_t qkv_buf_2; -+ size_t qkv_buf_3; -+ size_t mha; -+ bool mask; -+ FmhaType fmha_type; -+ int* padding_offset; -+ int* d_sequence_length; -+ int* padding_offset2; -+ int* d_sequence_length2; -+} attentionParam; -+ -+typedef struct { -+ CommonParam* common_param; -+ attentionParam attn; -+} attentionParamRun; -+ -+typedef struct { -+ float eps1; -+ float eps2; -+ float eps3; -+ float eps4; -+ bool layernorm_post; -+ bool has_beta; -+ size_t normed_from_tensor_buf; -+ size_t attn_out_buf; -+ size_t attn_ws_buf; -+ size_t attn2_out_buf; -+ size_t attn2_ws_buf; -+ size_t tmp_out_buf; -+ size_t ffn_ws_buf; -+ size_t normed_attn_out_buf; -+ size_t normed_attn2_out_buf; -+ size_t compress_buf; -+ size_t d_token_num_buf; -+ size_t padding_offset_buf; -+ size_t d_sequence_lengths_offset_buf; -+ size_t compress_buf2; -+ size_t d_token_num_buf2; -+ size_t padding_offset_buf2; -+ size_t d_sequence_lengths_offset_buf2; -+ bool is_layernorm; -+} decoderParam; -+ -+typedef struct { -+ CommonParam common_param; -+ attentionParamRun attn1; -+ attentionParamRun attn2; -+ ffnParamRun ffn_param; -+ decoderParam decoder; -+} decoderParamRun; -+ -+typedef struct { -+ float eps1; -+ float eps2; -+ float eps3; -+ bool layernorm_post; -+ bool has_beta; -+ size_t normed_from_tensor_buf; -+ size_t attn_ws_buf; -+ size_t attn_out_buf; -+ size_t normed_attn_out_buf; -+ size_t ffn_ws_buf; -+ size_t tmp_out_buf; -+ size_t tmp_out1_buf; -+ size_t compress_buf; -+ size_t d_token_num_buf; -+ size_t padding_offset_buf; -+ size_t d_sequence_lengths_offset_buf; -+ size_t norm_out_buf; -+ bool is_layernorm; -+} encoderParam; -+ -+typedef struct { -+ CommonParam common_param; -+ ffnParamRun ffn_param; -+ attentionParamRun attn; -+ encoderParam encoder; -+} encoderParamRun; -+} // namespace fastertransformer -\ No newline at end of file -diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt -index af33e76..97fc471 100644 ---- a/src/fastertransformer/models/CMakeLists.txt -+++ b/src/fastertransformer/models/CMakeLists.txt -@@ -21,8 +21,11 @@ add_subdirectory(xlnet) - - add_subdirectory(t5) - add_subdirectory(gptj) --add_subdirectory(multi_gpu_gpt) -+if(EXAMPLES) -+ add_subdirectory(multi_gpu_gpt) -+endif() - add_subdirectory(swin) - add_subdirectory(swin_int8) - add_subdirectory(vit) --add_subdirectory(vit_int8) -\ No newline at end of file -+add_subdirectory(vit_int8) -+add_subdirectory(ms) -\ No newline at end of file -diff --git a/src/fastertransformer/models/bert/Bert.cc b/src/fastertransformer/models/bert/Bert.cc -index ac727df..0682288 100644 ---- a/src/fastertransformer/models/bert/Bert.cc -+++ b/src/fastertransformer/models/bert/Bert.cc -@@ -255,7 +255,7 @@ void Bert::forward(std::vector* output_tensors, - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - invokeGetPaddingOffset(&h_token_num, - token_num_, -@@ -281,7 +281,7 @@ void Bert::forward(std::vector* output_tensors, - } - case AttentionType::UNFUSED_PADDED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; -diff --git a/src/fastertransformer/models/bert_int8/BertINT8.cc b/src/fastertransformer/models/bert_int8/BertINT8.cc -index 7c6347b..5f374ee 100644 ---- a/src/fastertransformer/models/bert_int8/BertINT8.cc -+++ b/src/fastertransformer/models/bert_int8/BertINT8.cc -@@ -180,7 +180,7 @@ void BertINT8::forward(std::vector* output_tensors, - switch (attention_type_) { - case AttentionType::UNFUSED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0, stream_); - sync_check_cuda_error(); - invokeGetPaddingOffset(&h_token_num, - token_num_, -@@ -206,7 +206,7 @@ void BertINT8::forward(std::vector* output_tensors, - } - case AttentionType::UNFUSED_PADDED_MHA: { - invokeBuildEncoderAttentionMask( -- attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); -+ attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, 0,stream_); - sync_check_cuda_error(); - h_token_num = request_batch_size * request_seq_len; - bert_input_ptr = (T*)input_tensors->at(0).data; -diff --git a/src/fastertransformer/models/gptj/CMakeLists.txt b/src/fastertransformer/models/gptj/CMakeLists.txt -index d7d9d3e..e69a988 100644 ---- a/src/fastertransformer/models/gptj/CMakeLists.txt -+++ b/src/fastertransformer/models/gptj/CMakeLists.txt -@@ -19,6 +19,7 @@ set_property(TARGET GptJDecoderLayerWeight PROPERTY POSITION_INDEPENDENT_CODE O - set_property(TARGET GptJDecoderLayerWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(GptJDecoderLayerWeight PUBLIC memory_utils) - -+if(off) - add_library(GptJDecoder STATIC GptJDecoder.cc) - set_property(TARGET GptJDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -40,12 +41,14 @@ target_link_libraries(GptJContextDecoder PUBLIC -lcudart cublasMMWrapper - add_residual_kernels - gpt_kernels - nccl_utils) -+endif() - - add_library(GptJWeight STATIC GptJWeight.cc) - set_property(TARGET GptJWeight PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(GptJWeight PUBLIC GptJDecoderLayerWeight) - -+if(off) - add_library(GptJ STATIC GptJ.cc) - set_property(TARGET GptJ PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET GptJ PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -58,3 +61,4 @@ target_link_libraries(GptJ PUBLIC -lcudart - BaseBeamSearchLayer - bert_preprocess_kernels - GptJWeight) -+endif() -\ No newline at end of file -diff --git a/src/fastertransformer/models/gptj/GptJ.cc b/src/fastertransformer/models/gptj/GptJ.cc -index 0829e0d..fe41d4b 100644 ---- a/src/fastertransformer/models/gptj/GptJ.cc -+++ b/src/fastertransformer/models/gptj/GptJ.cc -@@ -665,7 +665,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - logits_buf_ + vocab_size_units_offset, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - } - else { -@@ -691,7 +691,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, - nccl_logits_buf_ + vocab_size_units_offset, -@@ -928,7 +928,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); -@@ -958,7 +958,7 @@ void GptJ::forward(std::unordered_map* output_tensors, - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); -diff --git a/src/fastertransformer/models/ms/CMakeLists.txt b/src/fastertransformer/models/ms/CMakeLists.txt -new file mode 100644 -index 0000000..8a99ce4 ---- /dev/null -+++ b/src/fastertransformer/models/ms/CMakeLists.txt -@@ -0,0 +1,19 @@ -+# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+cmake_minimum_required(VERSION 3.8) -+ -+add_executable(ms_gemm main.cc) -+# target_link_libraries(ms_gemm PUBLIC -lcudart encoder_gemm_func encoder_igemm_func memory_utils) -+target_link_libraries(ms_gemm PUBLIC -lcudart ms_gemm_func memory_utils) -diff --git a/src/fastertransformer/models/ms/main.cc b/src/fastertransformer/models/ms/main.cc -new file mode 100644 -index 0000000..cd5844f ---- /dev/null -+++ b/src/fastertransformer/models/ms/main.cc -@@ -0,0 +1,179 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" -+#include "src/fastertransformer/utils/memory_utils.h" -+ -+namespace ft = fastertransformer; -+ -+struct ms_opt_arg { -+ size_t batch_size; -+ size_t num_layers; -+ size_t seq_len; // source seq len -+ size_t tgt_seq_len; -+ size_t head_num; -+ size_t hidden_size; -+ size_t size_per_head; -+ bool is_remove_padding; -+ int m; -+ int n; -+ int k; -+ std::string model_name; -+ std::string compute_type; -+ std::string w_compute_type; -+ std::string s_compute_type; -+}; -+ -+void usage() { -+ std::cout << "Usage: ms_benchmark -b -l -t " -+ << "-s -H -S -p " -+ << "-T -W -F " -+ << "-m -c -M -N -K \n"; -+} -+ -+bool read_args(int argc, char* argv[], ms_opt_arg* opt_a) { -+ int opt; -+ while ((opt = getopt(argc, argv, "b:l:s:t:H:S:p:m:T:W:F:i:w:M:N:K:")) != -1) { -+ switch (opt) { -+ case 'b': -+ opt_a->batch_size = atoi(optarg); -+ break; -+ case 'l': -+ opt_a->num_layers = atoi(optarg); -+ break; -+ case 's': -+ opt_a->seq_len = atoi(optarg); -+ break; -+ case 't': -+ opt_a->tgt_seq_len = atoi(optarg); -+ break; -+ case 'H': -+ opt_a->head_num = atoi(optarg); -+ break; -+ case 'S': -+ opt_a->hidden_size = atoi(optarg); -+ break; -+ case 'p': -+ opt_a->is_remove_padding = static_cast(atoi(optarg)); -+ break; -+ case 'm': -+ opt_a->model_name = std::string(optarg); -+ break; -+ case 'T': -+ opt_a->compute_type = std::string(optarg); -+ break; -+ case 'W': -+ opt_a->w_compute_type = std::string(optarg); -+ break; -+ case 'F': -+ opt_a->s_compute_type = std::string(optarg); -+ break; -+ case 'M': -+ opt_a->m = atoi(optarg); -+ break; -+ case 'N': -+ opt_a->n = atoi(optarg); -+ break; -+ case 'K': -+ opt_a->k = atoi(optarg); -+ break; -+ case 'i': -+ case 'w': -+ break; -+ case 'h': -+ default: -+ usage(); -+ return false; -+ } -+ } -+ opt_a->size_per_head = opt_a->hidden_size / opt_a->head_num; -+ opt_a->tgt_seq_len = (opt_a->tgt_seq_len == -1) ? opt_a->seq_len : opt_a->tgt_seq_len; -+ return true; -+} -+ -+int main(int argc, char* argv[]) -+{ -+ ms_opt_arg opt_a; -+ opt_a.batch_size = 1; -+ opt_a.num_layers = 1; -+ opt_a.seq_len = 1; -+ opt_a.tgt_seq_len = -1; -+ opt_a.head_num = 1; -+ opt_a.hidden_size = 1; -+ opt_a.size_per_head = 1; -+ opt_a.is_remove_padding = false; -+ opt_a.m = 1; -+ opt_a.n = 1; -+ opt_a.k = 1; -+ opt_a.model_name = ""; -+ opt_a.compute_type = "fp32"; -+ opt_a.w_compute_type = "fp32"; -+ opt_a.s_compute_type = "fp32"; -+ -+ if (!read_args(argc, argv, &opt_a)) { -+ printf("[ERROR] Failed to read arguments. \n"); -+ usage(); -+ return 0; -+ } -+ -+ bool c_type_fp32 = (opt_a.compute_type.compare("fp32") == 0); -+ std::cout << "[INFO] arguments: " << std::endl; -+ std::cout << " batch_size: " << opt_a.batch_size << std::endl; -+ std::cout << " num of layers: " << opt_a.num_layers << std::endl; -+ std::cout << " seq len:" << opt_a.seq_len << std::endl; -+ std::cout << " target seq len: " << opt_a.tgt_seq_len << std::endl; -+ std::cout << " head_num: " << opt_a.head_num << std::endl; -+ std::cout << " size_per_head: " << opt_a.size_per_head << std::endl; -+ // std::cout << " compute_type: " << c_type_fp32 << std::endl; -+ -+ std::cout << std::endl; -+ -+ const int inter_size = 4 * opt_a.head_num * opt_a.size_per_head; -+ const ft::CublasDataType data_type = static_cast(0); // 0 FP32, 1 FP16, 2 BF 16 -+ void* gemm_test_buf; -+ size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte(opt_a.batch_size, -+ opt_a.seq_len, -+ opt_a.head_num, -+ opt_a.size_per_head, -+ inter_size, -+ 0, // default -+ 0, // default -+ data_type); -+ -+ size_t total, free; -+ ft::check_cuda_error(cudaMemGetInfo(&free, &total)); -+ if (free < buf_size_in_byte + 10 * 1024 * 1024) { -+ printf("[ERROR] There is no enough device memory for gemm test!\n" -+ " %ld Bytes is needed, but only %ld Bytes is free.\n", -+ buf_size_in_byte, -+ free); -+ gemm_test_buf = NULL; -+ return -1; -+ } else { -+ ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, false); -+ } -+ // int fast_algo = 0; -+ if (data_type == ft::FLOAT_DATATYPE) { -+ ft::generate_ms_gemm_config(opt_a.batch_size, opt_a.seq_len, opt_a.tgt_seq_len, opt_a.head_num, opt_a.size_per_head, gemm_test_buf, -+ false); -+ } else { -+ printf("[ERROR] data type only supports fp32(0). \n"); -+ return -1; -+ } -+ // std::cout << "main fast algo: " << fast_algo << std::endl; -+ ft::check_cuda_error(cudaFree(gemm_test_buf)); -+ return 0; -+} -\ No newline at end of file -diff --git a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -index 10b9e0b..86d733f 100644 ---- a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -+++ b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt -@@ -37,7 +37,7 @@ set_property(TARGET ParallelGptDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(ParallelGptDecoder PUBLIC -lcudart TensorParallelGeluFfnLayer - TensorParallelDecoderSelfAttentionLayer layernorm_kernels - add_residual_kernels nccl_utils) -- -+ - add_library(ParallelGpt STATIC ParallelGpt.cc) - set_property(TARGET ParallelGpt PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET ParallelGpt PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -index 17f9099..d171b4b 100644 ---- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -+++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc -@@ -345,7 +345,7 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, - lp_logits_buf_, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - sync_check_cuda_error(); - } -@@ -370,7 +370,7 @@ void ParallelGpt::computeContextCumLogProbs(float* cum_log_probs, - lp_nccl_logits_buf_ + tensor_para_.rank_ * n_hidden_states * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - sync_check_cuda_error(); - ftNcclAllGather(lp_nccl_logits_buf_, -@@ -803,7 +803,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - logits_buf_ + vocab_size_units_offset, - CUDA_R_32F, - vocab_size_padded_, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - } - else { -@@ -829,7 +829,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, - CUDA_R_32F, - local_vocab_size, /* n */ -- CUDA_R_32F, -+ CUBLAS_COMPUTE_32F_FAST_TF32, - cublasGemmAlgo_t(-1)); - ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, - nccl_logits_buf_ + vocab_size_units_offset, -@@ -1057,7 +1057,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclSend(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - 0, - pipeline_para_, - stream_); -@@ -1087,7 +1087,7 @@ void ParallelGpt::forward(std::unordered_map* output_ten - if (output_tensors->count("output_log_probs") > 0 - && output_tensors->at("output_log_probs").data != nullptr) { - ftNcclRecv(output_tensors->at("output_log_probs").getPtr(), -- batch_size * beam_width * input_tensors->at("max_output_seq_len").getVal(), -+ output_tensors->at("output_log_probs").size(), - pipeline_para_.world_size_ - 1, - pipeline_para_, - stream_); -diff --git a/src/fastertransformer/models/t5/CMakeLists.txt b/src/fastertransformer/models/t5/CMakeLists.txt -index 9f3455d..e75bbbd 100644 ---- a/src/fastertransformer/models/t5/CMakeLists.txt -+++ b/src/fastertransformer/models/t5/CMakeLists.txt -@@ -14,6 +14,7 @@ - - cmake_minimum_required(VERSION 3.8) - -+if(False) - add_library(T5Decoder STATIC T5Decoder.cc T5DecoderLayerWeight.cc) - set_property(TARGET T5Decoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Decoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -21,6 +22,7 @@ target_link_libraries(T5Decoder PUBLIC -lcudart cublasMMWrapper TensorParallelDe - TensorParallelDecoderCrossAttentionLayer TensorParallelReluFfnLayer - layernorm_kernels add_residual_kernels nccl_utils memory_utils) - -+ - add_library(T5Decoding STATIC T5Decoding.cc T5DecodingWeight.cc) - set_property(TARGET T5Decoding PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Decoding PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -28,6 +30,8 @@ target_link_libraries(T5Decoding PUBLIC -lcudart cublasMMWrapper T5Decoder bert_ - decoding_kernels DynamicDecodeLayer BaseBeamSearchLayer - beam_search_topk_kernels gpt_kernels) - -+ -+ - add_library(T5Encoder STATIC T5Encoder.cc T5EncoderWeight.cc T5EncoderLayerWeight.cc) - set_property(TARGET T5Encoder PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET T5Encoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -@@ -36,4 +40,5 @@ target_link_libraries(T5Encoder PUBLIC -lcudart bert_preprocess_kernels cublasMM - TensorParallelGeluFfnLayer layernorm_kernels add_residual_kernels nccl_utils) - - add_executable(t5_gemm t5_gemm.cc) --target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) -\ No newline at end of file -+target_link_libraries(t5_gemm PUBLIC -lcudart t5_gemm_func memory_utils) -+endif() -\ No newline at end of file -diff --git a/src/fastertransformer/models/t5/T5Encoder.cc b/src/fastertransformer/models/t5/T5Encoder.cc -index 698e3d6..db989ff 100644 ---- a/src/fastertransformer/models/t5/T5Encoder.cc -+++ b/src/fastertransformer/models/t5/T5Encoder.cc -@@ -380,7 +380,7 @@ void T5Encoder::forward(std::unordered_map* output_tenso - request_seq_len, - request_seq_len, - local_batch_size, -- hidden_units_, -+ d_model_, - stream_); - } - else { -diff --git a/src/fastertransformer/models/vit/ViT.cc b/src/fastertransformer/models/vit/ViT.cc -index e785f2b..9a967e4 100644 ---- a/src/fastertransformer/models/vit/ViT.cc -+++ b/src/fastertransformer/models/vit/ViT.cc -@@ -415,7 +415,7 @@ bool ViTTransformer::setSeqLenVec(size_t batch_size) - template - void ViTTransformer::setDefaultMask(size_t batch_size) - { -- invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); -+ invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, 0, stream_); - } - - template -diff --git a/src/fastertransformer/models/vit_int8/ViTINT8.cc b/src/fastertransformer/models/vit_int8/ViTINT8.cc -index f610785..44fc5fc 100644 ---- a/src/fastertransformer/models/vit_int8/ViTINT8.cc -+++ b/src/fastertransformer/models/vit_int8/ViTINT8.cc -@@ -462,7 +462,7 @@ bool ViTTransformerINT8::setSeqLenVec(size_t batch_size) - template - void ViTTransformerINT8::setDefaultMask(size_t batch_size) - { -- invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); -+ invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, 0, stream_); - } - - template -diff --git a/src/fastertransformer/utils/CMakeLists.txt b/src/fastertransformer/utils/CMakeLists.txt -index 3d0f28a..3d2efbd 100644 ---- a/src/fastertransformer/utils/CMakeLists.txt -+++ b/src/fastertransformer/utils/CMakeLists.txt -@@ -44,10 +44,12 @@ set_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(memory_utils PUBLIC -lnvToolsExt) - -+if(EXAMPLES) - add_library(nccl_utils STATIC nccl_utils.cc) - set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - target_link_libraries(nccl_utils PUBLIC -lnccl) -+endif() - - add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) - set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) -diff --git a/src/fastertransformer/utils/cublasMMWrapper.cc b/src/fastertransformer/utils/cublasMMWrapper.cc -index e291151..e0c6d20 100644 ---- a/src/fastertransformer/utils/cublasMMWrapper.cc -+++ b/src/fastertransformer/utils/cublasMMWrapper.cc -@@ -99,7 +99,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - void* C, - cudaDataType_t Ctype, - int ldc, -- cudaDataType_t computeType, -+ cublasComputeType_t computeType, - cublasGemmAlgo_t algo) - { - mu_->lock(); -@@ -160,7 +160,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - - mu_->lock(); - // TODO: default cublas libs -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; - int batch_count = 1; - // fp32 use cublas as default -@@ -187,14 +187,14 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, - #if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType; - #else -- cudaDataType_t computeType; -+ cublasComputeType_t computeType; - #endif - - if (is_fp16_computeType) { - #if (CUDART_VERSION >= 11000) - computeType = CUBLAS_COMPUTE_16F; - #else -- computeType = CUDA_R_16F; -+ computeType = CUBLAS_COMPUTE_16F; - #endif - scaleType = CUDA_R_16F; - } -@@ -302,7 +302,7 @@ void cublasMMWrapper::setFP32GemmConfig() - Atype_ = CUDA_R_32F; - Btype_ = CUDA_R_32F; - Ctype_ = CUDA_R_32F; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; - } - - void cublasMMWrapper::setFP16GemmConfig() -@@ -310,7 +310,23 @@ void cublasMMWrapper::setFP16GemmConfig() - Atype_ = CUDA_R_16F; - Btype_ = CUDA_R_16F; - Ctype_ = CUDA_R_16F; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_16F; -+} -+ -+void cublasMMWrapper::setFP32MixedGemmConfig() -+{ -+ Atype_ = CUDA_R_32F; -+ Btype_ = CUDA_R_16F; -+ Ctype_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; -+} -+ -+void cublasMMWrapper::setFP16MixedGemmConfig() -+{ -+ Atype_ = CUDA_R_16F; -+ Btype_ = CUDA_R_32F; -+ Ctype_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_32F_FAST_TF32; - } - - #ifdef ENABLE_BF16 -@@ -319,14 +335,14 @@ void cublasMMWrapper::setBF16GemmConfig() - Atype_ = CUDA_R_16BF; - Btype_ = CUDA_R_16BF; - Ctype_ = CUDA_R_16BF; -- computeType_ = CUDA_R_32F; -+ computeType_ = CUBLAS_COMPUTE_16F; - } - #endif - - void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, - cudaDataType_t bType, - cudaDataType_t cType, -- cudaDataType_t computeType) -+ cublasComputeType_t computeType) - { - Atype_ = aType; - Btype_ = bType; -@@ -451,7 +467,7 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, - half h_beta = (half)f_beta; - - mu_->lock(); -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = - is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); -@@ -504,13 +520,13 @@ void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, - const int ldc, - const int64_t strideC, - const int batch_count, -- cudaDataType_t computeType) -+ cublasComputeType_t computeType) - { - half h_alpha = (half)f_alpha; - half h_beta = (half)f_beta; - - mu_->lock(); -- int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = - is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); -@@ -563,7 +579,7 @@ void cublasMMWrapper::batchedGemm(cublasOperation_t transa, - half h_beta = (half)0.0f; - - mu_->lock(); -- int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; -+ int is_fp16_computeType = computeType_ == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = is_fp16_computeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); -diff --git a/src/fastertransformer/utils/cublasMMWrapper.h b/src/fastertransformer/utils/cublasMMWrapper.h -index 6f410ab..21a8ea8 100644 ---- a/src/fastertransformer/utils/cublasMMWrapper.h -+++ b/src/fastertransformer/utils/cublasMMWrapper.h -@@ -41,7 +41,7 @@ private: - cudaDataType_t Atype_; - cudaDataType_t Btype_; - cudaDataType_t Ctype_; -- cudaDataType_t computeType_; -+ cublasComputeType_t computeType_; - - cudaStream_t stream_; - cublasAlgoMap* cublas_algo_map_; -@@ -90,7 +90,7 @@ public: - void* C, - cudaDataType_t Ctype, - int ldc, -- cudaDataType_t computeType, -+ cublasComputeType_t computeType, - cublasGemmAlgo_t algo); - - void Gemm(cublasOperation_t transa, -@@ -121,12 +121,14 @@ public: - - void setFP32GemmConfig(); - void setFP16GemmConfig(); -+ void setFP32MixedGemmConfig(); -+ void setFP16MixedGemmConfig(); - #ifdef ENABLE_BF16 - void setBF16GemmConfig(); - #endif - void setStream(cudaStream_t stream); - -- void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); -+ void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cublasComputeType_t computeType); - - CublasDataType getCublasDataType(cudaDataType_t data_type); - -@@ -183,7 +185,7 @@ public: - const int ldc, - const int64_t strideC, - const int batch_count, -- cudaDataType_t computeType); -+ cublasComputeType_t computeType); - - void batchedGemm(cublasOperation_t transa, - cublasOperation_t transb, -diff --git a/src/fastertransformer/utils/cuda_utils.h b/src/fastertransformer/utils/cuda_utils.h -index 5d73c87..aef6ab9 100644 ---- a/src/fastertransformer/utils/cuda_utils.h -+++ b/src/fastertransformer/utils/cuda_utils.h -@@ -382,7 +382,7 @@ public: - - static double diffTime(timeval start, timeval end) - { -- return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; -+ return (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_usec - start.tv_usec); - } - - /* ***************************** common utils ****************************** */ -diff --git a/src/fastertransformer/utils/custom_ar_comm.cc b/src/fastertransformer/utils/custom_ar_comm.cc -index ded1e58..159faaf 100644 ---- a/src/fastertransformer/utils/custom_ar_comm.cc -+++ b/src/fastertransformer/utils/custom_ar_comm.cc -@@ -54,6 +54,7 @@ void CustomAllReduceComm::customAllReduce(size_t elts, cudaStream_t stream) - output_tensor_->at(0).data = (const void*)tmp_tensor_data_; - } - -+ - template - void CustomAllReduceComm::allocateAndExchangePeerAccessPointer( - std::vector>* custom_all_reduce_comms) -diff --git a/src/fastertransformer/utils/gemm_test/CMakeLists.txt b/src/fastertransformer/utils/gemm_test/CMakeLists.txt -index 223b85d..ab48356 100644 ---- a/src/fastertransformer/utils/gemm_test/CMakeLists.txt -+++ b/src/fastertransformer/utils/gemm_test/CMakeLists.txt -@@ -49,6 +49,10 @@ set(swin_gemm_func_files - swin_gemm_func.cc - ) - -+set(ms_gemm_func_files -+ ms_gemm_func.cc -+) -+ - add_library(gemm_func STATIC ${gemm_func_files}) - target_link_libraries(gemm_func PUBLIC -lcublas -lcublasLt -lcudart) - set_property(TARGET gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) -@@ -109,3 +113,12 @@ add_library(swin_gemm_func STATIC ${swin_gemm_func_files}) - target_link_libraries(swin_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func) - set_property(TARGET swin_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET swin_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -+ -+add_library(ms_gemm_func STATIC ${ms_gemm_func_files}) -+if (SPARSITY_SUPPORT) -+target_link_libraries(ms_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func -lcusparse -lcusparseLt) -+else() -+target_link_libraries(ms_gemm_func PUBLIC -lcublas -lcublasLt -lcudart gemm_func) -+endif() -+set_property(TARGET ms_gemm_func PROPERTY POSITION_INDEPENDENT_CODE ON) -+set_property(TARGET ms_gemm_func PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -\ No newline at end of file -diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -index 03c6947..00f8ca0 100644 ---- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -+++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.cc -@@ -26,11 +26,11 @@ void generate_encoder_gemm_config( - void* buffer; - int workSpaceSize; - --#ifdef ENABLE_BF16 -- if (std::is_same::value || std::is_same::value) { --#else -+// #ifdef ENABLE_BF16 -+// if (std::is_same::value || std::is_same::value) { -+// #else - if (std::is_same::value) { --#endif // ENABLE_BF16 -+// #endif // ENABLE_BF16 - // cublas_workspace_ should be the start pointer of cudaMalloc() - // to ensure 16B alignemnet - cublas_workspace = buffer_in; -diff --git a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -index fd067b9..4bf3d6c 100644 ---- a/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -+++ b/src/fastertransformer/utils/gemm_test/encoder_gemm_func.h -@@ -36,5 +36,4 @@ namespace fastertransformer { - template - void generate_encoder_gemm_config( - int batch_size, int seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); -- - } // namespace fastertransformer -diff --git a/src/fastertransformer/utils/gemm_test/gemm_func.cc b/src/fastertransformer/utils/gemm_test/gemm_func.cc -index edbfc40..6187d45 100644 ---- a/src/fastertransformer/utils/gemm_test/gemm_func.cc -+++ b/src/fastertransformer/utils/gemm_test/gemm_func.cc -@@ -534,7 +534,6 @@ int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - } - } - } -- - // workspacesize==0 - printf("workspacesize==0, run %d algos\n", AlgoCountRestrict); - for (int i = 0; i < AlgoCountRestrict && i < (maxNumTraversal - nbAlgoIds); i++) { -@@ -594,7 +593,8 @@ CLEANUP: - if (stopEvent) { - cudaEventDestroy(stopEvent); - } -- return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -+ return AlgoCount; -+ // return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; - } - - template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, -@@ -634,7 +634,6 @@ template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - FILE* fout, - customMatmulPerf_t perfResults[], - int AlgoCombinations); -- - #ifdef ENABLE_BF16 - template int LtHgemmCustomFind(cublasLtHandle_t ltHandle, - int batch_size, -diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc -new file mode 100644 -index 0000000..e8f88fe ---- /dev/null -+++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc -@@ -0,0 +1,364 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" -+ -+namespace fastertransformer { -+ -+template -+void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer_in, bool isAppend) -+{ -+ void* cublas_workspace; -+ void* buffer; -+ int workSpaceSize; -+ -+#ifdef ENABLE_BF16 -+ if (std::is_same::value || std::is_same::value) { -+#else -+ if (std::is_same::value) { -+#endif // ENABLE_BF16 -+ // cublas_workspace_ should be the start pointer of cudaMalloc() -+ // to ensure 16B alignemnet -+ cublas_workspace = buffer_in; -+ buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); -+ workSpaceSize = CUBLAS_WORKSPACE_SIZE; -+ } -+ else { -+ cublas_workspace = nullptr; -+ buffer = buffer_in; -+ workSpaceSize = 0; -+ } -+ -+ struct cudaDeviceProp prop; -+ check_cuda_error(cudaGetDeviceProperties(&prop, 0)); -+ printf("Device %s\n", prop.name); -+ -+ // check config -+ FILE* fd; -+ int line_count = 0; -+ if (!isAppend) { -+ fd = fopen(GEMM_CONFIG, "w+"); -+ } -+ else { -+ fd = fopen(GEMM_CONFIG, "a+"); -+ std::vector config; -+ char line[1024]; -+ while (fgets(line, 1024, fd) != NULL) { -+ config.push_back(std::string(line)); -+ } -+ line_count = config.size(); -+ if (config.size() >= (MAX_CONFIG_NUM * GEMM_NUM + 1)) // 6 cublas/cublasLt, first row is not included -+ { -+ int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * GEMM_NUM); -+ fclose(fd); -+ fd = fopen(GEMM_CONFIG, "w+"); -+ fprintf(fd, "%s", config[0].c_str()); -+ for (uint i = startIdx; i < config.size(); i++) { -+ fprintf(fd, "%s", config[i].c_str()); -+ } -+ line_count = config.size() - (GEMM_NUM + 3); -+ } -+ } -+ -+ const int gemm_num = 4; -+ int M[gemm_num]; -+ int N[gemm_num]; -+ int K[gemm_num]; -+ int batchCount[gemm_num] = {1, 1, 1, 1}; -+ char mess[gemm_num][256]; -+ float exec_times[gemm_num]; -+ int gemm_lds[gemm_num][3]; // = {3 * hidden_size, hidden_size, 3 * hidden_size}; -+ cublasOperation_t gemm_ops[gemm_num][2]; // = {CUBLAS_OP_N, CUBLAS_OP_N}; -+ int gemm_strides[2][3]; -+ -+ // gemm1 -+ // int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size}; -+ int hidden_size = head_num * size_per_head; -+ M[0] = 3 * hidden_size; -+ N[0] = batch_size * seq_len; -+ K[0] = hidden_size; -+ gemm_lds[0][0] = 3 * hidden_size; -+ gemm_lds[0][1] = hidden_size; -+ gemm_lds[0][2] = 3 * hidden_size; -+ gemm_ops[0][0] = CUBLAS_OP_N; -+ gemm_ops[0][1] = CUBLAS_OP_N; -+ strcpy(mess[0], "cublasGemmEx "); -+ -+ // gemm2 -+ M[1] = tgt_seq_len; -+ N[1] = seq_len; -+ K[1] = size_per_head; -+ gemm_ops[1][0] = CUBLAS_OP_T; -+ gemm_ops[1][1] = CUBLAS_OP_N; -+ -+ gemm_lds[1][0] = size_per_head; -+ gemm_lds[1][1] = size_per_head; -+ gemm_lds[1][2] = tgt_seq_len; -+ -+ gemm_strides[0][0] = tgt_seq_len * size_per_head; -+ gemm_strides[0][1] = seq_len * size_per_head; -+ gemm_strides[0][2] = seq_len * tgt_seq_len; -+ strcpy(mess[1], "cublasGemmStridedBatchedEx"); -+ -+ // gemm3 -+ M[2] = size_per_head; -+ N[2] = seq_len; -+ K[2] = tgt_seq_len; -+ gemm_ops[2][0] = CUBLAS_OP_N; -+ gemm_ops[2][1] = CUBLAS_OP_N; -+ -+ gemm_lds[2][0] = size_per_head; -+ gemm_lds[2][1] = tgt_seq_len; -+ gemm_lds[2][2] = size_per_head; -+ -+ gemm_strides[1][0] = tgt_seq_len * size_per_head; -+ gemm_strides[1][1] = seq_len * tgt_seq_len; -+ gemm_strides[1][2] = seq_len * size_per_head; -+ strcpy(mess[2], "cublasGemmStridedBatchedEx"); -+ -+ // gemm4 -+ M[3] = hidden_size; -+ N[3] = batch_size * seq_len; -+ K[3] = hidden_size; -+ gemm_ops[3][0] = CUBLAS_OP_N; -+ gemm_ops[3][1] = CUBLAS_OP_N; -+ -+ gemm_lds[3][0] = hidden_size; -+ gemm_lds[3][1] = hidden_size; -+ gemm_lds[3][2] = hidden_size; -+ strcpy(mess[3], "cublasGemmEx"); -+ -+ cublasHandle_t cublas_handle; -+ check_cuda_error(cublasCreate(&cublas_handle)); -+ cublasLtHandle_t ltHandle; -+ check_cuda_error(cublasLtCreate(<Handle)); -+ -+ cudaDataType_t AType; -+ cudaDataType_t BType; -+ cudaDataType_t CType; -+ cublasComputeType_t computeType; -+ int startAlgo, endAlgo; -+ const int ites = 10000; -+ const int warmup_ites = 10000; -+ struct timeval start, end; -+ -+ CublasDataType data_type; -+ if (std::is_same::value) { -+ data_type = FLOAT_DATATYPE; -+ AType = CUDA_R_32F; -+ BType = CUDA_R_32F; -+ CType = CUDA_R_32F; -+ computeType = CUBLAS_COMPUTE_32F_FAST_TF32; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+ else if (std::is_same::value) { -+ data_type = HALF_DATATYPE; -+ AType = CUDA_R_16F; -+ BType = CUDA_R_16F; -+ CType = CUDA_R_16F; -+ computeType = CUBLAS_COMPUTE_16F; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+#ifdef ENABLE_BF16 -+ else if (std::is_same::value) { -+ data_type = BFLOAT16_DATATYPE; -+ AType = CUDA_R_16BF; -+ BType = CUDA_R_16BF; -+ CType = CUDA_R_16BF; -+ computeType = CUBLAS_COMPUTE_32F; -+ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; -+ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; -+ } -+#endif -+ using scaleT = typename ScaleTypeConverter::Type; -+ -+ scaleT alpha = (scaleT)1.0f; -+ scaleT beta = (scaleT)0.0f; -+ -+ printf("***Encoder Gemm Testing Begin***\n"); -+ printf("***Cublas Gemm Testing Begin***\n"); -+ if (line_count == 0) { -+ fprintf(fd, -+ "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, n, m, k, algoId, " -+ "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, exec_time\n"); -+ } -+ for (int i = 0; i < gemm_num; ++i) { -+ // if(i != 0 && i != 5) continue; -+ -+ int m = M[i], n = N[i], k = K[i]; -+ printf("\n-----------------------------\n"); -+ printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]); -+ // printf("GEMM test %d: [M: %d, K: %d, N: %d] \n", i, m, k, n); -+ T* d_A = (T*)buffer; -+ T* d_B = d_A + m * k * batchCount[i]; -+ T* d_C = d_B + k * n * batchCount[i]; -+ -+ // array of pointer for batchedGemm -+ T* harray[12]; -+ harray[0] = (T*)buffer; -+ harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); -+ harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); -+ harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); -+ harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); -+ harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); -+ harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); -+ harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); -+ harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); -+ -+ T** darray = 0; -+ check_cuda_error(cudaMalloc((void**)&darray, sizeof(T*) * 12)); -+ cudaMemcpy((void*)darray, (void*)harray, sizeof(T*) * 12, cudaMemcpyHostToDevice); -+ T** dAarray = darray; -+ T** dBarray = darray + 4; -+ T** dCarray = darray + 8; -+ -+ float exec_time = 99999.0f; -+ int fast_algo = 0; -+ -+ // warmup -+ // for (int j = 0; j < ites*10; j++) { -+ // cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ // gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(0)); -+ // } -+ -+ for (int algo = startAlgo; algo <= endAlgo; algo++) { -+ cublasStatus_t status; -+ //warmup -+ for (int ite = 0; ite < warmup_ites; ++ite) { -+ if ((i == 0) || (i == 3)) { -+ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); -+ } else { -+ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], -+ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, -+ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); -+ } -+ } -+ cudaDeviceSynchronize(); -+ gettimeofday(&start, NULL); -+ if ((i == 0) || (i == 3)) { -+ for (int ite = 0; ite < ites; ++ite) { -+ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, -+ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); -+ } -+ } else { -+ for (int ite = 0; ite < ites; ++ite) { -+ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], -+ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, -+ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); -+ } -+ } -+ -+ if (status != CUBLAS_STATUS_SUCCESS) { -+ break; -+ } -+ // } -+ cudaDeviceSynchronize(); -+ gettimeofday(&end, NULL); -+ if (status == CUBLAS_STATUS_SUCCESS) { -+ printf("algo_%d costs %.6fms \n", algo, diffTime(start, end) / ites); -+ if (diffTime(start, end) / ites < exec_time) { -+ exec_time = diffTime(start, end) / ites; -+ fast_algo = algo; -+ } -+ } -+ } -+ printf("fast_algo %d costs %.6f ms \n", fast_algo, exec_time); -+ -+ // for fp16 and bf16, we compare cublasLt -+ if (i < 3 && data_type != FLOAT_DATATYPE) { -+ printf("***cublasLt Gemm Testing Beign***\n"); -+ // Let try a fixed number of combinations -+ int ALGO_COMBINATIONS = 5000; -+ customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; -+ LtHgemmCustomFind(ltHandle, -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ n, -+ m, -+ k, -+ &alpha, -+ d_B, -+ d_A, -+ &beta, -+ d_C, -+ cublas_workspace, -+ workSpaceSize, -+ fd, -+ perfResults, -+ ALGO_COMBINATIONS); -+ if (perfResults[0].time < exec_time) { -+ printPerfStructure( -+ batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); -+ exec_time = perfResults[0].time; -+ } -+ else { -+ fprintf(fd, -+ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ data_type, -+ batchCount[i], -+ n, -+ m, -+ k, -+ fast_algo, -+ exec_time); -+ } -+ printf("***cublasLt Gemm Testing End***\n"); -+ } -+ else { -+ fprintf(fd, -+ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", -+ batch_size, -+ seq_len, -+ head_num, -+ size_per_head, -+ data_type, -+ batchCount[i], -+ n, -+ m, -+ k, -+ fast_algo, -+ exec_time); -+ } -+ exec_times[i] = exec_time; -+ cudaFree(darray); -+ } -+ printf("***cublas Gemm Testing End***\n\n"); -+ fclose(fd); -+ printf("***Encoder Gemm Testing End***\n"); -+ -+ return; -+} -+ -+template void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+template void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+#ifdef ENABLE_BF16 -+template void generate_ms_gemm_config<__nv_bfloat16>( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); -+#endif -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.h b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h -new file mode 100644 -index 0000000..c6f68ca ---- /dev/null -+++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h -@@ -0,0 +1,40 @@ -+/* -+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+ -+#pragma once -+ -+#include "src/fastertransformer/utils/cublasAlgoMap.h" -+#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -+#include "src/fastertransformer/utils/cuda_utils.h" -+#include "src/fastertransformer/utils/gemm_test/gemm_func.h" -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+namespace fastertransformer { -+ -+template -+void generate_ms_gemm_config( -+ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); -+ -+} // namespace fastertransformer -diff --git a/src/fastertransformer/utils/logger.h b/src/fastertransformer/utils/logger.h -index bcdf8fa..e3e7007 100644 ---- a/src/fastertransformer/utils/logger.h -+++ b/src/fastertransformer/utils/logger.h -@@ -65,7 +65,7 @@ private: - #else - const Level DEFAULT_LOG_LEVEL = INFO; - #endif -- Level level_ = DEFAULT_LOG_LEVEL; -+ Level level_ = ERROR; // DEFAULT_LOG_LEVEL; - - Logger() - { -diff --git a/tests/unittests/test_gemm.cu b/tests/unittests/test_gemm.cu -index 13719f7..4ecf0bd 100644 ---- a/tests/unittests/test_gemm.cu -+++ b/tests/unittests/test_gemm.cu -@@ -157,7 +157,7 @@ void computeReference(GemmOp transa, - cudaDataType_t atype = (A.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; - cudaDataType_t btype = (B.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; - cudaDataType_t ctype = (C.type == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; -- cudaDataType_t compute_type = (computeType == TYPE_FP16) ? CUDA_R_16F : CUDA_R_32F; -+ cublasComputeType_t compute_type = (computeType == TYPE_FP16) ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F_FAST_TF32; - - cublasHandle_t cublas_handle; - check_cuda_error(cublasCreate(&cublas_handle)); -@@ -391,7 +391,11 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) { - - cudaDataType_t cuda_dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; - cudaDataType_t cuda_ctype = (DataType::TYPE_FP32 == computeType) ? CUDA_R_32F : CUDA_R_16F; -- cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cuda_ctype); -+ // add culab type -+ cublasComputeType_t cublasComputeType = (DataType::TYPE_FP32 == computeType) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; -+ cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cublasComputeType); -+ //before change -+ // cublas_wrapper.setGemmConfig(cuda_dtype, cuda_dtype, cuda_dtype, cuda_ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_tensor.type, b_tensor.type, c_tensor.type, computeType); -@@ -506,8 +510,12 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) { - &allocator); - - cudaDataType_t dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -- cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // add culab type -+ cublasComputeType_t ctype = (computeType == DataType::TYPE_FP32) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; - cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); -+ //before change -+ // cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_type, b_type, c_type, computeType); -@@ -606,8 +614,12 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t - &allocator); - - cudaDataType_t dtype = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -- cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // add culab type -+ cublasComputeType_t ctype = (computeType == DataType::TYPE_FP32) ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_16F; - cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); -+ //before change -+ // cudaDataType_t ctype = (computeType == DataType::TYPE_FP32) ? CUDA_R_32F : CUDA_R_16F; -+ // cublas_wrapper.setGemmConfig(dtype, dtype, dtype, ctype); - - std::shared_ptr gemm = createGemm(&allocator, stream, false, false); - gemm->setTypes(a_tensor.type, b_tensor.type, c_tensor.type, computeType); -@@ -647,7 +659,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t - ldc, - stridec, - batch_size, -- getCublasDataType(computeType)); -+ getCublasComputeType(computeType)); - - c_tensor.setInvalidValues(); // to guarantee C has invalid data - gemm->stridedBatchedGemm(op_pair.transa, op_pair.transb, m, n, k, diff --git a/third_party/patch/ffmpeg/CVE-2022-3964.patch b/third_party/patch/ffmpeg/CVE-2022-3964.patch deleted file mode 100644 index 493bf023e10..00000000000 --- a/third_party/patch/ffmpeg/CVE-2022-3964.patch +++ /dev/null @@ -1,72 +0,0 @@ -diff -Npur ffmpeg-5.1.2/libavcodec/rpzaenc.c ffmpeg-5.1.2-change/libavcodec/rpzaenc.c ---- ffmpeg-5.1.2/libavcodec/rpzaenc.c 2022-07-23 01:58:39.000000000 +0800 -+++ ffmpeg-5.1.2-change/libavcodec/rpzaenc.c 2024-06-25 15:56:07.594394836 +0800 -@@ -205,7 +205,7 @@ static void get_max_component_diff(Block - - // loop thru and compare pixels - for (y = 0; y < bi->block_height; y++) { -- for (x = 0; x < bi->block_width; x++){ -+ for (x = 0; x < bi->block_width; x++) { - // TODO: optimize - min_r = FFMIN(R(block_ptr[x]), min_r); - min_g = FFMIN(G(block_ptr[x]), min_g); -@@ -277,7 +277,7 @@ static int leastsquares(uint16_t *block_ - return -1; - - for (i = 0; i < bi->block_height; i++) { -- for (j = 0; j < bi->block_width; j++){ -+ for (j = 0; j < bi->block_width; j++) { - x = GET_CHAN(block_ptr[j], xchannel); - y = GET_CHAN(block_ptr[j], ychannel); - sumx += x; -@@ -324,7 +324,7 @@ static int calc_lsq_max_fit_error(uint16 - int max_err = 0; - - for (i = 0; i < bi->block_height; i++) { -- for (j = 0; j < bi->block_width; j++){ -+ for (j = 0; j < bi->block_width; j++) { - int x_inc, lin_y, lin_x; - x = GET_CHAN(block_ptr[j], xchannel); - y = GET_CHAN(block_ptr[j], ychannel); -@@ -419,7 +419,9 @@ static void update_block_in_prev_frame(c - uint16_t *dest_pixels, - const BlockInfo *bi, int block_counter) - { -- for (int y = 0; y < 4; y++) { -+ const int y_size = FFMIN(4, bi->image_height - bi->row * 4); -+ -+ for (int y = 0; y < y_size; y++) { - memcpy(dest_pixels, src_pixels, 8); - dest_pixels += bi->rowstride; - src_pixels += bi->rowstride; -@@ -729,14 +731,15 @@ post_skip : - - if (err > s->sixteen_color_thresh) { // DO SIXTEEN COLOR BLOCK - uint16_t *row_ptr; -- int rgb555; -+ int y_size, rgb555; - - block_offset = get_block_info(&bi, block_counter); - - row_ptr = &src_pixels[block_offset]; -+ y_size = FFMIN(4, bi.image_height - bi.row * 4); - -- for (int y = 0; y < 4; y++) { -- for (int x = 0; x < 4; x++){ -+ for (int y = 0; y < y_size; y++) { -+ for (int x = 0; x < 4; x++) { - rgb555 = row_ptr[x] & ~0x8000; - - put_bits(&s->pb, 16, rgb555); -@@ -744,6 +747,11 @@ post_skip : - row_ptr += bi.rowstride; - } - -+ for (int y = y_size; y < 4; y++) { -+ for (int x = 0; x < 4; x++) -+ put_bits(&s->pb, 16, 0); -+ } -+ - block_counter++; - } else { // FOUR COLOR BLOCK - block_counter += encode_four_color_block(min_color, max_color, diff --git a/third_party/patch/ffmpeg/CVE-2022-3965.patch b/third_party/patch/ffmpeg/CVE-2022-3965.patch deleted file mode 100644 index 155aad1524b..00000000000 --- a/third_party/patch/ffmpeg/CVE-2022-3965.patch +++ /dev/null @@ -1,91 +0,0 @@ -diff -Npur ffmpeg-5.1.2/libavcodec/smcenc.c ffmpeg-5.1.2-change/libavcodec/smcenc.c ---- ffmpeg-5.1.2/libavcodec/smcenc.c 2022-07-23 01:58:39.000000000 +0800 -+++ ffmpeg-5.1.2-change/libavcodec/smcenc.c 2024-06-25 17:07:00.100640653 +0800 -@@ -61,6 +61,7 @@ typedef struct SMCContext { - { \ - row_ptr += stride * 4; \ - pixel_ptr = row_ptr; \ -+ cur_y += 4; \ - } \ - } \ - } -@@ -117,6 +118,7 @@ static void smc_encode_stream(SMCContext - const uint8_t *prev_pixels = (const uint8_t *)s->prev_frame->data[0]; - uint8_t *distinct_values = s->distinct_values; - const uint8_t *pixel_ptr, *row_ptr; -+ const int height = frame->height; - const int width = frame->width; - uint8_t block_values[16]; - int block_counter = 0; -@@ -125,13 +127,14 @@ static void smc_encode_stream(SMCContext - int color_octet_index = 0; - int color_table_index; /* indexes to color pair, quad, or octet tables */ - int total_blocks; -+ int cur_y = 0; - - memset(s->color_pairs, 0, sizeof(s->color_pairs)); - memset(s->color_quads, 0, sizeof(s->color_quads)); - memset(s->color_octets, 0, sizeof(s->color_octets)); - - /* Number of 4x4 blocks in frame. */ -- total_blocks = ((frame->width + 3) / 4) * ((frame->height + 3) / 4); -+ total_blocks = ((width + 3) / 4) * ((height + 3) / 4); - - pixel_ptr = row_ptr = src_pixels; - -@@ -145,11 +148,13 @@ static void smc_encode_stream(SMCContext - int cache_index; - int distinct = 0; - int blocks = 0; -+ int frame_y = cur_y; - - while (prev_pixels && s->key_frame == 0 && block_counter + inter_skip_blocks < total_blocks) { -+ const int y_size = FFMIN(4, height - cur_y); - int compare = 0; - -- for (int y = 0; y < 4; y++) { -+ for (int y = 0; y < y_size; y++) { - const ptrdiff_t offset = pixel_ptr - src_pixels; - const uint8_t *prev_pixel_ptr = prev_pixels + offset; - -@@ -170,8 +175,10 @@ static void smc_encode_stream(SMCContext - - pixel_ptr = xpixel_ptr; - row_ptr = xrow_ptr; -+ cur_y = frame_y; - - while (block_counter > 0 && block_counter + intra_skip_blocks < total_blocks) { -+ const int y_size = FFMIN(4, height - cur_y); - const ptrdiff_t offset = pixel_ptr - src_pixels; - const int sy = offset / stride; - const int sx = offset % stride; -@@ -180,7 +187,7 @@ static void smc_encode_stream(SMCContext - const uint8_t *old_pixel_ptr = src_pixels + nx + ny * stride; - int compare = 0; - -- for (int y = 0; y < 4; y++) { -+ for (int y = 0; y < y_size; y++) { - compare |= memcmp(old_pixel_ptr + y * stride, pixel_ptr + y * stride, 4); - if (compare) - break; -@@ -197,9 +204,11 @@ static void smc_encode_stream(SMCContext - - pixel_ptr = xpixel_ptr; - row_ptr = xrow_ptr; -+ cur_y = frame_y; - - while (block_counter + coded_blocks < total_blocks && coded_blocks < 256) { -- for (int y = 0; y < 4; y++) -+ const int y_size = FFMIN(4, height - cur_y); -+ for (int y = 0; y < y_size; y++) - memcpy(block_values + y * 4, pixel_ptr + y * stride, 4); - - qsort(block_values, 16, sizeof(block_values[0]), smc_cmp_values); -@@ -224,6 +233,7 @@ static void smc_encode_stream(SMCContext - - pixel_ptr = xpixel_ptr; - row_ptr = xrow_ptr; -+ cur_y = frame_y; - - blocks = coded_blocks; - distinct = coded_distinct; diff --git a/third_party/patch/ffmpeg/CVE-2023-47342.patch b/third_party/patch/ffmpeg/CVE-2023-47342.patch deleted file mode 100644 index 0eace366db0..00000000000 --- a/third_party/patch/ffmpeg/CVE-2023-47342.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff -Npur ffmpeg-5.1.2/libavformat/rtsp.c ffmpeg-5.1.2-change/libavformat/rtsp.c ---- ffmpeg-5.1.2/libavformat/rtsp.c 2022-07-23 01:58:39.000000000 +0800 -+++ ffmpeg-5.1.2-change/libavformat/rtsp.c 2024-06-25 16:37:03.333689422 +0800 -@@ -409,7 +409,7 @@ static void parse_fmtp(AVFormatContext * - if (rtsp_st->sdp_payload_type == payload_type && - rtsp_st->dynamic_handler && - rtsp_st->dynamic_handler->parse_sdp_a_line) { -- rtsp_st->dynamic_handler->parse_sdp_a_line(s, i, -+ rtsp_st->dynamic_handler->parse_sdp_a_line(s, rtsp_st->stream_index, - rtsp_st->dynamic_protocol_context, line); - } - } diff --git a/third_party/patch/jpeg_turbo/CVE-2020-35538.patch b/third_party/patch/jpeg_turbo/CVE-2020-35538.patch deleted file mode 100644 index 2cbd1cd4275..00000000000 --- a/third_party/patch/jpeg_turbo/CVE-2020-35538.patch +++ /dev/null @@ -1,452 +0,0 @@ -From 9120a247436e84c0b4eea828cb11e8f665fcde30 Mon Sep 17 00:00:00 2001 -From: DRC -Date: Thu, 23 Jul 2020 21:24:38 -0500 -Subject: [PATCH] Fix jpeg_skip_scanlines() segfault w/merged upsamp - -The additional segfault mentioned in #244 was due to the fact that -the merged upsamplers use a different private structure than the -non-merged upsamplers. jpeg_skip_scanlines() was assuming the latter, so -when merged upsampling was enabled, jpeg_skip_scanlines() clobbered one -of the IDCT method pointers in the merged upsampler's private structure. - -For reasons unknown, the test image in #441 did not encounter this -segfault (too small?), but it encountered an issue similar to the one -fixed in 5bc43c7821df982f65aa1c738f67fbf7cba8bd69, whereby it was -necessary to set up a dummy postprocessing function in -read_and_discard_scanlines() when merged upsampling was enabled. -Failing to do so caused either a segfault in merged_2v_upsample() (due -to a NULL pointer being passed to jcopy_sample_rows()) or an error -("Corrupt JPEG data: premature end of data segment"), depending on the -number of scanlines skipped and whether the first scanline skipped was -an odd- or even-numbered row. - -Fixes #441 -Fixes #244 (for real this time) ---- - ChangeLog.md | 6 +++++ - jdapistd.c | 72 ++++++++++++++++++++++++++++++++++++++++++++++------ - jdmerge.c | 46 +++++++-------------------------- - jdmerge.h | 47 ++++++++++++++++++++++++++++++++++ - jdmrg565.c | 10 ++++---- - jdmrgext.c | 6 ++--- - 6 files changed, 134 insertions(+), 53 deletions(-) - create mode 100644 jdmerge.h - -diff --git a/ChangeLog.md b/ChangeLog.md -index 5f4ab7b8..2e4a3d88 100644 ---- a/ChangeLog.md -+++ b/ChangeLog.md -@@ -7,6 +7,12 @@ - platforms when using any of the YUV encoding/compression/decompression/decoding - methods in the TurboJPEG Java API. - -+2. Fixed segfaults or "Corrupt JPEG data: premature end of data segment" errors -+in `jpeg_skip_scanlines()` that occurred when decompressing 4:2:2 or 4:2:0 JPEG -+images using the merged (non-fancy) upsampling algorithms (that is, when -+setting `cinfo.do_fancy_upsampling` to `FALSE`.) 2.0.0[6] was a similar fix, -+but it did not cover all cases. -+ - - 2.0.5 - ===== -diff --git a/jdapistd.c b/jdapistd.c -index 2c808fa5..91da642d 100644 ---- a/jdapistd.c -+++ b/jdapistd.c -@@ -4,7 +4,7 @@ - * This file was part of the Independent JPEG Group's software: - * Copyright (C) 1994-1996, Thomas G. Lane. - * libjpeg-turbo Modifications: -- * Copyright (C) 2010, 2015-2018, D. R. Commander. -+ * Copyright (C) 2010, 2015-2018, 2020, D. R. Commander. - * Copyright (C) 2015, Google, Inc. - * For conditions of distribution and use, see the accompanying README.ijg - * file. -@@ -21,6 +21,8 @@ - #include "jinclude.h" - #include "jdmainct.h" - #include "jdcoefct.h" -+#include "jdmaster.h" -+#include "jdmerge.h" - #include "jdsample.h" - #include "jmemsys.h" - -@@ -304,6 +306,16 @@ noop_quantize(j_decompress_ptr cinfo, JSAMPARRAY input_buf, - } - - -+/* Dummy postprocessing function used by jpeg_skip_scanlines() */ -+LOCAL(void) -+noop_post_process (j_decompress_ptr cinfo, JSAMPIMAGE input_buf, -+ JDIMENSION *in_row_group_ctr, -+ JDIMENSION in_row_groups_avail, JSAMPARRAY output_buf, -+ JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail) -+{ -+} -+ -+ - /* - * In some cases, it is best to call jpeg_read_scanlines() and discard the - * output, rather than skipping the scanlines, because this allows us to -@@ -316,11 +328,17 @@ LOCAL(void) - read_and_discard_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - { - JDIMENSION n; -+ my_master_ptr master = (my_master_ptr)cinfo->master; - void (*color_convert) (j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION input_row, JSAMPARRAY output_buf, - int num_rows) = NULL; - void (*color_quantize) (j_decompress_ptr cinfo, JSAMPARRAY input_buf, - JSAMPARRAY output_buf, int num_rows) = NULL; -+ void (*post_process_data) (j_decompress_ptr cinfo, JSAMPIMAGE input_buf, -+ JDIMENSION *in_row_group_ctr, -+ JDIMENSION in_row_groups_avail, -+ JSAMPARRAY output_buf, JDIMENSION *out_row_ctr, -+ JDIMENSION out_rows_avail) = NULL; - - if (cinfo->cconvert && cinfo->cconvert->color_convert) { - color_convert = cinfo->cconvert->color_convert; -@@ -332,6 +350,12 @@ read_and_discard_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - cinfo->cquantize->color_quantize = noop_quantize; - } - -+ if (master->using_merged_upsample && cinfo->post && -+ cinfo->post->post_process_data) { -+ post_process_data = cinfo->post->post_process_data; -+ cinfo->post->post_process_data = noop_post_process; -+ } -+ - for (n = 0; n < num_lines; n++) - jpeg_read_scanlines(cinfo, NULL, 1); - -@@ -340,6 +364,9 @@ read_and_discard_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - - if (color_quantize) - cinfo->cquantize->color_quantize = color_quantize; -+ -+ if (post_process_data) -+ cinfo->post->post_process_data = post_process_data; - } - - -@@ -382,7 +409,7 @@ jpeg_skip_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - { - my_main_ptr main_ptr = (my_main_ptr)cinfo->main; - my_coef_ptr coef = (my_coef_ptr)cinfo->coef; -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_master_ptr master = (my_master_ptr)cinfo->master; - JDIMENSION i, x; - int y; - JDIMENSION lines_per_iMCU_row, lines_left_in_iMCU_row, lines_after_iMCU_row; -@@ -445,8 +472,16 @@ jpeg_skip_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - main_ptr->buffer_full = FALSE; - main_ptr->rowgroup_ctr = 0; - main_ptr->context_state = CTX_PREPARE_FOR_IMCU; -- upsample->next_row_out = cinfo->max_v_samp_factor; -- upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ if (master->using_merged_upsample) { -+ my_merged_upsample_ptr upsample = -+ (my_merged_upsample_ptr)cinfo->upsample; -+ upsample->spare_full = FALSE; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } else { -+ my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ upsample->next_row_out = cinfo->max_v_samp_factor; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } - } - - /* Skipping is much simpler when context rows are not required. */ -@@ -458,8 +493,16 @@ jpeg_skip_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - cinfo->output_scanline += lines_left_in_iMCU_row; - main_ptr->buffer_full = FALSE; - main_ptr->rowgroup_ctr = 0; -- upsample->next_row_out = cinfo->max_v_samp_factor; -- upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ if (master->using_merged_upsample) { -+ my_merged_upsample_ptr upsample = -+ (my_merged_upsample_ptr)cinfo->upsample; -+ upsample->spare_full = FALSE; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } else { -+ my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ upsample->next_row_out = cinfo->max_v_samp_factor; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } - } - } - -@@ -494,7 +537,14 @@ jpeg_skip_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - cinfo->output_iMCU_row += lines_to_skip / lines_per_iMCU_row; - increment_simple_rowgroup_ctr(cinfo, lines_to_read); - } -- upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ if (master->using_merged_upsample) { -+ my_merged_upsample_ptr upsample = -+ (my_merged_upsample_ptr)cinfo->upsample; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } else { -+ my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } - return num_lines; - } - -@@ -535,7 +585,13 @@ jpeg_skip_scanlines(j_decompress_ptr cinfo, JDIMENSION num_lines) - * bit odd, since "rows_to_go" seems to be redundantly keeping track of - * output_scanline. - */ -- upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ if (master->using_merged_upsample) { -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } else { -+ my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ upsample->rows_to_go = cinfo->output_height - cinfo->output_scanline; -+ } - - /* Always skip the requested number of lines. */ - return num_lines; -diff --git a/jdmerge.c b/jdmerge.c -index dff5a350..833ad675 100644 ---- a/jdmerge.c -+++ b/jdmerge.c -@@ -5,7 +5,7 @@ - * Copyright (C) 1994-1996, Thomas G. Lane. - * libjpeg-turbo Modifications: - * Copyright 2009 Pierre Ossman for Cendio AB -- * Copyright (C) 2009, 2011, 2014-2015, D. R. Commander. -+ * Copyright (C) 2009, 2011, 2014-2015, 2020, D. R. Commander. - * Copyright (C) 2013, Linaro Limited. - * For conditions of distribution and use, see the accompanying README.ijg - * file. -@@ -40,41 +40,13 @@ - #define JPEG_INTERNALS - #include "jinclude.h" - #include "jpeglib.h" -+#include "jdmerge.h" - #include "jsimd.h" - #include "jconfigint.h" - - #ifdef UPSAMPLE_MERGING_SUPPORTED - - --/* Private subobject */ -- --typedef struct { -- struct jpeg_upsampler pub; /* public fields */ -- -- /* Pointer to routine to do actual upsampling/conversion of one row group */ -- void (*upmethod) (j_decompress_ptr cinfo, JSAMPIMAGE input_buf, -- JDIMENSION in_row_group_ctr, JSAMPARRAY output_buf); -- -- /* Private state for YCC->RGB conversion */ -- int *Cr_r_tab; /* => table for Cr to R conversion */ -- int *Cb_b_tab; /* => table for Cb to B conversion */ -- JLONG *Cr_g_tab; /* => table for Cr to G conversion */ -- JLONG *Cb_g_tab; /* => table for Cb to G conversion */ -- -- /* For 2:1 vertical sampling, we produce two output rows at a time. -- * We need a "spare" row buffer to hold the second output row if the -- * application provides just a one-row buffer; we also use the spare -- * to discard the dummy last row if the image height is odd. -- */ -- JSAMPROW spare_row; -- boolean spare_full; /* T if spare buffer is occupied */ -- -- JDIMENSION out_row_width; /* samples per output row */ -- JDIMENSION rows_to_go; /* counts rows remaining in image */ --} my_upsampler; -- --typedef my_upsampler *my_upsample_ptr; -- - #define SCALEBITS 16 /* speediest right-shift on some machines */ - #define ONE_HALF ((JLONG)1 << (SCALEBITS - 1)) - #define FIX(x) ((JLONG)((x) * (1L << SCALEBITS) + 0.5)) -@@ -189,7 +161,7 @@ typedef my_upsampler *my_upsample_ptr; - LOCAL(void) - build_ycc_rgb_table(j_decompress_ptr cinfo) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - int i; - JLONG x; - SHIFT_TEMPS -@@ -232,7 +204,7 @@ build_ycc_rgb_table(j_decompress_ptr cinfo) - METHODDEF(void) - start_pass_merged_upsample(j_decompress_ptr cinfo) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - - /* Mark the spare buffer empty */ - upsample->spare_full = FALSE; -@@ -254,7 +226,7 @@ merged_2v_upsample(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail) - /* 2:1 vertical sampling case: may need a spare row. */ - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - JSAMPROW work_ptrs[2]; - JDIMENSION num_rows; /* number of rows returned to caller */ - -@@ -305,7 +277,7 @@ merged_1v_upsample(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION *out_row_ctr, JDIMENSION out_rows_avail) - /* 1:1 vertical sampling case: much easier, never need a spare row. */ - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - - /* Just do the upsampling. */ - (*upsample->upmethod) (cinfo, input_buf, *in_row_group_ctr, -@@ -566,11 +538,11 @@ h2v2_merged_upsample_565D(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - GLOBAL(void) - jinit_merged_upsampler(j_decompress_ptr cinfo) - { -- my_upsample_ptr upsample; -+ my_merged_upsample_ptr upsample; - -- upsample = (my_upsample_ptr) -+ upsample = (my_merged_upsample_ptr) - (*cinfo->mem->alloc_small) ((j_common_ptr)cinfo, JPOOL_IMAGE, -- sizeof(my_upsampler)); -+ sizeof(my_merged_upsampler)); - cinfo->upsample = (struct jpeg_upsampler *)upsample; - upsample->pub.start_pass = start_pass_merged_upsample; - upsample->pub.need_context_rows = FALSE; -diff --git a/jdmerge.h b/jdmerge.h -new file mode 100644 -index 00000000..b583396b ---- /dev/null -+++ b/jdmerge.h -@@ -0,0 +1,47 @@ -+/* -+ * jdmerge.h -+ * -+ * This file was part of the Independent JPEG Group's software: -+ * Copyright (C) 1994-1996, Thomas G. Lane. -+ * libjpeg-turbo Modifications: -+ * Copyright (C) 2020, D. R. Commander. -+ * For conditions of distribution and use, see the accompanying README.ijg -+ * file. -+ */ -+ -+#define JPEG_INTERNALS -+#include "jpeglib.h" -+ -+#ifdef UPSAMPLE_MERGING_SUPPORTED -+ -+ -+/* Private subobject */ -+ -+typedef struct { -+ struct jpeg_upsampler pub; /* public fields */ -+ -+ /* Pointer to routine to do actual upsampling/conversion of one row group */ -+ void (*upmethod) (j_decompress_ptr cinfo, JSAMPIMAGE input_buf, -+ JDIMENSION in_row_group_ctr, JSAMPARRAY output_buf); -+ -+ /* Private state for YCC->RGB conversion */ -+ int *Cr_r_tab; /* => table for Cr to R conversion */ -+ int *Cb_b_tab; /* => table for Cb to B conversion */ -+ JLONG *Cr_g_tab; /* => table for Cr to G conversion */ -+ JLONG *Cb_g_tab; /* => table for Cb to G conversion */ -+ -+ /* For 2:1 vertical sampling, we produce two output rows at a time. -+ * We need a "spare" row buffer to hold the second output row if the -+ * application provides just a one-row buffer; we also use the spare -+ * to discard the dummy last row if the image height is odd. -+ */ -+ JSAMPROW spare_row; -+ boolean spare_full; /* T if spare buffer is occupied */ -+ -+ JDIMENSION out_row_width; /* samples per output row */ -+ JDIMENSION rows_to_go; /* counts rows remaining in image */ -+} my_merged_upsampler; -+ -+typedef my_merged_upsampler *my_merged_upsample_ptr; -+ -+#endif /* UPSAMPLE_MERGING_SUPPORTED */ -diff --git a/jdmrg565.c b/jdmrg565.c -index 1b87e371..53f1e167 100644 ---- a/jdmrg565.c -+++ b/jdmrg565.c -@@ -5,7 +5,7 @@ - * Copyright (C) 1994-1996, Thomas G. Lane. - * libjpeg-turbo Modifications: - * Copyright (C) 2013, Linaro Limited. -- * Copyright (C) 2014-2015, 2018, D. R. Commander. -+ * Copyright (C) 2014-2015, 2018, 2020, D. R. Commander. - * For conditions of distribution and use, see the accompanying README.ijg - * file. - * -@@ -19,7 +19,7 @@ h2v1_merged_upsample_565_internal(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr; -@@ -90,7 +90,7 @@ h2v1_merged_upsample_565D_internal(j_decompress_ptr cinfo, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr; -@@ -163,7 +163,7 @@ h2v2_merged_upsample_565_internal(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr0, outptr1; -@@ -259,7 +259,7 @@ h2v2_merged_upsample_565D_internal(j_decompress_ptr cinfo, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr0, outptr1; -diff --git a/jdmrgext.c b/jdmrgext.c -index b1c27df5..c9a44d82 100644 ---- a/jdmrgext.c -+++ b/jdmrgext.c -@@ -4,7 +4,7 @@ - * This file was part of the Independent JPEG Group's software: - * Copyright (C) 1994-1996, Thomas G. Lane. - * libjpeg-turbo Modifications: -- * Copyright (C) 2011, 2015, D. R. Commander. -+ * Copyright (C) 2011, 2015, 2020, D. R. Commander. - * For conditions of distribution and use, see the accompanying README.ijg - * file. - * -@@ -25,7 +25,7 @@ h2v1_merged_upsample_internal(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr; -@@ -97,7 +97,7 @@ h2v2_merged_upsample_internal(j_decompress_ptr cinfo, JSAMPIMAGE input_buf, - JDIMENSION in_row_group_ctr, - JSAMPARRAY output_buf) - { -- my_upsample_ptr upsample = (my_upsample_ptr)cinfo->upsample; -+ my_merged_upsample_ptr upsample = (my_merged_upsample_ptr)cinfo->upsample; - register int y, cred, cgreen, cblue; - int cb, cr; - register JSAMPROW outptr0, outptr1; --- -2.17.1 - diff --git a/third_party/patch/jpeg_turbo/CVE-2021-46822.patch b/third_party/patch/jpeg_turbo/CVE-2021-46822.patch deleted file mode 100644 index 5fd5ef45837..00000000000 --- a/third_party/patch/jpeg_turbo/CVE-2021-46822.patch +++ /dev/null @@ -1,63 +0,0 @@ -diff -Npur libjpeg-turbo-2.0.4/rdppm.c libjpeg-turbo-2.0.4-change/rdppm.c ---- libjpeg-turbo-2.0.4/rdppm.c 2019-12-31 15:10:30.000000000 +0800 -+++ libjpeg-turbo-2.0.4-change/rdppm.c 2022-07-28 05:33:19.254229939 +0800 -@@ -526,6 +526,11 @@ get_word_rgb_row(j_compress_ptr cinfo, c - register JSAMPLE *rescale = source->rescale; - JDIMENSION col; - unsigned int maxval = source->maxval; -+ register int rindex = rgb_red[cinfo->in_color_space]; -+ register int gindex = rgb_green[cinfo->in_color_space]; -+ register int bindex = rgb_blue[cinfo->in_color_space]; -+ register int aindex = alpha_index[cinfo->in_color_space]; -+ register int ps = rgb_pixelsize[cinfo->in_color_space]; - - if (!ReadOK(source->pub.input_file, source->iobuffer, source->buffer_width)) - ERREXIT(cinfo, JERR_INPUT_EOF); -@@ -537,17 +542,20 @@ get_word_rgb_row(j_compress_ptr cinfo, c - temp |= UCH(*bufferptr++); - if (temp > maxval) - ERREXIT(cinfo, JERR_PPM_OUTOFRANGE); -- *ptr++ = rescale[temp]; -+ ptr[rindex] = rescale[temp]; - temp = UCH(*bufferptr++) << 8; - temp |= UCH(*bufferptr++); - if (temp > maxval) - ERREXIT(cinfo, JERR_PPM_OUTOFRANGE); -- *ptr++ = rescale[temp]; -+ ptr[gindex] = rescale[temp]; - temp = UCH(*bufferptr++) << 8; - temp |= UCH(*bufferptr++); - if (temp > maxval) - ERREXIT(cinfo, JERR_PPM_OUTOFRANGE); -- *ptr++ = rescale[temp]; -+ ptr[bindex] = rescale[temp]; -+ if (aindex >= 0) -+ ptr[aindex] = 0xFF; -+ ptr += ps; - } - return 1; - } -@@ -634,7 +642,10 @@ start_input_ppm(j_compress_ptr cinfo, cj - cinfo->in_color_space = JCS_GRAYSCALE; - TRACEMS2(cinfo, 1, JTRC_PGM, w, h); - if (maxval > 255) { -- source->pub.get_pixel_rows = get_word_gray_row; -+ if (cinfo->in_color_space == JCS_GRAYSCALE) -+ source->pub.get_pixel_rows = get_word_gray_row; -+ else -+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); - } else if (maxval == MAXJSAMPLE && sizeof(JSAMPLE) == sizeof(U_CHAR) && - cinfo->in_color_space == JCS_GRAYSCALE) { - source->pub.get_pixel_rows = get_raw_row; -@@ -657,7 +668,10 @@ start_input_ppm(j_compress_ptr cinfo, cj - cinfo->in_color_space = JCS_EXT_RGB; - TRACEMS2(cinfo, 1, JTRC_PPM, w, h); - if (maxval > 255) { -- source->pub.get_pixel_rows = get_word_rgb_row; -+ if (IsExtRGB(cinfo->in_color_space)) -+ source->pub.get_pixel_rows = get_word_rgb_row; -+ else -+ ERREXIT(cinfo, JERR_BAD_IN_COLORSPACE); - } else if (maxval == MAXJSAMPLE && sizeof(JSAMPLE) == sizeof(U_CHAR) && - (cinfo->in_color_space == JCS_EXT_RGB - #if RGB_RED == 0 && RGB_GREEN == 1 && RGB_BLUE == 2 && RGB_PIXELSIZE == 3 diff --git a/third_party/patch/mockcpp/mockcpp_support_arm64.patch b/third_party/patch/mockcpp/mockcpp_support_arm64.patch deleted file mode 100644 index fc62c02ec0a..00000000000 --- a/third_party/patch/mockcpp/mockcpp_support_arm64.patch +++ /dev/null @@ -1,299 +0,0 @@ -From c050505bbee806f7389d1c7360eef08d6a39aad6 Mon Sep 17 00:00:00 2001 -From: Zhu Guodong -Date: Fri, 10 May 2024 18:33:50 +0800 -Subject: [PATCH] mockcpp patch - ---- - include/mockcpp/JmpCode.h | 1 + - include/mockcpp/mockcpp.h | 6 +++- - src/JmpCode.cpp | 16 ++++++++- - src/JmpCodeAARCH64.h | 69 +++++++++++++++++++++++++++++++++++++++ - src/JmpCodeARM32.h | 36 ++++++++++++++++++++ - src/JmpCodeArch.h | 20 +++++++++++- - src/JmpCodeX64.h | 3 +- - src/JmpCodeX86.h | 3 +- - src/JmpOnlyApiHook.cpp | 1 + - src/UnixCodeModifier.cpp | 2 ++ - 10 files changed, 152 insertions(+), 5 deletions(-) - create mode 100644 src/JmpCodeAARCH64.h - create mode 100644 src/JmpCodeARM32.h - -diff --git a/include/mockcpp/JmpCode.h b/include/mockcpp/JmpCode.h -index 26f77b0..ed0fac6 100644 ---- a/include/mockcpp/JmpCode.h -+++ b/include/mockcpp/JmpCode.h -@@ -33,6 +33,7 @@ struct JmpCode - - void* getCodeData() const; - size_t getCodeSize() const; -+ void flushCache() const; - private: - JmpCodeImpl* This; - }; -diff --git a/include/mockcpp/mockcpp.h b/include/mockcpp/mockcpp.h -index 306bc7a..8cc385d 100644 ---- a/include/mockcpp/mockcpp.h -+++ b/include/mockcpp/mockcpp.h -@@ -39,7 +39,11 @@ - #endif - - --#if ( defined (__LP64__) \ -+#if defined (__aarch64__) -+#define BUILD_FOR_AARCH64 -+#elif defined (__arm__) -+#define BUILD_FOR_ARM32 -+#elif ( defined (__LP64__) \ - || defined (__64BIT__) \ - || defined (_LP64) \ - || ((defined(__WORDSIZE)) && (__WORDSIZE == 64)) \ -diff --git a/src/JmpCode.cpp b/src/JmpCode.cpp -index 35794fb..6aac228 100644 ---- a/src/JmpCode.cpp -+++ b/src/JmpCode.cpp -@@ -30,6 +30,7 @@ struct JmpCodeImpl - //////////////////////////////////////////////// - JmpCodeImpl(const void* from, const void* to) - { -+ m_from = from; - ::memcpy(m_code, jmpCodeTemplate, JMP_CODE_SIZE); - SET_JMP_CODE(m_code, from, to); - } -@@ -47,7 +48,14 @@ struct JmpCodeImpl - } - - //////////////////////////////////////////////// -+ void flushCache() const -+ { -+ FLUSH_CACHE((const char *)m_from, JMP_CODE_SIZE); -+ } - -+ //////////////////////////////////////////////// -+ -+ const void *m_from; - unsigned char m_code[JMP_CODE_SIZE]; - }; - -@@ -77,5 +85,11 @@ JmpCode::getCodeSize() const - return This->getCodeSize(); - } - --MOCKCPP_NS_END -+/////////////////////////////////////////////////// -+void -+JmpCode::flushCache() const -+{ -+ return This->flushCache(); -+} - -+MOCKCPP_NS_END -diff --git a/src/JmpCodeAARCH64.h b/src/JmpCodeAARCH64.h -new file mode 100644 -index 0000000..4f5e90b ---- /dev/null -+++ b/src/JmpCodeAARCH64.h -@@ -0,0 +1,69 @@ -+/*** -+ mockcpp is a C/C++ mock framework. -+ Copyright [2008] [Darwin Yuan ] -+ [Chen Guodong ] -+ Licensed under the Apache License, Version 2.0 (the "License"); -+ you may not use this file except in compliance with the License. -+ You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+ Unless required by applicable law or agreed to in writing, software -+ distributed under the License is distributed on an "AS IS" BASIS, -+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ See the License for the specific language governing permissions and -+ limitations under the License. -+***/ -+#ifndef __MOCKCPP_JMP_CODE_AARCH64_H__ -+#define __MOCKCPP_JMP_CODE_AARCH64_H__ -+ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+ -+MOCKCPP_NS_START -+ -+struct l2cache_addr_range { -+ uintptr_t start; -+ uintptr_t end; -+}; -+ -+MOCKCPP_NS_END -+ -+#define ADDR_ALIGN_UP(addr) ((((addr) + ((4096) - 1)) & (~((4096) - 1))) & 0xffffffffffffffff) -+#define ADDR_ALIGN_DOWN(addr) (((addr) & (~((4096) - 1))) & 0xffffffffffffffff) -+#define OUTER_CACHE_INV_RANGE _IOWR('S', 0x00, struct l2cache_addr_range) -+#define OUTER_CACHE_CLEAN_RANGE _IOWR('S', 0x01, struct l2cache_addr_range) -+#define OUTER_CACHE_FLUSH_RANGE _IOWR('S', 0x02, struct l2cache_addr_range) -+#define L1_INV_I_CACHE _IOWR('S', 0x03, struct l2cache_addr_range) -+#define D_TO_I_CACHE_FLUSH_RANGE _IOWR('S', 0x04, struct l2cache_addr_range) -+ -+const unsigned char jmpCodeTemplate[] = -+ { 0x57, 0x00, 0x00, 0x58, 0xe0, 0x02, 0x1f, 0xd6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; -+ -+#define SET_SJMP_CODE(base, from, to) do { \ -+ using instruct_t = signed int; \ -+ instruct_t offset = (intptr_t)to - (intptr_t)from; \ -+ offset = ((offset >> 2) & 0x03FFFFFF) | 0x14000000; \ -+ *(instruct_t *)(base) = offset; \ -+ } while(0) -+ -+#define SET_JMP_CODE(base, from, to) do { \ -+ *(void **)(base + 8) = (void *)to; \ -+ } while(0) -+ -+#define FLUSH_CACHE(from, length) do { \ -+ struct l2cache_addr_range usr_data; \ -+ usr_data.start = ADDR_ALIGN_DOWN((unsigned long long)from); \ -+ usr_data.end = ADDR_ALIGN_UP((unsigned long long)from) + length; \ -+ __builtin___clear_cache((char *)usr_data.start, (char *)usr_data.end); \ -+} while (0) -+ -+#endif -diff --git a/src/JmpCodeARM32.h b/src/JmpCodeARM32.h -new file mode 100644 -index 0000000..1eec42b ---- /dev/null -+++ b/src/JmpCodeARM32.h -@@ -0,0 +1,36 @@ -+/*** -+ mockcpp is a C/C++ mock framework. -+ Copyright [2008] [Darwin Yuan ] -+ [Chen Guodong ] -+ Licensed under the Apache License, Version 2.0 (the "License"); -+ you may not use this file except in compliance with the License. -+ You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+ Unless required by applicable law or agreed to in writing, software -+ distributed under the License is distributed on an "AS IS" BASIS, -+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ See the License for the specific language governing permissions and -+ limitations under the License. -+***/ -+#ifndef __MOCKCPP_JMP_CODE_ARM32_H__ -+#define __MOCKCPP_JMP_CODE_ARM32_H__ -+ -+#include -+ -+const unsigned char jmpCodeTemplate[] = -+ { 0xEA, 0x00, 0x00, 0x00 }; -+ -+#define SET_JMP_CODE(base, from, to) do { \ -+ int offset = (int)to - (int)from - 8; \ -+ offset = (offset >> 2) & 0x00FFFFFF; \ -+ int code = *(int *)(base) | offset; \ -+ *(int *)(base) = changeByteOrder(code); \ -+ } while(0) -+ -+#define FLUSH_CACHE(from, length) do { \ -+ ::system("echo 3 > /proc/sys/vm/drop_caches"); \ -+} while (0) -+ -+#endif -diff --git a/src/JmpCodeArch.h b/src/JmpCodeArch.h -index 26abd73..53353eb 100644 ---- a/src/JmpCodeArch.h -+++ b/src/JmpCodeArch.h -@@ -19,11 +19,29 @@ - - #include - -+template -+inline T changeByteOrder(const T v) { -+ enum { S = sizeof(T) }; -+ T rst = v; -+ char *p = (char *)&rst; -+ char tmp = 0; -+ for (unsigned int i = 0; i < S / 2; ++i) { -+ tmp = p[i]; -+ p[i] = p[S - i - 1]; -+ p [S - i - 1] = tmp; -+ } -+ -+ return rst; -+} -+ - #if BUILD_FOR_X64 - # include "JmpCodeX64.h" - #elif BUILD_FOR_X86 - # include "JmpCodeX86.h" -+#elif defined(BUILD_FOR_ARM32) -+# include "JmpCodeARM32.h" -+#elif defined(BUILD_FOR_AARCH64) -+# include "JmpCodeAARCH64.h" - #endif - - #endif -- -diff --git a/src/JmpCodeX64.h b/src/JmpCodeX64.h -index 198507a..e5b4f31 100644 ---- a/src/JmpCodeX64.h -+++ b/src/JmpCodeX64.h -@@ -27,5 +27,6 @@ const unsigned char jmpCodeTemplate[] = - *(uintptr_t *)(base + 6) = (uintptr_t)to; \ - } while(0) - --#endif -+#define FLUSH_CACHE(from, length) ((void)0) - -+#endif -diff --git a/src/JmpCodeX86.h b/src/JmpCodeX86.h -index ebdc526..a06a02e 100644 ---- a/src/JmpCodeX86.h -+++ b/src/JmpCodeX86.h -@@ -23,5 +23,6 @@ const unsigned char jmpCodeTemplate[] = { 0xE9, 0x00, 0x00, 0x00, 0x00 }; - (unsigned long long)to - (unsigned long long)from - sizeof(jmpCodeTemplate); \ - } while(0) - --#endif -+#define FLUSH_CACHE(from, length) ((void)0) - -+#endif -diff --git a/src/JmpOnlyApiHook.cpp b/src/JmpOnlyApiHook.cpp -index d4cfa68..964828f 100644 ---- a/src/JmpOnlyApiHook.cpp -+++ b/src/JmpOnlyApiHook.cpp -@@ -68,6 +68,7 @@ struct JmpOnlyApiHookImpl - void changeCode(const void* data) - { - CodeModifier::modify(const_cast(m_api), data, m_jmpCode.getCodeSize()); -+ m_jmpCode.flushCache(); - } - - ///////////////////////////////////////////////////// -diff --git a/src/UnixCodeModifier.cpp b/src/UnixCodeModifier.cpp -index ab4014e..8e7dde9 100644 ---- a/src/UnixCodeModifier.cpp -+++ b/src/UnixCodeModifier.cpp -@@ -20,6 +20,7 @@ - #include - - #include -+#include "JmpCodeArch.h" - - #define PAGE_ALIGN_BITS 12 - -@@ -39,6 +40,7 @@ bool CodeModifier::modify(void *dest, const void *src, size_t size) - - ::memcpy(dest, src, size); - -+ FLUSH_CACHE(dest, size); - - #if 0 - #if BUILD_FOR_X86 --- -2.34.1 diff --git a/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch b/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch deleted file mode 100644 index f5256517751..00000000000 --- a/third_party/patch/onednn/0001-fix-user-threadpool-bug.patch +++ /dev/null @@ -1,20 +0,0 @@ -diff --git a/src/common/dnnl_thread.hpp b/src/common/dnnl_thread.hpp -index 342bc3b00..0b9190f9c 100644 ---- a/src/common/dnnl_thread.hpp -+++ b/src/common/dnnl_thread.hpp -@@ -104,10 +104,11 @@ inline int dnnl_get_max_threads() { - def_max_threads - = (int)dnnl::impl::cpu::platform::get_max_threads_to_use(); - assert(def_max_threads > 0); -- // Use the default value if the threadpool-provided is outside the range -- // [1, def_max_threads] -- return tp ? std::min(std::max(1, tp->get_num_threads()), def_max_threads) -- : def_max_threads; -+ -+ // Make user responsible for number of threads provided at execution time. -+ // This relates to the fact that the library may identify `def_max_threads` -+ // incorrectly for a platform. -+ return tp ? std::max(1, tp->get_num_threads()) : def_max_threads; - } - inline int dnnl_in_parallel() { - using namespace dnnl::impl::threadpool_utils; diff --git a/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch b/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch deleted file mode 100644 index d0ecb2f0cfe..00000000000 --- a/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch +++ /dev/null @@ -1,334 +0,0 @@ -diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp -index b678200a1..09736ccae 100644 ---- a/src/cpu/nchw_pooling.cpp -+++ b/src/cpu/nchw_pooling.cpp -@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t::execute_backward( - int od_end = min(OD, 1 + (padF + ID - 1) / SD); - - dim_t c_blk = pd()->channel_block_size_; -- int c_blk_tail = C % c_blk; -+ dim_t c_blk_tail = C % c_blk; -+ const int nthr = pd()->nthr_; -+ - if (alg == alg_kind::pooling_max) { -- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), -- [&](int ithr, int, int mb, int cb) { -+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), -+ [&](int ithr, int, dim_t mb, dim_t cb) { - bool is_last_c_block - = c_blk_tail > 0 && (cb + 1) * c_blk > C; - int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; -@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t::execute_backward( - diff_src_fp32, src_sp_size * curr_c_block); - }); - } else { -- parallel_nd_ext(0, MB, utils::div_up(C, c_blk), -- [&](int ithr, int, int mb, int cb) { -+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk), -+ [&](int ithr, int, dim_t mb, dim_t cb) { - bool is_last_c_block - = c_blk_tail > 0 && (cb + 1) * c_blk > C; - int curr_c_block = is_last_c_block ? c_blk_tail : c_blk; -diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp -index 9d649f3f5..2a73f6ae6 100644 ---- a/src/cpu/nchw_pooling.hpp -+++ b/src/cpu/nchw_pooling.hpp -@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - ws_md_ = *hint_fwd_pd_->workspace_md(); - } - -+ nthr_ = dnnl_get_max_threads(); - calculate_channel_block_size(); - init_scratchpad(); - -@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - } - - dim_t channel_block_size_; -+ int nthr_; // To not exceed the limit in execute used for set up. - - private: - void init_scratchpad() { -@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t { - if (diff_dst_md()->data_type == data_type::bf16) { - size_t dst_sz_ = OD() * OH() * OW(); - size_t src_sz_ = ID() * IH() * IW(); -- size_t nthrs = dnnl_get_max_threads(); - auto scratchpad = scratchpad_registry().registrar(); - - scratchpad.template book(key_pool_src_bf16cvt, -- src_sz_ * nthrs * channel_block_size_); -+ src_sz_ * nthr_ * channel_block_size_); - scratchpad.template book(key_pool_dst_bf16cvt, -- dst_sz_ * nthrs * channel_block_size_); -+ dst_sz_ * nthr_ * channel_block_size_); - } - } - -@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t { - // spatial - dim_t dst_sz_ = OD() * OH() * OW(); - dim_t src_sz_ = ID() * IH() * IW(); -- dim_t nthrs = dnnl_get_max_threads(); -- dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C()); -+ dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C()); - const dim_t max_block_size - = platform::get_per_core_cache_size(1) / 2; - dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 -diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp -index 48d9e1240..efe3083f7 100644 ---- a/src/cpu/nhwc_pooling.cpp -+++ b/src/cpu/nhwc_pooling.cpp -@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t::execute_forward( - return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow; - }; - const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty()); -+ const int nthr = pd()->nthr_; - -- parallel_nd_ext(0, MB, OD, OH, OW, -+ parallel_nd_ext(nthr, MB, OD, OH, OW, - [&](int ithr, int, int mb, int od, int oh, int ow) { - const size_t dst_offset_init = strided_offset(mb, dst_n_stride, - od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride); -@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t::execute_backward( - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; -+ const int nthr = pd()->nthr_; - -- parallel_nd_ext(0, MB, ID, IH, IW, -+ parallel_nd_ext(nthr, MB, ID, IH, IW, - [&](int ithr, int, int mb, int id, int ih, int iw) { - size_t src_offset_init = strided_offset(mb, diff_src_n_stride, - id, diff_src_d_stride, ih, diff_src_h_stride, iw, -diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp -index c65196a94..c16e840a2 100644 ---- a/src/cpu/nhwc_pooling.hpp -+++ b/src/cpu/nhwc_pooling.hpp -@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t { - init_default_ws(); - } - -+ nthr_ = dnnl_get_max_threads(); - init_scratchpad(); - - return status::success; - } - -+ int nthr_; // To not exceed the limit in execute used for set up. -+ - private: - void init_scratchpad() { - using namespace memory_tracking::names; - if (src_md()->data_type == data_type::bf16) { -- const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); -+ const size_t bf16cvt_sz_ = C() * nthr_; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.template book( - key_pool_src_bf16cvt, bf16cvt_sz_); -@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t { - if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; - } - -+ nthr_ = dnnl_get_max_threads(); - init_scratchpad(); - - return status::success; - } - -+ int nthr_; // To not exceed the limit in execute used for set up. -+ - private: - void init_scratchpad() { - using namespace memory_tracking::names; - if (diff_src_md()->data_type == data_type::bf16) { -- size_t bf16cvt_sz_ = C() * dnnl_get_max_threads(); -+ size_t bf16cvt_sz_ = C() * nthr_; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.template book( - key_pool_src_bf16cvt, bf16cvt_sz_); -diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp -index a2a181cfa..5befb81ac 100644 ---- a/src/cpu/x64/jit_primitive_conf.hpp -+++ b/src/cpu/x64/jit_primitive_conf.hpp -@@ -672,6 +672,7 @@ struct jit_pool_conf_t { - bool with_postops; - bool with_eltwise; - bool with_binary; -+ int nthr; - }; - - struct jit_pool_call_s { -diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp -index 36d129e6d..ebd4f3af1 100644 ---- a/src/cpu/x64/jit_uni_pool_kernel.cpp -+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp -@@ -76,8 +76,7 @@ jit_uni_pool_kernel::jit_uni_pool_kernel( - - template - status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, -- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, -- int nthreads) { -+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) { - - const auto &pd = *ppd->desc(); - const memory_desc_wrapper src_d( -@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - - const int ndims = src_d.ndims(); - -+ jpp.nthr = dnnl_get_max_threads(); - jpp.is_training = pd.prop_kind == prop_kind::forward_training; - jpp.is_backward = pd.prop_kind == prop_kind::backward_data; - -@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel::init_conf(jit_pool_conf_t &jpp, - ? (ndims == 5 && jpp.simple_alg ? jpp.od : 1) - : (ndims == 5 ? jpp.od : jpp.oh); - work *= jpp.mb * nb2_c; -- auto eff = (float)work / utils::rnd_up(work, nthreads); -+ auto eff = (float)work / utils::rnd_up(work, jpp.nthr); - if (eff > best_eff) { - - best_eff = eff; -diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp -index d5d5f25a2..57ce6f43d 100644 ---- a/src/cpu/x64/jit_uni_pool_kernel.hpp -+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp -@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel) - - static status_t init_conf(jit_pool_conf_t &jbp, -- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd, -- int nthreads); -+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd); - - private: - using Xmm = Xbyak::Xmm; -diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp -index b2055f2a9..29987f70c 100644 ---- a/src/cpu/x64/jit_uni_pooling.cpp -+++ b/src/cpu/x64/jit_uni_pooling.cpp -@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - (*kernel_)(&arg); - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) { -@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - } else { - if (trans_src || trans_dst) { - // ncsp format -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - }); - } else { - // nChw16c, nChw8c format -- parallel(0, [&](std::size_t ithr, std::size_t nthr) { -+ parallel(nthr, [&](int ithr, int nthr) { - const std::size_t work_amount - = static_cast(jpp.mb) * jpp.nb_c * jpp.oh; - if (ithr >= work_amount) return; -@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, - (*kernel_)(&arg); - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) { -@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, - }); - } else { - if (trans_src || trans_dst) { -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t::execute_backward( - transpose_facade.execute_transpose_output(ithr, n, b_c); - }; - -- parallel(0, [&](int ithr, int nthr) { -+ const int nthr = jpp.nthr; -+ -+ parallel(nthr, [&](int ithr, int nthr) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - const std::size_t work_amount - = static_cast(jpp.mb) * nb2_c; -@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - } - }; - -+ const int nthr = jpp.nthr; -+ - if (jpp.simple_alg) { - if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) { - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); -@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - } else { - assert(jpp.ur_bc == 1); - if (trans_src || trans_dst) { -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - if (trans_src) - transpose_facade.execute_transpose_input( -@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - if (!trans_src) { - const size_t chunk_size - = (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block; -- parallel_nd_ext(0, jpp.mb, jpp.nb_c, -+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c, - [&](int ithr, int nthr, int n, int b_c) { - const size_t offset - = ((size_t)n * jpp.nb_c + b_c) * chunk_size; -@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t::execute_backward_3d( - - const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc); - if (trans_src || trans_dst) { -- parallel_nd_ext( -- 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) { -+ parallel_nd_ext(nthr, jpp.mb, nb2_c, -+ [&](int ithr, int nthr, int n, int b2_c) { - const auto b_c = b2_c * jpp.ur_bc; - - if (trans_dst) { -diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp -index ec4b04a2b..e25d9ce05 100644 ---- a/src/cpu/x64/jit_uni_pooling.hpp -+++ b/src/cpu/x64/jit_uni_pooling.hpp -@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t { - init_default_ws(); - - auto scratchpad = scratchpad_registry().registrar(); -- return jit_uni_pool_kernel::init_conf( -- jpp_, scratchpad, this, dnnl_get_max_threads()); -+ CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); -+ -+ return status::success; - } - - jit_pool_conf_t jpp_; -@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t { - init_default_ws(); - if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; - } -+ - auto scratchpad = scratchpad_registry().registrar(); -- return jit_uni_pool_kernel::init_conf( -- jpp_, scratchpad, this, dnnl_get_max_threads()); -+ CHECK(jit_uni_pool_kernel::init_conf(jpp_, scratchpad, this)); -+ -+ return status::success; - } - - jit_pool_conf_t jpp_; diff --git a/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch b/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch deleted file mode 100644 index 0c3b6a76ed2..00000000000 --- a/third_party/patch/onednn/0003-fix-zero-threads-identified-on-AMD.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp -index 1397073ba..041a3436f 100644 ---- a/src/cpu/platform.cpp -+++ b/src/cpu/platform.cpp -@@ -154,6 +154,8 @@ unsigned get_num_cores() { - // function supports process affinity. - unsigned get_max_threads_to_use() { - int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores(); -+ if (num_cores_per_socket <= 1) -+ num_cores_per_socket = std::thread::hardware_concurrency(); - #if defined(_WIN32) - DWORD_PTR proc_affinity_mask; - DWORD_PTR sys_affinity_mask; diff --git a/third_party/patch/onednn/0004-fix-dnnl-limits.patch b/third_party/patch/onednn/0004-fix-dnnl-limits.patch deleted file mode 100644 index 7638e4ae651..00000000000 --- a/third_party/patch/onednn/0004-fix-dnnl-limits.patch +++ /dev/null @@ -1,10 +0,0 @@ ---- a/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h -+++ b/src/cpu/aarch64/xbyak_aarch64/xbyak_aarch64/xbyak_aarch64.h -@@ -28,6 +28,7 @@ - #include - #include - #include -+#include - #include - #include - #include diff --git a/third_party/patch/opencv/Fix_Binary.patch b/third_party/patch/opencv/Fix_Binary.patch deleted file mode 100644 index 50b31fb1cf3..00000000000 --- a/third_party/patch/opencv/Fix_Binary.patch +++ /dev/null @@ -1,20 +0,0 @@ -diff -Npur opencv-4.5.2/CMakeLists.txt opencv-4.5.2-change/CMakeLists.txt ---- opencv-4.5.2/CMakeLists.txt 2021-04-02 19:23:54.000000000 +0800 -+++ opencv-4.5.2-change/CMakeLists.txt 2023-02-08 03:40:02.807178015 +0800 -@@ -1050,7 +1050,7 @@ endif() - if(OPENCV_TIMESTAMP) - status(" Timestamp:" ${OPENCV_TIMESTAMP}) - endif() --status(" Host:" ${CMAKE_HOST_SYSTEM_NAME} ${CMAKE_HOST_SYSTEM_VERSION} ${CMAKE_HOST_SYSTEM_PROCESSOR}) -+status(" Host:" ${CMAKE_HOST_SYSTEM_NAME} ${CMAKE_HOST_SYSTEM_PROCESSOR}) - if(CMAKE_CROSSCOMPILING) - status(" Target:" ${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_VERSION} ${CMAKE_SYSTEM_PROCESSOR}) - endif() -@@ -1639,7 +1639,6 @@ status(" Python (for build):" PYTHON_D - if(BUILD_JAVA) - status("") - status(" Java:" BUILD_FAT_JAVA_LIB THEN "export all functions" ELSE "") -- status(" ant:" ANT_EXECUTABLE THEN "${ANT_EXECUTABLE} (ver ${ANT_VERSION})" ELSE NO) - if(NOT ANDROID) - status(" JNI:" JNI_INCLUDE_DIRS THEN "${JNI_INCLUDE_DIRS}" ELSE NO) - endif() diff --git a/third_party/patch/opencv/libtiff/CVE-2022-3970.patch b/third_party/patch/opencv/libtiff/CVE-2022-3970.patch deleted file mode 100644 index 2f7695435e5..00000000000 --- a/third_party/patch/opencv/libtiff/CVE-2022-3970.patch +++ /dev/null @@ -1,23 +0,0 @@ -diff -Npur opencv-4.5.2/3rdparty/libtiff/tif_getimage.c opencv-4.5.2-change/3rdparty/libtiff/tif_getimage.c ---- opencv-4.5.2/3rdparty/libtiff/tif_getimage.c 2021-04-02 19:23:54.000000000 +0800 -+++ opencv-4.5.2-change/3rdparty/libtiff/tif_getimage.c 2023-01-05 04:35:54.050388130 +0800 -@@ -3058,15 +3058,15 @@ TIFFReadRGBATileExt(TIFF* tif, uint32 co - return( ok ); - - for( i_row = 0; i_row < read_ysize; i_row++ ) { -- memmove( raster + (tile_ysize - i_row - 1) * tile_xsize, -- raster + (read_ysize - i_row - 1) * read_xsize, -+ memmove( raster + (size_t)(tile_ysize - i_row - 1) * tile_xsize, -+ raster + (size_t)(read_ysize - i_row - 1) * read_xsize, - read_xsize * sizeof(uint32) ); -- _TIFFmemset( raster + (tile_ysize - i_row - 1) * tile_xsize+read_xsize, -+ _TIFFmemset( raster + (size_t)(tile_ysize - i_row - 1) * tile_xsize+read_xsize, - 0, sizeof(uint32) * (tile_xsize - read_xsize) ); - } - - for( i_row = read_ysize; i_row < tile_ysize; i_row++ ) { -- _TIFFmemset( raster + (tile_ysize - i_row - 1) * tile_xsize, -+ _TIFFmemset( raster + (size_t)(tile_ysize - i_row - 1) * tile_xsize, - 0, sizeof(uint32) * tile_xsize ); - } - diff --git a/third_party/patch/opencv/libtiff/CVE-2023-3316.patch b/third_party/patch/opencv/libtiff/CVE-2023-3316.patch deleted file mode 100644 index 93db41ae64b..00000000000 --- a/third_party/patch/opencv/libtiff/CVE-2023-3316.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff -Npur opencv-4.5.2/3rdparty/libtiff/tif_close.c opencv-4.5.2-change/3rdparty/libtiff/tif_close.c ---- opencv-4.5.2/3rdparty/libtiff/tif_close.c 2021-04-02 19:23:54.000000000 +0800 -+++ opencv-4.5.2-change/3rdparty/libtiff/tif_close.c 2023-07-29 15:15:10.175435233 +0800 -@@ -120,11 +120,14 @@ TIFFCleanup(TIFF* tif) - void - TIFFClose(TIFF* tif) - { -- TIFFCloseProc closeproc = tif->tif_closeproc; -- thandle_t fd = tif->tif_clientdata; -+ if (tif != NULL) -+ { -+ TIFFCloseProc closeproc = tif->tif_closeproc; -+ thandle_t fd = tif->tif_clientdata; - -- TIFFCleanup(tif); -- (void) (*closeproc)(fd); -+ TIFFCleanup(tif); -+ (void) (*closeproc)(fd); -+ } - } - - /* vim: set ts=8 sts=8 sw=8 noet: */ - - diff --git a/third_party/patch/openssl/CVE-2021-3711.patch b/third_party/patch/openssl/CVE-2021-3711.patch deleted file mode 100644 index 790e10f8807..00000000000 --- a/third_party/patch/openssl/CVE-2021-3711.patch +++ /dev/null @@ -1,81 +0,0 @@ -diff --git a/crypto/sm2/sm2_crypt.c b/crypto/sm2/sm2_crypt.c -index ef505f6441..1188abfc6b 100644 ---- a/crypto/sm2/sm2_crypt.c -+++ b/crypto/sm2/sm2_crypt.c -@@ -61,29 +61,20 @@ static size_t ec_field_size(const EC_GROUP *group) - return field_size; - } - --int sm2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, -- size_t *pt_size) -+int sm2_plaintext_size(const unsigned char *ct, size_t ct_size, size_t *pt_size) - { -- const size_t field_size = ec_field_size(EC_KEY_get0_group(key)); -- const int md_size = EVP_MD_size(digest); -- size_t overhead; -+ struct SM2_Ciphertext_st *sm2_ctext = NULL; - -- if (md_size < 0) { -- SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_DIGEST); -- return 0; -- } -- if (field_size == 0) { -- SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_FIELD); -- return 0; -- } -+ sm2_ctext = d2i_SM2_Ciphertext(NULL, &ct, ct_size); - -- overhead = 10 + 2 * field_size + (size_t)md_size; -- if (msg_len <= overhead) { -+ if (sm2_ctext == NULL) { - SM2err(SM2_F_SM2_PLAINTEXT_SIZE, SM2_R_INVALID_ENCODING); - return 0; - } - -- *pt_size = msg_len - overhead; -+ *pt_size = sm2_ctext->C2->length; -+ SM2_Ciphertext_free(sm2_ctext); -+ - return 1; - } - -diff --git a/crypto/sm2/sm2_pmeth.c b/crypto/sm2/sm2_pmeth.c -index b42a14c32f..27025fbf3a 100644 ---- a/crypto/sm2/sm2_pmeth.c -+++ b/crypto/sm2/sm2_pmeth.c -@@ -151,7 +151,7 @@ static int pkey_sm2_decrypt(EVP_PKEY_CTX *ctx, - const EVP_MD *md = (dctx->md == NULL) ? EVP_sm3() : dctx->md; - - if (out == NULL) { -- if (!sm2_plaintext_size(ec, md, inlen, outlen)) -+ if (!sm2_plaintext_size(in, inlen, outlen)) - return -1; - else - return 1; -diff --git a/include/crypto/sm2.h b/include/crypto/sm2.h -index 76ee80baff..50851a83ce 100644 ---- a/include/crypto/sm2.h -+++ b/include/crypto/sm2.h -@@ -60,8 +60,7 @@ int sm2_verify(const unsigned char *dgst, int dgstlen, - int sm2_ciphertext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, - size_t *ct_size); - --int sm2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len, -- size_t *pt_size); -+int sm2_plaintext_size(const unsigned char *ct, size_t ct_size, size_t *pt_size); - - int sm2_encrypt(const EC_KEY *key, - const EVP_MD *digest, -diff --git a/test/sm2_internal_test.c b/test/sm2_internal_test.c -index 2bb73947ff..41827bb82f 100644 ---- a/test/sm2_internal_test.c -+++ b/test/sm2_internal_test.c -@@ -185,7 +185,7 @@ static int test_sm2_crypt(const EC_GROUP *group, - if (!TEST_mem_eq(ctext, ctext_len, expected, ctext_len)) - goto done; - -- if (!TEST_true(sm2_plaintext_size(key, digest, ctext_len, &ptext_len)) -+ if (!TEST_true(sm2_plaintext_size(ctext, ctext_len, &ptext_len)) - || !TEST_int_eq(ptext_len, msg_len)) - goto done; - diff --git a/third_party/patch/openssl/CVE-2021-3712.patch b/third_party/patch/openssl/CVE-2021-3712.patch deleted file mode 100644 index 1e07534d00a..00000000000 --- a/third_party/patch/openssl/CVE-2021-3712.patch +++ /dev/null @@ -1,17 +0,0 @@ -diff --git a/crypto/ec/ec_asn1.c b/crypto/ec/ec_asn1.c -index 7b7c75ce84..e497a25909 100644 ---- a/crypto/ec/ec_asn1.c -+++ b/crypto/ec/ec_asn1.c -@@ -761,7 +761,10 @@ EC_GROUP *EC_GROUP_new_from_ecparameters(const ECPARAMETERS *params) - ret->seed_len = params->curve->seed->length; - } - -- if (!params->order || !params->base || !params->base->data) { -+ if (params->order == NULL -+ || params->base == NULL -+ || params->base->data == NULL -+ || params->base->length == 0) { - ECerr(EC_F_EC_GROUP_NEW_FROM_ECPARAMETERS, EC_R_ASN1_ERROR); - goto err; - } - diff --git a/third_party/patch/openssl/CVE-2021-4160.patch b/third_party/patch/openssl/CVE-2021-4160.patch deleted file mode 100644 index c5773f5b9f1..00000000000 --- a/third_party/patch/openssl/CVE-2021-4160.patch +++ /dev/null @@ -1,78 +0,0 @@ -diff --git a/crypto/bn/asm/mips.pl b/crypto/bn/asm/mips.pl -index 95cb227dc5..91b7aac6e7 100644 ---- a/crypto/bn/asm/mips.pl -+++ b/crypto/bn/asm/mips.pl -@@ -1986,6 +1986,8 @@ $code.=<<___; - sltu $at,$c_2,$t_1 - $ADDU $c_3,$t_2,$at - $ST $c_2,$BNSZ($a0) -+ sltu $at,$c_3,$t_2 -+ $ADDU $c_1,$at - mflo ($t_1,$a_2,$a_0) - mfhi ($t_2,$a_2,$a_0) - ___ -@@ -2196,6 +2198,8 @@ $code.=<<___; - sltu $at,$c_2,$t_1 - $ADDU $c_3,$t_2,$at - $ST $c_2,$BNSZ($a0) -+ sltu $at,$c_3,$t_2 -+ $ADDU $c_1,$at - mflo ($t_1,$a_2,$a_0) - mfhi ($t_2,$a_2,$a_0) - ___ -diff --git a/test/bntest.c b/test/bntest.c -index 87e5c4065b..fa9fc07cef 100644 ---- a/test/bntest.c -+++ b/test/bntest.c -@@ -630,6 +630,51 @@ static int test_modexp_mont5(void) - if (!TEST_BN_eq(c, d)) - goto err; - -+ /* -+ * Regression test for overflow bug in bn_sqr_comba4/8 for -+ * mips-linux-gnu and mipsel-linux-gnu 32bit targets. -+ */ -+ { -+ static const char *ehex[] = { -+ "95564994a96c45954227b845a1e99cb939d5a1da99ee91acc962396ae999a9ee", -+ "38603790448f2f7694c242a875f0cad0aae658eba085f312d2febbbd128dd2b5", -+ "8f7d1149f03724215d704344d0d62c587ae3c5939cba4b9b5f3dc5e8e911ef9a", -+ "5ce1a5a749a4989d0d8368f6e1f8cdf3a362a6c97fb02047ff152b480a4ad985", -+ "2d45efdf0770542992afca6a0590d52930434bba96017afbc9f99e112950a8b1", -+ "a359473ec376f329bdae6a19f503be6d4be7393c4e43468831234e27e3838680", -+ "b949390d2e416a3f9759e5349ab4c253f6f29f819a6fe4cbfd27ada34903300e", -+ "da021f62839f5878a36f1bc3085375b00fd5fa3e68d316c0fdace87a97558465", -+ NULL}; -+ static const char *phex[] = { -+ "f95dc0f980fbd22e90caa5a387cc4a369f3f830d50dd321c40db8c09a7e1a241", -+ "a536e096622d3280c0c1ba849c1f4a79bf490f60006d081e8cf69960189f0d31", -+ "2cd9e17073a3fba7881b21474a13b334116cb2f5dbf3189a6de3515d0840f053", -+ "c776d3982d391b6d04d642dda5cc6d1640174c09875addb70595658f89efb439", -+ "dc6fbd55f903aadd307982d3f659207f265e1ec6271b274521b7a5e28e8fd7a5", -+ "5df089292820477802a43cf5b6b94e999e8c9944ddebb0d0e95a60f88cb7e813", -+ "ba110d20e1024774107dd02949031864923b3cb8c3f7250d6d1287b0a40db6a4", -+ "7bd5a469518eb65aa207ddc47d8c6e5fc8e0c105be8fc1d4b57b2e27540471d5", -+ NULL}; -+ static const char *mhex[] = { -+ "fef15d5ce4625f1bccfbba49fc8439c72bf8202af039a2259678941b60bb4a8f", -+ "2987e965d58fd8cf86a856674d519763d0e1211cc9f8596971050d56d9b35db3", -+ "785866cfbca17cfdbed6060be3629d894f924a89fdc1efc624f80d41a22f1900", -+ "9503fcc3824ef62ccb9208430c26f2d8ceb2c63488ec4c07437aa4c96c43dd8b", -+ "9289ed00a712ff66ee195dc71f5e4ead02172b63c543d69baf495f5fd63ba7bc", -+ "c633bd309c016e37736da92129d0b053d4ab28d21ad7d8b6fab2a8bbdc8ee647", -+ "d2fbcf2cf426cf892e6f5639e0252993965dfb73ccd277407014ea784aaa280c", -+ "b7b03972bc8b0baa72360bdb44b82415b86b2f260f877791cd33ba8f2d65229b", -+ NULL}; -+ -+ if (!TEST_true(parse_bigBN(&e, ehex)) -+ || !TEST_true(parse_bigBN(&p, phex)) -+ || !TEST_true(parse_bigBN(&m, mhex)) -+ || !TEST_true(BN_mod_exp_mont_consttime(d, e, p, m, ctx, NULL)) -+ || !TEST_true(BN_mod_exp_simple(a, e, p, m, ctx)) -+ || !TEST_BN_eq(a, d)) -+ goto err; -+ } -+ - /* Zero input */ - if (!TEST_true(BN_bntest_rand(p, 1024, 0, 0))) - goto err; \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-0778.patch b/third_party/patch/openssl/CVE-2022-0778.patch deleted file mode 100644 index e384dac255e..00000000000 --- a/third_party/patch/openssl/CVE-2022-0778.patch +++ /dev/null @@ -1,49 +0,0 @@ -diff --git a/crypto/bn/bn_sqrt.c b/crypto/bn/bn_sqrt.c -index 1723d5ded5..53b0f55985 100644 ---- a/crypto/bn/bn_sqrt.c -+++ b/crypto/bn/bn_sqrt.c -@@ -14,7 +14,8 @@ BIGNUM *BN_mod_sqrt(BIGNUM *in, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx) - /* - * Returns 'ret' such that ret^2 == a (mod p), using the Tonelli/Shanks - * algorithm (cf. Henri Cohen, "A Course in Algebraic Computational Number -- * Theory", algorithm 1.5.1). 'p' must be prime! -+ * Theory", algorithm 1.5.1). 'p' must be prime, otherwise an error or -+ * an incorrect "result" will be returned. - */ - { - BIGNUM *ret = in; -@@ -301,18 +302,23 @@ BIGNUM *BN_mod_sqrt(BIGNUM *in, const BIGNUM *a, const BIGNUM *p, BN_CTX *ctx) - goto vrfy; - } - -- /* find smallest i such that b^(2^i) = 1 */ -- i = 1; -- if (!BN_mod_sqr(t, b, p, ctx)) -- goto end; -- while (!BN_is_one(t)) { -- i++; -- if (i == e) { -- BNerr(BN_F_BN_MOD_SQRT, BN_R_NOT_A_SQUARE); -- goto end; -+ /* Find the smallest i, 0 < i < e, such that b^(2^i) = 1. */ -+ for (i = 1; i < e; i++) { -+ if (i == 1) { -+ if (!BN_mod_sqr(t, b, p, ctx)) -+ goto end; -+ -+ } else { -+ if (!BN_mod_mul(t, t, t, p, ctx)) -+ goto end; - } -- if (!BN_mod_mul(t, t, t, p, ctx)) -- goto end; -+ if (BN_is_one(t)) -+ break; -+ } -+ /* If not found, a is not a square or p is not prime. */ -+ if (i >= e) { -+ BNerr(BN_F_BN_MOD_SQRT, BN_R_NOT_A_SQUARE); -+ goto end; - } - - /* t := y^2^(e - i - 1) */ \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-1292.patch b/third_party/patch/openssl/CVE-2022-1292.patch deleted file mode 100644 index d07162220b5..00000000000 --- a/third_party/patch/openssl/CVE-2022-1292.patch +++ /dev/null @@ -1,58 +0,0 @@ -diff --git a/tools/c_rehash.in b/tools/c_rehash.in -index fa7c6c9fef..83c1cc80e0 100644 ---- a/tools/c_rehash.in -+++ b/tools/c_rehash.in -@@ -152,6 +152,23 @@ sub check_file { - return ($is_cert, $is_crl); - } - -+sub compute_hash { -+ my $fh; -+ if ( $^O eq "VMS" ) { -+ # VMS uses the open through shell -+ # The file names are safe there and list form is unsupported -+ if (!open($fh, "-|", join(' ', @_))) { -+ print STDERR "Cannot compute hash on '$fname'\n"; -+ return; -+ } -+ } else { -+ if (!open($fh, "-|", @_)) { -+ print STDERR "Cannot compute hash on '$fname'\n"; -+ return; -+ } -+ } -+ return (<$fh>, <$fh>); -+} - - # Link a certificate to its subject name hash value, each hash is of - # the form . where n is an integer. If the hash value already exists -@@ -161,10 +178,12 @@ sub check_file { - - sub link_hash_cert { - my $fname = $_[0]; -- $fname =~ s/\"/\\\"/g; -- my ($hash, $fprint) = `"$openssl" x509 $x509hash -fingerprint -noout -in "$fname"`; -+ my ($hash, $fprint) = compute_hash($openssl, "x509", $x509hash, -+ "-fingerprint", "-noout", -+ "-in", $fname); - chomp $hash; - chomp $fprint; -+ return if !$hash; - $fprint =~ s/^.*=//; - $fprint =~ tr/://d; - my $suffix = 0; -@@ -202,10 +221,12 @@ sub link_hash_cert { - - sub link_hash_crl { - my $fname = $_[0]; -- $fname =~ s/'/'\\''/g; -- my ($hash, $fprint) = `"$openssl" crl $crlhash -fingerprint -noout -in '$fname'`; -+ my ($hash, $fprint) = compute_hash($openssl, "crl", $crlhash, -+ "-fingerprint", "-noout", -+ "-in", $fname); - chomp $hash; - chomp $fprint; -+ return if !$hash; - $fprint =~ s/^.*=//; - $fprint =~ tr/://d; - my $suffix = 0; \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-2068.patch b/third_party/patch/openssl/CVE-2022-2068.patch deleted file mode 100644 index fde78b92782..00000000000 --- a/third_party/patch/openssl/CVE-2022-2068.patch +++ /dev/null @@ -1,241 +0,0 @@ -diff --git a/tools/c_rehash.in b/tools/c_rehash.in -index cfd18f5da1..9d2a6f6db7 100644 ---- a/tools/c_rehash.in -+++ b/tools/c_rehash.in -@@ -104,52 +104,78 @@ foreach (@dirlist) { - } - exit($errorcount); - -+sub copy_file { -+ my ($src_fname, $dst_fname) = @_; -+ -+ if (open(my $in, "<", $src_fname)) { -+ if (open(my $out, ">", $dst_fname)) { -+ print $out $_ while (<$in>); -+ close $out; -+ } else { -+ warn "Cannot open $dst_fname for write, $!"; -+ } -+ close $in; -+ } else { -+ warn "Cannot open $src_fname for read, $!"; -+ } -+} -+ - sub hash_dir { -- my %hashlist; -- print "Doing $_[0]\n"; -- chdir $_[0]; -- opendir(DIR, "."); -- my @flist = sort readdir(DIR); -- closedir DIR; -- if ( $removelinks ) { -- # Delete any existing symbolic links -- foreach (grep {/^[\da-f]+\.r{0,1}\d+$/} @flist) { -- if (-l $_) { -- print "unlink $_" if $verbose; -- unlink $_ || warn "Can't unlink $_, $!\n"; -- } -- } -- } -- FILE: foreach $fname (grep {/\.(pem)|(crt)|(cer)|(crl)$/} @flist) { -- # Check to see if certificates and/or CRLs present. -- my ($cert, $crl) = check_file($fname); -- if (!$cert && !$crl) { -- print STDERR "WARNING: $fname does not contain a certificate or CRL: skipping\n"; -- next; -- } -- link_hash_cert($fname) if ($cert); -- link_hash_crl($fname) if ($crl); -- } -+ my $dir = shift; -+ my %hashlist; -+ -+ print "Doing $dir\n"; -+ -+ if (!chdir $dir) { -+ print STDERR "WARNING: Cannot chdir to '$dir', $!\n"; -+ return; -+ } -+ -+ opendir(DIR, ".") || print STDERR "WARNING: Cannot opendir '.', $!\n"; -+ my @flist = sort readdir(DIR); -+ closedir DIR; -+ if ( $removelinks ) { -+ # Delete any existing symbolic links -+ foreach (grep {/^[\da-f]+\.r{0,1}\d+$/} @flist) { -+ if (-l $_) { -+ print "unlink $_\n" if $verbose; -+ unlink $_ || warn "Can't unlink $_, $!\n"; -+ } -+ } -+ } -+ FILE: foreach $fname (grep {/\.(pem)|(crt)|(cer)|(crl)$/} @flist) { -+ # Check to see if certificates and/or CRLs present. -+ my ($cert, $crl) = check_file($fname); -+ if (!$cert && !$crl) { -+ print STDERR "WARNING: $fname does not contain a certificate or CRL: skipping\n"; -+ next; -+ } -+ link_hash_cert($fname) if ($cert); -+ link_hash_crl($fname) if ($crl); -+ } -+ -+ chdir $pwd; - } - - sub check_file { -- my ($is_cert, $is_crl) = (0,0); -- my $fname = $_[0]; -- open IN, $fname; -- while() { -- if (/^-----BEGIN (.*)-----/) { -- my $hdr = $1; -- if ($hdr =~ /^(X509 |TRUSTED |)CERTIFICATE$/) { -- $is_cert = 1; -- last if ($is_crl); -- } elsif ($hdr eq "X509 CRL") { -- $is_crl = 1; -- last if ($is_cert); -- } -- } -- } -- close IN; -- return ($is_cert, $is_crl); -+ my ($is_cert, $is_crl) = (0,0); -+ my $fname = $_[0]; -+ -+ open(my $in, "<", $fname); -+ while(<$in>) { -+ if (/^-----BEGIN (.*)-----/) { -+ my $hdr = $1; -+ if ($hdr =~ /^(X509 |TRUSTED |)CERTIFICATE$/) { -+ $is_cert = 1; -+ last if ($is_crl); -+ } elsif ($hdr eq "X509 CRL") { -+ $is_crl = 1; -+ last if ($is_cert); -+ } -+ } -+ } -+ close $in; -+ return ($is_cert, $is_crl); - } - - sub compute_hash { -@@ -177,76 +203,48 @@ sub compute_hash { - # certificate fingerprints - - sub link_hash_cert { -- my $fname = $_[0]; -- my ($hash, $fprint) = compute_hash($openssl, "x509", $x509hash, -- "-fingerprint", "-noout", -- "-in", $fname); -- chomp $hash; -- chomp $fprint; -- return if !$hash; -- $fprint =~ s/^.*=//; -- $fprint =~ tr/://d; -- my $suffix = 0; -- # Search for an unused hash filename -- while(exists $hashlist{"$hash.$suffix"}) { -- # Hash matches: if fingerprint matches its a duplicate cert -- if ($hashlist{"$hash.$suffix"} eq $fprint) { -- print STDERR "WARNING: Skipping duplicate certificate $fname\n"; -- return; -- } -- $suffix++; -- } -- $hash .= ".$suffix"; -- if ($symlink_exists) { -- print "link $fname -> $hash\n" if $verbose; -- symlink $fname, $hash || warn "Can't symlink, $!"; -- } else { -- print "copy $fname -> $hash\n" if $verbose; -- if (open($in, "<", $fname)) { -- if (open($out,">", $hash)) { -- print $out $_ while (<$in>); -- close $out; -- } else { -- warn "can't open $hash for write, $!"; -- } -- close $in; -- } else { -- warn "can't open $fname for read, $!"; -- } -- } -- $hashlist{$hash} = $fprint; -+ link_hash($_[0], 'cert'); - } - - # Same as above except for a CRL. CRL links are of the form .r - - sub link_hash_crl { -- my $fname = $_[0]; -- my ($hash, $fprint) = compute_hash($openssl, "crl", $crlhash, -- "-fingerprint", "-noout", -- "-in", $fname); -- chomp $hash; -- chomp $fprint; -- return if !$hash; -- $fprint =~ s/^.*=//; -- $fprint =~ tr/://d; -- my $suffix = 0; -- # Search for an unused hash filename -- while(exists $hashlist{"$hash.r$suffix"}) { -- # Hash matches: if fingerprint matches its a duplicate cert -- if ($hashlist{"$hash.r$suffix"} eq $fprint) { -- print STDERR "WARNING: Skipping duplicate CRL $fname\n"; -- return; -- } -- $suffix++; -- } -- $hash .= ".r$suffix"; -- if ($symlink_exists) { -- print "link $fname -> $hash\n" if $verbose; -- symlink $fname, $hash || warn "Can't symlink, $!"; -- } else { -- print "cp $fname -> $hash\n" if $verbose; -- system ("cp", $fname, $hash); -- warn "Can't copy, $!" if ($? >> 8) != 0; -- } -- $hashlist{$hash} = $fprint; -+ link_hash($_[0], 'crl'); -+} -+ -+sub link_hash { -+ my ($fname, $type) = @_; -+ my $is_cert = $type eq 'cert'; -+ -+ my ($hash, $fprint) = compute_hash($openssl, -+ $is_cert ? "x509" : "crl", -+ $is_cert ? $x509hash : $crlhash, -+ "-fingerprint", "-noout", -+ "-in", $fname); -+ chomp $hash; -+ chomp $fprint; -+ return if !$hash; -+ $fprint =~ s/^.*=//; -+ $fprint =~ tr/://d; -+ my $suffix = 0; -+ # Search for an unused hash filename -+ my $crlmark = $is_cert ? "" : "r"; -+ while(exists $hashlist{"$hash.$crlmark$suffix"}) { -+ # Hash matches: if fingerprint matches its a duplicate cert -+ if ($hashlist{"$hash.$crlmark$suffix"} eq $fprint) { -+ my $what = $is_cert ? 'certificate' : 'CRL'; -+ print STDERR "WARNING: Skipping duplicate $what $fname\n"; -+ return; -+ } -+ $suffix++; -+ } -+ $hash .= ".$crlmark$suffix"; -+ if ($symlink_exists) { -+ print "link $fname -> $hash\n" if $verbose; -+ symlink $fname, $hash || warn "Can't symlink, $!"; -+ } else { -+ print "copy $fname -> $hash\n" if $verbose; -+ copy_file($fname, $hash); -+ } -+ $hashlist{$hash} = $fprint; - } \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-2097.patch b/third_party/patch/openssl/CVE-2022-2097.patch deleted file mode 100644 index 0bc4a3ab17b..00000000000 --- a/third_party/patch/openssl/CVE-2022-2097.patch +++ /dev/null @@ -1,22 +0,0 @@ -diff --git a/crypto/aes/asm/aesni-x86.pl b/crypto/aes/asm/aesni-x86.pl -index fe2b26542a..812758e02e 100644 ---- a/crypto/aes/asm/aesni-x86.pl -+++ b/crypto/aes/asm/aesni-x86.pl -@@ -2027,7 +2027,7 @@ my ($l_,$block,$i1,$i3,$i5) = ($rounds_,$key_,$rounds,$len,$out); - &movdqu (&QWP(-16*2,$out,$inp),$inout4); - &movdqu (&QWP(-16*1,$out,$inp),$inout5); - &cmp ($inp,$len); # done yet? -- &jb (&label("grandloop")); -+ &jbe (&label("grandloop")); - - &set_label("short"); - &add ($len,16*6); -@@ -2453,7 +2453,7 @@ my ($l_,$block,$i1,$i3,$i5) = ($rounds_,$key_,$rounds,$len,$out); - &pxor ($rndkey1,$inout5); - &movdqu (&QWP(-16*1,$out,$inp),$inout5); - &cmp ($inp,$len); # done yet? -- &jb (&label("grandloop")); -+ &jbe (&label("grandloop")); - - &set_label("short"); - &add ($len,16*6); \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-4304.patch b/third_party/patch/openssl/CVE-2022-4304.patch deleted file mode 100644 index b898a5073be..00000000000 --- a/third_party/patch/openssl/CVE-2022-4304.patch +++ /dev/null @@ -1,771 +0,0 @@ -diff --git a/crypto/bn/bn_blind.c b/crypto/bn/bn_blind.c -index 76fc7ebcff..6e9d239321 100644 ---- a/crypto/bn/bn_blind.c -+++ b/crypto/bn/bn_blind.c -@@ -13,20 +13,6 @@ - - #define BN_BLINDING_COUNTER 32 - --struct bn_blinding_st { -- BIGNUM *A; -- BIGNUM *Ai; -- BIGNUM *e; -- BIGNUM *mod; /* just a reference */ -- CRYPTO_THREAD_ID tid; -- int counter; -- unsigned long flags; -- BN_MONT_CTX *m_ctx; -- int (*bn_mod_exp) (BIGNUM *r, const BIGNUM *a, const BIGNUM *p, -- const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx); -- CRYPTO_RWLOCK *lock; --}; -- - BN_BLINDING *BN_BLINDING_new(const BIGNUM *A, const BIGNUM *Ai, BIGNUM *mod) - { - BN_BLINDING *ret = NULL; -diff --git a/crypto/bn/bn_err.c b/crypto/bn/bn_err.c -index dd87c152cf..3dd8d9a568 100644 ---- a/crypto/bn/bn_err.c -+++ b/crypto/bn/bn_err.c -@@ -73,6 +73,8 @@ static const ERR_STRING_DATA BN_str_functs[] = { - {ERR_PACK(ERR_LIB_BN, BN_F_BN_SET_WORDS, 0), "bn_set_words"}, - {ERR_PACK(ERR_LIB_BN, BN_F_BN_STACK_PUSH, 0), "BN_STACK_push"}, - {ERR_PACK(ERR_LIB_BN, BN_F_BN_USUB, 0), "BN_usub"}, -+ {ERR_PACK(ERR_LIB_BN, BN_F_OSSL_BN_RSA_DO_UNBLIND, 0), -+ "ossl_bn_rsa_do_unblind"}, - {0, NULL} - }; - -diff --git a/crypto/bn/bn_local.h b/crypto/bn/bn_local.h -index 62a969b134..4d8cb64675 100644 ---- a/crypto/bn/bn_local.h -+++ b/crypto/bn/bn_local.h -@@ -283,6 +283,20 @@ struct bn_gencb_st { - } cb; - }; - -+struct bn_blinding_st { -+ BIGNUM *A; -+ BIGNUM *Ai; -+ BIGNUM *e; -+ BIGNUM *mod; /* just a reference */ -+ CRYPTO_THREAD_ID tid; -+ int counter; -+ unsigned long flags; -+ BN_MONT_CTX *m_ctx; -+ int (*bn_mod_exp) (BIGNUM *r, const BIGNUM *a, const BIGNUM *p, -+ const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *m_ctx); -+ CRYPTO_RWLOCK *lock; -+}; -+ - /*- - * BN_window_bits_for_exponent_size -- macro for sliding window mod_exp functions - * -diff --git a/crypto/bn/build.info b/crypto/bn/build.info -index b9ed5322fa..c9fe2fdada 100644 ---- a/crypto/bn/build.info -+++ b/crypto/bn/build.info -@@ -5,7 +5,8 @@ SOURCE[../../libcrypto]=\ - bn_kron.c bn_sqrt.c bn_gcd.c bn_prime.c bn_err.c bn_sqr.c \ - {- $target{bn_asm_src} -} \ - bn_recp.c bn_mont.c bn_mpi.c bn_exp2.c bn_gf2m.c bn_nist.c \ -- bn_depr.c bn_const.c bn_x931p.c bn_intern.c bn_dh.c bn_srp.c -+ bn_depr.c bn_const.c bn_x931p.c bn_intern.c bn_dh.c bn_srp.c \ -+ rsa_sup_mul.c - - INCLUDE[bn_exp.o]=.. - -diff --git a/crypto/bn/rsa_sup_mul.c b/crypto/bn/rsa_sup_mul.c -new file mode 100644 -index 0000000000..acafefd5fe ---- /dev/null -+++ b/crypto/bn/rsa_sup_mul.c -@@ -0,0 +1,614 @@ -+#include -+#include -+#include -+#include -+#include -+#include -+#include -+#include "internal/numbers.h" -+#include "internal/constant_time.h" -+#include "bn_local.h" -+ -+# if BN_BYTES == 8 -+typedef uint64_t limb_t; -+# if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16 -+/* nonstandard; implemented by gcc on 64-bit platforms */ -+typedef __uint128_t limb2_t; -+# define HAVE_LIMB2_T -+# endif -+# define LIMB_BIT_SIZE 64 -+# define LIMB_BYTE_SIZE 8 -+# elif BN_BYTES == 4 -+typedef uint32_t limb_t; -+typedef uint64_t limb2_t; -+# define LIMB_BIT_SIZE 32 -+# define LIMB_BYTE_SIZE 4 -+# define HAVE_LIMB2_T -+# else -+# error "Not supported" -+# endif -+ -+/* -+ * For multiplication we're using schoolbook multiplication, -+ * so if we have two numbers, each with 6 "digits" (words) -+ * the multiplication is calculated as follows: -+ * A B C D E F -+ * x I J K L M N -+ * -------------- -+ * N*F -+ * N*E -+ * N*D -+ * N*C -+ * N*B -+ * N*A -+ * M*F -+ * M*E -+ * M*D -+ * M*C -+ * M*B -+ * M*A -+ * L*F -+ * L*E -+ * L*D -+ * L*C -+ * L*B -+ * L*A -+ * K*F -+ * K*E -+ * K*D -+ * K*C -+ * K*B -+ * K*A -+ * J*F -+ * J*E -+ * J*D -+ * J*C -+ * J*B -+ * J*A -+ * I*F -+ * I*E -+ * I*D -+ * I*C -+ * I*B -+ * + I*A -+ * ========================== -+ * N*B N*D N*F -+ * + N*A N*C N*E -+ * + M*B M*D M*F -+ * + M*A M*C M*E -+ * + L*B L*D L*F -+ * + L*A L*C L*E -+ * + K*B K*D K*F -+ * + K*A K*C K*E -+ * + J*B J*D J*F -+ * + J*A J*C J*E -+ * + I*B I*D I*F -+ * + I*A I*C I*E -+ * -+ * 1+1 1+3 1+5 -+ * 1+0 1+2 1+4 -+ * 0+1 0+3 0+5 -+ * 0+0 0+2 0+4 -+ * -+ * 0 1 2 3 4 5 6 -+ * which requires n^2 multiplications and 2n full length additions -+ * as we can keep every other result of limb multiplication in two separate -+ * limbs -+ */ -+ -+#if defined HAVE_LIMB2_T -+static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) -+{ -+ limb2_t t; -+ /* -+ * this is idiomatic code to tell compiler to use the native mul -+ * those three lines will actually compile to single instruction -+ */ -+ -+ t = (limb2_t)a * b; -+ *hi = t >> LIMB_BIT_SIZE; -+ *lo = (limb_t)t; -+} -+#elif (BN_BYTES == 8) && (defined _MSC_VER) -+/* https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-170 */ -+#pragma intrinsic(_umul128) -+static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) -+{ -+ *lo = _umul128(a, b, hi); -+} -+#else -+/* -+ * if the compiler doesn't have either a 128bit data type nor a "return -+ * high 64 bits of multiplication" -+ */ -+static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b) -+{ -+ limb_t a_low = (limb_t)(uint32_t)a; -+ limb_t a_hi = a >> 32; -+ limb_t b_low = (limb_t)(uint32_t)b; -+ limb_t b_hi = b >> 32; -+ -+ limb_t p0 = a_low * b_low; -+ limb_t p1 = a_low * b_hi; -+ limb_t p2 = a_hi * b_low; -+ limb_t p3 = a_hi * b_hi; -+ -+ uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32); -+ -+ *lo = p0 + (p1 << 32) + (p2 << 32); -+ *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy; -+} -+#endif -+ -+/* add two limbs with carry in, return carry out */ -+static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry) -+{ -+ limb_t carry1, carry2, t; -+ /* -+ * `c = a + b; if (c < a)` is idiomatic code that makes compilers -+ * use add with carry on assembly level -+ */ -+ -+ *ret = a + carry; -+ if (*ret < a) -+ carry1 = 1; -+ else -+ carry1 = 0; -+ -+ t = *ret; -+ *ret = t + b; -+ if (*ret < t) -+ carry2 = 1; -+ else -+ carry2 = 0; -+ -+ return carry1 + carry2; -+} -+ -+/* -+ * add two numbers of the same size, return overflow -+ * -+ * add a to b, place result in ret; all arrays need to be n limbs long -+ * return overflow from addition (0 or 1) -+ */ -+static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n) -+{ -+ limb_t c = 0; -+ ossl_ssize_t i; -+ -+ for(i = n - 1; i > -1; i--) -+ c = _add_limb(&ret[i], a[i], b[i], c); -+ -+ return c; -+} -+ -+/* -+ * return number of limbs necessary for temporary values -+ * when multiplying numbers n limbs large -+ */ -+static ossl_inline size_t mul_limb_numb(size_t n) -+{ -+ return 2 * n * 2; -+} -+ -+/* -+ * multiply two numbers of the same size -+ * -+ * multiply a by b, place result in ret; a and b need to be n limbs long -+ * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs -+ * long -+ */ -+static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp) -+{ -+ limb_t *r_odd, *r_even; -+ size_t i, j, k; -+ -+ r_odd = tmp; -+ r_even = &tmp[2 * n]; -+ -+ memset(ret, 0, 2 * n * sizeof(limb_t)); -+ -+ for (i = 0; i < n; i++) { -+ for (k = 0; k < i + n + 1; k++) { -+ r_even[k] = 0; -+ r_odd[k] = 0; -+ } -+ for (j = 0; j < n; j++) { -+ /* -+ * place results from even and odd limbs in separate arrays so that -+ * we don't have to calculate overflow every time we get individual -+ * limb multiplication result -+ */ -+ if (j % 2 == 0) -+ _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]); -+ else -+ _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]); -+ } -+ /* -+ * skip the least significant limbs when adding multiples of -+ * more significant limbs (they're zero anyway) -+ */ -+ add(ret, ret, r_even, n + i + 1); -+ add(ret, ret, r_odd, n + i + 1); -+ } -+} -+ -+/* modifies the value in place by performing a right shift by one bit */ -+static ossl_inline void rshift1(limb_t *val, size_t n) -+{ -+ limb_t shift_in = 0, shift_out = 0; -+ size_t i; -+ -+ for (i = 0; i < n; i++) { -+ shift_out = val[i] & 1; -+ val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1); -+ shift_in = shift_out; -+ } -+} -+ -+/* extend the LSB of flag to all bits of limb */ -+static ossl_inline limb_t mk_mask(limb_t flag) -+{ -+ flag |= flag << 1; -+ flag |= flag << 2; -+ flag |= flag << 4; -+ flag |= flag << 8; -+ flag |= flag << 16; -+#if (LIMB_BYTE_SIZE == 8) -+ flag |= flag << 32; -+#endif -+ return flag; -+} -+ -+/* -+ * copy from either a or b to ret based on flag -+ * when flag == 0, then copies from b -+ * when flag == 1, then copies from a -+ */ -+static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n) -+{ -+ /* -+ * would be more efficient with non volatile mask, but then gcc -+ * generates code with jumps -+ */ -+ volatile limb_t mask; -+ size_t i; -+ -+ mask = mk_mask(flag); -+ for (i = 0; i < n; i++) { -+#if (LIMB_BYTE_SIZE == 8) -+ ret[i] = constant_time_select_64(mask, a[i], b[i]); -+#else -+ ret[i] = constant_time_select_32(mask, a[i], b[i]); -+#endif -+ } -+} -+ -+static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow) -+{ -+ limb_t borrow1, borrow2, t; -+ /* -+ * while it doesn't look constant-time, this is idiomatic code -+ * to tell compilers to use the carry bit from subtraction -+ */ -+ -+ *ret = a - borrow; -+ if (*ret > a) -+ borrow1 = 1; -+ else -+ borrow1 = 0; -+ -+ t = *ret; -+ *ret = t - b; -+ if (*ret > t) -+ borrow2 = 1; -+ else -+ borrow2 = 0; -+ -+ return borrow1 + borrow2; -+} -+ -+/* -+ * place the result of a - b into ret, return the borrow bit. -+ * All arrays need to be n limbs long -+ */ -+static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n) -+{ -+ limb_t borrow = 0; -+ ossl_ssize_t i; -+ -+ for (i = n - 1; i > -1; i--) -+ borrow = _sub_limb(&ret[i], a[i], b[i], borrow); -+ -+ return borrow; -+} -+ -+/* return the number of limbs necessary to allocate for the mod() tmp operand */ -+static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum) -+{ -+ return (anum + modnum) * 3; -+} -+ -+/* -+ * calculate a % mod, place the result in ret -+ * size of a is defined by anum, size of ret and mod is modnum, -+ * size of tmp is returned by mod_limb_numb() -+ */ -+static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod, -+ size_t modnum, limb_t *tmp) -+{ -+ limb_t *atmp, *modtmp, *rettmp; -+ limb_t res; -+ size_t i; -+ -+ memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE); -+ -+ atmp = tmp; -+ modtmp = &tmp[anum + modnum]; -+ rettmp = &tmp[(anum + modnum) * 2]; -+ -+ for (i = modnum; i 0; i--, rp--) { -+ v = _mul_add_limb(rp, mod, modnum, rp[modnum - 1] * ni0, tmp2); -+ v = v + carry + rp[-1]; -+ carry |= (v != rp[-1]); -+ carry &= (v <= rp[-1]); -+ rp[-1] = v; -+ } -+ -+ /* perform the final reduction by mod... */ -+ carry -= sub(ret, rp, mod, modnum); -+ -+ /* ...conditionally */ -+ cselect(carry, ret, rp, ret, modnum); -+} -+ -+/* allocated buffer should be freed afterwards */ -+static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs) -+{ -+ int i; -+ int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; -+ limb_t *ptr = buf + (limbs - real_limbs); -+ -+ for (i = 0; i < real_limbs; i++) -+ ptr[i] = bn->d[real_limbs - i - 1]; -+} -+ -+#if LIMB_BYTE_SIZE == 8 -+static ossl_inline uint64_t be64(uint64_t host) -+{ -+ const union { -+ long one; -+ char little; -+ } is_endian = { 1 }; -+ -+ if (is_endian.little) { -+ uint64_t big = 0; -+ -+ big |= (host & 0xff00000000000000) >> 56; -+ big |= (host & 0x00ff000000000000) >> 40; -+ big |= (host & 0x0000ff0000000000) >> 24; -+ big |= (host & 0x000000ff00000000) >> 8; -+ big |= (host & 0x00000000ff000000) << 8; -+ big |= (host & 0x0000000000ff0000) << 24; -+ big |= (host & 0x000000000000ff00) << 40; -+ big |= (host & 0x00000000000000ff) << 56; -+ return big; -+ } else { -+ return host; -+ } -+} -+ -+#else -+/* Not all platforms have htobe32(). */ -+static ossl_inline uint32_t be32(uint32_t host) -+{ -+ const union { -+ long one; -+ char little; -+ } is_endian = { 1 }; -+ -+ if (is_endian.little) { -+ uint32_t big = 0; -+ -+ big |= (host & 0xff000000) >> 24; -+ big |= (host & 0x00ff0000) >> 8; -+ big |= (host & 0x0000ff00) << 8; -+ big |= (host & 0x000000ff) << 24; -+ return big; -+ } else { -+ return host; -+ } -+} -+#endif -+ -+/* -+ * We assume that intermediate, possible_arg2, blinding, and ctx are used -+ * similar to BN_BLINDING_invert_ex() arguments. -+ * to_mod is RSA modulus. -+ * buf and num is the serialization buffer and its length. -+ * -+ * Here we use classic/Montgomery multiplication and modulo. After the calculation finished -+ * we serialize the new structure instead of BIGNUMs taking endianness into account. -+ */ -+int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate, -+ const BN_BLINDING *blinding, -+ const BIGNUM *possible_arg2, -+ const BIGNUM *to_mod, BN_CTX *ctx, -+ unsigned char *buf, int num) -+{ -+ limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL; -+ limb_t *l_ret = NULL, *l_tmp = NULL, l_buf; -+ size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0; -+ size_t l_tmp_count = 0; -+ int ret = 0; -+ size_t i; -+ unsigned char *tmp; -+ const BIGNUM *arg1 = intermediate; -+ const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2; -+ -+ l_im_count = (BN_num_bytes(arg1) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; -+ l_mul_count = (BN_num_bytes(arg2) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; -+ l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE; -+ -+ l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count; -+ l_im = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); -+ l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE); -+ l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE); -+ -+ if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL)) -+ goto err; -+ -+ BN_to_limb(arg1, l_im, l_size); -+ BN_to_limb(arg2, l_mul, l_size); -+ BN_to_limb(to_mod, l_mod, l_mod_count); -+ -+ l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE); -+ -+ if (blinding->m_ctx != NULL) { -+ l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ? -+ mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count); -+ l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); -+ } else { -+ l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ? -+ mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count); -+ l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE); -+ } -+ -+ if ((l_ret == NULL) || (l_tmp == NULL)) -+ goto err; -+ -+ if (blinding->m_ctx != NULL) { -+ limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); -+ mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, -+ blinding->m_ctx->n0[0], l_tmp); -+ } else { -+ limb_mul(l_ret, l_im, l_mul, l_size, l_tmp); -+ mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp); -+ } -+ -+ /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */ -+ if (num < BN_num_bytes(to_mod)) { -+ BNerr(BN_F_OSSL_BN_RSA_DO_UNBLIND, ERR_R_PASSED_INVALID_ARGUMENT); -+ goto err; -+ } -+ -+ memset(buf, 0, num); -+ tmp = buf + num - BN_num_bytes(to_mod); -+ for (i = 0; i < l_mod_count; i++) { -+#if LIMB_BYTE_SIZE == 8 -+ l_buf = be64(l_ret[i]); -+#else -+ l_buf = be32(l_ret[i]); -+#endif -+ if (i == 0) { -+ int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num); -+ -+ memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta); -+ tmp += delta; -+ } else { -+ memcpy(tmp, &l_buf, LIMB_BYTE_SIZE); -+ tmp += LIMB_BYTE_SIZE; -+ } -+ } -+ ret = num; -+ -+ err: -+ OPENSSL_free(l_im); -+ OPENSSL_free(l_mul); -+ OPENSSL_free(l_mod); -+ OPENSSL_free(l_tmp); -+ OPENSSL_free(l_ret); -+ -+ return ret; -+} -diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt -index 9f91a4a811..ba3a46d5b9 100644 ---- a/crypto/err/openssl.txt -+++ b/crypto/err/openssl.txt -@@ -1,4 +1,4 @@ --# Copyright 1999-2021 The OpenSSL Project Authors. All Rights Reserved. -+# Copyright 1999-2023 The OpenSSL Project Authors. All Rights Reserved. - # - # Licensed under the OpenSSL license (the "License"). You may not use - # this file except in compliance with the License. You can obtain a copy -@@ -232,6 +232,7 @@ BN_F_BN_RSHIFT:146:BN_rshift - BN_F_BN_SET_WORDS:144:bn_set_words - BN_F_BN_STACK_PUSH:148:BN_STACK_push - BN_F_BN_USUB:115:BN_usub -+BN_F_OSSL_BN_RSA_DO_UNBLIND:151:ossl_bn_rsa_do_unblind - BUF_F_BUF_MEM_GROW:100:BUF_MEM_grow - BUF_F_BUF_MEM_GROW_CLEAN:105:BUF_MEM_grow_clean - BUF_F_BUF_MEM_NEW:101:BUF_MEM_new -diff --git a/crypto/rsa/rsa_ossl.c b/crypto/rsa/rsa_ossl.c -index b52a66f6a6..6c3c0cf78d 100644 ---- a/crypto/rsa/rsa_ossl.c -+++ b/crypto/rsa/rsa_ossl.c -@@ -465,11 +465,20 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from, - BN_free(d); - } - -- if (blinding) -- if (!rsa_blinding_invert(blinding, ret, unblind, ctx)) -+ if (blinding) { -+ /* -+ * ossl_bn_rsa_do_unblind() combines blinding inversion and -+ * 0-padded BN BE serialization -+ */ -+ j = ossl_bn_rsa_do_unblind(ret, blinding, unblind, rsa->n, ctx, -+ buf, num); -+ if (j == 0) - goto err; -- -- j = BN_bn2binpad(ret, buf, num); -+ } else { -+ j = BN_bn2binpad(ret, buf, num); -+ if (j < 0) -+ goto err; -+ } - - switch (padding) { - case RSA_PKCS1_PADDING: -diff --git a/include/crypto/bn.h b/include/crypto/bn.h -index 60afda1dad..b5f36fb25a 100644 ---- a/include/crypto/bn.h -+++ b/include/crypto/bn.h -@@ -86,5 +86,10 @@ int bn_lshift_fixed_top(BIGNUM *r, const BIGNUM *a, int n); - int bn_rshift_fixed_top(BIGNUM *r, const BIGNUM *a, int n); - int bn_div_fixed_top(BIGNUM *dv, BIGNUM *rem, const BIGNUM *m, - const BIGNUM *d, BN_CTX *ctx); -+int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate, -+ const BN_BLINDING *blinding, -+ const BIGNUM *possible_arg2, -+ const BIGNUM *to_mod, BN_CTX *ctx, -+ unsigned char *buf, int num); - - #endif -diff --git a/include/openssl/bnerr.h b/include/openssl/bnerr.h -index 9f3c7cfaab..a0752cea52 100644 ---- a/include/openssl/bnerr.h -+++ b/include/openssl/bnerr.h -@@ -72,6 +72,7 @@ int ERR_load_BN_strings(void); - # define BN_F_BN_SET_WORDS 144 - # define BN_F_BN_STACK_PUSH 148 - # define BN_F_BN_USUB 115 -+# define BN_F_OSSL_BN_RSA_DO_UNBLIND 151 - - /* - * BN reason codes. \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2022-4450.patch b/third_party/patch/openssl/CVE-2022-4450.patch deleted file mode 100644 index 3364ea75a61..00000000000 --- a/third_party/patch/openssl/CVE-2022-4450.patch +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/crypto/pem/pem_lib.c b/crypto/pem/pem_lib.c -index d416d939ea..328c30cdbb 100644 ---- a/crypto/pem/pem_lib.c -+++ b/crypto/pem/pem_lib.c -@@ -957,7 +957,9 @@ int PEM_read_bio_ex(BIO *bp, char **name_out, char **header, - *data = pem_malloc(len, flags); - if (*header == NULL || *data == NULL) { - pem_free(*header, flags, 0); -+ *header = NULL; - pem_free(*data, flags, 0); -+ *data = NULL; - goto end; - } - BIO_read(headerB, *header, headerlen); \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2023-0215.patch b/third_party/patch/openssl/CVE-2023-0215.patch deleted file mode 100644 index 26b5cde0a9e..00000000000 --- a/third_party/patch/openssl/CVE-2023-0215.patch +++ /dev/null @@ -1,144 +0,0 @@ -diff --git a/crypto/asn1/bio_ndef.c b/crypto/asn1/bio_ndef.c -index 760e4846a4..f8d4b1b9aa 100644 ---- a/crypto/asn1/bio_ndef.c -+++ b/crypto/asn1/bio_ndef.c -@@ -49,12 +49,19 @@ static int ndef_suffix(BIO *b, unsigned char **pbuf, int *plen, void *parg); - static int ndef_suffix_free(BIO *b, unsigned char **pbuf, int *plen, - void *parg); - -+/* -+ * On success, the returned BIO owns the input BIO as part of its BIO chain. -+ * On failure, NULL is returned and the input BIO is owned by the caller. -+ * -+ * Unfortunately cannot constify this due to CMS_stream() and PKCS7_stream() -+ */ - BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) - { - NDEF_SUPPORT *ndef_aux = NULL; - BIO *asn_bio = NULL; - const ASN1_AUX *aux = it->funcs; - ASN1_STREAM_ARG sarg; -+ BIO *pop_bio = NULL; - - if (!aux || !aux->asn1_cb) { - ASN1err(ASN1_F_BIO_NEW_NDEF, ASN1_R_STREAMING_NOT_SUPPORTED); -@@ -69,21 +76,39 @@ BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) - out = BIO_push(asn_bio, out); - if (out == NULL) - goto err; -+ pop_bio = asn_bio; - -- BIO_asn1_set_prefix(asn_bio, ndef_prefix, ndef_prefix_free); -- BIO_asn1_set_suffix(asn_bio, ndef_suffix, ndef_suffix_free); -+ if (BIO_asn1_set_prefix(asn_bio, ndef_prefix, ndef_prefix_free) <= 0 -+ || BIO_asn1_set_suffix(asn_bio, ndef_suffix, ndef_suffix_free) <= 0 -+ || BIO_ctrl(asn_bio, BIO_C_SET_EX_ARG, 0, ndef_aux) <= 0) -+ goto err; - - /* -- * Now let callback prepends any digest, cipher etc BIOs ASN1 structure -- * needs. -+ * Now let the callback prepend any digest, cipher, etc., that the BIO's -+ * ASN1 structure needs. - */ - - sarg.out = out; - sarg.ndef_bio = NULL; - sarg.boundary = NULL; - -- if (aux->asn1_cb(ASN1_OP_STREAM_PRE, &val, it, &sarg) <= 0) -+ /* -+ * The asn1_cb(), must not have mutated asn_bio on error, leaving it in the -+ * middle of some partially built, but not returned BIO chain. -+ */ -+ if (aux->asn1_cb(ASN1_OP_STREAM_PRE, &val, it, &sarg) <= 0) { -+ /* -+ * ndef_aux is now owned by asn_bio so we must not free it in the err -+ * clean up block -+ */ -+ ndef_aux = NULL; - goto err; -+ } -+ -+ /* -+ * We must not fail now because the callback has prepended additional -+ * BIOs to the chain -+ */ - - ndef_aux->val = val; - ndef_aux->it = it; -@@ -91,11 +116,11 @@ BIO *BIO_new_NDEF(BIO *out, ASN1_VALUE *val, const ASN1_ITEM *it) - ndef_aux->boundary = sarg.boundary; - ndef_aux->out = out; - -- BIO_ctrl(asn_bio, BIO_C_SET_EX_ARG, 0, ndef_aux); -- - return sarg.ndef_bio; - - err: -+ /* BIO_pop() is NULL safe */ -+ (void)BIO_pop(pop_bio); - BIO_free(asn_bio); - OPENSSL_free(ndef_aux); - return NULL; -diff --git a/test/recipes/80-test_cms.t b/test/recipes/80-test_cms.t -index 5dc6a3aebe..ec11bfc253 100644 ---- a/test/recipes/80-test_cms.t -+++ b/test/recipes/80-test_cms.t -@@ -13,7 +13,7 @@ use warnings; - use POSIX; - use File::Spec::Functions qw/catfile/; - use File::Compare qw/compare_text/; --use OpenSSL::Test qw/:DEFAULT srctop_dir srctop_file/; -+use OpenSSL::Test qw/:DEFAULT srctop_dir srctop_file with/; - use OpenSSL::Test::Utils; - - setup("test_cms"); -@@ -27,7 +27,7 @@ my $smcont = srctop_file("test", "smcont.txt"); - my ($no_des, $no_dh, $no_dsa, $no_ec, $no_ec2m, $no_rc2, $no_zlib) - = disabled qw/des dh dsa ec ec2m rc2 zlib/; - --plan tests => 6; -+plan tests => 7; - - my @smime_pkcs7_tests = ( - -@@ -584,3 +584,14 @@ sub check_availability { - - return ""; - } -+ -+# Check that we get the expected failure return code -+with({ exit_checker => sub { return shift == 6; } }, -+ sub { -+ ok(run(app(['openssl', 'cms', '-encrypt', -+ '-in', srctop_file("test", "smcont.txt"), -+ '-stream', '-recip', -+ srctop_file("test/smime-certs", "badrsa.pem"), -+ ])), -+ "Check failure during BIO setup with -stream is handled correctly"); -+ }); -diff --git a/test/smime-certs/badrsa.pem b/test/smime-certs/badrsa.pem -new file mode 100644 -index 0000000000..f824fc2267 ---- /dev/null -+++ b/test/smime-certs/badrsa.pem -@@ -0,0 +1,18 @@ -+-----BEGIN CERTIFICATE----- -+MIIDbTCCAlWgAwIBAgIToTV4Z0iuK08vZP20oTh//hC8BDANBgkqhkiG9w0BAQ0FADAtMSswKQYD -+VfcDEyJTYW1wbGUgTEFNUFMgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MCAXDTE5MTEyMDA2NTQxOFoY -+DzIwNTIwOTI3MDY1NDE4WjAZMRcwFQYDVQQDEw5BbGljZSBMb3ZlbGFjZTCCASIwDQYJKoZIhvcN -+AQEBBQADggEPADCCAQoCggEBALT0iehYOBY+TZp/T5K2KNI05Hwr+E3wP6XTvyi6WWyTgBK9LCOw -+I2juwdRrjFBmXkk7pWpjXwsA3A5GOtz0FpfgyC7OxsVcF7q4WHWZWleYXFKlQHJD73nQwXP968+A -+/3rBX7PhO0DBbZnfitOLPgPEwjTtdg0VQQ6Wz+CRQ/YbHPKaw7aRphZO63dKvIKp4cQVtkWQHi6s -+yTjGsgkLcLNau5LZDQUdsGV+SAo3nBdWCRYV+I65x8Kf4hCxqqmjV3d/2NKRu0BXnDe/N+iDz3X0 -+zEoj0fqXgq4SWcC0nsG1lyyXt1TL270I6ATKRGJWiQVCCpDtc0NT6vdJ45bCSxgCAwEAAaOBlzCB -+lDAMBgNVHRMBAf8EAjAAMB4GA1UdEQQXMBWBE2FsaWNlQHNtaW1lLmV4YW1wbGUwEwYDVR0lBAww -+CgYIKwYBBQUHAwQwDwYDVR0PAQH/BAUDAwfAADAdBgNVHQ4EFgQUu/bMsi0dBhIcl64papAQ0yBm -+ZnMwHwYDVR0jBBgwFoAUeF8OWnjYa+RUcD2z3ez38fL6wEcwDQYJKoZIhvcNAQENBQADggEBABbW -+eonR6TMTckehDKNOabwaCIcekahAIL6l9tTzUX5ew6ufiAPlC6I/zQlmUaU0iSyFDG1NW14kNbFt -+5CAokyLhMtE4ASHBIHbiOp/ZSbUBTVYJZB61ot7w1/ol5QECSs08b8zrxIncf+t2DHGuVEy/Qq1d -+rBz8d4ay8zpqAE1tUyL5Da6ZiKUfWwZQXSI/JlbjQFzYQqTRDnzHWrg1xPeMTO1P2/cplFaseTiv -+yk4cYwOp/W9UAWymOZXF8WcJYCIUXkdcG/nEZxr057KlScrJmFXOoh7Y+8ON4iWYYcAfiNgpUFo/ -+j8BAwrKKaFvdlZS9k1Ypb2+UQY75mKJE9Bg= -+-----END CERTIFICATE----- \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2023-0286.patch b/third_party/patch/openssl/CVE-2023-0286.patch deleted file mode 100644 index c3915edf128..00000000000 --- a/third_party/patch/openssl/CVE-2023-0286.patch +++ /dev/null @@ -1,44 +0,0 @@ -diff --git a/crypto/x509v3/v3_genn.c b/crypto/x509v3/v3_genn.c -index 87a5eff47c..e54ddc55c9 100644 ---- a/crypto/x509v3/v3_genn.c -+++ b/crypto/x509v3/v3_genn.c -@@ -98,7 +98,7 @@ int GENERAL_NAME_cmp(GENERAL_NAME *a, GENERAL_NAME *b) - return -1; - switch (a->type) { - case GEN_X400: -- result = ASN1_TYPE_cmp(a->d.x400Address, b->d.x400Address); -+ result = ASN1_STRING_cmp(a->d.x400Address, b->d.x400Address); - break; - - case GEN_EDIPARTY: -diff --git a/include/openssl/x509v3.h b/include/openssl/x509v3.h -index 90fa3592ce..e61c0f29d4 100644 ---- a/include/openssl/x509v3.h -+++ b/include/openssl/x509v3.h -@@ -136,7 +136,7 @@ typedef struct GENERAL_NAME_st { - OTHERNAME *otherName; /* otherName */ - ASN1_IA5STRING *rfc822Name; - ASN1_IA5STRING *dNSName; -- ASN1_TYPE *x400Address; -+ ASN1_STRING *x400Address; - X509_NAME *directoryName; - EDIPARTYNAME *ediPartyName; - ASN1_IA5STRING *uniformResourceIdentifier; -diff --git a/test/v3nametest.c b/test/v3nametest.c -index d1852190b8..37819da8fd 100644 ---- a/test/v3nametest.c -+++ b/test/v3nametest.c -@@ -646,6 +646,14 @@ static struct gennamedata { - 0xb7, 0x09, 0x02, 0x02 - }, - 15 -+ }, { -+ /* -+ * Regression test for CVE-2023-0286. -+ */ -+ { -+ 0xa3, 0x00 -+ }, -+ 2 - } - }; diff --git a/third_party/patch/openssl/CVE-2023-0464.patch b/third_party/patch/openssl/CVE-2023-0464.patch deleted file mode 100644 index f87f8f58884..00000000000 --- a/third_party/patch/openssl/CVE-2023-0464.patch +++ /dev/null @@ -1,222 +0,0 @@ -From 879f7080d7e141f415c79eaa3a8ac4a3dad0348b Mon Sep 17 00:00:00 2001 -From: Pauli -Date: Wed, 8 Mar 2023 15:28:20 +1100 -Subject: [PATCH] x509: excessive resource use verifying policy constraints - -A security vulnerability has been identified in all supported versions -of OpenSSL related to the verification of X.509 certificate chains -that include policy constraints. Attackers may be able to exploit this -vulnerability by creating a malicious certificate chain that triggers -exponential use of computational resources, leading to a denial-of-service -(DoS) attack on affected systems. - -Fixes CVE-2023-0464 - -Reviewed-by: Tomas Mraz -Reviewed-by: Shane Lontis -(Merged from https://github.com/openssl/openssl/pull/20569) ---- - crypto/x509v3/pcy_local.h | 8 +++++++- - crypto/x509v3/pcy_node.c | 12 +++++++++--- - crypto/x509v3/pcy_tree.c | 37 +++++++++++++++++++++++++++---------- - 3 files changed, 43 insertions(+), 14 deletions(-) - -diff --git a/crypto/x509v3/pcy_local.h b/crypto/x509v3/pcy_local.h -index 5daf78de45..344aa06765 100644 ---- a/crypto/x509v3/pcy_local.h -+++ b/crypto/x509v3/pcy_local.h -@@ -111,6 +111,11 @@ struct X509_POLICY_LEVEL_st { - }; - - struct X509_POLICY_TREE_st { -+ /* The number of nodes in the tree */ -+ size_t node_count; -+ /* The maximum number of nodes in the tree */ -+ size_t node_maximum; -+ - /* This is the tree 'level' data */ - X509_POLICY_LEVEL *levels; - int nlevel; -@@ -159,7 +164,8 @@ X509_POLICY_NODE *tree_find_sk(STACK_OF(X509_POLICY_NODE) *sk, - X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, - X509_POLICY_DATA *data, - X509_POLICY_NODE *parent, -- X509_POLICY_TREE *tree); -+ X509_POLICY_TREE *tree, -+ int extra_data); - void policy_node_free(X509_POLICY_NODE *node); - int policy_node_match(const X509_POLICY_LEVEL *lvl, - const X509_POLICY_NODE *node, const ASN1_OBJECT *oid); -diff --git a/crypto/x509v3/pcy_node.c b/crypto/x509v3/pcy_node.c -index e2d7b15322..d574fb9d66 100644 ---- a/crypto/x509v3/pcy_node.c -+++ b/crypto/x509v3/pcy_node.c -@@ -59,10 +59,15 @@ X509_POLICY_NODE *level_find_node(const X509_POLICY_LEVEL *level, - X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, - X509_POLICY_DATA *data, - X509_POLICY_NODE *parent, -- X509_POLICY_TREE *tree) -+ X509_POLICY_TREE *tree, -+ int extra_data) - { - X509_POLICY_NODE *node; - -+ /* Verify that the tree isn't too large. This mitigates CVE-2023-0464 */ -+ if (tree->node_maximum > 0 && tree->node_count >= tree->node_maximum) -+ return NULL; -+ - node = OPENSSL_zalloc(sizeof(*node)); - if (node == NULL) { - X509V3err(X509V3_F_LEVEL_ADD_NODE, ERR_R_MALLOC_FAILURE); -@@ -70,7 +75,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, - } - node->data = data; - node->parent = parent; -- if (level) { -+ if (level != NULL) { - if (OBJ_obj2nid(data->valid_policy) == NID_any_policy) { - if (level->anyPolicy) - goto node_error; -@@ -90,7 +95,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, - } - } - -- if (tree) { -+ if (extra_data) { - if (tree->extra_data == NULL) - tree->extra_data = sk_X509_POLICY_DATA_new_null(); - if (tree->extra_data == NULL){ -@@ -103,6 +108,7 @@ X509_POLICY_NODE *level_add_node(X509_POLICY_LEVEL *level, - } - } - -+ tree->node_count++; - if (parent) - parent->nchild++; - -diff --git a/crypto/x509v3/pcy_tree.c b/crypto/x509v3/pcy_tree.c -index 6e8322cbc5..6c7fd35405 100644 ---- a/crypto/x509v3/pcy_tree.c -+++ b/crypto/x509v3/pcy_tree.c -@@ -13,6 +13,18 @@ - - #include "pcy_local.h" - -+/* -+ * If the maximum number of nodes in the policy tree isn't defined, set it to -+ * a generous default of 1000 nodes. -+ * -+ * Defining this to be zero means unlimited policy tree growth which opens the -+ * door on CVE-2023-0464. -+ */ -+ -+#ifndef OPENSSL_POLICY_TREE_NODES_MAX -+# define OPENSSL_POLICY_TREE_NODES_MAX 1000 -+#endif -+ - /* - * Enable this to print out the complete policy tree at various point during - * evaluation. -@@ -168,6 +180,9 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, - return X509_PCY_TREE_INTERNAL; - } - -+ /* Limit the growth of the tree to mitigate CVE-2023-0464 */ -+ tree->node_maximum = OPENSSL_POLICY_TREE_NODES_MAX; -+ - /* - * http://tools.ietf.org/html/rfc5280#section-6.1.2, figure 3. - * -@@ -184,7 +199,7 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, - level = tree->levels; - if ((data = policy_data_new(NULL, OBJ_nid2obj(NID_any_policy), 0)) == NULL) - goto bad_tree; -- if (level_add_node(level, data, NULL, tree) == NULL) { -+ if (level_add_node(level, data, NULL, tree, 1) == NULL) { - policy_data_free(data); - goto bad_tree; - } -@@ -243,7 +258,8 @@ static int tree_init(X509_POLICY_TREE **ptree, STACK_OF(X509) *certs, - * Return value: 1 on success, 0 otherwise - */ - static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, -- X509_POLICY_DATA *data) -+ X509_POLICY_DATA *data, -+ X509_POLICY_TREE *tree) - { - X509_POLICY_LEVEL *last = curr - 1; - int i, matched = 0; -@@ -253,13 +269,13 @@ static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, - X509_POLICY_NODE *node = sk_X509_POLICY_NODE_value(last->nodes, i); - - if (policy_node_match(last, node, data->valid_policy)) { -- if (level_add_node(curr, data, node, NULL) == NULL) -+ if (level_add_node(curr, data, node, tree, 0) == NULL) - return 0; - matched = 1; - } - } - if (!matched && last->anyPolicy) { -- if (level_add_node(curr, data, last->anyPolicy, NULL) == NULL) -+ if (level_add_node(curr, data, last->anyPolicy, tree, 0) == NULL) - return 0; - } - return 1; -@@ -272,7 +288,8 @@ static int tree_link_matching_nodes(X509_POLICY_LEVEL *curr, - * Return value: 1 on success, 0 otherwise. - */ - static int tree_link_nodes(X509_POLICY_LEVEL *curr, -- const X509_POLICY_CACHE *cache) -+ const X509_POLICY_CACHE *cache, -+ X509_POLICY_TREE *tree) - { - int i; - -@@ -280,7 +297,7 @@ static int tree_link_nodes(X509_POLICY_LEVEL *curr, - X509_POLICY_DATA *data = sk_X509_POLICY_DATA_value(cache->data, i); - - /* Look for matching nodes in previous level */ -- if (!tree_link_matching_nodes(curr, data)) -+ if (!tree_link_matching_nodes(curr, data, tree)) - return 0; - } - return 1; -@@ -311,7 +328,7 @@ static int tree_add_unmatched(X509_POLICY_LEVEL *curr, - /* Curr may not have anyPolicy */ - data->qualifier_set = cache->anyPolicy->qualifier_set; - data->flags |= POLICY_DATA_FLAG_SHARED_QUALIFIERS; -- if (level_add_node(curr, data, node, tree) == NULL) { -+ if (level_add_node(curr, data, node, tree, 1) == NULL) { - policy_data_free(data); - return 0; - } -@@ -373,7 +390,7 @@ static int tree_link_any(X509_POLICY_LEVEL *curr, - } - /* Finally add link to anyPolicy */ - if (last->anyPolicy && -- level_add_node(curr, cache->anyPolicy, last->anyPolicy, NULL) == NULL) -+ level_add_node(curr, cache->anyPolicy, last->anyPolicy, tree, 0) == NULL) - return 0; - return 1; - } -@@ -555,7 +572,7 @@ static int tree_calculate_user_set(X509_POLICY_TREE *tree, - extra->qualifier_set = anyPolicy->data->qualifier_set; - extra->flags = POLICY_DATA_FLAG_SHARED_QUALIFIERS - | POLICY_DATA_FLAG_EXTRA_NODE; -- node = level_add_node(NULL, extra, anyPolicy->parent, tree); -+ node = level_add_node(NULL, extra, anyPolicy->parent, tree, 1); - } - if (!tree->user_policies) { - tree->user_policies = sk_X509_POLICY_NODE_new_null(); -@@ -582,7 +599,7 @@ static int tree_evaluate(X509_POLICY_TREE *tree) - - for (i = 1; i < tree->nlevel; i++, curr++) { - cache = policy_cache_set(curr->cert); -- if (!tree_link_nodes(curr, cache)) -+ if (!tree_link_nodes(curr, cache, tree)) - return X509_PCY_TREE_INTERNAL; - - if (!(curr->flags & X509_V_FLAG_INHIBIT_ANY) --- -2.34.1 - diff --git a/third_party/patch/openssl/CVE-2023-0465.patch b/third_party/patch/openssl/CVE-2023-0465.patch deleted file mode 100644 index 441e1d7be11..00000000000 --- a/third_party/patch/openssl/CVE-2023-0465.patch +++ /dev/null @@ -1,54 +0,0 @@ -From b013765abfa80036dc779dd0e50602c57bb3bf95 Mon Sep 17 00:00:00 2001 -From: Matt Caswell -Date: Tue, 7 Mar 2023 16:52:55 +0000 -Subject: [PATCH] Ensure that EXFLAG_INVALID_POLICY is checked even in leaf - certs - -Even though we check the leaf cert to confirm it is valid, we -later ignored the invalid flag and did not notice that the leaf -cert was bad. - -Fixes: CVE-2023-0465 - -Reviewed-by: Hugo Landau -Reviewed-by: Tomas Mraz -(Merged from https://github.com/openssl/openssl/pull/20588) ---- - crypto/x509/x509_vfy.c | 11 +++++++++-- - 1 file changed, 9 insertions(+), 2 deletions(-) - -diff --git a/crypto/x509/x509_vfy.c b/crypto/x509/x509_vfy.c -index 925fbb5412..1dfe4f9f31 100644 ---- a/crypto/x509/x509_vfy.c -+++ b/crypto/x509/x509_vfy.c -@@ -1649,18 +1649,25 @@ static int check_policy(X509_STORE_CTX *ctx) - } - /* Invalid or inconsistent extensions */ - if (ret == X509_PCY_TREE_INVALID) { -- int i; -+ int i, cbcalled = 0; - - /* Locate certificates with bad extensions and notify callback. */ -- for (i = 1; i < sk_X509_num(ctx->chain); i++) { -+ for (i = 0; i < sk_X509_num(ctx->chain); i++) { - X509 *x = sk_X509_value(ctx->chain, i); - - if (!(x->ex_flags & EXFLAG_INVALID_POLICY)) - continue; -+ cbcalled = 1; - if (!verify_cb_cert(ctx, x, i, - X509_V_ERR_INVALID_POLICY_EXTENSION)) - return 0; - } -+ if (!cbcalled) { -+ /* Should not be able to get here */ -+ X509err(X509_F_CHECK_POLICY, ERR_R_INTERNAL_ERROR); -+ return 0; -+ } -+ /* The callback ignored the error so we return success */ - return 1; - } - if (ret == X509_PCY_TREE_FAILURE) { --- -2.34.1 - diff --git a/third_party/patch/openssl/CVE-2023-0466.patch b/third_party/patch/openssl/CVE-2023-0466.patch deleted file mode 100644 index 81de1dcfa34..00000000000 --- a/third_party/patch/openssl/CVE-2023-0466.patch +++ /dev/null @@ -1,27 +0,0 @@ -diff --git a/doc/man3/X509_VERIFY_PARAM_set_flags.pod b/doc/man3/X509_VERIFY_PARAM_set_flags.pod -index f6f304bf7b..aa292f9336 100644 ---- a/doc/man3/X509_VERIFY_PARAM_set_flags.pod -+++ b/doc/man3/X509_VERIFY_PARAM_set_flags.pod -@@ -92,8 +92,9 @@ B. - X509_VERIFY_PARAM_set_time() sets the verification time in B to - B. Normally the current time is used. - --X509_VERIFY_PARAM_add0_policy() enables policy checking (it is disabled --by default) and adds B to the acceptable policy set. -+X509_VERIFY_PARAM_add0_policy() adds B to the acceptable policy set. -+Contrary to preexisting documentation of this function it does not enable -+policy checking. - - X509_VERIFY_PARAM_set1_policies() enables policy checking (it is disabled - by default) and sets the acceptable policy set to B. Any existing -@@ -377,6 +378,10 @@ and has no effect. - - The X509_VERIFY_PARAM_get_hostflags() function was added in OpenSSL 1.1.0i. - -+The function X509_VERIFY_PARAM_add0_policy() was historically documented as -+enabling policy checking however the implementation has never done this. -+The documentation was changed to align with the implementation. -+ - =head1 COPYRIGHT - - Copyright 2009-2020 The OpenSSL Project Authors. All Rights Reserved. diff --git a/third_party/patch/openssl/CVE-2023-2650.patch b/third_party/patch/openssl/CVE-2023-2650.patch deleted file mode 100644 index aef8f9ea0fc..00000000000 --- a/third_party/patch/openssl/CVE-2023-2650.patch +++ /dev/null @@ -1,63 +0,0 @@ -From 9e209944b35cf82368071f160a744b6178f9b098 Mon Sep 17 00:00:00 2001 -From: Richard Levitte -Date: Fri, 12 May 2023 10:00:13 +0200 -Subject: [PATCH] Restrict the size of OBJECT IDENTIFIERs that OBJ_obj2txt will - translate - -OBJ_obj2txt() would translate any size OBJECT IDENTIFIER to canonical -numeric text form. For gigantic sub-identifiers, this would take a very -long time, the time complexity being O(n^2) where n is the size of that -sub-identifier. - -To mitigate this, a restriction on the size that OBJ_obj2txt() will -translate to canonical numeric text form is added, based on RFC 2578 -(STD 58), which says this: - -> 3.5. OBJECT IDENTIFIER values -> -> An OBJECT IDENTIFIER value is an ordered list of non-negative numbers. -> For the SMIv2, each number in the list is referred to as a sub-identifier, -> there are at most 128 sub-identifiers in a value, and each sub-identifier -> has a maximum value of 2^32-1 (4294967295 decimal). - -Fixes otc/security#96 -Fixes CVE-2023-2650 - -Reviewed-by: Matt Caswell -Reviewed-by: Tomas Mraz ---- - crypto/objects/obj_dat.c | 19 +++++++++++++++++++ - -diff --git a/crypto/objects/obj_dat.c b/crypto/objects/obj_dat.c -index 7e8de727f3..d699915b20 100644 ---- a/crypto/objects/obj_dat.c -+++ b/crypto/objects/obj_dat.c -@@ -428,6 +428,25 @@ int OBJ_obj2txt(char *buf, int buf_len, const ASN1_OBJECT *a, int no_name) - first = 1; - bl = NULL; - -+ /* -+ * RFC 2578 (STD 58) says this about OBJECT IDENTIFIERs: -+ * -+ * > 3.5. OBJECT IDENTIFIER values -+ * > -+ * > An OBJECT IDENTIFIER value is an ordered list of non-negative -+ * > numbers. For the SMIv2, each number in the list is referred to as a -+ * > sub-identifier, there are at most 128 sub-identifiers in a value, -+ * > and each sub-identifier has a maximum value of 2^32-1 (4294967295 -+ * > decimal). -+ * -+ * So a legitimate OID according to this RFC is at most (32 * 128 / 7), -+ * i.e. 586 bytes long. -+ * -+ * Ref: https://datatracker.ietf.org/doc/html/rfc2578#section-3.5 -+ */ -+ if (len > 586) -+ goto err; -+ - while (len > 0) { - l = 0; - use_bn = 0; --- -2.34.1 - diff --git a/third_party/patch/openssl/CVE-2023-3446.patch b/third_party/patch/openssl/CVE-2023-3446.patch deleted file mode 100644 index 6804e674528..00000000000 --- a/third_party/patch/openssl/CVE-2023-3446.patch +++ /dev/null @@ -1,124 +0,0 @@ -From 8780a896543a654e757db1b9396383f9d8095528 Mon Sep 17 00:00:00 2001 -From: Matt Caswell -Date: Thu, 6 Jul 2023 16:36:35 +0100 -Subject: [PATCH] Fix DH_check() excessive time with over sized modulus - -The DH_check() function checks numerous aspects of the key or parameters -that have been supplied. Some of those checks use the supplied modulus -value even if it is excessively large. - -There is already a maximum DH modulus size (10,000 bits) over which -OpenSSL will not generate or derive keys. DH_check() will however still -perform various tests for validity on such a large modulus. We introduce a -new maximum (32,768) over which DH_check() will just fail. - -An application that calls DH_check() and supplies a key or parameters -obtained from an untrusted source could be vulnerable to a Denial of -Service attack. - -The function DH_check() is itself called by a number of other OpenSSL -functions. An application calling any of those other functions may -similarly be affected. The other functions affected by this are -DH_check_ex() and EVP_PKEY_param_check(). - -CVE-2023-3446 - -Reviewed-by: Paul Dale -Reviewed-by: Tom Cosgrove -Reviewed-by: Bernd Edlinger -Reviewed-by: Tomas Mraz -(Merged from https://github.com/openssl/openssl/pull/21452) ---- - crypto/dh/dh_check.c | 6 ++++++ - crypto/dh/dh_err.c | 3 ++- - crypto/err/openssl.txt | 1 + - include/openssl/dh.h | 3 +++ - include/openssl/dherr.h | 3 ++- - 5 files changed, 15 insertions(+), 3 deletions(-) - -diff --git a/crypto/dh/dh_check.c b/crypto/dh/dh_check.c -index 4ac169e75c..e5f9dd5030 100644 ---- a/crypto/dh/dh_check.c -+++ b/crypto/dh/dh_check.c -@@ -101,6 +101,12 @@ int DH_check(const DH *dh, int *ret) - BN_CTX *ctx = NULL; - BIGNUM *t1 = NULL, *t2 = NULL; - -+ /* Don't do any checks at all with an excessively large modulus */ -+ if (BN_num_bits(dh->p) > OPENSSL_DH_CHECK_MAX_MODULUS_BITS) { -+ DHerr(DH_F_DH_CHECK, DH_R_MODULUS_TOO_LARGE); -+ return 0; -+ } -+ - if (!DH_check_params(dh, ret)) - return 0; - -diff --git a/crypto/dh/dh_err.c b/crypto/dh/dh_err.c -index 7285587b4a..92800d3fcc 100644 ---- a/crypto/dh/dh_err.c -+++ b/crypto/dh/dh_err.c -@@ -1,6 +1,6 @@ - /* - * Generated by util/mkerr.pl DO NOT EDIT -- * Copyright 1995-2018 The OpenSSL Project Authors. All Rights Reserved. -+ * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved. - * - * Licensed under the OpenSSL license (the "License"). You may not use - * this file except in compliance with the License. You can obtain a copy -@@ -18,6 +18,7 @@ static const ERR_STRING_DATA DH_str_functs[] = { - {ERR_PACK(ERR_LIB_DH, DH_F_DHPARAMS_PRINT_FP, 0), "DHparams_print_fp"}, - {ERR_PACK(ERR_LIB_DH, DH_F_DH_BUILTIN_GENPARAMS, 0), - "dh_builtin_genparams"}, -+ {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK, 0), "DH_check"}, - {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_EX, 0), "DH_check_ex"}, - {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_PARAMS_EX, 0), "DH_check_params_ex"}, - {ERR_PACK(ERR_LIB_DH, DH_F_DH_CHECK_PUB_KEY_EX, 0), "DH_check_pub_key_ex"}, -diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt -index 9f91a4a811..c0a3cd720b 100644 ---- a/crypto/err/openssl.txt -+++ b/crypto/err/openssl.txt -@@ -401,6 +401,7 @@ CT_F_SCT_SET_VERSION:104:SCT_set_version - DH_F_COMPUTE_KEY:102:compute_key - DH_F_DHPARAMS_PRINT_FP:101:DHparams_print_fp - DH_F_DH_BUILTIN_GENPARAMS:106:dh_builtin_genparams -+DH_F_DH_CHECK:126:DH_check - DH_F_DH_CHECK_EX:121:DH_check_ex - DH_F_DH_CHECK_PARAMS_EX:122:DH_check_params_ex - DH_F_DH_CHECK_PUB_KEY_EX:123:DH_check_pub_key_ex -diff --git a/include/openssl/dh.h b/include/openssl/dh.h -index 3527540cdd..892e31559d 100644 ---- a/include/openssl/dh.h -+++ b/include/openssl/dh.h -@@ -29,6 +29,9 @@ extern "C" { - # ifndef OPENSSL_DH_MAX_MODULUS_BITS - # define OPENSSL_DH_MAX_MODULUS_BITS 10000 - # endif -+# ifndef OPENSSL_DH_CHECK_MAX_MODULUS_BITS -+# define OPENSSL_DH_CHECK_MAX_MODULUS_BITS 32768 -+# endif - - # define OPENSSL_DH_FIPS_MIN_MODULUS_BITS 1024 - -diff --git a/include/openssl/dherr.h b/include/openssl/dherr.h -index 916b3bed0b..528c819856 100644 ---- a/include/openssl/dherr.h -+++ b/include/openssl/dherr.h -@@ -1,6 +1,6 @@ - /* - * Generated by util/mkerr.pl DO NOT EDIT -- * Copyright 1995-2019 The OpenSSL Project Authors. All Rights Reserved. -+ * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved. - * - * Licensed under the OpenSSL license (the "License"). You may not use - * this file except in compliance with the License. You can obtain a copy -@@ -30,6 +30,7 @@ int ERR_load_DH_strings(void); - # define DH_F_COMPUTE_KEY 102 - # define DH_F_DHPARAMS_PRINT_FP 101 - # define DH_F_DH_BUILTIN_GENPARAMS 106 -+# define DH_F_DH_CHECK 126 - # define DH_F_DH_CHECK_EX 121 - # define DH_F_DH_CHECK_PARAMS_EX 122 - # define DH_F_DH_CHECK_PUB_KEY_EX 123 --- -2.34.1 - diff --git a/third_party/patch/openssl/CVE-2023-3817.patch b/third_party/patch/openssl/CVE-2023-3817.patch deleted file mode 100644 index 9d6ab499dad..00000000000 --- a/third_party/patch/openssl/CVE-2023-3817.patch +++ /dev/null @@ -1,57 +0,0 @@ -From 91ddeba0f2269b017dc06c46c993a788974b1aa5 Mon Sep 17 00:00:00 2001 -From: Tomas Mraz -Date: Fri, 21 Jul 2023 11:39:41 +0200 -Subject: [PATCH] DH_check(): Do not try checking q properties if it is - obviously invalid - -If |q| >= |p| then the q value is obviously wrong as q -is supposed to be a prime divisor of p-1. - -We check if p is overly large so this added test implies that -q is not large either when performing subsequent tests using that -q value. - -Otherwise if it is too large these additional checks of the q value -such as the primality test can then trigger DoS by doing overly long -computations. - -Fixes CVE-2023-3817 - -Reviewed-by: Paul Dale -Reviewed-by: Matt Caswell -(Merged from https://github.com/openssl/openssl/pull/21551) ---- - crypto/dh/dh_check.c | 11 +++++++++-- - 1 file changed, 9 insertions(+), 2 deletions(-) - -diff --git a/crypto/dh/dh_check.c b/crypto/dh/dh_check.c -index 2001d2e7cb..9ae96991eb 100644 ---- a/crypto/dh/dh_check.c -+++ b/crypto/dh/dh_check.c -@@ -97,7 +97,7 @@ int DH_check_ex(const DH *dh) - - int DH_check(const DH *dh, int *ret) - { -- int ok = 0, r; -+ int ok = 0, r, q_good = 0; - BN_CTX *ctx = NULL; - BIGNUM *t1 = NULL, *t2 = NULL; - -@@ -120,7 +120,14 @@ int DH_check(const DH *dh, int *ret) - if (t2 == NULL) - goto err; - -- if (dh->q) { -+ if (dh->q != NULL) { -+ if (BN_ucmp(dh->p, dh->q) > 0) -+ q_good = 1; -+ else -+ *ret |= DH_CHECK_INVALID_Q_VALUE; -+ } -+ -+ if (q_good) { - if (BN_cmp(dh->g, BN_value_one()) <= 0) - *ret |= DH_NOT_SUITABLE_GENERATOR; - else if (BN_cmp(dh->g, dh->p) >= 0) --- -2.34.1 \ No newline at end of file diff --git a/third_party/patch/openssl/CVE-2023-4807.patch b/third_party/patch/openssl/CVE-2023-4807.patch deleted file mode 100644 index 8e5791cde36..00000000000 --- a/third_party/patch/openssl/CVE-2023-4807.patch +++ /dev/null @@ -1,47 +0,0 @@ -From a632d534c73eeb3e3db8c7540d811194ef7c79ff Mon Sep 17 00:00:00 2001 -From: Bernd Edlinger -Date: Tue, 22 Aug 2023 16:07:30 +0200 -Subject: [PATCH] Avoid clobbering non-volatile XMM registers - -This affects some Poly1305 assembler functions -which are only used for certain CPU types. - -Remove those functions for Windows targets, -as a simple interim solution. - -Fixes #21522 - -Reviewed-by: Tomas Mraz -Reviewed-by: Paul Dale -(Merged from https://github.com/openssl/openssl/pull/21808) - -(cherry picked from commit 7b8e27bc2e02238986d89ef0ece067ec1b48e165) ---- - crypto/poly1305/asm/poly1305-x86_64.pl | 4 ++-- - 1 file changed, 2 insertions(+), 2 deletions(-) - -diff --git a/crypto/poly1305/asm/poly1305-x86_64.pl b/crypto/poly1305/asm/poly1305-x86_64.pl -index 5f834d8faf..801455c639 100755 ---- a/crypto/poly1305/asm/poly1305-x86_64.pl -+++ b/crypto/poly1305/asm/poly1305-x86_64.pl -@@ -193,7 +193,7 @@ $code.=<<___ if ($avx>1); - bt \$`5+32`,%r9 # AVX2? - cmovc %rax,%r10 - ___ --$code.=<<___ if ($avx>3); -+$code.=<<___ if ($avx>3 && !$win64); - mov \$`(1<<31|1<<21|1<<16)`,%rax - shr \$32,%r9 - and %rax,%r9 -@@ -2722,7 +2722,7 @@ $code.=<<___; - .cfi_endproc - .size poly1305_blocks_avx512,.-poly1305_blocks_avx512 - ___ --if ($avx>3) { -+if ($avx>3 && !$win64) { - ######################################################################## - # VPMADD52 version using 2^44 radix. - # --- -2.34.1 - diff --git a/third_party/patch/openssl/CVE-2023-5678.patch b/third_party/patch/openssl/CVE-2023-5678.patch deleted file mode 100644 index 02f8761ed53..00000000000 --- a/third_party/patch/openssl/CVE-2023-5678.patch +++ /dev/null @@ -1,113 +0,0 @@ -diff --git a/crypto/dh/dh_check.c b/crypto/dh/dh_check.c -index 4ac169e..9cb4482 100644 ---- a/crypto/dh/dh_check.c -+++ b/crypto/dh/dh_check.c -@@ -184,6 +184,20 @@ int DH_check_pub_key(const DH *dh, const BIGNUM *pub_key, int *ret) - BN_CTX *ctx = NULL; - - *ret = 0; -+ -+ /* Don't do any checks at all with an excessively large modulus */ -+ if (BN_num_bits(dh->p) > OPENSSL_DH_CHECK_MAX_MODULUS_BITS) { -+ DHerr(DH_F_DH_CHECK_EX, DH_R_MODULUS_TOO_LARGE); -+ *ret = DH_MODULUS_TOO_LARGE | DH_CHECK_PUBKEY_INVALID; -+ return 0; -+ } -+ -+ if (dh->q != NULL && BN_ucmp(dh->p, dh->q) < 0) { -+ *ret |= DH_CHECK_INVALID_Q_VALUE | DH_CHECK_PUBKEY_INVALID; -+ return 1; -+ } -+ -+ - ctx = BN_CTX_new(); - if (ctx == NULL) - goto err; -diff --git a/crypto/dh/dh_err.c b/crypto/dh/dh_err.c -index 7285587..85f1e51 100644 ---- a/crypto/dh/dh_err.c -+++ b/crypto/dh/dh_err.c -@@ -81,6 +81,7 @@ static const ERR_STRING_DATA DH_str_reasons[] = { - {ERR_PACK(ERR_LIB_DH, 0, DH_R_PARAMETER_ENCODING_ERROR), - "parameter encoding error"}, - {ERR_PACK(ERR_LIB_DH, 0, DH_R_PEER_KEY_ERROR), "peer key error"}, -+ {ERR_PACK(ERR_LIB_DH, 0, DH_R_Q_TOO_LARGE), "q too large"}, - {ERR_PACK(ERR_LIB_DH, 0, DH_R_SHARED_INFO_ERROR), "shared info error"}, - {ERR_PACK(ERR_LIB_DH, 0, DH_R_UNABLE_TO_CHECK_GENERATOR), - "unable to check generator"}, -diff --git a/crypto/dh/dh_key.c b/crypto/dh/dh_key.c -index 117f2fa..b4c789d 100644 ---- a/crypto/dh/dh_key.c -+++ b/crypto/dh/dh_key.c -@@ -109,6 +109,12 @@ static int generate_key(DH *dh) - BN_MONT_CTX *mont = NULL; - BIGNUM *pub_key = NULL, *priv_key = NULL; - -+ if (dh->q != NULL -+ && BN_num_bits(dh->q) > OPENSSL_DH_MAX_MODULUS_BITS) { -+ DHerr(DH_F_GENERATE_KEY, DH_R_Q_TOO_LARGE); -+ return 0; -+ } -+ - if (BN_num_bits(dh->p) > OPENSSL_DH_MAX_MODULUS_BITS) { - DHerr(DH_F_GENERATE_KEY, DH_R_MODULUS_TOO_LARGE); - return 0; -@@ -202,6 +208,12 @@ static int compute_key(unsigned char *key, const BIGNUM *pub_key, DH *dh) - int ret = -1; - int check_result; - -+ if (dh->q != NULL -+ && BN_num_bits(dh->q) > OPENSSL_DH_MAX_MODULUS_BITS) { -+ DHerr(DH_F_COMPUTE_KEY, DH_R_Q_TOO_LARGE); -+ goto err; -+ } -+ - if (BN_num_bits(dh->p) > OPENSSL_DH_MAX_MODULUS_BITS) { - DHerr(DH_F_COMPUTE_KEY, DH_R_MODULUS_TOO_LARGE); - goto err; -diff --git a/crypto/err/openssl.txt b/crypto/err/openssl.txt -index 7e17763..405c116 100644 ---- a/crypto/err/openssl.txt -+++ b/crypto/err/openssl.txt -@@ -2100,6 +2100,7 @@ DH_R_NO_PARAMETERS_SET:107:no parameters set - DH_R_NO_PRIVATE_VALUE:100:no private value - DH_R_PARAMETER_ENCODING_ERROR:105:parameter encoding error - DH_R_PEER_KEY_ERROR:111:peer key error -+DH_R_Q_TOO_LARGE:130:q too large - DH_R_SHARED_INFO_ERROR:113:shared info error - DH_R_UNABLE_TO_CHECK_GENERATOR:121:unable to check generator - DSA_R_BAD_Q_VALUE:102:bad q value -diff --git a/include/openssl/dh.h b/include/openssl/dh.h -index 3527540..a50ad96 100644 ---- a/include/openssl/dh.h -+++ b/include/openssl/dh.h -@@ -68,14 +68,15 @@ DECLARE_ASN1_ITEM(DHparams) - /* #define DH_GENERATOR_3 3 */ - # define DH_GENERATOR_5 5 - --/* DH_check error codes */ -+/* DH_check error codes, some of them shared with DH_check_pub_key */ - # define DH_CHECK_P_NOT_PRIME 0x01 - # define DH_CHECK_P_NOT_SAFE_PRIME 0x02 - # define DH_UNABLE_TO_CHECK_GENERATOR 0x04 - # define DH_NOT_SUITABLE_GENERATOR 0x08 - # define DH_CHECK_Q_NOT_PRIME 0x10 --# define DH_CHECK_INVALID_Q_VALUE 0x20 -+# define DH_CHECK_INVALID_Q_VALUE 0x20 /* +DH_check_pub_key */ - # define DH_CHECK_INVALID_J_VALUE 0x40 -+# define DH_MODULUS_TOO_LARGE 0x100 - - /* DH_check_pub_key error codes */ - # define DH_CHECK_PUBKEY_TOO_SMALL 0x01 -diff --git a/include/openssl/dherr.h b/include/openssl/dherr.h -index 916b3be..88c3a6c 100644 ---- a/include/openssl/dherr.h -+++ b/include/openssl/dherr.h -@@ -81,6 +81,7 @@ int ERR_load_DH_strings(void); - # define DH_R_NO_PRIVATE_VALUE 100 - # define DH_R_PARAMETER_ENCODING_ERROR 105 - # define DH_R_PEER_KEY_ERROR 111 -+# define DH_R_Q_TOO_LARGE 130 - # define DH_R_SHARED_INFO_ERROR 113 - # define DH_R_UNABLE_TO_CHECK_GENERATOR 121 - diff --git a/third_party/patch/openssl/CVE-2024-0727.patch b/third_party/patch/openssl/CVE-2024-0727.patch deleted file mode 100644 index c17bbfa4920..00000000000 --- a/third_party/patch/openssl/CVE-2024-0727.patch +++ /dev/null @@ -1,109 +0,0 @@ -From 09015a582baa980dc04f635504b16fe95dc3790b Mon Sep 17 00:00:00 2001 -From: l00511027 -Date: Fri, 26 Jan 2024 18:45:28 +0800 -Subject: [PATCH 1/2] fix CVE-2024-0727 - -Add NULL checks where ContentInfo data can be NULL ---- - crypto/pkcs12/p12_add.c | 16 ++++++++++++++ - crypto/pkcs12/p12_mutl.c | 5 +++++ - crypto/pkcs12/p12_npas.c | 5 +++-- - crypto/pkcs7/pk7_mime.c | 8 +++++-- - 4 files changed, 53 insertions(+), 6 deletions(-) - -diff --git a/crypto/pkcs12/p12_add.c b/crypto/pkcs12/p12_add.c -index af184c86af..9b40e5384e 100644 ---- a/crypto/pkcs12/p12_add.c -+++ b/crypto/pkcs12/p12_add.c -@@ -76,6 +76,12 @@ STACK_OF(PKCS12_SAFEBAG) *PKCS12_unpack_p7data(PKCS7 *p7) - PKCS12_R_CONTENT_TYPE_NOT_DATA); - return NULL; - } -+ -+ if (p7->d.data == NULL) { -+ PKCS12err(PKCS12_F_PKCS12_UNPACK_P7DATA, PKCS12_R_DECODE_ERROR); -+ return NULL; -+ } -+ - return ASN1_item_unpack(p7->d.data, ASN1_ITEM_rptr(PKCS12_SAFEBAGS)); - } - -@@ -132,6 +138,11 @@ STACK_OF(PKCS12_SAFEBAG) *PKCS12_unpack_p7encdata(PKCS7 *p7, const char *pass, - { - if (!PKCS7_type_is_encrypted(p7)) - return NULL; -+ -+ if (p7->d.encrypted == NULL) { -+ return NULL; -+ } -+ - return PKCS12_item_decrypt_d2i(p7->d.encrypted->enc_data->algorithm, - ASN1_ITEM_rptr(PKCS12_SAFEBAGS), - pass, passlen, -@@ -159,6 +170,11 @@ STACK_OF(PKCS7) *PKCS12_unpack_authsafes(const PKCS12 *p12) - PKCS12_R_CONTENT_TYPE_NOT_DATA); - return NULL; - } -+ if (p12->authsafes->d.data == NULL) { -+ PKCS12err(PKCS12_F_PKCS12_UNPACK_AUTHSAFES, PKCS12_R_DECODE_ERROR); -+ return NULL; -+ } -+ - return ASN1_item_unpack(p12->authsafes->d.data, - ASN1_ITEM_rptr(PKCS12_AUTHSAFES)); - } -diff --git a/crypto/pkcs12/p12_mutl.c b/crypto/pkcs12/p12_mutl.c -index 3658003fe5..766c9c1e9d 100644 ---- a/crypto/pkcs12/p12_mutl.c -+++ b/crypto/pkcs12/p12_mutl.c -@@ -93,6 +93,11 @@ static int pkcs12_gen_mac(PKCS12 *p12, const char *pass, int passlen, - return 0; - } - -+ if (p12->authsafes->d.data == NULL) { -+ PKCS12err(PKCS12_F_PKCS12_GEN_MAC, PKCS12_R_DECODE_ERROR); -+ return 0; -+ } -+ - salt = p12->mac->salt->data; - saltlen = p12->mac->salt->length; - if (!p12->mac->iter) -diff --git a/crypto/pkcs12/p12_npas.c b/crypto/pkcs12/p12_npas.c -index 0334289a89..130337638d 100644 ---- a/crypto/pkcs12/p12_npas.c -+++ b/crypto/pkcs12/p12_npas.c -@@ -78,8 +78,9 @@ static int newpass_p12(PKCS12 *p12, const char *oldpass, const char *newpass) - bags = PKCS12_unpack_p7data(p7); - } else if (bagnid == NID_pkcs7_encrypted) { - bags = PKCS12_unpack_p7encdata(p7, oldpass, -1); -- if (!alg_get(p7->d.encrypted->enc_data->algorithm, -- &pbe_nid, &pbe_iter, &pbe_saltlen)) -+ if (p7->d.encrypted == NULL -+ || !alg_get(p7->d.encrypted->enc_data->algorithm, -+ &pbe_nid, &pbe_iter, &pbe_saltlen)) - goto err; - } else { - continue; -diff --git a/crypto/pkcs7/pk7_mime.c b/crypto/pkcs7/pk7_mime.c -index 19e6868148..b457108c94 100644 ---- a/crypto/pkcs7/pk7_mime.c -+++ b/crypto/pkcs7/pk7_mime.c -@@ -30,10 +30,14 @@ int SMIME_write_PKCS7(BIO *bio, PKCS7 *p7, BIO *data, int flags) - { - STACK_OF(X509_ALGOR) *mdalgs; - int ctype_nid = OBJ_obj2nid(p7->type); -- if (ctype_nid == NID_pkcs7_signed) -+ -+ if (ctype_nid == NID_pkcs7_signed) { -+ if (p7->d.sign == NULL) -+ return 0; - mdalgs = p7->d.sign->md_algs; -- else -+ } else { - mdalgs = NULL; -+ } - - flags ^= SMIME_OLDMIME; - --- -2.17.1 diff --git a/third_party/patch/openssl/CVE-2024-2511.patch b/third_party/patch/openssl/CVE-2024-2511.patch deleted file mode 100644 index 8be177e5ae0..00000000000 --- a/third_party/patch/openssl/CVE-2024-2511.patch +++ /dev/null @@ -1,487 +0,0 @@ -From fc43b2b1abae58c1b261962299d2bbeee770810a Mon Sep 17 00:00:00 2001 -From: jxlang910 -Date: Thu, 11 Apr 2024 17:24:44 +0800 -Subject: [PATCH] fix CVE-2024-2511 - ---- - include/openssl/sslerr.h | 4 +- - ssl/ssl_err.c | 5 +- - ssl/ssl_lib.c | 5 +- - ssl/ssl_sess.c | 36 ++++- - ssl/statem/statem_srvr.c | 5 +- - test/sslapitest.c | 300 +++++++++++++++++++++++++++++++++++++++ - 6 files changed, 339 insertions(+), 16 deletions(-) - -diff --git a/include/openssl/sslerr.h b/include/openssl/sslerr.h -index aa5f56a482..3e99ffc27f 100644 ---- a/include/openssl/sslerr.h -+++ b/include/openssl/sslerr.h -@@ -1,6 +1,6 @@ - /* - * Generated by util/mkerr.pl DO NOT EDIT -- * Copyright 1995-2020 The OpenSSL Project Authors. All Rights Reserved. -+ * Copyright 1995-2024 The OpenSSL Project Authors. All Rights Reserved. - * - * Licensed under the OpenSSL license (the "License"). You may not use - * this file except in compliance with the License. You can obtain a copy -@@ -224,7 +224,7 @@ int ERR_load_SSL_strings(void); - # define SSL_F_SSL_RENEGOTIATE_ABBREVIATED 546 - # define SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT 320 - # define SSL_F_SSL_SCAN_SERVERHELLO_TLSEXT 321 --# define SSL_F_SSL_SESSION_DUP 348 -+# define SSL_F_SSL_SESSION_DUP_INTERN 668 - # define SSL_F_SSL_SESSION_NEW 189 - # define SSL_F_SSL_SESSION_PRINT_FP 190 - # define SSL_F_SSL_SESSION_SET1_ID 423 -diff --git a/ssl/ssl_err.c b/ssl/ssl_err.c -index 5a7c42a88c..c4144bb8b4 100644 ---- a/ssl/ssl_err.c -+++ b/ssl/ssl_err.c -@@ -1,6 +1,6 @@ - /* - * Generated by util/mkerr.pl DO NOT EDIT -- * Copyright 1995-2019 The OpenSSL Project Authors. All Rights Reserved. -+ * Copyright 1995-2024 The OpenSSL Project Authors. All Rights Reserved. - * - * Licensed under the OpenSSL license (the "License"). You may not use - * this file except in compliance with the License. You can obtain a copy -@@ -325,7 +325,8 @@ static const ERR_STRING_DATA SSL_str_functs[] = { - "SSL_renegotiate_abbreviated"}, - {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SCAN_CLIENTHELLO_TLSEXT, 0), ""}, - {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SCAN_SERVERHELLO_TLSEXT, 0), ""}, -- {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SESSION_DUP, 0), "ssl_session_dup"}, -+ {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SESSION_DUP_INTERN, 0), -+ "ssl_session_dup_intern"}, - {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SESSION_NEW, 0), "SSL_SESSION_new"}, - {ERR_PACK(ERR_LIB_SSL, SSL_F_SSL_SESSION_PRINT_FP, 0), - "SSL_SESSION_print_fp"}, -diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c -index 618549a2ca..2a44960fac 100644 ---- a/ssl/ssl_lib.c -+++ b/ssl/ssl_lib.c -@@ -3541,9 +3541,10 @@ void ssl_update_cache(SSL *s, int mode) - - /* - * If the session_id_length is 0, we are not supposed to cache it, and it -- * would be rather hard to do anyway :-) -+ * would be rather hard to do anyway :-). Also if the session has already -+ * been marked as not_resumable we should not cache it for later reuse. - */ -- if (s->session->session_id_length == 0) -+ if (s->session->session_id_length == 0 || s->session->not_resumable) - return; - - /* -diff --git a/ssl/ssl_sess.c b/ssl/ssl_sess.c -index 1b4c85b60c..5cc816b0fc 100644 ---- a/ssl/ssl_sess.c -+++ b/ssl/ssl_sess.c -@@ -94,16 +94,11 @@ SSL_SESSION *SSL_SESSION_new(void) - return ss; - } - --SSL_SESSION *SSL_SESSION_dup(SSL_SESSION *src) --{ -- return ssl_session_dup(src, 1); --} -- - /* - * Create a new SSL_SESSION and duplicate the contents of |src| into it. If - * ticket == 0 then no ticket information is duplicated, otherwise it is. - */ --SSL_SESSION *ssl_session_dup(SSL_SESSION *src, int ticket) -+static SSL_SESSION *ssl_session_dup_intern(SSL_SESSION *src, int ticket) - { - SSL_SESSION *dest; - -@@ -221,11 +216,32 @@ SSL_SESSION *ssl_session_dup(SSL_SESSION *src, int ticket) - - return dest; - err: -- SSLerr(SSL_F_SSL_SESSION_DUP, ERR_R_MALLOC_FAILURE); -+ SSLerr(SSL_F_SSL_SESSION_DUP_INTERN, ERR_R_MALLOC_FAILURE); - SSL_SESSION_free(dest); - return NULL; - } - -+SSL_SESSION *SSL_SESSION_dup(SSL_SESSION *src) -+{ -+ return ssl_session_dup_intern(src, 1); -+} -+ -+/* -+ * Used internally when duplicating a session which might be already shared. -+ * We will have resumed the original session. Subsequently we might have marked -+ * it as non-resumable (e.g. in another thread) - but this copy should be ok to -+ * resume from. -+ */ -+SSL_SESSION *ssl_session_dup(SSL_SESSION *src, int ticket) -+{ -+ SSL_SESSION *sess = ssl_session_dup_intern(src, ticket); -+ -+ if (sess != NULL) -+ sess->not_resumable = 0; -+ -+ return sess; -+} -+ - const unsigned char *SSL_SESSION_get_id(const SSL_SESSION *s, unsigned int *len) - { - if (len) -@@ -455,6 +471,12 @@ SSL_SESSION *lookup_sess_in_cache(SSL *s, const unsigned char *sess_id, - ret = s->session_ctx->get_session_cb(s, sess_id, sess_id_len, ©); - - if (ret != NULL) { -+ if (ret->not_resumable) { -+ /* If its not resumable then ignore this session */ -+ if (!copy) -+ SSL_SESSION_free(ret); -+ return NULL; -+ } - tsan_counter(&s->session_ctx->stats.sess_cb_hit); - - /* -diff --git a/ssl/statem/statem_srvr.c b/ssl/statem/statem_srvr.c -index 1b3b8002ee..d242e98024 100644 ---- a/ssl/statem/statem_srvr.c -+++ b/ssl/statem/statem_srvr.c -@@ -2418,9 +2418,8 @@ int tls_construct_server_hello(SSL *s, WPACKET *pkt) - * so the following won't overwrite an ID that we're supposed - * to send back. - */ -- if (s->session->not_resumable || -- (!(s->ctx->session_cache_mode & SSL_SESS_CACHE_SERVER) -- && !s->hit)) -+ if (!(s->ctx->session_cache_mode & SSL_SESS_CACHE_SERVER) -+ && !s->hit) - s->session->session_id_length = 0; - - if (usetls13) { -diff --git a/test/sslapitest.c b/test/sslapitest.c -index 5ee982ab06..395b1e5457 100644 ---- a/test/sslapitest.c -+++ b/test/sslapitest.c -@@ -6669,6 +6669,128 @@ static int test_ca_names(int tst) - return testresult; - } - -+/* -+ * Test that a session cache overflow works as expected -+ * Test 0: TLSv1.3, timeout on new session later than old session -+ * Test 1: TLSv1.2, timeout on new session later than old session -+ * Test 2: TLSv1.3, timeout on new session earlier than old session -+ * Test 3: TLSv1.2, timeout on new session earlier than old session -+ */ -+#if !defined(OPENSSL_NO_TLS1_3) || !defined(OPENSSL_NO_TLS1_2) -+static int test_session_cache_overflow(int idx) -+{ -+ SSL_CTX *sctx = NULL, *cctx = NULL; -+ SSL *serverssl = NULL, *clientssl = NULL; -+ int testresult = 0; -+ SSL_SESSION *sess = NULL; -+ -+#ifdef OPENSSL_NO_TLS1_3 -+ /* If no TLSv1.3 available then do nothing in this case */ -+ if (idx % 2 == 0) -+ TEST_info("No TLSv1.3 available"); -+ return 1; -+#endif -+#ifdef OPENSSL_NO_TLS1_2 -+ /* If no TLSv1.2 available then do nothing in this case */ -+ if (idx % 2 == 1) -+ TEST_info("No TLSv1.2 available"); -+ return 1; -+#endif -+ -+ if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), -+ TLS_client_method(), TLS1_VERSION, -+ (idx % 2 == 0) ? TLS1_3_VERSION -+ : TLS1_2_VERSION, -+ &sctx, &cctx, cert, privkey)) -+ || !TEST_true(SSL_CTX_set_options(sctx, SSL_OP_NO_TICKET))) -+ goto end; -+ -+ SSL_CTX_sess_set_get_cb(sctx, get_session_cb); -+ get_sess_val = NULL; -+ -+ SSL_CTX_sess_set_cache_size(sctx, 1); -+ -+ if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, -+ NULL, NULL))) -+ goto end; -+ -+ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) -+ goto end; -+ -+ if (idx > 1) { -+ sess = SSL_get_session(serverssl); -+ if (!TEST_ptr(sess)) -+ goto end; -+ -+ /* -+ * Cause this session to have a longer timeout than the next session to -+ * be added. -+ */ -+ if (!TEST_true(SSL_SESSION_set_timeout(sess, LONG_MAX / 2))) { -+ sess = NULL; -+ goto end; -+ } -+ sess = NULL; -+ } -+ -+ SSL_shutdown(serverssl); -+ SSL_shutdown(clientssl); -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ serverssl = clientssl = NULL; -+ -+ /* -+ * Session cache size is 1 and we already populated the cache with a session -+ * so the next connection should cause an overflow. -+ */ -+ -+ if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, -+ NULL, NULL))) -+ goto end; -+ -+ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) -+ goto end; -+ -+ /* -+ * The session we just negotiated may have been already removed from the -+ * internal cache - but we will return it anyway from our external cache. -+ */ -+ get_sess_val = SSL_get_session(serverssl); -+ if (!TEST_ptr(get_sess_val)) -+ goto end; -+ sess = SSL_get1_session(clientssl); -+ if (!TEST_ptr(sess)) -+ goto end; -+ -+ SSL_shutdown(serverssl); -+ SSL_shutdown(clientssl); -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ serverssl = clientssl = NULL; -+ -+ if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, -+ NULL, NULL))) -+ goto end; -+ -+ if (!TEST_true(SSL_set_session(clientssl, sess))) -+ goto end; -+ -+ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) -+ goto end; -+ -+ testresult = 1; -+ -+ end: -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ SSL_CTX_free(sctx); -+ SSL_CTX_free(cctx); -+ SSL_SESSION_free(sess); -+ -+ return testresult; -+} -+#endif /* !defined(OPENSSL_NO_TLS1_3) || !defined(OPENSSL_NO_TLS1_2) */ -+ - /* - * Test 0: Client sets servername and server acknowledges it (TLSv1.2) - * Test 1: Client sets servername and server does not acknowledge it (TLSv1.2) -@@ -7288,6 +7410,180 @@ static int test_inherit_verify_param(void) - return testresult; - } - -+struct resume_servername_cb_data { -+ int i; -+ SSL_CTX *cctx; -+ SSL_CTX *sctx; -+ SSL_SESSION *sess; -+ int recurse; -+}; -+ -+/* -+ * Servername callback. We use it here to run another complete handshake using -+ * the same session - and mark the session as not_resuamble at the end -+ */ -+static int resume_servername_cb(SSL *s, int *ad, void *arg) -+{ -+ struct resume_servername_cb_data *cbdata = arg; -+ SSL *serverssl = NULL, *clientssl = NULL; -+ int ret = SSL_TLSEXT_ERR_ALERT_FATAL; -+ -+ if (cbdata->recurse) -+ return SSL_TLSEXT_ERR_ALERT_FATAL; -+ -+ if ((cbdata->i % 3) != 1) -+ return SSL_TLSEXT_ERR_OK; -+ -+ cbdata->recurse = 1; -+ -+ if (!TEST_true(create_ssl_objects(cbdata->sctx, cbdata->cctx, &serverssl, -+ &clientssl, NULL, NULL)) -+ || !TEST_true(SSL_set_session(clientssl, cbdata->sess))) -+ goto end; -+ -+ ERR_set_mark(); -+ /* -+ * We expect this to fail - because the servername cb will fail. This will -+ * mark the session as not_resumable. -+ */ -+ if (!TEST_false(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) { -+ ERR_clear_last_mark(); -+ goto end; -+ } -+ ERR_pop_to_mark(); -+ -+ ret = SSL_TLSEXT_ERR_OK; -+ end: -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ cbdata->recurse = 0; -+ return ret; -+} -+ -+/* -+ * Test multiple resumptions and cache size handling -+ * Test 0: TLSv1.3 (max_early_data set) -+ * Test 1: TLSv1.3 (SSL_OP_NO_TICKET set) -+ * Test 2: TLSv1.3 (max_early_data and SSL_OP_NO_TICKET set) -+ * Test 3: TLSv1.3 (SSL_OP_NO_TICKET, simultaneous resumes) -+ * Test 4: TLSv1.2 -+ */ -+static int test_multi_resume(int idx) -+{ -+ SSL_CTX *sctx = NULL, *cctx = NULL; -+ SSL *serverssl = NULL, *clientssl = NULL; -+ SSL_SESSION *sess = NULL; -+ int max_version = TLS1_3_VERSION; -+ int i, testresult = 0; -+ struct resume_servername_cb_data cbdata; -+ -+#if defined(OPENSSL_NO_TLS1_2) -+ if (idx == 4) -+ TEST_info("TLSv1.2 is disabled in this build"); -+ return 1; -+#else -+ if (idx == 4) -+ max_version = TLS1_2_VERSION; -+#endif -+#if defined(OPENSSL_NO_TLS1_3) -+ if (idx != 4) -+ TEST_info("No usable TLSv1.3 in this build"); -+ return 1; -+#endif -+ -+ if (!TEST_true(create_ssl_ctx_pair(TLS_server_method(), -+ TLS_client_method(), TLS1_VERSION, -+ max_version, &sctx, &cctx, cert, -+ privkey))) -+ goto end; -+ -+ /* -+ * TLSv1.3 only uses a session cache if either max_early_data > 0 (used for -+ * replay protection), or if SSL_OP_NO_TICKET is in use -+ */ -+ if (idx == 0 || idx == 2) { -+ if (!TEST_true(SSL_CTX_set_max_early_data(sctx, 1024))) -+ goto end; -+ } -+ if (idx == 1 || idx == 2 || idx == 3) -+ SSL_CTX_set_options(sctx, SSL_OP_NO_TICKET); -+ -+ SSL_CTX_sess_set_cache_size(sctx, 5); -+ -+ if (idx == 3) { -+ SSL_CTX_set_tlsext_servername_callback(sctx, resume_servername_cb); -+ SSL_CTX_set_tlsext_servername_arg(sctx, &cbdata); -+ cbdata.cctx = cctx; -+ cbdata.sctx = sctx; -+ cbdata.recurse = 0; -+ } -+ -+ for (i = 0; i < 30; i++) { -+ if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, -+ NULL, NULL)) -+ || !TEST_true(SSL_set_session(clientssl, sess))) -+ goto end; -+ -+ /* -+ * Check simultaneous resumes. We pause the connection part way through -+ * the handshake by (mis)using the servername_cb. The pause occurs after -+ * session resumption has already occurred, but before any session -+ * tickets have been issued. While paused we run another complete -+ * handshake resuming the same session. -+ */ -+ if (idx == 3) { -+ cbdata.i = i; -+ cbdata.sess = sess; -+ } -+ -+ /* -+ * Recreate a bug where dynamically changing the max_early_data value -+ * can cause sessions in the session cache which cannot be deleted. -+ */ -+ if ((idx == 0 || idx == 2) && (i % 3) == 2) -+ SSL_set_max_early_data(serverssl, 0); -+ -+ if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))) -+ goto end; -+ -+ if (sess == NULL || (idx == 0 && (i % 3) == 2)) { -+ if (!TEST_false(SSL_session_reused(clientssl))) -+ goto end; -+ } else { -+ if (!TEST_true(SSL_session_reused(clientssl))) -+ goto end; -+ } -+ SSL_SESSION_free(sess); -+ -+ /* Do a full handshake, followed by two resumptions */ -+ if ((i % 3) == 2) { -+ sess = NULL; -+ } else { -+ if (!TEST_ptr((sess = SSL_get1_session(clientssl)))) -+ goto end; -+ } -+ -+ SSL_shutdown(clientssl); -+ SSL_shutdown(serverssl); -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ serverssl = clientssl = NULL; -+ } -+ -+ /* We should never exceed the session cache size limit */ -+ if (!TEST_long_le(SSL_CTX_sess_number(sctx), 5)) -+ goto end; -+ -+ testresult = 1; -+ end: -+ SSL_free(serverssl); -+ SSL_free(clientssl); -+ SSL_CTX_free(sctx); -+ SSL_CTX_free(cctx); -+ SSL_SESSION_free(sess); -+ return testresult; -+} -+ - int setup_tests(void) - { - if (!TEST_ptr(certsdir = test_get_argument(0)) -@@ -7422,6 +7718,10 @@ int setup_tests(void) - #if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_TLS1_3) - ADD_ALL_TESTS(test_serverinfo_custom, 4); - #endif -+#if !defined(OPENSSL_NO_TLS1_2) || !defined(OPENSSL_NO_TLS1_3) -+ ADD_ALL_TESTS(test_session_cache_overflow, 4); -+#endif -+ ADD_ALL_TESTS(test_multi_resume, 5); - return 1; - } - --- -2.43.0.windows.1 - diff --git a/third_party/patch/protobuf/CVE-2021-22570.patch b/third_party/patch/protobuf/CVE-2021-22570.patch deleted file mode 100644 index ac560239135..00000000000 --- a/third_party/patch/protobuf/CVE-2021-22570.patch +++ /dev/null @@ -1,163 +0,0 @@ -diff --git a/src/google/protobuf/descriptor.cc b/src/google/protobuf/descriptor.cc -index 9a448ffc8..40510b46c 100644 ---- a/src/google/protobuf/descriptor.cc -+++ b/src/google/protobuf/descriptor.cc -@@ -1090,7 +1090,7 @@ inline void DescriptorPool::Tables::FindAllExtensions( - - bool DescriptorPool::Tables::AddSymbol(const std::string& full_name, - Symbol symbol) { -- if (InsertIfNotPresent(&symbols_by_name_, full_name.c_str(), symbol)) { -+ if (InsertIfNotPresent(&symbols_by_name_, full_name, symbol)) { - symbols_after_checkpoint_.push_back(full_name.c_str()); - return true; - } else { -@@ -1106,7 +1106,7 @@ bool FileDescriptorTables::AddAliasUnderParent(const void* parent, - } - - bool DescriptorPool::Tables::AddFile(const FileDescriptor* file) { -- if (InsertIfNotPresent(&files_by_name_, file->name().c_str(), file)) { -+ if (InsertIfNotPresent(&files_by_name_, file->name(), file)) { - files_after_checkpoint_.push_back(file->name().c_str()); - return true; - } else { -@@ -2628,6 +2628,8 @@ void Descriptor::DebugString(int depth, std::string* contents, - const Descriptor::ReservedRange* range = reserved_range(i); - if (range->end == range->start + 1) { - strings::SubstituteAndAppend(contents, "$0, ", range->start); -+ } else if (range->end > FieldDescriptor::kMaxNumber) { -+ strings::SubstituteAndAppend(contents, "$0 to max, ", range->start); - } else { - strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start, - range->end - 1); -@@ -2831,6 +2833,8 @@ void EnumDescriptor::DebugString( - const EnumDescriptor::ReservedRange* range = reserved_range(i); - if (range->end == range->start) { - strings::SubstituteAndAppend(contents, "$0, ", range->start); -+ } else if (range->end == INT_MAX) { -+ strings::SubstituteAndAppend(contents, "$0 to max, ", range->start); - } else { - strings::SubstituteAndAppend(contents, "$0 to $1, ", range->start, - range->end); -@@ -4022,6 +4026,12 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name, - // Use its file as the parent instead. - if (parent == nullptr) parent = file_; - -+ if (full_name.find('\0') != std::string::npos) { -+ AddError(full_name, proto, DescriptorPool::ErrorCollector::NAME, -+ "\"" + full_name + "\" contains null character."); -+ return false; -+ } -+ - if (tables_->AddSymbol(full_name, symbol)) { - if (!file_tables_->AddAliasUnderParent(parent, name, symbol)) { - // This is only possible if there was already an error adding something of -@@ -4061,6 +4071,11 @@ bool DescriptorBuilder::AddSymbol(const std::string& full_name, - void DescriptorBuilder::AddPackage(const std::string& name, - const Message& proto, - const FileDescriptor* file) { -+ if (name.find('\0') != std::string::npos) { -+ AddError(name, proto, DescriptorPool::ErrorCollector::NAME, -+ "\"" + name + "\" contains null character."); -+ return; -+ } - if (tables_->AddSymbol(name, Symbol(file))) { - // Success. Also add parent package, if any. - std::string::size_type dot_pos = name.find_last_of('.'); -@@ -4374,6 +4389,12 @@ FileDescriptor* DescriptorBuilder::BuildFileImpl( - } - result->pool_ = pool_; - -+ if (result->name().find('\0') != std::string::npos) { -+ AddError(result->name(), proto, DescriptorPool::ErrorCollector::NAME, -+ "\"" + result->name() + "\" contains null character."); -+ return nullptr; -+ } -+ - // Add to tables. - if (!tables_->AddFile(result)) { - AddError(proto.name(), proto, DescriptorPool::ErrorCollector::OTHER, -diff --git a/src/google/protobuf/descriptor_unittest.cc b/src/google/protobuf/descriptor_unittest.cc -index 6085a122a..56c180aa4 100644 ---- a/src/google/protobuf/descriptor_unittest.cc -+++ b/src/google/protobuf/descriptor_unittest.cc -@@ -3786,6 +3786,45 @@ TEST_F(ValidationErrorTest, InvalidPackageName) { - "foo.proto: foo.$: NAME: \"$\" is not a valid identifier.\n"); - } - -+// 'str' is a static C-style string that may contain '\0' -+#define STATIC_STR(str) std::string((str), sizeof(str) - 1) -+ -+TEST_F(ValidationErrorTest, NullCharSymbolName) { -+ BuildFileWithErrors( -+ "name: \"bar.proto\" " -+ "package: \"foo\"" -+ "message_type { " -+ " name: '\\000\\001\\013.Bar' " -+ " field { name: \"foo\" number: 9 label:LABEL_OPTIONAL type:TYPE_INT32 " -+ "} " -+ "}", -+ STATIC_STR("bar.proto: foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a " -+ "valid identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: " -+ "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: " -+ "foo.\0\x1\v.Bar: NAME: \"\0\x1\v.Bar\" is not a valid " -+ "identifier.\nbar.proto: foo.\0\x1\v.Bar: NAME: " -+ "\"\0\x1\v.Bar\" is not a valid identifier.\nbar.proto: " -+ "foo.\0\x1\v.Bar.foo: NAME: \"foo.\0\x1\v.Bar.foo\" contains " -+ "null character.\nbar.proto: foo.\0\x1\v.Bar: NAME: " -+ "\"foo.\0\x1\v.Bar\" contains null character.\n")); -+} -+ -+TEST_F(ValidationErrorTest, NullCharFileName) { -+ BuildFileWithErrors( -+ "name: \"bar\\000\\001\\013.proto\" " -+ "package: \"outer.foo\"", -+ STATIC_STR("bar\0\x1\v.proto: bar\0\x1\v.proto: NAME: " -+ "\"bar\0\x1\v.proto\" contains null character.\n")); -+} -+ -+TEST_F(ValidationErrorTest, NullCharPackageName) { -+ BuildFileWithErrors( -+ "name: \"bar.proto\" " -+ "package: \"\\000\\001\\013.\"", -+ STATIC_STR("bar.proto: \0\x1\v.: NAME: \"\0\x1\v.\" contains null " -+ "character.\n")); -+} -+ - TEST_F(ValidationErrorTest, MissingFileName) { - BuildFileWithErrors("", - -@@ -4001,6 +4040,32 @@ TEST_F(ValidationErrorTest, ReservedFieldsDebugString) { - file->DebugString()); - } - -+TEST_F(ValidationErrorTest, DebugStringReservedRangeMax) { -+ const FileDescriptor* file = BuildFile(strings::Substitute( -+ "name: \"foo.proto\" " -+ "enum_type { " -+ " name: \"Bar\"" -+ " value { name:\"BAR\" number:1 }" -+ " reserved_range { start: 5 end: $0 }" -+ "}" -+ "message_type {" -+ " name: \"Foo\"" -+ " reserved_range { start: 5 end: $1 }" -+ "}", -+ std::numeric_limits::max(), FieldDescriptor::kMaxNumber + 1)); -+ -+ ASSERT_EQ( -+ "syntax = \"proto2\";\n\n" -+ "enum Bar {\n" -+ " BAR = 1;\n" -+ " reserved 5 to max;\n" -+ "}\n\n" -+ "message Foo {\n" -+ " reserved 5 to max;\n" -+ "}\n\n", -+ file->DebugString()); -+} -+ - TEST_F(ValidationErrorTest, EnumReservedFieldError) { - BuildFileWithErrors( - "name: \"foo.proto\" " - - diff --git a/third_party/patch/protobuf/CVE-2022-1941.patch b/third_party/patch/protobuf/CVE-2022-1941.patch deleted file mode 100644 index 7fb4c95e6ed..00000000000 --- a/third_party/patch/protobuf/CVE-2022-1941.patch +++ /dev/null @@ -1,204 +0,0 @@ -diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h -index 074784b96..aff050a81 100644 ---- a/src/google/protobuf/extension_set_inl.h -+++ b/src/google/protobuf/extension_set_inl.h -@@ -206,16 +206,22 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( - const char* ptr, const Msg* containing_type, - internal::InternalMetadata* metadata, internal::ParseContext* ctx) { - std::string payload; -- uint32 type_id = 0; -- bool payload_read = false; -+ -+ uint32_t type_id; -+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; -+ State state = State::kNoTag; -+ - while (!ctx->Done(&ptr)) { - uint32 tag = static_cast(*ptr++); - if (tag == WireFormatLite::kMessageSetTypeIdTag) { - uint64 tmp; - ptr = ParseBigVarint(ptr, &tmp); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -- type_id = tmp; -- if (payload_read) { -+ if (state == State::kNoTag) { -+ type_id = tmp; -+ state = State::kHasType; -+ } else if (state == State::kHasPayload) { -+ type_id = tmp; - ExtensionInfo extension; - bool was_packed_on_wire; - if (!FindExtension(2, type_id, containing_type, ctx, &extension, -@@ -241,20 +247,26 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( - GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && - tmp_ctx.EndedAtLimit()); - } -- type_id = 0; -+ state = State::kDone; - } - } else if (tag == WireFormatLite::kMessageSetMessageTag) { -- if (type_id != 0) { -- ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, -- containing_type, metadata, ctx); -+ -+ if (state == State::kHasType) { -+ ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, -+ containing_type, metadata, ctx); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); -- type_id = 0; -+ state = State::kDone; - } else { -- int32 size = ReadSize(&ptr); -+ -+ std::string tmp; -+ int32_t size = ReadSize(&ptr); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -- ptr = ctx->ReadString(ptr, size, &payload); -+ ptr = ctx->ReadString(ptr, size, &tmp); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -- payload_read = true; -+ if (state == State::kNoTag) { -+ payload = std::move(tmp); -+ state = State::kHasPayload; -+ } - } - } else { - ptr = ReadTag(ptr - 1, &tag); -diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc -index 16edf2ce3..88fb09169 100644 ---- a/src/google/protobuf/wire_format.cc -+++ b/src/google/protobuf/wire_format.cc -@@ -659,9 +659,11 @@ struct WireFormat::MessageSetParser { - const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { - // Parse a MessageSetItem - auto metadata = reflection->MutableInternalMetadata(msg); -+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; -+ State state = State::kNoTag; -+ - std::string payload; -- uint32 type_id = 0; -- bool payload_read = false; -+ uint32_t type_id = 0; - while (!ctx->Done(&ptr)) { - // We use 64 bit tags in order to allow typeid's that span the whole - // range of 32 bit numbers. -@@ -670,8 +672,11 @@ struct WireFormat::MessageSetParser { - uint64 tmp; - ptr = ParseBigVarint(ptr, &tmp); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -- type_id = tmp; -- if (payload_read) { -+ if (state == State::kNoTag) { -+ type_id = tmp; -+ state = State::kHasType; -+ } else if (state == State::kHasPayload) { -+ type_id = tmp; - const FieldDescriptor* field; - if (ctx->data().pool == nullptr) { - field = reflection->FindKnownExtensionByNumber(type_id); -@@ -698,17 +703,18 @@ struct WireFormat::MessageSetParser { - GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && - tmp_ctx.EndedAtLimit()); - } -- type_id = 0; -+ state = State::kDone; - } - continue; - } else if (tag == WireFormatLite::kMessageSetMessageTag) { -- if (type_id == 0) { -- int32 size = ReadSize(&ptr); -+ -+ if (state == State::kNoTag) { -+ int32_t size = ReadSize(&ptr); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - ptr = ctx->ReadString(ptr, size, &payload); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -- payload_read = true; -- } else { -+ state = State::kHasPayload; -+ } else if (state == State::kHasType) { - // We're now parsing the payload - const FieldDescriptor* field = nullptr; - if (descriptor->IsExtensionNumber(type_id)) { -@@ -722,7 +728,12 @@ struct WireFormat::MessageSetParser { - ptr = WireFormat::_InternalParseAndMergeField( - msg, ptr, ctx, static_cast(type_id) * 8 + 2, reflection, - field); -- type_id = 0; -+ state = State::kDone; -+ } else { -+ int32_t size = ReadSize(&ptr); -+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); -+ ptr = ctx->Skip(ptr, size); -+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - } - } else { - // An unknown field in MessageSetItem. -diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h -index c742fe869..4130bc531 100644 ---- a/src/google/protobuf/wire_format_lite.h -+++ b/src/google/protobuf/wire_format_lite.h -@@ -1798,6 +1798,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { - // we can parse it later. - std::string message_data; - -+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; -+ State state = State::kNoTag; -+ - while (true) { - const uint32 tag = input->ReadTagNoLastTag(); - if (tag == 0) return false; -@@ -1806,26 +1809,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { - case WireFormatLite::kMessageSetTypeIdTag: { - uint32 type_id; - if (!input->ReadVarint32(&type_id)) return false; -- last_type_id = type_id; -- -- if (!message_data.empty()) { -+ if (state == State::kNoTag) { -+ last_type_id = type_id; -+ state = State::kHasType; -+ } else if (state == State::kHasPayload) { - // We saw some message data before the type_id. Have to parse it - // now. - io::CodedInputStream sub_input( - reinterpret_cast(message_data.data()), - static_cast(message_data.size())); - sub_input.SetRecursionLimit(input->RecursionBudget()); -- if (!ms.ParseField(last_type_id, &sub_input)) { -+ if (!ms.ParseField(type_id, &sub_input)) { - return false; - } - message_data.clear(); -+ state = State::kDone; - } - - break; - } - - case WireFormatLite::kMessageSetMessageTag: { -- if (last_type_id == 0) { -+ if (state == State::kHasType) { -+ // Already saw type_id, so we can parse this directly. -+ if (!ms.ParseField(last_type_id, input)) { -+ return false; -+ } -+ state = State::kDone; -+ } else if (state == State::kNoTag) { - // We haven't seen a type_id yet. Append this data to message_data. - uint32 length; - if (!input->ReadVarint32(&length)) return false; -@@ -1836,11 +1847,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { - auto ptr = reinterpret_cast(&message_data[0]); - ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); - if (!input->ReadRaw(ptr, length)) return false; -+ state = State::kHasPayload; - } else { -- // Already saw type_id, so we can parse this directly. -- if (!ms.ParseField(last_type_id, input)) { -- return false; -- } -+ if (!ms.SkipField(tag, input)) return false; - } - - break; diff --git a/third_party/patch/sqlite/CVE-2021-36690.patch b/third_party/patch/sqlite/CVE-2021-36690.patch deleted file mode 100644 index 1e35aadd179..00000000000 --- a/third_party/patch/sqlite/CVE-2021-36690.patch +++ /dev/null @@ -1,44 +0,0 @@ -diff -Npur sqlite-version-3.36.0/ext/expert/sqlite3expert.c sqlite-version-3.36.0-change/ext/expert/sqlite3expert.c ---- sqlite-version-3.36.0/ext/expert/sqlite3expert.c 2021-06-19 02:36:39.000000000 +0800 -+++ sqlite-version-3.36.0-change/ext/expert/sqlite3expert.c 2022-09-14 23:12:47.929831193 +0800 -@@ -690,11 +690,13 @@ static int idxGetTableInfo( - rc = idxPrintfPrepareStmt(db, &p1, pzErrmsg, "PRAGMA table_xinfo=%Q", zTab); - while( rc==SQLITE_OK && SQLITE_ROW==sqlite3_step(p1) ){ - const char *zCol = (const char*)sqlite3_column_text(p1, 1); -+ const char *zColSeq = 0; - nByte += 1 + STRLEN(zCol); - rc = sqlite3_table_column_metadata( -- db, "main", zTab, zCol, 0, &zCol, 0, 0, 0 -+ db, "main", zTab, zCol, 0, &zColSeq, 0, 0, 0 - ); -- nByte += 1 + STRLEN(zCol); -+ if( zColSeq==0 ) zColSeq = "binary"; -+ nByte += 1 + STRLEN(zColSeq); - nCol++; - nPk += (sqlite3_column_int(p1, 5)>0); - } -@@ -714,6 +716,7 @@ static int idxGetTableInfo( - nCol = 0; - while( rc==SQLITE_OK && SQLITE_ROW==sqlite3_step(p1) ){ - const char *zCol = (const char*)sqlite3_column_text(p1, 1); -+ const char *zColSeq = 0; - int nCopy = STRLEN(zCol) + 1; - pNew->aCol[nCol].zName = pCsr; - pNew->aCol[nCol].iPk = (sqlite3_column_int(p1, 5)==1 && nPk==1); -@@ -721,12 +724,13 @@ static int idxGetTableInfo( - pCsr += nCopy; - - rc = sqlite3_table_column_metadata( -- db, "main", zTab, zCol, 0, &zCol, 0, 0, 0 -+ db, "main", zTab, zCol, 0, &zColSeq, 0, 0, 0 - ); - if( rc==SQLITE_OK ){ -- nCopy = STRLEN(zCol) + 1; -+ if( zColSeq==0 ) zColSeq = "binary"; -+ nCopy = STRLEN(zColSeq) + 1; - pNew->aCol[nCol].zColl = pCsr; -- memcpy(pCsr, zCol, nCopy); -+ memcpy(pCsr, zColSeq, nCopy); - pCsr += nCopy; - } - diff --git a/third_party/patch/sqlite/CVE-2022-35737.patch b/third_party/patch/sqlite/CVE-2022-35737.patch deleted file mode 100644 index 723eacf90ed..00000000000 --- a/third_party/patch/sqlite/CVE-2022-35737.patch +++ /dev/null @@ -1,14 +0,0 @@ -diff -Npur sqlite-version-3.32.2/src/printf.c sqlite-version-3.32.2-change/src/printf.c ---- sqlite-version-3.32.2/src/printf.c 2020-06-04 20:58:43.000000000 +0800 -+++ sqlite-version-3.32.2-change/src/printf.c 2022-08-17 05:52:17.869214453 +0800 -@@ -798,8 +798,8 @@ void sqlite3_str_vappendf( - case etSQLESCAPE: /* %q: Escape ' characters */ - case etSQLESCAPE2: /* %Q: Escape ' and enclose in '...' */ - case etSQLESCAPE3: { /* %w: Escape " characters */ -- int i, j, k, n, isnull; -- int needQuote; -+ i64 i, j, k, n; -+ int needQuote, isnull; - char ch; - char q = ((xtype==etSQLESCAPE3)?'"':'\''); /* Quote character */ - char *escarg; diff --git a/third_party/patch/sqlite/CVE-2023-7104.patch b/third_party/patch/sqlite/CVE-2023-7104.patch deleted file mode 100644 index b773c92f24f..00000000000 --- a/third_party/patch/sqlite/CVE-2023-7104.patch +++ /dev/null @@ -1,30 +0,0 @@ -diff -Npur sqlite-version-3.36.0/ext/session/sqlite3session.c sqlite-version-3.36.0-change/ext/session/sqlite3session.c ---- sqlite-version-3.36.0/ext/session/sqlite3session.c 2021-06-19 02:36:39.000000000 +0800 -+++ sqlite-version-3.36.0-change/ext/session/sqlite3session.c 2024-01-02 14:28:11.354114191 +0800 -@@ -3020,15 +3020,19 @@ static int sessionReadRecord( - } - } - if( eType==SQLITE_INTEGER || eType==SQLITE_FLOAT ){ -- sqlite3_int64 v = sessionGetI64(aVal); -- if( eType==SQLITE_INTEGER ){ -- sqlite3VdbeMemSetInt64(apOut[i], v); -+ if( (pIn->nData-pIn->iNext)<8 ){ -+ rc = SQLITE_CORRUPT_BKPT - }else{ -- double d; -- memcpy(&d, &v, 8); -- sqlite3VdbeMemSetDouble(apOut[i], d); -+ sqlite3_int64 v = sessionGetI64(aVal); -+ if( eType==SQLITE_INTEGER ){ -+ sqlite3VdbeMemSetInt64(apOut[i], v); -+ }else{ -+ double d; -+ memcpy(&d, &v, 8); -+ sqlite3VdbeMemSetDouble(apOut[i], d); -+ } -+ pIn->iNext += 8; - } -- pIn->iNext += 8; - } - } - } diff --git a/third_party/patch/zlib/CVE-2018-25032.patch b/third_party/patch/zlib/CVE-2018-25032.patch deleted file mode 100644 index 25fee45a8fa..00000000000 --- a/third_party/patch/zlib/CVE-2018-25032.patch +++ /dev/null @@ -1,309 +0,0 @@ -diff -Npur zlib-1.2.11/deflate.c zlib-1.2.11-change/deflate.c ---- zlib-1.2.11/deflate.c 2017-01-16 01:29:40.000000000 +0800 -+++ zlib-1.2.11-change/deflate.c 2022-07-28 04:48:30.310281281 +0800 -@@ -252,10 +252,6 @@ int ZEXPORT deflateInit2_(strm, level, m - int wrap = 1; - static const char my_version[] = ZLIB_VERSION; - -- ushf *overlay; -- /* We overlay pending_buf and d_buf+l_buf. This works since the average -- * output size for (length,distance) codes is <= 24 bits. -- */ - - if (version == Z_NULL || version[0] != my_version[0] || - stream_size != sizeof(z_stream)) { -@@ -326,9 +322,47 @@ int ZEXPORT deflateInit2_(strm, level, m - - s->lit_bufsize = 1 << (memLevel + 6); /* 16K elements by default */ - -- overlay = (ushf *) ZALLOC(strm, s->lit_bufsize, sizeof(ush)+2); -- s->pending_buf = (uchf *) overlay; -- s->pending_buf_size = (ulg)s->lit_bufsize * (sizeof(ush)+2L); -+ /* We overlay pending_buf and sym_buf. This works since the average size -+ * for length/distance pairs over any compressed block is assured to be 31 -+ * bits or less. -+ * -+ * Analysis: The longest fixed codes are a length code of 8 bits plus 5 -+ * extra bits, for lengths 131 to 257. The longest fixed distance codes are -+ * 5 bits plus 13 extra bits, for distances 16385 to 32768. The longest -+ * possible fixed-codes length/distance pair is then 31 bits total. -+ * -+ * sym_buf starts one-fourth of the way into pending_buf. So there are -+ * three bytes in sym_buf for every four bytes in pending_buf. Each symbol -+ * in sym_buf is three bytes -- two for the distance and one for the -+ * literal/length. As each symbol is consumed, the pointer to the next -+ * sym_buf value to read moves forward three bytes. From that symbol, up to -+ * 31 bits are written to pending_buf. The closest the written pending_buf -+ * bits gets to the next sym_buf symbol to read is just before the last -+ * code is written. At that time, 31*(n-2) bits have been written, just -+ * after 24*(n-2) bits have been consumed from sym_buf. sym_buf starts at -+ * 8*n bits into pending_buf. (Note that the symbol buffer fills when n-1 -+ * symbols are written.) The closest the writing gets to what is unread is -+ * then n+14 bits. Here n is lit_bufsize, which is 16384 by default, and -+ * can range from 128 to 32768. -+ * -+ * Therefore, at a minimum, there are 142 bits of space between what is -+ * written and what is read in the overlain buffers, so the symbols cannot -+ * be overwritten by the compressed data. That space is actually 139 bits, -+ * due to the three-bit fixed-code block header. -+ * -+ * That covers the case where either Z_FIXED is specified, forcing fixed -+ * codes, or when the use of fixed codes is chosen, because that choice -+ * results in a smaller compressed block than dynamic codes. That latter -+ * condition then assures that the above analysis also covers all dynamic -+ * blocks. A dynamic-code block will only be chosen to be emitted if it has -+ * fewer bits than a fixed-code block would for the same set of symbols. -+ * Therefore its average symbol length is assured to be less than 31. So -+ * the compressed data for a dynamic block also cannot overwrite the -+ * symbols from which it is being constructed. -+ */ -+ -+ s->pending_buf = (uchf *) ZALLOC(strm, s->lit_bufsize, 4); -+ s->pending_buf_size = (ulg)s->lit_bufsize * 4; - - if (s->window == Z_NULL || s->prev == Z_NULL || s->head == Z_NULL || - s->pending_buf == Z_NULL) { -@@ -337,8 +371,12 @@ int ZEXPORT deflateInit2_(strm, level, m - deflateEnd (strm); - return Z_MEM_ERROR; - } -- s->d_buf = overlay + s->lit_bufsize/sizeof(ush); -- s->l_buf = s->pending_buf + (1+sizeof(ush))*s->lit_bufsize; -+ s->sym_buf = s->pending_buf + s->lit_bufsize; -+ s->sym_end = (s->lit_bufsize - 1) * 3; -+ /* We avoid equality with lit_bufsize*3 because of wraparound at 64K -+ * on 16 bit machines and because stored blocks are restricted to -+ * 64K-1 bytes. -+ */ - - s->level = level; - s->strategy = strategy; -@@ -549,7 +587,7 @@ int ZEXPORT deflatePrime (strm, bits, va - - if (deflateStateCheck(strm)) return Z_STREAM_ERROR; - s = strm->state; -- if ((Bytef *)(s->d_buf) < s->pending_out + ((Buf_size + 7) >> 3)) -+ if (s->sym_buf < s->pending_out + ((Buf_size + 7) >> 3)) - return Z_BUF_ERROR; - do { - put = Buf_size - s->bi_valid; -@@ -1108,7 +1146,6 @@ int ZEXPORT deflateCopy (dest, source) - #else - deflate_state *ds; - deflate_state *ss; -- ushf *overlay; - - - if (deflateStateCheck(source) || dest == Z_NULL) { -@@ -1128,8 +1165,7 @@ int ZEXPORT deflateCopy (dest, source) - ds->window = (Bytef *) ZALLOC(dest, ds->w_size, 2*sizeof(Byte)); - ds->prev = (Posf *) ZALLOC(dest, ds->w_size, sizeof(Pos)); - ds->head = (Posf *) ZALLOC(dest, ds->hash_size, sizeof(Pos)); -- overlay = (ushf *) ZALLOC(dest, ds->lit_bufsize, sizeof(ush)+2); -- ds->pending_buf = (uchf *) overlay; -+ ds->pending_buf = (uchf *) ZALLOC(dest, ds->lit_bufsize, 4); - - if (ds->window == Z_NULL || ds->prev == Z_NULL || ds->head == Z_NULL || - ds->pending_buf == Z_NULL) { -@@ -1143,8 +1179,7 @@ int ZEXPORT deflateCopy (dest, source) - zmemcpy(ds->pending_buf, ss->pending_buf, (uInt)ds->pending_buf_size); - - ds->pending_out = ds->pending_buf + (ss->pending_out - ss->pending_buf); -- ds->d_buf = overlay + ds->lit_bufsize/sizeof(ush); -- ds->l_buf = ds->pending_buf + (1+sizeof(ush))*ds->lit_bufsize; -+ ds->sym_buf = ds->pending_buf + ds->lit_bufsize; - - ds->l_desc.dyn_tree = ds->dyn_ltree; - ds->d_desc.dyn_tree = ds->dyn_dtree; -@@ -1912,7 +1947,7 @@ local block_state deflate_fast(s, flush) - FLUSH_BLOCK(s, 1); - return finish_done; - } -- if (s->last_lit) -+ if (s->sym_next) - FLUSH_BLOCK(s, 0); - return block_done; - } -@@ -2043,7 +2078,7 @@ local block_state deflate_slow(s, flush) - FLUSH_BLOCK(s, 1); - return finish_done; - } -- if (s->last_lit) -+ if (s->sym_next) - FLUSH_BLOCK(s, 0); - return block_done; - } -@@ -2118,7 +2153,7 @@ local block_state deflate_rle(s, flush) - FLUSH_BLOCK(s, 1); - return finish_done; - } -- if (s->last_lit) -+ if (s->sym_next) - FLUSH_BLOCK(s, 0); - return block_done; - } -@@ -2157,7 +2192,7 @@ local block_state deflate_huff(s, flush) - FLUSH_BLOCK(s, 1); - return finish_done; - } -- if (s->last_lit) -+ if (s->sym_next) - FLUSH_BLOCK(s, 0); - return block_done; - } -diff -Npur zlib-1.2.11/deflate.h zlib-1.2.11-change/deflate.h ---- zlib-1.2.11/deflate.h 2017-01-01 15:37:10.000000000 +0800 -+++ zlib-1.2.11-change/deflate.h 2022-07-28 04:42:55.134287681 +0800 -@@ -217,7 +217,7 @@ typedef struct internal_state { - /* Depth of each subtree used as tie breaker for trees of equal frequency - */ - -- uchf *l_buf; /* buffer for literals or lengths */ -+ uchf *sym_buf; /* buffer for distances and literals/lengths */ - - uInt lit_bufsize; - /* Size of match buffer for literals/lengths. There are 4 reasons for -@@ -239,13 +239,8 @@ typedef struct internal_state { - * - I can't count above 4 - */ - -- uInt last_lit; /* running index in l_buf */ -- -- ushf *d_buf; -- /* Buffer for distances. To simplify the code, d_buf and l_buf have -- * the same number of elements. To use different lengths, an extra flag -- * array would be necessary. -- */ -+ uInt sym_next; /* running index in sym_buf */ -+ uInt sym_end; /* symbol table full when sym_next reaches this */ - - ulg opt_len; /* bit length of current block with optimal trees */ - ulg static_len; /* bit length of current block with static trees */ -@@ -325,20 +320,22 @@ void ZLIB_INTERNAL _tr_stored_block OF(( - - # define _tr_tally_lit(s, c, flush) \ - { uch cc = (c); \ -- s->d_buf[s->last_lit] = 0; \ -- s->l_buf[s->last_lit++] = cc; \ -+ s->sym_buf[s->sym_next++] = 0; \ -+ s->sym_buf[s->sym_next++] = 0; \ -+ s->sym_buf[s->sym_next++] = cc; \ - s->dyn_ltree[cc].Freq++; \ -- flush = (s->last_lit == s->lit_bufsize-1); \ -+ flush = (s->sym_next == s->sym_end); \ - } - # define _tr_tally_dist(s, distance, length, flush) \ - { uch len = (uch)(length); \ - ush dist = (ush)(distance); \ -- s->d_buf[s->last_lit] = dist; \ -- s->l_buf[s->last_lit++] = len; \ -+ s->sym_buf[s->sym_next++] = dist; \ -+ s->sym_buf[s->sym_next++] = dist >> 8; \ -+ s->sym_buf[s->sym_next++] = len; \ - dist--; \ - s->dyn_ltree[_length_code[len]+LITERALS+1].Freq++; \ - s->dyn_dtree[d_code(dist)].Freq++; \ -- flush = (s->last_lit == s->lit_bufsize-1); \ -+ flush = (s->sym_next == s->sym_end); \ - } - #else - # define _tr_tally_lit(s, c, flush) flush = _tr_tally(s, 0, c) -diff -Npur zlib-1.2.11/trees.c zlib-1.2.11-change/trees.c ---- zlib-1.2.11/trees.c 2017-01-16 01:07:14.000000000 +0800 -+++ zlib-1.2.11-change/trees.c 2022-07-28 05:00:04.094268034 +0800 -@@ -416,7 +416,7 @@ local void init_block(s) - - s->dyn_ltree[END_BLOCK].Freq = 1; - s->opt_len = s->static_len = 0L; -- s->last_lit = s->matches = 0; -+ s->sym_next = s->matches = 0; - } - - #define SMALLEST 1 -@@ -947,7 +947,7 @@ void ZLIB_INTERNAL _tr_flush_block(s, bu - - Tracev((stderr, "\nopt %lu(%lu) stat %lu(%lu) stored %lu lit %u ", - opt_lenb, s->opt_len, static_lenb, s->static_len, stored_len, -- s->last_lit)); -+ s->sym_next / 3)); - - if (static_lenb <= opt_lenb) opt_lenb = static_lenb; - -@@ -1016,8 +1016,9 @@ int ZLIB_INTERNAL _tr_tally (s, dist, lc - unsigned dist; /* distance of matched string */ - unsigned lc; /* match length-MIN_MATCH or unmatched char (if dist==0) */ - { -- s->d_buf[s->last_lit] = (ush)dist; -- s->l_buf[s->last_lit++] = (uch)lc; -+ s->sym_buf[s->sym_next++] = dist; -+ s->sym_buf[s->sym_next++] = dist >> 8; -+ s->sym_buf[s->sym_next++] = lc; - if (dist == 0) { - /* lc is the unmatched char */ - s->dyn_ltree[lc].Freq++; -@@ -1032,30 +1033,7 @@ int ZLIB_INTERNAL _tr_tally (s, dist, lc - s->dyn_ltree[_length_code[lc]+LITERALS+1].Freq++; - s->dyn_dtree[d_code(dist)].Freq++; - } -- --#ifdef TRUNCATE_BLOCK -- /* Try to guess if it is profitable to stop the current block here */ -- if ((s->last_lit & 0x1fff) == 0 && s->level > 2) { -- /* Compute an upper bound for the compressed length */ -- ulg out_length = (ulg)s->last_lit*8L; -- ulg in_length = (ulg)((long)s->strstart - s->block_start); -- int dcode; -- for (dcode = 0; dcode < D_CODES; dcode++) { -- out_length += (ulg)s->dyn_dtree[dcode].Freq * -- (5L+extra_dbits[dcode]); -- } -- out_length >>= 3; -- Tracev((stderr,"\nlast_lit %u, in %ld, out ~%ld(%ld%%) ", -- s->last_lit, in_length, out_length, -- 100L - out_length*100L/in_length)); -- if (s->matches < s->last_lit/2 && out_length < in_length/2) return 1; -- } --#endif -- return (s->last_lit == s->lit_bufsize-1); -- /* We avoid equality with lit_bufsize because of wraparound at 64K -- * on 16 bit machines and because stored blocks are restricted to -- * 64K-1 bytes. -- */ -+ return (s->sym_next == s->sym_end); - } - - /* =========================================================================== -@@ -1068,13 +1046,14 @@ local void compress_block(s, ltree, dtre - { - unsigned dist; /* distance of matched string */ - int lc; /* match length or unmatched char (if dist == 0) */ -- unsigned lx = 0; /* running index in l_buf */ -+ unsigned sx = 0; /* running index in sym_buf */ - unsigned code; /* the code to send */ - int extra; /* number of extra bits to send */ - -- if (s->last_lit != 0) do { -- dist = s->d_buf[lx]; -- lc = s->l_buf[lx++]; -+ if (s->sym_next != 0) do { -+ dist = s->sym_buf[sx++] & 0xff; -+ dist += (unsigned)(s->sym_buf[sx++] & 0xff) << 8; -+ lc = s->sym_buf[sx++]; - if (dist == 0) { - send_code(s, lc, ltree); /* send a literal byte */ - Tracecv(isgraph(lc), (stderr," '%c' ", lc)); -@@ -1099,11 +1078,10 @@ local void compress_block(s, ltree, dtre - } - } /* literal or match pair ? */ - -- /* Check that the overlay between pending_buf and d_buf+l_buf is ok: */ -- Assert((uInt)(s->pending) < s->lit_bufsize + 2*lx, -- "pendingBuf overflow"); -+ /* Check that the overlay between pending_buf and sym_buf is ok: */ -+ Assert(s->pending < s->lit_bufsize + sx, "pendingBuf overflow"); - -- } while (lx < s->last_lit); -+ } while (sx < s->sym_next); - - send_code(s, END_BLOCK, ltree); - } diff --git a/third_party/patch/zlib/CVE-2022-37434.patch b/third_party/patch/zlib/CVE-2022-37434.patch deleted file mode 100644 index 186e5e3e684..00000000000 --- a/third_party/patch/zlib/CVE-2022-37434.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff -Npur zlib-1.2.11/inflate.c zlib-1.2.11-change/inflate.c ---- zlib-1.2.11/inflate.c 2017-01-01 15:37:10.000000000 +0800 -+++ zlib-1.2.11-change/inflate.c 2022-08-17 06:25:06.033176873 +0800 -@@ -759,8 +759,9 @@ int flush; - if (copy > have) copy = have; - if (copy) { - if (state->head != Z_NULL && -- state->head->extra != Z_NULL) { -- len = state->head->extra_len - state->length; -+ state->head->extra != Z_NULL && -+ (len = state->head->extra_len - state->length) < -+ state->head->extra_max) { - zmemcpy(state->head->extra + len, next, - len + copy > state->head->extra_max ? - state->head->extra_max - len : copy); diff --git a/third_party/patch/zlib/CVE-2023-45853.patch b/third_party/patch/zlib/CVE-2023-45853.patch deleted file mode 100644 index bfc93a2d97f..00000000000 --- a/third_party/patch/zlib/CVE-2023-45853.patch +++ /dev/null @@ -1,21 +0,0 @@ -diff -Npur zlib-v1.2.11/contrib/minizip/zip.c zlib-v1.2.11-change/contrib/minizip/zip.c ---- zlib-v1.2.11/contrib/minizip/zip.c 2017-01-16 01:29:40.000000000 +0800 -+++ zlib-v1.2.11-change/contrib/minizip/zip.c 2023-10-16 15:13:00.436760625 +0800 -@@ -1083,6 +1083,17 @@ extern int ZEXPORT zipOpenNewFileInZip4_ - return ZIP_PARAMERROR; - #endif - -+ // The filename and comment length must fit in 16 bits. -+ if ((filename!=NULL) && (strlen(filename)>0xffff)) -+ return ZIP_PARAMERROR; -+ if ((comment!=NULL) && (strlen(comment)>0xffff)) -+ return ZIP_PARAMERROR; -+ // The extra field length must fit in 16 bits. If the member also requires -+ // a Zip64 extra block, that will also need to fit within that 16-bit -+ // length, but that will be checked for later. -+ if ((size_extrafield_local>0xffff) || (size_extrafield_global>0xffff)) -+ return ZIP_PARAMERROR; -+ - zi = (zip64_internal*)file; - - if (zi->in_opened_file_inzip == 1) diff --git a/third_party/proto/caffe/caffe.proto b/third_party/proto/caffe/caffe.proto old mode 100755 new mode 100644